diff --git a/tests/net/java/com/android/server/NetworkManagementServiceTest.java b/tests/net/java/com/android/server/NetworkManagementServiceTest.java index 968b3071bf..b8b5886951 100644 --- a/tests/net/java/com/android/server/NetworkManagementServiceTest.java +++ b/tests/net/java/com/android/server/NetworkManagementServiceTest.java @@ -16,6 +16,12 @@ package com.android.server; +import static android.util.DebugUtils.valueToString; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; @@ -29,15 +35,19 @@ import android.content.Context; import android.net.INetd; import android.net.INetdUnsolicitedEventListener; import android.net.LinkAddress; +import android.net.NetworkPolicyManager; import android.os.BatteryStats; import android.os.Binder; import android.os.IBinder; +import android.os.Process; +import android.os.RemoteException; import android.test.suitebuilder.annotation.SmallTest; +import android.util.ArrayMap; import androidx.test.runner.AndroidJUnit4; import com.android.internal.app.IBatteryStats; -import com.android.server.NetworkManagementService.SystemServices; +import com.android.server.NetworkManagementService.Dependencies; import com.android.server.net.BaseNetworkObserver; import org.junit.After; @@ -49,13 +59,14 @@ import org.mockito.Captor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import java.util.function.BiFunction; + /** * Tests for {@link NetworkManagementService}. */ @RunWith(AndroidJUnit4.class) @SmallTest public class NetworkManagementServiceTest { - private NetworkManagementService mNMService; @Mock private Context mContext; @@ -66,7 +77,9 @@ public class NetworkManagementServiceTest { @Captor private ArgumentCaptor mUnsolListenerCaptor; - private final SystemServices mServices = new SystemServices() { + private final MockDependencies mDeps = new MockDependencies(); + + private final class MockDependencies extends Dependencies { @Override public IBinder getService(String name) { switch (name) { @@ -76,14 +89,21 @@ public class NetworkManagementServiceTest { throw new UnsupportedOperationException("Unknown service " + name); } } + @Override public void registerLocalService(NetworkManagementInternal nmi) { } + @Override public INetd getNetd() { return mNetdService; } - }; + + @Override + public int getCallingUid() { + return Process.SYSTEM_UID; + } + } @Before public void setUp() throws Exception { @@ -91,7 +111,7 @@ public class NetworkManagementServiceTest { doNothing().when(mNetdService) .registerUnsolicitedEventListener(mUnsolListenerCaptor.capture()); // Start the service and wait until it connects to our socket. - mNMService = NetworkManagementService.create(mContext, mServices); + mNMService = NetworkManagementService.create(mContext, mDeps); } @After @@ -192,4 +212,98 @@ public class NetworkManagementServiceTest { // Make sure nothing else was called. verifyNoMoreInteractions(observer); } + + @Test + public void testFirewallEnabled() { + mNMService.setFirewallEnabled(true); + assertTrue(mNMService.isFirewallEnabled()); + + mNMService.setFirewallEnabled(false); + assertFalse(mNMService.isFirewallEnabled()); + } + + private static final int TEST_UID = 111; + + @Test + public void testNetworkRestrictedDefault() { + assertFalse(mNMService.isNetworkRestricted(TEST_UID)); + } + + @Test + public void testMeteredNetworkRestrictions() throws RemoteException { + // Make sure the mocked netd method returns true. + doReturn(true).when(mNetdService).bandwidthEnableDataSaver(anyBoolean()); + + // Restrict usage of mobile data in background + mNMService.setUidMeteredNetworkDenylist(TEST_UID, true); + assertTrue("Should be true since mobile data usage is restricted", + mNMService.isNetworkRestricted(TEST_UID)); + + mNMService.setDataSaverModeEnabled(true); + verify(mNetdService).bandwidthEnableDataSaver(true); + + mNMService.setUidMeteredNetworkDenylist(TEST_UID, false); + assertTrue("Should be true since data saver is on and the uid is not allowlisted", + mNMService.isNetworkRestricted(TEST_UID)); + + mNMService.setUidMeteredNetworkAllowlist(TEST_UID, true); + assertFalse("Should be false since data saver is on and the uid is allowlisted", + mNMService.isNetworkRestricted(TEST_UID)); + + // remove uid from allowlist and turn datasaver off again + mNMService.setUidMeteredNetworkAllowlist(TEST_UID, false); + mNMService.setDataSaverModeEnabled(false); + verify(mNetdService).bandwidthEnableDataSaver(false); + assertFalse("Network should not be restricted when data saver is off", + mNMService.isNetworkRestricted(TEST_UID)); + } + + @Test + public void testFirewallChains() { + final ArrayMap> expected = new ArrayMap<>(); + // Dozable chain + final ArrayMap isRestrictedForDozable = new ArrayMap<>(); + isRestrictedForDozable.put(NetworkPolicyManager.FIREWALL_RULE_DEFAULT, true); + isRestrictedForDozable.put(INetd.FIREWALL_RULE_ALLOW, false); + isRestrictedForDozable.put(INetd.FIREWALL_RULE_DENY, true); + expected.put(INetd.FIREWALL_CHAIN_DOZABLE, isRestrictedForDozable); + // Powersaver chain + final ArrayMap isRestrictedForPowerSave = new ArrayMap<>(); + isRestrictedForPowerSave.put(NetworkPolicyManager.FIREWALL_RULE_DEFAULT, true); + isRestrictedForPowerSave.put(INetd.FIREWALL_RULE_ALLOW, false); + isRestrictedForPowerSave.put(INetd.FIREWALL_RULE_DENY, true); + expected.put(INetd.FIREWALL_CHAIN_POWERSAVE, isRestrictedForPowerSave); + // Standby chain + final ArrayMap isRestrictedForStandby = new ArrayMap<>(); + isRestrictedForStandby.put(NetworkPolicyManager.FIREWALL_RULE_DEFAULT, false); + isRestrictedForStandby.put(INetd.FIREWALL_RULE_ALLOW, false); + isRestrictedForStandby.put(INetd.FIREWALL_RULE_DENY, true); + expected.put(INetd.FIREWALL_CHAIN_STANDBY, isRestrictedForStandby); + + final int[] chains = { + INetd.FIREWALL_CHAIN_STANDBY, + INetd.FIREWALL_CHAIN_POWERSAVE, + INetd.FIREWALL_CHAIN_DOZABLE + }; + final int[] states = { + INetd.FIREWALL_RULE_ALLOW, + INetd.FIREWALL_RULE_DENY, + NetworkPolicyManager.FIREWALL_RULE_DEFAULT + }; + BiFunction errorMsg = (chain, state) -> { + return String.format("Unexpected value for chain: %s and state: %s", + valueToString(INetd.class, "FIREWALL_CHAIN_", chain), + valueToString(INetd.class, "FIREWALL_RULE_", state)); + }; + for (int chain : chains) { + final ArrayMap expectedValues = expected.get(chain); + mNMService.setFirewallChainEnabled(chain, true); + for (int state : states) { + mNMService.setFirewallUidRule(chain, TEST_UID, state); + assertEquals(errorMsg.apply(chain, state), + expectedValues.get(state), mNMService.isNetworkRestricted(TEST_UID)); + } + mNMService.setFirewallChainEnabled(chain, false); + } + } }