Add MdnsServiceCache

Currently, the mDns discovery services will clear if no listener
registers to that MdnsServiceTypeClient. If an app does discover,
stop discover, resolve, at this point the listener was
unregistered, so the MdnsServiceTypeClient was deleted, and the
service is gone from cache. So this will actually restart
discovery without returning previous results from cache. Thus,
add MdnsServiceCache to store all services and reduce the
duplicated queries.

Bug: 265787401
Test: atest FrameworksNetTests
Change-Id: If3d4eb4e3dc5455f6f97cb782aa1b99b2a00f6e0
This commit is contained in:
Paul Hu
2023-04-19 17:26:27 +08:00
parent 700de306ef
commit ade3f45956
4 changed files with 441 additions and 0 deletions

View File

@@ -0,0 +1,178 @@
/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.server.connectivity.mdns;
import static com.android.server.connectivity.mdns.MdnsSocketProvider.ensureRunningOnHandlerThread;
import static com.android.server.connectivity.mdns.util.MdnsUtils.equalsIgnoreDnsCase;
import static com.android.server.connectivity.mdns.util.MdnsUtils.toDnsLowerCase;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.net.Network;
import android.os.Handler;
import android.os.Looper;
import android.util.ArrayMap;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
/**
* The {@link MdnsServiceCache} manages the service which discovers from each socket and cache these
* services to reduce duplicated queries.
*
* <p>This class is not thread safe, it is intended to be used only from the looper thread.
* However, the constructor is an exception, as it is called on another thread;
* therefore for thread safety all members of this class MUST either be final or initialized
* to their default value (0, false or null).
*/
public class MdnsServiceCache {
private static class CacheKey {
@NonNull final String mLowercaseServiceType;
@Nullable final Network mNetwork;
CacheKey(@NonNull String serviceType, @Nullable Network network) {
mLowercaseServiceType = toDnsLowerCase(serviceType);
mNetwork = network;
}
@Override public int hashCode() {
return Objects.hash(mLowercaseServiceType, mNetwork);
}
@Override public boolean equals(Object other) {
if (this == other) {
return true;
}
if (!(other instanceof CacheKey)) {
return false;
}
return Objects.equals(mLowercaseServiceType, ((CacheKey) other).mLowercaseServiceType)
&& Objects.equals(mNetwork, ((CacheKey) other).mNetwork);
}
}
/**
* A map of cached services. Key is composed of service name, type and network. Value is the
* service which use the service type to discover from each socket.
*/
@NonNull
private final ArrayMap<CacheKey, List<MdnsResponse>> mCachedServices = new ArrayMap<>();
@NonNull
private final Handler mHandler;
public MdnsServiceCache(@NonNull Looper looper) {
mHandler = new Handler(looper);
}
/**
* Get the cache services which are queried from given service type and network.
*
* @param serviceType the target service type.
* @param network the target network
* @return the set of services which matches the given service type.
*/
@NonNull
public List<MdnsResponse> getCachedServices(@NonNull String serviceType,
@Nullable Network network) {
ensureRunningOnHandlerThread(mHandler);
final CacheKey key = new CacheKey(serviceType, network);
return mCachedServices.containsKey(key)
? Collections.unmodifiableList(new ArrayList<>(mCachedServices.get(key)))
: Collections.emptyList();
}
private MdnsResponse findMatchedResponse(@NonNull List<MdnsResponse> responses,
@NonNull String serviceName) {
for (MdnsResponse response : responses) {
if (equalsIgnoreDnsCase(serviceName, response.getServiceInstanceName())) {
return response;
}
}
return null;
}
/**
* Get the cache service.
*
* @param serviceName the target service name.
* @param serviceType the target service type.
* @param network the target network
* @return the service which matches given conditions.
*/
@Nullable
public MdnsResponse getCachedService(@NonNull String serviceName,
@NonNull String serviceType, @Nullable Network network) {
ensureRunningOnHandlerThread(mHandler);
final List<MdnsResponse> responses =
mCachedServices.get(new CacheKey(serviceType, network));
if (responses == null) {
return null;
}
final MdnsResponse response = findMatchedResponse(responses, serviceName);
return response != null ? new MdnsResponse(response) : null;
}
/**
* Add or update a service.
*
* @param serviceType the service type.
* @param network the target network
* @param response the response of the discovered service.
*/
public void addOrUpdateService(@NonNull String serviceType, @Nullable Network network,
@NonNull MdnsResponse response) {
ensureRunningOnHandlerThread(mHandler);
final List<MdnsResponse> responses = mCachedServices.computeIfAbsent(
new CacheKey(serviceType, network), key -> new ArrayList<>());
// Remove existing service if present.
final MdnsResponse existing =
findMatchedResponse(responses, response.getServiceInstanceName());
responses.remove(existing);
responses.add(response);
}
/**
* Remove a service which matches the given service name, type and network.
*
* @param serviceName the target service name.
* @param serviceType the target service type.
* @param network the target network.
*/
@Nullable
public MdnsResponse removeService(@NonNull String serviceName, @NonNull String serviceType,
@Nullable Network network) {
ensureRunningOnHandlerThread(mHandler);
final List<MdnsResponse> responses =
mCachedServices.get(new CacheKey(serviceType, network));
if (responses == null) {
return null;
}
final Iterator<MdnsResponse> iterator = responses.iterator();
while (iterator.hasNext()) {
final MdnsResponse response = iterator.next();
if (equalsIgnoreDnsCase(serviceName, response.getServiceInstanceName())) {
iterator.remove();
return response;
}
}
return null;
}
// TODO: check ttl expiration for each service and notify to the clients.
}

