Merge "Add tests for conflicts" into main

This commit is contained in:
Remi NGUYEN VAN
2023-08-16 04:01:30 +00:00
committed by Gerrit Code Review

View File

@@ -25,6 +25,7 @@ import android.net.LinkAddress
import android.net.LinkProperties import android.net.LinkProperties
import android.net.LocalSocket import android.net.LocalSocket
import android.net.LocalSocketAddress import android.net.LocalSocketAddress
import android.net.MacAddress
import android.net.Network import android.net.Network
import android.net.NetworkAgentConfig import android.net.NetworkAgentConfig
import android.net.NetworkCapabilities import android.net.NetworkCapabilities
@@ -73,6 +74,8 @@ import android.system.Os
import android.system.OsConstants.AF_INET6 import android.system.OsConstants.AF_INET6
import android.system.OsConstants.EADDRNOTAVAIL import android.system.OsConstants.EADDRNOTAVAIL
import android.system.OsConstants.ENETUNREACH import android.system.OsConstants.ENETUNREACH
import android.system.OsConstants.ETH_P_IPV6
import android.system.OsConstants.IPPROTO_IPV6
import android.system.OsConstants.IPPROTO_UDP import android.system.OsConstants.IPPROTO_UDP
import android.system.OsConstants.SOCK_DGRAM import android.system.OsConstants.SOCK_DGRAM
import android.util.Log import android.util.Log
@@ -82,13 +85,21 @@ import com.android.compatibility.common.util.PollingCheck
import com.android.compatibility.common.util.PropertyUtil import com.android.compatibility.common.util.PropertyUtil
import com.android.modules.utils.build.SdkLevel.isAtLeastU import com.android.modules.utils.build.SdkLevel.isAtLeastU
import com.android.net.module.util.ArrayTrackRecord import com.android.net.module.util.ArrayTrackRecord
import com.android.net.module.util.DnsPacket
import com.android.net.module.util.HexDump
import com.android.net.module.util.NetworkStackConstants.ETHER_HEADER_LEN
import com.android.net.module.util.NetworkStackConstants.IPV6_HEADER_LEN
import com.android.net.module.util.NetworkStackConstants.UDP_HEADER_LEN
import com.android.net.module.util.PacketBuilder
import com.android.net.module.util.TrackRecord import com.android.net.module.util.TrackRecord
import com.android.testutils.ConnectivityModuleTest import com.android.testutils.ConnectivityModuleTest
import com.android.testutils.DevSdkIgnoreRule import com.android.testutils.DevSdkIgnoreRule
import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
import com.android.testutils.DevSdkIgnoreRunner import com.android.testutils.DevSdkIgnoreRunner
import com.android.testutils.IPv6UdpFilter
import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
import com.android.testutils.TapPacketReader
import com.android.testutils.TestableNetworkAgent import com.android.testutils.TestableNetworkAgent
import com.android.testutils.TestableNetworkAgent.CallbackEntry.OnNetworkCreated import com.android.testutils.TestableNetworkAgent.CallbackEntry.OnNetworkCreated
import com.android.testutils.TestableNetworkCallback import com.android.testutils.TestableNetworkCallback
@@ -103,6 +114,7 @@ import java.net.Inet6Address
import java.net.InetAddress import java.net.InetAddress
import java.net.NetworkInterface import java.net.NetworkInterface
import java.net.ServerSocket import java.net.ServerSocket
import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets import java.nio.charset.StandardCharsets
import java.util.Random import java.util.Random
import java.util.concurrent.Executor import java.util.concurrent.Executor
@@ -133,6 +145,8 @@ private const val NO_CALLBACK_TIMEOUT_MS = 200L
private const val REGISTRATION_TIMEOUT_MS = 10_000L private const val REGISTRATION_TIMEOUT_MS = 10_000L
private const val DBG = false private const val DBG = false
private const val TEST_PORT = 12345 private const val TEST_PORT = 12345
private const val MDNS_PORT = 5353.toShort()
private val multicastIpv6Addr = parseNumericAddress("ff02::fb") as Inet6Address
@AppModeFull(reason = "Socket cannot bind in instant app mode") @AppModeFull(reason = "Socket cannot bind in instant app mode")
@RunWith(DevSdkIgnoreRunner::class) @RunWith(DevSdkIgnoreRunner::class)
@@ -194,8 +208,8 @@ class NsdManagerTest {
inline fun <reified V : NsdEvent> expectCallback(timeoutMs: Long = TIMEOUT_MS): V { inline fun <reified V : NsdEvent> expectCallback(timeoutMs: Long = TIMEOUT_MS): V {
val nextEvent = nextEvents.poll(timeoutMs) val nextEvent = nextEvents.poll(timeoutMs)
assertNotNull(nextEvent, "No callback received after $timeoutMs ms, expected " + assertNotNull(nextEvent, "No callback received after $timeoutMs ms, " +
"${V::class.java.simpleName}") "expected ${V::class.java.simpleName}")
assertTrue(nextEvent is V, "Expected ${V::class.java.simpleName} but got " + assertTrue(nextEvent is V, "Expected ${V::class.java.simpleName} but got " +
nextEvent.javaClass.simpleName) nextEvent.javaClass.simpleName)
return nextEvent return nextEvent
@@ -411,7 +425,6 @@ class NsdManagerTest {
val lp = LinkProperties().apply { val lp = LinkProperties().apply {
interfaceName = ifaceName interfaceName = ifaceName
} }
val agent = TestableNetworkAgent(context, handlerThread.looper, val agent = TestableNetworkAgent(context, handlerThread.looper,
NetworkCapabilities().apply { NetworkCapabilities().apply {
removeCapability(NET_CAPABILITY_TRUSTED) removeCapability(NET_CAPABILITY_TRUSTED)
@@ -1144,6 +1157,176 @@ class NsdManagerTest {
} }
} }
@Test
fun testRegisterWithConflictDuringProbing() {
// This test requires shims supporting T+ APIs (NsdServiceInfo.network)
assumeTrue(TestUtils.shouldTestTApis())
val si = NsdServiceInfo()
si.serviceType = serviceType
si.serviceName = serviceName
si.network = testNetwork1.network
si.port = 12345 // Test won't try to connect so port does not matter
val packetReader = TapPacketReader(Handler(handlerThread.looper),
testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
packetReader.startAsyncForTest()
handlerThread.waitForIdle(TIMEOUT_MS)
// Register service on testNetwork1
val registrationRecord = NsdRegistrationRecord()
nsdManager.registerService(si, NsdManager.PROTOCOL_DNS_SD, { it.run() },
registrationRecord)
tryTest {
assertNotNull(packetReader.pollForProbe(serviceName, serviceType),
"Did not find a probe for the service")
packetReader.sendResponse(buildConflictingAnnouncement())
// Registration must use an updated name to avoid the conflict
val cb = registrationRecord.expectCallback<ServiceRegistered>(REGISTRATION_TIMEOUT_MS)
cb.serviceInfo.serviceName.let {
assertTrue("Unexpected registered name: $it",
it.startsWith(serviceName) && it != serviceName)
}
} cleanupStep {
nsdManager.unregisterService(registrationRecord)
registrationRecord.expectCallback<ServiceUnregistered>()
} cleanup {
packetReader.handler.post { packetReader.stop() }
handlerThread.waitForIdle(TIMEOUT_MS)
}
}
@Test
fun testRegisterWithConflictAfterProbing() {
// This test requires shims supporting T+ APIs (NsdServiceInfo.network)
assumeTrue(TestUtils.shouldTestTApis())
val si = NsdServiceInfo()
si.serviceType = serviceType
si.serviceName = serviceName
si.network = testNetwork1.network
si.port = 12345 // Test won't try to connect so port does not matter
// Register service on testNetwork1
val registrationRecord = NsdRegistrationRecord()
val discoveryRecord = NsdDiscoveryRecord()
val registeredService = registerService(registrationRecord, si)
val packetReader = TapPacketReader(Handler(handlerThread.looper),
testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
packetReader.startAsyncForTest()
handlerThread.waitForIdle(TIMEOUT_MS)
tryTest {
assertNotNull(packetReader.pollForAdvertisement(serviceName, serviceType),
"No announcements sent after initial probing")
assertEquals(si.serviceName, registeredService.serviceName)
nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
testNetwork1.network, { it.run() }, discoveryRecord)
discoveryRecord.waitForServiceDiscovered(si.serviceName, serviceType)
// Send a conflicting announcement
val conflictingAnnouncement = buildConflictingAnnouncement()
packetReader.sendResponse(conflictingAnnouncement)
// Expect to see probes (RFC6762 9., service is reset to probing state)
assertNotNull(packetReader.pollForProbe(serviceName, serviceType),
"Probe not received within timeout after conflict")
// Send the conflicting packet again to reply to the probe
packetReader.sendResponse(conflictingAnnouncement)
// Note the legacy mdnsresponder would send an exit announcement here (a 0-lifetime
// advertisement just for the PTR record), but not the new advertiser. This probably
// follows RFC 6762 8.4, saying that when a record rdata changed, "In the case of shared
// records, a host MUST send a "goodbye" announcement with RR TTL zero [...] for the old
// rdata, to cause it to be deleted from peer caches, before announcing the new rdata".
//
// This should be implemented by the new advertiser, but in the case of conflicts it is
// not very valuable since an identical PTR record would be used by the conflicting
// service (except for subtypes). In that case the exit announcement may be
// counter-productive as it conflicts with announcements done by the conflicting
// service.
// Note that before sending the following ServiceRegistered callback for the renamed
// service, the legacy mdnsresponder-based implementation would first send a
// Service*Registered* callback for the original service name being *unregistered*; it
// should have been a ServiceUnregistered callback instead (bug in NsdService
// interpretation of the callback).
val newRegistration = registrationRecord.expectCallbackEventually<ServiceRegistered>(
REGISTRATION_TIMEOUT_MS) {
it.serviceInfo.serviceName.startsWith(serviceName) &&
it.serviceInfo.serviceName != serviceName
}
discoveryRecord.expectCallbackEventually<ServiceFound> {
it.serviceInfo.serviceName == newRegistration.serviceInfo.serviceName
}
} cleanupStep {
nsdManager.stopServiceDiscovery(discoveryRecord)
discoveryRecord.expectCallback<DiscoveryStopped>()
} cleanupStep {
nsdManager.unregisterService(registrationRecord)
registrationRecord.expectCallback<ServiceUnregistered>()
} cleanup {
packetReader.handler.post { packetReader.stop() }
handlerThread.waitForIdle(TIMEOUT_MS)
}
}
private fun buildConflictingAnnouncement(): ByteBuffer {
/*
Generated with:
scapy.raw(scapy.DNS(rd=0, qr=1, aa=1, qd = None, an =
scapy.DNSRRSRV(rrname='NsdTest123456789._nmt123456789._tcp.local',
rclass=0x8001, port=31234, target='conflict.local', ttl=120)
)).hex()
*/
val mdnsPayload = HexDump.hexStringToByteArray("000084000000000100000000104e736454657" +
"3743132333435363738390d5f6e6d74313233343536373839045f746370056c6f63616c00002" +
"18001000000780016000000007a0208636f6e666c696374056c6f63616c00")
val packetBuffer = ByteBuffer.wrap(mdnsPayload)
// Replace service name and types in the packet with the random ones used in the test.
// Test service name and types have consistent length and are always ASCII
val testPacketName = "NsdTest123456789".encodeToByteArray()
val testPacketTypePrefix = "_nmt123456789".encodeToByteArray()
val encodedServiceName = serviceName.encodeToByteArray()
val encodedTypePrefix = serviceType.split('.')[0].encodeToByteArray()
assertEquals(testPacketName.size, encodedServiceName.size)
assertEquals(testPacketTypePrefix.size, encodedTypePrefix.size)
packetBuffer.position(mdnsPayload.indexOf(testPacketName))
packetBuffer.put(encodedServiceName)
packetBuffer.position(mdnsPayload.indexOf(testPacketTypePrefix))
packetBuffer.put(encodedTypePrefix)
return buildMdnsPacket(mdnsPayload)
}
private fun buildMdnsPacket(mdnsPayload: ByteArray): ByteBuffer {
val packetBuffer = PacketBuilder.allocate(true /* hasEther */, IPPROTO_IPV6,
IPPROTO_UDP, mdnsPayload.size)
val packetBuilder = PacketBuilder(packetBuffer)
// Multicast ethernet address for IPv6 to ff02::fb
val multicastEthAddr = MacAddress.fromBytes(
byteArrayOf(0x33, 0x33, 0, 0, 0, 0xfb.toByte()))
packetBuilder.writeL2Header(
MacAddress.fromBytes(byteArrayOf(1, 2, 3, 4, 5, 6)) /* srcMac */,
multicastEthAddr,
ETH_P_IPV6.toShort())
packetBuilder.writeIpv6Header(
0x60000000, // version=6, traffic class=0x0, flowlabel=0x0
IPPROTO_UDP.toByte(),
64 /* hop limit */,
parseNumericAddress("2001:db8::123") as Inet6Address /* srcIp */,
multicastIpv6Addr /* dstIp */)
packetBuilder.writeUdpHeader(MDNS_PORT /* srcPort */, MDNS_PORT /* dstPort */)
packetBuffer.put(mdnsPayload)
return packetBuilder.finalizePacket()
}
/** /**
* Register a service and return its registration record. * Register a service and return its registration record.
*/ */
@@ -1169,7 +1352,65 @@ class NsdManagerTest {
} }
} }
private fun TapPacketReader.pollForMdnsPacket(
timeoutMs: Long = REGISTRATION_TIMEOUT_MS,
predicate: (TestDnsPacket) -> Boolean
): ByteArray? {
val mdnsProbeFilter = IPv6UdpFilter(srcPort = MDNS_PORT, dstPort = MDNS_PORT).and {
val mdnsPayload = it.copyOfRange(
ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, it.size)
try {
predicate(TestDnsPacket(mdnsPayload))
} catch (e: DnsPacket.ParseException) {
false
}
}
return poll(timeoutMs, mdnsProbeFilter)
}
private fun TapPacketReader.pollForProbe(
serviceName: String,
serviceType: String,
timeoutMs: Long = REGISTRATION_TIMEOUT_MS
): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isProbeFor("$serviceName.$serviceType.local") }
private fun TapPacketReader.pollForAdvertisement(
serviceName: String,
serviceType: String,
timeoutMs: Long = REGISTRATION_TIMEOUT_MS
): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isReplyFor("$serviceName.$serviceType.local") }
private class TestDnsPacket(data: ByteArray) : DnsPacket(data) {
fun isProbeFor(name: String): Boolean = mRecords[QDSECTION].any {
it.dName == name && it.nsType == 0xff /* ANY */
}
fun isReplyFor(name: String): Boolean = mRecords[ANSECTION].any {
it.dName == name && it.nsType == 0x21 /* SRV */
}
}
private fun ByteArray?.utf8ToString(): String { private fun ByteArray?.utf8ToString(): String {
if (this == null) return "" if (this == null) return ""
return String(this, StandardCharsets.UTF_8) return String(this, StandardCharsets.UTF_8)
} }
private fun ByteArray.indexOf(sub: ByteArray): Int {
var subIndex = 0
forEachIndexed { i, b ->
when (b) {
// Still matching: continue comparing with next byte
sub[subIndex] -> {
subIndex++
if (subIndex == sub.size) {
return i - sub.size + 1
}
}
// Not matching next byte but matches first byte: continue comparing with 2nd byte
sub[0] -> subIndex = 1
// No matches: continue comparing from first byte
else -> subIndex = 0
}
}
return -1
}