1 /*
2 * Copyright (C) 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "dns_responder.h"
18
19 #include <arpa/inet.h>
20 #include <fcntl.h>
21 #include <netdb.h>
22 #include <stdarg.h>
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <string.h>
26 #include <sys/epoll.h>
27 #include <sys/socket.h>
28 #include <sys/types.h>
29 #include <unistd.h>
30
31 #include <iostream>
32 #include <vector>
33
34 #define LOG_TAG "DNSResponder"
35 #include <log/log.h>
36 #include <netdutils/SocketOption.h>
37
38 using android::netdutils::enableSockopt;
39
40 namespace test {
41
errno2str()42 std::string errno2str() {
43 char error_msg[512] = { 0 };
44 if (strerror_r(errno, error_msg, sizeof(error_msg)))
45 return std::string();
46 return std::string(error_msg);
47 }
48
49 #define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str())
50
str2hex(const char * buffer,size_t len)51 std::string str2hex(const char* buffer, size_t len) {
52 std::string str(len*2, '\0');
53 for (size_t i = 0 ; i < len ; ++i) {
54 static const char* hex = "0123456789ABCDEF";
55 uint8_t c = buffer[i];
56 str[i*2] = hex[c >> 4];
57 str[i*2 + 1] = hex[c & 0x0F];
58 }
59 return str;
60 }
61
addr2str(const sockaddr * sa,socklen_t sa_len)62 std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
63 char host_str[NI_MAXHOST] = { 0 };
64 int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0,
65 NI_NUMERICHOST);
66 if (rv == 0) return std::string(host_str);
67 return std::string();
68 }
69
70 /* DNS struct helpers */
71
dnstype2str(unsigned dnstype)72 const char* dnstype2str(unsigned dnstype) {
73 static std::unordered_map<unsigned, const char*> kTypeStrs = {
74 { ns_type::ns_t_a, "A" },
75 { ns_type::ns_t_ns, "NS" },
76 { ns_type::ns_t_md, "MD" },
77 { ns_type::ns_t_mf, "MF" },
78 { ns_type::ns_t_cname, "CNAME" },
79 { ns_type::ns_t_soa, "SOA" },
80 { ns_type::ns_t_mb, "MB" },
81 { ns_type::ns_t_mb, "MG" },
82 { ns_type::ns_t_mr, "MR" },
83 { ns_type::ns_t_null, "NULL" },
84 { ns_type::ns_t_wks, "WKS" },
85 { ns_type::ns_t_ptr, "PTR" },
86 { ns_type::ns_t_hinfo, "HINFO" },
87 { ns_type::ns_t_minfo, "MINFO" },
88 { ns_type::ns_t_mx, "MX" },
89 { ns_type::ns_t_txt, "TXT" },
90 { ns_type::ns_t_rp, "RP" },
91 { ns_type::ns_t_afsdb, "AFSDB" },
92 { ns_type::ns_t_x25, "X25" },
93 { ns_type::ns_t_isdn, "ISDN" },
94 { ns_type::ns_t_rt, "RT" },
95 { ns_type::ns_t_nsap, "NSAP" },
96 { ns_type::ns_t_nsap_ptr, "NSAP-PTR" },
97 { ns_type::ns_t_sig, "SIG" },
98 { ns_type::ns_t_key, "KEY" },
99 { ns_type::ns_t_px, "PX" },
100 { ns_type::ns_t_gpos, "GPOS" },
101 { ns_type::ns_t_aaaa, "AAAA" },
102 { ns_type::ns_t_loc, "LOC" },
103 { ns_type::ns_t_nxt, "NXT" },
104 { ns_type::ns_t_eid, "EID" },
105 { ns_type::ns_t_nimloc, "NIMLOC" },
106 { ns_type::ns_t_srv, "SRV" },
107 { ns_type::ns_t_naptr, "NAPTR" },
108 { ns_type::ns_t_kx, "KX" },
109 { ns_type::ns_t_cert, "CERT" },
110 { ns_type::ns_t_a6, "A6" },
111 { ns_type::ns_t_dname, "DNAME" },
112 { ns_type::ns_t_sink, "SINK" },
113 { ns_type::ns_t_opt, "OPT" },
114 { ns_type::ns_t_apl, "APL" },
115 { ns_type::ns_t_tkey, "TKEY" },
116 { ns_type::ns_t_tsig, "TSIG" },
117 { ns_type::ns_t_ixfr, "IXFR" },
118 { ns_type::ns_t_axfr, "AXFR" },
119 { ns_type::ns_t_mailb, "MAILB" },
120 { ns_type::ns_t_maila, "MAILA" },
121 { ns_type::ns_t_any, "ANY" },
122 { ns_type::ns_t_zxfr, "ZXFR" },
123 };
124 auto it = kTypeStrs.find(dnstype);
125 static const char* kUnknownStr{ "UNKNOWN" };
126 if (it == kTypeStrs.end()) return kUnknownStr;
127 return it->second;
128 }
129
dnsclass2str(unsigned dnsclass)130 const char* dnsclass2str(unsigned dnsclass) {
131 static std::unordered_map<unsigned, const char*> kClassStrs = {
132 { ns_class::ns_c_in , "Internet" },
133 { 2, "CSNet" },
134 { ns_class::ns_c_chaos, "ChaosNet" },
135 { ns_class::ns_c_hs, "Hesiod" },
136 { ns_class::ns_c_none, "none" },
137 { ns_class::ns_c_any, "any" }
138 };
139 auto it = kClassStrs.find(dnsclass);
140 static const char* kUnknownStr{ "UNKNOWN" };
141 if (it == kClassStrs.end()) return kUnknownStr;
142 return it->second;
143 return "unknown";
144 }
145
146 struct DNSName {
147 std::string name;
148 const char* read(const char* buffer, const char* buffer_end);
149 char* write(char* buffer, const char* buffer_end) const;
150 const char* toString() const;
151 private:
152 const char* parseField(const char* buffer, const char* buffer_end,
153 bool* last);
154 };
155
toString() const156 const char* DNSName::toString() const {
157 return name.c_str();
158 }
159
read(const char * buffer,const char * buffer_end)160 const char* DNSName::read(const char* buffer, const char* buffer_end) {
161 const char* cur = buffer;
162 bool last = false;
163 do {
164 cur = parseField(cur, buffer_end, &last);
165 if (cur == nullptr) {
166 ALOGI("parsing failed at line %d", __LINE__);
167 return nullptr;
168 }
169 } while (!last);
170 return cur;
171 }
172
write(char * buffer,const char * buffer_end) const173 char* DNSName::write(char* buffer, const char* buffer_end) const {
174 char* buffer_cur = buffer;
175 for (size_t pos = 0 ; pos < name.size() ; ) {
176 size_t dot_pos = name.find('.', pos);
177 if (dot_pos == std::string::npos) {
178 // Sanity check, should never happen unless parseField is broken.
179 ALOGI("logic error: all names are expected to end with a '.'");
180 return nullptr;
181 }
182 size_t len = dot_pos - pos;
183 if (len >= 256) {
184 ALOGI("name component '%s' is %zu long, but max is 255",
185 name.substr(pos, dot_pos - pos).c_str(), len);
186 return nullptr;
187 }
188 if (buffer_cur + sizeof(uint8_t) + len > buffer_end) {
189 ALOGI("buffer overflow at line %d", __LINE__);
190 return nullptr;
191 }
192 *buffer_cur++ = len;
193 buffer_cur = std::copy(std::next(name.begin(), pos),
194 std::next(name.begin(), dot_pos),
195 buffer_cur);
196 pos = dot_pos + 1;
197 }
198 // Write final zero.
199 *buffer_cur++ = 0;
200 return buffer_cur;
201 }
202
parseField(const char * buffer,const char * buffer_end,bool * last)203 const char* DNSName::parseField(const char* buffer, const char* buffer_end,
204 bool* last) {
205 if (buffer + sizeof(uint8_t) > buffer_end) {
206 ALOGI("parsing failed at line %d", __LINE__);
207 return nullptr;
208 }
209 unsigned field_type = *buffer >> 6;
210 unsigned ofs = *buffer & 0x3F;
211 const char* cur = buffer + sizeof(uint8_t);
212 if (field_type == 0) {
213 // length + name component
214 if (ofs == 0) {
215 *last = true;
216 return cur;
217 }
218 if (cur + ofs > buffer_end) {
219 ALOGI("parsing failed at line %d", __LINE__);
220 return nullptr;
221 }
222 name.append(cur, ofs);
223 name.push_back('.');
224 return cur + ofs;
225 } else if (field_type == 3) {
226 ALOGI("name compression not implemented");
227 return nullptr;
228 }
229 ALOGI("invalid name field type");
230 return nullptr;
231 }
232
233 struct DNSQuestion {
234 DNSName qname;
235 unsigned qtype;
236 unsigned qclass;
237 const char* read(const char* buffer, const char* buffer_end);
238 char* write(char* buffer, const char* buffer_end) const;
239 std::string toString() const;
240 };
241
read(const char * buffer,const char * buffer_end)242 const char* DNSQuestion::read(const char* buffer, const char* buffer_end) {
243 const char* cur = qname.read(buffer, buffer_end);
244 if (cur == nullptr) {
245 ALOGI("parsing failed at line %d", __LINE__);
246 return nullptr;
247 }
248 if (cur + 2*sizeof(uint16_t) > buffer_end) {
249 ALOGI("parsing failed at line %d", __LINE__);
250 return nullptr;
251 }
252 qtype = ntohs(*reinterpret_cast<const uint16_t*>(cur));
253 qclass = ntohs(*reinterpret_cast<const uint16_t*>(cur + sizeof(uint16_t)));
254 return cur + 2*sizeof(uint16_t);
255 }
256
write(char * buffer,const char * buffer_end) const257 char* DNSQuestion::write(char* buffer, const char* buffer_end) const {
258 char* buffer_cur = qname.write(buffer, buffer_end);
259 if (buffer_cur == nullptr) return nullptr;
260 if (buffer_cur + 2*sizeof(uint16_t) > buffer_end) {
261 ALOGI("buffer overflow on line %d", __LINE__);
262 return nullptr;
263 }
264 *reinterpret_cast<uint16_t*>(buffer_cur) = htons(qtype);
265 *reinterpret_cast<uint16_t*>(buffer_cur + sizeof(uint16_t)) =
266 htons(qclass);
267 return buffer_cur + 2*sizeof(uint16_t);
268 }
269
toString() const270 std::string DNSQuestion::toString() const {
271 char buffer[4096];
272 int len = snprintf(buffer, sizeof(buffer), "Q<%s,%s,%s>", qname.toString(),
273 dnstype2str(qtype), dnsclass2str(qclass));
274 return std::string(buffer, len);
275 }
276
277 struct DNSRecord {
278 DNSName name;
279 unsigned rtype;
280 unsigned rclass;
281 unsigned ttl;
282 std::vector<char> rdata;
283 const char* read(const char* buffer, const char* buffer_end);
284 char* write(char* buffer, const char* buffer_end) const;
285 std::string toString() const;
286 private:
287 struct IntFields {
288 uint16_t rtype;
289 uint16_t rclass;
290 uint32_t ttl;
291 uint16_t rdlen;
292 } __attribute__((__packed__));
293
294 const char* readIntFields(const char* buffer, const char* buffer_end,
295 unsigned* rdlen);
296 char* writeIntFields(unsigned rdlen, char* buffer,
297 const char* buffer_end) const;
298 };
299
read(const char * buffer,const char * buffer_end)300 const char* DNSRecord::read(const char* buffer, const char* buffer_end) {
301 const char* cur = name.read(buffer, buffer_end);
302 if (cur == nullptr) {
303 ALOGI("parsing failed at line %d", __LINE__);
304 return nullptr;
305 }
306 unsigned rdlen = 0;
307 cur = readIntFields(cur, buffer_end, &rdlen);
308 if (cur == nullptr) {
309 ALOGI("parsing failed at line %d", __LINE__);
310 return nullptr;
311 }
312 if (cur + rdlen > buffer_end) {
313 ALOGI("parsing failed at line %d", __LINE__);
314 return nullptr;
315 }
316 rdata.assign(cur, cur + rdlen);
317 return cur + rdlen;
318 }
319
write(char * buffer,const char * buffer_end) const320 char* DNSRecord::write(char* buffer, const char* buffer_end) const {
321 char* buffer_cur = name.write(buffer, buffer_end);
322 if (buffer_cur == nullptr) return nullptr;
323 buffer_cur = writeIntFields(rdata.size(), buffer_cur, buffer_end);
324 if (buffer_cur == nullptr) return nullptr;
325 if (buffer_cur + rdata.size() > buffer_end) {
326 ALOGI("buffer overflow on line %d", __LINE__);
327 return nullptr;
328 }
329 return std::copy(rdata.begin(), rdata.end(), buffer_cur);
330 }
331
toString() const332 std::string DNSRecord::toString() const {
333 char buffer[4096];
334 int len = snprintf(buffer, sizeof(buffer), "R<%s,%s,%s>", name.toString(),
335 dnstype2str(rtype), dnsclass2str(rclass));
336 return std::string(buffer, len);
337 }
338
readIntFields(const char * buffer,const char * buffer_end,unsigned * rdlen)339 const char* DNSRecord::readIntFields(const char* buffer, const char* buffer_end,
340 unsigned* rdlen) {
341 if (buffer + sizeof(IntFields) > buffer_end ) {
342 ALOGI("parsing failed at line %d", __LINE__);
343 return nullptr;
344 }
345 const auto& intfields = *reinterpret_cast<const IntFields*>(buffer);
346 rtype = ntohs(intfields.rtype);
347 rclass = ntohs(intfields.rclass);
348 ttl = ntohl(intfields.ttl);
349 *rdlen = ntohs(intfields.rdlen);
350 return buffer + sizeof(IntFields);
351 }
352
writeIntFields(unsigned rdlen,char * buffer,const char * buffer_end) const353 char* DNSRecord::writeIntFields(unsigned rdlen, char* buffer,
354 const char* buffer_end) const {
355 if (buffer + sizeof(IntFields) > buffer_end ) {
356 ALOGI("buffer overflow on line %d", __LINE__);
357 return nullptr;
358 }
359 auto& intfields = *reinterpret_cast<IntFields*>(buffer);
360 intfields.rtype = htons(rtype);
361 intfields.rclass = htons(rclass);
362 intfields.ttl = htonl(ttl);
363 intfields.rdlen = htons(rdlen);
364 return buffer + sizeof(IntFields);
365 }
366
367 struct DNSHeader {
368 unsigned id;
369 bool ra;
370 uint8_t rcode;
371 bool qr;
372 uint8_t opcode;
373 bool aa;
374 bool tr;
375 bool rd;
376 bool ad;
377 std::vector<DNSQuestion> questions;
378 std::vector<DNSRecord> answers;
379 std::vector<DNSRecord> authorities;
380 std::vector<DNSRecord> additionals;
381 const char* read(const char* buffer, const char* buffer_end);
382 char* write(char* buffer, const char* buffer_end) const;
383 std::string toString() const;
384
385 private:
386 struct Header {
387 uint16_t id;
388 uint8_t flags0;
389 uint8_t flags1;
390 uint16_t qdcount;
391 uint16_t ancount;
392 uint16_t nscount;
393 uint16_t arcount;
394 } __attribute__((__packed__));
395
396 const char* readHeader(const char* buffer, const char* buffer_end,
397 unsigned* qdcount, unsigned* ancount,
398 unsigned* nscount, unsigned* arcount);
399 };
400
read(const char * buffer,const char * buffer_end)401 const char* DNSHeader::read(const char* buffer, const char* buffer_end) {
402 unsigned qdcount;
403 unsigned ancount;
404 unsigned nscount;
405 unsigned arcount;
406 const char* cur = readHeader(buffer, buffer_end, &qdcount, &ancount,
407 &nscount, &arcount);
408 if (cur == nullptr) {
409 ALOGI("parsing failed at line %d", __LINE__);
410 return nullptr;
411 }
412 if (qdcount) {
413 questions.resize(qdcount);
414 for (unsigned i = 0 ; i < qdcount ; ++i) {
415 cur = questions[i].read(cur, buffer_end);
416 if (cur == nullptr) {
417 ALOGI("parsing failed at line %d", __LINE__);
418 return nullptr;
419 }
420 }
421 }
422 if (ancount) {
423 answers.resize(ancount);
424 for (unsigned i = 0 ; i < ancount ; ++i) {
425 cur = answers[i].read(cur, buffer_end);
426 if (cur == nullptr) {
427 ALOGI("parsing failed at line %d", __LINE__);
428 return nullptr;
429 }
430 }
431 }
432 if (nscount) {
433 authorities.resize(nscount);
434 for (unsigned i = 0 ; i < nscount ; ++i) {
435 cur = authorities[i].read(cur, buffer_end);
436 if (cur == nullptr) {
437 ALOGI("parsing failed at line %d", __LINE__);
438 return nullptr;
439 }
440 }
441 }
442 if (arcount) {
443 additionals.resize(arcount);
444 for (unsigned i = 0 ; i < arcount ; ++i) {
445 cur = additionals[i].read(cur, buffer_end);
446 if (cur == nullptr) {
447 ALOGI("parsing failed at line %d", __LINE__);
448 return nullptr;
449 }
450 }
451 }
452 return cur;
453 }
454
write(char * buffer,const char * buffer_end) const455 char* DNSHeader::write(char* buffer, const char* buffer_end) const {
456 if (buffer + sizeof(Header) > buffer_end) {
457 ALOGI("buffer overflow on line %d", __LINE__);
458 return nullptr;
459 }
460 Header& header = *reinterpret_cast<Header*>(buffer);
461 // bytes 0-1
462 header.id = htons(id);
463 // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
464 header.flags0 = (qr << 7) | (opcode << 3) | (aa << 2) | (tr << 1) | rd;
465 // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
466 // Fake behavior: if the query set the "ad" bit, set it in the response too.
467 // In a real server, this should be set only if the data is authentic and the
468 // query contained an "ad" bit or DNSSEC extensions.
469 header.flags1 = (ad << 5) | rcode;
470 // rest of header
471 header.qdcount = htons(questions.size());
472 header.ancount = htons(answers.size());
473 header.nscount = htons(authorities.size());
474 header.arcount = htons(additionals.size());
475 char* buffer_cur = buffer + sizeof(Header);
476 for (const DNSQuestion& question : questions) {
477 buffer_cur = question.write(buffer_cur, buffer_end);
478 if (buffer_cur == nullptr) return nullptr;
479 }
480 for (const DNSRecord& answer : answers) {
481 buffer_cur = answer.write(buffer_cur, buffer_end);
482 if (buffer_cur == nullptr) return nullptr;
483 }
484 for (const DNSRecord& authority : authorities) {
485 buffer_cur = authority.write(buffer_cur, buffer_end);
486 if (buffer_cur == nullptr) return nullptr;
487 }
488 for (const DNSRecord& additional : additionals) {
489 buffer_cur = additional.write(buffer_cur, buffer_end);
490 if (buffer_cur == nullptr) return nullptr;
491 }
492 return buffer_cur;
493 }
494
toString() const495 std::string DNSHeader::toString() const {
496 // TODO
497 return std::string();
498 }
499
readHeader(const char * buffer,const char * buffer_end,unsigned * qdcount,unsigned * ancount,unsigned * nscount,unsigned * arcount)500 const char* DNSHeader::readHeader(const char* buffer, const char* buffer_end,
501 unsigned* qdcount, unsigned* ancount,
502 unsigned* nscount, unsigned* arcount) {
503 if (buffer + sizeof(Header) > buffer_end)
504 return 0;
505 const auto& header = *reinterpret_cast<const Header*>(buffer);
506 // bytes 0-1
507 id = ntohs(header.id);
508 // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
509 qr = header.flags0 >> 7;
510 opcode = (header.flags0 >> 3) & 0x0F;
511 aa = (header.flags0 >> 2) & 1;
512 tr = (header.flags0 >> 1) & 1;
513 rd = header.flags0 & 1;
514 // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
515 ra = header.flags1 >> 7;
516 ad = (header.flags1 >> 5) & 1;
517 rcode = header.flags1 & 0xF;
518 // rest of header
519 *qdcount = ntohs(header.qdcount);
520 *ancount = ntohs(header.ancount);
521 *nscount = ntohs(header.nscount);
522 *arcount = ntohs(header.arcount);
523 return buffer + sizeof(Header);
524 }
525
526 /* DNS responder */
527
DNSResponder(std::string listen_address,std::string listen_service,int poll_timeout_ms,uint16_t error_rcode,double response_probability)528 DNSResponder::DNSResponder(std::string listen_address,
529 std::string listen_service, int poll_timeout_ms,
530 uint16_t error_rcode, double response_probability) :
531 listen_address_(std::move(listen_address)), listen_service_(std::move(listen_service)),
532 poll_timeout_ms_(poll_timeout_ms), error_rcode_(error_rcode),
533 response_probability_(response_probability),
534 fail_on_edns_(false),
535 socket_(-1), epoll_fd_(-1), terminate_(false) { }
536
~DNSResponder()537 DNSResponder::~DNSResponder() {
538 stopServer();
539 }
540
addMapping(const char * name,ns_type type,const char * addr)541 void DNSResponder::addMapping(const char* name, ns_type type,
542 const char* addr) {
543 std::lock_guard<std::mutex> lock(mappings_mutex_);
544 auto it = mappings_.find(QueryKey(name, type));
545 if (it != mappings_.end()) {
546 ALOGI("Overwriting mapping for (%s, %s), previous address %s, new "
547 "address %s", name, dnstype2str(type), it->second.c_str(),
548 addr);
549 it->second = addr;
550 return;
551 }
552 mappings_.emplace(std::piecewise_construct,
553 std::forward_as_tuple(name, type),
554 std::forward_as_tuple(addr));
555 }
556
removeMapping(const char * name,ns_type type)557 void DNSResponder::removeMapping(const char* name, ns_type type) {
558 std::lock_guard<std::mutex> lock(mappings_mutex_);
559 auto it = mappings_.find(QueryKey(name, type));
560 if (it != mappings_.end()) {
561 ALOGI("Cannot remove mapping mapping from (%s, %s), not present", name,
562 dnstype2str(type));
563 return;
564 }
565 mappings_.erase(it);
566 }
567
setResponseProbability(double response_probability)568 void DNSResponder::setResponseProbability(double response_probability) {
569 response_probability_ = response_probability;
570 }
571
running() const572 bool DNSResponder::running() const {
573 return socket_ != -1;
574 }
575
startServer()576 bool DNSResponder::startServer() {
577 if (running()) {
578 ALOGI("server already running");
579 return false;
580 }
581 addrinfo ai_hints{
582 .ai_family = AF_UNSPEC,
583 .ai_socktype = SOCK_DGRAM,
584 .ai_flags = AI_PASSIVE
585 };
586 addrinfo* ai_res;
587 int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(),
588 &ai_hints, &ai_res);
589 if (rv) {
590 ALOGI("getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
591 listen_service_.c_str(), gai_strerror(rv));
592 return false;
593 }
594 int s = -1;
595 for (const addrinfo* ai = ai_res ; ai ; ai = ai->ai_next) {
596 s = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
597 if (s < 0) continue;
598 enableSockopt(s, SOL_SOCKET, SO_REUSEPORT);
599 enableSockopt(s, SOL_SOCKET, SO_REUSEADDR);
600 if (bind(s, ai->ai_addr, ai->ai_addrlen)) {
601 APLOGI("bind failed for socket %d", s);
602 close(s);
603 s = -1;
604 continue;
605 }
606 std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
607 ALOGI("bound to UDP %s:%s", host_str.c_str(), listen_service_.c_str());
608 break;
609 }
610 freeaddrinfo(ai_res);
611 if (s < 0) {
612 ALOGI("bind() failed");
613 return false;
614 }
615
616 int flags = fcntl(s, F_GETFL, 0);
617 if (flags < 0) flags = 0;
618 if (fcntl(s, F_SETFL, flags | O_NONBLOCK) < 0) {
619 APLOGI("fcntl(F_SETFL) failed for socket %d", s);
620 close(s);
621 return false;
622 }
623
624 int ep_fd = epoll_create(1);
625 if (ep_fd < 0) {
626 char error_msg[512] = { 0 };
627 if (strerror_r(errno, error_msg, sizeof(error_msg)))
628 strncpy(error_msg, "UNKNOWN", sizeof(error_msg));
629 APLOGI("epoll_create() failed: %s", error_msg);
630 close(s);
631 return false;
632 }
633 epoll_event ev;
634 ev.events = EPOLLIN;
635 ev.data.fd = s;
636 if (epoll_ctl(ep_fd, EPOLL_CTL_ADD, s, &ev) < 0) {
637 APLOGI("epoll_ctl() failed for socket %d", s);
638 close(ep_fd);
639 close(s);
640 return false;
641 }
642
643 epoll_fd_ = ep_fd;
644 socket_ = s;
645 {
646 std::lock_guard<std::mutex> lock(update_mutex_);
647 handler_thread_ = std::thread(&DNSResponder::requestHandler, this);
648 }
649 ALOGI("server started successfully");
650 return true;
651 }
652
stopServer()653 bool DNSResponder::stopServer() {
654 std::lock_guard<std::mutex> lock(update_mutex_);
655 if (!running()) {
656 ALOGI("server not running");
657 return false;
658 }
659 if (terminate_) {
660 ALOGI("LOGIC ERROR");
661 return false;
662 }
663 ALOGI("stopping server");
664 terminate_ = true;
665 handler_thread_.join();
666 close(epoll_fd_);
667 close(socket_);
668 terminate_ = false;
669 socket_ = -1;
670 ALOGI("server stopped successfully");
671 return true;
672 }
673
queries() const674 std::vector<std::pair<std::string, ns_type >> DNSResponder::queries() const {
675 std::lock_guard<std::mutex> lock(queries_mutex_);
676 return queries_;
677 }
678
clearQueries()679 void DNSResponder::clearQueries() {
680 std::lock_guard<std::mutex> lock(queries_mutex_);
681 queries_.clear();
682 }
683
requestHandler()684 void DNSResponder::requestHandler() {
685 epoll_event evs[1];
686 while (!terminate_) {
687 int n = epoll_wait(epoll_fd_, evs, 1, poll_timeout_ms_);
688 if (n == 0) continue;
689 if (n < 0) {
690 ALOGI("epoll_wait() failed");
691 // TODO(imaipi): terminate on error.
692 return;
693 }
694 char buffer[4096];
695 sockaddr_storage sa;
696 socklen_t sa_len = sizeof(sa);
697 ssize_t len;
698 do {
699 len = recvfrom(socket_, buffer, sizeof(buffer), 0,
700 (sockaddr*) &sa, &sa_len);
701 } while (len < 0 && (errno == EAGAIN || errno == EINTR));
702 if (len <= 0) {
703 ALOGI("recvfrom() failed");
704 continue;
705 }
706 ALOGI("read %zd bytes", len);
707 char response[4096];
708 size_t response_len = sizeof(response);
709 if (handleDNSRequest(buffer, len, response, &response_len) &&
710 response_len > 0) {
711 len = sendto(socket_, response, response_len, 0,
712 reinterpret_cast<const sockaddr*>(&sa), sa_len);
713 std::string host_str =
714 addr2str(reinterpret_cast<const sockaddr*>(&sa), sa_len);
715 if (len > 0) {
716 ALOGI("sent %zu bytes to %s", len, host_str.c_str());
717 } else {
718 APLOGI("sendto() failed for %s", host_str.c_str());
719 }
720 // Test that the response is actually a correct DNS message.
721 const char* response_end = response + len;
722 DNSHeader header;
723 const char* cur = header.read(response, response_end);
724 if (cur == nullptr) ALOGI("response is flawed");
725
726 } else {
727 ALOGI("not responding");
728 }
729 }
730 }
731
handleDNSRequest(const char * buffer,ssize_t len,char * response,size_t * response_len) const732 bool DNSResponder::handleDNSRequest(const char* buffer, ssize_t len,
733 char* response, size_t* response_len)
734 const {
735 ALOGI("request: '%s'", str2hex(buffer, len).c_str());
736 const char* buffer_end = buffer + len;
737 DNSHeader header;
738 const char* cur = header.read(buffer, buffer_end);
739 // TODO(imaipi): for now, unparsable messages are silently dropped, fix.
740 if (cur == nullptr) {
741 ALOGI("failed to parse query");
742 return false;
743 }
744 if (header.qr) {
745 ALOGI("response received instead of a query");
746 return false;
747 }
748 if (header.opcode != ns_opcode::ns_o_query) {
749 ALOGI("unsupported request opcode received");
750 return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
751 response_len);
752 }
753 if (header.questions.empty()) {
754 ALOGI("no questions present");
755 return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
756 response_len);
757 }
758 if (!header.answers.empty()) {
759 ALOGI("already %zu answers present in query", header.answers.size());
760 return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
761 response_len);
762 }
763 if (!header.additionals.empty() && fail_on_edns_) {
764 ALOGI("DNS request has an additional section (assumed EDNS). "
765 "Simulating an ancient (pre-EDNS) server.");
766 return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
767 response_len);
768 }
769 {
770 std::lock_guard<std::mutex> lock(queries_mutex_);
771 for (const DNSQuestion& question : header.questions) {
772 queries_.push_back(make_pair(question.qname.name,
773 ns_type(question.qtype)));
774 }
775 }
776
777 // Ignore requests with the preset probability.
778 auto constexpr bound = std::numeric_limits<unsigned>::max();
779 if (arc4random_uniform(bound) > bound*response_probability_) {
780 ALOGI("returning SRVFAIL in accordance with probability distribution");
781 return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
782 response_len);
783 }
784
785 for (const DNSQuestion& question : header.questions) {
786 if (question.qclass != ns_class::ns_c_in &&
787 question.qclass != ns_class::ns_c_any) {
788 ALOGI("unsupported question class %u", question.qclass);
789 return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
790 response_len);
791 }
792 if (!addAnswerRecords(question, &header.answers)) {
793 return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
794 response_len);
795 }
796 }
797 header.qr = true;
798 char* response_cur = header.write(response, response + *response_len);
799 if (response_cur == nullptr) {
800 return false;
801 }
802 *response_len = response_cur - response;
803 return true;
804 }
805
addAnswerRecords(const DNSQuestion & question,std::vector<DNSRecord> * answers) const806 bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
807 std::vector<DNSRecord>* answers) const {
808 std::lock_guard<std::mutex> guard(mappings_mutex_);
809 auto it = mappings_.find(QueryKey(question.qname.name, question.qtype));
810 if (it == mappings_.end()) {
811 // TODO(imaipi): handle correctly
812 ALOGI("no mapping found for %s %s, lazily refusing to add an answer",
813 question.qname.name.c_str(), dnstype2str(question.qtype));
814 return true;
815 }
816 ALOGI("mapping found for %s %s: %s", question.qname.name.c_str(),
817 dnstype2str(question.qtype), it->second.c_str());
818 DNSRecord record;
819 record.name = question.qname;
820 record.rtype = question.qtype;
821 record.rclass = ns_class::ns_c_in;
822 record.ttl = 5; // seconds
823 if (question.qtype == ns_type::ns_t_a) {
824 record.rdata.resize(4);
825 if (inet_pton(AF_INET, it->second.c_str(), record.rdata.data()) != 1) {
826 ALOGI("inet_pton(AF_INET, %s) failed", it->second.c_str());
827 return false;
828 }
829 } else if (question.qtype == ns_type::ns_t_aaaa) {
830 record.rdata.resize(16);
831 if (inet_pton(AF_INET6, it->second.c_str(), record.rdata.data()) != 1) {
832 ALOGI("inet_pton(AF_INET6, %s) failed", it->second.c_str());
833 return false;
834 }
835 } else {
836 ALOGI("unhandled qtype %s", dnstype2str(question.qtype));
837 return false;
838 }
839 answers->push_back(std::move(record));
840 return true;
841 }
842
makeErrorResponse(DNSHeader * header,ns_rcode rcode,char * response,size_t * response_len) const843 bool DNSResponder::makeErrorResponse(DNSHeader* header, ns_rcode rcode,
844 char* response, size_t* response_len)
845 const {
846 header->answers.clear();
847 header->authorities.clear();
848 header->additionals.clear();
849 header->rcode = rcode;
850 header->qr = true;
851 char* response_cur = header->write(response, response + *response_len);
852 if (response_cur == nullptr) return false;
853 *response_len = response_cur - response;
854 return true;
855 }
856
857 } // namespace test
858
859