Check service ttl expiration

Now, every services are cached on MdnsServiceCache, so the
remaining TTL should be checked when retrieving services from
the MdnsServiceCache and have a callback to notify the
MdnsServiceTypeClient about expired services.

Bug: 265787401
Test: atest FrameworksNetTests CtsNetTestCases
Change-Id: I99da6cc79bdf5df3c899e642e067907501bc9d4f
This commit is contained in:
Paul Hu
2023-10-18 17:07:31 +08:00
parent 5532b8884c
commit 596a500607
7 changed files with 344 additions and 19 deletions

View File

@@ -1645,8 +1645,8 @@ public class NsdService extends INsdManager.Stub {
mContext, MdnsFeatureFlags.NSD_FORCE_DISABLE_MDNS_OFFLOAD)) mContext, MdnsFeatureFlags.NSD_FORCE_DISABLE_MDNS_OFFLOAD))
.setIncludeInetAddressRecordsInProbing(mDeps.isFeatureEnabled( .setIncludeInetAddressRecordsInProbing(mDeps.isFeatureEnabled(
mContext, MdnsFeatureFlags.INCLUDE_INET_ADDRESS_RECORDS_IN_PROBING)) mContext, MdnsFeatureFlags.INCLUDE_INET_ADDRESS_RECORDS_IN_PROBING))
.setIsExpiredServicesRemovalEnabled(mDeps.isTrunkStableFeatureEnabled( .setIsExpiredServicesRemovalEnabled(mDeps.isFeatureEnabled(
MdnsFeatureFlags.NSD_EXPIRED_SERVICES_REMOVAL)) mContext, MdnsFeatureFlags.NSD_EXPIRED_SERVICES_REMOVAL))
.setIsLabelCountLimitEnabled(mDeps.isTetheringFeatureNotChickenedOut( .setIsLabelCountLimitEnabled(mDeps.isTetheringFeatureNotChickenedOut(
mContext, MdnsFeatureFlags.NSD_LIMIT_LABEL_COUNT)) mContext, MdnsFeatureFlags.NSD_LIMIT_LABEL_COUNT))
.build(); .build();

View File

