Add MdnsProber

MdnsProber is an implementation of MdnsPacketRepeater that will be used
to send probes for service names before advertising them, to know if
they are already in use.

Bug: 241738458
Test: atest
Change-Id: I4e5f779b891e2c665ba7f752fb5fbd4255070725
This commit is contained in:
Remi NGUYEN VAN
2022-11-16 18:00:07 +09:00
parent edbf34a182
commit 3568fddb36
5 changed files with 372 additions and 4 deletions

View File

@@ -0,0 +1,156 @@
/*
* Copyright (C) 2022 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.server.connectivity.mdns;
import android.annotation.NonNull;
import android.os.Looper;
import com.android.internal.annotations.VisibleForTesting;
import com.android.net.module.util.CollectionUtils;
import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.function.Supplier;
/**
* Sends mDns probe requests to verify service records are unique on the network.
*
* TODO: implement receiving replies and handling conflicts.
*/
public class MdnsProber extends MdnsPacketRepeater<MdnsProber.ProbingInfo> {
@NonNull
private final String mLogTag;
public MdnsProber(@NonNull String interfaceTag, @NonNull Looper looper,
@NonNull MdnsReplySender replySender,
@NonNull PacketRepeaterCallback<ProbingInfo> cb) {
// 3 packets as per https://datatracker.ietf.org/doc/html/rfc6762#section-8.1
super(looper, replySender, cb);
mLogTag = MdnsProber.class.getSimpleName() + "/" + interfaceTag;
}
static class ProbingInfo implements Request {
private final int mServiceId;
@NonNull
private final MdnsPacket mPacket;
@NonNull
private final Supplier<Iterable<SocketAddress>> mDestinationsSupplier;
/**
* Create a new ProbingInfo
* @param serviceId Service to probe for.
* @param probeRecords Records to be probed for uniqueness.
* @param destinationsSupplier Supplier for the probe destinations. Will be called on the
* probe handler thread for each probe.
*/
ProbingInfo(int serviceId, @NonNull List<MdnsRecord> probeRecords,
@NonNull Supplier<Iterable<SocketAddress>> destinationsSupplier) {
mServiceId = serviceId;
mPacket = makePacket(probeRecords);
mDestinationsSupplier = destinationsSupplier;
}
public int getServiceId() {
return mServiceId;
}
@NonNull
@Override
public MdnsPacket getPacket(int index) {
return mPacket;
}
@NonNull
@Override
public Iterable<SocketAddress> getDestinations(int index) {
return mDestinationsSupplier.get();
}
@Override
public long getDelayMs(int nextIndex) {
// As per https://datatracker.ietf.org/doc/html/rfc6762#section-8.1
return 250L;
}
@Override
public int getNumSends() {
// 3 packets as per https://datatracker.ietf.org/doc/html/rfc6762#section-8.1
return 3;
}
private static MdnsPacket makePacket(@NonNull List<MdnsRecord> records) {
final ArrayList<MdnsRecord> questions = new ArrayList<>(records.size());
for (final MdnsRecord record : records) {
if (containsName(questions, record.getName())) {
// Already added this name
continue;
}
// TODO: legacy Android mDNS used to send the first probe (only) as unicast, even
// though https://datatracker.ietf.org/doc/html/rfc6762#section-8.1 says they
// SHOULD all be. rfc6762 15.1 says that if the port is shared with another
// responder unicast questions should not be used, and the legacy mdnsresponder may
// be running, so not using unicast at all may be better. Consider using legacy
// behavior if this causes problems.
questions.add(new MdnsAnyRecord(record.getName(), false /* unicast */));
}
return new MdnsPacket(
MdnsConstants.FLAGS_QUERY,
questions,
Collections.emptyList() /* answers */,
records /* authorityRecords */,
Collections.emptyList() /* additionalRecords */);
}
/**
* Return whether the specified name is present in the list of records.
*/
private static boolean containsName(@NonNull List<MdnsRecord> records,
@NonNull String[] name) {
return CollectionUtils.any(records, r -> Arrays.equals(name, r.getName()));
}
}
@NonNull
@Override
protected String getTag() {
return mLogTag;
}
@VisibleForTesting
protected long getInitialDelay() {
// First wait for a random time in 0-250ms
// as per https://datatracker.ietf.org/doc/html/rfc6762#section-8.1
return (long) (Math.random() * 250);
}
/**
* Start sending packets for probing.
*/
public void startProbing(@NonNull ProbingInfo info) {
startProbing(info, getInitialDelay());
}
private void startProbing(@NonNull ProbingInfo info, long delay) {
startSending(info.getServiceId(), info, delay);
}
}

View File

