• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/socks_client_socket.h"
6 
7 #include "base/basictypes.h"
8 #include "base/bind.h"
9 #include "base/callback_helpers.h"
10 #include "base/compiler_specific.h"
11 #include "base/sys_byteorder.h"
12 #include "net/base/io_buffer.h"
13 #include "net/base/net_log.h"
14 #include "net/base/net_util.h"
15 #include "net/socket/client_socket_handle.h"
16 
17 namespace net {
18 
19 // Every SOCKS server requests a user-id from the client. It is optional
20 // and we send an empty string.
21 static const char kEmptyUserId[] = "";
22 
23 // For SOCKS4, the client sends 8 bytes  plus the size of the user-id.
24 static const unsigned int kWriteHeaderSize = 8;
25 
26 // For SOCKS4 the server sends 8 bytes for acknowledgement.
27 static const unsigned int kReadHeaderSize = 8;
28 
29 // Server Response codes for SOCKS.
30 static const uint8 kServerResponseOk  = 0x5A;
31 static const uint8 kServerResponseRejected = 0x5B;
32 static const uint8 kServerResponseNotReachable = 0x5C;
33 static const uint8 kServerResponseMismatchedUserId = 0x5D;
34 
35 static const uint8 kSOCKSVersion4 = 0x04;
36 static const uint8 kSOCKSStreamRequest = 0x01;
37 
38 // A struct holding the essential details of the SOCKS4 Server Request.
39 // The port in the header is stored in network byte order.
40 struct SOCKS4ServerRequest {
41   uint8 version;
42   uint8 command;
43   uint16 nw_port;
44   uint8 ip[4];
45 };
46 COMPILE_ASSERT(sizeof(SOCKS4ServerRequest) == kWriteHeaderSize,
47                socks4_server_request_struct_wrong_size);
48 
49 // A struct holding details of the SOCKS4 Server Response.
50 struct SOCKS4ServerResponse {
51   uint8 reserved_null;
52   uint8 code;
53   uint16 port;
54   uint8 ip[4];
55 };
56 COMPILE_ASSERT(sizeof(SOCKS4ServerResponse) == kReadHeaderSize,
57                socks4_server_response_struct_wrong_size);
58 
SOCKSClientSocket(scoped_ptr<ClientSocketHandle> transport_socket,const HostResolver::RequestInfo & req_info,RequestPriority priority,HostResolver * host_resolver)59 SOCKSClientSocket::SOCKSClientSocket(
60     scoped_ptr<ClientSocketHandle> transport_socket,
61     const HostResolver::RequestInfo& req_info,
62     RequestPriority priority,
63     HostResolver* host_resolver)
64     : transport_(transport_socket.Pass()),
65       next_state_(STATE_NONE),
66       completed_handshake_(false),
67       bytes_sent_(0),
68       bytes_received_(0),
69       was_ever_used_(false),
70       host_resolver_(host_resolver),
71       host_request_info_(req_info),
72       priority_(priority),
73       net_log_(transport_->socket()->NetLog()) {}
74 
~SOCKSClientSocket()75 SOCKSClientSocket::~SOCKSClientSocket() {
76   Disconnect();
77 }
78 
Connect(const CompletionCallback & callback)79 int SOCKSClientSocket::Connect(const CompletionCallback& callback) {
80   DCHECK(transport_.get());
81   DCHECK(transport_->socket());
82   DCHECK_EQ(STATE_NONE, next_state_);
83   DCHECK(user_callback_.is_null());
84 
85   // If already connected, then just return OK.
86   if (completed_handshake_)
87     return OK;
88 
89   next_state_ = STATE_RESOLVE_HOST;
90 
91   net_log_.BeginEvent(NetLog::TYPE_SOCKS_CONNECT);
92 
93   int rv = DoLoop(OK);
94   if (rv == ERR_IO_PENDING) {
95     user_callback_ = callback;
96   } else {
97     net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv);
98   }
99   return rv;
100 }
101 
Disconnect()102 void SOCKSClientSocket::Disconnect() {
103   completed_handshake_ = false;
104   host_resolver_.Cancel();
105   transport_->socket()->Disconnect();
106 
107   // Reset other states to make sure they aren't mistakenly used later.
108   // These are the states initialized by Connect().
109   next_state_ = STATE_NONE;
110   user_callback_.Reset();
111 }
112 
IsConnected() const113 bool SOCKSClientSocket::IsConnected() const {
114   return completed_handshake_ && transport_->socket()->IsConnected();
115 }
116 
IsConnectedAndIdle() const117 bool SOCKSClientSocket::IsConnectedAndIdle() const {
118   return completed_handshake_ && transport_->socket()->IsConnectedAndIdle();
119 }
120 
NetLog() const121 const BoundNetLog& SOCKSClientSocket::NetLog() const {
122   return net_log_;
123 }
124 
SetSubresourceSpeculation()125 void SOCKSClientSocket::SetSubresourceSpeculation() {
126   if (transport_.get() && transport_->socket()) {
127     transport_->socket()->SetSubresourceSpeculation();
128   } else {
129     NOTREACHED();
130   }
131 }
132 
SetOmniboxSpeculation()133 void SOCKSClientSocket::SetOmniboxSpeculation() {
134   if (transport_.get() && transport_->socket()) {
135     transport_->socket()->SetOmniboxSpeculation();
136   } else {
137     NOTREACHED();
138   }
139 }
140 
WasEverUsed() const141 bool SOCKSClientSocket::WasEverUsed() const {
142   return was_ever_used_;
143 }
144 
UsingTCPFastOpen() const145 bool SOCKSClientSocket::UsingTCPFastOpen() const {
146   if (transport_.get() && transport_->socket()) {
147     return transport_->socket()->UsingTCPFastOpen();
148   }
149   NOTREACHED();
150   return false;
151 }
152 
WasNpnNegotiated() const153 bool SOCKSClientSocket::WasNpnNegotiated() const {
154   if (transport_.get() && transport_->socket()) {
155     return transport_->socket()->WasNpnNegotiated();
156   }
157   NOTREACHED();
158   return false;
159 }
160 
GetNegotiatedProtocol() const161 NextProto SOCKSClientSocket::GetNegotiatedProtocol() const {
162   if (transport_.get() && transport_->socket()) {
163     return transport_->socket()->GetNegotiatedProtocol();
164   }
165   NOTREACHED();
166   return kProtoUnknown;
167 }
168 
GetSSLInfo(SSLInfo * ssl_info)169 bool SOCKSClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
170   if (transport_.get() && transport_->socket()) {
171     return transport_->socket()->GetSSLInfo(ssl_info);
172   }
173   NOTREACHED();
174   return false;
175 
176 }
177 
178 // Read is called by the transport layer above to read. This can only be done
179 // if the SOCKS handshake is complete.
Read(IOBuffer * buf,int buf_len,const CompletionCallback & callback)180 int SOCKSClientSocket::Read(IOBuffer* buf, int buf_len,
181                             const CompletionCallback& callback) {
182   DCHECK(completed_handshake_);
183   DCHECK_EQ(STATE_NONE, next_state_);
184   DCHECK(user_callback_.is_null());
185   DCHECK(!callback.is_null());
186 
187   int rv = transport_->socket()->Read(
188       buf, buf_len,
189       base::Bind(&SOCKSClientSocket::OnReadWriteComplete,
190                  base::Unretained(this), callback));
191   if (rv > 0)
192     was_ever_used_ = true;
193   return rv;
194 }
195 
196 // Write is called by the transport layer. This can only be done if the
197 // SOCKS handshake is complete.
Write(IOBuffer * buf,int buf_len,const CompletionCallback & callback)198 int SOCKSClientSocket::Write(IOBuffer* buf, int buf_len,
199                              const CompletionCallback& callback) {
200   DCHECK(completed_handshake_);
201   DCHECK_EQ(STATE_NONE, next_state_);
202   DCHECK(user_callback_.is_null());
203   DCHECK(!callback.is_null());
204 
205   int rv = transport_->socket()->Write(
206       buf, buf_len,
207       base::Bind(&SOCKSClientSocket::OnReadWriteComplete,
208                  base::Unretained(this), callback));
209   if (rv > 0)
210     was_ever_used_ = true;
211   return rv;
212 }
213 
SetReceiveBufferSize(int32 size)214 int SOCKSClientSocket::SetReceiveBufferSize(int32 size) {
215   return transport_->socket()->SetReceiveBufferSize(size);
216 }
217 
SetSendBufferSize(int32 size)218 int SOCKSClientSocket::SetSendBufferSize(int32 size) {
219   return transport_->socket()->SetSendBufferSize(size);
220 }
221 
DoCallback(int result)222 void SOCKSClientSocket::DoCallback(int result) {
223   DCHECK_NE(ERR_IO_PENDING, result);
224   DCHECK(!user_callback_.is_null());
225 
226   // Since Run() may result in Read being called,
227   // clear user_callback_ up front.
228   DVLOG(1) << "Finished setting up SOCKS handshake";
229   base::ResetAndReturn(&user_callback_).Run(result);
230 }
231 
OnIOComplete(int result)232 void SOCKSClientSocket::OnIOComplete(int result) {
233   DCHECK_NE(STATE_NONE, next_state_);
234   int rv = DoLoop(result);
235   if (rv != ERR_IO_PENDING) {
236     net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv);
237     DoCallback(rv);
238   }
239 }
240 
OnReadWriteComplete(const CompletionCallback & callback,int result)241 void SOCKSClientSocket::OnReadWriteComplete(const CompletionCallback& callback,
242                                             int result) {
243   DCHECK_NE(ERR_IO_PENDING, result);
244   DCHECK(!callback.is_null());
245 
246   if (result > 0)
247     was_ever_used_ = true;
248   callback.Run(result);
249 }
250 
DoLoop(int last_io_result)251 int SOCKSClientSocket::DoLoop(int last_io_result) {
252   DCHECK_NE(next_state_, STATE_NONE);
253   int rv = last_io_result;
254   do {
255     State state = next_state_;
256     next_state_ = STATE_NONE;
257     switch (state) {
258       case STATE_RESOLVE_HOST:
259         DCHECK_EQ(OK, rv);
260         rv = DoResolveHost();
261         break;
262       case STATE_RESOLVE_HOST_COMPLETE:
263         rv = DoResolveHostComplete(rv);
264         break;
265       case STATE_HANDSHAKE_WRITE:
266         DCHECK_EQ(OK, rv);
267         rv = DoHandshakeWrite();
268         break;
269       case STATE_HANDSHAKE_WRITE_COMPLETE:
270         rv = DoHandshakeWriteComplete(rv);
271         break;
272       case STATE_HANDSHAKE_READ:
273         DCHECK_EQ(OK, rv);
274         rv = DoHandshakeRead();
275         break;
276       case STATE_HANDSHAKE_READ_COMPLETE:
277         rv = DoHandshakeReadComplete(rv);
278         break;
279       default:
280         NOTREACHED() << "bad state";
281         rv = ERR_UNEXPECTED;
282         break;
283     }
284   } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
285   return rv;
286 }
287 
DoResolveHost()288 int SOCKSClientSocket::DoResolveHost() {
289   next_state_ = STATE_RESOLVE_HOST_COMPLETE;
290   // SOCKS4 only supports IPv4 addresses, so only try getting the IPv4
291   // addresses for the target host.
292   host_request_info_.set_address_family(ADDRESS_FAMILY_IPV4);
293   return host_resolver_.Resolve(
294       host_request_info_,
295       priority_,
296       &addresses_,
297       base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)),
298       net_log_);
299 }
300 
DoResolveHostComplete(int result)301 int SOCKSClientSocket::DoResolveHostComplete(int result) {
302   if (result != OK) {
303     // Resolving the hostname failed; fail the request rather than automatically
304     // falling back to SOCKS4a (since it can be confusing to see invalid IP
305     // addresses being sent to the SOCKS4 server when it doesn't support 4A.)
306     return result;
307   }
308 
309   next_state_ = STATE_HANDSHAKE_WRITE;
310   return OK;
311 }
312 
313 // Builds the buffer that is to be sent to the server.
BuildHandshakeWriteBuffer() const314 const std::string SOCKSClientSocket::BuildHandshakeWriteBuffer() const {
315   SOCKS4ServerRequest request;
316   request.version = kSOCKSVersion4;
317   request.command = kSOCKSStreamRequest;
318   request.nw_port = base::HostToNet16(host_request_info_.port());
319 
320   DCHECK(!addresses_.empty());
321   const IPEndPoint& endpoint = addresses_.front();
322 
323   // We disabled IPv6 results when resolving the hostname, so none of the
324   // results in the list will be IPv6.
325   // TODO(eroman): we only ever use the first address in the list. It would be
326   //               more robust to try all the IP addresses we have before
327   //               failing the connect attempt.
328   CHECK_EQ(ADDRESS_FAMILY_IPV4, endpoint.GetFamily());
329   CHECK_LE(endpoint.address().size(), sizeof(request.ip));
330   memcpy(&request.ip, &endpoint.address()[0], endpoint.address().size());
331 
332   DVLOG(1) << "Resolved Host is : " << endpoint.ToStringWithoutPort();
333 
334   std::string handshake_data(reinterpret_cast<char*>(&request),
335                              sizeof(request));
336   handshake_data.append(kEmptyUserId, arraysize(kEmptyUserId));
337 
338   return handshake_data;
339 }
340 
341 // Writes the SOCKS handshake data to the underlying socket connection.
DoHandshakeWrite()342 int SOCKSClientSocket::DoHandshakeWrite() {
343   next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
344 
345   if (buffer_.empty()) {
346     buffer_ = BuildHandshakeWriteBuffer();
347     bytes_sent_ = 0;
348   }
349 
350   int handshake_buf_len = buffer_.size() - bytes_sent_;
351   DCHECK_GT(handshake_buf_len, 0);
352   handshake_buf_ = new IOBuffer(handshake_buf_len);
353   memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
354          handshake_buf_len);
355   return transport_->socket()->Write(
356       handshake_buf_.get(),
357       handshake_buf_len,
358       base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)));
359 }
360 
DoHandshakeWriteComplete(int result)361 int SOCKSClientSocket::DoHandshakeWriteComplete(int result) {
362   if (result < 0)
363     return result;
364 
365   // We ignore the case when result is 0, since the underlying Write
366   // may return spurious writes while waiting on the socket.
367 
368   bytes_sent_ += result;
369   if (bytes_sent_ == buffer_.size()) {
370     next_state_ = STATE_HANDSHAKE_READ;
371     buffer_.clear();
372   } else if (bytes_sent_ < buffer_.size()) {
373     next_state_ = STATE_HANDSHAKE_WRITE;
374   } else {
375     return ERR_UNEXPECTED;
376   }
377 
378   return OK;
379 }
380 
DoHandshakeRead()381 int SOCKSClientSocket::DoHandshakeRead() {
382   next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
383 
384   if (buffer_.empty()) {
385     bytes_received_ = 0;
386   }
387 
388   int handshake_buf_len = kReadHeaderSize - bytes_received_;
389   handshake_buf_ = new IOBuffer(handshake_buf_len);
390   return transport_->socket()->Read(
391       handshake_buf_.get(),
392       handshake_buf_len,
393       base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)));
394 }
395 
DoHandshakeReadComplete(int result)396 int SOCKSClientSocket::DoHandshakeReadComplete(int result) {
397   if (result < 0)
398     return result;
399 
400   // The underlying socket closed unexpectedly.
401   if (result == 0)
402     return ERR_CONNECTION_CLOSED;
403 
404   if (bytes_received_ + result > kReadHeaderSize) {
405     // TODO(eroman): Describe failure in NetLog.
406     return ERR_SOCKS_CONNECTION_FAILED;
407   }
408 
409   buffer_.append(handshake_buf_->data(), result);
410   bytes_received_ += result;
411   if (bytes_received_ < kReadHeaderSize) {
412     next_state_ = STATE_HANDSHAKE_READ;
413     return OK;
414   }
415 
416   const SOCKS4ServerResponse* response =
417       reinterpret_cast<const SOCKS4ServerResponse*>(buffer_.data());
418 
419   if (response->reserved_null != 0x00) {
420     LOG(ERROR) << "Unknown response from SOCKS server.";
421     return ERR_SOCKS_CONNECTION_FAILED;
422   }
423 
424   switch (response->code) {
425     case kServerResponseOk:
426       completed_handshake_ = true;
427       return OK;
428     case kServerResponseRejected:
429       LOG(ERROR) << "SOCKS request rejected or failed";
430       return ERR_SOCKS_CONNECTION_FAILED;
431     case kServerResponseNotReachable:
432       LOG(ERROR) << "SOCKS request failed because client is not running "
433                  << "identd (or not reachable from the server)";
434       return ERR_SOCKS_CONNECTION_HOST_UNREACHABLE;
435     case kServerResponseMismatchedUserId:
436       LOG(ERROR) << "SOCKS request failed because client's identd could "
437                  << "not confirm the user ID string in the request";
438       return ERR_SOCKS_CONNECTION_FAILED;
439     default:
440       LOG(ERROR) << "SOCKS server sent unknown response";
441       return ERR_SOCKS_CONNECTION_FAILED;
442   }
443 
444   // Note: we ignore the last 6 bytes as specified by the SOCKS protocol
445 }
446 
GetPeerAddress(IPEndPoint * address) const447 int SOCKSClientSocket::GetPeerAddress(IPEndPoint* address) const {
448   return transport_->socket()->GetPeerAddress(address);
449 }
450 
GetLocalAddress(IPEndPoint * address) const451 int SOCKSClientSocket::GetLocalAddress(IPEndPoint* address) const {
452   return transport_->socket()->GetLocalAddress(address);
453 }
454 
455 }  // namespace net
456