1 /*
2 * Copyright (C) 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "dns_responder.h"
18
19 #include <arpa/inet.h>
20 #include <fcntl.h>
21 #include <netdb.h>
22 #include <stdarg.h>
23 #include <stdlib.h>
24 #include <string.h>
25 #include <sys/epoll.h>
26 #include <sys/eventfd.h>
27 #include <sys/socket.h>
28 #include <sys/types.h>
29 #include <unistd.h>
30
31 #include <chrono>
32 #include <iostream>
33 #include <set>
34 #include <vector>
35
36 #define LOG_TAG "DNSResponder"
37 #include <android-base/logging.h>
38 #include <android-base/strings.h>
39 #include <netdutils/BackoffSequence.h>
40 #include <netdutils/InternetAddresses.h>
41 #include <netdutils/Slice.h>
42 #include <netdutils/SocketOption.h>
43
44 using android::netdutils::BackoffSequence;
45 using android::netdutils::enableSockopt;
46 using android::netdutils::ScopedAddrinfo;
47 using android::netdutils::Slice;
48 using std::chrono::milliseconds;
49
50 namespace test {
51
errno2str()52 std::string errno2str() {
53 char error_msg[512] = {0};
54 // It actually calls __gnu_strerror_r() which returns the type |char*| rather than |int|.
55 // PLOG is an option though it requires lots of changes from ALOGx() to LOG(x).
56 return strerror_r(errno, error_msg, sizeof(error_msg));
57 }
58
str2hex(const char * buffer,size_t len)59 std::string str2hex(const char* buffer, size_t len) {
60 std::string str(len * 2, '\0');
61 for (size_t i = 0; i < len; ++i) {
62 static const char* hex = "0123456789ABCDEF";
63 uint8_t c = buffer[i];
64 str[i * 2] = hex[c >> 4];
65 str[i * 2 + 1] = hex[c & 0x0F];
66 }
67 return str;
68 }
69
addr2str(const sockaddr * sa,socklen_t sa_len)70 std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
71 char host_str[NI_MAXHOST] = {0};
72 int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0, NI_NUMERICHOST);
73 if (rv == 0) return std::string(host_str);
74 return std::string();
75 }
76
77 // Because The address might still being set up (b/186181084), This is a wrapper function
78 // that retries bind() if errno is EADDRNOTAVAIL
bindSocket(int socket,const sockaddr * address,socklen_t address_len)79 int bindSocket(int socket, const sockaddr* address, socklen_t address_len) {
80 // Set the wrapper to try bind() at most 6 times with backoff time
81 // (100 ms, 200 ms, ..., 1600 ms).
82 auto backoff = BackoffSequence<milliseconds>::Builder()
83 .withInitialRetransmissionTime(milliseconds(100))
84 .withMaximumRetransmissionCount(5)
85 .build();
86
87 while (true) {
88 int ret = bind(socket, address, address_len);
89 if (ret == 0 || errno != EADDRNOTAVAIL) {
90 return ret;
91 }
92
93 if (!backoff.hasNextTimeout()) break;
94
95 LOG(WARNING) << "Retry to bind " << addr2str(address, address_len);
96 std::this_thread::sleep_for(backoff.getNextTimeout());
97 }
98
99 // Set errno before return since it might have been changed somewhere.
100 errno = EADDRNOTAVAIL;
101 return -1;
102 }
103
104 /* DNS struct helpers */
105
dnstype2str(unsigned dnstype)106 const char* dnstype2str(unsigned dnstype) {
107 static std::unordered_map<unsigned, const char*> kTypeStrs = {
108 {ns_type::ns_t_a, "A"},
109 {ns_type::ns_t_ns, "NS"},
110 {ns_type::ns_t_md, "MD"},
111 {ns_type::ns_t_mf, "MF"},
112 {ns_type::ns_t_cname, "CNAME"},
113 {ns_type::ns_t_soa, "SOA"},
114 {ns_type::ns_t_mb, "MB"},
115 {ns_type::ns_t_mb, "MG"},
116 {ns_type::ns_t_mr, "MR"},
117 {ns_type::ns_t_null, "NULL"},
118 {ns_type::ns_t_wks, "WKS"},
119 {ns_type::ns_t_ptr, "PTR"},
120 {ns_type::ns_t_hinfo, "HINFO"},
121 {ns_type::ns_t_minfo, "MINFO"},
122 {ns_type::ns_t_mx, "MX"},
123 {ns_type::ns_t_txt, "TXT"},
124 {ns_type::ns_t_rp, "RP"},
125 {ns_type::ns_t_afsdb, "AFSDB"},
126 {ns_type::ns_t_x25, "X25"},
127 {ns_type::ns_t_isdn, "ISDN"},
128 {ns_type::ns_t_rt, "RT"},
129 {ns_type::ns_t_nsap, "NSAP"},
130 {ns_type::ns_t_nsap_ptr, "NSAP-PTR"},
131 {ns_type::ns_t_sig, "SIG"},
132 {ns_type::ns_t_key, "KEY"},
133 {ns_type::ns_t_px, "PX"},
134 {ns_type::ns_t_gpos, "GPOS"},
135 {ns_type::ns_t_aaaa, "AAAA"},
136 {ns_type::ns_t_loc, "LOC"},
137 {ns_type::ns_t_nxt, "NXT"},
138 {ns_type::ns_t_eid, "EID"},
139 {ns_type::ns_t_nimloc, "NIMLOC"},
140 {ns_type::ns_t_srv, "SRV"},
141 {ns_type::ns_t_naptr, "NAPTR"},
142 {ns_type::ns_t_kx, "KX"},
143 {ns_type::ns_t_cert, "CERT"},
144 {ns_type::ns_t_a6, "A6"},
145 {ns_type::ns_t_dname, "DNAME"},
146 {ns_type::ns_t_sink, "SINK"},
147 {ns_type::ns_t_opt, "OPT"},
148 {ns_type::ns_t_apl, "APL"},
149 {ns_type::ns_t_tkey, "TKEY"},
150 {ns_type::ns_t_tsig, "TSIG"},
151 {ns_type::ns_t_ixfr, "IXFR"},
152 {ns_type::ns_t_axfr, "AXFR"},
153 {ns_type::ns_t_mailb, "MAILB"},
154 {ns_type::ns_t_maila, "MAILA"},
155 {ns_type::ns_t_any, "ANY"},
156 {ns_type::ns_t_zxfr, "ZXFR"},
157 };
158 auto it = kTypeStrs.find(dnstype);
159 static const char* kUnknownStr{"UNKNOWN"};
160 if (it == kTypeStrs.end()) return kUnknownStr;
161 return it->second;
162 }
163
dnsclass2str(unsigned dnsclass)164 const char* dnsclass2str(unsigned dnsclass) {
165 static std::unordered_map<unsigned, const char*> kClassStrs = {
166 {ns_class::ns_c_in, "Internet"}, {2, "CSNet"},
167 {ns_class::ns_c_chaos, "ChaosNet"}, {ns_class::ns_c_hs, "Hesiod"},
168 {ns_class::ns_c_none, "none"}, {ns_class::ns_c_any, "any"}};
169 auto it = kClassStrs.find(dnsclass);
170 static const char* kUnknownStr{"UNKNOWN"};
171 if (it == kClassStrs.end()) return kUnknownStr;
172 return it->second;
173 }
174
dnsproto2str(int protocol)175 const char* dnsproto2str(int protocol) {
176 switch (protocol) {
177 case IPPROTO_TCP:
178 return "TCP";
179 case IPPROTO_UDP:
180 return "UDP";
181 default:
182 return "UNKNOWN";
183 }
184 }
185
read(const char * buffer,const char * buffer_end)186 const char* DNSName::read(const char* buffer, const char* buffer_end) {
187 const char* cur = buffer;
188 bool last = false;
189 do {
190 cur = parseField(cur, buffer_end, &last);
191 if (cur == nullptr) {
192 LOG(ERROR) << "parsing failed at line " << __LINE__;
193 return nullptr;
194 }
195 } while (!last);
196 return cur;
197 }
198
write(char * buffer,const char * buffer_end) const199 char* DNSName::write(char* buffer, const char* buffer_end) const {
200 char* buffer_cur = buffer;
201 for (size_t pos = 0; pos < name.size();) {
202 size_t dot_pos = name.find('.', pos);
203 if (dot_pos == std::string::npos) {
204 // Soundness check, should never happen unless parseField is broken.
205 LOG(ERROR) << "logic error: all names are expected to end with a '.'";
206 return nullptr;
207 }
208 const size_t len = dot_pos - pos;
209 if (len >= 256) {
210 LOG(ERROR) << "name component '" << name.substr(pos, dot_pos - pos) << "' is " << len
211 << " long, but max is 255";
212 return nullptr;
213 }
214 if (buffer_cur + sizeof(uint8_t) + len > buffer_end) {
215 LOG(ERROR) << "buffer overflow at line " << __LINE__;
216 return nullptr;
217 }
218 *buffer_cur++ = len;
219 buffer_cur = std::copy(std::next(name.begin(), pos), std::next(name.begin(), dot_pos),
220 buffer_cur);
221 pos = dot_pos + 1;
222 }
223 // Write final zero.
224 *buffer_cur++ = 0;
225 return buffer_cur;
226 }
227
parseField(const char * buffer,const char * buffer_end,bool * last)228 const char* DNSName::parseField(const char* buffer, const char* buffer_end, bool* last) {
229 if (buffer + sizeof(uint8_t) > buffer_end) {
230 LOG(ERROR) << "parsing failed at line " << __LINE__;
231 return nullptr;
232 }
233 unsigned field_type = *buffer >> 6;
234 unsigned ofs = *buffer & 0x3F;
235 const char* cur = buffer + sizeof(uint8_t);
236 if (field_type == 0) {
237 // length + name component
238 if (ofs == 0) {
239 *last = true;
240 return cur;
241 }
242 if (cur + ofs > buffer_end) {
243 LOG(ERROR) << "parsing failed at line " << __LINE__;
244 return nullptr;
245 }
246 name.append(cur, ofs);
247 name.push_back('.');
248 return cur + ofs;
249 } else if (field_type == 3) {
250 LOG(ERROR) << "name compression not implemented";
251 return nullptr;
252 }
253 LOG(ERROR) << "invalid name field type";
254 return nullptr;
255 }
256
read(const char * buffer,const char * buffer_end)257 const char* DNSQuestion::read(const char* buffer, const char* buffer_end) {
258 const char* cur = qname.read(buffer, buffer_end);
259 if (cur == nullptr) {
260 LOG(ERROR) << "parsing failed at line " << __LINE__;
261 return nullptr;
262 }
263 if (cur + 2 * sizeof(uint16_t) > buffer_end) {
264 LOG(ERROR) << "parsing failed at line " << __LINE__;
265 return nullptr;
266 }
267 qtype = ntohs(*reinterpret_cast<const uint16_t*>(cur));
268 qclass = ntohs(*reinterpret_cast<const uint16_t*>(cur + sizeof(uint16_t)));
269 return cur + 2 * sizeof(uint16_t);
270 }
271
write(char * buffer,const char * buffer_end) const272 char* DNSQuestion::write(char* buffer, const char* buffer_end) const {
273 char* buffer_cur = qname.write(buffer, buffer_end);
274 if (buffer_cur == nullptr) return nullptr;
275 if (buffer_cur + 2 * sizeof(uint16_t) > buffer_end) {
276 LOG(ERROR) << "buffer overflow on line " << __LINE__;
277 return nullptr;
278 }
279 *reinterpret_cast<uint16_t*>(buffer_cur) = htons(qtype);
280 *reinterpret_cast<uint16_t*>(buffer_cur + sizeof(uint16_t)) = htons(qclass);
281 return buffer_cur + 2 * sizeof(uint16_t);
282 }
283
toString() const284 std::string DNSQuestion::toString() const {
285 char buffer[16384];
286 int len = snprintf(buffer, sizeof(buffer), "Q<%s,%s,%s>", qname.name.c_str(),
287 dnstype2str(qtype), dnsclass2str(qclass));
288 return std::string(buffer, len);
289 }
290
read(const char * buffer,const char * buffer_end)291 const char* DNSRecord::read(const char* buffer, const char* buffer_end) {
292 const char* cur = name.read(buffer, buffer_end);
293 if (cur == nullptr) {
294 LOG(ERROR) << "parsing failed at line " << __LINE__;
295 return nullptr;
296 }
297 unsigned rdlen = 0;
298 cur = readIntFields(cur, buffer_end, &rdlen);
299 if (cur == nullptr) {
300 LOG(ERROR) << "parsing failed at line " << __LINE__;
301 return nullptr;
302 }
303 if (cur + rdlen > buffer_end) {
304 LOG(ERROR) << "parsing failed at line " << __LINE__;
305 return nullptr;
306 }
307 rdata.assign(cur, cur + rdlen);
308 return cur + rdlen;
309 }
310
write(char * buffer,const char * buffer_end) const311 char* DNSRecord::write(char* buffer, const char* buffer_end) const {
312 char* buffer_cur = name.write(buffer, buffer_end);
313 if (buffer_cur == nullptr) return nullptr;
314 buffer_cur = writeIntFields(rdata.size(), buffer_cur, buffer_end);
315 if (buffer_cur == nullptr) return nullptr;
316 if (buffer_cur + rdata.size() > buffer_end) {
317 LOG(ERROR) << "buffer overflow on line " << __LINE__;
318 return nullptr;
319 }
320 return std::copy(rdata.begin(), rdata.end(), buffer_cur);
321 }
322
toString() const323 std::string DNSRecord::toString() const {
324 char buffer[16384];
325 int len = snprintf(buffer, sizeof(buffer), "R<%s,%s,%s>", name.name.c_str(), dnstype2str(rtype),
326 dnsclass2str(rclass));
327 return std::string(buffer, len);
328 }
329
readIntFields(const char * buffer,const char * buffer_end,unsigned * rdlen)330 const char* DNSRecord::readIntFields(const char* buffer, const char* buffer_end, unsigned* rdlen) {
331 if (buffer + sizeof(IntFields) > buffer_end) {
332 LOG(ERROR) << "parsing failed at line " << __LINE__;
333 return nullptr;
334 }
335 const auto& intfields = *reinterpret_cast<const IntFields*>(buffer);
336 rtype = ntohs(intfields.rtype);
337 rclass = ntohs(intfields.rclass);
338 ttl = ntohl(intfields.ttl);
339 *rdlen = ntohs(intfields.rdlen);
340 return buffer + sizeof(IntFields);
341 }
342
writeIntFields(unsigned rdlen,char * buffer,const char * buffer_end) const343 char* DNSRecord::writeIntFields(unsigned rdlen, char* buffer, const char* buffer_end) const {
344 if (buffer + sizeof(IntFields) > buffer_end) {
345 LOG(ERROR) << "buffer overflow on line " << __LINE__;
346 return nullptr;
347 }
348 auto& intfields = *reinterpret_cast<IntFields*>(buffer);
349 intfields.rtype = htons(rtype);
350 intfields.rclass = htons(rclass);
351 intfields.ttl = htonl(ttl);
352 intfields.rdlen = htons(rdlen);
353 return buffer + sizeof(IntFields);
354 }
355
read(const char * buffer,const char * buffer_end)356 const char* DNSHeader::read(const char* buffer, const char* buffer_end) {
357 unsigned qdcount;
358 unsigned ancount;
359 unsigned nscount;
360 unsigned arcount;
361 const char* cur = readHeader(buffer, buffer_end, &qdcount, &ancount, &nscount, &arcount);
362 if (cur == nullptr) {
363 LOG(ERROR) << "parsing failed at line " << __LINE__;
364 return nullptr;
365 }
366 if (qdcount) {
367 questions.resize(qdcount);
368 for (unsigned i = 0; i < qdcount; ++i) {
369 cur = questions[i].read(cur, buffer_end);
370 if (cur == nullptr) {
371 LOG(ERROR) << "parsing failed at line " << __LINE__;
372 return nullptr;
373 }
374 }
375 }
376 if (ancount) {
377 answers.resize(ancount);
378 for (unsigned i = 0; i < ancount; ++i) {
379 cur = answers[i].read(cur, buffer_end);
380 if (cur == nullptr) {
381 LOG(ERROR) << "parsing failed at line " << __LINE__;
382 return nullptr;
383 }
384 }
385 }
386 if (nscount) {
387 authorities.resize(nscount);
388 for (unsigned i = 0; i < nscount; ++i) {
389 cur = authorities[i].read(cur, buffer_end);
390 if (cur == nullptr) {
391 LOG(ERROR) << "parsing failed at line " << __LINE__;
392 return nullptr;
393 }
394 }
395 }
396 if (arcount) {
397 additionals.resize(arcount);
398 for (unsigned i = 0; i < arcount; ++i) {
399 cur = additionals[i].read(cur, buffer_end);
400 if (cur == nullptr) {
401 LOG(ERROR) << "parsing failed at line " << __LINE__;
402 return nullptr;
403 }
404 }
405 }
406 return cur;
407 }
408
write(char * buffer,const char * buffer_end) const409 char* DNSHeader::write(char* buffer, const char* buffer_end) const {
410 if (buffer + sizeof(Header) > buffer_end) {
411 LOG(ERROR) << "buffer overflow on line " << __LINE__;
412 return nullptr;
413 }
414 Header& header = *reinterpret_cast<Header*>(buffer);
415 // bytes 0-1
416 header.id = htons(id);
417 // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
418 header.flags0 = (qr << 7) | (opcode << 3) | (aa << 2) | (tr << 1) | rd;
419 // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
420 // Fake behavior: if the query set the "ad" bit, set it in the response too.
421 // In a real server, this should be set only if the data is authentic and the
422 // query contained an "ad" bit or DNSSEC extensions.
423 header.flags1 = (ad << 5) | rcode;
424 // rest of header
425 header.qdcount = htons(questions.size());
426 header.ancount = htons(answers.size());
427 header.nscount = htons(authorities.size());
428 header.arcount = htons(additionals.size());
429 char* buffer_cur = buffer + sizeof(Header);
430 for (const DNSQuestion& question : questions) {
431 buffer_cur = question.write(buffer_cur, buffer_end);
432 if (buffer_cur == nullptr) return nullptr;
433 }
434 for (const DNSRecord& answer : answers) {
435 buffer_cur = answer.write(buffer_cur, buffer_end);
436 if (buffer_cur == nullptr) return nullptr;
437 }
438 for (const DNSRecord& authority : authorities) {
439 buffer_cur = authority.write(buffer_cur, buffer_end);
440 if (buffer_cur == nullptr) return nullptr;
441 }
442 for (const DNSRecord& additional : additionals) {
443 buffer_cur = additional.write(buffer_cur, buffer_end);
444 if (buffer_cur == nullptr) return nullptr;
445 }
446 return buffer_cur;
447 }
448
449 // TODO: convert all callers to this interface, then delete the old one.
write(std::vector<uint8_t> * out) const450 bool DNSHeader::write(std::vector<uint8_t>* out) const {
451 char buffer[16384];
452 char* end = this->write(buffer, buffer + sizeof buffer);
453 if (end == nullptr) return false;
454 out->insert(out->end(), buffer, end);
455 return true;
456 }
457
toString() const458 std::string DNSHeader::toString() const {
459 // TODO
460 return std::string();
461 }
462
readHeader(const char * buffer,const char * buffer_end,unsigned * qdcount,unsigned * ancount,unsigned * nscount,unsigned * arcount)463 const char* DNSHeader::readHeader(const char* buffer, const char* buffer_end, unsigned* qdcount,
464 unsigned* ancount, unsigned* nscount, unsigned* arcount) {
465 if (buffer + sizeof(Header) > buffer_end) return nullptr;
466 const auto& header = *reinterpret_cast<const Header*>(buffer);
467 // bytes 0-1
468 id = ntohs(header.id);
469 // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
470 qr = header.flags0 >> 7;
471 opcode = (header.flags0 >> 3) & 0x0F;
472 aa = (header.flags0 >> 2) & 1;
473 tr = (header.flags0 >> 1) & 1;
474 rd = header.flags0 & 1;
475 // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
476 ra = header.flags1 >> 7;
477 ad = (header.flags1 >> 5) & 1;
478 rcode = header.flags1 & 0xF;
479 // rest of header
480 *qdcount = ntohs(header.qdcount);
481 *ancount = ntohs(header.ancount);
482 *nscount = ntohs(header.nscount);
483 *arcount = ntohs(header.arcount);
484 return buffer + sizeof(Header);
485 }
486
487 /* DNS responder */
488
DNSResponder(std::string listen_address,std::string listen_service,ns_rcode error_rcode,MappingType mapping_type)489 DNSResponder::DNSResponder(std::string listen_address, std::string listen_service,
490 ns_rcode error_rcode, MappingType mapping_type)
491 : listen_address_(std::move(listen_address)),
492 listen_service_(std::move(listen_service)),
493 error_rcode_(error_rcode),
494 mapping_type_(mapping_type) {}
495
~DNSResponder()496 DNSResponder::~DNSResponder() {
497 stopServer();
498 }
499
addMapping(const std::string & name,ns_type type,const std::string & addr)500 void DNSResponder::addMapping(const std::string& name, ns_type type, const std::string& addr) {
501 std::lock_guard lock(mappings_mutex_);
502 mappings_[{name, type}] = addr;
503 }
504
addMappingDnsHeader(const std::string & name,ns_type type,const DNSHeader & header)505 void DNSResponder::addMappingDnsHeader(const std::string& name, ns_type type,
506 const DNSHeader& header) {
507 std::lock_guard lock(mappings_mutex_);
508 dnsheader_mappings_[{name, type}] = header;
509 }
510
addMappingBinaryPacket(const std::vector<uint8_t> & query,const std::vector<uint8_t> & response)511 void DNSResponder::addMappingBinaryPacket(const std::vector<uint8_t>& query,
512 const std::vector<uint8_t>& response) {
513 std::lock_guard lock(mappings_mutex_);
514 packet_mappings_[query] = response;
515 }
516
removeMapping(const std::string & name,ns_type type)517 void DNSResponder::removeMapping(const std::string& name, ns_type type) {
518 std::lock_guard lock(mappings_mutex_);
519 if (!mappings_.erase({name, type})) {
520 LOG(ERROR) << "Cannot remove mapping from (" << name << ", " << dnstype2str(type)
521 << "), not present in registered mappings";
522 }
523 }
524
removeMappingDnsHeader(const std::string & name,ns_type type)525 void DNSResponder::removeMappingDnsHeader(const std::string& name, ns_type type) {
526 std::lock_guard lock(mappings_mutex_);
527 if (!dnsheader_mappings_.erase({name, type})) {
528 LOG(ERROR) << "Cannot remove mapping from (" << name << ", " << dnstype2str(type)
529 << "), not present in registered DnsHeader mappings";
530 }
531 }
532
removeMappingBinaryPacket(const std::vector<uint8_t> & query)533 void DNSResponder::removeMappingBinaryPacket(const std::vector<uint8_t>& query) {
534 std::lock_guard lock(mappings_mutex_);
535 if (!packet_mappings_.erase(query)) {
536 LOG(ERROR) << "Cannot remove mapping, not present in registered BinaryPacket mappings";
537 LOG(INFO) << "Hex dump:";
538 LOG(INFO) << android::netdutils::toHex(
539 Slice(const_cast<uint8_t*>(query.data()), query.size()), 32);
540 }
541 }
542
543 // Set response probability on all supported protocols.
setResponseProbability(double response_probability)544 void DNSResponder::setResponseProbability(double response_probability) {
545 setResponseProbability(response_probability, IPPROTO_TCP);
546 setResponseProbability(response_probability, IPPROTO_UDP);
547 }
548
setResponseDelayMs(unsigned timeMs)549 void DNSResponder::setResponseDelayMs(unsigned timeMs) {
550 response_delayed_ms_ = timeMs;
551 }
552
553 // Set response probability on specific protocol. It's caller's duty to ensure that the |protocol|
554 // can be supported by DNSResponder.
setResponseProbability(double response_probability,int protocol)555 void DNSResponder::setResponseProbability(double response_probability, int protocol) {
556 switch (protocol) {
557 case IPPROTO_TCP:
558 response_probability_tcp_ = response_probability;
559 break;
560 case IPPROTO_UDP:
561 response_probability_udp_ = response_probability;
562 break;
563 default:
564 LOG(FATAL) << "Unsupported protocol " << protocol; // abort() by log level FATAL
565 }
566 }
567
getResponseProbability(int protocol) const568 double DNSResponder::getResponseProbability(int protocol) const {
569 switch (protocol) {
570 case IPPROTO_TCP:
571 return response_probability_tcp_;
572 case IPPROTO_UDP:
573 return response_probability_udp_;
574 default:
575 LOG(FATAL) << "Unsupported protocol " << protocol; // abort() by log level FATAL
576 // unreachable
577 return -1;
578 }
579 }
580
setEdns(Edns edns)581 void DNSResponder::setEdns(Edns edns) {
582 edns_ = edns;
583 }
584
setTtl(unsigned ttl)585 void DNSResponder::setTtl(unsigned ttl) {
586 answer_record_ttl_sec_ = ttl;
587 }
588
running() const589 bool DNSResponder::running() const {
590 return (udp_socket_.ok()) && (tcp_socket_.ok());
591 }
592
startServer()593 bool DNSResponder::startServer() {
594 if (running()) {
595 LOG(ERROR) << "server already running";
596 return false;
597 }
598
599 // Create UDP, TCP socket
600 if (udp_socket_ = createListeningSocket(SOCK_DGRAM); udp_socket_.get() < 0) {
601 PLOG(ERROR) << "failed to create UDP socket";
602 return false;
603 }
604
605 if (tcp_socket_ = createListeningSocket(SOCK_STREAM); tcp_socket_.get() < 0) {
606 PLOG(ERROR) << "failed to create TCP socket";
607 return false;
608 }
609
610 if (listen(tcp_socket_.get(), 1) < 0) {
611 PLOG(ERROR) << "failed to listen TCP socket";
612 return false;
613 }
614
615 // Set up eventfd socket.
616 event_fd_.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));
617 if (event_fd_.get() == -1) {
618 PLOG(ERROR) << "failed to create eventfd";
619 return false;
620 }
621
622 // Set up epoll socket.
623 epoll_fd_.reset(epoll_create1(EPOLL_CLOEXEC));
624 if (epoll_fd_.get() < 0) {
625 PLOG(ERROR) << "epoll_create1() failed on fd";
626 return false;
627 }
628
629 LOG(INFO) << "adding UDP socket to epoll";
630 if (!addFd(udp_socket_.get(), EPOLLIN)) {
631 LOG(ERROR) << "failed to add the UDP socket to epoll";
632 return false;
633 }
634
635 LOG(INFO) << "adding TCP socket to epoll";
636 if (!addFd(tcp_socket_.get(), EPOLLIN)) {
637 LOG(ERROR) << "failed to add the TCP socket to epoll";
638 return false;
639 }
640
641 LOG(INFO) << "adding eventfd to epoll";
642 if (!addFd(event_fd_.get(), EPOLLIN)) {
643 LOG(ERROR) << "failed to add the eventfd to epoll";
644 return false;
645 }
646
647 {
648 std::lock_guard lock(update_mutex_);
649 handler_thread_ = std::thread(&DNSResponder::requestHandler, this);
650 }
651 LOG(INFO) << "server started successfully";
652 return true;
653 }
654
stopServer()655 bool DNSResponder::stopServer() {
656 std::lock_guard lock(update_mutex_);
657 if (!running()) {
658 LOG(ERROR) << "server not running";
659 return false;
660 }
661 LOG(INFO) << "stopping server";
662 if (!sendToEventFd()) {
663 return false;
664 }
665 handler_thread_.join();
666 epoll_fd_.reset();
667 event_fd_.reset();
668 udp_socket_.reset();
669 tcp_socket_.reset();
670 LOG(INFO) << "server stopped successfully";
671 return true;
672 }
673
queries() const674 std::vector<DNSResponder::QueryInfo> DNSResponder::queries() const {
675 std::lock_guard lock(queries_mutex_);
676 return queries_;
677 }
678
dumpQueries() const679 std::string DNSResponder::dumpQueries() const {
680 std::lock_guard lock(queries_mutex_);
681 std::string out;
682
683 for (const auto& q : queries_) {
684 out += "{\"" + q.name + "\", " + std::to_string(q.type) + "\", " +
685 dnsproto2str(q.protocol) + "} ";
686 }
687 return out;
688 }
689
clearQueries()690 void DNSResponder::clearQueries() {
691 std::lock_guard lock(queries_mutex_);
692 queries_.clear();
693 }
694
hasOptPseudoRR(DNSHeader * header) const695 bool DNSResponder::hasOptPseudoRR(DNSHeader* header) const {
696 if (header->additionals.empty()) return false;
697
698 // OPT RR may be placed anywhere within the additional section. See RFC 6891 section 6.1.1.
699 auto found = std::find_if(header->additionals.begin(), header->additionals.end(),
700 [](const auto& a) { return a.rtype == ns_type::ns_t_opt; });
701 return found != header->additionals.end();
702 }
703
requestHandler()704 void DNSResponder::requestHandler() {
705 epoll_event evs[EPOLL_MAX_EVENTS];
706 while (true) {
707 int n = epoll_wait(epoll_fd_.get(), evs, EPOLL_MAX_EVENTS, -1);
708 if (n <= 0) {
709 PLOG(ERROR) << "epoll_wait() failed, n=" << n;
710 return;
711 }
712
713 for (int i = 0; i < n; i++) {
714 const int fd = evs[i].data.fd;
715 const uint32_t events = evs[i].events;
716 if (fd == event_fd_.get() && (events & (EPOLLIN | EPOLLERR))) {
717 handleEventFd();
718 return;
719 } else if (fd == udp_socket_.get() && (events & (EPOLLIN | EPOLLERR))) {
720 handleQuery(IPPROTO_UDP);
721 } else if (fd == tcp_socket_.get() && (events & (EPOLLIN | EPOLLERR))) {
722 handleQuery(IPPROTO_TCP);
723 } else {
724 LOG(WARNING) << "unexpected epoll events " << events << " on fd " << fd;
725 }
726 }
727 }
728 }
729
handleDNSRequest(const char * buffer,ssize_t len,int protocol,char * response,size_t * response_len) const730 bool DNSResponder::handleDNSRequest(const char* buffer, ssize_t len, int protocol, char* response,
731 size_t* response_len) const {
732 LOG(DEBUG) << "request: '" << str2hex(buffer, len) << "', on " << dnsproto2str(protocol);
733 const char* buffer_end = buffer + len;
734 DNSHeader header;
735 const char* cur = header.read(buffer, buffer_end);
736 // TODO(imaipi): for now, unparsable messages are silently dropped, fix.
737 if (cur == nullptr) {
738 LOG(ERROR) << "failed to parse query";
739 return false;
740 }
741 if (header.qr) {
742 LOG(ERROR) << "response received instead of a query";
743 return false;
744 }
745 if (header.opcode != ns_opcode::ns_o_query) {
746 LOG(INFO) << "unsupported request opcode received";
747 return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response, response_len);
748 }
749 if (header.questions.empty()) {
750 LOG(INFO) << "no questions present";
751 return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, response_len);
752 }
753 if (!header.answers.empty()) {
754 LOG(INFO) << "already " << header.answers.size() << " answers present in query";
755 return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, response_len);
756 }
757
758 if (edns_ == Edns::FORMERR_UNCOND) {
759 LOG(INFO) << "force to return RCODE FORMERR";
760 return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, response_len);
761 }
762
763 if (!header.additionals.empty() && edns_ != Edns::ON) {
764 LOG(INFO) << "DNS request has an additional section (assumed EDNS). Simulating an ancient "
765 "(pre-EDNS) server, and returning "
766 << (edns_ == Edns::FORMERR_ON_EDNS ? "RCODE FORMERR." : "no response.");
767 if (edns_ == Edns::FORMERR_ON_EDNS) {
768 return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, response_len);
769 }
770 // No response.
771 return false;
772 }
773 {
774 std::lock_guard lock(queries_mutex_);
775 for (const DNSQuestion& question : header.questions) {
776 queries_.push_back({question.qname.name, ns_type(question.qtype), protocol});
777 }
778 }
779 // Ignore requests with the preset probability.
780 auto constexpr bound = std::numeric_limits<unsigned>::max();
781 if (arc4random_uniform(bound) > bound * getResponseProbability(protocol)) {
782 if (error_rcode_ < 0) {
783 LOG(ERROR) << "Returning no response";
784 return false;
785 } else {
786 LOG(INFO) << "returning RCODE " << static_cast<int>(error_rcode_)
787 << " in accordance with probability distribution";
788 return makeErrorResponse(&header, error_rcode_, response, response_len);
789 }
790 }
791
792 // Make the response. The query has been read into |header| which is used to build and return
793 // the response as well.
794 return makeResponse(&header, protocol, response, response_len);
795 }
796
addAnswerRecords(const DNSQuestion & question,std::vector<DNSRecord> * answers) const797 bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
798 std::vector<DNSRecord>* answers) const {
799 std::lock_guard guard(mappings_mutex_);
800 std::string rname = question.qname.name;
801 std::vector<int> rtypes;
802
803 if (question.qtype == ns_type::ns_t_a || question.qtype == ns_type::ns_t_aaaa ||
804 question.qtype == ns_type::ns_t_ptr)
805 rtypes.push_back(ns_type::ns_t_cname);
806 rtypes.push_back(question.qtype);
807 for (int rtype : rtypes) {
808 std::set<std::string> cnames_Loop;
809 std::unordered_map<QueryKey, std::string, QueryKeyHash>::const_iterator it;
810 while ((it = mappings_.find(QueryKey(rname, rtype))) != mappings_.end()) {
811 if (rtype == ns_type::ns_t_cname) {
812 // When detect CNAME infinite loops by cnames_Loop, it won't save the duplicate one.
813 // As following, the query will stop on loop3 by detecting the same cname.
814 // loop1.{"a.xxx.com", ns_type::ns_t_cname, "b.xxx.com"}(insert in answer record)
815 // loop2.{"b.xxx.com", ns_type::ns_t_cname, "a.xxx.com"}(insert in answer record)
816 // loop3.{"a.xxx.com", ns_type::ns_t_cname, "b.xxx.com"}(When the same cname record
817 // is found in cnames_Loop already, break the query loop.)
818 if (cnames_Loop.find(it->first.name) != cnames_Loop.end()) break;
819 cnames_Loop.insert(it->first.name);
820 }
821 DNSRecord record{
822 .name = {.name = it->first.name},
823 .rtype = it->first.type,
824 .rclass = ns_class::ns_c_in,
825 .ttl = answer_record_ttl_sec_, // seconds
826 };
827 if (!fillRdata(it->second, record)) return false;
828 answers->push_back(std::move(record));
829 if (rtype != ns_type::ns_t_cname) break;
830 rname = it->second;
831 }
832 }
833
834 if (answers->size() == 0) {
835 // TODO(imaipi): handle correctly
836 LOG(INFO) << "no mapping found for " << question.qname.name << " "
837 << dnstype2str(question.qtype) << ", lazily refusing to add an answer";
838 }
839
840 return true;
841 }
842
fillRdata(const std::string & rdatastr,DNSRecord & record)843 bool DNSResponder::fillRdata(const std::string& rdatastr, DNSRecord& record) {
844 if (record.rtype == ns_type::ns_t_a) {
845 record.rdata.resize(4);
846 if (inet_pton(AF_INET, rdatastr.c_str(), record.rdata.data()) != 1) {
847 LOG(ERROR) << "inet_pton(AF_INET, " << rdatastr << ") failed";
848 return false;
849 }
850 } else if (record.rtype == ns_type::ns_t_aaaa) {
851 record.rdata.resize(16);
852 if (inet_pton(AF_INET6, rdatastr.c_str(), record.rdata.data()) != 1) {
853 LOG(ERROR) << "inet_pton(AF_INET6, " << rdatastr << ") failed";
854 return false;
855 }
856 } else if ((record.rtype == ns_type::ns_t_ptr) || (record.rtype == ns_type::ns_t_cname) ||
857 (record.rtype == ns_type::ns_t_ns)) {
858 constexpr char delimiter = '.';
859 std::string name = rdatastr;
860 std::vector<char> rdata;
861
862 // Generating PTRDNAME field(section 3.3.12) or CNAME field(section 3.3.1) in rfc1035.
863 // The "name" should be an absolute domain name which ends in a dot.
864 if (name.back() != delimiter) {
865 LOG(ERROR) << "invalid absolute domain name";
866 return false;
867 }
868 name.pop_back(); // remove the dot in tail
869 for (const std::string& label : android::base::Split(name, {delimiter})) {
870 // The length of label is limited to 63 octets or less. See RFC 1035 section 3.1.
871 if (label.length() == 0 || label.length() > 63) {
872 LOG(ERROR) << "invalid label length";
873 return false;
874 }
875
876 rdata.push_back(label.length());
877 rdata.insert(rdata.end(), label.begin(), label.end());
878 }
879 rdata.push_back(0); // Length byte of zero terminates the label list
880
881 // The length of domain name is limited to 255 octets or less. See RFC 1035 section 3.1.
882 if (rdata.size() > 255) {
883 LOG(ERROR) << "invalid name length";
884 return false;
885 }
886 record.rdata = move(rdata);
887 } else {
888 LOG(ERROR) << "unhandled qtype " << dnstype2str(record.rtype);
889 return false;
890 }
891 return true;
892 }
893
writePacket(const DNSHeader * header,char * response,size_t * response_len) const894 bool DNSResponder::writePacket(const DNSHeader* header, char* response,
895 size_t* response_len) const {
896 char* response_cur = header->write(response, response + *response_len);
897 if (response_cur == nullptr) {
898 return false;
899 }
900 *response_len = response_cur - response;
901 return true;
902 }
903
makeErrorResponse(DNSHeader * header,ns_rcode rcode,char * response,size_t * response_len) const904 bool DNSResponder::makeErrorResponse(DNSHeader* header, ns_rcode rcode, char* response,
905 size_t* response_len) const {
906 header->answers.clear();
907 header->authorities.clear();
908 header->additionals.clear();
909 header->rcode = rcode;
910 header->qr = true;
911 return writePacket(header, response, response_len);
912 }
913
makeTruncatedResponse(DNSHeader * header,char * response,size_t * response_len) const914 bool DNSResponder::makeTruncatedResponse(DNSHeader* header, char* response,
915 size_t* response_len) const {
916 // Build a minimal response for non-EDNS response over UDP. Truncate all stub RRs in answer,
917 // authority and additional section. EDNS response truncation has not supported here yet
918 // because the EDNS response must have an OPT record. See RFC 6891 section 7.
919 header->answers.clear();
920 header->authorities.clear();
921 header->additionals.clear();
922 header->qr = true;
923 header->tr = true;
924 return writePacket(header, response, response_len);
925 }
926
makeResponse(DNSHeader * header,int protocol,char * response,size_t * response_len) const927 bool DNSResponder::makeResponse(DNSHeader* header, int protocol, char* response,
928 size_t* response_len) const {
929 char buffer[16384];
930 size_t buffer_len = sizeof(buffer);
931 bool ret;
932
933 switch (mapping_type_) {
934 case MappingType::DNS_HEADER:
935 ret = makeResponseFromDnsHeader(header, buffer, &buffer_len);
936 break;
937 case MappingType::BINARY_PACKET:
938 ret = makeResponseFromBinaryPacket(header, buffer, &buffer_len);
939 break;
940 case MappingType::ADDRESS_OR_HOSTNAME:
941 default:
942 ret = makeResponseFromAddressOrHostname(header, buffer, &buffer_len);
943 }
944
945 if (!ret) return false;
946
947 // Return truncated response if the built non-EDNS response size which is larger than 512 bytes
948 // will be responded over UDP. The truncated response implementation here just simply set up
949 // the TC bit and truncate all stub RRs in answer, authority and additional section. It is
950 // because the resolver will retry DNS query over TCP and use the full TCP response. See also
951 // RFC 1035 section 4.2.1 for UDP response truncation and RFC 6891 section 4.3 for EDNS larger
952 // response size capability.
953 // TODO: Perhaps keep the stub RRs as possible.
954 // TODO: Perhaps truncate the EDNS based response over UDP. See also RFC 6891 section 4.3,
955 // section 6.2.5 and section 7.
956 if (protocol == IPPROTO_UDP && buffer_len > kMaximumUdpSize &&
957 !hasOptPseudoRR(header) /* non-EDNS */) {
958 LOG(INFO) << "Return truncated response because original response length " << buffer_len
959 << " is larger than " << kMaximumUdpSize << " bytes.";
960 return makeTruncatedResponse(header, response, response_len);
961 }
962
963 if (buffer_len > *response_len) {
964 LOG(ERROR) << "buffer overflow on line " << __LINE__;
965 return false;
966 }
967 memcpy(response, buffer, buffer_len);
968 *response_len = buffer_len;
969 return true;
970 }
971
makeResponseFromAddressOrHostname(DNSHeader * header,char * response,size_t * response_len) const972 bool DNSResponder::makeResponseFromAddressOrHostname(DNSHeader* header, char* response,
973 size_t* response_len) const {
974 for (const DNSQuestion& question : header->questions) {
975 if (question.qclass != ns_class::ns_c_in && question.qclass != ns_class::ns_c_any) {
976 LOG(INFO) << "unsupported question class " << question.qclass;
977 return makeErrorResponse(header, ns_rcode::ns_r_notimpl, response, response_len);
978 }
979
980 if (!addAnswerRecords(question, &header->answers)) {
981 return makeErrorResponse(header, ns_rcode::ns_r_servfail, response, response_len);
982 }
983 }
984 header->qr = true;
985 return writePacket(header, response, response_len);
986 }
987
makeResponseFromDnsHeader(DNSHeader * header,char * response,size_t * response_len) const988 bool DNSResponder::makeResponseFromDnsHeader(DNSHeader* header, char* response,
989 size_t* response_len) const {
990 std::lock_guard guard(mappings_mutex_);
991
992 // Support single question record only. It should be okay because res_mkquery() sets "qdcount"
993 // as one for the operation QUERY and handleDNSRequest() checks ns_opcode::ns_o_query before
994 // making a response. In other words, only need to handle the query which has single question
995 // section. See also res_mkquery() in system/netd/resolv/res_mkquery.cpp.
996 // TODO: Perhaps add support for multi-question records.
997 const std::vector<DNSQuestion>& questions = header->questions;
998 if (questions.size() != 1) {
999 LOG(INFO) << "unsupported question count " << questions.size();
1000 return makeErrorResponse(header, ns_rcode::ns_r_notimpl, response, response_len);
1001 }
1002
1003 if (questions[0].qclass != ns_class::ns_c_in && questions[0].qclass != ns_class::ns_c_any) {
1004 LOG(INFO) << "unsupported question class " << questions[0].qclass;
1005 return makeErrorResponse(header, ns_rcode::ns_r_notimpl, response, response_len);
1006 }
1007
1008 const std::string name = questions[0].qname.name;
1009 const int qtype = questions[0].qtype;
1010 const auto it = dnsheader_mappings_.find(QueryKey(name, qtype));
1011 if (it != dnsheader_mappings_.end()) {
1012 // Store both "id" and "rd" which comes from query.
1013 const unsigned id = header->id;
1014 const bool rd = header->rd;
1015
1016 // Build a response from the registered DNSHeader mapping.
1017 *header = it->second;
1018 // Assign both "ID" and "RD" fields from query to response. See RFC 1035 section 4.1.1.
1019 header->id = id;
1020 header->rd = rd;
1021 } else {
1022 // TODO: handle correctly. See also TODO in addAnswerRecords().
1023 LOG(INFO) << "no mapping found for " << name << " " << dnstype2str(qtype)
1024 << ", couldn't build a response from DNSHeader mapping";
1025
1026 // Note that do nothing as makeResponseFromAddressOrHostname() if no mapping is found. It
1027 // just changes the QR flag from query (0) to response (1) in the query. Then, send the
1028 // modified query back as a response.
1029 header->qr = true;
1030 }
1031 return writePacket(header, response, response_len);
1032 }
1033
makeResponseFromBinaryPacket(DNSHeader * header,char * response,size_t * response_len) const1034 bool DNSResponder::makeResponseFromBinaryPacket(DNSHeader* header, char* response,
1035 size_t* response_len) const {
1036 std::lock_guard guard(mappings_mutex_);
1037
1038 // Build a search key of mapping from the query.
1039 // TODO: Perhaps pass the query packet buffer directly from the caller.
1040 std::vector<uint8_t> queryKey;
1041 if (!header->write(&queryKey)) return false;
1042 // Clear ID field (byte 0-1) because it is not required by the mapping key.
1043 queryKey[0] = 0;
1044 queryKey[1] = 0;
1045
1046 const auto it = packet_mappings_.find(queryKey);
1047 if (it != packet_mappings_.end()) {
1048 if (it->second.size() > *response_len) {
1049 LOG(ERROR) << "buffer overflow on line " << __LINE__;
1050 return false;
1051 } else {
1052 std::copy(it->second.begin(), it->second.end(), response);
1053 // Leave the "RD" flag assignment for testing. The "RD" flag of the response keep
1054 // using the one from the raw packet mapping but the received query.
1055 // Assign "ID" field from query to response. See RFC 1035 section 4.1.1.
1056 reinterpret_cast<uint16_t*>(response)[0] = htons(header->id); // bytes 0-1: id
1057 *response_len = it->second.size();
1058 return true;
1059 }
1060 } else {
1061 // TODO: handle correctly. See also TODO in addAnswerRecords().
1062 // TODO: Perhaps dump packet content to indicate which query failed.
1063 LOG(INFO) << "no mapping found, couldn't build a response from BinaryPacket mapping";
1064 // Note that do nothing as makeResponseFromAddressOrHostname() if no mapping is found. It
1065 // just changes the QR flag from query (0) to response (1) in the query. Then, send the
1066 // modified query back as a response.
1067 header->qr = true;
1068 return writePacket(header, response, response_len);
1069 }
1070 }
1071
setDeferredResp(bool deferred_resp)1072 void DNSResponder::setDeferredResp(bool deferred_resp) {
1073 std::lock_guard<std::mutex> guard(cv_mutex_for_deferred_resp_);
1074 deferred_resp_ = deferred_resp;
1075 if (!deferred_resp_) {
1076 cv_for_deferred_resp_.notify_one();
1077 }
1078 }
1079
addFd(int fd,uint32_t events)1080 bool DNSResponder::addFd(int fd, uint32_t events) {
1081 epoll_event ev;
1082 ev.events = events;
1083 ev.data.fd = fd;
1084 if (epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, fd, &ev) < 0) {
1085 PLOG(ERROR) << "epoll_ctl() for socket " << fd << " failed";
1086 return false;
1087 }
1088 return true;
1089 }
1090
handleQuery(int protocol)1091 void DNSResponder::handleQuery(int protocol) {
1092 char buffer[16384];
1093 sockaddr_storage sa;
1094 socklen_t sa_len = sizeof(sa);
1095 ssize_t len = 0;
1096 android::base::unique_fd tcpFd;
1097 switch (protocol) {
1098 case IPPROTO_UDP:
1099 do {
1100 len = recvfrom(udp_socket_.get(), buffer, sizeof(buffer), 0, (sockaddr*)&sa,
1101 &sa_len);
1102 } while (len < 0 && (errno == EAGAIN || errno == EINTR));
1103 if (len <= 0) {
1104 PLOG(ERROR) << "recvfrom() failed, len=" << len;
1105 return;
1106 }
1107 break;
1108 case IPPROTO_TCP:
1109 tcpFd.reset(accept4(tcp_socket_.get(), reinterpret_cast<sockaddr*>(&sa), &sa_len,
1110 SOCK_CLOEXEC));
1111 if (tcpFd.get() < 0) {
1112 PLOG(ERROR) << "failed to accept client socket";
1113 return;
1114 }
1115 // Get the message length from two byte length field.
1116 // See also RFC 1035, section 4.2.2 and RFC 7766, section 8
1117 uint8_t queryMessageLengthField[2];
1118 if (read(tcpFd.get(), &queryMessageLengthField, 2) != 2) {
1119 PLOG(ERROR) << "Not enough length field bytes";
1120 return;
1121 }
1122
1123 const uint16_t qlen = (queryMessageLengthField[0] << 8) | queryMessageLengthField[1];
1124 while (len < qlen) {
1125 ssize_t ret = read(tcpFd.get(), buffer + len, qlen - len);
1126 if (ret <= 0) {
1127 PLOG(ERROR) << "Error while reading query";
1128 return;
1129 }
1130 len += ret;
1131 }
1132 break;
1133 }
1134 LOG(DEBUG) << "read " << len << " bytes on " << dnsproto2str(protocol);
1135 std::lock_guard lock(cv_mutex_);
1136 char response[16384];
1137 size_t response_len = sizeof(response);
1138 // TODO: check whether sending malformed packets to DnsResponder
1139 if (handleDNSRequest(buffer, len, protocol, response, &response_len) && response_len > 0) {
1140 std::this_thread::sleep_for(std::chrono::milliseconds(response_delayed_ms_));
1141 // place wait_for after handleDNSRequest() so we can check the number of queries in
1142 // test case before it got responded.
1143 std::unique_lock guard(cv_mutex_for_deferred_resp_);
1144 cv_for_deferred_resp_.wait(
1145 guard, [this]() REQUIRES(cv_mutex_for_deferred_resp_) { return !deferred_resp_; });
1146 len = 0;
1147
1148 switch (protocol) {
1149 case IPPROTO_UDP:
1150 len = sendto(udp_socket_.get(), response, response_len, 0,
1151 reinterpret_cast<const sockaddr*>(&sa), sa_len);
1152 if (len < 0) {
1153 PLOG(ERROR) << "Failed to send response";
1154 }
1155 break;
1156 case IPPROTO_TCP:
1157 // Get the message length from two byte length field.
1158 // See also RFC 1035, section 4.2.2 and RFC 7766, section 8
1159 uint8_t responseMessageLengthField[2];
1160 responseMessageLengthField[0] = response_len >> 8;
1161 responseMessageLengthField[1] = response_len;
1162 if (write(tcpFd.get(), responseMessageLengthField, 2) != 2) {
1163 PLOG(ERROR) << "Failed to write response length field";
1164 break;
1165 }
1166 if (write(tcpFd.get(), response, response_len) !=
1167 static_cast<ssize_t>(response_len)) {
1168 PLOG(ERROR) << "Failed to write response";
1169 break;
1170 }
1171 len = response_len;
1172 break;
1173 }
1174 const std::string host_str = addr2str(reinterpret_cast<const sockaddr*>(&sa), sa_len);
1175 if (len > 0) {
1176 LOG(DEBUG) << "sent " << len << " bytes to " << host_str;
1177 } else {
1178 const char* method_str = (protocol == IPPROTO_TCP) ? "write()" : "sendto()";
1179 LOG(ERROR) << method_str << " failed for " << host_str;
1180 }
1181 // Test that the response is actually a correct DNS message.
1182 // TODO: Perhaps make DNS message validation to support name compression. Or it throws
1183 // a warning for a valid DNS message with name compression while the binary packet mapping
1184 // is used.
1185 const char* response_end = response + len;
1186 DNSHeader header;
1187 const char* cur = header.read(response, response_end);
1188 if (cur == nullptr) LOG(WARNING) << "response is flawed";
1189 } else {
1190 LOG(WARNING) << "not responding";
1191 }
1192 cv.notify_one();
1193 return;
1194 }
1195
sendToEventFd()1196 bool DNSResponder::sendToEventFd() {
1197 const uint64_t data = 1;
1198 if (const ssize_t rt = write(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
1199 PLOG(ERROR) << "failed to write eventfd, rt=" << rt;
1200 return false;
1201 }
1202 return true;
1203 }
1204
handleEventFd()1205 void DNSResponder::handleEventFd() {
1206 int64_t data;
1207 if (const ssize_t rt = read(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
1208 PLOG(INFO) << "ignore reading eventfd failed, rt=" << rt;
1209 }
1210 }
1211
createListeningSocket(int socket_type)1212 android::base::unique_fd DNSResponder::createListeningSocket(int socket_type) {
1213 addrinfo ai_hints{
1214 .ai_flags = AI_PASSIVE,
1215 .ai_family = AF_UNSPEC,
1216 .ai_socktype = socket_type,
1217 };
1218 addrinfo* ai_res = nullptr;
1219 const int rv =
1220 getaddrinfo(listen_address_.c_str(), listen_service_.c_str(), &ai_hints, &ai_res);
1221 ScopedAddrinfo ai_res_cleanup(ai_res);
1222 if (rv) {
1223 LOG(ERROR) << "getaddrinfo(" << listen_address_ << ", " << listen_service_
1224 << ") failed: " << gai_strerror(rv);
1225 return {};
1226 }
1227 for (const addrinfo* ai = ai_res; ai; ai = ai->ai_next) {
1228 android::base::unique_fd fd(
1229 socket(ai->ai_family, ai->ai_socktype | SOCK_NONBLOCK, ai->ai_protocol));
1230 if (fd.get() < 0) {
1231 PLOG(ERROR) << "ignore creating socket failed";
1232 continue;
1233 }
1234 enableSockopt(fd.get(), SOL_SOCKET, SO_REUSEPORT).ignoreError();
1235 enableSockopt(fd.get(), SOL_SOCKET, SO_REUSEADDR).ignoreError();
1236 const std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
1237 const char* socket_str = (socket_type == SOCK_STREAM) ? "TCP" : "UDP";
1238
1239 if (bindSocket(fd.get(), ai->ai_addr, ai->ai_addrlen)) {
1240 PLOG(ERROR) << "failed to bind " << socket_str << " " << host_str << ":"
1241 << listen_service_;
1242 continue;
1243 }
1244 LOG(INFO) << "bound to " << socket_str << " " << host_str << ":" << listen_service_;
1245 return fd;
1246 }
1247 return {};
1248 }
1249
1250 } // namespace test
1251