• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "mdns_packet_parser.h"
17 
18 #include <cstring>
19 
20 namespace OHOS {
21 namespace NetManagerStandard {
22 
23 namespace {
24 
25 constexpr size_t MDNS_STR_INITIAL_SIZE = 16;
26 
27 constexpr uint8_t DNS_STR_PTR_U8_MASK = 0xc0;
28 constexpr uint16_t DNS_STR_PTR_U16_MASK = 0xc000;
29 constexpr uint16_t DNS_STR_PTR_LENGTH = 0x3f;
30 constexpr uint8_t DNS_STR_EOL = '\0';
31 
WriteRawData(const T & data,MDnsPayload & payload)32 template <class T> void WriteRawData(const T &data, MDnsPayload &payload)
33 {
34     const uint8_t *begin = reinterpret_cast<const uint8_t *>(&data);
35     payload.insert(payload.end(), begin, begin + sizeof(T));
36 }
37 
WriteRawData(const T & data,uint8_t * ptr)38 template <class T> void WriteRawData(const T &data, uint8_t *ptr)
39 {
40     const uint8_t *begin = reinterpret_cast<const uint8_t *>(&data);
41     for (size_t i = 0; i < sizeof(T); ++i) {
42         ptr[i] = *begin++;
43     }
44 }
45 
ReadRawData(const uint8_t * raw,T & data)46 template <class T> const uint8_t *ReadRawData(const uint8_t *raw, T &data)
47 {
48     data = *reinterpret_cast<const T *>(raw);
49     return raw + sizeof(T);
50 }
51 
ReadNUint16(const uint8_t * raw,uint16_t & data)52 const uint8_t *ReadNUint16(const uint8_t *raw, uint16_t &data)
53 {
54     const uint8_t *tmp = ReadRawData(raw, data);
55     data = ntohs(data);
56     return tmp;
57 }
58 
ReadNUint32(const uint8_t * raw,uint32_t & data)59 const uint8_t *ReadNUint32(const uint8_t *raw, uint32_t &data)
60 {
61     const uint8_t *tmp = ReadRawData(raw, data);
62     data = ntohl(data);
63     return tmp;
64 }
65 
UnDotted(const std::string & name)66 std::string UnDotted(const std::string &name)
67 {
68     return EndsWith(name, MDNS_DOMAIN_SPLITER_STR) ? name.substr(0, name.size() - 1) : name;
69 }
70 
71 } // namespace
72 
FromBytes(const MDnsPayload & payload)73 MDnsMessage MDnsPayloadParser::FromBytes(const MDnsPayload &payload)
74 {
75     MDnsMessage msg;
76     errorFlags_ = PARSE_OK;
77     pos_ = Parse(payload.data(), payload, msg);
78     return msg;
79 }
80 
ToBytes(const MDnsMessage & msg)81 MDnsPayload MDnsPayloadParser::ToBytes(const MDnsMessage &msg)
82 {
83     MDnsPayload payload;
84     cachedPayload_ = &payload;
85     strCacheMap_.clear();
86     Serialize(msg, payload);
87     cachedPayload_ = nullptr;
88     strCacheMap_.clear();
89     return payload;
90 }
91 
Parse(const uint8_t * begin,const MDnsPayload & payload,MDnsMessage & msg)92 const uint8_t *MDnsPayloadParser::Parse(const uint8_t *begin, const MDnsPayload &payload, MDnsMessage &msg)
93 {
94     begin = ParseHeader(begin, payload, msg.header);
95     if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
96         return begin;
97     }
98     for (int i = 0; i < msg.header.qdcount; ++i) {
99         begin = ParseQuestion(begin, payload, msg.questions);
100         if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
101             return begin;
102         }
103     }
104     for (int i = 0; i < msg.header.ancount; ++i) {
105         begin = ParseRR(begin, payload, msg.answers);
106         if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
107             return begin;
108         }
109     }
110     for (int i = 0; i < msg.header.nscount; ++i) {
111         begin = ParseRR(begin, payload, msg.authorities);
112         if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
113             return begin;
114         }
115     }
116     for (int i = 0; i < msg.header.arcount; ++i) {
117         begin = ParseRR(begin, payload, msg.additional);
118         if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
119             return begin;
120         }
121     }
122     return begin;
123 }
124 
ParseHeader(const uint8_t * begin,const MDnsPayload & payload,DNSProto::Header & header)125 const uint8_t *MDnsPayloadParser::ParseHeader(const uint8_t *begin, const MDnsPayload &payload,
126                                               DNSProto::Header &header)
127 {
128     const uint8_t *end = payload.data() + payload.size();
129     if (end - begin < static_cast<int>(sizeof(DNSProto::Header))) {
130         errorFlags_ |= PARSE_ERROR_BAD_SIZE;
131         return begin;
132     }
133 
134     begin = ReadNUint16(begin, header.id);
135     begin = ReadNUint16(begin, header.flags);
136     begin = ReadNUint16(begin, header.qdcount);
137     begin = ReadNUint16(begin, header.ancount);
138     begin = ReadNUint16(begin, header.nscount);
139     begin = ReadNUint16(begin, header.arcount);
140     return begin;
141 }
142 
ParseQuestion(const uint8_t * begin,const MDnsPayload & payload,std::vector<DNSProto::Question> & questions)143 const uint8_t *MDnsPayloadParser::ParseQuestion(const uint8_t *begin, const MDnsPayload &payload,
144                                                 std::vector<DNSProto::Question> &questions)
145 {
146     questions.emplace_back();
147     begin = ParseDnsString(begin, payload, questions.back().name);
148     questions.back().name = UnDotted(questions.back().name);
149     if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
150         questions.pop_back();
151         return begin;
152     }
153 
154     const uint8_t *end = payload.data() + payload.size();
155     if (static_cast<ssize_t>(end - begin) < static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t))) {
156         errorFlags_ |= PARSE_ERROR_BAD_SIZE;
157         questions.pop_back();
158         return begin;
159     }
160 
161     begin = ReadNUint16(begin, questions.back().qtype);
162     begin = ReadNUint16(begin, questions.back().qclass);
163     return begin;
164 }
165 
ParseRR(const uint8_t * begin,const MDnsPayload & payload,std::vector<DNSProto::ResourceRecord> & answers)166 const uint8_t *MDnsPayloadParser::ParseRR(const uint8_t *begin, const MDnsPayload &payload,
167                                           std::vector<DNSProto::ResourceRecord> &answers)
168 {
169     answers.emplace_back();
170     begin = ParseDnsString(begin, payload, answers.back().name);
171     answers.back().name = UnDotted(answers.back().name);
172     if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
173         answers.pop_back();
174         return begin;
175     }
176 
177     const uint8_t *end = payload.data() + payload.size();
178     if (static_cast<ssize_t>(end - begin) <
179         static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t) + sizeof(uint32_t) + sizeof(uint16_t))) {
180         errorFlags_ |= PARSE_ERROR_BAD_SIZE;
181         answers.pop_back();
182         return begin;
183     }
184     begin = ReadNUint16(begin, answers.back().rtype);
185     begin = ReadNUint16(begin, answers.back().rclass);
186     begin = ReadNUint32(begin, answers.back().ttl);
187     begin = ReadNUint16(begin, answers.back().length);
188     return ParseRData(begin, payload, answers.back().rtype, answers.back().length, answers.back().rdata);
189 }
190 
ParseRData(const uint8_t * begin,const MDnsPayload & payload,int type,int length,std::any & data)191 const uint8_t *MDnsPayloadParser::ParseRData(const uint8_t *begin, const MDnsPayload &payload, int type, int length,
192                                              std::any &data)
193 {
194     switch (type) {
195         case DNSProto::RRTYPE_A: {
196             const uint8_t *end = payload.data() + payload.size();
197             if (static_cast<size_t>(end - begin) < sizeof(in_addr) || length != sizeof(in_addr)) {
198                 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
199                 return begin;
200             }
201             in_addr addr;
202             begin = ReadRawData(begin, addr);
203             data = addr;
204             return begin;
205         }
206         case DNSProto::RRTYPE_AAAA: {
207             const uint8_t *end = payload.data() + payload.size();
208             if (static_cast<ssize_t>(end - begin) <
209                 static_cast<ssize_t>(sizeof(in6_addr) || length != sizeof(in6_addr))) {
210                 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
211                 return begin;
212             }
213             in6_addr addr;
214             begin = ReadRawData(begin, addr);
215             data = addr;
216             return begin;
217         }
218         case DNSProto::RRTYPE_PTR: {
219             std::string str;
220             begin = ParseDnsString(begin, payload, str);
221             str = UnDotted(str);
222             if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
223                 return begin;
224             }
225             data = str;
226             return begin;
227         }
228         case DNSProto::RRTYPE_SRV: {
229             return ParseSrv(begin, payload, data);
230         }
231         case DNSProto::RRTYPE_TXT: {
232             return ParseTxt(begin, payload, length, data);
233         }
234         default: {
235             errorFlags_ |= PARSE_WARNING_BAD_RRTYPE;
236             return begin + length;
237         }
238     }
239 }
240 
ParseSrv(const uint8_t * begin,const MDnsPayload & payload,std::any & data)241 const uint8_t *MDnsPayloadParser::ParseSrv(const uint8_t *begin, const MDnsPayload &payload, std::any &data)
242 {
243     const uint8_t *end = payload.data() + payload.size();
244     if (static_cast<ssize_t>(end - begin) <
245         static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t) + sizeof(uint16_t))) {
246         errorFlags_ |= PARSE_ERROR_BAD_SIZE;
247         return begin;
248     }
249 
250     DNSProto::RDataSrv srv;
251     begin = ReadNUint16(begin, srv.priority);
252     begin = ReadNUint16(begin, srv.weight);
253     begin = ReadNUint16(begin, srv.port);
254     begin = ParseDnsString(begin, payload, srv.name);
255     srv.name = UnDotted(srv.name);
256     if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
257         return begin;
258     }
259     data = srv;
260     return begin;
261 }
262 
ParseTxt(const uint8_t * begin,const MDnsPayload & payload,int length,std::any & data)263 const uint8_t *MDnsPayloadParser::ParseTxt(const uint8_t *begin, const MDnsPayload &payload, int length, std::any &data)
264 {
265     const uint8_t *end = payload.data() + payload.size();
266     if (end - begin < length) {
267         errorFlags_ |= PARSE_ERROR_BAD_SIZE;
268         return begin;
269     }
270 
271     data = TxtRecordEncoded(begin, begin + length);
272     return begin + length;
273 }
274 
ParseDnsString(const uint8_t * begin,const MDnsPayload & payload,std::string & str)275 const uint8_t *MDnsPayloadParser::ParseDnsString(const uint8_t *begin, const MDnsPayload &payload, std::string &str)
276 {
277     const uint8_t *end = payload.data() + payload.size();
278     const uint8_t *p = begin;
279     str.reserve(MDNS_STR_INITIAL_SIZE);
280     while (p && p < end) {
281         if (*p == 0) {
282             return p + 1;
283         }
284         if (*p <= MDNS_MAX_DOMAIN_LABEL && p + *p < end) {
285             str.append(reinterpret_cast<const char *>(p) + 1, *p);
286             str.push_back(MDNS_DOMAIN_SPLITER);
287             p += (*p + 1);
288         } else if ((*p & DNS_STR_PTR_U8_MASK) == DNS_STR_PTR_U8_MASK) {
289             if (end - p < static_cast<int>(sizeof(uint16_t))) {
290                 errorFlags_ |= PARSE_ERROR_BAD_SIZE;
291                 return begin;
292             }
293 
294             uint16_t offset;
295             const uint8_t *tmp = ReadNUint16(p, offset);
296             offset = offset & ~DNS_STR_PTR_U16_MASK;
297             const uint8_t *next = payload.data() + (offset & ~DNS_STR_PTR_U16_MASK);
298 
299             if (next >= end || next >= begin) {
300                 errorFlags_ |= PARSE_ERROR_BAD_STRPTR;
301                 return begin;
302             }
303             ParseDnsString(next, payload, str);
304             if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
305                 return begin;
306             }
307             return tmp;
308         } else {
309             errorFlags_ |= PARSE_ERROR_BAD_STR;
310             return p;
311         }
312     }
313     return p;
314 }
315 
Serialize(const MDnsMessage & msg,MDnsPayload & payload)316 void MDnsPayloadParser::Serialize(const MDnsMessage &msg, MDnsPayload &payload)
317 {
318     payload.reserve(sizeof(DNSProto::Message));
319     DNSProto::Header header = msg.header;
320     header.qdcount = msg.questions.size();
321     header.ancount = msg.answers.size();
322     header.nscount = msg.authorities.size();
323     header.arcount = msg.additional.size();
324     SerializeHeader(header, msg, payload);
325     for (uint16_t i = 0; i < header.qdcount; ++i) {
326         SerializeQuestion(msg.questions[i], msg, payload);
327     }
328     for (uint16_t i = 0; i < header.ancount; ++i) {
329         SerializeRR(msg.answers[i], msg, payload);
330     }
331     for (uint16_t i = 0; i < header.nscount; ++i) {
332         SerializeRR(msg.authorities[i], msg, payload);
333     }
334     for (uint16_t i = 0; i < header.arcount; ++i) {
335         SerializeRR(msg.additional[i], msg, payload);
336     }
337 }
338 
SerializeHeader(const DNSProto::Header & header,const MDnsMessage & msg,MDnsPayload & payload)339 void MDnsPayloadParser::SerializeHeader(const DNSProto::Header &header, const MDnsMessage &msg, MDnsPayload &payload)
340 {
341     WriteRawData(htons(header.id), payload);
342     WriteRawData(htons(header.flags), payload);
343     WriteRawData(htons(header.qdcount), payload);
344     WriteRawData(htons(header.ancount), payload);
345     WriteRawData(htons(header.nscount), payload);
346     WriteRawData(htons(header.arcount), payload);
347 }
348 
SerializeQuestion(const DNSProto::Question & question,const MDnsMessage & msg,MDnsPayload & payload)349 void MDnsPayloadParser::SerializeQuestion(const DNSProto::Question &question, const MDnsMessage &msg,
350                                           MDnsPayload &payload)
351 {
352     SerializeDnsString(question.name, msg, payload);
353     WriteRawData(htons(question.qtype), payload);
354     WriteRawData(htons(question.qclass), payload);
355 }
356 
SerializeRR(const DNSProto::ResourceRecord & rr,const MDnsMessage & msg,MDnsPayload & payload)357 void MDnsPayloadParser::SerializeRR(const DNSProto::ResourceRecord &rr, const MDnsMessage &msg, MDnsPayload &payload)
358 {
359     SerializeDnsString(rr.name, msg, payload);
360     WriteRawData(htons(rr.rtype), payload);
361     WriteRawData(htons(rr.rclass), payload);
362     WriteRawData(htonl(rr.ttl), payload);
363     size_t lenStart = payload.size();
364     WriteRawData(htons(rr.length), payload);
365     SerializeRData(rr.rdata, msg, payload);
366     uint16_t len = payload.size() - lenStart - sizeof(uint16_t);
367     WriteRawData(htons(len), payload.data() + lenStart);
368 }
369 
SerializeRData(const std::any & rdata,const MDnsMessage & msg,MDnsPayload & payload)370 void MDnsPayloadParser::SerializeRData(const std::any &rdata, const MDnsMessage &msg, MDnsPayload &payload)
371 {
372     if (std::any_cast<const in_addr>(&rdata)) {
373         WriteRawData(*std::any_cast<const in_addr>(&rdata), payload);
374     } else if (std::any_cast<const in6_addr>(&rdata)) {
375         WriteRawData(*std::any_cast<const in6_addr>(&rdata), payload);
376     } else if (std::any_cast<const std::string>(&rdata)) {
377         SerializeDnsString(*std::any_cast<const std::string>(&rdata), msg, payload);
378     } else if (std::any_cast<const DNSProto::RDataSrv>(&rdata)) {
379         const DNSProto::RDataSrv *srv = std::any_cast<const DNSProto::RDataSrv>(&rdata);
380         WriteRawData(htons(srv->priority), payload);
381         WriteRawData(htons(srv->weight), payload);
382         WriteRawData(htons(srv->port), payload);
383         SerializeDnsString(srv->name, msg, payload);
384     } else if (std::any_cast<TxtRecordEncoded>(&rdata)) {
385         const auto *txt = std::any_cast<TxtRecordEncoded>(&rdata);
386         payload.insert(payload.end(), txt->begin(), txt->end());
387     }
388 }
389 
SerializeDnsString(const std::string & str,const MDnsMessage & msg,MDnsPayload & payload)390 void MDnsPayloadParser::SerializeDnsString(const std::string &str, const MDnsMessage &msg, MDnsPayload &payload)
391 {
392     size_t pos = 0;
393     while (pos < str.size()) {
394         if (cachedPayload_ == &payload && strCacheMap_.find(str.substr(pos)) != strCacheMap_.end()) {
395             return WriteRawData(htons(strCacheMap_[str.substr(pos)]), payload);
396         }
397 
398         size_t nextDot = str.find(MDNS_DOMAIN_SPLITER, pos);
399         if (nextDot == std::string::npos) {
400             nextDot = str.size();
401         }
402         uint8_t segLen = (nextDot - pos) & DNS_STR_PTR_LENGTH;
403         uint16_t strptr = payload.size();
404         WriteRawData(segLen, payload);
405         for (int i = 0; i < segLen; ++i) {
406             WriteRawData(str[pos + i], payload);
407         }
408         strCacheMap_[str.substr(pos)] = strptr | DNS_STR_PTR_U16_MASK;
409         pos = nextDot + 1;
410     }
411     WriteRawData(DNS_STR_EOL, payload);
412 }
413 
GetError() const414 uint32_t MDnsPayloadParser::GetError() const
415 {
416     return errorFlags_ & PARSE_ERROR;
417 }
418 
419 } // namespace NetManagerStandard
420 } // namespace OHOS
421