• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2021, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements TCP/IPv6 sockets.
32  */
33 
34 #include "openthread-core-config.h"
35 
36 #if OPENTHREAD_CONFIG_TCP_ENABLE
37 
38 #include "tcp6.hpp"
39 
40 #include "common/as_core_type.hpp"
41 #include "common/code_utils.hpp"
42 #include "common/error.hpp"
43 #include "common/instance.hpp"
44 #include "common/locator_getters.hpp"
45 #include "common/log.hpp"
46 #include "common/random.hpp"
47 #include "net/checksum.hpp"
48 #include "net/ip6.hpp"
49 #include "net/netif.hpp"
50 
51 #include "../../third_party/tcplp/tcplp.h"
52 
53 namespace ot {
54 namespace Ip6 {
55 
56 using ot::Encoding::BigEndian::HostSwap16;
57 using ot::Encoding::BigEndian::HostSwap32;
58 
59 RegisterLogModule("Tcp");
60 
61 static_assert(sizeof(struct tcpcb) == sizeof(Tcp::Endpoint::mTcb), "mTcb field in otTcpEndpoint is sized incorrectly");
62 static_assert(alignof(struct tcpcb) == alignof(decltype(Tcp::Endpoint::mTcb)),
63               "mTcb field in otTcpEndpoint is aligned incorrectly");
64 static_assert(offsetof(Tcp::Endpoint, mTcb) == 0, "mTcb field in otTcpEndpoint has nonzero offset");
65 
66 static_assert(sizeof(struct tcpcb_listen) == sizeof(Tcp::Listener::mTcbListen),
67               "mTcbListen field in otTcpListener is sized incorrectly");
68 static_assert(alignof(struct tcpcb_listen) == alignof(decltype(Tcp::Listener::mTcbListen)),
69               "mTcbListen field in otTcpListener is aligned incorrectly");
70 static_assert(offsetof(Tcp::Listener, mTcbListen) == 0, "mTcbListen field in otTcpEndpoint has nonzero offset");
71 
Tcp(Instance & aInstance)72 Tcp::Tcp(Instance &aInstance)
73     : InstanceLocator(aInstance)
74     , mTimer(aInstance, Tcp::HandleTimer)
75     , mTasklet(aInstance, Tcp::HandleTasklet)
76     , mEphemeralPort(kDynamicPortMin)
77 {
78     OT_UNUSED_VARIABLE(mEphemeralPort);
79 }
80 
Initialize(Instance & aInstance,const otTcpEndpointInitializeArgs & aArgs)81 Error Tcp::Endpoint::Initialize(Instance &aInstance, const otTcpEndpointInitializeArgs &aArgs)
82 {
83     Error         error;
84     struct tcpcb &tp = GetTcb();
85 
86     memset(&tp, 0x00, sizeof(tp));
87 
88     SuccessOrExit(error = aInstance.Get<Tcp>().mEndpoints.Add(*this));
89 
90     mContext                  = aArgs.mContext;
91     mEstablishedCallback      = aArgs.mEstablishedCallback;
92     mSendDoneCallback         = aArgs.mSendDoneCallback;
93     mForwardProgressCallback  = aArgs.mForwardProgressCallback;
94     mReceiveAvailableCallback = aArgs.mReceiveAvailableCallback;
95     mDisconnectedCallback     = aArgs.mDisconnectedCallback;
96 
97     memset(mTimers, 0x00, sizeof(mTimers));
98     memset(&mSockAddr, 0x00, sizeof(mSockAddr));
99     mPendingCallbacks = 0;
100 
101     /*
102      * Initialize buffers --- formerly in initialize_tcb.
103      */
104     {
105         uint8_t *recvbuf    = static_cast<uint8_t *>(aArgs.mReceiveBuffer);
106         size_t   recvbuflen = aArgs.mReceiveBufferSize - ((aArgs.mReceiveBufferSize + 8) / 9);
107         uint8_t *reassbmp   = recvbuf + recvbuflen;
108 
109         lbuf_init(&tp.sendbuf);
110         cbuf_init(&tp.recvbuf, recvbuf, recvbuflen);
111         tp.reassbmp = reassbmp;
112         bmp_init(tp.reassbmp, BITS_TO_BYTES(recvbuflen));
113     }
114 
115     tp.accepted_from = nullptr;
116     initialize_tcb(&tp);
117 
118     /* Note that we do not need to zero-initialize mReceiveLinks. */
119 
120     tp.instance = &aInstance;
121 
122 exit:
123     return error;
124 }
125 
GetInstance(void) const126 Instance &Tcp::Endpoint::GetInstance(void) const
127 {
128     return AsNonConst(AsCoreType(GetTcb().instance));
129 }
130 
GetLocalAddress(void) const131 const SockAddr &Tcp::Endpoint::GetLocalAddress(void) const
132 {
133     const struct tcpcb &tp = GetTcb();
134 
135     static otSockAddr temp;
136 
137     memcpy(&temp.mAddress, &tp.laddr, sizeof(temp.mAddress));
138     temp.mPort = HostSwap16(tp.lport);
139 
140     return AsCoreType(&temp);
141 }
142 
GetPeerAddress(void) const143 const SockAddr &Tcp::Endpoint::GetPeerAddress(void) const
144 {
145     const struct tcpcb &tp = GetTcb();
146 
147     static otSockAddr temp;
148 
149     memcpy(&temp.mAddress, &tp.faddr, sizeof(temp.mAddress));
150     temp.mPort = HostSwap16(tp.fport);
151 
152     return AsCoreType(&temp);
153 }
154 
Bind(const SockAddr & aSockName)155 Error Tcp::Endpoint::Bind(const SockAddr &aSockName)
156 {
157     Error         error;
158     struct tcpcb &tp = GetTcb();
159 
160     VerifyOrExit(!AsCoreType(&aSockName.mAddress).IsUnspecified(), error = kErrorInvalidArgs);
161     VerifyOrExit(Get<Tcp>().CanBind(aSockName), error = kErrorInvalidState);
162 
163     memcpy(&tp.laddr, &aSockName.mAddress, sizeof(tp.laddr));
164     tp.lport = HostSwap16(aSockName.mPort);
165     error    = kErrorNone;
166 
167 exit:
168     return error;
169 }
170 
Connect(const SockAddr & aSockName,uint32_t aFlags)171 Error Tcp::Endpoint::Connect(const SockAddr &aSockName, uint32_t aFlags)
172 {
173     Error               error = kErrorNone;
174     struct tcpcb &      tp    = GetTcb();
175     struct sockaddr_in6 sin6p;
176 
177     OT_UNUSED_VARIABLE(aFlags);
178 
179     VerifyOrExit(tp.t_state == TCP6S_CLOSED, error = kErrorInvalidState);
180 
181     memcpy(&sin6p.sin6_addr, &aSockName.mAddress, sizeof(sin6p.sin6_addr));
182     sin6p.sin6_port = HostSwap16(aSockName.mPort);
183     error           = BsdErrorToOtError(tcp6_usr_connect(&tp, &sin6p));
184 
185 exit:
186     return error;
187 }
188 
SendByReference(otLinkedBuffer & aBuffer,uint32_t aFlags)189 Error Tcp::Endpoint::SendByReference(otLinkedBuffer &aBuffer, uint32_t aFlags)
190 {
191     Error         error;
192     struct tcpcb &tp = GetTcb();
193 
194     size_t backlogBefore = GetBacklogBytes();
195     size_t sent          = aBuffer.mLength;
196 
197     SuccessOrExit(error = BsdErrorToOtError(tcp_usr_send(&tp, (aFlags & OT_TCP_SEND_MORE_TO_COME) != 0, &aBuffer, 0)));
198 
199     PostCallbacksAfterSend(sent, backlogBefore);
200 
201 exit:
202     return error;
203 }
204 
SendByExtension(size_t aNumBytes,uint32_t aFlags)205 Error Tcp::Endpoint::SendByExtension(size_t aNumBytes, uint32_t aFlags)
206 {
207     Error         error;
208     bool          moreToCome    = (aFlags & OT_TCP_SEND_MORE_TO_COME) != 0;
209     struct tcpcb &tp            = GetTcb();
210     size_t        backlogBefore = GetBacklogBytes();
211     int           bsdError;
212 
213     VerifyOrExit(lbuf_head(&tp.sendbuf) != nullptr, error = kErrorInvalidState);
214 
215     bsdError = tcp_usr_send(&tp, moreToCome ? 1 : 0, nullptr, aNumBytes);
216     SuccessOrExit(error = BsdErrorToOtError(bsdError));
217 
218     PostCallbacksAfterSend(aNumBytes, backlogBefore);
219 
220 exit:
221     return error;
222 }
223 
ReceiveByReference(const otLinkedBuffer * & aBuffer)224 Error Tcp::Endpoint::ReceiveByReference(const otLinkedBuffer *&aBuffer)
225 {
226     struct tcpcb &tp = GetTcb();
227 
228     cbuf_reference(&tp.recvbuf, &mReceiveLinks[0], &mReceiveLinks[1]);
229     aBuffer = &mReceiveLinks[0];
230 
231     return kErrorNone;
232 }
233 
ReceiveContiguify(void)234 Error Tcp::Endpoint::ReceiveContiguify(void)
235 {
236     return kErrorNotImplemented;
237 }
238 
CommitReceive(size_t aNumBytes,uint32_t aFlags)239 Error Tcp::Endpoint::CommitReceive(size_t aNumBytes, uint32_t aFlags)
240 {
241     Error         error = kErrorNone;
242     struct tcpcb &tp    = GetTcb();
243 
244     OT_UNUSED_VARIABLE(aFlags);
245 
246     VerifyOrExit(cbuf_used_space(&tp.recvbuf) >= aNumBytes, error = kErrorFailed);
247     VerifyOrExit(aNumBytes > 0, error = kErrorNone);
248 
249     cbuf_pop(&tp.recvbuf, aNumBytes);
250     error = BsdErrorToOtError(tcp_usr_rcvd(&tp));
251 
252 exit:
253     return error;
254 }
255 
SendEndOfStream(void)256 Error Tcp::Endpoint::SendEndOfStream(void)
257 {
258     struct tcpcb &tp = GetTcb();
259 
260     return BsdErrorToOtError(tcp_usr_shutdown(&tp));
261 }
262 
Abort(void)263 Error Tcp::Endpoint::Abort(void)
264 {
265     struct tcpcb &tp = GetTcb();
266 
267     tcp_usr_abort(&tp);
268     /* connection_lost will do any reinitialization work for this socket. */
269     return kErrorNone;
270 }
271 
Deinitialize(void)272 Error Tcp::Endpoint::Deinitialize(void)
273 {
274     Error error;
275 
276     SuccessOrExit(error = Get<Tcp>().mEndpoints.Remove(*this));
277     SetNext(nullptr);
278 
279     SuccessOrExit(error = Abort());
280 
281 exit:
282     return error;
283 }
284 
IsClosed(void) const285 bool Tcp::Endpoint::IsClosed(void) const
286 {
287     return GetTcb().t_state == TCP6S_CLOSED;
288 }
289 
TimerFlagToIndex(uint8_t aTimerFlag)290 uint8_t Tcp::Endpoint::TimerFlagToIndex(uint8_t aTimerFlag)
291 {
292     uint8_t timerIndex = 0;
293 
294     switch (aTimerFlag)
295     {
296     case TT_DELACK:
297         timerIndex = kTimerDelack;
298         break;
299     case TT_REXMT:
300     case TT_PERSIST:
301         timerIndex = kTimerRexmtPersist;
302         break;
303     case TT_KEEP:
304         timerIndex = kTimerKeep;
305         break;
306     case TT_2MSL:
307         timerIndex = kTimer2Msl;
308         break;
309     }
310 
311     return timerIndex;
312 }
313 
IsTimerActive(uint8_t aTimerIndex)314 bool Tcp::Endpoint::IsTimerActive(uint8_t aTimerIndex)
315 {
316     bool          active = false;
317     struct tcpcb *tp     = &GetTcb();
318 
319     OT_ASSERT(aTimerIndex < kNumTimers);
320     switch (aTimerIndex)
321     {
322     case kTimerDelack:
323         active = tcp_timer_active(tp, TT_DELACK);
324         break;
325     case kTimerRexmtPersist:
326         active = tcp_timer_active(tp, TT_REXMT) || tcp_timer_active(tp, TT_PERSIST);
327         break;
328     case kTimerKeep:
329         active = tcp_timer_active(tp, TT_KEEP);
330         break;
331     case kTimer2Msl:
332         active = tcp_timer_active(tp, TT_2MSL);
333         break;
334     }
335 
336     return active;
337 }
338 
SetTimer(uint8_t aTimerFlag,uint32_t aDelay)339 void Tcp::Endpoint::SetTimer(uint8_t aTimerFlag, uint32_t aDelay)
340 {
341     /*
342      * TCPlp has already set the flag for this timer to record that it's
343      * running. So, all that's left to do is record the expiry time and
344      * (re)set the main timer as appropriate.
345      */
346 
347     TimeMilli now         = TimerMilli::GetNow();
348     TimeMilli newFireTime = now + aDelay;
349     uint8_t   timerIndex  = TimerFlagToIndex(aTimerFlag);
350 
351     mTimers[timerIndex] = newFireTime.GetValue();
352     LogDebg("Endpoint %p set timer %u to %u ms", static_cast<void *>(this), static_cast<unsigned int>(timerIndex),
353             static_cast<unsigned int>(aDelay));
354 
355     Get<Tcp>().mTimer.FireAtIfEarlier(newFireTime);
356 }
357 
CancelTimer(uint8_t aTimerFlag)358 void Tcp::Endpoint::CancelTimer(uint8_t aTimerFlag)
359 {
360     /*
361      * TCPlp has already cleared the timer flag before calling this. Since the
362      * main timer's callback properly handles the case where no timers are
363      * actually due, there's actually no work to be done here.
364      */
365 
366     OT_UNUSED_VARIABLE(aTimerFlag);
367 
368     LogDebg("Endpoint %p cancelled timer %u", static_cast<void *>(this),
369             static_cast<unsigned int>(TimerFlagToIndex(aTimerFlag)));
370 }
371 
FirePendingTimers(TimeMilli aNow,bool & aHasFutureTimer,TimeMilli & aEarliestFutureExpiry)372 bool Tcp::Endpoint::FirePendingTimers(TimeMilli aNow, bool &aHasFutureTimer, TimeMilli &aEarliestFutureExpiry)
373 {
374     bool          calledUserCallback = false;
375     struct tcpcb *tp                 = &GetTcb();
376 
377     /*
378      * NOTE: Firing a timer might potentially activate/deactivate other timers.
379      * If timers x and y expire at the same time, but the callback for timer x
380      * (for x < y) cancels or postpones timer y, should timer y's callback be
381      * called? Our answer is no, since timer x's callback has updated the
382      * TCP stack's state in such a way that it no longer expects timer y's
383      * callback to to be called. Because the TCP stack thinks that timer y
384      * has been cancelled, calling timer y's callback could potentially cause
385      * problems.
386      *
387      * If the timer callbacks set other timers, then they may not be taken
388      * into account when setting aEarliestFutureExpiry. But mTimer's expiry
389      * time will be updated by those, so we can just compare against mTimer's
390      * expiry time when resetting mTimer.
391      */
392     for (uint8_t timerIndex = 0; timerIndex != kNumTimers; timerIndex++)
393     {
394         if (IsTimerActive(timerIndex))
395         {
396             TimeMilli expiry(mTimers[timerIndex]);
397 
398             if (expiry <= aNow)
399             {
400                 /*
401                  * If a user callback is called, then return true. For TCPlp,
402                  * this only happens if the connection is dropped (e.g., it
403                  * times out).
404                  */
405                 int dropped;
406 
407                 switch (timerIndex)
408                 {
409                 case kTimerDelack:
410                     dropped = tcp_timer_delack(tp);
411                     break;
412                 case kTimerRexmtPersist:
413                     if (tcp_timer_active(tp, TT_REXMT))
414                     {
415                         dropped = tcp_timer_rexmt(tp);
416                     }
417                     else
418                     {
419                         dropped = tcp_timer_persist(tp);
420                     }
421                     break;
422                 case kTimerKeep:
423                     dropped = tcp_timer_keep(tp);
424                     break;
425                 case kTimer2Msl:
426                     dropped = tcp_timer_2msl(tp);
427                     break;
428                 }
429                 VerifyOrExit(dropped == 0, calledUserCallback = true);
430             }
431             else
432             {
433                 aHasFutureTimer       = true;
434                 aEarliestFutureExpiry = OT_MIN(aEarliestFutureExpiry, expiry);
435             }
436         }
437     }
438 
439 exit:
440     return calledUserCallback;
441 }
442 
PostCallbacksAfterSend(size_t aSent,size_t aBacklogBefore)443 void Tcp::Endpoint::PostCallbacksAfterSend(size_t aSent, size_t aBacklogBefore)
444 {
445     size_t backlogAfter = GetBacklogBytes();
446 
447     if (backlogAfter < aBacklogBefore + aSent && mForwardProgressCallback != nullptr)
448     {
449         mPendingCallbacks |= kForwardProgressCallbackFlag;
450         Get<Tcp>().mTasklet.Post();
451     }
452 }
453 
FirePendingCallbacks(void)454 bool Tcp::Endpoint::FirePendingCallbacks(void)
455 {
456     bool calledUserCallback = false;
457 
458     if ((mPendingCallbacks & kForwardProgressCallbackFlag) != 0 && mForwardProgressCallback != nullptr)
459     {
460         mForwardProgressCallback(this, GetSendBufferBytes(), GetBacklogBytes());
461         calledUserCallback = true;
462     }
463 
464     mPendingCallbacks = 0;
465 
466     return calledUserCallback;
467 }
468 
GetSendBufferBytes(void) const469 size_t Tcp::Endpoint::GetSendBufferBytes(void) const
470 {
471     const struct tcpcb &tp = GetTcb();
472     return lbuf_used_space(&tp.sendbuf);
473 }
474 
GetInFlightBytes(void) const475 size_t Tcp::Endpoint::GetInFlightBytes(void) const
476 {
477     const struct tcpcb &tp = GetTcb();
478     return tp.snd_max - tp.snd_una;
479 }
480 
GetBacklogBytes(void) const481 size_t Tcp::Endpoint::GetBacklogBytes(void) const
482 {
483     return GetSendBufferBytes() - GetInFlightBytes();
484 }
485 
GetLocalIp6Address(void)486 Address &Tcp::Endpoint::GetLocalIp6Address(void)
487 {
488     return *reinterpret_cast<Address *>(&GetTcb().laddr);
489 }
490 
GetLocalIp6Address(void) const491 const Address &Tcp::Endpoint::GetLocalIp6Address(void) const
492 {
493     return *reinterpret_cast<const Address *>(&GetTcb().laddr);
494 }
495 
GetForeignIp6Address(void)496 Address &Tcp::Endpoint::GetForeignIp6Address(void)
497 {
498     return *reinterpret_cast<Address *>(&GetTcb().faddr);
499 }
500 
GetForeignIp6Address(void) const501 const Address &Tcp::Endpoint::GetForeignIp6Address(void) const
502 {
503     return *reinterpret_cast<const Address *>(&GetTcb().faddr);
504 }
505 
Matches(const MessageInfo & aMessageInfo) const506 bool Tcp::Endpoint::Matches(const MessageInfo &aMessageInfo) const
507 {
508     bool                matches = false;
509     const struct tcpcb *tp      = &GetTcb();
510 
511     VerifyOrExit(tp->t_state != TCP6S_CLOSED);
512     VerifyOrExit(tp->lport == HostSwap16(aMessageInfo.GetSockPort()));
513     VerifyOrExit(tp->fport == HostSwap16(aMessageInfo.GetPeerPort()));
514     VerifyOrExit(GetLocalIp6Address().IsUnspecified() || GetLocalIp6Address() == aMessageInfo.GetSockAddr());
515     VerifyOrExit(GetForeignIp6Address() == aMessageInfo.GetPeerAddr());
516 
517     matches = true;
518 
519 exit:
520     return matches;
521 }
522 
Initialize(Instance & aInstance,const otTcpListenerInitializeArgs & aArgs)523 Error Tcp::Listener::Initialize(Instance &aInstance, const otTcpListenerInitializeArgs &aArgs)
524 {
525     Error                error;
526     struct tcpcb_listen *tpl = &GetTcbListen();
527 
528     SuccessOrExit(error = aInstance.Get<Tcp>().mListeners.Add(*this));
529 
530     mContext             = aArgs.mContext;
531     mAcceptReadyCallback = aArgs.mAcceptReadyCallback;
532     mAcceptDoneCallback  = aArgs.mAcceptDoneCallback;
533 
534     memset(tpl, 0x00, sizeof(struct tcpcb_listen));
535     tpl->instance = &aInstance;
536 
537 exit:
538     return error;
539 }
540 
GetInstance(void) const541 Instance &Tcp::Listener::GetInstance(void) const
542 {
543     return AsNonConst(AsCoreType(GetTcbListen().instance));
544 }
545 
Listen(const SockAddr & aSockName)546 Error Tcp::Listener::Listen(const SockAddr &aSockName)
547 {
548     Error                error;
549     uint16_t             port = HostSwap16(aSockName.mPort);
550     struct tcpcb_listen *tpl  = &GetTcbListen();
551 
552     VerifyOrExit(Get<Tcp>().CanBind(aSockName), error = kErrorInvalidState);
553 
554     memcpy(&tpl->laddr, &aSockName.mAddress, sizeof(tpl->laddr));
555     tpl->lport   = port;
556     tpl->t_state = TCP6S_LISTEN;
557     error        = kErrorNone;
558 
559 exit:
560     return error;
561 }
562 
StopListening(void)563 Error Tcp::Listener::StopListening(void)
564 {
565     struct tcpcb_listen *tpl = &GetTcbListen();
566 
567     memset(&tpl->laddr, 0x00, sizeof(tpl->laddr));
568     tpl->lport   = 0;
569     tpl->t_state = TCP6S_CLOSED;
570     return kErrorNone;
571 }
572 
Deinitialize(void)573 Error Tcp::Listener::Deinitialize(void)
574 {
575     Error error;
576 
577     SuccessOrExit(error = Get<Tcp>().mListeners.Remove(*this));
578     SetNext(nullptr);
579 
580 exit:
581     return error;
582 }
583 
IsClosed(void) const584 bool Tcp::Listener::IsClosed(void) const
585 {
586     return GetTcbListen().t_state == TCP6S_CLOSED;
587 }
588 
GetLocalIp6Address(void)589 Address &Tcp::Listener::GetLocalIp6Address(void)
590 {
591     return *reinterpret_cast<Address *>(&GetTcbListen().laddr);
592 }
593 
GetLocalIp6Address(void) const594 const Address &Tcp::Listener::GetLocalIp6Address(void) const
595 {
596     return *reinterpret_cast<const Address *>(&GetTcbListen().laddr);
597 }
598 
Matches(const MessageInfo & aMessageInfo) const599 bool Tcp::Listener::Matches(const MessageInfo &aMessageInfo) const
600 {
601     bool                       matches = false;
602     const struct tcpcb_listen *tpl     = &GetTcbListen();
603 
604     VerifyOrExit(tpl->t_state == TCP6S_LISTEN);
605     VerifyOrExit(tpl->lport == HostSwap16(aMessageInfo.GetSockPort()));
606     VerifyOrExit(GetLocalIp6Address().IsUnspecified() || GetLocalIp6Address() == aMessageInfo.GetSockAddr());
607 
608     matches = true;
609 
610 exit:
611     return matches;
612 }
613 
HandleMessage(ot::Ip6::Header & aIp6Header,Message & aMessage,MessageInfo & aMessageInfo)614 Error Tcp::HandleMessage(ot::Ip6::Header &aIp6Header, Message &aMessage, MessageInfo &aMessageInfo)
615 {
616     Error error = kErrorNotImplemented;
617 
618     /*
619      * The type uint32_t was chosen for alignment purposes. The size is the
620      * maximum TCP header size, including options.
621      */
622     uint32_t header[15];
623 
624     uint16_t length = aIp6Header.GetPayloadLength();
625     uint8_t  headerSize;
626 
627     struct ip6_hdr *ip6Header;
628     struct tcphdr * tcpHeader;
629 
630     Endpoint *endpoint;
631     Endpoint *endpointPrev;
632 
633     Listener *listener;
634     Listener *listenerPrev;
635 
636     VerifyOrExit(length == aMessage.GetLength() - aMessage.GetOffset(), error = kErrorParse);
637     VerifyOrExit(length >= sizeof(Tcp::Header), error = kErrorParse);
638     SuccessOrExit(error = aMessage.Read(aMessage.GetOffset() + offsetof(struct tcphdr, th_off_x2), headerSize));
639     headerSize = static_cast<uint8_t>((headerSize >> TH_OFF_SHIFT) << 2);
640     VerifyOrExit(headerSize >= sizeof(struct tcphdr) && headerSize <= sizeof(header) &&
641                      static_cast<uint16_t>(headerSize) <= length,
642                  error = kErrorParse);
643     SuccessOrExit(error = Checksum::VerifyMessageChecksum(aMessage, aMessageInfo, kProtoTcp));
644     SuccessOrExit(error = aMessage.Read(aMessage.GetOffset(), &header[0], headerSize));
645 
646     ip6Header = reinterpret_cast<struct ip6_hdr *>(&aIp6Header);
647     tcpHeader = reinterpret_cast<struct tcphdr *>(&header[0]);
648     tcp_fields_to_host(tcpHeader);
649 
650     aMessageInfo.mPeerPort = HostSwap16(tcpHeader->th_sport);
651     aMessageInfo.mSockPort = HostSwap16(tcpHeader->th_dport);
652 
653     endpoint = mEndpoints.FindMatching(aMessageInfo, endpointPrev);
654     if (endpoint != nullptr)
655     {
656         struct tcplp_signals sig;
657         int                  nextAction;
658         struct tcpcb *       tp = &endpoint->GetTcb();
659 
660         otLinkedBuffer *priorHead    = lbuf_head(&tp->sendbuf);
661         size_t          priorBacklog = endpoint->GetSendBufferBytes() - endpoint->GetInFlightBytes();
662 
663         memset(&sig, 0x00, sizeof(sig));
664         nextAction = tcp_input(ip6Header, tcpHeader, &aMessage, tp, nullptr, &sig);
665         if (nextAction != RELOOKUP_REQUIRED)
666         {
667             ProcessSignals(*endpoint, priorHead, priorBacklog, sig);
668             ExitNow();
669         }
670         /* If the matching socket was in the TIME-WAIT state, then we try passive sockets. */
671     }
672 
673     listener = mListeners.FindMatching(aMessageInfo, listenerPrev);
674     if (listener != nullptr)
675     {
676         struct tcpcb_listen *tpl = &listener->GetTcbListen();
677 
678         tcp_input(ip6Header, tcpHeader, &aMessage, nullptr, tpl, nullptr);
679         ExitNow();
680     }
681 
682     tcp_dropwithreset(ip6Header, tcpHeader, nullptr, &InstanceLocator::GetInstance(), length - headerSize,
683                       ECONNREFUSED);
684 
685 exit:
686     return error;
687 }
688 
ProcessSignals(Endpoint & aEndpoint,otLinkedBuffer * aPriorHead,size_t aPriorBacklog,struct tcplp_signals & aSignals)689 void Tcp::ProcessSignals(Endpoint &            aEndpoint,
690                          otLinkedBuffer *      aPriorHead,
691                          size_t                aPriorBacklog,
692                          struct tcplp_signals &aSignals)
693 {
694     VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
695     if (aSignals.conn_established && aEndpoint.mEstablishedCallback != nullptr)
696     {
697         aEndpoint.mEstablishedCallback(&aEndpoint);
698     }
699 
700     VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
701     if (aEndpoint.mSendDoneCallback != nullptr)
702     {
703         otLinkedBuffer *curr = aPriorHead;
704 
705         for (uint32_t i = 0; i != aSignals.links_popped; i++)
706         {
707             otLinkedBuffer *next = curr->mNext;
708 
709             VerifyOrExit(i == 0 || (IsInitialized(aEndpoint) && !aEndpoint.IsClosed()));
710 
711             curr->mNext = nullptr;
712             aEndpoint.mSendDoneCallback(&aEndpoint, curr);
713             curr = next;
714         }
715     }
716 
717     VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
718     if (aEndpoint.mForwardProgressCallback != nullptr)
719     {
720         size_t backlogBytes = aEndpoint.GetBacklogBytes();
721 
722         if (aSignals.bytes_acked > 0 || backlogBytes < aPriorBacklog)
723         {
724             aEndpoint.mForwardProgressCallback(&aEndpoint, aEndpoint.GetSendBufferBytes(), backlogBytes);
725             aEndpoint.mPendingCallbacks &= ~kForwardProgressCallbackFlag;
726         }
727     }
728 
729     VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
730     if ((aSignals.recvbuf_added || aSignals.rcvd_fin) && aEndpoint.mReceiveAvailableCallback != nullptr)
731     {
732         aEndpoint.mReceiveAvailableCallback(&aEndpoint, cbuf_used_space(&aEndpoint.GetTcb().recvbuf),
733                                             aEndpoint.GetTcb().reass_fin_index != -1,
734                                             cbuf_free_space(&aEndpoint.GetTcb().recvbuf));
735     }
736 
737     VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
738     if (aEndpoint.GetTcb().t_state == TCP6S_TIME_WAIT && aEndpoint.mDisconnectedCallback != nullptr)
739     {
740         aEndpoint.mDisconnectedCallback(&aEndpoint, OT_TCP_DISCONNECTED_REASON_TIME_WAIT);
741     }
742 
743 exit:
744     return;
745 }
746 
BsdErrorToOtError(int aBsdError)747 Error Tcp::BsdErrorToOtError(int aBsdError)
748 {
749     Error error = kErrorFailed;
750 
751     switch (aBsdError)
752     {
753     case 0:
754         error = kErrorNone;
755         break;
756     }
757 
758     return error;
759 }
760 
CanBind(const SockAddr & aSockName)761 bool Tcp::CanBind(const SockAddr &aSockName)
762 {
763     uint16_t port    = HostSwap16(aSockName.mPort);
764     bool     allowed = false;
765 
766     for (Endpoint &endpoint : mEndpoints)
767     {
768         struct tcpcb *tp = &endpoint.GetTcb();
769 
770         if (tp->lport == port)
771         {
772             VerifyOrExit(!aSockName.GetAddress().IsUnspecified());
773             VerifyOrExit(!reinterpret_cast<Address *>(&tp->laddr)->IsUnspecified());
774             VerifyOrExit(memcmp(&endpoint.GetTcb().laddr, &aSockName.mAddress, sizeof(tp->laddr)) != 0);
775         }
776     }
777 
778     for (Listener &listener : mListeners)
779     {
780         struct tcpcb_listen *tpl = &listener.GetTcbListen();
781 
782         if (tpl->lport == port)
783         {
784             VerifyOrExit(!aSockName.GetAddress().IsUnspecified());
785             VerifyOrExit(!reinterpret_cast<Address *>(&tpl->laddr)->IsUnspecified());
786             VerifyOrExit(memcmp(&tpl->laddr, &aSockName.mAddress, sizeof(tpl->laddr)) != 0);
787         }
788     }
789 
790     allowed = true;
791 
792 exit:
793     return allowed;
794 }
795 
AutoBind(const SockAddr & aPeer,SockAddr & aToBind,bool aBindAddress,bool aBindPort)796 bool Tcp::AutoBind(const SockAddr &aPeer, SockAddr &aToBind, bool aBindAddress, bool aBindPort)
797 {
798     bool success;
799 
800     if (aBindAddress)
801     {
802         MessageInfo                  peerInfo;
803         const Netif::UnicastAddress *netifAddress;
804 
805         peerInfo.Clear();
806         peerInfo.SetPeerAddr(aPeer.GetAddress());
807         netifAddress = Get<Ip6>().SelectSourceAddress(peerInfo);
808         VerifyOrExit(netifAddress != nullptr, success = false);
809         aToBind.GetAddress() = netifAddress->GetAddress();
810     }
811 
812     if (aBindPort)
813     {
814         /*
815          * TODO: Use a less naive algorithm to allocate ephemeral ports. For
816          * example, see RFC 6056.
817          */
818 
819         for (uint16_t i = 0; i != kDynamicPortMax - kDynamicPortMin + 1; i++)
820         {
821             aToBind.SetPort(mEphemeralPort);
822 
823             if (mEphemeralPort == kDynamicPortMax)
824             {
825                 mEphemeralPort = kDynamicPortMin;
826             }
827             else
828             {
829                 mEphemeralPort++;
830             }
831 
832             if (CanBind(aToBind))
833             {
834                 ExitNow(success = true);
835             }
836         }
837 
838         ExitNow(success = false);
839     }
840 
841     success = CanBind(aToBind);
842 
843 exit:
844     return success;
845 }
846 
HandleTimer(Timer & aTimer)847 void Tcp::HandleTimer(Timer &aTimer)
848 {
849     OT_ASSERT(&aTimer == &aTimer.Get<Tcp>().mTimer);
850     LogDebg("Main TCP timer expired");
851     aTimer.Get<Tcp>().ProcessTimers();
852 }
853 
ProcessTimers(void)854 void Tcp::ProcessTimers(void)
855 {
856     TimeMilli now = TimerMilli::GetNow();
857     bool      pendingTimer;
858     TimeMilli earliestPendingTimerExpiry;
859 
860     OT_ASSERT(!mTimer.IsRunning());
861 
862     /*
863      * The timer callbacks could potentially set/reset/cancel timers.
864      * Importantly, Endpoint::SetTimer and Endpoint::CancelTimer do not call
865      * this function to recompute the timer. If they did, we'd have a
866      * re-entrancy problem, where the callbacks called in this function could
867      * wind up re-entering this function in a nested call frame.
868      *
869      * In general, calling this function from Endpoint::SetTimer and
870      * Endpoint::CancelTimer could be inefficient, since those functions are
871      * called multiple times on each received TCP segment. If we want to
872      * prevent the main timer from firing except when an actual TCP timer
873      * expires, a better alternative is to reset the main timer in
874      * HandleMessage, right before processing signals. That would achieve that
875      * objective while avoiding re-entrancy issues altogether.
876      */
877 restart:
878     pendingTimer               = false;
879     earliestPendingTimerExpiry = now.GetDistantFuture();
880 
881     for (Endpoint &endpoint : mEndpoints)
882     {
883         if (endpoint.FirePendingTimers(now, pendingTimer, earliestPendingTimerExpiry))
884         {
885             /*
886              * If a non-OpenThread callback is called --- which, in practice,
887              * happens if the connection times out and the user-defined
888              * connection lost callback is called --- then we might have to
889              * start over. The reason is that the user might deinitialize
890              * endpoints, changing the structure of the linked list. For
891              * example, if the user deinitializes both this endpoint and the
892              * next one in the linked list, then we can't continue traversing
893              * the linked list.
894              */
895             goto restart;
896         }
897     }
898 
899     if (pendingTimer)
900     {
901         /*
902          * We need to use Timer::FireAtIfEarlier instead of timer::FireAt
903          * because one of the earlier callbacks might have set TCP timers,
904          * in which case `mTimer` would have been set to the earliest of those
905          * timers.
906          */
907         mTimer.FireAtIfEarlier(earliestPendingTimerExpiry);
908         LogDebg("Reset main TCP timer to %u ms", static_cast<unsigned int>(earliestPendingTimerExpiry - now));
909     }
910     else
911     {
912         LogDebg("Did not reset main TCP timer");
913     }
914 }
915 
HandleTasklet(Tasklet & aTasklet)916 void Tcp::HandleTasklet(Tasklet &aTasklet)
917 {
918     OT_ASSERT(&aTasklet == &aTasklet.Get<Tcp>().mTasklet);
919     LogDebg("TCP tasklet invoked");
920     aTasklet.Get<Tcp>().ProcessCallbacks();
921 }
922 
ProcessCallbacks(void)923 void Tcp::ProcessCallbacks(void)
924 {
925     for (Endpoint &endpoint : mEndpoints)
926     {
927         if (endpoint.FirePendingCallbacks())
928         {
929             mTasklet.Post();
930             break;
931         }
932     }
933 }
934 
935 } // namespace Ip6
936 } // namespace ot
937 
938 /*
939  * Implement TCPlp system stubs declared in tcplp.h.
940  *
941  * Because these functions have C linkage, it is important that only one
942  * definition is given for each function name, regardless of the namespace it
943  * in. For example, if we give two definitions of tcplp_sys_new_message, we
944  * will get errors, even if they are in different namespaces. To avoid
945  * confusion, I've put these functions outside of any namespace.
946  */
947 
948 using namespace ot;
949 using namespace ot::Ip6;
950 
951 extern "C" {
952 
tcplp_sys_new_message(otInstance * aInstance)953 otMessage *tcplp_sys_new_message(otInstance *aInstance)
954 {
955     Instance &instance = AsCoreType(aInstance);
956     Message * message  = instance.Get<ot::Ip6::Ip6>().NewMessage(0);
957 
958     if (message)
959     {
960         message->SetLinkSecurityEnabled(true);
961     }
962 
963     return message;
964 }
965 
tcplp_sys_free_message(otInstance * aInstance,otMessage * aMessage)966 void tcplp_sys_free_message(otInstance *aInstance, otMessage *aMessage)
967 {
968     OT_UNUSED_VARIABLE(aInstance);
969     Message &message = AsCoreType(aMessage);
970     message.Free();
971 }
972 
tcplp_sys_send_message(otInstance * aInstance,otMessage * aMessage,otMessageInfo * aMessageInfo)973 void tcplp_sys_send_message(otInstance *aInstance, otMessage *aMessage, otMessageInfo *aMessageInfo)
974 {
975     Instance &   instance = AsCoreType(aInstance);
976     Message &    message  = AsCoreType(aMessage);
977     MessageInfo &info     = AsCoreType(aMessageInfo);
978 
979     LogDebg("Sending TCP segment: payload_size = %d", static_cast<int>(message.GetLength()));
980 
981     IgnoreError(instance.Get<ot::Ip6::Ip6>().SendDatagram(message, info, kProtoTcp));
982 }
983 
tcplp_sys_get_ticks(void)984 uint32_t tcplp_sys_get_ticks(void)
985 {
986     return TimerMilli::GetNow().GetValue();
987 }
988 
tcplp_sys_get_millis(void)989 uint32_t tcplp_sys_get_millis(void)
990 {
991     return TimerMilli::GetNow().GetValue();
992 }
993 
tcplp_sys_set_timer(struct tcpcb * aTcb,uint8_t aTimerFlag,uint32_t aDelay)994 void tcplp_sys_set_timer(struct tcpcb *aTcb, uint8_t aTimerFlag, uint32_t aDelay)
995 {
996     Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aTcb);
997     endpoint.SetTimer(aTimerFlag, aDelay);
998 }
999 
tcplp_sys_stop_timer(struct tcpcb * aTcb,uint8_t aTimerFlag)1000 void tcplp_sys_stop_timer(struct tcpcb *aTcb, uint8_t aTimerFlag)
1001 {
1002     Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aTcb);
1003     endpoint.CancelTimer(aTimerFlag);
1004 }
1005 
tcplp_sys_accept_ready(struct tcpcb_listen * aTcbListen,struct in6_addr * aAddr,uint16_t aPort)1006 struct tcpcb *tcplp_sys_accept_ready(struct tcpcb_listen *aTcbListen, struct in6_addr *aAddr, uint16_t aPort)
1007 {
1008     Tcp::Listener &               listener = Tcp::Listener::FromTcbListen(*aTcbListen);
1009     Tcp &                         tcp      = listener.Get<Tcp>();
1010     struct tcpcb *                rv       = (struct tcpcb *)-1;
1011     otSockAddr                    addr;
1012     otTcpEndpoint *               endpointPtr;
1013     otTcpIncomingConnectionAction action;
1014 
1015     VerifyOrExit(listener.mAcceptReadyCallback != nullptr);
1016 
1017     memcpy(&addr.mAddress, aAddr, sizeof(addr.mAddress));
1018     addr.mPort = HostSwap16(aPort);
1019     action     = listener.mAcceptReadyCallback(&listener, &addr, &endpointPtr);
1020 
1021     VerifyOrExit(tcp.IsInitialized(listener) && !listener.IsClosed());
1022 
1023     switch (action)
1024     {
1025     case OT_TCP_INCOMING_CONNECTION_ACTION_ACCEPT:
1026     {
1027         Tcp::Endpoint &endpoint = AsCoreType(endpointPtr);
1028 
1029         /*
1030          * The documentation says that the user must initialize the
1031          * endpoint before passing it here, so we do a sanity check to make
1032          * sure the endpoint is initialized and closed. That check may not
1033          * be necessary, but we do it anyway.
1034          */
1035         VerifyOrExit(tcp.IsInitialized(endpoint) && endpoint.IsClosed());
1036 
1037         rv = &endpoint.GetTcb();
1038 
1039         break;
1040     }
1041     case OT_TCP_INCOMING_CONNECTION_ACTION_DEFER:
1042         rv = nullptr;
1043         break;
1044     case OT_TCP_INCOMING_CONNECTION_ACTION_REFUSE:
1045         rv = (struct tcpcb *)-1;
1046         break;
1047     }
1048 
1049 exit:
1050     return rv;
1051 }
1052 
tcplp_sys_accepted_connection(struct tcpcb_listen * aTcbListen,struct tcpcb * aAccepted,struct in6_addr * aAddr,uint16_t aPort)1053 bool tcplp_sys_accepted_connection(struct tcpcb_listen *aTcbListen,
1054                                    struct tcpcb *       aAccepted,
1055                                    struct in6_addr *    aAddr,
1056                                    uint16_t             aPort)
1057 {
1058     Tcp::Listener &listener = Tcp::Listener::FromTcbListen(*aTcbListen);
1059     Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aAccepted);
1060     Tcp &          tcp      = endpoint.Get<Tcp>();
1061     bool           accepted = true;
1062 
1063     if (listener.mAcceptDoneCallback != nullptr)
1064     {
1065         otSockAddr addr;
1066 
1067         memcpy(&addr.mAddress, aAddr, sizeof(addr.mAddress));
1068         addr.mPort = HostSwap16(aPort);
1069         listener.mAcceptDoneCallback(&listener, &endpoint, &addr);
1070 
1071         if (!tcp.IsInitialized(endpoint) || endpoint.IsClosed())
1072         {
1073             accepted = false;
1074         }
1075     }
1076 
1077     return accepted;
1078 }
1079 
tcplp_sys_connection_lost(struct tcpcb * aTcb,uint8_t aErrNum)1080 void tcplp_sys_connection_lost(struct tcpcb *aTcb, uint8_t aErrNum)
1081 {
1082     Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aTcb);
1083 
1084     if (endpoint.mDisconnectedCallback != nullptr)
1085     {
1086         otTcpDisconnectedReason reason;
1087 
1088         switch (aErrNum)
1089         {
1090         case CONN_LOST_NORMAL:
1091             reason = OT_TCP_DISCONNECTED_REASON_NORMAL;
1092             break;
1093         case ECONNREFUSED:
1094             reason = OT_TCP_DISCONNECTED_REASON_REFUSED;
1095             break;
1096         case ETIMEDOUT:
1097             reason = OT_TCP_DISCONNECTED_REASON_TIMED_OUT;
1098             break;
1099         case ECONNRESET:
1100         default:
1101             reason = OT_TCP_DISCONNECTED_REASON_RESET;
1102             break;
1103         }
1104         endpoint.mDisconnectedCallback(&endpoint, reason);
1105     }
1106 }
1107 
tcplp_sys_on_state_change(struct tcpcb * aTcb,int aNewState)1108 void tcplp_sys_on_state_change(struct tcpcb *aTcb, int aNewState)
1109 {
1110     if (aNewState == TCP6S_CLOSED)
1111     {
1112         /* Re-initialize the TCB. */
1113         cbuf_pop(&aTcb->recvbuf, cbuf_used_space(&aTcb->recvbuf));
1114         aTcb->accepted_from = nullptr;
1115         initialize_tcb(aTcb);
1116     }
1117     /* Any adaptive changes to the sleep interval would go here. */
1118 }
1119 
tcplp_sys_log(const char * aFormat,...)1120 void tcplp_sys_log(const char *aFormat, ...)
1121 {
1122     char    buffer[128];
1123     va_list args;
1124     va_start(args, aFormat);
1125     vsnprintf(buffer, sizeof(buffer), aFormat, args);
1126     va_end(args);
1127 
1128     LogDebg(buffer);
1129 }
1130 
tcplp_sys_panic(const char * aFormat,...)1131 void tcplp_sys_panic(const char *aFormat, ...)
1132 {
1133     char    buffer[128];
1134     va_list args;
1135     va_start(args, aFormat);
1136     vsnprintf(buffer, sizeof(buffer), aFormat, args);
1137     va_end(args);
1138 
1139     LogCrit("%s", buffer);
1140 
1141     OT_ASSERT(false);
1142 }
1143 
tcplp_sys_autobind(otInstance * aInstance,const otSockAddr * aPeer,otSockAddr * aToBind,bool aBindAddress,bool aBindPort)1144 bool tcplp_sys_autobind(otInstance *      aInstance,
1145                         const otSockAddr *aPeer,
1146                         otSockAddr *      aToBind,
1147                         bool              aBindAddress,
1148                         bool              aBindPort)
1149 {
1150     Instance &instance = AsCoreType(aInstance);
1151 
1152     return instance.Get<Tcp>().AutoBind(*static_cast<const SockAddr *>(aPeer), *static_cast<SockAddr *>(aToBind),
1153                                         aBindAddress, aBindPort);
1154 }
1155 
tcplp_sys_generate_isn()1156 uint32_t tcplp_sys_generate_isn()
1157 {
1158     uint32_t isn;
1159     IgnoreError(Random::Crypto::FillBuffer(reinterpret_cast<uint8_t *>(&isn), sizeof(isn)));
1160     return isn;
1161 }
1162 
tcplp_sys_hostswap16(uint16_t aHostPort)1163 uint16_t tcplp_sys_hostswap16(uint16_t aHostPort)
1164 {
1165     return HostSwap16(aHostPort);
1166 }
1167 
tcplp_sys_hostswap32(uint32_t aHostPort)1168 uint32_t tcplp_sys_hostswap32(uint32_t aHostPort)
1169 {
1170     return HostSwap32(aHostPort);
1171 }
1172 }
1173 
1174 #endif // OPENTHREAD_CONFIG_TCP_ENABLE
1175