• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2021, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 #include "dns_dso.hpp"
30 
31 #if OPENTHREAD_CONFIG_DNS_DSO_ENABLE
32 
33 #include "common/array.hpp"
34 #include "common/as_core_type.hpp"
35 #include "common/code_utils.hpp"
36 #include "common/debug.hpp"
37 #include "common/instance.hpp"
38 #include "common/locator_getters.hpp"
39 #include "common/log.hpp"
40 #include "common/random.hpp"
41 
42 /**
43  * @file
44  *   This file implements the DNS Stateful Operations (DSO) per RFC 8490.
45  */
46 
47 namespace ot {
48 namespace Dns {
49 
50 RegisterLogModule("DnsDso");
51 
52 //---------------------------------------------------------------------------------------------------------------------
53 // otPlatDso transport callbacks
54 
otPlatDsoGetInstance(otPlatDsoConnection * aConnection)55 extern "C" otInstance *otPlatDsoGetInstance(otPlatDsoConnection *aConnection)
56 {
57     return &AsCoreType(aConnection).GetInstance();
58 }
59 
otPlatDsoAccept(otInstance * aInstance,const otSockAddr * aPeerSockAddr)60 extern "C" otPlatDsoConnection *otPlatDsoAccept(otInstance *aInstance, const otSockAddr *aPeerSockAddr)
61 {
62     return AsCoreType(aInstance).Get<Dso>().AcceptConnection(AsCoreType(aPeerSockAddr));
63 }
64 
otPlatDsoHandleConnected(otPlatDsoConnection * aConnection)65 extern "C" void otPlatDsoHandleConnected(otPlatDsoConnection *aConnection)
66 {
67     AsCoreType(aConnection).HandleConnected();
68 }
69 
otPlatDsoHandleReceive(otPlatDsoConnection * aConnection,otMessage * aMessage)70 extern "C" void otPlatDsoHandleReceive(otPlatDsoConnection *aConnection, otMessage *aMessage)
71 {
72     AsCoreType(aConnection).HandleReceive(AsCoreType(aMessage));
73 }
74 
otPlatDsoHandleDisconnected(otPlatDsoConnection * aConnection,otPlatDsoDisconnectMode aMode)75 extern "C" void otPlatDsoHandleDisconnected(otPlatDsoConnection *aConnection, otPlatDsoDisconnectMode aMode)
76 {
77     AsCoreType(aConnection).HandleDisconnected(MapEnum(aMode));
78 }
79 
80 //---------------------------------------------------------------------------------------------------------------------
81 // Dso::Connection
82 
Connection(Instance & aInstance,const Ip6::SockAddr & aPeerSockAddr,Callbacks & aCallbacks,uint32_t aInactivityTimeout,uint32_t aKeepAliveInterval)83 Dso::Connection::Connection(Instance &           aInstance,
84                             const Ip6::SockAddr &aPeerSockAddr,
85                             Callbacks &          aCallbacks,
86                             uint32_t             aInactivityTimeout,
87                             uint32_t             aKeepAliveInterval)
88     : InstanceLocator(aInstance)
89     , mNext(nullptr)
90     , mCallbacks(aCallbacks)
91     , mPeerSockAddr(aPeerSockAddr)
92     , mState(kStateDisconnected)
93     , mIsServer(false)
94     , mInactivity(aInactivityTimeout)
95     , mKeepAlive(aKeepAliveInterval)
96 {
97     OT_ASSERT(aKeepAliveInterval >= kMinKeepAliveInterval);
98     Init(/* aIsServer */ false);
99 }
100 
Init(bool aIsServer)101 void Dso::Connection::Init(bool aIsServer)
102 {
103     mNextMessageId       = 1;
104     mIsServer            = aIsServer;
105     mStateDidChange      = false;
106     mLongLivedOperation  = false;
107     mRetryDelay          = 0;
108     mRetryDelayErrorCode = Dns::Header::kResponseSuccess;
109     mDisconnectReason    = kReasonUnknown;
110 }
111 
SetState(State aState)112 void Dso::Connection::SetState(State aState)
113 {
114     VerifyOrExit(mState != aState);
115 
116     LogInfo("State: %s -> %s on connection with %s", StateToString(mState), StateToString(aState),
117             mPeerSockAddr.ToString().AsCString());
118 
119     mState          = aState;
120     mStateDidChange = true;
121 
122 exit:
123     return;
124 }
125 
SignalAnyStateChange(void)126 void Dso::Connection::SignalAnyStateChange(void)
127 {
128     VerifyOrExit(mStateDidChange);
129     mStateDidChange = false;
130 
131     switch (mState)
132     {
133     case kStateDisconnected:
134         mCallbacks.mHandleDisconnected(*this);
135         break;
136 
137     case kStateConnectedButSessionless:
138         mCallbacks.mHandleConnected(*this);
139         break;
140 
141     case kStateSessionEstablished:
142         mCallbacks.mHandleSessionEstablished(*this);
143         break;
144 
145     case kStateConnecting:
146     case kStateEstablishingSession:
147         break;
148     };
149 
150 exit:
151     return;
152 }
153 
NewMessage(void)154 Message *Dso::Connection::NewMessage(void)
155 {
156     return Get<MessagePool>().Allocate(Message::kTypeOther, sizeof(Dns::Header),
157                                        Message::Settings(Message::kPriorityNormal));
158 }
159 
Connect(void)160 void Dso::Connection::Connect(void)
161 {
162     OT_ASSERT(mState == kStateDisconnected);
163 
164     Init(/* aIsServer */ false);
165     Get<Dso>().mClientConnections.Push(*this);
166     MarkAsConnecting();
167     otPlatDsoConnect(this, &mPeerSockAddr);
168 }
169 
Accept(void)170 void Dso::Connection::Accept(void)
171 {
172     OT_ASSERT(mState == kStateDisconnected);
173 
174     Init(/* aIsServer */ true);
175     Get<Dso>().mServerConnections.Push(*this);
176     MarkAsConnecting();
177 }
178 
MarkAsConnecting(void)179 void Dso::Connection::MarkAsConnecting(void)
180 {
181     SetState(kStateConnecting);
182 
183     // While in `kStateConnecting` state we use the `mKeepAlive` to
184     // track the `kConnectingTimeout` (if connection is not established
185     // within the timeout, we consider it as failure and close it).
186 
187     mKeepAlive.SetExpirationTime(TimerMilli::GetNow() + kConnectingTimeout);
188     Get<Dso>().mTimer.FireAtIfEarlier(mKeepAlive.GetExpirationTime());
189 
190     // Wait for `HandleConnected()` or `HandleDisconnected()` callbacks
191     // or timeout.
192 }
193 
HandleConnected(void)194 void Dso::Connection::HandleConnected(void)
195 {
196     OT_ASSERT(mState == kStateConnecting);
197 
198     SetState(kStateConnectedButSessionless);
199     ResetTimeouts(/* aIsKeepAliveMessage */ false);
200 
201     SignalAnyStateChange();
202 }
203 
Disconnect(DisconnectMode aMode,DisconnectReason aReason)204 void Dso::Connection::Disconnect(DisconnectMode aMode, DisconnectReason aReason)
205 {
206     VerifyOrExit(mState != kStateDisconnected);
207 
208     mDisconnectReason = aReason;
209     MarkAsDisconnected();
210 
211     otPlatDsoDisconnect(this, MapEnum(aMode));
212 
213 exit:
214     return;
215 }
216 
HandleDisconnected(DisconnectMode aMode)217 void Dso::Connection::HandleDisconnected(DisconnectMode aMode)
218 {
219     VerifyOrExit(mState != kStateDisconnected);
220 
221     if (mState == kStateConnecting)
222     {
223         mDisconnectReason = kReasonFailedToConnect;
224     }
225     else
226     {
227         switch (aMode)
228         {
229         case kGracefullyClose:
230             mDisconnectReason = kReasonPeerClosed;
231             break;
232 
233         case kForciblyAbort:
234             mDisconnectReason = kReasonPeerAborted;
235         }
236     }
237 
238     MarkAsDisconnected();
239     SignalAnyStateChange();
240 
241 exit:
242     return;
243 }
244 
MarkAsDisconnected(void)245 void Dso::Connection::MarkAsDisconnected(void)
246 {
247     if (IsClient())
248     {
249         IgnoreError(Get<Dso>().mClientConnections.Remove(*this));
250     }
251     else
252     {
253         IgnoreError(Get<Dso>().mServerConnections.Remove(*this));
254     }
255 
256     mPendingRequests.Clear();
257     SetState(kStateDisconnected);
258 
259     LogInfo("Disconnect reason: %s", DisconnectReasonToString(mDisconnectReason));
260 }
261 
MarkSessionEstablished(void)262 void Dso::Connection::MarkSessionEstablished(void)
263 {
264     switch (mState)
265     {
266     case kStateConnectedButSessionless:
267     case kStateEstablishingSession:
268     case kStateSessionEstablished:
269         break;
270 
271     case kStateDisconnected:
272     case kStateConnecting:
273         OT_ASSERT(false);
274     }
275 
276     SetState(kStateSessionEstablished);
277 }
278 
SendRequestMessage(Message & aMessage,MessageId & aMessageId,uint32_t aResponseTimeout)279 Error Dso::Connection::SendRequestMessage(Message &aMessage, MessageId &aMessageId, uint32_t aResponseTimeout)
280 {
281     return SendMessage(aMessage, kRequestMessage, aMessageId, Dns::Header::kResponseSuccess, aResponseTimeout);
282 }
283 
SendUnidirectionalMessage(Message & aMessage)284 Error Dso::Connection::SendUnidirectionalMessage(Message &aMessage)
285 {
286     MessageId messageId = 0;
287 
288     return SendMessage(aMessage, kUnidirectionalMessage, messageId);
289 }
290 
SendResponseMessage(Message & aMessage,MessageId aResponseId)291 Error Dso::Connection::SendResponseMessage(Message &aMessage, MessageId aResponseId)
292 {
293     return SendMessage(aMessage, kResponseMessage, aResponseId);
294 }
295 
SetLongLivedOperation(bool aLongLivedOperation)296 void Dso::Connection::SetLongLivedOperation(bool aLongLivedOperation)
297 {
298     VerifyOrExit(mLongLivedOperation != aLongLivedOperation);
299 
300     mLongLivedOperation = aLongLivedOperation;
301 
302     LogInfo("Long-lived operation %s", mLongLivedOperation ? "started" : "stopped");
303 
304     if (!mLongLivedOperation)
305     {
306         TimeMilli now = TimerMilli::GetNow();
307         TimeMilli nextTime;
308 
309         nextTime = GetNextFireTime(now);
310 
311         if (nextTime != now.GetDistantFuture())
312         {
313             Get<Dso>().mTimer.FireAtIfEarlier(nextTime);
314         }
315     }
316 
317 exit:
318     return;
319 }
320 
SendRetryDelayMessage(uint32_t aDelay,Dns::Header::Response aResponseCode)321 Error Dso::Connection::SendRetryDelayMessage(uint32_t aDelay, Dns::Header::Response aResponseCode)
322 {
323     Error         error   = kErrorNone;
324     Message *     message = nullptr;
325     RetryDelayTlv retryDelayTlv;
326     MessageId     messageId;
327 
328     switch (mState)
329     {
330     case kStateSessionEstablished:
331         OT_ASSERT(IsServer());
332         break;
333 
334     case kStateConnectedButSessionless:
335     case kStateEstablishingSession:
336     case kStateDisconnected:
337     case kStateConnecting:
338         OT_ASSERT(false);
339     }
340 
341     message = NewMessage();
342     VerifyOrExit(message != nullptr, error = kErrorNoBufs);
343 
344     retryDelayTlv.Init();
345     retryDelayTlv.SetRetryDelay(aDelay);
346     SuccessOrExit(error = message->Append(retryDelayTlv));
347     error = SendMessage(*message, kUnidirectionalMessage, messageId, aResponseCode);
348 
349 exit:
350     FreeMessageOnError(message, error);
351     return error;
352 }
353 
SetTimeouts(uint32_t aInactivityTimeout,uint32_t aKeepAliveInterval)354 Error Dso::Connection::SetTimeouts(uint32_t aInactivityTimeout, uint32_t aKeepAliveInterval)
355 {
356     Error error = kErrorNone;
357 
358     VerifyOrExit(aKeepAliveInterval >= kMinKeepAliveInterval, error = kErrorInvalidArgs);
359 
360     // If acting as server, the timeout values are the ones we grant
361     // to a connecting clients. If acting as client, the timeout
362     // values are what to request when sending Keep Alive message.
363     // If in `kStateDisconnected` we set both (since we don't know
364     // yet whether we are going to connect as client or server).
365 
366     if ((mState == kStateDisconnected) || IsServer())
367     {
368         mKeepAlive.SetInterval(aKeepAliveInterval);
369         AdjustInactivityTimeout(aInactivityTimeout);
370     }
371 
372     if ((mState == kStateDisconnected) || IsClient())
373     {
374         mKeepAlive.SetRequestInterval(aKeepAliveInterval);
375         mInactivity.SetRequestInterval(aInactivityTimeout);
376     }
377 
378     switch (mState)
379     {
380     case kStateDisconnected:
381     case kStateConnecting:
382         break;
383 
384     case kStateConnectedButSessionless:
385     case kStateEstablishingSession:
386         if (IsServer())
387         {
388             break;
389         }
390 
391         OT_FALL_THROUGH;
392 
393     case kStateSessionEstablished:
394         error = SendKeepAliveMessage();
395     }
396 
397 exit:
398     return error;
399 }
400 
SendKeepAliveMessage(void)401 Error Dso::Connection::SendKeepAliveMessage(void)
402 {
403     return SendKeepAliveMessage(IsServer() ? kUnidirectionalMessage : kRequestMessage, 0);
404 }
405 
SendKeepAliveMessage(MessageType aMessageType,MessageId aResponseId)406 Error Dso::Connection::SendKeepAliveMessage(MessageType aMessageType, MessageId aResponseId)
407 {
408     // Sends a Keep Alive message of a given type. This is a common
409     // method used by both client and server. `aResponseId` is
410     // applicable and used only when the message type is
411     // `kResponseMessage`.
412 
413     Error        error   = kErrorNone;
414     Message *    message = nullptr;
415     KeepAliveTlv keepAliveTlv;
416 
417     switch (mState)
418     {
419     case kStateConnectedButSessionless:
420     case kStateEstablishingSession:
421         if (IsServer())
422         {
423             // While session is being established, server is only allowed
424             // to send a Keep Alive response to a request from client.
425             OT_ASSERT(aMessageType == kResponseMessage);
426         }
427         break;
428 
429     case kStateSessionEstablished:
430         break;
431 
432     case kStateDisconnected:
433     case kStateConnecting:
434         OT_ASSERT(false);
435     }
436 
437     // Server can send Keep Alive response (to a request from client)
438     // or a unidirectional Keep Alive message. Client can send
439     // KeepAlive request message.
440 
441     if (IsServer())
442     {
443         if (aMessageType == kResponseMessage)
444         {
445             OT_ASSERT(aResponseId != 0);
446         }
447         else
448         {
449             OT_ASSERT(aMessageType == kUnidirectionalMessage);
450         }
451     }
452     else
453     {
454         OT_ASSERT(aMessageType == kRequestMessage);
455     }
456 
457     message = NewMessage();
458     VerifyOrExit(message != nullptr, error = kErrorNoBufs);
459 
460     keepAliveTlv.Init();
461 
462     if (IsServer())
463     {
464         keepAliveTlv.SetInactivityTimeout(mInactivity.GetInterval());
465         keepAliveTlv.SetKeepAliveInterval(mKeepAlive.GetInterval());
466     }
467     else
468     {
469         keepAliveTlv.SetInactivityTimeout(mInactivity.GetRequestInterval());
470         keepAliveTlv.SetKeepAliveInterval(mKeepAlive.GetRequestInterval());
471     }
472 
473     SuccessOrExit(error = message->Append(keepAliveTlv));
474 
475     error = SendMessage(*message, aMessageType, aResponseId);
476 
477 exit:
478     FreeMessageOnError(message, error);
479     return error;
480 }
481 
SendMessage(Message & aMessage,MessageType aMessageType,MessageId & aMessageId,Dns::Header::Response aResponseCode,uint32_t aResponseTimeout)482 Error Dso::Connection::SendMessage(Message &             aMessage,
483                                    MessageType           aMessageType,
484                                    MessageId &           aMessageId,
485                                    Dns::Header::Response aResponseCode,
486                                    uint32_t              aResponseTimeout)
487 {
488     Error       error          = kErrorNone;
489     Tlv::Type   primaryTlvType = Tlv::kReservedType;
490     Dns::Header header;
491 
492     switch (mState)
493     {
494     case kStateConnectedButSessionless:
495         // To establish session, client MUST send a request message.
496         // Server is not allowed to send any messages. Unidirectional
497         // messages are not allowed before session is established.
498         OT_ASSERT(IsClient());
499         OT_ASSERT(aMessageType == kRequestMessage);
500         break;
501 
502     case kStateEstablishingSession:
503         // During session establishment, client is allowed to send
504         // additional request messages, server is only allowed to
505         // send response.
506         if (IsClient())
507         {
508             OT_ASSERT(aMessageType == kRequestMessage);
509         }
510         else
511         {
512             OT_ASSERT(aMessageType == kResponseMessage);
513         }
514         break;
515 
516     case kStateSessionEstablished:
517         // All message types are allowed.
518         break;
519 
520     case kStateDisconnected:
521     case kStateConnecting:
522         OT_ASSERT(false);
523     }
524 
525     // A DSO request or unidirectional message MUST contain at
526     // least one TLV. The first TLV is the "Primary TLV" and
527     // determines the nature of the operation being performed.
528     // A DSO response message may contain no TLVs, or may contain
529     // one or more TLVs. Response Primary TLV(s) MUST appear first
530     // in a DSO response message.
531 
532     aMessage.SetOffset(0);
533     IgnoreError(ReadPrimaryTlv(aMessage, primaryTlvType));
534 
535     switch (aMessageType)
536     {
537     case kResponseMessage:
538         break;
539     case kRequestMessage:
540     case kUnidirectionalMessage:
541         OT_ASSERT(primaryTlvType != Tlv::kReservedType);
542     }
543 
544     // `header` is cleared from its constructor call so all fields
545     // start as zero.
546 
547     switch (aMessageType)
548     {
549     case kRequestMessage:
550         header.SetType(Dns::Header::kTypeQuery);
551         aMessageId = mNextMessageId;
552         break;
553 
554     case kResponseMessage:
555         header.SetType(Dns::Header::kTypeResponse);
556         break;
557 
558     case kUnidirectionalMessage:
559         header.SetType(Dns::Header::kTypeQuery);
560         aMessageId = 0;
561         break;
562     }
563 
564     header.SetMessageId(aMessageId);
565     header.SetQueryType(Dns::Header::kQueryTypeDso);
566     header.SetResponseCode(aResponseCode);
567     SuccessOrExit(error = aMessage.Prepend(header));
568 
569     SuccessOrExit(error = AppendPadding(aMessage));
570 
571     // Update `mPendingRequests` list with the new request info
572 
573     if (aMessageType == kRequestMessage)
574     {
575         SuccessOrExit(
576             error = mPendingRequests.Add(mNextMessageId, primaryTlvType, TimerMilli::GetNow() + aResponseTimeout));
577 
578         if (++mNextMessageId == 0)
579         {
580             mNextMessageId = 1;
581         }
582     }
583 
584     LogInfo("Sending %s message with id %u to %s", MessageTypeToString(aMessageType), aMessageId,
585             mPeerSockAddr.ToString().AsCString());
586 
587     switch (mState)
588     {
589     case kStateConnectedButSessionless:
590         // On client we transition from "connected" state to
591         // "establishing session" state on successfully sending a
592         // request message.
593         if (IsClient())
594         {
595             SetState(kStateEstablishingSession);
596         }
597         break;
598 
599     case kStateEstablishingSession:
600         // On server we transition from "establishing session" state
601         // to "established" on sending a response with success
602         // response code.
603         if (IsServer() && (aResponseCode == Dns::Header::kResponseSuccess))
604         {
605             SetState(kStateSessionEstablished);
606         }
607 
608     default:
609         break;
610     }
611 
612     ResetTimeouts(/* aIsKeepAliveMessage*/ (primaryTlvType == KeepAliveTlv::kType));
613 
614     otPlatDsoSend(this, &aMessage);
615 
616     // Signal any state changes. This is done at the very end when the
617     // `SendMessage()` is fully processed (all state and local
618     // variables are updated) to ensure that we do not have any
619     // reentrancy issues (e.g., if the callback signalling state
620     // change triggers another tx).
621 
622     SignalAnyStateChange();
623 
624 exit:
625     return error;
626 }
627 
AppendPadding(Message & aMessage)628 Error Dso::Connection::AppendPadding(Message &aMessage)
629 {
630     // This method appends Encryption Padding TLV to a DSO message.
631     // It uses the padding policy "Random-Block-Length Padding" from
632     // RFC 8467.
633 
634     static const uint16_t kBlockLengths[] = {8, 11, 17, 21};
635 
636     Error                error = kErrorNone;
637     uint16_t             blockLength;
638     EncryptionPaddingTlv paddingTlv;
639 
640     // We pick a random block length. The random selection can be
641     // based on a "weak" source of randomness (so the use of
642     // `NonCrypto` is fine). We add padding to the message such
643     // that its padded length is a multiple of the chosen block
644     // length.
645 
646     blockLength = kBlockLengths[Random::NonCrypto::GetUint8InRange(0, GetArrayLength(kBlockLengths))];
647 
648     paddingTlv.Init((blockLength - ((aMessage.GetLength() + sizeof(Tlv)) % blockLength)) % blockLength);
649 
650     SuccessOrExit(error = aMessage.Append(paddingTlv));
651 
652     for (uint16_t len = paddingTlv.GetLength(); len > 0; len--)
653     {
654         SuccessOrExit(error = aMessage.Append<uint8_t>(0));
655     }
656 
657 exit:
658     return error;
659 }
660 
HandleReceive(Message & aMessage)661 void Dso::Connection::HandleReceive(Message &aMessage)
662 {
663     Error       error          = kErrorAbort;
664     Tlv::Type   primaryTlvType = Tlv::kReservedType;
665     Dns::Header header;
666 
667     SuccessOrExit(aMessage.Read(0, header));
668 
669     if (header.GetQueryType() != Dns::Header::kQueryTypeDso)
670     {
671         if (header.GetType() == Dns::Header::kTypeQuery)
672         {
673             SendErrorResponse(header, Dns::Header::kResponseNotImplemented);
674             error = kErrorNone;
675         }
676 
677         ExitNow();
678     }
679 
680     switch (mState)
681     {
682     case kStateConnectedButSessionless:
683         // After connection is established, client should initiate
684         // establishing session (by sending a request). So no rx is
685         // allowed before this. On server, we allow rx of a request
686         // message only.
687         VerifyOrExit(IsServer() && (header.GetType() == Dns::Header::kTypeQuery) && (header.GetMessageId() != 0));
688         break;
689 
690     case kStateEstablishingSession:
691         // Unidirectional message are allowed after session is
692         // established. While session is being established, on client,
693         // we allow rx on response message. On server we can rx
694         // request or response.
695 
696         VerifyOrExit(header.GetMessageId() != 0);
697 
698         if (IsClient())
699         {
700             VerifyOrExit(header.GetType() == Dns::Header::kTypeResponse);
701         }
702         break;
703 
704     case kStateSessionEstablished:
705         // All message types are allowed.
706         break;
707 
708     case kStateDisconnected:
709     case kStateConnecting:
710         ExitNow();
711     }
712 
713     // All count fields MUST be set to zero in the header.
714     VerifyOrExit((header.GetQuestionCount() == 0) && (header.GetAnswerCount() == 0) &&
715                  (header.GetAuthorityRecordCount() == 0) && (header.GetAdditionalRecordCount() == 0));
716 
717     aMessage.SetOffset(sizeof(header));
718 
719     switch (ReadPrimaryTlv(aMessage, primaryTlvType))
720     {
721     case kErrorNone:
722         VerifyOrExit(primaryTlvType != Tlv::kReservedType);
723         break;
724 
725     case kErrorNotFound:
726         // The `primaryTlvType` is set to `Tlv::kReservedType`
727         // (value zero) to indicate that there is no primary TLV.
728         break;
729 
730     default:
731         ExitNow();
732     }
733 
734     switch (header.GetType())
735     {
736     case Dns::Header::kTypeQuery:
737         error = ProcessRequestOrUnidirectionalMessage(header, aMessage, primaryTlvType);
738         break;
739 
740     case Dns::Header::kTypeResponse:
741         error = ProcessResponseMessage(header, aMessage, primaryTlvType);
742         break;
743     }
744 
745 exit:
746     aMessage.Free();
747 
748     if (error == kErrorNone)
749     {
750         ResetTimeouts(/* aIsKeepAliveMessage */ (primaryTlvType == KeepAliveTlv::kType));
751     }
752     else
753     {
754         Disconnect(kForciblyAbort, kReasonPeerMisbehavior);
755     }
756 
757     // We signal any state change at the very end when the received
758     // message is fully processed (all state and local variables are
759     // updated) to ensure that we do not have any reentrancy issues
760     // (e.g., if a `Connection` method happens to be called from the
761     // callback).
762 
763     SignalAnyStateChange();
764 }
765 
ReadPrimaryTlv(const Message & aMessage,Tlv::Type & aPrimaryTlvType) const766 Error Dso::Connection::ReadPrimaryTlv(const Message &aMessage, Tlv::Type &aPrimaryTlvType) const
767 {
768     // Read and validate the primary TLV (first TLV  after the header).
769     // The `aMessage.GetOffset()` must point to the first TLV. If no
770     // TLV then `kErrorNotFound` is returned. If TLV in message is not
771     // well-formed `kErrorParse` is returned. The read TLV type is
772     // returned in `aPrimaryTlvType` (set to `Tlv::kReservedType`
773     // (value zero) when `kErrorNotFound`).
774 
775     Error error = kErrorNotFound;
776     Tlv   tlv;
777 
778     aPrimaryTlvType = Tlv::kReservedType;
779 
780     SuccessOrExit(aMessage.Read(aMessage.GetOffset(), tlv));
781     VerifyOrExit(aMessage.GetOffset() + tlv.GetSize() <= aMessage.GetLength(), error = kErrorParse);
782     aPrimaryTlvType = tlv.GetType();
783     error           = kErrorNone;
784 
785 exit:
786     return error;
787 }
788 
ProcessRequestOrUnidirectionalMessage(const Dns::Header & aHeader,const Message & aMessage,Tlv::Type aPrimaryTlvType)789 Error Dso::Connection::ProcessRequestOrUnidirectionalMessage(const Dns::Header &aHeader,
790                                                              const Message &    aMessage,
791                                                              Tlv::Type          aPrimaryTlvType)
792 {
793     Error error = kErrorAbort;
794 
795     if (IsServer() && (mState == kStateConnectedButSessionless))
796     {
797         SetState(kStateEstablishingSession);
798     }
799 
800     // A DSO request or unidirectional message MUST contain at
801     // least one TLV which is the "Primary TLV" and determines
802     // the nature of the operation being performed.
803 
804     switch (aPrimaryTlvType)
805     {
806     case KeepAliveTlv::kType:
807         error = ProcessKeepAliveMessage(aHeader, aMessage);
808         break;
809 
810     case RetryDelayTlv::kType:
811         error = ProcessRetryDelayMessage(aHeader, aMessage);
812         break;
813 
814     case Tlv::kReservedType:
815     case EncryptionPaddingTlv::kType:
816         // Misbehavior by peer.
817         break;
818 
819     default:
820         if (aHeader.GetMessageId() == 0)
821         {
822             LogInfo("Received unidirectional message from %s", mPeerSockAddr.ToString().AsCString());
823 
824             error = mCallbacks.mProcessUnidirectionalMessage(*this, aMessage, aPrimaryTlvType);
825         }
826         else
827         {
828             MessageId messageId = aHeader.GetMessageId();
829 
830             LogInfo("Received request message with id %u from %s", messageId, mPeerSockAddr.ToString().AsCString());
831 
832             error = mCallbacks.mProcessRequestMessage(*this, messageId, aMessage, aPrimaryTlvType);
833 
834             // `kErrorNotFound` indicates that TLV type is not known.
835 
836             if (error == kErrorNotFound)
837             {
838                 SendErrorResponse(aHeader, Dns::Header::kDsoTypeNotImplemented);
839                 error = kErrorNone;
840             }
841         }
842         break;
843     }
844 
845     return error;
846 }
847 
ProcessResponseMessage(const Dns::Header & aHeader,const Message & aMessage,Tlv::Type aPrimaryTlvType)848 Error Dso::Connection::ProcessResponseMessage(const Dns::Header &aHeader,
849                                               const Message &    aMessage,
850                                               Tlv::Type          aPrimaryTlvType)
851 {
852     Error     error = kErrorAbort;
853     Tlv::Type requestPrimaryTlvType;
854 
855     // If a client or server receives a response where the message
856     // ID is zero, or is any other value that does not match the
857     // message ID of any of its outstanding operations, this is a
858     // fatal error and the recipient MUST forcibly abort the
859     // connection immediately.
860 
861     VerifyOrExit(aHeader.GetMessageId() != 0);
862     VerifyOrExit(mPendingRequests.Contains(aHeader.GetMessageId(), requestPrimaryTlvType));
863 
864     // If the response has no error and contains a primary TLV, it
865     // MUST match the request primary TLV.
866 
867     if ((aHeader.GetResponseCode() == Dns::Header::kResponseSuccess) && (aPrimaryTlvType != Tlv::kReservedType))
868     {
869         VerifyOrExit(aPrimaryTlvType == requestPrimaryTlvType);
870     }
871 
872     mPendingRequests.Remove(aHeader.GetMessageId());
873 
874     switch (requestPrimaryTlvType)
875     {
876     case KeepAliveTlv::kType:
877         SuccessOrExit(error = ProcessKeepAliveMessage(aHeader, aMessage));
878         break;
879 
880     default:
881         SuccessOrExit(error = mCallbacks.mProcessResponseMessage(*this, aHeader, aMessage, aPrimaryTlvType,
882                                                                  requestPrimaryTlvType));
883         break;
884     }
885 
886     // DSO session is established when client sends a request message
887     // and receives a response from server with no error code.
888 
889     if (IsClient() && (mState == kStateEstablishingSession) &&
890         (aHeader.GetResponseCode() == Dns::Header::kResponseSuccess))
891     {
892         SetState(kStateSessionEstablished);
893     }
894 
895 exit:
896     return error;
897 }
898 
ProcessKeepAliveMessage(const Dns::Header & aHeader,const Message & aMessage)899 Error Dso::Connection::ProcessKeepAliveMessage(const Dns::Header &aHeader, const Message &aMessage)
900 {
901     Error        error  = kErrorAbort;
902     uint16_t     offset = aMessage.GetOffset();
903     Tlv          tlv;
904     KeepAliveTlv keepAliveTlv;
905 
906     if (aHeader.GetType() == Dns::Header::kTypeResponse)
907     {
908         // A Keep Alive response message is allowed on a client from a sever.
909 
910         VerifyOrExit(IsClient());
911 
912         if (aHeader.GetResponseCode() != Dns::Header::kResponseSuccess)
913         {
914             // We got an error response code from server for our
915             // Keep Alive request message. If this happens while
916             // establishing the DSO session, it indicates that server
917             // does not support DSO, so we close the connection. If
918             // this happens while session is already established, it
919             // is a misbehavior (fatal error) by server.
920 
921             if (mState == kStateEstablishingSession)
922             {
923                 Disconnect(kGracefullyClose, kReasonPeerDoesNotSupportDso);
924                 error = kErrorNone;
925             }
926 
927             ExitNow();
928         }
929     }
930 
931     // Parse and validate the Keep Alive Message
932 
933     SuccessOrExit(aMessage.Read(offset, keepAliveTlv));
934     offset += keepAliveTlv.GetSize();
935 
936     VerifyOrExit((keepAliveTlv.GetType() == KeepAliveTlv::kType) && keepAliveTlv.IsValid());
937 
938     // Keep Alive message MUST contain only one Keep Alive TLV.
939 
940     while (offset < aMessage.GetLength())
941     {
942         SuccessOrExit(aMessage.Read(offset, tlv));
943         offset += tlv.GetSize();
944 
945         VerifyOrExit((tlv.GetType() != KeepAliveTlv::kType) && (tlv.GetType() != RetryDelayTlv::kType));
946     }
947 
948     VerifyOrExit(offset == aMessage.GetLength());
949 
950     if (aHeader.GetType() == Dns::Header::kTypeQuery)
951     {
952         if (IsServer())
953         {
954             // Received a Keep Alive message from client. It MUST
955             // be a request message (not unidirectional). We prepare
956             // and send a Keep Alive response.
957 
958             VerifyOrExit(aHeader.GetMessageId() != 0);
959 
960             LogInfo("Received KeepAlive request message from client %s", mPeerSockAddr.ToString().AsCString());
961 
962             IgnoreError(SendKeepAliveMessage(kResponseMessage, aHeader.GetMessageId()));
963             error = kErrorNone;
964             ExitNow();
965         }
966 
967         // Received a Keep Alive message on client from server. Server
968         // Keep Alive message MUST be unidirectional (message ID
969         // zero).
970 
971         VerifyOrExit(aHeader.GetMessageId() == 0);
972     }
973 
974     LogInfo("Received Keep Alive %s message from server %s",
975             (aHeader.GetMessageId() == 0) ? "unidirectional" : "response", mPeerSockAddr.ToString().AsCString());
976 
977     // Receiving a Keep Alive interval value from server less than the
978     // minimum (ten seconds) is a fatal error and client MUST then
979     // abort the connection.
980 
981     VerifyOrExit(keepAliveTlv.GetKeepAliveInterval() >= kMinKeepAliveInterval);
982 
983     // Update the timeout intervals on the connection from
984     // the new values we got from the server. The receive
985     // of the Keep Alive message does not itself reset the
986     // inactivity timer. So we use `AdjustInactivityTimeout`
987     // which takes into account the time elapsed since the
988     // last activity.
989 
990     AdjustInactivityTimeout(keepAliveTlv.GetInactivityTimeout());
991     mKeepAlive.SetInterval(keepAliveTlv.GetKeepAliveInterval());
992 
993     LogInfo("Timeouts Inactivity:%u, KeepAlive:%u", mInactivity.GetInterval(), mKeepAlive.GetInterval());
994 
995     error = kErrorNone;
996 
997 exit:
998     return error;
999 }
1000 
ProcessRetryDelayMessage(const Dns::Header & aHeader,const Message & aMessage)1001 Error Dso::Connection::ProcessRetryDelayMessage(const Dns::Header &aHeader, const Message &aMessage)
1002 
1003 {
1004     Error         error = kErrorAbort;
1005     RetryDelayTlv retryDelayTlv;
1006 
1007     // Retry Delay TLV can be used as the Primary TLV only in
1008     // a unidirectional message sent from server to client.
1009     // It is used by the server to instruct the client to
1010     // close the session and its underlying connection, and not
1011     // to reconnect for the indicated time interval.
1012 
1013     VerifyOrExit(IsClient() && (aHeader.GetMessageId() == 0));
1014 
1015     SuccessOrExit(aMessage.Read(aMessage.GetOffset(), retryDelayTlv));
1016     VerifyOrExit(retryDelayTlv.IsValid());
1017 
1018     mRetryDelayErrorCode = aHeader.GetResponseCode();
1019     mRetryDelay          = retryDelayTlv.GetRetryDelay();
1020 
1021     LogInfo("Received Retry Delay message from server %s", mPeerSockAddr.ToString().AsCString());
1022     LogInfo("   RetryDelay:%u ms, ResponseCode:%d", mRetryDelay, mRetryDelayErrorCode);
1023 
1024     Disconnect(kGracefullyClose, kReasonServerRetryDelayRequest);
1025 
1026 exit:
1027     return error;
1028 }
1029 
SendErrorResponse(const Dns::Header & aHeader,Dns::Header::Response aResponseCode)1030 void Dso::Connection::SendErrorResponse(const Dns::Header &aHeader, Dns::Header::Response aResponseCode)
1031 {
1032     Message *   response = NewMessage();
1033     Dns::Header header;
1034 
1035     VerifyOrExit(response != nullptr);
1036 
1037     header.SetMessageId(aHeader.GetMessageId());
1038     header.SetType(Dns::Header::kTypeResponse);
1039     header.SetQueryType(aHeader.GetQueryType());
1040     header.SetResponseCode(aResponseCode);
1041 
1042     SuccessOrExit(response->Prepend(header));
1043 
1044     otPlatDsoSend(this, response);
1045     response = nullptr;
1046 
1047 exit:
1048     FreeMessage(response);
1049 }
1050 
AdjustInactivityTimeout(uint32_t aNewTimeout)1051 void Dso::Connection::AdjustInactivityTimeout(uint32_t aNewTimeout)
1052 {
1053     // This method sets the inactivity timeout interval to a new value
1054     // and updates the expiration time based on the new timeout value.
1055     //
1056     // On client, it is called on receiving a Keep Alive response or
1057     // unidirectional message from server. Note that the receive of
1058     // the Keep Alive message does not itself reset the inactivity
1059     // timer. So the time elapsed since the last activity should be
1060     // taken into account with the new inactivity timeout value.
1061     //
1062     // On server this method is called from `SetTimeouts()` when a new
1063     // inactivity timeout value is set.
1064 
1065     TimeMilli now = TimerMilli::GetNow();
1066     TimeMilli start;
1067     TimeMilli newExpiration;
1068 
1069     if (mState == kStateDisconnected)
1070     {
1071         mInactivity.SetInterval(aNewTimeout);
1072         ExitNow();
1073     }
1074 
1075     VerifyOrExit(aNewTimeout != mInactivity.GetInterval());
1076 
1077     // Calculate the start time (i.e., the last time inactivity timer
1078     // was cleared). If the previous inactivity time is set to
1079     // `kInfinite` value (`IsUsed()` returns `false`) then
1080     // `GetExpirationTime()` returns the start time. Otherwise, we
1081     // calculate it going back from the current expiration time with
1082     // the current wait interval.
1083 
1084     if (!mInactivity.IsUsed())
1085     {
1086         start = mInactivity.GetExpirationTime();
1087     }
1088     else if (IsClient())
1089     {
1090         start = mInactivity.GetExpirationTime() - mInactivity.GetInterval();
1091     }
1092     else
1093     {
1094         start = mInactivity.GetExpirationTime() - CalculateServerInactivityWaitTime();
1095     }
1096 
1097     mInactivity.SetInterval(aNewTimeout);
1098 
1099     if (!mInactivity.IsUsed())
1100     {
1101         newExpiration = start;
1102     }
1103     else if (IsClient())
1104     {
1105         newExpiration = start + aNewTimeout;
1106 
1107         if (newExpiration < now)
1108         {
1109             newExpiration = now;
1110         }
1111     }
1112     else
1113     {
1114         newExpiration = start + CalculateServerInactivityWaitTime();
1115 
1116         if (newExpiration < now)
1117         {
1118             // If the server abruptly reduces the inactivity timeout
1119             // such that current elapsed time is already more than
1120             // twice the new inactivity timeout, then the client is
1121             // immediately considered delinquent (server can forcibly
1122             // abort the connection). So to give the client time to
1123             // close the connection gracefully, the server SHOULD
1124             // give the client an additional grace period of either
1125             // five seconds or one quarter of the new inactivity
1126             // timeout, whichever is greater [RFC 8490 - 7.1.1].
1127 
1128             newExpiration = now + OT_MAX(kMinServerInactivityWaitTime, aNewTimeout / 4);
1129         }
1130     }
1131 
1132     mInactivity.SetExpirationTime(newExpiration);
1133 
1134 exit:
1135     return;
1136 }
1137 
CalculateServerInactivityWaitTime(void) const1138 uint32_t Dso::Connection::CalculateServerInactivityWaitTime(void) const
1139 {
1140     // A server will abort an idle session after five seconds
1141     // (`kMinServerInactivityWaitTime`) or twice the inactivity
1142     // timeout value, whichever is greater [RFC 8490 - 6.4.1].
1143 
1144     OT_ASSERT(mInactivity.IsUsed());
1145 
1146     return OT_MAX(mInactivity.GetInterval() * 2, kMinServerInactivityWaitTime);
1147 }
1148 
ResetTimeouts(bool aIsKeepAliveMessage)1149 void Dso::Connection::ResetTimeouts(bool aIsKeepAliveMessage)
1150 {
1151     TimeMilli now = TimerMilli::GetNow();
1152     TimeMilli nextTime;
1153 
1154     // At both servers and clients, the generation or reception of any
1155     // complete DNS message resets both timers for that DSO
1156     // session, with the one exception being that a DSO Keep Alive
1157     // message resets only the keep alive timer, not the inactivity
1158     // timeout timer [RFC 8490 - 6.3]
1159 
1160     if (mKeepAlive.IsUsed())
1161     {
1162         // On client, we wait for the Keep Alive interval but on server
1163         // we wait for twice the interval before considering Keep Alive
1164         // timeout.
1165         //
1166         // Note that we limit the interval to `Timeout::kMaxInterval`
1167         // (which is ~12 days). This max limit ensures that even twice
1168         // the interval is less than max OpenThread timer duration so
1169         // that the expiration time calculations below stay within the
1170         // `TimerMilli` range.
1171 
1172         mKeepAlive.SetExpirationTime(now + mKeepAlive.GetInterval() * (IsServer() ? 2 : 1));
1173     }
1174 
1175     if (!aIsKeepAliveMessage)
1176     {
1177         if (mInactivity.IsUsed())
1178         {
1179             mInactivity.SetExpirationTime(
1180                 now + (IsServer() ? CalculateServerInactivityWaitTime() : mInactivity.GetInterval()));
1181         }
1182         else
1183         {
1184             // When Inactivity timeout is not used (i.e., interval is set
1185             // to the special `kInfinite` value), we still need to track
1186             // the time so that if/when later the inactivity interval
1187             // gets changed, we can adjust the remaining time correctly
1188             // from `AdjustInactivityTimeout()`. In this case, we just
1189             // track the current time as "expiration time".
1190 
1191             mInactivity.SetExpirationTime(now);
1192         }
1193     }
1194 
1195     nextTime = GetNextFireTime(now);
1196 
1197     if (nextTime != now.GetDistantFuture())
1198     {
1199         Get<Dso>().mTimer.FireAtIfEarlier(nextTime);
1200     }
1201 }
1202 
GetNextFireTime(TimeMilli aNow) const1203 TimeMilli Dso::Connection::GetNextFireTime(TimeMilli aNow) const
1204 {
1205     TimeMilli nextTime = aNow.GetDistantFuture();
1206 
1207     switch (mState)
1208     {
1209     case kStateDisconnected:
1210         break;
1211 
1212     case kStateConnecting:
1213         // While in `kStateConnecting`, Keep Alive timer is
1214         // used for `kConnectingTimeout`.
1215         VerifyOrExit(mKeepAlive.GetExpirationTime() > aNow, nextTime = aNow);
1216         nextTime = mKeepAlive.GetExpirationTime();
1217         break;
1218 
1219     case kStateConnectedButSessionless:
1220     case kStateEstablishingSession:
1221     case kStateSessionEstablished:
1222         nextTime = OT_MIN(nextTime, mPendingRequests.GetNextFireTime(aNow));
1223 
1224         if (mKeepAlive.IsUsed())
1225         {
1226             VerifyOrExit(mKeepAlive.GetExpirationTime() > aNow, nextTime = aNow);
1227             nextTime = OT_MIN(nextTime, mKeepAlive.GetExpirationTime());
1228         }
1229 
1230         if (mInactivity.IsUsed() && mPendingRequests.IsEmpty() && !mLongLivedOperation)
1231         {
1232             // An operation being active on a DSO Session includes
1233             // a request message waiting for a response, or an
1234             // active long-lived operation.
1235 
1236             VerifyOrExit(mInactivity.GetExpirationTime() > aNow, nextTime = aNow);
1237             nextTime = OT_MIN(nextTime, mInactivity.GetExpirationTime());
1238         }
1239 
1240         break;
1241     }
1242 
1243 exit:
1244     return nextTime;
1245 }
1246 
HandleTimer(TimeMilli aNow,TimeMilli & aNextTime)1247 void Dso::Connection::HandleTimer(TimeMilli aNow, TimeMilli &aNextTime)
1248 {
1249     switch (mState)
1250     {
1251     case kStateDisconnected:
1252         break;
1253 
1254     case kStateConnecting:
1255         if (mKeepAlive.IsExpired(aNow))
1256         {
1257             Disconnect(kGracefullyClose, kReasonFailedToConnect);
1258         }
1259         break;
1260 
1261     case kStateConnectedButSessionless:
1262     case kStateEstablishingSession:
1263     case kStateSessionEstablished:
1264         if (mPendingRequests.HasAnyTimedOut(aNow))
1265         {
1266             // If server sends no response to a request, client
1267             // waits for 30 seconds (`kResponseTimeout`) after which
1268             // client MUST forcibly abort the connection.
1269             Disconnect(kForciblyAbort, kReasonResponseTimeout);
1270             ExitNow();
1271         }
1272 
1273         // The inactivity timer is kept clear, while an operation is
1274         // active on the session (which includes a request waiting for
1275         // response or an active long-lived operation).
1276 
1277         if (mInactivity.IsUsed() && mPendingRequests.IsEmpty() && !mLongLivedOperation && mInactivity.IsExpired(aNow))
1278         {
1279             // On client, if the inactivity timeout is reached, the
1280             // connection is closed gracefully. On server, if too much
1281             // time (`CalculateServerInactivityWaitTime()`, i.e., five
1282             // seconds or twice the current inactivity timeout interval,
1283             // whichever is grater) elapses server MUST consider the
1284             // client delinquent and MUST forcibly abort the connection.
1285 
1286             Disconnect(IsClient() ? kGracefullyClose : kForciblyAbort, kReasonInactivityTimeout);
1287             ExitNow();
1288         }
1289 
1290         if (mKeepAlive.IsUsed() && mKeepAlive.IsExpired(aNow))
1291         {
1292             // On client, if the Keep Alive interval elapses without any
1293             // DNS messages being sent or received, the client MUST take
1294             // action and send a DSO Keep Alive message.
1295             //
1296             // On server, if twice the Keep Alive interval value elapses
1297             // without any messages being sent or received, the server
1298             // considers the client delinquent and aborts the connection.
1299 
1300             if (IsClient())
1301             {
1302                 IgnoreError(SendKeepAliveMessage());
1303             }
1304             else
1305             {
1306                 Disconnect(kForciblyAbort, kReasonKeepAliveTimeout);
1307                 ExitNow();
1308             }
1309         }
1310         break;
1311     }
1312 
1313 exit:
1314     aNextTime = OT_MIN(aNextTime, GetNextFireTime(aNow));
1315     SignalAnyStateChange();
1316 }
1317 
StateToString(State aState)1318 const char *Dso::Connection::StateToString(State aState)
1319 {
1320     static const char *const kStateStrings[] = {
1321         "Disconnected",            // (0) kStateDisconnected,
1322         "Connecting",              // (1) kStateConnecting,
1323         "ConnectedButSessionless", // (2) kStateConnectedButSessionless,
1324         "EstablishingSession",     // (3) kStateEstablishingSession,
1325         "SessionEstablished",      // (4) kStateSessionEstablished,
1326     };
1327 
1328     static_assert(0 == kStateDisconnected, "kStateDisconnected value is incorrect");
1329     static_assert(1 == kStateConnecting, "kStateConnecting value is incorrect");
1330     static_assert(2 == kStateConnectedButSessionless, "kStateConnectedButSessionless value is incorrect");
1331     static_assert(3 == kStateEstablishingSession, "kStateEstablishingSession value is incorrect");
1332     static_assert(4 == kStateSessionEstablished, "kStateSessionEstablished value is incorrect");
1333 
1334     return kStateStrings[aState];
1335 }
1336 
MessageTypeToString(MessageType aMessageType)1337 const char *Dso::Connection::MessageTypeToString(MessageType aMessageType)
1338 {
1339     static const char *const kMessageTypeStrings[] = {
1340         "Request",        // (0) kRequestMessage
1341         "Response",       // (1) kResponseMessage
1342         "Unidirectional", // (2) kUnidirectionalMessage
1343     };
1344 
1345     static_assert(0 == kRequestMessage, "kRequestMessage value is incorrect");
1346     static_assert(1 == kResponseMessage, "kResponseMessage value is incorrect");
1347     static_assert(2 == kUnidirectionalMessage, "kUnidirectionalMessage value is incorrect");
1348 
1349     return kMessageTypeStrings[aMessageType];
1350 }
1351 
DisconnectReasonToString(DisconnectReason aReason)1352 const char *Dso::Connection::DisconnectReasonToString(DisconnectReason aReason)
1353 {
1354     static const char *const kDisconnectReasonStrings[] = {
1355         "FailedToConnect",         // (0) kReasonFailedToConnect
1356         "ResponseTimeout",         // (1) kReasonResponseTimeout
1357         "PeerDoesNotSupportDso",   // (2) kReasonPeerDoesNotSupportDso
1358         "PeerClosed",              // (3) kReasonPeerClosed
1359         "PeerAborted",             // (4) kReasonPeerAborted
1360         "InactivityTimeout",       // (5) kReasonInactivityTimeout
1361         "KeepAliveTimeout",        // (6) kReasonKeepAliveTimeout
1362         "ServerRetryDelayRequest", // (7) kReasonServerRetryDelayRequest
1363         "PeerMisbehavior",         // (8) kReasonPeerMisbehavior
1364         "Unknown",                 // (9) kReasonUnknown
1365     };
1366 
1367     static_assert(0 == kReasonFailedToConnect, "kReasonFailedToConnect value is incorrect");
1368     static_assert(1 == kReasonResponseTimeout, "kReasonResponseTimeout value is incorrect");
1369     static_assert(2 == kReasonPeerDoesNotSupportDso, "kReasonPeerDoesNotSupportDso value is incorrect");
1370     static_assert(3 == kReasonPeerClosed, "kReasonPeerClosed value is incorrect");
1371     static_assert(4 == kReasonPeerAborted, "kReasonPeerAborted value is incorrect");
1372     static_assert(5 == kReasonInactivityTimeout, "kReasonInactivityTimeout value is incorrect");
1373     static_assert(6 == kReasonKeepAliveTimeout, "kReasonKeepAliveTimeout value is incorrect");
1374     static_assert(7 == kReasonServerRetryDelayRequest, "kReasonServerRetryDelayRequest value is incorrect");
1375     static_assert(8 == kReasonPeerMisbehavior, "kReasonPeerMisbehavior value is incorrect");
1376     static_assert(9 == kReasonUnknown, "kReasonUnknown value is incorrect");
1377 
1378     return kDisconnectReasonStrings[aReason];
1379 }
1380 
1381 //---------------------------------------------------------------------------------------------------------------------
1382 // Dso::Connection::PendingRequests
1383 
Contains(MessageId aMessageId,Tlv::Type & aPrimaryTlvType) const1384 bool Dso::Connection::PendingRequests::Contains(MessageId aMessageId, Tlv::Type &aPrimaryTlvType) const
1385 {
1386     bool         contains = true;
1387     const Entry *entry    = mRequests.FindMatching(aMessageId);
1388 
1389     VerifyOrExit(entry != nullptr, contains = false);
1390     aPrimaryTlvType = entry->mPrimaryTlvType;
1391 
1392 exit:
1393     return contains;
1394 }
1395 
Add(MessageId aMessageId,Tlv::Type aPrimaryTlvType,TimeMilli aResponseTimeout)1396 Error Dso::Connection::PendingRequests::Add(MessageId aMessageId, Tlv::Type aPrimaryTlvType, TimeMilli aResponseTimeout)
1397 {
1398     Error  error = kErrorNone;
1399     Entry *entry = mRequests.PushBack();
1400 
1401     VerifyOrExit(entry != nullptr, error = kErrorNoBufs);
1402     entry->mMessageId      = aMessageId;
1403     entry->mPrimaryTlvType = aPrimaryTlvType;
1404     entry->mTimeout        = aResponseTimeout;
1405 
1406 exit:
1407     return error;
1408 }
1409 
Remove(MessageId aMessageId)1410 void Dso::Connection::PendingRequests::Remove(MessageId aMessageId)
1411 {
1412     mRequests.RemoveMatching(aMessageId);
1413 }
1414 
HasAnyTimedOut(TimeMilli aNow) const1415 bool Dso::Connection::PendingRequests::HasAnyTimedOut(TimeMilli aNow) const
1416 {
1417     bool timedOut = false;
1418 
1419     for (const Entry &entry : mRequests)
1420     {
1421         if (entry.mTimeout <= aNow)
1422         {
1423             timedOut = true;
1424             break;
1425         }
1426     }
1427 
1428     return timedOut;
1429 }
1430 
GetNextFireTime(TimeMilli aNow) const1431 TimeMilli Dso::Connection::PendingRequests::GetNextFireTime(TimeMilli aNow) const
1432 {
1433     TimeMilli nextTime = aNow.GetDistantFuture();
1434 
1435     for (const Entry &entry : mRequests)
1436     {
1437         VerifyOrExit(entry.mTimeout > aNow, nextTime = aNow);
1438         nextTime = OT_MIN(entry.mTimeout, nextTime);
1439     }
1440 
1441 exit:
1442     return nextTime;
1443 }
1444 
1445 //---------------------------------------------------------------------------------------------------------------------
1446 // Dso
1447 
Dso(Instance & aInstance)1448 Dso::Dso(Instance &aInstance)
1449     : InstanceLocator(aInstance)
1450     , mAcceptHandler(nullptr)
1451     , mTimer(aInstance, HandleTimer)
1452 {
1453 }
1454 
StartListening(AcceptHandler aAcceptHandler)1455 void Dso::StartListening(AcceptHandler aAcceptHandler)
1456 {
1457     mAcceptHandler = aAcceptHandler;
1458     otPlatDsoEnableListening(&GetInstance(), true);
1459 }
1460 
StopListening(void)1461 void Dso::StopListening(void)
1462 {
1463     otPlatDsoEnableListening(&GetInstance(), false);
1464 }
1465 
FindClientConnection(const Ip6::SockAddr & aPeerSockAddr)1466 Dso::Connection *Dso::FindClientConnection(const Ip6::SockAddr &aPeerSockAddr)
1467 {
1468     return mClientConnections.FindMatching(aPeerSockAddr);
1469 }
1470 
FindServerConnection(const Ip6::SockAddr & aPeerSockAddr)1471 Dso::Connection *Dso::FindServerConnection(const Ip6::SockAddr &aPeerSockAddr)
1472 {
1473     return mServerConnections.FindMatching(aPeerSockAddr);
1474 }
1475 
AcceptConnection(const Ip6::SockAddr & aPeerSockAddr)1476 Dso::Connection *Dso::AcceptConnection(const Ip6::SockAddr &aPeerSockAddr)
1477 {
1478     Connection *connection = nullptr;
1479 
1480     VerifyOrExit(mAcceptHandler != nullptr);
1481     connection = mAcceptHandler(GetInstance(), aPeerSockAddr);
1482 
1483     VerifyOrExit(connection != nullptr);
1484     connection->Accept();
1485 
1486 exit:
1487     return connection;
1488 }
1489 
HandleTimer(Timer & aTimer)1490 void Dso::HandleTimer(Timer &aTimer)
1491 {
1492     aTimer.Get<Dso>().HandleTimer();
1493 }
1494 
HandleTimer(void)1495 void Dso::HandleTimer(void)
1496 {
1497     TimeMilli   now      = TimerMilli::GetNow();
1498     TimeMilli   nextTime = now.GetDistantFuture();
1499     Connection *conn;
1500     Connection *next;
1501 
1502     for (conn = mClientConnections.GetHead(); conn != nullptr; conn = next)
1503     {
1504         next = conn->GetNext();
1505         conn->HandleTimer(now, nextTime);
1506     }
1507 
1508     for (conn = mServerConnections.GetHead(); conn != nullptr; conn = next)
1509     {
1510         next = conn->GetNext();
1511         conn->HandleTimer(now, nextTime);
1512     }
1513 
1514     if (nextTime != now.GetDistantFuture())
1515     {
1516         mTimer.FireAtIfEarlier(nextTime);
1517     }
1518 }
1519 
1520 } // namespace Dns
1521 } // namespace ot
1522 
1523 #endif // OPENTHREAD_CONFIG_DNS_DSO_ENABLE
1524