1 /*
2 * Copyright (C) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "mdns_packet_parser.h"
17 #include "netmgr_ext_log_wrapper.h"
18 #include <cstring>
19
20 namespace OHOS {
21 namespace NetManagerStandard {
22
23 namespace {
24
25 constexpr size_t MDNS_STR_INITIAL_SIZE = 16;
26
27 constexpr uint8_t DNS_STR_PTR_U8_MASK = 0xc0;
28 constexpr uint16_t DNS_STR_PTR_U16_MASK = 0xc000;
29 constexpr uint16_t DNS_STR_PTR_LENGTH = 0x3f;
30 constexpr uint8_t DNS_STR_EOL = '\0';
31
WriteRawData(const T & data,MDnsPayload & payload)32 template <class T> void WriteRawData(const T &data, MDnsPayload &payload)
33 {
34 const uint8_t *begin = reinterpret_cast<const uint8_t *>(&data);
35 payload.insert(payload.end(), begin, begin + sizeof(T));
36 }
37
WriteRawData(const T & data,uint8_t * ptr)38 template <class T> void WriteRawData(const T &data, uint8_t *ptr)
39 {
40 const uint8_t *begin = reinterpret_cast<const uint8_t *>(&data);
41 for (size_t i = 0; i < sizeof(T); ++i) {
42 ptr[i] = *begin++;
43 }
44 }
45
ReadRawData(const uint8_t * raw,T & data)46 template <class T> const uint8_t *ReadRawData(const uint8_t *raw, T &data)
47 {
48 data = *reinterpret_cast<const T *>(raw);
49 return raw + sizeof(T);
50 }
51
ReadNUint16(const uint8_t * raw,uint16_t & data)52 const uint8_t *ReadNUint16(const uint8_t *raw, uint16_t &data)
53 {
54 const uint8_t *tmp = ReadRawData(raw, data);
55 data = ntohs(data);
56 return tmp;
57 }
58
ReadNUint32(const uint8_t * raw,uint32_t & data)59 const uint8_t *ReadNUint32(const uint8_t *raw, uint32_t &data)
60 {
61 const uint8_t *tmp = ReadRawData(raw, data);
62 data = ntohl(data);
63 return tmp;
64 }
65
UnDotted(const std::string & name)66 std::string UnDotted(const std::string &name)
67 {
68 return EndsWith(name, MDNS_DOMAIN_SPLITER_STR) ? name.substr(0, name.size() - 1) : name;
69 }
70
71 } // namespace
72
FromBytes(const MDnsPayload & payload)73 MDnsMessage MDnsPayloadParser::FromBytes(const MDnsPayload &payload)
74 {
75 MDnsMessage msg;
76 errorFlags_ = PARSE_OK;
77 pos_ = Parse(payload.data(), payload, msg);
78 return msg;
79 }
80
ToBytes(const MDnsMessage & msg)81 MDnsPayload MDnsPayloadParser::ToBytes(const MDnsMessage &msg)
82 {
83 MDnsPayload payload;
84 MDnsPayload *cachedPayload = &payload;
85 std::map<std::string, uint16_t> strCacheMap;
86 Serialize(msg, payload, cachedPayload, strCacheMap);
87 return payload;
88 }
89
Parse(const uint8_t * begin,const MDnsPayload & payload,MDnsMessage & msg)90 const uint8_t *MDnsPayloadParser::Parse(const uint8_t *begin, const MDnsPayload &payload, MDnsMessage &msg)
91 {
92 begin = ParseHeader(begin, payload, msg.header);
93 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
94 return begin;
95 }
96 for (int i = 0; i < msg.header.qdcount; ++i) {
97 begin = ParseQuestion(begin, payload, msg.questions);
98 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
99 return begin;
100 }
101 }
102 for (int i = 0; i < msg.header.ancount; ++i) {
103 begin = ParseRR(begin, payload, msg.answers);
104 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
105 return begin;
106 }
107 }
108 for (int i = 0; i < msg.header.nscount; ++i) {
109 begin = ParseRR(begin, payload, msg.authorities);
110 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
111 return begin;
112 }
113 }
114 for (int i = 0; i < msg.header.arcount; ++i) {
115 begin = ParseRR(begin, payload, msg.additional);
116 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
117 return begin;
118 }
119 }
120 return begin;
121 }
122
ParseHeader(const uint8_t * begin,const MDnsPayload & payload,DNSProto::Header & header)123 const uint8_t *MDnsPayloadParser::ParseHeader(const uint8_t *begin, const MDnsPayload &payload,
124 DNSProto::Header &header)
125 {
126 const uint8_t *end = payload.data() + payload.size();
127 if (end - begin < static_cast<int>(sizeof(DNSProto::Header))) {
128 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
129 return begin;
130 }
131
132 begin = ReadNUint16(begin, header.id);
133 begin = ReadNUint16(begin, header.flags);
134 begin = ReadNUint16(begin, header.qdcount);
135 begin = ReadNUint16(begin, header.ancount);
136 begin = ReadNUint16(begin, header.nscount);
137 begin = ReadNUint16(begin, header.arcount);
138 return begin;
139 }
140
ParseQuestion(const uint8_t * begin,const MDnsPayload & payload,std::vector<DNSProto::Question> & questions)141 const uint8_t *MDnsPayloadParser::ParseQuestion(const uint8_t *begin, const MDnsPayload &payload,
142 std::vector<DNSProto::Question> &questions)
143 {
144 questions.emplace_back();
145 begin = ParseDnsString(begin, payload, questions.back().name);
146 questions.back().name = UnDotted(questions.back().name);
147 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
148 questions.pop_back();
149 return begin;
150 }
151
152 const uint8_t *end = payload.data() + payload.size();
153 if (static_cast<ssize_t>(end - begin) < static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t))) {
154 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
155 questions.pop_back();
156 return begin;
157 }
158
159 begin = ReadNUint16(begin, questions.back().qtype);
160 begin = ReadNUint16(begin, questions.back().qclass);
161 return begin;
162 }
163
ParseRR(const uint8_t * begin,const MDnsPayload & payload,std::vector<DNSProto::ResourceRecord> & answers)164 const uint8_t *MDnsPayloadParser::ParseRR(const uint8_t *begin, const MDnsPayload &payload,
165 std::vector<DNSProto::ResourceRecord> &answers)
166 {
167 answers.emplace_back();
168 begin = ParseDnsString(begin, payload, answers.back().name);
169 answers.back().name = UnDotted(answers.back().name);
170 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
171 answers.pop_back();
172 return begin;
173 }
174
175 const uint8_t *end = payload.data() + payload.size();
176 if (static_cast<ssize_t>(end - begin) <
177 static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t) + sizeof(uint32_t) + sizeof(uint16_t))) {
178 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
179 answers.pop_back();
180 return begin;
181 }
182 begin = ReadNUint16(begin, answers.back().rtype);
183 begin = ReadNUint16(begin, answers.back().rclass);
184 begin = ReadNUint32(begin, answers.back().ttl);
185 begin = ReadNUint16(begin, answers.back().length);
186 return ParseRData(begin, payload, answers.back().rtype, answers.back().length, answers.back().rdata);
187 }
188
ParseRData(const uint8_t * begin,const MDnsPayload & payload,int type,int length,std::any & data)189 const uint8_t *MDnsPayloadParser::ParseRData(const uint8_t *begin, const MDnsPayload &payload, int type, int length,
190 std::any &data)
191 {
192 switch (type) {
193 case DNSProto::RRTYPE_A: {
194 const uint8_t *end = payload.data() + payload.size();
195 if (static_cast<size_t>(end - begin) < sizeof(in_addr) || length != sizeof(in_addr)) {
196 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
197 return begin;
198 }
199 in_addr addr;
200 begin = ReadRawData(begin, addr);
201 data = addr;
202 return begin;
203 }
204 case DNSProto::RRTYPE_AAAA: {
205 const uint8_t *end = payload.data() + payload.size();
206 if ((static_cast<ssize_t>(end - begin) <
207 static_cast<ssize_t>(sizeof(in6_addr))) || (length != sizeof(in6_addr))) {
208 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
209 return begin;
210 }
211 in6_addr addr;
212 begin = ReadRawData(begin, addr);
213 data = addr;
214 return begin;
215 }
216 case DNSProto::RRTYPE_PTR: {
217 std::string str;
218 begin = ParseDnsString(begin, payload, str);
219 str = UnDotted(str);
220 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
221 return begin;
222 }
223 data = str;
224 return begin;
225 }
226 case DNSProto::RRTYPE_SRV: {
227 return ParseSrv(begin, payload, data);
228 }
229 case DNSProto::RRTYPE_TXT: {
230 return ParseTxt(begin, payload, length, data);
231 }
232 default: {
233 errorFlags_ |= PARSE_WARNING_BAD_RRTYPE;
234 return begin + length;
235 }
236 }
237 }
238
ParseSrv(const uint8_t * begin,const MDnsPayload & payload,std::any & data)239 const uint8_t *MDnsPayloadParser::ParseSrv(const uint8_t *begin, const MDnsPayload &payload, std::any &data)
240 {
241 const uint8_t *end = payload.data() + payload.size();
242 if (static_cast<ssize_t>(end - begin) <
243 static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t) + sizeof(uint16_t))) {
244 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
245 return begin;
246 }
247
248 DNSProto::RDataSrv srv;
249 begin = ReadNUint16(begin, srv.priority);
250 begin = ReadNUint16(begin, srv.weight);
251 begin = ReadNUint16(begin, srv.port);
252 begin = ParseDnsString(begin, payload, srv.name);
253 srv.name = UnDotted(srv.name);
254 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
255 return begin;
256 }
257 data = srv;
258 return begin;
259 }
260
ParseTxt(const uint8_t * begin,const MDnsPayload & payload,int length,std::any & data)261 const uint8_t *MDnsPayloadParser::ParseTxt(const uint8_t *begin, const MDnsPayload &payload, int length, std::any &data)
262 {
263 const uint8_t *end = payload.data() + payload.size();
264 if (end - begin < length) {
265 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
266 return begin;
267 }
268
269 data = TxtRecordEncoded(begin, begin + length);
270 return begin + length;
271 }
272
ParseDnsString(const uint8_t * begin,const MDnsPayload & payload,std::string & str)273 const uint8_t *MDnsPayloadParser::ParseDnsString(const uint8_t *begin, const MDnsPayload &payload, std::string &str)
274 {
275 const uint8_t *end = payload.data() + payload.size();
276 const uint8_t *p = begin;
277 str.reserve(MDNS_STR_INITIAL_SIZE);
278 while (p && p < end) {
279 if (*p == 0) {
280 return p + 1;
281 }
282 if (*p <= MDNS_MAX_DOMAIN_LABEL && p + *p < end) {
283 str.append(reinterpret_cast<const char *>(p) + 1, *p);
284 str.push_back(MDNS_DOMAIN_SPLITER);
285 p += (*p + 1);
286 } else if ((*p & DNS_STR_PTR_U8_MASK) == DNS_STR_PTR_U8_MASK) {
287 if (end - p < static_cast<int>(sizeof(uint16_t))) {
288 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
289 return begin;
290 }
291
292 uint16_t offset;
293 const uint8_t *tmp = ReadNUint16(p, offset);
294 offset = offset & ~DNS_STR_PTR_U16_MASK;
295 const uint8_t *next = payload.data() + (offset & ~DNS_STR_PTR_U16_MASK);
296
297 if (next >= end || next >= begin) {
298 errorFlags_ |= PARSE_ERROR_BAD_STRPTR;
299 return begin;
300 }
301 ParseDnsString(next, payload, str);
302 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
303 return begin;
304 }
305 return tmp;
306 } else {
307 errorFlags_ |= PARSE_ERROR_BAD_STR;
308 return p;
309 }
310 }
311 return p;
312 }
313
Serialize(const MDnsMessage & msg,MDnsPayload & payload,MDnsPayload * cachedPayload,std::map<std::string,uint16_t> & strCacheMap)314 void MDnsPayloadParser::Serialize(const MDnsMessage &msg, MDnsPayload &payload, MDnsPayload *cachedPayload,
315 std::map<std::string, uint16_t> &strCacheMap)
316 {
317 payload.reserve(sizeof(DNSProto::Message));
318 DNSProto::Header header = msg.header;
319 header.qdcount = msg.questions.size();
320 header.ancount = msg.answers.size();
321 header.nscount = msg.authorities.size();
322 header.arcount = msg.additional.size();
323 SerializeHeader(header, msg, payload);
324 for (uint16_t i = 0; i < header.qdcount; ++i) {
325 SerializeQuestion(msg.questions[i], payload, cachedPayload, strCacheMap);
326 }
327 for (uint16_t i = 0; i < header.ancount; ++i) {
328 SerializeRR(msg.answers[i], payload, cachedPayload, strCacheMap);
329 }
330 for (uint16_t i = 0; i < header.nscount; ++i) {
331 SerializeRR(msg.authorities[i], payload, cachedPayload, strCacheMap);
332 }
333 for (uint16_t i = 0; i < header.arcount; ++i) {
334 SerializeRR(msg.additional[i], payload, cachedPayload, strCacheMap);
335 }
336 }
337
SerializeHeader(const DNSProto::Header & header,const MDnsMessage & msg,MDnsPayload & payload)338 void MDnsPayloadParser::SerializeHeader(const DNSProto::Header &header, const MDnsMessage &msg, MDnsPayload &payload)
339 {
340 WriteRawData(htons(header.id), payload);
341 WriteRawData(htons(header.flags), payload);
342 WriteRawData(htons(header.qdcount), payload);
343 WriteRawData(htons(header.ancount), payload);
344 WriteRawData(htons(header.nscount), payload);
345 WriteRawData(htons(header.arcount), payload);
346 }
347
SerializeQuestion(const DNSProto::Question & question,MDnsPayload & payload,MDnsPayload * cachedPayload,std::map<std::string,uint16_t> & strCacheMap)348 void MDnsPayloadParser::SerializeQuestion(const DNSProto::Question &question, MDnsPayload &payload,
349 MDnsPayload *cachedPayload, std::map<std::string, uint16_t> &strCacheMap)
350 {
351 SerializeDnsString(question.name, payload, cachedPayload, strCacheMap);
352 WriteRawData(htons(question.qtype), payload);
353 WriteRawData(htons(question.qclass), payload);
354 }
355
SerializeRR(const DNSProto::ResourceRecord & rr,MDnsPayload & payload,MDnsPayload * cachedPayload,std::map<std::string,uint16_t> & strCacheMap)356 void MDnsPayloadParser::SerializeRR(const DNSProto::ResourceRecord &rr, MDnsPayload &payload,
357 MDnsPayload *cachedPayload, std::map<std::string, uint16_t> &strCacheMap)
358 {
359 SerializeDnsString(rr.name, payload, cachedPayload, strCacheMap);
360 WriteRawData(htons(rr.rtype), payload);
361 WriteRawData(htons(rr.rclass), payload);
362 WriteRawData(htonl(rr.ttl), payload);
363 size_t lenStart = payload.size();
364 WriteRawData(htons(rr.length), payload);
365 SerializeRData(rr.rdata, payload, cachedPayload, strCacheMap);
366 uint16_t len = payload.size() - lenStart - sizeof(uint16_t);
367 WriteRawData(htons(len), payload.data() + lenStart);
368 }
369
SerializeRData(const std::any & rdata,MDnsPayload & payload,MDnsPayload * cachedPayload,std::map<std::string,uint16_t> & strCacheMap)370 void MDnsPayloadParser::SerializeRData(const std::any &rdata, MDnsPayload &payload, MDnsPayload *cachedPayload,
371 std::map<std::string, uint16_t> &strCacheMap)
372 {
373 if (std::any_cast<const in_addr>(&rdata)) {
374 WriteRawData(*std::any_cast<const in_addr>(&rdata), payload);
375 } else if (std::any_cast<const in6_addr>(&rdata)) {
376 WriteRawData(*std::any_cast<const in6_addr>(&rdata), payload);
377 } else if (std::any_cast<const std::string>(&rdata)) {
378 SerializeDnsString(*std::any_cast<const std::string>(&rdata), payload, cachedPayload, strCacheMap);
379 } else if (std::any_cast<const DNSProto::RDataSrv>(&rdata)) {
380 const DNSProto::RDataSrv *srv = std::any_cast<const DNSProto::RDataSrv>(&rdata);
381 WriteRawData(htons(srv->priority), payload);
382 WriteRawData(htons(srv->weight), payload);
383 WriteRawData(htons(srv->port), payload);
384 SerializeDnsString(srv->name, payload, cachedPayload, strCacheMap);
385 } else if (std::any_cast<TxtRecordEncoded>(&rdata)) {
386 const auto *txt = std::any_cast<TxtRecordEncoded>(&rdata);
387 payload.insert(payload.end(), txt->begin(), txt->end());
388 }
389 }
390
SerializeDnsString(const std::string & str,MDnsPayload & payload,MDnsPayload * cachedPayload,std::map<std::string,uint16_t> & strCacheMap)391 void MDnsPayloadParser::SerializeDnsString(const std::string &str, MDnsPayload &payload, MDnsPayload *cachedPayload,
392 std::map<std::string, uint16_t> &strCacheMap)
393 {
394 size_t pos = 0;
395 while (pos < str.size()) {
396 if ((cachedPayload == &payload) && (strCacheMap.find(str.substr(pos)) != strCacheMap.end())) {
397 return WriteRawData(htons(strCacheMap[str.substr(pos)]), payload);
398 }
399
400 size_t nextDot = str.find(MDNS_DOMAIN_SPLITER, pos);
401 if (nextDot == std::string::npos) {
402 nextDot = str.size();
403 }
404 uint8_t segLen = (nextDot - pos) & DNS_STR_PTR_LENGTH;
405
406 uint16_t strptr = payload.size();
407 WriteRawData(segLen, payload);
408 for (int i = 0; i < segLen; ++i) {
409 WriteRawData(str[pos + i], payload);
410 }
411 strCacheMap[str.substr(pos)] = strptr | DNS_STR_PTR_U16_MASK;
412 pos = nextDot + 1;
413 }
414 WriteRawData(DNS_STR_EOL, payload);
415 }
416
GetError() const417 uint32_t MDnsPayloadParser::GetError() const
418 {
419 return errorFlags_ & PARSE_ERROR;
420 }
421
422 } // namespace NetManagerStandard
423 } // namespace OHOS
424