Allow Advertiser, DiscoveryManager runtime toggle

Allow toggling MdnsAdvertiser and MdnsDiscoveryManager at runtime, by
always creating them in NsdService constructor, but only using them when
the flag is on when starting discovery, resolve or registration.

When stopping, based on the type of the stored request, stop the
corresponding backend.

Bug: 265891278
Test: atest NsdServiceTest
Change-Id: I7cb2f9fe9e1ed3dc77616689a8e3ffa00f5bc269
This commit is contained in:
Remi NGUYEN VAN
2023-01-18 18:57:41 +09:00
parent 8f453b9c9c
commit a8a777bbbd
2 changed files with 104 additions and 72 deletions

View File

@@ -118,13 +118,15 @@ public class NsdService extends INsdManager.Stub {
private final NsdStateMachine mNsdStateMachine;
private final MDnsManager mMDnsManager;
private final MDnsEventCallback mMDnsEventCallback;
@Nullable
@NonNull
private final Dependencies mDeps;
@NonNull
private final MdnsMultinetworkSocketClient mMdnsSocketClient;
@Nullable
@NonNull
private final MdnsDiscoveryManager mMdnsDiscoveryManager;
@Nullable
@NonNull
private final MdnsSocketProvider mMdnsSocketProvider;
@Nullable
@NonNull
private final MdnsAdvertiser mAdvertiser;
// WARNING : Accessing these values in any thread is not safe, it must only be changed in the
// state machine thread. If change this outside state machine, it will need to introduce
@@ -311,21 +313,14 @@ public class NsdService extends INsdManager.Stub {
mIsMonitoringSocketsStarted = true;
}
private void maybeStopMonitoringSockets() {
if (!mIsMonitoringSocketsStarted) {
if (DBG) Log.d(TAG, "Socket monitoring has not been started.");
return;
}
private void maybeStopMonitoringSocketsIfNoActiveRequest() {
if (!mIsMonitoringSocketsStarted) return;
if (isAnyRequestActive()) return;
mMdnsSocketProvider.stopMonitoringSockets();
mIsMonitoringSocketsStarted = false;
}
private void maybeStopMonitoringSocketsIfNoActiveRequest() {
if (!isAnyRequestActive()) {
maybeStopMonitoringSockets();
}
}
NsdStateMachine(String name, Handler handler) {
super(name, handler);
addState(mDefaultState);
@@ -362,9 +357,7 @@ public class NsdService extends INsdManager.Stub {
mLegacyClientCount -= 1;
}
}
if (mMdnsDiscoveryManager != null || mAdvertiser != null) {
maybeStopMonitoringSocketsIfNoActiveRequest();
}
maybeScheduleStop();
break;
case NsdManager.DISCOVER_SERVICES:
@@ -579,7 +572,7 @@ public class NsdService extends INsdManager.Stub {
final NsdServiceInfo info = args.serviceInfo;
id = getUniqueId();
if (mMdnsDiscoveryManager != null) {
if (mDeps.isMdnsDiscoveryManagerEnabled(mContext)) {
final String serviceType = constructServiceType(info.getServiceType());
if (serviceType == null) {
clientInfo.onDiscoverServicesFailed(clientId,
@@ -634,6 +627,9 @@ public class NsdService extends INsdManager.Stub {
break;
}
id = request.mGlobalId;
// Note isMdnsDiscoveryManagerEnabled may have changed to false at this
// point, so this needs to check the type of the original request to
// unregister instead of looking at the flag value.
if (request instanceof DiscoveryManagerRequest) {
final MdnsListener listener =
((DiscoveryManagerRequest) request).mListener;
@@ -671,7 +667,7 @@ public class NsdService extends INsdManager.Stub {
}
id = getUniqueId();
if (mAdvertiser != null) {
if (mDeps.isMdnsAdvertiserEnabled(mContext)) {
final NsdServiceInfo serviceInfo = args.serviceInfo;
final String serviceType = serviceInfo.getServiceType();
final String registerServiceType = constructServiceType(serviceType);
@@ -722,7 +718,10 @@ public class NsdService extends INsdManager.Stub {
id = request.mGlobalId;
removeRequestMap(clientId, id, clientInfo);
if (mAdvertiser != null) {
// Note isMdnsAdvertiserEnabled may have changed to false at this point,
// so this needs to check the type of the original request to unregister
// instead of looking at the flag value.
if (request instanceof AdvertiserClientRequest) {
mAdvertiser.removeService(id);
clientInfo.onUnregisterServiceSucceeded(clientId);
} else {
@@ -749,7 +748,7 @@ public class NsdService extends INsdManager.Stub {
final NsdServiceInfo info = args.serviceInfo;
id = getUniqueId();
if (mMdnsDiscoveryManager != null) {
if (mDeps.isMdnsDiscoveryManagerEnabled(mContext)) {
final String serviceType = constructServiceType(info.getServiceType());
if (serviceType == null) {
clientInfo.onResolveServiceFailed(clientId,
@@ -1241,32 +1240,16 @@ public class NsdService extends INsdManager.Stub {
mNsdStateMachine.start();
mMDnsManager = ctx.getSystemService(MDnsManager.class);
mMDnsEventCallback = new MDnsEventCallback(mNsdStateMachine);
mDeps = deps;
final boolean discoveryManagerEnabled = deps.isMdnsDiscoveryManagerEnabled(ctx);
final boolean advertiserEnabled = deps.isMdnsAdvertiserEnabled(ctx);
if (discoveryManagerEnabled || advertiserEnabled) {
mMdnsSocketProvider = deps.makeMdnsSocketProvider(ctx, handler.getLooper());
} else {
mMdnsSocketProvider = null;
}
if (discoveryManagerEnabled) {
mMdnsSocketClient =
new MdnsMultinetworkSocketClient(handler.getLooper(), mMdnsSocketProvider);
mMdnsDiscoveryManager =
deps.makeMdnsDiscoveryManager(new ExecutorProvider(), mMdnsSocketClient);
handler.post(() -> mMdnsSocketClient.setCallback(mMdnsDiscoveryManager));
} else {
mMdnsSocketClient = null;
mMdnsDiscoveryManager = null;
}
if (advertiserEnabled) {
mAdvertiser = deps.makeMdnsAdvertiser(handler.getLooper(), mMdnsSocketProvider,
new AdvertiserCallback());
} else {
mAdvertiser = null;
}
}
/**

View File

@@ -45,6 +45,7 @@ import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import android.compat.testing.PlatformCompatChangeRule;
@@ -170,6 +171,9 @@ public class NsdServiceTest {
doReturn(true).when(mMockMDnsM).resolve(
anyInt(), anyString(), anyString(), anyString(), anyInt());
doReturn(false).when(mDeps).isMdnsDiscoveryManagerEnabled(any(Context.class));
doReturn(mDiscoveryManager).when(mDeps).makeMdnsDiscoveryManager(any(), any());
doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any());
doReturn(mAdvertiser).when(mDeps).makeMdnsAdvertiser(any(), any(), any());
mService = makeService();
}
@@ -824,40 +828,50 @@ public class NsdServiceTest {
client.unregisterServiceInfoCallback(serviceInfoCallback));
}
private void makeServiceWithMdnsDiscoveryManagerEnabled() {
private void setMdnsDiscoveryManagerEnabled() {
doReturn(true).when(mDeps).isMdnsDiscoveryManagerEnabled(any(Context.class));
doReturn(mDiscoveryManager).when(mDeps).makeMdnsDiscoveryManager(any(), any());
doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any());
mService = makeService();
verify(mDeps).makeMdnsDiscoveryManager(any(), any());
verify(mDeps).makeMdnsSocketProvider(any(), any());
}
private void makeServiceWithMdnsAdvertiserEnabled() {
private void setMdnsAdvertiserEnabled() {
doReturn(true).when(mDeps).isMdnsAdvertiserEnabled(any(Context.class));
doReturn(mAdvertiser).when(mDeps).makeMdnsAdvertiser(any(), any(), any());
doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any());
mService = makeService();
verify(mDeps).makeMdnsAdvertiser(any(), any(), any());
verify(mDeps).makeMdnsSocketProvider(any(), any());
}
@Test
public void testMdnsDiscoveryManagerFeature() {
// Create NsdService w/o feature enabled.
connectClient(mService);
verify(mDeps, never()).makeMdnsDiscoveryManager(any(), any());
verify(mDeps, never()).makeMdnsSocketProvider(any(), any());
final NsdManager client = connectClient(mService);
final DiscoveryListener discListenerWithoutFeature = mock(DiscoveryListener.class);
client.discoverServices(SERVICE_TYPE, PROTOCOL, discListenerWithoutFeature);
waitForIdle();
// Create NsdService again w/ feature enabled.
makeServiceWithMdnsDiscoveryManagerEnabled();
final ArgumentCaptor<Integer> legacyIdCaptor = ArgumentCaptor.forClass(Integer.class);
verify(mMockMDnsM).discover(legacyIdCaptor.capture(), any(), anyInt());
verifyNoMoreInteractions(mDiscoveryManager);
setMdnsDiscoveryManagerEnabled();
final DiscoveryListener discListenerWithFeature = mock(DiscoveryListener.class);
client.discoverServices(SERVICE_TYPE, PROTOCOL, discListenerWithFeature);
waitForIdle();
final String serviceTypeWithLocalDomain = SERVICE_TYPE + ".local";
final ArgumentCaptor<MdnsServiceBrowserListener> listenerCaptor =
ArgumentCaptor.forClass(MdnsServiceBrowserListener.class);
verify(mDiscoveryManager).registerListener(eq(serviceTypeWithLocalDomain),
listenerCaptor.capture(), any());
client.stopServiceDiscovery(discListenerWithoutFeature);
waitForIdle();
verify(mMockMDnsM).stopOperation(legacyIdCaptor.getValue());
client.stopServiceDiscovery(discListenerWithFeature);
waitForIdle();
verify(mDiscoveryManager).unregisterListener(serviceTypeWithLocalDomain,
listenerCaptor.getValue());
}
@Test
public void testDiscoveryWithMdnsDiscoveryManager() {
makeServiceWithMdnsDiscoveryManagerEnabled();
setMdnsDiscoveryManagerEnabled();
final NsdManager client = connectClient(mService);
final DiscoveryListener discListener = mock(DiscoveryListener.class);
@@ -922,7 +936,7 @@ public class NsdServiceTest {
@Test
public void testDiscoveryWithMdnsDiscoveryManager_FailedWithInvalidServiceType() {
makeServiceWithMdnsDiscoveryManagerEnabled();
setMdnsDiscoveryManagerEnabled();
final NsdManager client = connectClient(mService);
final DiscoveryListener discListener = mock(DiscoveryListener.class);
@@ -951,7 +965,7 @@ public class NsdServiceTest {
@Test
public void testResolutionWithMdnsDiscoveryManager() throws UnknownHostException {
makeServiceWithMdnsDiscoveryManagerEnabled();
setMdnsDiscoveryManagerEnabled();
final NsdManager client = connectClient(mService);
final ResolveListener resolveListener = mock(ResolveListener.class);
@@ -1004,9 +1018,44 @@ public class NsdServiceTest {
verify(mSocketProvider, timeout(CLEANUP_DELAY_MS + TIMEOUT_MS)).stopMonitoringSockets();
}
@Test
public void testMdnsAdvertiserFeatureFlagging() {
// Create NsdService w/o feature enabled.
final NsdManager client = connectClient(mService);
final NsdServiceInfo regInfo = new NsdServiceInfo(SERVICE_NAME, SERVICE_TYPE);
regInfo.setHost(parseNumericAddress("192.0.2.123"));
regInfo.setPort(12345);
final RegistrationListener regListenerWithoutFeature = mock(RegistrationListener.class);
client.registerService(regInfo, PROTOCOL, regListenerWithoutFeature);
waitForIdle();
final ArgumentCaptor<Integer> legacyIdCaptor = ArgumentCaptor.forClass(Integer.class);
verify(mMockMDnsM).registerService(legacyIdCaptor.capture(), any(), any(), anyInt(),
any(), anyInt());
verifyNoMoreInteractions(mAdvertiser);
setMdnsAdvertiserEnabled();
final RegistrationListener regListenerWithFeature = mock(RegistrationListener.class);
client.registerService(regInfo, PROTOCOL, regListenerWithFeature);
waitForIdle();
final ArgumentCaptor<Integer> serviceIdCaptor = ArgumentCaptor.forClass(Integer.class);
verify(mAdvertiser).addService(serviceIdCaptor.capture(),
argThat(info -> matches(info, regInfo)));
client.unregisterService(regListenerWithoutFeature);
waitForIdle();
verify(mMockMDnsM).stopOperation(legacyIdCaptor.getValue());
verify(mAdvertiser, never()).removeService(anyInt());
client.unregisterService(regListenerWithFeature);
waitForIdle();
verify(mAdvertiser).removeService(serviceIdCaptor.getValue());
}
@Test
public void testAdvertiseWithMdnsAdvertiser() {
makeServiceWithMdnsAdvertiserEnabled();
setMdnsAdvertiserEnabled();
final NsdManager client = connectClient(mService);
final RegistrationListener regListener = mock(RegistrationListener.class);
@@ -1045,7 +1094,7 @@ public class NsdServiceTest {
@Test
public void testAdvertiseWithMdnsAdvertiser_FailedWithInvalidServiceType() {
makeServiceWithMdnsAdvertiserEnabled();
setMdnsAdvertiserEnabled();
final NsdManager client = connectClient(mService);
final RegistrationListener regListener = mock(RegistrationListener.class);
@@ -1070,7 +1119,7 @@ public class NsdServiceTest {
@Test
public void testAdvertiseWithMdnsAdvertiser_LongServiceName() {
makeServiceWithMdnsAdvertiserEnabled();
setMdnsAdvertiserEnabled();
final NsdManager client = connectClient(mService);
final RegistrationListener regListener = mock(RegistrationListener.class);