Add MdnsServiceCache

Currently, the mDns discovery services will clear if no listener
registers to that MdnsServiceTypeClient. If an app does discover,
stop discover, resolve, at this point the listener was
unregistered, so the MdnsServiceTypeClient was deleted, and the
service is gone from cache. So this will actually restart
discovery without returning previous results from cache. Thus,
add MdnsServiceCache to store all services and reduce the
duplicated queries.

Bug: 265787401
Test: atest FrameworksNetTests
Change-Id: If3d4eb4e3dc5455f6f97cb782aa1b99b2a00f6e0
This commit is contained in:
Paul Hu
2023-04-19 17:26:27 +08:00
parent 700de306ef
commit ade3f45956
4 changed files with 441 additions and 0 deletions

View File

@@ -0,0 +1,136 @@
/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License")
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.server.connectivity.mdns
import android.net.Network
import android.os.Build
import android.os.Handler
import android.os.HandlerThread
import com.android.testutils.DevSdkIgnoreRule
import com.android.testutils.DevSdkIgnoreRunner
import java.util.concurrent.CompletableFuture
import java.util.concurrent.TimeUnit
import kotlin.test.assertNotNull
import org.junit.After
import org.junit.Assert.assertEquals
import org.junit.Assert.assertNull
import org.junit.Assert.assertTrue
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.Mockito.mock
private const val SERVICE_NAME_1 = "service-instance-1"
private const val SERVICE_NAME_2 = "service-instance-2"
private const val SERVICE_TYPE_1 = "_test1._tcp.local"
private const val SERVICE_TYPE_2 = "_test2._tcp.local"
private const val INTERFACE_INDEX = 999
private const val DEFAULT_TIMEOUT_MS = 2000L
@RunWith(DevSdkIgnoreRunner::class)
@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
class MdnsServiceCacheTest {
private val network = mock(Network::class.java)
private val thread = HandlerThread(MdnsServiceCacheTest::class.simpleName)
private val handler by lazy {
Handler(thread.looper)
}
private val serviceCache by lazy {
MdnsServiceCache(thread.looper)
}
@Before
fun setUp() {
thread.start()
}
@After
fun tearDown() {
thread.quitSafely()
}
private fun <T> runningOnHandlerAndReturn(functor: (() -> T)): T {
val future = CompletableFuture<T>()
handler.post {
future.complete(functor())
}
return future.get(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS)
}
private fun addOrUpdateService(serviceType: String, network: Network, service: MdnsResponse):
Unit = runningOnHandlerAndReturn {
serviceCache.addOrUpdateService(serviceType, network, service) }
private fun removeService(serviceName: String, serviceType: String, network: Network):
Unit = runningOnHandlerAndReturn {
serviceCache.removeService(serviceName, serviceType, network) }
private fun getService(serviceName: String, serviceType: String, network: Network):
MdnsResponse? = runningOnHandlerAndReturn {
serviceCache.getCachedService(serviceName, serviceType, network) }
private fun getServices(serviceType: String, network: Network): List<MdnsResponse> =
runningOnHandlerAndReturn { serviceCache.getCachedServices(serviceType, network) }
@Test
fun testAddAndRemoveService() {
addOrUpdateService(SERVICE_TYPE_1, network, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
var response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, network)
assertNotNull(response)
assertEquals(SERVICE_NAME_1, response.serviceInstanceName)
removeService(SERVICE_NAME_1, SERVICE_TYPE_1, network)
response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, network)
assertNull(response)
}
@Test
fun testGetCachedServices_multipleServiceTypes() {
addOrUpdateService(SERVICE_TYPE_1, network, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
addOrUpdateService(SERVICE_TYPE_1, network, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1))
addOrUpdateService(SERVICE_TYPE_2, network, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2))
val responses1 = getServices(SERVICE_TYPE_1, network)
assertEquals(2, responses1.size)
assertTrue(responses1.stream().anyMatch { response ->
response.serviceInstanceName == SERVICE_NAME_1
})
assertTrue(responses1.any { response ->
response.serviceInstanceName == SERVICE_NAME_2
})
val responses2 = getServices(SERVICE_TYPE_2, network)
assertEquals(1, responses2.size)
assertTrue(responses2.any { response ->
response.serviceInstanceName == SERVICE_NAME_2
})
removeService(SERVICE_NAME_2, SERVICE_TYPE_1, network)
val responses3 = getServices(SERVICE_TYPE_1, network)
assertEquals(1, responses3.size)
assertTrue(responses3.any { response ->
response.serviceInstanceName == SERVICE_NAME_1
})
val responses4 = getServices(SERVICE_TYPE_2, network)
assertEquals(1, responses4.size)
assertTrue(responses4.any { response ->
response.serviceInstanceName == SERVICE_NAME_2
})
}
private fun createResponse(serviceInstanceName: String, serviceType: String) = MdnsResponse(
0 /* now */, "$serviceInstanceName.$serviceType".split(".").toTypedArray(),
INTERFACE_INDEX, network)
}

View File

@@ -0,0 +1,68 @@
/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License")
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.server.connectivity.mdns.util
import android.os.Build
import com.android.server.connectivity.mdns.util.MdnsUtils.equalsIgnoreDnsCase
import com.android.server.connectivity.mdns.util.MdnsUtils.toDnsLowerCase
import com.android.testutils.DevSdkIgnoreRule
import com.android.testutils.DevSdkIgnoreRunner
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
import org.junit.Assert.assertTrue
import org.junit.Test
import org.junit.runner.RunWith
@RunWith(DevSdkIgnoreRunner::class)
@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
class MdnsUtilsTest {
@Test
fun testToDnsLowerCase() {
assertEquals("test", toDnsLowerCase("TEST"))
assertEquals("test", toDnsLowerCase("TeSt"))
assertEquals("test", toDnsLowerCase("test"))
assertEquals("tÉst", toDnsLowerCase("TÉST"))
assertEquals("ţést", toDnsLowerCase("ţést"))
// Unicode characters 0x10000 (𐀀), 0x10001 (𐀁), 0x10041 (𐁁)
// Note the last 2 bytes of 0x10041 are identical to 'A', but it should remain unchanged.
assertEquals("test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ",
toDnsLowerCase("Test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- "))
// Also test some characters where the first surrogate is not \ud800
assertEquals("test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
"\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<",
toDnsLowerCase("Test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
"\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<"))
}
@Test
fun testEqualsIgnoreDnsCase() {
assertTrue(equalsIgnoreDnsCase("TEST", "Test"))
assertTrue(equalsIgnoreDnsCase("TEST", "test"))
assertTrue(equalsIgnoreDnsCase("test", "TeSt"))
assertTrue(equalsIgnoreDnsCase("Tést", "tést"))
assertFalse(equalsIgnoreDnsCase("ŢÉST", "ţést"))
// Unicode characters 0x10000 (𐀀), 0x10001 (𐀁), 0x10041 (𐁁)
// Note the last 2 bytes of 0x10041 are identical to 'A', but it should remain unchanged.
assertTrue(equalsIgnoreDnsCase("test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ",
"Test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- "))
// Also test some characters where the first surrogate is not \ud800
assertTrue(equalsIgnoreDnsCase("test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
"\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<",
"Test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
"\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<"))
}
}