1 /*
2 * Copyright (C) 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "dns_responder.h"
18
19 #include <arpa/inet.h>
20 #include <fcntl.h>
21 #include <netdb.h>
22 #include <stdarg.h>
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <string.h>
26 #include <sys/epoll.h>
27 #include <sys/eventfd.h>
28 #include <sys/socket.h>
29 #include <sys/types.h>
30 #include <unistd.h>
31 #include <set>
32
33 #include <iostream>
34 #include <vector>
35
36 #define LOG_TAG "DNSResponder"
37 #include <android-base/strings.h>
38 #include <log/log.h>
39 #include <netdutils/SocketOption.h>
40
41 #include "NetdConstants.h"
42
43 using android::netdutils::enableSockopt;
44
45 namespace test {
46
errno2str()47 std::string errno2str() {
48 char error_msg[512] = { 0 };
49 // It actually calls __gnu_strerror_r() which returns the type |char*| rather than |int|.
50 // PLOG is an option though it requires lots of changes from ALOGx() to LOG(x).
51 return strerror_r(errno, error_msg, sizeof(error_msg));
52 }
53
54 #define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str())
55
56 #if 0
57 #define DBGLOG(fmt, ...) ALOGI(fmt, __VA_ARGS__)
58 #else
59 #define DBGLOG(fmt, ...)
60 #endif
61
str2hex(const char * buffer,size_t len)62 std::string str2hex(const char* buffer, size_t len) {
63 std::string str(len*2, '\0');
64 for (size_t i = 0 ; i < len ; ++i) {
65 static const char* hex = "0123456789ABCDEF";
66 uint8_t c = buffer[i];
67 str[i*2] = hex[c >> 4];
68 str[i*2 + 1] = hex[c & 0x0F];
69 }
70 return str;
71 }
72
addr2str(const sockaddr * sa,socklen_t sa_len)73 std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
74 char host_str[NI_MAXHOST] = { 0 };
75 int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0,
76 NI_NUMERICHOST);
77 if (rv == 0) return std::string(host_str);
78 return std::string();
79 }
80
81 /* DNS struct helpers */
82
dnstype2str(unsigned dnstype)83 const char* dnstype2str(unsigned dnstype) {
84 static std::unordered_map<unsigned, const char*> kTypeStrs = {
85 { ns_type::ns_t_a, "A" },
86 { ns_type::ns_t_ns, "NS" },
87 { ns_type::ns_t_md, "MD" },
88 { ns_type::ns_t_mf, "MF" },
89 { ns_type::ns_t_cname, "CNAME" },
90 { ns_type::ns_t_soa, "SOA" },
91 { ns_type::ns_t_mb, "MB" },
92 { ns_type::ns_t_mb, "MG" },
93 { ns_type::ns_t_mr, "MR" },
94 { ns_type::ns_t_null, "NULL" },
95 { ns_type::ns_t_wks, "WKS" },
96 { ns_type::ns_t_ptr, "PTR" },
97 { ns_type::ns_t_hinfo, "HINFO" },
98 { ns_type::ns_t_minfo, "MINFO" },
99 { ns_type::ns_t_mx, "MX" },
100 { ns_type::ns_t_txt, "TXT" },
101 { ns_type::ns_t_rp, "RP" },
102 { ns_type::ns_t_afsdb, "AFSDB" },
103 { ns_type::ns_t_x25, "X25" },
104 { ns_type::ns_t_isdn, "ISDN" },
105 { ns_type::ns_t_rt, "RT" },
106 { ns_type::ns_t_nsap, "NSAP" },
107 { ns_type::ns_t_nsap_ptr, "NSAP-PTR" },
108 { ns_type::ns_t_sig, "SIG" },
109 { ns_type::ns_t_key, "KEY" },
110 { ns_type::ns_t_px, "PX" },
111 { ns_type::ns_t_gpos, "GPOS" },
112 { ns_type::ns_t_aaaa, "AAAA" },
113 { ns_type::ns_t_loc, "LOC" },
114 { ns_type::ns_t_nxt, "NXT" },
115 { ns_type::ns_t_eid, "EID" },
116 { ns_type::ns_t_nimloc, "NIMLOC" },
117 { ns_type::ns_t_srv, "SRV" },
118 { ns_type::ns_t_naptr, "NAPTR" },
119 { ns_type::ns_t_kx, "KX" },
120 { ns_type::ns_t_cert, "CERT" },
121 { ns_type::ns_t_a6, "A6" },
122 { ns_type::ns_t_dname, "DNAME" },
123 { ns_type::ns_t_sink, "SINK" },
124 { ns_type::ns_t_opt, "OPT" },
125 { ns_type::ns_t_apl, "APL" },
126 { ns_type::ns_t_tkey, "TKEY" },
127 { ns_type::ns_t_tsig, "TSIG" },
128 { ns_type::ns_t_ixfr, "IXFR" },
129 { ns_type::ns_t_axfr, "AXFR" },
130 { ns_type::ns_t_mailb, "MAILB" },
131 { ns_type::ns_t_maila, "MAILA" },
132 { ns_type::ns_t_any, "ANY" },
133 { ns_type::ns_t_zxfr, "ZXFR" },
134 };
135 auto it = kTypeStrs.find(dnstype);
136 static const char* kUnknownStr{ "UNKNOWN" };
137 if (it == kTypeStrs.end()) return kUnknownStr;
138 return it->second;
139 }
140
dnsclass2str(unsigned dnsclass)141 const char* dnsclass2str(unsigned dnsclass) {
142 static std::unordered_map<unsigned, const char*> kClassStrs = {
143 { ns_class::ns_c_in , "Internet" },
144 { 2, "CSNet" },
145 { ns_class::ns_c_chaos, "ChaosNet" },
146 { ns_class::ns_c_hs, "Hesiod" },
147 { ns_class::ns_c_none, "none" },
148 { ns_class::ns_c_any, "any" }
149 };
150 auto it = kClassStrs.find(dnsclass);
151 static const char* kUnknownStr{ "UNKNOWN" };
152 if (it == kClassStrs.end()) return kUnknownStr;
153 return it->second;
154 }
155
156 struct DNSName {
157 std::string name;
158 const char* read(const char* buffer, const char* buffer_end);
159 char* write(char* buffer, const char* buffer_end) const;
160 const char* toString() const;
161 private:
162 const char* parseField(const char* buffer, const char* buffer_end,
163 bool* last);
164 };
165
toString() const166 const char* DNSName::toString() const {
167 return name.c_str();
168 }
169
read(const char * buffer,const char * buffer_end)170 const char* DNSName::read(const char* buffer, const char* buffer_end) {
171 const char* cur = buffer;
172 bool last = false;
173 do {
174 cur = parseField(cur, buffer_end, &last);
175 if (cur == nullptr) {
176 ALOGI("parsing failed at line %d", __LINE__);
177 return nullptr;
178 }
179 } while (!last);
180 return cur;
181 }
182
write(char * buffer,const char * buffer_end) const183 char* DNSName::write(char* buffer, const char* buffer_end) const {
184 char* buffer_cur = buffer;
185 for (size_t pos = 0 ; pos < name.size() ; ) {
186 size_t dot_pos = name.find('.', pos);
187 if (dot_pos == std::string::npos) {
188 // Sanity check, should never happen unless parseField is broken.
189 ALOGI("logic error: all names are expected to end with a '.'");
190 return nullptr;
191 }
192 size_t len = dot_pos - pos;
193 if (len >= 256) {
194 ALOGI("name component '%s' is %zu long, but max is 255",
195 name.substr(pos, dot_pos - pos).c_str(), len);
196 return nullptr;
197 }
198 if (buffer_cur + sizeof(uint8_t) + len > buffer_end) {
199 ALOGI("buffer overflow at line %d", __LINE__);
200 return nullptr;
201 }
202 *buffer_cur++ = len;
203 buffer_cur = std::copy(std::next(name.begin(), pos),
204 std::next(name.begin(), dot_pos),
205 buffer_cur);
206 pos = dot_pos + 1;
207 }
208 // Write final zero.
209 *buffer_cur++ = 0;
210 return buffer_cur;
211 }
212
parseField(const char * buffer,const char * buffer_end,bool * last)213 const char* DNSName::parseField(const char* buffer, const char* buffer_end,
214 bool* last) {
215 if (buffer + sizeof(uint8_t) > buffer_end) {
216 ALOGI("parsing failed at line %d", __LINE__);
217 return nullptr;
218 }
219 unsigned field_type = *buffer >> 6;
220 unsigned ofs = *buffer & 0x3F;
221 const char* cur = buffer + sizeof(uint8_t);
222 if (field_type == 0) {
223 // length + name component
224 if (ofs == 0) {
225 *last = true;
226 return cur;
227 }
228 if (cur + ofs > buffer_end) {
229 ALOGI("parsing failed at line %d", __LINE__);
230 return nullptr;
231 }
232 name.append(cur, ofs);
233 name.push_back('.');
234 return cur + ofs;
235 } else if (field_type == 3) {
236 ALOGI("name compression not implemented");
237 return nullptr;
238 }
239 ALOGI("invalid name field type");
240 return nullptr;
241 }
242
243 struct DNSQuestion {
244 DNSName qname;
245 unsigned qtype;
246 unsigned qclass;
247 const char* read(const char* buffer, const char* buffer_end);
248 char* write(char* buffer, const char* buffer_end) const;
249 std::string toString() const;
250 };
251
read(const char * buffer,const char * buffer_end)252 const char* DNSQuestion::read(const char* buffer, const char* buffer_end) {
253 const char* cur = qname.read(buffer, buffer_end);
254 if (cur == nullptr) {
255 ALOGI("parsing failed at line %d", __LINE__);
256 return nullptr;
257 }
258 if (cur + 2*sizeof(uint16_t) > buffer_end) {
259 ALOGI("parsing failed at line %d", __LINE__);
260 return nullptr;
261 }
262 qtype = ntohs(*reinterpret_cast<const uint16_t*>(cur));
263 qclass = ntohs(*reinterpret_cast<const uint16_t*>(cur + sizeof(uint16_t)));
264 return cur + 2*sizeof(uint16_t);
265 }
266
write(char * buffer,const char * buffer_end) const267 char* DNSQuestion::write(char* buffer, const char* buffer_end) const {
268 char* buffer_cur = qname.write(buffer, buffer_end);
269 if (buffer_cur == nullptr) return nullptr;
270 if (buffer_cur + 2*sizeof(uint16_t) > buffer_end) {
271 ALOGI("buffer overflow on line %d", __LINE__);
272 return nullptr;
273 }
274 *reinterpret_cast<uint16_t*>(buffer_cur) = htons(qtype);
275 *reinterpret_cast<uint16_t*>(buffer_cur + sizeof(uint16_t)) =
276 htons(qclass);
277 return buffer_cur + 2*sizeof(uint16_t);
278 }
279
toString() const280 std::string DNSQuestion::toString() const {
281 char buffer[4096];
282 int len = snprintf(buffer, sizeof(buffer), "Q<%s,%s,%s>", qname.toString(),
283 dnstype2str(qtype), dnsclass2str(qclass));
284 return std::string(buffer, len);
285 }
286
287 struct DNSRecord {
288 DNSName name;
289 unsigned rtype;
290 unsigned rclass;
291 unsigned ttl;
292 std::vector<char> rdata;
293 const char* read(const char* buffer, const char* buffer_end);
294 char* write(char* buffer, const char* buffer_end) const;
295 std::string toString() const;
296 private:
297 struct IntFields {
298 uint16_t rtype;
299 uint16_t rclass;
300 uint32_t ttl;
301 uint16_t rdlen;
302 } __attribute__((__packed__));
303
304 const char* readIntFields(const char* buffer, const char* buffer_end,
305 unsigned* rdlen);
306 char* writeIntFields(unsigned rdlen, char* buffer,
307 const char* buffer_end) const;
308 };
309
read(const char * buffer,const char * buffer_end)310 const char* DNSRecord::read(const char* buffer, const char* buffer_end) {
311 const char* cur = name.read(buffer, buffer_end);
312 if (cur == nullptr) {
313 ALOGI("parsing failed at line %d", __LINE__);
314 return nullptr;
315 }
316 unsigned rdlen = 0;
317 cur = readIntFields(cur, buffer_end, &rdlen);
318 if (cur == nullptr) {
319 ALOGI("parsing failed at line %d", __LINE__);
320 return nullptr;
321 }
322 if (cur + rdlen > buffer_end) {
323 ALOGI("parsing failed at line %d", __LINE__);
324 return nullptr;
325 }
326 rdata.assign(cur, cur + rdlen);
327 return cur + rdlen;
328 }
329
write(char * buffer,const char * buffer_end) const330 char* DNSRecord::write(char* buffer, const char* buffer_end) const {
331 char* buffer_cur = name.write(buffer, buffer_end);
332 if (buffer_cur == nullptr) return nullptr;
333 buffer_cur = writeIntFields(rdata.size(), buffer_cur, buffer_end);
334 if (buffer_cur == nullptr) return nullptr;
335 if (buffer_cur + rdata.size() > buffer_end) {
336 ALOGI("buffer overflow on line %d", __LINE__);
337 return nullptr;
338 }
339 return std::copy(rdata.begin(), rdata.end(), buffer_cur);
340 }
341
toString() const342 std::string DNSRecord::toString() const {
343 char buffer[4096];
344 int len = snprintf(buffer, sizeof(buffer), "R<%s,%s,%s>", name.toString(),
345 dnstype2str(rtype), dnsclass2str(rclass));
346 return std::string(buffer, len);
347 }
348
readIntFields(const char * buffer,const char * buffer_end,unsigned * rdlen)349 const char* DNSRecord::readIntFields(const char* buffer, const char* buffer_end,
350 unsigned* rdlen) {
351 if (buffer + sizeof(IntFields) > buffer_end ) {
352 ALOGI("parsing failed at line %d", __LINE__);
353 return nullptr;
354 }
355 const auto& intfields = *reinterpret_cast<const IntFields*>(buffer);
356 rtype = ntohs(intfields.rtype);
357 rclass = ntohs(intfields.rclass);
358 ttl = ntohl(intfields.ttl);
359 *rdlen = ntohs(intfields.rdlen);
360 return buffer + sizeof(IntFields);
361 }
362
writeIntFields(unsigned rdlen,char * buffer,const char * buffer_end) const363 char* DNSRecord::writeIntFields(unsigned rdlen, char* buffer,
364 const char* buffer_end) const {
365 if (buffer + sizeof(IntFields) > buffer_end ) {
366 ALOGI("buffer overflow on line %d", __LINE__);
367 return nullptr;
368 }
369 auto& intfields = *reinterpret_cast<IntFields*>(buffer);
370 intfields.rtype = htons(rtype);
371 intfields.rclass = htons(rclass);
372 intfields.ttl = htonl(ttl);
373 intfields.rdlen = htons(rdlen);
374 return buffer + sizeof(IntFields);
375 }
376
377 struct DNSHeader {
378 unsigned id;
379 bool ra;
380 uint8_t rcode;
381 bool qr;
382 uint8_t opcode;
383 bool aa;
384 bool tr;
385 bool rd;
386 bool ad;
387 std::vector<DNSQuestion> questions;
388 std::vector<DNSRecord> answers;
389 std::vector<DNSRecord> authorities;
390 std::vector<DNSRecord> additionals;
391 const char* read(const char* buffer, const char* buffer_end);
392 char* write(char* buffer, const char* buffer_end) const;
393 std::string toString() const;
394
395 private:
396 struct Header {
397 uint16_t id;
398 uint8_t flags0;
399 uint8_t flags1;
400 uint16_t qdcount;
401 uint16_t ancount;
402 uint16_t nscount;
403 uint16_t arcount;
404 } __attribute__((__packed__));
405
406 const char* readHeader(const char* buffer, const char* buffer_end,
407 unsigned* qdcount, unsigned* ancount,
408 unsigned* nscount, unsigned* arcount);
409 };
410
read(const char * buffer,const char * buffer_end)411 const char* DNSHeader::read(const char* buffer, const char* buffer_end) {
412 unsigned qdcount;
413 unsigned ancount;
414 unsigned nscount;
415 unsigned arcount;
416 const char* cur = readHeader(buffer, buffer_end, &qdcount, &ancount,
417 &nscount, &arcount);
418 if (cur == nullptr) {
419 ALOGI("parsing failed at line %d", __LINE__);
420 return nullptr;
421 }
422 if (qdcount) {
423 questions.resize(qdcount);
424 for (unsigned i = 0 ; i < qdcount ; ++i) {
425 cur = questions[i].read(cur, buffer_end);
426 if (cur == nullptr) {
427 ALOGI("parsing failed at line %d", __LINE__);
428 return nullptr;
429 }
430 }
431 }
432 if (ancount) {
433 answers.resize(ancount);
434 for (unsigned i = 0 ; i < ancount ; ++i) {
435 cur = answers[i].read(cur, buffer_end);
436 if (cur == nullptr) {
437 ALOGI("parsing failed at line %d", __LINE__);
438 return nullptr;
439 }
440 }
441 }
442 if (nscount) {
443 authorities.resize(nscount);
444 for (unsigned i = 0 ; i < nscount ; ++i) {
445 cur = authorities[i].read(cur, buffer_end);
446 if (cur == nullptr) {
447 ALOGI("parsing failed at line %d", __LINE__);
448 return nullptr;
449 }
450 }
451 }
452 if (arcount) {
453 additionals.resize(arcount);
454 for (unsigned i = 0 ; i < arcount ; ++i) {
455 cur = additionals[i].read(cur, buffer_end);
456 if (cur == nullptr) {
457 ALOGI("parsing failed at line %d", __LINE__);
458 return nullptr;
459 }
460 }
461 }
462 return cur;
463 }
464
write(char * buffer,const char * buffer_end) const465 char* DNSHeader::write(char* buffer, const char* buffer_end) const {
466 if (buffer + sizeof(Header) > buffer_end) {
467 ALOGI("buffer overflow on line %d", __LINE__);
468 return nullptr;
469 }
470 Header& header = *reinterpret_cast<Header*>(buffer);
471 // bytes 0-1
472 header.id = htons(id);
473 // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
474 header.flags0 = (qr << 7) | (opcode << 3) | (aa << 2) | (tr << 1) | rd;
475 // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
476 // Fake behavior: if the query set the "ad" bit, set it in the response too.
477 // In a real server, this should be set only if the data is authentic and the
478 // query contained an "ad" bit or DNSSEC extensions.
479 header.flags1 = (ad << 5) | rcode;
480 // rest of header
481 header.qdcount = htons(questions.size());
482 header.ancount = htons(answers.size());
483 header.nscount = htons(authorities.size());
484 header.arcount = htons(additionals.size());
485 char* buffer_cur = buffer + sizeof(Header);
486 for (const DNSQuestion& question : questions) {
487 buffer_cur = question.write(buffer_cur, buffer_end);
488 if (buffer_cur == nullptr) return nullptr;
489 }
490 for (const DNSRecord& answer : answers) {
491 buffer_cur = answer.write(buffer_cur, buffer_end);
492 if (buffer_cur == nullptr) return nullptr;
493 }
494 for (const DNSRecord& authority : authorities) {
495 buffer_cur = authority.write(buffer_cur, buffer_end);
496 if (buffer_cur == nullptr) return nullptr;
497 }
498 for (const DNSRecord& additional : additionals) {
499 buffer_cur = additional.write(buffer_cur, buffer_end);
500 if (buffer_cur == nullptr) return nullptr;
501 }
502 return buffer_cur;
503 }
504
toString() const505 std::string DNSHeader::toString() const {
506 // TODO
507 return std::string();
508 }
509
readHeader(const char * buffer,const char * buffer_end,unsigned * qdcount,unsigned * ancount,unsigned * nscount,unsigned * arcount)510 const char* DNSHeader::readHeader(const char* buffer, const char* buffer_end,
511 unsigned* qdcount, unsigned* ancount,
512 unsigned* nscount, unsigned* arcount) {
513 if (buffer + sizeof(Header) > buffer_end)
514 return nullptr;
515 const auto& header = *reinterpret_cast<const Header*>(buffer);
516 // bytes 0-1
517 id = ntohs(header.id);
518 // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
519 qr = header.flags0 >> 7;
520 opcode = (header.flags0 >> 3) & 0x0F;
521 aa = (header.flags0 >> 2) & 1;
522 tr = (header.flags0 >> 1) & 1;
523 rd = header.flags0 & 1;
524 // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
525 ra = header.flags1 >> 7;
526 ad = (header.flags1 >> 5) & 1;
527 rcode = header.flags1 & 0xF;
528 // rest of header
529 *qdcount = ntohs(header.qdcount);
530 *ancount = ntohs(header.ancount);
531 *nscount = ntohs(header.nscount);
532 *arcount = ntohs(header.arcount);
533 return buffer + sizeof(Header);
534 }
535
536 /* DNS responder */
537
DNSResponder(std::string listen_address,std::string listen_service,int poll_timeout_ms,ns_rcode error_rcode)538 DNSResponder::DNSResponder(std::string listen_address, std::string listen_service,
539 int poll_timeout_ms, ns_rcode error_rcode)
540 : listen_address_(std::move(listen_address)),
541 listen_service_(std::move(listen_service)),
542 poll_timeout_ms_(poll_timeout_ms),
543 error_rcode_(error_rcode) {}
544
~DNSResponder()545 DNSResponder::~DNSResponder() {
546 stopServer();
547 }
548
addMapping(const std::string & name,ns_type type,const std::string & addr)549 void DNSResponder::addMapping(const std::string& name, ns_type type, const std::string& addr) {
550 std::lock_guard lock(mappings_mutex_);
551 auto it = mappings_.find(QueryKey(name, type));
552 if (it != mappings_.end()) {
553 ALOGI("Overwriting mapping for (%s, %s), previous address %s, new "
554 "address %s",
555 name.c_str(), dnstype2str(type), it->second.c_str(), addr.c_str());
556 it->second = addr;
557 return;
558 }
559 mappings_.try_emplace({name, type}, addr);
560 }
561
removeMapping(const std::string & name,ns_type type)562 void DNSResponder::removeMapping(const std::string& name, ns_type type) {
563 std::lock_guard lock(mappings_mutex_);
564 auto it = mappings_.find(QueryKey(name, type));
565 if (it != mappings_.end()) {
566 ALOGI("Cannot remove mapping mapping from (%s, %s), not present", name.c_str(),
567 dnstype2str(type));
568 return;
569 }
570 mappings_.erase(it);
571 }
572
setResponseProbability(double response_probability)573 void DNSResponder::setResponseProbability(double response_probability) {
574 response_probability_ = response_probability;
575 }
576
setEdns(Edns edns)577 void DNSResponder::setEdns(Edns edns) {
578 edns_ = edns;
579 }
580
running() const581 bool DNSResponder::running() const {
582 return socket_.get() != -1;
583 }
584
startServer()585 bool DNSResponder::startServer() {
586 if (running()) {
587 ALOGI("server already running");
588 return false;
589 }
590
591 // Set up UDP socket.
592 addrinfo ai_hints{
593 .ai_family = AF_UNSPEC,
594 .ai_socktype = SOCK_DGRAM,
595 .ai_flags = AI_PASSIVE
596 };
597 addrinfo* ai_res = nullptr;
598 int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(),
599 &ai_hints, &ai_res);
600 ScopedAddrinfo ai_res_cleanup(ai_res);
601 if (rv) {
602 ALOGI("getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
603 listen_service_.c_str(), gai_strerror(rv));
604 return false;
605 }
606 for (const addrinfo* ai = ai_res ; ai ; ai = ai->ai_next) {
607 socket_.reset(socket(ai->ai_family, ai->ai_socktype | SOCK_NONBLOCK, ai->ai_protocol));
608 if (socket_.get() < 0) {
609 APLOGI("ignore creating socket %d failed", socket_.get());
610 continue;
611 }
612 enableSockopt(socket_.get(), SOL_SOCKET, SO_REUSEPORT).ignoreError();
613 enableSockopt(socket_.get(), SOL_SOCKET, SO_REUSEADDR).ignoreError();
614 std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
615 if (bind(socket_.get(), ai->ai_addr, ai->ai_addrlen)) {
616 APLOGI("failed to bind UDP %s:%s", host_str.c_str(), listen_service_.c_str());
617 continue;
618 }
619 ALOGI("bound to UDP %s:%s", host_str.c_str(), listen_service_.c_str());
620 break;
621 }
622
623 // Set up eventfd socket.
624 event_fd_.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));
625 if (event_fd_.get() == -1) {
626 APLOGI("failed to create eventfd %d", event_fd_.get());
627 return false;
628 }
629
630 // Set up epoll socket.
631 epoll_fd_.reset(epoll_create1(EPOLL_CLOEXEC));
632 if (epoll_fd_.get() < 0) {
633 APLOGI("epoll_create1() failed on fd %d", epoll_fd_.get());
634 return false;
635 }
636
637 ALOGI("adding socket %d to epoll", socket_.get());
638 if (!addFd(socket_.get(), EPOLLIN)) {
639 ALOGE("failed to add the socket %d to epoll", socket_.get());
640 return false;
641 }
642 ALOGI("adding eventfd %d to epoll", event_fd_.get());
643 if (!addFd(event_fd_.get(), EPOLLIN)) {
644 ALOGE("failed to add the eventfd %d to epoll", event_fd_.get());
645 return false;
646 }
647
648 {
649 std::lock_guard lock(update_mutex_);
650 handler_thread_ = std::thread(&DNSResponder::requestHandler, this);
651 }
652 ALOGI("server started successfully");
653 return true;
654 }
655
stopServer()656 bool DNSResponder::stopServer() {
657 std::lock_guard lock(update_mutex_);
658 if (!running()) {
659 ALOGI("server not running");
660 return false;
661 }
662 ALOGI("stopping server");
663 if (!sendToEventFd()) {
664 return false;
665 }
666 handler_thread_.join();
667 epoll_fd_.reset();
668 socket_.reset();
669 ALOGI("server stopped successfully");
670 return true;
671 }
672
queries() const673 std::vector<std::pair<std::string, ns_type >> DNSResponder::queries() const {
674 std::lock_guard lock(queries_mutex_);
675 return queries_;
676 }
677
dumpQueries() const678 std::string DNSResponder::dumpQueries() const {
679 std::lock_guard lock(queries_mutex_);
680 std::string out;
681 for (const auto& q : queries_) {
682 out += "{\"" + q.first + "\", " + std::to_string(q.second) + "} ";
683 }
684 return out;
685 }
686
clearQueries()687 void DNSResponder::clearQueries() {
688 std::lock_guard lock(queries_mutex_);
689 queries_.clear();
690 }
691
requestHandler()692 void DNSResponder::requestHandler() {
693 epoll_event evs[EPOLL_MAX_EVENTS];
694 while (true) {
695 int n = epoll_wait(epoll_fd_.get(), evs, EPOLL_MAX_EVENTS, poll_timeout_ms_);
696 if (n == 0) continue;
697 if (n < 0) {
698 APLOGI("epoll_wait() failed, n=%d", n);
699 return;
700 }
701
702 for (int i = 0; i < n; i++) {
703 const int fd = evs[i].data.fd;
704 const uint32_t events = evs[i].events;
705 if (fd == event_fd_.get() && (events & (EPOLLIN | EPOLLERR))) {
706 handleEventFd();
707 return;
708 } else if (fd == socket_.get() && (events & (EPOLLIN | EPOLLERR))) {
709 handleQuery();
710 } else {
711 ALOGW("unexpected epoll events 0x%x on fd %d", events, fd);
712 }
713 }
714 }
715 }
716
handleDNSRequest(const char * buffer,ssize_t len,char * response,size_t * response_len) const717 bool DNSResponder::handleDNSRequest(const char* buffer, ssize_t len,
718 char* response, size_t* response_len)
719 const {
720 DBGLOG("request: '%s'", str2hex(buffer, len).c_str());
721 const char* buffer_end = buffer + len;
722 DNSHeader header;
723 const char* cur = header.read(buffer, buffer_end);
724 // TODO(imaipi): for now, unparsable messages are silently dropped, fix.
725 if (cur == nullptr) {
726 ALOGI("failed to parse query");
727 return false;
728 }
729 if (header.qr) {
730 ALOGI("response received instead of a query");
731 return false;
732 }
733 if (header.opcode != ns_opcode::ns_o_query) {
734 ALOGI("unsupported request opcode received");
735 return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
736 response_len);
737 }
738 if (header.questions.empty()) {
739 ALOGI("no questions present");
740 return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
741 response_len);
742 }
743 if (!header.answers.empty()) {
744 ALOGI("already %zu answers present in query", header.answers.size());
745 return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
746 response_len);
747 }
748
749 if (edns_ == Edns::FORMERR_UNCOND) {
750 ALOGI("force to return RCODE FORMERR");
751 return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, response_len);
752 }
753
754 if (!header.additionals.empty() && edns_ != Edns::ON) {
755 ALOGI("DNS request has an additional section (assumed EDNS). "
756 "Simulating an ancient (pre-EDNS) server, and returning %s",
757 edns_ == Edns::FORMERR_ON_EDNS ? "RCODE FORMERR." : "no response.");
758 if (edns_ == Edns::FORMERR_ON_EDNS) {
759 return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, response_len);
760 }
761 // No response.
762 return false;
763 }
764 {
765 std::lock_guard lock(queries_mutex_);
766 for (const DNSQuestion& question : header.questions) {
767 queries_.push_back(make_pair(question.qname.name,
768 ns_type(question.qtype)));
769 }
770 }
771
772 // Ignore requests with the preset probability.
773 auto constexpr bound = std::numeric_limits<unsigned>::max();
774 if (arc4random_uniform(bound) > bound * response_probability_) {
775 if (error_rcode_ < 0) {
776 ALOGI("Returning no response");
777 return false;
778 } else {
779 ALOGI("returning RCODE %d in accordance with probability distribution",
780 static_cast<int>(error_rcode_));
781 return makeErrorResponse(&header, error_rcode_, response, response_len);
782 }
783 }
784
785 for (const DNSQuestion& question : header.questions) {
786 if (question.qclass != ns_class::ns_c_in &&
787 question.qclass != ns_class::ns_c_any) {
788 ALOGI("unsupported question class %u", question.qclass);
789 return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
790 response_len);
791 }
792
793 if (!addAnswerRecords(question, &header.answers)) {
794 return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response, response_len);
795 }
796 }
797
798 header.qr = true;
799 char* response_cur = header.write(response, response + *response_len);
800 if (response_cur == nullptr) {
801 return false;
802 }
803 *response_len = response_cur - response;
804 return true;
805 }
806
addAnswerRecords(const DNSQuestion & question,std::vector<DNSRecord> * answers) const807 bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
808 std::vector<DNSRecord>* answers) const {
809 std::lock_guard guard(mappings_mutex_);
810 std::string rname = question.qname.name;
811 std::vector<int> rtypes;
812
813 if (question.qtype == ns_type::ns_t_a || question.qtype == ns_type::ns_t_aaaa)
814 rtypes.push_back(ns_type::ns_t_cname);
815 rtypes.push_back(question.qtype);
816 for (int rtype : rtypes) {
817 std::set<std::string> cnames_Loop;
818 std::unordered_map<QueryKey, std::string, QueryKeyHash>::const_iterator it;
819 while ((it = mappings_.find(QueryKey(rname, rtype))) != mappings_.end()) {
820 if (rtype == ns_type::ns_t_cname) {
821 // When detect CNAME infinite loops by cnames_Loop, it won't save the duplicate one.
822 // As following, the query will stop on loop3 by detecting the same cname.
823 // loop1.{"a.xxx.com", ns_type::ns_t_cname, "b.xxx.com"}(insert in answer record)
824 // loop2.{"b.xxx.com", ns_type::ns_t_cname, "a.xxx.com"}(insert in answer record)
825 // loop3.{"a.xxx.com", ns_type::ns_t_cname, "b.xxx.com"}(When the same cname record
826 // is found in cnames_Loop already, break the query loop.)
827 if (cnames_Loop.find(it->first.name) != cnames_Loop.end()) break;
828 cnames_Loop.insert(it->first.name);
829 }
830 DNSRecord record{
831 .name = {.name = it->first.name},
832 .rtype = it->first.type,
833 .rclass = ns_class::ns_c_in,
834 .ttl = 5, // seconds
835 };
836 fillAnswerRdata(it->second, record);
837 answers->push_back(std::move(record));
838 if (rtype != ns_type::ns_t_cname) break;
839 rname = it->second;
840 }
841 }
842
843 if (answers->size() == 0) {
844 // TODO(imaipi): handle correctly
845 ALOGI("no mapping found for %s %s, lazily refusing to add an answer",
846 question.qname.name.c_str(), dnstype2str(question.qtype));
847 }
848
849 return true;
850 }
851
fillAnswerRdata(const std::string & rdatastr,DNSRecord & record) const852 bool DNSResponder::fillAnswerRdata(const std::string& rdatastr, DNSRecord& record) const {
853 if (record.rtype == ns_type::ns_t_a) {
854 record.rdata.resize(4);
855 if (inet_pton(AF_INET, rdatastr.c_str(), record.rdata.data()) != 1) {
856 ALOGI("inet_pton(AF_INET, %s) failed", rdatastr.c_str());
857 return false;
858 }
859 } else if (record.rtype == ns_type::ns_t_aaaa) {
860 record.rdata.resize(16);
861 if (inet_pton(AF_INET6, rdatastr.c_str(), record.rdata.data()) != 1) {
862 ALOGI("inet_pton(AF_INET6, %s) failed", rdatastr.c_str());
863 return false;
864 }
865 } else if ((record.rtype == ns_type::ns_t_ptr) || (record.rtype == ns_type::ns_t_cname)) {
866 constexpr char delimiter = '.';
867 std::string name = rdatastr;
868 std::vector<char> rdata;
869
870 // Generating PTRDNAME field(section 3.3.12) or CNAME field(section 3.3.1) in rfc1035.
871 // The "name" should be an absolute domain name which ends in a dot.
872 if (name.back() != delimiter) {
873 ALOGI("invalid absolute domain name");
874 return false;
875 }
876 name.pop_back(); // remove the dot in tail
877 for (const std::string& label : android::base::Split(name, {delimiter})) {
878 // The length of label is limited to 63 octets or less. See RFC 1035 section 3.1.
879 if (label.length() == 0 || label.length() > 63) {
880 ALOGI("invalid label length");
881 return false;
882 }
883
884 rdata.push_back(label.length());
885 rdata.insert(rdata.end(), label.begin(), label.end());
886 }
887 rdata.push_back(0); // Length byte of zero terminates the label list
888
889 // The length of domain name is limited to 255 octets or less. See RFC 1035 section 3.1.
890 if (rdata.size() > 255) {
891 ALOGI("invalid name length");
892 return false;
893 }
894 record.rdata = move(rdata);
895 } else {
896 ALOGI("unhandled qtype %s", dnstype2str(record.rtype));
897 return false;
898 }
899 return true;
900 }
901
makeErrorResponse(DNSHeader * header,ns_rcode rcode,char * response,size_t * response_len) const902 bool DNSResponder::makeErrorResponse(DNSHeader* header, ns_rcode rcode,
903 char* response, size_t* response_len)
904 const {
905 header->answers.clear();
906 header->authorities.clear();
907 header->additionals.clear();
908 header->rcode = rcode;
909 header->qr = true;
910 char* response_cur = header->write(response, response + *response_len);
911 if (response_cur == nullptr) return false;
912 *response_len = response_cur - response;
913 return true;
914 }
915
setDeferredResp(bool deferred_resp)916 void DNSResponder::setDeferredResp(bool deferred_resp) {
917 std::lock_guard<std::mutex> guard(cv_mutex_for_deferred_resp_);
918 deferred_resp_ = deferred_resp;
919 if (!deferred_resp_) {
920 cv_for_deferred_resp_.notify_one();
921 }
922 }
923
addFd(int fd,uint32_t events)924 bool DNSResponder::addFd(int fd, uint32_t events) {
925 epoll_event ev;
926 ev.events = events;
927 ev.data.fd = fd;
928 if (epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, fd, &ev) < 0) {
929 APLOGI("epoll_ctl() for socket %d failed", fd);
930 return false;
931 }
932 return true;
933 }
934
handleQuery()935 void DNSResponder::handleQuery() {
936 char buffer[4096];
937 sockaddr_storage sa;
938 socklen_t sa_len = sizeof(sa);
939 ssize_t len;
940 do {
941 len = recvfrom(socket_.get(), buffer, sizeof(buffer), 0, (sockaddr*)&sa, &sa_len);
942 } while (len < 0 && (errno == EAGAIN || errno == EINTR));
943 if (len <= 0) {
944 APLOGI("recvfrom() failed, len=%zu", len);
945 return;
946 }
947 DBGLOG("read %zd bytes", len);
948 std::lock_guard lock(cv_mutex_);
949 char response[4096];
950 size_t response_len = sizeof(response);
951 if (handleDNSRequest(buffer, len, response, &response_len) && response_len > 0) {
952 // place wait_for after handleDNSRequest() so we can check the number of queries in
953 // test case before it got responded.
954 std::unique_lock guard(cv_mutex_for_deferred_resp_);
955 cv_for_deferred_resp_.wait(
956 guard, [this]() REQUIRES(cv_mutex_for_deferred_resp_) { return !deferred_resp_; });
957
958 len = sendto(socket_.get(), response, response_len, 0,
959 reinterpret_cast<const sockaddr*>(&sa), sa_len);
960 std::string host_str = addr2str(reinterpret_cast<const sockaddr*>(&sa), sa_len);
961 if (len > 0) {
962 DBGLOG("sent %zu bytes to %s", len, host_str.c_str());
963 } else {
964 APLOGI("sendto() failed for %s", host_str.c_str());
965 }
966 // Test that the response is actually a correct DNS message.
967 const char* response_end = response + len;
968 DNSHeader header;
969 const char* cur = header.read(response, response_end);
970 if (cur == nullptr) ALOGW("response is flawed");
971 } else {
972 ALOGW("not responding");
973 }
974 cv.notify_one();
975 return;
976 }
977
sendToEventFd()978 bool DNSResponder::sendToEventFd() {
979 const uint64_t data = 1;
980 if (const ssize_t rt = write(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
981 APLOGI("failed to write eventfd, rt=%zd", rt);
982 return false;
983 }
984 return true;
985 }
986
handleEventFd()987 void DNSResponder::handleEventFd() {
988 int64_t data;
989 if (const ssize_t rt = read(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
990 APLOGI("ignore reading eventfd failed, rt=%zd", rt);
991 }
992 }
993
994 } // namespace test
995