Merge "Remove PacketRepeater destinationsSupplier logic"

This commit is contained in:
Remi NGUYEN VAN
2023-01-13 01:51:07 +00:00
committed by Gerrit Code Review
6 changed files with 31 additions and 50 deletions

View File

@@ -22,10 +22,8 @@ import android.os.Looper;
import com.android.internal.annotations.VisibleForTesting; import com.android.internal.annotations.VisibleForTesting;
import java.net.SocketAddress;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.function.Supplier;
/** /**
* Sends mDns announcements when a service registration changes and at regular intervals. * Sends mDns announcements when a service registration changes and at regular intervals.
@@ -43,11 +41,8 @@ public class MdnsAnnouncer extends MdnsPacketRepeater<MdnsAnnouncer.Announcement
static class AnnouncementInfo implements MdnsPacketRepeater.Request { static class AnnouncementInfo implements MdnsPacketRepeater.Request {
@NonNull @NonNull
private final MdnsPacket mPacket; private final MdnsPacket mPacket;
@NonNull
private final Supplier<Iterable<SocketAddress>> mDestinationsSupplier;
AnnouncementInfo(List<MdnsRecord> announcedRecords, List<MdnsRecord> additionalRecords, AnnouncementInfo(List<MdnsRecord> announcedRecords, List<MdnsRecord> additionalRecords) {
Supplier<Iterable<SocketAddress>> destinationsSupplier) {
// Records to announce (as answers) // Records to announce (as answers)
// Records to place in the "Additional records", with NSEC negative responses // Records to place in the "Additional records", with NSEC negative responses
// to mark records that have been verified unique // to mark records that have been verified unique
@@ -57,7 +52,6 @@ public class MdnsAnnouncer extends MdnsPacketRepeater<MdnsAnnouncer.Announcement
announcedRecords, announcedRecords,
Collections.emptyList() /* authorityRecords */, Collections.emptyList() /* authorityRecords */,
additionalRecords); additionalRecords);
mDestinationsSupplier = destinationsSupplier;
} }
@Override @Override
@@ -65,11 +59,6 @@ public class MdnsAnnouncer extends MdnsPacketRepeater<MdnsAnnouncer.Announcement
return mPacket; return mPacket;
} }
@Override
public Iterable<SocketAddress> getDestinations(int index) {
return mDestinationsSupplier.get();
}
@Override @Override
public long getDelayMs(int nextIndex) { public long getDelayMs(int nextIndex) {
// Delay is doubled for each announcement // Delay is doubled for each announcement

View File

@@ -24,7 +24,7 @@ import android.os.Message;
import android.util.Log; import android.util.Log;
import java.io.IOException; import java.io.IOException;
import java.net.SocketAddress; import java.net.InetSocketAddress;
/** /**
* A class used to send several packets at given time intervals. * A class used to send several packets at given time intervals.
@@ -32,6 +32,14 @@ import java.net.SocketAddress;
*/ */
public abstract class MdnsPacketRepeater<T extends MdnsPacketRepeater.Request> { public abstract class MdnsPacketRepeater<T extends MdnsPacketRepeater.Request> {
private static final boolean DBG = MdnsAdvertiser.DBG; private static final boolean DBG = MdnsAdvertiser.DBG;
private static final InetSocketAddress IPV4_ADDR = new InetSocketAddress(
MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT);
private static final InetSocketAddress IPV6_ADDR = new InetSocketAddress(
MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
private static final InetSocketAddress[] ALL_ADDRS = new InetSocketAddress[] {
IPV4_ADDR, IPV6_ADDR
};
@NonNull @NonNull
private final MdnsReplySender mReplySender; private final MdnsReplySender mReplySender;
@NonNull @NonNull
@@ -69,12 +77,6 @@ public abstract class MdnsPacketRepeater<T extends MdnsPacketRepeater.Request> {
@NonNull @NonNull
MdnsPacket getPacket(int index); MdnsPacket getPacket(int index);
/**
* Get a set of destinations for the packet for one iteration.
*/
@NonNull
Iterable<SocketAddress> getDestinations(int index);
/** /**
* Get the delay in milliseconds until the next packet transmission. * Get the delay in milliseconds until the next packet transmission.
*/ */
@@ -110,12 +112,13 @@ public abstract class MdnsPacketRepeater<T extends MdnsPacketRepeater.Request> {
} }
final MdnsPacket packet = request.getPacket(index); final MdnsPacket packet = request.getPacket(index);
final Iterable<SocketAddress> destinations = request.getDestinations(index);
if (DBG) { if (DBG) {
Log.v(getTag(), "Sending packets to " + destinations + " for iteration " Log.v(getTag(), "Sending packets for iteration " + index + " out of "
+ index + " out of " + request.getNumSends()); + request.getNumSends());
} }
for (SocketAddress destination : destinations) { // Send to both v4 and v6 addresses; the reply sender will take care of ignoring the
// send when the socket has not joined the relevant group.
for (InetSocketAddress destination : ALL_ADDRS) {
try { try {
mReplySender.sendNow(packet, destination); mReplySender.sendNow(packet, destination);
} catch (IOException e) { } catch (IOException e) {

View File

@@ -22,12 +22,10 @@ import android.os.Looper;
import com.android.internal.annotations.VisibleForTesting; import com.android.internal.annotations.VisibleForTesting;
import com.android.net.module.util.CollectionUtils; import com.android.net.module.util.CollectionUtils;
import java.net.SocketAddress;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.function.Supplier;
/** /**
* Sends mDns probe requests to verify service records are unique on the network. * Sends mDns probe requests to verify service records are unique on the network.
@@ -51,21 +49,15 @@ public class MdnsProber extends MdnsPacketRepeater<MdnsProber.ProbingInfo> {
private final int mServiceId; private final int mServiceId;
@NonNull @NonNull
private final MdnsPacket mPacket; private final MdnsPacket mPacket;
@NonNull
private final Supplier<Iterable<SocketAddress>> mDestinationsSupplier;
/** /**
* Create a new ProbingInfo * Create a new ProbingInfo
* @param serviceId Service to probe for. * @param serviceId Service to probe for.
* @param probeRecords Records to be probed for uniqueness. * @param probeRecords Records to be probed for uniqueness.
* @param destinationsSupplier Supplier for the probe destinations. Will be called on the
* probe handler thread for each probe.
*/ */
ProbingInfo(int serviceId, @NonNull List<MdnsRecord> probeRecords, ProbingInfo(int serviceId, @NonNull List<MdnsRecord> probeRecords) {
@NonNull Supplier<Iterable<SocketAddress>> destinationsSupplier) {
mServiceId = serviceId; mServiceId = serviceId;
mPacket = makePacket(probeRecords); mPacket = makePacket(probeRecords);
mDestinationsSupplier = destinationsSupplier;
} }
public int getServiceId() { public int getServiceId() {
@@ -78,12 +70,6 @@ public class MdnsProber extends MdnsPacketRepeater<MdnsProber.ProbingInfo> {
return mPacket; return mPacket;
} }
@NonNull
@Override
public Iterable<SocketAddress> getDestinations(int index) {
return mDestinationsSupplier.get();
}
@Override @Override
public long getDelayMs(int nextIndex) { public long getDelayMs(int nextIndex) {
// As per https://datatracker.ietf.org/doc/html/rfc6762#section-8.1 // As per https://datatracker.ietf.org/doc/html/rfc6762#section-8.1

View File

@@ -21,8 +21,10 @@ import android.os.Looper;
import java.io.IOException; import java.io.IOException;
import java.net.DatagramPacket; import java.net.DatagramPacket;
import java.net.Inet4Address;
import java.net.Inet6Address;
import java.net.InetSocketAddress;
import java.net.MulticastSocket; import java.net.MulticastSocket;
import java.net.SocketAddress;
/** /**
* A class that handles sending mDNS replies to a {@link MulticastSocket}, possibly queueing them * A class that handles sending mDNS replies to a {@link MulticastSocket}, possibly queueing them
@@ -50,11 +52,16 @@ public class MdnsReplySender {
* *
* Must be called on the looper thread used by the {@link MdnsReplySender}. * Must be called on the looper thread used by the {@link MdnsReplySender}.
*/ */
public void sendNow(@NonNull MdnsPacket packet, @NonNull SocketAddress destination) public void sendNow(@NonNull MdnsPacket packet, @NonNull InetSocketAddress destination)
throws IOException { throws IOException {
if (Thread.currentThread() != mLooper.getThread()) { if (Thread.currentThread() != mLooper.getThread()) {
throw new IllegalStateException("sendNow must be called in the handler thread"); throw new IllegalStateException("sendNow must be called in the handler thread");
} }
if (!((destination.getAddress() instanceof Inet6Address && mSocket.hasJoinedIpv6())
|| (destination.getAddress() instanceof Inet4Address && mSocket.hasJoinedIpv4()))) {
// Skip sending if the socket has not joined the v4/v6 group (there was no address)
return;
}
// TODO: support packets over size (send in multiple packets with TC bit set) // TODO: support packets over size (send in multiple packets with TC bit set)
final MdnsPacketWriter writer = new MdnsPacketWriter(mPacketCreationBuffer); final MdnsPacketWriter writer = new MdnsPacketWriter(mPacketCreationBuffer);

View File

@@ -27,7 +27,6 @@ import com.android.testutils.DevSdkIgnoreRunner
import java.net.DatagramPacket import java.net.DatagramPacket
import java.net.Inet6Address import java.net.Inet6Address
import java.net.InetAddress import java.net.InetAddress
import java.net.InetSocketAddress
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertTrue import kotlin.test.assertTrue
import org.junit.After import org.junit.After
@@ -37,6 +36,7 @@ import org.junit.runner.RunWith
import org.mockito.ArgumentCaptor import org.mockito.ArgumentCaptor
import org.mockito.Mockito.any import org.mockito.Mockito.any
import org.mockito.Mockito.atLeast import org.mockito.Mockito.atLeast
import org.mockito.Mockito.doReturn
import org.mockito.Mockito.mock import org.mockito.Mockito.mock
import org.mockito.Mockito.timeout import org.mockito.Mockito.timeout
import org.mockito.Mockito.verify import org.mockito.Mockito.verify
@@ -46,9 +46,6 @@ private const val FIRST_ANNOUNCES_COUNT = 2
private const val NEXT_ANNOUNCES_DELAY = 1L private const val NEXT_ANNOUNCES_DELAY = 1L
private const val TEST_TIMEOUT_MS = 1000L private const val TEST_TIMEOUT_MS = 1000L
private val destinationsSupplier = {
listOf(InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT)) }
@RunWith(DevSdkIgnoreRunner::class) @RunWith(DevSdkIgnoreRunner::class)
@IgnoreUpTo(Build.VERSION_CODES.S_V2) @IgnoreUpTo(Build.VERSION_CODES.S_V2)
class MdnsAnnouncerTest { class MdnsAnnouncerTest {
@@ -59,6 +56,7 @@ class MdnsAnnouncerTest {
@Before @Before
fun setUp() { fun setUp() {
doReturn(true).`when`(socket).hasJoinedIpv6()
thread.start() thread.start()
} }
@@ -70,7 +68,7 @@ class MdnsAnnouncerTest {
private class TestAnnouncementInfo( private class TestAnnouncementInfo(
announcedRecords: List<MdnsRecord>, announcedRecords: List<MdnsRecord>,
additionalRecords: List<MdnsRecord> additionalRecords: List<MdnsRecord>
) : AnnouncementInfo(announcedRecords, additionalRecords, destinationsSupplier) { ) : AnnouncementInfo(announcedRecords, additionalRecords) {
override fun getDelayMs(nextIndex: Int) = override fun getDelayMs(nextIndex: Int) =
if (nextIndex < FIRST_ANNOUNCES_COUNT) { if (nextIndex < FIRST_ANNOUNCES_COUNT) {
FIRST_ANNOUNCES_DELAY FIRST_ANNOUNCES_DELAY

View File

@@ -25,7 +25,6 @@ import com.android.server.connectivity.mdns.MdnsProber.ProbingInfo
import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
import com.android.testutils.DevSdkIgnoreRunner import com.android.testutils.DevSdkIgnoreRunner
import java.net.DatagramPacket import java.net.DatagramPacket
import java.net.InetSocketAddress
import java.util.concurrent.CompletableFuture import java.util.concurrent.CompletableFuture
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import kotlin.test.assertEquals import kotlin.test.assertEquals
@@ -37,15 +36,13 @@ import org.junit.runner.RunWith
import org.mockito.ArgumentCaptor import org.mockito.ArgumentCaptor
import org.mockito.Mockito.any import org.mockito.Mockito.any
import org.mockito.Mockito.atLeast import org.mockito.Mockito.atLeast
import org.mockito.Mockito.doReturn
import org.mockito.Mockito.mock import org.mockito.Mockito.mock
import org.mockito.Mockito.never import org.mockito.Mockito.never
import org.mockito.Mockito.timeout import org.mockito.Mockito.timeout
import org.mockito.Mockito.times import org.mockito.Mockito.times
import org.mockito.Mockito.verify import org.mockito.Mockito.verify
private val destinationsSupplier = {
listOf(InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT)) }
private const val TEST_TIMEOUT_MS = 10_000L private const val TEST_TIMEOUT_MS = 10_000L
private const val SHORT_TIMEOUT_MS = 200L private const val SHORT_TIMEOUT_MS = 200L
@@ -64,6 +61,7 @@ class MdnsProberTest {
@Before @Before
fun setUp() { fun setUp() {
doReturn(true).`when`(socket).hasJoinedIpv6()
thread.start() thread.start()
} }
@@ -73,7 +71,7 @@ class MdnsProberTest {
} }
private class TestProbeInfo(probeRecords: List<MdnsRecord>, private val delayMs: Long = 1L) : private class TestProbeInfo(probeRecords: List<MdnsRecord>, private val delayMs: Long = 1L) :
ProbingInfo(1 /* serviceId */, probeRecords, destinationsSupplier) { ProbingInfo(1 /* serviceId */, probeRecords) {
// Just send the packets quickly. Timing-related tests for MdnsPacketRepeater are already // Just send the packets quickly. Timing-related tests for MdnsPacketRepeater are already
// done in MdnsAnnouncerTest. // done in MdnsAnnouncerTest.
override fun getDelayMs(nextIndex: Int) = delayMs override fun getDelayMs(nextIndex: Int) = delayMs