• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2023, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 #include "ble_secure.hpp"
30 
31 #if OPENTHREAD_CONFIG_BLE_TCAT_ENABLE
32 
33 #include <openthread/platform/ble.h>
34 
35 #include "instance/instance.hpp"
36 
37 using namespace ot;
38 
39 /**
40  * @file
41  *   This file implements the secure Ble agent.
42  */
43 
44 namespace ot {
45 namespace Ble {
46 
47 RegisterLogModule("BleSecure");
48 
BleSecure(Instance & aInstance)49 BleSecure::BleSecure(Instance &aInstance)
50     : InstanceLocator(aInstance)
51     , MeshCoP::Tls::Extension(mTls)
52     , mTls(aInstance, kNoLinkSecurity, *this)
53     , mTcatAgent(aInstance)
54     , mTlvMode(false)
55     , mReceivedMessage(nullptr)
56     , mSendMessage(nullptr)
57     , mTransmitTask(aInstance)
58     , mBleState(kStopped)
59     , mMtuSize(kInitialMtuSize)
60 {
61 }
62 
Start(ConnectCallback aConnectHandler,ReceiveCallback aReceiveHandler,bool aTlvMode,void * aContext)63 Error BleSecure::Start(ConnectCallback aConnectHandler, ReceiveCallback aReceiveHandler, bool aTlvMode, void *aContext)
64 {
65     Error    error             = kErrorNone;
66     uint16_t advertisementLen  = 0;
67     uint8_t *advertisementData = nullptr;
68 
69     VerifyOrExit(mBleState == kStopped, error = kErrorAlready);
70 
71     mConnectCallback.Set(aConnectHandler, aContext);
72     mReceiveCallback.Set(aReceiveHandler, aContext);
73     mTlvMode = aTlvMode;
74     mMtuSize = kInitialMtuSize;
75 
76     SuccessOrExit(error = otPlatBleEnable(&GetInstance()));
77 
78     SuccessOrExit(error = otPlatBleGetAdvertisementBuffer(&GetInstance(), &advertisementData));
79     SuccessOrExit(error = mTcatAgent.GetAdvertisementData(advertisementLen, advertisementData));
80     VerifyOrExit(advertisementData != nullptr, error = kErrorFailed);
81     SuccessOrExit(error = otPlatBleGapAdvSetData(&GetInstance(), advertisementData, advertisementLen));
82     SuccessOrExit(error = otPlatBleGapAdvStart(&GetInstance(), OT_BLE_ADV_INTERVAL_DEFAULT));
83 
84     SuccessOrExit(error = mTls.Open());
85     mTls.SetReceiveCallback(HandleTlsReceive, this);
86     mTls.SetConnectCallback(HandleTlsConnectEvent, this);
87     SuccessOrExit(error = mTls.Bind(HandleTransport, this));
88 
89 exit:
90     if (error == kErrorNone)
91     {
92         mBleState = kAdvertising;
93     }
94     return error;
95 }
96 
TcatStart(MeshCoP::TcatAgent::JoinCallback aJoinHandler)97 Error BleSecure::TcatStart(MeshCoP::TcatAgent::JoinCallback aJoinHandler)
98 {
99     Error error;
100 
101     VerifyOrExit(mBleState != kStopped, error = kErrorInvalidState);
102 
103     error = mTcatAgent.Start(mReceiveCallback.GetHandler(), aJoinHandler, mReceiveCallback.GetContext());
104 
105 exit:
106     return error;
107 }
108 
Stop(void)109 void BleSecure::Stop(void)
110 {
111     VerifyOrExit(mBleState != kStopped);
112     SuccessOrExit(otPlatBleGapAdvStop(&GetInstance()));
113     SuccessOrExit(otPlatBleDisable(&GetInstance()));
114     mBleState = kStopped;
115     mMtuSize  = kInitialMtuSize;
116 
117     if (mTcatAgent.IsEnabled())
118     {
119         mTcatAgent.Stop();
120     }
121 
122     mTls.Close();
123 
124     mTransmitQueue.DequeueAndFreeAll();
125 
126     mConnectCallback.Clear();
127     mReceiveCallback.Clear();
128 
129     FreeMessage(mReceivedMessage);
130     mReceivedMessage = nullptr;
131     FreeMessage(mSendMessage);
132     mSendMessage = nullptr;
133 
134 exit:
135     return;
136 }
137 
Connect(void)138 Error BleSecure::Connect(void)
139 {
140     Ip6::SockAddr sockaddr;
141     Error         error;
142 
143     VerifyOrExit(mBleState == kConnected, error = kErrorInvalidState);
144 
145     error = mTls.Connect(sockaddr);
146 
147 exit:
148     return error;
149 }
150 
Disconnect(void)151 void BleSecure::Disconnect(void)
152 {
153     if (mTls.IsConnected())
154     {
155         mTls.Disconnect();
156     }
157 
158     if (mBleState == kConnected)
159     {
160         mBleState = kAdvertising;
161         IgnoreReturnValue(otPlatBleGapDisconnect(&GetInstance()));
162     }
163 
164     mConnectCallback.InvokeIfSet(&GetInstance(), false, false);
165 }
166 
SetPsk(const MeshCoP::JoinerPskd & aPskd)167 void BleSecure::SetPsk(const MeshCoP::JoinerPskd &aPskd)
168 {
169     static_assert(static_cast<uint16_t>(MeshCoP::JoinerPskd::kMaxLength) <=
170                       static_cast<uint16_t>(MeshCoP::Tls::kPskMaxLength),
171                   "The maximum length of TLS PSK is smaller than joiner PSKd");
172 
173     SuccessOrAssert(mTls.SetPsk(reinterpret_cast<const uint8_t *>(aPskd.GetAsCString()), aPskd.GetLength()));
174 }
175 
SendMessage(ot::Message & aMessage)176 Error BleSecure::SendMessage(ot::Message &aMessage)
177 {
178     Error error = kErrorNone;
179 
180     VerifyOrExit(IsConnected(), error = kErrorInvalidState);
181     if (mSendMessage == nullptr)
182     {
183         mSendMessage = Get<MessagePool>().Allocate(Message::kTypeBle);
184         VerifyOrExit(mSendMessage != nullptr, error = kErrorNoBufs);
185     }
186     SuccessOrExit(error = mSendMessage->AppendBytesFromMessage(aMessage, 0, aMessage.GetLength()));
187     SuccessOrExit(error = Flush());
188 
189 exit:
190     aMessage.Free();
191     return error;
192 }
193 
Send(uint8_t * aBuf,uint16_t aLength)194 Error BleSecure::Send(uint8_t *aBuf, uint16_t aLength)
195 {
196     Error error = kErrorNone;
197 
198     VerifyOrExit(IsConnected(), error = kErrorInvalidState);
199     if (mSendMessage == nullptr)
200     {
201         mSendMessage = Get<MessagePool>().Allocate(Message::kTypeBle);
202         VerifyOrExit(mSendMessage != nullptr, error = kErrorNoBufs);
203     }
204     SuccessOrExit(error = mSendMessage->AppendBytes(aBuf, aLength));
205 
206 exit:
207     return error;
208 }
209 
SendApplicationTlv(uint8_t * aBuf,uint16_t aLength)210 Error BleSecure::SendApplicationTlv(uint8_t *aBuf, uint16_t aLength)
211 {
212     Error error = kErrorNone;
213     if (aLength > Tlv::kBaseTlvMaxLength)
214     {
215         ot::ExtendedTlv tlv;
216 
217         tlv.SetType(ot::MeshCoP::TcatAgent::kTlvSendApplicationData);
218         tlv.SetLength(aLength);
219         SuccessOrExit(error = Send(reinterpret_cast<uint8_t *>(&tlv), sizeof(tlv)));
220     }
221     else
222     {
223         ot::Tlv tlv;
224 
225         tlv.SetType(ot::MeshCoP::TcatAgent::kTlvSendApplicationData);
226         tlv.SetLength((uint8_t)aLength);
227         SuccessOrExit(error = Send(reinterpret_cast<uint8_t *>(&tlv), sizeof(tlv)));
228     }
229 
230     error = Send(aBuf, aLength);
231 exit:
232     return error;
233 }
234 
Flush(void)235 Error BleSecure::Flush(void)
236 {
237     Error error = kErrorNone;
238 
239     VerifyOrExit(IsConnected(), error = kErrorInvalidState);
240     VerifyOrExit(mSendMessage->GetLength() != 0, error = kErrorNone);
241 
242     mTransmitQueue.Enqueue(*mSendMessage);
243     mTransmitTask.Post();
244 
245     mSendMessage = nullptr;
246 
247 exit:
248     return error;
249 }
250 
HandleBleReceive(uint8_t * aBuf,uint16_t aLength)251 Error BleSecure::HandleBleReceive(uint8_t *aBuf, uint16_t aLength)
252 {
253     ot::Message     *message = nullptr;
254     Ip6::MessageInfo messageInfo;
255     Error            error = kErrorNone;
256 
257     if ((message = Get<MessagePool>().Allocate(Message::kTypeBle, 0)) == nullptr)
258     {
259         error = kErrorNoBufs;
260         ExitNow();
261     }
262     SuccessOrExit(error = message->AppendBytes(aBuf, aLength));
263 
264     // Cannot call Receive(..) directly because Setup(..) and mState are private
265     mTls.HandleReceive(*message, messageInfo);
266 
267 exit:
268     FreeMessage(message);
269     return error;
270 }
271 
HandleBleConnected(uint16_t aConnectionId)272 void BleSecure::HandleBleConnected(uint16_t aConnectionId)
273 {
274     OT_UNUSED_VARIABLE(aConnectionId);
275 
276     mBleState = kConnected;
277 
278     IgnoreReturnValue(otPlatBleGattMtuGet(&GetInstance(), &mMtuSize));
279 
280     mConnectCallback.InvokeIfSet(&GetInstance(), IsConnected(), true);
281 }
282 
HandleBleDisconnected(uint16_t aConnectionId)283 void BleSecure::HandleBleDisconnected(uint16_t aConnectionId)
284 {
285     OT_UNUSED_VARIABLE(aConnectionId);
286 
287     mBleState = kAdvertising;
288     mMtuSize  = kInitialMtuSize;
289 
290     Disconnect(); // Stop TLS connection
291 }
292 
HandleBleMtuUpdate(uint16_t aMtu)293 Error BleSecure::HandleBleMtuUpdate(uint16_t aMtu)
294 {
295     Error error = kErrorNone;
296 
297     if (aMtu <= OT_BLE_ATT_MTU_MAX)
298     {
299         mMtuSize = aMtu;
300     }
301     else
302     {
303         mMtuSize = OT_BLE_ATT_MTU_MAX;
304         error    = kErrorInvalidArgs;
305     }
306 
307     return error;
308 }
309 
HandleTlsConnectEvent(MeshCoP::Tls::ConnectEvent aEvent,void * aContext)310 void BleSecure::HandleTlsConnectEvent(MeshCoP::Tls::ConnectEvent aEvent, void *aContext)
311 {
312     return static_cast<BleSecure *>(aContext)->HandleTlsConnectEvent(aEvent);
313 }
314 
HandleTlsConnectEvent(MeshCoP::Tls::ConnectEvent aEvent)315 void BleSecure::HandleTlsConnectEvent(MeshCoP::Tls::ConnectEvent aEvent)
316 {
317     if (aEvent == MeshCoP::Tls::kConnected)
318     {
319         Error err;
320 
321         if (mReceivedMessage == nullptr)
322         {
323             mReceivedMessage = Get<MessagePool>().Allocate(Message::kTypeBle);
324         }
325         err = mTcatAgent.Connected(*this);
326 
327         if (err != kErrorNone)
328         {
329             mTls.Disconnect(); // must not use Close(), so that next Commissioner can connect
330             LogWarn("Rejected TCAT Commissioner, error: %s", ErrorToString(err));
331             ExitNow();
332         }
333     }
334     else
335     {
336         FreeMessage(mReceivedMessage);
337         mReceivedMessage = nullptr;
338 
339         if (mTcatAgent.IsEnabled())
340         {
341             mTcatAgent.Disconnected();
342         }
343     }
344 
345     mConnectCallback.InvokeIfSet(&GetInstance(), aEvent == MeshCoP::Tls::kConnected, true);
346 
347 exit:
348     return;
349 }
350 
HandleTlsReceive(void * aContext,uint8_t * aBuf,uint16_t aLength)351 void BleSecure::HandleTlsReceive(void *aContext, uint8_t *aBuf, uint16_t aLength)
352 {
353     return static_cast<BleSecure *>(aContext)->HandleTlsReceive(aBuf, aLength);
354 }
355 
HandleTlsReceive(uint8_t * aBuf,uint16_t aLength)356 void BleSecure::HandleTlsReceive(uint8_t *aBuf, uint16_t aLength)
357 {
358     VerifyOrExit(mReceivedMessage != nullptr);
359 
360     if (!mTlvMode)
361     {
362         SuccessOrExit(mReceivedMessage->AppendBytes(aBuf, aLength));
363         mReceiveCallback.InvokeIfSet(&GetInstance(), mReceivedMessage, 0, OT_TCAT_APPLICATION_PROTOCOL_NONE, "");
364         IgnoreReturnValue(mReceivedMessage->SetLength(0));
365     }
366     else
367     {
368         ot::Tlv  tlv;
369         uint32_t requiredBytes = sizeof(Tlv);
370         uint32_t offset;
371 
372         while (aLength > 0)
373         {
374             if (mReceivedMessage->GetLength() < requiredBytes)
375             {
376                 uint32_t missingBytes = requiredBytes - mReceivedMessage->GetLength();
377 
378                 if (missingBytes > aLength)
379                 {
380                     SuccessOrExit(mReceivedMessage->AppendBytes(aBuf, aLength));
381                     break;
382                 }
383                 else
384                 {
385                     SuccessOrExit(mReceivedMessage->AppendBytes(aBuf, (uint16_t)missingBytes));
386                     aLength -= missingBytes;
387                     aBuf += missingBytes;
388                 }
389             }
390 
391             IgnoreReturnValue(mReceivedMessage->Read(0, tlv));
392 
393             if (tlv.IsExtended())
394             {
395                 ot::ExtendedTlv extTlv;
396                 requiredBytes = sizeof(extTlv);
397 
398                 if (mReceivedMessage->GetLength() < requiredBytes)
399                 {
400                     continue;
401                 }
402 
403                 IgnoreReturnValue(mReceivedMessage->Read(0, extTlv));
404                 requiredBytes = extTlv.GetSize();
405                 offset        = sizeof(extTlv);
406             }
407             else
408             {
409                 requiredBytes = tlv.GetSize();
410                 offset        = sizeof(tlv);
411             }
412 
413             if (mReceivedMessage->GetLength() < requiredBytes)
414             {
415                 continue;
416             }
417 
418             // TLV fully loaded
419 
420             if (mTcatAgent.IsEnabled())
421             {
422                 ot::Message *message;
423                 Error        error = kErrorNone;
424 
425                 message = Get<MessagePool>().Allocate(Message::kTypeBle);
426                 VerifyOrExit(message != nullptr, error = kErrorNoBufs);
427 
428                 error = mTcatAgent.HandleSingleTlv(*mReceivedMessage, *message);
429                 if (message->GetLength() != 0)
430                 {
431                     IgnoreReturnValue(SendMessage(*message));
432                 }
433 
434                 if (error == kErrorAbort)
435                 {
436                     LogInfo("Disconnecting TCAT client.");
437                     // kErrorAbort indicates that a Disconnect command TLV has been received.
438                     Disconnect();
439                     // BleSecure is not stopped here, it must remain active in advertising state and
440                     // must be ready to receive a next TCAT commissioner.
441                     ExitNow();
442                 }
443             }
444             else
445             {
446                 mReceivedMessage->SetOffset((uint16_t)offset);
447                 mReceiveCallback.InvokeIfSet(&GetInstance(), mReceivedMessage, (int32_t)offset,
448                                              OT_TCAT_APPLICATION_PROTOCOL_NONE, "");
449             }
450 
451             SuccessOrExit(mReceivedMessage->SetLength(0)); // also sets the offset to 0
452             requiredBytes = sizeof(Tlv);
453         }
454     }
455 
456 exit:
457     return;
458 }
459 
HandleTransmit(void)460 void BleSecure::HandleTransmit(void)
461 {
462     Error        error   = kErrorNone;
463     ot::Message *message = mTransmitQueue.GetHead();
464 
465     VerifyOrExit(message != nullptr);
466     mTransmitQueue.Dequeue(*message);
467 
468     if (mTransmitQueue.GetHead() != nullptr)
469     {
470         mTransmitTask.Post();
471     }
472 
473     SuccessOrExit(error = mTls.Send(*message));
474     LogDebg("Transmit");
475 
476 exit:
477     FreeMessageOnError(message, error);
478     LogWarnOnError(error, "transmit");
479 }
480 
HandleTransport(void * aContext,ot::Message & aMessage,const Ip6::MessageInfo & aMessageInfo)481 Error BleSecure::HandleTransport(void *aContext, ot::Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
482 {
483     OT_UNUSED_VARIABLE(aMessageInfo);
484     return static_cast<BleSecure *>(aContext)->HandleTransport(aMessage);
485 }
486 
HandleTransport(ot::Message & aMessage)487 Error BleSecure::HandleTransport(ot::Message &aMessage)
488 {
489     otBleRadioPacket packet;
490     uint16_t         len    = aMessage.GetLength();
491     uint16_t         offset = 0;
492     Error            error  = kErrorNone;
493 
494     while (len > 0)
495     {
496         if (len <= mMtuSize - kGattOverhead)
497         {
498             packet.mLength = len;
499         }
500         else
501         {
502             packet.mLength = mMtuSize - kGattOverhead;
503         }
504 
505         if (packet.mLength > kPacketBufferSize)
506         {
507             packet.mLength = kPacketBufferSize;
508         }
509 
510         IgnoreReturnValue(aMessage.Read(offset, mPacketBuffer, packet.mLength));
511         packet.mValue = mPacketBuffer;
512         packet.mPower = OT_BLE_DEFAULT_POWER;
513 
514         SuccessOrExit(error = otPlatBleGattServerIndicate(&GetInstance(), kTxBleHandle, &packet));
515 
516         len -= packet.mLength;
517         offset += packet.mLength;
518     }
519 
520     aMessage.Free();
521 exit:
522     return error;
523 }
524 
525 } // namespace Ble
526 } // namespace ot
527 
otPlatBleGattServerOnWriteRequest(otInstance * aInstance,uint16_t aHandle,const otBleRadioPacket * aPacket)528 void otPlatBleGattServerOnWriteRequest(otInstance *aInstance, uint16_t aHandle, const otBleRadioPacket *aPacket)
529 {
530     OT_UNUSED_VARIABLE(aHandle); // Only a single handle is expected for RX
531 
532     VerifyOrExit(aPacket != nullptr);
533     IgnoreReturnValue(AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleReceive(aPacket->mValue, aPacket->mLength));
534 exit:
535     return;
536 }
537 
otPlatBleGapOnConnected(otInstance * aInstance,uint16_t aConnectionId)538 void otPlatBleGapOnConnected(otInstance *aInstance, uint16_t aConnectionId)
539 {
540     AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleConnected(aConnectionId);
541 }
542 
otPlatBleGapOnDisconnected(otInstance * aInstance,uint16_t aConnectionId)543 void otPlatBleGapOnDisconnected(otInstance *aInstance, uint16_t aConnectionId)
544 {
545     AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleDisconnected(aConnectionId);
546 }
547 
otPlatBleGattOnMtuUpdate(otInstance * aInstance,uint16_t aMtu)548 void otPlatBleGattOnMtuUpdate(otInstance *aInstance, uint16_t aMtu)
549 {
550     IgnoreReturnValue(AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleMtuUpdate(aMtu));
551 }
552 
553 #endif // OPENTHREAD_CONFIG_BLE_TCAT_ENABLE
554