1 /*
2 * Copyright (C) 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "dns_responder.h"
18
19 #include <arpa/inet.h>
20 #include <fcntl.h>
21 #include <netdb.h>
22 #include <stdarg.h>
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <string.h>
26 #include <sys/epoll.h>
27 #include <sys/socket.h>
28 #include <sys/types.h>
29 #include <unistd.h>
30
31 #include <iostream>
32 #include <vector>
33
34 #define LOG_TAG "DNSResponder"
35 #include <log/log.h>
36
37 namespace test {
38
errno2str()39 std::string errno2str() {
40 char error_msg[512] = { 0 };
41 if (strerror_r(errno, error_msg, sizeof(error_msg)))
42 return std::string();
43 return std::string(error_msg);
44 }
45
46 #define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str())
47
str2hex(const char * buffer,size_t len)48 std::string str2hex(const char* buffer, size_t len) {
49 std::string str(len*2, '\0');
50 for (size_t i = 0 ; i < len ; ++i) {
51 static const char* hex = "0123456789ABCDEF";
52 uint8_t c = buffer[i];
53 str[i*2] = hex[c >> 4];
54 str[i*2 + 1] = hex[c & 0x0F];
55 }
56 return str;
57 }
58
addr2str(const sockaddr * sa,socklen_t sa_len)59 std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
60 char host_str[NI_MAXHOST] = { 0 };
61 int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0,
62 NI_NUMERICHOST);
63 if (rv == 0) return std::string(host_str);
64 return std::string();
65 }
66
67 /* DNS struct helpers */
68
dnstype2str(unsigned dnstype)69 const char* dnstype2str(unsigned dnstype) {
70 static std::unordered_map<unsigned, const char*> kTypeStrs = {
71 { ns_type::ns_t_a, "A" },
72 { ns_type::ns_t_ns, "NS" },
73 { ns_type::ns_t_md, "MD" },
74 { ns_type::ns_t_mf, "MF" },
75 { ns_type::ns_t_cname, "CNAME" },
76 { ns_type::ns_t_soa, "SOA" },
77 { ns_type::ns_t_mb, "MB" },
78 { ns_type::ns_t_mb, "MG" },
79 { ns_type::ns_t_mr, "MR" },
80 { ns_type::ns_t_null, "NULL" },
81 { ns_type::ns_t_wks, "WKS" },
82 { ns_type::ns_t_ptr, "PTR" },
83 { ns_type::ns_t_hinfo, "HINFO" },
84 { ns_type::ns_t_minfo, "MINFO" },
85 { ns_type::ns_t_mx, "MX" },
86 { ns_type::ns_t_txt, "TXT" },
87 { ns_type::ns_t_rp, "RP" },
88 { ns_type::ns_t_afsdb, "AFSDB" },
89 { ns_type::ns_t_x25, "X25" },
90 { ns_type::ns_t_isdn, "ISDN" },
91 { ns_type::ns_t_rt, "RT" },
92 { ns_type::ns_t_nsap, "NSAP" },
93 { ns_type::ns_t_nsap_ptr, "NSAP-PTR" },
94 { ns_type::ns_t_sig, "SIG" },
95 { ns_type::ns_t_key, "KEY" },
96 { ns_type::ns_t_px, "PX" },
97 { ns_type::ns_t_gpos, "GPOS" },
98 { ns_type::ns_t_aaaa, "AAAA" },
99 { ns_type::ns_t_loc, "LOC" },
100 { ns_type::ns_t_nxt, "NXT" },
101 { ns_type::ns_t_eid, "EID" },
102 { ns_type::ns_t_nimloc, "NIMLOC" },
103 { ns_type::ns_t_srv, "SRV" },
104 { ns_type::ns_t_naptr, "NAPTR" },
105 { ns_type::ns_t_kx, "KX" },
106 { ns_type::ns_t_cert, "CERT" },
107 { ns_type::ns_t_a6, "A6" },
108 { ns_type::ns_t_dname, "DNAME" },
109 { ns_type::ns_t_sink, "SINK" },
110 { ns_type::ns_t_opt, "OPT" },
111 { ns_type::ns_t_apl, "APL" },
112 { ns_type::ns_t_tkey, "TKEY" },
113 { ns_type::ns_t_tsig, "TSIG" },
114 { ns_type::ns_t_ixfr, "IXFR" },
115 { ns_type::ns_t_axfr, "AXFR" },
116 { ns_type::ns_t_mailb, "MAILB" },
117 { ns_type::ns_t_maila, "MAILA" },
118 { ns_type::ns_t_any, "ANY" },
119 { ns_type::ns_t_zxfr, "ZXFR" },
120 };
121 auto it = kTypeStrs.find(dnstype);
122 static const char* kUnknownStr{ "UNKNOWN" };
123 if (it == kTypeStrs.end()) return kUnknownStr;
124 return it->second;
125 }
126
dnsclass2str(unsigned dnsclass)127 const char* dnsclass2str(unsigned dnsclass) {
128 static std::unordered_map<unsigned, const char*> kClassStrs = {
129 { ns_class::ns_c_in , "Internet" },
130 { 2, "CSNet" },
131 { ns_class::ns_c_chaos, "ChaosNet" },
132 { ns_class::ns_c_hs, "Hesiod" },
133 { ns_class::ns_c_none, "none" },
134 { ns_class::ns_c_any, "any" }
135 };
136 auto it = kClassStrs.find(dnsclass);
137 static const char* kUnknownStr{ "UNKNOWN" };
138 if (it == kClassStrs.end()) return kUnknownStr;
139 return it->second;
140 return "unknown";
141 }
142
143 struct DNSName {
144 std::string name;
145 const char* read(const char* buffer, const char* buffer_end);
146 char* write(char* buffer, const char* buffer_end) const;
147 const char* toString() const;
148 private:
149 const char* parseField(const char* buffer, const char* buffer_end,
150 bool* last);
151 };
152
toString() const153 const char* DNSName::toString() const {
154 return name.c_str();
155 }
156
read(const char * buffer,const char * buffer_end)157 const char* DNSName::read(const char* buffer, const char* buffer_end) {
158 const char* cur = buffer;
159 bool last = false;
160 do {
161 cur = parseField(cur, buffer_end, &last);
162 if (cur == nullptr) {
163 ALOGI("parsing failed at line %d", __LINE__);
164 return nullptr;
165 }
166 } while (!last);
167 return cur;
168 }
169
write(char * buffer,const char * buffer_end) const170 char* DNSName::write(char* buffer, const char* buffer_end) const {
171 char* buffer_cur = buffer;
172 for (size_t pos = 0 ; pos < name.size() ; ) {
173 size_t dot_pos = name.find('.', pos);
174 if (dot_pos == std::string::npos) {
175 // Sanity check, should never happen unless parseField is broken.
176 ALOGI("logic error: all names are expected to end with a '.'");
177 return nullptr;
178 }
179 size_t len = dot_pos - pos;
180 if (len >= 256) {
181 ALOGI("name component '%s' is %zu long, but max is 255",
182 name.substr(pos, dot_pos - pos).c_str(), len);
183 return nullptr;
184 }
185 if (buffer_cur + sizeof(uint8_t) + len > buffer_end) {
186 ALOGI("buffer overflow at line %d", __LINE__);
187 return nullptr;
188 }
189 *buffer_cur++ = len;
190 buffer_cur = std::copy(std::next(name.begin(), pos),
191 std::next(name.begin(), dot_pos),
192 buffer_cur);
193 pos = dot_pos + 1;
194 }
195 // Write final zero.
196 *buffer_cur++ = 0;
197 return buffer_cur;
198 }
199
parseField(const char * buffer,const char * buffer_end,bool * last)200 const char* DNSName::parseField(const char* buffer, const char* buffer_end,
201 bool* last) {
202 if (buffer + sizeof(uint8_t) > buffer_end) {
203 ALOGI("parsing failed at line %d", __LINE__);
204 return nullptr;
205 }
206 unsigned field_type = *buffer >> 6;
207 unsigned ofs = *buffer & 0x3F;
208 const char* cur = buffer + sizeof(uint8_t);
209 if (field_type == 0) {
210 // length + name component
211 if (ofs == 0) {
212 *last = true;
213 return cur;
214 }
215 if (cur + ofs > buffer_end) {
216 ALOGI("parsing failed at line %d", __LINE__);
217 return nullptr;
218 }
219 name.append(cur, ofs);
220 name.push_back('.');
221 return cur + ofs;
222 } else if (field_type == 3) {
223 ALOGI("name compression not implemented");
224 return nullptr;
225 }
226 ALOGI("invalid name field type");
227 return nullptr;
228 }
229
230 struct DNSQuestion {
231 DNSName qname;
232 unsigned qtype;
233 unsigned qclass;
234 const char* read(const char* buffer, const char* buffer_end);
235 char* write(char* buffer, const char* buffer_end) const;
236 std::string toString() const;
237 };
238
read(const char * buffer,const char * buffer_end)239 const char* DNSQuestion::read(const char* buffer, const char* buffer_end) {
240 const char* cur = qname.read(buffer, buffer_end);
241 if (cur == nullptr) {
242 ALOGI("parsing failed at line %d", __LINE__);
243 return nullptr;
244 }
245 if (cur + 2*sizeof(uint16_t) > buffer_end) {
246 ALOGI("parsing failed at line %d", __LINE__);
247 return nullptr;
248 }
249 qtype = ntohs(*reinterpret_cast<const uint16_t*>(cur));
250 qclass = ntohs(*reinterpret_cast<const uint16_t*>(cur + sizeof(uint16_t)));
251 return cur + 2*sizeof(uint16_t);
252 }
253
write(char * buffer,const char * buffer_end) const254 char* DNSQuestion::write(char* buffer, const char* buffer_end) const {
255 char* buffer_cur = qname.write(buffer, buffer_end);
256 if (buffer_cur == nullptr) return nullptr;
257 if (buffer_cur + 2*sizeof(uint16_t) > buffer_end) {
258 ALOGI("buffer overflow on line %d", __LINE__);
259 return nullptr;
260 }
261 *reinterpret_cast<uint16_t*>(buffer_cur) = htons(qtype);
262 *reinterpret_cast<uint16_t*>(buffer_cur + sizeof(uint16_t)) =
263 htons(qclass);
264 return buffer_cur + 2*sizeof(uint16_t);
265 }
266
toString() const267 std::string DNSQuestion::toString() const {
268 char buffer[4096];
269 int len = snprintf(buffer, sizeof(buffer), "Q<%s,%s,%s>", qname.toString(),
270 dnstype2str(qtype), dnsclass2str(qclass));
271 return std::string(buffer, len);
272 }
273
274 struct DNSRecord {
275 DNSName name;
276 unsigned rtype;
277 unsigned rclass;
278 unsigned ttl;
279 std::vector<char> rdata;
280 const char* read(const char* buffer, const char* buffer_end);
281 char* write(char* buffer, const char* buffer_end) const;
282 std::string toString() const;
283 private:
284 struct IntFields {
285 uint16_t rtype;
286 uint16_t rclass;
287 uint32_t ttl;
288 uint16_t rdlen;
289 } __attribute__((__packed__));
290
291 const char* readIntFields(const char* buffer, const char* buffer_end,
292 unsigned* rdlen);
293 char* writeIntFields(unsigned rdlen, char* buffer,
294 const char* buffer_end) const;
295 };
296
read(const char * buffer,const char * buffer_end)297 const char* DNSRecord::read(const char* buffer, const char* buffer_end) {
298 const char* cur = name.read(buffer, buffer_end);
299 if (cur == nullptr) {
300 ALOGI("parsing failed at line %d", __LINE__);
301 return nullptr;
302 }
303 unsigned rdlen = 0;
304 cur = readIntFields(cur, buffer_end, &rdlen);
305 if (cur == nullptr) {
306 ALOGI("parsing failed at line %d", __LINE__);
307 return nullptr;
308 }
309 if (cur + rdlen > buffer_end) {
310 ALOGI("parsing failed at line %d", __LINE__);
311 return nullptr;
312 }
313 rdata.assign(cur, cur + rdlen);
314 return cur + rdlen;
315 }
316
write(char * buffer,const char * buffer_end) const317 char* DNSRecord::write(char* buffer, const char* buffer_end) const {
318 char* buffer_cur = name.write(buffer, buffer_end);
319 if (buffer_cur == nullptr) return nullptr;
320 buffer_cur = writeIntFields(rdata.size(), buffer_cur, buffer_end);
321 if (buffer_cur == nullptr) return nullptr;
322 if (buffer_cur + rdata.size() > buffer_end) {
323 ALOGI("buffer overflow on line %d", __LINE__);
324 return nullptr;
325 }
326 return std::copy(rdata.begin(), rdata.end(), buffer_cur);
327 }
328
toString() const329 std::string DNSRecord::toString() const {
330 char buffer[4096];
331 int len = snprintf(buffer, sizeof(buffer), "R<%s,%s,%s>", name.toString(),
332 dnstype2str(rtype), dnsclass2str(rclass));
333 return std::string(buffer, len);
334 }
335
readIntFields(const char * buffer,const char * buffer_end,unsigned * rdlen)336 const char* DNSRecord::readIntFields(const char* buffer, const char* buffer_end,
337 unsigned* rdlen) {
338 if (buffer + sizeof(IntFields) > buffer_end ) {
339 ALOGI("parsing failed at line %d", __LINE__);
340 return nullptr;
341 }
342 const auto& intfields = *reinterpret_cast<const IntFields*>(buffer);
343 rtype = ntohs(intfields.rtype);
344 rclass = ntohs(intfields.rclass);
345 ttl = ntohl(intfields.ttl);
346 *rdlen = ntohs(intfields.rdlen);
347 return buffer + sizeof(IntFields);
348 }
349
writeIntFields(unsigned rdlen,char * buffer,const char * buffer_end) const350 char* DNSRecord::writeIntFields(unsigned rdlen, char* buffer,
351 const char* buffer_end) const {
352 if (buffer + sizeof(IntFields) > buffer_end ) {
353 ALOGI("buffer overflow on line %d", __LINE__);
354 return nullptr;
355 }
356 auto& intfields = *reinterpret_cast<IntFields*>(buffer);
357 intfields.rtype = htons(rtype);
358 intfields.rclass = htons(rclass);
359 intfields.ttl = htonl(ttl);
360 intfields.rdlen = htons(rdlen);
361 return buffer + sizeof(IntFields);
362 }
363
364 struct DNSHeader {
365 unsigned id;
366 bool ra;
367 uint8_t rcode;
368 bool qr;
369 uint8_t opcode;
370 bool aa;
371 bool tr;
372 bool rd;
373 std::vector<DNSQuestion> questions;
374 std::vector<DNSRecord> answers;
375 std::vector<DNSRecord> authorities;
376 std::vector<DNSRecord> additionals;
377 const char* read(const char* buffer, const char* buffer_end);
378 char* write(char* buffer, const char* buffer_end) const;
379 std::string toString() const;
380
381 private:
382 struct Header {
383 uint16_t id;
384 uint8_t flags0;
385 uint8_t flags1;
386 uint16_t qdcount;
387 uint16_t ancount;
388 uint16_t nscount;
389 uint16_t arcount;
390 } __attribute__((__packed__));
391
392 const char* readHeader(const char* buffer, const char* buffer_end,
393 unsigned* qdcount, unsigned* ancount,
394 unsigned* nscount, unsigned* arcount);
395 };
396
read(const char * buffer,const char * buffer_end)397 const char* DNSHeader::read(const char* buffer, const char* buffer_end) {
398 unsigned qdcount;
399 unsigned ancount;
400 unsigned nscount;
401 unsigned arcount;
402 const char* cur = readHeader(buffer, buffer_end, &qdcount, &ancount,
403 &nscount, &arcount);
404 if (cur == nullptr) {
405 ALOGI("parsing failed at line %d", __LINE__);
406 return nullptr;
407 }
408 if (qdcount) {
409 questions.resize(qdcount);
410 for (unsigned i = 0 ; i < qdcount ; ++i) {
411 cur = questions[i].read(cur, buffer_end);
412 if (cur == nullptr) {
413 ALOGI("parsing failed at line %d", __LINE__);
414 return nullptr;
415 }
416 }
417 }
418 if (ancount) {
419 answers.resize(ancount);
420 for (unsigned i = 0 ; i < ancount ; ++i) {
421 cur = answers[i].read(cur, buffer_end);
422 if (cur == nullptr) {
423 ALOGI("parsing failed at line %d", __LINE__);
424 return nullptr;
425 }
426 }
427 }
428 if (nscount) {
429 authorities.resize(nscount);
430 for (unsigned i = 0 ; i < nscount ; ++i) {
431 cur = authorities[i].read(cur, buffer_end);
432 if (cur == nullptr) {
433 ALOGI("parsing failed at line %d", __LINE__);
434 return nullptr;
435 }
436 }
437 }
438 if (arcount) {
439 additionals.resize(arcount);
440 for (unsigned i = 0 ; i < arcount ; ++i) {
441 cur = additionals[i].read(cur, buffer_end);
442 if (cur == nullptr) {
443 ALOGI("parsing failed at line %d", __LINE__);
444 return nullptr;
445 }
446 }
447 }
448 return cur;
449 }
450
write(char * buffer,const char * buffer_end) const451 char* DNSHeader::write(char* buffer, const char* buffer_end) const {
452 if (buffer + sizeof(Header) > buffer_end) {
453 ALOGI("buffer overflow on line %d", __LINE__);
454 return nullptr;
455 }
456 Header& header = *reinterpret_cast<Header*>(buffer);
457 // bytes 0-1
458 header.id = htons(id);
459 // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
460 header.flags0 = (qr << 7) | (opcode << 3) | (aa << 2) | (tr << 1) | rd;
461 // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
462 header.flags1 = rcode;
463 // rest of header
464 header.qdcount = htons(questions.size());
465 header.ancount = htons(answers.size());
466 header.nscount = htons(authorities.size());
467 header.arcount = htons(additionals.size());
468 char* buffer_cur = buffer + sizeof(Header);
469 for (const DNSQuestion& question : questions) {
470 buffer_cur = question.write(buffer_cur, buffer_end);
471 if (buffer_cur == nullptr) return nullptr;
472 }
473 for (const DNSRecord& answer : answers) {
474 buffer_cur = answer.write(buffer_cur, buffer_end);
475 if (buffer_cur == nullptr) return nullptr;
476 }
477 for (const DNSRecord& authority : authorities) {
478 buffer_cur = authority.write(buffer_cur, buffer_end);
479 if (buffer_cur == nullptr) return nullptr;
480 }
481 for (const DNSRecord& additional : additionals) {
482 buffer_cur = additional.write(buffer_cur, buffer_end);
483 if (buffer_cur == nullptr) return nullptr;
484 }
485 return buffer_cur;
486 }
487
toString() const488 std::string DNSHeader::toString() const {
489 // TODO
490 return std::string();
491 }
492
readHeader(const char * buffer,const char * buffer_end,unsigned * qdcount,unsigned * ancount,unsigned * nscount,unsigned * arcount)493 const char* DNSHeader::readHeader(const char* buffer, const char* buffer_end,
494 unsigned* qdcount, unsigned* ancount,
495 unsigned* nscount, unsigned* arcount) {
496 if (buffer + sizeof(Header) > buffer_end)
497 return 0;
498 const auto& header = *reinterpret_cast<const Header*>(buffer);
499 // bytes 0-1
500 id = ntohs(header.id);
501 // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
502 qr = header.flags0 >> 7;
503 opcode = (header.flags0 >> 3) & 0x0F;
504 aa = (header.flags0 >> 2) & 1;
505 tr = (header.flags0 >> 1) & 1;
506 rd = header.flags0 & 1;
507 // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
508 ra = header.flags1 >> 7;
509 rcode = header.flags1 & 0xF;
510 // rest of header
511 *qdcount = ntohs(header.qdcount);
512 *ancount = ntohs(header.ancount);
513 *nscount = ntohs(header.nscount);
514 *arcount = ntohs(header.arcount);
515 return buffer + sizeof(Header);
516 }
517
518 /* DNS responder */
519
DNSResponder(std::string listen_address,std::string listen_service,int poll_timeout_ms,uint16_t error_rcode,double response_probability)520 DNSResponder::DNSResponder(std::string listen_address,
521 std::string listen_service, int poll_timeout_ms,
522 uint16_t error_rcode, double response_probability) :
523 listen_address_(std::move(listen_address)), listen_service_(std::move(listen_service)),
524 poll_timeout_ms_(poll_timeout_ms), error_rcode_(error_rcode),
525 response_probability_(response_probability),
526 socket_(-1), epoll_fd_(-1), terminate_(false) { }
527
~DNSResponder()528 DNSResponder::~DNSResponder() {
529 stopServer();
530 }
531
addMapping(const char * name,ns_type type,const char * addr)532 void DNSResponder::addMapping(const char* name, ns_type type,
533 const char* addr) {
534 std::lock_guard<std::mutex> lock(mappings_mutex_);
535 auto it = mappings_.find(QueryKey(name, type));
536 if (it != mappings_.end()) {
537 ALOGI("Overwriting mapping for (%s, %s), previous address %s, new "
538 "address %s", name, dnstype2str(type), it->second.c_str(),
539 addr);
540 it->second = addr;
541 return;
542 }
543 mappings_.emplace(std::piecewise_construct,
544 std::forward_as_tuple(name, type),
545 std::forward_as_tuple(addr));
546 }
547
removeMapping(const char * name,ns_type type)548 void DNSResponder::removeMapping(const char* name, ns_type type) {
549 std::lock_guard<std::mutex> lock(mappings_mutex_);
550 auto it = mappings_.find(QueryKey(name, type));
551 if (it != mappings_.end()) {
552 ALOGI("Cannot remove mapping mapping from (%s, %s), not present", name,
553 dnstype2str(type));
554 return;
555 }
556 mappings_.erase(it);
557 }
558
setResponseProbability(double response_probability)559 void DNSResponder::setResponseProbability(double response_probability) {
560 response_probability_ = response_probability;
561 }
562
running() const563 bool DNSResponder::running() const {
564 return socket_ != -1;
565 }
566
startServer()567 bool DNSResponder::startServer() {
568 if (running()) {
569 ALOGI("server already running");
570 return false;
571 }
572 addrinfo ai_hints{
573 .ai_family = AF_UNSPEC,
574 .ai_socktype = SOCK_DGRAM,
575 .ai_flags = AI_PASSIVE
576 };
577 addrinfo* ai_res;
578 int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(),
579 &ai_hints, &ai_res);
580 if (rv) {
581 ALOGI("getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
582 listen_service_.c_str(), gai_strerror(rv));
583 return false;
584 }
585 int s = -1;
586 for (const addrinfo* ai = ai_res ; ai ; ai = ai->ai_next) {
587 s = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
588 if (s < 0) continue;
589 const int one = 1;
590 setsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
591 if (bind(s, ai->ai_addr, ai->ai_addrlen)) {
592 APLOGI("bind failed for socket %d", s);
593 close(s);
594 s = -1;
595 continue;
596 }
597 std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
598 ALOGI("bound to UDP %s:%s", host_str.c_str(), listen_service_.c_str());
599 break;
600 }
601 freeaddrinfo(ai_res);
602 if (s < 0) {
603 ALOGI("bind() failed");
604 return false;
605 }
606
607 int flags = fcntl(s, F_GETFL, 0);
608 if (flags < 0) flags = 0;
609 if (fcntl(s, F_SETFL, flags | O_NONBLOCK) < 0) {
610 APLOGI("fcntl(F_SETFL) failed for socket %d", s);
611 close(s);
612 return false;
613 }
614
615 int ep_fd = epoll_create(1);
616 if (ep_fd < 0) {
617 char error_msg[512] = { 0 };
618 if (strerror_r(errno, error_msg, sizeof(error_msg)))
619 strncpy(error_msg, "UNKNOWN", sizeof(error_msg));
620 APLOGI("epoll_create() failed: %s", error_msg);
621 close(s);
622 return false;
623 }
624 epoll_event ev;
625 ev.events = EPOLLIN;
626 ev.data.fd = s;
627 if (epoll_ctl(ep_fd, EPOLL_CTL_ADD, s, &ev) < 0) {
628 APLOGI("epoll_ctl() failed for socket %d", s);
629 close(ep_fd);
630 close(s);
631 return false;
632 }
633
634 epoll_fd_ = ep_fd;
635 socket_ = s;
636 {
637 std::lock_guard<std::mutex> lock(update_mutex_);
638 handler_thread_ = std::thread(&DNSResponder::requestHandler, this);
639 }
640 ALOGI("server started successfully");
641 return true;
642 }
643
stopServer()644 bool DNSResponder::stopServer() {
645 std::lock_guard<std::mutex> lock(update_mutex_);
646 if (!running()) {
647 ALOGI("server not running");
648 return false;
649 }
650 if (terminate_) {
651 ALOGI("LOGIC ERROR");
652 return false;
653 }
654 ALOGI("stopping server");
655 terminate_ = true;
656 handler_thread_.join();
657 close(epoll_fd_);
658 close(socket_);
659 terminate_ = false;
660 socket_ = -1;
661 ALOGI("server stopped successfully");
662 return true;
663 }
664
queries() const665 std::vector<std::pair<std::string, ns_type >> DNSResponder::queries() const {
666 std::lock_guard<std::mutex> lock(queries_mutex_);
667 return queries_;
668 }
669
clearQueries()670 void DNSResponder::clearQueries() {
671 std::lock_guard<std::mutex> lock(queries_mutex_);
672 queries_.clear();
673 }
674
requestHandler()675 void DNSResponder::requestHandler() {
676 epoll_event evs[1];
677 while (!terminate_) {
678 int n = epoll_wait(epoll_fd_, evs, 1, poll_timeout_ms_);
679 if (n == 0) continue;
680 if (n < 0) {
681 ALOGI("epoll_wait() failed");
682 // TODO(imaipi): terminate on error.
683 return;
684 }
685 char buffer[4096];
686 sockaddr_storage sa;
687 socklen_t sa_len = sizeof(sa);
688 ssize_t len;
689 do {
690 len = recvfrom(socket_, buffer, sizeof(buffer), 0,
691 (sockaddr*) &sa, &sa_len);
692 } while (len < 0 && (errno == EAGAIN || errno == EINTR));
693 if (len <= 0) {
694 ALOGI("recvfrom() failed");
695 continue;
696 }
697 ALOGI("read %zd bytes", len);
698 char response[4096];
699 size_t response_len = sizeof(response);
700 if (handleDNSRequest(buffer, len, response, &response_len) &&
701 response_len > 0) {
702 len = sendto(socket_, response, response_len, 0,
703 reinterpret_cast<const sockaddr*>(&sa), sa_len);
704 std::string host_str =
705 addr2str(reinterpret_cast<const sockaddr*>(&sa), sa_len);
706 if (len > 0) {
707 ALOGI("sent %zu bytes to %s", len, host_str.c_str());
708 } else {
709 APLOGI("sendto() failed for %s", host_str.c_str());
710 }
711 // Test that the response is actually a correct DNS message.
712 const char* response_end = response + len;
713 DNSHeader header;
714 const char* cur = header.read(response, response_end);
715 if (cur == nullptr) ALOGI("response is flawed");
716
717 } else {
718 ALOGI("not responding");
719 }
720 }
721 }
722
handleDNSRequest(const char * buffer,ssize_t len,char * response,size_t * response_len) const723 bool DNSResponder::handleDNSRequest(const char* buffer, ssize_t len,
724 char* response, size_t* response_len)
725 const {
726 ALOGI("request: '%s'", str2hex(buffer, len).c_str());
727 const char* buffer_end = buffer + len;
728 DNSHeader header;
729 const char* cur = header.read(buffer, buffer_end);
730 // TODO(imaipi): for now, unparsable messages are silently dropped, fix.
731 if (cur == nullptr) {
732 ALOGI("failed to parse query");
733 return false;
734 }
735 if (header.qr) {
736 ALOGI("response received instead of a query");
737 return false;
738 }
739 if (header.opcode != ns_opcode::ns_o_query) {
740 ALOGI("unsupported request opcode received");
741 return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
742 response_len);
743 }
744 if (header.questions.empty()) {
745 ALOGI("no questions present");
746 return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
747 response_len);
748 }
749 if (!header.answers.empty()) {
750 ALOGI("already %zu answers present in query", header.answers.size());
751 return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
752 response_len);
753 }
754 {
755 std::lock_guard<std::mutex> lock(queries_mutex_);
756 for (const DNSQuestion& question : header.questions) {
757 queries_.push_back(make_pair(question.qname.name,
758 ns_type(question.qtype)));
759 }
760 }
761
762 // Ignore requests with the preset probability.
763 auto constexpr bound = std::numeric_limits<unsigned>::max();
764 if (arc4random_uniform(bound) > bound*response_probability_) {
765 ALOGI("returning SRVFAIL in accordance with probability distribution");
766 return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
767 response_len);
768 }
769
770 for (const DNSQuestion& question : header.questions) {
771 if (question.qclass != ns_class::ns_c_in &&
772 question.qclass != ns_class::ns_c_any) {
773 ALOGI("unsupported question class %u", question.qclass);
774 return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
775 response_len);
776 }
777 if (!addAnswerRecords(question, &header.answers)) {
778 return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
779 response_len);
780 }
781 }
782 header.qr = true;
783 char* response_cur = header.write(response, response + *response_len);
784 if (response_cur == nullptr) {
785 return false;
786 }
787 *response_len = response_cur - response;
788 return true;
789 }
790
addAnswerRecords(const DNSQuestion & question,std::vector<DNSRecord> * answers) const791 bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
792 std::vector<DNSRecord>* answers) const {
793 auto it = mappings_.find(QueryKey(question.qname.name, question.qtype));
794 if (it == mappings_.end()) {
795 // TODO(imaipi): handle correctly
796 ALOGI("no mapping found for %s %s, lazily refusing to add an answer",
797 question.qname.name.c_str(), dnstype2str(question.qtype));
798 return true;
799 }
800 ALOGI("mapping found for %s %s: %s", question.qname.name.c_str(),
801 dnstype2str(question.qtype), it->second.c_str());
802 DNSRecord record;
803 record.name = question.qname;
804 record.rtype = question.qtype;
805 record.rclass = ns_class::ns_c_in;
806 record.ttl = 5; // seconds
807 if (question.qtype == ns_type::ns_t_a) {
808 record.rdata.resize(4);
809 if (inet_pton(AF_INET, it->second.c_str(), record.rdata.data()) != 1) {
810 ALOGI("inet_pton(AF_INET, %s) failed", it->second.c_str());
811 return false;
812 }
813 } else if (question.qtype == ns_type::ns_t_aaaa) {
814 record.rdata.resize(16);
815 if (inet_pton(AF_INET6, it->second.c_str(), record.rdata.data()) != 1) {
816 ALOGI("inet_pton(AF_INET6, %s) failed", it->second.c_str());
817 return false;
818 }
819 } else {
820 ALOGI("unhandled qtype %s", dnstype2str(question.qtype));
821 return false;
822 }
823 answers->push_back(std::move(record));
824 return true;
825 }
826
makeErrorResponse(DNSHeader * header,ns_rcode rcode,char * response,size_t * response_len) const827 bool DNSResponder::makeErrorResponse(DNSHeader* header, ns_rcode rcode,
828 char* response, size_t* response_len)
829 const {
830 header->answers.clear();
831 header->authorities.clear();
832 header->additionals.clear();
833 header->rcode = rcode;
834 header->qr = true;
835 char* response_cur = header->write(response, response + *response_len);
836 if (response_cur == nullptr) return false;
837 *response_len = response_cur - response;
838 return true;
839 }
840
841 } // namespace test
842
843