diff --git a/Tethering/tests/integration/Android.bp b/Tethering/tests/integration/Android.bp index 20f0bc6fb8..2594a5e13c 100644 --- a/Tethering/tests/integration/Android.bp +++ b/Tethering/tests/integration/Android.bp @@ -28,7 +28,7 @@ java_defaults { "DhcpPacketLib", "androidx.test.rules", "cts-net-utils", - "mockito-target-extended-minus-junit4", + "mockito-target-minus-junit4", "net-tests-utils", "net-utils-device-common", "net-utils-device-common-bpf", @@ -40,11 +40,6 @@ java_defaults { "android.test.base", "android.test.mock", ], - jni_libs: [ - // For mockito extended - "libdexmakerjvmtiagent", - "libstaticjvmtiagent", - ], } android_library { @@ -54,6 +49,7 @@ android_library { defaults: ["TetheringIntegrationTestsDefaults"], visibility: [ "//packages/modules/Connectivity/Tethering/tests/mts", + "//packages/modules/Connectivity/tests/cts/net", ] } diff --git a/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java b/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java index 83fc3e4ae1..0702aa75da 100644 --- a/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java +++ b/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java @@ -31,14 +31,12 @@ import static android.net.TetheringTester.isAddressIpv4; import static android.net.TetheringTester.isExpectedIcmpPacket; import static android.net.TetheringTester.isExpectedTcpPacket; import static android.net.TetheringTester.isExpectedUdpPacket; - import static com.android.net.module.util.HexDump.dumpHexString; import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ROUTER_ADVERTISEMENT; import static com.android.net.module.util.NetworkStackConstants.TCPHDR_ACK; import static com.android.net.module.util.NetworkStackConstants.TCPHDR_SYN; import static com.android.testutils.TestNetworkTrackerKt.initTestNetwork; import static com.android.testutils.TestPermissionUtil.runAsShell; - import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; @@ -164,6 +162,10 @@ public abstract class EthernetTetheringTestBase { private TapPacketReader mDownstreamReader; private MyTetheringEventCallback mTetheringEventCallback; + public Context getContext() { + return mContext; + } + @BeforeClass public static void setUpOnce() throws Exception { // The first test case may experience tethering restart with IP conflict handling. diff --git a/Tethering/tests/integration/base/android/net/TetheringTester.java b/Tethering/tests/integration/base/android/net/TetheringTester.java index 4f3c6e7400..ae4ae55098 100644 --- a/Tethering/tests/integration/base/android/net/TetheringTester.java +++ b/Tethering/tests/integration/base/android/net/TetheringTester.java @@ -27,12 +27,9 @@ import static android.system.OsConstants.IPPROTO_IP; import static android.system.OsConstants.IPPROTO_IPV6; import static android.system.OsConstants.IPPROTO_TCP; import static android.system.OsConstants.IPPROTO_UDP; - import static com.android.net.module.util.DnsPacket.ANSECTION; -import static com.android.net.module.util.DnsPacket.ARSECTION; import static com.android.net.module.util.DnsPacket.DnsHeader; import static com.android.net.module.util.DnsPacket.DnsRecord; -import static com.android.net.module.util.DnsPacket.NSSECTION; import static com.android.net.module.util.DnsPacket.QDSECTION; import static com.android.net.module.util.HexDump.dumpHexString; import static com.android.net.module.util.IpUtils.icmpChecksum; @@ -56,7 +53,6 @@ import static com.android.net.module.util.NetworkStackConstants.IPV6_ADDR_ALL_NO import static com.android.net.module.util.NetworkStackConstants.NEIGHBOR_ADVERTISEMENT_FLAG_OVERRIDE; import static com.android.net.module.util.NetworkStackConstants.NEIGHBOR_ADVERTISEMENT_FLAG_SOLICITED; import static com.android.net.module.util.NetworkStackConstants.TCPHDR_SYN; - import static org.junit.Assert.assertNotNull; import static org.junit.Assert.fail; diff --git a/framework/src/android/net/DnsResolver.java b/framework/src/android/net/DnsResolver.java index c6034f1f63..5fefcd6770 100644 --- a/framework/src/android/net/DnsResolver.java +++ b/framework/src/android/net/DnsResolver.java @@ -77,6 +77,15 @@ public final class DnsResolver { @interface QueryType {} public static final int TYPE_A = 1; public static final int TYPE_AAAA = 28; + // TODO: add below constants as part of QueryType and the public API + /** @hide */ + public static final int TYPE_PTR = 12; + /** @hide */ + public static final int TYPE_TXT = 16; + /** @hide */ + public static final int TYPE_SRV = 33; + /** @hide */ + public static final int TYPE_ANY = 255; @IntDef(prefix = { "FLAG_" }, value = { FLAG_EMPTY, diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java index c951e9840c..e4d5b812b2 100644 --- a/service-t/src/com/android/server/NsdService.java +++ b/service-t/src/com/android/server/NsdService.java @@ -1684,7 +1684,10 @@ public class NsdService extends INsdManager.Stub { mMdnsSocketProvider = deps.makeMdnsSocketProvider(ctx, handler.getLooper(), LOGGER.forSubComponent("MdnsSocketProvider"), new SocketRequestMonitor()); // Netlink monitor starts on boot, and intentionally never stopped, to ensure that all - // address events are received. + // address events are received. When the netlink monitor starts, any IP addresses already + // on the interfaces will not be seen. In practice, the network will not connect at boot + // time As a result, all the netlink message should be observed if the netlink monitor + // starts here. handler.post(mMdnsSocketProvider::startNetLinkMonitor); // NsdService is started after ActivityManager (startOtherServices in SystemServer, vs. diff --git a/tests/cts/net/Android.bp b/tests/cts/net/Android.bp index 1276d59e55..6de663af42 100644 --- a/tests/cts/net/Android.bp +++ b/tests/cts/net/Android.bp @@ -56,6 +56,7 @@ java_defaults { "modules-utils-build", "net-utils-framework-common", "truth-prebuilt", + "TetheringIntegrationTestsBaseLib", ], // uncomment when b/13249961 is fixed diff --git a/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt b/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt new file mode 100644 index 0000000000..bc1344237d --- /dev/null +++ b/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt @@ -0,0 +1,295 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package android.net.cts + +import android.net.DnsResolver +import android.net.Network +import android.net.nsd.NsdManager +import android.net.nsd.NsdServiceInfo +import android.os.Process +import com.android.net.module.util.ArrayTrackRecord +import com.android.net.module.util.DnsPacket +import com.android.net.module.util.NetworkStackConstants.ETHER_HEADER_LEN +import com.android.net.module.util.NetworkStackConstants.IPV6_HEADER_LEN +import com.android.net.module.util.NetworkStackConstants.UDP_HEADER_LEN +import com.android.net.module.util.TrackRecord +import com.android.testutils.IPv6UdpFilter +import com.android.testutils.TapPacketReader +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue +import kotlin.test.fail + +private const val MDNS_REGISTRATION_TIMEOUT_MS = 10_000L +private const val MDNS_PORT = 5353.toShort() +const val MDNS_CALLBACK_TIMEOUT = 2000L +const val MDNS_NO_CALLBACK_TIMEOUT_MS = 200L + +interface NsdEvent +open class NsdRecord private constructor( + private val history: ArrayTrackRecord, + private val expectedThreadId: Int? = null +) : TrackRecord by history { + constructor(expectedThreadId: Int? = null) : this(ArrayTrackRecord(), expectedThreadId) + + val nextEvents = history.newReadHead() + + override fun add(e: T): Boolean { + if (expectedThreadId != null) { + assertEquals( + expectedThreadId, Process.myTid(), + "Callback is running on the wrong thread" + ) + } + return history.add(e) + } + + inline fun expectCallbackEventually( + timeoutMs: Long = MDNS_CALLBACK_TIMEOUT, + crossinline predicate: (V) -> Boolean = { true } + ): V = nextEvents.poll(timeoutMs) { e -> e is V && predicate(e) } as V? + ?: fail("Callback for ${V::class.java.simpleName} not seen after $timeoutMs ms") + + inline fun expectCallback(timeoutMs: Long = MDNS_CALLBACK_TIMEOUT): V { + val nextEvent = nextEvents.poll(timeoutMs) + assertNotNull( + nextEvent, "No callback received after $timeoutMs ms, expected " + + "${V::class.java.simpleName}" + ) + assertTrue( + nextEvent is V, "Expected ${V::class.java.simpleName} but got " + + nextEvent.javaClass.simpleName + ) + return nextEvent + } + + inline fun assertNoCallback(timeoutMs: Long = MDNS_NO_CALLBACK_TIMEOUT_MS) { + val cb = nextEvents.poll(timeoutMs) + assertNull(cb, "Expected no callback but got $cb") + } +} + +class NsdDiscoveryRecord(expectedThreadId: Int? = null) : + NsdManager.DiscoveryListener, NsdRecord(expectedThreadId) { + sealed class DiscoveryEvent : NsdEvent { + data class StartDiscoveryFailed(val serviceType: String, val errorCode: Int) : + DiscoveryEvent() + + data class StopDiscoveryFailed(val serviceType: String, val errorCode: Int) : + DiscoveryEvent() + + data class DiscoveryStarted(val serviceType: String) : DiscoveryEvent() + data class DiscoveryStopped(val serviceType: String) : DiscoveryEvent() + data class ServiceFound(val serviceInfo: NsdServiceInfo) : DiscoveryEvent() + data class ServiceLost(val serviceInfo: NsdServiceInfo) : DiscoveryEvent() + } + + override fun onStartDiscoveryFailed(serviceType: String, err: Int) { + add(DiscoveryEvent.StartDiscoveryFailed(serviceType, err)) + } + + override fun onStopDiscoveryFailed(serviceType: String, err: Int) { + add(DiscoveryEvent.StopDiscoveryFailed(serviceType, err)) + } + + override fun onDiscoveryStarted(serviceType: String) { + add(DiscoveryEvent.DiscoveryStarted(serviceType)) + } + + override fun onDiscoveryStopped(serviceType: String) { + add(DiscoveryEvent.DiscoveryStopped(serviceType)) + } + + override fun onServiceFound(si: NsdServiceInfo) { + add(DiscoveryEvent.ServiceFound(si)) + } + + override fun onServiceLost(si: NsdServiceInfo) { + add(DiscoveryEvent.ServiceLost(si)) + } + + fun waitForServiceDiscovered( + serviceName: String, + serviceType: String, + expectedNetwork: Network? = null + ): NsdServiceInfo { + val serviceFound = expectCallbackEventually { + it.serviceInfo.serviceName == serviceName && + (expectedNetwork == null || + expectedNetwork == it.serviceInfo.network) + }.serviceInfo + // Discovered service types have a dot at the end + assertEquals("$serviceType.", serviceFound.serviceType) + return serviceFound + } +} + +class NsdRegistrationRecord(expectedThreadId: Int? = null) : NsdManager.RegistrationListener, + NsdRecord(expectedThreadId) { + sealed class RegistrationEvent : NsdEvent { + abstract val serviceInfo: NsdServiceInfo + + data class RegistrationFailed( + override val serviceInfo: NsdServiceInfo, + val errorCode: Int + ) : RegistrationEvent() + + data class UnregistrationFailed( + override val serviceInfo: NsdServiceInfo, + val errorCode: Int + ) : RegistrationEvent() + + data class ServiceRegistered(override val serviceInfo: NsdServiceInfo) : + RegistrationEvent() + + data class ServiceUnregistered(override val serviceInfo: NsdServiceInfo) : + RegistrationEvent() + } + + override fun onRegistrationFailed(si: NsdServiceInfo, err: Int) { + add(RegistrationEvent.RegistrationFailed(si, err)) + } + + override fun onUnregistrationFailed(si: NsdServiceInfo, err: Int) { + add(RegistrationEvent.UnregistrationFailed(si, err)) + } + + override fun onServiceRegistered(si: NsdServiceInfo) { + add(RegistrationEvent.ServiceRegistered(si)) + } + + override fun onServiceUnregistered(si: NsdServiceInfo) { + add(RegistrationEvent.ServiceUnregistered(si)) + } +} + +class NsdResolveRecord : NsdManager.ResolveListener, + NsdRecord() { + sealed class ResolveEvent : NsdEvent { + data class ResolveFailed(val serviceInfo: NsdServiceInfo, val errorCode: Int) : + ResolveEvent() + + data class ServiceResolved(val serviceInfo: NsdServiceInfo) : ResolveEvent() + data class ResolutionStopped(val serviceInfo: NsdServiceInfo) : ResolveEvent() + data class StopResolutionFailed(val serviceInfo: NsdServiceInfo, val errorCode: Int) : + ResolveEvent() + } + + override fun onResolveFailed(si: NsdServiceInfo, err: Int) { + add(ResolveEvent.ResolveFailed(si, err)) + } + + override fun onServiceResolved(si: NsdServiceInfo) { + add(ResolveEvent.ServiceResolved(si)) + } + + override fun onResolutionStopped(si: NsdServiceInfo) { + add(ResolveEvent.ResolutionStopped(si)) + } + + override fun onStopResolutionFailed(si: NsdServiceInfo, err: Int) { + super.onStopResolutionFailed(si, err) + add(ResolveEvent.StopResolutionFailed(si, err)) + } +} + +class NsdServiceInfoCallbackRecord : NsdManager.ServiceInfoCallback, + NsdRecord() { + sealed class ServiceInfoCallbackEvent : NsdEvent { + data class RegisterCallbackFailed(val errorCode: Int) : ServiceInfoCallbackEvent() + data class ServiceUpdated(val serviceInfo: NsdServiceInfo) : ServiceInfoCallbackEvent() + object ServiceUpdatedLost : ServiceInfoCallbackEvent() + object UnregisterCallbackSucceeded : ServiceInfoCallbackEvent() + } + + override fun onServiceInfoCallbackRegistrationFailed(err: Int) { + add(ServiceInfoCallbackEvent.RegisterCallbackFailed(err)) + } + + override fun onServiceUpdated(si: NsdServiceInfo) { + add(ServiceInfoCallbackEvent.ServiceUpdated(si)) + } + + override fun onServiceLost() { + add(ServiceInfoCallbackEvent.ServiceUpdatedLost) + } + + override fun onServiceInfoCallbackUnregistered() { + add(ServiceInfoCallbackEvent.UnregisterCallbackSucceeded) + } +} + +fun TapPacketReader.pollForMdnsPacket( + timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS, + predicate: (TestDnsPacket) -> Boolean +): ByteArray? { + val mdnsProbeFilter = IPv6UdpFilter(srcPort = MDNS_PORT, dstPort = MDNS_PORT).and { + val mdnsPayload = it.copyOfRange( + ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, it.size + ) + try { + predicate(TestDnsPacket(mdnsPayload)) + } catch (e: DnsPacket.ParseException) { + false + } + } + return poll(timeoutMs, mdnsProbeFilter) +} + +fun TapPacketReader.pollForProbe( + serviceName: String, + serviceType: String, + timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS +): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isProbeFor("$serviceName.$serviceType.local") } + +fun TapPacketReader.pollForAdvertisement( + serviceName: String, + serviceType: String, + timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS +): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isReplyFor("$serviceName.$serviceType.local") } + +fun TapPacketReader.pollForQuery( + recordName: String, + recordType: Int, + timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS +): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isQueryFor(recordName, recordType) } + +fun TapPacketReader.pollForReply( + serviceName: String, + serviceType: String, + timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS +): ByteArray? = pollForMdnsPacket(timeoutMs) { + it.isReplyFor("$serviceName.$serviceType.local") +} + +class TestDnsPacket(data: ByteArray) : DnsPacket(data) { + val header: DnsHeader + get() = mHeader + val records: Array> + get() = mRecords + fun isProbeFor(name: String): Boolean = mRecords[QDSECTION].any { + it.dName == name && it.nsType == DnsResolver.TYPE_ANY + } + + fun isReplyFor(name: String): Boolean = mRecords[ANSECTION].any { + it.dName == name && it.nsType == DnsResolver.TYPE_SRV + } + + fun isQueryFor(name: String, type: Int): Boolean = mRecords[QDSECTION].any { + it.dName == name && it.nsType == type + } +} diff --git a/tests/cts/net/src/android/net/cts/NsdManagerDownstreamTetheringTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerDownstreamTetheringTest.kt new file mode 100644 index 0000000000..c2bb7cd67e --- /dev/null +++ b/tests/cts/net/src/android/net/cts/NsdManagerDownstreamTetheringTest.kt @@ -0,0 +1,150 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package android.net.cts + +import android.net.EthernetTetheringTestBase +import android.net.LinkAddress +import android.net.TestNetworkInterface +import android.net.TetheringManager.CONNECTIVITY_SCOPE_LOCAL +import android.net.TetheringManager.TETHERING_ETHERNET +import android.net.TetheringManager.TetheringRequest +import android.net.nsd.NsdManager +import android.os.Build +import androidx.test.filters.SmallTest +import com.android.testutils.ConnectivityModuleTest +import com.android.testutils.DevSdkIgnoreRule +import com.android.testutils.DevSdkIgnoreRunner +import com.android.testutils.TapPacketReader +import com.android.testutils.tryTest +import java.util.Random +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import org.junit.After +import org.junit.Assume.assumeFalse +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith + +@RunWith(DevSdkIgnoreRunner::class) +@SmallTest +@ConnectivityModuleTest +@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2) +class NsdManagerDownstreamTetheringTest : EthernetTetheringTestBase() { + private val nsdManager by lazy { context.getSystemService(NsdManager::class.java)!! } + private val serviceType = "_nmt%09d._tcp".format(Random().nextInt(1_000_000_000)) + + @Before + override fun setUp() { + super.setUp() + setIncludeTestInterfaces(true) + } + + @After + override fun tearDown() { + super.tearDown() + setIncludeTestInterfaces(false) + } + + @Test + fun testMdnsDiscoveryCanSendPacketOnLocalOnlyDownstreamTetheringInterface() { + assumeFalse(isInterfaceForTetheringAvailable) + + var downstreamIface: TestNetworkInterface? = null + var tetheringEventCallback: MyTetheringEventCallback? = null + var downstreamReader: TapPacketReader? = null + + val discoveryRecord = NsdDiscoveryRecord() + + tryTest { + downstreamIface = createTestInterface() + val iface = tetheredInterface + assertEquals(iface, downstreamIface?.interfaceName) + val request = TetheringRequest.Builder(TETHERING_ETHERNET) + .setConnectivityScope(CONNECTIVITY_SCOPE_LOCAL).build() + tetheringEventCallback = enableEthernetTethering( + iface, request, + null /* any upstream */ + ).apply { + awaitInterfaceLocalOnly() + } + // This shouldn't be flaky because the TAP interface will buffer all packets even + // before the reader is started. + downstreamReader = makePacketReader(downstreamIface) + waitForRouterAdvertisement(downstreamReader, iface, WAIT_RA_TIMEOUT_MS) + + nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, discoveryRecord) + discoveryRecord.expectCallback() + assertNotNull(downstreamReader?.pollForQuery("$serviceType.local", 12 /* type PTR */)) + } cleanupStep { + nsdManager.stopServiceDiscovery(discoveryRecord) + discoveryRecord.expectCallback() + } cleanupStep { + maybeStopTapPacketReader(downstreamReader) + } cleanupStep { + maybeCloseTestInterface(downstreamIface) + } cleanup { + maybeUnregisterTetheringEventCallback(tetheringEventCallback) + } + } + + @Test + fun testMdnsDiscoveryWorkOnTetheringInterface() { + assumeFalse(isInterfaceForTetheringAvailable) + setIncludeTestInterfaces(true) + + var downstreamIface: TestNetworkInterface? = null + var tetheringEventCallback: MyTetheringEventCallback? = null + var downstreamReader: TapPacketReader? = null + + val discoveryRecord = NsdDiscoveryRecord() + + tryTest { + downstreamIface = createTestInterface() + val iface = tetheredInterface + assertEquals(iface, downstreamIface?.interfaceName) + + val localAddr = LinkAddress("192.0.2.3/28") + val clientAddr = LinkAddress("192.0.2.2/28") + val request = TetheringRequest.Builder(TETHERING_ETHERNET) + .setStaticIpv4Addresses(localAddr, clientAddr) + .setShouldShowEntitlementUi(false).build() + tetheringEventCallback = enableEthernetTethering( + iface, request, + null /* any upstream */ + ).apply { + awaitInterfaceTethered() + } + + val fd = downstreamIface?.fileDescriptor?.fileDescriptor + assertNotNull(fd) + downstreamReader = makePacketReader(fd, getMTU(downstreamIface)) + + nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, discoveryRecord) + discoveryRecord.expectCallback() + assertNotNull(downstreamReader?.pollForQuery("$serviceType.local", 12 /* type PTR */)) + // TODO: Add another test to check packet reply can trigger serviceFound. + } cleanupStep { + nsdManager.stopServiceDiscovery(discoveryRecord) + discoveryRecord.expectCallback() + } cleanupStep { + maybeStopTapPacketReader(downstreamReader) + } cleanupStep { + maybeCloseTestInterface(downstreamIface) + } cleanup { + maybeUnregisterTetheringEventCallback(tetheringEventCallback) + } + } +} diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt index 17a135a7a7..27bd5d32b2 100644 --- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt +++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt @@ -38,36 +38,26 @@ import android.net.TestNetworkInterface import android.net.TestNetworkManager import android.net.TestNetworkSpecifier import android.net.connectivity.ConnectivityCompatChanges -import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.DiscoveryStarted -import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.DiscoveryStopped -import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.ServiceFound -import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.ServiceLost -import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.StartDiscoveryFailed -import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.StopDiscoveryFailed -import android.net.cts.NsdManagerTest.NsdRegistrationRecord.RegistrationEvent.RegistrationFailed -import android.net.cts.NsdManagerTest.NsdRegistrationRecord.RegistrationEvent.ServiceRegistered -import android.net.cts.NsdManagerTest.NsdRegistrationRecord.RegistrationEvent.ServiceUnregistered -import android.net.cts.NsdManagerTest.NsdRegistrationRecord.RegistrationEvent.UnregistrationFailed -import android.net.cts.NsdManagerTest.NsdResolveRecord.ResolveEvent.ResolutionStopped -import android.net.cts.NsdManagerTest.NsdResolveRecord.ResolveEvent.ResolveFailed -import android.net.cts.NsdManagerTest.NsdResolveRecord.ResolveEvent.ServiceResolved -import android.net.cts.NsdManagerTest.NsdResolveRecord.ResolveEvent.StopResolutionFailed -import android.net.cts.NsdManagerTest.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.RegisterCallbackFailed -import android.net.cts.NsdManagerTest.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.ServiceUpdated -import android.net.cts.NsdManagerTest.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.ServiceUpdatedLost -import android.net.cts.NsdManagerTest.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.UnregisterCallbackSucceeded +import android.net.cts.NsdDiscoveryRecord.DiscoveryEvent.DiscoveryStarted +import android.net.cts.NsdDiscoveryRecord.DiscoveryEvent.DiscoveryStopped +import android.net.cts.NsdDiscoveryRecord.DiscoveryEvent.ServiceFound +import android.net.cts.NsdDiscoveryRecord.DiscoveryEvent.ServiceLost +import android.net.cts.NsdRegistrationRecord.RegistrationEvent.ServiceRegistered +import android.net.cts.NsdRegistrationRecord.RegistrationEvent.ServiceUnregistered +import android.net.cts.NsdResolveRecord.ResolveEvent.ResolutionStopped +import android.net.cts.NsdResolveRecord.ResolveEvent.ServiceResolved +import android.net.cts.NsdResolveRecord.ResolveEvent.StopResolutionFailed +import android.net.cts.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.ServiceUpdated +import android.net.cts.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.ServiceUpdatedLost +import android.net.cts.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.UnregisterCallbackSucceeded import android.net.cts.util.CtsNetUtils import android.net.nsd.NsdManager -import android.net.nsd.NsdManager.DiscoveryListener -import android.net.nsd.NsdManager.RegistrationListener -import android.net.nsd.NsdManager.ResolveListener import android.net.nsd.NsdServiceInfo import android.net.nsd.OffloadEngine import android.net.nsd.OffloadServiceInfo import android.os.Build import android.os.Handler import android.os.HandlerThread -import android.os.Process.myTid import android.platform.test.annotations.AppModeFull import android.system.ErrnoException import android.system.Os @@ -84,19 +74,13 @@ import androidx.test.platform.app.InstrumentationRegistry import com.android.compatibility.common.util.PollingCheck import com.android.compatibility.common.util.PropertyUtil import com.android.modules.utils.build.SdkLevel.isAtLeastU -import com.android.net.module.util.ArrayTrackRecord import com.android.net.module.util.DnsPacket import com.android.net.module.util.HexDump -import com.android.net.module.util.NetworkStackConstants.ETHER_HEADER_LEN -import com.android.net.module.util.NetworkStackConstants.IPV6_HEADER_LEN -import com.android.net.module.util.NetworkStackConstants.UDP_HEADER_LEN import com.android.net.module.util.PacketBuilder -import com.android.net.module.util.TrackRecord import com.android.testutils.ConnectivityModuleTest import com.android.testutils.DevSdkIgnoreRule import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo import com.android.testutils.DevSdkIgnoreRunner -import com.android.testutils.IPv6UdpFilter import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged import com.android.testutils.TapPacketReader @@ -123,7 +107,6 @@ import kotlin.test.assertEquals import kotlin.test.assertFailsWith import kotlin.test.assertNotNull import kotlin.test.assertNull -import kotlin.test.assertTrue import kotlin.test.fail import org.junit.After import org.junit.Assert.assertArrayEquals @@ -137,7 +120,6 @@ import org.junit.runner.RunWith private const val TAG = "NsdManagerTest" private const val TIMEOUT_MS = 2000L -private const val NO_CALLBACK_TIMEOUT_MS = 200L // Registration may take a long time if there are devices with the same hostname on the network, // as the device needs to try another name and probe again. This is especially true since when using // mdnsresponder the usual hostname is "Android", and on conflict "Android-2", "Android-3", ... are @@ -159,7 +141,9 @@ class NsdManagerTest { val ignoreRule = DevSdkIgnoreRule() private val context by lazy { InstrumentationRegistry.getInstrumentation().context } - private val nsdManager by lazy { context.getSystemService(NsdManager::class.java)!! } + private val nsdManager by lazy { + context.getSystemService(NsdManager::class.java) ?: fail("Could not get NsdManager service") + } private val cm by lazy { context.getSystemService(ConnectivityManager::class.java)!! } private val serviceName = "NsdTest%09d".format(Random().nextInt(1_000_000_000)) @@ -185,192 +169,6 @@ class NsdManagerTest { } } - private interface NsdEvent - private open class NsdRecord private constructor( - private val history: ArrayTrackRecord, - private val expectedThreadId: Int? = null - ) : TrackRecord by history { - constructor(expectedThreadId: Int? = null) : this(ArrayTrackRecord(), expectedThreadId) - - val nextEvents = history.newReadHead() - - override fun add(e: T): Boolean { - if (expectedThreadId != null) { - assertEquals(expectedThreadId, myTid(), "Callback is running on the wrong thread") - } - return history.add(e) - } - - inline fun expectCallbackEventually( - timeoutMs: Long = TIMEOUT_MS, - crossinline predicate: (V) -> Boolean = { true } - ): V = nextEvents.poll(timeoutMs) { e -> e is V && predicate(e) } as V? - ?: fail("Callback for ${V::class.java.simpleName} not seen after $timeoutMs ms") - - inline fun expectCallback(timeoutMs: Long = TIMEOUT_MS): V { - val nextEvent = nextEvents.poll(timeoutMs) - assertNotNull(nextEvent, "No callback received after $timeoutMs ms, " + - "expected ${V::class.java.simpleName}") - assertTrue(nextEvent is V, "Expected ${V::class.java.simpleName} but got " + - nextEvent.javaClass.simpleName) - return nextEvent - } - - inline fun assertNoCallback(timeoutMs: Long = NO_CALLBACK_TIMEOUT_MS) { - val cb = nextEvents.poll(timeoutMs) - assertNull(cb, "Expected no callback but got $cb") - } - } - - private class NsdRegistrationRecord(expectedThreadId: Int? = null) : RegistrationListener, - NsdRecord(expectedThreadId) { - sealed class RegistrationEvent : NsdEvent { - abstract val serviceInfo: NsdServiceInfo - - data class RegistrationFailed( - override val serviceInfo: NsdServiceInfo, - val errorCode: Int - ) : RegistrationEvent() - - data class UnregistrationFailed( - override val serviceInfo: NsdServiceInfo, - val errorCode: Int - ) : RegistrationEvent() - - data class ServiceRegistered(override val serviceInfo: NsdServiceInfo) : - RegistrationEvent() - data class ServiceUnregistered(override val serviceInfo: NsdServiceInfo) : - RegistrationEvent() - } - - override fun onRegistrationFailed(si: NsdServiceInfo, err: Int) { - add(RegistrationFailed(si, err)) - } - - override fun onUnregistrationFailed(si: NsdServiceInfo, err: Int) { - add(UnregistrationFailed(si, err)) - } - - override fun onServiceRegistered(si: NsdServiceInfo) { - add(ServiceRegistered(si)) - } - - override fun onServiceUnregistered(si: NsdServiceInfo) { - add(ServiceUnregistered(si)) - } - } - - private class NsdDiscoveryRecord(expectedThreadId: Int? = null) : - DiscoveryListener, NsdRecord(expectedThreadId) { - sealed class DiscoveryEvent : NsdEvent { - data class StartDiscoveryFailed(val serviceType: String, val errorCode: Int) : - DiscoveryEvent() - - data class StopDiscoveryFailed(val serviceType: String, val errorCode: Int) : - DiscoveryEvent() - - data class DiscoveryStarted(val serviceType: String) : DiscoveryEvent() - data class DiscoveryStopped(val serviceType: String) : DiscoveryEvent() - data class ServiceFound(val serviceInfo: NsdServiceInfo) : DiscoveryEvent() - data class ServiceLost(val serviceInfo: NsdServiceInfo) : DiscoveryEvent() - } - - override fun onStartDiscoveryFailed(serviceType: String, err: Int) { - add(StartDiscoveryFailed(serviceType, err)) - } - - override fun onStopDiscoveryFailed(serviceType: String, err: Int) { - add(StopDiscoveryFailed(serviceType, err)) - } - - override fun onDiscoveryStarted(serviceType: String) { - add(DiscoveryStarted(serviceType)) - } - - override fun onDiscoveryStopped(serviceType: String) { - add(DiscoveryStopped(serviceType)) - } - - override fun onServiceFound(si: NsdServiceInfo) { - add(ServiceFound(si)) - } - - override fun onServiceLost(si: NsdServiceInfo) { - add(ServiceLost(si)) - } - - fun waitForServiceDiscovered( - serviceName: String, - serviceType: String, - expectedNetwork: Network? = null - ): NsdServiceInfo { - val serviceFound = expectCallbackEventually { - it.serviceInfo.serviceName == serviceName && - (expectedNetwork == null || - expectedNetwork == it.serviceInfo.network) - }.serviceInfo - // Discovered service types have a dot at the end - assertEquals("$serviceType.", serviceFound.serviceType) - return serviceFound - } - } - - private class NsdResolveRecord : ResolveListener, - NsdRecord() { - sealed class ResolveEvent : NsdEvent { - data class ResolveFailed(val serviceInfo: NsdServiceInfo, val errorCode: Int) : - ResolveEvent() - - data class ServiceResolved(val serviceInfo: NsdServiceInfo) : ResolveEvent() - data class ResolutionStopped(val serviceInfo: NsdServiceInfo) : ResolveEvent() - data class StopResolutionFailed(val serviceInfo: NsdServiceInfo, val errorCode: Int) : - ResolveEvent() - } - - override fun onResolveFailed(si: NsdServiceInfo, err: Int) { - add(ResolveFailed(si, err)) - } - - override fun onServiceResolved(si: NsdServiceInfo) { - add(ServiceResolved(si)) - } - - override fun onResolutionStopped(si: NsdServiceInfo) { - add(ResolutionStopped(si)) - } - - override fun onStopResolutionFailed(si: NsdServiceInfo, err: Int) { - super.onStopResolutionFailed(si, err) - add(StopResolutionFailed(si, err)) - } - } - - private class NsdServiceInfoCallbackRecord : NsdManager.ServiceInfoCallback, - NsdRecord() { - sealed class ServiceInfoCallbackEvent : NsdEvent { - data class RegisterCallbackFailed(val errorCode: Int) : ServiceInfoCallbackEvent() - data class ServiceUpdated(val serviceInfo: NsdServiceInfo) : ServiceInfoCallbackEvent() - object ServiceUpdatedLost : ServiceInfoCallbackEvent() - object UnregisterCallbackSucceeded : ServiceInfoCallbackEvent() - } - - override fun onServiceInfoCallbackRegistrationFailed(err: Int) { - add(RegisterCallbackFailed(err)) - } - - override fun onServiceUpdated(si: NsdServiceInfo) { - add(ServiceUpdated(si)) - } - - override fun onServiceLost() { - add(ServiceUpdatedLost) - } - - override fun onServiceInfoCallbackUnregistered() { - add(UnregisterCallbackSucceeded) - } - } - private class TestNsdOffloadEngine : OffloadEngine, NsdRecord() { sealed class OffloadEvent : NsdEvent { @@ -1414,54 +1212,6 @@ class NsdManagerTest { } } -private fun TapPacketReader.pollForMdnsPacket( - timeoutMs: Long = REGISTRATION_TIMEOUT_MS, - predicate: (TestDnsPacket) -> Boolean -): ByteArray? { - val mdnsProbeFilter = IPv6UdpFilter(srcPort = MDNS_PORT, dstPort = MDNS_PORT).and { - val mdnsPayload = it.copyOfRange( - ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, it.size) - try { - predicate(TestDnsPacket(mdnsPayload)) - } catch (e: DnsPacket.ParseException) { - false - } - } - return poll(timeoutMs, mdnsProbeFilter) -} - -private fun TapPacketReader.pollForProbe( - serviceName: String, - serviceType: String, - timeoutMs: Long = REGISTRATION_TIMEOUT_MS -): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isProbeFor("$serviceName.$serviceType.local") } - -private fun TapPacketReader.pollForAdvertisement( - serviceName: String, - serviceType: String, - timeoutMs: Long = REGISTRATION_TIMEOUT_MS -): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isReplyFor("$serviceName.$serviceType.local") } - -private class TestDnsPacket(data: ByteArray) : DnsPacket(data) { - val header: DnsHeader - get() = mHeader - val records: Array> - get() = mRecords - - fun isProbeFor(name: String): Boolean = mRecords[QDSECTION].any { - it.dName == name && it.nsType == 0xff /* ANY */ - } - - fun isReplyFor(name: String): Boolean = mRecords[ANSECTION].any { - it.dName == name && it.nsType == 0x21 /* SRV */ - } -} - -private fun ByteArray?.utf8ToString(): String { - if (this == null) return "" - return String(this, StandardCharsets.UTF_8) -} - private fun ByteArray.indexOf(sub: ByteArray): Int { var subIndex = 0 forEachIndexed { i, b -> @@ -1481,3 +1231,8 @@ private fun ByteArray.indexOf(sub: ByteArray): Int { } return -1 } + +private fun ByteArray?.utf8ToString(): String { + if (this == null) return "" + return String(this, StandardCharsets.UTF_8) +}