Merge "Include A/AAAA records in probing packet" into main

This commit is contained in:
Paul Hu
2023-10-17 02:24:10 +00:00
committed by Gerrit Code Review
12 changed files with 180 additions and 53 deletions

View File

@@ -1709,9 +1709,14 @@ public class NsdService extends INsdManager.Stub {
mMdnsDiscoveryManager = deps.makeMdnsDiscoveryManager(new ExecutorProvider(),
mMdnsSocketClient, LOGGER.forSubComponent("MdnsDiscoveryManager"));
handler.post(() -> mMdnsSocketClient.setCallback(mMdnsDiscoveryManager));
MdnsFeatureFlags flags = new MdnsFeatureFlags.Builder().setIsMdnsOffloadFeatureEnabled(
mDeps.isTetheringFeatureNotChickenedOut(mContext,
MdnsFeatureFlags.NSD_FORCE_DISABLE_MDNS_OFFLOAD)).build();
MdnsFeatureFlags flags = new MdnsFeatureFlags.Builder()
.setIsMdnsOffloadFeatureEnabled(
mDeps.isTetheringFeatureNotChickenedOut(
mContext, MdnsFeatureFlags.NSD_FORCE_DISABLE_MDNS_OFFLOAD))
.setIncludeInetAddressRecordsInProbing(
mDeps.isFeatureEnabled(
mContext, MdnsFeatureFlags.INCLUDE_INET_ADDRESS_RECORDS_IN_PROBING))
.build();
mAdvertiser = deps.makeMdnsAdvertiser(handler.getLooper(), mMdnsSocketProvider,
new AdvertiserCallback(), LOGGER.forSubComponent("MdnsAdvertiser"), flags);
mClock = deps.makeClock();

View File

@@ -96,10 +96,11 @@ public class MdnsAdvertiser {
@NonNull Looper looper, @NonNull byte[] packetCreationBuffer,
@NonNull MdnsInterfaceAdvertiser.Callback cb,
@NonNull String[] deviceHostName,
@NonNull SharedLog sharedLog) {
@NonNull SharedLog sharedLog,
@NonNull MdnsFeatureFlags mdnsFeatureFlags) {
// Note NetworkInterface is final and not mockable
return new MdnsInterfaceAdvertiser(socket, initialAddresses, looper,
packetCreationBuffer, cb, deviceHostName, sharedLog);
packetCreationBuffer, cb, deviceHostName, sharedLog, mdnsFeatureFlags);
}
/**
@@ -394,7 +395,8 @@ public class MdnsAdvertiser {
if (advertiser == null) {
advertiser = mDeps.makeAdvertiser(socket, addresses, mLooper, mPacketCreationBuffer,
mInterfaceAdvertiserCb, mDeviceHostName,
mSharedLog.forSubComponent(socket.getInterface().getName()));
mSharedLog.forSubComponent(socket.getInterface().getName()),
mMdnsFeatureFlags);
mAllAdvertisers.put(socket, advertiser);
advertiser.start();
}

View File

@@ -24,14 +24,26 @@ public class MdnsFeatureFlags {
*/
public static final String NSD_FORCE_DISABLE_MDNS_OFFLOAD = "nsd_force_disable_mdns_offload";
/**
* The feature flag for controlling whether the probing question should include
* InetAddressRecords or not.
*/
public static final String INCLUDE_INET_ADDRESS_RECORDS_IN_PROBING =
"include_inet_address_records_in_probing";
// Flag for offload feature
public final boolean mIsMdnsOffloadFeatureEnabled;
// Flag for including InetAddressRecords in probing questions.
public final boolean mIncludeInetAddressRecordsInProbing;
/**
* The constructor for {@link MdnsFeatureFlags}.
*/
public MdnsFeatureFlags(boolean isOffloadFeatureEnabled) {
public MdnsFeatureFlags(boolean isOffloadFeatureEnabled,
boolean includeInetAddressRecordsInProbing) {
mIsMdnsOffloadFeatureEnabled = isOffloadFeatureEnabled;
mIncludeInetAddressRecordsInProbing = includeInetAddressRecordsInProbing;
}
@@ -44,12 +56,14 @@ public class MdnsFeatureFlags {
public static final class Builder {
private boolean mIsMdnsOffloadFeatureEnabled;
private boolean mIncludeInetAddressRecordsInProbing;
/**
* The constructor for {@link Builder}.
*/
public Builder() {
mIsMdnsOffloadFeatureEnabled = false;
mIncludeInetAddressRecordsInProbing = false;
}
/**
@@ -60,11 +74,21 @@ public class MdnsFeatureFlags {
return this;
}
/**
* Set if the probing question should include InetAddressRecords.
*/
public Builder setIncludeInetAddressRecordsInProbing(
boolean includeInetAddressRecordsInProbing) {
mIncludeInetAddressRecordsInProbing = includeInetAddressRecordsInProbing;
return this;
}
/**
* Builds a {@link MdnsFeatureFlags} with the arguments supplied to this builder.
*/
public MdnsFeatureFlags build() {
return new MdnsFeatureFlags(mIsMdnsOffloadFeatureEnabled);
return new MdnsFeatureFlags(
mIsMdnsOffloadFeatureEnabled, mIncludeInetAddressRecordsInProbing);
}
}

View File

@@ -18,7 +18,7 @@ package com.android.server.connectivity.mdns;
import android.annotation.Nullable;
import com.android.internal.annotations.VisibleForTesting;
import androidx.annotation.VisibleForTesting;
import java.io.IOException;
import java.net.Inet4Address;
@@ -29,7 +29,7 @@ import java.util.Locale;
import java.util.Objects;
/** An mDNS "AAAA" or "A" record, which holds an IPv6 or IPv4 address. */
@VisibleForTesting
@VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
public class MdnsInetAddressRecord extends MdnsRecord {
@Nullable private Inet6Address inet6Address;
@Nullable private Inet4Address inet4Address;

View File

@@ -150,8 +150,8 @@ public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHand
/** @see MdnsRecordRepository */
@NonNull
public MdnsRecordRepository makeRecordRepository(@NonNull Looper looper,
@NonNull String[] deviceHostName) {
return new MdnsRecordRepository(looper, deviceHostName);
@NonNull String[] deviceHostName, @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
return new MdnsRecordRepository(looper, deviceHostName, mdnsFeatureFlags);
}
/** @see MdnsReplySender */
@@ -187,16 +187,18 @@ public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHand
public MdnsInterfaceAdvertiser(@NonNull MdnsInterfaceSocket socket,
@NonNull List<LinkAddress> initialAddresses, @NonNull Looper looper,
@NonNull byte[] packetCreationBuffer, @NonNull Callback cb,
@NonNull String[] deviceHostName, @NonNull SharedLog sharedLog) {
@NonNull String[] deviceHostName, @NonNull SharedLog sharedLog,
@NonNull MdnsFeatureFlags mdnsFeatureFlags) {
this(socket, initialAddresses, looper, packetCreationBuffer, cb,
new Dependencies(), deviceHostName, sharedLog);
new Dependencies(), deviceHostName, sharedLog, mdnsFeatureFlags);
}
public MdnsInterfaceAdvertiser(@NonNull MdnsInterfaceSocket socket,
@NonNull List<LinkAddress> initialAddresses, @NonNull Looper looper,
@NonNull byte[] packetCreationBuffer, @NonNull Callback cb, @NonNull Dependencies deps,
@NonNull String[] deviceHostName, @NonNull SharedLog sharedLog) {
mRecordRepository = deps.makeRecordRepository(looper, deviceHostName);
@NonNull String[] deviceHostName, @NonNull SharedLog sharedLog,
@NonNull MdnsFeatureFlags mdnsFeatureFlags) {
mRecordRepository = deps.makeRecordRepository(looper, deviceHostName, mdnsFeatureFlags);
mRecordRepository.updateAddresses(initialAddresses);
mSocket = socket;
mCb = cb;

View File

@@ -18,14 +18,15 @@ package com.android.server.connectivity.mdns;
import android.annotation.Nullable;
import com.android.internal.annotations.VisibleForTesting;
import androidx.annotation.VisibleForTesting;
import com.android.server.connectivity.mdns.util.MdnsUtils;
import java.io.IOException;
import java.util.Arrays;
/** An mDNS "PTR" record, which holds a name (the "pointer"). */
@VisibleForTesting
@VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
public class MdnsPointerRecord extends MdnsRecord {
private String[] pointer;

View File

@@ -92,16 +92,19 @@ public class MdnsRecordRepository {
private final Looper mLooper;
@NonNull
private final String[] mDeviceHostname;
private final MdnsFeatureFlags mMdnsFeatureFlags;
public MdnsRecordRepository(@NonNull Looper looper, @NonNull String[] deviceHostname) {
this(looper, new Dependencies(), deviceHostname);
public MdnsRecordRepository(@NonNull Looper looper, @NonNull String[] deviceHostname,
@NonNull MdnsFeatureFlags mdnsFeatureFlags) {
this(looper, new Dependencies(), deviceHostname, mdnsFeatureFlags);
}
@VisibleForTesting
public MdnsRecordRepository(@NonNull Looper looper, @NonNull Dependencies deps,
@NonNull String[] deviceHostname) {
@NonNull String[] deviceHostname, @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
mDeviceHostname = deviceHostname;
mLooper = looper;
mMdnsFeatureFlags = mdnsFeatureFlags;
}
/**
@@ -351,7 +354,8 @@ public class MdnsRecordRepository {
}
private MdnsProber.ProbingInfo makeProbingInfo(int serviceId,
@NonNull MdnsServiceRecord srvRecord) {
@NonNull MdnsServiceRecord srvRecord,
@NonNull List<MdnsInetAddressRecord> inetAddressRecords) {
final List<MdnsRecord> probingRecords = new ArrayList<>();
// Probe with cacheFlush cleared; it is set when announcing, as it was verified unique:
// RFC6762 10.2
@@ -363,6 +367,15 @@ public class MdnsRecordRepository {
srvRecord.getServicePort(),
srvRecord.getServiceHost()));
for (MdnsInetAddressRecord inetAddressRecord : inetAddressRecords) {
probingRecords.add(new MdnsInetAddressRecord(inetAddressRecord.getName(),
0L /* receiptTimeMillis */,
false /* cacheFlush */,
inetAddressRecord.getTtl(),
inetAddressRecord.getInet4Address() == null
? inetAddressRecord.getInet6Address()
: inetAddressRecord.getInet4Address()));
}
return new MdnsProber.ProbingInfo(serviceId, probingRecords);
}
@@ -824,6 +837,18 @@ public class MdnsRecordRepository {
return conflicting;
}
private List<MdnsInetAddressRecord> makeProbingInetAddressRecords() {
final List<MdnsInetAddressRecord> records = new ArrayList<>();
if (mMdnsFeatureFlags.mIncludeInetAddressRecordsInProbing) {
for (RecordInfo<?> record : mGeneralRecords) {
if (record.record instanceof MdnsInetAddressRecord) {
records.add((MdnsInetAddressRecord) record.record);
}
}
}
return records;
}
/**
* (Re)set a service to the probing state.
* @return The {@link MdnsProber.ProbingInfo} to send for probing.
@@ -834,7 +859,8 @@ public class MdnsRecordRepository {
if (registration == null) return null;
registration.setProbing(true);
return makeProbingInfo(serviceId, registration.srvRecord.record);
return makeProbingInfo(
serviceId, registration.srvRecord.record, makeProbingInetAddressRecords());
}
/**
@@ -870,7 +896,8 @@ public class MdnsRecordRepository {
final ServiceRegistration newService = new ServiceRegistration(mDeviceHostname, newInfo,
existing.subtype, existing.repliedServiceCount, existing.sentPacketCount);
mServices.put(serviceId, newService);
return makeProbingInfo(serviceId, newService.srvRecord.record);
return makeProbingInfo(
serviceId, newService.srvRecord.record, makeProbingInetAddressRecords());
}
/**

View File

@@ -18,7 +18,8 @@ package com.android.server.connectivity.mdns;
import android.annotation.Nullable;
import com.android.internal.annotations.VisibleForTesting;
import androidx.annotation.VisibleForTesting;
import com.android.server.connectivity.mdns.util.MdnsUtils;
import java.io.IOException;
@@ -27,7 +28,7 @@ import java.util.Locale;
import java.util.Objects;
/** An mDNS "SRV" record, which contains service information. */
@VisibleForTesting
@VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
public class MdnsServiceRecord extends MdnsRecord {
public static final int PROTO_NONE = 0;
public static final int PROTO_TCP = 1;

View File

@@ -18,7 +18,8 @@ package com.android.server.connectivity.mdns;
import android.annotation.Nullable;
import com.android.internal.annotations.VisibleForTesting;
import androidx.annotation.VisibleForTesting;
import com.android.server.connectivity.mdns.MdnsServiceInfo.TextEntry;
import java.io.IOException;
@@ -28,7 +29,7 @@ import java.util.List;
import java.util.Objects;
/** An mDNS "TXT" record, which contains a list of {@link TextEntry}. */
@VisibleForTesting
@VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
public class MdnsTextRecord extends MdnsRecord {
private List<TextEntry> entries;

View File

@@ -153,10 +153,10 @@ class MdnsAdvertiserTest {
thread.start()
doReturn(TEST_HOSTNAME).`when`(mockDeps).generateHostname()
doReturn(mockInterfaceAdvertiser1).`when`(mockDeps).makeAdvertiser(eq(mockSocket1),
any(), any(), any(), any(), eq(TEST_HOSTNAME), any()
any(), any(), any(), any(), eq(TEST_HOSTNAME), any(), any()
)
doReturn(mockInterfaceAdvertiser2).`when`(mockDeps).makeAdvertiser(eq(mockSocket2),
any(), any(), any(), any(), eq(TEST_HOSTNAME), any()
any(), any(), any(), any(), eq(TEST_HOSTNAME), any(), any()
)
doReturn(true).`when`(mockInterfaceAdvertiser1).isProbing(anyInt())
doReturn(true).`when`(mockInterfaceAdvertiser2).isProbing(anyInt())
@@ -202,6 +202,7 @@ class MdnsAdvertiserTest {
any(),
intAdvCbCaptor.capture(),
eq(TEST_HOSTNAME),
any(),
any()
)
@@ -259,10 +260,10 @@ class MdnsAdvertiserTest {
val intAdvCbCaptor1 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
val intAdvCbCaptor2 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
verify(mockDeps).makeAdvertiser(eq(mockSocket1), eq(listOf(TEST_LINKADDR)),
eq(thread.looper), any(), intAdvCbCaptor1.capture(), eq(TEST_HOSTNAME), any()
eq(thread.looper), any(), intAdvCbCaptor1.capture(), eq(TEST_HOSTNAME), any(), any()
)
verify(mockDeps).makeAdvertiser(eq(mockSocket2), eq(listOf(TEST_LINKADDR)),
eq(thread.looper), any(), intAdvCbCaptor2.capture(), eq(TEST_HOSTNAME), any()
eq(thread.looper), any(), intAdvCbCaptor2.capture(), eq(TEST_HOSTNAME), any(), any()
)
verify(mockInterfaceAdvertiser1).addService(
anyInt(), eq(ALL_NETWORKS_SERVICE), eq(TEST_SUBTYPE))
@@ -367,7 +368,7 @@ class MdnsAdvertiserTest {
val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
verify(mockDeps).makeAdvertiser(eq(mockSocket1), eq(listOf(TEST_LINKADDR)),
eq(thread.looper), any(), intAdvCbCaptor.capture(), eq(TEST_HOSTNAME), any()
eq(thread.looper), any(), intAdvCbCaptor.capture(), eq(TEST_HOSTNAME), any(), any()
)
verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_1),
argThat { it.matches(SERVICE_1) }, eq(null))

View File

@@ -77,6 +77,7 @@ class MdnsInterfaceAdvertiserTest {
private val announcer = mock(MdnsAnnouncer::class.java)
private val prober = mock(MdnsProber::class.java)
private val sharedlog = SharedLog("MdnsInterfaceAdvertiserTest")
private val flags = MdnsFeatureFlags.newBuilder().build()
@Suppress("UNCHECKED_CAST")
private val probeCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
as ArgumentCaptor<PacketRepeaterCallback<ProbingInfo>>
@@ -99,15 +100,14 @@ class MdnsInterfaceAdvertiserTest {
cb,
deps,
TEST_HOSTNAME,
sharedlog
sharedlog,
flags
)
}
@Before
fun setUp() {
doReturn(repository).`when`(deps).makeRecordRepository(any(),
eq(TEST_HOSTNAME)
)
doReturn(repository).`when`(deps).makeRecordRepository(any(), eq(TEST_HOSTNAME), any())
doReturn(replySender).`when`(deps).makeReplySender(anyString(), any(), any(), any(), any())
doReturn(announcer).`when`(deps).makeMdnsAnnouncer(anyString(), any(), any(), any(), any())
doReturn(prober).`when`(deps).makeMdnsProber(anyString(), any(), any(), any(), any())

View File

@@ -78,6 +78,7 @@ class MdnsRecordRepositoryTest {
override fun getInterfaceInetAddresses(iface: NetworkInterface) =
Collections.enumeration(TEST_ADDRESSES.map { it.address })
}
private val flags = MdnsFeatureFlags.newBuilder().build()
@Before
fun setUp() {
@@ -92,7 +93,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testAddServiceAndProbe() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
assertEquals(0, repository.servicesCount)
assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1,
null /* subtype */))
@@ -127,7 +128,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testAddAndConflicts() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
assertFailsWith(NameConflictException::class) {
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1, null /* subtype */)
@@ -139,7 +140,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testInvalidReuseOfServiceId() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
assertFailsWith(IllegalArgumentException::class) {
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_2, null /* subtype */)
@@ -148,7 +149,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testHasActiveService() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
assertFalse(repository.hasActiveService(TEST_SERVICE_ID_1))
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
@@ -165,7 +166,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testExitAnnouncements() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */)
@@ -195,7 +196,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testExitAnnouncements_WithSubtype() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, TEST_SUBTYPE)
repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */)
@@ -231,7 +232,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testExitingServiceReAdded() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */)
repository.exitService(TEST_SERVICE_ID_1)
@@ -246,7 +247,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testOnProbingSucceeded() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1,
TEST_SUBTYPE)
repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */)
@@ -371,7 +372,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testGetOffloadPacket() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
val serviceType = arrayOf("_testservice", "_tcp", "local")
@@ -433,7 +434,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testGetReplyCaseInsensitive() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
val questionsCaseInSensitive =
listOf(MdnsPointerRecord(arrayOf("_TESTSERVICE", "_TCP", "local"),
@@ -463,7 +464,7 @@ class MdnsRecordRepositoryTest {
}
private fun doGetReplyTest(subtype: String?) {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, subtype)
val queriedName = if (subtype == null) arrayOf("_testservice", "_tcp", "local")
else arrayOf(subtype, "_sub", "_testservice", "_tcp", "local")
@@ -551,7 +552,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testGetConflictingServices() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */)
@@ -579,7 +580,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testGetConflictingServicesCaseInsensitive() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */)
@@ -607,7 +608,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testGetConflictingServices_IdenticalService() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */)
@@ -636,7 +637,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testGetConflictingServicesCaseInsensitive_IdenticalService() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */)
@@ -665,7 +666,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testGetServiceRepliedRequestsCount() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
// Verify that there is no packet replied.
assertEquals(MdnsConstants.NO_PACKET,
@@ -690,6 +691,68 @@ class MdnsRecordRepositoryTest {
assertEquals(MdnsConstants.NO_PACKET,
repository.getServiceRepliedRequestsCount(TEST_SERVICE_ID_2))
}
@Test
fun testIncludeInetAddressRecordsInProbing() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME,
MdnsFeatureFlags.newBuilder().setIncludeInetAddressRecordsInProbing(true).build())
repository.updateAddresses(TEST_ADDRESSES)
assertEquals(0, repository.servicesCount)
assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1,
null /* subtype */))
assertEquals(1, repository.servicesCount)
val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
assertNotNull(probingInfo)
assertTrue(repository.isProbing(TEST_SERVICE_ID_1))
assertEquals(TEST_SERVICE_ID_1, probingInfo.serviceId)
val packet = probingInfo.getPacket(0)
assertEquals(MdnsConstants.FLAGS_QUERY, packet.flags)
assertEquals(0, packet.answers.size)
assertEquals(0, packet.additionalRecords.size)
assertEquals(2, packet.questions.size)
val expectedName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
assertContentEquals(listOf(
MdnsAnyRecord(expectedName, false /* unicast */),
MdnsAnyRecord(TEST_HOSTNAME, false /* unicast */),
), packet.questions)
assertEquals(4, packet.authorityRecords.size)
assertContentEquals(listOf(
MdnsServiceRecord(
expectedName,
0L /* receiptTimeMillis */,
false /* cacheFlush */,
120_000L /* ttlMillis */,
0 /* servicePriority */,
0 /* serviceWeight */,
TEST_PORT,
TEST_HOSTNAME),
MdnsInetAddressRecord(
TEST_HOSTNAME,
0L /* receiptTimeMillis */,
false /* cacheFlush */,
120_000L /* ttlMillis */,
TEST_ADDRESSES[0].address),
MdnsInetAddressRecord(
TEST_HOSTNAME,
0L /* receiptTimeMillis */,
false /* cacheFlush */,
120_000L /* ttlMillis */,
TEST_ADDRESSES[1].address),
MdnsInetAddressRecord(
TEST_HOSTNAME,
0L /* receiptTimeMillis */,
false /* cacheFlush */,
120_000L /* ttlMillis */,
TEST_ADDRESSES[2].address)
), packet.authorityRecords)
assertContentEquals(intArrayOf(TEST_SERVICE_ID_1), repository.clearServices())
}
}
private fun MdnsRecordRepository.initWithService(