diff --git a/tests/unit/java/com/android/server/net/NetworkStatsObserversTest.java b/tests/unit/java/com/android/server/net/NetworkStatsObserversTest.java index 6a7da9ec1a..66dcf6d2c7 100644 --- a/tests/unit/java/com/android/server/net/NetworkStatsObserversTest.java +++ b/tests/unit/java/com/android/server/net/NetworkStatsObserversTest.java @@ -42,7 +42,6 @@ import android.net.NetworkIdentitySet; import android.net.NetworkStats; import android.net.NetworkStatsAccess; import android.net.NetworkTemplate; -import android.net.netstats.IUsageCallback; import android.os.HandlerThread; import android.os.IBinder; import android.os.Looper; @@ -101,7 +100,7 @@ public class NetworkStatsObserversTest { private ArrayMap mActiveUidIfaces; @Mock private IBinder mUsageCallbackBinder; - @Mock private IUsageCallback mUsageCallback; + private TestableUsageCallback mUsageCallback; @Before public void setUp() throws Exception { @@ -119,20 +118,27 @@ public class NetworkStatsObserversTest { mActiveIfaces = new ArrayMap<>(); mActiveUidIfaces = new ArrayMap<>(); - Mockito.when(mUsageCallback.asBinder()).thenReturn(mUsageCallbackBinder); + mUsageCallback = new TestableUsageCallback(mUsageCallbackBinder); } @Test public void testRegister_thresholdTooLow_setsDefaultThreshold() throws Exception { - long thresholdTooLowBytes = 1L; - DataUsageRequest inputRequest = new DataUsageRequest( + final long thresholdTooLowBytes = 1L; + final DataUsageRequest inputRequest = new DataUsageRequest( DataUsageRequest.REQUEST_ID_UNSET, sTemplateWifi, thresholdTooLowBytes); - DataUsageRequest request = mStatsObservers.register(inputRequest, mUsageCallback, - Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE); - assertTrue(request.requestId > 0); - assertTrue(Objects.equals(sTemplateWifi, request.template)); - assertEquals(THRESHOLD_BYTES, request.thresholdInBytes); + final DataUsageRequest requestByApp = mStatsObservers.register(inputRequest, mUsageCallback, + UID_RED, NetworkStatsAccess.Level.DEVICE); + assertTrue(requestByApp.requestId > 0); + assertTrue(Objects.equals(sTemplateWifi, requestByApp.template)); + assertEquals(THRESHOLD_BYTES, requestByApp.thresholdInBytes); + + // Verify the threshold requested by system uid won't be overridden. + final DataUsageRequest requestBySystem = mStatsObservers.register(inputRequest, + mUsageCallback, Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE); + assertTrue(requestBySystem.requestId > 0); + assertTrue(Objects.equals(sTemplateWifi, requestBySystem.template)); + assertEquals(1, requestBySystem.thresholdInBytes); } @Test @@ -304,7 +310,7 @@ public class NetworkStatsObserversTest { mStatsObservers.updateStats( xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces, TEST_START); waitForObserverToIdle(); - Mockito.verify(mUsageCallback).onThresholdReached(any()); + mUsageCallback.expectOnThresholdReached(request); } @Test @@ -337,7 +343,7 @@ public class NetworkStatsObserversTest { mStatsObservers.updateStats( xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces, TEST_START); waitForObserverToIdle(); - Mockito.verify(mUsageCallback).onThresholdReached(any()); + mUsageCallback.expectOnThresholdReached(request); } @Test @@ -402,7 +408,7 @@ public class NetworkStatsObserversTest { mStatsObservers.updateStats( xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces, TEST_START); waitForObserverToIdle(); - Mockito.verify(mUsageCallback).onThresholdReached(any()); + mUsageCallback.expectOnThresholdReached(request); } @Test diff --git a/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java b/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java index 94a4f3d1f7..ea35c31613 100644 --- a/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java +++ b/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java @@ -1292,7 +1292,7 @@ public class NetworkStatsServiceTest extends NetworkStatsBaseTest { // Wait for the caller to invoke expectOnThresholdReached. - mUsageCallback.expectOnThresholdReached(); + mUsageCallback.expectOnThresholdReached(request); // Allow binder to disconnect when(mUsageCallbackBinder.unlinkToDeath(any(IBinder.DeathRecipient.class), anyInt())) @@ -1302,7 +1302,7 @@ public class NetworkStatsServiceTest extends NetworkStatsBaseTest { mService.unregisterUsageRequest(request); // Wait for the caller to invoke expectOnCallbackReleased. - mUsageCallback.expectOnCallbackReleased(); + mUsageCallback.expectOnCallbackReleased(request); // Make sure that the caller binder gets disconnected verify(mUsageCallbackBinder).unlinkToDeath(any(IBinder.DeathRecipient.class), anyInt()); diff --git a/tests/unit/java/com/android/server/net/TestableUsageCallback.kt b/tests/unit/java/com/android/server/net/TestableUsageCallback.kt index 44f588c669..1917ec3b7b 100644 --- a/tests/unit/java/com/android/server/net/TestableUsageCallback.kt +++ b/tests/unit/java/com/android/server/net/TestableUsageCallback.kt @@ -21,37 +21,34 @@ import android.net.netstats.IUsageCallback import android.os.IBinder import java.util.concurrent.LinkedBlockingQueue import java.util.concurrent.TimeUnit -import kotlin.test.assertEquals import kotlin.test.fail private const val DEFAULT_TIMEOUT_MS = 200L // TODO: Move the class to static libs once all downstream have IUsageCallback definition. -open class TestableUsageCallback(private val binder: IBinder) : IUsageCallback.Stub() { - sealed class CallbackType { - object OnThresholdReached : CallbackType() - object OnCallbackReleased : CallbackType() +class TestableUsageCallback(private val binder: IBinder) : IUsageCallback.Stub() { + sealed class CallbackType(val request: DataUsageRequest) { + class OnThresholdReached(request: DataUsageRequest) : CallbackType(request) + class OnCallbackReleased(request: DataUsageRequest) : CallbackType(request) } // TODO: Change to use ArrayTrackRecord once moved into to the module. private val history = LinkedBlockingQueue() override fun onThresholdReached(request: DataUsageRequest) { - history.add(CallbackType.OnThresholdReached) + history.add(CallbackType.OnThresholdReached(request)) } override fun onCallbackReleased(request: DataUsageRequest) { - history.add(CallbackType.OnCallbackReleased) + history.add(CallbackType.OnCallbackReleased(request)) } - fun expectOnThresholdReached() { - assertEquals(CallbackType.OnThresholdReached, - history.poll(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS)) + fun expectOnThresholdReached(request: DataUsageRequest) { + expectCallback(request, DEFAULT_TIMEOUT_MS) } - fun expectOnCallbackReleased() { - assertEquals(CallbackType.OnCallbackReleased, - history.poll(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS)) + fun expectOnCallbackReleased(request: DataUsageRequest) { + expectCallback(request, DEFAULT_TIMEOUT_MS) } @JvmOverloads @@ -60,6 +57,22 @@ open class TestableUsageCallback(private val binder: IBinder) : IUsageCallback.S cb?.let { fail("Expected no callback but got $cb") } } + // Expects a callback of the specified request on the specified network within the timeout. + // If no callback arrives, or a different callback arrives, fail. + private inline fun expectCallback( + expectedRequest: DataUsageRequest, + timeoutMs: Long + ) { + history.poll(timeoutMs, TimeUnit.MILLISECONDS).let { + if (it !is T || it.request != expectedRequest) { + fail("Unexpected callback : $it," + + " expected ${T::class} with Request[$expectedRequest]") + } else { + it + } + } + } + override fun asBinder(): IBinder { return binder }