diff --git a/service-t/src/com/android/server/mdns/MdnsPacket.java b/service-t/src/com/android/server/mdns/MdnsPacket.java index eae084aca7..27002b9cc6 100644 --- a/service-t/src/com/android/server/mdns/MdnsPacket.java +++ b/service-t/src/com/android/server/mdns/MdnsPacket.java @@ -16,6 +16,13 @@ package com.android.server.connectivity.mdns; +import android.annotation.NonNull; +import android.annotation.Nullable; +import android.util.Log; + +import java.io.EOFException; +import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -23,21 +30,202 @@ import java.util.List; * A class holding data that can be included in a mDNS packet. */ public class MdnsPacket { + private static final String TAG = MdnsPacket.class.getSimpleName(); + public final int flags; + @NonNull public final List questions; + @NonNull public final List answers; + @NonNull public final List authorityRecords; + @NonNull public final List additionalRecords; MdnsPacket(int flags, - List questions, - List answers, - List authorityRecords, - List additionalRecords) { + @NonNull List questions, + @NonNull List answers, + @NonNull List authorityRecords, + @NonNull List additionalRecords) { this.flags = flags; this.questions = Collections.unmodifiableList(questions); this.answers = Collections.unmodifiableList(answers); this.authorityRecords = Collections.unmodifiableList(authorityRecords); this.additionalRecords = Collections.unmodifiableList(additionalRecords); } + + /** + * Exception thrown on parse errors. + */ + public static class ParseException extends IOException { + public final int code; + + public ParseException(int code, @NonNull String message, @Nullable Throwable cause) { + super(message, cause); + this.code = code; + } + } + + /** + * Parse the packet in the provided {@link MdnsPacketReader}. + */ + @NonNull + public static MdnsPacket parse(@NonNull MdnsPacketReader reader) throws ParseException { + final int flags; + try { + reader.readUInt16(); // transaction ID (not used) + flags = reader.readUInt16(); + } catch (EOFException e) { + throw new ParseException(MdnsResponseErrorCode.ERROR_END_OF_FILE, + "Reached the end of the mDNS response unexpectedly.", e); + } + return parseRecordsSection(reader, flags); + } + + /** + * Parse the records section of a mDNS packet in the provided {@link MdnsPacketReader}. + * + * The records section starts with the questions count, just after the packet flags. + */ + public static MdnsPacket parseRecordsSection(@NonNull MdnsPacketReader reader, int flags) + throws ParseException { + try { + final int numQuestions = reader.readUInt16(); + final int numAnswers = reader.readUInt16(); + final int numAuthority = reader.readUInt16(); + final int numAdditional = reader.readUInt16(); + + final ArrayList questions = parseRecords(reader, numQuestions, true); + final ArrayList answers = parseRecords(reader, numAnswers, false); + final ArrayList authority = parseRecords(reader, numAuthority, false); + final ArrayList additional = parseRecords(reader, numAdditional, false); + + return new MdnsPacket(flags, questions, answers, authority, additional); + } catch (EOFException e) { + throw new ParseException(MdnsResponseErrorCode.ERROR_END_OF_FILE, + "Reached the end of the mDNS response unexpectedly.", e); + } + } + + private static ArrayList parseRecords(@NonNull MdnsPacketReader reader, int count, + boolean isQuestion) + throws ParseException { + final ArrayList records = new ArrayList<>(count); + for (int i = 0; i < count; ++i) { + final MdnsRecord record = parseRecord(reader, isQuestion); + if (record != null) { + records.add(record); + } + } + return records; + } + + @Nullable + private static MdnsRecord parseRecord(@NonNull MdnsPacketReader reader, boolean isQuestion) + throws ParseException { + String[] name; + try { + name = reader.readLabels(); + } catch (IOException e) { + throw new ParseException(MdnsResponseErrorCode.ERROR_READING_RECORD_NAME, + "Failed to read labels from mDNS response.", e); + } + + final int type; + try { + type = reader.readUInt16(); + } catch (EOFException e) { + throw new ParseException(MdnsResponseErrorCode.ERROR_END_OF_FILE, + "Reached the end of the mDNS response unexpectedly.", e); + } + + switch (type) { + case MdnsRecord.TYPE_A: { + try { + return new MdnsInetAddressRecord(name, MdnsRecord.TYPE_A, reader, isQuestion); + } catch (IOException e) { + throw new ParseException(MdnsResponseErrorCode.ERROR_READING_A_RDATA, + "Failed to read A record from mDNS response.", e); + } + } + + case MdnsRecord.TYPE_AAAA: { + try { + return new MdnsInetAddressRecord(name, + MdnsRecord.TYPE_AAAA, reader, isQuestion); + } catch (IOException e) { + throw new ParseException(MdnsResponseErrorCode.ERROR_READING_AAAA_RDATA, + "Failed to read AAAA record from mDNS response.", e); + } + } + + case MdnsRecord.TYPE_PTR: { + try { + return new MdnsPointerRecord(name, reader, isQuestion); + } catch (IOException e) { + throw new ParseException(MdnsResponseErrorCode.ERROR_READING_PTR_RDATA, + "Failed to read PTR record from mDNS response.", e); + } + } + + case MdnsRecord.TYPE_SRV: { + try { + return new MdnsServiceRecord(name, reader, isQuestion); + } catch (IOException e) { + throw new ParseException(MdnsResponseErrorCode.ERROR_READING_SRV_RDATA, + "Failed to read SRV record from mDNS response.", e); + } + } + + case MdnsRecord.TYPE_TXT: { + try { + return new MdnsTextRecord(name, reader, isQuestion); + } catch (IOException e) { + throw new ParseException(MdnsResponseErrorCode.ERROR_READING_TXT_RDATA, + "Failed to read TXT record from mDNS response.", e); + } + } + + case MdnsRecord.TYPE_NSEC: { + try { + return new MdnsNsecRecord(name, reader, isQuestion); + } catch (IOException e) { + throw new ParseException(MdnsResponseErrorCode.ERROR_READING_NSEC_RDATA, + "Failed to read NSEC record from mDNS response.", e); + } + } + + case MdnsRecord.TYPE_ANY: { + try { + return new MdnsAnyRecord(name, reader); + } catch (IOException e) { + throw new ParseException(MdnsResponseErrorCode.ERROR_READING_ANY_RDATA, + "Failed to read TYPE_ANY record from mDNS response.", e); + } + } + + default: { + try { + if (MdnsAdvertiser.DBG) { + Log.i(TAG, "Skipping parsing of record of unhandled type " + type); + } + skipMdnsRecord(reader, isQuestion); + return null; + } catch (IOException e) { + throw new ParseException(MdnsResponseErrorCode.ERROR_SKIPPING_UNKNOWN_RECORD, + "Failed to skip mDNS record.", e); + } + } + } + } + + private static void skipMdnsRecord(@NonNull MdnsPacketReader reader, boolean isQuestion) + throws IOException { + reader.skip(2); // Skip the class + if (isQuestion) return; + // Skip TTL and data + reader.skip(4); + int dataLength = reader.readUInt16(); + reader.skip(dataLength); + } } diff --git a/service-t/src/com/android/server/mdns/MdnsResponseDecoder.java b/service-t/src/com/android/server/mdns/MdnsResponseDecoder.java index 50f206993b..82da2e42ed 100644 --- a/service-t/src/com/android/server/mdns/MdnsResponseDecoder.java +++ b/service-t/src/com/android/server/mdns/MdnsResponseDecoder.java @@ -24,11 +24,9 @@ import android.os.SystemClock; import com.android.server.connectivity.mdns.util.MdnsLogger; import java.io.EOFException; -import java.io.IOException; import java.net.DatagramPacket; import java.util.ArrayList; import java.util.Arrays; -import java.util.LinkedList; import java.util.List; /** A class that decodes mDNS responses from UDP packets. */ @@ -48,12 +46,6 @@ public class MdnsResponseDecoder { this.serviceType = serviceType; } - private static void skipMdnsRecord(MdnsPacketReader reader) throws IOException { - reader.skip(2 + 4); // skip the class and TTL - int dataLength = reader.readUInt16(); - reader.skip(dataLength); - } - private static MdnsResponse findResponseWithPointer( List responses, String[] pointer) { if (responses != null) { @@ -120,7 +112,7 @@ public class MdnsResponseDecoder { int interfaceIndex, @Nullable Network network) { MdnsPacketReader reader = new MdnsPacketReader(recvbuf, length); - List records; + final MdnsPacket mdnsPacket; try { reader.readUInt16(); // transaction ID (not used) int flags = reader.readUInt16(); @@ -128,111 +120,25 @@ public class MdnsResponseDecoder { return MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE; } - int numQuestions = reader.readUInt16(); - int numAnswers = reader.readUInt16(); - int numAuthority = reader.readUInt16(); - int numRecords = reader.readUInt16(); - - LOGGER.log(String.format( - "num questions: %d, num answers: %d, num authority: %d, num records: %d", - numQuestions, numAnswers, numAuthority, numRecords)); - - if (numAnswers < 1) { + mdnsPacket = MdnsPacket.parseRecordsSection(reader, flags); + if (mdnsPacket.answers.size() < 1) { return MdnsResponseErrorCode.ERROR_NO_ANSWERS; } - - records = new LinkedList<>(); - - for (int i = 0; i < (numAnswers + numAuthority + numRecords); ++i) { - String[] name; - try { - name = reader.readLabels(); - } catch (IOException e) { - LOGGER.e("Failed to read labels from mDNS response.", e); - return MdnsResponseErrorCode.ERROR_READING_RECORD_NAME; - } - int type = reader.readUInt16(); - - switch (type) { - case MdnsRecord.TYPE_A: { - try { - records.add(new MdnsInetAddressRecord(name, MdnsRecord.TYPE_A, reader)); - } catch (IOException e) { - LOGGER.e("Failed to read A record from mDNS response.", e); - return MdnsResponseErrorCode.ERROR_READING_A_RDATA; - } - break; - } - - case MdnsRecord.TYPE_AAAA: { - try { - // AAAA should only contain the IPv6 address. - MdnsInetAddressRecord record = - new MdnsInetAddressRecord(name, MdnsRecord.TYPE_AAAA, reader); - if (record.getInet6Address() != null) { - records.add(record); - } - } catch (IOException e) { - LOGGER.e("Failed to read AAAA record from mDNS response.", e); - return MdnsResponseErrorCode.ERROR_READING_AAAA_RDATA; - } - break; - } - - case MdnsRecord.TYPE_PTR: { - try { - records.add(new MdnsPointerRecord(name, reader)); - } catch (IOException e) { - LOGGER.e("Failed to read PTR record from mDNS response.", e); - return MdnsResponseErrorCode.ERROR_READING_PTR_RDATA; - } - break; - } - - case MdnsRecord.TYPE_SRV: { - if (name.length == 4) { - try { - records.add(new MdnsServiceRecord(name, reader)); - } catch (IOException e) { - LOGGER.e("Failed to read SRV record from mDNS response.", e); - return MdnsResponseErrorCode.ERROR_READING_SRV_RDATA; - } - } else { - try { - skipMdnsRecord(reader); - } catch (IOException e) { - LOGGER.e("Failed to skip SVR record from mDNS response.", e); - return MdnsResponseErrorCode.ERROR_SKIPPING_SRV_RDATA; - } - } - break; - } - - case MdnsRecord.TYPE_TXT: { - try { - records.add(new MdnsTextRecord(name, reader)); - } catch (IOException e) { - LOGGER.e("Failed to read TXT record from mDNS response.", e); - return MdnsResponseErrorCode.ERROR_READING_TXT_RDATA; - } - break; - } - - default: { - try { - skipMdnsRecord(reader); - } catch (IOException e) { - LOGGER.e("Failed to skip mDNS record.", e); - return MdnsResponseErrorCode.ERROR_SKIPPING_UNKNOWN_RECORD; - } - } - } - } } catch (EOFException e) { LOGGER.e("Reached the end of the mDNS response unexpectedly.", e); return MdnsResponseErrorCode.ERROR_END_OF_FILE; + } catch (MdnsPacket.ParseException e) { + LOGGER.e(e.getMessage(), e); + return e.code; } + final ArrayList records = new ArrayList<>( + mdnsPacket.questions.size() + mdnsPacket.answers.size() + + mdnsPacket.authorityRecords.size() + mdnsPacket.additionalRecords.size()); + records.addAll(mdnsPacket.answers); + records.addAll(mdnsPacket.authorityRecords); + records.addAll(mdnsPacket.additionalRecords); + // The response records are structured in a hierarchy, where some records reference // others, as follows: // diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt index 650607d9cc..334f99d701 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt @@ -91,7 +91,7 @@ class MdnsAnnouncerTest { scapy.raw(scapy.dns_compress(scapy.DNS(rd=0, qr=1, aa=1, qd = None, an = - scapy.DNSRR(type='PTR', rrname='123.0.2.192.in-addr.arpa.', rdata='Android.local', + scapy.DNSRR(type='PTR', rrname='123.2.0.192.in-addr.arpa.', rdata='Android.local', rclass=0x8001, ttl=120) / scapy.DNSRR(type='PTR', rrname='3.2.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.B.D.0.1.0.0.2.ip6.arpa', @@ -111,8 +111,8 @@ class MdnsAnnouncerTest { scapy.DNSRR(type='AAAA', rrname='Android.local', rclass=0x8001, rdata='2001:db8::456', ttl=120), ar = - scapy.DNSRRNSEC(rrname='123.0.2.192.in-addr.arpa.', rclass=0x8001, ttl=120, - nextname='123.0.2.192.in-addr.arpa.', typebitmaps=[12]) / + scapy.DNSRRNSEC(rrname='123.2.0.192.in-addr.arpa.', rclass=0x8001, ttl=120, + nextname='123.2.0.192.in-addr.arpa.', typebitmaps=[12]) / scapy.DNSRRNSEC( rrname='3.2.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.B.D.0.1.0.0.2.ip6.arpa', rclass=0x8001, ttl=120, @@ -131,7 +131,7 @@ class MdnsAnnouncerTest { typebitmaps=[1, 28])) )).hex().upper() */ - val expected = "00008400000000090000000503313233013001320331393207696E2D61646472046172706" + + val expected = "00008400000000090000000503313233013201300331393207696E2D61646472046172706" + "100000C800100000078000F07416E64726F6964056C6F63616C00013301320131013001300130013" + "00130013001300130013001300130013001300130013001300130013001300130013001380142014" + "40130013101300130013203697036C020000C8001000000780002C030013601350134C045000C800" + @@ -149,7 +149,7 @@ class MdnsAnnouncerTest { val v4Addr = parseNumericAddress("192.0.2.123") val v6Addr1 = parseNumericAddress("2001:DB8::123") val v6Addr2 = parseNumericAddress("2001:DB8::456") - val v4AddrRev = arrayOf("123", "0", "2", "192", "in-addr", "arpa") + val v4AddrRev = getReverseDnsAddress(v4Addr) val v6Addr1Rev = getReverseDnsAddress(v6Addr1) val v6Addr2Rev = getReverseDnsAddress(v6Addr2) @@ -254,7 +254,10 @@ class MdnsAnnouncerTest { verify(socket, atLeast(i + 1)).send(any()) val now = SystemClock.elapsedRealtime() assertTrue(now > timeStart + startDelay + i * FIRST_ANNOUNCES_DELAY) - assertTrue(now < timeStart + startDelay + (i + 1) * FIRST_ANNOUNCES_DELAY) + // Loops can be much slower than the expected timing (>100ms delay), use + // TEST_TIMEOUT_MS as tolerance. + assertTrue(now < timeStart + startDelay + (i + 1) * FIRST_ANNOUNCES_DELAY + + TEST_TIMEOUT_MS) } // Subsequent announces should happen quickly (NEXT_ANNOUNCES_DELAY) diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsPacketTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsPacketTest.kt new file mode 100644 index 0000000000..f88da1fe72 --- /dev/null +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsPacketTest.kt @@ -0,0 +1,70 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.server.connectivity.mdns + +import android.net.InetAddresses +import com.android.net.module.util.HexDump +import com.android.testutils.DevSdkIgnoreRunner +import kotlin.test.assertContentEquals +import kotlin.test.assertEquals +import kotlin.test.assertTrue +import org.junit.Test +import org.junit.runner.RunWith + +@RunWith(DevSdkIgnoreRunner::class) +class MdnsPacketTest { + @Test + fun testParseQuery() { + // Probe packet with 1 question for Android.local, and 4 additionalRecords with 4 addresses + // for Android.local (similar to legacy mdnsresponder probes, although it used to put 4 + // identical questions(!!) for Android.local when there were 4 addresses). + val packetHex = "00000000000100000004000007416e64726f6964056c6f63616c0000ff0001c00c000100" + + "01000000780004c000027bc00c001c000100000078001020010db8000000000000000000000123c0" + + "0c001c000100000078001020010db8000000000000000000000456c00c001c000100000078001020" + + "010db8000000000000000000000789" + + val bytes = HexDump.hexStringToByteArray(packetHex) + val reader = MdnsPacketReader(bytes, bytes.size) + val packet = MdnsPacket.parse(reader) + + assertEquals(1, packet.questions.size) + assertEquals(0, packet.answers.size) + assertEquals(4, packet.authorityRecords.size) + assertEquals(0, packet.additionalRecords.size) + + val hostname = arrayOf("Android", "local") + packet.questions[0].let { + assertTrue(it is MdnsAnyRecord) + assertContentEquals(hostname, it.name) + } + + packet.authorityRecords.forEach { + assertTrue(it is MdnsInetAddressRecord) + assertContentEquals(hostname, it.name) + assertEquals(120000, it.ttl) + } + + assertEquals(InetAddresses.parseNumericAddress("192.0.2.123"), + (packet.authorityRecords[0] as MdnsInetAddressRecord).inet4Address) + assertEquals(InetAddresses.parseNumericAddress("2001:db8::123"), + (packet.authorityRecords[1] as MdnsInetAddressRecord).inet6Address) + assertEquals(InetAddresses.parseNumericAddress("2001:db8::456"), + (packet.authorityRecords[2] as MdnsInetAddressRecord).inet6Address) + assertEquals(InetAddresses.parseNumericAddress("2001:db8::789"), + (packet.authorityRecords[3] as MdnsInetAddressRecord).inet6Address) + } +}