Merge "Revert "[ST03] Add test dns server for integration tests""

This commit is contained in:
Treehugger Robot
2022-11-16 05:13:03 +00:00
committed by Gerrit Code Review
3 changed files with 0 additions and 356 deletions

View File

@@ -1,122 +0,0 @@
/*
* Copyright (C) 2022 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.testutils
import android.net.DnsResolver.CLASS_IN
import android.net.DnsResolver.TYPE_AAAA
import android.net.Network
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.filters.SmallTest
import com.android.net.module.util.DnsPacket
import com.android.net.module.util.DnsPacket.DnsRecord
import libcore.net.InetAddressUtils
import org.junit.After
import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.Mockito
import java.net.DatagramPacket
import java.net.DatagramSocket
import java.net.InetAddress
import java.net.InetSocketAddress
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertFalse
import kotlin.test.assertTrue
val TEST_V6_ADDR = InetAddressUtils.parseNumericAddress("2001:db8::3")
const val TEST_DOMAIN = "hello.example.com"
@RunWith(AndroidJUnit4::class)
@SmallTest
class TestDnsServerTest {
private val network = Mockito.mock(Network::class.java)
private val localAddr = InetSocketAddress(InetAddress.getLocalHost(), 0 /* port */)
private val testServer: TestDnsServer = TestDnsServer(network, localAddr)
@After
fun tearDown() {
if (testServer.isAlive) testServer.stop()
}
@Test
fun testStartStop() {
repeat(100) {
val server = TestDnsServer(network, localAddr)
server.start()
assertTrue(server.isAlive)
server.stop()
assertFalse(server.isAlive)
}
// Test illegal start/stop.
assertFailsWith<IllegalStateException> { testServer.stop() }
testServer.start()
assertTrue(testServer.isAlive)
assertFailsWith<IllegalStateException> { testServer.start() }
testServer.stop()
assertFalse(testServer.isAlive)
assertFailsWith<IllegalStateException> { testServer.stop() }
// TestDnsServer rejects start after stop.
assertFailsWith<IllegalStateException> { testServer.start() }
}
@Test
fun testHandleDnsQuery() {
testServer.setAnswer(TEST_DOMAIN, listOf(TEST_V6_ADDR))
testServer.start()
// Mock query and send it to the test server.
val queryHeader = DnsPacket.DnsHeader(0xbeef /* id */,
0x0 /* flag */, 1 /* qcount */, 0 /* ancount */)
val qlist = listOf(DnsRecord.makeQuestion(TEST_DOMAIN, TYPE_AAAA, CLASS_IN))
val queryPacket = TestDnsServer.DnsQueryPacket(queryHeader, qlist, emptyList())
val response = resolve(queryPacket, testServer.port)
// Verify expected answer packet. Set QR bit of flag to 1 for response packet
// according to RFC 1035 section 4.1.1.
val answerHeader = DnsPacket.DnsHeader(0xbeef,
1 shl 15 /* flag */, 1 /* qcount */, 1 /* ancount */)
val alist = listOf(DnsRecord.makeAOrAAAARecord(DnsPacket.ANSECTION, TEST_DOMAIN,
CLASS_IN, DEFAULT_TTL_S, TEST_V6_ADDR))
val expectedAnswerPacket = TestDnsServer.DnsAnswerPacket(answerHeader, qlist, alist)
assertEquals(expectedAnswerPacket, response)
// Clean up the server in tearDown.
}
private fun resolve(queryDnsPacket: DnsPacket, serverPort: Int): TestDnsServer.DnsAnswerPacket {
val bytes = queryDnsPacket.bytes
// Create a new client socket, the socket will be bound to a
// random port other than the server port.
val socket = DatagramSocket(localAddr).also { it.soTimeout = 100 }
val queryPacket = DatagramPacket(bytes, bytes.size, localAddr.address, serverPort)
// Send query and wait for the reply.
socket.send(queryPacket)
val buffer = ByteArray(MAX_BUF_SIZE)
val reply = DatagramPacket(buffer, buffer.size)
socket.receive(reply)
return TestDnsServer.DnsAnswerPacket(reply.data)
}
// TODO: Add more tests, which includes:
// * Empty question RR packet (or more unexpected states)
// * No answer found (setAnswer empty list at L.78)
// * Test one or multi A record(s)
// * Test multi AAAA records
// * Test CNAME records
}

