Use MdnsDiscoveryManager for discovery

Register/Unregister the listener to/from MdnsDiscoveryManager
when discovery started/stopped.

Bug: 254166302
Test: atest FrameworksNetTests CtsNetTestsCases
Change-Id: Ibd782029826ac5856c608165928cd942e46dd9a4
This commit is contained in:
Paul Hu
2023-01-13 22:57:24 +08:00
parent 4bd98ef68e
commit 23fa202478
3 changed files with 285 additions and 20 deletions

View File

@@ -242,6 +242,9 @@ public final class NsdManager {
/** @hide */ /** @hide */
public static final int UNREGISTER_CLIENT = 22; public static final int UNREGISTER_CLIENT = 22;
/** @hide */
public static final int MDNS_MONITORING_SOCKETS_CLEANUP = 23;
/** Dns based service discovery protocol */ /** Dns based service discovery protocol */
public static final int PROTOCOL_DNS_SD = 0x0001; public static final int PROTOCOL_DNS_SD = 0x0001;

View File

@@ -20,6 +20,7 @@ import static android.net.ConnectivityManager.NETID_UNSET;
import static android.net.nsd.NsdManager.MDNS_SERVICE_EVENT; import static android.net.nsd.NsdManager.MDNS_SERVICE_EVENT;
import static android.provider.DeviceConfig.NAMESPACE_CONNECTIVITY; import static android.provider.DeviceConfig.NAMESPACE_CONNECTIVITY;
import android.annotation.NonNull;
import android.annotation.Nullable; import android.annotation.Nullable;
import android.content.Context; import android.content.Context;
import android.content.Intent; import android.content.Intent;
@@ -45,6 +46,7 @@ import android.os.Looper;
import android.os.Message; import android.os.Message;
import android.os.RemoteException; import android.os.RemoteException;
import android.os.UserHandle; import android.os.UserHandle;
import android.text.TextUtils;
import android.util.Log; import android.util.Log;
import android.util.Pair; import android.util.Pair;
import android.util.SparseArray; import android.util.SparseArray;
@@ -58,6 +60,9 @@ import com.android.net.module.util.PermissionUtils;
import com.android.server.connectivity.mdns.ExecutorProvider; import com.android.server.connectivity.mdns.ExecutorProvider;
import com.android.server.connectivity.mdns.MdnsDiscoveryManager; import com.android.server.connectivity.mdns.MdnsDiscoveryManager;
import com.android.server.connectivity.mdns.MdnsMultinetworkSocketClient; import com.android.server.connectivity.mdns.MdnsMultinetworkSocketClient;
import com.android.server.connectivity.mdns.MdnsSearchOptions;
import com.android.server.connectivity.mdns.MdnsServiceBrowserListener;
import com.android.server.connectivity.mdns.MdnsServiceInfo;
import com.android.server.connectivity.mdns.MdnsSocketClientBase; import com.android.server.connectivity.mdns.MdnsSocketClientBase;
import com.android.server.connectivity.mdns.MdnsSocketProvider; import com.android.server.connectivity.mdns.MdnsSocketProvider;
@@ -68,6 +73,9 @@ import java.net.NetworkInterface;
import java.net.SocketException; import java.net.SocketException;
import java.net.UnknownHostException; import java.net.UnknownHostException;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/** /**
* Network Service Discovery Service handles remote service discovery operation requests by * Network Service Discovery Service handles remote service discovery operation requests by
@@ -79,6 +87,7 @@ public class NsdService extends INsdManager.Stub {
private static final String TAG = "NsdService"; private static final String TAG = "NsdService";
private static final String MDNS_TAG = "mDnsConnector"; private static final String MDNS_TAG = "mDnsConnector";
private static final String MDNS_DISCOVERY_MANAGER_VERSION = "mdns_discovery_manager_version"; private static final String MDNS_DISCOVERY_MANAGER_VERSION = "mdns_discovery_manager_version";
private static final String LOCAL_DOMAIN_NAME = "local";
private static final boolean DBG = Log.isLoggable(TAG, Log.DEBUG); private static final boolean DBG = Log.isLoggable(TAG, Log.DEBUG);
private static final long CLEANUP_DELAY_MS = 10000; private static final long CLEANUP_DELAY_MS = 10000;
@@ -94,10 +103,11 @@ public class NsdService extends INsdManager.Stub {
private final MdnsDiscoveryManager mMdnsDiscoveryManager; private final MdnsDiscoveryManager mMdnsDiscoveryManager;
@Nullable @Nullable
private final MdnsSocketProvider mMdnsSocketProvider; private final MdnsSocketProvider mMdnsSocketProvider;
// WARNING : Accessing this value in any thread is not safe, it must only be changed in the // WARNING : Accessing these values in any thread is not safe, it must only be changed in the
// state machine thread. If change this outside state machine, it will need to introduce // state machine thread. If change this outside state machine, it will need to introduce
// synchronization. // synchronization.
private boolean mIsDaemonStarted = false; private boolean mIsDaemonStarted = false;
private boolean mIsMonitoringSocketsStarted = false;
/** /**
* Clients receiving asynchronous messages * Clients receiving asynchronous messages
@@ -114,6 +124,73 @@ public class NsdService extends INsdManager.Stub {
// The count of the connected legacy clients. // The count of the connected legacy clients.
private int mLegacyClientCount = 0; private int mLegacyClientCount = 0;
private static class MdnsListener implements MdnsServiceBrowserListener {
protected final int mClientId;
protected final int mTransactionId;
@NonNull
protected final NsdServiceInfo mReqServiceInfo;
@NonNull
protected final String mListenedServiceType;
MdnsListener(int clientId, int transactionId, @NonNull NsdServiceInfo reqServiceInfo,
@NonNull String listenedServiceType) {
mClientId = clientId;
mTransactionId = transactionId;
mReqServiceInfo = reqServiceInfo;
mListenedServiceType = listenedServiceType;
}
@NonNull
public String getListenedServiceType() {
return mListenedServiceType;
}
@Override
public void onServiceFound(@NonNull MdnsServiceInfo serviceInfo) { }
@Override
public void onServiceUpdated(@NonNull MdnsServiceInfo serviceInfo) { }
@Override
public void onServiceRemoved(@NonNull MdnsServiceInfo serviceInfo) { }
@Override
public void onServiceNameDiscovered(@NonNull MdnsServiceInfo serviceInfo) { }
@Override
public void onServiceNameRemoved(@NonNull MdnsServiceInfo serviceInfo) { }
@Override
public void onSearchStoppedWithError(int error) { }
@Override
public void onSearchFailedToStart() { }
@Override
public void onDiscoveryQuerySent(@NonNull List<String> subtypes, int transactionId) { }
@Override
public void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode) { }
}
private class DiscoveryListener extends MdnsListener {
DiscoveryListener(int clientId, int transactionId, @NonNull NsdServiceInfo reqServiceInfo,
@NonNull String listenServiceType) {
super(clientId, transactionId, reqServiceInfo, listenServiceType);
}
@Override
public void onServiceNameDiscovered(@NonNull MdnsServiceInfo serviceInfo) {
// TODO: implement service name discovered callback.
}
@Override
public void onServiceNameRemoved(@NonNull MdnsServiceInfo serviceInfo) {
// TODO: implement service name removed callback.
}
}
private class NsdStateMachine extends StateMachine { private class NsdStateMachine extends StateMachine {
private final DefaultState mDefaultState = new DefaultState(); private final DefaultState mDefaultState = new DefaultState();
@@ -164,6 +241,31 @@ public class NsdService extends INsdManager.Stub {
this.removeMessages(NsdManager.DAEMON_CLEANUP); this.removeMessages(NsdManager.DAEMON_CLEANUP);
} }
private void maybeStartMonitoringSockets() {
if (mIsMonitoringSocketsStarted) {
if (DBG) Log.d(TAG, "Socket monitoring is already started.");
return;
}
mMdnsSocketProvider.startMonitoringSockets();
mIsMonitoringSocketsStarted = true;
}
private void maybeStopMonitoringSockets() {
if (!mIsMonitoringSocketsStarted) {
if (DBG) Log.d(TAG, "Socket monitoring has not been started.");
return;
}
mMdnsSocketProvider.stopMonitoringSockets();
mIsMonitoringSocketsStarted = false;
}
private void maybeStopMonitoringSocketsIfNoActiveRequest() {
if (!isAnyRequestActive()) {
maybeStopMonitoringSockets();
}
}
NsdStateMachine(String name, Handler handler) { NsdStateMachine(String name, Handler handler) {
super(name, handler); super(name, handler);
addState(mDefaultState); addState(mDefaultState);
@@ -195,11 +297,17 @@ public class NsdService extends INsdManager.Stub {
final NsdServiceConnector connector = (NsdServiceConnector) msg.obj; final NsdServiceConnector connector = (NsdServiceConnector) msg.obj;
cInfo = mClients.remove(connector); cInfo = mClients.remove(connector);
if (cInfo != null) { if (cInfo != null) {
if (mMdnsDiscoveryManager != null) {
cInfo.unregisterAllListeners();
}
cInfo.expungeAllRequests(); cInfo.expungeAllRequests();
if (cInfo.isLegacy()) { if (cInfo.isLegacy()) {
mLegacyClientCount -= 1; mLegacyClientCount -= 1;
} }
} }
if (mMdnsDiscoveryManager != null) {
maybeStopMonitoringSocketsIfNoActiveRequest();
}
maybeScheduleStop(); maybeScheduleStop();
break; break;
case NsdManager.DISCOVER_SERVICES: case NsdManager.DISCOVER_SERVICES:
@@ -251,6 +359,9 @@ public class NsdService extends INsdManager.Stub {
maybeStartDaemon(); maybeStartDaemon();
} }
break; break;
case NsdManager.MDNS_MONITORING_SOCKETS_CLEANUP:
maybeStopMonitoringSockets();
break;
default: default:
Log.e(TAG, "Unhandled " + msg); Log.e(TAG, "Unhandled " + msg);
return NOT_HANDLED; return NOT_HANDLED;
@@ -300,6 +411,47 @@ public class NsdService extends INsdManager.Stub {
maybeScheduleStop(); maybeScheduleStop();
} }
private void storeListenerMap(int clientId, int transactionId, MdnsListener listener,
ClientInfo clientInfo) {
clientInfo.mClientIds.put(clientId, transactionId);
clientInfo.mListeners.put(clientId, listener);
mIdToClientInfoMap.put(transactionId, clientInfo);
removeMessages(NsdManager.MDNS_MONITORING_SOCKETS_CLEANUP);
}
private void removeListenerMap(int clientId, int transactionId, ClientInfo clientInfo) {
clientInfo.mClientIds.delete(clientId);
clientInfo.mListeners.delete(clientId);
mIdToClientInfoMap.remove(transactionId);
maybeStopMonitoringSocketsIfNoActiveRequest();
}
/**
* Check the given service type is valid and construct it to a service type
* which can use for discovery / resolution service.
*
* <p> The valid service type should be 2 labels, or 3 labels if the query is for a
* subtype (see RFC6763 7.1). Each label is up to 63 characters and must start with an
* underscore; they are alphanumerical characters or dashes or underscore, except the
* last one that is just alphanumerical. The last label must be _tcp or _udp.
*
* @param serviceType the request service type for discovery / resolution service
* @return constructed service type or null if the given service type is invalid.
*/
@Nullable
private String constructServiceType(String serviceType) {
if (TextUtils.isEmpty(serviceType)) return null;
final Pattern serviceTypePattern = Pattern.compile(
"^(_[a-zA-Z0-9-_]{1,61}[a-zA-Z0-9]\\.)?"
+ "(_[a-zA-Z0-9-_]{1,61}[a-zA-Z0-9]\\._(?:tcp|udp))$");
final Matcher matcher = serviceTypePattern.matcher(serviceType);
if (!matcher.matches()) return null;
return matcher.group(1) == null
? serviceType + ".local"
: matcher.group(1) + "._sub" + matcher.group(2) + ".local";
}
@Override @Override
public boolean processMessage(Message msg) { public boolean processMessage(Message msg) {
final ClientInfo clientInfo; final ClientInfo clientInfo;
@@ -325,19 +477,40 @@ public class NsdService extends INsdManager.Stub {
break; break;
} }
maybeStartDaemon(); final NsdServiceInfo info = args.serviceInfo;
id = getUniqueId(); id = getUniqueId();
if (discoverServices(id, args.serviceInfo)) { if (mMdnsDiscoveryManager != null) {
if (DBG) { final String serviceType = constructServiceType(info.getServiceType());
Log.d(TAG, "Discover " + msg.arg2 + " " + id if (serviceType == null) {
+ args.serviceInfo.getServiceType()); clientInfo.onDiscoverServicesFailed(clientId,
NsdManager.FAILURE_INTERNAL_ERROR);
break;
} }
storeRequestMap(clientId, id, clientInfo, msg.what);
clientInfo.onDiscoverServicesStarted(clientId, args.serviceInfo); maybeStartMonitoringSockets();
final MdnsListener listener =
new DiscoveryListener(clientId, id, info, serviceType);
final MdnsSearchOptions options = MdnsSearchOptions.newBuilder()
.setNetwork(info.getNetwork())
.setIsPassiveMode(true)
.build();
mMdnsDiscoveryManager.registerListener(serviceType, listener, options);
storeListenerMap(clientId, id, listener, clientInfo);
clientInfo.onDiscoverServicesStarted(clientId, info);
} else { } else {
stopServiceDiscovery(id); maybeStartDaemon();
clientInfo.onDiscoverServicesFailed(clientId, if (discoverServices(id, info)) {
NsdManager.FAILURE_INTERNAL_ERROR); if (DBG) {
Log.d(TAG, "Discover " + msg.arg2 + " " + id
+ info.getServiceType());
}
storeRequestMap(clientId, id, clientInfo, msg.what);
clientInfo.onDiscoverServicesStarted(clientId, info);
} else {
stopServiceDiscovery(id);
clientInfo.onDiscoverServicesFailed(clientId,
NsdManager.FAILURE_INTERNAL_ERROR);
}
} }
break; break;
case NsdManager.STOP_DISCOVERY: case NsdManager.STOP_DISCOVERY:
@@ -359,12 +532,25 @@ public class NsdService extends INsdManager.Stub {
clientId, NsdManager.FAILURE_INTERNAL_ERROR); clientId, NsdManager.FAILURE_INTERNAL_ERROR);
break; break;
} }
removeRequestMap(clientId, id, clientInfo); if (mMdnsDiscoveryManager != null) {
if (stopServiceDiscovery(id)) { final MdnsListener listener = clientInfo.mListeners.get(clientId);
if (listener == null) {
clientInfo.onStopDiscoveryFailed(
clientId, NsdManager.FAILURE_INTERNAL_ERROR);
break;
}
mMdnsDiscoveryManager.unregisterListener(
listener.getListenedServiceType(), listener);
removeListenerMap(clientId, id, clientInfo);
clientInfo.onStopDiscoverySucceeded(clientId); clientInfo.onStopDiscoverySucceeded(clientId);
} else { } else {
clientInfo.onStopDiscoveryFailed( removeRequestMap(clientId, id, clientInfo);
clientId, NsdManager.FAILURE_INTERNAL_ERROR); if (stopServiceDiscovery(id)) {
clientInfo.onStopDiscoverySucceeded(clientId);
} else {
clientInfo.onStopDiscoveryFailed(
clientId, NsdManager.FAILURE_INTERNAL_ERROR);
}
} }
break; break;
case NsdManager.REGISTER_SERVICE: case NsdManager.REGISTER_SERVICE:
@@ -982,6 +1168,9 @@ public class NsdService extends INsdManager.Stub {
/* A map from client id to the type of the request we had received */ /* A map from client id to the type of the request we had received */
private final SparseIntArray mClientRequests = new SparseIntArray(); private final SparseIntArray mClientRequests = new SparseIntArray();
/* A map from client id to the MdnsListener */
private final SparseArray<MdnsListener> mListeners = new SparseArray<>();
// The target SDK of this client < Build.VERSION_CODES.S // The target SDK of this client < Build.VERSION_CODES.S
private boolean mIsLegacy = false; private boolean mIsLegacy = false;
@@ -1043,6 +1232,15 @@ public class NsdService extends INsdManager.Stub {
mClientRequests.clear(); mClientRequests.clear();
} }
void unregisterAllListeners() {
for (int i = 0; i < mListeners.size(); i++) {
final MdnsListener listener = mListeners.valueAt(i);
mMdnsDiscoveryManager.unregisterListener(
listener.getListenedServiceType(), listener);
}
mListeners.clear();
}
// mClientIds is a sparse array of listener id -> mDnsClient id. For a given mDnsClient id, // mClientIds is a sparse array of listener id -> mDnsClient id. For a given mDnsClient id,
// return the corresponding listener id. mDnsClient id is also called a global id. // return the corresponding listener id. mDnsClient id is also called a global id.
private int getClientId(final int globalId) { private int getClientId(final int globalId) {

View File

@@ -28,6 +28,7 @@ import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.any; import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doCallRealMethod; import static org.mockito.Mockito.doCallRealMethod;
import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doReturn;
@@ -45,6 +46,7 @@ import android.content.ContentResolver;
import android.content.Context; import android.content.Context;
import android.net.INetd; import android.net.INetd;
import android.net.InetAddresses; import android.net.InetAddresses;
import android.net.Network;
import android.net.mdns.aidl.DiscoveryInfo; import android.net.mdns.aidl.DiscoveryInfo;
import android.net.mdns.aidl.GetAddressInfo; import android.net.mdns.aidl.GetAddressInfo;
import android.net.mdns.aidl.IMDnsEventListener; import android.net.mdns.aidl.IMDnsEventListener;
@@ -70,6 +72,8 @@ import androidx.annotation.NonNull;
import androidx.test.filters.SmallTest; import androidx.test.filters.SmallTest;
import com.android.server.NsdService.Dependencies; import com.android.server.NsdService.Dependencies;
import com.android.server.connectivity.mdns.MdnsDiscoveryManager;
import com.android.server.connectivity.mdns.MdnsSocketProvider;
import com.android.testutils.DevSdkIgnoreRule; import com.android.testutils.DevSdkIgnoreRule;
import com.android.testutils.DevSdkIgnoreRunner; import com.android.testutils.DevSdkIgnoreRunner;
import com.android.testutils.HandlerUtils; import com.android.testutils.HandlerUtils;
@@ -99,7 +103,7 @@ public class NsdServiceTest {
private static final long CLEANUP_DELAY_MS = 500; private static final long CLEANUP_DELAY_MS = 500;
private static final long TIMEOUT_MS = 500; private static final long TIMEOUT_MS = 500;
private static final String SERVICE_NAME = "a_name"; private static final String SERVICE_NAME = "a_name";
private static final String SERVICE_TYPE = "a_type"; private static final String SERVICE_TYPE = "_test._tcp";
private static final String SERVICE_FULL_NAME = SERVICE_NAME + "." + SERVICE_TYPE; private static final String SERVICE_FULL_NAME = SERVICE_NAME + "." + SERVICE_TYPE;
private static final String DOMAIN_NAME = "mytestdevice.local"; private static final String DOMAIN_NAME = "mytestdevice.local";
private static final int PORT = 2201; private static final int PORT = 2201;
@@ -116,6 +120,8 @@ public class NsdServiceTest {
@Mock ContentResolver mResolver; @Mock ContentResolver mResolver;
@Mock MDnsManager mMockMDnsM; @Mock MDnsManager mMockMDnsM;
@Mock Dependencies mDeps; @Mock Dependencies mDeps;
@Mock MdnsDiscoveryManager mDiscoveryManager;
@Mock MdnsSocketProvider mSocketProvider;
HandlerThread mThread; HandlerThread mThread;
TestHandler mHandler; TestHandler mHandler;
NsdService mService; NsdService mService;
@@ -558,6 +564,16 @@ public class NsdServiceTest {
anyInt()/* interfaceIdx */); anyInt()/* interfaceIdx */);
} }
private void makeServiceWithMdnsDiscoveryManagerEnabled() {
doReturn(true).when(mDeps).isMdnsDiscoveryManagerEnabled(any(Context.class));
doReturn(mDiscoveryManager).when(mDeps).makeMdnsDiscoveryManager(any(), any());
doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any());
mService = makeService();
verify(mDeps).makeMdnsDiscoveryManager(any(), any());
verify(mDeps).makeMdnsSocketProvider(any(), any());
}
@Test @Test
public void testMdnsDiscoveryManagerFeature() { public void testMdnsDiscoveryManagerFeature() {
// Create NsdService w/o feature enabled. // Create NsdService w/o feature enabled.
@@ -566,12 +582,60 @@ public class NsdServiceTest {
verify(mDeps, never()).makeMdnsSocketProvider(any(), any()); verify(mDeps, never()).makeMdnsSocketProvider(any(), any());
// Create NsdService again w/ feature enabled. // Create NsdService again w/ feature enabled.
doReturn(true).when(mDeps).isMdnsDiscoveryManagerEnabled(any(Context.class)); makeServiceWithMdnsDiscoveryManagerEnabled();
makeService();
verify(mDeps).makeMdnsDiscoveryManager(any(), any());
verify(mDeps).makeMdnsSocketProvider(any(), any());
} }
@Test
public void testDiscoveryWithMdnsDiscoveryManager() {
makeServiceWithMdnsDiscoveryManagerEnabled();
final NsdManager client = connectClient(mService);
final DiscoveryListener discListener = mock(DiscoveryListener.class);
final Network network = new Network(999);
final String serviceTypeWithLocalDomain = SERVICE_TYPE + ".local";
// Verify the discovery start / stop.
client.discoverServices(SERVICE_TYPE, PROTOCOL, network, r -> r.run(), discListener);
waitForIdle();
verify(mSocketProvider).startMonitoringSockets();
verify(mDiscoveryManager).registerListener(eq(serviceTypeWithLocalDomain), any(),
argThat(options -> network.equals(options.getNetwork())));
verify(discListener, timeout(TIMEOUT_MS)).onDiscoveryStarted(SERVICE_TYPE);
client.stopServiceDiscovery(discListener);
waitForIdle();
verify(mDiscoveryManager).unregisterListener(eq(serviceTypeWithLocalDomain), any());
verify(discListener, timeout(TIMEOUT_MS)).onDiscoveryStopped(SERVICE_TYPE);
verify(mSocketProvider, timeout(CLEANUP_DELAY_MS + TIMEOUT_MS)).stopMonitoringSockets();
}
@Test
public void testDiscoveryWithMdnsDiscoveryManager_FailedWithInvalidServiceType() {
makeServiceWithMdnsDiscoveryManagerEnabled();
final NsdManager client = connectClient(mService);
final DiscoveryListener discListener = mock(DiscoveryListener.class);
final Network network = new Network(999);
final String invalidServiceType = "a_service";
client.discoverServices(
invalidServiceType, PROTOCOL, network, r -> r.run(), discListener);
waitForIdle();
verify(discListener, timeout(TIMEOUT_MS))
.onStartDiscoveryFailed(invalidServiceType, FAILURE_INTERNAL_ERROR);
final String serviceTypeWithLocalDomain = SERVICE_TYPE + ".local";
client.discoverServices(
serviceTypeWithLocalDomain, PROTOCOL, network, r -> r.run(), discListener);
waitForIdle();
verify(discListener, timeout(TIMEOUT_MS))
.onStartDiscoveryFailed(serviceTypeWithLocalDomain, FAILURE_INTERNAL_ERROR);
final String serviceTypeWithoutTcpOrUdpEnding = "_test._com";
client.discoverServices(
serviceTypeWithoutTcpOrUdpEnding, PROTOCOL, network, r -> r.run(), discListener);
waitForIdle();
verify(discListener, timeout(TIMEOUT_MS))
.onStartDiscoveryFailed(serviceTypeWithoutTcpOrUdpEnding, FAILURE_INTERNAL_ERROR);
}
private void waitForIdle() { private void waitForIdle() {
HandlerUtils.waitForIdle(mHandler, TIMEOUT_MS); HandlerUtils.waitForIdle(mHandler, TIMEOUT_MS);