diff --git a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt index 2fdd5fb201..d0e3023db6 100644 --- a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt +++ b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt @@ -67,6 +67,9 @@ class NetworkAgentTest { private class Provider(context: Context, looper: Looper) : NetworkProvider(context, looper, "NetworkAgentTest NetworkProvider") + private val agentsToCleanUp = mutableListOf() + private val callbacksToCleanUp = mutableListOf() + @Before fun setUp() { instrumentation.getUiAutomation().adoptShellPermissionIdentity() @@ -75,11 +78,13 @@ class NetworkAgentTest { @After fun tearDown() { + agentsToCleanUp.forEach { it.unregister() } + callbacksToCleanUp.forEach { mCM.unregisterNetworkCallback(it) } mHandlerThread.quitSafely() instrumentation.getUiAutomation().dropShellPermissionIdentity() } - internal class TestableNetworkAgent( + private class TestableNetworkAgent( looper: Looper, nc: NetworkCapabilities, lp: LinkProperties, @@ -94,12 +99,10 @@ class NetworkAgentTest { } override fun onBandwidthUpdateRequested() { - super.onBandwidthUpdateRequested() history.add(OnBandwidthUpdateRequested) } override fun onNetworkUnwanted() { - super.onNetworkUnwanted() history.add(OnNetworkUnwanted) } @@ -109,6 +112,11 @@ class NetworkAgentTest { } } + private fun requestNetwork(request: NetworkRequest, callback: TestableNetworkCallback) { + mCM.requestNetwork(request, callback) + callbacksToCleanUp.add(callback) + } + private fun createNetworkAgent(): TestableNetworkAgent { val nc = NetworkCapabilities().apply { addTransportType(NetworkCapabilities.TRANSPORT_TEST) @@ -120,7 +128,9 @@ class NetworkAgentTest { } val lp = LinkProperties() val config = NetworkAgentConfig.Builder().build() - return TestableNetworkAgent(mHandlerThread.looper, nc, lp, config) + return TestableNetworkAgent(mHandlerThread.looper, nc, lp, config).also { + agentsToCleanUp.add(it) + } } private fun createConnectedNetworkAgent(): Pair { @@ -129,8 +139,9 @@ class NetworkAgentTest { .addTransportType(NetworkCapabilities.TRANSPORT_TEST) .build() val callback = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS) - mCM.requestNetwork(request, callback) - val agent = createNetworkAgent().also { it.register() } + requestNetwork(request, callback) + val agent = createNetworkAgent() + agent.register() agent.markConnected() return agent to callback }