diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java index dfaec751ae..f386dd487d 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java @@ -51,6 +51,7 @@ public class MdnsDiscoveryManager implements MdnsSocketClientBase.Callback { @NonNull private final PerSocketServiceTypeClients perSocketServiceTypeClients; @NonNull private final Handler handler; @Nullable private final HandlerThread handlerThread; + @NonNull private final MdnsServiceCache serviceCache; private static class PerSocketServiceTypeClients { private final ArrayMap, MdnsServiceTypeClient> clients = @@ -119,10 +120,12 @@ public class MdnsDiscoveryManager implements MdnsSocketClientBase.Callback { if (socketClient.getLooper() != null) { this.handlerThread = null; this.handler = new Handler(socketClient.getLooper()); + this.serviceCache = new MdnsServiceCache(socketClient.getLooper()); } else { this.handlerThread = new HandlerThread(MdnsDiscoveryManager.class.getSimpleName()); this.handlerThread.start(); this.handler = new Handler(handlerThread.getLooper()); + this.serviceCache = new MdnsServiceCache(handlerThread.getLooper()); } } @@ -289,6 +292,6 @@ public class MdnsDiscoveryManager implements MdnsSocketClientBase.Callback { return new MdnsServiceTypeClient( serviceType, socketClient, executorProvider.newServiceTypeClientSchedulerExecutor(), socketKey, - sharedLog.forSubComponent(tag), handler.getLooper()); + sharedLog.forSubComponent(tag), handler.getLooper(), serviceCache); } } \ No newline at end of file diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java index dc99e494f7..ec6af9b6d6 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java @@ -96,7 +96,14 @@ public class MdnsServiceCache { : Collections.emptyList(); } - private MdnsResponse findMatchedResponse(@NonNull List responses, + /** + * Find a matched response for given service name + * + * @param responses the responses to be searched. + * @param serviceName the target service name + * @return the response which matches the given service name or null if not found. + */ + public static MdnsResponse findMatchedResponse(@NonNull List responses, @NonNull String serviceName) { for (MdnsResponse response : responses) { if (equalsIgnoreDnsCase(serviceName, response.getServiceInstanceName())) { diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java index 7035c90ae8..3b5193abc4 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java @@ -16,6 +16,7 @@ package com.android.server.connectivity.mdns; +import static com.android.server.connectivity.mdns.MdnsServiceCache.findMatchedResponse; import static com.android.server.connectivity.mdns.util.MdnsUtils.ensureRunningOnHandlerThread; import android.annotation.NonNull; @@ -40,10 +41,8 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; -import java.util.HashMap; import java.util.Iterator; import java.util.List; -import java.util.Map; import java.util.concurrent.ScheduledExecutorService; /** @@ -67,10 +66,12 @@ public class MdnsServiceTypeClient { @NonNull private final SharedLog sharedLog; @NonNull private final Handler handler; @NonNull private final Dependencies dependencies; + /** + * The service caches for each socket. It should be accessed from looper thread only. + */ + @NonNull private final MdnsServiceCache serviceCache; private final ArrayMap listeners = new ArrayMap<>(); - // TODO: change instanceNameToResponse to TreeMap with case insensitive comparator. - private final Map instanceNameToResponse = new HashMap<>(); private final boolean removeServiceAfterTtlExpires = MdnsConfigs.removeServiceAfterTtlExpires(); private final MdnsResponseDecoder.Clock clock; @@ -190,9 +191,10 @@ public class MdnsServiceTypeClient { @NonNull ScheduledExecutorService executor, @NonNull SocketKey socketKey, @NonNull SharedLog sharedLog, - @NonNull Looper looper) { + @NonNull Looper looper, + @NonNull MdnsServiceCache serviceCache) { this(serviceType, socketClient, executor, new MdnsResponseDecoder.Clock(), socketKey, - sharedLog, looper, new Dependencies()); + sharedLog, looper, new Dependencies(), serviceCache); } @VisibleForTesting @@ -204,7 +206,8 @@ public class MdnsServiceTypeClient { @NonNull SocketKey socketKey, @NonNull SharedLog sharedLog, @NonNull Looper looper, - @NonNull Dependencies dependencies) { + @NonNull Dependencies dependencies, + @NonNull MdnsServiceCache serviceCache) { this.serviceType = serviceType; this.socketClient = socketClient; this.executor = executor; @@ -215,6 +218,7 @@ public class MdnsServiceTypeClient { this.sharedLog = sharedLog; this.handler = new QueryTaskHandler(looper); this.dependencies = dependencies; + this.serviceCache = serviceCache; } private static MdnsServiceInfo buildMdnsServiceInfoFromResponse( @@ -281,7 +285,8 @@ public class MdnsServiceTypeClient { this.searchOptions = searchOptions; boolean hadReply = false; if (listeners.put(listener, searchOptions) == null) { - for (MdnsResponse existingResponse : instanceNameToResponse.values()) { + for (MdnsResponse existingResponse : + serviceCache.getCachedServices(serviceType, socketKey)) { if (!responseMatchesOptions(existingResponse, searchOptions)) continue; final MdnsServiceInfo info = buildMdnsServiceInfoFromResponse(existingResponse, serviceTypeLabels); @@ -377,11 +382,13 @@ public class MdnsServiceTypeClient { ensureRunningOnHandlerThread(handler); // Augment the list of current known responses, and generated responses for resolve // requests if there is no known response - final List currentList = new ArrayList<>(instanceNameToResponse.values()); + final List cachedList = + serviceCache.getCachedServices(serviceType, socketKey); + final List currentList = new ArrayList<>(cachedList); List additionalResponses = makeResponsesForResolve(socketKey); for (MdnsResponse additionalResponse : additionalResponses) { - if (!instanceNameToResponse.containsKey( - additionalResponse.getServiceInstanceName())) { + if (findMatchedResponse( + cachedList, additionalResponse.getServiceInstanceName()) == null) { currentList.add(additionalResponse); } } @@ -393,16 +400,17 @@ public class MdnsServiceTypeClient { final ArrayList allResponses = augmentedResult.second; for (MdnsResponse response : allResponses) { + final String serviceInstanceName = response.getServiceInstanceName(); if (modifiedResponse.contains(response)) { if (response.isGoodbye()) { - onGoodbyeReceived(response.getServiceInstanceName()); + onGoodbyeReceived(serviceInstanceName); } else { onResponseModified(response); } - } else if (instanceNameToResponse.containsKey(response.getServiceInstanceName())) { + } else if (findMatchedResponse(cachedList, serviceInstanceName) != null) { // If the response is not modified and already in the cache. The cache will // need to be updated to refresh the last receipt time. - instanceNameToResponse.put(response.getServiceInstanceName(), response); + serviceCache.addOrUpdateService(serviceType, socketKey, response); } } if (dependencies.hasMessages(handler, EVENT_START_QUERYTASK) @@ -431,7 +439,7 @@ public class MdnsServiceTypeClient { /** Notify all services are removed because the socket is destroyed. */ public void notifySocketDestroyed() { ensureRunningOnHandlerThread(handler); - for (MdnsResponse response : instanceNameToResponse.values()) { + for (MdnsResponse response : serviceCache.getCachedServices(serviceType, socketKey)) { final String name = response.getServiceInstanceName(); if (name == null) continue; for (int i = 0; i < listeners.size(); i++) { @@ -453,18 +461,18 @@ public class MdnsServiceTypeClient { private void onResponseModified(@NonNull MdnsResponse response) { final String serviceInstanceName = response.getServiceInstanceName(); final MdnsResponse currentResponse = - instanceNameToResponse.get(serviceInstanceName); + serviceCache.getCachedService(serviceInstanceName, serviceType, socketKey); boolean newServiceFound = false; boolean serviceBecomesComplete = false; if (currentResponse == null) { newServiceFound = true; if (serviceInstanceName != null) { - instanceNameToResponse.put(serviceInstanceName, response); + serviceCache.addOrUpdateService(serviceType, socketKey, response); } } else { boolean before = currentResponse.isComplete(); - instanceNameToResponse.put(serviceInstanceName, response); + serviceCache.addOrUpdateService(serviceType, socketKey, response); boolean after = response.isComplete(); serviceBecomesComplete = !before && after; } @@ -497,7 +505,8 @@ public class MdnsServiceTypeClient { } private void onGoodbyeReceived(@Nullable String serviceInstanceName) { - final MdnsResponse response = instanceNameToResponse.remove(serviceInstanceName); + final MdnsResponse response = + serviceCache.removeService(serviceInstanceName, serviceType, socketKey); if (response == null) { return; } @@ -673,7 +682,8 @@ public class MdnsServiceTypeClient { if (resolveName == null) { continue; } - MdnsResponse knownResponse = instanceNameToResponse.get(resolveName); + MdnsResponse knownResponse = + serviceCache.getCachedService(resolveName, serviceType, socketKey); if (knownResponse == null) { final ArrayList instanceFullName = new ArrayList<>( serviceTypeLabels.length + 1); @@ -691,19 +701,21 @@ public class MdnsServiceTypeClient { private void tryRemoveServiceAfterTtlExpires() { if (!shouldRemoveServiceAfterTtlExpires()) return; - Iterator iter = instanceNameToResponse.values().iterator(); + Iterator iter = + serviceCache.getCachedServices(serviceType, socketKey).iterator(); while (iter.hasNext()) { MdnsResponse existingResponse = iter.next(); + final String serviceInstanceName = existingResponse.getServiceInstanceName(); if (existingResponse.hasServiceRecord() && existingResponse.getServiceRecord() .getRemainingTTL(clock.elapsedRealtime()) == 0) { - iter.remove(); + serviceCache.removeService(serviceInstanceName, serviceType, socketKey); for (int i = 0; i < listeners.size(); i++) { if (!responseMatchesOptions(existingResponse, listeners.valueAt(i))) { continue; } final MdnsServiceBrowserListener listener = listeners.keyAt(i); - if (existingResponse.getServiceInstanceName() != null) { + if (serviceInstanceName != null) { final MdnsServiceInfo serviceInfo = buildMdnsServiceInfoFromResponse( existingResponse, serviceTypeLabels); if (existingResponse.isComplete()) { @@ -812,7 +824,7 @@ public class MdnsServiceTypeClient { private long getMinRemainingTtl(long now) { long minRemainingTtl = Long.MAX_VALUE; - for (MdnsResponse response : instanceNameToResponse.values()) { + for (MdnsResponse response : serviceCache.getCachedServices(serviceType, socketKey)) { if (!response.isComplete()) { continue; } diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java index 4328053503..11c9653280 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java @@ -131,6 +131,7 @@ public class MdnsServiceTypeClientTests { private SocketKey socketKey; private HandlerThread thread; private Handler handler; + private MdnsServiceCache serviceCache; private long latestDelayMs = 0; private Message delayMessage = null; private Handler realHandler = null; @@ -190,6 +191,7 @@ public class MdnsServiceTypeClientTests { thread = new HandlerThread("MdnsServiceTypeClientTests"); thread.start(); handler = new Handler(thread.getLooper()); + serviceCache = new MdnsServiceCache(thread.getLooper()); doAnswer(inv -> { latestDelayMs = 0; @@ -213,7 +215,8 @@ public class MdnsServiceTypeClientTests { client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor, - mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps) { + mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps, + serviceCache) { @Override MdnsPacketWriter createMdnsPacketWriter() { return mockPacketWriter; @@ -908,7 +911,8 @@ public class MdnsServiceTypeClientTests { final String serviceInstanceName = "service-instance-1"; client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor, - mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps) { + mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps, + serviceCache) { @Override MdnsPacketWriter createMdnsPacketWriter() { return mockPacketWriter; @@ -953,7 +957,8 @@ public class MdnsServiceTypeClientTests { final String serviceInstanceName = "service-instance-1"; client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor, - mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps) { + mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps, + serviceCache) { @Override MdnsPacketWriter createMdnsPacketWriter() { return mockPacketWriter; @@ -986,7 +991,8 @@ public class MdnsServiceTypeClientTests { final String serviceInstanceName = "service-instance-1"; client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor, - mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps) { + mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps, + serviceCache) { @Override MdnsPacketWriter createMdnsPacketWriter() { return mockPacketWriter; @@ -1106,7 +1112,8 @@ public class MdnsServiceTypeClientTests { @Test public void testProcessResponse_Resolve() throws Exception { client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor, - mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps); + mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps, + serviceCache); final String instanceName = "service-instance"; final String[] hostname = new String[] { "testhost "}; @@ -1199,7 +1206,8 @@ public class MdnsServiceTypeClientTests { @Test public void testRenewTxtSrvInResolve() throws Exception { client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor, - mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps); + mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps, + serviceCache); final String instanceName = "service-instance"; final String[] hostname = new String[] { "testhost "}; @@ -1312,7 +1320,8 @@ public class MdnsServiceTypeClientTests { @Test public void testProcessResponse_ResolveExcludesOtherServices() { client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor, - mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps); + mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps, + serviceCache); final String requestedInstance = "instance1"; final String otherInstance = "instance2"; @@ -1376,7 +1385,8 @@ public class MdnsServiceTypeClientTests { @Test public void testProcessResponse_SubtypeDiscoveryLimitedToSubtype() { client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor, - mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps); + mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps, + serviceCache); final String matchingInstance = "instance1"; final String subtype = "_subtype"; @@ -1457,7 +1467,8 @@ public class MdnsServiceTypeClientTests { @Test public void testNotifySocketDestroyed() throws Exception { client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor, - mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps); + mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps, + serviceCache); final String requestedInstance = "instance1"; final String otherInstance = "instance2"; @@ -1512,16 +1523,109 @@ public class MdnsServiceTypeClientTests { verify(mockListenerOne, never()).onServiceNameRemoved(matchServiceName(otherInstance)); // mockListenerTwo gets notified for both though - final InOrder inOrder2 = inOrder(mockListenerTwo); - inOrder2.verify(mockListenerTwo).onServiceNameDiscovered( + verify(mockListenerTwo).onServiceNameDiscovered( matchServiceName(requestedInstance)); - inOrder2.verify(mockListenerTwo).onServiceFound(matchServiceName(requestedInstance)); - inOrder2.verify(mockListenerTwo).onServiceNameDiscovered(matchServiceName(otherInstance)); - inOrder2.verify(mockListenerTwo).onServiceFound(matchServiceName(otherInstance)); - inOrder2.verify(mockListenerTwo).onServiceRemoved(matchServiceName(otherInstance)); - inOrder2.verify(mockListenerTwo).onServiceNameRemoved(matchServiceName(otherInstance)); - inOrder2.verify(mockListenerTwo).onServiceRemoved(matchServiceName(requestedInstance)); - inOrder2.verify(mockListenerTwo).onServiceNameRemoved(matchServiceName(requestedInstance)); + verify(mockListenerTwo).onServiceFound(matchServiceName(requestedInstance)); + verify(mockListenerTwo).onServiceNameDiscovered(matchServiceName(otherInstance)); + verify(mockListenerTwo).onServiceFound(matchServiceName(otherInstance)); + verify(mockListenerTwo).onServiceRemoved(matchServiceName(otherInstance)); + verify(mockListenerTwo).onServiceNameRemoved(matchServiceName(otherInstance)); + verify(mockListenerTwo).onServiceRemoved(matchServiceName(requestedInstance)); + verify(mockListenerTwo).onServiceNameRemoved(matchServiceName(requestedInstance)); + } + + @Test + public void testServicesAreCached() throws Exception { + final String serviceName = "service-instance"; + final String ipV4Address = "192.0.2.0"; + // Register a listener + startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions()); + verify(mockDeps, times(1)).removeMessages(any(), eq(EVENT_START_QUERYTASK)); + InOrder inOrder = inOrder(mockListenerOne); + + // Process a response which has ip address to make response become complete. + final String subtype = "ABCDE"; + processResponse(createResponse( + serviceName, ipV4Address, 5353, subtype, + Collections.emptyMap(), TEST_TTL), + socketKey); + + // Verify that onServiceNameDiscovered is called. + inOrder.verify(mockListenerOne).onServiceNameDiscovered(serviceInfoCaptor.capture()); + verifyServiceInfo(serviceInfoCaptor.getAllValues().get(0), + serviceName, + SERVICE_TYPE_LABELS, + List.of(ipV4Address) /* ipv4Address */, + List.of() /* ipv6Address */, + 5353 /* port */, + Collections.singletonList(subtype) /* subTypes */, + Collections.singletonMap("key", null) /* attributes */, + socketKey); + + // Verify that onServiceFound is called. + inOrder.verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture()); + verifyServiceInfo(serviceInfoCaptor.getAllValues().get(1), + serviceName, + SERVICE_TYPE_LABELS, + List.of(ipV4Address) /* ipv4Address */, + List.of() /* ipv6Address */, + 5353 /* port */, + Collections.singletonList(subtype) /* subTypes */, + Collections.singletonMap("key", null) /* attributes */, + socketKey); + + // Unregister the listener + stopSendAndReceive(mockListenerOne); + verify(mockDeps, times(2)).removeMessages(any(), eq(EVENT_START_QUERYTASK)); + + // Register another listener. + startSendAndReceive(mockListenerTwo, MdnsSearchOptions.getDefaultOptions()); + verify(mockDeps, times(3)).removeMessages(any(), eq(EVENT_START_QUERYTASK)); + InOrder inOrder2 = inOrder(mockListenerTwo); + + // The services are cached in MdnsServiceCache, verify that onServiceNameDiscovered is + // called immediately. + inOrder2.verify(mockListenerTwo).onServiceNameDiscovered(serviceInfoCaptor.capture()); + verifyServiceInfo(serviceInfoCaptor.getAllValues().get(2), + serviceName, + SERVICE_TYPE_LABELS, + List.of(ipV4Address) /* ipv4Address */, + List.of() /* ipv6Address */, + 5353 /* port */, + Collections.singletonList(subtype) /* subTypes */, + Collections.singletonMap("key", null) /* attributes */, + socketKey); + + // The services are cached in MdnsServiceCache, verify that onServiceFound is + // called immediately. + inOrder2.verify(mockListenerTwo).onServiceFound(serviceInfoCaptor.capture()); + verifyServiceInfo(serviceInfoCaptor.getAllValues().get(3), + serviceName, + SERVICE_TYPE_LABELS, + List.of(ipV4Address) /* ipv4Address */, + List.of() /* ipv6Address */, + 5353 /* port */, + Collections.singletonList(subtype) /* subTypes */, + Collections.singletonMap("key", null) /* attributes */, + socketKey); + + // Process a response with a different ip address, port and updated text attributes. + final String ipV6Address = "2001:db8::"; + processResponse(createResponse( + serviceName, ipV6Address, 5354, subtype, + Collections.singletonMap("key", "value"), TEST_TTL), socketKey); + + // Verify the onServiceUpdated is called. + inOrder2.verify(mockListenerTwo).onServiceUpdated(serviceInfoCaptor.capture()); + verifyServiceInfo(serviceInfoCaptor.getAllValues().get(4), + serviceName, + SERVICE_TYPE_LABELS, + List.of(ipV4Address) /* ipv4Address */, + List.of(ipV6Address) /* ipv6Address */, + 5354 /* port */, + Collections.singletonList(subtype) /* subTypes */, + Collections.singletonMap("key", "value") /* attributes */, + socketKey); } private static MdnsServiceInfo matchServiceName(String name) {