• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/socks_client_socket.h"
6 
7 #include "base/memory/scoped_ptr.h"
8 #include "net/base/address_list.h"
9 #include "net/base/net_log.h"
10 #include "net/base/net_log_unittest.h"
11 #include "net/base/test_completion_callback.h"
12 #include "net/base/winsock_init.h"
13 #include "net/dns/host_resolver.h"
14 #include "net/dns/mock_host_resolver.h"
15 #include "net/socket/client_socket_factory.h"
16 #include "net/socket/socket_test_util.h"
17 #include "net/socket/tcp_client_socket.h"
18 #include "testing/gtest/include/gtest/gtest.h"
19 #include "testing/platform_test.h"
20 
21 //-----------------------------------------------------------------------------
22 
23 namespace net {
24 
25 const char kSOCKSOkRequest[] = { 0x04, 0x01, 0x00, 0x50, 127, 0, 0, 1, 0 };
26 const char kSOCKSOkReply[] = { 0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0 };
27 
28 class SOCKSClientSocketTest : public PlatformTest {
29  public:
30   SOCKSClientSocketTest();
31   // Create a SOCKSClientSocket on top of a MockSocket.
32   scoped_ptr<SOCKSClientSocket> BuildMockSocket(
33       MockRead reads[], size_t reads_count,
34       MockWrite writes[], size_t writes_count,
35       HostResolver* host_resolver,
36       const std::string& hostname, int port,
37       NetLog* net_log);
38   virtual void SetUp();
39 
40  protected:
41   scoped_ptr<SOCKSClientSocket> user_sock_;
42   AddressList address_list_;
43   // Filled in by BuildMockSocket() and owned by its return value
44   // (which |user_sock| is set to).
45   StreamSocket* tcp_sock_;
46   TestCompletionCallback callback_;
47   scoped_ptr<MockHostResolver> host_resolver_;
48   scoped_ptr<SocketDataProvider> data_;
49 };
50 
SOCKSClientSocketTest()51 SOCKSClientSocketTest::SOCKSClientSocketTest()
52   : host_resolver_(new MockHostResolver) {
53 }
54 
55 // Set up platform before every test case
SetUp()56 void SOCKSClientSocketTest::SetUp() {
57   PlatformTest::SetUp();
58 }
59 
BuildMockSocket(MockRead reads[],size_t reads_count,MockWrite writes[],size_t writes_count,HostResolver * host_resolver,const std::string & hostname,int port,NetLog * net_log)60 scoped_ptr<SOCKSClientSocket> SOCKSClientSocketTest::BuildMockSocket(
61     MockRead reads[],
62     size_t reads_count,
63     MockWrite writes[],
64     size_t writes_count,
65     HostResolver* host_resolver,
66     const std::string& hostname,
67     int port,
68     NetLog* net_log) {
69 
70   TestCompletionCallback callback;
71   data_.reset(new StaticSocketDataProvider(reads, reads_count,
72                                            writes, writes_count));
73   tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get());
74 
75   int rv = tcp_sock_->Connect(callback.callback());
76   EXPECT_EQ(ERR_IO_PENDING, rv);
77   rv = callback.WaitForResult();
78   EXPECT_EQ(OK, rv);
79   EXPECT_TRUE(tcp_sock_->IsConnected());
80 
81   scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
82   // |connection| takes ownership of |tcp_sock_|, but keep a
83   // non-owning pointer to it.
84   connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_));
85   return scoped_ptr<SOCKSClientSocket>(new SOCKSClientSocket(
86       connection.Pass(),
87       HostResolver::RequestInfo(HostPortPair(hostname, port)),
88       DEFAULT_PRIORITY,
89       host_resolver));
90 }
91 
92 // Implementation of HostResolver that never completes its resolve request.
93 // We use this in the test "DisconnectWhileHostResolveInProgress" to make
94 // sure that the outstanding resolve request gets cancelled.
95 class HangingHostResolverWithCancel : public HostResolver {
96  public:
HangingHostResolverWithCancel()97   HangingHostResolverWithCancel() : outstanding_request_(NULL) {}
98 
Resolve(const RequestInfo & info,RequestPriority priority,AddressList * addresses,const CompletionCallback & callback,RequestHandle * out_req,const BoundNetLog & net_log)99   virtual int Resolve(const RequestInfo& info,
100                       RequestPriority priority,
101                       AddressList* addresses,
102                       const CompletionCallback& callback,
103                       RequestHandle* out_req,
104                       const BoundNetLog& net_log) OVERRIDE {
105     DCHECK(addresses);
106     DCHECK_EQ(false, callback.is_null());
107     EXPECT_FALSE(HasOutstandingRequest());
108     outstanding_request_ = reinterpret_cast<RequestHandle>(1);
109     *out_req = outstanding_request_;
110     return ERR_IO_PENDING;
111   }
112 
ResolveFromCache(const RequestInfo & info,AddressList * addresses,const BoundNetLog & net_log)113   virtual int ResolveFromCache(const RequestInfo& info,
114                                AddressList* addresses,
115                                const BoundNetLog& net_log) OVERRIDE {
116     NOTIMPLEMENTED();
117     return ERR_UNEXPECTED;
118   }
119 
CancelRequest(RequestHandle req)120   virtual void CancelRequest(RequestHandle req) OVERRIDE {
121     EXPECT_TRUE(HasOutstandingRequest());
122     EXPECT_EQ(outstanding_request_, req);
123     outstanding_request_ = NULL;
124   }
125 
HasOutstandingRequest()126   bool HasOutstandingRequest() {
127     return outstanding_request_ != NULL;
128   }
129 
130  private:
131   RequestHandle outstanding_request_;
132 
133   DISALLOW_COPY_AND_ASSIGN(HangingHostResolverWithCancel);
134 };
135 
136 // Tests a complete handshake and the disconnection.
TEST_F(SOCKSClientSocketTest,CompleteHandshake)137 TEST_F(SOCKSClientSocketTest, CompleteHandshake) {
138   const std::string payload_write = "random data";
139   const std::string payload_read = "moar random data";
140 
141   MockWrite data_writes[] = {
142       MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)),
143       MockWrite(ASYNC, payload_write.data(), payload_write.size()) };
144   MockRead data_reads[] = {
145       MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)),
146       MockRead(ASYNC, payload_read.data(), payload_read.size()) };
147   CapturingNetLog log;
148 
149   user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
150                                data_writes, arraysize(data_writes),
151                                host_resolver_.get(),
152                                "localhost", 80,
153                                &log);
154 
155   // At this state the TCP connection is completed but not the SOCKS handshake.
156   EXPECT_TRUE(tcp_sock_->IsConnected());
157   EXPECT_FALSE(user_sock_->IsConnected());
158 
159   int rv = user_sock_->Connect(callback_.callback());
160   EXPECT_EQ(ERR_IO_PENDING, rv);
161 
162   CapturingNetLog::CapturedEntryList entries;
163   log.GetEntries(&entries);
164   EXPECT_TRUE(
165       LogContainsBeginEvent(entries, 0, NetLog::TYPE_SOCKS_CONNECT));
166   EXPECT_FALSE(user_sock_->IsConnected());
167 
168   rv = callback_.WaitForResult();
169   EXPECT_EQ(OK, rv);
170   EXPECT_TRUE(user_sock_->IsConnected());
171   log.GetEntries(&entries);
172   EXPECT_TRUE(LogContainsEndEvent(
173       entries, -1, NetLog::TYPE_SOCKS_CONNECT));
174 
175   scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size()));
176   memcpy(buffer->data(), payload_write.data(), payload_write.size());
177   rv = user_sock_->Write(
178       buffer.get(), payload_write.size(), callback_.callback());
179   EXPECT_EQ(ERR_IO_PENDING, rv);
180   rv = callback_.WaitForResult();
181   EXPECT_EQ(static_cast<int>(payload_write.size()), rv);
182 
183   buffer = new IOBuffer(payload_read.size());
184   rv =
185       user_sock_->Read(buffer.get(), payload_read.size(), callback_.callback());
186   EXPECT_EQ(ERR_IO_PENDING, rv);
187   rv = callback_.WaitForResult();
188   EXPECT_EQ(static_cast<int>(payload_read.size()), rv);
189   EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size()));
190 
191   user_sock_->Disconnect();
192   EXPECT_FALSE(tcp_sock_->IsConnected());
193   EXPECT_FALSE(user_sock_->IsConnected());
194 }
195 
196 // List of responses from the socks server and the errors they should
197 // throw up are tested here.
TEST_F(SOCKSClientSocketTest,HandshakeFailures)198 TEST_F(SOCKSClientSocketTest, HandshakeFailures) {
199   const struct {
200     const char fail_reply[8];
201     Error fail_code;
202   } tests[] = {
203     // Failure of the server response code
204     {
205       { 0x01, 0x5A, 0x00, 0x00, 0, 0, 0, 0 },
206       ERR_SOCKS_CONNECTION_FAILED,
207     },
208     // Failure of the null byte
209     {
210       { 0x00, 0x5B, 0x00, 0x00, 0, 0, 0, 0 },
211       ERR_SOCKS_CONNECTION_FAILED,
212     },
213   };
214 
215   //---------------------------------------
216 
217   for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); ++i) {
218     MockWrite data_writes[] = {
219         MockWrite(SYNCHRONOUS, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
220     MockRead data_reads[] = {
221         MockRead(SYNCHRONOUS, tests[i].fail_reply,
222                  arraysize(tests[i].fail_reply)) };
223     CapturingNetLog log;
224 
225     user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
226                                  data_writes, arraysize(data_writes),
227                                  host_resolver_.get(),
228                                  "localhost", 80,
229                                  &log);
230 
231     int rv = user_sock_->Connect(callback_.callback());
232     EXPECT_EQ(ERR_IO_PENDING, rv);
233 
234     CapturingNetLog::CapturedEntryList entries;
235     log.GetEntries(&entries);
236     EXPECT_TRUE(LogContainsBeginEvent(
237         entries, 0, NetLog::TYPE_SOCKS_CONNECT));
238 
239     rv = callback_.WaitForResult();
240     EXPECT_EQ(tests[i].fail_code, rv);
241     EXPECT_FALSE(user_sock_->IsConnected());
242     EXPECT_TRUE(tcp_sock_->IsConnected());
243     log.GetEntries(&entries);
244     EXPECT_TRUE(LogContainsEndEvent(
245         entries, -1, NetLog::TYPE_SOCKS_CONNECT));
246   }
247 }
248 
249 // Tests scenario when the server sends the handshake response in
250 // more than one packet.
TEST_F(SOCKSClientSocketTest,PartialServerReads)251 TEST_F(SOCKSClientSocketTest, PartialServerReads) {
252   const char kSOCKSPartialReply1[] = { 0x00 };
253   const char kSOCKSPartialReply2[] = { 0x5A, 0x00, 0x00, 0, 0, 0, 0 };
254 
255   MockWrite data_writes[] = {
256       MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
257   MockRead data_reads[] = {
258       MockRead(ASYNC, kSOCKSPartialReply1, arraysize(kSOCKSPartialReply1)),
259       MockRead(ASYNC, kSOCKSPartialReply2, arraysize(kSOCKSPartialReply2)) };
260   CapturingNetLog log;
261 
262   user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
263                                data_writes, arraysize(data_writes),
264                                host_resolver_.get(),
265                                "localhost", 80,
266                                &log);
267 
268   int rv = user_sock_->Connect(callback_.callback());
269   EXPECT_EQ(ERR_IO_PENDING, rv);
270   CapturingNetLog::CapturedEntryList entries;
271   log.GetEntries(&entries);
272   EXPECT_TRUE(LogContainsBeginEvent(
273       entries, 0, NetLog::TYPE_SOCKS_CONNECT));
274 
275   rv = callback_.WaitForResult();
276   EXPECT_EQ(OK, rv);
277   EXPECT_TRUE(user_sock_->IsConnected());
278   log.GetEntries(&entries);
279   EXPECT_TRUE(LogContainsEndEvent(
280       entries, -1, NetLog::TYPE_SOCKS_CONNECT));
281 }
282 
283 // Tests scenario when the client sends the handshake request in
284 // more than one packet.
TEST_F(SOCKSClientSocketTest,PartialClientWrites)285 TEST_F(SOCKSClientSocketTest, PartialClientWrites) {
286   const char kSOCKSPartialRequest1[] = { 0x04, 0x01 };
287   const char kSOCKSPartialRequest2[] = { 0x00, 0x50, 127, 0, 0, 1, 0 };
288 
289   MockWrite data_writes[] = {
290       MockWrite(ASYNC, arraysize(kSOCKSPartialRequest1)),
291       // simulate some empty writes
292       MockWrite(ASYNC, 0),
293       MockWrite(ASYNC, 0),
294       MockWrite(ASYNC, kSOCKSPartialRequest2,
295                 arraysize(kSOCKSPartialRequest2)) };
296   MockRead data_reads[] = {
297       MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)) };
298   CapturingNetLog log;
299 
300   user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
301                                data_writes, arraysize(data_writes),
302                                host_resolver_.get(),
303                                "localhost", 80,
304                                &log);
305 
306   int rv = user_sock_->Connect(callback_.callback());
307   EXPECT_EQ(ERR_IO_PENDING, rv);
308   CapturingNetLog::CapturedEntryList entries;
309   log.GetEntries(&entries);
310   EXPECT_TRUE(LogContainsBeginEvent(
311       entries, 0, NetLog::TYPE_SOCKS_CONNECT));
312 
313   rv = callback_.WaitForResult();
314   EXPECT_EQ(OK, rv);
315   EXPECT_TRUE(user_sock_->IsConnected());
316   log.GetEntries(&entries);
317   EXPECT_TRUE(LogContainsEndEvent(
318       entries, -1, NetLog::TYPE_SOCKS_CONNECT));
319 }
320 
321 // Tests the case when the server sends a smaller sized handshake data
322 // and closes the connection.
TEST_F(SOCKSClientSocketTest,FailedSocketRead)323 TEST_F(SOCKSClientSocketTest, FailedSocketRead) {
324   MockWrite data_writes[] = {
325       MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
326   MockRead data_reads[] = {
327       MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply) - 2),
328       // close connection unexpectedly
329       MockRead(SYNCHRONOUS, 0) };
330   CapturingNetLog log;
331 
332   user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
333                                data_writes, arraysize(data_writes),
334                                host_resolver_.get(),
335                                "localhost", 80,
336                                &log);
337 
338   int rv = user_sock_->Connect(callback_.callback());
339   EXPECT_EQ(ERR_IO_PENDING, rv);
340   CapturingNetLog::CapturedEntryList entries;
341   log.GetEntries(&entries);
342   EXPECT_TRUE(LogContainsBeginEvent(
343       entries, 0, NetLog::TYPE_SOCKS_CONNECT));
344 
345   rv = callback_.WaitForResult();
346   EXPECT_EQ(ERR_CONNECTION_CLOSED, rv);
347   EXPECT_FALSE(user_sock_->IsConnected());
348   log.GetEntries(&entries);
349   EXPECT_TRUE(LogContainsEndEvent(
350       entries, -1, NetLog::TYPE_SOCKS_CONNECT));
351 }
352 
353 // Tries to connect to an unknown hostname. Should fail rather than
354 // falling back to SOCKS4a.
TEST_F(SOCKSClientSocketTest,FailedDNS)355 TEST_F(SOCKSClientSocketTest, FailedDNS) {
356   const char hostname[] = "unresolved.ipv4.address";
357 
358   host_resolver_->rules()->AddSimulatedFailure(hostname);
359 
360   CapturingNetLog log;
361 
362   user_sock_ = BuildMockSocket(NULL, 0,
363                                NULL, 0,
364                                host_resolver_.get(),
365                                hostname, 80,
366                                &log);
367 
368   int rv = user_sock_->Connect(callback_.callback());
369   EXPECT_EQ(ERR_IO_PENDING, rv);
370   CapturingNetLog::CapturedEntryList entries;
371   log.GetEntries(&entries);
372   EXPECT_TRUE(LogContainsBeginEvent(
373       entries, 0, NetLog::TYPE_SOCKS_CONNECT));
374 
375   rv = callback_.WaitForResult();
376   EXPECT_EQ(ERR_NAME_NOT_RESOLVED, rv);
377   EXPECT_FALSE(user_sock_->IsConnected());
378   log.GetEntries(&entries);
379   EXPECT_TRUE(LogContainsEndEvent(
380       entries, -1, NetLog::TYPE_SOCKS_CONNECT));
381 }
382 
383 // Calls Disconnect() while a host resolve is in progress. The outstanding host
384 // resolve should be cancelled.
TEST_F(SOCKSClientSocketTest,DisconnectWhileHostResolveInProgress)385 TEST_F(SOCKSClientSocketTest, DisconnectWhileHostResolveInProgress) {
386   scoped_ptr<HangingHostResolverWithCancel> hanging_resolver(
387     new HangingHostResolverWithCancel());
388 
389   // Doesn't matter what the socket data is, we will never use it -- garbage.
390   MockWrite data_writes[] = { MockWrite(SYNCHRONOUS, "", 0) };
391   MockRead data_reads[] = { MockRead(SYNCHRONOUS, "", 0) };
392 
393   user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
394                                data_writes, arraysize(data_writes),
395                                hanging_resolver.get(),
396                                "foo", 80,
397                                NULL);
398 
399   // Start connecting (will get stuck waiting for the host to resolve).
400   int rv = user_sock_->Connect(callback_.callback());
401   EXPECT_EQ(ERR_IO_PENDING, rv);
402 
403   EXPECT_FALSE(user_sock_->IsConnected());
404   EXPECT_FALSE(user_sock_->IsConnectedAndIdle());
405 
406   // The host resolver should have received the resolve request.
407   EXPECT_TRUE(hanging_resolver->HasOutstandingRequest());
408 
409   // Disconnect the SOCKS socket -- this should cancel the outstanding resolve.
410   user_sock_->Disconnect();
411 
412   EXPECT_FALSE(hanging_resolver->HasOutstandingRequest());
413 
414   EXPECT_FALSE(user_sock_->IsConnected());
415   EXPECT_FALSE(user_sock_->IsConnectedAndIdle());
416 }
417 
418 // Tries to connect to an IPv6 IP.  Should fail, as SOCKS4 does not support
419 // IPv6.
TEST_F(SOCKSClientSocketTest,NoIPv6)420 TEST_F(SOCKSClientSocketTest, NoIPv6) {
421   const char kHostName[] = "::1";
422 
423   user_sock_ = BuildMockSocket(NULL, 0,
424                                NULL, 0,
425                                host_resolver_.get(),
426                                kHostName, 80,
427                                NULL);
428 
429   EXPECT_EQ(ERR_NAME_NOT_RESOLVED,
430             callback_.GetResult(user_sock_->Connect(callback_.callback())));
431 }
432 
433 // Same as above, but with a real resolver, to protect against regressions.
TEST_F(SOCKSClientSocketTest,NoIPv6RealResolver)434 TEST_F(SOCKSClientSocketTest, NoIPv6RealResolver) {
435   const char kHostName[] = "::1";
436 
437   scoped_ptr<HostResolver> host_resolver(
438       HostResolver::CreateSystemResolver(HostResolver::Options(), NULL));
439 
440   user_sock_ = BuildMockSocket(NULL, 0,
441                                NULL, 0,
442                                host_resolver.get(),
443                                kHostName, 80,
444                                NULL);
445 
446   EXPECT_EQ(ERR_NAME_NOT_RESOLVED,
447             callback_.GetResult(user_sock_->Connect(callback_.callback())));
448 }
449 
450 }  // namespace net
451