Merge "Send rawOffloadPacket to OffloadEngine" into main

This commit is contained in:
Yuyang Huang
2023-09-06 00:04:00 +00:00
committed by Gerrit Code Review
9 changed files with 266 additions and 38 deletions

View File

@@ -160,6 +160,23 @@ public final class OffloadServiceInfo implements Parcelable {
}
}
/**
* Create a new OffloadServiceInfo with payload updated.
*
* @hide
*/
@NonNull
public OffloadServiceInfo withOffloadPayload(@NonNull byte[] offloadPayload) {
return new OffloadServiceInfo(
this.getKey(),
this.getSubtypes(),
this.getHostname(),
offloadPayload,
this.getPriority(),
this.getOffloadType()
);
}
/**
* Get the offloadType.
* <p>

View File

@@ -146,9 +146,12 @@ public class MdnsAdvertiser {
interfaceName, k -> new ArrayList<>());
// Remove existing offload services from cache for update.
existingOffloadServiceInfoWrappers.removeIf(item -> item.mServiceId == serviceId);
byte[] rawOffloadPacket = advertiser.getRawOffloadPayload(serviceId);
final OffloadServiceInfoWrapper newOffloadServiceInfoWrapper = createOffloadService(
serviceId,
registration);
registration,
rawOffloadPacket);
existingOffloadServiceInfoWrappers.add(newOffloadServiceInfoWrapper);
mCb.onOffloadStartOrUpdate(interfaceName,
newOffloadServiceInfoWrapper.mOffloadServiceInfo);
@@ -393,7 +396,29 @@ public class MdnsAdvertiser {
public void onAddressesChanged(@NonNull SocketKey socketKey,
@NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {
final MdnsInterfaceAdvertiser advertiser = mAdvertisers.get(socket);
if (advertiser != null) advertiser.updateAddresses(addresses);
if (advertiser == null) {
return;
}
advertiser.updateAddresses(addresses);
// Update address should trigger offload packet update.
final String interfaceName = advertiser.getSocketInterfaceName();
final List<OffloadServiceInfoWrapper> existingOffloadServiceInfoWrappers =
mInterfaceOffloadServices.get(interfaceName);
if (existingOffloadServiceInfoWrappers == null) {
return;
}
final List<OffloadServiceInfoWrapper> updatedOffloadServiceInfoWrappers =
new ArrayList<>(existingOffloadServiceInfoWrappers.size());
for (OffloadServiceInfoWrapper oldWrapper : existingOffloadServiceInfoWrappers) {
OffloadServiceInfoWrapper newWrapper = new OffloadServiceInfoWrapper(
oldWrapper.mServiceId,
oldWrapper.mOffloadServiceInfo.withOffloadPayload(
advertiser.getRawOffloadPayload(oldWrapper.mServiceId))
);
updatedOffloadServiceInfoWrappers.add(newWrapper);
mCb.onOffloadStartOrUpdate(interfaceName, newWrapper.mOffloadServiceInfo);
}
mInterfaceOffloadServices.put(interfaceName, updatedOffloadServiceInfoWrappers);
}
}
@@ -630,9 +655,9 @@ public class MdnsAdvertiser {
}
private OffloadServiceInfoWrapper createOffloadService(int serviceId,
@NonNull Registration registration) {
@NonNull Registration registration, byte[] rawOffloadPacket) {
final NsdServiceInfo nsdServiceInfo = registration.getServiceInfo();
List<String> subTypes = new ArrayList<>();
final List<String> subTypes = new ArrayList<>();
String subType = registration.getSubtype();
if (subType != null) {
subTypes.add(subType);
@@ -642,7 +667,7 @@ public class MdnsAdvertiser {
nsdServiceInfo.getServiceType()),
subTypes,
String.join(".", mDeviceHostName),
null /* rawOffloadPacket */,
rawOffloadPacket,
// TODO: define overlayable resources in
// ServiceConnectivityResources that set the priority based on
// service type.
@@ -651,5 +676,4 @@ public class MdnsAdvertiser {
OffloadEngine.OFFLOAD_TYPE_REPLY);
return new OffloadServiceInfoWrapper(serviceId, offloadServiceInfo);
}
}

