Store transaction id in MdnsPacket

The transaction id is a number that is used to identify a
specific query packet. But it's not necessary for probing
or announcing services, so the transaction id is not
currently used on advertising when creating a MdnsPacket or
decoding the response to a MdnsPacket. This means that it is not
possible to track which query packets have received
responses. Therefore, store the transaction id so that
it can be used for subsequent query packet changes.

Bug: 302269599
Test: atest FrameworksNetTests
Change-Id: I6734752b32b91678afb7df06e1fa51237cf70894
This commit is contained in:
Paul Hu
2023-09-27 18:20:28 +08:00
parent b01d0721a2
commit 6df06daaec
5 changed files with 26 additions and 9 deletions

View File

@@ -31,6 +31,7 @@ import java.util.List;
public class MdnsPacket {
private static final String TAG = MdnsPacket.class.getSimpleName();
public final int transactionId;
public final int flags;
@NonNull
public final List<MdnsRecord> questions;
@@ -46,6 +47,15 @@ public class MdnsPacket {
@NonNull List<MdnsRecord> answers,
@NonNull List<MdnsRecord> authorityRecords,
@NonNull List<MdnsRecord> additionalRecords) {
this(0, flags, questions, answers, authorityRecords, additionalRecords);
}
MdnsPacket(int transactionId, int flags,
@NonNull List<MdnsRecord> questions,
@NonNull List<MdnsRecord> answers,
@NonNull List<MdnsRecord> authorityRecords,
@NonNull List<MdnsRecord> additionalRecords) {
this.transactionId = transactionId;
this.flags = flags;
this.questions = Collections.unmodifiableList(questions);
this.answers = Collections.unmodifiableList(answers);
@@ -70,15 +80,16 @@ public class MdnsPacket {
*/
@NonNull
public static MdnsPacket parse(@NonNull MdnsPacketReader reader) throws ParseException {
final int transactionId;
final int flags;
try {
reader.readUInt16(); // transaction ID (not used)
transactionId = reader.readUInt16();
flags = reader.readUInt16();
} catch (EOFException e) {
throw new ParseException(MdnsResponseErrorCode.ERROR_END_OF_FILE,
"Reached the end of the mDNS response unexpectedly.", e);
}
return parseRecordsSection(reader, flags);
return parseRecordsSection(reader, flags, transactionId);
}
/**
@@ -86,8 +97,8 @@ public class MdnsPacket {
*
* The records section starts with the questions count, just after the packet flags.
*/
public static MdnsPacket parseRecordsSection(@NonNull MdnsPacketReader reader, int flags)
throws ParseException {
public static MdnsPacket parseRecordsSection(@NonNull MdnsPacketReader reader, int flags,
int transactionId) throws ParseException {
try {
final int numQuestions = reader.readUInt16();
final int numAnswers = reader.readUInt16();
@@ -99,7 +110,7 @@ public class MdnsPacket {
final ArrayList<MdnsRecord> authority = parseRecords(reader, numAuthority, false);
final ArrayList<MdnsRecord> additional = parseRecords(reader, numAdditional, false);
return new MdnsPacket(flags, questions, answers, authority, additional);
return new MdnsPacket(transactionId, flags, questions, answers, authority, additional);
} catch (EOFException e) {
throw new ParseException(MdnsResponseErrorCode.ERROR_END_OF_FILE,
"Reached the end of the mDNS response unexpectedly.", e);

View File

@@ -90,14 +90,14 @@ public class MdnsResponseDecoder {
final MdnsPacket mdnsPacket;
try {
reader.readUInt16(); // transaction ID (not used)
final int transactionId = reader.readUInt16();
int flags = reader.readUInt16();
if ((flags & MdnsConstants.FLAGS_RESPONSE_MASK) != MdnsConstants.FLAGS_RESPONSE) {
throw new MdnsPacket.ParseException(
MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE, "Not a response", null);
}
mdnsPacket = MdnsPacket.parseRecordsSection(reader, flags);
mdnsPacket = MdnsPacket.parseRecordsSection(reader, flags, transactionId);
if (mdnsPacket.answers.size() < 1) {
throw new MdnsPacket.ParseException(
MdnsResponseErrorCode.ERROR_NO_ANSWERS, "Response has no answers",

View File

@@ -189,7 +189,7 @@ public class MdnsUtils {
// TODO: support packets over size (send in multiple packets with TC bit set)
final MdnsPacketWriter writer = new MdnsPacketWriter(packetCreationBuffer);
writer.writeUInt16(0); // Transaction ID (advertisement: 0)
writer.writeUInt16(packet.transactionId); // Transaction ID (advertisement: 0)
writer.writeUInt16(packet.flags); // Response, authoritative (rfc6762 18.4)
writer.writeUInt16(packet.questions.size()); // questions count
writer.writeUInt16(packet.answers.size()); // answers count

View File

@@ -32,7 +32,7 @@ class MdnsPacketTest {
// Probe packet with 1 question for Android.local, and 4 additionalRecords with 4 addresses
// for Android.local (similar to legacy mdnsresponder probes, although it used to put 4
// identical questions(!!) for Android.local when there were 4 addresses).
val packetHex = "00000000000100000004000007416e64726f6964056c6f63616c0000ff0001c00c000100" +
val packetHex = "007b0000000100000004000007416e64726f6964056c6f63616c0000ff0001c00c000100" +
"01000000780004c000027bc00c001c000100000078001020010db8000000000000000000000123c0" +
"0c001c000100000078001020010db8000000000000000000000456c00c001c000100000078001020" +
"010db8000000000000000000000789"
@@ -41,6 +41,7 @@ class MdnsPacketTest {
val reader = MdnsPacketReader(bytes, bytes.size)
val packet = MdnsPacket.parse(reader)
assertEquals(123, packet.transactionId)
assertEquals(1, packet.questions.size)
assertEquals(0, packet.answers.size)
assertEquals(4, packet.authorityRecords.size)

View File

@@ -105,6 +105,7 @@ class MdnsRecordRepositoryTest {
assertEquals(TEST_SERVICE_ID_1, probingInfo.serviceId)
val packet = probingInfo.getPacket(0)
assertEquals(0, packet.transactionId)
assertEquals(MdnsConstants.FLAGS_QUERY, packet.flags)
assertEquals(0, packet.answers.size)
assertEquals(0, packet.additionalRecords.size)
@@ -173,6 +174,7 @@ class MdnsRecordRepositoryTest {
assertEquals(1, repository.servicesCount)
val packet = exitAnnouncement.getPacket(0)
assertEquals(0, packet.transactionId)
assertEquals(0x8400 /* response, authoritative */, packet.flags)
assertEquals(0, packet.questions.size)
assertEquals(0, packet.authorityRecords.size)
@@ -202,6 +204,7 @@ class MdnsRecordRepositoryTest {
assertEquals(1, repository.servicesCount)
val packet = exitAnnouncement.getPacket(0)
assertEquals(0, packet.transactionId)
assertEquals(0x8400 /* response, authoritative */, packet.flags)
assertEquals(0, packet.questions.size)
assertEquals(0, packet.authorityRecords.size)
@@ -249,6 +252,7 @@ class MdnsRecordRepositoryTest {
repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */)
val packet = announcementInfo.getPacket(0)
assertEquals(0, packet.transactionId)
assertEquals(0x8400 /* response, authoritative */, packet.flags)
assertEquals(0, packet.questions.size)
assertEquals(0, packet.authorityRecords.size)
@@ -372,6 +376,7 @@ class MdnsRecordRepositoryTest {
val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
val serviceType = arrayOf("_testservice", "_tcp", "local")
val offloadPacket = repository.getOffloadPacket(TEST_SERVICE_ID_1)
assertEquals(0, offloadPacket.transactionId)
assertEquals(0x8400, offloadPacket.flags)
assertEquals(0, offloadPacket.questions.size)
assertEquals(0, offloadPacket.additionalRecords.size)