• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2021, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 #include "dns_dso.hpp"
30 
31 #if OPENTHREAD_CONFIG_DNS_DSO_ENABLE
32 
33 #include "instance/instance.hpp"
34 
35 /**
36  * @file
37  *   This file implements the DNS Stateful Operations (DSO) per RFC 8490.
38  */
39 
40 namespace ot {
41 namespace Dns {
42 
43 RegisterLogModule("DnsDso");
44 
45 //---------------------------------------------------------------------------------------------------------------------
46 // otPlatDso transport callbacks
47 
otPlatDsoGetInstance(otPlatDsoConnection * aConnection)48 extern "C" otInstance *otPlatDsoGetInstance(otPlatDsoConnection *aConnection)
49 {
50     return &AsCoreType(aConnection).GetInstance();
51 }
52 
otPlatDsoAccept(otInstance * aInstance,const otSockAddr * aPeerSockAddr)53 extern "C" otPlatDsoConnection *otPlatDsoAccept(otInstance *aInstance, const otSockAddr *aPeerSockAddr)
54 {
55     return AsCoreType(aInstance).Get<Dso>().AcceptConnection(AsCoreType(aPeerSockAddr));
56 }
57 
otPlatDsoHandleConnected(otPlatDsoConnection * aConnection)58 extern "C" void otPlatDsoHandleConnected(otPlatDsoConnection *aConnection)
59 {
60     AsCoreType(aConnection).HandleConnected();
61 }
62 
otPlatDsoHandleReceive(otPlatDsoConnection * aConnection,otMessage * aMessage)63 extern "C" void otPlatDsoHandleReceive(otPlatDsoConnection *aConnection, otMessage *aMessage)
64 {
65     AsCoreType(aConnection).HandleReceive(AsCoreType(aMessage));
66 }
67 
otPlatDsoHandleDisconnected(otPlatDsoConnection * aConnection,otPlatDsoDisconnectMode aMode)68 extern "C" void otPlatDsoHandleDisconnected(otPlatDsoConnection *aConnection, otPlatDsoDisconnectMode aMode)
69 {
70     AsCoreType(aConnection).HandleDisconnected(MapEnum(aMode));
71 }
72 
73 //---------------------------------------------------------------------------------------------------------------------
74 // Dso::Connection
75 
Connection(Instance & aInstance,const Ip6::SockAddr & aPeerSockAddr,Callbacks & aCallbacks,uint32_t aInactivityTimeout,uint32_t aKeepAliveInterval)76 Dso::Connection::Connection(Instance            &aInstance,
77                             const Ip6::SockAddr &aPeerSockAddr,
78                             Callbacks           &aCallbacks,
79                             uint32_t             aInactivityTimeout,
80                             uint32_t             aKeepAliveInterval)
81     : InstanceLocator(aInstance)
82     , mNext(nullptr)
83     , mCallbacks(aCallbacks)
84     , mPeerSockAddr(aPeerSockAddr)
85     , mState(kStateDisconnected)
86     , mIsServer(false)
87     , mInactivity(aInactivityTimeout)
88     , mKeepAlive(aKeepAliveInterval)
89 {
90     OT_ASSERT(aKeepAliveInterval >= kMinKeepAliveInterval);
91     Init(/* aIsServer */ false);
92 }
93 
Init(bool aIsServer)94 void Dso::Connection::Init(bool aIsServer)
95 {
96     mNextMessageId       = 1;
97     mIsServer            = aIsServer;
98     mStateDidChange      = false;
99     mLongLivedOperation  = false;
100     mRetryDelay          = 0;
101     mRetryDelayErrorCode = Dns::Header::kResponseSuccess;
102     mDisconnectReason    = kReasonUnknown;
103 }
104 
SetState(State aState)105 void Dso::Connection::SetState(State aState)
106 {
107     VerifyOrExit(mState != aState);
108 
109     LogInfo("State: %s -> %s on connection with %s", StateToString(mState), StateToString(aState),
110             mPeerSockAddr.ToString().AsCString());
111 
112     mState          = aState;
113     mStateDidChange = true;
114 
115 exit:
116     return;
117 }
118 
SignalAnyStateChange(void)119 void Dso::Connection::SignalAnyStateChange(void)
120 {
121     VerifyOrExit(mStateDidChange);
122     mStateDidChange = false;
123 
124     switch (mState)
125     {
126     case kStateDisconnected:
127         mCallbacks.mHandleDisconnected(*this);
128         break;
129 
130     case kStateConnectedButSessionless:
131         mCallbacks.mHandleConnected(*this);
132         break;
133 
134     case kStateSessionEstablished:
135         mCallbacks.mHandleSessionEstablished(*this);
136         break;
137 
138     case kStateConnecting:
139     case kStateEstablishingSession:
140         break;
141     };
142 
143 exit:
144     return;
145 }
146 
NewMessage(void)147 Message *Dso::Connection::NewMessage(void)
148 {
149     return Get<MessagePool>().Allocate(Message::kTypeOther, sizeof(Dns::Header),
150                                        Message::Settings(Message::kPriorityNormal));
151 }
152 
Connect(void)153 void Dso::Connection::Connect(void)
154 {
155     OT_ASSERT(mState == kStateDisconnected);
156 
157     Init(/* aIsServer */ false);
158     Get<Dso>().mClientConnections.Push(*this);
159     MarkAsConnecting();
160     otPlatDsoConnect(this, &mPeerSockAddr);
161 }
162 
Accept(void)163 void Dso::Connection::Accept(void)
164 {
165     OT_ASSERT(mState == kStateDisconnected);
166 
167     Init(/* aIsServer */ true);
168     Get<Dso>().mServerConnections.Push(*this);
169     MarkAsConnecting();
170 }
171 
MarkAsConnecting(void)172 void Dso::Connection::MarkAsConnecting(void)
173 {
174     SetState(kStateConnecting);
175 
176     // While in `kStateConnecting` state we use the `mKeepAlive` to
177     // track the `kConnectingTimeout` (if connection is not established
178     // within the timeout, we consider it as failure and close it).
179 
180     mKeepAlive.SetExpirationTime(TimerMilli::GetNow() + kConnectingTimeout);
181     Get<Dso>().mTimer.FireAtIfEarlier(mKeepAlive.GetExpirationTime());
182 
183     // Wait for `HandleConnected()` or `HandleDisconnected()` callbacks
184     // or timeout.
185 }
186 
HandleConnected(void)187 void Dso::Connection::HandleConnected(void)
188 {
189     OT_ASSERT(mState == kStateConnecting);
190 
191     SetState(kStateConnectedButSessionless);
192     ResetTimeouts(/* aIsKeepAliveMessage */ false);
193 
194     SignalAnyStateChange();
195 }
196 
Disconnect(DisconnectMode aMode,DisconnectReason aReason)197 void Dso::Connection::Disconnect(DisconnectMode aMode, DisconnectReason aReason)
198 {
199     VerifyOrExit(mState != kStateDisconnected);
200 
201     mDisconnectReason = aReason;
202     MarkAsDisconnected();
203 
204     otPlatDsoDisconnect(this, MapEnum(aMode));
205 
206 exit:
207     return;
208 }
209 
HandleDisconnected(DisconnectMode aMode)210 void Dso::Connection::HandleDisconnected(DisconnectMode aMode)
211 {
212     VerifyOrExit(mState != kStateDisconnected);
213 
214     if (mState == kStateConnecting)
215     {
216         mDisconnectReason = kReasonFailedToConnect;
217     }
218     else
219     {
220         switch (aMode)
221         {
222         case kGracefullyClose:
223             mDisconnectReason = kReasonPeerClosed;
224             break;
225 
226         case kForciblyAbort:
227             mDisconnectReason = kReasonPeerAborted;
228         }
229     }
230 
231     MarkAsDisconnected();
232     SignalAnyStateChange();
233 
234 exit:
235     return;
236 }
237 
MarkAsDisconnected(void)238 void Dso::Connection::MarkAsDisconnected(void)
239 {
240     if (IsClient())
241     {
242         IgnoreError(Get<Dso>().mClientConnections.Remove(*this));
243     }
244     else
245     {
246         IgnoreError(Get<Dso>().mServerConnections.Remove(*this));
247     }
248 
249     mPendingRequests.Clear();
250     SetState(kStateDisconnected);
251 
252     LogInfo("Disconnect reason: %s", DisconnectReasonToString(mDisconnectReason));
253 }
254 
MarkSessionEstablished(void)255 void Dso::Connection::MarkSessionEstablished(void)
256 {
257     switch (mState)
258     {
259     case kStateConnectedButSessionless:
260     case kStateEstablishingSession:
261     case kStateSessionEstablished:
262         break;
263 
264     case kStateDisconnected:
265     case kStateConnecting:
266         OT_ASSERT(false);
267     }
268 
269     SetState(kStateSessionEstablished);
270 }
271 
SendRequestMessage(Message & aMessage,MessageId & aMessageId,uint32_t aResponseTimeout)272 Error Dso::Connection::SendRequestMessage(Message &aMessage, MessageId &aMessageId, uint32_t aResponseTimeout)
273 {
274     return SendMessage(aMessage, kRequestMessage, aMessageId, Dns::Header::kResponseSuccess, aResponseTimeout);
275 }
276 
SendUnidirectionalMessage(Message & aMessage)277 Error Dso::Connection::SendUnidirectionalMessage(Message &aMessage)
278 {
279     MessageId messageId = 0;
280 
281     return SendMessage(aMessage, kUnidirectionalMessage, messageId);
282 }
283 
SendResponseMessage(Message & aMessage,MessageId aResponseId)284 Error Dso::Connection::SendResponseMessage(Message &aMessage, MessageId aResponseId)
285 {
286     return SendMessage(aMessage, kResponseMessage, aResponseId);
287 }
288 
SetLongLivedOperation(bool aLongLivedOperation)289 void Dso::Connection::SetLongLivedOperation(bool aLongLivedOperation)
290 {
291     VerifyOrExit(mLongLivedOperation != aLongLivedOperation);
292 
293     mLongLivedOperation = aLongLivedOperation;
294 
295     LogInfo("Long-lived operation %s", mLongLivedOperation ? "started" : "stopped");
296 
297     if (!mLongLivedOperation)
298     {
299         NextFireTime nextTime;
300 
301         UpdateNextFireTime(nextTime);
302         Get<Dso>().mTimer.FireAtIfEarlier(nextTime);
303     }
304 
305 exit:
306     return;
307 }
308 
SendRetryDelayMessage(uint32_t aDelay,Dns::Header::Response aResponseCode)309 Error Dso::Connection::SendRetryDelayMessage(uint32_t aDelay, Dns::Header::Response aResponseCode)
310 {
311     Error         error   = kErrorNone;
312     Message      *message = nullptr;
313     RetryDelayTlv retryDelayTlv;
314     MessageId     messageId;
315 
316     switch (mState)
317     {
318     case kStateSessionEstablished:
319         OT_ASSERT(IsServer());
320         break;
321 
322     case kStateConnectedButSessionless:
323     case kStateEstablishingSession:
324     case kStateDisconnected:
325     case kStateConnecting:
326         OT_ASSERT(false);
327     }
328 
329     message = NewMessage();
330     VerifyOrExit(message != nullptr, error = kErrorNoBufs);
331 
332     retryDelayTlv.Init();
333     retryDelayTlv.SetRetryDelay(aDelay);
334     SuccessOrExit(error = message->Append(retryDelayTlv));
335     error = SendMessage(*message, kUnidirectionalMessage, messageId, aResponseCode);
336 
337 exit:
338     FreeMessageOnError(message, error);
339     return error;
340 }
341 
SetTimeouts(uint32_t aInactivityTimeout,uint32_t aKeepAliveInterval)342 Error Dso::Connection::SetTimeouts(uint32_t aInactivityTimeout, uint32_t aKeepAliveInterval)
343 {
344     Error error = kErrorNone;
345 
346     VerifyOrExit(aKeepAliveInterval >= kMinKeepAliveInterval, error = kErrorInvalidArgs);
347 
348     // If acting as server, the timeout values are the ones we grant
349     // to a connecting clients. If acting as client, the timeout
350     // values are what to request when sending Keep Alive message.
351     // If in `kStateDisconnected` we set both (since we don't know
352     // yet whether we are going to connect as client or server).
353 
354     if ((mState == kStateDisconnected) || IsServer())
355     {
356         mKeepAlive.SetInterval(aKeepAliveInterval);
357         AdjustInactivityTimeout(aInactivityTimeout);
358     }
359 
360     if ((mState == kStateDisconnected) || IsClient())
361     {
362         mKeepAlive.SetRequestInterval(aKeepAliveInterval);
363         mInactivity.SetRequestInterval(aInactivityTimeout);
364     }
365 
366     switch (mState)
367     {
368     case kStateDisconnected:
369     case kStateConnecting:
370         break;
371 
372     case kStateConnectedButSessionless:
373     case kStateEstablishingSession:
374         if (IsServer())
375         {
376             break;
377         }
378 
379         OT_FALL_THROUGH;
380 
381     case kStateSessionEstablished:
382         error = SendKeepAliveMessage();
383     }
384 
385 exit:
386     return error;
387 }
388 
SendKeepAliveMessage(void)389 Error Dso::Connection::SendKeepAliveMessage(void)
390 {
391     return SendKeepAliveMessage(IsServer() ? kUnidirectionalMessage : kRequestMessage, 0);
392 }
393 
SendKeepAliveMessage(MessageType aMessageType,MessageId aResponseId)394 Error Dso::Connection::SendKeepAliveMessage(MessageType aMessageType, MessageId aResponseId)
395 {
396     // Sends a Keep Alive message of a given type. This is a common
397     // method used by both client and server. `aResponseId` is
398     // applicable and used only when the message type is
399     // `kResponseMessage`.
400 
401     Error        error   = kErrorNone;
402     Message     *message = nullptr;
403     KeepAliveTlv keepAliveTlv;
404 
405     switch (mState)
406     {
407     case kStateConnectedButSessionless:
408     case kStateEstablishingSession:
409         if (IsServer())
410         {
411             // While session is being established, server is only allowed
412             // to send a Keep Alive response to a request from client.
413             OT_ASSERT(aMessageType == kResponseMessage);
414         }
415         break;
416 
417     case kStateSessionEstablished:
418         break;
419 
420     case kStateDisconnected:
421     case kStateConnecting:
422         OT_ASSERT(false);
423     }
424 
425     // Server can send Keep Alive response (to a request from client)
426     // or a unidirectional Keep Alive message. Client can send
427     // KeepAlive request message.
428 
429     if (IsServer())
430     {
431         if (aMessageType == kResponseMessage)
432         {
433             OT_ASSERT(aResponseId != 0);
434         }
435         else
436         {
437             OT_ASSERT(aMessageType == kUnidirectionalMessage);
438         }
439     }
440     else
441     {
442         OT_ASSERT(aMessageType == kRequestMessage);
443     }
444 
445     message = NewMessage();
446     VerifyOrExit(message != nullptr, error = kErrorNoBufs);
447 
448     keepAliveTlv.Init();
449 
450     if (IsServer())
451     {
452         keepAliveTlv.SetInactivityTimeout(mInactivity.GetInterval());
453         keepAliveTlv.SetKeepAliveInterval(mKeepAlive.GetInterval());
454     }
455     else
456     {
457         keepAliveTlv.SetInactivityTimeout(mInactivity.GetRequestInterval());
458         keepAliveTlv.SetKeepAliveInterval(mKeepAlive.GetRequestInterval());
459     }
460 
461     SuccessOrExit(error = message->Append(keepAliveTlv));
462 
463     error = SendMessage(*message, aMessageType, aResponseId);
464 
465 exit:
466     FreeMessageOnError(message, error);
467     return error;
468 }
469 
SendMessage(Message & aMessage,MessageType aMessageType,MessageId & aMessageId,Dns::Header::Response aResponseCode,uint32_t aResponseTimeout)470 Error Dso::Connection::SendMessage(Message              &aMessage,
471                                    MessageType           aMessageType,
472                                    MessageId            &aMessageId,
473                                    Dns::Header::Response aResponseCode,
474                                    uint32_t              aResponseTimeout)
475 {
476     Error       error          = kErrorNone;
477     Tlv::Type   primaryTlvType = Tlv::kReservedType;
478     Dns::Header header;
479 
480     switch (mState)
481     {
482     case kStateConnectedButSessionless:
483         // To establish session, client MUST send a request message.
484         // Server is not allowed to send any messages. Unidirectional
485         // messages are not allowed before session is established.
486         OT_ASSERT(IsClient());
487         OT_ASSERT(aMessageType == kRequestMessage);
488         break;
489 
490     case kStateEstablishingSession:
491         // During session establishment, client is allowed to send
492         // additional request messages, server is only allowed to
493         // send response.
494         if (IsClient())
495         {
496             OT_ASSERT(aMessageType == kRequestMessage);
497         }
498         else
499         {
500             OT_ASSERT(aMessageType == kResponseMessage);
501         }
502         break;
503 
504     case kStateSessionEstablished:
505         // All message types are allowed.
506         break;
507 
508     case kStateDisconnected:
509     case kStateConnecting:
510         OT_ASSERT(false);
511     }
512 
513     // A DSO request or unidirectional message MUST contain at
514     // least one TLV. The first TLV is the "Primary TLV" and
515     // determines the nature of the operation being performed.
516     // A DSO response message may contain no TLVs, or may contain
517     // one or more TLVs. Response Primary TLV(s) MUST appear first
518     // in a DSO response message.
519 
520     aMessage.SetOffset(0);
521     IgnoreError(ReadPrimaryTlv(aMessage, primaryTlvType));
522 
523     switch (aMessageType)
524     {
525     case kResponseMessage:
526         break;
527     case kRequestMessage:
528     case kUnidirectionalMessage:
529         OT_ASSERT(primaryTlvType != Tlv::kReservedType);
530     }
531 
532     // `header` is cleared from its constructor call so all fields
533     // start as zero.
534 
535     switch (aMessageType)
536     {
537     case kRequestMessage:
538         header.SetType(Dns::Header::kTypeQuery);
539         aMessageId = mNextMessageId;
540         break;
541 
542     case kResponseMessage:
543         header.SetType(Dns::Header::kTypeResponse);
544         break;
545 
546     case kUnidirectionalMessage:
547         header.SetType(Dns::Header::kTypeQuery);
548         aMessageId = 0;
549         break;
550     }
551 
552     header.SetMessageId(aMessageId);
553     header.SetQueryType(Dns::Header::kQueryTypeDso);
554     header.SetResponseCode(aResponseCode);
555     SuccessOrExit(error = aMessage.Prepend(header));
556 
557     SuccessOrExit(error = AppendPadding(aMessage));
558 
559     // Update `mPendingRequests` list with the new request info
560 
561     if (aMessageType == kRequestMessage)
562     {
563         SuccessOrExit(
564             error = mPendingRequests.Add(mNextMessageId, primaryTlvType, TimerMilli::GetNow() + aResponseTimeout));
565 
566         if (++mNextMessageId == 0)
567         {
568             mNextMessageId = 1;
569         }
570     }
571 
572     LogInfo("Sending %s message with id %u to %s", MessageTypeToString(aMessageType), aMessageId,
573             mPeerSockAddr.ToString().AsCString());
574 
575     switch (mState)
576     {
577     case kStateConnectedButSessionless:
578         // On client we transition from "connected" state to
579         // "establishing session" state on successfully sending a
580         // request message.
581         if (IsClient())
582         {
583             SetState(kStateEstablishingSession);
584         }
585         break;
586 
587     case kStateEstablishingSession:
588         // On server we transition from "establishing session" state
589         // to "established" on sending a response with success
590         // response code.
591         if (IsServer() && (aResponseCode == Dns::Header::kResponseSuccess))
592         {
593             SetState(kStateSessionEstablished);
594         }
595 
596     default:
597         break;
598     }
599 
600     ResetTimeouts(/* aIsKeepAliveMessage*/ (primaryTlvType == KeepAliveTlv::kType));
601 
602     otPlatDsoSend(this, &aMessage);
603 
604     // Signal any state changes. This is done at the very end when the
605     // `SendMessage()` is fully processed (all state and local
606     // variables are updated) to ensure that we do not have any
607     // reentrancy issues (e.g., if the callback signalling state
608     // change triggers another tx).
609 
610     SignalAnyStateChange();
611 
612 exit:
613     return error;
614 }
615 
AppendPadding(Message & aMessage)616 Error Dso::Connection::AppendPadding(Message &aMessage)
617 {
618     // This method appends Encryption Padding TLV to a DSO message.
619     // It uses the padding policy "Random-Block-Length Padding" from
620     // RFC 8467.
621 
622     static const uint16_t kBlockLengths[] = {8, 11, 17, 21};
623 
624     Error                error = kErrorNone;
625     uint16_t             blockLength;
626     EncryptionPaddingTlv paddingTlv;
627 
628     // We pick a random block length. The random selection can be
629     // based on a "weak" source of randomness (so the use of
630     // `NonCrypto` is fine). We add padding to the message such
631     // that its padded length is a multiple of the chosen block
632     // length.
633 
634     blockLength = kBlockLengths[Random::NonCrypto::GetUint8InRange(0, GetArrayLength(kBlockLengths))];
635 
636     paddingTlv.Init((blockLength - ((aMessage.GetLength() + sizeof(Tlv)) % blockLength)) % blockLength);
637 
638     SuccessOrExit(error = aMessage.Append(paddingTlv));
639 
640     for (uint16_t len = paddingTlv.GetLength(); len > 0; len--)
641     {
642         SuccessOrExit(error = aMessage.Append<uint8_t>(0));
643     }
644 
645 exit:
646     return error;
647 }
648 
HandleReceive(Message & aMessage)649 void Dso::Connection::HandleReceive(Message &aMessage)
650 {
651     Error       error          = kErrorAbort;
652     Tlv::Type   primaryTlvType = Tlv::kReservedType;
653     Dns::Header header;
654 
655     SuccessOrExit(aMessage.Read(0, header));
656 
657     if (header.GetQueryType() != Dns::Header::kQueryTypeDso)
658     {
659         if (header.GetType() == Dns::Header::kTypeQuery)
660         {
661             SendErrorResponse(header, Dns::Header::kResponseNotImplemented);
662             error = kErrorNone;
663         }
664 
665         ExitNow();
666     }
667 
668     switch (mState)
669     {
670     case kStateConnectedButSessionless:
671         // After connection is established, client should initiate
672         // establishing session (by sending a request). So no rx is
673         // allowed before this. On server, we allow rx of a request
674         // message only.
675         VerifyOrExit(IsServer() && (header.GetType() == Dns::Header::kTypeQuery) && (header.GetMessageId() != 0));
676         break;
677 
678     case kStateEstablishingSession:
679         // Unidirectional message are allowed after session is
680         // established. While session is being established, on client,
681         // we allow rx on response message. On server we can rx
682         // request or response.
683 
684         VerifyOrExit(header.GetMessageId() != 0);
685 
686         if (IsClient())
687         {
688             VerifyOrExit(header.GetType() == Dns::Header::kTypeResponse);
689         }
690         break;
691 
692     case kStateSessionEstablished:
693         // All message types are allowed.
694         break;
695 
696     case kStateDisconnected:
697     case kStateConnecting:
698         ExitNow();
699     }
700 
701     // All count fields MUST be set to zero in the header.
702     VerifyOrExit((header.GetQuestionCount() == 0) && (header.GetAnswerCount() == 0) &&
703                  (header.GetAuthorityRecordCount() == 0) && (header.GetAdditionalRecordCount() == 0));
704 
705     aMessage.SetOffset(sizeof(header));
706 
707     switch (ReadPrimaryTlv(aMessage, primaryTlvType))
708     {
709     case kErrorNone:
710         VerifyOrExit(primaryTlvType != Tlv::kReservedType);
711         break;
712 
713     case kErrorNotFound:
714         // The `primaryTlvType` is set to `Tlv::kReservedType`
715         // (value zero) to indicate that there is no primary TLV.
716         break;
717 
718     default:
719         ExitNow();
720     }
721 
722     switch (header.GetType())
723     {
724     case Dns::Header::kTypeQuery:
725         error = ProcessRequestOrUnidirectionalMessage(header, aMessage, primaryTlvType);
726         break;
727 
728     case Dns::Header::kTypeResponse:
729         error = ProcessResponseMessage(header, aMessage, primaryTlvType);
730         break;
731     }
732 
733 exit:
734     aMessage.Free();
735 
736     if (error == kErrorNone)
737     {
738         ResetTimeouts(/* aIsKeepAliveMessage */ (primaryTlvType == KeepAliveTlv::kType));
739     }
740     else
741     {
742         Disconnect(kForciblyAbort, kReasonPeerMisbehavior);
743     }
744 
745     // We signal any state change at the very end when the received
746     // message is fully processed (all state and local variables are
747     // updated) to ensure that we do not have any reentrancy issues
748     // (e.g., if a `Connection` method happens to be called from the
749     // callback).
750 
751     SignalAnyStateChange();
752 }
753 
ReadPrimaryTlv(const Message & aMessage,Tlv::Type & aPrimaryTlvType) const754 Error Dso::Connection::ReadPrimaryTlv(const Message &aMessage, Tlv::Type &aPrimaryTlvType) const
755 {
756     // Read and validate the primary TLV (first TLV  after the header).
757     // The `aMessage.GetOffset()` must point to the first TLV. If no
758     // TLV then `kErrorNotFound` is returned. If TLV in message is not
759     // well-formed `kErrorParse` is returned. The read TLV type is
760     // returned in `aPrimaryTlvType` (set to `Tlv::kReservedType`
761     // (value zero) when `kErrorNotFound`).
762 
763     Error error = kErrorNotFound;
764     Tlv   tlv;
765 
766     aPrimaryTlvType = Tlv::kReservedType;
767 
768     SuccessOrExit(aMessage.Read(aMessage.GetOffset(), tlv));
769     VerifyOrExit(aMessage.GetOffset() + tlv.GetSize() <= aMessage.GetLength(), error = kErrorParse);
770     aPrimaryTlvType = tlv.GetType();
771     error           = kErrorNone;
772 
773 exit:
774     return error;
775 }
776 
ProcessRequestOrUnidirectionalMessage(const Dns::Header & aHeader,const Message & aMessage,Tlv::Type aPrimaryTlvType)777 Error Dso::Connection::ProcessRequestOrUnidirectionalMessage(const Dns::Header &aHeader,
778                                                              const Message     &aMessage,
779                                                              Tlv::Type          aPrimaryTlvType)
780 {
781     Error error = kErrorAbort;
782 
783     if (IsServer() && (mState == kStateConnectedButSessionless))
784     {
785         SetState(kStateEstablishingSession);
786     }
787 
788     // A DSO request or unidirectional message MUST contain at
789     // least one TLV which is the "Primary TLV" and determines
790     // the nature of the operation being performed.
791 
792     switch (aPrimaryTlvType)
793     {
794     case KeepAliveTlv::kType:
795         error = ProcessKeepAliveMessage(aHeader, aMessage);
796         break;
797 
798     case RetryDelayTlv::kType:
799         error = ProcessRetryDelayMessage(aHeader, aMessage);
800         break;
801 
802     case Tlv::kReservedType:
803     case EncryptionPaddingTlv::kType:
804         // Misbehavior by peer.
805         break;
806 
807     default:
808         if (aHeader.GetMessageId() == 0)
809         {
810             LogInfo("Received unidirectional message from %s", mPeerSockAddr.ToString().AsCString());
811 
812             error = mCallbacks.mProcessUnidirectionalMessage(*this, aMessage, aPrimaryTlvType);
813         }
814         else
815         {
816             MessageId messageId = aHeader.GetMessageId();
817 
818             LogInfo("Received request message with id %u from %s", messageId, mPeerSockAddr.ToString().AsCString());
819 
820             error = mCallbacks.mProcessRequestMessage(*this, messageId, aMessage, aPrimaryTlvType);
821 
822             // `kErrorNotFound` indicates that TLV type is not known.
823 
824             if (error == kErrorNotFound)
825             {
826                 SendErrorResponse(aHeader, Dns::Header::kDsoTypeNotImplemented);
827                 error = kErrorNone;
828             }
829         }
830         break;
831     }
832 
833     return error;
834 }
835 
ProcessResponseMessage(const Dns::Header & aHeader,const Message & aMessage,Tlv::Type aPrimaryTlvType)836 Error Dso::Connection::ProcessResponseMessage(const Dns::Header &aHeader,
837                                               const Message     &aMessage,
838                                               Tlv::Type          aPrimaryTlvType)
839 {
840     Error     error = kErrorAbort;
841     Tlv::Type requestPrimaryTlvType;
842 
843     // If a client or server receives a response where the message
844     // ID is zero, or is any other value that does not match the
845     // message ID of any of its outstanding operations, this is a
846     // fatal error and the recipient MUST forcibly abort the
847     // connection immediately.
848 
849     VerifyOrExit(aHeader.GetMessageId() != 0);
850     VerifyOrExit(mPendingRequests.Contains(aHeader.GetMessageId(), requestPrimaryTlvType));
851 
852     // If the response has no error and contains a primary TLV, it
853     // MUST match the request primary TLV.
854 
855     if ((aHeader.GetResponseCode() == Dns::Header::kResponseSuccess) && (aPrimaryTlvType != Tlv::kReservedType))
856     {
857         VerifyOrExit(aPrimaryTlvType == requestPrimaryTlvType);
858     }
859 
860     mPendingRequests.Remove(aHeader.GetMessageId());
861 
862     switch (requestPrimaryTlvType)
863     {
864     case KeepAliveTlv::kType:
865         SuccessOrExit(error = ProcessKeepAliveMessage(aHeader, aMessage));
866         break;
867 
868     default:
869         SuccessOrExit(error = mCallbacks.mProcessResponseMessage(*this, aHeader, aMessage, aPrimaryTlvType,
870                                                                  requestPrimaryTlvType));
871         break;
872     }
873 
874     // DSO session is established when client sends a request message
875     // and receives a response from server with no error code.
876 
877     if (IsClient() && (mState == kStateEstablishingSession) &&
878         (aHeader.GetResponseCode() == Dns::Header::kResponseSuccess))
879     {
880         SetState(kStateSessionEstablished);
881     }
882 
883 exit:
884     return error;
885 }
886 
ProcessKeepAliveMessage(const Dns::Header & aHeader,const Message & aMessage)887 Error Dso::Connection::ProcessKeepAliveMessage(const Dns::Header &aHeader, const Message &aMessage)
888 {
889     Error        error  = kErrorAbort;
890     uint16_t     offset = aMessage.GetOffset();
891     Tlv          tlv;
892     KeepAliveTlv keepAliveTlv;
893 
894     if (aHeader.GetType() == Dns::Header::kTypeResponse)
895     {
896         // A Keep Alive response message is allowed on a client from a sever.
897 
898         VerifyOrExit(IsClient());
899 
900         if (aHeader.GetResponseCode() != Dns::Header::kResponseSuccess)
901         {
902             // We got an error response code from server for our
903             // Keep Alive request message. If this happens while
904             // establishing the DSO session, it indicates that server
905             // does not support DSO, so we close the connection. If
906             // this happens while session is already established, it
907             // is a misbehavior (fatal error) by server.
908 
909             if (mState == kStateEstablishingSession)
910             {
911                 Disconnect(kGracefullyClose, kReasonPeerDoesNotSupportDso);
912                 error = kErrorNone;
913             }
914 
915             ExitNow();
916         }
917     }
918 
919     // Parse and validate the Keep Alive Message
920 
921     SuccessOrExit(aMessage.Read(offset, keepAliveTlv));
922     offset += keepAliveTlv.GetSize();
923 
924     VerifyOrExit((keepAliveTlv.GetType() == KeepAliveTlv::kType) && keepAliveTlv.IsValid());
925 
926     // Keep Alive message MUST contain only one Keep Alive TLV.
927 
928     while (offset < aMessage.GetLength())
929     {
930         SuccessOrExit(aMessage.Read(offset, tlv));
931         offset += tlv.GetSize();
932 
933         VerifyOrExit((tlv.GetType() != KeepAliveTlv::kType) && (tlv.GetType() != RetryDelayTlv::kType));
934     }
935 
936     VerifyOrExit(offset == aMessage.GetLength());
937 
938     if (aHeader.GetType() == Dns::Header::kTypeQuery)
939     {
940         if (IsServer())
941         {
942             // Received a Keep Alive message from client. It MUST
943             // be a request message (not unidirectional). We prepare
944             // and send a Keep Alive response.
945 
946             VerifyOrExit(aHeader.GetMessageId() != 0);
947 
948             LogInfo("Received KeepAlive request message from client %s", mPeerSockAddr.ToString().AsCString());
949 
950             IgnoreError(SendKeepAliveMessage(kResponseMessage, aHeader.GetMessageId()));
951             error = kErrorNone;
952             ExitNow();
953         }
954 
955         // Received a Keep Alive message on client from server. Server
956         // Keep Alive message MUST be unidirectional (message ID
957         // zero).
958 
959         VerifyOrExit(aHeader.GetMessageId() == 0);
960     }
961 
962     LogInfo("Received Keep Alive %s message from server %s",
963             (aHeader.GetMessageId() == 0) ? "unidirectional" : "response", mPeerSockAddr.ToString().AsCString());
964 
965     // Receiving a Keep Alive interval value from server less than the
966     // minimum (ten seconds) is a fatal error and client MUST then
967     // abort the connection.
968 
969     VerifyOrExit(keepAliveTlv.GetKeepAliveInterval() >= kMinKeepAliveInterval);
970 
971     // Update the timeout intervals on the connection from
972     // the new values we got from the server. The receive
973     // of the Keep Alive message does not itself reset the
974     // inactivity timer. So we use `AdjustInactivityTimeout`
975     // which takes into account the time elapsed since the
976     // last activity.
977 
978     AdjustInactivityTimeout(keepAliveTlv.GetInactivityTimeout());
979     mKeepAlive.SetInterval(keepAliveTlv.GetKeepAliveInterval());
980 
981     LogInfo("Timeouts Inactivity:%lu, KeepAlive:%lu", ToUlong(mInactivity.GetInterval()),
982             ToUlong(mKeepAlive.GetInterval()));
983 
984     error = kErrorNone;
985 
986 exit:
987     return error;
988 }
989 
ProcessRetryDelayMessage(const Dns::Header & aHeader,const Message & aMessage)990 Error Dso::Connection::ProcessRetryDelayMessage(const Dns::Header &aHeader, const Message &aMessage)
991 
992 {
993     Error         error = kErrorAbort;
994     RetryDelayTlv retryDelayTlv;
995 
996     // Retry Delay TLV can be used as the Primary TLV only in
997     // a unidirectional message sent from server to client.
998     // It is used by the server to instruct the client to
999     // close the session and its underlying connection, and not
1000     // to reconnect for the indicated time interval.
1001 
1002     VerifyOrExit(IsClient() && (aHeader.GetMessageId() == 0));
1003 
1004     SuccessOrExit(aMessage.Read(aMessage.GetOffset(), retryDelayTlv));
1005     VerifyOrExit(retryDelayTlv.IsValid());
1006 
1007     mRetryDelayErrorCode = aHeader.GetResponseCode();
1008     mRetryDelay          = retryDelayTlv.GetRetryDelay();
1009 
1010     LogInfo("Received Retry Delay message from server %s", mPeerSockAddr.ToString().AsCString());
1011     LogInfo("   RetryDelay:%lu ms, ResponseCode:%d", ToUlong(mRetryDelay), mRetryDelayErrorCode);
1012 
1013     Disconnect(kGracefullyClose, kReasonServerRetryDelayRequest);
1014 
1015 exit:
1016     return error;
1017 }
1018 
SendErrorResponse(const Dns::Header & aHeader,Dns::Header::Response aResponseCode)1019 void Dso::Connection::SendErrorResponse(const Dns::Header &aHeader, Dns::Header::Response aResponseCode)
1020 {
1021     Message    *response = NewMessage();
1022     Dns::Header header;
1023 
1024     VerifyOrExit(response != nullptr);
1025 
1026     header.SetMessageId(aHeader.GetMessageId());
1027     header.SetType(Dns::Header::kTypeResponse);
1028     header.SetQueryType(aHeader.GetQueryType());
1029     header.SetResponseCode(aResponseCode);
1030 
1031     SuccessOrExit(response->Prepend(header));
1032 
1033     otPlatDsoSend(this, response);
1034     response = nullptr;
1035 
1036 exit:
1037     FreeMessage(response);
1038 }
1039 
AdjustInactivityTimeout(uint32_t aNewTimeout)1040 void Dso::Connection::AdjustInactivityTimeout(uint32_t aNewTimeout)
1041 {
1042     // This method sets the inactivity timeout interval to a new value
1043     // and updates the expiration time based on the new timeout value.
1044     //
1045     // On client, it is called on receiving a Keep Alive response or
1046     // unidirectional message from server. Note that the receive of
1047     // the Keep Alive message does not itself reset the inactivity
1048     // timer. So the time elapsed since the last activity should be
1049     // taken into account with the new inactivity timeout value.
1050     //
1051     // On server this method is called from `SetTimeouts()` when a new
1052     // inactivity timeout value is set.
1053 
1054     TimeMilli now = TimerMilli::GetNow();
1055     TimeMilli start;
1056     TimeMilli newExpiration;
1057 
1058     if (mState == kStateDisconnected)
1059     {
1060         mInactivity.SetInterval(aNewTimeout);
1061         ExitNow();
1062     }
1063 
1064     VerifyOrExit(aNewTimeout != mInactivity.GetInterval());
1065 
1066     // Calculate the start time (i.e., the last time inactivity timer
1067     // was cleared). If the previous inactivity time is set to
1068     // `kInfinite` value (`IsUsed()` returns `false`) then
1069     // `GetExpirationTime()` returns the start time. Otherwise, we
1070     // calculate it going back from the current expiration time with
1071     // the current wait interval.
1072 
1073     if (!mInactivity.IsUsed())
1074     {
1075         start = mInactivity.GetExpirationTime();
1076     }
1077     else if (IsClient())
1078     {
1079         start = mInactivity.GetExpirationTime() - mInactivity.GetInterval();
1080     }
1081     else
1082     {
1083         start = mInactivity.GetExpirationTime() - CalculateServerInactivityWaitTime();
1084     }
1085 
1086     mInactivity.SetInterval(aNewTimeout);
1087 
1088     if (!mInactivity.IsUsed())
1089     {
1090         newExpiration = start;
1091     }
1092     else if (IsClient())
1093     {
1094         newExpiration = start + aNewTimeout;
1095 
1096         if (newExpiration < now)
1097         {
1098             newExpiration = now;
1099         }
1100     }
1101     else
1102     {
1103         newExpiration = start + CalculateServerInactivityWaitTime();
1104 
1105         if (newExpiration < now)
1106         {
1107             // If the server abruptly reduces the inactivity timeout
1108             // such that current elapsed time is already more than
1109             // twice the new inactivity timeout, then the client is
1110             // immediately considered delinquent (server can forcibly
1111             // abort the connection). So to give the client time to
1112             // close the connection gracefully, the server SHOULD
1113             // give the client an additional grace period of either
1114             // five seconds or one quarter of the new inactivity
1115             // timeout, whichever is greater [RFC 8490 - 7.1.1].
1116 
1117             newExpiration = now + Max(kMinServerInactivityWaitTime, aNewTimeout / 4);
1118         }
1119     }
1120 
1121     mInactivity.SetExpirationTime(newExpiration);
1122 
1123 exit:
1124     return;
1125 }
1126 
CalculateServerInactivityWaitTime(void) const1127 uint32_t Dso::Connection::CalculateServerInactivityWaitTime(void) const
1128 {
1129     // A server will abort an idle session after five seconds
1130     // (`kMinServerInactivityWaitTime`) or twice the inactivity
1131     // timeout value, whichever is greater [RFC 8490 - 6.4.1].
1132 
1133     OT_ASSERT(mInactivity.IsUsed());
1134 
1135     return Max(mInactivity.GetInterval() * 2, kMinServerInactivityWaitTime);
1136 }
1137 
ResetTimeouts(bool aIsKeepAliveMessage)1138 void Dso::Connection::ResetTimeouts(bool aIsKeepAliveMessage)
1139 {
1140     NextFireTime nextTime;
1141 
1142     // At both servers and clients, the generation or reception of any
1143     // complete DNS message resets both timers for that DSO
1144     // session, with the one exception being that a DSO Keep Alive
1145     // message resets only the keep alive timer, not the inactivity
1146     // timeout timer [RFC 8490 - 6.3]
1147 
1148     if (mKeepAlive.IsUsed())
1149     {
1150         // On client, we wait for the Keep Alive interval but on server
1151         // we wait for twice the interval before considering Keep Alive
1152         // timeout.
1153         //
1154         // Note that we limit the interval to `Timeout::kMaxInterval`
1155         // (which is ~12 days). This max limit ensures that even twice
1156         // the interval is less than max OpenThread timer duration so
1157         // that the expiration time calculations below stay within the
1158         // `TimerMilli` range.
1159 
1160         mKeepAlive.SetExpirationTime(nextTime.GetNow() + mKeepAlive.GetInterval() * (IsServer() ? 2 : 1));
1161     }
1162 
1163     if (!aIsKeepAliveMessage)
1164     {
1165         if (mInactivity.IsUsed())
1166         {
1167             mInactivity.SetExpirationTime(
1168                 nextTime.GetNow() + (IsServer() ? CalculateServerInactivityWaitTime() : mInactivity.GetInterval()));
1169         }
1170         else
1171         {
1172             // When Inactivity timeout is not used (i.e., interval is set
1173             // to the special `kInfinite` value), we still need to track
1174             // the time so that if/when later the inactivity interval
1175             // gets changed, we can adjust the remaining time correctly
1176             // from `AdjustInactivityTimeout()`. In this case, we just
1177             // track the current time as "expiration time".
1178 
1179             mInactivity.SetExpirationTime(nextTime.GetNow());
1180         }
1181     }
1182 
1183     UpdateNextFireTime(nextTime);
1184 
1185     Get<Dso>().mTimer.FireAtIfEarlier(nextTime);
1186 }
1187 
UpdateNextFireTime(NextFireTime & aNextTime) const1188 void Dso::Connection::UpdateNextFireTime(NextFireTime &aNextTime) const
1189 {
1190     switch (mState)
1191     {
1192     case kStateDisconnected:
1193         break;
1194 
1195     case kStateConnecting:
1196         // While in `kStateConnecting`, Keep Alive timer is
1197         // used for `kConnectingTimeout`.
1198         aNextTime.UpdateIfEarlier(mKeepAlive.GetExpirationTime());
1199         break;
1200 
1201     case kStateConnectedButSessionless:
1202     case kStateEstablishingSession:
1203     case kStateSessionEstablished:
1204         mPendingRequests.UpdateNextFireTime(aNextTime);
1205 
1206         if (mKeepAlive.IsUsed())
1207         {
1208             aNextTime.UpdateIfEarlier(mKeepAlive.GetExpirationTime());
1209         }
1210 
1211         if (mInactivity.IsUsed() && mPendingRequests.IsEmpty() && !mLongLivedOperation)
1212         {
1213             // An operation being active on a DSO Session includes
1214             // a request message waiting for a response, or an
1215             // active long-lived operation.
1216 
1217             aNextTime.UpdateIfEarlier(mInactivity.GetExpirationTime());
1218         }
1219 
1220         break;
1221     }
1222 }
1223 
HandleTimer(NextFireTime & aNextTime)1224 void Dso::Connection::HandleTimer(NextFireTime &aNextTime)
1225 {
1226     switch (mState)
1227     {
1228     case kStateDisconnected:
1229         break;
1230 
1231     case kStateConnecting:
1232         if (mKeepAlive.IsExpired(aNextTime.GetNow()))
1233         {
1234             Disconnect(kGracefullyClose, kReasonFailedToConnect);
1235         }
1236         break;
1237 
1238     case kStateConnectedButSessionless:
1239     case kStateEstablishingSession:
1240     case kStateSessionEstablished:
1241         if (mPendingRequests.HasAnyTimedOut(aNextTime.GetNow()))
1242         {
1243             // If server sends no response to a request, client
1244             // waits for 30 seconds (`kResponseTimeout`) after which
1245             // client MUST forcibly abort the connection.
1246             Disconnect(kForciblyAbort, kReasonResponseTimeout);
1247             ExitNow();
1248         }
1249 
1250         // The inactivity timer is kept clear, while an operation is
1251         // active on the session (which includes a request waiting for
1252         // response or an active long-lived operation).
1253 
1254         if (mInactivity.IsUsed() && mPendingRequests.IsEmpty() && !mLongLivedOperation &&
1255             mInactivity.IsExpired(aNextTime.GetNow()))
1256         {
1257             // On client, if the inactivity timeout is reached, the
1258             // connection is closed gracefully. On server, if too much
1259             // time (`CalculateServerInactivityWaitTime()`, i.e., five
1260             // seconds or twice the current inactivity timeout interval,
1261             // whichever is grater) elapses server MUST consider the
1262             // client delinquent and MUST forcibly abort the connection.
1263 
1264             Disconnect(IsClient() ? kGracefullyClose : kForciblyAbort, kReasonInactivityTimeout);
1265             ExitNow();
1266         }
1267 
1268         if (mKeepAlive.IsUsed() && mKeepAlive.IsExpired(aNextTime.GetNow()))
1269         {
1270             // On client, if the Keep Alive interval elapses without any
1271             // DNS messages being sent or received, the client MUST take
1272             // action and send a DSO Keep Alive message.
1273             //
1274             // On server, if twice the Keep Alive interval value elapses
1275             // without any messages being sent or received, the server
1276             // considers the client delinquent and aborts the connection.
1277 
1278             if (IsClient())
1279             {
1280                 IgnoreError(SendKeepAliveMessage());
1281             }
1282             else
1283             {
1284                 Disconnect(kForciblyAbort, kReasonKeepAliveTimeout);
1285                 ExitNow();
1286             }
1287         }
1288         break;
1289     }
1290 
1291 exit:
1292     UpdateNextFireTime(aNextTime);
1293     SignalAnyStateChange();
1294 }
1295 
StateToString(State aState)1296 const char *Dso::Connection::StateToString(State aState)
1297 {
1298     static const char *const kStateStrings[] = {
1299         "Disconnected",            // (0) kStateDisconnected,
1300         "Connecting",              // (1) kStateConnecting,
1301         "ConnectedButSessionless", // (2) kStateConnectedButSessionless,
1302         "EstablishingSession",     // (3) kStateEstablishingSession,
1303         "SessionEstablished",      // (4) kStateSessionEstablished,
1304     };
1305 
1306     struct EnumCheck
1307     {
1308         InitEnumValidatorCounter();
1309         ValidateNextEnum(kStateDisconnected);
1310         ValidateNextEnum(kStateConnecting);
1311         ValidateNextEnum(kStateConnectedButSessionless);
1312         ValidateNextEnum(kStateEstablishingSession);
1313         ValidateNextEnum(kStateSessionEstablished);
1314     };
1315 
1316     return kStateStrings[aState];
1317 }
1318 
MessageTypeToString(MessageType aMessageType)1319 const char *Dso::Connection::MessageTypeToString(MessageType aMessageType)
1320 {
1321     static const char *const kMessageTypeStrings[] = {
1322         "Request",        // (0) kRequestMessage
1323         "Response",       // (1) kResponseMessage
1324         "Unidirectional", // (2) kUnidirectionalMessage
1325     };
1326 
1327     struct EnumCheck
1328     {
1329         InitEnumValidatorCounter();
1330         ValidateNextEnum(kRequestMessage);
1331         ValidateNextEnum(kResponseMessage);
1332         ValidateNextEnum(kUnidirectionalMessage);
1333     };
1334 
1335     return kMessageTypeStrings[aMessageType];
1336 }
1337 
DisconnectReasonToString(DisconnectReason aReason)1338 const char *Dso::Connection::DisconnectReasonToString(DisconnectReason aReason)
1339 {
1340     static const char *const kDisconnectReasonStrings[] = {
1341         "FailedToConnect",         // (0) kReasonFailedToConnect
1342         "ResponseTimeout",         // (1) kReasonResponseTimeout
1343         "PeerDoesNotSupportDso",   // (2) kReasonPeerDoesNotSupportDso
1344         "PeerClosed",              // (3) kReasonPeerClosed
1345         "PeerAborted",             // (4) kReasonPeerAborted
1346         "InactivityTimeout",       // (5) kReasonInactivityTimeout
1347         "KeepAliveTimeout",        // (6) kReasonKeepAliveTimeout
1348         "ServerRetryDelayRequest", // (7) kReasonServerRetryDelayRequest
1349         "PeerMisbehavior",         // (8) kReasonPeerMisbehavior
1350         "Unknown",                 // (9) kReasonUnknown
1351     };
1352 
1353     struct EnumCheck
1354     {
1355         InitEnumValidatorCounter();
1356         ValidateNextEnum(kReasonFailedToConnect);
1357         ValidateNextEnum(kReasonResponseTimeout);
1358         ValidateNextEnum(kReasonPeerDoesNotSupportDso);
1359         ValidateNextEnum(kReasonPeerClosed);
1360         ValidateNextEnum(kReasonPeerAborted);
1361         ValidateNextEnum(kReasonInactivityTimeout);
1362         ValidateNextEnum(kReasonKeepAliveTimeout);
1363         ValidateNextEnum(kReasonServerRetryDelayRequest);
1364         ValidateNextEnum(kReasonPeerMisbehavior);
1365         ValidateNextEnum(kReasonUnknown);
1366     };
1367 
1368     return kDisconnectReasonStrings[aReason];
1369 }
1370 
1371 //---------------------------------------------------------------------------------------------------------------------
1372 // Dso::Connection::PendingRequests
1373 
Contains(MessageId aMessageId,Tlv::Type & aPrimaryTlvType) const1374 bool Dso::Connection::PendingRequests::Contains(MessageId aMessageId, Tlv::Type &aPrimaryTlvType) const
1375 {
1376     bool         contains = true;
1377     const Entry *entry    = mRequests.FindMatching(aMessageId);
1378 
1379     VerifyOrExit(entry != nullptr, contains = false);
1380     aPrimaryTlvType = entry->mPrimaryTlvType;
1381 
1382 exit:
1383     return contains;
1384 }
1385 
Add(MessageId aMessageId,Tlv::Type aPrimaryTlvType,TimeMilli aResponseTimeout)1386 Error Dso::Connection::PendingRequests::Add(MessageId aMessageId, Tlv::Type aPrimaryTlvType, TimeMilli aResponseTimeout)
1387 {
1388     Error  error = kErrorNone;
1389     Entry *entry = mRequests.PushBack();
1390 
1391     VerifyOrExit(entry != nullptr, error = kErrorNoBufs);
1392     entry->mMessageId      = aMessageId;
1393     entry->mPrimaryTlvType = aPrimaryTlvType;
1394     entry->mTimeout        = aResponseTimeout;
1395 
1396 exit:
1397     return error;
1398 }
1399 
Remove(MessageId aMessageId)1400 void Dso::Connection::PendingRequests::Remove(MessageId aMessageId) { mRequests.RemoveMatching(aMessageId); }
1401 
HasAnyTimedOut(TimeMilli aNow) const1402 bool Dso::Connection::PendingRequests::HasAnyTimedOut(TimeMilli aNow) const
1403 {
1404     bool timedOut = false;
1405 
1406     for (const Entry &entry : mRequests)
1407     {
1408         if (entry.mTimeout <= aNow)
1409         {
1410             timedOut = true;
1411             break;
1412         }
1413     }
1414 
1415     return timedOut;
1416 }
1417 
UpdateNextFireTime(NextFireTime & aNextTime) const1418 void Dso::Connection::PendingRequests::UpdateNextFireTime(NextFireTime &aNextTime) const
1419 {
1420     for (const Entry &entry : mRequests)
1421     {
1422         aNextTime.UpdateIfEarlier(entry.mTimeout);
1423     }
1424 }
1425 
1426 //---------------------------------------------------------------------------------------------------------------------
1427 // Dso
1428 
Dso(Instance & aInstance)1429 Dso::Dso(Instance &aInstance)
1430     : InstanceLocator(aInstance)
1431     , mAcceptHandler(nullptr)
1432     , mTimer(aInstance)
1433 {
1434 }
1435 
StartListening(AcceptHandler aAcceptHandler)1436 void Dso::StartListening(AcceptHandler aAcceptHandler)
1437 {
1438     mAcceptHandler = aAcceptHandler;
1439     otPlatDsoEnableListening(&GetInstance(), true);
1440 }
1441 
StopListening(void)1442 void Dso::StopListening(void) { otPlatDsoEnableListening(&GetInstance(), false); }
1443 
FindClientConnection(const Ip6::SockAddr & aPeerSockAddr)1444 Dso::Connection *Dso::FindClientConnection(const Ip6::SockAddr &aPeerSockAddr)
1445 {
1446     return mClientConnections.FindMatching(aPeerSockAddr);
1447 }
1448 
FindServerConnection(const Ip6::SockAddr & aPeerSockAddr)1449 Dso::Connection *Dso::FindServerConnection(const Ip6::SockAddr &aPeerSockAddr)
1450 {
1451     return mServerConnections.FindMatching(aPeerSockAddr);
1452 }
1453 
AcceptConnection(const Ip6::SockAddr & aPeerSockAddr)1454 Dso::Connection *Dso::AcceptConnection(const Ip6::SockAddr &aPeerSockAddr)
1455 {
1456     Connection *connection = nullptr;
1457 
1458     VerifyOrExit(mAcceptHandler != nullptr);
1459     connection = mAcceptHandler(GetInstance(), aPeerSockAddr);
1460 
1461     VerifyOrExit(connection != nullptr);
1462     connection->Accept();
1463 
1464 exit:
1465     return connection;
1466 }
1467 
HandleTimer(void)1468 void Dso::HandleTimer(void)
1469 {
1470     NextFireTime nextTime;
1471     Connection  *conn;
1472     Connection  *next;
1473 
1474     for (conn = mClientConnections.GetHead(); conn != nullptr; conn = next)
1475     {
1476         next = conn->GetNext();
1477         conn->HandleTimer(nextTime);
1478     }
1479 
1480     for (conn = mServerConnections.GetHead(); conn != nullptr; conn = next)
1481     {
1482         next = conn->GetNext();
1483         conn->HandleTimer(nextTime);
1484     }
1485 
1486     mTimer.FireAtIfEarlier(nextTime);
1487 }
1488 
1489 } // namespace Dns
1490 } // namespace ot
1491 
1492 #if OPENTHREAD_CONFIG_DNS_DSO_MOCK_PLAT_APIS_ENABLE
1493 
otPlatDsoEnableListening(otInstance * aInstance,bool aEnable)1494 OT_TOOL_WEAK void otPlatDsoEnableListening(otInstance *aInstance, bool aEnable)
1495 {
1496     OT_UNUSED_VARIABLE(aInstance);
1497     OT_UNUSED_VARIABLE(aEnable);
1498 }
1499 
otPlatDsoConnect(otPlatDsoConnection * aConnection,const otSockAddr * aPeerSockAddr)1500 OT_TOOL_WEAK void otPlatDsoConnect(otPlatDsoConnection *aConnection, const otSockAddr *aPeerSockAddr)
1501 {
1502     OT_UNUSED_VARIABLE(aConnection);
1503     OT_UNUSED_VARIABLE(aPeerSockAddr);
1504 }
1505 
otPlatDsoSend(otPlatDsoConnection * aConnection,otMessage * aMessage)1506 OT_TOOL_WEAK void otPlatDsoSend(otPlatDsoConnection *aConnection, otMessage *aMessage)
1507 {
1508     OT_UNUSED_VARIABLE(aConnection);
1509     OT_UNUSED_VARIABLE(aMessage);
1510 }
1511 
otPlatDsoDisconnect(otPlatDsoConnection * aConnection,otPlatDsoDisconnectMode aMode)1512 OT_TOOL_WEAK void otPlatDsoDisconnect(otPlatDsoConnection *aConnection, otPlatDsoDisconnectMode aMode)
1513 {
1514     OT_UNUSED_VARIABLE(aConnection);
1515     OT_UNUSED_VARIABLE(aMode);
1516 }
1517 
1518 #endif // OPENTHREAD_CONFIG_DNS_DSO_MOCK_PLAT_APIS_ENABLE
1519 
1520 #endif // OPENTHREAD_CONFIG_DNS_DSO_ENABLE
1521