1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/udp/udp_client_socket.h"
6 #include "net/udp/udp_server_socket.h"
7
8 #include "base/basictypes.h"
9 #include "base/metrics/histogram.h"
10 #include "net/base/io_buffer.h"
11 #include "net/base/ip_endpoint.h"
12 #include "net/base/net_errors.h"
13 #include "net/base/net_test_suite.h"
14 #include "net/base/net_util.h"
15 #include "net/base/sys_addrinfo.h"
16 #include "net/base/test_completion_callback.h"
17 #include "testing/gtest/include/gtest/gtest.h"
18 #include "testing/platform_test.h"
19
20 namespace net {
21
22 namespace {
23
24 class UDPSocketTest : public PlatformTest {
25 public:
UDPSocketTest()26 UDPSocketTest()
27 : buffer_(new IOBufferWithSize(kMaxRead)) {
28 }
29
30 // Blocks until data is read from the socket.
RecvFromSocket(UDPServerSocket * socket)31 std::string RecvFromSocket(UDPServerSocket* socket) {
32 TestCompletionCallback callback;
33
34 int rv = socket->RecvFrom(buffer_, kMaxRead, &recv_from_address_,
35 &callback);
36 if (rv == ERR_IO_PENDING)
37 rv = callback.WaitForResult();
38 if (rv < 0)
39 return ""; // error!
40 return std::string(buffer_->data(), rv);
41 }
42
43 // Loop until |msg| has been written to the socket or until an
44 // error occurs.
45 // If |address| is specified, then it is used for the destination
46 // to send to. Otherwise, will send to the last socket this server
47 // received from.
SendToSocket(UDPServerSocket * socket,std::string msg)48 int SendToSocket(UDPServerSocket* socket, std::string msg) {
49 return SendToSocket(socket, msg, recv_from_address_);
50 }
51
SendToSocket(UDPServerSocket * socket,std::string msg,const IPEndPoint & address)52 int SendToSocket(UDPServerSocket* socket,
53 std::string msg,
54 const IPEndPoint& address) {
55 TestCompletionCallback callback;
56
57 int length = msg.length();
58 scoped_refptr<StringIOBuffer> io_buffer(new StringIOBuffer(msg));
59 scoped_refptr<DrainableIOBuffer> buffer(
60 new DrainableIOBuffer(io_buffer, length));
61
62 int bytes_sent = 0;
63 while (buffer->BytesRemaining()) {
64 int rv = socket->SendTo(buffer, buffer->BytesRemaining(),
65 address, &callback);
66 if (rv == ERR_IO_PENDING)
67 rv = callback.WaitForResult();
68 if (rv <= 0)
69 return bytes_sent > 0 ? bytes_sent : rv;
70 bytes_sent += rv;
71 buffer->DidConsume(rv);
72 }
73 return bytes_sent;
74 }
75
ReadSocket(UDPClientSocket * socket)76 std::string ReadSocket(UDPClientSocket* socket) {
77 TestCompletionCallback callback;
78
79 int rv = socket->Read(buffer_, kMaxRead, &callback);
80 if (rv == ERR_IO_PENDING)
81 rv = callback.WaitForResult();
82 if (rv < 0)
83 return ""; // error!
84 return std::string(buffer_->data(), rv);
85 }
86
87 // Loop until |msg| has been written to the socket or until an
88 // error occurs.
WriteSocket(UDPClientSocket * socket,std::string msg)89 int WriteSocket(UDPClientSocket* socket, std::string msg) {
90 TestCompletionCallback callback;
91
92 int length = msg.length();
93 scoped_refptr<StringIOBuffer> io_buffer(new StringIOBuffer(msg));
94 scoped_refptr<DrainableIOBuffer> buffer(
95 new DrainableIOBuffer(io_buffer, length));
96
97 int bytes_sent = 0;
98 while (buffer->BytesRemaining()) {
99 int rv = socket->Write(buffer, buffer->BytesRemaining(), &callback);
100 if (rv == ERR_IO_PENDING)
101 rv = callback.WaitForResult();
102 if (rv <= 0)
103 return bytes_sent > 0 ? bytes_sent : rv;
104 bytes_sent += rv;
105 buffer->DidConsume(rv);
106 }
107 return bytes_sent;
108 }
109
110 protected:
111 static const int kMaxRead = 1024;
112 scoped_refptr<IOBufferWithSize> buffer_;
113 IPEndPoint recv_from_address_;
114 };
115
116 // Creates and address from an ip/port and returns it in |address|.
CreateUDPAddress(std::string ip_str,int port,IPEndPoint * address)117 void CreateUDPAddress(std::string ip_str, int port, IPEndPoint* address) {
118 IPAddressNumber ip_number;
119 bool rv = ParseIPLiteralToNumber(ip_str, &ip_number);
120 if (!rv)
121 return;
122 *address = IPEndPoint(ip_number, port);
123 }
124
TEST_F(UDPSocketTest,Connect)125 TEST_F(UDPSocketTest, Connect) {
126 const int kPort = 9999;
127 std::string simple_message("hello world!");
128
129 // Setup the server to listen.
130 IPEndPoint bind_address;
131 CreateUDPAddress("0.0.0.0", kPort, &bind_address);
132 UDPServerSocket server(NULL, NetLog::Source());
133 int rv = server.Listen(bind_address);
134 EXPECT_EQ(OK, rv);
135
136 // Setup the client.
137 IPEndPoint server_address;
138 CreateUDPAddress("127.0.0.1", kPort, &server_address);
139 UDPClientSocket client(NULL, NetLog::Source());
140 rv = client.Connect(server_address);
141 EXPECT_EQ(OK, rv);
142
143 // Client sends to the server.
144 rv = WriteSocket(&client, simple_message);
145 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
146
147 // Server waits for message.
148 std::string str = RecvFromSocket(&server);
149 DCHECK(simple_message == str);
150
151 // Server echoes reply.
152 rv = SendToSocket(&server, simple_message);
153 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
154
155 // Client waits for response.
156 str = ReadSocket(&client);
157 DCHECK(simple_message == str);
158 }
159
160 // In this test, we verify that connect() on a socket will have the effect
161 // of filtering reads on this socket only to data read from the destination
162 // we connected to.
163 //
164 // The purpose of this test is that some documentation indicates that connect
165 // binds the client's sends to send to a particular server endpoint, but does
166 // not bind the client's reads to only be from that endpoint, and that we need
167 // to always use recvfrom() to disambiguate.
TEST_F(UDPSocketTest,VerifyConnectBindsAddr)168 TEST_F(UDPSocketTest, VerifyConnectBindsAddr) {
169 const int kPort1 = 9999;
170 const int kPort2 = 10000;
171 std::string simple_message("hello world!");
172 std::string foreign_message("BAD MESSAGE TO GET!!");
173
174 // Setup the first server to listen.
175 IPEndPoint bind_address;
176 CreateUDPAddress("0.0.0.0", kPort1, &bind_address);
177 UDPServerSocket server1(NULL, NetLog::Source());
178 int rv = server1.Listen(bind_address);
179 EXPECT_EQ(OK, rv);
180
181 // Setup the second server to listen.
182 CreateUDPAddress("0.0.0.0", kPort2, &bind_address);
183 UDPServerSocket server2(NULL, NetLog::Source());
184 rv = server2.Listen(bind_address);
185 EXPECT_EQ(OK, rv);
186
187 // Setup the client, connected to server 1.
188 IPEndPoint server_address;
189 CreateUDPAddress("127.0.0.1", kPort1, &server_address);
190 UDPClientSocket client(NULL, NetLog::Source());
191 rv = client.Connect(server_address);
192 EXPECT_EQ(OK, rv);
193
194 // Client sends to server1.
195 rv = WriteSocket(&client, simple_message);
196 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
197
198 // Server1 waits for message.
199 std::string str = RecvFromSocket(&server1);
200 DCHECK(simple_message == str);
201
202 // Get the client's address.
203 IPEndPoint client_address;
204 rv = client.GetLocalAddress(&client_address);
205 EXPECT_EQ(OK, rv);
206
207 // Server2 sends reply.
208 rv = SendToSocket(&server2, foreign_message,
209 client_address);
210 EXPECT_EQ(foreign_message.length(), static_cast<size_t>(rv));
211
212 // Server1 sends reply.
213 rv = SendToSocket(&server1, simple_message,
214 client_address);
215 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
216
217 // Client waits for response.
218 str = ReadSocket(&client);
219 DCHECK(simple_message == str);
220 }
221
TEST_F(UDPSocketTest,ClientGetLocalPeerAddresses)222 TEST_F(UDPSocketTest, ClientGetLocalPeerAddresses) {
223 struct TestData {
224 std::string remote_address;
225 std::string local_address;
226 bool may_fail;
227 } tests[] = {
228 { "127.0.00.1", "127.0.0.1", false },
229 { "192.168.1.1", "127.0.0.1", false },
230 { "::1", "::1", true },
231 { "2001:db8:0::42", "::1", true },
232 };
233 for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); i++) {
234 SCOPED_TRACE(std::string("Connecting from ") + tests[i].local_address +
235 std::string(" to ") + tests[i].remote_address);
236
237 net::IPAddressNumber ip_number;
238 net::ParseIPLiteralToNumber(tests[i].remote_address, &ip_number);
239 net::IPEndPoint remote_address(ip_number, 80);
240 net::ParseIPLiteralToNumber(tests[i].local_address, &ip_number);
241 net::IPEndPoint local_address(ip_number, 80);
242
243 UDPClientSocket client(NULL, NetLog::Source());
244 int rv = client.Connect(remote_address);
245 if (tests[i].may_fail && rv == ERR_ADDRESS_UNREACHABLE) {
246 // Connect() may return ERR_ADDRESS_UNREACHABLE for IPv6
247 // addresses if IPv6 is not configured.
248 continue;
249 }
250
251 EXPECT_LE(ERR_IO_PENDING, rv);
252
253 IPEndPoint fetched_local_address;
254 rv = client.GetLocalAddress(&fetched_local_address);
255 EXPECT_EQ(OK, rv);
256
257 // TODO(mbelshe): figure out how to verify the IP and port.
258 // The port is dynamically generated by the udp stack.
259 // The IP is the real IP of the client, not necessarily
260 // loopback.
261 //EXPECT_EQ(local_address.address(), fetched_local_address.address());
262
263 IPEndPoint fetched_remote_address;
264 rv = client.GetPeerAddress(&fetched_remote_address);
265 EXPECT_EQ(OK, rv);
266
267 EXPECT_EQ(remote_address, fetched_remote_address);
268 }
269 }
270
TEST_F(UDPSocketTest,ServerGetLocalAddress)271 TEST_F(UDPSocketTest, ServerGetLocalAddress) {
272 IPEndPoint bind_address;
273 CreateUDPAddress("127.0.0.1", 0, &bind_address);
274 UDPServerSocket server(NULL, NetLog::Source());
275 int rv = server.Listen(bind_address);
276 EXPECT_EQ(OK, rv);
277
278 IPEndPoint local_address;
279 rv = server.GetLocalAddress(&local_address);
280 EXPECT_EQ(rv, 0);
281
282 // Verify that port was allocated.
283 EXPECT_GT(local_address.port(), 0);
284 EXPECT_EQ(local_address.address(), bind_address.address());
285 }
286
TEST_F(UDPSocketTest,ServerGetPeerAddress)287 TEST_F(UDPSocketTest, ServerGetPeerAddress) {
288 IPEndPoint bind_address;
289 CreateUDPAddress("127.0.0.1", 0, &bind_address);
290 UDPServerSocket server(NULL, NetLog::Source());
291 int rv = server.Listen(bind_address);
292 EXPECT_EQ(OK, rv);
293
294 IPEndPoint peer_address;
295 rv = server.GetPeerAddress(&peer_address);
296 EXPECT_EQ(rv, ERR_SOCKET_NOT_CONNECTED);
297 }
298
299 // Close the socket while read is pending.
TEST_F(UDPSocketTest,CloseWithPendingRead)300 TEST_F(UDPSocketTest, CloseWithPendingRead) {
301 IPEndPoint bind_address;
302 CreateUDPAddress("127.0.0.1", 0, &bind_address);
303 UDPServerSocket server(NULL, NetLog::Source());
304 int rv = server.Listen(bind_address);
305 EXPECT_EQ(OK, rv);
306
307 TestCompletionCallback callback;
308 IPEndPoint from;
309 rv = server.RecvFrom(buffer_, kMaxRead, &from, &callback);
310 EXPECT_EQ(rv, ERR_IO_PENDING);
311
312 server.Close();
313
314 EXPECT_FALSE(callback.have_result());
315 }
316
317 } // namespace
318
319 } // namespace net
320