• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/socks_client_socket_pool.h"
6 
7 #include "base/bind.h"
8 #include "base/bind_helpers.h"
9 #include "base/time/time.h"
10 #include "base/values.h"
11 #include "net/base/net_errors.h"
12 #include "net/socket/client_socket_factory.h"
13 #include "net/socket/client_socket_handle.h"
14 #include "net/socket/client_socket_pool_base.h"
15 #include "net/socket/socks5_client_socket.h"
16 #include "net/socket/socks_client_socket.h"
17 #include "net/socket/transport_client_socket_pool.h"
18 
19 namespace net {
20 
SOCKSSocketParams(const scoped_refptr<TransportSocketParams> & proxy_server,bool socks_v5,const HostPortPair & host_port_pair)21 SOCKSSocketParams::SOCKSSocketParams(
22     const scoped_refptr<TransportSocketParams>& proxy_server,
23     bool socks_v5,
24     const HostPortPair& host_port_pair)
25     : transport_params_(proxy_server),
26       destination_(host_port_pair),
27       socks_v5_(socks_v5) {
28   if (transport_params_.get())
29     ignore_limits_ = transport_params_->ignore_limits();
30   else
31     ignore_limits_ = false;
32 }
33 
~SOCKSSocketParams()34 SOCKSSocketParams::~SOCKSSocketParams() {}
35 
36 // SOCKSConnectJobs will time out after this many seconds.  Note this is on
37 // top of the timeout for the transport socket.
38 static const int kSOCKSConnectJobTimeoutInSeconds = 30;
39 
SOCKSConnectJob(const std::string & group_name,RequestPriority priority,const scoped_refptr<SOCKSSocketParams> & socks_params,const base::TimeDelta & timeout_duration,TransportClientSocketPool * transport_pool,HostResolver * host_resolver,Delegate * delegate,NetLog * net_log)40 SOCKSConnectJob::SOCKSConnectJob(
41     const std::string& group_name,
42     RequestPriority priority,
43     const scoped_refptr<SOCKSSocketParams>& socks_params,
44     const base::TimeDelta& timeout_duration,
45     TransportClientSocketPool* transport_pool,
46     HostResolver* host_resolver,
47     Delegate* delegate,
48     NetLog* net_log)
49     : ConnectJob(group_name, timeout_duration, priority, delegate,
50                  BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)),
51       socks_params_(socks_params),
52       transport_pool_(transport_pool),
53       resolver_(host_resolver),
54       callback_(base::Bind(&SOCKSConnectJob::OnIOComplete,
55                            base::Unretained(this))) {
56 }
57 
~SOCKSConnectJob()58 SOCKSConnectJob::~SOCKSConnectJob() {
59   // We don't worry about cancelling the tcp socket since the destructor in
60   // scoped_ptr<ClientSocketHandle> transport_socket_handle_ will take care of
61   // it.
62 }
63 
GetLoadState() const64 LoadState SOCKSConnectJob::GetLoadState() const {
65   switch (next_state_) {
66     case STATE_TRANSPORT_CONNECT:
67     case STATE_TRANSPORT_CONNECT_COMPLETE:
68       return transport_socket_handle_->GetLoadState();
69     case STATE_SOCKS_CONNECT:
70     case STATE_SOCKS_CONNECT_COMPLETE:
71       return LOAD_STATE_CONNECTING;
72     default:
73       NOTREACHED();
74       return LOAD_STATE_IDLE;
75   }
76 }
77 
OnIOComplete(int result)78 void SOCKSConnectJob::OnIOComplete(int result) {
79   int rv = DoLoop(result);
80   if (rv != ERR_IO_PENDING)
81     NotifyDelegateOfCompletion(rv);  // Deletes |this|
82 }
83 
DoLoop(int result)84 int SOCKSConnectJob::DoLoop(int result) {
85   DCHECK_NE(next_state_, STATE_NONE);
86 
87   int rv = result;
88   do {
89     State state = next_state_;
90     next_state_ = STATE_NONE;
91     switch (state) {
92       case STATE_TRANSPORT_CONNECT:
93         DCHECK_EQ(OK, rv);
94         rv = DoTransportConnect();
95         break;
96       case STATE_TRANSPORT_CONNECT_COMPLETE:
97         rv = DoTransportConnectComplete(rv);
98         break;
99       case STATE_SOCKS_CONNECT:
100         DCHECK_EQ(OK, rv);
101         rv = DoSOCKSConnect();
102         break;
103       case STATE_SOCKS_CONNECT_COMPLETE:
104         rv = DoSOCKSConnectComplete(rv);
105         break;
106       default:
107         NOTREACHED() << "bad state";
108         rv = ERR_FAILED;
109         break;
110     }
111   } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
112 
113   return rv;
114 }
115 
DoTransportConnect()116 int SOCKSConnectJob::DoTransportConnect() {
117   next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE;
118   transport_socket_handle_.reset(new ClientSocketHandle());
119   return transport_socket_handle_->Init(group_name(),
120                                         socks_params_->transport_params(),
121                                         priority(),
122                                         callback_,
123                                         transport_pool_,
124                                         net_log());
125 }
126 
DoTransportConnectComplete(int result)127 int SOCKSConnectJob::DoTransportConnectComplete(int result) {
128   if (result != OK)
129     return ERR_PROXY_CONNECTION_FAILED;
130 
131   // Reset the timer to just the length of time allowed for SOCKS handshake
132   // so that a fast TCP connection plus a slow SOCKS failure doesn't take
133   // longer to timeout than it should.
134   ResetTimer(base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds));
135   next_state_ = STATE_SOCKS_CONNECT;
136   return result;
137 }
138 
DoSOCKSConnect()139 int SOCKSConnectJob::DoSOCKSConnect() {
140   next_state_ = STATE_SOCKS_CONNECT_COMPLETE;
141 
142   // Add a SOCKS connection on top of the tcp socket.
143   if (socks_params_->is_socks_v5()) {
144     socket_.reset(new SOCKS5ClientSocket(transport_socket_handle_.Pass(),
145                                          socks_params_->destination()));
146   } else {
147     socket_.reset(new SOCKSClientSocket(transport_socket_handle_.Pass(),
148                                         socks_params_->destination(),
149                                         priority(),
150                                         resolver_));
151   }
152   return socket_->Connect(
153       base::Bind(&SOCKSConnectJob::OnIOComplete, base::Unretained(this)));
154 }
155 
DoSOCKSConnectComplete(int result)156 int SOCKSConnectJob::DoSOCKSConnectComplete(int result) {
157   if (result != OK) {
158     socket_->Disconnect();
159     return result;
160   }
161 
162   SetSocket(socket_.Pass());
163   return result;
164 }
165 
ConnectInternal()166 int SOCKSConnectJob::ConnectInternal() {
167   next_state_ = STATE_TRANSPORT_CONNECT;
168   return DoLoop(OK);
169 }
170 
171 scoped_ptr<ConnectJob>
NewConnectJob(const std::string & group_name,const PoolBase::Request & request,ConnectJob::Delegate * delegate) const172 SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob(
173     const std::string& group_name,
174     const PoolBase::Request& request,
175     ConnectJob::Delegate* delegate) const {
176   return scoped_ptr<ConnectJob>(new SOCKSConnectJob(group_name,
177                                                     request.priority(),
178                                                     request.params(),
179                                                     ConnectionTimeout(),
180                                                     transport_pool_,
181                                                     host_resolver_,
182                                                     delegate,
183                                                     net_log_));
184 }
185 
186 base::TimeDelta
ConnectionTimeout() const187 SOCKSClientSocketPool::SOCKSConnectJobFactory::ConnectionTimeout() const {
188   return transport_pool_->ConnectionTimeout() +
189       base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds);
190 }
191 
SOCKSClientSocketPool(int max_sockets,int max_sockets_per_group,ClientSocketPoolHistograms * histograms,HostResolver * host_resolver,TransportClientSocketPool * transport_pool,NetLog * net_log)192 SOCKSClientSocketPool::SOCKSClientSocketPool(
193     int max_sockets,
194     int max_sockets_per_group,
195     ClientSocketPoolHistograms* histograms,
196     HostResolver* host_resolver,
197     TransportClientSocketPool* transport_pool,
198     NetLog* net_log)
199     : transport_pool_(transport_pool),
200       base_(this, max_sockets, max_sockets_per_group, histograms,
201             ClientSocketPool::unused_idle_socket_timeout(),
202             ClientSocketPool::used_idle_socket_timeout(),
203             new SOCKSConnectJobFactory(transport_pool,
204                                        host_resolver,
205                                        net_log)) {
206   // We should always have a |transport_pool_| except in unit tests.
207   if (transport_pool_)
208     base_.AddLowerLayeredPool(transport_pool_);
209 }
210 
~SOCKSClientSocketPool()211 SOCKSClientSocketPool::~SOCKSClientSocketPool() {
212 }
213 
RequestSocket(const std::string & group_name,const void * socket_params,RequestPriority priority,ClientSocketHandle * handle,const CompletionCallback & callback,const BoundNetLog & net_log)214 int SOCKSClientSocketPool::RequestSocket(
215     const std::string& group_name, const void* socket_params,
216     RequestPriority priority, ClientSocketHandle* handle,
217     const CompletionCallback& callback, const BoundNetLog& net_log) {
218   const scoped_refptr<SOCKSSocketParams>* casted_socket_params =
219       static_cast<const scoped_refptr<SOCKSSocketParams>*>(socket_params);
220 
221   return base_.RequestSocket(group_name, *casted_socket_params, priority,
222                              handle, callback, net_log);
223 }
224 
RequestSockets(const std::string & group_name,const void * params,int num_sockets,const BoundNetLog & net_log)225 void SOCKSClientSocketPool::RequestSockets(
226     const std::string& group_name,
227     const void* params,
228     int num_sockets,
229     const BoundNetLog& net_log) {
230   const scoped_refptr<SOCKSSocketParams>* casted_params =
231       static_cast<const scoped_refptr<SOCKSSocketParams>*>(params);
232 
233   base_.RequestSockets(group_name, *casted_params, num_sockets, net_log);
234 }
235 
CancelRequest(const std::string & group_name,ClientSocketHandle * handle)236 void SOCKSClientSocketPool::CancelRequest(const std::string& group_name,
237                                           ClientSocketHandle* handle) {
238   base_.CancelRequest(group_name, handle);
239 }
240 
ReleaseSocket(const std::string & group_name,scoped_ptr<StreamSocket> socket,int id)241 void SOCKSClientSocketPool::ReleaseSocket(const std::string& group_name,
242                                           scoped_ptr<StreamSocket> socket,
243                                           int id) {
244   base_.ReleaseSocket(group_name, socket.Pass(), id);
245 }
246 
FlushWithError(int error)247 void SOCKSClientSocketPool::FlushWithError(int error) {
248   base_.FlushWithError(error);
249 }
250 
CloseIdleSockets()251 void SOCKSClientSocketPool::CloseIdleSockets() {
252   base_.CloseIdleSockets();
253 }
254 
IdleSocketCount() const255 int SOCKSClientSocketPool::IdleSocketCount() const {
256   return base_.idle_socket_count();
257 }
258 
IdleSocketCountInGroup(const std::string & group_name) const259 int SOCKSClientSocketPool::IdleSocketCountInGroup(
260     const std::string& group_name) const {
261   return base_.IdleSocketCountInGroup(group_name);
262 }
263 
GetLoadState(const std::string & group_name,const ClientSocketHandle * handle) const264 LoadState SOCKSClientSocketPool::GetLoadState(
265     const std::string& group_name, const ClientSocketHandle* handle) const {
266   return base_.GetLoadState(group_name, handle);
267 }
268 
GetInfoAsValue(const std::string & name,const std::string & type,bool include_nested_pools) const269 base::DictionaryValue* SOCKSClientSocketPool::GetInfoAsValue(
270     const std::string& name,
271     const std::string& type,
272     bool include_nested_pools) const {
273   base::DictionaryValue* dict = base_.GetInfoAsValue(name, type);
274   if (include_nested_pools) {
275     base::ListValue* list = new base::ListValue();
276     list->Append(transport_pool_->GetInfoAsValue("transport_socket_pool",
277                                                  "transport_socket_pool",
278                                                  false));
279     dict->Set("nested_pools", list);
280   }
281   return dict;
282 }
283 
ConnectionTimeout() const284 base::TimeDelta SOCKSClientSocketPool::ConnectionTimeout() const {
285   return base_.ConnectionTimeout();
286 }
287 
histograms() const288 ClientSocketPoolHistograms* SOCKSClientSocketPool::histograms() const {
289   return base_.histograms();
290 };
291 
IsStalled() const292 bool SOCKSClientSocketPool::IsStalled() const {
293   return base_.IsStalled();
294 }
295 
AddHigherLayeredPool(HigherLayeredPool * higher_pool)296 void SOCKSClientSocketPool::AddHigherLayeredPool(
297     HigherLayeredPool* higher_pool) {
298   base_.AddHigherLayeredPool(higher_pool);
299 }
300 
RemoveHigherLayeredPool(HigherLayeredPool * higher_pool)301 void SOCKSClientSocketPool::RemoveHigherLayeredPool(
302     HigherLayeredPool* higher_pool) {
303   base_.RemoveHigherLayeredPool(higher_pool);
304 }
305 
CloseOneIdleConnection()306 bool SOCKSClientSocketPool::CloseOneIdleConnection() {
307   if (base_.CloseOneIdleSocket())
308     return true;
309   return base_.CloseOneIdleConnectionInHigherLayeredPool();
310 }
311 
312 }  // namespace net
313