diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java index 977478adba..ec3e997938 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java @@ -31,6 +31,7 @@ import com.android.internal.annotations.VisibleForTesting; import java.util.List; import java.util.Map; +import java.util.UUID; import java.util.function.BiPredicate; import java.util.function.Consumer; @@ -43,6 +44,9 @@ public class MdnsAdvertiser { private static final String TAG = MdnsAdvertiser.class.getSimpleName(); static final boolean DBG = Log.isLoggable(TAG, Log.DEBUG); + // Top-level domain for link-local queries, as per RFC6762 3. + private static final String LOCAL_TLD = "local"; + private final Looper mLooper; private final AdvertiserCallback mCb; @@ -60,6 +64,8 @@ public class MdnsAdvertiser { private final SparseArray mRegistrations = new SparseArray<>(); private final Dependencies mDeps; + private String[] mDeviceHostName; + /** * Dependencies for {@link MdnsAdvertiser}, useful for testing. */ @@ -71,11 +77,32 @@ public class MdnsAdvertiser { public MdnsInterfaceAdvertiser makeAdvertiser(@NonNull MdnsInterfaceSocket socket, @NonNull List initialAddresses, @NonNull Looper looper, @NonNull byte[] packetCreationBuffer, - @NonNull MdnsInterfaceAdvertiser.Callback cb) { + @NonNull MdnsInterfaceAdvertiser.Callback cb, + @NonNull String[] deviceHostName) { // Note NetworkInterface is final and not mockable final String logTag = socket.getInterface().getName(); return new MdnsInterfaceAdvertiser(logTag, socket, initialAddresses, looper, - packetCreationBuffer, cb); + packetCreationBuffer, cb, deviceHostName); + } + + /** + * Generates a unique hostname to be used by the device. + */ + @NonNull + public String[] generateHostname() { + // Generate a very-probably-unique hostname. This allows minimizing possible conflicts + // to the point that probing for it is no longer necessary (as per RFC6762 8.1 last + // paragraph), and does not leak more information than what could already be obtained by + // looking at the mDNS packets source address. + // This differs from historical behavior that just used "Android.local" for many + // devices, creating a lot of conflicts. + // Having a different hostname per interface is an acceptable option as per RFC6762 14. + // This hostname will change every time the interface is reconnected, so this does not + // allow tracking the device. + // TODO: consider deriving a hostname from other sources, such as the IPv6 addresses + // (reusing the same privacy-protecting mechanics). + return new String[] { + "Android_" + UUID.randomUUID().toString().replace("-", ""), LOCAL_TLD }; } } @@ -260,7 +287,7 @@ public class MdnsAdvertiser { MdnsInterfaceAdvertiser advertiser = mAllAdvertisers.get(socket); if (advertiser == null) { advertiser = mDeps.makeAdvertiser(socket, addresses, mLooper, mPacketCreationBuffer, - mInterfaceAdvertiserCb); + mInterfaceAdvertiserCb, mDeviceHostName); mAllAdvertisers.put(socket, advertiser); advertiser.start(); } @@ -389,6 +416,7 @@ public class MdnsAdvertiser { mCb = cb; mSocketProvider = socketProvider; mDeps = deps; + mDeviceHostName = deps.generateHostname(); } private void checkThread() { @@ -453,6 +481,10 @@ public class MdnsAdvertiser { advertiser.removeService(id); } mRegistrations.remove(id); + // Regenerates host name when registrations removed. + if (mRegistrations.size() == 0) { + mDeviceHostName = mDeps.generateHostname(); + } } private static boolean any(@NonNull ArrayMap map, diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java index c616e01df6..79cddce0c8 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java @@ -141,8 +141,9 @@ public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHand public static class Dependencies { /** @see MdnsRecordRepository */ @NonNull - public MdnsRecordRepository makeRecordRepository(@NonNull Looper looper) { - return new MdnsRecordRepository(looper); + public MdnsRecordRepository makeRecordRepository(@NonNull Looper looper, + @NonNull String[] deviceHostName) { + return new MdnsRecordRepository(looper, deviceHostName); } /** @see MdnsReplySender */ @@ -169,17 +170,18 @@ public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHand public MdnsInterfaceAdvertiser(@NonNull String logTag, @NonNull MdnsInterfaceSocket socket, @NonNull List initialAddresses, - @NonNull Looper looper, @NonNull byte[] packetCreationBuffer, @NonNull Callback cb) { + @NonNull Looper looper, @NonNull byte[] packetCreationBuffer, @NonNull Callback cb, + @NonNull String[] deviceHostName) { this(logTag, socket, initialAddresses, looper, packetCreationBuffer, cb, - new Dependencies()); + new Dependencies(), deviceHostName); } public MdnsInterfaceAdvertiser(@NonNull String logTag, @NonNull MdnsInterfaceSocket socket, @NonNull List initialAddresses, @NonNull Looper looper, @NonNull byte[] packetCreationBuffer, @NonNull Callback cb, - @NonNull Dependencies deps) { + @NonNull Dependencies deps, @NonNull String[] deviceHostName) { mTag = MdnsInterfaceAdvertiser.class.getSimpleName() + "/" + logTag; - mRecordRepository = deps.makeRecordRepository(looper); + mRecordRepository = deps.makeRecordRepository(looper, deviceHostName); mRecordRepository.updateAddresses(initialAddresses); mSocket = socket; mCb = cb; diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java index e975ab417b..13291721a7 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java @@ -47,7 +47,6 @@ import java.util.Objects; import java.util.Random; import java.util.Set; import java.util.TreeMap; -import java.util.UUID; import java.util.concurrent.TimeUnit; /** @@ -90,15 +89,16 @@ public class MdnsRecordRepository { @NonNull private final Looper mLooper; @NonNull - private String[] mDeviceHostname; + private final String[] mDeviceHostname; - public MdnsRecordRepository(@NonNull Looper looper) { - this(looper, new Dependencies()); + public MdnsRecordRepository(@NonNull Looper looper, @NonNull String[] deviceHostname) { + this(looper, new Dependencies(), deviceHostname); } @VisibleForTesting - public MdnsRecordRepository(@NonNull Looper looper, @NonNull Dependencies deps) { - mDeviceHostname = deps.getHostname(); + public MdnsRecordRepository(@NonNull Looper looper, @NonNull Dependencies deps, + @NonNull String[] deviceHostname) { + mDeviceHostname = deviceHostname; mLooper = looper; } @@ -107,25 +107,6 @@ public class MdnsRecordRepository { */ @VisibleForTesting public static class Dependencies { - /** - * Get a unique hostname to be used by the device. - */ - @NonNull - public String[] getHostname() { - // Generate a very-probably-unique hostname. This allows minimizing possible conflicts - // to the point that probing for it is no longer necessary (as per RFC6762 8.1 last - // paragraph), and does not leak more information than what could already be obtained by - // looking at the mDNS packets source address. - // This differs from historical behavior that just used "Android.local" for many - // devices, creating a lot of conflicts. - // Having a different hostname per interface is an acceptable option as per RFC6762 14. - // This hostname will change every time the interface is reconnected, so this does not - // allow tracking the device. - // TODO: consider deriving a hostname from other sources, such as the IPv6 addresses - // (reusing the same privacy-protecting mechanics). - return new String[] { - "Android_" + UUID.randomUUID().toString().replace("-", ""), LOCAL_TLD }; - } /** * @see NetworkInterface#getInetAddresses(). diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt index 1febe6dd00..375c15004e 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt @@ -42,6 +42,7 @@ import org.mockito.Mockito.atLeastOnce import org.mockito.Mockito.doReturn import org.mockito.Mockito.mock import org.mockito.Mockito.never +import org.mockito.Mockito.times import org.mockito.Mockito.verify private const val SERVICE_ID_1 = 1 @@ -51,6 +52,7 @@ private val TEST_ADDR = parseNumericAddress("2001:db8::123") private val TEST_LINKADDR = LinkAddress(TEST_ADDR, 64 /* prefixLength */) private val TEST_NETWORK_1 = mock(Network::class.java) private val TEST_NETWORK_2 = mock(Network::class.java) +private val TEST_HOSTNAME = arrayOf("Android_test", "local") private val SERVICE_1 = NsdServiceInfo("TestServiceName", "_advertisertest._tcp").apply { port = 12345 @@ -81,10 +83,13 @@ class MdnsAdvertiserTest { @Before fun setUp() { thread.start() + doReturn(TEST_HOSTNAME).`when`(mockDeps).generateHostname() doReturn(mockInterfaceAdvertiser1).`when`(mockDeps).makeAdvertiser(eq(mockSocket1), - any(), any(), any(), any()) + any(), any(), any(), any(), eq(TEST_HOSTNAME) + ) doReturn(mockInterfaceAdvertiser2).`when`(mockDeps).makeAdvertiser(eq(mockSocket2), - any(), any(), any(), any()) + any(), any(), any(), any(), eq(TEST_HOSTNAME) + ) doReturn(true).`when`(mockInterfaceAdvertiser1).isProbing(anyInt()) doReturn(true).`when`(mockInterfaceAdvertiser2).isProbing(anyInt()) } @@ -106,8 +111,14 @@ class MdnsAdvertiserTest { postSync { socketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR)) } val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java) - verify(mockDeps).makeAdvertiser(eq(mockSocket1), - eq(listOf(TEST_LINKADDR)), eq(thread.looper), any(), intAdvCbCaptor.capture()) + verify(mockDeps).makeAdvertiser( + eq(mockSocket1), + eq(listOf(TEST_LINKADDR)), + eq(thread.looper), + any(), + intAdvCbCaptor.capture(), + eq(TEST_HOSTNAME) + ) doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_1) postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded( @@ -134,9 +145,11 @@ class MdnsAdvertiserTest { val intAdvCbCaptor1 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java) val intAdvCbCaptor2 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java) verify(mockDeps).makeAdvertiser(eq(mockSocket1), eq(listOf(TEST_LINKADDR)), - eq(thread.looper), any(), intAdvCbCaptor1.capture()) + eq(thread.looper), any(), intAdvCbCaptor1.capture(), eq(TEST_HOSTNAME) + ) verify(mockDeps).makeAdvertiser(eq(mockSocket2), eq(listOf(TEST_LINKADDR)), - eq(thread.looper), any(), intAdvCbCaptor2.capture()) + eq(thread.looper), any(), intAdvCbCaptor2.capture(), eq(TEST_HOSTNAME) + ) doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_1) postSync { intAdvCbCaptor1.value.onRegisterServiceSucceeded( @@ -192,7 +205,8 @@ class MdnsAdvertiserTest { val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java) verify(mockDeps).makeAdvertiser(eq(mockSocket1), eq(listOf(TEST_LINKADDR)), - eq(thread.looper), any(), intAdvCbCaptor.capture()) + eq(thread.looper), any(), intAdvCbCaptor.capture(), eq(TEST_HOSTNAME) + ) verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_1), argThat { it.matches(SERVICE_1) }) verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_2), @@ -216,6 +230,15 @@ class MdnsAdvertiserTest { verify(mockInterfaceAdvertiser1, atLeastOnce()).destroyNow() } + @Test + fun testRemoveService_whenAllServiceRemoved_thenUpdateHostName() { + val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps) + verify(mockDeps, times(1)).generateHostname() + postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1) } + postSync { advertiser.removeService(SERVICE_ID_1) } + verify(mockDeps, times(2)).generateHostname() + } + private fun postSync(r: () -> Unit) { handler.post(r) handler.waitForIdle(TIMEOUT_MS) diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt index 4a806b1ec8..2d8d8f306e 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt @@ -55,6 +55,7 @@ private const val TIMEOUT_MS = 10_000L private val TEST_ADDRS = listOf(LinkAddress(parseNumericAddress("2001:db8::123"), 64)) private val TEST_BUFFER = ByteArray(1300) +private val TEST_HOSTNAME = arrayOf("Android_test", "local") private const val TEST_SERVICE_ID_1 = 42 private val TEST_SERVICE_1 = NsdServiceInfo().apply { @@ -88,12 +89,23 @@ class MdnsInterfaceAdvertiserTest { private val packetHandler get() = packetHandlerCaptor.value private val advertiser by lazy { - MdnsInterfaceAdvertiser(LOG_TAG, socket, TEST_ADDRS, thread.looper, TEST_BUFFER, cb, deps) + MdnsInterfaceAdvertiser( + LOG_TAG, + socket, + TEST_ADDRS, + thread.looper, + TEST_BUFFER, + cb, + deps, + TEST_HOSTNAME + ) } @Before fun setUp() { - doReturn(repository).`when`(deps).makeRecordRepository(any()) + doReturn(repository).`when`(deps).makeRecordRepository(any(), + eq(TEST_HOSTNAME) + ) doReturn(replySender).`when`(deps).makeReplySender(anyString(), any(), any(), any()) doReturn(announcer).`when`(deps).makeMdnsAnnouncer(anyString(), any(), any(), any()) doReturn(prober).`when`(deps).makeMdnsProber(anyString(), any(), any(), any()) diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt index ecc11ec6af..5665091a98 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt @@ -67,7 +67,6 @@ private val TEST_SERVICE_2 = NsdServiceInfo().apply { class MdnsRecordRepositoryTest { private val thread = HandlerThread(MdnsRecordRepositoryTest::class.simpleName) private val deps = object : Dependencies() { - override fun getHostname() = TEST_HOSTNAME override fun getInterfaceInetAddresses(iface: NetworkInterface) = Collections.enumeration(TEST_ADDRESSES.map { it.address }) } @@ -84,7 +83,7 @@ class MdnsRecordRepositoryTest { @Test fun testAddServiceAndProbe() { - val repository = MdnsRecordRepository(thread.looper, deps) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) assertEquals(0, repository.servicesCount) assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)) assertEquals(1, repository.servicesCount) @@ -117,7 +116,7 @@ class MdnsRecordRepositoryTest { @Test fun testAddAndConflicts() { - val repository = MdnsRecordRepository(thread.looper, deps) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1) assertFailsWith(NameConflictException::class) { repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1) @@ -126,7 +125,7 @@ class MdnsRecordRepositoryTest { @Test fun testInvalidReuseOfServiceId() { - val repository = MdnsRecordRepository(thread.looper, deps) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1) assertFailsWith(IllegalArgumentException::class) { repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_2) @@ -135,7 +134,7 @@ class MdnsRecordRepositoryTest { @Test fun testHasActiveService() { - val repository = MdnsRecordRepository(thread.looper, deps) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) assertFalse(repository.hasActiveService(TEST_SERVICE_ID_1)) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1) @@ -152,7 +151,7 @@ class MdnsRecordRepositoryTest { @Test fun testExitAnnouncements() { - val repository = MdnsRecordRepository(thread.looper, deps) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1) repository.onAdvertisementSent(TEST_SERVICE_ID_1) @@ -181,7 +180,7 @@ class MdnsRecordRepositoryTest { @Test fun testExitingServiceReAdded() { - val repository = MdnsRecordRepository(thread.looper, deps) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1) repository.onAdvertisementSent(TEST_SERVICE_ID_1) repository.exitService(TEST_SERVICE_ID_1) @@ -195,7 +194,7 @@ class MdnsRecordRepositoryTest { @Test fun testOnProbingSucceeded() { - val repository = MdnsRecordRepository(thread.looper, deps) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1) repository.onAdvertisementSent(TEST_SERVICE_ID_1) val packet = announcementInfo.getPacket(0) @@ -319,7 +318,7 @@ class MdnsRecordRepositoryTest { @Test fun testGetReply() { - val repository = MdnsRecordRepository(thread.looper, deps) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1) val questions = listOf(MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"), 0L /* receiptTimeMillis */, @@ -404,7 +403,7 @@ class MdnsRecordRepositoryTest { @Test fun testGetConflictingServices() { - val repository = MdnsRecordRepository(thread.looper, deps) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1) repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2) @@ -432,7 +431,7 @@ class MdnsRecordRepositoryTest { @Test fun testGetConflictingServices_IdenticalService() { - val repository = MdnsRecordRepository(thread.looper, deps) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1) repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2)