Merge "Add support to update the registered service in place" into main

This commit is contained in:
Yuyang Huang
2023-12-02 08:49:23 +00:00
committed by Gerrit Code Review
9 changed files with 372 additions and 62 deletions

View File

@@ -1115,9 +1115,9 @@ public class NsdServiceTest {
final RegistrationListener regListener = mock(RegistrationListener.class);
client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
waitForIdle();
verify(mAdvertiser).addService(anyInt(), argThat(s ->
verify(mAdvertiser).addOrUpdateService(anyInt(), argThat(s ->
"Instance".equals(s.getServiceName())
&& SERVICE_TYPE.equals(s.getServiceType())), eq("_subtype"));
&& SERVICE_TYPE.equals(s.getServiceType())), eq("_subtype"), any());
final DiscoveryListener discListener = mock(DiscoveryListener.class);
client.discoverServices(typeWithSubtype, PROTOCOL, network, Runnable::run, discListener);
@@ -1222,8 +1222,8 @@ public class NsdServiceTest {
waitForIdle();
final ArgumentCaptor<Integer> serviceIdCaptor = ArgumentCaptor.forClass(Integer.class);
verify(mAdvertiser).addService(serviceIdCaptor.capture(),
argThat(info -> matches(info, regInfo)), eq(null) /* subtype */);
verify(mAdvertiser).addOrUpdateService(serviceIdCaptor.capture(),
argThat(info -> matches(info, regInfo)), eq(null) /* subtype */, any());
client.unregisterService(regListenerWithoutFeature);
waitForIdle();
@@ -1282,10 +1282,10 @@ public class NsdServiceTest {
waitForIdle();
// The advertiser is enabled for _type2 but not _type1
verify(mAdvertiser, never()).addService(
anyInt(), argThat(info -> matches(info, service1)), eq(null) /* subtype */);
verify(mAdvertiser).addService(
anyInt(), argThat(info -> matches(info, service2)), eq(null) /* subtype */);
verify(mAdvertiser, never()).addOrUpdateService(anyInt(),
argThat(info -> matches(info, service1)), eq(null) /* subtype */, any());
verify(mAdvertiser).addOrUpdateService(anyInt(), argThat(info -> matches(info, service2)),
eq(null) /* subtype */, any());
}
@Test
@@ -1309,8 +1309,8 @@ public class NsdServiceTest {
waitForIdle();
verify(mSocketProvider).startMonitoringSockets();
final ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class);
verify(mAdvertiser).addService(idCaptor.capture(), argThat(info ->
matches(info, regInfo)), eq(null) /* subtype */);
verify(mAdvertiser).addOrUpdateService(idCaptor.capture(), argThat(info ->
matches(info, regInfo)), eq(null) /* subtype */, any());
// Verify onServiceRegistered callback
final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue();
@@ -1358,7 +1358,7 @@ public class NsdServiceTest {
client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
waitForIdle();
verify(mAdvertiser, never()).addService(anyInt(), any(), any());
verify(mAdvertiser, never()).addOrUpdateService(anyInt(), any(), any(), any());
verify(regListener, timeout(TIMEOUT_MS)).onRegistrationFailed(
argThat(info -> matches(info, regInfo)), eq(FAILURE_INTERNAL_ERROR));
@@ -1387,9 +1387,9 @@ public class NsdServiceTest {
waitForIdle();
final ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class);
// Service name is truncated to 63 characters
verify(mAdvertiser).addService(idCaptor.capture(),
verify(mAdvertiser).addOrUpdateService(idCaptor.capture(),
argThat(info -> info.getServiceName().equals("a".repeat(63))),
eq(null) /* subtype */);
eq(null) /* subtype */, any());
// Verify onServiceRegistered callback
final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue();
@@ -1479,7 +1479,7 @@ public class NsdServiceTest {
client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
waitForIdle();
verify(mSocketProvider).startMonitoringSockets();
verify(mAdvertiser).addService(anyInt(), any(), any());
verify(mAdvertiser).addOrUpdateService(anyInt(), any(), any(), any());
// Verify the discovery uses MdnsDiscoveryManager
final DiscoveryListener discListener = mock(DiscoveryListener.class);
@@ -1512,7 +1512,7 @@ public class NsdServiceTest {
client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
waitForIdle();
verify(mSocketProvider).startMonitoringSockets();
verify(mAdvertiser).addService(anyInt(), any(), any());
verify(mAdvertiser).addOrUpdateService(anyInt(), any(), any(), any());
final Network wifiNetwork1 = new Network(123);
final Network wifiNetwork2 = new Network(124);

View File

@@ -19,6 +19,7 @@ package com.android.server.connectivity.mdns
import android.net.InetAddresses.parseNumericAddress
import android.net.LinkAddress
import android.net.Network
import android.net.nsd.NsdManager
import android.net.nsd.NsdServiceInfo
import android.net.nsd.OffloadEngine
import android.net.nsd.OffloadServiceInfo
@@ -71,6 +72,7 @@ private val TEST_INTERFACE1 = "test_iface1"
private val TEST_INTERFACE2 = "test_iface2"
private val TEST_OFFLOAD_PACKET1 = byteArrayOf(0x01, 0x02, 0x03)
private val TEST_OFFLOAD_PACKET2 = byteArrayOf(0x02, 0x03, 0x04)
private val DEFAULT_ADVERTISING_OPTION = MdnsAdvertisingOptions.getDefaultOptions()
private val SERVICE_1 = NsdServiceInfo("TestServiceName", "_advertisertest._tcp").apply {
port = 12345
@@ -186,7 +188,8 @@ class MdnsAdvertiserTest {
fun testAddService_OneNetwork() {
val advertiser =
MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags)
postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1, null /* subtype */) }
postSync { advertiser.addOrUpdateService(SERVICE_ID_1, SERVICE_1,
null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), socketCbCaptor.capture())
@@ -247,7 +250,8 @@ class MdnsAdvertiserTest {
fun testAddService_AllNetworks() {
val advertiser =
MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags)
postSync { advertiser.addService(SERVICE_ID_1, ALL_NETWORKS_SERVICE, TEST_SUBTYPE) }
postSync { advertiser.addOrUpdateService(SERVICE_ID_1, ALL_NETWORKS_SERVICE,
TEST_SUBTYPE, DEFAULT_ADVERTISING_OPTION) }
val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
verify(socketProvider).requestSocket(eq(ALL_NETWORKS_SERVICE.network),
@@ -318,24 +322,27 @@ class MdnsAdvertiserTest {
fun testAddService_Conflicts() {
val advertiser =
MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags)
postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1, null /* subtype */) }
postSync { advertiser.addOrUpdateService(SERVICE_ID_1, SERVICE_1,
null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
val oneNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), oneNetSocketCbCaptor.capture())
val oneNetSocketCb = oneNetSocketCbCaptor.value
// Register a service with the same name on all networks (name conflict)
postSync { advertiser.addService(SERVICE_ID_2, ALL_NETWORKS_SERVICE, null /* subtype */) }
postSync { advertiser.addOrUpdateService(SERVICE_ID_2, ALL_NETWORKS_SERVICE,
null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
val allNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
verify(socketProvider).requestSocket(eq(null), allNetSocketCbCaptor.capture())
val allNetSocketCb = allNetSocketCbCaptor.value
postSync { advertiser.addService(LONG_SERVICE_ID_1, LONG_SERVICE_1, null /* subtype */) }
postSync { advertiser.addService(LONG_SERVICE_ID_2, LONG_ALL_NETWORKS_SERVICE,
null /* subtype */) }
postSync { advertiser.addOrUpdateService(LONG_SERVICE_ID_1, LONG_SERVICE_1,
null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
postSync { advertiser.addOrUpdateService(LONG_SERVICE_ID_2, LONG_ALL_NETWORKS_SERVICE,
null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
postSync { advertiser.addService(CASE_INSENSITIVE_TEST_SERVICE_ID, ALL_NETWORKS_SERVICE_2,
null /* subtype */) }
postSync { advertiser.addOrUpdateService(CASE_INSENSITIVE_TEST_SERVICE_ID,
ALL_NETWORKS_SERVICE_2, null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
// Callbacks for matching network and all networks both get the socket
postSync {
@@ -399,12 +406,52 @@ class MdnsAdvertiserTest {
verify(mockInterfaceAdvertiser1, atLeastOnce()).destroyNow()
}
@Test
fun testAddOrUpdateService_Updates() {
val advertiser =
MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags)
postSync { advertiser.addOrUpdateService(SERVICE_ID_1, ALL_NETWORKS_SERVICE,
null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
verify(socketProvider).requestSocket(eq(null), socketCbCaptor.capture())
val socketCb = socketCbCaptor.value
postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR)) }
verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_1),
argThat { it.matches(ALL_NETWORKS_SERVICE) }, eq(null))
val updateOptions = MdnsAdvertisingOptions.newBuilder().setIsOnlyUpdate(true).build()
// Update with serviceId that is not registered yet should fail
postSync { advertiser.addOrUpdateService(SERVICE_ID_2, ALL_NETWORKS_SERVICE, TEST_SUBTYPE,
updateOptions) }
verify(cb).onRegisterServiceFailed(SERVICE_ID_2, NsdManager.FAILURE_INTERNAL_ERROR)
// Update service with different NsdServiceInfo should fail
postSync { advertiser.addOrUpdateService(SERVICE_ID_1, SERVICE_1, TEST_SUBTYPE,
updateOptions) }
verify(cb).onRegisterServiceFailed(SERVICE_ID_1, NsdManager.FAILURE_INTERNAL_ERROR)
// Update service with same NsdServiceInfo but different subType should succeed
postSync { advertiser.addOrUpdateService(SERVICE_ID_1, ALL_NETWORKS_SERVICE, TEST_SUBTYPE,
updateOptions) }
verify(mockInterfaceAdvertiser1).updateService(eq(SERVICE_ID_1), eq(TEST_SUBTYPE))
// Newly created MdnsInterfaceAdvertiser will get addService() call.
postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_2, mockSocket2, listOf(TEST_LINKADDR2)) }
verify(mockInterfaceAdvertiser2).addService(eq(SERVICE_ID_1),
argThat { it.matches(ALL_NETWORKS_SERVICE) }, eq(TEST_SUBTYPE))
}
@Test
fun testRemoveService_whenAllServiceRemoved_thenUpdateHostName() {
val advertiser =
MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags)
verify(mockDeps, times(1)).generateHostname()
postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1, null /* subtype */) }
postSync { advertiser.addOrUpdateService(SERVICE_ID_1, SERVICE_1,
null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
postSync { advertiser.removeService(SERVICE_ID_1) }
verify(mockDeps, times(2)).generateHostname()
}

View File

@@ -48,6 +48,7 @@ import org.mockito.Mockito.doAnswer
import org.mockito.Mockito.doReturn
import org.mockito.Mockito.eq
import org.mockito.Mockito.mock
import org.mockito.Mockito.never
import org.mockito.Mockito.times
import org.mockito.Mockito.verify
@@ -59,6 +60,7 @@ private val TEST_BUFFER = ByteArray(1300)
private val TEST_HOSTNAME = arrayOf("Android_test", "local")
private const val TEST_SERVICE_ID_1 = 42
private const val TEST_SERVICE_ID_DUPLICATE = 43
private val TEST_SERVICE_1 = NsdServiceInfo().apply {
serviceType = "_testservice._tcp"
serviceName = "MyTestService"
@@ -272,6 +274,28 @@ class MdnsInterfaceAdvertiserTest {
verify(prober).restartForConflict(mockProbingInfo)
}
@Test
fun testReplaceExitingService() {
doReturn(TEST_SERVICE_ID_DUPLICATE).`when`(repository)
.addService(eq(TEST_SERVICE_ID_DUPLICATE), any(), any())
val subType = "_sub"
advertiser.addService(TEST_SERVICE_ID_DUPLICATE, TEST_SERVICE_1, subType)
verify(repository).addService(eq(TEST_SERVICE_ID_DUPLICATE), any(), any())
verify(announcer).stop(TEST_SERVICE_ID_DUPLICATE)
verify(prober).startProbing(any())
}
@Test
fun testUpdateExistingService() {
doReturn(TEST_SERVICE_ID_DUPLICATE).`when`(repository)
.addService(eq(TEST_SERVICE_ID_DUPLICATE), any(), any())
val subType = "_sub"
advertiser.updateService(TEST_SERVICE_ID_DUPLICATE, subType)
verify(repository).updateService(eq(TEST_SERVICE_ID_DUPLICATE), any())
verify(announcer, never()).stop(TEST_SERVICE_ID_DUPLICATE)
verify(prober, never()).startProbing(any())
}
private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
AnnouncementInfo {
val testProbingInfo = mock(ProbingInfo::class.java)

View File

@@ -129,7 +129,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testAddAndConflicts() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
assertFailsWith(NameConflictException::class) {
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1, null /* subtype */)
}
@@ -138,6 +138,45 @@ class MdnsRecordRepositoryTest {
}
}
@Test
fun testAddAndUpdates() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
assertFailsWith(IllegalArgumentException::class) {
repository.updateService(TEST_SERVICE_ID_2, null /* subtype */)
}
repository.updateService(TEST_SERVICE_ID_1, TEST_SUBTYPE)
val queriedName = arrayOf(TEST_SUBTYPE, "_sub", "_testservice", "_tcp", "local")
val questions = listOf(MdnsPointerRecord(queriedName,
0L /* receiptTimeMillis */,
false /* cacheFlush */,
// TTL and data is empty for a question
0L /* ttlMillis */,
null /* pointer */))
val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
listOf() /* authorityRecords */, listOf() /* additionalRecords */)
val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
val reply = repository.getReply(query, src)
assertNotNull(reply)
// TTLs as per RFC6762 10.
val longTtl = 4_500_000L
val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
assertEquals(listOf(
MdnsPointerRecord(
queriedName,
0L /* receiptTimeMillis */,
false /* cacheFlush */,
longTtl,
serviceName),
), reply.answers)
}
@Test
fun testInvalidReuseOfServiceId() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
@@ -758,7 +797,7 @@ class MdnsRecordRepositoryTest {
private fun MdnsRecordRepository.initWithService(
serviceId: Int,
serviceInfo: NsdServiceInfo,
subtype: String? = null
subtype: String? = null,
): AnnouncementInfo {
updateAddresses(TEST_ADDRESSES)
addService(serviceId, serviceInfo, subtype)