1 /*
2 * Copyright (c) 2021, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 #include "dns_dso.hpp"
30
31 #if OPENTHREAD_CONFIG_DNS_DSO_ENABLE
32
33 #include "common/array.hpp"
34 #include "common/as_core_type.hpp"
35 #include "common/code_utils.hpp"
36 #include "common/debug.hpp"
37 #include "common/instance.hpp"
38 #include "common/locator_getters.hpp"
39 #include "common/log.hpp"
40 #include "common/random.hpp"
41
42 /**
43 * @file
44 * This file implements the DNS Stateful Operations (DSO) per RFC 8490.
45 */
46
47 namespace ot {
48 namespace Dns {
49
50 RegisterLogModule("DnsDso");
51
52 //---------------------------------------------------------------------------------------------------------------------
53 // otPlatDso transport callbacks
54
otPlatDsoGetInstance(otPlatDsoConnection * aConnection)55 extern "C" otInstance *otPlatDsoGetInstance(otPlatDsoConnection *aConnection)
56 {
57 return &AsCoreType(aConnection).GetInstance();
58 }
59
otPlatDsoAccept(otInstance * aInstance,const otSockAddr * aPeerSockAddr)60 extern "C" otPlatDsoConnection *otPlatDsoAccept(otInstance *aInstance, const otSockAddr *aPeerSockAddr)
61 {
62 return AsCoreType(aInstance).Get<Dso>().AcceptConnection(AsCoreType(aPeerSockAddr));
63 }
64
otPlatDsoHandleConnected(otPlatDsoConnection * aConnection)65 extern "C" void otPlatDsoHandleConnected(otPlatDsoConnection *aConnection)
66 {
67 AsCoreType(aConnection).HandleConnected();
68 }
69
otPlatDsoHandleReceive(otPlatDsoConnection * aConnection,otMessage * aMessage)70 extern "C" void otPlatDsoHandleReceive(otPlatDsoConnection *aConnection, otMessage *aMessage)
71 {
72 AsCoreType(aConnection).HandleReceive(AsCoreType(aMessage));
73 }
74
otPlatDsoHandleDisconnected(otPlatDsoConnection * aConnection,otPlatDsoDisconnectMode aMode)75 extern "C" void otPlatDsoHandleDisconnected(otPlatDsoConnection *aConnection, otPlatDsoDisconnectMode aMode)
76 {
77 AsCoreType(aConnection).HandleDisconnected(MapEnum(aMode));
78 }
79
80 //---------------------------------------------------------------------------------------------------------------------
81 // Dso::Connection
82
Connection(Instance & aInstance,const Ip6::SockAddr & aPeerSockAddr,Callbacks & aCallbacks,uint32_t aInactivityTimeout,uint32_t aKeepAliveInterval)83 Dso::Connection::Connection(Instance & aInstance,
84 const Ip6::SockAddr &aPeerSockAddr,
85 Callbacks & aCallbacks,
86 uint32_t aInactivityTimeout,
87 uint32_t aKeepAliveInterval)
88 : InstanceLocator(aInstance)
89 , mNext(nullptr)
90 , mCallbacks(aCallbacks)
91 , mPeerSockAddr(aPeerSockAddr)
92 , mState(kStateDisconnected)
93 , mIsServer(false)
94 , mInactivity(aInactivityTimeout)
95 , mKeepAlive(aKeepAliveInterval)
96 {
97 OT_ASSERT(aKeepAliveInterval >= kMinKeepAliveInterval);
98 Init(/* aIsServer */ false);
99 }
100
Init(bool aIsServer)101 void Dso::Connection::Init(bool aIsServer)
102 {
103 mNextMessageId = 1;
104 mIsServer = aIsServer;
105 mStateDidChange = false;
106 mLongLivedOperation = false;
107 mRetryDelay = 0;
108 mRetryDelayErrorCode = Dns::Header::kResponseSuccess;
109 mDisconnectReason = kReasonUnknown;
110 }
111
SetState(State aState)112 void Dso::Connection::SetState(State aState)
113 {
114 VerifyOrExit(mState != aState);
115
116 LogInfo("State: %s -> %s on connection with %s", StateToString(mState), StateToString(aState),
117 mPeerSockAddr.ToString().AsCString());
118
119 mState = aState;
120 mStateDidChange = true;
121
122 exit:
123 return;
124 }
125
SignalAnyStateChange(void)126 void Dso::Connection::SignalAnyStateChange(void)
127 {
128 VerifyOrExit(mStateDidChange);
129 mStateDidChange = false;
130
131 switch (mState)
132 {
133 case kStateDisconnected:
134 mCallbacks.mHandleDisconnected(*this);
135 break;
136
137 case kStateConnectedButSessionless:
138 mCallbacks.mHandleConnected(*this);
139 break;
140
141 case kStateSessionEstablished:
142 mCallbacks.mHandleSessionEstablished(*this);
143 break;
144
145 case kStateConnecting:
146 case kStateEstablishingSession:
147 break;
148 };
149
150 exit:
151 return;
152 }
153
NewMessage(void)154 Message *Dso::Connection::NewMessage(void)
155 {
156 return Get<MessagePool>().Allocate(Message::kTypeOther, sizeof(Dns::Header),
157 Message::Settings(Message::kPriorityNormal));
158 }
159
Connect(void)160 void Dso::Connection::Connect(void)
161 {
162 OT_ASSERT(mState == kStateDisconnected);
163
164 Init(/* aIsServer */ false);
165 Get<Dso>().mClientConnections.Push(*this);
166 MarkAsConnecting();
167 otPlatDsoConnect(this, &mPeerSockAddr);
168 }
169
Accept(void)170 void Dso::Connection::Accept(void)
171 {
172 OT_ASSERT(mState == kStateDisconnected);
173
174 Init(/* aIsServer */ true);
175 Get<Dso>().mServerConnections.Push(*this);
176 MarkAsConnecting();
177 }
178
MarkAsConnecting(void)179 void Dso::Connection::MarkAsConnecting(void)
180 {
181 SetState(kStateConnecting);
182
183 // While in `kStateConnecting` state we use the `mKeepAlive` to
184 // track the `kConnectingTimeout` (if connection is not established
185 // within the timeout, we consider it as failure and close it).
186
187 mKeepAlive.SetExpirationTime(TimerMilli::GetNow() + kConnectingTimeout);
188 Get<Dso>().mTimer.FireAtIfEarlier(mKeepAlive.GetExpirationTime());
189
190 // Wait for `HandleConnected()` or `HandleDisconnected()` callbacks
191 // or timeout.
192 }
193
HandleConnected(void)194 void Dso::Connection::HandleConnected(void)
195 {
196 OT_ASSERT(mState == kStateConnecting);
197
198 SetState(kStateConnectedButSessionless);
199 ResetTimeouts(/* aIsKeepAliveMessage */ false);
200
201 SignalAnyStateChange();
202 }
203
Disconnect(DisconnectMode aMode,DisconnectReason aReason)204 void Dso::Connection::Disconnect(DisconnectMode aMode, DisconnectReason aReason)
205 {
206 VerifyOrExit(mState != kStateDisconnected);
207
208 mDisconnectReason = aReason;
209 MarkAsDisconnected();
210
211 otPlatDsoDisconnect(this, MapEnum(aMode));
212
213 exit:
214 return;
215 }
216
HandleDisconnected(DisconnectMode aMode)217 void Dso::Connection::HandleDisconnected(DisconnectMode aMode)
218 {
219 VerifyOrExit(mState != kStateDisconnected);
220
221 if (mState == kStateConnecting)
222 {
223 mDisconnectReason = kReasonFailedToConnect;
224 }
225 else
226 {
227 switch (aMode)
228 {
229 case kGracefullyClose:
230 mDisconnectReason = kReasonPeerClosed;
231 break;
232
233 case kForciblyAbort:
234 mDisconnectReason = kReasonPeerAborted;
235 }
236 }
237
238 MarkAsDisconnected();
239 SignalAnyStateChange();
240
241 exit:
242 return;
243 }
244
MarkAsDisconnected(void)245 void Dso::Connection::MarkAsDisconnected(void)
246 {
247 if (IsClient())
248 {
249 IgnoreError(Get<Dso>().mClientConnections.Remove(*this));
250 }
251 else
252 {
253 IgnoreError(Get<Dso>().mServerConnections.Remove(*this));
254 }
255
256 mPendingRequests.Clear();
257 SetState(kStateDisconnected);
258
259 LogInfo("Disconnect reason: %s", DisconnectReasonToString(mDisconnectReason));
260 }
261
MarkSessionEstablished(void)262 void Dso::Connection::MarkSessionEstablished(void)
263 {
264 switch (mState)
265 {
266 case kStateConnectedButSessionless:
267 case kStateEstablishingSession:
268 case kStateSessionEstablished:
269 break;
270
271 case kStateDisconnected:
272 case kStateConnecting:
273 OT_ASSERT(false);
274 }
275
276 SetState(kStateSessionEstablished);
277 }
278
SendRequestMessage(Message & aMessage,MessageId & aMessageId,uint32_t aResponseTimeout)279 Error Dso::Connection::SendRequestMessage(Message &aMessage, MessageId &aMessageId, uint32_t aResponseTimeout)
280 {
281 return SendMessage(aMessage, kRequestMessage, aMessageId, Dns::Header::kResponseSuccess, aResponseTimeout);
282 }
283
SendUnidirectionalMessage(Message & aMessage)284 Error Dso::Connection::SendUnidirectionalMessage(Message &aMessage)
285 {
286 MessageId messageId = 0;
287
288 return SendMessage(aMessage, kUnidirectionalMessage, messageId);
289 }
290
SendResponseMessage(Message & aMessage,MessageId aResponseId)291 Error Dso::Connection::SendResponseMessage(Message &aMessage, MessageId aResponseId)
292 {
293 return SendMessage(aMessage, kResponseMessage, aResponseId);
294 }
295
SetLongLivedOperation(bool aLongLivedOperation)296 void Dso::Connection::SetLongLivedOperation(bool aLongLivedOperation)
297 {
298 VerifyOrExit(mLongLivedOperation != aLongLivedOperation);
299
300 mLongLivedOperation = aLongLivedOperation;
301
302 LogInfo("Long-lived operation %s", mLongLivedOperation ? "started" : "stopped");
303
304 if (!mLongLivedOperation)
305 {
306 TimeMilli now = TimerMilli::GetNow();
307 TimeMilli nextTime;
308
309 nextTime = GetNextFireTime(now);
310
311 if (nextTime != now.GetDistantFuture())
312 {
313 Get<Dso>().mTimer.FireAtIfEarlier(nextTime);
314 }
315 }
316
317 exit:
318 return;
319 }
320
SendRetryDelayMessage(uint32_t aDelay,Dns::Header::Response aResponseCode)321 Error Dso::Connection::SendRetryDelayMessage(uint32_t aDelay, Dns::Header::Response aResponseCode)
322 {
323 Error error = kErrorNone;
324 Message * message = nullptr;
325 RetryDelayTlv retryDelayTlv;
326 MessageId messageId;
327
328 switch (mState)
329 {
330 case kStateSessionEstablished:
331 OT_ASSERT(IsServer());
332 break;
333
334 case kStateConnectedButSessionless:
335 case kStateEstablishingSession:
336 case kStateDisconnected:
337 case kStateConnecting:
338 OT_ASSERT(false);
339 }
340
341 message = NewMessage();
342 VerifyOrExit(message != nullptr, error = kErrorNoBufs);
343
344 retryDelayTlv.Init();
345 retryDelayTlv.SetRetryDelay(aDelay);
346 SuccessOrExit(error = message->Append(retryDelayTlv));
347 error = SendMessage(*message, kUnidirectionalMessage, messageId, aResponseCode);
348
349 exit:
350 FreeMessageOnError(message, error);
351 return error;
352 }
353
SetTimeouts(uint32_t aInactivityTimeout,uint32_t aKeepAliveInterval)354 Error Dso::Connection::SetTimeouts(uint32_t aInactivityTimeout, uint32_t aKeepAliveInterval)
355 {
356 Error error = kErrorNone;
357
358 VerifyOrExit(aKeepAliveInterval >= kMinKeepAliveInterval, error = kErrorInvalidArgs);
359
360 // If acting as server, the timeout values are the ones we grant
361 // to a connecting clients. If acting as client, the timeout
362 // values are what to request when sending Keep Alive message.
363 // If in `kStateDisconnected` we set both (since we don't know
364 // yet whether we are going to connect as client or server).
365
366 if ((mState == kStateDisconnected) || IsServer())
367 {
368 mKeepAlive.SetInterval(aKeepAliveInterval);
369 AdjustInactivityTimeout(aInactivityTimeout);
370 }
371
372 if ((mState == kStateDisconnected) || IsClient())
373 {
374 mKeepAlive.SetRequestInterval(aKeepAliveInterval);
375 mInactivity.SetRequestInterval(aInactivityTimeout);
376 }
377
378 switch (mState)
379 {
380 case kStateDisconnected:
381 case kStateConnecting:
382 break;
383
384 case kStateConnectedButSessionless:
385 case kStateEstablishingSession:
386 if (IsServer())
387 {
388 break;
389 }
390
391 OT_FALL_THROUGH;
392
393 case kStateSessionEstablished:
394 error = SendKeepAliveMessage();
395 }
396
397 exit:
398 return error;
399 }
400
SendKeepAliveMessage(void)401 Error Dso::Connection::SendKeepAliveMessage(void)
402 {
403 return SendKeepAliveMessage(IsServer() ? kUnidirectionalMessage : kRequestMessage, 0);
404 }
405
SendKeepAliveMessage(MessageType aMessageType,MessageId aResponseId)406 Error Dso::Connection::SendKeepAliveMessage(MessageType aMessageType, MessageId aResponseId)
407 {
408 // Sends a Keep Alive message of a given type. This is a common
409 // method used by both client and server. `aResponseId` is
410 // applicable and used only when the message type is
411 // `kResponseMessage`.
412
413 Error error = kErrorNone;
414 Message * message = nullptr;
415 KeepAliveTlv keepAliveTlv;
416
417 switch (mState)
418 {
419 case kStateConnectedButSessionless:
420 case kStateEstablishingSession:
421 if (IsServer())
422 {
423 // While session is being established, server is only allowed
424 // to send a Keep Alive response to a request from client.
425 OT_ASSERT(aMessageType == kResponseMessage);
426 }
427 break;
428
429 case kStateSessionEstablished:
430 break;
431
432 case kStateDisconnected:
433 case kStateConnecting:
434 OT_ASSERT(false);
435 }
436
437 // Server can send Keep Alive response (to a request from client)
438 // or a unidirectional Keep Alive message. Client can send
439 // KeepAlive request message.
440
441 if (IsServer())
442 {
443 if (aMessageType == kResponseMessage)
444 {
445 OT_ASSERT(aResponseId != 0);
446 }
447 else
448 {
449 OT_ASSERT(aMessageType == kUnidirectionalMessage);
450 }
451 }
452 else
453 {
454 OT_ASSERT(aMessageType == kRequestMessage);
455 }
456
457 message = NewMessage();
458 VerifyOrExit(message != nullptr, error = kErrorNoBufs);
459
460 keepAliveTlv.Init();
461
462 if (IsServer())
463 {
464 keepAliveTlv.SetInactivityTimeout(mInactivity.GetInterval());
465 keepAliveTlv.SetKeepAliveInterval(mKeepAlive.GetInterval());
466 }
467 else
468 {
469 keepAliveTlv.SetInactivityTimeout(mInactivity.GetRequestInterval());
470 keepAliveTlv.SetKeepAliveInterval(mKeepAlive.GetRequestInterval());
471 }
472
473 SuccessOrExit(error = message->Append(keepAliveTlv));
474
475 error = SendMessage(*message, aMessageType, aResponseId);
476
477 exit:
478 FreeMessageOnError(message, error);
479 return error;
480 }
481
SendMessage(Message & aMessage,MessageType aMessageType,MessageId & aMessageId,Dns::Header::Response aResponseCode,uint32_t aResponseTimeout)482 Error Dso::Connection::SendMessage(Message & aMessage,
483 MessageType aMessageType,
484 MessageId & aMessageId,
485 Dns::Header::Response aResponseCode,
486 uint32_t aResponseTimeout)
487 {
488 Error error = kErrorNone;
489 Tlv::Type primaryTlvType = Tlv::kReservedType;
490 Dns::Header header;
491
492 switch (mState)
493 {
494 case kStateConnectedButSessionless:
495 // To establish session, client MUST send a request message.
496 // Server is not allowed to send any messages. Unidirectional
497 // messages are not allowed before session is established.
498 OT_ASSERT(IsClient());
499 OT_ASSERT(aMessageType == kRequestMessage);
500 break;
501
502 case kStateEstablishingSession:
503 // During session establishment, client is allowed to send
504 // additional request messages, server is only allowed to
505 // send response.
506 if (IsClient())
507 {
508 OT_ASSERT(aMessageType == kRequestMessage);
509 }
510 else
511 {
512 OT_ASSERT(aMessageType == kResponseMessage);
513 }
514 break;
515
516 case kStateSessionEstablished:
517 // All message types are allowed.
518 break;
519
520 case kStateDisconnected:
521 case kStateConnecting:
522 OT_ASSERT(false);
523 }
524
525 // A DSO request or unidirectional message MUST contain at
526 // least one TLV. The first TLV is the "Primary TLV" and
527 // determines the nature of the operation being performed.
528 // A DSO response message may contain no TLVs, or may contain
529 // one or more TLVs. Response Primary TLV(s) MUST appear first
530 // in a DSO response message.
531
532 aMessage.SetOffset(0);
533 IgnoreError(ReadPrimaryTlv(aMessage, primaryTlvType));
534
535 switch (aMessageType)
536 {
537 case kResponseMessage:
538 break;
539 case kRequestMessage:
540 case kUnidirectionalMessage:
541 OT_ASSERT(primaryTlvType != Tlv::kReservedType);
542 }
543
544 // `header` is cleared from its constructor call so all fields
545 // start as zero.
546
547 switch (aMessageType)
548 {
549 case kRequestMessage:
550 header.SetType(Dns::Header::kTypeQuery);
551 aMessageId = mNextMessageId;
552 break;
553
554 case kResponseMessage:
555 header.SetType(Dns::Header::kTypeResponse);
556 break;
557
558 case kUnidirectionalMessage:
559 header.SetType(Dns::Header::kTypeQuery);
560 aMessageId = 0;
561 break;
562 }
563
564 header.SetMessageId(aMessageId);
565 header.SetQueryType(Dns::Header::kQueryTypeDso);
566 header.SetResponseCode(aResponseCode);
567 SuccessOrExit(error = aMessage.Prepend(header));
568
569 SuccessOrExit(error = AppendPadding(aMessage));
570
571 // Update `mPendingRequests` list with the new request info
572
573 if (aMessageType == kRequestMessage)
574 {
575 SuccessOrExit(
576 error = mPendingRequests.Add(mNextMessageId, primaryTlvType, TimerMilli::GetNow() + aResponseTimeout));
577
578 if (++mNextMessageId == 0)
579 {
580 mNextMessageId = 1;
581 }
582 }
583
584 LogInfo("Sending %s message with id %u to %s", MessageTypeToString(aMessageType), aMessageId,
585 mPeerSockAddr.ToString().AsCString());
586
587 switch (mState)
588 {
589 case kStateConnectedButSessionless:
590 // On client we transition from "connected" state to
591 // "establishing session" state on successfully sending a
592 // request message.
593 if (IsClient())
594 {
595 SetState(kStateEstablishingSession);
596 }
597 break;
598
599 case kStateEstablishingSession:
600 // On server we transition from "establishing session" state
601 // to "established" on sending a response with success
602 // response code.
603 if (IsServer() && (aResponseCode == Dns::Header::kResponseSuccess))
604 {
605 SetState(kStateSessionEstablished);
606 }
607
608 default:
609 break;
610 }
611
612 ResetTimeouts(/* aIsKeepAliveMessage*/ (primaryTlvType == KeepAliveTlv::kType));
613
614 otPlatDsoSend(this, &aMessage);
615
616 // Signal any state changes. This is done at the very end when the
617 // `SendMessage()` is fully processed (all state and local
618 // variables are updated) to ensure that we do not have any
619 // reentrancy issues (e.g., if the callback signalling state
620 // change triggers another tx).
621
622 SignalAnyStateChange();
623
624 exit:
625 return error;
626 }
627
AppendPadding(Message & aMessage)628 Error Dso::Connection::AppendPadding(Message &aMessage)
629 {
630 // This method appends Encryption Padding TLV to a DSO message.
631 // It uses the padding policy "Random-Block-Length Padding" from
632 // RFC 8467.
633
634 static const uint16_t kBlockLengths[] = {8, 11, 17, 21};
635
636 Error error = kErrorNone;
637 uint16_t blockLength;
638 EncryptionPaddingTlv paddingTlv;
639
640 // We pick a random block length. The random selection can be
641 // based on a "weak" source of randomness (so the use of
642 // `NonCrypto` is fine). We add padding to the message such
643 // that its padded length is a multiple of the chosen block
644 // length.
645
646 blockLength = kBlockLengths[Random::NonCrypto::GetUint8InRange(0, GetArrayLength(kBlockLengths))];
647
648 paddingTlv.Init((blockLength - ((aMessage.GetLength() + sizeof(Tlv)) % blockLength)) % blockLength);
649
650 SuccessOrExit(error = aMessage.Append(paddingTlv));
651
652 for (uint16_t len = paddingTlv.GetLength(); len > 0; len--)
653 {
654 SuccessOrExit(error = aMessage.Append<uint8_t>(0));
655 }
656
657 exit:
658 return error;
659 }
660
HandleReceive(Message & aMessage)661 void Dso::Connection::HandleReceive(Message &aMessage)
662 {
663 Error error = kErrorAbort;
664 Tlv::Type primaryTlvType = Tlv::kReservedType;
665 Dns::Header header;
666
667 SuccessOrExit(aMessage.Read(0, header));
668
669 if (header.GetQueryType() != Dns::Header::kQueryTypeDso)
670 {
671 if (header.GetType() == Dns::Header::kTypeQuery)
672 {
673 SendErrorResponse(header, Dns::Header::kResponseNotImplemented);
674 error = kErrorNone;
675 }
676
677 ExitNow();
678 }
679
680 switch (mState)
681 {
682 case kStateConnectedButSessionless:
683 // After connection is established, client should initiate
684 // establishing session (by sending a request). So no rx is
685 // allowed before this. On server, we allow rx of a request
686 // message only.
687 VerifyOrExit(IsServer() && (header.GetType() == Dns::Header::kTypeQuery) && (header.GetMessageId() != 0));
688 break;
689
690 case kStateEstablishingSession:
691 // Unidirectional message are allowed after session is
692 // established. While session is being established, on client,
693 // we allow rx on response message. On server we can rx
694 // request or response.
695
696 VerifyOrExit(header.GetMessageId() != 0);
697
698 if (IsClient())
699 {
700 VerifyOrExit(header.GetType() == Dns::Header::kTypeResponse);
701 }
702 break;
703
704 case kStateSessionEstablished:
705 // All message types are allowed.
706 break;
707
708 case kStateDisconnected:
709 case kStateConnecting:
710 ExitNow();
711 }
712
713 // All count fields MUST be set to zero in the header.
714 VerifyOrExit((header.GetQuestionCount() == 0) && (header.GetAnswerCount() == 0) &&
715 (header.GetAuthorityRecordCount() == 0) && (header.GetAdditionalRecordCount() == 0));
716
717 aMessage.SetOffset(sizeof(header));
718
719 switch (ReadPrimaryTlv(aMessage, primaryTlvType))
720 {
721 case kErrorNone:
722 VerifyOrExit(primaryTlvType != Tlv::kReservedType);
723 break;
724
725 case kErrorNotFound:
726 // The `primaryTlvType` is set to `Tlv::kReservedType`
727 // (value zero) to indicate that there is no primary TLV.
728 break;
729
730 default:
731 ExitNow();
732 }
733
734 switch (header.GetType())
735 {
736 case Dns::Header::kTypeQuery:
737 error = ProcessRequestOrUnidirectionalMessage(header, aMessage, primaryTlvType);
738 break;
739
740 case Dns::Header::kTypeResponse:
741 error = ProcessResponseMessage(header, aMessage, primaryTlvType);
742 break;
743 }
744
745 exit:
746 aMessage.Free();
747
748 if (error == kErrorNone)
749 {
750 ResetTimeouts(/* aIsKeepAliveMessage */ (primaryTlvType == KeepAliveTlv::kType));
751 }
752 else
753 {
754 Disconnect(kForciblyAbort, kReasonPeerMisbehavior);
755 }
756
757 // We signal any state change at the very end when the received
758 // message is fully processed (all state and local variables are
759 // updated) to ensure that we do not have any reentrancy issues
760 // (e.g., if a `Connection` method happens to be called from the
761 // callback).
762
763 SignalAnyStateChange();
764 }
765
ReadPrimaryTlv(const Message & aMessage,Tlv::Type & aPrimaryTlvType) const766 Error Dso::Connection::ReadPrimaryTlv(const Message &aMessage, Tlv::Type &aPrimaryTlvType) const
767 {
768 // Read and validate the primary TLV (first TLV after the header).
769 // The `aMessage.GetOffset()` must point to the first TLV. If no
770 // TLV then `kErrorNotFound` is returned. If TLV in message is not
771 // well-formed `kErrorParse` is returned. The read TLV type is
772 // returned in `aPrimaryTlvType` (set to `Tlv::kReservedType`
773 // (value zero) when `kErrorNotFound`).
774
775 Error error = kErrorNotFound;
776 Tlv tlv;
777
778 aPrimaryTlvType = Tlv::kReservedType;
779
780 SuccessOrExit(aMessage.Read(aMessage.GetOffset(), tlv));
781 VerifyOrExit(aMessage.GetOffset() + tlv.GetSize() <= aMessage.GetLength(), error = kErrorParse);
782 aPrimaryTlvType = tlv.GetType();
783 error = kErrorNone;
784
785 exit:
786 return error;
787 }
788
ProcessRequestOrUnidirectionalMessage(const Dns::Header & aHeader,const Message & aMessage,Tlv::Type aPrimaryTlvType)789 Error Dso::Connection::ProcessRequestOrUnidirectionalMessage(const Dns::Header &aHeader,
790 const Message & aMessage,
791 Tlv::Type aPrimaryTlvType)
792 {
793 Error error = kErrorAbort;
794
795 if (IsServer() && (mState == kStateConnectedButSessionless))
796 {
797 SetState(kStateEstablishingSession);
798 }
799
800 // A DSO request or unidirectional message MUST contain at
801 // least one TLV which is the "Primary TLV" and determines
802 // the nature of the operation being performed.
803
804 switch (aPrimaryTlvType)
805 {
806 case KeepAliveTlv::kType:
807 error = ProcessKeepAliveMessage(aHeader, aMessage);
808 break;
809
810 case RetryDelayTlv::kType:
811 error = ProcessRetryDelayMessage(aHeader, aMessage);
812 break;
813
814 case Tlv::kReservedType:
815 case EncryptionPaddingTlv::kType:
816 // Misbehavior by peer.
817 break;
818
819 default:
820 if (aHeader.GetMessageId() == 0)
821 {
822 LogInfo("Received unidirectional message from %s", mPeerSockAddr.ToString().AsCString());
823
824 error = mCallbacks.mProcessUnidirectionalMessage(*this, aMessage, aPrimaryTlvType);
825 }
826 else
827 {
828 MessageId messageId = aHeader.GetMessageId();
829
830 LogInfo("Received request message with id %u from %s", messageId, mPeerSockAddr.ToString().AsCString());
831
832 error = mCallbacks.mProcessRequestMessage(*this, messageId, aMessage, aPrimaryTlvType);
833
834 // `kErrorNotFound` indicates that TLV type is not known.
835
836 if (error == kErrorNotFound)
837 {
838 SendErrorResponse(aHeader, Dns::Header::kDsoTypeNotImplemented);
839 error = kErrorNone;
840 }
841 }
842 break;
843 }
844
845 return error;
846 }
847
ProcessResponseMessage(const Dns::Header & aHeader,const Message & aMessage,Tlv::Type aPrimaryTlvType)848 Error Dso::Connection::ProcessResponseMessage(const Dns::Header &aHeader,
849 const Message & aMessage,
850 Tlv::Type aPrimaryTlvType)
851 {
852 Error error = kErrorAbort;
853 Tlv::Type requestPrimaryTlvType;
854
855 // If a client or server receives a response where the message
856 // ID is zero, or is any other value that does not match the
857 // message ID of any of its outstanding operations, this is a
858 // fatal error and the recipient MUST forcibly abort the
859 // connection immediately.
860
861 VerifyOrExit(aHeader.GetMessageId() != 0);
862 VerifyOrExit(mPendingRequests.Contains(aHeader.GetMessageId(), requestPrimaryTlvType));
863
864 // If the response has no error and contains a primary TLV, it
865 // MUST match the request primary TLV.
866
867 if ((aHeader.GetResponseCode() == Dns::Header::kResponseSuccess) && (aPrimaryTlvType != Tlv::kReservedType))
868 {
869 VerifyOrExit(aPrimaryTlvType == requestPrimaryTlvType);
870 }
871
872 mPendingRequests.Remove(aHeader.GetMessageId());
873
874 switch (requestPrimaryTlvType)
875 {
876 case KeepAliveTlv::kType:
877 SuccessOrExit(error = ProcessKeepAliveMessage(aHeader, aMessage));
878 break;
879
880 default:
881 SuccessOrExit(error = mCallbacks.mProcessResponseMessage(*this, aHeader, aMessage, aPrimaryTlvType,
882 requestPrimaryTlvType));
883 break;
884 }
885
886 // DSO session is established when client sends a request message
887 // and receives a response from server with no error code.
888
889 if (IsClient() && (mState == kStateEstablishingSession) &&
890 (aHeader.GetResponseCode() == Dns::Header::kResponseSuccess))
891 {
892 SetState(kStateSessionEstablished);
893 }
894
895 exit:
896 return error;
897 }
898
ProcessKeepAliveMessage(const Dns::Header & aHeader,const Message & aMessage)899 Error Dso::Connection::ProcessKeepAliveMessage(const Dns::Header &aHeader, const Message &aMessage)
900 {
901 Error error = kErrorAbort;
902 uint16_t offset = aMessage.GetOffset();
903 Tlv tlv;
904 KeepAliveTlv keepAliveTlv;
905
906 if (aHeader.GetType() == Dns::Header::kTypeResponse)
907 {
908 // A Keep Alive response message is allowed on a client from a sever.
909
910 VerifyOrExit(IsClient());
911
912 if (aHeader.GetResponseCode() != Dns::Header::kResponseSuccess)
913 {
914 // We got an error response code from server for our
915 // Keep Alive request message. If this happens while
916 // establishing the DSO session, it indicates that server
917 // does not support DSO, so we close the connection. If
918 // this happens while session is already established, it
919 // is a misbehavior (fatal error) by server.
920
921 if (mState == kStateEstablishingSession)
922 {
923 Disconnect(kGracefullyClose, kReasonPeerDoesNotSupportDso);
924 error = kErrorNone;
925 }
926
927 ExitNow();
928 }
929 }
930
931 // Parse and validate the Keep Alive Message
932
933 SuccessOrExit(aMessage.Read(offset, keepAliveTlv));
934 offset += keepAliveTlv.GetSize();
935
936 VerifyOrExit((keepAliveTlv.GetType() == KeepAliveTlv::kType) && keepAliveTlv.IsValid());
937
938 // Keep Alive message MUST contain only one Keep Alive TLV.
939
940 while (offset < aMessage.GetLength())
941 {
942 SuccessOrExit(aMessage.Read(offset, tlv));
943 offset += tlv.GetSize();
944
945 VerifyOrExit((tlv.GetType() != KeepAliveTlv::kType) && (tlv.GetType() != RetryDelayTlv::kType));
946 }
947
948 VerifyOrExit(offset == aMessage.GetLength());
949
950 if (aHeader.GetType() == Dns::Header::kTypeQuery)
951 {
952 if (IsServer())
953 {
954 // Received a Keep Alive message from client. It MUST
955 // be a request message (not unidirectional). We prepare
956 // and send a Keep Alive response.
957
958 VerifyOrExit(aHeader.GetMessageId() != 0);
959
960 LogInfo("Received KeepAlive request message from client %s", mPeerSockAddr.ToString().AsCString());
961
962 IgnoreError(SendKeepAliveMessage(kResponseMessage, aHeader.GetMessageId()));
963 error = kErrorNone;
964 ExitNow();
965 }
966
967 // Received a Keep Alive message on client from server. Server
968 // Keep Alive message MUST be unidirectional (message ID
969 // zero).
970
971 VerifyOrExit(aHeader.GetMessageId() == 0);
972 }
973
974 LogInfo("Received Keep Alive %s message from server %s",
975 (aHeader.GetMessageId() == 0) ? "unidirectional" : "response", mPeerSockAddr.ToString().AsCString());
976
977 // Receiving a Keep Alive interval value from server less than the
978 // minimum (ten seconds) is a fatal error and client MUST then
979 // abort the connection.
980
981 VerifyOrExit(keepAliveTlv.GetKeepAliveInterval() >= kMinKeepAliveInterval);
982
983 // Update the timeout intervals on the connection from
984 // the new values we got from the server. The receive
985 // of the Keep Alive message does not itself reset the
986 // inactivity timer. So we use `AdjustInactivityTimeout`
987 // which takes into account the time elapsed since the
988 // last activity.
989
990 AdjustInactivityTimeout(keepAliveTlv.GetInactivityTimeout());
991 mKeepAlive.SetInterval(keepAliveTlv.GetKeepAliveInterval());
992
993 LogInfo("Timeouts Inactivity:%u, KeepAlive:%u", mInactivity.GetInterval(), mKeepAlive.GetInterval());
994
995 error = kErrorNone;
996
997 exit:
998 return error;
999 }
1000
ProcessRetryDelayMessage(const Dns::Header & aHeader,const Message & aMessage)1001 Error Dso::Connection::ProcessRetryDelayMessage(const Dns::Header &aHeader, const Message &aMessage)
1002
1003 {
1004 Error error = kErrorAbort;
1005 RetryDelayTlv retryDelayTlv;
1006
1007 // Retry Delay TLV can be used as the Primary TLV only in
1008 // a unidirectional message sent from server to client.
1009 // It is used by the server to instruct the client to
1010 // close the session and its underlying connection, and not
1011 // to reconnect for the indicated time interval.
1012
1013 VerifyOrExit(IsClient() && (aHeader.GetMessageId() == 0));
1014
1015 SuccessOrExit(aMessage.Read(aMessage.GetOffset(), retryDelayTlv));
1016 VerifyOrExit(retryDelayTlv.IsValid());
1017
1018 mRetryDelayErrorCode = aHeader.GetResponseCode();
1019 mRetryDelay = retryDelayTlv.GetRetryDelay();
1020
1021 LogInfo("Received Retry Delay message from server %s", mPeerSockAddr.ToString().AsCString());
1022 LogInfo(" RetryDelay:%u ms, ResponseCode:%d", mRetryDelay, mRetryDelayErrorCode);
1023
1024 Disconnect(kGracefullyClose, kReasonServerRetryDelayRequest);
1025
1026 exit:
1027 return error;
1028 }
1029
SendErrorResponse(const Dns::Header & aHeader,Dns::Header::Response aResponseCode)1030 void Dso::Connection::SendErrorResponse(const Dns::Header &aHeader, Dns::Header::Response aResponseCode)
1031 {
1032 Message * response = NewMessage();
1033 Dns::Header header;
1034
1035 VerifyOrExit(response != nullptr);
1036
1037 header.SetMessageId(aHeader.GetMessageId());
1038 header.SetType(Dns::Header::kTypeResponse);
1039 header.SetQueryType(aHeader.GetQueryType());
1040 header.SetResponseCode(aResponseCode);
1041
1042 SuccessOrExit(response->Prepend(header));
1043
1044 otPlatDsoSend(this, response);
1045 response = nullptr;
1046
1047 exit:
1048 FreeMessage(response);
1049 }
1050
AdjustInactivityTimeout(uint32_t aNewTimeout)1051 void Dso::Connection::AdjustInactivityTimeout(uint32_t aNewTimeout)
1052 {
1053 // This method sets the inactivity timeout interval to a new value
1054 // and updates the expiration time based on the new timeout value.
1055 //
1056 // On client, it is called on receiving a Keep Alive response or
1057 // unidirectional message from server. Note that the receive of
1058 // the Keep Alive message does not itself reset the inactivity
1059 // timer. So the time elapsed since the last activity should be
1060 // taken into account with the new inactivity timeout value.
1061 //
1062 // On server this method is called from `SetTimeouts()` when a new
1063 // inactivity timeout value is set.
1064
1065 TimeMilli now = TimerMilli::GetNow();
1066 TimeMilli start;
1067 TimeMilli newExpiration;
1068
1069 if (mState == kStateDisconnected)
1070 {
1071 mInactivity.SetInterval(aNewTimeout);
1072 ExitNow();
1073 }
1074
1075 VerifyOrExit(aNewTimeout != mInactivity.GetInterval());
1076
1077 // Calculate the start time (i.e., the last time inactivity timer
1078 // was cleared). If the previous inactivity time is set to
1079 // `kInfinite` value (`IsUsed()` returns `false`) then
1080 // `GetExpirationTime()` returns the start time. Otherwise, we
1081 // calculate it going back from the current expiration time with
1082 // the current wait interval.
1083
1084 if (!mInactivity.IsUsed())
1085 {
1086 start = mInactivity.GetExpirationTime();
1087 }
1088 else if (IsClient())
1089 {
1090 start = mInactivity.GetExpirationTime() - mInactivity.GetInterval();
1091 }
1092 else
1093 {
1094 start = mInactivity.GetExpirationTime() - CalculateServerInactivityWaitTime();
1095 }
1096
1097 mInactivity.SetInterval(aNewTimeout);
1098
1099 if (!mInactivity.IsUsed())
1100 {
1101 newExpiration = start;
1102 }
1103 else if (IsClient())
1104 {
1105 newExpiration = start + aNewTimeout;
1106
1107 if (newExpiration < now)
1108 {
1109 newExpiration = now;
1110 }
1111 }
1112 else
1113 {
1114 newExpiration = start + CalculateServerInactivityWaitTime();
1115
1116 if (newExpiration < now)
1117 {
1118 // If the server abruptly reduces the inactivity timeout
1119 // such that current elapsed time is already more than
1120 // twice the new inactivity timeout, then the client is
1121 // immediately considered delinquent (server can forcibly
1122 // abort the connection). So to give the client time to
1123 // close the connection gracefully, the server SHOULD
1124 // give the client an additional grace period of either
1125 // five seconds or one quarter of the new inactivity
1126 // timeout, whichever is greater [RFC 8490 - 7.1.1].
1127
1128 newExpiration = now + OT_MAX(kMinServerInactivityWaitTime, aNewTimeout / 4);
1129 }
1130 }
1131
1132 mInactivity.SetExpirationTime(newExpiration);
1133
1134 exit:
1135 return;
1136 }
1137
CalculateServerInactivityWaitTime(void) const1138 uint32_t Dso::Connection::CalculateServerInactivityWaitTime(void) const
1139 {
1140 // A server will abort an idle session after five seconds
1141 // (`kMinServerInactivityWaitTime`) or twice the inactivity
1142 // timeout value, whichever is greater [RFC 8490 - 6.4.1].
1143
1144 OT_ASSERT(mInactivity.IsUsed());
1145
1146 return OT_MAX(mInactivity.GetInterval() * 2, kMinServerInactivityWaitTime);
1147 }
1148
ResetTimeouts(bool aIsKeepAliveMessage)1149 void Dso::Connection::ResetTimeouts(bool aIsKeepAliveMessage)
1150 {
1151 TimeMilli now = TimerMilli::GetNow();
1152 TimeMilli nextTime;
1153
1154 // At both servers and clients, the generation or reception of any
1155 // complete DNS message resets both timers for that DSO
1156 // session, with the one exception being that a DSO Keep Alive
1157 // message resets only the keep alive timer, not the inactivity
1158 // timeout timer [RFC 8490 - 6.3]
1159
1160 if (mKeepAlive.IsUsed())
1161 {
1162 // On client, we wait for the Keep Alive interval but on server
1163 // we wait for twice the interval before considering Keep Alive
1164 // timeout.
1165 //
1166 // Note that we limit the interval to `Timeout::kMaxInterval`
1167 // (which is ~12 days). This max limit ensures that even twice
1168 // the interval is less than max OpenThread timer duration so
1169 // that the expiration time calculations below stay within the
1170 // `TimerMilli` range.
1171
1172 mKeepAlive.SetExpirationTime(now + mKeepAlive.GetInterval() * (IsServer() ? 2 : 1));
1173 }
1174
1175 if (!aIsKeepAliveMessage)
1176 {
1177 if (mInactivity.IsUsed())
1178 {
1179 mInactivity.SetExpirationTime(
1180 now + (IsServer() ? CalculateServerInactivityWaitTime() : mInactivity.GetInterval()));
1181 }
1182 else
1183 {
1184 // When Inactivity timeout is not used (i.e., interval is set
1185 // to the special `kInfinite` value), we still need to track
1186 // the time so that if/when later the inactivity interval
1187 // gets changed, we can adjust the remaining time correctly
1188 // from `AdjustInactivityTimeout()`. In this case, we just
1189 // track the current time as "expiration time".
1190
1191 mInactivity.SetExpirationTime(now);
1192 }
1193 }
1194
1195 nextTime = GetNextFireTime(now);
1196
1197 if (nextTime != now.GetDistantFuture())
1198 {
1199 Get<Dso>().mTimer.FireAtIfEarlier(nextTime);
1200 }
1201 }
1202
GetNextFireTime(TimeMilli aNow) const1203 TimeMilli Dso::Connection::GetNextFireTime(TimeMilli aNow) const
1204 {
1205 TimeMilli nextTime = aNow.GetDistantFuture();
1206
1207 switch (mState)
1208 {
1209 case kStateDisconnected:
1210 break;
1211
1212 case kStateConnecting:
1213 // While in `kStateConnecting`, Keep Alive timer is
1214 // used for `kConnectingTimeout`.
1215 VerifyOrExit(mKeepAlive.GetExpirationTime() > aNow, nextTime = aNow);
1216 nextTime = mKeepAlive.GetExpirationTime();
1217 break;
1218
1219 case kStateConnectedButSessionless:
1220 case kStateEstablishingSession:
1221 case kStateSessionEstablished:
1222 nextTime = OT_MIN(nextTime, mPendingRequests.GetNextFireTime(aNow));
1223
1224 if (mKeepAlive.IsUsed())
1225 {
1226 VerifyOrExit(mKeepAlive.GetExpirationTime() > aNow, nextTime = aNow);
1227 nextTime = OT_MIN(nextTime, mKeepAlive.GetExpirationTime());
1228 }
1229
1230 if (mInactivity.IsUsed() && mPendingRequests.IsEmpty() && !mLongLivedOperation)
1231 {
1232 // An operation being active on a DSO Session includes
1233 // a request message waiting for a response, or an
1234 // active long-lived operation.
1235
1236 VerifyOrExit(mInactivity.GetExpirationTime() > aNow, nextTime = aNow);
1237 nextTime = OT_MIN(nextTime, mInactivity.GetExpirationTime());
1238 }
1239
1240 break;
1241 }
1242
1243 exit:
1244 return nextTime;
1245 }
1246
HandleTimer(TimeMilli aNow,TimeMilli & aNextTime)1247 void Dso::Connection::HandleTimer(TimeMilli aNow, TimeMilli &aNextTime)
1248 {
1249 switch (mState)
1250 {
1251 case kStateDisconnected:
1252 break;
1253
1254 case kStateConnecting:
1255 if (mKeepAlive.IsExpired(aNow))
1256 {
1257 Disconnect(kGracefullyClose, kReasonFailedToConnect);
1258 }
1259 break;
1260
1261 case kStateConnectedButSessionless:
1262 case kStateEstablishingSession:
1263 case kStateSessionEstablished:
1264 if (mPendingRequests.HasAnyTimedOut(aNow))
1265 {
1266 // If server sends no response to a request, client
1267 // waits for 30 seconds (`kResponseTimeout`) after which
1268 // client MUST forcibly abort the connection.
1269 Disconnect(kForciblyAbort, kReasonResponseTimeout);
1270 ExitNow();
1271 }
1272
1273 // The inactivity timer is kept clear, while an operation is
1274 // active on the session (which includes a request waiting for
1275 // response or an active long-lived operation).
1276
1277 if (mInactivity.IsUsed() && mPendingRequests.IsEmpty() && !mLongLivedOperation && mInactivity.IsExpired(aNow))
1278 {
1279 // On client, if the inactivity timeout is reached, the
1280 // connection is closed gracefully. On server, if too much
1281 // time (`CalculateServerInactivityWaitTime()`, i.e., five
1282 // seconds or twice the current inactivity timeout interval,
1283 // whichever is grater) elapses server MUST consider the
1284 // client delinquent and MUST forcibly abort the connection.
1285
1286 Disconnect(IsClient() ? kGracefullyClose : kForciblyAbort, kReasonInactivityTimeout);
1287 ExitNow();
1288 }
1289
1290 if (mKeepAlive.IsUsed() && mKeepAlive.IsExpired(aNow))
1291 {
1292 // On client, if the Keep Alive interval elapses without any
1293 // DNS messages being sent or received, the client MUST take
1294 // action and send a DSO Keep Alive message.
1295 //
1296 // On server, if twice the Keep Alive interval value elapses
1297 // without any messages being sent or received, the server
1298 // considers the client delinquent and aborts the connection.
1299
1300 if (IsClient())
1301 {
1302 IgnoreError(SendKeepAliveMessage());
1303 }
1304 else
1305 {
1306 Disconnect(kForciblyAbort, kReasonKeepAliveTimeout);
1307 ExitNow();
1308 }
1309 }
1310 break;
1311 }
1312
1313 exit:
1314 aNextTime = OT_MIN(aNextTime, GetNextFireTime(aNow));
1315 SignalAnyStateChange();
1316 }
1317
StateToString(State aState)1318 const char *Dso::Connection::StateToString(State aState)
1319 {
1320 static const char *const kStateStrings[] = {
1321 "Disconnected", // (0) kStateDisconnected,
1322 "Connecting", // (1) kStateConnecting,
1323 "ConnectedButSessionless", // (2) kStateConnectedButSessionless,
1324 "EstablishingSession", // (3) kStateEstablishingSession,
1325 "SessionEstablished", // (4) kStateSessionEstablished,
1326 };
1327
1328 static_assert(0 == kStateDisconnected, "kStateDisconnected value is incorrect");
1329 static_assert(1 == kStateConnecting, "kStateConnecting value is incorrect");
1330 static_assert(2 == kStateConnectedButSessionless, "kStateConnectedButSessionless value is incorrect");
1331 static_assert(3 == kStateEstablishingSession, "kStateEstablishingSession value is incorrect");
1332 static_assert(4 == kStateSessionEstablished, "kStateSessionEstablished value is incorrect");
1333
1334 return kStateStrings[aState];
1335 }
1336
MessageTypeToString(MessageType aMessageType)1337 const char *Dso::Connection::MessageTypeToString(MessageType aMessageType)
1338 {
1339 static const char *const kMessageTypeStrings[] = {
1340 "Request", // (0) kRequestMessage
1341 "Response", // (1) kResponseMessage
1342 "Unidirectional", // (2) kUnidirectionalMessage
1343 };
1344
1345 static_assert(0 == kRequestMessage, "kRequestMessage value is incorrect");
1346 static_assert(1 == kResponseMessage, "kResponseMessage value is incorrect");
1347 static_assert(2 == kUnidirectionalMessage, "kUnidirectionalMessage value is incorrect");
1348
1349 return kMessageTypeStrings[aMessageType];
1350 }
1351
DisconnectReasonToString(DisconnectReason aReason)1352 const char *Dso::Connection::DisconnectReasonToString(DisconnectReason aReason)
1353 {
1354 static const char *const kDisconnectReasonStrings[] = {
1355 "FailedToConnect", // (0) kReasonFailedToConnect
1356 "ResponseTimeout", // (1) kReasonResponseTimeout
1357 "PeerDoesNotSupportDso", // (2) kReasonPeerDoesNotSupportDso
1358 "PeerClosed", // (3) kReasonPeerClosed
1359 "PeerAborted", // (4) kReasonPeerAborted
1360 "InactivityTimeout", // (5) kReasonInactivityTimeout
1361 "KeepAliveTimeout", // (6) kReasonKeepAliveTimeout
1362 "ServerRetryDelayRequest", // (7) kReasonServerRetryDelayRequest
1363 "PeerMisbehavior", // (8) kReasonPeerMisbehavior
1364 "Unknown", // (9) kReasonUnknown
1365 };
1366
1367 static_assert(0 == kReasonFailedToConnect, "kReasonFailedToConnect value is incorrect");
1368 static_assert(1 == kReasonResponseTimeout, "kReasonResponseTimeout value is incorrect");
1369 static_assert(2 == kReasonPeerDoesNotSupportDso, "kReasonPeerDoesNotSupportDso value is incorrect");
1370 static_assert(3 == kReasonPeerClosed, "kReasonPeerClosed value is incorrect");
1371 static_assert(4 == kReasonPeerAborted, "kReasonPeerAborted value is incorrect");
1372 static_assert(5 == kReasonInactivityTimeout, "kReasonInactivityTimeout value is incorrect");
1373 static_assert(6 == kReasonKeepAliveTimeout, "kReasonKeepAliveTimeout value is incorrect");
1374 static_assert(7 == kReasonServerRetryDelayRequest, "kReasonServerRetryDelayRequest value is incorrect");
1375 static_assert(8 == kReasonPeerMisbehavior, "kReasonPeerMisbehavior value is incorrect");
1376 static_assert(9 == kReasonUnknown, "kReasonUnknown value is incorrect");
1377
1378 return kDisconnectReasonStrings[aReason];
1379 }
1380
1381 //---------------------------------------------------------------------------------------------------------------------
1382 // Dso::Connection::PendingRequests
1383
Contains(MessageId aMessageId,Tlv::Type & aPrimaryTlvType) const1384 bool Dso::Connection::PendingRequests::Contains(MessageId aMessageId, Tlv::Type &aPrimaryTlvType) const
1385 {
1386 bool contains = true;
1387 const Entry *entry = mRequests.FindMatching(aMessageId);
1388
1389 VerifyOrExit(entry != nullptr, contains = false);
1390 aPrimaryTlvType = entry->mPrimaryTlvType;
1391
1392 exit:
1393 return contains;
1394 }
1395
Add(MessageId aMessageId,Tlv::Type aPrimaryTlvType,TimeMilli aResponseTimeout)1396 Error Dso::Connection::PendingRequests::Add(MessageId aMessageId, Tlv::Type aPrimaryTlvType, TimeMilli aResponseTimeout)
1397 {
1398 Error error = kErrorNone;
1399 Entry *entry = mRequests.PushBack();
1400
1401 VerifyOrExit(entry != nullptr, error = kErrorNoBufs);
1402 entry->mMessageId = aMessageId;
1403 entry->mPrimaryTlvType = aPrimaryTlvType;
1404 entry->mTimeout = aResponseTimeout;
1405
1406 exit:
1407 return error;
1408 }
1409
Remove(MessageId aMessageId)1410 void Dso::Connection::PendingRequests::Remove(MessageId aMessageId)
1411 {
1412 mRequests.RemoveMatching(aMessageId);
1413 }
1414
HasAnyTimedOut(TimeMilli aNow) const1415 bool Dso::Connection::PendingRequests::HasAnyTimedOut(TimeMilli aNow) const
1416 {
1417 bool timedOut = false;
1418
1419 for (const Entry &entry : mRequests)
1420 {
1421 if (entry.mTimeout <= aNow)
1422 {
1423 timedOut = true;
1424 break;
1425 }
1426 }
1427
1428 return timedOut;
1429 }
1430
GetNextFireTime(TimeMilli aNow) const1431 TimeMilli Dso::Connection::PendingRequests::GetNextFireTime(TimeMilli aNow) const
1432 {
1433 TimeMilli nextTime = aNow.GetDistantFuture();
1434
1435 for (const Entry &entry : mRequests)
1436 {
1437 VerifyOrExit(entry.mTimeout > aNow, nextTime = aNow);
1438 nextTime = OT_MIN(entry.mTimeout, nextTime);
1439 }
1440
1441 exit:
1442 return nextTime;
1443 }
1444
1445 //---------------------------------------------------------------------------------------------------------------------
1446 // Dso
1447
Dso(Instance & aInstance)1448 Dso::Dso(Instance &aInstance)
1449 : InstanceLocator(aInstance)
1450 , mAcceptHandler(nullptr)
1451 , mTimer(aInstance, HandleTimer)
1452 {
1453 }
1454
StartListening(AcceptHandler aAcceptHandler)1455 void Dso::StartListening(AcceptHandler aAcceptHandler)
1456 {
1457 mAcceptHandler = aAcceptHandler;
1458 otPlatDsoEnableListening(&GetInstance(), true);
1459 }
1460
StopListening(void)1461 void Dso::StopListening(void)
1462 {
1463 otPlatDsoEnableListening(&GetInstance(), false);
1464 }
1465
FindClientConnection(const Ip6::SockAddr & aPeerSockAddr)1466 Dso::Connection *Dso::FindClientConnection(const Ip6::SockAddr &aPeerSockAddr)
1467 {
1468 return mClientConnections.FindMatching(aPeerSockAddr);
1469 }
1470
FindServerConnection(const Ip6::SockAddr & aPeerSockAddr)1471 Dso::Connection *Dso::FindServerConnection(const Ip6::SockAddr &aPeerSockAddr)
1472 {
1473 return mServerConnections.FindMatching(aPeerSockAddr);
1474 }
1475
AcceptConnection(const Ip6::SockAddr & aPeerSockAddr)1476 Dso::Connection *Dso::AcceptConnection(const Ip6::SockAddr &aPeerSockAddr)
1477 {
1478 Connection *connection = nullptr;
1479
1480 VerifyOrExit(mAcceptHandler != nullptr);
1481 connection = mAcceptHandler(GetInstance(), aPeerSockAddr);
1482
1483 VerifyOrExit(connection != nullptr);
1484 connection->Accept();
1485
1486 exit:
1487 return connection;
1488 }
1489
HandleTimer(Timer & aTimer)1490 void Dso::HandleTimer(Timer &aTimer)
1491 {
1492 aTimer.Get<Dso>().HandleTimer();
1493 }
1494
HandleTimer(void)1495 void Dso::HandleTimer(void)
1496 {
1497 TimeMilli now = TimerMilli::GetNow();
1498 TimeMilli nextTime = now.GetDistantFuture();
1499 Connection *conn;
1500 Connection *next;
1501
1502 for (conn = mClientConnections.GetHead(); conn != nullptr; conn = next)
1503 {
1504 next = conn->GetNext();
1505 conn->HandleTimer(now, nextTime);
1506 }
1507
1508 for (conn = mServerConnections.GetHead(); conn != nullptr; conn = next)
1509 {
1510 next = conn->GetNext();
1511 conn->HandleTimer(now, nextTime);
1512 }
1513
1514 if (nextTime != now.GetDistantFuture())
1515 {
1516 mTimer.FireAtIfEarlier(nextTime);
1517 }
1518 }
1519
1520 } // namespace Dns
1521 } // namespace ot
1522
1523 #endif // OPENTHREAD_CONFIG_DNS_DSO_ENABLE
1524