diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java index ba503e0091..7b019fa814 100755 --- a/service/src/com/android/server/ConnectivityService.java +++ b/service/src/com/android/server/ConnectivityService.java @@ -244,6 +244,7 @@ import android.util.ArraySet; import android.util.LocalLog; import android.util.Log; import android.util.Pair; +import android.util.Range; import android.util.SparseArray; import android.util.SparseIntArray; @@ -310,11 +311,13 @@ import org.xmlpull.v1.XmlPullParserException; import java.io.FileDescriptor; import java.io.IOException; +import java.io.InterruptedIOException; import java.io.PrintWriter; import java.io.Writer; import java.net.Inet4Address; import java.net.InetAddress; import java.net.InetSocketAddress; +import java.net.SocketException; import java.net.UnknownHostException; import java.util.ArrayList; import java.util.Arrays; @@ -1485,6 +1488,18 @@ public class ConnectivityService extends IConnectivityManager.Stub @NonNull final UserHandle user) { return CompatChanges.isChangeEnabled(changeId, packageName, user); } + + /** + * Call {@link InetDiagMessage#destroyLiveTcpSockets(Set, Set)} + * + * @param ranges target uid ranges + * @param exemptUids uids to skip close socket + */ + public void destroyLiveTcpSockets(@NonNull final Set> ranges, + @NonNull final Set exemptUids) + throws SocketException, InterruptedIOException, ErrnoException { + InetDiagMessage.destroyLiveTcpSockets(ranges, exemptUids); + } } public ConnectivityService(Context context) { @@ -8448,11 +8463,11 @@ public class ConnectivityService extends IConnectivityManager.Stub return stableRanges; } - private void maybeCloseSockets(NetworkAgentInfo nai, UidRangeParcel[] ranges, - int[] exemptUids) { + private void maybeCloseSockets(NetworkAgentInfo nai, Set ranges, + Set exemptUids) { if (nai.isVPN() && !nai.networkAgentConfig.allowBypass) { try { - mNetd.socketDestroy(ranges, exemptUids); + mDeps.destroyLiveTcpSockets(UidRange.toIntRanges(ranges), exemptUids); } catch (Exception e) { loge("Exception in socket destroy: ", e); } @@ -8460,16 +8475,16 @@ public class ConnectivityService extends IConnectivityManager.Stub } private void updateVpnUidRanges(boolean add, NetworkAgentInfo nai, Set uidRanges) { - int[] exemptUids = new int[2]; + final Set exemptUids = new ArraySet<>(); // TODO: Excluding VPN_UID is necessary in order to not to kill the TCP connection used // by PPTP. Fix this by making Vpn set the owner UID to VPN_UID instead of system when // starting a legacy VPN, and remove VPN_UID here. (b/176542831) - exemptUids[0] = VPN_UID; - exemptUids[1] = nai.networkCapabilities.getOwnerUid(); + exemptUids.add(VPN_UID); + exemptUids.add(nai.networkCapabilities.getOwnerUid()); UidRangeParcel[] ranges = toUidRangeStableParcels(uidRanges); // Close sockets before modifying uid ranges so that RST packets can reach to the server. - maybeCloseSockets(nai, ranges, exemptUids); + maybeCloseSockets(nai, uidRanges, exemptUids); try { if (add) { mNetd.networkAddUidRangesParcel(new NativeUidRangeConfig( @@ -8483,7 +8498,7 @@ public class ConnectivityService extends IConnectivityManager.Stub " on netId " + nai.network.netId + ". " + e); } // Close sockets that established connection while requesting netd. - maybeCloseSockets(nai, ranges, exemptUids); + maybeCloseSockets(nai, uidRanges, exemptUids); } private boolean isProxySetOnAnyDefaultNetwork() { diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java index 1cc0c8909d..03942bb878 100755 --- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java +++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java @@ -1853,7 +1853,7 @@ public class ConnectivityServiceTest { final Context mockResContext = mock(Context.class); doReturn(mResources).when(mockResContext).getResources(); ConnectivityResources.setResourcesContextForTest(mockResContext); - mDeps = new ConnectivityServiceDependencies(mockResContext); + mDeps = spy(new ConnectivityServiceDependencies(mockResContext)); mAutoOnOffKeepaliveDependencies = new AutomaticOnOffKeepaliveTrackerDependencies(mServiceContext); mService = new ConnectivityService(mServiceContext, @@ -1912,7 +1912,8 @@ public class ConnectivityServiceTest { .getBoolean(R.bool.config_cellular_radio_timesharing_capable); } - class ConnectivityServiceDependencies extends ConnectivityService.Dependencies { + // ConnectivityServiceDependencies is public to use Mockito.spy + public class ConnectivityServiceDependencies extends ConnectivityService.Dependencies { final ConnectivityResources mConnRes; ConnectivityServiceDependencies(final Context mockResContext) { @@ -2148,6 +2149,12 @@ public class ConnectivityServiceTest { } } } + + @Override + public void destroyLiveTcpSockets(final Set> ranges, + final Set exemptUids) { + // This function is empty since the invocation of this method is verified by mocks + } } private class AutomaticOnOffKeepaliveTrackerDependencies @@ -12469,12 +12476,11 @@ public class ConnectivityServiceTest { private void assertVpnUidRangesUpdated(boolean add, Set vpnRanges, int exemptUid) throws Exception { - InOrder inOrder = inOrder(mMockNetd); - ArgumentCaptor exemptUidCaptor = ArgumentCaptor.forClass(int[].class); + InOrder inOrder = inOrder(mMockNetd, mDeps); + final Set exemptUidSet = new ArraySet<>(List.of(exemptUid, Process.VPN_UID)); - inOrder.verify(mMockNetd, times(1)).socketDestroy(eq(toUidRangeStableParcels(vpnRanges)), - exemptUidCaptor.capture()); - assertContainsExactly(exemptUidCaptor.getValue(), Process.VPN_UID, exemptUid); + inOrder.verify(mDeps).destroyLiveTcpSockets(UidRange.toIntRanges(vpnRanges), + exemptUidSet); if (add) { inOrder.verify(mMockNetd, times(1)).networkAddUidRangesParcel( @@ -12486,9 +12492,8 @@ public class ConnectivityServiceTest { toUidRangeStableParcels(vpnRanges), PREFERENCE_ORDER_VPN)); } - inOrder.verify(mMockNetd, times(1)).socketDestroy(eq(toUidRangeStableParcels(vpnRanges)), - exemptUidCaptor.capture()); - assertContainsExactly(exemptUidCaptor.getValue(), Process.VPN_UID, exemptUid); + inOrder.verify(mDeps).destroyLiveTcpSockets(UidRange.toIntRanges(vpnRanges), + exemptUidSet); } @Test