Add NSD_LIMIT_LABEL_COUNT flag

Bug: 307475137
Test: atest FrameworksNetTests android.net.cts.NsdManagerTest
Merged-In: I48b1dc26f41549ec4afc71f87e98a02ac773430f
Change-Id: Ic4c2e4c0d61b76b1afd556560c18171bdb7a088e
This commit is contained in:
Paul Hu
2023-11-01 16:32:45 +08:00
parent 95cf7f9550
commit fd357ef695
12 changed files with 71 additions and 26 deletions

View File

@@ -1647,10 +1647,12 @@ public class NsdService extends INsdManager.Stub {
mContext, MdnsFeatureFlags.INCLUDE_INET_ADDRESS_RECORDS_IN_PROBING)) mContext, MdnsFeatureFlags.INCLUDE_INET_ADDRESS_RECORDS_IN_PROBING))
.setIsExpiredServicesRemovalEnabled(mDeps.isTrunkStableFeatureEnabled( .setIsExpiredServicesRemovalEnabled(mDeps.isTrunkStableFeatureEnabled(
MdnsFeatureFlags.NSD_EXPIRED_SERVICES_REMOVAL)) MdnsFeatureFlags.NSD_EXPIRED_SERVICES_REMOVAL))
.setIsLabelCountLimitEnabled(mDeps.isTetheringFeatureNotChickenedOut(
mContext, MdnsFeatureFlags.NSD_LIMIT_LABEL_COUNT))
.build(); .build();
mMdnsSocketClient = mMdnsSocketClient =
new MdnsMultinetworkSocketClient(handler.getLooper(), mMdnsSocketProvider, new MdnsMultinetworkSocketClient(handler.getLooper(), mMdnsSocketProvider,
LOGGER.forSubComponent("MdnsMultinetworkSocketClient")); LOGGER.forSubComponent("MdnsMultinetworkSocketClient"), flags);
mMdnsDiscoveryManager = deps.makeMdnsDiscoveryManager(new ExecutorProvider(), mMdnsDiscoveryManager = deps.makeMdnsDiscoveryManager(new ExecutorProvider(),
mMdnsSocketClient, LOGGER.forSubComponent("MdnsDiscoveryManager"), flags); mMdnsSocketClient, LOGGER.forSubComponent("MdnsDiscoveryManager"), flags);
handler.post(() -> mMdnsSocketClient.setCallback(mMdnsDiscoveryManager)); handler.post(() -> mMdnsSocketClient.setCallback(mMdnsDiscoveryManager));

View File

