Merge "[VCN05] Pass request type when requesting network"

This commit is contained in:
Junyu Lai
2021-01-14 06:52:46 +00:00
committed by Gerrit Code Review
5 changed files with 84 additions and 44 deletions

View File

@@ -16,6 +16,9 @@
package android.net; package android.net;
import static android.net.IpSecManager.INVALID_RESOURCE_ID; import static android.net.IpSecManager.INVALID_RESOURCE_ID;
import static android.net.NetworkRequest.Type.LISTEN;
import static android.net.NetworkRequest.Type.REQUEST;
import static android.net.NetworkRequest.Type.TRACK_DEFAULT;
import android.annotation.CallbackExecutor; import android.annotation.CallbackExecutor;
import android.annotation.IntDef; import android.annotation.IntDef;
@@ -3730,14 +3733,12 @@ public class ConnectivityManager {
private static final HashMap<NetworkRequest, NetworkCallback> sCallbacks = new HashMap<>(); private static final HashMap<NetworkRequest, NetworkCallback> sCallbacks = new HashMap<>();
private static CallbackHandler sCallbackHandler; private static CallbackHandler sCallbackHandler;
private static final int LISTEN = 1;
private static final int REQUEST = 2;
private NetworkRequest sendRequestForNetwork(NetworkCapabilities need, NetworkCallback callback, private NetworkRequest sendRequestForNetwork(NetworkCapabilities need, NetworkCallback callback,
int timeoutMs, int action, int legacyType, CallbackHandler handler) { int timeoutMs, NetworkRequest.Type reqType, int legacyType, CallbackHandler handler) {
printStackTrace(); printStackTrace();
checkCallbackNotNull(callback); checkCallbackNotNull(callback);
Preconditions.checkArgument(action == REQUEST || need != null, "null NetworkCapabilities"); Preconditions.checkArgument(
reqType == TRACK_DEFAULT || need != null, "null NetworkCapabilities");
final NetworkRequest request; final NetworkRequest request;
final String callingPackageName = mContext.getOpPackageName(); final String callingPackageName = mContext.getOpPackageName();
try { try {
@@ -3750,13 +3751,13 @@ public class ConnectivityManager {
} }
Messenger messenger = new Messenger(handler); Messenger messenger = new Messenger(handler);
Binder binder = new Binder(); Binder binder = new Binder();
if (action == LISTEN) { if (reqType == LISTEN) {
request = mService.listenForNetwork( request = mService.listenForNetwork(
need, messenger, binder, callingPackageName); need, messenger, binder, callingPackageName);
} else { } else {
request = mService.requestNetwork( request = mService.requestNetwork(
need, messenger, timeoutMs, binder, legacyType, callingPackageName, need, reqType.ordinal(), messenger, timeoutMs, binder, legacyType,
getAttributionTag()); callingPackageName, getAttributionTag());
} }
if (request != null) { if (request != null) {
sCallbacks.put(request, callback); sCallbacks.put(request, callback);
@@ -4260,7 +4261,7 @@ public class ConnectivityManager {
// request, i.e., the system default network. // request, i.e., the system default network.
CallbackHandler cbHandler = new CallbackHandler(handler); CallbackHandler cbHandler = new CallbackHandler(handler);
sendRequestForNetwork(null /* NetworkCapabilities need */, networkCallback, 0, sendRequestForNetwork(null /* NetworkCapabilities need */, networkCallback, 0,
REQUEST, TYPE_NONE, cbHandler); TRACK_DEFAULT, TYPE_NONE, cbHandler);
} }
/** /**

View File

@@ -167,7 +167,7 @@ interface IConnectivityManager
in NetworkCapabilities nc, int score, in NetworkAgentConfig config, in NetworkCapabilities nc, int score, in NetworkAgentConfig config,
in int factorySerialNumber); in int factorySerialNumber);
NetworkRequest requestNetwork(in NetworkCapabilities networkCapabilities, NetworkRequest requestNetwork(in NetworkCapabilities networkCapabilities, int reqType,
in Messenger messenger, int timeoutSec, in IBinder binder, int legacy, in Messenger messenger, int timeoutSec, in IBinder binder, int legacy,
String callingPackageName, String callingAttributionTag); String callingPackageName, String callingAttributionTag);

View File

@@ -5642,31 +5642,40 @@ public class ConnectivityService extends IConnectivityManager.Stub
@Override @Override
public NetworkRequest requestNetwork(NetworkCapabilities networkCapabilities, public NetworkRequest requestNetwork(NetworkCapabilities networkCapabilities,
Messenger messenger, int timeoutMs, IBinder binder, int legacyType, int reqTypeInt, Messenger messenger, int timeoutMs, IBinder binder,
@NonNull String callingPackageName, @Nullable String callingAttributionTag) { int legacyType, @NonNull String callingPackageName,
@Nullable String callingAttributionTag) {
if (legacyType != TYPE_NONE && !checkNetworkStackPermission()) { if (legacyType != TYPE_NONE && !checkNetworkStackPermission()) {
if (checkUnsupportedStartingFrom(Build.VERSION_CODES.M, callingPackageName)) { if (checkUnsupportedStartingFrom(Build.VERSION_CODES.M, callingPackageName)) {
throw new SecurityException("Insufficient permissions to specify legacy type"); throw new SecurityException("Insufficient permissions to specify legacy type");
} }
} }
final int callingUid = mDeps.getCallingUid(); final int callingUid = mDeps.getCallingUid();
final NetworkRequest.Type type = (networkCapabilities == null) final NetworkRequest.Type reqType;
? NetworkRequest.Type.TRACK_DEFAULT try {
: NetworkRequest.Type.REQUEST; reqType = NetworkRequest.Type.values()[reqTypeInt];
// If the requested networkCapabilities is null, take them instead from } catch (ArrayIndexOutOfBoundsException e) {
// the default network request. This allows callers to keep track of throw new IllegalArgumentException("Unsupported request type " + reqTypeInt);
// the system default network. }
if (type == NetworkRequest.Type.TRACK_DEFAULT) { switch (reqType) {
case TRACK_DEFAULT:
// If the request type is TRACK_DEFAULT, the passed {@code networkCapabilities}
// is unused and will be replaced by the one from the default network request.
// This allows callers to keep track of the system default network.
networkCapabilities = createDefaultNetworkCapabilitiesForUid(callingUid); networkCapabilities = createDefaultNetworkCapabilitiesForUid(callingUid);
enforceAccessPermission(); enforceAccessPermission();
} else { break;
case REQUEST:
networkCapabilities = new NetworkCapabilities(networkCapabilities); networkCapabilities = new NetworkCapabilities(networkCapabilities);
enforceNetworkRequestPermissions(networkCapabilities, callingPackageName, enforceNetworkRequestPermissions(networkCapabilities, callingPackageName,
callingAttributionTag); callingAttributionTag);
// TODO: this is incorrect. We mark the request as metered or not depending on the state // TODO: this is incorrect. We mark the request as metered or not depending on
// of the app when the request is filed, but we never change the request if the app // the state of the app when the request is filed, but we never change the
// changes network state. http://b/29964605 // request if the app changes network state. http://b/29964605
enforceMeteredApnPolicy(networkCapabilities); enforceMeteredApnPolicy(networkCapabilities);
break;
default:
throw new IllegalArgumentException("Unsupported request type " + reqType);
} }
ensureRequestableCapabilities(networkCapabilities); ensureRequestableCapabilities(networkCapabilities);
ensureSufficientPermissionsForRequest(networkCapabilities, ensureSufficientPermissionsForRequest(networkCapabilities,
@@ -5685,7 +5694,7 @@ public class ConnectivityService extends IConnectivityManager.Stub
ensureValid(networkCapabilities); ensureValid(networkCapabilities);
NetworkRequest networkRequest = new NetworkRequest(networkCapabilities, legacyType, NetworkRequest networkRequest = new NetworkRequest(networkCapabilities, legacyType,
nextNetworkRequestId(), type); nextNetworkRequestId(), reqType);
NetworkRequestInfo nri = new NetworkRequestInfo(messenger, networkRequest, binder); NetworkRequestInfo nri = new NetworkRequestInfo(messenger, networkRequest, binder);
if (DBG) log("requestNetwork for " + nri); if (DBG) log("requestNetwork for " + nri);

View File

@@ -16,6 +16,7 @@
package android.net; package android.net;
import static android.net.ConnectivityManager.TYPE_NONE;
import static android.net.NetworkCapabilities.NET_CAPABILITY_CBS; import static android.net.NetworkCapabilities.NET_CAPABILITY_CBS;
import static android.net.NetworkCapabilities.NET_CAPABILITY_DUN; import static android.net.NetworkCapabilities.NET_CAPABILITY_DUN;
import static android.net.NetworkCapabilities.NET_CAPABILITY_FOTA; import static android.net.NetworkCapabilities.NET_CAPABILITY_FOTA;
@@ -31,16 +32,21 @@ import static android.net.NetworkCapabilities.TRANSPORT_BLUETOOTH;
import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR; import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
import static android.net.NetworkCapabilities.TRANSPORT_ETHERNET; import static android.net.NetworkCapabilities.TRANSPORT_ETHERNET;
import static android.net.NetworkCapabilities.TRANSPORT_WIFI; import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
import static android.net.NetworkRequest.Type.REQUEST;
import static android.net.NetworkRequest.Type.TRACK_DEFAULT;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.Mockito.any; import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyBoolean; import static org.mockito.Mockito.anyBoolean;
import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@@ -49,9 +55,7 @@ import static org.mockito.Mockito.when;
import android.app.PendingIntent; import android.app.PendingIntent;
import android.content.Context; import android.content.Context;
import android.content.pm.ApplicationInfo; import android.content.pm.ApplicationInfo;
import android.net.ConnectivityManager;
import android.net.ConnectivityManager.NetworkCallback; import android.net.ConnectivityManager.NetworkCallback;
import android.net.NetworkCapabilities;
import android.os.Build.VERSION_CODES; import android.os.Build.VERSION_CODES;
import android.os.Bundle; import android.os.Bundle;
import android.os.Handler; import android.os.Handler;
@@ -213,9 +217,8 @@ public class ConnectivityManagerTest {
ArgumentCaptor<Messenger> captor = ArgumentCaptor.forClass(Messenger.class); ArgumentCaptor<Messenger> captor = ArgumentCaptor.forClass(Messenger.class);
// register callback // register callback
when(mService.requestNetwork( when(mService.requestNetwork(any(), anyInt(), captor.capture(), anyInt(), any(), anyInt(),
any(), captor.capture(), anyInt(), any(), anyInt(), any(), nullable(String.class))) any(), nullable(String.class))).thenReturn(request);
.thenReturn(request);
manager.requestNetwork(request, callback, handler); manager.requestNetwork(request, callback, handler);
// callback triggers // callback triggers
@@ -242,9 +245,8 @@ public class ConnectivityManagerTest {
ArgumentCaptor<Messenger> captor = ArgumentCaptor.forClass(Messenger.class); ArgumentCaptor<Messenger> captor = ArgumentCaptor.forClass(Messenger.class);
// register callback // register callback
when(mService.requestNetwork( when(mService.requestNetwork(any(), anyInt(), captor.capture(), anyInt(), any(), anyInt(),
any(), captor.capture(), anyInt(), any(), anyInt(), any(), nullable(String.class))) any(), nullable(String.class))).thenReturn(req1);
.thenReturn(req1);
manager.requestNetwork(req1, callback, handler); manager.requestNetwork(req1, callback, handler);
// callback triggers // callback triggers
@@ -261,9 +263,8 @@ public class ConnectivityManagerTest {
verify(callback, timeout(100).times(0)).onLosing(any(), anyInt()); verify(callback, timeout(100).times(0)).onLosing(any(), anyInt());
// callback can be registered again // callback can be registered again
when(mService.requestNetwork( when(mService.requestNetwork(any(), anyInt(), captor.capture(), anyInt(), any(), anyInt(),
any(), captor.capture(), anyInt(), any(), anyInt(), any(), nullable(String.class))) any(), nullable(String.class))).thenReturn(req2);
.thenReturn(req2);
manager.requestNetwork(req2, callback, handler); manager.requestNetwork(req2, callback, handler);
// callback triggers // callback triggers
@@ -286,7 +287,7 @@ public class ConnectivityManagerTest {
info.targetSdkVersion = VERSION_CODES.N_MR1 + 1; info.targetSdkVersion = VERSION_CODES.N_MR1 + 1;
when(mCtx.getApplicationInfo()).thenReturn(info); when(mCtx.getApplicationInfo()).thenReturn(info);
when(mService.requestNetwork(any(), any(), anyInt(), any(), anyInt(), any(), when(mService.requestNetwork(any(), anyInt(), any(), anyInt(), any(), anyInt(), any(),
nullable(String.class))).thenReturn(request); nullable(String.class))).thenReturn(request);
Handler handler = new Handler(Looper.getMainLooper()); Handler handler = new Handler(Looper.getMainLooper());
@@ -340,6 +341,35 @@ public class ConnectivityManagerTest {
} }
} }
@Test
public void testRequestType() throws Exception {
final String testPkgName = "MyPackage";
final ConnectivityManager manager = new ConnectivityManager(mCtx, mService);
when(mCtx.getOpPackageName()).thenReturn(testPkgName);
final NetworkRequest request = makeRequest(1);
final NetworkCallback callback = new ConnectivityManager.NetworkCallback();
manager.requestNetwork(request, callback);
verify(mService).requestNetwork(eq(request.networkCapabilities),
eq(REQUEST.ordinal()), any(), anyInt(), any(), eq(TYPE_NONE),
eq(testPkgName), eq(null));
reset(mService);
// Verify that register network callback does not calls requestNetwork at all.
manager.registerNetworkCallback(request, callback);
verify(mService, never()).requestNetwork(any(), anyInt(), any(), anyInt(), any(),
anyInt(), any(), any());
verify(mService).listenForNetwork(eq(request.networkCapabilities), any(), any(),
eq(testPkgName));
reset(mService);
manager.registerDefaultNetworkCallback(callback);
verify(mService).requestNetwork(eq(null),
eq(TRACK_DEFAULT.ordinal()), any(), anyInt(), any(), eq(TYPE_NONE),
eq(testPkgName), eq(null));
reset(mService);
}
static Message makeMessage(NetworkRequest req, int messageType) { static Message makeMessage(NetworkRequest req, int messageType) {
Bundle bundle = new Bundle(); Bundle bundle = new Bundle();
bundle.putParcelable(NetworkRequest.class.getSimpleName(), req); bundle.putParcelable(NetworkRequest.class.getSimpleName(), req);

View File

@@ -3360,8 +3360,8 @@ public class ConnectivityServiceTest {
NetworkCapabilities networkCapabilities = new NetworkCapabilities(); NetworkCapabilities networkCapabilities = new NetworkCapabilities();
networkCapabilities.addTransportType(TRANSPORT_WIFI) networkCapabilities.addTransportType(TRANSPORT_WIFI)
.setNetworkSpecifier(new MatchAllNetworkSpecifier()); .setNetworkSpecifier(new MatchAllNetworkSpecifier());
mService.requestNetwork(networkCapabilities, null, 0, null, mService.requestNetwork(networkCapabilities, NetworkRequest.Type.REQUEST.ordinal(),
ConnectivityManager.TYPE_WIFI, mContext.getPackageName(), null, 0, null, ConnectivityManager.TYPE_WIFI, mContext.getPackageName(),
getAttributionTag()); getAttributionTag());
}); });