1 /*
2 * Copyright (c) 2023, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 #include "ble_secure.hpp"
30
31 #if OPENTHREAD_CONFIG_BLE_TCAT_ENABLE
32
33 #include <openthread/platform/ble.h>
34
35 #include "instance/instance.hpp"
36
37 using namespace ot;
38
39 /**
40 * @file
41 * This file implements the secure Ble agent.
42 */
43
44 namespace ot {
45 namespace Ble {
46
47 RegisterLogModule("BleSecure");
48
BleSecure(Instance & aInstance)49 BleSecure::BleSecure(Instance &aInstance)
50 : InstanceLocator(aInstance)
51 , MeshCoP::Tls::Extension(mTls)
52 , mTls(aInstance, kNoLinkSecurity, *this)
53 , mTcatAgent(aInstance)
54 , mTlvMode(false)
55 , mReceivedMessage(nullptr)
56 , mSendMessage(nullptr)
57 , mTransmitTask(aInstance)
58 , mBleState(kStopped)
59 , mMtuSize(kInitialMtuSize)
60 {
61 }
62
Start(ConnectCallback aConnectHandler,ReceiveCallback aReceiveHandler,bool aTlvMode,void * aContext)63 Error BleSecure::Start(ConnectCallback aConnectHandler, ReceiveCallback aReceiveHandler, bool aTlvMode, void *aContext)
64 {
65 Error error = kErrorNone;
66 uint16_t advertisementLen = 0;
67 uint8_t *advertisementData = nullptr;
68
69 VerifyOrExit(mBleState == kStopped, error = kErrorAlready);
70
71 mConnectCallback.Set(aConnectHandler, aContext);
72 mReceiveCallback.Set(aReceiveHandler, aContext);
73 mTlvMode = aTlvMode;
74 mMtuSize = kInitialMtuSize;
75
76 SuccessOrExit(error = otPlatBleEnable(&GetInstance()));
77
78 SuccessOrExit(error = otPlatBleGetAdvertisementBuffer(&GetInstance(), &advertisementData));
79 SuccessOrExit(error = mTcatAgent.GetAdvertisementData(advertisementLen, advertisementData));
80 VerifyOrExit(advertisementData != nullptr, error = kErrorFailed);
81 SuccessOrExit(error = otPlatBleGapAdvSetData(&GetInstance(), advertisementData, advertisementLen));
82 SuccessOrExit(error = otPlatBleGapAdvStart(&GetInstance(), OT_BLE_ADV_INTERVAL_DEFAULT));
83
84 SuccessOrExit(error = mTls.Open());
85 mTls.SetReceiveCallback(HandleTlsReceive, this);
86 mTls.SetConnectCallback(HandleTlsConnectEvent, this);
87 SuccessOrExit(error = mTls.Bind(HandleTransport, this));
88
89 exit:
90 if (error == kErrorNone)
91 {
92 mBleState = kAdvertising;
93 }
94 return error;
95 }
96
TcatStart(MeshCoP::TcatAgent::JoinCallback aJoinHandler)97 Error BleSecure::TcatStart(MeshCoP::TcatAgent::JoinCallback aJoinHandler)
98 {
99 Error error;
100
101 VerifyOrExit(mBleState != kStopped, error = kErrorInvalidState);
102
103 error = mTcatAgent.Start(mReceiveCallback.GetHandler(), aJoinHandler, mReceiveCallback.GetContext());
104
105 exit:
106 return error;
107 }
108
Stop(void)109 void BleSecure::Stop(void)
110 {
111 VerifyOrExit(mBleState != kStopped);
112 SuccessOrExit(otPlatBleGapAdvStop(&GetInstance()));
113 SuccessOrExit(otPlatBleDisable(&GetInstance()));
114 mBleState = kStopped;
115 mMtuSize = kInitialMtuSize;
116
117 if (mTcatAgent.IsEnabled())
118 {
119 mTcatAgent.Stop();
120 }
121
122 mTls.Close();
123
124 mTransmitQueue.DequeueAndFreeAll();
125
126 mConnectCallback.Clear();
127 mReceiveCallback.Clear();
128
129 FreeMessage(mReceivedMessage);
130 mReceivedMessage = nullptr;
131 FreeMessage(mSendMessage);
132 mSendMessage = nullptr;
133
134 exit:
135 return;
136 }
137
Connect(void)138 Error BleSecure::Connect(void)
139 {
140 Ip6::SockAddr sockaddr;
141 Error error;
142
143 VerifyOrExit(mBleState == kConnected, error = kErrorInvalidState);
144
145 error = mTls.Connect(sockaddr);
146
147 exit:
148 return error;
149 }
150
Disconnect(void)151 void BleSecure::Disconnect(void)
152 {
153 if (mTls.IsConnected())
154 {
155 mTls.Disconnect();
156 }
157
158 if (mBleState == kConnected)
159 {
160 mBleState = kAdvertising;
161 IgnoreReturnValue(otPlatBleGapDisconnect(&GetInstance()));
162 }
163
164 mConnectCallback.InvokeIfSet(&GetInstance(), false, false);
165 }
166
SetPsk(const MeshCoP::JoinerPskd & aPskd)167 void BleSecure::SetPsk(const MeshCoP::JoinerPskd &aPskd)
168 {
169 static_assert(static_cast<uint16_t>(MeshCoP::JoinerPskd::kMaxLength) <=
170 static_cast<uint16_t>(MeshCoP::Tls::kPskMaxLength),
171 "The maximum length of TLS PSK is smaller than joiner PSKd");
172
173 SuccessOrAssert(mTls.SetPsk(reinterpret_cast<const uint8_t *>(aPskd.GetAsCString()), aPskd.GetLength()));
174 }
175
SendMessage(ot::Message & aMessage)176 Error BleSecure::SendMessage(ot::Message &aMessage)
177 {
178 Error error = kErrorNone;
179
180 VerifyOrExit(IsConnected(), error = kErrorInvalidState);
181 if (mSendMessage == nullptr)
182 {
183 mSendMessage = Get<MessagePool>().Allocate(Message::kTypeBle);
184 VerifyOrExit(mSendMessage != nullptr, error = kErrorNoBufs);
185 }
186 SuccessOrExit(error = mSendMessage->AppendBytesFromMessage(aMessage, 0, aMessage.GetLength()));
187 SuccessOrExit(error = Flush());
188
189 exit:
190 aMessage.Free();
191 return error;
192 }
193
Send(uint8_t * aBuf,uint16_t aLength)194 Error BleSecure::Send(uint8_t *aBuf, uint16_t aLength)
195 {
196 Error error = kErrorNone;
197
198 VerifyOrExit(IsConnected(), error = kErrorInvalidState);
199 if (mSendMessage == nullptr)
200 {
201 mSendMessage = Get<MessagePool>().Allocate(Message::kTypeBle);
202 VerifyOrExit(mSendMessage != nullptr, error = kErrorNoBufs);
203 }
204 SuccessOrExit(error = mSendMessage->AppendBytes(aBuf, aLength));
205
206 exit:
207 return error;
208 }
209
SendApplicationTlv(uint8_t * aBuf,uint16_t aLength)210 Error BleSecure::SendApplicationTlv(uint8_t *aBuf, uint16_t aLength)
211 {
212 Error error = kErrorNone;
213 if (aLength > Tlv::kBaseTlvMaxLength)
214 {
215 ot::ExtendedTlv tlv;
216
217 tlv.SetType(ot::MeshCoP::TcatAgent::kTlvSendApplicationData);
218 tlv.SetLength(aLength);
219 SuccessOrExit(error = Send(reinterpret_cast<uint8_t *>(&tlv), sizeof(tlv)));
220 }
221 else
222 {
223 ot::Tlv tlv;
224
225 tlv.SetType(ot::MeshCoP::TcatAgent::kTlvSendApplicationData);
226 tlv.SetLength((uint8_t)aLength);
227 SuccessOrExit(error = Send(reinterpret_cast<uint8_t *>(&tlv), sizeof(tlv)));
228 }
229
230 error = Send(aBuf, aLength);
231 exit:
232 return error;
233 }
234
Flush(void)235 Error BleSecure::Flush(void)
236 {
237 Error error = kErrorNone;
238
239 VerifyOrExit(IsConnected(), error = kErrorInvalidState);
240 VerifyOrExit(mSendMessage->GetLength() != 0, error = kErrorNone);
241
242 mTransmitQueue.Enqueue(*mSendMessage);
243 mTransmitTask.Post();
244
245 mSendMessage = nullptr;
246
247 exit:
248 return error;
249 }
250
HandleBleReceive(uint8_t * aBuf,uint16_t aLength)251 Error BleSecure::HandleBleReceive(uint8_t *aBuf, uint16_t aLength)
252 {
253 ot::Message *message = nullptr;
254 Ip6::MessageInfo messageInfo;
255 Error error = kErrorNone;
256
257 if ((message = Get<MessagePool>().Allocate(Message::kTypeBle, 0)) == nullptr)
258 {
259 error = kErrorNoBufs;
260 ExitNow();
261 }
262 SuccessOrExit(error = message->AppendBytes(aBuf, aLength));
263
264 // Cannot call Receive(..) directly because Setup(..) and mState are private
265 mTls.HandleReceive(*message, messageInfo);
266
267 exit:
268 FreeMessage(message);
269 return error;
270 }
271
HandleBleConnected(uint16_t aConnectionId)272 void BleSecure::HandleBleConnected(uint16_t aConnectionId)
273 {
274 OT_UNUSED_VARIABLE(aConnectionId);
275
276 mBleState = kConnected;
277
278 IgnoreReturnValue(otPlatBleGattMtuGet(&GetInstance(), &mMtuSize));
279
280 mConnectCallback.InvokeIfSet(&GetInstance(), IsConnected(), true);
281 }
282
HandleBleDisconnected(uint16_t aConnectionId)283 void BleSecure::HandleBleDisconnected(uint16_t aConnectionId)
284 {
285 OT_UNUSED_VARIABLE(aConnectionId);
286
287 mBleState = kAdvertising;
288 mMtuSize = kInitialMtuSize;
289
290 Disconnect(); // Stop TLS connection
291 }
292
HandleBleMtuUpdate(uint16_t aMtu)293 Error BleSecure::HandleBleMtuUpdate(uint16_t aMtu)
294 {
295 Error error = kErrorNone;
296
297 if (aMtu <= OT_BLE_ATT_MTU_MAX)
298 {
299 mMtuSize = aMtu;
300 }
301 else
302 {
303 mMtuSize = OT_BLE_ATT_MTU_MAX;
304 error = kErrorInvalidArgs;
305 }
306
307 return error;
308 }
309
HandleTlsConnectEvent(MeshCoP::Tls::ConnectEvent aEvent,void * aContext)310 void BleSecure::HandleTlsConnectEvent(MeshCoP::Tls::ConnectEvent aEvent, void *aContext)
311 {
312 return static_cast<BleSecure *>(aContext)->HandleTlsConnectEvent(aEvent);
313 }
314
HandleTlsConnectEvent(MeshCoP::Tls::ConnectEvent aEvent)315 void BleSecure::HandleTlsConnectEvent(MeshCoP::Tls::ConnectEvent aEvent)
316 {
317 if (aEvent == MeshCoP::Tls::kConnected)
318 {
319 Error err;
320
321 if (mReceivedMessage == nullptr)
322 {
323 mReceivedMessage = Get<MessagePool>().Allocate(Message::kTypeBle);
324 }
325 err = mTcatAgent.Connected(*this);
326
327 if (err != kErrorNone)
328 {
329 mTls.Disconnect(); // must not use Close(), so that next Commissioner can connect
330 LogWarn("Rejected TCAT Commissioner, error: %s", ErrorToString(err));
331 ExitNow();
332 }
333 }
334 else
335 {
336 FreeMessage(mReceivedMessage);
337 mReceivedMessage = nullptr;
338
339 if (mTcatAgent.IsEnabled())
340 {
341 mTcatAgent.Disconnected();
342 }
343 }
344
345 mConnectCallback.InvokeIfSet(&GetInstance(), aEvent == MeshCoP::Tls::kConnected, true);
346
347 exit:
348 return;
349 }
350
HandleTlsReceive(void * aContext,uint8_t * aBuf,uint16_t aLength)351 void BleSecure::HandleTlsReceive(void *aContext, uint8_t *aBuf, uint16_t aLength)
352 {
353 return static_cast<BleSecure *>(aContext)->HandleTlsReceive(aBuf, aLength);
354 }
355
HandleTlsReceive(uint8_t * aBuf,uint16_t aLength)356 void BleSecure::HandleTlsReceive(uint8_t *aBuf, uint16_t aLength)
357 {
358 VerifyOrExit(mReceivedMessage != nullptr);
359
360 if (!mTlvMode)
361 {
362 SuccessOrExit(mReceivedMessage->AppendBytes(aBuf, aLength));
363 mReceiveCallback.InvokeIfSet(&GetInstance(), mReceivedMessage, 0, OT_TCAT_APPLICATION_PROTOCOL_NONE, "");
364 IgnoreReturnValue(mReceivedMessage->SetLength(0));
365 }
366 else
367 {
368 ot::Tlv tlv;
369 uint32_t requiredBytes = sizeof(Tlv);
370 uint32_t offset;
371
372 while (aLength > 0)
373 {
374 if (mReceivedMessage->GetLength() < requiredBytes)
375 {
376 uint32_t missingBytes = requiredBytes - mReceivedMessage->GetLength();
377
378 if (missingBytes > aLength)
379 {
380 SuccessOrExit(mReceivedMessage->AppendBytes(aBuf, aLength));
381 break;
382 }
383 else
384 {
385 SuccessOrExit(mReceivedMessage->AppendBytes(aBuf, (uint16_t)missingBytes));
386 aLength -= missingBytes;
387 aBuf += missingBytes;
388 }
389 }
390
391 IgnoreReturnValue(mReceivedMessage->Read(0, tlv));
392
393 if (tlv.IsExtended())
394 {
395 ot::ExtendedTlv extTlv;
396 requiredBytes = sizeof(extTlv);
397
398 if (mReceivedMessage->GetLength() < requiredBytes)
399 {
400 continue;
401 }
402
403 IgnoreReturnValue(mReceivedMessage->Read(0, extTlv));
404 requiredBytes = extTlv.GetSize();
405 offset = sizeof(extTlv);
406 }
407 else
408 {
409 requiredBytes = tlv.GetSize();
410 offset = sizeof(tlv);
411 }
412
413 if (mReceivedMessage->GetLength() < requiredBytes)
414 {
415 continue;
416 }
417
418 // TLV fully loaded
419
420 if (mTcatAgent.IsEnabled())
421 {
422 ot::Message *message;
423 Error error = kErrorNone;
424
425 message = Get<MessagePool>().Allocate(Message::kTypeBle);
426 VerifyOrExit(message != nullptr, error = kErrorNoBufs);
427
428 error = mTcatAgent.HandleSingleTlv(*mReceivedMessage, *message);
429 if (message->GetLength() != 0)
430 {
431 IgnoreReturnValue(SendMessage(*message));
432 }
433
434 if (error == kErrorAbort)
435 {
436 LogInfo("Disconnecting TCAT client.");
437 // kErrorAbort indicates that a Disconnect command TLV has been received.
438 Disconnect();
439 // BleSecure is not stopped here, it must remain active in advertising state and
440 // must be ready to receive a next TCAT commissioner.
441 ExitNow();
442 }
443 }
444 else
445 {
446 mReceivedMessage->SetOffset((uint16_t)offset);
447 mReceiveCallback.InvokeIfSet(&GetInstance(), mReceivedMessage, (int32_t)offset,
448 OT_TCAT_APPLICATION_PROTOCOL_NONE, "");
449 }
450
451 SuccessOrExit(mReceivedMessage->SetLength(0)); // also sets the offset to 0
452 requiredBytes = sizeof(Tlv);
453 }
454 }
455
456 exit:
457 return;
458 }
459
HandleTransmit(void)460 void BleSecure::HandleTransmit(void)
461 {
462 Error error = kErrorNone;
463 ot::Message *message = mTransmitQueue.GetHead();
464
465 VerifyOrExit(message != nullptr);
466 mTransmitQueue.Dequeue(*message);
467
468 if (mTransmitQueue.GetHead() != nullptr)
469 {
470 mTransmitTask.Post();
471 }
472
473 SuccessOrExit(error = mTls.Send(*message));
474 LogDebg("Transmit");
475
476 exit:
477 FreeMessageOnError(message, error);
478 LogWarnOnError(error, "transmit");
479 }
480
HandleTransport(void * aContext,ot::Message & aMessage,const Ip6::MessageInfo & aMessageInfo)481 Error BleSecure::HandleTransport(void *aContext, ot::Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
482 {
483 OT_UNUSED_VARIABLE(aMessageInfo);
484 return static_cast<BleSecure *>(aContext)->HandleTransport(aMessage);
485 }
486
HandleTransport(ot::Message & aMessage)487 Error BleSecure::HandleTransport(ot::Message &aMessage)
488 {
489 otBleRadioPacket packet;
490 uint16_t len = aMessage.GetLength();
491 uint16_t offset = 0;
492 Error error = kErrorNone;
493
494 while (len > 0)
495 {
496 if (len <= mMtuSize - kGattOverhead)
497 {
498 packet.mLength = len;
499 }
500 else
501 {
502 packet.mLength = mMtuSize - kGattOverhead;
503 }
504
505 if (packet.mLength > kPacketBufferSize)
506 {
507 packet.mLength = kPacketBufferSize;
508 }
509
510 IgnoreReturnValue(aMessage.Read(offset, mPacketBuffer, packet.mLength));
511 packet.mValue = mPacketBuffer;
512 packet.mPower = OT_BLE_DEFAULT_POWER;
513
514 SuccessOrExit(error = otPlatBleGattServerIndicate(&GetInstance(), kTxBleHandle, &packet));
515
516 len -= packet.mLength;
517 offset += packet.mLength;
518 }
519
520 aMessage.Free();
521 exit:
522 return error;
523 }
524
525 } // namespace Ble
526 } // namespace ot
527
otPlatBleGattServerOnWriteRequest(otInstance * aInstance,uint16_t aHandle,const otBleRadioPacket * aPacket)528 void otPlatBleGattServerOnWriteRequest(otInstance *aInstance, uint16_t aHandle, const otBleRadioPacket *aPacket)
529 {
530 OT_UNUSED_VARIABLE(aHandle); // Only a single handle is expected for RX
531
532 VerifyOrExit(aPacket != nullptr);
533 IgnoreReturnValue(AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleReceive(aPacket->mValue, aPacket->mLength));
534 exit:
535 return;
536 }
537
otPlatBleGapOnConnected(otInstance * aInstance,uint16_t aConnectionId)538 void otPlatBleGapOnConnected(otInstance *aInstance, uint16_t aConnectionId)
539 {
540 AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleConnected(aConnectionId);
541 }
542
otPlatBleGapOnDisconnected(otInstance * aInstance,uint16_t aConnectionId)543 void otPlatBleGapOnDisconnected(otInstance *aInstance, uint16_t aConnectionId)
544 {
545 AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleDisconnected(aConnectionId);
546 }
547
otPlatBleGattOnMtuUpdate(otInstance * aInstance,uint16_t aMtu)548 void otPlatBleGattOnMtuUpdate(otInstance *aInstance, uint16_t aMtu)
549 {
550 IgnoreReturnValue(AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleMtuUpdate(aMtu));
551 }
552
553 #endif // OPENTHREAD_CONFIG_BLE_TCAT_ENABLE
554