1 /*
2 * Copyright 2018 The WebRTC Project Authors. All rights reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "p2p/base/mdns_message.h"
12
13 #include "rtc_base/logging.h"
14 #include "rtc_base/net_helpers.h"
15 #include "rtc_base/string_encode.h"
16
17 namespace webrtc {
18
19 namespace {
20 // RFC 1035, Section 4.1.1.
21 //
22 // QR bit.
23 constexpr uint16_t kMdnsFlagMaskQueryOrResponse = 0x8000;
24 // AA bit.
25 constexpr uint16_t kMdnsFlagMaskAuthoritative = 0x0400;
26 // RFC 1035, Section 4.1.2, QCLASS and RFC 6762, Section 18.12, repurposing of
27 // top bit of QCLASS as the unicast response bit.
28 constexpr uint16_t kMdnsQClassMaskUnicastResponse = 0x8000;
29 constexpr size_t kMdnsHeaderSizeBytes = 12;
30
ReadDomainName(MessageBufferReader * buf,std::string * name)31 bool ReadDomainName(MessageBufferReader* buf, std::string* name) {
32 size_t name_start_pos = buf->CurrentOffset();
33 uint8_t label_length;
34 if (!buf->ReadUInt8(&label_length)) {
35 return false;
36 }
37 // RFC 1035, Section 4.1.4.
38 //
39 // If the first two bits of the length octet are ones, the name is compressed
40 // and the rest six bits with the next octet denotes its position in the
41 // message by the offset from the start of the message.
42 auto is_pointer = [](uint8_t octet) {
43 return (octet & 0x80) && (octet & 0x40);
44 };
45 while (label_length && !is_pointer(label_length)) {
46 // RFC 1035, Section 2.3.1, labels are restricted to 63 octets or less.
47 if (label_length > 63) {
48 return false;
49 }
50 std::string label;
51 if (!buf->ReadString(&label, label_length)) {
52 return false;
53 }
54 (*name) += label + ".";
55 if (!buf->ReadUInt8(&label_length)) {
56 return false;
57 }
58 }
59 if (is_pointer(label_length)) {
60 uint8_t next_octet;
61 if (!buf->ReadUInt8(&next_octet)) {
62 return false;
63 }
64 size_t pos_jump_to = ((label_length & 0x3f) << 8) | next_octet;
65 // A legitimate pointer only refers to a prior occurrence of the same name,
66 // and we should only move strictly backward to a prior name field after the
67 // header.
68 if (pos_jump_to >= name_start_pos || pos_jump_to < kMdnsHeaderSizeBytes) {
69 return false;
70 }
71 MessageBufferReader new_buf(buf->MessageData(), buf->MessageLength());
72 if (!new_buf.Consume(pos_jump_to)) {
73 return false;
74 }
75 return ReadDomainName(&new_buf, name);
76 }
77 return true;
78 }
79
WriteDomainName(rtc::ByteBufferWriter * buf,const std::string & name)80 void WriteDomainName(rtc::ByteBufferWriter* buf, const std::string& name) {
81 std::vector<std::string> labels;
82 rtc::tokenize(name, '.', &labels);
83 for (const auto& label : labels) {
84 buf->WriteUInt8(label.length());
85 buf->WriteString(label);
86 }
87 buf->WriteUInt8(0);
88 }
89
90 } // namespace
91
SetQueryOrResponse(bool is_query)92 void MdnsHeader::SetQueryOrResponse(bool is_query) {
93 if (is_query) {
94 flags &= ~kMdnsFlagMaskQueryOrResponse;
95 } else {
96 flags |= kMdnsFlagMaskQueryOrResponse;
97 }
98 }
99
SetAuthoritative(bool is_authoritative)100 void MdnsHeader::SetAuthoritative(bool is_authoritative) {
101 if (is_authoritative) {
102 flags |= kMdnsFlagMaskAuthoritative;
103 } else {
104 flags &= ~kMdnsFlagMaskAuthoritative;
105 }
106 }
107
IsAuthoritative() const108 bool MdnsHeader::IsAuthoritative() const {
109 return flags & kMdnsFlagMaskAuthoritative;
110 }
111
Read(MessageBufferReader * buf)112 bool MdnsHeader::Read(MessageBufferReader* buf) {
113 if (!buf->ReadUInt16(&id) || !buf->ReadUInt16(&flags) ||
114 !buf->ReadUInt16(&qdcount) || !buf->ReadUInt16(&ancount) ||
115 !buf->ReadUInt16(&nscount) || !buf->ReadUInt16(&arcount)) {
116 RTC_LOG(LS_ERROR) << "Invalid mDNS header.";
117 return false;
118 }
119 return true;
120 }
121
Write(rtc::ByteBufferWriter * buf) const122 void MdnsHeader::Write(rtc::ByteBufferWriter* buf) const {
123 buf->WriteUInt16(id);
124 buf->WriteUInt16(flags);
125 buf->WriteUInt16(qdcount);
126 buf->WriteUInt16(ancount);
127 buf->WriteUInt16(nscount);
128 buf->WriteUInt16(arcount);
129 }
130
IsQuery() const131 bool MdnsHeader::IsQuery() const {
132 return !(flags & kMdnsFlagMaskQueryOrResponse);
133 }
134
135 MdnsSectionEntry::MdnsSectionEntry() = default;
136 MdnsSectionEntry::~MdnsSectionEntry() = default;
137 MdnsSectionEntry::MdnsSectionEntry(const MdnsSectionEntry& other) = default;
138
SetType(SectionEntryType type)139 void MdnsSectionEntry::SetType(SectionEntryType type) {
140 switch (type) {
141 case SectionEntryType::kA:
142 type_ = 1;
143 return;
144 case SectionEntryType::kAAAA:
145 type_ = 28;
146 return;
147 default:
148 RTC_NOTREACHED();
149 }
150 }
151
GetType() const152 SectionEntryType MdnsSectionEntry::GetType() const {
153 switch (type_) {
154 case 1:
155 return SectionEntryType::kA;
156 case 28:
157 return SectionEntryType::kAAAA;
158 default:
159 return SectionEntryType::kUnsupported;
160 }
161 }
162
SetClass(SectionEntryClass cls)163 void MdnsSectionEntry::SetClass(SectionEntryClass cls) {
164 switch (cls) {
165 case SectionEntryClass::kIN:
166 class_ = 1;
167 return;
168 default:
169 RTC_NOTREACHED();
170 }
171 }
172
GetClass() const173 SectionEntryClass MdnsSectionEntry::GetClass() const {
174 switch (class_) {
175 case 1:
176 return SectionEntryClass::kIN;
177 default:
178 return SectionEntryClass::kUnsupported;
179 }
180 }
181
182 MdnsQuestion::MdnsQuestion() = default;
183 MdnsQuestion::MdnsQuestion(const MdnsQuestion& other) = default;
184 MdnsQuestion::~MdnsQuestion() = default;
185
Read(MessageBufferReader * buf)186 bool MdnsQuestion::Read(MessageBufferReader* buf) {
187 if (!ReadDomainName(buf, &name_)) {
188 RTC_LOG(LS_ERROR) << "Invalid name.";
189 return false;
190 }
191 if (!buf->ReadUInt16(&type_) || !buf->ReadUInt16(&class_)) {
192 RTC_LOG(LS_ERROR) << "Invalid type and class.";
193 return false;
194 }
195 return true;
196 }
197
Write(rtc::ByteBufferWriter * buf) const198 bool MdnsQuestion::Write(rtc::ByteBufferWriter* buf) const {
199 WriteDomainName(buf, name_);
200 buf->WriteUInt16(type_);
201 buf->WriteUInt16(class_);
202 return true;
203 }
204
SetUnicastResponse(bool should_unicast)205 void MdnsQuestion::SetUnicastResponse(bool should_unicast) {
206 if (should_unicast) {
207 class_ |= kMdnsQClassMaskUnicastResponse;
208 } else {
209 class_ &= ~kMdnsQClassMaskUnicastResponse;
210 }
211 }
212
ShouldUnicastResponse() const213 bool MdnsQuestion::ShouldUnicastResponse() const {
214 return class_ & kMdnsQClassMaskUnicastResponse;
215 }
216
217 MdnsResourceRecord::MdnsResourceRecord() = default;
218 MdnsResourceRecord::MdnsResourceRecord(const MdnsResourceRecord& other) =
219 default;
220 MdnsResourceRecord::~MdnsResourceRecord() = default;
221
Read(MessageBufferReader * buf)222 bool MdnsResourceRecord::Read(MessageBufferReader* buf) {
223 if (!ReadDomainName(buf, &name_)) {
224 return false;
225 }
226 if (!buf->ReadUInt16(&type_) || !buf->ReadUInt16(&class_) ||
227 !buf->ReadUInt32(&ttl_seconds_) || !buf->ReadUInt16(&rdlength_)) {
228 return false;
229 }
230
231 switch (GetType()) {
232 case SectionEntryType::kA:
233 return ReadARData(buf);
234 case SectionEntryType::kAAAA:
235 return ReadQuadARData(buf);
236 case SectionEntryType::kUnsupported:
237 return false;
238 default:
239 RTC_NOTREACHED();
240 }
241 return false;
242 }
ReadARData(MessageBufferReader * buf)243 bool MdnsResourceRecord::ReadARData(MessageBufferReader* buf) {
244 // A RDATA contains a 32-bit IPv4 address.
245 return buf->ReadString(&rdata_, 4);
246 }
247
ReadQuadARData(MessageBufferReader * buf)248 bool MdnsResourceRecord::ReadQuadARData(MessageBufferReader* buf) {
249 // AAAA RDATA contains a 128-bit IPv6 address.
250 return buf->ReadString(&rdata_, 16);
251 }
252
Write(rtc::ByteBufferWriter * buf) const253 bool MdnsResourceRecord::Write(rtc::ByteBufferWriter* buf) const {
254 WriteDomainName(buf, name_);
255 buf->WriteUInt16(type_);
256 buf->WriteUInt16(class_);
257 buf->WriteUInt32(ttl_seconds_);
258 buf->WriteUInt16(rdlength_);
259 switch (GetType()) {
260 case SectionEntryType::kA:
261 WriteARData(buf);
262 return true;
263 case SectionEntryType::kAAAA:
264 WriteQuadARData(buf);
265 return true;
266 case SectionEntryType::kUnsupported:
267 return false;
268 default:
269 RTC_NOTREACHED();
270 }
271 return true;
272 }
273
WriteARData(rtc::ByteBufferWriter * buf) const274 void MdnsResourceRecord::WriteARData(rtc::ByteBufferWriter* buf) const {
275 buf->WriteString(rdata_);
276 }
277
WriteQuadARData(rtc::ByteBufferWriter * buf) const278 void MdnsResourceRecord::WriteQuadARData(rtc::ByteBufferWriter* buf) const {
279 buf->WriteString(rdata_);
280 }
281
SetIPAddressInRecordData(const rtc::IPAddress & address)282 bool MdnsResourceRecord::SetIPAddressInRecordData(
283 const rtc::IPAddress& address) {
284 int af = address.family();
285 if (af != AF_INET && af != AF_INET6) {
286 return false;
287 }
288 char out[16] = {0};
289 if (!rtc::inet_pton(af, address.ToString().c_str(), out)) {
290 return false;
291 }
292 rdlength_ = (af == AF_INET) ? 4 : 16;
293 rdata_ = std::string(out, rdlength_);
294 return true;
295 }
296
GetIPAddressFromRecordData(rtc::IPAddress * address) const297 bool MdnsResourceRecord::GetIPAddressFromRecordData(
298 rtc::IPAddress* address) const {
299 if (GetType() != SectionEntryType::kA &&
300 GetType() != SectionEntryType::kAAAA) {
301 return false;
302 }
303 if (rdata_.size() != 4 && rdata_.size() != 16) {
304 return false;
305 }
306 char out[INET6_ADDRSTRLEN] = {0};
307 int af = (GetType() == SectionEntryType::kA) ? AF_INET : AF_INET6;
308 if (!rtc::inet_ntop(af, rdata_.data(), out, sizeof(out))) {
309 return false;
310 }
311 return rtc::IPFromString(std::string(out), address);
312 }
313
314 MdnsMessage::MdnsMessage() = default;
315 MdnsMessage::~MdnsMessage() = default;
316
Read(MessageBufferReader * buf)317 bool MdnsMessage::Read(MessageBufferReader* buf) {
318 RTC_DCHECK_EQ(0u, buf->CurrentOffset());
319 if (!header_.Read(buf)) {
320 return false;
321 }
322
323 auto read_question = [&buf](std::vector<MdnsQuestion>* section,
324 uint16_t count) {
325 section->resize(count);
326 for (auto& question : (*section)) {
327 if (!question.Read(buf)) {
328 return false;
329 }
330 }
331 return true;
332 };
333 auto read_rr = [&buf](std::vector<MdnsResourceRecord>* section,
334 uint16_t count) {
335 section->resize(count);
336 for (auto& rr : (*section)) {
337 if (!rr.Read(buf)) {
338 return false;
339 }
340 }
341 return true;
342 };
343
344 if (!read_question(&question_section_, header_.qdcount) ||
345 !read_rr(&answer_section_, header_.ancount) ||
346 !read_rr(&authority_section_, header_.nscount) ||
347 !read_rr(&additional_section_, header_.arcount)) {
348 return false;
349 }
350 return true;
351 }
352
Write(rtc::ByteBufferWriter * buf) const353 bool MdnsMessage::Write(rtc::ByteBufferWriter* buf) const {
354 header_.Write(buf);
355
356 auto write_rr = [&buf](const std::vector<MdnsResourceRecord>& section) {
357 for (const auto& rr : section) {
358 if (!rr.Write(buf)) {
359 return false;
360 }
361 }
362 return true;
363 };
364
365 for (const auto& question : question_section_) {
366 if (!question.Write(buf)) {
367 return false;
368 }
369 }
370 if (!write_rr(answer_section_) || !write_rr(authority_section_) ||
371 !write_rr(additional_section_)) {
372 return false;
373 }
374
375 return true;
376 }
377
ShouldUnicastResponse() const378 bool MdnsMessage::ShouldUnicastResponse() const {
379 bool should_unicast = false;
380 for (const auto& question : question_section_) {
381 should_unicast |= question.ShouldUnicastResponse();
382 }
383 return should_unicast;
384 }
385
AddQuestion(const MdnsQuestion & question)386 void MdnsMessage::AddQuestion(const MdnsQuestion& question) {
387 question_section_.push_back(question);
388 header_.qdcount = question_section_.size();
389 }
390
AddAnswerRecord(const MdnsResourceRecord & answer)391 void MdnsMessage::AddAnswerRecord(const MdnsResourceRecord& answer) {
392 answer_section_.push_back(answer);
393 header_.ancount = answer_section_.size();
394 }
395
396 } // namespace webrtc
397