diff --git a/services/core/java/com/android/server/ConnectivityService.java b/services/core/java/com/android/server/ConnectivityService.java index d9f01da54a..01b2e691d4 100644 --- a/services/core/java/com/android/server/ConnectivityService.java +++ b/services/core/java/com/android/server/ConnectivityService.java @@ -2488,10 +2488,12 @@ public class ConnectivityService extends IConnectivityManager.Stub final List netDiags = new ArrayList(); final long DIAG_TIME_MS = 5000; for (NetworkAgentInfo nai : networksSortedById()) { + PrivateDnsConfig privateDnsCfg = mDnsManager.getPrivateDnsConfig(nai.network); // Start gathering diagnostic information. netDiags.add(new NetworkDiagnostics( nai.network, new LinkProperties(nai.linkProperties), // Must be a copy. + privateDnsCfg, DIAG_TIME_MS)); } diff --git a/services/core/java/com/android/server/connectivity/DnsManager.java b/services/core/java/com/android/server/connectivity/DnsManager.java index 506c8e3919..cf6a7f6e8d 100644 --- a/services/core/java/com/android/server/connectivity/DnsManager.java +++ b/services/core/java/com/android/server/connectivity/DnsManager.java @@ -57,6 +57,7 @@ import java.util.HashSet; import java.util.Iterator; import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; @@ -64,7 +65,9 @@ import java.util.stream.Collectors; * Encapsulate the management of DNS settings for networks. * * This class it NOT designed for concurrent access. Furthermore, all non-static - * methods MUST be called from ConnectivityService's thread. + * methods MUST be called from ConnectivityService's thread. However, an exceptional + * case is getPrivateDnsConfig(Network) which is exclusively for + * ConnectivityService#dumpNetworkDiagnostics() on a random binder thread. * * [ Private DNS ] * The code handling Private DNS is spread across several components, but this @@ -236,8 +239,8 @@ public class DnsManager { private final ContentResolver mContentResolver; private final IDnsResolver mDnsResolver; private final MockableSystemProperties mSystemProperties; - // TODO: Replace these Maps with SparseArrays. - private final Map mPrivateDnsMap; + private final ConcurrentHashMap mPrivateDnsMap; + // TODO: Replace the Map with SparseArrays. private final Map mPrivateDnsValidationMap; private final Map mLinkPropertiesMap; private final Map mTransportsMap; @@ -247,15 +250,13 @@ public class DnsManager { private int mSuccessThreshold; private int mMinSamples; private int mMaxSamples; - private String mPrivateDnsMode; - private String mPrivateDnsSpecifier; public DnsManager(Context ctx, IDnsResolver dnsResolver, MockableSystemProperties sp) { mContext = ctx; mContentResolver = mContext.getContentResolver(); mDnsResolver = dnsResolver; mSystemProperties = sp; - mPrivateDnsMap = new HashMap<>(); + mPrivateDnsMap = new ConcurrentHashMap<>(); mPrivateDnsValidationMap = new HashMap<>(); mLinkPropertiesMap = new HashMap<>(); mTransportsMap = new HashMap<>(); @@ -275,6 +276,12 @@ public class DnsManager { mLinkPropertiesMap.remove(network.netId); } + // This is exclusively called by ConnectivityService#dumpNetworkDiagnostics() which + // is not on the ConnectivityService handler thread. + public PrivateDnsConfig getPrivateDnsConfig(@NonNull Network network) { + return mPrivateDnsMap.getOrDefault(network.netId, PRIVATE_DNS_OFF); + } + public PrivateDnsConfig updatePrivateDns(Network network, PrivateDnsConfig cfg) { Slog.w(TAG, "updatePrivateDns(" + network + ", " + cfg + ")"); return (cfg != null) diff --git a/services/core/java/com/android/server/connectivity/NetworkDiagnostics.java b/services/core/java/com/android/server/connectivity/NetworkDiagnostics.java index a1a8e355dc..49c16ad96e 100644 --- a/services/core/java/com/android/server/connectivity/NetworkDiagnostics.java +++ b/services/core/java/com/android/server/connectivity/NetworkDiagnostics.java @@ -18,12 +18,15 @@ package com.android.server.connectivity; import static android.system.OsConstants.*; +import android.annotation.NonNull; +import android.annotation.Nullable; import android.net.LinkAddress; import android.net.LinkProperties; import android.net.Network; import android.net.NetworkUtils; import android.net.RouteInfo; import android.net.TrafficStats; +import android.net.shared.PrivateDnsConfig; import android.net.util.NetworkConstants; import android.os.SystemClock; import android.system.ErrnoException; @@ -38,6 +41,8 @@ import com.android.internal.util.TrafficStatsConstants; import libcore.io.IoUtils; import java.io.Closeable; +import java.io.DataInputStream; +import java.io.DataOutputStream; import java.io.FileDescriptor; import java.io.IOException; import java.io.InterruptedIOException; @@ -52,6 +57,7 @@ import java.net.UnknownHostException; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -59,6 +65,12 @@ import java.util.Random; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import javax.net.ssl.SNIHostName; +import javax.net.ssl.SNIServerName; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; + /** * NetworkDiagnostics * @@ -100,6 +112,7 @@ public class NetworkDiagnostics { private final Network mNetwork; private final LinkProperties mLinkProperties; + private final PrivateDnsConfig mPrivateDnsCfg; private final Integer mInterfaceIndex; private final long mTimeoutMs; @@ -163,12 +176,15 @@ public class NetworkDiagnostics { private final Map, Measurement> mExplicitSourceIcmpChecks = new HashMap<>(); private final Map mDnsUdpChecks = new HashMap<>(); + private final Map mDnsTlsChecks = new HashMap<>(); private final String mDescription; - public NetworkDiagnostics(Network network, LinkProperties lp, long timeoutMs) { + public NetworkDiagnostics(Network network, LinkProperties lp, + @NonNull PrivateDnsConfig privateDnsCfg, long timeoutMs) { mNetwork = network; mLinkProperties = lp; + mPrivateDnsCfg = privateDnsCfg; mInterfaceIndex = getInterfaceIndex(mLinkProperties.getInterfaceName()); mTimeoutMs = timeoutMs; mStartTime = now(); @@ -199,8 +215,22 @@ public class NetworkDiagnostics { } } for (InetAddress nameserver : mLinkProperties.getDnsServers()) { - prepareIcmpMeasurement(nameserver); - prepareDnsMeasurement(nameserver); + prepareIcmpMeasurement(nameserver); + prepareDnsMeasurement(nameserver); + + // Unlike the DnsResolver which doesn't do certificate validation in opportunistic mode, + // DoT probes to the DNS servers will fail if certificate validation fails. + prepareDnsTlsMeasurement(null /* hostname */, nameserver); + } + + for (InetAddress tlsNameserver : mPrivateDnsCfg.ips) { + // Reachability check is necessary since when resolving the strict mode hostname, + // NetworkMonitor always queries for both A and AAAA records, even if the network + // is IPv4-only or IPv6-only. + if (mLinkProperties.isReachable(tlsNameserver)) { + // If there are IPs, there must have been a name that resolved to them. + prepareDnsTlsMeasurement(mPrivateDnsCfg.hostname, tlsNameserver); + } } mCountDownLatch = new CountDownLatch(totalMeasurementCount()); @@ -222,6 +252,15 @@ public class NetworkDiagnostics { } } + private static String socketAddressToString(@NonNull SocketAddress sockAddr) { + // The default toString() implementation is not the prettiest. + InetSocketAddress inetSockAddr = (InetSocketAddress) sockAddr; + InetAddress localAddr = inetSockAddr.getAddress(); + return String.format( + (localAddr instanceof Inet6Address ? "[%s]:%d" : "%s:%d"), + localAddr.getHostAddress(), inetSockAddr.getPort()); + } + private void prepareIcmpMeasurement(InetAddress target) { if (!mIcmpChecks.containsKey(target)) { Measurement measurement = new Measurement(); @@ -252,8 +291,19 @@ public class NetworkDiagnostics { } } + private void prepareDnsTlsMeasurement(@Nullable String hostname, @NonNull InetAddress target) { + // This might overwrite an existing entry in mDnsTlsChecks, because |target| can be an IP + // address configured by the network as well as an IP address learned by resolving the + // strict mode DNS hostname. If the entry is overwritten, the overwritten measurement + // thread will not execute. + Measurement measurement = new Measurement(); + measurement.thread = new Thread(new DnsTlsCheck(hostname, target, measurement)); + mDnsTlsChecks.put(target, measurement); + } + private int totalMeasurementCount() { - return mIcmpChecks.size() + mExplicitSourceIcmpChecks.size() + mDnsUdpChecks.size(); + return mIcmpChecks.size() + mExplicitSourceIcmpChecks.size() + mDnsUdpChecks.size() + + mDnsTlsChecks.size(); } private void startMeasurements() { @@ -266,6 +316,9 @@ public class NetworkDiagnostics { for (Measurement measurement : mDnsUdpChecks.values()) { measurement.thread.start(); } + for (Measurement measurement : mDnsTlsChecks.values()) { + measurement.thread.start(); + } } public void waitForMeasurements() { @@ -297,6 +350,11 @@ public class NetworkDiagnostics { measurements.add(entry.getValue()); } } + for (Map.Entry entry : mDnsTlsChecks.entrySet()) { + if (entry.getKey() instanceof Inet4Address) { + measurements.add(entry.getValue()); + } + } // IPv6 measurements second. for (Map.Entry entry : mIcmpChecks.entrySet()) { @@ -315,6 +373,11 @@ public class NetworkDiagnostics { measurements.add(entry.getValue()); } } + for (Map.Entry entry : mDnsTlsChecks.entrySet()) { + if (entry.getKey() instanceof Inet6Address) { + measurements.add(entry.getValue()); + } + } return measurements; } @@ -387,6 +450,8 @@ public class NetworkDiagnostics { try { mFileDescriptor = Os.socket(mAddressFamily, sockType, protocol); } finally { + // TODO: The tag should remain set until all traffic is sent and received. + // Consider tagging the socket after the measurement thread is started. TrafficStats.setThreadStatsTag(oldTag); } // Setting SNDTIMEO is purely for defensive purposes. @@ -403,13 +468,12 @@ public class NetworkDiagnostics { mSocketAddress = Os.getsockname(mFileDescriptor); } - protected String getSocketAddressString() { - // The default toString() implementation is not the prettiest. - InetSocketAddress inetSockAddr = (InetSocketAddress) mSocketAddress; - InetAddress localAddr = inetSockAddr.getAddress(); - return String.format( - (localAddr instanceof Inet6Address ? "[%s]:%d" : "%s:%d"), - localAddr.getHostAddress(), inetSockAddr.getPort()); + protected boolean ensureMeasurementNecessary() { + if (mMeasurement.finishTime == 0) return false; + + // Countdown latch was not decremented when the measurement failed during setup. + mCountDownLatch.countDown(); + return true; } @Override @@ -448,13 +512,7 @@ public class NetworkDiagnostics { @Override public void run() { - // Check if this measurement has already failed during setup. - if (mMeasurement.finishTime > 0) { - // If the measurement failed during construction it didn't - // decrement the countdown latch; do so here. - mCountDownLatch.countDown(); - return; - } + if (ensureMeasurementNecessary()) return; try { setupSocket(SOCK_DGRAM, mProtocol, TIMEOUT_SEND, TIMEOUT_RECV, 0); @@ -462,7 +520,7 @@ public class NetworkDiagnostics { mMeasurement.recordFailure(e.toString()); return; } - mMeasurement.description += " src{" + getSocketAddressString() + "}"; + mMeasurement.description += " src{" + socketAddressToString(mSocketAddress) + "}"; // Build a trivial ICMP packet. final byte[] icmpPacket = { @@ -507,10 +565,10 @@ public class NetworkDiagnostics { private static final int RR_TYPE_AAAA = 28; private static final int PACKET_BUFSIZE = 512; - private final Random mRandom = new Random(); + protected final Random mRandom = new Random(); // Should be static, but the compiler mocks our puny, human attempts at reason. - private String responseCodeStr(int rcode) { + protected String responseCodeStr(int rcode) { try { return DnsResponseCode.values()[rcode].toString(); } catch (IndexOutOfBoundsException e) { @@ -518,7 +576,7 @@ public class NetworkDiagnostics { } } - private final int mQueryType; + protected final int mQueryType; public DnsUdpCheck(InetAddress target, Measurement measurement) { super(target, measurement); @@ -535,13 +593,7 @@ public class NetworkDiagnostics { @Override public void run() { - // Check if this measurement has already failed during setup. - if (mMeasurement.finishTime > 0) { - // If the measurement failed during construction it didn't - // decrement the countdown latch; do so here. - mCountDownLatch.countDown(); - return; - } + if (ensureMeasurementNecessary()) return; try { setupSocket(SOCK_DGRAM, IPPROTO_UDP, TIMEOUT_SEND, TIMEOUT_RECV, @@ -550,12 +602,10 @@ public class NetworkDiagnostics { mMeasurement.recordFailure(e.toString()); return; } - mMeasurement.description += " src{" + getSocketAddressString() + "}"; // This needs to be fixed length so it can be dropped into the pre-canned packet. final String sixRandomDigits = String.valueOf(mRandom.nextInt(900000) + 100000); - mMeasurement.description += " qtype{" + mQueryType + "}" - + " qname{" + sixRandomDigits + "-android-ds.metric.gstatic.com}"; + appendDnsToMeasurementDescription(sixRandomDigits, mSocketAddress); // Build a trivial DNS packet. final byte[] dnsPacket = getDnsQueryPacket(sixRandomDigits); @@ -592,7 +642,7 @@ public class NetworkDiagnostics { close(); } - private byte[] getDnsQueryPacket(String sixRandomDigits) { + protected byte[] getDnsQueryPacket(String sixRandomDigits) { byte[] rnd = sixRandomDigits.getBytes(StandardCharsets.US_ASCII); return new byte[] { (byte) mRandom.nextInt(), (byte) mRandom.nextInt(), // [0-1] query ID @@ -611,5 +661,97 @@ public class NetworkDiagnostics { 0, 1 // QCLASS, set to 1 = IN (Internet) }; } + + protected void appendDnsToMeasurementDescription( + String sixRandomDigits, SocketAddress sockAddr) { + mMeasurement.description += " src{" + socketAddressToString(sockAddr) + "}" + + " qtype{" + mQueryType + "}" + + " qname{" + sixRandomDigits + "-android-ds.metric.gstatic.com}"; + } + } + + // TODO: Have it inherited from SimpleSocketCheck, and separate common DNS helpers out of + // DnsUdpCheck. + private class DnsTlsCheck extends DnsUdpCheck { + private static final int TCP_CONNECT_TIMEOUT_MS = 2500; + private static final int TCP_TIMEOUT_MS = 2000; + private static final int DNS_TLS_PORT = 853; + private static final int DNS_HEADER_SIZE = 12; + + private final String mHostname; + + public DnsTlsCheck(@Nullable String hostname, @NonNull InetAddress target, + @NonNull Measurement measurement) { + super(target, measurement); + + mHostname = hostname; + mMeasurement.description = "DNS TLS dst{" + mTarget.getHostAddress() + "} hostname{" + + TextUtils.emptyIfNull(mHostname) + "}"; + } + + private SSLSocket setupSSLSocket() throws IOException { + // A TrustManager will be created and initialized with a KeyStore containing system + // CaCerts. During SSL handshake, it will be used to validate the certificates from + // the server. + SSLSocket sslSocket = (SSLSocket) SSLSocketFactory.getDefault().createSocket(); + sslSocket.setSoTimeout(TCP_TIMEOUT_MS); + + if (!TextUtils.isEmpty(mHostname)) { + // Set SNI. + final List names = + Collections.singletonList(new SNIHostName(mHostname)); + SSLParameters params = sslSocket.getSSLParameters(); + params.setServerNames(names); + sslSocket.setSSLParameters(params); + } + + mNetwork.bindSocket(sslSocket); + return sslSocket; + } + + private void sendDoTProbe(@Nullable SSLSocket sslSocket) throws IOException { + final String sixRandomDigits = String.valueOf(mRandom.nextInt(900000) + 100000); + final byte[] dnsPacket = getDnsQueryPacket(sixRandomDigits); + + mMeasurement.startTime = now(); + sslSocket.connect(new InetSocketAddress(mTarget, DNS_TLS_PORT), TCP_CONNECT_TIMEOUT_MS); + + // Synchronous call waiting for the TLS handshake complete. + sslSocket.startHandshake(); + appendDnsToMeasurementDescription(sixRandomDigits, sslSocket.getLocalSocketAddress()); + + final DataOutputStream output = new DataOutputStream(sslSocket.getOutputStream()); + output.writeShort(dnsPacket.length); + output.write(dnsPacket, 0, dnsPacket.length); + + final DataInputStream input = new DataInputStream(sslSocket.getInputStream()); + final int replyLength = Short.toUnsignedInt(input.readShort()); + final byte[] reply = new byte[replyLength]; + int bytesRead = 0; + while (bytesRead < replyLength) { + bytesRead += input.read(reply, bytesRead, replyLength - bytesRead); + } + + if (bytesRead > DNS_HEADER_SIZE && bytesRead == replyLength) { + mMeasurement.recordSuccess("1/1 " + responseCodeStr((int) (reply[3]) & 0x0f)); + } else { + mMeasurement.recordFailure("1/1 Read " + bytesRead + " bytes while expected to be " + + replyLength + " bytes"); + } + } + + @Override + public void run() { + if (ensureMeasurementNecessary()) return; + + // No need to restore the tag, since this thread is only used for this measurement. + TrafficStats.getAndSetThreadStatsTag(TrafficStatsConstants.TAG_SYSTEM_PROBE); + + try (SSLSocket sslSocket = setupSSLSocket()) { + sendDoTProbe(sslSocket); + } catch (IOException e) { + mMeasurement.recordFailure(e.toString()); + } + } } } diff --git a/tests/net/java/com/android/server/connectivity/DnsManagerTest.java b/tests/net/java/com/android/server/connectivity/DnsManagerTest.java index 0a603b8e4b..26a28da975 100644 --- a/tests/net/java/com/android/server/connectivity/DnsManagerTest.java +++ b/tests/net/java/com/android/server/connectivity/DnsManagerTest.java @@ -62,6 +62,8 @@ import androidx.test.runner.AndroidJUnit4; import com.android.internal.util.MessageUtils; import com.android.internal.util.test.FakeSettingsProvider; +import libcore.net.InetAddressUtils; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -379,4 +381,49 @@ public class DnsManagerTest { assertEquals(name, dnsTransTypes.get(i)); } } + + @Test + public void testGetPrivateDnsConfigForNetwork() throws Exception { + final Network network = new Network(TEST_NETID); + final InetAddress dnsAddr = InetAddressUtils.parseNumericAddress("3.3.3.3"); + final InetAddress[] tlsAddrs = new InetAddress[]{ + InetAddressUtils.parseNumericAddress("6.6.6.6"), + InetAddressUtils.parseNumericAddress("2001:db8:66:66::1") + }; + final String tlsName = "strictmode.com"; + LinkProperties lp = new LinkProperties(); + lp.addDnsServer(dnsAddr); + + // The PrivateDnsConfig map is empty, so the default PRIVATE_DNS_OFF is returned. + PrivateDnsConfig privateDnsCfg = mDnsManager.getPrivateDnsConfig(network); + assertFalse(privateDnsCfg.useTls); + assertEquals("", privateDnsCfg.hostname); + assertEquals(new InetAddress[0], privateDnsCfg.ips); + + // An entry with default PrivateDnsConfig is added to the PrivateDnsConfig map. + mDnsManager.updatePrivateDns(network, mDnsManager.getPrivateDnsConfig()); + mDnsManager.noteDnsServersForNetwork(TEST_NETID, lp); + mDnsManager.updatePrivateDnsValidation( + new DnsManager.PrivateDnsValidationUpdate(TEST_NETID, dnsAddr, "", true)); + mDnsManager.updatePrivateDnsStatus(TEST_NETID, lp); + privateDnsCfg = mDnsManager.getPrivateDnsConfig(network); + assertTrue(privateDnsCfg.useTls); + assertEquals("", privateDnsCfg.hostname); + assertEquals(new InetAddress[0], privateDnsCfg.ips); + + // The original entry is overwritten by a new PrivateDnsConfig. + mDnsManager.updatePrivateDns(network, new PrivateDnsConfig(tlsName, tlsAddrs)); + mDnsManager.updatePrivateDnsStatus(TEST_NETID, lp); + privateDnsCfg = mDnsManager.getPrivateDnsConfig(network); + assertTrue(privateDnsCfg.useTls); + assertEquals(tlsName, privateDnsCfg.hostname); + assertEquals(tlsAddrs, privateDnsCfg.ips); + + // The network is removed, so the PrivateDnsConfig map becomes empty again. + mDnsManager.removeNetwork(network); + privateDnsCfg = mDnsManager.getPrivateDnsConfig(network); + assertFalse(privateDnsCfg.useTls); + assertEquals("", privateDnsCfg.hostname); + assertEquals(new InetAddress[0], privateDnsCfg.ips); + } }