Merge "Add expired services removal flag" into main

This commit is contained in:
Paul Hu
2023-10-19 06:19:43 +00:00
committed by Gerrit Code Review
9 changed files with 117 additions and 50 deletions

View File

@@ -13,3 +13,10 @@ flag {
description: "This flag controls the forbidden capability API"
bug: "302997505"
}
flag {
name: "nsd_expired_services_removal"
namespace: "android_core_networking"
description: "Remove expired services from MdnsServiceCache"
bug: "304649384"
}

View File

@@ -1703,20 +1703,20 @@ public class NsdService extends INsdManager.Stub {
am.addOnUidImportanceListener(new UidImportanceListener(handler),
mRunningAppActiveImportanceCutoff);
final MdnsFeatureFlags flags = new MdnsFeatureFlags.Builder()
.setIsMdnsOffloadFeatureEnabled(mDeps.isTetheringFeatureNotChickenedOut(
mContext, MdnsFeatureFlags.NSD_FORCE_DISABLE_MDNS_OFFLOAD))
.setIncludeInetAddressRecordsInProbing(mDeps.isFeatureEnabled(
mContext, MdnsFeatureFlags.INCLUDE_INET_ADDRESS_RECORDS_IN_PROBING))
.setIsExpiredServicesRemovalEnabled(mDeps.isTrunkStableFeatureEnabled(
MdnsFeatureFlags.NSD_EXPIRED_SERVICES_REMOVAL))
.build();
mMdnsSocketClient =
new MdnsMultinetworkSocketClient(handler.getLooper(), mMdnsSocketProvider,
LOGGER.forSubComponent("MdnsMultinetworkSocketClient"));
mMdnsDiscoveryManager = deps.makeMdnsDiscoveryManager(new ExecutorProvider(),
mMdnsSocketClient, LOGGER.forSubComponent("MdnsDiscoveryManager"));
mMdnsSocketClient, LOGGER.forSubComponent("MdnsDiscoveryManager"), flags);
handler.post(() -> mMdnsSocketClient.setCallback(mMdnsDiscoveryManager));
MdnsFeatureFlags flags = new MdnsFeatureFlags.Builder()
.setIsMdnsOffloadFeatureEnabled(
mDeps.isTetheringFeatureNotChickenedOut(
mContext, MdnsFeatureFlags.NSD_FORCE_DISABLE_MDNS_OFFLOAD))
.setIncludeInetAddressRecordsInProbing(
mDeps.isFeatureEnabled(
mContext, MdnsFeatureFlags.INCLUDE_INET_ADDRESS_RECORDS_IN_PROBING))
.build();
mAdvertiser = deps.makeMdnsAdvertiser(handler.getLooper(), mMdnsSocketProvider,
new AdvertiserCallback(), LOGGER.forSubComponent("MdnsAdvertiser"), flags);
mClock = deps.makeClock();
@@ -1773,13 +1773,22 @@ public class NsdService extends INsdManager.Stub {
return DeviceConfigUtils.isTetheringFeatureNotChickenedOut(context, feature);
}
/**
* @see DeviceConfigUtils#isTrunkStableFeatureEnabled
*/
public boolean isTrunkStableFeatureEnabled(String feature) {
return DeviceConfigUtils.isTrunkStableFeatureEnabled(feature);
}
/**
* @see MdnsDiscoveryManager
*/
public MdnsDiscoveryManager makeMdnsDiscoveryManager(
@NonNull ExecutorProvider executorProvider,
@NonNull MdnsMultinetworkSocketClient socketClient, @NonNull SharedLog sharedLog) {
return new MdnsDiscoveryManager(executorProvider, socketClient, sharedLog);
@NonNull MdnsMultinetworkSocketClient socketClient, @NonNull SharedLog sharedLog,
@NonNull MdnsFeatureFlags featureFlags) {
return new MdnsDiscoveryManager(
executorProvider, socketClient, sharedLog, featureFlags);
}
/**

View File

@@ -53,6 +53,7 @@ public class MdnsDiscoveryManager implements MdnsSocketClientBase.Callback {
@NonNull private final Handler handler;
@Nullable private final HandlerThread handlerThread;
@NonNull private final MdnsServiceCache serviceCache;
@NonNull private final MdnsFeatureFlags mdnsFeatureFlags;
private static class PerSocketServiceTypeClients {
private final ArrayMap<Pair<String, SocketKey>, MdnsServiceTypeClient> clients =
@@ -117,20 +118,22 @@ public class MdnsDiscoveryManager implements MdnsSocketClientBase.Callback {
}
public MdnsDiscoveryManager(@NonNull ExecutorProvider executorProvider,
@NonNull MdnsSocketClientBase socketClient, @NonNull SharedLog sharedLog) {
@NonNull MdnsSocketClientBase socketClient, @NonNull SharedLog sharedLog,
@NonNull MdnsFeatureFlags mdnsFeatureFlags) {
this.executorProvider = executorProvider;
this.socketClient = socketClient;
this.sharedLog = sharedLog;
this.perSocketServiceTypeClients = new PerSocketServiceTypeClients();
this.mdnsFeatureFlags = mdnsFeatureFlags;
if (socketClient.getLooper() != null) {
this.handlerThread = null;
this.handler = new Handler(socketClient.getLooper());
this.serviceCache = new MdnsServiceCache(socketClient.getLooper());
this.serviceCache = new MdnsServiceCache(socketClient.getLooper(), mdnsFeatureFlags);
} else {
this.handlerThread = new HandlerThread(MdnsDiscoveryManager.class.getSimpleName());
this.handlerThread.start();
this.handler = new Handler(handlerThread.getLooper());
this.serviceCache = new MdnsServiceCache(handlerThread.getLooper());
this.serviceCache = new MdnsServiceCache(handlerThread.getLooper(), mdnsFeatureFlags);
}
}

View File

@@ -20,16 +20,21 @@ package com.android.server.connectivity.mdns;
*/
public class MdnsFeatureFlags {
/**
* The feature flag for control whether the mDNS offload is enabled or not.
* A feature flag to control whether the mDNS offload is enabled or not.
*/
public static final String NSD_FORCE_DISABLE_MDNS_OFFLOAD = "nsd_force_disable_mdns_offload";
/**
* The feature flag for controlling whether the probing question should include
* A feature flag to control whether the probing question should include
* InetAddressRecords or not.
*/
public static final String INCLUDE_INET_ADDRESS_RECORDS_IN_PROBING =
"include_inet_address_records_in_probing";
/**
* A feature flag to control whether expired services removal should be enabled.
*/
public static final String NSD_EXPIRED_SERVICES_REMOVAL =
"nsd_expired_services_removal";
// Flag for offload feature
public final boolean mIsMdnsOffloadFeatureEnabled;
@@ -37,13 +42,17 @@ public class MdnsFeatureFlags {
// Flag for including InetAddressRecords in probing questions.
public final boolean mIncludeInetAddressRecordsInProbing;
// Flag for expired services removal
public final boolean mIsExpiredServicesRemovalEnabled;
/**
* The constructor for {@link MdnsFeatureFlags}.
*/
public MdnsFeatureFlags(boolean isOffloadFeatureEnabled,
boolean includeInetAddressRecordsInProbing) {
boolean includeInetAddressRecordsInProbing, boolean isExpiredServicesRemovalEnabled) {
mIsMdnsOffloadFeatureEnabled = isOffloadFeatureEnabled;
mIncludeInetAddressRecordsInProbing = includeInetAddressRecordsInProbing;
mIsExpiredServicesRemovalEnabled = isExpiredServicesRemovalEnabled;
}
@@ -57,6 +66,7 @@ public class MdnsFeatureFlags {
private boolean mIsMdnsOffloadFeatureEnabled;
private boolean mIncludeInetAddressRecordsInProbing;
private boolean mIsExpiredServicesRemovalEnabled;
/**
* The constructor for {@link Builder}.
@@ -64,10 +74,13 @@ public class MdnsFeatureFlags {
public Builder() {
mIsMdnsOffloadFeatureEnabled = false;
mIncludeInetAddressRecordsInProbing = false;
mIsExpiredServicesRemovalEnabled = true; // Default enabled.
}
/**
* Set if the mDNS offload feature is enabled.
* Set whether the mDNS offload feature is enabled.
*
* @see #NSD_FORCE_DISABLE_MDNS_OFFLOAD
*/
public Builder setIsMdnsOffloadFeatureEnabled(boolean isMdnsOffloadFeatureEnabled) {
mIsMdnsOffloadFeatureEnabled = isMdnsOffloadFeatureEnabled;
@@ -75,7 +88,9 @@ public class MdnsFeatureFlags {
}
/**
* Set if the probing question should include InetAddressRecords.
* Set whether the probing question should include InetAddressRecords.
*
* @see #INCLUDE_INET_ADDRESS_RECORDS_IN_PROBING
*/
public Builder setIncludeInetAddressRecordsInProbing(
boolean includeInetAddressRecordsInProbing) {
@@ -83,13 +98,22 @@ public class MdnsFeatureFlags {
return this;
}
/**
* Set whether the expired services removal is enabled.
*
* @see #NSD_EXPIRED_SERVICES_REMOVAL
*/
public Builder setIsExpiredServicesRemovalEnabled(boolean isExpiredServicesRemovalEnabled) {
mIsExpiredServicesRemovalEnabled = isExpiredServicesRemovalEnabled;
return this;
}
/**
* Builds a {@link MdnsFeatureFlags} with the arguments supplied to this builder.
*/
public MdnsFeatureFlags build() {
return new MdnsFeatureFlags(
mIsMdnsOffloadFeatureEnabled, mIncludeInetAddressRecordsInProbing);
}
return new MdnsFeatureFlags(mIsMdnsOffloadFeatureEnabled,
mIncludeInetAddressRecordsInProbing, mIsExpiredServicesRemovalEnabled);
}
}
}

View File

@@ -80,9 +80,12 @@ public class MdnsServiceCache {
private final ArrayMap<CacheKey, ServiceExpiredCallback> mCallbacks = new ArrayMap<>();
@NonNull
private final Handler mHandler;
@NonNull
private final MdnsFeatureFlags mMdnsFeatureFlags;
public MdnsServiceCache(@NonNull Looper looper) {
public MdnsServiceCache(@NonNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
mHandler = new Handler(looper);
mMdnsFeatureFlags = mdnsFeatureFlags;
}
/**

View File

@@ -35,14 +35,17 @@ import static android.net.connectivity.ConnectivityCompatChanges.RUN_NATIVE_NSD_
import static android.net.nsd.NsdManager.FAILURE_BAD_PARAMETERS;
import static android.net.nsd.NsdManager.FAILURE_INTERNAL_ERROR;
import static android.net.nsd.NsdManager.FAILURE_OPERATION_NOT_RUNNING;
import static com.android.networkstack.apishim.api33.ConstantsShim.REGISTER_NSD_OFFLOAD_ENGINE;
import static com.android.server.NsdService.DEFAULT_RUNNING_APP_ACTIVE_IMPORTANCE_CUTOFF;
import static com.android.server.NsdService.MdnsListener;
import static com.android.server.NsdService.NO_TRANSACTION;
import static com.android.server.NsdService.parseTypeAndSubtype;
import static com.android.testutils.ContextUtils.mockService;
import static libcore.junit.util.compat.CoreCompatChangeRule.DisableCompatChanges;
import static libcore.junit.util.compat.CoreCompatChangeRule.EnableCompatChanges;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
@@ -220,7 +223,7 @@ public class NsdServiceTest {
anyInt(), anyString(), anyString(), anyString(), anyInt());
doReturn(false).when(mDeps).isMdnsDiscoveryManagerEnabled(any(Context.class));
doReturn(mDiscoveryManager).when(mDeps)
.makeMdnsDiscoveryManager(any(), any(), any());
.makeMdnsDiscoveryManager(any(), any(), any(), any());
doReturn(mMulticastLock).when(mWifiManager).createMulticastLock(any());
doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any(), any(), any());
doReturn(DEFAULT_RUNNING_APP_ACTIVE_IMPORTANCE_CUTOFF).when(mDeps).getDeviceConfigInt(

View File

@@ -106,7 +106,7 @@ public class MdnsDiscoveryManagerTests {
doReturn(thread.getLooper()).when(socketClient).getLooper();
doReturn(true).when(socketClient).supportsRequestingSpecificNetworks();
discoveryManager = new MdnsDiscoveryManager(executorProvider, socketClient,
sharedLog) {
sharedLog, MdnsFeatureFlags.newBuilder().build()) {
@Override
MdnsServiceTypeClient createServiceTypeClient(@NonNull String serviceType,
@NonNull SocketKey socketKey) {

View File

@@ -50,9 +50,6 @@ class MdnsServiceCacheTest {
private val handler by lazy {
Handler(thread.looper)
}
private val serviceCache by lazy {
MdnsServiceCache(thread.looper)
}
@Before
fun setUp() {
@@ -64,6 +61,11 @@ class MdnsServiceCacheTest {
thread.quitSafely()
}
private fun makeFlags(isExpiredServicesRemovalEnabled: Boolean = false) =
MdnsFeatureFlags.Builder()
.setIsExpiredServicesRemovalEnabled(isExpiredServicesRemovalEnabled)
.build()
private fun <T> runningOnHandlerAndReturn(functor: (() -> T)): T {
val future = CompletableFuture<T>()
handler.post {
@@ -72,36 +74,51 @@ class MdnsServiceCacheTest {
return future.get(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS)
}
private fun addOrUpdateService(cacheKey: CacheKey, service: MdnsResponse): Unit =
runningOnHandlerAndReturn { serviceCache.addOrUpdateService(cacheKey, service) }
private fun addOrUpdateService(
serviceCache: MdnsServiceCache,
cacheKey: CacheKey,
service: MdnsResponse
): Unit = runningOnHandlerAndReturn { serviceCache.addOrUpdateService(cacheKey, service) }
private fun removeService(serviceName: String, cacheKey: CacheKey): Unit =
runningOnHandlerAndReturn { serviceCache.removeService(serviceName, cacheKey) }
private fun removeService(
serviceCache: MdnsServiceCache,
serviceName: String,
cacheKey: CacheKey
): Unit = runningOnHandlerAndReturn { serviceCache.removeService(serviceName, cacheKey) }
private fun getService(serviceName: String, cacheKey: CacheKey): MdnsResponse? =
runningOnHandlerAndReturn { serviceCache.getCachedService(serviceName, cacheKey) }
private fun getService(
serviceCache: MdnsServiceCache,
serviceName: String,
cacheKey: CacheKey,
): MdnsResponse? = runningOnHandlerAndReturn {
serviceCache.getCachedService(serviceName, cacheKey)
}
private fun getServices(cacheKey: CacheKey): List<MdnsResponse> =
runningOnHandlerAndReturn { serviceCache.getCachedServices(cacheKey) }
private fun getServices(
serviceCache: MdnsServiceCache,
cacheKey: CacheKey,
): List<MdnsResponse> = runningOnHandlerAndReturn { serviceCache.getCachedServices(cacheKey) }
@Test
fun testAddAndRemoveService() {
addOrUpdateService(cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
var response = getService(SERVICE_NAME_1, cacheKey1)
val serviceCache = MdnsServiceCache(thread.looper, makeFlags())
addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
var response = getService(serviceCache, SERVICE_NAME_1, cacheKey1)
assertNotNull(response)
assertEquals(SERVICE_NAME_1, response.serviceInstanceName)
removeService(SERVICE_NAME_1, cacheKey1)
response = getService(SERVICE_NAME_1, cacheKey1)
removeService(serviceCache, SERVICE_NAME_1, cacheKey1)
response = getService(serviceCache, SERVICE_NAME_1, cacheKey1)
assertNull(response)
}
@Test
fun testGetCachedServices_multipleServiceTypes() {
addOrUpdateService(cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
addOrUpdateService(cacheKey1, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1))
addOrUpdateService(cacheKey2, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2))
val serviceCache = MdnsServiceCache(thread.looper, makeFlags())
addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1))
addOrUpdateService(serviceCache, cacheKey2, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2))
val responses1 = getServices(cacheKey1)
val responses1 = getServices(serviceCache, cacheKey1)
assertEquals(2, responses1.size)
assertTrue(responses1.stream().anyMatch { response ->
response.serviceInstanceName == SERVICE_NAME_1
@@ -109,19 +126,19 @@ class MdnsServiceCacheTest {
assertTrue(responses1.any { response ->
response.serviceInstanceName == SERVICE_NAME_2
})
val responses2 = getServices(cacheKey2)
val responses2 = getServices(serviceCache, cacheKey2)
assertEquals(1, responses2.size)
assertTrue(responses2.any { response ->
response.serviceInstanceName == SERVICE_NAME_2
})
removeService(SERVICE_NAME_2, cacheKey1)
val responses3 = getServices(cacheKey1)
removeService(serviceCache, SERVICE_NAME_2, cacheKey1)
val responses3 = getServices(serviceCache, cacheKey1)
assertEquals(1, responses3.size)
assertTrue(responses3.any { response ->
response.serviceInstanceName == SERVICE_NAME_1
})
val responses4 = getServices(cacheKey2)
val responses4 = getServices(serviceCache, cacheKey2)
assertEquals(1, responses4.size)
assertTrue(responses4.any { response ->
response.serviceInstanceName == SERVICE_NAME_2

View File

@@ -193,7 +193,8 @@ public class MdnsServiceTypeClientTests {
thread = new HandlerThread("MdnsServiceTypeClientTests");
thread.start();
handler = new Handler(thread.getLooper());
serviceCache = new MdnsServiceCache(thread.getLooper());
serviceCache = new MdnsServiceCache(
thread.getLooper(), MdnsFeatureFlags.newBuilder().build());
doAnswer(inv -> {
latestDelayMs = 0;