1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9
10 #include "net/dns/record_rdata.h"
11
12 #include <algorithm>
13 #include <numeric>
14 #include <string_view>
15 #include <utility>
16
17 #include "base/containers/span.h"
18 #include "base/containers/span_reader.h"
19 #include "base/logging.h"
20 #include "base/memory/ptr_util.h"
21 #include "base/rand_util.h"
22 #include "net/base/ip_address.h"
23 #include "net/dns/dns_response.h"
24 #include "net/dns/public/dns_protocol.h"
25
26 namespace net {
27
28 static const size_t kSrvRecordMinimumSize = 6;
29
30 // Minimal HTTPS rdata is 2 octets priority + 1 octet empty name.
31 static constexpr size_t kHttpsRdataMinimumSize = 3;
32
HasValidSize(std::string_view data,uint16_t type)33 bool RecordRdata::HasValidSize(std::string_view data, uint16_t type) {
34 switch (type) {
35 case dns_protocol::kTypeSRV:
36 return data.size() >= kSrvRecordMinimumSize;
37 case dns_protocol::kTypeA:
38 return data.size() == IPAddress::kIPv4AddressSize;
39 case dns_protocol::kTypeAAAA:
40 return data.size() == IPAddress::kIPv6AddressSize;
41 case dns_protocol::kTypeHttps:
42 return data.size() >= kHttpsRdataMinimumSize;
43 case dns_protocol::kTypeCNAME:
44 case dns_protocol::kTypePTR:
45 case dns_protocol::kTypeTXT:
46 case dns_protocol::kTypeNSEC:
47 case dns_protocol::kTypeOPT:
48 case dns_protocol::kTypeSOA:
49 return true;
50 default:
51 VLOG(1) << "Unrecognized RDATA type.";
52 return true;
53 }
54 }
55
56 SrvRecordRdata::SrvRecordRdata() = default;
57
58 SrvRecordRdata::~SrvRecordRdata() = default;
59
60 // static
Create(std::string_view data,const DnsRecordParser & parser)61 std::unique_ptr<SrvRecordRdata> SrvRecordRdata::Create(
62 std::string_view data,
63 const DnsRecordParser& parser) {
64 if (!HasValidSize(data, kType))
65 return nullptr;
66
67 auto rdata = base::WrapUnique(new SrvRecordRdata());
68
69 auto reader = base::SpanReader(base::as_byte_span(data));
70 // 2 bytes for priority, 2 bytes for weight, 2 bytes for port.
71 reader.ReadU16BigEndian(rdata->priority_);
72 reader.ReadU16BigEndian(rdata->weight_);
73 reader.ReadU16BigEndian(rdata->port_);
74
75 if (!parser.ReadName(data.substr(kSrvRecordMinimumSize).data(),
76 &rdata->target_)) {
77 return nullptr;
78 }
79
80 return rdata;
81 }
82
Type() const83 uint16_t SrvRecordRdata::Type() const {
84 return SrvRecordRdata::kType;
85 }
86
IsEqual(const RecordRdata * other) const87 bool SrvRecordRdata::IsEqual(const RecordRdata* other) const {
88 if (other->Type() != Type()) return false;
89 const SrvRecordRdata* srv_other = static_cast<const SrvRecordRdata*>(other);
90 return weight_ == srv_other->weight_ &&
91 port_ == srv_other->port_ &&
92 priority_ == srv_other->priority_ &&
93 target_ == srv_other->target_;
94 }
95
96 ARecordRdata::ARecordRdata() = default;
97
98 ARecordRdata::~ARecordRdata() = default;
99
100 // static
Create(std::string_view data,const DnsRecordParser & parser)101 std::unique_ptr<ARecordRdata> ARecordRdata::Create(
102 std::string_view data,
103 const DnsRecordParser& parser) {
104 if (!HasValidSize(data, kType))
105 return nullptr;
106
107 auto rdata = base::WrapUnique(new ARecordRdata());
108 rdata->address_ = IPAddress(base::as_byte_span(data));
109 return rdata;
110 }
111
Type() const112 uint16_t ARecordRdata::Type() const {
113 return ARecordRdata::kType;
114 }
115
IsEqual(const RecordRdata * other) const116 bool ARecordRdata::IsEqual(const RecordRdata* other) const {
117 if (other->Type() != Type()) return false;
118 const ARecordRdata* a_other = static_cast<const ARecordRdata*>(other);
119 return address_ == a_other->address_;
120 }
121
122 AAAARecordRdata::AAAARecordRdata() = default;
123
124 AAAARecordRdata::~AAAARecordRdata() = default;
125
126 // static
Create(std::string_view data,const DnsRecordParser & parser)127 std::unique_ptr<AAAARecordRdata> AAAARecordRdata::Create(
128 std::string_view data,
129 const DnsRecordParser& parser) {
130 if (!HasValidSize(data, kType))
131 return nullptr;
132
133 auto rdata = base::WrapUnique(new AAAARecordRdata());
134 rdata->address_ = IPAddress(base::as_byte_span(data));
135 return rdata;
136 }
137
Type() const138 uint16_t AAAARecordRdata::Type() const {
139 return AAAARecordRdata::kType;
140 }
141
IsEqual(const RecordRdata * other) const142 bool AAAARecordRdata::IsEqual(const RecordRdata* other) const {
143 if (other->Type() != Type()) return false;
144 const AAAARecordRdata* a_other = static_cast<const AAAARecordRdata*>(other);
145 return address_ == a_other->address_;
146 }
147
148 CnameRecordRdata::CnameRecordRdata() = default;
149
150 CnameRecordRdata::~CnameRecordRdata() = default;
151
152 // static
Create(std::string_view data,const DnsRecordParser & parser)153 std::unique_ptr<CnameRecordRdata> CnameRecordRdata::Create(
154 std::string_view data,
155 const DnsRecordParser& parser) {
156 auto rdata = base::WrapUnique(new CnameRecordRdata());
157
158 if (!parser.ReadName(data.data(), &rdata->cname_)) {
159 return nullptr;
160 }
161
162 return rdata;
163 }
164
Type() const165 uint16_t CnameRecordRdata::Type() const {
166 return CnameRecordRdata::kType;
167 }
168
IsEqual(const RecordRdata * other) const169 bool CnameRecordRdata::IsEqual(const RecordRdata* other) const {
170 if (other->Type() != Type()) return false;
171 const CnameRecordRdata* cname_other =
172 static_cast<const CnameRecordRdata*>(other);
173 return cname_ == cname_other->cname_;
174 }
175
176 PtrRecordRdata::PtrRecordRdata() = default;
177
178 PtrRecordRdata::~PtrRecordRdata() = default;
179
180 // static
Create(std::string_view data,const DnsRecordParser & parser)181 std::unique_ptr<PtrRecordRdata> PtrRecordRdata::Create(
182 std::string_view data,
183 const DnsRecordParser& parser) {
184 auto rdata = base::WrapUnique(new PtrRecordRdata());
185
186 if (!parser.ReadName(data.data(), &rdata->ptrdomain_)) {
187 return nullptr;
188 }
189
190 return rdata;
191 }
192
Type() const193 uint16_t PtrRecordRdata::Type() const {
194 return PtrRecordRdata::kType;
195 }
196
IsEqual(const RecordRdata * other) const197 bool PtrRecordRdata::IsEqual(const RecordRdata* other) const {
198 if (other->Type() != Type()) return false;
199 const PtrRecordRdata* ptr_other = static_cast<const PtrRecordRdata*>(other);
200 return ptrdomain_ == ptr_other->ptrdomain_;
201 }
202
203 TxtRecordRdata::TxtRecordRdata() = default;
204
205 TxtRecordRdata::~TxtRecordRdata() = default;
206
207 // static
Create(std::string_view data,const DnsRecordParser & parser)208 std::unique_ptr<TxtRecordRdata> TxtRecordRdata::Create(
209 std::string_view data,
210 const DnsRecordParser& parser) {
211 auto rdata = base::WrapUnique(new TxtRecordRdata());
212
213 for (size_t i = 0; i < data.size(); ) {
214 uint8_t length = data[i];
215
216 if (i + length >= data.size())
217 return nullptr;
218
219 rdata->texts_.push_back(std::string(data.substr(i + 1, length)));
220
221 // Move to the next string.
222 i += length + 1;
223 }
224
225 return rdata;
226 }
227
Type() const228 uint16_t TxtRecordRdata::Type() const {
229 return TxtRecordRdata::kType;
230 }
231
IsEqual(const RecordRdata * other) const232 bool TxtRecordRdata::IsEqual(const RecordRdata* other) const {
233 if (other->Type() != Type()) return false;
234 const TxtRecordRdata* txt_other = static_cast<const TxtRecordRdata*>(other);
235 return texts_ == txt_other->texts_;
236 }
237
238 NsecRecordRdata::NsecRecordRdata() = default;
239
240 NsecRecordRdata::~NsecRecordRdata() = default;
241
242 // static
Create(std::string_view data,const DnsRecordParser & parser)243 std::unique_ptr<NsecRecordRdata> NsecRecordRdata::Create(
244 std::string_view data,
245 const DnsRecordParser& parser) {
246 auto rdata = base::WrapUnique(new NsecRecordRdata());
247
248 // Read the "next domain". This part for the NSEC record format is
249 // ignored for mDNS, since it has no semantic meaning.
250 unsigned next_domain_length = parser.ReadName(data.data(), nullptr);
251
252 // If we did not succeed in getting the next domain or the data length
253 // is too short for reading the bitmap header, return.
254 if (next_domain_length == 0 || data.length() < next_domain_length + 2)
255 return nullptr;
256
257 struct BitmapHeader {
258 uint8_t block_number; // The block number should be zero.
259 uint8_t length; // Bitmap length in bytes. Between 1 and 32.
260 };
261
262 const BitmapHeader* header = reinterpret_cast<const BitmapHeader*>(
263 data.data() + next_domain_length);
264
265 // The block number must be zero in mDns-specific NSEC records. The bitmap
266 // length must be between 1 and 32.
267 if (header->block_number != 0 || header->length == 0 || header->length > 32)
268 return nullptr;
269
270 std::string_view bitmap_data = data.substr(next_domain_length + 2);
271
272 // Since we may only have one block, the data length must be exactly equal to
273 // the domain length plus bitmap size.
274 if (bitmap_data.length() != header->length)
275 return nullptr;
276
277 rdata->bitmap_.insert(rdata->bitmap_.begin(),
278 bitmap_data.begin(),
279 bitmap_data.end());
280
281 return rdata;
282 }
283
Type() const284 uint16_t NsecRecordRdata::Type() const {
285 return NsecRecordRdata::kType;
286 }
287
IsEqual(const RecordRdata * other) const288 bool NsecRecordRdata::IsEqual(const RecordRdata* other) const {
289 if (other->Type() != Type())
290 return false;
291 const NsecRecordRdata* nsec_other =
292 static_cast<const NsecRecordRdata*>(other);
293 return bitmap_ == nsec_other->bitmap_;
294 }
295
GetBit(unsigned i) const296 bool NsecRecordRdata::GetBit(unsigned i) const {
297 unsigned byte_num = i/8;
298 if (bitmap_.size() < byte_num + 1)
299 return false;
300
301 unsigned bit_num = 7 - i % 8;
302 return (bitmap_[byte_num] & (1 << bit_num)) != 0;
303 }
304
305 } // namespace net
306