Merge "QOS filter matching support based on remote address and port number for connected sockets"

This commit is contained in:
Jayachandran Chinnakkannu
2021-05-14 20:02:30 +00:00
committed by Gerrit Code Review
5 changed files with 78 additions and 16 deletions

View File

@@ -381,6 +381,7 @@ package android.net {
public abstract class QosFilter { public abstract class QosFilter {
method @NonNull public abstract android.net.Network getNetwork(); method @NonNull public abstract android.net.Network getNetwork();
method public abstract boolean matchesLocalAddress(@NonNull java.net.InetAddress, int, int); method public abstract boolean matchesLocalAddress(@NonNull java.net.InetAddress, int, int);
method public abstract boolean matchesRemoteAddress(@NonNull java.net.InetAddress, int, int);
} }
public final class QosSession implements android.os.Parcelable { public final class QosSession implements android.os.Parcelable {
@@ -403,6 +404,7 @@ package android.net {
method public int describeContents(); method public int describeContents();
method @NonNull public java.net.InetSocketAddress getLocalSocketAddress(); method @NonNull public java.net.InetSocketAddress getLocalSocketAddress();
method @NonNull public android.net.Network getNetwork(); method @NonNull public android.net.Network getNetwork();
method @Nullable public java.net.InetSocketAddress getRemoteSocketAddress();
method public void writeToParcel(@NonNull android.os.Parcel, int); method public void writeToParcel(@NonNull android.os.Parcel, int);
field @NonNull public static final android.os.Parcelable.Creator<android.net.QosSocketInfo> CREATOR; field @NonNull public static final android.os.Parcelable.Creator<android.net.QosSocketInfo> CREATOR;
} }

View File

@@ -71,5 +71,16 @@ public abstract class QosFilter {
*/ */
public abstract boolean matchesLocalAddress(@NonNull InetAddress address, public abstract boolean matchesLocalAddress(@NonNull InetAddress address,
int startPort, int endPort); int startPort, int endPort);
/**
* Determines whether or not the parameters is a match for the filter.
*
* @param address the remote address
* @param startPort the start of the port range
* @param endPort the end of the port range
* @return whether the parameters match the remote address of the filter
*/
public abstract boolean matchesRemoteAddress(@NonNull InetAddress address,
int startPort, int endPort);
} }

View File

@@ -138,13 +138,26 @@ public class QosSocketFilter extends QosFilter {
if (mQosSocketInfo.getLocalSocketAddress() == null) { if (mQosSocketInfo.getLocalSocketAddress() == null) {
return false; return false;
} }
return matchesAddress(mQosSocketInfo.getLocalSocketAddress(), address, startPort,
return matchesLocalAddress(mQosSocketInfo.getLocalSocketAddress(), address, startPort,
endPort); endPort);
} }
/** /**
* Called from {@link QosSocketFilter#matchesLocalAddress(InetAddress, int, int)} with the * @inheritDoc
*/
@Override
public boolean matchesRemoteAddress(@NonNull final InetAddress address, final int startPort,
final int endPort) {
if (mQosSocketInfo.getRemoteSocketAddress() == null) {
return false;
}
return matchesAddress(mQosSocketInfo.getRemoteSocketAddress(), address, startPort,
endPort);
}
/**
* Called from {@link QosSocketFilter#matchesLocalAddress(InetAddress, int, int)}
* and {@link QosSocketFilter#matchesRemoteAddress(InetAddress, int, int)} with the
* filterSocketAddress coming from {@link QosSocketInfo#getLocalSocketAddress()}. * filterSocketAddress coming from {@link QosSocketInfo#getLocalSocketAddress()}.
* <p> * <p>
* This method exists for testing purposes since {@link QosSocketInfo} couldn't be mocked * This method exists for testing purposes since {@link QosSocketInfo} couldn't be mocked
@@ -156,7 +169,7 @@ public class QosSocketFilter extends QosFilter {
* @param endPort the end of the port range to check * @param endPort the end of the port range to check
*/ */
@VisibleForTesting @VisibleForTesting
public static boolean matchesLocalAddress(@NonNull final InetSocketAddress filterSocketAddress, public static boolean matchesAddress(@NonNull final InetSocketAddress filterSocketAddress,
@NonNull final InetAddress address, @NonNull final InetAddress address,
final int startPort, final int endPort) { final int startPort, final int endPort) {
return startPort <= filterSocketAddress.getPort() return startPort <= filterSocketAddress.getPort()

View File

@@ -17,6 +17,7 @@
package android.net; package android.net;
import android.annotation.NonNull; import android.annotation.NonNull;
import android.annotation.Nullable;
import android.annotation.SystemApi; import android.annotation.SystemApi;
import android.os.Parcel; import android.os.Parcel;
import android.os.ParcelFileDescriptor; import android.os.ParcelFileDescriptor;
@@ -32,7 +33,8 @@ import java.util.Objects;
/** /**
* Used in conjunction with * Used in conjunction with
* {@link ConnectivityManager#registerQosCallback} * {@link ConnectivityManager#registerQosCallback}
* in order to receive Qos Sessions related to the local address and port of a bound {@link Socket}. * in order to receive Qos Sessions related to the local address and port of a bound {@link Socket}
* and/or remote address and port of a connected {@link Socket}.
* *
* @hide * @hide
*/ */
@@ -48,6 +50,9 @@ public final class QosSocketInfo implements Parcelable {
@NonNull @NonNull
private final InetSocketAddress mLocalSocketAddress; private final InetSocketAddress mLocalSocketAddress;
@Nullable
private final InetSocketAddress mRemoteSocketAddress;
/** /**
* The {@link Network} the socket is on. * The {@link Network} the socket is on.
* *
@@ -80,6 +85,18 @@ public final class QosSocketInfo implements Parcelable {
return mLocalSocketAddress; return mLocalSocketAddress;
} }
/**
* The remote address of the socket passed into {@link QosSocketInfo(Network, Socket)}.
* The value does not reflect any changes that occur to the socket after it is first set
* in the constructor.
*
* @return the remote address of the socket if socket is connected, null otherwise
*/
@Nullable
public InetSocketAddress getRemoteSocketAddress() {
return mRemoteSocketAddress;
}
/** /**
* Creates a {@link QosSocketInfo} given a {@link Network} and bound {@link Socket}. The * Creates a {@link QosSocketInfo} given a {@link Network} and bound {@link Socket}. The
* {@link Socket} must remain bound in order to receive {@link QosSession}s. * {@link Socket} must remain bound in order to receive {@link QosSession}s.
@@ -95,6 +112,12 @@ public final class QosSocketInfo implements Parcelable {
mParcelFileDescriptor = ParcelFileDescriptor.fromSocket(socket); mParcelFileDescriptor = ParcelFileDescriptor.fromSocket(socket);
mLocalSocketAddress = mLocalSocketAddress =
new InetSocketAddress(socket.getLocalAddress(), socket.getLocalPort()); new InetSocketAddress(socket.getLocalAddress(), socket.getLocalPort());
if (socket.isConnected()) {
mRemoteSocketAddress = (InetSocketAddress) socket.getRemoteSocketAddress();
} else {
mRemoteSocketAddress = null;
}
} }
/* Parcelable methods */ /* Parcelable methods */
@@ -102,11 +125,15 @@ public final class QosSocketInfo implements Parcelable {
mNetwork = Objects.requireNonNull(Network.CREATOR.createFromParcel(in)); mNetwork = Objects.requireNonNull(Network.CREATOR.createFromParcel(in));
mParcelFileDescriptor = ParcelFileDescriptor.CREATOR.createFromParcel(in); mParcelFileDescriptor = ParcelFileDescriptor.CREATOR.createFromParcel(in);
final int addressLength = in.readInt(); final int localAddressLength = in.readInt();
mLocalSocketAddress = readSocketAddress(in, addressLength); mLocalSocketAddress = readSocketAddress(in, localAddressLength);
final int remoteAddressLength = in.readInt();
mRemoteSocketAddress = remoteAddressLength == 0 ? null
: readSocketAddress(in, remoteAddressLength);
} }
private InetSocketAddress readSocketAddress(final Parcel in, final int addressLength) { private @NonNull InetSocketAddress readSocketAddress(final Parcel in, final int addressLength) {
final byte[] address = new byte[addressLength]; final byte[] address = new byte[addressLength];
in.readByteArray(address); in.readByteArray(address);
final int port = in.readInt(); final int port = in.readInt();
@@ -130,10 +157,19 @@ public final class QosSocketInfo implements Parcelable {
mNetwork.writeToParcel(dest, 0); mNetwork.writeToParcel(dest, 0);
mParcelFileDescriptor.writeToParcel(dest, 0); mParcelFileDescriptor.writeToParcel(dest, 0);
final byte[] address = mLocalSocketAddress.getAddress().getAddress(); final byte[] localAddress = mLocalSocketAddress.getAddress().getAddress();
dest.writeInt(address.length); dest.writeInt(localAddress.length);
dest.writeByteArray(address); dest.writeByteArray(localAddress);
dest.writeInt(mLocalSocketAddress.getPort()); dest.writeInt(mLocalSocketAddress.getPort());
if (mRemoteSocketAddress == null) {
dest.writeInt(0);
} else {
final byte[] remoteAddress = mRemoteSocketAddress.getAddress().getAddress();
dest.writeInt(remoteAddress.length);
dest.writeByteArray(remoteAddress);
dest.writeInt(mRemoteSocketAddress.getPort());
}
} }
@NonNull @NonNull

View File

@@ -35,7 +35,7 @@ public class QosSocketFilterTest {
public void testPortExactMatch() { public void testPortExactMatch() {
final InetAddress addressA = InetAddresses.parseNumericAddress("1.2.3.4"); final InetAddress addressA = InetAddresses.parseNumericAddress("1.2.3.4");
final InetAddress addressB = InetAddresses.parseNumericAddress("1.2.3.4"); final InetAddress addressB = InetAddresses.parseNumericAddress("1.2.3.4");
assertTrue(QosSocketFilter.matchesLocalAddress( assertTrue(QosSocketFilter.matchesAddress(
new InetSocketAddress(addressA, 10), addressB, 10, 10)); new InetSocketAddress(addressA, 10), addressB, 10, 10));
} }
@@ -44,7 +44,7 @@ public class QosSocketFilterTest {
public void testPortLessThanStart() { public void testPortLessThanStart() {
final InetAddress addressA = InetAddresses.parseNumericAddress("1.2.3.4"); final InetAddress addressA = InetAddresses.parseNumericAddress("1.2.3.4");
final InetAddress addressB = InetAddresses.parseNumericAddress("1.2.3.4"); final InetAddress addressB = InetAddresses.parseNumericAddress("1.2.3.4");
assertFalse(QosSocketFilter.matchesLocalAddress( assertFalse(QosSocketFilter.matchesAddress(
new InetSocketAddress(addressA, 8), addressB, 10, 10)); new InetSocketAddress(addressA, 8), addressB, 10, 10));
} }
@@ -52,7 +52,7 @@ public class QosSocketFilterTest {
public void testPortGreaterThanEnd() { public void testPortGreaterThanEnd() {
final InetAddress addressA = InetAddresses.parseNumericAddress("1.2.3.4"); final InetAddress addressA = InetAddresses.parseNumericAddress("1.2.3.4");
final InetAddress addressB = InetAddresses.parseNumericAddress("1.2.3.4"); final InetAddress addressB = InetAddresses.parseNumericAddress("1.2.3.4");
assertFalse(QosSocketFilter.matchesLocalAddress( assertFalse(QosSocketFilter.matchesAddress(
new InetSocketAddress(addressA, 18), addressB, 10, 10)); new InetSocketAddress(addressA, 18), addressB, 10, 10));
} }
@@ -60,7 +60,7 @@ public class QosSocketFilterTest {
public void testPortBetweenStartAndEnd() { public void testPortBetweenStartAndEnd() {
final InetAddress addressA = InetAddresses.parseNumericAddress("1.2.3.4"); final InetAddress addressA = InetAddresses.parseNumericAddress("1.2.3.4");
final InetAddress addressB = InetAddresses.parseNumericAddress("1.2.3.4"); final InetAddress addressB = InetAddresses.parseNumericAddress("1.2.3.4");
assertTrue(QosSocketFilter.matchesLocalAddress( assertTrue(QosSocketFilter.matchesAddress(
new InetSocketAddress(addressA, 10), addressB, 8, 18)); new InetSocketAddress(addressA, 10), addressB, 8, 18));
} }
@@ -68,7 +68,7 @@ public class QosSocketFilterTest {
public void testAddressesDontMatch() { public void testAddressesDontMatch() {
final InetAddress addressA = InetAddresses.parseNumericAddress("1.2.3.4"); final InetAddress addressA = InetAddresses.parseNumericAddress("1.2.3.4");
final InetAddress addressB = InetAddresses.parseNumericAddress("1.2.3.5"); final InetAddress addressB = InetAddresses.parseNumericAddress("1.2.3.5");
assertFalse(QosSocketFilter.matchesLocalAddress( assertFalse(QosSocketFilter.matchesAddress(
new InetSocketAddress(addressA, 10), addressB, 10, 10)); new InetSocketAddress(addressA, 10), addressB, 10, 10));
} }
} }