1 /*
2 * Copyright (c) 2021, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 /**
30 * @file
31 * This file implements TCP/IPv6 sockets.
32 */
33
34 #include "openthread-core-config.h"
35
36 #if OPENTHREAD_CONFIG_TCP_ENABLE
37
38 #include "tcp6.hpp"
39
40 #include "common/as_core_type.hpp"
41 #include "common/code_utils.hpp"
42 #include "common/error.hpp"
43 #include "common/instance.hpp"
44 #include "common/locator_getters.hpp"
45 #include "common/log.hpp"
46 #include "common/random.hpp"
47 #include "net/checksum.hpp"
48 #include "net/ip6.hpp"
49 #include "net/netif.hpp"
50
51 #include "../../third_party/tcplp/tcplp.h"
52
53 namespace ot {
54 namespace Ip6 {
55
56 using ot::Encoding::BigEndian::HostSwap16;
57 using ot::Encoding::BigEndian::HostSwap32;
58
59 RegisterLogModule("Tcp");
60
61 static_assert(sizeof(struct tcpcb) == sizeof(Tcp::Endpoint::mTcb), "mTcb field in otTcpEndpoint is sized incorrectly");
62 static_assert(alignof(struct tcpcb) == alignof(decltype(Tcp::Endpoint::mTcb)),
63 "mTcb field in otTcpEndpoint is aligned incorrectly");
64 static_assert(offsetof(Tcp::Endpoint, mTcb) == 0, "mTcb field in otTcpEndpoint has nonzero offset");
65
66 static_assert(sizeof(struct tcpcb_listen) == sizeof(Tcp::Listener::mTcbListen),
67 "mTcbListen field in otTcpListener is sized incorrectly");
68 static_assert(alignof(struct tcpcb_listen) == alignof(decltype(Tcp::Listener::mTcbListen)),
69 "mTcbListen field in otTcpListener is aligned incorrectly");
70 static_assert(offsetof(Tcp::Listener, mTcbListen) == 0, "mTcbListen field in otTcpEndpoint has nonzero offset");
71
Tcp(Instance & aInstance)72 Tcp::Tcp(Instance &aInstance)
73 : InstanceLocator(aInstance)
74 , mTimer(aInstance, Tcp::HandleTimer)
75 , mTasklet(aInstance, Tcp::HandleTasklet)
76 , mEphemeralPort(kDynamicPortMin)
77 {
78 OT_UNUSED_VARIABLE(mEphemeralPort);
79 }
80
Initialize(Instance & aInstance,const otTcpEndpointInitializeArgs & aArgs)81 Error Tcp::Endpoint::Initialize(Instance &aInstance, const otTcpEndpointInitializeArgs &aArgs)
82 {
83 Error error;
84 struct tcpcb &tp = GetTcb();
85
86 memset(&tp, 0x00, sizeof(tp));
87
88 SuccessOrExit(error = aInstance.Get<Tcp>().mEndpoints.Add(*this));
89
90 mContext = aArgs.mContext;
91 mEstablishedCallback = aArgs.mEstablishedCallback;
92 mSendDoneCallback = aArgs.mSendDoneCallback;
93 mForwardProgressCallback = aArgs.mForwardProgressCallback;
94 mReceiveAvailableCallback = aArgs.mReceiveAvailableCallback;
95 mDisconnectedCallback = aArgs.mDisconnectedCallback;
96
97 memset(mTimers, 0x00, sizeof(mTimers));
98 memset(&mSockAddr, 0x00, sizeof(mSockAddr));
99 mPendingCallbacks = 0;
100
101 /*
102 * Initialize buffers --- formerly in initialize_tcb.
103 */
104 {
105 uint8_t *recvbuf = static_cast<uint8_t *>(aArgs.mReceiveBuffer);
106 size_t recvbuflen = aArgs.mReceiveBufferSize - ((aArgs.mReceiveBufferSize + 8) / 9);
107 uint8_t *reassbmp = recvbuf + recvbuflen;
108
109 lbuf_init(&tp.sendbuf);
110 cbuf_init(&tp.recvbuf, recvbuf, recvbuflen);
111 tp.reassbmp = reassbmp;
112 bmp_init(tp.reassbmp, BITS_TO_BYTES(recvbuflen));
113 }
114
115 tp.accepted_from = nullptr;
116 initialize_tcb(&tp);
117
118 /* Note that we do not need to zero-initialize mReceiveLinks. */
119
120 tp.instance = &aInstance;
121
122 exit:
123 return error;
124 }
125
GetInstance(void) const126 Instance &Tcp::Endpoint::GetInstance(void) const
127 {
128 return AsNonConst(AsCoreType(GetTcb().instance));
129 }
130
GetLocalAddress(void) const131 const SockAddr &Tcp::Endpoint::GetLocalAddress(void) const
132 {
133 const struct tcpcb &tp = GetTcb();
134
135 static otSockAddr temp;
136
137 memcpy(&temp.mAddress, &tp.laddr, sizeof(temp.mAddress));
138 temp.mPort = HostSwap16(tp.lport);
139
140 return AsCoreType(&temp);
141 }
142
GetPeerAddress(void) const143 const SockAddr &Tcp::Endpoint::GetPeerAddress(void) const
144 {
145 const struct tcpcb &tp = GetTcb();
146
147 static otSockAddr temp;
148
149 memcpy(&temp.mAddress, &tp.faddr, sizeof(temp.mAddress));
150 temp.mPort = HostSwap16(tp.fport);
151
152 return AsCoreType(&temp);
153 }
154
Bind(const SockAddr & aSockName)155 Error Tcp::Endpoint::Bind(const SockAddr &aSockName)
156 {
157 Error error;
158 struct tcpcb &tp = GetTcb();
159
160 VerifyOrExit(!AsCoreType(&aSockName.mAddress).IsUnspecified(), error = kErrorInvalidArgs);
161 VerifyOrExit(Get<Tcp>().CanBind(aSockName), error = kErrorInvalidState);
162
163 memcpy(&tp.laddr, &aSockName.mAddress, sizeof(tp.laddr));
164 tp.lport = HostSwap16(aSockName.mPort);
165 error = kErrorNone;
166
167 exit:
168 return error;
169 }
170
Connect(const SockAddr & aSockName,uint32_t aFlags)171 Error Tcp::Endpoint::Connect(const SockAddr &aSockName, uint32_t aFlags)
172 {
173 Error error = kErrorNone;
174 struct tcpcb & tp = GetTcb();
175 struct sockaddr_in6 sin6p;
176
177 OT_UNUSED_VARIABLE(aFlags);
178
179 VerifyOrExit(tp.t_state == TCP6S_CLOSED, error = kErrorInvalidState);
180
181 memcpy(&sin6p.sin6_addr, &aSockName.mAddress, sizeof(sin6p.sin6_addr));
182 sin6p.sin6_port = HostSwap16(aSockName.mPort);
183 error = BsdErrorToOtError(tcp6_usr_connect(&tp, &sin6p));
184
185 exit:
186 return error;
187 }
188
SendByReference(otLinkedBuffer & aBuffer,uint32_t aFlags)189 Error Tcp::Endpoint::SendByReference(otLinkedBuffer &aBuffer, uint32_t aFlags)
190 {
191 Error error;
192 struct tcpcb &tp = GetTcb();
193
194 size_t backlogBefore = GetBacklogBytes();
195 size_t sent = aBuffer.mLength;
196
197 SuccessOrExit(error = BsdErrorToOtError(tcp_usr_send(&tp, (aFlags & OT_TCP_SEND_MORE_TO_COME) != 0, &aBuffer, 0)));
198
199 PostCallbacksAfterSend(sent, backlogBefore);
200
201 exit:
202 return error;
203 }
204
SendByExtension(size_t aNumBytes,uint32_t aFlags)205 Error Tcp::Endpoint::SendByExtension(size_t aNumBytes, uint32_t aFlags)
206 {
207 Error error;
208 bool moreToCome = (aFlags & OT_TCP_SEND_MORE_TO_COME) != 0;
209 struct tcpcb &tp = GetTcb();
210 size_t backlogBefore = GetBacklogBytes();
211 int bsdError;
212
213 VerifyOrExit(lbuf_head(&tp.sendbuf) != nullptr, error = kErrorInvalidState);
214
215 bsdError = tcp_usr_send(&tp, moreToCome ? 1 : 0, nullptr, aNumBytes);
216 SuccessOrExit(error = BsdErrorToOtError(bsdError));
217
218 PostCallbacksAfterSend(aNumBytes, backlogBefore);
219
220 exit:
221 return error;
222 }
223
ReceiveByReference(const otLinkedBuffer * & aBuffer)224 Error Tcp::Endpoint::ReceiveByReference(const otLinkedBuffer *&aBuffer)
225 {
226 struct tcpcb &tp = GetTcb();
227
228 cbuf_reference(&tp.recvbuf, &mReceiveLinks[0], &mReceiveLinks[1]);
229 aBuffer = &mReceiveLinks[0];
230
231 return kErrorNone;
232 }
233
ReceiveContiguify(void)234 Error Tcp::Endpoint::ReceiveContiguify(void)
235 {
236 return kErrorNotImplemented;
237 }
238
CommitReceive(size_t aNumBytes,uint32_t aFlags)239 Error Tcp::Endpoint::CommitReceive(size_t aNumBytes, uint32_t aFlags)
240 {
241 Error error = kErrorNone;
242 struct tcpcb &tp = GetTcb();
243
244 OT_UNUSED_VARIABLE(aFlags);
245
246 VerifyOrExit(cbuf_used_space(&tp.recvbuf) >= aNumBytes, error = kErrorFailed);
247 VerifyOrExit(aNumBytes > 0, error = kErrorNone);
248
249 cbuf_pop(&tp.recvbuf, aNumBytes);
250 error = BsdErrorToOtError(tcp_usr_rcvd(&tp));
251
252 exit:
253 return error;
254 }
255
SendEndOfStream(void)256 Error Tcp::Endpoint::SendEndOfStream(void)
257 {
258 struct tcpcb &tp = GetTcb();
259
260 return BsdErrorToOtError(tcp_usr_shutdown(&tp));
261 }
262
Abort(void)263 Error Tcp::Endpoint::Abort(void)
264 {
265 struct tcpcb &tp = GetTcb();
266
267 tcp_usr_abort(&tp);
268 /* connection_lost will do any reinitialization work for this socket. */
269 return kErrorNone;
270 }
271
Deinitialize(void)272 Error Tcp::Endpoint::Deinitialize(void)
273 {
274 Error error;
275
276 SuccessOrExit(error = Get<Tcp>().mEndpoints.Remove(*this));
277 SetNext(nullptr);
278
279 SuccessOrExit(error = Abort());
280
281 exit:
282 return error;
283 }
284
IsClosed(void) const285 bool Tcp::Endpoint::IsClosed(void) const
286 {
287 return GetTcb().t_state == TCP6S_CLOSED;
288 }
289
TimerFlagToIndex(uint8_t aTimerFlag)290 uint8_t Tcp::Endpoint::TimerFlagToIndex(uint8_t aTimerFlag)
291 {
292 uint8_t timerIndex = 0;
293
294 switch (aTimerFlag)
295 {
296 case TT_DELACK:
297 timerIndex = kTimerDelack;
298 break;
299 case TT_REXMT:
300 case TT_PERSIST:
301 timerIndex = kTimerRexmtPersist;
302 break;
303 case TT_KEEP:
304 timerIndex = kTimerKeep;
305 break;
306 case TT_2MSL:
307 timerIndex = kTimer2Msl;
308 break;
309 }
310
311 return timerIndex;
312 }
313
IsTimerActive(uint8_t aTimerIndex)314 bool Tcp::Endpoint::IsTimerActive(uint8_t aTimerIndex)
315 {
316 bool active = false;
317 struct tcpcb *tp = &GetTcb();
318
319 OT_ASSERT(aTimerIndex < kNumTimers);
320 switch (aTimerIndex)
321 {
322 case kTimerDelack:
323 active = tcp_timer_active(tp, TT_DELACK);
324 break;
325 case kTimerRexmtPersist:
326 active = tcp_timer_active(tp, TT_REXMT) || tcp_timer_active(tp, TT_PERSIST);
327 break;
328 case kTimerKeep:
329 active = tcp_timer_active(tp, TT_KEEP);
330 break;
331 case kTimer2Msl:
332 active = tcp_timer_active(tp, TT_2MSL);
333 break;
334 }
335
336 return active;
337 }
338
SetTimer(uint8_t aTimerFlag,uint32_t aDelay)339 void Tcp::Endpoint::SetTimer(uint8_t aTimerFlag, uint32_t aDelay)
340 {
341 /*
342 * TCPlp has already set the flag for this timer to record that it's
343 * running. So, all that's left to do is record the expiry time and
344 * (re)set the main timer as appropriate.
345 */
346
347 TimeMilli now = TimerMilli::GetNow();
348 TimeMilli newFireTime = now + aDelay;
349 uint8_t timerIndex = TimerFlagToIndex(aTimerFlag);
350
351 mTimers[timerIndex] = newFireTime.GetValue();
352 LogDebg("Endpoint %p set timer %u to %u ms", static_cast<void *>(this), static_cast<unsigned int>(timerIndex),
353 static_cast<unsigned int>(aDelay));
354
355 Get<Tcp>().mTimer.FireAtIfEarlier(newFireTime);
356 }
357
CancelTimer(uint8_t aTimerFlag)358 void Tcp::Endpoint::CancelTimer(uint8_t aTimerFlag)
359 {
360 /*
361 * TCPlp has already cleared the timer flag before calling this. Since the
362 * main timer's callback properly handles the case where no timers are
363 * actually due, there's actually no work to be done here.
364 */
365
366 OT_UNUSED_VARIABLE(aTimerFlag);
367
368 LogDebg("Endpoint %p cancelled timer %u", static_cast<void *>(this),
369 static_cast<unsigned int>(TimerFlagToIndex(aTimerFlag)));
370 }
371
FirePendingTimers(TimeMilli aNow,bool & aHasFutureTimer,TimeMilli & aEarliestFutureExpiry)372 bool Tcp::Endpoint::FirePendingTimers(TimeMilli aNow, bool &aHasFutureTimer, TimeMilli &aEarliestFutureExpiry)
373 {
374 bool calledUserCallback = false;
375 struct tcpcb *tp = &GetTcb();
376
377 /*
378 * NOTE: Firing a timer might potentially activate/deactivate other timers.
379 * If timers x and y expire at the same time, but the callback for timer x
380 * (for x < y) cancels or postpones timer y, should timer y's callback be
381 * called? Our answer is no, since timer x's callback has updated the
382 * TCP stack's state in such a way that it no longer expects timer y's
383 * callback to to be called. Because the TCP stack thinks that timer y
384 * has been cancelled, calling timer y's callback could potentially cause
385 * problems.
386 *
387 * If the timer callbacks set other timers, then they may not be taken
388 * into account when setting aEarliestFutureExpiry. But mTimer's expiry
389 * time will be updated by those, so we can just compare against mTimer's
390 * expiry time when resetting mTimer.
391 */
392 for (uint8_t timerIndex = 0; timerIndex != kNumTimers; timerIndex++)
393 {
394 if (IsTimerActive(timerIndex))
395 {
396 TimeMilli expiry(mTimers[timerIndex]);
397
398 if (expiry <= aNow)
399 {
400 /*
401 * If a user callback is called, then return true. For TCPlp,
402 * this only happens if the connection is dropped (e.g., it
403 * times out).
404 */
405 int dropped;
406
407 switch (timerIndex)
408 {
409 case kTimerDelack:
410 dropped = tcp_timer_delack(tp);
411 break;
412 case kTimerRexmtPersist:
413 if (tcp_timer_active(tp, TT_REXMT))
414 {
415 dropped = tcp_timer_rexmt(tp);
416 }
417 else
418 {
419 dropped = tcp_timer_persist(tp);
420 }
421 break;
422 case kTimerKeep:
423 dropped = tcp_timer_keep(tp);
424 break;
425 case kTimer2Msl:
426 dropped = tcp_timer_2msl(tp);
427 break;
428 }
429 VerifyOrExit(dropped == 0, calledUserCallback = true);
430 }
431 else
432 {
433 aHasFutureTimer = true;
434 aEarliestFutureExpiry = OT_MIN(aEarliestFutureExpiry, expiry);
435 }
436 }
437 }
438
439 exit:
440 return calledUserCallback;
441 }
442
PostCallbacksAfterSend(size_t aSent,size_t aBacklogBefore)443 void Tcp::Endpoint::PostCallbacksAfterSend(size_t aSent, size_t aBacklogBefore)
444 {
445 size_t backlogAfter = GetBacklogBytes();
446
447 if (backlogAfter < aBacklogBefore + aSent && mForwardProgressCallback != nullptr)
448 {
449 mPendingCallbacks |= kForwardProgressCallbackFlag;
450 Get<Tcp>().mTasklet.Post();
451 }
452 }
453
FirePendingCallbacks(void)454 bool Tcp::Endpoint::FirePendingCallbacks(void)
455 {
456 bool calledUserCallback = false;
457
458 if ((mPendingCallbacks & kForwardProgressCallbackFlag) != 0 && mForwardProgressCallback != nullptr)
459 {
460 mForwardProgressCallback(this, GetSendBufferBytes(), GetBacklogBytes());
461 calledUserCallback = true;
462 }
463
464 mPendingCallbacks = 0;
465
466 return calledUserCallback;
467 }
468
GetSendBufferBytes(void) const469 size_t Tcp::Endpoint::GetSendBufferBytes(void) const
470 {
471 const struct tcpcb &tp = GetTcb();
472 return lbuf_used_space(&tp.sendbuf);
473 }
474
GetInFlightBytes(void) const475 size_t Tcp::Endpoint::GetInFlightBytes(void) const
476 {
477 const struct tcpcb &tp = GetTcb();
478 return tp.snd_max - tp.snd_una;
479 }
480
GetBacklogBytes(void) const481 size_t Tcp::Endpoint::GetBacklogBytes(void) const
482 {
483 return GetSendBufferBytes() - GetInFlightBytes();
484 }
485
GetLocalIp6Address(void)486 Address &Tcp::Endpoint::GetLocalIp6Address(void)
487 {
488 return *reinterpret_cast<Address *>(&GetTcb().laddr);
489 }
490
GetLocalIp6Address(void) const491 const Address &Tcp::Endpoint::GetLocalIp6Address(void) const
492 {
493 return *reinterpret_cast<const Address *>(&GetTcb().laddr);
494 }
495
GetForeignIp6Address(void)496 Address &Tcp::Endpoint::GetForeignIp6Address(void)
497 {
498 return *reinterpret_cast<Address *>(&GetTcb().faddr);
499 }
500
GetForeignIp6Address(void) const501 const Address &Tcp::Endpoint::GetForeignIp6Address(void) const
502 {
503 return *reinterpret_cast<const Address *>(&GetTcb().faddr);
504 }
505
Matches(const MessageInfo & aMessageInfo) const506 bool Tcp::Endpoint::Matches(const MessageInfo &aMessageInfo) const
507 {
508 bool matches = false;
509 const struct tcpcb *tp = &GetTcb();
510
511 VerifyOrExit(tp->t_state != TCP6S_CLOSED);
512 VerifyOrExit(tp->lport == HostSwap16(aMessageInfo.GetSockPort()));
513 VerifyOrExit(tp->fport == HostSwap16(aMessageInfo.GetPeerPort()));
514 VerifyOrExit(GetLocalIp6Address().IsUnspecified() || GetLocalIp6Address() == aMessageInfo.GetSockAddr());
515 VerifyOrExit(GetForeignIp6Address() == aMessageInfo.GetPeerAddr());
516
517 matches = true;
518
519 exit:
520 return matches;
521 }
522
Initialize(Instance & aInstance,const otTcpListenerInitializeArgs & aArgs)523 Error Tcp::Listener::Initialize(Instance &aInstance, const otTcpListenerInitializeArgs &aArgs)
524 {
525 Error error;
526 struct tcpcb_listen *tpl = &GetTcbListen();
527
528 SuccessOrExit(error = aInstance.Get<Tcp>().mListeners.Add(*this));
529
530 mContext = aArgs.mContext;
531 mAcceptReadyCallback = aArgs.mAcceptReadyCallback;
532 mAcceptDoneCallback = aArgs.mAcceptDoneCallback;
533
534 memset(tpl, 0x00, sizeof(struct tcpcb_listen));
535 tpl->instance = &aInstance;
536
537 exit:
538 return error;
539 }
540
GetInstance(void) const541 Instance &Tcp::Listener::GetInstance(void) const
542 {
543 return AsNonConst(AsCoreType(GetTcbListen().instance));
544 }
545
Listen(const SockAddr & aSockName)546 Error Tcp::Listener::Listen(const SockAddr &aSockName)
547 {
548 Error error;
549 uint16_t port = HostSwap16(aSockName.mPort);
550 struct tcpcb_listen *tpl = &GetTcbListen();
551
552 VerifyOrExit(Get<Tcp>().CanBind(aSockName), error = kErrorInvalidState);
553
554 memcpy(&tpl->laddr, &aSockName.mAddress, sizeof(tpl->laddr));
555 tpl->lport = port;
556 tpl->t_state = TCP6S_LISTEN;
557 error = kErrorNone;
558
559 exit:
560 return error;
561 }
562
StopListening(void)563 Error Tcp::Listener::StopListening(void)
564 {
565 struct tcpcb_listen *tpl = &GetTcbListen();
566
567 memset(&tpl->laddr, 0x00, sizeof(tpl->laddr));
568 tpl->lport = 0;
569 tpl->t_state = TCP6S_CLOSED;
570 return kErrorNone;
571 }
572
Deinitialize(void)573 Error Tcp::Listener::Deinitialize(void)
574 {
575 Error error;
576
577 SuccessOrExit(error = Get<Tcp>().mListeners.Remove(*this));
578 SetNext(nullptr);
579
580 exit:
581 return error;
582 }
583
IsClosed(void) const584 bool Tcp::Listener::IsClosed(void) const
585 {
586 return GetTcbListen().t_state == TCP6S_CLOSED;
587 }
588
GetLocalIp6Address(void)589 Address &Tcp::Listener::GetLocalIp6Address(void)
590 {
591 return *reinterpret_cast<Address *>(&GetTcbListen().laddr);
592 }
593
GetLocalIp6Address(void) const594 const Address &Tcp::Listener::GetLocalIp6Address(void) const
595 {
596 return *reinterpret_cast<const Address *>(&GetTcbListen().laddr);
597 }
598
Matches(const MessageInfo & aMessageInfo) const599 bool Tcp::Listener::Matches(const MessageInfo &aMessageInfo) const
600 {
601 bool matches = false;
602 const struct tcpcb_listen *tpl = &GetTcbListen();
603
604 VerifyOrExit(tpl->t_state == TCP6S_LISTEN);
605 VerifyOrExit(tpl->lport == HostSwap16(aMessageInfo.GetSockPort()));
606 VerifyOrExit(GetLocalIp6Address().IsUnspecified() || GetLocalIp6Address() == aMessageInfo.GetSockAddr());
607
608 matches = true;
609
610 exit:
611 return matches;
612 }
613
HandleMessage(ot::Ip6::Header & aIp6Header,Message & aMessage,MessageInfo & aMessageInfo)614 Error Tcp::HandleMessage(ot::Ip6::Header &aIp6Header, Message &aMessage, MessageInfo &aMessageInfo)
615 {
616 Error error = kErrorNotImplemented;
617
618 /*
619 * The type uint32_t was chosen for alignment purposes. The size is the
620 * maximum TCP header size, including options.
621 */
622 uint32_t header[15];
623
624 uint16_t length = aIp6Header.GetPayloadLength();
625 uint8_t headerSize;
626
627 struct ip6_hdr *ip6Header;
628 struct tcphdr * tcpHeader;
629
630 Endpoint *endpoint;
631 Endpoint *endpointPrev;
632
633 Listener *listener;
634 Listener *listenerPrev;
635
636 VerifyOrExit(length == aMessage.GetLength() - aMessage.GetOffset(), error = kErrorParse);
637 VerifyOrExit(length >= sizeof(Tcp::Header), error = kErrorParse);
638 SuccessOrExit(error = aMessage.Read(aMessage.GetOffset() + offsetof(struct tcphdr, th_off_x2), headerSize));
639 headerSize = static_cast<uint8_t>((headerSize >> TH_OFF_SHIFT) << 2);
640 VerifyOrExit(headerSize >= sizeof(struct tcphdr) && headerSize <= sizeof(header) &&
641 static_cast<uint16_t>(headerSize) <= length,
642 error = kErrorParse);
643 SuccessOrExit(error = Checksum::VerifyMessageChecksum(aMessage, aMessageInfo, kProtoTcp));
644 SuccessOrExit(error = aMessage.Read(aMessage.GetOffset(), &header[0], headerSize));
645
646 ip6Header = reinterpret_cast<struct ip6_hdr *>(&aIp6Header);
647 tcpHeader = reinterpret_cast<struct tcphdr *>(&header[0]);
648 tcp_fields_to_host(tcpHeader);
649
650 aMessageInfo.mPeerPort = HostSwap16(tcpHeader->th_sport);
651 aMessageInfo.mSockPort = HostSwap16(tcpHeader->th_dport);
652
653 endpoint = mEndpoints.FindMatching(aMessageInfo, endpointPrev);
654 if (endpoint != nullptr)
655 {
656 struct tcplp_signals sig;
657 int nextAction;
658 struct tcpcb * tp = &endpoint->GetTcb();
659
660 otLinkedBuffer *priorHead = lbuf_head(&tp->sendbuf);
661 size_t priorBacklog = endpoint->GetSendBufferBytes() - endpoint->GetInFlightBytes();
662
663 memset(&sig, 0x00, sizeof(sig));
664 nextAction = tcp_input(ip6Header, tcpHeader, &aMessage, tp, nullptr, &sig);
665 if (nextAction != RELOOKUP_REQUIRED)
666 {
667 ProcessSignals(*endpoint, priorHead, priorBacklog, sig);
668 ExitNow();
669 }
670 /* If the matching socket was in the TIME-WAIT state, then we try passive sockets. */
671 }
672
673 listener = mListeners.FindMatching(aMessageInfo, listenerPrev);
674 if (listener != nullptr)
675 {
676 struct tcpcb_listen *tpl = &listener->GetTcbListen();
677
678 tcp_input(ip6Header, tcpHeader, &aMessage, nullptr, tpl, nullptr);
679 ExitNow();
680 }
681
682 tcp_dropwithreset(ip6Header, tcpHeader, nullptr, &InstanceLocator::GetInstance(), length - headerSize,
683 ECONNREFUSED);
684
685 exit:
686 return error;
687 }
688
ProcessSignals(Endpoint & aEndpoint,otLinkedBuffer * aPriorHead,size_t aPriorBacklog,struct tcplp_signals & aSignals)689 void Tcp::ProcessSignals(Endpoint & aEndpoint,
690 otLinkedBuffer * aPriorHead,
691 size_t aPriorBacklog,
692 struct tcplp_signals &aSignals)
693 {
694 VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
695 if (aSignals.conn_established && aEndpoint.mEstablishedCallback != nullptr)
696 {
697 aEndpoint.mEstablishedCallback(&aEndpoint);
698 }
699
700 VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
701 if (aEndpoint.mSendDoneCallback != nullptr)
702 {
703 otLinkedBuffer *curr = aPriorHead;
704
705 for (uint32_t i = 0; i != aSignals.links_popped; i++)
706 {
707 otLinkedBuffer *next = curr->mNext;
708
709 VerifyOrExit(i == 0 || (IsInitialized(aEndpoint) && !aEndpoint.IsClosed()));
710
711 curr->mNext = nullptr;
712 aEndpoint.mSendDoneCallback(&aEndpoint, curr);
713 curr = next;
714 }
715 }
716
717 VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
718 if (aEndpoint.mForwardProgressCallback != nullptr)
719 {
720 size_t backlogBytes = aEndpoint.GetBacklogBytes();
721
722 if (aSignals.bytes_acked > 0 || backlogBytes < aPriorBacklog)
723 {
724 aEndpoint.mForwardProgressCallback(&aEndpoint, aEndpoint.GetSendBufferBytes(), backlogBytes);
725 aEndpoint.mPendingCallbacks &= ~kForwardProgressCallbackFlag;
726 }
727 }
728
729 VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
730 if ((aSignals.recvbuf_added || aSignals.rcvd_fin) && aEndpoint.mReceiveAvailableCallback != nullptr)
731 {
732 aEndpoint.mReceiveAvailableCallback(&aEndpoint, cbuf_used_space(&aEndpoint.GetTcb().recvbuf),
733 aEndpoint.GetTcb().reass_fin_index != -1,
734 cbuf_free_space(&aEndpoint.GetTcb().recvbuf));
735 }
736
737 VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
738 if (aEndpoint.GetTcb().t_state == TCP6S_TIME_WAIT && aEndpoint.mDisconnectedCallback != nullptr)
739 {
740 aEndpoint.mDisconnectedCallback(&aEndpoint, OT_TCP_DISCONNECTED_REASON_TIME_WAIT);
741 }
742
743 exit:
744 return;
745 }
746
BsdErrorToOtError(int aBsdError)747 Error Tcp::BsdErrorToOtError(int aBsdError)
748 {
749 Error error = kErrorFailed;
750
751 switch (aBsdError)
752 {
753 case 0:
754 error = kErrorNone;
755 break;
756 }
757
758 return error;
759 }
760
CanBind(const SockAddr & aSockName)761 bool Tcp::CanBind(const SockAddr &aSockName)
762 {
763 uint16_t port = HostSwap16(aSockName.mPort);
764 bool allowed = false;
765
766 for (Endpoint &endpoint : mEndpoints)
767 {
768 struct tcpcb *tp = &endpoint.GetTcb();
769
770 if (tp->lport == port)
771 {
772 VerifyOrExit(!aSockName.GetAddress().IsUnspecified());
773 VerifyOrExit(!reinterpret_cast<Address *>(&tp->laddr)->IsUnspecified());
774 VerifyOrExit(memcmp(&endpoint.GetTcb().laddr, &aSockName.mAddress, sizeof(tp->laddr)) != 0);
775 }
776 }
777
778 for (Listener &listener : mListeners)
779 {
780 struct tcpcb_listen *tpl = &listener.GetTcbListen();
781
782 if (tpl->lport == port)
783 {
784 VerifyOrExit(!aSockName.GetAddress().IsUnspecified());
785 VerifyOrExit(!reinterpret_cast<Address *>(&tpl->laddr)->IsUnspecified());
786 VerifyOrExit(memcmp(&tpl->laddr, &aSockName.mAddress, sizeof(tpl->laddr)) != 0);
787 }
788 }
789
790 allowed = true;
791
792 exit:
793 return allowed;
794 }
795
AutoBind(const SockAddr & aPeer,SockAddr & aToBind,bool aBindAddress,bool aBindPort)796 bool Tcp::AutoBind(const SockAddr &aPeer, SockAddr &aToBind, bool aBindAddress, bool aBindPort)
797 {
798 bool success;
799
800 if (aBindAddress)
801 {
802 MessageInfo peerInfo;
803 const Netif::UnicastAddress *netifAddress;
804
805 peerInfo.Clear();
806 peerInfo.SetPeerAddr(aPeer.GetAddress());
807 netifAddress = Get<Ip6>().SelectSourceAddress(peerInfo);
808 VerifyOrExit(netifAddress != nullptr, success = false);
809 aToBind.GetAddress() = netifAddress->GetAddress();
810 }
811
812 if (aBindPort)
813 {
814 /*
815 * TODO: Use a less naive algorithm to allocate ephemeral ports. For
816 * example, see RFC 6056.
817 */
818
819 for (uint16_t i = 0; i != kDynamicPortMax - kDynamicPortMin + 1; i++)
820 {
821 aToBind.SetPort(mEphemeralPort);
822
823 if (mEphemeralPort == kDynamicPortMax)
824 {
825 mEphemeralPort = kDynamicPortMin;
826 }
827 else
828 {
829 mEphemeralPort++;
830 }
831
832 if (CanBind(aToBind))
833 {
834 ExitNow(success = true);
835 }
836 }
837
838 ExitNow(success = false);
839 }
840
841 success = CanBind(aToBind);
842
843 exit:
844 return success;
845 }
846
HandleTimer(Timer & aTimer)847 void Tcp::HandleTimer(Timer &aTimer)
848 {
849 OT_ASSERT(&aTimer == &aTimer.Get<Tcp>().mTimer);
850 LogDebg("Main TCP timer expired");
851 aTimer.Get<Tcp>().ProcessTimers();
852 }
853
ProcessTimers(void)854 void Tcp::ProcessTimers(void)
855 {
856 TimeMilli now = TimerMilli::GetNow();
857 bool pendingTimer;
858 TimeMilli earliestPendingTimerExpiry;
859
860 OT_ASSERT(!mTimer.IsRunning());
861
862 /*
863 * The timer callbacks could potentially set/reset/cancel timers.
864 * Importantly, Endpoint::SetTimer and Endpoint::CancelTimer do not call
865 * this function to recompute the timer. If they did, we'd have a
866 * re-entrancy problem, where the callbacks called in this function could
867 * wind up re-entering this function in a nested call frame.
868 *
869 * In general, calling this function from Endpoint::SetTimer and
870 * Endpoint::CancelTimer could be inefficient, since those functions are
871 * called multiple times on each received TCP segment. If we want to
872 * prevent the main timer from firing except when an actual TCP timer
873 * expires, a better alternative is to reset the main timer in
874 * HandleMessage, right before processing signals. That would achieve that
875 * objective while avoiding re-entrancy issues altogether.
876 */
877 restart:
878 pendingTimer = false;
879 earliestPendingTimerExpiry = now.GetDistantFuture();
880
881 for (Endpoint &endpoint : mEndpoints)
882 {
883 if (endpoint.FirePendingTimers(now, pendingTimer, earliestPendingTimerExpiry))
884 {
885 /*
886 * If a non-OpenThread callback is called --- which, in practice,
887 * happens if the connection times out and the user-defined
888 * connection lost callback is called --- then we might have to
889 * start over. The reason is that the user might deinitialize
890 * endpoints, changing the structure of the linked list. For
891 * example, if the user deinitializes both this endpoint and the
892 * next one in the linked list, then we can't continue traversing
893 * the linked list.
894 */
895 goto restart;
896 }
897 }
898
899 if (pendingTimer)
900 {
901 /*
902 * We need to use Timer::FireAtIfEarlier instead of timer::FireAt
903 * because one of the earlier callbacks might have set TCP timers,
904 * in which case `mTimer` would have been set to the earliest of those
905 * timers.
906 */
907 mTimer.FireAtIfEarlier(earliestPendingTimerExpiry);
908 LogDebg("Reset main TCP timer to %u ms", static_cast<unsigned int>(earliestPendingTimerExpiry - now));
909 }
910 else
911 {
912 LogDebg("Did not reset main TCP timer");
913 }
914 }
915
HandleTasklet(Tasklet & aTasklet)916 void Tcp::HandleTasklet(Tasklet &aTasklet)
917 {
918 OT_ASSERT(&aTasklet == &aTasklet.Get<Tcp>().mTasklet);
919 LogDebg("TCP tasklet invoked");
920 aTasklet.Get<Tcp>().ProcessCallbacks();
921 }
922
ProcessCallbacks(void)923 void Tcp::ProcessCallbacks(void)
924 {
925 for (Endpoint &endpoint : mEndpoints)
926 {
927 if (endpoint.FirePendingCallbacks())
928 {
929 mTasklet.Post();
930 break;
931 }
932 }
933 }
934
935 } // namespace Ip6
936 } // namespace ot
937
938 /*
939 * Implement TCPlp system stubs declared in tcplp.h.
940 *
941 * Because these functions have C linkage, it is important that only one
942 * definition is given for each function name, regardless of the namespace it
943 * in. For example, if we give two definitions of tcplp_sys_new_message, we
944 * will get errors, even if they are in different namespaces. To avoid
945 * confusion, I've put these functions outside of any namespace.
946 */
947
948 using namespace ot;
949 using namespace ot::Ip6;
950
951 extern "C" {
952
tcplp_sys_new_message(otInstance * aInstance)953 otMessage *tcplp_sys_new_message(otInstance *aInstance)
954 {
955 Instance &instance = AsCoreType(aInstance);
956 Message * message = instance.Get<ot::Ip6::Ip6>().NewMessage(0);
957
958 if (message)
959 {
960 message->SetLinkSecurityEnabled(true);
961 }
962
963 return message;
964 }
965
tcplp_sys_free_message(otInstance * aInstance,otMessage * aMessage)966 void tcplp_sys_free_message(otInstance *aInstance, otMessage *aMessage)
967 {
968 OT_UNUSED_VARIABLE(aInstance);
969 Message &message = AsCoreType(aMessage);
970 message.Free();
971 }
972
tcplp_sys_send_message(otInstance * aInstance,otMessage * aMessage,otMessageInfo * aMessageInfo)973 void tcplp_sys_send_message(otInstance *aInstance, otMessage *aMessage, otMessageInfo *aMessageInfo)
974 {
975 Instance & instance = AsCoreType(aInstance);
976 Message & message = AsCoreType(aMessage);
977 MessageInfo &info = AsCoreType(aMessageInfo);
978
979 LogDebg("Sending TCP segment: payload_size = %d", static_cast<int>(message.GetLength()));
980
981 IgnoreError(instance.Get<ot::Ip6::Ip6>().SendDatagram(message, info, kProtoTcp));
982 }
983
tcplp_sys_get_ticks(void)984 uint32_t tcplp_sys_get_ticks(void)
985 {
986 return TimerMilli::GetNow().GetValue();
987 }
988
tcplp_sys_get_millis(void)989 uint32_t tcplp_sys_get_millis(void)
990 {
991 return TimerMilli::GetNow().GetValue();
992 }
993
tcplp_sys_set_timer(struct tcpcb * aTcb,uint8_t aTimerFlag,uint32_t aDelay)994 void tcplp_sys_set_timer(struct tcpcb *aTcb, uint8_t aTimerFlag, uint32_t aDelay)
995 {
996 Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aTcb);
997 endpoint.SetTimer(aTimerFlag, aDelay);
998 }
999
tcplp_sys_stop_timer(struct tcpcb * aTcb,uint8_t aTimerFlag)1000 void tcplp_sys_stop_timer(struct tcpcb *aTcb, uint8_t aTimerFlag)
1001 {
1002 Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aTcb);
1003 endpoint.CancelTimer(aTimerFlag);
1004 }
1005
tcplp_sys_accept_ready(struct tcpcb_listen * aTcbListen,struct in6_addr * aAddr,uint16_t aPort)1006 struct tcpcb *tcplp_sys_accept_ready(struct tcpcb_listen *aTcbListen, struct in6_addr *aAddr, uint16_t aPort)
1007 {
1008 Tcp::Listener & listener = Tcp::Listener::FromTcbListen(*aTcbListen);
1009 Tcp & tcp = listener.Get<Tcp>();
1010 struct tcpcb * rv = (struct tcpcb *)-1;
1011 otSockAddr addr;
1012 otTcpEndpoint * endpointPtr;
1013 otTcpIncomingConnectionAction action;
1014
1015 VerifyOrExit(listener.mAcceptReadyCallback != nullptr);
1016
1017 memcpy(&addr.mAddress, aAddr, sizeof(addr.mAddress));
1018 addr.mPort = HostSwap16(aPort);
1019 action = listener.mAcceptReadyCallback(&listener, &addr, &endpointPtr);
1020
1021 VerifyOrExit(tcp.IsInitialized(listener) && !listener.IsClosed());
1022
1023 switch (action)
1024 {
1025 case OT_TCP_INCOMING_CONNECTION_ACTION_ACCEPT:
1026 {
1027 Tcp::Endpoint &endpoint = AsCoreType(endpointPtr);
1028
1029 /*
1030 * The documentation says that the user must initialize the
1031 * endpoint before passing it here, so we do a sanity check to make
1032 * sure the endpoint is initialized and closed. That check may not
1033 * be necessary, but we do it anyway.
1034 */
1035 VerifyOrExit(tcp.IsInitialized(endpoint) && endpoint.IsClosed());
1036
1037 rv = &endpoint.GetTcb();
1038
1039 break;
1040 }
1041 case OT_TCP_INCOMING_CONNECTION_ACTION_DEFER:
1042 rv = nullptr;
1043 break;
1044 case OT_TCP_INCOMING_CONNECTION_ACTION_REFUSE:
1045 rv = (struct tcpcb *)-1;
1046 break;
1047 }
1048
1049 exit:
1050 return rv;
1051 }
1052
tcplp_sys_accepted_connection(struct tcpcb_listen * aTcbListen,struct tcpcb * aAccepted,struct in6_addr * aAddr,uint16_t aPort)1053 bool tcplp_sys_accepted_connection(struct tcpcb_listen *aTcbListen,
1054 struct tcpcb * aAccepted,
1055 struct in6_addr * aAddr,
1056 uint16_t aPort)
1057 {
1058 Tcp::Listener &listener = Tcp::Listener::FromTcbListen(*aTcbListen);
1059 Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aAccepted);
1060 Tcp & tcp = endpoint.Get<Tcp>();
1061 bool accepted = true;
1062
1063 if (listener.mAcceptDoneCallback != nullptr)
1064 {
1065 otSockAddr addr;
1066
1067 memcpy(&addr.mAddress, aAddr, sizeof(addr.mAddress));
1068 addr.mPort = HostSwap16(aPort);
1069 listener.mAcceptDoneCallback(&listener, &endpoint, &addr);
1070
1071 if (!tcp.IsInitialized(endpoint) || endpoint.IsClosed())
1072 {
1073 accepted = false;
1074 }
1075 }
1076
1077 return accepted;
1078 }
1079
tcplp_sys_connection_lost(struct tcpcb * aTcb,uint8_t aErrNum)1080 void tcplp_sys_connection_lost(struct tcpcb *aTcb, uint8_t aErrNum)
1081 {
1082 Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aTcb);
1083
1084 if (endpoint.mDisconnectedCallback != nullptr)
1085 {
1086 otTcpDisconnectedReason reason;
1087
1088 switch (aErrNum)
1089 {
1090 case CONN_LOST_NORMAL:
1091 reason = OT_TCP_DISCONNECTED_REASON_NORMAL;
1092 break;
1093 case ECONNREFUSED:
1094 reason = OT_TCP_DISCONNECTED_REASON_REFUSED;
1095 break;
1096 case ETIMEDOUT:
1097 reason = OT_TCP_DISCONNECTED_REASON_TIMED_OUT;
1098 break;
1099 case ECONNRESET:
1100 default:
1101 reason = OT_TCP_DISCONNECTED_REASON_RESET;
1102 break;
1103 }
1104 endpoint.mDisconnectedCallback(&endpoint, reason);
1105 }
1106 }
1107
tcplp_sys_on_state_change(struct tcpcb * aTcb,int aNewState)1108 void tcplp_sys_on_state_change(struct tcpcb *aTcb, int aNewState)
1109 {
1110 if (aNewState == TCP6S_CLOSED)
1111 {
1112 /* Re-initialize the TCB. */
1113 cbuf_pop(&aTcb->recvbuf, cbuf_used_space(&aTcb->recvbuf));
1114 aTcb->accepted_from = nullptr;
1115 initialize_tcb(aTcb);
1116 }
1117 /* Any adaptive changes to the sleep interval would go here. */
1118 }
1119
tcplp_sys_log(const char * aFormat,...)1120 void tcplp_sys_log(const char *aFormat, ...)
1121 {
1122 char buffer[128];
1123 va_list args;
1124 va_start(args, aFormat);
1125 vsnprintf(buffer, sizeof(buffer), aFormat, args);
1126 va_end(args);
1127
1128 LogDebg(buffer);
1129 }
1130
tcplp_sys_panic(const char * aFormat,...)1131 void tcplp_sys_panic(const char *aFormat, ...)
1132 {
1133 char buffer[128];
1134 va_list args;
1135 va_start(args, aFormat);
1136 vsnprintf(buffer, sizeof(buffer), aFormat, args);
1137 va_end(args);
1138
1139 LogCrit("%s", buffer);
1140
1141 OT_ASSERT(false);
1142 }
1143
tcplp_sys_autobind(otInstance * aInstance,const otSockAddr * aPeer,otSockAddr * aToBind,bool aBindAddress,bool aBindPort)1144 bool tcplp_sys_autobind(otInstance * aInstance,
1145 const otSockAddr *aPeer,
1146 otSockAddr * aToBind,
1147 bool aBindAddress,
1148 bool aBindPort)
1149 {
1150 Instance &instance = AsCoreType(aInstance);
1151
1152 return instance.Get<Tcp>().AutoBind(*static_cast<const SockAddr *>(aPeer), *static_cast<SockAddr *>(aToBind),
1153 aBindAddress, aBindPort);
1154 }
1155
tcplp_sys_generate_isn()1156 uint32_t tcplp_sys_generate_isn()
1157 {
1158 uint32_t isn;
1159 IgnoreError(Random::Crypto::FillBuffer(reinterpret_cast<uint8_t *>(&isn), sizeof(isn)));
1160 return isn;
1161 }
1162
tcplp_sys_hostswap16(uint16_t aHostPort)1163 uint16_t tcplp_sys_hostswap16(uint16_t aHostPort)
1164 {
1165 return HostSwap16(aHostPort);
1166 }
1167
tcplp_sys_hostswap32(uint32_t aHostPort)1168 uint32_t tcplp_sys_hostswap32(uint32_t aHostPort)
1169 {
1170 return HostSwap32(aHostPort);
1171 }
1172 }
1173
1174 #endif // OPENTHREAD_CONFIG_TCP_ENABLE
1175