1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "discovery/mdns/mdns_records.h"
6
7 #include <algorithm>
8 #include <cctype>
9 #include <limits>
10 #include <sstream>
11 #include <vector>
12
13 #include "absl/strings/ascii.h"
14 #include "absl/strings/match.h"
15 #include "absl/strings/str_join.h"
16 #include "discovery/mdns/mdns_writer.h"
17
18 namespace openscreen {
19 namespace discovery {
20
21 namespace {
22
23 constexpr size_t kMaxRawRecordSize = std::numeric_limits<uint16_t>::max();
24
25 constexpr size_t kMaxMessageFieldEntryCount =
26 std::numeric_limits<uint16_t>::max();
27
CompareIgnoreCase(const std::string & x,const std::string & y)28 inline int CompareIgnoreCase(const std::string& x, const std::string& y) {
29 size_t i = 0;
30 for (; i < x.size(); i++) {
31 if (i == y.size()) {
32 return 1;
33 }
34 const char& x_char = std::tolower(x[i]);
35 const char& y_char = std::tolower(y[i]);
36 if (x_char < y_char) {
37 return -1;
38 } else if (y_char < x_char) {
39 return 1;
40 }
41 }
42 return i == y.size() ? 0 : -1;
43 }
44
45 template <typename RDataType>
IsGreaterThan(const Rdata & lhs,const Rdata & rhs)46 bool IsGreaterThan(const Rdata& lhs, const Rdata& rhs) {
47 const RDataType& lhs_cast = absl::get<RDataType>(lhs);
48 const RDataType& rhs_cast = absl::get<RDataType>(rhs);
49
50 // The Extra 2 in length is from the record size that Write() prepends to the
51 // result.
52 const size_t lhs_size = lhs_cast.MaxWireSize() + 2;
53 const size_t rhs_size = rhs_cast.MaxWireSize() + 2;
54
55 uint8_t lhs_bytes[lhs_size];
56 uint8_t rhs_bytes[rhs_size];
57 MdnsWriter lhs_writer(lhs_bytes, lhs_size);
58 MdnsWriter rhs_writer(rhs_bytes, rhs_size);
59
60 const bool lhs_write = lhs_writer.Write(lhs_cast);
61 const bool rhs_write = rhs_writer.Write(rhs_cast);
62 OSP_DCHECK(lhs_write);
63 OSP_DCHECK(rhs_write);
64
65 // Skip the size bits.
66 const size_t min_size = std::min(lhs_writer.offset(), rhs_writer.offset());
67 for (size_t i = 2; i < min_size; i++) {
68 if (lhs_bytes[i] != rhs_bytes[i]) {
69 return lhs_bytes[i] > rhs_bytes[i];
70 }
71 }
72
73 return lhs_size > rhs_size;
74 }
75
IsGreaterThan(DnsType type,const Rdata & lhs,const Rdata & rhs)76 bool IsGreaterThan(DnsType type, const Rdata& lhs, const Rdata& rhs) {
77 switch (type) {
78 case DnsType::kA:
79 return IsGreaterThan<ARecordRdata>(lhs, rhs);
80 case DnsType::kPTR:
81 return IsGreaterThan<PtrRecordRdata>(lhs, rhs);
82 case DnsType::kTXT:
83 return IsGreaterThan<TxtRecordRdata>(lhs, rhs);
84 case DnsType::kAAAA:
85 return IsGreaterThan<AAAARecordRdata>(lhs, rhs);
86 case DnsType::kSRV:
87 return IsGreaterThan<SrvRecordRdata>(lhs, rhs);
88 case DnsType::kNSEC:
89 return IsGreaterThan<NsecRecordRdata>(lhs, rhs);
90 default:
91 return IsGreaterThan<RawRecordRdata>(lhs, rhs);
92 }
93 }
94
95 } // namespace
96
IsValidDomainLabel(absl::string_view label)97 bool IsValidDomainLabel(absl::string_view label) {
98 const size_t label_size = label.size();
99 return label_size > 0 && label_size <= kMaxLabelLength;
100 }
101
102 DomainName::DomainName() = default;
103
DomainName(std::vector<std::string> labels)104 DomainName::DomainName(std::vector<std::string> labels)
105 : DomainName(labels.begin(), labels.end()) {}
106
DomainName(const std::vector<absl::string_view> & labels)107 DomainName::DomainName(const std::vector<absl::string_view>& labels)
108 : DomainName(labels.begin(), labels.end()) {}
109
DomainName(std::initializer_list<absl::string_view> labels)110 DomainName::DomainName(std::initializer_list<absl::string_view> labels)
111 : DomainName(labels.begin(), labels.end()) {}
112
DomainName(std::vector<std::string> labels,size_t max_wire_size)113 DomainName::DomainName(std::vector<std::string> labels, size_t max_wire_size)
114 : max_wire_size_(max_wire_size), labels_(std::move(labels)) {}
115
116 DomainName::DomainName(const DomainName& other) = default;
117
118 DomainName::DomainName(DomainName&& other) noexcept = default;
119
120 DomainName& DomainName::operator=(const DomainName& rhs) = default;
121
122 DomainName& DomainName::operator=(DomainName&& rhs) = default;
123
ToString() const124 std::string DomainName::ToString() const {
125 return absl::StrJoin(labels_, ".");
126 }
127
operator <(const DomainName & rhs) const128 bool DomainName::operator<(const DomainName& rhs) const {
129 size_t i = 0;
130 for (; i < labels_.size(); i++) {
131 if (i == rhs.labels_.size()) {
132 return false;
133 } else {
134 int result = CompareIgnoreCase(labels_[i], rhs.labels_[i]);
135 if (result < 0) {
136 return true;
137 } else if (result > 0) {
138 return false;
139 }
140 }
141 }
142 return i < rhs.labels_.size();
143 }
144
operator <=(const DomainName & rhs) const145 bool DomainName::operator<=(const DomainName& rhs) const {
146 return (*this < rhs) || (*this == rhs);
147 }
148
operator >(const DomainName & rhs) const149 bool DomainName::operator>(const DomainName& rhs) const {
150 return !(*this < rhs) && !(*this == rhs);
151 }
152
operator >=(const DomainName & rhs) const153 bool DomainName::operator>=(const DomainName& rhs) const {
154 return !(*this < rhs);
155 }
156
operator ==(const DomainName & rhs) const157 bool DomainName::operator==(const DomainName& rhs) const {
158 if (labels_.size() != rhs.labels_.size()) {
159 return false;
160 }
161 for (size_t i = 0; i < labels_.size(); i++) {
162 if (CompareIgnoreCase(labels_[i], rhs.labels_[i]) != 0) {
163 return false;
164 }
165 }
166 return true;
167 }
168
operator !=(const DomainName & rhs) const169 bool DomainName::operator!=(const DomainName& rhs) const {
170 return !(*this == rhs);
171 }
172
MaxWireSize() const173 size_t DomainName::MaxWireSize() const {
174 return max_wire_size_;
175 }
176
177 // static
TryCreate(std::vector<uint8_t> rdata)178 ErrorOr<RawRecordRdata> RawRecordRdata::TryCreate(std::vector<uint8_t> rdata) {
179 if (rdata.size() > kMaxRawRecordSize) {
180 return Error::Code::kIndexOutOfBounds;
181 } else {
182 return RawRecordRdata(std::move(rdata));
183 }
184 }
185
186 RawRecordRdata::RawRecordRdata() = default;
187
RawRecordRdata(std::vector<uint8_t> rdata)188 RawRecordRdata::RawRecordRdata(std::vector<uint8_t> rdata)
189 : rdata_(std::move(rdata)) {
190 // Ensure RDATA length does not exceed the maximum allowed.
191 OSP_DCHECK(rdata_.size() <= kMaxRawRecordSize);
192 }
193
RawRecordRdata(const uint8_t * begin,size_t size)194 RawRecordRdata::RawRecordRdata(const uint8_t* begin, size_t size)
195 : RawRecordRdata(std::vector<uint8_t>(begin, begin + size)) {}
196
197 RawRecordRdata::RawRecordRdata(const RawRecordRdata& other) = default;
198
199 RawRecordRdata::RawRecordRdata(RawRecordRdata&& other) noexcept = default;
200
201 RawRecordRdata& RawRecordRdata::operator=(const RawRecordRdata& rhs) = default;
202
203 RawRecordRdata& RawRecordRdata::operator=(RawRecordRdata&& rhs) = default;
204
operator ==(const RawRecordRdata & rhs) const205 bool RawRecordRdata::operator==(const RawRecordRdata& rhs) const {
206 return rdata_ == rhs.rdata_;
207 }
208
operator !=(const RawRecordRdata & rhs) const209 bool RawRecordRdata::operator!=(const RawRecordRdata& rhs) const {
210 return !(*this == rhs);
211 }
212
MaxWireSize() const213 size_t RawRecordRdata::MaxWireSize() const {
214 // max_wire_size includes uint16_t record length field.
215 return sizeof(uint16_t) + rdata_.size();
216 }
217
218 SrvRecordRdata::SrvRecordRdata() = default;
219
SrvRecordRdata(uint16_t priority,uint16_t weight,uint16_t port,DomainName target)220 SrvRecordRdata::SrvRecordRdata(uint16_t priority,
221 uint16_t weight,
222 uint16_t port,
223 DomainName target)
224 : priority_(priority),
225 weight_(weight),
226 port_(port),
227 target_(std::move(target)) {}
228
229 SrvRecordRdata::SrvRecordRdata(const SrvRecordRdata& other) = default;
230
231 SrvRecordRdata::SrvRecordRdata(SrvRecordRdata&& other) noexcept = default;
232
233 SrvRecordRdata& SrvRecordRdata::operator=(const SrvRecordRdata& rhs) = default;
234
235 SrvRecordRdata& SrvRecordRdata::operator=(SrvRecordRdata&& rhs) = default;
236
operator ==(const SrvRecordRdata & rhs) const237 bool SrvRecordRdata::operator==(const SrvRecordRdata& rhs) const {
238 return priority_ == rhs.priority_ && weight_ == rhs.weight_ &&
239 port_ == rhs.port_ && target_ == rhs.target_;
240 }
241
operator !=(const SrvRecordRdata & rhs) const242 bool SrvRecordRdata::operator!=(const SrvRecordRdata& rhs) const {
243 return !(*this == rhs);
244 }
245
MaxWireSize() const246 size_t SrvRecordRdata::MaxWireSize() const {
247 // max_wire_size includes uint16_t record length field.
248 return sizeof(uint16_t) + sizeof(priority_) + sizeof(weight_) +
249 sizeof(port_) + target_.MaxWireSize();
250 }
251
252 ARecordRdata::ARecordRdata() = default;
253
ARecordRdata(IPAddress ipv4_address,NetworkInterfaceIndex interface_index)254 ARecordRdata::ARecordRdata(IPAddress ipv4_address,
255 NetworkInterfaceIndex interface_index)
256 : ipv4_address_(std::move(ipv4_address)),
257 interface_index_(interface_index) {
258 OSP_CHECK(ipv4_address_.IsV4());
259 }
260
261 ARecordRdata::ARecordRdata(const ARecordRdata& other) = default;
262
263 ARecordRdata::ARecordRdata(ARecordRdata&& other) noexcept = default;
264
265 ARecordRdata& ARecordRdata::operator=(const ARecordRdata& rhs) = default;
266
267 ARecordRdata& ARecordRdata::operator=(ARecordRdata&& rhs) = default;
268
operator ==(const ARecordRdata & rhs) const269 bool ARecordRdata::operator==(const ARecordRdata& rhs) const {
270 return ipv4_address_ == rhs.ipv4_address_ &&
271 interface_index_ == rhs.interface_index_;
272 }
273
operator !=(const ARecordRdata & rhs) const274 bool ARecordRdata::operator!=(const ARecordRdata& rhs) const {
275 return !(*this == rhs);
276 }
277
MaxWireSize() const278 size_t ARecordRdata::MaxWireSize() const {
279 // max_wire_size includes uint16_t record length field.
280 return sizeof(uint16_t) + IPAddress::kV4Size;
281 }
282
283 AAAARecordRdata::AAAARecordRdata() = default;
284
AAAARecordRdata(IPAddress ipv6_address,NetworkInterfaceIndex interface_index)285 AAAARecordRdata::AAAARecordRdata(IPAddress ipv6_address,
286 NetworkInterfaceIndex interface_index)
287 : ipv6_address_(std::move(ipv6_address)),
288 interface_index_(interface_index) {
289 OSP_CHECK(ipv6_address_.IsV6());
290 }
291
292 AAAARecordRdata::AAAARecordRdata(const AAAARecordRdata& other) = default;
293
294 AAAARecordRdata::AAAARecordRdata(AAAARecordRdata&& other) noexcept = default;
295
296 AAAARecordRdata& AAAARecordRdata::operator=(const AAAARecordRdata& rhs) =
297 default;
298
299 AAAARecordRdata& AAAARecordRdata::operator=(AAAARecordRdata&& rhs) = default;
300
operator ==(const AAAARecordRdata & rhs) const301 bool AAAARecordRdata::operator==(const AAAARecordRdata& rhs) const {
302 return ipv6_address_ == rhs.ipv6_address_ &&
303 interface_index_ == rhs.interface_index_;
304 }
305
operator !=(const AAAARecordRdata & rhs) const306 bool AAAARecordRdata::operator!=(const AAAARecordRdata& rhs) const {
307 return !(*this == rhs);
308 }
309
MaxWireSize() const310 size_t AAAARecordRdata::MaxWireSize() const {
311 // max_wire_size includes uint16_t record length field.
312 return sizeof(uint16_t) + IPAddress::kV6Size;
313 }
314
315 PtrRecordRdata::PtrRecordRdata() = default;
316
PtrRecordRdata(DomainName ptr_domain)317 PtrRecordRdata::PtrRecordRdata(DomainName ptr_domain)
318 : ptr_domain_(ptr_domain) {}
319
320 PtrRecordRdata::PtrRecordRdata(const PtrRecordRdata& other) = default;
321
322 PtrRecordRdata::PtrRecordRdata(PtrRecordRdata&& other) noexcept = default;
323
324 PtrRecordRdata& PtrRecordRdata::operator=(const PtrRecordRdata& rhs) = default;
325
326 PtrRecordRdata& PtrRecordRdata::operator=(PtrRecordRdata&& rhs) = default;
327
operator ==(const PtrRecordRdata & rhs) const328 bool PtrRecordRdata::operator==(const PtrRecordRdata& rhs) const {
329 return ptr_domain_ == rhs.ptr_domain_;
330 }
331
operator !=(const PtrRecordRdata & rhs) const332 bool PtrRecordRdata::operator!=(const PtrRecordRdata& rhs) const {
333 return !(*this == rhs);
334 }
335
MaxWireSize() const336 size_t PtrRecordRdata::MaxWireSize() const {
337 // max_wire_size includes uint16_t record length field.
338 return sizeof(uint16_t) + ptr_domain_.MaxWireSize();
339 }
340
341 // static
TryCreate(std::vector<Entry> texts)342 ErrorOr<TxtRecordRdata> TxtRecordRdata::TryCreate(std::vector<Entry> texts) {
343 std::vector<std::string> str_texts;
344 size_t max_wire_size = 3;
345 if (texts.size() > 0) {
346 str_texts.reserve(texts.size());
347 // max_wire_size includes uint16_t record length field.
348 max_wire_size = sizeof(uint16_t);
349 for (const auto& text : texts) {
350 if (text.empty()) {
351 return Error::Code::kParameterInvalid;
352 }
353 str_texts.push_back(
354 std::string(reinterpret_cast<const char*>(text.data()), text.size()));
355 // Include the length byte in the size calculation.
356 max_wire_size += text.size() + 1;
357 }
358 }
359 return TxtRecordRdata(std::move(str_texts), max_wire_size);
360 }
361
362 TxtRecordRdata::TxtRecordRdata() = default;
363
TxtRecordRdata(std::vector<Entry> texts)364 TxtRecordRdata::TxtRecordRdata(std::vector<Entry> texts) {
365 ErrorOr<TxtRecordRdata> rdata = TxtRecordRdata::TryCreate(std::move(texts));
366 *this = std::move(rdata.value());
367 }
368
TxtRecordRdata(std::vector<std::string> texts,size_t max_wire_size)369 TxtRecordRdata::TxtRecordRdata(std::vector<std::string> texts,
370 size_t max_wire_size)
371 : max_wire_size_(max_wire_size), texts_(std::move(texts)) {}
372
373 TxtRecordRdata::TxtRecordRdata(const TxtRecordRdata& other) = default;
374
375 TxtRecordRdata::TxtRecordRdata(TxtRecordRdata&& other) noexcept = default;
376
377 TxtRecordRdata& TxtRecordRdata::operator=(const TxtRecordRdata& rhs) = default;
378
379 TxtRecordRdata& TxtRecordRdata::operator=(TxtRecordRdata&& rhs) = default;
380
operator ==(const TxtRecordRdata & rhs) const381 bool TxtRecordRdata::operator==(const TxtRecordRdata& rhs) const {
382 return texts_ == rhs.texts_;
383 }
384
operator !=(const TxtRecordRdata & rhs) const385 bool TxtRecordRdata::operator!=(const TxtRecordRdata& rhs) const {
386 return !(*this == rhs);
387 }
388
MaxWireSize() const389 size_t TxtRecordRdata::MaxWireSize() const {
390 return max_wire_size_;
391 }
392
393 NsecRecordRdata::NsecRecordRdata() = default;
394
NsecRecordRdata(DomainName next_domain_name,std::vector<DnsType> types)395 NsecRecordRdata::NsecRecordRdata(DomainName next_domain_name,
396 std::vector<DnsType> types)
397 : types_(std::move(types)), next_domain_name_(std::move(next_domain_name)) {
398 // Sort the types_ array for easier comparison later.
399 std::sort(types_.begin(), types_.end());
400
401 // Calculate the bitmaps as described in RFC 4034 Section 4.1.2.
402 std::vector<uint8_t> block_contents;
403 uint8_t current_block = 0;
404 for (auto type : types_) {
405 const uint16_t type_int = static_cast<uint16_t>(type);
406 const uint8_t block = static_cast<uint8_t>(type_int >> 8);
407 const uint8_t block_position = static_cast<uint8_t>(type_int & 0xFF);
408 const uint8_t byte_bit_is_at = block_position >> 3; // First 5 bits.
409 const uint8_t byte_mask = 0x80 >> (block_position & 0x07); // Last 3 bits.
410
411 // If the block has changed, write the previous block's info and all of its
412 // contents to the |encoded_types_| vector.
413 if (block > current_block) {
414 if (!block_contents.empty()) {
415 encoded_types_.push_back(current_block);
416 encoded_types_.push_back(static_cast<uint8_t>(block_contents.size()));
417 encoded_types_.insert(encoded_types_.end(), block_contents.begin(),
418 block_contents.end());
419 }
420 block_contents = std::vector<uint8_t>();
421 current_block = block;
422 }
423
424 // Make sure |block_contents| is large enough to hold the bit representing
425 // the new type , then set it.
426 if (block_contents.size() <= byte_bit_is_at) {
427 block_contents.insert(block_contents.end(),
428 byte_bit_is_at - block_contents.size() + 1, 0x00);
429 }
430
431 block_contents[byte_bit_is_at] |= byte_mask;
432 }
433
434 if (!block_contents.empty()) {
435 encoded_types_.push_back(current_block);
436 encoded_types_.push_back(static_cast<uint8_t>(block_contents.size()));
437 encoded_types_.insert(encoded_types_.end(), block_contents.begin(),
438 block_contents.end());
439 }
440 }
441
442 NsecRecordRdata::NsecRecordRdata(const NsecRecordRdata& other) = default;
443
444 NsecRecordRdata::NsecRecordRdata(NsecRecordRdata&& other) noexcept = default;
445
446 NsecRecordRdata& NsecRecordRdata::operator=(const NsecRecordRdata& rhs) =
447 default;
448
449 NsecRecordRdata& NsecRecordRdata::operator=(NsecRecordRdata&& rhs) = default;
450
operator ==(const NsecRecordRdata & rhs) const451 bool NsecRecordRdata::operator==(const NsecRecordRdata& rhs) const {
452 return types_ == rhs.types_ && next_domain_name_ == rhs.next_domain_name_;
453 }
454
operator !=(const NsecRecordRdata & rhs) const455 bool NsecRecordRdata::operator!=(const NsecRecordRdata& rhs) const {
456 return !(*this == rhs);
457 }
458
MaxWireSize() const459 size_t NsecRecordRdata::MaxWireSize() const {
460 return next_domain_name_.MaxWireSize() + encoded_types_.size();
461 }
462
MaxWireSize() const463 size_t OptRecordRdata::Option::MaxWireSize() const {
464 // One uint16_t for each of OPTION-LENGTH and OPTION-CODE as defined in RFC
465 // 6891 section 6.1.2.
466 constexpr size_t kOptionLengthAndCodeSize = 2 * sizeof(uint16_t);
467 return data.size() + kOptionLengthAndCodeSize;
468 }
469
operator >(const OptRecordRdata::Option & rhs) const470 bool OptRecordRdata::Option::operator>(
471 const OptRecordRdata::Option& rhs) const {
472 if (code != rhs.code) {
473 return code > rhs.code;
474 } else if (length != rhs.length) {
475 return length > rhs.length;
476 } else if (data.size() != rhs.data.size()) {
477 return data.size() > rhs.data.size();
478 }
479
480 for (int i = 0; i < static_cast<int>(data.size()); i++) {
481 if (data[i] != rhs.data[i]) {
482 return data[i] > rhs.data[i];
483 }
484 }
485
486 return false;
487 }
488
operator <(const OptRecordRdata::Option & rhs) const489 bool OptRecordRdata::Option::operator<(
490 const OptRecordRdata::Option& rhs) const {
491 return rhs > *this;
492 }
493
operator >=(const OptRecordRdata::Option & rhs) const494 bool OptRecordRdata::Option::operator>=(
495 const OptRecordRdata::Option& rhs) const {
496 return !(*this < rhs);
497 }
498
operator <=(const OptRecordRdata::Option & rhs) const499 bool OptRecordRdata::Option::operator<=(
500 const OptRecordRdata::Option& rhs) const {
501 return !(*this > rhs);
502 }
503
operator ==(const OptRecordRdata::Option & rhs) const504 bool OptRecordRdata::Option::operator==(
505 const OptRecordRdata::Option& rhs) const {
506 return *this >= rhs && *this <= rhs;
507 }
508
operator !=(const OptRecordRdata::Option & rhs) const509 bool OptRecordRdata::Option::operator!=(
510 const OptRecordRdata::Option& rhs) const {
511 return !(*this == rhs);
512 }
513
514 OptRecordRdata::OptRecordRdata() = default;
515
OptRecordRdata(std::vector<Option> options)516 OptRecordRdata::OptRecordRdata(std::vector<Option> options)
517 : options_(std::move(options)) {
518 for (const auto& option : options_) {
519 max_wire_size_ += option.MaxWireSize();
520 }
521 std::sort(options_.begin(), options_.end());
522 }
523
524 OptRecordRdata::OptRecordRdata(const OptRecordRdata& other) = default;
525
526 OptRecordRdata::OptRecordRdata(OptRecordRdata&& other) noexcept = default;
527
528 OptRecordRdata& OptRecordRdata::operator=(const OptRecordRdata& rhs) = default;
529
530 OptRecordRdata& OptRecordRdata::operator=(OptRecordRdata&& rhs) = default;
531
operator ==(const OptRecordRdata & rhs) const532 bool OptRecordRdata::operator==(const OptRecordRdata& rhs) const {
533 return options_ == rhs.options_;
534 }
535
operator !=(const OptRecordRdata & rhs) const536 bool OptRecordRdata::operator!=(const OptRecordRdata& rhs) const {
537 return !(*this == rhs);
538 }
539
540 // static
TryCreate(DomainName name,DnsType dns_type,DnsClass dns_class,RecordType record_type,std::chrono::seconds ttl,Rdata rdata)541 ErrorOr<MdnsRecord> MdnsRecord::TryCreate(DomainName name,
542 DnsType dns_type,
543 DnsClass dns_class,
544 RecordType record_type,
545 std::chrono::seconds ttl,
546 Rdata rdata) {
547 if (!IsValidConfig(name, dns_type, ttl, rdata)) {
548 return Error::Code::kParameterInvalid;
549 } else {
550 return MdnsRecord(std::move(name), dns_type, dns_class, record_type, ttl,
551 std::move(rdata));
552 }
553 }
554
555 MdnsRecord::MdnsRecord() = default;
556
MdnsRecord(DomainName name,DnsType dns_type,DnsClass dns_class,RecordType record_type,std::chrono::seconds ttl,Rdata rdata)557 MdnsRecord::MdnsRecord(DomainName name,
558 DnsType dns_type,
559 DnsClass dns_class,
560 RecordType record_type,
561 std::chrono::seconds ttl,
562 Rdata rdata)
563 : name_(std::move(name)),
564 dns_type_(dns_type),
565 dns_class_(dns_class),
566 record_type_(record_type),
567 ttl_(ttl),
568 rdata_(std::move(rdata)) {
569 OSP_DCHECK(IsValidConfig(name_, dns_type, ttl_, rdata_));
570 }
571
572 MdnsRecord::MdnsRecord(const MdnsRecord& other) = default;
573
574 MdnsRecord::MdnsRecord(MdnsRecord&& other) noexcept = default;
575
576 MdnsRecord& MdnsRecord::operator=(const MdnsRecord& rhs) = default;
577
578 MdnsRecord& MdnsRecord::operator=(MdnsRecord&& rhs) = default;
579
580 // static
IsValidConfig(const DomainName & name,DnsType dns_type,std::chrono::seconds ttl,const Rdata & rdata)581 bool MdnsRecord::IsValidConfig(const DomainName& name,
582 DnsType dns_type,
583 std::chrono::seconds ttl,
584 const Rdata& rdata) {
585 // NOTE: Although the name_ field was initially expected to be non-empty, this
586 // validation is no longer accurate for some record types (such as OPT
587 // records). To ensure that future record types correctly parse into
588 // RawRecordData types and do not invalidate the received message, this check
589 // has been removed.
590 return ttl.count() <= std::numeric_limits<uint32_t>::max() &&
591 ((dns_type == DnsType::kSRV &&
592 absl::holds_alternative<SrvRecordRdata>(rdata)) ||
593 (dns_type == DnsType::kA &&
594 absl::holds_alternative<ARecordRdata>(rdata)) ||
595 (dns_type == DnsType::kAAAA &&
596 absl::holds_alternative<AAAARecordRdata>(rdata)) ||
597 (dns_type == DnsType::kPTR &&
598 absl::holds_alternative<PtrRecordRdata>(rdata)) ||
599 (dns_type == DnsType::kTXT &&
600 absl::holds_alternative<TxtRecordRdata>(rdata)) ||
601 (dns_type == DnsType::kNSEC &&
602 absl::holds_alternative<NsecRecordRdata>(rdata)) ||
603 (dns_type == DnsType::kOPT &&
604 absl::holds_alternative<OptRecordRdata>(rdata)) ||
605 absl::holds_alternative<RawRecordRdata>(rdata));
606 }
607
operator ==(const MdnsRecord & rhs) const608 bool MdnsRecord::operator==(const MdnsRecord& rhs) const {
609 return IsReannouncementOf(rhs) && ttl_ == rhs.ttl_;
610 }
611
operator !=(const MdnsRecord & rhs) const612 bool MdnsRecord::operator!=(const MdnsRecord& rhs) const {
613 return !(*this == rhs);
614 }
615
operator >(const MdnsRecord & rhs) const616 bool MdnsRecord::operator>(const MdnsRecord& rhs) const {
617 // Returns the record which is lexicographically later. The determination of
618 // "lexicographically later" is performed by first comparing the record class,
619 // then the record type, then raw comparison of the binary content of the
620 // rdata without regard for meaning or structure.
621 // NOTE: Per RFC, the TTL is not included in this comparison.
622 if (name() != rhs.name()) {
623 return name() > rhs.name();
624 }
625
626 if (record_type() != rhs.record_type()) {
627 return record_type() == RecordType::kUnique;
628 }
629
630 if (dns_class() != rhs.dns_class()) {
631 return dns_class() > rhs.dns_class();
632 }
633
634 uint16_t this_type = static_cast<uint16_t>(dns_type()) & kClassMask;
635 uint16_t other_type = static_cast<uint16_t>(rhs.dns_type()) & kClassMask;
636 if (this_type != other_type) {
637 return this_type > other_type;
638 }
639
640 return IsGreaterThan(dns_type(), rdata(), rhs.rdata());
641 }
642
operator <(const MdnsRecord & rhs) const643 bool MdnsRecord::operator<(const MdnsRecord& rhs) const {
644 return rhs > *this;
645 }
646
operator <=(const MdnsRecord & rhs) const647 bool MdnsRecord::operator<=(const MdnsRecord& rhs) const {
648 return !(*this > rhs);
649 }
650
operator >=(const MdnsRecord & rhs) const651 bool MdnsRecord::operator>=(const MdnsRecord& rhs) const {
652 return !(*this < rhs);
653 }
654
IsReannouncementOf(const MdnsRecord & rhs) const655 bool MdnsRecord::IsReannouncementOf(const MdnsRecord& rhs) const {
656 return dns_type_ == rhs.dns_type_ && dns_class_ == rhs.dns_class_ &&
657 record_type_ == rhs.record_type_ && name_ == rhs.name_ &&
658 rdata_ == rhs.rdata_;
659 }
660
MaxWireSize() const661 size_t MdnsRecord::MaxWireSize() const {
662 auto wire_size_visitor = [](auto&& arg) { return arg.MaxWireSize(); };
663 // NAME size, 2-byte TYPE, 2-byte CLASS, 4-byte TTL, RDATA size
664 return name_.MaxWireSize() + absl::visit(wire_size_visitor, rdata_) + 8;
665 }
666
ToString() const667 std::string MdnsRecord::ToString() const {
668 std::stringstream ss;
669 ss << "name: '" << name_.ToString() << "'";
670 ss << ", type: " << dns_type_;
671
672 if (dns_type_ == DnsType::kPTR) {
673 const DomainName& target = absl::get<PtrRecordRdata>(rdata_).ptr_domain();
674 ss << ", target: '" << target.ToString() << "'";
675 } else if (dns_type_ == DnsType::kSRV) {
676 const DomainName& target = absl::get<SrvRecordRdata>(rdata_).target();
677 ss << ", target: '" << target.ToString() << "'";
678 } else if (dns_type_ == DnsType::kNSEC) {
679 const auto& nsec_rdata = absl::get<NsecRecordRdata>(rdata_);
680 std::vector<DnsType> types = nsec_rdata.types();
681 ss << ", representing [";
682 if (!types.empty()) {
683 auto it = types.begin();
684 ss << *it++;
685 while (it != types.end()) {
686 ss << ", " << *it++;
687 }
688 ss << "]";
689 }
690 }
691
692 return ss.str();
693 }
694
CreateAddressRecord(DomainName name,const IPAddress & address)695 MdnsRecord CreateAddressRecord(DomainName name, const IPAddress& address) {
696 Rdata rdata;
697 DnsType type;
698 std::chrono::seconds ttl;
699 if (address.IsV4()) {
700 type = DnsType::kA;
701 rdata = ARecordRdata(address);
702 ttl = kARecordTtl;
703 } else {
704 type = DnsType::kAAAA;
705 rdata = AAAARecordRdata(address);
706 ttl = kAAAARecordTtl;
707 }
708
709 return MdnsRecord(std::move(name), type, DnsClass::kIN, RecordType::kUnique,
710 ttl, std::move(rdata));
711 }
712
713 // static
TryCreate(DomainName name,DnsType dns_type,DnsClass dns_class,ResponseType response_type)714 ErrorOr<MdnsQuestion> MdnsQuestion::TryCreate(DomainName name,
715 DnsType dns_type,
716 DnsClass dns_class,
717 ResponseType response_type) {
718 if (name.empty()) {
719 return Error::Code::kParameterInvalid;
720 }
721
722 return MdnsQuestion(std::move(name), dns_type, dns_class, response_type);
723 }
724
MdnsQuestion(DomainName name,DnsType dns_type,DnsClass dns_class,ResponseType response_type)725 MdnsQuestion::MdnsQuestion(DomainName name,
726 DnsType dns_type,
727 DnsClass dns_class,
728 ResponseType response_type)
729 : name_(std::move(name)),
730 dns_type_(dns_type),
731 dns_class_(dns_class),
732 response_type_(response_type) {
733 OSP_CHECK(!name_.empty());
734 }
735
operator ==(const MdnsQuestion & rhs) const736 bool MdnsQuestion::operator==(const MdnsQuestion& rhs) const {
737 return dns_type_ == rhs.dns_type_ && dns_class_ == rhs.dns_class_ &&
738 response_type_ == rhs.response_type_ && name_ == rhs.name_;
739 }
740
operator !=(const MdnsQuestion & rhs) const741 bool MdnsQuestion::operator!=(const MdnsQuestion& rhs) const {
742 return !(*this == rhs);
743 }
744
MaxWireSize() const745 size_t MdnsQuestion::MaxWireSize() const {
746 // NAME size, 2-byte TYPE, 2-byte CLASS
747 return name_.MaxWireSize() + 4;
748 }
749
750 // static
TryCreate(uint16_t id,MessageType type,std::vector<MdnsQuestion> questions,std::vector<MdnsRecord> answers,std::vector<MdnsRecord> authority_records,std::vector<MdnsRecord> additional_records)751 ErrorOr<MdnsMessage> MdnsMessage::TryCreate(
752 uint16_t id,
753 MessageType type,
754 std::vector<MdnsQuestion> questions,
755 std::vector<MdnsRecord> answers,
756 std::vector<MdnsRecord> authority_records,
757 std::vector<MdnsRecord> additional_records) {
758 if (questions.size() >= kMaxMessageFieldEntryCount ||
759 answers.size() >= kMaxMessageFieldEntryCount ||
760 authority_records.size() >= kMaxMessageFieldEntryCount ||
761 additional_records.size() >= kMaxMessageFieldEntryCount) {
762 return Error::Code::kParameterInvalid;
763 }
764
765 return MdnsMessage(id, type, std::move(questions), std::move(answers),
766 std::move(authority_records),
767 std::move(additional_records));
768 }
769
MdnsMessage(uint16_t id,MessageType type)770 MdnsMessage::MdnsMessage(uint16_t id, MessageType type)
771 : id_(id), type_(type) {}
772
MdnsMessage(uint16_t id,MessageType type,std::vector<MdnsQuestion> questions,std::vector<MdnsRecord> answers,std::vector<MdnsRecord> authority_records,std::vector<MdnsRecord> additional_records)773 MdnsMessage::MdnsMessage(uint16_t id,
774 MessageType type,
775 std::vector<MdnsQuestion> questions,
776 std::vector<MdnsRecord> answers,
777 std::vector<MdnsRecord> authority_records,
778 std::vector<MdnsRecord> additional_records)
779 : id_(id),
780 type_(type),
781 questions_(std::move(questions)),
782 answers_(std::move(answers)),
783 authority_records_(std::move(authority_records)),
784 additional_records_(std::move(additional_records)) {
785 OSP_DCHECK(questions_.size() < kMaxMessageFieldEntryCount);
786 OSP_DCHECK(answers_.size() < kMaxMessageFieldEntryCount);
787 OSP_DCHECK(authority_records_.size() < kMaxMessageFieldEntryCount);
788 OSP_DCHECK(additional_records_.size() < kMaxMessageFieldEntryCount);
789
790 for (const MdnsQuestion& question : questions_) {
791 max_wire_size_ += question.MaxWireSize();
792 }
793 for (const MdnsRecord& record : answers_) {
794 max_wire_size_ += record.MaxWireSize();
795 }
796 for (const MdnsRecord& record : authority_records_) {
797 max_wire_size_ += record.MaxWireSize();
798 }
799 for (const MdnsRecord& record : additional_records_) {
800 max_wire_size_ += record.MaxWireSize();
801 }
802 }
803
operator ==(const MdnsMessage & rhs) const804 bool MdnsMessage::operator==(const MdnsMessage& rhs) const {
805 return id_ == rhs.id_ && type_ == rhs.type_ && questions_ == rhs.questions_ &&
806 answers_ == rhs.answers_ &&
807 authority_records_ == rhs.authority_records_ &&
808 additional_records_ == rhs.additional_records_;
809 }
810
operator !=(const MdnsMessage & rhs) const811 bool MdnsMessage::operator!=(const MdnsMessage& rhs) const {
812 return !(*this == rhs);
813 }
814
IsProbeQuery() const815 bool MdnsMessage::IsProbeQuery() const {
816 // A message is a probe query if it contains records in the authority section
817 // which answer the question being asked.
818 if (questions().empty() || authority_records().empty()) {
819 return false;
820 }
821
822 for (const MdnsQuestion& question : questions_) {
823 for (const MdnsRecord& record : authority_records_) {
824 if (question.name() == record.name() &&
825 ((question.dns_type() == record.dns_type()) ||
826 (question.dns_type() == DnsType::kANY)) &&
827 ((question.dns_class() == record.dns_class()) ||
828 (question.dns_class() == DnsClass::kANY))) {
829 return true;
830 }
831 }
832 }
833
834 return false;
835 }
836
MaxWireSize() const837 size_t MdnsMessage::MaxWireSize() const {
838 return max_wire_size_;
839 }
840
AddQuestion(MdnsQuestion question)841 void MdnsMessage::AddQuestion(MdnsQuestion question) {
842 OSP_DCHECK(questions_.size() < kMaxMessageFieldEntryCount);
843 max_wire_size_ += question.MaxWireSize();
844 questions_.emplace_back(std::move(question));
845 }
846
AddAnswer(MdnsRecord record)847 void MdnsMessage::AddAnswer(MdnsRecord record) {
848 OSP_DCHECK(answers_.size() < kMaxMessageFieldEntryCount);
849 max_wire_size_ += record.MaxWireSize();
850 answers_.emplace_back(std::move(record));
851 }
852
AddAuthorityRecord(MdnsRecord record)853 void MdnsMessage::AddAuthorityRecord(MdnsRecord record) {
854 OSP_DCHECK(authority_records_.size() < kMaxMessageFieldEntryCount);
855 max_wire_size_ += record.MaxWireSize();
856 authority_records_.emplace_back(std::move(record));
857 }
858
AddAdditionalRecord(MdnsRecord record)859 void MdnsMessage::AddAdditionalRecord(MdnsRecord record) {
860 OSP_DCHECK(additional_records_.size() < kMaxMessageFieldEntryCount);
861 max_wire_size_ += record.MaxWireSize();
862 additional_records_.emplace_back(std::move(record));
863 }
864
CanAddRecord(const MdnsRecord & record)865 bool MdnsMessage::CanAddRecord(const MdnsRecord& record) {
866 return (max_wire_size_ + record.MaxWireSize()) < kMaxMulticastMessageSize;
867 }
868
CreateMessageId()869 uint16_t CreateMessageId() {
870 static uint16_t id(0);
871 return id++;
872 }
873
CanBePublished(DnsType type)874 bool CanBePublished(DnsType type) {
875 // NOTE: A 'default' switch statement has intentionally been avoided below to
876 // enforce that new DnsTypes added must be added below through a compile-time
877 // check.
878 switch (type) {
879 case DnsType::kA:
880 case DnsType::kAAAA:
881 case DnsType::kPTR:
882 case DnsType::kTXT:
883 case DnsType::kSRV:
884 return true;
885 case DnsType::kOPT:
886 case DnsType::kNSEC:
887 case DnsType::kANY:
888 break;
889 }
890
891 return false;
892 }
893
CanBePublished(const MdnsRecord & record)894 bool CanBePublished(const MdnsRecord& record) {
895 return CanBePublished(record.dns_type());
896 }
897
CanBeQueried(DnsType type)898 bool CanBeQueried(DnsType type) {
899 // NOTE: A 'default' switch statement has intentionally been avoided below to
900 // enforce that new DnsTypes added must be added below through a compile-time
901 // check.
902 switch (type) {
903 case DnsType::kA:
904 case DnsType::kAAAA:
905 case DnsType::kPTR:
906 case DnsType::kTXT:
907 case DnsType::kSRV:
908 case DnsType::kANY:
909 return true;
910 case DnsType::kOPT:
911 case DnsType::kNSEC:
912 break;
913 }
914
915 return false;
916 }
917
CanBeProcessed(DnsType type)918 bool CanBeProcessed(DnsType type) {
919 // NOTE: A 'default' switch statement has intentionally been avoided below to
920 // enforce that new DnsTypes added must be added below through a compile-time
921 // check.
922 switch (type) {
923 case DnsType::kA:
924 case DnsType::kAAAA:
925 case DnsType::kPTR:
926 case DnsType::kTXT:
927 case DnsType::kSRV:
928 case DnsType::kNSEC:
929 return true;
930 case DnsType::kOPT:
931 case DnsType::kANY:
932 break;
933 }
934
935 return false;
936 }
937
938 } // namespace discovery
939 } // namespace openscreen
940