• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2020, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements generating and processing of DNS headers and helper functions/methods.
32  */
33 
34 #include "dns_types.hpp"
35 
36 #include "instance/instance.hpp"
37 
38 namespace ot {
39 namespace Dns {
40 
SetRandomMessageId(void)41 Error Header::SetRandomMessageId(void)
42 {
43     return Random::Crypto::FillBuffer(reinterpret_cast<uint8_t *>(&mMessageId), sizeof(mMessageId));
44 }
45 
ResponseCodeToError(Response aResponse)46 Error Header::ResponseCodeToError(Response aResponse)
47 {
48     Error error = kErrorFailed;
49 
50     switch (aResponse)
51     {
52     case kResponseSuccess:
53         error = kErrorNone;
54         break;
55 
56     case kResponseFormatError:   // Server unable to interpret request due to format error.
57     case kResponseBadName:       // Bad name.
58     case kResponseBadTruncation: // Bad truncation.
59     case kResponseNotZone:       // A name is not in the zone.
60         error = kErrorParse;
61         break;
62 
63     case kResponseServerFailure: // Server encountered an internal failure.
64         error = kErrorFailed;
65         break;
66 
67     case kResponseNameError:       // Name that ought to exist, does not exists.
68     case kResponseRecordNotExists: // Some RRset that ought to exist, does not exist.
69         error = kErrorNotFound;
70         break;
71 
72     case kResponseNotImplemented: // Server does not support the query type (OpCode).
73     case kDsoTypeNotImplemented:  // DSO TLV type is not implemented.
74         error = kErrorNotImplemented;
75         break;
76 
77     case kResponseBadAlg: // Bad algorithm.
78         error = kErrorNotCapable;
79         break;
80 
81     case kResponseNameExists:   // Some name that ought not to exist, does exist.
82     case kResponseRecordExists: // Some RRset that ought not to exist, does exist.
83         error = kErrorDuplicated;
84         break;
85 
86     case kResponseRefused: // Server refused to perform operation for policy or security reasons.
87     case kResponseNotAuth: // Service is not authoritative for zone.
88         error = kErrorSecurity;
89         break;
90 
91     default:
92         break;
93     }
94 
95     return error;
96 }
97 
Matches(const char * aFirstLabel,const char * aLabels,const char * aDomain) const98 bool Name::Matches(const char *aFirstLabel, const char *aLabels, const char *aDomain) const
99 {
100     bool matches = false;
101 
102     VerifyOrExit(!IsEmpty());
103 
104     if (IsFromCString())
105     {
106         const char *namePtr = mString;
107 
108         if (aFirstLabel != nullptr)
109         {
110             matches = CompareAndSkipLabels(namePtr, aFirstLabel, kLabelSeparatorChar);
111             VerifyOrExit(matches);
112         }
113 
114         matches = CompareAndSkipLabels(namePtr, aLabels, kLabelSeparatorChar);
115         VerifyOrExit(matches);
116 
117         matches = CompareAndSkipLabels(namePtr, aDomain, kNullChar);
118     }
119     else
120     {
121         uint16_t offset = mOffset;
122 
123         if (aFirstLabel != nullptr)
124         {
125             SuccessOrExit(CompareLabel(*mMessage, offset, aFirstLabel));
126         }
127 
128         SuccessOrExit(CompareMultipleLabels(*mMessage, offset, aLabels));
129         SuccessOrExit(CompareName(*mMessage, offset, aDomain));
130         matches = true;
131     }
132 
133 exit:
134     return matches;
135 }
136 
CompareAndSkipLabels(const char * & aNamePtr,const char * aLabels,char aExpectedNextChar)137 bool Name::CompareAndSkipLabels(const char *&aNamePtr, const char *aLabels, char aExpectedNextChar)
138 {
139     // Compares `aNamePtr` to the label string `aLabels` followed by
140     // the `aExpectedNextChar`(using case-insensitive match). Upon
141     // successful comparison, `aNamePtr` is advanced to point after
142     // the matched portion.
143 
144     bool     matches = false;
145     uint16_t len     = StringLength(aLabels, kMaxNameSize);
146 
147     VerifyOrExit(len < kMaxNameSize);
148 
149     VerifyOrExit(StringStartsWith(aNamePtr, aLabels, kStringCaseInsensitiveMatch));
150     aNamePtr += len;
151 
152     VerifyOrExit(*aNamePtr == aExpectedNextChar);
153     aNamePtr++;
154 
155     matches = true;
156 
157 exit:
158     return matches;
159 }
160 
AppendTo(Message & aMessage) const161 Error Name::AppendTo(Message &aMessage) const
162 {
163     Error error;
164 
165     if (IsEmpty())
166     {
167         error = AppendTerminator(aMessage);
168     }
169     else if (IsFromCString())
170     {
171         error = AppendName(GetAsCString(), aMessage);
172     }
173     else
174     {
175         // Name is from a message. Read labels one by one from
176         // `mMessage` and and append each to the `aMessage`.
177 
178         LabelIterator iterator(*mMessage, mOffset);
179 
180         while (true)
181         {
182             error = iterator.GetNextLabel();
183 
184             switch (error)
185             {
186             case kErrorNone:
187                 SuccessOrExit(error = iterator.AppendLabel(aMessage));
188                 break;
189 
190             case kErrorNotFound:
191                 // We reached the end of name successfully.
192                 error = AppendTerminator(aMessage);
193 
194                 OT_FALL_THROUGH;
195 
196             default:
197                 ExitNow();
198             }
199         }
200     }
201 
202 exit:
203     return error;
204 }
205 
AppendLabel(const char * aLabel,Message & aMessage)206 Error Name::AppendLabel(const char *aLabel, Message &aMessage)
207 {
208     return AppendLabel(aLabel, static_cast<uint8_t>(StringLength(aLabel, kMaxLabelSize)), aMessage);
209 }
210 
AppendLabel(const char * aLabel,uint8_t aLength,Message & aMessage)211 Error Name::AppendLabel(const char *aLabel, uint8_t aLength, Message &aMessage)
212 {
213     Error error = kErrorNone;
214 
215     VerifyOrExit((0 < aLength) && (aLength <= kMaxLabelLength), error = kErrorInvalidArgs);
216 
217     SuccessOrExit(error = aMessage.Append(aLength));
218     error = aMessage.AppendBytes(aLabel, aLength);
219 
220 exit:
221     return error;
222 }
223 
AppendMultipleLabels(const char * aLabels,Message & aMessage)224 Error Name::AppendMultipleLabels(const char *aLabels, Message &aMessage)
225 {
226     Error    error           = kErrorNone;
227     uint16_t index           = 0;
228     uint16_t labelStartIndex = 0;
229     char     ch;
230 
231     VerifyOrExit(aLabels != nullptr);
232 
233     do
234     {
235         ch = aLabels[index];
236 
237         if ((ch == kNullChar) || (ch == kLabelSeparatorChar))
238         {
239             uint8_t labelLength = static_cast<uint8_t>(index - labelStartIndex);
240 
241             if (labelLength == 0)
242             {
243                 // Empty label (e.g., consecutive dots) is invalid, but we
244                 // allow for two cases: (1) where `aLabels` ends with a dot
245                 // (`labelLength` is zero but we are at end of `aLabels` string
246                 // and `ch` is null char. (2) if `aLabels` is just "." (we
247                 // see a dot at index 0, and index 1 is null char).
248 
249                 error =
250                     ((ch == kNullChar) || ((index == 0) && (aLabels[1] == kNullChar))) ? kErrorNone : kErrorInvalidArgs;
251                 ExitNow();
252             }
253 
254             VerifyOrExit(index + 1 < kMaxEncodedLength, error = kErrorInvalidArgs);
255             SuccessOrExit(error = AppendLabel(&aLabels[labelStartIndex], labelLength, aMessage));
256 
257             labelStartIndex = index + 1;
258         }
259 
260         index++;
261 
262     } while (ch != kNullChar);
263 
264 exit:
265     return error;
266 }
267 
AppendTerminator(Message & aMessage)268 Error Name::AppendTerminator(Message &aMessage)
269 {
270     uint8_t terminator = 0;
271 
272     return aMessage.Append(terminator);
273 }
274 
AppendPointerLabel(uint16_t aOffset,Message & aMessage)275 Error Name::AppendPointerLabel(uint16_t aOffset, Message &aMessage)
276 {
277     Error    error;
278     uint16_t value;
279 
280 #if OPENTHREAD_CONFIG_REFERENCE_DEVICE_ENABLE
281     if (!Instance::IsDnsNameCompressionEnabled())
282     {
283         // If "DNS name compression" mode is disabled, instead of
284         // appending the pointer label, read the name from the message
285         // and append it uncompressed. Note that the `aOffset` parameter
286         // in this method is given relative to the start of DNS header
287         // in `aMessage` (which `aMessage.GetOffset()` specifies).
288 
289         error = Name(aMessage, aOffset + aMessage.GetOffset()).AppendTo(aMessage);
290         ExitNow();
291     }
292 #endif
293 
294     // A pointer label takes the form of a two byte sequence as a
295     // `uint16_t` value. The first two bits are ones. This allows a
296     // pointer to be distinguished from a text label, since the text
297     // label must begin with two zero bits (note that labels are
298     // restricted to 63 octets or less). The next 14-bits specify
299     // an offset value relative to start of DNS header.
300 
301     OT_ASSERT(aOffset < kPointerLabelTypeUint16);
302 
303     value = BigEndian::HostSwap16(aOffset | kPointerLabelTypeUint16);
304 
305     ExitNow(error = aMessage.Append(value));
306 
307 exit:
308     return error;
309 }
310 
AppendName(const char * aName,Message & aMessage)311 Error Name::AppendName(const char *aName, Message &aMessage)
312 {
313     Error error;
314 
315     SuccessOrExit(error = AppendMultipleLabels(aName, aMessage));
316     error = AppendTerminator(aMessage);
317 
318 exit:
319     return error;
320 }
321 
ParseName(const Message & aMessage,uint16_t & aOffset)322 Error Name::ParseName(const Message &aMessage, uint16_t &aOffset)
323 {
324     Error         error;
325     LabelIterator iterator(aMessage, aOffset);
326 
327     while (true)
328     {
329         error = iterator.GetNextLabel();
330 
331         switch (error)
332         {
333         case kErrorNone:
334             break;
335 
336         case kErrorNotFound:
337             // We reached the end of name successfully.
338             aOffset = iterator.mNameEndOffset;
339             error   = kErrorNone;
340 
341             OT_FALL_THROUGH;
342 
343         default:
344             ExitNow();
345         }
346     }
347 
348 exit:
349     return error;
350 }
351 
ReadLabel(const Message & aMessage,uint16_t & aOffset,char * aLabelBuffer,uint8_t & aLabelLength)352 Error Name::ReadLabel(const Message &aMessage, uint16_t &aOffset, char *aLabelBuffer, uint8_t &aLabelLength)
353 {
354     Error         error;
355     LabelIterator iterator(aMessage, aOffset);
356 
357     SuccessOrExit(error = iterator.GetNextLabel());
358     SuccessOrExit(error = iterator.ReadLabel(aLabelBuffer, aLabelLength, /* aAllowDotCharInLabel */ true));
359     aOffset = iterator.mNextLabelOffset;
360 
361 exit:
362     return error;
363 }
364 
ReadName(const Message & aMessage,uint16_t & aOffset,char * aNameBuffer,uint16_t aNameBufferSize)365 Error Name::ReadName(const Message &aMessage, uint16_t &aOffset, char *aNameBuffer, uint16_t aNameBufferSize)
366 {
367     Error         error;
368     LabelIterator iterator(aMessage, aOffset);
369     bool          firstLabel = true;
370     uint8_t       labelLength;
371 
372     while (true)
373     {
374         error = iterator.GetNextLabel();
375 
376         switch (error)
377         {
378         case kErrorNone:
379 
380             if (!firstLabel)
381             {
382                 *aNameBuffer++ = kLabelSeparatorChar;
383                 aNameBufferSize--;
384 
385                 // No need to check if we have reached end of the name buffer
386                 // here since `iterator.ReadLabel()` would verify it.
387             }
388 
389             labelLength = static_cast<uint8_t>(Min(static_cast<uint16_t>(kMaxLabelSize), aNameBufferSize));
390             SuccessOrExit(error = iterator.ReadLabel(aNameBuffer, labelLength, /* aAllowDotCharInLabel */ firstLabel));
391             aNameBuffer += labelLength;
392             aNameBufferSize -= labelLength;
393             firstLabel = false;
394             break;
395 
396         case kErrorNotFound:
397             // We reach the end of name successfully. Always add a terminating dot
398             // at the end.
399             *aNameBuffer++ = kLabelSeparatorChar;
400             aNameBufferSize--;
401             VerifyOrExit(aNameBufferSize >= sizeof(uint8_t), error = kErrorNoBufs);
402             *aNameBuffer = kNullChar;
403             aOffset      = iterator.mNameEndOffset;
404             error        = kErrorNone;
405 
406             OT_FALL_THROUGH;
407 
408         default:
409             ExitNow();
410         }
411     }
412 
413 exit:
414     return error;
415 }
416 
CompareLabel(const Message & aMessage,uint16_t & aOffset,const char * aLabel)417 Error Name::CompareLabel(const Message &aMessage, uint16_t &aOffset, const char *aLabel)
418 {
419     Error         error;
420     LabelIterator iterator(aMessage, aOffset);
421 
422     SuccessOrExit(error = iterator.GetNextLabel());
423     VerifyOrExit(iterator.CompareLabel(aLabel, kIsSingleLabel), error = kErrorNotFound);
424     aOffset = iterator.mNextLabelOffset;
425 
426 exit:
427     return error;
428 }
429 
CompareMultipleLabels(const Message & aMessage,uint16_t & aOffset,const char * aLabels)430 Error Name::CompareMultipleLabels(const Message &aMessage, uint16_t &aOffset, const char *aLabels)
431 {
432     Error         error;
433     LabelIterator iterator(aMessage, aOffset);
434 
435     while (true)
436     {
437         SuccessOrExit(error = iterator.GetNextLabel());
438         VerifyOrExit(iterator.CompareLabel(aLabels, !kIsSingleLabel), error = kErrorNotFound);
439 
440         if (*aLabels == kNullChar)
441         {
442             aOffset = iterator.mNextLabelOffset;
443             ExitNow();
444         }
445     }
446 
447 exit:
448     return error;
449 }
450 
CompareName(const Message & aMessage,uint16_t & aOffset,const char * aName)451 Error Name::CompareName(const Message &aMessage, uint16_t &aOffset, const char *aName)
452 {
453     Error         error;
454     LabelIterator iterator(aMessage, aOffset);
455     bool          matches = true;
456 
457     if (*aName == kLabelSeparatorChar)
458     {
459         aName++;
460         VerifyOrExit(*aName == kNullChar, error = kErrorInvalidArgs);
461     }
462 
463     while (true)
464     {
465         error = iterator.GetNextLabel();
466 
467         switch (error)
468         {
469         case kErrorNone:
470             if (matches && !iterator.CompareLabel(aName, !kIsSingleLabel))
471             {
472                 matches = false;
473             }
474 
475             break;
476 
477         case kErrorNotFound:
478             // We reached the end of the name in `aMessage`. We check if
479             // all the previous labels matched so far, and we are also
480             // at the end of `aName` string (see null char), then we
481             // return `kErrorNone` indicating a successful comparison
482             // (full match). Otherwise we return `kErrorNotFound` to
483             // indicate failed comparison.
484 
485             if (matches && (*aName == kNullChar))
486             {
487                 error = kErrorNone;
488             }
489 
490             aOffset = iterator.mNameEndOffset;
491 
492             OT_FALL_THROUGH;
493 
494         default:
495             ExitNow();
496         }
497     }
498 
499 exit:
500     return error;
501 }
502 
CompareName(const Message & aMessage,uint16_t & aOffset,const Message & aMessage2,uint16_t aOffset2)503 Error Name::CompareName(const Message &aMessage, uint16_t &aOffset, const Message &aMessage2, uint16_t aOffset2)
504 {
505     Error         error;
506     LabelIterator iterator(aMessage, aOffset);
507     LabelIterator iterator2(aMessage2, aOffset2);
508     bool          matches = true;
509 
510     while (true)
511     {
512         error = iterator.GetNextLabel();
513 
514         switch (error)
515         {
516         case kErrorNone:
517             // If all the previous labels matched so far, then verify
518             // that we can get the next label on `iterator2` and that it
519             // matches the label from `iterator`.
520             if (matches && (iterator2.GetNextLabel() != kErrorNone || !iterator.CompareLabel(iterator2)))
521             {
522                 matches = false;
523             }
524 
525             break;
526 
527         case kErrorNotFound:
528             // We reached the end of the name in `aMessage`. We check
529             // that `iterator2` is also at its end, and if all previous
530             // labels matched we return `kErrorNone`.
531 
532             if (matches && (iterator2.GetNextLabel() == kErrorNotFound))
533             {
534                 error = kErrorNone;
535             }
536 
537             aOffset = iterator.mNameEndOffset;
538 
539             OT_FALL_THROUGH;
540 
541         default:
542             ExitNow();
543         }
544     }
545 
546 exit:
547     return error;
548 }
549 
CompareName(const Message & aMessage,uint16_t & aOffset,const Name & aName)550 Error Name::CompareName(const Message &aMessage, uint16_t &aOffset, const Name &aName)
551 {
552     return aName.IsFromCString()
553                ? CompareName(aMessage, aOffset, aName.mString)
554                : (aName.IsFromMessage() ? CompareName(aMessage, aOffset, *aName.mMessage, aName.mOffset)
555                                         : ParseName(aMessage, aOffset));
556 }
557 
GetNextLabel(void)558 Error Name::LabelIterator::GetNextLabel(void)
559 {
560     Error error;
561 
562     while (true)
563     {
564         uint8_t labelLength;
565         uint8_t labelType;
566 
567         SuccessOrExit(error = mMessage.Read(mNextLabelOffset, labelLength));
568 
569         labelType = labelLength & kLabelTypeMask;
570 
571         if (labelType == kTextLabelType)
572         {
573             if (labelLength == 0)
574             {
575                 // Zero label length indicates end of a name.
576 
577                 if (!IsEndOffsetSet())
578                 {
579                     mNameEndOffset = mNextLabelOffset + sizeof(uint8_t);
580                 }
581 
582                 ExitNow(error = kErrorNotFound);
583             }
584 
585             mLabelStartOffset = mNextLabelOffset + sizeof(uint8_t);
586             mLabelLength      = labelLength;
587             mNextLabelOffset  = mLabelStartOffset + labelLength;
588             ExitNow();
589         }
590         else if (labelType == kPointerLabelType)
591         {
592             // A pointer label takes the form of a two byte sequence as a
593             // `uint16_t` value. The first two bits are ones. The next 14 bits
594             // specify an offset value from the start of the DNS header.
595 
596             uint16_t pointerValue;
597             uint16_t nextLabelOffset;
598 
599             SuccessOrExit(error = mMessage.Read(mNextLabelOffset, pointerValue));
600 
601             if (!IsEndOffsetSet())
602             {
603                 mNameEndOffset = mNextLabelOffset + sizeof(uint16_t);
604             }
605 
606             // `mMessage.GetOffset()` must point to the start of the
607             // DNS header.
608             nextLabelOffset = mMessage.GetOffset() + (BigEndian::HostSwap16(pointerValue) & kPointerLabelOffsetMask);
609             VerifyOrExit(nextLabelOffset < mMinLabelOffset, error = kErrorParse);
610             mNextLabelOffset = nextLabelOffset;
611             mMinLabelOffset  = nextLabelOffset;
612 
613             // Go back through the `while(true)` loop to get the next label.
614         }
615         else
616         {
617             ExitNow(error = kErrorParse);
618         }
619     }
620 
621 exit:
622     return error;
623 }
624 
ReadLabel(char * aLabelBuffer,uint8_t & aLabelLength,bool aAllowDotCharInLabel) const625 Error Name::LabelIterator::ReadLabel(char *aLabelBuffer, uint8_t &aLabelLength, bool aAllowDotCharInLabel) const
626 {
627     Error error;
628 
629     VerifyOrExit(mLabelLength < aLabelLength, error = kErrorNoBufs);
630 
631     SuccessOrExit(error = mMessage.Read(mLabelStartOffset, aLabelBuffer, mLabelLength));
632     aLabelBuffer[mLabelLength] = kNullChar;
633     aLabelLength               = mLabelLength;
634 
635     if (!aAllowDotCharInLabel)
636     {
637         VerifyOrExit(StringFind(aLabelBuffer, kLabelSeparatorChar) == nullptr, error = kErrorParse);
638     }
639 
640 exit:
641     return error;
642 }
643 
CaseInsensitiveMatch(uint8_t aFirst,uint8_t aSecond)644 bool Name::LabelIterator::CaseInsensitiveMatch(uint8_t aFirst, uint8_t aSecond)
645 {
646     return ToLowercase(static_cast<char>(aFirst)) == ToLowercase(static_cast<char>(aSecond));
647 }
648 
CompareLabel(const char * & aName,bool aIsSingleLabel) const649 bool Name::LabelIterator::CompareLabel(const char *&aName, bool aIsSingleLabel) const
650 {
651     // This method compares the current label in the iterator with the
652     // `aName` string. `aIsSingleLabel` indicates whether `aName` is a
653     // single label, or a sequence of labels separated by dot '.' char.
654     // If the label matches `aName`, then `aName` pointer is moved
655     // forward to the start of the next label (skipping over the `.`
656     // char). This method returns `true` when the labels match, `false`
657     // otherwise.
658 
659     bool matches = false;
660 
661     VerifyOrExit(StringLength(aName, mLabelLength) == mLabelLength);
662     matches = mMessage.CompareBytes(mLabelStartOffset, aName, mLabelLength, CaseInsensitiveMatch);
663 
664     VerifyOrExit(matches);
665 
666     aName += mLabelLength;
667 
668     // If `aName` is a single label, we should be also at the end of the
669     // `aName` string. Otherwise, we should see either null or dot '.'
670     // character (in case `aName` contains multiple labels).
671 
672     matches = (*aName == kNullChar);
673 
674     if (!aIsSingleLabel && (*aName == kLabelSeparatorChar))
675     {
676         matches = true;
677         aName++;
678     }
679 
680 exit:
681     return matches;
682 }
683 
CompareLabel(const LabelIterator & aOtherIterator) const684 bool Name::LabelIterator::CompareLabel(const LabelIterator &aOtherIterator) const
685 {
686     // This method compares the current label in the iterator with the
687     // label from another iterator.
688 
689     return (mLabelLength == aOtherIterator.mLabelLength) &&
690            mMessage.CompareBytes(mLabelStartOffset, aOtherIterator.mMessage, aOtherIterator.mLabelStartOffset,
691                                  mLabelLength, CaseInsensitiveMatch);
692 }
693 
AppendLabel(Message & aMessage) const694 Error Name::LabelIterator::AppendLabel(Message &aMessage) const
695 {
696     // This method reads and appends the current label in the iterator
697     // to `aMessage`.
698 
699     Error error;
700 
701     VerifyOrExit((0 < mLabelLength) && (mLabelLength <= kMaxLabelLength), error = kErrorInvalidArgs);
702     SuccessOrExit(error = aMessage.Append(mLabelLength));
703     error = aMessage.AppendBytesFromMessage(mMessage, mLabelStartOffset, mLabelLength);
704 
705 exit:
706     return error;
707 }
708 
ExtractLabels(const char * aName,const char * aSuffixName,char * aLabels,uint16_t aLabelsSize)709 Error Name::ExtractLabels(const char *aName, const char *aSuffixName, char *aLabels, uint16_t aLabelsSize)
710 {
711     Error       error        = kErrorParse;
712     uint16_t    nameLength   = StringLength(aName, kMaxNameSize);
713     uint16_t    suffixLength = StringLength(aSuffixName, kMaxNameSize);
714     const char *suffixStart;
715 
716     VerifyOrExit(nameLength < kMaxNameSize);
717     VerifyOrExit(suffixLength < kMaxNameSize);
718 
719     VerifyOrExit(nameLength > suffixLength);
720 
721     suffixStart = aName + nameLength - suffixLength;
722     VerifyOrExit(StringMatch(suffixStart, aSuffixName, kStringCaseInsensitiveMatch));
723     suffixStart--;
724     VerifyOrExit(*suffixStart == kLabelSeparatorChar);
725 
726     // Determine the labels length to copy
727     nameLength -= (suffixLength + 1);
728     VerifyOrExit(nameLength < aLabelsSize, error = kErrorNoBufs);
729 
730     if (aLabels != aName)
731     {
732         memmove(aLabels, aName, nameLength);
733     }
734 
735     aLabels[nameLength] = kNullChar;
736     error               = kErrorNone;
737 
738 exit:
739     return error;
740 }
741 
IsSubDomainOf(const char * aName,const char * aDomain)742 bool Name::IsSubDomainOf(const char *aName, const char *aDomain)
743 {
744     bool     match             = false;
745     bool     nameEndsWithDot   = false;
746     bool     domainEndsWithDot = false;
747     uint16_t nameLength        = StringLength(aName, kMaxNameLength);
748     uint16_t domainLength      = StringLength(aDomain, kMaxNameLength);
749 
750     if (nameLength > 0 && aName[nameLength - 1] == kLabelSeparatorChar)
751     {
752         nameEndsWithDot = true;
753         --nameLength;
754     }
755 
756     if (domainLength > 0 && aDomain[domainLength - 1] == kLabelSeparatorChar)
757     {
758         domainEndsWithDot = true;
759         --domainLength;
760     }
761 
762     VerifyOrExit(nameLength >= domainLength);
763 
764     aName += nameLength - domainLength;
765 
766     if (nameLength > domainLength)
767     {
768         VerifyOrExit(aName[-1] == kLabelSeparatorChar);
769     }
770 
771     // This method allows either `aName` or `aDomain` to include or
772     // exclude the last `.` character. If both include it or if both
773     // do not, we do a full comparison using `StringMatch()`.
774     // Otherwise (i.e., when one includes and the other one does not)
775     // we use `StringStartWith()` to allow the extra `.` character.
776 
777     if (nameEndsWithDot == domainEndsWithDot)
778     {
779         match = StringMatch(aName, aDomain, kStringCaseInsensitiveMatch);
780     }
781     else if (nameEndsWithDot)
782     {
783         // `aName` ends with dot, but `aDomain` does not.
784         match = StringStartsWith(aName, aDomain, kStringCaseInsensitiveMatch);
785     }
786     else
787     {
788         // `aDomain` ends with dot, but `aName` does not.
789         match = StringStartsWith(aDomain, aName, kStringCaseInsensitiveMatch);
790     }
791 
792 exit:
793     return match;
794 }
795 
IsSameDomain(const char * aDomain1,const char * aDomain2)796 bool Name::IsSameDomain(const char *aDomain1, const char *aDomain2)
797 {
798     return IsSubDomainOf(aDomain1, aDomain2) && IsSubDomainOf(aDomain2, aDomain1);
799 }
800 
ParseRecords(const Message & aMessage,uint16_t & aOffset,uint16_t aNumRecords)801 Error ResourceRecord::ParseRecords(const Message &aMessage, uint16_t &aOffset, uint16_t aNumRecords)
802 {
803     Error error = kErrorNone;
804 
805     while (aNumRecords > 0)
806     {
807         ResourceRecord record;
808 
809         SuccessOrExit(error = Name::ParseName(aMessage, aOffset));
810         SuccessOrExit(error = record.ReadFrom(aMessage, aOffset));
811         aOffset += static_cast<uint16_t>(record.GetSize());
812         aNumRecords--;
813     }
814 
815 exit:
816     return error;
817 }
818 
FindRecord(const Message & aMessage,uint16_t & aOffset,uint16_t & aNumRecords,const Name & aName)819 Error ResourceRecord::FindRecord(const Message &aMessage, uint16_t &aOffset, uint16_t &aNumRecords, const Name &aName)
820 {
821     Error error;
822 
823     while (aNumRecords > 0)
824     {
825         bool           matches = true;
826         ResourceRecord record;
827 
828         error = Name::CompareName(aMessage, aOffset, aName);
829 
830         switch (error)
831         {
832         case kErrorNone:
833             break;
834         case kErrorNotFound:
835             matches = false;
836             break;
837         default:
838             ExitNow();
839         }
840 
841         SuccessOrExit(error = record.ReadFrom(aMessage, aOffset));
842         aNumRecords--;
843         VerifyOrExit(!matches);
844         aOffset += static_cast<uint16_t>(record.GetSize());
845     }
846 
847     error = kErrorNotFound;
848 
849 exit:
850     return error;
851 }
852 
FindRecord(const Message & aMessage,uint16_t & aOffset,uint16_t aNumRecords,uint16_t aIndex,const Name & aName,uint16_t aType,ResourceRecord & aRecord,uint16_t aMinRecordSize)853 Error ResourceRecord::FindRecord(const Message  &aMessage,
854                                  uint16_t       &aOffset,
855                                  uint16_t        aNumRecords,
856                                  uint16_t        aIndex,
857                                  const Name     &aName,
858                                  uint16_t        aType,
859                                  ResourceRecord &aRecord,
860                                  uint16_t        aMinRecordSize)
861 {
862     // This static method searches in `aMessage` starting from `aOffset`
863     // up to maximum of `aNumRecords`, for the `(aIndex+1)`th
864     // occurrence of a resource record of type `aType` with record name
865     // matching `aName`. It also verifies that the record size is larger
866     // than `aMinRecordSize`. If found, `aMinRecordSize` bytes from the
867     // record are read and copied into `aRecord`. In this case `aOffset`
868     // is updated to point to the last record byte read from the message
869     // (so that the caller can read any remaining fields in the record
870     // data).
871 
872     Error    error;
873     uint16_t offset = aOffset;
874     uint16_t recordOffset;
875 
876     while (aNumRecords > 0)
877     {
878         SuccessOrExit(error = FindRecord(aMessage, offset, aNumRecords, aName));
879 
880         // Save the offset to start of `ResourceRecord` fields.
881         recordOffset = offset;
882 
883         error = ReadRecord(aMessage, offset, aType, aRecord, aMinRecordSize);
884 
885         if (error == kErrorNotFound)
886         {
887             // `ReadRecord()` already updates the `offset` to skip
888             // over a non-matching record.
889             continue;
890         }
891 
892         SuccessOrExit(error);
893 
894         if (aIndex == 0)
895         {
896             aOffset = offset;
897             ExitNow();
898         }
899 
900         aIndex--;
901 
902         // Skip over the record.
903         offset = static_cast<uint16_t>(recordOffset + aRecord.GetSize());
904     }
905 
906     error = kErrorNotFound;
907 
908 exit:
909     return error;
910 }
911 
ReadRecord(const Message & aMessage,uint16_t & aOffset,uint16_t aType,ResourceRecord & aRecord,uint16_t aMinRecordSize)912 Error ResourceRecord::ReadRecord(const Message  &aMessage,
913                                  uint16_t       &aOffset,
914                                  uint16_t        aType,
915                                  ResourceRecord &aRecord,
916                                  uint16_t        aMinRecordSize)
917 {
918     // This static method tries to read a matching resource record of a
919     // given type and a minimum record size from a message. The `aType`
920     // value of `kTypeAny` matches any type.  If the record in the
921     // message does not match, it skips over the record. Please see
922     // `ReadRecord<RecordType>()` for more details.
923 
924     Error          error;
925     ResourceRecord record;
926 
927     SuccessOrExit(error = record.ReadFrom(aMessage, aOffset));
928 
929     if (((aType == kTypeAny) || (record.GetType() == aType)) && (record.GetSize() >= aMinRecordSize))
930     {
931         IgnoreError(aMessage.Read(aOffset, &aRecord, aMinRecordSize));
932         aOffset += aMinRecordSize;
933     }
934     else
935     {
936         // Skip over the entire record.
937         aOffset += static_cast<uint16_t>(record.GetSize());
938         error = kErrorNotFound;
939     }
940 
941 exit:
942     return error;
943 }
944 
ReadName(const Message & aMessage,uint16_t & aOffset,uint16_t aStartOffset,char * aNameBuffer,uint16_t aNameBufferSize,bool aSkipRecord) const945 Error ResourceRecord::ReadName(const Message &aMessage,
946                                uint16_t      &aOffset,
947                                uint16_t       aStartOffset,
948                                char          *aNameBuffer,
949                                uint16_t       aNameBufferSize,
950                                bool           aSkipRecord) const
951 {
952     // This protected method parses and reads a name field in a record
953     // from a message. It is intended only for sub-classes of
954     // `ResourceRecord`.
955     //
956     // On input `aOffset` gives the offset in `aMessage` to the start of
957     // name field. `aStartOffset` gives the offset to the start of the
958     // `ResourceRecord`. `aSkipRecord` indicates whether to skip over
959     // the entire resource record or just the read name. On exit, when
960     // successfully read, `aOffset` is updated to either point after the
961     // end of record or after the the name field.
962     //
963     // When read successfully, this method returns `kErrorNone`. On a
964     // parse error (invalid format) returns `kErrorParse`. If the
965     // name does not fit in the given name buffer it returns
966     // `kErrorNoBufs`
967 
968     Error error = kErrorNone;
969 
970     SuccessOrExit(error = Name::ReadName(aMessage, aOffset, aNameBuffer, aNameBufferSize));
971     VerifyOrExit(aOffset <= aStartOffset + GetSize(), error = kErrorParse);
972 
973     VerifyOrExit(aSkipRecord);
974     aOffset = aStartOffset;
975     error   = SkipRecord(aMessage, aOffset);
976 
977 exit:
978     return error;
979 }
980 
SkipRecord(const Message & aMessage,uint16_t & aOffset) const981 Error ResourceRecord::SkipRecord(const Message &aMessage, uint16_t &aOffset) const
982 {
983     // This protected method parses and skips over a resource record
984     // in a message.
985     //
986     // On input `aOffset` gives the offset in `aMessage` to the start of
987     // the `ResourceRecord`. On exit, when successfully parsed, `aOffset`
988     // is updated to point to byte after the entire record.
989 
990     Error error;
991 
992     SuccessOrExit(error = CheckRecord(aMessage, aOffset));
993     aOffset += static_cast<uint16_t>(GetSize());
994 
995 exit:
996     return error;
997 }
998 
CheckRecord(const Message & aMessage,uint16_t aOffset) const999 Error ResourceRecord::CheckRecord(const Message &aMessage, uint16_t aOffset) const
1000 {
1001     // This method checks that the entire record (including record data)
1002     // is present in `aMessage` at `aOffset` (pointing to the start of
1003     // the `ResourceRecord` in `aMessage`).
1004 
1005     return (aOffset + GetSize() <= aMessage.GetLength()) ? kErrorNone : kErrorParse;
1006 }
1007 
ReadFrom(const Message & aMessage,uint16_t aOffset)1008 Error ResourceRecord::ReadFrom(const Message &aMessage, uint16_t aOffset)
1009 {
1010     // This method reads the `ResourceRecord` from `aMessage` at
1011     // `aOffset`. It verifies that the entire record (including record
1012     // data) is present in the message.
1013 
1014     Error error;
1015 
1016     SuccessOrExit(error = aMessage.Read(aOffset, *this));
1017     error = CheckRecord(aMessage, aOffset);
1018 
1019 exit:
1020     return error;
1021 }
1022 
Init(const uint8_t * aTxtData,uint16_t aTxtDataLength)1023 void TxtEntry::Iterator::Init(const uint8_t *aTxtData, uint16_t aTxtDataLength)
1024 {
1025     SetTxtData(aTxtData);
1026     SetTxtDataLength(aTxtDataLength);
1027     SetTxtDataPosition(0);
1028 }
1029 
GetNextEntry(TxtEntry & aEntry)1030 Error TxtEntry::Iterator::GetNextEntry(TxtEntry &aEntry)
1031 {
1032     Error       error = kErrorNone;
1033     uint8_t     length;
1034     uint8_t     index;
1035     const char *cur;
1036     char       *keyBuffer = GetKeyBuffer();
1037 
1038     static_assert(sizeof(mChar) >= TxtEntry::kMaxKeyLength + 1, "KeyBuffer cannot fit the max key length");
1039 
1040     VerifyOrExit(GetTxtData() != nullptr, error = kErrorParse);
1041 
1042     aEntry.mKey = keyBuffer;
1043 
1044     while ((cur = GetTxtData() + GetTxtDataPosition()) < GetTxtDataEnd())
1045     {
1046         length = static_cast<uint8_t>(*cur);
1047 
1048         cur++;
1049         VerifyOrExit(cur + length <= GetTxtDataEnd(), error = kErrorParse);
1050         IncreaseTxtDataPosition(sizeof(uint8_t) + length);
1051 
1052         // Silently skip over an empty string or if the string starts with
1053         // a `=` character (i.e., missing key) - RFC 6763 - section 6.4.
1054 
1055         if ((length == 0) || (cur[0] == kKeyValueSeparator))
1056         {
1057             continue;
1058         }
1059 
1060         for (index = 0; index < length; index++)
1061         {
1062             if (cur[index] == kKeyValueSeparator)
1063             {
1064                 keyBuffer[index++]  = kNullChar; // Increment index to skip over `=`.
1065                 aEntry.mValue       = reinterpret_cast<const uint8_t *>(&cur[index]);
1066                 aEntry.mValueLength = length - index;
1067                 ExitNow();
1068             }
1069 
1070             if (index >= sizeof(mChar) - 1)
1071             {
1072                 // The key is larger than supported key string length.
1073                 // In this case, we return the full encoded string in
1074                 // `mValue` and `mValueLength` and set `mKey` to
1075                 // `nullptr`.
1076 
1077                 aEntry.mKey         = nullptr;
1078                 aEntry.mValue       = reinterpret_cast<const uint8_t *>(cur);
1079                 aEntry.mValueLength = length;
1080                 ExitNow();
1081             }
1082 
1083             keyBuffer[index] = cur[index];
1084         }
1085 
1086         // If we reach the end of the string without finding `=` then
1087         // it is a boolean key attribute (encoded as "key").
1088 
1089         keyBuffer[index]    = kNullChar;
1090         aEntry.mValue       = nullptr;
1091         aEntry.mValueLength = 0;
1092         ExitNow();
1093     }
1094 
1095     error = kErrorNotFound;
1096 
1097 exit:
1098     return error;
1099 }
1100 
AppendTo(Message & aMessage) const1101 Error TxtEntry::AppendTo(Message &aMessage) const
1102 {
1103     Appender appender(aMessage);
1104 
1105     return AppendTo(appender);
1106 }
1107 
AppendTo(Appender & aAppender) const1108 Error TxtEntry::AppendTo(Appender &aAppender) const
1109 {
1110     Error    error = kErrorNone;
1111     uint16_t keyLength;
1112     char     separator = kKeyValueSeparator;
1113 
1114     if (mKey == nullptr)
1115     {
1116         VerifyOrExit((mValue != nullptr) && (mValueLength != 0));
1117         error = aAppender.AppendBytes(mValue, mValueLength);
1118         ExitNow();
1119     }
1120 
1121     keyLength = StringLength(mKey, static_cast<uint16_t>(kMaxKeyValueEncodedSize) + 1);
1122 
1123     VerifyOrExit(kMinKeyLength <= keyLength, error = kErrorInvalidArgs);
1124 
1125     if (mValue == nullptr)
1126     {
1127         // Treat as a boolean attribute and encoded as "key" (with no `=`).
1128         SuccessOrExit(error = aAppender.Append<uint8_t>(static_cast<uint8_t>(keyLength)));
1129         error = aAppender.AppendBytes(mKey, keyLength);
1130         ExitNow();
1131     }
1132 
1133     // Treat as key/value and encode as "key=value", value may be empty.
1134 
1135     VerifyOrExit(mValueLength + keyLength + sizeof(char) <= kMaxKeyValueEncodedSize, error = kErrorInvalidArgs);
1136 
1137     SuccessOrExit(error = aAppender.Append<uint8_t>(static_cast<uint8_t>(keyLength + mValueLength + sizeof(char))));
1138     SuccessOrExit(error = aAppender.AppendBytes(mKey, keyLength));
1139     SuccessOrExit(error = aAppender.Append(separator));
1140     error = aAppender.AppendBytes(mValue, mValueLength);
1141 
1142 exit:
1143     return error;
1144 }
1145 
AppendEntries(const TxtEntry * aEntries,uint16_t aNumEntries,Message & aMessage)1146 Error TxtEntry::AppendEntries(const TxtEntry *aEntries, uint16_t aNumEntries, Message &aMessage)
1147 {
1148     Appender appender(aMessage);
1149 
1150     return AppendEntries(aEntries, aNumEntries, appender);
1151 }
1152 
AppendEntries(const TxtEntry * aEntries,uint16_t aNumEntries,MutableData<kWithUint16Length> & aData)1153 Error TxtEntry::AppendEntries(const TxtEntry *aEntries, uint16_t aNumEntries, MutableData<kWithUint16Length> &aData)
1154 {
1155     Error    error;
1156     Appender appender(aData.GetBytes(), aData.GetLength());
1157 
1158     SuccessOrExit(error = AppendEntries(aEntries, aNumEntries, appender));
1159     appender.GetAsData(aData);
1160 
1161 exit:
1162     return error;
1163 }
1164 
AppendEntries(const TxtEntry * aEntries,uint16_t aNumEntries,Appender & aAppender)1165 Error TxtEntry::AppendEntries(const TxtEntry *aEntries, uint16_t aNumEntries, Appender &aAppender)
1166 {
1167     Error error = kErrorNone;
1168 
1169     for (uint16_t index = 0; index < aNumEntries; index++)
1170     {
1171         SuccessOrExit(error = aEntries[index].AppendTo(aAppender));
1172     }
1173 
1174     if (aAppender.GetAppendedLength() == 0)
1175     {
1176         error = aAppender.Append<uint8_t>(0);
1177     }
1178 
1179 exit:
1180     return error;
1181 }
1182 
IsValid(void) const1183 bool AaaaRecord::IsValid(void) const
1184 {
1185     return GetType() == Dns::ResourceRecord::kTypeAaaa && GetSize() == sizeof(*this);
1186 }
1187 
IsValid(void) const1188 bool KeyRecord::IsValid(void) const { return GetType() == Dns::ResourceRecord::kTypeKey; }
1189 
1190 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
Init(void)1191 void Ecdsa256KeyRecord::Init(void)
1192 {
1193     KeyRecord::Init();
1194     SetAlgorithm(kAlgorithmEcdsaP256Sha256);
1195 }
1196 
IsValid(void) const1197 bool Ecdsa256KeyRecord::IsValid(void) const
1198 {
1199     return KeyRecord::IsValid() && GetLength() == sizeof(*this) - sizeof(ResourceRecord) &&
1200            GetAlgorithm() == kAlgorithmEcdsaP256Sha256;
1201 }
1202 #endif
1203 
IsValid(void) const1204 bool SigRecord::IsValid(void) const
1205 {
1206     return GetType() == Dns::ResourceRecord::kTypeSig && GetLength() >= sizeof(*this) - sizeof(ResourceRecord);
1207 }
1208 
InitAsShortVariant(uint32_t aLeaseInterval)1209 void LeaseOption::InitAsShortVariant(uint32_t aLeaseInterval)
1210 {
1211     SetOptionCode(kUpdateLease);
1212     SetOptionLength(kShortLength);
1213     SetLeaseInterval(aLeaseInterval);
1214 }
1215 
InitAsLongVariant(uint32_t aLeaseInterval,uint32_t aKeyLeaseInterval)1216 void LeaseOption::InitAsLongVariant(uint32_t aLeaseInterval, uint32_t aKeyLeaseInterval)
1217 {
1218     SetOptionCode(kUpdateLease);
1219     SetOptionLength(kLongLength);
1220     SetLeaseInterval(aLeaseInterval);
1221     SetKeyLeaseInterval(aKeyLeaseInterval);
1222 }
1223 
IsValid(void) const1224 bool LeaseOption::IsValid(void) const
1225 {
1226     bool isValid = false;
1227 
1228     VerifyOrExit((GetOptionLength() == kShortLength) || (GetOptionLength() >= kLongLength));
1229     isValid = (GetLeaseInterval() <= GetKeyLeaseInterval());
1230 
1231 exit:
1232     return isValid;
1233 }
1234 
ReadFrom(const Message & aMessage,uint16_t aOffset,uint16_t aLength)1235 Error LeaseOption::ReadFrom(const Message &aMessage, uint16_t aOffset, uint16_t aLength)
1236 {
1237     Error    error = kErrorNone;
1238     uint16_t endOffset;
1239 
1240     VerifyOrExit(static_cast<uint32_t>(aOffset) + aLength <= aMessage.GetLength(), error = kErrorParse);
1241 
1242     endOffset = aOffset + aLength;
1243 
1244     while (aOffset < endOffset)
1245     {
1246         uint16_t size;
1247 
1248         SuccessOrExit(error = aMessage.Read(aOffset, this, sizeof(Option)));
1249 
1250         VerifyOrExit(aOffset + GetSize() <= endOffset, error = kErrorParse);
1251 
1252         size = static_cast<uint16_t>(GetSize());
1253 
1254         if (GetOptionCode() == kUpdateLease)
1255         {
1256             VerifyOrExit(GetOptionLength() >= kShortLength, error = kErrorParse);
1257 
1258             IgnoreError(aMessage.Read(aOffset, this, Min(size, static_cast<uint16_t>(sizeof(LeaseOption)))));
1259             VerifyOrExit(IsValid(), error = kErrorParse);
1260 
1261             ExitNow();
1262         }
1263 
1264         aOffset += size;
1265     }
1266 
1267     error = kErrorNotFound;
1268 
1269 exit:
1270     return error;
1271 }
1272 
ReadPtrName(const Message & aMessage,uint16_t & aOffset,char * aLabelBuffer,uint8_t aLabelBufferSize,char * aNameBuffer,uint16_t aNameBufferSize) const1273 Error PtrRecord::ReadPtrName(const Message &aMessage,
1274                              uint16_t      &aOffset,
1275                              char          *aLabelBuffer,
1276                              uint8_t        aLabelBufferSize,
1277                              char          *aNameBuffer,
1278                              uint16_t       aNameBufferSize) const
1279 {
1280     Error    error       = kErrorNone;
1281     uint16_t startOffset = aOffset - sizeof(PtrRecord); // start of `PtrRecord`.
1282 
1283     // Verify that the name is within the record data length.
1284     SuccessOrExit(error = Name::ParseName(aMessage, aOffset));
1285     VerifyOrExit(aOffset <= startOffset + GetSize(), error = kErrorParse);
1286 
1287     aOffset = startOffset + sizeof(PtrRecord);
1288     SuccessOrExit(error = Name::ReadLabel(aMessage, aOffset, aLabelBuffer, aLabelBufferSize));
1289 
1290     if (aNameBuffer != nullptr)
1291     {
1292         SuccessOrExit(error = Name::ReadName(aMessage, aOffset, aNameBuffer, aNameBufferSize));
1293     }
1294 
1295     aOffset = startOffset;
1296     error   = SkipRecord(aMessage, aOffset);
1297 
1298 exit:
1299     return error;
1300 }
1301 
ReadTxtData(const Message & aMessage,uint16_t & aOffset,uint8_t * aTxtBuffer,uint16_t & aTxtBufferSize) const1302 Error TxtRecord::ReadTxtData(const Message &aMessage,
1303                              uint16_t      &aOffset,
1304                              uint8_t       *aTxtBuffer,
1305                              uint16_t      &aTxtBufferSize) const
1306 {
1307     Error error = kErrorNone;
1308 
1309     SuccessOrExit(error = aMessage.Read(aOffset, aTxtBuffer, Min(GetLength(), aTxtBufferSize)));
1310     aOffset += GetLength();
1311 
1312     VerifyOrExit(GetLength() <= aTxtBufferSize, error = kErrorNoBufs);
1313     aTxtBufferSize = GetLength();
1314     VerifyOrExit(VerifyTxtData(aTxtBuffer, aTxtBufferSize, /* aAllowEmpty */ true), error = kErrorParse);
1315 
1316 exit:
1317     return error;
1318 }
1319 
VerifyTxtData(const uint8_t * aTxtData,uint16_t aTxtLength,bool aAllowEmpty)1320 bool TxtRecord::VerifyTxtData(const uint8_t *aTxtData, uint16_t aTxtLength, bool aAllowEmpty)
1321 {
1322     bool    valid          = false;
1323     uint8_t curEntryLength = 0;
1324 
1325     // Per RFC 1035, TXT-DATA MUST have one or more <character-string>s.
1326     VerifyOrExit(aAllowEmpty || aTxtLength > 0);
1327 
1328     for (uint16_t i = 0; i < aTxtLength; ++i)
1329     {
1330         if (curEntryLength == 0)
1331         {
1332             curEntryLength = aTxtData[i];
1333         }
1334         else
1335         {
1336             --curEntryLength;
1337         }
1338     }
1339 
1340     valid = (curEntryLength == 0);
1341 
1342 exit:
1343     return valid;
1344 }
1345 
AddType(uint16_t aType)1346 void NsecRecord::TypeBitMap::AddType(uint16_t aType)
1347 {
1348     if ((aType >> 8) == mBlockNumber)
1349     {
1350         uint8_t  type  = static_cast<uint8_t>(aType & 0xff);
1351         uint8_t  index = (type / kBitsPerByte);
1352         uint16_t mask  = (0x80 >> (type % kBitsPerByte));
1353 
1354         mBitmaps[index] |= mask;
1355         mBitmapLength = Max<uint8_t>(mBitmapLength, index + 1);
1356     }
1357 }
1358 
ContainsType(uint16_t aType) const1359 bool NsecRecord::TypeBitMap::ContainsType(uint16_t aType) const
1360 {
1361     bool     contains = false;
1362     uint8_t  type     = static_cast<uint8_t>(aType & 0xff);
1363     uint8_t  index    = (type / kBitsPerByte);
1364     uint16_t mask     = (0x80 >> (type % kBitsPerByte));
1365 
1366     VerifyOrExit((aType >> 8) == mBlockNumber);
1367 
1368     VerifyOrExit(index < mBitmapLength);
1369 
1370     contains = (mBitmaps[index] & mask);
1371 
1372 exit:
1373     return contains;
1374 }
1375 
1376 } // namespace Dns
1377 } // namespace ot
1378