diff --git a/framework/api/system-current.txt b/framework/api/system-current.txt index 5750845c58..730555ba7e 100644 --- a/framework/api/system-current.txt +++ b/framework/api/system-current.txt @@ -381,6 +381,7 @@ package android.net { public abstract class QosFilter { method @NonNull public abstract android.net.Network getNetwork(); method public abstract boolean matchesLocalAddress(@NonNull java.net.InetAddress, int, int); + method public abstract boolean matchesRemoteAddress(@NonNull java.net.InetAddress, int, int); } public final class QosSession implements android.os.Parcelable { @@ -403,6 +404,7 @@ package android.net { method public int describeContents(); method @NonNull public java.net.InetSocketAddress getLocalSocketAddress(); method @NonNull public android.net.Network getNetwork(); + method @Nullable public java.net.InetSocketAddress getRemoteSocketAddress(); method public void writeToParcel(@NonNull android.os.Parcel, int); field @NonNull public static final android.os.Parcelable.Creator CREATOR; } diff --git a/framework/src/android/net/QosFilter.java b/framework/src/android/net/QosFilter.java index ab55002e02..957c867f20 100644 --- a/framework/src/android/net/QosFilter.java +++ b/framework/src/android/net/QosFilter.java @@ -71,5 +71,16 @@ public abstract class QosFilter { */ public abstract boolean matchesLocalAddress(@NonNull InetAddress address, int startPort, int endPort); + + /** + * Determines whether or not the parameters is a match for the filter. + * + * @param address the remote address + * @param startPort the start of the port range + * @param endPort the end of the port range + * @return whether the parameters match the remote address of the filter + */ + public abstract boolean matchesRemoteAddress(@NonNull InetAddress address, + int startPort, int endPort); } diff --git a/framework/src/android/net/QosSocketFilter.java b/framework/src/android/net/QosSocketFilter.java index 2080e68f5f..69da7f4401 100644 --- a/framework/src/android/net/QosSocketFilter.java +++ b/framework/src/android/net/QosSocketFilter.java @@ -138,13 +138,26 @@ public class QosSocketFilter extends QosFilter { if (mQosSocketInfo.getLocalSocketAddress() == null) { return false; } - - return matchesLocalAddress(mQosSocketInfo.getLocalSocketAddress(), address, startPort, + return matchesAddress(mQosSocketInfo.getLocalSocketAddress(), address, startPort, endPort); } /** - * Called from {@link QosSocketFilter#matchesLocalAddress(InetAddress, int, int)} with the + * @inheritDoc + */ + @Override + public boolean matchesRemoteAddress(@NonNull final InetAddress address, final int startPort, + final int endPort) { + if (mQosSocketInfo.getRemoteSocketAddress() == null) { + return false; + } + return matchesAddress(mQosSocketInfo.getRemoteSocketAddress(), address, startPort, + endPort); + } + + /** + * Called from {@link QosSocketFilter#matchesLocalAddress(InetAddress, int, int)} + * and {@link QosSocketFilter#matchesRemoteAddress(InetAddress, int, int)} with the * filterSocketAddress coming from {@link QosSocketInfo#getLocalSocketAddress()}. *

* This method exists for testing purposes since {@link QosSocketInfo} couldn't be mocked @@ -156,7 +169,7 @@ public class QosSocketFilter extends QosFilter { * @param endPort the end of the port range to check */ @VisibleForTesting - public static boolean matchesLocalAddress(@NonNull final InetSocketAddress filterSocketAddress, + public static boolean matchesAddress(@NonNull final InetSocketAddress filterSocketAddress, @NonNull final InetAddress address, final int startPort, final int endPort) { return startPort <= filterSocketAddress.getPort() diff --git a/framework/src/android/net/QosSocketInfo.java b/framework/src/android/net/QosSocketInfo.java index 53d966937a..a45d5075d6 100644 --- a/framework/src/android/net/QosSocketInfo.java +++ b/framework/src/android/net/QosSocketInfo.java @@ -17,6 +17,7 @@ package android.net; import android.annotation.NonNull; +import android.annotation.Nullable; import android.annotation.SystemApi; import android.os.Parcel; import android.os.ParcelFileDescriptor; @@ -32,7 +33,8 @@ import java.util.Objects; /** * Used in conjunction with * {@link ConnectivityManager#registerQosCallback} - * in order to receive Qos Sessions related to the local address and port of a bound {@link Socket}. + * in order to receive Qos Sessions related to the local address and port of a bound {@link Socket} + * and/or remote address and port of a connected {@link Socket}. * * @hide */ @@ -48,6 +50,9 @@ public final class QosSocketInfo implements Parcelable { @NonNull private final InetSocketAddress mLocalSocketAddress; + @Nullable + private final InetSocketAddress mRemoteSocketAddress; + /** * The {@link Network} the socket is on. * @@ -80,6 +85,18 @@ public final class QosSocketInfo implements Parcelable { return mLocalSocketAddress; } + /** + * The remote address of the socket passed into {@link QosSocketInfo(Network, Socket)}. + * The value does not reflect any changes that occur to the socket after it is first set + * in the constructor. + * + * @return the remote address of the socket if socket is connected, null otherwise + */ + @Nullable + public InetSocketAddress getRemoteSocketAddress() { + return mRemoteSocketAddress; + } + /** * Creates a {@link QosSocketInfo} given a {@link Network} and bound {@link Socket}. The * {@link Socket} must remain bound in order to receive {@link QosSession}s. @@ -95,6 +112,12 @@ public final class QosSocketInfo implements Parcelable { mParcelFileDescriptor = ParcelFileDescriptor.fromSocket(socket); mLocalSocketAddress = new InetSocketAddress(socket.getLocalAddress(), socket.getLocalPort()); + + if (socket.isConnected()) { + mRemoteSocketAddress = (InetSocketAddress) socket.getRemoteSocketAddress(); + } else { + mRemoteSocketAddress = null; + } } /* Parcelable methods */ @@ -102,11 +125,15 @@ public final class QosSocketInfo implements Parcelable { mNetwork = Objects.requireNonNull(Network.CREATOR.createFromParcel(in)); mParcelFileDescriptor = ParcelFileDescriptor.CREATOR.createFromParcel(in); - final int addressLength = in.readInt(); - mLocalSocketAddress = readSocketAddress(in, addressLength); + final int localAddressLength = in.readInt(); + mLocalSocketAddress = readSocketAddress(in, localAddressLength); + + final int remoteAddressLength = in.readInt(); + mRemoteSocketAddress = remoteAddressLength == 0 ? null + : readSocketAddress(in, remoteAddressLength); } - private InetSocketAddress readSocketAddress(final Parcel in, final int addressLength) { + private @NonNull InetSocketAddress readSocketAddress(final Parcel in, final int addressLength) { final byte[] address = new byte[addressLength]; in.readByteArray(address); final int port = in.readInt(); @@ -130,10 +157,19 @@ public final class QosSocketInfo implements Parcelable { mNetwork.writeToParcel(dest, 0); mParcelFileDescriptor.writeToParcel(dest, 0); - final byte[] address = mLocalSocketAddress.getAddress().getAddress(); - dest.writeInt(address.length); - dest.writeByteArray(address); + final byte[] localAddress = mLocalSocketAddress.getAddress().getAddress(); + dest.writeInt(localAddress.length); + dest.writeByteArray(localAddress); dest.writeInt(mLocalSocketAddress.getPort()); + + if (mRemoteSocketAddress == null) { + dest.writeInt(0); + } else { + final byte[] remoteAddress = mRemoteSocketAddress.getAddress().getAddress(); + dest.writeInt(remoteAddress.length); + dest.writeByteArray(remoteAddress); + dest.writeInt(mRemoteSocketAddress.getPort()); + } } @NonNull diff --git a/tests/unit/java/android/net/QosSocketFilterTest.java b/tests/unit/java/android/net/QosSocketFilterTest.java index ad58960eaa..40f8f1b8d0 100644 --- a/tests/unit/java/android/net/QosSocketFilterTest.java +++ b/tests/unit/java/android/net/QosSocketFilterTest.java @@ -35,7 +35,7 @@ public class QosSocketFilterTest { public void testPortExactMatch() { final InetAddress addressA = InetAddresses.parseNumericAddress("1.2.3.4"); final InetAddress addressB = InetAddresses.parseNumericAddress("1.2.3.4"); - assertTrue(QosSocketFilter.matchesLocalAddress( + assertTrue(QosSocketFilter.matchesAddress( new InetSocketAddress(addressA, 10), addressB, 10, 10)); } @@ -44,7 +44,7 @@ public class QosSocketFilterTest { public void testPortLessThanStart() { final InetAddress addressA = InetAddresses.parseNumericAddress("1.2.3.4"); final InetAddress addressB = InetAddresses.parseNumericAddress("1.2.3.4"); - assertFalse(QosSocketFilter.matchesLocalAddress( + assertFalse(QosSocketFilter.matchesAddress( new InetSocketAddress(addressA, 8), addressB, 10, 10)); } @@ -52,7 +52,7 @@ public class QosSocketFilterTest { public void testPortGreaterThanEnd() { final InetAddress addressA = InetAddresses.parseNumericAddress("1.2.3.4"); final InetAddress addressB = InetAddresses.parseNumericAddress("1.2.3.4"); - assertFalse(QosSocketFilter.matchesLocalAddress( + assertFalse(QosSocketFilter.matchesAddress( new InetSocketAddress(addressA, 18), addressB, 10, 10)); } @@ -60,7 +60,7 @@ public class QosSocketFilterTest { public void testPortBetweenStartAndEnd() { final InetAddress addressA = InetAddresses.parseNumericAddress("1.2.3.4"); final InetAddress addressB = InetAddresses.parseNumericAddress("1.2.3.4"); - assertTrue(QosSocketFilter.matchesLocalAddress( + assertTrue(QosSocketFilter.matchesAddress( new InetSocketAddress(addressA, 10), addressB, 8, 18)); } @@ -68,7 +68,7 @@ public class QosSocketFilterTest { public void testAddressesDontMatch() { final InetAddress addressA = InetAddresses.parseNumericAddress("1.2.3.4"); final InetAddress addressB = InetAddresses.parseNumericAddress("1.2.3.5"); - assertFalse(QosSocketFilter.matchesLocalAddress( + assertFalse(QosSocketFilter.matchesAddress( new InetSocketAddress(addressA, 10), addressB, 10, 10)); } }