@@ -36,6 +36,11 @@ public class MdnsFeatureFlags {
public static final String NSD_EXPIRED_SERVICES_REMOVAL = public static final String NSD_EXPIRED_SERVICES_REMOVAL =
"nsd_expired_services_removal"; "nsd_expired_services_removal";
/**
* A feature flag to control whether the label count limit should be enabled.
*/
public static final String NSD_LIMIT_LABEL_COUNT = "nsd_limit_label_count";
// Flag for offload feature // Flag for offload feature
public final boolean mIsMdnsOffloadFeatureEnabled; public final boolean mIsMdnsOffloadFeatureEnabled;
@@ -45,14 +50,20 @@ public class MdnsFeatureFlags {
// Flag for expired services removal // Flag for expired services removal
public final boolean mIsExpiredServicesRemovalEnabled; public final boolean mIsExpiredServicesRemovalEnabled;
// Flag for label count limit
public final boolean mIsLabelCountLimitEnabled;
/** /**
* The constructor for {@link MdnsFeatureFlags}. * The constructor for {@link MdnsFeatureFlags}.
*/ */
public MdnsFeatureFlags(boolean isOffloadFeatureEnabled, public MdnsFeatureFlags(boolean isOffloadFeatureEnabled,
boolean includeInetAddressRecordsInProbing, boolean isExpiredServicesRemovalEnabled) { boolean includeInetAddressRecordsInProbing,
boolean isExpiredServicesRemovalEnabled,
boolean isLabelCountLimitEnabled) {
mIsMdnsOffloadFeatureEnabled = isOffloadFeatureEnabled; mIsMdnsOffloadFeatureEnabled = isOffloadFeatureEnabled;
mIncludeInetAddressRecordsInProbing = includeInetAddressRecordsInProbing; mIncludeInetAddressRecordsInProbing = includeInetAddressRecordsInProbing;
mIsExpiredServicesRemovalEnabled = isExpiredServicesRemovalEnabled; mIsExpiredServicesRemovalEnabled = isExpiredServicesRemovalEnabled;
mIsLabelCountLimitEnabled = isLabelCountLimitEnabled;
} }
@@ -67,6 +78,7 @@ public class MdnsFeatureFlags {
private boolean mIsMdnsOffloadFeatureEnabled; private boolean mIsMdnsOffloadFeatureEnabled;
private boolean mIncludeInetAddressRecordsInProbing; private boolean mIncludeInetAddressRecordsInProbing;
private boolean mIsExpiredServicesRemovalEnabled; private boolean mIsExpiredServicesRemovalEnabled;
private boolean mIsLabelCountLimitEnabled;
/** /**
* The constructor for {@link Builder}. * The constructor for {@link Builder}.
@@ -75,6 +87,7 @@ public class MdnsFeatureFlags {
mIsMdnsOffloadFeatureEnabled = false; mIsMdnsOffloadFeatureEnabled = false;
mIncludeInetAddressRecordsInProbing = false; mIncludeInetAddressRecordsInProbing = false;
mIsExpiredServicesRemovalEnabled = true; // Default enabled. mIsExpiredServicesRemovalEnabled = true; // Default enabled.
mIsLabelCountLimitEnabled = true; // Default enabled.
} }
/** /**
@@ -108,12 +121,24 @@ public class MdnsFeatureFlags {
return this; return this;
} }
/**
* Set whether the label count limit is enabled.
*
* @see #NSD_LIMIT_LABEL_COUNT
*/
public Builder setIsLabelCountLimitEnabled(boolean isLabelCountLimitEnabled) {
mIsLabelCountLimitEnabled = isLabelCountLimitEnabled;
return this;
}
/** /**
* Builds a {@link MdnsFeatureFlags} with the arguments supplied to this builder. * Builds a {@link MdnsFeatureFlags} with the arguments supplied to this builder.
*/ */
public MdnsFeatureFlags build() { public MdnsFeatureFlags build() {
return new MdnsFeatureFlags(mIsMdnsOffloadFeatureEnabled, return new MdnsFeatureFlags(mIsMdnsOffloadFeatureEnabled,
mIncludeInetAddressRecordsInProbing, mIsExpiredServicesRemovalEnabled); mIncludeInetAddressRecordsInProbing,
mIsExpiredServicesRemovalEnabled,
mIsLabelCountLimitEnabled);
} }
} }
} }

View File