View File

@@ -28,6 +28,7 @@ import com.android.net.module.util.HexDump;
import com.android.net.module.util.SharedLog;
import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo;
import com.android.server.connectivity.mdns.MdnsPacketRepeater.PacketRepeaterCallback;
import com.android.server.connectivity.mdns.util.MdnsUtils;
import java.io.IOException;
import java.net.InetSocketAddress;
@@ -351,7 +352,25 @@ public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHand
mReplySender.queueReply(answers);
}
/**
* Get the socket interface name.
*/
public String getSocketInterfaceName() {
return mSocket.getInterface().getName();
}
/**
* Gets the offload MdnsPacket.
* @param serviceId The serviceId.
* @return the raw offload payload
*/
public byte[] getRawOffloadPayload(int serviceId) {
try {
return MdnsUtils.createRawDnsPacket(mReplySender.getPacketCreationBuffer(),
mRecordRepository.getOffloadPacket(serviceId));
} catch (IOException | IllegalArgumentException e) {
mSharedLog.wtf("Cannot create rawOffloadPacket: " + e.getMessage());
return new byte[0];
}
}
}

View File

@@ -735,6 +735,38 @@ public class MdnsRecordRepository {
answers, additionalAnswers);
}
/**
* Gets the offload MdnsPacket.
* @param serviceId The serviceId.
* @return The offload {@link MdnsPacket} that contains PTR/SRV/TXT/A/AAAA records.
*/
public MdnsPacket getOffloadPacket(int serviceId) throws IllegalArgumentException {
final ServiceRegistration registration = mServices.get(serviceId);
if (registration == null) throw new IllegalArgumentException(
"Service is not registered: " + serviceId);
final ArrayList<MdnsRecord> answers = new ArrayList<>();
// Adds all PTR, SRV, TXT, A/AAAA records.
for (RecordInfo<MdnsPointerRecord> ptrRecord : registration.ptrRecords) {
answers.add(ptrRecord.record);
}
answers.add(registration.srvRecord.record);
answers.add(registration.txtRecord.record);
for (RecordInfo<?> record : mGeneralRecords) {
if (record.record instanceof MdnsInetAddressRecord) {
answers.add(record.record);
}
}
final int flags = 0x8400; // Response, authoritative (rfc6762 18.4)
return new MdnsPacket(flags,
Collections.emptyList() /* questions */,
answers,
Collections.emptyList() /* authorityRecords */,
Collections.emptyList() /* additionalRecords */);
}
/**
* Get the service IDs of services conflicting with a received packet.
*/

View File

