diff --git a/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt b/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt index bc1344237d..eef3f8724b 100644 --- a/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt +++ b/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt @@ -233,46 +233,51 @@ class NsdServiceInfoCallbackRecord : NsdManager.ServiceInfoCallback, } } +private fun getMdnsPayload(packet: ByteArray) = packet.copyOfRange( + ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, packet.size) + fun TapPacketReader.pollForMdnsPacket( timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS, predicate: (TestDnsPacket) -> Boolean -): ByteArray? { +): TestDnsPacket? { val mdnsProbeFilter = IPv6UdpFilter(srcPort = MDNS_PORT, dstPort = MDNS_PORT).and { - val mdnsPayload = it.copyOfRange( - ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, it.size - ) + val mdnsPayload = getMdnsPayload(it) try { predicate(TestDnsPacket(mdnsPayload)) } catch (e: DnsPacket.ParseException) { false } } - return poll(timeoutMs, mdnsProbeFilter) + return poll(timeoutMs, mdnsProbeFilter)?.let { TestDnsPacket(getMdnsPayload(it)) } } fun TapPacketReader.pollForProbe( serviceName: String, serviceType: String, timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS -): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isProbeFor("$serviceName.$serviceType.local") } +): TestDnsPacket? = pollForMdnsPacket(timeoutMs) { + it.isProbeFor("$serviceName.$serviceType.local") +} fun TapPacketReader.pollForAdvertisement( serviceName: String, serviceType: String, timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS -): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isReplyFor("$serviceName.$serviceType.local") } +): TestDnsPacket? = pollForMdnsPacket(timeoutMs) { + it.isReplyFor("$serviceName.$serviceType.local") +} fun TapPacketReader.pollForQuery( recordName: String, - recordType: Int, + vararg requiredTypes: Int, timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS -): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isQueryFor(recordName, recordType) } +): TestDnsPacket? = pollForMdnsPacket(timeoutMs) { it.isQueryFor(recordName, *requiredTypes) } fun TapPacketReader.pollForReply( serviceName: String, serviceType: String, timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS -): ByteArray? = pollForMdnsPacket(timeoutMs) { +): TestDnsPacket? = pollForMdnsPacket(timeoutMs) { it.isReplyFor("$serviceName.$serviceType.local") } @@ -289,7 +294,9 @@ class TestDnsPacket(data: ByteArray) : DnsPacket(data) { it.dName == name && it.nsType == DnsResolver.TYPE_SRV } - fun isQueryFor(name: String, type: Int): Boolean = mRecords[QDSECTION].any { - it.dName == name && it.nsType == type + fun isQueryFor(name: String, vararg requiredTypes: Int): Boolean = requiredTypes.all { type -> + mRecords[QDSECTION].any { + it.dName == name && it.nsType == type + } } } diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt index 27bd5d32b2..9c44a3ead3 100644 --- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt +++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt @@ -20,6 +20,7 @@ import android.Manifest.permission.NETWORK_SETTINGS import android.app.compat.CompatChanges import android.net.ConnectivityManager import android.net.ConnectivityManager.NetworkCallback +import android.net.DnsResolver import android.net.InetAddresses.parseNumericAddress import android.net.LinkAddress import android.net.LinkProperties @@ -87,6 +88,7 @@ import com.android.testutils.TapPacketReader import com.android.testutils.TestableNetworkAgent import com.android.testutils.TestableNetworkAgent.CallbackEntry.OnNetworkCreated import com.android.testutils.TestableNetworkCallback +import com.android.testutils.assertEmpty import com.android.testutils.filters.CtsNetTestCasesMaxTargetSdk30 import com.android.testutils.filters.CtsNetTestCasesMaxTargetSdk33 import com.android.testutils.runAsShell @@ -424,11 +426,7 @@ class NsdManagerTest { @Test fun testNsdManager_DiscoverOnNetwork() { - val si = NsdServiceInfo() - si.serviceType = serviceType - si.serviceName = this.serviceName - si.port = 12345 // Test won't try to connect so port does not matter - + val si = makeTestServiceInfo() val registrationRecord = NsdRegistrationRecord() val registeredInfo = registerService(registrationRecord, si) @@ -455,11 +453,7 @@ class NsdManagerTest { @Test fun testNsdManager_DiscoverWithNetworkRequest() { - val si = NsdServiceInfo() - si.serviceType = serviceType - si.serviceName = this.serviceName - si.port = 12345 // Test won't try to connect so port does not matter - + val si = makeTestServiceInfo() val handler = Handler(handlerThread.looper) val executor = Executor { handler.post(it) } @@ -524,11 +518,6 @@ class NsdManagerTest { @Test fun testNsdManager_DiscoverWithNetworkRequest_NoMatchingNetwork() { - val si = NsdServiceInfo() - si.serviceType = serviceType - si.serviceName = this.serviceName - si.port = 12345 // Test won't try to connect so port does not matter - val handler = Handler(handlerThread.looper) val executor = Executor { handler.post(it) } @@ -568,11 +557,7 @@ class NsdManagerTest { @Test fun testNsdManager_ResolveOnNetwork() { - val si = NsdServiceInfo() - si.serviceType = serviceType - si.serviceName = this.serviceName - si.port = 12345 // Test won't try to connect so port does not matter - + val si = makeTestServiceInfo() val registrationRecord = NsdRegistrationRecord() val registeredInfo = registerService(registrationRecord, si) tryTest { @@ -610,12 +595,7 @@ class NsdManagerTest { @Test fun testNsdManager_RegisterOnNetwork() { - val si = NsdServiceInfo() - si.serviceType = serviceType - si.serviceName = this.serviceName - si.network = testNetwork1.network - si.port = 12345 // Test won't try to connect so port does not matter - + val si = makeTestServiceInfo(testNetwork1.network) // Register service on testNetwork1 val registrationRecord = NsdRegistrationRecord() registerService(registrationRecord, si) @@ -889,11 +869,7 @@ class NsdManagerTest { @Test fun testStopServiceResolution() { - val si = NsdServiceInfo() - si.serviceType = this@NsdManagerTest.serviceType - si.serviceName = this@NsdManagerTest.serviceName - si.port = 12345 // Test won't try to connect so port does not matter - + val si = makeTestServiceInfo() val resolveRecord = NsdResolveRecord() // Try to resolve an unknown service then stop it immediately. // Expected ResolutionStopped callback. @@ -911,12 +887,7 @@ class NsdManagerTest { val addresses = lp.addresses assertFalse(addresses.isEmpty()) - val si = NsdServiceInfo().apply { - serviceType = this@NsdManagerTest.serviceType - serviceName = this@NsdManagerTest.serviceName - network = testNetwork1.network - port = 12345 // Test won't try to connect so port does not matter - } + val si = makeTestServiceInfo(testNetwork1.network) // Register service on the network val registrationRecord = NsdRegistrationRecord() @@ -1022,11 +993,7 @@ class NsdManagerTest { // This test requires shims supporting T+ APIs (NsdServiceInfo.network) assumeTrue(TestUtils.shouldTestTApis()) - val si = NsdServiceInfo() - si.serviceType = serviceType - si.serviceName = serviceName - si.network = testNetwork1.network - si.port = 12345 // Test won't try to connect so port does not matter + val si = makeTestServiceInfo(testNetwork1.network) val packetReader = TapPacketReader(Handler(handlerThread.looper), testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */) @@ -1063,11 +1030,7 @@ class NsdManagerTest { // This test requires shims supporting T+ APIs (NsdServiceInfo.network) assumeTrue(TestUtils.shouldTestTApis()) - val si = NsdServiceInfo() - si.serviceType = serviceType - si.serviceName = serviceName - si.network = testNetwork1.network - si.port = 12345 // Test won't try to connect so port does not matter + val si = makeTestServiceInfo(testNetwork1.network) // Register service on testNetwork1 val registrationRecord = NsdRegistrationRecord() @@ -1137,6 +1100,127 @@ class NsdManagerTest { } } + // Test that even if only a PTR record is received as a reply when discovering, without the + // SRV, TXT, address records as recommended (but not mandated) by RFC 6763 12, the service can + // still be discovered. + @Test + fun testDiscoveryWithPtrOnlyResponse_ServiceIsFound() { + // Register service on testNetwork1 + val discoveryRecord = NsdDiscoveryRecord() + val packetReader = TapPacketReader(Handler(handlerThread.looper), + testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */) + packetReader.startAsyncForTest() + handlerThread.waitForIdle(TIMEOUT_MS) + + nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, + testNetwork1.network, { it.run() }, discoveryRecord) + + tryTest { + discoveryRecord.expectCallback() + assertNotNull(packetReader.pollForQuery("$serviceType.local", DnsResolver.TYPE_PTR)) + /* + Generated with: + scapy.raw(scapy.DNS(rd=0, qr=1, aa=1, qd = None, an = + scapy.DNSRR(rrname='_nmt123456789._tcp.local', type='PTR', ttl=120, + rdata='NsdTest123456789._nmt123456789._tcp.local'))).hex() + */ + val ptrResponsePayload = HexDump.hexStringToByteArray("0000840000000001000000000d5f6e" + + "6d74313233343536373839045f746370056c6f63616c00000c000100000078002b104e736454" + + "6573743132333435363738390d5f6e6d74313233343536373839045f746370056c6f63616c00") + + replaceServiceNameAndTypeWithTestSuffix(ptrResponsePayload) + packetReader.sendResponse(buildMdnsPacket(ptrResponsePayload)) + + val serviceFound = discoveryRecord.expectCallback() + serviceFound.serviceInfo.let { + assertEquals(serviceName, it.serviceName) + // Discovered service types have a dot at the end + assertEquals("$serviceType.", it.serviceType) + assertEquals(testNetwork1.network, it.network) + // ServiceFound does not provide port, address or attributes (only information + // available in the PTR record is included in that callback, regardless of whether + // other records exist). + assertEquals(0, it.port) + assertEmpty(it.hostAddresses) + assertEquals(0, it.attributes.size) + } + } cleanup { + nsdManager.stopServiceDiscovery(discoveryRecord) + discoveryRecord.expectCallback() + } + } + + // Test RFC 6763 12. "Clients MUST be capable of functioning correctly with DNS servers [...] + // that fail to generate these additional records automatically, by issuing subsequent queries + // for any further record(s) they require" + @Test + fun testResolveWhenServerSendsNoAdditionalRecord() { + // Resolve service on testNetwork1 + val resolveRecord = NsdResolveRecord() + val packetReader = TapPacketReader(Handler(handlerThread.looper), + testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */) + packetReader.startAsyncForTest() + handlerThread.waitForIdle(TIMEOUT_MS) + + val si = makeTestServiceInfo(testNetwork1.network) + nsdManager.resolveService(si, { it.run() }, resolveRecord) + + val serviceFullName = "$serviceName.$serviceType.local" + // The query should ask for ANY, since both SRV and TXT are requested. Note legacy + // mdnsresponder will ask for SRV and TXT separately, and will not proceed to asking for + // address records without an answer for both. + val srvTxtQuery = packetReader.pollForQuery(serviceFullName, DnsResolver.TYPE_ANY) + assertNotNull(srvTxtQuery) + + /* + Generated with: + scapy.raw(scapy.dns_compress(scapy.DNS(rd=0, qr=1, aa=1, qd = None, an = + scapy.DNSRRSRV(rrname='NsdTest123456789._nmt123456789._tcp.local', + rclass=0x8001, port=31234, target='testhost.local', ttl=120) / + scapy.DNSRR(rrname='NsdTest123456789._nmt123456789._tcp.local', type='TXT', ttl=120, + rdata='testkey=testvalue') + ))).hex() + */ + val srvTxtResponsePayload = HexDump.hexStringToByteArray("000084000000000200000000104" + + "e7364546573743132333435363738390d5f6e6d74313233343536373839045f746370056c6f6" + + "3616c0000218001000000780011000000007a020874657374686f7374c030c00c00100001000" + + "00078001211746573746b65793d7465737476616c7565") + replaceServiceNameAndTypeWithTestSuffix(srvTxtResponsePayload) + packetReader.sendResponse(buildMdnsPacket(srvTxtResponsePayload)) + + val testHostname = "testhost.local" + val addressQuery = packetReader.pollForQuery(testHostname, + DnsResolver.TYPE_A, DnsResolver.TYPE_AAAA) + assertNotNull(addressQuery) + + /* + Generated with: + scapy.raw(scapy.dns_compress(scapy.DNS(rd=0, qr=1, aa=1, qd = None, an = + scapy.DNSRR(rrname='testhost.local', type='A', ttl=120, + rdata='192.0.2.123') / + scapy.DNSRR(rrname='testhost.local', type='AAAA', ttl=120, + rdata='2001:db8::123') + ))).hex() + */ + val addressPayload = HexDump.hexStringToByteArray("0000840000000002000000000874657374" + + "686f7374056c6f63616c0000010001000000780004c000027bc00c001c000100000078001020" + + "010db8000000000000000000000123") + packetReader.sendResponse(buildMdnsPacket(addressPayload)) + + val serviceResolved = resolveRecord.expectCallback() + serviceResolved.serviceInfo.let { + assertEquals(serviceName, it.serviceName) + assertEquals(".$serviceType", it.serviceType) + assertEquals(testNetwork1.network, it.network) + assertEquals(31234, it.port) + assertEquals(1, it.attributes.size) + assertArrayEquals("testvalue".encodeToByteArray(), it.attributes["testkey"]) + } + assertEquals( + setOf(parseNumericAddress("192.0.2.123"), parseNumericAddress("2001:db8::123")), + serviceResolved.serviceInfo.hostAddresses.toSet()) + } + private fun buildConflictingAnnouncement(): ByteBuffer { /* Generated with: @@ -1148,21 +1232,37 @@ class NsdManagerTest { val mdnsPayload = HexDump.hexStringToByteArray("000084000000000100000000104e736454657" + "3743132333435363738390d5f6e6d74313233343536373839045f746370056c6f63616c00002" + "18001000000780016000000007a0208636f6e666c696374056c6f63616c00") - val packetBuffer = ByteBuffer.wrap(mdnsPayload) - // Replace service name and types in the packet with the random ones used in the test. + replaceServiceNameAndTypeWithTestSuffix(mdnsPayload) + + return buildMdnsPacket(mdnsPayload) + } + + /** + * Replaces occurrences of "NsdTest123456789" and "_nmt123456789" in mDNS payload with the + * actual random name and type that are used by the test. + */ + private fun replaceServiceNameAndTypeWithTestSuffix(mdnsPayload: ByteArray) { // Test service name and types have consistent length and are always ASCII val testPacketName = "NsdTest123456789".encodeToByteArray() val testPacketTypePrefix = "_nmt123456789".encodeToByteArray() val encodedServiceName = serviceName.encodeToByteArray() val encodedTypePrefix = serviceType.split('.')[0].encodeToByteArray() - assertEquals(testPacketName.size, encodedServiceName.size) - assertEquals(testPacketTypePrefix.size, encodedTypePrefix.size) - packetBuffer.position(mdnsPayload.indexOf(testPacketName)) - packetBuffer.put(encodedServiceName) - packetBuffer.position(mdnsPayload.indexOf(testPacketTypePrefix)) - packetBuffer.put(encodedTypePrefix) - return buildMdnsPacket(mdnsPayload) + val packetBuffer = ByteBuffer.wrap(mdnsPayload) + replaceAll(packetBuffer, testPacketName, encodedServiceName) + replaceAll(packetBuffer, testPacketTypePrefix, encodedTypePrefix) + } + + private tailrec fun replaceAll(buffer: ByteBuffer, source: ByteArray, replacement: ByteArray) { + assertEquals(source.size, replacement.size) + val index = buffer.array().indexOf(source) + if (index < 0) return + + val origPosition = buffer.position() + buffer.position(index) + buffer.put(replacement) + buffer.position(origPosition) + replaceAll(buffer, source, replacement) } private fun buildMdnsPacket(mdnsPayload: ByteArray): ByteBuffer {