• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2016, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements the necessary hooks for mbedTLS.
32  */
33 
34 #include "dtls.hpp"
35 
36 #include <mbedtls/debug.h>
37 #ifdef MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
38 #include <mbedtls/pem.h>
39 #endif
40 
41 #include <openthread/platform/radio.h>
42 
43 #include "common/as_core_type.hpp"
44 #include "common/code_utils.hpp"
45 #include "common/debug.hpp"
46 #include "common/encoding.hpp"
47 #include "common/instance.hpp"
48 #include "common/locator_getters.hpp"
49 #include "common/log.hpp"
50 #include "common/timer.hpp"
51 #include "crypto/mbedtls.hpp"
52 #include "crypto/sha256.hpp"
53 #include "thread/thread_netif.hpp"
54 
55 #if OPENTHREAD_CONFIG_DTLS_ENABLE
56 
57 namespace ot {
58 namespace MeshCoP {
59 
60 RegisterLogModule("Dtls");
61 
62 #if (MBEDTLS_VERSION_NUMBER >= 0x03010000)
63 const uint16_t Dtls::sGroups[] = {MBEDTLS_SSL_IANA_TLS_GROUP_SECP256R1, MBEDTLS_SSL_IANA_TLS_GROUP_NONE};
64 #else
65 const mbedtls_ecp_group_id Dtls::sCurves[] = {MBEDTLS_ECP_DP_SECP256R1, MBEDTLS_ECP_DP_NONE};
66 #endif
67 
68 #if defined(MBEDTLS_KEY_EXCHANGE__WITH_CERT__ENABLED) || defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
69 const int Dtls::sHashes[] = {MBEDTLS_MD_SHA256, MBEDTLS_MD_NONE};
70 #endif
71 
Dtls(Instance & aInstance,bool aLayerTwoSecurity)72 Dtls::Dtls(Instance &aInstance, bool aLayerTwoSecurity)
73     : InstanceLocator(aInstance)
74     , mState(kStateClosed)
75     , mPskLength(0)
76     , mVerifyPeerCertificate(true)
77     , mTimer(aInstance, Dtls::HandleTimer, this)
78     , mTimerIntermediate(0)
79     , mTimerSet(false)
80     , mLayerTwoSecurity(aLayerTwoSecurity)
81     , mReceiveMessage(nullptr)
82     , mConnectedHandler(nullptr)
83     , mReceiveHandler(nullptr)
84     , mContext(nullptr)
85     , mSocket(aInstance)
86     , mTransportCallback(nullptr)
87     , mTransportContext(nullptr)
88     , mMessageSubType(Message::kSubTypeNone)
89     , mMessageDefaultSubType(Message::kSubTypeNone)
90 {
91 #if OPENTHREAD_CONFIG_COAP_SECURE_API_ENABLE
92 #ifdef MBEDTLS_KEY_EXCHANGE_PSK_ENABLED
93     mPreSharedKey         = nullptr;
94     mPreSharedKeyIdentity = nullptr;
95     mPreSharedKeyIdLength = 0;
96     mPreSharedKeyLength   = 0;
97 #endif
98 
99 #ifdef MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
100     mCaChainSrc       = nullptr;
101     mCaChainLength    = 0;
102     mOwnCertSrc       = nullptr;
103     mOwnCertLength    = 0;
104     mPrivateKeySrc    = nullptr;
105     mPrivateKeyLength = 0;
106     memset(&mCaChain, 0, sizeof(mCaChain));
107     memset(&mOwnCert, 0, sizeof(mOwnCert));
108     memset(&mPrivateKey, 0, sizeof(mPrivateKey));
109 #endif
110 #endif
111 
112     memset(mCipherSuites, 0, sizeof(mCipherSuites));
113     memset(mPsk, 0, sizeof(mPsk));
114     memset(&mSsl, 0, sizeof(mSsl));
115     memset(&mConf, 0, sizeof(mConf));
116 
117 #ifdef MBEDTLS_SSL_COOKIE_C
118     memset(&mCookieCtx, 0, sizeof(mCookieCtx));
119 #endif
120 }
121 
FreeMbedtls(void)122 void Dtls::FreeMbedtls(void)
123 {
124 #if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_COOKIE_C)
125     mbedtls_ssl_cookie_free(&mCookieCtx);
126 #endif
127 #if OPENTHREAD_CONFIG_COAP_SECURE_API_ENABLE
128 #ifdef MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
129     mbedtls_x509_crt_free(&mCaChain);
130     mbedtls_x509_crt_free(&mOwnCert);
131     mbedtls_pk_free(&mPrivateKey);
132 #endif
133 #endif
134     mbedtls_ssl_config_free(&mConf);
135     mbedtls_ssl_free(&mSsl);
136 }
137 
Open(ReceiveHandler aReceiveHandler,ConnectedHandler aConnectedHandler,void * aContext)138 Error Dtls::Open(ReceiveHandler aReceiveHandler, ConnectedHandler aConnectedHandler, void *aContext)
139 {
140     Error error;
141 
142     VerifyOrExit(mState == kStateClosed, error = kErrorAlready);
143 
144     SuccessOrExit(error = mSocket.Open(&Dtls::HandleUdpReceive, this));
145 
146     mReceiveHandler   = aReceiveHandler;
147     mConnectedHandler = aConnectedHandler;
148     mContext          = aContext;
149     mState            = kStateOpen;
150 
151 exit:
152     return error;
153 }
154 
Connect(const Ip6::SockAddr & aSockAddr)155 Error Dtls::Connect(const Ip6::SockAddr &aSockAddr)
156 {
157     Error error;
158 
159     VerifyOrExit(mState == kStateOpen, error = kErrorInvalidState);
160 
161     mMessageInfo.SetPeerAddr(aSockAddr.GetAddress());
162     mMessageInfo.SetPeerPort(aSockAddr.mPort);
163 
164     error = Setup(true);
165 
166 exit:
167     return error;
168 }
169 
HandleUdpReceive(void * aContext,otMessage * aMessage,const otMessageInfo * aMessageInfo)170 void Dtls::HandleUdpReceive(void *aContext, otMessage *aMessage, const otMessageInfo *aMessageInfo)
171 {
172     static_cast<Dtls *>(aContext)->HandleUdpReceive(AsCoreType(aMessage), AsCoreType(aMessageInfo));
173 }
174 
HandleUdpReceive(Message & aMessage,const Ip6::MessageInfo & aMessageInfo)175 void Dtls::HandleUdpReceive(Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
176 {
177     switch (mState)
178     {
179     case Dtls::kStateClosed:
180         ExitNow();
181 
182     case Dtls::kStateOpen:
183         IgnoreError(mSocket.Connect(Ip6::SockAddr(aMessageInfo.GetPeerAddr(), aMessageInfo.GetPeerPort())));
184 
185         mMessageInfo.SetPeerAddr(aMessageInfo.GetPeerAddr());
186         mMessageInfo.SetPeerPort(aMessageInfo.GetPeerPort());
187         mMessageInfo.SetIsHostInterface(aMessageInfo.IsHostInterface());
188 
189         if (Get<ThreadNetif>().HasUnicastAddress(aMessageInfo.GetSockAddr()))
190         {
191             mMessageInfo.SetSockAddr(aMessageInfo.GetSockAddr());
192         }
193 
194         mMessageInfo.SetSockPort(aMessageInfo.GetSockPort());
195 
196         SuccessOrExit(Setup(false));
197         break;
198 
199     default:
200         // Once DTLS session is started, communicate only with a peer.
201         VerifyOrExit((mMessageInfo.GetPeerAddr() == aMessageInfo.GetPeerAddr()) &&
202                      (mMessageInfo.GetPeerPort() == aMessageInfo.GetPeerPort()));
203         break;
204     }
205 
206 #ifdef MBEDTLS_SSL_SRV_C
207     if (mState == Dtls::kStateConnecting)
208     {
209         IgnoreError(SetClientId(mMessageInfo.GetPeerAddr().mFields.m8, sizeof(mMessageInfo.GetPeerAddr().mFields)));
210     }
211 #endif
212 
213     Receive(aMessage);
214 
215 exit:
216     return;
217 }
218 
GetUdpPort(void) const219 uint16_t Dtls::GetUdpPort(void) const
220 {
221     return mSocket.GetSockName().GetPort();
222 }
223 
Bind(uint16_t aPort)224 Error Dtls::Bind(uint16_t aPort)
225 {
226     Error error;
227 
228     VerifyOrExit(mState == kStateOpen, error = kErrorInvalidState);
229     VerifyOrExit(mTransportCallback == nullptr, error = kErrorAlready);
230 
231     SuccessOrExit(error = mSocket.Bind(aPort, OT_NETIF_UNSPECIFIED));
232 
233 exit:
234     return error;
235 }
236 
Bind(TransportCallback aCallback,void * aContext)237 Error Dtls::Bind(TransportCallback aCallback, void *aContext)
238 {
239     Error error = kErrorNone;
240 
241     VerifyOrExit(mState == kStateOpen, error = kErrorInvalidState);
242     VerifyOrExit(!mSocket.IsBound(), error = kErrorAlready);
243     VerifyOrExit(mTransportCallback == nullptr, error = kErrorAlready);
244 
245     mTransportCallback = aCallback;
246     mTransportContext  = aContext;
247 
248 exit:
249     return error;
250 }
251 
Setup(bool aClient)252 Error Dtls::Setup(bool aClient)
253 {
254     int rval;
255 
256     // do not handle new connection before guard time expired
257     VerifyOrExit(mState == kStateOpen, rval = MBEDTLS_ERR_SSL_TIMEOUT);
258 
259     mState = kStateInitializing;
260 
261     mbedtls_ssl_init(&mSsl);
262     mbedtls_ssl_config_init(&mConf);
263 #if OPENTHREAD_CONFIG_COAP_SECURE_API_ENABLE
264 #ifdef MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
265     mbedtls_x509_crt_init(&mCaChain);
266     mbedtls_x509_crt_init(&mOwnCert);
267     mbedtls_pk_init(&mPrivateKey);
268 #endif
269 #endif
270 #if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_COOKIE_C)
271     mbedtls_ssl_cookie_init(&mCookieCtx);
272 #endif
273 
274     rval = mbedtls_ssl_config_defaults(&mConf, aClient ? MBEDTLS_SSL_IS_CLIENT : MBEDTLS_SSL_IS_SERVER,
275                                        MBEDTLS_SSL_TRANSPORT_DATAGRAM, MBEDTLS_SSL_PRESET_DEFAULT);
276     VerifyOrExit(rval == 0);
277 
278 #if OPENTHREAD_CONFIG_COAP_SECURE_API_ENABLE
279     if (mVerifyPeerCertificate && mCipherSuites[0] == MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8)
280     {
281         mbedtls_ssl_conf_authmode(&mConf, MBEDTLS_SSL_VERIFY_REQUIRED);
282     }
283     else
284     {
285         mbedtls_ssl_conf_authmode(&mConf, MBEDTLS_SSL_VERIFY_NONE);
286     }
287 #else
288     OT_UNUSED_VARIABLE(mVerifyPeerCertificate);
289 #endif
290 
291     mbedtls_ssl_conf_rng(&mConf, Crypto::MbedTls::CryptoSecurePrng, nullptr);
292     mbedtls_ssl_conf_min_version(&mConf, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3);
293     mbedtls_ssl_conf_max_version(&mConf, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3);
294 
295     OT_ASSERT(mCipherSuites[1] == 0);
296     mbedtls_ssl_conf_ciphersuites(&mConf, mCipherSuites);
297     if (mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8)
298     {
299 #if (MBEDTLS_VERSION_NUMBER >= 0x03010000)
300         mbedtls_ssl_conf_groups(&mConf, sGroups);
301 #else
302         mbedtls_ssl_conf_curves(&mConf, sCurves);
303 #endif
304 #if defined(MBEDTLS_KEY_EXCHANGE__WITH_CERT__ENABLED) || defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
305         mbedtls_ssl_conf_sig_hashes(&mConf, sHashes);
306 #endif
307     }
308 
309 #if (MBEDTLS_VERSION_NUMBER >= 0x03000000)
310     mbedtls_ssl_set_export_keys_cb(&mSsl, HandleMbedtlsExportKeys, this);
311 #else
312     mbedtls_ssl_conf_export_keys_cb(&mConf, HandleMbedtlsExportKeys, this);
313 #endif
314 
315     mbedtls_ssl_conf_handshake_timeout(&mConf, 8000, 60000);
316     mbedtls_ssl_conf_dbg(&mConf, HandleMbedtlsDebug, this);
317 
318 #if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_COOKIE_C)
319     if (!aClient)
320     {
321         rval = mbedtls_ssl_cookie_setup(&mCookieCtx, Crypto::MbedTls::CryptoSecurePrng, nullptr);
322         VerifyOrExit(rval == 0);
323 
324         mbedtls_ssl_conf_dtls_cookies(&mConf, mbedtls_ssl_cookie_write, mbedtls_ssl_cookie_check, &mCookieCtx);
325     }
326 #endif
327 
328     rval = mbedtls_ssl_setup(&mSsl, &mConf);
329     VerifyOrExit(rval == 0);
330 
331     mbedtls_ssl_set_bio(&mSsl, this, &Dtls::HandleMbedtlsTransmit, HandleMbedtlsReceive, nullptr);
332     mbedtls_ssl_set_timer_cb(&mSsl, this, &Dtls::HandleMbedtlsSetTimer, HandleMbedtlsGetTimer);
333 
334     if (mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8)
335     {
336         rval = mbedtls_ssl_set_hs_ecjpake_password(&mSsl, mPsk, mPskLength);
337     }
338 #if OPENTHREAD_CONFIG_COAP_SECURE_API_ENABLE
339     else
340     {
341         rval = SetApplicationCoapSecureKeys();
342     }
343 #endif
344     VerifyOrExit(rval == 0);
345 
346     mReceiveMessage = nullptr;
347     mMessageSubType = Message::kSubTypeNone;
348     mState          = kStateConnecting;
349 
350     if (mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8)
351     {
352         LogInfo("DTLS started");
353     }
354 #if OPENTHREAD_CONFIG_COAP_SECURE_API_ENABLE
355     else
356     {
357         LogInfo("Application Coap Secure DTLS started");
358     }
359 #endif
360 
361     mState = kStateConnecting;
362 
363     Process();
364 
365 exit:
366     if ((mState == kStateInitializing) && (rval != 0))
367     {
368         mState = kStateOpen;
369         FreeMbedtls();
370     }
371 
372     return Crypto::MbedTls::MapError(rval);
373 }
374 
375 #if OPENTHREAD_CONFIG_COAP_SECURE_API_ENABLE
SetApplicationCoapSecureKeys(void)376 int Dtls::SetApplicationCoapSecureKeys(void)
377 {
378     int rval = 0;
379 
380     switch (mCipherSuites[0])
381     {
382     case MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8:
383 #ifdef MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
384         if (mCaChainSrc != nullptr)
385         {
386             rval = mbedtls_x509_crt_parse(&mCaChain, static_cast<const unsigned char *>(mCaChainSrc),
387                                           static_cast<size_t>(mCaChainLength));
388             VerifyOrExit(rval == 0);
389             mbedtls_ssl_conf_ca_chain(&mConf, &mCaChain, nullptr);
390         }
391 
392         if (mOwnCertSrc != nullptr && mPrivateKeySrc != nullptr)
393         {
394             rval = mbedtls_x509_crt_parse(&mOwnCert, static_cast<const unsigned char *>(mOwnCertSrc),
395                                           static_cast<size_t>(mOwnCertLength));
396             VerifyOrExit(rval == 0);
397 
398 #if (MBEDTLS_VERSION_NUMBER >= 0x03000000)
399             rval = mbedtls_pk_parse_key(&mPrivateKey, static_cast<const unsigned char *>(mPrivateKeySrc),
400                                         static_cast<size_t>(mPrivateKeyLength), nullptr, 0,
401                                         Crypto::MbedTls::CryptoSecurePrng, nullptr);
402 #else
403             rval = mbedtls_pk_parse_key(&mPrivateKey, static_cast<const unsigned char *>(mPrivateKeySrc),
404                                         static_cast<size_t>(mPrivateKeyLength), nullptr, 0);
405 #endif
406             VerifyOrExit(rval == 0);
407             rval = mbedtls_ssl_conf_own_cert(&mConf, &mOwnCert, &mPrivateKey);
408             VerifyOrExit(rval == 0);
409         }
410 #endif
411         break;
412 
413     case MBEDTLS_TLS_PSK_WITH_AES_128_CCM_8:
414 #ifdef MBEDTLS_KEY_EXCHANGE_PSK_ENABLED
415         rval = mbedtls_ssl_conf_psk(&mConf, static_cast<const unsigned char *>(mPreSharedKey), mPreSharedKeyLength,
416                                     static_cast<const unsigned char *>(mPreSharedKeyIdentity), mPreSharedKeyIdLength);
417         VerifyOrExit(rval == 0);
418 #endif
419         break;
420 
421     default:
422         LogCrit("Application Coap Secure: Not supported cipher.");
423         rval = MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
424         ExitNow();
425         break;
426     }
427 
428 exit:
429     return rval;
430 }
431 
432 #endif // OPENTHREAD_CONFIG_COAP_SECURE_API_ENABLE
433 
Close(void)434 void Dtls::Close(void)
435 {
436     Disconnect();
437 
438     mState             = kStateClosed;
439     mTransportCallback = nullptr;
440     mTransportContext  = nullptr;
441     mTimerSet          = false;
442 
443     IgnoreError(mSocket.Close());
444     mTimer.Stop();
445 }
446 
Disconnect(void)447 void Dtls::Disconnect(void)
448 {
449     VerifyOrExit(mState == kStateConnecting || mState == kStateConnected);
450 
451     mbedtls_ssl_close_notify(&mSsl);
452     mState = kStateCloseNotify;
453     mTimer.Start(kGuardTimeNewConnectionMilli);
454 
455     mMessageInfo.Clear();
456     IgnoreError(mSocket.Connect());
457 
458     FreeMbedtls();
459 
460 exit:
461     return;
462 }
463 
SetPsk(const uint8_t * aPsk,uint8_t aPskLength)464 Error Dtls::SetPsk(const uint8_t *aPsk, uint8_t aPskLength)
465 {
466     Error error = kErrorNone;
467 
468     VerifyOrExit(aPskLength <= sizeof(mPsk), error = kErrorInvalidArgs);
469 
470     memcpy(mPsk, aPsk, aPskLength);
471     mPskLength       = aPskLength;
472     mCipherSuites[0] = MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8;
473     mCipherSuites[1] = 0;
474 
475 exit:
476     return error;
477 }
478 
479 #if OPENTHREAD_CONFIG_COAP_SECURE_API_ENABLE
480 #ifdef MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
481 
SetCertificate(const uint8_t * aX509Certificate,uint32_t aX509CertLength,const uint8_t * aPrivateKey,uint32_t aPrivateKeyLength)482 void Dtls::SetCertificate(const uint8_t *aX509Certificate,
483                           uint32_t       aX509CertLength,
484                           const uint8_t *aPrivateKey,
485                           uint32_t       aPrivateKeyLength)
486 {
487     OT_ASSERT(aX509CertLength > 0);
488     OT_ASSERT(aX509Certificate != nullptr);
489 
490     OT_ASSERT(aPrivateKeyLength > 0);
491     OT_ASSERT(aPrivateKey != nullptr);
492 
493     mOwnCertSrc       = aX509Certificate;
494     mOwnCertLength    = aX509CertLength;
495     mPrivateKeySrc    = aPrivateKey;
496     mPrivateKeyLength = aPrivateKeyLength;
497 
498     mCipherSuites[0] = MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8;
499     mCipherSuites[1] = 0;
500 }
501 
SetCaCertificateChain(const uint8_t * aX509CaCertificateChain,uint32_t aX509CaCertChainLength)502 void Dtls::SetCaCertificateChain(const uint8_t *aX509CaCertificateChain, uint32_t aX509CaCertChainLength)
503 {
504     OT_ASSERT(aX509CaCertChainLength > 0);
505     OT_ASSERT(aX509CaCertificateChain != nullptr);
506 
507     mCaChainSrc    = aX509CaCertificateChain;
508     mCaChainLength = aX509CaCertChainLength;
509 }
510 
511 #endif // MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
512 
513 #ifdef MBEDTLS_KEY_EXCHANGE_PSK_ENABLED
SetPreSharedKey(const uint8_t * aPsk,uint16_t aPskLength,const uint8_t * aPskIdentity,uint16_t aPskIdLength)514 void Dtls::SetPreSharedKey(const uint8_t *aPsk, uint16_t aPskLength, const uint8_t *aPskIdentity, uint16_t aPskIdLength)
515 {
516     OT_ASSERT(aPsk != nullptr);
517     OT_ASSERT(aPskIdentity != nullptr);
518     OT_ASSERT(aPskLength > 0);
519     OT_ASSERT(aPskIdLength > 0);
520 
521     mPreSharedKey         = aPsk;
522     mPreSharedKeyLength   = aPskLength;
523     mPreSharedKeyIdentity = aPskIdentity;
524     mPreSharedKeyIdLength = aPskIdLength;
525 
526     mCipherSuites[0] = MBEDTLS_TLS_PSK_WITH_AES_128_CCM_8;
527     mCipherSuites[1] = 0;
528 }
529 #endif
530 
531 #if defined(MBEDTLS_BASE64_C) && defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
GetPeerCertificateBase64(unsigned char * aPeerCert,size_t * aCertLength,size_t aCertBufferSize)532 Error Dtls::GetPeerCertificateBase64(unsigned char *aPeerCert, size_t *aCertLength, size_t aCertBufferSize)
533 {
534     Error error = kErrorNone;
535 
536     VerifyOrExit(mState == kStateConnected, error = kErrorInvalidState);
537 
538 #if (MBEDTLS_VERSION_NUMBER >= 0x03010000)
539     VerifyOrExit(mbedtls_base64_encode(aPeerCert, aCertBufferSize, aCertLength,
540                                        mSsl.MBEDTLS_PRIVATE(session)->MBEDTLS_PRIVATE(peer_cert)->raw.p,
541                                        mSsl.MBEDTLS_PRIVATE(session)->MBEDTLS_PRIVATE(peer_cert)->raw.len) == 0,
542                  error = kErrorNoBufs);
543 #else
544     VerifyOrExit(
545         mbedtls_base64_encode(
546             aPeerCert, aCertBufferSize, aCertLength,
547             mSsl.MBEDTLS_PRIVATE(session)->MBEDTLS_PRIVATE(peer_cert)->MBEDTLS_PRIVATE(raw).MBEDTLS_PRIVATE(p),
548             mSsl.MBEDTLS_PRIVATE(session)->MBEDTLS_PRIVATE(peer_cert)->MBEDTLS_PRIVATE(raw).MBEDTLS_PRIVATE(len)) == 0,
549         error = kErrorNoBufs);
550 #endif
551 
552 exit:
553     return error;
554 }
555 #endif // defined(MBEDTLS_BASE64_C) && defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
556 
557 #endif // OPENTHREAD_CONFIG_COAP_SECURE_API_ENABLE
558 
559 #ifdef MBEDTLS_SSL_SRV_C
SetClientId(const uint8_t * aClientId,uint8_t aLength)560 Error Dtls::SetClientId(const uint8_t *aClientId, uint8_t aLength)
561 {
562     int rval = mbedtls_ssl_set_client_transport_id(&mSsl, aClientId, aLength);
563     return Crypto::MbedTls::MapError(rval);
564 }
565 #endif
566 
Send(Message & aMessage,uint16_t aLength)567 Error Dtls::Send(Message &aMessage, uint16_t aLength)
568 {
569     Error   error = kErrorNone;
570     uint8_t buffer[kApplicationDataMaxLength];
571 
572     VerifyOrExit(aLength <= kApplicationDataMaxLength, error = kErrorNoBufs);
573 
574     // Store message specific sub type.
575     if (aMessage.GetSubType() != Message::kSubTypeNone)
576     {
577         mMessageSubType = aMessage.GetSubType();
578     }
579 
580     aMessage.ReadBytes(0, buffer, aLength);
581 
582     SuccessOrExit(error = Crypto::MbedTls::MapError(mbedtls_ssl_write(&mSsl, buffer, aLength)));
583 
584     aMessage.Free();
585 
586 exit:
587     return error;
588 }
589 
Receive(Message & aMessage)590 void Dtls::Receive(Message &aMessage)
591 {
592     mReceiveMessage = &aMessage;
593 
594     Process();
595 
596     mReceiveMessage = nullptr;
597 }
598 
HandleMbedtlsTransmit(void * aContext,const unsigned char * aBuf,size_t aLength)599 int Dtls::HandleMbedtlsTransmit(void *aContext, const unsigned char *aBuf, size_t aLength)
600 {
601     return static_cast<Dtls *>(aContext)->HandleMbedtlsTransmit(aBuf, aLength);
602 }
603 
HandleMbedtlsTransmit(const unsigned char * aBuf,size_t aLength)604 int Dtls::HandleMbedtlsTransmit(const unsigned char *aBuf, size_t aLength)
605 {
606     Error error;
607     int   rval = 0;
608 
609     if (mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8)
610     {
611         LogDebg("HandleMbedtlsTransmit");
612     }
613 #if OPENTHREAD_CONFIG_COAP_SECURE_API_ENABLE
614     else
615     {
616         LogDebg("ApplicationCoapSecure HandleMbedtlsTransmit");
617     }
618 #endif
619 
620     error = HandleDtlsSend(aBuf, static_cast<uint16_t>(aLength), mMessageSubType);
621 
622     // Restore default sub type.
623     mMessageSubType = mMessageDefaultSubType;
624 
625     switch (error)
626     {
627     case kErrorNone:
628         rval = static_cast<int>(aLength);
629         break;
630 
631     case kErrorNoBufs:
632         rval = MBEDTLS_ERR_SSL_WANT_WRITE;
633         break;
634 
635     default:
636         LogWarn("HandleMbedtlsTransmit: %s error", ErrorToString(error));
637         rval = MBEDTLS_ERR_NET_SEND_FAILED;
638         break;
639     }
640 
641     return rval;
642 }
643 
HandleMbedtlsReceive(void * aContext,unsigned char * aBuf,size_t aLength)644 int Dtls::HandleMbedtlsReceive(void *aContext, unsigned char *aBuf, size_t aLength)
645 {
646     return static_cast<Dtls *>(aContext)->HandleMbedtlsReceive(aBuf, aLength);
647 }
648 
HandleMbedtlsReceive(unsigned char * aBuf,size_t aLength)649 int Dtls::HandleMbedtlsReceive(unsigned char *aBuf, size_t aLength)
650 {
651     int rval;
652 
653     if (mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8)
654     {
655         LogDebg("HandleMbedtlsReceive");
656     }
657 #if OPENTHREAD_CONFIG_COAP_SECURE_API_ENABLE
658     else
659     {
660         LogDebg("ApplicationCoapSecure HandleMbedtlsReceive");
661     }
662 #endif
663 
664     VerifyOrExit(mReceiveMessage != nullptr && (rval = mReceiveMessage->GetLength() - mReceiveMessage->GetOffset()) > 0,
665                  rval = MBEDTLS_ERR_SSL_WANT_READ);
666 
667     if (aLength > static_cast<size_t>(rval))
668     {
669         aLength = static_cast<size_t>(rval);
670     }
671 
672     rval = mReceiveMessage->ReadBytes(mReceiveMessage->GetOffset(), aBuf, static_cast<uint16_t>(aLength));
673     mReceiveMessage->MoveOffset(rval);
674 
675 exit:
676     return rval;
677 }
678 
HandleMbedtlsGetTimer(void * aContext)679 int Dtls::HandleMbedtlsGetTimer(void *aContext)
680 {
681     return static_cast<Dtls *>(aContext)->HandleMbedtlsGetTimer();
682 }
683 
HandleMbedtlsGetTimer(void)684 int Dtls::HandleMbedtlsGetTimer(void)
685 {
686     int rval;
687 
688     if (mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8)
689     {
690         LogDebg("HandleMbedtlsGetTimer");
691     }
692 #if OPENTHREAD_CONFIG_COAP_SECURE_API_ENABLE
693     else
694     {
695         LogDebg("ApplicationCoapSecure HandleMbedtlsGetTimer");
696     }
697 #endif
698 
699     if (!mTimerSet)
700     {
701         rval = -1;
702     }
703     else if (!mTimer.IsRunning())
704     {
705         rval = 2;
706     }
707     else if (mTimerIntermediate <= TimerMilli::GetNow())
708     {
709         rval = 1;
710     }
711     else
712     {
713         rval = 0;
714     }
715 
716     return rval;
717 }
718 
HandleMbedtlsSetTimer(void * aContext,uint32_t aIntermediate,uint32_t aFinish)719 void Dtls::HandleMbedtlsSetTimer(void *aContext, uint32_t aIntermediate, uint32_t aFinish)
720 {
721     static_cast<Dtls *>(aContext)->HandleMbedtlsSetTimer(aIntermediate, aFinish);
722 }
723 
HandleMbedtlsSetTimer(uint32_t aIntermediate,uint32_t aFinish)724 void Dtls::HandleMbedtlsSetTimer(uint32_t aIntermediate, uint32_t aFinish)
725 {
726     if (mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8)
727     {
728         LogDebg("SetTimer");
729     }
730 #if OPENTHREAD_CONFIG_COAP_SECURE_API_ENABLE
731     else
732     {
733         LogDebg("ApplicationCoapSecure SetTimer");
734     }
735 #endif
736 
737     if (aFinish == 0)
738     {
739         mTimerSet = false;
740         mTimer.Stop();
741     }
742     else
743     {
744         mTimerSet = true;
745         mTimer.Start(aFinish);
746         mTimerIntermediate = TimerMilli::GetNow() + aIntermediate;
747     }
748 }
749 
750 #if (MBEDTLS_VERSION_NUMBER >= 0x03000000)
751 
HandleMbedtlsExportKeys(void * aContext,mbedtls_ssl_key_export_type aType,const unsigned char * aMasterSecret,size_t aMasterSecretLen,const unsigned char aClientRandom[32],const unsigned char aServerRandom[32],mbedtls_tls_prf_types aTlsPrfType)752 void Dtls::HandleMbedtlsExportKeys(void *                      aContext,
753                                    mbedtls_ssl_key_export_type aType,
754                                    const unsigned char *       aMasterSecret,
755                                    size_t                      aMasterSecretLen,
756                                    const unsigned char         aClientRandom[32],
757                                    const unsigned char         aServerRandom[32],
758                                    mbedtls_tls_prf_types       aTlsPrfType)
759 {
760     static_cast<Dtls *>(aContext)->HandleMbedtlsExportKeys(aType, aMasterSecret, aMasterSecretLen, aClientRandom,
761                                                            aServerRandom, aTlsPrfType);
762 }
763 
HandleMbedtlsExportKeys(mbedtls_ssl_key_export_type aType,const unsigned char * aMasterSecret,size_t aMasterSecretLen,const unsigned char aClientRandom[32],const unsigned char aServerRandom[32],mbedtls_tls_prf_types aTlsPrfType)764 void Dtls::HandleMbedtlsExportKeys(mbedtls_ssl_key_export_type aType,
765                                    const unsigned char *       aMasterSecret,
766                                    size_t                      aMasterSecretLen,
767                                    const unsigned char         aClientRandom[32],
768                                    const unsigned char         aServerRandom[32],
769                                    mbedtls_tls_prf_types       aTlsPrfType)
770 {
771     Crypto::Sha256::Hash kek;
772     Crypto::Sha256       sha256;
773     unsigned char        keyBlock[kDtlsKeyBlockSize];
774     unsigned char        randBytes[2 * kDtlsRandomBufferSize];
775 
776     VerifyOrExit(mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8);
777     VerifyOrExit(aType == MBEDTLS_SSL_KEY_EXPORT_TLS12_MASTER_SECRET);
778 
779     memcpy(randBytes, aServerRandom, kDtlsRandomBufferSize);
780     memcpy(randBytes + kDtlsRandomBufferSize, aClientRandom, kDtlsRandomBufferSize);
781 
782     // Retrieve the Key block from Master secret
783     mbedtls_ssl_tls_prf(aTlsPrfType, aMasterSecret, aMasterSecretLen, "key expansion", randBytes, sizeof(randBytes),
784                         keyBlock, sizeof(keyBlock));
785 
786     sha256.Start();
787     sha256.Update(keyBlock, kDtlsKeyBlockSize);
788     sha256.Finish(kek);
789 
790     LogDebg("Generated KEK");
791     Get<KeyManager>().SetKek(kek.GetBytes());
792 
793 exit:
794     return;
795 }
796 
797 #else
798 
HandleMbedtlsExportKeys(void * aContext,const unsigned char * aMasterSecret,const unsigned char * aKeyBlock,size_t aMacLength,size_t aKeyLength,size_t aIvLength)799 int Dtls::HandleMbedtlsExportKeys(void *               aContext,
800                                   const unsigned char *aMasterSecret,
801                                   const unsigned char *aKeyBlock,
802                                   size_t               aMacLength,
803                                   size_t               aKeyLength,
804                                   size_t               aIvLength)
805 {
806     return static_cast<Dtls *>(aContext)->HandleMbedtlsExportKeys(aMasterSecret, aKeyBlock, aMacLength, aKeyLength,
807                                                                   aIvLength);
808 }
809 
HandleMbedtlsExportKeys(const unsigned char * aMasterSecret,const unsigned char * aKeyBlock,size_t aMacLength,size_t aKeyLength,size_t aIvLength)810 int Dtls::HandleMbedtlsExportKeys(const unsigned char *aMasterSecret,
811                                   const unsigned char *aKeyBlock,
812                                   size_t               aMacLength,
813                                   size_t               aKeyLength,
814                                   size_t               aIvLength)
815 {
816     OT_UNUSED_VARIABLE(aMasterSecret);
817 
818     Crypto::Sha256::Hash kek;
819     Crypto::Sha256       sha256;
820 
821     VerifyOrExit(mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8);
822 
823     sha256.Start();
824     sha256.Update(aKeyBlock, 2 * static_cast<uint16_t>(aMacLength + aKeyLength + aIvLength));
825     sha256.Finish(kek);
826 
827     LogDebg("Generated KEK");
828     Get<KeyManager>().SetKek(kek.GetBytes());
829 
830 exit:
831     return 0;
832 }
833 
834 #endif // (MBEDTLS_VERSION_NUMBER >= 0x03000000)
835 
HandleTimer(Timer & aTimer)836 void Dtls::HandleTimer(Timer &aTimer)
837 {
838     static_cast<Dtls *>(static_cast<TimerMilliContext &>(aTimer).GetContext())->HandleTimer();
839 }
840 
HandleTimer(void)841 void Dtls::HandleTimer(void)
842 {
843     switch (mState)
844     {
845     case kStateConnecting:
846     case kStateConnected:
847         Process();
848         break;
849 
850     case kStateCloseNotify:
851         mState = kStateOpen;
852         mTimer.Stop();
853 
854         if (mConnectedHandler != nullptr)
855         {
856             mConnectedHandler(mContext, false);
857         }
858         break;
859 
860     default:
861         OT_ASSERT(false);
862         OT_UNREACHABLE_CODE(break);
863     }
864 }
865 
Process(void)866 void Dtls::Process(void)
867 {
868     uint8_t buf[OPENTHREAD_CONFIG_DTLS_MAX_CONTENT_LEN];
869     bool    shouldDisconnect = false;
870     int     rval;
871 
872     while ((mState == kStateConnecting) || (mState == kStateConnected))
873     {
874         if (mState == kStateConnecting)
875         {
876             rval = mbedtls_ssl_handshake(&mSsl);
877 
878             if (mSsl.MBEDTLS_PRIVATE(state) == MBEDTLS_SSL_HANDSHAKE_OVER)
879             {
880                 mState = kStateConnected;
881 
882                 if (mConnectedHandler != nullptr)
883                 {
884                     mConnectedHandler(mContext, true);
885                 }
886             }
887         }
888         else
889         {
890             rval = mbedtls_ssl_read(&mSsl, buf, sizeof(buf));
891         }
892 
893         if (rval > 0)
894         {
895             if (mReceiveHandler != nullptr)
896             {
897                 mReceiveHandler(mContext, buf, static_cast<uint16_t>(rval));
898             }
899         }
900         else if (rval == 0 || rval == MBEDTLS_ERR_SSL_WANT_READ || rval == MBEDTLS_ERR_SSL_WANT_WRITE)
901         {
902             break;
903         }
904         else
905         {
906             switch (rval)
907             {
908             case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
909                 mbedtls_ssl_close_notify(&mSsl);
910                 ExitNow(shouldDisconnect = true);
911                 OT_UNREACHABLE_CODE(break);
912 
913             case MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED:
914                 break;
915 
916             case MBEDTLS_ERR_SSL_FATAL_ALERT_MESSAGE:
917                 mbedtls_ssl_close_notify(&mSsl);
918                 ExitNow(shouldDisconnect = true);
919                 OT_UNREACHABLE_CODE(break);
920 
921             case MBEDTLS_ERR_SSL_INVALID_MAC:
922                 if (mSsl.MBEDTLS_PRIVATE(state) != MBEDTLS_SSL_HANDSHAKE_OVER)
923                 {
924                     mbedtls_ssl_send_alert_message(&mSsl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
925                                                    MBEDTLS_SSL_ALERT_MSG_BAD_RECORD_MAC);
926                     ExitNow(shouldDisconnect = true);
927                 }
928 
929                 break;
930 
931             default:
932                 if (mSsl.MBEDTLS_PRIVATE(state) != MBEDTLS_SSL_HANDSHAKE_OVER)
933                 {
934                     mbedtls_ssl_send_alert_message(&mSsl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
935                                                    MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE);
936                     ExitNow(shouldDisconnect = true);
937                 }
938 
939                 break;
940             }
941 
942             mbedtls_ssl_session_reset(&mSsl);
943             if (mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8)
944             {
945                 mbedtls_ssl_set_hs_ecjpake_password(&mSsl, mPsk, mPskLength);
946             }
947             break;
948         }
949     }
950 
951 exit:
952 
953     if (shouldDisconnect)
954     {
955         Disconnect();
956     }
957 }
958 
HandleMbedtlsDebug(void * aContext,int aLevel,const char * aFile,int aLine,const char * aStr)959 void Dtls::HandleMbedtlsDebug(void *aContext, int aLevel, const char *aFile, int aLine, const char *aStr)
960 {
961     static_cast<Dtls *>(aContext)->HandleMbedtlsDebug(aLevel, aFile, aLine, aStr);
962 }
963 
HandleMbedtlsDebug(int aLevel,const char * aFile,int aLine,const char * aStr)964 void Dtls::HandleMbedtlsDebug(int aLevel, const char *aFile, int aLine, const char *aStr)
965 {
966     OT_UNUSED_VARIABLE(aStr);
967     OT_UNUSED_VARIABLE(aFile);
968     OT_UNUSED_VARIABLE(aLine);
969 
970     switch (aLevel)
971     {
972     case 1:
973         LogCrit("[%hu] %s", mSocket.GetSockName().mPort, aStr);
974         break;
975 
976     case 2:
977         LogWarn("[%hu] %s", mSocket.GetSockName().mPort, aStr);
978         break;
979 
980     case 3:
981         LogInfo("[%hu] %s", mSocket.GetSockName().mPort, aStr);
982         break;
983 
984     case 4:
985     default:
986         LogDebg("[%hu] %s", mSocket.GetSockName().mPort, aStr);
987         break;
988     }
989 }
990 
HandleDtlsSend(const uint8_t * aBuf,uint16_t aLength,Message::SubType aMessageSubType)991 Error Dtls::HandleDtlsSend(const uint8_t *aBuf, uint16_t aLength, Message::SubType aMessageSubType)
992 {
993     Error        error   = kErrorNone;
994     ot::Message *message = nullptr;
995 
996     VerifyOrExit((message = mSocket.NewMessage(0)) != nullptr, error = kErrorNoBufs);
997     message->SetSubType(aMessageSubType);
998     message->SetLinkSecurityEnabled(mLayerTwoSecurity);
999 
1000     SuccessOrExit(error = message->AppendBytes(aBuf, aLength));
1001 
1002     // Set message sub type in case Joiner Finalize Response is appended to the message.
1003     if (aMessageSubType != Message::kSubTypeNone)
1004     {
1005         message->SetSubType(aMessageSubType);
1006     }
1007 
1008     if (mTransportCallback)
1009     {
1010         SuccessOrExit(error = mTransportCallback(mTransportContext, *message, mMessageInfo));
1011     }
1012     else
1013     {
1014         SuccessOrExit(error = mSocket.SendTo(*message, mMessageInfo));
1015     }
1016 
1017 exit:
1018     FreeMessageOnError(message, error);
1019     return error;
1020 }
1021 
1022 } // namespace MeshCoP
1023 } // namespace ot
1024 
1025 #endif // OPENTHREAD_CONFIG_DTLS_ENABLE
1026