diff --git a/tests/common/java/android/net/NetworkProviderTest.kt b/tests/common/java/android/net/NetworkProviderTest.kt index 7424157bea..97d3c5a802 100644 --- a/tests/common/java/android/net/NetworkProviderTest.kt +++ b/tests/common/java/android/net/NetworkProviderTest.kt @@ -18,6 +18,7 @@ package android.net import android.app.Instrumentation import android.content.Context +import android.net.NetworkCapabilities.NET_CAPABILITY_TRUSTED import android.net.NetworkCapabilities.TRANSPORT_TEST import android.net.NetworkProviderTest.TestNetworkCallback.CallbackEntry.OnUnavailable import android.net.NetworkProviderTest.TestNetworkProvider.CallbackEntry.OnNetworkRequestWithdrawn @@ -25,14 +26,18 @@ import android.net.NetworkProviderTest.TestNetworkProvider.CallbackEntry.OnNetwo import android.os.Build import android.os.HandlerThread import android.os.Looper +import android.util.Log import androidx.test.InstrumentationRegistry import com.android.net.module.util.ArrayTrackRecord import com.android.testutils.CompatUtil +import com.android.testutils.DevSdkIgnoreRule +import com.android.testutils.DevSdkIgnoreRule.IgnoreAfter import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo import com.android.testutils.DevSdkIgnoreRunner import com.android.testutils.isDevSdkInRange import org.junit.After import org.junit.Before +import org.junit.Rule import org.junit.Test import org.junit.runner.RunWith import org.mockito.Mockito.doReturn @@ -41,6 +46,7 @@ import org.mockito.Mockito.verifyNoMoreInteractions import java.util.UUID import kotlin.test.assertEquals import kotlin.test.assertNotEquals +import kotlin.test.fail private const val DEFAULT_TIMEOUT_MS = 5000L private val instrumentation: Instrumentation @@ -51,6 +57,8 @@ private val PROVIDER_NAME = "NetworkProviderTest" @RunWith(DevSdkIgnoreRunner::class) @IgnoreUpTo(Build.VERSION_CODES.Q) class NetworkProviderTest { + @Rule @JvmField + val mIgnoreRule = DevSdkIgnoreRule() private val mCm = context.getSystemService(ConnectivityManager::class.java) private val mHandlerThread = HandlerThread("${javaClass.simpleName} handler thread") @@ -68,6 +76,7 @@ class NetworkProviderTest { private class TestNetworkProvider(context: Context, looper: Looper) : NetworkProvider(context, looper, PROVIDER_NAME) { + private val TAG = this::class.simpleName private val seenEvents = ArrayTrackRecord().newReadHead() sealed class CallbackEntry { @@ -80,22 +89,30 @@ class NetworkProviderTest { } override fun onNetworkRequested(request: NetworkRequest, score: Int, id: Int) { + Log.d(TAG, "onNetworkRequested $request, $score, $id") seenEvents.add(OnNetworkRequested(request, score, id)) } override fun onNetworkRequestWithdrawn(request: NetworkRequest) { + Log.d(TAG, "onNetworkRequestWithdrawn $request") seenEvents.add(OnNetworkRequestWithdrawn(request)) } - inline fun expectCallback( + inline fun eventuallyExpectCallbackThat( crossinline predicate: (T) -> Boolean ) = seenEvents.poll(DEFAULT_TIMEOUT_MS) { it is T && predicate(it) } + ?: fail("Did not receive callback after ${DEFAULT_TIMEOUT_MS}ms") } private fun createNetworkProvider(ctx: Context = context): TestNetworkProvider { return TestNetworkProvider(ctx, mHandlerThread.looper) } + // In S+ framework, do not run this test, since the provider will no longer receive + // onNetworkRequested for every request. Instead, provider needs to + // call {@code registerNetworkOffer} with the description of networks they + // might have ability to setup, and expects {@link NetworkOfferCallback#onNetworkNeeded}. + @IgnoreAfter(Build.VERSION_CODES.R) @Test fun testOnNetworkRequested() { val provider = createNetworkProvider() @@ -105,13 +122,15 @@ class NetworkProviderTest { val specifier = CompatUtil.makeTestNetworkSpecifier( UUID.randomUUID().toString()) + // Test network is not allowed to be trusted. val nr: NetworkRequest = NetworkRequest.Builder() .addTransportType(TRANSPORT_TEST) + .removeCapability(NET_CAPABILITY_TRUSTED) .setNetworkSpecifier(specifier) .build() val cb = ConnectivityManager.NetworkCallback() mCm.requestNetwork(nr, cb) - provider.expectCallback() { callback -> + provider.eventuallyExpectCallbackThat() { callback -> callback.request.getNetworkSpecifier() == specifier && callback.request.hasTransport(TRANSPORT_TEST) } @@ -131,22 +150,24 @@ class NetworkProviderTest { val config = NetworkAgentConfig.Builder().build() val agent = object : NetworkAgent(context, mHandlerThread.looper, "TestAgent", nc, lp, initialScore, config, provider) {} + agent.register() + agent.markConnected() - provider.expectCallback() { callback -> + provider.eventuallyExpectCallbackThat() { callback -> callback.request.getNetworkSpecifier() == specifier && callback.score == initialScore && callback.id == agent.providerId } agent.sendNetworkScore(updatedScore) - provider.expectCallback() { callback -> + provider.eventuallyExpectCallbackThat() { callback -> callback.request.getNetworkSpecifier() == specifier && callback.score == updatedScore && callback.id == agent.providerId } mCm.unregisterNetworkCallback(cb) - provider.expectCallback() { callback -> + provider.eventuallyExpectCallbackThat() { callback -> callback.request.getNetworkSpecifier() == specifier && callback.request.hasTransport(TRANSPORT_TEST) }