1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9
10 #include "net/socket/socks5_client_socket.h"
11
12 #include <algorithm>
13 #include <iterator>
14 #include <map>
15 #include <memory>
16 #include <utility>
17
18 #include "base/containers/span.h"
19 #include "base/memory/ptr_util.h"
20 #include "base/memory/raw_ptr.h"
21 #include "base/sys_byteorder.h"
22 #include "build/build_config.h"
23 #include "net/base/address_list.h"
24 #include "net/base/test_completion_callback.h"
25 #include "net/base/winsock_init.h"
26 #include "net/log/net_log_event_type.h"
27 #include "net/log/test_net_log.h"
28 #include "net/log/test_net_log_util.h"
29 #include "net/socket/client_socket_factory.h"
30 #include "net/socket/socket_test_util.h"
31 #include "net/socket/tcp_client_socket.h"
32 #include "net/test/gtest_util.h"
33 #include "net/test/test_with_task_environment.h"
34 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
35 #include "testing/gmock/include/gmock/gmock.h"
36 #include "testing/gtest/include/gtest/gtest.h"
37 #include "testing/platform_test.h"
38
39 using net::test::IsError;
40 using net::test::IsOk;
41
42 //-----------------------------------------------------------------------------
43
44 namespace net {
45
46 class NetLog;
47
48 namespace {
49
50 // Base class to test SOCKS5ClientSocket
51 class SOCKS5ClientSocketTest : public PlatformTest, public WithTaskEnvironment {
52 public:
53 SOCKS5ClientSocketTest();
54
55 SOCKS5ClientSocketTest(const SOCKS5ClientSocketTest&) = delete;
56 SOCKS5ClientSocketTest& operator=(const SOCKS5ClientSocketTest&) = delete;
57
58 // Create a SOCKSClientSocket on top of a MockSocket.
59 std::unique_ptr<SOCKS5ClientSocket> BuildMockSocket(
60 base::span<const MockRead> reads,
61 base::span<const MockWrite> writes,
62 const std::string& hostname,
63 int port,
64 NetLog* net_log);
65
66 void SetUp() override;
67
68 protected:
69 const uint16_t kNwPort;
70 RecordingNetLogObserver net_log_observer_;
71 std::unique_ptr<SOCKS5ClientSocket> user_sock_;
72 AddressList address_list_;
73 // Filled in by BuildMockSocket() and owned by its return value
74 // (which |user_sock| is set to).
75 raw_ptr<StreamSocket> tcp_sock_;
76 TestCompletionCallback callback_;
77 std::unique_ptr<SocketDataProvider> data_;
78 };
79
SOCKS5ClientSocketTest()80 SOCKS5ClientSocketTest::SOCKS5ClientSocketTest()
81 : kNwPort(base::HostToNet16(80)) {}
82
83 // Set up platform before every test case
SetUp()84 void SOCKS5ClientSocketTest::SetUp() {
85 PlatformTest::SetUp();
86
87 // Create the "localhost" AddressList used by the TCP connection to connect.
88 address_list_ =
89 AddressList::CreateFromIPAddress(IPAddress::IPv4Localhost(), 1080);
90 }
91
BuildMockSocket(base::span<const MockRead> reads,base::span<const MockWrite> writes,const std::string & hostname,int port,NetLog * net_log)92 std::unique_ptr<SOCKS5ClientSocket> SOCKS5ClientSocketTest::BuildMockSocket(
93 base::span<const MockRead> reads,
94 base::span<const MockWrite> writes,
95 const std::string& hostname,
96 int port,
97 NetLog* net_log) {
98 TestCompletionCallback callback;
99 data_ = std::make_unique<StaticSocketDataProvider>(reads, writes);
100 auto tcp_sock = std::make_unique<MockTCPClientSocket>(address_list_, net_log,
101 data_.get());
102 tcp_sock_ = tcp_sock.get();
103
104 int rv = tcp_sock_->Connect(callback.callback());
105 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
106 rv = callback.WaitForResult();
107 EXPECT_THAT(rv, IsOk());
108 EXPECT_TRUE(tcp_sock_->IsConnected());
109
110 // The SOCKS5ClientSocket takes ownership of |tcp_sock_|, but keep a
111 // non-owning pointer to it.
112 return std::make_unique<SOCKS5ClientSocket>(std::move(tcp_sock),
113 HostPortPair(hostname, port),
114 TRAFFIC_ANNOTATION_FOR_TESTS);
115 }
116
117 // Tests a complete SOCKS5 handshake and the disconnection.
TEST_F(SOCKS5ClientSocketTest,CompleteHandshake)118 TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) {
119 const std::string payload_write = "random data";
120 const std::string payload_read = "moar random data";
121
122 const char kOkRequest[] = {
123 0x05, // Version
124 0x01, // Command (CONNECT)
125 0x00, // Reserved.
126 0x03, // Address type (DOMAINNAME).
127 0x09, // Length of domain (9)
128 // Domain string:
129 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't',
130 0x00, 0x50, // 16-bit port (80)
131 };
132
133 MockWrite data_writes[] = {
134 MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
135 MockWrite(ASYNC, kOkRequest, std::size(kOkRequest)),
136 MockWrite(ASYNC, payload_write.data(), payload_write.size())};
137 MockRead data_reads[] = {
138 MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
139 MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength),
140 MockRead(ASYNC, payload_read.data(), payload_read.size()) };
141
142 user_sock_ =
143 BuildMockSocket(data_reads, data_writes, "localhost", 80, NetLog::Get());
144
145 // At this state the TCP connection is completed but not the SOCKS handshake.
146 EXPECT_TRUE(tcp_sock_->IsConnected());
147 EXPECT_FALSE(user_sock_->IsConnected());
148
149 int rv = user_sock_->Connect(callback_.callback());
150 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
151 EXPECT_FALSE(user_sock_->IsConnected());
152
153 auto net_log_entries = net_log_observer_.GetEntries();
154 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
155 NetLogEventType::SOCKS5_CONNECT));
156
157 rv = callback_.WaitForResult();
158
159 EXPECT_THAT(rv, IsOk());
160 EXPECT_TRUE(user_sock_->IsConnected());
161
162 net_log_entries = net_log_observer_.GetEntries();
163 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
164 NetLogEventType::SOCKS5_CONNECT));
165
166 auto buffer = base::MakeRefCounted<IOBufferWithSize>(payload_write.size());
167 memcpy(buffer->data(), payload_write.data(), payload_write.size());
168 rv = user_sock_->Write(buffer.get(), payload_write.size(),
169 callback_.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
170 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
171 rv = callback_.WaitForResult();
172 EXPECT_EQ(static_cast<int>(payload_write.size()), rv);
173
174 buffer = base::MakeRefCounted<IOBufferWithSize>(payload_read.size());
175 rv =
176 user_sock_->Read(buffer.get(), payload_read.size(), callback_.callback());
177 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
178 rv = callback_.WaitForResult();
179 EXPECT_EQ(static_cast<int>(payload_read.size()), rv);
180 EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size()));
181
182 user_sock_->Disconnect();
183 EXPECT_FALSE(tcp_sock_->IsConnected());
184 EXPECT_FALSE(user_sock_->IsConnected());
185 }
186
187 // Test that you can call Connect() again after having called Disconnect().
TEST_F(SOCKS5ClientSocketTest,ConnectAndDisconnectTwice)188 TEST_F(SOCKS5ClientSocketTest, ConnectAndDisconnectTwice) {
189 const std::string hostname = "my-host-name";
190 const char kSOCKS5DomainRequest[] = {
191 0x05, // VER
192 0x01, // CMD
193 0x00, // RSV
194 0x03, // ATYPE
195 };
196
197 std::string request(kSOCKS5DomainRequest, std::size(kSOCKS5DomainRequest));
198 request.push_back(static_cast<char>(hostname.size()));
199 request.append(hostname);
200 request.append(reinterpret_cast<const char*>(&kNwPort), sizeof(kNwPort));
201
202 for (int i = 0; i < 2; ++i) {
203 MockWrite data_writes[] = {
204 MockWrite(SYNCHRONOUS, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
205 MockWrite(SYNCHRONOUS, request.data(), request.size())
206 };
207 MockRead data_reads[] = {
208 MockRead(SYNCHRONOUS, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
209 MockRead(SYNCHRONOUS, kSOCKS5OkResponse, kSOCKS5OkResponseLength)
210 };
211
212 user_sock_ =
213 BuildMockSocket(data_reads, data_writes, hostname, 80, nullptr);
214
215 int rv = user_sock_->Connect(callback_.callback());
216 EXPECT_THAT(rv, IsOk());
217 EXPECT_TRUE(user_sock_->IsConnected());
218
219 user_sock_->Disconnect();
220 EXPECT_FALSE(user_sock_->IsConnected());
221 }
222 }
223
224 // Test that we fail trying to connect to a hostname longer than 255 bytes.
TEST_F(SOCKS5ClientSocketTest,LargeHostNameFails)225 TEST_F(SOCKS5ClientSocketTest, LargeHostNameFails) {
226 // Create a string of length 256, where each character is 'x'.
227 std::string large_host_name;
228 std::fill_n(std::back_inserter(large_host_name), 256, 'x');
229
230 // Create a SOCKS socket, with mock transport socket.
231 MockWrite data_writes[] = {MockWrite()};
232 MockRead data_reads[] = {MockRead()};
233 user_sock_ =
234 BuildMockSocket(data_reads, data_writes, large_host_name, 80, nullptr);
235
236 // Try to connect -- should fail (without having read/written anything to
237 // the transport socket first) because the hostname is too long.
238 TestCompletionCallback callback;
239 int rv = user_sock_->Connect(callback.callback());
240 EXPECT_THAT(rv, IsError(ERR_SOCKS_CONNECTION_FAILED));
241 }
242
TEST_F(SOCKS5ClientSocketTest,PartialReadWrites)243 TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) {
244 const std::string hostname = "www.google.com";
245
246 const char kOkRequest[] = {
247 0x05, // Version
248 0x01, // Command (CONNECT)
249 0x00, // Reserved.
250 0x03, // Address type (DOMAINNAME).
251 0x0E, // Length of domain (14)
252 // Domain string:
253 'w', 'w', 'w', '.', 'g', 'o', 'o', 'g', 'l', 'e', '.', 'c', 'o', 'm',
254 0x00, 0x50, // 16-bit port (80)
255 };
256
257 // Test for partial greet request write
258 {
259 const char partial1[] = { 0x05, 0x01 };
260 const char partial2[] = { 0x00 };
261 MockWrite data_writes[] = {
262 MockWrite(ASYNC, partial1, std::size(partial1)),
263 MockWrite(ASYNC, partial2, std::size(partial2)),
264 MockWrite(ASYNC, kOkRequest, std::size(kOkRequest))};
265 MockRead data_reads[] = {
266 MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
267 MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
268 user_sock_ =
269 BuildMockSocket(data_reads, data_writes, hostname, 80, NetLog::Get());
270 int rv = user_sock_->Connect(callback_.callback());
271 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
272
273 auto net_log_entries = net_log_observer_.GetEntries();
274 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
275 NetLogEventType::SOCKS5_CONNECT));
276
277 rv = callback_.WaitForResult();
278 EXPECT_THAT(rv, IsOk());
279 EXPECT_TRUE(user_sock_->IsConnected());
280
281 net_log_entries = net_log_observer_.GetEntries();
282 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
283 NetLogEventType::SOCKS5_CONNECT));
284 }
285
286 // Test for partial greet response read
287 {
288 const char partial1[] = { 0x05 };
289 const char partial2[] = { 0x00 };
290 MockWrite data_writes[] = {
291 MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
292 MockWrite(ASYNC, kOkRequest, std::size(kOkRequest))};
293 MockRead data_reads[] = {
294 MockRead(ASYNC, partial1, std::size(partial1)),
295 MockRead(ASYNC, partial2, std::size(partial2)),
296 MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength)};
297 user_sock_ =
298 BuildMockSocket(data_reads, data_writes, hostname, 80, NetLog::Get());
299 int rv = user_sock_->Connect(callback_.callback());
300 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
301
302 auto net_log_entries = net_log_observer_.GetEntries();
303 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
304 NetLogEventType::SOCKS5_CONNECT));
305 rv = callback_.WaitForResult();
306 EXPECT_THAT(rv, IsOk());
307 EXPECT_TRUE(user_sock_->IsConnected());
308 net_log_entries = net_log_observer_.GetEntries();
309 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
310 NetLogEventType::SOCKS5_CONNECT));
311 }
312
313 // Test for partial handshake request write.
314 {
315 const int kSplitPoint = 3; // Break handshake write into two parts.
316 MockWrite data_writes[] = {
317 MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
318 MockWrite(ASYNC, kOkRequest, kSplitPoint),
319 MockWrite(ASYNC, kOkRequest + kSplitPoint,
320 std::size(kOkRequest) - kSplitPoint)};
321 MockRead data_reads[] = {
322 MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
323 MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
324 user_sock_ =
325 BuildMockSocket(data_reads, data_writes, hostname, 80, NetLog::Get());
326 int rv = user_sock_->Connect(callback_.callback());
327 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
328 auto net_log_entries = net_log_observer_.GetEntries();
329 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
330 NetLogEventType::SOCKS5_CONNECT));
331 rv = callback_.WaitForResult();
332 EXPECT_THAT(rv, IsOk());
333 EXPECT_TRUE(user_sock_->IsConnected());
334 net_log_entries = net_log_observer_.GetEntries();
335 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
336 NetLogEventType::SOCKS5_CONNECT));
337 }
338
339 // Test for partial handshake response read
340 {
341 const int kSplitPoint = 6; // Break the handshake read into two parts.
342 MockWrite data_writes[] = {
343 MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
344 MockWrite(ASYNC, kOkRequest, std::size(kOkRequest))};
345 MockRead data_reads[] = {
346 MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
347 MockRead(ASYNC, kSOCKS5OkResponse, kSplitPoint),
348 MockRead(ASYNC, kSOCKS5OkResponse + kSplitPoint,
349 kSOCKS5OkResponseLength - kSplitPoint)
350 };
351
352 user_sock_ =
353 BuildMockSocket(data_reads, data_writes, hostname, 80, NetLog::Get());
354 int rv = user_sock_->Connect(callback_.callback());
355 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
356 auto net_log_entries = net_log_observer_.GetEntries();
357 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
358 NetLogEventType::SOCKS5_CONNECT));
359 rv = callback_.WaitForResult();
360 EXPECT_THAT(rv, IsOk());
361 EXPECT_TRUE(user_sock_->IsConnected());
362 net_log_entries = net_log_observer_.GetEntries();
363 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
364 NetLogEventType::SOCKS5_CONNECT));
365 }
366 }
367
TEST_F(SOCKS5ClientSocketTest,Tag)368 TEST_F(SOCKS5ClientSocketTest, Tag) {
369 StaticSocketDataProvider data;
370 auto tagging_sock = std::make_unique<MockTaggingStreamSocket>(
371 std::make_unique<MockTCPClientSocket>(address_list_, NetLog::Get(),
372 &data));
373 auto* tagging_sock_ptr = tagging_sock.get();
374
375 // |socket| takes ownership of |tagging_sock|, but keep a non-owning pointer
376 // to it.
377 SOCKS5ClientSocket socket(std::move(tagging_sock),
378 HostPortPair("localhost", 80),
379 TRAFFIC_ANNOTATION_FOR_TESTS);
380
381 EXPECT_EQ(tagging_sock_ptr->tag(), SocketTag());
382 #if BUILDFLAG(IS_ANDROID)
383 SocketTag tag(0x12345678, 0x87654321);
384 socket.ApplySocketTag(tag);
385 EXPECT_EQ(tagging_sock_ptr->tag(), tag);
386 #endif // BUILDFLAG(IS_ANDROID)
387 }
388
389 } // namespace
390
391 } // namespace net
392