View File

@@ -1,65 +0,0 @@
/*
* Copyright (C) 2022 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.testutils
import android.net.DnsResolver.CLASS_IN
import com.android.net.module.util.DnsPacket
import com.android.net.module.util.DnsPacket.ANSECTION
import java.net.InetAddress
import java.util.concurrent.ConcurrentHashMap
const val DEFAULT_TTL_S = 5L
/**
* Helper class to store the mapping of DNS queries.
*
* DnsAnswerProvider is built atop a ConcurrentHashMap and as such it provides the same
* guarantees as ConcurrentHashMap between writing and reading elements. Specifically :
* - Setting an answer happens-before reading the same answer.
* - Callers can read and write concurrently from DnsAnswerProvider and expect no
* ConcurrentModificationException.
* Freshness of the answers depends on ordering of the threads ; if callers need a
* freshness guarantee, they need to provide the happens-before relationship from a
* write that they want to observe to the read that they need to be observed.
*/
class DnsAnswerProvider {
private val mDnsKeyToRecords = ConcurrentHashMap<String, List<DnsPacket.DnsRecord>>()
/**
* Get answer for the specified hostname.
*
* @param query the target hostname.
* @param type type of record, could be A or AAAA.
*
* @return list of [DnsPacket.DnsRecord] associated to the query. Empty if no record matches.
*/
fun getAnswer(query: String, type: Int) = mDnsKeyToRecords[query]
.orEmpty().filter { it.nsType == type }
/** Set answer for the specified {@code query}.
*
* @param query the target hostname
* @param addresses [List<InetAddress>] which could be used to generate multiple A or AAAA
* RRs with the corresponding addresses.
*/
fun setAnswer(query: String, hosts: List<InetAddress>) = mDnsKeyToRecords.put(query, hosts.map {
DnsPacket.DnsRecord.makeAOrAAAARecord(ANSECTION, query, CLASS_IN, DEFAULT_TTL_S, it)
})
fun clearAnswer(query: String) = mDnsKeyToRecords.remove(query)
fun clearAll() = mDnsKeyToRecords.clear()
}

View File

