• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2021, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements TCP/IPv6 sockets.
32  */
33 
34 #include "tcp6.hpp"
35 
36 #if OPENTHREAD_CONFIG_TCP_ENABLE
37 
38 #include "instance/instance.hpp"
39 
40 #include "../../third_party/tcplp/tcplp.h"
41 
42 namespace ot {
43 namespace Ip6 {
44 
45 RegisterLogModule("Tcp");
46 
47 static_assert(sizeof(struct tcpcb) == sizeof(Tcp::Endpoint::mTcb), "mTcb field in otTcpEndpoint is sized incorrectly");
48 static_assert(alignof(struct tcpcb) == alignof(decltype(Tcp::Endpoint::mTcb)),
49               "mTcb field in otTcpEndpoint is aligned incorrectly");
50 static_assert(offsetof(Tcp::Endpoint, mTcb) == 0, "mTcb field in otTcpEndpoint has nonzero offset");
51 
52 static_assert(sizeof(struct tcpcb_listen) == sizeof(Tcp::Listener::mTcbListen),
53               "mTcbListen field in otTcpListener is sized incorrectly");
54 static_assert(alignof(struct tcpcb_listen) == alignof(decltype(Tcp::Listener::mTcbListen)),
55               "mTcbListen field in otTcpListener is aligned incorrectly");
56 static_assert(offsetof(Tcp::Listener, mTcbListen) == 0, "mTcbListen field in otTcpEndpoint has nonzero offset");
57 
Tcp(Instance & aInstance)58 Tcp::Tcp(Instance &aInstance)
59     : InstanceLocator(aInstance)
60     , mTimer(aInstance)
61     , mTasklet(aInstance)
62     , mEphemeralPort(kDynamicPortMin)
63 {
64     OT_UNUSED_VARIABLE(mEphemeralPort);
65 }
66 
Initialize(Instance & aInstance,const otTcpEndpointInitializeArgs & aArgs)67 Error Tcp::Endpoint::Initialize(Instance &aInstance, const otTcpEndpointInitializeArgs &aArgs)
68 {
69     Error         error;
70     struct tcpcb &tp = GetTcb();
71 
72     ClearAllBytes(tp);
73 
74     SuccessOrExit(error = aInstance.Get<Tcp>().mEndpoints.Add(*this));
75 
76     mContext                  = aArgs.mContext;
77     mEstablishedCallback      = aArgs.mEstablishedCallback;
78     mSendDoneCallback         = aArgs.mSendDoneCallback;
79     mForwardProgressCallback  = aArgs.mForwardProgressCallback;
80     mReceiveAvailableCallback = aArgs.mReceiveAvailableCallback;
81     mDisconnectedCallback     = aArgs.mDisconnectedCallback;
82 
83     ClearAllBytes(mTimers);
84     ClearAllBytes(mSockAddr);
85     mPendingCallbacks = 0;
86 
87     /*
88      * Initialize buffers --- formerly in initialize_tcb.
89      */
90     {
91         uint8_t *recvbuf    = static_cast<uint8_t *>(aArgs.mReceiveBuffer);
92         size_t   recvbuflen = aArgs.mReceiveBufferSize - ((aArgs.mReceiveBufferSize + 8) / 9);
93         uint8_t *reassbmp   = recvbuf + recvbuflen;
94 
95         lbuf_init(&tp.sendbuf);
96         cbuf_init(&tp.recvbuf, recvbuf, recvbuflen);
97         tp.reassbmp = reassbmp;
98         bmp_init(tp.reassbmp, BITS_TO_BYTES(recvbuflen));
99     }
100 
101     tp.accepted_from = nullptr;
102     initialize_tcb(&tp);
103 
104     /* Note that we do not need to zero-initialize mReceiveLinks. */
105 
106     tp.instance = &aInstance;
107 
108 exit:
109     return error;
110 }
111 
GetInstance(void) const112 Instance &Tcp::Endpoint::GetInstance(void) const { return AsNonConst(AsCoreType(GetTcb().instance)); }
113 
GetLocalAddress(void) const114 const SockAddr &Tcp::Endpoint::GetLocalAddress(void) const
115 {
116     const struct tcpcb &tp = GetTcb();
117 
118     static otSockAddr temp;
119 
120     memcpy(&temp.mAddress, &tp.laddr, sizeof(temp.mAddress));
121     temp.mPort = BigEndian::HostSwap16(tp.lport);
122 
123     return AsCoreType(&temp);
124 }
125 
GetPeerAddress(void) const126 const SockAddr &Tcp::Endpoint::GetPeerAddress(void) const
127 {
128     const struct tcpcb &tp = GetTcb();
129 
130     static otSockAddr temp;
131 
132     memcpy(&temp.mAddress, &tp.faddr, sizeof(temp.mAddress));
133     temp.mPort = BigEndian::HostSwap16(tp.fport);
134 
135     return AsCoreType(&temp);
136 }
137 
Bind(const SockAddr & aSockName)138 Error Tcp::Endpoint::Bind(const SockAddr &aSockName)
139 {
140     Error         error;
141     struct tcpcb &tp = GetTcb();
142 
143     VerifyOrExit(!AsCoreType(&aSockName.mAddress).IsUnspecified(), error = kErrorInvalidArgs);
144     VerifyOrExit(Get<Tcp>().CanBind(aSockName), error = kErrorInvalidState);
145 
146     memcpy(&tp.laddr, &aSockName.mAddress, sizeof(tp.laddr));
147     tp.lport = BigEndian::HostSwap16(aSockName.mPort);
148     error    = kErrorNone;
149 
150 exit:
151     return error;
152 }
153 
Connect(const SockAddr & aSockName,uint32_t aFlags)154 Error Tcp::Endpoint::Connect(const SockAddr &aSockName, uint32_t aFlags)
155 {
156     Error         error = kErrorNone;
157     struct tcpcb &tp    = GetTcb();
158 
159     VerifyOrExit(tp.t_state == TCP6S_CLOSED, error = kErrorInvalidState);
160 
161     if (aFlags & OT_TCP_CONNECT_NO_FAST_OPEN)
162     {
163         struct sockaddr_in6 sin6p;
164 
165         tp.t_flags &= ~TF_FASTOPEN;
166         memcpy(&sin6p.sin6_addr, &aSockName.mAddress, sizeof(sin6p.sin6_addr));
167         sin6p.sin6_port = BigEndian::HostSwap16(aSockName.mPort);
168         error           = BsdErrorToOtError(tcp6_usr_connect(&tp, &sin6p));
169     }
170     else
171     {
172         tp.t_flags |= TF_FASTOPEN;
173 
174         /* Stash the destination address in tp. */
175         memcpy(&tp.faddr, &aSockName.mAddress, sizeof(tp.faddr));
176         tp.fport = BigEndian::HostSwap16(aSockName.mPort);
177     }
178 
179 exit:
180     return error;
181 }
182 
SendByReference(otLinkedBuffer & aBuffer,uint32_t aFlags)183 Error Tcp::Endpoint::SendByReference(otLinkedBuffer &aBuffer, uint32_t aFlags)
184 {
185     Error         error;
186     struct tcpcb &tp = GetTcb();
187 
188     size_t backlogBefore = GetBacklogBytes();
189     size_t sent          = aBuffer.mLength;
190 
191     struct sockaddr_in6  sin6p;
192     struct sockaddr_in6 *name = nullptr;
193 
194     if (IS_FASTOPEN(tp.t_flags))
195     {
196         memcpy(&sin6p.sin6_addr, &tp.faddr, sizeof(sin6p.sin6_addr));
197         sin6p.sin6_port = tp.fport;
198         name            = &sin6p;
199     }
200     SuccessOrExit(
201         error = BsdErrorToOtError(tcp_usr_send(&tp, (aFlags & OT_TCP_SEND_MORE_TO_COME) != 0, &aBuffer, 0, name)));
202 
203     PostCallbacksAfterSend(sent, backlogBefore);
204 
205 exit:
206     return error;
207 }
208 
SendByExtension(size_t aNumBytes,uint32_t aFlags)209 Error Tcp::Endpoint::SendByExtension(size_t aNumBytes, uint32_t aFlags)
210 {
211     Error         error;
212     bool          moreToCome    = (aFlags & OT_TCP_SEND_MORE_TO_COME) != 0;
213     struct tcpcb &tp            = GetTcb();
214     size_t        backlogBefore = GetBacklogBytes();
215     int           bsdError;
216 
217     struct sockaddr_in6  sin6p;
218     struct sockaddr_in6 *name = nullptr;
219 
220     VerifyOrExit(lbuf_head(&tp.sendbuf) != nullptr, error = kErrorInvalidState);
221 
222     if (IS_FASTOPEN(tp.t_flags))
223     {
224         memcpy(&sin6p.sin6_addr, &tp.faddr, sizeof(sin6p.sin6_addr));
225         sin6p.sin6_port = tp.fport;
226         name            = &sin6p;
227     }
228 
229     bsdError = tcp_usr_send(&tp, moreToCome ? 1 : 0, nullptr, aNumBytes, name);
230     SuccessOrExit(error = BsdErrorToOtError(bsdError));
231 
232     PostCallbacksAfterSend(aNumBytes, backlogBefore);
233 
234 exit:
235     return error;
236 }
237 
ReceiveByReference(const otLinkedBuffer * & aBuffer)238 Error Tcp::Endpoint::ReceiveByReference(const otLinkedBuffer *&aBuffer)
239 {
240     struct tcpcb &tp = GetTcb();
241 
242     cbuf_reference(&tp.recvbuf, &mReceiveLinks[0], &mReceiveLinks[1]);
243     aBuffer = &mReceiveLinks[0];
244 
245     return kErrorNone;
246 }
247 
ReceiveContiguify(void)248 Error Tcp::Endpoint::ReceiveContiguify(void)
249 {
250     struct tcpcb &tp = GetTcb();
251 
252     cbuf_contiguify(&tp.recvbuf, tp.reassbmp);
253 
254     return kErrorNone;
255 }
256 
CommitReceive(size_t aNumBytes,uint32_t aFlags)257 Error Tcp::Endpoint::CommitReceive(size_t aNumBytes, uint32_t aFlags)
258 {
259     Error         error = kErrorNone;
260     struct tcpcb &tp    = GetTcb();
261 
262     OT_UNUSED_VARIABLE(aFlags);
263 
264     VerifyOrExit(cbuf_used_space(&tp.recvbuf) >= aNumBytes, error = kErrorFailed);
265     VerifyOrExit(aNumBytes > 0, error = kErrorNone);
266 
267     cbuf_pop(&tp.recvbuf, aNumBytes);
268     error = BsdErrorToOtError(tcp_usr_rcvd(&tp));
269 
270 exit:
271     return error;
272 }
273 
SendEndOfStream(void)274 Error Tcp::Endpoint::SendEndOfStream(void)
275 {
276     struct tcpcb &tp = GetTcb();
277 
278     return BsdErrorToOtError(tcp_usr_shutdown(&tp));
279 }
280 
Abort(void)281 Error Tcp::Endpoint::Abort(void)
282 {
283     struct tcpcb &tp = GetTcb();
284 
285     tcp_usr_abort(&tp);
286     /* connection_lost will do any reinitialization work for this socket. */
287     return kErrorNone;
288 }
289 
Deinitialize(void)290 Error Tcp::Endpoint::Deinitialize(void)
291 {
292     Error error;
293 
294     SuccessOrExit(error = Get<Tcp>().mEndpoints.Remove(*this));
295     SetNext(nullptr);
296 
297     SuccessOrExit(error = Abort());
298 
299 exit:
300     return error;
301 }
302 
IsClosed(void) const303 bool Tcp::Endpoint::IsClosed(void) const { return GetTcb().t_state == TCP6S_CLOSED; }
304 
TimerFlagToIndex(uint8_t aTimerFlag)305 uint8_t Tcp::Endpoint::TimerFlagToIndex(uint8_t aTimerFlag)
306 {
307     uint8_t timerIndex = 0;
308 
309     switch (aTimerFlag)
310     {
311     case TT_DELACK:
312         timerIndex = kTimerDelack;
313         break;
314     case TT_REXMT:
315     case TT_PERSIST:
316         timerIndex = kTimerRexmtPersist;
317         break;
318     case TT_KEEP:
319         timerIndex = kTimerKeep;
320         break;
321     case TT_2MSL:
322         timerIndex = kTimer2Msl;
323         break;
324     }
325 
326     return timerIndex;
327 }
328 
IsTimerActive(uint8_t aTimerIndex)329 bool Tcp::Endpoint::IsTimerActive(uint8_t aTimerIndex)
330 {
331     bool          active = false;
332     struct tcpcb *tp     = &GetTcb();
333 
334     OT_ASSERT(aTimerIndex < kNumTimers);
335     switch (aTimerIndex)
336     {
337     case kTimerDelack:
338         active = tcp_timer_active(tp, TT_DELACK);
339         break;
340     case kTimerRexmtPersist:
341         active = tcp_timer_active(tp, TT_REXMT) || tcp_timer_active(tp, TT_PERSIST);
342         break;
343     case kTimerKeep:
344         active = tcp_timer_active(tp, TT_KEEP);
345         break;
346     case kTimer2Msl:
347         active = tcp_timer_active(tp, TT_2MSL);
348         break;
349     }
350 
351     return active;
352 }
353 
SetTimer(uint8_t aTimerFlag,uint32_t aDelay)354 void Tcp::Endpoint::SetTimer(uint8_t aTimerFlag, uint32_t aDelay)
355 {
356     /*
357      * TCPlp has already set the flag for this timer to record that it's
358      * running. So, all that's left to do is record the expiry time and
359      * (re)set the main timer as appropriate.
360      */
361 
362     TimeMilli now         = TimerMilli::GetNow();
363     TimeMilli newFireTime = now + aDelay;
364     uint8_t   timerIndex  = TimerFlagToIndex(aTimerFlag);
365 
366     mTimers[timerIndex] = newFireTime.GetValue();
367     LogDebg("Endpoint %p set timer %u to %u ms", static_cast<void *>(this), static_cast<unsigned int>(timerIndex),
368             static_cast<unsigned int>(aDelay));
369 
370     Get<Tcp>().mTimer.FireAtIfEarlier(newFireTime);
371 }
372 
CancelTimer(uint8_t aTimerFlag)373 void Tcp::Endpoint::CancelTimer(uint8_t aTimerFlag)
374 {
375     /*
376      * TCPlp has already cleared the timer flag before calling this. Since the
377      * main timer's callback properly handles the case where no timers are
378      * actually due, there's actually no work to be done here.
379      */
380 
381     OT_UNUSED_VARIABLE(aTimerFlag);
382 
383     LogDebg("Endpoint %p cancelled timer %u", static_cast<void *>(this),
384             static_cast<unsigned int>(TimerFlagToIndex(aTimerFlag)));
385 }
386 
FirePendingTimers(TimeMilli aNow,bool & aHasFutureTimer,TimeMilli & aEarliestFutureExpiry)387 bool Tcp::Endpoint::FirePendingTimers(TimeMilli aNow, bool &aHasFutureTimer, TimeMilli &aEarliestFutureExpiry)
388 {
389     bool          calledUserCallback = false;
390     struct tcpcb *tp                 = &GetTcb();
391 
392     /*
393      * NOTE: Firing a timer might potentially activate/deactivate other timers.
394      * If timers x and y expire at the same time, but the callback for timer x
395      * (for x < y) cancels or postpones timer y, should timer y's callback be
396      * called? Our answer is no, since timer x's callback has updated the
397      * TCP stack's state in such a way that it no longer expects timer y's
398      * callback to to be called. Because the TCP stack thinks that timer y
399      * has been cancelled, calling timer y's callback could potentially cause
400      * problems.
401      *
402      * If the timer callbacks set other timers, then they may not be taken
403      * into account when setting aEarliestFutureExpiry. But mTimer's expiry
404      * time will be updated by those, so we can just compare against mTimer's
405      * expiry time when resetting mTimer.
406      */
407     for (uint8_t timerIndex = 0; timerIndex != kNumTimers; timerIndex++)
408     {
409         if (IsTimerActive(timerIndex))
410         {
411             TimeMilli expiry(mTimers[timerIndex]);
412 
413             if (expiry <= aNow)
414             {
415                 /*
416                  * If a user callback is called, then return true. For TCPlp,
417                  * this only happens if the connection is dropped (e.g., it
418                  * times out).
419                  */
420                 int dropped = 0;
421 
422                 switch (timerIndex)
423                 {
424                 case kTimerDelack:
425                     dropped = tcp_timer_delack(tp);
426                     break;
427                 case kTimerRexmtPersist:
428                     if (tcp_timer_active(tp, TT_REXMT))
429                     {
430                         dropped = tcp_timer_rexmt(tp);
431                     }
432                     else
433                     {
434                         dropped = tcp_timer_persist(tp);
435                     }
436                     break;
437                 case kTimerKeep:
438                     dropped = tcp_timer_keep(tp);
439                     break;
440                 case kTimer2Msl:
441                     dropped = tcp_timer_2msl(tp);
442                     break;
443                 }
444                 VerifyOrExit(dropped == 0, calledUserCallback = true);
445             }
446             else
447             {
448                 aHasFutureTimer       = true;
449                 aEarliestFutureExpiry = Min(aEarliestFutureExpiry, expiry);
450             }
451         }
452     }
453 
454 exit:
455     return calledUserCallback;
456 }
457 
PostCallbacksAfterSend(size_t aSent,size_t aBacklogBefore)458 void Tcp::Endpoint::PostCallbacksAfterSend(size_t aSent, size_t aBacklogBefore)
459 {
460     size_t backlogAfter = GetBacklogBytes();
461 
462     if (backlogAfter < aBacklogBefore + aSent && mForwardProgressCallback != nullptr)
463     {
464         mPendingCallbacks |= kForwardProgressCallbackFlag;
465         Get<Tcp>().mTasklet.Post();
466     }
467 }
468 
FirePendingCallbacks(void)469 bool Tcp::Endpoint::FirePendingCallbacks(void)
470 {
471     bool calledUserCallback = false;
472 
473     if ((mPendingCallbacks & kForwardProgressCallbackFlag) != 0 && mForwardProgressCallback != nullptr)
474     {
475         mForwardProgressCallback(this, GetSendBufferBytes(), GetBacklogBytes());
476         calledUserCallback = true;
477     }
478 
479     mPendingCallbacks = 0;
480 
481     return calledUserCallback;
482 }
483 
GetSendBufferBytes(void) const484 size_t Tcp::Endpoint::GetSendBufferBytes(void) const
485 {
486     const struct tcpcb &tp = GetTcb();
487     return lbuf_used_space(&tp.sendbuf);
488 }
489 
GetInFlightBytes(void) const490 size_t Tcp::Endpoint::GetInFlightBytes(void) const
491 {
492     const struct tcpcb &tp = GetTcb();
493     return tp.snd_max - tp.snd_una;
494 }
495 
GetBacklogBytes(void) const496 size_t Tcp::Endpoint::GetBacklogBytes(void) const { return GetSendBufferBytes() - GetInFlightBytes(); }
497 
GetLocalIp6Address(void)498 Address &Tcp::Endpoint::GetLocalIp6Address(void) { return *reinterpret_cast<Address *>(&GetTcb().laddr); }
499 
GetLocalIp6Address(void) const500 const Address &Tcp::Endpoint::GetLocalIp6Address(void) const
501 {
502     return *reinterpret_cast<const Address *>(&GetTcb().laddr);
503 }
504 
GetForeignIp6Address(void)505 Address &Tcp::Endpoint::GetForeignIp6Address(void) { return *reinterpret_cast<Address *>(&GetTcb().faddr); }
506 
GetForeignIp6Address(void) const507 const Address &Tcp::Endpoint::GetForeignIp6Address(void) const
508 {
509     return *reinterpret_cast<const Address *>(&GetTcb().faddr);
510 }
511 
Matches(const MessageInfo & aMessageInfo) const512 bool Tcp::Endpoint::Matches(const MessageInfo &aMessageInfo) const
513 {
514     bool                matches = false;
515     const struct tcpcb *tp      = &GetTcb();
516 
517     VerifyOrExit(tp->t_state != TCP6S_CLOSED);
518     VerifyOrExit(tp->lport == BigEndian::HostSwap16(aMessageInfo.GetSockPort()));
519     VerifyOrExit(tp->fport == BigEndian::HostSwap16(aMessageInfo.GetPeerPort()));
520     VerifyOrExit(GetLocalIp6Address().IsUnspecified() || GetLocalIp6Address() == aMessageInfo.GetSockAddr());
521     VerifyOrExit(GetForeignIp6Address() == aMessageInfo.GetPeerAddr());
522 
523     matches = true;
524 
525 exit:
526     return matches;
527 }
528 
Initialize(Instance & aInstance,const otTcpListenerInitializeArgs & aArgs)529 Error Tcp::Listener::Initialize(Instance &aInstance, const otTcpListenerInitializeArgs &aArgs)
530 {
531     Error                error;
532     struct tcpcb_listen *tpl = &GetTcbListen();
533 
534     SuccessOrExit(error = aInstance.Get<Tcp>().mListeners.Add(*this));
535 
536     mContext             = aArgs.mContext;
537     mAcceptReadyCallback = aArgs.mAcceptReadyCallback;
538     mAcceptDoneCallback  = aArgs.mAcceptDoneCallback;
539 
540     ClearAllBytes(*tpl);
541     tpl->instance = &aInstance;
542 
543 exit:
544     return error;
545 }
546 
GetInstance(void) const547 Instance &Tcp::Listener::GetInstance(void) const { return AsNonConst(AsCoreType(GetTcbListen().instance)); }
548 
Listen(const SockAddr & aSockName)549 Error Tcp::Listener::Listen(const SockAddr &aSockName)
550 {
551     Error                error;
552     uint16_t             port = BigEndian::HostSwap16(aSockName.mPort);
553     struct tcpcb_listen *tpl  = &GetTcbListen();
554 
555     VerifyOrExit(Get<Tcp>().CanBind(aSockName), error = kErrorInvalidState);
556 
557     memcpy(&tpl->laddr, &aSockName.mAddress, sizeof(tpl->laddr));
558     tpl->lport   = port;
559     tpl->t_state = TCP6S_LISTEN;
560     error        = kErrorNone;
561 
562 exit:
563     return error;
564 }
565 
StopListening(void)566 Error Tcp::Listener::StopListening(void)
567 {
568     struct tcpcb_listen *tpl = &GetTcbListen();
569 
570     ClearAllBytes(tpl->laddr);
571     tpl->lport   = 0;
572     tpl->t_state = TCP6S_CLOSED;
573     return kErrorNone;
574 }
575 
Deinitialize(void)576 Error Tcp::Listener::Deinitialize(void)
577 {
578     Error error;
579 
580     SuccessOrExit(error = Get<Tcp>().mListeners.Remove(*this));
581     SetNext(nullptr);
582 
583 exit:
584     return error;
585 }
586 
IsClosed(void) const587 bool Tcp::Listener::IsClosed(void) const { return GetTcbListen().t_state == TCP6S_CLOSED; }
588 
GetLocalIp6Address(void)589 Address &Tcp::Listener::GetLocalIp6Address(void) { return *reinterpret_cast<Address *>(&GetTcbListen().laddr); }
590 
GetLocalIp6Address(void) const591 const Address &Tcp::Listener::GetLocalIp6Address(void) const
592 {
593     return *reinterpret_cast<const Address *>(&GetTcbListen().laddr);
594 }
595 
Matches(const MessageInfo & aMessageInfo) const596 bool Tcp::Listener::Matches(const MessageInfo &aMessageInfo) const
597 {
598     bool                       matches = false;
599     const struct tcpcb_listen *tpl     = &GetTcbListen();
600 
601     VerifyOrExit(tpl->t_state == TCP6S_LISTEN);
602     VerifyOrExit(tpl->lport == BigEndian::HostSwap16(aMessageInfo.GetSockPort()));
603     VerifyOrExit(GetLocalIp6Address().IsUnspecified() || GetLocalIp6Address() == aMessageInfo.GetSockAddr());
604 
605     matches = true;
606 
607 exit:
608     return matches;
609 }
610 
HandleMessage(ot::Ip6::Header & aIp6Header,Message & aMessage,MessageInfo & aMessageInfo)611 Error Tcp::HandleMessage(ot::Ip6::Header &aIp6Header, Message &aMessage, MessageInfo &aMessageInfo)
612 {
613     Error error = kErrorNotImplemented;
614 
615     /*
616      * The type uint32_t was chosen for alignment purposes. The size is the
617      * maximum TCP header size, including options.
618      */
619     uint32_t header[15];
620 
621     uint16_t length = aIp6Header.GetPayloadLength();
622     uint8_t  headerSize;
623 
624     struct ip6_hdr *ip6Header;
625     struct tcphdr  *tcpHeader;
626 
627     Endpoint *endpoint;
628     Listener *listener;
629 
630     struct tcplp_signals sig;
631     int                  nextAction;
632 
633     VerifyOrExit(length == aMessage.GetLength() - aMessage.GetOffset(), error = kErrorParse);
634     VerifyOrExit(length >= sizeof(Tcp::Header), error = kErrorParse);
635     SuccessOrExit(error = aMessage.Read(aMessage.GetOffset() + offsetof(struct tcphdr, th_off_x2), headerSize));
636     headerSize = static_cast<uint8_t>((headerSize >> TH_OFF_SHIFT) << 2);
637     VerifyOrExit(headerSize >= sizeof(struct tcphdr) && headerSize <= sizeof(header) &&
638                      static_cast<uint16_t>(headerSize) <= length,
639                  error = kErrorParse);
640     SuccessOrExit(error = Checksum::VerifyMessageChecksum(aMessage, aMessageInfo, kProtoTcp));
641     SuccessOrExit(error = aMessage.Read(aMessage.GetOffset(), &header[0], headerSize));
642 
643     ip6Header = reinterpret_cast<struct ip6_hdr *>(&aIp6Header);
644     tcpHeader = reinterpret_cast<struct tcphdr *>(&header[0]);
645     tcp_fields_to_host(tcpHeader);
646 
647     aMessageInfo.mPeerPort = BigEndian::HostSwap16(tcpHeader->th_sport);
648     aMessageInfo.mSockPort = BigEndian::HostSwap16(tcpHeader->th_dport);
649 
650     endpoint = mEndpoints.FindMatching(aMessageInfo);
651 
652     if (endpoint != nullptr)
653     {
654         struct tcpcb *tp = &endpoint->GetTcb();
655 
656         otLinkedBuffer *priorHead    = lbuf_head(&tp->sendbuf);
657         size_t          priorBacklog = endpoint->GetSendBufferBytes() - endpoint->GetInFlightBytes();
658 
659         ClearAllBytes(sig);
660         nextAction = tcplp_input(ip6Header, tcpHeader, &aMessage, tp, nullptr, &sig);
661         if (nextAction != RELOOKUP_REQUIRED)
662         {
663             ProcessSignals(*endpoint, priorHead, priorBacklog, sig);
664             ExitNow();
665         }
666         /* If the matching socket was in the TIME-WAIT state, then we try passive sockets. */
667     }
668 
669     listener = mListeners.FindMatching(aMessageInfo);
670 
671     if (listener != nullptr)
672     {
673         struct tcpcb_listen *tpl = &listener->GetTcbListen();
674 
675         ClearAllBytes(sig);
676         nextAction = tcplp_input(ip6Header, tcpHeader, &aMessage, nullptr, tpl, &sig);
677         OT_ASSERT(nextAction != RELOOKUP_REQUIRED);
678         if (sig.accepted_connection != nullptr)
679         {
680             ProcessSignals(Tcp::Endpoint::FromTcb(*sig.accepted_connection), nullptr, 0, sig);
681         }
682         ExitNow();
683     }
684 
685     tcp_dropwithreset(ip6Header, tcpHeader, nullptr, &InstanceLocator::GetInstance(), length - headerSize,
686                       ECONNREFUSED);
687 
688 exit:
689     return error;
690 }
691 
ProcessSignals(Endpoint & aEndpoint,otLinkedBuffer * aPriorHead,size_t aPriorBacklog,struct tcplp_signals & aSignals) const692 void Tcp::ProcessSignals(Endpoint             &aEndpoint,
693                          otLinkedBuffer       *aPriorHead,
694                          size_t                aPriorBacklog,
695                          struct tcplp_signals &aSignals) const
696 {
697     VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
698     if (aSignals.conn_established && aEndpoint.mEstablishedCallback != nullptr)
699     {
700         aEndpoint.mEstablishedCallback(&aEndpoint);
701     }
702 
703     VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
704     if (aEndpoint.mSendDoneCallback != nullptr)
705     {
706         otLinkedBuffer *curr = aPriorHead;
707 
708         OT_ASSERT(curr != nullptr || aSignals.links_popped == 0);
709 
710         for (uint32_t i = 0; i != aSignals.links_popped; i++)
711         {
712             otLinkedBuffer *next = curr->mNext;
713 
714             VerifyOrExit(i == 0 || (IsInitialized(aEndpoint) && !aEndpoint.IsClosed()));
715 
716             curr->mNext = nullptr;
717             aEndpoint.mSendDoneCallback(&aEndpoint, curr);
718             curr = next;
719         }
720     }
721 
722     VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
723     if (aEndpoint.mForwardProgressCallback != nullptr)
724     {
725         size_t backlogBytes = aEndpoint.GetBacklogBytes();
726 
727         if (aSignals.bytes_acked > 0 || backlogBytes < aPriorBacklog)
728         {
729             aEndpoint.mForwardProgressCallback(&aEndpoint, aEndpoint.GetSendBufferBytes(), backlogBytes);
730             aEndpoint.mPendingCallbacks &= ~kForwardProgressCallbackFlag;
731         }
732     }
733 
734     VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
735     if ((aSignals.recvbuf_added || aSignals.rcvd_fin) && aEndpoint.mReceiveAvailableCallback != nullptr)
736     {
737         aEndpoint.mReceiveAvailableCallback(&aEndpoint, cbuf_used_space(&aEndpoint.GetTcb().recvbuf),
738                                             aEndpoint.GetTcb().reass_fin_index != -1,
739                                             cbuf_free_space(&aEndpoint.GetTcb().recvbuf));
740     }
741 
742     VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
743     if (aEndpoint.GetTcb().t_state == TCP6S_TIME_WAIT && aEndpoint.mDisconnectedCallback != nullptr)
744     {
745         aEndpoint.mDisconnectedCallback(&aEndpoint, OT_TCP_DISCONNECTED_REASON_TIME_WAIT);
746     }
747 
748 exit:
749     return;
750 }
751 
BsdErrorToOtError(int aBsdError)752 Error Tcp::BsdErrorToOtError(int aBsdError)
753 {
754     Error error = kErrorFailed;
755 
756     switch (aBsdError)
757     {
758     case 0:
759         error = kErrorNone;
760         break;
761     }
762 
763     return error;
764 }
765 
CanBind(const SockAddr & aSockName)766 bool Tcp::CanBind(const SockAddr &aSockName)
767 {
768     uint16_t port    = BigEndian::HostSwap16(aSockName.mPort);
769     bool     allowed = false;
770 
771     for (Endpoint &endpoint : mEndpoints)
772     {
773         struct tcpcb *tp = &endpoint.GetTcb();
774 
775         if (tp->lport == port)
776         {
777             VerifyOrExit(!aSockName.GetAddress().IsUnspecified());
778             VerifyOrExit(!reinterpret_cast<Address *>(&tp->laddr)->IsUnspecified());
779             VerifyOrExit(memcmp(&endpoint.GetTcb().laddr, &aSockName.mAddress, sizeof(tp->laddr)) != 0);
780         }
781     }
782 
783     for (Listener &listener : mListeners)
784     {
785         struct tcpcb_listen *tpl = &listener.GetTcbListen();
786 
787         if (tpl->lport == port)
788         {
789             VerifyOrExit(!aSockName.GetAddress().IsUnspecified());
790             VerifyOrExit(!reinterpret_cast<Address *>(&tpl->laddr)->IsUnspecified());
791             VerifyOrExit(memcmp(&tpl->laddr, &aSockName.mAddress, sizeof(tpl->laddr)) != 0);
792         }
793     }
794 
795     allowed = true;
796 
797 exit:
798     return allowed;
799 }
800 
AutoBind(const SockAddr & aPeer,SockAddr & aToBind,bool aBindAddress,bool aBindPort)801 bool Tcp::AutoBind(const SockAddr &aPeer, SockAddr &aToBind, bool aBindAddress, bool aBindPort)
802 {
803     bool success;
804 
805     if (aBindAddress)
806     {
807         const Address *source;
808 
809         source = Get<Ip6>().SelectSourceAddress(aPeer.GetAddress());
810         VerifyOrExit(source != nullptr, success = false);
811         aToBind.SetAddress(*source);
812     }
813 
814     if (aBindPort)
815     {
816         /*
817          * TODO: Use a less naive algorithm to allocate ephemeral ports. For
818          * example, see RFC 6056.
819          */
820 
821         for (uint16_t i = 0; i != kDynamicPortMax - kDynamicPortMin + 1; i++)
822         {
823             aToBind.SetPort(mEphemeralPort);
824 
825             if (mEphemeralPort == kDynamicPortMax)
826             {
827                 mEphemeralPort = kDynamicPortMin;
828             }
829             else
830             {
831                 mEphemeralPort++;
832             }
833 
834             if (CanBind(aToBind))
835             {
836                 ExitNow(success = true);
837             }
838         }
839 
840         ExitNow(success = false);
841     }
842 
843     success = CanBind(aToBind);
844 
845 exit:
846     return success;
847 }
848 
HandleTimer(void)849 void Tcp::HandleTimer(void)
850 {
851     TimeMilli now = TimerMilli::GetNow();
852     bool      pendingTimer;
853     TimeMilli earliestPendingTimerExpiry;
854 
855     LogDebg("Main TCP timer expired");
856 
857     /*
858      * The timer callbacks could potentially set/reset/cancel timers.
859      * Importantly, Endpoint::SetTimer and Endpoint::CancelTimer do not call
860      * this function to recompute the timer. If they did, we'd have a
861      * re-entrancy problem, where the callbacks called in this function could
862      * wind up re-entering this function in a nested call frame.
863      *
864      * In general, calling this function from Endpoint::SetTimer and
865      * Endpoint::CancelTimer could be inefficient, since those functions are
866      * called multiple times on each received TCP segment. If we want to
867      * prevent the main timer from firing except when an actual TCP timer
868      * expires, a better alternative is to reset the main timer in
869      * HandleMessage, right before processing signals. That would achieve that
870      * objective while avoiding re-entrancy issues altogether.
871      */
872 restart:
873     pendingTimer               = false;
874     earliestPendingTimerExpiry = now.GetDistantFuture();
875 
876     for (Endpoint &endpoint : mEndpoints)
877     {
878         if (endpoint.FirePendingTimers(now, pendingTimer, earliestPendingTimerExpiry))
879         {
880             /*
881              * If a non-OpenThread callback is called --- which, in practice,
882              * happens if the connection times out and the user-defined
883              * connection lost callback is called --- then we might have to
884              * start over. The reason is that the user might deinitialize
885              * endpoints, changing the structure of the linked list. For
886              * example, if the user deinitializes both this endpoint and the
887              * next one in the linked list, then we can't continue traversing
888              * the linked list.
889              */
890             goto restart;
891         }
892     }
893 
894     if (pendingTimer)
895     {
896         /*
897          * We need to use Timer::FireAtIfEarlier instead of timer::FireAt
898          * because one of the earlier callbacks might have set TCP timers,
899          * in which case `mTimer` would have been set to the earliest of those
900          * timers.
901          */
902         mTimer.FireAtIfEarlier(earliestPendingTimerExpiry);
903         LogDebg("Reset main TCP timer to %u ms", static_cast<unsigned int>(earliestPendingTimerExpiry - now));
904     }
905     else
906     {
907         LogDebg("Did not reset main TCP timer");
908     }
909 }
910 
ProcessCallbacks(void)911 void Tcp::ProcessCallbacks(void)
912 {
913     for (Endpoint &endpoint : mEndpoints)
914     {
915         if (endpoint.FirePendingCallbacks())
916         {
917             mTasklet.Post();
918             break;
919         }
920     }
921 }
922 
923 } // namespace Ip6
924 } // namespace ot
925 
926 /*
927  * Implement TCPlp system stubs declared in tcplp.h.
928  *
929  * Because these functions have C linkage, it is important that only one
930  * definition is given for each function name, regardless of the namespace it
931  * in. For example, if we give two definitions of tcplp_sys_new_message, we
932  * will get errors, even if they are in different namespaces. To avoid
933  * confusion, I've put these functions outside of any namespace.
934  */
935 
936 using namespace ot;
937 using namespace ot::Ip6;
938 
939 extern "C" {
940 
tcplp_sys_new_message(otInstance * aInstance)941 otMessage *tcplp_sys_new_message(otInstance *aInstance)
942 {
943     Instance &instance = AsCoreType(aInstance);
944     Message  *message  = instance.Get<ot::Ip6::Ip6>().NewMessage(0);
945 
946     if (message)
947     {
948         message->SetLinkSecurityEnabled(true);
949     }
950 
951     return message;
952 }
953 
tcplp_sys_free_message(otInstance * aInstance,otMessage * aMessage)954 void tcplp_sys_free_message(otInstance *aInstance, otMessage *aMessage)
955 {
956     OT_UNUSED_VARIABLE(aInstance);
957     Message &message = AsCoreType(aMessage);
958     message.Free();
959 }
960 
tcplp_sys_send_message(otInstance * aInstance,otMessage * aMessage,otMessageInfo * aMessageInfo)961 void tcplp_sys_send_message(otInstance *aInstance, otMessage *aMessage, otMessageInfo *aMessageInfo)
962 {
963     Instance    &instance = AsCoreType(aInstance);
964     Message     &message  = AsCoreType(aMessage);
965     MessageInfo &info     = AsCoreType(aMessageInfo);
966 
967     LogDebg("Sending TCP segment: payload_size = %d", static_cast<int>(message.GetLength()));
968 
969     IgnoreError(instance.Get<ot::Ip6::Ip6>().SendDatagram(message, info, kProtoTcp));
970 }
971 
tcplp_sys_get_ticks(void)972 uint32_t tcplp_sys_get_ticks(void) { return TimerMilli::GetNow().GetValue(); }
973 
tcplp_sys_get_millis(void)974 uint32_t tcplp_sys_get_millis(void) { return TimerMilli::GetNow().GetValue(); }
975 
tcplp_sys_set_timer(struct tcpcb * aTcb,uint8_t aTimerFlag,uint32_t aDelay)976 void tcplp_sys_set_timer(struct tcpcb *aTcb, uint8_t aTimerFlag, uint32_t aDelay)
977 {
978     Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aTcb);
979     endpoint.SetTimer(aTimerFlag, aDelay);
980 }
981 
tcplp_sys_stop_timer(struct tcpcb * aTcb,uint8_t aTimerFlag)982 void tcplp_sys_stop_timer(struct tcpcb *aTcb, uint8_t aTimerFlag)
983 {
984     Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aTcb);
985     endpoint.CancelTimer(aTimerFlag);
986 }
987 
tcplp_sys_accept_ready(struct tcpcb_listen * aTcbListen,struct in6_addr * aAddr,uint16_t aPort)988 struct tcpcb *tcplp_sys_accept_ready(struct tcpcb_listen *aTcbListen, struct in6_addr *aAddr, uint16_t aPort)
989 {
990     Tcp::Listener                &listener = Tcp::Listener::FromTcbListen(*aTcbListen);
991     Tcp                          &tcp      = listener.Get<Tcp>();
992     struct tcpcb                 *rv       = (struct tcpcb *)-1;
993     otSockAddr                    addr;
994     otTcpEndpoint                *endpointPtr;
995     otTcpIncomingConnectionAction action;
996 
997     VerifyOrExit(listener.mAcceptReadyCallback != nullptr);
998 
999     memcpy(&addr.mAddress, aAddr, sizeof(addr.mAddress));
1000     addr.mPort = BigEndian::HostSwap16(aPort);
1001     action     = listener.mAcceptReadyCallback(&listener, &addr, &endpointPtr);
1002 
1003     VerifyOrExit(tcp.IsInitialized(listener) && !listener.IsClosed());
1004 
1005     switch (action)
1006     {
1007     case OT_TCP_INCOMING_CONNECTION_ACTION_ACCEPT:
1008     {
1009         Tcp::Endpoint &endpoint = AsCoreType(endpointPtr);
1010 
1011         /*
1012          * The documentation says that the user must initialize the
1013          * endpoint before passing it here, so we do a sanity check to make
1014          * sure the endpoint is initialized and closed. That check may not
1015          * be necessary, but we do it anyway.
1016          */
1017         VerifyOrExit(tcp.IsInitialized(endpoint) && endpoint.IsClosed());
1018 
1019         rv = &endpoint.GetTcb();
1020 
1021         break;
1022     }
1023     case OT_TCP_INCOMING_CONNECTION_ACTION_DEFER:
1024         rv = nullptr;
1025         break;
1026     case OT_TCP_INCOMING_CONNECTION_ACTION_REFUSE:
1027         rv = (struct tcpcb *)-1;
1028         break;
1029     }
1030 
1031 exit:
1032     return rv;
1033 }
1034 
tcplp_sys_accepted_connection(struct tcpcb_listen * aTcbListen,struct tcpcb * aAccepted,struct in6_addr * aAddr,uint16_t aPort)1035 bool tcplp_sys_accepted_connection(struct tcpcb_listen *aTcbListen,
1036                                    struct tcpcb        *aAccepted,
1037                                    struct in6_addr     *aAddr,
1038                                    uint16_t             aPort)
1039 {
1040     Tcp::Listener &listener = Tcp::Listener::FromTcbListen(*aTcbListen);
1041     Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aAccepted);
1042     Tcp           &tcp      = endpoint.Get<Tcp>();
1043     bool           accepted = true;
1044 
1045     if (listener.mAcceptDoneCallback != nullptr)
1046     {
1047         otSockAddr addr;
1048 
1049         memcpy(&addr.mAddress, aAddr, sizeof(addr.mAddress));
1050         addr.mPort = BigEndian::HostSwap16(aPort);
1051         listener.mAcceptDoneCallback(&listener, &endpoint, &addr);
1052 
1053         if (!tcp.IsInitialized(endpoint) || endpoint.IsClosed())
1054         {
1055             accepted = false;
1056         }
1057     }
1058 
1059     return accepted;
1060 }
1061 
tcplp_sys_connection_lost(struct tcpcb * aTcb,uint8_t aErrNum)1062 void tcplp_sys_connection_lost(struct tcpcb *aTcb, uint8_t aErrNum)
1063 {
1064     Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aTcb);
1065 
1066     if (endpoint.mDisconnectedCallback != nullptr)
1067     {
1068         otTcpDisconnectedReason reason;
1069 
1070         switch (aErrNum)
1071         {
1072         case CONN_LOST_NORMAL:
1073             reason = OT_TCP_DISCONNECTED_REASON_NORMAL;
1074             break;
1075         case ECONNREFUSED:
1076             reason = OT_TCP_DISCONNECTED_REASON_REFUSED;
1077             break;
1078         case ETIMEDOUT:
1079             reason = OT_TCP_DISCONNECTED_REASON_TIMED_OUT;
1080             break;
1081         case ECONNRESET:
1082         default:
1083             reason = OT_TCP_DISCONNECTED_REASON_RESET;
1084             break;
1085         }
1086         endpoint.mDisconnectedCallback(&endpoint, reason);
1087     }
1088 }
1089 
tcplp_sys_on_state_change(struct tcpcb * aTcb,int aNewState)1090 void tcplp_sys_on_state_change(struct tcpcb *aTcb, int aNewState)
1091 {
1092     if (aNewState == TCP6S_CLOSED)
1093     {
1094         /* Re-initialize the TCB. */
1095         cbuf_pop(&aTcb->recvbuf, cbuf_used_space(&aTcb->recvbuf));
1096         aTcb->accepted_from = nullptr;
1097         initialize_tcb(aTcb);
1098     }
1099     /* Any adaptive changes to the sleep interval would go here. */
1100 }
1101 
tcplp_sys_log(const char * aFormat,...)1102 void tcplp_sys_log(const char *aFormat, ...)
1103 {
1104     char    buffer[128];
1105     va_list args;
1106     va_start(args, aFormat);
1107     vsnprintf(buffer, sizeof(buffer), aFormat, args);
1108     va_end(args);
1109 
1110     LogDebg("%s", buffer);
1111 }
1112 
tcplp_sys_panic(const char * aFormat,...)1113 void tcplp_sys_panic(const char *aFormat, ...)
1114 {
1115     char    buffer[128];
1116     va_list args;
1117     va_start(args, aFormat);
1118     vsnprintf(buffer, sizeof(buffer), aFormat, args);
1119     va_end(args);
1120 
1121     LogCrit("%s", buffer);
1122 
1123     OT_ASSERT(false);
1124 }
1125 
tcplp_sys_autobind(otInstance * aInstance,const otSockAddr * aPeer,otSockAddr * aToBind,bool aBindAddress,bool aBindPort)1126 bool tcplp_sys_autobind(otInstance       *aInstance,
1127                         const otSockAddr *aPeer,
1128                         otSockAddr       *aToBind,
1129                         bool              aBindAddress,
1130                         bool              aBindPort)
1131 {
1132     Instance &instance = AsCoreType(aInstance);
1133 
1134     return instance.Get<Tcp>().AutoBind(*static_cast<const SockAddr *>(aPeer), *static_cast<SockAddr *>(aToBind),
1135                                         aBindAddress, aBindPort);
1136 }
1137 
tcplp_sys_generate_isn()1138 uint32_t tcplp_sys_generate_isn()
1139 {
1140     uint32_t isn;
1141     IgnoreError(Random::Crypto::Fill(isn));
1142     return isn;
1143 }
1144 
tcplp_sys_hostswap16(uint16_t aHostPort)1145 uint16_t tcplp_sys_hostswap16(uint16_t aHostPort) { return BigEndian::HostSwap16(aHostPort); }
1146 
tcplp_sys_hostswap32(uint32_t aHostPort)1147 uint32_t tcplp_sys_hostswap32(uint32_t aHostPort) { return BigEndian::HostSwap32(aHostPort); }
1148 }
1149 
1150 #endif // OPENTHREAD_CONFIG_TCP_ENABLE
1151