• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "mdns_packet_parser.h"
17 
18 #include <cstring>
19 
20 namespace OHOS {
21 namespace NetManagerStandard {
22 
23 namespace {
24 
25 constexpr size_t MDNS_STR_INITIAL_SIZE = 16;
26 
27 constexpr uint8_t DNS_STR_PTR_U8_MASK = 0xc0;
28 constexpr uint16_t DNS_STR_PTR_U16_MASK = 0xc000;
29 constexpr uint16_t DNS_STR_PTR_LENGTH = 0x3f;
30 constexpr uint8_t DNS_STR_EOL = '\0';
31 
WriteRawData(const T & data,MDnsPayload & payload)32 template <class T> void WriteRawData(const T &data, MDnsPayload &payload)
33 {
34     const uint8_t *begin = reinterpret_cast<const uint8_t *>(&data);
35     payload.insert(payload.end(), begin, begin + sizeof(T));
36 }
37 
WriteRawData(const T & data,uint8_t * ptr)38 template <class T> void WriteRawData(const T &data, uint8_t *ptr)
39 {
40     const uint8_t *begin = reinterpret_cast<const uint8_t *>(&data);
41     for (size_t i = 0; i < sizeof(T); ++i) {
42         ptr[i] = *begin++;
43     }
44 }
45 
ReadRawData(const uint8_t * raw,T & data)46 template <class T> const uint8_t *ReadRawData(const uint8_t *raw, T &data)
47 {
48     data = *reinterpret_cast<const T *>(raw);
49     return raw + sizeof(T);
50 }
51 
ReadNUint16(const uint8_t * raw,uint16_t & data)52 const uint8_t *ReadNUint16(const uint8_t *raw, uint16_t &data)
53 {
54     const uint8_t *tmp = ReadRawData(raw, data);
55     data = ntohs(data);
56     return tmp;
57 }
58 
ReadNUint32(const uint8_t * raw,uint32_t & data)59 const uint8_t *ReadNUint32(const uint8_t *raw, uint32_t &data)
60 {
61     const uint8_t *tmp = ReadRawData(raw, data);
62     data = ntohl(data);
63     return tmp;
64 }
65 
66 } // namespace
67 
FromBytes(const MDnsPayload & payload)68 MDnsMessage MDnsPayloadParser::FromBytes(const MDnsPayload &payload)
69 {
70     MDnsMessage msg;
71     errorFlags_ = PARSE_OK;
72     pos_ = Parse(payload.data(), payload, msg);
73     return msg;
74 }
75 
ToBytes(const MDnsMessage & msg)76 MDnsPayload MDnsPayloadParser::ToBytes(const MDnsMessage &msg)
77 {
78     MDnsPayload payload;
79     cachedPayload_ = &payload;
80     strCacheMap_.clear();
81     Serialize(msg, payload);
82     cachedPayload_ = nullptr;
83     strCacheMap_.clear();
84     return payload;
85 }
86 
Parse(const uint8_t * begin,const MDnsPayload & payload,MDnsMessage & msg)87 const uint8_t *MDnsPayloadParser::Parse(const uint8_t *begin, const MDnsPayload &payload, MDnsMessage &msg)
88 {
89     begin = ParseHeader(begin, payload, msg.header);
90     if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
91         return begin;
92     }
93     for (int i = 0; i < msg.header.qdcount; ++i) {
94         begin = ParseQuestion(begin, payload, msg.questions);
95         if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
96             return begin;
97         }
98     }
99     for (int i = 0; i < msg.header.ancount; ++i) {
100         begin = ParseRR(begin, payload, msg.answers);
101         if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
102             return begin;
103         }
104     }
105     for (int i = 0; i < msg.header.nscount; ++i) {
106         begin = ParseRR(begin, payload, msg.authorities);
107         if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
108             return begin;
109         }
110     }
111     for (int i = 0; i < msg.header.arcount; ++i) {
112         begin = ParseRR(begin, payload, msg.additional);
113         if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
114             return begin;
115         }
116     }
117     return begin;
118 }
119 
ParseHeader(const uint8_t * begin,const MDnsPayload & payload,DNSProto::Header & header)120 const uint8_t *MDnsPayloadParser::ParseHeader(const uint8_t *begin, const MDnsPayload &payload,
121                                               DNSProto::Header &header)
122 {
123     const uint8_t *end = payload.data() + payload.size();
124     if (end - begin < static_cast<int>(sizeof(DNSProto::Header))) {
125         errorFlags_ |= PARSE_ERROR_BAD_SIZE;
126         return begin;
127     }
128 
129     begin = ReadNUint16(begin, header.id);
130     begin = ReadNUint16(begin, header.flags);
131     begin = ReadNUint16(begin, header.qdcount);
132     begin = ReadNUint16(begin, header.ancount);
133     begin = ReadNUint16(begin, header.nscount);
134     begin = ReadNUint16(begin, header.arcount);
135     return begin;
136 }
137 
ParseQuestion(const uint8_t * begin,const MDnsPayload & payload,std::vector<DNSProto::Question> & questions)138 const uint8_t *MDnsPayloadParser::ParseQuestion(const uint8_t *begin, const MDnsPayload &payload,
139                                                 std::vector<DNSProto::Question> &questions)
140 {
141     questions.emplace_back();
142     begin = ParseDnsString(begin, payload, questions.back().name);
143     if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
144         questions.pop_back();
145         return begin;
146     }
147 
148     const uint8_t *end = payload.data() + payload.size();
149     if (static_cast<ssize_t>(end - begin) < static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t))) {
150         errorFlags_ |= PARSE_ERROR_BAD_SIZE;
151         questions.pop_back();
152         return begin;
153     }
154 
155     begin = ReadNUint16(begin, questions.back().qtype);
156     begin = ReadNUint16(begin, questions.back().qclass);
157     return begin;
158 }
159 
ParseRR(const uint8_t * begin,const MDnsPayload & payload,std::vector<DNSProto::ResourceRecord> & answers)160 const uint8_t *MDnsPayloadParser::ParseRR(const uint8_t *begin, const MDnsPayload &payload,
161                                           std::vector<DNSProto::ResourceRecord> &answers)
162 {
163     answers.emplace_back();
164     begin = ParseDnsString(begin, payload, answers.back().name);
165     if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
166         answers.pop_back();
167         return begin;
168     }
169 
170     const uint8_t *end = payload.data() + payload.size();
171     if (static_cast<ssize_t>(end - begin) <
172         static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t) + sizeof(uint32_t) + sizeof(uint16_t))) {
173         errorFlags_ |= PARSE_ERROR_BAD_SIZE;
174         answers.pop_back();
175         return begin;
176     }
177     begin = ReadNUint16(begin, answers.back().rtype);
178     begin = ReadNUint16(begin, answers.back().rclass);
179     begin = ReadNUint32(begin, answers.back().ttl);
180     begin = ReadNUint16(begin, answers.back().length);
181     return ParseRData(begin, payload, answers.back().rtype, answers.back().length, answers.back().rdata);
182 }
183 
ParseRData(const uint8_t * begin,const MDnsPayload & payload,int type,int length,std::any & data)184 const uint8_t *MDnsPayloadParser::ParseRData(const uint8_t *begin, const MDnsPayload &payload, int type, int length,
185                                              std::any &data)
186 {
187     switch (type) {
188         case DNSProto::RRTYPE_A: {
189             const uint8_t *end = payload.data() + payload.size();
190             if (static_cast<size_t>(end - begin) < sizeof(in_addr) || length != sizeof(in_addr)) {
191                 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
192                 return begin;
193             }
194             in_addr addr;
195             begin = ReadRawData(begin, addr);
196             data = addr;
197             return begin;
198         }
199         case DNSProto::RRTYPE_AAAA: {
200             const uint8_t *end = payload.data() + payload.size();
201             if (static_cast<ssize_t>(end - begin) <
202                 static_cast<ssize_t>(sizeof(in6_addr) || length != sizeof(in6_addr))) {
203                 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
204                 return begin;
205             }
206             in6_addr addr;
207             begin = ReadRawData(begin, addr);
208             data = addr;
209             return begin;
210         }
211         case DNSProto::RRTYPE_PTR: {
212             std::string str;
213             begin = ParseDnsString(begin, payload, str);
214             if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
215                 return begin;
216             }
217             data = str;
218             return begin;
219         }
220         case DNSProto::RRTYPE_SRV: {
221             return ParseSrv(begin, payload, data);
222         }
223         case DNSProto::RRTYPE_TXT: {
224             return ParseTxt(begin, payload, length, data);
225         }
226         default: {
227             errorFlags_ |= PARSE_WARNING_BAD_RRTYPE;
228             return begin + length;
229         }
230     }
231 }
232 
ParseSrv(const uint8_t * begin,const MDnsPayload & payload,std::any & data)233 const uint8_t *MDnsPayloadParser::ParseSrv(const uint8_t *begin, const MDnsPayload &payload, std::any &data)
234 {
235     const uint8_t *end = payload.data() + payload.size();
236     if (static_cast<ssize_t>(end - begin) <
237         static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t) + sizeof(uint16_t))) {
238         errorFlags_ |= PARSE_ERROR_BAD_SIZE;
239         return begin;
240     }
241 
242     DNSProto::RDataSrv srv;
243     begin = ReadNUint16(begin, srv.priority);
244     begin = ReadNUint16(begin, srv.weight);
245     begin = ReadNUint16(begin, srv.port);
246     begin = ParseDnsString(begin, payload, srv.name);
247     if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
248         return begin;
249     }
250     data = srv;
251     return begin;
252 }
253 
ParseTxt(const uint8_t * begin,const MDnsPayload & payload,int length,std::any & data)254 const uint8_t *MDnsPayloadParser::ParseTxt(const uint8_t *begin, const MDnsPayload &payload, int length, std::any &data)
255 {
256     const uint8_t *end = payload.data() + payload.size();
257     if (end - begin < length) {
258         errorFlags_ |= PARSE_ERROR_BAD_SIZE;
259         return begin;
260     }
261 
262     data = TxtRecordEncoded(begin, begin + length);
263     return begin + length;
264 }
265 
ParseDnsString(const uint8_t * begin,const MDnsPayload & payload,std::string & str)266 const uint8_t *MDnsPayloadParser::ParseDnsString(const uint8_t *begin, const MDnsPayload &payload, std::string &str)
267 {
268     const uint8_t *end = payload.data() + payload.size();
269     const uint8_t *p = begin;
270     str.reserve(MDNS_STR_INITIAL_SIZE);
271     while (p < end) {
272         if (*p == 0) {
273             return p + 1;
274         }
275         if (*p <= MDNS_MAX_DOMAIN_LABEL && p + *p < end) {
276             str.append(reinterpret_cast<const char *>(p) + 1, *p);
277             str.push_back(MDNS_DOMAIN_SPLITER);
278             p += (*p + 1);
279         } else if ((*p & DNS_STR_PTR_U8_MASK) == DNS_STR_PTR_U8_MASK) {
280             uint16_t offset;
281             const uint8_t *tmp = ReadNUint16(p, offset);
282             offset = offset & ~DNS_STR_PTR_U16_MASK;
283             if (offset >= payload.size()) {
284                 errorFlags_ |= PARSE_ERROR_BAD_STRPTR;
285                 return begin;
286             }
287             ParseDnsString(payload.data() + (offset & ~DNS_STR_PTR_U16_MASK), payload, str);
288             if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
289                 return begin;
290             }
291             return tmp;
292         } else {
293             errorFlags_ |= PARSE_ERROR_BAD_STR;
294             return p;
295         }
296     }
297     return p;
298 }
299 
Serialize(const MDnsMessage & msg,MDnsPayload & payload)300 void MDnsPayloadParser::Serialize(const MDnsMessage &msg, MDnsPayload &payload)
301 {
302     payload.reserve(sizeof(DNSProto::Message));
303     DNSProto::Header header = msg.header;
304     header.qdcount = msg.questions.size();
305     header.ancount = msg.answers.size();
306     header.nscount = msg.authorities.size();
307     header.arcount = msg.additional.size();
308     SerializeHeader(header, msg, payload);
309     for (uint16_t i = 0; i < header.qdcount; ++i) {
310         SerializeQuestion(msg.questions[i], msg, payload);
311     }
312     for (uint16_t i = 0; i < header.ancount; ++i) {
313         SerializeRR(msg.answers[i], msg, payload);
314     }
315     for (uint16_t i = 0; i < header.nscount; ++i) {
316         SerializeRR(msg.authorities[i], msg, payload);
317     }
318     for (uint16_t i = 0; i < header.arcount; ++i) {
319         SerializeRR(msg.additional[i], msg, payload);
320     }
321 }
322 
SerializeHeader(const DNSProto::Header & header,const MDnsMessage & msg,MDnsPayload & payload)323 void MDnsPayloadParser::SerializeHeader(const DNSProto::Header &header, const MDnsMessage &msg, MDnsPayload &payload)
324 {
325     WriteRawData(htons(header.id), payload);
326     WriteRawData(htons(header.flags), payload);
327     WriteRawData(htons(header.qdcount), payload);
328     WriteRawData(htons(header.ancount), payload);
329     WriteRawData(htons(header.nscount), payload);
330     WriteRawData(htons(header.arcount), payload);
331 }
332 
SerializeQuestion(const DNSProto::Question & question,const MDnsMessage & msg,MDnsPayload & payload)333 void MDnsPayloadParser::SerializeQuestion(const DNSProto::Question &question, const MDnsMessage &msg,
334                                           MDnsPayload &payload)
335 {
336     SerializeDnsString(question.name, msg, payload);
337     WriteRawData(htons(question.qtype), payload);
338     WriteRawData(htons(question.qclass), payload);
339 }
340 
SerializeRR(const DNSProto::ResourceRecord & rr,const MDnsMessage & msg,MDnsPayload & payload)341 void MDnsPayloadParser::SerializeRR(const DNSProto::ResourceRecord &rr, const MDnsMessage &msg, MDnsPayload &payload)
342 {
343     SerializeDnsString(rr.name, msg, payload);
344     WriteRawData(htons(rr.rtype), payload);
345     WriteRawData(htons(rr.rclass), payload);
346     WriteRawData(htonl(rr.ttl), payload);
347     size_t lenStart = payload.size();
348     WriteRawData(htons(rr.length), payload);
349     SerializeRData(rr.rdata, msg, payload);
350     uint16_t len = payload.size() - lenStart - sizeof(uint16_t);
351     WriteRawData(htons(len), payload.data() + lenStart);
352 }
353 
SerializeRData(const std::any & rdata,const MDnsMessage & msg,MDnsPayload & payload)354 void MDnsPayloadParser::SerializeRData(const std::any &rdata, const MDnsMessage &msg, MDnsPayload &payload)
355 {
356     if (std::any_cast<const in_addr>(&rdata)) {
357         WriteRawData(*std::any_cast<const in_addr>(&rdata), payload);
358     } else if (std::any_cast<const in6_addr>(&rdata)) {
359         WriteRawData(*std::any_cast<const in6_addr>(&rdata), payload);
360     } else if (std::any_cast<const std::string>(&rdata)) {
361         SerializeDnsString(*std::any_cast<const std::string>(&rdata), msg, payload);
362     } else if (std::any_cast<const DNSProto::RDataSrv>(&rdata)) {
363         const DNSProto::RDataSrv *srv = std::any_cast<const DNSProto::RDataSrv>(&rdata);
364         WriteRawData(htons(srv->priority), payload);
365         WriteRawData(htons(srv->weight), payload);
366         WriteRawData(htons(srv->port), payload);
367         SerializeDnsString(srv->name, msg, payload);
368     } else if (std::any_cast<TxtRecordEncoded>(&rdata)) {
369         const auto *txt = std::any_cast<TxtRecordEncoded>(&rdata);
370         payload.insert(payload.end(), txt->begin(), txt->end());
371     }
372 }
373 
SerializeDnsString(const std::string & str,const MDnsMessage & msg,MDnsPayload & payload)374 void MDnsPayloadParser::SerializeDnsString(const std::string &str, const MDnsMessage &msg, MDnsPayload &payload)
375 {
376     size_t pos = 0;
377     while (pos < str.size()) {
378         if (cachedPayload_ == &payload && strCacheMap_.find(str.substr(pos)) != strCacheMap_.end()) {
379             return WriteRawData(htons(strCacheMap_[str.substr(pos)]), payload);
380         }
381 
382         size_t nextDot = str.find(MDNS_DOMAIN_SPLITER, pos);
383         if (nextDot == std::string::npos) {
384             nextDot = str.size();
385         }
386         uint8_t segLen = (nextDot - pos) & DNS_STR_PTR_LENGTH;
387         uint16_t strptr = payload.size();
388         WriteRawData(segLen, payload);
389         for (int i = 0; i < segLen; ++i) {
390             WriteRawData(str[pos + i], payload);
391         }
392         strCacheMap_[str.substr(pos)] = strptr | DNS_STR_PTR_U16_MASK;
393         pos = nextDot + 1;
394     }
395     WriteRawData(DNS_STR_EOL, payload);
396 }
397 
GetError() const398 uint32_t MDnsPayloadParser::GetError() const
399 {
400     return errorFlags_ & PARSE_ERROR;
401 }
402 
403 } // namespace NetManagerStandard
404 } // namespace OHOS
405