diff --git a/tests/integration/src/com/android/server/net/integrationtests/TestNetworkStackService.kt b/tests/integration/src/com/android/server/net/integrationtests/TestNetworkStackService.kt index eff66584d6..c7cf040bc7 100644 --- a/tests/integration/src/com/android/server/net/integrationtests/TestNetworkStackService.kt +++ b/tests/integration/src/com/android/server/net/integrationtests/TestNetworkStackService.kt @@ -36,7 +36,6 @@ import org.mockito.Mockito.spy import java.io.ByteArrayInputStream import java.net.HttpURLConnection import java.net.URL -import java.net.URLConnection import java.nio.charset.StandardCharsets private const val TEST_NETID = 42 @@ -63,6 +62,28 @@ class TestNetworkStackService : Service() { override fun getPrivateDnsBypassNetwork(network: Network?) = privateDnsBypassNetwork } + /** + * Mock [HttpURLConnection] to simulate reply from a server. + */ + private class MockConnection( + url: URL, + private val response: HttpResponse + ) : HttpURLConnection(url) { + private val responseBytes = response.content.toByteArray(StandardCharsets.UTF_8) + override fun getResponseCode() = response.responseCode + override fun getContentLengthLong() = responseBytes.size.toLong() + override fun getHeaderField(field: String): String? { + return when (field) { + "location" -> response.redirectUrl + else -> null + } + } + override fun getInputStream() = ByteArrayInputStream(responseBytes) + override fun connect() = Unit + override fun disconnect() = Unit + override fun usingProxy() = false + } + private inner class TestNetworkStackConnector(context: Context) : NetworkStackConnector( context, TestPermissionChecker(), NetworkStackService.Dependencies()) { @@ -70,17 +91,8 @@ class TestNetworkStackService : Service() { private val privateDnsBypassNetwork = TestNetwork(TEST_NETID) private inner class TestNetwork(netId: Int) : Network(netId) { - override fun openConnection(url: URL): URLConnection { - val response = InstrumentationConnector.processRequest(url) - val responseBytes = response.content.toByteArray(StandardCharsets.UTF_8) - - val connection = mock(HttpURLConnection::class.java) - doReturn(response.responseCode).`when`(connection).responseCode - doReturn(responseBytes.size.toLong()).`when`(connection).contentLengthLong - doReturn(response.redirectUrl).`when`(connection).getHeaderField("location") - doReturn(ByteArrayInputStream(responseBytes)).`when`(connection).inputStream - return connection - } + override fun openConnection(url: URL) = MockConnection( + url, InstrumentationConnector.processRequest(url)) } override fun makeNetworkMonitor(