@@ -200,6 +200,17 @@ public abstract class MdnsRecord {
*/ */
protected abstract void readData(MdnsPacketReader reader) throws IOException; protected abstract void readData(MdnsPacketReader reader) throws IOException;
/**
* Write the first fields of the record, which are common fields for questions and answers.
*
* @param writer The writer to use.
*/
public final void writeHeaderFields(MdnsPacketWriter writer) throws IOException {
writer.writeLabels(name);
writer.writeUInt16(type);
writer.writeUInt16(cls);
}
/** /**
* Writes the record to a packet. * Writes the record to a packet.
* *
@@ -208,9 +219,7 @@ public abstract class MdnsRecord {
*/ */
@VisibleForTesting @VisibleForTesting
public final void write(MdnsPacketWriter writer, long now) throws IOException { public final void write(MdnsPacketWriter writer, long now) throws IOException {
writer.writeLabels(name); writeHeaderFields(writer);
writer.writeUInt16(type);
writer.writeUInt16(cls);
writer.writeUInt32(MILLISECONDS.toSeconds(getRemainingTTL(now))); writer.writeUInt32(MILLISECONDS.toSeconds(getRemainingTTL(now)));

View File

@@ -67,7 +67,8 @@ public class MdnsReplySender {
writer.writeUInt16(packet.additionalRecords.size()); // additional records count writer.writeUInt16(packet.additionalRecords.size()); // additional records count
for (MdnsRecord record : packet.questions) { for (MdnsRecord record : packet.questions) {
record.write(writer, 0L); // Questions do not have TTL or data
record.writeHeaderFields(writer);
} }
for (MdnsRecord record : packet.answers) { for (MdnsRecord record : packet.answers) {
record.write(writer, 0L); record.write(writer, 0L);

View File

@@ -74,6 +74,7 @@ filegroup {
"java/com/android/server/connectivity/VpnTest.java", "java/com/android/server/connectivity/VpnTest.java",
"java/com/android/server/net/ipmemorystore/*.java", "java/com/android/server/net/ipmemorystore/*.java",
"java/com/android/server/connectivity/mdns/**/*.java", "java/com/android/server/connectivity/mdns/**/*.java",
"java/com/android/server/connectivity/mdns/**/*.kt",
] ]
} }

View File

@@ -0,0 +1,201 @@
/*
* Copyright (C) 2022 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.server.connectivity.mdns
import android.os.Build
import android.os.Handler
import android.os.HandlerThread
import android.os.Looper
import com.android.internal.util.HexDump
import com.android.server.connectivity.mdns.MdnsProber.ProbingInfo
import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
import com.android.testutils.DevSdkIgnoreRunner
import java.net.DatagramPacket
import java.net.InetSocketAddress
import java.net.MulticastSocket
import java.util.concurrent.CompletableFuture
import java.util.concurrent.TimeUnit
import kotlin.test.assertEquals
import kotlin.test.assertTrue
import org.junit.After
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.ArgumentCaptor
import org.mockito.Mockito.any
import org.mockito.Mockito.atLeast
import org.mockito.Mockito.mock
import org.mockito.Mockito.never
import org.mockito.Mockito.timeout
import org.mockito.Mockito.times
import org.mockito.Mockito.verify
private val destinationsSupplier = {
listOf(InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT)) }
private const val TEST_TIMEOUT_MS = 10_000L
private const val SHORT_TIMEOUT_MS = 200L
private val TEST_SERVICE_NAME_1 = arrayOf("testservice", "_nmt", "_tcp", "local")
private val TEST_SERVICE_NAME_2 = arrayOf("testservice2", "_nmt", "_tcp", "local")
@RunWith(DevSdkIgnoreRunner::class)
@IgnoreUpTo(Build.VERSION_CODES.S_V2)
class MdnsProberTest {
private val thread = HandlerThread(MdnsProberTest::class.simpleName)
private val socket = mock(MulticastSocket::class.java)
@Suppress("UNCHECKED_CAST")
private val cb = mock(MdnsPacketRepeater.PacketRepeaterCallback::class.java)
as MdnsPacketRepeater.PacketRepeaterCallback<ProbingInfo>
private val buffer = ByteArray(1500)
@Before
fun setUp() {
thread.start()
}
@After
fun tearDown() {
thread.quitSafely()
}
private class TestProbeInfo(probeRecords: List<MdnsRecord>, private val delayMs: Long = 1L) :
ProbingInfo(1 /* serviceId */, probeRecords, destinationsSupplier) {
// Just send the packets quickly. Timing-related tests for MdnsPacketRepeater are already
// done in MdnsAnnouncerTest.
override fun getDelayMs(nextIndex: Int) = delayMs
}
private class TestProber(
looper: Looper,
replySender: MdnsReplySender,
cb: PacketRepeaterCallback<ProbingInfo>
) : MdnsProber("testiface", looper, replySender, cb) {
override fun getInitialDelay() = 0L
}
private fun assertProbesSent(probeInfo: TestProbeInfo, expectedHex: String) {
repeat(probeInfo.numSends) { i ->
verify(cb, timeout(TEST_TIMEOUT_MS)).onSent(i, probeInfo)
// If the probe interval is short, more than (i+1) probes may have been sent already
verify(socket, atLeast(i + 1)).send(any())
}
val captor = ArgumentCaptor.forClass(DatagramPacket::class.java)
// There should be exactly numSends probes sent at the end
verify(socket, times(probeInfo.numSends)).send(captor.capture())
captor.allValues.forEach {
assertEquals(expectedHex, HexDump.toHexString(it.data))
}
verify(cb, timeout(TEST_TIMEOUT_MS)).onFinished(probeInfo)
}
private fun makeServiceRecord(name: Array<String>, port: Int) = MdnsServiceRecord(
name,
0L /* receiptTimeMillis */,
false /* cacheFlush */,
120_000L /* ttlMillis */,
0 /* servicePriority */,
0 /* serviceWeight */,
port,
arrayOf("myhostname", "local"))
@Test
fun testProbe() {
val replySender = MdnsReplySender(thread.looper, socket, buffer)
val prober = TestProber(thread.looper, replySender, cb)
val probeInfo = TestProbeInfo(
listOf(makeServiceRecord(TEST_SERVICE_NAME_1, 37890)))
prober.startProbing(probeInfo)
// Inspect with python3:
// import scapy.all as scapy; scapy.DNS(bytes.fromhex('[bytes]')).show2()
val expected = "0000000000010000000100000B7465737473657276696365045F6E6D74045F746370056C" +
"6F63616C0000FF0001C00C002100010000007800130000000094020A6D79686F73746E616D65C022"
assertProbesSent(probeInfo, expected)
}
@Test
fun testProbeMultipleRecords() {
val replySender = MdnsReplySender(thread.looper, socket, buffer)
val prober = TestProber(thread.looper, replySender, cb)
val probeInfo = TestProbeInfo(listOf(
makeServiceRecord(TEST_SERVICE_NAME_1, 37890),
makeServiceRecord(TEST_SERVICE_NAME_2, 37891),
MdnsTextRecord(
// Same name as the first record; there should not be 2 duplicated questions
TEST_SERVICE_NAME_1,
0L /* receiptTimeMillis */,
false /* cacheFlush */,
120_000L /* ttlMillis */,
listOf(MdnsServiceInfo.TextEntry("testKey", "testValue")))))
prober.startProbing(probeInfo)
/*
Expected data obtained with:
scapy.raw(scapy.dns_compress(scapy.DNS(rd=0,
qd =
scapy.DNSQR(qname='testservice._nmt._tcp.local.', qtype='ALL') /
scapy.DNSQR(qname='testservice2._nmt._tcp.local.', qtype='ALL'),
ns=
scapy.DNSRRSRV(rrname='testservice._nmt._tcp.local.', type='SRV', ttl=120,
port=37890, target='myhostname.local.') /
scapy.DNSRRSRV(rrname='testservice2._nmt._tcp.local.', type='SRV', ttl=120,
port=37891, target='myhostname.local.') /
scapy.DNSRR(type='TXT', ttl=120, rrname='testservice._nmt._tcp.local.',
rdata='testKey=testValue'))
)).hex().upper()
// NOTE: due to a bug the second "myhostname" is not getting DNS compressed in the current
// actual probe, so data below is slightly different. Fix compression so it gets compressed.
*/
val expected = "0000000000020000000300000B7465737473657276696365045F6E6D74045F746370056C6" +
"F63616C0000FF00010C746573747365727669636532C01800FF0001C00C002100010000007800130" +
"000000094020A6D79686F73746E616D65C0220C746573747365727669636532C0180021000100000" +
"07800130000000094030A6D79686F73746E616D65C022C00C0010000100000078001211746573744" +
"B65793D7465737456616C7565"
assertProbesSent(probeInfo, expected)
}
@Test
fun testStopProbing() {
val replySender = MdnsReplySender(thread.looper, socket, buffer)
val prober = TestProber(thread.looper, replySender, cb)
val probeInfo = TestProbeInfo(
listOf(makeServiceRecord(TEST_SERVICE_NAME_1, 37890)),
// delayMs is the delay between each probe, so does not apply to the first one
delayMs = SHORT_TIMEOUT_MS)
prober.startProbing(probeInfo)
// Expect the initial probe
verify(cb, timeout(TEST_TIMEOUT_MS)).onSent(0, probeInfo)
// Stop probing
val stopResult = CompletableFuture<Boolean>()
Handler(thread.looper).post { stopResult.complete(prober.stop(probeInfo.serviceId)) }
assertTrue(stopResult.get(TEST_TIMEOUT_MS, TimeUnit.MILLISECONDS),
"stop should return true when probing was in progress")
// Wait for a bit (more than the probe delay) to ensure no more probes were sent
Thread.sleep(SHORT_TIMEOUT_MS * 2)
verify(cb, never()).onSent(1, probeInfo)
verify(cb, never()).onFinished(probeInfo)
// Only one sent packet
verify(socket, times(1)).send(any())
}
}