1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/socks_client_socket.h"
6
7 #include "base/basictypes.h"
8 #include "base/compiler_specific.h"
9 #include "net/base/io_buffer.h"
10 #include "net/base/net_log.h"
11 #include "net/base/net_util.h"
12 #include "net/base/sys_addrinfo.h"
13 #include "net/socket/client_socket_handle.h"
14
15 namespace net {
16
17 // Every SOCKS server requests a user-id from the client. It is optional
18 // and we send an empty string.
19 static const char kEmptyUserId[] = "";
20
21 // For SOCKS4, the client sends 8 bytes plus the size of the user-id.
22 static const unsigned int kWriteHeaderSize = 8;
23
24 // For SOCKS4 the server sends 8 bytes for acknowledgement.
25 static const unsigned int kReadHeaderSize = 8;
26
27 // Server Response codes for SOCKS.
28 static const uint8 kServerResponseOk = 0x5A;
29 static const uint8 kServerResponseRejected = 0x5B;
30 static const uint8 kServerResponseNotReachable = 0x5C;
31 static const uint8 kServerResponseMismatchedUserId = 0x5D;
32
33 static const uint8 kSOCKSVersion4 = 0x04;
34 static const uint8 kSOCKSStreamRequest = 0x01;
35
36 // A struct holding the essential details of the SOCKS4 Server Request.
37 // The port in the header is stored in network byte order.
38 struct SOCKS4ServerRequest {
39 uint8 version;
40 uint8 command;
41 uint16 nw_port;
42 uint8 ip[4];
43 };
44 COMPILE_ASSERT(sizeof(SOCKS4ServerRequest) == kWriteHeaderSize,
45 socks4_server_request_struct_wrong_size);
46
47 // A struct holding details of the SOCKS4 Server Response.
48 struct SOCKS4ServerResponse {
49 uint8 reserved_null;
50 uint8 code;
51 uint16 port;
52 uint8 ip[4];
53 };
54 COMPILE_ASSERT(sizeof(SOCKS4ServerResponse) == kReadHeaderSize,
55 socks4_server_response_struct_wrong_size);
56
SOCKSClientSocket(ClientSocketHandle * transport_socket,const HostResolver::RequestInfo & req_info,HostResolver * host_resolver)57 SOCKSClientSocket::SOCKSClientSocket(ClientSocketHandle* transport_socket,
58 const HostResolver::RequestInfo& req_info,
59 HostResolver* host_resolver)
60 : ALLOW_THIS_IN_INITIALIZER_LIST(
61 io_callback_(this, &SOCKSClientSocket::OnIOComplete)),
62 transport_(transport_socket),
63 next_state_(STATE_NONE),
64 user_callback_(NULL),
65 completed_handshake_(false),
66 bytes_sent_(0),
67 bytes_received_(0),
68 host_resolver_(host_resolver),
69 host_request_info_(req_info),
70 net_log_(transport_socket->socket()->NetLog()) {
71 }
72
SOCKSClientSocket(ClientSocket * transport_socket,const HostResolver::RequestInfo & req_info,HostResolver * host_resolver)73 SOCKSClientSocket::SOCKSClientSocket(ClientSocket* transport_socket,
74 const HostResolver::RequestInfo& req_info,
75 HostResolver* host_resolver)
76 : ALLOW_THIS_IN_INITIALIZER_LIST(
77 io_callback_(this, &SOCKSClientSocket::OnIOComplete)),
78 transport_(new ClientSocketHandle()),
79 next_state_(STATE_NONE),
80 user_callback_(NULL),
81 completed_handshake_(false),
82 bytes_sent_(0),
83 bytes_received_(0),
84 host_resolver_(host_resolver),
85 host_request_info_(req_info),
86 net_log_(transport_socket->NetLog()) {
87 transport_->set_socket(transport_socket);
88 }
89
~SOCKSClientSocket()90 SOCKSClientSocket::~SOCKSClientSocket() {
91 Disconnect();
92 }
93
94 #ifdef ANDROID
95 // TODO(kristianm): find out if Connect should block
96 #endif
Connect(CompletionCallback * callback,bool wait_for_connect,bool valid_uid,uid_t calling_uid)97 int SOCKSClientSocket::Connect(CompletionCallback* callback
98 #ifdef ANDROID
99 , bool wait_for_connect
100 , bool valid_uid
101 , uid_t calling_uid
102 #endif
103 ) {
104 DCHECK(transport_.get());
105 DCHECK(transport_->socket());
106 DCHECK_EQ(STATE_NONE, next_state_);
107 DCHECK(!user_callback_);
108
109 // If already connected, then just return OK.
110 if (completed_handshake_)
111 return OK;
112
113 next_state_ = STATE_RESOLVE_HOST;
114
115 net_log_.BeginEvent(NetLog::TYPE_SOCKS_CONNECT, NULL);
116
117 int rv = DoLoop(OK);
118 if (rv == ERR_IO_PENDING) {
119 user_callback_ = callback;
120 } else {
121 net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv);
122 }
123 return rv;
124 }
125
Disconnect()126 void SOCKSClientSocket::Disconnect() {
127 completed_handshake_ = false;
128 host_resolver_.Cancel();
129 transport_->socket()->Disconnect();
130
131 // Reset other states to make sure they aren't mistakenly used later.
132 // These are the states initialized by Connect().
133 next_state_ = STATE_NONE;
134 user_callback_ = NULL;
135 }
136
IsConnected() const137 bool SOCKSClientSocket::IsConnected() const {
138 return completed_handshake_ && transport_->socket()->IsConnected();
139 }
140
IsConnectedAndIdle() const141 bool SOCKSClientSocket::IsConnectedAndIdle() const {
142 return completed_handshake_ && transport_->socket()->IsConnectedAndIdle();
143 }
144
NetLog() const145 const BoundNetLog& SOCKSClientSocket::NetLog() const {
146 return net_log_;
147 }
148
SetSubresourceSpeculation()149 void SOCKSClientSocket::SetSubresourceSpeculation() {
150 if (transport_.get() && transport_->socket()) {
151 transport_->socket()->SetSubresourceSpeculation();
152 } else {
153 NOTREACHED();
154 }
155 }
156
SetOmniboxSpeculation()157 void SOCKSClientSocket::SetOmniboxSpeculation() {
158 if (transport_.get() && transport_->socket()) {
159 transport_->socket()->SetOmniboxSpeculation();
160 } else {
161 NOTREACHED();
162 }
163 }
164
WasEverUsed() const165 bool SOCKSClientSocket::WasEverUsed() const {
166 if (transport_.get() && transport_->socket()) {
167 return transport_->socket()->WasEverUsed();
168 }
169 NOTREACHED();
170 return false;
171 }
172
UsingTCPFastOpen() const173 bool SOCKSClientSocket::UsingTCPFastOpen() const {
174 if (transport_.get() && transport_->socket()) {
175 return transport_->socket()->UsingTCPFastOpen();
176 }
177 NOTREACHED();
178 return false;
179 }
180
181
182 // Read is called by the transport layer above to read. This can only be done
183 // if the SOCKS handshake is complete.
Read(IOBuffer * buf,int buf_len,CompletionCallback * callback)184 int SOCKSClientSocket::Read(IOBuffer* buf, int buf_len,
185 CompletionCallback* callback) {
186 DCHECK(completed_handshake_);
187 DCHECK_EQ(STATE_NONE, next_state_);
188 DCHECK(!user_callback_);
189
190 return transport_->socket()->Read(buf, buf_len, callback);
191 }
192
193 // Write is called by the transport layer. This can only be done if the
194 // SOCKS handshake is complete.
Write(IOBuffer * buf,int buf_len,CompletionCallback * callback)195 int SOCKSClientSocket::Write(IOBuffer* buf, int buf_len,
196 CompletionCallback* callback) {
197 DCHECK(completed_handshake_);
198 DCHECK_EQ(STATE_NONE, next_state_);
199 DCHECK(!user_callback_);
200
201 return transport_->socket()->Write(buf, buf_len, callback);
202 }
203
SetReceiveBufferSize(int32 size)204 bool SOCKSClientSocket::SetReceiveBufferSize(int32 size) {
205 return transport_->socket()->SetReceiveBufferSize(size);
206 }
207
SetSendBufferSize(int32 size)208 bool SOCKSClientSocket::SetSendBufferSize(int32 size) {
209 return transport_->socket()->SetSendBufferSize(size);
210 }
211
DoCallback(int result)212 void SOCKSClientSocket::DoCallback(int result) {
213 DCHECK_NE(ERR_IO_PENDING, result);
214 DCHECK(user_callback_);
215
216 // Since Run() may result in Read being called,
217 // clear user_callback_ up front.
218 CompletionCallback* c = user_callback_;
219 user_callback_ = NULL;
220 DVLOG(1) << "Finished setting up SOCKS handshake";
221 c->Run(result);
222 }
223
OnIOComplete(int result)224 void SOCKSClientSocket::OnIOComplete(int result) {
225 DCHECK_NE(STATE_NONE, next_state_);
226 int rv = DoLoop(result);
227 if (rv != ERR_IO_PENDING) {
228 net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv);
229 DoCallback(rv);
230 }
231 }
232
DoLoop(int last_io_result)233 int SOCKSClientSocket::DoLoop(int last_io_result) {
234 DCHECK_NE(next_state_, STATE_NONE);
235 int rv = last_io_result;
236 do {
237 State state = next_state_;
238 next_state_ = STATE_NONE;
239 switch (state) {
240 case STATE_RESOLVE_HOST:
241 DCHECK_EQ(OK, rv);
242 rv = DoResolveHost();
243 break;
244 case STATE_RESOLVE_HOST_COMPLETE:
245 rv = DoResolveHostComplete(rv);
246 break;
247 case STATE_HANDSHAKE_WRITE:
248 DCHECK_EQ(OK, rv);
249 rv = DoHandshakeWrite();
250 break;
251 case STATE_HANDSHAKE_WRITE_COMPLETE:
252 rv = DoHandshakeWriteComplete(rv);
253 break;
254 case STATE_HANDSHAKE_READ:
255 DCHECK_EQ(OK, rv);
256 rv = DoHandshakeRead();
257 break;
258 case STATE_HANDSHAKE_READ_COMPLETE:
259 rv = DoHandshakeReadComplete(rv);
260 break;
261 default:
262 NOTREACHED() << "bad state";
263 rv = ERR_UNEXPECTED;
264 break;
265 }
266 } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
267 return rv;
268 }
269
DoResolveHost()270 int SOCKSClientSocket::DoResolveHost() {
271 next_state_ = STATE_RESOLVE_HOST_COMPLETE;
272 // SOCKS4 only supports IPv4 addresses, so only try getting the IPv4
273 // addresses for the target host.
274 host_request_info_.set_address_family(ADDRESS_FAMILY_IPV4);
275 return host_resolver_.Resolve(
276 host_request_info_, &addresses_, &io_callback_, net_log_);
277 }
278
DoResolveHostComplete(int result)279 int SOCKSClientSocket::DoResolveHostComplete(int result) {
280 if (result != OK) {
281 // Resolving the hostname failed; fail the request rather than automatically
282 // falling back to SOCKS4a (since it can be confusing to see invalid IP
283 // addresses being sent to the SOCKS4 server when it doesn't support 4A.)
284 return result;
285 }
286
287 next_state_ = STATE_HANDSHAKE_WRITE;
288 return OK;
289 }
290
291 // Builds the buffer that is to be sent to the server.
BuildHandshakeWriteBuffer() const292 const std::string SOCKSClientSocket::BuildHandshakeWriteBuffer() const {
293 SOCKS4ServerRequest request;
294 request.version = kSOCKSVersion4;
295 request.command = kSOCKSStreamRequest;
296 request.nw_port = htons(host_request_info_.port());
297
298 const struct addrinfo* ai = addresses_.head();
299 DCHECK(ai);
300
301 // We disabled IPv6 results when resolving the hostname, so none of the
302 // results in the list will be IPv6.
303 // TODO(eroman): we only ever use the first address in the list. It would be
304 // more robust to try all the IP addresses we have before
305 // failing the connect attempt.
306 CHECK_EQ(AF_INET, ai->ai_addr->sa_family);
307 struct sockaddr_in* ipv4_host =
308 reinterpret_cast<struct sockaddr_in*>(ai->ai_addr);
309 memcpy(&request.ip, &ipv4_host->sin_addr, sizeof(ipv4_host->sin_addr));
310
311 DVLOG(1) << "Resolved Host is : " << NetAddressToString(ai);
312
313 std::string handshake_data(reinterpret_cast<char*>(&request),
314 sizeof(request));
315 handshake_data.append(kEmptyUserId, arraysize(kEmptyUserId));
316
317 return handshake_data;
318 }
319
320 // Writes the SOCKS handshake data to the underlying socket connection.
DoHandshakeWrite()321 int SOCKSClientSocket::DoHandshakeWrite() {
322 next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
323
324 if (buffer_.empty()) {
325 buffer_ = BuildHandshakeWriteBuffer();
326 bytes_sent_ = 0;
327 }
328
329 int handshake_buf_len = buffer_.size() - bytes_sent_;
330 DCHECK_GT(handshake_buf_len, 0);
331 handshake_buf_ = new IOBuffer(handshake_buf_len);
332 memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
333 handshake_buf_len);
334 return transport_->socket()->Write(handshake_buf_, handshake_buf_len,
335 &io_callback_);
336 }
337
DoHandshakeWriteComplete(int result)338 int SOCKSClientSocket::DoHandshakeWriteComplete(int result) {
339 if (result < 0)
340 return result;
341
342 // We ignore the case when result is 0, since the underlying Write
343 // may return spurious writes while waiting on the socket.
344
345 bytes_sent_ += result;
346 if (bytes_sent_ == buffer_.size()) {
347 next_state_ = STATE_HANDSHAKE_READ;
348 buffer_.clear();
349 } else if (bytes_sent_ < buffer_.size()) {
350 next_state_ = STATE_HANDSHAKE_WRITE;
351 } else {
352 return ERR_UNEXPECTED;
353 }
354
355 return OK;
356 }
357
DoHandshakeRead()358 int SOCKSClientSocket::DoHandshakeRead() {
359 next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
360
361 if (buffer_.empty()) {
362 bytes_received_ = 0;
363 }
364
365 int handshake_buf_len = kReadHeaderSize - bytes_received_;
366 handshake_buf_ = new IOBuffer(handshake_buf_len);
367 return transport_->socket()->Read(handshake_buf_, handshake_buf_len,
368 &io_callback_);
369 }
370
DoHandshakeReadComplete(int result)371 int SOCKSClientSocket::DoHandshakeReadComplete(int result) {
372 if (result < 0)
373 return result;
374
375 // The underlying socket closed unexpectedly.
376 if (result == 0)
377 return ERR_CONNECTION_CLOSED;
378
379 if (bytes_received_ + result > kReadHeaderSize) {
380 // TODO(eroman): Describe failure in NetLog.
381 return ERR_SOCKS_CONNECTION_FAILED;
382 }
383
384 buffer_.append(handshake_buf_->data(), result);
385 bytes_received_ += result;
386 if (bytes_received_ < kReadHeaderSize) {
387 next_state_ = STATE_HANDSHAKE_READ;
388 return OK;
389 }
390
391 const SOCKS4ServerResponse* response =
392 reinterpret_cast<const SOCKS4ServerResponse*>(buffer_.data());
393
394 if (response->reserved_null != 0x00) {
395 LOG(ERROR) << "Unknown response from SOCKS server.";
396 return ERR_SOCKS_CONNECTION_FAILED;
397 }
398
399 switch (response->code) {
400 case kServerResponseOk:
401 completed_handshake_ = true;
402 return OK;
403 case kServerResponseRejected:
404 LOG(ERROR) << "SOCKS request rejected or failed";
405 return ERR_SOCKS_CONNECTION_FAILED;
406 case kServerResponseNotReachable:
407 LOG(ERROR) << "SOCKS request failed because client is not running "
408 << "identd (or not reachable from the server)";
409 return ERR_SOCKS_CONNECTION_HOST_UNREACHABLE;
410 case kServerResponseMismatchedUserId:
411 LOG(ERROR) << "SOCKS request failed because client's identd could "
412 << "not confirm the user ID string in the request";
413 return ERR_SOCKS_CONNECTION_FAILED;
414 default:
415 LOG(ERROR) << "SOCKS server sent unknown response";
416 return ERR_SOCKS_CONNECTION_FAILED;
417 }
418
419 // Note: we ignore the last 6 bytes as specified by the SOCKS protocol
420 }
421
GetPeerAddress(AddressList * address) const422 int SOCKSClientSocket::GetPeerAddress(AddressList* address) const {
423 return transport_->socket()->GetPeerAddress(address);
424 }
425
GetLocalAddress(IPEndPoint * address) const426 int SOCKSClientSocket::GetLocalAddress(IPEndPoint* address) const {
427 return transport_->socket()->GetLocalAddress(address);
428 }
429
430 } // namespace net
431