• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2023, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 #include "ble_secure.hpp"
30 
31 #if OPENTHREAD_CONFIG_BLE_TCAT_ENABLE
32 
33 #include <openthread/platform/ble.h>
34 #include "common/locator_getters.hpp"
35 #include "common/log.hpp"
36 #include "common/tlvs.hpp"
37 #include "instance/instance.hpp"
38 #include "meshcop/secure_transport.hpp"
39 
40 using namespace ot;
41 
42 /**
43  * @file
44  *   This file implements the secure Ble agent.
45  */
46 
47 namespace ot {
48 namespace Ble {
49 
50 RegisterLogModule("BleSecure");
51 
BleSecure(Instance & aInstance)52 BleSecure::BleSecure(Instance &aInstance)
53     : InstanceLocator(aInstance)
54     , mTls(aInstance, false, false)
55     , mTcatAgent(aInstance)
56     , mTlvMode(false)
57     , mReceivedMessage(nullptr)
58     , mSendMessage(nullptr)
59     , mTransmitTask(aInstance)
60     , mBleState(kStopped)
61     , mMtuSize(kInitialMtuSize)
62 {
63 }
64 
Start(ConnectCallback aConnectHandler,ReceiveCallback aReceiveHandler,bool aTlvMode,void * aContext)65 Error BleSecure::Start(ConnectCallback aConnectHandler, ReceiveCallback aReceiveHandler, bool aTlvMode, void *aContext)
66 {
67     Error error = kErrorNone;
68 
69     VerifyOrExit(mBleState == kStopped, error = kErrorAlready);
70 
71     mConnectCallback.Set(aConnectHandler, aContext);
72     mReceiveCallback.Set(aReceiveHandler, aContext);
73     mTlvMode = aTlvMode;
74     mMtuSize = kInitialMtuSize;
75 
76     SuccessOrExit(error = otPlatBleEnable(&GetInstance()));
77     SuccessOrExit(error = otPlatBleGapAdvStart(&GetInstance(), OT_BLE_ADV_INTERVAL_DEFAULT));
78     SuccessOrExit(error = mTls.Open(&BleSecure::HandleTlsReceive, &BleSecure::HandleTlsConnected, this));
79     SuccessOrExit(error = mTls.Bind(HandleTransport, this));
80 
81 exit:
82     if (error == kErrorNone)
83     {
84         mBleState = kAdvertising;
85     }
86     return error;
87 }
88 
TcatStart(const MeshCoP::TcatAgent::VendorInfo & aVendorInfo,MeshCoP::TcatAgent::JoinCallback aJoinHandler)89 Error BleSecure::TcatStart(const MeshCoP::TcatAgent::VendorInfo &aVendorInfo,
90                            MeshCoP::TcatAgent::JoinCallback      aJoinHandler)
91 {
92     Error error;
93 
94     VerifyOrExit(mBleState != kStopped, error = kErrorInvalidState);
95 
96     error = mTcatAgent.Start(aVendorInfo, mReceiveCallback.GetHandler(), aJoinHandler, mReceiveCallback.GetContext());
97 
98 exit:
99     return error;
100 }
101 
Stop(void)102 void BleSecure::Stop(void)
103 {
104     VerifyOrExit(mBleState != kStopped);
105     SuccessOrExit(otPlatBleGapAdvStop(&GetInstance()));
106     SuccessOrExit(otPlatBleDisable(&GetInstance()));
107     mBleState = kStopped;
108     mMtuSize  = kInitialMtuSize;
109 
110     if (mTcatAgent.IsEnabled())
111     {
112         mTcatAgent.Stop();
113     }
114 
115     mTls.Close();
116 
117     mTransmitQueue.DequeueAndFreeAll();
118 
119     mConnectCallback.Clear();
120     mReceiveCallback.Clear();
121 
122     FreeMessage(mReceivedMessage);
123     mReceivedMessage = nullptr;
124     FreeMessage(mSendMessage);
125     mSendMessage = nullptr;
126 
127 exit:
128     return;
129 }
130 
Connect(void)131 Error BleSecure::Connect(void)
132 {
133     Ip6::SockAddr sockaddr;
134     Error         error;
135 
136     VerifyOrExit(mBleState == kConnected, error = kErrorInvalidState);
137 
138     error = mTls.Connect(sockaddr);
139 
140 exit:
141     return error;
142 }
143 
Disconnect(void)144 void BleSecure::Disconnect(void)
145 {
146     if (mTls.IsConnected())
147     {
148         mTls.Disconnect();
149     }
150 
151     if (mBleState == kConnected)
152     {
153         mBleState = kAdvertising;
154         IgnoreReturnValue(otPlatBleGapDisconnect(&GetInstance()));
155     }
156 
157     mConnectCallback.InvokeIfSet(&GetInstance(), false, false);
158 }
159 
SetPsk(const MeshCoP::JoinerPskd & aPskd)160 void BleSecure::SetPsk(const MeshCoP::JoinerPskd &aPskd)
161 {
162     static_assert(static_cast<uint16_t>(MeshCoP::JoinerPskd::kMaxLength) <=
163                       static_cast<uint16_t>(MeshCoP::SecureTransport::kPskMaxLength),
164                   "The maximum length of TLS PSK is smaller than joiner PSKd");
165 
166     SuccessOrAssert(mTls.SetPsk(reinterpret_cast<const uint8_t *>(aPskd.GetAsCString()), aPskd.GetLength()));
167 }
168 
169 #if defined(MBEDTLS_BASE64_C) && defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
GetPeerCertificateBase64(unsigned char * aPeerCert,size_t * aCertLength)170 Error BleSecure::GetPeerCertificateBase64(unsigned char *aPeerCert, size_t *aCertLength)
171 {
172     Error error;
173 
174     VerifyOrExit(aCertLength != nullptr, error = kErrorInvalidArgs);
175 
176     error = mTls.GetPeerCertificateBase64(aPeerCert, aCertLength, *aCertLength);
177 
178 exit:
179     return error;
180 }
181 #endif // defined(MBEDTLS_BASE64_C) && defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
182 
SendMessage(ot::Message & aMessage)183 Error BleSecure::SendMessage(ot::Message &aMessage)
184 {
185     Error error = kErrorNone;
186 
187     VerifyOrExit(IsConnected(), error = kErrorInvalidState);
188     if (mSendMessage == nullptr)
189     {
190         mSendMessage = Get<MessagePool>().Allocate(Message::kTypeBle);
191         VerifyOrExit(mSendMessage != nullptr, error = kErrorNoBufs);
192     }
193     SuccessOrExit(error = mSendMessage->AppendBytesFromMessage(aMessage, 0, aMessage.GetLength()));
194     SuccessOrExit(error = Flush());
195 
196 exit:
197     aMessage.Free();
198     return error;
199 }
200 
Send(uint8_t * aBuf,uint16_t aLength)201 Error BleSecure::Send(uint8_t *aBuf, uint16_t aLength)
202 {
203     Error error = kErrorNone;
204 
205     VerifyOrExit(IsConnected(), error = kErrorInvalidState);
206     if (mSendMessage == nullptr)
207     {
208         mSendMessage = Get<MessagePool>().Allocate(Message::kTypeBle);
209         VerifyOrExit(mSendMessage != nullptr, error = kErrorNoBufs);
210     }
211     SuccessOrExit(error = mSendMessage->AppendBytes(aBuf, aLength));
212 
213 exit:
214     return error;
215 }
216 
SendApplicationTlv(uint8_t * aBuf,uint16_t aLength)217 Error BleSecure::SendApplicationTlv(uint8_t *aBuf, uint16_t aLength)
218 {
219     Error error = kErrorNone;
220     if (aLength > Tlv::kBaseTlvMaxLength)
221     {
222         ot::ExtendedTlv tlv;
223 
224         tlv.SetType(ot::MeshCoP::TcatAgent::kTlvSendApplicationData);
225         tlv.SetLength(aLength);
226         SuccessOrExit(error = Send(reinterpret_cast<uint8_t *>(&tlv), sizeof(tlv)));
227     }
228     else
229     {
230         ot::Tlv tlv;
231 
232         tlv.SetType(ot::MeshCoP::TcatAgent::kTlvSendApplicationData);
233         tlv.SetLength((uint8_t)aLength);
234         SuccessOrExit(error = Send(reinterpret_cast<uint8_t *>(&tlv), sizeof(tlv)));
235     }
236 
237     error = Send(aBuf, aLength);
238 exit:
239     return error;
240 }
241 
Flush(void)242 Error BleSecure::Flush(void)
243 {
244     Error error = kErrorNone;
245 
246     VerifyOrExit(IsConnected(), error = kErrorInvalidState);
247     VerifyOrExit(mSendMessage->GetLength() != 0, error = kErrorNone);
248 
249     mTransmitQueue.Enqueue(*mSendMessage);
250     mTransmitTask.Post();
251 
252     mSendMessage = nullptr;
253 
254 exit:
255     return error;
256 }
257 
HandleBleReceive(uint8_t * aBuf,uint16_t aLength)258 Error BleSecure::HandleBleReceive(uint8_t *aBuf, uint16_t aLength)
259 {
260     ot::Message     *message = nullptr;
261     Ip6::MessageInfo messageInfo;
262     Error            error = kErrorNone;
263 
264     if ((message = Get<MessagePool>().Allocate(Message::kTypeBle, 0)) == nullptr)
265     {
266         error = kErrorNoBufs;
267         ExitNow();
268     }
269     SuccessOrExit(error = message->AppendBytes(aBuf, aLength));
270 
271     // Cannot call Receive(..) directly because Setup(..) and mState are private
272     mTls.HandleReceive(*message, messageInfo);
273 
274 exit:
275     FreeMessage(message);
276     return error;
277 }
278 
HandleBleConnected(uint16_t aConnectionId)279 void BleSecure::HandleBleConnected(uint16_t aConnectionId)
280 {
281     OT_UNUSED_VARIABLE(aConnectionId);
282 
283     mBleState = kConnected;
284 
285     IgnoreReturnValue(otPlatBleGattMtuGet(&GetInstance(), &mMtuSize));
286 
287     mConnectCallback.InvokeIfSet(&GetInstance(), IsConnected(), true);
288 }
289 
HandleBleDisconnected(uint16_t aConnectionId)290 void BleSecure::HandleBleDisconnected(uint16_t aConnectionId)
291 {
292     OT_UNUSED_VARIABLE(aConnectionId);
293 
294     mBleState = kAdvertising;
295     mMtuSize  = kInitialMtuSize;
296 
297     Disconnect(); // Stop TLS connection
298 }
299 
HandleBleMtuUpdate(uint16_t aMtu)300 Error BleSecure::HandleBleMtuUpdate(uint16_t aMtu)
301 {
302     Error error = kErrorNone;
303 
304     if (aMtu <= OT_BLE_ATT_MTU_MAX)
305     {
306         mMtuSize = aMtu;
307     }
308     else
309     {
310         mMtuSize = OT_BLE_ATT_MTU_MAX;
311         error    = kErrorInvalidArgs;
312     }
313 
314     return error;
315 }
316 
HandleTlsConnected(void * aContext,bool aConnected)317 void BleSecure::HandleTlsConnected(void *aContext, bool aConnected)
318 {
319     return static_cast<BleSecure *>(aContext)->HandleTlsConnected(aConnected);
320 }
321 
HandleTlsConnected(bool aConnected)322 void BleSecure::HandleTlsConnected(bool aConnected)
323 {
324     if (aConnected)
325     {
326         if (mReceivedMessage == nullptr)
327         {
328             mReceivedMessage = Get<MessagePool>().Allocate(Message::kTypeBle);
329         }
330 
331         if (mTcatAgent.IsEnabled())
332         {
333             IgnoreReturnValue(mTcatAgent.Connected(mTls));
334         }
335     }
336     else
337     {
338         FreeMessage(mReceivedMessage);
339         mReceivedMessage = nullptr;
340 
341         if (mTcatAgent.IsEnabled())
342         {
343             mTcatAgent.Disconnected();
344         }
345     }
346 
347     mConnectCallback.InvokeIfSet(&GetInstance(), aConnected, true);
348 }
349 
HandleTlsReceive(void * aContext,uint8_t * aBuf,uint16_t aLength)350 void BleSecure::HandleTlsReceive(void *aContext, uint8_t *aBuf, uint16_t aLength)
351 {
352     return static_cast<BleSecure *>(aContext)->HandleTlsReceive(aBuf, aLength);
353 }
354 
HandleTlsReceive(uint8_t * aBuf,uint16_t aLength)355 void BleSecure::HandleTlsReceive(uint8_t *aBuf, uint16_t aLength)
356 {
357     VerifyOrExit(mReceivedMessage != nullptr);
358 
359     if (!mTlvMode)
360     {
361         SuccessOrExit(mReceivedMessage->AppendBytes(aBuf, aLength));
362         mReceiveCallback.InvokeIfSet(&GetInstance(), mReceivedMessage, 0, OT_TCAT_APPLICATION_PROTOCOL_NONE, "");
363         IgnoreReturnValue(mReceivedMessage->SetLength(0));
364     }
365     else
366     {
367         ot::Tlv  tlv;
368         uint32_t requiredBytes = sizeof(Tlv);
369         uint32_t offset;
370 
371         while (aLength > 0)
372         {
373             if (mReceivedMessage->GetLength() < requiredBytes)
374             {
375                 uint32_t missingBytes = requiredBytes - mReceivedMessage->GetLength();
376 
377                 if (missingBytes > aLength)
378                 {
379                     SuccessOrExit(mReceivedMessage->AppendBytes(aBuf, aLength));
380                     break;
381                 }
382                 else
383                 {
384                     SuccessOrExit(mReceivedMessage->AppendBytes(aBuf, (uint16_t)missingBytes));
385                     aLength -= missingBytes;
386                     aBuf += missingBytes;
387                 }
388             }
389 
390             IgnoreReturnValue(mReceivedMessage->Read(0, tlv));
391 
392             if (tlv.IsExtended())
393             {
394                 ot::ExtendedTlv extTlv;
395                 requiredBytes = sizeof(extTlv);
396 
397                 if (mReceivedMessage->GetLength() < requiredBytes)
398                 {
399                     continue;
400                 }
401 
402                 IgnoreReturnValue(mReceivedMessage->Read(0, extTlv));
403                 requiredBytes = extTlv.GetSize();
404                 offset        = sizeof(extTlv);
405             }
406             else
407             {
408                 requiredBytes = tlv.GetSize();
409                 offset        = sizeof(tlv);
410             }
411 
412             if (mReceivedMessage->GetLength() < requiredBytes)
413             {
414                 continue;
415             }
416 
417             // TLV fully loaded
418 
419             if (mTcatAgent.IsEnabled())
420             {
421                 ot::Message *message;
422                 Error        error = kErrorNone;
423 
424                 message = Get<MessagePool>().Allocate(Message::kTypeBle);
425                 VerifyOrExit(message != nullptr, error = kErrorNoBufs);
426 
427                 error = mTcatAgent.HandleSingleTlv(*mReceivedMessage, *message);
428                 if (message->GetLength() != 0)
429                 {
430                     IgnoreReturnValue(SendMessage(*message));
431                 }
432 
433                 if (error == kErrorAbort)
434                 {
435                     Disconnect();
436                     Stop();
437                     ExitNow();
438                 }
439             }
440             else
441             {
442                 mReceivedMessage->SetOffset((uint16_t)offset);
443                 mReceiveCallback.InvokeIfSet(&GetInstance(), mReceivedMessage, (int32_t)offset,
444                                              OT_TCAT_APPLICATION_PROTOCOL_NONE, "");
445             }
446 
447             SuccessOrExit(mReceivedMessage->SetLength(0)); // also sets the offset to 0
448             requiredBytes = sizeof(Tlv);
449         }
450     }
451 
452 exit:
453     return;
454 }
455 
HandleTransmit(void)456 void BleSecure::HandleTransmit(void)
457 {
458     Error        error   = kErrorNone;
459     ot::Message *message = mTransmitQueue.GetHead();
460 
461     VerifyOrExit(message != nullptr);
462     mTransmitQueue.Dequeue(*message);
463 
464     if (mTransmitQueue.GetHead() != nullptr)
465     {
466         mTransmitTask.Post();
467     }
468 
469     SuccessOrExit(error = mTls.Send(*message, message->GetLength()));
470 
471 exit:
472     if (error != kErrorNone)
473     {
474         LogNote("Transmit: %s", ErrorToString(error));
475         message->Free();
476     }
477     else
478     {
479         LogDebg("Transmit: %s", ErrorToString(error));
480     }
481 }
482 
HandleTransport(void * aContext,ot::Message & aMessage,const Ip6::MessageInfo & aMessageInfo)483 Error BleSecure::HandleTransport(void *aContext, ot::Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
484 {
485     OT_UNUSED_VARIABLE(aMessageInfo);
486     return static_cast<BleSecure *>(aContext)->HandleTransport(aMessage);
487 }
488 
HandleTransport(ot::Message & aMessage)489 Error BleSecure::HandleTransport(ot::Message &aMessage)
490 {
491     otBleRadioPacket packet;
492     uint16_t         len    = aMessage.GetLength();
493     uint16_t         offset = 0;
494     Error            error  = kErrorNone;
495 
496     while (len > 0)
497     {
498         if (len <= mMtuSize - kGattOverhead)
499         {
500             packet.mLength = len;
501         }
502         else
503         {
504             packet.mLength = mMtuSize - kGattOverhead;
505         }
506 
507         if (packet.mLength > kPacketBufferSize)
508         {
509             packet.mLength = kPacketBufferSize;
510         }
511 
512         IgnoreReturnValue(aMessage.Read(offset, mPacketBuffer, packet.mLength));
513         packet.mValue = mPacketBuffer;
514         packet.mPower = OT_BLE_DEFAULT_POWER;
515 
516         SuccessOrExit(error = otPlatBleGattServerIndicate(&GetInstance(), kTxBleHandle, &packet));
517 
518         len -= packet.mLength;
519         offset += packet.mLength;
520     }
521 
522     aMessage.Free();
523 exit:
524     return error;
525 }
526 
527 } // namespace Ble
528 } // namespace ot
529 
otPlatBleGattServerOnWriteRequest(otInstance * aInstance,uint16_t aHandle,const otBleRadioPacket * aPacket)530 void otPlatBleGattServerOnWriteRequest(otInstance *aInstance, uint16_t aHandle, const otBleRadioPacket *aPacket)
531 {
532     OT_UNUSED_VARIABLE(aHandle); // Only a single handle is expected for RX
533 
534     VerifyOrExit(aPacket != nullptr);
535     IgnoreReturnValue(AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleReceive(aPacket->mValue, aPacket->mLength));
536 exit:
537     return;
538 }
539 
otPlatBleGapOnConnected(otInstance * aInstance,uint16_t aConnectionId)540 void otPlatBleGapOnConnected(otInstance *aInstance, uint16_t aConnectionId)
541 {
542     AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleConnected(aConnectionId);
543 }
544 
otPlatBleGapOnDisconnected(otInstance * aInstance,uint16_t aConnectionId)545 void otPlatBleGapOnDisconnected(otInstance *aInstance, uint16_t aConnectionId)
546 {
547     AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleDisconnected(aConnectionId);
548 }
549 
otPlatBleGattOnMtuUpdate(otInstance * aInstance,uint16_t aMtu)550 void otPlatBleGattOnMtuUpdate(otInstance *aInstance, uint16_t aMtu)
551 {
552     IgnoreReturnValue(AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleMtuUpdate(aMtu));
553 }
554 
555 #endif // OPENTHREAD_CONFIG_BLE_TCAT_ENABLE
556