diff --git a/tests/unit/Android.bp b/tests/unit/Android.bp index 48751f4383..2927e43e36 100644 --- a/tests/unit/Android.bp +++ b/tests/unit/Android.bp @@ -89,6 +89,7 @@ filegroup { "java/com/android/server/connectivity/VpnTest.java", "java/com/android/server/net/ipmemorystore/*.java", "java/com/android/server/net/NetworkStats*.java", + "java/com/android/server/net/TestableUsageCallback.kt", ] } diff --git a/tests/unit/java/com/android/server/connectivity/MultipathPolicyTrackerTest.java b/tests/unit/java/com/android/server/connectivity/MultipathPolicyTrackerTest.java index c86e699466..ec515376af 100644 --- a/tests/unit/java/com/android/server/connectivity/MultipathPolicyTrackerTest.java +++ b/tests/unit/java/com/android/server/connectivity/MultipathPolicyTrackerTest.java @@ -296,7 +296,7 @@ public class MultipathPolicyTrackerTest { false /* roaming */); verify(mStatsManager, times(1)).registerUsageCallback( - any(), anyInt(), eq(DataUnit.MEGABYTES.toBytes(12)), any(), any()); + any(), eq(DataUnit.MEGABYTES.toBytes(12)), any(), any()); } @Test @@ -315,7 +315,7 @@ public class MultipathPolicyTrackerTest { // Daily budget should be 15MB (5% of daily quota), 7MB used today: callback set for 8MB verify(mStatsManager, times(1)).registerUsageCallback( - any(), anyInt(), eq(DataUnit.MEGABYTES.toBytes(8)), any(), any()); + any(), eq(DataUnit.MEGABYTES.toBytes(8)), any(), any()); } @Test @@ -334,7 +334,7 @@ public class MultipathPolicyTrackerTest { // Daily budget should be 15MB (5% of daily quota), 7MB used today: callback set for 8MB verify(mStatsManager, times(1)).registerUsageCallback( - any(), anyInt(), eq(DataUnit.MEGABYTES.toBytes(8)), any(), any()); + any(), eq(DataUnit.MEGABYTES.toBytes(8)), any(), any()); } @Test @@ -351,7 +351,7 @@ public class MultipathPolicyTrackerTest { // Default global setting should be used: 12 - 7 = 5 verify(mStatsManager, times(1)).registerUsageCallback( - any(), anyInt(), eq(DataUnit.MEGABYTES.toBytes(5)), any(), any()); + any(), eq(DataUnit.MEGABYTES.toBytes(5)), any(), any()); } @Test @@ -366,7 +366,7 @@ public class MultipathPolicyTrackerTest { false /* roaming */); verify(mStatsManager, times(1)).registerUsageCallback( - any(), anyInt(), eq(DataUnit.MEGABYTES.toBytes(8)), any(), any()); + any(), eq(DataUnit.MEGABYTES.toBytes(8)), any(), any()); // Update setting setDefaultQuotaGlobalSetting(DataUnit.MEGABYTES.toBytes(14)); @@ -376,7 +376,7 @@ public class MultipathPolicyTrackerTest { // Callback must have been re-registered with new setting verify(mStatsManager, times(1)).unregisterUsageCallback(any()); verify(mStatsManager, times(1)).registerUsageCallback( - any(), anyInt(), eq(DataUnit.MEGABYTES.toBytes(12)), any(), any()); + any(), eq(DataUnit.MEGABYTES.toBytes(12)), any(), any()); } @Test @@ -391,7 +391,7 @@ public class MultipathPolicyTrackerTest { false /* roaming */); verify(mStatsManager, times(1)).registerUsageCallback( - any(), anyInt(), eq(DataUnit.MEGABYTES.toBytes(12)), any(), any()); + any(), eq(DataUnit.MEGABYTES.toBytes(12)), any(), any()); when(mResources.getInteger(R.integer.config_networkDefaultDailyMultipathQuotaBytes)) .thenReturn((int) DataUnit.MEGABYTES.toBytes(16)); @@ -402,6 +402,6 @@ public class MultipathPolicyTrackerTest { // Uses the new setting (16 - 2 = 14MB) verify(mStatsManager, times(1)).registerUsageCallback( - any(), anyInt(), eq(DataUnit.MEGABYTES.toBytes(14)), any(), any()); + any(), eq(DataUnit.MEGABYTES.toBytes(14)), any(), any()); } } diff --git a/tests/unit/java/com/android/server/net/NetworkStatsObserversTest.java b/tests/unit/java/com/android/server/net/NetworkStatsObserversTest.java index d993d1f5a5..6a7da9ec1a 100644 --- a/tests/unit/java/com/android/server/net/NetworkStatsObserversTest.java +++ b/tests/unit/java/com/android/server/net/NetworkStatsObserversTest.java @@ -36,19 +36,16 @@ import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; -import android.app.usage.NetworkStatsManager; import android.net.DataUsageRequest; import android.net.NetworkIdentity; import android.net.NetworkIdentitySet; import android.net.NetworkStats; import android.net.NetworkStatsAccess; import android.net.NetworkTemplate; -import android.os.ConditionVariable; -import android.os.Handler; +import android.net.netstats.IUsageCallback; import android.os.HandlerThread; import android.os.IBinder; import android.os.Looper; -import android.os.Messenger; import android.os.Process; import android.os.UserHandle; import android.telephony.TelephonyManager; @@ -56,7 +53,6 @@ import android.util.ArrayMap; import androidx.test.filters.SmallTest; -import com.android.server.net.NetworkStatsServiceTest.LatchedHandler; import com.android.testutils.DevSdkIgnoreRule; import com.android.testutils.DevSdkIgnoreRunner; import com.android.testutils.HandlerUtils; @@ -97,21 +93,15 @@ public class NetworkStatsObserversTest { private static final long WAIT_TIMEOUT_MS = 500; private static final long THRESHOLD_BYTES = 2 * MB_IN_BYTES; private static final long BASE_BYTES = 7 * MB_IN_BYTES; - private static final int INVALID_TYPE = -1; - - private long mElapsedRealtime; private HandlerThread mObserverHandlerThread; - private Handler mObserverNoopHandler; - - private LatchedHandler mHandler; private NetworkStatsObservers mStatsObservers; - private Messenger mMessenger; private ArrayMap mActiveIfaces; private ArrayMap mActiveUidIfaces; - @Mock private IBinder mockBinder; + @Mock private IBinder mUsageCallbackBinder; + @Mock private IUsageCallback mUsageCallback; @Before public void setUp() throws Exception { @@ -127,11 +117,9 @@ public class NetworkStatsObserversTest { } }; - mHandler = new LatchedHandler(Looper.getMainLooper(), new ConditionVariable()); - mMessenger = new Messenger(mHandler); - mActiveIfaces = new ArrayMap<>(); mActiveUidIfaces = new ArrayMap<>(); + Mockito.when(mUsageCallback.asBinder()).thenReturn(mUsageCallbackBinder); } @Test @@ -140,7 +128,7 @@ public class NetworkStatsObserversTest { DataUsageRequest inputRequest = new DataUsageRequest( DataUsageRequest.REQUEST_ID_UNSET, sTemplateWifi, thresholdTooLowBytes); - DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder, + DataUsageRequest request = mStatsObservers.register(inputRequest, mUsageCallback, Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE); assertTrue(request.requestId > 0); assertTrue(Objects.equals(sTemplateWifi, request.template)); @@ -153,7 +141,7 @@ public class NetworkStatsObserversTest { DataUsageRequest inputRequest = new DataUsageRequest( DataUsageRequest.REQUEST_ID_UNSET, sTemplateWifi, highThresholdBytes); - DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder, + DataUsageRequest request = mStatsObservers.register(inputRequest, mUsageCallback, Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE); assertTrue(request.requestId > 0); assertTrue(Objects.equals(sTemplateWifi, request.template)); @@ -165,13 +153,13 @@ public class NetworkStatsObserversTest { DataUsageRequest inputRequest = new DataUsageRequest( DataUsageRequest.REQUEST_ID_UNSET, sTemplateWifi, THRESHOLD_BYTES); - DataUsageRequest request1 = mStatsObservers.register(inputRequest, mMessenger, mockBinder, + DataUsageRequest request1 = mStatsObservers.register(inputRequest, mUsageCallback, Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE); assertTrue(request1.requestId > 0); assertTrue(Objects.equals(sTemplateWifi, request1.template)); assertEquals(THRESHOLD_BYTES, request1.thresholdInBytes); - DataUsageRequest request2 = mStatsObservers.register(inputRequest, mMessenger, mockBinder, + DataUsageRequest request2 = mStatsObservers.register(inputRequest, mUsageCallback, Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE); assertTrue(request2.requestId > request1.requestId); assertTrue(Objects.equals(sTemplateWifi, request2.template)); @@ -191,17 +179,19 @@ public class NetworkStatsObserversTest { DataUsageRequest inputRequest = new DataUsageRequest( DataUsageRequest.REQUEST_ID_UNSET, sTemplateImsi1, THRESHOLD_BYTES); - DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder, + DataUsageRequest request = mStatsObservers.register(inputRequest, mUsageCallback, Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE); assertTrue(request.requestId > 0); assertTrue(Objects.equals(sTemplateImsi1, request.template)); assertEquals(THRESHOLD_BYTES, request.thresholdInBytes); - Mockito.verify(mockBinder).linkToDeath(any(IBinder.DeathRecipient.class), anyInt()); + Mockito.verify(mUsageCallbackBinder).linkToDeath(any(IBinder.DeathRecipient.class), + anyInt()); mStatsObservers.unregister(request, Process.SYSTEM_UID); waitForObserverToIdle(); - Mockito.verify(mockBinder).unlinkToDeath(any(IBinder.DeathRecipient.class), anyInt()); + Mockito.verify(mUsageCallbackBinder).unlinkToDeath(any(IBinder.DeathRecipient.class), + anyInt()); } @Test @@ -209,17 +199,18 @@ public class NetworkStatsObserversTest { DataUsageRequest inputRequest = new DataUsageRequest( DataUsageRequest.REQUEST_ID_UNSET, sTemplateImsi1, THRESHOLD_BYTES); - DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder, + DataUsageRequest request = mStatsObservers.register(inputRequest, mUsageCallback, UID_RED, NetworkStatsAccess.Level.DEVICE); assertTrue(request.requestId > 0); assertTrue(Objects.equals(sTemplateImsi1, request.template)); assertEquals(THRESHOLD_BYTES, request.thresholdInBytes); - Mockito.verify(mockBinder).linkToDeath(any(IBinder.DeathRecipient.class), anyInt()); + Mockito.verify(mUsageCallbackBinder) + .linkToDeath(any(IBinder.DeathRecipient.class), anyInt()); mStatsObservers.unregister(request, UID_BLUE); waitForObserverToIdle(); - Mockito.verifyZeroInteractions(mockBinder); + Mockito.verifyZeroInteractions(mUsageCallbackBinder); } private NetworkIdentitySet makeTestIdentSet() { @@ -236,7 +227,7 @@ public class NetworkStatsObserversTest { DataUsageRequest inputRequest = new DataUsageRequest( DataUsageRequest.REQUEST_ID_UNSET, sTemplateImsi1, THRESHOLD_BYTES); - DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder, + DataUsageRequest request = mStatsObservers.register(inputRequest, mUsageCallback, Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE); assertTrue(request.requestId > 0); assertTrue(Objects.equals(sTemplateImsi1, request.template)); @@ -260,7 +251,7 @@ public class NetworkStatsObserversTest { DataUsageRequest inputRequest = new DataUsageRequest( DataUsageRequest.REQUEST_ID_UNSET, sTemplateImsi1, THRESHOLD_BYTES); - DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder, + DataUsageRequest request = mStatsObservers.register(inputRequest, mUsageCallback, Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE); assertTrue(request.requestId > 0); assertTrue(Objects.equals(sTemplateImsi1, request.template)); @@ -290,7 +281,7 @@ public class NetworkStatsObserversTest { DataUsageRequest inputRequest = new DataUsageRequest( DataUsageRequest.REQUEST_ID_UNSET, sTemplateImsi1, THRESHOLD_BYTES); - DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder, + DataUsageRequest request = mStatsObservers.register(inputRequest, mUsageCallback, Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE); assertTrue(request.requestId > 0); assertTrue(Objects.equals(sTemplateImsi1, request.template)); @@ -313,7 +304,7 @@ public class NetworkStatsObserversTest { mStatsObservers.updateStats( xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces, TEST_START); waitForObserverToIdle(); - assertEquals(NetworkStatsManager.CALLBACK_LIMIT_REACHED, mHandler.lastMessageType); + Mockito.verify(mUsageCallback).onThresholdReached(any()); } @Test @@ -321,7 +312,7 @@ public class NetworkStatsObserversTest { DataUsageRequest inputRequest = new DataUsageRequest( DataUsageRequest.REQUEST_ID_UNSET, sTemplateImsi1, THRESHOLD_BYTES); - DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder, + DataUsageRequest request = mStatsObservers.register(inputRequest, mUsageCallback, UID_RED, NetworkStatsAccess.Level.DEFAULT); assertTrue(request.requestId > 0); assertTrue(Objects.equals(sTemplateImsi1, request.template)); @@ -346,7 +337,7 @@ public class NetworkStatsObserversTest { mStatsObservers.updateStats( xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces, TEST_START); waitForObserverToIdle(); - assertEquals(NetworkStatsManager.CALLBACK_LIMIT_REACHED, mHandler.lastMessageType); + Mockito.verify(mUsageCallback).onThresholdReached(any()); } @Test @@ -354,7 +345,7 @@ public class NetworkStatsObserversTest { DataUsageRequest inputRequest = new DataUsageRequest( DataUsageRequest.REQUEST_ID_UNSET, sTemplateImsi1, THRESHOLD_BYTES); - DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder, + DataUsageRequest request = mStatsObservers.register(inputRequest, mUsageCallback, UID_BLUE, NetworkStatsAccess.Level.DEFAULT); assertTrue(request.requestId > 0); assertTrue(Objects.equals(sTemplateImsi1, request.template)); @@ -386,7 +377,7 @@ public class NetworkStatsObserversTest { DataUsageRequest inputRequest = new DataUsageRequest( DataUsageRequest.REQUEST_ID_UNSET, sTemplateImsi1, THRESHOLD_BYTES); - DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder, + DataUsageRequest request = mStatsObservers.register(inputRequest, mUsageCallback, UID_BLUE, NetworkStatsAccess.Level.USER); assertTrue(request.requestId > 0); assertTrue(Objects.equals(sTemplateImsi1, request.template)); @@ -411,7 +402,7 @@ public class NetworkStatsObserversTest { mStatsObservers.updateStats( xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces, TEST_START); waitForObserverToIdle(); - assertEquals(NetworkStatsManager.CALLBACK_LIMIT_REACHED, mHandler.lastMessageType); + Mockito.verify(mUsageCallback).onThresholdReached(any()); } @Test @@ -419,7 +410,7 @@ public class NetworkStatsObserversTest { DataUsageRequest inputRequest = new DataUsageRequest( DataUsageRequest.REQUEST_ID_UNSET, sTemplateImsi1, THRESHOLD_BYTES); - DataUsageRequest request = mStatsObservers.register(inputRequest, mMessenger, mockBinder, + DataUsageRequest request = mStatsObservers.register(inputRequest, mUsageCallback, UID_RED, NetworkStatsAccess.Level.USER); assertTrue(request.requestId > 0); assertTrue(Objects.equals(sTemplateImsi1, request.template)); @@ -448,6 +439,5 @@ public class NetworkStatsObserversTest { private void waitForObserverToIdle() { HandlerUtils.waitForIdle(mObserverHandlerThread, WAIT_TIMEOUT_MS); - HandlerUtils.waitForIdle(mHandler, WAIT_TIMEOUT_MS); } } diff --git a/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java b/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java index d7bbf5018a..e3b3621288 100644 --- a/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java +++ b/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java @@ -80,7 +80,6 @@ import static org.mockito.Mockito.when; import android.annotation.NonNull; import android.app.AlarmManager; -import android.app.usage.NetworkStatsManager; import android.content.Context; import android.content.Intent; import android.database.ContentObserver; @@ -101,13 +100,9 @@ import android.net.UnderlyingNetworkInfo; import android.net.netstats.provider.INetworkStatsProviderCallback; import android.net.wifi.WifiInfo; import android.os.Build; -import android.os.ConditionVariable; import android.os.Handler; import android.os.HandlerThread; import android.os.IBinder; -import android.os.Looper; -import android.os.Message; -import android.os.Messenger; import android.os.PowerManager; import android.os.SimpleClock; import android.provider.Settings; @@ -188,7 +183,8 @@ public class NetworkStatsServiceTest extends NetworkStatsBaseTest { private @Mock TetheringManager mTetheringManager; private @Mock NetworkStatsFactory mStatsFactory; private @Mock NetworkStatsSettings mSettings; - private @Mock IBinder mBinder; + private @Mock IBinder mUsageCallbackBinder; + private TestableUsageCallback mUsageCallback; private @Mock AlarmManager mAlarmManager; @Mock private NetworkStatsSubscriptionsMonitor mNetworkStatsSubscriptionsMonitor; @@ -313,6 +309,8 @@ public class NetworkStatsServiceTest extends NetworkStatsBaseTest { verify(mTetheringManager).registerTetheringEventCallback( any(), tetheringEventCbCaptor.capture()); mTetheringEventCallback = tetheringEventCbCaptor.getValue(); + + mUsageCallback = new TestableUsageCallback(mUsageCallbackBinder); } @NonNull @@ -1240,20 +1238,14 @@ public class NetworkStatsServiceTest extends NetworkStatsBaseTest { DataUsageRequest inputRequest = new DataUsageRequest( DataUsageRequest.REQUEST_ID_UNSET, sTemplateWifi, thresholdInBytes); - // Create a messenger that waits for callback activity - ConditionVariable cv = new ConditionVariable(false); - LatchedHandler latchedHandler = new LatchedHandler(Looper.getMainLooper(), cv); - Messenger messenger = new Messenger(latchedHandler); - // Force poll expectDefaultSettings(); expectNetworkStatsSummary(buildEmptyStats()); expectNetworkStatsUidDetail(buildEmptyStats()); // Register and verify request and that binder was called - DataUsageRequest request = - mService.registerUsageCallback(mServiceContext.getOpPackageName(), inputRequest, - messenger, mBinder); + DataUsageRequest request = mService.registerUsageCallback( + mServiceContext.getOpPackageName(), inputRequest, mUsageCallback); assertTrue(request.requestId > 0); assertTrue(Objects.equals(sTemplateWifi, request.template)); long minThresholdInBytes = 2 * 1024 * 1024; // 2 MB @@ -1262,7 +1254,7 @@ public class NetworkStatsServiceTest extends NetworkStatsBaseTest { HandlerUtils.waitForIdle(mHandlerThread, WAIT_TIMEOUT); // Make sure that the caller binder gets connected - verify(mBinder).linkToDeath(any(IBinder.DeathRecipient.class), anyInt()); + verify(mUsageCallbackBinder).linkToDeath(any(IBinder.DeathRecipient.class), anyInt()); // modify some number on wifi, and trigger poll event // not enough traffic to call data usage callback @@ -1277,7 +1269,7 @@ public class NetworkStatsServiceTest extends NetworkStatsBaseTest { assertNetworkTotal(sTemplateWifi, 1024L, 1L, 2048L, 2L, 0); // make sure callback has not being called - assertEquals(INVALID_TYPE, latchedHandler.lastMessageType); + mUsageCallback.assertNoCallback(); // and bump forward again, with counters going higher. this is // important, since it will trigger the data usage callback @@ -1292,23 +1284,21 @@ public class NetworkStatsServiceTest extends NetworkStatsBaseTest { assertNetworkTotal(sTemplateWifi, 4096000L, 4L, 8192000L, 8L, 0); - // Wait for the caller to ack receipt of CALLBACK_LIMIT_REACHED - assertTrue(cv.block(WAIT_TIMEOUT)); - assertEquals(NetworkStatsManager.CALLBACK_LIMIT_REACHED, latchedHandler.lastMessageType); - cv.close(); + // Wait for the caller to invoke expectOnThresholdReached. + mUsageCallback.expectOnThresholdReached(); // Allow binder to disconnect - when(mBinder.unlinkToDeath(any(IBinder.DeathRecipient.class), anyInt())).thenReturn(true); + when(mUsageCallbackBinder.unlinkToDeath(any(IBinder.DeathRecipient.class), anyInt())) + .thenReturn(true); // Unregister request mService.unregisterUsageRequest(request); - // Wait for the caller to ack receipt of CALLBACK_RELEASED - assertTrue(cv.block(WAIT_TIMEOUT)); - assertEquals(NetworkStatsManager.CALLBACK_RELEASED, latchedHandler.lastMessageType); + // Wait for the caller to invoke expectOnCallbackReleased. + mUsageCallback.expectOnCallbackReleased(); // Make sure that the caller binder gets disconnected - verify(mBinder).unlinkToDeath(any(IBinder.DeathRecipient.class), anyInt()); + verify(mUsageCallbackBinder).unlinkToDeath(any(IBinder.DeathRecipient.class), anyInt()); } @Test @@ -1884,21 +1874,4 @@ public class NetworkStatsServiceTest extends NetworkStatsBaseTest { private void waitForIdle() { HandlerUtils.waitForIdle(mHandlerThread, WAIT_TIMEOUT); } - - static class LatchedHandler extends Handler { - private final ConditionVariable mCv; - int lastMessageType = INVALID_TYPE; - - LatchedHandler(Looper looper, ConditionVariable cv) { - super(looper); - mCv = cv; - } - - @Override - public void handleMessage(Message msg) { - lastMessageType = msg.what; - mCv.open(); - super.handleMessage(msg); - } - } } diff --git a/tests/unit/java/com/android/server/net/TestableUsageCallback.kt b/tests/unit/java/com/android/server/net/TestableUsageCallback.kt new file mode 100644 index 0000000000..44f588c669 --- /dev/null +++ b/tests/unit/java/com/android/server/net/TestableUsageCallback.kt @@ -0,0 +1,66 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.server.net + +import android.net.DataUsageRequest +import android.net.netstats.IUsageCallback +import android.os.IBinder +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.TimeUnit +import kotlin.test.assertEquals +import kotlin.test.fail + +private const val DEFAULT_TIMEOUT_MS = 200L + +// TODO: Move the class to static libs once all downstream have IUsageCallback definition. +open class TestableUsageCallback(private val binder: IBinder) : IUsageCallback.Stub() { + sealed class CallbackType { + object OnThresholdReached : CallbackType() + object OnCallbackReleased : CallbackType() + } + + // TODO: Change to use ArrayTrackRecord once moved into to the module. + private val history = LinkedBlockingQueue() + + override fun onThresholdReached(request: DataUsageRequest) { + history.add(CallbackType.OnThresholdReached) + } + + override fun onCallbackReleased(request: DataUsageRequest) { + history.add(CallbackType.OnCallbackReleased) + } + + fun expectOnThresholdReached() { + assertEquals(CallbackType.OnThresholdReached, + history.poll(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS)) + } + + fun expectOnCallbackReleased() { + assertEquals(CallbackType.OnCallbackReleased, + history.poll(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS)) + } + + @JvmOverloads + fun assertNoCallback(timeout: Long = DEFAULT_TIMEOUT_MS) { + val cb = history.poll(timeout, TimeUnit.MILLISECONDS) + cb?.let { fail("Expected no callback but got $cb") } + } + + override fun asBinder(): IBinder { + return binder + } +} \ No newline at end of file