Merge changes I69128db9,I13db22f8
* changes: Implement onServiceConflict Add replying to queries
This commit is contained in:
@@ -29,10 +29,10 @@ import android.util.SparseArray;
|
||||
|
||||
import com.android.internal.annotations.VisibleForTesting;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.function.BiPredicate;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
/**
|
||||
* MdnsAdvertiser manages advertising services per {@link com.android.server.NsdService} requests.
|
||||
@@ -85,7 +85,7 @@ public class MdnsAdvertiser {
|
||||
public void onRegisterServiceSucceeded(
|
||||
@NonNull MdnsInterfaceAdvertiser advertiser, int serviceId) {
|
||||
// Wait for all current interfaces to be done probing before notifying of success.
|
||||
if (anyAdvertiser(a -> a.isProbing(serviceId))) return;
|
||||
if (any(mAllAdvertisers, (k, a) -> a.isProbing(serviceId))) return;
|
||||
// The service may still be unregistered/renamed if a conflict is found on a later added
|
||||
// interface, or if a conflicting announcement/reply is detected (RFC6762 9.)
|
||||
|
||||
@@ -102,7 +102,37 @@ public class MdnsAdvertiser {
|
||||
|
||||
@Override
|
||||
public void onServiceConflict(@NonNull MdnsInterfaceAdvertiser advertiser, int serviceId) {
|
||||
// TODO: handle conflicts found after registration (during or after probing)
|
||||
if (DBG) {
|
||||
Log.v(TAG, "Found conflict, restarted probing for service " + serviceId);
|
||||
}
|
||||
|
||||
final Registration registration = mRegistrations.get(serviceId);
|
||||
if (registration == null) return;
|
||||
if (registration.mNotifiedRegistrationSuccess) {
|
||||
// TODO: consider notifying clients that the service is no longer registered with
|
||||
// the old name (back to probing). The legacy implementation did not send any
|
||||
// callback though; it only sent onServiceRegistered after re-probing finishes
|
||||
// (with the old, conflicting, actually not used name as argument... The new
|
||||
// implementation will send callbacks with the new name).
|
||||
registration.mNotifiedRegistrationSuccess = false;
|
||||
|
||||
// The service was done probing, just reset it to probing state (RFC6762 9.)
|
||||
forAllAdvertisers(a -> a.restartProbingForConflict(serviceId));
|
||||
return;
|
||||
}
|
||||
|
||||
// Conflict was found during probing; rename once to find a name that has no conflict
|
||||
registration.updateForConflict(
|
||||
registration.makeNewServiceInfoForConflict(1 /* renameCount */),
|
||||
1 /* renameCount */);
|
||||
|
||||
// Keep renaming if the new name conflicts in local registrations
|
||||
updateRegistrationUntilNoConflict((net, adv) -> adv.hasRegistration(registration),
|
||||
registration);
|
||||
|
||||
// Update advertisers to use the new name
|
||||
forAllAdvertisers(a -> a.renameServiceForConflict(
|
||||
serviceId, registration.getServiceInfo()));
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -116,6 +146,25 @@ public class MdnsAdvertiser {
|
||||
}
|
||||
};
|
||||
|
||||
private boolean hasAnyConflict(
|
||||
@NonNull BiPredicate<Network, InterfaceAdvertiserRequest> applicableAdvertiserFilter,
|
||||
@NonNull NsdServiceInfo newInfo) {
|
||||
return any(mAdvertiserRequests, (network, adv) ->
|
||||
applicableAdvertiserFilter.test(network, adv) && adv.hasConflict(newInfo));
|
||||
}
|
||||
|
||||
private void updateRegistrationUntilNoConflict(
|
||||
@NonNull BiPredicate<Network, InterfaceAdvertiserRequest> applicableAdvertiserFilter,
|
||||
@NonNull Registration registration) {
|
||||
int renameCount = 0;
|
||||
NsdServiceInfo newInfo = registration.getServiceInfo();
|
||||
while (hasAnyConflict(applicableAdvertiserFilter, newInfo)) {
|
||||
renameCount++;
|
||||
newInfo = registration.makeNewServiceInfoForConflict(renameCount);
|
||||
}
|
||||
registration.updateForConflict(newInfo, renameCount);
|
||||
}
|
||||
|
||||
/**
|
||||
* A request for a {@link MdnsInterfaceAdvertiser}.
|
||||
*
|
||||
@@ -152,6 +201,21 @@ public class MdnsAdvertiser {
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return whether this {@link InterfaceAdvertiserRequest} has the given registration.
|
||||
*/
|
||||
boolean hasRegistration(@NonNull Registration registration) {
|
||||
return mPendingRegistrations.indexOfValue(registration) >= 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return whether using the proposed new {@link NsdServiceInfo} to add a registration would
|
||||
* cause a conflict in this {@link InterfaceAdvertiserRequest}.
|
||||
*/
|
||||
boolean hasConflict(@NonNull NsdServiceInfo newInfo) {
|
||||
return getConflictingService(newInfo) >= 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the ID of a conflicting service, or -1 if none.
|
||||
*/
|
||||
@@ -166,16 +230,19 @@ public class MdnsAdvertiser {
|
||||
return -1;
|
||||
}
|
||||
|
||||
void addService(int id, Registration registration)
|
||||
throws NameConflictException {
|
||||
final int conflicting = getConflictingService(registration.getServiceInfo());
|
||||
if (conflicting >= 0) {
|
||||
throw new NameConflictException(conflicting);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a service.
|
||||
*
|
||||
* Conflicts must be checked via {@link #getConflictingService} before attempting to add.
|
||||
*/
|
||||
void addService(int id, Registration registration) {
|
||||
mPendingRegistrations.put(id, registration);
|
||||
for (int i = 0; i < mAdvertisers.size(); i++) {
|
||||
mAdvertisers.valueAt(i).addService(id, registration.getServiceInfo());
|
||||
try {
|
||||
mAdvertisers.valueAt(i).addService(id, registration.getServiceInfo());
|
||||
} catch (NameConflictException e) {
|
||||
Log.wtf(TAG, "Name conflict adding services that should have unique names", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -239,32 +306,42 @@ public class MdnsAdvertiser {
|
||||
/**
|
||||
* Update the registration to use a different service name, after a conflict was found.
|
||||
*
|
||||
* @param newInfo New service info to use.
|
||||
* @param renameCount How many renames were done before reaching the current name.
|
||||
*/
|
||||
private void updateForConflict(@NonNull NsdServiceInfo newInfo, int renameCount) {
|
||||
mConflictCount += renameCount;
|
||||
mServiceInfo = newInfo;
|
||||
}
|
||||
|
||||
/**
|
||||
* Make a new service name for the registration, after a conflict was found.
|
||||
*
|
||||
* If a name conflict was found during probing or because different advertising requests
|
||||
* used the same name, the registration is attempted again with a new name (here using
|
||||
* a number suffix, (1), (2) etc). Registration success is notified once probing succeeds
|
||||
* with a new name. This matches legacy behavior based on mdnsresponder, and appendix D of
|
||||
* RFC6763.
|
||||
* @return The new service info with the updated name.
|
||||
*
|
||||
* @param renameCount How much to increase the number suffix for this conflict.
|
||||
*/
|
||||
@NonNull
|
||||
private NsdServiceInfo updateForConflict() {
|
||||
mConflictCount++;
|
||||
public NsdServiceInfo makeNewServiceInfoForConflict(int renameCount) {
|
||||
// In case of conflict choose a different service name. After the first conflict use
|
||||
// "Name (2)", then "Name (3)" etc.
|
||||
// TODO: use a hidden method in NsdServiceInfo once MdnsAdvertiser is moved to service-t
|
||||
final NsdServiceInfo newInfo = new NsdServiceInfo();
|
||||
newInfo.setServiceName(mOriginalName + " (" + (mConflictCount + 1) + ")");
|
||||
newInfo.setServiceName(mOriginalName + " (" + (mConflictCount + renameCount + 1) + ")");
|
||||
newInfo.setServiceType(mServiceInfo.getServiceType());
|
||||
for (Map.Entry<String, byte[]> attr : mServiceInfo.getAttributes().entrySet()) {
|
||||
newInfo.setAttribute(attr.getKey(), attr.getValue());
|
||||
newInfo.setAttribute(attr.getKey(),
|
||||
attr.getValue() == null ? null : new String(attr.getValue()));
|
||||
}
|
||||
newInfo.setHost(mServiceInfo.getHost());
|
||||
newInfo.setPort(mServiceInfo.getPort());
|
||||
newInfo.setNetwork(mServiceInfo.getNetwork());
|
||||
// interfaceIndex is not set when registering
|
||||
|
||||
mServiceInfo = newInfo;
|
||||
return mServiceInfo;
|
||||
return newInfo;
|
||||
}
|
||||
|
||||
@NonNull
|
||||
@@ -338,55 +415,27 @@ public class MdnsAdvertiser {
|
||||
Log.i(TAG, "Adding service " + service + " with ID " + id);
|
||||
}
|
||||
|
||||
try {
|
||||
final Registration registration = new Registration(service);
|
||||
while (!tryAddRegistration(id, registration)) {
|
||||
registration.updateForConflict();
|
||||
}
|
||||
|
||||
mRegistrations.put(id, registration);
|
||||
} catch (IOException e) {
|
||||
Log.e(TAG, "Error adding service " + service, e);
|
||||
removeService(id);
|
||||
// TODO (b/264986328): add a more specific error code
|
||||
mCb.onRegisterServiceFailed(id, NsdManager.FAILURE_INTERNAL_ERROR);
|
||||
}
|
||||
}
|
||||
|
||||
private boolean tryAddRegistration(int id, @NonNull Registration registration)
|
||||
throws IOException {
|
||||
final NsdServiceInfo serviceInfo = registration.getServiceInfo();
|
||||
final Network network = serviceInfo.getNetwork();
|
||||
try {
|
||||
InterfaceAdvertiserRequest advertiser = mAdvertiserRequests.get(network);
|
||||
if (advertiser == null) {
|
||||
advertiser = new InterfaceAdvertiserRequest(network);
|
||||
mAdvertiserRequests.put(network, advertiser);
|
||||
}
|
||||
advertiser.addService(id, registration);
|
||||
} catch (NameConflictException e) {
|
||||
if (DBG) {
|
||||
Log.i(TAG, "Service name conflicts: " + serviceInfo.getServiceName());
|
||||
}
|
||||
removeService(id);
|
||||
return false;
|
||||
final Network network = service.getNetwork();
|
||||
final Registration registration = new Registration(service);
|
||||
final BiPredicate<Network, InterfaceAdvertiserRequest> checkConflictFilter;
|
||||
if (network == null) {
|
||||
// If registering on all networks, no advertiser must have conflicts
|
||||
checkConflictFilter = (net, adv) -> true;
|
||||
} else {
|
||||
// If registering on one network, the matching network advertiser and the one for all
|
||||
// networks must not have conflicts
|
||||
checkConflictFilter = (net, adv) -> net == null || network.equals(net);
|
||||
}
|
||||
|
||||
// When adding a service to a specific network, check that it does not conflict with other
|
||||
// registrations advertising on all networks
|
||||
final InterfaceAdvertiserRequest allNetworksAdvertiser = mAdvertiserRequests.get(null);
|
||||
if (network != null && allNetworksAdvertiser != null
|
||||
&& allNetworksAdvertiser.getConflictingService(serviceInfo) >= 0) {
|
||||
if (DBG) {
|
||||
Log.i(TAG, "Service conflicts with advertisement on all networks: "
|
||||
+ serviceInfo.getServiceName());
|
||||
}
|
||||
removeService(id);
|
||||
return false;
|
||||
}
|
||||
updateRegistrationUntilNoConflict(checkConflictFilter, registration);
|
||||
|
||||
InterfaceAdvertiserRequest advertiser = mAdvertiserRequests.get(network);
|
||||
if (advertiser == null) {
|
||||
advertiser = new InterfaceAdvertiserRequest(network);
|
||||
mAdvertiserRequests.put(network, advertiser);
|
||||
}
|
||||
advertiser.addService(id, registration);
|
||||
mRegistrations.put(id, registration);
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -406,12 +455,20 @@ public class MdnsAdvertiser {
|
||||
mRegistrations.remove(id);
|
||||
}
|
||||
|
||||
private boolean anyAdvertiser(@NonNull Predicate<MdnsInterfaceAdvertiser> predicate) {
|
||||
for (int i = 0; i < mAllAdvertisers.size(); i++) {
|
||||
if (predicate.test(mAllAdvertisers.valueAt(i))) {
|
||||
private static <K, V> boolean any(@NonNull ArrayMap<K, V> map,
|
||||
@NonNull BiPredicate<K, V> predicate) {
|
||||
for (int i = 0; i < map.size(); i++) {
|
||||
if (predicate.test(map.keyAt(i), map.valueAt(i))) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private void forAllAdvertisers(@NonNull Consumer<MdnsInterfaceAdvertiser> consumer) {
|
||||
any(mAllAdvertisers, (socket, advertiser) -> {
|
||||
consumer.accept(advertiser);
|
||||
return false;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ public final class MdnsConstants {
|
||||
public static final int FLAGS_QUERY = 0x0000;
|
||||
public static final int FLAGS_RESPONSE_MASK = 0xF80F;
|
||||
public static final int FLAGS_RESPONSE = 0x8000;
|
||||
public static final int FLAG_TRUNCATED = 0x0200;
|
||||
public static final int QCLASS_INTERNET = 0x0001;
|
||||
public static final int QCLASS_UNICAST = 0x8000;
|
||||
public static final String SUBTYPE_LABEL = "_sub";
|
||||
|
||||
@@ -25,16 +25,18 @@ import android.os.Looper;
|
||||
import android.util.Log;
|
||||
|
||||
import com.android.internal.annotations.VisibleForTesting;
|
||||
import com.android.net.module.util.HexDump;
|
||||
import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo;
|
||||
import com.android.server.connectivity.mdns.MdnsPacketRepeater.PacketRepeaterCallback;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A class that handles advertising services on a {@link MdnsInterfaceSocket} tied to an interface.
|
||||
*/
|
||||
public class MdnsInterfaceAdvertiser {
|
||||
public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHandler {
|
||||
private static final boolean DBG = MdnsAdvertiser.DBG;
|
||||
@VisibleForTesting
|
||||
public static final long EXIT_ANNOUNCEMENT_DELAY_MS = 100L;
|
||||
@@ -145,9 +147,9 @@ public class MdnsInterfaceAdvertiser {
|
||||
|
||||
/** @see MdnsReplySender */
|
||||
@NonNull
|
||||
public MdnsReplySender makeReplySender(@NonNull Looper looper,
|
||||
public MdnsReplySender makeReplySender(@NonNull String interfaceTag, @NonNull Looper looper,
|
||||
@NonNull MdnsInterfaceSocket socket, @NonNull byte[] packetCreationBuffer) {
|
||||
return new MdnsReplySender(looper, socket, packetCreationBuffer);
|
||||
return new MdnsReplySender(interfaceTag, looper, socket, packetCreationBuffer);
|
||||
}
|
||||
|
||||
/** @see MdnsAnnouncer */
|
||||
@@ -182,7 +184,7 @@ public class MdnsInterfaceAdvertiser {
|
||||
mSocket = socket;
|
||||
mCb = cb;
|
||||
mCbHandler = new Handler(looper);
|
||||
mReplySender = deps.makeReplySender(looper, socket, packetCreationBuffer);
|
||||
mReplySender = deps.makeReplySender(logTag, looper, socket, packetCreationBuffer);
|
||||
mAnnouncer = deps.makeMdnsAnnouncer(logTag, looper, mReplySender,
|
||||
mAnnouncingCallback);
|
||||
mProber = deps.makeMdnsProber(logTag, looper, mReplySender, mProbingCallback);
|
||||
@@ -196,7 +198,7 @@ public class MdnsInterfaceAdvertiser {
|
||||
* {@link #destroyNow()}.
|
||||
*/
|
||||
public void start() {
|
||||
// TODO: start receiving packets
|
||||
mSocket.addPacketHandler(this);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -267,8 +269,8 @@ public class MdnsInterfaceAdvertiser {
|
||||
mProber.stop(serviceId);
|
||||
mAnnouncer.stop(serviceId);
|
||||
}
|
||||
|
||||
// TODO: stop receiving packets
|
||||
mReplySender.cancelAll();
|
||||
mSocket.removePacketHandler(this);
|
||||
mCbHandler.post(() -> mCb.onDestroyed(mSocket));
|
||||
}
|
||||
|
||||
@@ -294,4 +296,33 @@ public class MdnsInterfaceAdvertiser {
|
||||
public boolean isProbing(int serviceId) {
|
||||
return mRecordRepository.isProbing(serviceId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlePacket(byte[] recvbuf, int length, InetSocketAddress src) {
|
||||
final MdnsPacket packet;
|
||||
try {
|
||||
packet = MdnsPacket.parse(new MdnsPacketReader(recvbuf, length));
|
||||
} catch (MdnsPacket.ParseException e) {
|
||||
Log.e(mTag, "Error parsing mDNS packet", e);
|
||||
if (DBG) {
|
||||
Log.v(
|
||||
mTag, "Packet: " + HexDump.toHexString(recvbuf, 0, length));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (DBG) {
|
||||
Log.v(mTag,
|
||||
"Parsed packet with " + packet.questions.size() + " questions, "
|
||||
+ packet.answers.size() + " answers, "
|
||||
+ packet.authorityRecords.size() + " authority, "
|
||||
+ packet.additionalRecords.size() + " additional from " + src);
|
||||
}
|
||||
|
||||
final MdnsRecordRepository.ReplyInfo answers =
|
||||
mRecordRepository.getReply(packet, src);
|
||||
|
||||
if (answers == null) return;
|
||||
mReplySender.queueReply(answers);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -161,6 +161,14 @@ public class MdnsInterfaceSocket {
|
||||
mPacketReader.addPacketHandler(handler);
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a handler added via {@link #addPacketHandler}. If the handler is not present, this is
|
||||
* a no-op.
|
||||
*/
|
||||
public void removePacketHandler(@NonNull MulticastPacketReader.PacketHandler handler) {
|
||||
mPacketReader.removePacketHandler(handler);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the network interface that this socket is bound to.
|
||||
*
|
||||
|
||||
@@ -16,6 +16,9 @@
|
||||
|
||||
package com.android.server.connectivity.mdns;
|
||||
|
||||
import static com.android.server.connectivity.mdns.MdnsRecordRepository.IPV4_ADDR;
|
||||
import static com.android.server.connectivity.mdns.MdnsRecordRepository.IPV6_ADDR;
|
||||
|
||||
import android.annotation.NonNull;
|
||||
import android.annotation.Nullable;
|
||||
import android.os.Handler;
|
||||
@@ -32,10 +35,6 @@ import java.net.InetSocketAddress;
|
||||
*/
|
||||
public abstract class MdnsPacketRepeater<T extends MdnsPacketRepeater.Request> {
|
||||
private static final boolean DBG = MdnsAdvertiser.DBG;
|
||||
private static final InetSocketAddress IPV4_ADDR = new InetSocketAddress(
|
||||
MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT);
|
||||
private static final InetSocketAddress IPV6_ADDR = new InetSocketAddress(
|
||||
MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
|
||||
private static final InetSocketAddress[] ALL_ADDRS = new InetSocketAddress[] {
|
||||
IPV4_ADDR, IPV6_ADDR
|
||||
};
|
||||
@@ -114,7 +113,7 @@ public abstract class MdnsPacketRepeater<T extends MdnsPacketRepeater.Request> {
|
||||
final MdnsPacket packet = request.getPacket(index);
|
||||
if (DBG) {
|
||||
Log.v(getTag(), "Sending packets for iteration " + index + " out of "
|
||||
+ request.getNumSends());
|
||||
+ request.getNumSends() + " for ID " + msg.what);
|
||||
}
|
||||
// Send to both v4 and v6 addresses; the reply sender will take care of ignoring the
|
||||
// send when the socket has not joined the relevant group.
|
||||
|
||||
@@ -45,6 +45,7 @@ public abstract class MdnsRecord {
|
||||
private static final int FLAG_CACHE_FLUSH = 0x8000;
|
||||
|
||||
public static final long RECEIPT_TIME_NOT_SENT = 0L;
|
||||
public static final int CLASS_ANY = 0x00ff;
|
||||
|
||||
/** Status indicating that the record is current. */
|
||||
public static final int STATUS_OK = 0;
|
||||
@@ -317,4 +318,4 @@ public abstract class MdnsRecord {
|
||||
return (recordType * 31) + Arrays.hashCode(recordName);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,6 +34,7 @@ import com.android.net.module.util.HexDump;
|
||||
import java.io.IOException;
|
||||
import java.net.Inet4Address;
|
||||
import java.net.InetAddress;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.net.NetworkInterface;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
@@ -42,6 +43,7 @@ import java.util.Enumeration;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import java.util.TreeMap;
|
||||
import java.util.UUID;
|
||||
@@ -54,6 +56,9 @@ import java.util.concurrent.TimeUnit;
|
||||
*/
|
||||
@TargetApi(Build.VERSION_CODES.TIRAMISU) // Allow calling T+ APIs; this is only loaded on T+
|
||||
public class MdnsRecordRepository {
|
||||
// RFC6762 p.15
|
||||
private static final long MIN_MULTICAST_REPLY_INTERVAL_MS = 1_000L;
|
||||
|
||||
// TTLs as per RFC6762 10.
|
||||
// TTL for records with a host name as the resource record's name (e.g., A, AAAA, HINFO) or a
|
||||
// host name contained within the resource record's rdata (e.g., SRV, reverse mapping PTR
|
||||
@@ -69,6 +74,13 @@ public class MdnsRecordRepository {
|
||||
private static final String[] DNS_SD_SERVICE_TYPE =
|
||||
new String[] { "_services", "_dns-sd", "_udp", LOCAL_TLD };
|
||||
|
||||
public static final InetSocketAddress IPV6_ADDR = new InetSocketAddress(
|
||||
MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
|
||||
public static final InetSocketAddress IPV4_ADDR = new InetSocketAddress(
|
||||
MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT);
|
||||
|
||||
@NonNull
|
||||
private final Random mDelayGenerator = new Random();
|
||||
// Map of service unique ID -> records for service
|
||||
@NonNull
|
||||
private final SparseArray<ServiceRegistration> mServices = new SparseArray<>();
|
||||
@@ -138,6 +150,11 @@ public class MdnsRecordRepository {
|
||||
*/
|
||||
public boolean isProbing;
|
||||
|
||||
/**
|
||||
* Last time (as per SystemClock.elapsedRealtime) when advertised via multicast, 0 if never
|
||||
*/
|
||||
public long lastAdvertisedTimeMs;
|
||||
|
||||
/**
|
||||
* Last time (as per SystemClock.elapsedRealtime) when sent via unicast or multicast,
|
||||
* 0 if never
|
||||
@@ -390,6 +407,212 @@ public class MdnsRecordRepository {
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* Info about a reply to be sent.
|
||||
*/
|
||||
public static class ReplyInfo {
|
||||
@NonNull
|
||||
public final List<MdnsRecord> answers;
|
||||
@NonNull
|
||||
public final List<MdnsRecord> additionalAnswers;
|
||||
public final long sendDelayMs;
|
||||
@NonNull
|
||||
public final InetSocketAddress destination;
|
||||
|
||||
public ReplyInfo(
|
||||
@NonNull List<MdnsRecord> answers,
|
||||
@NonNull List<MdnsRecord> additionalAnswers,
|
||||
long sendDelayMs,
|
||||
@NonNull InetSocketAddress destination) {
|
||||
this.answers = answers;
|
||||
this.additionalAnswers = additionalAnswers;
|
||||
this.sendDelayMs = sendDelayMs;
|
||||
this.destination = destination;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "{ReplyInfo to " + destination + ", answers: " + answers.size()
|
||||
+ ", additionalAnswers: " + additionalAnswers.size()
|
||||
+ ", sendDelayMs " + sendDelayMs + "}";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the reply to send to an incoming packet.
|
||||
*
|
||||
* @param packet The incoming packet.
|
||||
* @param src The source address of the incoming packet.
|
||||
*/
|
||||
@Nullable
|
||||
public ReplyInfo getReply(MdnsPacket packet, InetSocketAddress src) {
|
||||
final long now = SystemClock.elapsedRealtime();
|
||||
final boolean replyUnicast = (packet.flags & MdnsConstants.QCLASS_UNICAST) != 0;
|
||||
final ArrayList<MdnsRecord> additionalAnswerRecords = new ArrayList<>();
|
||||
final ArrayList<RecordInfo<?>> answerInfo = new ArrayList<>();
|
||||
for (MdnsRecord question : packet.questions) {
|
||||
// Add answers from general records
|
||||
addReplyFromService(question, mGeneralRecords, null /* servicePtrRecord */,
|
||||
null /* serviceSrvRecord */, null /* serviceTxtRecord */, replyUnicast, now,
|
||||
answerInfo, additionalAnswerRecords);
|
||||
|
||||
// Add answers from each service
|
||||
for (int i = 0; i < mServices.size(); i++) {
|
||||
final ServiceRegistration registration = mServices.valueAt(i);
|
||||
if (registration.exiting) continue;
|
||||
addReplyFromService(question, registration.allRecords, registration.ptrRecord,
|
||||
registration.srvRecord, registration.txtRecord, replyUnicast, now,
|
||||
answerInfo, additionalAnswerRecords);
|
||||
}
|
||||
}
|
||||
|
||||
if (answerInfo.size() == 0 && additionalAnswerRecords.size() == 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Determine the send delay
|
||||
final long delayMs;
|
||||
if ((packet.flags & MdnsConstants.FLAG_TRUNCATED) != 0) {
|
||||
// RFC 6762 6.: 400-500ms delay if TC bit is set
|
||||
delayMs = 400L + mDelayGenerator.nextInt(100);
|
||||
} else if (packet.questions.size() > 1
|
||||
|| CollectionUtils.any(answerInfo, a -> a.isSharedName)) {
|
||||
// 20-120ms if there may be responses from other hosts (not a fully owned
|
||||
// name) (RFC 6762 6.), or if there are multiple questions (6.3).
|
||||
// TODO: this should be 0 if this is a probe query ("can be distinguished from a
|
||||
// normal query by the fact that a probe query contains a proposed record in the
|
||||
// Authority Section that answers the question" in 6.), and the reply is for a fully
|
||||
// owned record.
|
||||
delayMs = 20L + mDelayGenerator.nextInt(100);
|
||||
} else {
|
||||
delayMs = 0L;
|
||||
}
|
||||
|
||||
// Determine the send destination
|
||||
final InetSocketAddress dest;
|
||||
if (replyUnicast) {
|
||||
dest = src;
|
||||
} else if (src.getAddress() instanceof Inet4Address) {
|
||||
dest = IPV4_ADDR;
|
||||
} else {
|
||||
dest = IPV6_ADDR;
|
||||
}
|
||||
|
||||
// Build the list of answer records from their RecordInfo
|
||||
final ArrayList<MdnsRecord> answerRecords = new ArrayList<>(answerInfo.size());
|
||||
for (RecordInfo<?> info : answerInfo) {
|
||||
// TODO: consider actual packet send delay after response aggregation
|
||||
info.lastSentTimeMs = now + delayMs;
|
||||
if (!replyUnicast) {
|
||||
info.lastAdvertisedTimeMs = info.lastSentTimeMs;
|
||||
}
|
||||
answerRecords.add(info.record);
|
||||
}
|
||||
|
||||
return new ReplyInfo(answerRecords, additionalAnswerRecords, delayMs, dest);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add answers and additional answers for a question, from a ServiceRegistration.
|
||||
*/
|
||||
private void addReplyFromService(@NonNull MdnsRecord question,
|
||||
@NonNull List<RecordInfo<?>> serviceRecords,
|
||||
@Nullable RecordInfo<MdnsPointerRecord> servicePtrRecord,
|
||||
@Nullable RecordInfo<MdnsServiceRecord> serviceSrvRecord,
|
||||
@Nullable RecordInfo<MdnsTextRecord> serviceTxtRecord,
|
||||
boolean replyUnicast, long now, @NonNull List<RecordInfo<?>> answerInfo,
|
||||
@NonNull List<MdnsRecord> additionalAnswerRecords) {
|
||||
boolean hasDnsSdPtrRecordAnswer = false;
|
||||
boolean hasDnsSdSrvRecordAnswer = false;
|
||||
boolean hasFullyOwnedNameMatch = false;
|
||||
boolean hasKnownAnswer = false;
|
||||
|
||||
final int answersStartIndex = answerInfo.size();
|
||||
for (RecordInfo<?> info : serviceRecords) {
|
||||
if (info.isProbing) continue;
|
||||
|
||||
/* RFC6762 6.: the record name must match the question name, the record rrtype
|
||||
must match the question qtype unless the qtype is "ANY" (255) or the rrtype is
|
||||
"CNAME" (5), and the record rrclass must match the question qclass unless the
|
||||
qclass is "ANY" (255) */
|
||||
if (!Arrays.equals(info.record.getName(), question.getName())) continue;
|
||||
hasFullyOwnedNameMatch |= !info.isSharedName;
|
||||
|
||||
// The repository does not store CNAME records
|
||||
if (question.getType() != MdnsRecord.TYPE_ANY
|
||||
&& question.getType() != info.record.getType()) {
|
||||
continue;
|
||||
}
|
||||
if (question.getRecordClass() != MdnsRecord.CLASS_ANY
|
||||
&& question.getRecordClass() != info.record.getRecordClass()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
hasKnownAnswer = true;
|
||||
hasDnsSdPtrRecordAnswer |= (info == servicePtrRecord);
|
||||
hasDnsSdSrvRecordAnswer |= (info == serviceSrvRecord);
|
||||
|
||||
// TODO: responses to probe queries should bypass this check and only ensure the
|
||||
// reply is sent 250ms after the last sent time (RFC 6762 p.15)
|
||||
if (!replyUnicast && info.lastAdvertisedTimeMs > 0L
|
||||
&& now - info.lastAdvertisedTimeMs < MIN_MULTICAST_REPLY_INTERVAL_MS) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: Don't reply if in known answers of the querier (7.1) if TTL is > half
|
||||
|
||||
answerInfo.add(info);
|
||||
}
|
||||
|
||||
// RFC6762 6.1:
|
||||
// "Any time a responder receives a query for a name for which it has verified exclusive
|
||||
// ownership, for a type for which that name has no records, the responder MUST [...]
|
||||
// respond asserting the nonexistence of that record"
|
||||
if (hasFullyOwnedNameMatch && !hasKnownAnswer) {
|
||||
additionalAnswerRecords.add(new MdnsNsecRecord(
|
||||
question.getName(),
|
||||
0L /* receiptTimeMillis */,
|
||||
true /* cacheFlush */,
|
||||
// TODO: RFC6762 6.1: "In general, the TTL given for an NSEC record SHOULD
|
||||
// be the same as the TTL that the record would have had, had it existed."
|
||||
NAME_RECORDS_TTL_MILLIS,
|
||||
question.getName(),
|
||||
new int[] { question.getType() }));
|
||||
}
|
||||
|
||||
// No more records to add if no answer
|
||||
if (answerInfo.size() == answersStartIndex) return;
|
||||
|
||||
final List<RecordInfo<?>> additionalAnswerInfo = new ArrayList<>();
|
||||
// RFC6763 12.1: if including PTR record, include the SRV and TXT records it names
|
||||
if (hasDnsSdPtrRecordAnswer) {
|
||||
if (serviceTxtRecord != null) {
|
||||
additionalAnswerInfo.add(serviceTxtRecord);
|
||||
}
|
||||
if (serviceSrvRecord != null) {
|
||||
additionalAnswerInfo.add(serviceSrvRecord);
|
||||
}
|
||||
}
|
||||
|
||||
// RFC6763 12.1&.2: if including PTR or SRV record, include the address records it names
|
||||
if (hasDnsSdPtrRecordAnswer || hasDnsSdSrvRecordAnswer) {
|
||||
for (RecordInfo<?> record : mGeneralRecords) {
|
||||
if (record.record instanceof MdnsInetAddressRecord) {
|
||||
additionalAnswerInfo.add(record);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (RecordInfo<?> info : additionalAnswerInfo) {
|
||||
additionalAnswerRecords.add(info.record);
|
||||
}
|
||||
|
||||
// RFC6762 6.1: negative responses
|
||||
addNsecRecordsForUniqueNames(additionalAnswerRecords,
|
||||
answerInfo.listIterator(answersStartIndex),
|
||||
additionalAnswerInfo.listIterator());
|
||||
}
|
||||
|
||||
/**
|
||||
* Add NSEC records indicating that the response records are unique.
|
||||
*
|
||||
@@ -540,6 +763,7 @@ public class MdnsRecordRepository {
|
||||
final long now = SystemClock.elapsedRealtime();
|
||||
for (RecordInfo<?> record : registration.allRecords) {
|
||||
record.lastSentTimeMs = now;
|
||||
record.lastAdvertisedTimeMs = now;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,8 +16,15 @@
|
||||
|
||||
package com.android.server.connectivity.mdns;
|
||||
|
||||
import static com.android.server.connectivity.mdns.MdnsSocketProvider.ensureRunningOnHandlerThread;
|
||||
|
||||
import android.annotation.NonNull;
|
||||
import android.os.Handler;
|
||||
import android.os.Looper;
|
||||
import android.os.Message;
|
||||
import android.util.Log;
|
||||
|
||||
import com.android.server.connectivity.mdns.MdnsRecordRepository.ReplyInfo;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.DatagramPacket;
|
||||
@@ -25,6 +32,7 @@ import java.net.Inet4Address;
|
||||
import java.net.Inet6Address;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.net.MulticastSocket;
|
||||
import java.util.Collections;
|
||||
|
||||
/**
|
||||
* A class that handles sending mDNS replies to a {@link MulticastSocket}, possibly queueing them
|
||||
@@ -33,20 +41,38 @@ import java.net.MulticastSocket;
|
||||
* TODO: implement sending after a delay, combining queued replies and duplicate answer suppression
|
||||
*/
|
||||
public class MdnsReplySender {
|
||||
private static final boolean DBG = MdnsAdvertiser.DBG;
|
||||
private static final int MSG_SEND = 1;
|
||||
|
||||
private final String mLogTag;
|
||||
@NonNull
|
||||
private final MdnsInterfaceSocket mSocket;
|
||||
@NonNull
|
||||
private final Looper mLooper;
|
||||
private final Handler mHandler;
|
||||
@NonNull
|
||||
private final byte[] mPacketCreationBuffer;
|
||||
|
||||
public MdnsReplySender(@NonNull Looper looper,
|
||||
public MdnsReplySender(@NonNull String interfaceTag, @NonNull Looper looper,
|
||||
@NonNull MdnsInterfaceSocket socket, @NonNull byte[] packetCreationBuffer) {
|
||||
mLooper = looper;
|
||||
mHandler = new SendHandler(looper);
|
||||
mLogTag = MdnsReplySender.class.getSimpleName() + "/" + interfaceTag;
|
||||
mSocket = socket;
|
||||
mPacketCreationBuffer = packetCreationBuffer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Queue a reply to be sent when its send delay expires.
|
||||
*/
|
||||
public void queueReply(@NonNull ReplyInfo reply) {
|
||||
ensureRunningOnHandlerThread(mHandler);
|
||||
// TODO: implement response aggregation (RFC 6762 6.4)
|
||||
mHandler.sendMessageDelayed(mHandler.obtainMessage(MSG_SEND, reply), reply.sendDelayMs);
|
||||
|
||||
if (DBG) {
|
||||
Log.v(mLogTag, "Scheduling " + reply);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a packet immediately.
|
||||
*
|
||||
@@ -54,9 +80,7 @@ public class MdnsReplySender {
|
||||
*/
|
||||
public void sendNow(@NonNull MdnsPacket packet, @NonNull InetSocketAddress destination)
|
||||
throws IOException {
|
||||
if (Thread.currentThread() != mLooper.getThread()) {
|
||||
throw new IllegalStateException("sendNow must be called in the handler thread");
|
||||
}
|
||||
ensureRunningOnHandlerThread(mHandler);
|
||||
if (!((destination.getAddress() instanceof Inet6Address && mSocket.hasJoinedIpv6())
|
||||
|| (destination.getAddress() instanceof Inet4Address && mSocket.hasJoinedIpv4()))) {
|
||||
// Skip sending if the socket has not joined the v4/v6 group (there was no address)
|
||||
@@ -93,4 +117,37 @@ public class MdnsReplySender {
|
||||
|
||||
mSocket.send(new DatagramPacket(outBuffer, 0, len, destination));
|
||||
}
|
||||
|
||||
/**
|
||||
* Cancel all pending sends.
|
||||
*/
|
||||
public void cancelAll() {
|
||||
ensureRunningOnHandlerThread(mHandler);
|
||||
mHandler.removeMessages(MSG_SEND);
|
||||
}
|
||||
|
||||
private class SendHandler extends Handler {
|
||||
SendHandler(@NonNull Looper looper) {
|
||||
super(looper);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleMessage(@NonNull Message msg) {
|
||||
final ReplyInfo replyInfo = (ReplyInfo) msg.obj;
|
||||
if (DBG) Log.v(mLogTag, "Sending " + replyInfo);
|
||||
|
||||
final int flags = 0x8400; // Response, authoritative (rfc6762 18.4)
|
||||
final MdnsPacket packet = new MdnsPacket(flags,
|
||||
Collections.emptyList() /* questions */,
|
||||
replyInfo.answers,
|
||||
Collections.emptyList() /* authorityRecords */,
|
||||
replyInfo.additionalAnswers);
|
||||
|
||||
try {
|
||||
sendNow(packet, replyInfo.destination);
|
||||
} catch (IOException e) {
|
||||
Log.e(mLogTag, "Error sending MDNS response", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,5 +107,14 @@ public class MulticastPacketReader extends FdEventsReader<MulticastPacketReader.
|
||||
ensureRunningOnHandlerThread(mHandler);
|
||||
mPacketHandlers.add(handler);
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a packet handler added via {@link #addPacketHandler}. If the handler was not set,
|
||||
* this is a no-op.
|
||||
*/
|
||||
public void removePacketHandler(@NonNull PacketHandler handler) {
|
||||
ensureRunningOnHandlerThread(mHandler);
|
||||
mPacketHandlers.remove(handler);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ import org.mockito.ArgumentMatchers.eq
|
||||
import org.mockito.Mockito.any
|
||||
import org.mockito.Mockito.anyInt
|
||||
import org.mockito.Mockito.argThat
|
||||
import org.mockito.Mockito.atLeastOnce
|
||||
import org.mockito.Mockito.doReturn
|
||||
import org.mockito.Mockito.mock
|
||||
import org.mockito.Mockito.never
|
||||
@@ -161,6 +162,60 @@ class MdnsAdvertiserTest {
|
||||
verify(socketProvider).unrequestSocket(socketCb)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAddService_Conflicts() {
|
||||
val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps)
|
||||
postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1) }
|
||||
|
||||
val oneNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
|
||||
verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), oneNetSocketCbCaptor.capture())
|
||||
val oneNetSocketCb = oneNetSocketCbCaptor.value
|
||||
|
||||
// Register a service with the same name on all networks (name conflict)
|
||||
postSync { advertiser.addService(SERVICE_ID_2, ALL_NETWORKS_SERVICE) }
|
||||
val allNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
|
||||
verify(socketProvider).requestSocket(eq(null), allNetSocketCbCaptor.capture())
|
||||
val allNetSocketCb = allNetSocketCbCaptor.value
|
||||
|
||||
// Callbacks for matching network and all networks both get the socket
|
||||
postSync {
|
||||
oneNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
|
||||
allNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
|
||||
}
|
||||
|
||||
val expectedRenamed = NsdServiceInfo(
|
||||
"${ALL_NETWORKS_SERVICE.serviceName} (2)", ALL_NETWORKS_SERVICE.serviceType).apply {
|
||||
port = ALL_NETWORKS_SERVICE.port
|
||||
host = ALL_NETWORKS_SERVICE.host
|
||||
network = ALL_NETWORKS_SERVICE.network
|
||||
}
|
||||
|
||||
val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
|
||||
verify(mockDeps).makeAdvertiser(eq(mockSocket1), eq(listOf(TEST_LINKADDR)),
|
||||
eq(thread.looper), any(), intAdvCbCaptor.capture())
|
||||
verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_1),
|
||||
argThat { it.matches(SERVICE_1) })
|
||||
verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_2),
|
||||
argThat { it.matches(expectedRenamed) })
|
||||
|
||||
doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_1)
|
||||
postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded(
|
||||
mockInterfaceAdvertiser1, SERVICE_ID_1) }
|
||||
verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_1), argThat { it.matches(SERVICE_1) })
|
||||
|
||||
doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_2)
|
||||
postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded(
|
||||
mockInterfaceAdvertiser1, SERVICE_ID_2) }
|
||||
verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_2),
|
||||
argThat { it.matches(expectedRenamed) })
|
||||
|
||||
postSync { oneNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
|
||||
postSync { allNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
|
||||
|
||||
// destroyNow can be called multiple times
|
||||
verify(mockInterfaceAdvertiser1, atLeastOnce()).destroyNow()
|
||||
}
|
||||
|
||||
private fun postSync(r: () -> Unit) {
|
||||
handler.post(r)
|
||||
handler.waitForIdle(TIMEOUT_MS)
|
||||
|
||||
@@ -79,7 +79,7 @@ class MdnsAnnouncerTest {
|
||||
|
||||
@Test
|
||||
fun testAnnounce() {
|
||||
val replySender = MdnsReplySender(thread.looper, socket, buffer)
|
||||
val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
val cb = mock(MdnsPacketRepeater.PacketRepeaterCallback::class.java)
|
||||
as MdnsPacketRepeater.PacketRepeaterCallback<BaseAnnouncementInfo>
|
||||
|
||||
@@ -21,6 +21,7 @@ import android.net.LinkAddress
|
||||
import android.net.nsd.NsdServiceInfo
|
||||
import android.os.Build
|
||||
import android.os.HandlerThread
|
||||
import com.android.net.module.util.HexDump
|
||||
import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
|
||||
import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo
|
||||
import com.android.server.connectivity.mdns.MdnsAnnouncer.ExitAnnouncementInfo
|
||||
@@ -30,6 +31,10 @@ import com.android.server.connectivity.mdns.MdnsProber.ProbingInfo
|
||||
import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
|
||||
import com.android.testutils.DevSdkIgnoreRunner
|
||||
import com.android.testutils.waitForIdle
|
||||
import java.net.InetSocketAddress
|
||||
import kotlin.test.assertContentEquals
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
import org.junit.After
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
@@ -37,8 +42,10 @@ import org.junit.runner.RunWith
|
||||
import org.mockito.ArgumentCaptor
|
||||
import org.mockito.Mockito.any
|
||||
import org.mockito.Mockito.anyInt
|
||||
import org.mockito.Mockito.anyString
|
||||
import org.mockito.Mockito.doAnswer
|
||||
import org.mockito.Mockito.doReturn
|
||||
import org.mockito.Mockito.eq
|
||||
import org.mockito.Mockito.mock
|
||||
import org.mockito.Mockito.times
|
||||
import org.mockito.Mockito.verify
|
||||
@@ -67,13 +74,18 @@ class MdnsInterfaceAdvertiserTest {
|
||||
private val replySender = mock(MdnsReplySender::class.java)
|
||||
private val announcer = mock(MdnsAnnouncer::class.java)
|
||||
private val prober = mock(MdnsProber::class.java)
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
private val probeCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
|
||||
as ArgumentCaptor<PacketRepeaterCallback<ProbingInfo>>
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
private val announceCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
|
||||
as ArgumentCaptor<PacketRepeaterCallback<BaseAnnouncementInfo>>
|
||||
private val packetHandlerCaptor = ArgumentCaptor.forClass(
|
||||
MulticastPacketReader.PacketHandler::class.java)
|
||||
|
||||
private val probeCb get() = probeCbCaptor.value
|
||||
private val announceCb get() = announceCbCaptor.value
|
||||
private val packetHandler get() = packetHandlerCaptor.value
|
||||
|
||||
private val advertiser by lazy {
|
||||
MdnsInterfaceAdvertiser(LOG_TAG, socket, TEST_ADDRS, thread.looper, TEST_BUFFER, cb, deps)
|
||||
@@ -82,9 +94,9 @@ class MdnsInterfaceAdvertiserTest {
|
||||
@Before
|
||||
fun setUp() {
|
||||
doReturn(repository).`when`(deps).makeRecordRepository(any())
|
||||
doReturn(replySender).`when`(deps).makeReplySender(any(), any(), any())
|
||||
doReturn(announcer).`when`(deps).makeMdnsAnnouncer(any(), any(), any(), any())
|
||||
doReturn(prober).`when`(deps).makeMdnsProber(any(), any(), any(), any())
|
||||
doReturn(replySender).`when`(deps).makeReplySender(anyString(), any(), any(), any())
|
||||
doReturn(announcer).`when`(deps).makeMdnsAnnouncer(anyString(), any(), any(), any())
|
||||
doReturn(prober).`when`(deps).makeMdnsProber(anyString(), any(), any(), any())
|
||||
|
||||
val knownServices = mutableSetOf<Int>()
|
||||
doAnswer { inv ->
|
||||
@@ -104,6 +116,7 @@ class MdnsInterfaceAdvertiserTest {
|
||||
thread.start()
|
||||
advertiser.start()
|
||||
|
||||
verify(socket).addPacketHandler(packetHandlerCaptor.capture())
|
||||
verify(deps).makeMdnsProber(any(), any(), any(), probeCbCaptor.capture())
|
||||
verify(deps).makeMdnsAnnouncer(any(), any(), any(), announceCbCaptor.capture())
|
||||
}
|
||||
@@ -157,6 +170,39 @@ class MdnsInterfaceAdvertiserTest {
|
||||
verify(announcer, times(1)).stop(TEST_SERVICE_ID_1)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testReplyToQuery() {
|
||||
addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1)
|
||||
|
||||
val mockReply = mock(MdnsRecordRepository.ReplyInfo::class.java)
|
||||
doReturn(mockReply).`when`(repository).getReply(any(), any())
|
||||
|
||||
// Query obtained with:
|
||||
// scapy.raw(scapy.DNS(
|
||||
// qd = scapy.DNSQR(qtype='PTR', qname='_testservice._tcp.local'))
|
||||
// ).hex().upper()
|
||||
val query = HexDump.hexStringToByteArray(
|
||||
"0000010000010000000000000C5F7465737473657276696365045F746370056C6F63616C00000C0001"
|
||||
)
|
||||
val src = InetSocketAddress(parseNumericAddress("2001:db8::456"), MdnsConstants.MDNS_PORT)
|
||||
packetHandler.handlePacket(query, query.size, src)
|
||||
|
||||
val packetCaptor = ArgumentCaptor.forClass(MdnsPacket::class.java)
|
||||
verify(repository).getReply(packetCaptor.capture(), eq(src))
|
||||
|
||||
packetCaptor.value.let {
|
||||
assertEquals(1, it.questions.size)
|
||||
assertEquals(0, it.answers.size)
|
||||
assertEquals(0, it.authorityRecords.size)
|
||||
assertEquals(0, it.additionalRecords.size)
|
||||
|
||||
assertTrue(it.questions[0] is MdnsPointerRecord)
|
||||
assertContentEquals(arrayOf("_testservice", "_tcp", "local"), it.questions[0].name)
|
||||
}
|
||||
|
||||
verify(replySender).queueReply(mockReply)
|
||||
}
|
||||
|
||||
private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
|
||||
AnnouncementInfo {
|
||||
val testProbingInfo = mock(ProbingInfo::class.java)
|
||||
|
||||
@@ -114,7 +114,7 @@ class MdnsProberTest {
|
||||
|
||||
@Test
|
||||
fun testProbe() {
|
||||
val replySender = MdnsReplySender(thread.looper, socket, buffer)
|
||||
val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
|
||||
val prober = TestProber(thread.looper, replySender, cb)
|
||||
val probeInfo = TestProbeInfo(
|
||||
listOf(makeServiceRecord(TEST_SERVICE_NAME_1, 37890)))
|
||||
@@ -129,7 +129,7 @@ class MdnsProberTest {
|
||||
|
||||
@Test
|
||||
fun testProbeMultipleRecords() {
|
||||
val replySender = MdnsReplySender(thread.looper, socket, buffer)
|
||||
val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
|
||||
val prober = TestProber(thread.looper, replySender, cb)
|
||||
val probeInfo = TestProbeInfo(listOf(
|
||||
makeServiceRecord(TEST_SERVICE_NAME_1, 37890),
|
||||
@@ -167,7 +167,7 @@ class MdnsProberTest {
|
||||
|
||||
@Test
|
||||
fun testStopProbing() {
|
||||
val replySender = MdnsReplySender(thread.looper, socket, buffer)
|
||||
val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
|
||||
val prober = TestProber(thread.looper, replySender, cb)
|
||||
val probeInfo = TestProbeInfo(
|
||||
listOf(makeServiceRecord(TEST_SERVICE_NAME_1, 37890)),
|
||||
|
||||
@@ -21,10 +21,12 @@ import android.net.LinkAddress
|
||||
import android.net.nsd.NsdServiceInfo
|
||||
import android.os.Build
|
||||
import android.os.HandlerThread
|
||||
import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
|
||||
import com.android.server.connectivity.mdns.MdnsRecordRepository.Dependencies
|
||||
import com.android.server.connectivity.mdns.MdnsRecordRepository.getReverseDnsAddress
|
||||
import com.android.testutils.DevSdkIgnoreRule
|
||||
import com.android.testutils.DevSdkIgnoreRunner
|
||||
import java.net.InetSocketAddress
|
||||
import java.net.NetworkInterface
|
||||
import java.util.Collections
|
||||
import kotlin.test.assertContentEquals
|
||||
@@ -150,11 +152,7 @@ class MdnsRecordRepositoryTest {
|
||||
@Test
|
||||
fun testExitAnnouncements() {
|
||||
val repository = MdnsRecordRepository(thread.looper, deps)
|
||||
repository.updateAddresses(TEST_ADDRESSES)
|
||||
|
||||
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
|
||||
val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
|
||||
repository.onProbingSucceeded(probingInfo)
|
||||
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
|
||||
repository.onAdvertisementSent(TEST_SERVICE_ID_1)
|
||||
|
||||
val exitAnnouncement = repository.exitService(TEST_SERVICE_ID_1)
|
||||
@@ -183,9 +181,7 @@ class MdnsRecordRepositoryTest {
|
||||
@Test
|
||||
fun testExitingServiceReAdded() {
|
||||
val repository = MdnsRecordRepository(thread.looper, deps)
|
||||
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
|
||||
val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
|
||||
repository.onProbingSucceeded(probingInfo)
|
||||
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
|
||||
repository.onAdvertisementSent(TEST_SERVICE_ID_1)
|
||||
repository.exitService(TEST_SERVICE_ID_1)
|
||||
|
||||
@@ -199,11 +195,8 @@ class MdnsRecordRepositoryTest {
|
||||
@Test
|
||||
fun testOnProbingSucceeded() {
|
||||
val repository = MdnsRecordRepository(thread.looper, deps)
|
||||
repository.updateAddresses(TEST_ADDRESSES)
|
||||
|
||||
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
|
||||
val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
|
||||
val announcementInfo = repository.onProbingSucceeded(probingInfo)
|
||||
val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
|
||||
repository.onAdvertisementSent(TEST_SERVICE_ID_1)
|
||||
val packet = announcementInfo.getPacket(0)
|
||||
|
||||
assertEquals(0x8400 /* response, authoritative */, packet.flags)
|
||||
@@ -322,4 +315,98 @@ class MdnsRecordRepositoryTest {
|
||||
val expectedV4 = "123.2.0.192.in-addr.arpa".split(".").toTypedArray()
|
||||
assertContentEquals(expectedV4, getReverseDnsAddress(parseNumericAddress("192.0.2.123")))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testGetReply() {
|
||||
val repository = MdnsRecordRepository(thread.looper, deps)
|
||||
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
|
||||
val questions = listOf(MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"),
|
||||
0L /* receiptTimeMillis */,
|
||||
false /* cacheFlush */,
|
||||
// TTL and data is empty for a question
|
||||
0L /* ttlMillis */,
|
||||
null /* pointer */))
|
||||
val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
|
||||
listOf() /* authorityRecords */, listOf() /* additionalRecords */)
|
||||
val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
|
||||
val reply = repository.getReply(query, src)
|
||||
|
||||
assertNotNull(reply)
|
||||
// Source address is IPv4
|
||||
assertEquals(MdnsConstants.getMdnsIPv4Address(), reply.destination.address)
|
||||
assertEquals(MdnsConstants.MDNS_PORT, reply.destination.port)
|
||||
|
||||
// TTLs as per RFC6762 10.
|
||||
val longTtl = 4_500_000L
|
||||
val shortTtl = 120_000L
|
||||
val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
|
||||
|
||||
assertEquals(listOf(
|
||||
MdnsPointerRecord(
|
||||
arrayOf("_testservice", "_tcp", "local"),
|
||||
0L /* receiptTimeMillis */,
|
||||
false /* cacheFlush */,
|
||||
longTtl,
|
||||
serviceName),
|
||||
), reply.answers)
|
||||
|
||||
assertEquals(listOf(
|
||||
MdnsTextRecord(
|
||||
serviceName,
|
||||
0L /* receiptTimeMillis */,
|
||||
true /* cacheFlush */,
|
||||
longTtl,
|
||||
listOf() /* entries */),
|
||||
MdnsServiceRecord(
|
||||
serviceName,
|
||||
0L /* receiptTimeMillis */,
|
||||
true /* cacheFlush */,
|
||||
shortTtl,
|
||||
0 /* servicePriority */,
|
||||
0 /* serviceWeight */,
|
||||
TEST_PORT,
|
||||
TEST_HOSTNAME),
|
||||
MdnsInetAddressRecord(
|
||||
TEST_HOSTNAME,
|
||||
0L /* receiptTimeMillis */,
|
||||
true /* cacheFlush */,
|
||||
shortTtl,
|
||||
TEST_ADDRESSES[0].address),
|
||||
MdnsInetAddressRecord(
|
||||
TEST_HOSTNAME,
|
||||
0L /* receiptTimeMillis */,
|
||||
true /* cacheFlush */,
|
||||
shortTtl,
|
||||
TEST_ADDRESSES[1].address),
|
||||
MdnsInetAddressRecord(
|
||||
TEST_HOSTNAME,
|
||||
0L /* receiptTimeMillis */,
|
||||
true /* cacheFlush */,
|
||||
shortTtl,
|
||||
TEST_ADDRESSES[2].address),
|
||||
MdnsNsecRecord(
|
||||
serviceName,
|
||||
0L /* receiptTimeMillis */,
|
||||
true /* cacheFlush */,
|
||||
longTtl,
|
||||
serviceName /* nextDomain */,
|
||||
intArrayOf(MdnsRecord.TYPE_TXT, MdnsRecord.TYPE_SRV)),
|
||||
MdnsNsecRecord(
|
||||
TEST_HOSTNAME,
|
||||
0L /* receiptTimeMillis */,
|
||||
true /* cacheFlush */,
|
||||
shortTtl,
|
||||
TEST_HOSTNAME /* nextDomain */,
|
||||
intArrayOf(MdnsRecord.TYPE_A, MdnsRecord.TYPE_AAAA)),
|
||||
), reply.additionalAnswers)
|
||||
}
|
||||
}
|
||||
|
||||
private fun MdnsRecordRepository.initWithService(serviceId: Int, serviceInfo: NsdServiceInfo):
|
||||
AnnouncementInfo {
|
||||
updateAddresses(TEST_ADDRESSES)
|
||||
addService(serviceId, serviceInfo)
|
||||
val probingInfo = setServiceProbing(serviceId)
|
||||
assertNotNull(probingInfo)
|
||||
return onProbingSucceeded(probingInfo)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user