Merge "Add test for downstream tethering" into main

This commit is contained in:
Treehugger Robot
2023-09-13 01:01:32 +00:00
committed by Gerrit Code Review
9 changed files with 485 additions and 278 deletions

View File

@@ -28,7 +28,7 @@ java_defaults {
"DhcpPacketLib",
"androidx.test.rules",
"cts-net-utils",
"mockito-target-extended-minus-junit4",
"mockito-target-minus-junit4",
"net-tests-utils",
"net-utils-device-common",
"net-utils-device-common-bpf",
@@ -40,11 +40,6 @@ java_defaults {
"android.test.base",
"android.test.mock",
],
jni_libs: [
// For mockito extended
"libdexmakerjvmtiagent",
"libstaticjvmtiagent",
],
}
android_library {
@@ -54,6 +49,7 @@ android_library {
defaults: ["TetheringIntegrationTestsDefaults"],
visibility: [
"//packages/modules/Connectivity/Tethering/tests/mts",
"//packages/modules/Connectivity/tests/cts/net",
]
}

View File

@@ -31,14 +31,12 @@ import static android.net.TetheringTester.isAddressIpv4;
import static android.net.TetheringTester.isExpectedIcmpPacket;
import static android.net.TetheringTester.isExpectedTcpPacket;
import static android.net.TetheringTester.isExpectedUdpPacket;
import static com.android.net.module.util.HexDump.dumpHexString;
import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ROUTER_ADVERTISEMENT;
import static com.android.net.module.util.NetworkStackConstants.TCPHDR_ACK;
import static com.android.net.module.util.NetworkStackConstants.TCPHDR_SYN;
import static com.android.testutils.TestNetworkTrackerKt.initTestNetwork;
import static com.android.testutils.TestPermissionUtil.runAsShell;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
@@ -164,6 +162,10 @@ public abstract class EthernetTetheringTestBase {
private TapPacketReader mDownstreamReader;
private MyTetheringEventCallback mTetheringEventCallback;
public Context getContext() {
return mContext;
}
@BeforeClass
public static void setUpOnce() throws Exception {
// The first test case may experience tethering restart with IP conflict handling.

View File

@@ -27,12 +27,9 @@ import static android.system.OsConstants.IPPROTO_IP;
import static android.system.OsConstants.IPPROTO_IPV6;
import static android.system.OsConstants.IPPROTO_TCP;
import static android.system.OsConstants.IPPROTO_UDP;
import static com.android.net.module.util.DnsPacket.ANSECTION;
import static com.android.net.module.util.DnsPacket.ARSECTION;
import static com.android.net.module.util.DnsPacket.DnsHeader;
import static com.android.net.module.util.DnsPacket.DnsRecord;
import static com.android.net.module.util.DnsPacket.NSSECTION;
import static com.android.net.module.util.DnsPacket.QDSECTION;
import static com.android.net.module.util.HexDump.dumpHexString;
import static com.android.net.module.util.IpUtils.icmpChecksum;
@@ -56,7 +53,6 @@ import static com.android.net.module.util.NetworkStackConstants.IPV6_ADDR_ALL_NO
import static com.android.net.module.util.NetworkStackConstants.NEIGHBOR_ADVERTISEMENT_FLAG_OVERRIDE;
import static com.android.net.module.util.NetworkStackConstants.NEIGHBOR_ADVERTISEMENT_FLAG_SOLICITED;
import static com.android.net.module.util.NetworkStackConstants.TCPHDR_SYN;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail;

View File

@@ -77,6 +77,15 @@ public final class DnsResolver {
@interface QueryType {}
public static final int TYPE_A = 1;
public static final int TYPE_AAAA = 28;
// TODO: add below constants as part of QueryType and the public API
/** @hide */
public static final int TYPE_PTR = 12;
/** @hide */
public static final int TYPE_TXT = 16;
/** @hide */
public static final int TYPE_SRV = 33;
/** @hide */
public static final int TYPE_ANY = 255;
@IntDef(prefix = { "FLAG_" }, value = {
FLAG_EMPTY,

View File

@@ -1684,7 +1684,10 @@ public class NsdService extends INsdManager.Stub {
mMdnsSocketProvider = deps.makeMdnsSocketProvider(ctx, handler.getLooper(),
LOGGER.forSubComponent("MdnsSocketProvider"), new SocketRequestMonitor());
// Netlink monitor starts on boot, and intentionally never stopped, to ensure that all
// address events are received.
// address events are received. When the netlink monitor starts, any IP addresses already
// on the interfaces will not be seen. In practice, the network will not connect at boot
// time As a result, all the netlink message should be observed if the netlink monitor
// starts here.
handler.post(mMdnsSocketProvider::startNetLinkMonitor);
// NsdService is started after ActivityManager (startOtherServices in SystemServer, vs.

View File

@@ -56,6 +56,7 @@ java_defaults {
"modules-utils-build",
"net-utils-framework-common",
"truth-prebuilt",
"TetheringIntegrationTestsBaseLib",
],
// uncomment when b/13249961 is fixed

View File

@@ -0,0 +1,295 @@
/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package android.net.cts
import android.net.DnsResolver
import android.net.Network
import android.net.nsd.NsdManager
import android.net.nsd.NsdServiceInfo
import android.os.Process
import com.android.net.module.util.ArrayTrackRecord
import com.android.net.module.util.DnsPacket
import com.android.net.module.util.NetworkStackConstants.ETHER_HEADER_LEN
import com.android.net.module.util.NetworkStackConstants.IPV6_HEADER_LEN
import com.android.net.module.util.NetworkStackConstants.UDP_HEADER_LEN
import com.android.net.module.util.TrackRecord
import com.android.testutils.IPv6UdpFilter
import com.android.testutils.TapPacketReader
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
import kotlin.test.assertNull
import kotlin.test.assertTrue
import kotlin.test.fail
private const val MDNS_REGISTRATION_TIMEOUT_MS = 10_000L
private const val MDNS_PORT = 5353.toShort()
const val MDNS_CALLBACK_TIMEOUT = 2000L
const val MDNS_NO_CALLBACK_TIMEOUT_MS = 200L
interface NsdEvent
open class NsdRecord<T : NsdEvent> private constructor(
private val history: ArrayTrackRecord<T>,
private val expectedThreadId: Int? = null
) : TrackRecord<T> by history {
constructor(expectedThreadId: Int? = null) : this(ArrayTrackRecord(), expectedThreadId)
val nextEvents = history.newReadHead()
override fun add(e: T): Boolean {
if (expectedThreadId != null) {
assertEquals(
expectedThreadId, Process.myTid(),
"Callback is running on the wrong thread"
)
}
return history.add(e)
}
inline fun <reified V : NsdEvent> expectCallbackEventually(
timeoutMs: Long = MDNS_CALLBACK_TIMEOUT,
crossinline predicate: (V) -> Boolean = { true }
): V = nextEvents.poll(timeoutMs) { e -> e is V && predicate(e) } as V?
?: fail("Callback for ${V::class.java.simpleName} not seen after $timeoutMs ms")
inline fun <reified V : NsdEvent> expectCallback(timeoutMs: Long = MDNS_CALLBACK_TIMEOUT): V {
val nextEvent = nextEvents.poll(timeoutMs)
assertNotNull(
nextEvent, "No callback received after $timeoutMs ms, expected " +
"${V::class.java.simpleName}"
)
assertTrue(
nextEvent is V, "Expected ${V::class.java.simpleName} but got " +
nextEvent.javaClass.simpleName
)
return nextEvent
}
inline fun assertNoCallback(timeoutMs: Long = MDNS_NO_CALLBACK_TIMEOUT_MS) {
val cb = nextEvents.poll(timeoutMs)
assertNull(cb, "Expected no callback but got $cb")
}
}
class NsdDiscoveryRecord(expectedThreadId: Int? = null) :
NsdManager.DiscoveryListener, NsdRecord<NsdDiscoveryRecord.DiscoveryEvent>(expectedThreadId) {
sealed class DiscoveryEvent : NsdEvent {
data class StartDiscoveryFailed(val serviceType: String, val errorCode: Int) :
DiscoveryEvent()
data class StopDiscoveryFailed(val serviceType: String, val errorCode: Int) :
DiscoveryEvent()
data class DiscoveryStarted(val serviceType: String) : DiscoveryEvent()
data class DiscoveryStopped(val serviceType: String) : DiscoveryEvent()
data class ServiceFound(val serviceInfo: NsdServiceInfo) : DiscoveryEvent()
data class ServiceLost(val serviceInfo: NsdServiceInfo) : DiscoveryEvent()
}
override fun onStartDiscoveryFailed(serviceType: String, err: Int) {
add(DiscoveryEvent.StartDiscoveryFailed(serviceType, err))
}
override fun onStopDiscoveryFailed(serviceType: String, err: Int) {
add(DiscoveryEvent.StopDiscoveryFailed(serviceType, err))
}
override fun onDiscoveryStarted(serviceType: String) {
add(DiscoveryEvent.DiscoveryStarted(serviceType))
}
override fun onDiscoveryStopped(serviceType: String) {
add(DiscoveryEvent.DiscoveryStopped(serviceType))
}
override fun onServiceFound(si: NsdServiceInfo) {
add(DiscoveryEvent.ServiceFound(si))
}
override fun onServiceLost(si: NsdServiceInfo) {
add(DiscoveryEvent.ServiceLost(si))
}
fun waitForServiceDiscovered(
serviceName: String,
serviceType: String,
expectedNetwork: Network? = null
): NsdServiceInfo {
val serviceFound = expectCallbackEventually<DiscoveryEvent.ServiceFound> {
it.serviceInfo.serviceName == serviceName &&
(expectedNetwork == null ||
expectedNetwork == it.serviceInfo.network)
}.serviceInfo
// Discovered service types have a dot at the end
assertEquals("$serviceType.", serviceFound.serviceType)
return serviceFound
}
}
class NsdRegistrationRecord(expectedThreadId: Int? = null) : NsdManager.RegistrationListener,
NsdRecord<NsdRegistrationRecord.RegistrationEvent>(expectedThreadId) {
sealed class RegistrationEvent : NsdEvent {
abstract val serviceInfo: NsdServiceInfo
data class RegistrationFailed(
override val serviceInfo: NsdServiceInfo,
val errorCode: Int
) : RegistrationEvent()
data class UnregistrationFailed(
override val serviceInfo: NsdServiceInfo,
val errorCode: Int
) : RegistrationEvent()
data class ServiceRegistered(override val serviceInfo: NsdServiceInfo) :
RegistrationEvent()
data class ServiceUnregistered(override val serviceInfo: NsdServiceInfo) :
RegistrationEvent()
}
override fun onRegistrationFailed(si: NsdServiceInfo, err: Int) {
add(RegistrationEvent.RegistrationFailed(si, err))
}
override fun onUnregistrationFailed(si: NsdServiceInfo, err: Int) {
add(RegistrationEvent.UnregistrationFailed(si, err))
}
override fun onServiceRegistered(si: NsdServiceInfo) {
add(RegistrationEvent.ServiceRegistered(si))
}
override fun onServiceUnregistered(si: NsdServiceInfo) {
add(RegistrationEvent.ServiceUnregistered(si))
}
}
class NsdResolveRecord : NsdManager.ResolveListener,
NsdRecord<NsdResolveRecord.ResolveEvent>() {
sealed class ResolveEvent : NsdEvent {
data class ResolveFailed(val serviceInfo: NsdServiceInfo, val errorCode: Int) :
ResolveEvent()
data class ServiceResolved(val serviceInfo: NsdServiceInfo) : ResolveEvent()
data class ResolutionStopped(val serviceInfo: NsdServiceInfo) : ResolveEvent()
data class StopResolutionFailed(val serviceInfo: NsdServiceInfo, val errorCode: Int) :
ResolveEvent()
}
override fun onResolveFailed(si: NsdServiceInfo, err: Int) {
add(ResolveEvent.ResolveFailed(si, err))
}
override fun onServiceResolved(si: NsdServiceInfo) {
add(ResolveEvent.ServiceResolved(si))
}
override fun onResolutionStopped(si: NsdServiceInfo) {
add(ResolveEvent.ResolutionStopped(si))
}
override fun onStopResolutionFailed(si: NsdServiceInfo, err: Int) {
super.onStopResolutionFailed(si, err)
add(ResolveEvent.StopResolutionFailed(si, err))
}
}
class NsdServiceInfoCallbackRecord : NsdManager.ServiceInfoCallback,
NsdRecord<NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent>() {
sealed class ServiceInfoCallbackEvent : NsdEvent {
data class RegisterCallbackFailed(val errorCode: Int) : ServiceInfoCallbackEvent()
data class ServiceUpdated(val serviceInfo: NsdServiceInfo) : ServiceInfoCallbackEvent()
object ServiceUpdatedLost : ServiceInfoCallbackEvent()
object UnregisterCallbackSucceeded : ServiceInfoCallbackEvent()
}
override fun onServiceInfoCallbackRegistrationFailed(err: Int) {
add(ServiceInfoCallbackEvent.RegisterCallbackFailed(err))
}
override fun onServiceUpdated(si: NsdServiceInfo) {
add(ServiceInfoCallbackEvent.ServiceUpdated(si))
}
override fun onServiceLost() {
add(ServiceInfoCallbackEvent.ServiceUpdatedLost)
}
override fun onServiceInfoCallbackUnregistered() {
add(ServiceInfoCallbackEvent.UnregisterCallbackSucceeded)
}
}
fun TapPacketReader.pollForMdnsPacket(
timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS,
predicate: (TestDnsPacket) -> Boolean
): ByteArray? {
val mdnsProbeFilter = IPv6UdpFilter(srcPort = MDNS_PORT, dstPort = MDNS_PORT).and {
val mdnsPayload = it.copyOfRange(
ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, it.size
)
try {
predicate(TestDnsPacket(mdnsPayload))
} catch (e: DnsPacket.ParseException) {
false
}
}
return poll(timeoutMs, mdnsProbeFilter)
}
fun TapPacketReader.pollForProbe(
serviceName: String,
serviceType: String,
timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isProbeFor("$serviceName.$serviceType.local") }
fun TapPacketReader.pollForAdvertisement(
serviceName: String,
serviceType: String,
timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isReplyFor("$serviceName.$serviceType.local") }
fun TapPacketReader.pollForQuery(
recordName: String,
recordType: Int,
timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isQueryFor(recordName, recordType) }
fun TapPacketReader.pollForReply(
serviceName: String,
serviceType: String,
timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
): ByteArray? = pollForMdnsPacket(timeoutMs) {
it.isReplyFor("$serviceName.$serviceType.local")
}
class TestDnsPacket(data: ByteArray) : DnsPacket(data) {
val header: DnsHeader
get() = mHeader
val records: Array<List<DnsRecord>>
get() = mRecords
fun isProbeFor(name: String): Boolean = mRecords[QDSECTION].any {
it.dName == name && it.nsType == DnsResolver.TYPE_ANY
}
fun isReplyFor(name: String): Boolean = mRecords[ANSECTION].any {
it.dName == name && it.nsType == DnsResolver.TYPE_SRV
}
fun isQueryFor(name: String, type: Int): Boolean = mRecords[QDSECTION].any {
it.dName == name && it.nsType == type
}
}

View File

@@ -0,0 +1,150 @@
/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package android.net.cts
import android.net.EthernetTetheringTestBase
import android.net.LinkAddress
import android.net.TestNetworkInterface
import android.net.TetheringManager.CONNECTIVITY_SCOPE_LOCAL
import android.net.TetheringManager.TETHERING_ETHERNET
import android.net.TetheringManager.TetheringRequest
import android.net.nsd.NsdManager
import android.os.Build
import androidx.test.filters.SmallTest
import com.android.testutils.ConnectivityModuleTest
import com.android.testutils.DevSdkIgnoreRule
import com.android.testutils.DevSdkIgnoreRunner
import com.android.testutils.TapPacketReader
import com.android.testutils.tryTest
import java.util.Random
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
import org.junit.After
import org.junit.Assume.assumeFalse
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
@RunWith(DevSdkIgnoreRunner::class)
@SmallTest
@ConnectivityModuleTest
@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
class NsdManagerDownstreamTetheringTest : EthernetTetheringTestBase() {
private val nsdManager by lazy { context.getSystemService(NsdManager::class.java)!! }
private val serviceType = "_nmt%09d._tcp".format(Random().nextInt(1_000_000_000))
@Before
override fun setUp() {
super.setUp()
setIncludeTestInterfaces(true)
}
@After
override fun tearDown() {
super.tearDown()
setIncludeTestInterfaces(false)
}
@Test
fun testMdnsDiscoveryCanSendPacketOnLocalOnlyDownstreamTetheringInterface() {
assumeFalse(isInterfaceForTetheringAvailable)
var downstreamIface: TestNetworkInterface? = null
var tetheringEventCallback: MyTetheringEventCallback? = null
var downstreamReader: TapPacketReader? = null
val discoveryRecord = NsdDiscoveryRecord()
tryTest {
downstreamIface = createTestInterface()
val iface = tetheredInterface
assertEquals(iface, downstreamIface?.interfaceName)
val request = TetheringRequest.Builder(TETHERING_ETHERNET)
.setConnectivityScope(CONNECTIVITY_SCOPE_LOCAL).build()
tetheringEventCallback = enableEthernetTethering(
iface, request,
null /* any upstream */
).apply {
awaitInterfaceLocalOnly()
}
// This shouldn't be flaky because the TAP interface will buffer all packets even
// before the reader is started.
downstreamReader = makePacketReader(downstreamIface)
waitForRouterAdvertisement(downstreamReader, iface, WAIT_RA_TIMEOUT_MS)
nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, discoveryRecord)
discoveryRecord.expectCallback<NsdDiscoveryRecord.DiscoveryEvent.DiscoveryStarted>()
assertNotNull(downstreamReader?.pollForQuery("$serviceType.local", 12 /* type PTR */))
} cleanupStep {
nsdManager.stopServiceDiscovery(discoveryRecord)
discoveryRecord.expectCallback<NsdDiscoveryRecord.DiscoveryEvent.DiscoveryStopped>()
} cleanupStep {
maybeStopTapPacketReader(downstreamReader)
} cleanupStep {
maybeCloseTestInterface(downstreamIface)
} cleanup {
maybeUnregisterTetheringEventCallback(tetheringEventCallback)
}
}
@Test
fun testMdnsDiscoveryWorkOnTetheringInterface() {
assumeFalse(isInterfaceForTetheringAvailable)
setIncludeTestInterfaces(true)
var downstreamIface: TestNetworkInterface? = null
var tetheringEventCallback: MyTetheringEventCallback? = null
var downstreamReader: TapPacketReader? = null
val discoveryRecord = NsdDiscoveryRecord()
tryTest {
downstreamIface = createTestInterface()
val iface = tetheredInterface
assertEquals(iface, downstreamIface?.interfaceName)
val localAddr = LinkAddress("192.0.2.3/28")
val clientAddr = LinkAddress("192.0.2.2/28")
val request = TetheringRequest.Builder(TETHERING_ETHERNET)
.setStaticIpv4Addresses(localAddr, clientAddr)
.setShouldShowEntitlementUi(false).build()
tetheringEventCallback = enableEthernetTethering(
iface, request,
null /* any upstream */
).apply {
awaitInterfaceTethered()
}
val fd = downstreamIface?.fileDescriptor?.fileDescriptor
assertNotNull(fd)
downstreamReader = makePacketReader(fd, getMTU(downstreamIface))
nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, discoveryRecord)
discoveryRecord.expectCallback<NsdDiscoveryRecord.DiscoveryEvent.DiscoveryStarted>()
assertNotNull(downstreamReader?.pollForQuery("$serviceType.local", 12 /* type PTR */))
// TODO: Add another test to check packet reply can trigger serviceFound.
} cleanupStep {
nsdManager.stopServiceDiscovery(discoveryRecord)
discoveryRecord.expectCallback<NsdDiscoveryRecord.DiscoveryEvent.DiscoveryStopped>()
} cleanupStep {
maybeStopTapPacketReader(downstreamReader)
} cleanupStep {
maybeCloseTestInterface(downstreamIface)
} cleanup {
maybeUnregisterTetheringEventCallback(tetheringEventCallback)
}
}
}

View File

@@ -38,36 +38,26 @@ import android.net.TestNetworkInterface
import android.net.TestNetworkManager
import android.net.TestNetworkSpecifier
import android.net.connectivity.ConnectivityCompatChanges
import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.DiscoveryStarted
import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.DiscoveryStopped
import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.ServiceFound
import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.ServiceLost
import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.StartDiscoveryFailed
import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.StopDiscoveryFailed
import android.net.cts.NsdManagerTest.NsdRegistrationRecord.RegistrationEvent.RegistrationFailed
import android.net.cts.NsdManagerTest.NsdRegistrationRecord.RegistrationEvent.ServiceRegistered
import android.net.cts.NsdManagerTest.NsdRegistrationRecord.RegistrationEvent.ServiceUnregistered
import android.net.cts.NsdManagerTest.NsdRegistrationRecord.RegistrationEvent.UnregistrationFailed
import android.net.cts.NsdManagerTest.NsdResolveRecord.ResolveEvent.ResolutionStopped
import android.net.cts.NsdManagerTest.NsdResolveRecord.ResolveEvent.ResolveFailed
import android.net.cts.NsdManagerTest.NsdResolveRecord.ResolveEvent.ServiceResolved
import android.net.cts.NsdManagerTest.NsdResolveRecord.ResolveEvent.StopResolutionFailed
import android.net.cts.NsdManagerTest.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.RegisterCallbackFailed
import android.net.cts.NsdManagerTest.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.ServiceUpdated
import android.net.cts.NsdManagerTest.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.ServiceUpdatedLost
import android.net.cts.NsdManagerTest.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.UnregisterCallbackSucceeded
import android.net.cts.NsdDiscoveryRecord.DiscoveryEvent.DiscoveryStarted
import android.net.cts.NsdDiscoveryRecord.DiscoveryEvent.DiscoveryStopped
import android.net.cts.NsdDiscoveryRecord.DiscoveryEvent.ServiceFound
import android.net.cts.NsdDiscoveryRecord.DiscoveryEvent.ServiceLost
import android.net.cts.NsdRegistrationRecord.RegistrationEvent.ServiceRegistered
import android.net.cts.NsdRegistrationRecord.RegistrationEvent.ServiceUnregistered
import android.net.cts.NsdResolveRecord.ResolveEvent.ResolutionStopped
import android.net.cts.NsdResolveRecord.ResolveEvent.ServiceResolved
import android.net.cts.NsdResolveRecord.ResolveEvent.StopResolutionFailed
import android.net.cts.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.ServiceUpdated
import android.net.cts.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.ServiceUpdatedLost
import android.net.cts.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.UnregisterCallbackSucceeded
import android.net.cts.util.CtsNetUtils
import android.net.nsd.NsdManager
import android.net.nsd.NsdManager.DiscoveryListener
import android.net.nsd.NsdManager.RegistrationListener
import android.net.nsd.NsdManager.ResolveListener
import android.net.nsd.NsdServiceInfo
import android.net.nsd.OffloadEngine
import android.net.nsd.OffloadServiceInfo
import android.os.Build
import android.os.Handler
import android.os.HandlerThread
import android.os.Process.myTid
import android.platform.test.annotations.AppModeFull
import android.system.ErrnoException
import android.system.Os
@@ -84,19 +74,13 @@ import androidx.test.platform.app.InstrumentationRegistry
import com.android.compatibility.common.util.PollingCheck
import com.android.compatibility.common.util.PropertyUtil
import com.android.modules.utils.build.SdkLevel.isAtLeastU
import com.android.net.module.util.ArrayTrackRecord
import com.android.net.module.util.DnsPacket
import com.android.net.module.util.HexDump
import com.android.net.module.util.NetworkStackConstants.ETHER_HEADER_LEN
import com.android.net.module.util.NetworkStackConstants.IPV6_HEADER_LEN
import com.android.net.module.util.NetworkStackConstants.UDP_HEADER_LEN
import com.android.net.module.util.PacketBuilder
import com.android.net.module.util.TrackRecord
import com.android.testutils.ConnectivityModuleTest
import com.android.testutils.DevSdkIgnoreRule
import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
import com.android.testutils.DevSdkIgnoreRunner
import com.android.testutils.IPv6UdpFilter
import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
import com.android.testutils.TapPacketReader
@@ -123,7 +107,6 @@ import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertNotNull
import kotlin.test.assertNull
import kotlin.test.assertTrue
import kotlin.test.fail
import org.junit.After
import org.junit.Assert.assertArrayEquals
@@ -137,7 +120,6 @@ import org.junit.runner.RunWith
private const val TAG = "NsdManagerTest"
private const val TIMEOUT_MS = 2000L
private const val NO_CALLBACK_TIMEOUT_MS = 200L
// Registration may take a long time if there are devices with the same hostname on the network,
// as the device needs to try another name and probe again. This is especially true since when using
// mdnsresponder the usual hostname is "Android", and on conflict "Android-2", "Android-3", ... are
@@ -159,7 +141,9 @@ class NsdManagerTest {
val ignoreRule = DevSdkIgnoreRule()
private val context by lazy { InstrumentationRegistry.getInstrumentation().context }
private val nsdManager by lazy { context.getSystemService(NsdManager::class.java)!! }
private val nsdManager by lazy {
context.getSystemService(NsdManager::class.java) ?: fail("Could not get NsdManager service")
}
private val cm by lazy { context.getSystemService(ConnectivityManager::class.java)!! }
private val serviceName = "NsdTest%09d".format(Random().nextInt(1_000_000_000))
@@ -185,192 +169,6 @@ class NsdManagerTest {
}
}
private interface NsdEvent
private open class NsdRecord<T : NsdEvent> private constructor(
private val history: ArrayTrackRecord<T>,
private val expectedThreadId: Int? = null
) : TrackRecord<T> by history {
constructor(expectedThreadId: Int? = null) : this(ArrayTrackRecord(), expectedThreadId)
val nextEvents = history.newReadHead()
override fun add(e: T): Boolean {
if (expectedThreadId != null) {
assertEquals(expectedThreadId, myTid(), "Callback is running on the wrong thread")
}
return history.add(e)
}
inline fun <reified V : NsdEvent> expectCallbackEventually(
timeoutMs: Long = TIMEOUT_MS,
crossinline predicate: (V) -> Boolean = { true }
): V = nextEvents.poll(timeoutMs) { e -> e is V && predicate(e) } as V?
?: fail("Callback for ${V::class.java.simpleName} not seen after $timeoutMs ms")
inline fun <reified V : NsdEvent> expectCallback(timeoutMs: Long = TIMEOUT_MS): V {
val nextEvent = nextEvents.poll(timeoutMs)
assertNotNull(nextEvent, "No callback received after $timeoutMs ms, " +
"expected ${V::class.java.simpleName}")
assertTrue(nextEvent is V, "Expected ${V::class.java.simpleName} but got " +
nextEvent.javaClass.simpleName)
return nextEvent
}
inline fun assertNoCallback(timeoutMs: Long = NO_CALLBACK_TIMEOUT_MS) {
val cb = nextEvents.poll(timeoutMs)
assertNull(cb, "Expected no callback but got $cb")
}
}
private class NsdRegistrationRecord(expectedThreadId: Int? = null) : RegistrationListener,
NsdRecord<NsdRegistrationRecord.RegistrationEvent>(expectedThreadId) {
sealed class RegistrationEvent : NsdEvent {
abstract val serviceInfo: NsdServiceInfo
data class RegistrationFailed(
override val serviceInfo: NsdServiceInfo,
val errorCode: Int
) : RegistrationEvent()
data class UnregistrationFailed(
override val serviceInfo: NsdServiceInfo,
val errorCode: Int
) : RegistrationEvent()
data class ServiceRegistered(override val serviceInfo: NsdServiceInfo) :
RegistrationEvent()
data class ServiceUnregistered(override val serviceInfo: NsdServiceInfo) :
RegistrationEvent()
}
override fun onRegistrationFailed(si: NsdServiceInfo, err: Int) {
add(RegistrationFailed(si, err))
}
override fun onUnregistrationFailed(si: NsdServiceInfo, err: Int) {
add(UnregistrationFailed(si, err))
}
override fun onServiceRegistered(si: NsdServiceInfo) {
add(ServiceRegistered(si))
}
override fun onServiceUnregistered(si: NsdServiceInfo) {
add(ServiceUnregistered(si))
}
}
private class NsdDiscoveryRecord(expectedThreadId: Int? = null) :
DiscoveryListener, NsdRecord<NsdDiscoveryRecord.DiscoveryEvent>(expectedThreadId) {
sealed class DiscoveryEvent : NsdEvent {
data class StartDiscoveryFailed(val serviceType: String, val errorCode: Int) :
DiscoveryEvent()
data class StopDiscoveryFailed(val serviceType: String, val errorCode: Int) :
DiscoveryEvent()
data class DiscoveryStarted(val serviceType: String) : DiscoveryEvent()
data class DiscoveryStopped(val serviceType: String) : DiscoveryEvent()
data class ServiceFound(val serviceInfo: NsdServiceInfo) : DiscoveryEvent()
data class ServiceLost(val serviceInfo: NsdServiceInfo) : DiscoveryEvent()
}
override fun onStartDiscoveryFailed(serviceType: String, err: Int) {
add(StartDiscoveryFailed(serviceType, err))
}
override fun onStopDiscoveryFailed(serviceType: String, err: Int) {
add(StopDiscoveryFailed(serviceType, err))
}
override fun onDiscoveryStarted(serviceType: String) {
add(DiscoveryStarted(serviceType))
}
override fun onDiscoveryStopped(serviceType: String) {
add(DiscoveryStopped(serviceType))
}
override fun onServiceFound(si: NsdServiceInfo) {
add(ServiceFound(si))
}
override fun onServiceLost(si: NsdServiceInfo) {
add(ServiceLost(si))
}
fun waitForServiceDiscovered(
serviceName: String,
serviceType: String,
expectedNetwork: Network? = null
): NsdServiceInfo {
val serviceFound = expectCallbackEventually<ServiceFound> {
it.serviceInfo.serviceName == serviceName &&
(expectedNetwork == null ||
expectedNetwork == it.serviceInfo.network)
}.serviceInfo
// Discovered service types have a dot at the end
assertEquals("$serviceType.", serviceFound.serviceType)
return serviceFound
}
}
private class NsdResolveRecord : ResolveListener,
NsdRecord<NsdResolveRecord.ResolveEvent>() {
sealed class ResolveEvent : NsdEvent {
data class ResolveFailed(val serviceInfo: NsdServiceInfo, val errorCode: Int) :
ResolveEvent()
data class ServiceResolved(val serviceInfo: NsdServiceInfo) : ResolveEvent()
data class ResolutionStopped(val serviceInfo: NsdServiceInfo) : ResolveEvent()
data class StopResolutionFailed(val serviceInfo: NsdServiceInfo, val errorCode: Int) :
ResolveEvent()
}
override fun onResolveFailed(si: NsdServiceInfo, err: Int) {
add(ResolveFailed(si, err))
}
override fun onServiceResolved(si: NsdServiceInfo) {
add(ServiceResolved(si))
}
override fun onResolutionStopped(si: NsdServiceInfo) {
add(ResolutionStopped(si))
}
override fun onStopResolutionFailed(si: NsdServiceInfo, err: Int) {
super.onStopResolutionFailed(si, err)
add(StopResolutionFailed(si, err))
}
}
private class NsdServiceInfoCallbackRecord : NsdManager.ServiceInfoCallback,
NsdRecord<NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent>() {
sealed class ServiceInfoCallbackEvent : NsdEvent {
data class RegisterCallbackFailed(val errorCode: Int) : ServiceInfoCallbackEvent()
data class ServiceUpdated(val serviceInfo: NsdServiceInfo) : ServiceInfoCallbackEvent()
object ServiceUpdatedLost : ServiceInfoCallbackEvent()
object UnregisterCallbackSucceeded : ServiceInfoCallbackEvent()
}
override fun onServiceInfoCallbackRegistrationFailed(err: Int) {
add(RegisterCallbackFailed(err))
}
override fun onServiceUpdated(si: NsdServiceInfo) {
add(ServiceUpdated(si))
}
override fun onServiceLost() {
add(ServiceUpdatedLost)
}
override fun onServiceInfoCallbackUnregistered() {
add(UnregisterCallbackSucceeded)
}
}
private class TestNsdOffloadEngine : OffloadEngine,
NsdRecord<TestNsdOffloadEngine.OffloadEvent>() {
sealed class OffloadEvent : NsdEvent {
@@ -1414,54 +1212,6 @@ class NsdManagerTest {
}
}
private fun TapPacketReader.pollForMdnsPacket(
timeoutMs: Long = REGISTRATION_TIMEOUT_MS,
predicate: (TestDnsPacket) -> Boolean
): ByteArray? {
val mdnsProbeFilter = IPv6UdpFilter(srcPort = MDNS_PORT, dstPort = MDNS_PORT).and {
val mdnsPayload = it.copyOfRange(
ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, it.size)
try {
predicate(TestDnsPacket(mdnsPayload))
} catch (e: DnsPacket.ParseException) {
false
}
}
return poll(timeoutMs, mdnsProbeFilter)
}
private fun TapPacketReader.pollForProbe(
serviceName: String,
serviceType: String,
timeoutMs: Long = REGISTRATION_TIMEOUT_MS
): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isProbeFor("$serviceName.$serviceType.local") }
private fun TapPacketReader.pollForAdvertisement(
serviceName: String,
serviceType: String,
timeoutMs: Long = REGISTRATION_TIMEOUT_MS
): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isReplyFor("$serviceName.$serviceType.local") }
private class TestDnsPacket(data: ByteArray) : DnsPacket(data) {
val header: DnsHeader
get() = mHeader
val records: Array<List<DnsRecord>>
get() = mRecords
fun isProbeFor(name: String): Boolean = mRecords[QDSECTION].any {
it.dName == name && it.nsType == 0xff /* ANY */
}
fun isReplyFor(name: String): Boolean = mRecords[ANSECTION].any {
it.dName == name && it.nsType == 0x21 /* SRV */
}
}
private fun ByteArray?.utf8ToString(): String {
if (this == null) return ""
return String(this, StandardCharsets.UTF_8)
}
private fun ByteArray.indexOf(sub: ByteArray): Int {
var subIndex = 0
forEachIndexed { i, b ->
@@ -1481,3 +1231,8 @@ private fun ByteArray.indexOf(sub: ByteArray): Int {
}
return -1
}
private fun ByteArray?.utf8ToString(): String {
if (this == null) return ""
return String(this, StandardCharsets.UTF_8)
}