From 9133888e0c25208f7907cf4353bd4eaedc8ef9cd Mon Sep 17 00:00:00 2001 From: Remi NGUYEN VAN Date: Tue, 31 May 2022 18:39:18 +0900 Subject: [PATCH] Ensure callbacks are run properly on executor NsdManager callbacks were run on a provided executor by capturing the handler message in a lambda, but the message will be recycled immediately after handleMessage returns. This means that any non-inline executor would see bogus callbacks, as they have an empty Message. Fix it by not capturing the Message in the lambda, but capturing its contents instead. This was broken when updating the class to support executors in change ID: I4c31e2d7ae601ea808b1fd64df32d116c6fff97f; before that, callbacks were all run on the NsdManager handler. Also, DelegatingDiscoveryListener is being run on the NsdManager handler thread for notifyAllServicesLost, causing onServiceLost to be run there, but other methods are run on the provided Executor, even though they access maps maintained on the handler thread, like mPerNetworkListeners. Revert DelegatingDiscoveryListener to run on the handler thread as before, and only use the provided executor to execute any app-facing callback instead. Bug: 234419509 Test: atest NsdManagerTest Change-Id: Icca64511b02dad2f725a2849d2a1e871135b3286 --- .../src/android/net/nsd/NsdManager.java | 61 +++++++++------ .../net/src/android/net/cts/NsdManagerTest.kt | 78 +++++++++++++++---- 2 files changed, 103 insertions(+), 36 deletions(-) diff --git a/framework-t/src/android/net/nsd/NsdManager.java b/framework-t/src/android/net/nsd/NsdManager.java index 33b44c8210..f19bf4a6fb 100644 --- a/framework-t/src/android/net/nsd/NsdManager.java +++ b/framework-t/src/android/net/nsd/NsdManager.java @@ -312,9 +312,12 @@ public final class NsdManager { @Override public void onAvailable(@NonNull Network network) { final DelegatingDiscoveryListener wrappedListener = new DelegatingDiscoveryListener( - network, mBaseListener); + network, mBaseListener, mBaseExecutor); mPerNetworkListeners.put(network, wrappedListener); - discoverServices(mServiceType, mProtocolType, network, mBaseExecutor, + // Run discovery callbacks inline on the service handler thread, which is the + // same thread used by this NetworkCallback, but DelegatingDiscoveryListener will + // use the base executor to run the wrapped callbacks. + discoverServices(mServiceType, mProtocolType, network, Runnable::run, wrappedListener); } @@ -334,7 +337,8 @@ public final class NsdManager { public void start(@NonNull NetworkRequest request) { final ConnectivityManager cm = mContext.getSystemService(ConnectivityManager.class); cm.registerNetworkCallback(request, mNetworkCb, mHandler); - mHandler.post(() -> mBaseListener.onDiscoveryStarted(mServiceType)); + mHandler.post(() -> mBaseExecutor.execute(() -> + mBaseListener.onDiscoveryStarted(mServiceType))); } /** @@ -351,7 +355,7 @@ public final class NsdManager { final ConnectivityManager cm = mContext.getSystemService(ConnectivityManager.class); cm.unregisterNetworkCallback(mNetworkCb); if (mPerNetworkListeners.size() == 0) { - mBaseListener.onDiscoveryStopped(mServiceType); + mBaseExecutor.execute(() -> mBaseListener.onDiscoveryStopped(mServiceType)); return; } for (int i = 0; i < mPerNetworkListeners.size(); i++) { @@ -399,14 +403,23 @@ public final class NsdManager { } } + /** + * A listener wrapping calls to an app-provided listener, while keeping track of found + * services, so they can all be reported lost when the underlying network is lost. + * + * This should be registered to run on the service handler. + */ private class DelegatingDiscoveryListener implements DiscoveryListener { private final Network mNetwork; private final DiscoveryListener mWrapped; + private final Executor mWrappedExecutor; private final ArraySet mFoundInfo = new ArraySet<>(); - private DelegatingDiscoveryListener(Network network, DiscoveryListener listener) { + private DelegatingDiscoveryListener(Network network, DiscoveryListener listener, + Executor executor) { mNetwork = network; mWrapped = listener; + mWrappedExecutor = executor; } void notifyAllServicesLost() { @@ -415,7 +428,7 @@ public final class NsdManager { final NsdServiceInfo serviceInfo = new NsdServiceInfo( trackedInfo.mServiceName, trackedInfo.mServiceType); serviceInfo.setNetwork(mNetwork); - mWrapped.onServiceLost(serviceInfo); + mWrappedExecutor.execute(() -> mWrapped.onServiceLost(serviceInfo)); } } @@ -444,7 +457,7 @@ public final class NsdManager { // Do not report onStopDiscoveryFailed when some underlying listeners failed: // this does not mean that all listeners did, and onStopDiscoveryFailed is not // actionable anyway. Just report that discovery stopped. - mWrapped.onDiscoveryStopped(serviceType); + mWrappedExecutor.execute(() -> mWrapped.onDiscoveryStopped(serviceType)); } } @@ -452,20 +465,20 @@ public final class NsdManager { public void onDiscoveryStopped(String serviceType) { mPerNetworkListeners.remove(mNetwork); if (mStopRequested && mPerNetworkListeners.size() == 0) { - mWrapped.onDiscoveryStopped(serviceType); + mWrappedExecutor.execute(() -> mWrapped.onDiscoveryStopped(serviceType)); } } @Override public void onServiceFound(NsdServiceInfo serviceInfo) { mFoundInfo.add(new TrackedNsdInfo(serviceInfo)); - mWrapped.onServiceFound(serviceInfo); + mWrappedExecutor.execute(() -> mWrapped.onServiceFound(serviceInfo)); } @Override public void onServiceLost(NsdServiceInfo serviceInfo) { mFoundInfo.remove(new TrackedNsdInfo(serviceInfo)); - mWrapped.onServiceLost(serviceInfo); + mWrappedExecutor.execute(() -> mWrapped.onServiceLost(serviceInfo)); } } } @@ -648,8 +661,12 @@ public final class NsdManager { @Override public void handleMessage(Message message) { + // Do not use message in the executor lambdas, as it will be recycled once this method + // returns. Keep references to its content instead. final int what = message.what; + final int errorCode = message.arg1; final int key = message.arg2; + final Object obj = message.obj; final Object listener; final NsdServiceInfo ns; final Executor executor; @@ -659,7 +676,7 @@ public final class NsdManager { executor = mExecutorMap.get(key); } if (listener == null) { - Log.d(TAG, "Stale key " + message.arg2); + Log.d(TAG, "Stale key " + key); return; } if (DBG) { @@ -667,28 +684,28 @@ public final class NsdManager { } switch (what) { case DISCOVER_SERVICES_STARTED: - final String s = getNsdServiceInfoType((NsdServiceInfo) message.obj); + final String s = getNsdServiceInfoType((NsdServiceInfo) obj); executor.execute(() -> ((DiscoveryListener) listener).onDiscoveryStarted(s)); break; case DISCOVER_SERVICES_FAILED: removeListener(key); executor.execute(() -> ((DiscoveryListener) listener).onStartDiscoveryFailed( - getNsdServiceInfoType(ns), message.arg1)); + getNsdServiceInfoType(ns), errorCode)); break; case SERVICE_FOUND: executor.execute(() -> ((DiscoveryListener) listener).onServiceFound( - (NsdServiceInfo) message.obj)); + (NsdServiceInfo) obj)); break; case SERVICE_LOST: executor.execute(() -> ((DiscoveryListener) listener).onServiceLost( - (NsdServiceInfo) message.obj)); + (NsdServiceInfo) obj)); break; case STOP_DISCOVERY_FAILED: // TODO: failure to stop discovery should be internal and retried internally, as // the effect for the client is indistinguishable from STOP_DISCOVERY_SUCCEEDED removeListener(key); executor.execute(() -> ((DiscoveryListener) listener).onStopDiscoveryFailed( - getNsdServiceInfoType(ns), message.arg1)); + getNsdServiceInfoType(ns), errorCode)); break; case STOP_DISCOVERY_SUCCEEDED: removeListener(key); @@ -698,33 +715,33 @@ public final class NsdManager { case REGISTER_SERVICE_FAILED: removeListener(key); executor.execute(() -> ((RegistrationListener) listener).onRegistrationFailed( - ns, message.arg1)); + ns, errorCode)); break; case REGISTER_SERVICE_SUCCEEDED: executor.execute(() -> ((RegistrationListener) listener).onServiceRegistered( - (NsdServiceInfo) message.obj)); + (NsdServiceInfo) obj)); break; case UNREGISTER_SERVICE_FAILED: removeListener(key); executor.execute(() -> ((RegistrationListener) listener).onUnregistrationFailed( - ns, message.arg1)); + ns, errorCode)); break; case UNREGISTER_SERVICE_SUCCEEDED: // TODO: do not unregister listener until service is unregistered, or provide // alternative way for unregistering ? - removeListener(message.arg2); + removeListener(key); executor.execute(() -> ((RegistrationListener) listener).onServiceUnregistered( ns)); break; case RESOLVE_SERVICE_FAILED: removeListener(key); executor.execute(() -> ((ResolveListener) listener).onResolveFailed( - ns, message.arg1)); + ns, errorCode)); break; case RESOLVE_SERVICE_SUCCEEDED: removeListener(key); executor.execute(() -> ((ResolveListener) listener).onServiceResolved( - (NsdServiceInfo) message.obj)); + (NsdServiceInfo) obj)); break; default: Log.d(TAG, "Ignored " + message); diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt index 33a0a83b5a..64cc97d265 100644 --- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt +++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt @@ -22,6 +22,7 @@ import android.net.LinkProperties import android.net.Network import android.net.NetworkAgentConfig import android.net.NetworkCapabilities +import android.net.NetworkCapabilities.NET_CAPABILITY_TEMPORARILY_NOT_METERED import android.net.NetworkCapabilities.NET_CAPABILITY_TRUSTED import android.net.NetworkCapabilities.TRANSPORT_TEST import android.net.NetworkRequest @@ -45,7 +46,9 @@ import android.net.nsd.NsdManager.DiscoveryListener import android.net.nsd.NsdManager.RegistrationListener import android.net.nsd.NsdManager.ResolveListener import android.net.nsd.NsdServiceInfo +import android.os.Handler import android.os.HandlerThread +import android.os.Process.myTid import android.platform.test.annotations.AppModeFull import android.util.Log import androidx.test.platform.app.InstrumentationRegistry @@ -111,12 +114,20 @@ class NsdManagerTest { private interface NsdEvent private open class NsdRecord private constructor( - private val history: ArrayTrackRecord + private val history: ArrayTrackRecord, + private val expectedThreadId: Int? = null ) : TrackRecord by history { - constructor() : this(ArrayTrackRecord()) + constructor(expectedThreadId: Int? = null) : this(ArrayTrackRecord(), expectedThreadId) val nextEvents = history.newReadHead() + override fun add(e: T): Boolean { + if (expectedThreadId != null) { + assertEquals(expectedThreadId, myTid(), "Callback is running on the wrong thread") + } + return history.add(e) + } + inline fun expectCallbackEventually( crossinline predicate: (V) -> Boolean = { true } ): V = nextEvents.poll(TIMEOUT_MS) { e -> e is V && predicate(e) } as V? @@ -136,8 +147,8 @@ class NsdManagerTest { } } - private class NsdRegistrationRecord : RegistrationListener, - NsdRecord() { + private class NsdRegistrationRecord(expectedThreadId: Int? = null) : RegistrationListener, + NsdRecord(expectedThreadId) { sealed class RegistrationEvent : NsdEvent { abstract val serviceInfo: NsdServiceInfo @@ -174,8 +185,8 @@ class NsdManagerTest { } } - private class NsdDiscoveryRecord : DiscoveryListener, - NsdRecord() { + private class NsdDiscoveryRecord(expectedThreadId: Int? = null) : + DiscoveryListener, NsdRecord(expectedThreadId) { sealed class DiscoveryEvent : NsdEvent { data class StartDiscoveryFailed(val serviceType: String, val errorCode: Int) : DiscoveryEvent() @@ -462,9 +473,12 @@ class NsdManagerTest { si.serviceName = this.serviceName si.port = 12345 // Test won't try to connect so port does not matter - val registrationRecord = NsdRegistrationRecord() - val registeredInfo1 = registerService(registrationRecord, si) - val discoveryRecord = NsdDiscoveryRecord() + val handler = Handler(handlerThread.looper) + val executor = Executor { handler.post(it) } + + val registrationRecord = NsdRegistrationRecord(expectedThreadId = handlerThread.threadId) + val registeredInfo1 = registerService(registrationRecord, si, executor) + val discoveryRecord = NsdDiscoveryRecord(expectedThreadId = handlerThread.threadId) tryTest { val specifier = TestNetworkSpecifier(testNetwork1.iface.interfaceName) @@ -474,7 +488,7 @@ class NsdManagerTest { .addTransportType(TRANSPORT_TEST) .setNetworkSpecifier(specifier) .build(), - Executor { it.run() }, discoveryRecord) + executor, discoveryRecord) val discoveryStarted = discoveryRecord.expectCallback() assertEquals(SERVICE_TYPE, discoveryStarted.serviceType) @@ -490,7 +504,7 @@ class NsdManagerTest { assertEquals(testNetwork1.network, nsdShim.getNetwork(serviceLost1.serviceInfo)) registrationRecord.expectCallback() - val registeredInfo2 = registerService(registrationRecord, si) + val registeredInfo2 = registerService(registrationRecord, si, executor) val serviceDiscovered2 = discoveryRecord.expectCallback() assertEquals(registeredInfo2.serviceName, serviceDiscovered2.serviceInfo.serviceName) assertEquals(testNetwork1.network, nsdShim.getNetwork(serviceDiscovered2.serviceInfo)) @@ -517,6 +531,39 @@ class NsdManagerTest { } } + @Test + fun testNsdManager_DiscoverWithNetworkRequest_NoMatchingNetwork() { + // This test requires shims supporting T+ APIs (discovering on network request) + assumeTrue(TestUtils.shouldTestTApis()) + + val si = NsdServiceInfo() + si.serviceType = SERVICE_TYPE + si.serviceName = this.serviceName + si.port = 12345 // Test won't try to connect so port does not matter + + val handler = Handler(handlerThread.looper) + val executor = Executor { handler.post(it) } + + val discoveryRecord = NsdDiscoveryRecord(expectedThreadId = handlerThread.threadId) + val specifier = TestNetworkSpecifier(testNetwork1.iface.interfaceName) + + tryTest { + nsdShim.discoverServices(nsdManager, SERVICE_TYPE, NsdManager.PROTOCOL_DNS_SD, + NetworkRequest.Builder() + .removeCapability(NET_CAPABILITY_TRUSTED) + .addTransportType(TRANSPORT_TEST) + // Specified network does not have this capability + .addCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED) + .setNetworkSpecifier(specifier) + .build(), + executor, discoveryRecord) + discoveryRecord.expectCallback() + } cleanup { + nsdManager.stopServiceDiscovery(discoveryRecord) + discoveryRecord.expectCallback() + } + } + @Test fun testNsdManager_ResolveOnNetwork() { // This test requires shims supporting T+ APIs (NsdServiceInfo.network) @@ -648,9 +695,12 @@ class NsdManagerTest { /** * Register a service and return its registration record. */ - private fun registerService(record: NsdRegistrationRecord, si: NsdServiceInfo): NsdServiceInfo { - nsdShim.registerService(nsdManager, si, NsdManager.PROTOCOL_DNS_SD, Executor { it.run() }, - record) + private fun registerService( + record: NsdRegistrationRecord, + si: NsdServiceInfo, + executor: Executor = Executor { it.run() } + ): NsdServiceInfo { + nsdShim.registerService(nsdManager, si, NsdManager.PROTOCOL_DNS_SD, executor, record) // We may not always get the name that we tried to register; // This events tells us the name that was registered. val cb = record.expectCallback()