Store transaction id in MdnsPacket

The transaction id is a number that is used to identify a
specific query packet. But it's not necessary for probing
or announcing services, so the transaction id is not
currently used on advertising when creating a MdnsPacket or
decoding the response to a MdnsPacket. This means that it is not
possible to track which query packets have received
responses. Therefore, store the transaction id so that
it can be used for subsequent query packet changes.

Bug: 302269599
Test: atest FrameworksNetTests
Change-Id: I6734752b32b91678afb7df06e1fa51237cf70894
This commit is contained in:
Paul Hu
2023-09-27 18:20:28 +08:00
parent b01d0721a2
commit 6df06daaec
5 changed files with 26 additions and 9 deletions

View File

@@ -31,6 +31,7 @@ import java.util.List;
public class MdnsPacket { public class MdnsPacket {
private static final String TAG = MdnsPacket.class.getSimpleName(); private static final String TAG = MdnsPacket.class.getSimpleName();
public final int transactionId;
public final int flags; public final int flags;
@NonNull @NonNull
public final List<MdnsRecord> questions; public final List<MdnsRecord> questions;
@@ -46,6 +47,15 @@ public class MdnsPacket {
@NonNull List<MdnsRecord> answers, @NonNull List<MdnsRecord> answers,
@NonNull List<MdnsRecord> authorityRecords, @NonNull List<MdnsRecord> authorityRecords,
@NonNull List<MdnsRecord> additionalRecords) { @NonNull List<MdnsRecord> additionalRecords) {
this(0, flags, questions, answers, authorityRecords, additionalRecords);
}
MdnsPacket(int transactionId, int flags,
@NonNull List<MdnsRecord> questions,
@NonNull List<MdnsRecord> answers,
@NonNull List<MdnsRecord> authorityRecords,
@NonNull List<MdnsRecord> additionalRecords) {
this.transactionId = transactionId;
this.flags = flags; this.flags = flags;
this.questions = Collections.unmodifiableList(questions); this.questions = Collections.unmodifiableList(questions);
this.answers = Collections.unmodifiableList(answers); this.answers = Collections.unmodifiableList(answers);
@@ -70,15 +80,16 @@ public class MdnsPacket {
*/ */
@NonNull @NonNull
public static MdnsPacket parse(@NonNull MdnsPacketReader reader) throws ParseException { public static MdnsPacket parse(@NonNull MdnsPacketReader reader) throws ParseException {
final int transactionId;
final int flags; final int flags;
try { try {
reader.readUInt16(); // transaction ID (not used) transactionId = reader.readUInt16();
flags = reader.readUInt16(); flags = reader.readUInt16();
} catch (EOFException e) { } catch (EOFException e) {
throw new ParseException(MdnsResponseErrorCode.ERROR_END_OF_FILE, throw new ParseException(MdnsResponseErrorCode.ERROR_END_OF_FILE,
"Reached the end of the mDNS response unexpectedly.", e); "Reached the end of the mDNS response unexpectedly.", e);
} }
return parseRecordsSection(reader, flags); return parseRecordsSection(reader, flags, transactionId);
} }
/** /**
@@ -86,8 +97,8 @@ public class MdnsPacket {
* *
* The records section starts with the questions count, just after the packet flags. * The records section starts with the questions count, just after the packet flags.
*/ */
public static MdnsPacket parseRecordsSection(@NonNull MdnsPacketReader reader, int flags) public static MdnsPacket parseRecordsSection(@NonNull MdnsPacketReader reader, int flags,
throws ParseException { int transactionId) throws ParseException {
try { try {
final int numQuestions = reader.readUInt16(); final int numQuestions = reader.readUInt16();
final int numAnswers = reader.readUInt16(); final int numAnswers = reader.readUInt16();
@@ -99,7 +110,7 @@ public class MdnsPacket {
final ArrayList<MdnsRecord> authority = parseRecords(reader, numAuthority, false); final ArrayList<MdnsRecord> authority = parseRecords(reader, numAuthority, false);
final ArrayList<MdnsRecord> additional = parseRecords(reader, numAdditional, false); final ArrayList<MdnsRecord> additional = parseRecords(reader, numAdditional, false);
return new MdnsPacket(flags, questions, answers, authority, additional); return new MdnsPacket(transactionId, flags, questions, answers, authority, additional);
} catch (EOFException e) { } catch (EOFException e) {
throw new ParseException(MdnsResponseErrorCode.ERROR_END_OF_FILE, throw new ParseException(MdnsResponseErrorCode.ERROR_END_OF_FILE,
"Reached the end of the mDNS response unexpectedly.", e); "Reached the end of the mDNS response unexpectedly.", e);

View File

@@ -90,14 +90,14 @@ public class MdnsResponseDecoder {
final MdnsPacket mdnsPacket; final MdnsPacket mdnsPacket;
try { try {
reader.readUInt16(); // transaction ID (not used) final int transactionId = reader.readUInt16();
int flags = reader.readUInt16(); int flags = reader.readUInt16();
if ((flags & MdnsConstants.FLAGS_RESPONSE_MASK) != MdnsConstants.FLAGS_RESPONSE) { if ((flags & MdnsConstants.FLAGS_RESPONSE_MASK) != MdnsConstants.FLAGS_RESPONSE) {
throw new MdnsPacket.ParseException( throw new MdnsPacket.ParseException(
MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE, "Not a response", null); MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE, "Not a response", null);
} }
mdnsPacket = MdnsPacket.parseRecordsSection(reader, flags); mdnsPacket = MdnsPacket.parseRecordsSection(reader, flags, transactionId);
if (mdnsPacket.answers.size() < 1) { if (mdnsPacket.answers.size() < 1) {
throw new MdnsPacket.ParseException( throw new MdnsPacket.ParseException(
MdnsResponseErrorCode.ERROR_NO_ANSWERS, "Response has no answers", MdnsResponseErrorCode.ERROR_NO_ANSWERS, "Response has no answers",

View File

@@ -189,7 +189,7 @@ public class MdnsUtils {
// TODO: support packets over size (send in multiple packets with TC bit set) // TODO: support packets over size (send in multiple packets with TC bit set)
final MdnsPacketWriter writer = new MdnsPacketWriter(packetCreationBuffer); final MdnsPacketWriter writer = new MdnsPacketWriter(packetCreationBuffer);
writer.writeUInt16(0); // Transaction ID (advertisement: 0) writer.writeUInt16(packet.transactionId); // Transaction ID (advertisement: 0)
writer.writeUInt16(packet.flags); // Response, authoritative (rfc6762 18.4) writer.writeUInt16(packet.flags); // Response, authoritative (rfc6762 18.4)
writer.writeUInt16(packet.questions.size()); // questions count writer.writeUInt16(packet.questions.size()); // questions count
writer.writeUInt16(packet.answers.size()); // answers count writer.writeUInt16(packet.answers.size()); // answers count

View File

@@ -32,7 +32,7 @@ class MdnsPacketTest {
// Probe packet with 1 question for Android.local, and 4 additionalRecords with 4 addresses // Probe packet with 1 question for Android.local, and 4 additionalRecords with 4 addresses
// for Android.local (similar to legacy mdnsresponder probes, although it used to put 4 // for Android.local (similar to legacy mdnsresponder probes, although it used to put 4
// identical questions(!!) for Android.local when there were 4 addresses). // identical questions(!!) for Android.local when there were 4 addresses).
val packetHex = "00000000000100000004000007416e64726f6964056c6f63616c0000ff0001c00c000100" + val packetHex = "007b0000000100000004000007416e64726f6964056c6f63616c0000ff0001c00c000100" +
"01000000780004c000027bc00c001c000100000078001020010db8000000000000000000000123c0" + "01000000780004c000027bc00c001c000100000078001020010db8000000000000000000000123c0" +
"0c001c000100000078001020010db8000000000000000000000456c00c001c000100000078001020" + "0c001c000100000078001020010db8000000000000000000000456c00c001c000100000078001020" +
"010db8000000000000000000000789" "010db8000000000000000000000789"
@@ -41,6 +41,7 @@ class MdnsPacketTest {
val reader = MdnsPacketReader(bytes, bytes.size) val reader = MdnsPacketReader(bytes, bytes.size)
val packet = MdnsPacket.parse(reader) val packet = MdnsPacket.parse(reader)
assertEquals(123, packet.transactionId)
assertEquals(1, packet.questions.size) assertEquals(1, packet.questions.size)
assertEquals(0, packet.answers.size) assertEquals(0, packet.answers.size)
assertEquals(4, packet.authorityRecords.size) assertEquals(4, packet.authorityRecords.size)

View File

@@ -105,6 +105,7 @@ class MdnsRecordRepositoryTest {
assertEquals(TEST_SERVICE_ID_1, probingInfo.serviceId) assertEquals(TEST_SERVICE_ID_1, probingInfo.serviceId)
val packet = probingInfo.getPacket(0) val packet = probingInfo.getPacket(0)
assertEquals(0, packet.transactionId)
assertEquals(MdnsConstants.FLAGS_QUERY, packet.flags) assertEquals(MdnsConstants.FLAGS_QUERY, packet.flags)
assertEquals(0, packet.answers.size) assertEquals(0, packet.answers.size)
assertEquals(0, packet.additionalRecords.size) assertEquals(0, packet.additionalRecords.size)
@@ -173,6 +174,7 @@ class MdnsRecordRepositoryTest {
assertEquals(1, repository.servicesCount) assertEquals(1, repository.servicesCount)
val packet = exitAnnouncement.getPacket(0) val packet = exitAnnouncement.getPacket(0)
assertEquals(0, packet.transactionId)
assertEquals(0x8400 /* response, authoritative */, packet.flags) assertEquals(0x8400 /* response, authoritative */, packet.flags)
assertEquals(0, packet.questions.size) assertEquals(0, packet.questions.size)
assertEquals(0, packet.authorityRecords.size) assertEquals(0, packet.authorityRecords.size)
@@ -202,6 +204,7 @@ class MdnsRecordRepositoryTest {
assertEquals(1, repository.servicesCount) assertEquals(1, repository.servicesCount)
val packet = exitAnnouncement.getPacket(0) val packet = exitAnnouncement.getPacket(0)
assertEquals(0, packet.transactionId)
assertEquals(0x8400 /* response, authoritative */, packet.flags) assertEquals(0x8400 /* response, authoritative */, packet.flags)
assertEquals(0, packet.questions.size) assertEquals(0, packet.questions.size)
assertEquals(0, packet.authorityRecords.size) assertEquals(0, packet.authorityRecords.size)
@@ -249,6 +252,7 @@ class MdnsRecordRepositoryTest {
repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */) repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */)
val packet = announcementInfo.getPacket(0) val packet = announcementInfo.getPacket(0)
assertEquals(0, packet.transactionId)
assertEquals(0x8400 /* response, authoritative */, packet.flags) assertEquals(0x8400 /* response, authoritative */, packet.flags)
assertEquals(0, packet.questions.size) assertEquals(0, packet.questions.size)
assertEquals(0, packet.authorityRecords.size) assertEquals(0, packet.authorityRecords.size)
@@ -372,6 +376,7 @@ class MdnsRecordRepositoryTest {
val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local") val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
val serviceType = arrayOf("_testservice", "_tcp", "local") val serviceType = arrayOf("_testservice", "_tcp", "local")
val offloadPacket = repository.getOffloadPacket(TEST_SERVICE_ID_1) val offloadPacket = repository.getOffloadPacket(TEST_SERVICE_ID_1)
assertEquals(0, offloadPacket.transactionId)
assertEquals(0x8400, offloadPacket.flags) assertEquals(0x8400, offloadPacket.flags)
assertEquals(0, offloadPacket.questions.size) assertEquals(0, offloadPacket.questions.size)
assertEquals(0, offloadPacket.additionalRecords.size) assertEquals(0, offloadPacket.additionalRecords.size)