1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/socks_client_socket.h"
6
7 #include <utility>
8
9 #include "base/compiler_specific.h"
10 #include "base/functional/bind.h"
11 #include "base/functional/callback_helpers.h"
12 #include "base/sys_byteorder.h"
13 #include "net/base/address_list.h"
14 #include "net/base/io_buffer.h"
15 #include "net/dns/public/dns_query_type.h"
16 #include "net/dns/public/secure_dns_policy.h"
17 #include "net/log/net_log.h"
18 #include "net/log/net_log_event_type.h"
19 #include "net/traffic_annotation/network_traffic_annotation.h"
20
21 namespace net {
22
23 // Every SOCKS server requests a user-id from the client. It is optional
24 // and we send an empty string.
25 static const char kEmptyUserId[] = "";
26
27 // For SOCKS4, the client sends 8 bytes plus the size of the user-id.
28 static const unsigned int kWriteHeaderSize = 8;
29
30 // For SOCKS4 the server sends 8 bytes for acknowledgement.
31 static const unsigned int kReadHeaderSize = 8;
32
33 // Server Response codes for SOCKS.
34 static const uint8_t kServerResponseOk = 0x5A;
35 static const uint8_t kServerResponseRejected = 0x5B;
36 static const uint8_t kServerResponseNotReachable = 0x5C;
37 static const uint8_t kServerResponseMismatchedUserId = 0x5D;
38
39 static const uint8_t kSOCKSVersion4 = 0x04;
40 static const uint8_t kSOCKSStreamRequest = 0x01;
41
42 // A struct holding the essential details of the SOCKS4 Server Request.
43 // The port in the header is stored in network byte order.
44 struct SOCKS4ServerRequest {
45 uint8_t version;
46 uint8_t command;
47 uint16_t nw_port;
48 uint8_t ip[4];
49 };
50 static_assert(sizeof(SOCKS4ServerRequest) == kWriteHeaderSize,
51 "socks4 server request struct has incorrect size");
52
53 // A struct holding details of the SOCKS4 Server Response.
54 struct SOCKS4ServerResponse {
55 uint8_t reserved_null;
56 uint8_t code;
57 uint16_t port;
58 uint8_t ip[4];
59 };
60 static_assert(sizeof(SOCKS4ServerResponse) == kReadHeaderSize,
61 "socks4 server response struct has incorrect size");
62
SOCKSClientSocket(std::unique_ptr<StreamSocket> transport_socket,const HostPortPair & destination,const NetworkAnonymizationKey & network_anonymization_key,RequestPriority priority,HostResolver * host_resolver,SecureDnsPolicy secure_dns_policy,const NetworkTrafficAnnotationTag & traffic_annotation)63 SOCKSClientSocket::SOCKSClientSocket(
64 std::unique_ptr<StreamSocket> transport_socket,
65 const HostPortPair& destination,
66 const NetworkAnonymizationKey& network_anonymization_key,
67 RequestPriority priority,
68 HostResolver* host_resolver,
69 SecureDnsPolicy secure_dns_policy,
70 const NetworkTrafficAnnotationTag& traffic_annotation)
71 : transport_socket_(std::move(transport_socket)),
72 host_resolver_(host_resolver),
73 secure_dns_policy_(secure_dns_policy),
74 destination_(destination),
75 network_anonymization_key_(network_anonymization_key),
76 priority_(priority),
77 net_log_(transport_socket_->NetLog()),
78 traffic_annotation_(traffic_annotation) {}
79
~SOCKSClientSocket()80 SOCKSClientSocket::~SOCKSClientSocket() {
81 Disconnect();
82 }
83
Connect(CompletionOnceCallback callback)84 int SOCKSClientSocket::Connect(CompletionOnceCallback callback) {
85 DCHECK(transport_socket_);
86 DCHECK_EQ(STATE_NONE, next_state_);
87 DCHECK(user_callback_.is_null());
88
89 // If already connected, then just return OK.
90 if (completed_handshake_)
91 return OK;
92
93 next_state_ = STATE_RESOLVE_HOST;
94
95 net_log_.BeginEvent(NetLogEventType::SOCKS_CONNECT);
96
97 int rv = DoLoop(OK);
98 if (rv == ERR_IO_PENDING) {
99 user_callback_ = std::move(callback);
100 } else {
101 net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS_CONNECT, rv);
102 }
103 return rv;
104 }
105
Disconnect()106 void SOCKSClientSocket::Disconnect() {
107 completed_handshake_ = false;
108 resolve_host_request_.reset();
109 transport_socket_->Disconnect();
110
111 // Reset other states to make sure they aren't mistakenly used later.
112 // These are the states initialized by Connect().
113 next_state_ = STATE_NONE;
114 user_callback_.Reset();
115 }
116
IsConnected() const117 bool SOCKSClientSocket::IsConnected() const {
118 return completed_handshake_ && transport_socket_->IsConnected();
119 }
120
IsConnectedAndIdle() const121 bool SOCKSClientSocket::IsConnectedAndIdle() const {
122 return completed_handshake_ && transport_socket_->IsConnectedAndIdle();
123 }
124
NetLog() const125 const NetLogWithSource& SOCKSClientSocket::NetLog() const {
126 return net_log_;
127 }
128
WasEverUsed() const129 bool SOCKSClientSocket::WasEverUsed() const {
130 return was_ever_used_;
131 }
132
GetNegotiatedProtocol() const133 NextProto SOCKSClientSocket::GetNegotiatedProtocol() const {
134 if (transport_socket_)
135 return transport_socket_->GetNegotiatedProtocol();
136 NOTREACHED();
137 }
138
GetSSLInfo(SSLInfo * ssl_info)139 bool SOCKSClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
140 if (transport_socket_)
141 return transport_socket_->GetSSLInfo(ssl_info);
142 NOTREACHED();
143 }
144
GetTotalReceivedBytes() const145 int64_t SOCKSClientSocket::GetTotalReceivedBytes() const {
146 return transport_socket_->GetTotalReceivedBytes();
147 }
148
ApplySocketTag(const SocketTag & tag)149 void SOCKSClientSocket::ApplySocketTag(const SocketTag& tag) {
150 return transport_socket_->ApplySocketTag(tag);
151 }
152
153 // Read is called by the transport layer above to read. This can only be done
154 // if the SOCKS handshake is complete.
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)155 int SOCKSClientSocket::Read(IOBuffer* buf,
156 int buf_len,
157 CompletionOnceCallback callback) {
158 DCHECK(completed_handshake_);
159 DCHECK_EQ(STATE_NONE, next_state_);
160 DCHECK(user_callback_.is_null());
161 DCHECK(!callback.is_null());
162
163 int rv = transport_socket_->Read(
164 buf, buf_len,
165 base::BindOnce(&SOCKSClientSocket::OnReadWriteComplete,
166 base::Unretained(this), std::move(callback)));
167 if (rv > 0)
168 was_ever_used_ = true;
169 return rv;
170 }
171
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)172 int SOCKSClientSocket::ReadIfReady(IOBuffer* buf,
173 int buf_len,
174 CompletionOnceCallback callback) {
175 DCHECK(completed_handshake_);
176 DCHECK_EQ(STATE_NONE, next_state_);
177 DCHECK(user_callback_.is_null());
178 DCHECK(!callback.is_null());
179
180 // Pass |callback| directly instead of wrapping it with OnReadWriteComplete.
181 // This is to avoid setting |was_ever_used_| unless data is actually read.
182 int rv = transport_socket_->ReadIfReady(buf, buf_len, std::move(callback));
183 if (rv > 0)
184 was_ever_used_ = true;
185 return rv;
186 }
187
CancelReadIfReady()188 int SOCKSClientSocket::CancelReadIfReady() {
189 return transport_socket_->CancelReadIfReady();
190 }
191
192 // Write is called by the transport layer. This can only be done if the
193 // SOCKS handshake is complete.
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)194 int SOCKSClientSocket::Write(
195 IOBuffer* buf,
196 int buf_len,
197 CompletionOnceCallback callback,
198 const NetworkTrafficAnnotationTag& traffic_annotation) {
199 DCHECK(completed_handshake_);
200 DCHECK_EQ(STATE_NONE, next_state_);
201 DCHECK(user_callback_.is_null());
202 DCHECK(!callback.is_null());
203
204 int rv = transport_socket_->Write(
205 buf, buf_len,
206 base::BindOnce(&SOCKSClientSocket::OnReadWriteComplete,
207 base::Unretained(this), std::move(callback)),
208 traffic_annotation);
209 if (rv > 0)
210 was_ever_used_ = true;
211 return rv;
212 }
213
SetReceiveBufferSize(int32_t size)214 int SOCKSClientSocket::SetReceiveBufferSize(int32_t size) {
215 return transport_socket_->SetReceiveBufferSize(size);
216 }
217
SetSendBufferSize(int32_t size)218 int SOCKSClientSocket::SetSendBufferSize(int32_t size) {
219 return transport_socket_->SetSendBufferSize(size);
220 }
221
DoCallback(int result)222 void SOCKSClientSocket::DoCallback(int result) {
223 DCHECK_NE(ERR_IO_PENDING, result);
224 DCHECK(!user_callback_.is_null());
225
226 // Since Run() may result in Read being called,
227 // clear user_callback_ up front.
228 DVLOG(1) << "Finished setting up SOCKS handshake";
229 std::move(user_callback_).Run(result);
230 }
231
OnIOComplete(int result)232 void SOCKSClientSocket::OnIOComplete(int result) {
233 DCHECK_NE(STATE_NONE, next_state_);
234 int rv = DoLoop(result);
235 if (rv != ERR_IO_PENDING) {
236 net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS_CONNECT, rv);
237 DoCallback(rv);
238 }
239 }
240
OnReadWriteComplete(CompletionOnceCallback callback,int result)241 void SOCKSClientSocket::OnReadWriteComplete(CompletionOnceCallback callback,
242 int result) {
243 DCHECK_NE(ERR_IO_PENDING, result);
244 DCHECK(!callback.is_null());
245
246 if (result > 0)
247 was_ever_used_ = true;
248 std::move(callback).Run(result);
249 }
250
DoLoop(int last_io_result)251 int SOCKSClientSocket::DoLoop(int last_io_result) {
252 DCHECK_NE(next_state_, STATE_NONE);
253 int rv = last_io_result;
254 do {
255 State state = next_state_;
256 next_state_ = STATE_NONE;
257 switch (state) {
258 case STATE_RESOLVE_HOST:
259 DCHECK_EQ(OK, rv);
260 rv = DoResolveHost();
261 break;
262 case STATE_RESOLVE_HOST_COMPLETE:
263 rv = DoResolveHostComplete(rv);
264 break;
265 case STATE_HANDSHAKE_WRITE:
266 DCHECK_EQ(OK, rv);
267 rv = DoHandshakeWrite();
268 break;
269 case STATE_HANDSHAKE_WRITE_COMPLETE:
270 rv = DoHandshakeWriteComplete(rv);
271 break;
272 case STATE_HANDSHAKE_READ:
273 DCHECK_EQ(OK, rv);
274 rv = DoHandshakeRead();
275 break;
276 case STATE_HANDSHAKE_READ_COMPLETE:
277 rv = DoHandshakeReadComplete(rv);
278 break;
279 default:
280 NOTREACHED() << "bad state";
281 }
282 } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
283 return rv;
284 }
285
DoResolveHost()286 int SOCKSClientSocket::DoResolveHost() {
287 next_state_ = STATE_RESOLVE_HOST_COMPLETE;
288 // SOCKS4 only supports IPv4 addresses, so only try getting the IPv4
289 // addresses for the target host.
290 HostResolver::ResolveHostParameters parameters;
291 parameters.dns_query_type = DnsQueryType::A;
292 parameters.initial_priority = priority_;
293 parameters.secure_dns_policy = secure_dns_policy_;
294 resolve_host_request_ = host_resolver_->CreateRequest(
295 destination_, network_anonymization_key_, net_log_, parameters);
296
297 return resolve_host_request_->Start(
298 base::BindOnce(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)));
299 }
300
DoResolveHostComplete(int result)301 int SOCKSClientSocket::DoResolveHostComplete(int result) {
302 resolve_error_info_ = resolve_host_request_->GetResolveErrorInfo();
303 if (result != OK) {
304 // Resolving the hostname failed; fail the request rather than automatically
305 // falling back to SOCKS4a (since it can be confusing to see invalid IP
306 // addresses being sent to the SOCKS4 server when it doesn't support 4A.)
307 return result;
308 }
309
310 next_state_ = STATE_HANDSHAKE_WRITE;
311 return OK;
312 }
313
314 // Builds the buffer that is to be sent to the server.
BuildHandshakeWriteBuffer() const315 const std::string SOCKSClientSocket::BuildHandshakeWriteBuffer() const {
316 SOCKS4ServerRequest request;
317 request.version = kSOCKSVersion4;
318 request.command = kSOCKSStreamRequest;
319 request.nw_port = base::HostToNet16(destination_.port());
320
321 DCHECK(resolve_host_request_->GetAddressResults() &&
322 !resolve_host_request_->GetAddressResults()->empty());
323 const IPEndPoint& endpoint =
324 resolve_host_request_->GetAddressResults()->front();
325
326 // We disabled IPv6 results when resolving the hostname, so none of the
327 // results in the list will be IPv6.
328 // TODO(eroman): we only ever use the first address in the list. It would be
329 // more robust to try all the IP addresses we have before
330 // failing the connect attempt.
331 CHECK_EQ(ADDRESS_FAMILY_IPV4, endpoint.GetFamily());
332 CHECK_LE(endpoint.address().size(), sizeof(request.ip));
333 memcpy(&request.ip, &endpoint.address().bytes()[0],
334 endpoint.address().size());
335
336 DVLOG(1) << "Resolved Host is : " << endpoint.ToStringWithoutPort();
337
338 std::string handshake_data(reinterpret_cast<char*>(&request),
339 sizeof(request));
340 handshake_data.append(kEmptyUserId, std::size(kEmptyUserId));
341
342 return handshake_data;
343 }
344
345 // Writes the SOCKS handshake data to the underlying socket connection.
DoHandshakeWrite()346 int SOCKSClientSocket::DoHandshakeWrite() {
347 next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
348
349 if (buffer_.empty()) {
350 buffer_ = BuildHandshakeWriteBuffer();
351 bytes_sent_ = 0;
352 }
353
354 int handshake_buf_len = buffer_.size() - bytes_sent_;
355 DCHECK_GT(handshake_buf_len, 0);
356 handshake_buf_ = base::MakeRefCounted<IOBufferWithSize>(handshake_buf_len);
357 memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
358 handshake_buf_len);
359 return transport_socket_->Write(
360 handshake_buf_.get(), handshake_buf_len,
361 base::BindOnce(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)),
362 traffic_annotation_);
363 }
364
DoHandshakeWriteComplete(int result)365 int SOCKSClientSocket::DoHandshakeWriteComplete(int result) {
366 if (result < 0)
367 return result;
368
369 // We ignore the case when result is 0, since the underlying Write
370 // may return spurious writes while waiting on the socket.
371
372 bytes_sent_ += result;
373 if (bytes_sent_ == buffer_.size()) {
374 next_state_ = STATE_HANDSHAKE_READ;
375 buffer_.clear();
376 } else if (bytes_sent_ < buffer_.size()) {
377 next_state_ = STATE_HANDSHAKE_WRITE;
378 } else {
379 return ERR_UNEXPECTED;
380 }
381
382 return OK;
383 }
384
DoHandshakeRead()385 int SOCKSClientSocket::DoHandshakeRead() {
386 next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
387
388 if (buffer_.empty()) {
389 bytes_received_ = 0;
390 }
391
392 int handshake_buf_len = kReadHeaderSize - bytes_received_;
393 handshake_buf_ = base::MakeRefCounted<IOBufferWithSize>(handshake_buf_len);
394 return transport_socket_->Read(
395 handshake_buf_.get(), handshake_buf_len,
396 base::BindOnce(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)));
397 }
398
DoHandshakeReadComplete(int result)399 int SOCKSClientSocket::DoHandshakeReadComplete(int result) {
400 if (result < 0)
401 return result;
402
403 // The underlying socket closed unexpectedly.
404 if (result == 0)
405 return ERR_CONNECTION_CLOSED;
406
407 if (bytes_received_ + result > kReadHeaderSize) {
408 // TODO(eroman): Describe failure in NetLog.
409 return ERR_SOCKS_CONNECTION_FAILED;
410 }
411
412 buffer_.append(handshake_buf_->data(), result);
413 bytes_received_ += result;
414 if (bytes_received_ < kReadHeaderSize) {
415 next_state_ = STATE_HANDSHAKE_READ;
416 return OK;
417 }
418
419 const SOCKS4ServerResponse* response =
420 reinterpret_cast<const SOCKS4ServerResponse*>(buffer_.data());
421
422 if (response->reserved_null != 0x00) {
423 DVLOG(1) << "Unknown response from SOCKS server.";
424 return ERR_SOCKS_CONNECTION_FAILED;
425 }
426
427 switch (response->code) {
428 case kServerResponseOk:
429 completed_handshake_ = true;
430 return OK;
431 case kServerResponseRejected:
432 DVLOG(1) << "SOCKS request rejected or failed";
433 return ERR_SOCKS_CONNECTION_FAILED;
434 case kServerResponseNotReachable:
435 DVLOG(1) << "SOCKS request failed because client is not running "
436 << "identd (or not reachable from the server)";
437 return ERR_SOCKS_CONNECTION_HOST_UNREACHABLE;
438 case kServerResponseMismatchedUserId:
439 DVLOG(1) << "SOCKS request failed because client's identd could "
440 << "not confirm the user ID string in the request";
441 return ERR_SOCKS_CONNECTION_FAILED;
442 default:
443 DVLOG(1) << "SOCKS server sent unknown response";
444 return ERR_SOCKS_CONNECTION_FAILED;
445 }
446
447 // Note: we ignore the last 6 bytes as specified by the SOCKS protocol
448 }
449
GetPeerAddress(IPEndPoint * address) const450 int SOCKSClientSocket::GetPeerAddress(IPEndPoint* address) const {
451 return transport_socket_->GetPeerAddress(address);
452 }
453
GetLocalAddress(IPEndPoint * address) const454 int SOCKSClientSocket::GetLocalAddress(IPEndPoint* address) const {
455 return transport_socket_->GetLocalAddress(address);
456 }
457
GetResolveErrorInfo() const458 ResolveErrorInfo SOCKSClientSocket::GetResolveErrorInfo() const {
459 return resolve_error_info_;
460 }
461
462 } // namespace net
463