• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "dns_responder.h"
18 
19 #include <arpa/inet.h>
20 #include <fcntl.h>
21 #include <netdb.h>
22 #include <stdarg.h>
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <string.h>
26 #include <sys/epoll.h>
27 #include <sys/socket.h>
28 #include <sys/types.h>
29 #include <unistd.h>
30 
31 #include <iostream>
32 #include <vector>
33 
34 #define LOG_TAG "DNSResponder"
35 #include <log/log.h>
36 #include <netdutils/SocketOption.h>
37 
38 using android::netdutils::enableSockopt;
39 
40 namespace test {
41 
errno2str()42 std::string errno2str() {
43     char error_msg[512] = { 0 };
44     if (strerror_r(errno, error_msg, sizeof(error_msg)))
45         return std::string();
46     return std::string(error_msg);
47 }
48 
49 #define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str())
50 
str2hex(const char * buffer,size_t len)51 std::string str2hex(const char* buffer, size_t len) {
52     std::string str(len*2, '\0');
53     for (size_t i = 0 ; i < len ; ++i) {
54         static const char* hex = "0123456789ABCDEF";
55         uint8_t c = buffer[i];
56         str[i*2] = hex[c >> 4];
57         str[i*2 + 1] = hex[c & 0x0F];
58     }
59     return str;
60 }
61 
addr2str(const sockaddr * sa,socklen_t sa_len)62 std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
63     char host_str[NI_MAXHOST] = { 0 };
64     int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0,
65                          NI_NUMERICHOST);
66     if (rv == 0) return std::string(host_str);
67     return std::string();
68 }
69 
70 /* DNS struct helpers */
71 
dnstype2str(unsigned dnstype)72 const char* dnstype2str(unsigned dnstype) {
73     static std::unordered_map<unsigned, const char*> kTypeStrs = {
74         { ns_type::ns_t_a, "A" },
75         { ns_type::ns_t_ns, "NS" },
76         { ns_type::ns_t_md, "MD" },
77         { ns_type::ns_t_mf, "MF" },
78         { ns_type::ns_t_cname, "CNAME" },
79         { ns_type::ns_t_soa, "SOA" },
80         { ns_type::ns_t_mb, "MB" },
81         { ns_type::ns_t_mb, "MG" },
82         { ns_type::ns_t_mr, "MR" },
83         { ns_type::ns_t_null, "NULL" },
84         { ns_type::ns_t_wks, "WKS" },
85         { ns_type::ns_t_ptr, "PTR" },
86         { ns_type::ns_t_hinfo, "HINFO" },
87         { ns_type::ns_t_minfo, "MINFO" },
88         { ns_type::ns_t_mx, "MX" },
89         { ns_type::ns_t_txt, "TXT" },
90         { ns_type::ns_t_rp, "RP" },
91         { ns_type::ns_t_afsdb, "AFSDB" },
92         { ns_type::ns_t_x25, "X25" },
93         { ns_type::ns_t_isdn, "ISDN" },
94         { ns_type::ns_t_rt, "RT" },
95         { ns_type::ns_t_nsap, "NSAP" },
96         { ns_type::ns_t_nsap_ptr, "NSAP-PTR" },
97         { ns_type::ns_t_sig, "SIG" },
98         { ns_type::ns_t_key, "KEY" },
99         { ns_type::ns_t_px, "PX" },
100         { ns_type::ns_t_gpos, "GPOS" },
101         { ns_type::ns_t_aaaa, "AAAA" },
102         { ns_type::ns_t_loc, "LOC" },
103         { ns_type::ns_t_nxt, "NXT" },
104         { ns_type::ns_t_eid, "EID" },
105         { ns_type::ns_t_nimloc, "NIMLOC" },
106         { ns_type::ns_t_srv, "SRV" },
107         { ns_type::ns_t_naptr, "NAPTR" },
108         { ns_type::ns_t_kx, "KX" },
109         { ns_type::ns_t_cert, "CERT" },
110         { ns_type::ns_t_a6, "A6" },
111         { ns_type::ns_t_dname, "DNAME" },
112         { ns_type::ns_t_sink, "SINK" },
113         { ns_type::ns_t_opt, "OPT" },
114         { ns_type::ns_t_apl, "APL" },
115         { ns_type::ns_t_tkey, "TKEY" },
116         { ns_type::ns_t_tsig, "TSIG" },
117         { ns_type::ns_t_ixfr, "IXFR" },
118         { ns_type::ns_t_axfr, "AXFR" },
119         { ns_type::ns_t_mailb, "MAILB" },
120         { ns_type::ns_t_maila, "MAILA" },
121         { ns_type::ns_t_any, "ANY" },
122         { ns_type::ns_t_zxfr, "ZXFR" },
123     };
124     auto it = kTypeStrs.find(dnstype);
125     static const char* kUnknownStr{ "UNKNOWN" };
126     if (it == kTypeStrs.end()) return kUnknownStr;
127     return it->second;
128 }
129 
dnsclass2str(unsigned dnsclass)130 const char* dnsclass2str(unsigned dnsclass) {
131     static std::unordered_map<unsigned, const char*> kClassStrs = {
132         { ns_class::ns_c_in , "Internet" },
133         { 2, "CSNet" },
134         { ns_class::ns_c_chaos, "ChaosNet" },
135         { ns_class::ns_c_hs, "Hesiod" },
136         { ns_class::ns_c_none, "none" },
137         { ns_class::ns_c_any, "any" }
138     };
139     auto it = kClassStrs.find(dnsclass);
140     static const char* kUnknownStr{ "UNKNOWN" };
141     if (it == kClassStrs.end()) return kUnknownStr;
142     return it->second;
143     return "unknown";
144 }
145 
146 struct DNSName {
147     std::string name;
148     const char* read(const char* buffer, const char* buffer_end);
149     char* write(char* buffer, const char* buffer_end) const;
150     const char* toString() const;
151 private:
152     const char* parseField(const char* buffer, const char* buffer_end,
153                            bool* last);
154 };
155 
toString() const156 const char* DNSName::toString() const {
157     return name.c_str();
158 }
159 
read(const char * buffer,const char * buffer_end)160 const char* DNSName::read(const char* buffer, const char* buffer_end) {
161     const char* cur = buffer;
162     bool last = false;
163     do {
164         cur = parseField(cur, buffer_end, &last);
165         if (cur == nullptr) {
166             ALOGI("parsing failed at line %d", __LINE__);
167             return nullptr;
168         }
169     } while (!last);
170     return cur;
171 }
172 
write(char * buffer,const char * buffer_end) const173 char* DNSName::write(char* buffer, const char* buffer_end) const {
174     char* buffer_cur = buffer;
175     for (size_t pos = 0 ; pos < name.size() ; ) {
176         size_t dot_pos = name.find('.', pos);
177         if (dot_pos == std::string::npos) {
178             // Sanity check, should never happen unless parseField is broken.
179             ALOGI("logic error: all names are expected to end with a '.'");
180             return nullptr;
181         }
182         size_t len = dot_pos - pos;
183         if (len >= 256) {
184             ALOGI("name component '%s' is %zu long, but max is 255",
185                     name.substr(pos, dot_pos - pos).c_str(), len);
186             return nullptr;
187         }
188         if (buffer_cur + sizeof(uint8_t) + len > buffer_end) {
189             ALOGI("buffer overflow at line %d", __LINE__);
190             return nullptr;
191         }
192         *buffer_cur++ = len;
193         buffer_cur = std::copy(std::next(name.begin(), pos),
194                                std::next(name.begin(), dot_pos),
195                                buffer_cur);
196         pos = dot_pos + 1;
197     }
198     // Write final zero.
199     *buffer_cur++ = 0;
200     return buffer_cur;
201 }
202 
parseField(const char * buffer,const char * buffer_end,bool * last)203 const char* DNSName::parseField(const char* buffer, const char* buffer_end,
204                                 bool* last) {
205     if (buffer + sizeof(uint8_t) > buffer_end) {
206         ALOGI("parsing failed at line %d", __LINE__);
207         return nullptr;
208     }
209     unsigned field_type = *buffer >> 6;
210     unsigned ofs = *buffer & 0x3F;
211     const char* cur = buffer + sizeof(uint8_t);
212     if (field_type == 0) {
213         // length + name component
214         if (ofs == 0) {
215             *last = true;
216             return cur;
217         }
218         if (cur + ofs > buffer_end) {
219             ALOGI("parsing failed at line %d", __LINE__);
220             return nullptr;
221         }
222         name.append(cur, ofs);
223         name.push_back('.');
224         return cur + ofs;
225     } else if (field_type == 3) {
226         ALOGI("name compression not implemented");
227         return nullptr;
228     }
229     ALOGI("invalid name field type");
230     return nullptr;
231 }
232 
233 struct DNSQuestion {
234     DNSName qname;
235     unsigned qtype;
236     unsigned qclass;
237     const char* read(const char* buffer, const char* buffer_end);
238     char* write(char* buffer, const char* buffer_end) const;
239     std::string toString() const;
240 };
241 
read(const char * buffer,const char * buffer_end)242 const char* DNSQuestion::read(const char* buffer, const char* buffer_end) {
243     const char* cur = qname.read(buffer, buffer_end);
244     if (cur == nullptr) {
245         ALOGI("parsing failed at line %d", __LINE__);
246         return nullptr;
247     }
248     if (cur + 2*sizeof(uint16_t) > buffer_end) {
249         ALOGI("parsing failed at line %d", __LINE__);
250         return nullptr;
251     }
252     qtype = ntohs(*reinterpret_cast<const uint16_t*>(cur));
253     qclass = ntohs(*reinterpret_cast<const uint16_t*>(cur + sizeof(uint16_t)));
254     return cur + 2*sizeof(uint16_t);
255 }
256 
write(char * buffer,const char * buffer_end) const257 char* DNSQuestion::write(char* buffer, const char* buffer_end) const {
258     char* buffer_cur = qname.write(buffer, buffer_end);
259     if (buffer_cur == nullptr) return nullptr;
260     if (buffer_cur + 2*sizeof(uint16_t) > buffer_end) {
261         ALOGI("buffer overflow on line %d", __LINE__);
262         return nullptr;
263     }
264     *reinterpret_cast<uint16_t*>(buffer_cur) = htons(qtype);
265     *reinterpret_cast<uint16_t*>(buffer_cur + sizeof(uint16_t)) =
266             htons(qclass);
267     return buffer_cur + 2*sizeof(uint16_t);
268 }
269 
toString() const270 std::string DNSQuestion::toString() const {
271     char buffer[4096];
272     int len = snprintf(buffer, sizeof(buffer), "Q<%s,%s,%s>", qname.toString(),
273                        dnstype2str(qtype), dnsclass2str(qclass));
274     return std::string(buffer, len);
275 }
276 
277 struct DNSRecord {
278     DNSName name;
279     unsigned rtype;
280     unsigned rclass;
281     unsigned ttl;
282     std::vector<char> rdata;
283     const char* read(const char* buffer, const char* buffer_end);
284     char* write(char* buffer, const char* buffer_end) const;
285     std::string toString() const;
286 private:
287     struct IntFields {
288         uint16_t rtype;
289         uint16_t rclass;
290         uint32_t ttl;
291         uint16_t rdlen;
292     } __attribute__((__packed__));
293 
294     const char* readIntFields(const char* buffer, const char* buffer_end,
295             unsigned* rdlen);
296     char* writeIntFields(unsigned rdlen, char* buffer,
297                          const char* buffer_end) const;
298 };
299 
read(const char * buffer,const char * buffer_end)300 const char* DNSRecord::read(const char* buffer, const char* buffer_end) {
301     const char* cur = name.read(buffer, buffer_end);
302     if (cur == nullptr) {
303         ALOGI("parsing failed at line %d", __LINE__);
304         return nullptr;
305     }
306     unsigned rdlen = 0;
307     cur = readIntFields(cur, buffer_end, &rdlen);
308     if (cur == nullptr) {
309         ALOGI("parsing failed at line %d", __LINE__);
310         return nullptr;
311     }
312     if (cur + rdlen > buffer_end) {
313         ALOGI("parsing failed at line %d", __LINE__);
314         return nullptr;
315     }
316     rdata.assign(cur, cur + rdlen);
317     return cur + rdlen;
318 }
319 
write(char * buffer,const char * buffer_end) const320 char* DNSRecord::write(char* buffer, const char* buffer_end) const {
321     char* buffer_cur = name.write(buffer, buffer_end);
322     if (buffer_cur == nullptr) return nullptr;
323     buffer_cur = writeIntFields(rdata.size(), buffer_cur, buffer_end);
324     if (buffer_cur == nullptr) return nullptr;
325     if (buffer_cur + rdata.size() > buffer_end) {
326         ALOGI("buffer overflow on line %d", __LINE__);
327         return nullptr;
328     }
329     return std::copy(rdata.begin(), rdata.end(), buffer_cur);
330 }
331 
toString() const332 std::string DNSRecord::toString() const {
333     char buffer[4096];
334     int len = snprintf(buffer, sizeof(buffer), "R<%s,%s,%s>", name.toString(),
335                        dnstype2str(rtype), dnsclass2str(rclass));
336     return std::string(buffer, len);
337 }
338 
readIntFields(const char * buffer,const char * buffer_end,unsigned * rdlen)339 const char* DNSRecord::readIntFields(const char* buffer, const char* buffer_end,
340                                      unsigned* rdlen) {
341     if (buffer + sizeof(IntFields) > buffer_end ) {
342         ALOGI("parsing failed at line %d", __LINE__);
343         return nullptr;
344     }
345     const auto& intfields = *reinterpret_cast<const IntFields*>(buffer);
346     rtype = ntohs(intfields.rtype);
347     rclass = ntohs(intfields.rclass);
348     ttl = ntohl(intfields.ttl);
349     *rdlen = ntohs(intfields.rdlen);
350     return buffer + sizeof(IntFields);
351 }
352 
writeIntFields(unsigned rdlen,char * buffer,const char * buffer_end) const353 char* DNSRecord::writeIntFields(unsigned rdlen, char* buffer,
354                                 const char* buffer_end) const {
355     if (buffer + sizeof(IntFields) > buffer_end ) {
356         ALOGI("buffer overflow on line %d", __LINE__);
357         return nullptr;
358     }
359     auto& intfields = *reinterpret_cast<IntFields*>(buffer);
360     intfields.rtype = htons(rtype);
361     intfields.rclass = htons(rclass);
362     intfields.ttl = htonl(ttl);
363     intfields.rdlen = htons(rdlen);
364     return buffer + sizeof(IntFields);
365 }
366 
367 struct DNSHeader {
368     unsigned id;
369     bool ra;
370     uint8_t rcode;
371     bool qr;
372     uint8_t opcode;
373     bool aa;
374     bool tr;
375     bool rd;
376     bool ad;
377     std::vector<DNSQuestion> questions;
378     std::vector<DNSRecord> answers;
379     std::vector<DNSRecord> authorities;
380     std::vector<DNSRecord> additionals;
381     const char* read(const char* buffer, const char* buffer_end);
382     char* write(char* buffer, const char* buffer_end) const;
383     std::string toString() const;
384 
385 private:
386     struct Header {
387         uint16_t id;
388         uint8_t flags0;
389         uint8_t flags1;
390         uint16_t qdcount;
391         uint16_t ancount;
392         uint16_t nscount;
393         uint16_t arcount;
394     } __attribute__((__packed__));
395 
396     const char* readHeader(const char* buffer, const char* buffer_end,
397                            unsigned* qdcount, unsigned* ancount,
398                            unsigned* nscount, unsigned* arcount);
399 };
400 
read(const char * buffer,const char * buffer_end)401 const char* DNSHeader::read(const char* buffer, const char* buffer_end) {
402     unsigned qdcount;
403     unsigned ancount;
404     unsigned nscount;
405     unsigned arcount;
406     const char* cur = readHeader(buffer, buffer_end, &qdcount, &ancount,
407                                  &nscount, &arcount);
408     if (cur == nullptr) {
409         ALOGI("parsing failed at line %d", __LINE__);
410         return nullptr;
411     }
412     if (qdcount) {
413         questions.resize(qdcount);
414         for (unsigned i = 0 ; i < qdcount ; ++i) {
415             cur = questions[i].read(cur, buffer_end);
416             if (cur == nullptr) {
417                 ALOGI("parsing failed at line %d", __LINE__);
418                 return nullptr;
419             }
420         }
421     }
422     if (ancount) {
423         answers.resize(ancount);
424         for (unsigned i = 0 ; i < ancount ; ++i) {
425             cur = answers[i].read(cur, buffer_end);
426             if (cur == nullptr) {
427                 ALOGI("parsing failed at line %d", __LINE__);
428                 return nullptr;
429             }
430         }
431     }
432     if (nscount) {
433         authorities.resize(nscount);
434         for (unsigned i = 0 ; i < nscount ; ++i) {
435             cur = authorities[i].read(cur, buffer_end);
436             if (cur == nullptr) {
437                 ALOGI("parsing failed at line %d", __LINE__);
438                 return nullptr;
439             }
440         }
441     }
442     if (arcount) {
443         additionals.resize(arcount);
444         for (unsigned i = 0 ; i < arcount ; ++i) {
445             cur = additionals[i].read(cur, buffer_end);
446             if (cur == nullptr) {
447                 ALOGI("parsing failed at line %d", __LINE__);
448                 return nullptr;
449             }
450         }
451     }
452     return cur;
453 }
454 
write(char * buffer,const char * buffer_end) const455 char* DNSHeader::write(char* buffer, const char* buffer_end) const {
456     if (buffer + sizeof(Header) > buffer_end) {
457         ALOGI("buffer overflow on line %d", __LINE__);
458         return nullptr;
459     }
460     Header& header = *reinterpret_cast<Header*>(buffer);
461     // bytes 0-1
462     header.id = htons(id);
463     // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
464     header.flags0 = (qr << 7) | (opcode << 3) | (aa << 2) | (tr << 1) | rd;
465     // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
466     // Fake behavior: if the query set the "ad" bit, set it in the response too.
467     // In a real server, this should be set only if the data is authentic and the
468     // query contained an "ad" bit or DNSSEC extensions.
469     header.flags1 = (ad << 5) | rcode;
470     // rest of header
471     header.qdcount = htons(questions.size());
472     header.ancount = htons(answers.size());
473     header.nscount = htons(authorities.size());
474     header.arcount = htons(additionals.size());
475     char* buffer_cur = buffer + sizeof(Header);
476     for (const DNSQuestion& question : questions) {
477         buffer_cur = question.write(buffer_cur, buffer_end);
478         if (buffer_cur == nullptr) return nullptr;
479     }
480     for (const DNSRecord& answer : answers) {
481         buffer_cur = answer.write(buffer_cur, buffer_end);
482         if (buffer_cur == nullptr) return nullptr;
483     }
484     for (const DNSRecord& authority : authorities) {
485         buffer_cur = authority.write(buffer_cur, buffer_end);
486         if (buffer_cur == nullptr) return nullptr;
487     }
488     for (const DNSRecord& additional : additionals) {
489         buffer_cur = additional.write(buffer_cur, buffer_end);
490         if (buffer_cur == nullptr) return nullptr;
491     }
492     return buffer_cur;
493 }
494 
toString() const495 std::string DNSHeader::toString() const {
496     // TODO
497     return std::string();
498 }
499 
readHeader(const char * buffer,const char * buffer_end,unsigned * qdcount,unsigned * ancount,unsigned * nscount,unsigned * arcount)500 const char* DNSHeader::readHeader(const char* buffer, const char* buffer_end,
501                                   unsigned* qdcount, unsigned* ancount,
502                                   unsigned* nscount, unsigned* arcount) {
503     if (buffer + sizeof(Header) > buffer_end)
504         return 0;
505     const auto& header = *reinterpret_cast<const Header*>(buffer);
506     // bytes 0-1
507     id = ntohs(header.id);
508     // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
509     qr = header.flags0 >> 7;
510     opcode = (header.flags0 >> 3) & 0x0F;
511     aa = (header.flags0 >> 2) & 1;
512     tr = (header.flags0 >> 1) & 1;
513     rd = header.flags0 & 1;
514     // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
515     ra = header.flags1 >> 7;
516     ad = (header.flags1 >> 5) & 1;
517     rcode = header.flags1 & 0xF;
518     // rest of header
519     *qdcount = ntohs(header.qdcount);
520     *ancount = ntohs(header.ancount);
521     *nscount = ntohs(header.nscount);
522     *arcount = ntohs(header.arcount);
523     return buffer + sizeof(Header);
524 }
525 
526 /* DNS responder */
527 
DNSResponder(std::string listen_address,std::string listen_service,int poll_timeout_ms,uint16_t error_rcode,double response_probability)528 DNSResponder::DNSResponder(std::string listen_address,
529                            std::string listen_service, int poll_timeout_ms,
530                            uint16_t error_rcode, double response_probability) :
531     listen_address_(std::move(listen_address)), listen_service_(std::move(listen_service)),
532     poll_timeout_ms_(poll_timeout_ms), error_rcode_(error_rcode),
533     response_probability_(response_probability),
534     fail_on_edns_(false),
535     socket_(-1), epoll_fd_(-1), terminate_(false) { }
536 
~DNSResponder()537 DNSResponder::~DNSResponder() {
538     stopServer();
539 }
540 
addMapping(const char * name,ns_type type,const char * addr)541 void DNSResponder::addMapping(const char* name, ns_type type,
542         const char* addr) {
543     std::lock_guard<std::mutex> lock(mappings_mutex_);
544     auto it = mappings_.find(QueryKey(name, type));
545     if (it != mappings_.end()) {
546         ALOGI("Overwriting mapping for (%s, %s), previous address %s, new "
547             "address %s", name, dnstype2str(type), it->second.c_str(),
548             addr);
549         it->second = addr;
550         return;
551     }
552     mappings_.emplace(std::piecewise_construct,
553                       std::forward_as_tuple(name, type),
554                       std::forward_as_tuple(addr));
555 }
556 
removeMapping(const char * name,ns_type type)557 void DNSResponder::removeMapping(const char* name, ns_type type) {
558     std::lock_guard<std::mutex> lock(mappings_mutex_);
559     auto it = mappings_.find(QueryKey(name, type));
560     if (it != mappings_.end()) {
561         ALOGI("Cannot remove mapping mapping from (%s, %s), not present", name,
562             dnstype2str(type));
563         return;
564     }
565     mappings_.erase(it);
566 }
567 
setResponseProbability(double response_probability)568 void DNSResponder::setResponseProbability(double response_probability) {
569     response_probability_ = response_probability;
570 }
571 
running() const572 bool DNSResponder::running() const {
573     return socket_ != -1;
574 }
575 
startServer()576 bool DNSResponder::startServer() {
577     if (running()) {
578         ALOGI("server already running");
579         return false;
580     }
581     addrinfo ai_hints{
582         .ai_family = AF_UNSPEC,
583         .ai_socktype = SOCK_DGRAM,
584         .ai_flags = AI_PASSIVE
585     };
586     addrinfo* ai_res;
587     int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(),
588                          &ai_hints, &ai_res);
589     if (rv) {
590         ALOGI("getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
591             listen_service_.c_str(), gai_strerror(rv));
592         return false;
593     }
594     int s = -1;
595     for (const addrinfo* ai = ai_res ; ai ; ai = ai->ai_next) {
596         s = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
597         if (s < 0) continue;
598         enableSockopt(s, SOL_SOCKET, SO_REUSEPORT);
599         enableSockopt(s, SOL_SOCKET, SO_REUSEADDR);
600         if (bind(s, ai->ai_addr, ai->ai_addrlen)) {
601             APLOGI("bind failed for socket %d", s);
602             close(s);
603             s = -1;
604             continue;
605         }
606         std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
607         ALOGI("bound to UDP %s:%s", host_str.c_str(), listen_service_.c_str());
608         break;
609     }
610     freeaddrinfo(ai_res);
611     if (s < 0) {
612         ALOGI("bind() failed");
613         return false;
614     }
615 
616     int flags = fcntl(s, F_GETFL, 0);
617     if (flags < 0) flags = 0;
618     if (fcntl(s, F_SETFL, flags | O_NONBLOCK) < 0) {
619         APLOGI("fcntl(F_SETFL) failed for socket %d", s);
620         close(s);
621         return false;
622     }
623 
624     int ep_fd = epoll_create(1);
625     if (ep_fd < 0) {
626         char error_msg[512] = { 0 };
627         if (strerror_r(errno, error_msg, sizeof(error_msg)))
628             strncpy(error_msg, "UNKNOWN", sizeof(error_msg));
629         APLOGI("epoll_create() failed: %s", error_msg);
630         close(s);
631         return false;
632     }
633     epoll_event ev;
634     ev.events = EPOLLIN;
635     ev.data.fd = s;
636     if (epoll_ctl(ep_fd, EPOLL_CTL_ADD, s, &ev) < 0) {
637         APLOGI("epoll_ctl() failed for socket %d", s);
638         close(ep_fd);
639         close(s);
640         return false;
641     }
642 
643     epoll_fd_ = ep_fd;
644     socket_ = s;
645     {
646         std::lock_guard<std::mutex> lock(update_mutex_);
647         handler_thread_ = std::thread(&DNSResponder::requestHandler, this);
648     }
649     ALOGI("server started successfully");
650     return true;
651 }
652 
stopServer()653 bool DNSResponder::stopServer() {
654     std::lock_guard<std::mutex> lock(update_mutex_);
655     if (!running()) {
656         ALOGI("server not running");
657         return false;
658     }
659     if (terminate_) {
660         ALOGI("LOGIC ERROR");
661         return false;
662     }
663     ALOGI("stopping server");
664     terminate_ = true;
665     handler_thread_.join();
666     close(epoll_fd_);
667     close(socket_);
668     terminate_ = false;
669     socket_ = -1;
670     ALOGI("server stopped successfully");
671     return true;
672 }
673 
queries() const674 std::vector<std::pair<std::string, ns_type >> DNSResponder::queries() const {
675     std::lock_guard<std::mutex> lock(queries_mutex_);
676     return queries_;
677 }
678 
clearQueries()679 void DNSResponder::clearQueries() {
680     std::lock_guard<std::mutex> lock(queries_mutex_);
681     queries_.clear();
682 }
683 
requestHandler()684 void DNSResponder::requestHandler() {
685     epoll_event evs[1];
686     while (!terminate_) {
687         int n = epoll_wait(epoll_fd_, evs, 1, poll_timeout_ms_);
688         if (n == 0) continue;
689         if (n < 0) {
690             ALOGI("epoll_wait() failed");
691             // TODO(imaipi): terminate on error.
692             return;
693         }
694         char buffer[4096];
695         sockaddr_storage sa;
696         socklen_t sa_len = sizeof(sa);
697         ssize_t len;
698         do {
699             len = recvfrom(socket_, buffer, sizeof(buffer), 0,
700                            (sockaddr*) &sa, &sa_len);
701         } while (len < 0 && (errno == EAGAIN || errno == EINTR));
702         if (len <= 0) {
703             ALOGI("recvfrom() failed");
704             continue;
705         }
706         ALOGI("read %zd bytes", len);
707         char response[4096];
708         size_t response_len = sizeof(response);
709         if (handleDNSRequest(buffer, len, response, &response_len) &&
710             response_len > 0) {
711             len = sendto(socket_, response, response_len, 0,
712                          reinterpret_cast<const sockaddr*>(&sa), sa_len);
713             std::string host_str =
714                 addr2str(reinterpret_cast<const sockaddr*>(&sa), sa_len);
715             if (len > 0) {
716                 ALOGI("sent %zu bytes to %s", len, host_str.c_str());
717             } else {
718                 APLOGI("sendto() failed for %s", host_str.c_str());
719             }
720             // Test that the response is actually a correct DNS message.
721             const char* response_end = response + len;
722             DNSHeader header;
723             const char* cur = header.read(response, response_end);
724             if (cur == nullptr) ALOGI("response is flawed");
725 
726         } else {
727             ALOGI("not responding");
728         }
729     }
730 }
731 
handleDNSRequest(const char * buffer,ssize_t len,char * response,size_t * response_len) const732 bool DNSResponder::handleDNSRequest(const char* buffer, ssize_t len,
733                                     char* response, size_t* response_len)
734                                     const {
735     ALOGI("request: '%s'", str2hex(buffer, len).c_str());
736     const char* buffer_end = buffer + len;
737     DNSHeader header;
738     const char* cur = header.read(buffer, buffer_end);
739     // TODO(imaipi): for now, unparsable messages are silently dropped, fix.
740     if (cur == nullptr) {
741         ALOGI("failed to parse query");
742         return false;
743     }
744     if (header.qr) {
745         ALOGI("response received instead of a query");
746         return false;
747     }
748     if (header.opcode != ns_opcode::ns_o_query) {
749         ALOGI("unsupported request opcode received");
750         return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
751                                  response_len);
752     }
753     if (header.questions.empty()) {
754         ALOGI("no questions present");
755         return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
756                                  response_len);
757     }
758     if (!header.answers.empty()) {
759         ALOGI("already %zu answers present in query", header.answers.size());
760         return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
761                                  response_len);
762     }
763     if (!header.additionals.empty() && fail_on_edns_) {
764         ALOGI("DNS request has an additional section (assumed EDNS). "
765               "Simulating an ancient (pre-EDNS) server.");
766         return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
767                                  response_len);
768     }
769     {
770         std::lock_guard<std::mutex> lock(queries_mutex_);
771         for (const DNSQuestion& question : header.questions) {
772             queries_.push_back(make_pair(question.qname.name,
773                                          ns_type(question.qtype)));
774         }
775     }
776 
777     // Ignore requests with the preset probability.
778     auto constexpr bound = std::numeric_limits<unsigned>::max();
779     if (arc4random_uniform(bound) > bound*response_probability_) {
780         ALOGI("returning SRVFAIL in accordance with probability distribution");
781         return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
782                                  response_len);
783     }
784 
785     for (const DNSQuestion& question : header.questions) {
786         if (question.qclass != ns_class::ns_c_in &&
787             question.qclass != ns_class::ns_c_any) {
788             ALOGI("unsupported question class %u", question.qclass);
789             return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
790                                      response_len);
791         }
792         if (!addAnswerRecords(question, &header.answers)) {
793             return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
794                                      response_len);
795         }
796     }
797     header.qr = true;
798     char* response_cur = header.write(response, response + *response_len);
799     if (response_cur == nullptr) {
800         return false;
801     }
802     *response_len = response_cur - response;
803     return true;
804 }
805 
addAnswerRecords(const DNSQuestion & question,std::vector<DNSRecord> * answers) const806 bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
807                                     std::vector<DNSRecord>* answers) const {
808     std::lock_guard<std::mutex> guard(mappings_mutex_);
809     auto it = mappings_.find(QueryKey(question.qname.name, question.qtype));
810     if (it == mappings_.end()) {
811         // TODO(imaipi): handle correctly
812         ALOGI("no mapping found for %s %s, lazily refusing to add an answer",
813             question.qname.name.c_str(), dnstype2str(question.qtype));
814         return true;
815     }
816     ALOGI("mapping found for %s %s: %s", question.qname.name.c_str(),
817         dnstype2str(question.qtype), it->second.c_str());
818     DNSRecord record;
819     record.name = question.qname;
820     record.rtype = question.qtype;
821     record.rclass = ns_class::ns_c_in;
822     record.ttl = 5;  // seconds
823     if (question.qtype == ns_type::ns_t_a) {
824         record.rdata.resize(4);
825         if (inet_pton(AF_INET, it->second.c_str(), record.rdata.data()) != 1) {
826             ALOGI("inet_pton(AF_INET, %s) failed", it->second.c_str());
827             return false;
828         }
829     } else if (question.qtype == ns_type::ns_t_aaaa) {
830         record.rdata.resize(16);
831         if (inet_pton(AF_INET6, it->second.c_str(), record.rdata.data()) != 1) {
832             ALOGI("inet_pton(AF_INET6, %s) failed", it->second.c_str());
833             return false;
834         }
835     } else {
836         ALOGI("unhandled qtype %s", dnstype2str(question.qtype));
837         return false;
838     }
839     answers->push_back(std::move(record));
840     return true;
841 }
842 
makeErrorResponse(DNSHeader * header,ns_rcode rcode,char * response,size_t * response_len) const843 bool DNSResponder::makeErrorResponse(DNSHeader* header, ns_rcode rcode,
844                                      char* response, size_t* response_len)
845                                      const {
846     header->answers.clear();
847     header->authorities.clear();
848     header->additionals.clear();
849     header->rcode = rcode;
850     header->qr = true;
851     char* response_cur = header->write(response, response + *response_len);
852     if (response_cur == nullptr) return false;
853     *response_len = response_cur - response;
854     return true;
855 }
856 
857 }  // namespace test
858 
859