1 /* 2 * Copyright (C) 2022 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.testutils 18 19 import android.net.Network 20 import android.util.Log 21 import com.android.internal.annotations.GuardedBy 22 import com.android.internal.annotations.VisibleForTesting 23 import com.android.internal.annotations.VisibleForTesting.Visibility.PRIVATE 24 import com.android.net.module.util.DnsPacket 25 import java.net.DatagramPacket 26 import java.net.DatagramSocket 27 import java.net.InetAddress 28 import java.net.InetSocketAddress 29 import java.net.SocketAddress 30 import java.net.SocketException 31 import java.util.ArrayList 32 33 private const val TAG = "TestDnsServer" 34 private const val VDBG = true 35 @VisibleForTesting(visibility = PRIVATE) 36 const val MAX_BUF_SIZE = 8192 37 38 /** 39 * A simple implementation of Dns Server that can be bound on specific address and Network. 40 * 41 * The caller should use start() to make the server start a new thread to receive DNS queries 42 * on the bound address, [isAlive] to check status, and stop() for stopping. 43 * The server allows user to manipulate the records to be answered through 44 * [setAnswer] at runtime. 45 * 46 * This server runs on its own thread. Please make sure writing the query to the socket 47 * happens-after using [setAnswer] to guarantee the correct answer is returned. If possible, 48 * use [setAnswer] before calling [start] for simplicity. 49 */ 50 class TestDnsServer(network: Network, addr: InetSocketAddress) { 51 enum class Status { 52 NOT_STARTED, STARTED, STOPPED 53 } 54 @GuardedBy("thread") 55 private var status: Status = Status.NOT_STARTED 56 private val thread = ReceivingThread() <lambda>null57 private val socket = DatagramSocket(addr).also { network.bindSocket(it) } 58 private val ansProvider = DnsAnswerProvider() 59 60 // The buffer to store the received packet. They are being reused for 61 // efficiency and it's fine because they are only ever accessed 62 // on the server thread in a sequential manner. 63 private val buffer = ByteArray(MAX_BUF_SIZE) 64 private val packet = DatagramPacket(buffer, buffer.size) 65 setAnswernull66 fun setAnswer(hostname: String, answer: List<InetAddress>) = 67 ansProvider.setAnswer(hostname, answer) 68 69 private fun processPacket() { 70 // Blocking read and try construct a DnsQueryPacket object. 71 socket.receive(packet) 72 val q = DnsQueryPacket(packet.data) 73 handleDnsQuery(q, packet.socketAddress) 74 } 75 76 // TODO: Add support to reply some error with a DNS reply packet with failure RCODE. handleDnsQuerynull77 private fun handleDnsQuery(q: DnsQueryPacket, src: SocketAddress) { 78 val queryRecords = q.queryRecords 79 if (queryRecords.size != 1) { 80 throw IllegalArgumentException( 81 "Expected one dns query record but got ${queryRecords.size}" 82 ) 83 } 84 val answerRecords = queryRecords[0].let { ansProvider.getAnswer(it.dName, it.nsType) } 85 86 if (VDBG) { 87 Log.v(TAG, "handleDnsPacket: " + 88 queryRecords.map { "${it.dName},${it.nsType}" }.joinToString() + 89 " ansCount=${answerRecords.size} socketAddress=$src") 90 } 91 92 val bytes = q.getAnswerPacket(answerRecords).bytes 93 val reply = DatagramPacket(bytes, bytes.size, src) 94 socket.send(reply) 95 } 96 startnull97 fun start() { 98 synchronized(thread) { 99 if (status != Status.NOT_STARTED) { 100 throw IllegalStateException("unexpected status: $status") 101 } 102 thread.start() 103 status = Status.STARTED 104 } 105 } stopnull106 fun stop() { 107 synchronized(thread) { 108 if (status != Status.STARTED) { 109 throw IllegalStateException("unexpected status: $status") 110 } 111 // The thread needs to be interrupted before closing the socket to prevent a data 112 // race where the thread tries to read from the socket while it's being closed. 113 // DatagramSocket is not thread-safe and running both concurrently can end up in 114 // getPort() returning -1 after it's been checked not to, resulting in a crash by 115 // IllegalArgumentException inside the DatagramSocket implementation. 116 thread.interrupt() 117 socket.close() 118 thread.join() 119 status = Status.STOPPED 120 } 121 } 122 val isAlive get() = thread.isAlive 123 val port get() = socket.localPort 124 125 inner class ReceivingThread : Thread() { runnull126 override fun run() { 127 while (!interrupted() && !socket.isClosed) { 128 try { 129 processPacket() 130 } catch (e: InterruptedException) { 131 // The caller terminated the server, exit. 132 break 133 } catch (e: SocketException) { 134 // The caller terminated the server, exit. 135 break 136 } 137 } 138 Log.i(TAG, "exiting socket={$socket}") 139 } 140 } 141 142 @VisibleForTesting(visibility = PRIVATE) 143 class DnsQueryPacket : DnsPacket { 144 constructor(data: ByteArray) : super(data) 145 constructor(header: DnsHeader, qd: List<DnsRecord>, an: List<DnsRecord>) : 146 super(header, qd, an) 147 148 init { 149 if (mHeader.isResponse) { 150 throw ParseException("Not a query packet") 151 } 152 } 153 154 val queryRecords: List<DnsRecord> 155 get() = mRecords[QDSECTION] 156 getAnswerPacketnull157 fun getAnswerPacket(ar: List<DnsRecord>): DnsAnswerPacket { 158 // Set QR bit of flag to 1 for response packet according to RFC 1035 section 4.1.1. 159 val flags = 1 shl 15 160 val qr = ArrayList(mRecords[QDSECTION]) 161 // Copy the query packet header id to the answer packet as RFC 1035 section 4.1.1. 162 val header = DnsHeader(mHeader.id, flags, qr.size, ar.size) 163 return DnsAnswerPacket(header, qr, ar) 164 } 165 } 166 167 class DnsAnswerPacket : DnsPacket { 168 constructor(header: DnsHeader, qr: List<DnsRecord>, ar: List<DnsRecord>) : 169 super(header, qr, ar) 170 @VisibleForTesting(visibility = PRIVATE) 171 constructor(bytes: ByteArray) : super(bytes) 172 } 173 } 174