diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt index e4ee8de031..7be4f78c11 100644 --- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt +++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt @@ -25,6 +25,7 @@ import android.net.LinkAddress import android.net.LinkProperties import android.net.LocalSocket import android.net.LocalSocketAddress +import android.net.MacAddress import android.net.Network import android.net.NetworkAgentConfig import android.net.NetworkCapabilities @@ -73,6 +74,8 @@ import android.system.Os import android.system.OsConstants.AF_INET6 import android.system.OsConstants.EADDRNOTAVAIL import android.system.OsConstants.ENETUNREACH +import android.system.OsConstants.ETH_P_IPV6 +import android.system.OsConstants.IPPROTO_IPV6 import android.system.OsConstants.IPPROTO_UDP import android.system.OsConstants.SOCK_DGRAM import android.util.Log @@ -82,13 +85,21 @@ import com.android.compatibility.common.util.PollingCheck import com.android.compatibility.common.util.PropertyUtil import com.android.modules.utils.build.SdkLevel.isAtLeastU import com.android.net.module.util.ArrayTrackRecord +import com.android.net.module.util.DnsPacket +import com.android.net.module.util.HexDump +import com.android.net.module.util.NetworkStackConstants.ETHER_HEADER_LEN +import com.android.net.module.util.NetworkStackConstants.IPV6_HEADER_LEN +import com.android.net.module.util.NetworkStackConstants.UDP_HEADER_LEN +import com.android.net.module.util.PacketBuilder import com.android.net.module.util.TrackRecord import com.android.testutils.ConnectivityModuleTest import com.android.testutils.DevSdkIgnoreRule import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo import com.android.testutils.DevSdkIgnoreRunner +import com.android.testutils.IPv6UdpFilter import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged +import com.android.testutils.TapPacketReader import com.android.testutils.TestableNetworkAgent import com.android.testutils.TestableNetworkAgent.CallbackEntry.OnNetworkCreated import com.android.testutils.TestableNetworkCallback @@ -103,6 +114,7 @@ import java.net.Inet6Address import java.net.InetAddress import java.net.NetworkInterface import java.net.ServerSocket +import java.nio.ByteBuffer import java.nio.charset.StandardCharsets import java.util.Random import java.util.concurrent.Executor @@ -133,6 +145,8 @@ private const val NO_CALLBACK_TIMEOUT_MS = 200L private const val REGISTRATION_TIMEOUT_MS = 10_000L private const val DBG = false private const val TEST_PORT = 12345 +private const val MDNS_PORT = 5353.toShort() +private val multicastIpv6Addr = parseNumericAddress("ff02::fb") as Inet6Address @AppModeFull(reason = "Socket cannot bind in instant app mode") @RunWith(DevSdkIgnoreRunner::class) @@ -194,8 +208,8 @@ class NsdManagerTest { inline fun expectCallback(timeoutMs: Long = TIMEOUT_MS): V { val nextEvent = nextEvents.poll(timeoutMs) - assertNotNull(nextEvent, "No callback received after $timeoutMs ms, expected " + - "${V::class.java.simpleName}") + assertNotNull(nextEvent, "No callback received after $timeoutMs ms, " + + "expected ${V::class.java.simpleName}") assertTrue(nextEvent is V, "Expected ${V::class.java.simpleName} but got " + nextEvent.javaClass.simpleName) return nextEvent @@ -411,7 +425,6 @@ class NsdManagerTest { val lp = LinkProperties().apply { interfaceName = ifaceName } - val agent = TestableNetworkAgent(context, handlerThread.looper, NetworkCapabilities().apply { removeCapability(NET_CAPABILITY_TRUSTED) @@ -1144,6 +1157,176 @@ class NsdManagerTest { } } + @Test + fun testRegisterWithConflictDuringProbing() { + // This test requires shims supporting T+ APIs (NsdServiceInfo.network) + assumeTrue(TestUtils.shouldTestTApis()) + + val si = NsdServiceInfo() + si.serviceType = serviceType + si.serviceName = serviceName + si.network = testNetwork1.network + si.port = 12345 // Test won't try to connect so port does not matter + + val packetReader = TapPacketReader(Handler(handlerThread.looper), + testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */) + packetReader.startAsyncForTest() + handlerThread.waitForIdle(TIMEOUT_MS) + + // Register service on testNetwork1 + val registrationRecord = NsdRegistrationRecord() + nsdManager.registerService(si, NsdManager.PROTOCOL_DNS_SD, { it.run() }, + registrationRecord) + + tryTest { + assertNotNull(packetReader.pollForProbe(serviceName, serviceType), + "Did not find a probe for the service") + packetReader.sendResponse(buildConflictingAnnouncement()) + + // Registration must use an updated name to avoid the conflict + val cb = registrationRecord.expectCallback(REGISTRATION_TIMEOUT_MS) + cb.serviceInfo.serviceName.let { + assertTrue("Unexpected registered name: $it", + it.startsWith(serviceName) && it != serviceName) + } + } cleanupStep { + nsdManager.unregisterService(registrationRecord) + registrationRecord.expectCallback() + } cleanup { + packetReader.handler.post { packetReader.stop() } + handlerThread.waitForIdle(TIMEOUT_MS) + } + } + + @Test + fun testRegisterWithConflictAfterProbing() { + // This test requires shims supporting T+ APIs (NsdServiceInfo.network) + assumeTrue(TestUtils.shouldTestTApis()) + + val si = NsdServiceInfo() + si.serviceType = serviceType + si.serviceName = serviceName + si.network = testNetwork1.network + si.port = 12345 // Test won't try to connect so port does not matter + + // Register service on testNetwork1 + val registrationRecord = NsdRegistrationRecord() + val discoveryRecord = NsdDiscoveryRecord() + val registeredService = registerService(registrationRecord, si) + val packetReader = TapPacketReader(Handler(handlerThread.looper), + testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */) + packetReader.startAsyncForTest() + handlerThread.waitForIdle(TIMEOUT_MS) + + tryTest { + assertNotNull(packetReader.pollForAdvertisement(serviceName, serviceType), + "No announcements sent after initial probing") + + assertEquals(si.serviceName, registeredService.serviceName) + + nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, + testNetwork1.network, { it.run() }, discoveryRecord) + discoveryRecord.waitForServiceDiscovered(si.serviceName, serviceType) + + // Send a conflicting announcement + val conflictingAnnouncement = buildConflictingAnnouncement() + packetReader.sendResponse(conflictingAnnouncement) + + // Expect to see probes (RFC6762 9., service is reset to probing state) + assertNotNull(packetReader.pollForProbe(serviceName, serviceType), + "Probe not received within timeout after conflict") + + // Send the conflicting packet again to reply to the probe + packetReader.sendResponse(conflictingAnnouncement) + + // Note the legacy mdnsresponder would send an exit announcement here (a 0-lifetime + // advertisement just for the PTR record), but not the new advertiser. This probably + // follows RFC 6762 8.4, saying that when a record rdata changed, "In the case of shared + // records, a host MUST send a "goodbye" announcement with RR TTL zero [...] for the old + // rdata, to cause it to be deleted from peer caches, before announcing the new rdata". + // + // This should be implemented by the new advertiser, but in the case of conflicts it is + // not very valuable since an identical PTR record would be used by the conflicting + // service (except for subtypes). In that case the exit announcement may be + // counter-productive as it conflicts with announcements done by the conflicting + // service. + + // Note that before sending the following ServiceRegistered callback for the renamed + // service, the legacy mdnsresponder-based implementation would first send a + // Service*Registered* callback for the original service name being *unregistered*; it + // should have been a ServiceUnregistered callback instead (bug in NsdService + // interpretation of the callback). + val newRegistration = registrationRecord.expectCallbackEventually( + REGISTRATION_TIMEOUT_MS) { + it.serviceInfo.serviceName.startsWith(serviceName) && + it.serviceInfo.serviceName != serviceName + } + + discoveryRecord.expectCallbackEventually { + it.serviceInfo.serviceName == newRegistration.serviceInfo.serviceName + } + } cleanupStep { + nsdManager.stopServiceDiscovery(discoveryRecord) + discoveryRecord.expectCallback() + } cleanupStep { + nsdManager.unregisterService(registrationRecord) + registrationRecord.expectCallback() + } cleanup { + packetReader.handler.post { packetReader.stop() } + handlerThread.waitForIdle(TIMEOUT_MS) + } + } + + private fun buildConflictingAnnouncement(): ByteBuffer { + /* + Generated with: + scapy.raw(scapy.DNS(rd=0, qr=1, aa=1, qd = None, an = + scapy.DNSRRSRV(rrname='NsdTest123456789._nmt123456789._tcp.local', + rclass=0x8001, port=31234, target='conflict.local', ttl=120) + )).hex() + */ + val mdnsPayload = HexDump.hexStringToByteArray("000084000000000100000000104e736454657" + + "3743132333435363738390d5f6e6d74313233343536373839045f746370056c6f63616c00002" + + "18001000000780016000000007a0208636f6e666c696374056c6f63616c00") + val packetBuffer = ByteBuffer.wrap(mdnsPayload) + // Replace service name and types in the packet with the random ones used in the test. + // Test service name and types have consistent length and are always ASCII + val testPacketName = "NsdTest123456789".encodeToByteArray() + val testPacketTypePrefix = "_nmt123456789".encodeToByteArray() + val encodedServiceName = serviceName.encodeToByteArray() + val encodedTypePrefix = serviceType.split('.')[0].encodeToByteArray() + assertEquals(testPacketName.size, encodedServiceName.size) + assertEquals(testPacketTypePrefix.size, encodedTypePrefix.size) + packetBuffer.position(mdnsPayload.indexOf(testPacketName)) + packetBuffer.put(encodedServiceName) + packetBuffer.position(mdnsPayload.indexOf(testPacketTypePrefix)) + packetBuffer.put(encodedTypePrefix) + + return buildMdnsPacket(mdnsPayload) + } + + private fun buildMdnsPacket(mdnsPayload: ByteArray): ByteBuffer { + val packetBuffer = PacketBuilder.allocate(true /* hasEther */, IPPROTO_IPV6, + IPPROTO_UDP, mdnsPayload.size) + val packetBuilder = PacketBuilder(packetBuffer) + // Multicast ethernet address for IPv6 to ff02::fb + val multicastEthAddr = MacAddress.fromBytes( + byteArrayOf(0x33, 0x33, 0, 0, 0, 0xfb.toByte())) + packetBuilder.writeL2Header( + MacAddress.fromBytes(byteArrayOf(1, 2, 3, 4, 5, 6)) /* srcMac */, + multicastEthAddr, + ETH_P_IPV6.toShort()) + packetBuilder.writeIpv6Header( + 0x60000000, // version=6, traffic class=0x0, flowlabel=0x0 + IPPROTO_UDP.toByte(), + 64 /* hop limit */, + parseNumericAddress("2001:db8::123") as Inet6Address /* srcIp */, + multicastIpv6Addr /* dstIp */) + packetBuilder.writeUdpHeader(MDNS_PORT /* srcPort */, MDNS_PORT /* dstPort */) + packetBuffer.put(mdnsPayload) + return packetBuilder.finalizePacket() + } + /** * Register a service and return its registration record. */ @@ -1169,7 +1352,65 @@ class NsdManagerTest { } } +private fun TapPacketReader.pollForMdnsPacket( + timeoutMs: Long = REGISTRATION_TIMEOUT_MS, + predicate: (TestDnsPacket) -> Boolean +): ByteArray? { + val mdnsProbeFilter = IPv6UdpFilter(srcPort = MDNS_PORT, dstPort = MDNS_PORT).and { + val mdnsPayload = it.copyOfRange( + ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, it.size) + try { + predicate(TestDnsPacket(mdnsPayload)) + } catch (e: DnsPacket.ParseException) { + false + } + } + return poll(timeoutMs, mdnsProbeFilter) +} + +private fun TapPacketReader.pollForProbe( + serviceName: String, + serviceType: String, + timeoutMs: Long = REGISTRATION_TIMEOUT_MS +): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isProbeFor("$serviceName.$serviceType.local") } + +private fun TapPacketReader.pollForAdvertisement( + serviceName: String, + serviceType: String, + timeoutMs: Long = REGISTRATION_TIMEOUT_MS +): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isReplyFor("$serviceName.$serviceType.local") } + +private class TestDnsPacket(data: ByteArray) : DnsPacket(data) { + fun isProbeFor(name: String): Boolean = mRecords[QDSECTION].any { + it.dName == name && it.nsType == 0xff /* ANY */ + } + + fun isReplyFor(name: String): Boolean = mRecords[ANSECTION].any { + it.dName == name && it.nsType == 0x21 /* SRV */ + } +} + private fun ByteArray?.utf8ToString(): String { if (this == null) return "" return String(this, StandardCharsets.UTF_8) } + +private fun ByteArray.indexOf(sub: ByteArray): Int { + var subIndex = 0 + forEachIndexed { i, b -> + when (b) { + // Still matching: continue comparing with next byte + sub[subIndex] -> { + subIndex++ + if (subIndex == sub.size) { + return i - sub.size + 1 + } + } + // Not matching next byte but matches first byte: continue comparing with 2nd byte + sub[0] -> subIndex = 1 + // No matches: continue comparing from first byte + else -> subIndex = 0 + } + } + return -1 +}