• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "dns_responder.h"
18 
19 #include <arpa/inet.h>
20 #include <fcntl.h>
21 #include <netdb.h>
22 #include <stdarg.h>
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <string.h>
26 #include <sys/epoll.h>
27 #include <sys/eventfd.h>
28 #include <sys/socket.h>
29 #include <sys/types.h>
30 #include <unistd.h>
31 #include <set>
32 
33 #include <iostream>
34 #include <vector>
35 
36 #define LOG_TAG "DNSResponder"
37 #include <android-base/strings.h>
38 #include <log/log.h>
39 #include <netdutils/SocketOption.h>
40 
41 #include "NetdConstants.h"
42 
43 using android::netdutils::enableSockopt;
44 
45 namespace test {
46 
errno2str()47 std::string errno2str() {
48     char error_msg[512] = { 0 };
49     // It actually calls __gnu_strerror_r() which returns the type |char*| rather than |int|.
50     // PLOG is an option though it requires lots of changes from ALOGx() to LOG(x).
51     return strerror_r(errno, error_msg, sizeof(error_msg));
52 }
53 
54 #define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str())
55 
56 #if 0
57 #define DBGLOG(fmt, ...) ALOGI(fmt, __VA_ARGS__)
58 #else
59 #define DBGLOG(fmt, ...)
60 #endif
61 
str2hex(const char * buffer,size_t len)62 std::string str2hex(const char* buffer, size_t len) {
63     std::string str(len*2, '\0');
64     for (size_t i = 0 ; i < len ; ++i) {
65         static const char* hex = "0123456789ABCDEF";
66         uint8_t c = buffer[i];
67         str[i*2] = hex[c >> 4];
68         str[i*2 + 1] = hex[c & 0x0F];
69     }
70     return str;
71 }
72 
addr2str(const sockaddr * sa,socklen_t sa_len)73 std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
74     char host_str[NI_MAXHOST] = { 0 };
75     int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0,
76                          NI_NUMERICHOST);
77     if (rv == 0) return std::string(host_str);
78     return std::string();
79 }
80 
81 /* DNS struct helpers */
82 
dnstype2str(unsigned dnstype)83 const char* dnstype2str(unsigned dnstype) {
84     static std::unordered_map<unsigned, const char*> kTypeStrs = {
85         { ns_type::ns_t_a, "A" },
86         { ns_type::ns_t_ns, "NS" },
87         { ns_type::ns_t_md, "MD" },
88         { ns_type::ns_t_mf, "MF" },
89         { ns_type::ns_t_cname, "CNAME" },
90         { ns_type::ns_t_soa, "SOA" },
91         { ns_type::ns_t_mb, "MB" },
92         { ns_type::ns_t_mb, "MG" },
93         { ns_type::ns_t_mr, "MR" },
94         { ns_type::ns_t_null, "NULL" },
95         { ns_type::ns_t_wks, "WKS" },
96         { ns_type::ns_t_ptr, "PTR" },
97         { ns_type::ns_t_hinfo, "HINFO" },
98         { ns_type::ns_t_minfo, "MINFO" },
99         { ns_type::ns_t_mx, "MX" },
100         { ns_type::ns_t_txt, "TXT" },
101         { ns_type::ns_t_rp, "RP" },
102         { ns_type::ns_t_afsdb, "AFSDB" },
103         { ns_type::ns_t_x25, "X25" },
104         { ns_type::ns_t_isdn, "ISDN" },
105         { ns_type::ns_t_rt, "RT" },
106         { ns_type::ns_t_nsap, "NSAP" },
107         { ns_type::ns_t_nsap_ptr, "NSAP-PTR" },
108         { ns_type::ns_t_sig, "SIG" },
109         { ns_type::ns_t_key, "KEY" },
110         { ns_type::ns_t_px, "PX" },
111         { ns_type::ns_t_gpos, "GPOS" },
112         { ns_type::ns_t_aaaa, "AAAA" },
113         { ns_type::ns_t_loc, "LOC" },
114         { ns_type::ns_t_nxt, "NXT" },
115         { ns_type::ns_t_eid, "EID" },
116         { ns_type::ns_t_nimloc, "NIMLOC" },
117         { ns_type::ns_t_srv, "SRV" },
118         { ns_type::ns_t_naptr, "NAPTR" },
119         { ns_type::ns_t_kx, "KX" },
120         { ns_type::ns_t_cert, "CERT" },
121         { ns_type::ns_t_a6, "A6" },
122         { ns_type::ns_t_dname, "DNAME" },
123         { ns_type::ns_t_sink, "SINK" },
124         { ns_type::ns_t_opt, "OPT" },
125         { ns_type::ns_t_apl, "APL" },
126         { ns_type::ns_t_tkey, "TKEY" },
127         { ns_type::ns_t_tsig, "TSIG" },
128         { ns_type::ns_t_ixfr, "IXFR" },
129         { ns_type::ns_t_axfr, "AXFR" },
130         { ns_type::ns_t_mailb, "MAILB" },
131         { ns_type::ns_t_maila, "MAILA" },
132         { ns_type::ns_t_any, "ANY" },
133         { ns_type::ns_t_zxfr, "ZXFR" },
134     };
135     auto it = kTypeStrs.find(dnstype);
136     static const char* kUnknownStr{ "UNKNOWN" };
137     if (it == kTypeStrs.end()) return kUnknownStr;
138     return it->second;
139 }
140 
dnsclass2str(unsigned dnsclass)141 const char* dnsclass2str(unsigned dnsclass) {
142     static std::unordered_map<unsigned, const char*> kClassStrs = {
143         { ns_class::ns_c_in , "Internet" },
144         { 2, "CSNet" },
145         { ns_class::ns_c_chaos, "ChaosNet" },
146         { ns_class::ns_c_hs, "Hesiod" },
147         { ns_class::ns_c_none, "none" },
148         { ns_class::ns_c_any, "any" }
149     };
150     auto it = kClassStrs.find(dnsclass);
151     static const char* kUnknownStr{ "UNKNOWN" };
152     if (it == kClassStrs.end()) return kUnknownStr;
153     return it->second;
154 }
155 
156 struct DNSName {
157     std::string name;
158     const char* read(const char* buffer, const char* buffer_end);
159     char* write(char* buffer, const char* buffer_end) const;
160     const char* toString() const;
161 private:
162     const char* parseField(const char* buffer, const char* buffer_end,
163                            bool* last);
164 };
165 
toString() const166 const char* DNSName::toString() const {
167     return name.c_str();
168 }
169 
read(const char * buffer,const char * buffer_end)170 const char* DNSName::read(const char* buffer, const char* buffer_end) {
171     const char* cur = buffer;
172     bool last = false;
173     do {
174         cur = parseField(cur, buffer_end, &last);
175         if (cur == nullptr) {
176             ALOGI("parsing failed at line %d", __LINE__);
177             return nullptr;
178         }
179     } while (!last);
180     return cur;
181 }
182 
write(char * buffer,const char * buffer_end) const183 char* DNSName::write(char* buffer, const char* buffer_end) const {
184     char* buffer_cur = buffer;
185     for (size_t pos = 0 ; pos < name.size() ; ) {
186         size_t dot_pos = name.find('.', pos);
187         if (dot_pos == std::string::npos) {
188             // Sanity check, should never happen unless parseField is broken.
189             ALOGI("logic error: all names are expected to end with a '.'");
190             return nullptr;
191         }
192         size_t len = dot_pos - pos;
193         if (len >= 256) {
194             ALOGI("name component '%s' is %zu long, but max is 255",
195                     name.substr(pos, dot_pos - pos).c_str(), len);
196             return nullptr;
197         }
198         if (buffer_cur + sizeof(uint8_t) + len > buffer_end) {
199             ALOGI("buffer overflow at line %d", __LINE__);
200             return nullptr;
201         }
202         *buffer_cur++ = len;
203         buffer_cur = std::copy(std::next(name.begin(), pos),
204                                std::next(name.begin(), dot_pos),
205                                buffer_cur);
206         pos = dot_pos + 1;
207     }
208     // Write final zero.
209     *buffer_cur++ = 0;
210     return buffer_cur;
211 }
212 
parseField(const char * buffer,const char * buffer_end,bool * last)213 const char* DNSName::parseField(const char* buffer, const char* buffer_end,
214                                 bool* last) {
215     if (buffer + sizeof(uint8_t) > buffer_end) {
216         ALOGI("parsing failed at line %d", __LINE__);
217         return nullptr;
218     }
219     unsigned field_type = *buffer >> 6;
220     unsigned ofs = *buffer & 0x3F;
221     const char* cur = buffer + sizeof(uint8_t);
222     if (field_type == 0) {
223         // length + name component
224         if (ofs == 0) {
225             *last = true;
226             return cur;
227         }
228         if (cur + ofs > buffer_end) {
229             ALOGI("parsing failed at line %d", __LINE__);
230             return nullptr;
231         }
232         name.append(cur, ofs);
233         name.push_back('.');
234         return cur + ofs;
235     } else if (field_type == 3) {
236         ALOGI("name compression not implemented");
237         return nullptr;
238     }
239     ALOGI("invalid name field type");
240     return nullptr;
241 }
242 
243 struct DNSQuestion {
244     DNSName qname;
245     unsigned qtype;
246     unsigned qclass;
247     const char* read(const char* buffer, const char* buffer_end);
248     char* write(char* buffer, const char* buffer_end) const;
249     std::string toString() const;
250 };
251 
read(const char * buffer,const char * buffer_end)252 const char* DNSQuestion::read(const char* buffer, const char* buffer_end) {
253     const char* cur = qname.read(buffer, buffer_end);
254     if (cur == nullptr) {
255         ALOGI("parsing failed at line %d", __LINE__);
256         return nullptr;
257     }
258     if (cur + 2*sizeof(uint16_t) > buffer_end) {
259         ALOGI("parsing failed at line %d", __LINE__);
260         return nullptr;
261     }
262     qtype = ntohs(*reinterpret_cast<const uint16_t*>(cur));
263     qclass = ntohs(*reinterpret_cast<const uint16_t*>(cur + sizeof(uint16_t)));
264     return cur + 2*sizeof(uint16_t);
265 }
266 
write(char * buffer,const char * buffer_end) const267 char* DNSQuestion::write(char* buffer, const char* buffer_end) const {
268     char* buffer_cur = qname.write(buffer, buffer_end);
269     if (buffer_cur == nullptr) return nullptr;
270     if (buffer_cur + 2*sizeof(uint16_t) > buffer_end) {
271         ALOGI("buffer overflow on line %d", __LINE__);
272         return nullptr;
273     }
274     *reinterpret_cast<uint16_t*>(buffer_cur) = htons(qtype);
275     *reinterpret_cast<uint16_t*>(buffer_cur + sizeof(uint16_t)) =
276             htons(qclass);
277     return buffer_cur + 2*sizeof(uint16_t);
278 }
279 
toString() const280 std::string DNSQuestion::toString() const {
281     char buffer[4096];
282     int len = snprintf(buffer, sizeof(buffer), "Q<%s,%s,%s>", qname.toString(),
283                        dnstype2str(qtype), dnsclass2str(qclass));
284     return std::string(buffer, len);
285 }
286 
287 struct DNSRecord {
288     DNSName name;
289     unsigned rtype;
290     unsigned rclass;
291     unsigned ttl;
292     std::vector<char> rdata;
293     const char* read(const char* buffer, const char* buffer_end);
294     char* write(char* buffer, const char* buffer_end) const;
295     std::string toString() const;
296 private:
297     struct IntFields {
298         uint16_t rtype;
299         uint16_t rclass;
300         uint32_t ttl;
301         uint16_t rdlen;
302     } __attribute__((__packed__));
303 
304     const char* readIntFields(const char* buffer, const char* buffer_end,
305             unsigned* rdlen);
306     char* writeIntFields(unsigned rdlen, char* buffer,
307                          const char* buffer_end) const;
308 };
309 
read(const char * buffer,const char * buffer_end)310 const char* DNSRecord::read(const char* buffer, const char* buffer_end) {
311     const char* cur = name.read(buffer, buffer_end);
312     if (cur == nullptr) {
313         ALOGI("parsing failed at line %d", __LINE__);
314         return nullptr;
315     }
316     unsigned rdlen = 0;
317     cur = readIntFields(cur, buffer_end, &rdlen);
318     if (cur == nullptr) {
319         ALOGI("parsing failed at line %d", __LINE__);
320         return nullptr;
321     }
322     if (cur + rdlen > buffer_end) {
323         ALOGI("parsing failed at line %d", __LINE__);
324         return nullptr;
325     }
326     rdata.assign(cur, cur + rdlen);
327     return cur + rdlen;
328 }
329 
write(char * buffer,const char * buffer_end) const330 char* DNSRecord::write(char* buffer, const char* buffer_end) const {
331     char* buffer_cur = name.write(buffer, buffer_end);
332     if (buffer_cur == nullptr) return nullptr;
333     buffer_cur = writeIntFields(rdata.size(), buffer_cur, buffer_end);
334     if (buffer_cur == nullptr) return nullptr;
335     if (buffer_cur + rdata.size() > buffer_end) {
336         ALOGI("buffer overflow on line %d", __LINE__);
337         return nullptr;
338     }
339     return std::copy(rdata.begin(), rdata.end(), buffer_cur);
340 }
341 
toString() const342 std::string DNSRecord::toString() const {
343     char buffer[4096];
344     int len = snprintf(buffer, sizeof(buffer), "R<%s,%s,%s>", name.toString(),
345                        dnstype2str(rtype), dnsclass2str(rclass));
346     return std::string(buffer, len);
347 }
348 
readIntFields(const char * buffer,const char * buffer_end,unsigned * rdlen)349 const char* DNSRecord::readIntFields(const char* buffer, const char* buffer_end,
350                                      unsigned* rdlen) {
351     if (buffer + sizeof(IntFields) > buffer_end ) {
352         ALOGI("parsing failed at line %d", __LINE__);
353         return nullptr;
354     }
355     const auto& intfields = *reinterpret_cast<const IntFields*>(buffer);
356     rtype = ntohs(intfields.rtype);
357     rclass = ntohs(intfields.rclass);
358     ttl = ntohl(intfields.ttl);
359     *rdlen = ntohs(intfields.rdlen);
360     return buffer + sizeof(IntFields);
361 }
362 
writeIntFields(unsigned rdlen,char * buffer,const char * buffer_end) const363 char* DNSRecord::writeIntFields(unsigned rdlen, char* buffer,
364                                 const char* buffer_end) const {
365     if (buffer + sizeof(IntFields) > buffer_end ) {
366         ALOGI("buffer overflow on line %d", __LINE__);
367         return nullptr;
368     }
369     auto& intfields = *reinterpret_cast<IntFields*>(buffer);
370     intfields.rtype = htons(rtype);
371     intfields.rclass = htons(rclass);
372     intfields.ttl = htonl(ttl);
373     intfields.rdlen = htons(rdlen);
374     return buffer + sizeof(IntFields);
375 }
376 
377 struct DNSHeader {
378     unsigned id;
379     bool ra;
380     uint8_t rcode;
381     bool qr;
382     uint8_t opcode;
383     bool aa;
384     bool tr;
385     bool rd;
386     bool ad;
387     std::vector<DNSQuestion> questions;
388     std::vector<DNSRecord> answers;
389     std::vector<DNSRecord> authorities;
390     std::vector<DNSRecord> additionals;
391     const char* read(const char* buffer, const char* buffer_end);
392     char* write(char* buffer, const char* buffer_end) const;
393     std::string toString() const;
394 
395 private:
396     struct Header {
397         uint16_t id;
398         uint8_t flags0;
399         uint8_t flags1;
400         uint16_t qdcount;
401         uint16_t ancount;
402         uint16_t nscount;
403         uint16_t arcount;
404     } __attribute__((__packed__));
405 
406     const char* readHeader(const char* buffer, const char* buffer_end,
407                            unsigned* qdcount, unsigned* ancount,
408                            unsigned* nscount, unsigned* arcount);
409 };
410 
read(const char * buffer,const char * buffer_end)411 const char* DNSHeader::read(const char* buffer, const char* buffer_end) {
412     unsigned qdcount;
413     unsigned ancount;
414     unsigned nscount;
415     unsigned arcount;
416     const char* cur = readHeader(buffer, buffer_end, &qdcount, &ancount,
417                                  &nscount, &arcount);
418     if (cur == nullptr) {
419         ALOGI("parsing failed at line %d", __LINE__);
420         return nullptr;
421     }
422     if (qdcount) {
423         questions.resize(qdcount);
424         for (unsigned i = 0 ; i < qdcount ; ++i) {
425             cur = questions[i].read(cur, buffer_end);
426             if (cur == nullptr) {
427                 ALOGI("parsing failed at line %d", __LINE__);
428                 return nullptr;
429             }
430         }
431     }
432     if (ancount) {
433         answers.resize(ancount);
434         for (unsigned i = 0 ; i < ancount ; ++i) {
435             cur = answers[i].read(cur, buffer_end);
436             if (cur == nullptr) {
437                 ALOGI("parsing failed at line %d", __LINE__);
438                 return nullptr;
439             }
440         }
441     }
442     if (nscount) {
443         authorities.resize(nscount);
444         for (unsigned i = 0 ; i < nscount ; ++i) {
445             cur = authorities[i].read(cur, buffer_end);
446             if (cur == nullptr) {
447                 ALOGI("parsing failed at line %d", __LINE__);
448                 return nullptr;
449             }
450         }
451     }
452     if (arcount) {
453         additionals.resize(arcount);
454         for (unsigned i = 0 ; i < arcount ; ++i) {
455             cur = additionals[i].read(cur, buffer_end);
456             if (cur == nullptr) {
457                 ALOGI("parsing failed at line %d", __LINE__);
458                 return nullptr;
459             }
460         }
461     }
462     return cur;
463 }
464 
write(char * buffer,const char * buffer_end) const465 char* DNSHeader::write(char* buffer, const char* buffer_end) const {
466     if (buffer + sizeof(Header) > buffer_end) {
467         ALOGI("buffer overflow on line %d", __LINE__);
468         return nullptr;
469     }
470     Header& header = *reinterpret_cast<Header*>(buffer);
471     // bytes 0-1
472     header.id = htons(id);
473     // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
474     header.flags0 = (qr << 7) | (opcode << 3) | (aa << 2) | (tr << 1) | rd;
475     // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
476     // Fake behavior: if the query set the "ad" bit, set it in the response too.
477     // In a real server, this should be set only if the data is authentic and the
478     // query contained an "ad" bit or DNSSEC extensions.
479     header.flags1 = (ad << 5) | rcode;
480     // rest of header
481     header.qdcount = htons(questions.size());
482     header.ancount = htons(answers.size());
483     header.nscount = htons(authorities.size());
484     header.arcount = htons(additionals.size());
485     char* buffer_cur = buffer + sizeof(Header);
486     for (const DNSQuestion& question : questions) {
487         buffer_cur = question.write(buffer_cur, buffer_end);
488         if (buffer_cur == nullptr) return nullptr;
489     }
490     for (const DNSRecord& answer : answers) {
491         buffer_cur = answer.write(buffer_cur, buffer_end);
492         if (buffer_cur == nullptr) return nullptr;
493     }
494     for (const DNSRecord& authority : authorities) {
495         buffer_cur = authority.write(buffer_cur, buffer_end);
496         if (buffer_cur == nullptr) return nullptr;
497     }
498     for (const DNSRecord& additional : additionals) {
499         buffer_cur = additional.write(buffer_cur, buffer_end);
500         if (buffer_cur == nullptr) return nullptr;
501     }
502     return buffer_cur;
503 }
504 
toString() const505 std::string DNSHeader::toString() const {
506     // TODO
507     return std::string();
508 }
509 
readHeader(const char * buffer,const char * buffer_end,unsigned * qdcount,unsigned * ancount,unsigned * nscount,unsigned * arcount)510 const char* DNSHeader::readHeader(const char* buffer, const char* buffer_end,
511                                   unsigned* qdcount, unsigned* ancount,
512                                   unsigned* nscount, unsigned* arcount) {
513     if (buffer + sizeof(Header) > buffer_end)
514         return nullptr;
515     const auto& header = *reinterpret_cast<const Header*>(buffer);
516     // bytes 0-1
517     id = ntohs(header.id);
518     // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
519     qr = header.flags0 >> 7;
520     opcode = (header.flags0 >> 3) & 0x0F;
521     aa = (header.flags0 >> 2) & 1;
522     tr = (header.flags0 >> 1) & 1;
523     rd = header.flags0 & 1;
524     // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
525     ra = header.flags1 >> 7;
526     ad = (header.flags1 >> 5) & 1;
527     rcode = header.flags1 & 0xF;
528     // rest of header
529     *qdcount = ntohs(header.qdcount);
530     *ancount = ntohs(header.ancount);
531     *nscount = ntohs(header.nscount);
532     *arcount = ntohs(header.arcount);
533     return buffer + sizeof(Header);
534 }
535 
536 /* DNS responder */
537 
DNSResponder(std::string listen_address,std::string listen_service,int poll_timeout_ms,ns_rcode error_rcode)538 DNSResponder::DNSResponder(std::string listen_address, std::string listen_service,
539                            int poll_timeout_ms, ns_rcode error_rcode)
540     : listen_address_(std::move(listen_address)),
541       listen_service_(std::move(listen_service)),
542       poll_timeout_ms_(poll_timeout_ms),
543       error_rcode_(error_rcode) {}
544 
~DNSResponder()545 DNSResponder::~DNSResponder() {
546     stopServer();
547 }
548 
addMapping(const std::string & name,ns_type type,const std::string & addr)549 void DNSResponder::addMapping(const std::string& name, ns_type type, const std::string& addr) {
550     std::lock_guard lock(mappings_mutex_);
551     auto it = mappings_.find(QueryKey(name, type));
552     if (it != mappings_.end()) {
553         ALOGI("Overwriting mapping for (%s, %s), previous address %s, new "
554               "address %s",
555               name.c_str(), dnstype2str(type), it->second.c_str(), addr.c_str());
556         it->second = addr;
557         return;
558     }
559     mappings_.try_emplace({name, type}, addr);
560 }
561 
removeMapping(const std::string & name,ns_type type)562 void DNSResponder::removeMapping(const std::string& name, ns_type type) {
563     std::lock_guard lock(mappings_mutex_);
564     auto it = mappings_.find(QueryKey(name, type));
565     if (it != mappings_.end()) {
566         ALOGI("Cannot remove mapping mapping from (%s, %s), not present", name.c_str(),
567               dnstype2str(type));
568         return;
569     }
570     mappings_.erase(it);
571 }
572 
setResponseProbability(double response_probability)573 void DNSResponder::setResponseProbability(double response_probability) {
574     response_probability_ = response_probability;
575 }
576 
setEdns(Edns edns)577 void DNSResponder::setEdns(Edns edns) {
578     edns_ = edns;
579 }
580 
running() const581 bool DNSResponder::running() const {
582     return socket_.get() != -1;
583 }
584 
startServer()585 bool DNSResponder::startServer() {
586     if (running()) {
587         ALOGI("server already running");
588         return false;
589     }
590 
591     // Set up UDP socket.
592     addrinfo ai_hints{
593         .ai_family = AF_UNSPEC,
594         .ai_socktype = SOCK_DGRAM,
595         .ai_flags = AI_PASSIVE
596     };
597     addrinfo* ai_res = nullptr;
598     int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(),
599                          &ai_hints, &ai_res);
600     ScopedAddrinfo ai_res_cleanup(ai_res);
601     if (rv) {
602         ALOGI("getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
603             listen_service_.c_str(), gai_strerror(rv));
604         return false;
605     }
606     for (const addrinfo* ai = ai_res ; ai ; ai = ai->ai_next) {
607         socket_.reset(socket(ai->ai_family, ai->ai_socktype | SOCK_NONBLOCK, ai->ai_protocol));
608         if (socket_.get() < 0) {
609             APLOGI("ignore creating socket %d failed", socket_.get());
610             continue;
611         }
612         enableSockopt(socket_.get(), SOL_SOCKET, SO_REUSEPORT).ignoreError();
613         enableSockopt(socket_.get(), SOL_SOCKET, SO_REUSEADDR).ignoreError();
614         std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
615         if (bind(socket_.get(), ai->ai_addr, ai->ai_addrlen)) {
616             APLOGI("failed to bind UDP %s:%s", host_str.c_str(), listen_service_.c_str());
617             continue;
618         }
619         ALOGI("bound to UDP %s:%s", host_str.c_str(), listen_service_.c_str());
620         break;
621     }
622 
623     // Set up eventfd socket.
624     event_fd_.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));
625     if (event_fd_.get() == -1) {
626         APLOGI("failed to create eventfd %d", event_fd_.get());
627         return false;
628     }
629 
630     // Set up epoll socket.
631     epoll_fd_.reset(epoll_create1(EPOLL_CLOEXEC));
632     if (epoll_fd_.get() < 0) {
633         APLOGI("epoll_create1() failed on fd %d", epoll_fd_.get());
634         return false;
635     }
636 
637     ALOGI("adding socket %d to epoll", socket_.get());
638     if (!addFd(socket_.get(), EPOLLIN)) {
639         ALOGE("failed to add the socket %d to epoll", socket_.get());
640         return false;
641     }
642     ALOGI("adding eventfd %d to epoll", event_fd_.get());
643     if (!addFd(event_fd_.get(), EPOLLIN)) {
644         ALOGE("failed to add the eventfd %d to epoll", event_fd_.get());
645         return false;
646     }
647 
648     {
649         std::lock_guard lock(update_mutex_);
650         handler_thread_ = std::thread(&DNSResponder::requestHandler, this);
651     }
652     ALOGI("server started successfully");
653     return true;
654 }
655 
stopServer()656 bool DNSResponder::stopServer() {
657     std::lock_guard lock(update_mutex_);
658     if (!running()) {
659         ALOGI("server not running");
660         return false;
661     }
662     ALOGI("stopping server");
663     if (!sendToEventFd()) {
664         return false;
665     }
666     handler_thread_.join();
667     epoll_fd_.reset();
668     socket_.reset();
669     ALOGI("server stopped successfully");
670     return true;
671 }
672 
queries() const673 std::vector<std::pair<std::string, ns_type >> DNSResponder::queries() const {
674     std::lock_guard lock(queries_mutex_);
675     return queries_;
676 }
677 
dumpQueries() const678 std::string DNSResponder::dumpQueries() const {
679     std::lock_guard lock(queries_mutex_);
680     std::string out;
681     for (const auto& q : queries_) {
682         out += "{\"" + q.first + "\", " + std::to_string(q.second) + "} ";
683     }
684     return out;
685 }
686 
clearQueries()687 void DNSResponder::clearQueries() {
688     std::lock_guard lock(queries_mutex_);
689     queries_.clear();
690 }
691 
requestHandler()692 void DNSResponder::requestHandler() {
693     epoll_event evs[EPOLL_MAX_EVENTS];
694     while (true) {
695         int n = epoll_wait(epoll_fd_.get(), evs, EPOLL_MAX_EVENTS, poll_timeout_ms_);
696         if (n == 0) continue;
697         if (n < 0) {
698             APLOGI("epoll_wait() failed, n=%d", n);
699             return;
700         }
701 
702         for (int i = 0; i < n; i++) {
703             const int fd = evs[i].data.fd;
704             const uint32_t events = evs[i].events;
705             if (fd == event_fd_.get() && (events & (EPOLLIN | EPOLLERR))) {
706                 handleEventFd();
707                 return;
708             } else if (fd == socket_.get() && (events & (EPOLLIN | EPOLLERR))) {
709                 handleQuery();
710             } else {
711                 ALOGW("unexpected epoll events 0x%x on fd %d", events, fd);
712             }
713         }
714     }
715 }
716 
handleDNSRequest(const char * buffer,ssize_t len,char * response,size_t * response_len) const717 bool DNSResponder::handleDNSRequest(const char* buffer, ssize_t len,
718                                     char* response, size_t* response_len)
719                                     const {
720     DBGLOG("request: '%s'", str2hex(buffer, len).c_str());
721     const char* buffer_end = buffer + len;
722     DNSHeader header;
723     const char* cur = header.read(buffer, buffer_end);
724     // TODO(imaipi): for now, unparsable messages are silently dropped, fix.
725     if (cur == nullptr) {
726         ALOGI("failed to parse query");
727         return false;
728     }
729     if (header.qr) {
730         ALOGI("response received instead of a query");
731         return false;
732     }
733     if (header.opcode != ns_opcode::ns_o_query) {
734         ALOGI("unsupported request opcode received");
735         return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
736                                  response_len);
737     }
738     if (header.questions.empty()) {
739         ALOGI("no questions present");
740         return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
741                                  response_len);
742     }
743     if (!header.answers.empty()) {
744         ALOGI("already %zu answers present in query", header.answers.size());
745         return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
746                                  response_len);
747     }
748 
749     if (edns_ == Edns::FORMERR_UNCOND) {
750         ALOGI("force to return RCODE FORMERR");
751         return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, response_len);
752     }
753 
754     if (!header.additionals.empty() && edns_ != Edns::ON) {
755         ALOGI("DNS request has an additional section (assumed EDNS). "
756               "Simulating an ancient (pre-EDNS) server, and returning %s",
757               edns_ == Edns::FORMERR_ON_EDNS ? "RCODE FORMERR." : "no response.");
758         if (edns_ == Edns::FORMERR_ON_EDNS) {
759             return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, response_len);
760         }
761         // No response.
762         return false;
763     }
764     {
765         std::lock_guard lock(queries_mutex_);
766         for (const DNSQuestion& question : header.questions) {
767             queries_.push_back(make_pair(question.qname.name,
768                                          ns_type(question.qtype)));
769         }
770     }
771 
772     // Ignore requests with the preset probability.
773     auto constexpr bound = std::numeric_limits<unsigned>::max();
774     if (arc4random_uniform(bound) > bound * response_probability_) {
775         if (error_rcode_ < 0) {
776             ALOGI("Returning no response");
777             return false;
778         } else {
779             ALOGI("returning RCODE %d in accordance with probability distribution",
780                   static_cast<int>(error_rcode_));
781             return makeErrorResponse(&header, error_rcode_, response, response_len);
782         }
783     }
784 
785     for (const DNSQuestion& question : header.questions) {
786         if (question.qclass != ns_class::ns_c_in &&
787             question.qclass != ns_class::ns_c_any) {
788             ALOGI("unsupported question class %u", question.qclass);
789             return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
790                                      response_len);
791         }
792 
793         if (!addAnswerRecords(question, &header.answers)) {
794             return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response, response_len);
795         }
796     }
797 
798     header.qr = true;
799     char* response_cur = header.write(response, response + *response_len);
800     if (response_cur == nullptr) {
801         return false;
802     }
803     *response_len = response_cur - response;
804     return true;
805 }
806 
addAnswerRecords(const DNSQuestion & question,std::vector<DNSRecord> * answers) const807 bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
808                                     std::vector<DNSRecord>* answers) const {
809     std::lock_guard guard(mappings_mutex_);
810     std::string rname = question.qname.name;
811     std::vector<int> rtypes;
812 
813     if (question.qtype == ns_type::ns_t_a || question.qtype == ns_type::ns_t_aaaa)
814         rtypes.push_back(ns_type::ns_t_cname);
815     rtypes.push_back(question.qtype);
816     for (int rtype : rtypes) {
817         std::set<std::string> cnames_Loop;
818         std::unordered_map<QueryKey, std::string, QueryKeyHash>::const_iterator it;
819         while ((it = mappings_.find(QueryKey(rname, rtype))) != mappings_.end()) {
820             if (rtype == ns_type::ns_t_cname) {
821                 // When detect CNAME infinite loops by cnames_Loop, it won't save the duplicate one.
822                 // As following, the query will stop on loop3 by detecting the same cname.
823                 // loop1.{"a.xxx.com", ns_type::ns_t_cname, "b.xxx.com"}(insert in answer record)
824                 // loop2.{"b.xxx.com", ns_type::ns_t_cname, "a.xxx.com"}(insert in answer record)
825                 // loop3.{"a.xxx.com", ns_type::ns_t_cname, "b.xxx.com"}(When the same cname record
826                 //   is found in cnames_Loop already, break the query loop.)
827                 if (cnames_Loop.find(it->first.name) != cnames_Loop.end()) break;
828                 cnames_Loop.insert(it->first.name);
829             }
830             DNSRecord record{
831                     .name = {.name = it->first.name},
832                     .rtype = it->first.type,
833                     .rclass = ns_class::ns_c_in,
834                     .ttl = 5,  // seconds
835             };
836             fillAnswerRdata(it->second, record);
837             answers->push_back(std::move(record));
838             if (rtype != ns_type::ns_t_cname) break;
839             rname = it->second;
840         }
841     }
842 
843     if (answers->size() == 0) {
844         // TODO(imaipi): handle correctly
845         ALOGI("no mapping found for %s %s, lazily refusing to add an answer",
846               question.qname.name.c_str(), dnstype2str(question.qtype));
847     }
848 
849     return true;
850 }
851 
fillAnswerRdata(const std::string & rdatastr,DNSRecord & record) const852 bool DNSResponder::fillAnswerRdata(const std::string& rdatastr, DNSRecord& record) const {
853     if (record.rtype == ns_type::ns_t_a) {
854         record.rdata.resize(4);
855         if (inet_pton(AF_INET, rdatastr.c_str(), record.rdata.data()) != 1) {
856             ALOGI("inet_pton(AF_INET, %s) failed", rdatastr.c_str());
857             return false;
858         }
859     } else if (record.rtype == ns_type::ns_t_aaaa) {
860         record.rdata.resize(16);
861         if (inet_pton(AF_INET6, rdatastr.c_str(), record.rdata.data()) != 1) {
862             ALOGI("inet_pton(AF_INET6, %s) failed", rdatastr.c_str());
863             return false;
864         }
865     } else if ((record.rtype == ns_type::ns_t_ptr) || (record.rtype == ns_type::ns_t_cname)) {
866         constexpr char delimiter = '.';
867         std::string name = rdatastr;
868         std::vector<char> rdata;
869 
870         // Generating PTRDNAME field(section 3.3.12) or CNAME field(section 3.3.1) in rfc1035.
871         // The "name" should be an absolute domain name which ends in a dot.
872         if (name.back() != delimiter) {
873             ALOGI("invalid absolute domain name");
874             return false;
875         }
876         name.pop_back();  // remove the dot in tail
877         for (const std::string& label : android::base::Split(name, {delimiter})) {
878             // The length of label is limited to 63 octets or less. See RFC 1035 section 3.1.
879             if (label.length() == 0 || label.length() > 63) {
880                 ALOGI("invalid label length");
881                 return false;
882             }
883 
884             rdata.push_back(label.length());
885             rdata.insert(rdata.end(), label.begin(), label.end());
886         }
887         rdata.push_back(0);  // Length byte of zero terminates the label list
888 
889         // The length of domain name is limited to 255 octets or less. See RFC 1035 section 3.1.
890         if (rdata.size() > 255) {
891             ALOGI("invalid name length");
892             return false;
893         }
894         record.rdata = move(rdata);
895     } else {
896         ALOGI("unhandled qtype %s", dnstype2str(record.rtype));
897         return false;
898     }
899     return true;
900 }
901 
makeErrorResponse(DNSHeader * header,ns_rcode rcode,char * response,size_t * response_len) const902 bool DNSResponder::makeErrorResponse(DNSHeader* header, ns_rcode rcode,
903                                      char* response, size_t* response_len)
904                                      const {
905     header->answers.clear();
906     header->authorities.clear();
907     header->additionals.clear();
908     header->rcode = rcode;
909     header->qr = true;
910     char* response_cur = header->write(response, response + *response_len);
911     if (response_cur == nullptr) return false;
912     *response_len = response_cur - response;
913     return true;
914 }
915 
setDeferredResp(bool deferred_resp)916 void DNSResponder::setDeferredResp(bool deferred_resp) {
917     std::lock_guard<std::mutex> guard(cv_mutex_for_deferred_resp_);
918     deferred_resp_ = deferred_resp;
919     if (!deferred_resp_) {
920         cv_for_deferred_resp_.notify_one();
921     }
922 }
923 
addFd(int fd,uint32_t events)924 bool DNSResponder::addFd(int fd, uint32_t events) {
925     epoll_event ev;
926     ev.events = events;
927     ev.data.fd = fd;
928     if (epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, fd, &ev) < 0) {
929         APLOGI("epoll_ctl() for socket %d failed", fd);
930         return false;
931     }
932     return true;
933 }
934 
handleQuery()935 void DNSResponder::handleQuery() {
936     char buffer[4096];
937     sockaddr_storage sa;
938     socklen_t sa_len = sizeof(sa);
939     ssize_t len;
940     do {
941         len = recvfrom(socket_.get(), buffer, sizeof(buffer), 0, (sockaddr*)&sa, &sa_len);
942     } while (len < 0 && (errno == EAGAIN || errno == EINTR));
943     if (len <= 0) {
944         APLOGI("recvfrom() failed, len=%zu", len);
945         return;
946     }
947     DBGLOG("read %zd bytes", len);
948     std::lock_guard lock(cv_mutex_);
949     char response[4096];
950     size_t response_len = sizeof(response);
951     if (handleDNSRequest(buffer, len, response, &response_len) && response_len > 0) {
952         // place wait_for after handleDNSRequest() so we can check the number of queries in
953         // test case before it got responded.
954         std::unique_lock guard(cv_mutex_for_deferred_resp_);
955         cv_for_deferred_resp_.wait(
956                 guard, [this]() REQUIRES(cv_mutex_for_deferred_resp_) { return !deferred_resp_; });
957 
958         len = sendto(socket_.get(), response, response_len, 0,
959                      reinterpret_cast<const sockaddr*>(&sa), sa_len);
960         std::string host_str = addr2str(reinterpret_cast<const sockaddr*>(&sa), sa_len);
961         if (len > 0) {
962             DBGLOG("sent %zu bytes to %s", len, host_str.c_str());
963         } else {
964             APLOGI("sendto() failed for %s", host_str.c_str());
965         }
966         // Test that the response is actually a correct DNS message.
967         const char* response_end = response + len;
968         DNSHeader header;
969         const char* cur = header.read(response, response_end);
970         if (cur == nullptr) ALOGW("response is flawed");
971     } else {
972         ALOGW("not responding");
973     }
974     cv.notify_one();
975     return;
976 }
977 
sendToEventFd()978 bool DNSResponder::sendToEventFd() {
979     const uint64_t data = 1;
980     if (const ssize_t rt = write(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
981         APLOGI("failed to write eventfd, rt=%zd", rt);
982         return false;
983     }
984     return true;
985 }
986 
handleEventFd()987 void DNSResponder::handleEventFd() {
988     int64_t data;
989     if (const ssize_t rt = read(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
990         APLOGI("ignore reading eventfd failed, rt=%zd", rt);
991     }
992 }
993 
994 }  // namespace test
995