• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/socks5_client_socket.h"
6 
7 #include <utility>
8 
9 #include "base/compiler_specific.h"
10 #include "base/format_macros.h"
11 #include "base/functional/bind.h"
12 #include "base/functional/callback_helpers.h"
13 #include "base/strings/string_util.h"
14 #include "base/sys_byteorder.h"
15 #include "net/base/io_buffer.h"
16 #include "net/base/sys_addrinfo.h"
17 #include "net/base/tracing.h"
18 #include "net/log/net_log.h"
19 #include "net/log/net_log_event_type.h"
20 #include "net/traffic_annotation/network_traffic_annotation.h"
21 
22 namespace net {
23 
24 const unsigned int SOCKS5ClientSocket::kGreetReadHeaderSize = 2;
25 const unsigned int SOCKS5ClientSocket::kWriteHeaderSize = 10;
26 const unsigned int SOCKS5ClientSocket::kReadHeaderSize = 5;
27 const uint8_t SOCKS5ClientSocket::kSOCKS5Version = 0x05;
28 const uint8_t SOCKS5ClientSocket::kTunnelCommand = 0x01;
29 const uint8_t SOCKS5ClientSocket::kNullByte = 0x00;
30 
31 static_assert(sizeof(struct in_addr) == 4, "incorrect system size of IPv4");
32 static_assert(sizeof(struct in6_addr) == 16, "incorrect system size of IPv6");
33 
SOCKS5ClientSocket(std::unique_ptr<StreamSocket> transport_socket,const HostPortPair & destination,const NetworkTrafficAnnotationTag & traffic_annotation)34 SOCKS5ClientSocket::SOCKS5ClientSocket(
35     std::unique_ptr<StreamSocket> transport_socket,
36     const HostPortPair& destination,
37     const NetworkTrafficAnnotationTag& traffic_annotation)
38     : io_callback_(base::BindRepeating(&SOCKS5ClientSocket::OnIOComplete,
39                                        base::Unretained(this))),
40       transport_socket_(std::move(transport_socket)),
41       read_header_size(kReadHeaderSize),
42       destination_(destination),
43       net_log_(transport_socket_->NetLog()),
44       traffic_annotation_(traffic_annotation) {}
45 
~SOCKS5ClientSocket()46 SOCKS5ClientSocket::~SOCKS5ClientSocket() {
47   Disconnect();
48 }
49 
Connect(CompletionOnceCallback callback)50 int SOCKS5ClientSocket::Connect(CompletionOnceCallback callback) {
51   DCHECK(transport_socket_);
52   DCHECK_EQ(STATE_NONE, next_state_);
53   DCHECK(user_callback_.is_null());
54 
55   // If already connected, then just return OK.
56   if (completed_handshake_)
57     return OK;
58 
59   net_log_.BeginEvent(NetLogEventType::SOCKS5_CONNECT);
60 
61   next_state_ = STATE_GREET_WRITE;
62   buffer_.clear();
63 
64   int rv = DoLoop(OK);
65   if (rv == ERR_IO_PENDING) {
66     user_callback_ = std::move(callback);
67   } else {
68     net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_CONNECT, rv);
69   }
70   return rv;
71 }
72 
Disconnect()73 void SOCKS5ClientSocket::Disconnect() {
74   completed_handshake_ = false;
75   transport_socket_->Disconnect();
76 
77   // Reset other states to make sure they aren't mistakenly used later.
78   // These are the states initialized by Connect().
79   next_state_ = STATE_NONE;
80   user_callback_.Reset();
81 }
82 
IsConnected() const83 bool SOCKS5ClientSocket::IsConnected() const {
84   return completed_handshake_ && transport_socket_->IsConnected();
85 }
86 
IsConnectedAndIdle() const87 bool SOCKS5ClientSocket::IsConnectedAndIdle() const {
88   return completed_handshake_ && transport_socket_->IsConnectedAndIdle();
89 }
90 
NetLog() const91 const NetLogWithSource& SOCKS5ClientSocket::NetLog() const {
92   return net_log_;
93 }
94 
WasEverUsed() const95 bool SOCKS5ClientSocket::WasEverUsed() const {
96   return was_ever_used_;
97 }
98 
WasAlpnNegotiated() const99 bool SOCKS5ClientSocket::WasAlpnNegotiated() const {
100   if (transport_socket_)
101     return transport_socket_->WasAlpnNegotiated();
102   NOTREACHED();
103   return false;
104 }
105 
GetNegotiatedProtocol() const106 NextProto SOCKS5ClientSocket::GetNegotiatedProtocol() const {
107   if (transport_socket_)
108     return transport_socket_->GetNegotiatedProtocol();
109   NOTREACHED();
110   return kProtoUnknown;
111 }
112 
GetSSLInfo(SSLInfo * ssl_info)113 bool SOCKS5ClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
114   if (transport_socket_)
115     return transport_socket_->GetSSLInfo(ssl_info);
116   NOTREACHED();
117   return false;
118 }
119 
GetTotalReceivedBytes() const120 int64_t SOCKS5ClientSocket::GetTotalReceivedBytes() const {
121   return transport_socket_->GetTotalReceivedBytes();
122 }
123 
ApplySocketTag(const SocketTag & tag)124 void SOCKS5ClientSocket::ApplySocketTag(const SocketTag& tag) {
125   return transport_socket_->ApplySocketTag(tag);
126 }
127 
128 // Read is called by the transport layer above to read. This can only be done
129 // if the SOCKS handshake is complete.
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)130 int SOCKS5ClientSocket::Read(IOBuffer* buf,
131                              int buf_len,
132                              CompletionOnceCallback callback) {
133   DCHECK(completed_handshake_);
134   DCHECK_EQ(STATE_NONE, next_state_);
135   DCHECK(user_callback_.is_null());
136   DCHECK(!callback.is_null());
137 
138   int rv = transport_socket_->Read(
139       buf, buf_len,
140       base::BindOnce(&SOCKS5ClientSocket::OnReadWriteComplete,
141                      base::Unretained(this), std::move(callback)));
142   if (rv > 0)
143     was_ever_used_ = true;
144   return rv;
145 }
146 
147 // Write is called by the transport layer. This can only be done if the
148 // SOCKS handshake is complete.
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)149 int SOCKS5ClientSocket::Write(
150     IOBuffer* buf,
151     int buf_len,
152     CompletionOnceCallback callback,
153     const NetworkTrafficAnnotationTag& traffic_annotation) {
154   DCHECK(completed_handshake_);
155   DCHECK_EQ(STATE_NONE, next_state_);
156   DCHECK(user_callback_.is_null());
157   DCHECK(!callback.is_null());
158 
159   int rv = transport_socket_->Write(
160       buf, buf_len,
161       base::BindOnce(&SOCKS5ClientSocket::OnReadWriteComplete,
162                      base::Unretained(this), std::move(callback)),
163       traffic_annotation);
164   if (rv > 0)
165     was_ever_used_ = true;
166   return rv;
167 }
168 
SetReceiveBufferSize(int32_t size)169 int SOCKS5ClientSocket::SetReceiveBufferSize(int32_t size) {
170   return transport_socket_->SetReceiveBufferSize(size);
171 }
172 
SetSendBufferSize(int32_t size)173 int SOCKS5ClientSocket::SetSendBufferSize(int32_t size) {
174   return transport_socket_->SetSendBufferSize(size);
175 }
176 
DoCallback(int result)177 void SOCKS5ClientSocket::DoCallback(int result) {
178   DCHECK_NE(ERR_IO_PENDING, result);
179   DCHECK(!user_callback_.is_null());
180 
181   // Since Run() may result in Read being called,
182   // clear user_callback_ up front.
183   std::move(user_callback_).Run(result);
184 }
185 
OnIOComplete(int result)186 void SOCKS5ClientSocket::OnIOComplete(int result) {
187   DCHECK_NE(STATE_NONE, next_state_);
188   int rv = DoLoop(result);
189   if (rv != ERR_IO_PENDING) {
190     net_log_.EndEvent(NetLogEventType::SOCKS5_CONNECT);
191     DoCallback(rv);
192   }
193 }
194 
OnReadWriteComplete(CompletionOnceCallback callback,int result)195 void SOCKS5ClientSocket::OnReadWriteComplete(CompletionOnceCallback callback,
196                                              int result) {
197   DCHECK_NE(ERR_IO_PENDING, result);
198   DCHECK(!callback.is_null());
199 
200   if (result > 0)
201     was_ever_used_ = true;
202   std::move(callback).Run(result);
203 }
204 
DoLoop(int last_io_result)205 int SOCKS5ClientSocket::DoLoop(int last_io_result) {
206   DCHECK_NE(next_state_, STATE_NONE);
207   int rv = last_io_result;
208   do {
209     State state = next_state_;
210     next_state_ = STATE_NONE;
211     switch (state) {
212       case STATE_GREET_WRITE:
213         DCHECK_EQ(OK, rv);
214         net_log_.BeginEvent(NetLogEventType::SOCKS5_GREET_WRITE);
215         rv = DoGreetWrite();
216         break;
217       case STATE_GREET_WRITE_COMPLETE:
218         rv = DoGreetWriteComplete(rv);
219         net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_WRITE,
220                                           rv);
221         break;
222       case STATE_GREET_READ:
223         DCHECK_EQ(OK, rv);
224         net_log_.BeginEvent(NetLogEventType::SOCKS5_GREET_READ);
225         rv = DoGreetRead();
226         break;
227       case STATE_GREET_READ_COMPLETE:
228         rv = DoGreetReadComplete(rv);
229         net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_READ,
230                                           rv);
231         break;
232       case STATE_HANDSHAKE_WRITE:
233         DCHECK_EQ(OK, rv);
234         net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_WRITE);
235         rv = DoHandshakeWrite();
236         break;
237       case STATE_HANDSHAKE_WRITE_COMPLETE:
238         rv = DoHandshakeWriteComplete(rv);
239         net_log_.EndEventWithNetErrorCode(
240             NetLogEventType::SOCKS5_HANDSHAKE_WRITE, rv);
241         break;
242       case STATE_HANDSHAKE_READ:
243         DCHECK_EQ(OK, rv);
244         net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_READ);
245         rv = DoHandshakeRead();
246         break;
247       case STATE_HANDSHAKE_READ_COMPLETE:
248         rv = DoHandshakeReadComplete(rv);
249         net_log_.EndEventWithNetErrorCode(
250             NetLogEventType::SOCKS5_HANDSHAKE_READ, rv);
251         break;
252       default:
253         NOTREACHED() << "bad state";
254         rv = ERR_UNEXPECTED;
255         break;
256     }
257   } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
258   return rv;
259 }
260 
261 const char kSOCKS5GreetWriteData[] = { 0x05, 0x01, 0x00 };  // no authentication
262 
DoGreetWrite()263 int SOCKS5ClientSocket::DoGreetWrite() {
264   // Since we only have 1 byte to send the hostname length in, if the
265   // URL has a hostname longer than 255 characters we can't send it.
266   if (0xFF < destination_.host().size()) {
267     net_log_.AddEvent(NetLogEventType::SOCKS_HOSTNAME_TOO_BIG);
268     return ERR_SOCKS_CONNECTION_FAILED;
269   }
270 
271   if (buffer_.empty()) {
272     buffer_ =
273         std::string(kSOCKS5GreetWriteData, std::size(kSOCKS5GreetWriteData));
274     bytes_sent_ = 0;
275   }
276 
277   next_state_ = STATE_GREET_WRITE_COMPLETE;
278   size_t handshake_buf_len = buffer_.size() - bytes_sent_;
279   handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
280   memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_],
281          handshake_buf_len);
282   return transport_socket_->Write(handshake_buf_.get(), handshake_buf_len,
283                                   io_callback_, traffic_annotation_);
284 }
285 
DoGreetWriteComplete(int result)286 int SOCKS5ClientSocket::DoGreetWriteComplete(int result) {
287   if (result < 0)
288     return result;
289 
290   bytes_sent_ += result;
291   if (bytes_sent_ == buffer_.size()) {
292     buffer_.clear();
293     bytes_received_ = 0;
294     next_state_ = STATE_GREET_READ;
295   } else {
296     next_state_ = STATE_GREET_WRITE;
297   }
298   return OK;
299 }
300 
DoGreetRead()301 int SOCKS5ClientSocket::DoGreetRead() {
302   next_state_ = STATE_GREET_READ_COMPLETE;
303   size_t handshake_buf_len = kGreetReadHeaderSize - bytes_received_;
304   handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
305   return transport_socket_->Read(handshake_buf_.get(), handshake_buf_len,
306                                  io_callback_);
307 }
308 
DoGreetReadComplete(int result)309 int SOCKS5ClientSocket::DoGreetReadComplete(int result) {
310   if (result < 0)
311     return result;
312 
313   if (result == 0) {
314     net_log_.AddEvent(
315         NetLogEventType::SOCKS_UNEXPECTEDLY_CLOSED_DURING_GREETING);
316     return ERR_SOCKS_CONNECTION_FAILED;
317   }
318 
319   bytes_received_ += result;
320   buffer_.append(handshake_buf_->data(), result);
321   if (bytes_received_ < kGreetReadHeaderSize) {
322     next_state_ = STATE_GREET_READ;
323     return OK;
324   }
325 
326   // Got the greet data.
327   if (buffer_[0] != kSOCKS5Version) {
328     net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION,
329                                    "version", buffer_[0]);
330     return ERR_SOCKS_CONNECTION_FAILED;
331   }
332   if (buffer_[1] != 0x00) {
333     net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_AUTH,
334                                    "method", buffer_[1]);
335     return ERR_SOCKS_CONNECTION_FAILED;
336   }
337 
338   buffer_.clear();
339   next_state_ = STATE_HANDSHAKE_WRITE;
340   return OK;
341 }
342 
BuildHandshakeWriteBuffer(std::string * handshake) const343 int SOCKS5ClientSocket::BuildHandshakeWriteBuffer(std::string* handshake)
344     const {
345   DCHECK(handshake->empty());
346 
347   handshake->push_back(kSOCKS5Version);
348   handshake->push_back(kTunnelCommand);  // Connect command
349   handshake->push_back(kNullByte);  // Reserved null
350 
351   handshake->push_back(kEndPointDomain);  // The type of the address.
352 
353   DCHECK_GE(static_cast<size_t>(0xFF), destination_.host().size());
354 
355   // First add the size of the hostname, followed by the hostname.
356   handshake->push_back(static_cast<unsigned char>(destination_.host().size()));
357   handshake->append(destination_.host());
358 
359   uint16_t nw_port = base::HostToNet16(destination_.port());
360   handshake->append(reinterpret_cast<char*>(&nw_port), sizeof(nw_port));
361   return OK;
362 }
363 
364 // Writes the SOCKS handshake data to the underlying socket connection.
DoHandshakeWrite()365 int SOCKS5ClientSocket::DoHandshakeWrite() {
366   next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
367 
368   if (buffer_.empty()) {
369     int rv = BuildHandshakeWriteBuffer(&buffer_);
370     if (rv != OK)
371       return rv;
372     bytes_sent_ = 0;
373   }
374 
375   int handshake_buf_len = buffer_.size() - bytes_sent_;
376   DCHECK_LT(0, handshake_buf_len);
377   handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
378   memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
379          handshake_buf_len);
380   return transport_socket_->Write(handshake_buf_.get(), handshake_buf_len,
381                                   io_callback_, traffic_annotation_);
382 }
383 
DoHandshakeWriteComplete(int result)384 int SOCKS5ClientSocket::DoHandshakeWriteComplete(int result) {
385   if (result < 0)
386     return result;
387 
388   // We ignore the case when result is 0, since the underlying Write
389   // may return spurious writes while waiting on the socket.
390 
391   bytes_sent_ += result;
392   if (bytes_sent_ == buffer_.size()) {
393     next_state_ = STATE_HANDSHAKE_READ;
394     buffer_.clear();
395   } else if (bytes_sent_ < buffer_.size()) {
396     next_state_ = STATE_HANDSHAKE_WRITE;
397   } else {
398     NOTREACHED();
399   }
400 
401   return OK;
402 }
403 
DoHandshakeRead()404 int SOCKS5ClientSocket::DoHandshakeRead() {
405   next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
406 
407   if (buffer_.empty()) {
408     bytes_received_ = 0;
409     read_header_size = kReadHeaderSize;
410   }
411 
412   int handshake_buf_len = read_header_size - bytes_received_;
413   handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
414   return transport_socket_->Read(handshake_buf_.get(), handshake_buf_len,
415                                  io_callback_);
416 }
417 
DoHandshakeReadComplete(int result)418 int SOCKS5ClientSocket::DoHandshakeReadComplete(int result) {
419   if (result < 0)
420     return result;
421 
422   // The underlying socket closed unexpectedly.
423   if (result == 0) {
424     net_log_.AddEvent(
425         NetLogEventType::SOCKS_UNEXPECTEDLY_CLOSED_DURING_HANDSHAKE);
426     return ERR_SOCKS_CONNECTION_FAILED;
427   }
428 
429   buffer_.append(handshake_buf_->data(), result);
430   bytes_received_ += result;
431 
432   // When the first few bytes are read, check how many more are required
433   // and accordingly increase them
434   if (bytes_received_ == kReadHeaderSize) {
435     if (buffer_[0] != kSOCKS5Version || buffer_[2] != kNullByte) {
436       net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION,
437                                      "version", buffer_[0]);
438       return ERR_SOCKS_CONNECTION_FAILED;
439     }
440     if (buffer_[1] != 0x00) {
441       net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_SERVER_ERROR,
442                                      "error_code", buffer_[1]);
443       return ERR_SOCKS_CONNECTION_FAILED;
444     }
445 
446     // We check the type of IP/Domain the server returns and accordingly
447     // increase the size of the response. For domains, we need to read the
448     // size of the domain, so the initial request size is upto the domain
449     // size. Since for IPv4/IPv6 the size is fixed and hence no 'size' is
450     // read, we substract 1 byte from the additional request size.
451     SocksEndPointAddressType address_type =
452         static_cast<SocksEndPointAddressType>(buffer_[3]);
453     if (address_type == kEndPointDomain) {
454       read_header_size += static_cast<uint8_t>(buffer_[4]);
455     } else if (address_type == kEndPointResolvedIPv4) {
456       read_header_size += sizeof(struct in_addr) - 1;
457     } else if (address_type == kEndPointResolvedIPv6) {
458       read_header_size += sizeof(struct in6_addr) - 1;
459     } else {
460       net_log_.AddEventWithIntParams(
461           NetLogEventType::SOCKS_UNKNOWN_ADDRESS_TYPE, "address_type",
462           buffer_[3]);
463       return ERR_SOCKS_CONNECTION_FAILED;
464     }
465 
466     read_header_size += 2;  // for the port.
467     next_state_ = STATE_HANDSHAKE_READ;
468     return OK;
469   }
470 
471   // When the final bytes are read, setup handshake. We ignore the rest
472   // of the response since they represent the SOCKSv5 endpoint and have
473   // no use when doing a tunnel connection.
474   if (bytes_received_ == read_header_size) {
475     completed_handshake_ = true;
476     buffer_.clear();
477     next_state_ = STATE_NONE;
478     return OK;
479   }
480 
481   next_state_ = STATE_HANDSHAKE_READ;
482   return OK;
483 }
484 
GetPeerAddress(IPEndPoint * address) const485 int SOCKS5ClientSocket::GetPeerAddress(IPEndPoint* address) const {
486   return transport_socket_->GetPeerAddress(address);
487 }
488 
GetLocalAddress(IPEndPoint * address) const489 int SOCKS5ClientSocket::GetLocalAddress(IPEndPoint* address) const {
490   return transport_socket_->GetLocalAddress(address);
491 }
492 
493 }  // namespace net
494