Merge "Factor out response decoding into MdnsPacket"

This commit is contained in:
Remi NGUYEN VAN
2023-01-18 01:12:58 +00:00
committed by Gerrit Code Review
4 changed files with 284 additions and 117 deletions

View File

@@ -16,6 +16,13 @@
package com.android.server.connectivity.mdns;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.util.Log;
import java.io.EOFException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@@ -23,21 +30,202 @@ import java.util.List;
* A class holding data that can be included in a mDNS packet.
*/
public class MdnsPacket {
private static final String TAG = MdnsPacket.class.getSimpleName();
public final int flags;
@NonNull
public final List<MdnsRecord> questions;
@NonNull
public final List<MdnsRecord> answers;
@NonNull
public final List<MdnsRecord> authorityRecords;
@NonNull
public final List<MdnsRecord> additionalRecords;
MdnsPacket(int flags,
List<MdnsRecord> questions,
List<MdnsRecord> answers,
List<MdnsRecord> authorityRecords,
List<MdnsRecord> additionalRecords) {
@NonNull List<MdnsRecord> questions,
@NonNull List<MdnsRecord> answers,
@NonNull List<MdnsRecord> authorityRecords,
@NonNull List<MdnsRecord> additionalRecords) {
this.flags = flags;
this.questions = Collections.unmodifiableList(questions);
this.answers = Collections.unmodifiableList(answers);
this.authorityRecords = Collections.unmodifiableList(authorityRecords);
this.additionalRecords = Collections.unmodifiableList(additionalRecords);
}
/**
* Exception thrown on parse errors.
*/
public static class ParseException extends IOException {
public final int code;
public ParseException(int code, @NonNull String message, @Nullable Throwable cause) {
super(message, cause);
this.code = code;
}
}
/**
* Parse the packet in the provided {@link MdnsPacketReader}.
*/
@NonNull
public static MdnsPacket parse(@NonNull MdnsPacketReader reader) throws ParseException {
final int flags;
try {
reader.readUInt16(); // transaction ID (not used)
flags = reader.readUInt16();
} catch (EOFException e) {
throw new ParseException(MdnsResponseErrorCode.ERROR_END_OF_FILE,
"Reached the end of the mDNS response unexpectedly.", e);
}
return parseRecordsSection(reader, flags);
}
/**
* Parse the records section of a mDNS packet in the provided {@link MdnsPacketReader}.
*
* The records section starts with the questions count, just after the packet flags.
*/
public static MdnsPacket parseRecordsSection(@NonNull MdnsPacketReader reader, int flags)
throws ParseException {
try {
final int numQuestions = reader.readUInt16();
final int numAnswers = reader.readUInt16();
final int numAuthority = reader.readUInt16();
final int numAdditional = reader.readUInt16();
final ArrayList<MdnsRecord> questions = parseRecords(reader, numQuestions, true);
final ArrayList<MdnsRecord> answers = parseRecords(reader, numAnswers, false);
final ArrayList<MdnsRecord> authority = parseRecords(reader, numAuthority, false);
final ArrayList<MdnsRecord> additional = parseRecords(reader, numAdditional, false);
return new MdnsPacket(flags, questions, answers, authority, additional);
} catch (EOFException e) {
throw new ParseException(MdnsResponseErrorCode.ERROR_END_OF_FILE,
"Reached the end of the mDNS response unexpectedly.", e);
}
}
private static ArrayList<MdnsRecord> parseRecords(@NonNull MdnsPacketReader reader, int count,
boolean isQuestion)
throws ParseException {
final ArrayList<MdnsRecord> records = new ArrayList<>(count);
for (int i = 0; i < count; ++i) {
final MdnsRecord record = parseRecord(reader, isQuestion);
if (record != null) {
records.add(record);
}
}
return records;
}
@Nullable
private static MdnsRecord parseRecord(@NonNull MdnsPacketReader reader, boolean isQuestion)
throws ParseException {
String[] name;
try {
name = reader.readLabels();
} catch (IOException e) {
throw new ParseException(MdnsResponseErrorCode.ERROR_READING_RECORD_NAME,
"Failed to read labels from mDNS response.", e);
}
final int type;
try {
type = reader.readUInt16();
} catch (EOFException e) {
throw new ParseException(MdnsResponseErrorCode.ERROR_END_OF_FILE,
"Reached the end of the mDNS response unexpectedly.", e);
}
switch (type) {
case MdnsRecord.TYPE_A: {
try {
return new MdnsInetAddressRecord(name, MdnsRecord.TYPE_A, reader, isQuestion);
} catch (IOException e) {
throw new ParseException(MdnsResponseErrorCode.ERROR_READING_A_RDATA,
"Failed to read A record from mDNS response.", e);
}
}
case MdnsRecord.TYPE_AAAA: {
try {
return new MdnsInetAddressRecord(name,
MdnsRecord.TYPE_AAAA, reader, isQuestion);
} catch (IOException e) {
throw new ParseException(MdnsResponseErrorCode.ERROR_READING_AAAA_RDATA,
"Failed to read AAAA record from mDNS response.", e);
}
}
case MdnsRecord.TYPE_PTR: {
try {
return new MdnsPointerRecord(name, reader, isQuestion);
} catch (IOException e) {
throw new ParseException(MdnsResponseErrorCode.ERROR_READING_PTR_RDATA,
"Failed to read PTR record from mDNS response.", e);
}
}
case MdnsRecord.TYPE_SRV: {
try {
return new MdnsServiceRecord(name, reader, isQuestion);
} catch (IOException e) {
throw new ParseException(MdnsResponseErrorCode.ERROR_READING_SRV_RDATA,
"Failed to read SRV record from mDNS response.", e);
}
}
case MdnsRecord.TYPE_TXT: {
try {
return new MdnsTextRecord(name, reader, isQuestion);
} catch (IOException e) {
throw new ParseException(MdnsResponseErrorCode.ERROR_READING_TXT_RDATA,
"Failed to read TXT record from mDNS response.", e);
}
}
case MdnsRecord.TYPE_NSEC: {
try {
return new MdnsNsecRecord(name, reader, isQuestion);
} catch (IOException e) {
throw new ParseException(MdnsResponseErrorCode.ERROR_READING_NSEC_RDATA,
"Failed to read NSEC record from mDNS response.", e);
}
}
case MdnsRecord.TYPE_ANY: {
try {
return new MdnsAnyRecord(name, reader);
} catch (IOException e) {
throw new ParseException(MdnsResponseErrorCode.ERROR_READING_ANY_RDATA,
"Failed to read TYPE_ANY record from mDNS response.", e);
}
}
default: {
try {
if (MdnsAdvertiser.DBG) {
Log.i(TAG, "Skipping parsing of record of unhandled type " + type);
}
skipMdnsRecord(reader, isQuestion);
return null;
} catch (IOException e) {
throw new ParseException(MdnsResponseErrorCode.ERROR_SKIPPING_UNKNOWN_RECORD,
"Failed to skip mDNS record.", e);
}
}
}
}
private static void skipMdnsRecord(@NonNull MdnsPacketReader reader, boolean isQuestion)
throws IOException {
reader.skip(2); // Skip the class
if (isQuestion) return;
// Skip TTL and data
reader.skip(4);
int dataLength = reader.readUInt16();
reader.skip(dataLength);
}
}

