1 /*
2 * Copyright (C) 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "dns_responder.h"
18
19 #include <arpa/inet.h>
20 #include <fcntl.h>
21 #include <netdb.h>
22 #include <stdarg.h>
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <string.h>
26 #include <sys/epoll.h>
27 #include <sys/socket.h>
28 #include <sys/types.h>
29 #include <unistd.h>
30
31 #include <iostream>
32 #include <vector>
33
34 #include <log/log.h>
35
36 namespace test {
37
errno2str()38 std::string errno2str() {
39 char error_msg[512] = { 0 };
40 if (strerror_r(errno, error_msg, sizeof(error_msg)))
41 return std::string();
42 return std::string(error_msg);
43 }
44
45 #define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str())
46
str2hex(const char * buffer,size_t len)47 std::string str2hex(const char* buffer, size_t len) {
48 std::string str(len*2, '\0');
49 for (size_t i = 0 ; i < len ; ++i) {
50 static const char* hex = "0123456789ABCDEF";
51 uint8_t c = buffer[i];
52 str[i*2] = hex[c >> 4];
53 str[i*2 + 1] = hex[c & 0x0F];
54 }
55 return str;
56 }
57
addr2str(const sockaddr * sa,socklen_t sa_len)58 std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
59 char host_str[NI_MAXHOST] = { 0 };
60 int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0,
61 NI_NUMERICHOST);
62 if (rv == 0) return std::string(host_str);
63 return std::string();
64 }
65
66 /* DNS struct helpers */
67
dnstype2str(unsigned dnstype)68 const char* dnstype2str(unsigned dnstype) {
69 static std::unordered_map<unsigned, const char*> kTypeStrs = {
70 { ns_type::ns_t_a, "A" },
71 { ns_type::ns_t_ns, "NS" },
72 { ns_type::ns_t_md, "MD" },
73 { ns_type::ns_t_mf, "MF" },
74 { ns_type::ns_t_cname, "CNAME" },
75 { ns_type::ns_t_soa, "SOA" },
76 { ns_type::ns_t_mb, "MB" },
77 { ns_type::ns_t_mb, "MG" },
78 { ns_type::ns_t_mr, "MR" },
79 { ns_type::ns_t_null, "NULL" },
80 { ns_type::ns_t_wks, "WKS" },
81 { ns_type::ns_t_ptr, "PTR" },
82 { ns_type::ns_t_hinfo, "HINFO" },
83 { ns_type::ns_t_minfo, "MINFO" },
84 { ns_type::ns_t_mx, "MX" },
85 { ns_type::ns_t_txt, "TXT" },
86 { ns_type::ns_t_rp, "RP" },
87 { ns_type::ns_t_afsdb, "AFSDB" },
88 { ns_type::ns_t_x25, "X25" },
89 { ns_type::ns_t_isdn, "ISDN" },
90 { ns_type::ns_t_rt, "RT" },
91 { ns_type::ns_t_nsap, "NSAP" },
92 { ns_type::ns_t_nsap_ptr, "NSAP-PTR" },
93 { ns_type::ns_t_sig, "SIG" },
94 { ns_type::ns_t_key, "KEY" },
95 { ns_type::ns_t_px, "PX" },
96 { ns_type::ns_t_gpos, "GPOS" },
97 { ns_type::ns_t_aaaa, "AAAA" },
98 { ns_type::ns_t_loc, "LOC" },
99 { ns_type::ns_t_nxt, "NXT" },
100 { ns_type::ns_t_eid, "EID" },
101 { ns_type::ns_t_nimloc, "NIMLOC" },
102 { ns_type::ns_t_srv, "SRV" },
103 { ns_type::ns_t_naptr, "NAPTR" },
104 { ns_type::ns_t_kx, "KX" },
105 { ns_type::ns_t_cert, "CERT" },
106 { ns_type::ns_t_a6, "A6" },
107 { ns_type::ns_t_dname, "DNAME" },
108 { ns_type::ns_t_sink, "SINK" },
109 { ns_type::ns_t_opt, "OPT" },
110 { ns_type::ns_t_apl, "APL" },
111 { ns_type::ns_t_tkey, "TKEY" },
112 { ns_type::ns_t_tsig, "TSIG" },
113 { ns_type::ns_t_ixfr, "IXFR" },
114 { ns_type::ns_t_axfr, "AXFR" },
115 { ns_type::ns_t_mailb, "MAILB" },
116 { ns_type::ns_t_maila, "MAILA" },
117 { ns_type::ns_t_any, "ANY" },
118 { ns_type::ns_t_zxfr, "ZXFR" },
119 };
120 auto it = kTypeStrs.find(dnstype);
121 static const char* kUnknownStr{ "UNKNOWN" };
122 if (it == kTypeStrs.end()) return kUnknownStr;
123 return it->second;
124 }
125
dnsclass2str(unsigned dnsclass)126 const char* dnsclass2str(unsigned dnsclass) {
127 static std::unordered_map<unsigned, const char*> kClassStrs = {
128 { ns_class::ns_c_in , "Internet" },
129 { 2, "CSNet" },
130 { ns_class::ns_c_chaos, "ChaosNet" },
131 { ns_class::ns_c_hs, "Hesiod" },
132 { ns_class::ns_c_none, "none" },
133 { ns_class::ns_c_any, "any" }
134 };
135 auto it = kClassStrs.find(dnsclass);
136 static const char* kUnknownStr{ "UNKNOWN" };
137 if (it == kClassStrs.end()) return kUnknownStr;
138 return it->second;
139 return "unknown";
140 }
141
142 struct DNSName {
143 std::string name;
144 const char* read(const char* buffer, const char* buffer_end);
145 char* write(char* buffer, const char* buffer_end) const;
146 const char* toString() const;
147 private:
148 const char* parseField(const char* buffer, const char* buffer_end,
149 bool* last);
150 };
151
toString() const152 const char* DNSName::toString() const {
153 return name.c_str();
154 }
155
read(const char * buffer,const char * buffer_end)156 const char* DNSName::read(const char* buffer, const char* buffer_end) {
157 const char* cur = buffer;
158 bool last = false;
159 do {
160 cur = parseField(cur, buffer_end, &last);
161 if (cur == nullptr) {
162 ALOGI("parsing failed at line %d", __LINE__);
163 return nullptr;
164 }
165 } while (!last);
166 return cur;
167 }
168
write(char * buffer,const char * buffer_end) const169 char* DNSName::write(char* buffer, const char* buffer_end) const {
170 char* buffer_cur = buffer;
171 for (size_t pos = 0 ; pos < name.size() ; ) {
172 size_t dot_pos = name.find('.', pos);
173 if (dot_pos == std::string::npos) {
174 // Sanity check, should never happen unless parseField is broken.
175 ALOGI("logic error: all names are expected to end with a '.'");
176 return nullptr;
177 }
178 size_t len = dot_pos - pos;
179 if (len >= 256) {
180 ALOGI("name component '%s' is %zu long, but max is 255",
181 name.substr(pos, dot_pos - pos).c_str(), len);
182 return nullptr;
183 }
184 if (buffer_cur + sizeof(uint8_t) + len > buffer_end) {
185 ALOGI("buffer overflow at line %d", __LINE__);
186 return nullptr;
187 }
188 *buffer_cur++ = len;
189 buffer_cur = std::copy(std::next(name.begin(), pos),
190 std::next(name.begin(), dot_pos),
191 buffer_cur);
192 pos = dot_pos + 1;
193 }
194 // Write final zero.
195 *buffer_cur++ = 0;
196 return buffer_cur;
197 }
198
parseField(const char * buffer,const char * buffer_end,bool * last)199 const char* DNSName::parseField(const char* buffer, const char* buffer_end,
200 bool* last) {
201 if (buffer + sizeof(uint8_t) > buffer_end) {
202 ALOGI("parsing failed at line %d", __LINE__);
203 return nullptr;
204 }
205 unsigned field_type = *buffer >> 6;
206 unsigned ofs = *buffer & 0x3F;
207 const char* cur = buffer + sizeof(uint8_t);
208 if (field_type == 0) {
209 // length + name component
210 if (ofs == 0) {
211 *last = true;
212 return cur;
213 }
214 if (cur + ofs > buffer_end) {
215 ALOGI("parsing failed at line %d", __LINE__);
216 return nullptr;
217 }
218 name.append(cur, ofs);
219 name.push_back('.');
220 return cur + ofs;
221 } else if (field_type == 3) {
222 ALOGI("name compression not implemented");
223 return nullptr;
224 }
225 ALOGI("invalid name field type");
226 return nullptr;
227 }
228
229 struct DNSQuestion {
230 DNSName qname;
231 unsigned qtype;
232 unsigned qclass;
233 const char* read(const char* buffer, const char* buffer_end);
234 char* write(char* buffer, const char* buffer_end) const;
235 std::string toString() const;
236 };
237
read(const char * buffer,const char * buffer_end)238 const char* DNSQuestion::read(const char* buffer, const char* buffer_end) {
239 const char* cur = qname.read(buffer, buffer_end);
240 if (cur == nullptr) {
241 ALOGI("parsing failed at line %d", __LINE__);
242 return nullptr;
243 }
244 if (cur + 2*sizeof(uint16_t) > buffer_end) {
245 ALOGI("parsing failed at line %d", __LINE__);
246 return nullptr;
247 }
248 qtype = ntohs(*reinterpret_cast<const uint16_t*>(cur));
249 qclass = ntohs(*reinterpret_cast<const uint16_t*>(cur + sizeof(uint16_t)));
250 return cur + 2*sizeof(uint16_t);
251 }
252
write(char * buffer,const char * buffer_end) const253 char* DNSQuestion::write(char* buffer, const char* buffer_end) const {
254 char* buffer_cur = qname.write(buffer, buffer_end);
255 if (buffer_cur == nullptr) return nullptr;
256 if (buffer_cur + 2*sizeof(uint16_t) > buffer_end) {
257 ALOGI("buffer overflow on line %d", __LINE__);
258 return nullptr;
259 }
260 *reinterpret_cast<uint16_t*>(buffer_cur) = htons(qtype);
261 *reinterpret_cast<uint16_t*>(buffer_cur + sizeof(uint16_t)) =
262 htons(qclass);
263 return buffer_cur + 2*sizeof(uint16_t);
264 }
265
toString() const266 std::string DNSQuestion::toString() const {
267 char buffer[4096];
268 int len = snprintf(buffer, sizeof(buffer), "Q<%s,%s,%s>", qname.toString(),
269 dnstype2str(qtype), dnsclass2str(qclass));
270 return std::string(buffer, len);
271 }
272
273 struct DNSRecord {
274 DNSName name;
275 unsigned rtype;
276 unsigned rclass;
277 unsigned ttl;
278 std::vector<char> rdata;
279 const char* read(const char* buffer, const char* buffer_end);
280 char* write(char* buffer, const char* buffer_end) const;
281 std::string toString() const;
282 private:
283 struct IntFields {
284 uint16_t rtype;
285 uint16_t rclass;
286 uint32_t ttl;
287 uint16_t rdlen;
288 } __attribute__((__packed__));
289
290 const char* readIntFields(const char* buffer, const char* buffer_end,
291 unsigned* rdlen);
292 char* writeIntFields(unsigned rdlen, char* buffer,
293 const char* buffer_end) const;
294 };
295
read(const char * buffer,const char * buffer_end)296 const char* DNSRecord::read(const char* buffer, const char* buffer_end) {
297 const char* cur = name.read(buffer, buffer_end);
298 if (cur == nullptr) {
299 ALOGI("parsing failed at line %d", __LINE__);
300 return nullptr;
301 }
302 unsigned rdlen = 0;
303 cur = readIntFields(cur, buffer_end, &rdlen);
304 if (cur == nullptr) {
305 ALOGI("parsing failed at line %d", __LINE__);
306 return nullptr;
307 }
308 if (cur + rdlen > buffer_end) {
309 ALOGI("parsing failed at line %d", __LINE__);
310 return nullptr;
311 }
312 rdata.assign(cur, cur + rdlen);
313 return cur + rdlen;
314 }
315
write(char * buffer,const char * buffer_end) const316 char* DNSRecord::write(char* buffer, const char* buffer_end) const {
317 char* buffer_cur = name.write(buffer, buffer_end);
318 if (buffer_cur == nullptr) return nullptr;
319 buffer_cur = writeIntFields(rdata.size(), buffer_cur, buffer_end);
320 if (buffer_cur == nullptr) return nullptr;
321 if (buffer_cur + rdata.size() > buffer_end) {
322 ALOGI("buffer overflow on line %d", __LINE__);
323 return nullptr;
324 }
325 return std::copy(rdata.begin(), rdata.end(), buffer_cur);
326 }
327
toString() const328 std::string DNSRecord::toString() const {
329 char buffer[4096];
330 int len = snprintf(buffer, sizeof(buffer), "R<%s,%s,%s>", name.toString(),
331 dnstype2str(rtype), dnsclass2str(rclass));
332 return std::string(buffer, len);
333 }
334
readIntFields(const char * buffer,const char * buffer_end,unsigned * rdlen)335 const char* DNSRecord::readIntFields(const char* buffer, const char* buffer_end,
336 unsigned* rdlen) {
337 if (buffer + sizeof(IntFields) > buffer_end ) {
338 ALOGI("parsing failed at line %d", __LINE__);
339 return nullptr;
340 }
341 const auto& intfields = *reinterpret_cast<const IntFields*>(buffer);
342 rtype = ntohs(intfields.rtype);
343 rclass = ntohs(intfields.rclass);
344 ttl = ntohl(intfields.ttl);
345 *rdlen = ntohs(intfields.rdlen);
346 return buffer + sizeof(IntFields);
347 }
348
writeIntFields(unsigned rdlen,char * buffer,const char * buffer_end) const349 char* DNSRecord::writeIntFields(unsigned rdlen, char* buffer,
350 const char* buffer_end) const {
351 if (buffer + sizeof(IntFields) > buffer_end ) {
352 ALOGI("buffer overflow on line %d", __LINE__);
353 return nullptr;
354 }
355 auto& intfields = *reinterpret_cast<IntFields*>(buffer);
356 intfields.rtype = htons(rtype);
357 intfields.rclass = htons(rclass);
358 intfields.ttl = htonl(ttl);
359 intfields.rdlen = htons(rdlen);
360 return buffer + sizeof(IntFields);
361 }
362
363 struct DNSHeader {
364 unsigned id;
365 bool ra;
366 uint8_t rcode;
367 bool qr;
368 uint8_t opcode;
369 bool aa;
370 bool tr;
371 bool rd;
372 std::vector<DNSQuestion> questions;
373 std::vector<DNSRecord> answers;
374 std::vector<DNSRecord> authorities;
375 std::vector<DNSRecord> additionals;
376 const char* read(const char* buffer, const char* buffer_end);
377 char* write(char* buffer, const char* buffer_end) const;
378 std::string toString() const;
379
380 private:
381 struct Header {
382 uint16_t id;
383 uint8_t flags0;
384 uint8_t flags1;
385 uint16_t qdcount;
386 uint16_t ancount;
387 uint16_t nscount;
388 uint16_t arcount;
389 } __attribute__((__packed__));
390
391 const char* readHeader(const char* buffer, const char* buffer_end,
392 unsigned* qdcount, unsigned* ancount,
393 unsigned* nscount, unsigned* arcount);
394 };
395
read(const char * buffer,const char * buffer_end)396 const char* DNSHeader::read(const char* buffer, const char* buffer_end) {
397 unsigned qdcount;
398 unsigned ancount;
399 unsigned nscount;
400 unsigned arcount;
401 const char* cur = readHeader(buffer, buffer_end, &qdcount, &ancount,
402 &nscount, &arcount);
403 if (cur == nullptr) {
404 ALOGI("parsing failed at line %d", __LINE__);
405 return nullptr;
406 }
407 if (qdcount) {
408 questions.resize(qdcount);
409 for (unsigned i = 0 ; i < qdcount ; ++i) {
410 cur = questions[i].read(cur, buffer_end);
411 if (cur == nullptr) {
412 ALOGI("parsing failed at line %d", __LINE__);
413 return nullptr;
414 }
415 }
416 }
417 if (ancount) {
418 answers.resize(ancount);
419 for (unsigned i = 0 ; i < ancount ; ++i) {
420 cur = answers[i].read(cur, buffer_end);
421 if (cur == nullptr) {
422 ALOGI("parsing failed at line %d", __LINE__);
423 return nullptr;
424 }
425 }
426 }
427 if (nscount) {
428 authorities.resize(nscount);
429 for (unsigned i = 0 ; i < nscount ; ++i) {
430 cur = authorities[i].read(cur, buffer_end);
431 if (cur == nullptr) {
432 ALOGI("parsing failed at line %d", __LINE__);
433 return nullptr;
434 }
435 }
436 }
437 if (arcount) {
438 additionals.resize(arcount);
439 for (unsigned i = 0 ; i < arcount ; ++i) {
440 cur = additionals[i].read(cur, buffer_end);
441 if (cur == nullptr) {
442 ALOGI("parsing failed at line %d", __LINE__);
443 return nullptr;
444 }
445 }
446 }
447 return cur;
448 }
449
write(char * buffer,const char * buffer_end) const450 char* DNSHeader::write(char* buffer, const char* buffer_end) const {
451 if (buffer + sizeof(Header) > buffer_end) {
452 ALOGI("buffer overflow on line %d", __LINE__);
453 return nullptr;
454 }
455 Header& header = *reinterpret_cast<Header*>(buffer);
456 // bytes 0-1
457 header.id = htons(id);
458 // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
459 header.flags0 = (qr << 7) | (opcode << 3) | (aa << 2) | (tr << 1) | rd;
460 // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
461 header.flags1 = rcode;
462 // rest of header
463 header.qdcount = htons(questions.size());
464 header.ancount = htons(answers.size());
465 header.nscount = htons(authorities.size());
466 header.arcount = htons(additionals.size());
467 char* buffer_cur = buffer + sizeof(Header);
468 for (const DNSQuestion& question : questions) {
469 buffer_cur = question.write(buffer_cur, buffer_end);
470 if (buffer_cur == nullptr) return nullptr;
471 }
472 for (const DNSRecord& answer : answers) {
473 buffer_cur = answer.write(buffer_cur, buffer_end);
474 if (buffer_cur == nullptr) return nullptr;
475 }
476 for (const DNSRecord& authority : authorities) {
477 buffer_cur = authority.write(buffer_cur, buffer_end);
478 if (buffer_cur == nullptr) return nullptr;
479 }
480 for (const DNSRecord& additional : additionals) {
481 buffer_cur = additional.write(buffer_cur, buffer_end);
482 if (buffer_cur == nullptr) return nullptr;
483 }
484 return buffer_cur;
485 }
486
toString() const487 std::string DNSHeader::toString() const {
488 // TODO
489 return std::string();
490 }
491
readHeader(const char * buffer,const char * buffer_end,unsigned * qdcount,unsigned * ancount,unsigned * nscount,unsigned * arcount)492 const char* DNSHeader::readHeader(const char* buffer, const char* buffer_end,
493 unsigned* qdcount, unsigned* ancount,
494 unsigned* nscount, unsigned* arcount) {
495 if (buffer + sizeof(Header) > buffer_end)
496 return 0;
497 const auto& header = *reinterpret_cast<const Header*>(buffer);
498 // bytes 0-1
499 id = ntohs(header.id);
500 // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
501 qr = header.flags0 >> 7;
502 opcode = (header.flags0 >> 3) & 0x0F;
503 aa = (header.flags0 >> 2) & 1;
504 tr = (header.flags0 >> 1) & 1;
505 rd = header.flags0 & 1;
506 // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
507 ra = header.flags1 >> 7;
508 rcode = header.flags1 & 0xF;
509 // rest of header
510 *qdcount = ntohs(header.qdcount);
511 *ancount = ntohs(header.ancount);
512 *nscount = ntohs(header.nscount);
513 *arcount = ntohs(header.arcount);
514 return buffer + sizeof(Header);
515 }
516
517 /* DNS responder */
518
DNSResponder(std::string listen_address,std::string listen_service,int poll_timeout_ms,uint16_t error_rcode,double response_probability)519 DNSResponder::DNSResponder(std::string listen_address,
520 std::string listen_service, int poll_timeout_ms,
521 uint16_t error_rcode, double response_probability) :
522 listen_address_(std::move(listen_address)), listen_service_(std::move(listen_service)),
523 poll_timeout_ms_(poll_timeout_ms), error_rcode_(error_rcode),
524 response_probability_(response_probability),
525 socket_(-1), epoll_fd_(-1), terminate_(false) { }
526
~DNSResponder()527 DNSResponder::~DNSResponder() {
528 stopServer();
529 }
530
addMapping(const char * name,ns_type type,const char * addr)531 void DNSResponder::addMapping(const char* name, ns_type type,
532 const char* addr) {
533 std::lock_guard<std::mutex> lock(mappings_mutex_);
534 auto it = mappings_.find(QueryKey(name, type));
535 if (it != mappings_.end()) {
536 ALOGI("Overwriting mapping for (%s, %s), previous address %s, new "
537 "address %s", name, dnstype2str(type), it->second.c_str(),
538 addr);
539 it->second = addr;
540 return;
541 }
542 mappings_.emplace(std::piecewise_construct,
543 std::forward_as_tuple(name, type),
544 std::forward_as_tuple(addr));
545 }
546
removeMapping(const char * name,ns_type type)547 void DNSResponder::removeMapping(const char* name, ns_type type) {
548 std::lock_guard<std::mutex> lock(mappings_mutex_);
549 auto it = mappings_.find(QueryKey(name, type));
550 if (it != mappings_.end()) {
551 ALOGI("Cannot remove mapping mapping from (%s, %s), not present", name,
552 dnstype2str(type));
553 return;
554 }
555 mappings_.erase(it);
556 }
557
setResponseProbability(double response_probability)558 void DNSResponder::setResponseProbability(double response_probability) {
559 response_probability_ = response_probability;
560 }
561
running() const562 bool DNSResponder::running() const {
563 return socket_ != -1;
564 }
565
startServer()566 bool DNSResponder::startServer() {
567 if (running()) {
568 ALOGI("server already running");
569 return false;
570 }
571 addrinfo ai_hints{
572 .ai_family = AF_UNSPEC,
573 .ai_socktype = SOCK_DGRAM,
574 .ai_flags = AI_PASSIVE
575 };
576 addrinfo* ai_res;
577 int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(),
578 &ai_hints, &ai_res);
579 if (rv) {
580 ALOGI("getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
581 listen_service_.c_str(), gai_strerror(rv));
582 return false;
583 }
584 int s = -1;
585 for (const addrinfo* ai = ai_res ; ai ; ai = ai->ai_next) {
586 s = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
587 if (s < 0) continue;
588 const int one = 1;
589 setsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
590 if (bind(s, ai->ai_addr, ai->ai_addrlen)) {
591 APLOGI("bind failed for socket %d", s);
592 close(s);
593 s = -1;
594 continue;
595 }
596 std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
597 ALOGI("bound to UDP %s:%s", host_str.c_str(), listen_service_.c_str());
598 break;
599 }
600 freeaddrinfo(ai_res);
601 if (s < 0) {
602 ALOGI("bind() failed");
603 return false;
604 }
605
606 int flags = fcntl(s, F_GETFL, 0);
607 if (flags < 0) flags = 0;
608 if (fcntl(s, F_SETFL, flags | O_NONBLOCK) < 0) {
609 APLOGI("fcntl(F_SETFL) failed for socket %d", s);
610 close(s);
611 return false;
612 }
613
614 int ep_fd = epoll_create(1);
615 if (ep_fd < 0) {
616 char error_msg[512] = { 0 };
617 if (strerror_r(errno, error_msg, sizeof(error_msg)))
618 strncpy(error_msg, "UNKNOWN", sizeof(error_msg));
619 APLOGI("epoll_create() failed: %s", error_msg);
620 close(s);
621 return false;
622 }
623 epoll_event ev;
624 ev.events = EPOLLIN;
625 ev.data.fd = s;
626 if (epoll_ctl(ep_fd, EPOLL_CTL_ADD, s, &ev) < 0) {
627 APLOGI("epoll_ctl() failed for socket %d", s);
628 close(ep_fd);
629 close(s);
630 return false;
631 }
632
633 epoll_fd_ = ep_fd;
634 socket_ = s;
635 {
636 std::lock_guard<std::mutex> lock(update_mutex_);
637 handler_thread_ = std::thread(&DNSResponder::requestHandler, this);
638 }
639 ALOGI("server started successfully");
640 return true;
641 }
642
stopServer()643 bool DNSResponder::stopServer() {
644 std::lock_guard<std::mutex> lock(update_mutex_);
645 if (!running()) {
646 ALOGI("server not running");
647 return false;
648 }
649 if (terminate_) {
650 ALOGI("LOGIC ERROR");
651 return false;
652 }
653 ALOGI("stopping server");
654 terminate_ = true;
655 handler_thread_.join();
656 close(epoll_fd_);
657 close(socket_);
658 terminate_ = false;
659 socket_ = -1;
660 ALOGI("server stopped successfully");
661 return true;
662 }
663
queries() const664 std::vector<std::pair<std::string, ns_type >> DNSResponder::queries() const {
665 std::lock_guard<std::mutex> lock(queries_mutex_);
666 return queries_;
667 }
668
clearQueries()669 void DNSResponder::clearQueries() {
670 std::lock_guard<std::mutex> lock(queries_mutex_);
671 queries_.clear();
672 }
673
requestHandler()674 void DNSResponder::requestHandler() {
675 epoll_event evs[1];
676 while (!terminate_) {
677 int n = epoll_wait(epoll_fd_, evs, 1, poll_timeout_ms_);
678 if (n == 0) continue;
679 if (n < 0) {
680 ALOGI("epoll_wait() failed");
681 // TODO(imaipi): terminate on error.
682 return;
683 }
684 char buffer[4096];
685 sockaddr_storage sa;
686 socklen_t sa_len = sizeof(sa);
687 ssize_t len;
688 do {
689 len = recvfrom(socket_, buffer, sizeof(buffer), 0,
690 (sockaddr*) &sa, &sa_len);
691 } while (len < 0 && (errno == EAGAIN || errno == EINTR));
692 if (len <= 0) {
693 ALOGI("recvfrom() failed");
694 continue;
695 }
696 ALOGI("read %zd bytes", len);
697 char response[4096];
698 size_t response_len = sizeof(response);
699 if (handleDNSRequest(buffer, len, response, &response_len) &&
700 response_len > 0) {
701 len = sendto(socket_, response, response_len, 0,
702 reinterpret_cast<const sockaddr*>(&sa), sa_len);
703 std::string host_str =
704 addr2str(reinterpret_cast<const sockaddr*>(&sa), sa_len);
705 if (len > 0) {
706 ALOGI("sent %zu bytes to %s", len, host_str.c_str());
707 } else {
708 APLOGI("sendto() failed for %s", host_str.c_str());
709 }
710 // Test that the response is actually a correct DNS message.
711 const char* response_end = response + len;
712 DNSHeader header;
713 const char* cur = header.read(response, response_end);
714 if (cur == nullptr) ALOGI("response is flawed");
715
716 } else {
717 ALOGI("not responding");
718 }
719 }
720 }
721
handleDNSRequest(const char * buffer,ssize_t len,char * response,size_t * response_len) const722 bool DNSResponder::handleDNSRequest(const char* buffer, ssize_t len,
723 char* response, size_t* response_len)
724 const {
725 ALOGI("request: '%s'", str2hex(buffer, len).c_str());
726 const char* buffer_end = buffer + len;
727 DNSHeader header;
728 const char* cur = header.read(buffer, buffer_end);
729 // TODO(imaipi): for now, unparsable messages are silently dropped, fix.
730 if (cur == nullptr) {
731 ALOGI("failed to parse query");
732 return false;
733 }
734 if (header.qr) {
735 ALOGI("response received instead of a query");
736 return false;
737 }
738 if (header.opcode != ns_opcode::ns_o_query) {
739 ALOGI("unsupported request opcode received");
740 return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
741 response_len);
742 }
743 if (header.questions.empty()) {
744 ALOGI("no questions present");
745 return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
746 response_len);
747 }
748 if (!header.answers.empty()) {
749 ALOGI("already %zu answers present in query", header.answers.size());
750 return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
751 response_len);
752 }
753 {
754 std::lock_guard<std::mutex> lock(queries_mutex_);
755 for (const DNSQuestion& question : header.questions) {
756 queries_.push_back(make_pair(question.qname.name,
757 ns_type(question.qtype)));
758 }
759 }
760
761 // Ignore requests with the preset probability.
762 auto constexpr bound = std::numeric_limits<unsigned>::max();
763 if (arc4random_uniform(bound) > bound*response_probability_) {
764 ALOGI("returning SRVFAIL in accordance with probability distribution");
765 return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
766 response_len);
767 }
768
769 for (const DNSQuestion& question : header.questions) {
770 if (question.qclass != ns_class::ns_c_in &&
771 question.qclass != ns_class::ns_c_any) {
772 ALOGI("unsupported question class %u", question.qclass);
773 return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
774 response_len);
775 }
776 if (!addAnswerRecords(question, &header.answers)) {
777 return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
778 response_len);
779 }
780 }
781 header.qr = true;
782 char* response_cur = header.write(response, response + *response_len);
783 if (response_cur == nullptr) {
784 return false;
785 }
786 *response_len = response_cur - response;
787 return true;
788 }
789
addAnswerRecords(const DNSQuestion & question,std::vector<DNSRecord> * answers) const790 bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
791 std::vector<DNSRecord>* answers) const {
792 auto it = mappings_.find(QueryKey(question.qname.name, question.qtype));
793 if (it == mappings_.end()) {
794 // TODO(imaipi): handle correctly
795 ALOGI("no mapping found for %s %s, lazily refusing to add an answer",
796 question.qname.name.c_str(), dnstype2str(question.qtype));
797 return true;
798 }
799 ALOGI("mapping found for %s %s: %s", question.qname.name.c_str(),
800 dnstype2str(question.qtype), it->second.c_str());
801 DNSRecord record;
802 record.name = question.qname;
803 record.rtype = question.qtype;
804 record.rclass = ns_class::ns_c_in;
805 record.ttl = 5; // seconds
806 if (question.qtype == ns_type::ns_t_a) {
807 record.rdata.resize(4);
808 if (inet_pton(AF_INET, it->second.c_str(), record.rdata.data()) != 1) {
809 ALOGI("inet_pton(AF_INET, %s) failed", it->second.c_str());
810 return false;
811 }
812 } else if (question.qtype == ns_type::ns_t_aaaa) {
813 record.rdata.resize(16);
814 if (inet_pton(AF_INET6, it->second.c_str(), record.rdata.data()) != 1) {
815 ALOGI("inet_pton(AF_INET6, %s) failed", it->second.c_str());
816 return false;
817 }
818 } else {
819 ALOGI("unhandled qtype %s", dnstype2str(question.qtype));
820 return false;
821 }
822 answers->push_back(std::move(record));
823 return true;
824 }
825
makeErrorResponse(DNSHeader * header,ns_rcode rcode,char * response,size_t * response_len) const826 bool DNSResponder::makeErrorResponse(DNSHeader* header, ns_rcode rcode,
827 char* response, size_t* response_len)
828 const {
829 header->answers.clear();
830 header->authorities.clear();
831 header->additionals.clear();
832 header->rcode = rcode;
833 header->qr = true;
834 char* response_cur = header->write(response, response + *response_len);
835 if (response_cur == nullptr) return false;
836 *response_len = response_cur - response;
837 return true;
838 }
839
840 } // namespace test
841
842