1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9
10 #include "net/socket/socks5_client_socket.h"
11
12 #include <utility>
13
14 #include "base/compiler_specific.h"
15 #include "base/format_macros.h"
16 #include "base/functional/bind.h"
17 #include "base/functional/callback_helpers.h"
18 #include "base/strings/string_util.h"
19 #include "base/sys_byteorder.h"
20 #include "net/base/io_buffer.h"
21 #include "net/base/sys_addrinfo.h"
22 #include "net/base/tracing.h"
23 #include "net/log/net_log.h"
24 #include "net/log/net_log_event_type.h"
25 #include "net/traffic_annotation/network_traffic_annotation.h"
26
27 namespace net {
28
29 const unsigned int SOCKS5ClientSocket::kGreetReadHeaderSize = 2;
30 const unsigned int SOCKS5ClientSocket::kWriteHeaderSize = 10;
31 const unsigned int SOCKS5ClientSocket::kReadHeaderSize = 5;
32 const uint8_t SOCKS5ClientSocket::kSOCKS5Version = 0x05;
33 const uint8_t SOCKS5ClientSocket::kTunnelCommand = 0x01;
34 const uint8_t SOCKS5ClientSocket::kNullByte = 0x00;
35
36 static_assert(sizeof(struct in_addr) == 4, "incorrect system size of IPv4");
37 static_assert(sizeof(struct in6_addr) == 16, "incorrect system size of IPv6");
38
SOCKS5ClientSocket(std::unique_ptr<StreamSocket> transport_socket,const HostPortPair & destination,const NetworkTrafficAnnotationTag & traffic_annotation)39 SOCKS5ClientSocket::SOCKS5ClientSocket(
40 std::unique_ptr<StreamSocket> transport_socket,
41 const HostPortPair& destination,
42 const NetworkTrafficAnnotationTag& traffic_annotation)
43 : io_callback_(base::BindRepeating(&SOCKS5ClientSocket::OnIOComplete,
44 base::Unretained(this))),
45 transport_socket_(std::move(transport_socket)),
46 read_header_size(kReadHeaderSize),
47 destination_(destination),
48 net_log_(transport_socket_->NetLog()),
49 traffic_annotation_(traffic_annotation) {}
50
~SOCKS5ClientSocket()51 SOCKS5ClientSocket::~SOCKS5ClientSocket() {
52 Disconnect();
53 }
54
Connect(CompletionOnceCallback callback)55 int SOCKS5ClientSocket::Connect(CompletionOnceCallback callback) {
56 DCHECK(transport_socket_);
57 DCHECK_EQ(STATE_NONE, next_state_);
58 DCHECK(user_callback_.is_null());
59
60 // If already connected, then just return OK.
61 if (completed_handshake_)
62 return OK;
63
64 net_log_.BeginEvent(NetLogEventType::SOCKS5_CONNECT);
65
66 next_state_ = STATE_GREET_WRITE;
67 buffer_.clear();
68
69 int rv = DoLoop(OK);
70 if (rv == ERR_IO_PENDING) {
71 user_callback_ = std::move(callback);
72 } else {
73 net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_CONNECT, rv);
74 }
75 return rv;
76 }
77
Disconnect()78 void SOCKS5ClientSocket::Disconnect() {
79 completed_handshake_ = false;
80 transport_socket_->Disconnect();
81
82 // Reset other states to make sure they aren't mistakenly used later.
83 // These are the states initialized by Connect().
84 next_state_ = STATE_NONE;
85 user_callback_.Reset();
86 }
87
IsConnected() const88 bool SOCKS5ClientSocket::IsConnected() const {
89 return completed_handshake_ && transport_socket_->IsConnected();
90 }
91
IsConnectedAndIdle() const92 bool SOCKS5ClientSocket::IsConnectedAndIdle() const {
93 return completed_handshake_ && transport_socket_->IsConnectedAndIdle();
94 }
95
NetLog() const96 const NetLogWithSource& SOCKS5ClientSocket::NetLog() const {
97 return net_log_;
98 }
99
WasEverUsed() const100 bool SOCKS5ClientSocket::WasEverUsed() const {
101 return was_ever_used_;
102 }
103
GetNegotiatedProtocol() const104 NextProto SOCKS5ClientSocket::GetNegotiatedProtocol() const {
105 if (transport_socket_)
106 return transport_socket_->GetNegotiatedProtocol();
107 NOTREACHED();
108 }
109
GetSSLInfo(SSLInfo * ssl_info)110 bool SOCKS5ClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
111 if (transport_socket_)
112 return transport_socket_->GetSSLInfo(ssl_info);
113 NOTREACHED();
114 }
115
GetTotalReceivedBytes() const116 int64_t SOCKS5ClientSocket::GetTotalReceivedBytes() const {
117 return transport_socket_->GetTotalReceivedBytes();
118 }
119
ApplySocketTag(const SocketTag & tag)120 void SOCKS5ClientSocket::ApplySocketTag(const SocketTag& tag) {
121 return transport_socket_->ApplySocketTag(tag);
122 }
123
124 // Read is called by the transport layer above to read. This can only be done
125 // if the SOCKS handshake is complete.
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)126 int SOCKS5ClientSocket::Read(IOBuffer* buf,
127 int buf_len,
128 CompletionOnceCallback callback) {
129 DCHECK(completed_handshake_);
130 DCHECK_EQ(STATE_NONE, next_state_);
131 DCHECK(user_callback_.is_null());
132 DCHECK(!callback.is_null());
133
134 int rv = transport_socket_->Read(
135 buf, buf_len,
136 base::BindOnce(&SOCKS5ClientSocket::OnReadWriteComplete,
137 base::Unretained(this), std::move(callback)));
138 if (rv > 0)
139 was_ever_used_ = true;
140 return rv;
141 }
142
143 // Write is called by the transport layer. This can only be done if the
144 // SOCKS handshake is complete.
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)145 int SOCKS5ClientSocket::Write(
146 IOBuffer* buf,
147 int buf_len,
148 CompletionOnceCallback callback,
149 const NetworkTrafficAnnotationTag& traffic_annotation) {
150 DCHECK(completed_handshake_);
151 DCHECK_EQ(STATE_NONE, next_state_);
152 DCHECK(user_callback_.is_null());
153 DCHECK(!callback.is_null());
154
155 int rv = transport_socket_->Write(
156 buf, buf_len,
157 base::BindOnce(&SOCKS5ClientSocket::OnReadWriteComplete,
158 base::Unretained(this), std::move(callback)),
159 traffic_annotation);
160 if (rv > 0)
161 was_ever_used_ = true;
162 return rv;
163 }
164
SetReceiveBufferSize(int32_t size)165 int SOCKS5ClientSocket::SetReceiveBufferSize(int32_t size) {
166 return transport_socket_->SetReceiveBufferSize(size);
167 }
168
SetSendBufferSize(int32_t size)169 int SOCKS5ClientSocket::SetSendBufferSize(int32_t size) {
170 return transport_socket_->SetSendBufferSize(size);
171 }
172
DoCallback(int result)173 void SOCKS5ClientSocket::DoCallback(int result) {
174 DCHECK_NE(ERR_IO_PENDING, result);
175 DCHECK(!user_callback_.is_null());
176
177 // Since Run() may result in Read being called,
178 // clear user_callback_ up front.
179 std::move(user_callback_).Run(result);
180 }
181
OnIOComplete(int result)182 void SOCKS5ClientSocket::OnIOComplete(int result) {
183 DCHECK_NE(STATE_NONE, next_state_);
184 int rv = DoLoop(result);
185 if (rv != ERR_IO_PENDING) {
186 net_log_.EndEvent(NetLogEventType::SOCKS5_CONNECT);
187 DoCallback(rv);
188 }
189 }
190
OnReadWriteComplete(CompletionOnceCallback callback,int result)191 void SOCKS5ClientSocket::OnReadWriteComplete(CompletionOnceCallback callback,
192 int result) {
193 DCHECK_NE(ERR_IO_PENDING, result);
194 DCHECK(!callback.is_null());
195
196 if (result > 0)
197 was_ever_used_ = true;
198 std::move(callback).Run(result);
199 }
200
DoLoop(int last_io_result)201 int SOCKS5ClientSocket::DoLoop(int last_io_result) {
202 DCHECK_NE(next_state_, STATE_NONE);
203 int rv = last_io_result;
204 do {
205 State state = next_state_;
206 next_state_ = STATE_NONE;
207 switch (state) {
208 case STATE_GREET_WRITE:
209 DCHECK_EQ(OK, rv);
210 net_log_.BeginEvent(NetLogEventType::SOCKS5_GREET_WRITE);
211 rv = DoGreetWrite();
212 break;
213 case STATE_GREET_WRITE_COMPLETE:
214 rv = DoGreetWriteComplete(rv);
215 net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_WRITE,
216 rv);
217 break;
218 case STATE_GREET_READ:
219 DCHECK_EQ(OK, rv);
220 net_log_.BeginEvent(NetLogEventType::SOCKS5_GREET_READ);
221 rv = DoGreetRead();
222 break;
223 case STATE_GREET_READ_COMPLETE:
224 rv = DoGreetReadComplete(rv);
225 net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_READ,
226 rv);
227 break;
228 case STATE_HANDSHAKE_WRITE:
229 DCHECK_EQ(OK, rv);
230 net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_WRITE);
231 rv = DoHandshakeWrite();
232 break;
233 case STATE_HANDSHAKE_WRITE_COMPLETE:
234 rv = DoHandshakeWriteComplete(rv);
235 net_log_.EndEventWithNetErrorCode(
236 NetLogEventType::SOCKS5_HANDSHAKE_WRITE, rv);
237 break;
238 case STATE_HANDSHAKE_READ:
239 DCHECK_EQ(OK, rv);
240 net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_READ);
241 rv = DoHandshakeRead();
242 break;
243 case STATE_HANDSHAKE_READ_COMPLETE:
244 rv = DoHandshakeReadComplete(rv);
245 net_log_.EndEventWithNetErrorCode(
246 NetLogEventType::SOCKS5_HANDSHAKE_READ, rv);
247 break;
248 default:
249 NOTREACHED() << "bad state";
250 }
251 } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
252 return rv;
253 }
254
255 const char kSOCKS5GreetWriteData[] = { 0x05, 0x01, 0x00 }; // no authentication
256
DoGreetWrite()257 int SOCKS5ClientSocket::DoGreetWrite() {
258 // Since we only have 1 byte to send the hostname length in, if the
259 // URL has a hostname longer than 255 characters we can't send it.
260 if (0xFF < destination_.host().size()) {
261 net_log_.AddEvent(NetLogEventType::SOCKS_HOSTNAME_TOO_BIG);
262 return ERR_SOCKS_CONNECTION_FAILED;
263 }
264
265 if (buffer_.empty()) {
266 buffer_ =
267 std::string(kSOCKS5GreetWriteData, std::size(kSOCKS5GreetWriteData));
268 bytes_sent_ = 0;
269 }
270
271 next_state_ = STATE_GREET_WRITE_COMPLETE;
272 size_t handshake_buf_len = buffer_.size() - bytes_sent_;
273 handshake_buf_ = base::MakeRefCounted<IOBufferWithSize>(handshake_buf_len);
274 memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_],
275 handshake_buf_len);
276 return transport_socket_->Write(handshake_buf_.get(), handshake_buf_len,
277 io_callback_, traffic_annotation_);
278 }
279
DoGreetWriteComplete(int result)280 int SOCKS5ClientSocket::DoGreetWriteComplete(int result) {
281 if (result < 0)
282 return result;
283
284 bytes_sent_ += result;
285 if (bytes_sent_ == buffer_.size()) {
286 buffer_.clear();
287 bytes_received_ = 0;
288 next_state_ = STATE_GREET_READ;
289 } else {
290 next_state_ = STATE_GREET_WRITE;
291 }
292 return OK;
293 }
294
DoGreetRead()295 int SOCKS5ClientSocket::DoGreetRead() {
296 next_state_ = STATE_GREET_READ_COMPLETE;
297 size_t handshake_buf_len = kGreetReadHeaderSize - bytes_received_;
298 handshake_buf_ = base::MakeRefCounted<IOBufferWithSize>(handshake_buf_len);
299 return transport_socket_->Read(handshake_buf_.get(), handshake_buf_len,
300 io_callback_);
301 }
302
DoGreetReadComplete(int result)303 int SOCKS5ClientSocket::DoGreetReadComplete(int result) {
304 if (result < 0)
305 return result;
306
307 if (result == 0) {
308 net_log_.AddEvent(
309 NetLogEventType::SOCKS_UNEXPECTEDLY_CLOSED_DURING_GREETING);
310 return ERR_SOCKS_CONNECTION_FAILED;
311 }
312
313 bytes_received_ += result;
314 buffer_.append(handshake_buf_->data(), result);
315 if (bytes_received_ < kGreetReadHeaderSize) {
316 next_state_ = STATE_GREET_READ;
317 return OK;
318 }
319
320 // Got the greet data.
321 if (buffer_[0] != kSOCKS5Version) {
322 net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION,
323 "version", buffer_[0]);
324 return ERR_SOCKS_CONNECTION_FAILED;
325 }
326 if (buffer_[1] != 0x00) {
327 net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_AUTH,
328 "method", buffer_[1]);
329 return ERR_SOCKS_CONNECTION_FAILED;
330 }
331
332 buffer_.clear();
333 next_state_ = STATE_HANDSHAKE_WRITE;
334 return OK;
335 }
336
BuildHandshakeWriteBuffer(std::string * handshake) const337 int SOCKS5ClientSocket::BuildHandshakeWriteBuffer(std::string* handshake)
338 const {
339 DCHECK(handshake->empty());
340
341 handshake->push_back(kSOCKS5Version);
342 handshake->push_back(kTunnelCommand); // Connect command
343 handshake->push_back(kNullByte); // Reserved null
344
345 handshake->push_back(kEndPointDomain); // The type of the address.
346
347 DCHECK_GE(static_cast<size_t>(0xFF), destination_.host().size());
348
349 // First add the size of the hostname, followed by the hostname.
350 handshake->push_back(static_cast<unsigned char>(destination_.host().size()));
351 handshake->append(destination_.host());
352
353 uint16_t nw_port = base::HostToNet16(destination_.port());
354 handshake->append(reinterpret_cast<char*>(&nw_port), sizeof(nw_port));
355 return OK;
356 }
357
358 // Writes the SOCKS handshake data to the underlying socket connection.
DoHandshakeWrite()359 int SOCKS5ClientSocket::DoHandshakeWrite() {
360 next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
361
362 if (buffer_.empty()) {
363 int rv = BuildHandshakeWriteBuffer(&buffer_);
364 if (rv != OK)
365 return rv;
366 bytes_sent_ = 0;
367 }
368
369 int handshake_buf_len = buffer_.size() - bytes_sent_;
370 DCHECK_LT(0, handshake_buf_len);
371 handshake_buf_ = base::MakeRefCounted<IOBufferWithSize>(handshake_buf_len);
372 memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
373 handshake_buf_len);
374 return transport_socket_->Write(handshake_buf_.get(), handshake_buf_len,
375 io_callback_, traffic_annotation_);
376 }
377
DoHandshakeWriteComplete(int result)378 int SOCKS5ClientSocket::DoHandshakeWriteComplete(int result) {
379 if (result < 0)
380 return result;
381
382 // We ignore the case when result is 0, since the underlying Write
383 // may return spurious writes while waiting on the socket.
384
385 bytes_sent_ += result;
386 if (bytes_sent_ == buffer_.size()) {
387 next_state_ = STATE_HANDSHAKE_READ;
388 buffer_.clear();
389 } else if (bytes_sent_ < buffer_.size()) {
390 next_state_ = STATE_HANDSHAKE_WRITE;
391 } else {
392 NOTREACHED();
393 }
394
395 return OK;
396 }
397
DoHandshakeRead()398 int SOCKS5ClientSocket::DoHandshakeRead() {
399 next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
400
401 if (buffer_.empty()) {
402 bytes_received_ = 0;
403 read_header_size = kReadHeaderSize;
404 }
405
406 int handshake_buf_len = read_header_size - bytes_received_;
407 handshake_buf_ = base::MakeRefCounted<IOBufferWithSize>(handshake_buf_len);
408 return transport_socket_->Read(handshake_buf_.get(), handshake_buf_len,
409 io_callback_);
410 }
411
DoHandshakeReadComplete(int result)412 int SOCKS5ClientSocket::DoHandshakeReadComplete(int result) {
413 if (result < 0)
414 return result;
415
416 // The underlying socket closed unexpectedly.
417 if (result == 0) {
418 net_log_.AddEvent(
419 NetLogEventType::SOCKS_UNEXPECTEDLY_CLOSED_DURING_HANDSHAKE);
420 return ERR_SOCKS_CONNECTION_FAILED;
421 }
422
423 buffer_.append(handshake_buf_->data(), result);
424 bytes_received_ += result;
425
426 // When the first few bytes are read, check how many more are required
427 // and accordingly increase them
428 if (bytes_received_ == kReadHeaderSize) {
429 if (buffer_[0] != kSOCKS5Version || buffer_[2] != kNullByte) {
430 net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION,
431 "version", buffer_[0]);
432 return ERR_SOCKS_CONNECTION_FAILED;
433 }
434 if (buffer_[1] != 0x00) {
435 net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_SERVER_ERROR,
436 "error_code", buffer_[1]);
437 return ERR_SOCKS_CONNECTION_FAILED;
438 }
439
440 // We check the type of IP/Domain the server returns and accordingly
441 // increase the size of the response. For domains, we need to read the
442 // size of the domain, so the initial request size is upto the domain
443 // size. Since for IPv4/IPv6 the size is fixed and hence no 'size' is
444 // read, we substract 1 byte from the additional request size.
445 SocksEndPointAddressType address_type =
446 static_cast<SocksEndPointAddressType>(buffer_[3]);
447 if (address_type == kEndPointDomain) {
448 read_header_size += static_cast<uint8_t>(buffer_[4]);
449 } else if (address_type == kEndPointResolvedIPv4) {
450 read_header_size += sizeof(struct in_addr) - 1;
451 } else if (address_type == kEndPointResolvedIPv6) {
452 read_header_size += sizeof(struct in6_addr) - 1;
453 } else {
454 net_log_.AddEventWithIntParams(
455 NetLogEventType::SOCKS_UNKNOWN_ADDRESS_TYPE, "address_type",
456 buffer_[3]);
457 return ERR_SOCKS_CONNECTION_FAILED;
458 }
459
460 read_header_size += 2; // for the port.
461 next_state_ = STATE_HANDSHAKE_READ;
462 return OK;
463 }
464
465 // When the final bytes are read, setup handshake. We ignore the rest
466 // of the response since they represent the SOCKSv5 endpoint and have
467 // no use when doing a tunnel connection.
468 if (bytes_received_ == read_header_size) {
469 completed_handshake_ = true;
470 buffer_.clear();
471 next_state_ = STATE_NONE;
472 return OK;
473 }
474
475 next_state_ = STATE_HANDSHAKE_READ;
476 return OK;
477 }
478
GetPeerAddress(IPEndPoint * address) const479 int SOCKS5ClientSocket::GetPeerAddress(IPEndPoint* address) const {
480 return transport_socket_->GetPeerAddress(address);
481 }
482
GetLocalAddress(IPEndPoint * address) const483 int SOCKS5ClientSocket::GetLocalAddress(IPEndPoint* address) const {
484 return transport_socket_->GetLocalAddress(address);
485 }
486
487 } // namespace net
488