diff --git a/core/java/android/net/DnsResolver.java b/core/java/android/net/DnsResolver.java index f9e0af227a..687b721834 100644 --- a/core/java/android/net/DnsResolver.java +++ b/core/java/android/net/DnsResolver.java @@ -22,6 +22,10 @@ import static android.net.NetworkUtils.resNetworkResult; import static android.net.NetworkUtils.resNetworkSend; import static android.os.MessageQueue.OnFileDescriptorEventListener.EVENT_ERROR; import static android.os.MessageQueue.OnFileDescriptorEventListener.EVENT_INPUT; +import static android.system.OsConstants.AF_INET; +import static android.system.OsConstants.AF_INET6; +import static android.system.OsConstants.IPPROTO_UDP; +import static android.system.OsConstants.SOCK_DGRAM; import android.annotation.CallbackExecutor; import android.annotation.IntDef; @@ -30,12 +34,18 @@ import android.annotation.Nullable; import android.os.CancellationSignal; import android.os.Looper; import android.system.ErrnoException; +import android.system.Os; import android.util.Log; +import libcore.io.IoUtils; + import java.io.FileDescriptor; +import java.io.IOException; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.net.UnknownHostException; import java.util.ArrayList; import java.util.List; @@ -52,6 +62,7 @@ public final class DnsResolver { private static final String TAG = "DnsResolver"; private static final int FD_EVENTS = EVENT_INPUT | EVENT_ERROR; private static final int MAXPACKET = 8 * 1024; + private static final int SLEEP_TIME = 2; @IntDef(prefix = { "CLASS_" }, value = { CLASS_IN @@ -188,9 +199,9 @@ public final class DnsResolver { * Send a raw DNS query. * The answer will be provided asynchronously through the provided {@link AnswerCallback}. * - * @param network {@link Network} specifying which network for querying. + * @param network {@link Network} specifying which network to query on. * {@code null} for query on default network. - * @param query blob message + * @param query blob message to query * @param flags flags as a combination of the FLAGS_* constants * @param executor The {@link Executor} that the callback should be executed on. * @param cancellationSignal used by the caller to signal if the query should be @@ -211,21 +222,23 @@ public final class DnsResolver { queryfd = resNetworkSend((network != null ? network.netId : NETID_UNSET), query, query.length, flags); } catch (ErrnoException e) { - callback.onQueryException(e); + executor.execute(() -> { + callback.onQueryException(e); + }); return; } - maybeAddCancellationSignal(cancellationSignal, queryfd, lock); registerFDListener(executor, queryfd, callback, cancellationSignal, lock); + maybeAddCancellationSignal(cancellationSignal, queryfd, lock); } /** * Send a DNS query with the specified name, class and query type. * The answer will be provided asynchronously through the provided {@link AnswerCallback}. * - * @param network {@link Network} specifying which network for querying. + * @param network {@link Network} specifying which network to query on. * {@code null} for query on default network. - * @param domain domain name for querying + * @param domain domain name to query * @param nsClass dns class as one of the CLASS_* constants * @param nsType dns resource record (RR) type as one of the TYPE_* constants * @param flags flags as a combination of the FLAGS_* constants @@ -249,12 +262,145 @@ public final class DnsResolver { queryfd = resNetworkQuery((network != null ? network.netId : NETID_UNSET), domain, nsClass, nsType, flags); } catch (ErrnoException e) { - callback.onQueryException(e); + executor.execute(() -> { + callback.onQueryException(e); + }); return; } - maybeAddCancellationSignal(cancellationSignal, queryfd, lock); registerFDListener(executor, queryfd, callback, cancellationSignal, lock); + maybeAddCancellationSignal(cancellationSignal, queryfd, lock); + } + + private class InetAddressAnswerAccumulator extends InetAddressAnswerCallback { + private final List mAllAnswers; + private ParseException mParseException; + private ErrnoException mErrnoException; + private final InetAddressAnswerCallback mUserCallback; + private final int mTargetAnswerCount; + private int mReceivedAnswerCount = 0; + + InetAddressAnswerAccumulator(int size, @NonNull InetAddressAnswerCallback callback) { + mTargetAnswerCount = size; + mAllAnswers = new ArrayList<>(); + mUserCallback = callback; + } + + private boolean maybeReportException() { + if (mErrnoException != null) { + mUserCallback.onQueryException(mErrnoException); + return true; + } + if (mParseException != null) { + mUserCallback.onParseException(mParseException); + return true; + } + return false; + } + + private void maybeReportAnswer() { + if (++mReceivedAnswerCount != mTargetAnswerCount) return; + if (mAllAnswers.isEmpty() && maybeReportException()) return; + // TODO: Do RFC6724 sort. + mUserCallback.onAnswer(mAllAnswers); + } + + @Override + public void onAnswer(@NonNull List answer) { + mAllAnswers.addAll(answer); + maybeReportAnswer(); + } + + @Override + public void onParseException(@NonNull ParseException e) { + mParseException = e; + maybeReportAnswer(); + } + + @Override + public void onQueryException(@NonNull ErrnoException e) { + mErrnoException = e; + maybeReportAnswer(); + } + } + + /** + * Send a DNS query with the specified name, get back a set of InetAddresses asynchronously. + * The answer will be provided asynchronously through the provided + * {@link InetAddressAnswerCallback}. + * + * @param network {@link Network} specifying which network to query on. + * {@code null} for query on default network. + * @param domain domain name to query + * @param flags flags as a combination of the FLAGS_* constants + * @param executor The {@link Executor} that the callback should be executed on. + * @param cancellationSignal used by the caller to signal if the query should be + * cancelled. May be {@code null}. + * @param callback an {@link InetAddressAnswerCallback} which will be called to notify the + * caller of the result of dns query. + */ + public void query(@Nullable Network network, @NonNull String domain, @QueryFlag int flags, + @NonNull @CallbackExecutor Executor executor, + @Nullable CancellationSignal cancellationSignal, + @NonNull InetAddressAnswerCallback callback) { + if (cancellationSignal != null && cancellationSignal.isCanceled()) { + return; + } + final Object lock = new Object(); + final boolean queryIpv6 = haveIpv6(network); + final boolean queryIpv4 = haveIpv4(network); + + final FileDescriptor v4fd; + final FileDescriptor v6fd; + + int queryCount = 0; + + if (queryIpv6) { + try { + v6fd = resNetworkQuery((network != null + ? network.netId : NETID_UNSET), domain, CLASS_IN, TYPE_AAAA, flags); + } catch (ErrnoException e) { + executor.execute(() -> { + callback.onQueryException(e); + }); + return; + } + queryCount++; + } else v6fd = null; + + // TODO: Use device flag to controll the sleep time. + // Avoiding gateways drop packets if queries are sent too close together + try { + Thread.sleep(SLEEP_TIME); + } catch (InterruptedException ex) { } + + if (queryIpv4) { + try { + v4fd = resNetworkQuery((network != null + ? network.netId : NETID_UNSET), domain, CLASS_IN, TYPE_A, flags); + } catch (ErrnoException e) { + if (queryIpv6) resNetworkCancel(v6fd); // Closes fd, marks it invalid. + executor.execute(() -> { + callback.onQueryException(e); + }); + return; + } + queryCount++; + } else v4fd = null; + + final InetAddressAnswerAccumulator accumulator = + new InetAddressAnswerAccumulator(queryCount, callback); + + if (queryIpv6) registerFDListener(executor, v6fd, accumulator, cancellationSignal, lock); + if (queryIpv4) registerFDListener(executor, v4fd, accumulator, cancellationSignal, lock); + + if (cancellationSignal == null) return; + cancellationSignal.setOnCancelListener(() -> { + synchronized (lock) { + if (queryIpv4) cancelQuery(v4fd); + if (queryIpv6) cancelQuery(v6fd); + } + }); } private void registerFDListener(@NonNull Executor executor, @@ -271,7 +417,7 @@ public final class DnsResolver { } byte[] answerbuf = null; try { - answerbuf = resNetworkResult(fd); + answerbuf = resNetworkResult(fd); // Closes fd, marks it invalid. } catch (ErrnoException e) { Log.e(TAG, "resNetworkResult:" + e.toString()); answerCallback.onQueryException(e); @@ -291,19 +437,54 @@ public final class DnsResolver { }); } + private void cancelQuery(@NonNull FileDescriptor queryfd) { + if (!queryfd.valid()) return; + Looper.getMainLooper().getQueue().removeOnFileDescriptorEventListener(queryfd); + resNetworkCancel(queryfd); // Closes fd, marks it invalid. + } + private void maybeAddCancellationSignal(@Nullable CancellationSignal cancellationSignal, @NonNull FileDescriptor queryfd, @NonNull Object lock) { if (cancellationSignal == null) return; cancellationSignal.setOnCancelListener(() -> { synchronized (lock) { - if (!queryfd.valid()) return; - Looper.getMainLooper().getQueue() - .removeOnFileDescriptorEventListener(queryfd); - resNetworkCancel(queryfd); + cancelQuery(queryfd); } }); } + // These two functions match the behaviour of have_ipv4 and have_ipv6 in the native resolver. + private boolean haveIpv4(@Nullable Network network) { + final SocketAddress addrIpv4 = + new InetSocketAddress(InetAddresses.parseNumericAddress("8.8.8.8"), 0); + return checkConnectivity(network, AF_INET, addrIpv4); + } + + private boolean haveIpv6(@Nullable Network network) { + final SocketAddress addrIpv6 = + new InetSocketAddress(InetAddresses.parseNumericAddress("2000::"), 0); + return checkConnectivity(network, AF_INET6, addrIpv6); + } + + private boolean checkConnectivity(@Nullable Network network, + int domain, @NonNull SocketAddress addr) { + final FileDescriptor socket; + try { + socket = Os.socket(domain, SOCK_DGRAM, IPPROTO_UDP); + } catch (ErrnoException e) { + return false; + } + try { + if (network != null) network.bindSocket(socket); + Os.connect(socket, addr); + } catch (IOException | ErrnoException e) { + return false; + } finally { + IoUtils.closeQuietly(socket); + } + return true; + } + private static class DnsAddressAnswer extends DnsPacket { private static final String TAG = "DnsResolver.DnsAddressAnswer"; private static final boolean DBG = false;