1 // Copyright 2014 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/websocket_transport_client_socket_pool.h"
6
7 #include <algorithm>
8
9 #include "base/check_op.h"
10 #include "base/compiler_specific.h"
11 #include "base/functional/bind.h"
12 #include "base/functional/callback_helpers.h"
13 #include "base/location.h"
14 #include "base/notreached.h"
15 #include "base/numerics/safe_conversions.h"
16 #include "base/strings/string_util.h"
17 #include "base/task/single_thread_task_runner.h"
18 #include "base/values.h"
19 #include "net/base/net_errors.h"
20 #include "net/log/net_log_event_type.h"
21 #include "net/log/net_log_source.h"
22 #include "net/log/net_log_source_type.h"
23 #include "net/socket/client_socket_handle.h"
24 #include "net/socket/connect_job.h"
25 #include "net/socket/connect_job_factory.h"
26 #include "net/socket/websocket_endpoint_lock_manager.h"
27 #include "net/traffic_annotation/network_traffic_annotation.h"
28
29 namespace net {
30
WebSocketTransportClientSocketPool(int max_sockets,int max_sockets_per_group,const ProxyServer & proxy_server,const CommonConnectJobParams * common_connect_job_params)31 WebSocketTransportClientSocketPool::WebSocketTransportClientSocketPool(
32 int max_sockets,
33 int max_sockets_per_group,
34 const ProxyServer& proxy_server,
35 const CommonConnectJobParams* common_connect_job_params)
36 : ClientSocketPool(/*is_for_websockets=*/true,
37 common_connect_job_params,
38 std::make_unique<ConnectJobFactory>()),
39 proxy_server_(proxy_server),
40 max_sockets_(max_sockets) {
41 DCHECK(common_connect_job_params->websocket_endpoint_lock_manager);
42 }
43
~WebSocketTransportClientSocketPool()44 WebSocketTransportClientSocketPool::~WebSocketTransportClientSocketPool() {
45 // Clean up any pending connect jobs.
46 FlushWithError(ERR_ABORTED, "");
47 DCHECK(pending_connects_.empty());
48 DCHECK_EQ(0, handed_out_socket_count_);
49 DCHECK(stalled_request_queue_.empty());
50 DCHECK(stalled_request_map_.empty());
51 }
52
53 // static
UnlockEndpoint(ClientSocketHandle * handle,WebSocketEndpointLockManager * websocket_endpoint_lock_manager)54 void WebSocketTransportClientSocketPool::UnlockEndpoint(
55 ClientSocketHandle* handle,
56 WebSocketEndpointLockManager* websocket_endpoint_lock_manager) {
57 DCHECK(handle->is_initialized());
58 DCHECK(handle->socket());
59 IPEndPoint address;
60 if (handle->socket()->GetPeerAddress(&address) == OK)
61 websocket_endpoint_lock_manager->UnlockEndpoint(address);
62 }
63
RequestSocket(const GroupId & group_id,scoped_refptr<SocketParams> params,const absl::optional<NetworkTrafficAnnotationTag> & proxy_annotation_tag,RequestPriority priority,const SocketTag & socket_tag,RespectLimits respect_limits,ClientSocketHandle * handle,CompletionOnceCallback callback,const ProxyAuthCallback & proxy_auth_callback,const NetLogWithSource & request_net_log)64 int WebSocketTransportClientSocketPool::RequestSocket(
65 const GroupId& group_id,
66 scoped_refptr<SocketParams> params,
67 const absl::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag,
68 RequestPriority priority,
69 const SocketTag& socket_tag,
70 RespectLimits respect_limits,
71 ClientSocketHandle* handle,
72 CompletionOnceCallback callback,
73 const ProxyAuthCallback& proxy_auth_callback,
74 const NetLogWithSource& request_net_log) {
75 DCHECK(params);
76 CHECK(!callback.is_null());
77 CHECK(handle);
78 DCHECK(socket_tag == SocketTag());
79
80 NetLogTcpClientSocketPoolRequestedSocket(request_net_log, group_id);
81 request_net_log.BeginEvent(NetLogEventType::SOCKET_POOL);
82
83 if (ReachedMaxSocketsLimit() &&
84 respect_limits == ClientSocketPool::RespectLimits::ENABLED) {
85 request_net_log.AddEvent(NetLogEventType::SOCKET_POOL_STALLED_MAX_SOCKETS);
86 stalled_request_queue_.emplace_back(group_id, params, proxy_annotation_tag,
87 priority, handle, std::move(callback),
88 proxy_auth_callback, request_net_log);
89 auto iterator = stalled_request_queue_.end();
90 --iterator;
91 DCHECK_EQ(handle, iterator->handle);
92 // Because StalledRequestQueue is a std::list, its iterators are guaranteed
93 // to remain valid as long as the elements are not removed. As long as
94 // stalled_request_queue_ and stalled_request_map_ are updated in sync, it
95 // is safe to dereference an iterator in stalled_request_map_ to find the
96 // corresponding list element.
97 stalled_request_map_.insert(
98 StalledRequestMap::value_type(handle, iterator));
99 return ERR_IO_PENDING;
100 }
101
102 std::unique_ptr<ConnectJobDelegate> connect_job_delegate =
103 std::make_unique<ConnectJobDelegate>(this, std::move(callback), handle,
104 request_net_log);
105
106 std::unique_ptr<ConnectJob> connect_job =
107 CreateConnectJob(group_id, params, proxy_server_, proxy_annotation_tag,
108 priority, SocketTag(), connect_job_delegate.get());
109
110 int result = connect_job_delegate->Connect(std::move(connect_job));
111
112 // Regardless of the outcome of |connect_job|, it will always be bound to
113 // |handle|, since this pool uses early-binding. So the binding is logged
114 // here, without waiting for the result.
115 request_net_log.AddEventReferencingSource(
116 NetLogEventType::SOCKET_POOL_BOUND_TO_CONNECT_JOB,
117 connect_job_delegate->connect_job_net_log().source());
118
119 if (result == ERR_IO_PENDING) {
120 // TODO(ricea): Implement backup job timer?
121 AddJob(handle, std::move(connect_job_delegate));
122 } else {
123 TryHandOutSocket(result, connect_job_delegate.get());
124 }
125
126 return result;
127 }
128
RequestSockets(const GroupId & group_id,scoped_refptr<SocketParams> params,const absl::optional<NetworkTrafficAnnotationTag> & proxy_annotation_tag,int num_sockets,CompletionOnceCallback callback,const NetLogWithSource & net_log)129 int WebSocketTransportClientSocketPool::RequestSockets(
130 const GroupId& group_id,
131 scoped_refptr<SocketParams> params,
132 const absl::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag,
133 int num_sockets,
134 CompletionOnceCallback callback,
135 const NetLogWithSource& net_log) {
136 NOTIMPLEMENTED();
137 return OK;
138 }
139
SetPriority(const GroupId & group_id,ClientSocketHandle * handle,RequestPriority priority)140 void WebSocketTransportClientSocketPool::SetPriority(const GroupId& group_id,
141 ClientSocketHandle* handle,
142 RequestPriority priority) {
143 // Since sockets requested by RequestSocket are bound early and
144 // stalled_request_{queue,map} don't take priorities into account, there's
145 // nothing to do within the pool to change priority of the request.
146 // TODO(rdsmith, ricea): Make stalled_request_{queue,map} take priorities
147 // into account.
148 // TODO(rdsmith, chlily): Investigate plumbing the reprioritization request to
149 // the connect job.
150 }
151
CancelRequest(const GroupId & group_id,ClientSocketHandle * handle,bool cancel_connect_job)152 void WebSocketTransportClientSocketPool::CancelRequest(
153 const GroupId& group_id,
154 ClientSocketHandle* handle,
155 bool cancel_connect_job) {
156 DCHECK(!handle->is_initialized());
157 if (DeleteStalledRequest(handle))
158 return;
159 std::unique_ptr<StreamSocket> socket = handle->PassSocket();
160 if (socket)
161 ReleaseSocket(handle->group_id(), std::move(socket),
162 handle->group_generation());
163 if (!DeleteJob(handle))
164 pending_callbacks_.erase(reinterpret_cast<ClientSocketHandleID>(handle));
165
166 ActivateStalledRequest();
167 }
168
ReleaseSocket(const GroupId & group_id,std::unique_ptr<StreamSocket> socket,int64_t generation)169 void WebSocketTransportClientSocketPool::ReleaseSocket(
170 const GroupId& group_id,
171 std::unique_ptr<StreamSocket> socket,
172 int64_t generation) {
173 CHECK_GT(handed_out_socket_count_, 0);
174 --handed_out_socket_count_;
175
176 ActivateStalledRequest();
177 }
178
FlushWithError(int error,const char * net_log_reason_utf8)179 void WebSocketTransportClientSocketPool::FlushWithError(
180 int error,
181 const char* net_log_reason_utf8) {
182 DCHECK_NE(error, OK);
183
184 // Sockets which are in LOAD_STATE_CONNECTING are in danger of unlocking
185 // sockets waiting for the endpoint lock. If they connected synchronously,
186 // then OnConnectJobComplete(). The |flushing_| flag tells this object to
187 // ignore spurious calls to OnConnectJobComplete(). It is safe to ignore those
188 // calls because this method will delete the jobs and call their callbacks
189 // anyway.
190 flushing_ = true;
191 for (auto it = pending_connects_.begin(); it != pending_connects_.end();) {
192 InvokeUserCallbackLater(it->second->socket_handle(),
193 it->second->release_callback(), error);
194 it->second->connect_job_net_log().AddEventWithStringParams(
195 NetLogEventType::SOCKET_POOL_CLOSING_SOCKET, "reason",
196 net_log_reason_utf8);
197 it = pending_connects_.erase(it);
198 }
199 for (auto& stalled_request : stalled_request_queue_) {
200 InvokeUserCallbackLater(stalled_request.handle,
201 std::move(stalled_request.callback), error);
202 }
203 stalled_request_map_.clear();
204 stalled_request_queue_.clear();
205 flushing_ = false;
206 }
207
CloseIdleSockets(const char * net_log_reason_utf8)208 void WebSocketTransportClientSocketPool::CloseIdleSockets(
209 const char* net_log_reason_utf8) {
210 // We have no idle sockets.
211 }
212
CloseIdleSocketsInGroup(const GroupId & group_id,const char * net_log_reason_utf8)213 void WebSocketTransportClientSocketPool::CloseIdleSocketsInGroup(
214 const GroupId& group_id,
215 const char* net_log_reason_utf8) {
216 // We have no idle sockets.
217 }
218
IdleSocketCount() const219 int WebSocketTransportClientSocketPool::IdleSocketCount() const {
220 return 0;
221 }
222
IdleSocketCountInGroup(const GroupId & group_id) const223 size_t WebSocketTransportClientSocketPool::IdleSocketCountInGroup(
224 const GroupId& group_id) const {
225 return 0;
226 }
227
GetLoadState(const GroupId & group_id,const ClientSocketHandle * handle) const228 LoadState WebSocketTransportClientSocketPool::GetLoadState(
229 const GroupId& group_id,
230 const ClientSocketHandle* handle) const {
231 if (stalled_request_map_.find(handle) != stalled_request_map_.end())
232 return LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET;
233 if (pending_callbacks_.count(reinterpret_cast<ClientSocketHandleID>(handle)))
234 return LOAD_STATE_CONNECTING;
235 return LookupConnectJob(handle)->GetLoadState();
236 }
237
GetInfoAsValue(const std::string & name,const std::string & type) const238 base::Value WebSocketTransportClientSocketPool::GetInfoAsValue(
239 const std::string& name,
240 const std::string& type) const {
241 base::Value::Dict dict;
242 dict.Set("name", name);
243 dict.Set("type", type);
244 dict.Set("handed_out_socket_count", handed_out_socket_count_);
245 dict.Set("connecting_socket_count",
246 static_cast<int>(pending_connects_.size()));
247 dict.Set("idle_socket_count", 0);
248 dict.Set("max_socket_count", max_sockets_);
249 dict.Set("max_sockets_per_group", max_sockets_);
250 return base::Value(std::move(dict));
251 }
252
HasActiveSocket(const GroupId & group_id) const253 bool WebSocketTransportClientSocketPool::HasActiveSocket(
254 const GroupId& group_id) const {
255 // This method is not supported for WebSocket.
256 NOTREACHED();
257 return false;
258 }
259
IsStalled() const260 bool WebSocketTransportClientSocketPool::IsStalled() const {
261 return !stalled_request_queue_.empty();
262 }
263
AddHigherLayeredPool(HigherLayeredPool * higher_pool)264 void WebSocketTransportClientSocketPool::AddHigherLayeredPool(
265 HigherLayeredPool* higher_pool) {
266 // This class doesn't use connection limits like the pools for HTTP do, so no
267 // need to track higher layered pools.
268 }
269
RemoveHigherLayeredPool(HigherLayeredPool * higher_pool)270 void WebSocketTransportClientSocketPool::RemoveHigherLayeredPool(
271 HigherLayeredPool* higher_pool) {
272 // This class doesn't use connection limits like the pools for HTTP do, so no
273 // need to track higher layered pools.
274 }
275
TryHandOutSocket(int result,ConnectJobDelegate * connect_job_delegate)276 bool WebSocketTransportClientSocketPool::TryHandOutSocket(
277 int result,
278 ConnectJobDelegate* connect_job_delegate) {
279 DCHECK_NE(result, ERR_IO_PENDING);
280
281 std::unique_ptr<StreamSocket> socket =
282 connect_job_delegate->connect_job()->PassSocket();
283 LoadTimingInfo::ConnectTiming connect_timing =
284 connect_job_delegate->connect_job()->connect_timing();
285 ClientSocketHandle* const handle = connect_job_delegate->socket_handle();
286 NetLogWithSource request_net_log = connect_job_delegate->request_net_log();
287
288 if (result == OK) {
289 DCHECK(socket);
290
291 HandOutSocket(std::move(socket), connect_timing, handle, request_net_log);
292
293 request_net_log.EndEvent(NetLogEventType::SOCKET_POOL);
294
295 return true;
296 }
297
298 bool handed_out_socket = false;
299
300 // If we got a socket, it must contain error information so pass that
301 // up so that the caller can retrieve it.
302 handle->SetAdditionalErrorState(connect_job_delegate->connect_job());
303 if (socket) {
304 HandOutSocket(std::move(socket), connect_timing, handle, request_net_log);
305 handed_out_socket = true;
306 }
307
308 request_net_log.EndEventWithNetErrorCode(NetLogEventType::SOCKET_POOL,
309 result);
310
311 return handed_out_socket;
312 }
313
OnConnectJobComplete(int result,ConnectJobDelegate * connect_job_delegate)314 void WebSocketTransportClientSocketPool::OnConnectJobComplete(
315 int result,
316 ConnectJobDelegate* connect_job_delegate) {
317 DCHECK_NE(ERR_IO_PENDING, result);
318
319 // See comment in FlushWithError.
320 if (flushing_) {
321 // Just delete the socket.
322 std::unique_ptr<StreamSocket> socket =
323 connect_job_delegate->connect_job()->PassSocket();
324 return;
325 }
326
327 bool handed_out_socket = TryHandOutSocket(result, connect_job_delegate);
328
329 CompletionOnceCallback callback = connect_job_delegate->release_callback();
330
331 ClientSocketHandle* const handle = connect_job_delegate->socket_handle();
332
333 bool delete_succeeded = DeleteJob(handle);
334 DCHECK(delete_succeeded);
335
336 connect_job_delegate = nullptr;
337
338 if (!handed_out_socket)
339 ActivateStalledRequest();
340
341 InvokeUserCallbackLater(handle, std::move(callback), result);
342 }
343
InvokeUserCallbackLater(ClientSocketHandle * handle,CompletionOnceCallback callback,int rv)344 void WebSocketTransportClientSocketPool::InvokeUserCallbackLater(
345 ClientSocketHandle* handle,
346 CompletionOnceCallback callback,
347 int rv) {
348 const auto handle_id = reinterpret_cast<ClientSocketHandleID>(handle);
349 DCHECK(!pending_callbacks_.count(handle_id));
350 pending_callbacks_.insert(handle_id);
351 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
352 FROM_HERE,
353 base::BindOnce(&WebSocketTransportClientSocketPool::InvokeUserCallback,
354 weak_factory_.GetWeakPtr(), handle_id, std::move(callback),
355 rv));
356 }
357
InvokeUserCallback(ClientSocketHandleID handle_id,CompletionOnceCallback callback,int rv)358 void WebSocketTransportClientSocketPool::InvokeUserCallback(
359 ClientSocketHandleID handle_id,
360 CompletionOnceCallback callback,
361 int rv) {
362 if (pending_callbacks_.erase(handle_id))
363 std::move(callback).Run(rv);
364 }
365
ReachedMaxSocketsLimit() const366 bool WebSocketTransportClientSocketPool::ReachedMaxSocketsLimit() const {
367 return handed_out_socket_count_ >= max_sockets_ ||
368 base::checked_cast<int>(pending_connects_.size()) >=
369 max_sockets_ - handed_out_socket_count_;
370 }
371
HandOutSocket(std::unique_ptr<StreamSocket> socket,const LoadTimingInfo::ConnectTiming & connect_timing,ClientSocketHandle * handle,const NetLogWithSource & net_log)372 void WebSocketTransportClientSocketPool::HandOutSocket(
373 std::unique_ptr<StreamSocket> socket,
374 const LoadTimingInfo::ConnectTiming& connect_timing,
375 ClientSocketHandle* handle,
376 const NetLogWithSource& net_log) {
377 DCHECK(socket);
378 DCHECK_EQ(ClientSocketHandle::UNUSED, handle->reuse_type());
379 DCHECK_EQ(0, handle->idle_time().InMicroseconds());
380
381 handle->SetSocket(std::move(socket));
382 handle->set_group_generation(0);
383 handle->set_connect_timing(connect_timing);
384
385 net_log.AddEventReferencingSource(
386 NetLogEventType::SOCKET_POOL_BOUND_TO_SOCKET,
387 handle->socket()->NetLog().source());
388
389 ++handed_out_socket_count_;
390 }
391
AddJob(ClientSocketHandle * handle,std::unique_ptr<ConnectJobDelegate> delegate)392 void WebSocketTransportClientSocketPool::AddJob(
393 ClientSocketHandle* handle,
394 std::unique_ptr<ConnectJobDelegate> delegate) {
395 bool inserted =
396 pending_connects_
397 .insert(PendingConnectsMap::value_type(handle, std::move(delegate)))
398 .second;
399 DCHECK(inserted);
400 }
401
DeleteJob(ClientSocketHandle * handle)402 bool WebSocketTransportClientSocketPool::DeleteJob(ClientSocketHandle* handle) {
403 auto it = pending_connects_.find(handle);
404 if (it == pending_connects_.end())
405 return false;
406 // Deleting a ConnectJob which holds an endpoint lock can lead to a different
407 // ConnectJob proceeding to connect. If the connect proceeds synchronously
408 // (usually because of a failure) then it can trigger that job to be
409 // deleted.
410 pending_connects_.erase(it);
411 return true;
412 }
413
LookupConnectJob(const ClientSocketHandle * handle) const414 const ConnectJob* WebSocketTransportClientSocketPool::LookupConnectJob(
415 const ClientSocketHandle* handle) const {
416 auto it = pending_connects_.find(handle);
417 CHECK(it != pending_connects_.end());
418 return it->second->connect_job();
419 }
420
ActivateStalledRequest()421 void WebSocketTransportClientSocketPool::ActivateStalledRequest() {
422 // Usually we will only be able to activate one stalled request at a time,
423 // however if all the connects fail synchronously for some reason, we may be
424 // able to clear the whole queue at once.
425 while (!stalled_request_queue_.empty() && !ReachedMaxSocketsLimit()) {
426 StalledRequest request = std::move(stalled_request_queue_.front());
427 stalled_request_queue_.pop_front();
428 stalled_request_map_.erase(request.handle);
429
430 auto split_callback = base::SplitOnceCallback(std::move(request.callback));
431
432 int rv = RequestSocket(
433 request.group_id, request.params, request.proxy_annotation_tag,
434 request.priority, SocketTag(),
435 // Stalled requests can't have |respect_limits|
436 // DISABLED.
437 RespectLimits::ENABLED, request.handle, std::move(split_callback.first),
438 request.proxy_auth_callback, request.net_log);
439
440 // ActivateStalledRequest() never returns synchronously, so it is never
441 // called re-entrantly.
442 if (rv != ERR_IO_PENDING)
443 InvokeUserCallbackLater(request.handle, std::move(split_callback.second),
444 rv);
445 }
446 }
447
DeleteStalledRequest(ClientSocketHandle * handle)448 bool WebSocketTransportClientSocketPool::DeleteStalledRequest(
449 ClientSocketHandle* handle) {
450 auto it = stalled_request_map_.find(handle);
451 if (it == stalled_request_map_.end())
452 return false;
453 stalled_request_queue_.erase(it->second);
454 stalled_request_map_.erase(it);
455 return true;
456 }
457
ConnectJobDelegate(WebSocketTransportClientSocketPool * owner,CompletionOnceCallback callback,ClientSocketHandle * socket_handle,const NetLogWithSource & request_net_log)458 WebSocketTransportClientSocketPool::ConnectJobDelegate::ConnectJobDelegate(
459 WebSocketTransportClientSocketPool* owner,
460 CompletionOnceCallback callback,
461 ClientSocketHandle* socket_handle,
462 const NetLogWithSource& request_net_log)
463 : owner_(owner),
464 callback_(std::move(callback)),
465 socket_handle_(socket_handle),
466 request_net_log_(request_net_log) {}
467
468 WebSocketTransportClientSocketPool::ConnectJobDelegate::~ConnectJobDelegate() =
469 default;
470
471 void
OnConnectJobComplete(int result,ConnectJob * job)472 WebSocketTransportClientSocketPool::ConnectJobDelegate::OnConnectJobComplete(
473 int result,
474 ConnectJob* job) {
475 DCHECK_EQ(job, connect_job_.get());
476 owner_->OnConnectJobComplete(result, this);
477 }
478
OnNeedsProxyAuth(const HttpResponseInfo & response,HttpAuthController * auth_controller,base::OnceClosure restart_with_auth_callback,ConnectJob * job)479 void WebSocketTransportClientSocketPool::ConnectJobDelegate::OnNeedsProxyAuth(
480 const HttpResponseInfo& response,
481 HttpAuthController* auth_controller,
482 base::OnceClosure restart_with_auth_callback,
483 ConnectJob* job) {
484 // This class isn't used for proxies.
485 NOTREACHED();
486 }
487
Connect(std::unique_ptr<ConnectJob> connect_job)488 int WebSocketTransportClientSocketPool::ConnectJobDelegate::Connect(
489 std::unique_ptr<ConnectJob> connect_job) {
490 connect_job_ = std::move(connect_job);
491 return connect_job_->Connect();
492 }
493
494 const NetLogWithSource&
connect_job_net_log()495 WebSocketTransportClientSocketPool::ConnectJobDelegate::connect_job_net_log() {
496 return connect_job_->net_log();
497 }
498
StalledRequest(const GroupId & group_id,const scoped_refptr<SocketParams> & params,const absl::optional<NetworkTrafficAnnotationTag> & proxy_annotation_tag,RequestPriority priority,ClientSocketHandle * handle,CompletionOnceCallback callback,const ProxyAuthCallback & proxy_auth_callback,const NetLogWithSource & net_log)499 WebSocketTransportClientSocketPool::StalledRequest::StalledRequest(
500 const GroupId& group_id,
501 const scoped_refptr<SocketParams>& params,
502 const absl::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag,
503 RequestPriority priority,
504 ClientSocketHandle* handle,
505 CompletionOnceCallback callback,
506 const ProxyAuthCallback& proxy_auth_callback,
507 const NetLogWithSource& net_log)
508 : group_id(group_id),
509 params(params),
510 proxy_annotation_tag(proxy_annotation_tag),
511 priority(priority),
512 handle(handle),
513 callback(std::move(callback)),
514 proxy_auth_callback(proxy_auth_callback),
515 net_log(net_log) {}
516
517 WebSocketTransportClientSocketPool::StalledRequest::StalledRequest(
518 StalledRequest&& other) = default;
519
520 WebSocketTransportClientSocketPool::StalledRequest::~StalledRequest() = default;
521
522 } // namespace net
523