1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/tcp_client_socket.h"
6
7 #include <memory>
8 #include <utility>
9
10 #include "base/check_op.h"
11 #include "base/functional/bind.h"
12 #include "base/functional/callback_helpers.h"
13 #include "base/memory/ptr_util.h"
14 #include "base/metrics/histogram_macros.h"
15 #include "base/notreached.h"
16 #include "base/time/time.h"
17 #include "net/base/features.h"
18 #include "net/base/io_buffer.h"
19 #include "net/base/ip_endpoint.h"
20 #include "net/base/net_errors.h"
21 #include "net/nqe/network_quality_estimator.h"
22 #include "net/socket/socket_performance_watcher.h"
23 #include "net/traffic_annotation/network_traffic_annotation.h"
24
25 #if defined(TCP_CLIENT_SOCKET_OBSERVES_SUSPEND)
26 #include "base/power_monitor/power_monitor.h"
27 #endif
28
29 namespace net {
30
31 class NetLogWithSource;
32
TCPClientSocket(const AddressList & addresses,std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,NetworkQualityEstimator * network_quality_estimator,net::NetLog * net_log,const net::NetLogSource & source,handles::NetworkHandle network)33 TCPClientSocket::TCPClientSocket(
34 const AddressList& addresses,
35 std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
36 NetworkQualityEstimator* network_quality_estimator,
37 net::NetLog* net_log,
38 const net::NetLogSource& source,
39 handles::NetworkHandle network)
40 : TCPClientSocket(TCPSocket::Create(std::move(socket_performance_watcher),
41 net_log,
42 source),
43 addresses,
44 -1 /* current_address_index */,
45 nullptr /* bind_address */,
46 network_quality_estimator,
47 network) {}
48
TCPClientSocket(std::unique_ptr<TCPSocket> connected_socket,const IPEndPoint & peer_address)49 TCPClientSocket::TCPClientSocket(std::unique_ptr<TCPSocket> connected_socket,
50 const IPEndPoint& peer_address)
51 : TCPClientSocket(std::move(connected_socket),
52 AddressList(peer_address),
53 0 /* current_address_index */,
54 nullptr /* bind_address */,
55 // TODO(https://crbug.com/1123197: Pass non-null
56 // NetworkQualityEstimator
57 nullptr /* network_quality_estimator */,
58 handles::kInvalidNetworkHandle) {}
59
TCPClientSocket(std::unique_ptr<TCPSocket> unconnected_socket,const AddressList & addresses,std::unique_ptr<IPEndPoint> bound_address,NetworkQualityEstimator * network_quality_estimator)60 TCPClientSocket::TCPClientSocket(
61 std::unique_ptr<TCPSocket> unconnected_socket,
62 const AddressList& addresses,
63 std::unique_ptr<IPEndPoint> bound_address,
64 NetworkQualityEstimator* network_quality_estimator)
65 : TCPClientSocket(std::move(unconnected_socket),
66 addresses,
67 -1 /* current_address_index */,
68 std::move(bound_address),
69 network_quality_estimator,
70 handles::kInvalidNetworkHandle) {}
71
~TCPClientSocket()72 TCPClientSocket::~TCPClientSocket() {
73 Disconnect();
74 #if defined(TCP_CLIENT_SOCKET_OBSERVES_SUSPEND)
75 base::PowerMonitor::GetInstance()->RemovePowerSuspendObserver(this);
76 #endif // defined(TCP_CLIENT_SOCKET_OBSERVES_SUSPEND)
77 }
78
CreateFromBoundSocket(std::unique_ptr<TCPSocket> bound_socket,const AddressList & addresses,const IPEndPoint & bound_address,NetworkQualityEstimator * network_quality_estimator)79 std::unique_ptr<TCPClientSocket> TCPClientSocket::CreateFromBoundSocket(
80 std::unique_ptr<TCPSocket> bound_socket,
81 const AddressList& addresses,
82 const IPEndPoint& bound_address,
83 NetworkQualityEstimator* network_quality_estimator) {
84 return base::WrapUnique(new TCPClientSocket(
85 std::move(bound_socket), addresses, -1 /* current_address_index */,
86 std::make_unique<IPEndPoint>(bound_address), network_quality_estimator,
87 handles::kInvalidNetworkHandle));
88 }
89
Bind(const IPEndPoint & address)90 int TCPClientSocket::Bind(const IPEndPoint& address) {
91 if (current_address_index_ >= 0 || bind_address_) {
92 // Cannot bind the socket if we are already connected or connecting.
93 NOTREACHED();
94 }
95
96 int result = OK;
97 if (!socket_->IsValid()) {
98 result = OpenSocket(address.GetFamily());
99 if (result != OK)
100 return result;
101 }
102
103 result = socket_->Bind(address);
104 if (result != OK)
105 return result;
106
107 bind_address_ = std::make_unique<IPEndPoint>(address);
108 return OK;
109 }
110
SetKeepAlive(bool enable,int delay)111 bool TCPClientSocket::SetKeepAlive(bool enable, int delay) {
112 return socket_->SetKeepAlive(enable, delay);
113 }
114
SetNoDelay(bool no_delay)115 bool TCPClientSocket::SetNoDelay(bool no_delay) {
116 return socket_->SetNoDelay(no_delay);
117 }
118
SetBeforeConnectCallback(const BeforeConnectCallback & before_connect_callback)119 void TCPClientSocket::SetBeforeConnectCallback(
120 const BeforeConnectCallback& before_connect_callback) {
121 DCHECK_EQ(CONNECT_STATE_NONE, next_connect_state_);
122 before_connect_callback_ = before_connect_callback;
123 }
124
Connect(CompletionOnceCallback callback)125 int TCPClientSocket::Connect(CompletionOnceCallback callback) {
126 DCHECK(!callback.is_null());
127
128 // If connecting or already connected, then just return OK.
129 if (socket_->IsValid() && current_address_index_ >= 0)
130 return OK;
131
132 DCHECK(!read_callback_);
133 DCHECK(!write_callback_);
134
135 if (was_disconnected_on_suspend_) {
136 Disconnect();
137 was_disconnected_on_suspend_ = false;
138 }
139
140 socket_->StartLoggingMultipleConnectAttempts(addresses_);
141
142 // We will try to connect to each address in addresses_. Start with the
143 // first one in the list.
144 next_connect_state_ = CONNECT_STATE_CONNECT;
145 current_address_index_ = 0;
146
147 int rv = DoConnectLoop(OK);
148 if (rv == ERR_IO_PENDING) {
149 connect_callback_ = std::move(callback);
150 } else {
151 socket_->EndLoggingMultipleConnectAttempts(rv);
152 }
153
154 return rv;
155 }
156
TCPClientSocket(std::unique_ptr<TCPSocket> socket,const AddressList & addresses,int current_address_index,std::unique_ptr<IPEndPoint> bind_address,NetworkQualityEstimator * network_quality_estimator,handles::NetworkHandle network)157 TCPClientSocket::TCPClientSocket(
158 std::unique_ptr<TCPSocket> socket,
159 const AddressList& addresses,
160 int current_address_index,
161 std::unique_ptr<IPEndPoint> bind_address,
162 NetworkQualityEstimator* network_quality_estimator,
163 handles::NetworkHandle network)
164 : socket_(std::move(socket)),
165 bind_address_(std::move(bind_address)),
166 addresses_(addresses),
167 current_address_index_(current_address_index),
168 network_quality_estimator_(network_quality_estimator),
169 network_(network) {
170 DCHECK(socket_);
171 if (socket_->IsValid())
172 socket_->SetDefaultOptionsForClient();
173 #if defined(TCP_CLIENT_SOCKET_OBSERVES_SUSPEND)
174 base::PowerMonitor::GetInstance()->AddPowerSuspendObserver(this);
175 #endif // defined(TCP_CLIENT_SOCKET_OBSERVES_SUSPEND)
176 }
177
ReadCommon(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,bool read_if_ready)178 int TCPClientSocket::ReadCommon(IOBuffer* buf,
179 int buf_len,
180 CompletionOnceCallback callback,
181 bool read_if_ready) {
182 DCHECK(!callback.is_null());
183 DCHECK(read_callback_.is_null());
184
185 if (was_disconnected_on_suspend_)
186 return ERR_NETWORK_IO_SUSPENDED;
187
188 // |socket_| is owned by |this| and the callback won't be run once |socket_|
189 // is gone/closed. Therefore, it is safe to use base::Unretained() here.
190 CompletionOnceCallback complete_read_callback =
191 base::BindOnce(&TCPClientSocket::DidCompleteRead, base::Unretained(this));
192 int result =
193 read_if_ready
194 ? socket_->ReadIfReady(buf, buf_len,
195 std::move(complete_read_callback))
196 : socket_->Read(buf, buf_len, std::move(complete_read_callback));
197 if (result == ERR_IO_PENDING) {
198 read_callback_ = std::move(callback);
199 } else if (result > 0) {
200 was_ever_used_ = true;
201 total_received_bytes_ += result;
202 }
203
204 return result;
205 }
206
DoConnectLoop(int result)207 int TCPClientSocket::DoConnectLoop(int result) {
208 DCHECK_NE(next_connect_state_, CONNECT_STATE_NONE);
209
210 int rv = result;
211 do {
212 ConnectState state = next_connect_state_;
213 next_connect_state_ = CONNECT_STATE_NONE;
214 switch (state) {
215 case CONNECT_STATE_CONNECT:
216 DCHECK_EQ(OK, rv);
217 rv = DoConnect();
218 break;
219 case CONNECT_STATE_CONNECT_COMPLETE:
220 rv = DoConnectComplete(rv);
221 break;
222 default:
223 NOTREACHED() << "bad state " << state;
224 }
225 } while (rv != ERR_IO_PENDING && next_connect_state_ != CONNECT_STATE_NONE);
226
227 return rv;
228 }
229
DoConnect()230 int TCPClientSocket::DoConnect() {
231 DCHECK_GE(current_address_index_, 0);
232 DCHECK_LT(current_address_index_, static_cast<int>(addresses_.size()));
233
234 const IPEndPoint& endpoint = addresses_[current_address_index_];
235
236 if (previously_disconnected_) {
237 was_ever_used_ = false;
238 previously_disconnected_ = false;
239 }
240
241 next_connect_state_ = CONNECT_STATE_CONNECT_COMPLETE;
242
243 if (!socket_->IsValid()) {
244 int result = OpenSocket(endpoint.GetFamily());
245 if (result != OK)
246 return result;
247
248 if (bind_address_) {
249 result = socket_->Bind(*bind_address_);
250 if (result != OK) {
251 socket_->Close();
252 return result;
253 }
254 }
255 }
256
257 if (before_connect_callback_) {
258 int result = before_connect_callback_.Run();
259 DCHECK_NE(ERR_IO_PENDING, result);
260 if (result != net::OK)
261 return result;
262 }
263
264 // Notify |socket_performance_watcher_| only if the |socket_| is reused to
265 // connect to a different IP Address.
266 if (socket_->socket_performance_watcher() && current_address_index_ != 0)
267 socket_->socket_performance_watcher()->OnConnectionChanged();
268
269 start_connect_attempt_ = base::TimeTicks::Now();
270
271 // Start a timer to fail the connect attempt if it takes too long.
272 base::TimeDelta attempt_timeout = GetConnectAttemptTimeout();
273 if (!attempt_timeout.is_max()) {
274 DCHECK(!connect_attempt_timer_.IsRunning());
275 connect_attempt_timer_.Start(
276 FROM_HERE, attempt_timeout,
277 base::BindOnce(&TCPClientSocket::OnConnectAttemptTimeout,
278 base::Unretained(this)));
279 }
280
281 return ConnectInternal(endpoint);
282 }
283
DoConnectComplete(int result)284 int TCPClientSocket::DoConnectComplete(int result) {
285 if (start_connect_attempt_) {
286 EmitConnectAttemptHistograms(result);
287 start_connect_attempt_ = std::nullopt;
288 connect_attempt_timer_.Stop();
289 }
290
291 if (result == OK)
292 return OK; // Done!
293
294 // Don't try the next address if entering suspend mode.
295 if (result == ERR_NETWORK_IO_SUSPENDED)
296 return result;
297
298 // Close whatever partially connected socket we currently have.
299 DoDisconnect();
300
301 // Try to fall back to the next address in the list.
302 if (current_address_index_ + 1 < static_cast<int>(addresses_.size())) {
303 next_connect_state_ = CONNECT_STATE_CONNECT;
304 ++current_address_index_;
305 return OK;
306 }
307
308 // Otherwise there is nothing to fall back to, so give up.
309 return result;
310 }
311
OnConnectAttemptTimeout()312 void TCPClientSocket::OnConnectAttemptTimeout() {
313 DidCompleteConnect(ERR_TIMED_OUT);
314 }
315
ConnectInternal(const IPEndPoint & endpoint)316 int TCPClientSocket::ConnectInternal(const IPEndPoint& endpoint) {
317 // |socket_| is owned by this class and the callback won't be run once
318 // |socket_| is gone. Therefore, it is safe to use base::Unretained() here.
319 return socket_->Connect(endpoint,
320 base::BindOnce(&TCPClientSocket::DidCompleteConnect,
321 base::Unretained(this)));
322 }
323
Disconnect()324 void TCPClientSocket::Disconnect() {
325 DoDisconnect();
326 current_address_index_ = -1;
327 bind_address_.reset();
328
329 // Cancel any pending callbacks. Not done in DoDisconnect() because that's
330 // called on connection failure, when the connect callback will need to be
331 // invoked.
332 was_disconnected_on_suspend_ = false;
333 connect_callback_.Reset();
334 read_callback_.Reset();
335 write_callback_.Reset();
336 }
337
DoDisconnect()338 void TCPClientSocket::DoDisconnect() {
339 if (start_connect_attempt_) {
340 EmitConnectAttemptHistograms(ERR_ABORTED);
341 start_connect_attempt_ = std::nullopt;
342 connect_attempt_timer_.Stop();
343 }
344
345 total_received_bytes_ = 0;
346
347 // If connecting or already connected, record that the socket has been
348 // disconnected.
349 previously_disconnected_ = socket_->IsValid() && current_address_index_ >= 0;
350 socket_->Close();
351
352 // Invalidate weak pointers, so if in the middle of a callback in OnSuspend,
353 // and something destroys this, no other callback is invoked.
354 weak_ptr_factory_.InvalidateWeakPtrs();
355 }
356
IsConnected() const357 bool TCPClientSocket::IsConnected() const {
358 return socket_->IsConnected();
359 }
360
IsConnectedAndIdle() const361 bool TCPClientSocket::IsConnectedAndIdle() const {
362 return socket_->IsConnectedAndIdle();
363 }
364
GetPeerAddress(IPEndPoint * address) const365 int TCPClientSocket::GetPeerAddress(IPEndPoint* address) const {
366 return socket_->GetPeerAddress(address);
367 }
368
GetLocalAddress(IPEndPoint * address) const369 int TCPClientSocket::GetLocalAddress(IPEndPoint* address) const {
370 DCHECK(address);
371
372 if (!socket_->IsValid()) {
373 if (bind_address_) {
374 *address = *bind_address_;
375 return OK;
376 }
377 return ERR_SOCKET_NOT_CONNECTED;
378 }
379
380 return socket_->GetLocalAddress(address);
381 }
382
NetLog() const383 const NetLogWithSource& TCPClientSocket::NetLog() const {
384 return socket_->net_log();
385 }
386
WasEverUsed() const387 bool TCPClientSocket::WasEverUsed() const {
388 return was_ever_used_;
389 }
390
GetNegotiatedProtocol() const391 NextProto TCPClientSocket::GetNegotiatedProtocol() const {
392 return kProtoUnknown;
393 }
394
GetSSLInfo(SSLInfo * ssl_info)395 bool TCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
396 return false;
397 }
398
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)399 int TCPClientSocket::Read(IOBuffer* buf,
400 int buf_len,
401 CompletionOnceCallback callback) {
402 return ReadCommon(buf, buf_len, std::move(callback), /*read_if_ready=*/false);
403 }
404
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)405 int TCPClientSocket::ReadIfReady(IOBuffer* buf,
406 int buf_len,
407 CompletionOnceCallback callback) {
408 return ReadCommon(buf, buf_len, std::move(callback), /*read_if_ready=*/true);
409 }
410
CancelReadIfReady()411 int TCPClientSocket::CancelReadIfReady() {
412 DCHECK(read_callback_);
413 read_callback_.Reset();
414 return socket_->CancelReadIfReady();
415 }
416
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)417 int TCPClientSocket::Write(
418 IOBuffer* buf,
419 int buf_len,
420 CompletionOnceCallback callback,
421 const NetworkTrafficAnnotationTag& traffic_annotation) {
422 DCHECK(!callback.is_null());
423 DCHECK(write_callback_.is_null());
424
425 if (was_disconnected_on_suspend_)
426 return ERR_NETWORK_IO_SUSPENDED;
427
428 // |socket_| is owned by this class and the callback won't be run once
429 // |socket_| is gone. Therefore, it is safe to use base::Unretained() here.
430 CompletionOnceCallback complete_write_callback = base::BindOnce(
431 &TCPClientSocket::DidCompleteWrite, base::Unretained(this));
432 int result = socket_->Write(buf, buf_len, std::move(complete_write_callback),
433 traffic_annotation);
434 if (result == ERR_IO_PENDING) {
435 write_callback_ = std::move(callback);
436 } else if (result > 0) {
437 was_ever_used_ = true;
438 }
439
440 return result;
441 }
442
SetReceiveBufferSize(int32_t size)443 int TCPClientSocket::SetReceiveBufferSize(int32_t size) {
444 return socket_->SetReceiveBufferSize(size);
445 }
446
SetSendBufferSize(int32_t size)447 int TCPClientSocket::SetSendBufferSize(int32_t size) {
448 return socket_->SetSendBufferSize(size);
449 }
450
SocketDescriptorForTesting() const451 SocketDescriptor TCPClientSocket::SocketDescriptorForTesting() const {
452 return socket_->SocketDescriptorForTesting();
453 }
454
GetTotalReceivedBytes() const455 int64_t TCPClientSocket::GetTotalReceivedBytes() const {
456 return total_received_bytes_;
457 }
458
ApplySocketTag(const SocketTag & tag)459 void TCPClientSocket::ApplySocketTag(const SocketTag& tag) {
460 socket_->ApplySocketTag(tag);
461 }
462
OnSuspend()463 void TCPClientSocket::OnSuspend() {
464 #if defined(TCP_CLIENT_SOCKET_OBSERVES_SUSPEND)
465 // If the socket is connected, or connecting, act as if current and future
466 // operations on the socket fail with ERR_NETWORK_IO_SUSPENDED, until the
467 // socket is reconnected.
468
469 if (next_connect_state_ != CONNECT_STATE_NONE) {
470 socket_->Close();
471 DidCompleteConnect(ERR_NETWORK_IO_SUSPENDED);
472 return;
473 }
474
475 // Nothing to do. Use IsValid() rather than IsConnected() because it results
476 // in more testable code, as when calling OnSuspend mode on two sockets
477 // connected to each other will otherwise cause two sockets to behave
478 // differently from each other.
479 if (!socket_->IsValid())
480 return;
481
482 // Use Close() rather than Disconnect() / DoDisconnect() to avoid mutating
483 // state, which more closely matches normal read/write error behavior.
484 socket_->Close();
485
486 was_disconnected_on_suspend_ = true;
487
488 // Grab a weak pointer just in case calling read callback results in |this|
489 // being destroyed, or disconnected. In either case, should not run the write
490 // callback.
491 base::WeakPtr<TCPClientSocket> weak_this = weak_ptr_factory_.GetWeakPtr();
492
493 // Have to grab the write callback now, as it's theoretically possible for the
494 // read callback to reconnects the socket, that reconnection to complete
495 // synchronously, and then for it to start a new write. That also means this
496 // code can't use DidCompleteWrite().
497 CompletionOnceCallback write_callback = std::move(write_callback_);
498 if (read_callback_)
499 DidCompleteRead(ERR_NETWORK_IO_SUSPENDED);
500 if (weak_this && write_callback)
501 std::move(write_callback).Run(ERR_NETWORK_IO_SUSPENDED);
502 #endif // defined(TCP_CLIENT_SOCKET_OBSERVES_SUSPEND)
503 }
504
DidCompleteConnect(int result)505 void TCPClientSocket::DidCompleteConnect(int result) {
506 DCHECK_EQ(next_connect_state_, CONNECT_STATE_CONNECT_COMPLETE);
507 DCHECK_NE(result, ERR_IO_PENDING);
508 DCHECK(!connect_callback_.is_null());
509
510 result = DoConnectLoop(result);
511 if (result != ERR_IO_PENDING) {
512 socket_->EndLoggingMultipleConnectAttempts(result);
513 std::move(connect_callback_).Run(result);
514 }
515 }
516
DidCompleteRead(int result)517 void TCPClientSocket::DidCompleteRead(int result) {
518 DCHECK(!read_callback_.is_null());
519
520 if (result > 0)
521 total_received_bytes_ += result;
522 DidCompleteReadWrite(std::move(read_callback_), result);
523 }
524
DidCompleteWrite(int result)525 void TCPClientSocket::DidCompleteWrite(int result) {
526 DCHECK(!write_callback_.is_null());
527
528 DidCompleteReadWrite(std::move(write_callback_), result);
529 }
530
DidCompleteReadWrite(CompletionOnceCallback callback,int result)531 void TCPClientSocket::DidCompleteReadWrite(CompletionOnceCallback callback,
532 int result) {
533 if (result > 0)
534 was_ever_used_ = true;
535 std::move(callback).Run(result);
536 }
537
OpenSocket(AddressFamily family)538 int TCPClientSocket::OpenSocket(AddressFamily family) {
539 DCHECK(!socket_->IsValid());
540
541 int result = socket_->Open(family);
542 if (result != OK)
543 return result;
544
545 if (network_ != handles::kInvalidNetworkHandle) {
546 result = socket_->BindToNetwork(network_);
547 if (result != OK) {
548 socket_->Close();
549 return result;
550 }
551 }
552
553 socket_->SetDefaultOptionsForClient();
554
555 return OK;
556 }
557
EmitConnectAttemptHistograms(int result)558 void TCPClientSocket::EmitConnectAttemptHistograms(int result) {
559 // This should only be called in response to completing a connect attempt.
560 DCHECK(start_connect_attempt_);
561
562 base::TimeDelta duration =
563 base::TimeTicks::Now() - start_connect_attempt_.value();
564
565 // Histogram the total time the connect attempt took, grouped by success and
566 // failure. Note that failures also include cases when the connect attempt
567 // was cancelled by the client before the handshake completed.
568 if (result == OK) {
569 DEPRECATED_UMA_HISTOGRAM_MEDIUM_TIMES(
570 "Net.TcpConnectAttempt.Latency.Success", duration);
571 } else {
572 DEPRECATED_UMA_HISTOGRAM_MEDIUM_TIMES("Net.TcpConnectAttempt.Latency.Error",
573 duration);
574 }
575 }
576
GetConnectAttemptTimeout()577 base::TimeDelta TCPClientSocket::GetConnectAttemptTimeout() {
578 if (!base::FeatureList::IsEnabled(features::kTimeoutTcpConnectAttempt))
579 return base::TimeDelta::Max();
580
581 std::optional<base::TimeDelta> transport_rtt = std::nullopt;
582 if (network_quality_estimator_)
583 transport_rtt = network_quality_estimator_->GetTransportRTT();
584
585 base::TimeDelta min_timeout = features::kTimeoutTcpConnectAttemptMin.Get();
586 base::TimeDelta max_timeout = features::kTimeoutTcpConnectAttemptMax.Get();
587
588 if (!transport_rtt)
589 return max_timeout;
590
591 base::TimeDelta adaptive_timeout =
592 transport_rtt.value() *
593 features::kTimeoutTcpConnectAttemptRTTMultiplier.Get();
594
595 if (adaptive_timeout <= min_timeout)
596 return min_timeout;
597
598 if (adaptive_timeout >= max_timeout)
599 return max_timeout;
600
601 return adaptive_timeout;
602 }
603
604 } // namespace net
605