[MS68.2] Adapt NetworkPolicyManagerService to use registerUsageCallback

This change also address comments at aosp/1958145.

Test: atest FrameworksNetTests
Bug: 204830222
Change-Id: I415d74df34caa91b1e1160478ebff30fbd1c7b6f
This commit is contained in:
junyulai
2022-01-24 21:02:15 +08:00
parent f8a1d3e5e6
commit 0a007c248c
3 changed files with 47 additions and 28 deletions

View File

@@ -42,7 +42,6 @@ import android.net.NetworkIdentitySet;
import android.net.NetworkStats; import android.net.NetworkStats;
import android.net.NetworkStatsAccess; import android.net.NetworkStatsAccess;
import android.net.NetworkTemplate; import android.net.NetworkTemplate;
import android.net.netstats.IUsageCallback;
import android.os.HandlerThread; import android.os.HandlerThread;
import android.os.IBinder; import android.os.IBinder;
import android.os.Looper; import android.os.Looper;
@@ -101,7 +100,7 @@ public class NetworkStatsObserversTest {
private ArrayMap<String, NetworkIdentitySet> mActiveUidIfaces; private ArrayMap<String, NetworkIdentitySet> mActiveUidIfaces;
@Mock private IBinder mUsageCallbackBinder; @Mock private IBinder mUsageCallbackBinder;
@Mock private IUsageCallback mUsageCallback; private TestableUsageCallback mUsageCallback;
@Before @Before
public void setUp() throws Exception { public void setUp() throws Exception {
@@ -119,20 +118,27 @@ public class NetworkStatsObserversTest {
mActiveIfaces = new ArrayMap<>(); mActiveIfaces = new ArrayMap<>();
mActiveUidIfaces = new ArrayMap<>(); mActiveUidIfaces = new ArrayMap<>();
Mockito.when(mUsageCallback.asBinder()).thenReturn(mUsageCallbackBinder); mUsageCallback = new TestableUsageCallback(mUsageCallbackBinder);
} }
@Test @Test
public void testRegister_thresholdTooLow_setsDefaultThreshold() throws Exception { public void testRegister_thresholdTooLow_setsDefaultThreshold() throws Exception {
long thresholdTooLowBytes = 1L; final long thresholdTooLowBytes = 1L;
DataUsageRequest inputRequest = new DataUsageRequest( final DataUsageRequest inputRequest = new DataUsageRequest(
DataUsageRequest.REQUEST_ID_UNSET, sTemplateWifi, thresholdTooLowBytes); DataUsageRequest.REQUEST_ID_UNSET, sTemplateWifi, thresholdTooLowBytes);
DataUsageRequest request = mStatsObservers.register(inputRequest, mUsageCallback, final DataUsageRequest requestByApp = mStatsObservers.register(inputRequest, mUsageCallback,
Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE); UID_RED, NetworkStatsAccess.Level.DEVICE);
assertTrue(request.requestId > 0); assertTrue(requestByApp.requestId > 0);
assertTrue(Objects.equals(sTemplateWifi, request.template)); assertTrue(Objects.equals(sTemplateWifi, requestByApp.template));
assertEquals(THRESHOLD_BYTES, request.thresholdInBytes); assertEquals(THRESHOLD_BYTES, requestByApp.thresholdInBytes);
// Verify the threshold requested by system uid won't be overridden.
final DataUsageRequest requestBySystem = mStatsObservers.register(inputRequest,
mUsageCallback, Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE);
assertTrue(requestBySystem.requestId > 0);
assertTrue(Objects.equals(sTemplateWifi, requestBySystem.template));
assertEquals(1, requestBySystem.thresholdInBytes);
} }
@Test @Test
@@ -304,7 +310,7 @@ public class NetworkStatsObserversTest {
mStatsObservers.updateStats( mStatsObservers.updateStats(
xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces, TEST_START); xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces, TEST_START);
waitForObserverToIdle(); waitForObserverToIdle();
Mockito.verify(mUsageCallback).onThresholdReached(any()); mUsageCallback.expectOnThresholdReached(request);
} }
@Test @Test
@@ -337,7 +343,7 @@ public class NetworkStatsObserversTest {
mStatsObservers.updateStats( mStatsObservers.updateStats(
xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces, TEST_START); xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces, TEST_START);
waitForObserverToIdle(); waitForObserverToIdle();
Mockito.verify(mUsageCallback).onThresholdReached(any()); mUsageCallback.expectOnThresholdReached(request);
} }
@Test @Test
@@ -402,7 +408,7 @@ public class NetworkStatsObserversTest {
mStatsObservers.updateStats( mStatsObservers.updateStats(
xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces, TEST_START); xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces, TEST_START);
waitForObserverToIdle(); waitForObserverToIdle();
Mockito.verify(mUsageCallback).onThresholdReached(any()); mUsageCallback.expectOnThresholdReached(request);
} }
@Test @Test

