diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java index 49c6ef0d08..4ad39e1b41 100644 --- a/service-t/src/com/android/server/NsdService.java +++ b/service-t/src/com/android/server/NsdService.java @@ -118,13 +118,15 @@ public class NsdService extends INsdManager.Stub { private final NsdStateMachine mNsdStateMachine; private final MDnsManager mMDnsManager; private final MDnsEventCallback mMDnsEventCallback; - @Nullable + @NonNull + private final Dependencies mDeps; + @NonNull private final MdnsMultinetworkSocketClient mMdnsSocketClient; - @Nullable + @NonNull private final MdnsDiscoveryManager mMdnsDiscoveryManager; - @Nullable + @NonNull private final MdnsSocketProvider mMdnsSocketProvider; - @Nullable + @NonNull private final MdnsAdvertiser mAdvertiser; // WARNING : Accessing these values in any thread is not safe, it must only be changed in the // state machine thread. If change this outside state machine, it will need to introduce @@ -311,21 +313,14 @@ public class NsdService extends INsdManager.Stub { mIsMonitoringSocketsStarted = true; } - private void maybeStopMonitoringSockets() { - if (!mIsMonitoringSocketsStarted) { - if (DBG) Log.d(TAG, "Socket monitoring has not been started."); - return; - } + private void maybeStopMonitoringSocketsIfNoActiveRequest() { + if (!mIsMonitoringSocketsStarted) return; + if (isAnyRequestActive()) return; + mMdnsSocketProvider.stopMonitoringSockets(); mIsMonitoringSocketsStarted = false; } - private void maybeStopMonitoringSocketsIfNoActiveRequest() { - if (!isAnyRequestActive()) { - maybeStopMonitoringSockets(); - } - } - NsdStateMachine(String name, Handler handler) { super(name, handler); addState(mDefaultState); @@ -362,9 +357,7 @@ public class NsdService extends INsdManager.Stub { mLegacyClientCount -= 1; } } - if (mMdnsDiscoveryManager != null || mAdvertiser != null) { - maybeStopMonitoringSocketsIfNoActiveRequest(); - } + maybeStopMonitoringSocketsIfNoActiveRequest(); maybeScheduleStop(); break; case NsdManager.DISCOVER_SERVICES: @@ -579,7 +572,7 @@ public class NsdService extends INsdManager.Stub { final NsdServiceInfo info = args.serviceInfo; id = getUniqueId(); - if (mMdnsDiscoveryManager != null) { + if (mDeps.isMdnsDiscoveryManagerEnabled(mContext)) { final String serviceType = constructServiceType(info.getServiceType()); if (serviceType == null) { clientInfo.onDiscoverServicesFailed(clientId, @@ -634,6 +627,9 @@ public class NsdService extends INsdManager.Stub { break; } id = request.mGlobalId; + // Note isMdnsDiscoveryManagerEnabled may have changed to false at this + // point, so this needs to check the type of the original request to + // unregister instead of looking at the flag value. if (request instanceof DiscoveryManagerRequest) { final MdnsListener listener = ((DiscoveryManagerRequest) request).mListener; @@ -671,7 +667,7 @@ public class NsdService extends INsdManager.Stub { } id = getUniqueId(); - if (mAdvertiser != null) { + if (mDeps.isMdnsAdvertiserEnabled(mContext)) { final NsdServiceInfo serviceInfo = args.serviceInfo; final String serviceType = serviceInfo.getServiceType(); final String registerServiceType = constructServiceType(serviceType); @@ -722,7 +718,10 @@ public class NsdService extends INsdManager.Stub { id = request.mGlobalId; removeRequestMap(clientId, id, clientInfo); - if (mAdvertiser != null) { + // Note isMdnsAdvertiserEnabled may have changed to false at this point, + // so this needs to check the type of the original request to unregister + // instead of looking at the flag value. + if (request instanceof AdvertiserClientRequest) { mAdvertiser.removeService(id); clientInfo.onUnregisterServiceSucceeded(clientId); } else { @@ -749,7 +748,7 @@ public class NsdService extends INsdManager.Stub { final NsdServiceInfo info = args.serviceInfo; id = getUniqueId(); - if (mMdnsDiscoveryManager != null) { + if (mDeps.isMdnsDiscoveryManagerEnabled(mContext)) { final String serviceType = constructServiceType(info.getServiceType()); if (serviceType == null) { clientInfo.onResolveServiceFailed(clientId, @@ -1241,32 +1240,16 @@ public class NsdService extends INsdManager.Stub { mNsdStateMachine.start(); mMDnsManager = ctx.getSystemService(MDnsManager.class); mMDnsEventCallback = new MDnsEventCallback(mNsdStateMachine); + mDeps = deps; - final boolean discoveryManagerEnabled = deps.isMdnsDiscoveryManagerEnabled(ctx); - final boolean advertiserEnabled = deps.isMdnsAdvertiserEnabled(ctx); - if (discoveryManagerEnabled || advertiserEnabled) { - mMdnsSocketProvider = deps.makeMdnsSocketProvider(ctx, handler.getLooper()); - } else { - mMdnsSocketProvider = null; - } - - if (discoveryManagerEnabled) { - mMdnsSocketClient = - new MdnsMultinetworkSocketClient(handler.getLooper(), mMdnsSocketProvider); - mMdnsDiscoveryManager = - deps.makeMdnsDiscoveryManager(new ExecutorProvider(), mMdnsSocketClient); - handler.post(() -> mMdnsSocketClient.setCallback(mMdnsDiscoveryManager)); - } else { - mMdnsSocketClient = null; - mMdnsDiscoveryManager = null; - } - - if (advertiserEnabled) { - mAdvertiser = deps.makeMdnsAdvertiser(handler.getLooper(), mMdnsSocketProvider, - new AdvertiserCallback()); - } else { - mAdvertiser = null; - } + mMdnsSocketProvider = deps.makeMdnsSocketProvider(ctx, handler.getLooper()); + mMdnsSocketClient = + new MdnsMultinetworkSocketClient(handler.getLooper(), mMdnsSocketProvider); + mMdnsDiscoveryManager = + deps.makeMdnsDiscoveryManager(new ExecutorProvider(), mMdnsSocketClient); + handler.post(() -> mMdnsSocketClient.setCallback(mMdnsDiscoveryManager)); + mAdvertiser = deps.makeMdnsAdvertiser(handler.getLooper(), mMdnsSocketProvider, + new AdvertiserCallback()); } /** diff --git a/tests/unit/java/com/android/server/NsdServiceTest.java b/tests/unit/java/com/android/server/NsdServiceTest.java index 98a8ed258e..a2c4b9be83 100644 --- a/tests/unit/java/com/android/server/NsdServiceTest.java +++ b/tests/unit/java/com/android/server/NsdServiceTest.java @@ -45,6 +45,7 @@ import static org.mockito.Mockito.reset; import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; import android.compat.testing.PlatformCompatChangeRule; @@ -170,6 +171,9 @@ public class NsdServiceTest { doReturn(true).when(mMockMDnsM).resolve( anyInt(), anyString(), anyString(), anyString(), anyInt()); doReturn(false).when(mDeps).isMdnsDiscoveryManagerEnabled(any(Context.class)); + doReturn(mDiscoveryManager).when(mDeps).makeMdnsDiscoveryManager(any(), any()); + doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any()); + doReturn(mAdvertiser).when(mDeps).makeMdnsAdvertiser(any(), any(), any()); mService = makeService(); } @@ -824,40 +828,50 @@ public class NsdServiceTest { client.unregisterServiceInfoCallback(serviceInfoCallback)); } - private void makeServiceWithMdnsDiscoveryManagerEnabled() { + private void setMdnsDiscoveryManagerEnabled() { doReturn(true).when(mDeps).isMdnsDiscoveryManagerEnabled(any(Context.class)); - doReturn(mDiscoveryManager).when(mDeps).makeMdnsDiscoveryManager(any(), any()); - doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any()); - - mService = makeService(); - verify(mDeps).makeMdnsDiscoveryManager(any(), any()); - verify(mDeps).makeMdnsSocketProvider(any(), any()); } - private void makeServiceWithMdnsAdvertiserEnabled() { + private void setMdnsAdvertiserEnabled() { doReturn(true).when(mDeps).isMdnsAdvertiserEnabled(any(Context.class)); - doReturn(mAdvertiser).when(mDeps).makeMdnsAdvertiser(any(), any(), any()); - doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any()); - - mService = makeService(); - verify(mDeps).makeMdnsAdvertiser(any(), any(), any()); - verify(mDeps).makeMdnsSocketProvider(any(), any()); } @Test public void testMdnsDiscoveryManagerFeature() { // Create NsdService w/o feature enabled. - connectClient(mService); - verify(mDeps, never()).makeMdnsDiscoveryManager(any(), any()); - verify(mDeps, never()).makeMdnsSocketProvider(any(), any()); + final NsdManager client = connectClient(mService); + final DiscoveryListener discListenerWithoutFeature = mock(DiscoveryListener.class); + client.discoverServices(SERVICE_TYPE, PROTOCOL, discListenerWithoutFeature); + waitForIdle(); - // Create NsdService again w/ feature enabled. - makeServiceWithMdnsDiscoveryManagerEnabled(); + final ArgumentCaptor legacyIdCaptor = ArgumentCaptor.forClass(Integer.class); + verify(mMockMDnsM).discover(legacyIdCaptor.capture(), any(), anyInt()); + verifyNoMoreInteractions(mDiscoveryManager); + + setMdnsDiscoveryManagerEnabled(); + final DiscoveryListener discListenerWithFeature = mock(DiscoveryListener.class); + client.discoverServices(SERVICE_TYPE, PROTOCOL, discListenerWithFeature); + waitForIdle(); + + final String serviceTypeWithLocalDomain = SERVICE_TYPE + ".local"; + final ArgumentCaptor listenerCaptor = + ArgumentCaptor.forClass(MdnsServiceBrowserListener.class); + verify(mDiscoveryManager).registerListener(eq(serviceTypeWithLocalDomain), + listenerCaptor.capture(), any()); + + client.stopServiceDiscovery(discListenerWithoutFeature); + waitForIdle(); + verify(mMockMDnsM).stopOperation(legacyIdCaptor.getValue()); + + client.stopServiceDiscovery(discListenerWithFeature); + waitForIdle(); + verify(mDiscoveryManager).unregisterListener(serviceTypeWithLocalDomain, + listenerCaptor.getValue()); } @Test public void testDiscoveryWithMdnsDiscoveryManager() { - makeServiceWithMdnsDiscoveryManagerEnabled(); + setMdnsDiscoveryManagerEnabled(); final NsdManager client = connectClient(mService); final DiscoveryListener discListener = mock(DiscoveryListener.class); @@ -922,7 +936,7 @@ public class NsdServiceTest { @Test public void testDiscoveryWithMdnsDiscoveryManager_FailedWithInvalidServiceType() { - makeServiceWithMdnsDiscoveryManagerEnabled(); + setMdnsDiscoveryManagerEnabled(); final NsdManager client = connectClient(mService); final DiscoveryListener discListener = mock(DiscoveryListener.class); @@ -951,7 +965,7 @@ public class NsdServiceTest { @Test public void testResolutionWithMdnsDiscoveryManager() throws UnknownHostException { - makeServiceWithMdnsDiscoveryManagerEnabled(); + setMdnsDiscoveryManagerEnabled(); final NsdManager client = connectClient(mService); final ResolveListener resolveListener = mock(ResolveListener.class); @@ -1004,9 +1018,44 @@ public class NsdServiceTest { verify(mSocketProvider, timeout(CLEANUP_DELAY_MS + TIMEOUT_MS)).stopMonitoringSockets(); } + @Test + public void testMdnsAdvertiserFeatureFlagging() { + // Create NsdService w/o feature enabled. + final NsdManager client = connectClient(mService); + final NsdServiceInfo regInfo = new NsdServiceInfo(SERVICE_NAME, SERVICE_TYPE); + regInfo.setHost(parseNumericAddress("192.0.2.123")); + regInfo.setPort(12345); + final RegistrationListener regListenerWithoutFeature = mock(RegistrationListener.class); + client.registerService(regInfo, PROTOCOL, regListenerWithoutFeature); + waitForIdle(); + + final ArgumentCaptor legacyIdCaptor = ArgumentCaptor.forClass(Integer.class); + verify(mMockMDnsM).registerService(legacyIdCaptor.capture(), any(), any(), anyInt(), + any(), anyInt()); + verifyNoMoreInteractions(mAdvertiser); + + setMdnsAdvertiserEnabled(); + final RegistrationListener regListenerWithFeature = mock(RegistrationListener.class); + client.registerService(regInfo, PROTOCOL, regListenerWithFeature); + waitForIdle(); + + final ArgumentCaptor serviceIdCaptor = ArgumentCaptor.forClass(Integer.class); + verify(mAdvertiser).addService(serviceIdCaptor.capture(), + argThat(info -> matches(info, regInfo))); + + client.unregisterService(regListenerWithoutFeature); + waitForIdle(); + verify(mMockMDnsM).stopOperation(legacyIdCaptor.getValue()); + verify(mAdvertiser, never()).removeService(anyInt()); + + client.unregisterService(regListenerWithFeature); + waitForIdle(); + verify(mAdvertiser).removeService(serviceIdCaptor.getValue()); + } + @Test public void testAdvertiseWithMdnsAdvertiser() { - makeServiceWithMdnsAdvertiserEnabled(); + setMdnsAdvertiserEnabled(); final NsdManager client = connectClient(mService); final RegistrationListener regListener = mock(RegistrationListener.class); @@ -1045,7 +1094,7 @@ public class NsdServiceTest { @Test public void testAdvertiseWithMdnsAdvertiser_FailedWithInvalidServiceType() { - makeServiceWithMdnsAdvertiserEnabled(); + setMdnsAdvertiserEnabled(); final NsdManager client = connectClient(mService); final RegistrationListener regListener = mock(RegistrationListener.class); @@ -1070,7 +1119,7 @@ public class NsdServiceTest { @Test public void testAdvertiseWithMdnsAdvertiser_LongServiceName() { - makeServiceWithMdnsAdvertiserEnabled(); + setMdnsAdvertiserEnabled(); final NsdManager client = connectClient(mService); final RegistrationListener regListener = mock(RegistrationListener.class);