• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2020, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file includes implementation for SRP server.
32  */
33 
34 #include "srp_server.hpp"
35 
36 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
37 
38 #include "common/as_core_type.hpp"
39 #include "common/const_cast.hpp"
40 #include "common/instance.hpp"
41 #include "common/locator_getters.hpp"
42 #include "common/log.hpp"
43 #include "common/new.hpp"
44 #include "common/random.hpp"
45 #include "net/dns_types.hpp"
46 #include "thread/thread_netif.hpp"
47 
48 namespace ot {
49 namespace Srp {
50 
51 RegisterLogModule("SrpServer");
52 
53 static const char kDefaultDomain[]       = "default.service.arpa.";
54 static const char kServiceSubTypeLabel[] = "._sub.";
55 
ErrorToDnsResponseCode(Error aError)56 static Dns::UpdateHeader::Response ErrorToDnsResponseCode(Error aError)
57 {
58     Dns::UpdateHeader::Response responseCode;
59 
60     switch (aError)
61     {
62     case kErrorNone:
63         responseCode = Dns::UpdateHeader::kResponseSuccess;
64         break;
65     case kErrorNoBufs:
66         responseCode = Dns::UpdateHeader::kResponseServerFailure;
67         break;
68     case kErrorParse:
69         responseCode = Dns::UpdateHeader::kResponseFormatError;
70         break;
71     case kErrorDuplicated:
72         responseCode = Dns::UpdateHeader::kResponseNameExists;
73         break;
74     default:
75         responseCode = Dns::UpdateHeader::kResponseRefused;
76         break;
77     }
78 
79     return responseCode;
80 }
81 
82 //---------------------------------------------------------------------------------------------------------------------
83 // Server
84 
Server(Instance & aInstance)85 Server::Server(Instance &aInstance)
86     : InstanceLocator(aInstance)
87     , mSocket(aInstance)
88     , mServiceUpdateHandler(nullptr)
89     , mServiceUpdateHandlerContext(nullptr)
90     , mLeaseTimer(aInstance, HandleLeaseTimer)
91     , mOutstandingUpdatesTimer(aInstance, HandleOutstandingUpdatesTimer)
92     , mServiceUpdateId(Random::NonCrypto::GetUint32())
93     , mPort(kUdpPortMin)
94     , mState(kStateDisabled)
95     , mAddressMode(kDefaultAddressMode)
96     , mAnycastSequenceNumber(0)
97     , mHasRegisteredAnyService(false)
98 {
99     IgnoreError(SetDomain(kDefaultDomain));
100 }
101 
SetServiceHandler(otSrpServerServiceUpdateHandler aServiceHandler,void * aServiceHandlerContext)102 void Server::SetServiceHandler(otSrpServerServiceUpdateHandler aServiceHandler, void *aServiceHandlerContext)
103 {
104     mServiceUpdateHandler        = aServiceHandler;
105     mServiceUpdateHandlerContext = aServiceHandlerContext;
106 }
107 
SetAddressMode(AddressMode aMode)108 Error Server::SetAddressMode(AddressMode aMode)
109 {
110     Error error = kErrorNone;
111 
112     VerifyOrExit(mState == kStateDisabled, error = kErrorInvalidState);
113     VerifyOrExit(mAddressMode != aMode);
114     LogInfo("Address Mode: %s -> %s", AddressModeToString(mAddressMode), AddressModeToString(aMode));
115     mAddressMode = aMode;
116 
117 exit:
118     return error;
119 }
120 
SetAnycastModeSequenceNumber(uint8_t aSequenceNumber)121 Error Server::SetAnycastModeSequenceNumber(uint8_t aSequenceNumber)
122 {
123     Error error = kErrorNone;
124 
125     VerifyOrExit(mState == kStateDisabled, error = kErrorInvalidState);
126     mAnycastSequenceNumber = aSequenceNumber;
127 
128     LogInfo("Set Anycast Address Mode Seq Number to %d", aSequenceNumber);
129 
130 exit:
131     return error;
132 }
133 
SetEnabled(bool aEnabled)134 void Server::SetEnabled(bool aEnabled)
135 {
136     if (aEnabled)
137     {
138         VerifyOrExit(mState == kStateDisabled);
139         mState = kStateStopped;
140 
141         // Request publishing of "DNS/SRP Address Service" entry in the
142         // Thread Network Data based of `mAddressMode`. Then wait for
143         // callback `HandleNetDataPublisherEntryChange()` from the
144         // `Publisher` to start the SRP server.
145 
146         switch (mAddressMode)
147         {
148         case kAddressModeUnicast:
149             SelectPort();
150             Get<NetworkData::Publisher>().PublishDnsSrpServiceUnicast(mPort);
151             break;
152 
153         case kAddressModeAnycast:
154             mPort = kAnycastAddressModePort;
155             Get<NetworkData::Publisher>().PublishDnsSrpServiceAnycast(mAnycastSequenceNumber);
156             break;
157         }
158     }
159     else
160     {
161         VerifyOrExit(mState != kStateDisabled);
162         Get<NetworkData::Publisher>().UnpublishDnsSrpService();
163         Stop();
164         mState = kStateDisabled;
165     }
166 
167 exit:
168     return;
169 }
170 
TtlConfig(void)171 Server::TtlConfig::TtlConfig(void)
172 {
173     mMinTtl = kDefaultMinTtl;
174     mMaxTtl = kDefaultMaxTtl;
175 }
176 
SetTtlConfig(const TtlConfig & aTtlConfig)177 Error Server::SetTtlConfig(const TtlConfig &aTtlConfig)
178 {
179     Error error = kErrorNone;
180 
181     VerifyOrExit(aTtlConfig.IsValid(), error = kErrorInvalidArgs);
182     mTtlConfig = aTtlConfig;
183 
184 exit:
185     return error;
186 }
187 
GrantTtl(uint32_t aLease,uint32_t aTtl) const188 uint32_t Server::TtlConfig::GrantTtl(uint32_t aLease, uint32_t aTtl) const
189 {
190     OT_ASSERT(mMinTtl <= mMaxTtl);
191 
192     return OT_MAX(mMinTtl, OT_MIN(OT_MIN(mMaxTtl, aLease), aTtl));
193 }
194 
LeaseConfig(void)195 Server::LeaseConfig::LeaseConfig(void)
196 {
197     mMinLease    = kDefaultMinLease;
198     mMaxLease    = kDefaultMaxLease;
199     mMinKeyLease = kDefaultMinKeyLease;
200     mMaxKeyLease = kDefaultMaxKeyLease;
201 }
202 
IsValid(void) const203 bool Server::LeaseConfig::IsValid(void) const
204 {
205     bool valid = false;
206 
207     // TODO: Support longer LEASE.
208     // We use milliseconds timer for LEASE & KEY-LEASE, this is to avoid overflow.
209     VerifyOrExit(mMaxKeyLease <= Time::MsecToSec(TimerMilli::kMaxDelay));
210     VerifyOrExit(mMinLease <= mMaxLease);
211     VerifyOrExit(mMinKeyLease <= mMaxKeyLease);
212     VerifyOrExit(mMinLease <= mMinKeyLease);
213     VerifyOrExit(mMaxLease <= mMaxKeyLease);
214 
215     valid = true;
216 
217 exit:
218     return valid;
219 }
220 
GrantLease(uint32_t aLease) const221 uint32_t Server::LeaseConfig::GrantLease(uint32_t aLease) const
222 {
223     OT_ASSERT(mMinLease <= mMaxLease);
224 
225     return (aLease == 0) ? 0 : OT_MAX(mMinLease, OT_MIN(mMaxLease, aLease));
226 }
227 
GrantKeyLease(uint32_t aKeyLease) const228 uint32_t Server::LeaseConfig::GrantKeyLease(uint32_t aKeyLease) const
229 {
230     OT_ASSERT(mMinKeyLease <= mMaxKeyLease);
231 
232     return (aKeyLease == 0) ? 0 : OT_MAX(mMinKeyLease, OT_MIN(mMaxKeyLease, aKeyLease));
233 }
234 
SetLeaseConfig(const LeaseConfig & aLeaseConfig)235 Error Server::SetLeaseConfig(const LeaseConfig &aLeaseConfig)
236 {
237     Error error = kErrorNone;
238 
239     VerifyOrExit(aLeaseConfig.IsValid(), error = kErrorInvalidArgs);
240     mLeaseConfig = aLeaseConfig;
241 
242 exit:
243     return error;
244 }
245 
SetDomain(const char * aDomain)246 Error Server::SetDomain(const char *aDomain)
247 {
248     Error    error = kErrorNone;
249     uint16_t length;
250 
251     VerifyOrExit(mState == kStateDisabled, error = kErrorInvalidState);
252 
253     length = StringLength(aDomain, Dns::Name::kMaxNameSize);
254     VerifyOrExit((length > 0) && (length < Dns::Name::kMaxNameSize), error = kErrorInvalidArgs);
255 
256     if (aDomain[length - 1] == '.')
257     {
258         error = mDomain.Set(aDomain);
259     }
260     else
261     {
262         // Need to append dot at the end
263 
264         char buf[Dns::Name::kMaxNameSize];
265 
266         VerifyOrExit(length < Dns::Name::kMaxNameSize - 1, error = kErrorInvalidArgs);
267 
268         memcpy(buf, aDomain, length);
269         buf[length]     = '.';
270         buf[length + 1] = '\0';
271 
272         error = mDomain.Set(buf);
273     }
274 
275 exit:
276     return error;
277 }
278 
GetNextHost(const Server::Host * aHost)279 const Server::Host *Server::GetNextHost(const Server::Host *aHost)
280 {
281     return (aHost == nullptr) ? mHosts.GetHead() : aHost->GetNext();
282 }
283 
284 // This method adds a SRP service host and takes ownership of it.
285 // The caller MUST make sure that there is no existing host with the same hostname.
AddHost(Host & aHost)286 void Server::AddHost(Host &aHost)
287 {
288     LogInfo("Add new host %s", aHost.GetFullName());
289 
290     OT_ASSERT(mHosts.FindMatching(aHost.GetFullName()) == nullptr);
291     IgnoreError(mHosts.Add(aHost));
292 }
RemoveHost(Host * aHost,RetainName aRetainName,NotifyMode aNotifyServiceHandler)293 void Server::RemoveHost(Host *aHost, RetainName aRetainName, NotifyMode aNotifyServiceHandler)
294 {
295     VerifyOrExit(aHost != nullptr);
296 
297     aHost->mLease = 0;
298     aHost->ClearResources();
299 
300     if (aRetainName)
301     {
302         LogInfo("Remove host %s (but retain its name)", aHost->GetFullName());
303     }
304     else
305     {
306         aHost->mKeyLease = 0;
307         IgnoreError(mHosts.Remove(*aHost));
308         LogInfo("Fully remove host %s", aHost->GetFullName());
309     }
310 
311     if (aNotifyServiceHandler && mServiceUpdateHandler != nullptr)
312     {
313         uint32_t updateId = AllocateId();
314 
315         LogInfo("SRP update handler is notified (updatedId = %u)", updateId);
316         mServiceUpdateHandler(updateId, aHost, kDefaultEventsHandlerTimeout, mServiceUpdateHandlerContext);
317         // We don't wait for the reply from the service update handler,
318         // but always remove the host (and its services) regardless of
319         // host/service update result. Because removing a host should fail
320         // only when there is system failure of the platform mDNS implementation
321         // and in which case the host is not expected to be still registered.
322     }
323 
324     if (!aRetainName)
325     {
326         aHost->Free();
327     }
328 
329 exit:
330     return;
331 }
332 
HasNameConflictsWith(Host & aHost) const333 bool Server::HasNameConflictsWith(Host &aHost) const
334 {
335     bool        hasConflicts = false;
336     const Host *existingHost = mHosts.FindMatching(aHost.GetFullName());
337 
338     if (existingHost != nullptr && aHost.GetKeyRecord()->GetKey() != existingHost->GetKeyRecord()->GetKey())
339     {
340         LogWarn("Name conflict: host name %s has already been allocated", aHost.GetFullName());
341         ExitNow(hasConflicts = true);
342     }
343 
344     for (const Service &service : aHost.mServices)
345     {
346         // Check on all hosts for a matching service with the same
347         // instance name and if found, verify that it has the same
348         // key.
349 
350         for (const Host &host : mHosts)
351         {
352             if (host.HasServiceInstance(service.GetInstanceName()) &&
353                 aHost.GetKeyRecord()->GetKey() != host.GetKeyRecord()->GetKey())
354             {
355                 LogWarn("Name conflict: service name %s has already been allocated", service.GetInstanceName());
356                 ExitNow(hasConflicts = true);
357             }
358         }
359     }
360 
361 exit:
362     return hasConflicts;
363 }
364 
HandleServiceUpdateResult(ServiceUpdateId aId,Error aError)365 void Server::HandleServiceUpdateResult(ServiceUpdateId aId, Error aError)
366 {
367     UpdateMetadata *update = mOutstandingUpdates.FindMatching(aId);
368 
369     if (update != nullptr)
370     {
371         HandleServiceUpdateResult(update, aError);
372     }
373     else
374     {
375         LogInfo("Delayed SRP host update result, the SRP update has been committed (updateId = %u)", aId);
376     }
377 }
378 
HandleServiceUpdateResult(UpdateMetadata * aUpdate,Error aError)379 void Server::HandleServiceUpdateResult(UpdateMetadata *aUpdate, Error aError)
380 {
381     LogInfo("Handler result of SRP update (id = %u) is received: %s", aUpdate->GetId(), ErrorToString(aError));
382 
383     IgnoreError(mOutstandingUpdates.Remove(*aUpdate));
384     CommitSrpUpdate(aError, *aUpdate);
385     aUpdate->Free();
386 
387     if (mOutstandingUpdates.IsEmpty())
388     {
389         mOutstandingUpdatesTimer.Stop();
390     }
391     else
392     {
393         mOutstandingUpdatesTimer.FireAt(mOutstandingUpdates.GetTail()->GetExpireTime());
394     }
395 }
396 
CommitSrpUpdate(Error aError,Host & aHost,const MessageMetadata & aMessageMetadata)397 void Server::CommitSrpUpdate(Error aError, Host &aHost, const MessageMetadata &aMessageMetadata)
398 {
399     CommitSrpUpdate(aError, aHost, aMessageMetadata.mDnsHeader, aMessageMetadata.mMessageInfo,
400                     aMessageMetadata.mTtlConfig, aMessageMetadata.mLeaseConfig);
401 }
402 
CommitSrpUpdate(Error aError,UpdateMetadata & aUpdateMetadata)403 void Server::CommitSrpUpdate(Error aError, UpdateMetadata &aUpdateMetadata)
404 {
405     CommitSrpUpdate(aError, aUpdateMetadata.GetHost(), aUpdateMetadata.GetDnsHeader(),
406                     aUpdateMetadata.IsDirectRxFromClient() ? &aUpdateMetadata.GetMessageInfo() : nullptr,
407                     aUpdateMetadata.GetTtlConfig(), aUpdateMetadata.GetLeaseConfig());
408 }
409 
CommitSrpUpdate(Error aError,Host & aHost,const Dns::UpdateHeader & aDnsHeader,const Ip6::MessageInfo * aMessageInfo,const TtlConfig & aTtlConfig,const LeaseConfig & aLeaseConfig)410 void Server::CommitSrpUpdate(Error                    aError,
411                              Host &                   aHost,
412                              const Dns::UpdateHeader &aDnsHeader,
413                              const Ip6::MessageInfo * aMessageInfo,
414                              const TtlConfig &        aTtlConfig,
415                              const LeaseConfig &      aLeaseConfig)
416 {
417     Host *   existingHost;
418     uint32_t hostLease;
419     uint32_t hostKeyLease;
420     uint32_t grantedLease;
421     uint32_t grantedKeyLease;
422     uint32_t grantedTtl;
423     bool     shouldFreeHost = true;
424 
425     SuccessOrExit(aError);
426 
427     hostLease       = aHost.GetLease();
428     hostKeyLease    = aHost.GetKeyLease();
429     grantedLease    = aLeaseConfig.GrantLease(hostLease);
430     grantedKeyLease = aLeaseConfig.GrantKeyLease(hostKeyLease);
431     grantedTtl      = aTtlConfig.GrantTtl(grantedLease, aHost.GetTtl());
432 
433     aHost.SetLease(grantedLease);
434     aHost.SetKeyLease(grantedKeyLease);
435     aHost.SetTtl(grantedTtl);
436 
437     for (Service &service : aHost.mServices)
438     {
439         service.mDescription->mLease    = grantedLease;
440         service.mDescription->mKeyLease = grantedKeyLease;
441         service.mDescription->mTtl      = grantedTtl;
442     }
443 
444     existingHost = mHosts.FindMatching(aHost.GetFullName());
445 
446     if (aHost.GetLease() == 0)
447     {
448         if (aHost.GetKeyLease() == 0)
449         {
450             LogInfo("Remove key of host %s", aHost.GetFullName());
451             RemoveHost(existingHost, kDeleteName, kDoNotNotifyServiceHandler);
452         }
453         else if (existingHost != nullptr)
454         {
455             existingHost->SetKeyLease(aHost.GetKeyLease());
456             RemoveHost(existingHost, kRetainName, kDoNotNotifyServiceHandler);
457 
458             for (Service &service : existingHost->mServices)
459             {
460                 existingHost->RemoveService(&service, kRetainName, kDoNotNotifyServiceHandler);
461             }
462         }
463     }
464     else if (existingHost != nullptr)
465     {
466         SuccessOrExit(aError = existingHost->MergeServicesAndResourcesFrom(aHost));
467     }
468     else
469     {
470         AddHost(aHost);
471         shouldFreeHost = false;
472 
473         for (Service &service : aHost.GetServices())
474         {
475             service.mIsCommitted = true;
476             service.Log(Service::kAddNew);
477         }
478 
479 #if OPENTHREAD_CONFIG_SRP_SERVER_PORT_SWITCH_ENABLE
480         if (!mHasRegisteredAnyService && (mAddressMode == kAddressModeUnicast))
481         {
482             Settings::SrpServerInfo info;
483 
484             mHasRegisteredAnyService = true;
485             info.SetPort(GetSocket().mSockName.mPort);
486             IgnoreError(Get<Settings>().Save(info));
487         }
488 #endif
489     }
490 
491     // Re-schedule the lease timer.
492     HandleLeaseTimer();
493 
494 exit:
495     if (aMessageInfo != nullptr)
496     {
497         if (aError == kErrorNone && !(grantedLease == hostLease && grantedKeyLease == hostKeyLease))
498         {
499             SendResponse(aDnsHeader, grantedLease, grantedKeyLease, *aMessageInfo);
500         }
501         else
502         {
503             SendResponse(aDnsHeader, ErrorToDnsResponseCode(aError), *aMessageInfo);
504         }
505     }
506 
507     if (shouldFreeHost)
508     {
509         aHost.Free();
510     }
511 }
512 
SelectPort(void)513 void Server::SelectPort(void)
514 {
515     mPort = kUdpPortMin;
516 
517 #if OPENTHREAD_CONFIG_SRP_SERVER_PORT_SWITCH_ENABLE
518     {
519         Settings::SrpServerInfo info;
520 
521         if (Get<Settings>().Read(info) == kErrorNone)
522         {
523             mPort = info.GetPort() + 1;
524             if (mPort < kUdpPortMin || mPort > kUdpPortMax)
525             {
526                 mPort = kUdpPortMin;
527             }
528         }
529     }
530 #endif
531 
532     LogInfo("Selected port %u", mPort);
533 }
534 
Start(void)535 void Server::Start(void)
536 {
537     VerifyOrExit(mState == kStateStopped);
538 
539     mState = kStateRunning;
540     PrepareSocket();
541     LogInfo("Start listening on port %u", mPort);
542 
543 exit:
544     return;
545 }
546 
PrepareSocket(void)547 void Server::PrepareSocket(void)
548 {
549     Error error = kErrorNone;
550 
551 #if OPENTHREAD_CONFIG_DNSSD_SERVER_ENABLE
552     Ip6::Udp::Socket &dnsSocket = Get<Dns::ServiceDiscovery::Server>().mSocket;
553 
554     if (dnsSocket.GetSockName().GetPort() == mPort)
555     {
556         // If the DNS-SD socket matches our port number, we use the
557         // same socket so we close our own socket (in case it was
558         // open). `GetSocket()` will now return the DNS-SD socket.
559 
560         IgnoreError(mSocket.Close());
561         ExitNow();
562     }
563 #endif
564 
565     VerifyOrExit(!mSocket.IsOpen());
566     SuccessOrExit(error = mSocket.Open(HandleUdpReceive, this));
567     error = mSocket.Bind(mPort, OT_NETIF_THREAD);
568 
569 exit:
570     if (error != kErrorNone)
571     {
572         LogCrit("Failed to prepare socket: %s", ErrorToString(error));
573         Stop();
574     }
575 }
576 
GetSocket(void)577 Ip6::Udp::Socket &Server::GetSocket(void)
578 {
579     Ip6::Udp::Socket *socket = &mSocket;
580 
581 #if OPENTHREAD_CONFIG_DNSSD_SERVER_ENABLE
582     Ip6::Udp::Socket &dnsSocket = Get<Dns::ServiceDiscovery::Server>().mSocket;
583 
584     if (dnsSocket.GetSockName().GetPort() == mPort)
585     {
586         socket = &dnsSocket;
587     }
588 #endif
589 
590     return *socket;
591 }
592 
593 #if OPENTHREAD_CONFIG_DNSSD_SERVER_ENABLE
594 
HandleDnssdServerStateChange(void)595 void Server::HandleDnssdServerStateChange(void)
596 {
597     // This is called from` Dns::ServiceDiscovery::Server` to notify
598     // that it has started or stopped. We check whether we need to
599     // share the socket.
600 
601     if (mState == kStateRunning)
602     {
603         PrepareSocket();
604     }
605 }
606 
HandleDnssdServerUdpReceive(Message & aMessage,const Ip6::MessageInfo & aMessageInfo)607 Error Server::HandleDnssdServerUdpReceive(Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
608 {
609     // This is called from` Dns::ServiceDiscovery::Server` when a UDP
610     // message is received on its socket. We check whether we are
611     // sharing socket and if so we process the received message. We
612     // return `kErrorNone` to indicate that message was successfully
613     // processed by `Srp::Server`, otherwise `kErrorDrop` is returned.
614 
615     Error error = kErrorDrop;
616 
617     VerifyOrExit((mState == kStateRunning) && !mSocket.IsOpen());
618 
619     error = ProcessMessage(aMessage, aMessageInfo);
620 
621 exit:
622     return error;
623 }
624 
625 #endif // OPENTHREAD_CONFIG_DNSSD_SERVER_ENABLE
626 
Stop(void)627 void Server::Stop(void)
628 {
629     VerifyOrExit(mState == kStateRunning);
630 
631     mState = kStateStopped;
632 
633     while (!mHosts.IsEmpty())
634     {
635         RemoveHost(mHosts.GetHead(), kDeleteName, kNotifyServiceHandler);
636     }
637 
638     // TODO: We should cancel any outstanding service updates, but current
639     // OTBR mDNS publisher cannot properly handle it.
640     while (!mOutstandingUpdates.IsEmpty())
641     {
642         mOutstandingUpdates.Pop()->Free();
643     }
644 
645     mLeaseTimer.Stop();
646     mOutstandingUpdatesTimer.Stop();
647 
648     LogInfo("Stop listening on %u", mPort);
649     IgnoreError(mSocket.Close());
650     mHasRegisteredAnyService = false;
651 
652 exit:
653     return;
654 }
655 
HandleNetDataPublisherEvent(NetworkData::Publisher::Event aEvent)656 void Server::HandleNetDataPublisherEvent(NetworkData::Publisher::Event aEvent)
657 {
658     switch (aEvent)
659     {
660     case NetworkData::Publisher::kEventEntryAdded:
661         Start();
662         break;
663 
664     case NetworkData::Publisher::kEventEntryRemoved:
665         Stop();
666         break;
667     }
668 }
669 
FindOutstandingUpdate(const MessageMetadata & aMessageMetadata) const670 const Server::UpdateMetadata *Server::FindOutstandingUpdate(const MessageMetadata &aMessageMetadata) const
671 {
672     const UpdateMetadata *ret = nullptr;
673 
674     VerifyOrExit(aMessageMetadata.IsDirectRxFromClient());
675 
676     for (const UpdateMetadata &update : mOutstandingUpdates)
677     {
678         if (aMessageMetadata.mDnsHeader.GetMessageId() == update.GetDnsHeader().GetMessageId() &&
679             aMessageMetadata.mMessageInfo->GetPeerAddr() == update.GetMessageInfo().GetPeerAddr() &&
680             aMessageMetadata.mMessageInfo->GetPeerPort() == update.GetMessageInfo().GetPeerPort())
681         {
682             ExitNow(ret = &update);
683         }
684     }
685 
686 exit:
687     return ret;
688 }
689 
ProcessDnsUpdate(Message & aMessage,MessageMetadata & aMetadata)690 void Server::ProcessDnsUpdate(Message &aMessage, MessageMetadata &aMetadata)
691 {
692     Error error = kErrorNone;
693     Host *host  = nullptr;
694 
695     LogInfo("Received DNS update from %s", aMetadata.IsDirectRxFromClient()
696                                                ? aMetadata.mMessageInfo->GetPeerAddr().ToString().AsCString()
697                                                : "an SRPL Partner");
698 
699     SuccessOrExit(error = ProcessZoneSection(aMessage, aMetadata));
700 
701     if (FindOutstandingUpdate(aMetadata) != nullptr)
702     {
703         LogInfo("Drop duplicated SRP update request: MessageId=%hu", aMetadata.mDnsHeader.GetMessageId());
704 
705         // Silently drop duplicate requests.
706         // This could rarely happen, because the outstanding SRP update timer should
707         // be shorter than the SRP update retransmission timer.
708         ExitNow(error = kErrorNone);
709     }
710 
711     // Per 2.3.2 of SRP draft 6, no prerequisites should be included in a SRP update.
712     VerifyOrExit(aMetadata.mDnsHeader.GetPrerequisiteRecordCount() == 0, error = kErrorFailed);
713 
714     host = Host::Allocate(GetInstance(), aMetadata.mRxTime);
715     VerifyOrExit(host != nullptr, error = kErrorNoBufs);
716     SuccessOrExit(error = ProcessUpdateSection(*host, aMessage, aMetadata));
717 
718     // Parse lease time and validate signature.
719     SuccessOrExit(error = ProcessAdditionalSection(host, aMessage, aMetadata));
720 
721     SuccessOrExit(error = ValidateServiceSubTypes(*host, aMetadata));
722 
723     HandleUpdate(*host, aMetadata);
724 
725 exit:
726     if (error != kErrorNone)
727     {
728         if (host != nullptr)
729         {
730             host->Free();
731         }
732 
733         if (aMetadata.IsDirectRxFromClient())
734         {
735             SendResponse(aMetadata.mDnsHeader, ErrorToDnsResponseCode(error), *aMetadata.mMessageInfo);
736         }
737     }
738 }
739 
ProcessZoneSection(const Message & aMessage,MessageMetadata & aMetadata) const740 Error Server::ProcessZoneSection(const Message &aMessage, MessageMetadata &aMetadata) const
741 {
742     Error    error = kErrorNone;
743     char     name[Dns::Name::kMaxNameSize];
744     uint16_t offset = aMetadata.mOffset;
745 
746     VerifyOrExit(aMetadata.mDnsHeader.GetZoneRecordCount() == 1, error = kErrorParse);
747 
748     SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, name, sizeof(name)));
749     // TODO: return `Dns::kResponseNotAuth` for not authorized zone names.
750     VerifyOrExit(StringMatch(name, GetDomain(), kStringCaseInsensitiveMatch), error = kErrorSecurity);
751     SuccessOrExit(error = aMessage.Read(offset, aMetadata.mDnsZone));
752     offset += sizeof(Dns::Zone);
753 
754     VerifyOrExit(aMetadata.mDnsZone.GetType() == Dns::ResourceRecord::kTypeSoa, error = kErrorParse);
755     aMetadata.mOffset = offset;
756 
757 exit:
758     if (error != kErrorNone)
759     {
760         LogWarn("Failed to process DNS Zone section: %s", ErrorToString(error));
761     }
762 
763     return error;
764 }
765 
ProcessUpdateSection(Host & aHost,const Message & aMessage,MessageMetadata & aMetadata) const766 Error Server::ProcessUpdateSection(Host &aHost, const Message &aMessage, MessageMetadata &aMetadata) const
767 {
768     Error error = kErrorNone;
769 
770     // Process Service Discovery, Host and Service Description Instructions with
771     // 3 times iterations over all DNS update RRs. The order of those processes matters.
772 
773     // 0. Enumerate over all Service Discovery Instructions before processing any other records.
774     // So that we will know whether a name is a hostname or service instance name when processing
775     // a "Delete All RRsets from a name" record.
776     SuccessOrExit(error = ProcessServiceDiscoveryInstructions(aHost, aMessage, aMetadata));
777 
778     // 1. Enumerate over all RRs to build the Host Description Instruction.
779     SuccessOrExit(error = ProcessHostDescriptionInstruction(aHost, aMessage, aMetadata));
780 
781     // 2. Enumerate over all RRs to build the Service Description Instructions.
782     SuccessOrExit(error = ProcessServiceDescriptionInstructions(aHost, aMessage, aMetadata));
783 
784     // 3. Verify that there are no name conflicts.
785     VerifyOrExit(!HasNameConflictsWith(aHost), error = kErrorDuplicated);
786 
787 exit:
788     if (error != kErrorNone)
789     {
790         LogWarn("Failed to process DNS Update section: %s", ErrorToString(error));
791     }
792 
793     return error;
794 }
795 
ProcessHostDescriptionInstruction(Host & aHost,const Message & aMessage,const MessageMetadata & aMetadata) const796 Error Server::ProcessHostDescriptionInstruction(Host &                 aHost,
797                                                 const Message &        aMessage,
798                                                 const MessageMetadata &aMetadata) const
799 {
800     Error    error;
801     uint16_t offset = aMetadata.mOffset;
802 
803     OT_ASSERT(aHost.GetFullName() == nullptr);
804 
805     for (uint16_t numRecords = aMetadata.mDnsHeader.GetUpdateRecordCount(); numRecords > 0; numRecords--)
806     {
807         char                name[Dns::Name::kMaxNameSize];
808         Dns::ResourceRecord record;
809 
810         SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, name, sizeof(name)));
811 
812         SuccessOrExit(error = aMessage.Read(offset, record));
813 
814         if (record.GetClass() == Dns::ResourceRecord::kClassAny)
815         {
816             // Delete All RRsets from a name.
817             VerifyOrExit(IsValidDeleteAllRecord(record), error = kErrorFailed);
818 
819             // A "Delete All RRsets from a name" RR can only apply to a Service or Host Description.
820 
821             if (!aHost.HasServiceInstance(name))
822             {
823                 // If host name is already set to a different name, `SetFullName()`
824                 // will return `kErrorFailed`.
825                 SuccessOrExit(error = aHost.SetFullName(name));
826                 aHost.ClearResources();
827             }
828         }
829         else if (record.GetType() == Dns::ResourceRecord::kTypeAaaa)
830         {
831             Dns::AaaaRecord aaaaRecord;
832 
833             VerifyOrExit(record.GetClass() == aMetadata.mDnsZone.GetClass(), error = kErrorFailed);
834 
835             SuccessOrExit(error = aHost.ProcessTtl(record.GetTtl()));
836 
837             SuccessOrExit(error = aHost.SetFullName(name));
838 
839             SuccessOrExit(error = aMessage.Read(offset, aaaaRecord));
840             VerifyOrExit(aaaaRecord.IsValid(), error = kErrorParse);
841 
842             // Tolerate kErrorDrop for AAAA Resources.
843             VerifyOrExit(aHost.AddIp6Address(aaaaRecord.GetAddress()) != kErrorNoBufs, error = kErrorNoBufs);
844         }
845         else if (record.GetType() == Dns::ResourceRecord::kTypeKey)
846         {
847             // We currently support only ECDSA P-256.
848             Dns::Ecdsa256KeyRecord keyRecord;
849 
850             VerifyOrExit(record.GetClass() == aMetadata.mDnsZone.GetClass(), error = kErrorFailed);
851 
852             SuccessOrExit(error = aHost.ProcessTtl(record.GetTtl()));
853 
854             SuccessOrExit(error = aMessage.Read(offset, keyRecord));
855             VerifyOrExit(keyRecord.IsValid(), error = kErrorParse);
856 
857             VerifyOrExit(aHost.GetKeyRecord() == nullptr || *aHost.GetKeyRecord() == keyRecord, error = kErrorSecurity);
858             aHost.SetKeyRecord(keyRecord);
859         }
860 
861         offset += record.GetSize();
862     }
863 
864     // Verify that we have a complete Host Description Instruction.
865 
866     VerifyOrExit(aHost.GetFullName() != nullptr, error = kErrorFailed);
867     VerifyOrExit(aHost.GetKeyRecord() != nullptr, error = kErrorFailed);
868 
869     // We check the number of host addresses after processing of the
870     // Lease Option in the Addition Section and determining whether
871     // the host is being removed or registered.
872 
873 exit:
874     if (error != kErrorNone)
875     {
876         LogWarn("Failed to process Host Description instructions: %s", ErrorToString(error));
877     }
878 
879     return error;
880 }
881 
ProcessServiceDiscoveryInstructions(Host & aHost,const Message & aMessage,const MessageMetadata & aMetadata) const882 Error Server::ProcessServiceDiscoveryInstructions(Host &                 aHost,
883                                                   const Message &        aMessage,
884                                                   const MessageMetadata &aMetadata) const
885 {
886     Error    error  = kErrorNone;
887     uint16_t offset = aMetadata.mOffset;
888 
889     for (uint16_t numRecords = aMetadata.mDnsHeader.GetUpdateRecordCount(); numRecords > 0; numRecords--)
890     {
891         char           serviceName[Dns::Name::kMaxNameSize];
892         char           instanceName[Dns::Name::kMaxNameSize];
893         Dns::PtrRecord ptrRecord;
894         const char *   subServiceName;
895         Service *      service;
896         bool           isSubType;
897 
898         SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, serviceName, sizeof(serviceName)));
899         VerifyOrExit(Dns::Name::IsSubDomainOf(serviceName, GetDomain()), error = kErrorSecurity);
900 
901         error = Dns::ResourceRecord::ReadRecord(aMessage, offset, ptrRecord);
902 
903         if (error == kErrorNotFound)
904         {
905             // `ReadRecord()` updates `aOffset` to skip over a
906             // non-matching record.
907             error = kErrorNone;
908             continue;
909         }
910 
911         SuccessOrExit(error);
912 
913         SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, instanceName, sizeof(instanceName)));
914 
915         VerifyOrExit(ptrRecord.GetClass() == Dns::ResourceRecord::kClassNone ||
916                          ptrRecord.GetClass() == aMetadata.mDnsZone.GetClass(),
917                      error = kErrorFailed);
918 
919         // Check if the `serviceName` is a subtype with the name
920         // format: "<sub-label>._sub.<service-labels>.<domain>."
921 
922         subServiceName = StringFind(serviceName, kServiceSubTypeLabel, kStringCaseInsensitiveMatch);
923         isSubType      = (subServiceName != nullptr);
924 
925         if (isSubType)
926         {
927             // Skip over the "._sub." label to get to the base
928             // service name.
929             subServiceName += sizeof(kServiceSubTypeLabel) - 1;
930         }
931 
932         // Verify that instance name and service name are related.
933 
934         VerifyOrExit(
935             StringEndsWith(instanceName, isSubType ? subServiceName : serviceName, kStringCaseInsensitiveMatch),
936             error = kErrorFailed);
937 
938         // Ensure the same service does not exist already.
939         VerifyOrExit(aHost.FindService(serviceName, instanceName) == nullptr, error = kErrorFailed);
940 
941         service = aHost.AddNewService(serviceName, instanceName, isSubType, aMetadata.mRxTime);
942         VerifyOrExit(service != nullptr, error = kErrorNoBufs);
943 
944         // This RR is a "Delete an RR from an RRset" update when the CLASS is NONE.
945         service->mIsDeleted = (ptrRecord.GetClass() == Dns::ResourceRecord::kClassNone);
946 
947         if (!service->mIsDeleted)
948         {
949             SuccessOrExit(error = aHost.ProcessTtl(ptrRecord.GetTtl()));
950         }
951     }
952 
953 exit:
954     if (error != kErrorNone)
955     {
956         LogWarn("Failed to process Service Discovery instructions: %s", ErrorToString(error));
957     }
958 
959     return error;
960 }
961 
ProcessServiceDescriptionInstructions(Host & aHost,const Message & aMessage,MessageMetadata & aMetadata) const962 Error Server::ProcessServiceDescriptionInstructions(Host &           aHost,
963                                                     const Message &  aMessage,
964                                                     MessageMetadata &aMetadata) const
965 {
966     Error    error  = kErrorNone;
967     uint16_t offset = aMetadata.mOffset;
968 
969     for (uint16_t numRecords = aMetadata.mDnsHeader.GetUpdateRecordCount(); numRecords > 0; numRecords--)
970     {
971         RetainPtr<Service::Description> desc;
972         char                            name[Dns::Name::kMaxNameSize];
973         Dns::ResourceRecord             record;
974 
975         SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, name, sizeof(name)));
976         SuccessOrExit(error = aMessage.Read(offset, record));
977 
978         if (record.GetClass() == Dns::ResourceRecord::kClassAny)
979         {
980             // Delete All RRsets from a name.
981             VerifyOrExit(IsValidDeleteAllRecord(record), error = kErrorFailed);
982 
983             desc = aHost.FindServiceDescription(name);
984 
985             if (desc != nullptr)
986             {
987                 desc->ClearResources();
988                 desc->mUpdateTime = aMetadata.mRxTime;
989             }
990 
991             offset += record.GetSize();
992             continue;
993         }
994 
995         if (record.GetType() == Dns::ResourceRecord::kTypeSrv)
996         {
997             Dns::SrvRecord srvRecord;
998             char           hostName[Dns::Name::kMaxNameSize];
999             uint16_t       hostNameLength = sizeof(hostName);
1000 
1001             VerifyOrExit(record.GetClass() == aMetadata.mDnsZone.GetClass(), error = kErrorFailed);
1002 
1003             SuccessOrExit(error = aHost.ProcessTtl(record.GetTtl()));
1004 
1005             SuccessOrExit(error = aMessage.Read(offset, srvRecord));
1006             offset += sizeof(srvRecord);
1007 
1008             SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, hostName, hostNameLength));
1009             VerifyOrExit(Dns::Name::IsSubDomainOf(name, GetDomain()), error = kErrorSecurity);
1010             VerifyOrExit(aHost.Matches(hostName), error = kErrorFailed);
1011 
1012             desc = aHost.FindServiceDescription(name);
1013             VerifyOrExit(desc != nullptr, error = kErrorFailed);
1014 
1015             // Make sure that this is the first SRV RR for this service description
1016             VerifyOrExit(desc->mPort == 0, error = kErrorFailed);
1017             desc->mTtl        = srvRecord.GetTtl();
1018             desc->mPriority   = srvRecord.GetPriority();
1019             desc->mWeight     = srvRecord.GetWeight();
1020             desc->mPort       = srvRecord.GetPort();
1021             desc->mUpdateTime = aMetadata.mRxTime;
1022         }
1023         else if (record.GetType() == Dns::ResourceRecord::kTypeTxt)
1024         {
1025             VerifyOrExit(record.GetClass() == aMetadata.mDnsZone.GetClass(), error = kErrorFailed);
1026 
1027             SuccessOrExit(error = aHost.ProcessTtl(record.GetTtl()));
1028 
1029             desc = aHost.FindServiceDescription(name);
1030             VerifyOrExit(desc != nullptr, error = kErrorFailed);
1031 
1032             offset += sizeof(record);
1033             SuccessOrExit(error = desc->SetTxtDataFromMessage(aMessage, offset, record.GetLength()));
1034             offset += record.GetLength();
1035         }
1036         else
1037         {
1038             offset += record.GetSize();
1039         }
1040     }
1041 
1042     // Verify that all service descriptions on `aHost` are updated. Note
1043     // that `mUpdateTime` on a new `Service::Description` is set to
1044     // `GetNow().GetDistantPast()`.
1045 
1046     for (Service &service : aHost.mServices)
1047     {
1048         VerifyOrExit(service.mDescription->mUpdateTime == aMetadata.mRxTime, error = kErrorFailed);
1049 
1050         // Check that either both `mPort` and `mTxtData` are set
1051         // (i.e., we saw both SRV and TXT record) or both are default
1052         // (cleared) value (i.e., we saw neither of them).
1053 
1054         VerifyOrExit((service.mDescription->mPort == 0) == service.mDescription->mTxtData.IsNull(),
1055                      error = kErrorFailed);
1056     }
1057 
1058     aMetadata.mOffset = offset;
1059 
1060 exit:
1061     if (error != kErrorNone)
1062     {
1063         LogWarn("Failed to process Service Description instructions: %s", ErrorToString(error));
1064     }
1065 
1066     return error;
1067 }
1068 
IsValidDeleteAllRecord(const Dns::ResourceRecord & aRecord)1069 bool Server::IsValidDeleteAllRecord(const Dns::ResourceRecord &aRecord)
1070 {
1071     return aRecord.GetClass() == Dns::ResourceRecord::kClassAny && aRecord.GetType() == Dns::ResourceRecord::kTypeAny &&
1072            aRecord.GetTtl() == 0 && aRecord.GetLength() == 0;
1073 }
1074 
ProcessAdditionalSection(Host * aHost,const Message & aMessage,MessageMetadata & aMetadata) const1075 Error Server::ProcessAdditionalSection(Host *aHost, const Message &aMessage, MessageMetadata &aMetadata) const
1076 {
1077     Error            error = kErrorNone;
1078     Dns::OptRecord   optRecord;
1079     Dns::LeaseOption leaseOption;
1080     Dns::SigRecord   sigRecord;
1081     char             name[2]; // The root domain name (".") is expected.
1082     uint16_t         offset = aMetadata.mOffset;
1083     uint16_t         sigOffset;
1084     uint16_t         sigRdataOffset;
1085     char             signerName[Dns::Name::kMaxNameSize];
1086     uint16_t         signatureLength;
1087 
1088     VerifyOrExit(aMetadata.mDnsHeader.GetAdditionalRecordCount() == 2, error = kErrorFailed);
1089 
1090     // EDNS(0) Update Lease Option.
1091 
1092     SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, name, sizeof(name)));
1093     SuccessOrExit(error = aMessage.Read(offset, optRecord));
1094     SuccessOrExit(error = aMessage.Read(offset + sizeof(optRecord), leaseOption));
1095     VerifyOrExit(leaseOption.IsValid(), error = kErrorFailed);
1096     VerifyOrExit(optRecord.GetSize() == sizeof(optRecord) + sizeof(leaseOption), error = kErrorParse);
1097 
1098     offset += optRecord.GetSize();
1099 
1100     aHost->SetLease(leaseOption.GetLeaseInterval());
1101     aHost->SetKeyLease(leaseOption.GetKeyLeaseInterval());
1102 
1103     if (aHost->GetLease() > 0)
1104     {
1105         uint8_t hostAddressesNum;
1106 
1107         aHost->GetAddresses(hostAddressesNum);
1108 
1109         // There MUST be at least one valid address if we have nonzero lease.
1110         VerifyOrExit(hostAddressesNum > 0, error = kErrorFailed);
1111     }
1112 
1113     // SIG(0).
1114 
1115     sigOffset = offset;
1116     SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, name, sizeof(name)));
1117     SuccessOrExit(error = aMessage.Read(offset, sigRecord));
1118     VerifyOrExit(sigRecord.IsValid(), error = kErrorParse);
1119 
1120     sigRdataOffset = offset + sizeof(Dns::ResourceRecord);
1121     offset += sizeof(sigRecord);
1122 
1123     // TODO: Verify that the signature doesn't expire. This is not
1124     // implemented because the end device may not be able to get
1125     // the synchronized date/time.
1126 
1127     SuccessOrExit(error = Dns::Name::ReadName(aMessage, offset, signerName, sizeof(signerName)));
1128 
1129     signatureLength = sigRecord.GetLength() - (offset - sigRdataOffset);
1130     offset += signatureLength;
1131 
1132     // Verify the signature. Currently supports only ECDSA.
1133 
1134     VerifyOrExit(sigRecord.GetAlgorithm() == Dns::KeyRecord::kAlgorithmEcdsaP256Sha256, error = kErrorFailed);
1135     VerifyOrExit(sigRecord.GetTypeCovered() == 0, error = kErrorFailed);
1136     VerifyOrExit(signatureLength == Crypto::Ecdsa::P256::Signature::kSize, error = kErrorParse);
1137 
1138     SuccessOrExit(error = VerifySignature(*aHost->GetKeyRecord(), aMessage, aMetadata.mDnsHeader, sigOffset,
1139                                           sigRdataOffset, sigRecord.GetLength(), signerName));
1140 
1141     aMetadata.mOffset = offset;
1142 
1143 exit:
1144     if (error != kErrorNone)
1145     {
1146         LogWarn("Failed to process DNS Additional section: %s", ErrorToString(error));
1147     }
1148 
1149     return error;
1150 }
1151 
VerifySignature(const Dns::Ecdsa256KeyRecord & aKeyRecord,const Message & aMessage,Dns::UpdateHeader aDnsHeader,uint16_t aSigOffset,uint16_t aSigRdataOffset,uint16_t aSigRdataLength,const char * aSignerName) const1152 Error Server::VerifySignature(const Dns::Ecdsa256KeyRecord &aKeyRecord,
1153                               const Message &               aMessage,
1154                               Dns::UpdateHeader             aDnsHeader,
1155                               uint16_t                      aSigOffset,
1156                               uint16_t                      aSigRdataOffset,
1157                               uint16_t                      aSigRdataLength,
1158                               const char *                  aSignerName) const
1159 {
1160     Error                          error;
1161     uint16_t                       offset = aMessage.GetOffset();
1162     uint16_t                       signatureOffset;
1163     Crypto::Sha256                 sha256;
1164     Crypto::Sha256::Hash           hash;
1165     Crypto::Ecdsa::P256::Signature signature;
1166     Message *                      signerNameMessage = nullptr;
1167 
1168     VerifyOrExit(aSigRdataLength >= Crypto::Ecdsa::P256::Signature::kSize, error = kErrorInvalidArgs);
1169 
1170     sha256.Start();
1171 
1172     // SIG RDATA less signature.
1173     sha256.Update(aMessage, aSigRdataOffset, sizeof(Dns::SigRecord) - sizeof(Dns::ResourceRecord));
1174 
1175     // The uncompressed (canonical) form of the signer name should be used for signature
1176     // verification. See https://tools.ietf.org/html/rfc2931#section-3.1 for details.
1177     signerNameMessage = Get<Ip6::Udp>().NewMessage(0);
1178     VerifyOrExit(signerNameMessage != nullptr, error = kErrorNoBufs);
1179     SuccessOrExit(error = Dns::Name::AppendName(aSignerName, *signerNameMessage));
1180     sha256.Update(*signerNameMessage, signerNameMessage->GetOffset(), signerNameMessage->GetLength());
1181 
1182     // We need the DNS header before appending the SIG RR.
1183     aDnsHeader.SetAdditionalRecordCount(aDnsHeader.GetAdditionalRecordCount() - 1);
1184     sha256.Update(aDnsHeader);
1185     sha256.Update(aMessage, offset + sizeof(aDnsHeader), aSigOffset - offset - sizeof(aDnsHeader));
1186 
1187     sha256.Finish(hash);
1188 
1189     signatureOffset = aSigRdataOffset + aSigRdataLength - Crypto::Ecdsa::P256::Signature::kSize;
1190     SuccessOrExit(error = aMessage.Read(signatureOffset, signature));
1191 
1192     error = aKeyRecord.GetKey().Verify(hash, signature);
1193 
1194 exit:
1195     if (error != kErrorNone)
1196     {
1197         LogWarn("Failed to verify message signature: %s", ErrorToString(error));
1198     }
1199 
1200     FreeMessage(signerNameMessage);
1201     return error;
1202 }
1203 
ValidateServiceSubTypes(Host & aHost,const MessageMetadata & aMetadata)1204 Error Server::ValidateServiceSubTypes(Host &aHost, const MessageMetadata &aMetadata)
1205 {
1206     Error error = kErrorNone;
1207     Host *existingHost;
1208 
1209     // Verify that there is a matching base type service for all
1210     // sub-type services in `aHost` (which is from the received
1211     // and parsed SRP Update message).
1212 
1213     for (const Service &service : aHost.GetServices())
1214     {
1215         if (service.IsSubType() && (aHost.FindBaseService(service.GetInstanceName()) == nullptr))
1216         {
1217 #if OT_SHOULD_LOG_AT(OT_LOG_LEVEL_WARN)
1218             char subLabel[Dns::Name::kMaxLabelSize];
1219 
1220             IgnoreError(service.GetServiceSubTypeLabel(subLabel, sizeof(subLabel)));
1221             LogWarn("Message contains instance %s with subtype %s without base type", service.GetInstanceName(),
1222                     subLabel);
1223 #endif
1224 
1225             ExitNow(error = kErrorParse);
1226         }
1227     }
1228 
1229     // SRP server must treat the update instructions for a service type
1230     // and all its sub-types as atomic, i.e., when a service and its
1231     // sub-types are being updated, whatever information appears in the
1232     // SRP Update is the entirety of information about that service and
1233     // its sub-types. Any previously registered sub-type that does not
1234     // appear in a new SRP Update, must be removed.
1235     //
1236     // We go though the list of registered services for the same host
1237     // and if the base service is included in the new SRP Update
1238     // message, we add any previously registered service sub-type that
1239     // does not appear in new Update message as "deleted".
1240 
1241     existingHost = mHosts.FindMatching(aHost.GetFullName());
1242     VerifyOrExit(existingHost != nullptr);
1243 
1244     for (const Service &baseService : existingHost->GetServices())
1245     {
1246         if (baseService.IsSubType() || (aHost.FindBaseService(baseService.GetInstanceName()) == nullptr))
1247         {
1248             continue;
1249         }
1250 
1251         for (const Service &subService : existingHost->GetServices())
1252         {
1253             if (!subService.IsSubType() || !subService.MatchesInstanceName(baseService.GetInstanceName()))
1254             {
1255                 continue;
1256             }
1257 
1258             SuccessOrExit(error = aHost.AddCopyOfServiceAsDeletedIfNotPresent(subService, aMetadata.mRxTime));
1259         }
1260     }
1261 
1262 exit:
1263     return error;
1264 }
1265 
HandleUpdate(Host & aHost,const MessageMetadata & aMetadata)1266 void Server::HandleUpdate(Host &aHost, const MessageMetadata &aMetadata)
1267 {
1268     Error error = kErrorNone;
1269     Host *existingHost;
1270 
1271     // Check whether the SRP update wants to remove `aHost`.
1272 
1273     VerifyOrExit(aHost.GetLease() == 0);
1274 
1275     aHost.ClearResources();
1276 
1277     existingHost = mHosts.FindMatching(aHost.GetFullName());
1278     VerifyOrExit(existingHost != nullptr);
1279 
1280     // The client may not include all services it has registered before
1281     // when removing a host. We copy and append any missing services to
1282     // `aHost` from the `existingHost` and mark them as deleted.
1283 
1284     for (Service &service : existingHost->mServices)
1285     {
1286         if (service.mIsDeleted)
1287         {
1288             continue;
1289         }
1290 
1291         SuccessOrExit(error = aHost.AddCopyOfServiceAsDeletedIfNotPresent(service, aMetadata.mRxTime));
1292     }
1293 
1294 exit:
1295     InformUpdateHandlerOrCommit(error, aHost, aMetadata);
1296 }
1297 
InformUpdateHandlerOrCommit(Error aError,Host & aHost,const MessageMetadata & aMetadata)1298 void Server::InformUpdateHandlerOrCommit(Error aError, Host &aHost, const MessageMetadata &aMetadata)
1299 {
1300     if ((aError == kErrorNone) && (mServiceUpdateHandler != nullptr))
1301     {
1302         UpdateMetadata *update = UpdateMetadata::Allocate(GetInstance(), aHost, aMetadata);
1303 
1304         if (update != nullptr)
1305         {
1306             mOutstandingUpdates.Push(*update);
1307             mOutstandingUpdatesTimer.FireAtIfEarlier(update->GetExpireTime());
1308 
1309             LogInfo("SRP update handler is notified (updatedId = %u)", update->GetId());
1310             mServiceUpdateHandler(update->GetId(), &aHost, kDefaultEventsHandlerTimeout, mServiceUpdateHandlerContext);
1311             ExitNow();
1312         }
1313 
1314         aError = kErrorNoBufs;
1315     }
1316 
1317     CommitSrpUpdate(aError, aHost, aMetadata);
1318 
1319 exit:
1320     return;
1321 }
1322 
SendResponse(const Dns::UpdateHeader & aHeader,Dns::UpdateHeader::Response aResponseCode,const Ip6::MessageInfo & aMessageInfo)1323 void Server::SendResponse(const Dns::UpdateHeader &   aHeader,
1324                           Dns::UpdateHeader::Response aResponseCode,
1325                           const Ip6::MessageInfo &    aMessageInfo)
1326 {
1327     Error             error;
1328     Message *         response = nullptr;
1329     Dns::UpdateHeader header;
1330 
1331     response = GetSocket().NewMessage(0);
1332     VerifyOrExit(response != nullptr, error = kErrorNoBufs);
1333 
1334     header.SetMessageId(aHeader.GetMessageId());
1335     header.SetType(Dns::UpdateHeader::kTypeResponse);
1336     header.SetQueryType(aHeader.GetQueryType());
1337     header.SetResponseCode(aResponseCode);
1338     SuccessOrExit(error = response->Append(header));
1339 
1340     SuccessOrExit(error = GetSocket().SendTo(*response, aMessageInfo));
1341 
1342     if (aResponseCode != Dns::UpdateHeader::kResponseSuccess)
1343     {
1344         LogWarn("Send fail response: %d", aResponseCode);
1345     }
1346     else
1347     {
1348         LogInfo("Send success response");
1349     }
1350 
1351     UpdateResponseCounters(aResponseCode);
1352 
1353 exit:
1354     if (error != kErrorNone)
1355     {
1356         LogWarn("Failed to send response: %s", ErrorToString(error));
1357         FreeMessage(response);
1358     }
1359 }
1360 
SendResponse(const Dns::UpdateHeader & aHeader,uint32_t aLease,uint32_t aKeyLease,const Ip6::MessageInfo & aMessageInfo)1361 void Server::SendResponse(const Dns::UpdateHeader &aHeader,
1362                           uint32_t                 aLease,
1363                           uint32_t                 aKeyLease,
1364                           const Ip6::MessageInfo & aMessageInfo)
1365 {
1366     Error             error;
1367     Message *         response = nullptr;
1368     Dns::UpdateHeader header;
1369     Dns::OptRecord    optRecord;
1370     Dns::LeaseOption  leaseOption;
1371 
1372     response = GetSocket().NewMessage(0);
1373     VerifyOrExit(response != nullptr, error = kErrorNoBufs);
1374 
1375     header.SetMessageId(aHeader.GetMessageId());
1376     header.SetType(Dns::UpdateHeader::kTypeResponse);
1377     header.SetQueryType(aHeader.GetQueryType());
1378     header.SetResponseCode(Dns::UpdateHeader::kResponseSuccess);
1379     header.SetAdditionalRecordCount(1);
1380     SuccessOrExit(error = response->Append(header));
1381 
1382     // Append the root domain (".").
1383     SuccessOrExit(error = Dns::Name::AppendTerminator(*response));
1384 
1385     optRecord.Init();
1386     optRecord.SetUdpPayloadSize(kUdpPayloadSize);
1387     optRecord.SetDnsSecurityFlag();
1388     optRecord.SetLength(sizeof(Dns::LeaseOption));
1389     SuccessOrExit(error = response->Append(optRecord));
1390 
1391     leaseOption.Init();
1392     leaseOption.SetLeaseInterval(aLease);
1393     leaseOption.SetKeyLeaseInterval(aKeyLease);
1394     SuccessOrExit(error = response->Append(leaseOption));
1395 
1396     SuccessOrExit(error = GetSocket().SendTo(*response, aMessageInfo));
1397 
1398     LogInfo("Send success response with granted lease: %u and key lease: %u", aLease, aKeyLease);
1399 
1400     UpdateResponseCounters(Dns::UpdateHeader::kResponseSuccess);
1401 
1402 exit:
1403     if (error != kErrorNone)
1404     {
1405         LogWarn("Failed to send response: %s", ErrorToString(error));
1406         FreeMessage(response);
1407     }
1408 }
1409 
HandleUdpReceive(void * aContext,otMessage * aMessage,const otMessageInfo * aMessageInfo)1410 void Server::HandleUdpReceive(void *aContext, otMessage *aMessage, const otMessageInfo *aMessageInfo)
1411 {
1412     static_cast<Server *>(aContext)->HandleUdpReceive(AsCoreType(aMessage), AsCoreType(aMessageInfo));
1413 }
1414 
HandleUdpReceive(Message & aMessage,const Ip6::MessageInfo & aMessageInfo)1415 void Server::HandleUdpReceive(Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
1416 {
1417     Error error = ProcessMessage(aMessage, aMessageInfo);
1418 
1419     if (error != kErrorNone)
1420     {
1421         LogInfo("Failed to handle DNS message: %s", ErrorToString(error));
1422     }
1423 }
1424 
ProcessMessage(Message & aMessage,const Ip6::MessageInfo & aMessageInfo)1425 Error Server::ProcessMessage(Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
1426 {
1427     return ProcessMessage(aMessage, TimerMilli::GetNow(), mTtlConfig, mLeaseConfig, &aMessageInfo);
1428 }
1429 
ProcessMessage(Message & aMessage,TimeMilli aRxTime,const TtlConfig & aTtlConfig,const LeaseConfig & aLeaseConfig,const Ip6::MessageInfo * aMessageInfo)1430 Error Server::ProcessMessage(Message &               aMessage,
1431                              TimeMilli               aRxTime,
1432                              const TtlConfig &       aTtlConfig,
1433                              const LeaseConfig &     aLeaseConfig,
1434                              const Ip6::MessageInfo *aMessageInfo)
1435 {
1436     Error           error;
1437     MessageMetadata metadata;
1438 
1439     metadata.mOffset      = aMessage.GetOffset();
1440     metadata.mRxTime      = aRxTime;
1441     metadata.mTtlConfig   = aTtlConfig;
1442     metadata.mLeaseConfig = aLeaseConfig;
1443     metadata.mMessageInfo = aMessageInfo;
1444 
1445     SuccessOrExit(error = aMessage.Read(metadata.mOffset, metadata.mDnsHeader));
1446     metadata.mOffset += sizeof(Dns::UpdateHeader);
1447 
1448     VerifyOrExit(metadata.mDnsHeader.GetType() == Dns::UpdateHeader::Type::kTypeQuery, error = kErrorDrop);
1449     VerifyOrExit(metadata.mDnsHeader.GetQueryType() == Dns::UpdateHeader::kQueryTypeUpdate, error = kErrorDrop);
1450 
1451     ProcessDnsUpdate(aMessage, metadata);
1452 
1453 exit:
1454     return error;
1455 }
1456 
HandleLeaseTimer(Timer & aTimer)1457 void Server::HandleLeaseTimer(Timer &aTimer)
1458 {
1459     aTimer.Get<Server>().HandleLeaseTimer();
1460 }
1461 
HandleLeaseTimer(void)1462 void Server::HandleLeaseTimer(void)
1463 {
1464     TimeMilli now                = TimerMilli::GetNow();
1465     TimeMilli earliestExpireTime = now.GetDistantFuture();
1466     Host *    nextHost;
1467 
1468     for (Host *host = mHosts.GetHead(); host != nullptr; host = nextHost)
1469     {
1470         nextHost = host->GetNext();
1471 
1472         if (host->GetKeyExpireTime() <= now)
1473         {
1474             LogInfo("KEY LEASE of host %s expired", host->GetFullName());
1475 
1476             // Removes the whole host and all services if the KEY RR expired.
1477             RemoveHost(host, kDeleteName, kNotifyServiceHandler);
1478         }
1479         else if (host->IsDeleted())
1480         {
1481             // The host has been deleted, but the hostname & service instance names retain.
1482 
1483             Service *next;
1484 
1485             earliestExpireTime = OT_MIN(earliestExpireTime, host->GetKeyExpireTime());
1486 
1487             // Check if any service instance name expired.
1488             for (Service *service = host->mServices.GetHead(); service != nullptr; service = next)
1489             {
1490                 next = service->GetNext();
1491 
1492                 OT_ASSERT(service->mIsDeleted);
1493 
1494                 if (service->GetKeyExpireTime() <= now)
1495                 {
1496                     service->Log(Service::kKeyLeaseExpired);
1497                     host->RemoveService(service, kDeleteName, kNotifyServiceHandler);
1498                 }
1499                 else
1500                 {
1501                     earliestExpireTime = OT_MIN(earliestExpireTime, service->GetKeyExpireTime());
1502                 }
1503             }
1504         }
1505         else if (host->GetExpireTime() <= now)
1506         {
1507             LogInfo("LEASE of host %s expired", host->GetFullName());
1508 
1509             // If the host expired, delete all resources of this host and its services.
1510             for (Service &service : host->mServices)
1511             {
1512                 // Don't need to notify the service handler as `RemoveHost` at below will do.
1513                 host->RemoveService(&service, kRetainName, kDoNotNotifyServiceHandler);
1514             }
1515 
1516             RemoveHost(host, kRetainName, kNotifyServiceHandler);
1517 
1518             earliestExpireTime = OT_MIN(earliestExpireTime, host->GetKeyExpireTime());
1519         }
1520         else
1521         {
1522             // The host doesn't expire, check if any service expired or is explicitly removed.
1523 
1524             Service *next;
1525 
1526             OT_ASSERT(!host->IsDeleted());
1527 
1528             earliestExpireTime = OT_MIN(earliestExpireTime, host->GetExpireTime());
1529 
1530             for (Service *service = host->mServices.GetHead(); service != nullptr; service = next)
1531             {
1532                 next = service->GetNext();
1533 
1534                 if (service->GetKeyExpireTime() <= now)
1535                 {
1536                     service->Log(Service::kKeyLeaseExpired);
1537                     host->RemoveService(service, kDeleteName, kNotifyServiceHandler);
1538                 }
1539                 else if (service->mIsDeleted)
1540                 {
1541                     // The service has been deleted but the name retains.
1542                     earliestExpireTime = OT_MIN(earliestExpireTime, service->GetKeyExpireTime());
1543                 }
1544                 else if (service->GetExpireTime() <= now)
1545                 {
1546                     service->Log(Service::kLeaseExpired);
1547 
1548                     // The service is expired, delete it.
1549                     host->RemoveService(service, kRetainName, kNotifyServiceHandler);
1550                     earliestExpireTime = OT_MIN(earliestExpireTime, service->GetKeyExpireTime());
1551                 }
1552                 else
1553                 {
1554                     earliestExpireTime = OT_MIN(earliestExpireTime, service->GetExpireTime());
1555                 }
1556             }
1557         }
1558     }
1559 
1560     if (earliestExpireTime != now.GetDistantFuture())
1561     {
1562         OT_ASSERT(earliestExpireTime >= now);
1563         if (!mLeaseTimer.IsRunning() || earliestExpireTime <= mLeaseTimer.GetFireTime())
1564         {
1565             LogInfo("Lease timer is scheduled for %u seconds", Time::MsecToSec(earliestExpireTime - now));
1566             mLeaseTimer.StartAt(earliestExpireTime, 0);
1567         }
1568     }
1569     else
1570     {
1571         LogInfo("Lease timer is stopped");
1572         mLeaseTimer.Stop();
1573     }
1574 }
1575 
HandleOutstandingUpdatesTimer(Timer & aTimer)1576 void Server::HandleOutstandingUpdatesTimer(Timer &aTimer)
1577 {
1578     aTimer.Get<Server>().HandleOutstandingUpdatesTimer();
1579 }
1580 
HandleOutstandingUpdatesTimer(void)1581 void Server::HandleOutstandingUpdatesTimer(void)
1582 {
1583     while (!mOutstandingUpdates.IsEmpty() && mOutstandingUpdates.GetTail()->GetExpireTime() <= TimerMilli::GetNow())
1584     {
1585         LogInfo("Outstanding service update timeout (updateId = %u)", mOutstandingUpdates.GetTail()->GetId());
1586         HandleServiceUpdateResult(mOutstandingUpdates.GetTail(), kErrorResponseTimeout);
1587     }
1588 }
1589 
AddressModeToString(AddressMode aMode)1590 const char *Server::AddressModeToString(AddressMode aMode)
1591 {
1592     static const char *const kAddressModeStrings[] = {
1593         "unicast", // (0) kAddressModeUnicast
1594         "anycast", // (1) kAddressModeAnycast
1595     };
1596 
1597     static_assert(kAddressModeUnicast == 0, "kAddressModeUnicast value is incorrect");
1598     static_assert(kAddressModeAnycast == 1, "kAddressModeAnycast value is incorrect");
1599 
1600     return kAddressModeStrings[aMode];
1601 }
1602 
UpdateResponseCounters(Dns::UpdateHeader::Response aResponseCode)1603 void Server::UpdateResponseCounters(Dns::UpdateHeader::Response aResponseCode)
1604 {
1605     switch (aResponseCode)
1606     {
1607     case Dns::UpdateHeader::kResponseSuccess:
1608         ++mResponseCounters.mSuccess;
1609         break;
1610     case Dns::UpdateHeader::kResponseServerFailure:
1611         ++mResponseCounters.mServerFailure;
1612         break;
1613     case Dns::UpdateHeader::kResponseFormatError:
1614         ++mResponseCounters.mFormatError;
1615         break;
1616     case Dns::UpdateHeader::kResponseNameExists:
1617         ++mResponseCounters.mNameExists;
1618         break;
1619     case Dns::UpdateHeader::kResponseRefused:
1620         ++mResponseCounters.mRefused;
1621         break;
1622     default:
1623         ++mResponseCounters.mOther;
1624         break;
1625     }
1626 }
1627 
1628 //---------------------------------------------------------------------------------------------------------------------
1629 // Server::Service
1630 
Init(const char * aServiceName,Description & aDescription,bool aIsSubType,TimeMilli aUpdateTime)1631 Error Server::Service::Init(const char *aServiceName, Description &aDescription, bool aIsSubType, TimeMilli aUpdateTime)
1632 {
1633     mDescription.Reset(&aDescription);
1634     mNext        = nullptr;
1635     mUpdateTime  = aUpdateTime;
1636     mIsDeleted   = false;
1637     mIsSubType   = aIsSubType;
1638     mIsCommitted = false;
1639 
1640     return mServiceName.Set(aServiceName);
1641 }
1642 
GetServiceSubTypeLabel(char * aLabel,uint8_t aMaxSize) const1643 Error Server::Service::GetServiceSubTypeLabel(char *aLabel, uint8_t aMaxSize) const
1644 {
1645     Error       error       = kErrorNone;
1646     const char *serviceName = GetServiceName();
1647     const char *subServiceName;
1648     uint8_t     labelLength;
1649 
1650     memset(aLabel, 0, aMaxSize);
1651 
1652     VerifyOrExit(IsSubType(), error = kErrorInvalidArgs);
1653 
1654     subServiceName = StringFind(serviceName, kServiceSubTypeLabel, kStringCaseInsensitiveMatch);
1655     OT_ASSERT(subServiceName != nullptr);
1656 
1657     if (subServiceName - serviceName < aMaxSize)
1658     {
1659         labelLength = static_cast<uint8_t>(subServiceName - serviceName);
1660     }
1661     else
1662     {
1663         labelLength = aMaxSize - 1;
1664         error       = kErrorNoBufs;
1665     }
1666 
1667     memcpy(aLabel, serviceName, labelLength);
1668 
1669 exit:
1670     return error;
1671 }
1672 
GetExpireTime(void) const1673 TimeMilli Server::Service::GetExpireTime(void) const
1674 {
1675     OT_ASSERT(!mIsDeleted);
1676     OT_ASSERT(!GetHost().IsDeleted());
1677 
1678     return mUpdateTime + Time::SecToMsec(mDescription->mLease);
1679 }
1680 
GetKeyExpireTime(void) const1681 TimeMilli Server::Service::GetKeyExpireTime(void) const
1682 {
1683     return mUpdateTime + Time::SecToMsec(mDescription->mKeyLease);
1684 }
1685 
GetLeaseInfo(LeaseInfo & aLeaseInfo) const1686 void Server::Service::GetLeaseInfo(LeaseInfo &aLeaseInfo) const
1687 {
1688     TimeMilli now           = TimerMilli::GetNow();
1689     TimeMilli expireTime    = GetExpireTime();
1690     TimeMilli keyExpireTime = GetKeyExpireTime();
1691 
1692     aLeaseInfo.mLease             = Time::SecToMsec(GetLease());
1693     aLeaseInfo.mKeyLease          = Time::SecToMsec(GetKeyLease());
1694     aLeaseInfo.mRemainingLease    = (now <= expireTime) ? (expireTime - now) : 0;
1695     aLeaseInfo.mRemainingKeyLease = (now <= keyExpireTime) ? (keyExpireTime - now) : 0;
1696 }
1697 
MatchesInstanceName(const char * aInstanceName) const1698 bool Server::Service::MatchesInstanceName(const char *aInstanceName) const
1699 {
1700     return StringMatch(mDescription->mInstanceName.AsCString(), aInstanceName, kStringCaseInsensitiveMatch);
1701 }
1702 
MatchesServiceName(const char * aServiceName) const1703 bool Server::Service::MatchesServiceName(const char *aServiceName) const
1704 {
1705     return StringMatch(mServiceName.AsCString(), aServiceName, kStringCaseInsensitiveMatch);
1706 }
1707 
MatchesFlags(Flags aFlags) const1708 bool Server::Service::MatchesFlags(Flags aFlags) const
1709 {
1710     bool matches = false;
1711 
1712     if (IsSubType())
1713     {
1714         VerifyOrExit(aFlags & kFlagSubType);
1715     }
1716     else
1717     {
1718         VerifyOrExit(aFlags & kFlagBaseType);
1719     }
1720 
1721     if (IsDeleted())
1722     {
1723         VerifyOrExit(aFlags & kFlagDeleted);
1724     }
1725     else
1726     {
1727         VerifyOrExit(aFlags & kFlagActive);
1728     }
1729 
1730     matches = true;
1731 
1732 exit:
1733     return matches;
1734 }
1735 
1736 #if OT_SHOULD_LOG_AT(OT_LOG_LEVEL_INFO)
Log(Action aAction) const1737 void Server::Service::Log(Action aAction) const
1738 {
1739     static const char *const kActionStrings[] = {
1740         "Add new",                   // (0) kAddNew
1741         "Update existing",           // (1) kUpdateExisting
1742         "Remove but retain name of", // (2) kRemoveButRetainName
1743         "Fully remove",              // (3) kFullyRemove
1744         "LEASE expired for ",        // (4) kLeaseExpired
1745         "KEY LEASE expired for",     // (5) kKeyLeaseExpired
1746     };
1747 
1748     char subLabel[Dns::Name::kMaxLabelSize];
1749 
1750     static_assert(0 == kAddNew, "kAddNew value is incorrect");
1751     static_assert(1 == kUpdateExisting, "kUpdateExisting value is incorrect");
1752     static_assert(2 == kRemoveButRetainName, "kRemoveButRetainName value is incorrect");
1753     static_assert(3 == kFullyRemove, "kFullyRemove value is incorrect");
1754     static_assert(4 == kLeaseExpired, "kLeaseExpired value is incorrect");
1755     static_assert(5 == kKeyLeaseExpired, "kKeyLeaseExpired value is incorrect");
1756 
1757     // We only log if the `Service` is marked as committed. This
1758     // ensures that temporary `Service` entries associated with a
1759     // newly received SRP update message are not logged (e.g., when
1760     // associated `Host` is being freed).
1761 
1762     if (mIsCommitted)
1763     {
1764         IgnoreError(GetServiceSubTypeLabel(subLabel, sizeof(subLabel)));
1765 
1766         LogInfo("%s service '%s'%s%s", kActionStrings[aAction], GetInstanceName(), IsSubType() ? " subtype:" : "",
1767                 subLabel);
1768     }
1769 }
1770 #else
Log(Action) const1771 void Server::Service::Log(Action) const
1772 {
1773 }
1774 #endif // #if OT_SHOULD_LOG_AT(OT_LOG_LEVEL_INFO)
1775 
1776 //---------------------------------------------------------------------------------------------------------------------
1777 // Server::Service::Description
1778 
Init(const char * aInstanceName,Host & aHost)1779 Error Server::Service::Description::Init(const char *aInstanceName, Host &aHost)
1780 {
1781     mNext       = nullptr;
1782     mHost       = &aHost;
1783     mPriority   = 0;
1784     mWeight     = 0;
1785     mTtl        = 0;
1786     mPort       = 0;
1787     mLease      = 0;
1788     mKeyLease   = 0;
1789     mUpdateTime = TimerMilli::GetNow().GetDistantPast();
1790     mTxtData.Free();
1791 
1792     return mInstanceName.Set(aInstanceName);
1793 }
1794 
Matches(const char * aInstanceName) const1795 bool Server::Service::Description::Matches(const char *aInstanceName) const
1796 {
1797     return StringMatch(mInstanceName.AsCString(), aInstanceName, kStringCaseInsensitiveMatch);
1798 }
1799 
ClearResources(void)1800 void Server::Service::Description::ClearResources(void)
1801 {
1802     mPort = 0;
1803     mTxtData.Free();
1804 }
1805 
TakeResourcesFrom(Description & aDescription)1806 void Server::Service::Description::TakeResourcesFrom(Description &aDescription)
1807 {
1808     mTxtData.SetFrom(static_cast<Heap::Data &&>(aDescription.mTxtData));
1809 
1810     mPriority = aDescription.mPriority;
1811     mWeight   = aDescription.mWeight;
1812     mPort     = aDescription.mPort;
1813 
1814     mTtl        = aDescription.mTtl;
1815     mLease      = aDescription.mLease;
1816     mKeyLease   = aDescription.mKeyLease;
1817     mUpdateTime = TimerMilli::GetNow();
1818 }
1819 
SetTxtDataFromMessage(const Message & aMessage,uint16_t aOffset,uint16_t aLength)1820 Error Server::Service::Description::SetTxtDataFromMessage(const Message &aMessage, uint16_t aOffset, uint16_t aLength)
1821 {
1822     Error error;
1823 
1824     SuccessOrExit(error = mTxtData.SetFrom(aMessage, aOffset, aLength));
1825     VerifyOrExit(Dns::TxtRecord::VerifyTxtData(mTxtData.GetBytes(), mTxtData.GetLength(), /* aAllowEmpty */ false),
1826                  error = kErrorParse);
1827 
1828 exit:
1829     if (error != kErrorNone)
1830     {
1831         mTxtData.Free();
1832     }
1833 
1834     return error;
1835 }
1836 
1837 //---------------------------------------------------------------------------------------------------------------------
1838 // Server::Host
1839 
Host(Instance & aInstance,TimeMilli aUpdateTime)1840 Server::Host::Host(Instance &aInstance, TimeMilli aUpdateTime)
1841     : InstanceLocator(aInstance)
1842     , mNext(nullptr)
1843     , mTtl(0)
1844     , mLease(0)
1845     , mKeyLease(0)
1846     , mUpdateTime(aUpdateTime)
1847 {
1848     mKeyRecord.Clear();
1849 }
1850 
~Host(void)1851 Server::Host::~Host(void)
1852 {
1853     FreeAllServices();
1854 }
1855 
SetFullName(const char * aFullName)1856 Error Server::Host::SetFullName(const char *aFullName)
1857 {
1858     // `mFullName` becomes immutable after it is set, so if it is
1859     // already set, we only accept a `aFullName` that matches the
1860     // current name.
1861 
1862     Error error;
1863 
1864     if (mFullName.IsNull())
1865     {
1866         error = mFullName.Set(aFullName);
1867     }
1868     else
1869     {
1870         error = Matches(aFullName) ? kErrorNone : kErrorFailed;
1871     }
1872 
1873     return error;
1874 }
1875 
Matches(const char * aFullName) const1876 bool Server::Host::Matches(const char *aFullName) const
1877 {
1878     return StringMatch(mFullName.AsCString(), aFullName, kStringCaseInsensitiveMatch);
1879 }
1880 
SetKeyRecord(Dns::Ecdsa256KeyRecord & aKeyRecord)1881 void Server::Host::SetKeyRecord(Dns::Ecdsa256KeyRecord &aKeyRecord)
1882 {
1883     OT_ASSERT(aKeyRecord.IsValid());
1884 
1885     mKeyRecord = aKeyRecord;
1886 }
1887 
GetExpireTime(void) const1888 TimeMilli Server::Host::GetExpireTime(void) const
1889 {
1890     OT_ASSERT(!IsDeleted());
1891 
1892     return mUpdateTime + Time::SecToMsec(mLease);
1893 }
1894 
GetKeyExpireTime(void) const1895 TimeMilli Server::Host::GetKeyExpireTime(void) const
1896 {
1897     return mUpdateTime + Time::SecToMsec(mKeyLease);
1898 }
1899 
GetLeaseInfo(LeaseInfo & aLeaseInfo) const1900 void Server::Host::GetLeaseInfo(LeaseInfo &aLeaseInfo) const
1901 {
1902     TimeMilli now           = TimerMilli::GetNow();
1903     TimeMilli expireTime    = GetExpireTime();
1904     TimeMilli keyExpireTime = GetKeyExpireTime();
1905 
1906     aLeaseInfo.mLease             = Time::SecToMsec(GetLease());
1907     aLeaseInfo.mKeyLease          = Time::SecToMsec(GetKeyLease());
1908     aLeaseInfo.mRemainingLease    = (now <= expireTime) ? (expireTime - now) : 0;
1909     aLeaseInfo.mRemainingKeyLease = (now <= keyExpireTime) ? (keyExpireTime - now) : 0;
1910 }
1911 
ProcessTtl(uint32_t aTtl)1912 Error Server::Host::ProcessTtl(uint32_t aTtl)
1913 {
1914     // This method processes the TTL value received in a resource record.
1915     //
1916     // If no TTL value is stored, this method wil set the stored value to @p aTtl and return `kErrorNone`.
1917     // If a TTL value is stored and @p aTtl equals the stored value, this method returns `kErrorNone`.
1918     // Otherwise, this method returns `kErrorRejected`.
1919 
1920     Error error = kErrorRejected;
1921 
1922     VerifyOrExit(aTtl && (mTtl == 0 || mTtl == aTtl));
1923 
1924     mTtl = aTtl;
1925 
1926     error = kErrorNone;
1927 
1928 exit:
1929     return error;
1930 }
1931 
FindNextService(const Service * aPrevService,Service::Flags aFlags,const char * aServiceName,const char * aInstanceName) const1932 const Server::Service *Server::Host::FindNextService(const Service *aPrevService,
1933                                                      Service::Flags aFlags,
1934                                                      const char *   aServiceName,
1935                                                      const char *   aInstanceName) const
1936 {
1937     const Service *service = (aPrevService == nullptr) ? GetServices().GetHead() : aPrevService->GetNext();
1938 
1939     for (; service != nullptr; service = service->GetNext())
1940     {
1941         if (!service->MatchesFlags(aFlags))
1942         {
1943             continue;
1944         }
1945 
1946         if ((aServiceName != nullptr) && !service->MatchesServiceName(aServiceName))
1947         {
1948             continue;
1949         }
1950 
1951         if ((aInstanceName != nullptr) && !service->MatchesInstanceName(aInstanceName))
1952         {
1953             continue;
1954         }
1955 
1956         break;
1957     }
1958 
1959     return service;
1960 }
1961 
AddNewService(const char * aServiceName,const char * aInstanceName,bool aIsSubType,TimeMilli aUpdateTime)1962 Server::Service *Server::Host::AddNewService(const char *aServiceName,
1963                                              const char *aInstanceName,
1964                                              bool        aIsSubType,
1965                                              TimeMilli   aUpdateTime)
1966 {
1967     Service *                       service = nullptr;
1968     RetainPtr<Service::Description> desc(FindServiceDescription(aInstanceName));
1969 
1970     if (desc == nullptr)
1971     {
1972         desc.Reset(Service::Description::AllocateAndInit(aInstanceName, *this));
1973         VerifyOrExit(desc != nullptr);
1974     }
1975 
1976     service = Service::AllocateAndInit(aServiceName, *desc, aIsSubType, aUpdateTime);
1977     VerifyOrExit(service != nullptr);
1978 
1979     mServices.Push(*service);
1980 
1981 exit:
1982     return service;
1983 }
1984 
RemoveService(Service * aService,RetainName aRetainName,NotifyMode aNotifyServiceHandler)1985 void Server::Host::RemoveService(Service *aService, RetainName aRetainName, NotifyMode aNotifyServiceHandler)
1986 {
1987     Server &server = Get<Server>();
1988 
1989     VerifyOrExit(aService != nullptr);
1990 
1991     aService->mIsDeleted = true;
1992 
1993     aService->Log(aRetainName ? Service::kRemoveButRetainName : Service::kFullyRemove);
1994 
1995     if (aNotifyServiceHandler && server.mServiceUpdateHandler != nullptr)
1996     {
1997         uint32_t updateId = server.AllocateId();
1998 
1999         LogInfo("SRP update handler is notified (updatedId = %u)", updateId);
2000         server.mServiceUpdateHandler(updateId, this, kDefaultEventsHandlerTimeout, server.mServiceUpdateHandlerContext);
2001         // We don't wait for the reply from the service update handler,
2002         // but always remove the service regardless of service update result.
2003         // Because removing a service should fail only when there is system
2004         // failure of the platform mDNS implementation and in which case the
2005         // service is not expected to be still registered.
2006     }
2007 
2008     if (!aRetainName)
2009     {
2010         IgnoreError(mServices.Remove(*aService));
2011         aService->Free();
2012     }
2013 
2014 exit:
2015     return;
2016 }
2017 
AddCopyOfServiceAsDeletedIfNotPresent(const Service & aService,TimeMilli aUpdateTime)2018 Error Server::Host::AddCopyOfServiceAsDeletedIfNotPresent(const Service &aService, TimeMilli aUpdateTime)
2019 {
2020     Error    error = kErrorNone;
2021     Service *newService;
2022 
2023     VerifyOrExit(FindService(aService.GetServiceName(), aService.GetInstanceName()) == nullptr);
2024 
2025     newService =
2026         AddNewService(aService.GetServiceName(), aService.GetInstanceName(), aService.IsSubType(), aUpdateTime);
2027 
2028     VerifyOrExit(newService != nullptr, error = kErrorNoBufs);
2029 
2030     newService->mDescription->mUpdateTime = aUpdateTime;
2031     newService->mIsDeleted                = true;
2032 
2033 exit:
2034     return error;
2035 }
2036 
FreeAllServices(void)2037 void Server::Host::FreeAllServices(void)
2038 {
2039     while (!mServices.IsEmpty())
2040     {
2041         RemoveService(mServices.GetHead(), kDeleteName, kDoNotNotifyServiceHandler);
2042     }
2043 }
2044 
ClearResources(void)2045 void Server::Host::ClearResources(void)
2046 {
2047     mAddresses.Free();
2048 }
2049 
MergeServicesAndResourcesFrom(Host & aHost)2050 Error Server::Host::MergeServicesAndResourcesFrom(Host &aHost)
2051 {
2052     // This method merges services, service descriptions, and other
2053     // resources from another `aHost` into current host. It can
2054     // possibly take ownership of some items from `aHost`.
2055 
2056     Error error = kErrorNone;
2057 
2058     LogInfo("Update host %s", GetFullName());
2059 
2060     mAddresses.TakeFrom(static_cast<Heap::Array<Ip6::Address> &&>(aHost.mAddresses));
2061     mKeyRecord  = aHost.mKeyRecord;
2062     mTtl        = aHost.mTtl;
2063     mLease      = aHost.mLease;
2064     mKeyLease   = aHost.mKeyLease;
2065     mUpdateTime = TimerMilli::GetNow();
2066 
2067     for (Service &service : aHost.mServices)
2068     {
2069         Service *existingService = FindService(service.GetServiceName(), service.GetInstanceName());
2070         Service *newService;
2071 
2072         if (service.mIsDeleted)
2073         {
2074             // `RemoveService()` does nothing if `exitsingService` is `nullptr`.
2075             RemoveService(existingService, kRetainName, kDoNotNotifyServiceHandler);
2076             continue;
2077         }
2078 
2079         // Add/Merge `service` into the existing service or a allocate a new one
2080 
2081         newService = (existingService != nullptr) ? existingService
2082                                                   : AddNewService(service.GetServiceName(), service.GetInstanceName(),
2083                                                                   service.IsSubType(), service.GetUpdateTime());
2084 
2085         VerifyOrExit(newService != nullptr, error = kErrorNoBufs);
2086 
2087         newService->mIsDeleted   = false;
2088         newService->mIsCommitted = true;
2089         newService->mUpdateTime  = TimerMilli::GetNow();
2090 
2091         if (!service.mIsSubType)
2092         {
2093             // (1) Service description is shared across a base type and all its subtypes.
2094             // (2) `TakeResourcesFrom()` releases resources pinned to its argument.
2095             // Therefore, make sure the function is called only for the base type.
2096             newService->mDescription->TakeResourcesFrom(*service.mDescription);
2097         }
2098 
2099         newService->Log((existingService != nullptr) ? Service::kUpdateExisting : Service::kAddNew);
2100     }
2101 
2102 exit:
2103     return error;
2104 }
2105 
HasServiceInstance(const char * aInstanceName) const2106 bool Server::Host::HasServiceInstance(const char *aInstanceName) const
2107 {
2108     return (FindServiceDescription(aInstanceName) != nullptr);
2109 }
2110 
FindServiceDescription(const char * aInstanceName) const2111 const RetainPtr<Server::Service::Description> Server::Host::FindServiceDescription(const char *aInstanceName) const
2112 {
2113     const Service::Description *desc = nullptr;
2114 
2115     for (const Service &service : mServices)
2116     {
2117         if (service.mDescription->Matches(aInstanceName))
2118         {
2119             desc = service.mDescription.Get();
2120             break;
2121         }
2122     }
2123 
2124     return RetainPtr<Service::Description>(AsNonConst(desc));
2125 }
2126 
FindServiceDescription(const char * aInstanceName)2127 RetainPtr<Server::Service::Description> Server::Host::FindServiceDescription(const char *aInstanceName)
2128 {
2129     return AsNonConst(AsConst(this)->FindServiceDescription(aInstanceName));
2130 }
2131 
FindService(const char * aServiceName,const char * aInstanceName) const2132 const Server::Service *Server::Host::FindService(const char *aServiceName, const char *aInstanceName) const
2133 {
2134     return FindNextService(/* aPrevService */ nullptr, kFlagsAnyService, aServiceName, aInstanceName);
2135 }
2136 
FindService(const char * aServiceName,const char * aInstanceName)2137 Server::Service *Server::Host::FindService(const char *aServiceName, const char *aInstanceName)
2138 {
2139     return AsNonConst(AsConst(this)->FindService(aServiceName, aInstanceName));
2140 }
2141 
FindBaseService(const char * aInstanceName) const2142 const Server::Service *Server::Host::FindBaseService(const char *aInstanceName) const
2143 {
2144     return FindNextService(/*a PrevService */ nullptr, kFlagsBaseTypeServiceOnly, /* aServiceName */ nullptr,
2145                            aInstanceName);
2146 }
2147 
FindBaseService(const char * aInstanceName)2148 Server::Service *Server::Host::FindBaseService(const char *aInstanceName)
2149 {
2150     return AsNonConst(AsConst(this)->FindBaseService(aInstanceName));
2151 }
2152 
AddIp6Address(const Ip6::Address & aIp6Address)2153 Error Server::Host::AddIp6Address(const Ip6::Address &aIp6Address)
2154 {
2155     Error error = kErrorNone;
2156 
2157     if (aIp6Address.IsMulticast() || aIp6Address.IsUnspecified() || aIp6Address.IsLoopback())
2158     {
2159         // We don't like those address because they cannot be used
2160         // for communication with exterior devices.
2161         ExitNow(error = kErrorDrop);
2162     }
2163 
2164     // Drop duplicate addresses.
2165     VerifyOrExit(!mAddresses.Contains(aIp6Address), error = kErrorDrop);
2166 
2167     error = mAddresses.PushBack(aIp6Address);
2168 
2169     if (error == kErrorNoBufs)
2170     {
2171         LogWarn("Too many addresses for host %s", GetFullName());
2172     }
2173 
2174 exit:
2175     return error;
2176 }
2177 
2178 //---------------------------------------------------------------------------------------------------------------------
2179 // Server::UpdateMetadata
2180 
UpdateMetadata(Instance & aInstance,Host & aHost,const MessageMetadata & aMessageMetadata)2181 Server::UpdateMetadata::UpdateMetadata(Instance &aInstance, Host &aHost, const MessageMetadata &aMessageMetadata)
2182     : InstanceLocator(aInstance)
2183     , mNext(nullptr)
2184     , mExpireTime(TimerMilli::GetNow() + kDefaultEventsHandlerTimeout)
2185     , mDnsHeader(aMessageMetadata.mDnsHeader)
2186     , mId(Get<Server>().AllocateId())
2187     , mTtlConfig(aMessageMetadata.mTtlConfig)
2188     , mLeaseConfig(aMessageMetadata.mLeaseConfig)
2189     , mHost(aHost)
2190     , mIsDirectRxFromClient(aMessageMetadata.IsDirectRxFromClient())
2191 {
2192     if (aMessageMetadata.mMessageInfo != nullptr)
2193     {
2194         mMessageInfo = *aMessageMetadata.mMessageInfo;
2195     }
2196 }
2197 
2198 } // namespace Srp
2199 } // namespace ot
2200 
2201 #endif // OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
2202