Merge changes I69128db9,I13db22f8
* changes: Implement onServiceConflict Add replying to queries
This commit is contained in:
@@ -38,6 +38,7 @@ import org.mockito.ArgumentMatchers.eq
|
||||
import org.mockito.Mockito.any
|
||||
import org.mockito.Mockito.anyInt
|
||||
import org.mockito.Mockito.argThat
|
||||
import org.mockito.Mockito.atLeastOnce
|
||||
import org.mockito.Mockito.doReturn
|
||||
import org.mockito.Mockito.mock
|
||||
import org.mockito.Mockito.never
|
||||
@@ -161,6 +162,60 @@ class MdnsAdvertiserTest {
|
||||
verify(socketProvider).unrequestSocket(socketCb)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAddService_Conflicts() {
|
||||
val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps)
|
||||
postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1) }
|
||||
|
||||
val oneNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
|
||||
verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), oneNetSocketCbCaptor.capture())
|
||||
val oneNetSocketCb = oneNetSocketCbCaptor.value
|
||||
|
||||
// Register a service with the same name on all networks (name conflict)
|
||||
postSync { advertiser.addService(SERVICE_ID_2, ALL_NETWORKS_SERVICE) }
|
||||
val allNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
|
||||
verify(socketProvider).requestSocket(eq(null), allNetSocketCbCaptor.capture())
|
||||
val allNetSocketCb = allNetSocketCbCaptor.value
|
||||
|
||||
// Callbacks for matching network and all networks both get the socket
|
||||
postSync {
|
||||
oneNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
|
||||
allNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
|
||||
}
|
||||
|
||||
val expectedRenamed = NsdServiceInfo(
|
||||
"${ALL_NETWORKS_SERVICE.serviceName} (2)", ALL_NETWORKS_SERVICE.serviceType).apply {
|
||||
port = ALL_NETWORKS_SERVICE.port
|
||||
host = ALL_NETWORKS_SERVICE.host
|
||||
network = ALL_NETWORKS_SERVICE.network
|
||||
}
|
||||
|
||||
val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
|
||||
verify(mockDeps).makeAdvertiser(eq(mockSocket1), eq(listOf(TEST_LINKADDR)),
|
||||
eq(thread.looper), any(), intAdvCbCaptor.capture())
|
||||
verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_1),
|
||||
argThat { it.matches(SERVICE_1) })
|
||||
verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_2),
|
||||
argThat { it.matches(expectedRenamed) })
|
||||
|
||||
doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_1)
|
||||
postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded(
|
||||
mockInterfaceAdvertiser1, SERVICE_ID_1) }
|
||||
verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_1), argThat { it.matches(SERVICE_1) })
|
||||
|
||||
doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_2)
|
||||
postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded(
|
||||
mockInterfaceAdvertiser1, SERVICE_ID_2) }
|
||||
verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_2),
|
||||
argThat { it.matches(expectedRenamed) })
|
||||
|
||||
postSync { oneNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
|
||||
postSync { allNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
|
||||
|
||||
// destroyNow can be called multiple times
|
||||
verify(mockInterfaceAdvertiser1, atLeastOnce()).destroyNow()
|
||||
}
|
||||
|
||||
private fun postSync(r: () -> Unit) {
|
||||
handler.post(r)
|
||||
handler.waitForIdle(TIMEOUT_MS)
|
||||
|
||||
@@ -79,7 +79,7 @@ class MdnsAnnouncerTest {
|
||||
|
||||
@Test
|
||||
fun testAnnounce() {
|
||||
val replySender = MdnsReplySender(thread.looper, socket, buffer)
|
||||
val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
val cb = mock(MdnsPacketRepeater.PacketRepeaterCallback::class.java)
|
||||
as MdnsPacketRepeater.PacketRepeaterCallback<BaseAnnouncementInfo>
|
||||
|
||||
@@ -21,6 +21,7 @@ import android.net.LinkAddress
|
||||
import android.net.nsd.NsdServiceInfo
|
||||
import android.os.Build
|
||||
import android.os.HandlerThread
|
||||
import com.android.net.module.util.HexDump
|
||||
import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
|
||||
import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo
|
||||
import com.android.server.connectivity.mdns.MdnsAnnouncer.ExitAnnouncementInfo
|
||||
@@ -30,6 +31,10 @@ import com.android.server.connectivity.mdns.MdnsProber.ProbingInfo
|
||||
import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
|
||||
import com.android.testutils.DevSdkIgnoreRunner
|
||||
import com.android.testutils.waitForIdle
|
||||
import java.net.InetSocketAddress
|
||||
import kotlin.test.assertContentEquals
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
import org.junit.After
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
@@ -37,8 +42,10 @@ import org.junit.runner.RunWith
|
||||
import org.mockito.ArgumentCaptor
|
||||
import org.mockito.Mockito.any
|
||||
import org.mockito.Mockito.anyInt
|
||||
import org.mockito.Mockito.anyString
|
||||
import org.mockito.Mockito.doAnswer
|
||||
import org.mockito.Mockito.doReturn
|
||||
import org.mockito.Mockito.eq
|
||||
import org.mockito.Mockito.mock
|
||||
import org.mockito.Mockito.times
|
||||
import org.mockito.Mockito.verify
|
||||
@@ -67,13 +74,18 @@ class MdnsInterfaceAdvertiserTest {
|
||||
private val replySender = mock(MdnsReplySender::class.java)
|
||||
private val announcer = mock(MdnsAnnouncer::class.java)
|
||||
private val prober = mock(MdnsProber::class.java)
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
private val probeCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
|
||||
as ArgumentCaptor<PacketRepeaterCallback<ProbingInfo>>
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
private val announceCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
|
||||
as ArgumentCaptor<PacketRepeaterCallback<BaseAnnouncementInfo>>
|
||||
private val packetHandlerCaptor = ArgumentCaptor.forClass(
|
||||
MulticastPacketReader.PacketHandler::class.java)
|
||||
|
||||
private val probeCb get() = probeCbCaptor.value
|
||||
private val announceCb get() = announceCbCaptor.value
|
||||
private val packetHandler get() = packetHandlerCaptor.value
|
||||
|
||||
private val advertiser by lazy {
|
||||
MdnsInterfaceAdvertiser(LOG_TAG, socket, TEST_ADDRS, thread.looper, TEST_BUFFER, cb, deps)
|
||||
@@ -82,9 +94,9 @@ class MdnsInterfaceAdvertiserTest {
|
||||
@Before
|
||||
fun setUp() {
|
||||
doReturn(repository).`when`(deps).makeRecordRepository(any())
|
||||
doReturn(replySender).`when`(deps).makeReplySender(any(), any(), any())
|
||||
doReturn(announcer).`when`(deps).makeMdnsAnnouncer(any(), any(), any(), any())
|
||||
doReturn(prober).`when`(deps).makeMdnsProber(any(), any(), any(), any())
|
||||
doReturn(replySender).`when`(deps).makeReplySender(anyString(), any(), any(), any())
|
||||
doReturn(announcer).`when`(deps).makeMdnsAnnouncer(anyString(), any(), any(), any())
|
||||
doReturn(prober).`when`(deps).makeMdnsProber(anyString(), any(), any(), any())
|
||||
|
||||
val knownServices = mutableSetOf<Int>()
|
||||
doAnswer { inv ->
|
||||
@@ -104,6 +116,7 @@ class MdnsInterfaceAdvertiserTest {
|
||||
thread.start()
|
||||
advertiser.start()
|
||||
|
||||
verify(socket).addPacketHandler(packetHandlerCaptor.capture())
|
||||
verify(deps).makeMdnsProber(any(), any(), any(), probeCbCaptor.capture())
|
||||
verify(deps).makeMdnsAnnouncer(any(), any(), any(), announceCbCaptor.capture())
|
||||
}
|
||||
@@ -157,6 +170,39 @@ class MdnsInterfaceAdvertiserTest {
|
||||
verify(announcer, times(1)).stop(TEST_SERVICE_ID_1)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testReplyToQuery() {
|
||||
addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1)
|
||||
|
||||
val mockReply = mock(MdnsRecordRepository.ReplyInfo::class.java)
|
||||
doReturn(mockReply).`when`(repository).getReply(any(), any())
|
||||
|
||||
// Query obtained with:
|
||||
// scapy.raw(scapy.DNS(
|
||||
// qd = scapy.DNSQR(qtype='PTR', qname='_testservice._tcp.local'))
|
||||
// ).hex().upper()
|
||||
val query = HexDump.hexStringToByteArray(
|
||||
"0000010000010000000000000C5F7465737473657276696365045F746370056C6F63616C00000C0001"
|
||||
)
|
||||
val src = InetSocketAddress(parseNumericAddress("2001:db8::456"), MdnsConstants.MDNS_PORT)
|
||||
packetHandler.handlePacket(query, query.size, src)
|
||||
|
||||
val packetCaptor = ArgumentCaptor.forClass(MdnsPacket::class.java)
|
||||
verify(repository).getReply(packetCaptor.capture(), eq(src))
|
||||
|
||||
packetCaptor.value.let {
|
||||
assertEquals(1, it.questions.size)
|
||||
assertEquals(0, it.answers.size)
|
||||
assertEquals(0, it.authorityRecords.size)
|
||||
assertEquals(0, it.additionalRecords.size)
|
||||
|
||||
assertTrue(it.questions[0] is MdnsPointerRecord)
|
||||
assertContentEquals(arrayOf("_testservice", "_tcp", "local"), it.questions[0].name)
|
||||
}
|
||||
|
||||
verify(replySender).queueReply(mockReply)
|
||||
}
|
||||
|
||||
private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
|
||||
AnnouncementInfo {
|
||||
val testProbingInfo = mock(ProbingInfo::class.java)
|
||||
|
||||
@@ -114,7 +114,7 @@ class MdnsProberTest {
|
||||
|
||||
@Test
|
||||
fun testProbe() {
|
||||
val replySender = MdnsReplySender(thread.looper, socket, buffer)
|
||||
val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
|
||||
val prober = TestProber(thread.looper, replySender, cb)
|
||||
val probeInfo = TestProbeInfo(
|
||||
listOf(makeServiceRecord(TEST_SERVICE_NAME_1, 37890)))
|
||||
@@ -129,7 +129,7 @@ class MdnsProberTest {
|
||||
|
||||
@Test
|
||||
fun testProbeMultipleRecords() {
|
||||
val replySender = MdnsReplySender(thread.looper, socket, buffer)
|
||||
val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
|
||||
val prober = TestProber(thread.looper, replySender, cb)
|
||||
val probeInfo = TestProbeInfo(listOf(
|
||||
makeServiceRecord(TEST_SERVICE_NAME_1, 37890),
|
||||
@@ -167,7 +167,7 @@ class MdnsProberTest {
|
||||
|
||||
@Test
|
||||
fun testStopProbing() {
|
||||
val replySender = MdnsReplySender(thread.looper, socket, buffer)
|
||||
val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
|
||||
val prober = TestProber(thread.looper, replySender, cb)
|
||||
val probeInfo = TestProbeInfo(
|
||||
listOf(makeServiceRecord(TEST_SERVICE_NAME_1, 37890)),
|
||||
|
||||
@@ -21,10 +21,12 @@ import android.net.LinkAddress
|
||||
import android.net.nsd.NsdServiceInfo
|
||||
import android.os.Build
|
||||
import android.os.HandlerThread
|
||||
import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
|
||||
import com.android.server.connectivity.mdns.MdnsRecordRepository.Dependencies
|
||||
import com.android.server.connectivity.mdns.MdnsRecordRepository.getReverseDnsAddress
|
||||
import com.android.testutils.DevSdkIgnoreRule
|
||||
import com.android.testutils.DevSdkIgnoreRunner
|
||||
import java.net.InetSocketAddress
|
||||
import java.net.NetworkInterface
|
||||
import java.util.Collections
|
||||
import kotlin.test.assertContentEquals
|
||||
@@ -150,11 +152,7 @@ class MdnsRecordRepositoryTest {
|
||||
@Test
|
||||
fun testExitAnnouncements() {
|
||||
val repository = MdnsRecordRepository(thread.looper, deps)
|
||||
repository.updateAddresses(TEST_ADDRESSES)
|
||||
|
||||
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
|
||||
val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
|
||||
repository.onProbingSucceeded(probingInfo)
|
||||
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
|
||||
repository.onAdvertisementSent(TEST_SERVICE_ID_1)
|
||||
|
||||
val exitAnnouncement = repository.exitService(TEST_SERVICE_ID_1)
|
||||
@@ -183,9 +181,7 @@ class MdnsRecordRepositoryTest {
|
||||
@Test
|
||||
fun testExitingServiceReAdded() {
|
||||
val repository = MdnsRecordRepository(thread.looper, deps)
|
||||
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
|
||||
val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
|
||||
repository.onProbingSucceeded(probingInfo)
|
||||
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
|
||||
repository.onAdvertisementSent(TEST_SERVICE_ID_1)
|
||||
repository.exitService(TEST_SERVICE_ID_1)
|
||||
|
||||
@@ -199,11 +195,8 @@ class MdnsRecordRepositoryTest {
|
||||
@Test
|
||||
fun testOnProbingSucceeded() {
|
||||
val repository = MdnsRecordRepository(thread.looper, deps)
|
||||
repository.updateAddresses(TEST_ADDRESSES)
|
||||
|
||||
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
|
||||
val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
|
||||
val announcementInfo = repository.onProbingSucceeded(probingInfo)
|
||||
val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
|
||||
repository.onAdvertisementSent(TEST_SERVICE_ID_1)
|
||||
val packet = announcementInfo.getPacket(0)
|
||||
|
||||
assertEquals(0x8400 /* response, authoritative */, packet.flags)
|
||||
@@ -322,4 +315,98 @@ class MdnsRecordRepositoryTest {
|
||||
val expectedV4 = "123.2.0.192.in-addr.arpa".split(".").toTypedArray()
|
||||
assertContentEquals(expectedV4, getReverseDnsAddress(parseNumericAddress("192.0.2.123")))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testGetReply() {
|
||||
val repository = MdnsRecordRepository(thread.looper, deps)
|
||||
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
|
||||
val questions = listOf(MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"),
|
||||
0L /* receiptTimeMillis */,
|
||||
false /* cacheFlush */,
|
||||
// TTL and data is empty for a question
|
||||
0L /* ttlMillis */,
|
||||
null /* pointer */))
|
||||
val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
|
||||
listOf() /* authorityRecords */, listOf() /* additionalRecords */)
|
||||
val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
|
||||
val reply = repository.getReply(query, src)
|
||||
|
||||
assertNotNull(reply)
|
||||
// Source address is IPv4
|
||||
assertEquals(MdnsConstants.getMdnsIPv4Address(), reply.destination.address)
|
||||
assertEquals(MdnsConstants.MDNS_PORT, reply.destination.port)
|
||||
|
||||
// TTLs as per RFC6762 10.
|
||||
val longTtl = 4_500_000L
|
||||
val shortTtl = 120_000L
|
||||
val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
|
||||
|
||||
assertEquals(listOf(
|
||||
MdnsPointerRecord(
|
||||
arrayOf("_testservice", "_tcp", "local"),
|
||||
0L /* receiptTimeMillis */,
|
||||
false /* cacheFlush */,
|
||||
longTtl,
|
||||
serviceName),
|
||||
), reply.answers)
|
||||
|
||||
assertEquals(listOf(
|
||||
MdnsTextRecord(
|
||||
serviceName,
|
||||
0L /* receiptTimeMillis */,
|
||||
true /* cacheFlush */,
|
||||
longTtl,
|
||||
listOf() /* entries */),
|
||||
MdnsServiceRecord(
|
||||
serviceName,
|
||||
0L /* receiptTimeMillis */,
|
||||
true /* cacheFlush */,
|
||||
shortTtl,
|
||||
0 /* servicePriority */,
|
||||
0 /* serviceWeight */,
|
||||
TEST_PORT,
|
||||
TEST_HOSTNAME),
|
||||
MdnsInetAddressRecord(
|
||||
TEST_HOSTNAME,
|
||||
0L /* receiptTimeMillis */,
|
||||
true /* cacheFlush */,
|
||||
shortTtl,
|
||||
TEST_ADDRESSES[0].address),
|
||||
MdnsInetAddressRecord(
|
||||
TEST_HOSTNAME,
|
||||
0L /* receiptTimeMillis */,
|
||||
true /* cacheFlush */,
|
||||
shortTtl,
|
||||
TEST_ADDRESSES[1].address),
|
||||
MdnsInetAddressRecord(
|
||||
TEST_HOSTNAME,
|
||||
0L /* receiptTimeMillis */,
|
||||
true /* cacheFlush */,
|
||||
shortTtl,
|
||||
TEST_ADDRESSES[2].address),
|
||||
MdnsNsecRecord(
|
||||
serviceName,
|
||||
0L /* receiptTimeMillis */,
|
||||
true /* cacheFlush */,
|
||||
longTtl,
|
||||
serviceName /* nextDomain */,
|
||||
intArrayOf(MdnsRecord.TYPE_TXT, MdnsRecord.TYPE_SRV)),
|
||||
MdnsNsecRecord(
|
||||
TEST_HOSTNAME,
|
||||
0L /* receiptTimeMillis */,
|
||||
true /* cacheFlush */,
|
||||
shortTtl,
|
||||
TEST_HOSTNAME /* nextDomain */,
|
||||
intArrayOf(MdnsRecord.TYPE_A, MdnsRecord.TYPE_AAAA)),
|
||||
), reply.additionalAnswers)
|
||||
}
|
||||
}
|
||||
|
||||
private fun MdnsRecordRepository.initWithService(serviceId: Int, serviceInfo: NsdServiceInfo):
|
||||
AnnouncementInfo {
|
||||
updateAddresses(TEST_ADDRESSES)
|
||||
addService(serviceId, serviceInfo)
|
||||
val probingInfo = setServiceProbing(serviceId)
|
||||
assertNotNull(probingInfo)
|
||||
return onProbingSucceeded(probingInfo)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user