@@ -86,7 +86,7 @@ public class MdnsFeatureFlags {
public Builder() { public Builder() {
mIsMdnsOffloadFeatureEnabled = false; mIsMdnsOffloadFeatureEnabled = false;
mIncludeInetAddressRecordsInProbing = false; mIncludeInetAddressRecordsInProbing = false;
mIsExpiredServicesRemovalEnabled = true; // Default enabled. mIsExpiredServicesRemovalEnabled = false;
mIsLabelCountLimitEnabled = true; // Default enabled. mIsLabelCountLimitEnabled = true; // Default enabled.
} }

View File

@@ -33,6 +33,7 @@ import java.util.Objects;
/** An mDNS response. */ /** An mDNS response. */
public class MdnsResponse { public class MdnsResponse {
public static final long EXPIRATION_NEVER = Long.MAX_VALUE;
private final List<MdnsRecord> records; private final List<MdnsRecord> records;
private final List<MdnsPointerRecord> pointerRecords; private final List<MdnsPointerRecord> pointerRecords;
private MdnsServiceRecord serviceRecord; private MdnsServiceRecord serviceRecord;
@@ -349,6 +350,21 @@ public class MdnsResponse {
return serviceName; return serviceName;
} }
/** Get the min remaining ttl time from received records */
public long getMinRemainingTtl(long now) {
long minRemainingTtl = EXPIRATION_NEVER;
// TODO: Check other records(A, AAAA, TXT) ttl time.
if (!hasServiceRecord()) {
return EXPIRATION_NEVER;
}
// Check ttl time.
long remainingTtl = serviceRecord.getRemainingTTL(now);
if (remainingTtl < minRemainingTtl) {
minRemainingTtl = remainingTtl;
}
return minRemainingTtl;
}
/** /**
* Tests if this response is a goodbye message. This will be true if a service record is present * Tests if this response is a goodbye message. This will be true if a service record is present
* and any of the records have a TTL of 0. * and any of the records have a TTL of 0.

View File

@@ -16,16 +16,22 @@
package com.android.server.connectivity.mdns; package com.android.server.connectivity.mdns;
import static com.android.server.connectivity.mdns.MdnsResponse.EXPIRATION_NEVER;
import static com.android.server.connectivity.mdns.util.MdnsUtils.ensureRunningOnHandlerThread; import static com.android.server.connectivity.mdns.util.MdnsUtils.ensureRunningOnHandlerThread;
import static com.android.server.connectivity.mdns.util.MdnsUtils.equalsIgnoreDnsCase; import static com.android.server.connectivity.mdns.util.MdnsUtils.equalsIgnoreDnsCase;
import static com.android.server.connectivity.mdns.util.MdnsUtils.toDnsLowerCase; import static com.android.server.connectivity.mdns.util.MdnsUtils.toDnsLowerCase;
import static java.lang.Math.min;
import android.annotation.NonNull; import android.annotation.NonNull;
import android.annotation.Nullable; import android.annotation.Nullable;
import android.os.Handler; import android.os.Handler;
import android.os.Looper; import android.os.Looper;
import android.util.ArrayMap; import android.util.ArrayMap;
import com.android.internal.annotations.VisibleForTesting;
import com.android.server.connectivity.mdns.util.MdnsUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.Iterator; import java.util.Iterator;
@@ -67,8 +73,11 @@ public class MdnsServiceCache {
} }
} }
/** /**
* A map of cached services. Key is composed of service name, type and socket. Value is the * A map of cached services. Key is composed of service type and socket. Value is the list of
* service which use the service type to discover from each socket. * services which are discovered from the given CacheKey.
* When the MdnsFeatureFlags#NSD_EXPIRED_SERVICES_REMOVAL flag is enabled, the lists are sorted
* by expiration time, with the earliest entries appearing first. This sorting allows the
* removal process to progress through the expiration check efficiently.
*/ */
@NonNull @NonNull
private final ArrayMap<CacheKey, List<MdnsResponse>> mCachedServices = new ArrayMap<>(); private final ArrayMap<CacheKey, List<MdnsResponse>> mCachedServices = new ArrayMap<>();
@@ -82,10 +91,20 @@ public class MdnsServiceCache {
private final Handler mHandler; private final Handler mHandler;
@NonNull @NonNull
private final MdnsFeatureFlags mMdnsFeatureFlags; private final MdnsFeatureFlags mMdnsFeatureFlags;
@NonNull
private final MdnsUtils.Clock mClock;
private long mNextExpirationTime = EXPIRATION_NEVER;
public MdnsServiceCache(@NonNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags) { public MdnsServiceCache(@NonNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
this(looper, mdnsFeatureFlags, new MdnsUtils.Clock());
}
@VisibleForTesting
MdnsServiceCache(@NonNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags,
@NonNull MdnsUtils.Clock clock) {
mHandler = new Handler(looper); mHandler = new Handler(looper);
mMdnsFeatureFlags = mdnsFeatureFlags; mMdnsFeatureFlags = mdnsFeatureFlags;
mClock = clock;
} }
/** /**
@@ -97,6 +116,9 @@ public class MdnsServiceCache {
@NonNull @NonNull
public List<MdnsResponse> getCachedServices(@NonNull CacheKey cacheKey) { public List<MdnsResponse> getCachedServices(@NonNull CacheKey cacheKey) {
ensureRunningOnHandlerThread(mHandler); ensureRunningOnHandlerThread(mHandler);
if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) {
maybeRemoveExpiredServices(cacheKey, mClock.elapsedRealtime());
}
return mCachedServices.containsKey(cacheKey) return mCachedServices.containsKey(cacheKey)
? Collections.unmodifiableList(new ArrayList<>(mCachedServices.get(cacheKey))) ? Collections.unmodifiableList(new ArrayList<>(mCachedServices.get(cacheKey)))
: Collections.emptyList(); : Collections.emptyList();
@@ -129,6 +151,9 @@ public class MdnsServiceCache {
@Nullable @Nullable
public MdnsResponse getCachedService(@NonNull String serviceName, @NonNull CacheKey cacheKey) { public MdnsResponse getCachedService(@NonNull String serviceName, @NonNull CacheKey cacheKey) {
ensureRunningOnHandlerThread(mHandler); ensureRunningOnHandlerThread(mHandler);
if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) {
maybeRemoveExpiredServices(cacheKey, mClock.elapsedRealtime());
}
final List<MdnsResponse> responses = mCachedServices.get(cacheKey); final List<MdnsResponse> responses = mCachedServices.get(cacheKey);
if (responses == null) { if (responses == null) {
return null; return null;
@@ -137,6 +162,16 @@ public class MdnsServiceCache {
return response != null ? new MdnsResponse(response) : null; return response != null ? new MdnsResponse(response) : null;
} }
static void insertResponseAndSortList(
List<MdnsResponse> responses, MdnsResponse response, long now) {
// binarySearch returns "the index of the search key, if it is contained in the list;
// otherwise, (-(insertion point) - 1)"
final int searchRes = Collections.binarySearch(responses, response,
// Sort the list by ttl.
(o1, o2) -> Long.compare(o1.getMinRemainingTtl(now), o2.getMinRemainingTtl(now)));
responses.add(searchRes >= 0 ? searchRes : (-searchRes - 1), response);
}
/** /**
* Add or update a service. * Add or update a service.
* *
@@ -151,7 +186,15 @@ public class MdnsServiceCache {
final MdnsResponse existing = final MdnsResponse existing =
findMatchedResponse(responses, response.getServiceInstanceName()); findMatchedResponse(responses, response.getServiceInstanceName());
responses.remove(existing); responses.remove(existing);
responses.add(response); if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) {
final long now = mClock.elapsedRealtime();
// Insert and sort service
insertResponseAndSortList(responses, response, now);
// Update the next expiration check time when a new service is added.
mNextExpirationTime = getNextExpirationTime(now);
} else {
responses.add(response);
}
} }
/** /**
@@ -168,14 +211,25 @@ public class MdnsServiceCache {
return null; return null;
} }
final Iterator<MdnsResponse> iterator = responses.iterator(); final Iterator<MdnsResponse> iterator = responses.iterator();
MdnsResponse removedResponse = null;
while (iterator.hasNext()) { while (iterator.hasNext()) {
final MdnsResponse response = iterator.next(); final MdnsResponse response = iterator.next();
if (equalsIgnoreDnsCase(serviceName, response.getServiceInstanceName())) { if (equalsIgnoreDnsCase(serviceName, response.getServiceInstanceName())) {
iterator.remove(); iterator.remove();
return response; removedResponse = response;
break;
} }
} }
return null;
if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) {
// Remove the serviceType if no response.
if (responses.isEmpty()) {
mCachedServices.remove(cacheKey);
}
// Update the next expiration check time when a service is removed.
mNextExpirationTime = getNextExpirationTime(mClock.elapsedRealtime());
}
return removedResponse;
} }
/** /**
@@ -203,6 +257,87 @@ public class MdnsServiceCache {
mCallbacks.remove(cacheKey); mCallbacks.remove(cacheKey);
} }
private void notifyServiceExpired(@NonNull CacheKey cacheKey,
@NonNull MdnsResponse previousResponse, @Nullable MdnsResponse newResponse) {
final ServiceExpiredCallback callback = mCallbacks.get(cacheKey);
if (callback == null) {
// The cached service is no listener.
return;
}
mHandler.post(()-> callback.onServiceRecordExpired(previousResponse, newResponse));
}
static List<MdnsResponse> removeExpiredServices(@NonNull List<MdnsResponse> responses,
long now) {
final List<MdnsResponse> removedResponses = new ArrayList<>();
final Iterator<MdnsResponse> iterator = responses.iterator();
while (iterator.hasNext()) {
final MdnsResponse response = iterator.next();
// TODO: Check other records (A, AAAA, TXT) ttl time and remove the record if it's
// expired. Then send service update notification.
if (!response.hasServiceRecord() || response.getMinRemainingTtl(now) > 0) {
// The responses are sorted by the service record ttl time. Break out of loop
// early if service is not expired or no service record.
break;
}
// Remove the ttl expired service.
iterator.remove();
removedResponses.add(response);
}
return removedResponses;
}
private long getNextExpirationTime(long now) {
if (mCachedServices.isEmpty()) {
return EXPIRATION_NEVER;
}
long minRemainingTtl = EXPIRATION_NEVER;
for (int i = 0; i < mCachedServices.size(); i++) {
minRemainingTtl = min(minRemainingTtl,
// The empty lists are not kept in the map, so there's always at least one
// element in the list. Therefore, it's fine to get the first element without a
// null check.
mCachedServices.valueAt(i).get(0).getMinRemainingTtl(now));
}
return minRemainingTtl == EXPIRATION_NEVER ? EXPIRATION_NEVER : now + minRemainingTtl;
}
/**
* Check whether the ttl time is expired on each service and notify to the listeners
*/
private void maybeRemoveExpiredServices(CacheKey cacheKey, long now) {
ensureRunningOnHandlerThread(mHandler);
if (now < mNextExpirationTime) {
// Skip the check if ttl time is not expired.
return;
}
final List<MdnsResponse> responses = mCachedServices.get(cacheKey);
if (responses == null) {
// No such services.
return;
}
final List<MdnsResponse> removedResponses = removeExpiredServices(responses, now);
if (removedResponses.isEmpty()) {
// No expired services.
return;
}
for (MdnsResponse previousResponse : removedResponses) {
notifyServiceExpired(cacheKey, previousResponse, null /* newResponse */);
}
// Remove the serviceType if no response.
if (responses.isEmpty()) {
mCachedServices.remove(cacheKey);
}
// Update next expiration time.
mNextExpirationTime = getNextExpirationTime(now);
}
/*** Callbacks for listening service expiration */ /*** Callbacks for listening service expiration */
public interface ServiceExpiredCallback { public interface ServiceExpiredCallback {
/*** Notify the service is expired */ /*** Notify the service is expired */
@@ -210,5 +345,5 @@ public class MdnsServiceCache {
@Nullable MdnsResponse newResponse); @Nullable MdnsResponse newResponse);
} }
// TODO: check ttl expiration for each service and notify to the clients. // TODO: Schedule a job to check ttl expiration for all services and notify to the clients.
} }