View File

@@ -0,0 +1,59 @@
/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.server.connectivity.mdns.util;
import android.annotation.NonNull;
/**
* Mdns utility functions.
*/
public class MdnsUtils {
private MdnsUtils() { }
/**
* Convert the string to DNS case-insensitive lowercase
*
* Per rfc6762#page-46, accented characters are not defined to be automatically equivalent to
* their unaccented counterparts. So the "DNS lowercase" should be if character is A-Z then they
* transform into a-z. Otherwise, they are kept as-is.
*/
public static String toDnsLowerCase(@NonNull String string) {
final char[] outChars = new char[string.length()];
for (int i = 0; i < string.length(); i++) {
outChars[i] = toDnsLowerCase(string.charAt(i));
}
return new String(outChars);
}
/**
* Compare two strings by DNS case-insensitive lowercase.
*/
public static boolean equalsIgnoreDnsCase(@NonNull String a, @NonNull String b) {
if (a.length() != b.length()) return false;
for (int i = 0; i < a.length(); i++) {
if (toDnsLowerCase(a.charAt(i)) != toDnsLowerCase(b.charAt(i))) {
return false;
}
}
return true;
}
private static char toDnsLowerCase(char a) {
return a >= 'A' && a <= 'Z' ? (char) (a + ('a' - 'A')) : a;
}
}

View File

