• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/dns/dns_response.h"
6 
7 #include <algorithm>
8 #include <cstdint>
9 #include <limits>
10 #include <numeric>
11 #include <utility>
12 #include <vector>
13 
14 #include "base/big_endian.h"
15 #include "base/containers/span.h"
16 #include "base/logging.h"
17 #include "base/numerics/safe_conversions.h"
18 #include "base/strings/string_util.h"
19 #include "base/sys_byteorder.h"
20 #include "net/base/io_buffer.h"
21 #include "net/base/net_errors.h"
22 #include "net/dns/dns_names_util.h"
23 #include "net/dns/dns_query.h"
24 #include "net/dns/dns_response_result_extractor.h"
25 #include "net/dns/dns_util.h"
26 #include "net/dns/public/dns_protocol.h"
27 #include "net/dns/record_rdata.h"
28 #include "third_party/abseil-cpp/absl/types/optional.h"
29 
30 namespace net {
31 
32 namespace {
33 
34 const size_t kHeaderSize = sizeof(dns_protocol::Header);
35 
36 const uint8_t kRcodeMask = 0xf;
37 
38 }  // namespace
39 
40 DnsResourceRecord::DnsResourceRecord() = default;
41 
DnsResourceRecord(const DnsResourceRecord & other)42 DnsResourceRecord::DnsResourceRecord(const DnsResourceRecord& other)
43     : name(other.name),
44       type(other.type),
45       klass(other.klass),
46       ttl(other.ttl),
47       owned_rdata(other.owned_rdata) {
48   if (!owned_rdata.empty())
49     rdata = owned_rdata;
50   else
51     rdata = other.rdata;
52 }
53 
DnsResourceRecord(DnsResourceRecord && other)54 DnsResourceRecord::DnsResourceRecord(DnsResourceRecord&& other)
55     : name(std::move(other.name)),
56       type(other.type),
57       klass(other.klass),
58       ttl(other.ttl),
59       owned_rdata(std::move(other.owned_rdata)) {
60   if (!owned_rdata.empty())
61     rdata = owned_rdata;
62   else
63     rdata = other.rdata;
64 }
65 
66 DnsResourceRecord::~DnsResourceRecord() = default;
67 
operator =(const DnsResourceRecord & other)68 DnsResourceRecord& DnsResourceRecord::operator=(
69     const DnsResourceRecord& other) {
70   name = other.name;
71   type = other.type;
72   klass = other.klass;
73   ttl = other.ttl;
74   owned_rdata = other.owned_rdata;
75 
76   if (!owned_rdata.empty())
77     rdata = owned_rdata;
78   else
79     rdata = other.rdata;
80 
81   return *this;
82 }
83 
operator =(DnsResourceRecord && other)84 DnsResourceRecord& DnsResourceRecord::operator=(DnsResourceRecord&& other) {
85   name = std::move(other.name);
86   type = other.type;
87   klass = other.klass;
88   ttl = other.ttl;
89   owned_rdata = std::move(other.owned_rdata);
90 
91   if (!owned_rdata.empty())
92     rdata = owned_rdata;
93   else
94     rdata = other.rdata;
95 
96   return *this;
97 }
98 
SetOwnedRdata(std::string value)99 void DnsResourceRecord::SetOwnedRdata(std::string value) {
100   DCHECK(!value.empty());
101   owned_rdata = std::move(value);
102   rdata = owned_rdata;
103   DCHECK_EQ(owned_rdata.data(), rdata.data());
104 }
105 
CalculateRecordSize() const106 size_t DnsResourceRecord::CalculateRecordSize() const {
107   bool has_final_dot = name.back() == '.';
108   // Depending on if |name| in the dotted format has the final dot for the root
109   // domain or not, the corresponding wire data in the DNS domain name format is
110   // 1 byte (with dot) or 2 bytes larger in size. See RFC 1035, Section 3.1 and
111   // DNSDomainFromDot.
112   return name.size() + (has_final_dot ? 1 : 2) +
113          net::dns_protocol::kResourceRecordSizeInBytesWithoutNameAndRData +
114          (owned_rdata.empty() ? rdata.size() : owned_rdata.size());
115 }
116 
117 DnsRecordParser::DnsRecordParser() = default;
118 
DnsRecordParser(const void * packet,size_t length,size_t offset,size_t num_records)119 DnsRecordParser::DnsRecordParser(const void* packet,
120                                  size_t length,
121                                  size_t offset,
122                                  size_t num_records)
123     : packet_(reinterpret_cast<const char*>(packet)),
124       length_(length),
125       num_records_(num_records),
126       cur_(packet_ + offset) {
127   CHECK_LE(offset, length);
128 }
129 
ReadName(const void * const vpos,std::string * out) const130 unsigned DnsRecordParser::ReadName(const void* const vpos,
131                                    std::string* out) const {
132   static const char kAbortMsg[] = "Abort parsing of noncompliant DNS record.";
133 
134   CHECK(packet_);
135   CHECK_LE(packet_, vpos);
136   CHECK_LE(vpos, packet_ + length_);
137   size_t initial_offset = (const char*)vpos - packet_;
138 
139   size_t offset = initial_offset;
140   // Count number of seen bytes to detect loops.
141   unsigned seen = 0;
142   // Remember how many bytes were consumed before first jump.
143   unsigned consumed = 0;
144   // The length of the encoded name (sum of label octets and label lengths).
145   // For context, RFC 1034 states that the total number of octets representing a
146   // domain name (the sum of all label octets and label lengths) is limited to
147   // 255. RFC 1035 introduces message compression as a way to reduce packet size
148   // on the wire, not to increase the maximum domain name length.
149   unsigned encoded_name_len = 0;
150 
151   if (initial_offset >= length_) {
152     return 0;
153   }
154 
155   if (out) {
156     out->clear();
157     out->reserve(dns_protocol::kMaxCharNameLength);
158   }
159 
160   for (;;) {
161     // The first two bits of the length give the type of the length. It's
162     // either a direct length or a pointer to the remainder of the name.
163     switch (packet_[offset] & dns_protocol::kLabelMask) {
164       case dns_protocol::kLabelPointer: {
165         if (offset + sizeof(uint16_t) > length_) {
166           VLOG(1) << kAbortMsg << " Truncated or missing label pointer.";
167           return 0;
168         }
169         if (consumed == 0) {
170           consumed = offset - initial_offset + sizeof(uint16_t);
171           if (!out)
172             return consumed;  // If name is not stored, that's all we need.
173         }
174         seen += sizeof(uint16_t);
175         // If seen the whole packet, then we must be in a loop.
176         if (seen > length_) {
177           VLOG(1) << kAbortMsg << " Detected loop in label pointers.";
178           return 0;
179         }
180         uint16_t new_offset;
181         base::ReadBigEndian(reinterpret_cast<const uint8_t*>(packet_ + offset),
182                             &new_offset);
183         offset = new_offset & dns_protocol::kOffsetMask;
184         if (offset >= length_) {
185           VLOG(1) << kAbortMsg << " Label pointer points outside packet.";
186           return 0;
187         }
188         break;
189       }
190       case dns_protocol::kLabelDirect: {
191         uint8_t label_len = packet_[offset];
192         ++offset;
193         // Note: root domain (".") is NOT included.
194         if (label_len == 0) {
195           if (consumed == 0) {
196             consumed = offset - initial_offset;
197           }  // else we set |consumed| before first jump
198           return consumed;
199         }
200         // Add one octet for the length and |label_len| for the number of
201         // following octets.
202         encoded_name_len += 1 + label_len;
203         if (encoded_name_len > dns_protocol::kMaxNameLength) {
204           VLOG(1) << kAbortMsg << " Name is too long.";
205           return 0;
206         }
207         if (offset + label_len >= length_) {
208           VLOG(1) << kAbortMsg << " Truncated or missing label.";
209           return 0;  // Truncated or missing label.
210         }
211         if (out) {
212           if (!out->empty())
213             out->append(".");
214           out->append(packet_ + offset, label_len);
215           CHECK_LE(out->size(), dns_protocol::kMaxCharNameLength);
216         }
217         offset += label_len;
218         seen += 1 + label_len;
219         break;
220       }
221       default:
222         // unhandled label type
223         VLOG(1) << kAbortMsg << " Unhandled label type.";
224         return 0;
225     }
226   }
227 }
228 
ReadRecord(DnsResourceRecord * out)229 bool DnsRecordParser::ReadRecord(DnsResourceRecord* out) {
230   CHECK(packet_);
231 
232   // Disallow parsing any more than the claimed number of records.
233   if (num_records_parsed_ >= num_records_)
234     return false;
235 
236   size_t consumed = ReadName(cur_, &out->name);
237   if (!consumed)
238     return false;
239   base::BigEndianReader reader(
240       reinterpret_cast<const uint8_t*>(cur_ + consumed),
241       packet_ + length_ - (cur_ + consumed));
242   uint16_t rdlen;
243   if (reader.ReadU16(&out->type) &&
244       reader.ReadU16(&out->klass) &&
245       reader.ReadU32(&out->ttl) &&
246       reader.ReadU16(&rdlen) &&
247       reader.ReadPiece(&out->rdata, rdlen)) {
248     cur_ = reinterpret_cast<const char*>(reader.ptr());
249     ++num_records_parsed_;
250     return true;
251   }
252   return false;
253 }
254 
ReadQuestion(std::string & out_dotted_qname,uint16_t & out_qtype)255 bool DnsRecordParser::ReadQuestion(std::string& out_dotted_qname,
256                                    uint16_t& out_qtype) {
257   size_t consumed = ReadName(cur_, &out_dotted_qname);
258   if (!consumed)
259     return false;
260 
261   if (consumed + 2 * sizeof(uint16_t) > (size_t)((packet_ + length_) - cur_)) {
262     return false;
263   }
264 
265   base::ReadBigEndian(reinterpret_cast<const uint8_t*>(cur_ + consumed),
266                       &out_qtype);
267 
268   cur_ += consumed + 2 * sizeof(uint16_t);  // QTYPE + QCLASS
269 
270   return true;
271 }
272 
DnsResponse(uint16_t id,bool is_authoritative,const std::vector<DnsResourceRecord> & answers,const std::vector<DnsResourceRecord> & authority_records,const std::vector<DnsResourceRecord> & additional_records,const absl::optional<DnsQuery> & query,uint8_t rcode,bool validate_records,bool validate_names_as_internet_hostnames)273 DnsResponse::DnsResponse(
274     uint16_t id,
275     bool is_authoritative,
276     const std::vector<DnsResourceRecord>& answers,
277     const std::vector<DnsResourceRecord>& authority_records,
278     const std::vector<DnsResourceRecord>& additional_records,
279     const absl::optional<DnsQuery>& query,
280     uint8_t rcode,
281     bool validate_records,
282     bool validate_names_as_internet_hostnames) {
283   bool has_query = query.has_value();
284   dns_protocol::Header header;
285   header.id = id;
286   bool success = true;
287   if (has_query) {
288     success &= (id == query.value().id());
289     DCHECK(success);
290     // DnsQuery only supports a single question.
291     header.qdcount = 1;
292   }
293   header.flags |= dns_protocol::kFlagResponse;
294   if (is_authoritative)
295     header.flags |= dns_protocol::kFlagAA;
296   DCHECK_EQ(0, rcode & ~kRcodeMask);
297   header.flags |= rcode;
298 
299   header.ancount = answers.size();
300   header.nscount = authority_records.size();
301   header.arcount = additional_records.size();
302 
303   // Response starts with the header and the question section (if any).
304   size_t response_size = has_query
305                              ? sizeof(header) + query.value().question_size()
306                              : sizeof(header);
307   // Add the size of all answers and additional records.
308   auto do_accumulation = [](size_t cur_size, const DnsResourceRecord& record) {
309     return cur_size + record.CalculateRecordSize();
310   };
311   response_size = std::accumulate(answers.begin(), answers.end(), response_size,
312                                   do_accumulation);
313   response_size =
314       std::accumulate(authority_records.begin(), authority_records.end(),
315                       response_size, do_accumulation);
316   response_size =
317       std::accumulate(additional_records.begin(), additional_records.end(),
318                       response_size, do_accumulation);
319 
320   auto io_buffer = base::MakeRefCounted<IOBufferWithSize>(response_size);
321   base::BigEndianWriter writer(io_buffer->data(), response_size);
322   success &= WriteHeader(&writer, header);
323   DCHECK(success);
324   if (has_query) {
325     success &= WriteQuestion(&writer, query.value());
326     DCHECK(success);
327   }
328   // Start the Answer section.
329   for (const auto& answer : answers) {
330     success &= WriteAnswer(&writer, answer, query, validate_records,
331                            validate_names_as_internet_hostnames);
332     DCHECK(success);
333   }
334   // Start the Authority section.
335   for (const auto& record : authority_records) {
336     success &= WriteRecord(&writer, record, validate_records,
337                            validate_names_as_internet_hostnames);
338     DCHECK(success);
339   }
340   // Start the Additional section.
341   for (const auto& record : additional_records) {
342     success &= WriteRecord(&writer, record, validate_records,
343                            validate_names_as_internet_hostnames);
344     DCHECK(success);
345   }
346   if (!success) {
347     return;
348   }
349   io_buffer_ = io_buffer;
350   io_buffer_size_ = response_size;
351   // Ensure we don't have any remaining uninitialized bytes in the buffer.
352   DCHECK(!writer.remaining());
353   memset(writer.ptr(), 0, writer.remaining());
354   if (has_query)
355     InitParse(io_buffer_size_, query.value());
356   else
357     InitParseWithoutQuery(io_buffer_size_);
358 }
359 
DnsResponse()360 DnsResponse::DnsResponse()
361     : io_buffer_(base::MakeRefCounted<IOBufferWithSize>(
362           dns_protocol::kMaxUDPSize + 1)),
363       io_buffer_size_(dns_protocol::kMaxUDPSize + 1) {}
364 
DnsResponse(scoped_refptr<IOBuffer> buffer,size_t size)365 DnsResponse::DnsResponse(scoped_refptr<IOBuffer> buffer, size_t size)
366     : io_buffer_(std::move(buffer)), io_buffer_size_(size) {}
367 
DnsResponse(size_t length)368 DnsResponse::DnsResponse(size_t length)
369     : io_buffer_(base::MakeRefCounted<IOBufferWithSize>(length)),
370       io_buffer_size_(length) {}
371 
DnsResponse(const void * data,size_t length,size_t answer_offset)372 DnsResponse::DnsResponse(const void* data, size_t length, size_t answer_offset)
373     : io_buffer_(base::MakeRefCounted<IOBufferWithSize>(length)),
374       io_buffer_size_(length),
375       parser_(io_buffer_->data(),
376               length,
377               answer_offset,
378               std::numeric_limits<size_t>::max()) {
379   DCHECK(data);
380   std::copy(static_cast<const char*>(data),
381             static_cast<const char*>(data) + length, io_buffer_->data());
382 }
383 
384 // static
CreateEmptyNoDataResponse(uint16_t id,bool is_authoritative,base::span<const uint8_t> qname,uint16_t qtype)385 DnsResponse DnsResponse::CreateEmptyNoDataResponse(
386     uint16_t id,
387     bool is_authoritative,
388     base::span<const uint8_t> qname,
389     uint16_t qtype) {
390   return DnsResponse(id, is_authoritative,
391                      /*answers=*/{},
392                      /*authority_records=*/{},
393                      /*additional_records=*/{}, DnsQuery(id, qname, qtype));
394 }
395 
396 DnsResponse::DnsResponse(DnsResponse&& other) = default;
397 DnsResponse& DnsResponse::operator=(DnsResponse&& other) = default;
398 
399 DnsResponse::~DnsResponse() = default;
400 
InitParse(size_t nbytes,const DnsQuery & query)401 bool DnsResponse::InitParse(size_t nbytes, const DnsQuery& query) {
402   const base::StringPiece question = query.question();
403 
404   // Response includes question, it should be at least that size.
405   if (nbytes < kHeaderSize + question.size() || nbytes > io_buffer_size_) {
406     return false;
407   }
408 
409   // At this point, it has been validated that the response is at least large
410   // enough to read the ID field.
411   id_available_ = true;
412 
413   // Match the query id.
414   DCHECK(id());
415   if (id().value() != query.id())
416     return false;
417 
418   // Not a response?
419   if ((base::NetToHost16(header()->flags) & dns_protocol::kFlagResponse) == 0)
420     return false;
421 
422   // Match question count.
423   if (base::NetToHost16(header()->qdcount) != 1)
424     return false;
425 
426   // Match the question section.
427   if (question !=
428       base::StringPiece(io_buffer_->data() + kHeaderSize, question.size())) {
429     return false;
430   }
431 
432   absl::optional<std::string> dotted_qname =
433       dns_names_util::NetworkToDottedName(query.qname());
434   if (!dotted_qname.has_value())
435     return false;
436   dotted_qnames_.push_back(std::move(dotted_qname).value());
437   qtypes_.push_back(query.qtype());
438 
439   size_t num_records = base::NetToHost16(header()->ancount) +
440                        base::NetToHost16(header()->nscount) +
441                        base::NetToHost16(header()->arcount);
442 
443   // Construct the parser. Only allow parsing up to `num_records` records. If
444   // more records are present in the buffer, it's just garbage extra data after
445   // the formal end of the response and should be ignored.
446   parser_ = DnsRecordParser(io_buffer_->data(), nbytes,
447                             kHeaderSize + question.size(), num_records);
448   return true;
449 }
450 
InitParseWithoutQuery(size_t nbytes)451 bool DnsResponse::InitParseWithoutQuery(size_t nbytes) {
452   if (nbytes < kHeaderSize || nbytes > io_buffer_size_) {
453     return false;
454   }
455   id_available_ = true;
456 
457   // Not a response?
458   if ((base::NetToHost16(header()->flags) & dns_protocol::kFlagResponse) == 0)
459     return false;
460 
461   size_t num_records = base::NetToHost16(header()->ancount) +
462                        base::NetToHost16(header()->nscount) +
463                        base::NetToHost16(header()->arcount);
464   // Only allow parsing up to `num_records` records. If more records are present
465   // in the buffer, it's just garbage extra data after the formal end of the
466   // response and should be ignored.
467   parser_ =
468       DnsRecordParser(io_buffer_->data(), nbytes, kHeaderSize, num_records);
469 
470   unsigned qdcount = base::NetToHost16(header()->qdcount);
471   for (unsigned i = 0; i < qdcount; ++i) {
472     std::string dotted_qname;
473     uint16_t qtype;
474     if (!parser_.ReadQuestion(dotted_qname, qtype)) {
475       parser_ = DnsRecordParser();  // Make parser invalid again.
476       return false;
477     }
478     dotted_qnames_.push_back(std::move(dotted_qname));
479     qtypes_.push_back(qtype);
480   }
481 
482   return true;
483 }
484 
id() const485 absl::optional<uint16_t> DnsResponse::id() const {
486   if (!id_available_)
487     return absl::nullopt;
488 
489   return base::NetToHost16(header()->id);
490 }
491 
IsValid() const492 bool DnsResponse::IsValid() const {
493   return parser_.IsValid();
494 }
495 
flags() const496 uint16_t DnsResponse::flags() const {
497   DCHECK(parser_.IsValid());
498   return base::NetToHost16(header()->flags) & ~(kRcodeMask);
499 }
500 
rcode() const501 uint8_t DnsResponse::rcode() const {
502   DCHECK(parser_.IsValid());
503   return base::NetToHost16(header()->flags) & kRcodeMask;
504 }
505 
question_count() const506 unsigned DnsResponse::question_count() const {
507   DCHECK(parser_.IsValid());
508   return base::NetToHost16(header()->qdcount);
509 }
510 
answer_count() const511 unsigned DnsResponse::answer_count() const {
512   DCHECK(parser_.IsValid());
513   return base::NetToHost16(header()->ancount);
514 }
515 
authority_count() const516 unsigned DnsResponse::authority_count() const {
517   DCHECK(parser_.IsValid());
518   return base::NetToHost16(header()->nscount);
519 }
520 
additional_answer_count() const521 unsigned DnsResponse::additional_answer_count() const {
522   DCHECK(parser_.IsValid());
523   return base::NetToHost16(header()->arcount);
524 }
525 
GetSingleQType() const526 uint16_t DnsResponse::GetSingleQType() const {
527   DCHECK_EQ(qtypes().size(), 1u);
528   return qtypes().front();
529 }
530 
GetSingleDottedName() const531 base::StringPiece DnsResponse::GetSingleDottedName() const {
532   DCHECK_EQ(dotted_qnames().size(), 1u);
533   return dotted_qnames().front();
534 }
535 
Parser() const536 DnsRecordParser DnsResponse::Parser() const {
537   DCHECK(parser_.IsValid());
538   // Return a copy of the parser.
539   return parser_;
540 }
541 
header() const542 const dns_protocol::Header* DnsResponse::header() const {
543   return reinterpret_cast<const dns_protocol::Header*>(io_buffer_->data());
544 }
545 
WriteHeader(base::BigEndianWriter * writer,const dns_protocol::Header & header)546 bool DnsResponse::WriteHeader(base::BigEndianWriter* writer,
547                               const dns_protocol::Header& header) {
548   return writer->WriteU16(header.id) && writer->WriteU16(header.flags) &&
549          writer->WriteU16(header.qdcount) && writer->WriteU16(header.ancount) &&
550          writer->WriteU16(header.nscount) && writer->WriteU16(header.arcount);
551 }
552 
WriteQuestion(base::BigEndianWriter * writer,const DnsQuery & query)553 bool DnsResponse::WriteQuestion(base::BigEndianWriter* writer,
554                                 const DnsQuery& query) {
555   base::StringPiece question = query.question();
556   return writer->WriteBytes(question.data(), question.size());
557 }
558 
WriteRecord(base::BigEndianWriter * writer,const DnsResourceRecord & record,bool validate_record,bool validate_name_as_internet_hostname)559 bool DnsResponse::WriteRecord(base::BigEndianWriter* writer,
560                               const DnsResourceRecord& record,
561                               bool validate_record,
562                               bool validate_name_as_internet_hostname) {
563   if (record.rdata != base::StringPiece(record.owned_rdata)) {
564     VLOG(1) << "record.rdata should point to record.owned_rdata.";
565     return false;
566   }
567 
568   if (validate_record &&
569       !RecordRdata::HasValidSize(record.owned_rdata, record.type)) {
570     VLOG(1) << "Invalid RDATA size for a record.";
571     return false;
572   }
573 
574   absl::optional<std::vector<uint8_t>> domain_name =
575       dns_names_util::DottedNameToNetwork(record.name,
576                                           validate_name_as_internet_hostname);
577   if (!domain_name.has_value()) {
578     VLOG(1) << "Invalid dotted name (as "
579             << (validate_name_as_internet_hostname ? "Internet hostname)."
580                                                    : "DNS name).");
581     return false;
582   }
583 
584   return writer->WriteBytes(domain_name.value().data(),
585                             domain_name.value().size()) &&
586          writer->WriteU16(record.type) && writer->WriteU16(record.klass) &&
587          writer->WriteU32(record.ttl) &&
588          writer->WriteU16(record.owned_rdata.size()) &&
589          // Use the owned RDATA in the record to construct the response.
590          writer->WriteBytes(record.owned_rdata.data(),
591                             record.owned_rdata.size());
592 }
593 
WriteAnswer(base::BigEndianWriter * writer,const DnsResourceRecord & answer,const absl::optional<DnsQuery> & query,bool validate_record,bool validate_name_as_internet_hostname)594 bool DnsResponse::WriteAnswer(base::BigEndianWriter* writer,
595                               const DnsResourceRecord& answer,
596                               const absl::optional<DnsQuery>& query,
597                               bool validate_record,
598                               bool validate_name_as_internet_hostname) {
599   // Generally assumed to be a mistake if we write answers that don't match the
600   // query type, except CNAME answers which can always be added.
601   if (validate_record && query.has_value() &&
602       answer.type != query.value().qtype() &&
603       answer.type != dns_protocol::kTypeCNAME) {
604     VLOG(1) << "Mismatched answer resource record type and qtype.";
605     return false;
606   }
607   return WriteRecord(writer, answer, validate_record,
608                      validate_name_as_internet_hostname);
609 }
610 
611 }  // namespace net
612