From 08521315205d7d759ed756a3749f5f97c860e86f Mon Sep 17 00:00:00 2001 From: Remi NGUYEN VAN Date: Wed, 15 Feb 2023 13:10:11 +0900 Subject: [PATCH] Add test for partial responses Test that when a responder only responds with the exact records that were queried, so only reply for PTR in discovery, only send SRV, TXT, A, AAAA when asked explicitly, service resolve succeeds. This ensures that the querier sends followup queries for each record. See RFC6763 12., especially the last paragraph. Bug: 267570781 Bug: 267371243 Test: atest NsdManagerTest Change-Id: Ia392e80c1e27b479c6177d19f6b4be6032dcb1cd --- .../net/src/android/net/cts/MdnsTestUtils.kt | 31 ++- .../net/src/android/net/cts/NsdManagerTest.kt | 212 +++++++++++++----- 2 files changed, 175 insertions(+), 68 deletions(-) diff --git a/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt b/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt index bc1344237d..eef3f8724b 100644 --- a/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt +++ b/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt @@ -233,46 +233,51 @@ class NsdServiceInfoCallbackRecord : NsdManager.ServiceInfoCallback, } } +private fun getMdnsPayload(packet: ByteArray) = packet.copyOfRange( + ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, packet.size) + fun TapPacketReader.pollForMdnsPacket( timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS, predicate: (TestDnsPacket) -> Boolean -): ByteArray? { +): TestDnsPacket? { val mdnsProbeFilter = IPv6UdpFilter(srcPort = MDNS_PORT, dstPort = MDNS_PORT).and { - val mdnsPayload = it.copyOfRange( - ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, it.size - ) + val mdnsPayload = getMdnsPayload(it) try { predicate(TestDnsPacket(mdnsPayload)) } catch (e: DnsPacket.ParseException) { false } } - return poll(timeoutMs, mdnsProbeFilter) + return poll(timeoutMs, mdnsProbeFilter)?.let { TestDnsPacket(getMdnsPayload(it)) } } fun TapPacketReader.pollForProbe( serviceName: String, serviceType: String, timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS -): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isProbeFor("$serviceName.$serviceType.local") } +): TestDnsPacket? = pollForMdnsPacket(timeoutMs) { + it.isProbeFor("$serviceName.$serviceType.local") +} fun TapPacketReader.pollForAdvertisement( serviceName: String, serviceType: String, timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS -): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isReplyFor("$serviceName.$serviceType.local") } +): TestDnsPacket? = pollForMdnsPacket(timeoutMs) { + it.isReplyFor("$serviceName.$serviceType.local") +} fun TapPacketReader.pollForQuery( recordName: String, - recordType: Int, + vararg requiredTypes: Int, timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS -): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isQueryFor(recordName, recordType) } +): TestDnsPacket? = pollForMdnsPacket(timeoutMs) { it.isQueryFor(recordName, *requiredTypes) } fun TapPacketReader.pollForReply( serviceName: String, serviceType: String, timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS -): ByteArray? = pollForMdnsPacket(timeoutMs) { +): TestDnsPacket? = pollForMdnsPacket(timeoutMs) { it.isReplyFor("$serviceName.$serviceType.local") } @@ -289,7 +294,9 @@ class TestDnsPacket(data: ByteArray) : DnsPacket(data) { it.dName == name && it.nsType == DnsResolver.TYPE_SRV } - fun isQueryFor(name: String, type: Int): Boolean = mRecords[QDSECTION].any { - it.dName == name && it.nsType == type + fun isQueryFor(name: String, vararg requiredTypes: Int): Boolean = requiredTypes.all { type -> + mRecords[QDSECTION].any { + it.dName == name && it.nsType == type + } } } diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt index 27bd5d32b2..9c44a3ead3 100644 --- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt +++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt @@ -20,6 +20,7 @@ import android.Manifest.permission.NETWORK_SETTINGS import android.app.compat.CompatChanges import android.net.ConnectivityManager import android.net.ConnectivityManager.NetworkCallback +import android.net.DnsResolver import android.net.InetAddresses.parseNumericAddress import android.net.LinkAddress import android.net.LinkProperties @@ -87,6 +88,7 @@ import com.android.testutils.TapPacketReader import com.android.testutils.TestableNetworkAgent import com.android.testutils.TestableNetworkAgent.CallbackEntry.OnNetworkCreated import com.android.testutils.TestableNetworkCallback +import com.android.testutils.assertEmpty import com.android.testutils.filters.CtsNetTestCasesMaxTargetSdk30 import com.android.testutils.filters.CtsNetTestCasesMaxTargetSdk33 import com.android.testutils.runAsShell @@ -424,11 +426,7 @@ class NsdManagerTest { @Test fun testNsdManager_DiscoverOnNetwork() { - val si = NsdServiceInfo() - si.serviceType = serviceType - si.serviceName = this.serviceName - si.port = 12345 // Test won't try to connect so port does not matter - + val si = makeTestServiceInfo() val registrationRecord = NsdRegistrationRecord() val registeredInfo = registerService(registrationRecord, si) @@ -455,11 +453,7 @@ class NsdManagerTest { @Test fun testNsdManager_DiscoverWithNetworkRequest() { - val si = NsdServiceInfo() - si.serviceType = serviceType - si.serviceName = this.serviceName - si.port = 12345 // Test won't try to connect so port does not matter - + val si = makeTestServiceInfo() val handler = Handler(handlerThread.looper) val executor = Executor { handler.post(it) } @@ -524,11 +518,6 @@ class NsdManagerTest { @Test fun testNsdManager_DiscoverWithNetworkRequest_NoMatchingNetwork() { - val si = NsdServiceInfo() - si.serviceType = serviceType - si.serviceName = this.serviceName - si.port = 12345 // Test won't try to connect so port does not matter - val handler = Handler(handlerThread.looper) val executor = Executor { handler.post(it) } @@ -568,11 +557,7 @@ class NsdManagerTest { @Test fun testNsdManager_ResolveOnNetwork() { - val si = NsdServiceInfo() - si.serviceType = serviceType - si.serviceName = this.serviceName - si.port = 12345 // Test won't try to connect so port does not matter - + val si = makeTestServiceInfo() val registrationRecord = NsdRegistrationRecord() val registeredInfo = registerService(registrationRecord, si) tryTest { @@ -610,12 +595,7 @@ class NsdManagerTest { @Test fun testNsdManager_RegisterOnNetwork() { - val si = NsdServiceInfo() - si.serviceType = serviceType - si.serviceName = this.serviceName - si.network = testNetwork1.network - si.port = 12345 // Test won't try to connect so port does not matter - + val si = makeTestServiceInfo(testNetwork1.network) // Register service on testNetwork1 val registrationRecord = NsdRegistrationRecord() registerService(registrationRecord, si) @@ -889,11 +869,7 @@ class NsdManagerTest { @Test fun testStopServiceResolution() { - val si = NsdServiceInfo() - si.serviceType = this@NsdManagerTest.serviceType - si.serviceName = this@NsdManagerTest.serviceName - si.port = 12345 // Test won't try to connect so port does not matter - + val si = makeTestServiceInfo() val resolveRecord = NsdResolveRecord() // Try to resolve an unknown service then stop it immediately. // Expected ResolutionStopped callback. @@ -911,12 +887,7 @@ class NsdManagerTest { val addresses = lp.addresses assertFalse(addresses.isEmpty()) - val si = NsdServiceInfo().apply { - serviceType = this@NsdManagerTest.serviceType - serviceName = this@NsdManagerTest.serviceName - network = testNetwork1.network - port = 12345 // Test won't try to connect so port does not matter - } + val si = makeTestServiceInfo(testNetwork1.network) // Register service on the network val registrationRecord = NsdRegistrationRecord() @@ -1022,11 +993,7 @@ class NsdManagerTest { // This test requires shims supporting T+ APIs (NsdServiceInfo.network) assumeTrue(TestUtils.shouldTestTApis()) - val si = NsdServiceInfo() - si.serviceType = serviceType - si.serviceName = serviceName - si.network = testNetwork1.network - si.port = 12345 // Test won't try to connect so port does not matter + val si = makeTestServiceInfo(testNetwork1.network) val packetReader = TapPacketReader(Handler(handlerThread.looper), testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */) @@ -1063,11 +1030,7 @@ class NsdManagerTest { // This test requires shims supporting T+ APIs (NsdServiceInfo.network) assumeTrue(TestUtils.shouldTestTApis()) - val si = NsdServiceInfo() - si.serviceType = serviceType - si.serviceName = serviceName - si.network = testNetwork1.network - si.port = 12345 // Test won't try to connect so port does not matter + val si = makeTestServiceInfo(testNetwork1.network) // Register service on testNetwork1 val registrationRecord = NsdRegistrationRecord() @@ -1137,6 +1100,127 @@ class NsdManagerTest { } } + // Test that even if only a PTR record is received as a reply when discovering, without the + // SRV, TXT, address records as recommended (but not mandated) by RFC 6763 12, the service can + // still be discovered. + @Test + fun testDiscoveryWithPtrOnlyResponse_ServiceIsFound() { + // Register service on testNetwork1 + val discoveryRecord = NsdDiscoveryRecord() + val packetReader = TapPacketReader(Handler(handlerThread.looper), + testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */) + packetReader.startAsyncForTest() + handlerThread.waitForIdle(TIMEOUT_MS) + + nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, + testNetwork1.network, { it.run() }, discoveryRecord) + + tryTest { + discoveryRecord.expectCallback() + assertNotNull(packetReader.pollForQuery("$serviceType.local", DnsResolver.TYPE_PTR)) + /* + Generated with: + scapy.raw(scapy.DNS(rd=0, qr=1, aa=1, qd = None, an = + scapy.DNSRR(rrname='_nmt123456789._tcp.local', type='PTR', ttl=120, + rdata='NsdTest123456789._nmt123456789._tcp.local'))).hex() + */ + val ptrResponsePayload = HexDump.hexStringToByteArray("0000840000000001000000000d5f6e" + + "6d74313233343536373839045f746370056c6f63616c00000c000100000078002b104e736454" + + "6573743132333435363738390d5f6e6d74313233343536373839045f746370056c6f63616c00") + + replaceServiceNameAndTypeWithTestSuffix(ptrResponsePayload) + packetReader.sendResponse(buildMdnsPacket(ptrResponsePayload)) + + val serviceFound = discoveryRecord.expectCallback() + serviceFound.serviceInfo.let { + assertEquals(serviceName, it.serviceName) + // Discovered service types have a dot at the end + assertEquals("$serviceType.", it.serviceType) + assertEquals(testNetwork1.network, it.network) + // ServiceFound does not provide port, address or attributes (only information + // available in the PTR record is included in that callback, regardless of whether + // other records exist). + assertEquals(0, it.port) + assertEmpty(it.hostAddresses) + assertEquals(0, it.attributes.size) + } + } cleanup { + nsdManager.stopServiceDiscovery(discoveryRecord) + discoveryRecord.expectCallback() + } + } + + // Test RFC 6763 12. "Clients MUST be capable of functioning correctly with DNS servers [...] + // that fail to generate these additional records automatically, by issuing subsequent queries + // for any further record(s) they require" + @Test + fun testResolveWhenServerSendsNoAdditionalRecord() { + // Resolve service on testNetwork1 + val resolveRecord = NsdResolveRecord() + val packetReader = TapPacketReader(Handler(handlerThread.looper), + testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */) + packetReader.startAsyncForTest() + handlerThread.waitForIdle(TIMEOUT_MS) + + val si = makeTestServiceInfo(testNetwork1.network) + nsdManager.resolveService(si, { it.run() }, resolveRecord) + + val serviceFullName = "$serviceName.$serviceType.local" + // The query should ask for ANY, since both SRV and TXT are requested. Note legacy + // mdnsresponder will ask for SRV and TXT separately, and will not proceed to asking for + // address records without an answer for both. + val srvTxtQuery = packetReader.pollForQuery(serviceFullName, DnsResolver.TYPE_ANY) + assertNotNull(srvTxtQuery) + + /* + Generated with: + scapy.raw(scapy.dns_compress(scapy.DNS(rd=0, qr=1, aa=1, qd = None, an = + scapy.DNSRRSRV(rrname='NsdTest123456789._nmt123456789._tcp.local', + rclass=0x8001, port=31234, target='testhost.local', ttl=120) / + scapy.DNSRR(rrname='NsdTest123456789._nmt123456789._tcp.local', type='TXT', ttl=120, + rdata='testkey=testvalue') + ))).hex() + */ + val srvTxtResponsePayload = HexDump.hexStringToByteArray("000084000000000200000000104" + + "e7364546573743132333435363738390d5f6e6d74313233343536373839045f746370056c6f6" + + "3616c0000218001000000780011000000007a020874657374686f7374c030c00c00100001000" + + "00078001211746573746b65793d7465737476616c7565") + replaceServiceNameAndTypeWithTestSuffix(srvTxtResponsePayload) + packetReader.sendResponse(buildMdnsPacket(srvTxtResponsePayload)) + + val testHostname = "testhost.local" + val addressQuery = packetReader.pollForQuery(testHostname, + DnsResolver.TYPE_A, DnsResolver.TYPE_AAAA) + assertNotNull(addressQuery) + + /* + Generated with: + scapy.raw(scapy.dns_compress(scapy.DNS(rd=0, qr=1, aa=1, qd = None, an = + scapy.DNSRR(rrname='testhost.local', type='A', ttl=120, + rdata='192.0.2.123') / + scapy.DNSRR(rrname='testhost.local', type='AAAA', ttl=120, + rdata='2001:db8::123') + ))).hex() + */ + val addressPayload = HexDump.hexStringToByteArray("0000840000000002000000000874657374" + + "686f7374056c6f63616c0000010001000000780004c000027bc00c001c000100000078001020" + + "010db8000000000000000000000123") + packetReader.sendResponse(buildMdnsPacket(addressPayload)) + + val serviceResolved = resolveRecord.expectCallback() + serviceResolved.serviceInfo.let { + assertEquals(serviceName, it.serviceName) + assertEquals(".$serviceType", it.serviceType) + assertEquals(testNetwork1.network, it.network) + assertEquals(31234, it.port) + assertEquals(1, it.attributes.size) + assertArrayEquals("testvalue".encodeToByteArray(), it.attributes["testkey"]) + } + assertEquals( + setOf(parseNumericAddress("192.0.2.123"), parseNumericAddress("2001:db8::123")), + serviceResolved.serviceInfo.hostAddresses.toSet()) + } + private fun buildConflictingAnnouncement(): ByteBuffer { /* Generated with: @@ -1148,21 +1232,37 @@ class NsdManagerTest { val mdnsPayload = HexDump.hexStringToByteArray("000084000000000100000000104e736454657" + "3743132333435363738390d5f6e6d74313233343536373839045f746370056c6f63616c00002" + "18001000000780016000000007a0208636f6e666c696374056c6f63616c00") - val packetBuffer = ByteBuffer.wrap(mdnsPayload) - // Replace service name and types in the packet with the random ones used in the test. + replaceServiceNameAndTypeWithTestSuffix(mdnsPayload) + + return buildMdnsPacket(mdnsPayload) + } + + /** + * Replaces occurrences of "NsdTest123456789" and "_nmt123456789" in mDNS payload with the + * actual random name and type that are used by the test. + */ + private fun replaceServiceNameAndTypeWithTestSuffix(mdnsPayload: ByteArray) { // Test service name and types have consistent length and are always ASCII val testPacketName = "NsdTest123456789".encodeToByteArray() val testPacketTypePrefix = "_nmt123456789".encodeToByteArray() val encodedServiceName = serviceName.encodeToByteArray() val encodedTypePrefix = serviceType.split('.')[0].encodeToByteArray() - assertEquals(testPacketName.size, encodedServiceName.size) - assertEquals(testPacketTypePrefix.size, encodedTypePrefix.size) - packetBuffer.position(mdnsPayload.indexOf(testPacketName)) - packetBuffer.put(encodedServiceName) - packetBuffer.position(mdnsPayload.indexOf(testPacketTypePrefix)) - packetBuffer.put(encodedTypePrefix) - return buildMdnsPacket(mdnsPayload) + val packetBuffer = ByteBuffer.wrap(mdnsPayload) + replaceAll(packetBuffer, testPacketName, encodedServiceName) + replaceAll(packetBuffer, testPacketTypePrefix, encodedTypePrefix) + } + + private tailrec fun replaceAll(buffer: ByteBuffer, source: ByteArray, replacement: ByteArray) { + assertEquals(source.size, replacement.size) + val index = buffer.array().indexOf(source) + if (index < 0) return + + val origPosition = buffer.position() + buffer.position(index) + buffer.put(replacement) + buffer.position(origPosition) + replaceAll(buffer, source, replacement) } private fun buildMdnsPacket(mdnsPayload: ByteArray): ByteBuffer {