1 /*
2 * Copyright (C) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "mdns_packet_parser.h"
17
18 #include <cstring>
19
20 namespace OHOS {
21 namespace NetManagerStandard {
22
23 namespace {
24
25 constexpr size_t MDNS_STR_INITIAL_SIZE = 16;
26
27 constexpr uint8_t DNS_STR_PTR_U8_MASK = 0xc0;
28 constexpr uint16_t DNS_STR_PTR_U16_MASK = 0xc000;
29 constexpr uint16_t DNS_STR_PTR_LENGTH = 0x3f;
30 constexpr uint8_t DNS_STR_EOL = '\0';
31
WriteRawData(const T & data,MDnsPayload & payload)32 template <class T> void WriteRawData(const T &data, MDnsPayload &payload)
33 {
34 const uint8_t *begin = reinterpret_cast<const uint8_t *>(&data);
35 payload.insert(payload.end(), begin, begin + sizeof(T));
36 }
37
WriteRawData(const T & data,uint8_t * ptr)38 template <class T> void WriteRawData(const T &data, uint8_t *ptr)
39 {
40 const uint8_t *begin = reinterpret_cast<const uint8_t *>(&data);
41 for (size_t i = 0; i < sizeof(T); ++i) {
42 ptr[i] = *begin++;
43 }
44 }
45
ReadRawData(const uint8_t * raw,T & data)46 template <class T> const uint8_t *ReadRawData(const uint8_t *raw, T &data)
47 {
48 data = *reinterpret_cast<const T *>(raw);
49 return raw + sizeof(T);
50 }
51
ReadNUint16(const uint8_t * raw,uint16_t & data)52 const uint8_t *ReadNUint16(const uint8_t *raw, uint16_t &data)
53 {
54 const uint8_t *tmp = ReadRawData(raw, data);
55 data = ntohs(data);
56 return tmp;
57 }
58
ReadNUint32(const uint8_t * raw,uint32_t & data)59 const uint8_t *ReadNUint32(const uint8_t *raw, uint32_t &data)
60 {
61 const uint8_t *tmp = ReadRawData(raw, data);
62 data = ntohl(data);
63 return tmp;
64 }
65
UnDotted(const std::string & name)66 std::string UnDotted(const std::string &name)
67 {
68 return EndsWith(name, MDNS_DOMAIN_SPLITER_STR) ? name.substr(0, name.size() - 1) : name;
69 }
70
71 } // namespace
72
FromBytes(const MDnsPayload & payload)73 MDnsMessage MDnsPayloadParser::FromBytes(const MDnsPayload &payload)
74 {
75 MDnsMessage msg;
76 errorFlags_ = PARSE_OK;
77 pos_ = Parse(payload.data(), payload, msg);
78 return msg;
79 }
80
ToBytes(const MDnsMessage & msg)81 MDnsPayload MDnsPayloadParser::ToBytes(const MDnsMessage &msg)
82 {
83 MDnsPayload payload;
84 cachedPayload_ = &payload;
85 strCacheMap_.clear();
86 Serialize(msg, payload);
87 cachedPayload_ = nullptr;
88 strCacheMap_.clear();
89 return payload;
90 }
91
Parse(const uint8_t * begin,const MDnsPayload & payload,MDnsMessage & msg)92 const uint8_t *MDnsPayloadParser::Parse(const uint8_t *begin, const MDnsPayload &payload, MDnsMessage &msg)
93 {
94 begin = ParseHeader(begin, payload, msg.header);
95 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
96 return begin;
97 }
98 for (int i = 0; i < msg.header.qdcount; ++i) {
99 begin = ParseQuestion(begin, payload, msg.questions);
100 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
101 return begin;
102 }
103 }
104 for (int i = 0; i < msg.header.ancount; ++i) {
105 begin = ParseRR(begin, payload, msg.answers);
106 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
107 return begin;
108 }
109 }
110 for (int i = 0; i < msg.header.nscount; ++i) {
111 begin = ParseRR(begin, payload, msg.authorities);
112 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
113 return begin;
114 }
115 }
116 for (int i = 0; i < msg.header.arcount; ++i) {
117 begin = ParseRR(begin, payload, msg.additional);
118 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
119 return begin;
120 }
121 }
122 return begin;
123 }
124
ParseHeader(const uint8_t * begin,const MDnsPayload & payload,DNSProto::Header & header)125 const uint8_t *MDnsPayloadParser::ParseHeader(const uint8_t *begin, const MDnsPayload &payload,
126 DNSProto::Header &header)
127 {
128 const uint8_t *end = payload.data() + payload.size();
129 if (end - begin < static_cast<int>(sizeof(DNSProto::Header))) {
130 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
131 return begin;
132 }
133
134 begin = ReadNUint16(begin, header.id);
135 begin = ReadNUint16(begin, header.flags);
136 begin = ReadNUint16(begin, header.qdcount);
137 begin = ReadNUint16(begin, header.ancount);
138 begin = ReadNUint16(begin, header.nscount);
139 begin = ReadNUint16(begin, header.arcount);
140 return begin;
141 }
142
ParseQuestion(const uint8_t * begin,const MDnsPayload & payload,std::vector<DNSProto::Question> & questions)143 const uint8_t *MDnsPayloadParser::ParseQuestion(const uint8_t *begin, const MDnsPayload &payload,
144 std::vector<DNSProto::Question> &questions)
145 {
146 questions.emplace_back();
147 begin = ParseDnsString(begin, payload, questions.back().name);
148 questions.back().name = UnDotted(questions.back().name);
149 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
150 questions.pop_back();
151 return begin;
152 }
153
154 const uint8_t *end = payload.data() + payload.size();
155 if (static_cast<ssize_t>(end - begin) < static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t))) {
156 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
157 questions.pop_back();
158 return begin;
159 }
160
161 begin = ReadNUint16(begin, questions.back().qtype);
162 begin = ReadNUint16(begin, questions.back().qclass);
163 return begin;
164 }
165
ParseRR(const uint8_t * begin,const MDnsPayload & payload,std::vector<DNSProto::ResourceRecord> & answers)166 const uint8_t *MDnsPayloadParser::ParseRR(const uint8_t *begin, const MDnsPayload &payload,
167 std::vector<DNSProto::ResourceRecord> &answers)
168 {
169 answers.emplace_back();
170 begin = ParseDnsString(begin, payload, answers.back().name);
171 answers.back().name = UnDotted(answers.back().name);
172 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
173 answers.pop_back();
174 return begin;
175 }
176
177 const uint8_t *end = payload.data() + payload.size();
178 if (static_cast<ssize_t>(end - begin) <
179 static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t) + sizeof(uint32_t) + sizeof(uint16_t))) {
180 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
181 answers.pop_back();
182 return begin;
183 }
184 begin = ReadNUint16(begin, answers.back().rtype);
185 begin = ReadNUint16(begin, answers.back().rclass);
186 begin = ReadNUint32(begin, answers.back().ttl);
187 begin = ReadNUint16(begin, answers.back().length);
188 return ParseRData(begin, payload, answers.back().rtype, answers.back().length, answers.back().rdata);
189 }
190
ParseRData(const uint8_t * begin,const MDnsPayload & payload,int type,int length,std::any & data)191 const uint8_t *MDnsPayloadParser::ParseRData(const uint8_t *begin, const MDnsPayload &payload, int type, int length,
192 std::any &data)
193 {
194 switch (type) {
195 case DNSProto::RRTYPE_A: {
196 const uint8_t *end = payload.data() + payload.size();
197 if (static_cast<size_t>(end - begin) < sizeof(in_addr) || length != sizeof(in_addr)) {
198 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
199 return begin;
200 }
201 in_addr addr;
202 begin = ReadRawData(begin, addr);
203 data = addr;
204 return begin;
205 }
206 case DNSProto::RRTYPE_AAAA: {
207 const uint8_t *end = payload.data() + payload.size();
208 if (static_cast<ssize_t>(end - begin) <
209 static_cast<ssize_t>(sizeof(in6_addr) || length != sizeof(in6_addr))) {
210 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
211 return begin;
212 }
213 in6_addr addr;
214 begin = ReadRawData(begin, addr);
215 data = addr;
216 return begin;
217 }
218 case DNSProto::RRTYPE_PTR: {
219 std::string str;
220 begin = ParseDnsString(begin, payload, str);
221 str = UnDotted(str);
222 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
223 return begin;
224 }
225 data = str;
226 return begin;
227 }
228 case DNSProto::RRTYPE_SRV: {
229 return ParseSrv(begin, payload, data);
230 }
231 case DNSProto::RRTYPE_TXT: {
232 return ParseTxt(begin, payload, length, data);
233 }
234 default: {
235 errorFlags_ |= PARSE_WARNING_BAD_RRTYPE;
236 return begin + length;
237 }
238 }
239 }
240
ParseSrv(const uint8_t * begin,const MDnsPayload & payload,std::any & data)241 const uint8_t *MDnsPayloadParser::ParseSrv(const uint8_t *begin, const MDnsPayload &payload, std::any &data)
242 {
243 const uint8_t *end = payload.data() + payload.size();
244 if (static_cast<ssize_t>(end - begin) <
245 static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t) + sizeof(uint16_t))) {
246 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
247 return begin;
248 }
249
250 DNSProto::RDataSrv srv;
251 begin = ReadNUint16(begin, srv.priority);
252 begin = ReadNUint16(begin, srv.weight);
253 begin = ReadNUint16(begin, srv.port);
254 begin = ParseDnsString(begin, payload, srv.name);
255 srv.name = UnDotted(srv.name);
256 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
257 return begin;
258 }
259 data = srv;
260 return begin;
261 }
262
ParseTxt(const uint8_t * begin,const MDnsPayload & payload,int length,std::any & data)263 const uint8_t *MDnsPayloadParser::ParseTxt(const uint8_t *begin, const MDnsPayload &payload, int length, std::any &data)
264 {
265 const uint8_t *end = payload.data() + payload.size();
266 if (end - begin < length) {
267 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
268 return begin;
269 }
270
271 data = TxtRecordEncoded(begin, begin + length);
272 return begin + length;
273 }
274
ParseDnsString(const uint8_t * begin,const MDnsPayload & payload,std::string & str)275 const uint8_t *MDnsPayloadParser::ParseDnsString(const uint8_t *begin, const MDnsPayload &payload, std::string &str)
276 {
277 const uint8_t *end = payload.data() + payload.size();
278 const uint8_t *p = begin;
279 str.reserve(MDNS_STR_INITIAL_SIZE);
280 while (p && p < end) {
281 if (*p == 0) {
282 return p + 1;
283 }
284 if (*p <= MDNS_MAX_DOMAIN_LABEL && p + *p < end) {
285 str.append(reinterpret_cast<const char *>(p) + 1, *p);
286 str.push_back(MDNS_DOMAIN_SPLITER);
287 p += (*p + 1);
288 } else if ((*p & DNS_STR_PTR_U8_MASK) == DNS_STR_PTR_U8_MASK) {
289 if (end - p < static_cast<int>(sizeof(uint16_t))) {
290 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
291 return begin;
292 }
293
294 uint16_t offset;
295 const uint8_t *tmp = ReadNUint16(p, offset);
296 offset = offset & ~DNS_STR_PTR_U16_MASK;
297 const uint8_t *next = payload.data() + (offset & ~DNS_STR_PTR_U16_MASK);
298
299 if (next >= end || next >= begin) {
300 errorFlags_ |= PARSE_ERROR_BAD_STRPTR;
301 return begin;
302 }
303 ParseDnsString(next, payload, str);
304 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
305 return begin;
306 }
307 return tmp;
308 } else {
309 errorFlags_ |= PARSE_ERROR_BAD_STR;
310 return p;
311 }
312 }
313 return p;
314 }
315
Serialize(const MDnsMessage & msg,MDnsPayload & payload)316 void MDnsPayloadParser::Serialize(const MDnsMessage &msg, MDnsPayload &payload)
317 {
318 payload.reserve(sizeof(DNSProto::Message));
319 DNSProto::Header header = msg.header;
320 header.qdcount = msg.questions.size();
321 header.ancount = msg.answers.size();
322 header.nscount = msg.authorities.size();
323 header.arcount = msg.additional.size();
324 SerializeHeader(header, msg, payload);
325 for (uint16_t i = 0; i < header.qdcount; ++i) {
326 SerializeQuestion(msg.questions[i], msg, payload);
327 }
328 for (uint16_t i = 0; i < header.ancount; ++i) {
329 SerializeRR(msg.answers[i], msg, payload);
330 }
331 for (uint16_t i = 0; i < header.nscount; ++i) {
332 SerializeRR(msg.authorities[i], msg, payload);
333 }
334 for (uint16_t i = 0; i < header.arcount; ++i) {
335 SerializeRR(msg.additional[i], msg, payload);
336 }
337 }
338
SerializeHeader(const DNSProto::Header & header,const MDnsMessage & msg,MDnsPayload & payload)339 void MDnsPayloadParser::SerializeHeader(const DNSProto::Header &header, const MDnsMessage &msg, MDnsPayload &payload)
340 {
341 WriteRawData(htons(header.id), payload);
342 WriteRawData(htons(header.flags), payload);
343 WriteRawData(htons(header.qdcount), payload);
344 WriteRawData(htons(header.ancount), payload);
345 WriteRawData(htons(header.nscount), payload);
346 WriteRawData(htons(header.arcount), payload);
347 }
348
SerializeQuestion(const DNSProto::Question & question,const MDnsMessage & msg,MDnsPayload & payload)349 void MDnsPayloadParser::SerializeQuestion(const DNSProto::Question &question, const MDnsMessage &msg,
350 MDnsPayload &payload)
351 {
352 SerializeDnsString(question.name, msg, payload);
353 WriteRawData(htons(question.qtype), payload);
354 WriteRawData(htons(question.qclass), payload);
355 }
356
SerializeRR(const DNSProto::ResourceRecord & rr,const MDnsMessage & msg,MDnsPayload & payload)357 void MDnsPayloadParser::SerializeRR(const DNSProto::ResourceRecord &rr, const MDnsMessage &msg, MDnsPayload &payload)
358 {
359 SerializeDnsString(rr.name, msg, payload);
360 WriteRawData(htons(rr.rtype), payload);
361 WriteRawData(htons(rr.rclass), payload);
362 WriteRawData(htonl(rr.ttl), payload);
363 size_t lenStart = payload.size();
364 WriteRawData(htons(rr.length), payload);
365 SerializeRData(rr.rdata, msg, payload);
366 uint16_t len = payload.size() - lenStart - sizeof(uint16_t);
367 WriteRawData(htons(len), payload.data() + lenStart);
368 }
369
SerializeRData(const std::any & rdata,const MDnsMessage & msg,MDnsPayload & payload)370 void MDnsPayloadParser::SerializeRData(const std::any &rdata, const MDnsMessage &msg, MDnsPayload &payload)
371 {
372 if (std::any_cast<const in_addr>(&rdata)) {
373 WriteRawData(*std::any_cast<const in_addr>(&rdata), payload);
374 } else if (std::any_cast<const in6_addr>(&rdata)) {
375 WriteRawData(*std::any_cast<const in6_addr>(&rdata), payload);
376 } else if (std::any_cast<const std::string>(&rdata)) {
377 SerializeDnsString(*std::any_cast<const std::string>(&rdata), msg, payload);
378 } else if (std::any_cast<const DNSProto::RDataSrv>(&rdata)) {
379 const DNSProto::RDataSrv *srv = std::any_cast<const DNSProto::RDataSrv>(&rdata);
380 WriteRawData(htons(srv->priority), payload);
381 WriteRawData(htons(srv->weight), payload);
382 WriteRawData(htons(srv->port), payload);
383 SerializeDnsString(srv->name, msg, payload);
384 } else if (std::any_cast<TxtRecordEncoded>(&rdata)) {
385 const auto *txt = std::any_cast<TxtRecordEncoded>(&rdata);
386 payload.insert(payload.end(), txt->begin(), txt->end());
387 }
388 }
389
SerializeDnsString(const std::string & str,const MDnsMessage & msg,MDnsPayload & payload)390 void MDnsPayloadParser::SerializeDnsString(const std::string &str, const MDnsMessage &msg, MDnsPayload &payload)
391 {
392 size_t pos = 0;
393 while (pos < str.size()) {
394 if (cachedPayload_ == &payload && strCacheMap_.find(str.substr(pos)) != strCacheMap_.end()) {
395 return WriteRawData(htons(strCacheMap_[str.substr(pos)]), payload);
396 }
397
398 size_t nextDot = str.find(MDNS_DOMAIN_SPLITER, pos);
399 if (nextDot == std::string::npos) {
400 nextDot = str.size();
401 }
402 uint8_t segLen = (nextDot - pos) & DNS_STR_PTR_LENGTH;
403 uint16_t strptr = payload.size();
404 WriteRawData(segLen, payload);
405 for (int i = 0; i < segLen; ++i) {
406 WriteRawData(str[pos + i], payload);
407 }
408 strCacheMap_[str.substr(pos)] = strptr | DNS_STR_PTR_U16_MASK;
409 pos = nextDot + 1;
410 }
411 WriteRawData(DNS_STR_EOL, payload);
412 }
413
GetError() const414 uint32_t MDnsPayloadParser::GetError() const
415 {
416 return errorFlags_ & PARSE_ERROR;
417 }
418
419 } // namespace NetManagerStandard
420 } // namespace OHOS
421