diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java b/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java new file mode 100644 index 0000000000..db7049e768 --- /dev/null +++ b/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java @@ -0,0 +1,156 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.server.connectivity.mdns; + +import android.annotation.NonNull; +import android.os.Looper; + +import com.android.internal.annotations.VisibleForTesting; +import com.android.net.module.util.CollectionUtils; + +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.function.Supplier; + +/** + * Sends mDns probe requests to verify service records are unique on the network. + * + * TODO: implement receiving replies and handling conflicts. + */ +public class MdnsProber extends MdnsPacketRepeater { + @NonNull + private final String mLogTag; + + public MdnsProber(@NonNull String interfaceTag, @NonNull Looper looper, + @NonNull MdnsReplySender replySender, + @NonNull PacketRepeaterCallback cb) { + // 3 packets as per https://datatracker.ietf.org/doc/html/rfc6762#section-8.1 + super(looper, replySender, cb); + mLogTag = MdnsProber.class.getSimpleName() + "/" + interfaceTag; + } + + static class ProbingInfo implements Request { + + private final int mServiceId; + @NonNull + private final MdnsPacket mPacket; + @NonNull + private final Supplier> mDestinationsSupplier; + + /** + * Create a new ProbingInfo + * @param serviceId Service to probe for. + * @param probeRecords Records to be probed for uniqueness. + * @param destinationsSupplier Supplier for the probe destinations. Will be called on the + * probe handler thread for each probe. + */ + ProbingInfo(int serviceId, @NonNull List probeRecords, + @NonNull Supplier> destinationsSupplier) { + mServiceId = serviceId; + mPacket = makePacket(probeRecords); + mDestinationsSupplier = destinationsSupplier; + } + + public int getServiceId() { + return mServiceId; + } + + @NonNull + @Override + public MdnsPacket getPacket(int index) { + return mPacket; + } + + @NonNull + @Override + public Iterable getDestinations(int index) { + return mDestinationsSupplier.get(); + } + + @Override + public long getDelayMs(int nextIndex) { + // As per https://datatracker.ietf.org/doc/html/rfc6762#section-8.1 + return 250L; + } + + @Override + public int getNumSends() { + // 3 packets as per https://datatracker.ietf.org/doc/html/rfc6762#section-8.1 + return 3; + } + + private static MdnsPacket makePacket(@NonNull List records) { + final ArrayList questions = new ArrayList<>(records.size()); + for (final MdnsRecord record : records) { + if (containsName(questions, record.getName())) { + // Already added this name + continue; + } + + // TODO: legacy Android mDNS used to send the first probe (only) as unicast, even + // though https://datatracker.ietf.org/doc/html/rfc6762#section-8.1 says they + // SHOULD all be. rfc6762 15.1 says that if the port is shared with another + // responder unicast questions should not be used, and the legacy mdnsresponder may + // be running, so not using unicast at all may be better. Consider using legacy + // behavior if this causes problems. + questions.add(new MdnsAnyRecord(record.getName(), false /* unicast */)); + } + + return new MdnsPacket( + MdnsConstants.FLAGS_QUERY, + questions, + Collections.emptyList() /* answers */, + records /* authorityRecords */, + Collections.emptyList() /* additionalRecords */); + } + + /** + * Return whether the specified name is present in the list of records. + */ + private static boolean containsName(@NonNull List records, + @NonNull String[] name) { + return CollectionUtils.any(records, r -> Arrays.equals(name, r.getName())); + } + } + + @NonNull + @Override + protected String getTag() { + return mLogTag; + } + + @VisibleForTesting + protected long getInitialDelay() { + // First wait for a random time in 0-250ms + // as per https://datatracker.ietf.org/doc/html/rfc6762#section-8.1 + return (long) (Math.random() * 250); + } + + /** + * Start sending packets for probing. + */ + public void startProbing(@NonNull ProbingInfo info) { + startProbing(info, getInitialDelay()); + } + + private void startProbing(@NonNull ProbingInfo info, long delay) { + startSending(info.getServiceId(), info, delay); + } +} diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsRecord.java b/service/mdns/com/android/server/connectivity/mdns/MdnsRecord.java index 10b882539c..00871ea2d4 100644 --- a/service/mdns/com/android/server/connectivity/mdns/MdnsRecord.java +++ b/service/mdns/com/android/server/connectivity/mdns/MdnsRecord.java @@ -200,6 +200,17 @@ public abstract class MdnsRecord { */ protected abstract void readData(MdnsPacketReader reader) throws IOException; + /** + * Write the first fields of the record, which are common fields for questions and answers. + * + * @param writer The writer to use. + */ + public final void writeHeaderFields(MdnsPacketWriter writer) throws IOException { + writer.writeLabels(name); + writer.writeUInt16(type); + writer.writeUInt16(cls); + } + /** * Writes the record to a packet. * @@ -208,9 +219,7 @@ public abstract class MdnsRecord { */ @VisibleForTesting public final void write(MdnsPacketWriter writer, long now) throws IOException { - writer.writeLabels(name); - writer.writeUInt16(type); - writer.writeUInt16(cls); + writeHeaderFields(writer); writer.writeUInt32(MILLISECONDS.toSeconds(getRemainingTTL(now))); diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsReplySender.java b/service/mdns/com/android/server/connectivity/mdns/MdnsReplySender.java index 2acd7898cb..1fdbc5cff0 100644 --- a/service/mdns/com/android/server/connectivity/mdns/MdnsReplySender.java +++ b/service/mdns/com/android/server/connectivity/mdns/MdnsReplySender.java @@ -67,7 +67,8 @@ public class MdnsReplySender { writer.writeUInt16(packet.additionalRecords.size()); // additional records count for (MdnsRecord record : packet.questions) { - record.write(writer, 0L); + // Questions do not have TTL or data + record.writeHeaderFields(writer); } for (MdnsRecord record : packet.answers) { record.write(writer, 0L); diff --git a/tests/unit/Android.bp b/tests/unit/Android.bp index 8ed735ac9c..209430a232 100644 --- a/tests/unit/Android.bp +++ b/tests/unit/Android.bp @@ -74,6 +74,7 @@ filegroup { "java/com/android/server/connectivity/VpnTest.java", "java/com/android/server/net/ipmemorystore/*.java", "java/com/android/server/connectivity/mdns/**/*.java", + "java/com/android/server/connectivity/mdns/**/*.kt", ] } diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt new file mode 100644 index 0000000000..cc7519193c --- /dev/null +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt @@ -0,0 +1,201 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.server.connectivity.mdns + +import android.os.Build +import android.os.Handler +import android.os.HandlerThread +import android.os.Looper +import com.android.internal.util.HexDump +import com.android.server.connectivity.mdns.MdnsProber.ProbingInfo +import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo +import com.android.testutils.DevSdkIgnoreRunner +import java.net.DatagramPacket +import java.net.InetSocketAddress +import java.net.MulticastSocket +import java.util.concurrent.CompletableFuture +import java.util.concurrent.TimeUnit +import kotlin.test.assertEquals +import kotlin.test.assertTrue +import org.junit.After +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.mockito.ArgumentCaptor +import org.mockito.Mockito.any +import org.mockito.Mockito.atLeast +import org.mockito.Mockito.mock +import org.mockito.Mockito.never +import org.mockito.Mockito.timeout +import org.mockito.Mockito.times +import org.mockito.Mockito.verify + +private val destinationsSupplier = { + listOf(InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT)) } + +private const val TEST_TIMEOUT_MS = 10_000L +private const val SHORT_TIMEOUT_MS = 200L + +private val TEST_SERVICE_NAME_1 = arrayOf("testservice", "_nmt", "_tcp", "local") +private val TEST_SERVICE_NAME_2 = arrayOf("testservice2", "_nmt", "_tcp", "local") + +@RunWith(DevSdkIgnoreRunner::class) +@IgnoreUpTo(Build.VERSION_CODES.S_V2) +class MdnsProberTest { + private val thread = HandlerThread(MdnsProberTest::class.simpleName) + private val socket = mock(MulticastSocket::class.java) + @Suppress("UNCHECKED_CAST") + private val cb = mock(MdnsPacketRepeater.PacketRepeaterCallback::class.java) + as MdnsPacketRepeater.PacketRepeaterCallback + private val buffer = ByteArray(1500) + + @Before + fun setUp() { + thread.start() + } + + @After + fun tearDown() { + thread.quitSafely() + } + + private class TestProbeInfo(probeRecords: List, private val delayMs: Long = 1L) : + ProbingInfo(1 /* serviceId */, probeRecords, destinationsSupplier) { + // Just send the packets quickly. Timing-related tests for MdnsPacketRepeater are already + // done in MdnsAnnouncerTest. + override fun getDelayMs(nextIndex: Int) = delayMs + } + + private class TestProber( + looper: Looper, + replySender: MdnsReplySender, + cb: PacketRepeaterCallback + ) : MdnsProber("testiface", looper, replySender, cb) { + override fun getInitialDelay() = 0L + } + + private fun assertProbesSent(probeInfo: TestProbeInfo, expectedHex: String) { + repeat(probeInfo.numSends) { i -> + verify(cb, timeout(TEST_TIMEOUT_MS)).onSent(i, probeInfo) + // If the probe interval is short, more than (i+1) probes may have been sent already + verify(socket, atLeast(i + 1)).send(any()) + } + + val captor = ArgumentCaptor.forClass(DatagramPacket::class.java) + // There should be exactly numSends probes sent at the end + verify(socket, times(probeInfo.numSends)).send(captor.capture()) + + captor.allValues.forEach { + assertEquals(expectedHex, HexDump.toHexString(it.data)) + } + verify(cb, timeout(TEST_TIMEOUT_MS)).onFinished(probeInfo) + } + + private fun makeServiceRecord(name: Array, port: Int) = MdnsServiceRecord( + name, + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + 120_000L /* ttlMillis */, + 0 /* servicePriority */, + 0 /* serviceWeight */, + port, + arrayOf("myhostname", "local")) + + @Test + fun testProbe() { + val replySender = MdnsReplySender(thread.looper, socket, buffer) + val prober = TestProber(thread.looper, replySender, cb) + val probeInfo = TestProbeInfo( + listOf(makeServiceRecord(TEST_SERVICE_NAME_1, 37890))) + prober.startProbing(probeInfo) + + // Inspect with python3: + // import scapy.all as scapy; scapy.DNS(bytes.fromhex('[bytes]')).show2() + val expected = "0000000000010000000100000B7465737473657276696365045F6E6D74045F746370056C" + + "6F63616C0000FF0001C00C002100010000007800130000000094020A6D79686F73746E616D65C022" + assertProbesSent(probeInfo, expected) + } + + @Test + fun testProbeMultipleRecords() { + val replySender = MdnsReplySender(thread.looper, socket, buffer) + val prober = TestProber(thread.looper, replySender, cb) + val probeInfo = TestProbeInfo(listOf( + makeServiceRecord(TEST_SERVICE_NAME_1, 37890), + makeServiceRecord(TEST_SERVICE_NAME_2, 37891), + MdnsTextRecord( + // Same name as the first record; there should not be 2 duplicated questions + TEST_SERVICE_NAME_1, + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + 120_000L /* ttlMillis */, + listOf(MdnsServiceInfo.TextEntry("testKey", "testValue"))))) + prober.startProbing(probeInfo) + + /* + Expected data obtained with: + scapy.raw(scapy.dns_compress(scapy.DNS(rd=0, + qd = + scapy.DNSQR(qname='testservice._nmt._tcp.local.', qtype='ALL') / + scapy.DNSQR(qname='testservice2._nmt._tcp.local.', qtype='ALL'), + ns= + scapy.DNSRRSRV(rrname='testservice._nmt._tcp.local.', type='SRV', ttl=120, + port=37890, target='myhostname.local.') / + scapy.DNSRRSRV(rrname='testservice2._nmt._tcp.local.', type='SRV', ttl=120, + port=37891, target='myhostname.local.') / + scapy.DNSRR(type='TXT', ttl=120, rrname='testservice._nmt._tcp.local.', + rdata='testKey=testValue')) + )).hex().upper() + // NOTE: due to a bug the second "myhostname" is not getting DNS compressed in the current + // actual probe, so data below is slightly different. Fix compression so it gets compressed. + */ + val expected = "0000000000020000000300000B7465737473657276696365045F6E6D74045F746370056C6" + + "F63616C0000FF00010C746573747365727669636532C01800FF0001C00C002100010000007800130" + + "000000094020A6D79686F73746E616D65C0220C746573747365727669636532C0180021000100000" + + "07800130000000094030A6D79686F73746E616D65C022C00C0010000100000078001211746573744" + + "B65793D7465737456616C7565" + assertProbesSent(probeInfo, expected) + } + + @Test + fun testStopProbing() { + val replySender = MdnsReplySender(thread.looper, socket, buffer) + val prober = TestProber(thread.looper, replySender, cb) + val probeInfo = TestProbeInfo( + listOf(makeServiceRecord(TEST_SERVICE_NAME_1, 37890)), + // delayMs is the delay between each probe, so does not apply to the first one + delayMs = SHORT_TIMEOUT_MS) + prober.startProbing(probeInfo) + + // Expect the initial probe + verify(cb, timeout(TEST_TIMEOUT_MS)).onSent(0, probeInfo) + + // Stop probing + val stopResult = CompletableFuture() + Handler(thread.looper).post { stopResult.complete(prober.stop(probeInfo.serviceId)) } + assertTrue(stopResult.get(TEST_TIMEOUT_MS, TimeUnit.MILLISECONDS), + "stop should return true when probing was in progress") + + // Wait for a bit (more than the probe delay) to ensure no more probes were sent + Thread.sleep(SHORT_TIMEOUT_MS * 2) + verify(cb, never()).onSent(1, probeInfo) + verify(cb, never()).onFinished(probeInfo) + + // Only one sent packet + verify(socket, times(1)).send(any()) + } +}