diff --git a/core/java/android/net/ConnectivityManager.java b/core/java/android/net/ConnectivityManager.java index 2a985e703a..0e5d049c78 100644 --- a/core/java/android/net/ConnectivityManager.java +++ b/core/java/android/net/ConnectivityManager.java @@ -15,8 +15,6 @@ */ package android.net; -import static com.android.internal.util.Preconditions.checkNotNull; - import android.annotation.IntDef; import android.annotation.Nullable; import android.annotation.SdkConstant; @@ -50,16 +48,19 @@ import android.util.SparseIntArray; import com.android.internal.telephony.ITelephony; import com.android.internal.telephony.PhoneConstants; -import com.android.internal.util.Protocol; import com.android.internal.util.MessageUtils; +import com.android.internal.util.Preconditions; +import com.android.internal.util.Protocol; import libcore.net.event.NetworkEventDispatcher; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.net.InetAddress; +import java.util.ArrayList; import java.util.HashMap; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.List; +import java.util.Map; /** * Class that answers queries about the state of network connectivity. It also @@ -1547,8 +1548,8 @@ public class ConnectivityManager { } private PacketKeepalive(Network network, PacketKeepaliveCallback callback) { - checkNotNull(network, "network cannot be null"); - checkNotNull(callback, "callback cannot be null"); + Preconditions.checkNotNull(network, "network cannot be null"); + Preconditions.checkNotNull(callback, "callback cannot be null"); mNetwork = network; mCallback = callback; HandlerThread thread = new HandlerThread(TAG); @@ -1835,8 +1836,8 @@ public class ConnectivityManager { * {@hide} */ public ConnectivityManager(Context context, IConnectivityManager service) { - mContext = checkNotNull(context, "missing context"); - mService = checkNotNull(service, "missing IConnectivityManager"); + mContext = Preconditions.checkNotNull(context, "missing context"); + mService = Preconditions.checkNotNull(service, "missing IConnectivityManager"); sInstance = this; } @@ -2099,7 +2100,7 @@ public class ConnectivityManager { @SystemApi public void startTethering(int type, boolean showProvisioningUi, final OnStartTetheringCallback callback, Handler handler) { - checkNotNull(callback, "OnStartTetheringCallback cannot be null."); + Preconditions.checkNotNull(callback, "OnStartTetheringCallback cannot be null."); ResultReceiver wrappedCallback = new ResultReceiver(handler) { @Override @@ -2559,8 +2560,16 @@ public class ConnectivityManager { } /** - * Base class for NetworkRequest callbacks. Used for notifications about network - * changes. Should be extended by applications wanting notifications. + * Base class for {@code NetworkRequest} callbacks. Used for notifications about network + * changes. Should be extended by applications wanting notifications. + * + * A {@code NetworkCallback} is registered by calling + * {@link #requestNetwork(NetworkRequest, NetworkCallback)}, + * {@link #registerNetworkCallback(NetworkRequest, NetworkCallback)}, + * or {@link #registerDefaultNetworkCallback(NetworkCallback). A {@code NetworkCallback} is + * unregistered by calling {@link #unregisterNetworkCallback(NetworkCallback)}. + * A {@code NetworkCallback} should be registered at most once at any time. + * A {@code NetworkCallback} that has been unregistered can be registered again. */ public static class NetworkCallback { /** @@ -2663,6 +2672,10 @@ public class ConnectivityManager { public void onNetworkResumed(Network network) {} private NetworkRequest networkRequest; + + private boolean isRegistered() { + return (networkRequest != null) && (networkRequest.requestId != REQUEST_ID_UNSET); + } } private static final int BASE = Protocol.BASE_CONNECTIVITY_MANAGER; @@ -2680,6 +2693,7 @@ public class ConnectivityManager { public static final int CALLBACK_CAP_CHANGED = BASE + 6; /** @hide */ public static final int CALLBACK_IP_CHANGED = BASE + 7; + // TODO: consider deleting CALLBACK_RELEASED and shifting following enum codes down by 1. /** @hide */ public static final int CALLBACK_RELEASED = BASE + 8; // TODO: consider deleting CALLBACK_EXIT and shifting following enum codes down by 1. @@ -2798,13 +2812,6 @@ public class ConnectivityManager { break; } case CALLBACK_RELEASED: { - final NetworkCallback callback; - synchronized(sCallbacks) { - callback = sCallbacks.remove(request); - } - if (callback == null) { - Log.e(TAG, "callback not found for RELEASED message"); - } break; } case CALLBACK_EXIT: { @@ -2822,12 +2829,12 @@ public class ConnectivityManager { } private NetworkCallback getCallback(NetworkRequest req, String name) { - NetworkCallback callback; + final NetworkCallback callback; synchronized(sCallbacks) { callback = sCallbacks.get(req); } if (callback == null) { - Log.e(TAG, "callback not found for " + name + " message"); + Log.w(TAG, "callback not found for " + name + " message"); } return callback; } @@ -2850,17 +2857,16 @@ public class ConnectivityManager { private NetworkRequest sendRequestForNetwork(NetworkCapabilities need, NetworkCallback callback, int timeoutMs, int action, int legacyType, CallbackHandler handler) { - if (callback == null) { - throw new IllegalArgumentException("null NetworkCallback"); - } - if (need == null && action != REQUEST) { - throw new IllegalArgumentException("null NetworkCapabilities"); - } - // TODO: throw an exception if callback.networkRequest is not null. - // http://b/20701525 + Preconditions.checkArgument(callback != null, "null NetworkCallback"); + Preconditions.checkArgument(action == REQUEST || need != null, "null NetworkCapabilities"); final NetworkRequest request; try { synchronized(sCallbacks) { + if (callback.isRegistered()) { + // TODO: throw exception instead and enforce 1:1 mapping of callbacks + // and requests (http://b/20701525). + Log.e(TAG, "NetworkCallback was already registered"); + } Messenger messenger = new Messenger(handler); Binder binder = new Binder(); if (action == LISTEN) { @@ -3325,25 +3331,42 @@ public class ConnectivityManager { } /** - * Unregisters callbacks about and possibly releases networks originating from + * Unregisters a {@code NetworkCallback} and possibly releases networks originating from * {@link #requestNetwork(NetworkRequest, NetworkCallback)} and * {@link #registerNetworkCallback(NetworkRequest, NetworkCallback)} calls. * If the given {@code NetworkCallback} had previously been used with * {@code #requestNetwork}, any networks that had been connected to only to satisfy that request * will be disconnected. * + * Notifications that would have triggered that {@code NetworkCallback} will immediately stop + * triggering it as soon as this call returns. + * * @param networkCallback The {@link NetworkCallback} used when making the request. */ public void unregisterNetworkCallback(NetworkCallback networkCallback) { - if (networkCallback == null || networkCallback.networkRequest == null || - networkCallback.networkRequest.requestId == REQUEST_ID_UNSET) { - throw new IllegalArgumentException("Invalid NetworkCallback"); - } - try { - // CallbackHandler will release callback when receiving CALLBACK_RELEASED. - mService.releaseNetworkRequest(networkCallback.networkRequest); - } catch (RemoteException e) { - throw e.rethrowFromSystemServer(); + Preconditions.checkArgument(networkCallback != null, "null NetworkCallback"); + final List reqs = new ArrayList<>(); + // Find all requests associated to this callback and stop callback triggers immediately. + // Callback is reusable immediately. http://b/20701525, http://b/35921499. + synchronized (sCallbacks) { + Preconditions.checkArgument( + networkCallback.isRegistered(), "NetworkCallback was not registered"); + for (Map.Entry e : sCallbacks.entrySet()) { + if (e.getValue() == networkCallback) { + reqs.add(e.getKey()); + } + } + // TODO: throw exception if callback was registered more than once (http://b/20701525). + for (NetworkRequest r : reqs) { + try { + mService.releaseNetworkRequest(r); + } catch (RemoteException e) { + throw e.rethrowFromSystemServer(); + } + // Only remove mapping if rpc was successful. + sCallbacks.remove(r); + } + networkCallback.networkRequest = null; } } diff --git a/tests/net/java/android/net/ConnectivityManagerTest.java b/tests/net/java/android/net/ConnectivityManagerTest.java index b984bbfdda..ceb0135727 100644 --- a/tests/net/java/android/net/ConnectivityManagerTest.java +++ b/tests/net/java/android/net/ConnectivityManagerTest.java @@ -36,21 +36,50 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import android.net.ConnectivityManager; import android.net.NetworkCapabilities; - +import android.content.Context; +import android.os.Bundle; +import android.os.Handler; +import android.os.Looper; +import android.os.Message; +import android.os.Messenger; +import android.content.pm.ApplicationInfo; +import android.os.Build.VERSION_CODES; +import android.net.ConnectivityManager.NetworkCallback; import android.support.test.filters.SmallTest; import android.support.test.runner.AndroidJUnit4; -import org.junit.runner.RunWith; +import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; @RunWith(AndroidJUnit4.class) @SmallTest public class ConnectivityManagerTest { + + @Mock Context mCtx; + @Mock IConnectivityManager mService; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + static NetworkCapabilities verifyNetworkCapabilities( int legacyType, int transportType, int... capabilities) { final NetworkCapabilities nc = ConnectivityManager.networkCapabilitiesForType(legacyType); @@ -173,4 +202,124 @@ public class ConnectivityManagerTest { verifyUnrestrictedNetworkCapabilities( ConnectivityManager.TYPE_ETHERNET, TRANSPORT_ETHERNET); } + + @Test + public void testCallbackRelease() throws Exception { + ConnectivityManager manager = new ConnectivityManager(mCtx, mService); + NetworkRequest request = makeRequest(1); + NetworkCallback callback = mock(ConnectivityManager.NetworkCallback.class); + Handler handler = new Handler(Looper.getMainLooper()); + ArgumentCaptor captor = ArgumentCaptor.forClass(Messenger.class); + + // register callback + when(mService.requestNetwork(any(), captor.capture(), anyInt(), any(), anyInt())) + .thenReturn(request); + manager.requestNetwork(request, callback, handler); + + // callback triggers + captor.getValue().send(makeMessage(request, ConnectivityManager.CALLBACK_AVAILABLE)); + verify(callback, timeout(500).times(1)).onAvailable(any()); + + // unregister callback + manager.unregisterNetworkCallback(callback); + verify(mService, times(1)).releaseNetworkRequest(request); + + // callback does not trigger anymore. + captor.getValue().send(makeMessage(request, ConnectivityManager.CALLBACK_LOSING)); + verify(callback, timeout(500).times(0)).onLosing(any(), anyInt()); + } + + @Test + public void testCallbackRecycling() throws Exception { + ConnectivityManager manager = new ConnectivityManager(mCtx, mService); + NetworkRequest req1 = makeRequest(1); + NetworkRequest req2 = makeRequest(2); + NetworkCallback callback = mock(ConnectivityManager.NetworkCallback.class); + Handler handler = new Handler(Looper.getMainLooper()); + ArgumentCaptor captor = ArgumentCaptor.forClass(Messenger.class); + + // register callback + when(mService.requestNetwork(any(), captor.capture(), anyInt(), any(), anyInt())) + .thenReturn(req1); + manager.requestNetwork(req1, callback, handler); + + // callback triggers + captor.getValue().send(makeMessage(req1, ConnectivityManager.CALLBACK_AVAILABLE)); + verify(callback, timeout(100).times(1)).onAvailable(any()); + + // unregister callback + manager.unregisterNetworkCallback(callback); + verify(mService, times(1)).releaseNetworkRequest(req1); + + // callback does not trigger anymore. + captor.getValue().send(makeMessage(req1, ConnectivityManager.CALLBACK_LOSING)); + verify(callback, timeout(100).times(0)).onLosing(any(), anyInt()); + + // callback can be registered again + when(mService.requestNetwork(any(), captor.capture(), anyInt(), any(), anyInt())) + .thenReturn(req2); + manager.requestNetwork(req2, callback, handler); + + // callback triggers + captor.getValue().send(makeMessage(req2, ConnectivityManager.CALLBACK_LOST)); + verify(callback, timeout(100).times(1)).onLost(any()); + + // unregister callback + manager.unregisterNetworkCallback(callback); + verify(mService, times(1)).releaseNetworkRequest(req2); + } + + // TODO: turn on this test when request callback 1:1 mapping is enforced + //@Test + private void noDoubleCallbackRegistration() throws Exception { + ConnectivityManager manager = new ConnectivityManager(mCtx, mService); + NetworkRequest request = makeRequest(1); + NetworkCallback callback = new ConnectivityManager.NetworkCallback(); + ApplicationInfo info = new ApplicationInfo(); + // TODO: update version when starting to enforce 1:1 mapping + info.targetSdkVersion = VERSION_CODES.N_MR1 + 1; + + when(mCtx.getApplicationInfo()).thenReturn(info); + when(mService.requestNetwork(any(), any(), anyInt(), any(), anyInt())).thenReturn(request); + + Handler handler = new Handler(Looper.getMainLooper()); + manager.requestNetwork(request, callback, handler); + + // callback is already registered, reregistration should fail. + Class wantException = IllegalArgumentException.class; + expectThrowable(() -> manager.requestNetwork(request, callback), wantException); + + manager.unregisterNetworkCallback(callback); + verify(mService, times(1)).releaseNetworkRequest(request); + + // unregistering the callback should make it registrable again. + manager.requestNetwork(request, callback); + } + + static Message makeMessage(NetworkRequest req, int messageType) { + Bundle bundle = new Bundle(); + bundle.putParcelable(NetworkRequest.class.getSimpleName(), req); + Message msg = Message.obtain(); + msg.what = messageType; + msg.setData(bundle); + return msg; + } + + static NetworkRequest makeRequest(int requestId) { + NetworkRequest request = new NetworkRequest.Builder().clearCapabilities().build(); + return new NetworkRequest(request.networkCapabilities, ConnectivityManager.TYPE_NONE, + requestId, NetworkRequest.Type.NONE); + } + + static void expectThrowable(Runnable block, Class throwableType) { + try { + block.run(); + } catch (Throwable t) { + if (t.getClass().equals(throwableType)) { + return; + } + fail("expected exception of type " + throwableType + ", but was " + t.getClass()); + } + fail("expected exception of type " + throwableType); + } }