1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #ifndef NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_ 6 #define NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_ 7 8 #include <map> 9 #include <string> 10 #include <vector> 11 12 #include "base/memory/ref_counted.h" 13 #include "base/memory/scoped_ptr.h" 14 #include "base/time/time.h" 15 #include "net/base/privacy_mode.h" 16 #include "net/dns/host_resolver.h" 17 #include "net/http/http_response_info.h" 18 #include "net/socket/client_socket_pool.h" 19 #include "net/socket/client_socket_pool_base.h" 20 #include "net/socket/client_socket_pool_histograms.h" 21 #include "net/socket/ssl_client_socket.h" 22 #include "net/ssl/ssl_config_service.h" 23 24 namespace net { 25 26 class CertVerifier; 27 class ClientSocketFactory; 28 class ConnectJobFactory; 29 class CTVerifier; 30 class HostPortPair; 31 class HttpProxyClientSocketPool; 32 class HttpProxySocketParams; 33 class SOCKSClientSocketPool; 34 class SOCKSSocketParams; 35 class SSLClientSocket; 36 class TransportClientSocketPool; 37 class TransportSecurityState; 38 class TransportSocketParams; 39 40 class NET_EXPORT_PRIVATE SSLSocketParams 41 : public base::RefCounted<SSLSocketParams> { 42 public: 43 enum ConnectionType { DIRECT, SOCKS_PROXY, HTTP_PROXY }; 44 45 // Exactly one of |direct_params|, |socks_proxy_params|, and 46 // |http_proxy_params| must be non-NULL. 47 SSLSocketParams( 48 const scoped_refptr<TransportSocketParams>& direct_params, 49 const scoped_refptr<SOCKSSocketParams>& socks_proxy_params, 50 const scoped_refptr<HttpProxySocketParams>& http_proxy_params, 51 const HostPortPair& host_and_port, 52 const SSLConfig& ssl_config, 53 PrivacyMode privacy_mode, 54 int load_flags, 55 bool force_spdy_over_ssl, 56 bool want_spdy_over_npn); 57 58 // Returns the type of the underlying connection. 59 ConnectionType GetConnectionType() const; 60 61 // Must be called only when GetConnectionType() returns DIRECT. 62 const scoped_refptr<TransportSocketParams>& 63 GetDirectConnectionParams() const; 64 65 // Must be called only when GetConnectionType() returns SOCKS_PROXY. 66 const scoped_refptr<SOCKSSocketParams>& 67 GetSocksProxyConnectionParams() const; 68 69 // Must be called only when GetConnectionType() returns HTTP_PROXY. 70 const scoped_refptr<HttpProxySocketParams>& 71 GetHttpProxyConnectionParams() const; 72 host_and_port()73 const HostPortPair& host_and_port() const { return host_and_port_; } ssl_config()74 const SSLConfig& ssl_config() const { return ssl_config_; } privacy_mode()75 PrivacyMode privacy_mode() const { return privacy_mode_; } load_flags()76 int load_flags() const { return load_flags_; } force_spdy_over_ssl()77 bool force_spdy_over_ssl() const { return force_spdy_over_ssl_; } want_spdy_over_npn()78 bool want_spdy_over_npn() const { return want_spdy_over_npn_; } ignore_limits()79 bool ignore_limits() const { return ignore_limits_; } 80 81 private: 82 friend class base::RefCounted<SSLSocketParams>; 83 ~SSLSocketParams(); 84 85 const scoped_refptr<TransportSocketParams> direct_params_; 86 const scoped_refptr<SOCKSSocketParams> socks_proxy_params_; 87 const scoped_refptr<HttpProxySocketParams> http_proxy_params_; 88 const HostPortPair host_and_port_; 89 const SSLConfig ssl_config_; 90 const PrivacyMode privacy_mode_; 91 const int load_flags_; 92 const bool force_spdy_over_ssl_; 93 const bool want_spdy_over_npn_; 94 bool ignore_limits_; 95 96 DISALLOW_COPY_AND_ASSIGN(SSLSocketParams); 97 }; 98 99 // SSLConnectJobMessenger handles communication between concurrent 100 // SSLConnectJobs that share the same SSL session cache key. 101 // 102 // SSLConnectJobMessengers tell the session cache when a certain 103 // connection should be monitored for success or failure, and 104 // tell SSLConnectJobs when to pause or resume their connections. 105 class SSLConnectJobMessenger { 106 public: 107 struct SocketAndCallback { 108 SocketAndCallback(SSLClientSocket* ssl_socket, 109 const base::Closure& job_resumption_callback); 110 ~SocketAndCallback(); 111 112 SSLClientSocket* socket; 113 base::Closure callback; 114 }; 115 116 typedef std::vector<SocketAndCallback> SSLPendingSocketsAndCallbacks; 117 118 // |messenger_finished_callback| is run when a connection monitored by the 119 // SSLConnectJobMessenger has completed and we are finished with the 120 // SSLConnectJobMessenger. 121 explicit SSLConnectJobMessenger( 122 const base::Closure& messenger_finished_callback); 123 ~SSLConnectJobMessenger(); 124 125 // Removes |socket| from the set of sockets being monitored. This 126 // guarantees that |job_resumption_callback| will not be called for 127 // the socket. 128 void RemovePendingSocket(SSLClientSocket* ssl_socket); 129 130 // Returns true if |ssl_socket|'s Connect() method should be called. 131 bool CanProceed(SSLClientSocket* ssl_socket); 132 133 // Configures the SSLConnectJobMessenger to begin monitoring |ssl_socket|'s 134 // connection status. After a successful connection, or an error, 135 // the messenger will determine which sockets that have been added 136 // via AddPendingSocket() to allow to proceed. 137 void MonitorConnectionResult(SSLClientSocket* ssl_socket); 138 139 // Adds |socket| to the list of sockets waiting to Connect(). When 140 // the messenger has determined that it's an appropriate time for |socket| 141 // to connect, it will invoke |callback|. 142 // 143 // Note: It is an error to call AddPendingSocket() without having first 144 // called MonitorConnectionResult() and configuring a socket that WILL 145 // have Connect() called on it. 146 void AddPendingSocket(SSLClientSocket* ssl_socket, 147 const base::Closure& callback); 148 149 private: 150 // Processes pending callbacks when a socket completes its SSL handshake -- 151 // either successfully or unsuccessfully. 152 void OnSSLHandshakeCompleted(); 153 154 // Runs all callbacks stored in |pending_sockets_and_callbacks_|. 155 void RunAllCallbacks( 156 const SSLPendingSocketsAndCallbacks& pending_socket_and_callbacks); 157 158 SSLPendingSocketsAndCallbacks pending_sockets_and_callbacks_; 159 // Note: this field is a vector to allow for future design changes. Currently, 160 // this vector should only ever have one entry. 161 std::vector<SSLClientSocket*> connecting_sockets_; 162 163 base::Closure messenger_finished_callback_; 164 165 base::WeakPtrFactory<SSLConnectJobMessenger> weak_factory_; 166 }; 167 168 // SSLConnectJob handles the SSL handshake after setting up the underlying 169 // connection as specified in the params. 170 class SSLConnectJob : public ConnectJob { 171 public: 172 // Callback to allow the SSLConnectJob to obtain an SSLConnectJobMessenger to 173 // coordinate connecting. The SSLConnectJob will supply a unique identifer 174 // (ex: the SSL session cache key), with the expectation that the same 175 // Messenger will be returned for all such ConnectJobs. 176 // 177 // Note: It will only be called for situations where the SSL session cache 178 // does not already have a candidate session to resume. 179 typedef base::Callback<SSLConnectJobMessenger*(const std::string&)> 180 GetMessengerCallback; 181 182 // Note: the SSLConnectJob does not own |messenger| so it must outlive the 183 // job. 184 SSLConnectJob(const std::string& group_name, 185 RequestPriority priority, 186 const scoped_refptr<SSLSocketParams>& params, 187 const base::TimeDelta& timeout_duration, 188 TransportClientSocketPool* transport_pool, 189 SOCKSClientSocketPool* socks_pool, 190 HttpProxyClientSocketPool* http_proxy_pool, 191 ClientSocketFactory* client_socket_factory, 192 HostResolver* host_resolver, 193 const SSLClientSocketContext& context, 194 const GetMessengerCallback& get_messenger_callback, 195 Delegate* delegate, 196 NetLog* net_log); 197 virtual ~SSLConnectJob(); 198 199 // ConnectJob methods. 200 virtual LoadState GetLoadState() const OVERRIDE; 201 202 virtual void GetAdditionalErrorState(ClientSocketHandle * handle) OVERRIDE; 203 204 private: 205 enum State { 206 STATE_TRANSPORT_CONNECT, 207 STATE_TRANSPORT_CONNECT_COMPLETE, 208 STATE_SOCKS_CONNECT, 209 STATE_SOCKS_CONNECT_COMPLETE, 210 STATE_TUNNEL_CONNECT, 211 STATE_TUNNEL_CONNECT_COMPLETE, 212 STATE_CREATE_SSL_SOCKET, 213 STATE_CHECK_FOR_RESUME, 214 STATE_SSL_CONNECT, 215 STATE_SSL_CONNECT_COMPLETE, 216 STATE_NONE, 217 }; 218 219 void OnIOComplete(int result); 220 221 // Runs the state transition loop. 222 int DoLoop(int result); 223 224 int DoTransportConnect(); 225 int DoTransportConnectComplete(int result); 226 int DoSOCKSConnect(); 227 int DoSOCKSConnectComplete(int result); 228 int DoTunnelConnect(); 229 int DoTunnelConnectComplete(int result); 230 int DoCreateSSLSocket(); 231 int DoCheckForResume(); 232 int DoSSLConnect(); 233 int DoSSLConnectComplete(int result); 234 235 // Tells a waiting SSLConnectJob to resume its SSL connection. 236 void ResumeSSLConnection(); 237 238 // Returns the initial state for the state machine based on the 239 // |connection_type|. 240 static State GetInitialState(SSLSocketParams::ConnectionType connection_type); 241 242 // Starts the SSL connection process. Returns OK on success and 243 // ERR_IO_PENDING if it cannot immediately service the request. 244 // Otherwise, it returns a net error code. 245 virtual int ConnectInternal() OVERRIDE; 246 247 scoped_refptr<SSLSocketParams> params_; 248 TransportClientSocketPool* const transport_pool_; 249 SOCKSClientSocketPool* const socks_pool_; 250 HttpProxyClientSocketPool* const http_proxy_pool_; 251 ClientSocketFactory* const client_socket_factory_; 252 HostResolver* const host_resolver_; 253 254 const SSLClientSocketContext context_; 255 256 State next_state_; 257 CompletionCallback io_callback_; 258 scoped_ptr<ClientSocketHandle> transport_socket_handle_; 259 scoped_ptr<SSLClientSocket> ssl_socket_; 260 261 SSLConnectJobMessenger* messenger_; 262 HttpResponseInfo error_response_info_; 263 264 GetMessengerCallback get_messenger_callback_; 265 266 base::WeakPtrFactory<SSLConnectJob> weak_factory_; 267 268 DISALLOW_COPY_AND_ASSIGN(SSLConnectJob); 269 }; 270 271 class NET_EXPORT_PRIVATE SSLClientSocketPool 272 : public ClientSocketPool, 273 public HigherLayeredPool, 274 public SSLConfigService::Observer { 275 public: 276 typedef SSLSocketParams SocketParams; 277 278 // Only the pools that will be used are required. i.e. if you never 279 // try to create an SSL over SOCKS socket, |socks_pool| may be NULL. 280 SSLClientSocketPool(int max_sockets, 281 int max_sockets_per_group, 282 ClientSocketPoolHistograms* histograms, 283 HostResolver* host_resolver, 284 CertVerifier* cert_verifier, 285 ChannelIDService* channel_id_service, 286 TransportSecurityState* transport_security_state, 287 CTVerifier* cert_transparency_verifier, 288 const std::string& ssl_session_cache_shard, 289 ClientSocketFactory* client_socket_factory, 290 TransportClientSocketPool* transport_pool, 291 SOCKSClientSocketPool* socks_pool, 292 HttpProxyClientSocketPool* http_proxy_pool, 293 SSLConfigService* ssl_config_service, 294 bool enable_ssl_connect_job_waiting, 295 NetLog* net_log); 296 297 virtual ~SSLClientSocketPool(); 298 299 // ClientSocketPool implementation. 300 virtual int RequestSocket(const std::string& group_name, 301 const void* connect_params, 302 RequestPriority priority, 303 ClientSocketHandle* handle, 304 const CompletionCallback& callback, 305 const BoundNetLog& net_log) OVERRIDE; 306 307 virtual void RequestSockets(const std::string& group_name, 308 const void* params, 309 int num_sockets, 310 const BoundNetLog& net_log) OVERRIDE; 311 312 virtual void CancelRequest(const std::string& group_name, 313 ClientSocketHandle* handle) OVERRIDE; 314 315 virtual void ReleaseSocket(const std::string& group_name, 316 scoped_ptr<StreamSocket> socket, 317 int id) OVERRIDE; 318 319 virtual void FlushWithError(int error) OVERRIDE; 320 321 virtual void CloseIdleSockets() OVERRIDE; 322 323 virtual int IdleSocketCount() const OVERRIDE; 324 325 virtual int IdleSocketCountInGroup( 326 const std::string& group_name) const OVERRIDE; 327 328 virtual LoadState GetLoadState( 329 const std::string& group_name, 330 const ClientSocketHandle* handle) const OVERRIDE; 331 332 virtual base::DictionaryValue* GetInfoAsValue( 333 const std::string& name, 334 const std::string& type, 335 bool include_nested_pools) const OVERRIDE; 336 337 virtual base::TimeDelta ConnectionTimeout() const OVERRIDE; 338 339 virtual ClientSocketPoolHistograms* histograms() const OVERRIDE; 340 341 // LowerLayeredPool implementation. 342 virtual bool IsStalled() const OVERRIDE; 343 344 virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE; 345 346 virtual void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE; 347 348 // HigherLayeredPool implementation. 349 virtual bool CloseOneIdleConnection() OVERRIDE; 350 351 // Gets the SSLConnectJobMessenger for the given ssl session |cache_key|. If 352 // none exits, it creates one and stores it in |messenger_map_|. 353 SSLConnectJobMessenger* GetOrCreateSSLConnectJobMessenger( 354 const std::string& cache_key); 355 void DeleteSSLConnectJobMessenger(const std::string& cache_key); 356 357 private: 358 typedef ClientSocketPoolBase<SSLSocketParams> PoolBase; 359 // Maps SSLConnectJob cache keys to SSLConnectJobMessenger objects. 360 typedef std::map<std::string, SSLConnectJobMessenger*> MessengerMap; 361 362 // SSLConfigService::Observer implementation. 363 364 // When the user changes the SSL config, we flush all idle sockets so they 365 // won't get re-used. 366 virtual void OnSSLConfigChanged() OVERRIDE; 367 368 class SSLConnectJobFactory : public PoolBase::ConnectJobFactory { 369 public: 370 SSLConnectJobFactory( 371 TransportClientSocketPool* transport_pool, 372 SOCKSClientSocketPool* socks_pool, 373 HttpProxyClientSocketPool* http_proxy_pool, 374 ClientSocketFactory* client_socket_factory, 375 HostResolver* host_resolver, 376 const SSLClientSocketContext& context, 377 const SSLConnectJob::GetMessengerCallback& get_messenger_callback, 378 NetLog* net_log); 379 380 virtual ~SSLConnectJobFactory(); 381 382 // ClientSocketPoolBase::ConnectJobFactory methods. 383 virtual scoped_ptr<ConnectJob> NewConnectJob( 384 const std::string& group_name, 385 const PoolBase::Request& request, 386 ConnectJob::Delegate* delegate) const OVERRIDE; 387 388 virtual base::TimeDelta ConnectionTimeout() const OVERRIDE; 389 390 private: 391 TransportClientSocketPool* const transport_pool_; 392 SOCKSClientSocketPool* const socks_pool_; 393 HttpProxyClientSocketPool* const http_proxy_pool_; 394 ClientSocketFactory* const client_socket_factory_; 395 HostResolver* const host_resolver_; 396 const SSLClientSocketContext context_; 397 base::TimeDelta timeout_; 398 SSLConnectJob::GetMessengerCallback get_messenger_callback_; 399 NetLog* net_log_; 400 401 DISALLOW_COPY_AND_ASSIGN(SSLConnectJobFactory); 402 }; 403 404 TransportClientSocketPool* const transport_pool_; 405 SOCKSClientSocketPool* const socks_pool_; 406 HttpProxyClientSocketPool* const http_proxy_pool_; 407 PoolBase base_; 408 const scoped_refptr<SSLConfigService> ssl_config_service_; 409 MessengerMap messenger_map_; 410 bool enable_ssl_connect_job_waiting_; 411 412 DISALLOW_COPY_AND_ASSIGN(SSLClientSocketPool); 413 }; 414 415 } // namespace net 416 417 #endif // NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_ 418