• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2021, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements the DNS-SD server.
32  */
33 
34 #include "dnssd_server.hpp"
35 
36 #if OPENTHREAD_CONFIG_DNSSD_SERVER_ENABLE
37 
38 #include "common/array.hpp"
39 #include "common/as_core_type.hpp"
40 #include "common/code_utils.hpp"
41 #include "common/debug.hpp"
42 #include "common/instance.hpp"
43 #include "common/locator_getters.hpp"
44 #include "common/log.hpp"
45 #include "common/string.hpp"
46 #include "net/srp_server.hpp"
47 #include "net/udp6.hpp"
48 
49 namespace ot {
50 namespace Dns {
51 namespace ServiceDiscovery {
52 
53 RegisterLogModule("DnssdServer");
54 
55 const char Server::kDnssdProtocolUdp[]  = "_udp";
56 const char Server::kDnssdProtocolTcp[]  = "_tcp";
57 const char Server::kDnssdSubTypeLabel[] = "._sub.";
58 const char Server::kDefaultDomainName[] = "default.service.arpa.";
59 
Server(Instance & aInstance)60 Server::Server(Instance &aInstance)
61     : InstanceLocator(aInstance)
62     , mSocket(aInstance)
63     , mQueryCallbackContext(nullptr)
64     , mQuerySubscribe(nullptr)
65     , mQueryUnsubscribe(nullptr)
66     , mTimer(aInstance, Server::HandleTimer)
67 {
68     mCounters.Clear();
69 }
70 
Start(void)71 Error Server::Start(void)
72 {
73     Error error = kErrorNone;
74 
75     VerifyOrExit(!IsRunning());
76 
77     SuccessOrExit(error = mSocket.Open(&Server::HandleUdpReceive, this));
78     SuccessOrExit(error = mSocket.Bind(kPort, kBindUnspecifiedNetif ? OT_NETIF_UNSPECIFIED : OT_NETIF_THREAD));
79 
80 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
81     Get<Srp::Server>().HandleDnssdServerStateChange();
82 #endif
83 
84 exit:
85     LogInfo("started: %s", ErrorToString(error));
86 
87     if (error != kErrorNone)
88     {
89         IgnoreError(mSocket.Close());
90     }
91 
92     return error;
93 }
94 
Stop(void)95 void Server::Stop(void)
96 {
97     // Abort all query transactions
98     for (QueryTransaction &query : mQueryTransactions)
99     {
100         if (query.IsValid())
101         {
102             FinalizeQuery(query, Header::kResponseServerFailure);
103         }
104     }
105 
106     mTimer.Stop();
107 
108     IgnoreError(mSocket.Close());
109     LogInfo("stopped");
110 
111 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
112     Get<Srp::Server>().HandleDnssdServerStateChange();
113 #endif
114 }
115 
HandleUdpReceive(void * aContext,otMessage * aMessage,const otMessageInfo * aMessageInfo)116 void Server::HandleUdpReceive(void *aContext, otMessage *aMessage, const otMessageInfo *aMessageInfo)
117 {
118     static_cast<Server *>(aContext)->HandleUdpReceive(AsCoreType(aMessage), AsCoreType(aMessageInfo));
119 }
120 
HandleUdpReceive(Message & aMessage,const Ip6::MessageInfo & aMessageInfo)121 void Server::HandleUdpReceive(Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
122 {
123     Header requestHeader;
124 
125 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
126     // We first let the `Srp::Server` process the received message.
127     // It returns `kErrorNone` to indicate that it successfully
128     // processed the message.
129 
130     VerifyOrExit(Get<Srp::Server>().HandleDnssdServerUdpReceive(aMessage, aMessageInfo) != kErrorNone);
131 #endif
132 
133     SuccessOrExit(aMessage.Read(aMessage.GetOffset(), requestHeader));
134     VerifyOrExit(requestHeader.GetType() == Header::kTypeQuery);
135 
136     ProcessQuery(requestHeader, aMessage, aMessageInfo);
137 
138 exit:
139     return;
140 }
141 
ProcessQuery(const Header & aRequestHeader,Message & aRequestMessage,const Ip6::MessageInfo & aMessageInfo)142 void Server::ProcessQuery(const Header &aRequestHeader, Message &aRequestMessage, const Ip6::MessageInfo &aMessageInfo)
143 {
144     Error            error           = kErrorNone;
145     Message *        responseMessage = nullptr;
146     Header           responseHeader;
147     NameCompressInfo compressInfo(kDefaultDomainName);
148     Header::Response response                = Header::kResponseSuccess;
149     bool             resolveByQueryCallbacks = false;
150 
151     responseMessage = mSocket.NewMessage(0);
152     VerifyOrExit(responseMessage != nullptr, error = kErrorNoBufs);
153 
154     // Allocate space for DNS header
155     SuccessOrExit(error = responseMessage->SetLength(sizeof(Header)));
156 
157     // Setup initial DNS response header
158     responseHeader.Clear();
159     responseHeader.SetType(Header::kTypeResponse);
160     responseHeader.SetMessageId(aRequestHeader.GetMessageId());
161 
162     // Validate the query
163     VerifyOrExit(aRequestHeader.GetQueryType() == Header::kQueryTypeStandard,
164                  response = Header::kResponseNotImplemented);
165     VerifyOrExit(!aRequestHeader.IsTruncationFlagSet(), response = Header::kResponseFormatError);
166     VerifyOrExit(aRequestHeader.GetQuestionCount() > 0, response = Header::kResponseFormatError);
167 
168     response = AddQuestions(aRequestHeader, aRequestMessage, responseHeader, *responseMessage, compressInfo);
169     VerifyOrExit(response == Header::kResponseSuccess);
170 
171 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
172     // Answer the questions
173     response = ResolveBySrp(responseHeader, *responseMessage, compressInfo);
174 #endif
175 
176     // Resolve the question using query callbacks if SRP server failed to resolve the questions.
177     if (responseHeader.GetAnswerCount() == 0)
178     {
179         if (kErrorNone == ResolveByQueryCallbacks(responseHeader, *responseMessage, compressInfo, aMessageInfo))
180         {
181             resolveByQueryCallbacks = true;
182         }
183     }
184 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
185     else
186     {
187         ++mCounters.mResolvedBySrp;
188     }
189 #endif
190 
191 exit:
192     if (error == kErrorNone && !resolveByQueryCallbacks)
193     {
194         SendResponse(responseHeader, response, *responseMessage, aMessageInfo, mSocket);
195     }
196 
197     FreeMessageOnError(responseMessage, error);
198 }
199 
SendResponse(Header aHeader,Header::Response aResponseCode,Message & aMessage,const Ip6::MessageInfo & aMessageInfo,Ip6::Udp::Socket & aSocket)200 void Server::SendResponse(Header                  aHeader,
201                           Header::Response        aResponseCode,
202                           Message &               aMessage,
203                           const Ip6::MessageInfo &aMessageInfo,
204                           Ip6::Udp::Socket &      aSocket)
205 {
206     Error error;
207 
208     if (aResponseCode == Header::kResponseServerFailure)
209     {
210         LogWarn("failed to handle DNS query due to server failure");
211         aHeader.SetQuestionCount(0);
212         aHeader.SetAnswerCount(0);
213         aHeader.SetAdditionalRecordCount(0);
214         IgnoreError(aMessage.SetLength(sizeof(Header)));
215     }
216 
217     aHeader.SetResponseCode(aResponseCode);
218     aMessage.Write(0, aHeader);
219 
220     error = aSocket.SendTo(aMessage, aMessageInfo);
221 
222     FreeMessageOnError(&aMessage, error);
223 
224     if (error != kErrorNone)
225     {
226         LogWarn("failed to send DNS-SD reply: %s", ErrorToString(error));
227     }
228     else
229     {
230         LogInfo("send DNS-SD reply: %s, RCODE=%d", ErrorToString(error), aResponseCode);
231     }
232 
233     UpdateResponseCounters(aResponseCode);
234 }
235 
AddQuestions(const Header & aRequestHeader,const Message & aRequestMessage,Header & aResponseHeader,Message & aResponseMessage,NameCompressInfo & aCompressInfo)236 Header::Response Server::AddQuestions(const Header &    aRequestHeader,
237                                       const Message &   aRequestMessage,
238                                       Header &          aResponseHeader,
239                                       Message &         aResponseMessage,
240                                       NameCompressInfo &aCompressInfo)
241 {
242     Question         question;
243     uint16_t         readOffset;
244     Header::Response response = Header::kResponseSuccess;
245     char             name[Name::kMaxNameSize];
246 
247     readOffset = sizeof(Header);
248 
249     // Check and append the questions
250     for (uint16_t i = 0; i < aRequestHeader.GetQuestionCount(); i++)
251     {
252         NameComponentsOffsetInfo nameComponentsOffsetInfo;
253         uint16_t                 qtype;
254 
255         VerifyOrExit(kErrorNone == Name::ReadName(aRequestMessage, readOffset, name, sizeof(name)),
256                      response = Header::kResponseFormatError);
257         VerifyOrExit(kErrorNone == aRequestMessage.Read(readOffset, question), response = Header::kResponseFormatError);
258         readOffset += sizeof(question);
259 
260         qtype = question.GetType();
261 
262         VerifyOrExit(qtype == ResourceRecord::kTypePtr || qtype == ResourceRecord::kTypeSrv ||
263                          qtype == ResourceRecord::kTypeTxt || qtype == ResourceRecord::kTypeAaaa,
264                      response = Header::kResponseNotImplemented);
265 
266         VerifyOrExit(kErrorNone == FindNameComponents(name, aCompressInfo.GetDomainName(), nameComponentsOffsetInfo),
267                      response = Header::kResponseNameError);
268 
269         switch (question.GetType())
270         {
271         case ResourceRecord::kTypePtr:
272             VerifyOrExit(nameComponentsOffsetInfo.IsServiceName(), response = Header::kResponseNameError);
273             break;
274         case ResourceRecord::kTypeSrv:
275             VerifyOrExit(nameComponentsOffsetInfo.IsServiceInstanceName(), response = Header::kResponseNameError);
276             break;
277         case ResourceRecord::kTypeTxt:
278             VerifyOrExit(nameComponentsOffsetInfo.IsServiceInstanceName(), response = Header::kResponseNameError);
279             break;
280         case ResourceRecord::kTypeAaaa:
281             VerifyOrExit(nameComponentsOffsetInfo.IsHostName(), response = Header::kResponseNameError);
282             break;
283         default:
284             ExitNow(response = Header::kResponseNotImplemented);
285         }
286 
287         VerifyOrExit(AppendQuestion(name, question, aResponseMessage, aCompressInfo) == kErrorNone,
288                      response = Header::kResponseServerFailure);
289     }
290 
291     aResponseHeader.SetQuestionCount(aRequestHeader.GetQuestionCount());
292 
293 exit:
294     return response;
295 }
296 
AppendQuestion(const char * aName,const Question & aQuestion,Message & aMessage,NameCompressInfo & aCompressInfo)297 Error Server::AppendQuestion(const char *      aName,
298                              const Question &  aQuestion,
299                              Message &         aMessage,
300                              NameCompressInfo &aCompressInfo)
301 {
302     Error error = kErrorNone;
303 
304     switch (aQuestion.GetType())
305     {
306     case ResourceRecord::kTypePtr:
307         SuccessOrExit(error = AppendServiceName(aMessage, aName, aCompressInfo));
308         break;
309     case ResourceRecord::kTypeSrv:
310     case ResourceRecord::kTypeTxt:
311         SuccessOrExit(error = AppendInstanceName(aMessage, aName, aCompressInfo));
312         break;
313     case ResourceRecord::kTypeAaaa:
314         SuccessOrExit(error = AppendHostName(aMessage, aName, aCompressInfo));
315         break;
316     default:
317         OT_ASSERT(false);
318     }
319 
320     error = aMessage.Append(aQuestion);
321 
322 exit:
323     return error;
324 }
325 
AppendPtrRecord(Message & aMessage,const char * aServiceName,const char * aInstanceName,uint32_t aTtl,NameCompressInfo & aCompressInfo)326 Error Server::AppendPtrRecord(Message &         aMessage,
327                               const char *      aServiceName,
328                               const char *      aInstanceName,
329                               uint32_t          aTtl,
330                               NameCompressInfo &aCompressInfo)
331 {
332     Error     error;
333     PtrRecord ptrRecord;
334     uint16_t  recordOffset;
335 
336     ptrRecord.Init();
337     ptrRecord.SetTtl(aTtl);
338 
339     SuccessOrExit(error = AppendServiceName(aMessage, aServiceName, aCompressInfo));
340 
341     recordOffset = aMessage.GetLength();
342     SuccessOrExit(error = aMessage.SetLength(recordOffset + sizeof(ptrRecord)));
343 
344     SuccessOrExit(error = AppendInstanceName(aMessage, aInstanceName, aCompressInfo));
345 
346     ptrRecord.SetLength(aMessage.GetLength() - (recordOffset + sizeof(ResourceRecord)));
347     aMessage.Write(recordOffset, ptrRecord);
348 
349 exit:
350     return error;
351 }
352 
AppendSrvRecord(Message & aMessage,const char * aInstanceName,const char * aHostName,uint32_t aTtl,uint16_t aPriority,uint16_t aWeight,uint16_t aPort,NameCompressInfo & aCompressInfo)353 Error Server::AppendSrvRecord(Message &         aMessage,
354                               const char *      aInstanceName,
355                               const char *      aHostName,
356                               uint32_t          aTtl,
357                               uint16_t          aPriority,
358                               uint16_t          aWeight,
359                               uint16_t          aPort,
360                               NameCompressInfo &aCompressInfo)
361 {
362     SrvRecord srvRecord;
363     Error     error = kErrorNone;
364     uint16_t  recordOffset;
365 
366     srvRecord.Init();
367     srvRecord.SetTtl(aTtl);
368     srvRecord.SetPriority(aPriority);
369     srvRecord.SetWeight(aWeight);
370     srvRecord.SetPort(aPort);
371 
372     SuccessOrExit(error = AppendInstanceName(aMessage, aInstanceName, aCompressInfo));
373 
374     recordOffset = aMessage.GetLength();
375     SuccessOrExit(error = aMessage.SetLength(recordOffset + sizeof(srvRecord)));
376 
377     SuccessOrExit(error = AppendHostName(aMessage, aHostName, aCompressInfo));
378 
379     srvRecord.SetLength(aMessage.GetLength() - (recordOffset + sizeof(ResourceRecord)));
380     aMessage.Write(recordOffset, srvRecord);
381 
382 exit:
383     return error;
384 }
385 
AppendAaaaRecord(Message & aMessage,const char * aHostName,const Ip6::Address & aAddress,uint32_t aTtl,NameCompressInfo & aCompressInfo)386 Error Server::AppendAaaaRecord(Message &           aMessage,
387                                const char *        aHostName,
388                                const Ip6::Address &aAddress,
389                                uint32_t            aTtl,
390                                NameCompressInfo &  aCompressInfo)
391 {
392     AaaaRecord aaaaRecord;
393     Error      error;
394 
395     aaaaRecord.Init();
396     aaaaRecord.SetTtl(aTtl);
397     aaaaRecord.SetAddress(aAddress);
398 
399     SuccessOrExit(error = AppendHostName(aMessage, aHostName, aCompressInfo));
400     error = aMessage.Append(aaaaRecord);
401 
402 exit:
403     return error;
404 }
405 
AppendServiceName(Message & aMessage,const char * aName,NameCompressInfo & aCompressInfo)406 Error Server::AppendServiceName(Message &aMessage, const char *aName, NameCompressInfo &aCompressInfo)
407 {
408     Error       error;
409     uint16_t    serviceCompressOffset = aCompressInfo.GetServiceNameOffset(aMessage, aName);
410     const char *serviceName;
411 
412     // Check whether `aName` is a sub-type service name.
413     serviceName = StringFind(aName, kDnssdSubTypeLabel, kStringCaseInsensitiveMatch);
414 
415     if (serviceName != nullptr)
416     {
417         uint8_t subTypeLabelLength = static_cast<uint8_t>(serviceName - aName) + sizeof(kDnssdSubTypeLabel) - 1;
418 
419         SuccessOrExit(error = Name::AppendMultipleLabels(aName, subTypeLabelLength, aMessage));
420 
421         // Skip over the "._sub." label to get to the root service name.
422         serviceName += sizeof(kDnssdSubTypeLabel) - 1;
423     }
424     else
425     {
426         serviceName = aName;
427     }
428 
429     if (serviceCompressOffset != NameCompressInfo::kUnknownOffset)
430     {
431         error = Name::AppendPointerLabel(serviceCompressOffset, aMessage);
432     }
433     else
434     {
435         uint8_t  domainStart          = static_cast<uint8_t>(StringLength(serviceName, Name::kMaxNameSize - 1) -
436                                                    StringLength(aCompressInfo.GetDomainName(), Name::kMaxNameSize - 1));
437         uint16_t domainCompressOffset = aCompressInfo.GetDomainNameOffset();
438 
439         serviceCompressOffset = aMessage.GetLength();
440         aCompressInfo.SetServiceNameOffset(serviceCompressOffset);
441 
442         if (domainCompressOffset == NameCompressInfo::kUnknownOffset)
443         {
444             aCompressInfo.SetDomainNameOffset(serviceCompressOffset + domainStart);
445             error = Name::AppendName(serviceName, aMessage);
446         }
447         else
448         {
449             SuccessOrExit(error = Name::AppendMultipleLabels(serviceName, domainStart, aMessage));
450             error = Name::AppendPointerLabel(domainCompressOffset, aMessage);
451         }
452     }
453 
454 exit:
455     return error;
456 }
457 
AppendInstanceName(Message & aMessage,const char * aName,NameCompressInfo & aCompressInfo)458 Error Server::AppendInstanceName(Message &aMessage, const char *aName, NameCompressInfo &aCompressInfo)
459 {
460     Error    error;
461     uint16_t instanceCompressOffset = aCompressInfo.GetInstanceNameOffset(aMessage, aName);
462 
463     if (instanceCompressOffset != NameCompressInfo::kUnknownOffset)
464     {
465         error = Name::AppendPointerLabel(instanceCompressOffset, aMessage);
466     }
467     else
468     {
469         NameComponentsOffsetInfo nameComponentsInfo;
470 
471         IgnoreError(FindNameComponents(aName, aCompressInfo.GetDomainName(), nameComponentsInfo));
472         OT_ASSERT(nameComponentsInfo.IsServiceInstanceName());
473 
474         aCompressInfo.SetInstanceNameOffset(aMessage.GetLength());
475 
476         // Append the instance name as one label
477         SuccessOrExit(error = Name::AppendLabel(aName, nameComponentsInfo.mServiceOffset - 1, aMessage));
478 
479         {
480             const char *serviceName           = aName + nameComponentsInfo.mServiceOffset;
481             uint16_t    serviceCompressOffset = aCompressInfo.GetServiceNameOffset(aMessage, serviceName);
482 
483             if (serviceCompressOffset != NameCompressInfo::kUnknownOffset)
484             {
485                 error = Name::AppendPointerLabel(serviceCompressOffset, aMessage);
486             }
487             else
488             {
489                 aCompressInfo.SetServiceNameOffset(aMessage.GetLength());
490                 error = Name::AppendName(serviceName, aMessage);
491             }
492         }
493     }
494 
495 exit:
496     return error;
497 }
498 
AppendTxtRecord(Message & aMessage,const char * aInstanceName,const void * aTxtData,uint16_t aTxtLength,uint32_t aTtl,NameCompressInfo & aCompressInfo)499 Error Server::AppendTxtRecord(Message &         aMessage,
500                               const char *      aInstanceName,
501                               const void *      aTxtData,
502                               uint16_t          aTxtLength,
503                               uint32_t          aTtl,
504                               NameCompressInfo &aCompressInfo)
505 {
506     Error         error = kErrorNone;
507     TxtRecord     txtRecord;
508     const uint8_t kEmptyTxt = 0;
509 
510     SuccessOrExit(error = AppendInstanceName(aMessage, aInstanceName, aCompressInfo));
511 
512     txtRecord.Init();
513     txtRecord.SetTtl(aTtl);
514     txtRecord.SetLength(aTxtLength > 0 ? aTxtLength : sizeof(kEmptyTxt));
515 
516     SuccessOrExit(error = aMessage.Append(txtRecord));
517     if (aTxtLength > 0)
518     {
519         error = aMessage.AppendBytes(aTxtData, aTxtLength);
520     }
521     else
522     {
523         error = aMessage.Append(kEmptyTxt);
524     }
525 
526 exit:
527     return error;
528 }
529 
AppendHostName(Message & aMessage,const char * aName,NameCompressInfo & aCompressInfo)530 Error Server::AppendHostName(Message &aMessage, const char *aName, NameCompressInfo &aCompressInfo)
531 {
532     Error    error;
533     uint16_t hostCompressOffset = aCompressInfo.GetHostNameOffset(aMessage, aName);
534 
535     if (hostCompressOffset != NameCompressInfo::kUnknownOffset)
536     {
537         error = Name::AppendPointerLabel(hostCompressOffset, aMessage);
538     }
539     else
540     {
541         uint8_t  domainStart          = static_cast<uint8_t>(StringLength(aName, Name::kMaxNameLength) -
542                                                    StringLength(aCompressInfo.GetDomainName(), Name::kMaxNameSize - 1));
543         uint16_t domainCompressOffset = aCompressInfo.GetDomainNameOffset();
544 
545         hostCompressOffset = aMessage.GetLength();
546         aCompressInfo.SetHostNameOffset(hostCompressOffset);
547 
548         if (domainCompressOffset == NameCompressInfo::kUnknownOffset)
549         {
550             aCompressInfo.SetDomainNameOffset(hostCompressOffset + domainStart);
551             error = Name::AppendName(aName, aMessage);
552         }
553         else
554         {
555             SuccessOrExit(error = Name::AppendMultipleLabels(aName, domainStart, aMessage));
556             error = Name::AppendPointerLabel(domainCompressOffset, aMessage);
557         }
558     }
559 
560 exit:
561     return error;
562 }
563 
IncResourceRecordCount(Header & aHeader,bool aAdditional)564 void Server::IncResourceRecordCount(Header &aHeader, bool aAdditional)
565 {
566     if (aAdditional)
567     {
568         aHeader.SetAdditionalRecordCount(aHeader.GetAdditionalRecordCount() + 1);
569     }
570     else
571     {
572         aHeader.SetAnswerCount(aHeader.GetAnswerCount() + 1);
573     }
574 }
575 
FindNameComponents(const char * aName,const char * aDomain,NameComponentsOffsetInfo & aInfo)576 Error Server::FindNameComponents(const char *aName, const char *aDomain, NameComponentsOffsetInfo &aInfo)
577 {
578     uint8_t nameLen   = static_cast<uint8_t>(StringLength(aName, Name::kMaxNameLength));
579     uint8_t domainLen = static_cast<uint8_t>(StringLength(aDomain, Name::kMaxNameLength));
580     Error   error     = kErrorNone;
581     uint8_t labelBegin, labelEnd;
582 
583     VerifyOrExit(Name::IsSubDomainOf(aName, aDomain), error = kErrorInvalidArgs);
584 
585     labelBegin          = nameLen - domainLen;
586     aInfo.mDomainOffset = labelBegin;
587 
588     while (true)
589     {
590         error = FindPreviousLabel(aName, labelBegin, labelEnd);
591 
592         VerifyOrExit(error == kErrorNone, error = (error == kErrorNotFound ? kErrorNone : error));
593 
594         if (labelEnd == labelBegin + kProtocolLabelLength &&
595             (StringStartsWith(&aName[labelBegin], kDnssdProtocolUdp, kStringCaseInsensitiveMatch) ||
596              StringStartsWith(&aName[labelBegin], kDnssdProtocolTcp, kStringCaseInsensitiveMatch)))
597         {
598             // <Protocol> label found
599             aInfo.mProtocolOffset = labelBegin;
600             break;
601         }
602     }
603 
604     // Get service label <Service>
605     error = FindPreviousLabel(aName, labelBegin, labelEnd);
606     VerifyOrExit(error == kErrorNone, error = (error == kErrorNotFound ? kErrorNone : error));
607 
608     aInfo.mServiceOffset = labelBegin;
609 
610     // Check for service subtype
611     error = FindPreviousLabel(aName, labelBegin, labelEnd);
612     VerifyOrExit(error == kErrorNone, error = (error == kErrorNotFound ? kErrorNone : error));
613 
614     // Note that `kDnssdSubTypeLabel` is "._sub.". Here we get the
615     // label only so we want to compare it with "_sub".
616     if ((labelEnd == labelBegin + kSubTypeLabelLength) &&
617         StringStartsWith(&aName[labelBegin], kDnssdSubTypeLabel + 1, kStringCaseInsensitiveMatch))
618     {
619         SuccessOrExit(error = FindPreviousLabel(aName, labelBegin, labelEnd));
620         VerifyOrExit(labelBegin == 0, error = kErrorInvalidArgs);
621         aInfo.mSubTypeOffset = labelBegin;
622         ExitNow();
623     }
624 
625     // Treat everything before <Service> as <Instance> label
626     aInfo.mInstanceOffset = 0;
627 
628 exit:
629     return error;
630 }
631 
FindPreviousLabel(const char * aName,uint8_t & aStart,uint8_t & aStop)632 Error Server::FindPreviousLabel(const char *aName, uint8_t &aStart, uint8_t &aStop)
633 {
634     // This method finds the previous label before the current label (whose start index is @p aStart), and updates @p
635     // aStart to the start index of the label and @p aStop to the index of the dot just after the label.
636     // @note The input value of @p aStop does not matter because it is only used to output.
637 
638     Error   error = kErrorNone;
639     uint8_t start = aStart;
640     uint8_t end;
641 
642     VerifyOrExit(start > 0, error = kErrorNotFound);
643     VerifyOrExit(aName[--start] == Name::kLabelSeperatorChar, error = kErrorInvalidArgs);
644 
645     end = start;
646     while (start > 0 && aName[start - 1] != Name::kLabelSeperatorChar)
647     {
648         start--;
649     }
650 
651     VerifyOrExit(start < end, error = kErrorInvalidArgs);
652 
653     aStart = start;
654     aStop  = end;
655 
656 exit:
657     return error;
658 }
659 
660 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
ResolveBySrp(Header & aResponseHeader,Message & aResponseMessage,Server::NameCompressInfo & aCompressInfo)661 Header::Response Server::ResolveBySrp(Header &                  aResponseHeader,
662                                       Message &                 aResponseMessage,
663                                       Server::NameCompressInfo &aCompressInfo)
664 {
665     Question         question;
666     uint16_t         readOffset = sizeof(Header);
667     Header::Response response   = Header::kResponseSuccess;
668     char             name[Name::kMaxNameSize];
669 
670     for (uint16_t i = 0; i < aResponseHeader.GetQuestionCount(); i++)
671     {
672         IgnoreError(Name::ReadName(aResponseMessage, readOffset, name, sizeof(name)));
673         IgnoreError(aResponseMessage.Read(readOffset, question));
674         readOffset += sizeof(question);
675 
676         response = ResolveQuestionBySrp(name, question, aResponseHeader, aResponseMessage, aCompressInfo,
677                                         /* aAdditional */ false);
678 
679         LogInfo("ANSWER: TRANSACTION=0x%04x, QUESTION=[%s %d %d], RCODE=%d", aResponseHeader.GetMessageId(), name,
680                 question.GetClass(), question.GetType(), response);
681     }
682 
683     // Answer the questions with additional RRs if required
684     if (aResponseHeader.GetAnswerCount() > 0)
685     {
686         readOffset = sizeof(Header);
687         for (uint16_t i = 0; i < aResponseHeader.GetQuestionCount(); i++)
688         {
689             IgnoreError(Name::ReadName(aResponseMessage, readOffset, name, sizeof(name)));
690             IgnoreError(aResponseMessage.Read(readOffset, question));
691             readOffset += sizeof(question);
692 
693             VerifyOrExit(Header::kResponseServerFailure != ResolveQuestionBySrp(name, question, aResponseHeader,
694                                                                                 aResponseMessage, aCompressInfo,
695                                                                                 /* aAdditional */ true),
696                          response = Header::kResponseServerFailure);
697 
698             LogInfo("ADDITIONAL: TRANSACTION=0x%04x, QUESTION=[%s %d %d], RCODE=%d", aResponseHeader.GetMessageId(),
699                     name, question.GetClass(), question.GetType(), response);
700         }
701     }
702 exit:
703     return response;
704 }
705 
ResolveQuestionBySrp(const char * aName,const Question & aQuestion,Header & aResponseHeader,Message & aResponseMessage,NameCompressInfo & aCompressInfo,bool aAdditional)706 Header::Response Server::ResolveQuestionBySrp(const char *      aName,
707                                               const Question &  aQuestion,
708                                               Header &          aResponseHeader,
709                                               Message &         aResponseMessage,
710                                               NameCompressInfo &aCompressInfo,
711                                               bool              aAdditional)
712 {
713     Error                    error    = kErrorNone;
714     const Srp::Server::Host *host     = nullptr;
715     TimeMilli                now      = TimerMilli::GetNow();
716     uint16_t                 qtype    = aQuestion.GetType();
717     Header::Response         response = Header::kResponseNameError;
718 
719     while ((host = GetNextSrpHost(host)) != nullptr)
720     {
721         bool        needAdditionalAaaaRecord = false;
722         const char *hostName                 = host->GetFullName();
723 
724         // Handle PTR/SRV/TXT query
725         if (qtype == ResourceRecord::kTypePtr || qtype == ResourceRecord::kTypeSrv || qtype == ResourceRecord::kTypeTxt)
726         {
727             const Srp::Server::Service *service = nullptr;
728 
729             while ((service = GetNextSrpService(*host, service)) != nullptr)
730             {
731                 uint32_t    instanceTtl         = TimeMilli::MsecToSec(service->GetExpireTime() - TimerMilli::GetNow());
732                 const char *instanceName        = service->GetInstanceName();
733                 bool        serviceNameMatched  = service->MatchesServiceName(aName);
734                 bool        instanceNameMatched = service->MatchesInstanceName(aName);
735                 bool        ptrQueryMatched     = qtype == ResourceRecord::kTypePtr && serviceNameMatched;
736                 bool        srvQueryMatched     = qtype == ResourceRecord::kTypeSrv && instanceNameMatched;
737                 bool        txtQueryMatched     = qtype == ResourceRecord::kTypeTxt && instanceNameMatched;
738 
739                 if (ptrQueryMatched || srvQueryMatched)
740                 {
741                     needAdditionalAaaaRecord = true;
742                 }
743 
744                 if (!aAdditional && ptrQueryMatched)
745                 {
746                     SuccessOrExit(
747                         error = AppendPtrRecord(aResponseMessage, aName, instanceName, instanceTtl, aCompressInfo));
748                     IncResourceRecordCount(aResponseHeader, aAdditional);
749                     response = Header::kResponseSuccess;
750                 }
751 
752                 if ((!aAdditional && srvQueryMatched) ||
753                     (aAdditional && ptrQueryMatched &&
754                      !HasQuestion(aResponseHeader, aResponseMessage, instanceName, ResourceRecord::kTypeSrv)))
755                 {
756                     SuccessOrExit(error = AppendSrvRecord(aResponseMessage, instanceName, hostName, instanceTtl,
757                                                           service->GetPriority(), service->GetWeight(),
758                                                           service->GetPort(), aCompressInfo));
759                     IncResourceRecordCount(aResponseHeader, aAdditional);
760                     response = Header::kResponseSuccess;
761                 }
762 
763                 if ((!aAdditional && txtQueryMatched) ||
764                     (aAdditional && ptrQueryMatched &&
765                      !HasQuestion(aResponseHeader, aResponseMessage, instanceName, ResourceRecord::kTypeTxt)))
766                 {
767                     SuccessOrExit(error = AppendTxtRecord(aResponseMessage, instanceName, service->GetTxtData(),
768                                                           service->GetTxtDataLength(), instanceTtl, aCompressInfo));
769                     IncResourceRecordCount(aResponseHeader, aAdditional);
770                     response = Header::kResponseSuccess;
771                 }
772             }
773         }
774 
775         // Handle AAAA query
776         if ((!aAdditional && qtype == ResourceRecord::kTypeAaaa && host->Matches(aName)) ||
777             (aAdditional && needAdditionalAaaaRecord &&
778              !HasQuestion(aResponseHeader, aResponseMessage, hostName, ResourceRecord::kTypeAaaa)))
779         {
780             uint8_t             addrNum;
781             const Ip6::Address *addrs   = host->GetAddresses(addrNum);
782             uint32_t            hostTtl = TimeMilli::MsecToSec(host->GetExpireTime() - now);
783 
784             for (uint8_t i = 0; i < addrNum; i++)
785             {
786                 SuccessOrExit(error = AppendAaaaRecord(aResponseMessage, hostName, addrs[i], hostTtl, aCompressInfo));
787                 IncResourceRecordCount(aResponseHeader, aAdditional);
788             }
789 
790             response = Header::kResponseSuccess;
791         }
792     }
793 
794 exit:
795     return error == kErrorNone ? response : Header::kResponseServerFailure;
796 }
797 
GetNextSrpHost(const Srp::Server::Host * aHost)798 const Srp::Server::Host *Server::GetNextSrpHost(const Srp::Server::Host *aHost)
799 {
800     const Srp::Server::Host *host = Get<Srp::Server>().GetNextHost(aHost);
801 
802     while (host != nullptr && host->IsDeleted())
803     {
804         host = Get<Srp::Server>().GetNextHost(host);
805     }
806 
807     return host;
808 }
809 
GetNextSrpService(const Srp::Server::Host & aHost,const Srp::Server::Service * aService)810 const Srp::Server::Service *Server::GetNextSrpService(const Srp::Server::Host &   aHost,
811                                                       const Srp::Server::Service *aService)
812 {
813     return aHost.FindNextService(aService, Srp::Server::kFlagsAnyTypeActiveService);
814 }
815 #endif // OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
816 
ResolveByQueryCallbacks(Header & aResponseHeader,Message & aResponseMessage,NameCompressInfo & aCompressInfo,const Ip6::MessageInfo & aMessageInfo)817 Error Server::ResolveByQueryCallbacks(Header &                aResponseHeader,
818                                       Message &               aResponseMessage,
819                                       NameCompressInfo &      aCompressInfo,
820                                       const Ip6::MessageInfo &aMessageInfo)
821 {
822     QueryTransaction *query = nullptr;
823     DnsQueryType      queryType;
824     char              name[Name::kMaxNameSize];
825 
826     Error error = kErrorNone;
827 
828     VerifyOrExit(mQuerySubscribe != nullptr, error = kErrorFailed);
829 
830     queryType = GetQueryTypeAndName(aResponseHeader, aResponseMessage, name);
831     VerifyOrExit(queryType != kDnsQueryNone, error = kErrorNotImplemented);
832 
833     query = NewQuery(aResponseHeader, aResponseMessage, aCompressInfo, aMessageInfo);
834     VerifyOrExit(query != nullptr, error = kErrorNoBufs);
835 
836     mQuerySubscribe(mQueryCallbackContext, name);
837 
838 exit:
839     return error;
840 }
841 
NewQuery(const Header & aResponseHeader,Message & aResponseMessage,const NameCompressInfo & aCompressInfo,const Ip6::MessageInfo & aMessageInfo)842 Server::QueryTransaction *Server::NewQuery(const Header &          aResponseHeader,
843                                            Message &               aResponseMessage,
844                                            const NameCompressInfo &aCompressInfo,
845                                            const Ip6::MessageInfo &aMessageInfo)
846 {
847     QueryTransaction *newQuery = nullptr;
848 
849     for (QueryTransaction &query : mQueryTransactions)
850     {
851         if (query.IsValid())
852         {
853             continue;
854         }
855 
856         query.Init(aResponseHeader, aResponseMessage, aCompressInfo, aMessageInfo, GetInstance());
857         ExitNow(newQuery = &query);
858     }
859 
860 exit:
861     if (newQuery != nullptr)
862     {
863         ResetTimer();
864     }
865 
866     return newQuery;
867 }
868 
CanAnswerQuery(const QueryTransaction & aQuery,const char * aServiceFullName,const otDnssdServiceInstanceInfo & aInstanceInfo)869 bool Server::CanAnswerQuery(const QueryTransaction &          aQuery,
870                             const char *                      aServiceFullName,
871                             const otDnssdServiceInstanceInfo &aInstanceInfo)
872 {
873     char         name[Name::kMaxNameSize];
874     DnsQueryType sdType;
875     bool         canAnswer = false;
876 
877     sdType = GetQueryTypeAndName(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), name);
878 
879     switch (sdType)
880     {
881     case kDnsQueryBrowse:
882         canAnswer = StringMatch(name, aServiceFullName, kStringCaseInsensitiveMatch);
883         break;
884     case kDnsQueryResolve:
885         canAnswer = StringMatch(name, aInstanceInfo.mFullName, kStringCaseInsensitiveMatch);
886         break;
887     default:
888         break;
889     }
890 
891     return canAnswer;
892 }
893 
CanAnswerQuery(const Server::QueryTransaction & aQuery,const char * aHostFullName)894 bool Server::CanAnswerQuery(const Server::QueryTransaction &aQuery, const char *aHostFullName)
895 {
896     char         name[Name::kMaxNameSize];
897     DnsQueryType sdType;
898 
899     sdType = GetQueryTypeAndName(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), name);
900     return (sdType == kDnsQueryResolveHost) && StringMatch(name, aHostFullName, kStringCaseInsensitiveMatch);
901 }
902 
AnswerQuery(QueryTransaction & aQuery,const char * aServiceFullName,const otDnssdServiceInstanceInfo & aInstanceInfo)903 void Server::AnswerQuery(QueryTransaction &                aQuery,
904                          const char *                      aServiceFullName,
905                          const otDnssdServiceInstanceInfo &aInstanceInfo)
906 {
907     Header &          responseHeader  = aQuery.GetResponseHeader();
908     Message &         responseMessage = aQuery.GetResponseMessage();
909     Error             error           = kErrorNone;
910     NameCompressInfo &compressInfo    = aQuery.GetNameCompressInfo();
911 
912     if (HasQuestion(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), aServiceFullName,
913                     ResourceRecord::kTypePtr))
914     {
915         SuccessOrExit(error = AppendPtrRecord(responseMessage, aServiceFullName, aInstanceInfo.mFullName,
916                                               aInstanceInfo.mTtl, compressInfo));
917         IncResourceRecordCount(responseHeader, false);
918     }
919 
920     for (uint8_t additional = 0; additional <= 1; additional++)
921     {
922         if (HasQuestion(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), aInstanceInfo.mFullName,
923                         ResourceRecord::kTypeSrv) == !additional)
924         {
925             SuccessOrExit(error = AppendSrvRecord(responseMessage, aInstanceInfo.mFullName, aInstanceInfo.mHostName,
926                                                   aInstanceInfo.mTtl, aInstanceInfo.mPriority, aInstanceInfo.mWeight,
927                                                   aInstanceInfo.mPort, compressInfo));
928             IncResourceRecordCount(responseHeader, additional);
929         }
930 
931         if (HasQuestion(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), aInstanceInfo.mFullName,
932                         ResourceRecord::kTypeTxt) == !additional)
933         {
934             SuccessOrExit(error = AppendTxtRecord(responseMessage, aInstanceInfo.mFullName, aInstanceInfo.mTxtData,
935                                                   aInstanceInfo.mTxtLength, aInstanceInfo.mTtl, compressInfo));
936             IncResourceRecordCount(responseHeader, additional);
937         }
938 
939         if (HasQuestion(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), aInstanceInfo.mHostName,
940                         ResourceRecord::kTypeAaaa) == !additional)
941         {
942             for (uint8_t i = 0; i < aInstanceInfo.mAddressNum; i++)
943             {
944                 const Ip6::Address &address = AsCoreType(&aInstanceInfo.mAddresses[i]);
945 
946                 OT_ASSERT(!address.IsUnspecified() && !address.IsLinkLocal() && !address.IsMulticast() &&
947                           !address.IsLoopback());
948 
949                 SuccessOrExit(error = AppendAaaaRecord(responseMessage, aInstanceInfo.mHostName, address,
950                                                        aInstanceInfo.mTtl, compressInfo));
951                 IncResourceRecordCount(responseHeader, additional);
952             }
953         }
954     }
955 
956 exit:
957     FinalizeQuery(aQuery, error == kErrorNone ? Header::kResponseSuccess : Header::kResponseServerFailure);
958     ResetTimer();
959 }
960 
AnswerQuery(QueryTransaction & aQuery,const char * aHostFullName,const otDnssdHostInfo & aHostInfo)961 void Server::AnswerQuery(QueryTransaction &aQuery, const char *aHostFullName, const otDnssdHostInfo &aHostInfo)
962 {
963     Header &          responseHeader  = aQuery.GetResponseHeader();
964     Message &         responseMessage = aQuery.GetResponseMessage();
965     Error             error           = kErrorNone;
966     NameCompressInfo &compressInfo    = aQuery.GetNameCompressInfo();
967 
968     if (HasQuestion(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), aHostFullName, ResourceRecord::kTypeAaaa))
969     {
970         for (uint8_t i = 0; i < aHostInfo.mAddressNum; i++)
971         {
972             const Ip6::Address &address = AsCoreType(&aHostInfo.mAddresses[i]);
973 
974             OT_ASSERT(!address.IsUnspecified() && !address.IsMulticast() && !address.IsLinkLocal() &&
975                       !address.IsLoopback());
976 
977             SuccessOrExit(error =
978                               AppendAaaaRecord(responseMessage, aHostFullName, address, aHostInfo.mTtl, compressInfo));
979             IncResourceRecordCount(responseHeader, /* aAdditional */ false);
980         }
981     }
982 
983 exit:
984     FinalizeQuery(aQuery, error == kErrorNone ? Header::kResponseSuccess : Header::kResponseServerFailure);
985     ResetTimer();
986 }
987 
SetQueryCallbacks(otDnssdQuerySubscribeCallback aSubscribe,otDnssdQueryUnsubscribeCallback aUnsubscribe,void * aContext)988 void Server::SetQueryCallbacks(otDnssdQuerySubscribeCallback   aSubscribe,
989                                otDnssdQueryUnsubscribeCallback aUnsubscribe,
990                                void *                          aContext)
991 {
992     OT_ASSERT((aSubscribe == nullptr) == (aUnsubscribe == nullptr));
993 
994     mQuerySubscribe       = aSubscribe;
995     mQueryUnsubscribe     = aUnsubscribe;
996     mQueryCallbackContext = aContext;
997 }
998 
HandleDiscoveredServiceInstance(const char * aServiceFullName,const otDnssdServiceInstanceInfo & aInstanceInfo)999 void Server::HandleDiscoveredServiceInstance(const char *                      aServiceFullName,
1000                                              const otDnssdServiceInstanceInfo &aInstanceInfo)
1001 {
1002     OT_ASSERT(StringEndsWith(aServiceFullName, Name::kLabelSeperatorChar));
1003     OT_ASSERT(StringEndsWith(aInstanceInfo.mFullName, Name::kLabelSeperatorChar));
1004     OT_ASSERT(StringEndsWith(aInstanceInfo.mHostName, Name::kLabelSeperatorChar));
1005 
1006     for (QueryTransaction &query : mQueryTransactions)
1007     {
1008         if (query.IsValid() && CanAnswerQuery(query, aServiceFullName, aInstanceInfo))
1009         {
1010             AnswerQuery(query, aServiceFullName, aInstanceInfo);
1011         }
1012     }
1013 }
1014 
HandleDiscoveredHost(const char * aHostFullName,const otDnssdHostInfo & aHostInfo)1015 void Server::HandleDiscoveredHost(const char *aHostFullName, const otDnssdHostInfo &aHostInfo)
1016 {
1017     OT_ASSERT(StringEndsWith(aHostFullName, Name::kLabelSeperatorChar));
1018 
1019     for (QueryTransaction &query : mQueryTransactions)
1020     {
1021         if (query.IsValid() && CanAnswerQuery(query, aHostFullName))
1022         {
1023             AnswerQuery(query, aHostFullName, aHostInfo);
1024         }
1025     }
1026 }
1027 
GetNextQuery(const otDnssdQuery * aQuery) const1028 const otDnssdQuery *Server::GetNextQuery(const otDnssdQuery *aQuery) const
1029 {
1030     const QueryTransaction *cur   = &mQueryTransactions[0];
1031     const QueryTransaction *found = nullptr;
1032     const QueryTransaction *query = static_cast<const QueryTransaction *>(aQuery);
1033 
1034     if (aQuery != nullptr)
1035     {
1036         cur = query + 1;
1037     }
1038 
1039     for (; cur < GetArrayEnd(mQueryTransactions); cur++)
1040     {
1041         if (cur->IsValid())
1042         {
1043             found = cur;
1044             break;
1045         }
1046     }
1047 
1048     return static_cast<const otDnssdQuery *>(found);
1049 }
1050 
GetQueryTypeAndName(const otDnssdQuery * aQuery,char (& aName)[Name::kMaxNameSize])1051 Server::DnsQueryType Server::GetQueryTypeAndName(const otDnssdQuery *aQuery, char (&aName)[Name::kMaxNameSize])
1052 {
1053     const QueryTransaction *query = static_cast<const QueryTransaction *>(aQuery);
1054 
1055     OT_ASSERT(query->IsValid());
1056     return GetQueryTypeAndName(query->GetResponseHeader(), query->GetResponseMessage(), aName);
1057 }
1058 
GetQueryTypeAndName(const Header & aHeader,const Message & aMessage,char (& aName)[Name::kMaxNameSize])1059 Server::DnsQueryType Server::GetQueryTypeAndName(const Header & aHeader,
1060                                                  const Message &aMessage,
1061                                                  char (&aName)[Name::kMaxNameSize])
1062 {
1063     DnsQueryType sdType = kDnsQueryNone;
1064 
1065     for (uint16_t i = 0, readOffset = sizeof(Header); i < aHeader.GetQuestionCount(); i++)
1066     {
1067         Question question;
1068 
1069         IgnoreError(Name::ReadName(aMessage, readOffset, aName, sizeof(aName)));
1070         IgnoreError(aMessage.Read(readOffset, question));
1071         readOffset += sizeof(question);
1072 
1073         switch (question.GetType())
1074         {
1075         case ResourceRecord::kTypePtr:
1076             ExitNow(sdType = kDnsQueryBrowse);
1077         case ResourceRecord::kTypeSrv:
1078         case ResourceRecord::kTypeTxt:
1079             ExitNow(sdType = kDnsQueryResolve);
1080         }
1081     }
1082 
1083     for (uint16_t i = 0, readOffset = sizeof(Header); i < aHeader.GetQuestionCount(); i++)
1084     {
1085         Question question;
1086 
1087         IgnoreError(Name::ReadName(aMessage, readOffset, aName, sizeof(aName)));
1088         IgnoreError(aMessage.Read(readOffset, question));
1089         readOffset += sizeof(question);
1090 
1091         switch (question.GetType())
1092         {
1093         case ResourceRecord::kTypeAaaa:
1094         case ResourceRecord::kTypeA:
1095             ExitNow(sdType = kDnsQueryResolveHost);
1096         }
1097     }
1098 
1099 exit:
1100     return sdType;
1101 }
1102 
HasQuestion(const Header & aHeader,const Message & aMessage,const char * aName,uint16_t aQuestionType)1103 bool Server::HasQuestion(const Header &aHeader, const Message &aMessage, const char *aName, uint16_t aQuestionType)
1104 {
1105     bool found = false;
1106 
1107     for (uint16_t i = 0, readOffset = sizeof(Header); i < aHeader.GetQuestionCount(); i++)
1108     {
1109         Question question;
1110         Error    error;
1111 
1112         error = Name::CompareName(aMessage, readOffset, aName);
1113         IgnoreError(aMessage.Read(readOffset, question));
1114         readOffset += sizeof(question);
1115 
1116         if (error == kErrorNone && aQuestionType == question.GetType())
1117         {
1118             ExitNow(found = true);
1119         }
1120     }
1121 
1122 exit:
1123     return found;
1124 }
1125 
HandleTimer(Timer & aTimer)1126 void Server::HandleTimer(Timer &aTimer)
1127 {
1128     aTimer.Get<Server>().HandleTimer();
1129 }
1130 
HandleTimer(void)1131 void Server::HandleTimer(void)
1132 {
1133     TimeMilli now = TimerMilli::GetNow();
1134 
1135     for (QueryTransaction &query : mQueryTransactions)
1136     {
1137         TimeMilli expire;
1138 
1139         if (!query.IsValid())
1140         {
1141             continue;
1142         }
1143 
1144         expire = query.GetStartTime() + kQueryTimeout;
1145         if (expire <= now)
1146         {
1147             FinalizeQuery(query, Header::kResponseSuccess);
1148         }
1149     }
1150 
1151     ResetTimer();
1152 }
1153 
ResetTimer(void)1154 void Server::ResetTimer(void)
1155 {
1156     TimeMilli now        = TimerMilli::GetNow();
1157     TimeMilli nextExpire = now.GetDistantFuture();
1158 
1159     for (QueryTransaction &query : mQueryTransactions)
1160     {
1161         TimeMilli expire;
1162 
1163         if (!query.IsValid())
1164         {
1165             continue;
1166         }
1167 
1168         expire = query.GetStartTime() + kQueryTimeout;
1169         if (expire <= now)
1170         {
1171             nextExpire = now;
1172         }
1173         else if (expire < nextExpire)
1174         {
1175             nextExpire = expire;
1176         }
1177     }
1178 
1179     if (nextExpire < now.GetDistantFuture())
1180     {
1181         mTimer.FireAt(nextExpire);
1182     }
1183     else
1184     {
1185         mTimer.Stop();
1186     }
1187 }
1188 
FinalizeQuery(QueryTransaction & aQuery,Header::Response aResponseCode)1189 void Server::FinalizeQuery(QueryTransaction &aQuery, Header::Response aResponseCode)
1190 {
1191     char         name[Name::kMaxNameSize];
1192     DnsQueryType sdType;
1193 
1194     OT_ASSERT(mQueryUnsubscribe != nullptr);
1195 
1196     sdType = GetQueryTypeAndName(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), name);
1197 
1198     OT_ASSERT(sdType != kDnsQueryNone);
1199     OT_UNUSED_VARIABLE(sdType);
1200 
1201     mQueryUnsubscribe(mQueryCallbackContext, name);
1202     aQuery.Finalize(aResponseCode, mSocket);
1203 }
1204 
Init(const Header & aResponseHeader,Message & aResponseMessage,const NameCompressInfo & aCompressInfo,const Ip6::MessageInfo & aMessageInfo,Instance & aInstance)1205 void Server::QueryTransaction::Init(const Header &          aResponseHeader,
1206                                     Message &               aResponseMessage,
1207                                     const NameCompressInfo &aCompressInfo,
1208                                     const Ip6::MessageInfo &aMessageInfo,
1209                                     Instance &              aInstance)
1210 {
1211     OT_ASSERT(mResponseMessage == nullptr);
1212 
1213     InstanceLocatorInit::Init(aInstance);
1214     mResponseHeader  = aResponseHeader;
1215     mResponseMessage = &aResponseMessage;
1216     mCompressInfo    = aCompressInfo;
1217     mMessageInfo     = aMessageInfo;
1218     mStartTime       = TimerMilli::GetNow();
1219 }
1220 
Finalize(Header::Response aResponseMessage,Ip6::Udp::Socket & aSocket)1221 void Server::QueryTransaction::Finalize(Header::Response aResponseMessage, Ip6::Udp::Socket &aSocket)
1222 {
1223     OT_ASSERT(mResponseMessage != nullptr);
1224 
1225     Get<Server>().SendResponse(mResponseHeader, aResponseMessage, *mResponseMessage, mMessageInfo, aSocket);
1226     mResponseMessage = nullptr;
1227 }
1228 
UpdateResponseCounters(Header::Response aResponseCode)1229 void Server::UpdateResponseCounters(Header::Response aResponseCode)
1230 {
1231     switch (aResponseCode)
1232     {
1233     case UpdateHeader::kResponseSuccess:
1234         ++mCounters.mSuccessResponse;
1235         break;
1236     case UpdateHeader::kResponseServerFailure:
1237         ++mCounters.mServerFailureResponse;
1238         break;
1239     case UpdateHeader::kResponseFormatError:
1240         ++mCounters.mFormatErrorResponse;
1241         break;
1242     case UpdateHeader::kResponseNameError:
1243         ++mCounters.mNameErrorResponse;
1244         break;
1245     case UpdateHeader::kResponseNotImplemented:
1246         ++mCounters.mNotImplementedResponse;
1247         break;
1248     default:
1249         ++mCounters.mOtherResponse;
1250         break;
1251     }
1252 }
1253 
1254 } // namespace ServiceDiscovery
1255 } // namespace Dns
1256 } // namespace ot
1257 
1258 #endif // OPENTHREAD_CONFIG_DNS_SERVER_ENABLE
1259