• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.testutils
18 
19 import android.net.Network
20 import android.util.Log
21 import com.android.internal.annotations.GuardedBy
22 import com.android.internal.annotations.VisibleForTesting
23 import com.android.internal.annotations.VisibleForTesting.Visibility.PRIVATE
24 import com.android.net.module.util.DnsPacket
25 import java.net.DatagramPacket
26 import java.net.DatagramSocket
27 import java.net.InetAddress
28 import java.net.InetSocketAddress
29 import java.net.SocketAddress
30 import java.net.SocketException
31 import java.util.ArrayList
32 
33 private const val TAG = "TestDnsServer"
34 private const val VDBG = true
35 @VisibleForTesting(visibility = PRIVATE)
36 const val MAX_BUF_SIZE = 8192
37 
38 /**
39  * A simple implementation of Dns Server that can be bound on specific address and Network.
40  *
41  * The caller should use start() to make the server start a new thread to receive DNS queries
42  * on the bound address, [isAlive] to check status, and stop() for stopping.
43  * The server allows user to manipulate the records to be answered through
44  * [setAnswer] at runtime.
45  *
46  * This server runs on its own thread. Please make sure writing the query to the socket
47  * happens-after using [setAnswer] to guarantee the correct answer is returned. If possible,
48  * use [setAnswer] before calling [start] for simplicity.
49  */
50 class TestDnsServer(network: Network, addr: InetSocketAddress) {
51     enum class Status {
52         NOT_STARTED, STARTED, STOPPED
53     }
54     @GuardedBy("thread")
55     private var status: Status = Status.NOT_STARTED
56     private val thread = ReceivingThread()
<lambda>null57     private val socket = DatagramSocket(addr).also { network.bindSocket(it) }
58     private val ansProvider = DnsAnswerProvider()
59 
60     // The buffer to store the received packet. They are being reused for
61     // efficiency and it's fine because they are only ever accessed
62     // on the server thread in a sequential manner.
63     private val buffer = ByteArray(MAX_BUF_SIZE)
64     private val packet = DatagramPacket(buffer, buffer.size)
65 
setAnswernull66     fun setAnswer(hostname: String, answer: List<InetAddress>) =
67         ansProvider.setAnswer(hostname, answer)
68 
69     private fun processPacket() {
70         // Blocking read and try construct a DnsQueryPacket object.
71         socket.receive(packet)
72         val q = DnsQueryPacket(packet.data)
73         handleDnsQuery(q, packet.socketAddress)
74     }
75 
76     // TODO: Add support to reply some error with a DNS reply packet with failure RCODE.
handleDnsQuerynull77     private fun handleDnsQuery(q: DnsQueryPacket, src: SocketAddress) {
78         val queryRecords = q.queryRecords
79         if (queryRecords.size != 1) {
80             throw IllegalArgumentException(
81                 "Expected one dns query record but got ${queryRecords.size}"
82             )
83         }
84         val answerRecords = queryRecords[0].let { ansProvider.getAnswer(it.dName, it.nsType) }
85 
86         if (VDBG) {
87             Log.v(TAG, "handleDnsPacket: " +
88                         queryRecords.map { "${it.dName},${it.nsType}" }.joinToString() +
89                         " ansCount=${answerRecords.size} socketAddress=$src")
90         }
91 
92         val bytes = q.getAnswerPacket(answerRecords).bytes
93         val reply = DatagramPacket(bytes, bytes.size, src)
94         socket.send(reply)
95     }
96 
startnull97     fun start() {
98         synchronized(thread) {
99             if (status != Status.NOT_STARTED) {
100                 throw IllegalStateException("unexpected status: $status")
101             }
102             thread.start()
103             status = Status.STARTED
104         }
105     }
stopnull106     fun stop() {
107         synchronized(thread) {
108             if (status != Status.STARTED) {
109                 throw IllegalStateException("unexpected status: $status")
110             }
111             // The thread needs to be interrupted before closing the socket to prevent a data
112             // race where the thread tries to read from the socket while it's being closed.
113             // DatagramSocket is not thread-safe and running both concurrently can end up in
114             // getPort() returning -1 after it's been checked not to, resulting in a crash by
115             // IllegalArgumentException inside the DatagramSocket implementation.
116             thread.interrupt()
117             socket.close()
118             thread.join()
119             status = Status.STOPPED
120         }
121     }
122     val isAlive get() = thread.isAlive
123     val port get() = socket.localPort
124 
125     inner class ReceivingThread : Thread() {
runnull126         override fun run() {
127             while (!interrupted() && !socket.isClosed) {
128                 try {
129                     processPacket()
130                 } catch (e: InterruptedException) {
131                     // The caller terminated the server, exit.
132                     break
133                 } catch (e: SocketException) {
134                     // The caller terminated the server, exit.
135                     break
136                 }
137             }
138             Log.i(TAG, "exiting socket={$socket}")
139         }
140     }
141 
142     @VisibleForTesting(visibility = PRIVATE)
143     class DnsQueryPacket : DnsPacket {
144         constructor(data: ByteArray) : super(data)
145         constructor(header: DnsHeader, qd: List<DnsRecord>, an: List<DnsRecord>) :
146                 super(header, qd, an)
147 
148         init {
149             if (mHeader.isResponse) {
150                 throw ParseException("Not a query packet")
151             }
152         }
153 
154         val queryRecords: List<DnsRecord>
155             get() = mRecords[QDSECTION]
156 
getAnswerPacketnull157         fun getAnswerPacket(ar: List<DnsRecord>): DnsAnswerPacket {
158             // Set QR bit of flag to 1 for response packet according to RFC 1035 section 4.1.1.
159             val flags = 1 shl 15
160             val qr = ArrayList(mRecords[QDSECTION])
161             // Copy the query packet header id to the answer packet as RFC 1035 section 4.1.1.
162             val header = DnsHeader(mHeader.id, flags, qr.size, ar.size)
163             return DnsAnswerPacket(header, qr, ar)
164         }
165     }
166 
167     class DnsAnswerPacket : DnsPacket {
168         constructor(header: DnsHeader, qr: List<DnsRecord>, ar: List<DnsRecord>) :
169                 super(header, qr, ar)
170         @VisibleForTesting(visibility = PRIVATE)
171         constructor(bytes: ByteArray) : super(bytes)
172     }
173 }
174