Check service ttl expiration
Now, every services are cached on MdnsServiceCache, so the remaining TTL should be checked when retrieving services from the MdnsServiceCache and have a callback to notify the MdnsServiceTypeClient about expired services. Bug: 265787401 Test: atest FrameworksNetTests CtsNetTestCases Change-Id: I99da6cc79bdf5df3c899e642e067907501bc9d4f
This commit is contained in:
@@ -1645,8 +1645,8 @@ public class NsdService extends INsdManager.Stub {
|
|||||||
mContext, MdnsFeatureFlags.NSD_FORCE_DISABLE_MDNS_OFFLOAD))
|
mContext, MdnsFeatureFlags.NSD_FORCE_DISABLE_MDNS_OFFLOAD))
|
||||||
.setIncludeInetAddressRecordsInProbing(mDeps.isFeatureEnabled(
|
.setIncludeInetAddressRecordsInProbing(mDeps.isFeatureEnabled(
|
||||||
mContext, MdnsFeatureFlags.INCLUDE_INET_ADDRESS_RECORDS_IN_PROBING))
|
mContext, MdnsFeatureFlags.INCLUDE_INET_ADDRESS_RECORDS_IN_PROBING))
|
||||||
.setIsExpiredServicesRemovalEnabled(mDeps.isTrunkStableFeatureEnabled(
|
.setIsExpiredServicesRemovalEnabled(mDeps.isFeatureEnabled(
|
||||||
MdnsFeatureFlags.NSD_EXPIRED_SERVICES_REMOVAL))
|
mContext, MdnsFeatureFlags.NSD_EXPIRED_SERVICES_REMOVAL))
|
||||||
.setIsLabelCountLimitEnabled(mDeps.isTetheringFeatureNotChickenedOut(
|
.setIsLabelCountLimitEnabled(mDeps.isTetheringFeatureNotChickenedOut(
|
||||||
mContext, MdnsFeatureFlags.NSD_LIMIT_LABEL_COUNT))
|
mContext, MdnsFeatureFlags.NSD_LIMIT_LABEL_COUNT))
|
||||||
.build();
|
.build();
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ public class MdnsFeatureFlags {
|
|||||||
public Builder() {
|
public Builder() {
|
||||||
mIsMdnsOffloadFeatureEnabled = false;
|
mIsMdnsOffloadFeatureEnabled = false;
|
||||||
mIncludeInetAddressRecordsInProbing = false;
|
mIncludeInetAddressRecordsInProbing = false;
|
||||||
mIsExpiredServicesRemovalEnabled = true; // Default enabled.
|
mIsExpiredServicesRemovalEnabled = false;
|
||||||
mIsLabelCountLimitEnabled = true; // Default enabled.
|
mIsLabelCountLimitEnabled = true; // Default enabled.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ import java.util.Objects;
|
|||||||
|
|
||||||
/** An mDNS response. */
|
/** An mDNS response. */
|
||||||
public class MdnsResponse {
|
public class MdnsResponse {
|
||||||
|
public static final long EXPIRATION_NEVER = Long.MAX_VALUE;
|
||||||
private final List<MdnsRecord> records;
|
private final List<MdnsRecord> records;
|
||||||
private final List<MdnsPointerRecord> pointerRecords;
|
private final List<MdnsPointerRecord> pointerRecords;
|
||||||
private MdnsServiceRecord serviceRecord;
|
private MdnsServiceRecord serviceRecord;
|
||||||
@@ -349,6 +350,21 @@ public class MdnsResponse {
|
|||||||
return serviceName;
|
return serviceName;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Get the min remaining ttl time from received records */
|
||||||
|
public long getMinRemainingTtl(long now) {
|
||||||
|
long minRemainingTtl = EXPIRATION_NEVER;
|
||||||
|
// TODO: Check other records(A, AAAA, TXT) ttl time.
|
||||||
|
if (!hasServiceRecord()) {
|
||||||
|
return EXPIRATION_NEVER;
|
||||||
|
}
|
||||||
|
// Check ttl time.
|
||||||
|
long remainingTtl = serviceRecord.getRemainingTTL(now);
|
||||||
|
if (remainingTtl < minRemainingTtl) {
|
||||||
|
minRemainingTtl = remainingTtl;
|
||||||
|
}
|
||||||
|
return minRemainingTtl;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tests if this response is a goodbye message. This will be true if a service record is present
|
* Tests if this response is a goodbye message. This will be true if a service record is present
|
||||||
* and any of the records have a TTL of 0.
|
* and any of the records have a TTL of 0.
|
||||||
|
|||||||
@@ -16,16 +16,22 @@
|
|||||||
|
|
||||||
package com.android.server.connectivity.mdns;
|
package com.android.server.connectivity.mdns;
|
||||||
|
|
||||||
|
import static com.android.server.connectivity.mdns.MdnsResponse.EXPIRATION_NEVER;
|
||||||
import static com.android.server.connectivity.mdns.util.MdnsUtils.ensureRunningOnHandlerThread;
|
import static com.android.server.connectivity.mdns.util.MdnsUtils.ensureRunningOnHandlerThread;
|
||||||
import static com.android.server.connectivity.mdns.util.MdnsUtils.equalsIgnoreDnsCase;
|
import static com.android.server.connectivity.mdns.util.MdnsUtils.equalsIgnoreDnsCase;
|
||||||
import static com.android.server.connectivity.mdns.util.MdnsUtils.toDnsLowerCase;
|
import static com.android.server.connectivity.mdns.util.MdnsUtils.toDnsLowerCase;
|
||||||
|
|
||||||
|
import static java.lang.Math.min;
|
||||||
|
|
||||||
import android.annotation.NonNull;
|
import android.annotation.NonNull;
|
||||||
import android.annotation.Nullable;
|
import android.annotation.Nullable;
|
||||||
import android.os.Handler;
|
import android.os.Handler;
|
||||||
import android.os.Looper;
|
import android.os.Looper;
|
||||||
import android.util.ArrayMap;
|
import android.util.ArrayMap;
|
||||||
|
|
||||||
|
import com.android.internal.annotations.VisibleForTesting;
|
||||||
|
import com.android.server.connectivity.mdns.util.MdnsUtils;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
@@ -67,8 +73,11 @@ public class MdnsServiceCache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* A map of cached services. Key is composed of service name, type and socket. Value is the
|
* A map of cached services. Key is composed of service type and socket. Value is the list of
|
||||||
* service which use the service type to discover from each socket.
|
* services which are discovered from the given CacheKey.
|
||||||
|
* When the MdnsFeatureFlags#NSD_EXPIRED_SERVICES_REMOVAL flag is enabled, the lists are sorted
|
||||||
|
* by expiration time, with the earliest entries appearing first. This sorting allows the
|
||||||
|
* removal process to progress through the expiration check efficiently.
|
||||||
*/
|
*/
|
||||||
@NonNull
|
@NonNull
|
||||||
private final ArrayMap<CacheKey, List<MdnsResponse>> mCachedServices = new ArrayMap<>();
|
private final ArrayMap<CacheKey, List<MdnsResponse>> mCachedServices = new ArrayMap<>();
|
||||||
@@ -82,10 +91,20 @@ public class MdnsServiceCache {
|
|||||||
private final Handler mHandler;
|
private final Handler mHandler;
|
||||||
@NonNull
|
@NonNull
|
||||||
private final MdnsFeatureFlags mMdnsFeatureFlags;
|
private final MdnsFeatureFlags mMdnsFeatureFlags;
|
||||||
|
@NonNull
|
||||||
|
private final MdnsUtils.Clock mClock;
|
||||||
|
private long mNextExpirationTime = EXPIRATION_NEVER;
|
||||||
|
|
||||||
public MdnsServiceCache(@NonNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
|
public MdnsServiceCache(@NonNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
|
||||||
|
this(looper, mdnsFeatureFlags, new MdnsUtils.Clock());
|
||||||
|
}
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
MdnsServiceCache(@NonNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags,
|
||||||
|
@NonNull MdnsUtils.Clock clock) {
|
||||||
mHandler = new Handler(looper);
|
mHandler = new Handler(looper);
|
||||||
mMdnsFeatureFlags = mdnsFeatureFlags;
|
mMdnsFeatureFlags = mdnsFeatureFlags;
|
||||||
|
mClock = clock;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -97,6 +116,9 @@ public class MdnsServiceCache {
|
|||||||
@NonNull
|
@NonNull
|
||||||
public List<MdnsResponse> getCachedServices(@NonNull CacheKey cacheKey) {
|
public List<MdnsResponse> getCachedServices(@NonNull CacheKey cacheKey) {
|
||||||
ensureRunningOnHandlerThread(mHandler);
|
ensureRunningOnHandlerThread(mHandler);
|
||||||
|
if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) {
|
||||||
|
maybeRemoveExpiredServices(cacheKey, mClock.elapsedRealtime());
|
||||||
|
}
|
||||||
return mCachedServices.containsKey(cacheKey)
|
return mCachedServices.containsKey(cacheKey)
|
||||||
? Collections.unmodifiableList(new ArrayList<>(mCachedServices.get(cacheKey)))
|
? Collections.unmodifiableList(new ArrayList<>(mCachedServices.get(cacheKey)))
|
||||||
: Collections.emptyList();
|
: Collections.emptyList();
|
||||||
@@ -129,6 +151,9 @@ public class MdnsServiceCache {
|
|||||||
@Nullable
|
@Nullable
|
||||||
public MdnsResponse getCachedService(@NonNull String serviceName, @NonNull CacheKey cacheKey) {
|
public MdnsResponse getCachedService(@NonNull String serviceName, @NonNull CacheKey cacheKey) {
|
||||||
ensureRunningOnHandlerThread(mHandler);
|
ensureRunningOnHandlerThread(mHandler);
|
||||||
|
if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) {
|
||||||
|
maybeRemoveExpiredServices(cacheKey, mClock.elapsedRealtime());
|
||||||
|
}
|
||||||
final List<MdnsResponse> responses = mCachedServices.get(cacheKey);
|
final List<MdnsResponse> responses = mCachedServices.get(cacheKey);
|
||||||
if (responses == null) {
|
if (responses == null) {
|
||||||
return null;
|
return null;
|
||||||
@@ -137,6 +162,16 @@ public class MdnsServiceCache {
|
|||||||
return response != null ? new MdnsResponse(response) : null;
|
return response != null ? new MdnsResponse(response) : null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void insertResponseAndSortList(
|
||||||
|
List<MdnsResponse> responses, MdnsResponse response, long now) {
|
||||||
|
// binarySearch returns "the index of the search key, if it is contained in the list;
|
||||||
|
// otherwise, (-(insertion point) - 1)"
|
||||||
|
final int searchRes = Collections.binarySearch(responses, response,
|
||||||
|
// Sort the list by ttl.
|
||||||
|
(o1, o2) -> Long.compare(o1.getMinRemainingTtl(now), o2.getMinRemainingTtl(now)));
|
||||||
|
responses.add(searchRes >= 0 ? searchRes : (-searchRes - 1), response);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add or update a service.
|
* Add or update a service.
|
||||||
*
|
*
|
||||||
@@ -151,7 +186,15 @@ public class MdnsServiceCache {
|
|||||||
final MdnsResponse existing =
|
final MdnsResponse existing =
|
||||||
findMatchedResponse(responses, response.getServiceInstanceName());
|
findMatchedResponse(responses, response.getServiceInstanceName());
|
||||||
responses.remove(existing);
|
responses.remove(existing);
|
||||||
responses.add(response);
|
if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) {
|
||||||
|
final long now = mClock.elapsedRealtime();
|
||||||
|
// Insert and sort service
|
||||||
|
insertResponseAndSortList(responses, response, now);
|
||||||
|
// Update the next expiration check time when a new service is added.
|
||||||
|
mNextExpirationTime = getNextExpirationTime(now);
|
||||||
|
} else {
|
||||||
|
responses.add(response);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -168,14 +211,25 @@ public class MdnsServiceCache {
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
final Iterator<MdnsResponse> iterator = responses.iterator();
|
final Iterator<MdnsResponse> iterator = responses.iterator();
|
||||||
|
MdnsResponse removedResponse = null;
|
||||||
while (iterator.hasNext()) {
|
while (iterator.hasNext()) {
|
||||||
final MdnsResponse response = iterator.next();
|
final MdnsResponse response = iterator.next();
|
||||||
if (equalsIgnoreDnsCase(serviceName, response.getServiceInstanceName())) {
|
if (equalsIgnoreDnsCase(serviceName, response.getServiceInstanceName())) {
|
||||||
iterator.remove();
|
iterator.remove();
|
||||||
return response;
|
removedResponse = response;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return null;
|
|
||||||
|
if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) {
|
||||||
|
// Remove the serviceType if no response.
|
||||||
|
if (responses.isEmpty()) {
|
||||||
|
mCachedServices.remove(cacheKey);
|
||||||
|
}
|
||||||
|
// Update the next expiration check time when a service is removed.
|
||||||
|
mNextExpirationTime = getNextExpirationTime(mClock.elapsedRealtime());
|
||||||
|
}
|
||||||
|
return removedResponse;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -203,6 +257,87 @@ public class MdnsServiceCache {
|
|||||||
mCallbacks.remove(cacheKey);
|
mCallbacks.remove(cacheKey);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void notifyServiceExpired(@NonNull CacheKey cacheKey,
|
||||||
|
@NonNull MdnsResponse previousResponse, @Nullable MdnsResponse newResponse) {
|
||||||
|
final ServiceExpiredCallback callback = mCallbacks.get(cacheKey);
|
||||||
|
if (callback == null) {
|
||||||
|
// The cached service is no listener.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
mHandler.post(()-> callback.onServiceRecordExpired(previousResponse, newResponse));
|
||||||
|
}
|
||||||
|
|
||||||
|
static List<MdnsResponse> removeExpiredServices(@NonNull List<MdnsResponse> responses,
|
||||||
|
long now) {
|
||||||
|
final List<MdnsResponse> removedResponses = new ArrayList<>();
|
||||||
|
final Iterator<MdnsResponse> iterator = responses.iterator();
|
||||||
|
while (iterator.hasNext()) {
|
||||||
|
final MdnsResponse response = iterator.next();
|
||||||
|
// TODO: Check other records (A, AAAA, TXT) ttl time and remove the record if it's
|
||||||
|
// expired. Then send service update notification.
|
||||||
|
if (!response.hasServiceRecord() || response.getMinRemainingTtl(now) > 0) {
|
||||||
|
// The responses are sorted by the service record ttl time. Break out of loop
|
||||||
|
// early if service is not expired or no service record.
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// Remove the ttl expired service.
|
||||||
|
iterator.remove();
|
||||||
|
removedResponses.add(response);
|
||||||
|
}
|
||||||
|
return removedResponses;
|
||||||
|
}
|
||||||
|
|
||||||
|
private long getNextExpirationTime(long now) {
|
||||||
|
if (mCachedServices.isEmpty()) {
|
||||||
|
return EXPIRATION_NEVER;
|
||||||
|
}
|
||||||
|
|
||||||
|
long minRemainingTtl = EXPIRATION_NEVER;
|
||||||
|
for (int i = 0; i < mCachedServices.size(); i++) {
|
||||||
|
minRemainingTtl = min(minRemainingTtl,
|
||||||
|
// The empty lists are not kept in the map, so there's always at least one
|
||||||
|
// element in the list. Therefore, it's fine to get the first element without a
|
||||||
|
// null check.
|
||||||
|
mCachedServices.valueAt(i).get(0).getMinRemainingTtl(now));
|
||||||
|
}
|
||||||
|
return minRemainingTtl == EXPIRATION_NEVER ? EXPIRATION_NEVER : now + minRemainingTtl;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check whether the ttl time is expired on each service and notify to the listeners
|
||||||
|
*/
|
||||||
|
private void maybeRemoveExpiredServices(CacheKey cacheKey, long now) {
|
||||||
|
ensureRunningOnHandlerThread(mHandler);
|
||||||
|
if (now < mNextExpirationTime) {
|
||||||
|
// Skip the check if ttl time is not expired.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
final List<MdnsResponse> responses = mCachedServices.get(cacheKey);
|
||||||
|
if (responses == null) {
|
||||||
|
// No such services.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
final List<MdnsResponse> removedResponses = removeExpiredServices(responses, now);
|
||||||
|
if (removedResponses.isEmpty()) {
|
||||||
|
// No expired services.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (MdnsResponse previousResponse : removedResponses) {
|
||||||
|
notifyServiceExpired(cacheKey, previousResponse, null /* newResponse */);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the serviceType if no response.
|
||||||
|
if (responses.isEmpty()) {
|
||||||
|
mCachedServices.remove(cacheKey);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update next expiration time.
|
||||||
|
mNextExpirationTime = getNextExpirationTime(now);
|
||||||
|
}
|
||||||
|
|
||||||
/*** Callbacks for listening service expiration */
|
/*** Callbacks for listening service expiration */
|
||||||
public interface ServiceExpiredCallback {
|
public interface ServiceExpiredCallback {
|
||||||
/*** Notify the service is expired */
|
/*** Notify the service is expired */
|
||||||
@@ -210,5 +345,5 @@ public class MdnsServiceCache {
|
|||||||
@Nullable MdnsResponse newResponse);
|
@Nullable MdnsResponse newResponse);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: check ttl expiration for each service and notify to the clients.
|
// TODO: Schedule a job to check ttl expiration for all services and notify to the clients.
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -312,8 +312,7 @@ public class MdnsServiceTypeClient {
|
|||||||
this.searchOptions = searchOptions;
|
this.searchOptions = searchOptions;
|
||||||
boolean hadReply = false;
|
boolean hadReply = false;
|
||||||
if (listeners.put(listener, searchOptions) == null) {
|
if (listeners.put(listener, searchOptions) == null) {
|
||||||
for (MdnsResponse existingResponse :
|
for (MdnsResponse existingResponse : serviceCache.getCachedServices(cacheKey)) {
|
||||||
serviceCache.getCachedServices(cacheKey)) {
|
|
||||||
if (!responseMatchesOptions(existingResponse, searchOptions)) continue;
|
if (!responseMatchesOptions(existingResponse, searchOptions)) continue;
|
||||||
final MdnsServiceInfo info =
|
final MdnsServiceInfo info =
|
||||||
buildMdnsServiceInfoFromResponse(existingResponse, serviceTypeLabels);
|
buildMdnsServiceInfoFromResponse(existingResponse, serviceTypeLabels);
|
||||||
|
|||||||
@@ -19,7 +19,10 @@ package com.android.server.connectivity.mdns
|
|||||||
import android.os.Build
|
import android.os.Build
|
||||||
import android.os.Handler
|
import android.os.Handler
|
||||||
import android.os.HandlerThread
|
import android.os.HandlerThread
|
||||||
|
import com.android.net.module.util.ArrayTrackRecord
|
||||||
import com.android.server.connectivity.mdns.MdnsServiceCache.CacheKey
|
import com.android.server.connectivity.mdns.MdnsServiceCache.CacheKey
|
||||||
|
import com.android.server.connectivity.mdns.MdnsServiceCacheTest.ExpiredRecord.ExpiredEvent.ServiceRecordExpired
|
||||||
|
import com.android.server.connectivity.mdns.util.MdnsUtils
|
||||||
import com.android.testutils.DevSdkIgnoreRule
|
import com.android.testutils.DevSdkIgnoreRule
|
||||||
import com.android.testutils.DevSdkIgnoreRunner
|
import com.android.testutils.DevSdkIgnoreRunner
|
||||||
import java.util.concurrent.CompletableFuture
|
import java.util.concurrent.CompletableFuture
|
||||||
@@ -32,13 +35,19 @@ import org.junit.Assert.assertTrue
|
|||||||
import org.junit.Before
|
import org.junit.Before
|
||||||
import org.junit.Test
|
import org.junit.Test
|
||||||
import org.junit.runner.RunWith
|
import org.junit.runner.RunWith
|
||||||
|
import org.mockito.Mockito.doReturn
|
||||||
|
import org.mockito.Mockito.mock
|
||||||
|
|
||||||
private const val SERVICE_NAME_1 = "service-instance-1"
|
private const val SERVICE_NAME_1 = "service-instance-1"
|
||||||
private const val SERVICE_NAME_2 = "service-instance-2"
|
private const val SERVICE_NAME_2 = "service-instance-2"
|
||||||
|
private const val SERVICE_NAME_3 = "service-instance-3"
|
||||||
private const val SERVICE_TYPE_1 = "_test1._tcp.local"
|
private const val SERVICE_TYPE_1 = "_test1._tcp.local"
|
||||||
private const val SERVICE_TYPE_2 = "_test2._tcp.local"
|
private const val SERVICE_TYPE_2 = "_test2._tcp.local"
|
||||||
private const val INTERFACE_INDEX = 999
|
private const val INTERFACE_INDEX = 999
|
||||||
private const val DEFAULT_TIMEOUT_MS = 2000L
|
private const val DEFAULT_TIMEOUT_MS = 2000L
|
||||||
|
private const val NO_CALLBACK_TIMEOUT_MS = 200L
|
||||||
|
private const val TEST_ELAPSED_REALTIME_MS = 123L
|
||||||
|
private const val DEFAULT_TTL_TIME_MS = 120000L
|
||||||
|
|
||||||
@RunWith(DevSdkIgnoreRunner::class)
|
@RunWith(DevSdkIgnoreRunner::class)
|
||||||
@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
|
@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
|
||||||
@@ -47,10 +56,46 @@ class MdnsServiceCacheTest {
|
|||||||
private val cacheKey1 = CacheKey(SERVICE_TYPE_1, socketKey)
|
private val cacheKey1 = CacheKey(SERVICE_TYPE_1, socketKey)
|
||||||
private val cacheKey2 = CacheKey(SERVICE_TYPE_2, socketKey)
|
private val cacheKey2 = CacheKey(SERVICE_TYPE_2, socketKey)
|
||||||
private val thread = HandlerThread(MdnsServiceCacheTest::class.simpleName)
|
private val thread = HandlerThread(MdnsServiceCacheTest::class.simpleName)
|
||||||
|
private val clock = mock(MdnsUtils.Clock::class.java)
|
||||||
private val handler by lazy {
|
private val handler by lazy {
|
||||||
Handler(thread.looper)
|
Handler(thread.looper)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private class ExpiredRecord : MdnsServiceCache.ServiceExpiredCallback {
|
||||||
|
val history = ArrayTrackRecord<ExpiredEvent>().newReadHead()
|
||||||
|
|
||||||
|
sealed class ExpiredEvent {
|
||||||
|
abstract val previousResponse: MdnsResponse
|
||||||
|
abstract val newResponse: MdnsResponse?
|
||||||
|
data class ServiceRecordExpired(
|
||||||
|
override val previousResponse: MdnsResponse,
|
||||||
|
override val newResponse: MdnsResponse?
|
||||||
|
) : ExpiredEvent()
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun onServiceRecordExpired(
|
||||||
|
previousResponse: MdnsResponse,
|
||||||
|
newResponse: MdnsResponse?
|
||||||
|
) {
|
||||||
|
history.add(ServiceRecordExpired(previousResponse, newResponse))
|
||||||
|
}
|
||||||
|
|
||||||
|
fun expectedServiceRecordExpired(
|
||||||
|
serviceName: String,
|
||||||
|
timeoutMs: Long = DEFAULT_TIMEOUT_MS
|
||||||
|
) {
|
||||||
|
val event = history.poll(timeoutMs)
|
||||||
|
assertNotNull(event)
|
||||||
|
assertTrue(event is ServiceRecordExpired)
|
||||||
|
assertEquals(serviceName, event.previousResponse.serviceInstanceName)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun assertNoCallback() {
|
||||||
|
val cb = history.poll(NO_CALLBACK_TIMEOUT_MS)
|
||||||
|
assertNull("Expected no callback but got $cb", cb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
fun setUp() {
|
fun setUp() {
|
||||||
thread.start()
|
thread.start()
|
||||||
@@ -89,19 +134,27 @@ class MdnsServiceCacheTest {
|
|||||||
private fun getService(
|
private fun getService(
|
||||||
serviceCache: MdnsServiceCache,
|
serviceCache: MdnsServiceCache,
|
||||||
serviceName: String,
|
serviceName: String,
|
||||||
cacheKey: CacheKey,
|
cacheKey: CacheKey
|
||||||
): MdnsResponse? = runningOnHandlerAndReturn {
|
): MdnsResponse? = runningOnHandlerAndReturn {
|
||||||
serviceCache.getCachedService(serviceName, cacheKey)
|
serviceCache.getCachedService(serviceName, cacheKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun getServices(
|
private fun getServices(
|
||||||
serviceCache: MdnsServiceCache,
|
serviceCache: MdnsServiceCache,
|
||||||
cacheKey: CacheKey,
|
cacheKey: CacheKey
|
||||||
): List<MdnsResponse> = runningOnHandlerAndReturn { serviceCache.getCachedServices(cacheKey) }
|
): List<MdnsResponse> = runningOnHandlerAndReturn { serviceCache.getCachedServices(cacheKey) }
|
||||||
|
|
||||||
|
private fun registerServiceExpiredCallback(
|
||||||
|
serviceCache: MdnsServiceCache,
|
||||||
|
cacheKey: CacheKey,
|
||||||
|
callback: MdnsServiceCache.ServiceExpiredCallback
|
||||||
|
) = runningOnHandlerAndReturn {
|
||||||
|
serviceCache.registerServiceExpiredCallback(cacheKey, callback)
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testAddAndRemoveService() {
|
fun testAddAndRemoveService() {
|
||||||
val serviceCache = MdnsServiceCache(thread.looper, makeFlags())
|
val serviceCache = MdnsServiceCache(thread.looper, makeFlags(), clock)
|
||||||
addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
|
addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
|
||||||
var response = getService(serviceCache, SERVICE_NAME_1, cacheKey1)
|
var response = getService(serviceCache, SERVICE_NAME_1, cacheKey1)
|
||||||
assertNotNull(response)
|
assertNotNull(response)
|
||||||
@@ -113,7 +166,7 @@ class MdnsServiceCacheTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testGetCachedServices_multipleServiceTypes() {
|
fun testGetCachedServices_multipleServiceTypes() {
|
||||||
val serviceCache = MdnsServiceCache(thread.looper, makeFlags())
|
val serviceCache = MdnsServiceCache(thread.looper, makeFlags(), clock)
|
||||||
addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
|
addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
|
||||||
addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1))
|
addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1))
|
||||||
addOrUpdateService(serviceCache, cacheKey2, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2))
|
addOrUpdateService(serviceCache, cacheKey2, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2))
|
||||||
@@ -145,7 +198,127 @@ class MdnsServiceCacheTest {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun createResponse(serviceInstanceName: String, serviceType: String) = MdnsResponse(
|
@Test
|
||||||
0 /* now */, "$serviceInstanceName.$serviceType".split(".").toTypedArray(),
|
fun testServiceExpiredAndSendCallbacks() {
|
||||||
socketKey.interfaceIndex, socketKey.network)
|
val serviceCache = MdnsServiceCache(
|
||||||
|
thread.looper, makeFlags(isExpiredServicesRemovalEnabled = true), clock)
|
||||||
|
// Register service expired callbacks
|
||||||
|
val callback1 = ExpiredRecord()
|
||||||
|
val callback2 = ExpiredRecord()
|
||||||
|
registerServiceExpiredCallback(serviceCache, cacheKey1, callback1)
|
||||||
|
registerServiceExpiredCallback(serviceCache, cacheKey2, callback2)
|
||||||
|
|
||||||
|
doReturn(TEST_ELAPSED_REALTIME_MS).`when`(clock).elapsedRealtime()
|
||||||
|
|
||||||
|
// Add multiple services with different ttl time.
|
||||||
|
addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1,
|
||||||
|
DEFAULT_TTL_TIME_MS))
|
||||||
|
addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1,
|
||||||
|
DEFAULT_TTL_TIME_MS + 20L))
|
||||||
|
addOrUpdateService(serviceCache, cacheKey2, createResponse(SERVICE_NAME_3, SERVICE_TYPE_2,
|
||||||
|
DEFAULT_TTL_TIME_MS + 10L))
|
||||||
|
|
||||||
|
// Check the service expiration immediately. Should be no callback.
|
||||||
|
assertEquals(2, getServices(serviceCache, cacheKey1).size)
|
||||||
|
assertEquals(1, getServices(serviceCache, cacheKey2).size)
|
||||||
|
callback1.assertNoCallback()
|
||||||
|
callback2.assertNoCallback()
|
||||||
|
|
||||||
|
// Simulate the case where the response is after TTL then check expired services.
|
||||||
|
// Expect SERVICE_NAME_1 expired.
|
||||||
|
doReturn(TEST_ELAPSED_REALTIME_MS + DEFAULT_TTL_TIME_MS).`when`(clock).elapsedRealtime()
|
||||||
|
assertEquals(1, getServices(serviceCache, cacheKey1).size)
|
||||||
|
assertEquals(1, getServices(serviceCache, cacheKey2).size)
|
||||||
|
callback1.expectedServiceRecordExpired(SERVICE_NAME_1)
|
||||||
|
callback2.assertNoCallback()
|
||||||
|
|
||||||
|
// Simulate the case where the response is after TTL then check expired services.
|
||||||
|
// Expect SERVICE_NAME_3 expired.
|
||||||
|
doReturn(TEST_ELAPSED_REALTIME_MS + DEFAULT_TTL_TIME_MS + 11L)
|
||||||
|
.`when`(clock).elapsedRealtime()
|
||||||
|
assertEquals(1, getServices(serviceCache, cacheKey1).size)
|
||||||
|
assertEquals(0, getServices(serviceCache, cacheKey2).size)
|
||||||
|
callback1.assertNoCallback()
|
||||||
|
callback2.expectedServiceRecordExpired(SERVICE_NAME_3)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testRemoveExpiredServiceWhenGetting() {
|
||||||
|
val serviceCache = MdnsServiceCache(
|
||||||
|
thread.looper, makeFlags(isExpiredServicesRemovalEnabled = true), clock)
|
||||||
|
|
||||||
|
doReturn(TEST_ELAPSED_REALTIME_MS).`when`(clock).elapsedRealtime()
|
||||||
|
addOrUpdateService(serviceCache, cacheKey1,
|
||||||
|
createResponse(SERVICE_NAME_1, SERVICE_TYPE_1, 1L /* ttlTime */))
|
||||||
|
doReturn(TEST_ELAPSED_REALTIME_MS + 2L).`when`(clock).elapsedRealtime()
|
||||||
|
assertNull(getService(serviceCache, SERVICE_NAME_1, cacheKey1))
|
||||||
|
|
||||||
|
addOrUpdateService(serviceCache, cacheKey2,
|
||||||
|
createResponse(SERVICE_NAME_2, SERVICE_TYPE_2, 3L /* ttlTime */))
|
||||||
|
doReturn(TEST_ELAPSED_REALTIME_MS + 4L).`when`(clock).elapsedRealtime()
|
||||||
|
assertEquals(0, getServices(serviceCache, cacheKey2).size)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testInsertResponseAndSortList() {
|
||||||
|
val responses = ArrayList<MdnsResponse>()
|
||||||
|
val response1 = createResponse(SERVICE_NAME_1, SERVICE_TYPE_1, 100L /* ttlTime */)
|
||||||
|
MdnsServiceCache.insertResponseAndSortList(responses, response1, TEST_ELAPSED_REALTIME_MS)
|
||||||
|
assertEquals(1, responses.size)
|
||||||
|
assertEquals(response1, responses[0])
|
||||||
|
|
||||||
|
val response2 = createResponse(SERVICE_NAME_2, SERVICE_TYPE_1, 50L /* ttlTime */)
|
||||||
|
MdnsServiceCache.insertResponseAndSortList(responses, response2, TEST_ELAPSED_REALTIME_MS)
|
||||||
|
assertEquals(2, responses.size)
|
||||||
|
assertEquals(response2, responses[0])
|
||||||
|
assertEquals(response1, responses[1])
|
||||||
|
|
||||||
|
val response3 = createResponse(SERVICE_NAME_3, SERVICE_TYPE_1, 75L /* ttlTime */)
|
||||||
|
MdnsServiceCache.insertResponseAndSortList(responses, response3, TEST_ELAPSED_REALTIME_MS)
|
||||||
|
assertEquals(3, responses.size)
|
||||||
|
assertEquals(response2, responses[0])
|
||||||
|
assertEquals(response3, responses[1])
|
||||||
|
assertEquals(response1, responses[2])
|
||||||
|
|
||||||
|
val response4 = createResponse("service-instance-4", SERVICE_TYPE_1, 125L /* ttlTime */)
|
||||||
|
MdnsServiceCache.insertResponseAndSortList(responses, response4, TEST_ELAPSED_REALTIME_MS)
|
||||||
|
assertEquals(4, responses.size)
|
||||||
|
assertEquals(response2, responses[0])
|
||||||
|
assertEquals(response3, responses[1])
|
||||||
|
assertEquals(response1, responses[2])
|
||||||
|
assertEquals(response4, responses[3])
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun createResponse(
|
||||||
|
serviceInstanceName: String,
|
||||||
|
serviceType: String,
|
||||||
|
ttlTime: Long = 120000L
|
||||||
|
): MdnsResponse {
|
||||||
|
val serviceName = "$serviceInstanceName.$serviceType".split(".").toTypedArray()
|
||||||
|
val response = MdnsResponse(
|
||||||
|
0 /* now */, "$serviceInstanceName.$serviceType".split(".").toTypedArray(),
|
||||||
|
socketKey.interfaceIndex, socketKey.network)
|
||||||
|
|
||||||
|
// Set PTR record
|
||||||
|
val pointerRecord = MdnsPointerRecord(
|
||||||
|
serviceType.split(".").toTypedArray(),
|
||||||
|
TEST_ELAPSED_REALTIME_MS /* receiptTimeMillis */,
|
||||||
|
false /* cacheFlush */,
|
||||||
|
ttlTime /* ttlMillis */,
|
||||||
|
serviceName)
|
||||||
|
response.addPointerRecord(pointerRecord)
|
||||||
|
|
||||||
|
// Set SRV record.
|
||||||
|
val serviceRecord = MdnsServiceRecord(
|
||||||
|
serviceName,
|
||||||
|
TEST_ELAPSED_REALTIME_MS /* receiptTimeMillis */,
|
||||||
|
false /* cacheFlush */,
|
||||||
|
ttlTime /* ttlMillis */,
|
||||||
|
0 /* servicePriority */,
|
||||||
|
0 /* serviceWeight */,
|
||||||
|
12345 /* port */,
|
||||||
|
arrayOf("hostname"))
|
||||||
|
response.serviceRecord = serviceRecord
|
||||||
|
return response
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -194,7 +194,9 @@ public class MdnsServiceTypeClientTests {
|
|||||||
thread.start();
|
thread.start();
|
||||||
handler = new Handler(thread.getLooper());
|
handler = new Handler(thread.getLooper());
|
||||||
serviceCache = new MdnsServiceCache(
|
serviceCache = new MdnsServiceCache(
|
||||||
thread.getLooper(), MdnsFeatureFlags.newBuilder().build());
|
thread.getLooper(),
|
||||||
|
MdnsFeatureFlags.newBuilder().setIsExpiredServicesRemovalEnabled(false).build(),
|
||||||
|
mockDecoderClock);
|
||||||
|
|
||||||
doAnswer(inv -> {
|
doAnswer(inv -> {
|
||||||
latestDelayMs = 0;
|
latestDelayMs = 0;
|
||||||
|
|||||||
Reference in New Issue
Block a user