Use InetDiagMessage.destroyLiveTcpSocket instead of netd.socketDestroy

Netd is not updatable since it's not mainlined.
To make socket destroy code updatable, the code was re-implemented in
java and moved to Connectivity.

Bug: 270298713
Test: atest FrameworksNetTests
Change-Id: I5439c0c76c42a9f738a1b25a1f62e701755cbd05
This commit is contained in:
Motomu Utsumi
2023-03-16 17:04:21 +09:00
parent 7002f0bee1
commit 93a2218e41
2 changed files with 38 additions and 18 deletions

View File

@@ -244,6 +244,7 @@ import android.util.ArraySet;
import android.util.LocalLog; import android.util.LocalLog;
import android.util.Log; import android.util.Log;
import android.util.Pair; import android.util.Pair;
import android.util.Range;
import android.util.SparseArray; import android.util.SparseArray;
import android.util.SparseIntArray; import android.util.SparseIntArray;
@@ -310,11 +311,13 @@ import org.xmlpull.v1.XmlPullParserException;
import java.io.FileDescriptor; import java.io.FileDescriptor;
import java.io.IOException; import java.io.IOException;
import java.io.InterruptedIOException;
import java.io.PrintWriter; import java.io.PrintWriter;
import java.io.Writer; import java.io.Writer;
import java.net.Inet4Address; import java.net.Inet4Address;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.SocketException;
import java.net.UnknownHostException; import java.net.UnknownHostException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@@ -1485,6 +1488,18 @@ public class ConnectivityService extends IConnectivityManager.Stub
@NonNull final UserHandle user) { @NonNull final UserHandle user) {
return CompatChanges.isChangeEnabled(changeId, packageName, user); return CompatChanges.isChangeEnabled(changeId, packageName, user);
} }
/**
* Call {@link InetDiagMessage#destroyLiveTcpSockets(Set, Set)}
*
* @param ranges target uid ranges
* @param exemptUids uids to skip close socket
*/
public void destroyLiveTcpSockets(@NonNull final Set<Range<Integer>> ranges,
@NonNull final Set<Integer> exemptUids)
throws SocketException, InterruptedIOException, ErrnoException {
InetDiagMessage.destroyLiveTcpSockets(ranges, exemptUids);
}
} }
public ConnectivityService(Context context) { public ConnectivityService(Context context) {
@@ -8448,11 +8463,11 @@ public class ConnectivityService extends IConnectivityManager.Stub
return stableRanges; return stableRanges;
} }
private void maybeCloseSockets(NetworkAgentInfo nai, UidRangeParcel[] ranges, private void maybeCloseSockets(NetworkAgentInfo nai, Set<UidRange> ranges,
int[] exemptUids) { Set<Integer> exemptUids) {
if (nai.isVPN() && !nai.networkAgentConfig.allowBypass) { if (nai.isVPN() && !nai.networkAgentConfig.allowBypass) {
try { try {
mNetd.socketDestroy(ranges, exemptUids); mDeps.destroyLiveTcpSockets(UidRange.toIntRanges(ranges), exemptUids);
} catch (Exception e) { } catch (Exception e) {
loge("Exception in socket destroy: ", e); loge("Exception in socket destroy: ", e);
} }
@@ -8460,16 +8475,16 @@ public class ConnectivityService extends IConnectivityManager.Stub
} }
private void updateVpnUidRanges(boolean add, NetworkAgentInfo nai, Set<UidRange> uidRanges) { private void updateVpnUidRanges(boolean add, NetworkAgentInfo nai, Set<UidRange> uidRanges) {
int[] exemptUids = new int[2]; final Set<Integer> exemptUids = new ArraySet<>();
// TODO: Excluding VPN_UID is necessary in order to not to kill the TCP connection used // TODO: Excluding VPN_UID is necessary in order to not to kill the TCP connection used
// by PPTP. Fix this by making Vpn set the owner UID to VPN_UID instead of system when // by PPTP. Fix this by making Vpn set the owner UID to VPN_UID instead of system when
// starting a legacy VPN, and remove VPN_UID here. (b/176542831) // starting a legacy VPN, and remove VPN_UID here. (b/176542831)
exemptUids[0] = VPN_UID; exemptUids.add(VPN_UID);
exemptUids[1] = nai.networkCapabilities.getOwnerUid(); exemptUids.add(nai.networkCapabilities.getOwnerUid());
UidRangeParcel[] ranges = toUidRangeStableParcels(uidRanges); UidRangeParcel[] ranges = toUidRangeStableParcels(uidRanges);
// Close sockets before modifying uid ranges so that RST packets can reach to the server. // Close sockets before modifying uid ranges so that RST packets can reach to the server.
maybeCloseSockets(nai, ranges, exemptUids); maybeCloseSockets(nai, uidRanges, exemptUids);
try { try {
if (add) { if (add) {
mNetd.networkAddUidRangesParcel(new NativeUidRangeConfig( mNetd.networkAddUidRangesParcel(new NativeUidRangeConfig(
@@ -8483,7 +8498,7 @@ public class ConnectivityService extends IConnectivityManager.Stub
" on netId " + nai.network.netId + ". " + e); " on netId " + nai.network.netId + ". " + e);
} }
// Close sockets that established connection while requesting netd. // Close sockets that established connection while requesting netd.
maybeCloseSockets(nai, ranges, exemptUids); maybeCloseSockets(nai, uidRanges, exemptUids);
} }
private boolean isProxySetOnAnyDefaultNetwork() { private boolean isProxySetOnAnyDefaultNetwork() {

View File

@@ -1853,7 +1853,7 @@ public class ConnectivityServiceTest {
final Context mockResContext = mock(Context.class); final Context mockResContext = mock(Context.class);
doReturn(mResources).when(mockResContext).getResources(); doReturn(mResources).when(mockResContext).getResources();
ConnectivityResources.setResourcesContextForTest(mockResContext); ConnectivityResources.setResourcesContextForTest(mockResContext);
mDeps = new ConnectivityServiceDependencies(mockResContext); mDeps = spy(new ConnectivityServiceDependencies(mockResContext));
mAutoOnOffKeepaliveDependencies = mAutoOnOffKeepaliveDependencies =
new AutomaticOnOffKeepaliveTrackerDependencies(mServiceContext); new AutomaticOnOffKeepaliveTrackerDependencies(mServiceContext);
mService = new ConnectivityService(mServiceContext, mService = new ConnectivityService(mServiceContext,
@@ -1912,7 +1912,8 @@ public class ConnectivityServiceTest {
.getBoolean(R.bool.config_cellular_radio_timesharing_capable); .getBoolean(R.bool.config_cellular_radio_timesharing_capable);
} }
class ConnectivityServiceDependencies extends ConnectivityService.Dependencies { // ConnectivityServiceDependencies is public to use Mockito.spy
public class ConnectivityServiceDependencies extends ConnectivityService.Dependencies {
final ConnectivityResources mConnRes; final ConnectivityResources mConnRes;
ConnectivityServiceDependencies(final Context mockResContext) { ConnectivityServiceDependencies(final Context mockResContext) {
@@ -2148,6 +2149,12 @@ public class ConnectivityServiceTest {
} }
} }
} }
@Override
public void destroyLiveTcpSockets(final Set<Range<Integer>> ranges,
final Set<Integer> exemptUids) {
// This function is empty since the invocation of this method is verified by mocks
}
} }
private class AutomaticOnOffKeepaliveTrackerDependencies private class AutomaticOnOffKeepaliveTrackerDependencies
@@ -12469,12 +12476,11 @@ public class ConnectivityServiceTest {
private void assertVpnUidRangesUpdated(boolean add, Set<UidRange> vpnRanges, int exemptUid) private void assertVpnUidRangesUpdated(boolean add, Set<UidRange> vpnRanges, int exemptUid)
throws Exception { throws Exception {
InOrder inOrder = inOrder(mMockNetd); InOrder inOrder = inOrder(mMockNetd, mDeps);
ArgumentCaptor<int[]> exemptUidCaptor = ArgumentCaptor.forClass(int[].class); final Set<Integer> exemptUidSet = new ArraySet<>(List.of(exemptUid, Process.VPN_UID));
inOrder.verify(mMockNetd, times(1)).socketDestroy(eq(toUidRangeStableParcels(vpnRanges)), inOrder.verify(mDeps).destroyLiveTcpSockets(UidRange.toIntRanges(vpnRanges),
exemptUidCaptor.capture()); exemptUidSet);
assertContainsExactly(exemptUidCaptor.getValue(), Process.VPN_UID, exemptUid);
if (add) { if (add) {
inOrder.verify(mMockNetd, times(1)).networkAddUidRangesParcel( inOrder.verify(mMockNetd, times(1)).networkAddUidRangesParcel(
@@ -12486,9 +12492,8 @@ public class ConnectivityServiceTest {
toUidRangeStableParcels(vpnRanges), PREFERENCE_ORDER_VPN)); toUidRangeStableParcels(vpnRanges), PREFERENCE_ORDER_VPN));
} }
inOrder.verify(mMockNetd, times(1)).socketDestroy(eq(toUidRangeStableParcels(vpnRanges)), inOrder.verify(mDeps).destroyLiveTcpSockets(UidRange.toIntRanges(vpnRanges),
exemptUidCaptor.capture()); exemptUidSet);
assertContainsExactly(exemptUidCaptor.getValue(), Process.VPN_UID, exemptUid);
} }
@Test @Test