Implement proper subtype advertising

Implement subtype advertising by advertising an additional PTR record if
subtype advertising was requested.

For an advertiser a subtype is just an additional PTR record, so this
just involves plumbing the subtype down to MdnsRecordRepository, which
includes the additional record in the service registration.

Bug: 266167702
Test: atest
Change-Id: I09e780af25149162f16bd75410ddc50f160a0dab
This commit is contained in:
Remi NGUYEN VAN
2023-05-11 20:42:26 +09:00
parent f2d064112c
commit ce44beb7aa
8 changed files with 183 additions and 71 deletions

View File

@@ -727,7 +727,7 @@ public class NsdService extends INsdManager.Stub {
// service type would generate service instance names like // service type would generate service instance names like
// Name._subtype._sub._type._tcp, which is incorrect // Name._subtype._sub._type._tcp, which is incorrect
// (it should be Name._type._tcp). // (it should be Name._type._tcp).
mAdvertiser.addService(id, serviceInfo); mAdvertiser.addService(id, serviceInfo, typeSubtype.second);
storeAdvertiserRequestMap(clientId, id, clientInfo); storeAdvertiserRequestMap(clientId, id, clientInfo);
} else { } else {
maybeStartDaemon(); maybeStartDaemon();

View File

@@ -270,7 +270,8 @@ public class MdnsAdvertiser {
mPendingRegistrations.put(id, registration); mPendingRegistrations.put(id, registration);
for (int i = 0; i < mAdvertisers.size(); i++) { for (int i = 0; i < mAdvertisers.size(); i++) {
try { try {
mAdvertisers.valueAt(i).addService(id, registration.getServiceInfo()); mAdvertisers.valueAt(i).addService(
id, registration.getServiceInfo(), registration.getSubtype());
} catch (NameConflictException e) { } catch (NameConflictException e) {
Log.wtf(TAG, "Name conflict adding services that should have unique names", e); Log.wtf(TAG, "Name conflict adding services that should have unique names", e);
} }
@@ -298,9 +299,10 @@ public class MdnsAdvertiser {
} }
mAdvertisers.put(socket, advertiser); mAdvertisers.put(socket, advertiser);
for (int i = 0; i < mPendingRegistrations.size(); i++) { for (int i = 0; i < mPendingRegistrations.size(); i++) {
final Registration registration = mPendingRegistrations.valueAt(i);
try { try {
advertiser.addService(mPendingRegistrations.keyAt(i), advertiser.addService(mPendingRegistrations.keyAt(i),
mPendingRegistrations.valueAt(i).getServiceInfo()); registration.getServiceInfo(), registration.getSubtype());
} catch (NameConflictException e) { } catch (NameConflictException e) {
Log.wtf(TAG, "Name conflict adding services that should have unique names", e); Log.wtf(TAG, "Name conflict adding services that should have unique names", e);
} }
@@ -329,10 +331,13 @@ public class MdnsAdvertiser {
private int mConflictCount; private int mConflictCount;
@NonNull @NonNull
private NsdServiceInfo mServiceInfo; private NsdServiceInfo mServiceInfo;
@Nullable
private final String mSubtype;
private Registration(@NonNull NsdServiceInfo serviceInfo) { private Registration(@NonNull NsdServiceInfo serviceInfo, @Nullable String subtype) {
this.mOriginalName = serviceInfo.getServiceName(); this.mOriginalName = serviceInfo.getServiceName();
this.mServiceInfo = serviceInfo; this.mServiceInfo = serviceInfo;
this.mSubtype = subtype;
} }
/** /**
@@ -387,6 +392,11 @@ public class MdnsAdvertiser {
public NsdServiceInfo getServiceInfo() { public NsdServiceInfo getServiceInfo() {
return mServiceInfo; return mServiceInfo;
} }
@Nullable
public String getSubtype() {
return mSubtype;
}
} }
/** /**
@@ -443,8 +453,9 @@ public class MdnsAdvertiser {
* Add a service to advertise. * Add a service to advertise.
* @param id A unique ID for the service. * @param id A unique ID for the service.
* @param service The service info to advertise. * @param service The service info to advertise.
* @param subtype An optional subtype to advertise the service with.
*/ */
public void addService(int id, NsdServiceInfo service) { public void addService(int id, NsdServiceInfo service, @Nullable String subtype) {
checkThread(); checkThread();
if (mRegistrations.get(id) != null) { if (mRegistrations.get(id) != null) {
Log.e(TAG, "Adding duplicate registration for " + service); Log.e(TAG, "Adding duplicate registration for " + service);
@@ -453,10 +464,10 @@ public class MdnsAdvertiser {
return; return;
} }
mSharedLog.i("Adding service " + service + " with ID " + id); mSharedLog.i("Adding service " + service + " with ID " + id + " and subtype " + subtype);
final Network network = service.getNetwork(); final Network network = service.getNetwork();
final Registration registration = new Registration(service); final Registration registration = new Registration(service, subtype);
final BiPredicate<Network, InterfaceAdvertiserRequest> checkConflictFilter; final BiPredicate<Network, InterfaceAdvertiserRequest> checkConflictFilter;
if (network == null) { if (network == null) {
// If registering on all networks, no advertiser must have conflicts // If registering on all networks, no advertiser must have conflicts

View File

@@ -212,8 +212,9 @@ public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHand
* *
* @throws NameConflictException There is already a service being advertised with that name. * @throws NameConflictException There is already a service being advertised with that name.
*/ */
public void addService(int id, NsdServiceInfo service) throws NameConflictException { public void addService(int id, NsdServiceInfo service, @Nullable String subtype)
final int replacedExitingService = mRecordRepository.addService(id, service); throws NameConflictException {
final int replacedExitingService = mRecordRepository.addService(id, service, subtype);
// Cancel announcements for the existing service. This only happens for exiting services // Cancel announcements for the existing service. This only happens for exiting services
// (so cancelling exiting announcements), as per RecordRepository.addService. // (so cancelling exiting announcements), as per RecordRepository.addService.
if (replacedExitingService >= 0) { if (replacedExitingService >= 0) {

View File

@@ -69,6 +69,8 @@ public class MdnsRecordRepository {
// Top-level domain for link-local queries, as per RFC6762 3. // Top-level domain for link-local queries, as per RFC6762 3.
private static final String LOCAL_TLD = "local"; private static final String LOCAL_TLD = "local";
// Subtype separator as per RFC6763 7.1 (_printer._sub._http._tcp.local)
private static final String SUBTYPE_SEPARATOR = "_sub";
// Service type for service enumeration (RFC6763 9.) // Service type for service enumeration (RFC6763 9.)
private static final String[] DNS_SD_SERVICE_TYPE = private static final String[] DNS_SD_SERVICE_TYPE =
@@ -156,13 +158,15 @@ public class MdnsRecordRepository {
@NonNull @NonNull
public final List<RecordInfo<?>> allRecords; public final List<RecordInfo<?>> allRecords;
@NonNull @NonNull
public final RecordInfo<MdnsPointerRecord> ptrRecord; public final List<RecordInfo<MdnsPointerRecord>> ptrRecords;
@NonNull @NonNull
public final RecordInfo<MdnsServiceRecord> srvRecord; public final RecordInfo<MdnsServiceRecord> srvRecord;
@NonNull @NonNull
public final RecordInfo<MdnsTextRecord> txtRecord; public final RecordInfo<MdnsTextRecord> txtRecord;
@NonNull @NonNull
public final NsdServiceInfo serviceInfo; public final NsdServiceInfo serviceInfo;
@Nullable
public final String subtype;
/** /**
* Whether the service is sending exit announcements and will be destroyed soon. * Whether the service is sending exit announcements and will be destroyed soon.
@@ -175,14 +179,16 @@ public class MdnsRecordRepository {
* @param deviceHostname Hostname of the device (for the interface used) * @param deviceHostname Hostname of the device (for the interface used)
* @param serviceInfo Service to advertise * @param serviceInfo Service to advertise
*/ */
ServiceRegistration(@NonNull String[] deviceHostname, @NonNull NsdServiceInfo serviceInfo) { ServiceRegistration(@NonNull String[] deviceHostname, @NonNull NsdServiceInfo serviceInfo,
@Nullable String subtype) {
this.serviceInfo = serviceInfo; this.serviceInfo = serviceInfo;
this.subtype = subtype;
final String[] serviceType = splitServiceType(serviceInfo); final String[] serviceType = splitServiceType(serviceInfo);
final String[] serviceName = splitFullyQualifiedName(serviceInfo, serviceType); final String[] serviceName = splitFullyQualifiedName(serviceInfo, serviceType);
// Service PTR record // Service PTR record
ptrRecord = new RecordInfo<>( final RecordInfo<MdnsPointerRecord> ptrRecord = new RecordInfo<>(
serviceInfo, serviceInfo,
new MdnsPointerRecord( new MdnsPointerRecord(
serviceType, serviceType,
@@ -192,6 +198,26 @@ public class MdnsRecordRepository {
serviceName), serviceName),
true /* sharedName */, true /* probing */); true /* sharedName */, true /* probing */);
if (subtype == null) {
this.ptrRecords = Collections.singletonList(ptrRecord);
} else {
final String[] subtypeName = new String[serviceType.length + 2];
System.arraycopy(serviceType, 0, subtypeName, 2, serviceType.length);
subtypeName[0] = subtype;
subtypeName[1] = SUBTYPE_SEPARATOR;
final RecordInfo<MdnsPointerRecord> subtypeRecord = new RecordInfo<>(
serviceInfo,
new MdnsPointerRecord(
subtypeName,
0L /* receiptTimeMillis */,
false /* cacheFlush */,
NON_NAME_RECORDS_TTL_MILLIS,
serviceName),
true /* sharedName */, true /* probing */);
this.ptrRecords = List.of(ptrRecord, subtypeRecord);
}
srvRecord = new RecordInfo<>( srvRecord = new RecordInfo<>(
serviceInfo, serviceInfo,
new MdnsServiceRecord(serviceName, new MdnsServiceRecord(serviceName,
@@ -211,8 +237,8 @@ public class MdnsRecordRepository {
attrsToTextEntries(serviceInfo.getAttributes())), attrsToTextEntries(serviceInfo.getAttributes())),
false /* sharedName */, true /* probing */); false /* sharedName */, true /* probing */);
final ArrayList<RecordInfo<?>> allRecords = new ArrayList<>(4); final ArrayList<RecordInfo<?>> allRecords = new ArrayList<>(5);
allRecords.add(ptrRecord); allRecords.addAll(ptrRecords);
allRecords.add(srvRecord); allRecords.add(srvRecord);
allRecords.add(txtRecord); allRecords.add(txtRecord);
// Service type enumeration record (RFC6763 9.) // Service type enumeration record (RFC6763 9.)
@@ -275,7 +301,8 @@ public class MdnsRecordRepository {
* ID of the replaced service. * ID of the replaced service.
* @throws NameConflictException There is already a (non-exiting) service using the name. * @throws NameConflictException There is already a (non-exiting) service using the name.
*/ */
public int addService(int serviceId, NsdServiceInfo serviceInfo) throws NameConflictException { public int addService(int serviceId, NsdServiceInfo serviceInfo, @Nullable String subtype)
throws NameConflictException {
if (mServices.contains(serviceId)) { if (mServices.contains(serviceId)) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Service ID must not be reused across registrations: " + serviceId); "Service ID must not be reused across registrations: " + serviceId);
@@ -288,7 +315,7 @@ public class MdnsRecordRepository {
} }
final ServiceRegistration registration = new ServiceRegistration( final ServiceRegistration registration = new ServiceRegistration(
mDeviceHostname, serviceInfo); mDeviceHostname, serviceInfo, subtype);
mServices.put(serviceId, registration); mServices.put(serviceId, registration);
// Remove existing exiting service // Remove existing exiting service
@@ -344,24 +371,25 @@ public class MdnsRecordRepository {
if (registration == null) return null; if (registration == null) return null;
if (registration.exiting) return null; if (registration.exiting) return null;
// Send exit (TTL 0) for the PTR record, if the record was sent (in particular don't send // Send exit (TTL 0) for the PTR records, if at least one was sent (in particular don't send
// if still probing) // if still probing)
if (registration.ptrRecord.lastSentTimeMs == 0L) { if (CollectionUtils.all(registration.ptrRecords, r -> r.lastSentTimeMs == 0L)) {
return null; return null;
} }
registration.exiting = true; registration.exiting = true;
final MdnsPointerRecord expiredRecord = new MdnsPointerRecord( final List<MdnsRecord> expiredRecords = CollectionUtils.map(registration.ptrRecords,
registration.ptrRecord.record.getName(), r -> new MdnsPointerRecord(
r.record.getName(),
0L /* receiptTimeMillis */, 0L /* receiptTimeMillis */,
true /* cacheFlush */, true /* cacheFlush */,
0L /* ttlMillis */, 0L /* ttlMillis */,
registration.ptrRecord.record.getPointer()); r.record.getPointer()));
// Exit should be skipped if the record is still advertised by another service, but that // Exit should be skipped if the record is still advertised by another service, but that
// would be a conflict (2 service registrations with the same service name), so it would // would be a conflict (2 service registrations with the same service name), so it would
// not have been allowed by the repository. // not have been allowed by the repository.
return new MdnsAnnouncer.ExitAnnouncementInfo(id, Collections.singletonList(expiredRecord)); return new MdnsAnnouncer.ExitAnnouncementInfo(id, expiredRecords);
} }
public void removeService(int id) { public void removeService(int id) {
@@ -442,7 +470,7 @@ public class MdnsRecordRepository {
for (int i = 0; i < mServices.size(); i++) { for (int i = 0; i < mServices.size(); i++) {
final ServiceRegistration registration = mServices.valueAt(i); final ServiceRegistration registration = mServices.valueAt(i);
if (registration.exiting) continue; if (registration.exiting) continue;
addReplyFromService(question, registration.allRecords, registration.ptrRecord, addReplyFromService(question, registration.allRecords, registration.ptrRecords,
registration.srvRecord, registration.txtRecord, replyUnicast, now, registration.srvRecord, registration.txtRecord, replyUnicast, now,
answerInfo, additionalAnswerRecords); answerInfo, additionalAnswerRecords);
} }
@@ -499,7 +527,7 @@ public class MdnsRecordRepository {
*/ */
private void addReplyFromService(@NonNull MdnsRecord question, private void addReplyFromService(@NonNull MdnsRecord question,
@NonNull List<RecordInfo<?>> serviceRecords, @NonNull List<RecordInfo<?>> serviceRecords,
@Nullable RecordInfo<MdnsPointerRecord> servicePtrRecord, @Nullable List<RecordInfo<MdnsPointerRecord>> servicePtrRecords,
@Nullable RecordInfo<MdnsServiceRecord> serviceSrvRecord, @Nullable RecordInfo<MdnsServiceRecord> serviceSrvRecord,
@Nullable RecordInfo<MdnsTextRecord> serviceTxtRecord, @Nullable RecordInfo<MdnsTextRecord> serviceTxtRecord,
boolean replyUnicast, long now, @NonNull List<RecordInfo<?>> answerInfo, boolean replyUnicast, long now, @NonNull List<RecordInfo<?>> answerInfo,
@@ -531,7 +559,8 @@ public class MdnsRecordRepository {
} }
hasKnownAnswer = true; hasKnownAnswer = true;
hasDnsSdPtrRecordAnswer |= (info == servicePtrRecord); hasDnsSdPtrRecordAnswer |= (servicePtrRecords != null
&& CollectionUtils.any(servicePtrRecords, r -> info == r));
hasDnsSdSrvRecordAnswer |= (info == serviceSrvRecord); hasDnsSdSrvRecordAnswer |= (info == serviceSrvRecord);
// TODO: responses to probe queries should bypass this check and only ensure the // TODO: responses to probe queries should bypass this check and only ensure the
@@ -791,10 +820,11 @@ public class MdnsRecordRepository {
*/ */
@Nullable @Nullable
public MdnsProber.ProbingInfo renameServiceForConflict(int serviceId, NsdServiceInfo newInfo) { public MdnsProber.ProbingInfo renameServiceForConflict(int serviceId, NsdServiceInfo newInfo) {
if (!mServices.contains(serviceId)) return null; final ServiceRegistration existing = mServices.get(serviceId);
if (existing == null) return null;
final ServiceRegistration newService = new ServiceRegistration( final ServiceRegistration newService = new ServiceRegistration(
mDeviceHostname, newInfo); mDeviceHostname, newInfo, existing.subtype);
mServices.put(serviceId, newService); mServices.put(serviceId, newService);
return makeProbingInfo(serviceId, newService.srvRecord.record); return makeProbingInfo(serviceId, newService.srvRecord.record);
} }

View File

@@ -985,10 +985,9 @@ public class NsdServiceTest {
final RegistrationListener regListener = mock(RegistrationListener.class); final RegistrationListener regListener = mock(RegistrationListener.class);
client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener); client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
waitForIdle(); waitForIdle();
// TODO: also pass the subtype to MdnsAdvertiser
verify(mAdvertiser).addService(anyInt(), argThat(s -> verify(mAdvertiser).addService(anyInt(), argThat(s ->
"Instance".equals(s.getServiceName()) "Instance".equals(s.getServiceName())
&& SERVICE_TYPE.equals(s.getServiceType()))); && SERVICE_TYPE.equals(s.getServiceType())), eq("_subtype"));
final DiscoveryListener discListener = mock(DiscoveryListener.class); final DiscoveryListener discListener = mock(DiscoveryListener.class);
client.discoverServices(typeWithSubtype, PROTOCOL, network, Runnable::run, discListener); client.discoverServices(typeWithSubtype, PROTOCOL, network, Runnable::run, discListener);
@@ -1090,7 +1089,7 @@ public class NsdServiceTest {
final ArgumentCaptor<Integer> serviceIdCaptor = ArgumentCaptor.forClass(Integer.class); final ArgumentCaptor<Integer> serviceIdCaptor = ArgumentCaptor.forClass(Integer.class);
verify(mAdvertiser).addService(serviceIdCaptor.capture(), verify(mAdvertiser).addService(serviceIdCaptor.capture(),
argThat(info -> matches(info, regInfo))); argThat(info -> matches(info, regInfo)), eq(null) /* subtype */);
client.unregisterService(regListenerWithoutFeature); client.unregisterService(regListenerWithoutFeature);
waitForIdle(); waitForIdle();
@@ -1147,8 +1146,10 @@ public class NsdServiceTest {
waitForIdle(); waitForIdle();
// The advertiser is enabled for _type2 but not _type1 // The advertiser is enabled for _type2 but not _type1
verify(mAdvertiser, never()).addService(anyInt(), argThat(info -> matches(info, service1))); verify(mAdvertiser, never()).addService(
verify(mAdvertiser).addService(anyInt(), argThat(info -> matches(info, service2))); anyInt(), argThat(info -> matches(info, service1)), eq(null) /* subtype */);
verify(mAdvertiser).addService(
anyInt(), argThat(info -> matches(info, service2)), eq(null) /* subtype */);
} }
@Test @Test
@@ -1173,7 +1174,7 @@ public class NsdServiceTest {
verify(mSocketProvider).startMonitoringSockets(); verify(mSocketProvider).startMonitoringSockets();
final ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class); final ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class);
verify(mAdvertiser).addService(idCaptor.capture(), argThat(info -> verify(mAdvertiser).addService(idCaptor.capture(), argThat(info ->
matches(info, regInfo))); matches(info, regInfo)), eq(null) /* subtype */);
// Verify onServiceRegistered callback // Verify onServiceRegistered callback
final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue(); final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue();
@@ -1209,7 +1210,7 @@ public class NsdServiceTest {
client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener); client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
waitForIdle(); waitForIdle();
verify(mAdvertiser, never()).addService(anyInt(), any()); verify(mAdvertiser, never()).addService(anyInt(), any(), any());
verify(regListener, timeout(TIMEOUT_MS)).onRegistrationFailed( verify(regListener, timeout(TIMEOUT_MS)).onRegistrationFailed(
argThat(info -> matches(info, regInfo)), eq(FAILURE_INTERNAL_ERROR)); argThat(info -> matches(info, regInfo)), eq(FAILURE_INTERNAL_ERROR));
@@ -1237,7 +1238,8 @@ public class NsdServiceTest {
final ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class); final ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class);
// Service name is truncated to 63 characters // Service name is truncated to 63 characters
verify(mAdvertiser).addService(idCaptor.capture(), verify(mAdvertiser).addService(idCaptor.capture(),
argThat(info -> info.getServiceName().equals("a".repeat(63)))); argThat(info -> info.getServiceName().equals("a".repeat(63))),
eq(null) /* subtype */);
// Verify onServiceRegistered callback // Verify onServiceRegistered callback
final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue(); final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue();
@@ -1319,7 +1321,7 @@ public class NsdServiceTest {
client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener); client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
waitForIdle(); waitForIdle();
verify(mSocketProvider).startMonitoringSockets(); verify(mSocketProvider).startMonitoringSockets();
verify(mAdvertiser).addService(anyInt(), any()); verify(mAdvertiser).addService(anyInt(), any(), any());
// Verify the discovery uses MdnsDiscoveryManager // Verify the discovery uses MdnsDiscoveryManager
final DiscoveryListener discListener = mock(DiscoveryListener.class); final DiscoveryListener discListener = mock(DiscoveryListener.class);

View File

@@ -57,6 +57,7 @@ private val TEST_LINKADDR = LinkAddress(TEST_ADDR, 64 /* prefixLength */)
private val TEST_NETWORK_1 = mock(Network::class.java) private val TEST_NETWORK_1 = mock(Network::class.java)
private val TEST_NETWORK_2 = mock(Network::class.java) private val TEST_NETWORK_2 = mock(Network::class.java)
private val TEST_HOSTNAME = arrayOf("Android_test", "local") private val TEST_HOSTNAME = arrayOf("Android_test", "local")
private const val TEST_SUBTYPE = "_subtype"
private val SERVICE_1 = NsdServiceInfo("TestServiceName", "_advertisertest._tcp").apply { private val SERVICE_1 = NsdServiceInfo("TestServiceName", "_advertisertest._tcp").apply {
port = 12345 port = 12345
@@ -130,7 +131,7 @@ class MdnsAdvertiserTest {
@Test @Test
fun testAddService_OneNetwork() { fun testAddService_OneNetwork() {
val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog) val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog)
postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1) } postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1, null /* subtype */) }
val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java) val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), socketCbCaptor.capture()) verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), socketCbCaptor.capture())
@@ -161,7 +162,7 @@ class MdnsAdvertiserTest {
@Test @Test
fun testAddService_AllNetworks() { fun testAddService_AllNetworks() {
val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog) val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog)
postSync { advertiser.addService(SERVICE_ID_1, ALL_NETWORKS_SERVICE) } postSync { advertiser.addService(SERVICE_ID_1, ALL_NETWORKS_SERVICE, TEST_SUBTYPE) }
val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java) val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
verify(socketProvider).requestSocket(eq(ALL_NETWORKS_SERVICE.network), verify(socketProvider).requestSocket(eq(ALL_NETWORKS_SERVICE.network),
@@ -179,6 +180,10 @@ class MdnsAdvertiserTest {
verify(mockDeps).makeAdvertiser(eq(mockSocket2), eq(listOf(TEST_LINKADDR)), verify(mockDeps).makeAdvertiser(eq(mockSocket2), eq(listOf(TEST_LINKADDR)),
eq(thread.looper), any(), intAdvCbCaptor2.capture(), eq(TEST_HOSTNAME), any() eq(thread.looper), any(), intAdvCbCaptor2.capture(), eq(TEST_HOSTNAME), any()
) )
verify(mockInterfaceAdvertiser1).addService(
anyInt(), eq(ALL_NETWORKS_SERVICE), eq(TEST_SUBTYPE))
verify(mockInterfaceAdvertiser2).addService(
anyInt(), eq(ALL_NETWORKS_SERVICE), eq(TEST_SUBTYPE))
doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_1) doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_1)
postSync { intAdvCbCaptor1.value.onRegisterServiceSucceeded( postSync { intAdvCbCaptor1.value.onRegisterServiceSucceeded(
@@ -207,20 +212,21 @@ class MdnsAdvertiserTest {
@Test @Test
fun testAddService_Conflicts() { fun testAddService_Conflicts() {
val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog) val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog)
postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1) } postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1, null /* subtype */) }
val oneNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java) val oneNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), oneNetSocketCbCaptor.capture()) verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), oneNetSocketCbCaptor.capture())
val oneNetSocketCb = oneNetSocketCbCaptor.value val oneNetSocketCb = oneNetSocketCbCaptor.value
// Register a service with the same name on all networks (name conflict) // Register a service with the same name on all networks (name conflict)
postSync { advertiser.addService(SERVICE_ID_2, ALL_NETWORKS_SERVICE) } postSync { advertiser.addService(SERVICE_ID_2, ALL_NETWORKS_SERVICE, null /* subtype */) }
val allNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java) val allNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
verify(socketProvider).requestSocket(eq(null), allNetSocketCbCaptor.capture()) verify(socketProvider).requestSocket(eq(null), allNetSocketCbCaptor.capture())
val allNetSocketCb = allNetSocketCbCaptor.value val allNetSocketCb = allNetSocketCbCaptor.value
postSync { advertiser.addService(LONG_SERVICE_ID_1, LONG_SERVICE_1) } postSync { advertiser.addService(LONG_SERVICE_ID_1, LONG_SERVICE_1, null /* subtype */) }
postSync { advertiser.addService(LONG_SERVICE_ID_2, LONG_ALL_NETWORKS_SERVICE) } postSync { advertiser.addService(LONG_SERVICE_ID_2, LONG_ALL_NETWORKS_SERVICE,
null /* subtype */) }
// Callbacks for matching network and all networks both get the socket // Callbacks for matching network and all networks both get the socket
postSync { postSync {
@@ -248,13 +254,13 @@ class MdnsAdvertiserTest {
eq(thread.looper), any(), intAdvCbCaptor.capture(), eq(TEST_HOSTNAME), any() eq(thread.looper), any(), intAdvCbCaptor.capture(), eq(TEST_HOSTNAME), any()
) )
verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_1), verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_1),
argThat { it.matches(SERVICE_1) }) argThat { it.matches(SERVICE_1) }, eq(null))
verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_2), verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_2),
argThat { it.matches(expectedRenamed) }) argThat { it.matches(expectedRenamed) }, eq(null))
verify(mockInterfaceAdvertiser1).addService(eq(LONG_SERVICE_ID_1), verify(mockInterfaceAdvertiser1).addService(eq(LONG_SERVICE_ID_1),
argThat { it.matches(LONG_SERVICE_1) }) argThat { it.matches(LONG_SERVICE_1) }, eq(null))
verify(mockInterfaceAdvertiser1).addService(eq(LONG_SERVICE_ID_2), verify(mockInterfaceAdvertiser1).addService(eq(LONG_SERVICE_ID_2),
argThat { it.matches(expectedLongRenamed) }) argThat { it.matches(expectedLongRenamed) }, eq(null))
doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_1) doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_1)
postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded( postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded(
@@ -278,7 +284,7 @@ class MdnsAdvertiserTest {
fun testRemoveService_whenAllServiceRemoved_thenUpdateHostName() { fun testRemoveService_whenAllServiceRemoved_thenUpdateHostName() {
val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog) val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog)
verify(mockDeps, times(1)).generateHostname() verify(mockDeps, times(1)).generateHostname()
postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1) } postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1, null /* subtype */) }
postSync { advertiser.removeService(SERVICE_ID_1) } postSync { advertiser.removeService(SERVICE_ID_1) }
verify(mockDeps, times(2)).generateHostname() verify(mockDeps, times(2)).generateHostname()
} }

View File

@@ -117,7 +117,7 @@ class MdnsInterfaceAdvertiserTest {
knownServices.add(inv.getArgument(0)) knownServices.add(inv.getArgument(0))
-1 -1
}.`when`(repository).addService(anyInt(), any()) }.`when`(repository).addService(anyInt(), any(), any())
doAnswer { inv -> doAnswer { inv ->
knownServices.remove(inv.getArgument(0)) knownServices.remove(inv.getArgument(0))
null null
@@ -278,8 +278,8 @@ class MdnsInterfaceAdvertiserTest {
doReturn(serviceId).`when`(testProbingInfo).serviceId doReturn(serviceId).`when`(testProbingInfo).serviceId
doReturn(testProbingInfo).`when`(repository).setServiceProbing(serviceId) doReturn(testProbingInfo).`when`(repository).setServiceProbing(serviceId)
advertiser.addService(serviceId, serviceInfo) advertiser.addService(serviceId, serviceInfo, null /* subtype */)
verify(repository).addService(serviceId, serviceInfo) verify(repository).addService(serviceId, serviceInfo, null /* subtype */)
verify(prober).startProbing(testProbingInfo) verify(prober).startProbing(testProbingInfo)
// Simulate probing success: continues to announcing // Simulate probing success: continues to announcing

View File

@@ -44,6 +44,7 @@ import org.junit.runner.RunWith
private const val TEST_SERVICE_ID_1 = 42 private const val TEST_SERVICE_ID_1 = 42
private const val TEST_SERVICE_ID_2 = 43 private const val TEST_SERVICE_ID_2 = 43
private const val TEST_PORT = 12345 private const val TEST_PORT = 12345
private const val TEST_SUBTYPE = "_subtype"
private val TEST_HOSTNAME = arrayOf("Android_000102030405060708090A0B0C0D0E0F", "local") private val TEST_HOSTNAME = arrayOf("Android_000102030405060708090A0B0C0D0E0F", "local")
private val TEST_ADDRESSES = listOf( private val TEST_ADDRESSES = listOf(
LinkAddress(parseNumericAddress("192.0.2.111"), 24), LinkAddress(parseNumericAddress("192.0.2.111"), 24),
@@ -86,7 +87,8 @@ class MdnsRecordRepositoryTest {
fun testAddServiceAndProbe() { fun testAddServiceAndProbe() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
assertEquals(0, repository.servicesCount) assertEquals(0, repository.servicesCount)
assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)) assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1,
null /* subtype */))
assertEquals(1, repository.servicesCount) assertEquals(1, repository.servicesCount)
val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1) val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
@@ -118,18 +120,18 @@ class MdnsRecordRepositoryTest {
@Test @Test
fun testAddAndConflicts() { fun testAddAndConflicts() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
assertFailsWith(NameConflictException::class) { assertFailsWith(NameConflictException::class) {
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1) repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1, null /* subtype */)
} }
} }
@Test @Test
fun testInvalidReuseOfServiceId() { fun testInvalidReuseOfServiceId() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
assertFailsWith(IllegalArgumentException::class) { assertFailsWith(IllegalArgumentException::class) {
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_2) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_2, null /* subtype */)
} }
} }
@@ -138,7 +140,7 @@ class MdnsRecordRepositoryTest {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
assertFalse(repository.hasActiveService(TEST_SERVICE_ID_1)) assertFalse(repository.hasActiveService(TEST_SERVICE_ID_1))
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
assertTrue(repository.hasActiveService(TEST_SERVICE_ID_1)) assertTrue(repository.hasActiveService(TEST_SERVICE_ID_1))
val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1) val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
@@ -179,6 +181,41 @@ class MdnsRecordRepositoryTest {
assertEquals(0, repository.servicesCount) assertEquals(0, repository.servicesCount)
} }
@Test
fun testExitAnnouncements_WithSubtype() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, TEST_SUBTYPE)
repository.onAdvertisementSent(TEST_SERVICE_ID_1)
val exitAnnouncement = repository.exitService(TEST_SERVICE_ID_1)
assertNotNull(exitAnnouncement)
assertEquals(1, repository.servicesCount)
val packet = exitAnnouncement.getPacket(0)
assertEquals(0x8400 /* response, authoritative */, packet.flags)
assertEquals(0, packet.questions.size)
assertEquals(0, packet.authorityRecords.size)
assertEquals(0, packet.additionalRecords.size)
assertContentEquals(listOf(
MdnsPointerRecord(
arrayOf("_testservice", "_tcp", "local"),
0L /* receiptTimeMillis */,
true /* cacheFlush */,
0L /* ttlMillis */,
arrayOf("MyTestService", "_testservice", "_tcp", "local")),
MdnsPointerRecord(
arrayOf("_subtype", "_sub", "_testservice", "_tcp", "local"),
0L /* receiptTimeMillis */,
true /* cacheFlush */,
0L /* ttlMillis */,
arrayOf("MyTestService", "_testservice", "_tcp", "local")),
), packet.answers)
repository.removeService(TEST_SERVICE_ID_1)
assertEquals(0, repository.servicesCount)
}
@Test @Test
fun testExitingServiceReAdded() { fun testExitingServiceReAdded() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
@@ -186,7 +223,8 @@ class MdnsRecordRepositoryTest {
repository.onAdvertisementSent(TEST_SERVICE_ID_1) repository.onAdvertisementSent(TEST_SERVICE_ID_1)
repository.exitService(TEST_SERVICE_ID_1) repository.exitService(TEST_SERVICE_ID_1)
assertEquals(TEST_SERVICE_ID_1, repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1)) assertEquals(TEST_SERVICE_ID_1,
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1, null /* subtype */))
assertEquals(1, repository.servicesCount) assertEquals(1, repository.servicesCount)
repository.removeService(TEST_SERVICE_ID_2) repository.removeService(TEST_SERVICE_ID_2)
@@ -196,7 +234,8 @@ class MdnsRecordRepositoryTest {
@Test @Test
fun testOnProbingSucceeded() { fun testOnProbingSucceeded() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1) val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1,
TEST_SUBTYPE)
repository.onAdvertisementSent(TEST_SERVICE_ID_1) repository.onAdvertisementSent(TEST_SERVICE_ID_1)
val packet = announcementInfo.getPacket(0) val packet = announcementInfo.getPacket(0)
@@ -205,6 +244,7 @@ class MdnsRecordRepositoryTest {
assertEquals(0, packet.authorityRecords.size) assertEquals(0, packet.authorityRecords.size)
val serviceType = arrayOf("_testservice", "_tcp", "local") val serviceType = arrayOf("_testservice", "_tcp", "local")
val serviceSubtype = arrayOf(TEST_SUBTYPE, "_sub", "_testservice", "_tcp", "local")
val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local") val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
val v4AddrRev = getReverseDnsAddress(TEST_ADDRESSES[0].address) val v4AddrRev = getReverseDnsAddress(TEST_ADDRESSES[0].address)
val v6Addr1Rev = getReverseDnsAddress(TEST_ADDRESSES[1].address) val v6Addr1Rev = getReverseDnsAddress(TEST_ADDRESSES[1].address)
@@ -250,6 +290,13 @@ class MdnsRecordRepositoryTest {
false /* cacheFlush */, false /* cacheFlush */,
4500000L /* ttlMillis */, 4500000L /* ttlMillis */,
serviceName), serviceName),
MdnsPointerRecord(
serviceSubtype,
0L /* receiptTimeMillis */,
// Not a unique name owned by the announcer, so cacheFlush=false
false /* cacheFlush */,
4500000L /* ttlMillis */,
serviceName),
MdnsServiceRecord( MdnsServiceRecord(
serviceName, serviceName,
0L /* receiptTimeMillis */, 0L /* receiptTimeMillis */,
@@ -319,9 +366,21 @@ class MdnsRecordRepositoryTest {
@Test @Test
fun testGetReply() { fun testGetReply() {
doGetReplyTest(subtype = null)
}
@Test
fun testGetReply_WithSubtype() {
doGetReplyTest(TEST_SUBTYPE)
}
private fun doGetReplyTest(subtype: String?) {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1) repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, subtype)
val questions = listOf(MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"), val queriedName = if (subtype == null) arrayOf("_testservice", "_tcp", "local")
else arrayOf(subtype, "_sub", "_testservice", "_tcp", "local")
val questions = listOf(MdnsPointerRecord(queriedName,
0L /* receiptTimeMillis */, 0L /* receiptTimeMillis */,
false /* cacheFlush */, false /* cacheFlush */,
// TTL and data is empty for a question // TTL and data is empty for a question
@@ -344,7 +403,7 @@ class MdnsRecordRepositoryTest {
assertEquals(listOf( assertEquals(listOf(
MdnsPointerRecord( MdnsPointerRecord(
arrayOf("_testservice", "_tcp", "local"), queriedName,
0L /* receiptTimeMillis */, 0L /* receiptTimeMillis */,
false /* cacheFlush */, false /* cacheFlush */,
longTtl, longTtl,
@@ -405,8 +464,8 @@ class MdnsRecordRepositoryTest {
@Test @Test
fun testGetConflictingServices() { fun testGetConflictingServices() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2) repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */)
val packet = MdnsPacket( val packet = MdnsPacket(
0 /* flags */, 0 /* flags */,
@@ -433,8 +492,8 @@ class MdnsRecordRepositoryTest {
@Test @Test
fun testGetConflictingServices_IdenticalService() { fun testGetConflictingServices_IdenticalService() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2) repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */)
val otherTtlMillis = 1234L val otherTtlMillis = 1234L
val packet = MdnsPacket( val packet = MdnsPacket(
@@ -460,10 +519,13 @@ class MdnsRecordRepositoryTest {
} }
} }
private fun MdnsRecordRepository.initWithService(serviceId: Int, serviceInfo: NsdServiceInfo): private fun MdnsRecordRepository.initWithService(
AnnouncementInfo { serviceId: Int,
serviceInfo: NsdServiceInfo,
subtype: String? = null
): AnnouncementInfo {
updateAddresses(TEST_ADDRESSES) updateAddresses(TEST_ADDRESSES)
addService(serviceId, serviceInfo) addService(serviceId, serviceInfo, subtype)
val probingInfo = setServiceProbing(serviceId) val probingInfo = setServiceProbing(serviceId)
assertNotNull(probingInfo) assertNotNull(probingInfo)
return onProbingSucceeded(probingInfo) return onProbingSucceeded(probingInfo)