View File

@@ -312,8 +312,7 @@ public class MdnsServiceTypeClient {
this.searchOptions = searchOptions; this.searchOptions = searchOptions;
boolean hadReply = false; boolean hadReply = false;
if (listeners.put(listener, searchOptions) == null) { if (listeners.put(listener, searchOptions) == null) {
for (MdnsResponse existingResponse : for (MdnsResponse existingResponse : serviceCache.getCachedServices(cacheKey)) {
serviceCache.getCachedServices(cacheKey)) {
if (!responseMatchesOptions(existingResponse, searchOptions)) continue; if (!responseMatchesOptions(existingResponse, searchOptions)) continue;
final MdnsServiceInfo info = final MdnsServiceInfo info =
buildMdnsServiceInfoFromResponse(existingResponse, serviceTypeLabels); buildMdnsServiceInfoFromResponse(existingResponse, serviceTypeLabels);

View File

@@ -19,7 +19,10 @@ package com.android.server.connectivity.mdns
import android.os.Build import android.os.Build
import android.os.Handler import android.os.Handler
import android.os.HandlerThread import android.os.HandlerThread
import com.android.net.module.util.ArrayTrackRecord
import com.android.server.connectivity.mdns.MdnsServiceCache.CacheKey import com.android.server.connectivity.mdns.MdnsServiceCache.CacheKey
import com.android.server.connectivity.mdns.MdnsServiceCacheTest.ExpiredRecord.ExpiredEvent.ServiceRecordExpired
import com.android.server.connectivity.mdns.util.MdnsUtils
import com.android.testutils.DevSdkIgnoreRule import com.android.testutils.DevSdkIgnoreRule
import com.android.testutils.DevSdkIgnoreRunner import com.android.testutils.DevSdkIgnoreRunner
import java.util.concurrent.CompletableFuture import java.util.concurrent.CompletableFuture
@@ -32,13 +35,19 @@ import org.junit.Assert.assertTrue
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
import org.junit.runner.RunWith import org.junit.runner.RunWith
import org.mockito.Mockito.doReturn
import org.mockito.Mockito.mock
private const val SERVICE_NAME_1 = "service-instance-1" private const val SERVICE_NAME_1 = "service-instance-1"
private const val SERVICE_NAME_2 = "service-instance-2" private const val SERVICE_NAME_2 = "service-instance-2"
private const val SERVICE_NAME_3 = "service-instance-3"
private const val SERVICE_TYPE_1 = "_test1._tcp.local" private const val SERVICE_TYPE_1 = "_test1._tcp.local"
private const val SERVICE_TYPE_2 = "_test2._tcp.local" private const val SERVICE_TYPE_2 = "_test2._tcp.local"
private const val INTERFACE_INDEX = 999 private const val INTERFACE_INDEX = 999
private const val DEFAULT_TIMEOUT_MS = 2000L private const val DEFAULT_TIMEOUT_MS = 2000L
private const val NO_CALLBACK_TIMEOUT_MS = 200L
private const val TEST_ELAPSED_REALTIME_MS = 123L
private const val DEFAULT_TTL_TIME_MS = 120000L
@RunWith(DevSdkIgnoreRunner::class) @RunWith(DevSdkIgnoreRunner::class)
@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2) @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
@@ -47,10 +56,46 @@ class MdnsServiceCacheTest {
private val cacheKey1 = CacheKey(SERVICE_TYPE_1, socketKey) private val cacheKey1 = CacheKey(SERVICE_TYPE_1, socketKey)
private val cacheKey2 = CacheKey(SERVICE_TYPE_2, socketKey) private val cacheKey2 = CacheKey(SERVICE_TYPE_2, socketKey)
private val thread = HandlerThread(MdnsServiceCacheTest::class.simpleName) private val thread = HandlerThread(MdnsServiceCacheTest::class.simpleName)
private val clock = mock(MdnsUtils.Clock::class.java)
private val handler by lazy { private val handler by lazy {
Handler(thread.looper) Handler(thread.looper)
} }
private class ExpiredRecord : MdnsServiceCache.ServiceExpiredCallback {
val history = ArrayTrackRecord<ExpiredEvent>().newReadHead()
sealed class ExpiredEvent {
abstract val previousResponse: MdnsResponse
abstract val newResponse: MdnsResponse?
data class ServiceRecordExpired(
override val previousResponse: MdnsResponse,
override val newResponse: MdnsResponse?
) : ExpiredEvent()
}
override fun onServiceRecordExpired(
previousResponse: MdnsResponse,
newResponse: MdnsResponse?
) {
history.add(ServiceRecordExpired(previousResponse, newResponse))
}
fun expectedServiceRecordExpired(
serviceName: String,
timeoutMs: Long = DEFAULT_TIMEOUT_MS
) {
val event = history.poll(timeoutMs)
assertNotNull(event)
assertTrue(event is ServiceRecordExpired)
assertEquals(serviceName, event.previousResponse.serviceInstanceName)
}
fun assertNoCallback() {
val cb = history.poll(NO_CALLBACK_TIMEOUT_MS)
assertNull("Expected no callback but got $cb", cb)
}
}
@Before @Before
fun setUp() { fun setUp() {
thread.start() thread.start()
@@ -89,19 +134,27 @@ class MdnsServiceCacheTest {
private fun getService( private fun getService(
serviceCache: MdnsServiceCache, serviceCache: MdnsServiceCache,
serviceName: String, serviceName: String,
cacheKey: CacheKey, cacheKey: CacheKey
): MdnsResponse? = runningOnHandlerAndReturn { ): MdnsResponse? = runningOnHandlerAndReturn {
serviceCache.getCachedService(serviceName, cacheKey) serviceCache.getCachedService(serviceName, cacheKey)
} }
private fun getServices( private fun getServices(
serviceCache: MdnsServiceCache, serviceCache: MdnsServiceCache,
cacheKey: CacheKey, cacheKey: CacheKey
): List<MdnsResponse> = runningOnHandlerAndReturn { serviceCache.getCachedServices(cacheKey) } ): List<MdnsResponse> = runningOnHandlerAndReturn { serviceCache.getCachedServices(cacheKey) }
private fun registerServiceExpiredCallback(
serviceCache: MdnsServiceCache,
cacheKey: CacheKey,
callback: MdnsServiceCache.ServiceExpiredCallback
) = runningOnHandlerAndReturn {
serviceCache.registerServiceExpiredCallback(cacheKey, callback)
}
@Test @Test
fun testAddAndRemoveService() { fun testAddAndRemoveService() {
val serviceCache = MdnsServiceCache(thread.looper, makeFlags()) val serviceCache = MdnsServiceCache(thread.looper, makeFlags(), clock)
addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1)) addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
var response = getService(serviceCache, SERVICE_NAME_1, cacheKey1) var response = getService(serviceCache, SERVICE_NAME_1, cacheKey1)
assertNotNull(response) assertNotNull(response)
@@ -113,7 +166,7 @@ class MdnsServiceCacheTest {
@Test @Test
fun testGetCachedServices_multipleServiceTypes() { fun testGetCachedServices_multipleServiceTypes() {
val serviceCache = MdnsServiceCache(thread.looper, makeFlags()) val serviceCache = MdnsServiceCache(thread.looper, makeFlags(), clock)
addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1)) addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1)) addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1))
addOrUpdateService(serviceCache, cacheKey2, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2)) addOrUpdateService(serviceCache, cacheKey2, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2))
@@ -145,7 +198,127 @@ class MdnsServiceCacheTest {
}) })
} }
private fun createResponse(serviceInstanceName: String, serviceType: String) = MdnsResponse( @Test
0 /* now */, "$serviceInstanceName.$serviceType".split(".").toTypedArray(), fun testServiceExpiredAndSendCallbacks() {
socketKey.interfaceIndex, socketKey.network) val serviceCache = MdnsServiceCache(
thread.looper, makeFlags(isExpiredServicesRemovalEnabled = true), clock)
// Register service expired callbacks
val callback1 = ExpiredRecord()
val callback2 = ExpiredRecord()
registerServiceExpiredCallback(serviceCache, cacheKey1, callback1)
registerServiceExpiredCallback(serviceCache, cacheKey2, callback2)
doReturn(TEST_ELAPSED_REALTIME_MS).`when`(clock).elapsedRealtime()
// Add multiple services with different ttl time.
addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1,
DEFAULT_TTL_TIME_MS))
addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1,
DEFAULT_TTL_TIME_MS + 20L))
addOrUpdateService(serviceCache, cacheKey2, createResponse(SERVICE_NAME_3, SERVICE_TYPE_2,
DEFAULT_TTL_TIME_MS + 10L))
// Check the service expiration immediately. Should be no callback.
assertEquals(2, getServices(serviceCache, cacheKey1).size)
assertEquals(1, getServices(serviceCache, cacheKey2).size)
callback1.assertNoCallback()
callback2.assertNoCallback()
// Simulate the case where the response is after TTL then check expired services.
// Expect SERVICE_NAME_1 expired.
doReturn(TEST_ELAPSED_REALTIME_MS + DEFAULT_TTL_TIME_MS).`when`(clock).elapsedRealtime()
assertEquals(1, getServices(serviceCache, cacheKey1).size)
assertEquals(1, getServices(serviceCache, cacheKey2).size)
callback1.expectedServiceRecordExpired(SERVICE_NAME_1)
callback2.assertNoCallback()
// Simulate the case where the response is after TTL then check expired services.
// Expect SERVICE_NAME_3 expired.
doReturn(TEST_ELAPSED_REALTIME_MS + DEFAULT_TTL_TIME_MS + 11L)
.`when`(clock).elapsedRealtime()
assertEquals(1, getServices(serviceCache, cacheKey1).size)
assertEquals(0, getServices(serviceCache, cacheKey2).size)
callback1.assertNoCallback()
callback2.expectedServiceRecordExpired(SERVICE_NAME_3)
}
@Test
fun testRemoveExpiredServiceWhenGetting() {
val serviceCache = MdnsServiceCache(
thread.looper, makeFlags(isExpiredServicesRemovalEnabled = true), clock)
doReturn(TEST_ELAPSED_REALTIME_MS).`when`(clock).elapsedRealtime()
addOrUpdateService(serviceCache, cacheKey1,
createResponse(SERVICE_NAME_1, SERVICE_TYPE_1, 1L /* ttlTime */))
doReturn(TEST_ELAPSED_REALTIME_MS + 2L).`when`(clock).elapsedRealtime()
assertNull(getService(serviceCache, SERVICE_NAME_1, cacheKey1))
addOrUpdateService(serviceCache, cacheKey2,
createResponse(SERVICE_NAME_2, SERVICE_TYPE_2, 3L /* ttlTime */))
doReturn(TEST_ELAPSED_REALTIME_MS + 4L).`when`(clock).elapsedRealtime()
assertEquals(0, getServices(serviceCache, cacheKey2).size)
}
@Test
fun testInsertResponseAndSortList() {
val responses = ArrayList<MdnsResponse>()
val response1 = createResponse(SERVICE_NAME_1, SERVICE_TYPE_1, 100L /* ttlTime */)
MdnsServiceCache.insertResponseAndSortList(responses, response1, TEST_ELAPSED_REALTIME_MS)
assertEquals(1, responses.size)
assertEquals(response1, responses[0])
val response2 = createResponse(SERVICE_NAME_2, SERVICE_TYPE_1, 50L /* ttlTime */)
MdnsServiceCache.insertResponseAndSortList(responses, response2, TEST_ELAPSED_REALTIME_MS)
assertEquals(2, responses.size)
assertEquals(response2, responses[0])
assertEquals(response1, responses[1])
val response3 = createResponse(SERVICE_NAME_3, SERVICE_TYPE_1, 75L /* ttlTime */)
MdnsServiceCache.insertResponseAndSortList(responses, response3, TEST_ELAPSED_REALTIME_MS)
assertEquals(3, responses.size)
assertEquals(response2, responses[0])
assertEquals(response3, responses[1])
assertEquals(response1, responses[2])
val response4 = createResponse("service-instance-4", SERVICE_TYPE_1, 125L /* ttlTime */)
MdnsServiceCache.insertResponseAndSortList(responses, response4, TEST_ELAPSED_REALTIME_MS)
assertEquals(4, responses.size)
assertEquals(response2, responses[0])
assertEquals(response3, responses[1])
assertEquals(response1, responses[2])
assertEquals(response4, responses[3])
}
private fun createResponse(
serviceInstanceName: String,
serviceType: String,
ttlTime: Long = 120000L
): MdnsResponse {
val serviceName = "$serviceInstanceName.$serviceType".split(".").toTypedArray()
val response = MdnsResponse(
0 /* now */, "$serviceInstanceName.$serviceType".split(".").toTypedArray(),
socketKey.interfaceIndex, socketKey.network)
// Set PTR record
val pointerRecord = MdnsPointerRecord(
serviceType.split(".").toTypedArray(),
TEST_ELAPSED_REALTIME_MS /* receiptTimeMillis */,
false /* cacheFlush */,
ttlTime /* ttlMillis */,
serviceName)
response.addPointerRecord(pointerRecord)
// Set SRV record.
val serviceRecord = MdnsServiceRecord(
serviceName,
TEST_ELAPSED_REALTIME_MS /* receiptTimeMillis */,
false /* cacheFlush */,
ttlTime /* ttlMillis */,
0 /* servicePriority */,
0 /* serviceWeight */,
12345 /* port */,
arrayOf("hostname"))
response.serviceRecord = serviceRecord
return response
}
} }

View File

@@ -194,7 +194,9 @@ public class MdnsServiceTypeClientTests {
thread.start(); thread.start();
handler = new Handler(thread.getLooper()); handler = new Handler(thread.getLooper());
serviceCache = new MdnsServiceCache( serviceCache = new MdnsServiceCache(
thread.getLooper(), MdnsFeatureFlags.newBuilder().build()); thread.getLooper(),
MdnsFeatureFlags.newBuilder().setIsExpiredServicesRemovalEnabled(false).build(),
mockDecoderClock);
doAnswer(inv -> { doAnswer(inv -> {
latestDelayMs = 0; latestDelayMs = 0;