• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9 
10 #include "net/socket/socks5_client_socket.h"
11 
12 #include <algorithm>
13 #include <iterator>
14 #include <map>
15 #include <memory>
16 #include <utility>
17 
18 #include "base/containers/span.h"
19 #include "base/memory/ptr_util.h"
20 #include "base/memory/raw_ptr.h"
21 #include "base/sys_byteorder.h"
22 #include "build/build_config.h"
23 #include "net/base/address_list.h"
24 #include "net/base/test_completion_callback.h"
25 #include "net/base/winsock_init.h"
26 #include "net/log/net_log_event_type.h"
27 #include "net/log/test_net_log.h"
28 #include "net/log/test_net_log_util.h"
29 #include "net/socket/client_socket_factory.h"
30 #include "net/socket/socket_test_util.h"
31 #include "net/socket/tcp_client_socket.h"
32 #include "net/test/gtest_util.h"
33 #include "net/test/test_with_task_environment.h"
34 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
35 #include "testing/gmock/include/gmock/gmock.h"
36 #include "testing/gtest/include/gtest/gtest.h"
37 #include "testing/platform_test.h"
38 
39 using net::test::IsError;
40 using net::test::IsOk;
41 
42 //-----------------------------------------------------------------------------
43 
44 namespace net {
45 
46 class NetLog;
47 
48 namespace {
49 
50 // Base class to test SOCKS5ClientSocket
51 class SOCKS5ClientSocketTest : public PlatformTest, public WithTaskEnvironment {
52  public:
53   SOCKS5ClientSocketTest();
54 
55   SOCKS5ClientSocketTest(const SOCKS5ClientSocketTest&) = delete;
56   SOCKS5ClientSocketTest& operator=(const SOCKS5ClientSocketTest&) = delete;
57 
58   // Create a SOCKSClientSocket on top of a MockSocket.
59   std::unique_ptr<SOCKS5ClientSocket> BuildMockSocket(
60       base::span<const MockRead> reads,
61       base::span<const MockWrite> writes,
62       const std::string& hostname,
63       int port,
64       NetLog* net_log);
65 
66   void SetUp() override;
67 
68  protected:
69   const uint16_t kNwPort;
70   RecordingNetLogObserver net_log_observer_;
71   std::unique_ptr<SOCKS5ClientSocket> user_sock_;
72   AddressList address_list_;
73   // Filled in by BuildMockSocket() and owned by its return value
74   // (which |user_sock| is set to).
75   raw_ptr<StreamSocket> tcp_sock_;
76   TestCompletionCallback callback_;
77   std::unique_ptr<SocketDataProvider> data_;
78 };
79 
SOCKS5ClientSocketTest()80 SOCKS5ClientSocketTest::SOCKS5ClientSocketTest()
81     : kNwPort(base::HostToNet16(80)) {}
82 
83 // Set up platform before every test case
SetUp()84 void SOCKS5ClientSocketTest::SetUp() {
85   PlatformTest::SetUp();
86 
87   // Create the "localhost" AddressList used by the TCP connection to connect.
88   address_list_ =
89       AddressList::CreateFromIPAddress(IPAddress::IPv4Localhost(), 1080);
90 }
91 
BuildMockSocket(base::span<const MockRead> reads,base::span<const MockWrite> writes,const std::string & hostname,int port,NetLog * net_log)92 std::unique_ptr<SOCKS5ClientSocket> SOCKS5ClientSocketTest::BuildMockSocket(
93     base::span<const MockRead> reads,
94     base::span<const MockWrite> writes,
95     const std::string& hostname,
96     int port,
97     NetLog* net_log) {
98   TestCompletionCallback callback;
99   data_ = std::make_unique<StaticSocketDataProvider>(reads, writes);
100   auto tcp_sock = std::make_unique<MockTCPClientSocket>(address_list_, net_log,
101                                                         data_.get());
102   tcp_sock_ = tcp_sock.get();
103 
104   int rv = tcp_sock_->Connect(callback.callback());
105   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
106   rv = callback.WaitForResult();
107   EXPECT_THAT(rv, IsOk());
108   EXPECT_TRUE(tcp_sock_->IsConnected());
109 
110   // The SOCKS5ClientSocket takes ownership of |tcp_sock_|, but keep a
111   // non-owning pointer to it.
112   return std::make_unique<SOCKS5ClientSocket>(std::move(tcp_sock),
113                                               HostPortPair(hostname, port),
114                                               TRAFFIC_ANNOTATION_FOR_TESTS);
115 }
116 
117 // Tests a complete SOCKS5 handshake and the disconnection.
TEST_F(SOCKS5ClientSocketTest,CompleteHandshake)118 TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) {
119   const std::string payload_write = "random data";
120   const std::string payload_read = "moar random data";
121 
122   const char kOkRequest[] = {
123     0x05,  // Version
124     0x01,  // Command (CONNECT)
125     0x00,  // Reserved.
126     0x03,  // Address type (DOMAINNAME).
127     0x09,  // Length of domain (9)
128     // Domain string:
129     'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't',
130     0x00, 0x50,  // 16-bit port (80)
131   };
132 
133   MockWrite data_writes[] = {
134       MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
135       MockWrite(ASYNC, kOkRequest, std::size(kOkRequest)),
136       MockWrite(ASYNC, payload_write.data(), payload_write.size())};
137   MockRead data_reads[] = {
138       MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
139       MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength),
140       MockRead(ASYNC, payload_read.data(), payload_read.size()) };
141 
142   user_sock_ =
143       BuildMockSocket(data_reads, data_writes, "localhost", 80, NetLog::Get());
144 
145   // At this state the TCP connection is completed but not the SOCKS handshake.
146   EXPECT_TRUE(tcp_sock_->IsConnected());
147   EXPECT_FALSE(user_sock_->IsConnected());
148 
149   int rv = user_sock_->Connect(callback_.callback());
150   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
151   EXPECT_FALSE(user_sock_->IsConnected());
152 
153   auto net_log_entries = net_log_observer_.GetEntries();
154   EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
155                                     NetLogEventType::SOCKS5_CONNECT));
156 
157   rv = callback_.WaitForResult();
158 
159   EXPECT_THAT(rv, IsOk());
160   EXPECT_TRUE(user_sock_->IsConnected());
161 
162   net_log_entries = net_log_observer_.GetEntries();
163   EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
164                                   NetLogEventType::SOCKS5_CONNECT));
165 
166   auto buffer = base::MakeRefCounted<IOBufferWithSize>(payload_write.size());
167   memcpy(buffer->data(), payload_write.data(), payload_write.size());
168   rv = user_sock_->Write(buffer.get(), payload_write.size(),
169                          callback_.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
170   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
171   rv = callback_.WaitForResult();
172   EXPECT_EQ(static_cast<int>(payload_write.size()), rv);
173 
174   buffer = base::MakeRefCounted<IOBufferWithSize>(payload_read.size());
175   rv =
176       user_sock_->Read(buffer.get(), payload_read.size(), callback_.callback());
177   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
178   rv = callback_.WaitForResult();
179   EXPECT_EQ(static_cast<int>(payload_read.size()), rv);
180   EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size()));
181 
182   user_sock_->Disconnect();
183   EXPECT_FALSE(tcp_sock_->IsConnected());
184   EXPECT_FALSE(user_sock_->IsConnected());
185 }
186 
187 // Test that you can call Connect() again after having called Disconnect().
TEST_F(SOCKS5ClientSocketTest,ConnectAndDisconnectTwice)188 TEST_F(SOCKS5ClientSocketTest, ConnectAndDisconnectTwice) {
189   const std::string hostname = "my-host-name";
190   const char kSOCKS5DomainRequest[] = {
191       0x05,  // VER
192       0x01,  // CMD
193       0x00,  // RSV
194       0x03,  // ATYPE
195   };
196 
197   std::string request(kSOCKS5DomainRequest, std::size(kSOCKS5DomainRequest));
198   request.push_back(static_cast<char>(hostname.size()));
199   request.append(hostname);
200   request.append(reinterpret_cast<const char*>(&kNwPort), sizeof(kNwPort));
201 
202   for (int i = 0; i < 2; ++i) {
203     MockWrite data_writes[] = {
204         MockWrite(SYNCHRONOUS, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
205         MockWrite(SYNCHRONOUS, request.data(), request.size())
206     };
207     MockRead data_reads[] = {
208         MockRead(SYNCHRONOUS, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
209         MockRead(SYNCHRONOUS, kSOCKS5OkResponse, kSOCKS5OkResponseLength)
210     };
211 
212     user_sock_ =
213         BuildMockSocket(data_reads, data_writes, hostname, 80, nullptr);
214 
215     int rv = user_sock_->Connect(callback_.callback());
216     EXPECT_THAT(rv, IsOk());
217     EXPECT_TRUE(user_sock_->IsConnected());
218 
219     user_sock_->Disconnect();
220     EXPECT_FALSE(user_sock_->IsConnected());
221   }
222 }
223 
224 // Test that we fail trying to connect to a hostname longer than 255 bytes.
TEST_F(SOCKS5ClientSocketTest,LargeHostNameFails)225 TEST_F(SOCKS5ClientSocketTest, LargeHostNameFails) {
226   // Create a string of length 256, where each character is 'x'.
227   std::string large_host_name;
228   std::fill_n(std::back_inserter(large_host_name), 256, 'x');
229 
230   // Create a SOCKS socket, with mock transport socket.
231   MockWrite data_writes[] = {MockWrite()};
232   MockRead data_reads[] = {MockRead()};
233   user_sock_ =
234       BuildMockSocket(data_reads, data_writes, large_host_name, 80, nullptr);
235 
236   // Try to connect -- should fail (without having read/written anything to
237   // the transport socket first) because the hostname is too long.
238   TestCompletionCallback callback;
239   int rv = user_sock_->Connect(callback.callback());
240   EXPECT_THAT(rv, IsError(ERR_SOCKS_CONNECTION_FAILED));
241 }
242 
TEST_F(SOCKS5ClientSocketTest,PartialReadWrites)243 TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) {
244   const std::string hostname = "www.google.com";
245 
246   const char kOkRequest[] = {
247     0x05,  // Version
248     0x01,  // Command (CONNECT)
249     0x00,  // Reserved.
250     0x03,  // Address type (DOMAINNAME).
251     0x0E,  // Length of domain (14)
252     // Domain string:
253     'w', 'w', 'w', '.', 'g', 'o', 'o', 'g', 'l', 'e', '.', 'c', 'o', 'm',
254     0x00, 0x50,  // 16-bit port (80)
255   };
256 
257   // Test for partial greet request write
258   {
259     const char partial1[] = { 0x05, 0x01 };
260     const char partial2[] = { 0x00 };
261     MockWrite data_writes[] = {
262         MockWrite(ASYNC, partial1, std::size(partial1)),
263         MockWrite(ASYNC, partial2, std::size(partial2)),
264         MockWrite(ASYNC, kOkRequest, std::size(kOkRequest))};
265     MockRead data_reads[] = {
266         MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
267         MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
268     user_sock_ =
269         BuildMockSocket(data_reads, data_writes, hostname, 80, NetLog::Get());
270     int rv = user_sock_->Connect(callback_.callback());
271     EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
272 
273     auto net_log_entries = net_log_observer_.GetEntries();
274     EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
275                                       NetLogEventType::SOCKS5_CONNECT));
276 
277     rv = callback_.WaitForResult();
278     EXPECT_THAT(rv, IsOk());
279     EXPECT_TRUE(user_sock_->IsConnected());
280 
281     net_log_entries = net_log_observer_.GetEntries();
282     EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
283                                     NetLogEventType::SOCKS5_CONNECT));
284   }
285 
286   // Test for partial greet response read
287   {
288     const char partial1[] = { 0x05 };
289     const char partial2[] = { 0x00 };
290     MockWrite data_writes[] = {
291         MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
292         MockWrite(ASYNC, kOkRequest, std::size(kOkRequest))};
293     MockRead data_reads[] = {
294         MockRead(ASYNC, partial1, std::size(partial1)),
295         MockRead(ASYNC, partial2, std::size(partial2)),
296         MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength)};
297     user_sock_ =
298         BuildMockSocket(data_reads, data_writes, hostname, 80, NetLog::Get());
299     int rv = user_sock_->Connect(callback_.callback());
300     EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
301 
302     auto net_log_entries = net_log_observer_.GetEntries();
303     EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
304                                       NetLogEventType::SOCKS5_CONNECT));
305     rv = callback_.WaitForResult();
306     EXPECT_THAT(rv, IsOk());
307     EXPECT_TRUE(user_sock_->IsConnected());
308     net_log_entries = net_log_observer_.GetEntries();
309     EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
310                                     NetLogEventType::SOCKS5_CONNECT));
311   }
312 
313   // Test for partial handshake request write.
314   {
315     const int kSplitPoint = 3;  // Break handshake write into two parts.
316     MockWrite data_writes[] = {
317         MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
318         MockWrite(ASYNC, kOkRequest, kSplitPoint),
319         MockWrite(ASYNC, kOkRequest + kSplitPoint,
320                   std::size(kOkRequest) - kSplitPoint)};
321     MockRead data_reads[] = {
322         MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
323         MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
324     user_sock_ =
325         BuildMockSocket(data_reads, data_writes, hostname, 80, NetLog::Get());
326     int rv = user_sock_->Connect(callback_.callback());
327     EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
328     auto net_log_entries = net_log_observer_.GetEntries();
329     EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
330                                       NetLogEventType::SOCKS5_CONNECT));
331     rv = callback_.WaitForResult();
332     EXPECT_THAT(rv, IsOk());
333     EXPECT_TRUE(user_sock_->IsConnected());
334     net_log_entries = net_log_observer_.GetEntries();
335     EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
336                                     NetLogEventType::SOCKS5_CONNECT));
337   }
338 
339   // Test for partial handshake response read
340   {
341     const int kSplitPoint = 6;  // Break the handshake read into two parts.
342     MockWrite data_writes[] = {
343         MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
344         MockWrite(ASYNC, kOkRequest, std::size(kOkRequest))};
345     MockRead data_reads[] = {
346         MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
347         MockRead(ASYNC, kSOCKS5OkResponse, kSplitPoint),
348         MockRead(ASYNC, kSOCKS5OkResponse + kSplitPoint,
349                  kSOCKS5OkResponseLength - kSplitPoint)
350     };
351 
352     user_sock_ =
353         BuildMockSocket(data_reads, data_writes, hostname, 80, NetLog::Get());
354     int rv = user_sock_->Connect(callback_.callback());
355     EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
356     auto net_log_entries = net_log_observer_.GetEntries();
357     EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
358                                       NetLogEventType::SOCKS5_CONNECT));
359     rv = callback_.WaitForResult();
360     EXPECT_THAT(rv, IsOk());
361     EXPECT_TRUE(user_sock_->IsConnected());
362     net_log_entries = net_log_observer_.GetEntries();
363     EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
364                                     NetLogEventType::SOCKS5_CONNECT));
365   }
366 }
367 
TEST_F(SOCKS5ClientSocketTest,Tag)368 TEST_F(SOCKS5ClientSocketTest, Tag) {
369   StaticSocketDataProvider data;
370   auto tagging_sock = std::make_unique<MockTaggingStreamSocket>(
371       std::make_unique<MockTCPClientSocket>(address_list_, NetLog::Get(),
372                                             &data));
373   auto* tagging_sock_ptr = tagging_sock.get();
374 
375   // |socket| takes ownership of |tagging_sock|, but keep a non-owning pointer
376   // to it.
377   SOCKS5ClientSocket socket(std::move(tagging_sock),
378                             HostPortPair("localhost", 80),
379                             TRAFFIC_ANNOTATION_FOR_TESTS);
380 
381   EXPECT_EQ(tagging_sock_ptr->tag(), SocketTag());
382 #if BUILDFLAG(IS_ANDROID)
383   SocketTag tag(0x12345678, 0x87654321);
384   socket.ApplySocketTag(tag);
385   EXPECT_EQ(tagging_sock_ptr->tag(), tag);
386 #endif  // BUILDFLAG(IS_ANDROID)
387 }
388 
389 }  // namespace
390 
391 }  // namespace net
392