1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/socks5_client_socket.h"
6
7 #include <utility>
8
9 #include "base/compiler_specific.h"
10 #include "base/format_macros.h"
11 #include "base/functional/bind.h"
12 #include "base/functional/callback_helpers.h"
13 #include "base/strings/string_util.h"
14 #include "base/sys_byteorder.h"
15 #include "net/base/io_buffer.h"
16 #include "net/base/sys_addrinfo.h"
17 #include "net/base/tracing.h"
18 #include "net/log/net_log.h"
19 #include "net/log/net_log_event_type.h"
20 #include "net/traffic_annotation/network_traffic_annotation.h"
21
22 namespace net {
23
24 const unsigned int SOCKS5ClientSocket::kGreetReadHeaderSize = 2;
25 const unsigned int SOCKS5ClientSocket::kWriteHeaderSize = 10;
26 const unsigned int SOCKS5ClientSocket::kReadHeaderSize = 5;
27 const uint8_t SOCKS5ClientSocket::kSOCKS5Version = 0x05;
28 const uint8_t SOCKS5ClientSocket::kTunnelCommand = 0x01;
29 const uint8_t SOCKS5ClientSocket::kNullByte = 0x00;
30
31 static_assert(sizeof(struct in_addr) == 4, "incorrect system size of IPv4");
32 static_assert(sizeof(struct in6_addr) == 16, "incorrect system size of IPv6");
33
SOCKS5ClientSocket(std::unique_ptr<StreamSocket> transport_socket,const HostPortPair & destination,const NetworkTrafficAnnotationTag & traffic_annotation)34 SOCKS5ClientSocket::SOCKS5ClientSocket(
35 std::unique_ptr<StreamSocket> transport_socket,
36 const HostPortPair& destination,
37 const NetworkTrafficAnnotationTag& traffic_annotation)
38 : io_callback_(base::BindRepeating(&SOCKS5ClientSocket::OnIOComplete,
39 base::Unretained(this))),
40 transport_socket_(std::move(transport_socket)),
41 read_header_size(kReadHeaderSize),
42 destination_(destination),
43 net_log_(transport_socket_->NetLog()),
44 traffic_annotation_(traffic_annotation) {}
45
~SOCKS5ClientSocket()46 SOCKS5ClientSocket::~SOCKS5ClientSocket() {
47 Disconnect();
48 }
49
Connect(CompletionOnceCallback callback)50 int SOCKS5ClientSocket::Connect(CompletionOnceCallback callback) {
51 DCHECK(transport_socket_);
52 DCHECK_EQ(STATE_NONE, next_state_);
53 DCHECK(user_callback_.is_null());
54
55 // If already connected, then just return OK.
56 if (completed_handshake_)
57 return OK;
58
59 net_log_.BeginEvent(NetLogEventType::SOCKS5_CONNECT);
60
61 next_state_ = STATE_GREET_WRITE;
62 buffer_.clear();
63
64 int rv = DoLoop(OK);
65 if (rv == ERR_IO_PENDING) {
66 user_callback_ = std::move(callback);
67 } else {
68 net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_CONNECT, rv);
69 }
70 return rv;
71 }
72
Disconnect()73 void SOCKS5ClientSocket::Disconnect() {
74 completed_handshake_ = false;
75 transport_socket_->Disconnect();
76
77 // Reset other states to make sure they aren't mistakenly used later.
78 // These are the states initialized by Connect().
79 next_state_ = STATE_NONE;
80 user_callback_.Reset();
81 }
82
IsConnected() const83 bool SOCKS5ClientSocket::IsConnected() const {
84 return completed_handshake_ && transport_socket_->IsConnected();
85 }
86
IsConnectedAndIdle() const87 bool SOCKS5ClientSocket::IsConnectedAndIdle() const {
88 return completed_handshake_ && transport_socket_->IsConnectedAndIdle();
89 }
90
NetLog() const91 const NetLogWithSource& SOCKS5ClientSocket::NetLog() const {
92 return net_log_;
93 }
94
WasEverUsed() const95 bool SOCKS5ClientSocket::WasEverUsed() const {
96 return was_ever_used_;
97 }
98
WasAlpnNegotiated() const99 bool SOCKS5ClientSocket::WasAlpnNegotiated() const {
100 if (transport_socket_)
101 return transport_socket_->WasAlpnNegotiated();
102 NOTREACHED();
103 return false;
104 }
105
GetNegotiatedProtocol() const106 NextProto SOCKS5ClientSocket::GetNegotiatedProtocol() const {
107 if (transport_socket_)
108 return transport_socket_->GetNegotiatedProtocol();
109 NOTREACHED();
110 return kProtoUnknown;
111 }
112
GetSSLInfo(SSLInfo * ssl_info)113 bool SOCKS5ClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
114 if (transport_socket_)
115 return transport_socket_->GetSSLInfo(ssl_info);
116 NOTREACHED();
117 return false;
118 }
119
GetTotalReceivedBytes() const120 int64_t SOCKS5ClientSocket::GetTotalReceivedBytes() const {
121 return transport_socket_->GetTotalReceivedBytes();
122 }
123
ApplySocketTag(const SocketTag & tag)124 void SOCKS5ClientSocket::ApplySocketTag(const SocketTag& tag) {
125 return transport_socket_->ApplySocketTag(tag);
126 }
127
128 // Read is called by the transport layer above to read. This can only be done
129 // if the SOCKS handshake is complete.
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)130 int SOCKS5ClientSocket::Read(IOBuffer* buf,
131 int buf_len,
132 CompletionOnceCallback callback) {
133 DCHECK(completed_handshake_);
134 DCHECK_EQ(STATE_NONE, next_state_);
135 DCHECK(user_callback_.is_null());
136 DCHECK(!callback.is_null());
137
138 int rv = transport_socket_->Read(
139 buf, buf_len,
140 base::BindOnce(&SOCKS5ClientSocket::OnReadWriteComplete,
141 base::Unretained(this), std::move(callback)));
142 if (rv > 0)
143 was_ever_used_ = true;
144 return rv;
145 }
146
147 // Write is called by the transport layer. This can only be done if the
148 // SOCKS handshake is complete.
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)149 int SOCKS5ClientSocket::Write(
150 IOBuffer* buf,
151 int buf_len,
152 CompletionOnceCallback callback,
153 const NetworkTrafficAnnotationTag& traffic_annotation) {
154 DCHECK(completed_handshake_);
155 DCHECK_EQ(STATE_NONE, next_state_);
156 DCHECK(user_callback_.is_null());
157 DCHECK(!callback.is_null());
158
159 int rv = transport_socket_->Write(
160 buf, buf_len,
161 base::BindOnce(&SOCKS5ClientSocket::OnReadWriteComplete,
162 base::Unretained(this), std::move(callback)),
163 traffic_annotation);
164 if (rv > 0)
165 was_ever_used_ = true;
166 return rv;
167 }
168
SetReceiveBufferSize(int32_t size)169 int SOCKS5ClientSocket::SetReceiveBufferSize(int32_t size) {
170 return transport_socket_->SetReceiveBufferSize(size);
171 }
172
SetSendBufferSize(int32_t size)173 int SOCKS5ClientSocket::SetSendBufferSize(int32_t size) {
174 return transport_socket_->SetSendBufferSize(size);
175 }
176
DoCallback(int result)177 void SOCKS5ClientSocket::DoCallback(int result) {
178 DCHECK_NE(ERR_IO_PENDING, result);
179 DCHECK(!user_callback_.is_null());
180
181 // Since Run() may result in Read being called,
182 // clear user_callback_ up front.
183 std::move(user_callback_).Run(result);
184 }
185
OnIOComplete(int result)186 void SOCKS5ClientSocket::OnIOComplete(int result) {
187 DCHECK_NE(STATE_NONE, next_state_);
188 int rv = DoLoop(result);
189 if (rv != ERR_IO_PENDING) {
190 net_log_.EndEvent(NetLogEventType::SOCKS5_CONNECT);
191 DoCallback(rv);
192 }
193 }
194
OnReadWriteComplete(CompletionOnceCallback callback,int result)195 void SOCKS5ClientSocket::OnReadWriteComplete(CompletionOnceCallback callback,
196 int result) {
197 DCHECK_NE(ERR_IO_PENDING, result);
198 DCHECK(!callback.is_null());
199
200 if (result > 0)
201 was_ever_used_ = true;
202 std::move(callback).Run(result);
203 }
204
DoLoop(int last_io_result)205 int SOCKS5ClientSocket::DoLoop(int last_io_result) {
206 DCHECK_NE(next_state_, STATE_NONE);
207 int rv = last_io_result;
208 do {
209 State state = next_state_;
210 next_state_ = STATE_NONE;
211 switch (state) {
212 case STATE_GREET_WRITE:
213 DCHECK_EQ(OK, rv);
214 net_log_.BeginEvent(NetLogEventType::SOCKS5_GREET_WRITE);
215 rv = DoGreetWrite();
216 break;
217 case STATE_GREET_WRITE_COMPLETE:
218 rv = DoGreetWriteComplete(rv);
219 net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_WRITE,
220 rv);
221 break;
222 case STATE_GREET_READ:
223 DCHECK_EQ(OK, rv);
224 net_log_.BeginEvent(NetLogEventType::SOCKS5_GREET_READ);
225 rv = DoGreetRead();
226 break;
227 case STATE_GREET_READ_COMPLETE:
228 rv = DoGreetReadComplete(rv);
229 net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_READ,
230 rv);
231 break;
232 case STATE_HANDSHAKE_WRITE:
233 DCHECK_EQ(OK, rv);
234 net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_WRITE);
235 rv = DoHandshakeWrite();
236 break;
237 case STATE_HANDSHAKE_WRITE_COMPLETE:
238 rv = DoHandshakeWriteComplete(rv);
239 net_log_.EndEventWithNetErrorCode(
240 NetLogEventType::SOCKS5_HANDSHAKE_WRITE, rv);
241 break;
242 case STATE_HANDSHAKE_READ:
243 DCHECK_EQ(OK, rv);
244 net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_READ);
245 rv = DoHandshakeRead();
246 break;
247 case STATE_HANDSHAKE_READ_COMPLETE:
248 rv = DoHandshakeReadComplete(rv);
249 net_log_.EndEventWithNetErrorCode(
250 NetLogEventType::SOCKS5_HANDSHAKE_READ, rv);
251 break;
252 default:
253 NOTREACHED() << "bad state";
254 rv = ERR_UNEXPECTED;
255 break;
256 }
257 } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
258 return rv;
259 }
260
261 const char kSOCKS5GreetWriteData[] = { 0x05, 0x01, 0x00 }; // no authentication
262
DoGreetWrite()263 int SOCKS5ClientSocket::DoGreetWrite() {
264 // Since we only have 1 byte to send the hostname length in, if the
265 // URL has a hostname longer than 255 characters we can't send it.
266 if (0xFF < destination_.host().size()) {
267 net_log_.AddEvent(NetLogEventType::SOCKS_HOSTNAME_TOO_BIG);
268 return ERR_SOCKS_CONNECTION_FAILED;
269 }
270
271 if (buffer_.empty()) {
272 buffer_ =
273 std::string(kSOCKS5GreetWriteData, std::size(kSOCKS5GreetWriteData));
274 bytes_sent_ = 0;
275 }
276
277 next_state_ = STATE_GREET_WRITE_COMPLETE;
278 size_t handshake_buf_len = buffer_.size() - bytes_sent_;
279 handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
280 memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_],
281 handshake_buf_len);
282 return transport_socket_->Write(handshake_buf_.get(), handshake_buf_len,
283 io_callback_, traffic_annotation_);
284 }
285
DoGreetWriteComplete(int result)286 int SOCKS5ClientSocket::DoGreetWriteComplete(int result) {
287 if (result < 0)
288 return result;
289
290 bytes_sent_ += result;
291 if (bytes_sent_ == buffer_.size()) {
292 buffer_.clear();
293 bytes_received_ = 0;
294 next_state_ = STATE_GREET_READ;
295 } else {
296 next_state_ = STATE_GREET_WRITE;
297 }
298 return OK;
299 }
300
DoGreetRead()301 int SOCKS5ClientSocket::DoGreetRead() {
302 next_state_ = STATE_GREET_READ_COMPLETE;
303 size_t handshake_buf_len = kGreetReadHeaderSize - bytes_received_;
304 handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
305 return transport_socket_->Read(handshake_buf_.get(), handshake_buf_len,
306 io_callback_);
307 }
308
DoGreetReadComplete(int result)309 int SOCKS5ClientSocket::DoGreetReadComplete(int result) {
310 if (result < 0)
311 return result;
312
313 if (result == 0) {
314 net_log_.AddEvent(
315 NetLogEventType::SOCKS_UNEXPECTEDLY_CLOSED_DURING_GREETING);
316 return ERR_SOCKS_CONNECTION_FAILED;
317 }
318
319 bytes_received_ += result;
320 buffer_.append(handshake_buf_->data(), result);
321 if (bytes_received_ < kGreetReadHeaderSize) {
322 next_state_ = STATE_GREET_READ;
323 return OK;
324 }
325
326 // Got the greet data.
327 if (buffer_[0] != kSOCKS5Version) {
328 net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION,
329 "version", buffer_[0]);
330 return ERR_SOCKS_CONNECTION_FAILED;
331 }
332 if (buffer_[1] != 0x00) {
333 net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_AUTH,
334 "method", buffer_[1]);
335 return ERR_SOCKS_CONNECTION_FAILED;
336 }
337
338 buffer_.clear();
339 next_state_ = STATE_HANDSHAKE_WRITE;
340 return OK;
341 }
342
BuildHandshakeWriteBuffer(std::string * handshake) const343 int SOCKS5ClientSocket::BuildHandshakeWriteBuffer(std::string* handshake)
344 const {
345 DCHECK(handshake->empty());
346
347 handshake->push_back(kSOCKS5Version);
348 handshake->push_back(kTunnelCommand); // Connect command
349 handshake->push_back(kNullByte); // Reserved null
350
351 handshake->push_back(kEndPointDomain); // The type of the address.
352
353 DCHECK_GE(static_cast<size_t>(0xFF), destination_.host().size());
354
355 // First add the size of the hostname, followed by the hostname.
356 handshake->push_back(static_cast<unsigned char>(destination_.host().size()));
357 handshake->append(destination_.host());
358
359 uint16_t nw_port = base::HostToNet16(destination_.port());
360 handshake->append(reinterpret_cast<char*>(&nw_port), sizeof(nw_port));
361 return OK;
362 }
363
364 // Writes the SOCKS handshake data to the underlying socket connection.
DoHandshakeWrite()365 int SOCKS5ClientSocket::DoHandshakeWrite() {
366 next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
367
368 if (buffer_.empty()) {
369 int rv = BuildHandshakeWriteBuffer(&buffer_);
370 if (rv != OK)
371 return rv;
372 bytes_sent_ = 0;
373 }
374
375 int handshake_buf_len = buffer_.size() - bytes_sent_;
376 DCHECK_LT(0, handshake_buf_len);
377 handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
378 memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
379 handshake_buf_len);
380 return transport_socket_->Write(handshake_buf_.get(), handshake_buf_len,
381 io_callback_, traffic_annotation_);
382 }
383
DoHandshakeWriteComplete(int result)384 int SOCKS5ClientSocket::DoHandshakeWriteComplete(int result) {
385 if (result < 0)
386 return result;
387
388 // We ignore the case when result is 0, since the underlying Write
389 // may return spurious writes while waiting on the socket.
390
391 bytes_sent_ += result;
392 if (bytes_sent_ == buffer_.size()) {
393 next_state_ = STATE_HANDSHAKE_READ;
394 buffer_.clear();
395 } else if (bytes_sent_ < buffer_.size()) {
396 next_state_ = STATE_HANDSHAKE_WRITE;
397 } else {
398 NOTREACHED();
399 }
400
401 return OK;
402 }
403
DoHandshakeRead()404 int SOCKS5ClientSocket::DoHandshakeRead() {
405 next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
406
407 if (buffer_.empty()) {
408 bytes_received_ = 0;
409 read_header_size = kReadHeaderSize;
410 }
411
412 int handshake_buf_len = read_header_size - bytes_received_;
413 handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
414 return transport_socket_->Read(handshake_buf_.get(), handshake_buf_len,
415 io_callback_);
416 }
417
DoHandshakeReadComplete(int result)418 int SOCKS5ClientSocket::DoHandshakeReadComplete(int result) {
419 if (result < 0)
420 return result;
421
422 // The underlying socket closed unexpectedly.
423 if (result == 0) {
424 net_log_.AddEvent(
425 NetLogEventType::SOCKS_UNEXPECTEDLY_CLOSED_DURING_HANDSHAKE);
426 return ERR_SOCKS_CONNECTION_FAILED;
427 }
428
429 buffer_.append(handshake_buf_->data(), result);
430 bytes_received_ += result;
431
432 // When the first few bytes are read, check how many more are required
433 // and accordingly increase them
434 if (bytes_received_ == kReadHeaderSize) {
435 if (buffer_[0] != kSOCKS5Version || buffer_[2] != kNullByte) {
436 net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION,
437 "version", buffer_[0]);
438 return ERR_SOCKS_CONNECTION_FAILED;
439 }
440 if (buffer_[1] != 0x00) {
441 net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_SERVER_ERROR,
442 "error_code", buffer_[1]);
443 return ERR_SOCKS_CONNECTION_FAILED;
444 }
445
446 // We check the type of IP/Domain the server returns and accordingly
447 // increase the size of the response. For domains, we need to read the
448 // size of the domain, so the initial request size is upto the domain
449 // size. Since for IPv4/IPv6 the size is fixed and hence no 'size' is
450 // read, we substract 1 byte from the additional request size.
451 SocksEndPointAddressType address_type =
452 static_cast<SocksEndPointAddressType>(buffer_[3]);
453 if (address_type == kEndPointDomain) {
454 read_header_size += static_cast<uint8_t>(buffer_[4]);
455 } else if (address_type == kEndPointResolvedIPv4) {
456 read_header_size += sizeof(struct in_addr) - 1;
457 } else if (address_type == kEndPointResolvedIPv6) {
458 read_header_size += sizeof(struct in6_addr) - 1;
459 } else {
460 net_log_.AddEventWithIntParams(
461 NetLogEventType::SOCKS_UNKNOWN_ADDRESS_TYPE, "address_type",
462 buffer_[3]);
463 return ERR_SOCKS_CONNECTION_FAILED;
464 }
465
466 read_header_size += 2; // for the port.
467 next_state_ = STATE_HANDSHAKE_READ;
468 return OK;
469 }
470
471 // When the final bytes are read, setup handshake. We ignore the rest
472 // of the response since they represent the SOCKSv5 endpoint and have
473 // no use when doing a tunnel connection.
474 if (bytes_received_ == read_header_size) {
475 completed_handshake_ = true;
476 buffer_.clear();
477 next_state_ = STATE_NONE;
478 return OK;
479 }
480
481 next_state_ = STATE_HANDSHAKE_READ;
482 return OK;
483 }
484
GetPeerAddress(IPEndPoint * address) const485 int SOCKS5ClientSocket::GetPeerAddress(IPEndPoint* address) const {
486 return transport_socket_->GetPeerAddress(address);
487 }
488
GetLocalAddress(IPEndPoint * address) const489 int SOCKS5ClientSocket::GetLocalAddress(IPEndPoint* address) const {
490 return transport_socket_->GetLocalAddress(address);
491 }
492
493 } // namespace net
494