1 // Copyright 2014 The Chromium Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #ifndef NET_SOCKET_WEBSOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_ 6 #define NET_SOCKET_WEBSOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_ 7 8 #include <list> 9 #include <map> 10 #include <memory> 11 #include <optional> 12 #include <set> 13 #include <string> 14 #include <utility> 15 16 #include "base/memory/raw_ptr.h" 17 #include "base/memory/scoped_refptr.h" 18 #include "base/memory/weak_ptr.h" 19 #include "base/timer/timer.h" 20 #include "net/base/net_export.h" 21 #include "net/base/proxy_chain.h" 22 #include "net/log/net_log_with_source.h" 23 #include "net/socket/client_socket_pool.h" 24 #include "net/socket/connect_job.h" 25 #include "net/socket/ssl_client_socket.h" 26 27 namespace net { 28 29 struct CommonConnectJobParams; 30 struct NetworkTrafficAnnotationTag; 31 class StreamSocketHandle; 32 33 // Identifier for a ClientSocketHandle to scope the lifetime of references. 34 // ClientSocketHandleID are derived from ClientSocketHandle*, used in 35 // comparison only, and are never dereferenced. We use an std::uintptr_t here to 36 // match the size of a pointer, and to prevent dereferencing. Also, our 37 // tooling complains about dangling pointers if we pass around a raw ptr. 38 using ClientSocketHandleID = std::uintptr_t; 39 40 class NET_EXPORT_PRIVATE WebSocketTransportClientSocketPool 41 : public ClientSocketPool { 42 public: 43 WebSocketTransportClientSocketPool( 44 int max_sockets, 45 int max_sockets_per_group, 46 const ProxyChain& proxy_chain, 47 const CommonConnectJobParams* common_connect_job_params); 48 49 WebSocketTransportClientSocketPool( 50 const WebSocketTransportClientSocketPool&) = delete; 51 WebSocketTransportClientSocketPool& operator=( 52 const WebSocketTransportClientSocketPool&) = delete; 53 54 ~WebSocketTransportClientSocketPool() override; 55 56 // Allow another connection to be started to the IPEndPoint that this |handle| 57 // is connected to. Used when the WebSocket handshake completes successfully. 58 // This only works if the socket is connected, however the caller does not 59 // need to explicitly check for this. Instead, ensure that dead sockets are 60 // returned to ReleaseSocket() in a timely fashion. 61 static void UnlockEndpoint( 62 StreamSocketHandle* handle, 63 WebSocketEndpointLockManager* websocket_endpoint_lock_manager); 64 65 // ClientSocketPool implementation. 66 int RequestSocket( 67 const GroupId& group_id, 68 scoped_refptr<SocketParams> params, 69 const std::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag, 70 RequestPriority priority, 71 const SocketTag& socket_tag, 72 RespectLimits respect_limits, 73 ClientSocketHandle* handle, 74 CompletionOnceCallback callback, 75 const ProxyAuthCallback& proxy_auth_callback, 76 const NetLogWithSource& net_log) override; 77 int RequestSockets( 78 const GroupId& group_id, 79 scoped_refptr<SocketParams> params, 80 const std::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag, 81 int num_sockets, 82 CompletionOnceCallback callback, 83 const NetLogWithSource& net_log) override; 84 void SetPriority(const GroupId& group_id, 85 ClientSocketHandle* handle, 86 RequestPriority priority) override; 87 void CancelRequest(const GroupId& group_id, 88 ClientSocketHandle* handle, 89 bool cancel_connect_job) override; 90 void ReleaseSocket(const GroupId& group_id, 91 std::unique_ptr<StreamSocket> socket, 92 int64_t generation) override; 93 void FlushWithError(int error, const char* net_log_reason_utf8) override; 94 void CloseIdleSockets(const char* net_log_reason_utf8) override; 95 void CloseIdleSocketsInGroup(const GroupId& group_id, 96 const char* net_log_reason_utf8) override; 97 int IdleSocketCount() const override; 98 size_t IdleSocketCountInGroup(const GroupId& group_id) const override; 99 LoadState GetLoadState(const GroupId& group_id, 100 const ClientSocketHandle* handle) const override; 101 base::Value GetInfoAsValue(const std::string& name, 102 const std::string& type) const override; 103 bool HasActiveSocket(const GroupId& group_id) const override; 104 105 // HigherLayeredPool implementation. 106 bool IsStalled() const override; 107 void AddHigherLayeredPool(HigherLayeredPool* higher_pool) override; 108 void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) override; 109 110 private: 111 class ConnectJobDelegate : public ConnectJob::Delegate { 112 public: 113 ConnectJobDelegate(WebSocketTransportClientSocketPool* owner, 114 CompletionOnceCallback callback, 115 ClientSocketHandle* socket_handle, 116 const NetLogWithSource& request_net_log); 117 118 ConnectJobDelegate(const ConnectJobDelegate&) = delete; 119 ConnectJobDelegate& operator=(const ConnectJobDelegate&) = delete; 120 121 ~ConnectJobDelegate() override; 122 123 // ConnectJob::Delegate implementation 124 void OnConnectJobComplete(int result, ConnectJob* job) override; 125 void OnNeedsProxyAuth(const HttpResponseInfo& response, 126 HttpAuthController* auth_controller, 127 base::OnceClosure restart_with_auth_callback, 128 ConnectJob* job) override; 129 130 // Calls Connect() on |connect_job|, and takes ownership. Returns Connect's 131 // return value. 132 int Connect(std::unique_ptr<ConnectJob> connect_job); 133 release_callback()134 CompletionOnceCallback release_callback() { return std::move(callback_); } connect_job()135 ConnectJob* connect_job() { return connect_job_.get(); } socket_handle()136 ClientSocketHandle* socket_handle() { return socket_handle_; } 137 request_net_log()138 const NetLogWithSource& request_net_log() { return request_net_log_; } 139 const NetLogWithSource& connect_job_net_log(); 140 141 private: 142 raw_ptr<WebSocketTransportClientSocketPool> owner_; 143 144 CompletionOnceCallback callback_; 145 std::unique_ptr<ConnectJob> connect_job_; 146 const raw_ptr<ClientSocketHandle> socket_handle_; 147 const NetLogWithSource request_net_log_; 148 }; 149 150 // Store the arguments from a call to RequestSocket() that has stalled so we 151 // can replay it when there are available socket slots. 152 struct StalledRequest { 153 StalledRequest( 154 const GroupId& group_id, 155 const scoped_refptr<SocketParams>& params, 156 const std::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag, 157 RequestPriority priority, 158 ClientSocketHandle* handle, 159 CompletionOnceCallback callback, 160 const ProxyAuthCallback& proxy_auth_callback, 161 const NetLogWithSource& net_log); 162 StalledRequest(StalledRequest&& other); 163 ~StalledRequest(); 164 165 const GroupId group_id; 166 const scoped_refptr<SocketParams> params; 167 const std::optional<NetworkTrafficAnnotationTag> proxy_annotation_tag; 168 const RequestPriority priority; 169 const raw_ptr<ClientSocketHandle> handle; 170 CompletionOnceCallback callback; 171 ProxyAuthCallback proxy_auth_callback; 172 const NetLogWithSource net_log; 173 }; 174 175 typedef std::map<const ClientSocketHandle*, 176 std::unique_ptr<ConnectJobDelegate>> 177 PendingConnectsMap; 178 // This is a list so that we can remove requests from the middle, and also 179 // so that iterators are not invalidated unless the corresponding request is 180 // removed. 181 typedef std::list<StalledRequest> StalledRequestQueue; 182 typedef std::map<const ClientSocketHandle*, StalledRequestQueue::iterator> 183 StalledRequestMap; 184 185 // Tries to hand out the socket connected by |job|. |result| must be (async) 186 // result of TransportConnectJob::Connect(). Returns true iff it has handed 187 // out a socket. 188 bool TryHandOutSocket(int result, ConnectJobDelegate* connect_job_delegate); 189 void OnConnectJobComplete(int result, 190 ConnectJobDelegate* connect_job_delegate); 191 void InvokeUserCallbackLater(ClientSocketHandle* handle, 192 CompletionOnceCallback callback, 193 int rv); 194 void InvokeUserCallback(ClientSocketHandleID handle_id, 195 base::WeakPtr<ClientSocketHandle> weak_handle, 196 CompletionOnceCallback callback, 197 int rv); 198 bool ReachedMaxSocketsLimit() const; 199 void HandOutSocket(std::unique_ptr<StreamSocket> socket, 200 const LoadTimingInfo::ConnectTiming& connect_timing, 201 ClientSocketHandle* handle, 202 const NetLogWithSource& net_log); 203 void AddJob(ClientSocketHandle* handle, 204 std::unique_ptr<ConnectJobDelegate> delegate); 205 bool DeleteJob(ClientSocketHandle* handle); 206 const ConnectJob* LookupConnectJob(const ClientSocketHandle* handle) const; 207 void ActivateStalledRequest(); 208 bool DeleteStalledRequest(ClientSocketHandle* handle); 209 210 const ProxyChain proxy_chain_; 211 std::set<ClientSocketHandleID> pending_callbacks_; 212 PendingConnectsMap pending_connects_; 213 StalledRequestQueue stalled_request_queue_; 214 StalledRequestMap stalled_request_map_; 215 const int max_sockets_; 216 int handed_out_socket_count_ = 0; 217 bool flushing_ = false; 218 219 base::WeakPtrFactory<WebSocketTransportClientSocketPool> weak_factory_{this}; 220 }; 221 222 } // namespace net 223 224 #endif // NET_SOCKET_WEBSOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_ 225