Avoid going over max service name length when renaming NSD services

As per  RFC 1034/1035, the max size of the label is 63 bytes. It should
also be guaranteed when the serviceName is renamed due to the conflict.

Bug: 265865456
Test: atest FrameworksNetTests
Change-Id: I077d8abdb91071db62b9618d9918e3a12682aaf4
This commit is contained in:
Yuyang Huang
2023-05-02 17:14:22 +09:00
parent 441a32e39d
commit de802c8dc4
6 changed files with 83 additions and 25 deletions

View File

@@ -23,6 +23,7 @@ import static android.provider.DeviceConfig.NAMESPACE_CONNECTIVITY;
import static android.provider.DeviceConfig.NAMESPACE_TETHERING; import static android.provider.DeviceConfig.NAMESPACE_TETHERING;
import static com.android.modules.utils.build.SdkLevel.isAtLeastU; import static com.android.modules.utils.build.SdkLevel.isAtLeastU;
import static com.android.server.connectivity.mdns.MdnsRecord.MAX_LABEL_LENGTH;
import android.annotation.NonNull; import android.annotation.NonNull;
import android.annotation.Nullable; import android.annotation.Nullable;
@@ -73,6 +74,7 @@ import com.android.server.connectivity.mdns.MdnsServiceBrowserListener;
import com.android.server.connectivity.mdns.MdnsServiceInfo; import com.android.server.connectivity.mdns.MdnsServiceInfo;
import com.android.server.connectivity.mdns.MdnsSocketClientBase; import com.android.server.connectivity.mdns.MdnsSocketClientBase;
import com.android.server.connectivity.mdns.MdnsSocketProvider; import com.android.server.connectivity.mdns.MdnsSocketProvider;
import com.android.server.connectivity.mdns.util.MdnsUtils;
import java.io.FileDescriptor; import java.io.FileDescriptor;
import java.io.PrintWriter; import java.io.PrintWriter;
@@ -81,11 +83,6 @@ import java.net.InetAddress;
import java.net.NetworkInterface; import java.net.NetworkInterface;
import java.net.SocketException; import java.net.SocketException;
import java.net.UnknownHostException; import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.Charset;
import java.nio.charset.CharsetEncoder;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
@@ -108,8 +105,6 @@ public class NsdService extends INsdManager.Stub {
*/ */
private static final String MDNS_DISCOVERY_MANAGER_VERSION = "mdns_discovery_manager_version"; private static final String MDNS_DISCOVERY_MANAGER_VERSION = "mdns_discovery_manager_version";
private static final String LOCAL_DOMAIN_NAME = "local"; private static final String LOCAL_DOMAIN_NAME = "local";
// Max label length as per RFC 1034/1035
private static final int MAX_LABEL_LENGTH = 63;
/** /**
* Enable advertising using the Java MdnsAdvertiser, instead of the legacy mdnsresponder * Enable advertising using the Java MdnsAdvertiser, instead of the legacy mdnsresponder
@@ -570,18 +565,7 @@ public class NsdService extends INsdManager.Stub {
*/ */
@NonNull @NonNull
private String truncateServiceName(@NonNull String originalName) { private String truncateServiceName(@NonNull String originalName) {
// UTF-8 is at most 4 bytes per character; return early in the common case where return MdnsUtils.truncateServiceName(originalName, MAX_LABEL_LENGTH);
// the name can't possibly be over the limit given its string length.
if (originalName.length() <= MAX_LABEL_LENGTH / 4) return originalName;
final Charset utf8 = StandardCharsets.UTF_8;
final CharsetEncoder encoder = utf8.newEncoder();
final ByteBuffer out = ByteBuffer.allocate(MAX_LABEL_LENGTH);
// encode will write as many characters as possible to the out buffer, and just
// return an overflow code if there were too many characters (no need to check the
// return code here, this method truncates the name on purpose).
encoder.encode(CharBuffer.wrap(originalName), out, true /* endOfInput */);
return new String(out.array(), 0, out.position(), utf8);
} }
private void stopDiscoveryManagerRequest(ClientRequest request, int clientId, int id, private void stopDiscoveryManagerRequest(ClientRequest request, int clientId, int id,

View File

@@ -16,6 +16,8 @@
package com.android.server.connectivity.mdns; package com.android.server.connectivity.mdns;
import static com.android.server.connectivity.mdns.MdnsRecord.MAX_LABEL_LENGTH;
import android.annotation.NonNull; import android.annotation.NonNull;
import android.annotation.Nullable; import android.annotation.Nullable;
import android.net.LinkAddress; import android.net.LinkAddress;
@@ -29,6 +31,7 @@ import android.util.SparseArray;
import com.android.internal.annotations.VisibleForTesting; import com.android.internal.annotations.VisibleForTesting;
import com.android.net.module.util.SharedLog; import com.android.net.module.util.SharedLog;
import com.android.server.connectivity.mdns.util.MdnsUtils;
import java.io.PrintWriter; import java.io.PrintWriter;
import java.util.List; import java.util.List;
@@ -359,7 +362,7 @@ public class MdnsAdvertiser {
// "Name (2)", then "Name (3)" etc. // "Name (2)", then "Name (3)" etc.
// TODO: use a hidden method in NsdServiceInfo once MdnsAdvertiser is moved to service-t // TODO: use a hidden method in NsdServiceInfo once MdnsAdvertiser is moved to service-t
final NsdServiceInfo newInfo = new NsdServiceInfo(); final NsdServiceInfo newInfo = new NsdServiceInfo();
newInfo.setServiceName(mOriginalName + " (" + (mConflictCount + renameCount + 1) + ")"); newInfo.setServiceName(getUpdatedServiceName(renameCount));
newInfo.setServiceType(mServiceInfo.getServiceType()); newInfo.setServiceType(mServiceInfo.getServiceType());
for (Map.Entry<String, byte[]> attr : mServiceInfo.getAttributes().entrySet()) { for (Map.Entry<String, byte[]> attr : mServiceInfo.getAttributes().entrySet()) {
newInfo.setAttribute(attr.getKey(), newInfo.setAttribute(attr.getKey(),
@@ -372,6 +375,13 @@ public class MdnsAdvertiser {
return newInfo; return newInfo;
} }
private String getUpdatedServiceName(int renameCount) {
final String suffix = " (" + (mConflictCount + renameCount + 1) + ")";
final String truncatedServiceName = MdnsUtils.truncateServiceName(mOriginalName,
MAX_LABEL_LENGTH - suffix.length());
return truncatedServiceName + suffix;
}
@NonNull @NonNull
public NsdServiceInfo getServiceInfo() { public NsdServiceInfo getServiceInfo() {
return mServiceInfo; return mServiceInfo;

View File

@@ -46,6 +46,8 @@ public abstract class MdnsRecord {
public static final long RECEIPT_TIME_NOT_SENT = 0L; public static final long RECEIPT_TIME_NOT_SENT = 0L;
public static final int CLASS_ANY = 0x00ff; public static final int CLASS_ANY = 0x00ff;
/** Max label length as per RFC 1034/1035 */
public static final int MAX_LABEL_LENGTH = 63;
/** Status indicating that the record is current. */ /** Status indicating that the record is current. */
public static final int STATUS_OK = 0; public static final int STATUS_OK = 0;

View File

@@ -21,6 +21,12 @@ import android.annotation.Nullable;
import android.net.Network; import android.net.Network;
import android.os.Handler; import android.os.Handler;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.Charset;
import java.nio.charset.CharsetEncoder;
import java.nio.charset.StandardCharsets;
/** /**
* Mdns utility functions. * Mdns utility functions.
*/ */
@@ -73,4 +79,22 @@ public class MdnsUtils {
@Nullable Network currentNetwork) { @Nullable Network currentNetwork) {
return targetNetwork == null || targetNetwork.equals(currentNetwork); return targetNetwork == null || targetNetwork.equals(currentNetwork);
} }
}
/**
* Truncate a service name to up to maxLength UTF-8 bytes.
*/
public static String truncateServiceName(@NonNull String originalName, int maxLength) {
// UTF-8 is at most 4 bytes per character; return early in the common case where
// the name can't possibly be over the limit given its string length.
if (originalName.length() <= maxLength / 4) return originalName;
final Charset utf8 = StandardCharsets.UTF_8;
final CharsetEncoder encoder = utf8.newEncoder();
final ByteBuffer out = ByteBuffer.allocate(maxLength);
// encode will write as many characters as possible to the out buffer, and just
// return an overflow code if there were too many characters (no need to check the
// return code here, this method truncates the name on purpose).
encoder.encode(CharBuffer.wrap(originalName), out, true /* endOfInput */);
return new String(out.array(), 0, out.position(), utf8);
}
}

View File

@@ -47,6 +47,8 @@ import org.mockito.Mockito.verify
private const val SERVICE_ID_1 = 1 private const val SERVICE_ID_1 = 1
private const val SERVICE_ID_2 = 2 private const val SERVICE_ID_2 = 2
private const val LONG_SERVICE_ID_1 = 3
private const val LONG_SERVICE_ID_2 = 4
private const val TIMEOUT_MS = 10_000L private const val TIMEOUT_MS = 10_000L
private val TEST_ADDR = parseNumericAddress("2001:db8::123") private val TEST_ADDR = parseNumericAddress("2001:db8::123")
private val TEST_LINKADDR = LinkAddress(TEST_ADDR, 64 /* prefixLength */) private val TEST_LINKADDR = LinkAddress(TEST_ADDR, 64 /* prefixLength */)
@@ -56,16 +58,30 @@ private val TEST_HOSTNAME = arrayOf("Android_test", "local")
private val SERVICE_1 = NsdServiceInfo("TestServiceName", "_advertisertest._tcp").apply { private val SERVICE_1 = NsdServiceInfo("TestServiceName", "_advertisertest._tcp").apply {
port = 12345 port = 12345
host = TEST_ADDR hostAddresses = listOf(TEST_ADDR)
network = TEST_NETWORK_1 network = TEST_NETWORK_1
} }
private val LONG_SERVICE_1 =
NsdServiceInfo("a".repeat(48) + "TestServiceName", "_longadvertisertest._tcp").apply {
port = 12345
hostAddresses = listOf(TEST_ADDR)
network = TEST_NETWORK_1
}
private val ALL_NETWORKS_SERVICE = NsdServiceInfo("TestServiceName", "_advertisertest._tcp").apply { private val ALL_NETWORKS_SERVICE = NsdServiceInfo("TestServiceName", "_advertisertest._tcp").apply {
port = 12345 port = 12345
host = TEST_ADDR hostAddresses = listOf(TEST_ADDR)
network = null network = null
} }
private val LONG_ALL_NETWORKS_SERVICE =
NsdServiceInfo("a".repeat(48) + "TestServiceName", "_longadvertisertest._tcp").apply {
port = 12345
hostAddresses = listOf(TEST_ADDR)
network = null
}
@RunWith(DevSdkIgnoreRunner::class) @RunWith(DevSdkIgnoreRunner::class)
@IgnoreUpTo(Build.VERSION_CODES.S_V2) @IgnoreUpTo(Build.VERSION_CODES.S_V2)
class MdnsAdvertiserTest { class MdnsAdvertiserTest {
@@ -191,6 +207,9 @@ class MdnsAdvertiserTest {
verify(socketProvider).requestSocket(eq(null), allNetSocketCbCaptor.capture()) verify(socketProvider).requestSocket(eq(null), allNetSocketCbCaptor.capture())
val allNetSocketCb = allNetSocketCbCaptor.value val allNetSocketCb = allNetSocketCbCaptor.value
postSync { advertiser.addService(LONG_SERVICE_ID_1, LONG_SERVICE_1) }
postSync { advertiser.addService(LONG_SERVICE_ID_2, LONG_ALL_NETWORKS_SERVICE) }
// Callbacks for matching network and all networks both get the socket // Callbacks for matching network and all networks both get the socket
postSync { postSync {
oneNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR)) oneNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
@@ -200,10 +219,18 @@ class MdnsAdvertiserTest {
val expectedRenamed = NsdServiceInfo( val expectedRenamed = NsdServiceInfo(
"${ALL_NETWORKS_SERVICE.serviceName} (2)", ALL_NETWORKS_SERVICE.serviceType).apply { "${ALL_NETWORKS_SERVICE.serviceName} (2)", ALL_NETWORKS_SERVICE.serviceType).apply {
port = ALL_NETWORKS_SERVICE.port port = ALL_NETWORKS_SERVICE.port
host = ALL_NETWORKS_SERVICE.host hostAddresses = ALL_NETWORKS_SERVICE.hostAddresses
network = ALL_NETWORKS_SERVICE.network network = ALL_NETWORKS_SERVICE.network
} }
val expectedLongRenamed = NsdServiceInfo(
"${LONG_ALL_NETWORKS_SERVICE.serviceName.dropLast(4)} (2)",
LONG_ALL_NETWORKS_SERVICE.serviceType).apply {
port = LONG_ALL_NETWORKS_SERVICE.port
hostAddresses = LONG_ALL_NETWORKS_SERVICE.hostAddresses
network = LONG_ALL_NETWORKS_SERVICE.network
}
val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java) val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
verify(mockDeps).makeAdvertiser(eq(mockSocket1), eq(listOf(TEST_LINKADDR)), verify(mockDeps).makeAdvertiser(eq(mockSocket1), eq(listOf(TEST_LINKADDR)),
eq(thread.looper), any(), intAdvCbCaptor.capture(), eq(TEST_HOSTNAME) eq(thread.looper), any(), intAdvCbCaptor.capture(), eq(TEST_HOSTNAME)
@@ -212,6 +239,10 @@ class MdnsAdvertiserTest {
argThat { it.matches(SERVICE_1) }) argThat { it.matches(SERVICE_1) })
verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_2), verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_2),
argThat { it.matches(expectedRenamed) }) argThat { it.matches(expectedRenamed) })
verify(mockInterfaceAdvertiser1).addService(eq(LONG_SERVICE_ID_1),
argThat { it.matches(LONG_SERVICE_1) })
verify(mockInterfaceAdvertiser1).addService(eq(LONG_SERVICE_ID_2),
argThat { it.matches(expectedLongRenamed) })
doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_1) doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_1)
postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded( postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded(
@@ -251,7 +282,7 @@ private fun NsdServiceInfo.matches(other: NsdServiceInfo): Boolean {
return Objects.equals(serviceName, other.serviceName) && return Objects.equals(serviceName, other.serviceName) &&
Objects.equals(serviceType, other.serviceType) && Objects.equals(serviceType, other.serviceType) &&
Objects.equals(attributes, other.attributes) && Objects.equals(attributes, other.attributes) &&
Objects.equals(host, other.host) && Objects.equals(hostAddresses, other.hostAddresses) &&
port == other.port && port == other.port &&
Objects.equals(network, other.network) Objects.equals(network, other.network)
} }

View File

@@ -19,6 +19,7 @@ package com.android.server.connectivity.mdns.util
import android.os.Build import android.os.Build
import com.android.server.connectivity.mdns.util.MdnsUtils.equalsIgnoreDnsCase import com.android.server.connectivity.mdns.util.MdnsUtils.equalsIgnoreDnsCase
import com.android.server.connectivity.mdns.util.MdnsUtils.toDnsLowerCase import com.android.server.connectivity.mdns.util.MdnsUtils.toDnsLowerCase
import com.android.server.connectivity.mdns.util.MdnsUtils.truncateServiceName
import com.android.testutils.DevSdkIgnoreRule import com.android.testutils.DevSdkIgnoreRule
import com.android.testutils.DevSdkIgnoreRunner import com.android.testutils.DevSdkIgnoreRunner
import org.junit.Assert.assertEquals import org.junit.Assert.assertEquals
@@ -65,4 +66,10 @@ class MdnsUtilsTest {
"Test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" + "Test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
"\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<")) "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<"))
} }
@Test
fun testTruncateServiceName() {
assertEquals(truncateServiceName("测试abcde", 7), "测试a")
assertEquals(truncateServiceName("测试abcde", 100), "测试abcde")
}
} }