• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "dns_responder.h"
18 
19 #include <arpa/inet.h>
20 #include <fcntl.h>
21 #include <netdb.h>
22 #include <stdarg.h>
23 #include <stdlib.h>
24 #include <string.h>
25 #include <sys/epoll.h>
26 #include <sys/eventfd.h>
27 #include <sys/socket.h>
28 #include <sys/types.h>
29 #include <unistd.h>
30 
31 #include <chrono>
32 #include <iostream>
33 #include <set>
34 #include <vector>
35 
36 #define LOG_TAG "DNSResponder"
37 #include <android-base/logging.h>
38 #include <android-base/strings.h>
39 #include <netdutils/InternetAddresses.h>
40 #include <netdutils/Slice.h>
41 #include <netdutils/SocketOption.h>
42 
43 using android::netdutils::enableSockopt;
44 using android::netdutils::ScopedAddrinfo;
45 using android::netdutils::Slice;
46 
47 namespace test {
48 
errno2str()49 std::string errno2str() {
50     char error_msg[512] = {0};
51     // It actually calls __gnu_strerror_r() which returns the type |char*| rather than |int|.
52     // PLOG is an option though it requires lots of changes from ALOGx() to LOG(x).
53     return strerror_r(errno, error_msg, sizeof(error_msg));
54 }
55 
str2hex(const char * buffer,size_t len)56 std::string str2hex(const char* buffer, size_t len) {
57     std::string str(len * 2, '\0');
58     for (size_t i = 0; i < len; ++i) {
59         static const char* hex = "0123456789ABCDEF";
60         uint8_t c = buffer[i];
61         str[i * 2] = hex[c >> 4];
62         str[i * 2 + 1] = hex[c & 0x0F];
63     }
64     return str;
65 }
66 
addr2str(const sockaddr * sa,socklen_t sa_len)67 std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
68     char host_str[NI_MAXHOST] = {0};
69     int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0, NI_NUMERICHOST);
70     if (rv == 0) return std::string(host_str);
71     return std::string();
72 }
73 
74 /* DNS struct helpers */
75 
dnstype2str(unsigned dnstype)76 const char* dnstype2str(unsigned dnstype) {
77     static std::unordered_map<unsigned, const char*> kTypeStrs = {
78             {ns_type::ns_t_a, "A"},
79             {ns_type::ns_t_ns, "NS"},
80             {ns_type::ns_t_md, "MD"},
81             {ns_type::ns_t_mf, "MF"},
82             {ns_type::ns_t_cname, "CNAME"},
83             {ns_type::ns_t_soa, "SOA"},
84             {ns_type::ns_t_mb, "MB"},
85             {ns_type::ns_t_mb, "MG"},
86             {ns_type::ns_t_mr, "MR"},
87             {ns_type::ns_t_null, "NULL"},
88             {ns_type::ns_t_wks, "WKS"},
89             {ns_type::ns_t_ptr, "PTR"},
90             {ns_type::ns_t_hinfo, "HINFO"},
91             {ns_type::ns_t_minfo, "MINFO"},
92             {ns_type::ns_t_mx, "MX"},
93             {ns_type::ns_t_txt, "TXT"},
94             {ns_type::ns_t_rp, "RP"},
95             {ns_type::ns_t_afsdb, "AFSDB"},
96             {ns_type::ns_t_x25, "X25"},
97             {ns_type::ns_t_isdn, "ISDN"},
98             {ns_type::ns_t_rt, "RT"},
99             {ns_type::ns_t_nsap, "NSAP"},
100             {ns_type::ns_t_nsap_ptr, "NSAP-PTR"},
101             {ns_type::ns_t_sig, "SIG"},
102             {ns_type::ns_t_key, "KEY"},
103             {ns_type::ns_t_px, "PX"},
104             {ns_type::ns_t_gpos, "GPOS"},
105             {ns_type::ns_t_aaaa, "AAAA"},
106             {ns_type::ns_t_loc, "LOC"},
107             {ns_type::ns_t_nxt, "NXT"},
108             {ns_type::ns_t_eid, "EID"},
109             {ns_type::ns_t_nimloc, "NIMLOC"},
110             {ns_type::ns_t_srv, "SRV"},
111             {ns_type::ns_t_naptr, "NAPTR"},
112             {ns_type::ns_t_kx, "KX"},
113             {ns_type::ns_t_cert, "CERT"},
114             {ns_type::ns_t_a6, "A6"},
115             {ns_type::ns_t_dname, "DNAME"},
116             {ns_type::ns_t_sink, "SINK"},
117             {ns_type::ns_t_opt, "OPT"},
118             {ns_type::ns_t_apl, "APL"},
119             {ns_type::ns_t_tkey, "TKEY"},
120             {ns_type::ns_t_tsig, "TSIG"},
121             {ns_type::ns_t_ixfr, "IXFR"},
122             {ns_type::ns_t_axfr, "AXFR"},
123             {ns_type::ns_t_mailb, "MAILB"},
124             {ns_type::ns_t_maila, "MAILA"},
125             {ns_type::ns_t_any, "ANY"},
126             {ns_type::ns_t_zxfr, "ZXFR"},
127     };
128     auto it = kTypeStrs.find(dnstype);
129     static const char* kUnknownStr{"UNKNOWN"};
130     if (it == kTypeStrs.end()) return kUnknownStr;
131     return it->second;
132 }
133 
dnsclass2str(unsigned dnsclass)134 const char* dnsclass2str(unsigned dnsclass) {
135     static std::unordered_map<unsigned, const char*> kClassStrs = {
136             {ns_class::ns_c_in, "Internet"},    {2, "CSNet"},
137             {ns_class::ns_c_chaos, "ChaosNet"}, {ns_class::ns_c_hs, "Hesiod"},
138             {ns_class::ns_c_none, "none"},      {ns_class::ns_c_any, "any"}};
139     auto it = kClassStrs.find(dnsclass);
140     static const char* kUnknownStr{"UNKNOWN"};
141     if (it == kClassStrs.end()) return kUnknownStr;
142     return it->second;
143 }
144 
dnsproto2str(int protocol)145 const char* dnsproto2str(int protocol) {
146     switch (protocol) {
147         case IPPROTO_TCP:
148             return "TCP";
149         case IPPROTO_UDP:
150             return "UDP";
151         default:
152             return "UNKNOWN";
153     }
154 }
155 
read(const char * buffer,const char * buffer_end)156 const char* DNSName::read(const char* buffer, const char* buffer_end) {
157     const char* cur = buffer;
158     bool last = false;
159     do {
160         cur = parseField(cur, buffer_end, &last);
161         if (cur == nullptr) {
162             LOG(ERROR) << "parsing failed at line " << __LINE__;
163             return nullptr;
164         }
165     } while (!last);
166     return cur;
167 }
168 
write(char * buffer,const char * buffer_end) const169 char* DNSName::write(char* buffer, const char* buffer_end) const {
170     char* buffer_cur = buffer;
171     for (size_t pos = 0; pos < name.size();) {
172         size_t dot_pos = name.find('.', pos);
173         if (dot_pos == std::string::npos) {
174             // Sanity check, should never happen unless parseField is broken.
175             LOG(ERROR) << "logic error: all names are expected to end with a '.'";
176             return nullptr;
177         }
178         const size_t len = dot_pos - pos;
179         if (len >= 256) {
180             LOG(ERROR) << "name component '" << name.substr(pos, dot_pos - pos) << "' is " << len
181                        << " long, but max is 255";
182             return nullptr;
183         }
184         if (buffer_cur + sizeof(uint8_t) + len > buffer_end) {
185             LOG(ERROR) << "buffer overflow at line " << __LINE__;
186             return nullptr;
187         }
188         *buffer_cur++ = len;
189         buffer_cur = std::copy(std::next(name.begin(), pos), std::next(name.begin(), dot_pos),
190                                buffer_cur);
191         pos = dot_pos + 1;
192     }
193     // Write final zero.
194     *buffer_cur++ = 0;
195     return buffer_cur;
196 }
197 
parseField(const char * buffer,const char * buffer_end,bool * last)198 const char* DNSName::parseField(const char* buffer, const char* buffer_end, bool* last) {
199     if (buffer + sizeof(uint8_t) > buffer_end) {
200         LOG(ERROR) << "parsing failed at line " << __LINE__;
201         return nullptr;
202     }
203     unsigned field_type = *buffer >> 6;
204     unsigned ofs = *buffer & 0x3F;
205     const char* cur = buffer + sizeof(uint8_t);
206     if (field_type == 0) {
207         // length + name component
208         if (ofs == 0) {
209             *last = true;
210             return cur;
211         }
212         if (cur + ofs > buffer_end) {
213             LOG(ERROR) << "parsing failed at line " << __LINE__;
214             return nullptr;
215         }
216         name.append(cur, ofs);
217         name.push_back('.');
218         return cur + ofs;
219     } else if (field_type == 3) {
220         LOG(ERROR) << "name compression not implemented";
221         return nullptr;
222     }
223     LOG(ERROR) << "invalid name field type";
224     return nullptr;
225 }
226 
read(const char * buffer,const char * buffer_end)227 const char* DNSQuestion::read(const char* buffer, const char* buffer_end) {
228     const char* cur = qname.read(buffer, buffer_end);
229     if (cur == nullptr) {
230         LOG(ERROR) << "parsing failed at line " << __LINE__;
231         return nullptr;
232     }
233     if (cur + 2 * sizeof(uint16_t) > buffer_end) {
234         LOG(ERROR) << "parsing failed at line " << __LINE__;
235         return nullptr;
236     }
237     qtype = ntohs(*reinterpret_cast<const uint16_t*>(cur));
238     qclass = ntohs(*reinterpret_cast<const uint16_t*>(cur + sizeof(uint16_t)));
239     return cur + 2 * sizeof(uint16_t);
240 }
241 
write(char * buffer,const char * buffer_end) const242 char* DNSQuestion::write(char* buffer, const char* buffer_end) const {
243     char* buffer_cur = qname.write(buffer, buffer_end);
244     if (buffer_cur == nullptr) return nullptr;
245     if (buffer_cur + 2 * sizeof(uint16_t) > buffer_end) {
246         LOG(ERROR) << "buffer overflow on line " << __LINE__;
247         return nullptr;
248     }
249     *reinterpret_cast<uint16_t*>(buffer_cur) = htons(qtype);
250     *reinterpret_cast<uint16_t*>(buffer_cur + sizeof(uint16_t)) = htons(qclass);
251     return buffer_cur + 2 * sizeof(uint16_t);
252 }
253 
toString() const254 std::string DNSQuestion::toString() const {
255     char buffer[16384];
256     int len = snprintf(buffer, sizeof(buffer), "Q<%s,%s,%s>", qname.name.c_str(),
257                        dnstype2str(qtype), dnsclass2str(qclass));
258     return std::string(buffer, len);
259 }
260 
read(const char * buffer,const char * buffer_end)261 const char* DNSRecord::read(const char* buffer, const char* buffer_end) {
262     const char* cur = name.read(buffer, buffer_end);
263     if (cur == nullptr) {
264         LOG(ERROR) << "parsing failed at line " << __LINE__;
265         return nullptr;
266     }
267     unsigned rdlen = 0;
268     cur = readIntFields(cur, buffer_end, &rdlen);
269     if (cur == nullptr) {
270         LOG(ERROR) << "parsing failed at line " << __LINE__;
271         return nullptr;
272     }
273     if (cur + rdlen > buffer_end) {
274         LOG(ERROR) << "parsing failed at line " << __LINE__;
275         return nullptr;
276     }
277     rdata.assign(cur, cur + rdlen);
278     return cur + rdlen;
279 }
280 
write(char * buffer,const char * buffer_end) const281 char* DNSRecord::write(char* buffer, const char* buffer_end) const {
282     char* buffer_cur = name.write(buffer, buffer_end);
283     if (buffer_cur == nullptr) return nullptr;
284     buffer_cur = writeIntFields(rdata.size(), buffer_cur, buffer_end);
285     if (buffer_cur == nullptr) return nullptr;
286     if (buffer_cur + rdata.size() > buffer_end) {
287         LOG(ERROR) << "buffer overflow on line " << __LINE__;
288         return nullptr;
289     }
290     return std::copy(rdata.begin(), rdata.end(), buffer_cur);
291 }
292 
toString() const293 std::string DNSRecord::toString() const {
294     char buffer[16384];
295     int len = snprintf(buffer, sizeof(buffer), "R<%s,%s,%s>", name.name.c_str(), dnstype2str(rtype),
296                        dnsclass2str(rclass));
297     return std::string(buffer, len);
298 }
299 
readIntFields(const char * buffer,const char * buffer_end,unsigned * rdlen)300 const char* DNSRecord::readIntFields(const char* buffer, const char* buffer_end, unsigned* rdlen) {
301     if (buffer + sizeof(IntFields) > buffer_end) {
302         LOG(ERROR) << "parsing failed at line " << __LINE__;
303         return nullptr;
304     }
305     const auto& intfields = *reinterpret_cast<const IntFields*>(buffer);
306     rtype = ntohs(intfields.rtype);
307     rclass = ntohs(intfields.rclass);
308     ttl = ntohl(intfields.ttl);
309     *rdlen = ntohs(intfields.rdlen);
310     return buffer + sizeof(IntFields);
311 }
312 
writeIntFields(unsigned rdlen,char * buffer,const char * buffer_end) const313 char* DNSRecord::writeIntFields(unsigned rdlen, char* buffer, const char* buffer_end) const {
314     if (buffer + sizeof(IntFields) > buffer_end) {
315         LOG(ERROR) << "buffer overflow on line " << __LINE__;
316         return nullptr;
317     }
318     auto& intfields = *reinterpret_cast<IntFields*>(buffer);
319     intfields.rtype = htons(rtype);
320     intfields.rclass = htons(rclass);
321     intfields.ttl = htonl(ttl);
322     intfields.rdlen = htons(rdlen);
323     return buffer + sizeof(IntFields);
324 }
325 
read(const char * buffer,const char * buffer_end)326 const char* DNSHeader::read(const char* buffer, const char* buffer_end) {
327     unsigned qdcount;
328     unsigned ancount;
329     unsigned nscount;
330     unsigned arcount;
331     const char* cur = readHeader(buffer, buffer_end, &qdcount, &ancount, &nscount, &arcount);
332     if (cur == nullptr) {
333         LOG(ERROR) << "parsing failed at line " << __LINE__;
334         return nullptr;
335     }
336     if (qdcount) {
337         questions.resize(qdcount);
338         for (unsigned i = 0; i < qdcount; ++i) {
339             cur = questions[i].read(cur, buffer_end);
340             if (cur == nullptr) {
341                 LOG(ERROR) << "parsing failed at line " << __LINE__;
342                 return nullptr;
343             }
344         }
345     }
346     if (ancount) {
347         answers.resize(ancount);
348         for (unsigned i = 0; i < ancount; ++i) {
349             cur = answers[i].read(cur, buffer_end);
350             if (cur == nullptr) {
351                 LOG(ERROR) << "parsing failed at line " << __LINE__;
352                 return nullptr;
353             }
354         }
355     }
356     if (nscount) {
357         authorities.resize(nscount);
358         for (unsigned i = 0; i < nscount; ++i) {
359             cur = authorities[i].read(cur, buffer_end);
360             if (cur == nullptr) {
361                 LOG(ERROR) << "parsing failed at line " << __LINE__;
362                 return nullptr;
363             }
364         }
365     }
366     if (arcount) {
367         additionals.resize(arcount);
368         for (unsigned i = 0; i < arcount; ++i) {
369             cur = additionals[i].read(cur, buffer_end);
370             if (cur == nullptr) {
371                 LOG(ERROR) << "parsing failed at line " << __LINE__;
372                 return nullptr;
373             }
374         }
375     }
376     return cur;
377 }
378 
write(char * buffer,const char * buffer_end) const379 char* DNSHeader::write(char* buffer, const char* buffer_end) const {
380     if (buffer + sizeof(Header) > buffer_end) {
381         LOG(ERROR) << "buffer overflow on line " << __LINE__;
382         return nullptr;
383     }
384     Header& header = *reinterpret_cast<Header*>(buffer);
385     // bytes 0-1
386     header.id = htons(id);
387     // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
388     header.flags0 = (qr << 7) | (opcode << 3) | (aa << 2) | (tr << 1) | rd;
389     // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
390     // Fake behavior: if the query set the "ad" bit, set it in the response too.
391     // In a real server, this should be set only if the data is authentic and the
392     // query contained an "ad" bit or DNSSEC extensions.
393     header.flags1 = (ad << 5) | rcode;
394     // rest of header
395     header.qdcount = htons(questions.size());
396     header.ancount = htons(answers.size());
397     header.nscount = htons(authorities.size());
398     header.arcount = htons(additionals.size());
399     char* buffer_cur = buffer + sizeof(Header);
400     for (const DNSQuestion& question : questions) {
401         buffer_cur = question.write(buffer_cur, buffer_end);
402         if (buffer_cur == nullptr) return nullptr;
403     }
404     for (const DNSRecord& answer : answers) {
405         buffer_cur = answer.write(buffer_cur, buffer_end);
406         if (buffer_cur == nullptr) return nullptr;
407     }
408     for (const DNSRecord& authority : authorities) {
409         buffer_cur = authority.write(buffer_cur, buffer_end);
410         if (buffer_cur == nullptr) return nullptr;
411     }
412     for (const DNSRecord& additional : additionals) {
413         buffer_cur = additional.write(buffer_cur, buffer_end);
414         if (buffer_cur == nullptr) return nullptr;
415     }
416     return buffer_cur;
417 }
418 
419 // TODO: convert all callers to this interface, then delete the old one.
write(std::vector<uint8_t> * out) const420 bool DNSHeader::write(std::vector<uint8_t>* out) const {
421     char buffer[16384];
422     char* end = this->write(buffer, buffer + sizeof buffer);
423     if (end == nullptr) return false;
424     out->insert(out->end(), buffer, end);
425     return true;
426 }
427 
toString() const428 std::string DNSHeader::toString() const {
429     // TODO
430     return std::string();
431 }
432 
readHeader(const char * buffer,const char * buffer_end,unsigned * qdcount,unsigned * ancount,unsigned * nscount,unsigned * arcount)433 const char* DNSHeader::readHeader(const char* buffer, const char* buffer_end, unsigned* qdcount,
434                                   unsigned* ancount, unsigned* nscount, unsigned* arcount) {
435     if (buffer + sizeof(Header) > buffer_end) return nullptr;
436     const auto& header = *reinterpret_cast<const Header*>(buffer);
437     // bytes 0-1
438     id = ntohs(header.id);
439     // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
440     qr = header.flags0 >> 7;
441     opcode = (header.flags0 >> 3) & 0x0F;
442     aa = (header.flags0 >> 2) & 1;
443     tr = (header.flags0 >> 1) & 1;
444     rd = header.flags0 & 1;
445     // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
446     ra = header.flags1 >> 7;
447     ad = (header.flags1 >> 5) & 1;
448     rcode = header.flags1 & 0xF;
449     // rest of header
450     *qdcount = ntohs(header.qdcount);
451     *ancount = ntohs(header.ancount);
452     *nscount = ntohs(header.nscount);
453     *arcount = ntohs(header.arcount);
454     return buffer + sizeof(Header);
455 }
456 
457 /* DNS responder */
458 
DNSResponder(std::string listen_address,std::string listen_service,ns_rcode error_rcode,MappingType mapping_type)459 DNSResponder::DNSResponder(std::string listen_address, std::string listen_service,
460                            ns_rcode error_rcode, MappingType mapping_type)
461     : listen_address_(std::move(listen_address)),
462       listen_service_(std::move(listen_service)),
463       error_rcode_(error_rcode),
464       mapping_type_(mapping_type) {}
465 
~DNSResponder()466 DNSResponder::~DNSResponder() {
467     stopServer();
468 }
469 
addMapping(const std::string & name,ns_type type,const std::string & addr)470 void DNSResponder::addMapping(const std::string& name, ns_type type, const std::string& addr) {
471     std::lock_guard lock(mappings_mutex_);
472     mappings_[{name, type}] = addr;
473 }
474 
addMappingDnsHeader(const std::string & name,ns_type type,const DNSHeader & header)475 void DNSResponder::addMappingDnsHeader(const std::string& name, ns_type type,
476                                        const DNSHeader& header) {
477     std::lock_guard lock(mappings_mutex_);
478     dnsheader_mappings_[{name, type}] = header;
479 }
480 
addMappingBinaryPacket(const std::vector<uint8_t> & query,const std::vector<uint8_t> & response)481 void DNSResponder::addMappingBinaryPacket(const std::vector<uint8_t>& query,
482                                           const std::vector<uint8_t>& response) {
483     std::lock_guard lock(mappings_mutex_);
484     packet_mappings_[query] = response;
485 }
486 
removeMapping(const std::string & name,ns_type type)487 void DNSResponder::removeMapping(const std::string& name, ns_type type) {
488     std::lock_guard lock(mappings_mutex_);
489     if (!mappings_.erase({name, type})) {
490         LOG(ERROR) << "Cannot remove mapping from (" << name << ", " << dnstype2str(type)
491                    << "), not present in registered mappings";
492     }
493 }
494 
removeMappingDnsHeader(const std::string & name,ns_type type)495 void DNSResponder::removeMappingDnsHeader(const std::string& name, ns_type type) {
496     std::lock_guard lock(mappings_mutex_);
497     if (!dnsheader_mappings_.erase({name, type})) {
498         LOG(ERROR) << "Cannot remove mapping from (" << name << ", " << dnstype2str(type)
499                    << "), not present in registered DnsHeader mappings";
500     }
501 }
502 
removeMappingBinaryPacket(const std::vector<uint8_t> & query)503 void DNSResponder::removeMappingBinaryPacket(const std::vector<uint8_t>& query) {
504     std::lock_guard lock(mappings_mutex_);
505     if (!packet_mappings_.erase(query)) {
506         LOG(ERROR) << "Cannot remove mapping, not present in registered BinaryPacket mappings";
507         LOG(INFO) << "Hex dump:";
508         LOG(INFO) << android::netdutils::toHex(
509                 Slice(const_cast<uint8_t*>(query.data()), query.size()), 32);
510     }
511 }
512 
513 // Set response probability on all supported protocols.
setResponseProbability(double response_probability)514 void DNSResponder::setResponseProbability(double response_probability) {
515     setResponseProbability(response_probability, IPPROTO_TCP);
516     setResponseProbability(response_probability, IPPROTO_UDP);
517 }
518 
setResponseDelayMs(unsigned timeMs)519 void DNSResponder::setResponseDelayMs(unsigned timeMs) {
520     response_delayed_ms_ = timeMs;
521 }
522 
523 // Set response probability on specific protocol. It's caller's duty to ensure that the |protocol|
524 // can be supported by DNSResponder.
setResponseProbability(double response_probability,int protocol)525 void DNSResponder::setResponseProbability(double response_probability, int protocol) {
526     switch (protocol) {
527         case IPPROTO_TCP:
528             response_probability_tcp_ = response_probability;
529             break;
530         case IPPROTO_UDP:
531             response_probability_udp_ = response_probability;
532             break;
533         default:
534             LOG(FATAL) << "Unsupported protocol " << protocol;  // abort() by log level FATAL
535     }
536 }
537 
getResponseProbability(int protocol) const538 double DNSResponder::getResponseProbability(int protocol) const {
539     switch (protocol) {
540         case IPPROTO_TCP:
541             return response_probability_tcp_;
542         case IPPROTO_UDP:
543             return response_probability_udp_;
544         default:
545             LOG(FATAL) << "Unsupported protocol " << protocol;  // abort() by log level FATAL
546             // unreachable
547             return -1;
548     }
549 }
550 
setEdns(Edns edns)551 void DNSResponder::setEdns(Edns edns) {
552     edns_ = edns;
553 }
554 
setTtl(unsigned ttl)555 void DNSResponder::setTtl(unsigned ttl) {
556     answer_record_ttl_sec_ = ttl;
557 }
558 
running() const559 bool DNSResponder::running() const {
560     return (udp_socket_.ok()) && (tcp_socket_.ok());
561 }
562 
startServer()563 bool DNSResponder::startServer() {
564     if (running()) {
565         LOG(ERROR) << "server already running";
566         return false;
567     }
568 
569     // Create UDP, TCP socket
570     if (udp_socket_ = createListeningSocket(SOCK_DGRAM); udp_socket_.get() < 0) {
571         PLOG(ERROR) << "failed to create UDP socket";
572         return false;
573     }
574 
575     if (tcp_socket_ = createListeningSocket(SOCK_STREAM); tcp_socket_.get() < 0) {
576         PLOG(ERROR) << "failed to create TCP socket";
577         return false;
578     }
579 
580     if (listen(tcp_socket_.get(), 1) < 0) {
581         PLOG(ERROR) << "failed to listen TCP socket";
582         return false;
583     }
584 
585     // Set up eventfd socket.
586     event_fd_.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));
587     if (event_fd_.get() == -1) {
588         PLOG(ERROR) << "failed to create eventfd";
589         return false;
590     }
591 
592     // Set up epoll socket.
593     epoll_fd_.reset(epoll_create1(EPOLL_CLOEXEC));
594     if (epoll_fd_.get() < 0) {
595         PLOG(ERROR) << "epoll_create1() failed on fd";
596         return false;
597     }
598 
599     LOG(INFO) << "adding UDP socket to epoll";
600     if (!addFd(udp_socket_.get(), EPOLLIN)) {
601         LOG(ERROR) << "failed to add the UDP socket to epoll";
602         return false;
603     }
604 
605     LOG(INFO) << "adding TCP socket to epoll";
606     if (!addFd(tcp_socket_.get(), EPOLLIN)) {
607         LOG(ERROR) << "failed to add the TCP socket to epoll";
608         return false;
609     }
610 
611     LOG(INFO) << "adding eventfd to epoll";
612     if (!addFd(event_fd_.get(), EPOLLIN)) {
613         LOG(ERROR) << "failed to add the eventfd to epoll";
614         return false;
615     }
616 
617     {
618         std::lock_guard lock(update_mutex_);
619         handler_thread_ = std::thread(&DNSResponder::requestHandler, this);
620     }
621     LOG(INFO) << "server started successfully";
622     return true;
623 }
624 
stopServer()625 bool DNSResponder::stopServer() {
626     std::lock_guard lock(update_mutex_);
627     if (!running()) {
628         LOG(ERROR) << "server not running";
629         return false;
630     }
631     LOG(INFO) << "stopping server";
632     if (!sendToEventFd()) {
633         return false;
634     }
635     handler_thread_.join();
636     epoll_fd_.reset();
637     event_fd_.reset();
638     udp_socket_.reset();
639     tcp_socket_.reset();
640     LOG(INFO) << "server stopped successfully";
641     return true;
642 }
643 
queries() const644 std::vector<DNSResponder::QueryInfo> DNSResponder::queries() const {
645     std::lock_guard lock(queries_mutex_);
646     return queries_;
647 }
648 
dumpQueries() const649 std::string DNSResponder::dumpQueries() const {
650     std::lock_guard lock(queries_mutex_);
651     std::string out;
652 
653     for (const auto& q : queries_) {
654         out += "{\"" + q.name + "\", " + std::to_string(q.type) + "\", " +
655                dnsproto2str(q.protocol) + "} ";
656     }
657     return out;
658 }
659 
clearQueries()660 void DNSResponder::clearQueries() {
661     std::lock_guard lock(queries_mutex_);
662     queries_.clear();
663 }
664 
hasOptPseudoRR(DNSHeader * header) const665 bool DNSResponder::hasOptPseudoRR(DNSHeader* header) const {
666     if (header->additionals.empty()) return false;
667 
668     // OPT RR may be placed anywhere within the additional section. See RFC 6891 section 6.1.1.
669     auto found = std::find_if(header->additionals.begin(), header->additionals.end(),
670                               [](const auto& a) { return a.rtype == ns_type::ns_t_opt; });
671     return found != header->additionals.end();
672 }
673 
requestHandler()674 void DNSResponder::requestHandler() {
675     epoll_event evs[EPOLL_MAX_EVENTS];
676     while (true) {
677         int n = epoll_wait(epoll_fd_.get(), evs, EPOLL_MAX_EVENTS, -1);
678         if (n <= 0) {
679             PLOG(ERROR) << "epoll_wait() failed, n=" << n;
680             return;
681         }
682 
683         for (int i = 0; i < n; i++) {
684             const int fd = evs[i].data.fd;
685             const uint32_t events = evs[i].events;
686             if (fd == event_fd_.get() && (events & (EPOLLIN | EPOLLERR))) {
687                 handleEventFd();
688                 return;
689             } else if (fd == udp_socket_.get() && (events & (EPOLLIN | EPOLLERR))) {
690                 handleQuery(IPPROTO_UDP);
691             } else if (fd == tcp_socket_.get() && (events & (EPOLLIN | EPOLLERR))) {
692                 handleQuery(IPPROTO_TCP);
693             } else {
694                 LOG(WARNING) << "unexpected epoll events " << events << " on fd " << fd;
695             }
696         }
697     }
698 }
699 
handleDNSRequest(const char * buffer,ssize_t len,int protocol,char * response,size_t * response_len) const700 bool DNSResponder::handleDNSRequest(const char* buffer, ssize_t len, int protocol, char* response,
701                                     size_t* response_len) const {
702     LOG(DEBUG) << "request: '" << str2hex(buffer, len) << "', on " << dnsproto2str(protocol);
703     const char* buffer_end = buffer + len;
704     DNSHeader header;
705     const char* cur = header.read(buffer, buffer_end);
706     // TODO(imaipi): for now, unparsable messages are silently dropped, fix.
707     if (cur == nullptr) {
708         LOG(ERROR) << "failed to parse query";
709         return false;
710     }
711     if (header.qr) {
712         LOG(ERROR) << "response received instead of a query";
713         return false;
714     }
715     if (header.opcode != ns_opcode::ns_o_query) {
716         LOG(INFO) << "unsupported request opcode received";
717         return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response, response_len);
718     }
719     if (header.questions.empty()) {
720         LOG(INFO) << "no questions present";
721         return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, response_len);
722     }
723     if (!header.answers.empty()) {
724         LOG(INFO) << "already " << header.answers.size() << " answers present in query";
725         return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, response_len);
726     }
727 
728     if (edns_ == Edns::FORMERR_UNCOND) {
729         LOG(INFO) << "force to return RCODE FORMERR";
730         return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, response_len);
731     }
732 
733     if (!header.additionals.empty() && edns_ != Edns::ON) {
734         LOG(INFO) << "DNS request has an additional section (assumed EDNS). Simulating an ancient "
735                      "(pre-EDNS) server, and returning "
736                   << (edns_ == Edns::FORMERR_ON_EDNS ? "RCODE FORMERR." : "no response.");
737         if (edns_ == Edns::FORMERR_ON_EDNS) {
738             return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, response_len);
739         }
740         // No response.
741         return false;
742     }
743     {
744         std::lock_guard lock(queries_mutex_);
745         for (const DNSQuestion& question : header.questions) {
746             queries_.push_back({question.qname.name, ns_type(question.qtype), protocol});
747         }
748     }
749     // Ignore requests with the preset probability.
750     auto constexpr bound = std::numeric_limits<unsigned>::max();
751     if (arc4random_uniform(bound) > bound * getResponseProbability(protocol)) {
752         if (error_rcode_ < 0) {
753             LOG(ERROR) << "Returning no response";
754             return false;
755         } else {
756             LOG(INFO) << "returning RCODE " << static_cast<int>(error_rcode_)
757                       << " in accordance with probability distribution";
758             return makeErrorResponse(&header, error_rcode_, response, response_len);
759         }
760     }
761 
762     // Make the response. The query has been read into |header| which is used to build and return
763     // the response as well.
764     return makeResponse(&header, protocol, response, response_len);
765 }
766 
addAnswerRecords(const DNSQuestion & question,std::vector<DNSRecord> * answers) const767 bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
768                                     std::vector<DNSRecord>* answers) const {
769     std::lock_guard guard(mappings_mutex_);
770     std::string rname = question.qname.name;
771     std::vector<int> rtypes;
772 
773     if (question.qtype == ns_type::ns_t_a || question.qtype == ns_type::ns_t_aaaa ||
774         question.qtype == ns_type::ns_t_ptr)
775         rtypes.push_back(ns_type::ns_t_cname);
776     rtypes.push_back(question.qtype);
777     for (int rtype : rtypes) {
778         std::set<std::string> cnames_Loop;
779         std::unordered_map<QueryKey, std::string, QueryKeyHash>::const_iterator it;
780         while ((it = mappings_.find(QueryKey(rname, rtype))) != mappings_.end()) {
781             if (rtype == ns_type::ns_t_cname) {
782                 // When detect CNAME infinite loops by cnames_Loop, it won't save the duplicate one.
783                 // As following, the query will stop on loop3 by detecting the same cname.
784                 // loop1.{"a.xxx.com", ns_type::ns_t_cname, "b.xxx.com"}(insert in answer record)
785                 // loop2.{"b.xxx.com", ns_type::ns_t_cname, "a.xxx.com"}(insert in answer record)
786                 // loop3.{"a.xxx.com", ns_type::ns_t_cname, "b.xxx.com"}(When the same cname record
787                 //   is found in cnames_Loop already, break the query loop.)
788                 if (cnames_Loop.find(it->first.name) != cnames_Loop.end()) break;
789                 cnames_Loop.insert(it->first.name);
790             }
791             DNSRecord record{
792                     .name = {.name = it->first.name},
793                     .rtype = it->first.type,
794                     .rclass = ns_class::ns_c_in,
795                     .ttl = answer_record_ttl_sec_,  // seconds
796             };
797             if (!fillRdata(it->second, record)) return false;
798             answers->push_back(std::move(record));
799             if (rtype != ns_type::ns_t_cname) break;
800             rname = it->second;
801         }
802     }
803 
804     if (answers->size() == 0) {
805         // TODO(imaipi): handle correctly
806         LOG(INFO) << "no mapping found for " << question.qname.name << " "
807                   << dnstype2str(question.qtype) << ", lazily refusing to add an answer";
808     }
809 
810     return true;
811 }
812 
fillRdata(const std::string & rdatastr,DNSRecord & record)813 bool DNSResponder::fillRdata(const std::string& rdatastr, DNSRecord& record) {
814     if (record.rtype == ns_type::ns_t_a) {
815         record.rdata.resize(4);
816         if (inet_pton(AF_INET, rdatastr.c_str(), record.rdata.data()) != 1) {
817             LOG(ERROR) << "inet_pton(AF_INET, " << rdatastr << ") failed";
818             return false;
819         }
820     } else if (record.rtype == ns_type::ns_t_aaaa) {
821         record.rdata.resize(16);
822         if (inet_pton(AF_INET6, rdatastr.c_str(), record.rdata.data()) != 1) {
823             LOG(ERROR) << "inet_pton(AF_INET6, " << rdatastr << ") failed";
824             return false;
825         }
826     } else if ((record.rtype == ns_type::ns_t_ptr) || (record.rtype == ns_type::ns_t_cname) ||
827                (record.rtype == ns_type::ns_t_ns)) {
828         constexpr char delimiter = '.';
829         std::string name = rdatastr;
830         std::vector<char> rdata;
831 
832         // Generating PTRDNAME field(section 3.3.12) or CNAME field(section 3.3.1) in rfc1035.
833         // The "name" should be an absolute domain name which ends in a dot.
834         if (name.back() != delimiter) {
835             LOG(ERROR) << "invalid absolute domain name";
836             return false;
837         }
838         name.pop_back();  // remove the dot in tail
839         for (const std::string& label : android::base::Split(name, {delimiter})) {
840             // The length of label is limited to 63 octets or less. See RFC 1035 section 3.1.
841             if (label.length() == 0 || label.length() > 63) {
842                 LOG(ERROR) << "invalid label length";
843                 return false;
844             }
845 
846             rdata.push_back(label.length());
847             rdata.insert(rdata.end(), label.begin(), label.end());
848         }
849         rdata.push_back(0);  // Length byte of zero terminates the label list
850 
851         // The length of domain name is limited to 255 octets or less. See RFC 1035 section 3.1.
852         if (rdata.size() > 255) {
853             LOG(ERROR) << "invalid name length";
854             return false;
855         }
856         record.rdata = move(rdata);
857     } else {
858         LOG(ERROR) << "unhandled qtype " << dnstype2str(record.rtype);
859         return false;
860     }
861     return true;
862 }
863 
writePacket(const DNSHeader * header,char * response,size_t * response_len) const864 bool DNSResponder::writePacket(const DNSHeader* header, char* response,
865                                size_t* response_len) const {
866     char* response_cur = header->write(response, response + *response_len);
867     if (response_cur == nullptr) {
868         return false;
869     }
870     *response_len = response_cur - response;
871     return true;
872 }
873 
makeErrorResponse(DNSHeader * header,ns_rcode rcode,char * response,size_t * response_len) const874 bool DNSResponder::makeErrorResponse(DNSHeader* header, ns_rcode rcode, char* response,
875                                      size_t* response_len) const {
876     header->answers.clear();
877     header->authorities.clear();
878     header->additionals.clear();
879     header->rcode = rcode;
880     header->qr = true;
881     return writePacket(header, response, response_len);
882 }
883 
makeTruncatedResponse(DNSHeader * header,char * response,size_t * response_len) const884 bool DNSResponder::makeTruncatedResponse(DNSHeader* header, char* response,
885                                          size_t* response_len) const {
886     // Build a minimal response for non-EDNS response over UDP. Truncate all stub RRs in answer,
887     // authority and additional section. EDNS response truncation has not supported here yet
888     // because the EDNS response must have an OPT record. See RFC 6891 section 7.
889     header->answers.clear();
890     header->authorities.clear();
891     header->additionals.clear();
892     header->qr = true;
893     header->tr = true;
894     return writePacket(header, response, response_len);
895 }
896 
makeResponse(DNSHeader * header,int protocol,char * response,size_t * response_len) const897 bool DNSResponder::makeResponse(DNSHeader* header, int protocol, char* response,
898                                 size_t* response_len) const {
899     char buffer[16384];
900     size_t buffer_len = sizeof(buffer);
901     bool ret;
902 
903     switch (mapping_type_) {
904         case MappingType::DNS_HEADER:
905             ret = makeResponseFromDnsHeader(header, buffer, &buffer_len);
906             break;
907         case MappingType::BINARY_PACKET:
908             ret = makeResponseFromBinaryPacket(header, buffer, &buffer_len);
909             break;
910         case MappingType::ADDRESS_OR_HOSTNAME:
911         default:
912             ret = makeResponseFromAddressOrHostname(header, buffer, &buffer_len);
913     }
914 
915     if (!ret) return false;
916 
917     // Return truncated response if the built non-EDNS response size which is larger than 512 bytes
918     // will be responded over UDP. The truncated response implementation here just simply set up
919     // the TC bit and truncate all stub RRs in answer, authority and additional section. It is
920     // because the resolver will retry DNS query over TCP and use the full TCP response. See also
921     // RFC 1035 section 4.2.1 for UDP response truncation and RFC 6891 section 4.3 for EDNS larger
922     // response size capability.
923     // TODO: Perhaps keep the stub RRs as possible.
924     // TODO: Perhaps truncate the EDNS based response over UDP. See also RFC 6891 section 4.3,
925     // section 6.2.5 and section 7.
926     if (protocol == IPPROTO_UDP && buffer_len > kMaximumUdpSize &&
927         !hasOptPseudoRR(header) /* non-EDNS */) {
928         LOG(INFO) << "Return truncated response because original response length " << buffer_len
929                   << " is larger than " << kMaximumUdpSize << " bytes.";
930         return makeTruncatedResponse(header, response, response_len);
931     }
932 
933     if (buffer_len > *response_len) {
934         LOG(ERROR) << "buffer overflow on line " << __LINE__;
935         return false;
936     }
937     memcpy(response, buffer, buffer_len);
938     *response_len = buffer_len;
939     return true;
940 }
941 
makeResponseFromAddressOrHostname(DNSHeader * header,char * response,size_t * response_len) const942 bool DNSResponder::makeResponseFromAddressOrHostname(DNSHeader* header, char* response,
943                                                      size_t* response_len) const {
944     for (const DNSQuestion& question : header->questions) {
945         if (question.qclass != ns_class::ns_c_in && question.qclass != ns_class::ns_c_any) {
946             LOG(INFO) << "unsupported question class " << question.qclass;
947             return makeErrorResponse(header, ns_rcode::ns_r_notimpl, response, response_len);
948         }
949 
950         if (!addAnswerRecords(question, &header->answers)) {
951             return makeErrorResponse(header, ns_rcode::ns_r_servfail, response, response_len);
952         }
953     }
954     header->qr = true;
955     return writePacket(header, response, response_len);
956 }
957 
makeResponseFromDnsHeader(DNSHeader * header,char * response,size_t * response_len) const958 bool DNSResponder::makeResponseFromDnsHeader(DNSHeader* header, char* response,
959                                              size_t* response_len) const {
960     std::lock_guard guard(mappings_mutex_);
961 
962     // Support single question record only. It should be okay because res_mkquery() sets "qdcount"
963     // as one for the operation QUERY and handleDNSRequest() checks ns_opcode::ns_o_query before
964     // making a response. In other words, only need to handle the query which has single question
965     // section. See also res_mkquery() in system/netd/resolv/res_mkquery.cpp.
966     // TODO: Perhaps add support for multi-question records.
967     const std::vector<DNSQuestion>& questions = header->questions;
968     if (questions.size() != 1) {
969         LOG(INFO) << "unsupported question count " << questions.size();
970         return makeErrorResponse(header, ns_rcode::ns_r_notimpl, response, response_len);
971     }
972 
973     if (questions[0].qclass != ns_class::ns_c_in && questions[0].qclass != ns_class::ns_c_any) {
974         LOG(INFO) << "unsupported question class " << questions[0].qclass;
975         return makeErrorResponse(header, ns_rcode::ns_r_notimpl, response, response_len);
976     }
977 
978     const std::string name = questions[0].qname.name;
979     const int qtype = questions[0].qtype;
980     const auto it = dnsheader_mappings_.find(QueryKey(name, qtype));
981     if (it != dnsheader_mappings_.end()) {
982         // Store both "id" and "rd" which comes from query.
983         const unsigned id = header->id;
984         const bool rd = header->rd;
985 
986         // Build a response from the registered DNSHeader mapping.
987         *header = it->second;
988         // Assign both "ID" and "RD" fields from query to response. See RFC 1035 section 4.1.1.
989         header->id = id;
990         header->rd = rd;
991     } else {
992         // TODO: handle correctly. See also TODO in addAnswerRecords().
993         LOG(INFO) << "no mapping found for " << name << " " << dnstype2str(qtype)
994                   << ", couldn't build a response from DNSHeader mapping";
995 
996         // Note that do nothing as makeResponseFromAddressOrHostname() if no mapping is found. It
997         // just changes the QR flag from query (0) to response (1) in the query. Then, send the
998         // modified query back as a response.
999         header->qr = true;
1000     }
1001     return writePacket(header, response, response_len);
1002 }
1003 
makeResponseFromBinaryPacket(DNSHeader * header,char * response,size_t * response_len) const1004 bool DNSResponder::makeResponseFromBinaryPacket(DNSHeader* header, char* response,
1005                                                 size_t* response_len) const {
1006     std::lock_guard guard(mappings_mutex_);
1007 
1008     // Build a search key of mapping from the query.
1009     // TODO: Perhaps pass the query packet buffer directly from the caller.
1010     std::vector<uint8_t> queryKey;
1011     if (!header->write(&queryKey)) return false;
1012     // Clear ID field (byte 0-1) because it is not required by the mapping key.
1013     queryKey[0] = 0;
1014     queryKey[1] = 0;
1015 
1016     const auto it = packet_mappings_.find(queryKey);
1017     if (it != packet_mappings_.end()) {
1018         if (it->second.size() > *response_len) {
1019             LOG(ERROR) << "buffer overflow on line " << __LINE__;
1020             return false;
1021         } else {
1022             std::copy(it->second.begin(), it->second.end(), response);
1023             // Leave the "RD" flag assignment for testing. The "RD" flag of the response keep
1024             // using the one from the raw packet mapping but the received query.
1025             // Assign "ID" field from query to response. See RFC 1035 section 4.1.1.
1026             reinterpret_cast<uint16_t*>(response)[0] = htons(header->id);  // bytes 0-1: id
1027             *response_len = it->second.size();
1028             return true;
1029         }
1030     } else {
1031         // TODO: handle correctly. See also TODO in addAnswerRecords().
1032         // TODO: Perhaps dump packet content to indicate which query failed.
1033         LOG(INFO) << "no mapping found, couldn't build a response from BinaryPacket mapping";
1034         // Note that do nothing as makeResponseFromAddressOrHostname() if no mapping is found. It
1035         // just changes the QR flag from query (0) to response (1) in the query. Then, send the
1036         // modified query back as a response.
1037         header->qr = true;
1038         return writePacket(header, response, response_len);
1039     }
1040 }
1041 
setDeferredResp(bool deferred_resp)1042 void DNSResponder::setDeferredResp(bool deferred_resp) {
1043     std::lock_guard<std::mutex> guard(cv_mutex_for_deferred_resp_);
1044     deferred_resp_ = deferred_resp;
1045     if (!deferred_resp_) {
1046         cv_for_deferred_resp_.notify_one();
1047     }
1048 }
1049 
addFd(int fd,uint32_t events)1050 bool DNSResponder::addFd(int fd, uint32_t events) {
1051     epoll_event ev;
1052     ev.events = events;
1053     ev.data.fd = fd;
1054     if (epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, fd, &ev) < 0) {
1055         PLOG(ERROR) << "epoll_ctl() for socket " << fd << " failed";
1056         return false;
1057     }
1058     return true;
1059 }
1060 
handleQuery(int protocol)1061 void DNSResponder::handleQuery(int protocol) {
1062     char buffer[16384];
1063     sockaddr_storage sa;
1064     socklen_t sa_len = sizeof(sa);
1065     ssize_t len = 0;
1066     android::base::unique_fd tcpFd;
1067     switch (protocol) {
1068         case IPPROTO_UDP:
1069             do {
1070                 len = recvfrom(udp_socket_.get(), buffer, sizeof(buffer), 0, (sockaddr*)&sa,
1071                                &sa_len);
1072             } while (len < 0 && (errno == EAGAIN || errno == EINTR));
1073             if (len <= 0) {
1074                 PLOG(ERROR) << "recvfrom() failed, len=" << len;
1075                 return;
1076             }
1077             break;
1078         case IPPROTO_TCP:
1079             tcpFd.reset(accept4(tcp_socket_.get(), reinterpret_cast<sockaddr*>(&sa), &sa_len,
1080                                 SOCK_CLOEXEC));
1081             if (tcpFd.get() < 0) {
1082                 PLOG(ERROR) << "failed to accept client socket";
1083                 return;
1084             }
1085             // Get the message length from two byte length field.
1086             // See also RFC 1035, section 4.2.2 and RFC 7766, section 8
1087             uint8_t queryMessageLengthField[2];
1088             if (read(tcpFd.get(), &queryMessageLengthField, 2) != 2) {
1089                 PLOG(ERROR) << "Not enough length field bytes";
1090                 return;
1091             }
1092 
1093             const uint16_t qlen = (queryMessageLengthField[0] << 8) | queryMessageLengthField[1];
1094             while (len < qlen) {
1095                 ssize_t ret = read(tcpFd.get(), buffer + len, qlen - len);
1096                 if (ret <= 0) {
1097                     PLOG(ERROR) << "Error while reading query";
1098                     return;
1099                 }
1100                 len += ret;
1101             }
1102             break;
1103     }
1104     LOG(DEBUG) << "read " << len << " bytes on " << dnsproto2str(protocol);
1105     std::lock_guard lock(cv_mutex_);
1106     char response[16384];
1107     size_t response_len = sizeof(response);
1108     // TODO: check whether sending malformed packets to DnsResponder
1109     if (handleDNSRequest(buffer, len, protocol, response, &response_len) && response_len > 0) {
1110         std::this_thread::sleep_for(std::chrono::milliseconds(response_delayed_ms_));
1111         // place wait_for after handleDNSRequest() so we can check the number of queries in
1112         // test case before it got responded.
1113         std::unique_lock guard(cv_mutex_for_deferred_resp_);
1114         cv_for_deferred_resp_.wait(
1115                 guard, [this]() REQUIRES(cv_mutex_for_deferred_resp_) { return !deferred_resp_; });
1116         len = 0;
1117 
1118         switch (protocol) {
1119             case IPPROTO_UDP:
1120                 len = sendto(udp_socket_.get(), response, response_len, 0,
1121                              reinterpret_cast<const sockaddr*>(&sa), sa_len);
1122                 if (len < 0) {
1123                     PLOG(ERROR) << "Failed to send response";
1124                 }
1125                 break;
1126             case IPPROTO_TCP:
1127                 // Get the message length from two byte length field.
1128                 // See also RFC 1035, section 4.2.2 and RFC 7766, section 8
1129                 uint8_t responseMessageLengthField[2];
1130                 responseMessageLengthField[0] = response_len >> 8;
1131                 responseMessageLengthField[1] = response_len;
1132                 if (write(tcpFd.get(), responseMessageLengthField, 2) != 2) {
1133                     PLOG(ERROR) << "Failed to write response length field";
1134                     break;
1135                 }
1136                 if (write(tcpFd.get(), response, response_len) !=
1137                     static_cast<ssize_t>(response_len)) {
1138                     PLOG(ERROR) << "Failed to write response";
1139                     break;
1140                 }
1141                 len = response_len;
1142                 break;
1143         }
1144         const std::string host_str = addr2str(reinterpret_cast<const sockaddr*>(&sa), sa_len);
1145         if (len > 0) {
1146             LOG(DEBUG) << "sent " << len << " bytes to " << host_str;
1147         } else {
1148             const char* method_str = (protocol == IPPROTO_TCP) ? "write()" : "sendto()";
1149             LOG(ERROR) << method_str << " failed for " << host_str;
1150         }
1151         // Test that the response is actually a correct DNS message.
1152         // TODO: Perhaps make DNS message validation to support name compression. Or it throws
1153         // a warning for a valid DNS message with name compression while the binary packet mapping
1154         // is used.
1155         const char* response_end = response + len;
1156         DNSHeader header;
1157         const char* cur = header.read(response, response_end);
1158         if (cur == nullptr) LOG(WARNING) << "response is flawed";
1159     } else {
1160         LOG(WARNING) << "not responding";
1161     }
1162     cv.notify_one();
1163     return;
1164 }
1165 
sendToEventFd()1166 bool DNSResponder::sendToEventFd() {
1167     const uint64_t data = 1;
1168     if (const ssize_t rt = write(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
1169         PLOG(ERROR) << "failed to write eventfd, rt=" << rt;
1170         return false;
1171     }
1172     return true;
1173 }
1174 
handleEventFd()1175 void DNSResponder::handleEventFd() {
1176     int64_t data;
1177     if (const ssize_t rt = read(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
1178         PLOG(INFO) << "ignore reading eventfd failed, rt=" << rt;
1179     }
1180 }
1181 
createListeningSocket(int socket_type)1182 android::base::unique_fd DNSResponder::createListeningSocket(int socket_type) {
1183     addrinfo ai_hints{
1184             .ai_flags = AI_PASSIVE,
1185             .ai_family = AF_UNSPEC,
1186             .ai_socktype = socket_type,
1187     };
1188     addrinfo* ai_res = nullptr;
1189     const int rv =
1190             getaddrinfo(listen_address_.c_str(), listen_service_.c_str(), &ai_hints, &ai_res);
1191     ScopedAddrinfo ai_res_cleanup(ai_res);
1192     if (rv) {
1193         LOG(ERROR) << "getaddrinfo(" << listen_address_ << ", " << listen_service_
1194                    << ") failed: " << gai_strerror(rv);
1195         return {};
1196     }
1197     for (const addrinfo* ai = ai_res; ai; ai = ai->ai_next) {
1198         android::base::unique_fd fd(
1199                 socket(ai->ai_family, ai->ai_socktype | SOCK_NONBLOCK, ai->ai_protocol));
1200         if (fd.get() < 0) {
1201             PLOG(ERROR) << "ignore creating socket failed";
1202             continue;
1203         }
1204         enableSockopt(fd.get(), SOL_SOCKET, SO_REUSEPORT).ignoreError();
1205         enableSockopt(fd.get(), SOL_SOCKET, SO_REUSEADDR).ignoreError();
1206         const std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
1207         const char* socket_str = (socket_type == SOCK_STREAM) ? "TCP" : "UDP";
1208 
1209         if (bind(fd.get(), ai->ai_addr, ai->ai_addrlen)) {
1210             PLOG(ERROR) << "failed to bind " << socket_str << " " << host_str << ":"
1211                         << listen_service_;
1212             continue;
1213         }
1214         LOG(INFO) << "bound to " << socket_str << " " << host_str << ":" << listen_service_;
1215         return fd;
1216     }
1217     return {};
1218 }
1219 
1220 }  // namespace test
1221