• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "mdns_packet_parser.h"
17 #include "netmgr_ext_log_wrapper.h"
18 #include <cstring>
19 
20 namespace OHOS {
21 namespace NetManagerStandard {
22 
23 namespace {
24 
25 constexpr size_t MDNS_STR_INITIAL_SIZE = 16;
26 
27 constexpr uint8_t DNS_STR_PTR_U8_MASK = 0xc0;
28 constexpr uint16_t DNS_STR_PTR_U16_MASK = 0xc000;
29 constexpr uint16_t DNS_STR_PTR_LENGTH = 0x3f;
30 constexpr uint8_t DNS_STR_EOL = '\0';
31 
WriteRawData(const T & data,MDnsPayload & payload)32 template <class T> void WriteRawData(const T &data, MDnsPayload &payload)
33 {
34     const uint8_t *begin = reinterpret_cast<const uint8_t *>(&data);
35     payload.insert(payload.end(), begin, begin + sizeof(T));
36 }
37 
WriteRawData(const T & data,uint8_t * ptr)38 template <class T> void WriteRawData(const T &data, uint8_t *ptr)
39 {
40     const uint8_t *begin = reinterpret_cast<const uint8_t *>(&data);
41     for (size_t i = 0; i < sizeof(T); ++i) {
42         ptr[i] = *begin++;
43     }
44 }
45 
ReadRawData(const uint8_t * raw,T & data)46 template <class T> const uint8_t *ReadRawData(const uint8_t *raw, T &data)
47 {
48     data = *reinterpret_cast<const T *>(raw);
49     return raw + sizeof(T);
50 }
51 
ReadNUint16(const uint8_t * raw,uint16_t & data)52 const uint8_t *ReadNUint16(const uint8_t *raw, uint16_t &data)
53 {
54     const uint8_t *tmp = ReadRawData(raw, data);
55     data = ntohs(data);
56     return tmp;
57 }
58 
ReadNUint32(const uint8_t * raw,uint32_t & data)59 const uint8_t *ReadNUint32(const uint8_t *raw, uint32_t &data)
60 {
61     const uint8_t *tmp = ReadRawData(raw, data);
62     data = ntohl(data);
63     return tmp;
64 }
65 
UnDotted(const std::string & name)66 std::string UnDotted(const std::string &name)
67 {
68     return EndsWith(name, MDNS_DOMAIN_SPLITER_STR) ? name.substr(0, name.size() - 1) : name;
69 }
70 
71 } // namespace
72 
FromBytes(const MDnsPayload & payload)73 MDnsMessage MDnsPayloadParser::FromBytes(const MDnsPayload &payload)
74 {
75     MDnsMessage msg;
76     errorFlags_ = PARSE_OK;
77     pos_ = Parse(payload.data(), payload, msg);
78     return msg;
79 }
80 
ToBytes(const MDnsMessage & msg)81 MDnsPayload MDnsPayloadParser::ToBytes(const MDnsMessage &msg)
82 {
83     MDnsPayload payload;
84     MDnsPayload *cachedPayload = &payload;
85     std::map<std::string, uint16_t> strCacheMap;
86     Serialize(msg, payload, cachedPayload, strCacheMap);
87     return payload;
88 }
89 
Parse(const uint8_t * begin,const MDnsPayload & payload,MDnsMessage & msg)90 const uint8_t *MDnsPayloadParser::Parse(const uint8_t *begin, const MDnsPayload &payload, MDnsMessage &msg)
91 {
92     begin = ParseHeader(begin, payload, msg.header);
93     if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
94         return begin;
95     }
96     for (int i = 0; i < msg.header.qdcount; ++i) {
97         begin = ParseQuestion(begin, payload, msg.questions);
98         if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
99             return begin;
100         }
101     }
102     for (int i = 0; i < msg.header.ancount; ++i) {
103         begin = ParseRR(begin, payload, msg.answers);
104         if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
105             return begin;
106         }
107     }
108     for (int i = 0; i < msg.header.nscount; ++i) {
109         begin = ParseRR(begin, payload, msg.authorities);
110         if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
111             return begin;
112         }
113     }
114     for (int i = 0; i < msg.header.arcount; ++i) {
115         begin = ParseRR(begin, payload, msg.additional);
116         if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
117             return begin;
118         }
119     }
120     return begin;
121 }
122 
ParseHeader(const uint8_t * begin,const MDnsPayload & payload,DNSProto::Header & header)123 const uint8_t *MDnsPayloadParser::ParseHeader(const uint8_t *begin, const MDnsPayload &payload,
124                                               DNSProto::Header &header)
125 {
126     const uint8_t *end = payload.data() + payload.size();
127     if (end - begin < static_cast<int>(sizeof(DNSProto::Header))) {
128         errorFlags_ |= PARSE_ERROR_BAD_SIZE;
129         return begin;
130     }
131 
132     begin = ReadNUint16(begin, header.id);
133     begin = ReadNUint16(begin, header.flags);
134     begin = ReadNUint16(begin, header.qdcount);
135     begin = ReadNUint16(begin, header.ancount);
136     begin = ReadNUint16(begin, header.nscount);
137     begin = ReadNUint16(begin, header.arcount);
138     return begin;
139 }
140 
ParseQuestion(const uint8_t * begin,const MDnsPayload & payload,std::vector<DNSProto::Question> & questions)141 const uint8_t *MDnsPayloadParser::ParseQuestion(const uint8_t *begin, const MDnsPayload &payload,
142                                                 std::vector<DNSProto::Question> &questions)
143 {
144     questions.emplace_back();
145     begin = ParseDnsString(begin, payload, questions.back().name);
146     questions.back().name = UnDotted(questions.back().name);
147     if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
148         questions.pop_back();
149         return begin;
150     }
151 
152     const uint8_t *end = payload.data() + payload.size();
153     if (static_cast<ssize_t>(end - begin) < static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t))) {
154         errorFlags_ |= PARSE_ERROR_BAD_SIZE;
155         questions.pop_back();
156         return begin;
157     }
158 
159     begin = ReadNUint16(begin, questions.back().qtype);
160     begin = ReadNUint16(begin, questions.back().qclass);
161     return begin;
162 }
163 
ParseRR(const uint8_t * begin,const MDnsPayload & payload,std::vector<DNSProto::ResourceRecord> & answers)164 const uint8_t *MDnsPayloadParser::ParseRR(const uint8_t *begin, const MDnsPayload &payload,
165                                           std::vector<DNSProto::ResourceRecord> &answers)
166 {
167     answers.emplace_back();
168     begin = ParseDnsString(begin, payload, answers.back().name);
169     answers.back().name = UnDotted(answers.back().name);
170     if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
171         answers.pop_back();
172         return begin;
173     }
174 
175     const uint8_t *end = payload.data() + payload.size();
176     if (static_cast<ssize_t>(end - begin) <
177         static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t) + sizeof(uint32_t) + sizeof(uint16_t))) {
178         errorFlags_ |= PARSE_ERROR_BAD_SIZE;
179         answers.pop_back();
180         return begin;
181     }
182     begin = ReadNUint16(begin, answers.back().rtype);
183     begin = ReadNUint16(begin, answers.back().rclass);
184     begin = ReadNUint32(begin, answers.back().ttl);
185     begin = ReadNUint16(begin, answers.back().length);
186     return ParseRData(begin, payload, answers.back().rtype, answers.back().length, answers.back().rdata);
187 }
188 
ParseRData(const uint8_t * begin,const MDnsPayload & payload,int type,int length,std::any & data)189 const uint8_t *MDnsPayloadParser::ParseRData(const uint8_t *begin, const MDnsPayload &payload, int type, int length,
190                                              std::any &data)
191 {
192     switch (type) {
193         case DNSProto::RRTYPE_A: {
194             const uint8_t *end = payload.data() + payload.size();
195             if (static_cast<size_t>(end - begin) < sizeof(in_addr) || length != sizeof(in_addr)) {
196                 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
197                 return begin;
198             }
199             in_addr addr;
200             begin = ReadRawData(begin, addr);
201             data = addr;
202             return begin;
203         }
204         case DNSProto::RRTYPE_AAAA: {
205             const uint8_t *end = payload.data() + payload.size();
206             if ((static_cast<ssize_t>(end - begin) <
207                 static_cast<ssize_t>(sizeof(in6_addr))) || (length != sizeof(in6_addr))) {
208                 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
209                 return begin;
210             }
211             in6_addr addr;
212             begin = ReadRawData(begin, addr);
213             data = addr;
214             return begin;
215         }
216         case DNSProto::RRTYPE_PTR: {
217             std::string str;
218             begin = ParseDnsString(begin, payload, str);
219             str = UnDotted(str);
220             if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
221                 return begin;
222             }
223             data = str;
224             return begin;
225         }
226         case DNSProto::RRTYPE_SRV: {
227             return ParseSrv(begin, payload, data);
228         }
229         case DNSProto::RRTYPE_TXT: {
230             return ParseTxt(begin, payload, length, data);
231         }
232         default: {
233             errorFlags_ |= PARSE_WARNING_BAD_RRTYPE;
234             return begin + length;
235         }
236     }
237 }
238 
ParseSrv(const uint8_t * begin,const MDnsPayload & payload,std::any & data)239 const uint8_t *MDnsPayloadParser::ParseSrv(const uint8_t *begin, const MDnsPayload &payload, std::any &data)
240 {
241     const uint8_t *end = payload.data() + payload.size();
242     if (static_cast<ssize_t>(end - begin) <
243         static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t) + sizeof(uint16_t))) {
244         errorFlags_ |= PARSE_ERROR_BAD_SIZE;
245         return begin;
246     }
247 
248     DNSProto::RDataSrv srv;
249     begin = ReadNUint16(begin, srv.priority);
250     begin = ReadNUint16(begin, srv.weight);
251     begin = ReadNUint16(begin, srv.port);
252     begin = ParseDnsString(begin, payload, srv.name);
253     srv.name = UnDotted(srv.name);
254     if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
255         return begin;
256     }
257     data = srv;
258     return begin;
259 }
260 
ParseTxt(const uint8_t * begin,const MDnsPayload & payload,int length,std::any & data)261 const uint8_t *MDnsPayloadParser::ParseTxt(const uint8_t *begin, const MDnsPayload &payload, int length, std::any &data)
262 {
263     const uint8_t *end = payload.data() + payload.size();
264     if (end - begin < length) {
265         errorFlags_ |= PARSE_ERROR_BAD_SIZE;
266         return begin;
267     }
268 
269     data = TxtRecordEncoded(begin, begin + length);
270     return begin + length;
271 }
272 
ParseDnsString(const uint8_t * begin,const MDnsPayload & payload,std::string & str)273 const uint8_t *MDnsPayloadParser::ParseDnsString(const uint8_t *begin, const MDnsPayload &payload, std::string &str)
274 {
275     const uint8_t *end = payload.data() + payload.size();
276     const uint8_t *p = begin;
277     str.reserve(MDNS_STR_INITIAL_SIZE);
278     while (p && p < end) {
279         if (*p == 0) {
280             return p + 1;
281         }
282         if (*p <= MDNS_MAX_DOMAIN_LABEL && p + *p < end) {
283             str.append(reinterpret_cast<const char *>(p) + 1, *p);
284             str.push_back(MDNS_DOMAIN_SPLITER);
285             p += (*p + 1);
286         } else if ((*p & DNS_STR_PTR_U8_MASK) == DNS_STR_PTR_U8_MASK) {
287             if (end - p < static_cast<int>(sizeof(uint16_t))) {
288                 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
289                 return begin;
290             }
291 
292             uint16_t offset;
293             const uint8_t *tmp = ReadNUint16(p, offset);
294             offset = offset & ~DNS_STR_PTR_U16_MASK;
295             const uint8_t *next = payload.data() + (offset & ~DNS_STR_PTR_U16_MASK);
296 
297             if (next >= end || next >= begin) {
298                 errorFlags_ |= PARSE_ERROR_BAD_STRPTR;
299                 return begin;
300             }
301             ParseDnsString(next, payload, str);
302             if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
303                 return begin;
304             }
305             return tmp;
306         } else {
307             errorFlags_ |= PARSE_ERROR_BAD_STR;
308             return p;
309         }
310     }
311     return p;
312 }
313 
Serialize(const MDnsMessage & msg,MDnsPayload & payload,MDnsPayload * cachedPayload,std::map<std::string,uint16_t> & strCacheMap)314 void MDnsPayloadParser::Serialize(const MDnsMessage &msg, MDnsPayload &payload, MDnsPayload *cachedPayload,
315                                   std::map<std::string, uint16_t> &strCacheMap)
316 {
317     payload.reserve(sizeof(DNSProto::Message));
318     DNSProto::Header header = msg.header;
319     header.qdcount = msg.questions.size();
320     header.ancount = msg.answers.size();
321     header.nscount = msg.authorities.size();
322     header.arcount = msg.additional.size();
323     SerializeHeader(header, msg, payload);
324     for (uint16_t i = 0; i < header.qdcount; ++i) {
325         SerializeQuestion(msg.questions[i], payload, cachedPayload, strCacheMap);
326     }
327     for (uint16_t i = 0; i < header.ancount; ++i) {
328         SerializeRR(msg.answers[i], payload, cachedPayload, strCacheMap);
329     }
330     for (uint16_t i = 0; i < header.nscount; ++i) {
331         SerializeRR(msg.authorities[i], payload, cachedPayload, strCacheMap);
332     }
333     for (uint16_t i = 0; i < header.arcount; ++i) {
334         SerializeRR(msg.additional[i], payload, cachedPayload, strCacheMap);
335     }
336 }
337 
SerializeHeader(const DNSProto::Header & header,const MDnsMessage & msg,MDnsPayload & payload)338 void MDnsPayloadParser::SerializeHeader(const DNSProto::Header &header, const MDnsMessage &msg, MDnsPayload &payload)
339 {
340     WriteRawData(htons(header.id), payload);
341     WriteRawData(htons(header.flags), payload);
342     WriteRawData(htons(header.qdcount), payload);
343     WriteRawData(htons(header.ancount), payload);
344     WriteRawData(htons(header.nscount), payload);
345     WriteRawData(htons(header.arcount), payload);
346 }
347 
SerializeQuestion(const DNSProto::Question & question,MDnsPayload & payload,MDnsPayload * cachedPayload,std::map<std::string,uint16_t> & strCacheMap)348 void MDnsPayloadParser::SerializeQuestion(const DNSProto::Question &question, MDnsPayload &payload,
349                                           MDnsPayload *cachedPayload, std::map<std::string, uint16_t> &strCacheMap)
350 {
351     SerializeDnsString(question.name, payload, cachedPayload, strCacheMap);
352     WriteRawData(htons(question.qtype), payload);
353     WriteRawData(htons(question.qclass), payload);
354 }
355 
SerializeRR(const DNSProto::ResourceRecord & rr,MDnsPayload & payload,MDnsPayload * cachedPayload,std::map<std::string,uint16_t> & strCacheMap)356 void MDnsPayloadParser::SerializeRR(const DNSProto::ResourceRecord &rr, MDnsPayload &payload,
357                                     MDnsPayload *cachedPayload, std::map<std::string, uint16_t> &strCacheMap)
358 {
359     SerializeDnsString(rr.name, payload, cachedPayload, strCacheMap);
360     WriteRawData(htons(rr.rtype), payload);
361     WriteRawData(htons(rr.rclass), payload);
362     WriteRawData(htonl(rr.ttl), payload);
363     size_t lenStart = payload.size();
364     WriteRawData(htons(rr.length), payload);
365     SerializeRData(rr.rdata, payload, cachedPayload, strCacheMap);
366     uint16_t len = payload.size() - lenStart - sizeof(uint16_t);
367     WriteRawData(htons(len), payload.data() + lenStart);
368 }
369 
SerializeRData(const std::any & rdata,MDnsPayload & payload,MDnsPayload * cachedPayload,std::map<std::string,uint16_t> & strCacheMap)370 void MDnsPayloadParser::SerializeRData(const std::any &rdata, MDnsPayload &payload, MDnsPayload *cachedPayload,
371                                        std::map<std::string, uint16_t> &strCacheMap)
372 {
373     if (std::any_cast<const in_addr>(&rdata)) {
374         WriteRawData(*std::any_cast<const in_addr>(&rdata), payload);
375     } else if (std::any_cast<const in6_addr>(&rdata)) {
376         WriteRawData(*std::any_cast<const in6_addr>(&rdata), payload);
377     } else if (std::any_cast<const std::string>(&rdata)) {
378         SerializeDnsString(*std::any_cast<const std::string>(&rdata), payload, cachedPayload, strCacheMap);
379     } else if (std::any_cast<const DNSProto::RDataSrv>(&rdata)) {
380         const DNSProto::RDataSrv *srv = std::any_cast<const DNSProto::RDataSrv>(&rdata);
381         WriteRawData(htons(srv->priority), payload);
382         WriteRawData(htons(srv->weight), payload);
383         WriteRawData(htons(srv->port), payload);
384         SerializeDnsString(srv->name, payload, cachedPayload, strCacheMap);
385     } else if (std::any_cast<TxtRecordEncoded>(&rdata)) {
386         const auto *txt = std::any_cast<TxtRecordEncoded>(&rdata);
387         payload.insert(payload.end(), txt->begin(), txt->end());
388     }
389 }
390 
SerializeDnsString(const std::string & str,MDnsPayload & payload,MDnsPayload * cachedPayload,std::map<std::string,uint16_t> & strCacheMap)391 void MDnsPayloadParser::SerializeDnsString(const std::string &str, MDnsPayload &payload, MDnsPayload *cachedPayload,
392                                            std::map<std::string, uint16_t> &strCacheMap)
393 {
394     size_t pos = 0;
395     while (pos < str.size()) {
396         if ((cachedPayload == &payload) && (strCacheMap.find(str.substr(pos)) != strCacheMap.end())) {
397             return WriteRawData(htons(strCacheMap[str.substr(pos)]), payload);
398         }
399 
400         size_t nextDot = str.find(MDNS_DOMAIN_SPLITER, pos);
401         if (nextDot == std::string::npos) {
402             nextDot = str.size();
403         }
404         uint8_t segLen = (nextDot - pos) & DNS_STR_PTR_LENGTH;
405 
406         uint16_t strptr = payload.size();
407         WriteRawData(segLen, payload);
408         for (int i = 0; i < segLen; ++i) {
409             WriteRawData(str[pos + i], payload);
410         }
411         strCacheMap[str.substr(pos)] = strptr | DNS_STR_PTR_U16_MASK;
412         pos = nextDot + 1;
413     }
414     WriteRawData(DNS_STR_EOL, payload);
415 }
416 
GetError() const417 uint32_t MDnsPayloadParser::GetError() const
418 {
419     return errorFlags_ & PARSE_ERROR;
420 }
421 
422 } // namespace NetManagerStandard
423 } // namespace OHOS
424