• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2016, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements the necessary hooks for mbedTLS.
32  */
33 
34 #include "secure_transport.hpp"
35 
36 #include <mbedtls/debug.h>
37 #ifdef MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
38 #include <mbedtls/pem.h>
39 #endif
40 
41 #include "instance/instance.hpp"
42 
43 #if OPENTHREAD_CONFIG_SECURE_TRANSPORT_ENABLE
44 
45 namespace ot {
46 namespace MeshCoP {
47 
48 RegisterLogModule("SecTransport");
49 
50 //---------------------------------------------------------------------------------------------------------------------
51 // SecureSession
52 
SecureSession(SecureTransport & aTransport)53 SecureSession::SecureSession(SecureTransport &aTransport)
54     : mTransport(aTransport)
55 {
56     Init();
57 }
58 
Init(void)59 void SecureSession::Init(void)
60 {
61     mTimerSet       = false;
62     mIsServer       = false;
63     mState          = kStateDisconnected;
64     mMessageSubType = Message::kSubTypeNone;
65     mConnectEvent   = kDisconnectedError;
66     mReceiveMessage = nullptr;
67     mMessageInfo.Clear();
68 
69     MarkAsNotUsed();
70     ClearAllBytes(mSsl);
71     ClearAllBytes(mConf);
72 #if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_COOKIE_C)
73     ClearAllBytes(mCookieCtx);
74 #endif
75 }
76 
FreeMbedtls(void)77 void SecureSession::FreeMbedtls(void)
78 {
79 #if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_COOKIE_C)
80     if (mTransport.mDatagramTransport)
81     {
82         mbedtls_ssl_cookie_free(&mCookieCtx);
83     }
84 #endif
85 #if OPENTHREAD_CONFIG_TLS_API_ENABLE && defined(MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED)
86     if (mTransport.mExtension != nullptr)
87     {
88         mTransport.mExtension->mEcdheEcdsaInfo.Free();
89     }
90 #endif
91     mbedtls_ssl_config_free(&mConf);
92     mbedtls_ssl_free(&mSsl);
93 }
94 
SetState(State aState)95 void SecureSession::SetState(State aState)
96 {
97     VerifyOrExit(mState != aState);
98 
99     LogInfo("Session state: %s -> %s", StateToString(mState), StateToString(aState));
100     mState = aState;
101 
102 exit:
103     return;
104 }
105 
Connect(const Ip6::SockAddr & aSockAddr)106 Error SecureSession::Connect(const Ip6::SockAddr &aSockAddr)
107 {
108     Error error;
109 
110     VerifyOrExit(mTransport.mIsOpen, error = kErrorInvalidState);
111     VerifyOrExit(!IsSessionInUse(), error = kErrorInvalidState);
112 
113     Init();
114     mMessageInfo.SetPeerAddr(aSockAddr.GetAddress());
115     mMessageInfo.SetPeerPort(aSockAddr.mPort);
116 
117     SuccessOrExit(error = Setup());
118 
119     mTransport.mSessions.Push(*this);
120 
121 exit:
122     return error;
123 }
124 
Accept(Message & aMessage,const Ip6::MessageInfo & aMessageInfo)125 void SecureSession::Accept(Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
126 {
127     mMessageInfo.SetPeerAddr(aMessageInfo.GetPeerAddr());
128     mMessageInfo.SetPeerPort(aMessageInfo.GetPeerPort());
129     mMessageInfo.SetIsHostInterface(aMessageInfo.IsHostInterface());
130     mMessageInfo.SetSockAddr(aMessageInfo.GetSockAddr());
131     mMessageInfo.SetSockPort(aMessageInfo.GetSockPort());
132 
133     mIsServer = true;
134 
135     if (Setup() == kErrorNone)
136     {
137         HandleTransportReceive(aMessage);
138     }
139 }
140 
HandleTransportReceive(Message & aMessage)141 void SecureSession::HandleTransportReceive(Message &aMessage)
142 {
143     VerifyOrExit(!IsDisconnected());
144 
145 #ifdef MBEDTLS_SSL_SRV_C
146     if (IsConnecting())
147     {
148         mbedtls_ssl_set_client_transport_id(&mSsl, mMessageInfo.GetPeerAddr().GetBytes(), sizeof(Ip6::Address));
149     }
150 #endif
151 
152     mReceiveMessage = &aMessage;
153     Process();
154     mReceiveMessage = nullptr;
155 
156 exit:
157     return;
158 }
159 
Setup(void)160 Error SecureSession::Setup(void)
161 {
162     Error error = kErrorNone;
163     int   rval  = 0;
164 
165     OT_ASSERT(mTransport.mCipherSuite != SecureTransport::kUnspecifiedCipherSuite);
166 
167     SetState(kStateInitializing);
168 
169     if (mTransport.HasNoRemainingConnectionAttempts())
170     {
171         mConnectEvent = kDisconnectedMaxAttempts;
172         error         = kErrorNoBufs;
173         ExitNow();
174     }
175 
176     mTransport.DecremenetRemainingConnectionAttempts();
177 
178     //- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
179     // Setup the mbedtls_ssl_config `mConf`.
180 
181     mbedtls_ssl_config_init(&mConf);
182 
183     rval = mbedtls_ssl_config_defaults(&mConf, mIsServer ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT,
184                                        mTransport.mDatagramTransport ? MBEDTLS_SSL_TRANSPORT_DATAGRAM
185                                                                      : MBEDTLS_SSL_TRANSPORT_STREAM,
186                                        MBEDTLS_SSL_PRESET_DEFAULT);
187     VerifyOrExit(rval == 0);
188 
189 #if OPENTHREAD_CONFIG_TLS_API_ENABLE && defined(MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED)
190     if (mTransport.mVerifyPeerCertificate &&
191         (mTransport.mCipherSuite == SecureTransport::kEcdheEcdsaWithAes128Ccm8 ||
192          mTransport.mCipherSuite == SecureTransport::kEcdheEcdsaWithAes128GcmSha256))
193     {
194         mbedtls_ssl_conf_authmode(&mConf, MBEDTLS_SSL_VERIFY_REQUIRED);
195     }
196     else
197     {
198         mbedtls_ssl_conf_authmode(&mConf, MBEDTLS_SSL_VERIFY_NONE);
199     }
200 #endif
201 
202     mbedtls_ssl_conf_rng(&mConf, Crypto::MbedTls::CryptoSecurePrng, nullptr);
203 #if (MBEDTLS_VERSION_NUMBER >= 0x03020000)
204     mbedtls_ssl_conf_min_tls_version(&mConf, MBEDTLS_SSL_VERSION_TLS1_2);
205     mbedtls_ssl_conf_max_tls_version(&mConf, MBEDTLS_SSL_VERSION_TLS1_2);
206 #else
207     mbedtls_ssl_conf_min_version(&mConf, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3);
208     mbedtls_ssl_conf_max_version(&mConf, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3);
209 #endif
210 
211     {
212         // We use `kCipherSuites[mCipherSuite]` to look up the cipher
213         // suites array to pass to `mbedtls_ssl_conf_ciphersuites()`
214         // associated with `mCipherSuite`. We validate that the `enum`
215         // values are correct and match the order in the `kCipherSuites[]`
216         // array.
217 
218         struct EnumCheck
219         {
220             InitEnumValidatorCounter();
221             ValidateNextEnum(SecureTransport::kEcjpakeWithAes128Ccm8);
222 #if OPENTHREAD_CONFIG_TLS_API_ENABLE && defined(MBEDTLS_KEY_EXCHANGE_PSK_ENABLED)
223             ValidateNextEnum(SecureTransport::kPskWithAes128Ccm8);
224 #endif
225 #if OPENTHREAD_CONFIG_TLS_API_ENABLE && defined(MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED)
226             ValidateNextEnum(SecureTransport::kEcdheEcdsaWithAes128Ccm8);
227             ValidateNextEnum(SecureTransport::kEcdheEcdsaWithAes128GcmSha256);
228 #endif
229         };
230 
231         mbedtls_ssl_conf_ciphersuites(&mConf, SecureTransport::kCipherSuites[mTransport.mCipherSuite]);
232     }
233 
234     if (mTransport.mCipherSuite == SecureTransport::kEcjpakeWithAes128Ccm8)
235     {
236 #if (MBEDTLS_VERSION_NUMBER >= 0x03010000)
237         mbedtls_ssl_conf_groups(&mConf, SecureTransport::kGroups);
238 #else
239         mbedtls_ssl_conf_curves(&mConf, SecureTransport::kCurves);
240 #endif
241 #if defined(MBEDTLS_KEY_EXCHANGE__WITH_CERT__ENABLED) || defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
242 #if (MBEDTLS_VERSION_NUMBER >= 0x03020000)
243         mbedtls_ssl_conf_sig_algs(&mConf, SecureTransport::kSignatures);
244 #else
245         mbedtls_ssl_conf_sig_hashes(&mConf, SecureTransport::kHashes);
246 #endif
247 #endif
248     }
249 
250 #if (MBEDTLS_VERSION_NUMBER < 0x03000000)
251     mbedtls_ssl_conf_export_keys_cb(&mConf, SecureTransport::HandleMbedtlsExportKeys, &mTransport);
252 #endif
253 
254     mbedtls_ssl_conf_handshake_timeout(&mConf, 8000, 60000);
255     mbedtls_ssl_conf_dbg(&mConf, SecureTransport::HandleMbedtlsDebug, &mTransport);
256 
257     //- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
258     // Setup the `Extension` components.
259 
260 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
261     if (mTransport.mExtension != nullptr)
262     {
263 #if defined(MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED)
264         mTransport.mExtension->mEcdheEcdsaInfo.Init();
265 #endif
266         rval = mTransport.mExtension->SetApplicationSecureKeys(mConf);
267         VerifyOrExit(rval == 0);
268     }
269 #endif
270 
271     //- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
272     // Setup the mbedtls_ssl_cookie_ctx `mCookieCtx`.
273 
274 #if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_COOKIE_C)
275     if (mTransport.mDatagramTransport)
276     {
277         mbedtls_ssl_cookie_init(&mCookieCtx);
278 
279         if (mIsServer)
280         {
281             rval = mbedtls_ssl_cookie_setup(&mCookieCtx, Crypto::MbedTls::CryptoSecurePrng, nullptr);
282             VerifyOrExit(rval == 0);
283 
284             mbedtls_ssl_conf_dtls_cookies(&mConf, mbedtls_ssl_cookie_write, mbedtls_ssl_cookie_check, &mCookieCtx);
285         }
286     }
287 #endif
288 
289     //- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
290     // Setup the mbedtls_ssl_context `mSsl`.
291 
292     mbedtls_ssl_init(&mSsl);
293 
294     rval = mbedtls_ssl_setup(&mSsl, &mConf);
295     VerifyOrExit(rval == 0);
296 
297     mbedtls_ssl_set_bio(&mSsl, this, HandleMbedtlsTransmit, HandleMbedtlsReceive, /* RecvTimeoutFn */ nullptr);
298 
299     if (mTransport.mDatagramTransport)
300     {
301         mbedtls_ssl_set_timer_cb(&mSsl, this, HandleMbedtlsSetTimer, HandleMbedtlsGetTimer);
302     }
303 
304 #if (MBEDTLS_VERSION_NUMBER >= 0x03000000)
305     mbedtls_ssl_set_export_keys_cb(&mSsl, SecureTransport::HandleMbedtlsExportKeys, &mTransport);
306 #endif
307 
308     if (mTransport.mCipherSuite == SecureTransport::kEcjpakeWithAes128Ccm8)
309     {
310         rval = mbedtls_ssl_set_hs_ecjpake_password(&mSsl, mTransport.mPsk, mTransport.mPskLength);
311         VerifyOrExit(rval == 0);
312     }
313 
314     mReceiveMessage = nullptr;
315     mMessageSubType = Message::kSubTypeNone;
316 
317     SetState(kStateConnecting);
318 
319     Process();
320 
321 exit:
322     if (IsInitializing())
323     {
324         error = (error == kErrorNone) ? Crypto::MbedTls::MapError(rval) : error;
325 
326         SetState(kStateDisconnected);
327         FreeMbedtls();
328         mTransport.mUpdateTask.Post();
329     }
330 
331     return error;
332 }
333 
Disconnect(ConnectEvent aEvent)334 void SecureSession::Disconnect(ConnectEvent aEvent)
335 {
336     VerifyOrExit(mTransport.mIsOpen);
337     VerifyOrExit(IsConnectingOrConnected());
338 
339     mbedtls_ssl_close_notify(&mSsl);
340     SetState(kStateDisconnecting);
341     mConnectEvent = aEvent;
342 
343     mTimerSet    = false;
344     mTimerFinish = TimerMilli::GetNow() + kGuardTimeNewConnectionMilli;
345     mTransport.mTimer.FireAtIfEarlier(mTimerFinish);
346 
347     FreeMbedtls();
348 
349 exit:
350     return;
351 }
352 
Send(Message & aMessage)353 Error SecureSession::Send(Message &aMessage)
354 {
355     Error    error  = kErrorNone;
356     uint16_t length = aMessage.GetLength();
357     uint8_t  buffer[kApplicationDataMaxLength];
358 
359     VerifyOrExit(length <= sizeof(buffer), error = kErrorNoBufs);
360 
361     mMessageSubType = aMessage.GetSubType();
362     aMessage.ReadBytes(0, buffer, length);
363 
364     SuccessOrExit(error = Crypto::MbedTls::MapError(mbedtls_ssl_write(&mSsl, buffer, length)));
365 
366     aMessage.Free();
367 
368 exit:
369     return error;
370 }
371 
IsMbedtlsHandshakeOver(mbedtls_ssl_context * aSslContext)372 bool SecureSession::IsMbedtlsHandshakeOver(mbedtls_ssl_context *aSslContext)
373 {
374     return
375 #if (MBEDTLS_VERSION_NUMBER >= 0x03000000)
376         mbedtls_ssl_is_handshake_over(aSslContext);
377 #else
378         (aSslContext->MBEDTLS_PRIVATE(state) == MBEDTLS_SSL_HANDSHAKE_OVER);
379 #endif
380 }
381 
HandleMbedtlsTransmit(void * aContext,const unsigned char * aBuf,size_t aLength)382 int SecureSession::HandleMbedtlsTransmit(void *aContext, const unsigned char *aBuf, size_t aLength)
383 {
384     return static_cast<SecureSession *>(aContext)->HandleMbedtlsTransmit(aBuf, aLength);
385 }
386 
HandleMbedtlsTransmit(const unsigned char * aBuf,size_t aLength)387 int SecureSession::HandleMbedtlsTransmit(const unsigned char *aBuf, size_t aLength)
388 {
389     Message::SubType msgSubType = mMessageSubType;
390 
391     mMessageSubType = Message::kSubTypeNone;
392 
393     return mTransport.Transmit(aBuf, aLength, mMessageInfo, msgSubType);
394 }
395 
HandleMbedtlsReceive(void * aContext,unsigned char * aBuf,size_t aLength)396 int SecureSession::HandleMbedtlsReceive(void *aContext, unsigned char *aBuf, size_t aLength)
397 {
398     return static_cast<SecureSession *>(aContext)->HandleMbedtlsReceive(aBuf, aLength);
399 }
400 
HandleMbedtlsReceive(unsigned char * aBuf,size_t aLength)401 int SecureSession::HandleMbedtlsReceive(unsigned char *aBuf, size_t aLength)
402 {
403     int      rval = MBEDTLS_ERR_SSL_WANT_READ;
404     uint16_t readLength;
405 
406     VerifyOrExit(mReceiveMessage != nullptr);
407 
408     readLength = mReceiveMessage->ReadBytes(mReceiveMessage->GetOffset(), aBuf, static_cast<uint16_t>(aLength));
409     VerifyOrExit(readLength > 0);
410 
411     mReceiveMessage->MoveOffset(readLength);
412     rval = static_cast<int>(readLength);
413 
414 exit:
415     return rval;
416 }
417 
HandleMbedtlsGetTimer(void * aContext)418 int SecureSession::HandleMbedtlsGetTimer(void *aContext)
419 {
420     return static_cast<SecureSession *>(aContext)->HandleMbedtlsGetTimer();
421 }
422 
HandleMbedtlsGetTimer(void)423 int SecureSession::HandleMbedtlsGetTimer(void)
424 {
425     int rval = 0;
426 
427     // `mbedtls_ssl_get_timer_t` return values:
428     //   -1 if cancelled
429     //    0 if none of the delays have passed,
430     //    1 if only the intermediate delay has passed,
431     //    2 if the final delay has passed.
432 
433     if (!mTimerSet)
434     {
435         rval = -1;
436     }
437     else
438     {
439         TimeMilli now = TimerMilli::GetNow();
440 
441         if (now >= mTimerFinish)
442         {
443             rval = 2;
444         }
445         else if (now >= mTimerIntermediate)
446         {
447             rval = 1;
448         }
449     }
450 
451     return rval;
452 }
453 
HandleMbedtlsSetTimer(void * aContext,uint32_t aIntermediate,uint32_t aFinish)454 void SecureSession::HandleMbedtlsSetTimer(void *aContext, uint32_t aIntermediate, uint32_t aFinish)
455 {
456     static_cast<SecureSession *>(aContext)->HandleMbedtlsSetTimer(aIntermediate, aFinish);
457 }
458 
HandleMbedtlsSetTimer(uint32_t aIntermediate,uint32_t aFinish)459 void SecureSession::HandleMbedtlsSetTimer(uint32_t aIntermediate, uint32_t aFinish)
460 {
461     if (aFinish == 0)
462     {
463         mTimerSet = false;
464     }
465     else
466     {
467         TimeMilli now = TimerMilli::GetNow();
468 
469         mTimerSet          = true;
470         mTimerIntermediate = now + aIntermediate;
471         mTimerFinish       = now + aFinish;
472 
473         mTransport.mTimer.FireAtIfEarlier(mTimerFinish);
474     }
475 }
476 
HandleTimer(TimeMilli aNow)477 void SecureSession::HandleTimer(TimeMilli aNow)
478 {
479     if (IsConnectingOrConnected())
480     {
481         VerifyOrExit(mTimerSet);
482 
483         if (aNow < mTimerFinish)
484         {
485             mTransport.mTimer.FireAtIfEarlier(mTimerFinish);
486             ExitNow();
487         }
488 
489         Process();
490         ExitNow();
491     }
492 
493     if (IsDisconnecting())
494     {
495         if (aNow < mTimerFinish)
496         {
497             mTransport.mTimer.FireAtIfEarlier(mTimerFinish);
498             ExitNow();
499         }
500 
501         SetState(kStateDisconnected);
502         mTransport.mUpdateTask.Post();
503     }
504 
505 exit:
506     return;
507 }
508 
Process(void)509 void SecureSession::Process(void)
510 {
511     uint8_t      buf[kMaxContentLen];
512     int          rval;
513     ConnectEvent disconnectEvent;
514     bool         shouldReset;
515 
516     while (IsConnectingOrConnected())
517     {
518         if (IsConnecting())
519         {
520             rval = mbedtls_ssl_handshake(&mSsl);
521 
522             if (IsMbedtlsHandshakeOver(&mSsl))
523             {
524                 SetState(kStateConnected);
525                 mConnectEvent = kConnected;
526                 mConnectedCallback.InvokeIfSet(mConnectEvent);
527             }
528         }
529         else
530         {
531             rval = mbedtls_ssl_read(&mSsl, buf, sizeof(buf));
532 
533             if (rval > 0)
534             {
535                 mReceiveCallback.InvokeIfSet(buf, static_cast<uint16_t>(rval));
536                 continue;
537             }
538         }
539 
540         // Check `rval` to determine if the connection should be
541         // disconnected, reset, or if we should wait.
542 
543         disconnectEvent = kConnected;
544         shouldReset     = true;
545 
546         switch (rval)
547         {
548         case 0:
549         case MBEDTLS_ERR_SSL_WANT_READ:
550         case MBEDTLS_ERR_SSL_WANT_WRITE:
551             shouldReset = false;
552             break;
553 
554         case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
555             disconnectEvent = kDisconnectedPeerClosed;
556             break;
557 
558         case MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED:
559             break;
560 
561         case MBEDTLS_ERR_SSL_FATAL_ALERT_MESSAGE:
562             disconnectEvent = kDisconnectedError;
563             break;
564 
565         case MBEDTLS_ERR_SSL_INVALID_MAC:
566             if (!IsMbedtlsHandshakeOver(&mSsl))
567             {
568                 mbedtls_ssl_send_alert_message(&mSsl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
569                                                MBEDTLS_SSL_ALERT_MSG_BAD_RECORD_MAC);
570                 disconnectEvent = kDisconnectedError;
571             }
572             break;
573 
574         default:
575             if (!IsMbedtlsHandshakeOver(&mSsl))
576             {
577                 mbedtls_ssl_send_alert_message(&mSsl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
578                                                MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE);
579                 disconnectEvent = kDisconnectedError;
580             }
581 
582             break;
583         }
584 
585         if (disconnectEvent != kConnected)
586         {
587             Disconnect(disconnectEvent);
588         }
589         else if (shouldReset)
590         {
591             mbedtls_ssl_session_reset(&mSsl);
592 
593             if (mTransport.mCipherSuite == SecureTransport::kEcjpakeWithAes128Ccm8)
594             {
595                 mbedtls_ssl_set_hs_ecjpake_password(&mSsl, mTransport.mPsk, mTransport.mPskLength);
596             }
597         }
598 
599         break; // from `while()` loop
600     }
601 }
602 
603 #if OT_SHOULD_LOG_AT(OT_LOG_LEVEL_INFO)
604 
StateToString(State aState)605 const char *SecureSession::StateToString(State aState)
606 {
607     static const char *const kStateStrings[] = {
608         "Disconnected",  // (0) kStateDisconnected
609         "Initializing",  // (1) kStateInitializing
610         "Connecting",    // (2) kStateConnecting
611         "Connected",     // (3) kStateConnected
612         "Disconnecting", // (4) kStateDisconnecting
613     };
614 
615     struct EnumCheck
616     {
617         InitEnumValidatorCounter();
618         ValidateNextEnum(kStateDisconnected);
619         ValidateNextEnum(kStateInitializing);
620         ValidateNextEnum(kStateConnecting);
621         ValidateNextEnum(kStateConnected);
622         ValidateNextEnum(kStateDisconnecting);
623     };
624 
625     return kStateStrings[aState];
626 }
627 
628 #endif
629 
630 //---------------------------------------------------------------------------------------------------------------------
631 // SecureTransport
632 
633 #if (MBEDTLS_VERSION_NUMBER >= 0x03010000)
634 const uint16_t SecureTransport::kGroups[] = {MBEDTLS_SSL_IANA_TLS_GROUP_SECP256R1, MBEDTLS_SSL_IANA_TLS_GROUP_NONE};
635 #else
636 const mbedtls_ecp_group_id SecureTransport::kCurves[] = {MBEDTLS_ECP_DP_SECP256R1, MBEDTLS_ECP_DP_NONE};
637 #endif
638 
639 #if defined(MBEDTLS_KEY_EXCHANGE__WITH_CERT__ENABLED) || defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
640 #if (MBEDTLS_VERSION_NUMBER >= 0x03020000)
641 const uint16_t SecureTransport::kSignatures[] = {MBEDTLS_TLS1_3_SIG_ECDSA_SECP256R1_SHA256, MBEDTLS_TLS1_3_SIG_NONE};
642 #else
643 const int SecureTransport::kHashes[] = {MBEDTLS_MD_SHA256, MBEDTLS_MD_NONE};
644 #endif
645 #endif
646 
647 const int SecureTransport::kCipherSuites[][2] = {
648     /* kEcjpakeWithAes128Ccm8         */ {MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8, 0},
649 #if OPENTHREAD_CONFIG_TLS_API_ENABLE && defined(MBEDTLS_KEY_EXCHANGE_PSK_ENABLED)
650     /* kPskWithAes128Ccm8             */ {MBEDTLS_TLS_PSK_WITH_AES_128_CCM_8, 0},
651 #endif
652 #if OPENTHREAD_CONFIG_TLS_API_ENABLE && defined(MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED)
653     /* kEcdheEcdsaWithAes128Ccm8      */ {MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8, 0},
654     /* kEcdheEcdsaWithAes128GcmSha256 */ {MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 0},
655 #endif
656 };
657 
SecureTransport(Instance & aInstance,LinkSecurityMode aLayerTwoSecurity,bool aDatagramTransport)658 SecureTransport::SecureTransport(Instance &aInstance, LinkSecurityMode aLayerTwoSecurity, bool aDatagramTransport)
659     : mLayerTwoSecurity(aLayerTwoSecurity)
660     , mDatagramTransport(aDatagramTransport)
661     , mIsOpen(false)
662     , mIsClosing(false)
663     , mVerifyPeerCertificate(true)
664     , mCipherSuite(kUnspecifiedCipherSuite)
665     , mPskLength(0)
666     , mMaxConnectionAttempts(0)
667     , mRemainingConnectionAttempts(0)
668     , mSocket(aInstance, *this)
669     , mTimer(aInstance, HandleTimer, this)
670     , mUpdateTask(aInstance, HandleUpdateTask, this)
671 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
672     , mExtension(nullptr)
673 #endif
674 {
675     ClearAllBytes(mPsk);
676     OT_UNUSED_VARIABLE(mVerifyPeerCertificate);
677 }
678 
Open(Ip6::NetifIdentifier aNetifIdentifier)679 Error SecureTransport::Open(Ip6::NetifIdentifier aNetifIdentifier)
680 {
681     Error error;
682 
683     VerifyOrExit(!mIsOpen, error = kErrorAlready);
684 
685     SuccessOrExit(error = mSocket.Open(aNetifIdentifier));
686     mIsOpen                      = true;
687     mRemainingConnectionAttempts = mMaxConnectionAttempts;
688 
689 exit:
690     return error;
691 }
692 
SetMaxConnectionAttempts(uint16_t aMaxAttempts,AutoCloseCallback aCallback,void * aContext)693 Error SecureTransport::SetMaxConnectionAttempts(uint16_t aMaxAttempts, AutoCloseCallback aCallback, void *aContext)
694 {
695     Error error = kErrorNone;
696 
697     VerifyOrExit(!mIsOpen, error = kErrorInvalidState);
698 
699     mMaxConnectionAttempts = aMaxAttempts;
700     mAutoCloseCallback.Set(aCallback, aContext);
701 
702 exit:
703     return error;
704 }
705 
HandleReceive(Message & aMessage,const Ip6::MessageInfo & aMessageInfo)706 void SecureTransport::HandleReceive(Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
707 {
708     SecureSession *session;
709 
710     VerifyOrExit(mIsOpen);
711 
712     session = mSessions.FindMatching(aMessageInfo);
713 
714     if (session != nullptr)
715     {
716         session->HandleTransportReceive(aMessage);
717         ExitNow();
718     }
719 
720     // A new connection request
721 
722     VerifyOrExit(mAcceptCallback.IsSet());
723 
724     session = mAcceptCallback.Invoke(aMessageInfo);
725     VerifyOrExit(session != nullptr);
726 
727     session->Init();
728     mSessions.Push(*session);
729 
730     session->Accept(aMessage, aMessageInfo);
731 
732 exit:
733     return;
734 }
735 
Bind(uint16_t aPort)736 Error SecureTransport::Bind(uint16_t aPort)
737 {
738     Error error;
739 
740     VerifyOrExit(mIsOpen, error = kErrorInvalidState);
741     VerifyOrExit(!mTransportCallback.IsSet(), error = kErrorAlready);
742 
743     VerifyOrExit(mSessions.IsEmpty(), error = kErrorInvalidState);
744 
745     error = mSocket.Bind(aPort);
746 
747 exit:
748     return error;
749 }
750 
Bind(TransportCallback aCallback,void * aContext)751 Error SecureTransport::Bind(TransportCallback aCallback, void *aContext)
752 {
753     Error error = kErrorNone;
754 
755     VerifyOrExit(mIsOpen, error = kErrorInvalidState);
756     VerifyOrExit(!mSocket.IsBound(), error = kErrorAlready);
757     VerifyOrExit(!mTransportCallback.IsSet(), error = kErrorAlready);
758 
759     VerifyOrExit(mSessions.IsEmpty(), error = kErrorInvalidState);
760 
761     mTransportCallback.Set(aCallback, aContext);
762 
763 exit:
764     return error;
765 }
766 
Close(void)767 void SecureTransport::Close(void)
768 {
769     VerifyOrExit(mIsOpen);
770     VerifyOrExit(!mIsClosing);
771 
772     // `mIsClosing` is used to protect against multiple
773     // calls to `Close()` and re-entry. As the transport is closed,
774     // all existing sessions are disconnected, which can trigger
775     // connect and remove callbacks to be invoked. These callbacks
776     // may call `Close()` again.
777 
778     mIsClosing = true;
779 
780     for (SecureSession &session : mSessions)
781     {
782         session.Disconnect(SecureSession::kDisconnectedLocalClosed);
783         session.SetState(SecureSession::kStateDisconnected);
784     }
785 
786     RemoveDisconnectedSessions();
787 
788     mIsOpen    = false;
789     mIsClosing = false;
790     mTransportCallback.Clear();
791     IgnoreError(mSocket.Close());
792     mTimer.Stop();
793 
794 exit:
795     return;
796 }
797 
RemoveDisconnectedSessions(void)798 void SecureTransport::RemoveDisconnectedSessions(void)
799 {
800     LinkedList<SecureSession> disconnectedSessions;
801     SecureSession            *session;
802 
803     mSessions.RemoveAllMatching(disconnectedSessions, SecureSession::kStateDisconnected);
804 
805     while ((session = disconnectedSessions.Pop()) != nullptr)
806     {
807         session->mConnectedCallback.InvokeIfSet(session->mConnectEvent);
808         session->MarkAsNotUsed();
809         session->mMessageInfo.Clear();
810         mRemoveSessionCallback.InvokeIfSet(*session);
811     }
812 }
813 
DecremenetRemainingConnectionAttempts(void)814 void SecureTransport::DecremenetRemainingConnectionAttempts(void)
815 {
816     if (mRemainingConnectionAttempts > 0)
817     {
818         mRemainingConnectionAttempts--;
819     }
820 }
821 
HasNoRemainingConnectionAttempts(void) const822 bool SecureTransport::HasNoRemainingConnectionAttempts(void) const
823 {
824     return (mMaxConnectionAttempts > 0) && (mRemainingConnectionAttempts == 0);
825 }
826 
SetPsk(const uint8_t * aPsk,uint8_t aPskLength)827 Error SecureTransport::SetPsk(const uint8_t *aPsk, uint8_t aPskLength)
828 {
829     Error error = kErrorNone;
830 
831     VerifyOrExit(aPskLength <= sizeof(mPsk), error = kErrorInvalidArgs);
832 
833     memcpy(mPsk, aPsk, aPskLength);
834     mPskLength   = aPskLength;
835     mCipherSuite = kEcjpakeWithAes128Ccm8;
836 
837 exit:
838     return error;
839 }
840 
SetPsk(const JoinerPskd & aPskd)841 void SecureTransport::SetPsk(const JoinerPskd &aPskd)
842 {
843     static_assert(JoinerPskd::kMaxLength <= kPskMaxLength, "The max DTLS PSK length is smaller than joiner PSKd");
844 
845     IgnoreError(SetPsk(aPskd.GetBytes(), aPskd.GetLength()));
846 }
847 
Transmit(const unsigned char * aBuf,size_t aLength,const Ip6::MessageInfo & aMessageInfo,Message::SubType aMessageSubType)848 int SecureTransport::Transmit(const unsigned char    *aBuf,
849                               size_t                  aLength,
850                               const Ip6::MessageInfo &aMessageInfo,
851                               Message::SubType        aMessageSubType)
852 {
853     Error    error   = kErrorNone;
854     Message *message = mSocket.NewMessage();
855     int      rval;
856 
857     VerifyOrExit(message != nullptr, error = kErrorNoBufs);
858     message->SetSubType(aMessageSubType);
859     message->SetLinkSecurityEnabled(mLayerTwoSecurity);
860 
861     SuccessOrExit(error = message->AppendBytes(aBuf, static_cast<uint16_t>(aLength)));
862 
863     if (mTransportCallback.IsSet())
864     {
865         error = mTransportCallback.Invoke(*message, aMessageInfo);
866     }
867     else
868     {
869         error = mSocket.SendTo(*message, aMessageInfo);
870     }
871 
872 exit:
873     FreeMessageOnError(message, error);
874 
875     switch (error)
876     {
877     case kErrorNone:
878         rval = static_cast<int>(aLength);
879         break;
880 
881     case kErrorNoBufs:
882         rval = MBEDTLS_ERR_SSL_WANT_WRITE;
883         break;
884 
885     default:
886         LogWarnOnError(error, "HandleMbedtlsTransmit");
887         rval = MBEDTLS_ERR_NET_SEND_FAILED;
888         break;
889     }
890 
891     return rval;
892 }
893 
894 #if (MBEDTLS_VERSION_NUMBER >= 0x03000000)
895 
HandleMbedtlsExportKeys(void * aContext,mbedtls_ssl_key_export_type aType,const unsigned char * aMasterSecret,size_t aMasterSecretLen,const unsigned char aClientRandom[32],const unsigned char aServerRandom[32],mbedtls_tls_prf_types aTlsPrfType)896 void SecureTransport::HandleMbedtlsExportKeys(void                       *aContext,
897                                               mbedtls_ssl_key_export_type aType,
898                                               const unsigned char        *aMasterSecret,
899                                               size_t                      aMasterSecretLen,
900                                               const unsigned char         aClientRandom[32],
901                                               const unsigned char         aServerRandom[32],
902                                               mbedtls_tls_prf_types       aTlsPrfType)
903 {
904     static_cast<SecureTransport *>(aContext)->HandleMbedtlsExportKeys(aType, aMasterSecret, aMasterSecretLen,
905                                                                       aClientRandom, aServerRandom, aTlsPrfType);
906 }
907 
HandleMbedtlsExportKeys(mbedtls_ssl_key_export_type aType,const unsigned char * aMasterSecret,size_t aMasterSecretLen,const unsigned char aClientRandom[32],const unsigned char aServerRandom[32],mbedtls_tls_prf_types aTlsPrfType)908 void SecureTransport::HandleMbedtlsExportKeys(mbedtls_ssl_key_export_type aType,
909                                               const unsigned char        *aMasterSecret,
910                                               size_t                      aMasterSecretLen,
911                                               const unsigned char         aClientRandom[32],
912                                               const unsigned char         aServerRandom[32],
913                                               mbedtls_tls_prf_types       aTlsPrfType)
914 {
915     Crypto::Sha256::Hash kek;
916     Crypto::Sha256       sha256;
917     unsigned char        keyBlock[kSecureTransportKeyBlockSize];
918     unsigned char        randBytes[2 * kSecureTransportRandomBufferSize];
919 
920     VerifyOrExit(mCipherSuite == kEcjpakeWithAes128Ccm8);
921     VerifyOrExit(aType == MBEDTLS_SSL_KEY_EXPORT_TLS12_MASTER_SECRET);
922 
923     memcpy(randBytes, aServerRandom, kSecureTransportRandomBufferSize);
924     memcpy(randBytes + kSecureTransportRandomBufferSize, aClientRandom, kSecureTransportRandomBufferSize);
925 
926     // Retrieve the Key block from Master secret
927     mbedtls_ssl_tls_prf(aTlsPrfType, aMasterSecret, aMasterSecretLen, "key expansion", randBytes, sizeof(randBytes),
928                         keyBlock, sizeof(keyBlock));
929 
930     sha256.Start();
931     sha256.Update(keyBlock, kSecureTransportKeyBlockSize);
932     sha256.Finish(kek);
933 
934     mTimer.Get<KeyManager>().SetKek(kek.GetBytes());
935 
936 exit:
937     return;
938 }
939 
940 #else
941 
HandleMbedtlsExportKeys(void * aContext,const unsigned char * aMasterSecret,const unsigned char * aKeyBlock,size_t aMacLength,size_t aKeyLength,size_t aIvLength)942 int SecureTransport::HandleMbedtlsExportKeys(void *aContext,
943                                              const unsigned char *aMasterSecret,
944                                              const unsigned char *aKeyBlock,
945                                              size_t aMacLength,
946                                              size_t aKeyLength,
947                                              size_t aIvLength)
948 {
949     return static_cast<SecureTransport *>(aContext)->HandleMbedtlsExportKeys(aMasterSecret, aKeyBlock, aMacLength,
950                                                                              aKeyLength, aIvLength);
951 }
952 
HandleMbedtlsExportKeys(const unsigned char * aMasterSecret,const unsigned char * aKeyBlock,size_t aMacLength,size_t aKeyLength,size_t aIvLength)953 int SecureTransport::HandleMbedtlsExportKeys(const unsigned char *aMasterSecret,
954                                              const unsigned char *aKeyBlock,
955                                              size_t aMacLength,
956                                              size_t aKeyLength,
957                                              size_t aIvLength)
958 {
959     OT_UNUSED_VARIABLE(aMasterSecret);
960 
961     Crypto::Sha256::Hash kek;
962     Crypto::Sha256 sha256;
963 
964     VerifyOrExit(mCipherSuite == kEcjpakeWithAes128Ccm8);
965 
966     sha256.Start();
967     sha256.Update(aKeyBlock, 2 * static_cast<uint16_t>(aMacLength + aKeyLength + aIvLength));
968     sha256.Finish(kek);
969 
970     mTimer.Get<KeyManager>().SetKek(kek.GetBytes());
971 
972 exit:
973     return 0;
974 }
975 
976 #endif // (MBEDTLS_VERSION_NUMBER >= 0x03000000)
977 
HandleUpdateTask(Tasklet & aTasklet)978 void SecureTransport::HandleUpdateTask(Tasklet &aTasklet)
979 {
980     static_cast<SecureTransport *>(static_cast<TaskletContext &>(aTasklet).GetContext())->HandleUpdateTask();
981 }
982 
HandleUpdateTask(void)983 void SecureTransport::HandleUpdateTask(void)
984 {
985     RemoveDisconnectedSessions();
986 
987     if (mSessions.IsEmpty() && HasNoRemainingConnectionAttempts())
988     {
989         Close();
990         mAutoCloseCallback.InvokeIfSet();
991     }
992 }
993 
HandleTimer(Timer & aTimer)994 void SecureTransport::HandleTimer(Timer &aTimer)
995 {
996     static_cast<SecureTransport *>(static_cast<TimerMilliContext &>(aTimer).GetContext())->HandleTimer();
997 }
998 
HandleTimer(void)999 void SecureTransport::HandleTimer(void)
1000 {
1001     TimeMilli now = TimerMilli::GetNow();
1002 
1003     VerifyOrExit(mIsOpen);
1004 
1005     for (SecureSession &session : mSessions)
1006     {
1007         session.HandleTimer(now);
1008     }
1009 
1010 exit:
1011     return;
1012 }
1013 
HandleMbedtlsDebug(void * aContext,int aLevel,const char * aFile,int aLine,const char * aStr)1014 void SecureTransport::HandleMbedtlsDebug(void *aContext, int aLevel, const char *aFile, int aLine, const char *aStr)
1015 {
1016     static_cast<SecureTransport *>(aContext)->HandleMbedtlsDebug(aLevel, aFile, aLine, aStr);
1017 }
1018 
HandleMbedtlsDebug(int aLevel,const char * aFile,int aLine,const char * aStr)1019 void SecureTransport::HandleMbedtlsDebug(int aLevel, const char *aFile, int aLine, const char *aStr)
1020 {
1021     LogLevel logLevel = kLogLevelDebg;
1022 
1023     switch (aLevel)
1024     {
1025     case 1:
1026         logLevel = kLogLevelCrit;
1027         break;
1028 
1029     case 2:
1030         logLevel = kLogLevelWarn;
1031         break;
1032 
1033     case 3:
1034         logLevel = kLogLevelInfo;
1035         break;
1036 
1037     case 4:
1038     default:
1039         break;
1040     }
1041 
1042     LogAt(logLevel, "[%u] %s", mSocket.GetSockName().mPort, aStr);
1043 
1044     OT_UNUSED_VARIABLE(aStr);
1045     OT_UNUSED_VARIABLE(aFile);
1046     OT_UNUSED_VARIABLE(aLine);
1047     OT_UNUSED_VARIABLE(logLevel);
1048 }
1049 
1050 //---------------------------------------------------------------------------------------------------------------------
1051 // SecureTransport::Extension
1052 
1053 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
1054 
SetApplicationSecureKeys(mbedtls_ssl_config & aConfig)1055 int SecureTransport::Extension::SetApplicationSecureKeys(mbedtls_ssl_config &aConfig)
1056 {
1057     int rval = 0;
1058 
1059     switch (mSecureTransport.mCipherSuite)
1060     {
1061     case kEcjpakeWithAes128Ccm8:
1062         // PSK will be set on `mbedtls_ssl_context` when set up.
1063         break;
1064 
1065 #ifdef MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
1066     case kEcdheEcdsaWithAes128Ccm8:
1067     case kEcdheEcdsaWithAes128GcmSha256:
1068         rval = mEcdheEcdsaInfo.SetSecureKeys(aConfig);
1069         VerifyOrExit(rval == 0);
1070         break;
1071 #endif
1072 
1073 #ifdef MBEDTLS_KEY_EXCHANGE_PSK_ENABLED
1074     case kPskWithAes128Ccm8:
1075         rval = mPskInfo.SetSecureKeys(aConfig);
1076         VerifyOrExit(rval == 0);
1077         break;
1078 #endif
1079 
1080     default:
1081         LogCrit("Application Coap Secure: Not supported cipher.");
1082         rval = MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
1083         ExitNow();
1084     }
1085 
1086 exit:
1087     return rval;
1088 }
1089 
1090 #ifdef MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
1091 
Init(void)1092 void SecureTransport::Extension::EcdheEcdsaInfo::Init(void)
1093 {
1094     mbedtls_x509_crt_init(&mCaChain);
1095     mbedtls_x509_crt_init(&mOwnCert);
1096     mbedtls_pk_init(&mPrivateKey);
1097 }
1098 
Free(void)1099 void SecureTransport::Extension::EcdheEcdsaInfo::Free(void)
1100 {
1101     mbedtls_x509_crt_free(&mCaChain);
1102     mbedtls_x509_crt_free(&mOwnCert);
1103     mbedtls_pk_free(&mPrivateKey);
1104 }
1105 
SetSecureKeys(mbedtls_ssl_config & aConfig)1106 int SecureTransport::Extension::EcdheEcdsaInfo::SetSecureKeys(mbedtls_ssl_config &aConfig)
1107 {
1108     int rval = 0;
1109 
1110     if (mCaChainSrc != nullptr)
1111     {
1112         rval = mbedtls_x509_crt_parse(&mCaChain, static_cast<const unsigned char *>(mCaChainSrc),
1113                                       static_cast<size_t>(mCaChainLength));
1114         VerifyOrExit(rval == 0);
1115         mbedtls_ssl_conf_ca_chain(&aConfig, &mCaChain, nullptr);
1116     }
1117 
1118     if (mOwnCertSrc != nullptr && mPrivateKeySrc != nullptr)
1119     {
1120         rval = mbedtls_x509_crt_parse(&mOwnCert, static_cast<const unsigned char *>(mOwnCertSrc),
1121                                       static_cast<size_t>(mOwnCertLength));
1122         VerifyOrExit(rval == 0);
1123 
1124 #if (MBEDTLS_VERSION_NUMBER >= 0x03000000)
1125         rval = mbedtls_pk_parse_key(&mPrivateKey, static_cast<const unsigned char *>(mPrivateKeySrc),
1126                                     static_cast<size_t>(mPrivateKeyLength), nullptr, 0,
1127                                     Crypto::MbedTls::CryptoSecurePrng, nullptr);
1128 #else
1129         rval = mbedtls_pk_parse_key(&mPrivateKey, static_cast<const unsigned char *>(mPrivateKeySrc),
1130                                     static_cast<size_t>(mPrivateKeyLength), nullptr, 0);
1131 #endif
1132         VerifyOrExit(rval == 0);
1133         rval = mbedtls_ssl_conf_own_cert(&aConfig, &mOwnCert, &mPrivateKey);
1134     }
1135 
1136 exit:
1137     return rval;
1138 }
1139 
SetCertificate(const uint8_t * aX509Certificate,uint32_t aX509CertLength,const uint8_t * aPrivateKey,uint32_t aPrivateKeyLength)1140 void SecureTransport::Extension::SetCertificate(const uint8_t *aX509Certificate,
1141                                                 uint32_t       aX509CertLength,
1142                                                 const uint8_t *aPrivateKey,
1143                                                 uint32_t       aPrivateKeyLength)
1144 {
1145     OT_ASSERT(aX509CertLength > 0);
1146     OT_ASSERT(aX509Certificate != nullptr);
1147 
1148     OT_ASSERT(aPrivateKeyLength > 0);
1149     OT_ASSERT(aPrivateKey != nullptr);
1150 
1151     mEcdheEcdsaInfo.mOwnCertSrc       = aX509Certificate;
1152     mEcdheEcdsaInfo.mOwnCertLength    = aX509CertLength;
1153     mEcdheEcdsaInfo.mPrivateKeySrc    = aPrivateKey;
1154     mEcdheEcdsaInfo.mPrivateKeyLength = aPrivateKeyLength;
1155 
1156     mSecureTransport.mCipherSuite =
1157         mSecureTransport.mDatagramTransport ? kEcdheEcdsaWithAes128Ccm8 : kEcdheEcdsaWithAes128GcmSha256;
1158 }
1159 
SetCaCertificateChain(const uint8_t * aX509CaCertificateChain,uint32_t aX509CaCertChainLength)1160 void SecureTransport::Extension::SetCaCertificateChain(const uint8_t *aX509CaCertificateChain,
1161                                                        uint32_t       aX509CaCertChainLength)
1162 {
1163     OT_ASSERT(aX509CaCertChainLength > 0);
1164     OT_ASSERT(aX509CaCertificateChain != nullptr);
1165 
1166     mEcdheEcdsaInfo.mCaChainSrc    = aX509CaCertificateChain;
1167     mEcdheEcdsaInfo.mCaChainLength = aX509CaCertChainLength;
1168 }
1169 
1170 #endif // MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
1171 
1172 #ifdef MBEDTLS_KEY_EXCHANGE_PSK_ENABLED
1173 
SetSecureKeys(mbedtls_ssl_config & aConfig) const1174 int SecureTransport::Extension::PskInfo::SetSecureKeys(mbedtls_ssl_config &aConfig) const
1175 {
1176     return mbedtls_ssl_conf_psk(&aConfig, static_cast<const unsigned char *>(mPreSharedKey), mPreSharedKeyLength,
1177                                 static_cast<const unsigned char *>(mPreSharedKeyIdentity), mPreSharedKeyIdLength);
1178 }
1179 
SetPreSharedKey(const uint8_t * aPsk,uint16_t aPskLength,const uint8_t * aPskIdentity,uint16_t aPskIdLength)1180 void SecureTransport::Extension::SetPreSharedKey(const uint8_t *aPsk,
1181                                                  uint16_t       aPskLength,
1182                                                  const uint8_t *aPskIdentity,
1183                                                  uint16_t       aPskIdLength)
1184 {
1185     OT_ASSERT(aPsk != nullptr);
1186     OT_ASSERT(aPskIdentity != nullptr);
1187     OT_ASSERT(aPskLength > 0);
1188     OT_ASSERT(aPskIdLength > 0);
1189 
1190     mPskInfo.mPreSharedKey         = aPsk;
1191     mPskInfo.mPreSharedKeyLength   = aPskLength;
1192     mPskInfo.mPreSharedKeyIdentity = aPskIdentity;
1193     mPskInfo.mPreSharedKeyIdLength = aPskIdLength;
1194 
1195     mSecureTransport.mCipherSuite = kPskWithAes128Ccm8;
1196 }
1197 
1198 #endif // MBEDTLS_KEY_EXCHANGE_PSK_ENABLED
1199 
1200 #if defined(MBEDTLS_BASE64_C) && defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
GetPeerCertificateBase64(unsigned char * aPeerCert,size_t * aCertLength,size_t aCertBufferSize)1201 Error SecureTransport::Extension::GetPeerCertificateBase64(unsigned char *aPeerCert,
1202                                                            size_t        *aCertLength,
1203                                                            size_t         aCertBufferSize)
1204 {
1205     Error          error   = kErrorNone;
1206     SecureSession *session = mSecureTransport.mSessions.GetHead();
1207 
1208     VerifyOrExit(session != nullptr, error = kErrorInvalidState);
1209     VerifyOrExit(session->IsConnected(), error = kErrorInvalidState);
1210 
1211 #if (MBEDTLS_VERSION_NUMBER >= 0x03010000)
1212     VerifyOrExit(mbedtls_base64_encode(aPeerCert, aCertBufferSize, aCertLength,
1213                                        session->mSsl.MBEDTLS_PRIVATE(session)->MBEDTLS_PRIVATE(peer_cert)->raw.p,
1214                                        session->mSsl.MBEDTLS_PRIVATE(session)->MBEDTLS_PRIVATE(peer_cert)->raw.len) ==
1215                      0,
1216                  error = kErrorNoBufs);
1217 #else
1218     VerifyOrExit(
1219         mbedtls_base64_encode(
1220             aPeerCert, aCertBufferSize, aCertLength,
1221             session->mSsl.MBEDTLS_PRIVATE(session)->MBEDTLS_PRIVATE(peer_cert)->MBEDTLS_PRIVATE(raw).MBEDTLS_PRIVATE(p),
1222             session->mSsl.MBEDTLS_PRIVATE(session)->MBEDTLS_PRIVATE(peer_cert)->MBEDTLS_PRIVATE(raw).MBEDTLS_PRIVATE(
1223                 len)) == 0,
1224         error = kErrorNoBufs);
1225 #endif
1226 
1227 exit:
1228     return error;
1229 }
1230 #endif // defined(MBEDTLS_BASE64_C) && defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
1231 
1232 #if defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
GetPeerCertificateDer(uint8_t * aPeerCert,size_t * aCertLength,size_t aCertBufferSize)1233 Error SecureTransport::Extension::GetPeerCertificateDer(uint8_t *aPeerCert, size_t *aCertLength, size_t aCertBufferSize)
1234 {
1235     Error          error   = kErrorNone;
1236     SecureSession *session = mSecureTransport.mSessions.GetHead();
1237 
1238     VerifyOrExit(session->IsConnected(), error = kErrorInvalidState);
1239 
1240 #if (MBEDTLS_VERSION_NUMBER >= 0x03010000)
1241     VerifyOrExit(session->mSsl.MBEDTLS_PRIVATE(session)->MBEDTLS_PRIVATE(peer_cert)->raw.len < aCertBufferSize,
1242                  error = kErrorNoBufs);
1243 
1244     *aCertLength = session->mSsl.MBEDTLS_PRIVATE(session)->MBEDTLS_PRIVATE(peer_cert)->raw.len;
1245     memcpy(aPeerCert, session->mSsl.MBEDTLS_PRIVATE(session)->MBEDTLS_PRIVATE(peer_cert)->raw.p, *aCertLength);
1246 
1247 #else
1248     VerifyOrExit(
1249         session->mSsl.MBEDTLS_PRIVATE(session)->MBEDTLS_PRIVATE(peer_cert)->MBEDTLS_PRIVATE(raw).MBEDTLS_PRIVATE(len) <
1250             aCertBufferSize,
1251         error = kErrorNoBufs);
1252 
1253     *aCertLength =
1254         session->mSsl.MBEDTLS_PRIVATE(session)->MBEDTLS_PRIVATE(peer_cert)->MBEDTLS_PRIVATE(raw).MBEDTLS_PRIVATE(len);
1255     memcpy(aPeerCert,
1256            session->mSsl.MBEDTLS_PRIVATE(session)->MBEDTLS_PRIVATE(peer_cert)->MBEDTLS_PRIVATE(raw).MBEDTLS_PRIVATE(p),
1257            *aCertLength);
1258 #endif
1259 
1260 exit:
1261     return error;
1262 }
1263 
GetPeerSubjectAttributeByOid(const char * aOid,size_t aOidLength,uint8_t * aAttributeBuffer,size_t * aAttributeLength,int * aAsn1Type)1264 Error SecureTransport::Extension::GetPeerSubjectAttributeByOid(const char *aOid,
1265                                                                size_t      aOidLength,
1266                                                                uint8_t    *aAttributeBuffer,
1267                                                                size_t     *aAttributeLength,
1268                                                                int        *aAsn1Type)
1269 {
1270     Error                          error = kErrorNone;
1271     const mbedtls_asn1_named_data *data;
1272     size_t                         length;
1273     size_t                         attributeBufferSize;
1274     SecureSession                 *session;
1275     mbedtls_x509_crt              *peerCert;
1276 
1277     session = mSecureTransport.mSessions.GetHead();
1278     VerifyOrExit(session != nullptr, error = kErrorInvalidState);
1279 
1280     peerCert = const_cast<mbedtls_x509_crt *>(mbedtls_ssl_get_peer_cert(&session->mSsl));
1281 
1282     VerifyOrExit(aAttributeLength != nullptr, error = kErrorInvalidArgs);
1283     attributeBufferSize = *aAttributeLength;
1284     *aAttributeLength   = 0;
1285 
1286     VerifyOrExit(aAttributeBuffer != nullptr, error = kErrorNoBufs);
1287     VerifyOrExit(peerCert != nullptr, error = kErrorInvalidState);
1288 
1289     data = mbedtls_asn1_find_named_data(&peerCert->subject, aOid, aOidLength);
1290     VerifyOrExit(data != nullptr, error = kErrorNotFound);
1291 
1292     length = data->val.len;
1293     VerifyOrExit(length <= attributeBufferSize, error = kErrorNoBufs);
1294     *aAttributeLength = length;
1295 
1296     if (aAsn1Type != nullptr)
1297     {
1298         *aAsn1Type = data->val.tag;
1299     }
1300 
1301     memcpy(aAttributeBuffer, data->val.p, length);
1302 
1303 exit:
1304     return error;
1305 }
1306 
GetThreadAttributeFromPeerCertificate(int aThreadOidDescriptor,uint8_t * aAttributeBuffer,size_t * aAttributeLength)1307 Error SecureTransport::Extension::GetThreadAttributeFromPeerCertificate(int      aThreadOidDescriptor,
1308                                                                         uint8_t *aAttributeBuffer,
1309                                                                         size_t  *aAttributeLength)
1310 {
1311     Error                   error;
1312     SecureSession          *session = mSecureTransport.mSessions.GetHead();
1313     const mbedtls_x509_crt *cert;
1314 
1315     VerifyOrExit(session != nullptr, error = kErrorInvalidState);
1316     cert  = mbedtls_ssl_get_peer_cert(&session->mSsl);
1317     error = GetThreadAttributeFromCertificate(cert, aThreadOidDescriptor, aAttributeBuffer, aAttributeLength);
1318 
1319 exit:
1320     return error;
1321 }
1322 
1323 #endif // defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
1324 
GetThreadAttributeFromOwnCertificate(int aThreadOidDescriptor,uint8_t * aAttributeBuffer,size_t * aAttributeLength)1325 Error SecureTransport::Extension::GetThreadAttributeFromOwnCertificate(int      aThreadOidDescriptor,
1326                                                                        uint8_t *aAttributeBuffer,
1327                                                                        size_t  *aAttributeLength)
1328 {
1329     const mbedtls_x509_crt *cert = &mEcdheEcdsaInfo.mOwnCert;
1330 
1331     return GetThreadAttributeFromCertificate(cert, aThreadOidDescriptor, aAttributeBuffer, aAttributeLength);
1332 }
1333 
GetThreadAttributeFromCertificate(const mbedtls_x509_crt * aCert,int aThreadOidDescriptor,uint8_t * aAttributeBuffer,size_t * aAttributeLength)1334 Error SecureTransport::Extension::GetThreadAttributeFromCertificate(const mbedtls_x509_crt *aCert,
1335                                                                     int                     aThreadOidDescriptor,
1336                                                                     uint8_t                *aAttributeBuffer,
1337                                                                     size_t                 *aAttributeLength)
1338 {
1339     Error            error  = kErrorNotFound;
1340     char             oid[9] = {0x2B, 0x06, 0x01, 0x04, 0x01, static_cast<char>(0x82), static_cast<char>(0xDF),
1341                                0x2A, 0x00}; // 1.3.6.1.4.1.44970.0
1342     mbedtls_x509_buf v3_ext;
1343     unsigned char   *p, *end, *endExtData;
1344     size_t           len;
1345     size_t           attributeBufferSize;
1346     mbedtls_x509_buf extnOid;
1347     int              ret, isCritical;
1348 
1349     VerifyOrExit(aAttributeLength != nullptr, error = kErrorInvalidArgs);
1350     attributeBufferSize = *aAttributeLength;
1351     *aAttributeLength   = 0;
1352 
1353     VerifyOrExit(aCert != nullptr, error = kErrorInvalidState);
1354     v3_ext = aCert->v3_ext;
1355     p      = v3_ext.p;
1356     VerifyOrExit(p != nullptr, error = kErrorInvalidState);
1357     end = p + v3_ext.len;
1358     VerifyOrExit(mbedtls_asn1_get_tag(&p, end, &len, MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE) == 0,
1359                  error = kErrorParse);
1360     VerifyOrExit(end == p + len, error = kErrorParse);
1361 
1362     VerifyOrExit(aThreadOidDescriptor < 128, error = kErrorNotImplemented);
1363     oid[sizeof(oid) - 1] = static_cast<char>(aThreadOidDescriptor);
1364 
1365     while (p < end)
1366     {
1367         isCritical = 0;
1368         VerifyOrExit(mbedtls_asn1_get_tag(&p, end, &len, MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE) == 0,
1369                      error = kErrorParse);
1370         endExtData = p + len;
1371 
1372         // Get extension ID
1373         VerifyOrExit(mbedtls_asn1_get_tag(&p, endExtData, &extnOid.len, MBEDTLS_ASN1_OID) == 0, error = kErrorParse);
1374         extnOid.tag = MBEDTLS_ASN1_OID;
1375         extnOid.p   = p;
1376         p += extnOid.len;
1377 
1378         // Get optional critical
1379         ret = mbedtls_asn1_get_bool(&p, endExtData, &isCritical);
1380         VerifyOrExit(ret == 0 || ret == MBEDTLS_ERR_ASN1_UNEXPECTED_TAG, error = kErrorParse);
1381 
1382         // Data must be octet string type, see https://datatracker.ietf.org/doc/html/rfc5280#section-4.1
1383         VerifyOrExit(mbedtls_asn1_get_tag(&p, endExtData, &len, MBEDTLS_ASN1_OCTET_STRING) == 0, error = kErrorParse);
1384         VerifyOrExit(endExtData == p + len, error = kErrorParse);
1385 
1386         // TODO: extensions with isCritical == 1 that are unknown should lead to rejection of the entire cert.
1387         if (extnOid.len == sizeof(oid) && memcmp(extnOid.p, oid, sizeof(oid)) == 0)
1388         {
1389             // per RFC 5280, octet string must contain ASN.1 Type Length Value octets
1390             VerifyOrExit(len >= 2, error = kErrorParse);
1391             VerifyOrExit(*(p + 1) == len - 2, error = kErrorParse); // check TLV Length, not Type.
1392             *aAttributeLength = len - 2; // strip the ASN.1 Type Length bytes from embedded TLV
1393 
1394             if (aAttributeBuffer != nullptr)
1395             {
1396                 VerifyOrExit(*aAttributeLength <= attributeBufferSize, error = kErrorNoBufs);
1397                 memcpy(aAttributeBuffer, p + 2, *aAttributeLength);
1398             }
1399 
1400             error = kErrorNone;
1401             break;
1402         }
1403         p += len;
1404     }
1405 
1406 exit:
1407     return error;
1408 }
1409 
1410 #endif // OPENTHREAD_CONFIG_TLS_API_ENABLE
1411 
1412 #if OPENTHREAD_CONFIG_BLE_TCAT_ENABLE
1413 
1414 //---------------------------------------------------------------------------------------------------------------------
1415 // Tls
1416 
HandleAccept(void * aContext,const Ip6::MessageInfo & aMessageInfo)1417 SecureSession *Tls::HandleAccept(void *aContext, const Ip6::MessageInfo &aMessageInfo)
1418 {
1419     OT_UNUSED_VARIABLE(aMessageInfo);
1420 
1421     return static_cast<Tls *>(aContext)->HandleAccept();
1422 }
1423 
HandleAccept(void)1424 SecureSession *Tls::HandleAccept(void) { return IsSessionInUse() ? nullptr : static_cast<SecureSession *>(this); }
1425 
1426 #endif
1427 
1428 } // namespace MeshCoP
1429 } // namespace ot
1430 
1431 #endif // OPENTHREAD_CONFIG_SECURE_TRANSPORT_ENABLE
1432