Merge "Ensure callbacks are run properly on executor"

This commit is contained in:
Remi NGUYEN VAN
2022-06-03 09:04:29 +00:00
committed by Gerrit Code Review
2 changed files with 103 additions and 36 deletions

View File

@@ -306,9 +306,12 @@ public final class NsdManager {
@Override @Override
public void onAvailable(@NonNull Network network) { public void onAvailable(@NonNull Network network) {
final DelegatingDiscoveryListener wrappedListener = new DelegatingDiscoveryListener( final DelegatingDiscoveryListener wrappedListener = new DelegatingDiscoveryListener(
network, mBaseListener); network, mBaseListener, mBaseExecutor);
mPerNetworkListeners.put(network, wrappedListener); mPerNetworkListeners.put(network, wrappedListener);
discoverServices(mServiceType, mProtocolType, network, mBaseExecutor, // Run discovery callbacks inline on the service handler thread, which is the
// same thread used by this NetworkCallback, but DelegatingDiscoveryListener will
// use the base executor to run the wrapped callbacks.
discoverServices(mServiceType, mProtocolType, network, Runnable::run,
wrappedListener); wrappedListener);
} }
@@ -328,7 +331,8 @@ public final class NsdManager {
public void start(@NonNull NetworkRequest request) { public void start(@NonNull NetworkRequest request) {
final ConnectivityManager cm = mContext.getSystemService(ConnectivityManager.class); final ConnectivityManager cm = mContext.getSystemService(ConnectivityManager.class);
cm.registerNetworkCallback(request, mNetworkCb, mHandler); cm.registerNetworkCallback(request, mNetworkCb, mHandler);
mHandler.post(() -> mBaseListener.onDiscoveryStarted(mServiceType)); mHandler.post(() -> mBaseExecutor.execute(() ->
mBaseListener.onDiscoveryStarted(mServiceType)));
} }
/** /**
@@ -345,7 +349,7 @@ public final class NsdManager {
final ConnectivityManager cm = mContext.getSystemService(ConnectivityManager.class); final ConnectivityManager cm = mContext.getSystemService(ConnectivityManager.class);
cm.unregisterNetworkCallback(mNetworkCb); cm.unregisterNetworkCallback(mNetworkCb);
if (mPerNetworkListeners.size() == 0) { if (mPerNetworkListeners.size() == 0) {
mBaseListener.onDiscoveryStopped(mServiceType); mBaseExecutor.execute(() -> mBaseListener.onDiscoveryStopped(mServiceType));
return; return;
} }
for (int i = 0; i < mPerNetworkListeners.size(); i++) { for (int i = 0; i < mPerNetworkListeners.size(); i++) {
@@ -393,14 +397,23 @@ public final class NsdManager {
} }
} }
/**
* A listener wrapping calls to an app-provided listener, while keeping track of found
* services, so they can all be reported lost when the underlying network is lost.
*
* This should be registered to run on the service handler.
*/
private class DelegatingDiscoveryListener implements DiscoveryListener { private class DelegatingDiscoveryListener implements DiscoveryListener {
private final Network mNetwork; private final Network mNetwork;
private final DiscoveryListener mWrapped; private final DiscoveryListener mWrapped;
private final Executor mWrappedExecutor;
private final ArraySet<TrackedNsdInfo> mFoundInfo = new ArraySet<>(); private final ArraySet<TrackedNsdInfo> mFoundInfo = new ArraySet<>();
private DelegatingDiscoveryListener(Network network, DiscoveryListener listener) { private DelegatingDiscoveryListener(Network network, DiscoveryListener listener,
Executor executor) {
mNetwork = network; mNetwork = network;
mWrapped = listener; mWrapped = listener;
mWrappedExecutor = executor;
} }
void notifyAllServicesLost() { void notifyAllServicesLost() {
@@ -409,7 +422,7 @@ public final class NsdManager {
final NsdServiceInfo serviceInfo = new NsdServiceInfo( final NsdServiceInfo serviceInfo = new NsdServiceInfo(
trackedInfo.mServiceName, trackedInfo.mServiceType); trackedInfo.mServiceName, trackedInfo.mServiceType);
serviceInfo.setNetwork(mNetwork); serviceInfo.setNetwork(mNetwork);
mWrapped.onServiceLost(serviceInfo); mWrappedExecutor.execute(() -> mWrapped.onServiceLost(serviceInfo));
} }
} }
@@ -438,7 +451,7 @@ public final class NsdManager {
// Do not report onStopDiscoveryFailed when some underlying listeners failed: // Do not report onStopDiscoveryFailed when some underlying listeners failed:
// this does not mean that all listeners did, and onStopDiscoveryFailed is not // this does not mean that all listeners did, and onStopDiscoveryFailed is not
// actionable anyway. Just report that discovery stopped. // actionable anyway. Just report that discovery stopped.
mWrapped.onDiscoveryStopped(serviceType); mWrappedExecutor.execute(() -> mWrapped.onDiscoveryStopped(serviceType));
} }
} }
@@ -446,20 +459,20 @@ public final class NsdManager {
public void onDiscoveryStopped(String serviceType) { public void onDiscoveryStopped(String serviceType) {
mPerNetworkListeners.remove(mNetwork); mPerNetworkListeners.remove(mNetwork);
if (mStopRequested && mPerNetworkListeners.size() == 0) { if (mStopRequested && mPerNetworkListeners.size() == 0) {
mWrapped.onDiscoveryStopped(serviceType); mWrappedExecutor.execute(() -> mWrapped.onDiscoveryStopped(serviceType));
} }
} }
@Override @Override
public void onServiceFound(NsdServiceInfo serviceInfo) { public void onServiceFound(NsdServiceInfo serviceInfo) {
mFoundInfo.add(new TrackedNsdInfo(serviceInfo)); mFoundInfo.add(new TrackedNsdInfo(serviceInfo));
mWrapped.onServiceFound(serviceInfo); mWrappedExecutor.execute(() -> mWrapped.onServiceFound(serviceInfo));
} }
@Override @Override
public void onServiceLost(NsdServiceInfo serviceInfo) { public void onServiceLost(NsdServiceInfo serviceInfo) {
mFoundInfo.remove(new TrackedNsdInfo(serviceInfo)); mFoundInfo.remove(new TrackedNsdInfo(serviceInfo));
mWrapped.onServiceLost(serviceInfo); mWrappedExecutor.execute(() -> mWrapped.onServiceLost(serviceInfo));
} }
} }
} }
@@ -642,8 +655,12 @@ public final class NsdManager {
@Override @Override
public void handleMessage(Message message) { public void handleMessage(Message message) {
// Do not use message in the executor lambdas, as it will be recycled once this method
// returns. Keep references to its content instead.
final int what = message.what; final int what = message.what;
final int errorCode = message.arg1;
final int key = message.arg2; final int key = message.arg2;
final Object obj = message.obj;
final Object listener; final Object listener;
final NsdServiceInfo ns; final NsdServiceInfo ns;
final Executor executor; final Executor executor;
@@ -653,7 +670,7 @@ public final class NsdManager {
executor = mExecutorMap.get(key); executor = mExecutorMap.get(key);
} }
if (listener == null) { if (listener == null) {
Log.d(TAG, "Stale key " + message.arg2); Log.d(TAG, "Stale key " + key);
return; return;
} }
if (DBG) { if (DBG) {
@@ -661,28 +678,28 @@ public final class NsdManager {
} }
switch (what) { switch (what) {
case DISCOVER_SERVICES_STARTED: case DISCOVER_SERVICES_STARTED:
final String s = getNsdServiceInfoType((NsdServiceInfo) message.obj); final String s = getNsdServiceInfoType((NsdServiceInfo) obj);
executor.execute(() -> ((DiscoveryListener) listener).onDiscoveryStarted(s)); executor.execute(() -> ((DiscoveryListener) listener).onDiscoveryStarted(s));
break; break;
case DISCOVER_SERVICES_FAILED: case DISCOVER_SERVICES_FAILED:
removeListener(key); removeListener(key);
executor.execute(() -> ((DiscoveryListener) listener).onStartDiscoveryFailed( executor.execute(() -> ((DiscoveryListener) listener).onStartDiscoveryFailed(
getNsdServiceInfoType(ns), message.arg1)); getNsdServiceInfoType(ns), errorCode));
break; break;
case SERVICE_FOUND: case SERVICE_FOUND:
executor.execute(() -> ((DiscoveryListener) listener).onServiceFound( executor.execute(() -> ((DiscoveryListener) listener).onServiceFound(
(NsdServiceInfo) message.obj)); (NsdServiceInfo) obj));
break; break;
case SERVICE_LOST: case SERVICE_LOST:
executor.execute(() -> ((DiscoveryListener) listener).onServiceLost( executor.execute(() -> ((DiscoveryListener) listener).onServiceLost(
(NsdServiceInfo) message.obj)); (NsdServiceInfo) obj));
break; break;
case STOP_DISCOVERY_FAILED: case STOP_DISCOVERY_FAILED:
// TODO: failure to stop discovery should be internal and retried internally, as // TODO: failure to stop discovery should be internal and retried internally, as
// the effect for the client is indistinguishable from STOP_DISCOVERY_SUCCEEDED // the effect for the client is indistinguishable from STOP_DISCOVERY_SUCCEEDED
removeListener(key); removeListener(key);
executor.execute(() -> ((DiscoveryListener) listener).onStopDiscoveryFailed( executor.execute(() -> ((DiscoveryListener) listener).onStopDiscoveryFailed(
getNsdServiceInfoType(ns), message.arg1)); getNsdServiceInfoType(ns), errorCode));
break; break;
case STOP_DISCOVERY_SUCCEEDED: case STOP_DISCOVERY_SUCCEEDED:
removeListener(key); removeListener(key);
@@ -692,33 +709,33 @@ public final class NsdManager {
case REGISTER_SERVICE_FAILED: case REGISTER_SERVICE_FAILED:
removeListener(key); removeListener(key);
executor.execute(() -> ((RegistrationListener) listener).onRegistrationFailed( executor.execute(() -> ((RegistrationListener) listener).onRegistrationFailed(
ns, message.arg1)); ns, errorCode));
break; break;
case REGISTER_SERVICE_SUCCEEDED: case REGISTER_SERVICE_SUCCEEDED:
executor.execute(() -> ((RegistrationListener) listener).onServiceRegistered( executor.execute(() -> ((RegistrationListener) listener).onServiceRegistered(
(NsdServiceInfo) message.obj)); (NsdServiceInfo) obj));
break; break;
case UNREGISTER_SERVICE_FAILED: case UNREGISTER_SERVICE_FAILED:
removeListener(key); removeListener(key);
executor.execute(() -> ((RegistrationListener) listener).onUnregistrationFailed( executor.execute(() -> ((RegistrationListener) listener).onUnregistrationFailed(
ns, message.arg1)); ns, errorCode));
break; break;
case UNREGISTER_SERVICE_SUCCEEDED: case UNREGISTER_SERVICE_SUCCEEDED:
// TODO: do not unregister listener until service is unregistered, or provide // TODO: do not unregister listener until service is unregistered, or provide
// alternative way for unregistering ? // alternative way for unregistering ?
removeListener(message.arg2); removeListener(key);
executor.execute(() -> ((RegistrationListener) listener).onServiceUnregistered( executor.execute(() -> ((RegistrationListener) listener).onServiceUnregistered(
ns)); ns));
break; break;
case RESOLVE_SERVICE_FAILED: case RESOLVE_SERVICE_FAILED:
removeListener(key); removeListener(key);
executor.execute(() -> ((ResolveListener) listener).onResolveFailed( executor.execute(() -> ((ResolveListener) listener).onResolveFailed(
ns, message.arg1)); ns, errorCode));
break; break;
case RESOLVE_SERVICE_SUCCEEDED: case RESOLVE_SERVICE_SUCCEEDED:
removeListener(key); removeListener(key);
executor.execute(() -> ((ResolveListener) listener).onServiceResolved( executor.execute(() -> ((ResolveListener) listener).onServiceResolved(
(NsdServiceInfo) message.obj)); (NsdServiceInfo) obj));
break; break;
default: default:
Log.d(TAG, "Ignored " + message); Log.d(TAG, "Ignored " + message);

View File

@@ -22,6 +22,7 @@ import android.net.LinkProperties
import android.net.Network import android.net.Network
import android.net.NetworkAgentConfig import android.net.NetworkAgentConfig
import android.net.NetworkCapabilities import android.net.NetworkCapabilities
import android.net.NetworkCapabilities.NET_CAPABILITY_TEMPORARILY_NOT_METERED
import android.net.NetworkCapabilities.NET_CAPABILITY_TRUSTED import android.net.NetworkCapabilities.NET_CAPABILITY_TRUSTED
import android.net.NetworkCapabilities.TRANSPORT_TEST import android.net.NetworkCapabilities.TRANSPORT_TEST
import android.net.NetworkRequest import android.net.NetworkRequest
@@ -45,7 +46,9 @@ import android.net.nsd.NsdManager.DiscoveryListener
import android.net.nsd.NsdManager.RegistrationListener import android.net.nsd.NsdManager.RegistrationListener
import android.net.nsd.NsdManager.ResolveListener import android.net.nsd.NsdManager.ResolveListener
import android.net.nsd.NsdServiceInfo import android.net.nsd.NsdServiceInfo
import android.os.Handler
import android.os.HandlerThread import android.os.HandlerThread
import android.os.Process.myTid
import android.platform.test.annotations.AppModeFull import android.platform.test.annotations.AppModeFull
import android.util.Log import android.util.Log
import androidx.test.platform.app.InstrumentationRegistry import androidx.test.platform.app.InstrumentationRegistry
@@ -111,12 +114,20 @@ class NsdManagerTest {
private interface NsdEvent private interface NsdEvent
private open class NsdRecord<T : NsdEvent> private constructor( private open class NsdRecord<T : NsdEvent> private constructor(
private val history: ArrayTrackRecord<T> private val history: ArrayTrackRecord<T>,
private val expectedThreadId: Int? = null
) : TrackRecord<T> by history { ) : TrackRecord<T> by history {
constructor() : this(ArrayTrackRecord()) constructor(expectedThreadId: Int? = null) : this(ArrayTrackRecord(), expectedThreadId)
val nextEvents = history.newReadHead() val nextEvents = history.newReadHead()
override fun add(e: T): Boolean {
if (expectedThreadId != null) {
assertEquals(expectedThreadId, myTid(), "Callback is running on the wrong thread")
}
return history.add(e)
}
inline fun <reified V : NsdEvent> expectCallbackEventually( inline fun <reified V : NsdEvent> expectCallbackEventually(
crossinline predicate: (V) -> Boolean = { true } crossinline predicate: (V) -> Boolean = { true }
): V = nextEvents.poll(TIMEOUT_MS) { e -> e is V && predicate(e) } as V? ): V = nextEvents.poll(TIMEOUT_MS) { e -> e is V && predicate(e) } as V?
@@ -136,8 +147,8 @@ class NsdManagerTest {
} }
} }
private class NsdRegistrationRecord : RegistrationListener, private class NsdRegistrationRecord(expectedThreadId: Int? = null) : RegistrationListener,
NsdRecord<NsdRegistrationRecord.RegistrationEvent>() { NsdRecord<NsdRegistrationRecord.RegistrationEvent>(expectedThreadId) {
sealed class RegistrationEvent : NsdEvent { sealed class RegistrationEvent : NsdEvent {
abstract val serviceInfo: NsdServiceInfo abstract val serviceInfo: NsdServiceInfo
@@ -174,8 +185,8 @@ class NsdManagerTest {
} }
} }
private class NsdDiscoveryRecord : DiscoveryListener, private class NsdDiscoveryRecord(expectedThreadId: Int? = null) :
NsdRecord<NsdDiscoveryRecord.DiscoveryEvent>() { DiscoveryListener, NsdRecord<NsdDiscoveryRecord.DiscoveryEvent>(expectedThreadId) {
sealed class DiscoveryEvent : NsdEvent { sealed class DiscoveryEvent : NsdEvent {
data class StartDiscoveryFailed(val serviceType: String, val errorCode: Int) data class StartDiscoveryFailed(val serviceType: String, val errorCode: Int)
: DiscoveryEvent() : DiscoveryEvent()
@@ -462,9 +473,12 @@ class NsdManagerTest {
si.serviceName = this.serviceName si.serviceName = this.serviceName
si.port = 12345 // Test won't try to connect so port does not matter si.port = 12345 // Test won't try to connect so port does not matter
val registrationRecord = NsdRegistrationRecord() val handler = Handler(handlerThread.looper)
val registeredInfo1 = registerService(registrationRecord, si) val executor = Executor { handler.post(it) }
val discoveryRecord = NsdDiscoveryRecord()
val registrationRecord = NsdRegistrationRecord(expectedThreadId = handlerThread.threadId)
val registeredInfo1 = registerService(registrationRecord, si, executor)
val discoveryRecord = NsdDiscoveryRecord(expectedThreadId = handlerThread.threadId)
tryTest { tryTest {
val specifier = TestNetworkSpecifier(testNetwork1.iface.interfaceName) val specifier = TestNetworkSpecifier(testNetwork1.iface.interfaceName)
@@ -474,7 +488,7 @@ class NsdManagerTest {
.addTransportType(TRANSPORT_TEST) .addTransportType(TRANSPORT_TEST)
.setNetworkSpecifier(specifier) .setNetworkSpecifier(specifier)
.build(), .build(),
Executor { it.run() }, discoveryRecord) executor, discoveryRecord)
val discoveryStarted = discoveryRecord.expectCallback<DiscoveryStarted>() val discoveryStarted = discoveryRecord.expectCallback<DiscoveryStarted>()
assertEquals(SERVICE_TYPE, discoveryStarted.serviceType) assertEquals(SERVICE_TYPE, discoveryStarted.serviceType)
@@ -490,7 +504,7 @@ class NsdManagerTest {
assertEquals(testNetwork1.network, nsdShim.getNetwork(serviceLost1.serviceInfo)) assertEquals(testNetwork1.network, nsdShim.getNetwork(serviceLost1.serviceInfo))
registrationRecord.expectCallback<ServiceUnregistered>() registrationRecord.expectCallback<ServiceUnregistered>()
val registeredInfo2 = registerService(registrationRecord, si) val registeredInfo2 = registerService(registrationRecord, si, executor)
val serviceDiscovered2 = discoveryRecord.expectCallback<ServiceFound>() val serviceDiscovered2 = discoveryRecord.expectCallback<ServiceFound>()
assertEquals(registeredInfo2.serviceName, serviceDiscovered2.serviceInfo.serviceName) assertEquals(registeredInfo2.serviceName, serviceDiscovered2.serviceInfo.serviceName)
assertEquals(testNetwork1.network, nsdShim.getNetwork(serviceDiscovered2.serviceInfo)) assertEquals(testNetwork1.network, nsdShim.getNetwork(serviceDiscovered2.serviceInfo))
@@ -517,6 +531,39 @@ class NsdManagerTest {
} }
} }
@Test
fun testNsdManager_DiscoverWithNetworkRequest_NoMatchingNetwork() {
// This test requires shims supporting T+ APIs (discovering on network request)
assumeTrue(TestUtils.shouldTestTApis())
val si = NsdServiceInfo()
si.serviceType = SERVICE_TYPE
si.serviceName = this.serviceName
si.port = 12345 // Test won't try to connect so port does not matter
val handler = Handler(handlerThread.looper)
val executor = Executor { handler.post(it) }
val discoveryRecord = NsdDiscoveryRecord(expectedThreadId = handlerThread.threadId)
val specifier = TestNetworkSpecifier(testNetwork1.iface.interfaceName)
tryTest {
nsdShim.discoverServices(nsdManager, SERVICE_TYPE, NsdManager.PROTOCOL_DNS_SD,
NetworkRequest.Builder()
.removeCapability(NET_CAPABILITY_TRUSTED)
.addTransportType(TRANSPORT_TEST)
// Specified network does not have this capability
.addCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
.setNetworkSpecifier(specifier)
.build(),
executor, discoveryRecord)
discoveryRecord.expectCallback<DiscoveryStarted>()
} cleanup {
nsdManager.stopServiceDiscovery(discoveryRecord)
discoveryRecord.expectCallback<DiscoveryStopped>()
}
}
@Test @Test
fun testNsdManager_ResolveOnNetwork() { fun testNsdManager_ResolveOnNetwork() {
// This test requires shims supporting T+ APIs (NsdServiceInfo.network) // This test requires shims supporting T+ APIs (NsdServiceInfo.network)
@@ -648,9 +695,12 @@ class NsdManagerTest {
/** /**
* Register a service and return its registration record. * Register a service and return its registration record.
*/ */
private fun registerService(record: NsdRegistrationRecord, si: NsdServiceInfo): NsdServiceInfo { private fun registerService(
nsdShim.registerService(nsdManager, si, NsdManager.PROTOCOL_DNS_SD, Executor { it.run() }, record: NsdRegistrationRecord,
record) si: NsdServiceInfo,
executor: Executor = Executor { it.run() }
): NsdServiceInfo {
nsdShim.registerService(nsdManager, si, NsdManager.PROTOCOL_DNS_SD, executor, record)
// We may not always get the name that we tried to register; // We may not always get the name that we tried to register;
// This events tells us the name that was registered. // This events tells us the name that was registered.
val cb = record.expectCallback<ServiceRegistered>() val cb = record.expectCallback<ServiceRegistered>()