• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright 2018 The WebRTC Project Authors. All rights reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "p2p/base/mdns_message.h"
12 
13 #include "rtc_base/logging.h"
14 #include "rtc_base/net_helpers.h"
15 #include "rtc_base/string_encode.h"
16 
17 namespace webrtc {
18 
19 namespace {
20 // RFC 1035, Section 4.1.1.
21 //
22 // QR bit.
23 constexpr uint16_t kMdnsFlagMaskQueryOrResponse = 0x8000;
24 // AA bit.
25 constexpr uint16_t kMdnsFlagMaskAuthoritative = 0x0400;
26 // RFC 1035, Section 4.1.2, QCLASS and RFC 6762, Section 18.12, repurposing of
27 // top bit of QCLASS as the unicast response bit.
28 constexpr uint16_t kMdnsQClassMaskUnicastResponse = 0x8000;
29 constexpr size_t kMdnsHeaderSizeBytes = 12;
30 
ReadDomainName(MessageBufferReader * buf,std::string * name)31 bool ReadDomainName(MessageBufferReader* buf, std::string* name) {
32   size_t name_start_pos = buf->CurrentOffset();
33   uint8_t label_length;
34   if (!buf->ReadUInt8(&label_length)) {
35     return false;
36   }
37   // RFC 1035, Section 4.1.4.
38   //
39   // If the first two bits of the length octet are ones, the name is compressed
40   // and the rest six bits with the next octet denotes its position in the
41   // message by the offset from the start of the message.
42   auto is_pointer = [](uint8_t octet) {
43     return (octet & 0x80) && (octet & 0x40);
44   };
45   while (label_length && !is_pointer(label_length)) {
46     // RFC 1035, Section 2.3.1, labels are restricted to 63 octets or less.
47     if (label_length > 63) {
48       return false;
49     }
50     std::string label;
51     if (!buf->ReadString(&label, label_length)) {
52       return false;
53     }
54     (*name) += label + ".";
55     if (!buf->ReadUInt8(&label_length)) {
56       return false;
57     }
58   }
59   if (is_pointer(label_length)) {
60     uint8_t next_octet;
61     if (!buf->ReadUInt8(&next_octet)) {
62       return false;
63     }
64     size_t pos_jump_to = ((label_length & 0x3f) << 8) | next_octet;
65     // A legitimate pointer only refers to a prior occurrence of the same name,
66     // and we should only move strictly backward to a prior name field after the
67     // header.
68     if (pos_jump_to >= name_start_pos || pos_jump_to < kMdnsHeaderSizeBytes) {
69       return false;
70     }
71     MessageBufferReader new_buf(buf->MessageData(), buf->MessageLength());
72     if (!new_buf.Consume(pos_jump_to)) {
73       return false;
74     }
75     return ReadDomainName(&new_buf, name);
76   }
77   return true;
78 }
79 
WriteDomainName(rtc::ByteBufferWriter * buf,const std::string & name)80 void WriteDomainName(rtc::ByteBufferWriter* buf, const std::string& name) {
81   std::vector<std::string> labels;
82   rtc::tokenize(name, '.', &labels);
83   for (const auto& label : labels) {
84     buf->WriteUInt8(label.length());
85     buf->WriteString(label);
86   }
87   buf->WriteUInt8(0);
88 }
89 
90 }  // namespace
91 
SetQueryOrResponse(bool is_query)92 void MdnsHeader::SetQueryOrResponse(bool is_query) {
93   if (is_query) {
94     flags &= ~kMdnsFlagMaskQueryOrResponse;
95   } else {
96     flags |= kMdnsFlagMaskQueryOrResponse;
97   }
98 }
99 
SetAuthoritative(bool is_authoritative)100 void MdnsHeader::SetAuthoritative(bool is_authoritative) {
101   if (is_authoritative) {
102     flags |= kMdnsFlagMaskAuthoritative;
103   } else {
104     flags &= ~kMdnsFlagMaskAuthoritative;
105   }
106 }
107 
IsAuthoritative() const108 bool MdnsHeader::IsAuthoritative() const {
109   return flags & kMdnsFlagMaskAuthoritative;
110 }
111 
Read(MessageBufferReader * buf)112 bool MdnsHeader::Read(MessageBufferReader* buf) {
113   if (!buf->ReadUInt16(&id) || !buf->ReadUInt16(&flags) ||
114       !buf->ReadUInt16(&qdcount) || !buf->ReadUInt16(&ancount) ||
115       !buf->ReadUInt16(&nscount) || !buf->ReadUInt16(&arcount)) {
116     RTC_LOG(LS_ERROR) << "Invalid mDNS header.";
117     return false;
118   }
119   return true;
120 }
121 
Write(rtc::ByteBufferWriter * buf) const122 void MdnsHeader::Write(rtc::ByteBufferWriter* buf) const {
123   buf->WriteUInt16(id);
124   buf->WriteUInt16(flags);
125   buf->WriteUInt16(qdcount);
126   buf->WriteUInt16(ancount);
127   buf->WriteUInt16(nscount);
128   buf->WriteUInt16(arcount);
129 }
130 
IsQuery() const131 bool MdnsHeader::IsQuery() const {
132   return !(flags & kMdnsFlagMaskQueryOrResponse);
133 }
134 
135 MdnsSectionEntry::MdnsSectionEntry() = default;
136 MdnsSectionEntry::~MdnsSectionEntry() = default;
137 MdnsSectionEntry::MdnsSectionEntry(const MdnsSectionEntry& other) = default;
138 
SetType(SectionEntryType type)139 void MdnsSectionEntry::SetType(SectionEntryType type) {
140   switch (type) {
141     case SectionEntryType::kA:
142       type_ = 1;
143       return;
144     case SectionEntryType::kAAAA:
145       type_ = 28;
146       return;
147     default:
148       RTC_NOTREACHED();
149   }
150 }
151 
GetType() const152 SectionEntryType MdnsSectionEntry::GetType() const {
153   switch (type_) {
154     case 1:
155       return SectionEntryType::kA;
156     case 28:
157       return SectionEntryType::kAAAA;
158     default:
159       return SectionEntryType::kUnsupported;
160   }
161 }
162 
SetClass(SectionEntryClass cls)163 void MdnsSectionEntry::SetClass(SectionEntryClass cls) {
164   switch (cls) {
165     case SectionEntryClass::kIN:
166       class_ = 1;
167       return;
168     default:
169       RTC_NOTREACHED();
170   }
171 }
172 
GetClass() const173 SectionEntryClass MdnsSectionEntry::GetClass() const {
174   switch (class_) {
175     case 1:
176       return SectionEntryClass::kIN;
177     default:
178       return SectionEntryClass::kUnsupported;
179   }
180 }
181 
182 MdnsQuestion::MdnsQuestion() = default;
183 MdnsQuestion::MdnsQuestion(const MdnsQuestion& other) = default;
184 MdnsQuestion::~MdnsQuestion() = default;
185 
Read(MessageBufferReader * buf)186 bool MdnsQuestion::Read(MessageBufferReader* buf) {
187   if (!ReadDomainName(buf, &name_)) {
188     RTC_LOG(LS_ERROR) << "Invalid name.";
189     return false;
190   }
191   if (!buf->ReadUInt16(&type_) || !buf->ReadUInt16(&class_)) {
192     RTC_LOG(LS_ERROR) << "Invalid type and class.";
193     return false;
194   }
195   return true;
196 }
197 
Write(rtc::ByteBufferWriter * buf) const198 bool MdnsQuestion::Write(rtc::ByteBufferWriter* buf) const {
199   WriteDomainName(buf, name_);
200   buf->WriteUInt16(type_);
201   buf->WriteUInt16(class_);
202   return true;
203 }
204 
SetUnicastResponse(bool should_unicast)205 void MdnsQuestion::SetUnicastResponse(bool should_unicast) {
206   if (should_unicast) {
207     class_ |= kMdnsQClassMaskUnicastResponse;
208   } else {
209     class_ &= ~kMdnsQClassMaskUnicastResponse;
210   }
211 }
212 
ShouldUnicastResponse() const213 bool MdnsQuestion::ShouldUnicastResponse() const {
214   return class_ & kMdnsQClassMaskUnicastResponse;
215 }
216 
217 MdnsResourceRecord::MdnsResourceRecord() = default;
218 MdnsResourceRecord::MdnsResourceRecord(const MdnsResourceRecord& other) =
219     default;
220 MdnsResourceRecord::~MdnsResourceRecord() = default;
221 
Read(MessageBufferReader * buf)222 bool MdnsResourceRecord::Read(MessageBufferReader* buf) {
223   if (!ReadDomainName(buf, &name_)) {
224     return false;
225   }
226   if (!buf->ReadUInt16(&type_) || !buf->ReadUInt16(&class_) ||
227       !buf->ReadUInt32(&ttl_seconds_) || !buf->ReadUInt16(&rdlength_)) {
228     return false;
229   }
230 
231   switch (GetType()) {
232     case SectionEntryType::kA:
233       return ReadARData(buf);
234     case SectionEntryType::kAAAA:
235       return ReadQuadARData(buf);
236     case SectionEntryType::kUnsupported:
237       return false;
238     default:
239       RTC_NOTREACHED();
240   }
241   return false;
242 }
ReadARData(MessageBufferReader * buf)243 bool MdnsResourceRecord::ReadARData(MessageBufferReader* buf) {
244   // A RDATA contains a 32-bit IPv4 address.
245   return buf->ReadString(&rdata_, 4);
246 }
247 
ReadQuadARData(MessageBufferReader * buf)248 bool MdnsResourceRecord::ReadQuadARData(MessageBufferReader* buf) {
249   // AAAA RDATA contains a 128-bit IPv6 address.
250   return buf->ReadString(&rdata_, 16);
251 }
252 
Write(rtc::ByteBufferWriter * buf) const253 bool MdnsResourceRecord::Write(rtc::ByteBufferWriter* buf) const {
254   WriteDomainName(buf, name_);
255   buf->WriteUInt16(type_);
256   buf->WriteUInt16(class_);
257   buf->WriteUInt32(ttl_seconds_);
258   buf->WriteUInt16(rdlength_);
259   switch (GetType()) {
260     case SectionEntryType::kA:
261       WriteARData(buf);
262       return true;
263     case SectionEntryType::kAAAA:
264       WriteQuadARData(buf);
265       return true;
266     case SectionEntryType::kUnsupported:
267       return false;
268     default:
269       RTC_NOTREACHED();
270   }
271   return true;
272 }
273 
WriteARData(rtc::ByteBufferWriter * buf) const274 void MdnsResourceRecord::WriteARData(rtc::ByteBufferWriter* buf) const {
275   buf->WriteString(rdata_);
276 }
277 
WriteQuadARData(rtc::ByteBufferWriter * buf) const278 void MdnsResourceRecord::WriteQuadARData(rtc::ByteBufferWriter* buf) const {
279   buf->WriteString(rdata_);
280 }
281 
SetIPAddressInRecordData(const rtc::IPAddress & address)282 bool MdnsResourceRecord::SetIPAddressInRecordData(
283     const rtc::IPAddress& address) {
284   int af = address.family();
285   if (af != AF_INET && af != AF_INET6) {
286     return false;
287   }
288   char out[16] = {0};
289   if (!rtc::inet_pton(af, address.ToString().c_str(), out)) {
290     return false;
291   }
292   rdlength_ = (af == AF_INET) ? 4 : 16;
293   rdata_ = std::string(out, rdlength_);
294   return true;
295 }
296 
GetIPAddressFromRecordData(rtc::IPAddress * address) const297 bool MdnsResourceRecord::GetIPAddressFromRecordData(
298     rtc::IPAddress* address) const {
299   if (GetType() != SectionEntryType::kA &&
300       GetType() != SectionEntryType::kAAAA) {
301     return false;
302   }
303   if (rdata_.size() != 4 && rdata_.size() != 16) {
304     return false;
305   }
306   char out[INET6_ADDRSTRLEN] = {0};
307   int af = (GetType() == SectionEntryType::kA) ? AF_INET : AF_INET6;
308   if (!rtc::inet_ntop(af, rdata_.data(), out, sizeof(out))) {
309     return false;
310   }
311   return rtc::IPFromString(std::string(out), address);
312 }
313 
314 MdnsMessage::MdnsMessage() = default;
315 MdnsMessage::~MdnsMessage() = default;
316 
Read(MessageBufferReader * buf)317 bool MdnsMessage::Read(MessageBufferReader* buf) {
318   RTC_DCHECK_EQ(0u, buf->CurrentOffset());
319   if (!header_.Read(buf)) {
320     return false;
321   }
322 
323   auto read_question = [&buf](std::vector<MdnsQuestion>* section,
324                               uint16_t count) {
325     section->resize(count);
326     for (auto& question : (*section)) {
327       if (!question.Read(buf)) {
328         return false;
329       }
330     }
331     return true;
332   };
333   auto read_rr = [&buf](std::vector<MdnsResourceRecord>* section,
334                         uint16_t count) {
335     section->resize(count);
336     for (auto& rr : (*section)) {
337       if (!rr.Read(buf)) {
338         return false;
339       }
340     }
341     return true;
342   };
343 
344   if (!read_question(&question_section_, header_.qdcount) ||
345       !read_rr(&answer_section_, header_.ancount) ||
346       !read_rr(&authority_section_, header_.nscount) ||
347       !read_rr(&additional_section_, header_.arcount)) {
348     return false;
349   }
350   return true;
351 }
352 
Write(rtc::ByteBufferWriter * buf) const353 bool MdnsMessage::Write(rtc::ByteBufferWriter* buf) const {
354   header_.Write(buf);
355 
356   auto write_rr = [&buf](const std::vector<MdnsResourceRecord>& section) {
357     for (const auto& rr : section) {
358       if (!rr.Write(buf)) {
359         return false;
360       }
361     }
362     return true;
363   };
364 
365   for (const auto& question : question_section_) {
366     if (!question.Write(buf)) {
367       return false;
368     }
369   }
370   if (!write_rr(answer_section_) || !write_rr(authority_section_) ||
371       !write_rr(additional_section_)) {
372     return false;
373   }
374 
375   return true;
376 }
377 
ShouldUnicastResponse() const378 bool MdnsMessage::ShouldUnicastResponse() const {
379   bool should_unicast = false;
380   for (const auto& question : question_section_) {
381     should_unicast |= question.ShouldUnicastResponse();
382   }
383   return should_unicast;
384 }
385 
AddQuestion(const MdnsQuestion & question)386 void MdnsMessage::AddQuestion(const MdnsQuestion& question) {
387   question_section_.push_back(question);
388   header_.qdcount = question_section_.size();
389 }
390 
AddAnswerRecord(const MdnsResourceRecord & answer)391 void MdnsMessage::AddAnswerRecord(const MdnsResourceRecord& answer) {
392   answer_section_.push_back(answer);
393   header_.ancount = answer_section_.size();
394 }
395 
396 }  // namespace webrtc
397