diff --git a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt index 32f2bfafe2..14f52e1e7c 100644 --- a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt +++ b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt @@ -24,7 +24,9 @@ import android.net.LinkProperties import android.net.Network import android.net.NetworkAgent import android.net.NetworkAgent.CMD_ADD_KEEPALIVE_PACKET_FILTER +import android.net.NetworkAgent.CMD_PREVENT_AUTOMATIC_RECONNECT import android.net.NetworkAgent.CMD_REMOVE_KEEPALIVE_PACKET_FILTER +import android.net.NetworkAgent.CMD_SAVE_ACCEPT_UNVALIDATED import android.net.NetworkAgent.CMD_START_SOCKET_KEEPALIVE import android.net.NetworkAgent.CMD_STOP_SOCKET_KEEPALIVE import android.net.NetworkAgentConfig @@ -39,9 +41,11 @@ import android.os.Looper import android.os.Message import android.os.Messenger import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnAddKeepalivePacketFilter +import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnAutomaticReconnectDisabled import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnBandwidthUpdateRequested import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnNetworkUnwanted import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnRemoveKeepalivePacketFilter +import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnSaveAcceptUnvalidated import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnStartSocketKeepalive import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnStopSocketKeepalive import androidx.test.InstrumentationRegistry @@ -60,6 +64,7 @@ import org.junit.runner.RunWith import java.net.InetAddress import java.time.Duration import kotlin.test.assertEquals +import kotlin.test.assertFalse import kotlin.test.assertFailsWith import kotlin.test.assertNotNull import kotlin.test.assertNull @@ -123,27 +128,38 @@ class NetworkAgentTest { * only keeps track of one async channel. */ private class FakeConnectivityService(looper: Looper) { + private val CMD_EXPECT_DISCONNECT = 1 + private var disconnectExpected = false private val msgHistory = ArrayTrackRecord().newReadHead() private val asyncChannel = AsyncChannel() private val handler = object : Handler(looper) { override fun handleMessage(msg: Message) { msgHistory.add(Message.obtain(msg)) // make a copy as the original will be recycled when (msg.what) { + CMD_EXPECT_DISCONNECT -> disconnectExpected = true AsyncChannel.CMD_CHANNEL_HALF_CONNECTED -> asyncChannel.sendMessage(AsyncChannel.CMD_CHANNEL_FULL_CONNECTION) - AsyncChannel.CMD_CHANNEL_DISCONNECT, AsyncChannel.CMD_CHANNEL_DISCONNECTED -> - fail("Agent unexpectedly disconnected") + AsyncChannel.CMD_CHANNEL_DISCONNECTED -> + if (!disconnectExpected) { + fail("Agent unexpectedly disconnected") + } else { + disconnectExpected = false + } } } } fun connect(agentMsngr: Messenger) = asyncChannel.connect(context, handler, agentMsngr) + fun disconnect() = asyncChannel.disconnect() + fun sendMessage(what: Int, arg1: Int = 0, arg2: Int = 0, obj: Any? = null) = asyncChannel.sendMessage(Message(what, arg1, arg2, obj)) fun expectMessage(what: Int) = assertNotNull(msgHistory.poll(DEFAULT_TIMEOUT_MS) { it.what == what }) + + fun willExpectDisconnectOnce() = handler.sendEmptyMessage(CMD_EXPECT_DISCONNECT) } private open class TestableNetworkAgent( @@ -169,6 +185,8 @@ class NetworkAgentTest { val packet: KeepalivePacketData ) : CallbackEntry() data class OnStopSocketKeepalive(val slot: Int) : CallbackEntry() + data class OnSaveAcceptUnvalidated(val accept: Boolean) : CallbackEntry() + object OnAutomaticReconnectDisabled : CallbackEntry() } override fun onBandwidthUpdateRequested() { @@ -199,6 +217,14 @@ class NetworkAgentTest { history.add(OnStopSocketKeepalive(slot)) } + override fun onSaveAcceptUnvalidated(accept: Boolean) { + history.add(OnSaveAcceptUnvalidated(accept)) + } + + override fun onAutomaticReconnectDisabled() { + history.add(OnAutomaticReconnectDisabled) + } + inline fun expectCallback(): T { val foundCallback = history.poll(DEFAULT_TIMEOUT_MS) assertTrue(foundCallback is T, "Expected ${T::class} but found $foundCallback") @@ -315,4 +341,40 @@ class NetworkAgentTest { assertEquals(it.slot, slot) } } + + @Test + fun testSetAcceptUnvalidated() { + createNetworkAgentWithFakeCS().let { agent -> + mFakeConnectivityService.sendMessage(CMD_SAVE_ACCEPT_UNVALIDATED, 1) + agent.expectCallback().let { + assertTrue(it.accept) + } + agent.assertNoCallback() + } + createNetworkAgentWithFakeCS().let { agent -> + mFakeConnectivityService.sendMessage(CMD_SAVE_ACCEPT_UNVALIDATED, 0) + mFakeConnectivityService.sendMessage(CMD_PREVENT_AUTOMATIC_RECONNECT) + agent.expectCallback().let { + assertFalse(it.accept) + } + agent.expectCallback() + agent.assertNoCallback() + // When automatic reconnect is turned off, the network is torn down and + // ConnectivityService sends a disconnect. This in turn causes the agent + // to send a DISCONNECTED message to CS. + mFakeConnectivityService.willExpectDisconnectOnce() + mFakeConnectivityService.disconnect() + mFakeConnectivityService.expectMessage(AsyncChannel.CMD_CHANNEL_DISCONNECTED) + agent.expectCallback() + } + createNetworkAgentWithFakeCS().let { agent -> + mFakeConnectivityService.sendMessage(CMD_PREVENT_AUTOMATIC_RECONNECT) + agent.expectCallback() + agent.assertNoCallback() + mFakeConnectivityService.willExpectDisconnectOnce() + mFakeConnectivityService.disconnect() + mFakeConnectivityService.expectMessage(AsyncChannel.CMD_CHANNEL_DISCONNECTED) + agent.expectCallback() + } + } }