View File

@@ -1285,7 +1285,7 @@ public class NetworkStatsServiceTest extends NetworkStatsBaseTest {
// Wait for the caller to invoke expectOnThresholdReached. // Wait for the caller to invoke expectOnThresholdReached.
mUsageCallback.expectOnThresholdReached(); mUsageCallback.expectOnThresholdReached(request);
// Allow binder to disconnect // Allow binder to disconnect
when(mUsageCallbackBinder.unlinkToDeath(any(IBinder.DeathRecipient.class), anyInt())) when(mUsageCallbackBinder.unlinkToDeath(any(IBinder.DeathRecipient.class), anyInt()))
@@ -1295,7 +1295,7 @@ public class NetworkStatsServiceTest extends NetworkStatsBaseTest {
mService.unregisterUsageRequest(request); mService.unregisterUsageRequest(request);
// Wait for the caller to invoke expectOnCallbackReleased. // Wait for the caller to invoke expectOnCallbackReleased.
mUsageCallback.expectOnCallbackReleased(); mUsageCallback.expectOnCallbackReleased(request);
// Make sure that the caller binder gets disconnected // Make sure that the caller binder gets disconnected
verify(mUsageCallbackBinder).unlinkToDeath(any(IBinder.DeathRecipient.class), anyInt()); verify(mUsageCallbackBinder).unlinkToDeath(any(IBinder.DeathRecipient.class), anyInt());

View File

@@ -21,37 +21,34 @@ import android.net.netstats.IUsageCallback
import android.os.IBinder import android.os.IBinder
import java.util.concurrent.LinkedBlockingQueue import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import kotlin.test.assertEquals
import kotlin.test.fail import kotlin.test.fail
private const val DEFAULT_TIMEOUT_MS = 200L private const val DEFAULT_TIMEOUT_MS = 200L
// TODO: Move the class to static libs once all downstream have IUsageCallback definition. // TODO: Move the class to static libs once all downstream have IUsageCallback definition.
open class TestableUsageCallback(private val binder: IBinder) : IUsageCallback.Stub() { class TestableUsageCallback(private val binder: IBinder) : IUsageCallback.Stub() {
sealed class CallbackType { sealed class CallbackType(val request: DataUsageRequest) {
object OnThresholdReached : CallbackType() class OnThresholdReached(request: DataUsageRequest) : CallbackType(request)
object OnCallbackReleased : CallbackType() class OnCallbackReleased(request: DataUsageRequest) : CallbackType(request)
} }
// TODO: Change to use ArrayTrackRecord once moved into to the module. // TODO: Change to use ArrayTrackRecord once moved into to the module.
private val history = LinkedBlockingQueue<CallbackType>() private val history = LinkedBlockingQueue<CallbackType>()
override fun onThresholdReached(request: DataUsageRequest) { override fun onThresholdReached(request: DataUsageRequest) {
history.add(CallbackType.OnThresholdReached) history.add(CallbackType.OnThresholdReached(request))
} }
override fun onCallbackReleased(request: DataUsageRequest) { override fun onCallbackReleased(request: DataUsageRequest) {
history.add(CallbackType.OnCallbackReleased) history.add(CallbackType.OnCallbackReleased(request))
} }
fun expectOnThresholdReached() { fun expectOnThresholdReached(request: DataUsageRequest) {
assertEquals(CallbackType.OnThresholdReached, expectCallback<CallbackType.OnThresholdReached>(request, DEFAULT_TIMEOUT_MS)
history.poll(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS))
} }
fun expectOnCallbackReleased() { fun expectOnCallbackReleased(request: DataUsageRequest) {
assertEquals(CallbackType.OnCallbackReleased, expectCallback<CallbackType.OnCallbackReleased>(request, DEFAULT_TIMEOUT_MS)
history.poll(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS))
} }
@JvmOverloads @JvmOverloads
@@ -60,6 +57,22 @@ open class TestableUsageCallback(private val binder: IBinder) : IUsageCallback.S
cb?.let { fail("Expected no callback but got $cb") } cb?.let { fail("Expected no callback but got $cb") }
} }
// Expects a callback of the specified request on the specified network within the timeout.
// If no callback arrives, or a different callback arrives, fail.
private inline fun <reified T : CallbackType> expectCallback(
expectedRequest: DataUsageRequest,
timeoutMs: Long
) {
history.poll(timeoutMs, TimeUnit.MILLISECONDS).let {
if (it !is T || it.request != expectedRequest) {
fail("Unexpected callback : $it," +
" expected ${T::class} with Request[$expectedRequest]")
} else {
it
}
}
}
override fun asBinder(): IBinder { override fun asBinder(): IBinder {
return binder return binder
} }