Merge "Add ServiceExpiredCallback" into main

This commit is contained in:
Paul Hu
2023-10-16 01:41:49 +00:00
committed by Gerrit Code Review
3 changed files with 132 additions and 116 deletions

View File

@@ -42,7 +42,7 @@ import java.util.Objects;
* to their default value (0, false or null).
*/
public class MdnsServiceCache {
private static class CacheKey {
static class CacheKey {
@NonNull final String mLowercaseServiceType;
@NonNull final SocketKey mSocketKey;
@@ -72,6 +72,12 @@ public class MdnsServiceCache {
*/
@NonNull
private final ArrayMap<CacheKey, List<MdnsResponse>> mCachedServices = new ArrayMap<>();
/**
* A map of service expire callbacks. Key is composed of service type and socket and value is
* the callback listener.
*/
@NonNull
private final ArrayMap<CacheKey, ServiceExpiredCallback> mCallbacks = new ArrayMap<>();
@NonNull
private final Handler mHandler;
@@ -82,17 +88,14 @@ public class MdnsServiceCache {
/**
* Get the cache services which are queried from given service type and socket.
*
* @param serviceType the target service type.
* @param socketKey the target socket
* @param cacheKey the target CacheKey.
* @return the set of services which matches the given service type.
*/
@NonNull
public List<MdnsResponse> getCachedServices(@NonNull String serviceType,
@NonNull SocketKey socketKey) {
public List<MdnsResponse> getCachedServices(@NonNull CacheKey cacheKey) {
ensureRunningOnHandlerThread(mHandler);
final CacheKey key = new CacheKey(serviceType, socketKey);
return mCachedServices.containsKey(key)
? Collections.unmodifiableList(new ArrayList<>(mCachedServices.get(key)))
return mCachedServices.containsKey(cacheKey)
? Collections.unmodifiableList(new ArrayList<>(mCachedServices.get(cacheKey)))
: Collections.emptyList();
}
@@ -117,16 +120,13 @@ public class MdnsServiceCache {
* Get the cache service.
*
* @param serviceName the target service name.
* @param serviceType the target service type.
* @param socketKey the target socket
* @param cacheKey the target CacheKey.
* @return the service which matches given conditions.
*/
@Nullable
public MdnsResponse getCachedService(@NonNull String serviceName,
@NonNull String serviceType, @NonNull SocketKey socketKey) {
public MdnsResponse getCachedService(@NonNull String serviceName, @NonNull CacheKey cacheKey) {
ensureRunningOnHandlerThread(mHandler);
final List<MdnsResponse> responses =
mCachedServices.get(new CacheKey(serviceType, socketKey));
final List<MdnsResponse> responses = mCachedServices.get(cacheKey);
if (responses == null) {
return null;
}
@@ -137,15 +137,13 @@ public class MdnsServiceCache {
/**
* Add or update a service.
*
* @param serviceType the service type.
* @param socketKey the target socket
* @param cacheKey the target CacheKey.
* @param response the response of the discovered service.
*/
public void addOrUpdateService(@NonNull String serviceType, @NonNull SocketKey socketKey,
@NonNull MdnsResponse response) {
public void addOrUpdateService(@NonNull CacheKey cacheKey, @NonNull MdnsResponse response) {
ensureRunningOnHandlerThread(mHandler);
final List<MdnsResponse> responses = mCachedServices.computeIfAbsent(
new CacheKey(serviceType, socketKey), key -> new ArrayList<>());
cacheKey, key -> new ArrayList<>());
// Remove existing service if present.
final MdnsResponse existing =
findMatchedResponse(responses, response.getServiceInstanceName());
@@ -157,15 +155,12 @@ public class MdnsServiceCache {
* Remove a service which matches the given service name, type and socket.
*
* @param serviceName the target service name.
* @param serviceType the target service type.
* @param socketKey the target socket.
* @param cacheKey the target CacheKey.
*/
@Nullable
public MdnsResponse removeService(@NonNull String serviceName, @NonNull String serviceType,
@NonNull SocketKey socketKey) {
public MdnsResponse removeService(@NonNull String serviceName, @NonNull CacheKey cacheKey) {
ensureRunningOnHandlerThread(mHandler);
final List<MdnsResponse> responses =
mCachedServices.get(new CacheKey(serviceType, socketKey));
final List<MdnsResponse> responses = mCachedServices.get(cacheKey);
if (responses == null) {
return null;
}
@@ -180,5 +175,37 @@ public class MdnsServiceCache {
return null;
}
/**
* Register a callback to listen to service expiration.
*
* <p> Registering the same callback instance twice is a no-op, since MdnsServiceTypeClient
* relies on this.
*
* @param cacheKey the target CacheKey.
* @param callback the callback that notify the service is expired.
*/
public void registerServiceExpiredCallback(@NonNull CacheKey cacheKey,
@NonNull ServiceExpiredCallback callback) {
ensureRunningOnHandlerThread(mHandler);
mCallbacks.put(cacheKey, callback);
}
/**
* Unregister the service expired callback.
*
* @param cacheKey the CacheKey that is registered to listen service expiration before.
*/
public void unregisterServiceExpiredCallback(@NonNull CacheKey cacheKey) {
ensureRunningOnHandlerThread(mHandler);
mCallbacks.remove(cacheKey);
}
/*** Callbacks for listening service expiration */
public interface ServiceExpiredCallback {
/*** Notify the service is expired */
void onServiceRecordExpired(@NonNull MdnsResponse previousResponse,
@Nullable MdnsResponse newResponse);
}
// TODO: check ttl expiration for each service and notify to the clients.
}

View File

@@ -16,6 +16,7 @@
package com.android.server.connectivity.mdns;
import static com.android.server.connectivity.mdns.MdnsServiceCache.ServiceExpiredCallback;
import static com.android.server.connectivity.mdns.MdnsServiceCache.findMatchedResponse;
import static com.android.server.connectivity.mdns.util.MdnsUtils.Clock;
import static com.android.server.connectivity.mdns.util.MdnsUtils.ensureRunningOnHandlerThread;
@@ -71,6 +72,15 @@ public class MdnsServiceTypeClient {
* The service caches for each socket. It should be accessed from looper thread only.
*/
@NonNull private final MdnsServiceCache serviceCache;
@NonNull private final MdnsServiceCache.CacheKey cacheKey;
@NonNull private final ServiceExpiredCallback serviceExpiredCallback =
new ServiceExpiredCallback() {
@Override
public void onServiceRecordExpired(@NonNull MdnsResponse previousResponse,
@Nullable MdnsResponse newResponse) {
notifyRemovedServiceToListeners(previousResponse, "Service record expired");
}
};
private final ArrayMap<MdnsServiceBrowserListener, MdnsSearchOptions> listeners =
new ArrayMap<>();
private final boolean removeServiceAfterTtlExpires =
@@ -225,6 +235,16 @@ public class MdnsServiceTypeClient {
this.dependencies = dependencies;
this.serviceCache = serviceCache;
this.mdnsQueryScheduler = new MdnsQueryScheduler();
this.cacheKey = new MdnsServiceCache.CacheKey(serviceType, socketKey);
}
/**
* Do the cleanup of the MdnsServiceTypeClient
*/
private void shutDown() {
removeScheduledTask();
mdnsQueryScheduler.cancelScheduledRun();
serviceCache.unregisterServiceExpiredCallback(cacheKey);
}
private static MdnsServiceInfo buildMdnsServiceInfoFromResponse(
@@ -293,7 +313,7 @@ public class MdnsServiceTypeClient {
boolean hadReply = false;
if (listeners.put(listener, searchOptions) == null) {
for (MdnsResponse existingResponse :
serviceCache.getCachedServices(serviceType, socketKey)) {
serviceCache.getCachedServices(cacheKey)) {
if (!responseMatchesOptions(existingResponse, searchOptions)) continue;
final MdnsServiceInfo info =
buildMdnsServiceInfoFromResponse(existingResponse, serviceTypeLabels);
@@ -341,6 +361,8 @@ public class MdnsServiceTypeClient {
servicesToResolve.size() < listeners.size() /* sendDiscoveryQueries */);
executor.submit(queryTask);
}
serviceCache.registerServiceExpiredCallback(cacheKey, serviceExpiredCallback);
}
/**
@@ -390,8 +412,7 @@ public class MdnsServiceTypeClient {
return listeners.isEmpty();
}
if (listeners.isEmpty()) {
removeScheduledTask();
mdnsQueryScheduler.cancelScheduledRun();
shutDown();
}
return listeners.isEmpty();
}
@@ -404,8 +425,7 @@ public class MdnsServiceTypeClient {
ensureRunningOnHandlerThread(handler);
// Augment the list of current known responses, and generated responses for resolve
// requests if there is no known response
final List<MdnsResponse> cachedList =
serviceCache.getCachedServices(serviceType, socketKey);
final List<MdnsResponse> cachedList = serviceCache.getCachedServices(cacheKey);
final List<MdnsResponse> currentList = new ArrayList<>(cachedList);
List<MdnsResponse> additionalResponses = makeResponsesForResolve(socketKey);
for (MdnsResponse additionalResponse : additionalResponses) {
@@ -432,7 +452,7 @@ public class MdnsServiceTypeClient {
} else if (findMatchedResponse(cachedList, serviceInstanceName) != null) {
// If the response is not modified and already in the cache. The cache will
// need to be updated to refresh the last receipt time.
serviceCache.addOrUpdateService(serviceType, socketKey, response);
serviceCache.addOrUpdateService(cacheKey, response);
}
}
if (dependencies.hasMessages(handler, EVENT_START_QUERYTASK)) {
@@ -458,44 +478,50 @@ public class MdnsServiceTypeClient {
}
}
/** Notify all services are removed because the socket is destroyed. */
public void notifySocketDestroyed() {
ensureRunningOnHandlerThread(handler);
for (MdnsResponse response : serviceCache.getCachedServices(serviceType, socketKey)) {
final String name = response.getServiceInstanceName();
if (name == null) continue;
private void notifyRemovedServiceToListeners(@NonNull MdnsResponse response,
@NonNull String message) {
for (int i = 0; i < listeners.size(); i++) {
if (!responseMatchesOptions(response, listeners.valueAt(i))) continue;
final MdnsServiceBrowserListener listener = listeners.keyAt(i);
final MdnsServiceInfo serviceInfo =
buildMdnsServiceInfoFromResponse(response, serviceTypeLabels);
if (response.getServiceInstanceName() != null) {
final MdnsServiceInfo serviceInfo = buildMdnsServiceInfoFromResponse(
response, serviceTypeLabels);
if (response.isComplete()) {
sharedLog.log("Socket destroyed. onServiceRemoved: " + name);
sharedLog.log(message + ". onServiceRemoved: " + serviceInfo);
listener.onServiceRemoved(serviceInfo);
}
sharedLog.log("Socket destroyed. onServiceNameRemoved: " + name);
sharedLog.log(message + ". onServiceNameRemoved: " + serviceInfo);
listener.onServiceNameRemoved(serviceInfo);
}
}
removeScheduledTask();
mdnsQueryScheduler.cancelScheduledRun();
}
/** Notify all services are removed because the socket is destroyed. */
public void notifySocketDestroyed() {
ensureRunningOnHandlerThread(handler);
for (MdnsResponse response : serviceCache.getCachedServices(cacheKey)) {
final String name = response.getServiceInstanceName();
if (name == null) continue;
notifyRemovedServiceToListeners(response, "Socket destroyed");
}
shutDown();
}
private void onResponseModified(@NonNull MdnsResponse response) {
final String serviceInstanceName = response.getServiceInstanceName();
final MdnsResponse currentResponse =
serviceCache.getCachedService(serviceInstanceName, serviceType, socketKey);
serviceCache.getCachedService(serviceInstanceName, cacheKey);
boolean newServiceFound = false;
boolean serviceBecomesComplete = false;
if (currentResponse == null) {
newServiceFound = true;
if (serviceInstanceName != null) {
serviceCache.addOrUpdateService(serviceType, socketKey, response);
serviceCache.addOrUpdateService(cacheKey, response);
}
} else {
boolean before = currentResponse.isComplete();
serviceCache.addOrUpdateService(serviceType, socketKey, response);
serviceCache.addOrUpdateService(cacheKey, response);
boolean after = response.isComplete();
serviceBecomesComplete = !before && after;
}
@@ -529,22 +555,11 @@ public class MdnsServiceTypeClient {
private void onGoodbyeReceived(@Nullable String serviceInstanceName) {
final MdnsResponse response =
serviceCache.removeService(serviceInstanceName, serviceType, socketKey);
serviceCache.removeService(serviceInstanceName, cacheKey);
if (response == null) {
return;
}
for (int i = 0; i < listeners.size(); i++) {
if (!responseMatchesOptions(response, listeners.valueAt(i))) continue;
final MdnsServiceBrowserListener listener = listeners.keyAt(i);
final MdnsServiceInfo serviceInfo =
buildMdnsServiceInfoFromResponse(response, serviceTypeLabels);
if (response.isComplete()) {
sharedLog.log("onServiceRemoved: " + serviceInfo);
listener.onServiceRemoved(serviceInfo);
}
sharedLog.log("onServiceNameRemoved: " + serviceInfo);
listener.onServiceNameRemoved(serviceInfo);
}
notifyRemovedServiceToListeners(response, "Goodbye received");
}
private boolean shouldRemoveServiceAfterTtlExpires() {
@@ -567,7 +582,7 @@ public class MdnsServiceTypeClient {
continue;
}
MdnsResponse knownResponse =
serviceCache.getCachedService(resolveName, serviceType, socketKey);
serviceCache.getCachedService(resolveName, cacheKey);
if (knownResponse == null) {
final ArrayList<String> instanceFullName = new ArrayList<>(
serviceTypeLabels.length + 1);
@@ -585,35 +600,17 @@ public class MdnsServiceTypeClient {
private void tryRemoveServiceAfterTtlExpires() {
if (!shouldRemoveServiceAfterTtlExpires()) return;
Iterator<MdnsResponse> iter =
serviceCache.getCachedServices(serviceType, socketKey).iterator();
final Iterator<MdnsResponse> iter = serviceCache.getCachedServices(cacheKey).iterator();
while (iter.hasNext()) {
MdnsResponse existingResponse = iter.next();
final String serviceInstanceName = existingResponse.getServiceInstanceName();
if (existingResponse.hasServiceRecord()
&& existingResponse.getServiceRecord()
.getRemainingTTL(clock.elapsedRealtime()) == 0) {
serviceCache.removeService(serviceInstanceName, serviceType, socketKey);
for (int i = 0; i < listeners.size(); i++) {
if (!responseMatchesOptions(existingResponse, listeners.valueAt(i))) {
continue;
}
final MdnsServiceBrowserListener listener = listeners.keyAt(i);
if (serviceInstanceName != null) {
final MdnsServiceInfo serviceInfo = buildMdnsServiceInfoFromResponse(
existingResponse, serviceTypeLabels);
if (existingResponse.isComplete()) {
sharedLog.log("TTL expired. onServiceRemoved: " + serviceInfo);
listener.onServiceRemoved(serviceInfo);
}
sharedLog.log("TTL expired. onServiceNameRemoved: " + serviceInfo);
listener.onServiceNameRemoved(serviceInfo);
serviceCache.removeService(existingResponse.getServiceInstanceName(), cacheKey);
notifyRemovedServiceToListeners(existingResponse, "TTL expired");
}
}
}
}
}
private static class QuerySentArguments {
private final int transactionId;
@@ -672,7 +669,7 @@ public class MdnsServiceTypeClient {
private long getMinRemainingTtl(long now) {
long minRemainingTtl = Long.MAX_VALUE;
for (MdnsResponse response : serviceCache.getCachedServices(serviceType, socketKey)) {
for (MdnsResponse response : serviceCache.getCachedServices(cacheKey)) {
if (!response.isComplete()) {
continue;
}

View File

@@ -19,6 +19,7 @@ package com.android.server.connectivity.mdns
import android.os.Build
import android.os.Handler
import android.os.HandlerThread
import com.android.server.connectivity.mdns.MdnsServiceCache.CacheKey
import com.android.testutils.DevSdkIgnoreRule
import com.android.testutils.DevSdkIgnoreRunner
import java.util.concurrent.CompletableFuture
@@ -43,6 +44,8 @@ private const val DEFAULT_TIMEOUT_MS = 2000L
@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
class MdnsServiceCacheTest {
private val socketKey = SocketKey(null /* network */, INTERFACE_INDEX)
private val cacheKey1 = CacheKey(SERVICE_TYPE_1, socketKey)
private val cacheKey2 = CacheKey(SERVICE_TYPE_2, socketKey)
private val thread = HandlerThread(MdnsServiceCacheTest::class.simpleName)
private val handler by lazy {
Handler(thread.looper)
@@ -69,47 +72,36 @@ class MdnsServiceCacheTest {
return future.get(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS)
}
private fun addOrUpdateService(
serviceType: String,
socketKey: SocketKey,
service: MdnsResponse
): Unit = runningOnHandlerAndReturn {
serviceCache.addOrUpdateService(serviceType, socketKey, service)
}
private fun addOrUpdateService(cacheKey: CacheKey, service: MdnsResponse): Unit =
runningOnHandlerAndReturn { serviceCache.addOrUpdateService(cacheKey, service) }
private fun removeService(serviceName: String, serviceType: String, socketKey: SocketKey):
Unit = runningOnHandlerAndReturn {
serviceCache.removeService(serviceName, serviceType, socketKey) }
private fun removeService(serviceName: String, cacheKey: CacheKey): Unit =
runningOnHandlerAndReturn { serviceCache.removeService(serviceName, cacheKey) }
private fun getService(serviceName: String, serviceType: String, socketKey: SocketKey):
MdnsResponse? = runningOnHandlerAndReturn {
serviceCache.getCachedService(serviceName, serviceType, socketKey) }
private fun getService(serviceName: String, cacheKey: CacheKey): MdnsResponse? =
runningOnHandlerAndReturn { serviceCache.getCachedService(serviceName, cacheKey) }
private fun getServices(serviceType: String, socketKey: SocketKey): List<MdnsResponse> =
runningOnHandlerAndReturn { serviceCache.getCachedServices(serviceType, socketKey) }
private fun getServices(cacheKey: CacheKey): List<MdnsResponse> =
runningOnHandlerAndReturn { serviceCache.getCachedServices(cacheKey) }
@Test
fun testAddAndRemoveService() {
addOrUpdateService(
SERVICE_TYPE_1, socketKey, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
var response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, socketKey)
addOrUpdateService(cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
var response = getService(SERVICE_NAME_1, cacheKey1)
assertNotNull(response)
assertEquals(SERVICE_NAME_1, response.serviceInstanceName)
removeService(SERVICE_NAME_1, SERVICE_TYPE_1, socketKey)
response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, socketKey)
removeService(SERVICE_NAME_1, cacheKey1)
response = getService(SERVICE_NAME_1, cacheKey1)
assertNull(response)
}
@Test
fun testGetCachedServices_multipleServiceTypes() {
addOrUpdateService(
SERVICE_TYPE_1, socketKey, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
addOrUpdateService(
SERVICE_TYPE_1, socketKey, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1))
addOrUpdateService(
SERVICE_TYPE_2, socketKey, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2))
addOrUpdateService(cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
addOrUpdateService(cacheKey1, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1))
addOrUpdateService(cacheKey2, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2))
val responses1 = getServices(SERVICE_TYPE_1, socketKey)
val responses1 = getServices(cacheKey1)
assertEquals(2, responses1.size)
assertTrue(responses1.stream().anyMatch { response ->
response.serviceInstanceName == SERVICE_NAME_1
@@ -117,19 +109,19 @@ class MdnsServiceCacheTest {
assertTrue(responses1.any { response ->
response.serviceInstanceName == SERVICE_NAME_2
})
val responses2 = getServices(SERVICE_TYPE_2, socketKey)
val responses2 = getServices(cacheKey2)
assertEquals(1, responses2.size)
assertTrue(responses2.any { response ->
response.serviceInstanceName == SERVICE_NAME_2
})
removeService(SERVICE_NAME_2, SERVICE_TYPE_1, socketKey)
val responses3 = getServices(SERVICE_TYPE_1, socketKey)
removeService(SERVICE_NAME_2, cacheKey1)
val responses3 = getServices(cacheKey1)
assertEquals(1, responses3.size)
assertTrue(responses3.any { response ->
response.serviceInstanceName == SERVICE_NAME_1
})
val responses4 = getServices(SERVICE_TYPE_2, socketKey)
val responses4 = getServices(cacheKey2)
assertEquals(1, responses4.size)
assertTrue(responses4.any { response ->
response.serviceInstanceName == SERVICE_NAME_2