• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/tcp_server_socket.h"
6 
7 #include <string>
8 #include <vector>
9 
10 #include "base/compiler_specific.h"
11 #include "base/memory/ref_counted.h"
12 #include "base/memory/scoped_ptr.h"
13 #include "net/base/address_list.h"
14 #include "net/base/io_buffer.h"
15 #include "net/base/ip_endpoint.h"
16 #include "net/base/net_errors.h"
17 #include "net/base/test_completion_callback.h"
18 #include "net/socket/tcp_client_socket.h"
19 #include "testing/gtest/include/gtest/gtest.h"
20 #include "testing/platform_test.h"
21 
22 namespace net {
23 
24 namespace {
25 const int kListenBacklog = 5;
26 
27 class TCPServerSocketTest : public PlatformTest {
28  protected:
TCPServerSocketTest()29   TCPServerSocketTest()
30       : socket_(NULL, NetLog::Source()) {
31   }
32 
SetUpIPv4()33   void SetUpIPv4() {
34     IPEndPoint address;
35     ParseAddress("127.0.0.1", 0, &address);
36     ASSERT_EQ(OK, socket_.Listen(address, kListenBacklog));
37     ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_));
38   }
39 
SetUpIPv6(bool * success)40   void SetUpIPv6(bool* success) {
41     *success = false;
42     IPEndPoint address;
43     ParseAddress("::1", 0, &address);
44     if (socket_.Listen(address, kListenBacklog) != 0) {
45       LOG(ERROR) << "Failed to listen on ::1 - probably because IPv6 is "
46           "disabled. Skipping the test";
47       return;
48     }
49     ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_));
50     *success = true;
51   }
52 
ParseAddress(std::string ip_str,int port,IPEndPoint * address)53   void ParseAddress(std::string ip_str, int port, IPEndPoint* address) {
54     IPAddressNumber ip_number;
55     bool rv = ParseIPLiteralToNumber(ip_str, &ip_number);
56     if (!rv)
57       return;
58     *address = IPEndPoint(ip_number, port);
59   }
60 
GetPeerAddress(StreamSocket * socket)61   static IPEndPoint GetPeerAddress(StreamSocket* socket) {
62     IPEndPoint address;
63     EXPECT_EQ(OK, socket->GetPeerAddress(&address));
64     return address;
65   }
66 
local_address_list() const67   AddressList local_address_list() const {
68     return AddressList(local_address_);
69   }
70 
71   TCPServerSocket socket_;
72   IPEndPoint local_address_;
73 };
74 
TEST_F(TCPServerSocketTest,Accept)75 TEST_F(TCPServerSocketTest, Accept) {
76   ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
77 
78   TestCompletionCallback connect_callback;
79   TCPClientSocket connecting_socket(local_address_list(),
80                                     NULL, NetLog::Source());
81   connecting_socket.Connect(connect_callback.callback());
82 
83   TestCompletionCallback accept_callback;
84   scoped_ptr<StreamSocket> accepted_socket;
85   int result = socket_.Accept(&accepted_socket, accept_callback.callback());
86   if (result == ERR_IO_PENDING)
87     result = accept_callback.WaitForResult();
88   ASSERT_EQ(OK, result);
89 
90   ASSERT_TRUE(accepted_socket.get() != NULL);
91 
92   // Both sockets should be on the loopback network interface.
93   EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
94             local_address_.address());
95 
96   EXPECT_EQ(OK, connect_callback.WaitForResult());
97 }
98 
99 // Test Accept() callback.
TEST_F(TCPServerSocketTest,AcceptAsync)100 TEST_F(TCPServerSocketTest, AcceptAsync) {
101   ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
102 
103   TestCompletionCallback accept_callback;
104   scoped_ptr<StreamSocket> accepted_socket;
105 
106   ASSERT_EQ(ERR_IO_PENDING,
107             socket_.Accept(&accepted_socket, accept_callback.callback()));
108 
109   TestCompletionCallback connect_callback;
110   TCPClientSocket connecting_socket(local_address_list(),
111                                     NULL, NetLog::Source());
112   connecting_socket.Connect(connect_callback.callback());
113 
114   EXPECT_EQ(OK, connect_callback.WaitForResult());
115   EXPECT_EQ(OK, accept_callback.WaitForResult());
116 
117   EXPECT_TRUE(accepted_socket != NULL);
118 
119   // Both sockets should be on the loopback network interface.
120   EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
121             local_address_.address());
122 }
123 
124 // Accept two connections simultaneously.
TEST_F(TCPServerSocketTest,Accept2Connections)125 TEST_F(TCPServerSocketTest, Accept2Connections) {
126   ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
127 
128   TestCompletionCallback accept_callback;
129   scoped_ptr<StreamSocket> accepted_socket;
130 
131   ASSERT_EQ(ERR_IO_PENDING,
132             socket_.Accept(&accepted_socket, accept_callback.callback()));
133 
134   TestCompletionCallback connect_callback;
135   TCPClientSocket connecting_socket(local_address_list(),
136                                     NULL, NetLog::Source());
137   connecting_socket.Connect(connect_callback.callback());
138 
139   TestCompletionCallback connect_callback2;
140   TCPClientSocket connecting_socket2(local_address_list(),
141                                      NULL, NetLog::Source());
142   connecting_socket2.Connect(connect_callback2.callback());
143 
144   EXPECT_EQ(OK, accept_callback.WaitForResult());
145 
146   TestCompletionCallback accept_callback2;
147   scoped_ptr<StreamSocket> accepted_socket2;
148   int result = socket_.Accept(&accepted_socket2, accept_callback2.callback());
149   if (result == ERR_IO_PENDING)
150     result = accept_callback2.WaitForResult();
151   ASSERT_EQ(OK, result);
152 
153   EXPECT_EQ(OK, connect_callback.WaitForResult());
154 
155   EXPECT_TRUE(accepted_socket != NULL);
156   EXPECT_TRUE(accepted_socket2 != NULL);
157   EXPECT_NE(accepted_socket.get(), accepted_socket2.get());
158 
159   EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
160             local_address_.address());
161   EXPECT_EQ(GetPeerAddress(accepted_socket2.get()).address(),
162             local_address_.address());
163 }
164 
TEST_F(TCPServerSocketTest,AcceptIPv6)165 TEST_F(TCPServerSocketTest, AcceptIPv6) {
166   bool initialized = false;
167   ASSERT_NO_FATAL_FAILURE(SetUpIPv6(&initialized));
168   if (!initialized)
169     return;
170 
171   TestCompletionCallback connect_callback;
172   TCPClientSocket connecting_socket(local_address_list(),
173                                     NULL, NetLog::Source());
174   connecting_socket.Connect(connect_callback.callback());
175 
176   TestCompletionCallback accept_callback;
177   scoped_ptr<StreamSocket> accepted_socket;
178   int result = socket_.Accept(&accepted_socket, accept_callback.callback());
179   if (result == ERR_IO_PENDING)
180     result = accept_callback.WaitForResult();
181   ASSERT_EQ(OK, result);
182 
183   ASSERT_TRUE(accepted_socket.get() != NULL);
184 
185   // Both sockets should be on the loopback network interface.
186   EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
187             local_address_.address());
188 
189   EXPECT_EQ(OK, connect_callback.WaitForResult());
190 }
191 
TEST_F(TCPServerSocketTest,AcceptIO)192 TEST_F(TCPServerSocketTest, AcceptIO) {
193   ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
194 
195   TestCompletionCallback connect_callback;
196   TCPClientSocket connecting_socket(local_address_list(),
197                                     NULL, NetLog::Source());
198   connecting_socket.Connect(connect_callback.callback());
199 
200   TestCompletionCallback accept_callback;
201   scoped_ptr<StreamSocket> accepted_socket;
202   int result = socket_.Accept(&accepted_socket, accept_callback.callback());
203   ASSERT_EQ(OK, accept_callback.GetResult(result));
204 
205   ASSERT_TRUE(accepted_socket.get() != NULL);
206 
207   // Both sockets should be on the loopback network interface.
208   EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
209             local_address_.address());
210 
211   EXPECT_EQ(OK, connect_callback.WaitForResult());
212 
213   const std::string message("test message");
214   std::vector<char> buffer(message.size());
215 
216   size_t bytes_written = 0;
217   while (bytes_written < message.size()) {
218     scoped_refptr<net::IOBufferWithSize> write_buffer(
219         new net::IOBufferWithSize(message.size() - bytes_written));
220     memmove(write_buffer->data(), message.data(), message.size());
221 
222     TestCompletionCallback write_callback;
223     int write_result = accepted_socket->Write(
224         write_buffer.get(), write_buffer->size(), write_callback.callback());
225     write_result = write_callback.GetResult(write_result);
226     ASSERT_TRUE(write_result >= 0);
227     ASSERT_TRUE(bytes_written + write_result <= message.size());
228     bytes_written += write_result;
229   }
230 
231   size_t bytes_read = 0;
232   while (bytes_read < message.size()) {
233     scoped_refptr<net::IOBufferWithSize> read_buffer(
234         new net::IOBufferWithSize(message.size() - bytes_read));
235     TestCompletionCallback read_callback;
236     int read_result = connecting_socket.Read(
237         read_buffer.get(), read_buffer->size(), read_callback.callback());
238     read_result = read_callback.GetResult(read_result);
239     ASSERT_TRUE(read_result >= 0);
240     ASSERT_TRUE(bytes_read + read_result <= message.size());
241     memmove(&buffer[bytes_read], read_buffer->data(), read_result);
242     bytes_read += read_result;
243   }
244 
245   std::string received_message(buffer.begin(), buffer.end());
246   ASSERT_EQ(message, received_message);
247 }
248 
249 }  // namespace
250 
251 }  // namespace net
252