Merge "Add test for partial responses" into main

This commit is contained in:
Treehugger Robot
2023-09-20 09:07:15 +00:00
committed by Gerrit Code Review
2 changed files with 175 additions and 68 deletions

View File

@@ -233,46 +233,51 @@ class NsdServiceInfoCallbackRecord : NsdManager.ServiceInfoCallback,
}
}
private fun getMdnsPayload(packet: ByteArray) = packet.copyOfRange(
ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, packet.size)
fun TapPacketReader.pollForMdnsPacket(
timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS,
predicate: (TestDnsPacket) -> Boolean
): ByteArray? {
): TestDnsPacket? {
val mdnsProbeFilter = IPv6UdpFilter(srcPort = MDNS_PORT, dstPort = MDNS_PORT).and {
val mdnsPayload = it.copyOfRange(
ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, it.size
)
val mdnsPayload = getMdnsPayload(it)
try {
predicate(TestDnsPacket(mdnsPayload))
} catch (e: DnsPacket.ParseException) {
false
}
}
return poll(timeoutMs, mdnsProbeFilter)
return poll(timeoutMs, mdnsProbeFilter)?.let { TestDnsPacket(getMdnsPayload(it)) }
}
fun TapPacketReader.pollForProbe(
serviceName: String,
serviceType: String,
timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isProbeFor("$serviceName.$serviceType.local") }
): TestDnsPacket? = pollForMdnsPacket(timeoutMs) {
it.isProbeFor("$serviceName.$serviceType.local")
}
fun TapPacketReader.pollForAdvertisement(
serviceName: String,
serviceType: String,
timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isReplyFor("$serviceName.$serviceType.local") }
): TestDnsPacket? = pollForMdnsPacket(timeoutMs) {
it.isReplyFor("$serviceName.$serviceType.local")
}
fun TapPacketReader.pollForQuery(
recordName: String,
recordType: Int,
vararg requiredTypes: Int,
timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isQueryFor(recordName, recordType) }
): TestDnsPacket? = pollForMdnsPacket(timeoutMs) { it.isQueryFor(recordName, *requiredTypes) }
fun TapPacketReader.pollForReply(
serviceName: String,
serviceType: String,
timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
): ByteArray? = pollForMdnsPacket(timeoutMs) {
): TestDnsPacket? = pollForMdnsPacket(timeoutMs) {
it.isReplyFor("$serviceName.$serviceType.local")
}
@@ -289,7 +294,9 @@ class TestDnsPacket(data: ByteArray) : DnsPacket(data) {
it.dName == name && it.nsType == DnsResolver.TYPE_SRV
}
fun isQueryFor(name: String, type: Int): Boolean = mRecords[QDSECTION].any {
fun isQueryFor(name: String, vararg requiredTypes: Int): Boolean = requiredTypes.all { type ->
mRecords[QDSECTION].any {
it.dName == name && it.nsType == type
}
}
}

View File

