1 /*
2 * Copyright (C) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "mdns_packet_parser.h"
17
18 #include <cstring>
19
20 namespace OHOS {
21 namespace NetManagerStandard {
22
23 namespace {
24
25 constexpr size_t MDNS_STR_INITIAL_SIZE = 16;
26
27 constexpr uint8_t DNS_STR_PTR_U8_MASK = 0xc0;
28 constexpr uint16_t DNS_STR_PTR_U16_MASK = 0xc000;
29 constexpr uint16_t DNS_STR_PTR_LENGTH = 0x3f;
30 constexpr uint8_t DNS_STR_EOL = '\0';
31
WriteRawData(const T & data,MDnsPayload & payload)32 template <class T> void WriteRawData(const T &data, MDnsPayload &payload)
33 {
34 const uint8_t *begin = reinterpret_cast<const uint8_t *>(&data);
35 payload.insert(payload.end(), begin, begin + sizeof(T));
36 }
37
WriteRawData(const T & data,uint8_t * ptr)38 template <class T> void WriteRawData(const T &data, uint8_t *ptr)
39 {
40 const uint8_t *begin = reinterpret_cast<const uint8_t *>(&data);
41 for (size_t i = 0; i < sizeof(T); ++i) {
42 ptr[i] = *begin++;
43 }
44 }
45
ReadRawData(const uint8_t * raw,T & data)46 template <class T> const uint8_t *ReadRawData(const uint8_t *raw, T &data)
47 {
48 data = *reinterpret_cast<const T *>(raw);
49 return raw + sizeof(T);
50 }
51
ReadNUint16(const uint8_t * raw,uint16_t & data)52 const uint8_t *ReadNUint16(const uint8_t *raw, uint16_t &data)
53 {
54 const uint8_t *tmp = ReadRawData(raw, data);
55 data = ntohs(data);
56 return tmp;
57 }
58
ReadNUint32(const uint8_t * raw,uint32_t & data)59 const uint8_t *ReadNUint32(const uint8_t *raw, uint32_t &data)
60 {
61 const uint8_t *tmp = ReadRawData(raw, data);
62 data = ntohl(data);
63 return tmp;
64 }
65
66 } // namespace
67
FromBytes(const MDnsPayload & payload)68 MDnsMessage MDnsPayloadParser::FromBytes(const MDnsPayload &payload)
69 {
70 MDnsMessage msg;
71 errorFlags_ = PARSE_OK;
72 pos_ = Parse(payload.data(), payload, msg);
73 return msg;
74 }
75
ToBytes(const MDnsMessage & msg)76 MDnsPayload MDnsPayloadParser::ToBytes(const MDnsMessage &msg)
77 {
78 MDnsPayload payload;
79 cachedPayload_ = &payload;
80 strCacheMap_.clear();
81 Serialize(msg, payload);
82 cachedPayload_ = nullptr;
83 strCacheMap_.clear();
84 return payload;
85 }
86
Parse(const uint8_t * begin,const MDnsPayload & payload,MDnsMessage & msg)87 const uint8_t *MDnsPayloadParser::Parse(const uint8_t *begin, const MDnsPayload &payload, MDnsMessage &msg)
88 {
89 begin = ParseHeader(begin, payload, msg.header);
90 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
91 return begin;
92 }
93 for (int i = 0; i < msg.header.qdcount; ++i) {
94 begin = ParseQuestion(begin, payload, msg.questions);
95 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
96 return begin;
97 }
98 }
99 for (int i = 0; i < msg.header.ancount; ++i) {
100 begin = ParseRR(begin, payload, msg.answers);
101 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
102 return begin;
103 }
104 }
105 for (int i = 0; i < msg.header.nscount; ++i) {
106 begin = ParseRR(begin, payload, msg.authorities);
107 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
108 return begin;
109 }
110 }
111 for (int i = 0; i < msg.header.arcount; ++i) {
112 begin = ParseRR(begin, payload, msg.additional);
113 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
114 return begin;
115 }
116 }
117 return begin;
118 }
119
ParseHeader(const uint8_t * begin,const MDnsPayload & payload,DNSProto::Header & header)120 const uint8_t *MDnsPayloadParser::ParseHeader(const uint8_t *begin, const MDnsPayload &payload,
121 DNSProto::Header &header)
122 {
123 const uint8_t *end = payload.data() + payload.size();
124 if (end - begin < static_cast<int>(sizeof(DNSProto::Header))) {
125 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
126 return begin;
127 }
128
129 begin = ReadNUint16(begin, header.id);
130 begin = ReadNUint16(begin, header.flags);
131 begin = ReadNUint16(begin, header.qdcount);
132 begin = ReadNUint16(begin, header.ancount);
133 begin = ReadNUint16(begin, header.nscount);
134 begin = ReadNUint16(begin, header.arcount);
135 return begin;
136 }
137
ParseQuestion(const uint8_t * begin,const MDnsPayload & payload,std::vector<DNSProto::Question> & questions)138 const uint8_t *MDnsPayloadParser::ParseQuestion(const uint8_t *begin, const MDnsPayload &payload,
139 std::vector<DNSProto::Question> &questions)
140 {
141 questions.emplace_back();
142 begin = ParseDnsString(begin, payload, questions.back().name);
143 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
144 questions.pop_back();
145 return begin;
146 }
147
148 const uint8_t *end = payload.data() + payload.size();
149 if (static_cast<ssize_t>(end - begin) < static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t))) {
150 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
151 questions.pop_back();
152 return begin;
153 }
154
155 begin = ReadNUint16(begin, questions.back().qtype);
156 begin = ReadNUint16(begin, questions.back().qclass);
157 return begin;
158 }
159
ParseRR(const uint8_t * begin,const MDnsPayload & payload,std::vector<DNSProto::ResourceRecord> & answers)160 const uint8_t *MDnsPayloadParser::ParseRR(const uint8_t *begin, const MDnsPayload &payload,
161 std::vector<DNSProto::ResourceRecord> &answers)
162 {
163 answers.emplace_back();
164 begin = ParseDnsString(begin, payload, answers.back().name);
165 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
166 answers.pop_back();
167 return begin;
168 }
169
170 const uint8_t *end = payload.data() + payload.size();
171 if (static_cast<ssize_t>(end - begin) <
172 static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t) + sizeof(uint32_t) + sizeof(uint16_t))) {
173 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
174 answers.pop_back();
175 return begin;
176 }
177 begin = ReadNUint16(begin, answers.back().rtype);
178 begin = ReadNUint16(begin, answers.back().rclass);
179 begin = ReadNUint32(begin, answers.back().ttl);
180 begin = ReadNUint16(begin, answers.back().length);
181 return ParseRData(begin, payload, answers.back().rtype, answers.back().length, answers.back().rdata);
182 }
183
ParseRData(const uint8_t * begin,const MDnsPayload & payload,int type,int length,std::any & data)184 const uint8_t *MDnsPayloadParser::ParseRData(const uint8_t *begin, const MDnsPayload &payload, int type, int length,
185 std::any &data)
186 {
187 switch (type) {
188 case DNSProto::RRTYPE_A: {
189 const uint8_t *end = payload.data() + payload.size();
190 if (static_cast<size_t>(end - begin) < sizeof(in_addr) || length != sizeof(in_addr)) {
191 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
192 return begin;
193 }
194 in_addr addr;
195 begin = ReadRawData(begin, addr);
196 data = addr;
197 return begin;
198 }
199 case DNSProto::RRTYPE_AAAA: {
200 const uint8_t *end = payload.data() + payload.size();
201 if (static_cast<ssize_t>(end - begin) <
202 static_cast<ssize_t>(sizeof(in6_addr) || length != sizeof(in6_addr))) {
203 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
204 return begin;
205 }
206 in6_addr addr;
207 begin = ReadRawData(begin, addr);
208 data = addr;
209 return begin;
210 }
211 case DNSProto::RRTYPE_PTR: {
212 std::string str;
213 begin = ParseDnsString(begin, payload, str);
214 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
215 return begin;
216 }
217 data = str;
218 return begin;
219 }
220 case DNSProto::RRTYPE_SRV: {
221 return ParseSrv(begin, payload, data);
222 }
223 case DNSProto::RRTYPE_TXT: {
224 return ParseTxt(begin, payload, length, data);
225 }
226 default: {
227 errorFlags_ |= PARSE_WARNING_BAD_RRTYPE;
228 return begin + length;
229 }
230 }
231 }
232
ParseSrv(const uint8_t * begin,const MDnsPayload & payload,std::any & data)233 const uint8_t *MDnsPayloadParser::ParseSrv(const uint8_t *begin, const MDnsPayload &payload, std::any &data)
234 {
235 const uint8_t *end = payload.data() + payload.size();
236 if (static_cast<ssize_t>(end - begin) <
237 static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t) + sizeof(uint16_t))) {
238 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
239 return begin;
240 }
241
242 DNSProto::RDataSrv srv;
243 begin = ReadNUint16(begin, srv.priority);
244 begin = ReadNUint16(begin, srv.weight);
245 begin = ReadNUint16(begin, srv.port);
246 begin = ParseDnsString(begin, payload, srv.name);
247 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
248 return begin;
249 }
250 data = srv;
251 return begin;
252 }
253
ParseTxt(const uint8_t * begin,const MDnsPayload & payload,int length,std::any & data)254 const uint8_t *MDnsPayloadParser::ParseTxt(const uint8_t *begin, const MDnsPayload &payload, int length, std::any &data)
255 {
256 const uint8_t *end = payload.data() + payload.size();
257 if (end - begin < length) {
258 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
259 return begin;
260 }
261
262 data = TxtRecordEncoded(begin, begin + length);
263 return begin + length;
264 }
265
ParseDnsString(const uint8_t * begin,const MDnsPayload & payload,std::string & str)266 const uint8_t *MDnsPayloadParser::ParseDnsString(const uint8_t *begin, const MDnsPayload &payload, std::string &str)
267 {
268 const uint8_t *end = payload.data() + payload.size();
269 const uint8_t *p = begin;
270 str.reserve(MDNS_STR_INITIAL_SIZE);
271 while (p < end) {
272 if (*p == 0) {
273 return p + 1;
274 }
275 if (*p <= MDNS_MAX_DOMAIN_LABEL && p + *p < end) {
276 str.append(reinterpret_cast<const char *>(p) + 1, *p);
277 str.push_back(MDNS_DOMAIN_SPLITER);
278 p += (*p + 1);
279 } else if ((*p & DNS_STR_PTR_U8_MASK) == DNS_STR_PTR_U8_MASK) {
280 uint16_t offset;
281 const uint8_t *tmp = ReadNUint16(p, offset);
282 offset = offset & ~DNS_STR_PTR_U16_MASK;
283 if (offset >= payload.size()) {
284 errorFlags_ |= PARSE_ERROR_BAD_STRPTR;
285 return begin;
286 }
287 ParseDnsString(payload.data() + (offset & ~DNS_STR_PTR_U16_MASK), payload, str);
288 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
289 return begin;
290 }
291 return tmp;
292 } else {
293 errorFlags_ |= PARSE_ERROR_BAD_STR;
294 return p;
295 }
296 }
297 return p;
298 }
299
Serialize(const MDnsMessage & msg,MDnsPayload & payload)300 void MDnsPayloadParser::Serialize(const MDnsMessage &msg, MDnsPayload &payload)
301 {
302 payload.reserve(sizeof(DNSProto::Message));
303 DNSProto::Header header = msg.header;
304 header.qdcount = msg.questions.size();
305 header.ancount = msg.answers.size();
306 header.nscount = msg.authorities.size();
307 header.arcount = msg.additional.size();
308 SerializeHeader(header, msg, payload);
309 for (uint16_t i = 0; i < header.qdcount; ++i) {
310 SerializeQuestion(msg.questions[i], msg, payload);
311 }
312 for (uint16_t i = 0; i < header.ancount; ++i) {
313 SerializeRR(msg.answers[i], msg, payload);
314 }
315 for (uint16_t i = 0; i < header.nscount; ++i) {
316 SerializeRR(msg.authorities[i], msg, payload);
317 }
318 for (uint16_t i = 0; i < header.arcount; ++i) {
319 SerializeRR(msg.additional[i], msg, payload);
320 }
321 }
322
SerializeHeader(const DNSProto::Header & header,const MDnsMessage & msg,MDnsPayload & payload)323 void MDnsPayloadParser::SerializeHeader(const DNSProto::Header &header, const MDnsMessage &msg, MDnsPayload &payload)
324 {
325 WriteRawData(htons(header.id), payload);
326 WriteRawData(htons(header.flags), payload);
327 WriteRawData(htons(header.qdcount), payload);
328 WriteRawData(htons(header.ancount), payload);
329 WriteRawData(htons(header.nscount), payload);
330 WriteRawData(htons(header.arcount), payload);
331 }
332
SerializeQuestion(const DNSProto::Question & question,const MDnsMessage & msg,MDnsPayload & payload)333 void MDnsPayloadParser::SerializeQuestion(const DNSProto::Question &question, const MDnsMessage &msg,
334 MDnsPayload &payload)
335 {
336 SerializeDnsString(question.name, msg, payload);
337 WriteRawData(htons(question.qtype), payload);
338 WriteRawData(htons(question.qclass), payload);
339 }
340
SerializeRR(const DNSProto::ResourceRecord & rr,const MDnsMessage & msg,MDnsPayload & payload)341 void MDnsPayloadParser::SerializeRR(const DNSProto::ResourceRecord &rr, const MDnsMessage &msg, MDnsPayload &payload)
342 {
343 SerializeDnsString(rr.name, msg, payload);
344 WriteRawData(htons(rr.rtype), payload);
345 WriteRawData(htons(rr.rclass), payload);
346 WriteRawData(htonl(rr.ttl), payload);
347 size_t lenStart = payload.size();
348 WriteRawData(htons(rr.length), payload);
349 SerializeRData(rr.rdata, msg, payload);
350 uint16_t len = payload.size() - lenStart - sizeof(uint16_t);
351 WriteRawData(htons(len), payload.data() + lenStart);
352 }
353
SerializeRData(const std::any & rdata,const MDnsMessage & msg,MDnsPayload & payload)354 void MDnsPayloadParser::SerializeRData(const std::any &rdata, const MDnsMessage &msg, MDnsPayload &payload)
355 {
356 if (std::any_cast<const in_addr>(&rdata)) {
357 WriteRawData(*std::any_cast<const in_addr>(&rdata), payload);
358 } else if (std::any_cast<const in6_addr>(&rdata)) {
359 WriteRawData(*std::any_cast<const in6_addr>(&rdata), payload);
360 } else if (std::any_cast<const std::string>(&rdata)) {
361 SerializeDnsString(*std::any_cast<const std::string>(&rdata), msg, payload);
362 } else if (std::any_cast<const DNSProto::RDataSrv>(&rdata)) {
363 const DNSProto::RDataSrv *srv = std::any_cast<const DNSProto::RDataSrv>(&rdata);
364 WriteRawData(htons(srv->priority), payload);
365 WriteRawData(htons(srv->weight), payload);
366 WriteRawData(htons(srv->port), payload);
367 SerializeDnsString(srv->name, msg, payload);
368 } else if (std::any_cast<TxtRecordEncoded>(&rdata)) {
369 const auto *txt = std::any_cast<TxtRecordEncoded>(&rdata);
370 payload.insert(payload.end(), txt->begin(), txt->end());
371 }
372 }
373
SerializeDnsString(const std::string & str,const MDnsMessage & msg,MDnsPayload & payload)374 void MDnsPayloadParser::SerializeDnsString(const std::string &str, const MDnsMessage &msg, MDnsPayload &payload)
375 {
376 size_t pos = 0;
377 while (pos < str.size()) {
378 if (cachedPayload_ == &payload && strCacheMap_.find(str.substr(pos)) != strCacheMap_.end()) {
379 return WriteRawData(htons(strCacheMap_[str.substr(pos)]), payload);
380 }
381
382 size_t nextDot = str.find(MDNS_DOMAIN_SPLITER, pos);
383 if (nextDot == std::string::npos) {
384 nextDot = str.size();
385 }
386 uint8_t segLen = (nextDot - pos) & DNS_STR_PTR_LENGTH;
387 uint16_t strptr = payload.size();
388 WriteRawData(segLen, payload);
389 for (int i = 0; i < segLen; ++i) {
390 WriteRawData(str[pos + i], payload);
391 }
392 strCacheMap_[str.substr(pos)] = strptr | DNS_STR_PTR_U16_MASK;
393 pos = nextDot + 1;
394 }
395 WriteRawData(DNS_STR_EOL, payload);
396 }
397
GetError() const398 uint32_t MDnsPayloadParser::GetError() const
399 {
400 return errorFlags_ & PARSE_ERROR;
401 }
402
403 } // namespace NetManagerStandard
404 } // namespace OHOS
405