diff --git a/core/java/android/net/NetworkAgent.java b/core/java/android/net/NetworkAgent.java index 51ef4a6209..4c49bc9f28 100644 --- a/core/java/android/net/NetworkAgent.java +++ b/core/java/android/net/NetworkAgent.java @@ -29,6 +29,7 @@ import com.android.internal.util.AsyncChannel; import com.android.internal.util.Protocol; import java.util.ArrayList; +import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; /** @@ -224,6 +225,11 @@ public abstract class NetworkAgent extends Handler { Context.CONNECTIVITY_SERVICE); netId = cm.registerNetworkAgent(new Messenger(this), new NetworkInfo(ni), new LinkProperties(lp), new NetworkCapabilities(nc), score, misc); + + final Set uids = nc.getUids(); + if (null != uids) { + addUidRanges(uids.toArray(new UidRange[uids.size()])); + } } @Override diff --git a/core/java/android/net/NetworkCapabilities.java b/core/java/android/net/NetworkCapabilities.java index 6bcaffd97a..d0f6f9bf77 100644 --- a/core/java/android/net/NetworkCapabilities.java +++ b/core/java/android/net/NetworkCapabilities.java @@ -20,6 +20,7 @@ import android.annotation.IntDef; import android.net.ConnectivityManager.NetworkCallback; import android.os.Parcel; import android.os.Parcelable; +import android.util.ArraySet; import android.util.proto.ProtoOutputStream; import com.android.internal.annotations.VisibleForTesting; @@ -29,6 +30,7 @@ import com.android.internal.util.Preconditions; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.util.Objects; +import java.util.Set; import java.util.StringJoiner; /** @@ -64,6 +66,7 @@ public final class NetworkCapabilities implements Parcelable { mLinkDownBandwidthKbps = nc.mLinkDownBandwidthKbps; mNetworkSpecifier = nc.mNetworkSpecifier; mSignalStrength = nc.mSignalStrength; + mUids = nc.mUids; } } @@ -77,6 +80,7 @@ public final class NetworkCapabilities implements Parcelable { mLinkUpBandwidthKbps = mLinkDownBandwidthKbps = LINK_BANDWIDTH_UNSPECIFIED; mNetworkSpecifier = null; mSignalStrength = SIGNAL_STRENGTH_UNSPECIFIED; + mUids = null; } /** @@ -836,6 +840,150 @@ public final class NetworkCapabilities implements Parcelable { return this.mSignalStrength == nc.mSignalStrength; } + /** + * List of UIDs this network applies to. No restriction if null. + *

+ * This is typically (and at this time, only) used by VPN. This network is only available to + * the UIDs in this list, and it is their default network. Apps in this list that wish to + * bypass the VPN can do so iff the VPN app allows them to or if they are privileged. If this + * member is null, then the network is not restricted by app UID. If it's an empty list, then + * it means nobody can use it. + *

+ * Please note that in principle a single app can be associated with multiple UIDs because + * each app will have a different UID when it's run as a different (macro-)user. A single + * macro user can only have a single active VPN app at any given time however. + *

+ * Also please be aware this class does not try to enforce any normalization on this. Callers + * can only alter the UIDs by setting them wholesale : this class does not provide any utility + * to add or remove individual UIDs or ranges. If callers have any normalization needs on + * their own (like requiring sortedness or no overlap) they need to enforce it + * themselves. Some of the internal methods also assume this is normalized as in no adjacent + * or overlapping ranges are present. + * + * @hide + */ + private Set mUids = null; + + /** + * Set the list of UIDs this network applies to. + * This makes a copy of the set so that callers can't modify it after the call. + * @hide + */ + public NetworkCapabilities setUids(Set uids) { + if (null == uids) { + mUids = null; + } else { + mUids = new ArraySet<>(uids); + } + return this; + } + + /** + * Get the list of UIDs this network applies to. + * This returns a copy of the set so that callers can't modify the original object. + * @hide + */ + public Set getUids() { + return null == mUids ? null : new ArraySet<>(mUids); + } + + /** + * Test whether this network applies to this UID. + * @hide + */ + public boolean appliesToUid(int uid) { + if (null == mUids) return true; + for (UidRange range : mUids) { + if (range.contains(uid)) { + return true; + } + } + return false; + } + + /** + * Tests if the set of UIDs that this network applies to is the same of the passed set of UIDs. + *

+ * This test only checks whether equal range objects are in both sets. It will + * return false if the ranges are not exactly the same, even if the covered UIDs + * are for an equivalent result. + *

+ * Note that this method is not very optimized, which is fine as long as it's not used very + * often. + *

+ * nc is assumed nonnull. + * + * @hide + */ + @VisibleForTesting + public boolean equalsUids(NetworkCapabilities nc) { + Set comparedUids = nc.mUids; + if (null == comparedUids) return null == mUids; + if (null == mUids) return false; + // Make a copy so it can be mutated to check that all ranges in mUids + // also are in uids. + final Set uids = new ArraySet<>(mUids); + for (UidRange range : comparedUids) { + if (!uids.contains(range)) { + return false; + } + uids.remove(range); + } + return uids.isEmpty(); + } + + /** + * Test whether the passed NetworkCapabilities satisfies the UIDs this capabilities require. + * + * This is called on the NetworkCapabilities embedded in a request with the capabilities + * of an available network. + * nc is assumed nonnull. + * @see #appliesToUid + * @hide + */ + public boolean satisfiedByUids(NetworkCapabilities nc) { + if (null == nc.mUids) return true; // The network satisfies everything. + if (null == mUids) return false; // Not everything allowed but requires everything + for (UidRange requiredRange : mUids) { + if (!nc.appliesToUidRange(requiredRange)) { + return false; + } + } + return true; + } + + /** + * Returns whether this network applies to the passed ranges. + * This assumes that to apply, the passed range has to be entirely contained + * within one of the ranges this network applies to. If the ranges are not normalized, + * this method may return false even though all required UIDs are covered because no + * single range contained them all. + * @hide + */ + @VisibleForTesting + public boolean appliesToUidRange(UidRange requiredRange) { + if (null == mUids) return true; + for (UidRange uidRange : mUids) { + if (uidRange.containsRange(requiredRange)) { + return true; + } + } + return false; + } + + /** + * Combine the UIDs this network currently applies to with the UIDs the passed + * NetworkCapabilities apply to. + * nc is assumed nonnull. + */ + private void combineUids(NetworkCapabilities nc) { + if (null == nc.mUids || null == mUids) { + mUids = null; + return; + } + mUids.addAll(nc.mUids); + } + /** * Combine a set of Capabilities to this one. Useful for coming up with the complete set * @hide @@ -846,6 +994,7 @@ public final class NetworkCapabilities implements Parcelable { combineLinkBandwidths(nc); combineSpecifiers(nc); combineSignalStrength(nc); + combineUids(nc); } /** @@ -858,12 +1007,13 @@ public final class NetworkCapabilities implements Parcelable { * @hide */ private boolean satisfiedByNetworkCapabilities(NetworkCapabilities nc, boolean onlyImmutable) { - return (nc != null && - satisfiedByNetCapabilities(nc, onlyImmutable) && - satisfiedByTransportTypes(nc) && - (onlyImmutable || satisfiedByLinkBandwidths(nc)) && - satisfiedBySpecifier(nc) && - (onlyImmutable || satisfiedBySignalStrength(nc))); + return (nc != null + && satisfiedByNetCapabilities(nc, onlyImmutable) + && satisfiedByTransportTypes(nc) + && (onlyImmutable || satisfiedByLinkBandwidths(nc)) + && satisfiedBySpecifier(nc) + && (onlyImmutable || satisfiedBySignalStrength(nc)) + && (onlyImmutable || satisfiedByUids(nc))); } /** @@ -946,24 +1096,26 @@ public final class NetworkCapabilities implements Parcelable { @Override public boolean equals(Object obj) { if (obj == null || (obj instanceof NetworkCapabilities == false)) return false; - NetworkCapabilities that = (NetworkCapabilities)obj; - return (equalsNetCapabilities(that) && - equalsTransportTypes(that) && - equalsLinkBandwidths(that) && - equalsSignalStrength(that) && - equalsSpecifier(that)); + NetworkCapabilities that = (NetworkCapabilities) obj; + return (equalsNetCapabilities(that) + && equalsTransportTypes(that) + && equalsLinkBandwidths(that) + && equalsSignalStrength(that) + && equalsSpecifier(that) + && equalsUids(that)); } @Override public int hashCode() { - return ((int)(mNetworkCapabilities & 0xFFFFFFFF) + - ((int)(mNetworkCapabilities >> 32) * 3) + - ((int)(mTransportTypes & 0xFFFFFFFF) * 5) + - ((int)(mTransportTypes >> 32) * 7) + - (mLinkUpBandwidthKbps * 11) + - (mLinkDownBandwidthKbps * 13) + - Objects.hashCode(mNetworkSpecifier) * 17 + - (mSignalStrength * 19)); + return ((int) (mNetworkCapabilities & 0xFFFFFFFF) + + ((int) (mNetworkCapabilities >> 32) * 3) + + ((int) (mTransportTypes & 0xFFFFFFFF) * 5) + + ((int) (mTransportTypes >> 32) * 7) + + (mLinkUpBandwidthKbps * 11) + + (mLinkDownBandwidthKbps * 13) + + Objects.hashCode(mNetworkSpecifier) * 17 + + (mSignalStrength * 19) + + Objects.hashCode(mUids) * 23); } @Override @@ -978,6 +1130,7 @@ public final class NetworkCapabilities implements Parcelable { dest.writeInt(mLinkDownBandwidthKbps); dest.writeParcelable((Parcelable) mNetworkSpecifier, flags); dest.writeInt(mSignalStrength); + dest.writeArraySet(new ArraySet<>(mUids)); } public static final Creator CREATOR = @@ -992,6 +1145,8 @@ public final class NetworkCapabilities implements Parcelable { netCap.mLinkDownBandwidthKbps = in.readInt(); netCap.mNetworkSpecifier = in.readParcelable(null); netCap.mSignalStrength = in.readInt(); + netCap.mUids = (ArraySet) in.readArraySet( + null /* ClassLoader, null for default */); return netCap; } @Override @@ -1024,7 +1179,10 @@ public final class NetworkCapabilities implements Parcelable { String signalStrength = (hasSignalStrength() ? " SignalStrength: " + mSignalStrength : ""); - return "[" + transports + capabilities + upBand + dnBand + specifier + signalStrength + "]"; + String uids = (null != mUids ? " Uids: <" + mUids + ">" : ""); + + return "[" + transports + capabilities + upBand + dnBand + specifier + signalStrength + + uids + "]"; } /** diff --git a/tests/net/java/android/net/NetworkCapabilitiesTest.java b/tests/net/java/android/net/NetworkCapabilitiesTest.java index e6170cb42c..4c6a64464b 100644 --- a/tests/net/java/android/net/NetworkCapabilitiesTest.java +++ b/tests/net/java/android/net/NetworkCapabilitiesTest.java @@ -34,12 +34,15 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertTrue; +import android.os.Parcel; import android.support.test.runner.AndroidJUnit4; import android.test.suitebuilder.annotation.SmallTest; +import android.util.ArraySet; import org.junit.Test; import org.junit.runner.RunWith; +import java.util.Set; @RunWith(AndroidJUnit4.class) @SmallTest @@ -189,4 +192,84 @@ public class NetworkCapabilitiesTest { assertEquals(20, NetworkCapabilities .maxBandwidth(10, 20)); } + + @Test + public void testSetUids() { + final NetworkCapabilities netCap = new NetworkCapabilities(); + final Set uids = new ArraySet<>(); + uids.add(new UidRange(50, 100)); + uids.add(new UidRange(3000, 4000)); + netCap.setUids(uids); + assertTrue(netCap.appliesToUid(50)); + assertTrue(netCap.appliesToUid(80)); + assertTrue(netCap.appliesToUid(100)); + assertTrue(netCap.appliesToUid(3000)); + assertTrue(netCap.appliesToUid(3001)); + assertFalse(netCap.appliesToUid(10)); + assertFalse(netCap.appliesToUid(25)); + assertFalse(netCap.appliesToUid(49)); + assertFalse(netCap.appliesToUid(101)); + assertFalse(netCap.appliesToUid(2000)); + assertFalse(netCap.appliesToUid(100000)); + + assertTrue(netCap.appliesToUidRange(new UidRange(50, 100))); + assertTrue(netCap.appliesToUidRange(new UidRange(70, 72))); + assertTrue(netCap.appliesToUidRange(new UidRange(3500, 3912))); + assertFalse(netCap.appliesToUidRange(new UidRange(1, 100))); + assertFalse(netCap.appliesToUidRange(new UidRange(49, 100))); + assertFalse(netCap.appliesToUidRange(new UidRange(1, 10))); + assertFalse(netCap.appliesToUidRange(new UidRange(60, 101))); + assertFalse(netCap.appliesToUidRange(new UidRange(60, 3400))); + + NetworkCapabilities netCap2 = new NetworkCapabilities(); + assertFalse(netCap2.satisfiedByUids(netCap)); + assertFalse(netCap2.equalsUids(netCap)); + netCap2.setUids(uids); + assertTrue(netCap2.satisfiedByUids(netCap)); + assertTrue(netCap.equalsUids(netCap2)); + assertTrue(netCap2.equalsUids(netCap)); + + uids.add(new UidRange(600, 700)); + netCap2.setUids(uids); + assertFalse(netCap2.satisfiedByUids(netCap)); + assertFalse(netCap.appliesToUid(650)); + assertTrue(netCap2.appliesToUid(650)); + netCap.combineCapabilities(netCap2); + assertTrue(netCap2.satisfiedByUids(netCap)); + assertTrue(netCap.appliesToUid(650)); + assertFalse(netCap.appliesToUid(500)); + + assertFalse(new NetworkCapabilities().satisfiedByUids(netCap)); + netCap.combineCapabilities(new NetworkCapabilities()); + assertTrue(netCap.appliesToUid(500)); + assertTrue(netCap.appliesToUidRange(new UidRange(1, 100000))); + assertFalse(netCap2.appliesToUid(500)); + assertFalse(netCap2.appliesToUidRange(new UidRange(1, 100000))); + assertTrue(new NetworkCapabilities().satisfiedByUids(netCap)); + } + + @Test + public void testParcelNetworkCapabilities() { + final Set uids = new ArraySet<>(); + uids.add(new UidRange(50, 100)); + uids.add(new UidRange(3000, 4000)); + final NetworkCapabilities netCap = new NetworkCapabilities() + .addCapability(NET_CAPABILITY_INTERNET) + .setUids(uids) + .addCapability(NET_CAPABILITY_EIMS) + .addCapability(NET_CAPABILITY_NOT_METERED); + assertEqualsThroughMarshalling(netCap); + } + + private void assertEqualsThroughMarshalling(NetworkCapabilities netCap) { + Parcel p = Parcel.obtain(); + netCap.writeToParcel(p, /* flags */ 0); + p.setDataPosition(0); + byte[] marshalledData = p.marshall(); + + p = Parcel.obtain(); + p.unmarshall(marshalledData, 0, marshalledData.length); + p.setDataPosition(0); + assertEquals(NetworkCapabilities.CREATOR.createFromParcel(p), netCap); + } }