1 /*
2 * Copyright (c) 2020, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 /**
30 * @file
31 * This file implements generating and processing of DNS headers and helper functions/methods.
32 */
33
34 #include "dns_types.hpp"
35
36 #include "common/code_utils.hpp"
37 #include "common/debug.hpp"
38 #include "common/instance.hpp"
39 #include "common/random.hpp"
40 #include "common/string.hpp"
41
42 namespace ot {
43 namespace Dns {
44
45 using ot::Encoding::BigEndian::HostSwap16;
46
SetRandomMessageId(void)47 Error Header::SetRandomMessageId(void)
48 {
49 return Random::Crypto::FillBuffer(reinterpret_cast<uint8_t *>(&mMessageId), sizeof(mMessageId));
50 }
51
ResponseCodeToError(Response aResponse)52 Error Header::ResponseCodeToError(Response aResponse)
53 {
54 Error error = kErrorFailed;
55
56 switch (aResponse)
57 {
58 case kResponseSuccess:
59 error = kErrorNone;
60 break;
61
62 case kResponseFormatError: // Server unable to interpret request due to format error.
63 case kResponseBadName: // Bad name.
64 case kResponseBadTruncation: // Bad truncation.
65 case kResponseNotZone: // A name is not in the zone.
66 error = kErrorParse;
67 break;
68
69 case kResponseServerFailure: // Server encountered an internal failure.
70 error = kErrorFailed;
71 break;
72
73 case kResponseNameError: // Name that ought to exist, does not exists.
74 case kResponseRecordNotExists: // Some RRset that ought to exist, does not exist.
75 error = kErrorNotFound;
76 break;
77
78 case kResponseNotImplemented: // Server does not support the query type (OpCode).
79 case kDsoTypeNotImplemented: // DSO TLV type is not implemented.
80 error = kErrorNotImplemented;
81 break;
82
83 case kResponseBadAlg: // Bad algorithm.
84 error = kErrorNotCapable;
85 break;
86
87 case kResponseNameExists: // Some name that ought not to exist, does exist.
88 case kResponseRecordExists: // Some RRset that ought not to exist, does exist.
89 error = kErrorDuplicated;
90 break;
91
92 case kResponseRefused: // Server refused to perform operation for policy or security reasons.
93 case kResponseNotAuth: // Service is not authoritative for zone.
94 error = kErrorSecurity;
95 break;
96
97 default:
98 break;
99 }
100
101 return error;
102 }
103
AppendTo(Message & aMessage) const104 Error Name::AppendTo(Message &aMessage) const
105 {
106 Error error;
107
108 if (IsEmpty())
109 {
110 error = AppendTerminator(aMessage);
111 }
112 else if (IsFromCString())
113 {
114 error = AppendName(GetAsCString(), aMessage);
115 }
116 else
117 {
118 // Name is from a message. Read labels one by one from
119 // `mMessage` and and append each to the `aMessage`.
120
121 LabelIterator iterator(*mMessage, mOffset);
122
123 while (true)
124 {
125 error = iterator.GetNextLabel();
126
127 switch (error)
128 {
129 case kErrorNone:
130 SuccessOrExit(error = iterator.AppendLabel(aMessage));
131 break;
132
133 case kErrorNotFound:
134 // We reached the end of name successfully.
135 error = AppendTerminator(aMessage);
136
137 OT_FALL_THROUGH;
138
139 default:
140 ExitNow();
141 }
142 }
143 }
144
145 exit:
146 return error;
147 }
148
AppendLabel(const char * aLabel,Message & aMessage)149 Error Name::AppendLabel(const char *aLabel, Message &aMessage)
150 {
151 return AppendLabel(aLabel, static_cast<uint8_t>(StringLength(aLabel, kMaxLabelSize)), aMessage);
152 }
153
AppendLabel(const char * aLabel,uint8_t aLength,Message & aMessage)154 Error Name::AppendLabel(const char *aLabel, uint8_t aLength, Message &aMessage)
155 {
156 Error error = kErrorNone;
157
158 VerifyOrExit((0 < aLength) && (aLength <= kMaxLabelLength), error = kErrorInvalidArgs);
159
160 SuccessOrExit(error = aMessage.Append(aLength));
161 error = aMessage.AppendBytes(aLabel, aLength);
162
163 exit:
164 return error;
165 }
166
AppendMultipleLabels(const char * aLabels,Message & aMessage)167 Error Name::AppendMultipleLabels(const char *aLabels, Message &aMessage)
168 {
169 return AppendMultipleLabels(aLabels, kMaxNameLength, aMessage);
170 }
171
AppendMultipleLabels(const char * aLabels,uint8_t aLength,Message & aMessage)172 Error Name::AppendMultipleLabels(const char *aLabels, uint8_t aLength, Message &aMessage)
173 {
174 Error error = kErrorNone;
175 uint16_t index = 0;
176 uint16_t labelStartIndex = 0;
177 char ch;
178
179 VerifyOrExit(aLabels != nullptr);
180
181 do
182 {
183 ch = index < aLength ? aLabels[index] : static_cast<char>(kNullChar);
184
185 if ((ch == kNullChar) || (ch == kLabelSeperatorChar))
186 {
187 uint8_t labelLength = static_cast<uint8_t>(index - labelStartIndex);
188
189 if (labelLength == 0)
190 {
191 // Empty label (e.g., consecutive dots) is invalid, but we
192 // allow for two cases: (1) where `aLabels` ends with a dot
193 // (`labelLength` is zero but we are at end of `aLabels` string
194 // and `ch` is null char. (2) if `aLabels` is just "." (we
195 // see a dot at index 0, and index 1 is null char).
196
197 error =
198 ((ch == kNullChar) || ((index == 0) && (aLabels[1] == kNullChar))) ? kErrorNone : kErrorInvalidArgs;
199 ExitNow();
200 }
201
202 VerifyOrExit(index + 1 < kMaxEncodedLength, error = kErrorInvalidArgs);
203 SuccessOrExit(error = AppendLabel(&aLabels[labelStartIndex], labelLength, aMessage));
204
205 labelStartIndex = index + 1;
206 }
207
208 index++;
209
210 } while (ch != kNullChar);
211
212 exit:
213 return error;
214 }
215
AppendTerminator(Message & aMessage)216 Error Name::AppendTerminator(Message &aMessage)
217 {
218 uint8_t terminator = 0;
219
220 return aMessage.Append(terminator);
221 }
222
AppendPointerLabel(uint16_t aOffset,Message & aMessage)223 Error Name::AppendPointerLabel(uint16_t aOffset, Message &aMessage)
224 {
225 Error error;
226 uint16_t value;
227
228 #if OPENTHREAD_CONFIG_REFERENCE_DEVICE_ENABLE
229 if (!Instance::IsDnsNameCompressionEnabled())
230 {
231 // If "DNS name compression" mode is disabled, instead of
232 // appending the pointer label, read the name from the message
233 // and append it uncompressed. Note that the `aOffset` parameter
234 // in this method is given relative to the start of DNS header
235 // in `aMessage` (which `aMessage.GetOffset()` specifies).
236
237 error = Name(aMessage, aOffset + aMessage.GetOffset()).AppendTo(aMessage);
238 ExitNow();
239 }
240 #endif
241
242 // A pointer label takes the form of a two byte sequence as a
243 // `uint16_t` value. The first two bits are ones. This allows a
244 // pointer to be distinguished from a text label, since the text
245 // label must begin with two zero bits (note that labels are
246 // restricted to 63 octets or less). The next 14-bits specify
247 // an offset value relative to start of DNS header.
248
249 OT_ASSERT(aOffset < kPointerLabelTypeUint16);
250
251 value = HostSwap16(aOffset | kPointerLabelTypeUint16);
252
253 ExitNow(error = aMessage.Append(value));
254
255 exit:
256 return error;
257 }
258
AppendName(const char * aName,Message & aMessage)259 Error Name::AppendName(const char *aName, Message &aMessage)
260 {
261 Error error;
262
263 SuccessOrExit(error = AppendMultipleLabels(aName, aMessage));
264 error = AppendTerminator(aMessage);
265
266 exit:
267 return error;
268 }
269
ParseName(const Message & aMessage,uint16_t & aOffset)270 Error Name::ParseName(const Message &aMessage, uint16_t &aOffset)
271 {
272 Error error;
273 LabelIterator iterator(aMessage, aOffset);
274
275 while (true)
276 {
277 error = iterator.GetNextLabel();
278
279 switch (error)
280 {
281 case kErrorNone:
282 break;
283
284 case kErrorNotFound:
285 // We reached the end of name successfully.
286 aOffset = iterator.mNameEndOffset;
287 error = kErrorNone;
288
289 OT_FALL_THROUGH;
290
291 default:
292 ExitNow();
293 }
294 }
295
296 exit:
297 return error;
298 }
299
ReadLabel(const Message & aMessage,uint16_t & aOffset,char * aLabelBuffer,uint8_t & aLabelLength)300 Error Name::ReadLabel(const Message &aMessage, uint16_t &aOffset, char *aLabelBuffer, uint8_t &aLabelLength)
301 {
302 Error error;
303 LabelIterator iterator(aMessage, aOffset);
304
305 SuccessOrExit(error = iterator.GetNextLabel());
306 SuccessOrExit(error = iterator.ReadLabel(aLabelBuffer, aLabelLength, /* aAllowDotCharInLabel */ true));
307 aOffset = iterator.mNextLabelOffset;
308
309 exit:
310 return error;
311 }
312
ReadName(const Message & aMessage,uint16_t & aOffset,char * aNameBuffer,uint16_t aNameBufferSize)313 Error Name::ReadName(const Message &aMessage, uint16_t &aOffset, char *aNameBuffer, uint16_t aNameBufferSize)
314 {
315 Error error;
316 LabelIterator iterator(aMessage, aOffset);
317 bool firstLabel = true;
318 uint8_t labelLength;
319
320 while (true)
321 {
322 error = iterator.GetNextLabel();
323
324 switch (error)
325 {
326 case kErrorNone:
327
328 if (!firstLabel)
329 {
330 *aNameBuffer++ = kLabelSeperatorChar;
331 aNameBufferSize--;
332
333 // No need to check if we have reached end of the name buffer
334 // here since `iterator.ReadLabel()` would verify it.
335 }
336
337 labelLength = static_cast<uint8_t>(OT_MIN(static_cast<uint8_t>(kMaxLabelSize), aNameBufferSize));
338 SuccessOrExit(error = iterator.ReadLabel(aNameBuffer, labelLength, /* aAllowDotCharInLabel */ false));
339 aNameBuffer += labelLength;
340 aNameBufferSize -= labelLength;
341 firstLabel = false;
342 break;
343
344 case kErrorNotFound:
345 // We reach the end of name successfully. Always add a terminating dot
346 // at the end.
347 *aNameBuffer++ = kLabelSeperatorChar;
348 aNameBufferSize--;
349 VerifyOrExit(aNameBufferSize >= sizeof(uint8_t), error = kErrorNoBufs);
350 *aNameBuffer = kNullChar;
351 aOffset = iterator.mNameEndOffset;
352 error = kErrorNone;
353
354 OT_FALL_THROUGH;
355
356 default:
357 ExitNow();
358 }
359 }
360
361 exit:
362 return error;
363 }
364
CompareLabel(const Message & aMessage,uint16_t & aOffset,const char * aLabel)365 Error Name::CompareLabel(const Message &aMessage, uint16_t &aOffset, const char *aLabel)
366 {
367 Error error;
368 LabelIterator iterator(aMessage, aOffset);
369
370 SuccessOrExit(error = iterator.GetNextLabel());
371 VerifyOrExit(iterator.CompareLabel(aLabel, kIsSingleLabel), error = kErrorNotFound);
372 aOffset = iterator.mNextLabelOffset;
373
374 exit:
375 return error;
376 }
377
CompareName(const Message & aMessage,uint16_t & aOffset,const char * aName)378 Error Name::CompareName(const Message &aMessage, uint16_t &aOffset, const char *aName)
379 {
380 Error error;
381 LabelIterator iterator(aMessage, aOffset);
382 bool matches = true;
383
384 if (*aName == kLabelSeperatorChar)
385 {
386 aName++;
387 VerifyOrExit(*aName == kNullChar, error = kErrorInvalidArgs);
388 }
389
390 while (true)
391 {
392 error = iterator.GetNextLabel();
393
394 switch (error)
395 {
396 case kErrorNone:
397 if (matches && !iterator.CompareLabel(aName, !kIsSingleLabel))
398 {
399 matches = false;
400 }
401
402 break;
403
404 case kErrorNotFound:
405 // We reached the end of the name in `aMessage`. We check if
406 // all the previous labels matched so far, and we are also
407 // at the end of `aName` string (see null char), then we
408 // return `kErrorNone` indicating a successful comparison
409 // (full match). Otherwise we return `kErrorNotFound` to
410 // indicate failed comparison.
411
412 if (matches && (*aName == kNullChar))
413 {
414 error = kErrorNone;
415 }
416
417 aOffset = iterator.mNameEndOffset;
418
419 OT_FALL_THROUGH;
420
421 default:
422 ExitNow();
423 }
424 }
425
426 exit:
427 return error;
428 }
429
CompareName(const Message & aMessage,uint16_t & aOffset,const Message & aMessage2,uint16_t aOffset2)430 Error Name::CompareName(const Message &aMessage, uint16_t &aOffset, const Message &aMessage2, uint16_t aOffset2)
431 {
432 Error error;
433 LabelIterator iterator(aMessage, aOffset);
434 LabelIterator iterator2(aMessage2, aOffset2);
435 bool matches = true;
436
437 while (true)
438 {
439 error = iterator.GetNextLabel();
440
441 switch (error)
442 {
443 case kErrorNone:
444 // If all the previous labels matched so far, then verify
445 // that we can get the next label on `iterator2` and that it
446 // matches the label from `iterator`.
447 if (matches && (iterator2.GetNextLabel() != kErrorNone || !iterator.CompareLabel(iterator2)))
448 {
449 matches = false;
450 }
451
452 break;
453
454 case kErrorNotFound:
455 // We reached the end of the name in `aMessage`. We check
456 // that `iterator2` is also at its end, and if all previous
457 // labels matched we return `kErrorNone`.
458
459 if (matches && (iterator2.GetNextLabel() == kErrorNotFound))
460 {
461 error = kErrorNone;
462 }
463
464 aOffset = iterator.mNameEndOffset;
465
466 OT_FALL_THROUGH;
467
468 default:
469 ExitNow();
470 }
471 }
472
473 exit:
474 return error;
475 }
476
CompareName(const Message & aMessage,uint16_t & aOffset,const Name & aName)477 Error Name::CompareName(const Message &aMessage, uint16_t &aOffset, const Name &aName)
478 {
479 return aName.IsFromCString()
480 ? CompareName(aMessage, aOffset, aName.mString)
481 : (aName.IsFromMessage() ? CompareName(aMessage, aOffset, *aName.mMessage, aName.mOffset)
482 : ParseName(aMessage, aOffset));
483 }
484
GetNextLabel(void)485 Error Name::LabelIterator::GetNextLabel(void)
486 {
487 Error error;
488
489 while (true)
490 {
491 uint8_t labelLength;
492 uint8_t labelType;
493
494 SuccessOrExit(error = mMessage.Read(mNextLabelOffset, labelLength));
495
496 labelType = labelLength & kLabelTypeMask;
497
498 if (labelType == kTextLabelType)
499 {
500 if (labelLength == 0)
501 {
502 // Zero label length indicates end of a name.
503
504 if (!IsEndOffsetSet())
505 {
506 mNameEndOffset = mNextLabelOffset + sizeof(uint8_t);
507 }
508
509 ExitNow(error = kErrorNotFound);
510 }
511
512 mLabelStartOffset = mNextLabelOffset + sizeof(uint8_t);
513 mLabelLength = labelLength;
514 mNextLabelOffset = mLabelStartOffset + labelLength;
515 ExitNow();
516 }
517 else if (labelType == kPointerLabelType)
518 {
519 // A pointer label takes the form of a two byte sequence as a
520 // `uint16_t` value. The first two bits are ones. The next 14 bits
521 // specify an offset value from the start of the DNS header.
522
523 uint16_t pointerValue;
524
525 SuccessOrExit(error = mMessage.Read(mNextLabelOffset, pointerValue));
526
527 if (!IsEndOffsetSet())
528 {
529 mNameEndOffset = mNextLabelOffset + sizeof(uint16_t);
530 }
531
532 // `mMessage.GetOffset()` must point to the start of the
533 // DNS header.
534 mNextLabelOffset = mMessage.GetOffset() + (HostSwap16(pointerValue) & kPointerLabelOffsetMask);
535
536 // Go back through the `while(true)` loop to get the next label.
537 }
538 else
539 {
540 ExitNow(error = kErrorParse);
541 }
542 }
543
544 exit:
545 return error;
546 }
547
ReadLabel(char * aLabelBuffer,uint8_t & aLabelLength,bool aAllowDotCharInLabel) const548 Error Name::LabelIterator::ReadLabel(char *aLabelBuffer, uint8_t &aLabelLength, bool aAllowDotCharInLabel) const
549 {
550 Error error;
551
552 VerifyOrExit(mLabelLength < aLabelLength, error = kErrorNoBufs);
553
554 SuccessOrExit(error = mMessage.Read(mLabelStartOffset, aLabelBuffer, mLabelLength));
555 aLabelBuffer[mLabelLength] = kNullChar;
556 aLabelLength = mLabelLength;
557
558 if (!aAllowDotCharInLabel)
559 {
560 VerifyOrExit(StringFind(aLabelBuffer, kLabelSeperatorChar) == nullptr, error = kErrorParse);
561 }
562
563 exit:
564 return error;
565 }
566
CaseInsensitiveMatch(uint8_t aFirst,uint8_t aSecond)567 bool Name::LabelIterator::CaseInsensitiveMatch(uint8_t aFirst, uint8_t aSecond)
568 {
569 return ToLowercase(static_cast<char>(aFirst)) == ToLowercase(static_cast<char>(aSecond));
570 }
571
CompareLabel(const char * & aName,bool aIsSingleLabel) const572 bool Name::LabelIterator::CompareLabel(const char *&aName, bool aIsSingleLabel) const
573 {
574 // This method compares the current label in the iterator with the
575 // `aName` string. `aIsSingleLabel` indicates whether `aName` is a
576 // single label, or a sequence of labels separated by dot '.' char.
577 // If the label matches `aName`, then `aName` pointer is moved
578 // forward to the start of the next label (skipping over the `.`
579 // char). This method returns `true` when the labels match, `false`
580 // otherwise.
581
582 bool matches = false;
583
584 VerifyOrExit(StringLength(aName, mLabelLength) == mLabelLength);
585 matches = mMessage.CompareBytes(mLabelStartOffset, aName, mLabelLength, CaseInsensitiveMatch);
586
587 VerifyOrExit(matches);
588
589 aName += mLabelLength;
590
591 // If `aName` is a single label, we should be also at the end of the
592 // `aName` string. Otherwise, we should see either null or dot '.'
593 // character (in case `aName` contains multiple labels).
594
595 matches = (*aName == kNullChar);
596
597 if (!aIsSingleLabel && (*aName == kLabelSeperatorChar))
598 {
599 matches = true;
600 aName++;
601 }
602
603 exit:
604 return matches;
605 }
606
CompareLabel(const LabelIterator & aOtherIterator) const607 bool Name::LabelIterator::CompareLabel(const LabelIterator &aOtherIterator) const
608 {
609 // This method compares the current label in the iterator with the
610 // label from another iterator.
611
612 return (mLabelLength == aOtherIterator.mLabelLength) &&
613 mMessage.CompareBytes(mLabelStartOffset, aOtherIterator.mMessage, aOtherIterator.mLabelStartOffset,
614 mLabelLength, CaseInsensitiveMatch);
615 }
616
AppendLabel(Message & aMessage) const617 Error Name::LabelIterator::AppendLabel(Message &aMessage) const
618 {
619 // This method reads and appends the current label in the iterator
620 // to `aMessage`.
621
622 Error error;
623
624 VerifyOrExit((0 < mLabelLength) && (mLabelLength <= kMaxLabelLength), error = kErrorInvalidArgs);
625 SuccessOrExit(error = aMessage.Append(mLabelLength));
626 error = aMessage.AppendBytesFromMessage(mMessage, mLabelStartOffset, mLabelLength);
627
628 exit:
629 return error;
630 }
631
IsSubDomainOf(const char * aName,const char * aDomain)632 bool Name::IsSubDomainOf(const char *aName, const char *aDomain)
633 {
634 bool match = false;
635 bool nameEndsWithDot = false;
636 bool domainEndsWithDot = false;
637 uint16_t nameLength = StringLength(aName, kMaxNameLength);
638 uint16_t domainLength = StringLength(aDomain, kMaxNameLength);
639
640 if (nameLength > 0 && aName[nameLength - 1] == kLabelSeperatorChar)
641 {
642 nameEndsWithDot = true;
643 --nameLength;
644 }
645
646 if (domainLength > 0 && aDomain[domainLength - 1] == kLabelSeperatorChar)
647 {
648 domainEndsWithDot = true;
649 --domainLength;
650 }
651
652 VerifyOrExit(nameLength >= domainLength);
653
654 aName += nameLength - domainLength;
655
656 if (nameLength > domainLength)
657 {
658 VerifyOrExit(aName[-1] == kLabelSeperatorChar);
659 }
660
661 // This method allows either `aName` or `aDomain` to include or
662 // exclude the last `.` character. If both include it or if both
663 // do not, we do a full comparison using `StringMatch()`.
664 // Otherwise (i.e., when one includes and the other one does not)
665 // we use `StringStartWith()` to allow the extra `.` character.
666
667 if (nameEndsWithDot == domainEndsWithDot)
668 {
669 match = StringMatch(aName, aDomain, kStringCaseInsensitiveMatch);
670 }
671 else if (nameEndsWithDot)
672 {
673 // `aName` ends with dot, but `aDomain` does not.
674 match = StringStartsWith(aName, aDomain, kStringCaseInsensitiveMatch);
675 }
676 else
677 {
678 // `aDomain` ends with dot, but `aName` does not.
679 match = StringStartsWith(aDomain, aName, kStringCaseInsensitiveMatch);
680 }
681
682 exit:
683 return match;
684 }
685
ParseRecords(const Message & aMessage,uint16_t & aOffset,uint16_t aNumRecords)686 Error ResourceRecord::ParseRecords(const Message &aMessage, uint16_t &aOffset, uint16_t aNumRecords)
687 {
688 Error error = kErrorNone;
689
690 while (aNumRecords > 0)
691 {
692 ResourceRecord record;
693
694 SuccessOrExit(error = Name::ParseName(aMessage, aOffset));
695 SuccessOrExit(error = record.ReadFrom(aMessage, aOffset));
696 aOffset += static_cast<uint16_t>(record.GetSize());
697 aNumRecords--;
698 }
699
700 exit:
701 return error;
702 }
703
FindRecord(const Message & aMessage,uint16_t & aOffset,uint16_t & aNumRecords,const Name & aName)704 Error ResourceRecord::FindRecord(const Message &aMessage, uint16_t &aOffset, uint16_t &aNumRecords, const Name &aName)
705 {
706 Error error;
707
708 while (aNumRecords > 0)
709 {
710 bool matches = true;
711 ResourceRecord record;
712
713 error = Name::CompareName(aMessage, aOffset, aName);
714
715 switch (error)
716 {
717 case kErrorNone:
718 break;
719 case kErrorNotFound:
720 matches = false;
721 break;
722 default:
723 ExitNow();
724 }
725
726 SuccessOrExit(error = record.ReadFrom(aMessage, aOffset));
727 aNumRecords--;
728 VerifyOrExit(!matches);
729 aOffset += static_cast<uint16_t>(record.GetSize());
730 }
731
732 error = kErrorNotFound;
733
734 exit:
735 return error;
736 }
737
FindRecord(const Message & aMessage,uint16_t & aOffset,uint16_t aNumRecords,uint16_t aIndex,const Name & aName,uint16_t aType,ResourceRecord & aRecord,uint16_t aMinRecordSize)738 Error ResourceRecord::FindRecord(const Message & aMessage,
739 uint16_t & aOffset,
740 uint16_t aNumRecords,
741 uint16_t aIndex,
742 const Name & aName,
743 uint16_t aType,
744 ResourceRecord &aRecord,
745 uint16_t aMinRecordSize)
746 {
747 // This static method searches in `aMessage` starting from `aOffset`
748 // up to maximum of `aNumRecords`, for the `(aIndex+1)`th
749 // occurrence of a resource record of type `aType` with record name
750 // matching `aName`. It also verifies that the record size is larger
751 // than `aMinRecordSize`. If found, `aMinRecordSize` bytes from the
752 // record are read and copied into `aRecord`. In this case `aOffset`
753 // is updated to point to the last record byte read from the message
754 // (so that the caller can read any remaining fields in the record
755 // data).
756
757 Error error;
758 uint16_t offset = aOffset;
759 uint16_t recordOffset;
760
761 while (aNumRecords > 0)
762 {
763 SuccessOrExit(error = FindRecord(aMessage, offset, aNumRecords, aName));
764
765 // Save the offset to start of `ResourceRecord` fields.
766 recordOffset = offset;
767
768 error = ReadRecord(aMessage, offset, aType, aRecord, aMinRecordSize);
769
770 if (error == kErrorNotFound)
771 {
772 // `ReadRecord()` already updates the `offset` to skip
773 // over a non-matching record.
774 continue;
775 }
776
777 SuccessOrExit(error);
778
779 if (aIndex == 0)
780 {
781 aOffset = offset;
782 ExitNow();
783 }
784
785 aIndex--;
786
787 // Skip over the record.
788 offset = static_cast<uint16_t>(recordOffset + aRecord.GetSize());
789 }
790
791 error = kErrorNotFound;
792
793 exit:
794 return error;
795 }
796
ReadRecord(const Message & aMessage,uint16_t & aOffset,uint16_t aType,ResourceRecord & aRecord,uint16_t aMinRecordSize)797 Error ResourceRecord::ReadRecord(const Message & aMessage,
798 uint16_t & aOffset,
799 uint16_t aType,
800 ResourceRecord &aRecord,
801 uint16_t aMinRecordSize)
802 {
803 // This static method tries to read a matching resource record of a
804 // given type and a minimum record size from a message. The `aType`
805 // value of `kTypeAny` matches any type. If the record in the
806 // message does not match, it skips over the record. Please see
807 // `ReadRecord<RecordType>()` for more details.
808
809 Error error;
810 ResourceRecord record;
811
812 SuccessOrExit(error = record.ReadFrom(aMessage, aOffset));
813
814 if (((aType == kTypeAny) || (record.GetType() == aType)) && (record.GetSize() >= aMinRecordSize))
815 {
816 IgnoreError(aMessage.Read(aOffset, &aRecord, aMinRecordSize));
817 aOffset += aMinRecordSize;
818 }
819 else
820 {
821 // Skip over the entire record.
822 aOffset += static_cast<uint16_t>(record.GetSize());
823 error = kErrorNotFound;
824 }
825
826 exit:
827 return error;
828 }
829
ReadName(const Message & aMessage,uint16_t & aOffset,uint16_t aStartOffset,char * aNameBuffer,uint16_t aNameBufferSize,bool aSkipRecord) const830 Error ResourceRecord::ReadName(const Message &aMessage,
831 uint16_t & aOffset,
832 uint16_t aStartOffset,
833 char * aNameBuffer,
834 uint16_t aNameBufferSize,
835 bool aSkipRecord) const
836 {
837 // This protected method parses and reads a name field in a record
838 // from a message. It is intended only for sub-classes of
839 // `ResourceRecord`.
840 //
841 // On input `aOffset` gives the offset in `aMessage` to the start of
842 // name field. `aStartOffset` gives the offset to the start of the
843 // `ResourceRecord`. `aSkipRecord` indicates whether to skip over
844 // the entire resource record or just the read name. On exit, when
845 // successfully read, `aOffset` is updated to either point after the
846 // end of record or after the the name field.
847 //
848 // When read successfully, this method returns `kErrorNone`. On a
849 // parse error (invalid format) returns `kErrorParse`. If the
850 // name does not fit in the given name buffer it returns
851 // `kErrorNoBufs`
852
853 Error error = kErrorNone;
854
855 SuccessOrExit(error = Name::ReadName(aMessage, aOffset, aNameBuffer, aNameBufferSize));
856 VerifyOrExit(aOffset <= aStartOffset + GetSize(), error = kErrorParse);
857
858 VerifyOrExit(aSkipRecord);
859 aOffset = aStartOffset;
860 error = SkipRecord(aMessage, aOffset);
861
862 exit:
863 return error;
864 }
865
SkipRecord(const Message & aMessage,uint16_t & aOffset) const866 Error ResourceRecord::SkipRecord(const Message &aMessage, uint16_t &aOffset) const
867 {
868 // This protected method parses and skips over a resource record
869 // in a message.
870 //
871 // On input `aOffset` gives the offset in `aMessage` to the start of
872 // the `ResourceRecord`. On exit, when successfully parsed, `aOffset`
873 // is updated to point to byte after the entire record.
874
875 Error error;
876
877 SuccessOrExit(error = CheckRecord(aMessage, aOffset));
878 aOffset += static_cast<uint16_t>(GetSize());
879
880 exit:
881 return error;
882 }
883
CheckRecord(const Message & aMessage,uint16_t aOffset) const884 Error ResourceRecord::CheckRecord(const Message &aMessage, uint16_t aOffset) const
885 {
886 // This method checks that the entire record (including record data)
887 // is present in `aMessage` at `aOffset` (pointing to the start of
888 // the `ResourceRecord` in `aMessage`).
889
890 return (aOffset + GetSize() <= aMessage.GetLength()) ? kErrorNone : kErrorParse;
891 }
892
ReadFrom(const Message & aMessage,uint16_t aOffset)893 Error ResourceRecord::ReadFrom(const Message &aMessage, uint16_t aOffset)
894 {
895 // This method reads the `ResourceRecord` from `aMessage` at
896 // `aOffset`. It verifies that the entire record (including record
897 // data) is present in the message.
898
899 Error error;
900
901 SuccessOrExit(error = aMessage.Read(aOffset, *this));
902 error = CheckRecord(aMessage, aOffset);
903
904 exit:
905 return error;
906 }
907
Init(const uint8_t * aTxtData,uint16_t aTxtDataLength)908 void TxtEntry::Iterator::Init(const uint8_t *aTxtData, uint16_t aTxtDataLength)
909 {
910 SetTxtData(aTxtData);
911 SetTxtDataLength(aTxtDataLength);
912 SetTxtDataPosition(0);
913 }
914
GetNextEntry(TxtEntry & aEntry)915 Error TxtEntry::Iterator::GetNextEntry(TxtEntry &aEntry)
916 {
917 Error error = kErrorNone;
918 uint8_t length;
919 uint8_t index;
920 const char *cur;
921 char * keyBuffer = GetKeyBuffer();
922
923 static_assert(sizeof(mChar) == TxtEntry::kMaxKeyLength + 1, "KeyBuffer cannot fit the max key length");
924
925 VerifyOrExit(GetTxtData() != nullptr, error = kErrorParse);
926
927 aEntry.mKey = keyBuffer;
928
929 while ((cur = GetTxtData() + GetTxtDataPosition()) < GetTxtDataEnd())
930 {
931 length = static_cast<uint8_t>(*cur);
932
933 cur++;
934 VerifyOrExit(cur + length <= GetTxtDataEnd(), error = kErrorParse);
935 IncreaseTxtDataPosition(sizeof(uint8_t) + length);
936
937 // Silently skip over an empty string or if the string starts with
938 // a `=` character (i.e., missing key) - RFC 6763 - section 6.4.
939
940 if ((length == 0) || (cur[0] == kKeyValueSeparator))
941 {
942 continue;
943 }
944
945 for (index = 0; index < length; index++)
946 {
947 if (cur[index] == kKeyValueSeparator)
948 {
949 keyBuffer[index++] = kNullChar; // Increment index to skip over `=`.
950 aEntry.mValue = reinterpret_cast<const uint8_t *>(&cur[index]);
951 aEntry.mValueLength = length - index;
952 ExitNow();
953 }
954
955 if (index >= kMaxKeyLength)
956 {
957 // The key is larger than recommended max key length.
958 // In this case, we return the full encoded string in
959 // `mValue` and `mValueLength` and set `mKey` to
960 // `nullptr`.
961
962 aEntry.mKey = nullptr;
963 aEntry.mValue = reinterpret_cast<const uint8_t *>(cur);
964 aEntry.mValueLength = length;
965 ExitNow();
966 }
967
968 keyBuffer[index] = cur[index];
969 }
970
971 // If we reach the end of the string without finding `=` then
972 // it is a boolean key attribute (encoded as "key").
973
974 keyBuffer[index] = kNullChar;
975 aEntry.mValue = nullptr;
976 aEntry.mValueLength = 0;
977 ExitNow();
978 }
979
980 error = kErrorNotFound;
981
982 exit:
983 return error;
984 }
985
AppendTo(Message & aMessage) const986 Error TxtEntry::AppendTo(Message &aMessage) const
987 {
988 Appender appender(aMessage);
989
990 return AppendTo(appender);
991 }
992
AppendTo(Appender & aAppender) const993 Error TxtEntry::AppendTo(Appender &aAppender) const
994 {
995 Error error = kErrorNone;
996 uint16_t keyLength;
997 char separator = kKeyValueSeparator;
998
999 if (mKey == nullptr)
1000 {
1001 VerifyOrExit((mValue != nullptr) && (mValueLength != 0));
1002 error = aAppender.AppendBytes(mValue, mValueLength);
1003 ExitNow();
1004 }
1005
1006 keyLength = StringLength(mKey, static_cast<uint16_t>(kMaxKeyValueEncodedSize) + 1);
1007
1008 VerifyOrExit(kMinKeyLength <= keyLength, error = kErrorInvalidArgs);
1009
1010 if (mValue == nullptr)
1011 {
1012 // Treat as a boolean attribute and encoded as "key" (with no `=`).
1013 SuccessOrExit(error = aAppender.Append<uint8_t>(static_cast<uint8_t>(keyLength)));
1014 error = aAppender.AppendBytes(mKey, keyLength);
1015 ExitNow();
1016 }
1017
1018 // Treat as key/value and encode as "key=value", value may be empty.
1019
1020 VerifyOrExit(mValueLength + keyLength + sizeof(char) <= kMaxKeyValueEncodedSize, error = kErrorInvalidArgs);
1021
1022 SuccessOrExit(error = aAppender.Append<uint8_t>(static_cast<uint8_t>(keyLength + mValueLength + sizeof(char))));
1023 SuccessOrExit(error = aAppender.AppendBytes(mKey, keyLength));
1024 SuccessOrExit(error = aAppender.Append(separator));
1025 error = aAppender.AppendBytes(mValue, mValueLength);
1026
1027 exit:
1028 return error;
1029 }
1030
AppendEntries(const TxtEntry * aEntries,uint8_t aNumEntries,Message & aMessage)1031 Error TxtEntry::AppendEntries(const TxtEntry *aEntries, uint8_t aNumEntries, Message &aMessage)
1032 {
1033 Appender appender(aMessage);
1034
1035 return AppendEntries(aEntries, aNumEntries, appender);
1036 }
1037
AppendEntries(const TxtEntry * aEntries,uint8_t aNumEntries,MutableData<kWithUint16Length> & aData)1038 Error TxtEntry::AppendEntries(const TxtEntry *aEntries, uint8_t aNumEntries, MutableData<kWithUint16Length> &aData)
1039 {
1040 Error error;
1041 Appender appender(aData.GetBytes(), aData.GetLength());
1042
1043 SuccessOrExit(error = AppendEntries(aEntries, aNumEntries, appender));
1044 appender.GetAsData(aData);
1045
1046 exit:
1047 return error;
1048 }
1049
AppendEntries(const TxtEntry * aEntries,uint8_t aNumEntries,Appender & aAppender)1050 Error TxtEntry::AppendEntries(const TxtEntry *aEntries, uint8_t aNumEntries, Appender &aAppender)
1051 {
1052 Error error = kErrorNone;
1053
1054 for (uint8_t index = 0; index < aNumEntries; index++)
1055 {
1056 SuccessOrExit(error = aEntries[index].AppendTo(aAppender));
1057 }
1058
1059 if (aAppender.GetAppendedLength() == 0)
1060 {
1061 error = aAppender.Append<uint8_t>(0);
1062 }
1063
1064 exit:
1065 return error;
1066 }
1067
IsValid(void) const1068 bool AaaaRecord::IsValid(void) const
1069 {
1070 return GetType() == Dns::ResourceRecord::kTypeAaaa && GetSize() == sizeof(*this);
1071 }
1072
IsValid(void) const1073 bool KeyRecord::IsValid(void) const
1074 {
1075 return GetType() == Dns::ResourceRecord::kTypeKey;
1076 }
1077
1078 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
Init(void)1079 void Ecdsa256KeyRecord::Init(void)
1080 {
1081 KeyRecord::Init();
1082 SetAlgorithm(kAlgorithmEcdsaP256Sha256);
1083 }
1084
IsValid(void) const1085 bool Ecdsa256KeyRecord::IsValid(void) const
1086 {
1087 return KeyRecord::IsValid() && GetLength() == sizeof(*this) - sizeof(ResourceRecord) &&
1088 GetAlgorithm() == kAlgorithmEcdsaP256Sha256;
1089 }
1090 #endif
1091
IsValid(void) const1092 bool SigRecord::IsValid(void) const
1093 {
1094 return GetType() == Dns::ResourceRecord::kTypeSig && GetLength() >= sizeof(*this) - sizeof(ResourceRecord);
1095 }
1096
IsValid(void) const1097 bool LeaseOption::IsValid(void) const
1098 {
1099 return GetLeaseInterval() <= GetKeyLeaseInterval();
1100 }
1101
ReadPtrName(const Message & aMessage,uint16_t & aOffset,char * aLabelBuffer,uint8_t aLabelBufferSize,char * aNameBuffer,uint16_t aNameBufferSize) const1102 Error PtrRecord::ReadPtrName(const Message &aMessage,
1103 uint16_t & aOffset,
1104 char * aLabelBuffer,
1105 uint8_t aLabelBufferSize,
1106 char * aNameBuffer,
1107 uint16_t aNameBufferSize) const
1108 {
1109 Error error = kErrorNone;
1110 uint16_t startOffset = aOffset - sizeof(PtrRecord); // start of `PtrRecord`.
1111
1112 // Verify that the name is within the record data length.
1113 SuccessOrExit(error = Name::ParseName(aMessage, aOffset));
1114 VerifyOrExit(aOffset <= startOffset + GetSize(), error = kErrorParse);
1115
1116 aOffset = startOffset + sizeof(PtrRecord);
1117 SuccessOrExit(error = Name::ReadLabel(aMessage, aOffset, aLabelBuffer, aLabelBufferSize));
1118
1119 if (aNameBuffer != nullptr)
1120 {
1121 SuccessOrExit(error = Name::ReadName(aMessage, aOffset, aNameBuffer, aNameBufferSize));
1122 }
1123
1124 aOffset = startOffset;
1125 error = SkipRecord(aMessage, aOffset);
1126
1127 exit:
1128 return error;
1129 }
1130
ReadTxtData(const Message & aMessage,uint16_t & aOffset,uint8_t * aTxtBuffer,uint16_t & aTxtBufferSize) const1131 Error TxtRecord::ReadTxtData(const Message &aMessage,
1132 uint16_t & aOffset,
1133 uint8_t * aTxtBuffer,
1134 uint16_t & aTxtBufferSize) const
1135 {
1136 Error error = kErrorNone;
1137
1138 VerifyOrExit(GetLength() <= aTxtBufferSize, error = kErrorNoBufs);
1139 SuccessOrExit(error = aMessage.Read(aOffset, aTxtBuffer, GetLength()));
1140 VerifyOrExit(VerifyTxtData(aTxtBuffer, GetLength(), /* aAllowEmpty */ true), error = kErrorParse);
1141 aTxtBufferSize = GetLength();
1142 aOffset += GetLength();
1143
1144 exit:
1145 return error;
1146 }
1147
VerifyTxtData(const uint8_t * aTxtData,uint16_t aTxtLength,bool aAllowEmpty)1148 bool TxtRecord::VerifyTxtData(const uint8_t *aTxtData, uint16_t aTxtLength, bool aAllowEmpty)
1149 {
1150 bool valid = false;
1151 uint8_t curEntryLength = 0;
1152
1153 // Per RFC 1035, TXT-DATA MUST have one or more <character-string>s.
1154 VerifyOrExit(aAllowEmpty || aTxtLength > 0);
1155
1156 for (uint16_t i = 0; i < aTxtLength; ++i)
1157 {
1158 if (curEntryLength == 0)
1159 {
1160 curEntryLength = aTxtData[i];
1161 }
1162 else
1163 {
1164 --curEntryLength;
1165 }
1166 }
1167
1168 valid = (curEntryLength == 0);
1169
1170 exit:
1171 return valid;
1172 }
1173
1174 } // namespace Dns
1175 } // namespace ot
1176