• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2017-2021, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 #include "dns_client.hpp"
30 
31 #if OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
32 
33 #include "common/array.hpp"
34 #include "common/as_core_type.hpp"
35 #include "common/code_utils.hpp"
36 #include "common/debug.hpp"
37 #include "common/instance.hpp"
38 #include "common/locator_getters.hpp"
39 #include "common/log.hpp"
40 #include "net/udp6.hpp"
41 #include "thread/network_data_types.hpp"
42 #include "thread/thread_netif.hpp"
43 
44 /**
45  * @file
46  *   This file implements the DNS client.
47  */
48 
49 namespace ot {
50 namespace Dns {
51 
52 RegisterLogModule("DnsClient");
53 
54 //---------------------------------------------------------------------------------------------------------------------
55 // Client::QueryConfig
56 
57 const char Client::QueryConfig::kDefaultServerAddressString[] = OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_IP6_ADDRESS;
58 
QueryConfig(InitMode aMode)59 Client::QueryConfig::QueryConfig(InitMode aMode)
60 {
61     OT_UNUSED_VARIABLE(aMode);
62 
63     IgnoreError(GetServerSockAddr().GetAddress().FromString(kDefaultServerAddressString));
64     GetServerSockAddr().SetPort(kDefaultServerPort);
65     SetResponseTimeout(kDefaultResponseTimeout);
66     SetMaxTxAttempts(kDefaultMaxTxAttempts);
67     SetRecursionFlag(kDefaultRecursionDesired ? kFlagRecursionDesired : kFlagNoRecursion);
68 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
69     SetNat64Mode(kDefaultNat64Allowed ? kNat64Allow : kNat64Disallow);
70 #endif
71 }
72 
SetFrom(const QueryConfig & aConfig,const QueryConfig & aDefaultConfig)73 void Client::QueryConfig::SetFrom(const QueryConfig &aConfig, const QueryConfig &aDefaultConfig)
74 {
75     // This method sets the config from `aConfig` replacing any
76     // unspecified fields (value zero) with the fields from
77     // `aDefaultConfig`.
78 
79     *this = aConfig;
80 
81     if (GetServerSockAddr().GetAddress().IsUnspecified())
82     {
83         GetServerSockAddr().GetAddress() = aDefaultConfig.GetServerSockAddr().GetAddress();
84     }
85 
86     if (GetServerSockAddr().GetPort() == 0)
87     {
88         GetServerSockAddr().SetPort(aDefaultConfig.GetServerSockAddr().GetPort());
89     }
90 
91     if (GetResponseTimeout() == 0)
92     {
93         SetResponseTimeout(aDefaultConfig.GetResponseTimeout());
94     }
95 
96     if (GetMaxTxAttempts() == 0)
97     {
98         SetMaxTxAttempts(aDefaultConfig.GetMaxTxAttempts());
99     }
100 
101     if (GetRecursionFlag() == kFlagUnspecified)
102     {
103         SetRecursionFlag(aDefaultConfig.GetRecursionFlag());
104     }
105 
106 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
107     if (GetNat64Mode() == kNat64Unspecified)
108     {
109         SetNat64Mode(aDefaultConfig.GetNat64Mode());
110     }
111 #endif
112 }
113 
114 //---------------------------------------------------------------------------------------------------------------------
115 // Client::Response
116 
SelectSection(Section aSection,uint16_t & aOffset,uint16_t & aNumRecord) const117 void Client::Response::SelectSection(Section aSection, uint16_t &aOffset, uint16_t &aNumRecord) const
118 {
119     switch (aSection)
120     {
121     case kAnswerSection:
122         aOffset    = mAnswerOffset;
123         aNumRecord = mAnswerRecordCount;
124         break;
125     case kAdditionalDataSection:
126     default:
127         aOffset    = mAdditionalOffset;
128         aNumRecord = mAdditionalRecordCount;
129         break;
130     }
131 }
132 
GetName(char * aNameBuffer,uint16_t aNameBufferSize) const133 Error Client::Response::GetName(char *aNameBuffer, uint16_t aNameBufferSize) const
134 {
135     uint16_t offset = kNameOffsetInQuery;
136 
137     return Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize);
138 }
139 
CheckForHostNameAlias(Section aSection,Name & aHostName) const140 Error Client::Response::CheckForHostNameAlias(Section aSection, Name &aHostName) const
141 {
142     // If the response includes a CNAME record mapping the query host
143     // name to a canonical name, we update `aHostName` to the new alias
144     // name. Otherwise `aHostName` remains as before.
145 
146     Error       error;
147     uint16_t    offset;
148     uint16_t    numRecords;
149     CnameRecord cnameRecord;
150 
151     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
152 
153     SelectSection(aSection, offset, numRecords);
154     error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aHostName, cnameRecord);
155 
156     switch (error)
157     {
158     case kErrorNone:
159         // A CNAME record was found. `offset` now points to after the
160         // last read byte within the `mMessage` into the `cnameRecord`
161         // (which is the start of the new canonical name).
162         aHostName.SetFromMessage(*mMessage, offset);
163         error = Name::ParseName(*mMessage, offset);
164         break;
165 
166     case kErrorNotFound:
167         error = kErrorNone;
168         break;
169 
170     default:
171         break;
172     }
173 
174 exit:
175     return error;
176 }
177 
FindHostAddress(Section aSection,const Name & aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const178 Error Client::Response::FindHostAddress(Section       aSection,
179                                         const Name &  aHostName,
180                                         uint16_t      aIndex,
181                                         Ip6::Address &aAddress,
182                                         uint32_t &    aTtl) const
183 {
184     Error      error;
185     uint16_t   offset;
186     uint16_t   numRecords;
187     Name       name = aHostName;
188     AaaaRecord aaaaRecord;
189 
190     SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
191 
192     SelectSection(aSection, offset, numRecords);
193     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aaaaRecord));
194     aAddress = aaaaRecord.GetAddress();
195     aTtl     = aaaaRecord.GetTtl();
196 
197 exit:
198     return error;
199 }
200 
201 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
202 
FindARecord(Section aSection,const Name & aHostName,uint16_t aIndex,ARecord & aARecord) const203 Error Client::Response::FindARecord(Section aSection, const Name &aHostName, uint16_t aIndex, ARecord &aARecord) const
204 {
205     Error    error;
206     uint16_t offset;
207     uint16_t numRecords;
208     Name     name = aHostName;
209 
210     SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
211 
212     SelectSection(aSection, offset, numRecords);
213     error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aARecord);
214 
215 exit:
216     return error;
217 }
218 
219 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
220 
221 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
222 
FindServiceInfo(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const223 Error Client::Response::FindServiceInfo(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
224 {
225     // This method searches for SRV and TXT records in the given
226     // section matching the record name against `aName`, and updates
227     // the `aServiceInfo` accordingly. It also searches for AAAA
228     // record for host name associated with the service (from SRV
229     // record). The search for AAAA record is always performed in
230     // Additional Data section (independent of the value given in
231     // `aSection`).
232 
233     Error     error;
234     uint16_t  offset;
235     uint16_t  numRecords;
236     Name      hostName;
237     SrvRecord srvRecord;
238     TxtRecord txtRecord;
239 
240     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
241 
242     // Search for a matching SRV record
243     SelectSection(aSection, offset, numRecords);
244     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, srvRecord));
245 
246     aServiceInfo.mTtl      = srvRecord.GetTtl();
247     aServiceInfo.mPort     = srvRecord.GetPort();
248     aServiceInfo.mPriority = srvRecord.GetPriority();
249     aServiceInfo.mWeight   = srvRecord.GetWeight();
250 
251     hostName.SetFromMessage(*mMessage, offset);
252 
253     if (aServiceInfo.mHostNameBuffer != nullptr)
254     {
255         SuccessOrExit(error = srvRecord.ReadTargetHostName(*mMessage, offset, aServiceInfo.mHostNameBuffer,
256                                                            aServiceInfo.mHostNameBufferSize));
257     }
258     else
259     {
260         SuccessOrExit(error = Name::ParseName(*mMessage, offset));
261     }
262 
263     // Search in additional section for AAAA record for the host name.
264 
265     error = FindHostAddress(kAdditionalDataSection, hostName, /* aIndex */ 0, AsCoreType(&aServiceInfo.mHostAddress),
266                             aServiceInfo.mHostAddressTtl);
267 
268     if (error == kErrorNotFound)
269     {
270         AsCoreType(&aServiceInfo.mHostAddress).Clear();
271         aServiceInfo.mHostAddressTtl = 0;
272     }
273     else
274     {
275         SuccessOrExit(error);
276     }
277 
278     // A null `mTxtData` indicates that caller does not want to retrieve TXT data.
279     VerifyOrExit(aServiceInfo.mTxtData != nullptr);
280 
281     // Search for a matching TXT record. If not found, indicate this by
282     // setting `aServiceInfo.mTxtDataSize` to zero.
283 
284     SelectSection(aSection, offset, numRecords);
285     error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, txtRecord);
286 
287     switch (error)
288     {
289     case kErrorNone:
290         SuccessOrExit(error =
291                           txtRecord.ReadTxtData(*mMessage, offset, aServiceInfo.mTxtData, aServiceInfo.mTxtDataSize));
292         aServiceInfo.mTxtDataTtl = txtRecord.GetTtl();
293         break;
294 
295     case kErrorNotFound:
296         aServiceInfo.mTxtDataSize = 0;
297         aServiceInfo.mTxtDataTtl  = 0;
298         break;
299 
300     default:
301         ExitNow();
302     }
303 
304 exit:
305     return error;
306 }
307 
308 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
309 
310 //---------------------------------------------------------------------------------------------------------------------
311 // Client::AddressResponse
312 
GetAddress(uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const313 Error Client::AddressResponse::GetAddress(uint16_t aIndex, Ip6::Address &aAddress, uint32_t &aTtl) const
314 {
315     Error error = kErrorNone;
316     Name  name(*mQuery, kNameOffsetInQuery);
317 
318 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
319 
320     // If the response is for an IPv4 address query or if it is an
321     // IPv6 address query response with no IPv6 address but with
322     // an IPv4 in its additional section, we read the IPv4 address
323     // and translate it to an IPv6 address.
324 
325     QueryInfo info;
326 
327     info.ReadFrom(*mQuery);
328 
329     if ((info.mQueryType == kIp4AddressQuery) || mIp6QueryResponseRequiresNat64)
330     {
331         Section                          section;
332         ARecord                          aRecord;
333         NetworkData::ExternalRouteConfig nat64Prefix;
334 
335         VerifyOrExit(mInstance->Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
336                      error = kErrorInvalidState);
337 
338         section = (info.mQueryType == kIp4AddressQuery) ? kAnswerSection : kAdditionalDataSection;
339         SuccessOrExit(error = FindARecord(section, name, aIndex, aRecord));
340 
341         aAddress.SynthesizeFromIp4Address(nat64Prefix.GetPrefix(), aRecord.GetAddress());
342         aTtl = aRecord.GetTtl();
343 
344         ExitNow();
345     }
346 
347 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
348 
349     ExitNow(error = FindHostAddress(kAnswerSection, name, aIndex, aAddress, aTtl));
350 
351 exit:
352     return error;
353 }
354 
355 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
356 
357 //---------------------------------------------------------------------------------------------------------------------
358 // Client::BrowseResponse
359 
GetServiceInstance(uint16_t aIndex,char * aLabelBuffer,uint8_t aLabelBufferSize) const360 Error Client::BrowseResponse::GetServiceInstance(uint16_t aIndex, char *aLabelBuffer, uint8_t aLabelBufferSize) const
361 {
362     Error     error;
363     uint16_t  offset;
364     uint16_t  numRecords;
365     Name      serviceName(*mQuery, kNameOffsetInQuery);
366     PtrRecord ptrRecord;
367 
368     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
369 
370     SelectSection(kAnswerSection, offset, numRecords);
371     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, serviceName, ptrRecord));
372     error = ptrRecord.ReadPtrName(*mMessage, offset, aLabelBuffer, aLabelBufferSize, nullptr, 0);
373 
374 exit:
375     return error;
376 }
377 
GetServiceInfo(const char * aInstanceLabel,ServiceInfo & aServiceInfo) const378 Error Client::BrowseResponse::GetServiceInfo(const char *aInstanceLabel, ServiceInfo &aServiceInfo) const
379 {
380     Error error;
381     Name  instanceName;
382 
383     // Find a matching PTR record for the service instance label.
384     // Then search and read SRV, TXT and AAAA records in Additional Data section
385     // matching the same name to populate `aServiceInfo`.
386 
387     SuccessOrExit(error = FindPtrRecord(aInstanceLabel, instanceName));
388     error = FindServiceInfo(kAdditionalDataSection, instanceName, aServiceInfo);
389 
390 exit:
391     return error;
392 }
393 
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const394 Error Client::BrowseResponse::GetHostAddress(const char *  aHostName,
395                                              uint16_t      aIndex,
396                                              Ip6::Address &aAddress,
397                                              uint32_t &    aTtl) const
398 {
399     return FindHostAddress(kAdditionalDataSection, Name(aHostName), aIndex, aAddress, aTtl);
400 }
401 
FindPtrRecord(const char * aInstanceLabel,Name & aInstanceName) const402 Error Client::BrowseResponse::FindPtrRecord(const char *aInstanceLabel, Name &aInstanceName) const
403 {
404     // This method searches within the Answer Section for a PTR record
405     // matching a given instance label @aInstanceLabel. If found, the
406     // `aName` is updated to return the name in the message.
407 
408     Error     error;
409     uint16_t  offset;
410     Name      serviceName(*mQuery, kNameOffsetInQuery);
411     uint16_t  numRecords;
412     uint16_t  labelOffset;
413     PtrRecord ptrRecord;
414 
415     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
416 
417     SelectSection(kAnswerSection, offset, numRecords);
418 
419     for (; numRecords > 0; numRecords--)
420     {
421         SuccessOrExit(error = Name::CompareName(*mMessage, offset, serviceName));
422 
423         error = ResourceRecord::ReadRecord(*mMessage, offset, ptrRecord);
424 
425         if (error == kErrorNotFound)
426         {
427             // `ReadRecord()` updates `offset` to skip over a
428             // non-matching record.
429             continue;
430         }
431 
432         SuccessOrExit(error);
433 
434         // It is a PTR record. Check the first label to match the
435         // instance label.
436 
437         labelOffset = offset;
438         error       = Name::CompareLabel(*mMessage, labelOffset, aInstanceLabel);
439 
440         if (error == kErrorNone)
441         {
442             aInstanceName.SetFromMessage(*mMessage, offset);
443             ExitNow();
444         }
445 
446         VerifyOrExit(error == kErrorNotFound);
447 
448         // Update offset to skip over the PTR record.
449         offset += static_cast<uint16_t>(ptrRecord.GetSize()) - sizeof(ptrRecord);
450     }
451 
452     error = kErrorNotFound;
453 
454 exit:
455     return error;
456 }
457 
458 //---------------------------------------------------------------------------------------------------------------------
459 // Client::ServiceResponse
460 
GetServiceName(char * aLabelBuffer,uint8_t aLabelBufferSize,char * aNameBuffer,uint16_t aNameBufferSize) const461 Error Client::ServiceResponse::GetServiceName(char *   aLabelBuffer,
462                                               uint8_t  aLabelBufferSize,
463                                               char *   aNameBuffer,
464                                               uint16_t aNameBufferSize) const
465 {
466     Error    error;
467     uint16_t offset = kNameOffsetInQuery;
468 
469     SuccessOrExit(error = Name::ReadLabel(*mQuery, offset, aLabelBuffer, aLabelBufferSize));
470 
471     VerifyOrExit(aNameBuffer != nullptr);
472     SuccessOrExit(error = Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize));
473 
474 exit:
475     return error;
476 }
477 
GetServiceInfo(ServiceInfo & aServiceInfo) const478 Error Client::ServiceResponse::GetServiceInfo(ServiceInfo &aServiceInfo) const
479 {
480     // Search and read SRV, TXT records in Answer Section
481     // matching name from query.
482 
483     return FindServiceInfo(kAnswerSection, Name(*mQuery, kNameOffsetInQuery), aServiceInfo);
484 }
485 
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const486 Error Client::ServiceResponse::GetHostAddress(const char *  aHostName,
487                                               uint16_t      aIndex,
488                                               Ip6::Address &aAddress,
489                                               uint32_t &    aTtl) const
490 {
491     return FindHostAddress(kAdditionalDataSection, Name(aHostName), aIndex, aAddress, aTtl);
492 }
493 
494 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
495 
496 //---------------------------------------------------------------------------------------------------------------------
497 // Client
498 
499 const uint16_t Client::kIp6AddressQueryRecordTypes[] = {ResourceRecord::kTypeAaaa};
500 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
501 const uint16_t Client::kIp4AddressQueryRecordTypes[] = {ResourceRecord::kTypeA};
502 #endif
503 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
504 const uint16_t Client::kBrowseQueryRecordTypes[]  = {ResourceRecord::kTypePtr};
505 const uint16_t Client::kServiceQueryRecordTypes[] = {ResourceRecord::kTypeSrv, ResourceRecord::kTypeTxt};
506 #endif
507 
508 const uint8_t Client::kQuestionCount[] = {
509     /* kIp6AddressQuery -> */ GetArrayLength(kIp6AddressQueryRecordTypes), // AAAA records
510 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
511     /* kIp4AddressQuery -> */ GetArrayLength(kIp4AddressQueryRecordTypes), // A records
512 #endif
513 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
514     /* kBrowseQuery  -> */ GetArrayLength(kBrowseQueryRecordTypes),  // PTR records
515     /* kServiceQuery -> */ GetArrayLength(kServiceQueryRecordTypes), // SRV and TXT records
516 #endif
517 };
518 
519 const uint16_t *Client::kQuestionRecordTypes[] = {
520     /* kIp6AddressQuery -> */ kIp6AddressQueryRecordTypes,
521 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
522     /* kIp4AddressQuery -> */ kIp4AddressQueryRecordTypes,
523 #endif
524 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
525     /* kBrowseQuery  -> */ kBrowseQueryRecordTypes,
526     /* kServiceQuery -> */ kServiceQueryRecordTypes,
527 #endif
528 };
529 
Client(Instance & aInstance)530 Client::Client(Instance &aInstance)
531     : InstanceLocator(aInstance)
532     , mSocket(aInstance)
533     , mTimer(aInstance, Client::HandleTimer)
534     , mDefaultConfig(QueryConfig::kInitFromDefaults)
535 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
536     , mUserDidSetDefaultAddress(false)
537 #endif
538 {
539     static_assert(kIp6AddressQuery == 0, "kIp6AddressQuery value is not correct");
540 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
541     static_assert(kIp4AddressQuery == 1, "kIp4AddressQuery value is not correct");
542 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
543     static_assert(kBrowseQuery == 2, "kBrowseQuery value is not correct");
544     static_assert(kServiceQuery == 3, "kServiceQuery value is not correct");
545 #endif
546 #elif OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
547     static_assert(kBrowseQuery == 1, "kBrowseQuery value is not correct");
548     static_assert(kServiceQuery == 2, "kServiceQuery value is not correct");
549 #endif
550 }
551 
Start(void)552 Error Client::Start(void)
553 {
554     Error error;
555 
556     SuccessOrExit(error = mSocket.Open(&Client::HandleUdpReceive, this));
557     SuccessOrExit(error = mSocket.Bind(0, OT_NETIF_UNSPECIFIED));
558 
559 exit:
560     return error;
561 }
562 
Stop(void)563 void Client::Stop(void)
564 {
565     Query *query;
566 
567     while ((query = mQueries.GetHead()) != nullptr)
568     {
569         FinalizeQuery(*query, kErrorAbort);
570     }
571 
572     IgnoreError(mSocket.Close());
573 }
574 
SetDefaultConfig(const QueryConfig & aQueryConfig)575 void Client::SetDefaultConfig(const QueryConfig &aQueryConfig)
576 {
577     QueryConfig startingDefault(QueryConfig::kInitFromDefaults);
578 
579     mDefaultConfig.SetFrom(aQueryConfig, startingDefault);
580 
581 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
582     mUserDidSetDefaultAddress = !aQueryConfig.GetServerSockAddr().GetAddress().IsUnspecified();
583     UpdateDefaultConfigAddress();
584 #endif
585 }
586 
ResetDefaultConfig(void)587 void Client::ResetDefaultConfig(void)
588 {
589     mDefaultConfig = QueryConfig(QueryConfig::kInitFromDefaults);
590 
591 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
592     mUserDidSetDefaultAddress = false;
593     UpdateDefaultConfigAddress();
594 #endif
595 }
596 
597 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
UpdateDefaultConfigAddress(void)598 void Client::UpdateDefaultConfigAddress(void)
599 {
600     const Ip6::Address &srpServerAddr = Get<Srp::Client>().GetServerAddress().GetAddress();
601 
602     if (!mUserDidSetDefaultAddress && Get<Srp::Client>().IsServerSelectedByAutoStart() &&
603         !srpServerAddr.IsUnspecified())
604     {
605         mDefaultConfig.GetServerSockAddr().SetAddress(srpServerAddr);
606     }
607 }
608 #endif
609 
ResolveAddress(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)610 Error Client::ResolveAddress(const char *       aHostName,
611                              AddressCallback    aCallback,
612                              void *             aContext,
613                              const QueryConfig *aConfig)
614 {
615     QueryInfo info;
616 
617     info.Clear();
618     info.mQueryType                 = kIp6AddressQuery;
619     info.mCallback.mAddressCallback = aCallback;
620 
621     return StartQuery(info, aConfig, nullptr, aHostName, aContext);
622 }
623 
624 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
ResolveIp4Address(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)625 Error Client::ResolveIp4Address(const char *       aHostName,
626                                 AddressCallback    aCallback,
627                                 void *             aContext,
628                                 const QueryConfig *aConfig)
629 {
630     QueryInfo info;
631 
632     info.Clear();
633     info.mQueryType                 = kIp4AddressQuery;
634     info.mCallback.mAddressCallback = aCallback;
635 
636     return StartQuery(info, aConfig, nullptr, aHostName, aContext);
637 }
638 #endif
639 
640 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
641 
Browse(const char * aServiceName,BrowseCallback aCallback,void * aContext,const QueryConfig * aConfig)642 Error Client::Browse(const char *aServiceName, BrowseCallback aCallback, void *aContext, const QueryConfig *aConfig)
643 {
644     QueryInfo info;
645 
646     info.Clear();
647     info.mQueryType                = kBrowseQuery;
648     info.mCallback.mBrowseCallback = aCallback;
649 
650     return StartQuery(info, aConfig, nullptr, aServiceName, aContext);
651 }
652 
ResolveService(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)653 Error Client::ResolveService(const char *       aInstanceLabel,
654                              const char *       aServiceName,
655                              ServiceCallback    aCallback,
656                              void *             aContext,
657                              const QueryConfig *aConfig)
658 {
659     QueryInfo info;
660     Error     error;
661 
662     VerifyOrExit(aInstanceLabel != nullptr, error = kErrorInvalidArgs);
663 
664     info.Clear();
665     info.mQueryType                 = kServiceQuery;
666     info.mCallback.mServiceCallback = aCallback;
667 
668     error = StartQuery(info, aConfig, aInstanceLabel, aServiceName, aContext);
669 
670 exit:
671     return error;
672 }
673 
674 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
675 
StartQuery(QueryInfo & aInfo,const QueryConfig * aConfig,const char * aLabel,const char * aName,void * aContext)676 Error Client::StartQuery(QueryInfo &        aInfo,
677                          const QueryConfig *aConfig,
678                          const char *       aLabel,
679                          const char *       aName,
680                          void *             aContext)
681 {
682     // This method assumes that `mQueryType` and `mCallback` to be
683     // already set by caller on `aInfo`. The `aLabel` can be `nullptr`
684     // and then `aName` provides the full name, otherwise the name is
685     // appended as `{aLabel}.{aName}`.
686 
687     Error  error;
688     Query *query;
689 
690     VerifyOrExit(mSocket.IsBound(), error = kErrorInvalidState);
691 
692     if (aConfig == nullptr)
693     {
694         aInfo.mConfig = mDefaultConfig;
695     }
696     else
697     {
698         // To form the config for this query, replace any unspecified
699         // fields (zero value) in the given `aConfig` with the fields
700         // from `mDefaultConfig`.
701 
702         aInfo.mConfig.SetFrom(*aConfig, mDefaultConfig);
703     }
704 
705     aInfo.mCallbackContext = aContext;
706 
707 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
708     if (aInfo.mQueryType == kIp4AddressQuery)
709     {
710         NetworkData::ExternalRouteConfig nat64Prefix;
711 
712         VerifyOrExit(aInfo.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow, error = kErrorInvalidArgs);
713         VerifyOrExit(Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
714                      error = kErrorInvalidState);
715     }
716 #endif
717 
718     SuccessOrExit(error = AllocateQuery(aInfo, aLabel, aName, query));
719     mQueries.Enqueue(*query);
720 
721     SendQuery(*query, aInfo, /* aUpdateTimer */ true);
722 
723 exit:
724     return error;
725 }
726 
AllocateQuery(const QueryInfo & aInfo,const char * aLabel,const char * aName,Query * & aQuery)727 Error Client::AllocateQuery(const QueryInfo &aInfo, const char *aLabel, const char *aName, Query *&aQuery)
728 {
729     Error error = kErrorNone;
730 
731     aQuery = Get<MessagePool>().Allocate(Message::kTypeOther);
732     VerifyOrExit(aQuery != nullptr, error = kErrorNoBufs);
733 
734     SuccessOrExit(error = aQuery->Append(aInfo));
735 
736     if (aLabel != nullptr)
737     {
738         SuccessOrExit(error = Name::AppendLabel(aLabel, *aQuery));
739     }
740 
741     SuccessOrExit(error = Name::AppendName(aName, *aQuery));
742 
743 exit:
744     FreeAndNullMessageOnError(aQuery, error);
745     return error;
746 }
747 
FreeQuery(Query & aQuery)748 void Client::FreeQuery(Query &aQuery)
749 {
750     mQueries.DequeueAndFree(aQuery);
751 }
752 
SendQuery(Query & aQuery,QueryInfo & aInfo,bool aUpdateTimer)753 void Client::SendQuery(Query &aQuery, QueryInfo &aInfo, bool aUpdateTimer)
754 {
755     // This method prepares and sends a query message represented by
756     // `aQuery` and `aInfo`. This method updates `aInfo` (e.g., sets
757     // the new `mRetransmissionTime`) and updates it in `aQuery` as
758     // well. `aUpdateTimer` indicates whether the timer should be
759     // updated when query is sent or not (used in the case where timer
760     // is handled by caller).
761 
762     Error            error   = kErrorNone;
763     Message *        message = nullptr;
764     Header           header;
765     Ip6::MessageInfo messageInfo;
766 
767     aInfo.mTransmissionCount++;
768     aInfo.mRetransmissionTime = TimerMilli::GetNow() + aInfo.mConfig.GetResponseTimeout();
769 
770     if (aInfo.mMessageId == 0)
771     {
772         do
773         {
774             SuccessOrExit(error = header.SetRandomMessageId());
775         } while ((header.GetMessageId() == 0) || (FindQueryById(header.GetMessageId()) != nullptr));
776 
777         aInfo.mMessageId = header.GetMessageId();
778     }
779     else
780     {
781         header.SetMessageId(aInfo.mMessageId);
782     }
783 
784     header.SetType(Header::kTypeQuery);
785     header.SetQueryType(Header::kQueryTypeStandard);
786 
787     if (aInfo.mConfig.GetRecursionFlag() == QueryConfig::kFlagRecursionDesired)
788     {
789         header.SetRecursionDesiredFlag();
790     }
791 
792     header.SetQuestionCount(kQuestionCount[aInfo.mQueryType]);
793 
794     message = mSocket.NewMessage(0);
795     VerifyOrExit(message != nullptr, error = kErrorNoBufs);
796 
797     SuccessOrExit(error = message->Append(header));
798 
799     // Prepare the question section.
800 
801     for (uint8_t num = 0; num < kQuestionCount[aInfo.mQueryType]; num++)
802     {
803         SuccessOrExit(error = AppendNameFromQuery(aQuery, *message));
804         SuccessOrExit(error = message->Append(Question(kQuestionRecordTypes[aInfo.mQueryType][num])));
805     }
806 
807     messageInfo.SetPeerAddr(aInfo.mConfig.GetServerSockAddr().GetAddress());
808     messageInfo.SetPeerPort(aInfo.mConfig.GetServerSockAddr().GetPort());
809 
810     SuccessOrExit(error = mSocket.SendTo(*message, messageInfo));
811 
812 exit:
813     FreeMessageOnError(message, error);
814 
815     UpdateQuery(aQuery, aInfo);
816 
817     if (aUpdateTimer)
818     {
819         mTimer.FireAtIfEarlier(aInfo.mRetransmissionTime);
820     }
821 }
822 
AppendNameFromQuery(const Query & aQuery,Message & aMessage)823 Error Client::AppendNameFromQuery(const Query &aQuery, Message &aMessage)
824 {
825     Error    error = kErrorNone;
826     uint16_t offset;
827     uint16_t length;
828 
829     // The name is encoded and included after the `Info` in `aQuery`. We
830     // first calculate the encoded length of the name, then grow the
831     // message, and finally copy the encoded name bytes from `aQuery`
832     // into `aMessage`.
833 
834     length = aQuery.GetLength() - kNameOffsetInQuery;
835 
836     offset = aMessage.GetLength();
837     SuccessOrExit(error = aMessage.SetLength(offset + length));
838 
839     aQuery.CopyTo(/* aSourceOffset */ kNameOffsetInQuery, /* aDestOffset */ offset, length, aMessage);
840 
841 exit:
842     return error;
843 }
844 
FinalizeQuery(Query & aQuery,Error aError)845 void Client::FinalizeQuery(Query &aQuery, Error aError)
846 {
847     Response  response;
848     QueryInfo info;
849 
850     response.mInstance = &Get<Instance>();
851     response.mQuery    = &aQuery;
852     info.ReadFrom(aQuery);
853 
854     FinalizeQuery(response, info.mQueryType, aError);
855 }
856 
FinalizeQuery(Response & aResponse,QueryType aType,Error aError)857 void Client::FinalizeQuery(Response &aResponse, QueryType aType, Error aError)
858 {
859     Callback callback;
860     void *   context;
861 
862     GetCallback(*aResponse.mQuery, callback, context);
863 
864     switch (aType)
865     {
866     case kIp6AddressQuery:
867 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
868     case kIp4AddressQuery:
869 #endif
870         if (callback.mAddressCallback != nullptr)
871         {
872             callback.mAddressCallback(aError, &aResponse, context);
873         }
874         break;
875 
876 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
877     case kBrowseQuery:
878         if (callback.mBrowseCallback != nullptr)
879         {
880             callback.mBrowseCallback(aError, &aResponse, context);
881         }
882         break;
883 
884     case kServiceQuery:
885         if (callback.mServiceCallback != nullptr)
886         {
887             callback.mServiceCallback(aError, &aResponse, context);
888         }
889         break;
890 #endif
891     }
892 
893     FreeQuery(*aResponse.mQuery);
894 }
895 
GetCallback(const Query & aQuery,Callback & aCallback,void * & aContext)896 void Client::GetCallback(const Query &aQuery, Callback &aCallback, void *&aContext)
897 {
898     QueryInfo info;
899 
900     info.ReadFrom(aQuery);
901 
902     aCallback = info.mCallback;
903     aContext  = info.mCallbackContext;
904 }
905 
FindQueryById(uint16_t aMessageId)906 Client::Query *Client::FindQueryById(uint16_t aMessageId)
907 {
908     Query *   matchedQuery = nullptr;
909     QueryInfo info;
910 
911     for (Query &query : mQueries)
912     {
913         info.ReadFrom(query);
914 
915         if (info.mMessageId == aMessageId)
916         {
917             matchedQuery = &query;
918             break;
919         }
920     }
921 
922     return matchedQuery;
923 }
924 
HandleUdpReceive(void * aContext,otMessage * aMessage,const otMessageInfo * aMsgInfo)925 void Client::HandleUdpReceive(void *aContext, otMessage *aMessage, const otMessageInfo *aMsgInfo)
926 {
927     OT_UNUSED_VARIABLE(aMsgInfo);
928 
929     static_cast<Client *>(aContext)->ProcessResponse(AsCoreType(aMessage));
930 }
931 
ProcessResponse(const Message & aMessage)932 void Client::ProcessResponse(const Message &aMessage)
933 {
934     Response  response;
935     QueryType type;
936     Error     responseError;
937 
938     response.mInstance = &Get<Instance>();
939     response.mMessage  = &aMessage;
940 
941     // We intentionally parse the response in a separate method
942     // `ParseResponse()` to free all the stack allocated variables
943     // (e.g., `QueryInfo`) used during parsing of the message before
944     // finalizing the query and invoking the user's callback.
945 
946     SuccessOrExit(ParseResponse(response, type, responseError));
947     FinalizeQuery(response, type, responseError);
948 
949 exit:
950     return;
951 }
952 
ParseResponse(Response & aResponse,QueryType & aType,Error & aResponseError)953 Error Client::ParseResponse(Response &aResponse, QueryType &aType, Error &aResponseError)
954 {
955     Error          error   = kErrorNone;
956     const Message &message = *aResponse.mMessage;
957     uint16_t       offset  = message.GetOffset();
958     Header         header;
959     QueryInfo      info;
960     Name           queryName;
961 
962     SuccessOrExit(error = message.Read(offset, header));
963     offset += sizeof(Header);
964 
965     VerifyOrExit((header.GetType() == Header::kTypeResponse) && (header.GetQueryType() == Header::kQueryTypeStandard) &&
966                      !header.IsTruncationFlagSet(),
967                  error = kErrorDrop);
968 
969     aResponse.mQuery = FindQueryById(header.GetMessageId());
970     VerifyOrExit(aResponse.mQuery != nullptr, error = kErrorNotFound);
971 
972     info.ReadFrom(*aResponse.mQuery);
973     aType = info.mQueryType;
974 
975     queryName.SetFromMessage(*aResponse.mQuery, kNameOffsetInQuery);
976 
977     // Check the Question Section
978 
979     if (header.GetQuestionCount() == kQuestionCount[aType])
980     {
981         for (uint8_t num = 0; num < kQuestionCount[aType]; num++)
982         {
983             SuccessOrExit(error = Name::CompareName(message, offset, queryName));
984             offset += sizeof(Question);
985         }
986     }
987     else
988     {
989         VerifyOrExit((header.GetResponseCode() != Header::kResponseSuccess) && (header.GetQuestionCount() == 0),
990                      error = kErrorParse);
991     }
992 
993     // Check the answer, authority and additional record sections
994 
995     aResponse.mAnswerOffset = offset;
996     SuccessOrExit(error = ResourceRecord::ParseRecords(message, offset, header.GetAnswerCount()));
997     SuccessOrExit(error = ResourceRecord::ParseRecords(message, offset, header.GetAuthorityRecordCount()));
998     aResponse.mAdditionalOffset = offset;
999     SuccessOrExit(error = ResourceRecord::ParseRecords(message, offset, header.GetAdditionalRecordCount()));
1000 
1001     aResponse.mAnswerRecordCount     = header.GetAnswerCount();
1002     aResponse.mAdditionalRecordCount = header.GetAdditionalRecordCount();
1003 
1004     // Check the response code from server
1005 
1006     aResponseError = Header::ResponseCodeToError(header.GetResponseCode());
1007 
1008 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1009 
1010     if (aType == kIp6AddressQuery)
1011     {
1012         Ip6::Address ip6ddress;
1013         uint32_t     ttl;
1014         ARecord      aRecord;
1015 
1016         // If the response does not contain an answer for the IPv6 address
1017         // resolution query and if NAT64 is allowed for this query, we can
1018         // perform IPv4 to IPv6 address translation.
1019 
1020         VerifyOrExit(aResponse.FindHostAddress(Response::kAnswerSection, queryName, /* aIndex */ 0, ip6ddress, ttl) !=
1021                      kErrorNone);
1022         VerifyOrExit(info.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow);
1023 
1024         // First, we check if the response already contains an A record
1025         // (IPv4 address) for the query name.
1026 
1027         if (aResponse.FindARecord(Response::kAdditionalDataSection, queryName, /* aIndex */ 0, aRecord) == kErrorNone)
1028         {
1029             aResponse.mIp6QueryResponseRequiresNat64 = true;
1030             aResponseError                           = kErrorNone;
1031             ExitNow();
1032         }
1033 
1034         // Otherwise, we send a new query for IPv4 address resolution
1035         // for the same host name. We reuse the existing `query`
1036         // instance and keep all the info but clear `mTransmissionCount`
1037         // and `mMessageId` (so that a new random message ID is
1038         // selected). The new `info` will be saved in the query in
1039         // `SendQuery()`. Note that the current query is still in the
1040         // `mQueries` list when `SendQuery()` selects a new random
1041         // message ID, so the existing message ID for this query will
1042         // not be reused. Since the query is not yet resolved, we
1043         // return `kErrorPending`.
1044 
1045         info.mQueryType         = kIp4AddressQuery;
1046         info.mMessageId         = 0;
1047         info.mTransmissionCount = 0;
1048 
1049         SendQuery(*aResponse.mQuery, info, /* aUpdateTimer */ true);
1050 
1051         error = kErrorPending;
1052     }
1053 
1054 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1055 
1056 exit:
1057     if (error != kErrorNone)
1058     {
1059         LogInfo("Failed to parse response %s", ErrorToString(error));
1060     }
1061 
1062     return error;
1063 }
1064 
HandleTimer(Timer & aTimer)1065 void Client::HandleTimer(Timer &aTimer)
1066 {
1067     aTimer.Get<Client>().HandleTimer();
1068 }
1069 
HandleTimer(void)1070 void Client::HandleTimer(void)
1071 {
1072     TimeMilli now      = TimerMilli::GetNow();
1073     TimeMilli nextTime = now.GetDistantFuture();
1074     QueryInfo info;
1075 
1076     for (Query &query : mQueries)
1077     {
1078         info.ReadFrom(query);
1079 
1080         if (now >= info.mRetransmissionTime)
1081         {
1082             if (info.mTransmissionCount >= info.mConfig.GetMaxTxAttempts())
1083             {
1084                 FinalizeQuery(query, kErrorResponseTimeout);
1085                 continue;
1086             }
1087 
1088             SendQuery(query, info, /* aUpdateTimer */ false);
1089         }
1090 
1091         if (nextTime > info.mRetransmissionTime)
1092         {
1093             nextTime = info.mRetransmissionTime;
1094         }
1095     }
1096 
1097     if (nextTime < now.GetDistantFuture())
1098     {
1099         mTimer.FireAt(nextTime);
1100     }
1101 }
1102 
1103 } // namespace Dns
1104 } // namespace ot
1105 
1106 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
1107