@@ -0,0 +1,136 @@
/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License")
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.server.connectivity.mdns
import android.net.Network
import android.os.Build
import android.os.Handler
import android.os.HandlerThread
import com.android.testutils.DevSdkIgnoreRule
import com.android.testutils.DevSdkIgnoreRunner
import java.util.concurrent.CompletableFuture
import java.util.concurrent.TimeUnit
import kotlin.test.assertNotNull
import org.junit.After
import org.junit.Assert.assertEquals
import org.junit.Assert.assertNull
import org.junit.Assert.assertTrue
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.Mockito.mock
private const val SERVICE_NAME_1 = "service-instance-1"
private const val SERVICE_NAME_2 = "service-instance-2"
private const val SERVICE_TYPE_1 = "_test1._tcp.local"
private const val SERVICE_TYPE_2 = "_test2._tcp.local"
private const val INTERFACE_INDEX = 999
private const val DEFAULT_TIMEOUT_MS = 2000L
@RunWith(DevSdkIgnoreRunner::class)
@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
class MdnsServiceCacheTest {
private val network = mock(Network::class.java)
private val thread = HandlerThread(MdnsServiceCacheTest::class.simpleName)
private val handler by lazy {
Handler(thread.looper)
}
private val serviceCache by lazy {
MdnsServiceCache(thread.looper)
}
@Before
fun setUp() {
thread.start()
}
@After
fun tearDown() {
thread.quitSafely()
}
private fun <T> runningOnHandlerAndReturn(functor: (() -> T)): T {
val future = CompletableFuture<T>()
handler.post {
future.complete(functor())
}
return future.get(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS)
}
private fun addOrUpdateService(serviceType: String, network: Network, service: MdnsResponse):
Unit = runningOnHandlerAndReturn {
serviceCache.addOrUpdateService(serviceType, network, service) }
private fun removeService(serviceName: String, serviceType: String, network: Network):
Unit = runningOnHandlerAndReturn {
serviceCache.removeService(serviceName, serviceType, network) }
private fun getService(serviceName: String, serviceType: String, network: Network):
MdnsResponse? = runningOnHandlerAndReturn {
serviceCache.getCachedService(serviceName, serviceType, network) }
private fun getServices(serviceType: String, network: Network): List<MdnsResponse> =
runningOnHandlerAndReturn { serviceCache.getCachedServices(serviceType, network) }
@Test
fun testAddAndRemoveService() {
addOrUpdateService(SERVICE_TYPE_1, network, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
var response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, network)
assertNotNull(response)
assertEquals(SERVICE_NAME_1, response.serviceInstanceName)
removeService(SERVICE_NAME_1, SERVICE_TYPE_1, network)
response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, network)
assertNull(response)
}
@Test
fun testGetCachedServices_multipleServiceTypes() {
addOrUpdateService(SERVICE_TYPE_1, network, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
addOrUpdateService(SERVICE_TYPE_1, network, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1))
addOrUpdateService(SERVICE_TYPE_2, network, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2))
val responses1 = getServices(SERVICE_TYPE_1, network)
assertEquals(2, responses1.size)
assertTrue(responses1.stream().anyMatch { response ->
response.serviceInstanceName == SERVICE_NAME_1
})
assertTrue(responses1.any { response ->
response.serviceInstanceName == SERVICE_NAME_2
})
val responses2 = getServices(SERVICE_TYPE_2, network)
assertEquals(1, responses2.size)
assertTrue(responses2.any { response ->
response.serviceInstanceName == SERVICE_NAME_2
})
removeService(SERVICE_NAME_2, SERVICE_TYPE_1, network)
val responses3 = getServices(SERVICE_TYPE_1, network)
assertEquals(1, responses3.size)
assertTrue(responses3.any { response ->
response.serviceInstanceName == SERVICE_NAME_1
})
val responses4 = getServices(SERVICE_TYPE_2, network)
assertEquals(1, responses4.size)
assertTrue(responses4.any { response ->
response.serviceInstanceName == SERVICE_NAME_2
})
}
private fun createResponse(serviceInstanceName: String, serviceType: String) = MdnsResponse(
0 /* now */, "$serviceInstanceName.$serviceType".split(".").toTypedArray(),
INTERFACE_INDEX, network)
}

View File

@@ -0,0 +1,68 @@
/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License")
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.server.connectivity.mdns.util
import android.os.Build
import com.android.server.connectivity.mdns.util.MdnsUtils.equalsIgnoreDnsCase
import com.android.server.connectivity.mdns.util.MdnsUtils.toDnsLowerCase
import com.android.testutils.DevSdkIgnoreRule
import com.android.testutils.DevSdkIgnoreRunner
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
import org.junit.Assert.assertTrue
import org.junit.Test
import org.junit.runner.RunWith
@RunWith(DevSdkIgnoreRunner::class)
@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
class MdnsUtilsTest {
@Test
fun testToDnsLowerCase() {
assertEquals("test", toDnsLowerCase("TEST"))
assertEquals("test", toDnsLowerCase("TeSt"))
assertEquals("test", toDnsLowerCase("test"))
assertEquals("tÉst", toDnsLowerCase("TÉST"))
assertEquals("ţést", toDnsLowerCase("ţést"))
// Unicode characters 0x10000 (𐀀), 0x10001 (𐀁), 0x10041 (𐁁)
// Note the last 2 bytes of 0x10041 are identical to 'A', but it should remain unchanged.
assertEquals("test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ",
toDnsLowerCase("Test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- "))
// Also test some characters where the first surrogate is not \ud800
assertEquals("test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
"\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<",
toDnsLowerCase("Test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
"\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<"))
}
@Test
fun testEqualsIgnoreDnsCase() {
assertTrue(equalsIgnoreDnsCase("TEST", "Test"))
assertTrue(equalsIgnoreDnsCase("TEST", "test"))
assertTrue(equalsIgnoreDnsCase("test", "TeSt"))
assertTrue(equalsIgnoreDnsCase("Tést", "tést"))
assertFalse(equalsIgnoreDnsCase("ŢÉST", "ţést"))
// Unicode characters 0x10000 (𐀀), 0x10001 (𐀁), 0x10041 (𐁁)
// Note the last 2 bytes of 0x10041 are identical to 'A', but it should remain unchanged.
assertTrue(equalsIgnoreDnsCase("test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ",
"Test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- "))
// Also test some characters where the first surrogate is not \ud800
assertTrue(equalsIgnoreDnsCase("test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
"\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<",
"Test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
"\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<"))
}
}