@@ -65,11 +65,12 @@ public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHand
private final MdnsProber mProber; private final MdnsProber mProber;
@NonNull @NonNull
private final MdnsReplySender mReplySender; private final MdnsReplySender mReplySender;
@NonNull @NonNull
private final SharedLog mSharedLog; private final SharedLog mSharedLog;
@NonNull @NonNull
private final byte[] mPacketCreationBuffer; private final byte[] mPacketCreationBuffer;
@NonNull
private final MdnsFeatureFlags mMdnsFeatureFlags;
/** /**
* Callbacks called by {@link MdnsInterfaceAdvertiser} to report status updates. * Callbacks called by {@link MdnsInterfaceAdvertiser} to report status updates.
@@ -213,6 +214,7 @@ public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHand
mProber = deps.makeMdnsProber(sharedLog.getTag(), looper, mReplySender, mProbingCallback, mProber = deps.makeMdnsProber(sharedLog.getTag(), looper, mReplySender, mProbingCallback,
sharedLog); sharedLog);
mSharedLog = sharedLog; mSharedLog = sharedLog;
mMdnsFeatureFlags = mdnsFeatureFlags;
} }
/** /**
@@ -351,7 +353,7 @@ public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHand
public void handlePacket(byte[] recvbuf, int length, InetSocketAddress src) { public void handlePacket(byte[] recvbuf, int length, InetSocketAddress src) {
final MdnsPacket packet; final MdnsPacket packet;
try { try {
packet = MdnsPacket.parse(new MdnsPacketReader(recvbuf, length)); packet = MdnsPacket.parse(new MdnsPacketReader(recvbuf, length, mMdnsFeatureFlags));
} catch (MdnsPacket.ParseException e) { } catch (MdnsPacket.ParseException e) {
mSharedLog.e("Error parsing mDNS packet", e); mSharedLog.e("Error parsing mDNS packet", e);
if (DBG) { if (DBG) {

View File

@@ -50,6 +50,7 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase {
@NonNull private final Handler mHandler; @NonNull private final Handler mHandler;
@NonNull private final MdnsSocketProvider mSocketProvider; @NonNull private final MdnsSocketProvider mSocketProvider;
@NonNull private final SharedLog mSharedLog; @NonNull private final SharedLog mSharedLog;
@NonNull private final MdnsFeatureFlags mMdnsFeatureFlags;
private final ArrayMap<MdnsServiceBrowserListener, InterfaceSocketCallback> mSocketRequests = private final ArrayMap<MdnsServiceBrowserListener, InterfaceSocketCallback> mSocketRequests =
new ArrayMap<>(); new ArrayMap<>();
@@ -58,11 +59,12 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase {
private int mReceivedPacketNumber = 0; private int mReceivedPacketNumber = 0;
public MdnsMultinetworkSocketClient(@NonNull Looper looper, public MdnsMultinetworkSocketClient(@NonNull Looper looper,
@NonNull MdnsSocketProvider provider, @NonNull MdnsSocketProvider provider, @NonNull SharedLog sharedLog,
@NonNull SharedLog sharedLog) { @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
mHandler = new Handler(looper); mHandler = new Handler(looper);
mSocketProvider = provider; mSocketProvider = provider;
mSharedLog = sharedLog; mSharedLog = sharedLog;
mMdnsFeatureFlags = mdnsFeatureFlags;
} }
private class InterfaceSocketCallback implements MdnsSocketProvider.SocketCallback { private class InterfaceSocketCallback implements MdnsSocketProvider.SocketCallback {
@@ -239,7 +241,7 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase {
final MdnsPacket response; final MdnsPacket response;
try { try {
response = MdnsResponseDecoder.parseResponse(recvbuf, length); response = MdnsResponseDecoder.parseResponse(recvbuf, length, mMdnsFeatureFlags);
} catch (MdnsPacket.ParseException e) { } catch (MdnsPacket.ParseException e) {
if (e.code != MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE) { if (e.code != MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE) {
mSharedLog.e(e.getMessage(), e); mSharedLog.e(e.getMessage(), e);

View File

@@ -16,6 +16,7 @@
package com.android.server.connectivity.mdns; package com.android.server.connectivity.mdns;
import android.annotation.NonNull;
import android.annotation.Nullable; import android.annotation.Nullable;
import android.util.SparseArray; import android.util.SparseArray;
@@ -33,21 +34,23 @@ public class MdnsPacketReader {
private final byte[] buf; private final byte[] buf;
private final int count; private final int count;
private final SparseArray<LabelEntry> labelDictionary; private final SparseArray<LabelEntry> labelDictionary;
private final MdnsFeatureFlags mMdnsFeatureFlags;
private int pos; private int pos;
private int limit; private int limit;
/** Constructs a reader for the given packet. */ /** Constructs a reader for the given packet. */
public MdnsPacketReader(DatagramPacket packet) { public MdnsPacketReader(DatagramPacket packet) {
this(packet.getData(), packet.getLength()); this(packet.getData(), packet.getLength(), MdnsFeatureFlags.newBuilder().build());
} }
/** Constructs a reader for the given packet. */ /** Constructs a reader for the given packet. */
public MdnsPacketReader(byte[] buffer, int length) { public MdnsPacketReader(byte[] buffer, int length, @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
buf = buffer; buf = buffer;
count = length; count = length;
pos = 0; pos = 0;
limit = -1; limit = -1;
labelDictionary = new SparseArray<>(16); labelDictionary = new SparseArray<>(16);
mMdnsFeatureFlags = mdnsFeatureFlags;
} }
/** /**
@@ -269,4 +272,4 @@ public class MdnsPacketReader {
this.label = label; this.label = label;
} }
} }
} }

View File

@@ -84,9 +84,9 @@ public class MdnsResponseDecoder {
* @throws MdnsPacket.ParseException if a response packet could not be parsed. * @throws MdnsPacket.ParseException if a response packet could not be parsed.
*/ */
@NonNull @NonNull
public static MdnsPacket parseResponse(@NonNull byte[] recvbuf, int length) public static MdnsPacket parseResponse(@NonNull byte[] recvbuf, int length,
throws MdnsPacket.ParseException { @NonNull MdnsFeatureFlags mdnsFeatureFlags) throws MdnsPacket.ParseException {
MdnsPacketReader reader = new MdnsPacketReader(recvbuf, length); final MdnsPacketReader reader = new MdnsPacketReader(recvbuf, length, mdnsFeatureFlags);
final MdnsPacket mdnsPacket; final MdnsPacket mdnsPacket;
try { try {

View File

@@ -105,9 +105,10 @@ public class MdnsSocketClient implements MdnsSocketClientBase {
private AtomicInteger packetsCount; private AtomicInteger packetsCount;
@Nullable private Timer checkMulticastResponseTimer; @Nullable private Timer checkMulticastResponseTimer;
private final SharedLog sharedLog; private final SharedLog sharedLog;
@NonNull private final MdnsFeatureFlags mdnsFeatureFlags;
public MdnsSocketClient(@NonNull Context context, @NonNull MulticastLock multicastLock, public MdnsSocketClient(@NonNull Context context, @NonNull MulticastLock multicastLock,
SharedLog sharedLog) { SharedLog sharedLog, @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
this.sharedLog = sharedLog; this.sharedLog = sharedLog;
this.context = context; this.context = context;
this.multicastLock = multicastLock; this.multicastLock = multicastLock;
@@ -116,6 +117,7 @@ public class MdnsSocketClient implements MdnsSocketClientBase {
} else { } else {
unicastReceiverBuffer = null; unicastReceiverBuffer = null;
} }
this.mdnsFeatureFlags = mdnsFeatureFlags;
} }
@Override @Override
@@ -454,7 +456,8 @@ public class MdnsSocketClient implements MdnsSocketClientBase {
final MdnsPacket response; final MdnsPacket response;
try { try {
response = MdnsResponseDecoder.parseResponse(packet.getData(), packet.getLength()); response = MdnsResponseDecoder.parseResponse(
packet.getData(), packet.getLength(), mdnsFeatureFlags);
} catch (MdnsPacket.ParseException e) { } catch (MdnsPacket.ParseException e) {
sharedLog.w(String.format("Error while decoding %s packet (%d): %d", sharedLog.w(String.format("Error while decoding %s packet (%d): %d",
responseType, packetNumber, e.code)); responseType, packetNumber, e.code));

View File

@@ -82,8 +82,8 @@ public class MdnsMultinetworkSocketClientTest {
mHandlerThread.start(); mHandlerThread.start();
mHandler = new Handler(mHandlerThread.getLooper()); mHandler = new Handler(mHandlerThread.getLooper());
mSocketKey = new SocketKey(1000 /* interfaceIndex */); mSocketKey = new SocketKey(1000 /* interfaceIndex */);
mSocketClient = new MdnsMultinetworkSocketClient( mSocketClient = new MdnsMultinetworkSocketClient(mHandlerThread.getLooper(), mProvider,
mHandlerThread.getLooper(), mProvider, mSharedLog); mSharedLog, MdnsFeatureFlags.newBuilder().build());
mHandler.post(() -> mSocketClient.setCallback(mCallback)); mHandler.post(() -> mSocketClient.setCallback(mCallback));
} }

View File

@@ -75,7 +75,7 @@ public class MdnsPacketReaderTests {
+ "the packet length"); + "the packet length");
} catch (IOException e) { } catch (IOException e) {
// Expected // Expected
} catch (Exception e) { } catch (RuntimeException e) {
fail(String.format( fail(String.format(
Locale.ROOT, Locale.ROOT,
"Should not have thrown any other exception except " + "for IOException: %s", "Should not have thrown any other exception except " + "for IOException: %s",
@@ -83,4 +83,4 @@ public class MdnsPacketReaderTests {
} }
assertEquals(data.length, packetReader.getRemaining()); assertEquals(data.length, packetReader.getRemaining());
} }
} }

View File

@@ -27,6 +27,9 @@ import org.junit.runner.RunWith
@RunWith(DevSdkIgnoreRunner::class) @RunWith(DevSdkIgnoreRunner::class)
class MdnsPacketTest { class MdnsPacketTest {
private fun makeFlags(isLabelCountLimitEnabled: Boolean = false): MdnsFeatureFlags =
MdnsFeatureFlags.newBuilder()
.setIsLabelCountLimitEnabled(isLabelCountLimitEnabled).build()
@Test @Test
fun testParseQuery() { fun testParseQuery() {
// Probe packet with 1 question for Android.local, and 4 additionalRecords with 4 addresses // Probe packet with 1 question for Android.local, and 4 additionalRecords with 4 addresses
@@ -38,7 +41,7 @@ class MdnsPacketTest {
"010db8000000000000000000000789" "010db8000000000000000000000789"
val bytes = HexDump.hexStringToByteArray(packetHex) val bytes = HexDump.hexStringToByteArray(packetHex)
val reader = MdnsPacketReader(bytes, bytes.size) val reader = MdnsPacketReader(bytes, bytes.size, makeFlags())
val packet = MdnsPacket.parse(reader) val packet = MdnsPacket.parse(reader)
assertEquals(123, packet.transactionId) assertEquals(123, packet.transactionId)

View File

@@ -17,8 +17,10 @@
package com.android.server.connectivity.mdns; package com.android.server.connectivity.mdns;
import static android.net.InetAddresses.parseNumericAddress; import static android.net.InetAddresses.parseNumericAddress;
import static com.android.server.connectivity.mdns.util.MdnsUtils.Clock; import static com.android.server.connectivity.mdns.util.MdnsUtils.Clock;
import static com.android.testutils.DevSdkIgnoreRuleKt.SC_V2; import static com.android.testutils.DevSdkIgnoreRuleKt.SC_V2;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
@@ -337,7 +339,8 @@ public class MdnsResponseDecoderTests {
packet.setSocketAddress( packet.setSocketAddress(
new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT)); new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT));
final MdnsPacket parsedPacket = MdnsResponseDecoder.parseResponse(data6, data6.length); final MdnsPacket parsedPacket = MdnsResponseDecoder.parseResponse(
data6, data6.length, MdnsFeatureFlags.newBuilder().build());
assertNotNull(parsedPacket); assertNotNull(parsedPacket);
final Network network = mock(Network.class); final Network network = mock(Network.class);
@@ -636,7 +639,8 @@ public class MdnsResponseDecoderTests {
private ArraySet<MdnsResponse> decode(MdnsResponseDecoder decoder, byte[] data, private ArraySet<MdnsResponse> decode(MdnsResponseDecoder decoder, byte[] data,
Collection<MdnsResponse> existingResponses) throws MdnsPacket.ParseException { Collection<MdnsResponse> existingResponses) throws MdnsPacket.ParseException {
final MdnsPacket parsedPacket = MdnsResponseDecoder.parseResponse(data, data.length); final MdnsPacket parsedPacket = MdnsResponseDecoder.parseResponse(
data, data.length, MdnsFeatureFlags.newBuilder().build());
assertNotNull(parsedPacket); assertNotNull(parsedPacket);
return new ArraySet<>(decoder.augmentResponses(parsedPacket, return new ArraySet<>(decoder.augmentResponses(parsedPacket,

View File

@@ -78,6 +78,7 @@ public class MdnsSocketClientTests {
@Mock private SharedLog sharedLog; @Mock private SharedLog sharedLog;
private MdnsSocketClient mdnsClient; private MdnsSocketClient mdnsClient;
private MdnsFeatureFlags flags = MdnsFeatureFlags.newBuilder().build();
@Before @Before
public void setup() throws RuntimeException, IOException { public void setup() throws RuntimeException, IOException {
@@ -86,7 +87,7 @@ public class MdnsSocketClientTests {
when(mockWifiManager.createMulticastLock(ArgumentMatchers.anyString())) when(mockWifiManager.createMulticastLock(ArgumentMatchers.anyString()))
.thenReturn(mockMulticastLock); .thenReturn(mockMulticastLock);
mdnsClient = new MdnsSocketClient(mContext, mockMulticastLock, sharedLog) { mdnsClient = new MdnsSocketClient(mContext, mockMulticastLock, sharedLog, flags) {
@Override @Override
MdnsSocket createMdnsSocket(int port, SharedLog sharedLog) throws IOException { MdnsSocket createMdnsSocket(int port, SharedLog sharedLog) throws IOException {
if (port == MdnsConstants.MDNS_PORT) { if (port == MdnsConstants.MDNS_PORT) {
@@ -515,7 +516,7 @@ public class MdnsSocketClientTests {
//MdnsConfigsFlagsImpl.allowNetworkInterfaceIndexPropagation.override(true); //MdnsConfigsFlagsImpl.allowNetworkInterfaceIndexPropagation.override(true);
when(mockMulticastSocket.getInterfaceIndex()).thenReturn(21); when(mockMulticastSocket.getInterfaceIndex()).thenReturn(21);
mdnsClient = new MdnsSocketClient(mContext, mockMulticastLock, sharedLog) { mdnsClient = new MdnsSocketClient(mContext, mockMulticastLock, sharedLog, flags) {
@Override @Override
MdnsSocket createMdnsSocket(int port, SharedLog sharedLog) { MdnsSocket createMdnsSocket(int port, SharedLog sharedLog) {
if (port == MdnsConstants.MDNS_PORT) { if (port == MdnsConstants.MDNS_PORT) {
@@ -538,7 +539,7 @@ public class MdnsSocketClientTests {
//MdnsConfigsFlagsImpl.allowNetworkInterfaceIndexPropagation.override(false); //MdnsConfigsFlagsImpl.allowNetworkInterfaceIndexPropagation.override(false);
when(mockMulticastSocket.getInterfaceIndex()).thenReturn(21); when(mockMulticastSocket.getInterfaceIndex()).thenReturn(21);
mdnsClient = new MdnsSocketClient(mContext, mockMulticastLock, sharedLog) { mdnsClient = new MdnsSocketClient(mContext, mockMulticastLock, sharedLog, flags) {
@Override @Override
MdnsSocket createMdnsSocket(int port, SharedLog sharedLog) { MdnsSocket createMdnsSocket(int port, SharedLog sharedLog) {
if (port == MdnsConstants.MDNS_PORT) { if (port == MdnsConstants.MDNS_PORT) {