• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "discovery/mdns/mdns_writer.h"
6 
7 #include <limits>
8 #include <string>
9 #include <utility>
10 #include <vector>
11 
12 #include "absl/hash/hash.h"
13 #include "absl/strings/ascii.h"
14 #include "util/hashing.h"
15 #include "util/osp_logging.h"
16 
17 namespace openscreen {
18 namespace discovery {
19 
20 namespace {
21 
ComputeDomainNameSubhashes(const DomainName & name)22 std::vector<uint64_t> ComputeDomainNameSubhashes(const DomainName& name) {
23   const std::vector<std::string>& labels = name.labels();
24   // Use a large prime between 2^63 and 2^64 as a starting value.
25   // This is taken from absl::Hash implementation.
26   uint64_t hash_value = UINT64_C(0xc3a5c85c97cb3127);
27   std::vector<uint64_t> subhashes(labels.size());
28   for (size_t i = labels.size(); i-- > 0;) {
29     hash_value =
30         ComputeAggregateHash(hash_value, absl::AsciiStrToLower(labels[i]));
31     subhashes[i] = hash_value;
32   }
33   return subhashes;
34 }
35 
36 // This helper method writes the number of bytes between |begin| and |end| minus
37 // the size of the uint16_t into the uint16_t length field at |begin|. The
38 // method returns true if the number of bytes between |begin| and |end| fits in
39 // uint16_t type, returns false otherwise.
UpdateRecordLength(const uint8_t * end,uint8_t * begin)40 bool UpdateRecordLength(const uint8_t* end, uint8_t* begin) {
41   OSP_DCHECK_LE(begin + sizeof(uint16_t), end);
42   ptrdiff_t record_length = end - begin - sizeof(uint16_t);
43   if (record_length <= std::numeric_limits<uint16_t>::max()) {
44     WriteBigEndian<uint16_t>(record_length, begin);
45     return true;
46   }
47   return false;
48 }
49 
50 }  // namespace
51 
Write(absl::string_view value)52 bool MdnsWriter::Write(absl::string_view value) {
53   if (value.length() > std::numeric_limits<uint8_t>::max()) {
54     return false;
55   }
56   Cursor cursor(this);
57   if (Write(static_cast<uint8_t>(value.length())) &&
58       Write(value.data(), value.length())) {
59     cursor.Commit();
60     return true;
61   }
62   return false;
63 }
64 
Write(const std::string & value)65 bool MdnsWriter::Write(const std::string& value) {
66   return Write(absl::string_view(value));
67 }
68 
69 // RFC 1035: https://www.ietf.org/rfc/rfc1035.txt
70 // See section 4.1.4. Message compression
Write(const DomainName & name)71 bool MdnsWriter::Write(const DomainName& name) {
72   if (name.empty()) {
73     return false;
74   }
75 
76   Cursor cursor(this);
77   const std::vector<uint64_t> subhashes = ComputeDomainNameSubhashes(name);
78   // Tentative dictionary contains label pointer entries to be added to the
79   // compression dictionary after successfully writing the domain name.
80   std::unordered_map<uint64_t, uint16_t> tentative_dictionary;
81   const std::vector<std::string>& labels = name.labels();
82   for (size_t i = 0; i < labels.size(); ++i) {
83     OSP_DCHECK(IsValidDomainLabel(labels[i]));
84     // We only need to do a look up in the compression dictionary and not in the
85     // tentative dictionary. The tentative dictionary cannot possibly contain a
86     // valid label pointer as all the entries previously added to it are for
87     // names that are longer than the currently processed sub-name.
88     auto find_result = dictionary_.find(subhashes[i]);
89     if (find_result != dictionary_.end()) {
90       if (!Write(find_result->second)) {
91         return false;
92       }
93       dictionary_.insert(tentative_dictionary.begin(),
94                          tentative_dictionary.end());
95       cursor.Commit();
96       return true;
97     }
98     // Only add a pointer_label for compression if the offset into the buffer
99     // fits into the bits available to store it.
100     if (IsValidPointerLabelOffset(current() - begin())) {
101       tentative_dictionary.insert(
102           std::make_pair(subhashes[i], MakePointerLabel(current() - begin())));
103     }
104     if (!Write(MakeDirectLabel(labels[i].size())) ||
105         !Write(labels[i].data(), labels[i].size())) {
106       return false;
107     }
108   }
109   if (!Write(kLabelTermination)) {
110     return false;
111   }
112   // The probability of a collision is extremely low in this application, as the
113   // number of domain names compressed is insignificant in comparison to the
114   // hash function image.
115   dictionary_.insert(tentative_dictionary.begin(), tentative_dictionary.end());
116   cursor.Commit();
117   return true;
118 }
119 
Write(const RawRecordRdata & rdata)120 bool MdnsWriter::Write(const RawRecordRdata& rdata) {
121   Cursor cursor(this);
122   if (Write(rdata.size()) && Write(rdata.data(), rdata.size())) {
123     cursor.Commit();
124     return true;
125   }
126   return false;
127 }
128 
Write(const SrvRecordRdata & rdata)129 bool MdnsWriter::Write(const SrvRecordRdata& rdata) {
130   Cursor cursor(this);
131   // Leave space at the beginning at |rollback_position| to write the record
132   // length. Cannot write it upfront, since the exact space taken by the target
133   // domain name is not known as it might be compressed.
134   if (Skip(sizeof(uint16_t)) && Write(rdata.priority()) &&
135       Write(rdata.weight()) && Write(rdata.port()) && Write(rdata.target()) &&
136       UpdateRecordLength(current(), cursor.origin())) {
137     cursor.Commit();
138     return true;
139   }
140   return false;
141 }
142 
Write(const ARecordRdata & rdata)143 bool MdnsWriter::Write(const ARecordRdata& rdata) {
144   Cursor cursor(this);
145   if (Write(static_cast<uint16_t>(IPAddress::kV4Size)) &&
146       Write(rdata.ipv4_address())) {
147     cursor.Commit();
148     return true;
149   }
150   return false;
151 }
152 
Write(const AAAARecordRdata & rdata)153 bool MdnsWriter::Write(const AAAARecordRdata& rdata) {
154   Cursor cursor(this);
155   if (Write(static_cast<uint16_t>(IPAddress::kV6Size)) &&
156       Write(rdata.ipv6_address())) {
157     cursor.Commit();
158     return true;
159   }
160   return false;
161 }
162 
Write(const PtrRecordRdata & rdata)163 bool MdnsWriter::Write(const PtrRecordRdata& rdata) {
164   Cursor cursor(this);
165   // Leave space at the beginning at |rollback_position| to write the record
166   // length. Cannot write it upfront, since the exact space taken by the target
167   // domain name is not known as it might be compressed.
168   if (Skip(sizeof(uint16_t)) && Write(rdata.ptr_domain()) &&
169       UpdateRecordLength(current(), cursor.origin())) {
170     cursor.Commit();
171     return true;
172   }
173   return false;
174 }
175 
Write(const TxtRecordRdata & rdata)176 bool MdnsWriter::Write(const TxtRecordRdata& rdata) {
177   Cursor cursor(this);
178   // Leave space at the beginning at |rollback_position| to write the record
179   // length. It's cheaper to update it at the end than precompute the length.
180   if (!Skip(sizeof(uint16_t))) {
181     return false;
182   }
183   if (rdata.texts().size() > 0) {
184     if (!Write(rdata.texts())) {
185       return false;
186     }
187   } else {
188     if (!Write(kTXTEmptyRdata)) {
189       return false;
190     }
191   }
192   if (!UpdateRecordLength(current(), cursor.origin())) {
193     return false;
194   }
195   cursor.Commit();
196   return true;
197 }
198 
Write(const NsecRecordRdata & rdata)199 bool MdnsWriter::Write(const NsecRecordRdata& rdata) {
200   Cursor cursor(this);
201   if (Skip(sizeof(uint16_t)) && Write(rdata.next_domain_name()) &&
202       Write(rdata.encoded_types()) &&
203       UpdateRecordLength(current(), cursor.origin())) {
204     cursor.Commit();
205     return true;
206   }
207   return false;
208 }
209 
Write(const OptRecordRdata & rdata)210 bool MdnsWriter::Write(const OptRecordRdata& rdata) {
211   // OPT records are currently not supported for outgoing messages.
212   OSP_UNIMPLEMENTED();
213   return false;
214 }
215 
Write(const MdnsRecord & record)216 bool MdnsWriter::Write(const MdnsRecord& record) {
217   Cursor cursor(this);
218   if (Write(record.name()) && Write(static_cast<uint16_t>(record.dns_type())) &&
219       Write(MakeRecordClass(record.dns_class(), record.record_type())) &&
220       Write(static_cast<uint32_t>(record.ttl().count())) &&
221       Write(record.rdata())) {
222     cursor.Commit();
223     return true;
224   }
225   return false;
226 }
227 
Write(const MdnsQuestion & question)228 bool MdnsWriter::Write(const MdnsQuestion& question) {
229   Cursor cursor(this);
230   if (Write(question.name()) &&
231       Write(static_cast<uint16_t>(question.dns_type())) &&
232       Write(
233           MakeQuestionClass(question.dns_class(), question.response_type()))) {
234     cursor.Commit();
235     return true;
236   }
237   return false;
238 }
239 
Write(const MdnsMessage & message)240 bool MdnsWriter::Write(const MdnsMessage& message) {
241   Cursor cursor(this);
242   Header header;
243   header.id = message.id();
244   header.flags = MakeFlags(message.type(), message.is_truncated());
245   header.question_count = message.questions().size();
246   header.answer_count = message.answers().size();
247   header.authority_record_count = message.authority_records().size();
248   header.additional_record_count = message.additional_records().size();
249   if (Write(header) && Write(message.questions()) && Write(message.answers()) &&
250       Write(message.authority_records()) &&
251       Write(message.additional_records())) {
252     cursor.Commit();
253     return true;
254   }
255   return false;
256 }
257 
Write(const IPAddress & address)258 bool MdnsWriter::Write(const IPAddress& address) {
259   uint8_t bytes[IPAddress::kV6Size];
260   size_t size;
261   if (address.IsV6()) {
262     address.CopyToV6(bytes);
263     size = IPAddress::kV6Size;
264   } else {
265     address.CopyToV4(bytes);
266     size = IPAddress::kV4Size;
267   }
268   return Write(bytes, size);
269 }
270 
Write(const Rdata & rdata)271 bool MdnsWriter::Write(const Rdata& rdata) {
272   return absl::visit([this](const auto& rdata) { return this->Write(rdata); },
273                      rdata);
274 }
275 
Write(const Header & header)276 bool MdnsWriter::Write(const Header& header) {
277   Cursor cursor(this);
278   if (Write(header.id) && Write(header.flags) && Write(header.question_count) &&
279       Write(header.answer_count) && Write(header.authority_record_count) &&
280       Write(header.additional_record_count)) {
281     cursor.Commit();
282     return true;
283   }
284   return false;
285 }
286 
287 }  // namespace discovery
288 }  // namespace openscreen
289