• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "dns_responder.h"
18 
19 #include <arpa/inet.h>
20 #include <fcntl.h>
21 #include <netdb.h>
22 #include <stdarg.h>
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <string.h>
26 #include <sys/epoll.h>
27 #include <sys/socket.h>
28 #include <sys/types.h>
29 #include <unistd.h>
30 
31 #include <iostream>
32 #include <vector>
33 
34 #define LOG_TAG "DNSResponder"
35 #include <log/log.h>
36 
37 namespace test {
38 
errno2str()39 std::string errno2str() {
40     char error_msg[512] = { 0 };
41     if (strerror_r(errno, error_msg, sizeof(error_msg)))
42         return std::string();
43     return std::string(error_msg);
44 }
45 
46 #define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str())
47 
str2hex(const char * buffer,size_t len)48 std::string str2hex(const char* buffer, size_t len) {
49     std::string str(len*2, '\0');
50     for (size_t i = 0 ; i < len ; ++i) {
51         static const char* hex = "0123456789ABCDEF";
52         uint8_t c = buffer[i];
53         str[i*2] = hex[c >> 4];
54         str[i*2 + 1] = hex[c & 0x0F];
55     }
56     return str;
57 }
58 
addr2str(const sockaddr * sa,socklen_t sa_len)59 std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
60     char host_str[NI_MAXHOST] = { 0 };
61     int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0,
62                          NI_NUMERICHOST);
63     if (rv == 0) return std::string(host_str);
64     return std::string();
65 }
66 
67 /* DNS struct helpers */
68 
dnstype2str(unsigned dnstype)69 const char* dnstype2str(unsigned dnstype) {
70     static std::unordered_map<unsigned, const char*> kTypeStrs = {
71         { ns_type::ns_t_a, "A" },
72         { ns_type::ns_t_ns, "NS" },
73         { ns_type::ns_t_md, "MD" },
74         { ns_type::ns_t_mf, "MF" },
75         { ns_type::ns_t_cname, "CNAME" },
76         { ns_type::ns_t_soa, "SOA" },
77         { ns_type::ns_t_mb, "MB" },
78         { ns_type::ns_t_mb, "MG" },
79         { ns_type::ns_t_mr, "MR" },
80         { ns_type::ns_t_null, "NULL" },
81         { ns_type::ns_t_wks, "WKS" },
82         { ns_type::ns_t_ptr, "PTR" },
83         { ns_type::ns_t_hinfo, "HINFO" },
84         { ns_type::ns_t_minfo, "MINFO" },
85         { ns_type::ns_t_mx, "MX" },
86         { ns_type::ns_t_txt, "TXT" },
87         { ns_type::ns_t_rp, "RP" },
88         { ns_type::ns_t_afsdb, "AFSDB" },
89         { ns_type::ns_t_x25, "X25" },
90         { ns_type::ns_t_isdn, "ISDN" },
91         { ns_type::ns_t_rt, "RT" },
92         { ns_type::ns_t_nsap, "NSAP" },
93         { ns_type::ns_t_nsap_ptr, "NSAP-PTR" },
94         { ns_type::ns_t_sig, "SIG" },
95         { ns_type::ns_t_key, "KEY" },
96         { ns_type::ns_t_px, "PX" },
97         { ns_type::ns_t_gpos, "GPOS" },
98         { ns_type::ns_t_aaaa, "AAAA" },
99         { ns_type::ns_t_loc, "LOC" },
100         { ns_type::ns_t_nxt, "NXT" },
101         { ns_type::ns_t_eid, "EID" },
102         { ns_type::ns_t_nimloc, "NIMLOC" },
103         { ns_type::ns_t_srv, "SRV" },
104         { ns_type::ns_t_naptr, "NAPTR" },
105         { ns_type::ns_t_kx, "KX" },
106         { ns_type::ns_t_cert, "CERT" },
107         { ns_type::ns_t_a6, "A6" },
108         { ns_type::ns_t_dname, "DNAME" },
109         { ns_type::ns_t_sink, "SINK" },
110         { ns_type::ns_t_opt, "OPT" },
111         { ns_type::ns_t_apl, "APL" },
112         { ns_type::ns_t_tkey, "TKEY" },
113         { ns_type::ns_t_tsig, "TSIG" },
114         { ns_type::ns_t_ixfr, "IXFR" },
115         { ns_type::ns_t_axfr, "AXFR" },
116         { ns_type::ns_t_mailb, "MAILB" },
117         { ns_type::ns_t_maila, "MAILA" },
118         { ns_type::ns_t_any, "ANY" },
119         { ns_type::ns_t_zxfr, "ZXFR" },
120     };
121     auto it = kTypeStrs.find(dnstype);
122     static const char* kUnknownStr{ "UNKNOWN" };
123     if (it == kTypeStrs.end()) return kUnknownStr;
124     return it->second;
125 }
126 
dnsclass2str(unsigned dnsclass)127 const char* dnsclass2str(unsigned dnsclass) {
128     static std::unordered_map<unsigned, const char*> kClassStrs = {
129         { ns_class::ns_c_in , "Internet" },
130         { 2, "CSNet" },
131         { ns_class::ns_c_chaos, "ChaosNet" },
132         { ns_class::ns_c_hs, "Hesiod" },
133         { ns_class::ns_c_none, "none" },
134         { ns_class::ns_c_any, "any" }
135     };
136     auto it = kClassStrs.find(dnsclass);
137     static const char* kUnknownStr{ "UNKNOWN" };
138     if (it == kClassStrs.end()) return kUnknownStr;
139     return it->second;
140     return "unknown";
141 }
142 
143 struct DNSName {
144     std::string name;
145     const char* read(const char* buffer, const char* buffer_end);
146     char* write(char* buffer, const char* buffer_end) const;
147     const char* toString() const;
148 private:
149     const char* parseField(const char* buffer, const char* buffer_end,
150                            bool* last);
151 };
152 
toString() const153 const char* DNSName::toString() const {
154     return name.c_str();
155 }
156 
read(const char * buffer,const char * buffer_end)157 const char* DNSName::read(const char* buffer, const char* buffer_end) {
158     const char* cur = buffer;
159     bool last = false;
160     do {
161         cur = parseField(cur, buffer_end, &last);
162         if (cur == nullptr) {
163             ALOGI("parsing failed at line %d", __LINE__);
164             return nullptr;
165         }
166     } while (!last);
167     return cur;
168 }
169 
write(char * buffer,const char * buffer_end) const170 char* DNSName::write(char* buffer, const char* buffer_end) const {
171     char* buffer_cur = buffer;
172     for (size_t pos = 0 ; pos < name.size() ; ) {
173         size_t dot_pos = name.find('.', pos);
174         if (dot_pos == std::string::npos) {
175             // Sanity check, should never happen unless parseField is broken.
176             ALOGI("logic error: all names are expected to end with a '.'");
177             return nullptr;
178         }
179         size_t len = dot_pos - pos;
180         if (len >= 256) {
181             ALOGI("name component '%s' is %zu long, but max is 255",
182                     name.substr(pos, dot_pos - pos).c_str(), len);
183             return nullptr;
184         }
185         if (buffer_cur + sizeof(uint8_t) + len > buffer_end) {
186             ALOGI("buffer overflow at line %d", __LINE__);
187             return nullptr;
188         }
189         *buffer_cur++ = len;
190         buffer_cur = std::copy(std::next(name.begin(), pos),
191                                std::next(name.begin(), dot_pos),
192                                buffer_cur);
193         pos = dot_pos + 1;
194     }
195     // Write final zero.
196     *buffer_cur++ = 0;
197     return buffer_cur;
198 }
199 
parseField(const char * buffer,const char * buffer_end,bool * last)200 const char* DNSName::parseField(const char* buffer, const char* buffer_end,
201                                 bool* last) {
202     if (buffer + sizeof(uint8_t) > buffer_end) {
203         ALOGI("parsing failed at line %d", __LINE__);
204         return nullptr;
205     }
206     unsigned field_type = *buffer >> 6;
207     unsigned ofs = *buffer & 0x3F;
208     const char* cur = buffer + sizeof(uint8_t);
209     if (field_type == 0) {
210         // length + name component
211         if (ofs == 0) {
212             *last = true;
213             return cur;
214         }
215         if (cur + ofs > buffer_end) {
216             ALOGI("parsing failed at line %d", __LINE__);
217             return nullptr;
218         }
219         name.append(cur, ofs);
220         name.push_back('.');
221         return cur + ofs;
222     } else if (field_type == 3) {
223         ALOGI("name compression not implemented");
224         return nullptr;
225     }
226     ALOGI("invalid name field type");
227     return nullptr;
228 }
229 
230 struct DNSQuestion {
231     DNSName qname;
232     unsigned qtype;
233     unsigned qclass;
234     const char* read(const char* buffer, const char* buffer_end);
235     char* write(char* buffer, const char* buffer_end) const;
236     std::string toString() const;
237 };
238 
read(const char * buffer,const char * buffer_end)239 const char* DNSQuestion::read(const char* buffer, const char* buffer_end) {
240     const char* cur = qname.read(buffer, buffer_end);
241     if (cur == nullptr) {
242         ALOGI("parsing failed at line %d", __LINE__);
243         return nullptr;
244     }
245     if (cur + 2*sizeof(uint16_t) > buffer_end) {
246         ALOGI("parsing failed at line %d", __LINE__);
247         return nullptr;
248     }
249     qtype = ntohs(*reinterpret_cast<const uint16_t*>(cur));
250     qclass = ntohs(*reinterpret_cast<const uint16_t*>(cur + sizeof(uint16_t)));
251     return cur + 2*sizeof(uint16_t);
252 }
253 
write(char * buffer,const char * buffer_end) const254 char* DNSQuestion::write(char* buffer, const char* buffer_end) const {
255     char* buffer_cur = qname.write(buffer, buffer_end);
256     if (buffer_cur == nullptr) return nullptr;
257     if (buffer_cur + 2*sizeof(uint16_t) > buffer_end) {
258         ALOGI("buffer overflow on line %d", __LINE__);
259         return nullptr;
260     }
261     *reinterpret_cast<uint16_t*>(buffer_cur) = htons(qtype);
262     *reinterpret_cast<uint16_t*>(buffer_cur + sizeof(uint16_t)) =
263             htons(qclass);
264     return buffer_cur + 2*sizeof(uint16_t);
265 }
266 
toString() const267 std::string DNSQuestion::toString() const {
268     char buffer[4096];
269     int len = snprintf(buffer, sizeof(buffer), "Q<%s,%s,%s>", qname.toString(),
270                        dnstype2str(qtype), dnsclass2str(qclass));
271     return std::string(buffer, len);
272 }
273 
274 struct DNSRecord {
275     DNSName name;
276     unsigned rtype;
277     unsigned rclass;
278     unsigned ttl;
279     std::vector<char> rdata;
280     const char* read(const char* buffer, const char* buffer_end);
281     char* write(char* buffer, const char* buffer_end) const;
282     std::string toString() const;
283 private:
284     struct IntFields {
285         uint16_t rtype;
286         uint16_t rclass;
287         uint32_t ttl;
288         uint16_t rdlen;
289     } __attribute__((__packed__));
290 
291     const char* readIntFields(const char* buffer, const char* buffer_end,
292             unsigned* rdlen);
293     char* writeIntFields(unsigned rdlen, char* buffer,
294                          const char* buffer_end) const;
295 };
296 
read(const char * buffer,const char * buffer_end)297 const char* DNSRecord::read(const char* buffer, const char* buffer_end) {
298     const char* cur = name.read(buffer, buffer_end);
299     if (cur == nullptr) {
300         ALOGI("parsing failed at line %d", __LINE__);
301         return nullptr;
302     }
303     unsigned rdlen = 0;
304     cur = readIntFields(cur, buffer_end, &rdlen);
305     if (cur == nullptr) {
306         ALOGI("parsing failed at line %d", __LINE__);
307         return nullptr;
308     }
309     if (cur + rdlen > buffer_end) {
310         ALOGI("parsing failed at line %d", __LINE__);
311         return nullptr;
312     }
313     rdata.assign(cur, cur + rdlen);
314     return cur + rdlen;
315 }
316 
write(char * buffer,const char * buffer_end) const317 char* DNSRecord::write(char* buffer, const char* buffer_end) const {
318     char* buffer_cur = name.write(buffer, buffer_end);
319     if (buffer_cur == nullptr) return nullptr;
320     buffer_cur = writeIntFields(rdata.size(), buffer_cur, buffer_end);
321     if (buffer_cur == nullptr) return nullptr;
322     if (buffer_cur + rdata.size() > buffer_end) {
323         ALOGI("buffer overflow on line %d", __LINE__);
324         return nullptr;
325     }
326     return std::copy(rdata.begin(), rdata.end(), buffer_cur);
327 }
328 
toString() const329 std::string DNSRecord::toString() const {
330     char buffer[4096];
331     int len = snprintf(buffer, sizeof(buffer), "R<%s,%s,%s>", name.toString(),
332                        dnstype2str(rtype), dnsclass2str(rclass));
333     return std::string(buffer, len);
334 }
335 
readIntFields(const char * buffer,const char * buffer_end,unsigned * rdlen)336 const char* DNSRecord::readIntFields(const char* buffer, const char* buffer_end,
337                                      unsigned* rdlen) {
338     if (buffer + sizeof(IntFields) > buffer_end ) {
339         ALOGI("parsing failed at line %d", __LINE__);
340         return nullptr;
341     }
342     const auto& intfields = *reinterpret_cast<const IntFields*>(buffer);
343     rtype = ntohs(intfields.rtype);
344     rclass = ntohs(intfields.rclass);
345     ttl = ntohl(intfields.ttl);
346     *rdlen = ntohs(intfields.rdlen);
347     return buffer + sizeof(IntFields);
348 }
349 
writeIntFields(unsigned rdlen,char * buffer,const char * buffer_end) const350 char* DNSRecord::writeIntFields(unsigned rdlen, char* buffer,
351                                 const char* buffer_end) const {
352     if (buffer + sizeof(IntFields) > buffer_end ) {
353         ALOGI("buffer overflow on line %d", __LINE__);
354         return nullptr;
355     }
356     auto& intfields = *reinterpret_cast<IntFields*>(buffer);
357     intfields.rtype = htons(rtype);
358     intfields.rclass = htons(rclass);
359     intfields.ttl = htonl(ttl);
360     intfields.rdlen = htons(rdlen);
361     return buffer + sizeof(IntFields);
362 }
363 
364 struct DNSHeader {
365     unsigned id;
366     bool ra;
367     uint8_t rcode;
368     bool qr;
369     uint8_t opcode;
370     bool aa;
371     bool tr;
372     bool rd;
373     std::vector<DNSQuestion> questions;
374     std::vector<DNSRecord> answers;
375     std::vector<DNSRecord> authorities;
376     std::vector<DNSRecord> additionals;
377     const char* read(const char* buffer, const char* buffer_end);
378     char* write(char* buffer, const char* buffer_end) const;
379     std::string toString() const;
380 
381 private:
382     struct Header {
383         uint16_t id;
384         uint8_t flags0;
385         uint8_t flags1;
386         uint16_t qdcount;
387         uint16_t ancount;
388         uint16_t nscount;
389         uint16_t arcount;
390     } __attribute__((__packed__));
391 
392     const char* readHeader(const char* buffer, const char* buffer_end,
393                            unsigned* qdcount, unsigned* ancount,
394                            unsigned* nscount, unsigned* arcount);
395 };
396 
read(const char * buffer,const char * buffer_end)397 const char* DNSHeader::read(const char* buffer, const char* buffer_end) {
398     unsigned qdcount;
399     unsigned ancount;
400     unsigned nscount;
401     unsigned arcount;
402     const char* cur = readHeader(buffer, buffer_end, &qdcount, &ancount,
403                                  &nscount, &arcount);
404     if (cur == nullptr) {
405         ALOGI("parsing failed at line %d", __LINE__);
406         return nullptr;
407     }
408     if (qdcount) {
409         questions.resize(qdcount);
410         for (unsigned i = 0 ; i < qdcount ; ++i) {
411             cur = questions[i].read(cur, buffer_end);
412             if (cur == nullptr) {
413                 ALOGI("parsing failed at line %d", __LINE__);
414                 return nullptr;
415             }
416         }
417     }
418     if (ancount) {
419         answers.resize(ancount);
420         for (unsigned i = 0 ; i < ancount ; ++i) {
421             cur = answers[i].read(cur, buffer_end);
422             if (cur == nullptr) {
423                 ALOGI("parsing failed at line %d", __LINE__);
424                 return nullptr;
425             }
426         }
427     }
428     if (nscount) {
429         authorities.resize(nscount);
430         for (unsigned i = 0 ; i < nscount ; ++i) {
431             cur = authorities[i].read(cur, buffer_end);
432             if (cur == nullptr) {
433                 ALOGI("parsing failed at line %d", __LINE__);
434                 return nullptr;
435             }
436         }
437     }
438     if (arcount) {
439         additionals.resize(arcount);
440         for (unsigned i = 0 ; i < arcount ; ++i) {
441             cur = additionals[i].read(cur, buffer_end);
442             if (cur == nullptr) {
443                 ALOGI("parsing failed at line %d", __LINE__);
444                 return nullptr;
445             }
446         }
447     }
448     return cur;
449 }
450 
write(char * buffer,const char * buffer_end) const451 char* DNSHeader::write(char* buffer, const char* buffer_end) const {
452     if (buffer + sizeof(Header) > buffer_end) {
453         ALOGI("buffer overflow on line %d", __LINE__);
454         return nullptr;
455     }
456     Header& header = *reinterpret_cast<Header*>(buffer);
457     // bytes 0-1
458     header.id = htons(id);
459     // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
460     header.flags0 = (qr << 7) | (opcode << 3) | (aa << 2) | (tr << 1) | rd;
461     // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
462     header.flags1 = rcode;
463     // rest of header
464     header.qdcount = htons(questions.size());
465     header.ancount = htons(answers.size());
466     header.nscount = htons(authorities.size());
467     header.arcount = htons(additionals.size());
468     char* buffer_cur = buffer + sizeof(Header);
469     for (const DNSQuestion& question : questions) {
470         buffer_cur = question.write(buffer_cur, buffer_end);
471         if (buffer_cur == nullptr) return nullptr;
472     }
473     for (const DNSRecord& answer : answers) {
474         buffer_cur = answer.write(buffer_cur, buffer_end);
475         if (buffer_cur == nullptr) return nullptr;
476     }
477     for (const DNSRecord& authority : authorities) {
478         buffer_cur = authority.write(buffer_cur, buffer_end);
479         if (buffer_cur == nullptr) return nullptr;
480     }
481     for (const DNSRecord& additional : additionals) {
482         buffer_cur = additional.write(buffer_cur, buffer_end);
483         if (buffer_cur == nullptr) return nullptr;
484     }
485     return buffer_cur;
486 }
487 
toString() const488 std::string DNSHeader::toString() const {
489     // TODO
490     return std::string();
491 }
492 
readHeader(const char * buffer,const char * buffer_end,unsigned * qdcount,unsigned * ancount,unsigned * nscount,unsigned * arcount)493 const char* DNSHeader::readHeader(const char* buffer, const char* buffer_end,
494                                   unsigned* qdcount, unsigned* ancount,
495                                   unsigned* nscount, unsigned* arcount) {
496     if (buffer + sizeof(Header) > buffer_end)
497         return 0;
498     const auto& header = *reinterpret_cast<const Header*>(buffer);
499     // bytes 0-1
500     id = ntohs(header.id);
501     // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
502     qr = header.flags0 >> 7;
503     opcode = (header.flags0 >> 3) & 0x0F;
504     aa = (header.flags0 >> 2) & 1;
505     tr = (header.flags0 >> 1) & 1;
506     rd = header.flags0 & 1;
507     // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
508     ra = header.flags1 >> 7;
509     rcode = header.flags1 & 0xF;
510     // rest of header
511     *qdcount = ntohs(header.qdcount);
512     *ancount = ntohs(header.ancount);
513     *nscount = ntohs(header.nscount);
514     *arcount = ntohs(header.arcount);
515     return buffer + sizeof(Header);
516 }
517 
518 /* DNS responder */
519 
DNSResponder(std::string listen_address,std::string listen_service,int poll_timeout_ms,uint16_t error_rcode,double response_probability)520 DNSResponder::DNSResponder(std::string listen_address,
521                            std::string listen_service, int poll_timeout_ms,
522                            uint16_t error_rcode, double response_probability) :
523     listen_address_(std::move(listen_address)), listen_service_(std::move(listen_service)),
524     poll_timeout_ms_(poll_timeout_ms), error_rcode_(error_rcode),
525     response_probability_(response_probability),
526     socket_(-1), epoll_fd_(-1), terminate_(false) { }
527 
~DNSResponder()528 DNSResponder::~DNSResponder() {
529     stopServer();
530 }
531 
addMapping(const char * name,ns_type type,const char * addr)532 void DNSResponder::addMapping(const char* name, ns_type type,
533         const char* addr) {
534     std::lock_guard<std::mutex> lock(mappings_mutex_);
535     auto it = mappings_.find(QueryKey(name, type));
536     if (it != mappings_.end()) {
537         ALOGI("Overwriting mapping for (%s, %s), previous address %s, new "
538             "address %s", name, dnstype2str(type), it->second.c_str(),
539             addr);
540         it->second = addr;
541         return;
542     }
543     mappings_.emplace(std::piecewise_construct,
544                       std::forward_as_tuple(name, type),
545                       std::forward_as_tuple(addr));
546 }
547 
removeMapping(const char * name,ns_type type)548 void DNSResponder::removeMapping(const char* name, ns_type type) {
549     std::lock_guard<std::mutex> lock(mappings_mutex_);
550     auto it = mappings_.find(QueryKey(name, type));
551     if (it != mappings_.end()) {
552         ALOGI("Cannot remove mapping mapping from (%s, %s), not present", name,
553             dnstype2str(type));
554         return;
555     }
556     mappings_.erase(it);
557 }
558 
setResponseProbability(double response_probability)559 void DNSResponder::setResponseProbability(double response_probability) {
560     response_probability_ = response_probability;
561 }
562 
running() const563 bool DNSResponder::running() const {
564     return socket_ != -1;
565 }
566 
startServer()567 bool DNSResponder::startServer() {
568     if (running()) {
569         ALOGI("server already running");
570         return false;
571     }
572     addrinfo ai_hints{
573         .ai_family = AF_UNSPEC,
574         .ai_socktype = SOCK_DGRAM,
575         .ai_flags = AI_PASSIVE
576     };
577     addrinfo* ai_res;
578     int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(),
579                          &ai_hints, &ai_res);
580     if (rv) {
581         ALOGI("getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
582             listen_service_.c_str(), gai_strerror(rv));
583         return false;
584     }
585     int s = -1;
586     for (const addrinfo* ai = ai_res ; ai ; ai = ai->ai_next) {
587         s = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
588         if (s < 0) continue;
589         const int one = 1;
590         setsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
591         if (bind(s, ai->ai_addr, ai->ai_addrlen)) {
592             APLOGI("bind failed for socket %d", s);
593             close(s);
594             s = -1;
595             continue;
596         }
597         std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
598         ALOGI("bound to UDP %s:%s", host_str.c_str(), listen_service_.c_str());
599         break;
600     }
601     freeaddrinfo(ai_res);
602     if (s < 0) {
603         ALOGI("bind() failed");
604         return false;
605     }
606 
607     int flags = fcntl(s, F_GETFL, 0);
608     if (flags < 0) flags = 0;
609     if (fcntl(s, F_SETFL, flags | O_NONBLOCK) < 0) {
610         APLOGI("fcntl(F_SETFL) failed for socket %d", s);
611         close(s);
612         return false;
613     }
614 
615     int ep_fd = epoll_create(1);
616     if (ep_fd < 0) {
617         char error_msg[512] = { 0 };
618         if (strerror_r(errno, error_msg, sizeof(error_msg)))
619             strncpy(error_msg, "UNKNOWN", sizeof(error_msg));
620         APLOGI("epoll_create() failed: %s", error_msg);
621         close(s);
622         return false;
623     }
624     epoll_event ev;
625     ev.events = EPOLLIN;
626     ev.data.fd = s;
627     if (epoll_ctl(ep_fd, EPOLL_CTL_ADD, s, &ev) < 0) {
628         APLOGI("epoll_ctl() failed for socket %d", s);
629         close(ep_fd);
630         close(s);
631         return false;
632     }
633 
634     epoll_fd_ = ep_fd;
635     socket_ = s;
636     {
637         std::lock_guard<std::mutex> lock(update_mutex_);
638         handler_thread_ = std::thread(&DNSResponder::requestHandler, this);
639     }
640     ALOGI("server started successfully");
641     return true;
642 }
643 
stopServer()644 bool DNSResponder::stopServer() {
645     std::lock_guard<std::mutex> lock(update_mutex_);
646     if (!running()) {
647         ALOGI("server not running");
648         return false;
649     }
650     if (terminate_) {
651         ALOGI("LOGIC ERROR");
652         return false;
653     }
654     ALOGI("stopping server");
655     terminate_ = true;
656     handler_thread_.join();
657     close(epoll_fd_);
658     close(socket_);
659     terminate_ = false;
660     socket_ = -1;
661     ALOGI("server stopped successfully");
662     return true;
663 }
664 
queries() const665 std::vector<std::pair<std::string, ns_type >> DNSResponder::queries() const {
666     std::lock_guard<std::mutex> lock(queries_mutex_);
667     return queries_;
668 }
669 
clearQueries()670 void DNSResponder::clearQueries() {
671     std::lock_guard<std::mutex> lock(queries_mutex_);
672     queries_.clear();
673 }
674 
requestHandler()675 void DNSResponder::requestHandler() {
676     epoll_event evs[1];
677     while (!terminate_) {
678         int n = epoll_wait(epoll_fd_, evs, 1, poll_timeout_ms_);
679         if (n == 0) continue;
680         if (n < 0) {
681             ALOGI("epoll_wait() failed");
682             // TODO(imaipi): terminate on error.
683             return;
684         }
685         char buffer[4096];
686         sockaddr_storage sa;
687         socklen_t sa_len = sizeof(sa);
688         ssize_t len;
689         do {
690             len = recvfrom(socket_, buffer, sizeof(buffer), 0,
691                            (sockaddr*) &sa, &sa_len);
692         } while (len < 0 && (errno == EAGAIN || errno == EINTR));
693         if (len <= 0) {
694             ALOGI("recvfrom() failed");
695             continue;
696         }
697         ALOGI("read %zd bytes", len);
698         char response[4096];
699         size_t response_len = sizeof(response);
700         if (handleDNSRequest(buffer, len, response, &response_len) &&
701             response_len > 0) {
702             len = sendto(socket_, response, response_len, 0,
703                          reinterpret_cast<const sockaddr*>(&sa), sa_len);
704             std::string host_str =
705                 addr2str(reinterpret_cast<const sockaddr*>(&sa), sa_len);
706             if (len > 0) {
707                 ALOGI("sent %zu bytes to %s", len, host_str.c_str());
708             } else {
709                 APLOGI("sendto() failed for %s", host_str.c_str());
710             }
711             // Test that the response is actually a correct DNS message.
712             const char* response_end = response + len;
713             DNSHeader header;
714             const char* cur = header.read(response, response_end);
715             if (cur == nullptr) ALOGI("response is flawed");
716 
717         } else {
718             ALOGI("not responding");
719         }
720     }
721 }
722 
handleDNSRequest(const char * buffer,ssize_t len,char * response,size_t * response_len) const723 bool DNSResponder::handleDNSRequest(const char* buffer, ssize_t len,
724                                     char* response, size_t* response_len)
725                                     const {
726     ALOGI("request: '%s'", str2hex(buffer, len).c_str());
727     const char* buffer_end = buffer + len;
728     DNSHeader header;
729     const char* cur = header.read(buffer, buffer_end);
730     // TODO(imaipi): for now, unparsable messages are silently dropped, fix.
731     if (cur == nullptr) {
732         ALOGI("failed to parse query");
733         return false;
734     }
735     if (header.qr) {
736         ALOGI("response received instead of a query");
737         return false;
738     }
739     if (header.opcode != ns_opcode::ns_o_query) {
740         ALOGI("unsupported request opcode received");
741         return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
742                                  response_len);
743     }
744     if (header.questions.empty()) {
745         ALOGI("no questions present");
746         return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
747                                  response_len);
748     }
749     if (!header.answers.empty()) {
750         ALOGI("already %zu answers present in query", header.answers.size());
751         return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
752                                  response_len);
753     }
754     {
755         std::lock_guard<std::mutex> lock(queries_mutex_);
756         for (const DNSQuestion& question : header.questions) {
757             queries_.push_back(make_pair(question.qname.name,
758                                          ns_type(question.qtype)));
759         }
760     }
761 
762     // Ignore requests with the preset probability.
763     auto constexpr bound = std::numeric_limits<unsigned>::max();
764     if (arc4random_uniform(bound) > bound*response_probability_) {
765         ALOGI("returning SRVFAIL in accordance with probability distribution");
766         return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
767                                  response_len);
768     }
769 
770     for (const DNSQuestion& question : header.questions) {
771         if (question.qclass != ns_class::ns_c_in &&
772             question.qclass != ns_class::ns_c_any) {
773             ALOGI("unsupported question class %u", question.qclass);
774             return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
775                                      response_len);
776         }
777         if (!addAnswerRecords(question, &header.answers)) {
778             return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
779                                      response_len);
780         }
781     }
782     header.qr = true;
783     char* response_cur = header.write(response, response + *response_len);
784     if (response_cur == nullptr) {
785         return false;
786     }
787     *response_len = response_cur - response;
788     return true;
789 }
790 
addAnswerRecords(const DNSQuestion & question,std::vector<DNSRecord> * answers) const791 bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
792                                     std::vector<DNSRecord>* answers) const {
793     auto it = mappings_.find(QueryKey(question.qname.name, question.qtype));
794     if (it == mappings_.end()) {
795         // TODO(imaipi): handle correctly
796         ALOGI("no mapping found for %s %s, lazily refusing to add an answer",
797             question.qname.name.c_str(), dnstype2str(question.qtype));
798         return true;
799     }
800     ALOGI("mapping found for %s %s: %s", question.qname.name.c_str(),
801         dnstype2str(question.qtype), it->second.c_str());
802     DNSRecord record;
803     record.name = question.qname;
804     record.rtype = question.qtype;
805     record.rclass = ns_class::ns_c_in;
806     record.ttl = 5;  // seconds
807     if (question.qtype == ns_type::ns_t_a) {
808         record.rdata.resize(4);
809         if (inet_pton(AF_INET, it->second.c_str(), record.rdata.data()) != 1) {
810             ALOGI("inet_pton(AF_INET, %s) failed", it->second.c_str());
811             return false;
812         }
813     } else if (question.qtype == ns_type::ns_t_aaaa) {
814         record.rdata.resize(16);
815         if (inet_pton(AF_INET6, it->second.c_str(), record.rdata.data()) != 1) {
816             ALOGI("inet_pton(AF_INET6, %s) failed", it->second.c_str());
817             return false;
818         }
819     } else {
820         ALOGI("unhandled qtype %s", dnstype2str(question.qtype));
821         return false;
822     }
823     answers->push_back(std::move(record));
824     return true;
825 }
826 
makeErrorResponse(DNSHeader * header,ns_rcode rcode,char * response,size_t * response_len) const827 bool DNSResponder::makeErrorResponse(DNSHeader* header, ns_rcode rcode,
828                                      char* response, size_t* response_len)
829                                      const {
830     header->answers.clear();
831     header->authorities.clear();
832     header->additionals.clear();
833     header->rcode = rcode;
834     header->qr = true;
835     char* response_cur = header->write(response, response + *response_len);
836     if (response_cur == nullptr) return false;
837     *response_len = response_cur - response;
838     return true;
839 }
840 
841 }  // namespace test
842 
843