Merge changes from topic "cherrypicker-L78000000960635464:N55900001368285360" into udc-dev

* changes:
  Close sockets from ConnectivityService#setFirewallChainEnabled
  Add test to verify socket close when firewall is enabled
This commit is contained in:
Motomu Utsumi
2023-05-15 04:15:48 +00:00
committed by Android (Google) Code Review
6 changed files with 276 additions and 3 deletions

View File

@@ -384,7 +384,6 @@ public class BpfNetMaps {
* ALLOWLIST means the firewall denies all by default, uids must be explicitly allowed
* DENYLIST means the firewall allows all by default, uids must be explicitly denyed
*/
@VisibleForTesting
public boolean isFirewallAllowList(final int chain) {
switch (chain) {
case FIREWALL_CHAIN_DOZABLE:
@@ -745,6 +744,65 @@ public class BpfNetMaps {
}
}
private Set<Integer> getUidsMatchEnabled(final int childChain) throws ErrnoException {
final long match = getMatchByFirewallChain(childChain);
Set<Integer> uids = new ArraySet<>();
synchronized (sUidOwnerMap) {
sUidOwnerMap.forEach((uid, val) -> {
if (val == null) {
Log.wtf(TAG, "sUidOwnerMap entry was deleted while holding a lock");
} else {
if ((val.rule & match) != 0) {
uids.add(uid.val);
}
}
});
}
return uids;
}
/**
* Get uids that has FIREWALL_RULE_ALLOW on allowlist chain.
* Allowlist means the firewall denies all by default, uids must be explicitly allowed.
*
* Note that uids that has FIREWALL_RULE_DENY on allowlist chain can not be computed from the
* bpf map, since all the uids that does not have explicit FIREWALL_RULE_ALLOW rule in bpf map
* are determined to have FIREWALL_RULE_DENY.
*
* @param childChain target chain
* @return Set of uids
*/
public Set<Integer> getUidsWithAllowRuleOnAllowListChain(final int childChain)
throws ErrnoException {
if (!isFirewallAllowList(childChain)) {
throw new IllegalArgumentException("getUidsWithAllowRuleOnAllowListChain is called with"
+ " denylist chain:" + childChain);
}
// Corresponding match is enabled for uids that has FIREWALL_RULE_ALLOW on allowlist chain.
return getUidsMatchEnabled(childChain);
}
/**
* Get uids that has FIREWALL_RULE_DENY on denylist chain.
* Denylist means the firewall allows all by default, uids must be explicitly denyed
*
* Note that uids that has FIREWALL_RULE_ALLOW on denylist chain can not be computed from the
* bpf map, since all the uids that does not have explicit FIREWALL_RULE_DENY rule in bpf map
* are determined to have the FIREWALL_RULE_ALLOW.
*
* @param childChain target chain
* @return Set of uids
*/
public Set<Integer> getUidsWithDenyRuleOnDenyListChain(final int childChain)
throws ErrnoException {
if (isFirewallAllowList(childChain)) {
throw new IllegalArgumentException("getUidsWithDenyRuleOnDenyListChain is called with"
+ " allowlist chain:" + childChain);
}
// Corresponding match is enabled for uids that has FIREWALL_RULE_DENY on denylist chain.
return getUidsMatchEnabled(childChain);
}
/**
* Add ingress interface filtering rules to a list of UIDs
*

View File

@@ -1509,6 +1509,16 @@ public class ConnectivityService extends IConnectivityManager.Stub
throws SocketException, InterruptedIOException, ErrnoException {
InetDiagMessage.destroyLiveTcpSockets(ranges, exemptUids);
}
/**
* Call {@link InetDiagMessage#destroyLiveTcpSocketsByOwnerUids(Set)}
*
* @param ownerUids target uids to close sockets
*/
public void destroyLiveTcpSocketsByOwnerUids(final Set<Integer> ownerUids)
throws SocketException, InterruptedIOException, ErrnoException {
InetDiagMessage.destroyLiveTcpSocketsByOwnerUids(ownerUids);
}
}
public ConnectivityService(Context context) {
@@ -12002,6 +12012,23 @@ public class ConnectivityService extends IConnectivityManager.Stub
return rule;
}
private void closeSocketsForFirewallChainLocked(final int chain)
throws ErrnoException, SocketException, InterruptedIOException {
if (mBpfNetMaps.isFirewallAllowList(chain)) {
// Allowlist means the firewall denies all by default, uids must be explicitly allowed
// So, close all non-system socket owned by uids that are not explicitly allowed
Set<Range<Integer>> ranges = new ArraySet<>();
ranges.add(new Range<>(Process.FIRST_APPLICATION_UID, Integer.MAX_VALUE));
final Set<Integer> exemptUids = mBpfNetMaps.getUidsWithAllowRuleOnAllowListChain(chain);
mDeps.destroyLiveTcpSockets(ranges, exemptUids);
} else {
// Denylist means the firewall allows all by default, uids must be explicitly denied
// So, close socket owned by uids that are explicitly denied
final Set<Integer> ownerUids = mBpfNetMaps.getUidsWithDenyRuleOnDenyListChain(chain);
mDeps.destroyLiveTcpSocketsByOwnerUids(ownerUids);
}
}
@Override
public void setFirewallChainEnabled(final int chain, final boolean enable) {
enforceNetworkStackOrSettingsPermission();
@@ -12011,6 +12038,14 @@ public class ConnectivityService extends IConnectivityManager.Stub
} catch (ServiceSpecificException e) {
throw new IllegalStateException(e);
}
if (SdkLevel.isAtLeastU() && enable) {
try {
closeSocketsForFirewallChainLocked(chain);
} catch (ErrnoException | SocketException | InterruptedIOException e) {
Log.e(TAG, "Failed to close sockets after enabling chain (" + chain + "): " + e);
}
}
}
@Override

