Use SocketKey in MdnsServiceCache

The MdnsServiceTypeClient is now created using a SocketKey, so
the MdnsServiceCache should also use the SocketKey to deal with
the caching services.

Bug: 265787401
Test: atest FrameworksNetTests
Change-Id: I6165ffd420a39e750c06778b4851142a3ba3cf44
This commit is contained in:
Paul Hu
2023-07-03 09:22:35 +00:00
parent e59d30bc2a
commit 775840e1b4
2 changed files with 51 additions and 46 deletions

View File

@@ -22,7 +22,6 @@ import static com.android.server.connectivity.mdns.util.MdnsUtils.toDnsLowerCase
import android.annotation.NonNull; import android.annotation.NonNull;
import android.annotation.Nullable; import android.annotation.Nullable;
import android.net.Network;
import android.os.Handler; import android.os.Handler;
import android.os.Looper; import android.os.Looper;
import android.util.ArrayMap; import android.util.ArrayMap;
@@ -45,15 +44,15 @@ import java.util.Objects;
public class MdnsServiceCache { public class MdnsServiceCache {
private static class CacheKey { private static class CacheKey {
@NonNull final String mLowercaseServiceType; @NonNull final String mLowercaseServiceType;
@Nullable final Network mNetwork; @NonNull final SocketKey mSocketKey;
CacheKey(@NonNull String serviceType, @Nullable Network network) { CacheKey(@NonNull String serviceType, @NonNull SocketKey socketKey) {
mLowercaseServiceType = toDnsLowerCase(serviceType); mLowercaseServiceType = toDnsLowerCase(serviceType);
mNetwork = network; mSocketKey = socketKey;
} }
@Override public int hashCode() { @Override public int hashCode() {
return Objects.hash(mLowercaseServiceType, mNetwork); return Objects.hash(mLowercaseServiceType, mSocketKey);
} }
@Override public boolean equals(Object other) { @Override public boolean equals(Object other) {
@@ -64,11 +63,11 @@ public class MdnsServiceCache {
return false; return false;
} }
return Objects.equals(mLowercaseServiceType, ((CacheKey) other).mLowercaseServiceType) return Objects.equals(mLowercaseServiceType, ((CacheKey) other).mLowercaseServiceType)
&& Objects.equals(mNetwork, ((CacheKey) other).mNetwork); && Objects.equals(mSocketKey, ((CacheKey) other).mSocketKey);
} }
} }
/** /**
* A map of cached services. Key is composed of service name, type and network. Value is the * A map of cached services. Key is composed of service name, type and socket. Value is the
* service which use the service type to discover from each socket. * service which use the service type to discover from each socket.
*/ */
@NonNull @NonNull
@@ -81,17 +80,17 @@ public class MdnsServiceCache {
} }
/** /**
* Get the cache services which are queried from given service type and network. * Get the cache services which are queried from given service type and socket.
* *
* @param serviceType the target service type. * @param serviceType the target service type.
* @param network the target network * @param socketKey the target socket
* @return the set of services which matches the given service type. * @return the set of services which matches the given service type.
*/ */
@NonNull @NonNull
public List<MdnsResponse> getCachedServices(@NonNull String serviceType, public List<MdnsResponse> getCachedServices(@NonNull String serviceType,
@Nullable Network network) { @NonNull SocketKey socketKey) {
ensureRunningOnHandlerThread(mHandler); ensureRunningOnHandlerThread(mHandler);
final CacheKey key = new CacheKey(serviceType, network); final CacheKey key = new CacheKey(serviceType, socketKey);
return mCachedServices.containsKey(key) return mCachedServices.containsKey(key)
? Collections.unmodifiableList(new ArrayList<>(mCachedServices.get(key))) ? Collections.unmodifiableList(new ArrayList<>(mCachedServices.get(key)))
: Collections.emptyList(); : Collections.emptyList();
@@ -112,15 +111,15 @@ public class MdnsServiceCache {
* *
* @param serviceName the target service name. * @param serviceName the target service name.
* @param serviceType the target service type. * @param serviceType the target service type.
* @param network the target network * @param socketKey the target socket
* @return the service which matches given conditions. * @return the service which matches given conditions.
*/ */
@Nullable @Nullable
public MdnsResponse getCachedService(@NonNull String serviceName, public MdnsResponse getCachedService(@NonNull String serviceName,
@NonNull String serviceType, @Nullable Network network) { @NonNull String serviceType, @NonNull SocketKey socketKey) {
ensureRunningOnHandlerThread(mHandler); ensureRunningOnHandlerThread(mHandler);
final List<MdnsResponse> responses = final List<MdnsResponse> responses =
mCachedServices.get(new CacheKey(serviceType, network)); mCachedServices.get(new CacheKey(serviceType, socketKey));
if (responses == null) { if (responses == null) {
return null; return null;
} }
@@ -132,14 +131,14 @@ public class MdnsServiceCache {
* Add or update a service. * Add or update a service.
* *
* @param serviceType the service type. * @param serviceType the service type.
* @param network the target network * @param socketKey the target socket
* @param response the response of the discovered service. * @param response the response of the discovered service.
*/ */
public void addOrUpdateService(@NonNull String serviceType, @Nullable Network network, public void addOrUpdateService(@NonNull String serviceType, @NonNull SocketKey socketKey,
@NonNull MdnsResponse response) { @NonNull MdnsResponse response) {
ensureRunningOnHandlerThread(mHandler); ensureRunningOnHandlerThread(mHandler);
final List<MdnsResponse> responses = mCachedServices.computeIfAbsent( final List<MdnsResponse> responses = mCachedServices.computeIfAbsent(
new CacheKey(serviceType, network), key -> new ArrayList<>()); new CacheKey(serviceType, socketKey), key -> new ArrayList<>());
// Remove existing service if present. // Remove existing service if present.
final MdnsResponse existing = final MdnsResponse existing =
findMatchedResponse(responses, response.getServiceInstanceName()); findMatchedResponse(responses, response.getServiceInstanceName());
@@ -148,18 +147,18 @@ public class MdnsServiceCache {
} }
/** /**
* Remove a service which matches the given service name, type and network. * Remove a service which matches the given service name, type and socket.
* *
* @param serviceName the target service name. * @param serviceName the target service name.
* @param serviceType the target service type. * @param serviceType the target service type.
* @param network the target network. * @param socketKey the target socket.
*/ */
@Nullable @Nullable
public MdnsResponse removeService(@NonNull String serviceName, @NonNull String serviceType, public MdnsResponse removeService(@NonNull String serviceName, @NonNull String serviceType,
@Nullable Network network) { @NonNull SocketKey socketKey) {
ensureRunningOnHandlerThread(mHandler); ensureRunningOnHandlerThread(mHandler);
final List<MdnsResponse> responses = final List<MdnsResponse> responses =
mCachedServices.get(new CacheKey(serviceType, network)); mCachedServices.get(new CacheKey(serviceType, socketKey));
if (responses == null) { if (responses == null) {
return null; return null;
} }

View File

@@ -16,7 +16,6 @@
package com.android.server.connectivity.mdns package com.android.server.connectivity.mdns
import android.net.Network
import android.os.Build import android.os.Build
import android.os.Handler import android.os.Handler
import android.os.HandlerThread import android.os.HandlerThread
@@ -32,7 +31,6 @@ import org.junit.Assert.assertTrue
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
import org.junit.runner.RunWith import org.junit.runner.RunWith
import org.mockito.Mockito.mock
private const val SERVICE_NAME_1 = "service-instance-1" private const val SERVICE_NAME_1 = "service-instance-1"
private const val SERVICE_NAME_2 = "service-instance-2" private const val SERVICE_NAME_2 = "service-instance-2"
@@ -44,7 +42,7 @@ private const val DEFAULT_TIMEOUT_MS = 2000L
@RunWith(DevSdkIgnoreRunner::class) @RunWith(DevSdkIgnoreRunner::class)
@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2) @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
class MdnsServiceCacheTest { class MdnsServiceCacheTest {
private val network = mock(Network::class.java) private val socketKey = SocketKey(null /* network */, INTERFACE_INDEX)
private val thread = HandlerThread(MdnsServiceCacheTest::class.simpleName) private val thread = HandlerThread(MdnsServiceCacheTest::class.simpleName)
private val handler by lazy { private val handler by lazy {
Handler(thread.looper) Handler(thread.looper)
@@ -71,39 +69,47 @@ class MdnsServiceCacheTest {
return future.get(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS) return future.get(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS)
} }
private fun addOrUpdateService(serviceType: String, network: Network, service: MdnsResponse): private fun addOrUpdateService(
Unit = runningOnHandlerAndReturn { serviceType: String,
serviceCache.addOrUpdateService(serviceType, network, service) } socketKey: SocketKey,
service: MdnsResponse
): Unit = runningOnHandlerAndReturn {
serviceCache.addOrUpdateService(serviceType, socketKey, service)
}
private fun removeService(serviceName: String, serviceType: String, network: Network): private fun removeService(serviceName: String, serviceType: String, socketKey: SocketKey):
Unit = runningOnHandlerAndReturn { Unit = runningOnHandlerAndReturn {
serviceCache.removeService(serviceName, serviceType, network) } serviceCache.removeService(serviceName, serviceType, socketKey) }
private fun getService(serviceName: String, serviceType: String, network: Network): private fun getService(serviceName: String, serviceType: String, socketKey: SocketKey):
MdnsResponse? = runningOnHandlerAndReturn { MdnsResponse? = runningOnHandlerAndReturn {
serviceCache.getCachedService(serviceName, serviceType, network) } serviceCache.getCachedService(serviceName, serviceType, socketKey) }
private fun getServices(serviceType: String, network: Network): List<MdnsResponse> = private fun getServices(serviceType: String, socketKey: SocketKey): List<MdnsResponse> =
runningOnHandlerAndReturn { serviceCache.getCachedServices(serviceType, network) } runningOnHandlerAndReturn { serviceCache.getCachedServices(serviceType, socketKey) }
@Test @Test
fun testAddAndRemoveService() { fun testAddAndRemoveService() {
addOrUpdateService(SERVICE_TYPE_1, network, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1)) addOrUpdateService(
var response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, network) SERVICE_TYPE_1, socketKey, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
var response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, socketKey)
assertNotNull(response) assertNotNull(response)
assertEquals(SERVICE_NAME_1, response.serviceInstanceName) assertEquals(SERVICE_NAME_1, response.serviceInstanceName)
removeService(SERVICE_NAME_1, SERVICE_TYPE_1, network) removeService(SERVICE_NAME_1, SERVICE_TYPE_1, socketKey)
response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, network) response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, socketKey)
assertNull(response) assertNull(response)
} }
@Test @Test
fun testGetCachedServices_multipleServiceTypes() { fun testGetCachedServices_multipleServiceTypes() {
addOrUpdateService(SERVICE_TYPE_1, network, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1)) addOrUpdateService(
addOrUpdateService(SERVICE_TYPE_1, network, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1)) SERVICE_TYPE_1, socketKey, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
addOrUpdateService(SERVICE_TYPE_2, network, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2)) addOrUpdateService(
SERVICE_TYPE_1, socketKey, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1))
addOrUpdateService(
SERVICE_TYPE_2, socketKey, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2))
val responses1 = getServices(SERVICE_TYPE_1, network) val responses1 = getServices(SERVICE_TYPE_1, socketKey)
assertEquals(2, responses1.size) assertEquals(2, responses1.size)
assertTrue(responses1.stream().anyMatch { response -> assertTrue(responses1.stream().anyMatch { response ->
response.serviceInstanceName == SERVICE_NAME_1 response.serviceInstanceName == SERVICE_NAME_1
@@ -111,19 +117,19 @@ class MdnsServiceCacheTest {
assertTrue(responses1.any { response -> assertTrue(responses1.any { response ->
response.serviceInstanceName == SERVICE_NAME_2 response.serviceInstanceName == SERVICE_NAME_2
}) })
val responses2 = getServices(SERVICE_TYPE_2, network) val responses2 = getServices(SERVICE_TYPE_2, socketKey)
assertEquals(1, responses2.size) assertEquals(1, responses2.size)
assertTrue(responses2.any { response -> assertTrue(responses2.any { response ->
response.serviceInstanceName == SERVICE_NAME_2 response.serviceInstanceName == SERVICE_NAME_2
}) })
removeService(SERVICE_NAME_2, SERVICE_TYPE_1, network) removeService(SERVICE_NAME_2, SERVICE_TYPE_1, socketKey)
val responses3 = getServices(SERVICE_TYPE_1, network) val responses3 = getServices(SERVICE_TYPE_1, socketKey)
assertEquals(1, responses3.size) assertEquals(1, responses3.size)
assertTrue(responses3.any { response -> assertTrue(responses3.any { response ->
response.serviceInstanceName == SERVICE_NAME_1 response.serviceInstanceName == SERVICE_NAME_1
}) })
val responses4 = getServices(SERVICE_TYPE_2, network) val responses4 = getServices(SERVICE_TYPE_2, socketKey)
assertEquals(1, responses4.size) assertEquals(1, responses4.size)
assertTrue(responses4.any { response -> assertTrue(responses4.any { response ->
response.serviceInstanceName == SERVICE_NAME_2 response.serviceInstanceName == SERVICE_NAME_2
@@ -132,5 +138,5 @@ class MdnsServiceCacheTest {
private fun createResponse(serviceInstanceName: String, serviceType: String) = MdnsResponse( private fun createResponse(serviceInstanceName: String, serviceType: String) = MdnsResponse(
0 /* now */, "$serviceInstanceName.$serviceType".split(".").toTypedArray(), 0 /* now */, "$serviceInstanceName.$serviceType".split(".").toTypedArray(),
INTERFACE_INDEX, network) socketKey.interfaceIndex, socketKey.network)
} }