diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java index 9a2cc5f6b6..4af4c6abb8 100644 --- a/service-t/src/com/android/server/NsdService.java +++ b/service-t/src/com/android/server/NsdService.java @@ -727,7 +727,7 @@ public class NsdService extends INsdManager.Stub { // service type would generate service instance names like // Name._subtype._sub._type._tcp, which is incorrect // (it should be Name._type._tcp). - mAdvertiser.addService(id, serviceInfo); + mAdvertiser.addService(id, serviceInfo, typeSubtype.second); storeAdvertiserRequestMap(clientId, id, clientInfo); } else { maybeStartDaemon(); diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java index 655c36442b..cc08ea1acb 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java @@ -270,7 +270,8 @@ public class MdnsAdvertiser { mPendingRegistrations.put(id, registration); for (int i = 0; i < mAdvertisers.size(); i++) { try { - mAdvertisers.valueAt(i).addService(id, registration.getServiceInfo()); + mAdvertisers.valueAt(i).addService( + id, registration.getServiceInfo(), registration.getSubtype()); } catch (NameConflictException e) { Log.wtf(TAG, "Name conflict adding services that should have unique names", e); } @@ -298,9 +299,10 @@ public class MdnsAdvertiser { } mAdvertisers.put(socket, advertiser); for (int i = 0; i < mPendingRegistrations.size(); i++) { + final Registration registration = mPendingRegistrations.valueAt(i); try { advertiser.addService(mPendingRegistrations.keyAt(i), - mPendingRegistrations.valueAt(i).getServiceInfo()); + registration.getServiceInfo(), registration.getSubtype()); } catch (NameConflictException e) { Log.wtf(TAG, "Name conflict adding services that should have unique names", e); } @@ -329,10 +331,13 @@ public class MdnsAdvertiser { private int mConflictCount; @NonNull private NsdServiceInfo mServiceInfo; + @Nullable + private final String mSubtype; - private Registration(@NonNull NsdServiceInfo serviceInfo) { + private Registration(@NonNull NsdServiceInfo serviceInfo, @Nullable String subtype) { this.mOriginalName = serviceInfo.getServiceName(); this.mServiceInfo = serviceInfo; + this.mSubtype = subtype; } /** @@ -387,6 +392,11 @@ public class MdnsAdvertiser { public NsdServiceInfo getServiceInfo() { return mServiceInfo; } + + @Nullable + public String getSubtype() { + return mSubtype; + } } /** @@ -443,8 +453,9 @@ public class MdnsAdvertiser { * Add a service to advertise. * @param id A unique ID for the service. * @param service The service info to advertise. + * @param subtype An optional subtype to advertise the service with. */ - public void addService(int id, NsdServiceInfo service) { + public void addService(int id, NsdServiceInfo service, @Nullable String subtype) { checkThread(); if (mRegistrations.get(id) != null) { Log.e(TAG, "Adding duplicate registration for " + service); @@ -453,10 +464,10 @@ public class MdnsAdvertiser { return; } - mSharedLog.i("Adding service " + service + " with ID " + id); + mSharedLog.i("Adding service " + service + " with ID " + id + " and subtype " + subtype); final Network network = service.getNetwork(); - final Registration registration = new Registration(service); + final Registration registration = new Registration(service, subtype); final BiPredicate checkConflictFilter; if (network == null) { // If registering on all networks, no advertiser must have conflicts diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java index 4e09515a81..724a7045d1 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java @@ -212,8 +212,9 @@ public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHand * * @throws NameConflictException There is already a service being advertised with that name. */ - public void addService(int id, NsdServiceInfo service) throws NameConflictException { - final int replacedExitingService = mRecordRepository.addService(id, service); + public void addService(int id, NsdServiceInfo service, @Nullable String subtype) + throws NameConflictException { + final int replacedExitingService = mRecordRepository.addService(id, service, subtype); // Cancel announcements for the existing service. This only happens for exiting services // (so cancelling exiting announcements), as per RecordRepository.addService. if (replacedExitingService >= 0) { diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java index 13291721a7..f756459623 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java @@ -69,6 +69,8 @@ public class MdnsRecordRepository { // Top-level domain for link-local queries, as per RFC6762 3. private static final String LOCAL_TLD = "local"; + // Subtype separator as per RFC6763 7.1 (_printer._sub._http._tcp.local) + private static final String SUBTYPE_SEPARATOR = "_sub"; // Service type for service enumeration (RFC6763 9.) private static final String[] DNS_SD_SERVICE_TYPE = @@ -156,13 +158,15 @@ public class MdnsRecordRepository { @NonNull public final List> allRecords; @NonNull - public final RecordInfo ptrRecord; + public final List> ptrRecords; @NonNull public final RecordInfo srvRecord; @NonNull public final RecordInfo txtRecord; @NonNull public final NsdServiceInfo serviceInfo; + @Nullable + public final String subtype; /** * Whether the service is sending exit announcements and will be destroyed soon. @@ -175,14 +179,16 @@ public class MdnsRecordRepository { * @param deviceHostname Hostname of the device (for the interface used) * @param serviceInfo Service to advertise */ - ServiceRegistration(@NonNull String[] deviceHostname, @NonNull NsdServiceInfo serviceInfo) { + ServiceRegistration(@NonNull String[] deviceHostname, @NonNull NsdServiceInfo serviceInfo, + @Nullable String subtype) { this.serviceInfo = serviceInfo; + this.subtype = subtype; final String[] serviceType = splitServiceType(serviceInfo); final String[] serviceName = splitFullyQualifiedName(serviceInfo, serviceType); // Service PTR record - ptrRecord = new RecordInfo<>( + final RecordInfo ptrRecord = new RecordInfo<>( serviceInfo, new MdnsPointerRecord( serviceType, @@ -192,6 +198,26 @@ public class MdnsRecordRepository { serviceName), true /* sharedName */, true /* probing */); + if (subtype == null) { + this.ptrRecords = Collections.singletonList(ptrRecord); + } else { + final String[] subtypeName = new String[serviceType.length + 2]; + System.arraycopy(serviceType, 0, subtypeName, 2, serviceType.length); + subtypeName[0] = subtype; + subtypeName[1] = SUBTYPE_SEPARATOR; + final RecordInfo subtypeRecord = new RecordInfo<>( + serviceInfo, + new MdnsPointerRecord( + subtypeName, + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + NON_NAME_RECORDS_TTL_MILLIS, + serviceName), + true /* sharedName */, true /* probing */); + + this.ptrRecords = List.of(ptrRecord, subtypeRecord); + } + srvRecord = new RecordInfo<>( serviceInfo, new MdnsServiceRecord(serviceName, @@ -211,8 +237,8 @@ public class MdnsRecordRepository { attrsToTextEntries(serviceInfo.getAttributes())), false /* sharedName */, true /* probing */); - final ArrayList> allRecords = new ArrayList<>(4); - allRecords.add(ptrRecord); + final ArrayList> allRecords = new ArrayList<>(5); + allRecords.addAll(ptrRecords); allRecords.add(srvRecord); allRecords.add(txtRecord); // Service type enumeration record (RFC6763 9.) @@ -275,7 +301,8 @@ public class MdnsRecordRepository { * ID of the replaced service. * @throws NameConflictException There is already a (non-exiting) service using the name. */ - public int addService(int serviceId, NsdServiceInfo serviceInfo) throws NameConflictException { + public int addService(int serviceId, NsdServiceInfo serviceInfo, @Nullable String subtype) + throws NameConflictException { if (mServices.contains(serviceId)) { throw new IllegalArgumentException( "Service ID must not be reused across registrations: " + serviceId); @@ -288,7 +315,7 @@ public class MdnsRecordRepository { } final ServiceRegistration registration = new ServiceRegistration( - mDeviceHostname, serviceInfo); + mDeviceHostname, serviceInfo, subtype); mServices.put(serviceId, registration); // Remove existing exiting service @@ -344,24 +371,25 @@ public class MdnsRecordRepository { if (registration == null) return null; if (registration.exiting) return null; - // Send exit (TTL 0) for the PTR record, if the record was sent (in particular don't send + // Send exit (TTL 0) for the PTR records, if at least one was sent (in particular don't send // if still probing) - if (registration.ptrRecord.lastSentTimeMs == 0L) { + if (CollectionUtils.all(registration.ptrRecords, r -> r.lastSentTimeMs == 0L)) { return null; } registration.exiting = true; - final MdnsPointerRecord expiredRecord = new MdnsPointerRecord( - registration.ptrRecord.record.getName(), - 0L /* receiptTimeMillis */, - true /* cacheFlush */, - 0L /* ttlMillis */, - registration.ptrRecord.record.getPointer()); + final List expiredRecords = CollectionUtils.map(registration.ptrRecords, + r -> new MdnsPointerRecord( + r.record.getName(), + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + 0L /* ttlMillis */, + r.record.getPointer())); // Exit should be skipped if the record is still advertised by another service, but that // would be a conflict (2 service registrations with the same service name), so it would // not have been allowed by the repository. - return new MdnsAnnouncer.ExitAnnouncementInfo(id, Collections.singletonList(expiredRecord)); + return new MdnsAnnouncer.ExitAnnouncementInfo(id, expiredRecords); } public void removeService(int id) { @@ -442,7 +470,7 @@ public class MdnsRecordRepository { for (int i = 0; i < mServices.size(); i++) { final ServiceRegistration registration = mServices.valueAt(i); if (registration.exiting) continue; - addReplyFromService(question, registration.allRecords, registration.ptrRecord, + addReplyFromService(question, registration.allRecords, registration.ptrRecords, registration.srvRecord, registration.txtRecord, replyUnicast, now, answerInfo, additionalAnswerRecords); } @@ -499,7 +527,7 @@ public class MdnsRecordRepository { */ private void addReplyFromService(@NonNull MdnsRecord question, @NonNull List> serviceRecords, - @Nullable RecordInfo servicePtrRecord, + @Nullable List> servicePtrRecords, @Nullable RecordInfo serviceSrvRecord, @Nullable RecordInfo serviceTxtRecord, boolean replyUnicast, long now, @NonNull List> answerInfo, @@ -531,7 +559,8 @@ public class MdnsRecordRepository { } hasKnownAnswer = true; - hasDnsSdPtrRecordAnswer |= (info == servicePtrRecord); + hasDnsSdPtrRecordAnswer |= (servicePtrRecords != null + && CollectionUtils.any(servicePtrRecords, r -> info == r)); hasDnsSdSrvRecordAnswer |= (info == serviceSrvRecord); // TODO: responses to probe queries should bypass this check and only ensure the @@ -791,10 +820,11 @@ public class MdnsRecordRepository { */ @Nullable public MdnsProber.ProbingInfo renameServiceForConflict(int serviceId, NsdServiceInfo newInfo) { - if (!mServices.contains(serviceId)) return null; + final ServiceRegistration existing = mServices.get(serviceId); + if (existing == null) return null; final ServiceRegistration newService = new ServiceRegistration( - mDeviceHostname, newInfo); + mDeviceHostname, newInfo, existing.subtype); mServices.put(serviceId, newService); return makeProbingInfo(serviceId, newService.srvRecord.record); } diff --git a/tests/unit/java/com/android/server/NsdServiceTest.java b/tests/unit/java/com/android/server/NsdServiceTest.java index 322b4d293a..b3e8cc844b 100644 --- a/tests/unit/java/com/android/server/NsdServiceTest.java +++ b/tests/unit/java/com/android/server/NsdServiceTest.java @@ -985,10 +985,9 @@ public class NsdServiceTest { final RegistrationListener regListener = mock(RegistrationListener.class); client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener); waitForIdle(); - // TODO: also pass the subtype to MdnsAdvertiser verify(mAdvertiser).addService(anyInt(), argThat(s -> "Instance".equals(s.getServiceName()) - && SERVICE_TYPE.equals(s.getServiceType()))); + && SERVICE_TYPE.equals(s.getServiceType())), eq("_subtype")); final DiscoveryListener discListener = mock(DiscoveryListener.class); client.discoverServices(typeWithSubtype, PROTOCOL, network, Runnable::run, discListener); @@ -1090,7 +1089,7 @@ public class NsdServiceTest { final ArgumentCaptor serviceIdCaptor = ArgumentCaptor.forClass(Integer.class); verify(mAdvertiser).addService(serviceIdCaptor.capture(), - argThat(info -> matches(info, regInfo))); + argThat(info -> matches(info, regInfo)), eq(null) /* subtype */); client.unregisterService(regListenerWithoutFeature); waitForIdle(); @@ -1147,8 +1146,10 @@ public class NsdServiceTest { waitForIdle(); // The advertiser is enabled for _type2 but not _type1 - verify(mAdvertiser, never()).addService(anyInt(), argThat(info -> matches(info, service1))); - verify(mAdvertiser).addService(anyInt(), argThat(info -> matches(info, service2))); + verify(mAdvertiser, never()).addService( + anyInt(), argThat(info -> matches(info, service1)), eq(null) /* subtype */); + verify(mAdvertiser).addService( + anyInt(), argThat(info -> matches(info, service2)), eq(null) /* subtype */); } @Test @@ -1173,7 +1174,7 @@ public class NsdServiceTest { verify(mSocketProvider).startMonitoringSockets(); final ArgumentCaptor idCaptor = ArgumentCaptor.forClass(Integer.class); verify(mAdvertiser).addService(idCaptor.capture(), argThat(info -> - matches(info, regInfo))); + matches(info, regInfo)), eq(null) /* subtype */); // Verify onServiceRegistered callback final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue(); @@ -1209,7 +1210,7 @@ public class NsdServiceTest { client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener); waitForIdle(); - verify(mAdvertiser, never()).addService(anyInt(), any()); + verify(mAdvertiser, never()).addService(anyInt(), any(), any()); verify(regListener, timeout(TIMEOUT_MS)).onRegistrationFailed( argThat(info -> matches(info, regInfo)), eq(FAILURE_INTERNAL_ERROR)); @@ -1237,7 +1238,8 @@ public class NsdServiceTest { final ArgumentCaptor idCaptor = ArgumentCaptor.forClass(Integer.class); // Service name is truncated to 63 characters verify(mAdvertiser).addService(idCaptor.capture(), - argThat(info -> info.getServiceName().equals("a".repeat(63)))); + argThat(info -> info.getServiceName().equals("a".repeat(63))), + eq(null) /* subtype */); // Verify onServiceRegistered callback final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue(); @@ -1319,7 +1321,7 @@ public class NsdServiceTest { client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener); waitForIdle(); verify(mSocketProvider).startMonitoringSockets(); - verify(mAdvertiser).addService(anyInt(), any()); + verify(mAdvertiser).addService(anyInt(), any(), any()); // Verify the discovery uses MdnsDiscoveryManager final DiscoveryListener discListener = mock(DiscoveryListener.class); diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt index 3bb08a668e..b539fe0a83 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt @@ -57,6 +57,7 @@ private val TEST_LINKADDR = LinkAddress(TEST_ADDR, 64 /* prefixLength */) private val TEST_NETWORK_1 = mock(Network::class.java) private val TEST_NETWORK_2 = mock(Network::class.java) private val TEST_HOSTNAME = arrayOf("Android_test", "local") +private const val TEST_SUBTYPE = "_subtype" private val SERVICE_1 = NsdServiceInfo("TestServiceName", "_advertisertest._tcp").apply { port = 12345 @@ -130,7 +131,7 @@ class MdnsAdvertiserTest { @Test fun testAddService_OneNetwork() { val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog) - postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1) } + postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1, null /* subtype */) } val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java) verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), socketCbCaptor.capture()) @@ -161,7 +162,7 @@ class MdnsAdvertiserTest { @Test fun testAddService_AllNetworks() { val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog) - postSync { advertiser.addService(SERVICE_ID_1, ALL_NETWORKS_SERVICE) } + postSync { advertiser.addService(SERVICE_ID_1, ALL_NETWORKS_SERVICE, TEST_SUBTYPE) } val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java) verify(socketProvider).requestSocket(eq(ALL_NETWORKS_SERVICE.network), @@ -179,6 +180,10 @@ class MdnsAdvertiserTest { verify(mockDeps).makeAdvertiser(eq(mockSocket2), eq(listOf(TEST_LINKADDR)), eq(thread.looper), any(), intAdvCbCaptor2.capture(), eq(TEST_HOSTNAME), any() ) + verify(mockInterfaceAdvertiser1).addService( + anyInt(), eq(ALL_NETWORKS_SERVICE), eq(TEST_SUBTYPE)) + verify(mockInterfaceAdvertiser2).addService( + anyInt(), eq(ALL_NETWORKS_SERVICE), eq(TEST_SUBTYPE)) doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_1) postSync { intAdvCbCaptor1.value.onRegisterServiceSucceeded( @@ -207,20 +212,21 @@ class MdnsAdvertiserTest { @Test fun testAddService_Conflicts() { val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog) - postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1) } + postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1, null /* subtype */) } val oneNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java) verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), oneNetSocketCbCaptor.capture()) val oneNetSocketCb = oneNetSocketCbCaptor.value // Register a service with the same name on all networks (name conflict) - postSync { advertiser.addService(SERVICE_ID_2, ALL_NETWORKS_SERVICE) } + postSync { advertiser.addService(SERVICE_ID_2, ALL_NETWORKS_SERVICE, null /* subtype */) } val allNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java) verify(socketProvider).requestSocket(eq(null), allNetSocketCbCaptor.capture()) val allNetSocketCb = allNetSocketCbCaptor.value - postSync { advertiser.addService(LONG_SERVICE_ID_1, LONG_SERVICE_1) } - postSync { advertiser.addService(LONG_SERVICE_ID_2, LONG_ALL_NETWORKS_SERVICE) } + postSync { advertiser.addService(LONG_SERVICE_ID_1, LONG_SERVICE_1, null /* subtype */) } + postSync { advertiser.addService(LONG_SERVICE_ID_2, LONG_ALL_NETWORKS_SERVICE, + null /* subtype */) } // Callbacks for matching network and all networks both get the socket postSync { @@ -248,13 +254,13 @@ class MdnsAdvertiserTest { eq(thread.looper), any(), intAdvCbCaptor.capture(), eq(TEST_HOSTNAME), any() ) verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_1), - argThat { it.matches(SERVICE_1) }) + argThat { it.matches(SERVICE_1) }, eq(null)) verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_2), - argThat { it.matches(expectedRenamed) }) + argThat { it.matches(expectedRenamed) }, eq(null)) verify(mockInterfaceAdvertiser1).addService(eq(LONG_SERVICE_ID_1), - argThat { it.matches(LONG_SERVICE_1) }) + argThat { it.matches(LONG_SERVICE_1) }, eq(null)) verify(mockInterfaceAdvertiser1).addService(eq(LONG_SERVICE_ID_2), - argThat { it.matches(expectedLongRenamed) }) + argThat { it.matches(expectedLongRenamed) }, eq(null)) doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_1) postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded( @@ -278,7 +284,7 @@ class MdnsAdvertiserTest { fun testRemoveService_whenAllServiceRemoved_thenUpdateHostName() { val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog) verify(mockDeps, times(1)).generateHostname() - postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1) } + postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1, null /* subtype */) } postSync { advertiser.removeService(SERVICE_ID_1) } verify(mockDeps, times(2)).generateHostname() } diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt index ee190af14d..dd458b812f 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt @@ -117,7 +117,7 @@ class MdnsInterfaceAdvertiserTest { knownServices.add(inv.getArgument(0)) -1 - }.`when`(repository).addService(anyInt(), any()) + }.`when`(repository).addService(anyInt(), any(), any()) doAnswer { inv -> knownServices.remove(inv.getArgument(0)) null @@ -278,8 +278,8 @@ class MdnsInterfaceAdvertiserTest { doReturn(serviceId).`when`(testProbingInfo).serviceId doReturn(testProbingInfo).`when`(repository).setServiceProbing(serviceId) - advertiser.addService(serviceId, serviceInfo) - verify(repository).addService(serviceId, serviceInfo) + advertiser.addService(serviceId, serviceInfo, null /* subtype */) + verify(repository).addService(serviceId, serviceInfo, null /* subtype */) verify(prober).startProbing(testProbingInfo) // Simulate probing success: continues to announcing diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt index 44e0d08648..4a39b935f3 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt @@ -44,6 +44,7 @@ import org.junit.runner.RunWith private const val TEST_SERVICE_ID_1 = 42 private const val TEST_SERVICE_ID_2 = 43 private const val TEST_PORT = 12345 +private const val TEST_SUBTYPE = "_subtype" private val TEST_HOSTNAME = arrayOf("Android_000102030405060708090A0B0C0D0E0F", "local") private val TEST_ADDRESSES = listOf( LinkAddress(parseNumericAddress("192.0.2.111"), 24), @@ -86,7 +87,8 @@ class MdnsRecordRepositoryTest { fun testAddServiceAndProbe() { val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) assertEquals(0, repository.servicesCount) - assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)) + assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, + null /* subtype */)) assertEquals(1, repository.servicesCount) val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1) @@ -118,18 +120,18 @@ class MdnsRecordRepositoryTest { @Test fun testAddAndConflicts() { val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) - repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1) + repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */) assertFailsWith(NameConflictException::class) { - repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1) + repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1, null /* subtype */) } } @Test fun testInvalidReuseOfServiceId() { val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) - repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1) + repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */) assertFailsWith(IllegalArgumentException::class) { - repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_2) + repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_2, null /* subtype */) } } @@ -138,7 +140,7 @@ class MdnsRecordRepositoryTest { val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) assertFalse(repository.hasActiveService(TEST_SERVICE_ID_1)) - repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1) + repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */) assertTrue(repository.hasActiveService(TEST_SERVICE_ID_1)) val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1) @@ -179,6 +181,41 @@ class MdnsRecordRepositoryTest { assertEquals(0, repository.servicesCount) } + @Test + fun testExitAnnouncements_WithSubtype() { + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) + repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, TEST_SUBTYPE) + repository.onAdvertisementSent(TEST_SERVICE_ID_1) + + val exitAnnouncement = repository.exitService(TEST_SERVICE_ID_1) + assertNotNull(exitAnnouncement) + assertEquals(1, repository.servicesCount) + val packet = exitAnnouncement.getPacket(0) + + assertEquals(0x8400 /* response, authoritative */, packet.flags) + assertEquals(0, packet.questions.size) + assertEquals(0, packet.authorityRecords.size) + assertEquals(0, packet.additionalRecords.size) + + assertContentEquals(listOf( + MdnsPointerRecord( + arrayOf("_testservice", "_tcp", "local"), + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + 0L /* ttlMillis */, + arrayOf("MyTestService", "_testservice", "_tcp", "local")), + MdnsPointerRecord( + arrayOf("_subtype", "_sub", "_testservice", "_tcp", "local"), + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + 0L /* ttlMillis */, + arrayOf("MyTestService", "_testservice", "_tcp", "local")), + ), packet.answers) + + repository.removeService(TEST_SERVICE_ID_1) + assertEquals(0, repository.servicesCount) + } + @Test fun testExitingServiceReAdded() { val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) @@ -186,7 +223,8 @@ class MdnsRecordRepositoryTest { repository.onAdvertisementSent(TEST_SERVICE_ID_1) repository.exitService(TEST_SERVICE_ID_1) - assertEquals(TEST_SERVICE_ID_1, repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1)) + assertEquals(TEST_SERVICE_ID_1, + repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1, null /* subtype */)) assertEquals(1, repository.servicesCount) repository.removeService(TEST_SERVICE_ID_2) @@ -196,7 +234,8 @@ class MdnsRecordRepositoryTest { @Test fun testOnProbingSucceeded() { val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) - val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1) + val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, + TEST_SUBTYPE) repository.onAdvertisementSent(TEST_SERVICE_ID_1) val packet = announcementInfo.getPacket(0) @@ -205,6 +244,7 @@ class MdnsRecordRepositoryTest { assertEquals(0, packet.authorityRecords.size) val serviceType = arrayOf("_testservice", "_tcp", "local") + val serviceSubtype = arrayOf(TEST_SUBTYPE, "_sub", "_testservice", "_tcp", "local") val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local") val v4AddrRev = getReverseDnsAddress(TEST_ADDRESSES[0].address) val v6Addr1Rev = getReverseDnsAddress(TEST_ADDRESSES[1].address) @@ -250,6 +290,13 @@ class MdnsRecordRepositoryTest { false /* cacheFlush */, 4500000L /* ttlMillis */, serviceName), + MdnsPointerRecord( + serviceSubtype, + 0L /* receiptTimeMillis */, + // Not a unique name owned by the announcer, so cacheFlush=false + false /* cacheFlush */, + 4500000L /* ttlMillis */, + serviceName), MdnsServiceRecord( serviceName, 0L /* receiptTimeMillis */, @@ -319,9 +366,21 @@ class MdnsRecordRepositoryTest { @Test fun testGetReply() { + doGetReplyTest(subtype = null) + } + + @Test + fun testGetReply_WithSubtype() { + doGetReplyTest(TEST_SUBTYPE) + } + + private fun doGetReplyTest(subtype: String?) { val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) - repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1) - val questions = listOf(MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"), + repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, subtype) + val queriedName = if (subtype == null) arrayOf("_testservice", "_tcp", "local") + else arrayOf(subtype, "_sub", "_testservice", "_tcp", "local") + + val questions = listOf(MdnsPointerRecord(queriedName, 0L /* receiptTimeMillis */, false /* cacheFlush */, // TTL and data is empty for a question @@ -344,7 +403,7 @@ class MdnsRecordRepositoryTest { assertEquals(listOf( MdnsPointerRecord( - arrayOf("_testservice", "_tcp", "local"), + queriedName, 0L /* receiptTimeMillis */, false /* cacheFlush */, longTtl, @@ -405,8 +464,8 @@ class MdnsRecordRepositoryTest { @Test fun testGetConflictingServices() { val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) - repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1) - repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2) + repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */) + repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */) val packet = MdnsPacket( 0 /* flags */, @@ -433,8 +492,8 @@ class MdnsRecordRepositoryTest { @Test fun testGetConflictingServices_IdenticalService() { val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) - repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1) - repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2) + repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */) + repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */) val otherTtlMillis = 1234L val packet = MdnsPacket( @@ -460,10 +519,13 @@ class MdnsRecordRepositoryTest { } } -private fun MdnsRecordRepository.initWithService(serviceId: Int, serviceInfo: NsdServiceInfo): - AnnouncementInfo { +private fun MdnsRecordRepository.initWithService( + serviceId: Int, + serviceInfo: NsdServiceInfo, + subtype: String? = null +): AnnouncementInfo { updateAddresses(TEST_ADDRESSES) - addService(serviceId, serviceInfo) + addService(serviceId, serviceInfo, subtype) val probingInfo = setServiceProbing(serviceId) assertNotNull(probingInfo) return onProbingSucceeded(probingInfo)