• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "dns_responder.h"
18 
19 #include <arpa/inet.h>
20 #include <fcntl.h>
21 #include <netdb.h>
22 #include <stdarg.h>
23 #include <stdlib.h>
24 #include <string.h>
25 #include <sys/epoll.h>
26 #include <sys/eventfd.h>
27 #include <sys/socket.h>
28 #include <sys/types.h>
29 #include <unistd.h>
30 #include <span>
31 
32 #include <chrono>
33 #include <iostream>
34 #include <set>
35 #include <vector>
36 
37 #define LOG_TAG "DNSResponder"
38 #include <android-base/logging.h>
39 #include <android-base/strings.h>
40 #include <netdutils/BackoffSequence.h>
41 #include <netdutils/InternetAddresses.h>
42 #include <netdutils/SocketOption.h>
43 
44 using android::base::unique_fd;
45 using android::netdutils::BackoffSequence;
46 using android::netdutils::enableSockopt;
47 using android::netdutils::ScopedAddrinfo;
48 using std::chrono::milliseconds;
49 
50 namespace test {
51 
errno2str()52 std::string errno2str() {
53     char error_msg[512] = {0};
54     // It actually calls __gnu_strerror_r() which returns the type |char*| rather than |int|.
55     // PLOG is an option though it requires lots of changes from ALOGx() to LOG(x).
56     return strerror_r(errno, error_msg, sizeof(error_msg));
57 }
58 
addr2str(const sockaddr * sa,socklen_t sa_len)59 std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
60     char host_str[NI_MAXHOST] = {0};
61     int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0, NI_NUMERICHOST);
62     if (rv == 0) return std::string(host_str);
63     return std::string();
64 }
65 
bytesToHexStr(std::span<const uint8_t> bytes)66 std::string bytesToHexStr(std::span<const uint8_t> bytes) {
67     static char const hex[16] = {'0', '1', '2', '3', '4', '5', '6', '7',
68                                  '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};
69     std::string str;
70     str.reserve(bytes.size() * 2);
71     for (uint8_t ch : bytes) {
72         str.append({hex[(ch & 0xf0) >> 4], hex[ch & 0xf]});
73     }
74     return str;
75 }
76 
77 // Because The address might still being set up (b/186181084), This is a wrapper function
78 // that retries bind() if errno is EADDRNOTAVAIL
bindSocket(int socket,const sockaddr * address,socklen_t address_len)79 int bindSocket(int socket, const sockaddr* address, socklen_t address_len) {
80     // Set the wrapper to try bind() at most 6 times with backoff time
81     // (100 ms, 200 ms, ..., 1600 ms).
82     auto backoff = BackoffSequence<milliseconds>::Builder()
83                            .withInitialRetransmissionTime(milliseconds(100))
84                            .withMaximumRetransmissionCount(5)
85                            .build();
86 
87     while (true) {
88         int ret = bind(socket, address, address_len);
89         if (ret == 0 || errno != EADDRNOTAVAIL) {
90             return ret;
91         }
92 
93         if (!backoff.hasNextTimeout()) break;
94 
95         LOG(WARNING) << "Retry to bind " << addr2str(address, address_len);
96         std::this_thread::sleep_for(backoff.getNextTimeout());
97     }
98 
99     // Set errno before return since it might have been changed somewhere.
100     errno = EADDRNOTAVAIL;
101     return -1;
102 }
103 
104 /* DNS struct helpers */
105 
dnstype2str(unsigned dnstype)106 const char* dnstype2str(unsigned dnstype) {
107     static std::unordered_map<unsigned, const char*> kTypeStrs = {
108             {ns_type::ns_t_a, "A"},
109             {ns_type::ns_t_ns, "NS"},
110             {ns_type::ns_t_md, "MD"},
111             {ns_type::ns_t_mf, "MF"},
112             {ns_type::ns_t_cname, "CNAME"},
113             {ns_type::ns_t_soa, "SOA"},
114             {ns_type::ns_t_mb, "MB"},
115             {ns_type::ns_t_mb, "MG"},
116             {ns_type::ns_t_mr, "MR"},
117             {ns_type::ns_t_null, "NULL"},
118             {ns_type::ns_t_wks, "WKS"},
119             {ns_type::ns_t_ptr, "PTR"},
120             {ns_type::ns_t_hinfo, "HINFO"},
121             {ns_type::ns_t_minfo, "MINFO"},
122             {ns_type::ns_t_mx, "MX"},
123             {ns_type::ns_t_txt, "TXT"},
124             {ns_type::ns_t_rp, "RP"},
125             {ns_type::ns_t_afsdb, "AFSDB"},
126             {ns_type::ns_t_x25, "X25"},
127             {ns_type::ns_t_isdn, "ISDN"},
128             {ns_type::ns_t_rt, "RT"},
129             {ns_type::ns_t_nsap, "NSAP"},
130             {ns_type::ns_t_nsap_ptr, "NSAP-PTR"},
131             {ns_type::ns_t_sig, "SIG"},
132             {ns_type::ns_t_key, "KEY"},
133             {ns_type::ns_t_px, "PX"},
134             {ns_type::ns_t_gpos, "GPOS"},
135             {ns_type::ns_t_aaaa, "AAAA"},
136             {ns_type::ns_t_loc, "LOC"},
137             {ns_type::ns_t_nxt, "NXT"},
138             {ns_type::ns_t_eid, "EID"},
139             {ns_type::ns_t_nimloc, "NIMLOC"},
140             {ns_type::ns_t_srv, "SRV"},
141             {ns_type::ns_t_naptr, "NAPTR"},
142             {ns_type::ns_t_kx, "KX"},
143             {ns_type::ns_t_cert, "CERT"},
144             {ns_type::ns_t_a6, "A6"},
145             {ns_type::ns_t_dname, "DNAME"},
146             {ns_type::ns_t_sink, "SINK"},
147             {ns_type::ns_t_opt, "OPT"},
148             {ns_type::ns_t_apl, "APL"},
149             {ns_type::ns_t_tkey, "TKEY"},
150             {ns_type::ns_t_tsig, "TSIG"},
151             {ns_type::ns_t_ixfr, "IXFR"},
152             {ns_type::ns_t_axfr, "AXFR"},
153             {ns_type::ns_t_mailb, "MAILB"},
154             {ns_type::ns_t_maila, "MAILA"},
155             {ns_type::ns_t_any, "ANY"},
156             {ns_type::ns_t_zxfr, "ZXFR"},
157     };
158     auto it = kTypeStrs.find(dnstype);
159     static const char* kUnknownStr{"UNKNOWN"};
160     if (it == kTypeStrs.end()) return kUnknownStr;
161     return it->second;
162 }
163 
dnsclass2str(unsigned dnsclass)164 const char* dnsclass2str(unsigned dnsclass) {
165     static std::unordered_map<unsigned, const char*> kClassStrs = {
166             {ns_class::ns_c_in, "Internet"},    {2, "CSNet"},
167             {ns_class::ns_c_chaos, "ChaosNet"}, {ns_class::ns_c_hs, "Hesiod"},
168             {ns_class::ns_c_none, "none"},      {ns_class::ns_c_any, "any"}};
169     auto it = kClassStrs.find(dnsclass);
170     static const char* kUnknownStr{"UNKNOWN"};
171     if (it == kClassStrs.end()) return kUnknownStr;
172     return it->second;
173 }
174 
dnsproto2str(int protocol)175 const char* dnsproto2str(int protocol) {
176     switch (protocol) {
177         case IPPROTO_TCP:
178             return "TCP";
179         case IPPROTO_UDP:
180             return "UDP";
181         default:
182             return "UNKNOWN";
183     }
184 }
185 
read(const char * buffer,const char * buffer_end)186 const char* DNSName::read(const char* buffer, const char* buffer_end) {
187     const char* cur = buffer;
188     bool last = false;
189     do {
190         cur = parseField(cur, buffer_end, &last);
191         if (cur == nullptr) {
192             LOG(ERROR) << "parsing failed at line " << __LINE__;
193             return nullptr;
194         }
195     } while (!last);
196     return cur;
197 }
198 
write(char * buffer,const char * buffer_end) const199 char* DNSName::write(char* buffer, const char* buffer_end) const {
200     char* buffer_cur = buffer;
201     for (size_t pos = 0; pos < name.size();) {
202         size_t dot_pos = name.find('.', pos);
203         if (dot_pos == std::string::npos) {
204             // Soundness check, should never happen unless parseField is broken.
205             LOG(ERROR) << "logic error: all names are expected to end with a '.'";
206             return nullptr;
207         }
208         const size_t len = dot_pos - pos;
209         if (len >= 256) {
210             LOG(ERROR) << "name component '" << name.substr(pos, dot_pos - pos) << "' is " << len
211                        << " long, but max is 255";
212             return nullptr;
213         }
214         if (buffer_cur + sizeof(uint8_t) + len > buffer_end) {
215             LOG(ERROR) << "buffer overflow at line " << __LINE__;
216             return nullptr;
217         }
218         *buffer_cur++ = len;
219         buffer_cur = std::copy(std::next(name.begin(), pos), std::next(name.begin(), dot_pos),
220                                buffer_cur);
221         pos = dot_pos + 1;
222     }
223     // Write final zero.
224     *buffer_cur++ = 0;
225     return buffer_cur;
226 }
227 
parseField(const char * buffer,const char * buffer_end,bool * last)228 const char* DNSName::parseField(const char* buffer, const char* buffer_end, bool* last) {
229     if (buffer + sizeof(uint8_t) > buffer_end) {
230         LOG(ERROR) << "parsing failed at line " << __LINE__;
231         return nullptr;
232     }
233     unsigned field_type = *buffer >> 6;
234     unsigned ofs = *buffer & 0x3F;
235     const char* cur = buffer + sizeof(uint8_t);
236     if (field_type == 0) {
237         // length + name component
238         if (ofs == 0) {
239             *last = true;
240             return cur;
241         }
242         if (cur + ofs > buffer_end) {
243             LOG(ERROR) << "parsing failed at line " << __LINE__;
244             return nullptr;
245         }
246         name.append(cur, ofs);
247         name.push_back('.');
248         return cur + ofs;
249     } else if (field_type == 3) {
250         LOG(ERROR) << "name compression not implemented";
251         return nullptr;
252     }
253     LOG(ERROR) << "invalid name field type";
254     return nullptr;
255 }
256 
read(const char * buffer,const char * buffer_end)257 const char* DNSQuestion::read(const char* buffer, const char* buffer_end) {
258     const char* cur = qname.read(buffer, buffer_end);
259     if (cur == nullptr) {
260         LOG(ERROR) << "parsing failed at line " << __LINE__;
261         return nullptr;
262     }
263     if (cur + 2 * sizeof(uint16_t) > buffer_end) {
264         LOG(ERROR) << "parsing failed at line " << __LINE__;
265         return nullptr;
266     }
267     qtype = ntohs(*reinterpret_cast<const uint16_t*>(cur));
268     qclass = ntohs(*reinterpret_cast<const uint16_t*>(cur + sizeof(uint16_t)));
269     return cur + 2 * sizeof(uint16_t);
270 }
271 
write(char * buffer,const char * buffer_end) const272 char* DNSQuestion::write(char* buffer, const char* buffer_end) const {
273     char* buffer_cur = qname.write(buffer, buffer_end);
274     if (buffer_cur == nullptr) return nullptr;
275     if (buffer_cur + 2 * sizeof(uint16_t) > buffer_end) {
276         LOG(ERROR) << "buffer overflow on line " << __LINE__;
277         return nullptr;
278     }
279     *reinterpret_cast<uint16_t*>(buffer_cur) = htons(qtype);
280     *reinterpret_cast<uint16_t*>(buffer_cur + sizeof(uint16_t)) = htons(qclass);
281     return buffer_cur + 2 * sizeof(uint16_t);
282 }
283 
toString() const284 std::string DNSQuestion::toString() const {
285     char buffer[16384];
286     int len = snprintf(buffer, sizeof(buffer), "Q<%s,%s,%s>", qname.name.c_str(),
287                        dnstype2str(qtype), dnsclass2str(qclass));
288     return std::string(buffer, len);
289 }
290 
read(const char * buffer,const char * buffer_end)291 const char* DNSRecord::read(const char* buffer, const char* buffer_end) {
292     const char* cur = name.read(buffer, buffer_end);
293     if (cur == nullptr) {
294         LOG(ERROR) << "parsing failed at line " << __LINE__;
295         return nullptr;
296     }
297     unsigned rdlen = 0;
298     cur = readIntFields(cur, buffer_end, &rdlen);
299     if (cur == nullptr) {
300         LOG(ERROR) << "parsing failed at line " << __LINE__;
301         return nullptr;
302     }
303     if (cur + rdlen > buffer_end) {
304         LOG(ERROR) << "parsing failed at line " << __LINE__;
305         return nullptr;
306     }
307     rdata.assign(cur, cur + rdlen);
308     return cur + rdlen;
309 }
310 
write(char * buffer,const char * buffer_end) const311 char* DNSRecord::write(char* buffer, const char* buffer_end) const {
312     char* buffer_cur = name.write(buffer, buffer_end);
313     if (buffer_cur == nullptr) return nullptr;
314     buffer_cur = writeIntFields(rdata.size(), buffer_cur, buffer_end);
315     if (buffer_cur == nullptr) return nullptr;
316     if (buffer_cur + rdata.size() > buffer_end) {
317         LOG(ERROR) << "buffer overflow on line " << __LINE__;
318         return nullptr;
319     }
320     return std::copy(rdata.begin(), rdata.end(), buffer_cur);
321 }
322 
toString() const323 std::string DNSRecord::toString() const {
324     char buffer[16384];
325     int len = snprintf(buffer, sizeof(buffer), "R<%s,%s,%s>", name.name.c_str(), dnstype2str(rtype),
326                        dnsclass2str(rclass));
327     return std::string(buffer, len);
328 }
329 
readIntFields(const char * buffer,const char * buffer_end,unsigned * rdlen)330 const char* DNSRecord::readIntFields(const char* buffer, const char* buffer_end, unsigned* rdlen) {
331     if (buffer + sizeof(IntFields) > buffer_end) {
332         LOG(ERROR) << "parsing failed at line " << __LINE__;
333         return nullptr;
334     }
335     const auto& intfields = *reinterpret_cast<const IntFields*>(buffer);
336     rtype = ntohs(intfields.rtype);
337     rclass = ntohs(intfields.rclass);
338     ttl = ntohl(intfields.ttl);
339     *rdlen = ntohs(intfields.rdlen);
340     return buffer + sizeof(IntFields);
341 }
342 
writeIntFields(unsigned rdlen,char * buffer,const char * buffer_end) const343 char* DNSRecord::writeIntFields(unsigned rdlen, char* buffer, const char* buffer_end) const {
344     if (buffer + sizeof(IntFields) > buffer_end) {
345         LOG(ERROR) << "buffer overflow on line " << __LINE__;
346         return nullptr;
347     }
348     auto& intfields = *reinterpret_cast<IntFields*>(buffer);
349     intfields.rtype = htons(rtype);
350     intfields.rclass = htons(rclass);
351     intfields.ttl = htonl(ttl);
352     intfields.rdlen = htons(rdlen);
353     return buffer + sizeof(IntFields);
354 }
355 
read(const char * buffer,const char * buffer_end)356 const char* DNSHeader::read(const char* buffer, const char* buffer_end) {
357     unsigned qdcount;
358     unsigned ancount;
359     unsigned nscount;
360     unsigned arcount;
361     const char* cur = readHeader(buffer, buffer_end, &qdcount, &ancount, &nscount, &arcount);
362     if (cur == nullptr) {
363         LOG(ERROR) << "parsing failed at line " << __LINE__;
364         return nullptr;
365     }
366     if (qdcount) {
367         questions.resize(qdcount);
368         for (unsigned i = 0; i < qdcount; ++i) {
369             cur = questions[i].read(cur, buffer_end);
370             if (cur == nullptr) {
371                 LOG(ERROR) << "parsing failed at line " << __LINE__;
372                 return nullptr;
373             }
374         }
375     }
376     if (ancount) {
377         answers.resize(ancount);
378         for (unsigned i = 0; i < ancount; ++i) {
379             cur = answers[i].read(cur, buffer_end);
380             if (cur == nullptr) {
381                 LOG(ERROR) << "parsing failed at line " << __LINE__;
382                 return nullptr;
383             }
384         }
385     }
386     if (nscount) {
387         authorities.resize(nscount);
388         for (unsigned i = 0; i < nscount; ++i) {
389             cur = authorities[i].read(cur, buffer_end);
390             if (cur == nullptr) {
391                 LOG(ERROR) << "parsing failed at line " << __LINE__;
392                 return nullptr;
393             }
394         }
395     }
396     if (arcount) {
397         additionals.resize(arcount);
398         for (unsigned i = 0; i < arcount; ++i) {
399             cur = additionals[i].read(cur, buffer_end);
400             if (cur == nullptr) {
401                 LOG(ERROR) << "parsing failed at line " << __LINE__;
402                 return nullptr;
403             }
404         }
405     }
406     return cur;
407 }
408 
write(char * buffer,const char * buffer_end) const409 char* DNSHeader::write(char* buffer, const char* buffer_end) const {
410     if (buffer + sizeof(Header) > buffer_end) {
411         LOG(ERROR) << "buffer overflow on line " << __LINE__;
412         return nullptr;
413     }
414     Header& header = *reinterpret_cast<Header*>(buffer);
415     // bytes 0-1
416     header.id = htons(id);
417     // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
418     header.flags0 = (qr << 7) | (opcode << 3) | (aa << 2) | (tr << 1) | rd;
419     // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
420     // Fake behavior: if the query set the "ad" bit, set it in the response too.
421     // In a real server, this should be set only if the data is authentic and the
422     // query contained an "ad" bit or DNSSEC extensions.
423     header.flags1 = (ad << 5) | rcode;
424     // rest of header
425     header.qdcount = htons(questions.size());
426     header.ancount = htons(answers.size());
427     header.nscount = htons(authorities.size());
428     header.arcount = htons(additionals.size());
429     char* buffer_cur = buffer + sizeof(Header);
430     for (const DNSQuestion& question : questions) {
431         buffer_cur = question.write(buffer_cur, buffer_end);
432         if (buffer_cur == nullptr) return nullptr;
433     }
434     for (const DNSRecord& answer : answers) {
435         buffer_cur = answer.write(buffer_cur, buffer_end);
436         if (buffer_cur == nullptr) return nullptr;
437     }
438     for (const DNSRecord& authority : authorities) {
439         buffer_cur = authority.write(buffer_cur, buffer_end);
440         if (buffer_cur == nullptr) return nullptr;
441     }
442     for (const DNSRecord& additional : additionals) {
443         buffer_cur = additional.write(buffer_cur, buffer_end);
444         if (buffer_cur == nullptr) return nullptr;
445     }
446     return buffer_cur;
447 }
448 
449 // TODO: convert all callers to this interface, then delete the old one.
write(std::vector<uint8_t> * out) const450 bool DNSHeader::write(std::vector<uint8_t>* out) const {
451     char buffer[16384];
452     char* end = this->write(buffer, buffer + sizeof buffer);
453     if (end == nullptr) return false;
454     out->insert(out->end(), buffer, end);
455     return true;
456 }
457 
toString() const458 std::string DNSHeader::toString() const {
459     // TODO
460     return std::string();
461 }
462 
readHeader(const char * buffer,const char * buffer_end,unsigned * qdcount,unsigned * ancount,unsigned * nscount,unsigned * arcount)463 const char* DNSHeader::readHeader(const char* buffer, const char* buffer_end, unsigned* qdcount,
464                                   unsigned* ancount, unsigned* nscount, unsigned* arcount) {
465     if (buffer + sizeof(Header) > buffer_end) return nullptr;
466     const auto& header = *reinterpret_cast<const Header*>(buffer);
467     // bytes 0-1
468     id = ntohs(header.id);
469     // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
470     qr = header.flags0 >> 7;
471     opcode = (header.flags0 >> 3) & 0x0F;
472     aa = (header.flags0 >> 2) & 1;
473     tr = (header.flags0 >> 1) & 1;
474     rd = header.flags0 & 1;
475     // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
476     ra = header.flags1 >> 7;
477     ad = (header.flags1 >> 5) & 1;
478     rcode = header.flags1 & 0xF;
479     // rest of header
480     *qdcount = ntohs(header.qdcount);
481     *ancount = ntohs(header.ancount);
482     *nscount = ntohs(header.nscount);
483     *arcount = ntohs(header.arcount);
484     return buffer + sizeof(Header);
485 }
486 
487 /* DNS responder */
488 
DNSResponder(std::string listen_address,std::string listen_service,ns_rcode error_rcode,MappingType mapping_type)489 DNSResponder::DNSResponder(std::string listen_address, std::string listen_service,
490                            ns_rcode error_rcode, MappingType mapping_type)
491     : listen_address_(std::move(listen_address)),
492       listen_service_(std::move(listen_service)),
493       error_rcode_(error_rcode),
494       mapping_type_(mapping_type) {}
495 
~DNSResponder()496 DNSResponder::~DNSResponder() {
497     stopServer();
498 }
499 
addMapping(const std::string & name,ns_type type,const std::string & addr)500 void DNSResponder::addMapping(const std::string& name, ns_type type, const std::string& addr) {
501     std::lock_guard lock(mappings_mutex_);
502     mappings_[{name, type}] = addr;
503 }
504 
addMappingDnsHeader(const std::string & name,ns_type type,const DNSHeader & header)505 void DNSResponder::addMappingDnsHeader(const std::string& name, ns_type type,
506                                        const DNSHeader& header) {
507     std::lock_guard lock(mappings_mutex_);
508     dnsheader_mappings_[{name, type}] = header;
509 }
510 
addMappingBinaryPacket(const std::vector<uint8_t> & query,const std::vector<uint8_t> & response)511 void DNSResponder::addMappingBinaryPacket(const std::vector<uint8_t>& query,
512                                           const std::vector<uint8_t>& response) {
513     std::lock_guard lock(mappings_mutex_);
514     packet_mappings_[query] = response;
515 }
516 
removeMapping(const std::string & name,ns_type type)517 void DNSResponder::removeMapping(const std::string& name, ns_type type) {
518     std::lock_guard lock(mappings_mutex_);
519     if (!mappings_.erase({name, type})) {
520         LOG(ERROR) << "Cannot remove mapping from (" << name << ", " << dnstype2str(type)
521                    << "), not present in registered mappings";
522     }
523 }
524 
removeMappingDnsHeader(const std::string & name,ns_type type)525 void DNSResponder::removeMappingDnsHeader(const std::string& name, ns_type type) {
526     std::lock_guard lock(mappings_mutex_);
527     if (!dnsheader_mappings_.erase({name, type})) {
528         LOG(ERROR) << "Cannot remove mapping from (" << name << ", " << dnstype2str(type)
529                    << "), not present in registered DnsHeader mappings";
530     }
531 }
532 
removeMappingBinaryPacket(const std::vector<uint8_t> & query)533 void DNSResponder::removeMappingBinaryPacket(const std::vector<uint8_t>& query) {
534     std::lock_guard lock(mappings_mutex_);
535     if (!packet_mappings_.erase(query)) {
536         LOG(ERROR) << "Cannot remove mapping, not present in registered BinaryPacket mappings";
537         LOG(INFO) << "Hex dump:";
538         LOG(INFO) << bytesToHexStr(query);
539     }
540 }
541 
542 // Set response probability on all supported protocols.
setResponseProbability(double response_probability)543 void DNSResponder::setResponseProbability(double response_probability) {
544     setResponseProbability(response_probability, IPPROTO_TCP);
545     setResponseProbability(response_probability, IPPROTO_UDP);
546 }
547 
setResponseDelayMs(unsigned timeMs)548 void DNSResponder::setResponseDelayMs(unsigned timeMs) {
549     response_delayed_ms_ = timeMs;
550 }
551 
552 // Set response probability on specific protocol. It's caller's duty to ensure that the |protocol|
553 // can be supported by DNSResponder.
setResponseProbability(double response_probability,int protocol)554 void DNSResponder::setResponseProbability(double response_probability, int protocol) {
555     switch (protocol) {
556         case IPPROTO_TCP:
557             response_probability_tcp_ = response_probability;
558             break;
559         case IPPROTO_UDP:
560             response_probability_udp_ = response_probability;
561             break;
562         default:
563             LOG(FATAL) << "Unsupported protocol " << protocol;  // abort() by log level FATAL
564     }
565 }
566 
getResponseProbability(int protocol) const567 double DNSResponder::getResponseProbability(int protocol) const {
568     switch (protocol) {
569         case IPPROTO_TCP:
570             return response_probability_tcp_;
571         case IPPROTO_UDP:
572             return response_probability_udp_;
573         default:
574             LOG(FATAL) << "Unsupported protocol " << protocol;  // abort() by log level FATAL
575             // unreachable
576             return -1;
577     }
578 }
579 
setEdns(Edns edns)580 void DNSResponder::setEdns(Edns edns) {
581     edns_ = edns;
582 }
583 
setTtl(unsigned ttl)584 void DNSResponder::setTtl(unsigned ttl) {
585     answer_record_ttl_sec_ = ttl;
586 }
587 
running() const588 bool DNSResponder::running() const {
589     if (listen_service_ == kDefaultMdnsListenService)
590         return udp_socket_.ok();
591     else
592         return (udp_socket_.ok()) && (tcp_socket_.ok());
593 }
594 
startServer()595 bool DNSResponder::startServer() {
596     if (running()) {
597         LOG(ERROR) << "server already running";
598         return false;
599     }
600 
601     // Create UDP, TCP socket
602     if (udp_socket_ = createListeningSocket(SOCK_DGRAM); udp_socket_.get() < 0) {
603         PLOG(ERROR) << "failed to create UDP socket";
604         return false;
605     }
606 
607     if (listen_service_ != kDefaultMdnsListenService) {
608         if (tcp_socket_ = createListeningSocket(SOCK_STREAM); tcp_socket_.get() < 0) {
609             PLOG(ERROR) << "failed to create TCP socket";
610             return false;
611         }
612 
613         if (listen(tcp_socket_.get(), 1) < 0) {
614             PLOG(ERROR) << "failed to listen TCP socket";
615             return false;
616         }
617     }
618 
619     // Set up eventfd socket.
620     event_fd_.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));
621     if (event_fd_.get() == -1) {
622         PLOG(ERROR) << "failed to create eventfd";
623         return false;
624     }
625 
626     // Set up epoll socket.
627     epoll_fd_.reset(epoll_create1(EPOLL_CLOEXEC));
628     if (epoll_fd_.get() < 0) {
629         PLOG(ERROR) << "epoll_create1() failed on fd";
630         return false;
631     }
632 
633     LOG(INFO) << "adding UDP socket to epoll";
634     if (!addFd(udp_socket_.get(), EPOLLIN)) {
635         LOG(ERROR) << "failed to add the UDP socket to epoll";
636         return false;
637     }
638 
639     if (listen_service_ != kDefaultMdnsListenService) {
640         LOG(INFO) << "adding TCP socket to epoll";
641         if (!addFd(tcp_socket_.get(), EPOLLIN)) {
642             LOG(ERROR) << "failed to add the TCP socket to epoll";
643             return false;
644         }
645     }
646 
647     LOG(INFO) << "adding eventfd to epoll";
648     if (!addFd(event_fd_.get(), EPOLLIN)) {
649         LOG(ERROR) << "failed to add the eventfd to epoll";
650         return false;
651     }
652 
653     {
654         std::lock_guard lock(update_mutex_);
655         handler_thread_ = std::thread(&DNSResponder::requestHandler, this);
656     }
657     LOG(INFO) << "server started successfully";
658     return true;
659 }
660 
stopServer()661 bool DNSResponder::stopServer() {
662     std::lock_guard lock(update_mutex_);
663     if (!running()) {
664         LOG(ERROR) << "server not running";
665         return false;
666     }
667     LOG(INFO) << "stopping server";
668     if (!sendToEventFd()) {
669         return false;
670     }
671     handler_thread_.join();
672     epoll_fd_.reset();
673     event_fd_.reset();
674     udp_socket_.reset();
675     tcp_socket_.reset();
676     LOG(INFO) << "server stopped successfully";
677     return true;
678 }
679 
queries() const680 std::vector<DNSResponder::QueryInfo> DNSResponder::queries() const {
681     std::lock_guard lock(queries_mutex_);
682     return queries_;
683 }
684 
dumpQueries() const685 std::string DNSResponder::dumpQueries() const {
686     std::lock_guard lock(queries_mutex_);
687     std::string out;
688 
689     for (const auto& q : queries_) {
690         out += "{\"" + q.name + "\", " + std::to_string(q.type) + "\", " +
691                dnsproto2str(q.protocol) + "} ";
692     }
693     return out;
694 }
695 
clearQueries()696 void DNSResponder::clearQueries() {
697     std::lock_guard lock(queries_mutex_);
698     queries_.clear();
699 }
700 
hasOptPseudoRR(DNSHeader * header) const701 bool DNSResponder::hasOptPseudoRR(DNSHeader* header) const {
702     if (header->additionals.empty()) return false;
703 
704     // OPT RR may be placed anywhere within the additional section. See RFC 6891 section 6.1.1.
705     auto found = std::find_if(header->additionals.begin(), header->additionals.end(),
706                               [](const auto& a) { return a.rtype == ns_type::ns_t_opt; });
707     return found != header->additionals.end();
708 }
709 
requestHandler()710 void DNSResponder::requestHandler() {
711     epoll_event evs[EPOLL_MAX_EVENTS];
712     while (true) {
713         int n = epoll_wait(epoll_fd_.get(), evs, EPOLL_MAX_EVENTS, -1);
714         if (n <= 0) {
715             PLOG(ERROR) << "epoll_wait() failed, n=" << n;
716             return;
717         }
718 
719         for (int i = 0; i < n; i++) {
720             const int fd = evs[i].data.fd;
721             const uint32_t events = evs[i].events;
722             if (fd == event_fd_.get() && (events & (EPOLLIN | EPOLLERR))) {
723                 handleEventFd();
724                 return;
725             } else if (fd == udp_socket_.get() && (events & (EPOLLIN | EPOLLERR))) {
726                 handleQuery(IPPROTO_UDP);
727             } else if (fd == tcp_socket_.get() && (events & (EPOLLIN | EPOLLERR))) {
728                 handleQuery(IPPROTO_TCP);
729             } else {
730                 LOG(WARNING) << "unexpected epoll events " << events << " on fd " << fd;
731             }
732         }
733     }
734 }
735 
handleDNSRequest(const char * buffer,ssize_t len,int protocol,char * response,size_t * response_len) const736 bool DNSResponder::handleDNSRequest(const char* buffer, ssize_t len, int protocol, char* response,
737                                     size_t* response_len) const {
738     LOG(DEBUG) << "request: '" << bytesToHexStr({reinterpret_cast<const uint8_t*>(buffer), len})
739                << "', on " << dnsproto2str(protocol);
740     const char* buffer_end = buffer + len;
741     DNSHeader header;
742     const char* cur = header.read(buffer, buffer_end);
743     // TODO(imaipi): for now, unparsable messages are silently dropped, fix.
744     if (cur == nullptr) {
745         LOG(ERROR) << "failed to parse query";
746         return false;
747     }
748     if (header.qr) {
749         LOG(ERROR) << "response received instead of a query";
750         return false;
751     }
752     if (header.opcode != ns_opcode::ns_o_query) {
753         LOG(INFO) << "unsupported request opcode received";
754         return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response, response_len);
755     }
756     if (header.questions.empty()) {
757         LOG(INFO) << "no questions present";
758         return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, response_len);
759     }
760     if (!header.answers.empty()) {
761         LOG(INFO) << "already " << header.answers.size() << " answers present in query";
762         return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, response_len);
763     }
764 
765     if (edns_ == Edns::FORMERR_UNCOND) {
766         LOG(INFO) << "force to return RCODE FORMERR";
767         return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, response_len);
768     }
769 
770     if (!header.additionals.empty() && edns_ != Edns::ON) {
771         LOG(INFO) << "DNS request has an additional section (assumed EDNS). Simulating an ancient "
772                      "(pre-EDNS) server, and returning "
773                   << (edns_ == Edns::FORMERR_ON_EDNS ? "RCODE FORMERR." : "no response.");
774         if (edns_ == Edns::FORMERR_ON_EDNS) {
775             return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, response_len);
776         }
777         // No response.
778         return false;
779     }
780     {
781         std::lock_guard lock(queries_mutex_);
782         for (const DNSQuestion& question : header.questions) {
783             queries_.push_back({question.qname.name, ns_type(question.qtype), protocol});
784         }
785     }
786     // Ignore requests with the preset probability.
787     auto constexpr bound = std::numeric_limits<unsigned>::max();
788     if (arc4random_uniform(bound) > bound * getResponseProbability(protocol)) {
789         if (error_rcode_ < 0) {
790             LOG(ERROR) << "Returning no response";
791             return false;
792         } else {
793             LOG(INFO) << "returning RCODE " << static_cast<int>(error_rcode_)
794                       << " in accordance with probability distribution";
795             return makeErrorResponse(&header, error_rcode_, response, response_len);
796         }
797     }
798 
799     // Make the response. The query has been read into |header| which is used to build and return
800     // the response as well.
801     return makeResponse(&header, protocol, response, response_len);
802 }
803 
addAnswerRecords(const DNSQuestion & question,std::vector<DNSRecord> * answers) const804 bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
805                                     std::vector<DNSRecord>* answers) const {
806     std::lock_guard guard(mappings_mutex_);
807     std::string rname = question.qname.name;
808     std::vector<int> rtypes;
809 
810     if (question.qtype == ns_type::ns_t_a || question.qtype == ns_type::ns_t_aaaa ||
811         question.qtype == ns_type::ns_t_ptr)
812         rtypes.push_back(ns_type::ns_t_cname);
813     rtypes.push_back(question.qtype);
814     for (int rtype : rtypes) {
815         std::set<std::string> cnames_Loop;
816         std::unordered_map<QueryKey, std::string, QueryKeyHash>::const_iterator it;
817         while ((it = mappings_.find(QueryKey(rname, rtype))) != mappings_.end()) {
818             if (rtype == ns_type::ns_t_cname) {
819                 // When detect CNAME infinite loops by cnames_Loop, it won't save the duplicate one.
820                 // As following, the query will stop on loop3 by detecting the same cname.
821                 // loop1.{"a.xxx.com", ns_type::ns_t_cname, "b.xxx.com"}(insert in answer record)
822                 // loop2.{"b.xxx.com", ns_type::ns_t_cname, "a.xxx.com"}(insert in answer record)
823                 // loop3.{"a.xxx.com", ns_type::ns_t_cname, "b.xxx.com"}(When the same cname record
824                 //   is found in cnames_Loop already, break the query loop.)
825                 if (cnames_Loop.find(it->first.name) != cnames_Loop.end()) break;
826                 cnames_Loop.insert(it->first.name);
827             }
828             DNSRecord record{
829                     .name = {.name = it->first.name},
830                     .rtype = it->first.type,
831                     .rclass = ns_class::ns_c_in,
832                     .ttl = answer_record_ttl_sec_,  // seconds
833             };
834             if (!fillRdata(it->second, record)) return false;
835             answers->push_back(std::move(record));
836             if (rtype != ns_type::ns_t_cname) break;
837             rname = it->second;
838         }
839     }
840 
841     if (answers->size() == 0) {
842         // TODO(imaipi): handle correctly
843         LOG(INFO) << "no mapping found for " << question.qname.name << " "
844                   << dnstype2str(question.qtype) << ", lazily refusing to add an answer";
845     }
846 
847     return true;
848 }
849 
fillRdata(const std::string & rdatastr,DNSRecord & record)850 bool DNSResponder::fillRdata(const std::string& rdatastr, DNSRecord& record) {
851     if (record.rtype == ns_type::ns_t_a) {
852         record.rdata.resize(4);
853         if (inet_pton(AF_INET, rdatastr.c_str(), record.rdata.data()) != 1) {
854             LOG(ERROR) << "inet_pton(AF_INET, " << rdatastr << ") failed";
855             return false;
856         }
857     } else if (record.rtype == ns_type::ns_t_aaaa) {
858         record.rdata.resize(16);
859         if (inet_pton(AF_INET6, rdatastr.c_str(), record.rdata.data()) != 1) {
860             LOG(ERROR) << "inet_pton(AF_INET6, " << rdatastr << ") failed";
861             return false;
862         }
863     } else if ((record.rtype == ns_type::ns_t_ptr) || (record.rtype == ns_type::ns_t_cname) ||
864                (record.rtype == ns_type::ns_t_ns)) {
865         constexpr char delimiter = '.';
866         std::string name = rdatastr;
867         std::vector<char> rdata;
868 
869         // Generating PTRDNAME field(section 3.3.12) or CNAME field(section 3.3.1) in rfc1035.
870         // The "name" should be an absolute domain name which ends in a dot.
871         if (name.back() != delimiter) {
872             LOG(ERROR) << "invalid absolute domain name";
873             return false;
874         }
875         name.pop_back();  // remove the dot in tail
876         for (const std::string& label : android::base::Split(name, {delimiter})) {
877             // The length of label is limited to 63 octets or less. See RFC 1035 section 3.1.
878             if (label.length() == 0 || label.length() > 63) {
879                 LOG(ERROR) << "invalid label length";
880                 return false;
881             }
882 
883             rdata.push_back(label.length());
884             rdata.insert(rdata.end(), label.begin(), label.end());
885         }
886         rdata.push_back(0);  // Length byte of zero terminates the label list
887 
888         // The length of domain name is limited to 255 octets or less. See RFC 1035 section 3.1.
889         if (rdata.size() > 255) {
890             LOG(ERROR) << "invalid name length";
891             return false;
892         }
893         record.rdata = move(rdata);
894     } else {
895         LOG(ERROR) << "unhandled qtype " << dnstype2str(record.rtype);
896         return false;
897     }
898     return true;
899 }
900 
writePacket(const DNSHeader * header,char * response,size_t * response_len) const901 bool DNSResponder::writePacket(const DNSHeader* header, char* response,
902                                size_t* response_len) const {
903     char* response_cur = header->write(response, response + *response_len);
904     if (response_cur == nullptr) {
905         return false;
906     }
907     *response_len = response_cur - response;
908     return true;
909 }
910 
makeErrorResponse(DNSHeader * header,ns_rcode rcode,char * response,size_t * response_len) const911 bool DNSResponder::makeErrorResponse(DNSHeader* header, ns_rcode rcode, char* response,
912                                      size_t* response_len) const {
913     header->answers.clear();
914     header->authorities.clear();
915     header->additionals.clear();
916     header->rcode = rcode;
917     header->qr = true;
918     return writePacket(header, response, response_len);
919 }
920 
makeTruncatedResponse(DNSHeader * header,char * response,size_t * response_len) const921 bool DNSResponder::makeTruncatedResponse(DNSHeader* header, char* response,
922                                          size_t* response_len) const {
923     // Build a minimal response for non-EDNS response over UDP. Truncate all stub RRs in answer,
924     // authority and additional section. EDNS response truncation has not supported here yet
925     // because the EDNS response must have an OPT record. See RFC 6891 section 7.
926     header->answers.clear();
927     header->authorities.clear();
928     header->additionals.clear();
929     header->qr = true;
930     header->tr = true;
931     return writePacket(header, response, response_len);
932 }
933 
makeResponse(DNSHeader * header,int protocol,char * response,size_t * response_len) const934 bool DNSResponder::makeResponse(DNSHeader* header, int protocol, char* response,
935                                 size_t* response_len) const {
936     char buffer[16384];
937     size_t buffer_len = sizeof(buffer);
938     bool ret;
939 
940     switch (mapping_type_) {
941         case MappingType::DNS_HEADER:
942             ret = makeResponseFromDnsHeader(header, buffer, &buffer_len);
943             break;
944         case MappingType::BINARY_PACKET:
945             ret = makeResponseFromBinaryPacket(header, buffer, &buffer_len);
946             break;
947         case MappingType::ADDRESS_OR_HOSTNAME:
948         default:
949             ret = makeResponseFromAddressOrHostname(header, buffer, &buffer_len);
950     }
951 
952     if (!ret) return false;
953 
954     // Return truncated response if the built non-EDNS response size which is larger than 512 bytes
955     // will be responded over UDP. The truncated response implementation here just simply set up
956     // the TC bit and truncate all stub RRs in answer, authority and additional section. It is
957     // because the resolver will retry DNS query over TCP and use the full TCP response. See also
958     // RFC 1035 section 4.2.1 for UDP response truncation and RFC 6891 section 4.3 for EDNS larger
959     // response size capability.
960     // TODO: Perhaps keep the stub RRs as possible.
961     // TODO: Perhaps truncate the EDNS based response over UDP. See also RFC 6891 section 4.3,
962     // section 6.2.5 and section 7.
963     if (protocol == IPPROTO_UDP && buffer_len > kMaximumUdpSize &&
964         !hasOptPseudoRR(header) /* non-EDNS */) {
965         LOG(INFO) << "Return truncated response because original response length " << buffer_len
966                   << " is larger than " << kMaximumUdpSize << " bytes.";
967         return makeTruncatedResponse(header, response, response_len);
968     }
969 
970     if (buffer_len > *response_len) {
971         LOG(ERROR) << "buffer overflow on line " << __LINE__;
972         return false;
973     }
974     memcpy(response, buffer, buffer_len);
975     *response_len = buffer_len;
976     return true;
977 }
978 
makeResponseFromAddressOrHostname(DNSHeader * header,char * response,size_t * response_len) const979 bool DNSResponder::makeResponseFromAddressOrHostname(DNSHeader* header, char* response,
980                                                      size_t* response_len) const {
981     for (const DNSQuestion& question : header->questions) {
982         if (question.qclass != ns_class::ns_c_in && question.qclass != ns_class::ns_c_any) {
983             LOG(INFO) << "unsupported question class " << question.qclass;
984             return makeErrorResponse(header, ns_rcode::ns_r_notimpl, response, response_len);
985         }
986 
987         if (!addAnswerRecords(question, &header->answers)) {
988             return makeErrorResponse(header, ns_rcode::ns_r_servfail, response, response_len);
989         }
990     }
991     header->qr = true;
992     return writePacket(header, response, response_len);
993 }
994 
makeResponseFromDnsHeader(DNSHeader * header,char * response,size_t * response_len) const995 bool DNSResponder::makeResponseFromDnsHeader(DNSHeader* header, char* response,
996                                              size_t* response_len) const {
997     std::lock_guard guard(mappings_mutex_);
998 
999     // Support single question record only. It should be okay because res_mkquery() sets "qdcount"
1000     // as one for the operation QUERY and handleDNSRequest() checks ns_opcode::ns_o_query before
1001     // making a response. In other words, only need to handle the query which has single question
1002     // section. See also res_mkquery() in system/netd/resolv/res_mkquery.cpp.
1003     // TODO: Perhaps add support for multi-question records.
1004     const std::vector<DNSQuestion>& questions = header->questions;
1005     if (questions.size() != 1) {
1006         LOG(INFO) << "unsupported question count " << questions.size();
1007         return makeErrorResponse(header, ns_rcode::ns_r_notimpl, response, response_len);
1008     }
1009 
1010     if (questions[0].qclass != ns_class::ns_c_in && questions[0].qclass != ns_class::ns_c_any) {
1011         LOG(INFO) << "unsupported question class " << questions[0].qclass;
1012         return makeErrorResponse(header, ns_rcode::ns_r_notimpl, response, response_len);
1013     }
1014 
1015     const std::string name = questions[0].qname.name;
1016     const int qtype = questions[0].qtype;
1017     const auto it = dnsheader_mappings_.find(QueryKey(name, qtype));
1018     if (it != dnsheader_mappings_.end()) {
1019         // Store both "id" and "rd" which comes from query.
1020         const unsigned id = header->id;
1021         const bool rd = header->rd;
1022 
1023         // Build a response from the registered DNSHeader mapping.
1024         *header = it->second;
1025         // Assign both "ID" and "RD" fields from query to response. See RFC 1035 section 4.1.1.
1026         header->id = id;
1027         header->rd = rd;
1028     } else {
1029         // TODO: handle correctly. See also TODO in addAnswerRecords().
1030         LOG(INFO) << "no mapping found for " << name << " " << dnstype2str(qtype)
1031                   << ", couldn't build a response from DNSHeader mapping";
1032 
1033         // Note that do nothing as makeResponseFromAddressOrHostname() if no mapping is found. It
1034         // just changes the QR flag from query (0) to response (1) in the query. Then, send the
1035         // modified query back as a response.
1036         header->qr = true;
1037     }
1038     return writePacket(header, response, response_len);
1039 }
1040 
makeResponseFromBinaryPacket(DNSHeader * header,char * response,size_t * response_len) const1041 bool DNSResponder::makeResponseFromBinaryPacket(DNSHeader* header, char* response,
1042                                                 size_t* response_len) const {
1043     std::lock_guard guard(mappings_mutex_);
1044 
1045     // Build a search key of mapping from the query.
1046     // TODO: Perhaps pass the query packet buffer directly from the caller.
1047     std::vector<uint8_t> queryKey;
1048     if (!header->write(&queryKey)) return false;
1049     // Clear ID field (byte 0-1) because it is not required by the mapping key.
1050     queryKey[0] = 0;
1051     queryKey[1] = 0;
1052 
1053     const auto it = packet_mappings_.find(queryKey);
1054     if (it != packet_mappings_.end()) {
1055         if (it->second.size() > *response_len) {
1056             LOG(ERROR) << "buffer overflow on line " << __LINE__;
1057             return false;
1058         } else {
1059             std::copy(it->second.begin(), it->second.end(), response);
1060             // Leave the "RD" flag assignment for testing. The "RD" flag of the response keep
1061             // using the one from the raw packet mapping but the received query.
1062             // Assign "ID" field from query to response. See RFC 1035 section 4.1.1.
1063             reinterpret_cast<uint16_t*>(response)[0] = htons(header->id);  // bytes 0-1: id
1064             *response_len = it->second.size();
1065             return true;
1066         }
1067     } else {
1068         // TODO: handle correctly. See also TODO in addAnswerRecords().
1069         // TODO: Perhaps dump packet content to indicate which query failed.
1070         LOG(INFO) << "no mapping found, couldn't build a response from BinaryPacket mapping";
1071         // Note that do nothing as makeResponseFromAddressOrHostname() if no mapping is found. It
1072         // just changes the QR flag from query (0) to response (1) in the query. Then, send the
1073         // modified query back as a response.
1074         header->qr = true;
1075         return writePacket(header, response, response_len);
1076     }
1077 }
1078 
setDeferredResp(bool deferred_resp)1079 void DNSResponder::setDeferredResp(bool deferred_resp) {
1080     std::lock_guard<std::mutex> guard(cv_mutex_for_deferred_resp_);
1081     deferred_resp_ = deferred_resp;
1082     if (!deferred_resp_) {
1083         cv_for_deferred_resp_.notify_one();
1084     }
1085 }
1086 
addFd(int fd,uint32_t events)1087 bool DNSResponder::addFd(int fd, uint32_t events) {
1088     epoll_event ev;
1089     ev.events = events;
1090     ev.data.fd = fd;
1091     if (epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, fd, &ev) < 0) {
1092         PLOG(ERROR) << "epoll_ctl() for socket " << fd << " failed";
1093         return false;
1094     }
1095     return true;
1096 }
1097 
handleQuery(int protocol)1098 void DNSResponder::handleQuery(int protocol) {
1099     char buffer[16384];
1100     sockaddr_storage sa;
1101     socklen_t sa_len = sizeof(sa);
1102     ssize_t len = 0;
1103     unique_fd tcpFd;
1104     switch (protocol) {
1105         case IPPROTO_UDP:
1106             do {
1107                 len = recvfrom(udp_socket_.get(), buffer, sizeof(buffer), 0, (sockaddr*)&sa,
1108                                &sa_len);
1109             } while (len < 0 && (errno == EAGAIN || errno == EINTR));
1110             if (len <= 0) {
1111                 PLOG(ERROR) << "recvfrom() failed, len=" << len;
1112                 return;
1113             }
1114             break;
1115         case IPPROTO_TCP:
1116             tcpFd.reset(accept4(tcp_socket_.get(), reinterpret_cast<sockaddr*>(&sa), &sa_len,
1117                                 SOCK_CLOEXEC));
1118             if (tcpFd.get() < 0) {
1119                 PLOG(ERROR) << "failed to accept client socket";
1120                 return;
1121             }
1122             // Get the message length from two byte length field.
1123             // See also RFC 1035, section 4.2.2 and RFC 7766, section 8
1124             uint8_t queryMessageLengthField[2];
1125             if (read(tcpFd.get(), &queryMessageLengthField, 2) != 2) {
1126                 PLOG(ERROR) << "Not enough length field bytes";
1127                 return;
1128             }
1129 
1130             const uint16_t qlen = (queryMessageLengthField[0] << 8) | queryMessageLengthField[1];
1131             while (len < qlen) {
1132                 ssize_t ret = read(tcpFd.get(), buffer + len, qlen - len);
1133                 if (ret <= 0) {
1134                     PLOG(ERROR) << "Error while reading query";
1135                     return;
1136                 }
1137                 len += ret;
1138             }
1139             break;
1140     }
1141     LOG(DEBUG) << "read " << len << " bytes on " << dnsproto2str(protocol);
1142     std::lock_guard lock(cv_mutex_);
1143     char response[16384];
1144     size_t response_len = sizeof(response);
1145     // TODO: check whether sending malformed packets to DnsResponder
1146     if (handleDNSRequest(buffer, len, protocol, response, &response_len) && response_len > 0) {
1147         std::this_thread::sleep_for(std::chrono::milliseconds(response_delayed_ms_));
1148         // place wait_for after handleDNSRequest() so we can check the number of queries in
1149         // test case before it got responded.
1150         std::unique_lock guard(cv_mutex_for_deferred_resp_);
1151         cv_for_deferred_resp_.wait(
1152                 guard, [this]() REQUIRES(cv_mutex_for_deferred_resp_) { return !deferred_resp_; });
1153         len = 0;
1154 
1155         switch (protocol) {
1156             case IPPROTO_UDP:
1157                 len = sendto(udp_socket_.get(), response, response_len, 0,
1158                              reinterpret_cast<const sockaddr*>(&sa), sa_len);
1159                 if (len < 0) {
1160                     PLOG(ERROR) << "Failed to send response";
1161                 }
1162                 break;
1163             case IPPROTO_TCP:
1164                 // Get the message length from two byte length field.
1165                 // See also RFC 1035, section 4.2.2 and RFC 7766, section 8
1166                 uint8_t responseMessageLengthField[2];
1167                 responseMessageLengthField[0] = response_len >> 8;
1168                 responseMessageLengthField[1] = response_len;
1169                 if (write(tcpFd.get(), responseMessageLengthField, 2) != 2) {
1170                     PLOG(ERROR) << "Failed to write response length field";
1171                     break;
1172                 }
1173                 if (write(tcpFd.get(), response, response_len) !=
1174                     static_cast<ssize_t>(response_len)) {
1175                     PLOG(ERROR) << "Failed to write response";
1176                     break;
1177                 }
1178                 len = response_len;
1179                 break;
1180         }
1181         const std::string host_str = addr2str(reinterpret_cast<const sockaddr*>(&sa), sa_len);
1182         if (len > 0) {
1183             LOG(DEBUG) << "sent " << len << " bytes to " << host_str;
1184         } else {
1185             const char* method_str = (protocol == IPPROTO_TCP) ? "write()" : "sendto()";
1186             LOG(ERROR) << method_str << " failed for " << host_str;
1187         }
1188         // Test that the response is actually a correct DNS message.
1189         // TODO: Perhaps make DNS message validation to support name compression. Or it throws
1190         // a warning for a valid DNS message with name compression while the binary packet mapping
1191         // is used.
1192         const char* response_end = response + len;
1193         DNSHeader header;
1194         const char* cur = header.read(response, response_end);
1195         if (cur == nullptr) LOG(WARNING) << "response is flawed";
1196     } else {
1197         LOG(WARNING) << "not responding";
1198     }
1199     cv.notify_one();
1200     return;
1201 }
1202 
sendToEventFd()1203 bool DNSResponder::sendToEventFd() {
1204     const uint64_t data = 1;
1205     if (const ssize_t rt = write(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
1206         PLOG(ERROR) << "failed to write eventfd, rt=" << rt;
1207         return false;
1208     }
1209     return true;
1210 }
1211 
handleEventFd()1212 void DNSResponder::handleEventFd() {
1213     int64_t data;
1214     if (const ssize_t rt = read(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
1215         PLOG(INFO) << "ignore reading eventfd failed, rt=" << rt;
1216     }
1217 }
1218 
createListeningSocket(int socket_type)1219 unique_fd DNSResponder::createListeningSocket(int socket_type) {
1220     addrinfo ai_hints{
1221             .ai_flags = AI_PASSIVE,
1222             .ai_family = AF_UNSPEC,
1223             .ai_socktype = socket_type,
1224     };
1225     addrinfo* ai_res = nullptr;
1226     const int rv =
1227             getaddrinfo(listen_address_.c_str(), listen_service_.c_str(), &ai_hints, &ai_res);
1228     ScopedAddrinfo ai_res_cleanup(ai_res);
1229     if (rv) {
1230         LOG(ERROR) << "getaddrinfo(" << listen_address_ << ", " << listen_service_
1231                    << ") failed: " << gai_strerror(rv);
1232         return {};
1233     }
1234     for (const addrinfo* ai = ai_res; ai; ai = ai->ai_next) {
1235         unique_fd fd(socket(ai->ai_family, ai->ai_socktype | SOCK_NONBLOCK, ai->ai_protocol));
1236         if (fd.get() < 0) {
1237             PLOG(ERROR) << "ignore creating socket failed";
1238             continue;
1239         }
1240 
1241         enableSockopt(fd.get(), SOL_SOCKET, SO_REUSEADDR).ignoreError();
1242         const std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
1243         if ((listen_service_ == kDefaultMdnsListenService) && (socket_type == SOCK_DGRAM)) {
1244             const int mdns_port = 5353;
1245             const char mdns_multiaddrv4[] = "224.0.0.251";
1246             const char mdns_multiaddrv6[] = "ff02::fb";
1247             if (ai_res->ai_family == AF_INET) {
1248                 // Join the MDNS IPV4 multicast group
1249                 struct ip_mreq mreq;
1250                 mreq.imr_multiaddr.s_addr = inet_addr(mdns_multiaddrv4);
1251                 mreq.imr_interface.s_addr = inet_addr(host_str.c_str());
1252                 if (setsockopt(fd.get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &mreq,
1253                                sizeof(struct ip_mreq)) == -1) {
1254                     LOG(ERROR) << "Error set setsockopt for IP_ADD_MEMBERSHIP ";
1255                     return {};
1256                 }
1257                 struct sockaddr_in addr = {.sin_family = AF_INET,
1258                                            .sin_port = htons(mdns_port),
1259                                            .sin_addr = {INADDR_ANY}};
1260                 if (bindSocket(fd.get(), (struct sockaddr*)&addr, sizeof(addr))) {
1261                     LOG(ERROR) << "Unable to bind socket to interface.";
1262                     return {};
1263                 }
1264             } else if (ai_res->ai_family == AF_INET6) {
1265                 // Join the MDNS IPV6 multicast group
1266                 struct ipv6_mreq mreqv6;
1267                 inet_pton(AF_INET6, mdns_multiaddrv6, &mreqv6.ipv6mr_multiaddr.s6_addr);
1268                 mreqv6.ipv6mr_interface = 0;
1269                 if (setsockopt(fd.get(), IPPROTO_IPV6, IPV6_JOIN_GROUP, &mreqv6, sizeof(mreqv6)) ==
1270                     -1) {
1271                     LOG(ERROR) << "Error set setsockopt for IPV6_JOIN_GROUP ";
1272                     return {};
1273                 }
1274                 struct sockaddr_in6 addr = {
1275                         .sin6_family = AF_INET6,
1276                         .sin6_port = htons(mdns_port),
1277                         .sin6_addr = IN6ADDR_ANY_INIT,
1278                 };
1279                 if (bindSocket(fd.get(), (struct sockaddr*)&addr, sizeof(addr))) {
1280                     LOG(ERROR) << "Unable to bind socket to interface.MDNS IPV6";
1281                     return {};
1282                 }
1283             }
1284             return fd;
1285         } else {
1286             const char* socket_str = (socket_type == SOCK_STREAM) ? "TCP" : "UDP";
1287             if (bindSocket(fd.get(), ai->ai_addr, ai->ai_addrlen)) {
1288                 PLOG(ERROR) << "failed to bind " << socket_str << " " << host_str << ":"
1289                             << listen_service_;
1290                 continue;
1291             }
1292             LOG(INFO) << "bound to " << socket_str << " " << host_str << ":" << listen_service_;
1293             return fd;
1294         }
1295     }
1296     return {};
1297 }
1298 
1299 }  // namespace test
1300