1 /*
2 * Copyright (c) 2023, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 #include "ble_secure.hpp"
30
31 #if OPENTHREAD_CONFIG_BLE_TCAT_ENABLE
32
33 #include <openthread/platform/ble.h>
34 #include "common/locator_getters.hpp"
35 #include "common/log.hpp"
36 #include "common/tlvs.hpp"
37 #include "instance/instance.hpp"
38 #include "meshcop/secure_transport.hpp"
39
40 using namespace ot;
41
42 /**
43 * @file
44 * This file implements the secure Ble agent.
45 */
46
47 namespace ot {
48 namespace Ble {
49
50 RegisterLogModule("BleSecure");
51
BleSecure(Instance & aInstance)52 BleSecure::BleSecure(Instance &aInstance)
53 : InstanceLocator(aInstance)
54 , mTls(aInstance, false, false)
55 , mTcatAgent(aInstance)
56 , mTlvMode(false)
57 , mReceivedMessage(nullptr)
58 , mSendMessage(nullptr)
59 , mTransmitTask(aInstance)
60 , mBleState(kStopped)
61 , mMtuSize(kInitialMtuSize)
62 {
63 }
64
Start(ConnectCallback aConnectHandler,ReceiveCallback aReceiveHandler,bool aTlvMode,void * aContext)65 Error BleSecure::Start(ConnectCallback aConnectHandler, ReceiveCallback aReceiveHandler, bool aTlvMode, void *aContext)
66 {
67 Error error = kErrorNone;
68
69 VerifyOrExit(mBleState == kStopped, error = kErrorAlready);
70
71 mConnectCallback.Set(aConnectHandler, aContext);
72 mReceiveCallback.Set(aReceiveHandler, aContext);
73 mTlvMode = aTlvMode;
74 mMtuSize = kInitialMtuSize;
75
76 SuccessOrExit(error = otPlatBleEnable(&GetInstance()));
77 SuccessOrExit(error = otPlatBleGapAdvStart(&GetInstance(), OT_BLE_ADV_INTERVAL_DEFAULT));
78 SuccessOrExit(error = mTls.Open(&BleSecure::HandleTlsReceive, &BleSecure::HandleTlsConnected, this));
79 SuccessOrExit(error = mTls.Bind(HandleTransport, this));
80
81 exit:
82 if (error == kErrorNone)
83 {
84 mBleState = kAdvertising;
85 }
86 return error;
87 }
88
TcatStart(const MeshCoP::TcatAgent::VendorInfo & aVendorInfo,MeshCoP::TcatAgent::JoinCallback aJoinHandler)89 Error BleSecure::TcatStart(const MeshCoP::TcatAgent::VendorInfo &aVendorInfo,
90 MeshCoP::TcatAgent::JoinCallback aJoinHandler)
91 {
92 Error error;
93
94 VerifyOrExit(mBleState != kStopped, error = kErrorInvalidState);
95
96 error = mTcatAgent.Start(aVendorInfo, mReceiveCallback.GetHandler(), aJoinHandler, mReceiveCallback.GetContext());
97
98 exit:
99 return error;
100 }
101
Stop(void)102 void BleSecure::Stop(void)
103 {
104 VerifyOrExit(mBleState != kStopped);
105 SuccessOrExit(otPlatBleGapAdvStop(&GetInstance()));
106 SuccessOrExit(otPlatBleDisable(&GetInstance()));
107 mBleState = kStopped;
108 mMtuSize = kInitialMtuSize;
109
110 if (mTcatAgent.IsEnabled())
111 {
112 mTcatAgent.Stop();
113 }
114
115 mTls.Close();
116
117 mTransmitQueue.DequeueAndFreeAll();
118
119 mConnectCallback.Clear();
120 mReceiveCallback.Clear();
121
122 FreeMessage(mReceivedMessage);
123 mReceivedMessage = nullptr;
124 FreeMessage(mSendMessage);
125 mSendMessage = nullptr;
126
127 exit:
128 return;
129 }
130
Connect(void)131 Error BleSecure::Connect(void)
132 {
133 Ip6::SockAddr sockaddr;
134 Error error;
135
136 VerifyOrExit(mBleState == kConnected, error = kErrorInvalidState);
137
138 error = mTls.Connect(sockaddr);
139
140 exit:
141 return error;
142 }
143
Disconnect(void)144 void BleSecure::Disconnect(void)
145 {
146 if (mTls.IsConnected())
147 {
148 mTls.Disconnect();
149 }
150
151 if (mBleState == kConnected)
152 {
153 mBleState = kAdvertising;
154 IgnoreReturnValue(otPlatBleGapDisconnect(&GetInstance()));
155 }
156
157 mConnectCallback.InvokeIfSet(&GetInstance(), false, false);
158 }
159
SetPsk(const MeshCoP::JoinerPskd & aPskd)160 void BleSecure::SetPsk(const MeshCoP::JoinerPskd &aPskd)
161 {
162 static_assert(static_cast<uint16_t>(MeshCoP::JoinerPskd::kMaxLength) <=
163 static_cast<uint16_t>(MeshCoP::SecureTransport::kPskMaxLength),
164 "The maximum length of TLS PSK is smaller than joiner PSKd");
165
166 SuccessOrAssert(mTls.SetPsk(reinterpret_cast<const uint8_t *>(aPskd.GetAsCString()), aPskd.GetLength()));
167 }
168
169 #if defined(MBEDTLS_BASE64_C) && defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
GetPeerCertificateBase64(unsigned char * aPeerCert,size_t * aCertLength)170 Error BleSecure::GetPeerCertificateBase64(unsigned char *aPeerCert, size_t *aCertLength)
171 {
172 Error error;
173
174 VerifyOrExit(aCertLength != nullptr, error = kErrorInvalidArgs);
175
176 error = mTls.GetPeerCertificateBase64(aPeerCert, aCertLength, *aCertLength);
177
178 exit:
179 return error;
180 }
181 #endif // defined(MBEDTLS_BASE64_C) && defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
182
SendMessage(ot::Message & aMessage)183 Error BleSecure::SendMessage(ot::Message &aMessage)
184 {
185 Error error = kErrorNone;
186
187 VerifyOrExit(IsConnected(), error = kErrorInvalidState);
188 if (mSendMessage == nullptr)
189 {
190 mSendMessage = Get<MessagePool>().Allocate(Message::kTypeBle);
191 VerifyOrExit(mSendMessage != nullptr, error = kErrorNoBufs);
192 }
193 SuccessOrExit(error = mSendMessage->AppendBytesFromMessage(aMessage, 0, aMessage.GetLength()));
194 SuccessOrExit(error = Flush());
195
196 exit:
197 aMessage.Free();
198 return error;
199 }
200
Send(uint8_t * aBuf,uint16_t aLength)201 Error BleSecure::Send(uint8_t *aBuf, uint16_t aLength)
202 {
203 Error error = kErrorNone;
204
205 VerifyOrExit(IsConnected(), error = kErrorInvalidState);
206 if (mSendMessage == nullptr)
207 {
208 mSendMessage = Get<MessagePool>().Allocate(Message::kTypeBle);
209 VerifyOrExit(mSendMessage != nullptr, error = kErrorNoBufs);
210 }
211 SuccessOrExit(error = mSendMessage->AppendBytes(aBuf, aLength));
212
213 exit:
214 return error;
215 }
216
SendApplicationTlv(uint8_t * aBuf,uint16_t aLength)217 Error BleSecure::SendApplicationTlv(uint8_t *aBuf, uint16_t aLength)
218 {
219 Error error = kErrorNone;
220 if (aLength > Tlv::kBaseTlvMaxLength)
221 {
222 ot::ExtendedTlv tlv;
223
224 tlv.SetType(ot::MeshCoP::TcatAgent::kTlvSendApplicationData);
225 tlv.SetLength(aLength);
226 SuccessOrExit(error = Send(reinterpret_cast<uint8_t *>(&tlv), sizeof(tlv)));
227 }
228 else
229 {
230 ot::Tlv tlv;
231
232 tlv.SetType(ot::MeshCoP::TcatAgent::kTlvSendApplicationData);
233 tlv.SetLength((uint8_t)aLength);
234 SuccessOrExit(error = Send(reinterpret_cast<uint8_t *>(&tlv), sizeof(tlv)));
235 }
236
237 error = Send(aBuf, aLength);
238 exit:
239 return error;
240 }
241
Flush(void)242 Error BleSecure::Flush(void)
243 {
244 Error error = kErrorNone;
245
246 VerifyOrExit(IsConnected(), error = kErrorInvalidState);
247 VerifyOrExit(mSendMessage->GetLength() != 0, error = kErrorNone);
248
249 mTransmitQueue.Enqueue(*mSendMessage);
250 mTransmitTask.Post();
251
252 mSendMessage = nullptr;
253
254 exit:
255 return error;
256 }
257
HandleBleReceive(uint8_t * aBuf,uint16_t aLength)258 Error BleSecure::HandleBleReceive(uint8_t *aBuf, uint16_t aLength)
259 {
260 ot::Message *message = nullptr;
261 Ip6::MessageInfo messageInfo;
262 Error error = kErrorNone;
263
264 if ((message = Get<MessagePool>().Allocate(Message::kTypeBle, 0)) == nullptr)
265 {
266 error = kErrorNoBufs;
267 ExitNow();
268 }
269 SuccessOrExit(error = message->AppendBytes(aBuf, aLength));
270
271 // Cannot call Receive(..) directly because Setup(..) and mState are private
272 mTls.HandleReceive(*message, messageInfo);
273
274 exit:
275 FreeMessage(message);
276 return error;
277 }
278
HandleBleConnected(uint16_t aConnectionId)279 void BleSecure::HandleBleConnected(uint16_t aConnectionId)
280 {
281 OT_UNUSED_VARIABLE(aConnectionId);
282
283 mBleState = kConnected;
284
285 IgnoreReturnValue(otPlatBleGattMtuGet(&GetInstance(), &mMtuSize));
286
287 mConnectCallback.InvokeIfSet(&GetInstance(), IsConnected(), true);
288 }
289
HandleBleDisconnected(uint16_t aConnectionId)290 void BleSecure::HandleBleDisconnected(uint16_t aConnectionId)
291 {
292 OT_UNUSED_VARIABLE(aConnectionId);
293
294 mBleState = kAdvertising;
295 mMtuSize = kInitialMtuSize;
296
297 Disconnect(); // Stop TLS connection
298 }
299
HandleBleMtuUpdate(uint16_t aMtu)300 Error BleSecure::HandleBleMtuUpdate(uint16_t aMtu)
301 {
302 Error error = kErrorNone;
303
304 if (aMtu <= OT_BLE_ATT_MTU_MAX)
305 {
306 mMtuSize = aMtu;
307 }
308 else
309 {
310 mMtuSize = OT_BLE_ATT_MTU_MAX;
311 error = kErrorInvalidArgs;
312 }
313
314 return error;
315 }
316
HandleTlsConnected(void * aContext,bool aConnected)317 void BleSecure::HandleTlsConnected(void *aContext, bool aConnected)
318 {
319 return static_cast<BleSecure *>(aContext)->HandleTlsConnected(aConnected);
320 }
321
HandleTlsConnected(bool aConnected)322 void BleSecure::HandleTlsConnected(bool aConnected)
323 {
324 if (aConnected)
325 {
326 if (mReceivedMessage == nullptr)
327 {
328 mReceivedMessage = Get<MessagePool>().Allocate(Message::kTypeBle);
329 }
330
331 if (mTcatAgent.IsEnabled())
332 {
333 IgnoreReturnValue(mTcatAgent.Connected(mTls));
334 }
335 }
336 else
337 {
338 FreeMessage(mReceivedMessage);
339 mReceivedMessage = nullptr;
340
341 if (mTcatAgent.IsEnabled())
342 {
343 mTcatAgent.Disconnected();
344 }
345 }
346
347 mConnectCallback.InvokeIfSet(&GetInstance(), aConnected, true);
348 }
349
HandleTlsReceive(void * aContext,uint8_t * aBuf,uint16_t aLength)350 void BleSecure::HandleTlsReceive(void *aContext, uint8_t *aBuf, uint16_t aLength)
351 {
352 return static_cast<BleSecure *>(aContext)->HandleTlsReceive(aBuf, aLength);
353 }
354
HandleTlsReceive(uint8_t * aBuf,uint16_t aLength)355 void BleSecure::HandleTlsReceive(uint8_t *aBuf, uint16_t aLength)
356 {
357 VerifyOrExit(mReceivedMessage != nullptr);
358
359 if (!mTlvMode)
360 {
361 SuccessOrExit(mReceivedMessage->AppendBytes(aBuf, aLength));
362 mReceiveCallback.InvokeIfSet(&GetInstance(), mReceivedMessage, 0, OT_TCAT_APPLICATION_PROTOCOL_NONE, "");
363 IgnoreReturnValue(mReceivedMessage->SetLength(0));
364 }
365 else
366 {
367 ot::Tlv tlv;
368 uint32_t requiredBytes = sizeof(Tlv);
369 uint32_t offset;
370
371 while (aLength > 0)
372 {
373 if (mReceivedMessage->GetLength() < requiredBytes)
374 {
375 uint32_t missingBytes = requiredBytes - mReceivedMessage->GetLength();
376
377 if (missingBytes > aLength)
378 {
379 SuccessOrExit(mReceivedMessage->AppendBytes(aBuf, aLength));
380 break;
381 }
382 else
383 {
384 SuccessOrExit(mReceivedMessage->AppendBytes(aBuf, (uint16_t)missingBytes));
385 aLength -= missingBytes;
386 aBuf += missingBytes;
387 }
388 }
389
390 IgnoreReturnValue(mReceivedMessage->Read(0, tlv));
391
392 if (tlv.IsExtended())
393 {
394 ot::ExtendedTlv extTlv;
395 requiredBytes = sizeof(extTlv);
396
397 if (mReceivedMessage->GetLength() < requiredBytes)
398 {
399 continue;
400 }
401
402 IgnoreReturnValue(mReceivedMessage->Read(0, extTlv));
403 requiredBytes = extTlv.GetSize();
404 offset = sizeof(extTlv);
405 }
406 else
407 {
408 requiredBytes = tlv.GetSize();
409 offset = sizeof(tlv);
410 }
411
412 if (mReceivedMessage->GetLength() < requiredBytes)
413 {
414 continue;
415 }
416
417 // TLV fully loaded
418
419 if (mTcatAgent.IsEnabled())
420 {
421 ot::Message *message;
422 Error error = kErrorNone;
423
424 message = Get<MessagePool>().Allocate(Message::kTypeBle);
425 VerifyOrExit(message != nullptr, error = kErrorNoBufs);
426
427 error = mTcatAgent.HandleSingleTlv(*mReceivedMessage, *message);
428 if (message->GetLength() != 0)
429 {
430 IgnoreReturnValue(SendMessage(*message));
431 }
432
433 if (error == kErrorAbort)
434 {
435 Disconnect();
436 Stop();
437 ExitNow();
438 }
439 }
440 else
441 {
442 mReceivedMessage->SetOffset((uint16_t)offset);
443 mReceiveCallback.InvokeIfSet(&GetInstance(), mReceivedMessage, (int32_t)offset,
444 OT_TCAT_APPLICATION_PROTOCOL_NONE, "");
445 }
446
447 SuccessOrExit(mReceivedMessage->SetLength(0)); // also sets the offset to 0
448 requiredBytes = sizeof(Tlv);
449 }
450 }
451
452 exit:
453 return;
454 }
455
HandleTransmit(void)456 void BleSecure::HandleTransmit(void)
457 {
458 Error error = kErrorNone;
459 ot::Message *message = mTransmitQueue.GetHead();
460
461 VerifyOrExit(message != nullptr);
462 mTransmitQueue.Dequeue(*message);
463
464 if (mTransmitQueue.GetHead() != nullptr)
465 {
466 mTransmitTask.Post();
467 }
468
469 SuccessOrExit(error = mTls.Send(*message, message->GetLength()));
470
471 exit:
472 if (error != kErrorNone)
473 {
474 LogNote("Transmit: %s", ErrorToString(error));
475 message->Free();
476 }
477 else
478 {
479 LogDebg("Transmit: %s", ErrorToString(error));
480 }
481 }
482
HandleTransport(void * aContext,ot::Message & aMessage,const Ip6::MessageInfo & aMessageInfo)483 Error BleSecure::HandleTransport(void *aContext, ot::Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
484 {
485 OT_UNUSED_VARIABLE(aMessageInfo);
486 return static_cast<BleSecure *>(aContext)->HandleTransport(aMessage);
487 }
488
HandleTransport(ot::Message & aMessage)489 Error BleSecure::HandleTransport(ot::Message &aMessage)
490 {
491 otBleRadioPacket packet;
492 uint16_t len = aMessage.GetLength();
493 uint16_t offset = 0;
494 Error error = kErrorNone;
495
496 while (len > 0)
497 {
498 if (len <= mMtuSize - kGattOverhead)
499 {
500 packet.mLength = len;
501 }
502 else
503 {
504 packet.mLength = mMtuSize - kGattOverhead;
505 }
506
507 if (packet.mLength > kPacketBufferSize)
508 {
509 packet.mLength = kPacketBufferSize;
510 }
511
512 IgnoreReturnValue(aMessage.Read(offset, mPacketBuffer, packet.mLength));
513 packet.mValue = mPacketBuffer;
514 packet.mPower = OT_BLE_DEFAULT_POWER;
515
516 SuccessOrExit(error = otPlatBleGattServerIndicate(&GetInstance(), kTxBleHandle, &packet));
517
518 len -= packet.mLength;
519 offset += packet.mLength;
520 }
521
522 aMessage.Free();
523 exit:
524 return error;
525 }
526
527 } // namespace Ble
528 } // namespace ot
529
otPlatBleGattServerOnWriteRequest(otInstance * aInstance,uint16_t aHandle,const otBleRadioPacket * aPacket)530 void otPlatBleGattServerOnWriteRequest(otInstance *aInstance, uint16_t aHandle, const otBleRadioPacket *aPacket)
531 {
532 OT_UNUSED_VARIABLE(aHandle); // Only a single handle is expected for RX
533
534 VerifyOrExit(aPacket != nullptr);
535 IgnoreReturnValue(AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleReceive(aPacket->mValue, aPacket->mLength));
536 exit:
537 return;
538 }
539
otPlatBleGapOnConnected(otInstance * aInstance,uint16_t aConnectionId)540 void otPlatBleGapOnConnected(otInstance *aInstance, uint16_t aConnectionId)
541 {
542 AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleConnected(aConnectionId);
543 }
544
otPlatBleGapOnDisconnected(otInstance * aInstance,uint16_t aConnectionId)545 void otPlatBleGapOnDisconnected(otInstance *aInstance, uint16_t aConnectionId)
546 {
547 AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleDisconnected(aConnectionId);
548 }
549
otPlatBleGattOnMtuUpdate(otInstance * aInstance,uint16_t aMtu)550 void otPlatBleGattOnMtuUpdate(otInstance *aInstance, uint16_t aMtu)
551 {
552 IgnoreReturnValue(AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleMtuUpdate(aMtu));
553 }
554
555 #endif // OPENTHREAD_CONFIG_BLE_TCAT_ENABLE
556