1 // Copyright (c) 2010 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/socks5_client_socket.h"
6
7 #include <algorithm>
8 #include <map>
9
10 #include "net/base/address_list.h"
11 #include "net/base/net_log.h"
12 #include "net/base/net_log_unittest.h"
13 #include "net/base/mock_host_resolver.h"
14 #include "net/base/sys_addrinfo.h"
15 #include "net/base/test_completion_callback.h"
16 #include "net/base/winsock_init.h"
17 #include "net/socket/client_socket_factory.h"
18 #include "net/socket/socket_test_util.h"
19 #include "net/socket/tcp_client_socket.h"
20 #include "testing/gtest/include/gtest/gtest.h"
21 #include "testing/platform_test.h"
22
23 //-----------------------------------------------------------------------------
24
25 namespace net {
26
27 namespace {
28
29 // Base class to test SOCKS5ClientSocket
30 class SOCKS5ClientSocketTest : public PlatformTest {
31 public:
32 SOCKS5ClientSocketTest();
33 // Create a SOCKSClientSocket on top of a MockSocket.
34 SOCKS5ClientSocket* BuildMockSocket(MockRead reads[],
35 size_t reads_count,
36 MockWrite writes[],
37 size_t writes_count,
38 const std::string& hostname,
39 int port,
40 NetLog* net_log);
41
42 virtual void SetUp();
43
44 protected:
45 const uint16 kNwPort;
46 CapturingNetLog net_log_;
47 scoped_ptr<SOCKS5ClientSocket> user_sock_;
48 AddressList address_list_;
49 ClientSocket* tcp_sock_;
50 TestCompletionCallback callback_;
51 scoped_ptr<MockHostResolver> host_resolver_;
52 scoped_ptr<SocketDataProvider> data_;
53
54 private:
55 DISALLOW_COPY_AND_ASSIGN(SOCKS5ClientSocketTest);
56 };
57
SOCKS5ClientSocketTest()58 SOCKS5ClientSocketTest::SOCKS5ClientSocketTest()
59 : kNwPort(htons(80)),
60 net_log_(CapturingNetLog::kUnbounded),
61 host_resolver_(new MockHostResolver) {
62 }
63
64 // Set up platform before every test case
SetUp()65 void SOCKS5ClientSocketTest::SetUp() {
66 PlatformTest::SetUp();
67
68 // Resolve the "localhost" AddressList used by the TCP connection to connect.
69 HostResolver::RequestInfo info(HostPortPair("www.socks-proxy.com", 1080));
70 int rv = host_resolver_->Resolve(info, &address_list_, NULL, NULL,
71 BoundNetLog());
72 ASSERT_EQ(OK, rv);
73 }
74
BuildMockSocket(MockRead reads[],size_t reads_count,MockWrite writes[],size_t writes_count,const std::string & hostname,int port,NetLog * net_log)75 SOCKS5ClientSocket* SOCKS5ClientSocketTest::BuildMockSocket(
76 MockRead reads[],
77 size_t reads_count,
78 MockWrite writes[],
79 size_t writes_count,
80 const std::string& hostname,
81 int port,
82 NetLog* net_log) {
83 TestCompletionCallback callback;
84 data_.reset(new StaticSocketDataProvider(reads, reads_count,
85 writes, writes_count));
86 tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get());
87
88 int rv = tcp_sock_->Connect(&callback);
89 EXPECT_EQ(ERR_IO_PENDING, rv);
90 rv = callback.WaitForResult();
91 EXPECT_EQ(OK, rv);
92 EXPECT_TRUE(tcp_sock_->IsConnected());
93
94 return new SOCKS5ClientSocket(tcp_sock_,
95 HostResolver::RequestInfo(HostPortPair(hostname, port)));
96 }
97
98 // Tests a complete SOCKS5 handshake and the disconnection.
TEST_F(SOCKS5ClientSocketTest,CompleteHandshake)99 TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) {
100 const std::string payload_write = "random data";
101 const std::string payload_read = "moar random data";
102
103 const char kOkRequest[] = {
104 0x05, // Version
105 0x01, // Command (CONNECT)
106 0x00, // Reserved.
107 0x03, // Address type (DOMAINNAME).
108 0x09, // Length of domain (9)
109 // Domain string:
110 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't',
111 0x00, 0x50, // 16-bit port (80)
112 };
113
114 MockWrite data_writes[] = {
115 MockWrite(true, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
116 MockWrite(true, kOkRequest, arraysize(kOkRequest)),
117 MockWrite(true, payload_write.data(), payload_write.size()) };
118 MockRead data_reads[] = {
119 MockRead(true, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
120 MockRead(true, kSOCKS5OkResponse, kSOCKS5OkResponseLength),
121 MockRead(true, payload_read.data(), payload_read.size()) };
122
123 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
124 data_writes, arraysize(data_writes),
125 "localhost", 80, &net_log_));
126
127 // At this state the TCP connection is completed but not the SOCKS handshake.
128 EXPECT_TRUE(tcp_sock_->IsConnected());
129 EXPECT_FALSE(user_sock_->IsConnected());
130
131 int rv = user_sock_->Connect(&callback_);
132 EXPECT_EQ(ERR_IO_PENDING, rv);
133 EXPECT_FALSE(user_sock_->IsConnected());
134
135 net::CapturingNetLog::EntryList net_log_entries;
136 net_log_.GetEntries(&net_log_entries);
137 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
138 NetLog::TYPE_SOCKS5_CONNECT));
139
140 rv = callback_.WaitForResult();
141
142 EXPECT_EQ(OK, rv);
143 EXPECT_TRUE(user_sock_->IsConnected());
144
145 net_log_.GetEntries(&net_log_entries);
146 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
147 NetLog::TYPE_SOCKS5_CONNECT));
148
149 scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size()));
150 memcpy(buffer->data(), payload_write.data(), payload_write.size());
151 rv = user_sock_->Write(buffer, payload_write.size(), &callback_);
152 EXPECT_EQ(ERR_IO_PENDING, rv);
153 rv = callback_.WaitForResult();
154 EXPECT_EQ(static_cast<int>(payload_write.size()), rv);
155
156 buffer = new IOBuffer(payload_read.size());
157 rv = user_sock_->Read(buffer, payload_read.size(), &callback_);
158 EXPECT_EQ(ERR_IO_PENDING, rv);
159 rv = callback_.WaitForResult();
160 EXPECT_EQ(static_cast<int>(payload_read.size()), rv);
161 EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size()));
162
163 user_sock_->Disconnect();
164 EXPECT_FALSE(tcp_sock_->IsConnected());
165 EXPECT_FALSE(user_sock_->IsConnected());
166 }
167
168 // Test that you can call Connect() again after having called Disconnect().
TEST_F(SOCKS5ClientSocketTest,ConnectAndDisconnectTwice)169 TEST_F(SOCKS5ClientSocketTest, ConnectAndDisconnectTwice) {
170 const std::string hostname = "my-host-name";
171 const char kSOCKS5DomainRequest[] = {
172 0x05, // VER
173 0x01, // CMD
174 0x00, // RSV
175 0x03, // ATYPE
176 };
177
178 std::string request(kSOCKS5DomainRequest, arraysize(kSOCKS5DomainRequest));
179 request.push_back(hostname.size());
180 request.append(hostname);
181 request.append(reinterpret_cast<const char*>(&kNwPort), sizeof(kNwPort));
182
183 for (int i = 0; i < 2; ++i) {
184 MockWrite data_writes[] = {
185 MockWrite(false, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
186 MockWrite(false, request.data(), request.size())
187 };
188 MockRead data_reads[] = {
189 MockRead(false, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
190 MockRead(false, kSOCKS5OkResponse, kSOCKS5OkResponseLength)
191 };
192
193 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
194 data_writes, arraysize(data_writes),
195 hostname, 80, NULL));
196
197 int rv = user_sock_->Connect(&callback_);
198 EXPECT_EQ(OK, rv);
199 EXPECT_TRUE(user_sock_->IsConnected());
200
201 user_sock_->Disconnect();
202 EXPECT_FALSE(user_sock_->IsConnected());
203 }
204 }
205
206 // Test that we fail trying to connect to a hosname longer than 255 bytes.
TEST_F(SOCKS5ClientSocketTest,LargeHostNameFails)207 TEST_F(SOCKS5ClientSocketTest, LargeHostNameFails) {
208 // Create a string of length 256, where each character is 'x'.
209 std::string large_host_name;
210 std::fill_n(std::back_inserter(large_host_name), 256, 'x');
211
212 // Create a SOCKS socket, with mock transport socket.
213 MockWrite data_writes[] = {MockWrite()};
214 MockRead data_reads[] = {MockRead()};
215 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
216 data_writes, arraysize(data_writes),
217 large_host_name, 80, NULL));
218
219 // Try to connect -- should fail (without having read/written anything to
220 // the transport socket first) because the hostname is too long.
221 TestCompletionCallback callback;
222 int rv = user_sock_->Connect(&callback);
223 EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, rv);
224 }
225
TEST_F(SOCKS5ClientSocketTest,PartialReadWrites)226 TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) {
227 const std::string hostname = "www.google.com";
228
229 const char kOkRequest[] = {
230 0x05, // Version
231 0x01, // Command (CONNECT)
232 0x00, // Reserved.
233 0x03, // Address type (DOMAINNAME).
234 0x0E, // Length of domain (14)
235 // Domain string:
236 'w', 'w', 'w', '.', 'g', 'o', 'o', 'g', 'l', 'e', '.', 'c', 'o', 'm',
237 0x00, 0x50, // 16-bit port (80)
238 };
239
240 // Test for partial greet request write
241 {
242 const char partial1[] = { 0x05, 0x01 };
243 const char partial2[] = { 0x00 };
244 MockWrite data_writes[] = {
245 MockWrite(true, arraysize(partial1)),
246 MockWrite(true, partial2, arraysize(partial2)),
247 MockWrite(true, kOkRequest, arraysize(kOkRequest)) };
248 MockRead data_reads[] = {
249 MockRead(true, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
250 MockRead(true, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
251 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
252 data_writes, arraysize(data_writes),
253 hostname, 80, &net_log_));
254 int rv = user_sock_->Connect(&callback_);
255 EXPECT_EQ(ERR_IO_PENDING, rv);
256
257 net::CapturingNetLog::EntryList net_log_entries;
258 net_log_.GetEntries(&net_log_entries);
259 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
260 NetLog::TYPE_SOCKS5_CONNECT));
261
262 rv = callback_.WaitForResult();
263 EXPECT_EQ(OK, rv);
264 EXPECT_TRUE(user_sock_->IsConnected());
265
266 net_log_.GetEntries(&net_log_entries);
267 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
268 NetLog::TYPE_SOCKS5_CONNECT));
269 }
270
271 // Test for partial greet response read
272 {
273 const char partial1[] = { 0x05 };
274 const char partial2[] = { 0x00 };
275 MockWrite data_writes[] = {
276 MockWrite(true, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
277 MockWrite(true, kOkRequest, arraysize(kOkRequest)) };
278 MockRead data_reads[] = {
279 MockRead(true, partial1, arraysize(partial1)),
280 MockRead(true, partial2, arraysize(partial2)),
281 MockRead(true, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
282 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
283 data_writes, arraysize(data_writes),
284 hostname, 80, &net_log_));
285 int rv = user_sock_->Connect(&callback_);
286 EXPECT_EQ(ERR_IO_PENDING, rv);
287
288 net::CapturingNetLog::EntryList net_log_entries;
289 net_log_.GetEntries(&net_log_entries);
290 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
291 NetLog::TYPE_SOCKS5_CONNECT));
292 rv = callback_.WaitForResult();
293 EXPECT_EQ(OK, rv);
294 EXPECT_TRUE(user_sock_->IsConnected());
295 net_log_.GetEntries(&net_log_entries);
296 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
297 NetLog::TYPE_SOCKS5_CONNECT));
298 }
299
300 // Test for partial handshake request write.
301 {
302 const int kSplitPoint = 3; // Break handshake write into two parts.
303 MockWrite data_writes[] = {
304 MockWrite(true, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
305 MockWrite(true, kOkRequest, kSplitPoint),
306 MockWrite(true, kOkRequest + kSplitPoint,
307 arraysize(kOkRequest) - kSplitPoint)
308 };
309 MockRead data_reads[] = {
310 MockRead(true, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
311 MockRead(true, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
312 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
313 data_writes, arraysize(data_writes),
314 hostname, 80, &net_log_));
315 int rv = user_sock_->Connect(&callback_);
316 EXPECT_EQ(ERR_IO_PENDING, rv);
317 net::CapturingNetLog::EntryList net_log_entries;
318 net_log_.GetEntries(&net_log_entries);
319 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
320 NetLog::TYPE_SOCKS5_CONNECT));
321 rv = callback_.WaitForResult();
322 EXPECT_EQ(OK, rv);
323 EXPECT_TRUE(user_sock_->IsConnected());
324 net_log_.GetEntries(&net_log_entries);
325 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
326 NetLog::TYPE_SOCKS5_CONNECT));
327 }
328
329 // Test for partial handshake response read
330 {
331 const int kSplitPoint = 6; // Break the handshake read into two parts.
332 MockWrite data_writes[] = {
333 MockWrite(true, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
334 MockWrite(true, kOkRequest, arraysize(kOkRequest))
335 };
336 MockRead data_reads[] = {
337 MockRead(true, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
338 MockRead(true, kSOCKS5OkResponse, kSplitPoint),
339 MockRead(true, kSOCKS5OkResponse + kSplitPoint,
340 kSOCKS5OkResponseLength - kSplitPoint)
341 };
342
343 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
344 data_writes, arraysize(data_writes),
345 hostname, 80, &net_log_));
346 int rv = user_sock_->Connect(&callback_);
347 EXPECT_EQ(ERR_IO_PENDING, rv);
348 net::CapturingNetLog::EntryList net_log_entries;
349 net_log_.GetEntries(&net_log_entries);
350 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
351 NetLog::TYPE_SOCKS5_CONNECT));
352 rv = callback_.WaitForResult();
353 EXPECT_EQ(OK, rv);
354 EXPECT_TRUE(user_sock_->IsConnected());
355 net_log_.GetEntries(&net_log_entries);
356 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
357 NetLog::TYPE_SOCKS5_CONNECT));
358 }
359 }
360
361 } // namespace
362
363 } // namespace net
364