View File

@@ -24,11 +24,9 @@ import android.os.SystemClock;
import com.android.server.connectivity.mdns.util.MdnsLogger;
import java.io.EOFException;
import java.io.IOException;
import java.net.DatagramPacket;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
/** A class that decodes mDNS responses from UDP packets. */
@@ -48,12 +46,6 @@ public class MdnsResponseDecoder {
this.serviceType = serviceType;
}
private static void skipMdnsRecord(MdnsPacketReader reader) throws IOException {
reader.skip(2 + 4); // skip the class and TTL
int dataLength = reader.readUInt16();
reader.skip(dataLength);
}
private static MdnsResponse findResponseWithPointer(
List<MdnsResponse> responses, String[] pointer) {
if (responses != null) {
@@ -120,7 +112,7 @@ public class MdnsResponseDecoder {
int interfaceIndex, @Nullable Network network) {
MdnsPacketReader reader = new MdnsPacketReader(recvbuf, length);
List<MdnsRecord> records;
final MdnsPacket mdnsPacket;
try {
reader.readUInt16(); // transaction ID (not used)
int flags = reader.readUInt16();
@@ -128,111 +120,25 @@ public class MdnsResponseDecoder {
return MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE;
}
int numQuestions = reader.readUInt16();
int numAnswers = reader.readUInt16();
int numAuthority = reader.readUInt16();
int numRecords = reader.readUInt16();
LOGGER.log(String.format(
"num questions: %d, num answers: %d, num authority: %d, num records: %d",
numQuestions, numAnswers, numAuthority, numRecords));
if (numAnswers < 1) {
mdnsPacket = MdnsPacket.parseRecordsSection(reader, flags);
if (mdnsPacket.answers.size() < 1) {
return MdnsResponseErrorCode.ERROR_NO_ANSWERS;
}
records = new LinkedList<>();
for (int i = 0; i < (numAnswers + numAuthority + numRecords); ++i) {
String[] name;
try {
name = reader.readLabels();
} catch (IOException e) {
LOGGER.e("Failed to read labels from mDNS response.", e);
return MdnsResponseErrorCode.ERROR_READING_RECORD_NAME;
}
int type = reader.readUInt16();
switch (type) {
case MdnsRecord.TYPE_A: {
try {
records.add(new MdnsInetAddressRecord(name, MdnsRecord.TYPE_A, reader));
} catch (IOException e) {
LOGGER.e("Failed to read A record from mDNS response.", e);
return MdnsResponseErrorCode.ERROR_READING_A_RDATA;
}
break;
}
case MdnsRecord.TYPE_AAAA: {
try {
// AAAA should only contain the IPv6 address.
MdnsInetAddressRecord record =
new MdnsInetAddressRecord(name, MdnsRecord.TYPE_AAAA, reader);
if (record.getInet6Address() != null) {
records.add(record);
}
} catch (IOException e) {
LOGGER.e("Failed to read AAAA record from mDNS response.", e);
return MdnsResponseErrorCode.ERROR_READING_AAAA_RDATA;
}
break;
}
case MdnsRecord.TYPE_PTR: {
try {
records.add(new MdnsPointerRecord(name, reader));
} catch (IOException e) {
LOGGER.e("Failed to read PTR record from mDNS response.", e);
return MdnsResponseErrorCode.ERROR_READING_PTR_RDATA;
}
break;
}
case MdnsRecord.TYPE_SRV: {
if (name.length == 4) {
try {
records.add(new MdnsServiceRecord(name, reader));
} catch (IOException e) {
LOGGER.e("Failed to read SRV record from mDNS response.", e);
return MdnsResponseErrorCode.ERROR_READING_SRV_RDATA;
}
} else {
try {
skipMdnsRecord(reader);
} catch (IOException e) {
LOGGER.e("Failed to skip SVR record from mDNS response.", e);
return MdnsResponseErrorCode.ERROR_SKIPPING_SRV_RDATA;
}
}
break;
}
case MdnsRecord.TYPE_TXT: {
try {
records.add(new MdnsTextRecord(name, reader));
} catch (IOException e) {
LOGGER.e("Failed to read TXT record from mDNS response.", e);
return MdnsResponseErrorCode.ERROR_READING_TXT_RDATA;
}
break;
}
default: {
try {
skipMdnsRecord(reader);
} catch (IOException e) {
LOGGER.e("Failed to skip mDNS record.", e);
return MdnsResponseErrorCode.ERROR_SKIPPING_UNKNOWN_RECORD;
}
}
}
}
} catch (EOFException e) {
LOGGER.e("Reached the end of the mDNS response unexpectedly.", e);
return MdnsResponseErrorCode.ERROR_END_OF_FILE;
} catch (MdnsPacket.ParseException e) {
LOGGER.e(e.getMessage(), e);
return e.code;
}
final ArrayList<MdnsRecord> records = new ArrayList<>(
mdnsPacket.questions.size() + mdnsPacket.answers.size()
+ mdnsPacket.authorityRecords.size() + mdnsPacket.additionalRecords.size());
records.addAll(mdnsPacket.answers);
records.addAll(mdnsPacket.authorityRecords);
records.addAll(mdnsPacket.additionalRecords);
// The response records are structured in a hierarchy, where some records reference
// others, as follows:
//

View File

@@ -91,7 +91,7 @@ class MdnsAnnouncerTest {
scapy.raw(scapy.dns_compress(scapy.DNS(rd=0, qr=1, aa=1,
qd = None,
an =
scapy.DNSRR(type='PTR', rrname='123.0.2.192.in-addr.arpa.', rdata='Android.local',
scapy.DNSRR(type='PTR', rrname='123.2.0.192.in-addr.arpa.', rdata='Android.local',
rclass=0x8001, ttl=120) /
scapy.DNSRR(type='PTR',
rrname='3.2.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.B.D.0.1.0.0.2.ip6.arpa',
@@ -111,8 +111,8 @@ class MdnsAnnouncerTest {
scapy.DNSRR(type='AAAA', rrname='Android.local', rclass=0x8001, rdata='2001:db8::456',
ttl=120),
ar =
scapy.DNSRRNSEC(rrname='123.0.2.192.in-addr.arpa.', rclass=0x8001, ttl=120,
nextname='123.0.2.192.in-addr.arpa.', typebitmaps=[12]) /
scapy.DNSRRNSEC(rrname='123.2.0.192.in-addr.arpa.', rclass=0x8001, ttl=120,
nextname='123.2.0.192.in-addr.arpa.', typebitmaps=[12]) /
scapy.DNSRRNSEC(
rrname='3.2.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.B.D.0.1.0.0.2.ip6.arpa',
rclass=0x8001, ttl=120,
@@ -131,7 +131,7 @@ class MdnsAnnouncerTest {
typebitmaps=[1, 28]))
)).hex().upper()
*/
val expected = "00008400000000090000000503313233013001320331393207696E2D61646472046172706" +
val expected = "00008400000000090000000503313233013201300331393207696E2D61646472046172706" +
"100000C800100000078000F07416E64726F6964056C6F63616C00013301320131013001300130013" +
"00130013001300130013001300130013001300130013001300130013001300130013001380142014" +
"40130013101300130013203697036C020000C8001000000780002C030013601350134C045000C800" +
@@ -149,7 +149,7 @@ class MdnsAnnouncerTest {
val v4Addr = parseNumericAddress("192.0.2.123")
val v6Addr1 = parseNumericAddress("2001:DB8::123")
val v6Addr2 = parseNumericAddress("2001:DB8::456")
val v4AddrRev = arrayOf("123", "0", "2", "192", "in-addr", "arpa")
val v4AddrRev = getReverseDnsAddress(v4Addr)
val v6Addr1Rev = getReverseDnsAddress(v6Addr1)
val v6Addr2Rev = getReverseDnsAddress(v6Addr2)
@@ -254,7 +254,10 @@ class MdnsAnnouncerTest {
verify(socket, atLeast(i + 1)).send(any())
val now = SystemClock.elapsedRealtime()
assertTrue(now > timeStart + startDelay + i * FIRST_ANNOUNCES_DELAY)
assertTrue(now < timeStart + startDelay + (i + 1) * FIRST_ANNOUNCES_DELAY)
// Loops can be much slower than the expected timing (>100ms delay), use
// TEST_TIMEOUT_MS as tolerance.
assertTrue(now < timeStart + startDelay + (i + 1) * FIRST_ANNOUNCES_DELAY +
TEST_TIMEOUT_MS)
}
// Subsequent announces should happen quickly (NEXT_ANNOUNCES_DELAY)

View File

@@ -0,0 +1,70 @@
/*
* Copyright (C) 2022 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.server.connectivity.mdns
import android.net.InetAddresses
import com.android.net.module.util.HexDump
import com.android.testutils.DevSdkIgnoreRunner
import kotlin.test.assertContentEquals
import kotlin.test.assertEquals
import kotlin.test.assertTrue
import org.junit.Test
import org.junit.runner.RunWith
@RunWith(DevSdkIgnoreRunner::class)
class MdnsPacketTest {
@Test
fun testParseQuery() {
// Probe packet with 1 question for Android.local, and 4 additionalRecords with 4 addresses
// for Android.local (similar to legacy mdnsresponder probes, although it used to put 4
// identical questions(!!) for Android.local when there were 4 addresses).
val packetHex = "00000000000100000004000007416e64726f6964056c6f63616c0000ff0001c00c000100" +
"01000000780004c000027bc00c001c000100000078001020010db8000000000000000000000123c0" +
"0c001c000100000078001020010db8000000000000000000000456c00c001c000100000078001020" +
"010db8000000000000000000000789"
val bytes = HexDump.hexStringToByteArray(packetHex)
val reader = MdnsPacketReader(bytes, bytes.size)
val packet = MdnsPacket.parse(reader)
assertEquals(1, packet.questions.size)
assertEquals(0, packet.answers.size)
assertEquals(4, packet.authorityRecords.size)
assertEquals(0, packet.additionalRecords.size)
val hostname = arrayOf("Android", "local")
packet.questions[0].let {
assertTrue(it is MdnsAnyRecord)
assertContentEquals(hostname, it.name)
}
packet.authorityRecords.forEach {
assertTrue(it is MdnsInetAddressRecord)
assertContentEquals(hostname, it.name)
assertEquals(120000, it.ttl)
}
assertEquals(InetAddresses.parseNumericAddress("192.0.2.123"),
(packet.authorityRecords[0] as MdnsInetAddressRecord).inet4Address)
assertEquals(InetAddresses.parseNumericAddress("2001:db8::123"),
(packet.authorityRecords[1] as MdnsInetAddressRecord).inet6Address)
assertEquals(InetAddresses.parseNumericAddress("2001:db8::456"),
(packet.authorityRecords[2] as MdnsInetAddressRecord).inet6Address)
assertEquals(InetAddresses.parseNumericAddress("2001:db8::789"),
(packet.authorityRecords[3] as MdnsInetAddressRecord).inet6Address)
}
}