/* * Copyright (c) 2021, The OpenThread Authors. * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * 1. Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * 2. Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * 3. Neither the name of the copyright holder nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #include "dns_dso.hpp" #if OPENTHREAD_CONFIG_DNS_DSO_ENABLE #include "common/array.hpp" #include "common/as_core_type.hpp" #include "common/code_utils.hpp" #include "common/debug.hpp" #include "common/instance.hpp" #include "common/locator_getters.hpp" #include "common/log.hpp" #include "common/random.hpp" /** * @file * This file implements the DNS Stateful Operations (DSO) per RFC 8490. */ namespace ot { namespace Dns { RegisterLogModule("DnsDso"); //--------------------------------------------------------------------------------------------------------------------- // otPlatDso transport callbacks extern "C" otInstance *otPlatDsoGetInstance(otPlatDsoConnection *aConnection) { return &AsCoreType(aConnection).GetInstance(); } extern "C" otPlatDsoConnection *otPlatDsoAccept(otInstance *aInstance, const otSockAddr *aPeerSockAddr) { return AsCoreType(aInstance).Get().AcceptConnection(AsCoreType(aPeerSockAddr)); } extern "C" void otPlatDsoHandleConnected(otPlatDsoConnection *aConnection) { AsCoreType(aConnection).HandleConnected(); } extern "C" void otPlatDsoHandleReceive(otPlatDsoConnection *aConnection, otMessage *aMessage) { AsCoreType(aConnection).HandleReceive(AsCoreType(aMessage)); } extern "C" void otPlatDsoHandleDisconnected(otPlatDsoConnection *aConnection, otPlatDsoDisconnectMode aMode) { AsCoreType(aConnection).HandleDisconnected(MapEnum(aMode)); } //--------------------------------------------------------------------------------------------------------------------- // Dso::Connection Dso::Connection::Connection(Instance & aInstance, const Ip6::SockAddr &aPeerSockAddr, Callbacks & aCallbacks, uint32_t aInactivityTimeout, uint32_t aKeepAliveInterval) : InstanceLocator(aInstance) , mNext(nullptr) , mCallbacks(aCallbacks) , mPeerSockAddr(aPeerSockAddr) , mState(kStateDisconnected) , mIsServer(false) , mInactivity(aInactivityTimeout) , mKeepAlive(aKeepAliveInterval) { OT_ASSERT(aKeepAliveInterval >= kMinKeepAliveInterval); Init(/* aIsServer */ false); } void Dso::Connection::Init(bool aIsServer) { mNextMessageId = 1; mIsServer = aIsServer; mStateDidChange = false; mLongLivedOperation = false; mRetryDelay = 0; mRetryDelayErrorCode = Dns::Header::kResponseSuccess; mDisconnectReason = kReasonUnknown; } void Dso::Connection::SetState(State aState) { VerifyOrExit(mState != aState); LogInfo("State: %s -> %s on connection with %s", StateToString(mState), StateToString(aState), mPeerSockAddr.ToString().AsCString()); mState = aState; mStateDidChange = true; exit: return; } void Dso::Connection::SignalAnyStateChange(void) { VerifyOrExit(mStateDidChange); mStateDidChange = false; switch (mState) { case kStateDisconnected: mCallbacks.mHandleDisconnected(*this); break; case kStateConnectedButSessionless: mCallbacks.mHandleConnected(*this); break; case kStateSessionEstablished: mCallbacks.mHandleSessionEstablished(*this); break; case kStateConnecting: case kStateEstablishingSession: break; }; exit: return; } Message *Dso::Connection::NewMessage(void) { return Get().Allocate(Message::kTypeOther, sizeof(Dns::Header), Message::Settings(Message::kPriorityNormal)); } void Dso::Connection::Connect(void) { OT_ASSERT(mState == kStateDisconnected); Init(/* aIsServer */ false); Get().mClientConnections.Push(*this); MarkAsConnecting(); otPlatDsoConnect(this, &mPeerSockAddr); } void Dso::Connection::Accept(void) { OT_ASSERT(mState == kStateDisconnected); Init(/* aIsServer */ true); Get().mServerConnections.Push(*this); MarkAsConnecting(); } void Dso::Connection::MarkAsConnecting(void) { SetState(kStateConnecting); // While in `kStateConnecting` state we use the `mKeepAlive` to // track the `kConnectingTimeout` (if connection is not established // within the timeout, we consider it as failure and close it). mKeepAlive.SetExpirationTime(TimerMilli::GetNow() + kConnectingTimeout); Get().mTimer.FireAtIfEarlier(mKeepAlive.GetExpirationTime()); // Wait for `HandleConnected()` or `HandleDisconnected()` callbacks // or timeout. } void Dso::Connection::HandleConnected(void) { OT_ASSERT(mState == kStateConnecting); SetState(kStateConnectedButSessionless); ResetTimeouts(/* aIsKeepAliveMessage */ false); SignalAnyStateChange(); } void Dso::Connection::Disconnect(DisconnectMode aMode, DisconnectReason aReason) { VerifyOrExit(mState != kStateDisconnected); mDisconnectReason = aReason; MarkAsDisconnected(); otPlatDsoDisconnect(this, MapEnum(aMode)); exit: return; } void Dso::Connection::HandleDisconnected(DisconnectMode aMode) { VerifyOrExit(mState != kStateDisconnected); if (mState == kStateConnecting) { mDisconnectReason = kReasonFailedToConnect; } else { switch (aMode) { case kGracefullyClose: mDisconnectReason = kReasonPeerClosed; break; case kForciblyAbort: mDisconnectReason = kReasonPeerAborted; } } MarkAsDisconnected(); SignalAnyStateChange(); exit: return; } void Dso::Connection::MarkAsDisconnected(void) { if (IsClient()) { IgnoreError(Get().mClientConnections.Remove(*this)); } else { IgnoreError(Get().mServerConnections.Remove(*this)); } mPendingRequests.Clear(); SetState(kStateDisconnected); LogInfo("Disconnect reason: %s", DisconnectReasonToString(mDisconnectReason)); } void Dso::Connection::MarkSessionEstablished(void) { switch (mState) { case kStateConnectedButSessionless: case kStateEstablishingSession: case kStateSessionEstablished: break; case kStateDisconnected: case kStateConnecting: OT_ASSERT(false); } SetState(kStateSessionEstablished); } Error Dso::Connection::SendRequestMessage(Message &aMessage, MessageId &aMessageId, uint32_t aResponseTimeout) { return SendMessage(aMessage, kRequestMessage, aMessageId, Dns::Header::kResponseSuccess, aResponseTimeout); } Error Dso::Connection::SendUnidirectionalMessage(Message &aMessage) { MessageId messageId = 0; return SendMessage(aMessage, kUnidirectionalMessage, messageId); } Error Dso::Connection::SendResponseMessage(Message &aMessage, MessageId aResponseId) { return SendMessage(aMessage, kResponseMessage, aResponseId); } void Dso::Connection::SetLongLivedOperation(bool aLongLivedOperation) { VerifyOrExit(mLongLivedOperation != aLongLivedOperation); mLongLivedOperation = aLongLivedOperation; LogInfo("Long-lived operation %s", mLongLivedOperation ? "started" : "stopped"); if (!mLongLivedOperation) { TimeMilli now = TimerMilli::GetNow(); TimeMilli nextTime; nextTime = GetNextFireTime(now); if (nextTime != now.GetDistantFuture()) { Get().mTimer.FireAtIfEarlier(nextTime); } } exit: return; } Error Dso::Connection::SendRetryDelayMessage(uint32_t aDelay, Dns::Header::Response aResponseCode) { Error error = kErrorNone; Message * message = nullptr; RetryDelayTlv retryDelayTlv; MessageId messageId; switch (mState) { case kStateSessionEstablished: OT_ASSERT(IsServer()); break; case kStateConnectedButSessionless: case kStateEstablishingSession: case kStateDisconnected: case kStateConnecting: OT_ASSERT(false); } message = NewMessage(); VerifyOrExit(message != nullptr, error = kErrorNoBufs); retryDelayTlv.Init(); retryDelayTlv.SetRetryDelay(aDelay); SuccessOrExit(error = message->Append(retryDelayTlv)); error = SendMessage(*message, kUnidirectionalMessage, messageId, aResponseCode); exit: FreeMessageOnError(message, error); return error; } Error Dso::Connection::SetTimeouts(uint32_t aInactivityTimeout, uint32_t aKeepAliveInterval) { Error error = kErrorNone; VerifyOrExit(aKeepAliveInterval >= kMinKeepAliveInterval, error = kErrorInvalidArgs); // If acting as server, the timeout values are the ones we grant // to a connecting clients. If acting as client, the timeout // values are what to request when sending Keep Alive message. // If in `kStateDisconnected` we set both (since we don't know // yet whether we are going to connect as client or server). if ((mState == kStateDisconnected) || IsServer()) { mKeepAlive.SetInterval(aKeepAliveInterval); AdjustInactivityTimeout(aInactivityTimeout); } if ((mState == kStateDisconnected) || IsClient()) { mKeepAlive.SetRequestInterval(aKeepAliveInterval); mInactivity.SetRequestInterval(aInactivityTimeout); } switch (mState) { case kStateDisconnected: case kStateConnecting: break; case kStateConnectedButSessionless: case kStateEstablishingSession: if (IsServer()) { break; } OT_FALL_THROUGH; case kStateSessionEstablished: error = SendKeepAliveMessage(); } exit: return error; } Error Dso::Connection::SendKeepAliveMessage(void) { return SendKeepAliveMessage(IsServer() ? kUnidirectionalMessage : kRequestMessage, 0); } Error Dso::Connection::SendKeepAliveMessage(MessageType aMessageType, MessageId aResponseId) { // Sends a Keep Alive message of a given type. This is a common // method used by both client and server. `aResponseId` is // applicable and used only when the message type is // `kResponseMessage`. Error error = kErrorNone; Message * message = nullptr; KeepAliveTlv keepAliveTlv; switch (mState) { case kStateConnectedButSessionless: case kStateEstablishingSession: if (IsServer()) { // While session is being established, server is only allowed // to send a Keep Alive response to a request from client. OT_ASSERT(aMessageType == kResponseMessage); } break; case kStateSessionEstablished: break; case kStateDisconnected: case kStateConnecting: OT_ASSERT(false); } // Server can send Keep Alive response (to a request from client) // or a unidirectional Keep Alive message. Client can send // KeepAlive request message. if (IsServer()) { if (aMessageType == kResponseMessage) { OT_ASSERT(aResponseId != 0); } else { OT_ASSERT(aMessageType == kUnidirectionalMessage); } } else { OT_ASSERT(aMessageType == kRequestMessage); } message = NewMessage(); VerifyOrExit(message != nullptr, error = kErrorNoBufs); keepAliveTlv.Init(); if (IsServer()) { keepAliveTlv.SetInactivityTimeout(mInactivity.GetInterval()); keepAliveTlv.SetKeepAliveInterval(mKeepAlive.GetInterval()); } else { keepAliveTlv.SetInactivityTimeout(mInactivity.GetRequestInterval()); keepAliveTlv.SetKeepAliveInterval(mKeepAlive.GetRequestInterval()); } SuccessOrExit(error = message->Append(keepAliveTlv)); error = SendMessage(*message, aMessageType, aResponseId); exit: FreeMessageOnError(message, error); return error; } Error Dso::Connection::SendMessage(Message & aMessage, MessageType aMessageType, MessageId & aMessageId, Dns::Header::Response aResponseCode, uint32_t aResponseTimeout) { Error error = kErrorNone; Tlv::Type primaryTlvType = Tlv::kReservedType; Dns::Header header; switch (mState) { case kStateConnectedButSessionless: // To establish session, client MUST send a request message. // Server is not allowed to send any messages. Unidirectional // messages are not allowed before session is established. OT_ASSERT(IsClient()); OT_ASSERT(aMessageType == kRequestMessage); break; case kStateEstablishingSession: // During session establishment, client is allowed to send // additional request messages, server is only allowed to // send response. if (IsClient()) { OT_ASSERT(aMessageType == kRequestMessage); } else { OT_ASSERT(aMessageType == kResponseMessage); } break; case kStateSessionEstablished: // All message types are allowed. break; case kStateDisconnected: case kStateConnecting: OT_ASSERT(false); } // A DSO request or unidirectional message MUST contain at // least one TLV. The first TLV is the "Primary TLV" and // determines the nature of the operation being performed. // A DSO response message may contain no TLVs, or may contain // one or more TLVs. Response Primary TLV(s) MUST appear first // in a DSO response message. aMessage.SetOffset(0); IgnoreError(ReadPrimaryTlv(aMessage, primaryTlvType)); switch (aMessageType) { case kResponseMessage: break; case kRequestMessage: case kUnidirectionalMessage: OT_ASSERT(primaryTlvType != Tlv::kReservedType); } // `header` is cleared from its constructor call so all fields // start as zero. switch (aMessageType) { case kRequestMessage: header.SetType(Dns::Header::kTypeQuery); aMessageId = mNextMessageId; break; case kResponseMessage: header.SetType(Dns::Header::kTypeResponse); break; case kUnidirectionalMessage: header.SetType(Dns::Header::kTypeQuery); aMessageId = 0; break; } header.SetMessageId(aMessageId); header.SetQueryType(Dns::Header::kQueryTypeDso); header.SetResponseCode(aResponseCode); SuccessOrExit(error = aMessage.Prepend(header)); SuccessOrExit(error = AppendPadding(aMessage)); // Update `mPendingRequests` list with the new request info if (aMessageType == kRequestMessage) { SuccessOrExit( error = mPendingRequests.Add(mNextMessageId, primaryTlvType, TimerMilli::GetNow() + aResponseTimeout)); if (++mNextMessageId == 0) { mNextMessageId = 1; } } LogInfo("Sending %s message with id %u to %s", MessageTypeToString(aMessageType), aMessageId, mPeerSockAddr.ToString().AsCString()); switch (mState) { case kStateConnectedButSessionless: // On client we transition from "connected" state to // "establishing session" state on successfully sending a // request message. if (IsClient()) { SetState(kStateEstablishingSession); } break; case kStateEstablishingSession: // On server we transition from "establishing session" state // to "established" on sending a response with success // response code. if (IsServer() && (aResponseCode == Dns::Header::kResponseSuccess)) { SetState(kStateSessionEstablished); } default: break; } ResetTimeouts(/* aIsKeepAliveMessage*/ (primaryTlvType == KeepAliveTlv::kType)); otPlatDsoSend(this, &aMessage); // Signal any state changes. This is done at the very end when the // `SendMessage()` is fully processed (all state and local // variables are updated) to ensure that we do not have any // reentrancy issues (e.g., if the callback signalling state // change triggers another tx). SignalAnyStateChange(); exit: return error; } Error Dso::Connection::AppendPadding(Message &aMessage) { // This method appends Encryption Padding TLV to a DSO message. // It uses the padding policy "Random-Block-Length Padding" from // RFC 8467. static const uint16_t kBlockLengths[] = {8, 11, 17, 21}; Error error = kErrorNone; uint16_t blockLength; EncryptionPaddingTlv paddingTlv; // We pick a random block length. The random selection can be // based on a "weak" source of randomness (so the use of // `NonCrypto` is fine). We add padding to the message such // that its padded length is a multiple of the chosen block // length. blockLength = kBlockLengths[Random::NonCrypto::GetUint8InRange(0, GetArrayLength(kBlockLengths))]; paddingTlv.Init((blockLength - ((aMessage.GetLength() + sizeof(Tlv)) % blockLength)) % blockLength); SuccessOrExit(error = aMessage.Append(paddingTlv)); for (uint16_t len = paddingTlv.GetLength(); len > 0; len--) { SuccessOrExit(error = aMessage.Append(0)); } exit: return error; } void Dso::Connection::HandleReceive(Message &aMessage) { Error error = kErrorAbort; Tlv::Type primaryTlvType = Tlv::kReservedType; Dns::Header header; SuccessOrExit(aMessage.Read(0, header)); if (header.GetQueryType() != Dns::Header::kQueryTypeDso) { if (header.GetType() == Dns::Header::kTypeQuery) { SendErrorResponse(header, Dns::Header::kResponseNotImplemented); error = kErrorNone; } ExitNow(); } switch (mState) { case kStateConnectedButSessionless: // After connection is established, client should initiate // establishing session (by sending a request). So no rx is // allowed before this. On server, we allow rx of a request // message only. VerifyOrExit(IsServer() && (header.GetType() == Dns::Header::kTypeQuery) && (header.GetMessageId() != 0)); break; case kStateEstablishingSession: // Unidirectional message are allowed after session is // established. While session is being established, on client, // we allow rx on response message. On server we can rx // request or response. VerifyOrExit(header.GetMessageId() != 0); if (IsClient()) { VerifyOrExit(header.GetType() == Dns::Header::kTypeResponse); } break; case kStateSessionEstablished: // All message types are allowed. break; case kStateDisconnected: case kStateConnecting: ExitNow(); } // All count fields MUST be set to zero in the header. VerifyOrExit((header.GetQuestionCount() == 0) && (header.GetAnswerCount() == 0) && (header.GetAuthorityRecordCount() == 0) && (header.GetAdditionalRecordCount() == 0)); aMessage.SetOffset(sizeof(header)); switch (ReadPrimaryTlv(aMessage, primaryTlvType)) { case kErrorNone: VerifyOrExit(primaryTlvType != Tlv::kReservedType); break; case kErrorNotFound: // The `primaryTlvType` is set to `Tlv::kReservedType` // (value zero) to indicate that there is no primary TLV. break; default: ExitNow(); } switch (header.GetType()) { case Dns::Header::kTypeQuery: error = ProcessRequestOrUnidirectionalMessage(header, aMessage, primaryTlvType); break; case Dns::Header::kTypeResponse: error = ProcessResponseMessage(header, aMessage, primaryTlvType); break; } exit: aMessage.Free(); if (error == kErrorNone) { ResetTimeouts(/* aIsKeepAliveMessage */ (primaryTlvType == KeepAliveTlv::kType)); } else { Disconnect(kForciblyAbort, kReasonPeerMisbehavior); } // We signal any state change at the very end when the received // message is fully processed (all state and local variables are // updated) to ensure that we do not have any reentrancy issues // (e.g., if a `Connection` method happens to be called from the // callback). SignalAnyStateChange(); } Error Dso::Connection::ReadPrimaryTlv(const Message &aMessage, Tlv::Type &aPrimaryTlvType) const { // Read and validate the primary TLV (first TLV after the header). // The `aMessage.GetOffset()` must point to the first TLV. If no // TLV then `kErrorNotFound` is returned. If TLV in message is not // well-formed `kErrorParse` is returned. The read TLV type is // returned in `aPrimaryTlvType` (set to `Tlv::kReservedType` // (value zero) when `kErrorNotFound`). Error error = kErrorNotFound; Tlv tlv; aPrimaryTlvType = Tlv::kReservedType; SuccessOrExit(aMessage.Read(aMessage.GetOffset(), tlv)); VerifyOrExit(aMessage.GetOffset() + tlv.GetSize() <= aMessage.GetLength(), error = kErrorParse); aPrimaryTlvType = tlv.GetType(); error = kErrorNone; exit: return error; } Error Dso::Connection::ProcessRequestOrUnidirectionalMessage(const Dns::Header &aHeader, const Message & aMessage, Tlv::Type aPrimaryTlvType) { Error error = kErrorAbort; if (IsServer() && (mState == kStateConnectedButSessionless)) { SetState(kStateEstablishingSession); } // A DSO request or unidirectional message MUST contain at // least one TLV which is the "Primary TLV" and determines // the nature of the operation being performed. switch (aPrimaryTlvType) { case KeepAliveTlv::kType: error = ProcessKeepAliveMessage(aHeader, aMessage); break; case RetryDelayTlv::kType: error = ProcessRetryDelayMessage(aHeader, aMessage); break; case Tlv::kReservedType: case EncryptionPaddingTlv::kType: // Misbehavior by peer. break; default: if (aHeader.GetMessageId() == 0) { LogInfo("Received unidirectional message from %s", mPeerSockAddr.ToString().AsCString()); error = mCallbacks.mProcessUnidirectionalMessage(*this, aMessage, aPrimaryTlvType); } else { MessageId messageId = aHeader.GetMessageId(); LogInfo("Received request message with id %u from %s", messageId, mPeerSockAddr.ToString().AsCString()); error = mCallbacks.mProcessRequestMessage(*this, messageId, aMessage, aPrimaryTlvType); // `kErrorNotFound` indicates that TLV type is not known. if (error == kErrorNotFound) { SendErrorResponse(aHeader, Dns::Header::kDsoTypeNotImplemented); error = kErrorNone; } } break; } return error; } Error Dso::Connection::ProcessResponseMessage(const Dns::Header &aHeader, const Message & aMessage, Tlv::Type aPrimaryTlvType) { Error error = kErrorAbort; Tlv::Type requestPrimaryTlvType; // If a client or server receives a response where the message // ID is zero, or is any other value that does not match the // message ID of any of its outstanding operations, this is a // fatal error and the recipient MUST forcibly abort the // connection immediately. VerifyOrExit(aHeader.GetMessageId() != 0); VerifyOrExit(mPendingRequests.Contains(aHeader.GetMessageId(), requestPrimaryTlvType)); // If the response has no error and contains a primary TLV, it // MUST match the request primary TLV. if ((aHeader.GetResponseCode() == Dns::Header::kResponseSuccess) && (aPrimaryTlvType != Tlv::kReservedType)) { VerifyOrExit(aPrimaryTlvType == requestPrimaryTlvType); } mPendingRequests.Remove(aHeader.GetMessageId()); switch (requestPrimaryTlvType) { case KeepAliveTlv::kType: SuccessOrExit(error = ProcessKeepAliveMessage(aHeader, aMessage)); break; default: SuccessOrExit(error = mCallbacks.mProcessResponseMessage(*this, aHeader, aMessage, aPrimaryTlvType, requestPrimaryTlvType)); break; } // DSO session is established when client sends a request message // and receives a response from server with no error code. if (IsClient() && (mState == kStateEstablishingSession) && (aHeader.GetResponseCode() == Dns::Header::kResponseSuccess)) { SetState(kStateSessionEstablished); } exit: return error; } Error Dso::Connection::ProcessKeepAliveMessage(const Dns::Header &aHeader, const Message &aMessage) { Error error = kErrorAbort; uint16_t offset = aMessage.GetOffset(); Tlv tlv; KeepAliveTlv keepAliveTlv; if (aHeader.GetType() == Dns::Header::kTypeResponse) { // A Keep Alive response message is allowed on a client from a sever. VerifyOrExit(IsClient()); if (aHeader.GetResponseCode() != Dns::Header::kResponseSuccess) { // We got an error response code from server for our // Keep Alive request message. If this happens while // establishing the DSO session, it indicates that server // does not support DSO, so we close the connection. If // this happens while session is already established, it // is a misbehavior (fatal error) by server. if (mState == kStateEstablishingSession) { Disconnect(kGracefullyClose, kReasonPeerDoesNotSupportDso); error = kErrorNone; } ExitNow(); } } // Parse and validate the Keep Alive Message SuccessOrExit(aMessage.Read(offset, keepAliveTlv)); offset += keepAliveTlv.GetSize(); VerifyOrExit((keepAliveTlv.GetType() == KeepAliveTlv::kType) && keepAliveTlv.IsValid()); // Keep Alive message MUST contain only one Keep Alive TLV. while (offset < aMessage.GetLength()) { SuccessOrExit(aMessage.Read(offset, tlv)); offset += tlv.GetSize(); VerifyOrExit((tlv.GetType() != KeepAliveTlv::kType) && (tlv.GetType() != RetryDelayTlv::kType)); } VerifyOrExit(offset == aMessage.GetLength()); if (aHeader.GetType() == Dns::Header::kTypeQuery) { if (IsServer()) { // Received a Keep Alive message from client. It MUST // be a request message (not unidirectional). We prepare // and send a Keep Alive response. VerifyOrExit(aHeader.GetMessageId() != 0); LogInfo("Received KeepAlive request message from client %s", mPeerSockAddr.ToString().AsCString()); IgnoreError(SendKeepAliveMessage(kResponseMessage, aHeader.GetMessageId())); error = kErrorNone; ExitNow(); } // Received a Keep Alive message on client from server. Server // Keep Alive message MUST be unidirectional (message ID // zero). VerifyOrExit(aHeader.GetMessageId() == 0); } LogInfo("Received Keep Alive %s message from server %s", (aHeader.GetMessageId() == 0) ? "unidirectional" : "response", mPeerSockAddr.ToString().AsCString()); // Receiving a Keep Alive interval value from server less than the // minimum (ten seconds) is a fatal error and client MUST then // abort the connection. VerifyOrExit(keepAliveTlv.GetKeepAliveInterval() >= kMinKeepAliveInterval); // Update the timeout intervals on the connection from // the new values we got from the server. The receive // of the Keep Alive message does not itself reset the // inactivity timer. So we use `AdjustInactivityTimeout` // which takes into account the time elapsed since the // last activity. AdjustInactivityTimeout(keepAliveTlv.GetInactivityTimeout()); mKeepAlive.SetInterval(keepAliveTlv.GetKeepAliveInterval()); LogInfo("Timeouts Inactivity:%u, KeepAlive:%u", mInactivity.GetInterval(), mKeepAlive.GetInterval()); error = kErrorNone; exit: return error; } Error Dso::Connection::ProcessRetryDelayMessage(const Dns::Header &aHeader, const Message &aMessage) { Error error = kErrorAbort; RetryDelayTlv retryDelayTlv; // Retry Delay TLV can be used as the Primary TLV only in // a unidirectional message sent from server to client. // It is used by the server to instruct the client to // close the session and its underlying connection, and not // to reconnect for the indicated time interval. VerifyOrExit(IsClient() && (aHeader.GetMessageId() == 0)); SuccessOrExit(aMessage.Read(aMessage.GetOffset(), retryDelayTlv)); VerifyOrExit(retryDelayTlv.IsValid()); mRetryDelayErrorCode = aHeader.GetResponseCode(); mRetryDelay = retryDelayTlv.GetRetryDelay(); LogInfo("Received Retry Delay message from server %s", mPeerSockAddr.ToString().AsCString()); LogInfo(" RetryDelay:%u ms, ResponseCode:%d", mRetryDelay, mRetryDelayErrorCode); Disconnect(kGracefullyClose, kReasonServerRetryDelayRequest); exit: return error; } void Dso::Connection::SendErrorResponse(const Dns::Header &aHeader, Dns::Header::Response aResponseCode) { Message * response = NewMessage(); Dns::Header header; VerifyOrExit(response != nullptr); header.SetMessageId(aHeader.GetMessageId()); header.SetType(Dns::Header::kTypeResponse); header.SetQueryType(aHeader.GetQueryType()); header.SetResponseCode(aResponseCode); SuccessOrExit(response->Prepend(header)); otPlatDsoSend(this, response); response = nullptr; exit: FreeMessage(response); } void Dso::Connection::AdjustInactivityTimeout(uint32_t aNewTimeout) { // This method sets the inactivity timeout interval to a new value // and updates the expiration time based on the new timeout value. // // On client, it is called on receiving a Keep Alive response or // unidirectional message from server. Note that the receive of // the Keep Alive message does not itself reset the inactivity // timer. So the time elapsed since the last activity should be // taken into account with the new inactivity timeout value. // // On server this method is called from `SetTimeouts()` when a new // inactivity timeout value is set. TimeMilli now = TimerMilli::GetNow(); TimeMilli start; TimeMilli newExpiration; if (mState == kStateDisconnected) { mInactivity.SetInterval(aNewTimeout); ExitNow(); } VerifyOrExit(aNewTimeout != mInactivity.GetInterval()); // Calculate the start time (i.e., the last time inactivity timer // was cleared). If the previous inactivity time is set to // `kInfinite` value (`IsUsed()` returns `false`) then // `GetExpirationTime()` returns the start time. Otherwise, we // calculate it going back from the current expiration time with // the current wait interval. if (!mInactivity.IsUsed()) { start = mInactivity.GetExpirationTime(); } else if (IsClient()) { start = mInactivity.GetExpirationTime() - mInactivity.GetInterval(); } else { start = mInactivity.GetExpirationTime() - CalculateServerInactivityWaitTime(); } mInactivity.SetInterval(aNewTimeout); if (!mInactivity.IsUsed()) { newExpiration = start; } else if (IsClient()) { newExpiration = start + aNewTimeout; if (newExpiration < now) { newExpiration = now; } } else { newExpiration = start + CalculateServerInactivityWaitTime(); if (newExpiration < now) { // If the server abruptly reduces the inactivity timeout // such that current elapsed time is already more than // twice the new inactivity timeout, then the client is // immediately considered delinquent (server can forcibly // abort the connection). So to give the client time to // close the connection gracefully, the server SHOULD // give the client an additional grace period of either // five seconds or one quarter of the new inactivity // timeout, whichever is greater [RFC 8490 - 7.1.1]. newExpiration = now + OT_MAX(kMinServerInactivityWaitTime, aNewTimeout / 4); } } mInactivity.SetExpirationTime(newExpiration); exit: return; } uint32_t Dso::Connection::CalculateServerInactivityWaitTime(void) const { // A server will abort an idle session after five seconds // (`kMinServerInactivityWaitTime`) or twice the inactivity // timeout value, whichever is greater [RFC 8490 - 6.4.1]. OT_ASSERT(mInactivity.IsUsed()); return OT_MAX(mInactivity.GetInterval() * 2, kMinServerInactivityWaitTime); } void Dso::Connection::ResetTimeouts(bool aIsKeepAliveMessage) { TimeMilli now = TimerMilli::GetNow(); TimeMilli nextTime; // At both servers and clients, the generation or reception of any // complete DNS message resets both timers for that DSO // session, with the one exception being that a DSO Keep Alive // message resets only the keep alive timer, not the inactivity // timeout timer [RFC 8490 - 6.3] if (mKeepAlive.IsUsed()) { // On client, we wait for the Keep Alive interval but on server // we wait for twice the interval before considering Keep Alive // timeout. // // Note that we limit the interval to `Timeout::kMaxInterval` // (which is ~12 days). This max limit ensures that even twice // the interval is less than max OpenThread timer duration so // that the expiration time calculations below stay within the // `TimerMilli` range. mKeepAlive.SetExpirationTime(now + mKeepAlive.GetInterval() * (IsServer() ? 2 : 1)); } if (!aIsKeepAliveMessage) { if (mInactivity.IsUsed()) { mInactivity.SetExpirationTime( now + (IsServer() ? CalculateServerInactivityWaitTime() : mInactivity.GetInterval())); } else { // When Inactivity timeout is not used (i.e., interval is set // to the special `kInfinite` value), we still need to track // the time so that if/when later the inactivity interval // gets changed, we can adjust the remaining time correctly // from `AdjustInactivityTimeout()`. In this case, we just // track the current time as "expiration time". mInactivity.SetExpirationTime(now); } } nextTime = GetNextFireTime(now); if (nextTime != now.GetDistantFuture()) { Get().mTimer.FireAtIfEarlier(nextTime); } } TimeMilli Dso::Connection::GetNextFireTime(TimeMilli aNow) const { TimeMilli nextTime = aNow.GetDistantFuture(); switch (mState) { case kStateDisconnected: break; case kStateConnecting: // While in `kStateConnecting`, Keep Alive timer is // used for `kConnectingTimeout`. VerifyOrExit(mKeepAlive.GetExpirationTime() > aNow, nextTime = aNow); nextTime = mKeepAlive.GetExpirationTime(); break; case kStateConnectedButSessionless: case kStateEstablishingSession: case kStateSessionEstablished: nextTime = OT_MIN(nextTime, mPendingRequests.GetNextFireTime(aNow)); if (mKeepAlive.IsUsed()) { VerifyOrExit(mKeepAlive.GetExpirationTime() > aNow, nextTime = aNow); nextTime = OT_MIN(nextTime, mKeepAlive.GetExpirationTime()); } if (mInactivity.IsUsed() && mPendingRequests.IsEmpty() && !mLongLivedOperation) { // An operation being active on a DSO Session includes // a request message waiting for a response, or an // active long-lived operation. VerifyOrExit(mInactivity.GetExpirationTime() > aNow, nextTime = aNow); nextTime = OT_MIN(nextTime, mInactivity.GetExpirationTime()); } break; } exit: return nextTime; } void Dso::Connection::HandleTimer(TimeMilli aNow, TimeMilli &aNextTime) { switch (mState) { case kStateDisconnected: break; case kStateConnecting: if (mKeepAlive.IsExpired(aNow)) { Disconnect(kGracefullyClose, kReasonFailedToConnect); } break; case kStateConnectedButSessionless: case kStateEstablishingSession: case kStateSessionEstablished: if (mPendingRequests.HasAnyTimedOut(aNow)) { // If server sends no response to a request, client // waits for 30 seconds (`kResponseTimeout`) after which // client MUST forcibly abort the connection. Disconnect(kForciblyAbort, kReasonResponseTimeout); ExitNow(); } // The inactivity timer is kept clear, while an operation is // active on the session (which includes a request waiting for // response or an active long-lived operation). if (mInactivity.IsUsed() && mPendingRequests.IsEmpty() && !mLongLivedOperation && mInactivity.IsExpired(aNow)) { // On client, if the inactivity timeout is reached, the // connection is closed gracefully. On server, if too much // time (`CalculateServerInactivityWaitTime()`, i.e., five // seconds or twice the current inactivity timeout interval, // whichever is grater) elapses server MUST consider the // client delinquent and MUST forcibly abort the connection. Disconnect(IsClient() ? kGracefullyClose : kForciblyAbort, kReasonInactivityTimeout); ExitNow(); } if (mKeepAlive.IsUsed() && mKeepAlive.IsExpired(aNow)) { // On client, if the Keep Alive interval elapses without any // DNS messages being sent or received, the client MUST take // action and send a DSO Keep Alive message. // // On server, if twice the Keep Alive interval value elapses // without any messages being sent or received, the server // considers the client delinquent and aborts the connection. if (IsClient()) { IgnoreError(SendKeepAliveMessage()); } else { Disconnect(kForciblyAbort, kReasonKeepAliveTimeout); ExitNow(); } } break; } exit: aNextTime = OT_MIN(aNextTime, GetNextFireTime(aNow)); SignalAnyStateChange(); } const char *Dso::Connection::StateToString(State aState) { static const char *const kStateStrings[] = { "Disconnected", // (0) kStateDisconnected, "Connecting", // (1) kStateConnecting, "ConnectedButSessionless", // (2) kStateConnectedButSessionless, "EstablishingSession", // (3) kStateEstablishingSession, "SessionEstablished", // (4) kStateSessionEstablished, }; static_assert(0 == kStateDisconnected, "kStateDisconnected value is incorrect"); static_assert(1 == kStateConnecting, "kStateConnecting value is incorrect"); static_assert(2 == kStateConnectedButSessionless, "kStateConnectedButSessionless value is incorrect"); static_assert(3 == kStateEstablishingSession, "kStateEstablishingSession value is incorrect"); static_assert(4 == kStateSessionEstablished, "kStateSessionEstablished value is incorrect"); return kStateStrings[aState]; } const char *Dso::Connection::MessageTypeToString(MessageType aMessageType) { static const char *const kMessageTypeStrings[] = { "Request", // (0) kRequestMessage "Response", // (1) kResponseMessage "Unidirectional", // (2) kUnidirectionalMessage }; static_assert(0 == kRequestMessage, "kRequestMessage value is incorrect"); static_assert(1 == kResponseMessage, "kResponseMessage value is incorrect"); static_assert(2 == kUnidirectionalMessage, "kUnidirectionalMessage value is incorrect"); return kMessageTypeStrings[aMessageType]; } const char *Dso::Connection::DisconnectReasonToString(DisconnectReason aReason) { static const char *const kDisconnectReasonStrings[] = { "FailedToConnect", // (0) kReasonFailedToConnect "ResponseTimeout", // (1) kReasonResponseTimeout "PeerDoesNotSupportDso", // (2) kReasonPeerDoesNotSupportDso "PeerClosed", // (3) kReasonPeerClosed "PeerAborted", // (4) kReasonPeerAborted "InactivityTimeout", // (5) kReasonInactivityTimeout "KeepAliveTimeout", // (6) kReasonKeepAliveTimeout "ServerRetryDelayRequest", // (7) kReasonServerRetryDelayRequest "PeerMisbehavior", // (8) kReasonPeerMisbehavior "Unknown", // (9) kReasonUnknown }; static_assert(0 == kReasonFailedToConnect, "kReasonFailedToConnect value is incorrect"); static_assert(1 == kReasonResponseTimeout, "kReasonResponseTimeout value is incorrect"); static_assert(2 == kReasonPeerDoesNotSupportDso, "kReasonPeerDoesNotSupportDso value is incorrect"); static_assert(3 == kReasonPeerClosed, "kReasonPeerClosed value is incorrect"); static_assert(4 == kReasonPeerAborted, "kReasonPeerAborted value is incorrect"); static_assert(5 == kReasonInactivityTimeout, "kReasonInactivityTimeout value is incorrect"); static_assert(6 == kReasonKeepAliveTimeout, "kReasonKeepAliveTimeout value is incorrect"); static_assert(7 == kReasonServerRetryDelayRequest, "kReasonServerRetryDelayRequest value is incorrect"); static_assert(8 == kReasonPeerMisbehavior, "kReasonPeerMisbehavior value is incorrect"); static_assert(9 == kReasonUnknown, "kReasonUnknown value is incorrect"); return kDisconnectReasonStrings[aReason]; } //--------------------------------------------------------------------------------------------------------------------- // Dso::Connection::PendingRequests bool Dso::Connection::PendingRequests::Contains(MessageId aMessageId, Tlv::Type &aPrimaryTlvType) const { bool contains = true; const Entry *entry = mRequests.FindMatching(aMessageId); VerifyOrExit(entry != nullptr, contains = false); aPrimaryTlvType = entry->mPrimaryTlvType; exit: return contains; } Error Dso::Connection::PendingRequests::Add(MessageId aMessageId, Tlv::Type aPrimaryTlvType, TimeMilli aResponseTimeout) { Error error = kErrorNone; Entry *entry = mRequests.PushBack(); VerifyOrExit(entry != nullptr, error = kErrorNoBufs); entry->mMessageId = aMessageId; entry->mPrimaryTlvType = aPrimaryTlvType; entry->mTimeout = aResponseTimeout; exit: return error; } void Dso::Connection::PendingRequests::Remove(MessageId aMessageId) { mRequests.RemoveMatching(aMessageId); } bool Dso::Connection::PendingRequests::HasAnyTimedOut(TimeMilli aNow) const { bool timedOut = false; for (const Entry &entry : mRequests) { if (entry.mTimeout <= aNow) { timedOut = true; break; } } return timedOut; } TimeMilli Dso::Connection::PendingRequests::GetNextFireTime(TimeMilli aNow) const { TimeMilli nextTime = aNow.GetDistantFuture(); for (const Entry &entry : mRequests) { VerifyOrExit(entry.mTimeout > aNow, nextTime = aNow); nextTime = OT_MIN(entry.mTimeout, nextTime); } exit: return nextTime; } //--------------------------------------------------------------------------------------------------------------------- // Dso Dso::Dso(Instance &aInstance) : InstanceLocator(aInstance) , mAcceptHandler(nullptr) , mTimer(aInstance, HandleTimer) { } void Dso::StartListening(AcceptHandler aAcceptHandler) { mAcceptHandler = aAcceptHandler; otPlatDsoEnableListening(&GetInstance(), true); } void Dso::StopListening(void) { otPlatDsoEnableListening(&GetInstance(), false); } Dso::Connection *Dso::FindClientConnection(const Ip6::SockAddr &aPeerSockAddr) { return mClientConnections.FindMatching(aPeerSockAddr); } Dso::Connection *Dso::FindServerConnection(const Ip6::SockAddr &aPeerSockAddr) { return mServerConnections.FindMatching(aPeerSockAddr); } Dso::Connection *Dso::AcceptConnection(const Ip6::SockAddr &aPeerSockAddr) { Connection *connection = nullptr; VerifyOrExit(mAcceptHandler != nullptr); connection = mAcceptHandler(GetInstance(), aPeerSockAddr); VerifyOrExit(connection != nullptr); connection->Accept(); exit: return connection; } void Dso::HandleTimer(Timer &aTimer) { aTimer.Get().HandleTimer(); } void Dso::HandleTimer(void) { TimeMilli now = TimerMilli::GetNow(); TimeMilli nextTime = now.GetDistantFuture(); Connection *conn; Connection *next; for (conn = mClientConnections.GetHead(); conn != nullptr; conn = next) { next = conn->GetNext(); conn->HandleTimer(now, nextTime); } for (conn = mServerConnections.GetHead(); conn != nullptr; conn = next) { next = conn->GetNext(); conn->HandleTimer(now, nextTime); } if (nextTime != now.GetDistantFuture()) { mTimer.FireAtIfEarlier(nextTime); } } } // namespace Dns } // namespace ot #endif // OPENTHREAD_CONFIG_DNS_DSO_ENABLE