diff --git a/tests/unit/java/com/android/server/connectivity/MultipathPolicyTrackerTest.java b/tests/unit/java/com/android/server/connectivity/MultipathPolicyTrackerTest.java index 349ca5662e..c86e699466 100644 --- a/tests/unit/java/com/android/server/connectivity/MultipathPolicyTrackerTest.java +++ b/tests/unit/java/com/android/server/connectivity/MultipathPolicyTrackerTest.java @@ -31,15 +31,16 @@ import static com.android.server.net.NetworkPolicyManagerService.OPPORTUNISTIC_Q import static org.junit.Assert.assertNotNull; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; -import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doCallRealMethod; import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import android.app.usage.NetworkStats; import android.app.usage.NetworkStatsManager; import android.content.BroadcastReceiver; import android.content.Context; @@ -61,6 +62,7 @@ import android.provider.Settings; import android.telephony.TelephonyManager; import android.test.mock.MockContentResolver; import android.util.DataUnit; +import android.util.Range; import android.util.RecurrenceRule; import androidx.test.filters.SmallTest; @@ -69,7 +71,6 @@ import com.android.internal.R; import com.android.internal.util.test.FakeSettingsProvider; import com.android.server.LocalServices; import com.android.server.net.NetworkPolicyManagerInternal; -import com.android.server.net.NetworkStatsManagerInternal; import com.android.testutils.DevSdkIgnoreRule; import com.android.testutils.DevSdkIgnoreRunner; @@ -88,6 +89,7 @@ import java.time.Period; import java.time.ZoneId; import java.time.ZonedDateTime; import java.time.temporal.ChronoUnit; +import java.util.Set; @RunWith(DevSdkIgnoreRunner.class) @SmallTest @@ -107,7 +109,6 @@ public class MultipathPolicyTrackerTest { @Mock private NetworkPolicyManager mNPM; @Mock private NetworkStatsManager mStatsManager; @Mock private NetworkPolicyManagerInternal mNPMI; - @Mock private NetworkStatsManagerInternal mNetworkStatsManagerInternal; @Mock private TelephonyManager mTelephonyManager; private MockContentResolver mContentResolver; @@ -165,9 +166,6 @@ public class MultipathPolicyTrackerTest { LocalServices.removeServiceForTest(NetworkPolicyManagerInternal.class); LocalServices.addService(NetworkPolicyManagerInternal.class, mNPMI); - LocalServices.removeServiceForTest(NetworkStatsManagerInternal.class); - LocalServices.addService(NetworkStatsManagerInternal.class, mNetworkStatsManagerInternal); - mTracker = new MultipathPolicyTracker(mContext, mHandler, mDeps); } @@ -202,6 +200,11 @@ public class MultipathPolicyTrackerTest { when(mNPMI.getSubscriptionOpportunisticQuota(TEST_NETWORK, QUOTA_TYPE_MULTIPATH)) .thenReturn(subscriptionQuota); + // Prepare stats to be mocked. + final NetworkStats.Bucket mockedStatsBucket = mock(NetworkStats.Bucket.class); + when(mockedStatsBucket.getTxBytes()).thenReturn(usedBytesToday / 3); + when(mockedStatsBucket.getRxBytes()).thenReturn(usedBytesToday - usedBytesToday / 3); + // Setup user policy warning / limit if (policyWarning != WARNING_DISABLED || policyLimit != LIMIT_DISABLED) { final Instant recurrenceStart = Instant.parse("2017-04-01T00:00:00Z"); @@ -215,7 +218,9 @@ public class MultipathPolicyTrackerTest { final boolean snoozeLimit = policyLimit == POLICY_SNOOZED; when(mNPM.getNetworkPolicies()).thenReturn(new NetworkPolicy[] { new NetworkPolicy( - NetworkTemplate.buildTemplateMobileWildcard(), + new NetworkTemplate.Builder(NetworkTemplate.MATCH_MOBILE) + .setSubscriberIds(Set.of(TEST_IMSI1)) + .setMeteredness(android.net.NetworkStats.METERED_YES).build(), recurrenceRule, snoozeWarning ? 0 : policyWarning, snoozeLimit ? 0 : policyLimit, @@ -225,6 +230,13 @@ public class MultipathPolicyTrackerTest { true /* metered */, false /* inferred */) }); + + // Mock stats for this month. + final Range cycleOfTheMonth = recurrenceRule.cycleIterator().next(); + when(mStatsManager.querySummaryForDevice(any(), + eq(cycleOfTheMonth.getLower().toInstant().toEpochMilli()), + eq(cycleOfTheMonth.getUpper().toInstant().toEpochMilli()))) + .thenReturn(mockedStatsBucket); } else { when(mNPM.getNetworkPolicies()).thenReturn(new NetworkPolicy[0]); } @@ -236,8 +248,10 @@ public class MultipathPolicyTrackerTest { when(mResources.getInteger(R.integer.config_networkDefaultDailyMultipathQuotaBytes)) .thenReturn((int) defaultResSetting); - when(mNetworkStatsManagerInternal.getNetworkTotalBytes( - any(), anyLong(), anyLong())).thenReturn(usedBytesToday); + // Mock stats for today. + when(mStatsManager.querySummaryForDevice(any(), + eq(startOfDay.toInstant().toEpochMilli()), + eq(now.toInstant().toEpochMilli()))).thenReturn(mockedStatsBucket); ArgumentCaptor networkCallback = ArgumentCaptor.forClass(ConnectivityManager.NetworkCallback.class);