1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/socks_client_socket.h"
6
7 #include <utility>
8
9 #include "base/compiler_specific.h"
10 #include "base/functional/bind.h"
11 #include "base/functional/callback_helpers.h"
12 #include "base/sys_byteorder.h"
13 #include "net/base/address_list.h"
14 #include "net/base/io_buffer.h"
15 #include "net/dns/public/dns_query_type.h"
16 #include "net/dns/public/secure_dns_policy.h"
17 #include "net/log/net_log.h"
18 #include "net/log/net_log_event_type.h"
19 #include "net/traffic_annotation/network_traffic_annotation.h"
20
21 namespace net {
22
23 // Every SOCKS server requests a user-id from the client. It is optional
24 // and we send an empty string.
25 static const char kEmptyUserId[] = "";
26
27 // For SOCKS4, the client sends 8 bytes plus the size of the user-id.
28 static const unsigned int kWriteHeaderSize = 8;
29
30 // For SOCKS4 the server sends 8 bytes for acknowledgement.
31 static const unsigned int kReadHeaderSize = 8;
32
33 // Server Response codes for SOCKS.
34 static const uint8_t kServerResponseOk = 0x5A;
35 static const uint8_t kServerResponseRejected = 0x5B;
36 static const uint8_t kServerResponseNotReachable = 0x5C;
37 static const uint8_t kServerResponseMismatchedUserId = 0x5D;
38
39 static const uint8_t kSOCKSVersion4 = 0x04;
40 static const uint8_t kSOCKSStreamRequest = 0x01;
41
42 // A struct holding the essential details of the SOCKS4 Server Request.
43 // The port in the header is stored in network byte order.
44 struct SOCKS4ServerRequest {
45 uint8_t version;
46 uint8_t command;
47 uint16_t nw_port;
48 uint8_t ip[4];
49 };
50 static_assert(sizeof(SOCKS4ServerRequest) == kWriteHeaderSize,
51 "socks4 server request struct has incorrect size");
52
53 // A struct holding details of the SOCKS4 Server Response.
54 struct SOCKS4ServerResponse {
55 uint8_t reserved_null;
56 uint8_t code;
57 uint16_t port;
58 uint8_t ip[4];
59 };
60 static_assert(sizeof(SOCKS4ServerResponse) == kReadHeaderSize,
61 "socks4 server response struct has incorrect size");
62
SOCKSClientSocket(std::unique_ptr<StreamSocket> transport_socket,const HostPortPair & destination,const NetworkAnonymizationKey & network_anonymization_key,RequestPriority priority,HostResolver * host_resolver,SecureDnsPolicy secure_dns_policy,const NetworkTrafficAnnotationTag & traffic_annotation)63 SOCKSClientSocket::SOCKSClientSocket(
64 std::unique_ptr<StreamSocket> transport_socket,
65 const HostPortPair& destination,
66 const NetworkAnonymizationKey& network_anonymization_key,
67 RequestPriority priority,
68 HostResolver* host_resolver,
69 SecureDnsPolicy secure_dns_policy,
70 const NetworkTrafficAnnotationTag& traffic_annotation)
71 : transport_socket_(std::move(transport_socket)),
72 host_resolver_(host_resolver),
73 secure_dns_policy_(secure_dns_policy),
74 destination_(destination),
75 network_anonymization_key_(network_anonymization_key),
76 priority_(priority),
77 net_log_(transport_socket_->NetLog()),
78 traffic_annotation_(traffic_annotation) {}
79
~SOCKSClientSocket()80 SOCKSClientSocket::~SOCKSClientSocket() {
81 Disconnect();
82 }
83
Connect(CompletionOnceCallback callback)84 int SOCKSClientSocket::Connect(CompletionOnceCallback callback) {
85 DCHECK(transport_socket_);
86 DCHECK_EQ(STATE_NONE, next_state_);
87 DCHECK(user_callback_.is_null());
88
89 // If already connected, then just return OK.
90 if (completed_handshake_)
91 return OK;
92
93 next_state_ = STATE_RESOLVE_HOST;
94
95 net_log_.BeginEvent(NetLogEventType::SOCKS_CONNECT);
96
97 int rv = DoLoop(OK);
98 if (rv == ERR_IO_PENDING) {
99 user_callback_ = std::move(callback);
100 } else {
101 net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS_CONNECT, rv);
102 }
103 return rv;
104 }
105
Disconnect()106 void SOCKSClientSocket::Disconnect() {
107 completed_handshake_ = false;
108 resolve_host_request_.reset();
109 transport_socket_->Disconnect();
110
111 // Reset other states to make sure they aren't mistakenly used later.
112 // These are the states initialized by Connect().
113 next_state_ = STATE_NONE;
114 user_callback_.Reset();
115 }
116
IsConnected() const117 bool SOCKSClientSocket::IsConnected() const {
118 return completed_handshake_ && transport_socket_->IsConnected();
119 }
120
IsConnectedAndIdle() const121 bool SOCKSClientSocket::IsConnectedAndIdle() const {
122 return completed_handshake_ && transport_socket_->IsConnectedAndIdle();
123 }
124
NetLog() const125 const NetLogWithSource& SOCKSClientSocket::NetLog() const {
126 return net_log_;
127 }
128
WasEverUsed() const129 bool SOCKSClientSocket::WasEverUsed() const {
130 return was_ever_used_;
131 }
132
WasAlpnNegotiated() const133 bool SOCKSClientSocket::WasAlpnNegotiated() const {
134 if (transport_socket_)
135 return transport_socket_->WasAlpnNegotiated();
136 NOTREACHED();
137 return false;
138 }
139
GetNegotiatedProtocol() const140 NextProto SOCKSClientSocket::GetNegotiatedProtocol() const {
141 if (transport_socket_)
142 return transport_socket_->GetNegotiatedProtocol();
143 NOTREACHED();
144 return kProtoUnknown;
145 }
146
GetSSLInfo(SSLInfo * ssl_info)147 bool SOCKSClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
148 if (transport_socket_)
149 return transport_socket_->GetSSLInfo(ssl_info);
150 NOTREACHED();
151 return false;
152 }
153
GetTotalReceivedBytes() const154 int64_t SOCKSClientSocket::GetTotalReceivedBytes() const {
155 return transport_socket_->GetTotalReceivedBytes();
156 }
157
ApplySocketTag(const SocketTag & tag)158 void SOCKSClientSocket::ApplySocketTag(const SocketTag& tag) {
159 return transport_socket_->ApplySocketTag(tag);
160 }
161
162 // Read is called by the transport layer above to read. This can only be done
163 // if the SOCKS handshake is complete.
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)164 int SOCKSClientSocket::Read(IOBuffer* buf,
165 int buf_len,
166 CompletionOnceCallback callback) {
167 DCHECK(completed_handshake_);
168 DCHECK_EQ(STATE_NONE, next_state_);
169 DCHECK(user_callback_.is_null());
170 DCHECK(!callback.is_null());
171
172 int rv = transport_socket_->Read(
173 buf, buf_len,
174 base::BindOnce(&SOCKSClientSocket::OnReadWriteComplete,
175 base::Unretained(this), std::move(callback)));
176 if (rv > 0)
177 was_ever_used_ = true;
178 return rv;
179 }
180
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)181 int SOCKSClientSocket::ReadIfReady(IOBuffer* buf,
182 int buf_len,
183 CompletionOnceCallback callback) {
184 DCHECK(completed_handshake_);
185 DCHECK_EQ(STATE_NONE, next_state_);
186 DCHECK(user_callback_.is_null());
187 DCHECK(!callback.is_null());
188
189 // Pass |callback| directly instead of wrapping it with OnReadWriteComplete.
190 // This is to avoid setting |was_ever_used_| unless data is actually read.
191 int rv = transport_socket_->ReadIfReady(buf, buf_len, std::move(callback));
192 if (rv > 0)
193 was_ever_used_ = true;
194 return rv;
195 }
196
CancelReadIfReady()197 int SOCKSClientSocket::CancelReadIfReady() {
198 return transport_socket_->CancelReadIfReady();
199 }
200
201 // Write is called by the transport layer. This can only be done if the
202 // SOCKS handshake is complete.
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)203 int SOCKSClientSocket::Write(
204 IOBuffer* buf,
205 int buf_len,
206 CompletionOnceCallback callback,
207 const NetworkTrafficAnnotationTag& traffic_annotation) {
208 DCHECK(completed_handshake_);
209 DCHECK_EQ(STATE_NONE, next_state_);
210 DCHECK(user_callback_.is_null());
211 DCHECK(!callback.is_null());
212
213 int rv = transport_socket_->Write(
214 buf, buf_len,
215 base::BindOnce(&SOCKSClientSocket::OnReadWriteComplete,
216 base::Unretained(this), std::move(callback)),
217 traffic_annotation);
218 if (rv > 0)
219 was_ever_used_ = true;
220 return rv;
221 }
222
SetReceiveBufferSize(int32_t size)223 int SOCKSClientSocket::SetReceiveBufferSize(int32_t size) {
224 return transport_socket_->SetReceiveBufferSize(size);
225 }
226
SetSendBufferSize(int32_t size)227 int SOCKSClientSocket::SetSendBufferSize(int32_t size) {
228 return transport_socket_->SetSendBufferSize(size);
229 }
230
DoCallback(int result)231 void SOCKSClientSocket::DoCallback(int result) {
232 DCHECK_NE(ERR_IO_PENDING, result);
233 DCHECK(!user_callback_.is_null());
234
235 // Since Run() may result in Read being called,
236 // clear user_callback_ up front.
237 DVLOG(1) << "Finished setting up SOCKS handshake";
238 std::move(user_callback_).Run(result);
239 }
240
OnIOComplete(int result)241 void SOCKSClientSocket::OnIOComplete(int result) {
242 DCHECK_NE(STATE_NONE, next_state_);
243 int rv = DoLoop(result);
244 if (rv != ERR_IO_PENDING) {
245 net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS_CONNECT, rv);
246 DoCallback(rv);
247 }
248 }
249
OnReadWriteComplete(CompletionOnceCallback callback,int result)250 void SOCKSClientSocket::OnReadWriteComplete(CompletionOnceCallback callback,
251 int result) {
252 DCHECK_NE(ERR_IO_PENDING, result);
253 DCHECK(!callback.is_null());
254
255 if (result > 0)
256 was_ever_used_ = true;
257 std::move(callback).Run(result);
258 }
259
DoLoop(int last_io_result)260 int SOCKSClientSocket::DoLoop(int last_io_result) {
261 DCHECK_NE(next_state_, STATE_NONE);
262 int rv = last_io_result;
263 do {
264 State state = next_state_;
265 next_state_ = STATE_NONE;
266 switch (state) {
267 case STATE_RESOLVE_HOST:
268 DCHECK_EQ(OK, rv);
269 rv = DoResolveHost();
270 break;
271 case STATE_RESOLVE_HOST_COMPLETE:
272 rv = DoResolveHostComplete(rv);
273 break;
274 case STATE_HANDSHAKE_WRITE:
275 DCHECK_EQ(OK, rv);
276 rv = DoHandshakeWrite();
277 break;
278 case STATE_HANDSHAKE_WRITE_COMPLETE:
279 rv = DoHandshakeWriteComplete(rv);
280 break;
281 case STATE_HANDSHAKE_READ:
282 DCHECK_EQ(OK, rv);
283 rv = DoHandshakeRead();
284 break;
285 case STATE_HANDSHAKE_READ_COMPLETE:
286 rv = DoHandshakeReadComplete(rv);
287 break;
288 default:
289 NOTREACHED() << "bad state";
290 rv = ERR_UNEXPECTED;
291 break;
292 }
293 } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
294 return rv;
295 }
296
DoResolveHost()297 int SOCKSClientSocket::DoResolveHost() {
298 next_state_ = STATE_RESOLVE_HOST_COMPLETE;
299 // SOCKS4 only supports IPv4 addresses, so only try getting the IPv4
300 // addresses for the target host.
301 HostResolver::ResolveHostParameters parameters;
302 parameters.dns_query_type = DnsQueryType::A;
303 parameters.initial_priority = priority_;
304 parameters.secure_dns_policy = secure_dns_policy_;
305 resolve_host_request_ = host_resolver_->CreateRequest(
306 destination_, network_anonymization_key_, net_log_, parameters);
307
308 return resolve_host_request_->Start(
309 base::BindOnce(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)));
310 }
311
DoResolveHostComplete(int result)312 int SOCKSClientSocket::DoResolveHostComplete(int result) {
313 resolve_error_info_ = resolve_host_request_->GetResolveErrorInfo();
314 if (result != OK) {
315 // Resolving the hostname failed; fail the request rather than automatically
316 // falling back to SOCKS4a (since it can be confusing to see invalid IP
317 // addresses being sent to the SOCKS4 server when it doesn't support 4A.)
318 return result;
319 }
320
321 next_state_ = STATE_HANDSHAKE_WRITE;
322 return OK;
323 }
324
325 // Builds the buffer that is to be sent to the server.
BuildHandshakeWriteBuffer() const326 const std::string SOCKSClientSocket::BuildHandshakeWriteBuffer() const {
327 SOCKS4ServerRequest request;
328 request.version = kSOCKSVersion4;
329 request.command = kSOCKSStreamRequest;
330 request.nw_port = base::HostToNet16(destination_.port());
331
332 DCHECK(resolve_host_request_->GetAddressResults() &&
333 !resolve_host_request_->GetAddressResults()->empty());
334 const IPEndPoint& endpoint =
335 resolve_host_request_->GetAddressResults()->front();
336
337 // We disabled IPv6 results when resolving the hostname, so none of the
338 // results in the list will be IPv6.
339 // TODO(eroman): we only ever use the first address in the list. It would be
340 // more robust to try all the IP addresses we have before
341 // failing the connect attempt.
342 CHECK_EQ(ADDRESS_FAMILY_IPV4, endpoint.GetFamily());
343 CHECK_LE(endpoint.address().size(), sizeof(request.ip));
344 memcpy(&request.ip, &endpoint.address().bytes()[0],
345 endpoint.address().size());
346
347 DVLOG(1) << "Resolved Host is : " << endpoint.ToStringWithoutPort();
348
349 std::string handshake_data(reinterpret_cast<char*>(&request),
350 sizeof(request));
351 handshake_data.append(kEmptyUserId, std::size(kEmptyUserId));
352
353 return handshake_data;
354 }
355
356 // Writes the SOCKS handshake data to the underlying socket connection.
DoHandshakeWrite()357 int SOCKSClientSocket::DoHandshakeWrite() {
358 next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
359
360 if (buffer_.empty()) {
361 buffer_ = BuildHandshakeWriteBuffer();
362 bytes_sent_ = 0;
363 }
364
365 int handshake_buf_len = buffer_.size() - bytes_sent_;
366 DCHECK_GT(handshake_buf_len, 0);
367 handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
368 memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
369 handshake_buf_len);
370 return transport_socket_->Write(
371 handshake_buf_.get(), handshake_buf_len,
372 base::BindOnce(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)),
373 traffic_annotation_);
374 }
375
DoHandshakeWriteComplete(int result)376 int SOCKSClientSocket::DoHandshakeWriteComplete(int result) {
377 if (result < 0)
378 return result;
379
380 // We ignore the case when result is 0, since the underlying Write
381 // may return spurious writes while waiting on the socket.
382
383 bytes_sent_ += result;
384 if (bytes_sent_ == buffer_.size()) {
385 next_state_ = STATE_HANDSHAKE_READ;
386 buffer_.clear();
387 } else if (bytes_sent_ < buffer_.size()) {
388 next_state_ = STATE_HANDSHAKE_WRITE;
389 } else {
390 return ERR_UNEXPECTED;
391 }
392
393 return OK;
394 }
395
DoHandshakeRead()396 int SOCKSClientSocket::DoHandshakeRead() {
397 next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
398
399 if (buffer_.empty()) {
400 bytes_received_ = 0;
401 }
402
403 int handshake_buf_len = kReadHeaderSize - bytes_received_;
404 handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
405 return transport_socket_->Read(
406 handshake_buf_.get(), handshake_buf_len,
407 base::BindOnce(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)));
408 }
409
DoHandshakeReadComplete(int result)410 int SOCKSClientSocket::DoHandshakeReadComplete(int result) {
411 if (result < 0)
412 return result;
413
414 // The underlying socket closed unexpectedly.
415 if (result == 0)
416 return ERR_CONNECTION_CLOSED;
417
418 if (bytes_received_ + result > kReadHeaderSize) {
419 // TODO(eroman): Describe failure in NetLog.
420 return ERR_SOCKS_CONNECTION_FAILED;
421 }
422
423 buffer_.append(handshake_buf_->data(), result);
424 bytes_received_ += result;
425 if (bytes_received_ < kReadHeaderSize) {
426 next_state_ = STATE_HANDSHAKE_READ;
427 return OK;
428 }
429
430 const SOCKS4ServerResponse* response =
431 reinterpret_cast<const SOCKS4ServerResponse*>(buffer_.data());
432
433 if (response->reserved_null != 0x00) {
434 DVLOG(1) << "Unknown response from SOCKS server.";
435 return ERR_SOCKS_CONNECTION_FAILED;
436 }
437
438 switch (response->code) {
439 case kServerResponseOk:
440 completed_handshake_ = true;
441 return OK;
442 case kServerResponseRejected:
443 DVLOG(1) << "SOCKS request rejected or failed";
444 return ERR_SOCKS_CONNECTION_FAILED;
445 case kServerResponseNotReachable:
446 DVLOG(1) << "SOCKS request failed because client is not running "
447 << "identd (or not reachable from the server)";
448 return ERR_SOCKS_CONNECTION_HOST_UNREACHABLE;
449 case kServerResponseMismatchedUserId:
450 DVLOG(1) << "SOCKS request failed because client's identd could "
451 << "not confirm the user ID string in the request";
452 return ERR_SOCKS_CONNECTION_FAILED;
453 default:
454 DVLOG(1) << "SOCKS server sent unknown response";
455 return ERR_SOCKS_CONNECTION_FAILED;
456 }
457
458 // Note: we ignore the last 6 bytes as specified by the SOCKS protocol
459 }
460
GetPeerAddress(IPEndPoint * address) const461 int SOCKSClientSocket::GetPeerAddress(IPEndPoint* address) const {
462 return transport_socket_->GetPeerAddress(address);
463 }
464
GetLocalAddress(IPEndPoint * address) const465 int SOCKSClientSocket::GetLocalAddress(IPEndPoint* address) const {
466 return transport_socket_->GetLocalAddress(address);
467 }
468
GetResolveErrorInfo() const469 ResolveErrorInfo SOCKSClientSocket::GetResolveErrorInfo() const {
470 return resolve_error_info_;
471 }
472
473 } // namespace net
474