• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2021, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements the DNS-SD server.
32  */
33 
34 #include "dnssd_server.hpp"
35 
36 #if OPENTHREAD_CONFIG_DNSSD_SERVER_ENABLE
37 
38 #include "instance/instance.hpp"
39 
40 namespace ot {
41 namespace Dns {
42 namespace ServiceDiscovery {
43 
44 RegisterLogModule("DnssdServer");
45 
46 const char Server::kDefaultDomainName[] = "default.service.arpa.";
47 const char Server::kSubLabel[]          = "_sub";
48 
49 #if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
50 const char *Server::kBlockedDomains[] = {"ipv4only.arpa."};
51 #endif
52 
Server(Instance & aInstance)53 Server::Server(Instance &aInstance)
54     : InstanceLocator(aInstance)
55     , mSocket(aInstance, *this)
56 #if OPENTHREAD_CONFIG_DNSSD_DISCOVERY_PROXY_ENABLE
57     , mDiscoveryProxy(aInstance)
58 #endif
59 #if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
60     , mEnableUpstreamQuery(false)
61 #endif
62     , mTimer(aInstance)
63     , mTestMode(kTestModeDisabled)
64 {
65     mCounters.Clear();
66 }
67 
Start(void)68 Error Server::Start(void)
69 {
70     Error error = kErrorNone;
71 
72     VerifyOrExit(!IsRunning());
73 
74     SuccessOrExit(error = mSocket.Open(kBindUnspecifiedNetif ? Ip6::kNetifUnspecified : Ip6::kNetifThreadInternal));
75     SuccessOrExit(error = mSocket.Bind(kPort));
76 
77 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
78     Get<Srp::Server>().HandleDnssdServerStateChange();
79 #endif
80 
81     LogInfo("Started");
82 
83 #if OPENTHREAD_CONFIG_DNSSD_DISCOVERY_PROXY_ENABLE
84     mDiscoveryProxy.UpdateState();
85 #endif
86 
87 exit:
88     if (error != kErrorNone)
89     {
90         IgnoreError(mSocket.Close());
91     }
92 
93     return error;
94 }
95 
Stop(void)96 void Server::Stop(void)
97 {
98     for (ProxyQuery &query : mProxyQueries)
99     {
100         Finalize(query, Header::kResponseServerFailure);
101     }
102 
103 #if OPENTHREAD_CONFIG_DNSSD_DISCOVERY_PROXY_ENABLE
104     mDiscoveryProxy.Stop();
105 #endif
106 
107 #if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
108     for (UpstreamQueryTransaction &txn : mUpstreamQueryTransactions)
109     {
110         if (txn.IsValid())
111         {
112             ResetUpstreamQueryTransaction(txn, kErrorFailed);
113         }
114     }
115 #endif
116 
117     mTimer.Stop();
118 
119     IgnoreError(mSocket.Close());
120     LogInfo("Stopped");
121 
122 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
123     Get<Srp::Server>().HandleDnssdServerStateChange();
124 #endif
125 }
126 
HandleUdpReceive(Message & aMessage,const Ip6::MessageInfo & aMessageInfo)127 void Server::HandleUdpReceive(Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
128 {
129     Request request;
130 
131 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
132     // We first let the `Srp::Server` process the received message.
133     // It returns `kErrorNone` to indicate that it successfully
134     // processed the message.
135 
136     VerifyOrExit(Get<Srp::Server>().HandleDnssdServerUdpReceive(aMessage, aMessageInfo) != kErrorNone);
137 #endif
138 
139     request.mMessage     = &aMessage;
140     request.mMessageInfo = &aMessageInfo;
141     SuccessOrExit(aMessage.Read(aMessage.GetOffset(), request.mHeader));
142 
143     VerifyOrExit(request.mHeader.GetType() == Header::kTypeQuery);
144 
145     LogInfo("Received query from %s", aMessageInfo.GetPeerAddr().ToString().AsCString());
146 
147     ProcessQuery(request);
148 
149 exit:
150     return;
151 }
152 
ProcessQuery(Request & aRequest)153 void Server::ProcessQuery(Request &aRequest)
154 {
155     ResponseCode rcode         = Header::kResponseSuccess;
156     bool         shouldRespond = true;
157     Response     response(GetInstance());
158 
159 #if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
160     if (mEnableUpstreamQuery && ShouldForwardToUpstream(aRequest))
161     {
162         Error error = ResolveByUpstream(aRequest);
163 
164         if (error == kErrorNone)
165         {
166             ExitNow();
167         }
168 
169         LogWarnOnError(error, "forwarding to upstream");
170 
171         rcode = Header::kResponseServerFailure;
172 
173         // Continue to allocate and prepare the response message
174         // to send the `kResponseServerFailure` response code.
175     }
176 #endif
177 
178     SuccessOrExit(response.AllocateAndInitFrom(aRequest));
179 
180 #if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
181     // Forwarding the query to the upstream may have already set the
182     // response error code.
183     SuccessOrExit(rcode);
184 #endif
185 
186     SuccessOrExit(rcode = aRequest.ParseQuestions(mTestMode, shouldRespond));
187     SuccessOrExit(rcode = response.AddQuestionsFrom(aRequest));
188 
189 #if OT_SHOULD_LOG_AT(OT_LOG_LEVEL_INFO)
190     response.Log();
191 #endif
192 
193 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
194     switch (response.ResolveBySrp())
195     {
196     case kErrorNone:
197         mCounters.mResolvedBySrp++;
198         ExitNow();
199 
200     case kErrorNotFound:
201         rcode = Header::kResponseNameError;
202         break;
203 
204     default:
205         rcode = Header::kResponseServerFailure;
206         ExitNow();
207     }
208 #endif
209 
210     ResolveByProxy(response, *aRequest.mMessageInfo);
211 
212 exit:
213     if (rcode != Header::kResponseSuccess)
214     {
215         response.SetResponseCode(rcode);
216     }
217 
218     if (shouldRespond)
219     {
220         response.Send(*aRequest.mMessageInfo);
221     }
222 }
223 
Response(Instance & aInstance)224 Server::Response::Response(Instance &aInstance)
225     : InstanceLocator(aInstance)
226 {
227     // `mHeader` constructors already clears it
228 
229     mOffsets.Clear();
230 }
231 
AllocateAndInitFrom(const Request & aRequest)232 Error Server::Response::AllocateAndInitFrom(const Request &aRequest)
233 {
234     Error error = kErrorNone;
235 
236     mMessage.Reset(Get<Server>().mSocket.NewMessage());
237     VerifyOrExit(!mMessage.IsNull(), error = kErrorNoBufs);
238 
239     mHeader.SetType(Header::kTypeResponse);
240     mHeader.SetMessageId(aRequest.mHeader.GetMessageId());
241     mHeader.SetQueryType(aRequest.mHeader.GetQueryType());
242 
243     if (aRequest.mHeader.IsRecursionDesiredFlagSet())
244     {
245         mHeader.SetRecursionDesiredFlag();
246     }
247 
248     // Append the empty header to reserve room for it in the message.
249     // Header will be updated in the message before sending it.
250     error = mMessage->Append(mHeader);
251 
252 exit:
253     if (error != kErrorNone)
254     {
255         mMessage.Free();
256     }
257 
258     return error;
259 }
260 
Send(const Ip6::MessageInfo & aMessageInfo)261 void Server::Response::Send(const Ip6::MessageInfo &aMessageInfo)
262 {
263     ResponseCode rcode = mHeader.GetResponseCode();
264 
265     VerifyOrExit(!mMessage.IsNull());
266 
267     if (rcode == Header::kResponseServerFailure)
268     {
269         mHeader.SetQuestionCount(0);
270         mHeader.SetAnswerCount(0);
271         mHeader.SetAdditionalRecordCount(0);
272         IgnoreError(mMessage->SetLength(sizeof(Header)));
273     }
274 
275     mMessage->Write(0, mHeader);
276 
277     SuccessOrExit(Get<Server>().mSocket.SendTo(*mMessage, aMessageInfo));
278 
279     // When `SendTo()` returns success it takes over ownership of
280     // the given message, so we release ownership of `mMessage`.
281 
282     mMessage.Release();
283 
284     LogInfo("Send response, rcode:%u", rcode);
285 
286     Get<Server>().UpdateResponseCounters(rcode);
287 
288 exit:
289     return;
290 }
291 
ParseQuestions(uint8_t aTestMode,bool & aShouldRespond)292 Server::ResponseCode Server::Request::ParseQuestions(uint8_t aTestMode, bool &aShouldRespond)
293 {
294     // Parse header and questions from a `Request` query message and
295     // determine the `QueryType`.
296 
297     ResponseCode rcode         = Header::kResponseFormatError;
298     uint16_t     offset        = sizeof(Header);
299     uint16_t     questionCount = mHeader.GetQuestionCount();
300     Question     question;
301 
302     aShouldRespond = true;
303 
304     VerifyOrExit(mHeader.GetQueryType() == Header::kQueryTypeStandard, rcode = Header::kResponseNotImplemented);
305     VerifyOrExit(!mHeader.IsTruncationFlagSet());
306 
307     VerifyOrExit(questionCount > 0);
308 
309     SuccessOrExit(Name::ParseName(*mMessage, offset));
310     SuccessOrExit(mMessage->Read(offset, question));
311     offset += sizeof(question);
312 
313     switch (question.GetType())
314     {
315     case ResourceRecord::kTypePtr:
316         mType = kPtrQuery;
317         break;
318     case ResourceRecord::kTypeSrv:
319         mType = kSrvQuery;
320         break;
321     case ResourceRecord::kTypeTxt:
322         mType = kTxtQuery;
323         break;
324     case ResourceRecord::kTypeAaaa:
325         mType = kAaaaQuery;
326         break;
327     case ResourceRecord::kTypeA:
328         mType = kAQuery;
329         break;
330     default:
331         ExitNow(rcode = Header::kResponseNotImplemented);
332     }
333 
334     if (questionCount > 1)
335     {
336         VerifyOrExit(!(aTestMode & kTestModeRejectMultiQuestionQuery));
337         VerifyOrExit(!(aTestMode & kTestModeIgnoreMultiQuestionQuery), aShouldRespond = false);
338 
339         VerifyOrExit(questionCount == 2);
340 
341         SuccessOrExit(Name::CompareName(*mMessage, offset, *mMessage, sizeof(Header)));
342         SuccessOrExit(mMessage->Read(offset, question));
343 
344         switch (question.GetType())
345         {
346         case ResourceRecord::kTypeSrv:
347             VerifyOrExit(mType == kTxtQuery);
348             break;
349 
350         case ResourceRecord::kTypeTxt:
351             VerifyOrExit(mType == kSrvQuery);
352             break;
353 
354         default:
355             ExitNow();
356         }
357 
358         mType = kSrvTxtQuery;
359     }
360 
361     rcode = Header::kResponseSuccess;
362 
363 exit:
364     return rcode;
365 }
366 
AddQuestionsFrom(const Request & aRequest)367 Server::ResponseCode Server::Response::AddQuestionsFrom(const Request &aRequest)
368 {
369     ResponseCode rcode = Header::kResponseServerFailure;
370     uint16_t     offset;
371 
372     mType = aRequest.mType;
373 
374     // Read the name from `aRequest.mMessage` and append it as is to
375     // the response message. This ensures all name formats, including
376     // service instance names with dot characters in the instance
377     // label, are appended correctly.
378 
379     SuccessOrExit(Name(*aRequest.mMessage, sizeof(Header)).AppendTo(*mMessage));
380 
381     // Check the name to include the correct domain name and determine
382     // the domain name offset (for DNS name compression).
383 
384     VerifyOrExit(ParseQueryName() == kErrorNone, rcode = Header::kResponseNameError);
385 
386     mHeader.SetQuestionCount(aRequest.mHeader.GetQuestionCount());
387 
388     offset = sizeof(Header);
389 
390     for (uint16_t questionCount = 0; questionCount < mHeader.GetQuestionCount(); questionCount++)
391     {
392         Question question;
393 
394         // The names and questions in `aRequest` are validated already
395         // from `ParseQuestions()`, so we can `IgnoreError()`  here.
396 
397         IgnoreError(Name::ParseName(*aRequest.mMessage, offset));
398         IgnoreError(aRequest.mMessage->Read(offset, question));
399         offset += sizeof(question);
400 
401         if (questionCount != 0)
402         {
403             SuccessOrExit(AppendQueryName());
404         }
405 
406         SuccessOrExit(mMessage->Append(question));
407     }
408 
409     rcode = Header::kResponseSuccess;
410 
411 exit:
412     return rcode;
413 }
414 
ParseQueryName(void)415 Error Server::Response::ParseQueryName(void)
416 {
417     // Parses and validates the query name and updates
418     // the name compression offsets.
419 
420     Error        error = kErrorNone;
421     Name::Buffer name;
422     uint16_t     offset;
423 
424     offset = sizeof(Header);
425     SuccessOrExit(error = Name::ReadName(*mMessage, offset, name));
426 
427     switch (mType)
428     {
429     case kPtrQuery:
430         // `mOffsets.mServiceName` may be updated as we read labels and if we
431         // determine that the query name is a sub-type service.
432         mOffsets.mServiceName = sizeof(Header);
433         break;
434 
435     case kSrvQuery:
436     case kTxtQuery:
437     case kSrvTxtQuery:
438         mOffsets.mInstanceName = sizeof(Header);
439         break;
440 
441     case kAaaaQuery:
442     case kAQuery:
443         mOffsets.mHostName = sizeof(Header);
444         break;
445     }
446 
447     // Read the query name labels one by one to check if the name is
448     // service sub-type and also check that it is sub-domain of the
449     // default domain name and determine its offset
450 
451     offset = sizeof(Header);
452 
453     while (true)
454     {
455         Name::LabelBuffer label;
456         uint8_t           labelLength = sizeof(label);
457         uint16_t          comapreOffset;
458 
459         SuccessOrExit(error = Name::ReadLabel(*mMessage, offset, label, labelLength));
460 
461         if ((mType == kPtrQuery) && StringMatch(label, kSubLabel, kStringCaseInsensitiveMatch))
462         {
463             mOffsets.mServiceName = offset;
464         }
465 
466         comapreOffset = offset;
467 
468         if (Name::CompareName(*mMessage, comapreOffset, kDefaultDomainName) == kErrorNone)
469         {
470             mOffsets.mDomainName = offset;
471             ExitNow();
472         }
473     }
474 
475     error = kErrorParse;
476 
477 exit:
478     return error;
479 }
480 
ReadQueryName(Name::Buffer & aName) const481 void Server::Response::ReadQueryName(Name::Buffer &aName) const { Server::ReadQueryName(*mMessage, aName); }
482 
QueryNameMatches(const char * aName) const483 bool Server::Response::QueryNameMatches(const char *aName) const { return Server::QueryNameMatches(*mMessage, aName); }
484 
AppendQueryName(void)485 Error Server::Response::AppendQueryName(void) { return Name::AppendPointerLabel(sizeof(Header), *mMessage); }
486 
AppendPtrRecord(const char * aInstanceLabel,uint32_t aTtl)487 Error Server::Response::AppendPtrRecord(const char *aInstanceLabel, uint32_t aTtl)
488 {
489     Error     error;
490     uint16_t  recordOffset;
491     PtrRecord ptrRecord;
492 
493     ptrRecord.Init();
494     ptrRecord.SetTtl(aTtl);
495 
496     SuccessOrExit(error = AppendQueryName());
497 
498     recordOffset = mMessage->GetLength();
499     SuccessOrExit(error = mMessage->Append(ptrRecord));
500 
501     mOffsets.mInstanceName = mMessage->GetLength();
502     SuccessOrExit(error = Name::AppendLabel(aInstanceLabel, *mMessage));
503     SuccessOrExit(error = Name::AppendPointerLabel(mOffsets.mServiceName, *mMessage));
504 
505     UpdateRecordLength(ptrRecord, recordOffset);
506 
507     IncResourceRecordCount();
508 
509 exit:
510     return error;
511 }
512 
513 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
AppendSrvRecord(const Srp::Server::Service & aService)514 Error Server::Response::AppendSrvRecord(const Srp::Server::Service &aService)
515 {
516     uint32_t ttl = TimeMilli::MsecToSec(aService.GetExpireTime() - TimerMilli::GetNow());
517 
518     return AppendSrvRecord(aService.GetHost().GetFullName(), ttl, aService.GetPriority(), aService.GetWeight(),
519                            aService.GetPort());
520 }
521 #endif
522 
AppendSrvRecord(const ServiceInstanceInfo & aInstanceInfo)523 Error Server::Response::AppendSrvRecord(const ServiceInstanceInfo &aInstanceInfo)
524 {
525     return AppendSrvRecord(aInstanceInfo.mHostName, aInstanceInfo.mTtl, aInstanceInfo.mPriority, aInstanceInfo.mWeight,
526                            aInstanceInfo.mPort);
527 }
528 
AppendSrvRecord(const char * aHostName,uint32_t aTtl,uint16_t aPriority,uint16_t aWeight,uint16_t aPort)529 Error Server::Response::AppendSrvRecord(const char *aHostName,
530                                         uint32_t    aTtl,
531                                         uint16_t    aPriority,
532                                         uint16_t    aWeight,
533                                         uint16_t    aPort)
534 {
535     Error        error = kErrorNone;
536     SrvRecord    srvRecord;
537     uint16_t     recordOffset;
538     Name::Buffer hostLabels;
539 
540     SuccessOrExit(error = Name::ExtractLabels(aHostName, kDefaultDomainName, hostLabels));
541 
542     srvRecord.Init();
543     srvRecord.SetTtl(aTtl);
544     srvRecord.SetPriority(aPriority);
545     srvRecord.SetWeight(aWeight);
546     srvRecord.SetPort(aPort);
547 
548     SuccessOrExit(error = Name::AppendPointerLabel(mOffsets.mInstanceName, *mMessage));
549 
550     recordOffset = mMessage->GetLength();
551     SuccessOrExit(error = mMessage->Append(srvRecord));
552 
553     mOffsets.mHostName = mMessage->GetLength();
554     SuccessOrExit(error = Name::AppendMultipleLabels(hostLabels, *mMessage));
555     SuccessOrExit(error = Name::AppendPointerLabel(mOffsets.mDomainName, *mMessage));
556 
557     UpdateRecordLength(srvRecord, recordOffset);
558 
559     IncResourceRecordCount();
560 
561 exit:
562     return error;
563 }
564 
565 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
AppendHostAddresses(const Srp::Server::Host & aHost)566 Error Server::Response::AppendHostAddresses(const Srp::Server::Host &aHost)
567 {
568     const Ip6::Address *addrs;
569     uint8_t             addrsLength;
570     uint32_t            ttl;
571 
572     addrs = aHost.GetAddresses(addrsLength);
573     ttl   = TimeMilli::MsecToSec(aHost.GetExpireTime() - TimerMilli::GetNow());
574 
575     return AppendHostAddresses(kIp6AddrType, addrs, addrsLength, ttl);
576 }
577 #endif
578 
AppendHostAddresses(AddrType aAddrType,const HostInfo & aHostInfo)579 Error Server::Response::AppendHostAddresses(AddrType aAddrType, const HostInfo &aHostInfo)
580 {
581     return AppendHostAddresses(aAddrType, AsCoreTypePtr(aHostInfo.mAddresses), aHostInfo.mAddressNum, aHostInfo.mTtl);
582 }
583 
AppendHostAddresses(const ServiceInstanceInfo & aInstanceInfo)584 Error Server::Response::AppendHostAddresses(const ServiceInstanceInfo &aInstanceInfo)
585 {
586     return AppendHostAddresses(kIp6AddrType, AsCoreTypePtr(aInstanceInfo.mAddresses), aInstanceInfo.mAddressNum,
587                                aInstanceInfo.mTtl);
588 }
589 
AppendHostAddresses(AddrType aAddrType,const Ip6::Address * aAddrs,uint16_t aAddrsLength,uint32_t aTtl)590 Error Server::Response::AppendHostAddresses(AddrType            aAddrType,
591                                             const Ip6::Address *aAddrs,
592                                             uint16_t            aAddrsLength,
593                                             uint32_t            aTtl)
594 {
595     Error error = kErrorNone;
596 
597     for (uint16_t index = 0; index < aAddrsLength; index++)
598     {
599         const Ip6::Address &address = aAddrs[index];
600 
601         switch (aAddrType)
602         {
603         case kIp6AddrType:
604             SuccessOrExit(error = AppendAaaaRecord(address, aTtl));
605             break;
606 
607         case kIp4AddrType:
608             SuccessOrExit(error = AppendARecord(address, aTtl));
609             break;
610         }
611     }
612 
613 exit:
614     return error;
615 }
616 
AppendAaaaRecord(const Ip6::Address & aAddress,uint32_t aTtl)617 Error Server::Response::AppendAaaaRecord(const Ip6::Address &aAddress, uint32_t aTtl)
618 {
619     Error      error = kErrorNone;
620     AaaaRecord aaaaRecord;
621 
622     VerifyOrExit(!aAddress.IsIp4Mapped());
623 
624     aaaaRecord.Init();
625     aaaaRecord.SetTtl(aTtl);
626     aaaaRecord.SetAddress(aAddress);
627 
628     SuccessOrExit(error = Name::AppendPointerLabel(mOffsets.mHostName, *mMessage));
629     SuccessOrExit(error = mMessage->Append(aaaaRecord));
630     IncResourceRecordCount();
631 
632 exit:
633     return error;
634 }
635 
AppendARecord(const Ip6::Address & aAddress,uint32_t aTtl)636 Error Server::Response::AppendARecord(const Ip6::Address &aAddress, uint32_t aTtl)
637 {
638     Error        error = kErrorNone;
639     ARecord      aRecord;
640     Ip4::Address ip4Address;
641 
642     SuccessOrExit(ip4Address.ExtractFromIp4MappedIp6Address(aAddress));
643 
644     aRecord.Init();
645     aRecord.SetTtl(aTtl);
646     aRecord.SetAddress(ip4Address);
647 
648     SuccessOrExit(error = Name::AppendPointerLabel(mOffsets.mHostName, *mMessage));
649     SuccessOrExit(error = mMessage->Append(aRecord));
650     IncResourceRecordCount();
651 
652 exit:
653     return error;
654 }
655 
656 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
AppendTxtRecord(const Srp::Server::Service & aService)657 Error Server::Response::AppendTxtRecord(const Srp::Server::Service &aService)
658 {
659     return AppendTxtRecord(aService.GetTxtData(), aService.GetTxtDataLength(),
660                            TimeMilli::MsecToSec(aService.GetExpireTime() - TimerMilli::GetNow()));
661 }
662 #endif
663 
AppendTxtRecord(const ServiceInstanceInfo & aInstanceInfo)664 Error Server::Response::AppendTxtRecord(const ServiceInstanceInfo &aInstanceInfo)
665 {
666     return AppendTxtRecord(aInstanceInfo.mTxtData, aInstanceInfo.mTxtLength, aInstanceInfo.mTtl);
667 }
668 
AppendTxtRecord(const void * aTxtData,uint16_t aTxtLength,uint32_t aTtl)669 Error Server::Response::AppendTxtRecord(const void *aTxtData, uint16_t aTxtLength, uint32_t aTtl)
670 {
671     Error     error = kErrorNone;
672     TxtRecord txtRecord;
673     uint8_t   emptyTxt = 0;
674 
675     if (aTxtLength == 0)
676     {
677         aTxtData   = &emptyTxt;
678         aTxtLength = sizeof(emptyTxt);
679     }
680 
681     txtRecord.Init();
682     txtRecord.SetTtl(aTtl);
683     txtRecord.SetLength(aTxtLength);
684 
685     SuccessOrExit(error = Name::AppendPointerLabel(mOffsets.mInstanceName, *mMessage));
686     SuccessOrExit(error = mMessage->Append(txtRecord));
687     SuccessOrExit(error = mMessage->AppendBytes(aTxtData, aTxtLength));
688 
689     IncResourceRecordCount();
690 
691 exit:
692     return error;
693 }
694 
UpdateRecordLength(ResourceRecord & aRecord,uint16_t aOffset)695 void Server::Response::UpdateRecordLength(ResourceRecord &aRecord, uint16_t aOffset)
696 {
697     // Calculates RR DATA length and updates and re-writes it in the
698     // response message. This should be called immediately
699     // after all the fields in the record are written in the message.
700     // `aOffset` gives the offset in the message to the start of the
701     // record.
702 
703     aRecord.SetLength(mMessage->GetLength() - aOffset - sizeof(Dns::ResourceRecord));
704     mMessage->Write(aOffset, aRecord);
705 }
706 
IncResourceRecordCount(void)707 void Server::Response::IncResourceRecordCount(void)
708 {
709     switch (mSection)
710     {
711     case kAnswerSection:
712         mHeader.SetAnswerCount(mHeader.GetAnswerCount() + 1);
713         break;
714     case kAdditionalDataSection:
715         mHeader.SetAdditionalRecordCount(mHeader.GetAdditionalRecordCount() + 1);
716         break;
717     }
718 }
719 
720 #if OT_SHOULD_LOG_AT(OT_LOG_LEVEL_INFO)
Log(void) const721 void Server::Response::Log(void) const
722 {
723     Name::Buffer name;
724 
725     ReadQueryName(name);
726     LogInfo("%s query for '%s'", QueryTypeToString(mType), name);
727 }
728 
QueryTypeToString(QueryType aType)729 const char *Server::Response::QueryTypeToString(QueryType aType)
730 {
731     static const char *const kTypeNames[] = {
732         "PTR",       // (0) kPtrQuery
733         "SRV",       // (1) kSrvQuery
734         "TXT",       // (2) kTxtQuery
735         "SRV & TXT", // (3) kSrvTxtQuery
736         "AAAA",      // (4) kAaaaQuery
737         "A",         // (5) kAQuery
738     };
739 
740     struct EumCheck
741     {
742         InitEnumValidatorCounter();
743         ValidateNextEnum(kPtrQuery);
744         ValidateNextEnum(kSrvQuery);
745         ValidateNextEnum(kTxtQuery);
746         ValidateNextEnum(kSrvTxtQuery);
747         ValidateNextEnum(kAaaaQuery);
748         ValidateNextEnum(kAQuery);
749     };
750 
751     return kTypeNames[aType];
752 }
753 #endif
754 
755 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
756 
ResolveBySrp(void)757 Error Server::Response::ResolveBySrp(void)
758 {
759     static const Section kSections[] = {kAnswerSection, kAdditionalDataSection};
760 
761     Error                       error          = kErrorNotFound;
762     const Srp::Server::Service *matchedService = nullptr;
763     bool                        found          = false;
764     Section                     srvSection;
765     Section                     txtSection;
766 
767     mSection = kAnswerSection;
768 
769     for (const Srp::Server::Host &host : Get<Srp::Server>().GetHosts())
770     {
771         if (host.IsDeleted())
772         {
773             continue;
774         }
775 
776         if ((mType == kAaaaQuery) || (mType == kAQuery))
777         {
778             if (QueryNameMatches(host.GetFullName()))
779             {
780                 mSection = (mType == kAaaaQuery) ? kAnswerSection : kAdditionalDataSection;
781                 error    = AppendHostAddresses(host);
782                 ExitNow();
783             }
784 
785             continue;
786         }
787 
788         // `mType` is PTR or SRV/TXT query
789 
790         for (const Srp::Server::Service &service : host.GetServices())
791         {
792             if (service.IsDeleted())
793             {
794                 continue;
795             }
796 
797             if (mType == kPtrQuery)
798             {
799                 if (QueryNameMatchesService(service))
800                 {
801                     uint32_t ttl = TimeMilli::MsecToSec(service.GetExpireTime() - TimerMilli::GetNow());
802 
803                     SuccessOrExit(error = AppendPtrRecord(service.GetInstanceLabel(), ttl));
804                     matchedService = &service;
805                 }
806             }
807             else if (QueryNameMatches(service.GetInstanceName()))
808             {
809                 matchedService = &service;
810                 found          = true;
811                 break;
812             }
813         }
814 
815         if (found)
816         {
817             break;
818         }
819     }
820 
821     VerifyOrExit(matchedService != nullptr);
822 
823     if (mType == kPtrQuery)
824     {
825         // Skip adding additional records, when answering a
826         // PTR query with more than one answer. This is the
827         // recommended behavior to keep the size of the
828         // response small.
829 
830         VerifyOrExit(mHeader.GetAnswerCount() == 1);
831     }
832 
833     srvSection = ((mType == kSrvQuery) || (mType == kSrvTxtQuery)) ? kAnswerSection : kAdditionalDataSection;
834     txtSection = ((mType == kTxtQuery) || (mType == kSrvTxtQuery)) ? kAnswerSection : kAdditionalDataSection;
835 
836     for (Section section : kSections)
837     {
838         mSection = section;
839 
840         if (mSection == kAdditionalDataSection)
841         {
842             VerifyOrExit(!(Get<Server>().mTestMode & kTestModeEmptyAdditionalSection));
843         }
844 
845         if (srvSection == mSection)
846         {
847             SuccessOrExit(error = AppendSrvRecord(*matchedService));
848         }
849 
850         if (txtSection == mSection)
851         {
852             SuccessOrExit(error = AppendTxtRecord(*matchedService));
853         }
854     }
855 
856     SuccessOrExit(error = AppendHostAddresses(matchedService->GetHost()));
857 
858 exit:
859     return error;
860 }
861 
QueryNameMatchesService(const Srp::Server::Service & aService) const862 bool Server::Response::QueryNameMatchesService(const Srp::Server::Service &aService) const
863 {
864     // Check if the query name matches the base service name or any
865     // sub-type service names associated with `aService`.
866 
867     bool matches = QueryNameMatches(aService.GetServiceName());
868 
869     VerifyOrExit(!matches);
870 
871     for (uint16_t index = 0; index < aService.GetNumberOfSubTypes(); index++)
872     {
873         matches = QueryNameMatches(aService.GetSubTypeServiceNameAt(index));
874         VerifyOrExit(!matches);
875     }
876 
877 exit:
878     return matches;
879 }
880 
881 #endif // OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
882 
883 #if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
ShouldForwardToUpstream(const Request & aRequest)884 bool Server::ShouldForwardToUpstream(const Request &aRequest)
885 {
886     bool         shouldForward = false;
887     uint16_t     readOffset;
888     Name::Buffer name;
889 
890     VerifyOrExit(aRequest.mHeader.IsRecursionDesiredFlagSet());
891     readOffset = sizeof(Header);
892 
893     for (uint16_t i = 0; i < aRequest.mHeader.GetQuestionCount(); i++)
894     {
895         SuccessOrExit(Name::ReadName(*aRequest.mMessage, readOffset, name));
896         readOffset += sizeof(Question);
897 
898         VerifyOrExit(!Name::IsSubDomainOf(name, kDefaultDomainName));
899 
900         for (const char *blockedDomain : kBlockedDomains)
901         {
902             VerifyOrExit(!Name::IsSameDomain(name, blockedDomain));
903         }
904     }
905 
906     shouldForward = true;
907 
908 exit:
909     return shouldForward;
910 }
911 
OnUpstreamQueryDone(UpstreamQueryTransaction & aQueryTransaction,Message * aResponseMessage)912 void Server::OnUpstreamQueryDone(UpstreamQueryTransaction &aQueryTransaction, Message *aResponseMessage)
913 {
914     Error error = kErrorNone;
915 
916     VerifyOrExit(aQueryTransaction.IsValid(), error = kErrorInvalidArgs);
917 
918     if (aResponseMessage != nullptr)
919     {
920         error = mSocket.SendTo(*aResponseMessage, aQueryTransaction.GetMessageInfo());
921     }
922     else
923     {
924         error = kErrorResponseTimeout;
925     }
926 
927     ResetUpstreamQueryTransaction(aQueryTransaction, error);
928 
929 exit:
930     FreeMessageOnError(aResponseMessage, error);
931 }
932 
AllocateUpstreamQueryTransaction(const Ip6::MessageInfo & aMessageInfo)933 Server::UpstreamQueryTransaction *Server::AllocateUpstreamQueryTransaction(const Ip6::MessageInfo &aMessageInfo)
934 {
935     UpstreamQueryTransaction *newTxn = nullptr;
936 
937     for (UpstreamQueryTransaction &txn : mUpstreamQueryTransactions)
938     {
939         if (!txn.IsValid())
940         {
941             newTxn = &txn;
942             break;
943         }
944     }
945 
946     VerifyOrExit(newTxn != nullptr, mCounters.mUpstreamDnsCounters.mFailures++);
947 
948     newTxn->Init(aMessageInfo);
949     LogInfo("Upstream query transaction %d initialized.", static_cast<int>(newTxn - mUpstreamQueryTransactions));
950     mTimer.FireAtIfEarlier(newTxn->GetExpireTime());
951 
952 exit:
953     return newTxn;
954 }
955 
ResolveByUpstream(const Request & aRequest)956 Error Server::ResolveByUpstream(const Request &aRequest)
957 {
958     Error                     error = kErrorNone;
959     UpstreamQueryTransaction *txn;
960 
961     txn = AllocateUpstreamQueryTransaction(*aRequest.mMessageInfo);
962     VerifyOrExit(txn != nullptr, error = kErrorNoBufs);
963 
964     otPlatDnsStartUpstreamQuery(&GetInstance(), txn, aRequest.mMessage);
965     mCounters.mUpstreamDnsCounters.mQueries++;
966 
967 exit:
968     return error;
969 }
970 #endif // OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
971 
ResolveByProxy(Response & aResponse,const Ip6::MessageInfo & aMessageInfo)972 void Server::ResolveByProxy(Response &aResponse, const Ip6::MessageInfo &aMessageInfo)
973 {
974     ProxyQuery    *query;
975     ProxyQueryInfo info;
976 
977 #if OPENTHREAD_CONFIG_DNSSD_DISCOVERY_PROXY_ENABLE
978     VerifyOrExit(mQuerySubscribe.IsSet() || mDiscoveryProxy.IsRunning());
979 #else
980     VerifyOrExit(mQuerySubscribe.IsSet());
981 #endif
982 
983     // We try to convert `aResponse.mMessage` to a `ProxyQuery` by
984     // appending `ProxyQueryInfo` to it.
985 
986     info.mType        = aResponse.mType;
987     info.mMessageInfo = aMessageInfo;
988     info.mExpireTime  = TimerMilli::GetNow() + kQueryTimeout;
989     info.mOffsets     = aResponse.mOffsets;
990 
991 #if OPENTHREAD_CONFIG_DNSSD_DISCOVERY_PROXY_ENABLE
992     info.mAction = kNoAction;
993 #endif
994 
995     if (aResponse.mMessage->Append(info) != kErrorNone)
996     {
997         aResponse.SetResponseCode(Header::kResponseServerFailure);
998         ExitNow();
999     }
1000 
1001     // Take over the ownership of `aResponse.mMessage` and add it as a
1002     // `ProxyQuery` in `mProxyQueries` list.
1003 
1004     query = aResponse.mMessage.Release();
1005 
1006     query->Write(0, aResponse.mHeader);
1007     mProxyQueries.Enqueue(*query);
1008 
1009     mTimer.FireAtIfEarlier(info.mExpireTime);
1010 
1011 #if OPENTHREAD_CONFIG_DNSSD_DISCOVERY_PROXY_ENABLE
1012     if (mQuerySubscribe.IsSet())
1013 #endif
1014     {
1015         Name::Buffer name;
1016 
1017         ReadQueryName(*query, name);
1018         mQuerySubscribe.Invoke(name);
1019     }
1020 #if OPENTHREAD_CONFIG_DNSSD_DISCOVERY_PROXY_ENABLE
1021     else
1022     {
1023         mDiscoveryProxy.Resolve(*query, info);
1024     }
1025 #endif
1026 
1027 exit:
1028     return;
1029 }
1030 
ReadQueryName(const Message & aQuery,Name::Buffer & aName)1031 void Server::ReadQueryName(const Message &aQuery, Name::Buffer &aName)
1032 {
1033     uint16_t offset = sizeof(Header);
1034 
1035     IgnoreError(Name::ReadName(aQuery, offset, aName));
1036 }
1037 
QueryNameMatches(const Message & aQuery,const char * aName)1038 bool Server::QueryNameMatches(const Message &aQuery, const char *aName)
1039 {
1040     uint16_t offset = sizeof(Header);
1041 
1042     return (Name::CompareName(aQuery, offset, aName) == kErrorNone);
1043 }
1044 
ReadQueryInstanceName(const ProxyQuery & aQuery,const ProxyQueryInfo & aInfo,Name::Buffer & aName)1045 void Server::ReadQueryInstanceName(const ProxyQuery &aQuery, const ProxyQueryInfo &aInfo, Name::Buffer &aName)
1046 {
1047     uint16_t offset = aInfo.mOffsets.mInstanceName;
1048 
1049     IgnoreError(Name::ReadName(aQuery, offset, aName, sizeof(aName)));
1050 }
1051 
ReadQueryInstanceName(const ProxyQuery & aQuery,const ProxyQueryInfo & aInfo,Name::LabelBuffer & aInstanceLabel,Name::Buffer & aServiceType)1052 void Server::ReadQueryInstanceName(const ProxyQuery     &aQuery,
1053                                    const ProxyQueryInfo &aInfo,
1054                                    Name::LabelBuffer    &aInstanceLabel,
1055                                    Name::Buffer         &aServiceType)
1056 {
1057     // Reads the service instance label and service type with domain
1058     // name stripped.
1059 
1060     uint16_t offset      = aInfo.mOffsets.mInstanceName;
1061     uint8_t  labelLength = sizeof(aInstanceLabel);
1062 
1063     IgnoreError(Dns::Name::ReadLabel(aQuery, offset, aInstanceLabel, labelLength));
1064     IgnoreError(Dns::Name::ReadName(aQuery, offset, aServiceType));
1065     IgnoreError(StripDomainName(aServiceType));
1066 }
1067 
QueryInstanceNameMatches(const ProxyQuery & aQuery,const ProxyQueryInfo & aInfo,const char * aName)1068 bool Server::QueryInstanceNameMatches(const ProxyQuery &aQuery, const ProxyQueryInfo &aInfo, const char *aName)
1069 {
1070     uint16_t offset = aInfo.mOffsets.mInstanceName;
1071 
1072     return (Name::CompareName(aQuery, offset, aName) == kErrorNone);
1073 }
1074 
ReadQueryHostName(const ProxyQuery & aQuery,const ProxyQueryInfo & aInfo,Name::Buffer & aName)1075 void Server::ReadQueryHostName(const ProxyQuery &aQuery, const ProxyQueryInfo &aInfo, Name::Buffer &aName)
1076 {
1077     uint16_t offset = aInfo.mOffsets.mHostName;
1078 
1079     IgnoreError(Name::ReadName(aQuery, offset, aName, sizeof(aName)));
1080 }
1081 
QueryHostNameMatches(const ProxyQuery & aQuery,const ProxyQueryInfo & aInfo,const char * aName)1082 bool Server::QueryHostNameMatches(const ProxyQuery &aQuery, const ProxyQueryInfo &aInfo, const char *aName)
1083 {
1084     uint16_t offset = aInfo.mOffsets.mHostName;
1085 
1086     return (Name::CompareName(aQuery, offset, aName) == kErrorNone);
1087 }
1088 
StripDomainName(Name::Buffer & aName)1089 Error Server::StripDomainName(Name::Buffer &aName)
1090 {
1091     // In-place removes the domain name from `aName`.
1092 
1093     return Name::StripName(aName, kDefaultDomainName);
1094 }
1095 
StripDomainName(const char * aFullName,Name::Buffer & aLabels)1096 Error Server::StripDomainName(const char *aFullName, Name::Buffer &aLabels)
1097 {
1098     // Remove the domain name from `aFullName` and copies
1099     // the result into `aLabels`.
1100 
1101     return Name::ExtractLabels(aFullName, kDefaultDomainName, aLabels, sizeof(aLabels));
1102 }
1103 
ConstructFullName(const char * aLabels,Name::Buffer & aFullName)1104 void Server::ConstructFullName(const char *aLabels, Name::Buffer &aFullName)
1105 {
1106     // Construct a full name by appending the default domain name
1107     // to `aLabels`.
1108 
1109     StringWriter fullName(aFullName, sizeof(aFullName));
1110 
1111     fullName.Append("%s.%s", aLabels, kDefaultDomainName);
1112 }
1113 
ConstructFullInstanceName(const char * aInstanceLabel,const char * aServiceType,Name::Buffer & aFullName)1114 void Server::ConstructFullInstanceName(const char *aInstanceLabel, const char *aServiceType, Name::Buffer &aFullName)
1115 {
1116     StringWriter fullName(aFullName, sizeof(aFullName));
1117 
1118     fullName.Append("%s.%s.%s", aInstanceLabel, aServiceType, kDefaultDomainName);
1119 }
1120 
ConstructFullServiceSubTypeName(const char * aServiceType,const char * aSubTypeLabel,Name::Buffer & aFullName)1121 void Server::ConstructFullServiceSubTypeName(const char   *aServiceType,
1122                                              const char   *aSubTypeLabel,
1123                                              Name::Buffer &aFullName)
1124 {
1125     StringWriter fullName(aFullName, sizeof(aFullName));
1126 
1127     fullName.Append("%s._sub.%s.%s", aSubTypeLabel, aServiceType, kDefaultDomainName);
1128 }
1129 
ExtractServiceInstanceLabel(const char * aInstanceName,Name::LabelBuffer & aLabel)1130 Error Server::Response::ExtractServiceInstanceLabel(const char *aInstanceName, Name::LabelBuffer &aLabel)
1131 {
1132     uint16_t     offset;
1133     Name::Buffer serviceName;
1134 
1135     offset = mOffsets.mServiceName;
1136     IgnoreError(Name::ReadName(*mMessage, offset, serviceName));
1137 
1138     return Name::ExtractLabels(aInstanceName, serviceName, aLabel);
1139 }
1140 
RemoveQueryAndPrepareResponse(ProxyQuery & aQuery,ProxyQueryInfo & aInfo,Response & aResponse)1141 void Server::RemoveQueryAndPrepareResponse(ProxyQuery &aQuery, ProxyQueryInfo &aInfo, Response &aResponse)
1142 {
1143 #if OPENTHREAD_CONFIG_DNSSD_DISCOVERY_PROXY_ENABLE
1144     mDiscoveryProxy.CancelAction(aQuery, aInfo);
1145 #endif
1146 
1147     mProxyQueries.Dequeue(aQuery);
1148     aInfo.RemoveFrom(aQuery);
1149 
1150     if (mQueryUnsubscribe.IsSet())
1151     {
1152         Name::Buffer name;
1153 
1154         ReadQueryName(aQuery, name);
1155         mQueryUnsubscribe.Invoke(name);
1156     }
1157 
1158     aResponse.InitFrom(aQuery, aInfo);
1159 }
1160 
InitFrom(ProxyQuery & aQuery,const ProxyQueryInfo & aInfo)1161 void Server::Response::InitFrom(ProxyQuery &aQuery, const ProxyQueryInfo &aInfo)
1162 {
1163     mMessage.Reset(&aQuery);
1164     IgnoreError(mMessage->Read(0, mHeader));
1165     mType    = aInfo.mType;
1166     mOffsets = aInfo.mOffsets;
1167 }
1168 
Answer(const ServiceInstanceInfo & aInstanceInfo,const Ip6::MessageInfo & aMessageInfo)1169 void Server::Response::Answer(const ServiceInstanceInfo &aInstanceInfo, const Ip6::MessageInfo &aMessageInfo)
1170 {
1171     static const Section kSections[] = {kAnswerSection, kAdditionalDataSection};
1172 
1173     Error   error      = kErrorNone;
1174     Section srvSection = ((mType == kSrvQuery) || (mType == kSrvTxtQuery)) ? kAnswerSection : kAdditionalDataSection;
1175     Section txtSection = ((mType == kTxtQuery) || (mType == kSrvTxtQuery)) ? kAnswerSection : kAdditionalDataSection;
1176 
1177     if (mType == kPtrQuery)
1178     {
1179         Name::LabelBuffer instanceLabel;
1180 
1181         SuccessOrExit(error = ExtractServiceInstanceLabel(aInstanceInfo.mFullName, instanceLabel));
1182         mSection = kAnswerSection;
1183         SuccessOrExit(error = AppendPtrRecord(instanceLabel, aInstanceInfo.mTtl));
1184     }
1185 
1186     for (Section section : kSections)
1187     {
1188         mSection = section;
1189 
1190         if (mSection == kAdditionalDataSection)
1191         {
1192             VerifyOrExit(!(Get<Server>().mTestMode & kTestModeEmptyAdditionalSection));
1193         }
1194 
1195         if (srvSection == mSection)
1196         {
1197             SuccessOrExit(error = AppendSrvRecord(aInstanceInfo));
1198         }
1199 
1200         if (txtSection == mSection)
1201         {
1202             SuccessOrExit(error = AppendTxtRecord(aInstanceInfo));
1203         }
1204     }
1205 
1206     error = AppendHostAddresses(aInstanceInfo);
1207 
1208 exit:
1209     if (error != kErrorNone)
1210     {
1211         SetResponseCode(Header::kResponseServerFailure);
1212     }
1213 
1214     Send(aMessageInfo);
1215 }
1216 
Answer(const HostInfo & aHostInfo,const Ip6::MessageInfo & aMessageInfo)1217 void Server::Response::Answer(const HostInfo &aHostInfo, const Ip6::MessageInfo &aMessageInfo)
1218 {
1219     // Caller already ensures that `mType` is either `kAaaaQuery` or
1220     // `kAQuery`.
1221 
1222     AddrType addrType = (mType == kAaaaQuery) ? kIp6AddrType : kIp4AddrType;
1223 
1224     mSection = kAnswerSection;
1225 
1226     if (AppendHostAddresses(addrType, aHostInfo) != kErrorNone)
1227     {
1228         SetResponseCode(Header::kResponseServerFailure);
1229     }
1230 
1231     Send(aMessageInfo);
1232 }
1233 
SetQueryCallbacks(SubscribeCallback aSubscribe,UnsubscribeCallback aUnsubscribe,void * aContext)1234 void Server::SetQueryCallbacks(SubscribeCallback aSubscribe, UnsubscribeCallback aUnsubscribe, void *aContext)
1235 {
1236     OT_ASSERT((aSubscribe == nullptr) == (aUnsubscribe == nullptr));
1237 
1238     mQuerySubscribe.Set(aSubscribe, aContext);
1239     mQueryUnsubscribe.Set(aUnsubscribe, aContext);
1240 }
1241 
HandleDiscoveredServiceInstance(const char * aServiceFullName,const ServiceInstanceInfo & aInstanceInfo)1242 void Server::HandleDiscoveredServiceInstance(const char *aServiceFullName, const ServiceInstanceInfo &aInstanceInfo)
1243 {
1244     OT_ASSERT(StringEndsWith(aServiceFullName, Name::kLabelSeparatorChar));
1245     OT_ASSERT(StringEndsWith(aInstanceInfo.mFullName, Name::kLabelSeparatorChar));
1246     OT_ASSERT(StringEndsWith(aInstanceInfo.mHostName, Name::kLabelSeparatorChar));
1247 
1248     // It is safe to remove entries from `mProxyQueries` as we iterate
1249     // over it since it is a `MessageQueue`.
1250 
1251     for (ProxyQuery &query : mProxyQueries)
1252     {
1253         bool           canAnswer = false;
1254         ProxyQueryInfo info;
1255 
1256         info.ReadFrom(query);
1257 
1258         switch (info.mType)
1259         {
1260         case kPtrQuery:
1261             canAnswer = QueryNameMatches(query, aServiceFullName);
1262             break;
1263 
1264         case kSrvQuery:
1265         case kTxtQuery:
1266         case kSrvTxtQuery:
1267             canAnswer = QueryNameMatches(query, aInstanceInfo.mFullName);
1268             break;
1269 
1270         case kAaaaQuery:
1271         case kAQuery:
1272             break;
1273         }
1274 
1275         if (canAnswer)
1276         {
1277             Response response(GetInstance());
1278 
1279             RemoveQueryAndPrepareResponse(query, info, response);
1280             response.Answer(aInstanceInfo, info.mMessageInfo);
1281         }
1282     }
1283 }
1284 
HandleDiscoveredHost(const char * aHostFullName,const HostInfo & aHostInfo)1285 void Server::HandleDiscoveredHost(const char *aHostFullName, const HostInfo &aHostInfo)
1286 {
1287     OT_ASSERT(StringEndsWith(aHostFullName, Name::kLabelSeparatorChar));
1288 
1289     for (ProxyQuery &query : mProxyQueries)
1290     {
1291         ProxyQueryInfo info;
1292 
1293         info.ReadFrom(query);
1294 
1295         switch (info.mType)
1296         {
1297         case kAaaaQuery:
1298         case kAQuery:
1299             if (QueryNameMatches(query, aHostFullName))
1300             {
1301                 Response response(GetInstance());
1302 
1303                 RemoveQueryAndPrepareResponse(query, info, response);
1304                 response.Answer(aHostInfo, info.mMessageInfo);
1305             }
1306 
1307             break;
1308 
1309         default:
1310             break;
1311         }
1312     }
1313 }
1314 
GetNextQuery(const otDnssdQuery * aQuery) const1315 const otDnssdQuery *Server::GetNextQuery(const otDnssdQuery *aQuery) const
1316 {
1317     const ProxyQuery *query = static_cast<const ProxyQuery *>(aQuery);
1318 
1319     return (query == nullptr) ? mProxyQueries.GetHead() : query->GetNext();
1320 }
1321 
GetQueryTypeAndName(const otDnssdQuery * aQuery,Dns::Name::Buffer & aName)1322 Server::DnsQueryType Server::GetQueryTypeAndName(const otDnssdQuery *aQuery, Dns::Name::Buffer &aName)
1323 {
1324     const ProxyQuery *query = static_cast<const ProxyQuery *>(aQuery);
1325     ProxyQueryInfo    info;
1326     DnsQueryType      type;
1327 
1328     ReadQueryName(*query, aName);
1329     info.ReadFrom(*query);
1330 
1331     type = kDnsQueryBrowse;
1332 
1333     switch (info.mType)
1334     {
1335     case kPtrQuery:
1336         break;
1337 
1338     case kSrvQuery:
1339     case kTxtQuery:
1340     case kSrvTxtQuery:
1341         type = kDnsQueryResolve;
1342         break;
1343 
1344     case kAaaaQuery:
1345     case kAQuery:
1346         type = kDnsQueryResolveHost;
1347         break;
1348     }
1349 
1350     return type;
1351 }
1352 
HandleTimer(void)1353 void Server::HandleTimer(void)
1354 {
1355     NextFireTime nextExpire;
1356 
1357     for (ProxyQuery &query : mProxyQueries)
1358     {
1359         ProxyQueryInfo info;
1360 
1361         info.ReadFrom(query);
1362 
1363         if (info.mExpireTime <= nextExpire.GetNow())
1364         {
1365             Finalize(query, Header::kResponseSuccess);
1366         }
1367         else
1368         {
1369             nextExpire.UpdateIfEarlier(info.mExpireTime);
1370         }
1371     }
1372 
1373 #if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
1374     for (UpstreamQueryTransaction &query : mUpstreamQueryTransactions)
1375     {
1376         if (!query.IsValid())
1377         {
1378             continue;
1379         }
1380 
1381         if (query.GetExpireTime() <= nextExpire.GetNow())
1382         {
1383             otPlatDnsCancelUpstreamQuery(&GetInstance(), &query);
1384         }
1385         else
1386         {
1387             nextExpire.UpdateIfEarlier(query.GetExpireTime());
1388         }
1389     }
1390 #endif
1391 
1392     mTimer.FireAtIfEarlier(nextExpire);
1393 }
1394 
Finalize(ProxyQuery & aQuery,ResponseCode aResponseCode)1395 void Server::Finalize(ProxyQuery &aQuery, ResponseCode aResponseCode)
1396 {
1397     Response       response(GetInstance());
1398     ProxyQueryInfo info;
1399 
1400     info.ReadFrom(aQuery);
1401     RemoveQueryAndPrepareResponse(aQuery, info, response);
1402 
1403     response.SetResponseCode(aResponseCode);
1404     response.Send(info.mMessageInfo);
1405 }
1406 
UpdateResponseCounters(ResponseCode aResponseCode)1407 void Server::UpdateResponseCounters(ResponseCode aResponseCode)
1408 {
1409     switch (aResponseCode)
1410     {
1411     case UpdateHeader::kResponseSuccess:
1412         ++mCounters.mSuccessResponse;
1413         break;
1414     case UpdateHeader::kResponseServerFailure:
1415         ++mCounters.mServerFailureResponse;
1416         break;
1417     case UpdateHeader::kResponseFormatError:
1418         ++mCounters.mFormatErrorResponse;
1419         break;
1420     case UpdateHeader::kResponseNameError:
1421         ++mCounters.mNameErrorResponse;
1422         break;
1423     case UpdateHeader::kResponseNotImplemented:
1424         ++mCounters.mNotImplementedResponse;
1425         break;
1426     default:
1427         ++mCounters.mOtherResponse;
1428         break;
1429     }
1430 }
1431 
1432 #if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
Init(const Ip6::MessageInfo & aMessageInfo)1433 void Server::UpstreamQueryTransaction::Init(const Ip6::MessageInfo &aMessageInfo)
1434 {
1435     mMessageInfo = aMessageInfo;
1436     mValid       = true;
1437     mExpireTime  = TimerMilli::GetNow() + kQueryTimeout;
1438 }
1439 
ResetUpstreamQueryTransaction(UpstreamQueryTransaction & aTxn,Error aError)1440 void Server::ResetUpstreamQueryTransaction(UpstreamQueryTransaction &aTxn, Error aError)
1441 {
1442     int index = static_cast<int>(&aTxn - mUpstreamQueryTransactions);
1443 
1444     // Avoid the warnings when info / warn logging is disabled.
1445     OT_UNUSED_VARIABLE(index);
1446     if (aError == kErrorNone)
1447     {
1448         mCounters.mUpstreamDnsCounters.mResponses++;
1449         LogInfo("Upstream query transaction %d completed.", index);
1450     }
1451     else
1452     {
1453         mCounters.mUpstreamDnsCounters.mFailures++;
1454         LogWarn("Upstream query transaction %d closed: %s.", index, ErrorToString(aError));
1455     }
1456     aTxn.Reset();
1457 }
1458 #endif
1459 
1460 #if OPENTHREAD_CONFIG_DNSSD_DISCOVERY_PROXY_ENABLE
1461 
DiscoveryProxy(Instance & aInstance)1462 Server::DiscoveryProxy::DiscoveryProxy(Instance &aInstance)
1463     : InstanceLocator(aInstance)
1464     , mIsRunning(false)
1465 {
1466 }
1467 
UpdateState(void)1468 void Server::DiscoveryProxy::UpdateState(void)
1469 {
1470     if (Get<Server>().IsRunning() && Get<Dnssd>().IsReady() && Get<BorderRouter::InfraIf>().IsRunning())
1471     {
1472         Start();
1473     }
1474     else
1475     {
1476         Stop();
1477     }
1478 }
1479 
Start(void)1480 void Server::DiscoveryProxy::Start(void)
1481 {
1482     VerifyOrExit(!mIsRunning);
1483     mIsRunning = true;
1484     LogInfo("Started discovery proxy");
1485 
1486 exit:
1487     return;
1488 }
1489 
Stop(void)1490 void Server::DiscoveryProxy::Stop(void)
1491 {
1492     VerifyOrExit(mIsRunning);
1493 
1494     for (ProxyQuery &query : Get<Server>().mProxyQueries)
1495     {
1496         Get<Server>().Finalize(query, Header::kResponseSuccess);
1497     }
1498 
1499     mIsRunning = false;
1500     LogInfo("Stopped discovery proxy");
1501 
1502 exit:
1503     return;
1504 }
1505 
Resolve(ProxyQuery & aQuery,ProxyQueryInfo & aInfo)1506 void Server::DiscoveryProxy::Resolve(ProxyQuery &aQuery, ProxyQueryInfo &aInfo)
1507 {
1508     ProxyAction action = kNoAction;
1509 
1510     switch (aInfo.mType)
1511     {
1512     case kPtrQuery:
1513         action = kBrowsing;
1514         break;
1515 
1516     case kSrvQuery:
1517     case kSrvTxtQuery:
1518         action = kResolvingSrv;
1519         break;
1520 
1521     case kTxtQuery:
1522         action = kResolvingTxt;
1523         break;
1524 
1525     case kAaaaQuery:
1526         action = kResolvingIp6Address;
1527         break;
1528     case kAQuery:
1529         action = kResolvingIp4Address;
1530         break;
1531     }
1532 
1533     Perform(action, aQuery, aInfo);
1534 }
1535 
Perform(ProxyAction aAction,ProxyQuery & aQuery,ProxyQueryInfo & aInfo)1536 void Server::DiscoveryProxy::Perform(ProxyAction aAction, ProxyQuery &aQuery, ProxyQueryInfo &aInfo)
1537 {
1538     bool         shouldStart;
1539     Name::Buffer name;
1540 
1541     VerifyOrExit(aAction != kNoAction);
1542 
1543     // The order of the steps below is crucial. First, we read the
1544     // name associated with the action. Then we check if another
1545     // query has an active browser/resolver for the same name. This
1546     // helps us determine if a new browser/resolver is needed. Then,
1547     // we update the `ProxyQueryInfo` within `aQuery` to reflect the
1548     // `aAction` being performed. Finally, if necessary, we start the
1549     // proper browser/resolver on DNS-SD/mDNS. Placing this last
1550     // ensures correct processing even if a DNS-SD/mDNS callback is
1551     // invoked immediately.
1552 
1553     ReadNameFor(aAction, aQuery, aInfo, name);
1554 
1555     shouldStart = !HasActive(aAction, name);
1556 
1557     aInfo.mAction = aAction;
1558     aInfo.UpdateIn(aQuery);
1559 
1560     VerifyOrExit(shouldStart);
1561     UpdateProxy(kStart, aAction, aQuery, aInfo, name);
1562 
1563 exit:
1564     return;
1565 }
1566 
ReadNameFor(ProxyAction aAction,ProxyQuery & aQuery,ProxyQueryInfo & aInfo,Name::Buffer & aName) const1567 void Server::DiscoveryProxy::ReadNameFor(ProxyAction     aAction,
1568                                          ProxyQuery     &aQuery,
1569                                          ProxyQueryInfo &aInfo,
1570                                          Name::Buffer   &aName) const
1571 {
1572     // Read the name corresponding to `aAction` from `aQuery`.
1573 
1574     switch (aAction)
1575     {
1576     case kNoAction:
1577         break;
1578     case kBrowsing:
1579         ReadQueryName(aQuery, aName);
1580         break;
1581     case kResolvingSrv:
1582     case kResolvingTxt:
1583         ReadQueryInstanceName(aQuery, aInfo, aName);
1584         break;
1585     case kResolvingIp6Address:
1586     case kResolvingIp4Address:
1587         ReadQueryHostName(aQuery, aInfo, aName);
1588         break;
1589     }
1590 }
1591 
CancelAction(ProxyQuery & aQuery,ProxyQueryInfo & aInfo)1592 void Server::DiscoveryProxy::CancelAction(ProxyQuery &aQuery, ProxyQueryInfo &aInfo)
1593 {
1594     // Cancel the current action for a given `aQuery`, then
1595     // determine if we need to stop any browser/resolver
1596     // on infrastructure.
1597 
1598     ProxyAction  action = aInfo.mAction;
1599     Name::Buffer name;
1600 
1601     VerifyOrExit(mIsRunning);
1602     VerifyOrExit(action != kNoAction);
1603 
1604     // We first update the `aInfo` on `aQuery` before calling
1605     // `HasActive()`. This ensures that the current query is not
1606     // taken into account when we try to determine if any query
1607     // is waiting for same `aAction` browser/resolver.
1608 
1609     ReadNameFor(action, aQuery, aInfo, name);
1610 
1611     aInfo.mAction = kNoAction;
1612     aInfo.UpdateIn(aQuery);
1613 
1614     VerifyOrExit(!HasActive(action, name));
1615     UpdateProxy(kStop, action, aQuery, aInfo, name);
1616 
1617 exit:
1618     return;
1619 }
1620 
UpdateProxy(Command aCommand,ProxyAction aAction,const ProxyQuery & aQuery,const ProxyQueryInfo & aInfo,Name::Buffer & aName)1621 void Server::DiscoveryProxy::UpdateProxy(Command               aCommand,
1622                                          ProxyAction           aAction,
1623                                          const ProxyQuery     &aQuery,
1624                                          const ProxyQueryInfo &aInfo,
1625                                          Name::Buffer         &aName)
1626 {
1627     // Start or stop browser/resolver corresponding to `aAction`.
1628     // `aName` may be changed.
1629 
1630     switch (aAction)
1631     {
1632     case kNoAction:
1633         break;
1634     case kBrowsing:
1635         StartOrStopBrowser(aCommand, aName);
1636         break;
1637     case kResolvingSrv:
1638         StartOrStopSrvResolver(aCommand, aQuery, aInfo);
1639         break;
1640     case kResolvingTxt:
1641         StartOrStopTxtResolver(aCommand, aQuery, aInfo);
1642         break;
1643     case kResolvingIp6Address:
1644         StartOrStopIp6Resolver(aCommand, aName);
1645         break;
1646     case kResolvingIp4Address:
1647         StartOrStopIp4Resolver(aCommand, aName);
1648         break;
1649     }
1650 }
1651 
StartOrStopBrowser(Command aCommand,Name::Buffer & aServiceName)1652 void Server::DiscoveryProxy::StartOrStopBrowser(Command aCommand, Name::Buffer &aServiceName)
1653 {
1654     // Start or stop a service browser for a given service type
1655     // or sub-type.
1656 
1657     static const char kFullSubLabel[] = "._sub.";
1658 
1659     Dnssd::Browser browser;
1660     char          *ptr;
1661 
1662     browser.Clear();
1663 
1664     IgnoreError(StripDomainName(aServiceName));
1665 
1666     // Check if the service name is a sub-type with name
1667     // format: "<sub-label>._sub.<service-labels>.
1668 
1669     ptr = AsNonConst(StringFind(aServiceName, kFullSubLabel, kStringCaseInsensitiveMatch));
1670 
1671     if (ptr != nullptr)
1672     {
1673         *ptr = kNullChar;
1674         ptr += sizeof(kFullSubLabel) - 1;
1675 
1676         browser.mServiceType  = ptr;
1677         browser.mSubTypeLabel = aServiceName;
1678     }
1679     else
1680     {
1681         browser.mServiceType  = aServiceName;
1682         browser.mSubTypeLabel = nullptr;
1683     }
1684 
1685     browser.mInfraIfIndex = Get<BorderRouter::InfraIf>().GetIfIndex();
1686     browser.mCallback     = HandleBrowseResult;
1687 
1688     switch (aCommand)
1689     {
1690     case kStart:
1691         Get<Dnssd>().StartBrowser(browser);
1692         break;
1693 
1694     case kStop:
1695         Get<Dnssd>().StopBrowser(browser);
1696         break;
1697     }
1698 }
1699 
StartOrStopSrvResolver(Command aCommand,const ProxyQuery & aQuery,const ProxyQueryInfo & aInfo)1700 void Server::DiscoveryProxy::StartOrStopSrvResolver(Command               aCommand,
1701                                                     const ProxyQuery     &aQuery,
1702                                                     const ProxyQueryInfo &aInfo)
1703 {
1704     // Start or stop an SRV record resolver for a given query.
1705 
1706     Dnssd::SrvResolver resolver;
1707     Name::LabelBuffer  instanceLabel;
1708     Name::Buffer       serviceType;
1709 
1710     ReadQueryInstanceName(aQuery, aInfo, instanceLabel, serviceType);
1711 
1712     resolver.Clear();
1713 
1714     resolver.mServiceInstance = instanceLabel;
1715     resolver.mServiceType     = serviceType;
1716     resolver.mInfraIfIndex    = Get<BorderRouter::InfraIf>().GetIfIndex();
1717     resolver.mCallback        = HandleSrvResult;
1718 
1719     switch (aCommand)
1720     {
1721     case kStart:
1722         Get<Dnssd>().StartSrvResolver(resolver);
1723         break;
1724 
1725     case kStop:
1726         Get<Dnssd>().StopSrvResolver(resolver);
1727         break;
1728     }
1729 }
1730 
StartOrStopTxtResolver(Command aCommand,const ProxyQuery & aQuery,const ProxyQueryInfo & aInfo)1731 void Server::DiscoveryProxy::StartOrStopTxtResolver(Command               aCommand,
1732                                                     const ProxyQuery     &aQuery,
1733                                                     const ProxyQueryInfo &aInfo)
1734 {
1735     // Start or stop a TXT record resolver for a given query.
1736 
1737     Dnssd::TxtResolver resolver;
1738     Name::LabelBuffer  instanceLabel;
1739     Name::Buffer       serviceType;
1740 
1741     ReadQueryInstanceName(aQuery, aInfo, instanceLabel, serviceType);
1742 
1743     resolver.Clear();
1744 
1745     resolver.mServiceInstance = instanceLabel;
1746     resolver.mServiceType     = serviceType;
1747     resolver.mInfraIfIndex    = Get<BorderRouter::InfraIf>().GetIfIndex();
1748     resolver.mCallback        = HandleTxtResult;
1749 
1750     switch (aCommand)
1751     {
1752     case kStart:
1753         Get<Dnssd>().StartTxtResolver(resolver);
1754         break;
1755 
1756     case kStop:
1757         Get<Dnssd>().StopTxtResolver(resolver);
1758         break;
1759     }
1760 }
1761 
StartOrStopIp6Resolver(Command aCommand,Name::Buffer & aHostName)1762 void Server::DiscoveryProxy::StartOrStopIp6Resolver(Command aCommand, Name::Buffer &aHostName)
1763 {
1764     // Start or stop an IPv6 address resolver for a given host name.
1765 
1766     Dnssd::AddressResolver resolver;
1767 
1768     IgnoreError(StripDomainName(aHostName));
1769 
1770     resolver.mHostName     = aHostName;
1771     resolver.mInfraIfIndex = Get<BorderRouter::InfraIf>().GetIfIndex();
1772     resolver.mCallback     = HandleIp6AddressResult;
1773 
1774     switch (aCommand)
1775     {
1776     case kStart:
1777         Get<Dnssd>().StartIp6AddressResolver(resolver);
1778         break;
1779 
1780     case kStop:
1781         Get<Dnssd>().StopIp6AddressResolver(resolver);
1782         break;
1783     }
1784 }
1785 
StartOrStopIp4Resolver(Command aCommand,Name::Buffer & aHostName)1786 void Server::DiscoveryProxy::StartOrStopIp4Resolver(Command aCommand, Name::Buffer &aHostName)
1787 {
1788     // Start or stop an IPv4 address resolver for a given host name.
1789 
1790     Dnssd::AddressResolver resolver;
1791 
1792     IgnoreError(StripDomainName(aHostName));
1793 
1794     resolver.mHostName     = aHostName;
1795     resolver.mInfraIfIndex = Get<BorderRouter::InfraIf>().GetIfIndex();
1796     resolver.mCallback     = HandleIp4AddressResult;
1797 
1798     switch (aCommand)
1799     {
1800     case kStart:
1801         Get<Dnssd>().StartIp4AddressResolver(resolver);
1802         break;
1803 
1804     case kStop:
1805         Get<Dnssd>().StopIp4AddressResolver(resolver);
1806         break;
1807     }
1808 }
1809 
QueryMatches(const ProxyQuery & aQuery,const ProxyQueryInfo & aInfo,ProxyAction aAction,const Name::Buffer & aName) const1810 bool Server::DiscoveryProxy::QueryMatches(const ProxyQuery     &aQuery,
1811                                           const ProxyQueryInfo &aInfo,
1812                                           ProxyAction           aAction,
1813                                           const Name::Buffer   &aName) const
1814 {
1815     // Check whether `aQuery` is performing `aAction` and
1816     // its name matches `aName`.
1817 
1818     bool matches = false;
1819 
1820     VerifyOrExit(aInfo.mAction == aAction);
1821 
1822     switch (aAction)
1823     {
1824     case kBrowsing:
1825         VerifyOrExit(QueryNameMatches(aQuery, aName));
1826         break;
1827     case kResolvingSrv:
1828     case kResolvingTxt:
1829         VerifyOrExit(QueryInstanceNameMatches(aQuery, aInfo, aName));
1830         break;
1831     case kResolvingIp6Address:
1832     case kResolvingIp4Address:
1833         VerifyOrExit(QueryHostNameMatches(aQuery, aInfo, aName));
1834         break;
1835     case kNoAction:
1836         ExitNow();
1837     }
1838 
1839     matches = true;
1840 
1841 exit:
1842     return matches;
1843 }
1844 
HasActive(ProxyAction aAction,const Name::Buffer & aName) const1845 bool Server::DiscoveryProxy::HasActive(ProxyAction aAction, const Name::Buffer &aName) const
1846 {
1847     // Determine whether or not we have an active browser/resolver
1848     // corresponding to `aAction` for `aName`.
1849 
1850     bool has = false;
1851 
1852     for (const ProxyQuery &query : Get<Server>().mProxyQueries)
1853     {
1854         ProxyQueryInfo info;
1855 
1856         info.ReadFrom(query);
1857 
1858         if (QueryMatches(query, info, aAction, aName))
1859         {
1860             has = true;
1861             break;
1862         }
1863     }
1864 
1865     return has;
1866 }
1867 
HandleBrowseResult(otInstance * aInstance,const otPlatDnssdBrowseResult * aResult)1868 void Server::DiscoveryProxy::HandleBrowseResult(otInstance *aInstance, const otPlatDnssdBrowseResult *aResult)
1869 {
1870     AsCoreType(aInstance).Get<Server>().mDiscoveryProxy.HandleBrowseResult(*aResult);
1871 }
1872 
HandleBrowseResult(const Dnssd::BrowseResult & aResult)1873 void Server::DiscoveryProxy::HandleBrowseResult(const Dnssd::BrowseResult &aResult)
1874 {
1875     Name::Buffer serviceName;
1876 
1877     VerifyOrExit(mIsRunning);
1878     VerifyOrExit(aResult.mTtl != 0);
1879     VerifyOrExit(aResult.mInfraIfIndex == Get<BorderRouter::InfraIf>().GetIfIndex());
1880 
1881     if (aResult.mSubTypeLabel != nullptr)
1882     {
1883         ConstructFullServiceSubTypeName(aResult.mServiceType, aResult.mSubTypeLabel, serviceName);
1884     }
1885     else
1886     {
1887         ConstructFullName(aResult.mServiceType, serviceName);
1888     }
1889 
1890     HandleResult(kBrowsing, serviceName, &Response::AppendPtrRecord, ProxyResult(aResult));
1891 
1892 exit:
1893     return;
1894 }
1895 
HandleSrvResult(otInstance * aInstance,const otPlatDnssdSrvResult * aResult)1896 void Server::DiscoveryProxy::HandleSrvResult(otInstance *aInstance, const otPlatDnssdSrvResult *aResult)
1897 {
1898     AsCoreType(aInstance).Get<Server>().mDiscoveryProxy.HandleSrvResult(*aResult);
1899 }
1900 
HandleSrvResult(const Dnssd::SrvResult & aResult)1901 void Server::DiscoveryProxy::HandleSrvResult(const Dnssd::SrvResult &aResult)
1902 {
1903     Name::Buffer instanceName;
1904 
1905     VerifyOrExit(mIsRunning);
1906     VerifyOrExit(aResult.mTtl != 0);
1907     VerifyOrExit(aResult.mInfraIfIndex == Get<BorderRouter::InfraIf>().GetIfIndex());
1908 
1909     ConstructFullInstanceName(aResult.mServiceInstance, aResult.mServiceType, instanceName);
1910     HandleResult(kResolvingSrv, instanceName, &Response::AppendSrvRecord, ProxyResult(aResult));
1911 
1912 exit:
1913     return;
1914 }
1915 
HandleTxtResult(otInstance * aInstance,const otPlatDnssdTxtResult * aResult)1916 void Server::DiscoveryProxy::HandleTxtResult(otInstance *aInstance, const otPlatDnssdTxtResult *aResult)
1917 {
1918     AsCoreType(aInstance).Get<Server>().mDiscoveryProxy.HandleTxtResult(*aResult);
1919 }
1920 
HandleTxtResult(const Dnssd::TxtResult & aResult)1921 void Server::DiscoveryProxy::HandleTxtResult(const Dnssd::TxtResult &aResult)
1922 {
1923     Name::Buffer instanceName;
1924 
1925     VerifyOrExit(mIsRunning);
1926     VerifyOrExit(aResult.mTtl != 0);
1927     VerifyOrExit(aResult.mInfraIfIndex == Get<BorderRouter::InfraIf>().GetIfIndex());
1928 
1929     ConstructFullInstanceName(aResult.mServiceInstance, aResult.mServiceType, instanceName);
1930     HandleResult(kResolvingTxt, instanceName, &Response::AppendTxtRecord, ProxyResult(aResult));
1931 
1932 exit:
1933     return;
1934 }
1935 
HandleIp6AddressResult(otInstance * aInstance,const otPlatDnssdAddressResult * aResult)1936 void Server::DiscoveryProxy::HandleIp6AddressResult(otInstance *aInstance, const otPlatDnssdAddressResult *aResult)
1937 {
1938     AsCoreType(aInstance).Get<Server>().mDiscoveryProxy.HandleIp6AddressResult(*aResult);
1939 }
1940 
HandleIp6AddressResult(const Dnssd::AddressResult & aResult)1941 void Server::DiscoveryProxy::HandleIp6AddressResult(const Dnssd::AddressResult &aResult)
1942 {
1943     bool         hasValidAddress = false;
1944     Name::Buffer fullHostName;
1945 
1946     VerifyOrExit(mIsRunning);
1947     VerifyOrExit(aResult.mInfraIfIndex == Get<BorderRouter::InfraIf>().GetIfIndex());
1948 
1949     for (uint16_t index = 0; index < aResult.mAddressesLength; index++)
1950     {
1951         const Dnssd::AddressAndTtl &entry   = aResult.mAddresses[index];
1952         const Ip6::Address         &address = AsCoreType(&entry.mAddress);
1953 
1954         if (entry.mTtl == 0)
1955         {
1956             continue;
1957         }
1958 
1959         if (IsProxyAddressValid(address))
1960         {
1961             hasValidAddress = true;
1962             break;
1963         }
1964     }
1965 
1966     VerifyOrExit(hasValidAddress);
1967 
1968     ConstructFullName(aResult.mHostName, fullHostName);
1969     HandleResult(kResolvingIp6Address, fullHostName, &Response::AppendHostIp6Addresses, ProxyResult(aResult));
1970 
1971 exit:
1972     return;
1973 }
1974 
HandleIp4AddressResult(otInstance * aInstance,const otPlatDnssdAddressResult * aResult)1975 void Server::DiscoveryProxy::HandleIp4AddressResult(otInstance *aInstance, const otPlatDnssdAddressResult *aResult)
1976 {
1977     AsCoreType(aInstance).Get<Server>().mDiscoveryProxy.HandleIp4AddressResult(*aResult);
1978 }
1979 
HandleIp4AddressResult(const Dnssd::AddressResult & aResult)1980 void Server::DiscoveryProxy::HandleIp4AddressResult(const Dnssd::AddressResult &aResult)
1981 {
1982     bool         hasValidAddress = false;
1983     Name::Buffer fullHostName;
1984 
1985     VerifyOrExit(mIsRunning);
1986     VerifyOrExit(aResult.mInfraIfIndex == Get<BorderRouter::InfraIf>().GetIfIndex());
1987 
1988     for (uint16_t index = 0; index < aResult.mAddressesLength; index++)
1989     {
1990         const Dnssd::AddressAndTtl &entry   = aResult.mAddresses[index];
1991         const Ip6::Address         &address = AsCoreType(&entry.mAddress);
1992 
1993         if (entry.mTtl == 0)
1994         {
1995             continue;
1996         }
1997 
1998         if (address.IsIp4Mapped())
1999         {
2000             hasValidAddress = true;
2001             break;
2002         }
2003     }
2004 
2005     VerifyOrExit(hasValidAddress);
2006 
2007     ConstructFullName(aResult.mHostName, fullHostName);
2008     HandleResult(kResolvingIp4Address, fullHostName, &Response::AppendHostIp4Addresses, ProxyResult(aResult));
2009 
2010 exit:
2011     return;
2012 }
2013 
HandleResult(ProxyAction aAction,const Name::Buffer & aName,ResponseAppender aAppender,const ProxyResult & aResult)2014 void Server::DiscoveryProxy::HandleResult(ProxyAction         aAction,
2015                                           const Name::Buffer &aName,
2016                                           ResponseAppender    aAppender,
2017                                           const ProxyResult  &aResult)
2018 {
2019     // Common method that handles result from DNS-SD/mDNS. It
2020     // iterates over all `ProxyQuery` entries and checks if any entry
2021     // is waiting for the result of `aAction` for `aName`. Matching
2022     // queries are updated using the `aAppender` method pointer,
2023     // which appends the corresponding record(s) to the response. We
2024     // then determine the next action to be performed for the
2025     // `ProxyQuery` or if it can be finalized.
2026 
2027     ProxyQueryList nextActionQueries;
2028     ProxyQueryInfo info;
2029     ProxyAction    nextAction;
2030 
2031     for (ProxyQuery &query : Get<Server>().mProxyQueries)
2032     {
2033         Response response(GetInstance());
2034         bool     shouldFinalize;
2035 
2036         info.ReadFrom(query);
2037 
2038         if (!QueryMatches(query, info, aAction, aName))
2039         {
2040             continue;
2041         }
2042 
2043         CancelAction(query, info);
2044 
2045         nextAction = kNoAction;
2046 
2047         switch (aAction)
2048         {
2049         case kBrowsing:
2050             nextAction = kResolvingSrv;
2051             break;
2052         case kResolvingSrv:
2053             nextAction = (info.mType == kSrvQuery) ? kResolvingIp6Address : kResolvingTxt;
2054             break;
2055         case kResolvingTxt:
2056             nextAction = (info.mType == kTxtQuery) ? kNoAction : kResolvingIp6Address;
2057             break;
2058         case kNoAction:
2059         case kResolvingIp6Address:
2060         case kResolvingIp4Address:
2061             break;
2062         }
2063 
2064         shouldFinalize = (nextAction == kNoAction);
2065 
2066         if ((Get<Server>().mTestMode & kTestModeEmptyAdditionalSection) &&
2067             IsActionForAdditionalSection(nextAction, info.mType))
2068         {
2069             shouldFinalize = true;
2070         }
2071 
2072         Get<Server>().mProxyQueries.Dequeue(query);
2073         info.RemoveFrom(query);
2074         response.InitFrom(query, info);
2075 
2076         if ((response.*aAppender)(aResult) != kErrorNone)
2077         {
2078             response.SetResponseCode(Header::kResponseServerFailure);
2079             shouldFinalize = true;
2080         }
2081 
2082         if (shouldFinalize)
2083         {
2084             response.Send(info.mMessageInfo);
2085             continue;
2086         }
2087 
2088         // The `query` is not yet finished and we need to perform
2089         // the `nextAction` for it.
2090 
2091         // Reinitialize `response` as a `ProxyQuey` by updating
2092         // and appending `info` to it after the newly appended
2093         // records from `aResult` and saving the `mHeader`.
2094 
2095         info.mOffsets = response.mOffsets;
2096         info.mAction  = nextAction;
2097         response.mMessage->Write(0, response.mHeader);
2098 
2099         if (response.mMessage->Append(info) != kErrorNone)
2100         {
2101             response.SetResponseCode(Header::kResponseServerFailure);
2102             response.Send(info.mMessageInfo);
2103             continue;
2104         }
2105 
2106         // Take back ownership of `response.mMessage` as we still
2107         // treat it as a `ProxyQuery`.
2108 
2109         response.mMessage.Release();
2110 
2111         // We place the `query` in a separate list and add it back to
2112         // the main `mProxyQueries` list after we are done with the
2113         // current iteration. This ensures that other entries in the
2114         // `mProxyQueries` list are not updated or removed due to the
2115         // DNS-SD platform callback being invoked immediately when we
2116         // potentially start a browser or resolver to perform the
2117         // `nextAction` for `query`.
2118 
2119         nextActionQueries.Enqueue(query);
2120     }
2121 
2122     for (ProxyQuery &query : nextActionQueries)
2123     {
2124         nextActionQueries.Dequeue(query);
2125 
2126         info.ReadFrom(query);
2127 
2128         nextAction = info.mAction;
2129 
2130         info.mAction = kNoAction;
2131         info.UpdateIn(query);
2132 
2133         Get<Server>().mProxyQueries.Enqueue(query);
2134         Perform(nextAction, query, info);
2135     }
2136 }
2137 
IsActionForAdditionalSection(ProxyAction aAction,QueryType aQueryType)2138 bool Server::DiscoveryProxy::IsActionForAdditionalSection(ProxyAction aAction, QueryType aQueryType)
2139 {
2140     bool isForAddnlSection = false;
2141 
2142     switch (aAction)
2143     {
2144     case kResolvingSrv:
2145         VerifyOrExit((aQueryType == kSrvQuery) || (aQueryType == kSrvTxtQuery));
2146         break;
2147     case kResolvingTxt:
2148         VerifyOrExit((aQueryType == kTxtQuery) || (aQueryType == kSrvTxtQuery));
2149         break;
2150 
2151     case kResolvingIp6Address:
2152         VerifyOrExit(aQueryType == kAaaaQuery);
2153         break;
2154 
2155     case kResolvingIp4Address:
2156         VerifyOrExit(aQueryType == kAQuery);
2157         break;
2158 
2159     case kNoAction:
2160     case kBrowsing:
2161         ExitNow();
2162     }
2163 
2164     isForAddnlSection = true;
2165 
2166 exit:
2167     return isForAddnlSection;
2168 }
2169 
AppendPtrRecord(const ProxyResult & aResult)2170 Error Server::Response::AppendPtrRecord(const ProxyResult &aResult)
2171 {
2172     const Dnssd::BrowseResult *browseResult = aResult.mBrowseResult;
2173 
2174     mSection = kAnswerSection;
2175 
2176     return AppendPtrRecord(browseResult->mServiceInstance, browseResult->mTtl);
2177 }
2178 
AppendSrvRecord(const ProxyResult & aResult)2179 Error Server::Response::AppendSrvRecord(const ProxyResult &aResult)
2180 {
2181     const Dnssd::SrvResult *srvResult = aResult.mSrvResult;
2182     Name::Buffer            fullHostName;
2183 
2184     mSection = ((mType == kSrvQuery) || (mType == kSrvTxtQuery)) ? kAnswerSection : kAdditionalDataSection;
2185 
2186     ConstructFullName(srvResult->mHostName, fullHostName);
2187 
2188     return AppendSrvRecord(fullHostName, srvResult->mTtl, srvResult->mPriority, srvResult->mWeight, srvResult->mPort);
2189 }
2190 
AppendTxtRecord(const ProxyResult & aResult)2191 Error Server::Response::AppendTxtRecord(const ProxyResult &aResult)
2192 {
2193     const Dnssd::TxtResult *txtResult = aResult.mTxtResult;
2194 
2195     mSection = ((mType == kTxtQuery) || (mType == kSrvTxtQuery)) ? kAnswerSection : kAdditionalDataSection;
2196 
2197     return AppendTxtRecord(txtResult->mTxtData, txtResult->mTxtDataLength, txtResult->mTtl);
2198 }
2199 
AppendHostIp6Addresses(const ProxyResult & aResult)2200 Error Server::Response::AppendHostIp6Addresses(const ProxyResult &aResult)
2201 {
2202     Error                       error      = kErrorNone;
2203     const Dnssd::AddressResult *addrResult = aResult.mAddressResult;
2204 
2205     mSection = (mType == kAaaaQuery) ? kAnswerSection : kAdditionalDataSection;
2206 
2207     for (uint16_t index = 0; index < addrResult->mAddressesLength; index++)
2208     {
2209         const Dnssd::AddressAndTtl &entry   = addrResult->mAddresses[index];
2210         const Ip6::Address         &address = AsCoreType(&entry.mAddress);
2211 
2212         if (entry.mTtl == 0)
2213         {
2214             continue;
2215         }
2216 
2217         if (!IsProxyAddressValid(address))
2218         {
2219             continue;
2220         }
2221 
2222         SuccessOrExit(error = AppendAaaaRecord(address, entry.mTtl));
2223     }
2224 
2225 exit:
2226     return error;
2227 }
2228 
AppendHostIp4Addresses(const ProxyResult & aResult)2229 Error Server::Response::AppendHostIp4Addresses(const ProxyResult &aResult)
2230 {
2231     Error                       error      = kErrorNone;
2232     const Dnssd::AddressResult *addrResult = aResult.mAddressResult;
2233 
2234     mSection = (mType == kAQuery) ? kAnswerSection : kAdditionalDataSection;
2235 
2236     for (uint16_t index = 0; index < addrResult->mAddressesLength; index++)
2237     {
2238         const Dnssd::AddressAndTtl &entry   = addrResult->mAddresses[index];
2239         const Ip6::Address         &address = AsCoreType(&entry.mAddress);
2240 
2241         if (entry.mTtl == 0)
2242         {
2243             continue;
2244         }
2245 
2246         SuccessOrExit(error = AppendARecord(address, entry.mTtl));
2247     }
2248 
2249 exit:
2250     return error;
2251 }
2252 
IsProxyAddressValid(const Ip6::Address & aAddress)2253 bool Server::IsProxyAddressValid(const Ip6::Address &aAddress)
2254 {
2255     return !aAddress.IsLinkLocalUnicast() && !aAddress.IsMulticast() && !aAddress.IsUnspecified() &&
2256            !aAddress.IsLoopback();
2257 }
2258 
2259 #endif // OPENTHREAD_CONFIG_DNSSD_DISCOVERY_PROXY_ENABLE
2260 
2261 } // namespace ServiceDiscovery
2262 } // namespace Dns
2263 } // namespace ot
2264 
2265 #if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE && OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_MOCK_PLAT_APIS_ENABLE
otPlatDnsStartUpstreamQuery(otInstance * aInstance,otPlatDnsUpstreamQuery * aTxn,const otMessage * aQuery)2266 void otPlatDnsStartUpstreamQuery(otInstance *aInstance, otPlatDnsUpstreamQuery *aTxn, const otMessage *aQuery)
2267 {
2268     OT_UNUSED_VARIABLE(aInstance);
2269     OT_UNUSED_VARIABLE(aTxn);
2270     OT_UNUSED_VARIABLE(aQuery);
2271 }
2272 
otPlatDnsCancelUpstreamQuery(otInstance * aInstance,otPlatDnsUpstreamQuery * aTxn)2273 void otPlatDnsCancelUpstreamQuery(otInstance *aInstance, otPlatDnsUpstreamQuery *aTxn)
2274 {
2275     otPlatDnsUpstreamQueryDone(aInstance, aTxn, nullptr);
2276 }
2277 #endif
2278 
2279 #endif // OPENTHREAD_CONFIG_DNS_SERVER_ENABLE
2280