1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/socks_client_socket.h"
6
7 #include "base/memory/scoped_ptr.h"
8 #include "net/base/address_list.h"
9 #include "net/base/net_log.h"
10 #include "net/base/net_log_unittest.h"
11 #include "net/base/test_completion_callback.h"
12 #include "net/base/winsock_init.h"
13 #include "net/dns/host_resolver.h"
14 #include "net/dns/mock_host_resolver.h"
15 #include "net/socket/client_socket_factory.h"
16 #include "net/socket/socket_test_util.h"
17 #include "net/socket/tcp_client_socket.h"
18 #include "testing/gtest/include/gtest/gtest.h"
19 #include "testing/platform_test.h"
20
21 //-----------------------------------------------------------------------------
22
23 namespace net {
24
25 const char kSOCKSOkRequest[] = { 0x04, 0x01, 0x00, 0x50, 127, 0, 0, 1, 0 };
26 const char kSOCKSOkReply[] = { 0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0 };
27
28 class SOCKSClientSocketTest : public PlatformTest {
29 public:
30 SOCKSClientSocketTest();
31 // Create a SOCKSClientSocket on top of a MockSocket.
32 scoped_ptr<SOCKSClientSocket> BuildMockSocket(
33 MockRead reads[], size_t reads_count,
34 MockWrite writes[], size_t writes_count,
35 HostResolver* host_resolver,
36 const std::string& hostname, int port,
37 NetLog* net_log);
38 virtual void SetUp();
39
40 protected:
41 scoped_ptr<SOCKSClientSocket> user_sock_;
42 AddressList address_list_;
43 // Filled in by BuildMockSocket() and owned by its return value
44 // (which |user_sock| is set to).
45 StreamSocket* tcp_sock_;
46 TestCompletionCallback callback_;
47 scoped_ptr<MockHostResolver> host_resolver_;
48 scoped_ptr<SocketDataProvider> data_;
49 };
50
SOCKSClientSocketTest()51 SOCKSClientSocketTest::SOCKSClientSocketTest()
52 : host_resolver_(new MockHostResolver) {
53 }
54
55 // Set up platform before every test case
SetUp()56 void SOCKSClientSocketTest::SetUp() {
57 PlatformTest::SetUp();
58 }
59
BuildMockSocket(MockRead reads[],size_t reads_count,MockWrite writes[],size_t writes_count,HostResolver * host_resolver,const std::string & hostname,int port,NetLog * net_log)60 scoped_ptr<SOCKSClientSocket> SOCKSClientSocketTest::BuildMockSocket(
61 MockRead reads[],
62 size_t reads_count,
63 MockWrite writes[],
64 size_t writes_count,
65 HostResolver* host_resolver,
66 const std::string& hostname,
67 int port,
68 NetLog* net_log) {
69
70 TestCompletionCallback callback;
71 data_.reset(new StaticSocketDataProvider(reads, reads_count,
72 writes, writes_count));
73 tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get());
74
75 int rv = tcp_sock_->Connect(callback.callback());
76 EXPECT_EQ(ERR_IO_PENDING, rv);
77 rv = callback.WaitForResult();
78 EXPECT_EQ(OK, rv);
79 EXPECT_TRUE(tcp_sock_->IsConnected());
80
81 scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
82 // |connection| takes ownership of |tcp_sock_|, but keep a
83 // non-owning pointer to it.
84 connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_));
85 return scoped_ptr<SOCKSClientSocket>(new SOCKSClientSocket(
86 connection.Pass(),
87 HostResolver::RequestInfo(HostPortPair(hostname, port)),
88 DEFAULT_PRIORITY,
89 host_resolver));
90 }
91
92 // Implementation of HostResolver that never completes its resolve request.
93 // We use this in the test "DisconnectWhileHostResolveInProgress" to make
94 // sure that the outstanding resolve request gets cancelled.
95 class HangingHostResolverWithCancel : public HostResolver {
96 public:
HangingHostResolverWithCancel()97 HangingHostResolverWithCancel() : outstanding_request_(NULL) {}
98
Resolve(const RequestInfo & info,RequestPriority priority,AddressList * addresses,const CompletionCallback & callback,RequestHandle * out_req,const BoundNetLog & net_log)99 virtual int Resolve(const RequestInfo& info,
100 RequestPriority priority,
101 AddressList* addresses,
102 const CompletionCallback& callback,
103 RequestHandle* out_req,
104 const BoundNetLog& net_log) OVERRIDE {
105 DCHECK(addresses);
106 DCHECK_EQ(false, callback.is_null());
107 EXPECT_FALSE(HasOutstandingRequest());
108 outstanding_request_ = reinterpret_cast<RequestHandle>(1);
109 *out_req = outstanding_request_;
110 return ERR_IO_PENDING;
111 }
112
ResolveFromCache(const RequestInfo & info,AddressList * addresses,const BoundNetLog & net_log)113 virtual int ResolveFromCache(const RequestInfo& info,
114 AddressList* addresses,
115 const BoundNetLog& net_log) OVERRIDE {
116 NOTIMPLEMENTED();
117 return ERR_UNEXPECTED;
118 }
119
CancelRequest(RequestHandle req)120 virtual void CancelRequest(RequestHandle req) OVERRIDE {
121 EXPECT_TRUE(HasOutstandingRequest());
122 EXPECT_EQ(outstanding_request_, req);
123 outstanding_request_ = NULL;
124 }
125
HasOutstandingRequest()126 bool HasOutstandingRequest() {
127 return outstanding_request_ != NULL;
128 }
129
130 private:
131 RequestHandle outstanding_request_;
132
133 DISALLOW_COPY_AND_ASSIGN(HangingHostResolverWithCancel);
134 };
135
136 // Tests a complete handshake and the disconnection.
TEST_F(SOCKSClientSocketTest,CompleteHandshake)137 TEST_F(SOCKSClientSocketTest, CompleteHandshake) {
138 const std::string payload_write = "random data";
139 const std::string payload_read = "moar random data";
140
141 MockWrite data_writes[] = {
142 MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)),
143 MockWrite(ASYNC, payload_write.data(), payload_write.size()) };
144 MockRead data_reads[] = {
145 MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)),
146 MockRead(ASYNC, payload_read.data(), payload_read.size()) };
147 CapturingNetLog log;
148
149 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
150 data_writes, arraysize(data_writes),
151 host_resolver_.get(),
152 "localhost", 80,
153 &log);
154
155 // At this state the TCP connection is completed but not the SOCKS handshake.
156 EXPECT_TRUE(tcp_sock_->IsConnected());
157 EXPECT_FALSE(user_sock_->IsConnected());
158
159 int rv = user_sock_->Connect(callback_.callback());
160 EXPECT_EQ(ERR_IO_PENDING, rv);
161
162 CapturingNetLog::CapturedEntryList entries;
163 log.GetEntries(&entries);
164 EXPECT_TRUE(
165 LogContainsBeginEvent(entries, 0, NetLog::TYPE_SOCKS_CONNECT));
166 EXPECT_FALSE(user_sock_->IsConnected());
167
168 rv = callback_.WaitForResult();
169 EXPECT_EQ(OK, rv);
170 EXPECT_TRUE(user_sock_->IsConnected());
171 log.GetEntries(&entries);
172 EXPECT_TRUE(LogContainsEndEvent(
173 entries, -1, NetLog::TYPE_SOCKS_CONNECT));
174
175 scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size()));
176 memcpy(buffer->data(), payload_write.data(), payload_write.size());
177 rv = user_sock_->Write(
178 buffer.get(), payload_write.size(), callback_.callback());
179 EXPECT_EQ(ERR_IO_PENDING, rv);
180 rv = callback_.WaitForResult();
181 EXPECT_EQ(static_cast<int>(payload_write.size()), rv);
182
183 buffer = new IOBuffer(payload_read.size());
184 rv =
185 user_sock_->Read(buffer.get(), payload_read.size(), callback_.callback());
186 EXPECT_EQ(ERR_IO_PENDING, rv);
187 rv = callback_.WaitForResult();
188 EXPECT_EQ(static_cast<int>(payload_read.size()), rv);
189 EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size()));
190
191 user_sock_->Disconnect();
192 EXPECT_FALSE(tcp_sock_->IsConnected());
193 EXPECT_FALSE(user_sock_->IsConnected());
194 }
195
196 // List of responses from the socks server and the errors they should
197 // throw up are tested here.
TEST_F(SOCKSClientSocketTest,HandshakeFailures)198 TEST_F(SOCKSClientSocketTest, HandshakeFailures) {
199 const struct {
200 const char fail_reply[8];
201 Error fail_code;
202 } tests[] = {
203 // Failure of the server response code
204 {
205 { 0x01, 0x5A, 0x00, 0x00, 0, 0, 0, 0 },
206 ERR_SOCKS_CONNECTION_FAILED,
207 },
208 // Failure of the null byte
209 {
210 { 0x00, 0x5B, 0x00, 0x00, 0, 0, 0, 0 },
211 ERR_SOCKS_CONNECTION_FAILED,
212 },
213 };
214
215 //---------------------------------------
216
217 for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); ++i) {
218 MockWrite data_writes[] = {
219 MockWrite(SYNCHRONOUS, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
220 MockRead data_reads[] = {
221 MockRead(SYNCHRONOUS, tests[i].fail_reply,
222 arraysize(tests[i].fail_reply)) };
223 CapturingNetLog log;
224
225 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
226 data_writes, arraysize(data_writes),
227 host_resolver_.get(),
228 "localhost", 80,
229 &log);
230
231 int rv = user_sock_->Connect(callback_.callback());
232 EXPECT_EQ(ERR_IO_PENDING, rv);
233
234 CapturingNetLog::CapturedEntryList entries;
235 log.GetEntries(&entries);
236 EXPECT_TRUE(LogContainsBeginEvent(
237 entries, 0, NetLog::TYPE_SOCKS_CONNECT));
238
239 rv = callback_.WaitForResult();
240 EXPECT_EQ(tests[i].fail_code, rv);
241 EXPECT_FALSE(user_sock_->IsConnected());
242 EXPECT_TRUE(tcp_sock_->IsConnected());
243 log.GetEntries(&entries);
244 EXPECT_TRUE(LogContainsEndEvent(
245 entries, -1, NetLog::TYPE_SOCKS_CONNECT));
246 }
247 }
248
249 // Tests scenario when the server sends the handshake response in
250 // more than one packet.
TEST_F(SOCKSClientSocketTest,PartialServerReads)251 TEST_F(SOCKSClientSocketTest, PartialServerReads) {
252 const char kSOCKSPartialReply1[] = { 0x00 };
253 const char kSOCKSPartialReply2[] = { 0x5A, 0x00, 0x00, 0, 0, 0, 0 };
254
255 MockWrite data_writes[] = {
256 MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
257 MockRead data_reads[] = {
258 MockRead(ASYNC, kSOCKSPartialReply1, arraysize(kSOCKSPartialReply1)),
259 MockRead(ASYNC, kSOCKSPartialReply2, arraysize(kSOCKSPartialReply2)) };
260 CapturingNetLog log;
261
262 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
263 data_writes, arraysize(data_writes),
264 host_resolver_.get(),
265 "localhost", 80,
266 &log);
267
268 int rv = user_sock_->Connect(callback_.callback());
269 EXPECT_EQ(ERR_IO_PENDING, rv);
270 CapturingNetLog::CapturedEntryList entries;
271 log.GetEntries(&entries);
272 EXPECT_TRUE(LogContainsBeginEvent(
273 entries, 0, NetLog::TYPE_SOCKS_CONNECT));
274
275 rv = callback_.WaitForResult();
276 EXPECT_EQ(OK, rv);
277 EXPECT_TRUE(user_sock_->IsConnected());
278 log.GetEntries(&entries);
279 EXPECT_TRUE(LogContainsEndEvent(
280 entries, -1, NetLog::TYPE_SOCKS_CONNECT));
281 }
282
283 // Tests scenario when the client sends the handshake request in
284 // more than one packet.
TEST_F(SOCKSClientSocketTest,PartialClientWrites)285 TEST_F(SOCKSClientSocketTest, PartialClientWrites) {
286 const char kSOCKSPartialRequest1[] = { 0x04, 0x01 };
287 const char kSOCKSPartialRequest2[] = { 0x00, 0x50, 127, 0, 0, 1, 0 };
288
289 MockWrite data_writes[] = {
290 MockWrite(ASYNC, arraysize(kSOCKSPartialRequest1)),
291 // simulate some empty writes
292 MockWrite(ASYNC, 0),
293 MockWrite(ASYNC, 0),
294 MockWrite(ASYNC, kSOCKSPartialRequest2,
295 arraysize(kSOCKSPartialRequest2)) };
296 MockRead data_reads[] = {
297 MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)) };
298 CapturingNetLog log;
299
300 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
301 data_writes, arraysize(data_writes),
302 host_resolver_.get(),
303 "localhost", 80,
304 &log);
305
306 int rv = user_sock_->Connect(callback_.callback());
307 EXPECT_EQ(ERR_IO_PENDING, rv);
308 CapturingNetLog::CapturedEntryList entries;
309 log.GetEntries(&entries);
310 EXPECT_TRUE(LogContainsBeginEvent(
311 entries, 0, NetLog::TYPE_SOCKS_CONNECT));
312
313 rv = callback_.WaitForResult();
314 EXPECT_EQ(OK, rv);
315 EXPECT_TRUE(user_sock_->IsConnected());
316 log.GetEntries(&entries);
317 EXPECT_TRUE(LogContainsEndEvent(
318 entries, -1, NetLog::TYPE_SOCKS_CONNECT));
319 }
320
321 // Tests the case when the server sends a smaller sized handshake data
322 // and closes the connection.
TEST_F(SOCKSClientSocketTest,FailedSocketRead)323 TEST_F(SOCKSClientSocketTest, FailedSocketRead) {
324 MockWrite data_writes[] = {
325 MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
326 MockRead data_reads[] = {
327 MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply) - 2),
328 // close connection unexpectedly
329 MockRead(SYNCHRONOUS, 0) };
330 CapturingNetLog log;
331
332 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
333 data_writes, arraysize(data_writes),
334 host_resolver_.get(),
335 "localhost", 80,
336 &log);
337
338 int rv = user_sock_->Connect(callback_.callback());
339 EXPECT_EQ(ERR_IO_PENDING, rv);
340 CapturingNetLog::CapturedEntryList entries;
341 log.GetEntries(&entries);
342 EXPECT_TRUE(LogContainsBeginEvent(
343 entries, 0, NetLog::TYPE_SOCKS_CONNECT));
344
345 rv = callback_.WaitForResult();
346 EXPECT_EQ(ERR_CONNECTION_CLOSED, rv);
347 EXPECT_FALSE(user_sock_->IsConnected());
348 log.GetEntries(&entries);
349 EXPECT_TRUE(LogContainsEndEvent(
350 entries, -1, NetLog::TYPE_SOCKS_CONNECT));
351 }
352
353 // Tries to connect to an unknown hostname. Should fail rather than
354 // falling back to SOCKS4a.
TEST_F(SOCKSClientSocketTest,FailedDNS)355 TEST_F(SOCKSClientSocketTest, FailedDNS) {
356 const char hostname[] = "unresolved.ipv4.address";
357
358 host_resolver_->rules()->AddSimulatedFailure(hostname);
359
360 CapturingNetLog log;
361
362 user_sock_ = BuildMockSocket(NULL, 0,
363 NULL, 0,
364 host_resolver_.get(),
365 hostname, 80,
366 &log);
367
368 int rv = user_sock_->Connect(callback_.callback());
369 EXPECT_EQ(ERR_IO_PENDING, rv);
370 CapturingNetLog::CapturedEntryList entries;
371 log.GetEntries(&entries);
372 EXPECT_TRUE(LogContainsBeginEvent(
373 entries, 0, NetLog::TYPE_SOCKS_CONNECT));
374
375 rv = callback_.WaitForResult();
376 EXPECT_EQ(ERR_NAME_NOT_RESOLVED, rv);
377 EXPECT_FALSE(user_sock_->IsConnected());
378 log.GetEntries(&entries);
379 EXPECT_TRUE(LogContainsEndEvent(
380 entries, -1, NetLog::TYPE_SOCKS_CONNECT));
381 }
382
383 // Calls Disconnect() while a host resolve is in progress. The outstanding host
384 // resolve should be cancelled.
TEST_F(SOCKSClientSocketTest,DisconnectWhileHostResolveInProgress)385 TEST_F(SOCKSClientSocketTest, DisconnectWhileHostResolveInProgress) {
386 scoped_ptr<HangingHostResolverWithCancel> hanging_resolver(
387 new HangingHostResolverWithCancel());
388
389 // Doesn't matter what the socket data is, we will never use it -- garbage.
390 MockWrite data_writes[] = { MockWrite(SYNCHRONOUS, "", 0) };
391 MockRead data_reads[] = { MockRead(SYNCHRONOUS, "", 0) };
392
393 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
394 data_writes, arraysize(data_writes),
395 hanging_resolver.get(),
396 "foo", 80,
397 NULL);
398
399 // Start connecting (will get stuck waiting for the host to resolve).
400 int rv = user_sock_->Connect(callback_.callback());
401 EXPECT_EQ(ERR_IO_PENDING, rv);
402
403 EXPECT_FALSE(user_sock_->IsConnected());
404 EXPECT_FALSE(user_sock_->IsConnectedAndIdle());
405
406 // The host resolver should have received the resolve request.
407 EXPECT_TRUE(hanging_resolver->HasOutstandingRequest());
408
409 // Disconnect the SOCKS socket -- this should cancel the outstanding resolve.
410 user_sock_->Disconnect();
411
412 EXPECT_FALSE(hanging_resolver->HasOutstandingRequest());
413
414 EXPECT_FALSE(user_sock_->IsConnected());
415 EXPECT_FALSE(user_sock_->IsConnectedAndIdle());
416 }
417
418 // Tries to connect to an IPv6 IP. Should fail, as SOCKS4 does not support
419 // IPv6.
TEST_F(SOCKSClientSocketTest,NoIPv6)420 TEST_F(SOCKSClientSocketTest, NoIPv6) {
421 const char kHostName[] = "::1";
422
423 user_sock_ = BuildMockSocket(NULL, 0,
424 NULL, 0,
425 host_resolver_.get(),
426 kHostName, 80,
427 NULL);
428
429 EXPECT_EQ(ERR_NAME_NOT_RESOLVED,
430 callback_.GetResult(user_sock_->Connect(callback_.callback())));
431 }
432
433 // Same as above, but with a real resolver, to protect against regressions.
TEST_F(SOCKSClientSocketTest,NoIPv6RealResolver)434 TEST_F(SOCKSClientSocketTest, NoIPv6RealResolver) {
435 const char kHostName[] = "::1";
436
437 scoped_ptr<HostResolver> host_resolver(
438 HostResolver::CreateSystemResolver(HostResolver::Options(), NULL));
439
440 user_sock_ = BuildMockSocket(NULL, 0,
441 NULL, 0,
442 host_resolver.get(),
443 kHostName, 80,
444 NULL);
445
446 EXPECT_EQ(ERR_NAME_NOT_RESOLVED,
447 callback_.GetResult(user_sock_->Connect(callback_.callback())));
448 }
449
450 } // namespace net
451