Merge changes I69128db9,I13db22f8

* changes:
  Implement onServiceConflict
  Add replying to queries
This commit is contained in:
Remi NGUYEN VAN
2023-01-18 01:13:35 +00:00
committed by Gerrit Code Review
14 changed files with 682 additions and 107 deletions

View File

@@ -29,10 +29,10 @@ import android.util.SparseArray;
import com.android.internal.annotations.VisibleForTesting;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.function.BiPredicate;
import java.util.function.Consumer;
/**
* MdnsAdvertiser manages advertising services per {@link com.android.server.NsdService} requests.
@@ -85,7 +85,7 @@ public class MdnsAdvertiser {
public void onRegisterServiceSucceeded(
@NonNull MdnsInterfaceAdvertiser advertiser, int serviceId) {
// Wait for all current interfaces to be done probing before notifying of success.
if (anyAdvertiser(a -> a.isProbing(serviceId))) return;
if (any(mAllAdvertisers, (k, a) -> a.isProbing(serviceId))) return;
// The service may still be unregistered/renamed if a conflict is found on a later added
// interface, or if a conflicting announcement/reply is detected (RFC6762 9.)
@@ -102,7 +102,37 @@ public class MdnsAdvertiser {
@Override
public void onServiceConflict(@NonNull MdnsInterfaceAdvertiser advertiser, int serviceId) {
// TODO: handle conflicts found after registration (during or after probing)
if (DBG) {
Log.v(TAG, "Found conflict, restarted probing for service " + serviceId);
}
final Registration registration = mRegistrations.get(serviceId);
if (registration == null) return;
if (registration.mNotifiedRegistrationSuccess) {
// TODO: consider notifying clients that the service is no longer registered with
// the old name (back to probing). The legacy implementation did not send any
// callback though; it only sent onServiceRegistered after re-probing finishes
// (with the old, conflicting, actually not used name as argument... The new
// implementation will send callbacks with the new name).
registration.mNotifiedRegistrationSuccess = false;
// The service was done probing, just reset it to probing state (RFC6762 9.)
forAllAdvertisers(a -> a.restartProbingForConflict(serviceId));
return;
}
// Conflict was found during probing; rename once to find a name that has no conflict
registration.updateForConflict(
registration.makeNewServiceInfoForConflict(1 /* renameCount */),
1 /* renameCount */);
// Keep renaming if the new name conflicts in local registrations
updateRegistrationUntilNoConflict((net, adv) -> adv.hasRegistration(registration),
registration);
// Update advertisers to use the new name
forAllAdvertisers(a -> a.renameServiceForConflict(
serviceId, registration.getServiceInfo()));
}
@Override
@@ -116,6 +146,25 @@ public class MdnsAdvertiser {
}
};
private boolean hasAnyConflict(
@NonNull BiPredicate<Network, InterfaceAdvertiserRequest> applicableAdvertiserFilter,
@NonNull NsdServiceInfo newInfo) {
return any(mAdvertiserRequests, (network, adv) ->
applicableAdvertiserFilter.test(network, adv) && adv.hasConflict(newInfo));
}
private void updateRegistrationUntilNoConflict(
@NonNull BiPredicate<Network, InterfaceAdvertiserRequest> applicableAdvertiserFilter,
@NonNull Registration registration) {
int renameCount = 0;
NsdServiceInfo newInfo = registration.getServiceInfo();
while (hasAnyConflict(applicableAdvertiserFilter, newInfo)) {
renameCount++;
newInfo = registration.makeNewServiceInfoForConflict(renameCount);
}
registration.updateForConflict(newInfo, renameCount);
}
/**
* A request for a {@link MdnsInterfaceAdvertiser}.
*
@@ -152,6 +201,21 @@ public class MdnsAdvertiser {
return false;
}
/**
* Return whether this {@link InterfaceAdvertiserRequest} has the given registration.
*/
boolean hasRegistration(@NonNull Registration registration) {
return mPendingRegistrations.indexOfValue(registration) >= 0;
}
/**
* Return whether using the proposed new {@link NsdServiceInfo} to add a registration would
* cause a conflict in this {@link InterfaceAdvertiserRequest}.
*/
boolean hasConflict(@NonNull NsdServiceInfo newInfo) {
return getConflictingService(newInfo) >= 0;
}
/**
* Get the ID of a conflicting service, or -1 if none.
*/
@@ -166,16 +230,19 @@ public class MdnsAdvertiser {
return -1;
}
void addService(int id, Registration registration)
throws NameConflictException {
final int conflicting = getConflictingService(registration.getServiceInfo());
if (conflicting >= 0) {
throw new NameConflictException(conflicting);
}
/**
* Add a service.
*
* Conflicts must be checked via {@link #getConflictingService} before attempting to add.
*/
void addService(int id, Registration registration) {
mPendingRegistrations.put(id, registration);
for (int i = 0; i < mAdvertisers.size(); i++) {
mAdvertisers.valueAt(i).addService(id, registration.getServiceInfo());
try {
mAdvertisers.valueAt(i).addService(id, registration.getServiceInfo());
} catch (NameConflictException e) {
Log.wtf(TAG, "Name conflict adding services that should have unique names", e);
}
}
}
@@ -239,32 +306,42 @@ public class MdnsAdvertiser {
/**
* Update the registration to use a different service name, after a conflict was found.
*
* @param newInfo New service info to use.
* @param renameCount How many renames were done before reaching the current name.
*/
private void updateForConflict(@NonNull NsdServiceInfo newInfo, int renameCount) {
mConflictCount += renameCount;
mServiceInfo = newInfo;
}
/**
* Make a new service name for the registration, after a conflict was found.
*
* If a name conflict was found during probing or because different advertising requests
* used the same name, the registration is attempted again with a new name (here using
* a number suffix, (1), (2) etc). Registration success is notified once probing succeeds
* with a new name. This matches legacy behavior based on mdnsresponder, and appendix D of
* RFC6763.
* @return The new service info with the updated name.
*
* @param renameCount How much to increase the number suffix for this conflict.
*/
@NonNull
private NsdServiceInfo updateForConflict() {
mConflictCount++;
public NsdServiceInfo makeNewServiceInfoForConflict(int renameCount) {
// In case of conflict choose a different service name. After the first conflict use
// "Name (2)", then "Name (3)" etc.
// TODO: use a hidden method in NsdServiceInfo once MdnsAdvertiser is moved to service-t
final NsdServiceInfo newInfo = new NsdServiceInfo();
newInfo.setServiceName(mOriginalName + " (" + (mConflictCount + 1) + ")");
newInfo.setServiceName(mOriginalName + " (" + (mConflictCount + renameCount + 1) + ")");
newInfo.setServiceType(mServiceInfo.getServiceType());
for (Map.Entry<String, byte[]> attr : mServiceInfo.getAttributes().entrySet()) {
newInfo.setAttribute(attr.getKey(), attr.getValue());
newInfo.setAttribute(attr.getKey(),
attr.getValue() == null ? null : new String(attr.getValue()));
}
newInfo.setHost(mServiceInfo.getHost());
newInfo.setPort(mServiceInfo.getPort());
newInfo.setNetwork(mServiceInfo.getNetwork());
// interfaceIndex is not set when registering
mServiceInfo = newInfo;
return mServiceInfo;
return newInfo;
}
@NonNull
@@ -338,55 +415,27 @@ public class MdnsAdvertiser {
Log.i(TAG, "Adding service " + service + " with ID " + id);
}
try {
final Registration registration = new Registration(service);
while (!tryAddRegistration(id, registration)) {
registration.updateForConflict();
}
mRegistrations.put(id, registration);
} catch (IOException e) {
Log.e(TAG, "Error adding service " + service, e);
removeService(id);
// TODO (b/264986328): add a more specific error code
mCb.onRegisterServiceFailed(id, NsdManager.FAILURE_INTERNAL_ERROR);
}
}
private boolean tryAddRegistration(int id, @NonNull Registration registration)
throws IOException {
final NsdServiceInfo serviceInfo = registration.getServiceInfo();
final Network network = serviceInfo.getNetwork();
try {
InterfaceAdvertiserRequest advertiser = mAdvertiserRequests.get(network);
if (advertiser == null) {
advertiser = new InterfaceAdvertiserRequest(network);
mAdvertiserRequests.put(network, advertiser);
}
advertiser.addService(id, registration);
} catch (NameConflictException e) {
if (DBG) {
Log.i(TAG, "Service name conflicts: " + serviceInfo.getServiceName());
}
removeService(id);
return false;
final Network network = service.getNetwork();
final Registration registration = new Registration(service);
final BiPredicate<Network, InterfaceAdvertiserRequest> checkConflictFilter;
if (network == null) {
// If registering on all networks, no advertiser must have conflicts
checkConflictFilter = (net, adv) -> true;
} else {
// If registering on one network, the matching network advertiser and the one for all
// networks must not have conflicts
checkConflictFilter = (net, adv) -> net == null || network.equals(net);
}
// When adding a service to a specific network, check that it does not conflict with other
// registrations advertising on all networks
final InterfaceAdvertiserRequest allNetworksAdvertiser = mAdvertiserRequests.get(null);
if (network != null && allNetworksAdvertiser != null
&& allNetworksAdvertiser.getConflictingService(serviceInfo) >= 0) {
if (DBG) {
Log.i(TAG, "Service conflicts with advertisement on all networks: "
+ serviceInfo.getServiceName());
}
removeService(id);
return false;
}
updateRegistrationUntilNoConflict(checkConflictFilter, registration);
InterfaceAdvertiserRequest advertiser = mAdvertiserRequests.get(network);
if (advertiser == null) {
advertiser = new InterfaceAdvertiserRequest(network);
mAdvertiserRequests.put(network, advertiser);
}
advertiser.addService(id, registration);
mRegistrations.put(id, registration);
return true;
}
/**
@@ -406,12 +455,20 @@ public class MdnsAdvertiser {
mRegistrations.remove(id);
}
private boolean anyAdvertiser(@NonNull Predicate<MdnsInterfaceAdvertiser> predicate) {
for (int i = 0; i < mAllAdvertisers.size(); i++) {
if (predicate.test(mAllAdvertisers.valueAt(i))) {
private static <K, V> boolean any(@NonNull ArrayMap<K, V> map,
@NonNull BiPredicate<K, V> predicate) {
for (int i = 0; i < map.size(); i++) {
if (predicate.test(map.keyAt(i), map.valueAt(i))) {
return true;
}
}
return false;
}
private void forAllAdvertisers(@NonNull Consumer<MdnsInterfaceAdvertiser> consumer) {
any(mAllAdvertisers, (socket, advertiser) -> {
consumer.accept(advertiser);
return false;
});
}
}

View File

@@ -37,6 +37,7 @@ public final class MdnsConstants {
public static final int FLAGS_QUERY = 0x0000;
public static final int FLAGS_RESPONSE_MASK = 0xF80F;
public static final int FLAGS_RESPONSE = 0x8000;
public static final int FLAG_TRUNCATED = 0x0200;
public static final int QCLASS_INTERNET = 0x0001;
public static final int QCLASS_UNICAST = 0x8000;
public static final String SUBTYPE_LABEL = "_sub";

View File

@@ -25,16 +25,18 @@ import android.os.Looper;
import android.util.Log;
import com.android.internal.annotations.VisibleForTesting;
import com.android.net.module.util.HexDump;
import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo;
import com.android.server.connectivity.mdns.MdnsPacketRepeater.PacketRepeaterCallback;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.List;
/**
* A class that handles advertising services on a {@link MdnsInterfaceSocket} tied to an interface.
*/
public class MdnsInterfaceAdvertiser {
public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHandler {
private static final boolean DBG = MdnsAdvertiser.DBG;
@VisibleForTesting
public static final long EXIT_ANNOUNCEMENT_DELAY_MS = 100L;
@@ -145,9 +147,9 @@ public class MdnsInterfaceAdvertiser {
/** @see MdnsReplySender */
@NonNull
public MdnsReplySender makeReplySender(@NonNull Looper looper,
public MdnsReplySender makeReplySender(@NonNull String interfaceTag, @NonNull Looper looper,
@NonNull MdnsInterfaceSocket socket, @NonNull byte[] packetCreationBuffer) {
return new MdnsReplySender(looper, socket, packetCreationBuffer);
return new MdnsReplySender(interfaceTag, looper, socket, packetCreationBuffer);
}
/** @see MdnsAnnouncer */
@@ -182,7 +184,7 @@ public class MdnsInterfaceAdvertiser {
mSocket = socket;
mCb = cb;
mCbHandler = new Handler(looper);
mReplySender = deps.makeReplySender(looper, socket, packetCreationBuffer);
mReplySender = deps.makeReplySender(logTag, looper, socket, packetCreationBuffer);
mAnnouncer = deps.makeMdnsAnnouncer(logTag, looper, mReplySender,
mAnnouncingCallback);
mProber = deps.makeMdnsProber(logTag, looper, mReplySender, mProbingCallback);
@@ -196,7 +198,7 @@ public class MdnsInterfaceAdvertiser {
* {@link #destroyNow()}.
*/
public void start() {
// TODO: start receiving packets
mSocket.addPacketHandler(this);
}
/**
@@ -267,8 +269,8 @@ public class MdnsInterfaceAdvertiser {
mProber.stop(serviceId);
mAnnouncer.stop(serviceId);
}
// TODO: stop receiving packets
mReplySender.cancelAll();
mSocket.removePacketHandler(this);
mCbHandler.post(() -> mCb.onDestroyed(mSocket));
}
@@ -294,4 +296,33 @@ public class MdnsInterfaceAdvertiser {
public boolean isProbing(int serviceId) {
return mRecordRepository.isProbing(serviceId);
}
@Override
public void handlePacket(byte[] recvbuf, int length, InetSocketAddress src) {
final MdnsPacket packet;
try {
packet = MdnsPacket.parse(new MdnsPacketReader(recvbuf, length));
} catch (MdnsPacket.ParseException e) {
Log.e(mTag, "Error parsing mDNS packet", e);
if (DBG) {
Log.v(
mTag, "Packet: " + HexDump.toHexString(recvbuf, 0, length));
}
return;
}
if (DBG) {
Log.v(mTag,
"Parsed packet with " + packet.questions.size() + " questions, "
+ packet.answers.size() + " answers, "
+ packet.authorityRecords.size() + " authority, "
+ packet.additionalRecords.size() + " additional from " + src);
}
final MdnsRecordRepository.ReplyInfo answers =
mRecordRepository.getReply(packet, src);
if (answers == null) return;
mReplySender.queueReply(answers);
}
}

View File

@@ -161,6 +161,14 @@ public class MdnsInterfaceSocket {
mPacketReader.addPacketHandler(handler);
}
/**
* Remove a handler added via {@link #addPacketHandler}. If the handler is not present, this is
* a no-op.
*/
public void removePacketHandler(@NonNull MulticastPacketReader.PacketHandler handler) {
mPacketReader.removePacketHandler(handler);
}
/**
* Returns the network interface that this socket is bound to.
*

View File

@@ -16,6 +16,9 @@
package com.android.server.connectivity.mdns;
import static com.android.server.connectivity.mdns.MdnsRecordRepository.IPV4_ADDR;
import static com.android.server.connectivity.mdns.MdnsRecordRepository.IPV6_ADDR;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.os.Handler;
@@ -32,10 +35,6 @@ import java.net.InetSocketAddress;
*/
public abstract class MdnsPacketRepeater<T extends MdnsPacketRepeater.Request> {
private static final boolean DBG = MdnsAdvertiser.DBG;
private static final InetSocketAddress IPV4_ADDR = new InetSocketAddress(
MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT);
private static final InetSocketAddress IPV6_ADDR = new InetSocketAddress(
MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
private static final InetSocketAddress[] ALL_ADDRS = new InetSocketAddress[] {
IPV4_ADDR, IPV6_ADDR
};
@@ -114,7 +113,7 @@ public abstract class MdnsPacketRepeater<T extends MdnsPacketRepeater.Request> {
final MdnsPacket packet = request.getPacket(index);
if (DBG) {
Log.v(getTag(), "Sending packets for iteration " + index + " out of "
+ request.getNumSends());
+ request.getNumSends() + " for ID " + msg.what);
}
// Send to both v4 and v6 addresses; the reply sender will take care of ignoring the
// send when the socket has not joined the relevant group.

View File

@@ -45,6 +45,7 @@ public abstract class MdnsRecord {
private static final int FLAG_CACHE_FLUSH = 0x8000;
public static final long RECEIPT_TIME_NOT_SENT = 0L;
public static final int CLASS_ANY = 0x00ff;
/** Status indicating that the record is current. */
public static final int STATUS_OK = 0;
@@ -317,4 +318,4 @@ public abstract class MdnsRecord {
return (recordType * 31) + Arrays.hashCode(recordName);
}
}
}
}

View File

@@ -34,6 +34,7 @@ import com.android.net.module.util.HexDump;
import java.io.IOException;
import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.NetworkInterface;
import java.util.ArrayList;
import java.util.Arrays;
@@ -42,6 +43,7 @@ import java.util.Enumeration;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;
import java.util.UUID;
@@ -54,6 +56,9 @@ import java.util.concurrent.TimeUnit;
*/
@TargetApi(Build.VERSION_CODES.TIRAMISU) // Allow calling T+ APIs; this is only loaded on T+
public class MdnsRecordRepository {
// RFC6762 p.15
private static final long MIN_MULTICAST_REPLY_INTERVAL_MS = 1_000L;
// TTLs as per RFC6762 10.
// TTL for records with a host name as the resource record's name (e.g., A, AAAA, HINFO) or a
// host name contained within the resource record's rdata (e.g., SRV, reverse mapping PTR
@@ -69,6 +74,13 @@ public class MdnsRecordRepository {
private static final String[] DNS_SD_SERVICE_TYPE =
new String[] { "_services", "_dns-sd", "_udp", LOCAL_TLD };
public static final InetSocketAddress IPV6_ADDR = new InetSocketAddress(
MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
public static final InetSocketAddress IPV4_ADDR = new InetSocketAddress(
MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT);
@NonNull
private final Random mDelayGenerator = new Random();
// Map of service unique ID -> records for service
@NonNull
private final SparseArray<ServiceRegistration> mServices = new SparseArray<>();
@@ -138,6 +150,11 @@ public class MdnsRecordRepository {
*/
public boolean isProbing;
/**
* Last time (as per SystemClock.elapsedRealtime) when advertised via multicast, 0 if never
*/
public long lastAdvertisedTimeMs;
/**
* Last time (as per SystemClock.elapsedRealtime) when sent via unicast or multicast,
* 0 if never
@@ -390,6 +407,212 @@ public class MdnsRecordRepository {
return ret;
}
/**
* Info about a reply to be sent.
*/
public static class ReplyInfo {
@NonNull
public final List<MdnsRecord> answers;
@NonNull
public final List<MdnsRecord> additionalAnswers;
public final long sendDelayMs;
@NonNull
public final InetSocketAddress destination;
public ReplyInfo(
@NonNull List<MdnsRecord> answers,
@NonNull List<MdnsRecord> additionalAnswers,
long sendDelayMs,
@NonNull InetSocketAddress destination) {
this.answers = answers;
this.additionalAnswers = additionalAnswers;
this.sendDelayMs = sendDelayMs;
this.destination = destination;
}
@Override
public String toString() {
return "{ReplyInfo to " + destination + ", answers: " + answers.size()
+ ", additionalAnswers: " + additionalAnswers.size()
+ ", sendDelayMs " + sendDelayMs + "}";
}
}
/**
* Get the reply to send to an incoming packet.
*
* @param packet The incoming packet.
* @param src The source address of the incoming packet.
*/
@Nullable
public ReplyInfo getReply(MdnsPacket packet, InetSocketAddress src) {
final long now = SystemClock.elapsedRealtime();
final boolean replyUnicast = (packet.flags & MdnsConstants.QCLASS_UNICAST) != 0;
final ArrayList<MdnsRecord> additionalAnswerRecords = new ArrayList<>();
final ArrayList<RecordInfo<?>> answerInfo = new ArrayList<>();
for (MdnsRecord question : packet.questions) {
// Add answers from general records
addReplyFromService(question, mGeneralRecords, null /* servicePtrRecord */,
null /* serviceSrvRecord */, null /* serviceTxtRecord */, replyUnicast, now,
answerInfo, additionalAnswerRecords);
// Add answers from each service
for (int i = 0; i < mServices.size(); i++) {
final ServiceRegistration registration = mServices.valueAt(i);
if (registration.exiting) continue;
addReplyFromService(question, registration.allRecords, registration.ptrRecord,
registration.srvRecord, registration.txtRecord, replyUnicast, now,
answerInfo, additionalAnswerRecords);
}
}
if (answerInfo.size() == 0 && additionalAnswerRecords.size() == 0) {
return null;
}
// Determine the send delay
final long delayMs;
if ((packet.flags & MdnsConstants.FLAG_TRUNCATED) != 0) {
// RFC 6762 6.: 400-500ms delay if TC bit is set
delayMs = 400L + mDelayGenerator.nextInt(100);
} else if (packet.questions.size() > 1
|| CollectionUtils.any(answerInfo, a -> a.isSharedName)) {
// 20-120ms if there may be responses from other hosts (not a fully owned
// name) (RFC 6762 6.), or if there are multiple questions (6.3).
// TODO: this should be 0 if this is a probe query ("can be distinguished from a
// normal query by the fact that a probe query contains a proposed record in the
// Authority Section that answers the question" in 6.), and the reply is for a fully
// owned record.
delayMs = 20L + mDelayGenerator.nextInt(100);
} else {
delayMs = 0L;
}
// Determine the send destination
final InetSocketAddress dest;
if (replyUnicast) {
dest = src;
} else if (src.getAddress() instanceof Inet4Address) {
dest = IPV4_ADDR;
} else {
dest = IPV6_ADDR;
}
// Build the list of answer records from their RecordInfo
final ArrayList<MdnsRecord> answerRecords = new ArrayList<>(answerInfo.size());
for (RecordInfo<?> info : answerInfo) {
// TODO: consider actual packet send delay after response aggregation
info.lastSentTimeMs = now + delayMs;
if (!replyUnicast) {
info.lastAdvertisedTimeMs = info.lastSentTimeMs;
}
answerRecords.add(info.record);
}
return new ReplyInfo(answerRecords, additionalAnswerRecords, delayMs, dest);
}
/**
* Add answers and additional answers for a question, from a ServiceRegistration.
*/
private void addReplyFromService(@NonNull MdnsRecord question,
@NonNull List<RecordInfo<?>> serviceRecords,
@Nullable RecordInfo<MdnsPointerRecord> servicePtrRecord,
@Nullable RecordInfo<MdnsServiceRecord> serviceSrvRecord,
@Nullable RecordInfo<MdnsTextRecord> serviceTxtRecord,
boolean replyUnicast, long now, @NonNull List<RecordInfo<?>> answerInfo,
@NonNull List<MdnsRecord> additionalAnswerRecords) {
boolean hasDnsSdPtrRecordAnswer = false;
boolean hasDnsSdSrvRecordAnswer = false;
boolean hasFullyOwnedNameMatch = false;
boolean hasKnownAnswer = false;
final int answersStartIndex = answerInfo.size();
for (RecordInfo<?> info : serviceRecords) {
if (info.isProbing) continue;
/* RFC6762 6.: the record name must match the question name, the record rrtype
must match the question qtype unless the qtype is "ANY" (255) or the rrtype is
"CNAME" (5), and the record rrclass must match the question qclass unless the
qclass is "ANY" (255) */
if (!Arrays.equals(info.record.getName(), question.getName())) continue;
hasFullyOwnedNameMatch |= !info.isSharedName;
// The repository does not store CNAME records
if (question.getType() != MdnsRecord.TYPE_ANY
&& question.getType() != info.record.getType()) {
continue;
}
if (question.getRecordClass() != MdnsRecord.CLASS_ANY
&& question.getRecordClass() != info.record.getRecordClass()) {
continue;
}
hasKnownAnswer = true;
hasDnsSdPtrRecordAnswer |= (info == servicePtrRecord);
hasDnsSdSrvRecordAnswer |= (info == serviceSrvRecord);
// TODO: responses to probe queries should bypass this check and only ensure the
// reply is sent 250ms after the last sent time (RFC 6762 p.15)
if (!replyUnicast && info.lastAdvertisedTimeMs > 0L
&& now - info.lastAdvertisedTimeMs < MIN_MULTICAST_REPLY_INTERVAL_MS) {
continue;
}
// TODO: Don't reply if in known answers of the querier (7.1) if TTL is > half
answerInfo.add(info);
}
// RFC6762 6.1:
// "Any time a responder receives a query for a name for which it has verified exclusive
// ownership, for a type for which that name has no records, the responder MUST [...]
// respond asserting the nonexistence of that record"
if (hasFullyOwnedNameMatch && !hasKnownAnswer) {
additionalAnswerRecords.add(new MdnsNsecRecord(
question.getName(),
0L /* receiptTimeMillis */,
true /* cacheFlush */,
// TODO: RFC6762 6.1: "In general, the TTL given for an NSEC record SHOULD
// be the same as the TTL that the record would have had, had it existed."
NAME_RECORDS_TTL_MILLIS,
question.getName(),
new int[] { question.getType() }));
}
// No more records to add if no answer
if (answerInfo.size() == answersStartIndex) return;
final List<RecordInfo<?>> additionalAnswerInfo = new ArrayList<>();
// RFC6763 12.1: if including PTR record, include the SRV and TXT records it names
if (hasDnsSdPtrRecordAnswer) {
if (serviceTxtRecord != null) {
additionalAnswerInfo.add(serviceTxtRecord);
}
if (serviceSrvRecord != null) {
additionalAnswerInfo.add(serviceSrvRecord);
}
}
// RFC6763 12.1&.2: if including PTR or SRV record, include the address records it names
if (hasDnsSdPtrRecordAnswer || hasDnsSdSrvRecordAnswer) {
for (RecordInfo<?> record : mGeneralRecords) {
if (record.record instanceof MdnsInetAddressRecord) {
additionalAnswerInfo.add(record);
}
}
}
for (RecordInfo<?> info : additionalAnswerInfo) {
additionalAnswerRecords.add(info.record);
}
// RFC6762 6.1: negative responses
addNsecRecordsForUniqueNames(additionalAnswerRecords,
answerInfo.listIterator(answersStartIndex),
additionalAnswerInfo.listIterator());
}
/**
* Add NSEC records indicating that the response records are unique.
*
@@ -540,6 +763,7 @@ public class MdnsRecordRepository {
final long now = SystemClock.elapsedRealtime();
for (RecordInfo<?> record : registration.allRecords) {
record.lastSentTimeMs = now;
record.lastAdvertisedTimeMs = now;
}
}

View File

@@ -16,8 +16,15 @@
package com.android.server.connectivity.mdns;
import static com.android.server.connectivity.mdns.MdnsSocketProvider.ensureRunningOnHandlerThread;
import android.annotation.NonNull;
import android.os.Handler;
import android.os.Looper;
import android.os.Message;
import android.util.Log;
import com.android.server.connectivity.mdns.MdnsRecordRepository.ReplyInfo;
import java.io.IOException;
import java.net.DatagramPacket;
@@ -25,6 +32,7 @@ import java.net.Inet4Address;
import java.net.Inet6Address;
import java.net.InetSocketAddress;
import java.net.MulticastSocket;
import java.util.Collections;
/**
* A class that handles sending mDNS replies to a {@link MulticastSocket}, possibly queueing them
@@ -33,20 +41,38 @@ import java.net.MulticastSocket;
* TODO: implement sending after a delay, combining queued replies and duplicate answer suppression
*/
public class MdnsReplySender {
private static final boolean DBG = MdnsAdvertiser.DBG;
private static final int MSG_SEND = 1;
private final String mLogTag;
@NonNull
private final MdnsInterfaceSocket mSocket;
@NonNull
private final Looper mLooper;
private final Handler mHandler;
@NonNull
private final byte[] mPacketCreationBuffer;
public MdnsReplySender(@NonNull Looper looper,
public MdnsReplySender(@NonNull String interfaceTag, @NonNull Looper looper,
@NonNull MdnsInterfaceSocket socket, @NonNull byte[] packetCreationBuffer) {
mLooper = looper;
mHandler = new SendHandler(looper);
mLogTag = MdnsReplySender.class.getSimpleName() + "/" + interfaceTag;
mSocket = socket;
mPacketCreationBuffer = packetCreationBuffer;
}
/**
* Queue a reply to be sent when its send delay expires.
*/
public void queueReply(@NonNull ReplyInfo reply) {
ensureRunningOnHandlerThread(mHandler);
// TODO: implement response aggregation (RFC 6762 6.4)
mHandler.sendMessageDelayed(mHandler.obtainMessage(MSG_SEND, reply), reply.sendDelayMs);
if (DBG) {
Log.v(mLogTag, "Scheduling " + reply);
}
}
/**
* Send a packet immediately.
*
@@ -54,9 +80,7 @@ public class MdnsReplySender {
*/
public void sendNow(@NonNull MdnsPacket packet, @NonNull InetSocketAddress destination)
throws IOException {
if (Thread.currentThread() != mLooper.getThread()) {
throw new IllegalStateException("sendNow must be called in the handler thread");
}
ensureRunningOnHandlerThread(mHandler);
if (!((destination.getAddress() instanceof Inet6Address && mSocket.hasJoinedIpv6())
|| (destination.getAddress() instanceof Inet4Address && mSocket.hasJoinedIpv4()))) {
// Skip sending if the socket has not joined the v4/v6 group (there was no address)
@@ -93,4 +117,37 @@ public class MdnsReplySender {
mSocket.send(new DatagramPacket(outBuffer, 0, len, destination));
}
/**
* Cancel all pending sends.
*/
public void cancelAll() {
ensureRunningOnHandlerThread(mHandler);
mHandler.removeMessages(MSG_SEND);
}
private class SendHandler extends Handler {
SendHandler(@NonNull Looper looper) {
super(looper);
}
@Override
public void handleMessage(@NonNull Message msg) {
final ReplyInfo replyInfo = (ReplyInfo) msg.obj;
if (DBG) Log.v(mLogTag, "Sending " + replyInfo);
final int flags = 0x8400; // Response, authoritative (rfc6762 18.4)
final MdnsPacket packet = new MdnsPacket(flags,
Collections.emptyList() /* questions */,
replyInfo.answers,
Collections.emptyList() /* authorityRecords */,
replyInfo.additionalAnswers);
try {
sendNow(packet, replyInfo.destination);
} catch (IOException e) {
Log.e(mLogTag, "Error sending MDNS response", e);
}
}
}
}

View File

@@ -107,5 +107,14 @@ public class MulticastPacketReader extends FdEventsReader<MulticastPacketReader.
ensureRunningOnHandlerThread(mHandler);
mPacketHandlers.add(handler);
}
/**
* Remove a packet handler added via {@link #addPacketHandler}. If the handler was not set,
* this is a no-op.
*/
public void removePacketHandler(@NonNull PacketHandler handler) {
ensureRunningOnHandlerThread(mHandler);
mPacketHandlers.remove(handler);
}
}

View File

@@ -38,6 +38,7 @@ import org.mockito.ArgumentMatchers.eq
import org.mockito.Mockito.any
import org.mockito.Mockito.anyInt
import org.mockito.Mockito.argThat
import org.mockito.Mockito.atLeastOnce
import org.mockito.Mockito.doReturn
import org.mockito.Mockito.mock
import org.mockito.Mockito.never
@@ -161,6 +162,60 @@ class MdnsAdvertiserTest {
verify(socketProvider).unrequestSocket(socketCb)
}
@Test
fun testAddService_Conflicts() {
val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps)
postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1) }
val oneNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), oneNetSocketCbCaptor.capture())
val oneNetSocketCb = oneNetSocketCbCaptor.value
// Register a service with the same name on all networks (name conflict)
postSync { advertiser.addService(SERVICE_ID_2, ALL_NETWORKS_SERVICE) }
val allNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
verify(socketProvider).requestSocket(eq(null), allNetSocketCbCaptor.capture())
val allNetSocketCb = allNetSocketCbCaptor.value
// Callbacks for matching network and all networks both get the socket
postSync {
oneNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
allNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
}
val expectedRenamed = NsdServiceInfo(
"${ALL_NETWORKS_SERVICE.serviceName} (2)", ALL_NETWORKS_SERVICE.serviceType).apply {
port = ALL_NETWORKS_SERVICE.port
host = ALL_NETWORKS_SERVICE.host
network = ALL_NETWORKS_SERVICE.network
}
val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
verify(mockDeps).makeAdvertiser(eq(mockSocket1), eq(listOf(TEST_LINKADDR)),
eq(thread.looper), any(), intAdvCbCaptor.capture())
verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_1),
argThat { it.matches(SERVICE_1) })
verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_2),
argThat { it.matches(expectedRenamed) })
doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_1)
postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded(
mockInterfaceAdvertiser1, SERVICE_ID_1) }
verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_1), argThat { it.matches(SERVICE_1) })
doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_2)
postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded(
mockInterfaceAdvertiser1, SERVICE_ID_2) }
verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_2),
argThat { it.matches(expectedRenamed) })
postSync { oneNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
postSync { allNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
// destroyNow can be called multiple times
verify(mockInterfaceAdvertiser1, atLeastOnce()).destroyNow()
}
private fun postSync(r: () -> Unit) {
handler.post(r)
handler.waitForIdle(TIMEOUT_MS)

View File

@@ -79,7 +79,7 @@ class MdnsAnnouncerTest {
@Test
fun testAnnounce() {
val replySender = MdnsReplySender(thread.looper, socket, buffer)
val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
@Suppress("UNCHECKED_CAST")
val cb = mock(MdnsPacketRepeater.PacketRepeaterCallback::class.java)
as MdnsPacketRepeater.PacketRepeaterCallback<BaseAnnouncementInfo>

View File

@@ -21,6 +21,7 @@ import android.net.LinkAddress
import android.net.nsd.NsdServiceInfo
import android.os.Build
import android.os.HandlerThread
import com.android.net.module.util.HexDump
import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo
import com.android.server.connectivity.mdns.MdnsAnnouncer.ExitAnnouncementInfo
@@ -30,6 +31,10 @@ import com.android.server.connectivity.mdns.MdnsProber.ProbingInfo
import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
import com.android.testutils.DevSdkIgnoreRunner
import com.android.testutils.waitForIdle
import java.net.InetSocketAddress
import kotlin.test.assertContentEquals
import kotlin.test.assertEquals
import kotlin.test.assertTrue
import org.junit.After
import org.junit.Before
import org.junit.Test
@@ -37,8 +42,10 @@ import org.junit.runner.RunWith
import org.mockito.ArgumentCaptor
import org.mockito.Mockito.any
import org.mockito.Mockito.anyInt
import org.mockito.Mockito.anyString
import org.mockito.Mockito.doAnswer
import org.mockito.Mockito.doReturn
import org.mockito.Mockito.eq
import org.mockito.Mockito.mock
import org.mockito.Mockito.times
import org.mockito.Mockito.verify
@@ -67,13 +74,18 @@ class MdnsInterfaceAdvertiserTest {
private val replySender = mock(MdnsReplySender::class.java)
private val announcer = mock(MdnsAnnouncer::class.java)
private val prober = mock(MdnsProber::class.java)
@Suppress("UNCHECKED_CAST")
private val probeCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
as ArgumentCaptor<PacketRepeaterCallback<ProbingInfo>>
@Suppress("UNCHECKED_CAST")
private val announceCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
as ArgumentCaptor<PacketRepeaterCallback<BaseAnnouncementInfo>>
private val packetHandlerCaptor = ArgumentCaptor.forClass(
MulticastPacketReader.PacketHandler::class.java)
private val probeCb get() = probeCbCaptor.value
private val announceCb get() = announceCbCaptor.value
private val packetHandler get() = packetHandlerCaptor.value
private val advertiser by lazy {
MdnsInterfaceAdvertiser(LOG_TAG, socket, TEST_ADDRS, thread.looper, TEST_BUFFER, cb, deps)
@@ -82,9 +94,9 @@ class MdnsInterfaceAdvertiserTest {
@Before
fun setUp() {
doReturn(repository).`when`(deps).makeRecordRepository(any())
doReturn(replySender).`when`(deps).makeReplySender(any(), any(), any())
doReturn(announcer).`when`(deps).makeMdnsAnnouncer(any(), any(), any(), any())
doReturn(prober).`when`(deps).makeMdnsProber(any(), any(), any(), any())
doReturn(replySender).`when`(deps).makeReplySender(anyString(), any(), any(), any())
doReturn(announcer).`when`(deps).makeMdnsAnnouncer(anyString(), any(), any(), any())
doReturn(prober).`when`(deps).makeMdnsProber(anyString(), any(), any(), any())
val knownServices = mutableSetOf<Int>()
doAnswer { inv ->
@@ -104,6 +116,7 @@ class MdnsInterfaceAdvertiserTest {
thread.start()
advertiser.start()
verify(socket).addPacketHandler(packetHandlerCaptor.capture())
verify(deps).makeMdnsProber(any(), any(), any(), probeCbCaptor.capture())
verify(deps).makeMdnsAnnouncer(any(), any(), any(), announceCbCaptor.capture())
}
@@ -157,6 +170,39 @@ class MdnsInterfaceAdvertiserTest {
verify(announcer, times(1)).stop(TEST_SERVICE_ID_1)
}
@Test
fun testReplyToQuery() {
addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1)
val mockReply = mock(MdnsRecordRepository.ReplyInfo::class.java)
doReturn(mockReply).`when`(repository).getReply(any(), any())
// Query obtained with:
// scapy.raw(scapy.DNS(
// qd = scapy.DNSQR(qtype='PTR', qname='_testservice._tcp.local'))
// ).hex().upper()
val query = HexDump.hexStringToByteArray(
"0000010000010000000000000C5F7465737473657276696365045F746370056C6F63616C00000C0001"
)
val src = InetSocketAddress(parseNumericAddress("2001:db8::456"), MdnsConstants.MDNS_PORT)
packetHandler.handlePacket(query, query.size, src)
val packetCaptor = ArgumentCaptor.forClass(MdnsPacket::class.java)
verify(repository).getReply(packetCaptor.capture(), eq(src))
packetCaptor.value.let {
assertEquals(1, it.questions.size)
assertEquals(0, it.answers.size)
assertEquals(0, it.authorityRecords.size)
assertEquals(0, it.additionalRecords.size)
assertTrue(it.questions[0] is MdnsPointerRecord)
assertContentEquals(arrayOf("_testservice", "_tcp", "local"), it.questions[0].name)
}
verify(replySender).queueReply(mockReply)
}
private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
AnnouncementInfo {
val testProbingInfo = mock(ProbingInfo::class.java)

View File

@@ -114,7 +114,7 @@ class MdnsProberTest {
@Test
fun testProbe() {
val replySender = MdnsReplySender(thread.looper, socket, buffer)
val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
val prober = TestProber(thread.looper, replySender, cb)
val probeInfo = TestProbeInfo(
listOf(makeServiceRecord(TEST_SERVICE_NAME_1, 37890)))
@@ -129,7 +129,7 @@ class MdnsProberTest {
@Test
fun testProbeMultipleRecords() {
val replySender = MdnsReplySender(thread.looper, socket, buffer)
val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
val prober = TestProber(thread.looper, replySender, cb)
val probeInfo = TestProbeInfo(listOf(
makeServiceRecord(TEST_SERVICE_NAME_1, 37890),
@@ -167,7 +167,7 @@ class MdnsProberTest {
@Test
fun testStopProbing() {
val replySender = MdnsReplySender(thread.looper, socket, buffer)
val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
val prober = TestProber(thread.looper, replySender, cb)
val probeInfo = TestProbeInfo(
listOf(makeServiceRecord(TEST_SERVICE_NAME_1, 37890)),

View File

@@ -21,10 +21,12 @@ import android.net.LinkAddress
import android.net.nsd.NsdServiceInfo
import android.os.Build
import android.os.HandlerThread
import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
import com.android.server.connectivity.mdns.MdnsRecordRepository.Dependencies
import com.android.server.connectivity.mdns.MdnsRecordRepository.getReverseDnsAddress
import com.android.testutils.DevSdkIgnoreRule
import com.android.testutils.DevSdkIgnoreRunner
import java.net.InetSocketAddress
import java.net.NetworkInterface
import java.util.Collections
import kotlin.test.assertContentEquals
@@ -150,11 +152,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testExitAnnouncements() {
val repository = MdnsRecordRepository(thread.looper, deps)
repository.updateAddresses(TEST_ADDRESSES)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
repository.onProbingSucceeded(probingInfo)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
repository.onAdvertisementSent(TEST_SERVICE_ID_1)
val exitAnnouncement = repository.exitService(TEST_SERVICE_ID_1)
@@ -183,9 +181,7 @@ class MdnsRecordRepositoryTest {
@Test
fun testExitingServiceReAdded() {
val repository = MdnsRecordRepository(thread.looper, deps)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
repository.onProbingSucceeded(probingInfo)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
repository.onAdvertisementSent(TEST_SERVICE_ID_1)
repository.exitService(TEST_SERVICE_ID_1)
@@ -199,11 +195,8 @@ class MdnsRecordRepositoryTest {
@Test
fun testOnProbingSucceeded() {
val repository = MdnsRecordRepository(thread.looper, deps)
repository.updateAddresses(TEST_ADDRESSES)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
val announcementInfo = repository.onProbingSucceeded(probingInfo)
val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
repository.onAdvertisementSent(TEST_SERVICE_ID_1)
val packet = announcementInfo.getPacket(0)
assertEquals(0x8400 /* response, authoritative */, packet.flags)
@@ -322,4 +315,98 @@ class MdnsRecordRepositoryTest {
val expectedV4 = "123.2.0.192.in-addr.arpa".split(".").toTypedArray()
assertContentEquals(expectedV4, getReverseDnsAddress(parseNumericAddress("192.0.2.123")))
}
@Test
fun testGetReply() {
val repository = MdnsRecordRepository(thread.looper, deps)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
val questions = listOf(MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"),
0L /* receiptTimeMillis */,
false /* cacheFlush */,
// TTL and data is empty for a question
0L /* ttlMillis */,
null /* pointer */))
val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
listOf() /* authorityRecords */, listOf() /* additionalRecords */)
val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
val reply = repository.getReply(query, src)
assertNotNull(reply)
// Source address is IPv4
assertEquals(MdnsConstants.getMdnsIPv4Address(), reply.destination.address)
assertEquals(MdnsConstants.MDNS_PORT, reply.destination.port)
// TTLs as per RFC6762 10.
val longTtl = 4_500_000L
val shortTtl = 120_000L
val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
assertEquals(listOf(
MdnsPointerRecord(
arrayOf("_testservice", "_tcp", "local"),
0L /* receiptTimeMillis */,
false /* cacheFlush */,
longTtl,
serviceName),
), reply.answers)
assertEquals(listOf(
MdnsTextRecord(
serviceName,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
longTtl,
listOf() /* entries */),
MdnsServiceRecord(
serviceName,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
shortTtl,
0 /* servicePriority */,
0 /* serviceWeight */,
TEST_PORT,
TEST_HOSTNAME),
MdnsInetAddressRecord(
TEST_HOSTNAME,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
shortTtl,
TEST_ADDRESSES[0].address),
MdnsInetAddressRecord(
TEST_HOSTNAME,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
shortTtl,
TEST_ADDRESSES[1].address),
MdnsInetAddressRecord(
TEST_HOSTNAME,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
shortTtl,
TEST_ADDRESSES[2].address),
MdnsNsecRecord(
serviceName,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
longTtl,
serviceName /* nextDomain */,
intArrayOf(MdnsRecord.TYPE_TXT, MdnsRecord.TYPE_SRV)),
MdnsNsecRecord(
TEST_HOSTNAME,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
shortTtl,
TEST_HOSTNAME /* nextDomain */,
intArrayOf(MdnsRecord.TYPE_A, MdnsRecord.TYPE_AAAA)),
), reply.additionalAnswers)
}
}
private fun MdnsRecordRepository.initWithService(serviceId: Int, serviceInfo: NsdServiceInfo):
AnnouncementInfo {
updateAddresses(TEST_ADDRESSES)
addService(serviceId, serviceInfo)
val probingInfo = setServiceProbing(serviceId)
assertNotNull(probingInfo)
return onProbingSucceeded(probingInfo)
}