• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2017-2021, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 #include "dns_client.hpp"
30 
31 #if OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
32 
33 #include "common/array.hpp"
34 #include "common/as_core_type.hpp"
35 #include "common/code_utils.hpp"
36 #include "common/debug.hpp"
37 #include "common/locator_getters.hpp"
38 #include "common/log.hpp"
39 #include "instance/instance.hpp"
40 #include "net/udp6.hpp"
41 #include "thread/network_data_types.hpp"
42 #include "thread/thread_netif.hpp"
43 
44 /**
45  * @file
46  *   This file implements the DNS client.
47  */
48 
49 namespace ot {
50 namespace Dns {
51 
52 RegisterLogModule("DnsClient");
53 
54 //---------------------------------------------------------------------------------------------------------------------
55 // Client::QueryConfig
56 
57 const char Client::QueryConfig::kDefaultServerAddressString[] = OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_IP6_ADDRESS;
58 
QueryConfig(InitMode aMode)59 Client::QueryConfig::QueryConfig(InitMode aMode)
60 {
61     OT_UNUSED_VARIABLE(aMode);
62 
63     IgnoreError(GetServerSockAddr().GetAddress().FromString(kDefaultServerAddressString));
64     GetServerSockAddr().SetPort(kDefaultServerPort);
65     SetResponseTimeout(kDefaultResponseTimeout);
66     SetMaxTxAttempts(kDefaultMaxTxAttempts);
67     SetRecursionFlag(kDefaultRecursionDesired ? kFlagRecursionDesired : kFlagNoRecursion);
68     SetServiceMode(kDefaultServiceMode);
69 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
70     SetNat64Mode(kDefaultNat64Allowed ? kNat64Allow : kNat64Disallow);
71 #endif
72     SetTransportProto(kDnsTransportUdp);
73 }
74 
SetFrom(const QueryConfig * aConfig,const QueryConfig & aDefaultConfig)75 void Client::QueryConfig::SetFrom(const QueryConfig *aConfig, const QueryConfig &aDefaultConfig)
76 {
77     // This method sets the config from `aConfig` replacing any
78     // unspecified fields (value zero) with the fields from
79     // `aDefaultConfig`. If `aConfig` is `nullptr` then
80     // `aDefaultConfig` is used.
81 
82     if (aConfig == nullptr)
83     {
84         *this = aDefaultConfig;
85         ExitNow();
86     }
87 
88     *this = *aConfig;
89 
90     if (GetServerSockAddr().GetAddress().IsUnspecified())
91     {
92         GetServerSockAddr().GetAddress() = aDefaultConfig.GetServerSockAddr().GetAddress();
93     }
94 
95     if (GetServerSockAddr().GetPort() == 0)
96     {
97         GetServerSockAddr().SetPort(aDefaultConfig.GetServerSockAddr().GetPort());
98     }
99 
100     if (GetResponseTimeout() == 0)
101     {
102         SetResponseTimeout(aDefaultConfig.GetResponseTimeout());
103     }
104 
105     if (GetMaxTxAttempts() == 0)
106     {
107         SetMaxTxAttempts(aDefaultConfig.GetMaxTxAttempts());
108     }
109 
110     if (GetRecursionFlag() == kFlagUnspecified)
111     {
112         SetRecursionFlag(aDefaultConfig.GetRecursionFlag());
113     }
114 
115 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
116     if (GetNat64Mode() == kNat64Unspecified)
117     {
118         SetNat64Mode(aDefaultConfig.GetNat64Mode());
119     }
120 #endif
121 
122     if (GetServiceMode() == kServiceModeUnspecified)
123     {
124         SetServiceMode(aDefaultConfig.GetServiceMode());
125     }
126 
127     if (GetTransportProto() == kDnsTransportUnspecified)
128     {
129         SetTransportProto(aDefaultConfig.GetTransportProto());
130     }
131 
132 exit:
133     return;
134 }
135 
136 //---------------------------------------------------------------------------------------------------------------------
137 // Client::Response
138 
SelectSection(Section aSection,uint16_t & aOffset,uint16_t & aNumRecord) const139 void Client::Response::SelectSection(Section aSection, uint16_t &aOffset, uint16_t &aNumRecord) const
140 {
141     switch (aSection)
142     {
143     case kAnswerSection:
144         aOffset    = mAnswerOffset;
145         aNumRecord = mAnswerRecordCount;
146         break;
147     case kAdditionalDataSection:
148     default:
149         aOffset    = mAdditionalOffset;
150         aNumRecord = mAdditionalRecordCount;
151         break;
152     }
153 }
154 
GetName(char * aNameBuffer,uint16_t aNameBufferSize) const155 Error Client::Response::GetName(char *aNameBuffer, uint16_t aNameBufferSize) const
156 {
157     uint16_t offset = kNameOffsetInQuery;
158 
159     return Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize);
160 }
161 
CheckForHostNameAlias(Section aSection,Name & aHostName) const162 Error Client::Response::CheckForHostNameAlias(Section aSection, Name &aHostName) const
163 {
164     // If the response includes a CNAME record mapping the query host
165     // name to a canonical name, we update `aHostName` to the new alias
166     // name. Otherwise `aHostName` remains as before. This method handles
167     // when there are multiple CNAME records mapping the host name multiple
168     // times. We limit number of changes to `kMaxCnameAliasNameChanges`
169     // to detect and handle if the response contains CNAME record loops.
170 
171     Error       error;
172     uint16_t    offset;
173     uint16_t    numRecords;
174     CnameRecord cnameRecord;
175 
176     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
177 
178     for (uint16_t counter = 0; counter < kMaxCnameAliasNameChanges; counter++)
179     {
180         SelectSection(aSection, offset, numRecords);
181         error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aHostName, cnameRecord);
182 
183         if (error == kErrorNotFound)
184         {
185             error = kErrorNone;
186             ExitNow();
187         }
188 
189         SuccessOrExit(error);
190 
191         // A CNAME record was found. `offset` now points to after the
192         // last read byte within the `mMessage` into the `cnameRecord`
193         // (which is the start of the new canonical name).
194 
195         aHostName.SetFromMessage(*mMessage, offset);
196         SuccessOrExit(error = Name::ParseName(*mMessage, offset));
197 
198         // Loop back to check if there may be a CNAME record for the
199         // new `aHostName`.
200     }
201 
202     error = kErrorParse;
203 
204 exit:
205     return error;
206 }
207 
FindHostAddress(Section aSection,const Name & aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const208 Error Client::Response::FindHostAddress(Section       aSection,
209                                         const Name   &aHostName,
210                                         uint16_t      aIndex,
211                                         Ip6::Address &aAddress,
212                                         uint32_t     &aTtl) const
213 {
214     Error      error;
215     uint16_t   offset;
216     uint16_t   numRecords;
217     Name       name = aHostName;
218     AaaaRecord aaaaRecord;
219 
220     SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
221 
222     SelectSection(aSection, offset, numRecords);
223     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aaaaRecord));
224     aAddress = aaaaRecord.GetAddress();
225     aTtl     = aaaaRecord.GetTtl();
226 
227 exit:
228     return error;
229 }
230 
231 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
232 
FindARecord(Section aSection,const Name & aHostName,uint16_t aIndex,ARecord & aARecord) const233 Error Client::Response::FindARecord(Section aSection, const Name &aHostName, uint16_t aIndex, ARecord &aARecord) const
234 {
235     Error    error;
236     uint16_t offset;
237     uint16_t numRecords;
238     Name     name = aHostName;
239 
240     SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
241 
242     SelectSection(aSection, offset, numRecords);
243     error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aARecord);
244 
245 exit:
246     return error;
247 }
248 
249 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
250 
251 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
252 
InitServiceInfo(ServiceInfo & aServiceInfo) const253 void Client::Response::InitServiceInfo(ServiceInfo &aServiceInfo) const
254 {
255     // This method initializes `aServiceInfo` setting all
256     // TTLs to zero and host name to empty string.
257 
258     aServiceInfo.mTtl              = 0;
259     aServiceInfo.mHostAddressTtl   = 0;
260     aServiceInfo.mTxtDataTtl       = 0;
261     aServiceInfo.mTxtDataTruncated = false;
262 
263     AsCoreType(&aServiceInfo.mHostAddress).Clear();
264 
265     if ((aServiceInfo.mHostNameBuffer != nullptr) && (aServiceInfo.mHostNameBufferSize > 0))
266     {
267         aServiceInfo.mHostNameBuffer[0] = '\0';
268     }
269 }
270 
ReadServiceInfo(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const271 Error Client::Response::ReadServiceInfo(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
272 {
273     // This method searches for SRV record in the given `aSection`
274     // matching the record name against `aName`, and updates the
275     // `aServiceInfo` accordingly. It also searches for AAAA record
276     // for host name associated with the service (from SRV record).
277     // The search for AAAA record is always performed in Additional
278     // Data section (independent of the value given in `aSection`).
279 
280     Error     error = kErrorNone;
281     uint16_t  offset;
282     uint16_t  numRecords;
283     Name      hostName;
284     SrvRecord srvRecord;
285 
286     // A non-zero `mTtl` indicates that SRV record is already found
287     // and parsed from a previous response.
288     VerifyOrExit(aServiceInfo.mTtl == 0);
289 
290     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
291 
292     // Search for a matching SRV record
293     SelectSection(aSection, offset, numRecords);
294     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, srvRecord));
295 
296     aServiceInfo.mTtl      = srvRecord.GetTtl();
297     aServiceInfo.mPort     = srvRecord.GetPort();
298     aServiceInfo.mPriority = srvRecord.GetPriority();
299     aServiceInfo.mWeight   = srvRecord.GetWeight();
300 
301     hostName.SetFromMessage(*mMessage, offset);
302 
303     if (aServiceInfo.mHostNameBuffer != nullptr)
304     {
305         SuccessOrExit(error = srvRecord.ReadTargetHostName(*mMessage, offset, aServiceInfo.mHostNameBuffer,
306                                                            aServiceInfo.mHostNameBufferSize));
307     }
308     else
309     {
310         SuccessOrExit(error = Name::ParseName(*mMessage, offset));
311     }
312 
313     // Search in additional section for AAAA record for the host name.
314 
315     VerifyOrExit(AsCoreType(&aServiceInfo.mHostAddress).IsUnspecified());
316 
317     error = FindHostAddress(kAdditionalDataSection, hostName, /* aIndex */ 0, AsCoreType(&aServiceInfo.mHostAddress),
318                             aServiceInfo.mHostAddressTtl);
319 
320     if (error == kErrorNotFound)
321     {
322         error = kErrorNone;
323     }
324 
325 exit:
326     return error;
327 }
328 
ReadTxtRecord(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const329 Error Client::Response::ReadTxtRecord(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
330 {
331     // This method searches a TXT record in the given `aSection`
332     // matching the record name against `aName` and updates the TXT
333     // related properties in `aServicesInfo`.
334     //
335     // If no match is found `mTxtDataTtl` (which is initialized to zero)
336     // remains unchanged to indicate this. In this case this method still
337     // returns `kErrorNone`.
338 
339     Error     error = kErrorNone;
340     uint16_t  offset;
341     uint16_t  numRecords;
342     TxtRecord txtRecord;
343 
344     // A non-zero `mTxtDataTtl` indicates that TXT record is already
345     // found and parsed from a previous response.
346     VerifyOrExit(aServiceInfo.mTxtDataTtl == 0);
347 
348     // A null `mTxtData` indicates that caller does not want to retrieve
349     // TXT data.
350     VerifyOrExit(aServiceInfo.mTxtData != nullptr);
351 
352     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
353 
354     SelectSection(aSection, offset, numRecords);
355 
356     aServiceInfo.mTxtDataTruncated = false;
357 
358     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, txtRecord));
359 
360     error = txtRecord.ReadTxtData(*mMessage, offset, aServiceInfo.mTxtData, aServiceInfo.mTxtDataSize);
361 
362     if (error == kErrorNoBufs)
363     {
364         error = kErrorNone;
365 
366         // Mark `mTxtDataTruncated` to indicate that we could not read
367         // the full TXT record into the given `mTxtData` buffer.
368         aServiceInfo.mTxtDataTruncated = true;
369     }
370 
371     SuccessOrExit(error);
372     aServiceInfo.mTxtDataTtl = txtRecord.GetTtl();
373 
374 exit:
375     if (error == kErrorNotFound)
376     {
377         error = kErrorNone;
378     }
379 
380     return error;
381 }
382 
383 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
384 
PopulateFrom(const Message & aMessage)385 void Client::Response::PopulateFrom(const Message &aMessage)
386 {
387     // Populate `Response` with info from `aMessage`.
388 
389     uint16_t offset = aMessage.GetOffset();
390     Header   header;
391 
392     mMessage = &aMessage;
393 
394     IgnoreError(aMessage.Read(offset, header));
395     offset += sizeof(Header);
396 
397     for (uint16_t num = 0; num < header.GetQuestionCount(); num++)
398     {
399         IgnoreError(Name::ParseName(aMessage, offset));
400         offset += sizeof(Question);
401     }
402 
403     mAnswerOffset = offset;
404     IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAnswerCount()));
405     IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAuthorityRecordCount()));
406     mAdditionalOffset = offset;
407     IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAdditionalRecordCount()));
408 
409     mAnswerRecordCount     = header.GetAnswerCount();
410     mAdditionalRecordCount = header.GetAdditionalRecordCount();
411 }
412 
413 //---------------------------------------------------------------------------------------------------------------------
414 // Client::AddressResponse
415 
GetAddress(uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const416 Error Client::AddressResponse::GetAddress(uint16_t aIndex, Ip6::Address &aAddress, uint32_t &aTtl) const
417 {
418     Error error = kErrorNone;
419     Name  name(*mQuery, kNameOffsetInQuery);
420 
421 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
422 
423     // If the response is for an IPv4 address query or if it is an
424     // IPv6 address query response with no IPv6 address but with
425     // an IPv4 in its additional section, we read the IPv4 address
426     // and translate it to an IPv6 address.
427 
428     QueryInfo info;
429 
430     info.ReadFrom(*mQuery);
431 
432     if ((info.mQueryType == kIp4AddressQuery) || mIp6QueryResponseRequiresNat64)
433     {
434         Section                          section;
435         ARecord                          aRecord;
436         NetworkData::ExternalRouteConfig nat64Prefix;
437 
438         VerifyOrExit(mInstance->Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
439                      error = kErrorInvalidState);
440 
441         section = (info.mQueryType == kIp4AddressQuery) ? kAnswerSection : kAdditionalDataSection;
442         SuccessOrExit(error = FindARecord(section, name, aIndex, aRecord));
443 
444         aAddress.SynthesizeFromIp4Address(nat64Prefix.GetPrefix(), aRecord.GetAddress());
445         aTtl = aRecord.GetTtl();
446 
447         ExitNow();
448     }
449 
450 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
451 
452     ExitNow(error = FindHostAddress(kAnswerSection, name, aIndex, aAddress, aTtl));
453 
454 exit:
455     return error;
456 }
457 
458 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
459 
460 //---------------------------------------------------------------------------------------------------------------------
461 // Client::BrowseResponse
462 
GetServiceInstance(uint16_t aIndex,char * aLabelBuffer,uint8_t aLabelBufferSize) const463 Error Client::BrowseResponse::GetServiceInstance(uint16_t aIndex, char *aLabelBuffer, uint8_t aLabelBufferSize) const
464 {
465     Error     error;
466     uint16_t  offset;
467     uint16_t  numRecords;
468     Name      serviceName(*mQuery, kNameOffsetInQuery);
469     PtrRecord ptrRecord;
470 
471     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
472 
473     SelectSection(kAnswerSection, offset, numRecords);
474     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, serviceName, ptrRecord));
475     error = ptrRecord.ReadPtrName(*mMessage, offset, aLabelBuffer, aLabelBufferSize, nullptr, 0);
476 
477 exit:
478     return error;
479 }
480 
GetServiceInfo(const char * aInstanceLabel,ServiceInfo & aServiceInfo) const481 Error Client::BrowseResponse::GetServiceInfo(const char *aInstanceLabel, ServiceInfo &aServiceInfo) const
482 {
483     Error error;
484     Name  instanceName;
485 
486     // Find a matching PTR record for the service instance label. Then
487     // search and read SRV, TXT and AAAA records in Additional Data
488     // section matching the same name to populate `aServiceInfo`.
489 
490     SuccessOrExit(error = FindPtrRecord(aInstanceLabel, instanceName));
491 
492     InitServiceInfo(aServiceInfo);
493     SuccessOrExit(error = ReadServiceInfo(kAdditionalDataSection, instanceName, aServiceInfo));
494     SuccessOrExit(error = ReadTxtRecord(kAdditionalDataSection, instanceName, aServiceInfo));
495 
496     if (aServiceInfo.mTxtDataTtl == 0)
497     {
498         aServiceInfo.mTxtDataSize = 0;
499     }
500 
501 exit:
502     return error;
503 }
504 
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const505 Error Client::BrowseResponse::GetHostAddress(const char   *aHostName,
506                                              uint16_t      aIndex,
507                                              Ip6::Address &aAddress,
508                                              uint32_t     &aTtl) const
509 {
510     return FindHostAddress(kAdditionalDataSection, Name(aHostName), aIndex, aAddress, aTtl);
511 }
512 
FindPtrRecord(const char * aInstanceLabel,Name & aInstanceName) const513 Error Client::BrowseResponse::FindPtrRecord(const char *aInstanceLabel, Name &aInstanceName) const
514 {
515     // This method searches within the Answer Section for a PTR record
516     // matching a given instance label @aInstanceLabel. If found, the
517     // `aName` is updated to return the name in the message.
518 
519     Error     error;
520     uint16_t  offset;
521     Name      serviceName(*mQuery, kNameOffsetInQuery);
522     uint16_t  numRecords;
523     uint16_t  labelOffset;
524     PtrRecord ptrRecord;
525 
526     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
527 
528     SelectSection(kAnswerSection, offset, numRecords);
529 
530     for (; numRecords > 0; numRecords--)
531     {
532         SuccessOrExit(error = Name::CompareName(*mMessage, offset, serviceName));
533 
534         error = ResourceRecord::ReadRecord(*mMessage, offset, ptrRecord);
535 
536         if (error == kErrorNotFound)
537         {
538             // `ReadRecord()` updates `offset` to skip over a
539             // non-matching record.
540             continue;
541         }
542 
543         SuccessOrExit(error);
544 
545         // It is a PTR record. Check the first label to match the
546         // instance label.
547 
548         labelOffset = offset;
549         error       = Name::CompareLabel(*mMessage, labelOffset, aInstanceLabel);
550 
551         if (error == kErrorNone)
552         {
553             aInstanceName.SetFromMessage(*mMessage, offset);
554             ExitNow();
555         }
556 
557         VerifyOrExit(error == kErrorNotFound);
558 
559         // Update offset to skip over the PTR record.
560         offset += static_cast<uint16_t>(ptrRecord.GetSize()) - sizeof(ptrRecord);
561     }
562 
563     error = kErrorNotFound;
564 
565 exit:
566     return error;
567 }
568 
569 //---------------------------------------------------------------------------------------------------------------------
570 // Client::ServiceResponse
571 
GetServiceName(char * aLabelBuffer,uint8_t aLabelBufferSize,char * aNameBuffer,uint16_t aNameBufferSize) const572 Error Client::ServiceResponse::GetServiceName(char    *aLabelBuffer,
573                                               uint8_t  aLabelBufferSize,
574                                               char    *aNameBuffer,
575                                               uint16_t aNameBufferSize) const
576 {
577     Error    error;
578     uint16_t offset = kNameOffsetInQuery;
579 
580     SuccessOrExit(error = Name::ReadLabel(*mQuery, offset, aLabelBuffer, aLabelBufferSize));
581 
582     VerifyOrExit(aNameBuffer != nullptr);
583     SuccessOrExit(error = Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize));
584 
585 exit:
586     return error;
587 }
588 
GetServiceInfo(ServiceInfo & aServiceInfo) const589 Error Client::ServiceResponse::GetServiceInfo(ServiceInfo &aServiceInfo) const
590 {
591     // Search and read SRV, TXT records matching name from query.
592 
593     Error error = kErrorNotFound;
594 
595     InitServiceInfo(aServiceInfo);
596 
597     for (const Response *response = this; response != nullptr; response = response->mNext)
598     {
599         Name      name(*response->mQuery, kNameOffsetInQuery);
600         QueryInfo info;
601         Section   srvSection;
602         Section   txtSection;
603 
604         info.ReadFrom(*response->mQuery);
605 
606         switch (info.mQueryType)
607         {
608         case kIp6AddressQuery:
609 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
610         case kIp4AddressQuery:
611 #endif
612             IgnoreError(response->FindHostAddress(kAnswerSection, name, /* aIndex */ 0,
613                                                   AsCoreType(&aServiceInfo.mHostAddress),
614                                                   aServiceInfo.mHostAddressTtl));
615 
616             continue; // to `for()` loop
617 
618         case kServiceQuerySrvTxt:
619         case kServiceQuerySrv:
620         case kServiceQueryTxt:
621             break;
622 
623         default:
624             continue;
625         }
626 
627         // Determine from which section we should try to read the SRV and
628         // TXT records based on the query type.
629         //
630         // In `kServiceQuerySrv` or `kServiceQueryTxt` we expect to see
631         // only one record (SRV or TXT) in the answer section, but we
632         // still try to read the other records from additional data
633         // section in case server provided them.
634 
635         srvSection = (info.mQueryType != kServiceQueryTxt) ? kAnswerSection : kAdditionalDataSection;
636         txtSection = (info.mQueryType != kServiceQuerySrv) ? kAnswerSection : kAdditionalDataSection;
637 
638         error = response->ReadServiceInfo(srvSection, name, aServiceInfo);
639 
640         if ((srvSection == kAdditionalDataSection) && (error == kErrorNotFound))
641         {
642             error = kErrorNone;
643         }
644 
645         SuccessOrExit(error);
646 
647         SuccessOrExit(error = response->ReadTxtRecord(txtSection, name, aServiceInfo));
648     }
649 
650     if (aServiceInfo.mTxtDataTtl == 0)
651     {
652         aServiceInfo.mTxtDataSize = 0;
653     }
654 
655 exit:
656     return error;
657 }
658 
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const659 Error Client::ServiceResponse::GetHostAddress(const char   *aHostName,
660                                               uint16_t      aIndex,
661                                               Ip6::Address &aAddress,
662                                               uint32_t     &aTtl) const
663 {
664     Error error = kErrorNotFound;
665 
666     for (const Response *response = this; response != nullptr; response = response->mNext)
667     {
668         Section   section = kAdditionalDataSection;
669         QueryInfo info;
670 
671         info.ReadFrom(*response->mQuery);
672 
673         switch (info.mQueryType)
674         {
675         case kIp6AddressQuery:
676 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
677         case kIp4AddressQuery:
678 #endif
679             section = kAnswerSection;
680             break;
681 
682         default:
683             break;
684         }
685 
686         error = response->FindHostAddress(section, Name(aHostName), aIndex, aAddress, aTtl);
687 
688         if (error == kErrorNone)
689         {
690             break;
691         }
692     }
693 
694     return error;
695 }
696 
697 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
698 
699 //---------------------------------------------------------------------------------------------------------------------
700 // Client
701 
702 const uint16_t Client::kIp6AddressQueryRecordTypes[] = {ResourceRecord::kTypeAaaa};
703 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
704 const uint16_t Client::kIp4AddressQueryRecordTypes[] = {ResourceRecord::kTypeA};
705 #endif
706 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
707 const uint16_t Client::kBrowseQueryRecordTypes[]  = {ResourceRecord::kTypePtr};
708 const uint16_t Client::kServiceQueryRecordTypes[] = {ResourceRecord::kTypeSrv, ResourceRecord::kTypeTxt};
709 #endif
710 
711 const uint8_t Client::kQuestionCount[] = {
712     /* kIp6AddressQuery -> */ GetArrayLength(kIp6AddressQueryRecordTypes), // AAAA record
713 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
714     /* kIp4AddressQuery -> */ GetArrayLength(kIp4AddressQueryRecordTypes), // A record
715 #endif
716 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
717     /* kBrowseQuery        -> */ GetArrayLength(kBrowseQueryRecordTypes),  // PTR record
718     /* kServiceQuerySrvTxt -> */ GetArrayLength(kServiceQueryRecordTypes), // SRV and TXT records
719     /* kServiceQuerySrv    -> */ 1,                                        // SRV record only
720     /* kServiceQueryTxt    -> */ 1,                                        // TXT record only
721 #endif
722 };
723 
724 const uint16_t *const Client::kQuestionRecordTypes[] = {
725     /* kIp6AddressQuery -> */ kIp6AddressQueryRecordTypes,
726 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
727     /* kIp4AddressQuery -> */ kIp4AddressQueryRecordTypes,
728 #endif
729 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
730     /* kBrowseQuery  -> */ kBrowseQueryRecordTypes,
731     /* kServiceQuerySrvTxt -> */ kServiceQueryRecordTypes,
732     /* kServiceQuerySrv    -> */ &kServiceQueryRecordTypes[0],
733     /* kServiceQueryTxt    -> */ &kServiceQueryRecordTypes[1],
734 
735 #endif
736 };
737 
Client(Instance & aInstance)738 Client::Client(Instance &aInstance)
739     : InstanceLocator(aInstance)
740     , mSocket(aInstance)
741 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
742     , mTcpState(kTcpUninitialized)
743 #endif
744     , mTimer(aInstance)
745     , mDefaultConfig(QueryConfig::kInitFromDefaults)
746 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
747     , mUserDidSetDefaultAddress(false)
748 #endif
749 {
750     static_assert(kIp6AddressQuery == 0, "kIp6AddressQuery value is not correct");
751 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
752     static_assert(kIp4AddressQuery == 1, "kIp4AddressQuery value is not correct");
753 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
754     static_assert(kBrowseQuery == 2, "kBrowseQuery value is not correct");
755     static_assert(kServiceQuerySrvTxt == 3, "kServiceQuerySrvTxt value is not correct");
756     static_assert(kServiceQuerySrv == 4, "kServiceQuerySrv value is not correct");
757     static_assert(kServiceQueryTxt == 5, "kServiceQueryTxt value is not correct");
758 #endif
759 #elif OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
760     static_assert(kBrowseQuery == 1, "kBrowseQuery value is not correct");
761     static_assert(kServiceQuerySrvTxt == 2, "kServiceQuerySrvTxt value is not correct");
762     static_assert(kServiceQuerySrv == 3, "kServiceQuerySrv value is not correct");
763     static_assert(kServiceQueryTxt == 4, "kServiceQuerySrv value is not correct");
764 #endif
765 }
766 
Start(void)767 Error Client::Start(void)
768 {
769     Error error;
770 
771     SuccessOrExit(error = mSocket.Open(&Client::HandleUdpReceive, this));
772     SuccessOrExit(error = mSocket.Bind(0, Ip6::kNetifUnspecified));
773 
774 exit:
775     return error;
776 }
777 
Stop(void)778 void Client::Stop(void)
779 {
780     Query *query;
781 
782     while ((query = mMainQueries.GetHead()) != nullptr)
783     {
784         FinalizeQuery(*query, kErrorAbort);
785     }
786 
787     IgnoreError(mSocket.Close());
788 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
789     if (mTcpState != kTcpUninitialized)
790     {
791         IgnoreError(mEndpoint.Deinitialize());
792     }
793 #endif
794 }
795 
796 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
InitTcpSocket(void)797 Error Client::InitTcpSocket(void)
798 {
799     Error                       error;
800     otTcpEndpointInitializeArgs endpointArgs;
801 
802     ClearAllBytes(endpointArgs);
803     endpointArgs.mSendDoneCallback         = HandleTcpSendDoneCallback;
804     endpointArgs.mEstablishedCallback      = HandleTcpEstablishedCallback;
805     endpointArgs.mReceiveAvailableCallback = HandleTcpReceiveAvailableCallback;
806     endpointArgs.mDisconnectedCallback     = HandleTcpDisconnectedCallback;
807     endpointArgs.mContext                  = this;
808     endpointArgs.mReceiveBuffer            = mReceiveBufferBytes;
809     endpointArgs.mReceiveBufferSize        = sizeof(mReceiveBufferBytes);
810 
811     mSendLink.mNext   = nullptr;
812     mSendLink.mData   = mSendBufferBytes;
813     mSendLink.mLength = 0;
814 
815     SuccessOrExit(error = mEndpoint.Initialize(Get<Instance>(), endpointArgs));
816 exit:
817     return error;
818 }
819 #endif
820 
SetDefaultConfig(const QueryConfig & aQueryConfig)821 void Client::SetDefaultConfig(const QueryConfig &aQueryConfig)
822 {
823     QueryConfig startingDefault(QueryConfig::kInitFromDefaults);
824 
825     mDefaultConfig.SetFrom(&aQueryConfig, startingDefault);
826 
827 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
828     mUserDidSetDefaultAddress = !aQueryConfig.GetServerSockAddr().GetAddress().IsUnspecified();
829     UpdateDefaultConfigAddress();
830 #endif
831 }
832 
ResetDefaultConfig(void)833 void Client::ResetDefaultConfig(void)
834 {
835     mDefaultConfig = QueryConfig(QueryConfig::kInitFromDefaults);
836 
837 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
838     mUserDidSetDefaultAddress = false;
839     UpdateDefaultConfigAddress();
840 #endif
841 }
842 
843 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
UpdateDefaultConfigAddress(void)844 void Client::UpdateDefaultConfigAddress(void)
845 {
846     const Ip6::Address &srpServerAddr = Get<Srp::Client>().GetServerAddress().GetAddress();
847 
848     if (!mUserDidSetDefaultAddress && Get<Srp::Client>().IsServerSelectedByAutoStart() &&
849         !srpServerAddr.IsUnspecified())
850     {
851         mDefaultConfig.GetServerSockAddr().SetAddress(srpServerAddr);
852     }
853 }
854 #endif
855 
ResolveAddress(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)856 Error Client::ResolveAddress(const char        *aHostName,
857                              AddressCallback    aCallback,
858                              void              *aContext,
859                              const QueryConfig *aConfig)
860 {
861     QueryInfo info;
862 
863     info.Clear();
864     info.mQueryType = kIp6AddressQuery;
865     info.mConfig.SetFrom(aConfig, mDefaultConfig);
866     info.mCallback.mAddressCallback = aCallback;
867     info.mCallbackContext           = aContext;
868 
869     return StartQuery(info, nullptr, aHostName);
870 }
871 
872 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
ResolveIp4Address(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)873 Error Client::ResolveIp4Address(const char        *aHostName,
874                                 AddressCallback    aCallback,
875                                 void              *aContext,
876                                 const QueryConfig *aConfig)
877 {
878     QueryInfo info;
879 
880     info.Clear();
881     info.mQueryType = kIp4AddressQuery;
882     info.mConfig.SetFrom(aConfig, mDefaultConfig);
883     info.mCallback.mAddressCallback = aCallback;
884     info.mCallbackContext           = aContext;
885 
886     return StartQuery(info, nullptr, aHostName);
887 }
888 #endif
889 
890 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
891 
Browse(const char * aServiceName,BrowseCallback aCallback,void * aContext,const QueryConfig * aConfig)892 Error Client::Browse(const char *aServiceName, BrowseCallback aCallback, void *aContext, const QueryConfig *aConfig)
893 {
894     QueryInfo info;
895 
896     info.Clear();
897     info.mQueryType = kBrowseQuery;
898     info.mConfig.SetFrom(aConfig, mDefaultConfig);
899     info.mCallback.mBrowseCallback = aCallback;
900     info.mCallbackContext          = aContext;
901 
902     return StartQuery(info, nullptr, aServiceName);
903 }
904 
ResolveService(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)905 Error Client::ResolveService(const char        *aInstanceLabel,
906                              const char        *aServiceName,
907                              ServiceCallback    aCallback,
908                              void              *aContext,
909                              const QueryConfig *aConfig)
910 {
911     return Resolve(aInstanceLabel, aServiceName, aCallback, aContext, aConfig, false);
912 }
913 
ResolveServiceAndHostAddress(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)914 Error Client::ResolveServiceAndHostAddress(const char        *aInstanceLabel,
915                                            const char        *aServiceName,
916                                            ServiceCallback    aCallback,
917                                            void              *aContext,
918                                            const QueryConfig *aConfig)
919 {
920     return Resolve(aInstanceLabel, aServiceName, aCallback, aContext, aConfig, true);
921 }
922 
Resolve(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig,bool aShouldResolveHostAddr)923 Error Client::Resolve(const char        *aInstanceLabel,
924                       const char        *aServiceName,
925                       ServiceCallback    aCallback,
926                       void              *aContext,
927                       const QueryConfig *aConfig,
928                       bool               aShouldResolveHostAddr)
929 {
930     QueryInfo info;
931     Error     error;
932     QueryType secondQueryType = kNoQuery;
933 
934     VerifyOrExit(aInstanceLabel != nullptr, error = kErrorInvalidArgs);
935 
936     info.Clear();
937 
938     info.mConfig.SetFrom(aConfig, mDefaultConfig);
939     info.mShouldResolveHostAddr = aShouldResolveHostAddr;
940 
941     switch (info.mConfig.GetServiceMode())
942     {
943     case QueryConfig::kServiceModeSrvTxtSeparate:
944         secondQueryType = kServiceQueryTxt;
945 
946         OT_FALL_THROUGH;
947 
948     case QueryConfig::kServiceModeSrv:
949         info.mQueryType = kServiceQuerySrv;
950         break;
951 
952     case QueryConfig::kServiceModeTxt:
953         info.mQueryType = kServiceQueryTxt;
954         VerifyOrExit(!info.mShouldResolveHostAddr, error = kErrorInvalidArgs);
955         break;
956 
957     case QueryConfig::kServiceModeSrvTxt:
958     case QueryConfig::kServiceModeSrvTxtOptimize:
959     default:
960         info.mQueryType = kServiceQuerySrvTxt;
961         break;
962     }
963 
964     info.mCallback.mServiceCallback = aCallback;
965     info.mCallbackContext           = aContext;
966 
967     error = StartQuery(info, aInstanceLabel, aServiceName, secondQueryType);
968 
969 exit:
970     return error;
971 }
972 
973 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
974 
StartQuery(QueryInfo & aInfo,const char * aLabel,const char * aName,QueryType aSecondType)975 Error Client::StartQuery(QueryInfo &aInfo, const char *aLabel, const char *aName, QueryType aSecondType)
976 {
977     // The `aLabel` can be `nullptr` and then `aName` provides the
978     // full name, otherwise the name is appended as `{aLabel}.
979     // {aName}`.
980 
981     Error  error;
982     Query *query;
983 
984     VerifyOrExit(mSocket.IsBound(), error = kErrorInvalidState);
985 
986 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
987     if (aInfo.mQueryType == kIp4AddressQuery)
988     {
989         NetworkData::ExternalRouteConfig nat64Prefix;
990 
991         VerifyOrExit(aInfo.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow, error = kErrorInvalidArgs);
992         VerifyOrExit(Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
993                      error = kErrorInvalidState);
994     }
995 #endif
996 
997     SuccessOrExit(error = AllocateQuery(aInfo, aLabel, aName, query));
998 
999     mMainQueries.Enqueue(*query);
1000 
1001     error = SendQuery(*query, aInfo, /* aUpdateTimer */ true);
1002     VerifyOrExit(error == kErrorNone, FreeQuery(*query));
1003 
1004     if (aSecondType != kNoQuery)
1005     {
1006         Query *secondQuery;
1007 
1008         aInfo.mQueryType         = aSecondType;
1009         aInfo.mMessageId         = 0;
1010         aInfo.mTransmissionCount = 0;
1011         aInfo.mMainQuery         = query;
1012 
1013         // We intentionally do not use `error` here so in the unlikely
1014         // case where we cannot allocate the second query we can proceed
1015         // with the first one.
1016         SuccessOrExit(AllocateQuery(aInfo, aLabel, aName, secondQuery));
1017 
1018         IgnoreError(SendQuery(*secondQuery, aInfo, /* aUpdateTimer */ true));
1019 
1020         // Update first query to link to second one by updating
1021         // its `mNextQuery`.
1022         aInfo.ReadFrom(*query);
1023         aInfo.mNextQuery = secondQuery;
1024         UpdateQuery(*query, aInfo);
1025     }
1026 
1027 exit:
1028     return error;
1029 }
1030 
AllocateQuery(const QueryInfo & aInfo,const char * aLabel,const char * aName,Query * & aQuery)1031 Error Client::AllocateQuery(const QueryInfo &aInfo, const char *aLabel, const char *aName, Query *&aQuery)
1032 {
1033     Error error = kErrorNone;
1034 
1035     aQuery = nullptr;
1036 
1037     VerifyOrExit(aInfo.mConfig.GetResponseTimeout() <= TimerMilli::kMaxDelay, error = kErrorInvalidArgs);
1038 
1039     aQuery = Get<MessagePool>().Allocate(Message::kTypeOther);
1040     VerifyOrExit(aQuery != nullptr, error = kErrorNoBufs);
1041 
1042     SuccessOrExit(error = aQuery->Append(aInfo));
1043 
1044     if (aLabel != nullptr)
1045     {
1046         SuccessOrExit(error = Name::AppendLabel(aLabel, *aQuery));
1047     }
1048 
1049     SuccessOrExit(error = Name::AppendName(aName, *aQuery));
1050 
1051 exit:
1052     FreeAndNullMessageOnError(aQuery, error);
1053     return error;
1054 }
1055 
FindMainQuery(Query & aQuery)1056 Client::Query &Client::FindMainQuery(Query &aQuery)
1057 {
1058     QueryInfo info;
1059 
1060     info.ReadFrom(aQuery);
1061 
1062     return (info.mMainQuery == nullptr) ? aQuery : *info.mMainQuery;
1063 }
1064 
FreeQuery(Query & aQuery)1065 void Client::FreeQuery(Query &aQuery)
1066 {
1067     Query    &mainQuery = FindMainQuery(aQuery);
1068     QueryInfo info;
1069 
1070     mMainQueries.Dequeue(mainQuery);
1071 
1072     for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1073     {
1074         info.ReadFrom(*query);
1075         FreeMessage(info.mSavedResponse);
1076         query->Free();
1077     }
1078 }
1079 
SendQuery(Query & aQuery,QueryInfo & aInfo,bool aUpdateTimer)1080 Error Client::SendQuery(Query &aQuery, QueryInfo &aInfo, bool aUpdateTimer)
1081 {
1082     // This method prepares and sends a query message represented by
1083     // `aQuery` and `aInfo`. This method updates `aInfo` (e.g., sets
1084     // the new `mRetransmissionTime`) and updates it in `aQuery` as
1085     // well. `aUpdateTimer` indicates whether the timer should be
1086     // updated when query is sent or not (used in the case where timer
1087     // is handled by caller).
1088 
1089     Error            error   = kErrorNone;
1090     Message         *message = nullptr;
1091     Header           header;
1092     Ip6::MessageInfo messageInfo;
1093     uint16_t         length = 0;
1094 
1095     aInfo.mTransmissionCount++;
1096     aInfo.mRetransmissionTime = TimerMilli::GetNow() + aInfo.mConfig.GetResponseTimeout();
1097 
1098     if (aInfo.mMessageId == 0)
1099     {
1100         do
1101         {
1102             SuccessOrExit(error = header.SetRandomMessageId());
1103         } while ((header.GetMessageId() == 0) || (FindQueryById(header.GetMessageId()) != nullptr));
1104 
1105         aInfo.mMessageId = header.GetMessageId();
1106     }
1107     else
1108     {
1109         header.SetMessageId(aInfo.mMessageId);
1110     }
1111 
1112     header.SetType(Header::kTypeQuery);
1113     header.SetQueryType(Header::kQueryTypeStandard);
1114 
1115     if (aInfo.mConfig.GetRecursionFlag() == QueryConfig::kFlagRecursionDesired)
1116     {
1117         header.SetRecursionDesiredFlag();
1118     }
1119 
1120     header.SetQuestionCount(kQuestionCount[aInfo.mQueryType]);
1121 
1122     message = mSocket.NewMessage();
1123     VerifyOrExit(message != nullptr, error = kErrorNoBufs);
1124 
1125     SuccessOrExit(error = message->Append(header));
1126 
1127     // Prepare the question section.
1128 
1129     for (uint8_t num = 0; num < kQuestionCount[aInfo.mQueryType]; num++)
1130     {
1131         SuccessOrExit(error = AppendNameFromQuery(aQuery, *message));
1132         SuccessOrExit(error = message->Append(Question(kQuestionRecordTypes[aInfo.mQueryType][num])));
1133     }
1134 
1135     length = message->GetLength() - message->GetOffset();
1136 
1137     if (aInfo.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1138 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1139     {
1140         // Check if query will fit into tcp buffer if not return error.
1141         VerifyOrExit(length + sizeof(uint16_t) + mSendLink.mLength <=
1142                          OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_QUERY_MAX_SIZE,
1143                      error = kErrorNoBufs);
1144 
1145         // In case of initialized connection check if connected peer and new query have the same address.
1146         if (mTcpState != kTcpUninitialized)
1147         {
1148             VerifyOrExit(mEndpoint.GetPeerAddress() == AsCoreType(&aInfo.mConfig.mServerSockAddr),
1149                          error = kErrorFailed);
1150         }
1151 
1152         switch (mTcpState)
1153         {
1154         case kTcpUninitialized:
1155             SuccessOrExit(error = InitTcpSocket());
1156             SuccessOrExit(
1157                 error = mEndpoint.Connect(AsCoreType(&aInfo.mConfig.mServerSockAddr), OT_TCP_CONNECT_NO_FAST_OPEN));
1158             mTcpState = kTcpConnecting;
1159             PrepareTcpMessage(*message);
1160             break;
1161         case kTcpConnectedIdle:
1162             PrepareTcpMessage(*message);
1163             SuccessOrExit(error = mEndpoint.SendByReference(mSendLink, /* aFlags */ 0));
1164             mTcpState = kTcpConnectedSending;
1165             break;
1166         case kTcpConnecting:
1167             PrepareTcpMessage(*message);
1168             break;
1169         case kTcpConnectedSending:
1170             BigEndian::WriteUint16(length, mSendBufferBytes + mSendLink.mLength);
1171             SuccessOrAssert(error = message->Read(message->GetOffset(),
1172                                                   (mSendBufferBytes + sizeof(uint16_t) + mSendLink.mLength), length));
1173             IgnoreError(mEndpoint.SendByExtension(length + sizeof(uint16_t), /* aFlags */ 0));
1174             break;
1175         }
1176         message->Free();
1177         message = nullptr;
1178     }
1179 #else
1180     {
1181         error = kErrorInvalidArgs;
1182         LogWarn("DNS query over TCP not supported.");
1183         ExitNow();
1184     }
1185 #endif
1186     else
1187     {
1188         VerifyOrExit(length <= kUdpQueryMaxSize, error = kErrorInvalidArgs);
1189         messageInfo.SetPeerAddr(aInfo.mConfig.GetServerSockAddr().GetAddress());
1190         messageInfo.SetPeerPort(aInfo.mConfig.GetServerSockAddr().GetPort());
1191         SuccessOrExit(error = mSocket.SendTo(*message, messageInfo));
1192     }
1193 
1194 exit:
1195 
1196     FreeMessageOnError(message, error);
1197     if (aUpdateTimer)
1198     {
1199         mTimer.FireAtIfEarlier(aInfo.mRetransmissionTime);
1200     }
1201 
1202     UpdateQuery(aQuery, aInfo);
1203 
1204     return error;
1205 }
1206 
AppendNameFromQuery(const Query & aQuery,Message & aMessage)1207 Error Client::AppendNameFromQuery(const Query &aQuery, Message &aMessage)
1208 {
1209     // The name is encoded and included after the `Info` in `aQuery`
1210     // starting at `kNameOffsetInQuery`.
1211 
1212     return aMessage.AppendBytesFromMessage(aQuery, kNameOffsetInQuery, aQuery.GetLength() - kNameOffsetInQuery);
1213 }
1214 
FinalizeQuery(Query & aQuery,Error aError)1215 void Client::FinalizeQuery(Query &aQuery, Error aError)
1216 {
1217     Response response;
1218     Query   &mainQuery = FindMainQuery(aQuery);
1219 
1220     response.mInstance = &Get<Instance>();
1221     response.mQuery    = &mainQuery;
1222 
1223     FinalizeQuery(response, aError);
1224 }
1225 
FinalizeQuery(Response & aResponse,Error aError)1226 void Client::FinalizeQuery(Response &aResponse, Error aError)
1227 {
1228     QueryType type;
1229     Callback  callback;
1230     void     *context;
1231 
1232     GetQueryTypeAndCallback(*aResponse.mQuery, type, callback, context);
1233 
1234     switch (type)
1235     {
1236     case kIp6AddressQuery:
1237 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1238     case kIp4AddressQuery:
1239 #endif
1240         if (callback.mAddressCallback != nullptr)
1241         {
1242             callback.mAddressCallback(aError, &aResponse, context);
1243         }
1244         break;
1245 
1246 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1247     case kBrowseQuery:
1248         if (callback.mBrowseCallback != nullptr)
1249         {
1250             callback.mBrowseCallback(aError, &aResponse, context);
1251         }
1252         break;
1253 
1254     case kServiceQuerySrvTxt:
1255     case kServiceQuerySrv:
1256     case kServiceQueryTxt:
1257         if (callback.mServiceCallback != nullptr)
1258         {
1259             callback.mServiceCallback(aError, &aResponse, context);
1260         }
1261         break;
1262 #endif
1263     case kNoQuery:
1264         break;
1265     }
1266 
1267     FreeQuery(*aResponse.mQuery);
1268 }
1269 
GetQueryTypeAndCallback(const Query & aQuery,QueryType & aType,Callback & aCallback,void * & aContext)1270 void Client::GetQueryTypeAndCallback(const Query &aQuery, QueryType &aType, Callback &aCallback, void *&aContext)
1271 {
1272     QueryInfo info;
1273 
1274     info.ReadFrom(aQuery);
1275 
1276     aType     = info.mQueryType;
1277     aCallback = info.mCallback;
1278     aContext  = info.mCallbackContext;
1279 }
1280 
FindQueryById(uint16_t aMessageId)1281 Client::Query *Client::FindQueryById(uint16_t aMessageId)
1282 {
1283     Query    *matchedQuery = nullptr;
1284     QueryInfo info;
1285 
1286     for (Query &mainQuery : mMainQueries)
1287     {
1288         for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1289         {
1290             info.ReadFrom(*query);
1291 
1292             if (info.mMessageId == aMessageId)
1293             {
1294                 matchedQuery = query;
1295                 ExitNow();
1296             }
1297         }
1298     }
1299 
1300 exit:
1301     return matchedQuery;
1302 }
1303 
HandleUdpReceive(void * aContext,otMessage * aMessage,const otMessageInfo * aMsgInfo)1304 void Client::HandleUdpReceive(void *aContext, otMessage *aMessage, const otMessageInfo *aMsgInfo)
1305 {
1306     OT_UNUSED_VARIABLE(aMsgInfo);
1307 
1308     static_cast<Client *>(aContext)->ProcessResponse(AsCoreType(aMessage));
1309 }
1310 
ProcessResponse(const Message & aResponseMessage)1311 void Client::ProcessResponse(const Message &aResponseMessage)
1312 {
1313     Error  responseError;
1314     Query *query;
1315 
1316     SuccessOrExit(ParseResponse(aResponseMessage, query, responseError));
1317 
1318     if (responseError != kErrorNone)
1319     {
1320         // Received an error from server, check if we can replace
1321         // the query.
1322 
1323 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1324         if (ReplaceWithIp4Query(*query) == kErrorNone)
1325         {
1326             ExitNow();
1327         }
1328 #endif
1329 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1330         if (ReplaceWithSeparateSrvTxtQueries(*query) == kErrorNone)
1331         {
1332             ExitNow();
1333         }
1334 #endif
1335 
1336         FinalizeQuery(*query, responseError);
1337         ExitNow();
1338     }
1339 
1340     // Received successful response from server.
1341 
1342 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1343     ResolveHostAddressIfNeeded(*query, aResponseMessage);
1344 #endif
1345 
1346     if (!CanFinalizeQuery(*query))
1347     {
1348         SaveQueryResponse(*query, aResponseMessage);
1349         ExitNow();
1350     }
1351 
1352     PrepareResponseAndFinalize(FindMainQuery(*query), aResponseMessage, nullptr);
1353 
1354 exit:
1355     return;
1356 }
1357 
ParseResponse(const Message & aResponseMessage,Query * & aQuery,Error & aResponseError)1358 Error Client::ParseResponse(const Message &aResponseMessage, Query *&aQuery, Error &aResponseError)
1359 {
1360     Error     error  = kErrorNone;
1361     uint16_t  offset = aResponseMessage.GetOffset();
1362     Header    header;
1363     QueryInfo info;
1364     Name      queryName;
1365 
1366     SuccessOrExit(error = aResponseMessage.Read(offset, header));
1367     offset += sizeof(Header);
1368 
1369     VerifyOrExit((header.GetType() == Header::kTypeResponse) && (header.GetQueryType() == Header::kQueryTypeStandard) &&
1370                      !header.IsTruncationFlagSet(),
1371                  error = kErrorDrop);
1372 
1373     aQuery = FindQueryById(header.GetMessageId());
1374     VerifyOrExit(aQuery != nullptr, error = kErrorNotFound);
1375 
1376     info.ReadFrom(*aQuery);
1377 
1378     queryName.SetFromMessage(*aQuery, kNameOffsetInQuery);
1379 
1380     // Check the Question Section
1381 
1382     if (header.GetQuestionCount() == kQuestionCount[info.mQueryType])
1383     {
1384         for (uint8_t num = 0; num < kQuestionCount[info.mQueryType]; num++)
1385         {
1386             SuccessOrExit(error = Name::CompareName(aResponseMessage, offset, queryName));
1387             offset += sizeof(Question);
1388         }
1389     }
1390     else
1391     {
1392         VerifyOrExit((header.GetResponseCode() != Header::kResponseSuccess) && (header.GetQuestionCount() == 0),
1393                      error = kErrorParse);
1394     }
1395 
1396     // Check the answer, authority and additional record sections
1397 
1398     SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAnswerCount()));
1399     SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAuthorityRecordCount()));
1400     SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAdditionalRecordCount()));
1401 
1402     // Read the response code
1403 
1404     aResponseError = Header::ResponseCodeToError(header.GetResponseCode());
1405 
1406 exit:
1407     return error;
1408 }
1409 
CanFinalizeQuery(Query & aQuery)1410 bool Client::CanFinalizeQuery(Query &aQuery)
1411 {
1412     // Determines whether we can finalize a main query by checking if
1413     // we have received and saved responses for all other related
1414     // queries associated with `aQuery`. Note that this method is
1415     // called when we receive a response for `aQuery`, so no need to
1416     // check for a saved response for `aQuery` itself.
1417 
1418     bool      canFinalize = true;
1419     QueryInfo info;
1420 
1421     for (Query *query = &FindMainQuery(aQuery); query != nullptr; query = info.mNextQuery)
1422     {
1423         info.ReadFrom(*query);
1424 
1425         if (query == &aQuery)
1426         {
1427             continue;
1428         }
1429 
1430         if (info.mSavedResponse == nullptr)
1431         {
1432             canFinalize = false;
1433             ExitNow();
1434         }
1435     }
1436 
1437 exit:
1438     return canFinalize;
1439 }
1440 
SaveQueryResponse(Query & aQuery,const Message & aResponseMessage)1441 void Client::SaveQueryResponse(Query &aQuery, const Message &aResponseMessage)
1442 {
1443     QueryInfo info;
1444 
1445     info.ReadFrom(aQuery);
1446     VerifyOrExit(info.mSavedResponse == nullptr);
1447 
1448     // If `Clone()` fails we let retry or timeout handle the error.
1449     info.mSavedResponse = aResponseMessage.Clone();
1450 
1451     UpdateQuery(aQuery, info);
1452 
1453 exit:
1454     return;
1455 }
1456 
PopulateResponse(Response & aResponse,Query & aQuery,const Message & aResponseMessage)1457 Client::Query *Client::PopulateResponse(Response &aResponse, Query &aQuery, const Message &aResponseMessage)
1458 {
1459     // Populate `aResponse` for `aQuery`. If there is a saved response
1460     // message for `aQuery` we use it, otherwise, we use
1461     // `aResponseMessage`.
1462 
1463     QueryInfo info;
1464 
1465     info.ReadFrom(aQuery);
1466 
1467     aResponse.mInstance = &Get<Instance>();
1468     aResponse.mQuery    = &aQuery;
1469     aResponse.PopulateFrom((info.mSavedResponse == nullptr) ? aResponseMessage : *info.mSavedResponse);
1470 
1471     return info.mNextQuery;
1472 }
1473 
PrepareResponseAndFinalize(Query & aQuery,const Message & aResponseMessage,Response * aPrevResponse)1474 void Client::PrepareResponseAndFinalize(Query &aQuery, const Message &aResponseMessage, Response *aPrevResponse)
1475 {
1476     // This method prepares a list of chained `Response` instances
1477     // corresponding to all related (chained) queries. It uses
1478     // recursion to go through the queries and construct the
1479     // `Response` chain.
1480 
1481     Response response;
1482     Query   *nextQuery;
1483 
1484     nextQuery      = PopulateResponse(response, aQuery, aResponseMessage);
1485     response.mNext = aPrevResponse;
1486 
1487     if (nextQuery != nullptr)
1488     {
1489         PrepareResponseAndFinalize(*nextQuery, aResponseMessage, &response);
1490     }
1491     else
1492     {
1493         FinalizeQuery(response, kErrorNone);
1494     }
1495 }
1496 
HandleTimer(void)1497 void Client::HandleTimer(void)
1498 {
1499     TimeMilli now      = TimerMilli::GetNow();
1500     TimeMilli nextTime = now.GetDistantFuture();
1501     QueryInfo info;
1502 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1503     bool hasTcpQuery = false;
1504 #endif
1505 
1506     for (Query &mainQuery : mMainQueries)
1507     {
1508         for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1509         {
1510             info.ReadFrom(*query);
1511 
1512             if (info.mSavedResponse != nullptr)
1513             {
1514                 continue;
1515             }
1516 
1517             if (now >= info.mRetransmissionTime)
1518             {
1519                 if (info.mTransmissionCount >= info.mConfig.GetMaxTxAttempts())
1520                 {
1521                     FinalizeQuery(*query, kErrorResponseTimeout);
1522                     break;
1523                 }
1524 
1525                 IgnoreError(SendQuery(*query, info, /* aUpdateTimer */ false));
1526             }
1527 
1528             if (nextTime > info.mRetransmissionTime)
1529             {
1530                 nextTime = info.mRetransmissionTime;
1531             }
1532 
1533 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1534             if (info.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1535             {
1536                 hasTcpQuery = true;
1537             }
1538 #endif
1539         }
1540     }
1541 
1542     if (nextTime < now.GetDistantFuture())
1543     {
1544         mTimer.FireAt(nextTime);
1545     }
1546 
1547 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1548     if (!hasTcpQuery && mTcpState != kTcpUninitialized)
1549     {
1550         IgnoreError(mEndpoint.SendEndOfStream());
1551     }
1552 #endif
1553 }
1554 
1555 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1556 
ReplaceWithIp4Query(Query & aQuery)1557 Error Client::ReplaceWithIp4Query(Query &aQuery)
1558 {
1559     Error     error = kErrorFailed;
1560     QueryInfo info;
1561 
1562     info.ReadFrom(aQuery);
1563 
1564     VerifyOrExit(info.mQueryType == kIp4AddressQuery);
1565     VerifyOrExit(info.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow);
1566 
1567     // We send a new query for IPv4 address resolution
1568     // for the same host name. We reuse the existing `aQuery`
1569     // instance and keep all the info but clear `mTransmissionCount`
1570     // and `mMessageId` (so that a new random message ID is
1571     // selected). The new `info` will be saved in the query in
1572     // `SendQuery()`. Note that the current query is still in the
1573     // `mMainQueries` list when `SendQuery()` selects a new random
1574     // message ID, so the existing message ID for this query will
1575     // not be reused.
1576 
1577     info.mQueryType         = kIp4AddressQuery;
1578     info.mMessageId         = 0;
1579     info.mTransmissionCount = 0;
1580 
1581     IgnoreError(SendQuery(aQuery, info, /* aUpdateTimer */ true));
1582     error = kErrorNone;
1583 
1584 exit:
1585     return error;
1586 }
1587 
1588 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1589 
1590 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1591 
ReplaceWithSeparateSrvTxtQueries(Query & aQuery)1592 Error Client::ReplaceWithSeparateSrvTxtQueries(Query &aQuery)
1593 {
1594     Error     error = kErrorFailed;
1595     QueryInfo info;
1596     Query    *secondQuery;
1597 
1598     info.ReadFrom(aQuery);
1599 
1600     VerifyOrExit(info.mQueryType == kServiceQuerySrvTxt);
1601     VerifyOrExit(info.mConfig.GetServiceMode() == QueryConfig::kServiceModeSrvTxtOptimize);
1602 
1603     secondQuery = aQuery.Clone();
1604     VerifyOrExit(secondQuery != nullptr);
1605 
1606     info.mQueryType         = kServiceQueryTxt;
1607     info.mMessageId         = 0;
1608     info.mTransmissionCount = 0;
1609     info.mMainQuery         = &aQuery;
1610     IgnoreError(SendQuery(*secondQuery, info, /* aUpdateTimer */ true));
1611 
1612     info.mQueryType         = kServiceQuerySrv;
1613     info.mMessageId         = 0;
1614     info.mTransmissionCount = 0;
1615     info.mNextQuery         = secondQuery;
1616     IgnoreError(SendQuery(aQuery, info, /* aUpdateTimer */ true));
1617     error = kErrorNone;
1618 
1619 exit:
1620     return error;
1621 }
1622 
ResolveHostAddressIfNeeded(Query & aQuery,const Message & aResponseMessage)1623 void Client::ResolveHostAddressIfNeeded(Query &aQuery, const Message &aResponseMessage)
1624 {
1625     QueryInfo   info;
1626     Response    response;
1627     ServiceInfo serviceInfo;
1628     char        hostName[Name::kMaxNameSize];
1629 
1630     info.ReadFrom(aQuery);
1631 
1632     VerifyOrExit(info.mQueryType == kServiceQuerySrvTxt || info.mQueryType == kServiceQuerySrv);
1633     VerifyOrExit(info.mShouldResolveHostAddr);
1634 
1635     PopulateResponse(response, aQuery, aResponseMessage);
1636 
1637     ClearAllBytes(serviceInfo);
1638     serviceInfo.mHostNameBuffer     = hostName;
1639     serviceInfo.mHostNameBufferSize = sizeof(hostName);
1640     SuccessOrExit(response.ReadServiceInfo(Response::kAnswerSection, Name(aQuery, kNameOffsetInQuery), serviceInfo));
1641 
1642     // Check whether AAAA record for host address is provided in the SRV query response
1643 
1644     if (AsCoreType(&serviceInfo.mHostAddress).IsUnspecified())
1645     {
1646         Query *newQuery;
1647 
1648         info.mQueryType         = kIp6AddressQuery;
1649         info.mMessageId         = 0;
1650         info.mTransmissionCount = 0;
1651         info.mMainQuery         = &FindMainQuery(aQuery);
1652 
1653         SuccessOrExit(AllocateQuery(info, nullptr, hostName, newQuery));
1654         IgnoreError(SendQuery(*newQuery, info, /* aUpdateTimer */ true));
1655 
1656         // Update `aQuery` to be linked with new query (inserting
1657         // the `newQuery` into the linked-list after `aQuery`).
1658 
1659         info.ReadFrom(aQuery);
1660         info.mNextQuery = newQuery;
1661         UpdateQuery(aQuery, info);
1662     }
1663 
1664 exit:
1665     return;
1666 }
1667 
1668 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1669 
1670 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
PrepareTcpMessage(Message & aMessage)1671 void Client::PrepareTcpMessage(Message &aMessage)
1672 {
1673     uint16_t length = aMessage.GetLength() - aMessage.GetOffset();
1674 
1675     // Prepending the DNS query with length of the packet according to RFC1035.
1676     BigEndian::WriteUint16(length, mSendBufferBytes + mSendLink.mLength);
1677     SuccessOrAssert(
1678         aMessage.Read(aMessage.GetOffset(), (mSendBufferBytes + sizeof(uint16_t) + mSendLink.mLength), length));
1679     mSendLink.mLength += length + sizeof(uint16_t);
1680 }
1681 
HandleTcpSendDone(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)1682 void Client::HandleTcpSendDone(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
1683 {
1684     OT_UNUSED_VARIABLE(aEndpoint);
1685     OT_UNUSED_VARIABLE(aData);
1686     OT_ASSERT(mTcpState == kTcpConnectedSending);
1687 
1688     mSendLink.mLength = 0;
1689     mTcpState         = kTcpConnectedIdle;
1690 }
1691 
HandleTcpSendDoneCallback(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)1692 void Client::HandleTcpSendDoneCallback(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
1693 {
1694     static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpSendDone(aEndpoint, aData);
1695 }
1696 
HandleTcpEstablished(otTcpEndpoint * aEndpoint)1697 void Client::HandleTcpEstablished(otTcpEndpoint *aEndpoint)
1698 {
1699     OT_UNUSED_VARIABLE(aEndpoint);
1700     IgnoreError(mEndpoint.SendByReference(mSendLink, /* aFlags */ 0));
1701     mTcpState = kTcpConnectedSending;
1702 }
1703 
HandleTcpEstablishedCallback(otTcpEndpoint * aEndpoint)1704 void Client::HandleTcpEstablishedCallback(otTcpEndpoint *aEndpoint)
1705 {
1706     static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpEstablished(aEndpoint);
1707 }
1708 
ReadFromLinkBuffer(const otLinkedBuffer * & aLinkedBuffer,size_t & aOffset,Message & aMessage,uint16_t aLength)1709 Error Client::ReadFromLinkBuffer(const otLinkedBuffer *&aLinkedBuffer,
1710                                  size_t                &aOffset,
1711                                  Message               &aMessage,
1712                                  uint16_t               aLength)
1713 {
1714     // Read `aLength` bytes from `aLinkedBuffer` starting at `aOffset`
1715     // and copy the content into `aMessage`. As we read we can move
1716     // to the next `aLinkedBuffer` and update `aOffset`.
1717     // Returns:
1718     // - `kErrorNone` if `aLength` bytes are successfully read and
1719     //    `aOffset` and `aLinkedBuffer` are updated.
1720     // - `kErrorNotFound` is not enough bytes available to read
1721     //    from `aLinkedBuffer`.
1722     // - `kErrorNotBufs` if cannot grow `aMessage` to append bytes.
1723 
1724     Error error = kErrorNone;
1725 
1726     while (aLength > 0)
1727     {
1728         uint16_t bytesToRead = aLength;
1729 
1730         VerifyOrExit(aLinkedBuffer != nullptr, error = kErrorNotFound);
1731 
1732         if (bytesToRead > aLinkedBuffer->mLength - aOffset)
1733         {
1734             bytesToRead = static_cast<uint16_t>(aLinkedBuffer->mLength - aOffset);
1735         }
1736 
1737         SuccessOrExit(error = aMessage.AppendBytes(&aLinkedBuffer->mData[aOffset], bytesToRead));
1738 
1739         aLength -= bytesToRead;
1740         aOffset += bytesToRead;
1741 
1742         if (aOffset == aLinkedBuffer->mLength)
1743         {
1744             aLinkedBuffer = aLinkedBuffer->mNext;
1745             aOffset       = 0;
1746         }
1747     }
1748 
1749 exit:
1750     return error;
1751 }
1752 
HandleTcpReceiveAvailable(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)1753 void Client::HandleTcpReceiveAvailable(otTcpEndpoint *aEndpoint,
1754                                        size_t         aBytesAvailable,
1755                                        bool           aEndOfStream,
1756                                        size_t         aBytesRemaining)
1757 {
1758     OT_UNUSED_VARIABLE(aEndpoint);
1759     OT_UNUSED_VARIABLE(aBytesRemaining);
1760 
1761     Message              *message   = nullptr;
1762     size_t                totalRead = 0;
1763     size_t                offset    = 0;
1764     const otLinkedBuffer *data;
1765 
1766     if (aEndOfStream)
1767     {
1768         // Cleanup is done in disconnected callback.
1769         IgnoreError(mEndpoint.SendEndOfStream());
1770     }
1771 
1772     SuccessOrExit(mEndpoint.ReceiveByReference(data));
1773     VerifyOrExit(data != nullptr);
1774 
1775     message = mSocket.NewMessage();
1776     VerifyOrExit(message != nullptr);
1777 
1778     while (aBytesAvailable > totalRead)
1779     {
1780         uint16_t length;
1781 
1782         // Read the `length` field.
1783         SuccessOrExit(ReadFromLinkBuffer(data, offset, *message, sizeof(uint16_t)));
1784 
1785         IgnoreError(message->Read(/* aOffset */ 0, length));
1786         length = BigEndian::HostSwap16(length);
1787 
1788         // Try to read `length` bytes.
1789         IgnoreError(message->SetLength(0));
1790         SuccessOrExit(ReadFromLinkBuffer(data, offset, *message, length));
1791 
1792         totalRead += length + sizeof(uint16_t);
1793 
1794         // Now process the read message as query response.
1795         ProcessResponse(*message);
1796 
1797         IgnoreError(message->SetLength(0));
1798 
1799         // Loop again to see if we can read another response.
1800     }
1801 
1802 exit:
1803     // Inform `mEndPoint` about the total read and processed bytes
1804     IgnoreError(mEndpoint.CommitReceive(totalRead, /* aFlags */ 0));
1805     FreeMessage(message);
1806 }
1807 
HandleTcpReceiveAvailableCallback(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)1808 void Client::HandleTcpReceiveAvailableCallback(otTcpEndpoint *aEndpoint,
1809                                                size_t         aBytesAvailable,
1810                                                bool           aEndOfStream,
1811                                                size_t         aBytesRemaining)
1812 {
1813     static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))
1814         ->HandleTcpReceiveAvailable(aEndpoint, aBytesAvailable, aEndOfStream, aBytesRemaining);
1815 }
1816 
HandleTcpDisconnected(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)1817 void Client::HandleTcpDisconnected(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
1818 {
1819     OT_UNUSED_VARIABLE(aEndpoint);
1820     OT_UNUSED_VARIABLE(aReason);
1821     QueryInfo info;
1822 
1823     IgnoreError(mEndpoint.Deinitialize());
1824     mTcpState = kTcpUninitialized;
1825 
1826     // Abort queries in case of connection failures
1827     for (Query &mainQuery : mMainQueries)
1828     {
1829         info.ReadFrom(mainQuery);
1830 
1831         if (info.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1832         {
1833             FinalizeQuery(mainQuery, kErrorAbort);
1834         }
1835     }
1836 }
1837 
HandleTcpDisconnectedCallback(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)1838 void Client::HandleTcpDisconnectedCallback(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
1839 {
1840     static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpDisconnected(aEndpoint, aReason);
1841 }
1842 
1843 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1844 
1845 } // namespace Dns
1846 } // namespace ot
1847 
1848 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
1849