@@ -1,169 +0,0 @@
/*
* Copyright (C) 2022 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.testutils
import android.net.Network
import android.util.Log
import com.android.internal.annotations.GuardedBy
import com.android.internal.annotations.VisibleForTesting
import com.android.internal.annotations.VisibleForTesting.Visibility.PRIVATE
import com.android.net.module.util.DnsPacket
import java.net.DatagramPacket
import java.net.DatagramSocket
import java.net.InetAddress
import java.net.InetSocketAddress
import java.net.SocketAddress
import java.net.SocketException
import java.util.ArrayList
private const val TAG = "TestDnsServer"
private const val VDBG = true
@VisibleForTesting(visibility = PRIVATE)
const val MAX_BUF_SIZE = 8192
/**
* A simple implementation of Dns Server that can be bound on specific address and Network.
*
* The caller should use start() to make the server start a new thread to receive DNS queries
* on the bound address, [isAlive] to check status, and stop() for stopping.
* The server allows user to manipulate the records to be answered through
* [setAnswer] at runtime.
*
* This server runs on its own thread. Please make sure writing the query to the socket
* happens-after using [setAnswer] to guarantee the correct answer is returned. If possible,
* use [setAnswer] before calling [start] for simplicity.
*/
class TestDnsServer(network: Network, addr: InetSocketAddress) {
enum class Status {
NOT_STARTED, STARTED, STOPPED
}
@GuardedBy("thread")
private var status: Status = Status.NOT_STARTED
private val thread = ReceivingThread()
private val socket = DatagramSocket(addr).also { network.bindSocket(it) }
private val ansProvider = DnsAnswerProvider()
// The buffer to store the received packet. They are being reused for
// efficiency and it's fine because they are only ever accessed
// on the server thread in a sequential manner.
private val buffer = ByteArray(MAX_BUF_SIZE)
private val packet = DatagramPacket(buffer, buffer.size)
fun setAnswer(hostname: String, answer: List<InetAddress>) =
ansProvider.setAnswer(hostname, answer)
private fun processPacket() {
// Blocking read and try construct a DnsQueryPacket object.
socket.receive(packet)
val q = DnsQueryPacket(packet.data)
handleDnsQuery(q, packet.socketAddress)
}
// TODO: Add support to reply some error with a DNS reply packet with failure RCODE.
private fun handleDnsQuery(q: DnsQueryPacket, src: SocketAddress) {
val queryRecords = q.queryRecords
if (queryRecords.size != 1) {
throw IllegalArgumentException(
"Expected one dns query record but got ${queryRecords.size}"
)
}
val answerRecords = queryRecords[0].let { ansProvider.getAnswer(it.dName, it.nsType) }
if (VDBG) {
Log.v(TAG, "handleDnsPacket: " +
queryRecords.map { "${it.dName},${it.nsType}" }.joinToString() +
" ansCount=${answerRecords.size} socketAddress=$src")
}
val bytes = q.getAnswerPacket(answerRecords).bytes
val reply = DatagramPacket(bytes, bytes.size, src)
socket.send(reply)
}
fun start() {
synchronized(thread) {
if (status != Status.NOT_STARTED) {
throw IllegalStateException("unexpected status: $status")
}
thread.start()
status = Status.STARTED
}
}
fun stop() {
synchronized(thread) {
if (status != Status.STARTED) {
throw IllegalStateException("unexpected status: $status")
}
socket.close()
thread.interrupt()
thread.join()
status = Status.STOPPED
}
}
val isAlive get() = thread.isAlive
val port get() = socket.localPort
inner class ReceivingThread : Thread() {
override fun run() {
Log.i(TAG, "starting addr={${socket.localSocketAddress}}")
while (!interrupted() && !socket.isClosed) {
try {
processPacket()
} catch (e: InterruptedException) {
// The caller terminated the server, exit.
break
} catch (e: SocketException) {
// The caller terminated the server, exit.
break
}
}
Log.i(TAG, "exiting socket={$socket}")
}
}
@VisibleForTesting(visibility = PRIVATE)
class DnsQueryPacket : DnsPacket {
constructor(data: ByteArray) : super(data)
constructor(header: DnsHeader, qd: List<DnsRecord>, an: List<DnsRecord>) :
super(header, qd, an)
init {
if (mHeader.isResponse) {
throw ParseException("Not a query packet")
}
}
val queryRecords: List<DnsRecord>
get() = mRecords[QDSECTION]
fun getAnswerPacket(ar: List<DnsRecord>): DnsAnswerPacket {
// Set QR bit of flag to 1 for response packet according to RFC 1035 section 4.1.1.
val flags = 1 shl 15
val qr = ArrayList(mRecords[QDSECTION])
// Copy the query packet header id to the answer packet as RFC 1035 section 4.1.1.
val header = DnsHeader(mHeader.id, flags, qr.size, ar.size)
return DnsAnswerPacket(header, qr, ar)
}
}
class DnsAnswerPacket : DnsPacket {
constructor(header: DnsHeader, qr: List<DnsRecord>, ar: List<DnsRecord>) :
super(header, qr, ar)
@VisibleForTesting(visibility = PRIVATE)
constructor(bytes: ByteArray) : super(bytes)
}
}