diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java new file mode 100644 index 0000000000..bfda535466 --- /dev/null +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java @@ -0,0 +1,178 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.server.connectivity.mdns; + +import static com.android.server.connectivity.mdns.MdnsSocketProvider.ensureRunningOnHandlerThread; +import static com.android.server.connectivity.mdns.util.MdnsUtils.equalsIgnoreDnsCase; +import static com.android.server.connectivity.mdns.util.MdnsUtils.toDnsLowerCase; + +import android.annotation.NonNull; +import android.annotation.Nullable; +import android.net.Network; +import android.os.Handler; +import android.os.Looper; +import android.util.ArrayMap; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; + +/** + * The {@link MdnsServiceCache} manages the service which discovers from each socket and cache these + * services to reduce duplicated queries. + * + *

This class is not thread safe, it is intended to be used only from the looper thread. + * However, the constructor is an exception, as it is called on another thread; + * therefore for thread safety all members of this class MUST either be final or initialized + * to their default value (0, false or null). + */ +public class MdnsServiceCache { + private static class CacheKey { + @NonNull final String mLowercaseServiceType; + @Nullable final Network mNetwork; + + CacheKey(@NonNull String serviceType, @Nullable Network network) { + mLowercaseServiceType = toDnsLowerCase(serviceType); + mNetwork = network; + } + + @Override public int hashCode() { + return Objects.hash(mLowercaseServiceType, mNetwork); + } + + @Override public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof CacheKey)) { + return false; + } + return Objects.equals(mLowercaseServiceType, ((CacheKey) other).mLowercaseServiceType) + && Objects.equals(mNetwork, ((CacheKey) other).mNetwork); + } + } + /** + * A map of cached services. Key is composed of service name, type and network. Value is the + * service which use the service type to discover from each socket. + */ + @NonNull + private final ArrayMap> mCachedServices = new ArrayMap<>(); + @NonNull + private final Handler mHandler; + + public MdnsServiceCache(@NonNull Looper looper) { + mHandler = new Handler(looper); + } + + /** + * Get the cache services which are queried from given service type and network. + * + * @param serviceType the target service type. + * @param network the target network + * @return the set of services which matches the given service type. + */ + @NonNull + public List getCachedServices(@NonNull String serviceType, + @Nullable Network network) { + ensureRunningOnHandlerThread(mHandler); + final CacheKey key = new CacheKey(serviceType, network); + return mCachedServices.containsKey(key) + ? Collections.unmodifiableList(new ArrayList<>(mCachedServices.get(key))) + : Collections.emptyList(); + } + + private MdnsResponse findMatchedResponse(@NonNull List responses, + @NonNull String serviceName) { + for (MdnsResponse response : responses) { + if (equalsIgnoreDnsCase(serviceName, response.getServiceInstanceName())) { + return response; + } + } + return null; + } + + /** + * Get the cache service. + * + * @param serviceName the target service name. + * @param serviceType the target service type. + * @param network the target network + * @return the service which matches given conditions. + */ + @Nullable + public MdnsResponse getCachedService(@NonNull String serviceName, + @NonNull String serviceType, @Nullable Network network) { + ensureRunningOnHandlerThread(mHandler); + final List responses = + mCachedServices.get(new CacheKey(serviceType, network)); + if (responses == null) { + return null; + } + final MdnsResponse response = findMatchedResponse(responses, serviceName); + return response != null ? new MdnsResponse(response) : null; + } + + /** + * Add or update a service. + * + * @param serviceType the service type. + * @param network the target network + * @param response the response of the discovered service. + */ + public void addOrUpdateService(@NonNull String serviceType, @Nullable Network network, + @NonNull MdnsResponse response) { + ensureRunningOnHandlerThread(mHandler); + final List responses = mCachedServices.computeIfAbsent( + new CacheKey(serviceType, network), key -> new ArrayList<>()); + // Remove existing service if present. + final MdnsResponse existing = + findMatchedResponse(responses, response.getServiceInstanceName()); + responses.remove(existing); + responses.add(response); + } + + /** + * Remove a service which matches the given service name, type and network. + * + * @param serviceName the target service name. + * @param serviceType the target service type. + * @param network the target network. + */ + @Nullable + public MdnsResponse removeService(@NonNull String serviceName, @NonNull String serviceType, + @Nullable Network network) { + ensureRunningOnHandlerThread(mHandler); + final List responses = + mCachedServices.get(new CacheKey(serviceType, network)); + if (responses == null) { + return null; + } + final Iterator iterator = responses.iterator(); + while (iterator.hasNext()) { + final MdnsResponse response = iterator.next(); + if (equalsIgnoreDnsCase(serviceName, response.getServiceInstanceName())) { + iterator.remove(); + return response; + } + } + return null; + } + + // TODO: check ttl expiration for each service and notify to the clients. +} diff --git a/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java b/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java new file mode 100644 index 0000000000..4b0f2a4e19 --- /dev/null +++ b/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java @@ -0,0 +1,59 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.server.connectivity.mdns.util; + +import android.annotation.NonNull; + +/** + * Mdns utility functions. + */ +public class MdnsUtils { + + private MdnsUtils() { } + + /** + * Convert the string to DNS case-insensitive lowercase + * + * Per rfc6762#page-46, accented characters are not defined to be automatically equivalent to + * their unaccented counterparts. So the "DNS lowercase" should be if character is A-Z then they + * transform into a-z. Otherwise, they are kept as-is. + */ + public static String toDnsLowerCase(@NonNull String string) { + final char[] outChars = new char[string.length()]; + for (int i = 0; i < string.length(); i++) { + outChars[i] = toDnsLowerCase(string.charAt(i)); + } + return new String(outChars); + } + + /** + * Compare two strings by DNS case-insensitive lowercase. + */ + public static boolean equalsIgnoreDnsCase(@NonNull String a, @NonNull String b) { + if (a.length() != b.length()) return false; + for (int i = 0; i < a.length(); i++) { + if (toDnsLowerCase(a.charAt(i)) != toDnsLowerCase(b.charAt(i))) { + return false; + } + } + return true; + } + + private static char toDnsLowerCase(char a) { + return a >= 'A' && a <= 'Z' ? (char) (a + ('a' - 'A')) : a; + } +} diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt new file mode 100644 index 0000000000..f091eead74 --- /dev/null +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt @@ -0,0 +1,136 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License") + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.server.connectivity.mdns + +import android.net.Network +import android.os.Build +import android.os.Handler +import android.os.HandlerThread +import com.android.testutils.DevSdkIgnoreRule +import com.android.testutils.DevSdkIgnoreRunner +import java.util.concurrent.CompletableFuture +import java.util.concurrent.TimeUnit +import kotlin.test.assertNotNull +import org.junit.After +import org.junit.Assert.assertEquals +import org.junit.Assert.assertNull +import org.junit.Assert.assertTrue +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.mockito.Mockito.mock + +private const val SERVICE_NAME_1 = "service-instance-1" +private const val SERVICE_NAME_2 = "service-instance-2" +private const val SERVICE_TYPE_1 = "_test1._tcp.local" +private const val SERVICE_TYPE_2 = "_test2._tcp.local" +private const val INTERFACE_INDEX = 999 +private const val DEFAULT_TIMEOUT_MS = 2000L + +@RunWith(DevSdkIgnoreRunner::class) +@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2) +class MdnsServiceCacheTest { + private val network = mock(Network::class.java) + private val thread = HandlerThread(MdnsServiceCacheTest::class.simpleName) + private val handler by lazy { + Handler(thread.looper) + } + private val serviceCache by lazy { + MdnsServiceCache(thread.looper) + } + + @Before + fun setUp() { + thread.start() + } + + @After + fun tearDown() { + thread.quitSafely() + } + + private fun runningOnHandlerAndReturn(functor: (() -> T)): T { + val future = CompletableFuture() + handler.post { + future.complete(functor()) + } + return future.get(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS) + } + + private fun addOrUpdateService(serviceType: String, network: Network, service: MdnsResponse): + Unit = runningOnHandlerAndReturn { + serviceCache.addOrUpdateService(serviceType, network, service) } + + private fun removeService(serviceName: String, serviceType: String, network: Network): + Unit = runningOnHandlerAndReturn { + serviceCache.removeService(serviceName, serviceType, network) } + + private fun getService(serviceName: String, serviceType: String, network: Network): + MdnsResponse? = runningOnHandlerAndReturn { + serviceCache.getCachedService(serviceName, serviceType, network) } + + private fun getServices(serviceType: String, network: Network): List = + runningOnHandlerAndReturn { serviceCache.getCachedServices(serviceType, network) } + + @Test + fun testAddAndRemoveService() { + addOrUpdateService(SERVICE_TYPE_1, network, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1)) + var response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, network) + assertNotNull(response) + assertEquals(SERVICE_NAME_1, response.serviceInstanceName) + removeService(SERVICE_NAME_1, SERVICE_TYPE_1, network) + response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, network) + assertNull(response) + } + + @Test + fun testGetCachedServices_multipleServiceTypes() { + addOrUpdateService(SERVICE_TYPE_1, network, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1)) + addOrUpdateService(SERVICE_TYPE_1, network, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1)) + addOrUpdateService(SERVICE_TYPE_2, network, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2)) + + val responses1 = getServices(SERVICE_TYPE_1, network) + assertEquals(2, responses1.size) + assertTrue(responses1.stream().anyMatch { response -> + response.serviceInstanceName == SERVICE_NAME_1 + }) + assertTrue(responses1.any { response -> + response.serviceInstanceName == SERVICE_NAME_2 + }) + val responses2 = getServices(SERVICE_TYPE_2, network) + assertEquals(1, responses2.size) + assertTrue(responses2.any { response -> + response.serviceInstanceName == SERVICE_NAME_2 + }) + + removeService(SERVICE_NAME_2, SERVICE_TYPE_1, network) + val responses3 = getServices(SERVICE_TYPE_1, network) + assertEquals(1, responses3.size) + assertTrue(responses3.any { response -> + response.serviceInstanceName == SERVICE_NAME_1 + }) + val responses4 = getServices(SERVICE_TYPE_2, network) + assertEquals(1, responses4.size) + assertTrue(responses4.any { response -> + response.serviceInstanceName == SERVICE_NAME_2 + }) + } + + private fun createResponse(serviceInstanceName: String, serviceType: String) = MdnsResponse( + 0 /* now */, "$serviceInstanceName.$serviceType".split(".").toTypedArray(), + INTERFACE_INDEX, network) +} diff --git a/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt new file mode 100644 index 0000000000..f584ed5434 --- /dev/null +++ b/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License") + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.server.connectivity.mdns.util + +import android.os.Build +import com.android.server.connectivity.mdns.util.MdnsUtils.equalsIgnoreDnsCase +import com.android.server.connectivity.mdns.util.MdnsUtils.toDnsLowerCase +import com.android.testutils.DevSdkIgnoreRule +import com.android.testutils.DevSdkIgnoreRunner +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertTrue +import org.junit.Test +import org.junit.runner.RunWith + +@RunWith(DevSdkIgnoreRunner::class) +@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2) +class MdnsUtilsTest { + @Test + fun testToDnsLowerCase() { + assertEquals("test", toDnsLowerCase("TEST")) + assertEquals("test", toDnsLowerCase("TeSt")) + assertEquals("test", toDnsLowerCase("test")) + assertEquals("tÉst", toDnsLowerCase("TÉST")) + assertEquals("ţést", toDnsLowerCase("ţést")) + // Unicode characters 0x10000 (𐀀), 0x10001 (𐀁), 0x10041 (𐁁) + // Note the last 2 bytes of 0x10041 are identical to 'A', but it should remain unchanged. + assertEquals("test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ", + toDnsLowerCase("Test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ")) + // Also test some characters where the first surrogate is not \ud800 + assertEquals("test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" + + "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<", + toDnsLowerCase("Test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" + + "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<")) + } + + @Test + fun testEqualsIgnoreDnsCase() { + assertTrue(equalsIgnoreDnsCase("TEST", "Test")) + assertTrue(equalsIgnoreDnsCase("TEST", "test")) + assertTrue(equalsIgnoreDnsCase("test", "TeSt")) + assertTrue(equalsIgnoreDnsCase("Tést", "tést")) + assertFalse(equalsIgnoreDnsCase("ŢÉST", "ţést")) + // Unicode characters 0x10000 (𐀀), 0x10001 (𐀁), 0x10041 (𐁁) + // Note the last 2 bytes of 0x10041 are identical to 'A', but it should remain unchanged. + assertTrue(equalsIgnoreDnsCase("test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ", + "Test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ")) + // Also test some characters where the first surrogate is not \ud800 + assertTrue(equalsIgnoreDnsCase("test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" + + "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<", + "Test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" + + "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<")) + } +}