1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "discovery/mdns/mdns_writer.h"
6
7 #include <limits>
8 #include <string>
9 #include <utility>
10 #include <vector>
11
12 #include "absl/hash/hash.h"
13 #include "absl/strings/ascii.h"
14 #include "util/hashing.h"
15 #include "util/osp_logging.h"
16
17 namespace openscreen {
18 namespace discovery {
19
20 namespace {
21
ComputeDomainNameSubhashes(const DomainName & name)22 std::vector<uint64_t> ComputeDomainNameSubhashes(const DomainName& name) {
23 const std::vector<std::string>& labels = name.labels();
24 // Use a large prime between 2^63 and 2^64 as a starting value.
25 // This is taken from absl::Hash implementation.
26 uint64_t hash_value = UINT64_C(0xc3a5c85c97cb3127);
27 std::vector<uint64_t> subhashes(labels.size());
28 for (size_t i = labels.size(); i-- > 0;) {
29 hash_value =
30 ComputeAggregateHash(hash_value, absl::AsciiStrToLower(labels[i]));
31 subhashes[i] = hash_value;
32 }
33 return subhashes;
34 }
35
36 // This helper method writes the number of bytes between |begin| and |end| minus
37 // the size of the uint16_t into the uint16_t length field at |begin|. The
38 // method returns true if the number of bytes between |begin| and |end| fits in
39 // uint16_t type, returns false otherwise.
UpdateRecordLength(const uint8_t * end,uint8_t * begin)40 bool UpdateRecordLength(const uint8_t* end, uint8_t* begin) {
41 OSP_DCHECK_LE(begin + sizeof(uint16_t), end);
42 ptrdiff_t record_length = end - begin - sizeof(uint16_t);
43 if (record_length <= std::numeric_limits<uint16_t>::max()) {
44 WriteBigEndian<uint16_t>(record_length, begin);
45 return true;
46 }
47 return false;
48 }
49
50 } // namespace
51
Write(absl::string_view value)52 bool MdnsWriter::Write(absl::string_view value) {
53 if (value.length() > std::numeric_limits<uint8_t>::max()) {
54 return false;
55 }
56 Cursor cursor(this);
57 if (Write(static_cast<uint8_t>(value.length())) &&
58 Write(value.data(), value.length())) {
59 cursor.Commit();
60 return true;
61 }
62 return false;
63 }
64
Write(const std::string & value)65 bool MdnsWriter::Write(const std::string& value) {
66 return Write(absl::string_view(value));
67 }
68
69 // RFC 1035: https://www.ietf.org/rfc/rfc1035.txt
70 // See section 4.1.4. Message compression
Write(const DomainName & name)71 bool MdnsWriter::Write(const DomainName& name) {
72 if (name.empty()) {
73 return false;
74 }
75
76 Cursor cursor(this);
77 const std::vector<uint64_t> subhashes = ComputeDomainNameSubhashes(name);
78 // Tentative dictionary contains label pointer entries to be added to the
79 // compression dictionary after successfully writing the domain name.
80 std::unordered_map<uint64_t, uint16_t> tentative_dictionary;
81 const std::vector<std::string>& labels = name.labels();
82 for (size_t i = 0; i < labels.size(); ++i) {
83 OSP_DCHECK(IsValidDomainLabel(labels[i]));
84 // We only need to do a look up in the compression dictionary and not in the
85 // tentative dictionary. The tentative dictionary cannot possibly contain a
86 // valid label pointer as all the entries previously added to it are for
87 // names that are longer than the currently processed sub-name.
88 auto find_result = dictionary_.find(subhashes[i]);
89 if (find_result != dictionary_.end()) {
90 if (!Write(find_result->second)) {
91 return false;
92 }
93 dictionary_.insert(tentative_dictionary.begin(),
94 tentative_dictionary.end());
95 cursor.Commit();
96 return true;
97 }
98 // Only add a pointer_label for compression if the offset into the buffer
99 // fits into the bits available to store it.
100 if (IsValidPointerLabelOffset(current() - begin())) {
101 tentative_dictionary.insert(
102 std::make_pair(subhashes[i], MakePointerLabel(current() - begin())));
103 }
104 if (!Write(MakeDirectLabel(labels[i].size())) ||
105 !Write(labels[i].data(), labels[i].size())) {
106 return false;
107 }
108 }
109 if (!Write(kLabelTermination)) {
110 return false;
111 }
112 // The probability of a collision is extremely low in this application, as the
113 // number of domain names compressed is insignificant in comparison to the
114 // hash function image.
115 dictionary_.insert(tentative_dictionary.begin(), tentative_dictionary.end());
116 cursor.Commit();
117 return true;
118 }
119
Write(const RawRecordRdata & rdata)120 bool MdnsWriter::Write(const RawRecordRdata& rdata) {
121 Cursor cursor(this);
122 if (Write(rdata.size()) && Write(rdata.data(), rdata.size())) {
123 cursor.Commit();
124 return true;
125 }
126 return false;
127 }
128
Write(const SrvRecordRdata & rdata)129 bool MdnsWriter::Write(const SrvRecordRdata& rdata) {
130 Cursor cursor(this);
131 // Leave space at the beginning at |rollback_position| to write the record
132 // length. Cannot write it upfront, since the exact space taken by the target
133 // domain name is not known as it might be compressed.
134 if (Skip(sizeof(uint16_t)) && Write(rdata.priority()) &&
135 Write(rdata.weight()) && Write(rdata.port()) && Write(rdata.target()) &&
136 UpdateRecordLength(current(), cursor.origin())) {
137 cursor.Commit();
138 return true;
139 }
140 return false;
141 }
142
Write(const ARecordRdata & rdata)143 bool MdnsWriter::Write(const ARecordRdata& rdata) {
144 Cursor cursor(this);
145 if (Write(static_cast<uint16_t>(IPAddress::kV4Size)) &&
146 Write(rdata.ipv4_address())) {
147 cursor.Commit();
148 return true;
149 }
150 return false;
151 }
152
Write(const AAAARecordRdata & rdata)153 bool MdnsWriter::Write(const AAAARecordRdata& rdata) {
154 Cursor cursor(this);
155 if (Write(static_cast<uint16_t>(IPAddress::kV6Size)) &&
156 Write(rdata.ipv6_address())) {
157 cursor.Commit();
158 return true;
159 }
160 return false;
161 }
162
Write(const PtrRecordRdata & rdata)163 bool MdnsWriter::Write(const PtrRecordRdata& rdata) {
164 Cursor cursor(this);
165 // Leave space at the beginning at |rollback_position| to write the record
166 // length. Cannot write it upfront, since the exact space taken by the target
167 // domain name is not known as it might be compressed.
168 if (Skip(sizeof(uint16_t)) && Write(rdata.ptr_domain()) &&
169 UpdateRecordLength(current(), cursor.origin())) {
170 cursor.Commit();
171 return true;
172 }
173 return false;
174 }
175
Write(const TxtRecordRdata & rdata)176 bool MdnsWriter::Write(const TxtRecordRdata& rdata) {
177 Cursor cursor(this);
178 // Leave space at the beginning at |rollback_position| to write the record
179 // length. It's cheaper to update it at the end than precompute the length.
180 if (!Skip(sizeof(uint16_t))) {
181 return false;
182 }
183 if (rdata.texts().size() > 0) {
184 if (!Write(rdata.texts())) {
185 return false;
186 }
187 } else {
188 if (!Write(kTXTEmptyRdata)) {
189 return false;
190 }
191 }
192 if (!UpdateRecordLength(current(), cursor.origin())) {
193 return false;
194 }
195 cursor.Commit();
196 return true;
197 }
198
Write(const NsecRecordRdata & rdata)199 bool MdnsWriter::Write(const NsecRecordRdata& rdata) {
200 Cursor cursor(this);
201 if (Skip(sizeof(uint16_t)) && Write(rdata.next_domain_name()) &&
202 Write(rdata.encoded_types()) &&
203 UpdateRecordLength(current(), cursor.origin())) {
204 cursor.Commit();
205 return true;
206 }
207 return false;
208 }
209
Write(const OptRecordRdata & rdata)210 bool MdnsWriter::Write(const OptRecordRdata& rdata) {
211 // OPT records are currently not supported for outgoing messages.
212 OSP_UNIMPLEMENTED();
213 return false;
214 }
215
Write(const MdnsRecord & record)216 bool MdnsWriter::Write(const MdnsRecord& record) {
217 Cursor cursor(this);
218 if (Write(record.name()) && Write(static_cast<uint16_t>(record.dns_type())) &&
219 Write(MakeRecordClass(record.dns_class(), record.record_type())) &&
220 Write(static_cast<uint32_t>(record.ttl().count())) &&
221 Write(record.rdata())) {
222 cursor.Commit();
223 return true;
224 }
225 return false;
226 }
227
Write(const MdnsQuestion & question)228 bool MdnsWriter::Write(const MdnsQuestion& question) {
229 Cursor cursor(this);
230 if (Write(question.name()) &&
231 Write(static_cast<uint16_t>(question.dns_type())) &&
232 Write(
233 MakeQuestionClass(question.dns_class(), question.response_type()))) {
234 cursor.Commit();
235 return true;
236 }
237 return false;
238 }
239
Write(const MdnsMessage & message)240 bool MdnsWriter::Write(const MdnsMessage& message) {
241 Cursor cursor(this);
242 Header header;
243 header.id = message.id();
244 header.flags = MakeFlags(message.type(), message.is_truncated());
245 header.question_count = message.questions().size();
246 header.answer_count = message.answers().size();
247 header.authority_record_count = message.authority_records().size();
248 header.additional_record_count = message.additional_records().size();
249 if (Write(header) && Write(message.questions()) && Write(message.answers()) &&
250 Write(message.authority_records()) &&
251 Write(message.additional_records())) {
252 cursor.Commit();
253 return true;
254 }
255 return false;
256 }
257
Write(const IPAddress & address)258 bool MdnsWriter::Write(const IPAddress& address) {
259 uint8_t bytes[IPAddress::kV6Size];
260 size_t size;
261 if (address.IsV6()) {
262 address.CopyToV6(bytes);
263 size = IPAddress::kV6Size;
264 } else {
265 address.CopyToV4(bytes);
266 size = IPAddress::kV4Size;
267 }
268 return Write(bytes, size);
269 }
270
Write(const Rdata & rdata)271 bool MdnsWriter::Write(const Rdata& rdata) {
272 return absl::visit([this](const auto& rdata) { return this->Write(rdata); },
273 rdata);
274 }
275
Write(const Header & header)276 bool MdnsWriter::Write(const Header& header) {
277 Cursor cursor(this);
278 if (Write(header.id) && Write(header.flags) && Write(header.question_count) &&
279 Write(header.answer_count) && Write(header.authority_record_count) &&
280 Write(header.additional_record_count)) {
281 cursor.Commit();
282 return true;
283 }
284 return false;
285 }
286
287 } // namespace discovery
288 } // namespace openscreen
289