@@ -25,6 +25,7 @@ import android.os.Message;
import com.android.net.module.util.SharedLog;
import com.android.server.connectivity.mdns.MdnsRecordRepository.ReplyInfo;
import com.android.server.connectivity.mdns.util.MdnsUtils;
import java.io.IOException;
import java.net.DatagramPacket;
@@ -86,36 +87,13 @@ public class MdnsReplySender {
// Skip sending if the socket has not joined the v4/v6 group (there was no address)
return;
}
final byte[] outBuffer = MdnsUtils.createRawDnsPacket(mPacketCreationBuffer, packet);
mSocket.send(new DatagramPacket(outBuffer, 0, outBuffer.length, destination));
}
// TODO: support packets over size (send in multiple packets with TC bit set)
final MdnsPacketWriter writer = new MdnsPacketWriter(mPacketCreationBuffer);
writer.writeUInt16(0); // Transaction ID (advertisement: 0)
writer.writeUInt16(packet.flags); // Response, authoritative (rfc6762 18.4)
writer.writeUInt16(packet.questions.size()); // questions count
writer.writeUInt16(packet.answers.size()); // answers count
writer.writeUInt16(packet.authorityRecords.size()); // authority entries count
writer.writeUInt16(packet.additionalRecords.size()); // additional records count
for (MdnsRecord record : packet.questions) {
// Questions do not have TTL or data
record.writeHeaderFields(writer);
}
for (MdnsRecord record : packet.answers) {
record.write(writer, 0L);
}
for (MdnsRecord record : packet.authorityRecords) {
record.write(writer, 0L);
}
for (MdnsRecord record : packet.additionalRecords) {
record.write(writer, 0L);
}
final int len = writer.getWritePosition();
final byte[] outBuffer = new byte[len];
System.arraycopy(mPacketCreationBuffer, 0, outBuffer, 0, len);
mSocket.send(new DatagramPacket(outBuffer, 0, len, destination));
/** Get the packetCreationBuffer */
public byte[] getPacketCreationBuffer() {
return mPacketCreationBuffer;
}
/**

View File

@@ -24,8 +24,11 @@ import android.os.SystemClock;
import android.util.ArraySet;
import com.android.server.connectivity.mdns.MdnsConstants;
import com.android.server.connectivity.mdns.MdnsPacket;
import com.android.server.connectivity.mdns.MdnsPacketWriter;
import com.android.server.connectivity.mdns.MdnsRecord;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.Charset;
@@ -164,6 +167,41 @@ public class MdnsUtils {
return new String(out.array(), 0, out.position(), utf8);
}
/**
* Create a raw DNS packet.
*/
public static byte[] createRawDnsPacket(@NonNull byte[] packetCreationBuffer,
@NonNull MdnsPacket packet) throws IOException {
// TODO: support packets over size (send in multiple packets with TC bit set)
final MdnsPacketWriter writer = new MdnsPacketWriter(packetCreationBuffer);
writer.writeUInt16(0); // Transaction ID (advertisement: 0)
writer.writeUInt16(packet.flags); // Response, authoritative (rfc6762 18.4)
writer.writeUInt16(packet.questions.size()); // questions count
writer.writeUInt16(packet.answers.size()); // answers count
writer.writeUInt16(packet.authorityRecords.size()); // authority entries count
writer.writeUInt16(packet.additionalRecords.size()); // additional records count
for (MdnsRecord record : packet.questions) {
// Questions do not have TTL or data
record.writeHeaderFields(writer);
}
for (MdnsRecord record : packet.answers) {
record.write(writer, 0L);
}
for (MdnsRecord record : packet.authorityRecords) {
record.write(writer, 0L);
}
for (MdnsRecord record : packet.additionalRecords) {
record.write(writer, 0L);
}
final int len = writer.getWritePosition();
final byte[] outBuffer = new byte[len];
System.arraycopy(packetCreationBuffer, 0, outBuffer, 0, len);
return outBuffer;
}
/**
* Checks if the MdnsRecord needs to be renewed or not.
*

View File

@@ -900,6 +900,39 @@ class NsdManagerTest {
assertTrue(serviceInfo.hostname.endsWith("local"))
assertEquals(0, serviceInfo.priority)
assertEquals(OffloadEngine.OFFLOAD_TYPE_REPLY.toLong(), serviceInfo.offloadType)
val offloadPayload = serviceInfo.offloadPayload
assertNotNull(offloadPayload)
val dnsPacket = TestDnsPacket(offloadPayload)
assertEquals(0x8400, dnsPacket.header.flags)
assertEquals(0, dnsPacket.records[DnsPacket.QDSECTION].size)
assertTrue(dnsPacket.records[DnsPacket.ANSECTION].size >= 5)
assertEquals(0, dnsPacket.records[DnsPacket.NSSECTION].size)
assertEquals(0, dnsPacket.records[DnsPacket.ARSECTION].size)
val ptrRecord = dnsPacket.records[DnsPacket.ANSECTION][0]
assertEquals("$expectedServiceType.local", ptrRecord.dName)
assertEquals(0x0C /* PTR */, ptrRecord.nsType)
val ptrSubRecord = dnsPacket.records[DnsPacket.ANSECTION][1]
assertEquals("_subtype._sub.$expectedServiceType.local", ptrSubRecord.dName)
assertEquals(0x0C /* PTR */, ptrSubRecord.nsType)
val srvRecord = dnsPacket.records[DnsPacket.ANSECTION][2]
assertEquals("${si.serviceName}.$expectedServiceType.local", srvRecord.dName)
assertEquals(0x21 /* SRV */, srvRecord.nsType)
val txtRecord = dnsPacket.records[DnsPacket.ANSECTION][3]
assertEquals("${si.serviceName}.$expectedServiceType.local", txtRecord.dName)
assertEquals(0x10 /* TXT */, txtRecord.nsType)
val iface = NetworkInterface.getByName(testNetwork1.iface.interfaceName)
val allAddress = iface.inetAddresses.toList()
for (i in 4 until dnsPacket.records[DnsPacket.ANSECTION].size) {
val addressRecord = dnsPacket.records[DnsPacket.ANSECTION][i]
assertTrue(addressRecord.dName.startsWith("Android_"))
assertTrue(addressRecord.dName.endsWith("local"))
assertTrue(addressRecord.nsType in arrayOf(0x1C /* AAAA */, 0x01 /* A */))
val rData = addressRecord.rr
assertNotNull(rData)
val addr = InetAddress.getByAddress(rData)
assertTrue(addr in allAddress)
}
}
@Test
@@ -1410,6 +1443,11 @@ private fun TapPacketReader.pollForAdvertisement(
): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isReplyFor("$serviceName.$serviceType.local") }
private class TestDnsPacket(data: ByteArray) : DnsPacket(data) {
val header: DnsHeader
get() = mHeader
val records: Array<List<DnsRecord>>
get() = mRecords
fun isProbeFor(name: String): Boolean = mRecords[QDSECTION].any {
it.dName == name && it.nsType == 0xff /* ANY */
}

View File

@@ -56,7 +56,9 @@ private const val LONG_SERVICE_ID_2 = 4
private const val CASE_INSENSITIVE_TEST_SERVICE_ID = 5
private const val TIMEOUT_MS = 10_000L
private val TEST_ADDR = parseNumericAddress("2001:db8::123")
private val TEST_ADDR2 = parseNumericAddress("2001:db8::124")
private val TEST_LINKADDR = LinkAddress(TEST_ADDR, 64 /* prefixLength */)
private val TEST_LINKADDR2 = LinkAddress(TEST_ADDR2, 64 /* prefixLength */)
private val TEST_NETWORK_1 = mock(Network::class.java)
private val TEST_SOCKETKEY_1 = SocketKey(1001 /* interfaceIndex */)
private val TEST_SOCKETKEY_2 = SocketKey(1002 /* interfaceIndex */)
@@ -64,6 +66,8 @@ private val TEST_HOSTNAME = arrayOf("Android_test", "local")
private const val TEST_SUBTYPE = "_subtype"
private val TEST_INTERFACE1 = "test_iface1"
private val TEST_INTERFACE2 = "test_iface2"
private val TEST_OFFLOAD_PACKET1 = byteArrayOf(0x01, 0x02, 0x03)
private val TEST_OFFLOAD_PACKET2 = byteArrayOf(0x02, 0x03, 0x04)
private val SERVICE_1 = NsdServiceInfo("TestServiceName", "_advertisertest._tcp").apply {
port = 12345
@@ -102,7 +106,7 @@ private val OFFLOAD_SERVICEINFO = OffloadServiceInfo(
OffloadServiceInfo.Key("TestServiceName", "_advertisertest._tcp"),
listOf(TEST_SUBTYPE),
"Android_test.local",
null, /* rawOffloadPacket */
TEST_OFFLOAD_PACKET1,
0, /* priority */
OffloadEngine.OFFLOAD_TYPE_REPLY.toLong()
)
@@ -111,7 +115,16 @@ private val OFFLOAD_SERVICEINFO_NO_SUBTYPE = OffloadServiceInfo(
OffloadServiceInfo.Key("TestServiceName", "_advertisertest._tcp"),
listOf(),
"Android_test.local",
null, /* rawOffloadPacket */
TEST_OFFLOAD_PACKET1,
0, /* priority */
OffloadEngine.OFFLOAD_TYPE_REPLY.toLong()
)
private val OFFLOAD_SERVICEINFO_NO_SUBTYPE2 = OffloadServiceInfo(
OffloadServiceInfo.Key("TestServiceName", "_advertisertest._tcp"),
listOf(),
"Android_test.local",
TEST_OFFLOAD_PACKET2,
0, /* priority */
OffloadEngine.OFFLOAD_TYPE_REPLY.toLong()
)
@@ -147,6 +160,10 @@ class MdnsAdvertiserTest {
doReturn(createEmptyNetworkInterface()).`when`(mockSocket2).getInterface()
doReturn(TEST_INTERFACE1).`when`(mockInterfaceAdvertiser1).socketInterfaceName
doReturn(TEST_INTERFACE2).`when`(mockInterfaceAdvertiser2).socketInterfaceName
doReturn(TEST_OFFLOAD_PACKET1).`when`(mockInterfaceAdvertiser1).getRawOffloadPayload(
SERVICE_ID_1)
doReturn(TEST_OFFLOAD_PACKET1).`when`(mockInterfaceAdvertiser2).getRawOffloadPayload(
SERVICE_ID_1)
}
@After
@@ -189,10 +206,23 @@ class MdnsAdvertiserTest {
verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_1), argThat { it.matches(SERVICE_1) })
verify(cb).onOffloadStartOrUpdate(eq(TEST_INTERFACE1), eq(OFFLOAD_SERVICEINFO_NO_SUBTYPE))
doReturn(TEST_OFFLOAD_PACKET2).`when`(mockInterfaceAdvertiser1)
.getRawOffloadPayload(
SERVICE_ID_1
)
postSync {
socketCb.onAddressesChanged(
TEST_SOCKETKEY_1,
mockSocket1,
listOf(TEST_LINKADDR2)
)
}
verify(cb).onOffloadStartOrUpdate(eq(TEST_INTERFACE1), eq(OFFLOAD_SERVICEINFO_NO_SUBTYPE2))
postSync { socketCb.onInterfaceDestroyed(TEST_SOCKETKEY_1, mockSocket1) }
verify(mockInterfaceAdvertiser1).destroyNow()
postSync { intAdvCbCaptor.value.onDestroyed(mockSocket1) }
verify(cb).onOffloadStop(eq(TEST_INTERFACE1), eq(OFFLOAD_SERVICEINFO_NO_SUBTYPE))
verify(cb).onOffloadStop(eq(TEST_INTERFACE1), eq(OFFLOAD_SERVICEINFO_NO_SUBTYPE2))
}
@Test

