diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java index 4af4c6abb8..29a03d163d 100644 --- a/service-t/src/com/android/server/NsdService.java +++ b/service-t/src/com/android/server/NsdService.java @@ -1394,7 +1394,8 @@ public class NsdService extends INsdManager.Stub { mMdnsSocketClient = new MdnsMultinetworkSocketClient(handler.getLooper(), mMdnsSocketProvider); mMdnsDiscoveryManager = deps.makeMdnsDiscoveryManager(new ExecutorProvider(), - mMdnsSocketClient, LOGGER.forSubComponent("MdnsDiscoveryManager")); + mMdnsSocketClient, LOGGER.forSubComponent("MdnsDiscoveryManager"), + handler.getLooper()); handler.post(() -> mMdnsSocketClient.setCallback(mMdnsDiscoveryManager)); mAdvertiser = deps.makeMdnsAdvertiser(handler.getLooper(), mMdnsSocketProvider, new AdvertiserCallback(), LOGGER.forSubComponent("MdnsAdvertiser")); @@ -1452,8 +1453,9 @@ public class NsdService extends INsdManager.Stub { */ public MdnsDiscoveryManager makeMdnsDiscoveryManager( @NonNull ExecutorProvider executorProvider, - @NonNull MdnsSocketClientBase socketClient, @NonNull SharedLog sharedLog) { - return new MdnsDiscoveryManager(executorProvider, socketClient, sharedLog); + @NonNull MdnsSocketClientBase socketClient, @NonNull SharedLog sharedLog, + @NonNull Looper looper) { + return new MdnsDiscoveryManager(executorProvider, socketClient, sharedLog, looper); } /** diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java index c7b93e5041..01548275a2 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java @@ -16,18 +16,21 @@ package com.android.server.connectivity.mdns; +import static com.android.server.connectivity.mdns.util.MdnsUtils.ensureRunningOnHandlerThread; import static com.android.server.connectivity.mdns.util.MdnsUtils.isNetworkMatched; +import static com.android.server.connectivity.mdns.util.MdnsUtils.isRunningOnHandlerThread; import android.Manifest.permission; import android.annotation.NonNull; import android.annotation.Nullable; import android.annotation.RequiresPermission; import android.net.Network; +import android.os.Handler; +import android.os.Looper; import android.util.ArrayMap; import android.util.Log; import android.util.Pair; -import com.android.internal.annotations.GuardedBy; import com.android.internal.annotations.VisibleForTesting; import com.android.net.module.util.SharedLog; import com.android.server.connectivity.mdns.util.MdnsUtils; @@ -48,8 +51,8 @@ public class MdnsDiscoveryManager implements MdnsSocketClientBase.Callback { private final MdnsSocketClientBase socketClient; @NonNull private final SharedLog sharedLog; - @GuardedBy("this") @NonNull private final PerNetworkServiceTypeClients perNetworkServiceTypeClients; + @NonNull private final Handler handler; private static class PerNetworkServiceTypeClients { private final ArrayMap, MdnsServiceTypeClient> clients = @@ -109,11 +112,21 @@ public class MdnsDiscoveryManager implements MdnsSocketClientBase.Callback { } public MdnsDiscoveryManager(@NonNull ExecutorProvider executorProvider, - @NonNull MdnsSocketClientBase socketClient, @NonNull SharedLog sharedLog) { + @NonNull MdnsSocketClientBase socketClient, @NonNull SharedLog sharedLog, + @NonNull Looper looper) { this.executorProvider = executorProvider; this.socketClient = socketClient; this.sharedLog = sharedLog; perNetworkServiceTypeClients = new PerNetworkServiceTypeClients(); + handler = new Handler(looper); + } + + private void checkAndRunOnHandlerThread(@NonNull Runnable function) { + if (isRunningOnHandlerThread(handler)) { + function.run(); + } else { + handler.post(function); + } } /** @@ -126,11 +139,19 @@ public class MdnsDiscoveryManager implements MdnsSocketClientBase.Callback { * serviceType}. */ @RequiresPermission(permission.CHANGE_WIFI_MULTICAST_STATE) - public synchronized void registerListener( + public void registerListener( @NonNull String serviceType, @NonNull MdnsServiceBrowserListener listener, @NonNull MdnsSearchOptions searchOptions) { sharedLog.i("Registering listener for serviceType: " + serviceType); + checkAndRunOnHandlerThread(() -> + handleRegisterListener(serviceType, listener, searchOptions)); + } + + private void handleRegisterListener( + @NonNull String serviceType, + @NonNull MdnsServiceBrowserListener listener, + @NonNull MdnsSearchOptions searchOptions) { if (perNetworkServiceTypeClients.isEmpty()) { // First listener. Starts the socket client. try { @@ -145,30 +166,28 @@ public class MdnsDiscoveryManager implements MdnsSocketClientBase.Callback { new MdnsSocketClientBase.SocketCreationCallback() { @Override public void onSocketCreated(@Nullable Network network) { - synchronized (MdnsDiscoveryManager.this) { - // All listeners of the same service types shares the same - // MdnsServiceTypeClient. - MdnsServiceTypeClient serviceTypeClient = - perNetworkServiceTypeClients.get(serviceType, network); - if (serviceTypeClient == null) { - serviceTypeClient = createServiceTypeClient(serviceType, network); - perNetworkServiceTypeClients.put(serviceType, network, - serviceTypeClient); - } - serviceTypeClient.startSendAndReceive(listener, searchOptions); + ensureRunningOnHandlerThread(handler); + // All listeners of the same service types shares the same + // MdnsServiceTypeClient. + MdnsServiceTypeClient serviceTypeClient = + perNetworkServiceTypeClients.get(serviceType, network); + if (serviceTypeClient == null) { + serviceTypeClient = createServiceTypeClient(serviceType, network); + perNetworkServiceTypeClients.put(serviceType, network, + serviceTypeClient); } + serviceTypeClient.startSendAndReceive(listener, searchOptions); } @Override public void onAllSocketsDestroyed(@Nullable Network network) { - synchronized (MdnsDiscoveryManager.this) { - final MdnsServiceTypeClient serviceTypeClient = - perNetworkServiceTypeClients.get(serviceType, network); - if (serviceTypeClient == null) return; - // Notify all listeners that all services are removed from this socket. - serviceTypeClient.notifySocketDestroyed(); - perNetworkServiceTypeClients.remove(serviceTypeClient); - } + ensureRunningOnHandlerThread(handler); + final MdnsServiceTypeClient serviceTypeClient = + perNetworkServiceTypeClients.get(serviceType, network); + if (serviceTypeClient == null) return; + // Notify all listeners that all services are removed from this socket. + serviceTypeClient.notifySocketDestroyed(); + perNetworkServiceTypeClients.remove(serviceTypeClient); } }); } @@ -181,9 +200,14 @@ public class MdnsDiscoveryManager implements MdnsSocketClientBase.Callback { * @param listener The {@link MdnsServiceBrowserListener} listener. */ @RequiresPermission(permission.CHANGE_WIFI_MULTICAST_STATE) - public synchronized void unregisterListener( + public void unregisterListener( @NonNull String serviceType, @NonNull MdnsServiceBrowserListener listener) { sharedLog.i("Unregistering listener for serviceType:" + serviceType); + checkAndRunOnHandlerThread(() -> handleUnregisterListener(serviceType, listener)); + } + + private void handleUnregisterListener( + @NonNull String serviceType, @NonNull MdnsServiceBrowserListener listener) { final List serviceTypeClients = perNetworkServiceTypeClients.getByServiceType(serviceType); if (serviceTypeClients.isEmpty()) { @@ -206,8 +230,14 @@ public class MdnsDiscoveryManager implements MdnsSocketClientBase.Callback { } @Override - public synchronized void onResponseReceived(@NonNull MdnsPacket packet, - int interfaceIndex, Network network) { + public void onResponseReceived(@NonNull MdnsPacket packet, + int interfaceIndex, @Nullable Network network) { + checkAndRunOnHandlerThread(() -> + handleOnResponseReceived(packet, interfaceIndex, network)); + } + + private void handleOnResponseReceived(@NonNull MdnsPacket packet, int interfaceIndex, + @Nullable Network network) { for (MdnsServiceTypeClient serviceTypeClient : perNetworkServiceTypeClients.getByMatchingNetwork(network)) { serviceTypeClient.processResponse(packet, interfaceIndex, network); @@ -215,8 +245,14 @@ public class MdnsDiscoveryManager implements MdnsSocketClientBase.Callback { } @Override - public synchronized void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode, - Network network) { + public void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode, + @Nullable Network network) { + checkAndRunOnHandlerThread(() -> + handleOnFailedToParseMdnsResponse(receivedPacketNumber, errorCode, network)); + } + + private void handleOnFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode, + @Nullable Network network) { for (MdnsServiceTypeClient serviceTypeClient : perNetworkServiceTypeClients.getByMatchingNetwork(network)) { serviceTypeClient.onFailedToParseMdnsResponse(receivedPacketNumber, errorCode); diff --git a/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java b/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java index 63d1a504fe..bc948699a8 100644 --- a/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java +++ b/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java @@ -115,12 +115,20 @@ public class MdnsUtils { /*** Ensure that current running thread is same as given handler thread */ public static void ensureRunningOnHandlerThread(@NonNull Handler handler) { - if (handler.getLooper().getThread() != Thread.currentThread()) { + if (!isRunningOnHandlerThread(handler)) { throw new IllegalStateException( "Not running on Handler thread: " + Thread.currentThread().getName()); } } + /*** Check that current running thread is same as given handler thread */ + public static boolean isRunningOnHandlerThread(@NonNull Handler handler) { + if (handler.getLooper().getThread() == Thread.currentThread()) { + return true; + } + return false; + } + /*** Check whether the target network is matched current network */ public static boolean isNetworkMatched(@Nullable Network targetNetwork, @Nullable Network currentNetwork) { diff --git a/tests/unit/java/com/android/server/NsdServiceTest.java b/tests/unit/java/com/android/server/NsdServiceTest.java index b3e8cc844b..955be1246d 100644 --- a/tests/unit/java/com/android/server/NsdServiceTest.java +++ b/tests/unit/java/com/android/server/NsdServiceTest.java @@ -178,10 +178,10 @@ public class NsdServiceTest { doReturn(true).when(mMockMDnsM).resolve( anyInt(), anyString(), anyString(), anyString(), anyInt()); doReturn(false).when(mDeps).isMdnsDiscoveryManagerEnabled(any(Context.class)); - doReturn(mDiscoveryManager).when(mDeps).makeMdnsDiscoveryManager(any(), any(), any()); + doReturn(mDiscoveryManager).when(mDeps) + .makeMdnsDiscoveryManager(any(), any(), any(), any()); doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any(), any()); doReturn(mAdvertiser).when(mDeps).makeMdnsAdvertiser(any(), any(), any(), any()); - mService = makeService(); } diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java index 45da874139..89776e20ca 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java @@ -27,6 +27,8 @@ import static org.mockito.Mockito.when; import android.annotation.NonNull; import android.annotation.Nullable; import android.net.Network; +import android.os.Handler; +import android.os.HandlerThread; import android.text.TextUtils; import android.util.Pair; @@ -34,7 +36,9 @@ import com.android.net.module.util.SharedLog; import com.android.server.connectivity.mdns.MdnsSocketClientBase.SocketCreationCallback; import com.android.testutils.DevSdkIgnoreRule; import com.android.testutils.DevSdkIgnoreRunner; +import com.android.testutils.HandlerUtils; +import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -53,7 +57,7 @@ import java.util.List; @RunWith(DevSdkIgnoreRunner.class) @DevSdkIgnoreRule.IgnoreUpTo(SC_V2) public class MdnsDiscoveryManagerTests { - + private static final long DEFAULT_TIMEOUT = 2000L; private static final String SERVICE_TYPE_1 = "_googlecast._tcp.local"; private static final String SERVICE_TYPE_2 = "_test._tcp.local"; private static final Network NETWORK_1 = Mockito.mock(Network.class); @@ -78,12 +82,18 @@ public class MdnsDiscoveryManagerTests { @Mock MdnsServiceBrowserListener mockListenerTwo; @Mock SharedLog sharedLog; private MdnsDiscoveryManager discoveryManager; + private HandlerThread thread; + private Handler handler; @Before public void setUp() { MockitoAnnotations.initMocks(this); - discoveryManager = new MdnsDiscoveryManager(executorProvider, socketClient, sharedLog) { + thread = new HandlerThread("MdnsDiscoveryManagerTests"); + thread.start(); + handler = new Handler(thread.getLooper()); + discoveryManager = new MdnsDiscoveryManager(executorProvider, socketClient, sharedLog, + thread.getLooper()) { @Override MdnsServiceTypeClient createServiceTypeClient(@NonNull String serviceType, @Nullable Network network) { @@ -103,11 +113,23 @@ public class MdnsDiscoveryManagerTests { }; } + @After + public void tearDown() { + if (thread != null) { + thread.quitSafely(); + } + } + + private void runOnHandler(Runnable r) { + handler.post(r); + HandlerUtils.waitForIdle(handler, DEFAULT_TIMEOUT); + } + private SocketCreationCallback expectSocketCreationCallback(String serviceType, MdnsServiceBrowserListener listener, MdnsSearchOptions options) throws IOException { final ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(SocketCreationCallback.class); - discoveryManager.registerListener(serviceType, listener, options); + runOnHandler(() -> discoveryManager.registerListener(serviceType, listener, options)); verify(socketClient).startDiscovery(); verify(socketClient).notifyNetworkRequested( eq(listener), eq(options.getNetwork()), callbackCaptor.capture()); @@ -120,11 +142,11 @@ public class MdnsDiscoveryManagerTests { MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build(); final SocketCreationCallback callback = expectSocketCreationCallback( SERVICE_TYPE_1, mockListenerOne, options); - callback.onSocketCreated(null /* network */); + runOnHandler(() -> callback.onSocketCreated(null /* network */)); verify(mockServiceTypeClientOne).startSendAndReceive(mockListenerOne, options); when(mockServiceTypeClientOne.stopSendAndReceive(mockListenerOne)).thenReturn(true); - discoveryManager.unregisterListener(SERVICE_TYPE_1, mockListenerOne); + runOnHandler(() -> discoveryManager.unregisterListener(SERVICE_TYPE_1, mockListenerOne)); verify(mockServiceTypeClientOne).stopSendAndReceive(mockListenerOne); verify(socketClient).stopDiscovery(); } @@ -135,16 +157,16 @@ public class MdnsDiscoveryManagerTests { MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build(); final SocketCreationCallback callback = expectSocketCreationCallback( SERVICE_TYPE_1, mockListenerOne, options); - callback.onSocketCreated(null /* network */); + runOnHandler(() -> callback.onSocketCreated(null /* network */)); verify(mockServiceTypeClientOne).startSendAndReceive(mockListenerOne, options); - callback.onSocketCreated(NETWORK_1); + runOnHandler(() -> callback.onSocketCreated(NETWORK_1)); verify(mockServiceTypeClientOne1).startSendAndReceive(mockListenerOne, options); final SocketCreationCallback callback2 = expectSocketCreationCallback( SERVICE_TYPE_2, mockListenerTwo, options); - callback2.onSocketCreated(null /* network */); + runOnHandler(() -> callback2.onSocketCreated(null /* network */)); verify(mockServiceTypeClientTwo).startSendAndReceive(mockListenerTwo, options); - callback2.onSocketCreated(NETWORK_2); + runOnHandler(() -> callback2.onSocketCreated(NETWORK_2)); verify(mockServiceTypeClientTwo2).startSendAndReceive(mockListenerTwo, options); } @@ -154,21 +176,22 @@ public class MdnsDiscoveryManagerTests { MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build(); final SocketCreationCallback callback = expectSocketCreationCallback( SERVICE_TYPE_1, mockListenerOne, options1); - callback.onSocketCreated(null /* network */); + runOnHandler(() -> callback.onSocketCreated(null /* network */)); verify(mockServiceTypeClientOne).startSendAndReceive(mockListenerOne, options1); - callback.onSocketCreated(NETWORK_1); + runOnHandler(() -> callback.onSocketCreated(NETWORK_1)); verify(mockServiceTypeClientOne1).startSendAndReceive(mockListenerOne, options1); final MdnsSearchOptions options2 = MdnsSearchOptions.newBuilder().setNetwork(NETWORK_2).build(); final SocketCreationCallback callback2 = expectSocketCreationCallback( SERVICE_TYPE_2, mockListenerTwo, options2); - callback2.onSocketCreated(NETWORK_2); + runOnHandler(() -> callback2.onSocketCreated(NETWORK_2)); verify(mockServiceTypeClientTwo2).startSendAndReceive(mockListenerTwo, options2); final MdnsPacket responseForServiceTypeOne = createMdnsPacket(SERVICE_TYPE_1); final int ifIndex = 1; - discoveryManager.onResponseReceived(responseForServiceTypeOne, ifIndex, null /* network */); + runOnHandler(() -> discoveryManager.onResponseReceived( + responseForServiceTypeOne, ifIndex, null /* network */)); verify(mockServiceTypeClientOne).processResponse(responseForServiceTypeOne, ifIndex, null /* network */); verify(mockServiceTypeClientOne1).processResponse(responseForServiceTypeOne, ifIndex, @@ -177,7 +200,8 @@ public class MdnsDiscoveryManagerTests { null /* network */); final MdnsPacket responseForServiceTypeTwo = createMdnsPacket(SERVICE_TYPE_2); - discoveryManager.onResponseReceived(responseForServiceTypeTwo, ifIndex, NETWORK_1); + runOnHandler(() -> discoveryManager.onResponseReceived( + responseForServiceTypeTwo, ifIndex, NETWORK_1)); verify(mockServiceTypeClientOne).processResponse(responseForServiceTypeTwo, ifIndex, NETWORK_1); verify(mockServiceTypeClientOne1).processResponse(responseForServiceTypeTwo, ifIndex, @@ -187,7 +211,8 @@ public class MdnsDiscoveryManagerTests { final MdnsPacket responseForSubtype = createMdnsPacket("subtype._sub._googlecast._tcp.local"); - discoveryManager.onResponseReceived(responseForSubtype, ifIndex, NETWORK_2); + runOnHandler(() -> discoveryManager.onResponseReceived( + responseForSubtype, ifIndex, NETWORK_2)); verify(mockServiceTypeClientOne).processResponse(responseForSubtype, ifIndex, NETWORK_2); verify(mockServiceTypeClientOne1, never()).processResponse( responseForSubtype, ifIndex, NETWORK_2); @@ -201,7 +226,7 @@ public class MdnsDiscoveryManagerTests { MdnsSearchOptions.newBuilder().setNetwork(NETWORK_1).build(); final SocketCreationCallback callback = expectSocketCreationCallback( SERVICE_TYPE_1, mockListenerOne, options1); - callback.onSocketCreated(NETWORK_1); + runOnHandler(() -> callback.onSocketCreated(NETWORK_1)); verify(mockServiceTypeClientOne1).startSendAndReceive(mockListenerOne, options1); // Create a ServiceTypeClient for SERVICE_TYPE_2 and NETWORK_2 @@ -209,26 +234,28 @@ public class MdnsDiscoveryManagerTests { MdnsSearchOptions.newBuilder().setNetwork(NETWORK_2).build(); final SocketCreationCallback callback2 = expectSocketCreationCallback( SERVICE_TYPE_2, mockListenerTwo, options2); - callback2.onSocketCreated(NETWORK_2); + runOnHandler(() -> callback2.onSocketCreated(NETWORK_2)); verify(mockServiceTypeClientTwo2).startSendAndReceive(mockListenerTwo, options2); // Receive a response, it should be processed on both clients. final MdnsPacket response = createMdnsPacket(SERVICE_TYPE_1); final int ifIndex = 1; - discoveryManager.onResponseReceived(response, ifIndex, null /* network */); + runOnHandler(() -> discoveryManager.onResponseReceived( + response, ifIndex, null /* network */)); verify(mockServiceTypeClientOne1).processResponse(response, ifIndex, null /* network */); verify(mockServiceTypeClientTwo2).processResponse(response, ifIndex, null /* network */); // The client for NETWORK_1 receives the callback that the NETWORK_1 has been destroyed, // mockServiceTypeClientOne1 should send service removed notifications and remove from the // list of clients. - callback.onAllSocketsDestroyed(NETWORK_1); + runOnHandler(() -> callback.onAllSocketsDestroyed(NETWORK_1)); verify(mockServiceTypeClientOne1).notifySocketDestroyed(); // Receive a response again, it should be processed only on mockServiceTypeClientTwo2. // Because the mockServiceTypeClientOne1 is removed from the list of clients, it is no // longer able to process responses. - discoveryManager.onResponseReceived(response, ifIndex, null /* network */); + runOnHandler(() -> discoveryManager.onResponseReceived( + response, ifIndex, null /* network */)); verify(mockServiceTypeClientOne1, times(1)) .processResponse(response, ifIndex, null /* network */); verify(mockServiceTypeClientTwo2, times(2)) @@ -236,12 +263,13 @@ public class MdnsDiscoveryManagerTests { // The client for NETWORK_2 receives the callback that the NETWORK_1 has been destroyed, // mockServiceTypeClientTwo2 shouldn't send any notifications. - callback2.onAllSocketsDestroyed(NETWORK_1); + runOnHandler(() -> callback2.onAllSocketsDestroyed(NETWORK_1)); verify(mockServiceTypeClientTwo2, never()).notifySocketDestroyed(); // Receive a response again, mockServiceTypeClientTwo2 is still in the list of clients, it's // still able to process responses. - discoveryManager.onResponseReceived(response, ifIndex, null /* network */); + runOnHandler(() -> discoveryManager.onResponseReceived( + response, ifIndex, null /* network */)); verify(mockServiceTypeClientOne1, times(1)) .processResponse(response, ifIndex, null /* network */); verify(mockServiceTypeClientTwo2, times(3))