diff --git a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt index 45a84f8985..85d0a2e7eb 100644 --- a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt +++ b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt @@ -82,6 +82,7 @@ import org.mockito.Mockito.mock import org.mockito.Mockito.verify import java.net.InetAddress import java.time.Duration +import java.util.Arrays import java.util.UUID import kotlin.test.assertEquals import kotlin.test.assertFailsWith @@ -312,9 +313,11 @@ class NetworkAgentTest { private fun createNetworkAgent( context: Context = realContext, - name: String? = null + name: String? = null, + nc: NetworkCapabilities = NetworkCapabilities(), + lp: LinkProperties = LinkProperties() ): TestableNetworkAgent { - val nc = NetworkCapabilities().apply { + nc.apply { addTransportType(NetworkCapabilities.TRANSPORT_TEST) removeCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED) removeCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET) @@ -325,7 +328,7 @@ class NetworkAgentTest { setNetworkSpecifier(StringNetworkSpecifier(name)) } } - val lp = LinkProperties().apply { + lp.apply { addLinkAddress(LinkAddress(LOCAL_IPV4_ADDRESS, 0)) } val config = NetworkAgentConfig.Builder().build() @@ -541,8 +544,63 @@ class NetworkAgentTest { // tearDown() will unregister the requests and agents } + private fun hasAllTransports(nc: NetworkCapabilities?, transports: IntArray) = + nc != null && transports.all { nc.hasTransport(it) } + @Test - @IgnoreUpTo(android.os.Build.VERSION_CODES.R) + @IgnoreUpTo(Build.VERSION_CODES.R) + fun testSetUnderlyingNetworks() { + val request = NetworkRequest.Builder() + .addTransportType(NetworkCapabilities.TRANSPORT_TEST) + .addTransportType(NetworkCapabilities.TRANSPORT_VPN) + .removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN) + .removeCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED) // TODO: add to VPN! + .build() + val callback = TestableNetworkCallback() + mCM.registerNetworkCallback(request, callback) + + val nc = NetworkCapabilities().apply { + addTransportType(NetworkCapabilities.TRANSPORT_TEST) + addTransportType(NetworkCapabilities.TRANSPORT_VPN) + removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN) + } + val defaultNetwork = mCM.activeNetwork + assertNotNull(defaultNetwork) + val defaultNetworkTransports = mCM.getNetworkCapabilities(defaultNetwork).transportTypes + + val agent = createNetworkAgent(nc = nc) + agent.register() + agent.markConnected() + callback.expectAvailableThenValidatedCallbacks(agent.network!!) + + var vpnNc = mCM.getNetworkCapabilities(agent.network) + assertNotNull(vpnNc) + assertTrue(NetworkCapabilities.TRANSPORT_VPN in vpnNc.transportTypes) + assertTrue(hasAllTransports(vpnNc, defaultNetworkTransports), + "VPN transports ${Arrays.toString(vpnNc.transportTypes)}" + + " lacking transports from ${Arrays.toString(defaultNetworkTransports)}") + + agent.setUnderlyingNetworks(listOf()) + callback.expectCapabilitiesThat(agent.network!!) { + it.transportTypes.size == 1 && it.hasTransport(NetworkCapabilities.TRANSPORT_VPN) + } + + val expectedTransports = (defaultNetworkTransports.toSet() + + NetworkCapabilities.TRANSPORT_VPN).toIntArray() + agent.setUnderlyingNetworks(null) + callback.expectCapabilitiesThat(agent.network!!) { + it.transportTypes.size == expectedTransports.size && + hasAllTransports(it, expectedTransports) + } + + agent.unregister() + callback.expectCallback(agent.network) + + mCM.unregisterNetworkCallback(callback) + } + + @Test + @IgnoreUpTo(Build.VERSION_CODES.R) fun testAgentStartsInConnecting() { val mockContext = mock(Context::class.java) val mockCm = mock(ConnectivityManager::class.java)