View File

@@ -225,6 +225,7 @@ import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.MalformedURLException;
import java.net.Socket;
import java.net.SocketException;
import java.net.URL;
import java.net.UnknownHostException;
import java.nio.charset.StandardCharsets;
@@ -278,6 +279,7 @@ public class ConnectivityManagerTest {
// TODO(b/252972908): reset the original timer when aosp/2188755 is ramped up.
private static final int LISTEN_ACTIVITY_TIMEOUT_MS = 30_000;
private static final int NO_CALLBACK_TIMEOUT_MS = 100;
private static final int NETWORK_REQUEST_TIMEOUT_MS = 3000;
private static final int SOCKET_TIMEOUT_MS = 100;
private static final int NUM_TRIES_MULTIPATH_PREF_CHECK = 20;
private static final long INTERVAL_MULTIPATH_PREF_CHECK_MS = 500;
@@ -3548,6 +3550,103 @@ public class ConnectivityManagerTest {
doTestFirewallBlocking(FIREWALL_CHAIN_OEM_DENY_3, DENYLIST);
}
private void assertSocketOpen(final Socket socket) throws Exception {
mCtsNetUtils.testHttpRequest(socket);
}
private void assertSocketClosed(final Socket socket) throws Exception {
try {
mCtsNetUtils.testHttpRequest(socket);
fail("Socket is expected to be closed");
} catch (SocketException expected) {
}
}
private static final boolean EXPECT_OPEN = false;
private static final boolean EXPECT_CLOSE = true;
private void doTestFirewallCloseSocket(final int chain, final int rule, final int targetUid,
final boolean expectClose) {
runWithShellPermissionIdentity(() -> {
// Firewall chain status will be restored after the test.
final boolean wasChainEnabled = mCm.getFirewallChainEnabled(chain);
final int previousUidFirewallRule = mCm.getUidFirewallRule(chain, targetUid);
final Socket socket = new Socket(TEST_HOST, HTTP_PORT);
socket.setSoTimeout(NETWORK_REQUEST_TIMEOUT_MS);
testAndCleanup(() -> {
mCm.setFirewallChainEnabled(chain, false /* enable */);
assertSocketOpen(socket);
try {
mCm.setUidFirewallRule(chain, targetUid, rule);
} catch (IllegalStateException ignored) {
// Removing match causes an exception when the rule entry for the uid does
// not exist. But this is fine and can be ignored.
}
mCm.setFirewallChainEnabled(chain, true /* enable */);
if (expectClose) {
assertSocketClosed(socket);
} else {
assertSocketOpen(socket);
}
}, /* cleanup */ () -> {
// Restore the global chain status
mCm.setFirewallChainEnabled(chain, wasChainEnabled);
}, /* cleanup */ () -> {
// Restore the uid firewall rule status
try {
mCm.setUidFirewallRule(chain, targetUid, previousUidFirewallRule);
} catch (IllegalStateException ignored) {
// Removing match causes an exception when the rule entry for the uid does
// not exist. But this is fine and can be ignored.
}
}, /* cleanup */ () -> {
socket.close();
});
}, NETWORK_SETTINGS);
}
@Test @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) @ConnectivityModuleTest
public void testFirewallCloseSocketAllowlistChainAllow() {
doTestFirewallCloseSocket(FIREWALL_CHAIN_DOZABLE, FIREWALL_RULE_ALLOW,
Process.myUid(), EXPECT_OPEN);
}
@Test @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) @ConnectivityModuleTest
public void testFirewallCloseSocketAllowlistChainDeny() {
doTestFirewallCloseSocket(FIREWALL_CHAIN_DOZABLE, FIREWALL_RULE_DENY,
Process.myUid(), EXPECT_CLOSE);
}
@Test @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) @ConnectivityModuleTest
public void testFirewallCloseSocketAllowlistChainOtherUid() {
doTestFirewallCloseSocket(FIREWALL_CHAIN_DOZABLE, FIREWALL_RULE_ALLOW,
Process.myUid() + 1, EXPECT_CLOSE);
doTestFirewallCloseSocket(FIREWALL_CHAIN_DOZABLE, FIREWALL_RULE_DENY,
Process.myUid() + 1, EXPECT_CLOSE);
}
@Test @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) @ConnectivityModuleTest
public void testFirewallCloseSocketDenylistChainAllow() {
doTestFirewallCloseSocket(FIREWALL_CHAIN_STANDBY, FIREWALL_RULE_ALLOW,
Process.myUid(), EXPECT_OPEN);
}
@Test @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) @ConnectivityModuleTest
public void testFirewallCloseSocketDenylistChainDeny() {
doTestFirewallCloseSocket(FIREWALL_CHAIN_STANDBY, FIREWALL_RULE_DENY,
Process.myUid(), EXPECT_CLOSE);
}
@Test @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) @ConnectivityModuleTest
public void testFirewallCloseSocketDenylistChainOtherUid() {
doTestFirewallCloseSocket(FIREWALL_CHAIN_STANDBY, FIREWALL_RULE_ALLOW,
Process.myUid() + 1, EXPECT_OPEN);
doTestFirewallCloseSocket(FIREWALL_CHAIN_STANDBY, FIREWALL_RULE_DENY,
Process.myUid() + 1, EXPECT_OPEN);
}
private void assumeTestSApis() {
// Cannot use @IgnoreUpTo(Build.VERSION_CODES.R) because this test also requires API 31
// shims, and @IgnoreUpTo does not check that.

View File

@@ -422,7 +422,7 @@ public final class CtsNetUtils {
.build();
}
private void testHttpRequest(Socket s) throws IOException {
public void testHttpRequest(Socket s) throws IOException {
OutputStream out = s.getOutputStream();
InputStream in = s.getInputStream();
@@ -430,7 +430,9 @@ public final class CtsNetUtils {
byte[] responseBytes = new byte[4096];
out.write(requestBytes);
in.read(responseBytes);
assertTrue(new String(responseBytes, "UTF-8").startsWith("HTTP/1.0 204 No Content\r\n"));
final String response = new String(responseBytes, "UTF-8");
assertTrue("Received unexpected response: " + response,
response.startsWith("HTTP/1.0 204 No Content\r\n"));
}
private Socket getBoundSocket(Network network, String host, int port) throws IOException {

View File

@@ -66,6 +66,7 @@ import android.net.INetd;
import android.os.Build;
import android.os.ServiceSpecificException;
import android.system.ErrnoException;
import android.util.ArraySet;
import android.util.IndentingPrintWriter;
import androidx.test.filters.SmallTest;
@@ -1151,4 +1152,33 @@ public final class BpfNetMapsTest {
mCookieTagMap.updateEntry(new CookieTagMapKey(123), new CookieTagMapValue(456, 0x789));
assertDumpContains(getDump(), "cookie=123 tag=0x789 uid=456");
}
@Test
public void testGetUids() throws ErrnoException {
final int uid0 = TEST_UIDS[0];
final int uid1 = TEST_UIDS[1];
final long match0 = DOZABLE_MATCH | POWERSAVE_MATCH;
final long match1 = DOZABLE_MATCH | STANDBY_MATCH;
mUidOwnerMap.updateEntry(new S32(uid0), new UidOwnerValue(NULL_IIF, match0));
mUidOwnerMap.updateEntry(new S32(uid1), new UidOwnerValue(NULL_IIF, match1));
assertEquals(new ArraySet<>(List.of(uid0, uid1)),
mBpfNetMaps.getUidsWithAllowRuleOnAllowListChain(FIREWALL_CHAIN_DOZABLE));
assertEquals(new ArraySet<>(List.of(uid0)),
mBpfNetMaps.getUidsWithAllowRuleOnAllowListChain(FIREWALL_CHAIN_POWERSAVE));
assertEquals(new ArraySet<>(List.of(uid1)),
mBpfNetMaps.getUidsWithDenyRuleOnDenyListChain(FIREWALL_CHAIN_STANDBY));
assertEquals(new ArraySet<>(),
mBpfNetMaps.getUidsWithDenyRuleOnDenyListChain(FIREWALL_CHAIN_OEM_DENY_1));
}
@Test
public void testGetUidsIllegalArgument() {
final Class<IllegalArgumentException> expected = IllegalArgumentException.class;
assertThrows(expected,
() -> mBpfNetMaps.getUidsWithDenyRuleOnDenyListChain(FIREWALL_CHAIN_DOZABLE));
assertThrows(expected,
() -> mBpfNetMaps.getUidsWithAllowRuleOnAllowListChain(FIREWALL_CHAIN_OEM_DENY_1));
}
}

View File

@@ -2173,6 +2173,11 @@ public class ConnectivityServiceTest {
final Set<Integer> exemptUids) {
// This function is empty since the invocation of this method is verified by mocks
}
@Override
public void destroyLiveTcpSocketsByOwnerUids(final Set<Integer> ownerUids) {
// This function is empty since the invocation of this method is verified by mocks
}
}
private class AutomaticOnOffKeepaliveTrackerDependencies
@@ -10244,6 +10249,50 @@ public class ConnectivityServiceTest {
}
}
private void doTestSetFirewallChainEnabledCloseSocket(final int chain,
final boolean isAllowList) throws Exception {
reset(mDeps);
mCm.setFirewallChainEnabled(chain, true /* enabled */);
final Set<Integer> uids =
new ArraySet<>(List.of(TEST_PACKAGE_UID, TEST_PACKAGE_UID2));
if (isAllowList) {
final Set<Range<Integer>> range = new ArraySet<>(
List.of(new Range<>(Process.FIRST_APPLICATION_UID, Integer.MAX_VALUE)));
verify(mDeps).destroyLiveTcpSockets(range, uids);
} else {
verify(mDeps).destroyLiveTcpSocketsByOwnerUids(uids);
}
mCm.setFirewallChainEnabled(chain, false /* enabled */);
verifyNoMoreInteractions(mDeps);
}
@Test @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
public void testSetFirewallChainEnabledCloseSocket() throws Exception {
doReturn(new ArraySet<>(Arrays.asList(TEST_PACKAGE_UID, TEST_PACKAGE_UID2)))
.when(mBpfNetMaps)
.getUidsWithDenyRuleOnDenyListChain(anyInt());
doReturn(new ArraySet<>(Arrays.asList(TEST_PACKAGE_UID, TEST_PACKAGE_UID2)))
.when(mBpfNetMaps)
.getUidsWithAllowRuleOnAllowListChain(anyInt());
final boolean allowlist = true;
final boolean denylist = false;
doReturn(true).when(mBpfNetMaps).isFirewallAllowList(anyInt());
doTestSetFirewallChainEnabledCloseSocket(FIREWALL_CHAIN_DOZABLE, allowlist);
doTestSetFirewallChainEnabledCloseSocket(FIREWALL_CHAIN_POWERSAVE, allowlist);
doTestSetFirewallChainEnabledCloseSocket(FIREWALL_CHAIN_RESTRICTED, allowlist);
doTestSetFirewallChainEnabledCloseSocket(FIREWALL_CHAIN_LOW_POWER_STANDBY, allowlist);
doReturn(false).when(mBpfNetMaps).isFirewallAllowList(anyInt());
doTestSetFirewallChainEnabledCloseSocket(FIREWALL_CHAIN_STANDBY, denylist);
doTestSetFirewallChainEnabledCloseSocket(FIREWALL_CHAIN_OEM_DENY_1, denylist);
doTestSetFirewallChainEnabledCloseSocket(FIREWALL_CHAIN_OEM_DENY_2, denylist);
doTestSetFirewallChainEnabledCloseSocket(FIREWALL_CHAIN_OEM_DENY_3, denylist);
}
private void doTestReplaceFirewallChain(final int chain) {
final int[] uids = new int[] {1001, 1002};
mCm.replaceFirewallChain(chain, uids);