1 /*
2 * Copyright (c) 2021, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 /**
30 * @file
31 * This file implements TCP/IPv6 sockets.
32 */
33
34 #include "tcp6.hpp"
35
36 #if OPENTHREAD_CONFIG_TCP_ENABLE
37
38 #include "instance/instance.hpp"
39
40 #include "../../third_party/tcplp/tcplp.h"
41
42 namespace ot {
43 namespace Ip6 {
44
45 RegisterLogModule("Tcp");
46
47 static_assert(sizeof(struct tcpcb) == sizeof(Tcp::Endpoint::mTcb), "mTcb field in otTcpEndpoint is sized incorrectly");
48 static_assert(alignof(struct tcpcb) == alignof(decltype(Tcp::Endpoint::mTcb)),
49 "mTcb field in otTcpEndpoint is aligned incorrectly");
50 static_assert(offsetof(Tcp::Endpoint, mTcb) == 0, "mTcb field in otTcpEndpoint has nonzero offset");
51
52 static_assert(sizeof(struct tcpcb_listen) == sizeof(Tcp::Listener::mTcbListen),
53 "mTcbListen field in otTcpListener is sized incorrectly");
54 static_assert(alignof(struct tcpcb_listen) == alignof(decltype(Tcp::Listener::mTcbListen)),
55 "mTcbListen field in otTcpListener is aligned incorrectly");
56 static_assert(offsetof(Tcp::Listener, mTcbListen) == 0, "mTcbListen field in otTcpEndpoint has nonzero offset");
57
Tcp(Instance & aInstance)58 Tcp::Tcp(Instance &aInstance)
59 : InstanceLocator(aInstance)
60 , mTimer(aInstance)
61 , mTasklet(aInstance)
62 , mEphemeralPort(kDynamicPortMin)
63 {
64 OT_UNUSED_VARIABLE(mEphemeralPort);
65 }
66
Initialize(Instance & aInstance,const otTcpEndpointInitializeArgs & aArgs)67 Error Tcp::Endpoint::Initialize(Instance &aInstance, const otTcpEndpointInitializeArgs &aArgs)
68 {
69 Error error;
70 struct tcpcb &tp = GetTcb();
71
72 ClearAllBytes(tp);
73
74 SuccessOrExit(error = aInstance.Get<Tcp>().mEndpoints.Add(*this));
75
76 mContext = aArgs.mContext;
77 mEstablishedCallback = aArgs.mEstablishedCallback;
78 mSendDoneCallback = aArgs.mSendDoneCallback;
79 mForwardProgressCallback = aArgs.mForwardProgressCallback;
80 mReceiveAvailableCallback = aArgs.mReceiveAvailableCallback;
81 mDisconnectedCallback = aArgs.mDisconnectedCallback;
82
83 ClearAllBytes(mTimers);
84 ClearAllBytes(mSockAddr);
85 mPendingCallbacks = 0;
86
87 /*
88 * Initialize buffers --- formerly in initialize_tcb.
89 */
90 {
91 uint8_t *recvbuf = static_cast<uint8_t *>(aArgs.mReceiveBuffer);
92 size_t recvbuflen = aArgs.mReceiveBufferSize - ((aArgs.mReceiveBufferSize + 8) / 9);
93 uint8_t *reassbmp = recvbuf + recvbuflen;
94
95 lbuf_init(&tp.sendbuf);
96 cbuf_init(&tp.recvbuf, recvbuf, recvbuflen);
97 tp.reassbmp = reassbmp;
98 bmp_init(tp.reassbmp, BITS_TO_BYTES(recvbuflen));
99 }
100
101 tp.accepted_from = nullptr;
102 initialize_tcb(&tp);
103
104 /* Note that we do not need to zero-initialize mReceiveLinks. */
105
106 tp.instance = &aInstance;
107
108 exit:
109 return error;
110 }
111
GetInstance(void) const112 Instance &Tcp::Endpoint::GetInstance(void) const { return AsNonConst(AsCoreType(GetTcb().instance)); }
113
GetLocalAddress(void) const114 const SockAddr &Tcp::Endpoint::GetLocalAddress(void) const
115 {
116 const struct tcpcb &tp = GetTcb();
117
118 static otSockAddr temp;
119
120 memcpy(&temp.mAddress, &tp.laddr, sizeof(temp.mAddress));
121 temp.mPort = BigEndian::HostSwap16(tp.lport);
122
123 return AsCoreType(&temp);
124 }
125
GetPeerAddress(void) const126 const SockAddr &Tcp::Endpoint::GetPeerAddress(void) const
127 {
128 const struct tcpcb &tp = GetTcb();
129
130 static otSockAddr temp;
131
132 memcpy(&temp.mAddress, &tp.faddr, sizeof(temp.mAddress));
133 temp.mPort = BigEndian::HostSwap16(tp.fport);
134
135 return AsCoreType(&temp);
136 }
137
Bind(const SockAddr & aSockName)138 Error Tcp::Endpoint::Bind(const SockAddr &aSockName)
139 {
140 Error error;
141 struct tcpcb &tp = GetTcb();
142
143 VerifyOrExit(!AsCoreType(&aSockName.mAddress).IsUnspecified(), error = kErrorInvalidArgs);
144 VerifyOrExit(Get<Tcp>().CanBind(aSockName), error = kErrorInvalidState);
145
146 memcpy(&tp.laddr, &aSockName.mAddress, sizeof(tp.laddr));
147 tp.lport = BigEndian::HostSwap16(aSockName.mPort);
148 error = kErrorNone;
149
150 exit:
151 return error;
152 }
153
Connect(const SockAddr & aSockName,uint32_t aFlags)154 Error Tcp::Endpoint::Connect(const SockAddr &aSockName, uint32_t aFlags)
155 {
156 Error error = kErrorNone;
157 struct tcpcb &tp = GetTcb();
158
159 VerifyOrExit(tp.t_state == TCP6S_CLOSED, error = kErrorInvalidState);
160
161 if (aFlags & OT_TCP_CONNECT_NO_FAST_OPEN)
162 {
163 struct sockaddr_in6 sin6p;
164
165 tp.t_flags &= ~TF_FASTOPEN;
166 memcpy(&sin6p.sin6_addr, &aSockName.mAddress, sizeof(sin6p.sin6_addr));
167 sin6p.sin6_port = BigEndian::HostSwap16(aSockName.mPort);
168 error = BsdErrorToOtError(tcp6_usr_connect(&tp, &sin6p));
169 }
170 else
171 {
172 tp.t_flags |= TF_FASTOPEN;
173
174 /* Stash the destination address in tp. */
175 memcpy(&tp.faddr, &aSockName.mAddress, sizeof(tp.faddr));
176 tp.fport = BigEndian::HostSwap16(aSockName.mPort);
177 }
178
179 exit:
180 return error;
181 }
182
SendByReference(otLinkedBuffer & aBuffer,uint32_t aFlags)183 Error Tcp::Endpoint::SendByReference(otLinkedBuffer &aBuffer, uint32_t aFlags)
184 {
185 Error error;
186 struct tcpcb &tp = GetTcb();
187
188 size_t backlogBefore = GetBacklogBytes();
189 size_t sent = aBuffer.mLength;
190
191 struct sockaddr_in6 sin6p;
192 struct sockaddr_in6 *name = nullptr;
193
194 if (IS_FASTOPEN(tp.t_flags))
195 {
196 memcpy(&sin6p.sin6_addr, &tp.faddr, sizeof(sin6p.sin6_addr));
197 sin6p.sin6_port = tp.fport;
198 name = &sin6p;
199 }
200 SuccessOrExit(
201 error = BsdErrorToOtError(tcp_usr_send(&tp, (aFlags & OT_TCP_SEND_MORE_TO_COME) != 0, &aBuffer, 0, name)));
202
203 PostCallbacksAfterSend(sent, backlogBefore);
204
205 exit:
206 return error;
207 }
208
SendByExtension(size_t aNumBytes,uint32_t aFlags)209 Error Tcp::Endpoint::SendByExtension(size_t aNumBytes, uint32_t aFlags)
210 {
211 Error error;
212 bool moreToCome = (aFlags & OT_TCP_SEND_MORE_TO_COME) != 0;
213 struct tcpcb &tp = GetTcb();
214 size_t backlogBefore = GetBacklogBytes();
215 int bsdError;
216
217 struct sockaddr_in6 sin6p;
218 struct sockaddr_in6 *name = nullptr;
219
220 VerifyOrExit(lbuf_head(&tp.sendbuf) != nullptr, error = kErrorInvalidState);
221
222 if (IS_FASTOPEN(tp.t_flags))
223 {
224 memcpy(&sin6p.sin6_addr, &tp.faddr, sizeof(sin6p.sin6_addr));
225 sin6p.sin6_port = tp.fport;
226 name = &sin6p;
227 }
228
229 bsdError = tcp_usr_send(&tp, moreToCome ? 1 : 0, nullptr, aNumBytes, name);
230 SuccessOrExit(error = BsdErrorToOtError(bsdError));
231
232 PostCallbacksAfterSend(aNumBytes, backlogBefore);
233
234 exit:
235 return error;
236 }
237
ReceiveByReference(const otLinkedBuffer * & aBuffer)238 Error Tcp::Endpoint::ReceiveByReference(const otLinkedBuffer *&aBuffer)
239 {
240 struct tcpcb &tp = GetTcb();
241
242 cbuf_reference(&tp.recvbuf, &mReceiveLinks[0], &mReceiveLinks[1]);
243 aBuffer = &mReceiveLinks[0];
244
245 return kErrorNone;
246 }
247
ReceiveContiguify(void)248 Error Tcp::Endpoint::ReceiveContiguify(void)
249 {
250 struct tcpcb &tp = GetTcb();
251
252 cbuf_contiguify(&tp.recvbuf, tp.reassbmp);
253
254 return kErrorNone;
255 }
256
CommitReceive(size_t aNumBytes,uint32_t aFlags)257 Error Tcp::Endpoint::CommitReceive(size_t aNumBytes, uint32_t aFlags)
258 {
259 Error error = kErrorNone;
260 struct tcpcb &tp = GetTcb();
261
262 OT_UNUSED_VARIABLE(aFlags);
263
264 VerifyOrExit(cbuf_used_space(&tp.recvbuf) >= aNumBytes, error = kErrorFailed);
265 VerifyOrExit(aNumBytes > 0, error = kErrorNone);
266
267 cbuf_pop(&tp.recvbuf, aNumBytes);
268 error = BsdErrorToOtError(tcp_usr_rcvd(&tp));
269
270 exit:
271 return error;
272 }
273
SendEndOfStream(void)274 Error Tcp::Endpoint::SendEndOfStream(void)
275 {
276 struct tcpcb &tp = GetTcb();
277
278 return BsdErrorToOtError(tcp_usr_shutdown(&tp));
279 }
280
Abort(void)281 Error Tcp::Endpoint::Abort(void)
282 {
283 struct tcpcb &tp = GetTcb();
284
285 tcp_usr_abort(&tp);
286 /* connection_lost will do any reinitialization work for this socket. */
287 return kErrorNone;
288 }
289
Deinitialize(void)290 Error Tcp::Endpoint::Deinitialize(void)
291 {
292 Error error;
293
294 SuccessOrExit(error = Get<Tcp>().mEndpoints.Remove(*this));
295 SetNext(nullptr);
296
297 SuccessOrExit(error = Abort());
298
299 exit:
300 return error;
301 }
302
IsClosed(void) const303 bool Tcp::Endpoint::IsClosed(void) const { return GetTcb().t_state == TCP6S_CLOSED; }
304
TimerFlagToIndex(uint8_t aTimerFlag)305 uint8_t Tcp::Endpoint::TimerFlagToIndex(uint8_t aTimerFlag)
306 {
307 uint8_t timerIndex = 0;
308
309 switch (aTimerFlag)
310 {
311 case TT_DELACK:
312 timerIndex = kTimerDelack;
313 break;
314 case TT_REXMT:
315 case TT_PERSIST:
316 timerIndex = kTimerRexmtPersist;
317 break;
318 case TT_KEEP:
319 timerIndex = kTimerKeep;
320 break;
321 case TT_2MSL:
322 timerIndex = kTimer2Msl;
323 break;
324 }
325
326 return timerIndex;
327 }
328
IsTimerActive(uint8_t aTimerIndex)329 bool Tcp::Endpoint::IsTimerActive(uint8_t aTimerIndex)
330 {
331 bool active = false;
332 struct tcpcb *tp = &GetTcb();
333
334 OT_ASSERT(aTimerIndex < kNumTimers);
335 switch (aTimerIndex)
336 {
337 case kTimerDelack:
338 active = tcp_timer_active(tp, TT_DELACK);
339 break;
340 case kTimerRexmtPersist:
341 active = tcp_timer_active(tp, TT_REXMT) || tcp_timer_active(tp, TT_PERSIST);
342 break;
343 case kTimerKeep:
344 active = tcp_timer_active(tp, TT_KEEP);
345 break;
346 case kTimer2Msl:
347 active = tcp_timer_active(tp, TT_2MSL);
348 break;
349 }
350
351 return active;
352 }
353
SetTimer(uint8_t aTimerFlag,uint32_t aDelay)354 void Tcp::Endpoint::SetTimer(uint8_t aTimerFlag, uint32_t aDelay)
355 {
356 /*
357 * TCPlp has already set the flag for this timer to record that it's
358 * running. So, all that's left to do is record the expiry time and
359 * (re)set the main timer as appropriate.
360 */
361
362 TimeMilli now = TimerMilli::GetNow();
363 TimeMilli newFireTime = now + aDelay;
364 uint8_t timerIndex = TimerFlagToIndex(aTimerFlag);
365
366 mTimers[timerIndex] = newFireTime.GetValue();
367 LogDebg("Endpoint %p set timer %u to %u ms", static_cast<void *>(this), static_cast<unsigned int>(timerIndex),
368 static_cast<unsigned int>(aDelay));
369
370 Get<Tcp>().mTimer.FireAtIfEarlier(newFireTime);
371 }
372
CancelTimer(uint8_t aTimerFlag)373 void Tcp::Endpoint::CancelTimer(uint8_t aTimerFlag)
374 {
375 /*
376 * TCPlp has already cleared the timer flag before calling this. Since the
377 * main timer's callback properly handles the case where no timers are
378 * actually due, there's actually no work to be done here.
379 */
380
381 OT_UNUSED_VARIABLE(aTimerFlag);
382
383 LogDebg("Endpoint %p cancelled timer %u", static_cast<void *>(this),
384 static_cast<unsigned int>(TimerFlagToIndex(aTimerFlag)));
385 }
386
FirePendingTimers(TimeMilli aNow,bool & aHasFutureTimer,TimeMilli & aEarliestFutureExpiry)387 bool Tcp::Endpoint::FirePendingTimers(TimeMilli aNow, bool &aHasFutureTimer, TimeMilli &aEarliestFutureExpiry)
388 {
389 bool calledUserCallback = false;
390 struct tcpcb *tp = &GetTcb();
391
392 /*
393 * NOTE: Firing a timer might potentially activate/deactivate other timers.
394 * If timers x and y expire at the same time, but the callback for timer x
395 * (for x < y) cancels or postpones timer y, should timer y's callback be
396 * called? Our answer is no, since timer x's callback has updated the
397 * TCP stack's state in such a way that it no longer expects timer y's
398 * callback to to be called. Because the TCP stack thinks that timer y
399 * has been cancelled, calling timer y's callback could potentially cause
400 * problems.
401 *
402 * If the timer callbacks set other timers, then they may not be taken
403 * into account when setting aEarliestFutureExpiry. But mTimer's expiry
404 * time will be updated by those, so we can just compare against mTimer's
405 * expiry time when resetting mTimer.
406 */
407 for (uint8_t timerIndex = 0; timerIndex != kNumTimers; timerIndex++)
408 {
409 if (IsTimerActive(timerIndex))
410 {
411 TimeMilli expiry(mTimers[timerIndex]);
412
413 if (expiry <= aNow)
414 {
415 /*
416 * If a user callback is called, then return true. For TCPlp,
417 * this only happens if the connection is dropped (e.g., it
418 * times out).
419 */
420 int dropped = 0;
421
422 switch (timerIndex)
423 {
424 case kTimerDelack:
425 dropped = tcp_timer_delack(tp);
426 break;
427 case kTimerRexmtPersist:
428 if (tcp_timer_active(tp, TT_REXMT))
429 {
430 dropped = tcp_timer_rexmt(tp);
431 }
432 else
433 {
434 dropped = tcp_timer_persist(tp);
435 }
436 break;
437 case kTimerKeep:
438 dropped = tcp_timer_keep(tp);
439 break;
440 case kTimer2Msl:
441 dropped = tcp_timer_2msl(tp);
442 break;
443 }
444 VerifyOrExit(dropped == 0, calledUserCallback = true);
445 }
446 else
447 {
448 aHasFutureTimer = true;
449 aEarliestFutureExpiry = Min(aEarliestFutureExpiry, expiry);
450 }
451 }
452 }
453
454 exit:
455 return calledUserCallback;
456 }
457
PostCallbacksAfterSend(size_t aSent,size_t aBacklogBefore)458 void Tcp::Endpoint::PostCallbacksAfterSend(size_t aSent, size_t aBacklogBefore)
459 {
460 size_t backlogAfter = GetBacklogBytes();
461
462 if (backlogAfter < aBacklogBefore + aSent && mForwardProgressCallback != nullptr)
463 {
464 mPendingCallbacks |= kForwardProgressCallbackFlag;
465 Get<Tcp>().mTasklet.Post();
466 }
467 }
468
FirePendingCallbacks(void)469 bool Tcp::Endpoint::FirePendingCallbacks(void)
470 {
471 bool calledUserCallback = false;
472
473 if ((mPendingCallbacks & kForwardProgressCallbackFlag) != 0 && mForwardProgressCallback != nullptr)
474 {
475 mForwardProgressCallback(this, GetSendBufferBytes(), GetBacklogBytes());
476 calledUserCallback = true;
477 }
478
479 mPendingCallbacks = 0;
480
481 return calledUserCallback;
482 }
483
GetSendBufferBytes(void) const484 size_t Tcp::Endpoint::GetSendBufferBytes(void) const
485 {
486 const struct tcpcb &tp = GetTcb();
487 return lbuf_used_space(&tp.sendbuf);
488 }
489
GetInFlightBytes(void) const490 size_t Tcp::Endpoint::GetInFlightBytes(void) const
491 {
492 const struct tcpcb &tp = GetTcb();
493 return tp.snd_max - tp.snd_una;
494 }
495
GetBacklogBytes(void) const496 size_t Tcp::Endpoint::GetBacklogBytes(void) const { return GetSendBufferBytes() - GetInFlightBytes(); }
497
GetLocalIp6Address(void)498 Address &Tcp::Endpoint::GetLocalIp6Address(void) { return *reinterpret_cast<Address *>(&GetTcb().laddr); }
499
GetLocalIp6Address(void) const500 const Address &Tcp::Endpoint::GetLocalIp6Address(void) const
501 {
502 return *reinterpret_cast<const Address *>(&GetTcb().laddr);
503 }
504
GetForeignIp6Address(void)505 Address &Tcp::Endpoint::GetForeignIp6Address(void) { return *reinterpret_cast<Address *>(&GetTcb().faddr); }
506
GetForeignIp6Address(void) const507 const Address &Tcp::Endpoint::GetForeignIp6Address(void) const
508 {
509 return *reinterpret_cast<const Address *>(&GetTcb().faddr);
510 }
511
Matches(const MessageInfo & aMessageInfo) const512 bool Tcp::Endpoint::Matches(const MessageInfo &aMessageInfo) const
513 {
514 bool matches = false;
515 const struct tcpcb *tp = &GetTcb();
516
517 VerifyOrExit(tp->t_state != TCP6S_CLOSED);
518 VerifyOrExit(tp->lport == BigEndian::HostSwap16(aMessageInfo.GetSockPort()));
519 VerifyOrExit(tp->fport == BigEndian::HostSwap16(aMessageInfo.GetPeerPort()));
520 VerifyOrExit(GetLocalIp6Address().IsUnspecified() || GetLocalIp6Address() == aMessageInfo.GetSockAddr());
521 VerifyOrExit(GetForeignIp6Address() == aMessageInfo.GetPeerAddr());
522
523 matches = true;
524
525 exit:
526 return matches;
527 }
528
Initialize(Instance & aInstance,const otTcpListenerInitializeArgs & aArgs)529 Error Tcp::Listener::Initialize(Instance &aInstance, const otTcpListenerInitializeArgs &aArgs)
530 {
531 Error error;
532 struct tcpcb_listen *tpl = &GetTcbListen();
533
534 SuccessOrExit(error = aInstance.Get<Tcp>().mListeners.Add(*this));
535
536 mContext = aArgs.mContext;
537 mAcceptReadyCallback = aArgs.mAcceptReadyCallback;
538 mAcceptDoneCallback = aArgs.mAcceptDoneCallback;
539
540 ClearAllBytes(*tpl);
541 tpl->instance = &aInstance;
542
543 exit:
544 return error;
545 }
546
GetInstance(void) const547 Instance &Tcp::Listener::GetInstance(void) const { return AsNonConst(AsCoreType(GetTcbListen().instance)); }
548
Listen(const SockAddr & aSockName)549 Error Tcp::Listener::Listen(const SockAddr &aSockName)
550 {
551 Error error;
552 uint16_t port = BigEndian::HostSwap16(aSockName.mPort);
553 struct tcpcb_listen *tpl = &GetTcbListen();
554
555 VerifyOrExit(Get<Tcp>().CanBind(aSockName), error = kErrorInvalidState);
556
557 memcpy(&tpl->laddr, &aSockName.mAddress, sizeof(tpl->laddr));
558 tpl->lport = port;
559 tpl->t_state = TCP6S_LISTEN;
560 error = kErrorNone;
561
562 exit:
563 return error;
564 }
565
StopListening(void)566 Error Tcp::Listener::StopListening(void)
567 {
568 struct tcpcb_listen *tpl = &GetTcbListen();
569
570 ClearAllBytes(tpl->laddr);
571 tpl->lport = 0;
572 tpl->t_state = TCP6S_CLOSED;
573 return kErrorNone;
574 }
575
Deinitialize(void)576 Error Tcp::Listener::Deinitialize(void)
577 {
578 Error error;
579
580 SuccessOrExit(error = Get<Tcp>().mListeners.Remove(*this));
581 SetNext(nullptr);
582
583 exit:
584 return error;
585 }
586
IsClosed(void) const587 bool Tcp::Listener::IsClosed(void) const { return GetTcbListen().t_state == TCP6S_CLOSED; }
588
GetLocalIp6Address(void)589 Address &Tcp::Listener::GetLocalIp6Address(void) { return *reinterpret_cast<Address *>(&GetTcbListen().laddr); }
590
GetLocalIp6Address(void) const591 const Address &Tcp::Listener::GetLocalIp6Address(void) const
592 {
593 return *reinterpret_cast<const Address *>(&GetTcbListen().laddr);
594 }
595
Matches(const MessageInfo & aMessageInfo) const596 bool Tcp::Listener::Matches(const MessageInfo &aMessageInfo) const
597 {
598 bool matches = false;
599 const struct tcpcb_listen *tpl = &GetTcbListen();
600
601 VerifyOrExit(tpl->t_state == TCP6S_LISTEN);
602 VerifyOrExit(tpl->lport == BigEndian::HostSwap16(aMessageInfo.GetSockPort()));
603 VerifyOrExit(GetLocalIp6Address().IsUnspecified() || GetLocalIp6Address() == aMessageInfo.GetSockAddr());
604
605 matches = true;
606
607 exit:
608 return matches;
609 }
610
HandleMessage(ot::Ip6::Header & aIp6Header,Message & aMessage,MessageInfo & aMessageInfo)611 Error Tcp::HandleMessage(ot::Ip6::Header &aIp6Header, Message &aMessage, MessageInfo &aMessageInfo)
612 {
613 Error error = kErrorNotImplemented;
614
615 /*
616 * The type uint32_t was chosen for alignment purposes. The size is the
617 * maximum TCP header size, including options.
618 */
619 uint32_t header[15];
620
621 uint16_t length = aIp6Header.GetPayloadLength();
622 uint8_t headerSize;
623
624 struct ip6_hdr *ip6Header;
625 struct tcphdr *tcpHeader;
626
627 Endpoint *endpoint;
628 Listener *listener;
629
630 struct tcplp_signals sig;
631 int nextAction;
632
633 VerifyOrExit(length == aMessage.GetLength() - aMessage.GetOffset(), error = kErrorParse);
634 VerifyOrExit(length >= sizeof(Tcp::Header), error = kErrorParse);
635 SuccessOrExit(error = aMessage.Read(aMessage.GetOffset() + offsetof(struct tcphdr, th_off_x2), headerSize));
636 headerSize = static_cast<uint8_t>((headerSize >> TH_OFF_SHIFT) << 2);
637 VerifyOrExit(headerSize >= sizeof(struct tcphdr) && headerSize <= sizeof(header) &&
638 static_cast<uint16_t>(headerSize) <= length,
639 error = kErrorParse);
640 SuccessOrExit(error = Checksum::VerifyMessageChecksum(aMessage, aMessageInfo, kProtoTcp));
641 SuccessOrExit(error = aMessage.Read(aMessage.GetOffset(), &header[0], headerSize));
642
643 ip6Header = reinterpret_cast<struct ip6_hdr *>(&aIp6Header);
644 tcpHeader = reinterpret_cast<struct tcphdr *>(&header[0]);
645 tcp_fields_to_host(tcpHeader);
646
647 aMessageInfo.mPeerPort = BigEndian::HostSwap16(tcpHeader->th_sport);
648 aMessageInfo.mSockPort = BigEndian::HostSwap16(tcpHeader->th_dport);
649
650 endpoint = mEndpoints.FindMatching(aMessageInfo);
651
652 if (endpoint != nullptr)
653 {
654 struct tcpcb *tp = &endpoint->GetTcb();
655
656 otLinkedBuffer *priorHead = lbuf_head(&tp->sendbuf);
657 size_t priorBacklog = endpoint->GetSendBufferBytes() - endpoint->GetInFlightBytes();
658
659 ClearAllBytes(sig);
660 nextAction = tcplp_input(ip6Header, tcpHeader, &aMessage, tp, nullptr, &sig);
661 if (nextAction != RELOOKUP_REQUIRED)
662 {
663 ProcessSignals(*endpoint, priorHead, priorBacklog, sig);
664 ExitNow();
665 }
666 /* If the matching socket was in the TIME-WAIT state, then we try passive sockets. */
667 }
668
669 listener = mListeners.FindMatching(aMessageInfo);
670
671 if (listener != nullptr)
672 {
673 struct tcpcb_listen *tpl = &listener->GetTcbListen();
674
675 ClearAllBytes(sig);
676 nextAction = tcplp_input(ip6Header, tcpHeader, &aMessage, nullptr, tpl, &sig);
677 OT_ASSERT(nextAction != RELOOKUP_REQUIRED);
678 if (sig.accepted_connection != nullptr)
679 {
680 ProcessSignals(Tcp::Endpoint::FromTcb(*sig.accepted_connection), nullptr, 0, sig);
681 }
682 ExitNow();
683 }
684
685 tcp_dropwithreset(ip6Header, tcpHeader, nullptr, &InstanceLocator::GetInstance(), length - headerSize,
686 ECONNREFUSED);
687
688 exit:
689 return error;
690 }
691
ProcessSignals(Endpoint & aEndpoint,otLinkedBuffer * aPriorHead,size_t aPriorBacklog,struct tcplp_signals & aSignals) const692 void Tcp::ProcessSignals(Endpoint &aEndpoint,
693 otLinkedBuffer *aPriorHead,
694 size_t aPriorBacklog,
695 struct tcplp_signals &aSignals) const
696 {
697 VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
698 if (aSignals.conn_established && aEndpoint.mEstablishedCallback != nullptr)
699 {
700 aEndpoint.mEstablishedCallback(&aEndpoint);
701 }
702
703 VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
704 if (aEndpoint.mSendDoneCallback != nullptr)
705 {
706 otLinkedBuffer *curr = aPriorHead;
707
708 OT_ASSERT(curr != nullptr || aSignals.links_popped == 0);
709
710 for (uint32_t i = 0; i != aSignals.links_popped; i++)
711 {
712 otLinkedBuffer *next = curr->mNext;
713
714 VerifyOrExit(i == 0 || (IsInitialized(aEndpoint) && !aEndpoint.IsClosed()));
715
716 curr->mNext = nullptr;
717 aEndpoint.mSendDoneCallback(&aEndpoint, curr);
718 curr = next;
719 }
720 }
721
722 VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
723 if (aEndpoint.mForwardProgressCallback != nullptr)
724 {
725 size_t backlogBytes = aEndpoint.GetBacklogBytes();
726
727 if (aSignals.bytes_acked > 0 || backlogBytes < aPriorBacklog)
728 {
729 aEndpoint.mForwardProgressCallback(&aEndpoint, aEndpoint.GetSendBufferBytes(), backlogBytes);
730 aEndpoint.mPendingCallbacks &= ~kForwardProgressCallbackFlag;
731 }
732 }
733
734 VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
735 if ((aSignals.recvbuf_added || aSignals.rcvd_fin) && aEndpoint.mReceiveAvailableCallback != nullptr)
736 {
737 aEndpoint.mReceiveAvailableCallback(&aEndpoint, cbuf_used_space(&aEndpoint.GetTcb().recvbuf),
738 aEndpoint.GetTcb().reass_fin_index != -1,
739 cbuf_free_space(&aEndpoint.GetTcb().recvbuf));
740 }
741
742 VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
743 if (aEndpoint.GetTcb().t_state == TCP6S_TIME_WAIT && aEndpoint.mDisconnectedCallback != nullptr)
744 {
745 aEndpoint.mDisconnectedCallback(&aEndpoint, OT_TCP_DISCONNECTED_REASON_TIME_WAIT);
746 }
747
748 exit:
749 return;
750 }
751
BsdErrorToOtError(int aBsdError)752 Error Tcp::BsdErrorToOtError(int aBsdError)
753 {
754 Error error = kErrorFailed;
755
756 switch (aBsdError)
757 {
758 case 0:
759 error = kErrorNone;
760 break;
761 }
762
763 return error;
764 }
765
CanBind(const SockAddr & aSockName)766 bool Tcp::CanBind(const SockAddr &aSockName)
767 {
768 uint16_t port = BigEndian::HostSwap16(aSockName.mPort);
769 bool allowed = false;
770
771 for (Endpoint &endpoint : mEndpoints)
772 {
773 struct tcpcb *tp = &endpoint.GetTcb();
774
775 if (tp->lport == port)
776 {
777 VerifyOrExit(!aSockName.GetAddress().IsUnspecified());
778 VerifyOrExit(!reinterpret_cast<Address *>(&tp->laddr)->IsUnspecified());
779 VerifyOrExit(memcmp(&endpoint.GetTcb().laddr, &aSockName.mAddress, sizeof(tp->laddr)) != 0);
780 }
781 }
782
783 for (Listener &listener : mListeners)
784 {
785 struct tcpcb_listen *tpl = &listener.GetTcbListen();
786
787 if (tpl->lport == port)
788 {
789 VerifyOrExit(!aSockName.GetAddress().IsUnspecified());
790 VerifyOrExit(!reinterpret_cast<Address *>(&tpl->laddr)->IsUnspecified());
791 VerifyOrExit(memcmp(&tpl->laddr, &aSockName.mAddress, sizeof(tpl->laddr)) != 0);
792 }
793 }
794
795 allowed = true;
796
797 exit:
798 return allowed;
799 }
800
AutoBind(const SockAddr & aPeer,SockAddr & aToBind,bool aBindAddress,bool aBindPort)801 bool Tcp::AutoBind(const SockAddr &aPeer, SockAddr &aToBind, bool aBindAddress, bool aBindPort)
802 {
803 bool success;
804
805 if (aBindAddress)
806 {
807 const Address *source;
808
809 source = Get<Ip6>().SelectSourceAddress(aPeer.GetAddress());
810 VerifyOrExit(source != nullptr, success = false);
811 aToBind.SetAddress(*source);
812 }
813
814 if (aBindPort)
815 {
816 /*
817 * TODO: Use a less naive algorithm to allocate ephemeral ports. For
818 * example, see RFC 6056.
819 */
820
821 for (uint16_t i = 0; i != kDynamicPortMax - kDynamicPortMin + 1; i++)
822 {
823 aToBind.SetPort(mEphemeralPort);
824
825 if (mEphemeralPort == kDynamicPortMax)
826 {
827 mEphemeralPort = kDynamicPortMin;
828 }
829 else
830 {
831 mEphemeralPort++;
832 }
833
834 if (CanBind(aToBind))
835 {
836 ExitNow(success = true);
837 }
838 }
839
840 ExitNow(success = false);
841 }
842
843 success = CanBind(aToBind);
844
845 exit:
846 return success;
847 }
848
HandleTimer(void)849 void Tcp::HandleTimer(void)
850 {
851 TimeMilli now = TimerMilli::GetNow();
852 bool pendingTimer;
853 TimeMilli earliestPendingTimerExpiry;
854
855 LogDebg("Main TCP timer expired");
856
857 /*
858 * The timer callbacks could potentially set/reset/cancel timers.
859 * Importantly, Endpoint::SetTimer and Endpoint::CancelTimer do not call
860 * this function to recompute the timer. If they did, we'd have a
861 * re-entrancy problem, where the callbacks called in this function could
862 * wind up re-entering this function in a nested call frame.
863 *
864 * In general, calling this function from Endpoint::SetTimer and
865 * Endpoint::CancelTimer could be inefficient, since those functions are
866 * called multiple times on each received TCP segment. If we want to
867 * prevent the main timer from firing except when an actual TCP timer
868 * expires, a better alternative is to reset the main timer in
869 * HandleMessage, right before processing signals. That would achieve that
870 * objective while avoiding re-entrancy issues altogether.
871 */
872 restart:
873 pendingTimer = false;
874 earliestPendingTimerExpiry = now.GetDistantFuture();
875
876 for (Endpoint &endpoint : mEndpoints)
877 {
878 if (endpoint.FirePendingTimers(now, pendingTimer, earliestPendingTimerExpiry))
879 {
880 /*
881 * If a non-OpenThread callback is called --- which, in practice,
882 * happens if the connection times out and the user-defined
883 * connection lost callback is called --- then we might have to
884 * start over. The reason is that the user might deinitialize
885 * endpoints, changing the structure of the linked list. For
886 * example, if the user deinitializes both this endpoint and the
887 * next one in the linked list, then we can't continue traversing
888 * the linked list.
889 */
890 goto restart;
891 }
892 }
893
894 if (pendingTimer)
895 {
896 /*
897 * We need to use Timer::FireAtIfEarlier instead of timer::FireAt
898 * because one of the earlier callbacks might have set TCP timers,
899 * in which case `mTimer` would have been set to the earliest of those
900 * timers.
901 */
902 mTimer.FireAtIfEarlier(earliestPendingTimerExpiry);
903 LogDebg("Reset main TCP timer to %u ms", static_cast<unsigned int>(earliestPendingTimerExpiry - now));
904 }
905 else
906 {
907 LogDebg("Did not reset main TCP timer");
908 }
909 }
910
ProcessCallbacks(void)911 void Tcp::ProcessCallbacks(void)
912 {
913 for (Endpoint &endpoint : mEndpoints)
914 {
915 if (endpoint.FirePendingCallbacks())
916 {
917 mTasklet.Post();
918 break;
919 }
920 }
921 }
922
923 } // namespace Ip6
924 } // namespace ot
925
926 /*
927 * Implement TCPlp system stubs declared in tcplp.h.
928 *
929 * Because these functions have C linkage, it is important that only one
930 * definition is given for each function name, regardless of the namespace it
931 * in. For example, if we give two definitions of tcplp_sys_new_message, we
932 * will get errors, even if they are in different namespaces. To avoid
933 * confusion, I've put these functions outside of any namespace.
934 */
935
936 using namespace ot;
937 using namespace ot::Ip6;
938
939 extern "C" {
940
tcplp_sys_new_message(otInstance * aInstance)941 otMessage *tcplp_sys_new_message(otInstance *aInstance)
942 {
943 Instance &instance = AsCoreType(aInstance);
944 Message *message = instance.Get<ot::Ip6::Ip6>().NewMessage(0);
945
946 if (message)
947 {
948 message->SetLinkSecurityEnabled(true);
949 }
950
951 return message;
952 }
953
tcplp_sys_free_message(otInstance * aInstance,otMessage * aMessage)954 void tcplp_sys_free_message(otInstance *aInstance, otMessage *aMessage)
955 {
956 OT_UNUSED_VARIABLE(aInstance);
957 Message &message = AsCoreType(aMessage);
958 message.Free();
959 }
960
tcplp_sys_send_message(otInstance * aInstance,otMessage * aMessage,otMessageInfo * aMessageInfo)961 void tcplp_sys_send_message(otInstance *aInstance, otMessage *aMessage, otMessageInfo *aMessageInfo)
962 {
963 Instance &instance = AsCoreType(aInstance);
964 Message &message = AsCoreType(aMessage);
965 MessageInfo &info = AsCoreType(aMessageInfo);
966
967 LogDebg("Sending TCP segment: payload_size = %d", static_cast<int>(message.GetLength()));
968
969 IgnoreError(instance.Get<ot::Ip6::Ip6>().SendDatagram(message, info, kProtoTcp));
970 }
971
tcplp_sys_get_ticks(void)972 uint32_t tcplp_sys_get_ticks(void) { return TimerMilli::GetNow().GetValue(); }
973
tcplp_sys_get_millis(void)974 uint32_t tcplp_sys_get_millis(void) { return TimerMilli::GetNow().GetValue(); }
975
tcplp_sys_set_timer(struct tcpcb * aTcb,uint8_t aTimerFlag,uint32_t aDelay)976 void tcplp_sys_set_timer(struct tcpcb *aTcb, uint8_t aTimerFlag, uint32_t aDelay)
977 {
978 Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aTcb);
979 endpoint.SetTimer(aTimerFlag, aDelay);
980 }
981
tcplp_sys_stop_timer(struct tcpcb * aTcb,uint8_t aTimerFlag)982 void tcplp_sys_stop_timer(struct tcpcb *aTcb, uint8_t aTimerFlag)
983 {
984 Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aTcb);
985 endpoint.CancelTimer(aTimerFlag);
986 }
987
tcplp_sys_accept_ready(struct tcpcb_listen * aTcbListen,struct in6_addr * aAddr,uint16_t aPort)988 struct tcpcb *tcplp_sys_accept_ready(struct tcpcb_listen *aTcbListen, struct in6_addr *aAddr, uint16_t aPort)
989 {
990 Tcp::Listener &listener = Tcp::Listener::FromTcbListen(*aTcbListen);
991 Tcp &tcp = listener.Get<Tcp>();
992 struct tcpcb *rv = (struct tcpcb *)-1;
993 otSockAddr addr;
994 otTcpEndpoint *endpointPtr;
995 otTcpIncomingConnectionAction action;
996
997 VerifyOrExit(listener.mAcceptReadyCallback != nullptr);
998
999 memcpy(&addr.mAddress, aAddr, sizeof(addr.mAddress));
1000 addr.mPort = BigEndian::HostSwap16(aPort);
1001 action = listener.mAcceptReadyCallback(&listener, &addr, &endpointPtr);
1002
1003 VerifyOrExit(tcp.IsInitialized(listener) && !listener.IsClosed());
1004
1005 switch (action)
1006 {
1007 case OT_TCP_INCOMING_CONNECTION_ACTION_ACCEPT:
1008 {
1009 Tcp::Endpoint &endpoint = AsCoreType(endpointPtr);
1010
1011 /*
1012 * The documentation says that the user must initialize the
1013 * endpoint before passing it here, so we do a sanity check to make
1014 * sure the endpoint is initialized and closed. That check may not
1015 * be necessary, but we do it anyway.
1016 */
1017 VerifyOrExit(tcp.IsInitialized(endpoint) && endpoint.IsClosed());
1018
1019 rv = &endpoint.GetTcb();
1020
1021 break;
1022 }
1023 case OT_TCP_INCOMING_CONNECTION_ACTION_DEFER:
1024 rv = nullptr;
1025 break;
1026 case OT_TCP_INCOMING_CONNECTION_ACTION_REFUSE:
1027 rv = (struct tcpcb *)-1;
1028 break;
1029 }
1030
1031 exit:
1032 return rv;
1033 }
1034
tcplp_sys_accepted_connection(struct tcpcb_listen * aTcbListen,struct tcpcb * aAccepted,struct in6_addr * aAddr,uint16_t aPort)1035 bool tcplp_sys_accepted_connection(struct tcpcb_listen *aTcbListen,
1036 struct tcpcb *aAccepted,
1037 struct in6_addr *aAddr,
1038 uint16_t aPort)
1039 {
1040 Tcp::Listener &listener = Tcp::Listener::FromTcbListen(*aTcbListen);
1041 Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aAccepted);
1042 Tcp &tcp = endpoint.Get<Tcp>();
1043 bool accepted = true;
1044
1045 if (listener.mAcceptDoneCallback != nullptr)
1046 {
1047 otSockAddr addr;
1048
1049 memcpy(&addr.mAddress, aAddr, sizeof(addr.mAddress));
1050 addr.mPort = BigEndian::HostSwap16(aPort);
1051 listener.mAcceptDoneCallback(&listener, &endpoint, &addr);
1052
1053 if (!tcp.IsInitialized(endpoint) || endpoint.IsClosed())
1054 {
1055 accepted = false;
1056 }
1057 }
1058
1059 return accepted;
1060 }
1061
tcplp_sys_connection_lost(struct tcpcb * aTcb,uint8_t aErrNum)1062 void tcplp_sys_connection_lost(struct tcpcb *aTcb, uint8_t aErrNum)
1063 {
1064 Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aTcb);
1065
1066 if (endpoint.mDisconnectedCallback != nullptr)
1067 {
1068 otTcpDisconnectedReason reason;
1069
1070 switch (aErrNum)
1071 {
1072 case CONN_LOST_NORMAL:
1073 reason = OT_TCP_DISCONNECTED_REASON_NORMAL;
1074 break;
1075 case ECONNREFUSED:
1076 reason = OT_TCP_DISCONNECTED_REASON_REFUSED;
1077 break;
1078 case ETIMEDOUT:
1079 reason = OT_TCP_DISCONNECTED_REASON_TIMED_OUT;
1080 break;
1081 case ECONNRESET:
1082 default:
1083 reason = OT_TCP_DISCONNECTED_REASON_RESET;
1084 break;
1085 }
1086 endpoint.mDisconnectedCallback(&endpoint, reason);
1087 }
1088 }
1089
tcplp_sys_on_state_change(struct tcpcb * aTcb,int aNewState)1090 void tcplp_sys_on_state_change(struct tcpcb *aTcb, int aNewState)
1091 {
1092 if (aNewState == TCP6S_CLOSED)
1093 {
1094 /* Re-initialize the TCB. */
1095 cbuf_pop(&aTcb->recvbuf, cbuf_used_space(&aTcb->recvbuf));
1096 aTcb->accepted_from = nullptr;
1097 initialize_tcb(aTcb);
1098 }
1099 /* Any adaptive changes to the sleep interval would go here. */
1100 }
1101
tcplp_sys_log(const char * aFormat,...)1102 void tcplp_sys_log(const char *aFormat, ...)
1103 {
1104 char buffer[128];
1105 va_list args;
1106 va_start(args, aFormat);
1107 vsnprintf(buffer, sizeof(buffer), aFormat, args);
1108 va_end(args);
1109
1110 LogDebg("%s", buffer);
1111 }
1112
tcplp_sys_panic(const char * aFormat,...)1113 void tcplp_sys_panic(const char *aFormat, ...)
1114 {
1115 char buffer[128];
1116 va_list args;
1117 va_start(args, aFormat);
1118 vsnprintf(buffer, sizeof(buffer), aFormat, args);
1119 va_end(args);
1120
1121 LogCrit("%s", buffer);
1122
1123 OT_ASSERT(false);
1124 }
1125
tcplp_sys_autobind(otInstance * aInstance,const otSockAddr * aPeer,otSockAddr * aToBind,bool aBindAddress,bool aBindPort)1126 bool tcplp_sys_autobind(otInstance *aInstance,
1127 const otSockAddr *aPeer,
1128 otSockAddr *aToBind,
1129 bool aBindAddress,
1130 bool aBindPort)
1131 {
1132 Instance &instance = AsCoreType(aInstance);
1133
1134 return instance.Get<Tcp>().AutoBind(*static_cast<const SockAddr *>(aPeer), *static_cast<SockAddr *>(aToBind),
1135 aBindAddress, aBindPort);
1136 }
1137
tcplp_sys_generate_isn()1138 uint32_t tcplp_sys_generate_isn()
1139 {
1140 uint32_t isn;
1141 IgnoreError(Random::Crypto::Fill(isn));
1142 return isn;
1143 }
1144
tcplp_sys_hostswap16(uint16_t aHostPort)1145 uint16_t tcplp_sys_hostswap16(uint16_t aHostPort) { return BigEndian::HostSwap16(aHostPort); }
1146
tcplp_sys_hostswap32(uint32_t aHostPort)1147 uint32_t tcplp_sys_hostswap32(uint32_t aHostPort) { return BigEndian::HostSwap32(aHostPort); }
1148 }
1149
1150 #endif // OPENTHREAD_CONFIG_TCP_ENABLE
1151