From 46676497eb8a530078ad7f7e76ee736ddaa4adfd Mon Sep 17 00:00:00 2001 From: Remi NGUYEN VAN Date: Thu, 12 Jan 2023 23:12:05 +0900 Subject: [PATCH 1/2] Add replying to queries MdnsInterfaceAdvertiser registers to receive incoming packets, and sends replies to queries as built by MdnsRecordRepository. Bug: 241738458 Test: atest Change-Id: I13db22f8efc870b6e0747d105f6bc8f759910f81 --- .../android/server/mdns/MdnsConstants.java | 1 + .../server/mdns/MdnsInterfaceAdvertiser.java | 45 +++- .../server/mdns/MdnsInterfaceSocket.java | 8 + .../server/mdns/MdnsPacketRepeater.java | 9 +- .../com/android/server/mdns/MdnsRecord.java | 3 +- .../server/mdns/MdnsRecordRepository.java | 224 ++++++++++++++++++ .../android/server/mdns/MdnsReplySender.java | 69 +++++- .../server/mdns/MulticastPacketReader.java | 9 + .../connectivity/mdns/MdnsAnnouncerTest.kt | 2 +- .../mdns/MdnsInterfaceAdvertiserTest.kt | 52 +++- .../connectivity/mdns/MdnsProberTest.kt | 6 +- .../mdns/MdnsRecordRepositoryTest.kt | 113 ++++++++- 12 files changed, 502 insertions(+), 39 deletions(-) diff --git a/service-t/src/com/android/server/mdns/MdnsConstants.java b/service-t/src/com/android/server/mdns/MdnsConstants.java index 396be5f065..f0e1717df5 100644 --- a/service-t/src/com/android/server/mdns/MdnsConstants.java +++ b/service-t/src/com/android/server/mdns/MdnsConstants.java @@ -37,6 +37,7 @@ public final class MdnsConstants { public static final int FLAGS_QUERY = 0x0000; public static final int FLAGS_RESPONSE_MASK = 0xF80F; public static final int FLAGS_RESPONSE = 0x8000; + public static final int FLAG_TRUNCATED = 0x0200; public static final int QCLASS_INTERNET = 0x0001; public static final int QCLASS_UNICAST = 0x8000; public static final String SUBTYPE_LABEL = "_sub"; diff --git a/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java index 790e69a860..a14b5ad8b3 100644 --- a/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java +++ b/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java @@ -25,16 +25,18 @@ import android.os.Looper; import android.util.Log; import com.android.internal.annotations.VisibleForTesting; +import com.android.net.module.util.HexDump; import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo; import com.android.server.connectivity.mdns.MdnsPacketRepeater.PacketRepeaterCallback; import java.io.IOException; +import java.net.InetSocketAddress; import java.util.List; /** * A class that handles advertising services on a {@link MdnsInterfaceSocket} tied to an interface. */ -public class MdnsInterfaceAdvertiser { +public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHandler { private static final boolean DBG = MdnsAdvertiser.DBG; @VisibleForTesting public static final long EXIT_ANNOUNCEMENT_DELAY_MS = 100L; @@ -145,9 +147,9 @@ public class MdnsInterfaceAdvertiser { /** @see MdnsReplySender */ @NonNull - public MdnsReplySender makeReplySender(@NonNull Looper looper, + public MdnsReplySender makeReplySender(@NonNull String interfaceTag, @NonNull Looper looper, @NonNull MdnsInterfaceSocket socket, @NonNull byte[] packetCreationBuffer) { - return new MdnsReplySender(looper, socket, packetCreationBuffer); + return new MdnsReplySender(interfaceTag, looper, socket, packetCreationBuffer); } /** @see MdnsAnnouncer */ @@ -182,7 +184,7 @@ public class MdnsInterfaceAdvertiser { mSocket = socket; mCb = cb; mCbHandler = new Handler(looper); - mReplySender = deps.makeReplySender(looper, socket, packetCreationBuffer); + mReplySender = deps.makeReplySender(logTag, looper, socket, packetCreationBuffer); mAnnouncer = deps.makeMdnsAnnouncer(logTag, looper, mReplySender, mAnnouncingCallback); mProber = deps.makeMdnsProber(logTag, looper, mReplySender, mProbingCallback); @@ -196,7 +198,7 @@ public class MdnsInterfaceAdvertiser { * {@link #destroyNow()}. */ public void start() { - // TODO: start receiving packets + mSocket.addPacketHandler(this); } /** @@ -267,8 +269,8 @@ public class MdnsInterfaceAdvertiser { mProber.stop(serviceId); mAnnouncer.stop(serviceId); } - - // TODO: stop receiving packets + mReplySender.cancelAll(); + mSocket.removePacketHandler(this); mCbHandler.post(() -> mCb.onDestroyed(mSocket)); } @@ -294,4 +296,33 @@ public class MdnsInterfaceAdvertiser { public boolean isProbing(int serviceId) { return mRecordRepository.isProbing(serviceId); } + + @Override + public void handlePacket(byte[] recvbuf, int length, InetSocketAddress src) { + final MdnsPacket packet; + try { + packet = MdnsPacket.parse(new MdnsPacketReader(recvbuf, length)); + } catch (MdnsPacket.ParseException e) { + Log.e(mTag, "Error parsing mDNS packet", e); + if (DBG) { + Log.v( + mTag, "Packet: " + HexDump.toHexString(recvbuf, 0, length)); + } + return; + } + + if (DBG) { + Log.v(mTag, + "Parsed packet with " + packet.questions.size() + " questions, " + + packet.answers.size() + " answers, " + + packet.authorityRecords.size() + " authority, " + + packet.additionalRecords.size() + " additional from " + src); + } + + final MdnsRecordRepository.ReplyInfo answers = + mRecordRepository.getReply(packet, src); + + if (answers == null) return; + mReplySender.queueReply(answers); + } } diff --git a/service-t/src/com/android/server/mdns/MdnsInterfaceSocket.java b/service-t/src/com/android/server/mdns/MdnsInterfaceSocket.java index d1290b6ea9..119c7a8ee4 100644 --- a/service-t/src/com/android/server/mdns/MdnsInterfaceSocket.java +++ b/service-t/src/com/android/server/mdns/MdnsInterfaceSocket.java @@ -161,6 +161,14 @@ public class MdnsInterfaceSocket { mPacketReader.addPacketHandler(handler); } + /** + * Remove a handler added via {@link #addPacketHandler}. If the handler is not present, this is + * a no-op. + */ + public void removePacketHandler(@NonNull MulticastPacketReader.PacketHandler handler) { + mPacketReader.removePacketHandler(handler); + } + /** * Returns the network interface that this socket is bound to. * diff --git a/service-t/src/com/android/server/mdns/MdnsPacketRepeater.java b/service-t/src/com/android/server/mdns/MdnsPacketRepeater.java index ae54e702da..4c385da436 100644 --- a/service-t/src/com/android/server/mdns/MdnsPacketRepeater.java +++ b/service-t/src/com/android/server/mdns/MdnsPacketRepeater.java @@ -16,6 +16,9 @@ package com.android.server.connectivity.mdns; +import static com.android.server.connectivity.mdns.MdnsRecordRepository.IPV4_ADDR; +import static com.android.server.connectivity.mdns.MdnsRecordRepository.IPV6_ADDR; + import android.annotation.NonNull; import android.annotation.Nullable; import android.os.Handler; @@ -32,10 +35,6 @@ import java.net.InetSocketAddress; */ public abstract class MdnsPacketRepeater { private static final boolean DBG = MdnsAdvertiser.DBG; - private static final InetSocketAddress IPV4_ADDR = new InetSocketAddress( - MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT); - private static final InetSocketAddress IPV6_ADDR = new InetSocketAddress( - MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT); private static final InetSocketAddress[] ALL_ADDRS = new InetSocketAddress[] { IPV4_ADDR, IPV6_ADDR }; @@ -114,7 +113,7 @@ public abstract class MdnsPacketRepeater { final MdnsPacket packet = request.getPacket(index); if (DBG) { Log.v(getTag(), "Sending packets for iteration " + index + " out of " - + request.getNumSends()); + + request.getNumSends() + " for ID " + msg.what); } // Send to both v4 and v6 addresses; the reply sender will take care of ignoring the // send when the socket has not joined the relevant group. diff --git a/service-t/src/com/android/server/mdns/MdnsRecord.java b/service-t/src/com/android/server/mdns/MdnsRecord.java index 00871ea2d4..bcee9d1514 100644 --- a/service-t/src/com/android/server/mdns/MdnsRecord.java +++ b/service-t/src/com/android/server/mdns/MdnsRecord.java @@ -45,6 +45,7 @@ public abstract class MdnsRecord { private static final int FLAG_CACHE_FLUSH = 0x8000; public static final long RECEIPT_TIME_NOT_SENT = 0L; + public static final int CLASS_ANY = 0x00ff; /** Status indicating that the record is current. */ public static final int STATUS_OK = 0; @@ -317,4 +318,4 @@ public abstract class MdnsRecord { return (recordType * 31) + Arrays.hashCode(recordName); } } -} \ No newline at end of file +} diff --git a/service-t/src/com/android/server/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/mdns/MdnsRecordRepository.java index dd00212186..4b2f553acc 100644 --- a/service-t/src/com/android/server/mdns/MdnsRecordRepository.java +++ b/service-t/src/com/android/server/mdns/MdnsRecordRepository.java @@ -34,6 +34,7 @@ import com.android.net.module.util.HexDump; import java.io.IOException; import java.net.Inet4Address; import java.net.InetAddress; +import java.net.InetSocketAddress; import java.net.NetworkInterface; import java.util.ArrayList; import java.util.Arrays; @@ -42,6 +43,7 @@ import java.util.Enumeration; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Random; import java.util.Set; import java.util.TreeMap; import java.util.UUID; @@ -54,6 +56,9 @@ import java.util.concurrent.TimeUnit; */ @TargetApi(Build.VERSION_CODES.TIRAMISU) // Allow calling T+ APIs; this is only loaded on T+ public class MdnsRecordRepository { + // RFC6762 p.15 + private static final long MIN_MULTICAST_REPLY_INTERVAL_MS = 1_000L; + // TTLs as per RFC6762 10. // TTL for records with a host name as the resource record's name (e.g., A, AAAA, HINFO) or a // host name contained within the resource record's rdata (e.g., SRV, reverse mapping PTR @@ -69,6 +74,13 @@ public class MdnsRecordRepository { private static final String[] DNS_SD_SERVICE_TYPE = new String[] { "_services", "_dns-sd", "_udp", LOCAL_TLD }; + public static final InetSocketAddress IPV6_ADDR = new InetSocketAddress( + MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT); + public static final InetSocketAddress IPV4_ADDR = new InetSocketAddress( + MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT); + + @NonNull + private final Random mDelayGenerator = new Random(); // Map of service unique ID -> records for service @NonNull private final SparseArray mServices = new SparseArray<>(); @@ -138,6 +150,11 @@ public class MdnsRecordRepository { */ public boolean isProbing; + /** + * Last time (as per SystemClock.elapsedRealtime) when advertised via multicast, 0 if never + */ + public long lastAdvertisedTimeMs; + /** * Last time (as per SystemClock.elapsedRealtime) when sent via unicast or multicast, * 0 if never @@ -390,6 +407,212 @@ public class MdnsRecordRepository { return ret; } + /** + * Info about a reply to be sent. + */ + public static class ReplyInfo { + @NonNull + public final List answers; + @NonNull + public final List additionalAnswers; + public final long sendDelayMs; + @NonNull + public final InetSocketAddress destination; + + public ReplyInfo( + @NonNull List answers, + @NonNull List additionalAnswers, + long sendDelayMs, + @NonNull InetSocketAddress destination) { + this.answers = answers; + this.additionalAnswers = additionalAnswers; + this.sendDelayMs = sendDelayMs; + this.destination = destination; + } + + @Override + public String toString() { + return "{ReplyInfo to " + destination + ", answers: " + answers.size() + + ", additionalAnswers: " + additionalAnswers.size() + + ", sendDelayMs " + sendDelayMs + "}"; + } + } + + /** + * Get the reply to send to an incoming packet. + * + * @param packet The incoming packet. + * @param src The source address of the incoming packet. + */ + @Nullable + public ReplyInfo getReply(MdnsPacket packet, InetSocketAddress src) { + final long now = SystemClock.elapsedRealtime(); + final boolean replyUnicast = (packet.flags & MdnsConstants.QCLASS_UNICAST) != 0; + final ArrayList additionalAnswerRecords = new ArrayList<>(); + final ArrayList> answerInfo = new ArrayList<>(); + for (MdnsRecord question : packet.questions) { + // Add answers from general records + addReplyFromService(question, mGeneralRecords, null /* servicePtrRecord */, + null /* serviceSrvRecord */, null /* serviceTxtRecord */, replyUnicast, now, + answerInfo, additionalAnswerRecords); + + // Add answers from each service + for (int i = 0; i < mServices.size(); i++) { + final ServiceRegistration registration = mServices.valueAt(i); + if (registration.exiting) continue; + addReplyFromService(question, registration.allRecords, registration.ptrRecord, + registration.srvRecord, registration.txtRecord, replyUnicast, now, + answerInfo, additionalAnswerRecords); + } + } + + if (answerInfo.size() == 0 && additionalAnswerRecords.size() == 0) { + return null; + } + + // Determine the send delay + final long delayMs; + if ((packet.flags & MdnsConstants.FLAG_TRUNCATED) != 0) { + // RFC 6762 6.: 400-500ms delay if TC bit is set + delayMs = 400L + mDelayGenerator.nextInt(100); + } else if (packet.questions.size() > 1 + || CollectionUtils.any(answerInfo, a -> a.isSharedName)) { + // 20-120ms if there may be responses from other hosts (not a fully owned + // name) (RFC 6762 6.), or if there are multiple questions (6.3). + // TODO: this should be 0 if this is a probe query ("can be distinguished from a + // normal query by the fact that a probe query contains a proposed record in the + // Authority Section that answers the question" in 6.), and the reply is for a fully + // owned record. + delayMs = 20L + mDelayGenerator.nextInt(100); + } else { + delayMs = 0L; + } + + // Determine the send destination + final InetSocketAddress dest; + if (replyUnicast) { + dest = src; + } else if (src.getAddress() instanceof Inet4Address) { + dest = IPV4_ADDR; + } else { + dest = IPV6_ADDR; + } + + // Build the list of answer records from their RecordInfo + final ArrayList answerRecords = new ArrayList<>(answerInfo.size()); + for (RecordInfo info : answerInfo) { + // TODO: consider actual packet send delay after response aggregation + info.lastSentTimeMs = now + delayMs; + if (!replyUnicast) { + info.lastAdvertisedTimeMs = info.lastSentTimeMs; + } + answerRecords.add(info.record); + } + + return new ReplyInfo(answerRecords, additionalAnswerRecords, delayMs, dest); + } + + /** + * Add answers and additional answers for a question, from a ServiceRegistration. + */ + private void addReplyFromService(@NonNull MdnsRecord question, + @NonNull List> serviceRecords, + @Nullable RecordInfo servicePtrRecord, + @Nullable RecordInfo serviceSrvRecord, + @Nullable RecordInfo serviceTxtRecord, + boolean replyUnicast, long now, @NonNull List> answerInfo, + @NonNull List additionalAnswerRecords) { + boolean hasDnsSdPtrRecordAnswer = false; + boolean hasDnsSdSrvRecordAnswer = false; + boolean hasFullyOwnedNameMatch = false; + boolean hasKnownAnswer = false; + + final int answersStartIndex = answerInfo.size(); + for (RecordInfo info : serviceRecords) { + if (info.isProbing) continue; + + /* RFC6762 6.: the record name must match the question name, the record rrtype + must match the question qtype unless the qtype is "ANY" (255) or the rrtype is + "CNAME" (5), and the record rrclass must match the question qclass unless the + qclass is "ANY" (255) */ + if (!Arrays.equals(info.record.getName(), question.getName())) continue; + hasFullyOwnedNameMatch |= !info.isSharedName; + + // The repository does not store CNAME records + if (question.getType() != MdnsRecord.TYPE_ANY + && question.getType() != info.record.getType()) { + continue; + } + if (question.getRecordClass() != MdnsRecord.CLASS_ANY + && question.getRecordClass() != info.record.getRecordClass()) { + continue; + } + + hasKnownAnswer = true; + hasDnsSdPtrRecordAnswer |= (info == servicePtrRecord); + hasDnsSdSrvRecordAnswer |= (info == serviceSrvRecord); + + // TODO: responses to probe queries should bypass this check and only ensure the + // reply is sent 250ms after the last sent time (RFC 6762 p.15) + if (!replyUnicast && info.lastAdvertisedTimeMs > 0L + && now - info.lastAdvertisedTimeMs < MIN_MULTICAST_REPLY_INTERVAL_MS) { + continue; + } + + // TODO: Don't reply if in known answers of the querier (7.1) if TTL is > half + + answerInfo.add(info); + } + + // RFC6762 6.1: + // "Any time a responder receives a query for a name for which it has verified exclusive + // ownership, for a type for which that name has no records, the responder MUST [...] + // respond asserting the nonexistence of that record" + if (hasFullyOwnedNameMatch && !hasKnownAnswer) { + additionalAnswerRecords.add(new MdnsNsecRecord( + question.getName(), + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + // TODO: RFC6762 6.1: "In general, the TTL given for an NSEC record SHOULD + // be the same as the TTL that the record would have had, had it existed." + NAME_RECORDS_TTL_MILLIS, + question.getName(), + new int[] { question.getType() })); + } + + // No more records to add if no answer + if (answerInfo.size() == answersStartIndex) return; + + final List> additionalAnswerInfo = new ArrayList<>(); + // RFC6763 12.1: if including PTR record, include the SRV and TXT records it names + if (hasDnsSdPtrRecordAnswer) { + if (serviceTxtRecord != null) { + additionalAnswerInfo.add(serviceTxtRecord); + } + if (serviceSrvRecord != null) { + additionalAnswerInfo.add(serviceSrvRecord); + } + } + + // RFC6763 12.1&.2: if including PTR or SRV record, include the address records it names + if (hasDnsSdPtrRecordAnswer || hasDnsSdSrvRecordAnswer) { + for (RecordInfo record : mGeneralRecords) { + if (record.record instanceof MdnsInetAddressRecord) { + additionalAnswerInfo.add(record); + } + } + } + + for (RecordInfo info : additionalAnswerInfo) { + additionalAnswerRecords.add(info.record); + } + + // RFC6762 6.1: negative responses + addNsecRecordsForUniqueNames(additionalAnswerRecords, + answerInfo.listIterator(answersStartIndex), + additionalAnswerInfo.listIterator()); + } + /** * Add NSEC records indicating that the response records are unique. * @@ -540,6 +763,7 @@ public class MdnsRecordRepository { final long now = SystemClock.elapsedRealtime(); for (RecordInfo record : registration.allRecords) { record.lastSentTimeMs = now; + record.lastAdvertisedTimeMs = now; } } diff --git a/service-t/src/com/android/server/mdns/MdnsReplySender.java b/service-t/src/com/android/server/mdns/MdnsReplySender.java index c6b8f47be6..f1389cab6c 100644 --- a/service-t/src/com/android/server/mdns/MdnsReplySender.java +++ b/service-t/src/com/android/server/mdns/MdnsReplySender.java @@ -16,8 +16,15 @@ package com.android.server.connectivity.mdns; +import static com.android.server.connectivity.mdns.MdnsSocketProvider.ensureRunningOnHandlerThread; + import android.annotation.NonNull; +import android.os.Handler; import android.os.Looper; +import android.os.Message; +import android.util.Log; + +import com.android.server.connectivity.mdns.MdnsRecordRepository.ReplyInfo; import java.io.IOException; import java.net.DatagramPacket; @@ -25,6 +32,7 @@ import java.net.Inet4Address; import java.net.Inet6Address; import java.net.InetSocketAddress; import java.net.MulticastSocket; +import java.util.Collections; /** * A class that handles sending mDNS replies to a {@link MulticastSocket}, possibly queueing them @@ -33,20 +41,38 @@ import java.net.MulticastSocket; * TODO: implement sending after a delay, combining queued replies and duplicate answer suppression */ public class MdnsReplySender { + private static final boolean DBG = MdnsAdvertiser.DBG; + private static final int MSG_SEND = 1; + + private final String mLogTag; @NonNull private final MdnsInterfaceSocket mSocket; @NonNull - private final Looper mLooper; + private final Handler mHandler; @NonNull private final byte[] mPacketCreationBuffer; - public MdnsReplySender(@NonNull Looper looper, + public MdnsReplySender(@NonNull String interfaceTag, @NonNull Looper looper, @NonNull MdnsInterfaceSocket socket, @NonNull byte[] packetCreationBuffer) { - mLooper = looper; + mHandler = new SendHandler(looper); + mLogTag = MdnsReplySender.class.getSimpleName() + "/" + interfaceTag; mSocket = socket; mPacketCreationBuffer = packetCreationBuffer; } + /** + * Queue a reply to be sent when its send delay expires. + */ + public void queueReply(@NonNull ReplyInfo reply) { + ensureRunningOnHandlerThread(mHandler); + // TODO: implement response aggregation (RFC 6762 6.4) + mHandler.sendMessageDelayed(mHandler.obtainMessage(MSG_SEND, reply), reply.sendDelayMs); + + if (DBG) { + Log.v(mLogTag, "Scheduling " + reply); + } + } + /** * Send a packet immediately. * @@ -54,9 +80,7 @@ public class MdnsReplySender { */ public void sendNow(@NonNull MdnsPacket packet, @NonNull InetSocketAddress destination) throws IOException { - if (Thread.currentThread() != mLooper.getThread()) { - throw new IllegalStateException("sendNow must be called in the handler thread"); - } + ensureRunningOnHandlerThread(mHandler); if (!((destination.getAddress() instanceof Inet6Address && mSocket.hasJoinedIpv6()) || (destination.getAddress() instanceof Inet4Address && mSocket.hasJoinedIpv4()))) { // Skip sending if the socket has not joined the v4/v6 group (there was no address) @@ -93,4 +117,37 @@ public class MdnsReplySender { mSocket.send(new DatagramPacket(outBuffer, 0, len, destination)); } + + /** + * Cancel all pending sends. + */ + public void cancelAll() { + ensureRunningOnHandlerThread(mHandler); + mHandler.removeMessages(MSG_SEND); + } + + private class SendHandler extends Handler { + SendHandler(@NonNull Looper looper) { + super(looper); + } + + @Override + public void handleMessage(@NonNull Message msg) { + final ReplyInfo replyInfo = (ReplyInfo) msg.obj; + if (DBG) Log.v(mLogTag, "Sending " + replyInfo); + + final int flags = 0x8400; // Response, authoritative (rfc6762 18.4) + final MdnsPacket packet = new MdnsPacket(flags, + Collections.emptyList() /* questions */, + replyInfo.answers, + Collections.emptyList() /* authorityRecords */, + replyInfo.additionalAnswers); + + try { + sendNow(packet, replyInfo.destination); + } catch (IOException e) { + Log.e(mLogTag, "Error sending MDNS response", e); + } + } + } } diff --git a/service-t/src/com/android/server/mdns/MulticastPacketReader.java b/service-t/src/com/android/server/mdns/MulticastPacketReader.java index 20cc47f7bd..b597f0a831 100644 --- a/service-t/src/com/android/server/mdns/MulticastPacketReader.java +++ b/service-t/src/com/android/server/mdns/MulticastPacketReader.java @@ -107,5 +107,14 @@ public class MulticastPacketReader extends FdEventsReader diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt index 2cb0850c46..02b39767c4 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt @@ -21,6 +21,7 @@ import android.net.LinkAddress import android.net.nsd.NsdServiceInfo import android.os.Build import android.os.HandlerThread +import com.android.net.module.util.HexDump import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo import com.android.server.connectivity.mdns.MdnsAnnouncer.ExitAnnouncementInfo @@ -30,6 +31,10 @@ import com.android.server.connectivity.mdns.MdnsProber.ProbingInfo import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo import com.android.testutils.DevSdkIgnoreRunner import com.android.testutils.waitForIdle +import java.net.InetSocketAddress +import kotlin.test.assertContentEquals +import kotlin.test.assertEquals +import kotlin.test.assertTrue import org.junit.After import org.junit.Before import org.junit.Test @@ -37,8 +42,10 @@ import org.junit.runner.RunWith import org.mockito.ArgumentCaptor import org.mockito.Mockito.any import org.mockito.Mockito.anyInt +import org.mockito.Mockito.anyString import org.mockito.Mockito.doAnswer import org.mockito.Mockito.doReturn +import org.mockito.Mockito.eq import org.mockito.Mockito.mock import org.mockito.Mockito.times import org.mockito.Mockito.verify @@ -67,13 +74,18 @@ class MdnsInterfaceAdvertiserTest { private val replySender = mock(MdnsReplySender::class.java) private val announcer = mock(MdnsAnnouncer::class.java) private val prober = mock(MdnsProber::class.java) + @Suppress("UNCHECKED_CAST") private val probeCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java) as ArgumentCaptor> + @Suppress("UNCHECKED_CAST") private val announceCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java) as ArgumentCaptor> + private val packetHandlerCaptor = ArgumentCaptor.forClass( + MulticastPacketReader.PacketHandler::class.java) private val probeCb get() = probeCbCaptor.value private val announceCb get() = announceCbCaptor.value + private val packetHandler get() = packetHandlerCaptor.value private val advertiser by lazy { MdnsInterfaceAdvertiser(LOG_TAG, socket, TEST_ADDRS, thread.looper, TEST_BUFFER, cb, deps) @@ -82,9 +94,9 @@ class MdnsInterfaceAdvertiserTest { @Before fun setUp() { doReturn(repository).`when`(deps).makeRecordRepository(any()) - doReturn(replySender).`when`(deps).makeReplySender(any(), any(), any()) - doReturn(announcer).`when`(deps).makeMdnsAnnouncer(any(), any(), any(), any()) - doReturn(prober).`when`(deps).makeMdnsProber(any(), any(), any(), any()) + doReturn(replySender).`when`(deps).makeReplySender(anyString(), any(), any(), any()) + doReturn(announcer).`when`(deps).makeMdnsAnnouncer(anyString(), any(), any(), any()) + doReturn(prober).`when`(deps).makeMdnsProber(anyString(), any(), any(), any()) val knownServices = mutableSetOf() doAnswer { inv -> @@ -104,6 +116,7 @@ class MdnsInterfaceAdvertiserTest { thread.start() advertiser.start() + verify(socket).addPacketHandler(packetHandlerCaptor.capture()) verify(deps).makeMdnsProber(any(), any(), any(), probeCbCaptor.capture()) verify(deps).makeMdnsAnnouncer(any(), any(), any(), announceCbCaptor.capture()) } @@ -157,6 +170,39 @@ class MdnsInterfaceAdvertiserTest { verify(announcer, times(1)).stop(TEST_SERVICE_ID_1) } + @Test + fun testReplyToQuery() { + addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1) + + val mockReply = mock(MdnsRecordRepository.ReplyInfo::class.java) + doReturn(mockReply).`when`(repository).getReply(any(), any()) + + // Query obtained with: + // scapy.raw(scapy.DNS( + // qd = scapy.DNSQR(qtype='PTR', qname='_testservice._tcp.local')) + // ).hex().upper() + val query = HexDump.hexStringToByteArray( + "0000010000010000000000000C5F7465737473657276696365045F746370056C6F63616C00000C0001" + ) + val src = InetSocketAddress(parseNumericAddress("2001:db8::456"), MdnsConstants.MDNS_PORT) + packetHandler.handlePacket(query, query.size, src) + + val packetCaptor = ArgumentCaptor.forClass(MdnsPacket::class.java) + verify(repository).getReply(packetCaptor.capture(), eq(src)) + + packetCaptor.value.let { + assertEquals(1, it.questions.size) + assertEquals(0, it.answers.size) + assertEquals(0, it.authorityRecords.size) + assertEquals(0, it.additionalRecords.size) + + assertTrue(it.questions[0] is MdnsPointerRecord) + assertContentEquals(arrayOf("_testservice", "_tcp", "local"), it.questions[0].name) + } + + verify(replySender).queueReply(mockReply) + } + private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo): AnnouncementInfo { val testProbingInfo = mock(ProbingInfo::class.java) diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt index 3caa97dbca..a2dbbc65ff 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt @@ -114,7 +114,7 @@ class MdnsProberTest { @Test fun testProbe() { - val replySender = MdnsReplySender(thread.looper, socket, buffer) + val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer) val prober = TestProber(thread.looper, replySender, cb) val probeInfo = TestProbeInfo( listOf(makeServiceRecord(TEST_SERVICE_NAME_1, 37890))) @@ -129,7 +129,7 @@ class MdnsProberTest { @Test fun testProbeMultipleRecords() { - val replySender = MdnsReplySender(thread.looper, socket, buffer) + val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer) val prober = TestProber(thread.looper, replySender, cb) val probeInfo = TestProbeInfo(listOf( makeServiceRecord(TEST_SERVICE_NAME_1, 37890), @@ -167,7 +167,7 @@ class MdnsProberTest { @Test fun testStopProbing() { - val replySender = MdnsReplySender(thread.looper, socket, buffer) + val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer) val prober = TestProber(thread.looper, replySender, cb) val probeInfo = TestProbeInfo( listOf(makeServiceRecord(TEST_SERVICE_NAME_1, 37890)), diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt index 29d0854e22..597663c6b5 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt @@ -21,10 +21,12 @@ import android.net.LinkAddress import android.net.nsd.NsdServiceInfo import android.os.Build import android.os.HandlerThread +import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo import com.android.server.connectivity.mdns.MdnsRecordRepository.Dependencies import com.android.server.connectivity.mdns.MdnsRecordRepository.getReverseDnsAddress import com.android.testutils.DevSdkIgnoreRule import com.android.testutils.DevSdkIgnoreRunner +import java.net.InetSocketAddress import java.net.NetworkInterface import java.util.Collections import kotlin.test.assertContentEquals @@ -150,11 +152,7 @@ class MdnsRecordRepositoryTest { @Test fun testExitAnnouncements() { val repository = MdnsRecordRepository(thread.looper, deps) - repository.updateAddresses(TEST_ADDRESSES) - - repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1) - val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1) - repository.onProbingSucceeded(probingInfo) + repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1) repository.onAdvertisementSent(TEST_SERVICE_ID_1) val exitAnnouncement = repository.exitService(TEST_SERVICE_ID_1) @@ -183,9 +181,7 @@ class MdnsRecordRepositoryTest { @Test fun testExitingServiceReAdded() { val repository = MdnsRecordRepository(thread.looper, deps) - repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1) - val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1) - repository.onProbingSucceeded(probingInfo) + repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1) repository.onAdvertisementSent(TEST_SERVICE_ID_1) repository.exitService(TEST_SERVICE_ID_1) @@ -199,11 +195,8 @@ class MdnsRecordRepositoryTest { @Test fun testOnProbingSucceeded() { val repository = MdnsRecordRepository(thread.looper, deps) - repository.updateAddresses(TEST_ADDRESSES) - - repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1) - val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1) - val announcementInfo = repository.onProbingSucceeded(probingInfo) + val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1) + repository.onAdvertisementSent(TEST_SERVICE_ID_1) val packet = announcementInfo.getPacket(0) assertEquals(0x8400 /* response, authoritative */, packet.flags) @@ -322,4 +315,98 @@ class MdnsRecordRepositoryTest { val expectedV4 = "123.2.0.192.in-addr.arpa".split(".").toTypedArray() assertContentEquals(expectedV4, getReverseDnsAddress(parseNumericAddress("192.0.2.123"))) } + + @Test + fun testGetReply() { + val repository = MdnsRecordRepository(thread.looper, deps) + repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1) + val questions = listOf(MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"), + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + // TTL and data is empty for a question + 0L /* ttlMillis */, + null /* pointer */)) + val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */, + listOf() /* authorityRecords */, listOf() /* additionalRecords */) + val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353) + val reply = repository.getReply(query, src) + + assertNotNull(reply) + // Source address is IPv4 + assertEquals(MdnsConstants.getMdnsIPv4Address(), reply.destination.address) + assertEquals(MdnsConstants.MDNS_PORT, reply.destination.port) + + // TTLs as per RFC6762 10. + val longTtl = 4_500_000L + val shortTtl = 120_000L + val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local") + + assertEquals(listOf( + MdnsPointerRecord( + arrayOf("_testservice", "_tcp", "local"), + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + longTtl, + serviceName), + ), reply.answers) + + assertEquals(listOf( + MdnsTextRecord( + serviceName, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + longTtl, + listOf() /* entries */), + MdnsServiceRecord( + serviceName, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + shortTtl, + 0 /* servicePriority */, + 0 /* serviceWeight */, + TEST_PORT, + TEST_HOSTNAME), + MdnsInetAddressRecord( + TEST_HOSTNAME, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + shortTtl, + TEST_ADDRESSES[0].address), + MdnsInetAddressRecord( + TEST_HOSTNAME, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + shortTtl, + TEST_ADDRESSES[1].address), + MdnsInetAddressRecord( + TEST_HOSTNAME, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + shortTtl, + TEST_ADDRESSES[2].address), + MdnsNsecRecord( + serviceName, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + longTtl, + serviceName /* nextDomain */, + intArrayOf(MdnsRecord.TYPE_TXT, MdnsRecord.TYPE_SRV)), + MdnsNsecRecord( + TEST_HOSTNAME, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + shortTtl, + TEST_HOSTNAME /* nextDomain */, + intArrayOf(MdnsRecord.TYPE_A, MdnsRecord.TYPE_AAAA)), + ), reply.additionalAnswers) + } +} + +private fun MdnsRecordRepository.initWithService(serviceId: Int, serviceInfo: NsdServiceInfo): + AnnouncementInfo { + updateAddresses(TEST_ADDRESSES) + addService(serviceId, serviceInfo) + val probingInfo = setServiceProbing(serviceId) + assertNotNull(probingInfo) + return onProbingSucceeded(probingInfo) } From b1b7fab156dc7d8bf3046d85f7a63003d5bb8ac0 Mon Sep 17 00:00:00 2001 From: Remi NGUYEN VAN Date: Fri, 13 Jan 2023 20:46:58 +0900 Subject: [PATCH 2/2] Implement onServiceConflict Implement the onServiceConflict callback in MdnsAdvertiser, refactoring the conflict detection to reuse it both in onServiceConflict (when a conflict is detected on the network after add) and at service add time. Bug: 241738458 Test: atest MdnsAdvertiserTest Change-Id: I69128db936296bd2c5e90e9f00df19fd881e1748 --- .../android/server/mdns/MdnsAdvertiser.java | 193 ++++++++++++------ .../connectivity/mdns/MdnsAdvertiserTest.kt | 55 +++++ 2 files changed, 180 insertions(+), 68 deletions(-) diff --git a/service-t/src/com/android/server/mdns/MdnsAdvertiser.java b/service-t/src/com/android/server/mdns/MdnsAdvertiser.java index 4e40efe4ec..977478adba 100644 --- a/service-t/src/com/android/server/mdns/MdnsAdvertiser.java +++ b/service-t/src/com/android/server/mdns/MdnsAdvertiser.java @@ -29,10 +29,10 @@ import android.util.SparseArray; import com.android.internal.annotations.VisibleForTesting; -import java.io.IOException; import java.util.List; import java.util.Map; -import java.util.function.Predicate; +import java.util.function.BiPredicate; +import java.util.function.Consumer; /** * MdnsAdvertiser manages advertising services per {@link com.android.server.NsdService} requests. @@ -85,7 +85,7 @@ public class MdnsAdvertiser { public void onRegisterServiceSucceeded( @NonNull MdnsInterfaceAdvertiser advertiser, int serviceId) { // Wait for all current interfaces to be done probing before notifying of success. - if (anyAdvertiser(a -> a.isProbing(serviceId))) return; + if (any(mAllAdvertisers, (k, a) -> a.isProbing(serviceId))) return; // The service may still be unregistered/renamed if a conflict is found on a later added // interface, or if a conflicting announcement/reply is detected (RFC6762 9.) @@ -102,7 +102,37 @@ public class MdnsAdvertiser { @Override public void onServiceConflict(@NonNull MdnsInterfaceAdvertiser advertiser, int serviceId) { - // TODO: handle conflicts found after registration (during or after probing) + if (DBG) { + Log.v(TAG, "Found conflict, restarted probing for service " + serviceId); + } + + final Registration registration = mRegistrations.get(serviceId); + if (registration == null) return; + if (registration.mNotifiedRegistrationSuccess) { + // TODO: consider notifying clients that the service is no longer registered with + // the old name (back to probing). The legacy implementation did not send any + // callback though; it only sent onServiceRegistered after re-probing finishes + // (with the old, conflicting, actually not used name as argument... The new + // implementation will send callbacks with the new name). + registration.mNotifiedRegistrationSuccess = false; + + // The service was done probing, just reset it to probing state (RFC6762 9.) + forAllAdvertisers(a -> a.restartProbingForConflict(serviceId)); + return; + } + + // Conflict was found during probing; rename once to find a name that has no conflict + registration.updateForConflict( + registration.makeNewServiceInfoForConflict(1 /* renameCount */), + 1 /* renameCount */); + + // Keep renaming if the new name conflicts in local registrations + updateRegistrationUntilNoConflict((net, adv) -> adv.hasRegistration(registration), + registration); + + // Update advertisers to use the new name + forAllAdvertisers(a -> a.renameServiceForConflict( + serviceId, registration.getServiceInfo())); } @Override @@ -116,6 +146,25 @@ public class MdnsAdvertiser { } }; + private boolean hasAnyConflict( + @NonNull BiPredicate applicableAdvertiserFilter, + @NonNull NsdServiceInfo newInfo) { + return any(mAdvertiserRequests, (network, adv) -> + applicableAdvertiserFilter.test(network, adv) && adv.hasConflict(newInfo)); + } + + private void updateRegistrationUntilNoConflict( + @NonNull BiPredicate applicableAdvertiserFilter, + @NonNull Registration registration) { + int renameCount = 0; + NsdServiceInfo newInfo = registration.getServiceInfo(); + while (hasAnyConflict(applicableAdvertiserFilter, newInfo)) { + renameCount++; + newInfo = registration.makeNewServiceInfoForConflict(renameCount); + } + registration.updateForConflict(newInfo, renameCount); + } + /** * A request for a {@link MdnsInterfaceAdvertiser}. * @@ -152,6 +201,21 @@ public class MdnsAdvertiser { return false; } + /** + * Return whether this {@link InterfaceAdvertiserRequest} has the given registration. + */ + boolean hasRegistration(@NonNull Registration registration) { + return mPendingRegistrations.indexOfValue(registration) >= 0; + } + + /** + * Return whether using the proposed new {@link NsdServiceInfo} to add a registration would + * cause a conflict in this {@link InterfaceAdvertiserRequest}. + */ + boolean hasConflict(@NonNull NsdServiceInfo newInfo) { + return getConflictingService(newInfo) >= 0; + } + /** * Get the ID of a conflicting service, or -1 if none. */ @@ -166,16 +230,19 @@ public class MdnsAdvertiser { return -1; } - void addService(int id, Registration registration) - throws NameConflictException { - final int conflicting = getConflictingService(registration.getServiceInfo()); - if (conflicting >= 0) { - throw new NameConflictException(conflicting); - } - + /** + * Add a service. + * + * Conflicts must be checked via {@link #getConflictingService} before attempting to add. + */ + void addService(int id, Registration registration) { mPendingRegistrations.put(id, registration); for (int i = 0; i < mAdvertisers.size(); i++) { - mAdvertisers.valueAt(i).addService(id, registration.getServiceInfo()); + try { + mAdvertisers.valueAt(i).addService(id, registration.getServiceInfo()); + } catch (NameConflictException e) { + Log.wtf(TAG, "Name conflict adding services that should have unique names", e); + } } } @@ -239,32 +306,42 @@ public class MdnsAdvertiser { /** * Update the registration to use a different service name, after a conflict was found. * + * @param newInfo New service info to use. + * @param renameCount How many renames were done before reaching the current name. + */ + private void updateForConflict(@NonNull NsdServiceInfo newInfo, int renameCount) { + mConflictCount += renameCount; + mServiceInfo = newInfo; + } + + /** + * Make a new service name for the registration, after a conflict was found. + * * If a name conflict was found during probing or because different advertising requests * used the same name, the registration is attempted again with a new name (here using * a number suffix, (1), (2) etc). Registration success is notified once probing succeeds * with a new name. This matches legacy behavior based on mdnsresponder, and appendix D of * RFC6763. - * @return The new service info with the updated name. + * + * @param renameCount How much to increase the number suffix for this conflict. */ @NonNull - private NsdServiceInfo updateForConflict() { - mConflictCount++; + public NsdServiceInfo makeNewServiceInfoForConflict(int renameCount) { // In case of conflict choose a different service name. After the first conflict use // "Name (2)", then "Name (3)" etc. // TODO: use a hidden method in NsdServiceInfo once MdnsAdvertiser is moved to service-t final NsdServiceInfo newInfo = new NsdServiceInfo(); - newInfo.setServiceName(mOriginalName + " (" + (mConflictCount + 1) + ")"); + newInfo.setServiceName(mOriginalName + " (" + (mConflictCount + renameCount + 1) + ")"); newInfo.setServiceType(mServiceInfo.getServiceType()); for (Map.Entry attr : mServiceInfo.getAttributes().entrySet()) { - newInfo.setAttribute(attr.getKey(), attr.getValue()); + newInfo.setAttribute(attr.getKey(), + attr.getValue() == null ? null : new String(attr.getValue())); } newInfo.setHost(mServiceInfo.getHost()); newInfo.setPort(mServiceInfo.getPort()); newInfo.setNetwork(mServiceInfo.getNetwork()); // interfaceIndex is not set when registering - - mServiceInfo = newInfo; - return mServiceInfo; + return newInfo; } @NonNull @@ -338,55 +415,27 @@ public class MdnsAdvertiser { Log.i(TAG, "Adding service " + service + " with ID " + id); } - try { - final Registration registration = new Registration(service); - while (!tryAddRegistration(id, registration)) { - registration.updateForConflict(); - } - - mRegistrations.put(id, registration); - } catch (IOException e) { - Log.e(TAG, "Error adding service " + service, e); - removeService(id); - // TODO (b/264986328): add a more specific error code - mCb.onRegisterServiceFailed(id, NsdManager.FAILURE_INTERNAL_ERROR); - } - } - - private boolean tryAddRegistration(int id, @NonNull Registration registration) - throws IOException { - final NsdServiceInfo serviceInfo = registration.getServiceInfo(); - final Network network = serviceInfo.getNetwork(); - try { - InterfaceAdvertiserRequest advertiser = mAdvertiserRequests.get(network); - if (advertiser == null) { - advertiser = new InterfaceAdvertiserRequest(network); - mAdvertiserRequests.put(network, advertiser); - } - advertiser.addService(id, registration); - } catch (NameConflictException e) { - if (DBG) { - Log.i(TAG, "Service name conflicts: " + serviceInfo.getServiceName()); - } - removeService(id); - return false; + final Network network = service.getNetwork(); + final Registration registration = new Registration(service); + final BiPredicate checkConflictFilter; + if (network == null) { + // If registering on all networks, no advertiser must have conflicts + checkConflictFilter = (net, adv) -> true; + } else { + // If registering on one network, the matching network advertiser and the one for all + // networks must not have conflicts + checkConflictFilter = (net, adv) -> net == null || network.equals(net); } - // When adding a service to a specific network, check that it does not conflict with other - // registrations advertising on all networks - final InterfaceAdvertiserRequest allNetworksAdvertiser = mAdvertiserRequests.get(null); - if (network != null && allNetworksAdvertiser != null - && allNetworksAdvertiser.getConflictingService(serviceInfo) >= 0) { - if (DBG) { - Log.i(TAG, "Service conflicts with advertisement on all networks: " - + serviceInfo.getServiceName()); - } - removeService(id); - return false; - } + updateRegistrationUntilNoConflict(checkConflictFilter, registration); + InterfaceAdvertiserRequest advertiser = mAdvertiserRequests.get(network); + if (advertiser == null) { + advertiser = new InterfaceAdvertiserRequest(network); + mAdvertiserRequests.put(network, advertiser); + } + advertiser.addService(id, registration); mRegistrations.put(id, registration); - return true; } /** @@ -406,12 +455,20 @@ public class MdnsAdvertiser { mRegistrations.remove(id); } - private boolean anyAdvertiser(@NonNull Predicate predicate) { - for (int i = 0; i < mAllAdvertisers.size(); i++) { - if (predicate.test(mAllAdvertisers.valueAt(i))) { + private static boolean any(@NonNull ArrayMap map, + @NonNull BiPredicate predicate) { + for (int i = 0; i < map.size(); i++) { + if (predicate.test(map.keyAt(i), map.valueAt(i))) { return true; } } return false; } + + private void forAllAdvertisers(@NonNull Consumer consumer) { + any(mAllAdvertisers, (socket, advertiser) -> { + consumer.accept(advertiser); + return false; + }); + } } diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt index e2babb175d..1febe6dd00 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt @@ -38,6 +38,7 @@ import org.mockito.ArgumentMatchers.eq import org.mockito.Mockito.any import org.mockito.Mockito.anyInt import org.mockito.Mockito.argThat +import org.mockito.Mockito.atLeastOnce import org.mockito.Mockito.doReturn import org.mockito.Mockito.mock import org.mockito.Mockito.never @@ -161,6 +162,60 @@ class MdnsAdvertiserTest { verify(socketProvider).unrequestSocket(socketCb) } + @Test + fun testAddService_Conflicts() { + val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps) + postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1) } + + val oneNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java) + verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), oneNetSocketCbCaptor.capture()) + val oneNetSocketCb = oneNetSocketCbCaptor.value + + // Register a service with the same name on all networks (name conflict) + postSync { advertiser.addService(SERVICE_ID_2, ALL_NETWORKS_SERVICE) } + val allNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java) + verify(socketProvider).requestSocket(eq(null), allNetSocketCbCaptor.capture()) + val allNetSocketCb = allNetSocketCbCaptor.value + + // Callbacks for matching network and all networks both get the socket + postSync { + oneNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR)) + allNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR)) + } + + val expectedRenamed = NsdServiceInfo( + "${ALL_NETWORKS_SERVICE.serviceName} (2)", ALL_NETWORKS_SERVICE.serviceType).apply { + port = ALL_NETWORKS_SERVICE.port + host = ALL_NETWORKS_SERVICE.host + network = ALL_NETWORKS_SERVICE.network + } + + val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java) + verify(mockDeps).makeAdvertiser(eq(mockSocket1), eq(listOf(TEST_LINKADDR)), + eq(thread.looper), any(), intAdvCbCaptor.capture()) + verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_1), + argThat { it.matches(SERVICE_1) }) + verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_2), + argThat { it.matches(expectedRenamed) }) + + doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_1) + postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded( + mockInterfaceAdvertiser1, SERVICE_ID_1) } + verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_1), argThat { it.matches(SERVICE_1) }) + + doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_2) + postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded( + mockInterfaceAdvertiser1, SERVICE_ID_2) } + verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_2), + argThat { it.matches(expectedRenamed) }) + + postSync { oneNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) } + postSync { allNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) } + + // destroyNow can be called multiple times + verify(mockInterfaceAdvertiser1, atLeastOnce()).destroyNow() + } + private fun postSync(r: () -> Unit) { handler.post(r) handler.waitForIdle(TIMEOUT_MS)