Merge "Include A/AAAA records in probing packet" into main

This commit is contained in:
Paul Hu
2023-10-17 02:24:10 +00:00
committed by Gerrit Code Review
12 changed files with 180 additions and 53 deletions

View File

@@ -1709,9 +1709,14 @@ public class NsdService extends INsdManager.Stub {
mMdnsDiscoveryManager = deps.makeMdnsDiscoveryManager(new ExecutorProvider(), mMdnsDiscoveryManager = deps.makeMdnsDiscoveryManager(new ExecutorProvider(),
mMdnsSocketClient, LOGGER.forSubComponent("MdnsDiscoveryManager")); mMdnsSocketClient, LOGGER.forSubComponent("MdnsDiscoveryManager"));
handler.post(() -> mMdnsSocketClient.setCallback(mMdnsDiscoveryManager)); handler.post(() -> mMdnsSocketClient.setCallback(mMdnsDiscoveryManager));
MdnsFeatureFlags flags = new MdnsFeatureFlags.Builder().setIsMdnsOffloadFeatureEnabled( MdnsFeatureFlags flags = new MdnsFeatureFlags.Builder()
mDeps.isTetheringFeatureNotChickenedOut(mContext, .setIsMdnsOffloadFeatureEnabled(
MdnsFeatureFlags.NSD_FORCE_DISABLE_MDNS_OFFLOAD)).build(); mDeps.isTetheringFeatureNotChickenedOut(
mContext, MdnsFeatureFlags.NSD_FORCE_DISABLE_MDNS_OFFLOAD))
.setIncludeInetAddressRecordsInProbing(
mDeps.isFeatureEnabled(
mContext, MdnsFeatureFlags.INCLUDE_INET_ADDRESS_RECORDS_IN_PROBING))
.build();
mAdvertiser = deps.makeMdnsAdvertiser(handler.getLooper(), mMdnsSocketProvider, mAdvertiser = deps.makeMdnsAdvertiser(handler.getLooper(), mMdnsSocketProvider,
new AdvertiserCallback(), LOGGER.forSubComponent("MdnsAdvertiser"), flags); new AdvertiserCallback(), LOGGER.forSubComponent("MdnsAdvertiser"), flags);
mClock = deps.makeClock(); mClock = deps.makeClock();

View File

@@ -96,10 +96,11 @@ public class MdnsAdvertiser {
@NonNull Looper looper, @NonNull byte[] packetCreationBuffer, @NonNull Looper looper, @NonNull byte[] packetCreationBuffer,
@NonNull MdnsInterfaceAdvertiser.Callback cb, @NonNull MdnsInterfaceAdvertiser.Callback cb,
@NonNull String[] deviceHostName, @NonNull String[] deviceHostName,
@NonNull SharedLog sharedLog) { @NonNull SharedLog sharedLog,
@NonNull MdnsFeatureFlags mdnsFeatureFlags) {
// Note NetworkInterface is final and not mockable // Note NetworkInterface is final and not mockable
return new MdnsInterfaceAdvertiser(socket, initialAddresses, looper, return new MdnsInterfaceAdvertiser(socket, initialAddresses, looper,
packetCreationBuffer, cb, deviceHostName, sharedLog); packetCreationBuffer, cb, deviceHostName, sharedLog, mdnsFeatureFlags);
} }
/** /**
@@ -394,7 +395,8 @@ public class MdnsAdvertiser {
if (advertiser == null) { if (advertiser == null) {
advertiser = mDeps.makeAdvertiser(socket, addresses, mLooper, mPacketCreationBuffer, advertiser = mDeps.makeAdvertiser(socket, addresses, mLooper, mPacketCreationBuffer,
mInterfaceAdvertiserCb, mDeviceHostName, mInterfaceAdvertiserCb, mDeviceHostName,
mSharedLog.forSubComponent(socket.getInterface().getName())); mSharedLog.forSubComponent(socket.getInterface().getName()),
mMdnsFeatureFlags);
mAllAdvertisers.put(socket, advertiser); mAllAdvertisers.put(socket, advertiser);
advertiser.start(); advertiser.start();
} }

View File

@@ -24,14 +24,26 @@ public class MdnsFeatureFlags {
*/ */
public static final String NSD_FORCE_DISABLE_MDNS_OFFLOAD = "nsd_force_disable_mdns_offload"; public static final String NSD_FORCE_DISABLE_MDNS_OFFLOAD = "nsd_force_disable_mdns_offload";
/**
* The feature flag for controlling whether the probing question should include
* InetAddressRecords or not.
*/
public static final String INCLUDE_INET_ADDRESS_RECORDS_IN_PROBING =
"include_inet_address_records_in_probing";
// Flag for offload feature // Flag for offload feature
public final boolean mIsMdnsOffloadFeatureEnabled; public final boolean mIsMdnsOffloadFeatureEnabled;
// Flag for including InetAddressRecords in probing questions.
public final boolean mIncludeInetAddressRecordsInProbing;
/** /**
* The constructor for {@link MdnsFeatureFlags}. * The constructor for {@link MdnsFeatureFlags}.
*/ */
public MdnsFeatureFlags(boolean isOffloadFeatureEnabled) { public MdnsFeatureFlags(boolean isOffloadFeatureEnabled,
boolean includeInetAddressRecordsInProbing) {
mIsMdnsOffloadFeatureEnabled = isOffloadFeatureEnabled; mIsMdnsOffloadFeatureEnabled = isOffloadFeatureEnabled;
mIncludeInetAddressRecordsInProbing = includeInetAddressRecordsInProbing;
} }
@@ -44,12 +56,14 @@ public class MdnsFeatureFlags {
public static final class Builder { public static final class Builder {
private boolean mIsMdnsOffloadFeatureEnabled; private boolean mIsMdnsOffloadFeatureEnabled;
private boolean mIncludeInetAddressRecordsInProbing;
/** /**
* The constructor for {@link Builder}. * The constructor for {@link Builder}.
*/ */
public Builder() { public Builder() {
mIsMdnsOffloadFeatureEnabled = false; mIsMdnsOffloadFeatureEnabled = false;
mIncludeInetAddressRecordsInProbing = false;
} }
/** /**
@@ -60,11 +74,21 @@ public class MdnsFeatureFlags {
return this; return this;
} }
/**
* Set if the probing question should include InetAddressRecords.
*/
public Builder setIncludeInetAddressRecordsInProbing(
boolean includeInetAddressRecordsInProbing) {
mIncludeInetAddressRecordsInProbing = includeInetAddressRecordsInProbing;
return this;
}
/** /**
* Builds a {@link MdnsFeatureFlags} with the arguments supplied to this builder. * Builds a {@link MdnsFeatureFlags} with the arguments supplied to this builder.
*/ */
public MdnsFeatureFlags build() { public MdnsFeatureFlags build() {
return new MdnsFeatureFlags(mIsMdnsOffloadFeatureEnabled); return new MdnsFeatureFlags(
mIsMdnsOffloadFeatureEnabled, mIncludeInetAddressRecordsInProbing);
} }
} }

View File

@@ -18,7 +18,7 @@ package com.android.server.connectivity.mdns;
import android.annotation.Nullable; import android.annotation.Nullable;
import com.android.internal.annotations.VisibleForTesting; import androidx.annotation.VisibleForTesting;
import java.io.IOException; import java.io.IOException;
import java.net.Inet4Address; import java.net.Inet4Address;
@@ -29,7 +29,7 @@ import java.util.Locale;
import java.util.Objects; import java.util.Objects;
/** An mDNS "AAAA" or "A" record, which holds an IPv6 or IPv4 address. */ /** An mDNS "AAAA" or "A" record, which holds an IPv6 or IPv4 address. */
@VisibleForTesting @VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
public class MdnsInetAddressRecord extends MdnsRecord { public class MdnsInetAddressRecord extends MdnsRecord {
@Nullable private Inet6Address inet6Address; @Nullable private Inet6Address inet6Address;
@Nullable private Inet4Address inet4Address; @Nullable private Inet4Address inet4Address;

View File

@@ -150,8 +150,8 @@ public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHand
/** @see MdnsRecordRepository */ /** @see MdnsRecordRepository */
@NonNull @NonNull
public MdnsRecordRepository makeRecordRepository(@NonNull Looper looper, public MdnsRecordRepository makeRecordRepository(@NonNull Looper looper,
@NonNull String[] deviceHostName) { @NonNull String[] deviceHostName, @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
return new MdnsRecordRepository(looper, deviceHostName); return new MdnsRecordRepository(looper, deviceHostName, mdnsFeatureFlags);
} }
/** @see MdnsReplySender */ /** @see MdnsReplySender */
@@ -187,16 +187,18 @@ public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHand
public MdnsInterfaceAdvertiser(@NonNull MdnsInterfaceSocket socket, public MdnsInterfaceAdvertiser(@NonNull MdnsInterfaceSocket socket,
@NonNull List<LinkAddress> initialAddresses, @NonNull Looper looper, @NonNull List<LinkAddress> initialAddresses, @NonNull Looper looper,
@NonNull byte[] packetCreationBuffer, @NonNull Callback cb, @NonNull byte[] packetCreationBuffer, @NonNull Callback cb,
@NonNull String[] deviceHostName, @NonNull SharedLog sharedLog) { @NonNull String[] deviceHostName, @NonNull SharedLog sharedLog,
@NonNull MdnsFeatureFlags mdnsFeatureFlags) {
this(socket, initialAddresses, looper, packetCreationBuffer, cb, this(socket, initialAddresses, looper, packetCreationBuffer, cb,
new Dependencies(), deviceHostName, sharedLog); new Dependencies(), deviceHostName, sharedLog, mdnsFeatureFlags);
} }
public MdnsInterfaceAdvertiser(@NonNull MdnsInterfaceSocket socket, public MdnsInterfaceAdvertiser(@NonNull MdnsInterfaceSocket socket,
@NonNull List<LinkAddress> initialAddresses, @NonNull Looper looper, @NonNull List<LinkAddress> initialAddresses, @NonNull Looper looper,
@NonNull byte[] packetCreationBuffer, @NonNull Callback cb, @NonNull Dependencies deps, @NonNull byte[] packetCreationBuffer, @NonNull Callback cb, @NonNull Dependencies deps,
@NonNull String[] deviceHostName, @NonNull SharedLog sharedLog) { @NonNull String[] deviceHostName, @NonNull SharedLog sharedLog,
mRecordRepository = deps.makeRecordRepository(looper, deviceHostName); @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
mRecordRepository = deps.makeRecordRepository(looper, deviceHostName, mdnsFeatureFlags);
mRecordRepository.updateAddresses(initialAddresses); mRecordRepository.updateAddresses(initialAddresses);
mSocket = socket; mSocket = socket;
mCb = cb; mCb = cb;

View File

@@ -18,14 +18,15 @@ package com.android.server.connectivity.mdns;
import android.annotation.Nullable; import android.annotation.Nullable;
import com.android.internal.annotations.VisibleForTesting; import androidx.annotation.VisibleForTesting;
import com.android.server.connectivity.mdns.util.MdnsUtils; import com.android.server.connectivity.mdns.util.MdnsUtils;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
/** An mDNS "PTR" record, which holds a name (the "pointer"). */ /** An mDNS "PTR" record, which holds a name (the "pointer"). */
@VisibleForTesting @VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
public class MdnsPointerRecord extends MdnsRecord { public class MdnsPointerRecord extends MdnsRecord {
private String[] pointer; private String[] pointer;

View File

@@ -92,16 +92,19 @@ public class MdnsRecordRepository {
private final Looper mLooper; private final Looper mLooper;
@NonNull @NonNull
private final String[] mDeviceHostname; private final String[] mDeviceHostname;
private final MdnsFeatureFlags mMdnsFeatureFlags;
public MdnsRecordRepository(@NonNull Looper looper, @NonNull String[] deviceHostname) { public MdnsRecordRepository(@NonNull Looper looper, @NonNull String[] deviceHostname,
this(looper, new Dependencies(), deviceHostname); @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
this(looper, new Dependencies(), deviceHostname, mdnsFeatureFlags);
} }
@VisibleForTesting @VisibleForTesting
public MdnsRecordRepository(@NonNull Looper looper, @NonNull Dependencies deps, public MdnsRecordRepository(@NonNull Looper looper, @NonNull Dependencies deps,
@NonNull String[] deviceHostname) { @NonNull String[] deviceHostname, @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
mDeviceHostname = deviceHostname; mDeviceHostname = deviceHostname;
mLooper = looper; mLooper = looper;
mMdnsFeatureFlags = mdnsFeatureFlags;
} }
/** /**
@@ -351,7 +354,8 @@ public class MdnsRecordRepository {
} }
private MdnsProber.ProbingInfo makeProbingInfo(int serviceId, private MdnsProber.ProbingInfo makeProbingInfo(int serviceId,
@NonNull MdnsServiceRecord srvRecord) { @NonNull MdnsServiceRecord srvRecord,
@NonNull List<MdnsInetAddressRecord> inetAddressRecords) {
final List<MdnsRecord> probingRecords = new ArrayList<>(); final List<MdnsRecord> probingRecords = new ArrayList<>();
// Probe with cacheFlush cleared; it is set when announcing, as it was verified unique: // Probe with cacheFlush cleared; it is set when announcing, as it was verified unique:
// RFC6762 10.2 // RFC6762 10.2
@@ -363,6 +367,15 @@ public class MdnsRecordRepository {
srvRecord.getServicePort(), srvRecord.getServicePort(),
srvRecord.getServiceHost())); srvRecord.getServiceHost()));
for (MdnsInetAddressRecord inetAddressRecord : inetAddressRecords) {
probingRecords.add(new MdnsInetAddressRecord(inetAddressRecord.getName(),
0L /* receiptTimeMillis */,
false /* cacheFlush */,
inetAddressRecord.getTtl(),
inetAddressRecord.getInet4Address() == null
? inetAddressRecord.getInet6Address()
: inetAddressRecord.getInet4Address()));
}
return new MdnsProber.ProbingInfo(serviceId, probingRecords); return new MdnsProber.ProbingInfo(serviceId, probingRecords);
} }
@@ -824,6 +837,18 @@ public class MdnsRecordRepository {
return conflicting; return conflicting;
} }
private List<MdnsInetAddressRecord> makeProbingInetAddressRecords() {
final List<MdnsInetAddressRecord> records = new ArrayList<>();
if (mMdnsFeatureFlags.mIncludeInetAddressRecordsInProbing) {
for (RecordInfo<?> record : mGeneralRecords) {
if (record.record instanceof MdnsInetAddressRecord) {
records.add((MdnsInetAddressRecord) record.record);
}
}
}
return records;
}
/** /**
* (Re)set a service to the probing state. * (Re)set a service to the probing state.
* @return The {@link MdnsProber.ProbingInfo} to send for probing. * @return The {@link MdnsProber.ProbingInfo} to send for probing.
@@ -834,7 +859,8 @@ public class MdnsRecordRepository {
if (registration == null) return null; if (registration == null) return null;
registration.setProbing(true); registration.setProbing(true);
return makeProbingInfo(serviceId, registration.srvRecord.record); return makeProbingInfo(
serviceId, registration.srvRecord.record, makeProbingInetAddressRecords());
} }
/** /**
@@ -870,7 +896,8 @@ public class MdnsRecordRepository {
final ServiceRegistration newService = new ServiceRegistration(mDeviceHostname, newInfo, final ServiceRegistration newService = new ServiceRegistration(mDeviceHostname, newInfo,
existing.subtype, existing.repliedServiceCount, existing.sentPacketCount); existing.subtype, existing.repliedServiceCount, existing.sentPacketCount);
mServices.put(serviceId, newService); mServices.put(serviceId, newService);
return makeProbingInfo(serviceId, newService.srvRecord.record); return makeProbingInfo(
serviceId, newService.srvRecord.record, makeProbingInetAddressRecords());
} }
/** /**

View File

@@ -18,7 +18,8 @@ package com.android.server.connectivity.mdns;
import android.annotation.Nullable; import android.annotation.Nullable;
import com.android.internal.annotations.VisibleForTesting; import androidx.annotation.VisibleForTesting;
import com.android.server.connectivity.mdns.util.MdnsUtils; import com.android.server.connectivity.mdns.util.MdnsUtils;
import java.io.IOException; import java.io.IOException;
@@ -27,7 +28,7 @@ import java.util.Locale;
import java.util.Objects; import java.util.Objects;
/** An mDNS "SRV" record, which contains service information. */ /** An mDNS "SRV" record, which contains service information. */
@VisibleForTesting @VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
public class MdnsServiceRecord extends MdnsRecord { public class MdnsServiceRecord extends MdnsRecord {
public static final int PROTO_NONE = 0; public static final int PROTO_NONE = 0;
public static final int PROTO_TCP = 1; public static final int PROTO_TCP = 1;

View File

@@ -18,7 +18,8 @@ package com.android.server.connectivity.mdns;
import android.annotation.Nullable; import android.annotation.Nullable;
import com.android.internal.annotations.VisibleForTesting; import androidx.annotation.VisibleForTesting;
import com.android.server.connectivity.mdns.MdnsServiceInfo.TextEntry; import com.android.server.connectivity.mdns.MdnsServiceInfo.TextEntry;
import java.io.IOException; import java.io.IOException;
@@ -28,7 +29,7 @@ import java.util.List;
import java.util.Objects; import java.util.Objects;
/** An mDNS "TXT" record, which contains a list of {@link TextEntry}. */ /** An mDNS "TXT" record, which contains a list of {@link TextEntry}. */
@VisibleForTesting @VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
public class MdnsTextRecord extends MdnsRecord { public class MdnsTextRecord extends MdnsRecord {
private List<TextEntry> entries; private List<TextEntry> entries;

View File

@@ -153,10 +153,10 @@ class MdnsAdvertiserTest {
thread.start() thread.start()
doReturn(TEST_HOSTNAME).`when`(mockDeps).generateHostname() doReturn(TEST_HOSTNAME).`when`(mockDeps).generateHostname()
doReturn(mockInterfaceAdvertiser1).`when`(mockDeps).makeAdvertiser(eq(mockSocket1), doReturn(mockInterfaceAdvertiser1).`when`(mockDeps).makeAdvertiser(eq(mockSocket1),
any(), any(), any(), any(), eq(TEST_HOSTNAME), any() any(), any(), any(), any(), eq(TEST_HOSTNAME), any(), any()
) )
doReturn(mockInterfaceAdvertiser2).`when`(mockDeps).makeAdvertiser(eq(mockSocket2), doReturn(mockInterfaceAdvertiser2).`when`(mockDeps).makeAdvertiser(eq(mockSocket2),
any(), any(), any(), any(), eq(TEST_HOSTNAME), any() any(), any(), any(), any(), eq(TEST_HOSTNAME), any(), any()
) )
doReturn(true).`when`(mockInterfaceAdvertiser1).isProbing(anyInt()) doReturn(true).`when`(mockInterfaceAdvertiser1).isProbing(anyInt())
doReturn(true).`when`(mockInterfaceAdvertiser2).isProbing(anyInt()) doReturn(true).`when`(mockInterfaceAdvertiser2).isProbing(anyInt())
@@ -202,6 +202,7 @@ class MdnsAdvertiserTest {
any(), any(),
intAdvCbCaptor.capture(), intAdvCbCaptor.capture(),
eq(TEST_HOSTNAME), eq(TEST_HOSTNAME),
any(),
any() any()
) )
@@ -259,10 +260,10 @@ class MdnsAdvertiserTest {
val intAdvCbCaptor1 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java) val intAdvCbCaptor1 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
val intAdvCbCaptor2 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java) val intAdvCbCaptor2 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
verify(mockDeps).makeAdvertiser(eq(mockSocket1), eq(listOf(TEST_LINKADDR)), verify(mockDeps).makeAdvertiser(eq(mockSocket1), eq(listOf(TEST_LINKADDR)),
eq(thread.looper), any(), intAdvCbCaptor1.capture(), eq(TEST_HOSTNAME), any() eq(thread.looper), any(), intAdvCbCaptor1.capture(), eq(TEST_HOSTNAME), any(), any()
) )
verify(mockDeps).makeAdvertiser(eq(mockSocket2), eq(listOf(TEST_LINKADDR)), verify(mockDeps).makeAdvertiser(eq(mockSocket2), eq(listOf(TEST_LINKADDR)),
eq(thread.looper), any(), intAdvCbCaptor2.capture(), eq(TEST_HOSTNAME), any() eq(thread.looper), any(), intAdvCbCaptor2.capture(), eq(TEST_HOSTNAME), any(), any()
) )
verify(mockInterfaceAdvertiser1).addService( verify(mockInterfaceAdvertiser1).addService(
anyInt(), eq(ALL_NETWORKS_SERVICE), eq(TEST_SUBTYPE)) anyInt(), eq(ALL_NETWORKS_SERVICE), eq(TEST_SUBTYPE))
@@ -367,7 +368,7 @@ class MdnsAdvertiserTest {
val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java) val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
verify(mockDeps).makeAdvertiser(eq(mockSocket1), eq(listOf(TEST_LINKADDR)), verify(mockDeps).makeAdvertiser(eq(mockSocket1), eq(listOf(TEST_LINKADDR)),
eq(thread.looper), any(), intAdvCbCaptor.capture(), eq(TEST_HOSTNAME), any() eq(thread.looper), any(), intAdvCbCaptor.capture(), eq(TEST_HOSTNAME), any(), any()
) )
verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_1), verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_1),
argThat { it.matches(SERVICE_1) }, eq(null)) argThat { it.matches(SERVICE_1) }, eq(null))

View File

@@ -77,6 +77,7 @@ class MdnsInterfaceAdvertiserTest {
private val announcer = mock(MdnsAnnouncer::class.java) private val announcer = mock(MdnsAnnouncer::class.java)
private val prober = mock(MdnsProber::class.java) private val prober = mock(MdnsProber::class.java)
private val sharedlog = SharedLog("MdnsInterfaceAdvertiserTest") private val sharedlog = SharedLog("MdnsInterfaceAdvertiserTest")
private val flags = MdnsFeatureFlags.newBuilder().build()
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
private val probeCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java) private val probeCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
as ArgumentCaptor<PacketRepeaterCallback<ProbingInfo>> as ArgumentCaptor<PacketRepeaterCallback<ProbingInfo>>
@@ -99,15 +100,14 @@ class MdnsInterfaceAdvertiserTest {
cb, cb,
deps, deps,
TEST_HOSTNAME, TEST_HOSTNAME,
sharedlog sharedlog,
flags
) )
} }
@Before @Before
fun setUp() { fun setUp() {
doReturn(repository).`when`(deps).makeRecordRepository(any(), doReturn(repository).`when`(deps).makeRecordRepository(any(), eq(TEST_HOSTNAME), any())
eq(TEST_HOSTNAME)
)
doReturn(replySender).`when`(deps).makeReplySender(anyString(), any(), any(), any(), any()) doReturn(replySender).`when`(deps).makeReplySender(anyString(), any(), any(), any(), any())
doReturn(announcer).`when`(deps).makeMdnsAnnouncer(anyString(), any(), any(), any(), any()) doReturn(announcer).`when`(deps).makeMdnsAnnouncer(anyString(), any(), any(), any(), any())
doReturn(prober).`when`(deps).makeMdnsProber(anyString(), any(), any(), any(), any()) doReturn(prober).`when`(deps).makeMdnsProber(anyString(), any(), any(), any(), any())

View File

@@ -78,6 +78,7 @@ class MdnsRecordRepositoryTest {
override fun getInterfaceInetAddresses(iface: NetworkInterface) = override fun getInterfaceInetAddresses(iface: NetworkInterface) =
Collections.enumeration(TEST_ADDRESSES.map { it.address }) Collections.enumeration(TEST_ADDRESSES.map { it.address })
} }
private val flags = MdnsFeatureFlags.newBuilder().build()
@Before @Before
fun setUp() { fun setUp() {
@@ -92,7 +93,7 @@ class MdnsRecordRepositoryTest {
@Test @Test
fun testAddServiceAndProbe() { fun testAddServiceAndProbe() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
assertEquals(0, repository.servicesCount) assertEquals(0, repository.servicesCount)
assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1,
null /* subtype */)) null /* subtype */))
@@ -127,7 +128,7 @@ class MdnsRecordRepositoryTest {
@Test @Test
fun testAddAndConflicts() { fun testAddAndConflicts() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
assertFailsWith(NameConflictException::class) { assertFailsWith(NameConflictException::class) {
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1, null /* subtype */) repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1, null /* subtype */)
@@ -139,7 +140,7 @@ class MdnsRecordRepositoryTest {
@Test @Test
fun testInvalidReuseOfServiceId() { fun testInvalidReuseOfServiceId() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
assertFailsWith(IllegalArgumentException::class) { assertFailsWith(IllegalArgumentException::class) {
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_2, null /* subtype */) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_2, null /* subtype */)
@@ -148,7 +149,7 @@ class MdnsRecordRepositoryTest {
@Test @Test
fun testHasActiveService() { fun testHasActiveService() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
assertFalse(repository.hasActiveService(TEST_SERVICE_ID_1)) assertFalse(repository.hasActiveService(TEST_SERVICE_ID_1))
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
@@ -165,7 +166,7 @@ class MdnsRecordRepositoryTest {
@Test @Test
fun testExitAnnouncements() { fun testExitAnnouncements() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1) repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */) repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */)
@@ -195,7 +196,7 @@ class MdnsRecordRepositoryTest {
@Test @Test
fun testExitAnnouncements_WithSubtype() { fun testExitAnnouncements_WithSubtype() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, TEST_SUBTYPE) repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, TEST_SUBTYPE)
repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */) repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */)
@@ -231,7 +232,7 @@ class MdnsRecordRepositoryTest {
@Test @Test
fun testExitingServiceReAdded() { fun testExitingServiceReAdded() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1) repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */) repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */)
repository.exitService(TEST_SERVICE_ID_1) repository.exitService(TEST_SERVICE_ID_1)
@@ -246,7 +247,7 @@ class MdnsRecordRepositoryTest {
@Test @Test
fun testOnProbingSucceeded() { fun testOnProbingSucceeded() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1,
TEST_SUBTYPE) TEST_SUBTYPE)
repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */) repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */)
@@ -371,7 +372,7 @@ class MdnsRecordRepositoryTest {
@Test @Test
fun testGetOffloadPacket() { fun testGetOffloadPacket() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1) repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local") val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
val serviceType = arrayOf("_testservice", "_tcp", "local") val serviceType = arrayOf("_testservice", "_tcp", "local")
@@ -433,7 +434,7 @@ class MdnsRecordRepositoryTest {
@Test @Test
fun testGetReplyCaseInsensitive() { fun testGetReplyCaseInsensitive() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1) repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
val questionsCaseInSensitive = val questionsCaseInSensitive =
listOf(MdnsPointerRecord(arrayOf("_TESTSERVICE", "_TCP", "local"), listOf(MdnsPointerRecord(arrayOf("_TESTSERVICE", "_TCP", "local"),
@@ -463,7 +464,7 @@ class MdnsRecordRepositoryTest {
} }
private fun doGetReplyTest(subtype: String?) { private fun doGetReplyTest(subtype: String?) {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, subtype) repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, subtype)
val queriedName = if (subtype == null) arrayOf("_testservice", "_tcp", "local") val queriedName = if (subtype == null) arrayOf("_testservice", "_tcp", "local")
else arrayOf(subtype, "_sub", "_testservice", "_tcp", "local") else arrayOf(subtype, "_sub", "_testservice", "_tcp", "local")
@@ -551,7 +552,7 @@ class MdnsRecordRepositoryTest {
@Test @Test
fun testGetConflictingServices() { fun testGetConflictingServices() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */) repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */)
@@ -579,7 +580,7 @@ class MdnsRecordRepositoryTest {
@Test @Test
fun testGetConflictingServicesCaseInsensitive() { fun testGetConflictingServicesCaseInsensitive() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */) repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */)
@@ -607,7 +608,7 @@ class MdnsRecordRepositoryTest {
@Test @Test
fun testGetConflictingServices_IdenticalService() { fun testGetConflictingServices_IdenticalService() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */) repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */)
@@ -636,7 +637,7 @@ class MdnsRecordRepositoryTest {
@Test @Test
fun testGetConflictingServicesCaseInsensitive_IdenticalService() { fun testGetConflictingServicesCaseInsensitive_IdenticalService() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */) repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */)
@@ -665,7 +666,7 @@ class MdnsRecordRepositoryTest {
@Test @Test
fun testGetServiceRepliedRequestsCount() { fun testGetServiceRepliedRequestsCount() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1) repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
// Verify that there is no packet replied. // Verify that there is no packet replied.
assertEquals(MdnsConstants.NO_PACKET, assertEquals(MdnsConstants.NO_PACKET,
@@ -690,6 +691,68 @@ class MdnsRecordRepositoryTest {
assertEquals(MdnsConstants.NO_PACKET, assertEquals(MdnsConstants.NO_PACKET,
repository.getServiceRepliedRequestsCount(TEST_SERVICE_ID_2)) repository.getServiceRepliedRequestsCount(TEST_SERVICE_ID_2))
} }
@Test
fun testIncludeInetAddressRecordsInProbing() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME,
MdnsFeatureFlags.newBuilder().setIncludeInetAddressRecordsInProbing(true).build())
repository.updateAddresses(TEST_ADDRESSES)
assertEquals(0, repository.servicesCount)
assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1,
null /* subtype */))
assertEquals(1, repository.servicesCount)
val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
assertNotNull(probingInfo)
assertTrue(repository.isProbing(TEST_SERVICE_ID_1))
assertEquals(TEST_SERVICE_ID_1, probingInfo.serviceId)
val packet = probingInfo.getPacket(0)
assertEquals(MdnsConstants.FLAGS_QUERY, packet.flags)
assertEquals(0, packet.answers.size)
assertEquals(0, packet.additionalRecords.size)
assertEquals(2, packet.questions.size)
val expectedName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
assertContentEquals(listOf(
MdnsAnyRecord(expectedName, false /* unicast */),
MdnsAnyRecord(TEST_HOSTNAME, false /* unicast */),
), packet.questions)
assertEquals(4, packet.authorityRecords.size)
assertContentEquals(listOf(
MdnsServiceRecord(
expectedName,
0L /* receiptTimeMillis */,
false /* cacheFlush */,
120_000L /* ttlMillis */,
0 /* servicePriority */,
0 /* serviceWeight */,
TEST_PORT,
TEST_HOSTNAME),
MdnsInetAddressRecord(
TEST_HOSTNAME,
0L /* receiptTimeMillis */,
false /* cacheFlush */,
120_000L /* ttlMillis */,
TEST_ADDRESSES[0].address),
MdnsInetAddressRecord(
TEST_HOSTNAME,
0L /* receiptTimeMillis */,
false /* cacheFlush */,
120_000L /* ttlMillis */,
TEST_ADDRESSES[1].address),
MdnsInetAddressRecord(
TEST_HOSTNAME,
0L /* receiptTimeMillis */,
false /* cacheFlush */,
120_000L /* ttlMillis */,
TEST_ADDRESSES[2].address)
), packet.authorityRecords)
assertContentEquals(intArrayOf(TEST_SERVICE_ID_1), repository.clearServices())
}
} }
private fun MdnsRecordRepository.initWithService( private fun MdnsRecordRepository.initWithService(