• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "dns_responder.h"
18 
19 #include <arpa/inet.h>
20 #include <fcntl.h>
21 #include <netdb.h>
22 #include <stdarg.h>
23 #include <stdlib.h>
24 #include <string.h>
25 #include <sys/epoll.h>
26 #include <sys/eventfd.h>
27 #include <sys/socket.h>
28 #include <sys/types.h>
29 #include <unistd.h>
30 #include <span>
31 
32 #include <chrono>
33 #include <iostream>
34 #include <set>
35 #include <vector>
36 
37 #define LOG_TAG "DNSResponder"
38 #include <android-base/logging.h>
39 #include <android-base/strings.h>
40 #include <netdutils/BackoffSequence.h>
41 #include <netdutils/InternetAddresses.h>
42 #include <netdutils/SocketOption.h>
43 
44 using android::base::ErrnoError;
45 using android::base::Result;
46 using android::base::unique_fd;
47 using android::netdutils::BackoffSequence;
48 using android::netdutils::enableSockopt;
49 using android::netdutils::ScopedAddrinfo;
50 using std::chrono::milliseconds;
51 
52 namespace test {
53 
errno2str()54 std::string errno2str() {
55     char error_msg[512] = {0};
56     // It actually calls __gnu_strerror_r() which returns the type |char*| rather than |int|.
57     // PLOG is an option though it requires lots of changes from ALOGx() to LOG(x).
58     return strerror_r(errno, error_msg, sizeof(error_msg));
59 }
60 
addr2str(const sockaddr * sa,socklen_t sa_len)61 std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
62     char host_str[NI_MAXHOST] = {0};
63     int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0, NI_NUMERICHOST);
64     if (rv == 0) return std::string(host_str);
65     return std::string();
66 }
67 
bytesToHexStr(std::span<const uint8_t> bytes)68 std::string bytesToHexStr(std::span<const uint8_t> bytes) {
69     static char const hex[16] = {'0', '1', '2', '3', '4', '5', '6', '7',
70                                  '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};
71     std::string str;
72     str.reserve(bytes.size() * 2);
73     for (uint8_t ch : bytes) {
74         str.append({hex[(ch & 0xf0) >> 4], hex[ch & 0xf]});
75     }
76     return str;
77 }
78 
79 // Because The address might still being set up (b/186181084), This is a wrapper function
80 // that retries bind() if errno is EADDRNOTAVAIL
bindSocket(int socket,const sockaddr * address,socklen_t address_len)81 Result<void> bindSocket(int socket, const sockaddr* address, socklen_t address_len) {
82     // Set the wrapper to try bind() at most 6 times with backoff time
83     // (100 ms, 200 ms, ..., 1600 ms).
84     auto backoff = BackoffSequence<milliseconds>::Builder()
85                            .withInitialRetransmissionTime(milliseconds(100))
86                            .withMaximumRetransmissionCount(5)
87                            .build();
88 
89     while (true) {
90         if (0 == bind(socket, address, address_len)) return {};
91         if (errno != EADDRNOTAVAIL) return ErrnoError();
92         if (!backoff.hasNextTimeout()) return ErrnoError();
93 
94         LOG(WARNING) << "Retry to bind " << addr2str(address, address_len);
95         std::this_thread::sleep_for(backoff.getNextTimeout());
96     }
97 }
98 
99 /* DNS struct helpers */
100 
dnstype2str(unsigned dnstype)101 const char* dnstype2str(unsigned dnstype) {
102     static std::unordered_map<unsigned, const char*> kTypeStrs = {
103             {ns_type::ns_t_a, "A"},
104             {ns_type::ns_t_ns, "NS"},
105             {ns_type::ns_t_md, "MD"},
106             {ns_type::ns_t_mf, "MF"},
107             {ns_type::ns_t_cname, "CNAME"},
108             {ns_type::ns_t_soa, "SOA"},
109             {ns_type::ns_t_mb, "MB"},
110             {ns_type::ns_t_mb, "MG"},
111             {ns_type::ns_t_mr, "MR"},
112             {ns_type::ns_t_null, "NULL"},
113             {ns_type::ns_t_wks, "WKS"},
114             {ns_type::ns_t_ptr, "PTR"},
115             {ns_type::ns_t_hinfo, "HINFO"},
116             {ns_type::ns_t_minfo, "MINFO"},
117             {ns_type::ns_t_mx, "MX"},
118             {ns_type::ns_t_txt, "TXT"},
119             {ns_type::ns_t_rp, "RP"},
120             {ns_type::ns_t_afsdb, "AFSDB"},
121             {ns_type::ns_t_x25, "X25"},
122             {ns_type::ns_t_isdn, "ISDN"},
123             {ns_type::ns_t_rt, "RT"},
124             {ns_type::ns_t_nsap, "NSAP"},
125             {ns_type::ns_t_nsap_ptr, "NSAP-PTR"},
126             {ns_type::ns_t_sig, "SIG"},
127             {ns_type::ns_t_key, "KEY"},
128             {ns_type::ns_t_px, "PX"},
129             {ns_type::ns_t_gpos, "GPOS"},
130             {ns_type::ns_t_aaaa, "AAAA"},
131             {ns_type::ns_t_loc, "LOC"},
132             {ns_type::ns_t_nxt, "NXT"},
133             {ns_type::ns_t_eid, "EID"},
134             {ns_type::ns_t_nimloc, "NIMLOC"},
135             {ns_type::ns_t_srv, "SRV"},
136             {ns_type::ns_t_naptr, "NAPTR"},
137             {ns_type::ns_t_kx, "KX"},
138             {ns_type::ns_t_cert, "CERT"},
139             {ns_type::ns_t_a6, "A6"},
140             {ns_type::ns_t_dname, "DNAME"},
141             {ns_type::ns_t_sink, "SINK"},
142             {ns_type::ns_t_opt, "OPT"},
143             {ns_type::ns_t_apl, "APL"},
144             {ns_type::ns_t_tkey, "TKEY"},
145             {ns_type::ns_t_tsig, "TSIG"},
146             {ns_type::ns_t_ixfr, "IXFR"},
147             {ns_type::ns_t_axfr, "AXFR"},
148             {ns_type::ns_t_mailb, "MAILB"},
149             {ns_type::ns_t_maila, "MAILA"},
150             {ns_type::ns_t_any, "ANY"},
151             {ns_type::ns_t_zxfr, "ZXFR"},
152     };
153     auto it = kTypeStrs.find(dnstype);
154     static const char* kUnknownStr{"UNKNOWN"};
155     if (it == kTypeStrs.end()) return kUnknownStr;
156     return it->second;
157 }
158 
dnsclass2str(unsigned dnsclass)159 const char* dnsclass2str(unsigned dnsclass) {
160     static std::unordered_map<unsigned, const char*> kClassStrs = {
161             {ns_class::ns_c_in, "Internet"},    {2, "CSNet"},
162             {ns_class::ns_c_chaos, "ChaosNet"}, {ns_class::ns_c_hs, "Hesiod"},
163             {ns_class::ns_c_none, "none"},      {ns_class::ns_c_any, "any"}};
164     auto it = kClassStrs.find(dnsclass);
165     static const char* kUnknownStr{"UNKNOWN"};
166     if (it == kClassStrs.end()) return kUnknownStr;
167     return it->second;
168 }
169 
dnsproto2str(int protocol)170 const char* dnsproto2str(int protocol) {
171     switch (protocol) {
172         case IPPROTO_TCP:
173             return "TCP";
174         case IPPROTO_UDP:
175             return "UDP";
176         default:
177             return "UNKNOWN";
178     }
179 }
180 
read(const char * buffer,const char * buffer_end)181 const char* DNSName::read(const char* buffer, const char* buffer_end) {
182     const char* cur = buffer;
183     bool last = false;
184     do {
185         cur = parseField(cur, buffer_end, &last);
186         if (cur == nullptr) {
187             LOG(ERROR) << "parsing failed at line " << __LINE__;
188             return nullptr;
189         }
190     } while (!last);
191     return cur;
192 }
193 
write(char * buffer,const char * buffer_end) const194 char* DNSName::write(char* buffer, const char* buffer_end) const {
195     char* buffer_cur = buffer;
196     for (size_t pos = 0; pos < name.size();) {
197         size_t dot_pos = name.find('.', pos);
198         if (dot_pos == std::string::npos) {
199             // Soundness check, should never happen unless parseField is broken.
200             LOG(ERROR) << "logic error: all names are expected to end with a '.'";
201             return nullptr;
202         }
203         const size_t len = dot_pos - pos;
204         if (len >= 256) {
205             LOG(ERROR) << "name component '" << name.substr(pos, dot_pos - pos) << "' is " << len
206                        << " long, but max is 255";
207             return nullptr;
208         }
209         if (buffer_cur + sizeof(uint8_t) + len > buffer_end) {
210             LOG(ERROR) << "buffer overflow at line " << __LINE__;
211             return nullptr;
212         }
213         *buffer_cur++ = len;
214         buffer_cur = std::copy(std::next(name.begin(), pos), std::next(name.begin(), dot_pos),
215                                buffer_cur);
216         pos = dot_pos + 1;
217     }
218     // Write final zero.
219     *buffer_cur++ = 0;
220     return buffer_cur;
221 }
222 
parseField(const char * buffer,const char * buffer_end,bool * last)223 const char* DNSName::parseField(const char* buffer, const char* buffer_end, bool* last) {
224     if (buffer + sizeof(uint8_t) > buffer_end) {
225         LOG(ERROR) << "parsing failed at line " << __LINE__;
226         return nullptr;
227     }
228     unsigned field_type = *buffer >> 6;
229     unsigned ofs = *buffer & 0x3F;
230     const char* cur = buffer + sizeof(uint8_t);
231     if (field_type == 0) {
232         // length + name component
233         if (ofs == 0) {
234             *last = true;
235             return cur;
236         }
237         if (cur + ofs > buffer_end) {
238             LOG(ERROR) << "parsing failed at line " << __LINE__;
239             return nullptr;
240         }
241         name.append(cur, ofs);
242         name.push_back('.');
243         return cur + ofs;
244     } else if (field_type == 3) {
245         LOG(ERROR) << "name compression not implemented";
246         return nullptr;
247     }
248     LOG(ERROR) << "invalid name field type";
249     return nullptr;
250 }
251 
read(const char * buffer,const char * buffer_end)252 const char* DNSQuestion::read(const char* buffer, const char* buffer_end) {
253     const char* cur = qname.read(buffer, buffer_end);
254     if (cur == nullptr) {
255         LOG(ERROR) << "parsing failed at line " << __LINE__;
256         return nullptr;
257     }
258     if (cur + 2 * sizeof(uint16_t) > buffer_end) {
259         LOG(ERROR) << "parsing failed at line " << __LINE__;
260         return nullptr;
261     }
262     qtype = ntohs(*reinterpret_cast<const uint16_t*>(cur));
263     qclass = ntohs(*reinterpret_cast<const uint16_t*>(cur + sizeof(uint16_t)));
264     return cur + 2 * sizeof(uint16_t);
265 }
266 
write(char * buffer,const char * buffer_end) const267 char* DNSQuestion::write(char* buffer, const char* buffer_end) const {
268     char* buffer_cur = qname.write(buffer, buffer_end);
269     if (buffer_cur == nullptr) return nullptr;
270     if (buffer_cur + 2 * sizeof(uint16_t) > buffer_end) {
271         LOG(ERROR) << "buffer overflow on line " << __LINE__;
272         return nullptr;
273     }
274     *reinterpret_cast<uint16_t*>(buffer_cur) = htons(qtype);
275     *reinterpret_cast<uint16_t*>(buffer_cur + sizeof(uint16_t)) = htons(qclass);
276     return buffer_cur + 2 * sizeof(uint16_t);
277 }
278 
toString() const279 std::string DNSQuestion::toString() const {
280     char buffer[16384];
281     int len = snprintf(buffer, sizeof(buffer), "Q<%s,%s,%s>", qname.name.c_str(),
282                        dnstype2str(qtype), dnsclass2str(qclass));
283     return std::string(buffer, len);
284 }
285 
read(const char * buffer,const char * buffer_end)286 const char* DNSRecord::read(const char* buffer, const char* buffer_end) {
287     const char* cur = name.read(buffer, buffer_end);
288     if (cur == nullptr) {
289         LOG(ERROR) << "parsing failed at line " << __LINE__;
290         return nullptr;
291     }
292     unsigned rdlen = 0;
293     cur = readIntFields(cur, buffer_end, &rdlen);
294     if (cur == nullptr) {
295         LOG(ERROR) << "parsing failed at line " << __LINE__;
296         return nullptr;
297     }
298     if (cur + rdlen > buffer_end) {
299         LOG(ERROR) << "parsing failed at line " << __LINE__;
300         return nullptr;
301     }
302     rdata.assign(cur, cur + rdlen);
303     return cur + rdlen;
304 }
305 
write(char * buffer,const char * buffer_end) const306 char* DNSRecord::write(char* buffer, const char* buffer_end) const {
307     char* buffer_cur = name.write(buffer, buffer_end);
308     if (buffer_cur == nullptr) return nullptr;
309     buffer_cur = writeIntFields(rdata.size(), buffer_cur, buffer_end);
310     if (buffer_cur == nullptr) return nullptr;
311     if (buffer_cur + rdata.size() > buffer_end) {
312         LOG(ERROR) << "buffer overflow on line " << __LINE__;
313         return nullptr;
314     }
315     return std::copy(rdata.begin(), rdata.end(), buffer_cur);
316 }
317 
toString() const318 std::string DNSRecord::toString() const {
319     char buffer[16384];
320     int len = snprintf(buffer, sizeof(buffer), "R<%s,%s,%s>", name.name.c_str(), dnstype2str(rtype),
321                        dnsclass2str(rclass));
322     return std::string(buffer, len);
323 }
324 
readIntFields(const char * buffer,const char * buffer_end,unsigned * rdlen)325 const char* DNSRecord::readIntFields(const char* buffer, const char* buffer_end, unsigned* rdlen) {
326     if (buffer + sizeof(IntFields) > buffer_end) {
327         LOG(ERROR) << "parsing failed at line " << __LINE__;
328         return nullptr;
329     }
330     const auto& intfields = *reinterpret_cast<const IntFields*>(buffer);
331     rtype = ntohs(intfields.rtype);
332     rclass = ntohs(intfields.rclass);
333     ttl = ntohl(intfields.ttl);
334     *rdlen = ntohs(intfields.rdlen);
335     return buffer + sizeof(IntFields);
336 }
337 
writeIntFields(unsigned rdlen,char * buffer,const char * buffer_end) const338 char* DNSRecord::writeIntFields(unsigned rdlen, char* buffer, const char* buffer_end) const {
339     if (buffer + sizeof(IntFields) > buffer_end) {
340         LOG(ERROR) << "buffer overflow on line " << __LINE__;
341         return nullptr;
342     }
343     auto& intfields = *reinterpret_cast<IntFields*>(buffer);
344     intfields.rtype = htons(rtype);
345     intfields.rclass = htons(rclass);
346     intfields.ttl = htonl(ttl);
347     intfields.rdlen = htons(rdlen);
348     return buffer + sizeof(IntFields);
349 }
350 
read(const char * buffer,const char * buffer_end)351 const char* DNSHeader::read(const char* buffer, const char* buffer_end) {
352     unsigned qdcount;
353     unsigned ancount;
354     unsigned nscount;
355     unsigned arcount;
356     const char* cur = readHeader(buffer, buffer_end, &qdcount, &ancount, &nscount, &arcount);
357     if (cur == nullptr) {
358         LOG(ERROR) << "parsing failed at line " << __LINE__;
359         return nullptr;
360     }
361     if (qdcount) {
362         questions.resize(qdcount);
363         for (unsigned i = 0; i < qdcount; ++i) {
364             cur = questions[i].read(cur, buffer_end);
365             if (cur == nullptr) {
366                 LOG(ERROR) << "parsing failed at line " << __LINE__;
367                 return nullptr;
368             }
369         }
370     }
371     if (ancount) {
372         answers.resize(ancount);
373         for (unsigned i = 0; i < ancount; ++i) {
374             cur = answers[i].read(cur, buffer_end);
375             if (cur == nullptr) {
376                 LOG(ERROR) << "parsing failed at line " << __LINE__;
377                 return nullptr;
378             }
379         }
380     }
381     if (nscount) {
382         authorities.resize(nscount);
383         for (unsigned i = 0; i < nscount; ++i) {
384             cur = authorities[i].read(cur, buffer_end);
385             if (cur == nullptr) {
386                 LOG(ERROR) << "parsing failed at line " << __LINE__;
387                 return nullptr;
388             }
389         }
390     }
391     if (arcount) {
392         additionals.resize(arcount);
393         for (unsigned i = 0; i < arcount; ++i) {
394             cur = additionals[i].read(cur, buffer_end);
395             if (cur == nullptr) {
396                 LOG(ERROR) << "parsing failed at line " << __LINE__;
397                 return nullptr;
398             }
399         }
400     }
401     return cur;
402 }
403 
write(char * buffer,const char * buffer_end) const404 char* DNSHeader::write(char* buffer, const char* buffer_end) const {
405     if (buffer + sizeof(Header) > buffer_end) {
406         LOG(ERROR) << "buffer overflow on line " << __LINE__;
407         return nullptr;
408     }
409     Header& header = *reinterpret_cast<Header*>(buffer);
410     // bytes 0-1
411     header.id = htons(id);
412     // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
413     header.flags0 = (qr << 7) | (opcode << 3) | (aa << 2) | (tr << 1) | rd;
414     // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
415     // Fake behavior: if the query set the "ad" bit, set it in the response too.
416     // In a real server, this should be set only if the data is authentic and the
417     // query contained an "ad" bit or DNSSEC extensions.
418     header.flags1 = (ad << 5) | rcode;
419     // rest of header
420     header.qdcount = htons(questions.size());
421     header.ancount = htons(answers.size());
422     header.nscount = htons(authorities.size());
423     header.arcount = htons(additionals.size());
424     char* buffer_cur = buffer + sizeof(Header);
425     for (const DNSQuestion& question : questions) {
426         buffer_cur = question.write(buffer_cur, buffer_end);
427         if (buffer_cur == nullptr) return nullptr;
428     }
429     for (const DNSRecord& answer : answers) {
430         buffer_cur = answer.write(buffer_cur, buffer_end);
431         if (buffer_cur == nullptr) return nullptr;
432     }
433     for (const DNSRecord& authority : authorities) {
434         buffer_cur = authority.write(buffer_cur, buffer_end);
435         if (buffer_cur == nullptr) return nullptr;
436     }
437     for (const DNSRecord& additional : additionals) {
438         buffer_cur = additional.write(buffer_cur, buffer_end);
439         if (buffer_cur == nullptr) return nullptr;
440     }
441     return buffer_cur;
442 }
443 
444 // TODO: convert all callers to this interface, then delete the old one.
write(std::vector<uint8_t> * out) const445 bool DNSHeader::write(std::vector<uint8_t>* out) const {
446     char buffer[16384];
447     char* end = this->write(buffer, buffer + sizeof buffer);
448     if (end == nullptr) return false;
449     out->insert(out->end(), buffer, end);
450     return true;
451 }
452 
toString() const453 std::string DNSHeader::toString() const {
454     // TODO
455     return std::string();
456 }
457 
readHeader(const char * buffer,const char * buffer_end,unsigned * qdcount,unsigned * ancount,unsigned * nscount,unsigned * arcount)458 const char* DNSHeader::readHeader(const char* buffer, const char* buffer_end, unsigned* qdcount,
459                                   unsigned* ancount, unsigned* nscount, unsigned* arcount) {
460     if (buffer + sizeof(Header) > buffer_end) return nullptr;
461     const auto& header = *reinterpret_cast<const Header*>(buffer);
462     // bytes 0-1
463     id = ntohs(header.id);
464     // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
465     qr = header.flags0 >> 7;
466     opcode = (header.flags0 >> 3) & 0x0F;
467     aa = (header.flags0 >> 2) & 1;
468     tr = (header.flags0 >> 1) & 1;
469     rd = header.flags0 & 1;
470     // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
471     ra = header.flags1 >> 7;
472     ad = (header.flags1 >> 5) & 1;
473     rcode = header.flags1 & 0xF;
474     // rest of header
475     *qdcount = ntohs(header.qdcount);
476     *ancount = ntohs(header.ancount);
477     *nscount = ntohs(header.nscount);
478     *arcount = ntohs(header.arcount);
479     return buffer + sizeof(Header);
480 }
481 
482 /* DNS responder */
483 
DNSResponder(std::string listen_address,std::string listen_service,ns_rcode error_rcode,MappingType mapping_type)484 DNSResponder::DNSResponder(std::string listen_address, std::string listen_service,
485                            ns_rcode error_rcode, MappingType mapping_type)
486     : listen_address_(std::move(listen_address)),
487       listen_service_(std::move(listen_service)),
488       error_rcode_(error_rcode),
489       mapping_type_(mapping_type) {}
490 
~DNSResponder()491 DNSResponder::~DNSResponder() {
492     stopServer();
493 }
494 
addMapping(const std::string & name,ns_type type,const std::string & addr)495 void DNSResponder::addMapping(const std::string& name, ns_type type, const std::string& addr) {
496     std::lock_guard lock(mappings_mutex_);
497     mappings_[{name, type}] = addr;
498 }
499 
addMappingDnsHeader(const std::string & name,ns_type type,const DNSHeader & header)500 void DNSResponder::addMappingDnsHeader(const std::string& name, ns_type type,
501                                        const DNSHeader& header) {
502     std::lock_guard lock(mappings_mutex_);
503     dnsheader_mappings_[{name, type}] = header;
504 }
505 
addMappingBinaryPacket(const std::vector<uint8_t> & query,const std::vector<uint8_t> & response)506 void DNSResponder::addMappingBinaryPacket(const std::vector<uint8_t>& query,
507                                           const std::vector<uint8_t>& response) {
508     std::lock_guard lock(mappings_mutex_);
509     packet_mappings_[query] = response;
510 }
511 
removeMapping(const std::string & name,ns_type type)512 void DNSResponder::removeMapping(const std::string& name, ns_type type) {
513     std::lock_guard lock(mappings_mutex_);
514     if (!mappings_.erase({name, type})) {
515         LOG(ERROR) << "Cannot remove mapping from (" << name << ", " << dnstype2str(type)
516                    << "), not present in registered mappings";
517     }
518 }
519 
removeMappingDnsHeader(const std::string & name,ns_type type)520 void DNSResponder::removeMappingDnsHeader(const std::string& name, ns_type type) {
521     std::lock_guard lock(mappings_mutex_);
522     if (!dnsheader_mappings_.erase({name, type})) {
523         LOG(ERROR) << "Cannot remove mapping from (" << name << ", " << dnstype2str(type)
524                    << "), not present in registered DnsHeader mappings";
525     }
526 }
527 
removeMappingBinaryPacket(const std::vector<uint8_t> & query)528 void DNSResponder::removeMappingBinaryPacket(const std::vector<uint8_t>& query) {
529     std::lock_guard lock(mappings_mutex_);
530     if (!packet_mappings_.erase(query)) {
531         LOG(ERROR) << "Cannot remove mapping, not present in registered BinaryPacket mappings";
532         LOG(INFO) << "Hex dump:";
533         LOG(INFO) << bytesToHexStr(query);
534     }
535 }
536 
537 // Set response probability on all supported protocols.
setResponseProbability(double response_probability)538 void DNSResponder::setResponseProbability(double response_probability) {
539     setResponseProbability(response_probability, IPPROTO_TCP);
540     setResponseProbability(response_probability, IPPROTO_UDP);
541 }
542 
setResponseDelayMs(unsigned timeMs)543 void DNSResponder::setResponseDelayMs(unsigned timeMs) {
544     response_delayed_ms_ = timeMs;
545 }
546 
547 // Set response probability on specific protocol. It's caller's duty to ensure that the |protocol|
548 // can be supported by DNSResponder.
setResponseProbability(double response_probability,int protocol)549 void DNSResponder::setResponseProbability(double response_probability, int protocol) {
550     switch (protocol) {
551         case IPPROTO_TCP:
552             response_probability_tcp_ = response_probability;
553             break;
554         case IPPROTO_UDP:
555             response_probability_udp_ = response_probability;
556             break;
557         default:
558             LOG(FATAL) << "Unsupported protocol " << protocol;  // abort() by log level FATAL
559     }
560 }
561 
getResponseProbability(int protocol) const562 double DNSResponder::getResponseProbability(int protocol) const {
563     switch (protocol) {
564         case IPPROTO_TCP:
565             return response_probability_tcp_;
566         case IPPROTO_UDP:
567             return response_probability_udp_;
568         default:
569             LOG(FATAL) << "Unsupported protocol " << protocol;  // abort() by log level FATAL
570             // unreachable
571             return -1;
572     }
573 }
574 
setEdns(Edns edns)575 void DNSResponder::setEdns(Edns edns) {
576     edns_ = edns;
577 }
578 
setTtl(unsigned ttl)579 void DNSResponder::setTtl(unsigned ttl) {
580     answer_record_ttl_sec_ = ttl;
581 }
582 
running() const583 bool DNSResponder::running() const {
584     if (listen_service_ == kDefaultMdnsListenService)
585         return udp_socket_.ok();
586     else
587         return (udp_socket_.ok()) && (tcp_socket_.ok());
588 }
589 
startServer()590 bool DNSResponder::startServer() {
591     if (running()) {
592         LOG(ERROR) << "server already running";
593         return false;
594     }
595 
596     // Create UDP, TCP socket
597     if (udp_socket_ = createListeningSocket(SOCK_DGRAM); udp_socket_.get() < 0) {
598         PLOG(ERROR) << "failed to create UDP socket";
599         return false;
600     }
601 
602     if (listen_service_ != kDefaultMdnsListenService) {
603         if (tcp_socket_ = createListeningSocket(SOCK_STREAM); tcp_socket_.get() < 0) {
604             PLOG(ERROR) << "failed to create TCP socket";
605             return false;
606         }
607 
608         if (listen(tcp_socket_.get(), 1) < 0) {
609             PLOG(ERROR) << "failed to listen TCP socket";
610             return false;
611         }
612     }
613 
614     // Set up eventfd socket.
615     event_fd_.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));
616     if (event_fd_.get() == -1) {
617         PLOG(ERROR) << "failed to create eventfd";
618         return false;
619     }
620 
621     // Set up epoll socket.
622     epoll_fd_.reset(epoll_create1(EPOLL_CLOEXEC));
623     if (epoll_fd_.get() < 0) {
624         PLOG(ERROR) << "epoll_create1() failed on fd";
625         return false;
626     }
627 
628     LOG(INFO) << "adding UDP socket to epoll";
629     if (!addFd(udp_socket_.get(), EPOLLIN)) {
630         LOG(ERROR) << "failed to add the UDP socket to epoll";
631         return false;
632     }
633 
634     if (listen_service_ != kDefaultMdnsListenService) {
635         LOG(INFO) << "adding TCP socket to epoll";
636         if (!addFd(tcp_socket_.get(), EPOLLIN)) {
637             LOG(ERROR) << "failed to add the TCP socket to epoll";
638             return false;
639         }
640     }
641 
642     LOG(INFO) << "adding eventfd to epoll";
643     if (!addFd(event_fd_.get(), EPOLLIN)) {
644         LOG(ERROR) << "failed to add the eventfd to epoll";
645         return false;
646     }
647 
648     {
649         std::lock_guard lock(update_mutex_);
650         handler_thread_ = std::thread(&DNSResponder::requestHandler, this);
651     }
652     LOG(INFO) << "server started successfully";
653     return true;
654 }
655 
stopServer()656 bool DNSResponder::stopServer() {
657     std::lock_guard lock(update_mutex_);
658     if (!running()) {
659         LOG(ERROR) << "server not running";
660         return false;
661     }
662     LOG(INFO) << "stopping server";
663     if (!sendToEventFd()) {
664         return false;
665     }
666     handler_thread_.join();
667     epoll_fd_.reset();
668     event_fd_.reset();
669     udp_socket_.reset();
670     tcp_socket_.reset();
671     LOG(INFO) << "server stopped successfully";
672     return true;
673 }
674 
queries() const675 std::vector<DNSResponder::QueryInfo> DNSResponder::queries() const {
676     std::lock_guard lock(queries_mutex_);
677     return queries_;
678 }
679 
dumpQueries() const680 std::string DNSResponder::dumpQueries() const {
681     std::lock_guard lock(queries_mutex_);
682     std::string out;
683 
684     for (const auto& q : queries_) {
685         out += "{\"" + q.name + "\", " + std::to_string(q.type) + "\", " +
686                dnsproto2str(q.protocol) + "} ";
687     }
688     return out;
689 }
690 
clearQueries()691 void DNSResponder::clearQueries() {
692     std::lock_guard lock(queries_mutex_);
693     queries_.clear();
694 }
695 
hasOptPseudoRR(DNSHeader * header) const696 bool DNSResponder::hasOptPseudoRR(DNSHeader* header) const {
697     if (header->additionals.empty()) return false;
698 
699     // OPT RR may be placed anywhere within the additional section. See RFC 6891 section 6.1.1.
700     auto found = std::find_if(header->additionals.begin(), header->additionals.end(),
701                               [](const auto& a) { return a.rtype == ns_type::ns_t_opt; });
702     return found != header->additionals.end();
703 }
704 
requestHandler()705 void DNSResponder::requestHandler() {
706     epoll_event evs[EPOLL_MAX_EVENTS];
707     while (true) {
708         int n = epoll_wait(epoll_fd_.get(), evs, EPOLL_MAX_EVENTS, -1);
709         if (n <= 0) {
710             PLOG(ERROR) << "epoll_wait() failed, n=" << n;
711             return;
712         }
713 
714         for (int i = 0; i < n; i++) {
715             const int fd = evs[i].data.fd;
716             const uint32_t events = evs[i].events;
717             if (fd == event_fd_.get() && (events & (EPOLLIN | EPOLLERR))) {
718                 handleEventFd();
719                 return;
720             } else if (fd == udp_socket_.get() && (events & (EPOLLIN | EPOLLERR))) {
721                 handleQuery(IPPROTO_UDP);
722             } else if (fd == tcp_socket_.get() && (events & (EPOLLIN | EPOLLERR))) {
723                 handleQuery(IPPROTO_TCP);
724             } else {
725                 LOG(WARNING) << "unexpected epoll events " << events << " on fd " << fd;
726             }
727         }
728     }
729 }
730 
handleDNSRequest(const char * buffer,ssize_t len,int protocol,char * response,size_t * response_len) const731 bool DNSResponder::handleDNSRequest(const char* buffer, ssize_t len, int protocol, char* response,
732                                     size_t* response_len) const {
733     LOG(DEBUG) << "request: '" << bytesToHexStr({reinterpret_cast<const uint8_t*>(buffer), len})
734                << "', on " << dnsproto2str(protocol);
735     const char* buffer_end = buffer + len;
736     DNSHeader header;
737     const char* cur = header.read(buffer, buffer_end);
738     // TODO(imaipi): for now, unparsable messages are silently dropped, fix.
739     if (cur == nullptr) {
740         LOG(ERROR) << "failed to parse query";
741         return false;
742     }
743     if (header.qr) {
744         LOG(ERROR) << "response received instead of a query";
745         return false;
746     }
747     if (header.opcode != ns_opcode::ns_o_query) {
748         LOG(INFO) << "unsupported request opcode received";
749         return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response, response_len);
750     }
751     if (header.questions.empty()) {
752         LOG(INFO) << "no questions present";
753         return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, response_len);
754     }
755     if (!header.answers.empty()) {
756         LOG(INFO) << "already " << header.answers.size() << " answers present in query";
757         return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, response_len);
758     }
759 
760     if (edns_ == Edns::FORMERR_UNCOND) {
761         LOG(INFO) << "force to return RCODE FORMERR";
762         return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, response_len);
763     }
764 
765     if (!header.additionals.empty() && edns_ != Edns::ON) {
766         LOG(INFO) << "DNS request has an additional section (assumed EDNS). Simulating an ancient "
767                      "(pre-EDNS) server, and returning "
768                   << (edns_ == Edns::FORMERR_ON_EDNS ? "RCODE FORMERR." : "no response.");
769         if (edns_ == Edns::FORMERR_ON_EDNS) {
770             return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, response_len);
771         }
772         // No response.
773         return false;
774     }
775     {
776         std::lock_guard lock(queries_mutex_);
777         for (const DNSQuestion& question : header.questions) {
778             queries_.push_back({question.qname.name, ns_type(question.qtype), protocol});
779         }
780     }
781     // Ignore requests with the preset probability.
782     auto constexpr bound = std::numeric_limits<unsigned>::max();
783     if (arc4random_uniform(bound) > bound * getResponseProbability(protocol)) {
784         if (error_rcode_ < 0) {
785             LOG(ERROR) << "Returning no response";
786             return false;
787         } else {
788             LOG(INFO) << "returning RCODE " << static_cast<int>(error_rcode_)
789                       << " in accordance with probability distribution";
790             return makeErrorResponse(&header, error_rcode_, response, response_len);
791         }
792     }
793 
794     // Make the response. The query has been read into |header| which is used to build and return
795     // the response as well.
796     return makeResponse(&header, protocol, response, response_len);
797 }
798 
addAnswerRecords(const DNSQuestion & question,std::vector<DNSRecord> * answers) const799 bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
800                                     std::vector<DNSRecord>* answers) const {
801     std::lock_guard guard(mappings_mutex_);
802     std::string rname = question.qname.name;
803     std::vector<int> rtypes;
804 
805     if (question.qtype == ns_type::ns_t_a || question.qtype == ns_type::ns_t_aaaa ||
806         question.qtype == ns_type::ns_t_ptr)
807         rtypes.push_back(ns_type::ns_t_cname);
808     rtypes.push_back(question.qtype);
809     for (int rtype : rtypes) {
810         std::set<std::string> cnames_Loop;
811         std::unordered_map<QueryKey, std::string, QueryKeyHash>::const_iterator it;
812         while ((it = mappings_.find(QueryKey(rname, rtype))) != mappings_.end()) {
813             if (rtype == ns_type::ns_t_cname) {
814                 // When detect CNAME infinite loops by cnames_Loop, it won't save the duplicate one.
815                 // As following, the query will stop on loop3 by detecting the same cname.
816                 // loop1.{"a.xxx.com", ns_type::ns_t_cname, "b.xxx.com"}(insert in answer record)
817                 // loop2.{"b.xxx.com", ns_type::ns_t_cname, "a.xxx.com"}(insert in answer record)
818                 // loop3.{"a.xxx.com", ns_type::ns_t_cname, "b.xxx.com"}(When the same cname record
819                 //   is found in cnames_Loop already, break the query loop.)
820                 if (cnames_Loop.find(it->first.name) != cnames_Loop.end()) break;
821                 cnames_Loop.insert(it->first.name);
822             }
823             DNSRecord record{
824                     .name = {.name = it->first.name},
825                     .rtype = it->first.type,
826                     .rclass = ns_class::ns_c_in,
827                     .ttl = answer_record_ttl_sec_,  // seconds
828             };
829             if (!fillRdata(it->second, record)) return false;
830             answers->push_back(std::move(record));
831             if (rtype != ns_type::ns_t_cname) break;
832             rname = it->second;
833         }
834     }
835 
836     if (answers->size() == 0) {
837         // TODO(imaipi): handle correctly
838         LOG(INFO) << "no mapping found for " << question.qname.name << " "
839                   << dnstype2str(question.qtype) << ", lazily refusing to add an answer";
840     }
841 
842     return true;
843 }
844 
fillRdata(const std::string & rdatastr,DNSRecord & record)845 bool DNSResponder::fillRdata(const std::string& rdatastr, DNSRecord& record) {
846     if (record.rtype == ns_type::ns_t_a) {
847         record.rdata.resize(4);
848         if (inet_pton(AF_INET, rdatastr.c_str(), record.rdata.data()) != 1) {
849             LOG(ERROR) << "inet_pton(AF_INET, " << rdatastr << ") failed";
850             return false;
851         }
852     } else if (record.rtype == ns_type::ns_t_aaaa) {
853         record.rdata.resize(16);
854         if (inet_pton(AF_INET6, rdatastr.c_str(), record.rdata.data()) != 1) {
855             LOG(ERROR) << "inet_pton(AF_INET6, " << rdatastr << ") failed";
856             return false;
857         }
858     } else if ((record.rtype == ns_type::ns_t_ptr) || (record.rtype == ns_type::ns_t_cname) ||
859                (record.rtype == ns_type::ns_t_ns)) {
860         constexpr char delimiter = '.';
861         std::string name = rdatastr;
862         std::vector<char> rdata;
863 
864         // Generating PTRDNAME field(section 3.3.12) or CNAME field(section 3.3.1) in rfc1035.
865         // The "name" should be an absolute domain name which ends in a dot.
866         if (name.back() != delimiter) {
867             LOG(ERROR) << "invalid absolute domain name";
868             return false;
869         }
870         name.pop_back();  // remove the dot in tail
871         for (const std::string& label : android::base::Split(name, {delimiter})) {
872             // The length of label is limited to 63 octets or less. See RFC 1035 section 3.1.
873             if (label.length() == 0 || label.length() > 63) {
874                 LOG(ERROR) << "invalid label length";
875                 return false;
876             }
877 
878             rdata.push_back(label.length());
879             rdata.insert(rdata.end(), label.begin(), label.end());
880         }
881         rdata.push_back(0);  // Length byte of zero terminates the label list
882 
883         // The length of domain name is limited to 255 octets or less. See RFC 1035 section 3.1.
884         if (rdata.size() > 255) {
885             LOG(ERROR) << "invalid name length";
886             return false;
887         }
888         record.rdata = std::move(rdata);
889     } else {
890         LOG(ERROR) << "unhandled qtype " << dnstype2str(record.rtype);
891         return false;
892     }
893     return true;
894 }
895 
writePacket(const DNSHeader * header,char * response,size_t * response_len) const896 bool DNSResponder::writePacket(const DNSHeader* header, char* response,
897                                size_t* response_len) const {
898     char* response_cur = header->write(response, response + *response_len);
899     if (response_cur == nullptr) {
900         return false;
901     }
902     *response_len = response_cur - response;
903     return true;
904 }
905 
makeErrorResponse(DNSHeader * header,ns_rcode rcode,char * response,size_t * response_len) const906 bool DNSResponder::makeErrorResponse(DNSHeader* header, ns_rcode rcode, char* response,
907                                      size_t* response_len) const {
908     header->answers.clear();
909     header->authorities.clear();
910     header->additionals.clear();
911     header->rcode = rcode;
912     header->qr = true;
913     return writePacket(header, response, response_len);
914 }
915 
makeTruncatedResponse(DNSHeader * header,char * response,size_t * response_len) const916 bool DNSResponder::makeTruncatedResponse(DNSHeader* header, char* response,
917                                          size_t* response_len) const {
918     // Build a minimal response for non-EDNS response over UDP. Truncate all stub RRs in answer,
919     // authority and additional section. EDNS response truncation has not supported here yet
920     // because the EDNS response must have an OPT record. See RFC 6891 section 7.
921     header->answers.clear();
922     header->authorities.clear();
923     header->additionals.clear();
924     header->qr = true;
925     header->tr = true;
926     return writePacket(header, response, response_len);
927 }
928 
makeResponse(DNSHeader * header,int protocol,char * response,size_t * response_len) const929 bool DNSResponder::makeResponse(DNSHeader* header, int protocol, char* response,
930                                 size_t* response_len) const {
931     char buffer[16384];
932     size_t buffer_len = sizeof(buffer);
933     bool ret;
934 
935     switch (mapping_type_) {
936         case MappingType::DNS_HEADER:
937             ret = makeResponseFromDnsHeader(header, buffer, &buffer_len);
938             break;
939         case MappingType::BINARY_PACKET:
940             ret = makeResponseFromBinaryPacket(header, buffer, &buffer_len);
941             break;
942         case MappingType::ADDRESS_OR_HOSTNAME:
943         default:
944             ret = makeResponseFromAddressOrHostname(header, buffer, &buffer_len);
945     }
946 
947     if (!ret) return false;
948 
949     // Return truncated response if the built non-EDNS response size which is larger than 512 bytes
950     // will be responded over UDP. The truncated response implementation here just simply set up
951     // the TC bit and truncate all stub RRs in answer, authority and additional section. It is
952     // because the resolver will retry DNS query over TCP and use the full TCP response. See also
953     // RFC 1035 section 4.2.1 for UDP response truncation and RFC 6891 section 4.3 for EDNS larger
954     // response size capability.
955     // TODO: Perhaps keep the stub RRs as possible.
956     // TODO: Perhaps truncate the EDNS based response over UDP. See also RFC 6891 section 4.3,
957     // section 6.2.5 and section 7.
958     if (protocol == IPPROTO_UDP && buffer_len > kMaximumUdpSize &&
959         !hasOptPseudoRR(header) /* non-EDNS */) {
960         LOG(INFO) << "Return truncated response because original response length " << buffer_len
961                   << " is larger than " << kMaximumUdpSize << " bytes.";
962         return makeTruncatedResponse(header, response, response_len);
963     }
964 
965     if (buffer_len > *response_len) {
966         LOG(ERROR) << "buffer overflow on line " << __LINE__;
967         return false;
968     }
969     memcpy(response, buffer, buffer_len);
970     *response_len = buffer_len;
971     return true;
972 }
973 
makeResponseFromAddressOrHostname(DNSHeader * header,char * response,size_t * response_len) const974 bool DNSResponder::makeResponseFromAddressOrHostname(DNSHeader* header, char* response,
975                                                      size_t* response_len) const {
976     for (const DNSQuestion& question : header->questions) {
977         if (question.qclass != ns_class::ns_c_in && question.qclass != ns_class::ns_c_any) {
978             LOG(INFO) << "unsupported question class " << question.qclass;
979             return makeErrorResponse(header, ns_rcode::ns_r_notimpl, response, response_len);
980         }
981 
982         if (!addAnswerRecords(question, &header->answers)) {
983             return makeErrorResponse(header, ns_rcode::ns_r_servfail, response, response_len);
984         }
985     }
986     header->qr = true;
987     return writePacket(header, response, response_len);
988 }
989 
makeResponseFromDnsHeader(DNSHeader * header,char * response,size_t * response_len) const990 bool DNSResponder::makeResponseFromDnsHeader(DNSHeader* header, char* response,
991                                              size_t* response_len) const {
992     std::lock_guard guard(mappings_mutex_);
993 
994     // Support single question record only. It should be okay because res_mkquery() sets "qdcount"
995     // as one for the operation QUERY and handleDNSRequest() checks ns_opcode::ns_o_query before
996     // making a response. In other words, only need to handle the query which has single question
997     // section. See also res_mkquery() in system/netd/resolv/res_mkquery.cpp.
998     // TODO: Perhaps add support for multi-question records.
999     const std::vector<DNSQuestion>& questions = header->questions;
1000     if (questions.size() != 1) {
1001         LOG(INFO) << "unsupported question count " << questions.size();
1002         return makeErrorResponse(header, ns_rcode::ns_r_notimpl, response, response_len);
1003     }
1004 
1005     if (questions[0].qclass != ns_class::ns_c_in && questions[0].qclass != ns_class::ns_c_any) {
1006         LOG(INFO) << "unsupported question class " << questions[0].qclass;
1007         return makeErrorResponse(header, ns_rcode::ns_r_notimpl, response, response_len);
1008     }
1009 
1010     const std::string name = questions[0].qname.name;
1011     const int qtype = questions[0].qtype;
1012     const auto it = dnsheader_mappings_.find(QueryKey(name, qtype));
1013     if (it != dnsheader_mappings_.end()) {
1014         // Store both "id" and "rd" which comes from query.
1015         const unsigned id = header->id;
1016         const bool rd = header->rd;
1017 
1018         // Build a response from the registered DNSHeader mapping.
1019         *header = it->second;
1020         // Assign both "ID" and "RD" fields from query to response. See RFC 1035 section 4.1.1.
1021         header->id = id;
1022         header->rd = rd;
1023     } else {
1024         // TODO: handle correctly. See also TODO in addAnswerRecords().
1025         LOG(INFO) << "no mapping found for " << name << " " << dnstype2str(qtype)
1026                   << ", couldn't build a response from DNSHeader mapping";
1027 
1028         // Note that do nothing as makeResponseFromAddressOrHostname() if no mapping is found. It
1029         // just changes the QR flag from query (0) to response (1) in the query. Then, send the
1030         // modified query back as a response.
1031         header->qr = true;
1032     }
1033     return writePacket(header, response, response_len);
1034 }
1035 
makeResponseFromBinaryPacket(DNSHeader * header,char * response,size_t * response_len) const1036 bool DNSResponder::makeResponseFromBinaryPacket(DNSHeader* header, char* response,
1037                                                 size_t* response_len) const {
1038     std::lock_guard guard(mappings_mutex_);
1039 
1040     // Build a search key of mapping from the query.
1041     // TODO: Perhaps pass the query packet buffer directly from the caller.
1042     std::vector<uint8_t> queryKey;
1043     if (!header->write(&queryKey)) return false;
1044     // Clear ID field (byte 0-1) because it is not required by the mapping key.
1045     queryKey[0] = 0;
1046     queryKey[1] = 0;
1047 
1048     const auto it = packet_mappings_.find(queryKey);
1049     if (it != packet_mappings_.end()) {
1050         if (it->second.size() > *response_len) {
1051             LOG(ERROR) << "buffer overflow on line " << __LINE__;
1052             return false;
1053         } else {
1054             std::copy(it->second.begin(), it->second.end(), response);
1055             // Leave the "RD" flag assignment for testing. The "RD" flag of the response keep
1056             // using the one from the raw packet mapping but the received query.
1057             // Assign "ID" field from query to response. See RFC 1035 section 4.1.1.
1058             reinterpret_cast<uint16_t*>(response)[0] = htons(header->id);  // bytes 0-1: id
1059             *response_len = it->second.size();
1060             return true;
1061         }
1062     } else {
1063         // TODO: handle correctly. See also TODO in addAnswerRecords().
1064         // TODO: Perhaps dump packet content to indicate which query failed.
1065         LOG(INFO) << "no mapping found, couldn't build a response from BinaryPacket mapping";
1066         // Note that do nothing as makeResponseFromAddressOrHostname() if no mapping is found. It
1067         // just changes the QR flag from query (0) to response (1) in the query. Then, send the
1068         // modified query back as a response.
1069         header->qr = true;
1070         return writePacket(header, response, response_len);
1071     }
1072 }
1073 
setDeferredResp(bool deferred_resp)1074 void DNSResponder::setDeferredResp(bool deferred_resp) {
1075     std::lock_guard<std::mutex> guard(cv_mutex_for_deferred_resp_);
1076     deferred_resp_ = deferred_resp;
1077     if (!deferred_resp_) {
1078         cv_for_deferred_resp_.notify_one();
1079     }
1080 }
1081 
addFd(int fd,uint32_t events)1082 bool DNSResponder::addFd(int fd, uint32_t events) {
1083     epoll_event ev;
1084     ev.events = events;
1085     ev.data.fd = fd;
1086     if (epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, fd, &ev) < 0) {
1087         PLOG(ERROR) << "epoll_ctl() for socket " << fd << " failed";
1088         return false;
1089     }
1090     return true;
1091 }
1092 
handleQuery(int protocol)1093 void DNSResponder::handleQuery(int protocol) {
1094     char buffer[16384];
1095     sockaddr_storage sa;
1096     socklen_t sa_len = sizeof(sa);
1097     ssize_t len = 0;
1098     unique_fd tcpFd;
1099     switch (protocol) {
1100         case IPPROTO_UDP:
1101             do {
1102                 len = recvfrom(udp_socket_.get(), buffer, sizeof(buffer), 0, (sockaddr*)&sa,
1103                                &sa_len);
1104             } while (len < 0 && (errno == EAGAIN || errno == EINTR));
1105             if (len <= 0) {
1106                 PLOG(ERROR) << "recvfrom() failed, len=" << len;
1107                 return;
1108             }
1109             break;
1110         case IPPROTO_TCP:
1111             tcpFd.reset(accept4(tcp_socket_.get(), reinterpret_cast<sockaddr*>(&sa), &sa_len,
1112                                 SOCK_CLOEXEC));
1113             if (tcpFd.get() < 0) {
1114                 PLOG(ERROR) << "failed to accept client socket";
1115                 return;
1116             }
1117             // Get the message length from two byte length field.
1118             // See also RFC 1035, section 4.2.2 and RFC 7766, section 8
1119             uint8_t queryMessageLengthField[2];
1120             if (read(tcpFd.get(), &queryMessageLengthField, 2) != 2) {
1121                 PLOG(ERROR) << "Not enough length field bytes";
1122                 return;
1123             }
1124 
1125             const uint16_t qlen = (queryMessageLengthField[0] << 8) | queryMessageLengthField[1];
1126             while (len < qlen) {
1127                 ssize_t ret = read(tcpFd.get(), buffer + len, qlen - len);
1128                 if (ret <= 0) {
1129                     PLOG(ERROR) << "Error while reading query";
1130                     return;
1131                 }
1132                 len += ret;
1133             }
1134             break;
1135     }
1136     LOG(DEBUG) << "read " << len << " bytes on " << dnsproto2str(protocol);
1137     std::lock_guard lock(cv_mutex_);
1138     char response[16384];
1139     size_t response_len = sizeof(response);
1140     // TODO: check whether sending malformed packets to DnsResponder
1141     if (handleDNSRequest(buffer, len, protocol, response, &response_len) && response_len > 0) {
1142         std::this_thread::sleep_for(std::chrono::milliseconds(response_delayed_ms_));
1143         // place wait_for after handleDNSRequest() so we can check the number of queries in
1144         // test case before it got responded.
1145         std::unique_lock guard(cv_mutex_for_deferred_resp_);
1146         cv_for_deferred_resp_.wait(
1147                 guard, [this]() REQUIRES(cv_mutex_for_deferred_resp_) { return !deferred_resp_; });
1148         len = 0;
1149 
1150         switch (protocol) {
1151             case IPPROTO_UDP:
1152                 len = sendto(udp_socket_.get(), response, response_len, 0,
1153                              reinterpret_cast<const sockaddr*>(&sa), sa_len);
1154                 if (len < 0) {
1155                     PLOG(ERROR) << "Failed to send response";
1156                 }
1157                 break;
1158             case IPPROTO_TCP:
1159                 // Get the message length from two byte length field.
1160                 // See also RFC 1035, section 4.2.2 and RFC 7766, section 8
1161                 uint8_t responseMessageLengthField[2];
1162                 responseMessageLengthField[0] = response_len >> 8;
1163                 responseMessageLengthField[1] = response_len;
1164                 if (write(tcpFd.get(), responseMessageLengthField, 2) != 2) {
1165                     PLOG(ERROR) << "Failed to write response length field";
1166                     break;
1167                 }
1168                 if (write(tcpFd.get(), response, response_len) !=
1169                     static_cast<ssize_t>(response_len)) {
1170                     PLOG(ERROR) << "Failed to write response";
1171                     break;
1172                 }
1173                 len = response_len;
1174                 break;
1175         }
1176         const std::string host_str = addr2str(reinterpret_cast<const sockaddr*>(&sa), sa_len);
1177         if (len > 0) {
1178             LOG(DEBUG) << "sent " << len << " bytes to " << host_str;
1179         } else {
1180             const char* method_str = (protocol == IPPROTO_TCP) ? "write()" : "sendto()";
1181             LOG(ERROR) << method_str << " failed for " << host_str;
1182         }
1183         // Test that the response is actually a correct DNS message.
1184         // TODO: Perhaps make DNS message validation to support name compression. Or it throws
1185         // a warning for a valid DNS message with name compression while the binary packet mapping
1186         // is used.
1187         const char* response_end = response + len;
1188         DNSHeader header;
1189         const char* cur = header.read(response, response_end);
1190         if (cur == nullptr) LOG(WARNING) << "response is flawed";
1191     } else {
1192         LOG(WARNING) << "not responding";
1193     }
1194     cv.notify_one();
1195     return;
1196 }
1197 
sendToEventFd()1198 bool DNSResponder::sendToEventFd() {
1199     const uint64_t data = 1;
1200     if (const ssize_t rt = write(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
1201         PLOG(ERROR) << "failed to write eventfd, rt=" << rt;
1202         return false;
1203     }
1204     return true;
1205 }
1206 
handleEventFd()1207 void DNSResponder::handleEventFd() {
1208     int64_t data;
1209     if (const ssize_t rt = read(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
1210         PLOG(INFO) << "ignore reading eventfd failed, rt=" << rt;
1211     }
1212 }
1213 
createListeningSocket(int socket_type)1214 unique_fd DNSResponder::createListeningSocket(int socket_type) {
1215     addrinfo ai_hints{
1216             .ai_flags = AI_PASSIVE,
1217             .ai_family = AF_UNSPEC,
1218             .ai_socktype = socket_type,
1219     };
1220     addrinfo* ai_res = nullptr;
1221     const int rv =
1222             getaddrinfo(listen_address_.c_str(), listen_service_.c_str(), &ai_hints, &ai_res);
1223     ScopedAddrinfo ai_res_cleanup(ai_res);
1224     if (rv) {
1225         LOG(ERROR) << "getaddrinfo(" << listen_address_ << ", " << listen_service_
1226                    << ") failed: " << gai_strerror(rv);
1227         return {};
1228     }
1229     for (const addrinfo* ai = ai_res; ai; ai = ai->ai_next) {
1230         unique_fd fd(socket(ai->ai_family, ai->ai_socktype | SOCK_NONBLOCK, ai->ai_protocol));
1231         if (fd.get() < 0) {
1232             PLOG(ERROR) << "ignore creating socket failed";
1233             continue;
1234         }
1235 
1236         enableSockopt(fd.get(), SOL_SOCKET, SO_REUSEADDR).ignoreError();
1237         const std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
1238         if ((listen_service_ == kDefaultMdnsListenService) && (socket_type == SOCK_DGRAM)) {
1239             const int mdns_port = 5353;
1240             const char mdns_multiaddrv4[] = "224.0.0.251";
1241             const char mdns_multiaddrv6[] = "ff02::fb";
1242             if (ai_res->ai_family == AF_INET) {
1243                 // Join the MDNS IPV4 multicast group
1244                 struct ip_mreq mreq;
1245                 mreq.imr_multiaddr.s_addr = inet_addr(mdns_multiaddrv4);
1246                 mreq.imr_interface.s_addr = inet_addr(host_str.c_str());
1247                 if (setsockopt(fd.get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &mreq,
1248                                sizeof(struct ip_mreq)) == -1) {
1249                     LOG(ERROR) << "Error set setsockopt for IP_ADD_MEMBERSHIP ";
1250                     return {};
1251                 }
1252                 struct sockaddr_in addr = {.sin_family = AF_INET,
1253                                            .sin_port = htons(mdns_port),
1254                                            .sin_addr = {INADDR_ANY}};
1255                 if (auto result = bindSocket(fd.get(), (struct sockaddr*)&addr, sizeof(addr));
1256                     !result.ok()) {
1257                     LOG(ERROR) << "failed to bind. MDNS IPv4: " << result.error().message();
1258                     return {};
1259                 }
1260             } else if (ai_res->ai_family == AF_INET6) {
1261                 // Join the MDNS IPV6 multicast group
1262                 struct ipv6_mreq mreqv6;
1263                 inet_pton(AF_INET6, mdns_multiaddrv6, &mreqv6.ipv6mr_multiaddr.s6_addr);
1264                 mreqv6.ipv6mr_interface = 0;
1265                 if (setsockopt(fd.get(), IPPROTO_IPV6, IPV6_JOIN_GROUP, &mreqv6, sizeof(mreqv6)) ==
1266                     -1) {
1267                     LOG(ERROR) << "Error set setsockopt for IPV6_JOIN_GROUP ";
1268                     return {};
1269                 }
1270                 struct sockaddr_in6 addr = {
1271                         .sin6_family = AF_INET6,
1272                         .sin6_port = htons(mdns_port),
1273                         .sin6_addr = IN6ADDR_ANY_INIT,
1274                 };
1275                 if (auto result = bindSocket(fd.get(), (struct sockaddr*)&addr, sizeof(addr));
1276                     !result.ok()) {
1277                     LOG(ERROR) << "failed to bind. MDNS IPV6: " << result.error().message();
1278                     return {};
1279                 }
1280             }
1281             return fd;
1282         } else {
1283             const char* socket_str = (socket_type == SOCK_STREAM) ? "TCP" : "UDP";
1284             if (auto result = bindSocket(fd.get(), ai->ai_addr, ai->ai_addrlen); !result.ok()) {
1285                 LOG(ERROR) << "failed to bind " << socket_str << " " << host_str << ":"
1286                            << listen_service_ << " " << result.error().message();
1287                 continue;
1288             }
1289             LOG(INFO) << "bound to " << socket_str << " " << host_str << ":" << listen_service_;
1290             return fd;
1291         }
1292     }
1293     return {};
1294 }
1295 
1296 }  // namespace test
1297