Merge "Include A/AAAA records in probing packet" into main

This commit is contained in:
Paul Hu
2023-10-17 02:24:10 +00:00
committed by Gerrit Code Review
12 changed files with 180 additions and 53 deletions

View File

@@ -153,10 +153,10 @@ class MdnsAdvertiserTest {
thread.start()
doReturn(TEST_HOSTNAME).`when`(mockDeps).generateHostname()
doReturn(mockInterfaceAdvertiser1).`when`(mockDeps).makeAdvertiser(eq(mockSocket1),
any(), any(), any(), any(), eq(TEST_HOSTNAME), any()
any(), any(), any(), any(), eq(TEST_HOSTNAME), any(), any()
)
doReturn(mockInterfaceAdvertiser2).`when`(mockDeps).makeAdvertiser(eq(mockSocket2),
any(), any(), any(), any(), eq(TEST_HOSTNAME), any()
any(), any(), any(), any(), eq(TEST_HOSTNAME), any(), any()
)
doReturn(true).`when`(mockInterfaceAdvertiser1).isProbing(anyInt())
doReturn(true).`when`(mockInterfaceAdvertiser2).isProbing(anyInt())
@@ -202,6 +202,7 @@ class MdnsAdvertiserTest {
any(),
intAdvCbCaptor.capture(),
eq(TEST_HOSTNAME),
any(),
any()
)
@@ -259,10 +260,10 @@ class MdnsAdvertiserTest {
val intAdvCbCaptor1 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
val intAdvCbCaptor2 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
verify(mockDeps).makeAdvertiser(eq(mockSocket1), eq(listOf(TEST_LINKADDR)),
eq(thread.looper), any(), intAdvCbCaptor1.capture(), eq(TEST_HOSTNAME), any()
eq(thread.looper), any(), intAdvCbCaptor1.capture(), eq(TEST_HOSTNAME), any(), any()
)
verify(mockDeps).makeAdvertiser(eq(mockSocket2), eq(listOf(TEST_LINKADDR)),
eq(thread.looper), any(), intAdvCbCaptor2.capture(), eq(TEST_HOSTNAME), any()
eq(thread.looper), any(), intAdvCbCaptor2.capture(), eq(TEST_HOSTNAME), any(), any()
)
verify(mockInterfaceAdvertiser1).addService(
anyInt(), eq(ALL_NETWORKS_SERVICE), eq(TEST_SUBTYPE))
@@ -367,7 +368,7 @@ class MdnsAdvertiserTest {
val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
verify(mockDeps).makeAdvertiser(eq(mockSocket1), eq(listOf(TEST_LINKADDR)),
eq(thread.looper), any(), intAdvCbCaptor.capture(), eq(TEST_HOSTNAME), any()
eq(thread.looper), any(), intAdvCbCaptor.capture(), eq(TEST_HOSTNAME), any(), any()
)
verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_1),
argThat { it.matches(SERVICE_1) }, eq(null))

View File

@@ -77,6 +77,7 @@ class MdnsInterfaceAdvertiserTest {
private val announcer = mock(MdnsAnnouncer::class.java)
private val prober = mock(MdnsProber::class.java)
private val sharedlog = SharedLog("MdnsInterfaceAdvertiserTest")
private val flags = MdnsFeatureFlags.newBuilder().build()
@Suppress("UNCHECKED_CAST")
private val probeCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
as ArgumentCaptor<PacketRepeaterCallback<ProbingInfo>>
@@ -99,15 +100,14 @@ class MdnsInterfaceAdvertiserTest {
cb,
deps,
TEST_HOSTNAME,
sharedlog
sharedlog,
flags
)
}
@Before
fun setUp() {
doReturn(repository).`when`(deps).makeRecordRepository(any(),
eq(TEST_HOSTNAME)
)
doReturn(repository).`when`(deps).makeRecordRepository(any(), eq(TEST_HOSTNAME), any())
doReturn(replySender).`when`(deps).makeReplySender(anyString(), any(), any(), any(), any())
doReturn(announcer).`when`(deps).makeMdnsAnnouncer(anyString(), any(), any(), any(), any())
doReturn(prober).`when`(deps).makeMdnsProber(anyString(), any(), any(), any(), any())

View File

@@ -78,6 +78,7 @@ class MdnsRecordRepositoryTest {
override fun getInterfaceInetAddresses(iface: NetworkInterface) =
Collections.enumeration(TEST_ADDRESSES.map { it.address })
}
private val flags = MdnsFeatureFlags.newBuilder().build()
@Before
fun setUp() {
@@ -92,7 +93,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testAddServiceAndProbe() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
assertEquals(0, repository.servicesCount)
assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1,
null /* subtype */))
@@ -127,7 +128,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testAddAndConflicts() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
assertFailsWith(NameConflictException::class) {
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1, null /* subtype */)
@@ -139,7 +140,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testInvalidReuseOfServiceId() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
assertFailsWith(IllegalArgumentException::class) {
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_2, null /* subtype */)
@@ -148,7 +149,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testHasActiveService() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
assertFalse(repository.hasActiveService(TEST_SERVICE_ID_1))
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
@@ -165,7 +166,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testExitAnnouncements() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */)
@@ -195,7 +196,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testExitAnnouncements_WithSubtype() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, TEST_SUBTYPE)
repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */)
@@ -231,7 +232,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testExitingServiceReAdded() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */)
repository.exitService(TEST_SERVICE_ID_1)
@@ -246,7 +247,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testOnProbingSucceeded() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1,
TEST_SUBTYPE)
repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */)
@@ -371,7 +372,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testGetOffloadPacket() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
val serviceType = arrayOf("_testservice", "_tcp", "local")
@@ -433,7 +434,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testGetReplyCaseInsensitive() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
val questionsCaseInSensitive =
listOf(MdnsPointerRecord(arrayOf("_TESTSERVICE", "_TCP", "local"),
@@ -463,7 +464,7 @@ class MdnsRecordRepositoryTest {
}
private fun doGetReplyTest(subtype: String?) {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, subtype)
val queriedName = if (subtype == null) arrayOf("_testservice", "_tcp", "local")
else arrayOf(subtype, "_sub", "_testservice", "_tcp", "local")
@@ -551,7 +552,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testGetConflictingServices() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */)
@@ -579,7 +580,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testGetConflictingServicesCaseInsensitive() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */)
@@ -607,7 +608,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testGetConflictingServices_IdenticalService() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */)
@@ -636,7 +637,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testGetConflictingServicesCaseInsensitive_IdenticalService() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */)
@@ -665,7 +666,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testGetServiceRepliedRequestsCount() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
// Verify that there is no packet replied.
assertEquals(MdnsConstants.NO_PACKET,
@@ -690,6 +691,68 @@ class MdnsRecordRepositoryTest {
assertEquals(MdnsConstants.NO_PACKET,
repository.getServiceRepliedRequestsCount(TEST_SERVICE_ID_2))
}
@Test
fun testIncludeInetAddressRecordsInProbing() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME,
MdnsFeatureFlags.newBuilder().setIncludeInetAddressRecordsInProbing(true).build())
repository.updateAddresses(TEST_ADDRESSES)
assertEquals(0, repository.servicesCount)
assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1,
null /* subtype */))
assertEquals(1, repository.servicesCount)
val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
assertNotNull(probingInfo)
assertTrue(repository.isProbing(TEST_SERVICE_ID_1))
assertEquals(TEST_SERVICE_ID_1, probingInfo.serviceId)
val packet = probingInfo.getPacket(0)
assertEquals(MdnsConstants.FLAGS_QUERY, packet.flags)
assertEquals(0, packet.answers.size)
assertEquals(0, packet.additionalRecords.size)
assertEquals(2, packet.questions.size)
val expectedName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
assertContentEquals(listOf(
MdnsAnyRecord(expectedName, false /* unicast */),
MdnsAnyRecord(TEST_HOSTNAME, false /* unicast */),
), packet.questions)
assertEquals(4, packet.authorityRecords.size)
assertContentEquals(listOf(
MdnsServiceRecord(
expectedName,
0L /* receiptTimeMillis */,
false /* cacheFlush */,
120_000L /* ttlMillis */,
0 /* servicePriority */,
0 /* serviceWeight */,
TEST_PORT,
TEST_HOSTNAME),
MdnsInetAddressRecord(
TEST_HOSTNAME,
0L /* receiptTimeMillis */,
false /* cacheFlush */,
120_000L /* ttlMillis */,
TEST_ADDRESSES[0].address),
MdnsInetAddressRecord(
TEST_HOSTNAME,
0L /* receiptTimeMillis */,
false /* cacheFlush */,
120_000L /* ttlMillis */,
TEST_ADDRESSES[1].address),
MdnsInetAddressRecord(
TEST_HOSTNAME,
0L /* receiptTimeMillis */,
false /* cacheFlush */,
120_000L /* ttlMillis */,
TEST_ADDRESSES[2].address)
), packet.authorityRecords)
assertContentEquals(intArrayOf(TEST_SERVICE_ID_1), repository.clearServices())
}
}
private fun MdnsRecordRepository.initWithService(