Merge changes I69128db9,I13db22f8

* changes:
  Implement onServiceConflict
  Add replying to queries
This commit is contained in:
Remi NGUYEN VAN
2023-01-18 01:13:35 +00:00
committed by Gerrit Code Review
14 changed files with 682 additions and 107 deletions

View File

@@ -38,6 +38,7 @@ import org.mockito.ArgumentMatchers.eq
import org.mockito.Mockito.any
import org.mockito.Mockito.anyInt
import org.mockito.Mockito.argThat
import org.mockito.Mockito.atLeastOnce
import org.mockito.Mockito.doReturn
import org.mockito.Mockito.mock
import org.mockito.Mockito.never
@@ -161,6 +162,60 @@ class MdnsAdvertiserTest {
verify(socketProvider).unrequestSocket(socketCb)
}
@Test
fun testAddService_Conflicts() {
val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps)
postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1) }
val oneNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), oneNetSocketCbCaptor.capture())
val oneNetSocketCb = oneNetSocketCbCaptor.value
// Register a service with the same name on all networks (name conflict)
postSync { advertiser.addService(SERVICE_ID_2, ALL_NETWORKS_SERVICE) }
val allNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
verify(socketProvider).requestSocket(eq(null), allNetSocketCbCaptor.capture())
val allNetSocketCb = allNetSocketCbCaptor.value
// Callbacks for matching network and all networks both get the socket
postSync {
oneNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
allNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
}
val expectedRenamed = NsdServiceInfo(
"${ALL_NETWORKS_SERVICE.serviceName} (2)", ALL_NETWORKS_SERVICE.serviceType).apply {
port = ALL_NETWORKS_SERVICE.port
host = ALL_NETWORKS_SERVICE.host
network = ALL_NETWORKS_SERVICE.network
}
val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
verify(mockDeps).makeAdvertiser(eq(mockSocket1), eq(listOf(TEST_LINKADDR)),
eq(thread.looper), any(), intAdvCbCaptor.capture())
verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_1),
argThat { it.matches(SERVICE_1) })
verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_2),
argThat { it.matches(expectedRenamed) })
doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_1)
postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded(
mockInterfaceAdvertiser1, SERVICE_ID_1) }
verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_1), argThat { it.matches(SERVICE_1) })
doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_2)
postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded(
mockInterfaceAdvertiser1, SERVICE_ID_2) }
verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_2),
argThat { it.matches(expectedRenamed) })
postSync { oneNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
postSync { allNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
// destroyNow can be called multiple times
verify(mockInterfaceAdvertiser1, atLeastOnce()).destroyNow()
}
private fun postSync(r: () -> Unit) {
handler.post(r)
handler.waitForIdle(TIMEOUT_MS)

View File

@@ -79,7 +79,7 @@ class MdnsAnnouncerTest {
@Test
fun testAnnounce() {
val replySender = MdnsReplySender(thread.looper, socket, buffer)
val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
@Suppress("UNCHECKED_CAST")
val cb = mock(MdnsPacketRepeater.PacketRepeaterCallback::class.java)
as MdnsPacketRepeater.PacketRepeaterCallback<BaseAnnouncementInfo>

View File

@@ -21,6 +21,7 @@ import android.net.LinkAddress
import android.net.nsd.NsdServiceInfo
import android.os.Build
import android.os.HandlerThread
import com.android.net.module.util.HexDump
import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo
import com.android.server.connectivity.mdns.MdnsAnnouncer.ExitAnnouncementInfo
@@ -30,6 +31,10 @@ import com.android.server.connectivity.mdns.MdnsProber.ProbingInfo
import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
import com.android.testutils.DevSdkIgnoreRunner
import com.android.testutils.waitForIdle
import java.net.InetSocketAddress
import kotlin.test.assertContentEquals
import kotlin.test.assertEquals
import kotlin.test.assertTrue
import org.junit.After
import org.junit.Before
import org.junit.Test
@@ -37,8 +42,10 @@ import org.junit.runner.RunWith
import org.mockito.ArgumentCaptor
import org.mockito.Mockito.any
import org.mockito.Mockito.anyInt
import org.mockito.Mockito.anyString
import org.mockito.Mockito.doAnswer
import org.mockito.Mockito.doReturn
import org.mockito.Mockito.eq
import org.mockito.Mockito.mock
import org.mockito.Mockito.times
import org.mockito.Mockito.verify
@@ -67,13 +74,18 @@ class MdnsInterfaceAdvertiserTest {
private val replySender = mock(MdnsReplySender::class.java)
private val announcer = mock(MdnsAnnouncer::class.java)
private val prober = mock(MdnsProber::class.java)
@Suppress("UNCHECKED_CAST")
private val probeCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
as ArgumentCaptor<PacketRepeaterCallback<ProbingInfo>>
@Suppress("UNCHECKED_CAST")
private val announceCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
as ArgumentCaptor<PacketRepeaterCallback<BaseAnnouncementInfo>>
private val packetHandlerCaptor = ArgumentCaptor.forClass(
MulticastPacketReader.PacketHandler::class.java)
private val probeCb get() = probeCbCaptor.value
private val announceCb get() = announceCbCaptor.value
private val packetHandler get() = packetHandlerCaptor.value
private val advertiser by lazy {
MdnsInterfaceAdvertiser(LOG_TAG, socket, TEST_ADDRS, thread.looper, TEST_BUFFER, cb, deps)
@@ -82,9 +94,9 @@ class MdnsInterfaceAdvertiserTest {
@Before
fun setUp() {
doReturn(repository).`when`(deps).makeRecordRepository(any())
doReturn(replySender).`when`(deps).makeReplySender(any(), any(), any())
doReturn(announcer).`when`(deps).makeMdnsAnnouncer(any(), any(), any(), any())
doReturn(prober).`when`(deps).makeMdnsProber(any(), any(), any(), any())
doReturn(replySender).`when`(deps).makeReplySender(anyString(), any(), any(), any())
doReturn(announcer).`when`(deps).makeMdnsAnnouncer(anyString(), any(), any(), any())
doReturn(prober).`when`(deps).makeMdnsProber(anyString(), any(), any(), any())
val knownServices = mutableSetOf<Int>()
doAnswer { inv ->
@@ -104,6 +116,7 @@ class MdnsInterfaceAdvertiserTest {
thread.start()
advertiser.start()
verify(socket).addPacketHandler(packetHandlerCaptor.capture())
verify(deps).makeMdnsProber(any(), any(), any(), probeCbCaptor.capture())
verify(deps).makeMdnsAnnouncer(any(), any(), any(), announceCbCaptor.capture())
}
@@ -157,6 +170,39 @@ class MdnsInterfaceAdvertiserTest {
verify(announcer, times(1)).stop(TEST_SERVICE_ID_1)
}
@Test
fun testReplyToQuery() {
addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1)
val mockReply = mock(MdnsRecordRepository.ReplyInfo::class.java)
doReturn(mockReply).`when`(repository).getReply(any(), any())
// Query obtained with:
// scapy.raw(scapy.DNS(
// qd = scapy.DNSQR(qtype='PTR', qname='_testservice._tcp.local'))
// ).hex().upper()
val query = HexDump.hexStringToByteArray(
"0000010000010000000000000C5F7465737473657276696365045F746370056C6F63616C00000C0001"
)
val src = InetSocketAddress(parseNumericAddress("2001:db8::456"), MdnsConstants.MDNS_PORT)
packetHandler.handlePacket(query, query.size, src)
val packetCaptor = ArgumentCaptor.forClass(MdnsPacket::class.java)
verify(repository).getReply(packetCaptor.capture(), eq(src))
packetCaptor.value.let {
assertEquals(1, it.questions.size)
assertEquals(0, it.answers.size)
assertEquals(0, it.authorityRecords.size)
assertEquals(0, it.additionalRecords.size)
assertTrue(it.questions[0] is MdnsPointerRecord)
assertContentEquals(arrayOf("_testservice", "_tcp", "local"), it.questions[0].name)
}
verify(replySender).queueReply(mockReply)
}
private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
AnnouncementInfo {
val testProbingInfo = mock(ProbingInfo::class.java)

View File

@@ -114,7 +114,7 @@ class MdnsProberTest {
@Test
fun testProbe() {
val replySender = MdnsReplySender(thread.looper, socket, buffer)
val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
val prober = TestProber(thread.looper, replySender, cb)
val probeInfo = TestProbeInfo(
listOf(makeServiceRecord(TEST_SERVICE_NAME_1, 37890)))
@@ -129,7 +129,7 @@ class MdnsProberTest {
@Test
fun testProbeMultipleRecords() {
val replySender = MdnsReplySender(thread.looper, socket, buffer)
val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
val prober = TestProber(thread.looper, replySender, cb)
val probeInfo = TestProbeInfo(listOf(
makeServiceRecord(TEST_SERVICE_NAME_1, 37890),
@@ -167,7 +167,7 @@ class MdnsProberTest {
@Test
fun testStopProbing() {
val replySender = MdnsReplySender(thread.looper, socket, buffer)
val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
val prober = TestProber(thread.looper, replySender, cb)
val probeInfo = TestProbeInfo(
listOf(makeServiceRecord(TEST_SERVICE_NAME_1, 37890)),

View File

@@ -21,10 +21,12 @@ import android.net.LinkAddress
import android.net.nsd.NsdServiceInfo
import android.os.Build
import android.os.HandlerThread
import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
import com.android.server.connectivity.mdns.MdnsRecordRepository.Dependencies
import com.android.server.connectivity.mdns.MdnsRecordRepository.getReverseDnsAddress
import com.android.testutils.DevSdkIgnoreRule
import com.android.testutils.DevSdkIgnoreRunner
import java.net.InetSocketAddress
import java.net.NetworkInterface
import java.util.Collections
import kotlin.test.assertContentEquals
@@ -150,11 +152,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testExitAnnouncements() {
val repository = MdnsRecordRepository(thread.looper, deps)
repository.updateAddresses(TEST_ADDRESSES)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
repository.onProbingSucceeded(probingInfo)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
repository.onAdvertisementSent(TEST_SERVICE_ID_1)
val exitAnnouncement = repository.exitService(TEST_SERVICE_ID_1)
@@ -183,9 +181,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testExitingServiceReAdded() {
val repository = MdnsRecordRepository(thread.looper, deps)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
repository.onProbingSucceeded(probingInfo)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
repository.onAdvertisementSent(TEST_SERVICE_ID_1)
repository.exitService(TEST_SERVICE_ID_1)
@@ -199,11 +195,8 @@ class MdnsRecordRepositoryTest {
@Test
fun testOnProbingSucceeded() {
val repository = MdnsRecordRepository(thread.looper, deps)
repository.updateAddresses(TEST_ADDRESSES)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
val announcementInfo = repository.onProbingSucceeded(probingInfo)
val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
repository.onAdvertisementSent(TEST_SERVICE_ID_1)
val packet = announcementInfo.getPacket(0)
assertEquals(0x8400 /* response, authoritative */, packet.flags)
@@ -322,4 +315,98 @@ class MdnsRecordRepositoryTest {
val expectedV4 = "123.2.0.192.in-addr.arpa".split(".").toTypedArray()
assertContentEquals(expectedV4, getReverseDnsAddress(parseNumericAddress("192.0.2.123")))
}
@Test
fun testGetReply() {
val repository = MdnsRecordRepository(thread.looper, deps)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
val questions = listOf(MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"),
0L /* receiptTimeMillis */,
false /* cacheFlush */,
// TTL and data is empty for a question
0L /* ttlMillis */,
null /* pointer */))
val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
listOf() /* authorityRecords */, listOf() /* additionalRecords */)
val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
val reply = repository.getReply(query, src)
assertNotNull(reply)
// Source address is IPv4
assertEquals(MdnsConstants.getMdnsIPv4Address(), reply.destination.address)
assertEquals(MdnsConstants.MDNS_PORT, reply.destination.port)
// TTLs as per RFC6762 10.
val longTtl = 4_500_000L
val shortTtl = 120_000L
val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
assertEquals(listOf(
MdnsPointerRecord(
arrayOf("_testservice", "_tcp", "local"),
0L /* receiptTimeMillis */,
false /* cacheFlush */,
longTtl,
serviceName),
), reply.answers)
assertEquals(listOf(
MdnsTextRecord(
serviceName,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
longTtl,
listOf() /* entries */),
MdnsServiceRecord(
serviceName,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
shortTtl,
0 /* servicePriority */,
0 /* serviceWeight */,
TEST_PORT,
TEST_HOSTNAME),
MdnsInetAddressRecord(
TEST_HOSTNAME,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
shortTtl,
TEST_ADDRESSES[0].address),
MdnsInetAddressRecord(
TEST_HOSTNAME,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
shortTtl,
TEST_ADDRESSES[1].address),
MdnsInetAddressRecord(
TEST_HOSTNAME,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
shortTtl,
TEST_ADDRESSES[2].address),
MdnsNsecRecord(
serviceName,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
longTtl,
serviceName /* nextDomain */,
intArrayOf(MdnsRecord.TYPE_TXT, MdnsRecord.TYPE_SRV)),
MdnsNsecRecord(
TEST_HOSTNAME,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
shortTtl,
TEST_HOSTNAME /* nextDomain */,
intArrayOf(MdnsRecord.TYPE_A, MdnsRecord.TYPE_AAAA)),
), reply.additionalAnswers)
}
}
private fun MdnsRecordRepository.initWithService(serviceId: Int, serviceInfo: NsdServiceInfo):
AnnouncementInfo {
updateAddresses(TEST_ADDRESSES)
addService(serviceId, serviceInfo)
val probingInfo = setServiceProbing(serviceId)
assertNotNull(probingInfo)
return onProbingSucceeded(probingInfo)
}