• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2016, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements the necessary hooks for mbedTLS.
32  */
33 
34 #include "secure_transport.hpp"
35 
36 #include <mbedtls/debug.h>
37 #ifdef MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
38 #include <mbedtls/pem.h>
39 #endif
40 
41 #include <openthread/platform/radio.h>
42 
43 #include "common/as_core_type.hpp"
44 #include "common/clearable.hpp"
45 #include "common/code_utils.hpp"
46 #include "common/debug.hpp"
47 #include "common/encoding.hpp"
48 #include "common/locator_getters.hpp"
49 #include "common/log.hpp"
50 #include "common/timer.hpp"
51 #include "crypto/mbedtls.hpp"
52 #include "crypto/sha256.hpp"
53 #include "instance/instance.hpp"
54 #include "thread/thread_netif.hpp"
55 
56 #if OPENTHREAD_CONFIG_SECURE_TRANSPORT_ENABLE
57 
58 namespace ot {
59 namespace MeshCoP {
60 
61 RegisterLogModule("SecTransport");
62 
63 #if (MBEDTLS_VERSION_NUMBER >= 0x03010000)
64 const uint16_t SecureTransport::sGroups[] = {MBEDTLS_SSL_IANA_TLS_GROUP_SECP256R1, MBEDTLS_SSL_IANA_TLS_GROUP_NONE};
65 #else
66 const mbedtls_ecp_group_id SecureTransport::sCurves[] = {MBEDTLS_ECP_DP_SECP256R1, MBEDTLS_ECP_DP_NONE};
67 #endif
68 
69 #if defined(MBEDTLS_KEY_EXCHANGE__WITH_CERT__ENABLED) || defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
70 #if (MBEDTLS_VERSION_NUMBER >= 0x03020000)
71 const uint16_t SecureTransport::sSignatures[] = {MBEDTLS_TLS1_3_SIG_ECDSA_SECP256R1_SHA256, MBEDTLS_TLS1_3_SIG_NONE};
72 #else
73 const int SecureTransport::sHashes[] = {MBEDTLS_MD_SHA256, MBEDTLS_MD_NONE};
74 #endif
75 #endif
76 
SecureTransport(Instance & aInstance,bool aLayerTwoSecurity,bool aDatagramTransport)77 SecureTransport::SecureTransport(Instance &aInstance, bool aLayerTwoSecurity, bool aDatagramTransport)
78     : InstanceLocator(aInstance)
79     , mState(kStateClosed)
80     , mPskLength(0)
81     , mVerifyPeerCertificate(true)
82     , mTimer(aInstance, SecureTransport::HandleTimer, this)
83     , mTimerIntermediate(0)
84     , mTimerSet(false)
85     , mLayerTwoSecurity(aLayerTwoSecurity)
86     , mDatagramTransport(aDatagramTransport)
87     , mMaxConnectionAttempts(0)
88     , mRemainingConnectionAttempts(0)
89     , mReceiveMessage(nullptr)
90     , mSocket(aInstance)
91     , mMessageSubType(Message::kSubTypeNone)
92     , mMessageDefaultSubType(Message::kSubTypeNone)
93 {
94 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
95 #ifdef MBEDTLS_KEY_EXCHANGE_PSK_ENABLED
96     mPreSharedKey         = nullptr;
97     mPreSharedKeyIdentity = nullptr;
98     mPreSharedKeyIdLength = 0;
99     mPreSharedKeyLength   = 0;
100 #endif
101 
102 #ifdef MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
103     mCaChainSrc       = nullptr;
104     mCaChainLength    = 0;
105     mOwnCertSrc       = nullptr;
106     mOwnCertLength    = 0;
107     mPrivateKeySrc    = nullptr;
108     mPrivateKeyLength = 0;
109     ClearAllBytes(mCaChain);
110     ClearAllBytes(mOwnCert);
111     ClearAllBytes(mPrivateKey);
112 #endif
113 #endif
114 
115     ClearAllBytes(mCipherSuites);
116     ClearAllBytes(mPsk);
117     ClearAllBytes(mSsl);
118     ClearAllBytes(mConf);
119 
120 #ifdef MBEDTLS_SSL_COOKIE_C
121     ClearAllBytes(mCookieCtx);
122 #endif
123 }
124 
FreeMbedtls(void)125 void SecureTransport::FreeMbedtls(void)
126 {
127 #if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_COOKIE_C)
128     if (mDatagramTransport)
129     {
130         mbedtls_ssl_cookie_free(&mCookieCtx);
131     }
132 #endif
133 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
134 #ifdef MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
135     mbedtls_x509_crt_free(&mCaChain);
136     mbedtls_x509_crt_free(&mOwnCert);
137     mbedtls_pk_free(&mPrivateKey);
138 #endif
139 #endif
140     mbedtls_ssl_config_free(&mConf);
141     mbedtls_ssl_free(&mSsl);
142 }
143 
SetState(State aState)144 void SecureTransport::SetState(State aState)
145 {
146     VerifyOrExit(mState != aState);
147 
148     LogInfo("State: %s -> %s", StateToString(mState), StateToString(aState));
149     mState = aState;
150 
151 exit:
152     return;
153 }
154 
Open(ReceiveHandler aReceiveHandler,ConnectedHandler aConnectedHandler,void * aContext)155 Error SecureTransport::Open(ReceiveHandler aReceiveHandler, ConnectedHandler aConnectedHandler, void *aContext)
156 {
157     Error error;
158 
159     VerifyOrExit(IsStateClosed(), error = kErrorAlready);
160 
161     SuccessOrExit(error = mSocket.Open(&SecureTransport::HandleReceive, this));
162 
163     mConnectedCallback.Set(aConnectedHandler, aContext);
164     mReceiveCallback.Set(aReceiveHandler, aContext);
165 
166     mRemainingConnectionAttempts = mMaxConnectionAttempts;
167 
168     SetState(kStateOpen);
169 
170 exit:
171     return error;
172 }
173 
SetMaxConnectionAttempts(uint16_t aMaxAttempts,AutoCloseCallback aCallback,void * aContext)174 Error SecureTransport::SetMaxConnectionAttempts(uint16_t aMaxAttempts, AutoCloseCallback aCallback, void *aContext)
175 {
176     Error error = kErrorNone;
177 
178     VerifyOrExit(IsStateClosed(), error = kErrorInvalidState);
179 
180     mMaxConnectionAttempts = aMaxAttempts;
181     mAutoCloseCallback.Set(aCallback, aContext);
182 
183 exit:
184     return error;
185 }
186 
Connect(const Ip6::SockAddr & aSockAddr)187 Error SecureTransport::Connect(const Ip6::SockAddr &aSockAddr)
188 {
189     Error error;
190 
191     VerifyOrExit(IsStateOpen(), error = kErrorInvalidState);
192 
193     if (mRemainingConnectionAttempts > 0)
194     {
195         mRemainingConnectionAttempts--;
196     }
197 
198     mMessageInfo.SetPeerAddr(aSockAddr.GetAddress());
199     mMessageInfo.SetPeerPort(aSockAddr.mPort);
200 
201     error = Setup(true);
202 
203 exit:
204     return error;
205 }
206 
HandleReceive(void * aContext,otMessage * aMessage,const otMessageInfo * aMessageInfo)207 void SecureTransport::HandleReceive(void *aContext, otMessage *aMessage, const otMessageInfo *aMessageInfo)
208 {
209     static_cast<SecureTransport *>(aContext)->HandleReceive(AsCoreType(aMessage), AsCoreType(aMessageInfo));
210 }
211 
HandleReceive(Message & aMessage,const Ip6::MessageInfo & aMessageInfo)212 void SecureTransport::HandleReceive(Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
213 {
214     VerifyOrExit(!IsStateClosed());
215 
216     if (IsStateOpen())
217     {
218         if (mRemainingConnectionAttempts > 0)
219         {
220             mRemainingConnectionAttempts--;
221         }
222 
223         mMessageInfo.SetPeerAddr(aMessageInfo.GetPeerAddr());
224         mMessageInfo.SetPeerPort(aMessageInfo.GetPeerPort());
225         mMessageInfo.SetIsHostInterface(aMessageInfo.IsHostInterface());
226 
227         mMessageInfo.SetSockAddr(aMessageInfo.GetSockAddr());
228         mMessageInfo.SetSockPort(aMessageInfo.GetSockPort());
229 
230         SuccessOrExit(Setup(false));
231     }
232     else
233     {
234         // Once DTLS session is started, communicate only with a peer.
235         VerifyOrExit((mMessageInfo.GetPeerAddr() == aMessageInfo.GetPeerAddr()) &&
236                      (mMessageInfo.GetPeerPort() == aMessageInfo.GetPeerPort()));
237     }
238 
239 #ifdef MBEDTLS_SSL_SRV_C
240     if (IsStateConnecting())
241     {
242         IgnoreError(SetClientId(mMessageInfo.GetPeerAddr().mFields.m8, sizeof(mMessageInfo.GetPeerAddr().mFields)));
243     }
244 #endif
245 
246     Receive(aMessage);
247 
248 exit:
249     return;
250 }
251 
GetUdpPort(void) const252 uint16_t SecureTransport::GetUdpPort(void) const { return mSocket.GetSockName().GetPort(); }
253 
Bind(uint16_t aPort)254 Error SecureTransport::Bind(uint16_t aPort)
255 {
256     Error error;
257 
258     VerifyOrExit(IsStateOpen(), error = kErrorInvalidState);
259     VerifyOrExit(!mTransportCallback.IsSet(), error = kErrorAlready);
260 
261     SuccessOrExit(error = mSocket.Bind(aPort, Ip6::kNetifUnspecified));
262 
263 exit:
264     return error;
265 }
266 
Bind(TransportCallback aCallback,void * aContext)267 Error SecureTransport::Bind(TransportCallback aCallback, void *aContext)
268 {
269     Error error = kErrorNone;
270 
271     VerifyOrExit(IsStateOpen(), error = kErrorInvalidState);
272     VerifyOrExit(!mSocket.IsBound(), error = kErrorAlready);
273     VerifyOrExit(!mTransportCallback.IsSet(), error = kErrorAlready);
274 
275     mTransportCallback.Set(aCallback, aContext);
276 
277 exit:
278     return error;
279 }
280 
Setup(bool aClient)281 Error SecureTransport::Setup(bool aClient)
282 {
283     int rval;
284 
285     // do not handle new connection before guard time expired
286     VerifyOrExit(IsStateOpen(), rval = MBEDTLS_ERR_SSL_TIMEOUT);
287 
288     SetState(kStateInitializing);
289 
290     mbedtls_ssl_init(&mSsl);
291     mbedtls_ssl_config_init(&mConf);
292 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
293 #ifdef MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
294     mbedtls_x509_crt_init(&mCaChain);
295     mbedtls_x509_crt_init(&mOwnCert);
296     mbedtls_pk_init(&mPrivateKey);
297 #endif
298 #endif
299 #if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_COOKIE_C)
300     if (mDatagramTransport)
301     {
302         mbedtls_ssl_cookie_init(&mCookieCtx);
303     }
304 #endif
305 
306     rval = mbedtls_ssl_config_defaults(
307         &mConf, aClient ? MBEDTLS_SSL_IS_CLIENT : MBEDTLS_SSL_IS_SERVER,
308         mDatagramTransport ? MBEDTLS_SSL_TRANSPORT_DATAGRAM : MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT);
309     VerifyOrExit(rval == 0);
310 
311 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
312     if (mVerifyPeerCertificate && (mCipherSuites[0] == MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 ||
313                                    mCipherSuites[0] == MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256))
314     {
315         mbedtls_ssl_conf_authmode(&mConf, MBEDTLS_SSL_VERIFY_REQUIRED);
316     }
317     else
318     {
319         mbedtls_ssl_conf_authmode(&mConf, MBEDTLS_SSL_VERIFY_NONE);
320     }
321 #else
322     OT_UNUSED_VARIABLE(mVerifyPeerCertificate);
323 #endif
324 
325     mbedtls_ssl_conf_rng(&mConf, Crypto::MbedTls::CryptoSecurePrng, nullptr);
326 #if (MBEDTLS_VERSION_NUMBER >= 0x03020000)
327     mbedtls_ssl_conf_min_tls_version(&mConf, MBEDTLS_SSL_VERSION_TLS1_2);
328     mbedtls_ssl_conf_max_tls_version(&mConf, MBEDTLS_SSL_VERSION_TLS1_2);
329 #else
330     mbedtls_ssl_conf_min_version(&mConf, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3);
331     mbedtls_ssl_conf_max_version(&mConf, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3);
332 #endif
333 
334     OT_ASSERT(mCipherSuites[1] == 0);
335     mbedtls_ssl_conf_ciphersuites(&mConf, mCipherSuites);
336     if (mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8)
337     {
338 #if (MBEDTLS_VERSION_NUMBER >= 0x03010000)
339         mbedtls_ssl_conf_groups(&mConf, sGroups);
340 #else
341         mbedtls_ssl_conf_curves(&mConf, sCurves);
342 #endif
343 #if defined(MBEDTLS_KEY_EXCHANGE__WITH_CERT__ENABLED) || defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
344 #if (MBEDTLS_VERSION_NUMBER >= 0x03020000)
345         mbedtls_ssl_conf_sig_algs(&mConf, sSignatures);
346 #else
347         mbedtls_ssl_conf_sig_hashes(&mConf, sHashes);
348 #endif
349 #endif
350     }
351 
352 #if (MBEDTLS_VERSION_NUMBER >= 0x03000000)
353     mbedtls_ssl_set_export_keys_cb(&mSsl, HandleMbedtlsExportKeys, this);
354 #else
355     mbedtls_ssl_conf_export_keys_cb(&mConf, HandleMbedtlsExportKeys, this);
356 #endif
357 
358     mbedtls_ssl_conf_handshake_timeout(&mConf, 8000, 60000);
359     mbedtls_ssl_conf_dbg(&mConf, HandleMbedtlsDebug, this);
360 
361 #if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_COOKIE_C)
362     if (!aClient && mDatagramTransport)
363     {
364         rval = mbedtls_ssl_cookie_setup(&mCookieCtx, Crypto::MbedTls::CryptoSecurePrng, nullptr);
365         VerifyOrExit(rval == 0);
366 
367         mbedtls_ssl_conf_dtls_cookies(&mConf, mbedtls_ssl_cookie_write, mbedtls_ssl_cookie_check, &mCookieCtx);
368     }
369 #endif
370 
371     rval = mbedtls_ssl_setup(&mSsl, &mConf);
372     VerifyOrExit(rval == 0);
373 
374     mbedtls_ssl_set_bio(&mSsl, this, &SecureTransport::HandleMbedtlsTransmit, HandleMbedtlsReceive, nullptr);
375 
376     if (mDatagramTransport)
377     {
378         mbedtls_ssl_set_timer_cb(&mSsl, this, &SecureTransport::HandleMbedtlsSetTimer, HandleMbedtlsGetTimer);
379     }
380 
381     if (mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8)
382     {
383         rval = mbedtls_ssl_set_hs_ecjpake_password(&mSsl, mPsk, mPskLength);
384     }
385 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
386     else
387     {
388         rval = SetApplicationSecureKeys();
389     }
390 #endif
391     VerifyOrExit(rval == 0);
392 
393     mReceiveMessage = nullptr;
394     mMessageSubType = Message::kSubTypeNone;
395 
396     if (mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8)
397     {
398         LogInfo("DTLS started");
399     }
400 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
401     else
402     {
403         LogInfo("Application Secure (D)TLS started");
404     }
405 #endif
406 
407     SetState(kStateConnecting);
408 
409     Process();
410 
411 exit:
412     if (IsStateInitializing() && (rval != 0))
413     {
414         if ((mMaxConnectionAttempts > 0) && (mRemainingConnectionAttempts == 0))
415         {
416             Close();
417             mAutoCloseCallback.InvokeIfSet();
418         }
419         else
420         {
421             SetState(kStateOpen);
422             FreeMbedtls();
423         }
424     }
425 
426     return Crypto::MbedTls::MapError(rval);
427 }
428 
429 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
SetApplicationSecureKeys(void)430 int SecureTransport::SetApplicationSecureKeys(void)
431 {
432     int rval = 0;
433 
434     switch (mCipherSuites[0])
435     {
436     case MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8:
437     case MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
438 
439 #ifdef MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
440         if (mCaChainSrc != nullptr)
441         {
442             rval = mbedtls_x509_crt_parse(&mCaChain, static_cast<const unsigned char *>(mCaChainSrc),
443                                           static_cast<size_t>(mCaChainLength));
444             VerifyOrExit(rval == 0);
445             mbedtls_ssl_conf_ca_chain(&mConf, &mCaChain, nullptr);
446         }
447 
448         if (mOwnCertSrc != nullptr && mPrivateKeySrc != nullptr)
449         {
450             rval = mbedtls_x509_crt_parse(&mOwnCert, static_cast<const unsigned char *>(mOwnCertSrc),
451                                           static_cast<size_t>(mOwnCertLength));
452             VerifyOrExit(rval == 0);
453 
454 #if (MBEDTLS_VERSION_NUMBER >= 0x03000000)
455             rval = mbedtls_pk_parse_key(&mPrivateKey, static_cast<const unsigned char *>(mPrivateKeySrc),
456                                         static_cast<size_t>(mPrivateKeyLength), nullptr, 0,
457                                         Crypto::MbedTls::CryptoSecurePrng, nullptr);
458 #else
459             rval = mbedtls_pk_parse_key(&mPrivateKey, static_cast<const unsigned char *>(mPrivateKeySrc),
460                                         static_cast<size_t>(mPrivateKeyLength), nullptr, 0);
461 #endif
462             VerifyOrExit(rval == 0);
463             rval = mbedtls_ssl_conf_own_cert(&mConf, &mOwnCert, &mPrivateKey);
464             VerifyOrExit(rval == 0);
465         }
466 #endif
467         break;
468 
469     case MBEDTLS_TLS_PSK_WITH_AES_128_CCM_8:
470 #ifdef MBEDTLS_KEY_EXCHANGE_PSK_ENABLED
471         rval = mbedtls_ssl_conf_psk(&mConf, static_cast<const unsigned char *>(mPreSharedKey), mPreSharedKeyLength,
472                                     static_cast<const unsigned char *>(mPreSharedKeyIdentity), mPreSharedKeyIdLength);
473         VerifyOrExit(rval == 0);
474 #endif
475         break;
476 
477     default:
478         LogCrit("Application Coap Secure: Not supported cipher.");
479         rval = MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
480         ExitNow();
481         break;
482     }
483 
484 exit:
485     return rval;
486 }
487 
488 #endif // OPENTHREAD_CONFIG_TLS_API_ENABLE
489 
Close(void)490 void SecureTransport::Close(void)
491 {
492     Disconnect();
493 
494     SetState(kStateClosed);
495     mTimerSet = false;
496     mTransportCallback.Clear();
497 
498     IgnoreError(mSocket.Close());
499     mTimer.Stop();
500 }
501 
Disconnect(void)502 void SecureTransport::Disconnect(void)
503 {
504     VerifyOrExit(IsStateConnectingOrConnected());
505 
506     mbedtls_ssl_close_notify(&mSsl);
507     SetState(kStateCloseNotify);
508     mTimer.Start(kGuardTimeNewConnectionMilli);
509 
510     mMessageInfo.Clear();
511 
512     FreeMbedtls();
513 
514 exit:
515     return;
516 }
517 
SetPsk(const uint8_t * aPsk,uint8_t aPskLength)518 Error SecureTransport::SetPsk(const uint8_t *aPsk, uint8_t aPskLength)
519 {
520     Error error = kErrorNone;
521 
522     VerifyOrExit(aPskLength <= sizeof(mPsk), error = kErrorInvalidArgs);
523 
524     memcpy(mPsk, aPsk, aPskLength);
525     mPskLength       = aPskLength;
526     mCipherSuites[0] = MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8;
527     mCipherSuites[1] = 0;
528 
529 exit:
530     return error;
531 }
532 
533 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
534 #ifdef MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
535 
SetCertificate(const uint8_t * aX509Certificate,uint32_t aX509CertLength,const uint8_t * aPrivateKey,uint32_t aPrivateKeyLength)536 void SecureTransport::SetCertificate(const uint8_t *aX509Certificate,
537                                      uint32_t       aX509CertLength,
538                                      const uint8_t *aPrivateKey,
539                                      uint32_t       aPrivateKeyLength)
540 {
541     OT_ASSERT(aX509CertLength > 0);
542     OT_ASSERT(aX509Certificate != nullptr);
543 
544     OT_ASSERT(aPrivateKeyLength > 0);
545     OT_ASSERT(aPrivateKey != nullptr);
546 
547     mOwnCertSrc       = aX509Certificate;
548     mOwnCertLength    = aX509CertLength;
549     mPrivateKeySrc    = aPrivateKey;
550     mPrivateKeyLength = aPrivateKeyLength;
551 
552     if (mDatagramTransport)
553     {
554         mCipherSuites[0] = MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8;
555     }
556     else
557     {
558         mCipherSuites[0] = MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256;
559     }
560 
561     mCipherSuites[1] = 0;
562 }
563 
SetCaCertificateChain(const uint8_t * aX509CaCertificateChain,uint32_t aX509CaCertChainLength)564 void SecureTransport::SetCaCertificateChain(const uint8_t *aX509CaCertificateChain, uint32_t aX509CaCertChainLength)
565 {
566     OT_ASSERT(aX509CaCertChainLength > 0);
567     OT_ASSERT(aX509CaCertificateChain != nullptr);
568 
569     mCaChainSrc    = aX509CaCertificateChain;
570     mCaChainLength = aX509CaCertChainLength;
571 }
572 
573 #endif // MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
574 
575 #ifdef MBEDTLS_KEY_EXCHANGE_PSK_ENABLED
SetPreSharedKey(const uint8_t * aPsk,uint16_t aPskLength,const uint8_t * aPskIdentity,uint16_t aPskIdLength)576 void SecureTransport::SetPreSharedKey(const uint8_t *aPsk,
577                                       uint16_t       aPskLength,
578                                       const uint8_t *aPskIdentity,
579                                       uint16_t       aPskIdLength)
580 {
581     OT_ASSERT(aPsk != nullptr);
582     OT_ASSERT(aPskIdentity != nullptr);
583     OT_ASSERT(aPskLength > 0);
584     OT_ASSERT(aPskIdLength > 0);
585 
586     mPreSharedKey         = aPsk;
587     mPreSharedKeyLength   = aPskLength;
588     mPreSharedKeyIdentity = aPskIdentity;
589     mPreSharedKeyIdLength = aPskIdLength;
590 
591     mCipherSuites[0] = MBEDTLS_TLS_PSK_WITH_AES_128_CCM_8;
592     mCipherSuites[1] = 0;
593 }
594 #endif
595 
596 #if defined(MBEDTLS_BASE64_C) && defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
GetPeerCertificateBase64(unsigned char * aPeerCert,size_t * aCertLength,size_t aCertBufferSize)597 Error SecureTransport::GetPeerCertificateBase64(unsigned char *aPeerCert, size_t *aCertLength, size_t aCertBufferSize)
598 {
599     Error error = kErrorNone;
600 
601     VerifyOrExit(IsStateConnected(), error = kErrorInvalidState);
602 
603 #if (MBEDTLS_VERSION_NUMBER >= 0x03010000)
604     VerifyOrExit(mbedtls_base64_encode(aPeerCert, aCertBufferSize, aCertLength,
605                                        mSsl.MBEDTLS_PRIVATE(session)->MBEDTLS_PRIVATE(peer_cert)->raw.p,
606                                        mSsl.MBEDTLS_PRIVATE(session)->MBEDTLS_PRIVATE(peer_cert)->raw.len) == 0,
607                  error = kErrorNoBufs);
608 #else
609     VerifyOrExit(
610         mbedtls_base64_encode(
611             aPeerCert, aCertBufferSize, aCertLength,
612             mSsl.MBEDTLS_PRIVATE(session)->MBEDTLS_PRIVATE(peer_cert)->MBEDTLS_PRIVATE(raw).MBEDTLS_PRIVATE(p),
613             mSsl.MBEDTLS_PRIVATE(session)->MBEDTLS_PRIVATE(peer_cert)->MBEDTLS_PRIVATE(raw).MBEDTLS_PRIVATE(len)) == 0,
614         error = kErrorNoBufs);
615 #endif
616 
617 exit:
618     return error;
619 }
620 #endif // defined(MBEDTLS_BASE64_C) && defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
621 
622 #if defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
GetPeerSubjectAttributeByOid(const char * aOid,size_t aOidLength,uint8_t * aAttributeBuffer,size_t * aAttributeLength,int * aAsn1Type)623 Error SecureTransport::GetPeerSubjectAttributeByOid(const char *aOid,
624                                                     size_t      aOidLength,
625                                                     uint8_t    *aAttributeBuffer,
626                                                     size_t     *aAttributeLength,
627                                                     int        *aAsn1Type)
628 {
629     Error                          error = kErrorNone;
630     const mbedtls_asn1_named_data *data;
631     size_t                         length;
632     size_t                         attributeBufferSize;
633     mbedtls_x509_crt              *peerCert = const_cast<mbedtls_x509_crt *>(mbedtls_ssl_get_peer_cert(&mSsl));
634 
635     VerifyOrExit(aAttributeLength != nullptr, error = kErrorInvalidArgs);
636     attributeBufferSize = *aAttributeLength;
637     *aAttributeLength   = 0;
638 
639     VerifyOrExit(aAttributeBuffer != nullptr, error = kErrorNoBufs);
640     VerifyOrExit(peerCert != nullptr, error = kErrorInvalidState);
641     data = mbedtls_asn1_find_named_data(&peerCert->subject, aOid, aOidLength);
642     VerifyOrExit(data != nullptr, error = kErrorNotFound);
643     length = data->val.len;
644     VerifyOrExit(length <= attributeBufferSize, error = kErrorNoBufs);
645 
646     if (aAttributeLength != nullptr)
647     {
648         *aAttributeLength = length;
649     }
650 
651     if (aAsn1Type != nullptr)
652     {
653         *aAsn1Type = data->val.tag;
654     }
655 
656     memcpy(aAttributeBuffer, data->val.p, length);
657 
658 exit:
659     return error;
660 }
661 
GetThreadAttributeFromPeerCertificate(int aThreadOidDescriptor,uint8_t * aAttributeBuffer,size_t * aAttributeLength)662 Error SecureTransport::GetThreadAttributeFromPeerCertificate(int      aThreadOidDescriptor,
663                                                              uint8_t *aAttributeBuffer,
664                                                              size_t  *aAttributeLength)
665 {
666     const mbedtls_x509_crt *cert = mbedtls_ssl_get_peer_cert(&mSsl);
667 
668     return GetThreadAttributeFromCertificate(cert, aThreadOidDescriptor, aAttributeBuffer, aAttributeLength);
669 }
670 
671 #endif // defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
672 
GetThreadAttributeFromOwnCertificate(int aThreadOidDescriptor,uint8_t * aAttributeBuffer,size_t * aAttributeLength)673 Error SecureTransport::GetThreadAttributeFromOwnCertificate(int      aThreadOidDescriptor,
674                                                             uint8_t *aAttributeBuffer,
675                                                             size_t  *aAttributeLength)
676 {
677     const mbedtls_x509_crt *cert = &mOwnCert;
678 
679     return GetThreadAttributeFromCertificate(cert, aThreadOidDescriptor, aAttributeBuffer, aAttributeLength);
680 }
681 
GetThreadAttributeFromCertificate(const mbedtls_x509_crt * aCert,int aThreadOidDescriptor,uint8_t * aAttributeBuffer,size_t * aAttributeLength)682 Error SecureTransport::GetThreadAttributeFromCertificate(const mbedtls_x509_crt *aCert,
683                                                          int                     aThreadOidDescriptor,
684                                                          uint8_t                *aAttributeBuffer,
685                                                          size_t                 *aAttributeLength)
686 {
687     Error            error  = kErrorNotFound;
688     char             oid[9] = {0x2B, 0x06, 0x01, 0x04, 0x01, static_cast<char>(0x82), static_cast<char>(0xDF),
689                                0x2A, 0x00}; // 1.3.6.1.4.1.44970.0
690     mbedtls_x509_buf v3_ext;
691     unsigned char   *p, *end, *endExtData;
692     size_t           len;
693     size_t           attributeBufferSize;
694     mbedtls_x509_buf extnOid;
695     int              ret, isCritical;
696 
697     VerifyOrExit(aAttributeLength != nullptr, error = kErrorInvalidArgs);
698     attributeBufferSize = *aAttributeLength;
699     *aAttributeLength   = 0;
700 
701     VerifyOrExit(aCert != nullptr, error = kErrorInvalidState);
702     v3_ext = aCert->v3_ext;
703     p      = v3_ext.p;
704     VerifyOrExit(p != nullptr, error = kErrorInvalidState);
705     end = p + v3_ext.len;
706     VerifyOrExit(mbedtls_asn1_get_tag(&p, end, &len, MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE) == 0,
707                  error = kErrorParse);
708     VerifyOrExit(end == p + len, error = kErrorParse);
709 
710     VerifyOrExit(aThreadOidDescriptor < 128, error = kErrorNotImplemented);
711     oid[sizeof(oid) - 1] = static_cast<char>(aThreadOidDescriptor);
712 
713     while (p < end)
714     {
715         isCritical = 0;
716         VerifyOrExit(mbedtls_asn1_get_tag(&p, end, &len, MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE) == 0,
717                      error = kErrorParse);
718         endExtData = p + len;
719 
720         // Get extension ID
721         VerifyOrExit(mbedtls_asn1_get_tag(&p, endExtData, &extnOid.len, MBEDTLS_ASN1_OID) == 0, error = kErrorParse);
722         extnOid.tag = MBEDTLS_ASN1_OID;
723         extnOid.p   = p;
724         p += extnOid.len;
725 
726         // Get optional critical
727         ret = mbedtls_asn1_get_bool(&p, endExtData, &isCritical);
728         VerifyOrExit(ret == 0 || ret == MBEDTLS_ERR_ASN1_UNEXPECTED_TAG, error = kErrorParse);
729 
730         // Data should be octet string type
731         VerifyOrExit(mbedtls_asn1_get_tag(&p, endExtData, &len, MBEDTLS_ASN1_OCTET_STRING) == 0, error = kErrorParse);
732         VerifyOrExit(endExtData == p + len, error = kErrorParse);
733 
734         if (isCritical || extnOid.len != sizeof(oid))
735         {
736             continue;
737         }
738 
739         if (memcmp(extnOid.p, oid, sizeof(oid)) == 0)
740         {
741             *aAttributeLength = len;
742 
743             if (aAttributeBuffer != nullptr)
744             {
745                 VerifyOrExit(len <= attributeBufferSize, error = kErrorNoBufs);
746                 memcpy(aAttributeBuffer, p, len);
747             }
748 
749             error = kErrorNone;
750             break;
751         }
752     }
753 
754 exit:
755     return error;
756 }
757 
758 #endif // OPENTHREAD_CONFIG_TLS_API_ENABLE
759 
760 #ifdef MBEDTLS_SSL_SRV_C
SetClientId(const uint8_t * aClientId,uint8_t aLength)761 Error SecureTransport::SetClientId(const uint8_t *aClientId, uint8_t aLength)
762 {
763     int rval = mbedtls_ssl_set_client_transport_id(&mSsl, aClientId, aLength);
764     return Crypto::MbedTls::MapError(rval);
765 }
766 #endif
767 
Send(Message & aMessage,uint16_t aLength)768 Error SecureTransport::Send(Message &aMessage, uint16_t aLength)
769 {
770     Error   error = kErrorNone;
771     uint8_t buffer[kApplicationDataMaxLength];
772 
773     VerifyOrExit(aLength <= kApplicationDataMaxLength, error = kErrorNoBufs);
774 
775     // Store message specific sub type.
776     if (aMessage.GetSubType() != Message::kSubTypeNone)
777     {
778         mMessageSubType = aMessage.GetSubType();
779     }
780 
781     aMessage.ReadBytes(0, buffer, aLength);
782 
783     SuccessOrExit(error = Crypto::MbedTls::MapError(mbedtls_ssl_write(&mSsl, buffer, aLength)));
784 
785     aMessage.Free();
786 
787 exit:
788     return error;
789 }
790 
Receive(Message & aMessage)791 void SecureTransport::Receive(Message &aMessage)
792 {
793     mReceiveMessage = &aMessage;
794 
795     Process();
796 
797     mReceiveMessage = nullptr;
798 }
799 
HandleMbedtlsTransmit(void * aContext,const unsigned char * aBuf,size_t aLength)800 int SecureTransport::HandleMbedtlsTransmit(void *aContext, const unsigned char *aBuf, size_t aLength)
801 {
802     return static_cast<SecureTransport *>(aContext)->HandleMbedtlsTransmit(aBuf, aLength);
803 }
804 
HandleMbedtlsTransmit(const unsigned char * aBuf,size_t aLength)805 int SecureTransport::HandleMbedtlsTransmit(const unsigned char *aBuf, size_t aLength)
806 {
807     Error error;
808     int   rval = 0;
809 
810     if (mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8)
811     {
812         LogDebg("HandleMbedtlsTransmit DTLS");
813     }
814 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
815     else
816     {
817         LogDebg("HandleMbedtlsTransmit TLS");
818     }
819 #endif
820 
821     error = HandleSecureTransportSend(aBuf, static_cast<uint16_t>(aLength), mMessageSubType);
822 
823     // Restore default sub type.
824     mMessageSubType = mMessageDefaultSubType;
825 
826     switch (error)
827     {
828     case kErrorNone:
829         rval = static_cast<int>(aLength);
830         break;
831 
832     case kErrorNoBufs:
833         rval = MBEDTLS_ERR_SSL_WANT_WRITE;
834         break;
835 
836     default:
837         LogWarn("HandleMbedtlsTransmit: %s error", ErrorToString(error));
838         rval = MBEDTLS_ERR_NET_SEND_FAILED;
839         break;
840     }
841 
842     return rval;
843 }
844 
HandleMbedtlsReceive(void * aContext,unsigned char * aBuf,size_t aLength)845 int SecureTransport::HandleMbedtlsReceive(void *aContext, unsigned char *aBuf, size_t aLength)
846 {
847     return static_cast<SecureTransport *>(aContext)->HandleMbedtlsReceive(aBuf, aLength);
848 }
849 
HandleMbedtlsReceive(unsigned char * aBuf,size_t aLength)850 int SecureTransport::HandleMbedtlsReceive(unsigned char *aBuf, size_t aLength)
851 {
852     int rval;
853 
854     if (mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8)
855     {
856         LogDebg("HandleMbedtlsReceive DTLS");
857     }
858 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
859     else
860     {
861         LogDebg("HandleMbedtlsReceive TLS");
862     }
863 #endif
864 
865     VerifyOrExit(mReceiveMessage != nullptr && (rval = mReceiveMessage->GetLength() - mReceiveMessage->GetOffset()) > 0,
866                  rval = MBEDTLS_ERR_SSL_WANT_READ);
867 
868     if (aLength > static_cast<size_t>(rval))
869     {
870         aLength = static_cast<size_t>(rval);
871     }
872 
873     rval = mReceiveMessage->ReadBytes(mReceiveMessage->GetOffset(), aBuf, static_cast<uint16_t>(aLength));
874     mReceiveMessage->MoveOffset(rval);
875 
876 exit:
877     return rval;
878 }
879 
HandleMbedtlsGetTimer(void * aContext)880 int SecureTransport::HandleMbedtlsGetTimer(void *aContext)
881 {
882     return static_cast<SecureTransport *>(aContext)->HandleMbedtlsGetTimer();
883 }
884 
HandleMbedtlsGetTimer(void)885 int SecureTransport::HandleMbedtlsGetTimer(void)
886 {
887     int rval;
888 
889     if (mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8)
890     {
891         LogDebg("HandleMbedtlsGetTimer");
892     }
893 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
894     else
895     {
896         LogDebg("HandleMbedtlsGetTimer");
897     }
898 #endif
899 
900     if (!mTimerSet)
901     {
902         rval = -1;
903     }
904     else if (!mTimer.IsRunning())
905     {
906         rval = 2;
907     }
908     else if (mTimerIntermediate <= TimerMilli::GetNow())
909     {
910         rval = 1;
911     }
912     else
913     {
914         rval = 0;
915     }
916 
917     return rval;
918 }
919 
HandleMbedtlsSetTimer(void * aContext,uint32_t aIntermediate,uint32_t aFinish)920 void SecureTransport::HandleMbedtlsSetTimer(void *aContext, uint32_t aIntermediate, uint32_t aFinish)
921 {
922     static_cast<SecureTransport *>(aContext)->HandleMbedtlsSetTimer(aIntermediate, aFinish);
923 }
924 
HandleMbedtlsSetTimer(uint32_t aIntermediate,uint32_t aFinish)925 void SecureTransport::HandleMbedtlsSetTimer(uint32_t aIntermediate, uint32_t aFinish)
926 {
927     if (mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8)
928     {
929         LogDebg("SetTimer DTLS");
930     }
931 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
932     else
933     {
934         LogDebg("SetTimer TLS");
935     }
936 #endif
937 
938     if (aFinish == 0)
939     {
940         mTimerSet = false;
941         mTimer.Stop();
942     }
943     else
944     {
945         mTimerSet = true;
946         mTimer.Start(aFinish);
947         mTimerIntermediate = TimerMilli::GetNow() + aIntermediate;
948     }
949 }
950 
951 #if (MBEDTLS_VERSION_NUMBER >= 0x03000000)
952 
HandleMbedtlsExportKeys(void * aContext,mbedtls_ssl_key_export_type aType,const unsigned char * aMasterSecret,size_t aMasterSecretLen,const unsigned char aClientRandom[32],const unsigned char aServerRandom[32],mbedtls_tls_prf_types aTlsPrfType)953 void SecureTransport::HandleMbedtlsExportKeys(void                       *aContext,
954                                               mbedtls_ssl_key_export_type aType,
955                                               const unsigned char        *aMasterSecret,
956                                               size_t                      aMasterSecretLen,
957                                               const unsigned char         aClientRandom[32],
958                                               const unsigned char         aServerRandom[32],
959                                               mbedtls_tls_prf_types       aTlsPrfType)
960 {
961     static_cast<SecureTransport *>(aContext)->HandleMbedtlsExportKeys(aType, aMasterSecret, aMasterSecretLen,
962                                                                       aClientRandom, aServerRandom, aTlsPrfType);
963 }
964 
HandleMbedtlsExportKeys(mbedtls_ssl_key_export_type aType,const unsigned char * aMasterSecret,size_t aMasterSecretLen,const unsigned char aClientRandom[32],const unsigned char aServerRandom[32],mbedtls_tls_prf_types aTlsPrfType)965 void SecureTransport::HandleMbedtlsExportKeys(mbedtls_ssl_key_export_type aType,
966                                               const unsigned char        *aMasterSecret,
967                                               size_t                      aMasterSecretLen,
968                                               const unsigned char         aClientRandom[32],
969                                               const unsigned char         aServerRandom[32],
970                                               mbedtls_tls_prf_types       aTlsPrfType)
971 {
972     Crypto::Sha256::Hash kek;
973     Crypto::Sha256       sha256;
974     unsigned char        keyBlock[kSecureTransportKeyBlockSize];
975     unsigned char        randBytes[2 * kSecureTransportRandomBufferSize];
976 
977     VerifyOrExit(mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8);
978     VerifyOrExit(aType == MBEDTLS_SSL_KEY_EXPORT_TLS12_MASTER_SECRET);
979 
980     memcpy(randBytes, aServerRandom, kSecureTransportRandomBufferSize);
981     memcpy(randBytes + kSecureTransportRandomBufferSize, aClientRandom, kSecureTransportRandomBufferSize);
982 
983     // Retrieve the Key block from Master secret
984     mbedtls_ssl_tls_prf(aTlsPrfType, aMasterSecret, aMasterSecretLen, "key expansion", randBytes, sizeof(randBytes),
985                         keyBlock, sizeof(keyBlock));
986 
987     sha256.Start();
988     sha256.Update(keyBlock, kSecureTransportKeyBlockSize);
989     sha256.Finish(kek);
990 
991     LogDebg("Generated KEK");
992     Get<KeyManager>().SetKek(kek.GetBytes());
993 
994 exit:
995     return;
996 }
997 
998 #else
999 
HandleMbedtlsExportKeys(void * aContext,const unsigned char * aMasterSecret,const unsigned char * aKeyBlock,size_t aMacLength,size_t aKeyLength,size_t aIvLength)1000 int SecureTransport::HandleMbedtlsExportKeys(void                *aContext,
1001                                              const unsigned char *aMasterSecret,
1002                                              const unsigned char *aKeyBlock,
1003                                              size_t               aMacLength,
1004                                              size_t               aKeyLength,
1005                                              size_t               aIvLength)
1006 {
1007     return static_cast<SecureTransport *>(aContext)->HandleMbedtlsExportKeys(aMasterSecret, aKeyBlock, aMacLength,
1008                                                                              aKeyLength, aIvLength);
1009 }
1010 
HandleMbedtlsExportKeys(const unsigned char * aMasterSecret,const unsigned char * aKeyBlock,size_t aMacLength,size_t aKeyLength,size_t aIvLength)1011 int SecureTransport::HandleMbedtlsExportKeys(const unsigned char *aMasterSecret,
1012                                              const unsigned char *aKeyBlock,
1013                                              size_t               aMacLength,
1014                                              size_t               aKeyLength,
1015                                              size_t               aIvLength)
1016 {
1017     OT_UNUSED_VARIABLE(aMasterSecret);
1018 
1019     Crypto::Sha256::Hash kek;
1020     Crypto::Sha256       sha256;
1021 
1022     VerifyOrExit(mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8);
1023 
1024     sha256.Start();
1025     sha256.Update(aKeyBlock, 2 * static_cast<uint16_t>(aMacLength + aKeyLength + aIvLength));
1026     sha256.Finish(kek);
1027 
1028     LogDebg("Generated KEK");
1029     Get<KeyManager>().SetKek(kek.GetBytes());
1030 
1031 exit:
1032     return 0;
1033 }
1034 
1035 #endif // (MBEDTLS_VERSION_NUMBER >= 0x03000000)
1036 
HandleTimer(Timer & aTimer)1037 void SecureTransport::HandleTimer(Timer &aTimer)
1038 {
1039     static_cast<SecureTransport *>(static_cast<TimerMilliContext &>(aTimer).GetContext())->HandleTimer();
1040 }
1041 
HandleTimer(void)1042 void SecureTransport::HandleTimer(void)
1043 {
1044     if (IsStateConnectingOrConnected())
1045     {
1046         Process();
1047     }
1048     else if (IsStateCloseNotify())
1049     {
1050         if ((mMaxConnectionAttempts > 0) && (mRemainingConnectionAttempts == 0))
1051         {
1052             Close();
1053             mConnectedCallback.InvokeIfSet(false);
1054             mAutoCloseCallback.InvokeIfSet();
1055         }
1056         else
1057         {
1058             SetState(kStateOpen);
1059             mTimer.Stop();
1060             mConnectedCallback.InvokeIfSet(false);
1061         }
1062     }
1063 }
1064 
Process(void)1065 void SecureTransport::Process(void)
1066 {
1067     uint8_t buf[OPENTHREAD_CONFIG_DTLS_MAX_CONTENT_LEN];
1068     bool    shouldDisconnect = false;
1069     int     rval;
1070 
1071     while (IsStateConnectingOrConnected())
1072     {
1073         if (IsStateConnecting())
1074         {
1075             rval = mbedtls_ssl_handshake(&mSsl);
1076 
1077             if (mSsl.MBEDTLS_PRIVATE(state) == MBEDTLS_SSL_HANDSHAKE_OVER)
1078             {
1079                 SetState(kStateConnected);
1080                 mConnectedCallback.InvokeIfSet(true);
1081             }
1082         }
1083         else
1084         {
1085             rval = mbedtls_ssl_read(&mSsl, buf, sizeof(buf));
1086         }
1087 
1088         if (rval > 0)
1089         {
1090             mReceiveCallback.InvokeIfSet(buf, static_cast<uint16_t>(rval));
1091         }
1092         else if (rval == 0 || rval == MBEDTLS_ERR_SSL_WANT_READ || rval == MBEDTLS_ERR_SSL_WANT_WRITE)
1093         {
1094             break;
1095         }
1096         else
1097         {
1098             switch (rval)
1099             {
1100             case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
1101                 mbedtls_ssl_close_notify(&mSsl);
1102                 ExitNow(shouldDisconnect = true);
1103                 OT_UNREACHABLE_CODE(break);
1104 
1105             case MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED:
1106                 break;
1107 
1108             case MBEDTLS_ERR_SSL_FATAL_ALERT_MESSAGE:
1109                 mbedtls_ssl_close_notify(&mSsl);
1110                 ExitNow(shouldDisconnect = true);
1111                 OT_UNREACHABLE_CODE(break);
1112 
1113             case MBEDTLS_ERR_SSL_INVALID_MAC:
1114                 if (mSsl.MBEDTLS_PRIVATE(state) != MBEDTLS_SSL_HANDSHAKE_OVER)
1115                 {
1116                     mbedtls_ssl_send_alert_message(&mSsl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
1117                                                    MBEDTLS_SSL_ALERT_MSG_BAD_RECORD_MAC);
1118                     ExitNow(shouldDisconnect = true);
1119                 }
1120 
1121                 break;
1122 
1123             default:
1124                 if (mSsl.MBEDTLS_PRIVATE(state) != MBEDTLS_SSL_HANDSHAKE_OVER)
1125                 {
1126                     mbedtls_ssl_send_alert_message(&mSsl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
1127                                                    MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE);
1128                     ExitNow(shouldDisconnect = true);
1129                 }
1130 
1131                 break;
1132             }
1133 
1134             mbedtls_ssl_session_reset(&mSsl);
1135             if (mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8)
1136             {
1137                 mbedtls_ssl_set_hs_ecjpake_password(&mSsl, mPsk, mPskLength);
1138             }
1139             break;
1140         }
1141     }
1142 
1143 exit:
1144 
1145     if (shouldDisconnect)
1146     {
1147         Disconnect();
1148     }
1149 }
1150 
HandleMbedtlsDebug(void * aContext,int aLevel,const char * aFile,int aLine,const char * aStr)1151 void SecureTransport::HandleMbedtlsDebug(void *aContext, int aLevel, const char *aFile, int aLine, const char *aStr)
1152 {
1153     static_cast<SecureTransport *>(aContext)->HandleMbedtlsDebug(aLevel, aFile, aLine, aStr);
1154 }
1155 
HandleMbedtlsDebug(int aLevel,const char * aFile,int aLine,const char * aStr)1156 void SecureTransport::HandleMbedtlsDebug(int aLevel, const char *aFile, int aLine, const char *aStr)
1157 {
1158     OT_UNUSED_VARIABLE(aStr);
1159     OT_UNUSED_VARIABLE(aFile);
1160     OT_UNUSED_VARIABLE(aLine);
1161 
1162     switch (aLevel)
1163     {
1164     case 1:
1165         LogCrit("[%u] %s", mSocket.GetSockName().mPort, aStr);
1166         break;
1167 
1168     case 2:
1169         LogWarn("[%u] %s", mSocket.GetSockName().mPort, aStr);
1170         break;
1171 
1172     case 3:
1173         LogInfo("[%u] %s", mSocket.GetSockName().mPort, aStr);
1174         break;
1175 
1176     case 4:
1177     default:
1178         LogDebg("[%u] %s", mSocket.GetSockName().mPort, aStr);
1179         break;
1180     }
1181 }
1182 
HandleSecureTransportSend(const uint8_t * aBuf,uint16_t aLength,Message::SubType aMessageSubType)1183 Error SecureTransport::HandleSecureTransportSend(const uint8_t   *aBuf,
1184                                                  uint16_t         aLength,
1185                                                  Message::SubType aMessageSubType)
1186 {
1187     Error        error   = kErrorNone;
1188     ot::Message *message = nullptr;
1189 
1190     VerifyOrExit((message = mSocket.NewMessage()) != nullptr, error = kErrorNoBufs);
1191     message->SetSubType(aMessageSubType);
1192     message->SetLinkSecurityEnabled(mLayerTwoSecurity);
1193 
1194     SuccessOrExit(error = message->AppendBytes(aBuf, aLength));
1195 
1196     // Set message sub type in case Joiner Finalize Response is appended to the message.
1197     if (aMessageSubType != Message::kSubTypeNone)
1198     {
1199         message->SetSubType(aMessageSubType);
1200     }
1201 
1202     if (mTransportCallback.IsSet())
1203     {
1204         SuccessOrExit(error = mTransportCallback.Invoke(*message, mMessageInfo));
1205     }
1206     else
1207     {
1208         SuccessOrExit(error = mSocket.SendTo(*message, mMessageInfo));
1209     }
1210 
1211 exit:
1212     FreeMessageOnError(message, error);
1213     return error;
1214 }
1215 
1216 #if OT_SHOULD_LOG_AT(OT_LOG_LEVEL_INFO)
1217 
StateToString(State aState)1218 const char *SecureTransport::StateToString(State aState)
1219 {
1220     static const char *const kStateStrings[] = {
1221         "Closed",       // (0) kStateClosed
1222         "Open",         // (1) kStateOpen
1223         "Initializing", // (2) kStateInitializing
1224         "Connecting",   // (3) kStateConnecting
1225         "Connected",    // (4) kStateConnected
1226         "CloseNotify",  // (5) kStateCloseNotify
1227     };
1228 
1229     static_assert(0 == kStateClosed, "kStateClosed valid is incorrect");
1230     static_assert(1 == kStateOpen, "kStateOpen valid is incorrect");
1231     static_assert(2 == kStateInitializing, "kStateInitializing valid is incorrect");
1232     static_assert(3 == kStateConnecting, "kStateConnecting valid is incorrect");
1233     static_assert(4 == kStateConnected, "kStateConnected valid is incorrect");
1234     static_assert(5 == kStateCloseNotify, "kStateCloseNotify valid is incorrect");
1235 
1236     return kStateStrings[aState];
1237 }
1238 
1239 #endif
1240 
1241 } // namespace MeshCoP
1242 } // namespace ot
1243 
1244 #endif // OPENTHREAD_CONFIG_SECURE_TRANSPORT_ENABLE
1245