@@ -20,6 +20,7 @@ import android.Manifest.permission.NETWORK_SETTINGS
import android.app.compat.CompatChanges
import android.net.ConnectivityManager
import android.net.ConnectivityManager.NetworkCallback
import android.net.DnsResolver
import android.net.InetAddresses.parseNumericAddress
import android.net.LinkAddress
import android.net.LinkProperties
@@ -87,6 +88,7 @@ import com.android.testutils.TapPacketReader
import com.android.testutils.TestableNetworkAgent
import com.android.testutils.TestableNetworkAgent.CallbackEntry.OnNetworkCreated
import com.android.testutils.TestableNetworkCallback
import com.android.testutils.assertEmpty
import com.android.testutils.filters.CtsNetTestCasesMaxTargetSdk30
import com.android.testutils.filters.CtsNetTestCasesMaxTargetSdk33
import com.android.testutils.runAsShell
@@ -424,11 +426,7 @@ class NsdManagerTest {
@Test
fun testNsdManager_DiscoverOnNetwork() {
val si = NsdServiceInfo()
si.serviceType = serviceType
si.serviceName = this.serviceName
si.port = 12345 // Test won't try to connect so port does not matter
val si = makeTestServiceInfo()
val registrationRecord = NsdRegistrationRecord()
val registeredInfo = registerService(registrationRecord, si)
@@ -455,11 +453,7 @@ class NsdManagerTest {
@Test
fun testNsdManager_DiscoverWithNetworkRequest() {
val si = NsdServiceInfo()
si.serviceType = serviceType
si.serviceName = this.serviceName
si.port = 12345 // Test won't try to connect so port does not matter
val si = makeTestServiceInfo()
val handler = Handler(handlerThread.looper)
val executor = Executor { handler.post(it) }
@@ -524,11 +518,6 @@ class NsdManagerTest {
@Test
fun testNsdManager_DiscoverWithNetworkRequest_NoMatchingNetwork() {
val si = NsdServiceInfo()
si.serviceType = serviceType
si.serviceName = this.serviceName
si.port = 12345 // Test won't try to connect so port does not matter
val handler = Handler(handlerThread.looper)
val executor = Executor { handler.post(it) }
@@ -568,11 +557,7 @@ class NsdManagerTest {
@Test
fun testNsdManager_ResolveOnNetwork() {
val si = NsdServiceInfo()
si.serviceType = serviceType
si.serviceName = this.serviceName
si.port = 12345 // Test won't try to connect so port does not matter
val si = makeTestServiceInfo()
val registrationRecord = NsdRegistrationRecord()
val registeredInfo = registerService(registrationRecord, si)
tryTest {
@@ -610,12 +595,7 @@ class NsdManagerTest {
@Test
fun testNsdManager_RegisterOnNetwork() {
val si = NsdServiceInfo()
si.serviceType = serviceType
si.serviceName = this.serviceName
si.network = testNetwork1.network
si.port = 12345 // Test won't try to connect so port does not matter
val si = makeTestServiceInfo(testNetwork1.network)
// Register service on testNetwork1
val registrationRecord = NsdRegistrationRecord()
registerService(registrationRecord, si)
@@ -889,11 +869,7 @@ class NsdManagerTest {
@Test
fun testStopServiceResolution() {
val si = NsdServiceInfo()
si.serviceType = this@NsdManagerTest.serviceType
si.serviceName = this@NsdManagerTest.serviceName
si.port = 12345 // Test won't try to connect so port does not matter
val si = makeTestServiceInfo()
val resolveRecord = NsdResolveRecord()
// Try to resolve an unknown service then stop it immediately.
// Expected ResolutionStopped callback.
@@ -911,12 +887,7 @@ class NsdManagerTest {
val addresses = lp.addresses
assertFalse(addresses.isEmpty())
val si = NsdServiceInfo().apply {
serviceType = this@NsdManagerTest.serviceType
serviceName = this@NsdManagerTest.serviceName
network = testNetwork1.network
port = 12345 // Test won't try to connect so port does not matter
}
val si = makeTestServiceInfo(testNetwork1.network)
// Register service on the network
val registrationRecord = NsdRegistrationRecord()
@@ -1022,11 +993,7 @@ class NsdManagerTest {
// This test requires shims supporting T+ APIs (NsdServiceInfo.network)
assumeTrue(TestUtils.shouldTestTApis())
val si = NsdServiceInfo()
si.serviceType = serviceType
si.serviceName = serviceName
si.network = testNetwork1.network
si.port = 12345 // Test won't try to connect so port does not matter
val si = makeTestServiceInfo(testNetwork1.network)
val packetReader = TapPacketReader(Handler(handlerThread.looper),
testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
@@ -1063,11 +1030,7 @@ class NsdManagerTest {
// This test requires shims supporting T+ APIs (NsdServiceInfo.network)
assumeTrue(TestUtils.shouldTestTApis())
val si = NsdServiceInfo()
si.serviceType = serviceType
si.serviceName = serviceName
si.network = testNetwork1.network
si.port = 12345 // Test won't try to connect so port does not matter
val si = makeTestServiceInfo(testNetwork1.network)
// Register service on testNetwork1
val registrationRecord = NsdRegistrationRecord()
@@ -1137,6 +1100,127 @@ class NsdManagerTest {
}
}
// Test that even if only a PTR record is received as a reply when discovering, without the
// SRV, TXT, address records as recommended (but not mandated) by RFC 6763 12, the service can
// still be discovered.
@Test
fun testDiscoveryWithPtrOnlyResponse_ServiceIsFound() {
// Register service on testNetwork1
val discoveryRecord = NsdDiscoveryRecord()
val packetReader = TapPacketReader(Handler(handlerThread.looper),
testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
packetReader.startAsyncForTest()
handlerThread.waitForIdle(TIMEOUT_MS)
nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
testNetwork1.network, { it.run() }, discoveryRecord)
tryTest {
discoveryRecord.expectCallback<DiscoveryStarted>()
assertNotNull(packetReader.pollForQuery("$serviceType.local", DnsResolver.TYPE_PTR))
/*
Generated with:
scapy.raw(scapy.DNS(rd=0, qr=1, aa=1, qd = None, an =
scapy.DNSRR(rrname='_nmt123456789._tcp.local', type='PTR', ttl=120,
rdata='NsdTest123456789._nmt123456789._tcp.local'))).hex()
*/
val ptrResponsePayload = HexDump.hexStringToByteArray("0000840000000001000000000d5f6e" +
"6d74313233343536373839045f746370056c6f63616c00000c000100000078002b104e736454" +
"6573743132333435363738390d5f6e6d74313233343536373839045f746370056c6f63616c00")
replaceServiceNameAndTypeWithTestSuffix(ptrResponsePayload)
packetReader.sendResponse(buildMdnsPacket(ptrResponsePayload))
val serviceFound = discoveryRecord.expectCallback<ServiceFound>()
serviceFound.serviceInfo.let {
assertEquals(serviceName, it.serviceName)
// Discovered service types have a dot at the end
assertEquals("$serviceType.", it.serviceType)
assertEquals(testNetwork1.network, it.network)
// ServiceFound does not provide port, address or attributes (only information
// available in the PTR record is included in that callback, regardless of whether
// other records exist).
assertEquals(0, it.port)
assertEmpty(it.hostAddresses)
assertEquals(0, it.attributes.size)
}
} cleanup {
nsdManager.stopServiceDiscovery(discoveryRecord)
discoveryRecord.expectCallback<DiscoveryStopped>()
}
}
// Test RFC 6763 12. "Clients MUST be capable of functioning correctly with DNS servers [...]
// that fail to generate these additional records automatically, by issuing subsequent queries
// for any further record(s) they require"
@Test
fun testResolveWhenServerSendsNoAdditionalRecord() {
// Resolve service on testNetwork1
val resolveRecord = NsdResolveRecord()
val packetReader = TapPacketReader(Handler(handlerThread.looper),
testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
packetReader.startAsyncForTest()
handlerThread.waitForIdle(TIMEOUT_MS)
val si = makeTestServiceInfo(testNetwork1.network)
nsdManager.resolveService(si, { it.run() }, resolveRecord)
val serviceFullName = "$serviceName.$serviceType.local"
// The query should ask for ANY, since both SRV and TXT are requested. Note legacy
// mdnsresponder will ask for SRV and TXT separately, and will not proceed to asking for
// address records without an answer for both.
val srvTxtQuery = packetReader.pollForQuery(serviceFullName, DnsResolver.TYPE_ANY)
assertNotNull(srvTxtQuery)
/*
Generated with:
scapy.raw(scapy.dns_compress(scapy.DNS(rd=0, qr=1, aa=1, qd = None, an =
scapy.DNSRRSRV(rrname='NsdTest123456789._nmt123456789._tcp.local',
rclass=0x8001, port=31234, target='testhost.local', ttl=120) /
scapy.DNSRR(rrname='NsdTest123456789._nmt123456789._tcp.local', type='TXT', ttl=120,
rdata='testkey=testvalue')
))).hex()
*/
val srvTxtResponsePayload = HexDump.hexStringToByteArray("000084000000000200000000104" +
"e7364546573743132333435363738390d5f6e6d74313233343536373839045f746370056c6f6" +
"3616c0000218001000000780011000000007a020874657374686f7374c030c00c00100001000" +
"00078001211746573746b65793d7465737476616c7565")
replaceServiceNameAndTypeWithTestSuffix(srvTxtResponsePayload)
packetReader.sendResponse(buildMdnsPacket(srvTxtResponsePayload))
val testHostname = "testhost.local"
val addressQuery = packetReader.pollForQuery(testHostname,
DnsResolver.TYPE_A, DnsResolver.TYPE_AAAA)
assertNotNull(addressQuery)
/*
Generated with:
scapy.raw(scapy.dns_compress(scapy.DNS(rd=0, qr=1, aa=1, qd = None, an =
scapy.DNSRR(rrname='testhost.local', type='A', ttl=120,
rdata='192.0.2.123') /
scapy.DNSRR(rrname='testhost.local', type='AAAA', ttl=120,
rdata='2001:db8::123')
))).hex()
*/
val addressPayload = HexDump.hexStringToByteArray("0000840000000002000000000874657374" +
"686f7374056c6f63616c0000010001000000780004c000027bc00c001c000100000078001020" +
"010db8000000000000000000000123")
packetReader.sendResponse(buildMdnsPacket(addressPayload))
val serviceResolved = resolveRecord.expectCallback<ServiceResolved>()
serviceResolved.serviceInfo.let {
assertEquals(serviceName, it.serviceName)
assertEquals(".$serviceType", it.serviceType)
assertEquals(testNetwork1.network, it.network)
assertEquals(31234, it.port)
assertEquals(1, it.attributes.size)
assertArrayEquals("testvalue".encodeToByteArray(), it.attributes["testkey"])
}
assertEquals(
setOf(parseNumericAddress("192.0.2.123"), parseNumericAddress("2001:db8::123")),
serviceResolved.serviceInfo.hostAddresses.toSet())
}
private fun buildConflictingAnnouncement(): ByteBuffer {
/*
Generated with:
@@ -1148,21 +1232,37 @@ class NsdManagerTest {
val mdnsPayload = HexDump.hexStringToByteArray("000084000000000100000000104e736454657" +
"3743132333435363738390d5f6e6d74313233343536373839045f746370056c6f63616c00002" +
"18001000000780016000000007a0208636f6e666c696374056c6f63616c00")
val packetBuffer = ByteBuffer.wrap(mdnsPayload)
// Replace service name and types in the packet with the random ones used in the test.
replaceServiceNameAndTypeWithTestSuffix(mdnsPayload)
return buildMdnsPacket(mdnsPayload)
}
/**
* Replaces occurrences of "NsdTest123456789" and "_nmt123456789" in mDNS payload with the
* actual random name and type that are used by the test.
*/
private fun replaceServiceNameAndTypeWithTestSuffix(mdnsPayload: ByteArray) {
// Test service name and types have consistent length and are always ASCII
val testPacketName = "NsdTest123456789".encodeToByteArray()
val testPacketTypePrefix = "_nmt123456789".encodeToByteArray()
val encodedServiceName = serviceName.encodeToByteArray()
val encodedTypePrefix = serviceType.split('.')[0].encodeToByteArray()
assertEquals(testPacketName.size, encodedServiceName.size)
assertEquals(testPacketTypePrefix.size, encodedTypePrefix.size)
packetBuffer.position(mdnsPayload.indexOf(testPacketName))
packetBuffer.put(encodedServiceName)
packetBuffer.position(mdnsPayload.indexOf(testPacketTypePrefix))
packetBuffer.put(encodedTypePrefix)
return buildMdnsPacket(mdnsPayload)
val packetBuffer = ByteBuffer.wrap(mdnsPayload)
replaceAll(packetBuffer, testPacketName, encodedServiceName)
replaceAll(packetBuffer, testPacketTypePrefix, encodedTypePrefix)
}
private tailrec fun replaceAll(buffer: ByteBuffer, source: ByteArray, replacement: ByteArray) {
assertEquals(source.size, replacement.size)
val index = buffer.array().indexOf(source)
if (index < 0) return
val origPosition = buffer.position()
buffer.position(index)
buffer.put(replacement)
buffer.position(origPosition)
replaceAll(buffer, source, replacement)
}
private fun buildMdnsPacket(mdnsPayload: ByteArray): ByteBuffer {