• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2020, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements generating and processing of DNS headers and helper functions/methods.
32  */
33 
34 #include "dns_types.hpp"
35 
36 #include "common/code_utils.hpp"
37 #include "common/debug.hpp"
38 #include "common/instance.hpp"
39 #include "common/random.hpp"
40 #include "common/string.hpp"
41 
42 namespace ot {
43 namespace Dns {
44 
45 using ot::Encoding::BigEndian::HostSwap16;
46 
SetRandomMessageId(void)47 Error Header::SetRandomMessageId(void)
48 {
49     return Random::Crypto::FillBuffer(reinterpret_cast<uint8_t *>(&mMessageId), sizeof(mMessageId));
50 }
51 
ResponseCodeToError(Response aResponse)52 Error Header::ResponseCodeToError(Response aResponse)
53 {
54     Error error = kErrorFailed;
55 
56     switch (aResponse)
57     {
58     case kResponseSuccess:
59         error = kErrorNone;
60         break;
61 
62     case kResponseFormatError:   // Server unable to interpret request due to format error.
63     case kResponseBadName:       // Bad name.
64     case kResponseBadTruncation: // Bad truncation.
65     case kResponseNotZone:       // A name is not in the zone.
66         error = kErrorParse;
67         break;
68 
69     case kResponseServerFailure: // Server encountered an internal failure.
70         error = kErrorFailed;
71         break;
72 
73     case kResponseNameError:       // Name that ought to exist, does not exists.
74     case kResponseRecordNotExists: // Some RRset that ought to exist, does not exist.
75         error = kErrorNotFound;
76         break;
77 
78     case kResponseNotImplemented: // Server does not support the query type (OpCode).
79     case kDsoTypeNotImplemented:  // DSO TLV type is not implemented.
80         error = kErrorNotImplemented;
81         break;
82 
83     case kResponseBadAlg: // Bad algorithm.
84         error = kErrorNotCapable;
85         break;
86 
87     case kResponseNameExists:   // Some name that ought not to exist, does exist.
88     case kResponseRecordExists: // Some RRset that ought not to exist, does exist.
89         error = kErrorDuplicated;
90         break;
91 
92     case kResponseRefused: // Server refused to perform operation for policy or security reasons.
93     case kResponseNotAuth: // Service is not authoritative for zone.
94         error = kErrorSecurity;
95         break;
96 
97     default:
98         break;
99     }
100 
101     return error;
102 }
103 
AppendTo(Message & aMessage) const104 Error Name::AppendTo(Message &aMessage) const
105 {
106     Error error;
107 
108     if (IsEmpty())
109     {
110         error = AppendTerminator(aMessage);
111     }
112     else if (IsFromCString())
113     {
114         error = AppendName(GetAsCString(), aMessage);
115     }
116     else
117     {
118         // Name is from a message. Read labels one by one from
119         // `mMessage` and and append each to the `aMessage`.
120 
121         LabelIterator iterator(*mMessage, mOffset);
122 
123         while (true)
124         {
125             error = iterator.GetNextLabel();
126 
127             switch (error)
128             {
129             case kErrorNone:
130                 SuccessOrExit(error = iterator.AppendLabel(aMessage));
131                 break;
132 
133             case kErrorNotFound:
134                 // We reached the end of name successfully.
135                 error = AppendTerminator(aMessage);
136 
137                 OT_FALL_THROUGH;
138 
139             default:
140                 ExitNow();
141             }
142         }
143     }
144 
145 exit:
146     return error;
147 }
148 
AppendLabel(const char * aLabel,Message & aMessage)149 Error Name::AppendLabel(const char *aLabel, Message &aMessage)
150 {
151     return AppendLabel(aLabel, static_cast<uint8_t>(StringLength(aLabel, kMaxLabelSize)), aMessage);
152 }
153 
AppendLabel(const char * aLabel,uint8_t aLength,Message & aMessage)154 Error Name::AppendLabel(const char *aLabel, uint8_t aLength, Message &aMessage)
155 {
156     Error error = kErrorNone;
157 
158     VerifyOrExit((0 < aLength) && (aLength <= kMaxLabelLength), error = kErrorInvalidArgs);
159 
160     SuccessOrExit(error = aMessage.Append(aLength));
161     error = aMessage.AppendBytes(aLabel, aLength);
162 
163 exit:
164     return error;
165 }
166 
AppendMultipleLabels(const char * aLabels,Message & aMessage)167 Error Name::AppendMultipleLabels(const char *aLabels, Message &aMessage)
168 {
169     return AppendMultipleLabels(aLabels, kMaxNameLength, aMessage);
170 }
171 
AppendMultipleLabels(const char * aLabels,uint8_t aLength,Message & aMessage)172 Error Name::AppendMultipleLabels(const char *aLabels, uint8_t aLength, Message &aMessage)
173 {
174     Error    error           = kErrorNone;
175     uint16_t index           = 0;
176     uint16_t labelStartIndex = 0;
177     char     ch;
178 
179     VerifyOrExit(aLabels != nullptr);
180 
181     do
182     {
183         ch = index < aLength ? aLabels[index] : static_cast<char>(kNullChar);
184 
185         if ((ch == kNullChar) || (ch == kLabelSeperatorChar))
186         {
187             uint8_t labelLength = static_cast<uint8_t>(index - labelStartIndex);
188 
189             if (labelLength == 0)
190             {
191                 // Empty label (e.g., consecutive dots) is invalid, but we
192                 // allow for two cases: (1) where `aLabels` ends with a dot
193                 // (`labelLength` is zero but we are at end of `aLabels` string
194                 // and `ch` is null char. (2) if `aLabels` is just "." (we
195                 // see a dot at index 0, and index 1 is null char).
196 
197                 error =
198                     ((ch == kNullChar) || ((index == 0) && (aLabels[1] == kNullChar))) ? kErrorNone : kErrorInvalidArgs;
199                 ExitNow();
200             }
201 
202             VerifyOrExit(index + 1 < kMaxEncodedLength, error = kErrorInvalidArgs);
203             SuccessOrExit(error = AppendLabel(&aLabels[labelStartIndex], labelLength, aMessage));
204 
205             labelStartIndex = index + 1;
206         }
207 
208         index++;
209 
210     } while (ch != kNullChar);
211 
212 exit:
213     return error;
214 }
215 
AppendTerminator(Message & aMessage)216 Error Name::AppendTerminator(Message &aMessage)
217 {
218     uint8_t terminator = 0;
219 
220     return aMessage.Append(terminator);
221 }
222 
AppendPointerLabel(uint16_t aOffset,Message & aMessage)223 Error Name::AppendPointerLabel(uint16_t aOffset, Message &aMessage)
224 {
225     Error    error;
226     uint16_t value;
227 
228 #if OPENTHREAD_CONFIG_REFERENCE_DEVICE_ENABLE
229     if (!Instance::IsDnsNameCompressionEnabled())
230     {
231         // If "DNS name compression" mode is disabled, instead of
232         // appending the pointer label, read the name from the message
233         // and append it uncompressed. Note that the `aOffset` parameter
234         // in this method is given relative to the start of DNS header
235         // in `aMessage` (which `aMessage.GetOffset()` specifies).
236 
237         error = Name(aMessage, aOffset + aMessage.GetOffset()).AppendTo(aMessage);
238         ExitNow();
239     }
240 #endif
241 
242     // A pointer label takes the form of a two byte sequence as a
243     // `uint16_t` value. The first two bits are ones. This allows a
244     // pointer to be distinguished from a text label, since the text
245     // label must begin with two zero bits (note that labels are
246     // restricted to 63 octets or less). The next 14-bits specify
247     // an offset value relative to start of DNS header.
248 
249     OT_ASSERT(aOffset < kPointerLabelTypeUint16);
250 
251     value = HostSwap16(aOffset | kPointerLabelTypeUint16);
252 
253     ExitNow(error = aMessage.Append(value));
254 
255 exit:
256     return error;
257 }
258 
AppendName(const char * aName,Message & aMessage)259 Error Name::AppendName(const char *aName, Message &aMessage)
260 {
261     Error error;
262 
263     SuccessOrExit(error = AppendMultipleLabels(aName, aMessage));
264     error = AppendTerminator(aMessage);
265 
266 exit:
267     return error;
268 }
269 
ParseName(const Message & aMessage,uint16_t & aOffset)270 Error Name::ParseName(const Message &aMessage, uint16_t &aOffset)
271 {
272     Error         error;
273     LabelIterator iterator(aMessage, aOffset);
274 
275     while (true)
276     {
277         error = iterator.GetNextLabel();
278 
279         switch (error)
280         {
281         case kErrorNone:
282             break;
283 
284         case kErrorNotFound:
285             // We reached the end of name successfully.
286             aOffset = iterator.mNameEndOffset;
287             error   = kErrorNone;
288 
289             OT_FALL_THROUGH;
290 
291         default:
292             ExitNow();
293         }
294     }
295 
296 exit:
297     return error;
298 }
299 
ReadLabel(const Message & aMessage,uint16_t & aOffset,char * aLabelBuffer,uint8_t & aLabelLength)300 Error Name::ReadLabel(const Message &aMessage, uint16_t &aOffset, char *aLabelBuffer, uint8_t &aLabelLength)
301 {
302     Error         error;
303     LabelIterator iterator(aMessage, aOffset);
304 
305     SuccessOrExit(error = iterator.GetNextLabel());
306     SuccessOrExit(error = iterator.ReadLabel(aLabelBuffer, aLabelLength, /* aAllowDotCharInLabel */ true));
307     aOffset = iterator.mNextLabelOffset;
308 
309 exit:
310     return error;
311 }
312 
ReadName(const Message & aMessage,uint16_t & aOffset,char * aNameBuffer,uint16_t aNameBufferSize)313 Error Name::ReadName(const Message &aMessage, uint16_t &aOffset, char *aNameBuffer, uint16_t aNameBufferSize)
314 {
315     Error         error;
316     LabelIterator iterator(aMessage, aOffset);
317     bool          firstLabel = true;
318     uint8_t       labelLength;
319 
320     while (true)
321     {
322         error = iterator.GetNextLabel();
323 
324         switch (error)
325         {
326         case kErrorNone:
327 
328             if (!firstLabel)
329             {
330                 *aNameBuffer++ = kLabelSeperatorChar;
331                 aNameBufferSize--;
332 
333                 // No need to check if we have reached end of the name buffer
334                 // here since `iterator.ReadLabel()` would verify it.
335             }
336 
337             labelLength = static_cast<uint8_t>(OT_MIN(static_cast<uint8_t>(kMaxLabelSize), aNameBufferSize));
338             SuccessOrExit(error = iterator.ReadLabel(aNameBuffer, labelLength, /* aAllowDotCharInLabel */ false));
339             aNameBuffer += labelLength;
340             aNameBufferSize -= labelLength;
341             firstLabel = false;
342             break;
343 
344         case kErrorNotFound:
345             // We reach the end of name successfully. Always add a terminating dot
346             // at the end.
347             *aNameBuffer++ = kLabelSeperatorChar;
348             aNameBufferSize--;
349             VerifyOrExit(aNameBufferSize >= sizeof(uint8_t), error = kErrorNoBufs);
350             *aNameBuffer = kNullChar;
351             aOffset      = iterator.mNameEndOffset;
352             error        = kErrorNone;
353 
354             OT_FALL_THROUGH;
355 
356         default:
357             ExitNow();
358         }
359     }
360 
361 exit:
362     return error;
363 }
364 
CompareLabel(const Message & aMessage,uint16_t & aOffset,const char * aLabel)365 Error Name::CompareLabel(const Message &aMessage, uint16_t &aOffset, const char *aLabel)
366 {
367     Error         error;
368     LabelIterator iterator(aMessage, aOffset);
369 
370     SuccessOrExit(error = iterator.GetNextLabel());
371     VerifyOrExit(iterator.CompareLabel(aLabel, kIsSingleLabel), error = kErrorNotFound);
372     aOffset = iterator.mNextLabelOffset;
373 
374 exit:
375     return error;
376 }
377 
CompareName(const Message & aMessage,uint16_t & aOffset,const char * aName)378 Error Name::CompareName(const Message &aMessage, uint16_t &aOffset, const char *aName)
379 {
380     Error         error;
381     LabelIterator iterator(aMessage, aOffset);
382     bool          matches = true;
383 
384     if (*aName == kLabelSeperatorChar)
385     {
386         aName++;
387         VerifyOrExit(*aName == kNullChar, error = kErrorInvalidArgs);
388     }
389 
390     while (true)
391     {
392         error = iterator.GetNextLabel();
393 
394         switch (error)
395         {
396         case kErrorNone:
397             if (matches && !iterator.CompareLabel(aName, !kIsSingleLabel))
398             {
399                 matches = false;
400             }
401 
402             break;
403 
404         case kErrorNotFound:
405             // We reached the end of the name in `aMessage`. We check if
406             // all the previous labels matched so far, and we are also
407             // at the end of `aName` string (see null char), then we
408             // return `kErrorNone` indicating a successful comparison
409             // (full match). Otherwise we return `kErrorNotFound` to
410             // indicate failed comparison.
411 
412             if (matches && (*aName == kNullChar))
413             {
414                 error = kErrorNone;
415             }
416 
417             aOffset = iterator.mNameEndOffset;
418 
419             OT_FALL_THROUGH;
420 
421         default:
422             ExitNow();
423         }
424     }
425 
426 exit:
427     return error;
428 }
429 
CompareName(const Message & aMessage,uint16_t & aOffset,const Message & aMessage2,uint16_t aOffset2)430 Error Name::CompareName(const Message &aMessage, uint16_t &aOffset, const Message &aMessage2, uint16_t aOffset2)
431 {
432     Error         error;
433     LabelIterator iterator(aMessage, aOffset);
434     LabelIterator iterator2(aMessage2, aOffset2);
435     bool          matches = true;
436 
437     while (true)
438     {
439         error = iterator.GetNextLabel();
440 
441         switch (error)
442         {
443         case kErrorNone:
444             // If all the previous labels matched so far, then verify
445             // that we can get the next label on `iterator2` and that it
446             // matches the label from `iterator`.
447             if (matches && (iterator2.GetNextLabel() != kErrorNone || !iterator.CompareLabel(iterator2)))
448             {
449                 matches = false;
450             }
451 
452             break;
453 
454         case kErrorNotFound:
455             // We reached the end of the name in `aMessage`. We check
456             // that `iterator2` is also at its end, and if all previous
457             // labels matched we return `kErrorNone`.
458 
459             if (matches && (iterator2.GetNextLabel() == kErrorNotFound))
460             {
461                 error = kErrorNone;
462             }
463 
464             aOffset = iterator.mNameEndOffset;
465 
466             OT_FALL_THROUGH;
467 
468         default:
469             ExitNow();
470         }
471     }
472 
473 exit:
474     return error;
475 }
476 
CompareName(const Message & aMessage,uint16_t & aOffset,const Name & aName)477 Error Name::CompareName(const Message &aMessage, uint16_t &aOffset, const Name &aName)
478 {
479     return aName.IsFromCString()
480                ? CompareName(aMessage, aOffset, aName.mString)
481                : (aName.IsFromMessage() ? CompareName(aMessage, aOffset, *aName.mMessage, aName.mOffset)
482                                         : ParseName(aMessage, aOffset));
483 }
484 
GetNextLabel(void)485 Error Name::LabelIterator::GetNextLabel(void)
486 {
487     Error error;
488 
489     while (true)
490     {
491         uint8_t labelLength;
492         uint8_t labelType;
493 
494         SuccessOrExit(error = mMessage.Read(mNextLabelOffset, labelLength));
495 
496         labelType = labelLength & kLabelTypeMask;
497 
498         if (labelType == kTextLabelType)
499         {
500             if (labelLength == 0)
501             {
502                 // Zero label length indicates end of a name.
503 
504                 if (!IsEndOffsetSet())
505                 {
506                     mNameEndOffset = mNextLabelOffset + sizeof(uint8_t);
507                 }
508 
509                 ExitNow(error = kErrorNotFound);
510             }
511 
512             mLabelStartOffset = mNextLabelOffset + sizeof(uint8_t);
513             mLabelLength      = labelLength;
514             mNextLabelOffset  = mLabelStartOffset + labelLength;
515             ExitNow();
516         }
517         else if (labelType == kPointerLabelType)
518         {
519             // A pointer label takes the form of a two byte sequence as a
520             // `uint16_t` value. The first two bits are ones. The next 14 bits
521             // specify an offset value from the start of the DNS header.
522 
523             uint16_t pointerValue;
524 
525             SuccessOrExit(error = mMessage.Read(mNextLabelOffset, pointerValue));
526 
527             if (!IsEndOffsetSet())
528             {
529                 mNameEndOffset = mNextLabelOffset + sizeof(uint16_t);
530             }
531 
532             // `mMessage.GetOffset()` must point to the start of the
533             // DNS header.
534             mNextLabelOffset = mMessage.GetOffset() + (HostSwap16(pointerValue) & kPointerLabelOffsetMask);
535 
536             // Go back through the `while(true)` loop to get the next label.
537         }
538         else
539         {
540             ExitNow(error = kErrorParse);
541         }
542     }
543 
544 exit:
545     return error;
546 }
547 
ReadLabel(char * aLabelBuffer,uint8_t & aLabelLength,bool aAllowDotCharInLabel) const548 Error Name::LabelIterator::ReadLabel(char *aLabelBuffer, uint8_t &aLabelLength, bool aAllowDotCharInLabel) const
549 {
550     Error error;
551 
552     VerifyOrExit(mLabelLength < aLabelLength, error = kErrorNoBufs);
553 
554     SuccessOrExit(error = mMessage.Read(mLabelStartOffset, aLabelBuffer, mLabelLength));
555     aLabelBuffer[mLabelLength] = kNullChar;
556     aLabelLength               = mLabelLength;
557 
558     if (!aAllowDotCharInLabel)
559     {
560         VerifyOrExit(StringFind(aLabelBuffer, kLabelSeperatorChar) == nullptr, error = kErrorParse);
561     }
562 
563 exit:
564     return error;
565 }
566 
CaseInsensitiveMatch(uint8_t aFirst,uint8_t aSecond)567 bool Name::LabelIterator::CaseInsensitiveMatch(uint8_t aFirst, uint8_t aSecond)
568 {
569     return ToLowercase(static_cast<char>(aFirst)) == ToLowercase(static_cast<char>(aSecond));
570 }
571 
CompareLabel(const char * & aName,bool aIsSingleLabel) const572 bool Name::LabelIterator::CompareLabel(const char *&aName, bool aIsSingleLabel) const
573 {
574     // This method compares the current label in the iterator with the
575     // `aName` string. `aIsSingleLabel` indicates whether `aName` is a
576     // single label, or a sequence of labels separated by dot '.' char.
577     // If the label matches `aName`, then `aName` pointer is moved
578     // forward to the start of the next label (skipping over the `.`
579     // char). This method returns `true` when the labels match, `false`
580     // otherwise.
581 
582     bool matches = false;
583 
584     VerifyOrExit(StringLength(aName, mLabelLength) == mLabelLength);
585     matches = mMessage.CompareBytes(mLabelStartOffset, aName, mLabelLength, CaseInsensitiveMatch);
586 
587     VerifyOrExit(matches);
588 
589     aName += mLabelLength;
590 
591     // If `aName` is a single label, we should be also at the end of the
592     // `aName` string. Otherwise, we should see either null or dot '.'
593     // character (in case `aName` contains multiple labels).
594 
595     matches = (*aName == kNullChar);
596 
597     if (!aIsSingleLabel && (*aName == kLabelSeperatorChar))
598     {
599         matches = true;
600         aName++;
601     }
602 
603 exit:
604     return matches;
605 }
606 
CompareLabel(const LabelIterator & aOtherIterator) const607 bool Name::LabelIterator::CompareLabel(const LabelIterator &aOtherIterator) const
608 {
609     // This method compares the current label in the iterator with the
610     // label from another iterator.
611 
612     return (mLabelLength == aOtherIterator.mLabelLength) &&
613            mMessage.CompareBytes(mLabelStartOffset, aOtherIterator.mMessage, aOtherIterator.mLabelStartOffset,
614                                  mLabelLength, CaseInsensitiveMatch);
615 }
616 
AppendLabel(Message & aMessage) const617 Error Name::LabelIterator::AppendLabel(Message &aMessage) const
618 {
619     // This method reads and appends the current label in the iterator
620     // to `aMessage`.
621 
622     Error error;
623 
624     VerifyOrExit((0 < mLabelLength) && (mLabelLength <= kMaxLabelLength), error = kErrorInvalidArgs);
625     SuccessOrExit(error = aMessage.Append(mLabelLength));
626     error = aMessage.AppendBytesFromMessage(mMessage, mLabelStartOffset, mLabelLength);
627 
628 exit:
629     return error;
630 }
631 
IsSubDomainOf(const char * aName,const char * aDomain)632 bool Name::IsSubDomainOf(const char *aName, const char *aDomain)
633 {
634     bool     match             = false;
635     bool     nameEndsWithDot   = false;
636     bool     domainEndsWithDot = false;
637     uint16_t nameLength        = StringLength(aName, kMaxNameLength);
638     uint16_t domainLength      = StringLength(aDomain, kMaxNameLength);
639 
640     if (nameLength > 0 && aName[nameLength - 1] == kLabelSeperatorChar)
641     {
642         nameEndsWithDot = true;
643         --nameLength;
644     }
645 
646     if (domainLength > 0 && aDomain[domainLength - 1] == kLabelSeperatorChar)
647     {
648         domainEndsWithDot = true;
649         --domainLength;
650     }
651 
652     VerifyOrExit(nameLength >= domainLength);
653 
654     aName += nameLength - domainLength;
655 
656     if (nameLength > domainLength)
657     {
658         VerifyOrExit(aName[-1] == kLabelSeperatorChar);
659     }
660 
661     // This method allows either `aName` or `aDomain` to include or
662     // exclude the last `.` character. If both include it or if both
663     // do not, we do a full comparison using `StringMatch()`.
664     // Otherwise (i.e., when one includes and the other one does not)
665     // we use `StringStartWith()` to allow the extra `.` character.
666 
667     if (nameEndsWithDot == domainEndsWithDot)
668     {
669         match = StringMatch(aName, aDomain, kStringCaseInsensitiveMatch);
670     }
671     else if (nameEndsWithDot)
672     {
673         // `aName` ends with dot, but `aDomain` does not.
674         match = StringStartsWith(aName, aDomain, kStringCaseInsensitiveMatch);
675     }
676     else
677     {
678         // `aDomain` ends with dot, but `aName` does not.
679         match = StringStartsWith(aDomain, aName, kStringCaseInsensitiveMatch);
680     }
681 
682 exit:
683     return match;
684 }
685 
ParseRecords(const Message & aMessage,uint16_t & aOffset,uint16_t aNumRecords)686 Error ResourceRecord::ParseRecords(const Message &aMessage, uint16_t &aOffset, uint16_t aNumRecords)
687 {
688     Error error = kErrorNone;
689 
690     while (aNumRecords > 0)
691     {
692         ResourceRecord record;
693 
694         SuccessOrExit(error = Name::ParseName(aMessage, aOffset));
695         SuccessOrExit(error = record.ReadFrom(aMessage, aOffset));
696         aOffset += static_cast<uint16_t>(record.GetSize());
697         aNumRecords--;
698     }
699 
700 exit:
701     return error;
702 }
703 
FindRecord(const Message & aMessage,uint16_t & aOffset,uint16_t & aNumRecords,const Name & aName)704 Error ResourceRecord::FindRecord(const Message &aMessage, uint16_t &aOffset, uint16_t &aNumRecords, const Name &aName)
705 {
706     Error error;
707 
708     while (aNumRecords > 0)
709     {
710         bool           matches = true;
711         ResourceRecord record;
712 
713         error = Name::CompareName(aMessage, aOffset, aName);
714 
715         switch (error)
716         {
717         case kErrorNone:
718             break;
719         case kErrorNotFound:
720             matches = false;
721             break;
722         default:
723             ExitNow();
724         }
725 
726         SuccessOrExit(error = record.ReadFrom(aMessage, aOffset));
727         aNumRecords--;
728         VerifyOrExit(!matches);
729         aOffset += static_cast<uint16_t>(record.GetSize());
730     }
731 
732     error = kErrorNotFound;
733 
734 exit:
735     return error;
736 }
737 
FindRecord(const Message & aMessage,uint16_t & aOffset,uint16_t aNumRecords,uint16_t aIndex,const Name & aName,uint16_t aType,ResourceRecord & aRecord,uint16_t aMinRecordSize)738 Error ResourceRecord::FindRecord(const Message & aMessage,
739                                  uint16_t &      aOffset,
740                                  uint16_t        aNumRecords,
741                                  uint16_t        aIndex,
742                                  const Name &    aName,
743                                  uint16_t        aType,
744                                  ResourceRecord &aRecord,
745                                  uint16_t        aMinRecordSize)
746 {
747     // This static method searches in `aMessage` starting from `aOffset`
748     // up to maximum of `aNumRecords`, for the `(aIndex+1)`th
749     // occurrence of a resource record of type `aType` with record name
750     // matching `aName`. It also verifies that the record size is larger
751     // than `aMinRecordSize`. If found, `aMinRecordSize` bytes from the
752     // record are read and copied into `aRecord`. In this case `aOffset`
753     // is updated to point to the last record byte read from the message
754     // (so that the caller can read any remaining fields in the record
755     // data).
756 
757     Error    error;
758     uint16_t offset = aOffset;
759     uint16_t recordOffset;
760 
761     while (aNumRecords > 0)
762     {
763         SuccessOrExit(error = FindRecord(aMessage, offset, aNumRecords, aName));
764 
765         // Save the offset to start of `ResourceRecord` fields.
766         recordOffset = offset;
767 
768         error = ReadRecord(aMessage, offset, aType, aRecord, aMinRecordSize);
769 
770         if (error == kErrorNotFound)
771         {
772             // `ReadRecord()` already updates the `offset` to skip
773             // over a non-matching record.
774             continue;
775         }
776 
777         SuccessOrExit(error);
778 
779         if (aIndex == 0)
780         {
781             aOffset = offset;
782             ExitNow();
783         }
784 
785         aIndex--;
786 
787         // Skip over the record.
788         offset = static_cast<uint16_t>(recordOffset + aRecord.GetSize());
789     }
790 
791     error = kErrorNotFound;
792 
793 exit:
794     return error;
795 }
796 
ReadRecord(const Message & aMessage,uint16_t & aOffset,uint16_t aType,ResourceRecord & aRecord,uint16_t aMinRecordSize)797 Error ResourceRecord::ReadRecord(const Message & aMessage,
798                                  uint16_t &      aOffset,
799                                  uint16_t        aType,
800                                  ResourceRecord &aRecord,
801                                  uint16_t        aMinRecordSize)
802 {
803     // This static method tries to read a matching resource record of a
804     // given type and a minimum record size from a message. The `aType`
805     // value of `kTypeAny` matches any type.  If the record in the
806     // message does not match, it skips over the record. Please see
807     // `ReadRecord<RecordType>()` for more details.
808 
809     Error          error;
810     ResourceRecord record;
811 
812     SuccessOrExit(error = record.ReadFrom(aMessage, aOffset));
813 
814     if (((aType == kTypeAny) || (record.GetType() == aType)) && (record.GetSize() >= aMinRecordSize))
815     {
816         IgnoreError(aMessage.Read(aOffset, &aRecord, aMinRecordSize));
817         aOffset += aMinRecordSize;
818     }
819     else
820     {
821         // Skip over the entire record.
822         aOffset += static_cast<uint16_t>(record.GetSize());
823         error = kErrorNotFound;
824     }
825 
826 exit:
827     return error;
828 }
829 
ReadName(const Message & aMessage,uint16_t & aOffset,uint16_t aStartOffset,char * aNameBuffer,uint16_t aNameBufferSize,bool aSkipRecord) const830 Error ResourceRecord::ReadName(const Message &aMessage,
831                                uint16_t &     aOffset,
832                                uint16_t       aStartOffset,
833                                char *         aNameBuffer,
834                                uint16_t       aNameBufferSize,
835                                bool           aSkipRecord) const
836 {
837     // This protected method parses and reads a name field in a record
838     // from a message. It is intended only for sub-classes of
839     // `ResourceRecord`.
840     //
841     // On input `aOffset` gives the offset in `aMessage` to the start of
842     // name field. `aStartOffset` gives the offset to the start of the
843     // `ResourceRecord`. `aSkipRecord` indicates whether to skip over
844     // the entire resource record or just the read name. On exit, when
845     // successfully read, `aOffset` is updated to either point after the
846     // end of record or after the the name field.
847     //
848     // When read successfully, this method returns `kErrorNone`. On a
849     // parse error (invalid format) returns `kErrorParse`. If the
850     // name does not fit in the given name buffer it returns
851     // `kErrorNoBufs`
852 
853     Error error = kErrorNone;
854 
855     SuccessOrExit(error = Name::ReadName(aMessage, aOffset, aNameBuffer, aNameBufferSize));
856     VerifyOrExit(aOffset <= aStartOffset + GetSize(), error = kErrorParse);
857 
858     VerifyOrExit(aSkipRecord);
859     aOffset = aStartOffset;
860     error   = SkipRecord(aMessage, aOffset);
861 
862 exit:
863     return error;
864 }
865 
SkipRecord(const Message & aMessage,uint16_t & aOffset) const866 Error ResourceRecord::SkipRecord(const Message &aMessage, uint16_t &aOffset) const
867 {
868     // This protected method parses and skips over a resource record
869     // in a message.
870     //
871     // On input `aOffset` gives the offset in `aMessage` to the start of
872     // the `ResourceRecord`. On exit, when successfully parsed, `aOffset`
873     // is updated to point to byte after the entire record.
874 
875     Error error;
876 
877     SuccessOrExit(error = CheckRecord(aMessage, aOffset));
878     aOffset += static_cast<uint16_t>(GetSize());
879 
880 exit:
881     return error;
882 }
883 
CheckRecord(const Message & aMessage,uint16_t aOffset) const884 Error ResourceRecord::CheckRecord(const Message &aMessage, uint16_t aOffset) const
885 {
886     // This method checks that the entire record (including record data)
887     // is present in `aMessage` at `aOffset` (pointing to the start of
888     // the `ResourceRecord` in `aMessage`).
889 
890     return (aOffset + GetSize() <= aMessage.GetLength()) ? kErrorNone : kErrorParse;
891 }
892 
ReadFrom(const Message & aMessage,uint16_t aOffset)893 Error ResourceRecord::ReadFrom(const Message &aMessage, uint16_t aOffset)
894 {
895     // This method reads the `ResourceRecord` from `aMessage` at
896     // `aOffset`. It verifies that the entire record (including record
897     // data) is present in the message.
898 
899     Error error;
900 
901     SuccessOrExit(error = aMessage.Read(aOffset, *this));
902     error = CheckRecord(aMessage, aOffset);
903 
904 exit:
905     return error;
906 }
907 
Init(const uint8_t * aTxtData,uint16_t aTxtDataLength)908 void TxtEntry::Iterator::Init(const uint8_t *aTxtData, uint16_t aTxtDataLength)
909 {
910     SetTxtData(aTxtData);
911     SetTxtDataLength(aTxtDataLength);
912     SetTxtDataPosition(0);
913 }
914 
GetNextEntry(TxtEntry & aEntry)915 Error TxtEntry::Iterator::GetNextEntry(TxtEntry &aEntry)
916 {
917     Error       error = kErrorNone;
918     uint8_t     length;
919     uint8_t     index;
920     const char *cur;
921     char *      keyBuffer = GetKeyBuffer();
922 
923     static_assert(sizeof(mChar) == TxtEntry::kMaxKeyLength + 1, "KeyBuffer cannot fit the max key length");
924 
925     VerifyOrExit(GetTxtData() != nullptr, error = kErrorParse);
926 
927     aEntry.mKey = keyBuffer;
928 
929     while ((cur = GetTxtData() + GetTxtDataPosition()) < GetTxtDataEnd())
930     {
931         length = static_cast<uint8_t>(*cur);
932 
933         cur++;
934         VerifyOrExit(cur + length <= GetTxtDataEnd(), error = kErrorParse);
935         IncreaseTxtDataPosition(sizeof(uint8_t) + length);
936 
937         // Silently skip over an empty string or if the string starts with
938         // a `=` character (i.e., missing key) - RFC 6763 - section 6.4.
939 
940         if ((length == 0) || (cur[0] == kKeyValueSeparator))
941         {
942             continue;
943         }
944 
945         for (index = 0; index < length; index++)
946         {
947             if (cur[index] == kKeyValueSeparator)
948             {
949                 keyBuffer[index++]  = kNullChar; // Increment index to skip over `=`.
950                 aEntry.mValue       = reinterpret_cast<const uint8_t *>(&cur[index]);
951                 aEntry.mValueLength = length - index;
952                 ExitNow();
953             }
954 
955             if (index >= kMaxKeyLength)
956             {
957                 // The key is larger than recommended max key length.
958                 // In this case, we return the full encoded string in
959                 // `mValue` and `mValueLength` and set `mKey` to
960                 // `nullptr`.
961 
962                 aEntry.mKey         = nullptr;
963                 aEntry.mValue       = reinterpret_cast<const uint8_t *>(cur);
964                 aEntry.mValueLength = length;
965                 ExitNow();
966             }
967 
968             keyBuffer[index] = cur[index];
969         }
970 
971         // If we reach the end of the string without finding `=` then
972         // it is a boolean key attribute (encoded as "key").
973 
974         keyBuffer[index]    = kNullChar;
975         aEntry.mValue       = nullptr;
976         aEntry.mValueLength = 0;
977         ExitNow();
978     }
979 
980     error = kErrorNotFound;
981 
982 exit:
983     return error;
984 }
985 
AppendTo(Message & aMessage) const986 Error TxtEntry::AppendTo(Message &aMessage) const
987 {
988     Appender appender(aMessage);
989 
990     return AppendTo(appender);
991 }
992 
AppendTo(Appender & aAppender) const993 Error TxtEntry::AppendTo(Appender &aAppender) const
994 {
995     Error    error = kErrorNone;
996     uint16_t keyLength;
997     char     separator = kKeyValueSeparator;
998 
999     if (mKey == nullptr)
1000     {
1001         VerifyOrExit((mValue != nullptr) && (mValueLength != 0));
1002         error = aAppender.AppendBytes(mValue, mValueLength);
1003         ExitNow();
1004     }
1005 
1006     keyLength = StringLength(mKey, static_cast<uint16_t>(kMaxKeyValueEncodedSize) + 1);
1007 
1008     VerifyOrExit(kMinKeyLength <= keyLength, error = kErrorInvalidArgs);
1009 
1010     if (mValue == nullptr)
1011     {
1012         // Treat as a boolean attribute and encoded as "key" (with no `=`).
1013         SuccessOrExit(error = aAppender.Append<uint8_t>(static_cast<uint8_t>(keyLength)));
1014         error = aAppender.AppendBytes(mKey, keyLength);
1015         ExitNow();
1016     }
1017 
1018     // Treat as key/value and encode as "key=value", value may be empty.
1019 
1020     VerifyOrExit(mValueLength + keyLength + sizeof(char) <= kMaxKeyValueEncodedSize, error = kErrorInvalidArgs);
1021 
1022     SuccessOrExit(error = aAppender.Append<uint8_t>(static_cast<uint8_t>(keyLength + mValueLength + sizeof(char))));
1023     SuccessOrExit(error = aAppender.AppendBytes(mKey, keyLength));
1024     SuccessOrExit(error = aAppender.Append(separator));
1025     error = aAppender.AppendBytes(mValue, mValueLength);
1026 
1027 exit:
1028     return error;
1029 }
1030 
AppendEntries(const TxtEntry * aEntries,uint8_t aNumEntries,Message & aMessage)1031 Error TxtEntry::AppendEntries(const TxtEntry *aEntries, uint8_t aNumEntries, Message &aMessage)
1032 {
1033     Appender appender(aMessage);
1034 
1035     return AppendEntries(aEntries, aNumEntries, appender);
1036 }
1037 
AppendEntries(const TxtEntry * aEntries,uint8_t aNumEntries,MutableData<kWithUint16Length> & aData)1038 Error TxtEntry::AppendEntries(const TxtEntry *aEntries, uint8_t aNumEntries, MutableData<kWithUint16Length> &aData)
1039 {
1040     Error    error;
1041     Appender appender(aData.GetBytes(), aData.GetLength());
1042 
1043     SuccessOrExit(error = AppendEntries(aEntries, aNumEntries, appender));
1044     appender.GetAsData(aData);
1045 
1046 exit:
1047     return error;
1048 }
1049 
AppendEntries(const TxtEntry * aEntries,uint8_t aNumEntries,Appender & aAppender)1050 Error TxtEntry::AppendEntries(const TxtEntry *aEntries, uint8_t aNumEntries, Appender &aAppender)
1051 {
1052     Error error = kErrorNone;
1053 
1054     for (uint8_t index = 0; index < aNumEntries; index++)
1055     {
1056         SuccessOrExit(error = aEntries[index].AppendTo(aAppender));
1057     }
1058 
1059     if (aAppender.GetAppendedLength() == 0)
1060     {
1061         error = aAppender.Append<uint8_t>(0);
1062     }
1063 
1064 exit:
1065     return error;
1066 }
1067 
IsValid(void) const1068 bool AaaaRecord::IsValid(void) const
1069 {
1070     return GetType() == Dns::ResourceRecord::kTypeAaaa && GetSize() == sizeof(*this);
1071 }
1072 
IsValid(void) const1073 bool KeyRecord::IsValid(void) const
1074 {
1075     return GetType() == Dns::ResourceRecord::kTypeKey;
1076 }
1077 
1078 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
Init(void)1079 void Ecdsa256KeyRecord::Init(void)
1080 {
1081     KeyRecord::Init();
1082     SetAlgorithm(kAlgorithmEcdsaP256Sha256);
1083 }
1084 
IsValid(void) const1085 bool Ecdsa256KeyRecord::IsValid(void) const
1086 {
1087     return KeyRecord::IsValid() && GetLength() == sizeof(*this) - sizeof(ResourceRecord) &&
1088            GetAlgorithm() == kAlgorithmEcdsaP256Sha256;
1089 }
1090 #endif
1091 
IsValid(void) const1092 bool SigRecord::IsValid(void) const
1093 {
1094     return GetType() == Dns::ResourceRecord::kTypeSig && GetLength() >= sizeof(*this) - sizeof(ResourceRecord);
1095 }
1096 
IsValid(void) const1097 bool LeaseOption::IsValid(void) const
1098 {
1099     return GetLeaseInterval() <= GetKeyLeaseInterval();
1100 }
1101 
ReadPtrName(const Message & aMessage,uint16_t & aOffset,char * aLabelBuffer,uint8_t aLabelBufferSize,char * aNameBuffer,uint16_t aNameBufferSize) const1102 Error PtrRecord::ReadPtrName(const Message &aMessage,
1103                              uint16_t &     aOffset,
1104                              char *         aLabelBuffer,
1105                              uint8_t        aLabelBufferSize,
1106                              char *         aNameBuffer,
1107                              uint16_t       aNameBufferSize) const
1108 {
1109     Error    error       = kErrorNone;
1110     uint16_t startOffset = aOffset - sizeof(PtrRecord); // start of `PtrRecord`.
1111 
1112     // Verify that the name is within the record data length.
1113     SuccessOrExit(error = Name::ParseName(aMessage, aOffset));
1114     VerifyOrExit(aOffset <= startOffset + GetSize(), error = kErrorParse);
1115 
1116     aOffset = startOffset + sizeof(PtrRecord);
1117     SuccessOrExit(error = Name::ReadLabel(aMessage, aOffset, aLabelBuffer, aLabelBufferSize));
1118 
1119     if (aNameBuffer != nullptr)
1120     {
1121         SuccessOrExit(error = Name::ReadName(aMessage, aOffset, aNameBuffer, aNameBufferSize));
1122     }
1123 
1124     aOffset = startOffset;
1125     error   = SkipRecord(aMessage, aOffset);
1126 
1127 exit:
1128     return error;
1129 }
1130 
ReadTxtData(const Message & aMessage,uint16_t & aOffset,uint8_t * aTxtBuffer,uint16_t & aTxtBufferSize) const1131 Error TxtRecord::ReadTxtData(const Message &aMessage,
1132                              uint16_t &     aOffset,
1133                              uint8_t *      aTxtBuffer,
1134                              uint16_t &     aTxtBufferSize) const
1135 {
1136     Error error = kErrorNone;
1137 
1138     VerifyOrExit(GetLength() <= aTxtBufferSize, error = kErrorNoBufs);
1139     SuccessOrExit(error = aMessage.Read(aOffset, aTxtBuffer, GetLength()));
1140     VerifyOrExit(VerifyTxtData(aTxtBuffer, GetLength(), /* aAllowEmpty */ true), error = kErrorParse);
1141     aTxtBufferSize = GetLength();
1142     aOffset += GetLength();
1143 
1144 exit:
1145     return error;
1146 }
1147 
VerifyTxtData(const uint8_t * aTxtData,uint16_t aTxtLength,bool aAllowEmpty)1148 bool TxtRecord::VerifyTxtData(const uint8_t *aTxtData, uint16_t aTxtLength, bool aAllowEmpty)
1149 {
1150     bool    valid          = false;
1151     uint8_t curEntryLength = 0;
1152 
1153     // Per RFC 1035, TXT-DATA MUST have one or more <character-string>s.
1154     VerifyOrExit(aAllowEmpty || aTxtLength > 0);
1155 
1156     for (uint16_t i = 0; i < aTxtLength; ++i)
1157     {
1158         if (curEntryLength == 0)
1159         {
1160             curEntryLength = aTxtData[i];
1161         }
1162         else
1163         {
1164             --curEntryLength;
1165         }
1166     }
1167 
1168     valid = (curEntryLength == 0);
1169 
1170 exit:
1171     return valid;
1172 }
1173 
1174 } // namespace Dns
1175 } // namespace ot
1176