View File

@@ -365,6 +365,58 @@ class MdnsRecordRepositoryTest {
), packet.additionalRecords)
}
@Test
fun testGetOffloadPacket() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
val serviceType = arrayOf("_testservice", "_tcp", "local")
val offloadPacket = repository.getOffloadPacket(TEST_SERVICE_ID_1)
assertEquals(0x8400, offloadPacket.flags)
assertEquals(0, offloadPacket.questions.size)
assertEquals(0, offloadPacket.additionalRecords.size)
assertEquals(0, offloadPacket.authorityRecords.size)
assertContentEquals(listOf(
MdnsPointerRecord(
serviceType,
0L /* receiptTimeMillis */,
// Not a unique name owned by the announcer, so cacheFlush=false
false /* cacheFlush */,
4500000L /* ttlMillis */,
serviceName),
MdnsServiceRecord(
serviceName,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
120000L /* ttlMillis */,
0 /* servicePriority */,
0 /* serviceWeight */,
TEST_PORT /* servicePort */,
TEST_HOSTNAME),
MdnsTextRecord(
serviceName,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
4500000L /* ttlMillis */,
emptyList() /* entries */),
MdnsInetAddressRecord(TEST_HOSTNAME,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
120000L /* ttlMillis */,
TEST_ADDRESSES[0].address),
MdnsInetAddressRecord(TEST_HOSTNAME,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
120000L /* ttlMillis */,
TEST_ADDRESSES[1].address),
MdnsInetAddressRecord(TEST_HOSTNAME,
0L /* receiptTimeMillis */,
true /* cacheFlush */,
120000L /* ttlMillis */,
TEST_ADDRESSES[2].address),
), offloadPacket.answers)
}
@Test
fun testGetReverseDnsAddress() {
val expectedV6 = "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.B.D.0.1.0.0.2.ip6.arpa"