Allow Advertiser, DiscoveryManager runtime toggle

Allow toggling MdnsAdvertiser and MdnsDiscoveryManager at runtime, by
always creating them in NsdService constructor, but only using them when
the flag is on when starting discovery, resolve or registration.

When stopping, based on the type of the stored request, stop the
corresponding backend.

Bug: 265891278
Test: atest NsdServiceTest
Change-Id: I7cb2f9fe9e1ed3dc77616689a8e3ffa00f5bc269
This commit is contained in:
Remi NGUYEN VAN
2023-01-18 18:57:41 +09:00
parent 8f453b9c9c
commit a8a777bbbd
2 changed files with 104 additions and 72 deletions

View File

@@ -118,13 +118,15 @@ public class NsdService extends INsdManager.Stub {
private final NsdStateMachine mNsdStateMachine; private final NsdStateMachine mNsdStateMachine;
private final MDnsManager mMDnsManager; private final MDnsManager mMDnsManager;
private final MDnsEventCallback mMDnsEventCallback; private final MDnsEventCallback mMDnsEventCallback;
@Nullable @NonNull
private final Dependencies mDeps;
@NonNull
private final MdnsMultinetworkSocketClient mMdnsSocketClient; private final MdnsMultinetworkSocketClient mMdnsSocketClient;
@Nullable @NonNull
private final MdnsDiscoveryManager mMdnsDiscoveryManager; private final MdnsDiscoveryManager mMdnsDiscoveryManager;
@Nullable @NonNull
private final MdnsSocketProvider mMdnsSocketProvider; private final MdnsSocketProvider mMdnsSocketProvider;
@Nullable @NonNull
private final MdnsAdvertiser mAdvertiser; private final MdnsAdvertiser mAdvertiser;
// WARNING : Accessing these values in any thread is not safe, it must only be changed in the // WARNING : Accessing these values in any thread is not safe, it must only be changed in the
// state machine thread. If change this outside state machine, it will need to introduce // state machine thread. If change this outside state machine, it will need to introduce
@@ -311,21 +313,14 @@ public class NsdService extends INsdManager.Stub {
mIsMonitoringSocketsStarted = true; mIsMonitoringSocketsStarted = true;
} }
private void maybeStopMonitoringSockets() { private void maybeStopMonitoringSocketsIfNoActiveRequest() {
if (!mIsMonitoringSocketsStarted) { if (!mIsMonitoringSocketsStarted) return;
if (DBG) Log.d(TAG, "Socket monitoring has not been started."); if (isAnyRequestActive()) return;
return;
}
mMdnsSocketProvider.stopMonitoringSockets(); mMdnsSocketProvider.stopMonitoringSockets();
mIsMonitoringSocketsStarted = false; mIsMonitoringSocketsStarted = false;
} }
private void maybeStopMonitoringSocketsIfNoActiveRequest() {
if (!isAnyRequestActive()) {
maybeStopMonitoringSockets();
}
}
NsdStateMachine(String name, Handler handler) { NsdStateMachine(String name, Handler handler) {
super(name, handler); super(name, handler);
addState(mDefaultState); addState(mDefaultState);
@@ -362,9 +357,7 @@ public class NsdService extends INsdManager.Stub {
mLegacyClientCount -= 1; mLegacyClientCount -= 1;
} }
} }
if (mMdnsDiscoveryManager != null || mAdvertiser != null) { maybeStopMonitoringSocketsIfNoActiveRequest();
maybeStopMonitoringSocketsIfNoActiveRequest();
}
maybeScheduleStop(); maybeScheduleStop();
break; break;
case NsdManager.DISCOVER_SERVICES: case NsdManager.DISCOVER_SERVICES:
@@ -579,7 +572,7 @@ public class NsdService extends INsdManager.Stub {
final NsdServiceInfo info = args.serviceInfo; final NsdServiceInfo info = args.serviceInfo;
id = getUniqueId(); id = getUniqueId();
if (mMdnsDiscoveryManager != null) { if (mDeps.isMdnsDiscoveryManagerEnabled(mContext)) {
final String serviceType = constructServiceType(info.getServiceType()); final String serviceType = constructServiceType(info.getServiceType());
if (serviceType == null) { if (serviceType == null) {
clientInfo.onDiscoverServicesFailed(clientId, clientInfo.onDiscoverServicesFailed(clientId,
@@ -634,6 +627,9 @@ public class NsdService extends INsdManager.Stub {
break; break;
} }
id = request.mGlobalId; id = request.mGlobalId;
// Note isMdnsDiscoveryManagerEnabled may have changed to false at this
// point, so this needs to check the type of the original request to
// unregister instead of looking at the flag value.
if (request instanceof DiscoveryManagerRequest) { if (request instanceof DiscoveryManagerRequest) {
final MdnsListener listener = final MdnsListener listener =
((DiscoveryManagerRequest) request).mListener; ((DiscoveryManagerRequest) request).mListener;
@@ -671,7 +667,7 @@ public class NsdService extends INsdManager.Stub {
} }
id = getUniqueId(); id = getUniqueId();
if (mAdvertiser != null) { if (mDeps.isMdnsAdvertiserEnabled(mContext)) {
final NsdServiceInfo serviceInfo = args.serviceInfo; final NsdServiceInfo serviceInfo = args.serviceInfo;
final String serviceType = serviceInfo.getServiceType(); final String serviceType = serviceInfo.getServiceType();
final String registerServiceType = constructServiceType(serviceType); final String registerServiceType = constructServiceType(serviceType);
@@ -722,7 +718,10 @@ public class NsdService extends INsdManager.Stub {
id = request.mGlobalId; id = request.mGlobalId;
removeRequestMap(clientId, id, clientInfo); removeRequestMap(clientId, id, clientInfo);
if (mAdvertiser != null) { // Note isMdnsAdvertiserEnabled may have changed to false at this point,
// so this needs to check the type of the original request to unregister
// instead of looking at the flag value.
if (request instanceof AdvertiserClientRequest) {
mAdvertiser.removeService(id); mAdvertiser.removeService(id);
clientInfo.onUnregisterServiceSucceeded(clientId); clientInfo.onUnregisterServiceSucceeded(clientId);
} else { } else {
@@ -749,7 +748,7 @@ public class NsdService extends INsdManager.Stub {
final NsdServiceInfo info = args.serviceInfo; final NsdServiceInfo info = args.serviceInfo;
id = getUniqueId(); id = getUniqueId();
if (mMdnsDiscoveryManager != null) { if (mDeps.isMdnsDiscoveryManagerEnabled(mContext)) {
final String serviceType = constructServiceType(info.getServiceType()); final String serviceType = constructServiceType(info.getServiceType());
if (serviceType == null) { if (serviceType == null) {
clientInfo.onResolveServiceFailed(clientId, clientInfo.onResolveServiceFailed(clientId,
@@ -1241,32 +1240,16 @@ public class NsdService extends INsdManager.Stub {
mNsdStateMachine.start(); mNsdStateMachine.start();
mMDnsManager = ctx.getSystemService(MDnsManager.class); mMDnsManager = ctx.getSystemService(MDnsManager.class);
mMDnsEventCallback = new MDnsEventCallback(mNsdStateMachine); mMDnsEventCallback = new MDnsEventCallback(mNsdStateMachine);
mDeps = deps;
final boolean discoveryManagerEnabled = deps.isMdnsDiscoveryManagerEnabled(ctx); mMdnsSocketProvider = deps.makeMdnsSocketProvider(ctx, handler.getLooper());
final boolean advertiserEnabled = deps.isMdnsAdvertiserEnabled(ctx); mMdnsSocketClient =
if (discoveryManagerEnabled || advertiserEnabled) { new MdnsMultinetworkSocketClient(handler.getLooper(), mMdnsSocketProvider);
mMdnsSocketProvider = deps.makeMdnsSocketProvider(ctx, handler.getLooper()); mMdnsDiscoveryManager =
} else { deps.makeMdnsDiscoveryManager(new ExecutorProvider(), mMdnsSocketClient);
mMdnsSocketProvider = null; handler.post(() -> mMdnsSocketClient.setCallback(mMdnsDiscoveryManager));
} mAdvertiser = deps.makeMdnsAdvertiser(handler.getLooper(), mMdnsSocketProvider,
new AdvertiserCallback());
if (discoveryManagerEnabled) {
mMdnsSocketClient =
new MdnsMultinetworkSocketClient(handler.getLooper(), mMdnsSocketProvider);
mMdnsDiscoveryManager =
deps.makeMdnsDiscoveryManager(new ExecutorProvider(), mMdnsSocketClient);
handler.post(() -> mMdnsSocketClient.setCallback(mMdnsDiscoveryManager));
} else {
mMdnsSocketClient = null;
mMdnsDiscoveryManager = null;
}
if (advertiserEnabled) {
mAdvertiser = deps.makeMdnsAdvertiser(handler.getLooper(), mMdnsSocketProvider,
new AdvertiserCallback());
} else {
mAdvertiser = null;
}
} }
/** /**

View File

@@ -45,6 +45,7 @@ import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import android.compat.testing.PlatformCompatChangeRule; import android.compat.testing.PlatformCompatChangeRule;
@@ -170,6 +171,9 @@ public class NsdServiceTest {
doReturn(true).when(mMockMDnsM).resolve( doReturn(true).when(mMockMDnsM).resolve(
anyInt(), anyString(), anyString(), anyString(), anyInt()); anyInt(), anyString(), anyString(), anyString(), anyInt());
doReturn(false).when(mDeps).isMdnsDiscoveryManagerEnabled(any(Context.class)); doReturn(false).when(mDeps).isMdnsDiscoveryManagerEnabled(any(Context.class));
doReturn(mDiscoveryManager).when(mDeps).makeMdnsDiscoveryManager(any(), any());
doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any());
doReturn(mAdvertiser).when(mDeps).makeMdnsAdvertiser(any(), any(), any());
mService = makeService(); mService = makeService();
} }
@@ -824,40 +828,50 @@ public class NsdServiceTest {
client.unregisterServiceInfoCallback(serviceInfoCallback)); client.unregisterServiceInfoCallback(serviceInfoCallback));
} }
private void makeServiceWithMdnsDiscoveryManagerEnabled() { private void setMdnsDiscoveryManagerEnabled() {
doReturn(true).when(mDeps).isMdnsDiscoveryManagerEnabled(any(Context.class)); doReturn(true).when(mDeps).isMdnsDiscoveryManagerEnabled(any(Context.class));
doReturn(mDiscoveryManager).when(mDeps).makeMdnsDiscoveryManager(any(), any());
doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any());
mService = makeService();
verify(mDeps).makeMdnsDiscoveryManager(any(), any());
verify(mDeps).makeMdnsSocketProvider(any(), any());
} }
private void makeServiceWithMdnsAdvertiserEnabled() { private void setMdnsAdvertiserEnabled() {
doReturn(true).when(mDeps).isMdnsAdvertiserEnabled(any(Context.class)); doReturn(true).when(mDeps).isMdnsAdvertiserEnabled(any(Context.class));
doReturn(mAdvertiser).when(mDeps).makeMdnsAdvertiser(any(), any(), any());
doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any());
mService = makeService();
verify(mDeps).makeMdnsAdvertiser(any(), any(), any());
verify(mDeps).makeMdnsSocketProvider(any(), any());
} }
@Test @Test
public void testMdnsDiscoveryManagerFeature() { public void testMdnsDiscoveryManagerFeature() {
// Create NsdService w/o feature enabled. // Create NsdService w/o feature enabled.
connectClient(mService); final NsdManager client = connectClient(mService);
verify(mDeps, never()).makeMdnsDiscoveryManager(any(), any()); final DiscoveryListener discListenerWithoutFeature = mock(DiscoveryListener.class);
verify(mDeps, never()).makeMdnsSocketProvider(any(), any()); client.discoverServices(SERVICE_TYPE, PROTOCOL, discListenerWithoutFeature);
waitForIdle();
// Create NsdService again w/ feature enabled. final ArgumentCaptor<Integer> legacyIdCaptor = ArgumentCaptor.forClass(Integer.class);
makeServiceWithMdnsDiscoveryManagerEnabled(); verify(mMockMDnsM).discover(legacyIdCaptor.capture(), any(), anyInt());
verifyNoMoreInteractions(mDiscoveryManager);
setMdnsDiscoveryManagerEnabled();
final DiscoveryListener discListenerWithFeature = mock(DiscoveryListener.class);
client.discoverServices(SERVICE_TYPE, PROTOCOL, discListenerWithFeature);
waitForIdle();
final String serviceTypeWithLocalDomain = SERVICE_TYPE + ".local";
final ArgumentCaptor<MdnsServiceBrowserListener> listenerCaptor =
ArgumentCaptor.forClass(MdnsServiceBrowserListener.class);
verify(mDiscoveryManager).registerListener(eq(serviceTypeWithLocalDomain),
listenerCaptor.capture(), any());
client.stopServiceDiscovery(discListenerWithoutFeature);
waitForIdle();
verify(mMockMDnsM).stopOperation(legacyIdCaptor.getValue());
client.stopServiceDiscovery(discListenerWithFeature);
waitForIdle();
verify(mDiscoveryManager).unregisterListener(serviceTypeWithLocalDomain,
listenerCaptor.getValue());
} }
@Test @Test
public void testDiscoveryWithMdnsDiscoveryManager() { public void testDiscoveryWithMdnsDiscoveryManager() {
makeServiceWithMdnsDiscoveryManagerEnabled(); setMdnsDiscoveryManagerEnabled();
final NsdManager client = connectClient(mService); final NsdManager client = connectClient(mService);
final DiscoveryListener discListener = mock(DiscoveryListener.class); final DiscoveryListener discListener = mock(DiscoveryListener.class);
@@ -922,7 +936,7 @@ public class NsdServiceTest {
@Test @Test
public void testDiscoveryWithMdnsDiscoveryManager_FailedWithInvalidServiceType() { public void testDiscoveryWithMdnsDiscoveryManager_FailedWithInvalidServiceType() {
makeServiceWithMdnsDiscoveryManagerEnabled(); setMdnsDiscoveryManagerEnabled();
final NsdManager client = connectClient(mService); final NsdManager client = connectClient(mService);
final DiscoveryListener discListener = mock(DiscoveryListener.class); final DiscoveryListener discListener = mock(DiscoveryListener.class);
@@ -951,7 +965,7 @@ public class NsdServiceTest {
@Test @Test
public void testResolutionWithMdnsDiscoveryManager() throws UnknownHostException { public void testResolutionWithMdnsDiscoveryManager() throws UnknownHostException {
makeServiceWithMdnsDiscoveryManagerEnabled(); setMdnsDiscoveryManagerEnabled();
final NsdManager client = connectClient(mService); final NsdManager client = connectClient(mService);
final ResolveListener resolveListener = mock(ResolveListener.class); final ResolveListener resolveListener = mock(ResolveListener.class);
@@ -1004,9 +1018,44 @@ public class NsdServiceTest {
verify(mSocketProvider, timeout(CLEANUP_DELAY_MS + TIMEOUT_MS)).stopMonitoringSockets(); verify(mSocketProvider, timeout(CLEANUP_DELAY_MS + TIMEOUT_MS)).stopMonitoringSockets();
} }
@Test
public void testMdnsAdvertiserFeatureFlagging() {
// Create NsdService w/o feature enabled.
final NsdManager client = connectClient(mService);
final NsdServiceInfo regInfo = new NsdServiceInfo(SERVICE_NAME, SERVICE_TYPE);
regInfo.setHost(parseNumericAddress("192.0.2.123"));
regInfo.setPort(12345);
final RegistrationListener regListenerWithoutFeature = mock(RegistrationListener.class);
client.registerService(regInfo, PROTOCOL, regListenerWithoutFeature);
waitForIdle();
final ArgumentCaptor<Integer> legacyIdCaptor = ArgumentCaptor.forClass(Integer.class);
verify(mMockMDnsM).registerService(legacyIdCaptor.capture(), any(), any(), anyInt(),
any(), anyInt());
verifyNoMoreInteractions(mAdvertiser);
setMdnsAdvertiserEnabled();
final RegistrationListener regListenerWithFeature = mock(RegistrationListener.class);
client.registerService(regInfo, PROTOCOL, regListenerWithFeature);
waitForIdle();
final ArgumentCaptor<Integer> serviceIdCaptor = ArgumentCaptor.forClass(Integer.class);
verify(mAdvertiser).addService(serviceIdCaptor.capture(),
argThat(info -> matches(info, regInfo)));
client.unregisterService(regListenerWithoutFeature);
waitForIdle();
verify(mMockMDnsM).stopOperation(legacyIdCaptor.getValue());
verify(mAdvertiser, never()).removeService(anyInt());
client.unregisterService(regListenerWithFeature);
waitForIdle();
verify(mAdvertiser).removeService(serviceIdCaptor.getValue());
}
@Test @Test
public void testAdvertiseWithMdnsAdvertiser() { public void testAdvertiseWithMdnsAdvertiser() {
makeServiceWithMdnsAdvertiserEnabled(); setMdnsAdvertiserEnabled();
final NsdManager client = connectClient(mService); final NsdManager client = connectClient(mService);
final RegistrationListener regListener = mock(RegistrationListener.class); final RegistrationListener regListener = mock(RegistrationListener.class);
@@ -1045,7 +1094,7 @@ public class NsdServiceTest {
@Test @Test
public void testAdvertiseWithMdnsAdvertiser_FailedWithInvalidServiceType() { public void testAdvertiseWithMdnsAdvertiser_FailedWithInvalidServiceType() {
makeServiceWithMdnsAdvertiserEnabled(); setMdnsAdvertiserEnabled();
final NsdManager client = connectClient(mService); final NsdManager client = connectClient(mService);
final RegistrationListener regListener = mock(RegistrationListener.class); final RegistrationListener regListener = mock(RegistrationListener.class);
@@ -1070,7 +1119,7 @@ public class NsdServiceTest {
@Test @Test
public void testAdvertiseWithMdnsAdvertiser_LongServiceName() { public void testAdvertiseWithMdnsAdvertiser_LongServiceName() {
makeServiceWithMdnsAdvertiserEnabled(); setMdnsAdvertiserEnabled();
final NsdManager client = connectClient(mService); final NsdManager client = connectClient(mService);
final RegistrationListener regListener = mock(RegistrationListener.class); final RegistrationListener regListener = mock(RegistrationListener.class);