Check service ttl expiration

Now, every services are cached on MdnsServiceCache, so the
remaining TTL should be checked when retrieving services from
the MdnsServiceCache and have a callback to notify the
MdnsServiceTypeClient about expired services.

Bug: 265787401
Test: atest FrameworksNetTests CtsNetTestCases
Change-Id: I99da6cc79bdf5df3c899e642e067907501bc9d4f
This commit is contained in:
Paul Hu
2023-10-18 17:07:31 +08:00
parent 5532b8884c
commit 596a500607
7 changed files with 344 additions and 19 deletions

View File

@@ -19,7 +19,10 @@ package com.android.server.connectivity.mdns
import android.os.Build
import android.os.Handler
import android.os.HandlerThread
import com.android.net.module.util.ArrayTrackRecord
import com.android.server.connectivity.mdns.MdnsServiceCache.CacheKey
import com.android.server.connectivity.mdns.MdnsServiceCacheTest.ExpiredRecord.ExpiredEvent.ServiceRecordExpired
import com.android.server.connectivity.mdns.util.MdnsUtils
import com.android.testutils.DevSdkIgnoreRule
import com.android.testutils.DevSdkIgnoreRunner
import java.util.concurrent.CompletableFuture
@@ -32,13 +35,19 @@ import org.junit.Assert.assertTrue
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.Mockito.doReturn
import org.mockito.Mockito.mock
private const val SERVICE_NAME_1 = "service-instance-1"
private const val SERVICE_NAME_2 = "service-instance-2"
private const val SERVICE_NAME_3 = "service-instance-3"
private const val SERVICE_TYPE_1 = "_test1._tcp.local"
private const val SERVICE_TYPE_2 = "_test2._tcp.local"
private const val INTERFACE_INDEX = 999
private const val DEFAULT_TIMEOUT_MS = 2000L
private const val NO_CALLBACK_TIMEOUT_MS = 200L
private const val TEST_ELAPSED_REALTIME_MS = 123L
private const val DEFAULT_TTL_TIME_MS = 120000L
@RunWith(DevSdkIgnoreRunner::class)
@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
@@ -47,10 +56,46 @@ class MdnsServiceCacheTest {
private val cacheKey1 = CacheKey(SERVICE_TYPE_1, socketKey)
private val cacheKey2 = CacheKey(SERVICE_TYPE_2, socketKey)
private val thread = HandlerThread(MdnsServiceCacheTest::class.simpleName)
private val clock = mock(MdnsUtils.Clock::class.java)
private val handler by lazy {
Handler(thread.looper)
}
private class ExpiredRecord : MdnsServiceCache.ServiceExpiredCallback {
val history = ArrayTrackRecord<ExpiredEvent>().newReadHead()
sealed class ExpiredEvent {
abstract val previousResponse: MdnsResponse
abstract val newResponse: MdnsResponse?
data class ServiceRecordExpired(
override val previousResponse: MdnsResponse,
override val newResponse: MdnsResponse?
) : ExpiredEvent()
}
override fun onServiceRecordExpired(
previousResponse: MdnsResponse,
newResponse: MdnsResponse?
) {
history.add(ServiceRecordExpired(previousResponse, newResponse))
}
fun expectedServiceRecordExpired(
serviceName: String,
timeoutMs: Long = DEFAULT_TIMEOUT_MS
) {
val event = history.poll(timeoutMs)
assertNotNull(event)
assertTrue(event is ServiceRecordExpired)
assertEquals(serviceName, event.previousResponse.serviceInstanceName)
}
fun assertNoCallback() {
val cb = history.poll(NO_CALLBACK_TIMEOUT_MS)
assertNull("Expected no callback but got $cb", cb)
}
}
@Before
fun setUp() {
thread.start()
@@ -89,19 +134,27 @@ class MdnsServiceCacheTest {
private fun getService(
serviceCache: MdnsServiceCache,
serviceName: String,
cacheKey: CacheKey,
cacheKey: CacheKey
): MdnsResponse? = runningOnHandlerAndReturn {
serviceCache.getCachedService(serviceName, cacheKey)
}
private fun getServices(
serviceCache: MdnsServiceCache,
cacheKey: CacheKey,
cacheKey: CacheKey
): List<MdnsResponse> = runningOnHandlerAndReturn { serviceCache.getCachedServices(cacheKey) }
private fun registerServiceExpiredCallback(
serviceCache: MdnsServiceCache,
cacheKey: CacheKey,
callback: MdnsServiceCache.ServiceExpiredCallback
) = runningOnHandlerAndReturn {
serviceCache.registerServiceExpiredCallback(cacheKey, callback)
}
@Test
fun testAddAndRemoveService() {
val serviceCache = MdnsServiceCache(thread.looper, makeFlags())
val serviceCache = MdnsServiceCache(thread.looper, makeFlags(), clock)
addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
var response = getService(serviceCache, SERVICE_NAME_1, cacheKey1)
assertNotNull(response)
@@ -113,7 +166,7 @@ class MdnsServiceCacheTest {
@Test
fun testGetCachedServices_multipleServiceTypes() {
val serviceCache = MdnsServiceCache(thread.looper, makeFlags())
val serviceCache = MdnsServiceCache(thread.looper, makeFlags(), clock)
addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1))
addOrUpdateService(serviceCache, cacheKey2, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2))
@@ -145,7 +198,127 @@ class MdnsServiceCacheTest {
})
}
private fun createResponse(serviceInstanceName: String, serviceType: String) = MdnsResponse(
0 /* now */, "$serviceInstanceName.$serviceType".split(".").toTypedArray(),
socketKey.interfaceIndex, socketKey.network)
@Test
fun testServiceExpiredAndSendCallbacks() {
val serviceCache = MdnsServiceCache(
thread.looper, makeFlags(isExpiredServicesRemovalEnabled = true), clock)
// Register service expired callbacks
val callback1 = ExpiredRecord()
val callback2 = ExpiredRecord()
registerServiceExpiredCallback(serviceCache, cacheKey1, callback1)
registerServiceExpiredCallback(serviceCache, cacheKey2, callback2)
doReturn(TEST_ELAPSED_REALTIME_MS).`when`(clock).elapsedRealtime()
// Add multiple services with different ttl time.
addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1,
DEFAULT_TTL_TIME_MS))
addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1,
DEFAULT_TTL_TIME_MS + 20L))
addOrUpdateService(serviceCache, cacheKey2, createResponse(SERVICE_NAME_3, SERVICE_TYPE_2,
DEFAULT_TTL_TIME_MS + 10L))
// Check the service expiration immediately. Should be no callback.
assertEquals(2, getServices(serviceCache, cacheKey1).size)
assertEquals(1, getServices(serviceCache, cacheKey2).size)
callback1.assertNoCallback()
callback2.assertNoCallback()
// Simulate the case where the response is after TTL then check expired services.
// Expect SERVICE_NAME_1 expired.
doReturn(TEST_ELAPSED_REALTIME_MS + DEFAULT_TTL_TIME_MS).`when`(clock).elapsedRealtime()
assertEquals(1, getServices(serviceCache, cacheKey1).size)
assertEquals(1, getServices(serviceCache, cacheKey2).size)
callback1.expectedServiceRecordExpired(SERVICE_NAME_1)
callback2.assertNoCallback()
// Simulate the case where the response is after TTL then check expired services.
// Expect SERVICE_NAME_3 expired.
doReturn(TEST_ELAPSED_REALTIME_MS + DEFAULT_TTL_TIME_MS + 11L)
.`when`(clock).elapsedRealtime()
assertEquals(1, getServices(serviceCache, cacheKey1).size)
assertEquals(0, getServices(serviceCache, cacheKey2).size)
callback1.assertNoCallback()
callback2.expectedServiceRecordExpired(SERVICE_NAME_3)
}
@Test
fun testRemoveExpiredServiceWhenGetting() {
val serviceCache = MdnsServiceCache(
thread.looper, makeFlags(isExpiredServicesRemovalEnabled = true), clock)
doReturn(TEST_ELAPSED_REALTIME_MS).`when`(clock).elapsedRealtime()
addOrUpdateService(serviceCache, cacheKey1,
createResponse(SERVICE_NAME_1, SERVICE_TYPE_1, 1L /* ttlTime */))
doReturn(TEST_ELAPSED_REALTIME_MS + 2L).`when`(clock).elapsedRealtime()
assertNull(getService(serviceCache, SERVICE_NAME_1, cacheKey1))
addOrUpdateService(serviceCache, cacheKey2,
createResponse(SERVICE_NAME_2, SERVICE_TYPE_2, 3L /* ttlTime */))
doReturn(TEST_ELAPSED_REALTIME_MS + 4L).`when`(clock).elapsedRealtime()
assertEquals(0, getServices(serviceCache, cacheKey2).size)
}
@Test
fun testInsertResponseAndSortList() {
val responses = ArrayList<MdnsResponse>()
val response1 = createResponse(SERVICE_NAME_1, SERVICE_TYPE_1, 100L /* ttlTime */)
MdnsServiceCache.insertResponseAndSortList(responses, response1, TEST_ELAPSED_REALTIME_MS)
assertEquals(1, responses.size)
assertEquals(response1, responses[0])
val response2 = createResponse(SERVICE_NAME_2, SERVICE_TYPE_1, 50L /* ttlTime */)
MdnsServiceCache.insertResponseAndSortList(responses, response2, TEST_ELAPSED_REALTIME_MS)
assertEquals(2, responses.size)
assertEquals(response2, responses[0])
assertEquals(response1, responses[1])
val response3 = createResponse(SERVICE_NAME_3, SERVICE_TYPE_1, 75L /* ttlTime */)
MdnsServiceCache.insertResponseAndSortList(responses, response3, TEST_ELAPSED_REALTIME_MS)
assertEquals(3, responses.size)
assertEquals(response2, responses[0])
assertEquals(response3, responses[1])
assertEquals(response1, responses[2])
val response4 = createResponse("service-instance-4", SERVICE_TYPE_1, 125L /* ttlTime */)
MdnsServiceCache.insertResponseAndSortList(responses, response4, TEST_ELAPSED_REALTIME_MS)
assertEquals(4, responses.size)
assertEquals(response2, responses[0])
assertEquals(response3, responses[1])
assertEquals(response1, responses[2])
assertEquals(response4, responses[3])
}
private fun createResponse(
serviceInstanceName: String,
serviceType: String,
ttlTime: Long = 120000L
): MdnsResponse {
val serviceName = "$serviceInstanceName.$serviceType".split(".").toTypedArray()
val response = MdnsResponse(
0 /* now */, "$serviceInstanceName.$serviceType".split(".").toTypedArray(),
socketKey.interfaceIndex, socketKey.network)
// Set PTR record
val pointerRecord = MdnsPointerRecord(
serviceType.split(".").toTypedArray(),
TEST_ELAPSED_REALTIME_MS /* receiptTimeMillis */,
false /* cacheFlush */,
ttlTime /* ttlMillis */,
serviceName)
response.addPointerRecord(pointerRecord)
// Set SRV record.
val serviceRecord = MdnsServiceRecord(
serviceName,
TEST_ELAPSED_REALTIME_MS /* receiptTimeMillis */,
false /* cacheFlush */,
ttlTime /* ttlMillis */,
0 /* servicePriority */,
0 /* serviceWeight */,
12345 /* port */,
arrayOf("hostname"))
response.serviceRecord = serviceRecord
return response
}
}

View File

@@ -194,7 +194,9 @@ public class MdnsServiceTypeClientTests {
thread.start();
handler = new Handler(thread.getLooper());
serviceCache = new MdnsServiceCache(
thread.getLooper(), MdnsFeatureFlags.newBuilder().build());
thread.getLooper(),
MdnsFeatureFlags.newBuilder().setIsExpiredServicesRemovalEnabled(false).build(),
mockDecoderClock);
doAnswer(inv -> {
latestDelayMs = 0;