1 /*
2 * Copyright (c) 2017-2021, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 #include "dns_client.hpp"
30
31 #if OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
32
33 #include "common/array.hpp"
34 #include "common/as_core_type.hpp"
35 #include "common/code_utils.hpp"
36 #include "common/debug.hpp"
37 #include "common/instance.hpp"
38 #include "common/locator_getters.hpp"
39 #include "common/log.hpp"
40 #include "net/udp6.hpp"
41 #include "thread/network_data_types.hpp"
42 #include "thread/thread_netif.hpp"
43
44 /**
45 * @file
46 * This file implements the DNS client.
47 */
48
49 namespace ot {
50 namespace Dns {
51
52 RegisterLogModule("DnsClient");
53
54 //---------------------------------------------------------------------------------------------------------------------
55 // Client::QueryConfig
56
57 const char Client::QueryConfig::kDefaultServerAddressString[] = OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_IP6_ADDRESS;
58
QueryConfig(InitMode aMode)59 Client::QueryConfig::QueryConfig(InitMode aMode)
60 {
61 OT_UNUSED_VARIABLE(aMode);
62
63 IgnoreError(GetServerSockAddr().GetAddress().FromString(kDefaultServerAddressString));
64 GetServerSockAddr().SetPort(kDefaultServerPort);
65 SetResponseTimeout(kDefaultResponseTimeout);
66 SetMaxTxAttempts(kDefaultMaxTxAttempts);
67 SetRecursionFlag(kDefaultRecursionDesired ? kFlagRecursionDesired : kFlagNoRecursion);
68 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
69 SetNat64Mode(kDefaultNat64Allowed ? kNat64Allow : kNat64Disallow);
70 #endif
71 }
72
SetFrom(const QueryConfig & aConfig,const QueryConfig & aDefaultConfig)73 void Client::QueryConfig::SetFrom(const QueryConfig &aConfig, const QueryConfig &aDefaultConfig)
74 {
75 // This method sets the config from `aConfig` replacing any
76 // unspecified fields (value zero) with the fields from
77 // `aDefaultConfig`.
78
79 *this = aConfig;
80
81 if (GetServerSockAddr().GetAddress().IsUnspecified())
82 {
83 GetServerSockAddr().GetAddress() = aDefaultConfig.GetServerSockAddr().GetAddress();
84 }
85
86 if (GetServerSockAddr().GetPort() == 0)
87 {
88 GetServerSockAddr().SetPort(aDefaultConfig.GetServerSockAddr().GetPort());
89 }
90
91 if (GetResponseTimeout() == 0)
92 {
93 SetResponseTimeout(aDefaultConfig.GetResponseTimeout());
94 }
95
96 if (GetMaxTxAttempts() == 0)
97 {
98 SetMaxTxAttempts(aDefaultConfig.GetMaxTxAttempts());
99 }
100
101 if (GetRecursionFlag() == kFlagUnspecified)
102 {
103 SetRecursionFlag(aDefaultConfig.GetRecursionFlag());
104 }
105
106 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
107 if (GetNat64Mode() == kNat64Unspecified)
108 {
109 SetNat64Mode(aDefaultConfig.GetNat64Mode());
110 }
111 #endif
112 }
113
114 //---------------------------------------------------------------------------------------------------------------------
115 // Client::Response
116
SelectSection(Section aSection,uint16_t & aOffset,uint16_t & aNumRecord) const117 void Client::Response::SelectSection(Section aSection, uint16_t &aOffset, uint16_t &aNumRecord) const
118 {
119 switch (aSection)
120 {
121 case kAnswerSection:
122 aOffset = mAnswerOffset;
123 aNumRecord = mAnswerRecordCount;
124 break;
125 case kAdditionalDataSection:
126 default:
127 aOffset = mAdditionalOffset;
128 aNumRecord = mAdditionalRecordCount;
129 break;
130 }
131 }
132
GetName(char * aNameBuffer,uint16_t aNameBufferSize) const133 Error Client::Response::GetName(char *aNameBuffer, uint16_t aNameBufferSize) const
134 {
135 uint16_t offset = kNameOffsetInQuery;
136
137 return Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize);
138 }
139
CheckForHostNameAlias(Section aSection,Name & aHostName) const140 Error Client::Response::CheckForHostNameAlias(Section aSection, Name &aHostName) const
141 {
142 // If the response includes a CNAME record mapping the query host
143 // name to a canonical name, we update `aHostName` to the new alias
144 // name. Otherwise `aHostName` remains as before.
145
146 Error error;
147 uint16_t offset;
148 uint16_t numRecords;
149 CnameRecord cnameRecord;
150
151 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
152
153 SelectSection(aSection, offset, numRecords);
154 error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aHostName, cnameRecord);
155
156 switch (error)
157 {
158 case kErrorNone:
159 // A CNAME record was found. `offset` now points to after the
160 // last read byte within the `mMessage` into the `cnameRecord`
161 // (which is the start of the new canonical name).
162 aHostName.SetFromMessage(*mMessage, offset);
163 error = Name::ParseName(*mMessage, offset);
164 break;
165
166 case kErrorNotFound:
167 error = kErrorNone;
168 break;
169
170 default:
171 break;
172 }
173
174 exit:
175 return error;
176 }
177
FindHostAddress(Section aSection,const Name & aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const178 Error Client::Response::FindHostAddress(Section aSection,
179 const Name & aHostName,
180 uint16_t aIndex,
181 Ip6::Address &aAddress,
182 uint32_t & aTtl) const
183 {
184 Error error;
185 uint16_t offset;
186 uint16_t numRecords;
187 Name name = aHostName;
188 AaaaRecord aaaaRecord;
189
190 SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
191
192 SelectSection(aSection, offset, numRecords);
193 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aaaaRecord));
194 aAddress = aaaaRecord.GetAddress();
195 aTtl = aaaaRecord.GetTtl();
196
197 exit:
198 return error;
199 }
200
201 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
202
FindARecord(Section aSection,const Name & aHostName,uint16_t aIndex,ARecord & aARecord) const203 Error Client::Response::FindARecord(Section aSection, const Name &aHostName, uint16_t aIndex, ARecord &aARecord) const
204 {
205 Error error;
206 uint16_t offset;
207 uint16_t numRecords;
208 Name name = aHostName;
209
210 SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
211
212 SelectSection(aSection, offset, numRecords);
213 error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aARecord);
214
215 exit:
216 return error;
217 }
218
219 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
220
221 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
222
FindServiceInfo(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const223 Error Client::Response::FindServiceInfo(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
224 {
225 // This method searches for SRV and TXT records in the given
226 // section matching the record name against `aName`, and updates
227 // the `aServiceInfo` accordingly. It also searches for AAAA
228 // record for host name associated with the service (from SRV
229 // record). The search for AAAA record is always performed in
230 // Additional Data section (independent of the value given in
231 // `aSection`).
232
233 Error error;
234 uint16_t offset;
235 uint16_t numRecords;
236 Name hostName;
237 SrvRecord srvRecord;
238 TxtRecord txtRecord;
239
240 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
241
242 // Search for a matching SRV record
243 SelectSection(aSection, offset, numRecords);
244 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, srvRecord));
245
246 aServiceInfo.mTtl = srvRecord.GetTtl();
247 aServiceInfo.mPort = srvRecord.GetPort();
248 aServiceInfo.mPriority = srvRecord.GetPriority();
249 aServiceInfo.mWeight = srvRecord.GetWeight();
250
251 hostName.SetFromMessage(*mMessage, offset);
252
253 if (aServiceInfo.mHostNameBuffer != nullptr)
254 {
255 SuccessOrExit(error = srvRecord.ReadTargetHostName(*mMessage, offset, aServiceInfo.mHostNameBuffer,
256 aServiceInfo.mHostNameBufferSize));
257 }
258 else
259 {
260 SuccessOrExit(error = Name::ParseName(*mMessage, offset));
261 }
262
263 // Search in additional section for AAAA record for the host name.
264
265 error = FindHostAddress(kAdditionalDataSection, hostName, /* aIndex */ 0, AsCoreType(&aServiceInfo.mHostAddress),
266 aServiceInfo.mHostAddressTtl);
267
268 if (error == kErrorNotFound)
269 {
270 AsCoreType(&aServiceInfo.mHostAddress).Clear();
271 aServiceInfo.mHostAddressTtl = 0;
272 }
273 else
274 {
275 SuccessOrExit(error);
276 }
277
278 // A null `mTxtData` indicates that caller does not want to retrieve TXT data.
279 VerifyOrExit(aServiceInfo.mTxtData != nullptr);
280
281 // Search for a matching TXT record. If not found, indicate this by
282 // setting `aServiceInfo.mTxtDataSize` to zero.
283
284 SelectSection(aSection, offset, numRecords);
285 error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, txtRecord);
286
287 switch (error)
288 {
289 case kErrorNone:
290 SuccessOrExit(error =
291 txtRecord.ReadTxtData(*mMessage, offset, aServiceInfo.mTxtData, aServiceInfo.mTxtDataSize));
292 aServiceInfo.mTxtDataTtl = txtRecord.GetTtl();
293 break;
294
295 case kErrorNotFound:
296 aServiceInfo.mTxtDataSize = 0;
297 aServiceInfo.mTxtDataTtl = 0;
298 break;
299
300 default:
301 ExitNow();
302 }
303
304 exit:
305 return error;
306 }
307
308 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
309
310 //---------------------------------------------------------------------------------------------------------------------
311 // Client::AddressResponse
312
GetAddress(uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const313 Error Client::AddressResponse::GetAddress(uint16_t aIndex, Ip6::Address &aAddress, uint32_t &aTtl) const
314 {
315 Error error = kErrorNone;
316 Name name(*mQuery, kNameOffsetInQuery);
317
318 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
319
320 // If the response is for an IPv4 address query or if it is an
321 // IPv6 address query response with no IPv6 address but with
322 // an IPv4 in its additional section, we read the IPv4 address
323 // and translate it to an IPv6 address.
324
325 QueryInfo info;
326
327 info.ReadFrom(*mQuery);
328
329 if ((info.mQueryType == kIp4AddressQuery) || mIp6QueryResponseRequiresNat64)
330 {
331 Section section;
332 ARecord aRecord;
333 NetworkData::ExternalRouteConfig nat64Prefix;
334
335 VerifyOrExit(mInstance->Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
336 error = kErrorInvalidState);
337
338 section = (info.mQueryType == kIp4AddressQuery) ? kAnswerSection : kAdditionalDataSection;
339 SuccessOrExit(error = FindARecord(section, name, aIndex, aRecord));
340
341 aAddress.SynthesizeFromIp4Address(nat64Prefix.GetPrefix(), aRecord.GetAddress());
342 aTtl = aRecord.GetTtl();
343
344 ExitNow();
345 }
346
347 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
348
349 ExitNow(error = FindHostAddress(kAnswerSection, name, aIndex, aAddress, aTtl));
350
351 exit:
352 return error;
353 }
354
355 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
356
357 //---------------------------------------------------------------------------------------------------------------------
358 // Client::BrowseResponse
359
GetServiceInstance(uint16_t aIndex,char * aLabelBuffer,uint8_t aLabelBufferSize) const360 Error Client::BrowseResponse::GetServiceInstance(uint16_t aIndex, char *aLabelBuffer, uint8_t aLabelBufferSize) const
361 {
362 Error error;
363 uint16_t offset;
364 uint16_t numRecords;
365 Name serviceName(*mQuery, kNameOffsetInQuery);
366 PtrRecord ptrRecord;
367
368 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
369
370 SelectSection(kAnswerSection, offset, numRecords);
371 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, serviceName, ptrRecord));
372 error = ptrRecord.ReadPtrName(*mMessage, offset, aLabelBuffer, aLabelBufferSize, nullptr, 0);
373
374 exit:
375 return error;
376 }
377
GetServiceInfo(const char * aInstanceLabel,ServiceInfo & aServiceInfo) const378 Error Client::BrowseResponse::GetServiceInfo(const char *aInstanceLabel, ServiceInfo &aServiceInfo) const
379 {
380 Error error;
381 Name instanceName;
382
383 // Find a matching PTR record for the service instance label.
384 // Then search and read SRV, TXT and AAAA records in Additional Data section
385 // matching the same name to populate `aServiceInfo`.
386
387 SuccessOrExit(error = FindPtrRecord(aInstanceLabel, instanceName));
388 error = FindServiceInfo(kAdditionalDataSection, instanceName, aServiceInfo);
389
390 exit:
391 return error;
392 }
393
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const394 Error Client::BrowseResponse::GetHostAddress(const char * aHostName,
395 uint16_t aIndex,
396 Ip6::Address &aAddress,
397 uint32_t & aTtl) const
398 {
399 return FindHostAddress(kAdditionalDataSection, Name(aHostName), aIndex, aAddress, aTtl);
400 }
401
FindPtrRecord(const char * aInstanceLabel,Name & aInstanceName) const402 Error Client::BrowseResponse::FindPtrRecord(const char *aInstanceLabel, Name &aInstanceName) const
403 {
404 // This method searches within the Answer Section for a PTR record
405 // matching a given instance label @aInstanceLabel. If found, the
406 // `aName` is updated to return the name in the message.
407
408 Error error;
409 uint16_t offset;
410 Name serviceName(*mQuery, kNameOffsetInQuery);
411 uint16_t numRecords;
412 uint16_t labelOffset;
413 PtrRecord ptrRecord;
414
415 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
416
417 SelectSection(kAnswerSection, offset, numRecords);
418
419 for (; numRecords > 0; numRecords--)
420 {
421 SuccessOrExit(error = Name::CompareName(*mMessage, offset, serviceName));
422
423 error = ResourceRecord::ReadRecord(*mMessage, offset, ptrRecord);
424
425 if (error == kErrorNotFound)
426 {
427 // `ReadRecord()` updates `offset` to skip over a
428 // non-matching record.
429 continue;
430 }
431
432 SuccessOrExit(error);
433
434 // It is a PTR record. Check the first label to match the
435 // instance label.
436
437 labelOffset = offset;
438 error = Name::CompareLabel(*mMessage, labelOffset, aInstanceLabel);
439
440 if (error == kErrorNone)
441 {
442 aInstanceName.SetFromMessage(*mMessage, offset);
443 ExitNow();
444 }
445
446 VerifyOrExit(error == kErrorNotFound);
447
448 // Update offset to skip over the PTR record.
449 offset += static_cast<uint16_t>(ptrRecord.GetSize()) - sizeof(ptrRecord);
450 }
451
452 error = kErrorNotFound;
453
454 exit:
455 return error;
456 }
457
458 //---------------------------------------------------------------------------------------------------------------------
459 // Client::ServiceResponse
460
GetServiceName(char * aLabelBuffer,uint8_t aLabelBufferSize,char * aNameBuffer,uint16_t aNameBufferSize) const461 Error Client::ServiceResponse::GetServiceName(char * aLabelBuffer,
462 uint8_t aLabelBufferSize,
463 char * aNameBuffer,
464 uint16_t aNameBufferSize) const
465 {
466 Error error;
467 uint16_t offset = kNameOffsetInQuery;
468
469 SuccessOrExit(error = Name::ReadLabel(*mQuery, offset, aLabelBuffer, aLabelBufferSize));
470
471 VerifyOrExit(aNameBuffer != nullptr);
472 SuccessOrExit(error = Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize));
473
474 exit:
475 return error;
476 }
477
GetServiceInfo(ServiceInfo & aServiceInfo) const478 Error Client::ServiceResponse::GetServiceInfo(ServiceInfo &aServiceInfo) const
479 {
480 // Search and read SRV, TXT records in Answer Section
481 // matching name from query.
482
483 return FindServiceInfo(kAnswerSection, Name(*mQuery, kNameOffsetInQuery), aServiceInfo);
484 }
485
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const486 Error Client::ServiceResponse::GetHostAddress(const char * aHostName,
487 uint16_t aIndex,
488 Ip6::Address &aAddress,
489 uint32_t & aTtl) const
490 {
491 return FindHostAddress(kAdditionalDataSection, Name(aHostName), aIndex, aAddress, aTtl);
492 }
493
494 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
495
496 //---------------------------------------------------------------------------------------------------------------------
497 // Client
498
499 const uint16_t Client::kIp6AddressQueryRecordTypes[] = {ResourceRecord::kTypeAaaa};
500 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
501 const uint16_t Client::kIp4AddressQueryRecordTypes[] = {ResourceRecord::kTypeA};
502 #endif
503 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
504 const uint16_t Client::kBrowseQueryRecordTypes[] = {ResourceRecord::kTypePtr};
505 const uint16_t Client::kServiceQueryRecordTypes[] = {ResourceRecord::kTypeSrv, ResourceRecord::kTypeTxt};
506 #endif
507
508 const uint8_t Client::kQuestionCount[] = {
509 /* kIp6AddressQuery -> */ GetArrayLength(kIp6AddressQueryRecordTypes), // AAAA records
510 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
511 /* kIp4AddressQuery -> */ GetArrayLength(kIp4AddressQueryRecordTypes), // A records
512 #endif
513 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
514 /* kBrowseQuery -> */ GetArrayLength(kBrowseQueryRecordTypes), // PTR records
515 /* kServiceQuery -> */ GetArrayLength(kServiceQueryRecordTypes), // SRV and TXT records
516 #endif
517 };
518
519 const uint16_t *Client::kQuestionRecordTypes[] = {
520 /* kIp6AddressQuery -> */ kIp6AddressQueryRecordTypes,
521 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
522 /* kIp4AddressQuery -> */ kIp4AddressQueryRecordTypes,
523 #endif
524 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
525 /* kBrowseQuery -> */ kBrowseQueryRecordTypes,
526 /* kServiceQuery -> */ kServiceQueryRecordTypes,
527 #endif
528 };
529
Client(Instance & aInstance)530 Client::Client(Instance &aInstance)
531 : InstanceLocator(aInstance)
532 , mSocket(aInstance)
533 , mTimer(aInstance, Client::HandleTimer)
534 , mDefaultConfig(QueryConfig::kInitFromDefaults)
535 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
536 , mUserDidSetDefaultAddress(false)
537 #endif
538 {
539 static_assert(kIp6AddressQuery == 0, "kIp6AddressQuery value is not correct");
540 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
541 static_assert(kIp4AddressQuery == 1, "kIp4AddressQuery value is not correct");
542 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
543 static_assert(kBrowseQuery == 2, "kBrowseQuery value is not correct");
544 static_assert(kServiceQuery == 3, "kServiceQuery value is not correct");
545 #endif
546 #elif OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
547 static_assert(kBrowseQuery == 1, "kBrowseQuery value is not correct");
548 static_assert(kServiceQuery == 2, "kServiceQuery value is not correct");
549 #endif
550 }
551
Start(void)552 Error Client::Start(void)
553 {
554 Error error;
555
556 SuccessOrExit(error = mSocket.Open(&Client::HandleUdpReceive, this));
557 SuccessOrExit(error = mSocket.Bind(0, OT_NETIF_UNSPECIFIED));
558
559 exit:
560 return error;
561 }
562
Stop(void)563 void Client::Stop(void)
564 {
565 Query *query;
566
567 while ((query = mQueries.GetHead()) != nullptr)
568 {
569 FinalizeQuery(*query, kErrorAbort);
570 }
571
572 IgnoreError(mSocket.Close());
573 }
574
SetDefaultConfig(const QueryConfig & aQueryConfig)575 void Client::SetDefaultConfig(const QueryConfig &aQueryConfig)
576 {
577 QueryConfig startingDefault(QueryConfig::kInitFromDefaults);
578
579 mDefaultConfig.SetFrom(aQueryConfig, startingDefault);
580
581 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
582 mUserDidSetDefaultAddress = !aQueryConfig.GetServerSockAddr().GetAddress().IsUnspecified();
583 UpdateDefaultConfigAddress();
584 #endif
585 }
586
ResetDefaultConfig(void)587 void Client::ResetDefaultConfig(void)
588 {
589 mDefaultConfig = QueryConfig(QueryConfig::kInitFromDefaults);
590
591 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
592 mUserDidSetDefaultAddress = false;
593 UpdateDefaultConfigAddress();
594 #endif
595 }
596
597 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
UpdateDefaultConfigAddress(void)598 void Client::UpdateDefaultConfigAddress(void)
599 {
600 const Ip6::Address &srpServerAddr = Get<Srp::Client>().GetServerAddress().GetAddress();
601
602 if (!mUserDidSetDefaultAddress && Get<Srp::Client>().IsServerSelectedByAutoStart() &&
603 !srpServerAddr.IsUnspecified())
604 {
605 mDefaultConfig.GetServerSockAddr().SetAddress(srpServerAddr);
606 }
607 }
608 #endif
609
ResolveAddress(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)610 Error Client::ResolveAddress(const char * aHostName,
611 AddressCallback aCallback,
612 void * aContext,
613 const QueryConfig *aConfig)
614 {
615 QueryInfo info;
616
617 info.Clear();
618 info.mQueryType = kIp6AddressQuery;
619 info.mCallback.mAddressCallback = aCallback;
620
621 return StartQuery(info, aConfig, nullptr, aHostName, aContext);
622 }
623
624 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
ResolveIp4Address(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)625 Error Client::ResolveIp4Address(const char * aHostName,
626 AddressCallback aCallback,
627 void * aContext,
628 const QueryConfig *aConfig)
629 {
630 QueryInfo info;
631
632 info.Clear();
633 info.mQueryType = kIp4AddressQuery;
634 info.mCallback.mAddressCallback = aCallback;
635
636 return StartQuery(info, aConfig, nullptr, aHostName, aContext);
637 }
638 #endif
639
640 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
641
Browse(const char * aServiceName,BrowseCallback aCallback,void * aContext,const QueryConfig * aConfig)642 Error Client::Browse(const char *aServiceName, BrowseCallback aCallback, void *aContext, const QueryConfig *aConfig)
643 {
644 QueryInfo info;
645
646 info.Clear();
647 info.mQueryType = kBrowseQuery;
648 info.mCallback.mBrowseCallback = aCallback;
649
650 return StartQuery(info, aConfig, nullptr, aServiceName, aContext);
651 }
652
ResolveService(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)653 Error Client::ResolveService(const char * aInstanceLabel,
654 const char * aServiceName,
655 ServiceCallback aCallback,
656 void * aContext,
657 const QueryConfig *aConfig)
658 {
659 QueryInfo info;
660 Error error;
661
662 VerifyOrExit(aInstanceLabel != nullptr, error = kErrorInvalidArgs);
663
664 info.Clear();
665 info.mQueryType = kServiceQuery;
666 info.mCallback.mServiceCallback = aCallback;
667
668 error = StartQuery(info, aConfig, aInstanceLabel, aServiceName, aContext);
669
670 exit:
671 return error;
672 }
673
674 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
675
StartQuery(QueryInfo & aInfo,const QueryConfig * aConfig,const char * aLabel,const char * aName,void * aContext)676 Error Client::StartQuery(QueryInfo & aInfo,
677 const QueryConfig *aConfig,
678 const char * aLabel,
679 const char * aName,
680 void * aContext)
681 {
682 // This method assumes that `mQueryType` and `mCallback` to be
683 // already set by caller on `aInfo`. The `aLabel` can be `nullptr`
684 // and then `aName` provides the full name, otherwise the name is
685 // appended as `{aLabel}.{aName}`.
686
687 Error error;
688 Query *query;
689
690 VerifyOrExit(mSocket.IsBound(), error = kErrorInvalidState);
691
692 if (aConfig == nullptr)
693 {
694 aInfo.mConfig = mDefaultConfig;
695 }
696 else
697 {
698 // To form the config for this query, replace any unspecified
699 // fields (zero value) in the given `aConfig` with the fields
700 // from `mDefaultConfig`.
701
702 aInfo.mConfig.SetFrom(*aConfig, mDefaultConfig);
703 }
704
705 aInfo.mCallbackContext = aContext;
706
707 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
708 if (aInfo.mQueryType == kIp4AddressQuery)
709 {
710 NetworkData::ExternalRouteConfig nat64Prefix;
711
712 VerifyOrExit(aInfo.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow, error = kErrorInvalidArgs);
713 VerifyOrExit(Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
714 error = kErrorInvalidState);
715 }
716 #endif
717
718 SuccessOrExit(error = AllocateQuery(aInfo, aLabel, aName, query));
719 mQueries.Enqueue(*query);
720
721 SendQuery(*query, aInfo, /* aUpdateTimer */ true);
722
723 exit:
724 return error;
725 }
726
AllocateQuery(const QueryInfo & aInfo,const char * aLabel,const char * aName,Query * & aQuery)727 Error Client::AllocateQuery(const QueryInfo &aInfo, const char *aLabel, const char *aName, Query *&aQuery)
728 {
729 Error error = kErrorNone;
730
731 aQuery = Get<MessagePool>().Allocate(Message::kTypeOther);
732 VerifyOrExit(aQuery != nullptr, error = kErrorNoBufs);
733
734 SuccessOrExit(error = aQuery->Append(aInfo));
735
736 if (aLabel != nullptr)
737 {
738 SuccessOrExit(error = Name::AppendLabel(aLabel, *aQuery));
739 }
740
741 SuccessOrExit(error = Name::AppendName(aName, *aQuery));
742
743 exit:
744 FreeAndNullMessageOnError(aQuery, error);
745 return error;
746 }
747
FreeQuery(Query & aQuery)748 void Client::FreeQuery(Query &aQuery)
749 {
750 mQueries.DequeueAndFree(aQuery);
751 }
752
SendQuery(Query & aQuery,QueryInfo & aInfo,bool aUpdateTimer)753 void Client::SendQuery(Query &aQuery, QueryInfo &aInfo, bool aUpdateTimer)
754 {
755 // This method prepares and sends a query message represented by
756 // `aQuery` and `aInfo`. This method updates `aInfo` (e.g., sets
757 // the new `mRetransmissionTime`) and updates it in `aQuery` as
758 // well. `aUpdateTimer` indicates whether the timer should be
759 // updated when query is sent or not (used in the case where timer
760 // is handled by caller).
761
762 Error error = kErrorNone;
763 Message * message = nullptr;
764 Header header;
765 Ip6::MessageInfo messageInfo;
766
767 aInfo.mTransmissionCount++;
768 aInfo.mRetransmissionTime = TimerMilli::GetNow() + aInfo.mConfig.GetResponseTimeout();
769
770 if (aInfo.mMessageId == 0)
771 {
772 do
773 {
774 SuccessOrExit(error = header.SetRandomMessageId());
775 } while ((header.GetMessageId() == 0) || (FindQueryById(header.GetMessageId()) != nullptr));
776
777 aInfo.mMessageId = header.GetMessageId();
778 }
779 else
780 {
781 header.SetMessageId(aInfo.mMessageId);
782 }
783
784 header.SetType(Header::kTypeQuery);
785 header.SetQueryType(Header::kQueryTypeStandard);
786
787 if (aInfo.mConfig.GetRecursionFlag() == QueryConfig::kFlagRecursionDesired)
788 {
789 header.SetRecursionDesiredFlag();
790 }
791
792 header.SetQuestionCount(kQuestionCount[aInfo.mQueryType]);
793
794 message = mSocket.NewMessage(0);
795 VerifyOrExit(message != nullptr, error = kErrorNoBufs);
796
797 SuccessOrExit(error = message->Append(header));
798
799 // Prepare the question section.
800
801 for (uint8_t num = 0; num < kQuestionCount[aInfo.mQueryType]; num++)
802 {
803 SuccessOrExit(error = AppendNameFromQuery(aQuery, *message));
804 SuccessOrExit(error = message->Append(Question(kQuestionRecordTypes[aInfo.mQueryType][num])));
805 }
806
807 messageInfo.SetPeerAddr(aInfo.mConfig.GetServerSockAddr().GetAddress());
808 messageInfo.SetPeerPort(aInfo.mConfig.GetServerSockAddr().GetPort());
809
810 SuccessOrExit(error = mSocket.SendTo(*message, messageInfo));
811
812 exit:
813 FreeMessageOnError(message, error);
814
815 UpdateQuery(aQuery, aInfo);
816
817 if (aUpdateTimer)
818 {
819 mTimer.FireAtIfEarlier(aInfo.mRetransmissionTime);
820 }
821 }
822
AppendNameFromQuery(const Query & aQuery,Message & aMessage)823 Error Client::AppendNameFromQuery(const Query &aQuery, Message &aMessage)
824 {
825 Error error = kErrorNone;
826 uint16_t offset;
827 uint16_t length;
828
829 // The name is encoded and included after the `Info` in `aQuery`. We
830 // first calculate the encoded length of the name, then grow the
831 // message, and finally copy the encoded name bytes from `aQuery`
832 // into `aMessage`.
833
834 length = aQuery.GetLength() - kNameOffsetInQuery;
835
836 offset = aMessage.GetLength();
837 SuccessOrExit(error = aMessage.SetLength(offset + length));
838
839 aQuery.CopyTo(/* aSourceOffset */ kNameOffsetInQuery, /* aDestOffset */ offset, length, aMessage);
840
841 exit:
842 return error;
843 }
844
FinalizeQuery(Query & aQuery,Error aError)845 void Client::FinalizeQuery(Query &aQuery, Error aError)
846 {
847 Response response;
848 QueryInfo info;
849
850 response.mInstance = &Get<Instance>();
851 response.mQuery = &aQuery;
852 info.ReadFrom(aQuery);
853
854 FinalizeQuery(response, info.mQueryType, aError);
855 }
856
FinalizeQuery(Response & aResponse,QueryType aType,Error aError)857 void Client::FinalizeQuery(Response &aResponse, QueryType aType, Error aError)
858 {
859 Callback callback;
860 void * context;
861
862 GetCallback(*aResponse.mQuery, callback, context);
863
864 switch (aType)
865 {
866 case kIp6AddressQuery:
867 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
868 case kIp4AddressQuery:
869 #endif
870 if (callback.mAddressCallback != nullptr)
871 {
872 callback.mAddressCallback(aError, &aResponse, context);
873 }
874 break;
875
876 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
877 case kBrowseQuery:
878 if (callback.mBrowseCallback != nullptr)
879 {
880 callback.mBrowseCallback(aError, &aResponse, context);
881 }
882 break;
883
884 case kServiceQuery:
885 if (callback.mServiceCallback != nullptr)
886 {
887 callback.mServiceCallback(aError, &aResponse, context);
888 }
889 break;
890 #endif
891 }
892
893 FreeQuery(*aResponse.mQuery);
894 }
895
GetCallback(const Query & aQuery,Callback & aCallback,void * & aContext)896 void Client::GetCallback(const Query &aQuery, Callback &aCallback, void *&aContext)
897 {
898 QueryInfo info;
899
900 info.ReadFrom(aQuery);
901
902 aCallback = info.mCallback;
903 aContext = info.mCallbackContext;
904 }
905
FindQueryById(uint16_t aMessageId)906 Client::Query *Client::FindQueryById(uint16_t aMessageId)
907 {
908 Query * matchedQuery = nullptr;
909 QueryInfo info;
910
911 for (Query &query : mQueries)
912 {
913 info.ReadFrom(query);
914
915 if (info.mMessageId == aMessageId)
916 {
917 matchedQuery = &query;
918 break;
919 }
920 }
921
922 return matchedQuery;
923 }
924
HandleUdpReceive(void * aContext,otMessage * aMessage,const otMessageInfo * aMsgInfo)925 void Client::HandleUdpReceive(void *aContext, otMessage *aMessage, const otMessageInfo *aMsgInfo)
926 {
927 OT_UNUSED_VARIABLE(aMsgInfo);
928
929 static_cast<Client *>(aContext)->ProcessResponse(AsCoreType(aMessage));
930 }
931
ProcessResponse(const Message & aMessage)932 void Client::ProcessResponse(const Message &aMessage)
933 {
934 Response response;
935 QueryType type;
936 Error responseError;
937
938 response.mInstance = &Get<Instance>();
939 response.mMessage = &aMessage;
940
941 // We intentionally parse the response in a separate method
942 // `ParseResponse()` to free all the stack allocated variables
943 // (e.g., `QueryInfo`) used during parsing of the message before
944 // finalizing the query and invoking the user's callback.
945
946 SuccessOrExit(ParseResponse(response, type, responseError));
947 FinalizeQuery(response, type, responseError);
948
949 exit:
950 return;
951 }
952
ParseResponse(Response & aResponse,QueryType & aType,Error & aResponseError)953 Error Client::ParseResponse(Response &aResponse, QueryType &aType, Error &aResponseError)
954 {
955 Error error = kErrorNone;
956 const Message &message = *aResponse.mMessage;
957 uint16_t offset = message.GetOffset();
958 Header header;
959 QueryInfo info;
960 Name queryName;
961
962 SuccessOrExit(error = message.Read(offset, header));
963 offset += sizeof(Header);
964
965 VerifyOrExit((header.GetType() == Header::kTypeResponse) && (header.GetQueryType() == Header::kQueryTypeStandard) &&
966 !header.IsTruncationFlagSet(),
967 error = kErrorDrop);
968
969 aResponse.mQuery = FindQueryById(header.GetMessageId());
970 VerifyOrExit(aResponse.mQuery != nullptr, error = kErrorNotFound);
971
972 info.ReadFrom(*aResponse.mQuery);
973 aType = info.mQueryType;
974
975 queryName.SetFromMessage(*aResponse.mQuery, kNameOffsetInQuery);
976
977 // Check the Question Section
978
979 if (header.GetQuestionCount() == kQuestionCount[aType])
980 {
981 for (uint8_t num = 0; num < kQuestionCount[aType]; num++)
982 {
983 SuccessOrExit(error = Name::CompareName(message, offset, queryName));
984 offset += sizeof(Question);
985 }
986 }
987 else
988 {
989 VerifyOrExit((header.GetResponseCode() != Header::kResponseSuccess) && (header.GetQuestionCount() == 0),
990 error = kErrorParse);
991 }
992
993 // Check the answer, authority and additional record sections
994
995 aResponse.mAnswerOffset = offset;
996 SuccessOrExit(error = ResourceRecord::ParseRecords(message, offset, header.GetAnswerCount()));
997 SuccessOrExit(error = ResourceRecord::ParseRecords(message, offset, header.GetAuthorityRecordCount()));
998 aResponse.mAdditionalOffset = offset;
999 SuccessOrExit(error = ResourceRecord::ParseRecords(message, offset, header.GetAdditionalRecordCount()));
1000
1001 aResponse.mAnswerRecordCount = header.GetAnswerCount();
1002 aResponse.mAdditionalRecordCount = header.GetAdditionalRecordCount();
1003
1004 // Check the response code from server
1005
1006 aResponseError = Header::ResponseCodeToError(header.GetResponseCode());
1007
1008 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1009
1010 if (aType == kIp6AddressQuery)
1011 {
1012 Ip6::Address ip6ddress;
1013 uint32_t ttl;
1014 ARecord aRecord;
1015
1016 // If the response does not contain an answer for the IPv6 address
1017 // resolution query and if NAT64 is allowed for this query, we can
1018 // perform IPv4 to IPv6 address translation.
1019
1020 VerifyOrExit(aResponse.FindHostAddress(Response::kAnswerSection, queryName, /* aIndex */ 0, ip6ddress, ttl) !=
1021 kErrorNone);
1022 VerifyOrExit(info.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow);
1023
1024 // First, we check if the response already contains an A record
1025 // (IPv4 address) for the query name.
1026
1027 if (aResponse.FindARecord(Response::kAdditionalDataSection, queryName, /* aIndex */ 0, aRecord) == kErrorNone)
1028 {
1029 aResponse.mIp6QueryResponseRequiresNat64 = true;
1030 aResponseError = kErrorNone;
1031 ExitNow();
1032 }
1033
1034 // Otherwise, we send a new query for IPv4 address resolution
1035 // for the same host name. We reuse the existing `query`
1036 // instance and keep all the info but clear `mTransmissionCount`
1037 // and `mMessageId` (so that a new random message ID is
1038 // selected). The new `info` will be saved in the query in
1039 // `SendQuery()`. Note that the current query is still in the
1040 // `mQueries` list when `SendQuery()` selects a new random
1041 // message ID, so the existing message ID for this query will
1042 // not be reused. Since the query is not yet resolved, we
1043 // return `kErrorPending`.
1044
1045 info.mQueryType = kIp4AddressQuery;
1046 info.mMessageId = 0;
1047 info.mTransmissionCount = 0;
1048
1049 SendQuery(*aResponse.mQuery, info, /* aUpdateTimer */ true);
1050
1051 error = kErrorPending;
1052 }
1053
1054 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1055
1056 exit:
1057 if (error != kErrorNone)
1058 {
1059 LogInfo("Failed to parse response %s", ErrorToString(error));
1060 }
1061
1062 return error;
1063 }
1064
HandleTimer(Timer & aTimer)1065 void Client::HandleTimer(Timer &aTimer)
1066 {
1067 aTimer.Get<Client>().HandleTimer();
1068 }
1069
HandleTimer(void)1070 void Client::HandleTimer(void)
1071 {
1072 TimeMilli now = TimerMilli::GetNow();
1073 TimeMilli nextTime = now.GetDistantFuture();
1074 QueryInfo info;
1075
1076 for (Query &query : mQueries)
1077 {
1078 info.ReadFrom(query);
1079
1080 if (now >= info.mRetransmissionTime)
1081 {
1082 if (info.mTransmissionCount >= info.mConfig.GetMaxTxAttempts())
1083 {
1084 FinalizeQuery(query, kErrorResponseTimeout);
1085 continue;
1086 }
1087
1088 SendQuery(query, info, /* aUpdateTimer */ false);
1089 }
1090
1091 if (nextTime > info.mRetransmissionTime)
1092 {
1093 nextTime = info.mRetransmissionTime;
1094 }
1095 }
1096
1097 if (nextTime < now.GetDistantFuture())
1098 {
1099 mTimer.FireAt(nextTime);
1100 }
1101 }
1102
1103 } // namespace Dns
1104 } // namespace ot
1105
1106 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
1107