1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/udp_socket.h"
6
7 #include <algorithm>
8
9 #include "base/containers/circular_deque.h"
10 #include "base/functional/bind.h"
11 #include "base/location.h"
12 #include "base/memory/raw_ptr.h"
13 #include "base/memory/weak_ptr.h"
14 #include "base/run_loop.h"
15 #include "base/scoped_clear_last_error.h"
16 #include "base/strings/string_number_conversions.h"
17 #include "base/task/single_thread_task_runner.h"
18 #include "base/test/scoped_feature_list.h"
19 #include "base/threading/thread.h"
20 #include "base/time/time.h"
21 #include "build/build_config.h"
22 #include "build/chromeos_buildflags.h"
23 #include "net/base/features.h"
24 #include "net/base/io_buffer.h"
25 #include "net/base/ip_address.h"
26 #include "net/base/ip_endpoint.h"
27 #include "net/base/net_errors.h"
28 #include "net/base/network_interfaces.h"
29 #include "net/base/test_completion_callback.h"
30 #include "net/log/net_log_event_type.h"
31 #include "net/log/net_log_source.h"
32 #include "net/log/test_net_log.h"
33 #include "net/log/test_net_log_util.h"
34 #include "net/socket/socket_test_util.h"
35 #include "net/socket/udp_client_socket.h"
36 #include "net/socket/udp_server_socket.h"
37 #include "net/socket/udp_socket_global_limits.h"
38 #include "net/test/gtest_util.h"
39 #include "net/test/test_with_task_environment.h"
40 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
41 #include "testing/gmock/include/gmock/gmock.h"
42 #include "testing/gtest/include/gtest/gtest.h"
43 #include "testing/platform_test.h"
44
45 #if !BUILDFLAG(IS_WIN)
46 #include <netinet/in.h>
47 #include <sys/socket.h>
48 #else
49 #include <winsock2.h>
50 #endif
51
52 #if BUILDFLAG(IS_ANDROID)
53 #include "base/android/build_info.h"
54 #include "net/android/network_change_notifier_factory_android.h"
55 #include "net/base/network_change_notifier.h"
56 #endif
57
58 #if BUILDFLAG(IS_IOS)
59 #include <TargetConditionals.h>
60 #endif
61
62 #if BUILDFLAG(IS_MAC)
63 #include "base/mac/mac_util.h"
64 #endif // BUILDFLAG(IS_MAC)
65
66 using net::test::IsError;
67 using net::test::IsOk;
68 using testing::DoAll;
69 using testing::Not;
70
71 namespace net {
72
73 namespace {
74
75 // Creates an address from ip address and port and writes it to |*address|.
CreateUDPAddress(const std::string & ip_str,uint16_t port,IPEndPoint * address)76 bool CreateUDPAddress(const std::string& ip_str,
77 uint16_t port,
78 IPEndPoint* address) {
79 IPAddress ip_address;
80 if (!ip_address.AssignFromIPLiteral(ip_str))
81 return false;
82
83 *address = IPEndPoint(ip_address, port);
84 return true;
85 }
86
87 class UDPSocketTest : public PlatformTest, public WithTaskEnvironment {
88 public:
UDPSocketTest()89 UDPSocketTest() : buffer_(base::MakeRefCounted<IOBufferWithSize>(kMaxRead)) {}
90
91 // Blocks until data is read from the socket.
RecvFromSocket(UDPServerSocket * socket)92 std::string RecvFromSocket(UDPServerSocket* socket) {
93 TestCompletionCallback callback;
94
95 int rv = socket->RecvFrom(buffer_.get(), kMaxRead, &recv_from_address_,
96 callback.callback());
97 rv = callback.GetResult(rv);
98 if (rv < 0)
99 return std::string();
100 return std::string(buffer_->data(), rv);
101 }
102
103 // Sends UDP packet.
104 // If |address| is specified, then it is used for the destination
105 // to send to. Otherwise, will send to the last socket this server
106 // received from.
SendToSocket(UDPServerSocket * socket,const std::string & msg)107 int SendToSocket(UDPServerSocket* socket, const std::string& msg) {
108 return SendToSocket(socket, msg, recv_from_address_);
109 }
110
SendToSocket(UDPServerSocket * socket,std::string msg,const IPEndPoint & address)111 int SendToSocket(UDPServerSocket* socket,
112 std::string msg,
113 const IPEndPoint& address) {
114 scoped_refptr<StringIOBuffer> io_buffer =
115 base::MakeRefCounted<StringIOBuffer>(msg);
116 TestCompletionCallback callback;
117 int rv = socket->SendTo(io_buffer.get(), io_buffer->size(), address,
118 callback.callback());
119 return callback.GetResult(rv);
120 }
121
ReadSocket(UDPClientSocket * socket)122 std::string ReadSocket(UDPClientSocket* socket) {
123 TestCompletionCallback callback;
124
125 int rv = socket->Read(buffer_.get(), kMaxRead, callback.callback());
126 rv = callback.GetResult(rv);
127 if (rv < 0)
128 return std::string();
129 return std::string(buffer_->data(), rv);
130 }
131
132 // Writes specified message to the socket.
WriteSocket(UDPClientSocket * socket,const std::string & msg)133 int WriteSocket(UDPClientSocket* socket, const std::string& msg) {
134 scoped_refptr<StringIOBuffer> io_buffer =
135 base::MakeRefCounted<StringIOBuffer>(msg);
136 TestCompletionCallback callback;
137 int rv = socket->Write(io_buffer.get(), io_buffer->size(),
138 callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
139 return callback.GetResult(rv);
140 }
141
WriteSocketIgnoreResult(UDPClientSocket * socket,const std::string & msg)142 void WriteSocketIgnoreResult(UDPClientSocket* socket,
143 const std::string& msg) {
144 WriteSocket(socket, msg);
145 }
146
147 // And again for a bare socket
SendToSocket(UDPSocket * socket,std::string msg,const IPEndPoint & address)148 int SendToSocket(UDPSocket* socket,
149 std::string msg,
150 const IPEndPoint& address) {
151 auto io_buffer = base::MakeRefCounted<StringIOBuffer>(msg);
152 TestCompletionCallback callback;
153 int rv = socket->SendTo(io_buffer.get(), io_buffer->size(), address,
154 callback.callback());
155 return callback.GetResult(rv);
156 }
157
158 // Run unit test for a connection test.
159 // |use_nonblocking_io| is used to switch between overlapped and non-blocking
160 // IO on Windows. It has no effect in other ports.
161 void ConnectTest(bool use_nonblocking_io, bool use_async);
162
163 protected:
164 static const int kMaxRead = 1024;
165 scoped_refptr<IOBufferWithSize> buffer_;
166 IPEndPoint recv_from_address_;
167 };
168
169 const int UDPSocketTest::kMaxRead;
170
ReadCompleteCallback(int * result_out,base::OnceClosure callback,int result)171 void ReadCompleteCallback(int* result_out,
172 base::OnceClosure callback,
173 int result) {
174 *result_out = result;
175 std::move(callback).Run();
176 }
177
ConnectTest(bool use_nonblocking_io,bool use_async)178 void UDPSocketTest::ConnectTest(bool use_nonblocking_io, bool use_async) {
179 std::string simple_message("hello world!");
180 RecordingNetLogObserver net_log_observer;
181 // Setup the server to listen.
182 IPEndPoint server_address(IPAddress::IPv4Localhost(), 0 /* port */);
183 auto server =
184 std::make_unique<UDPServerSocket>(NetLog::Get(), NetLogSource());
185 if (use_nonblocking_io)
186 server->UseNonBlockingIO();
187 server->AllowAddressReuse();
188 ASSERT_THAT(server->Listen(server_address), IsOk());
189 // Get bound port.
190 ASSERT_THAT(server->GetLocalAddress(&server_address), IsOk());
191
192 // Setup the client.
193 auto client = std::make_unique<UDPClientSocket>(
194 DatagramSocket::DEFAULT_BIND, NetLog::Get(), NetLogSource());
195 if (use_nonblocking_io)
196 client->UseNonBlockingIO();
197
198 if (!use_async) {
199 EXPECT_THAT(client->Connect(server_address), IsOk());
200 } else {
201 TestCompletionCallback callback;
202 int rv = client->ConnectAsync(server_address, callback.callback());
203 if (rv != OK) {
204 ASSERT_EQ(rv, ERR_IO_PENDING);
205 rv = callback.WaitForResult();
206 EXPECT_EQ(rv, OK);
207 } else {
208 EXPECT_EQ(rv, OK);
209 }
210 }
211 // Client sends to the server.
212 EXPECT_EQ(simple_message.length(),
213 static_cast<size_t>(WriteSocket(client.get(), simple_message)));
214
215 // Server waits for message.
216 std::string str = RecvFromSocket(server.get());
217 EXPECT_EQ(simple_message, str);
218
219 // Server echoes reply.
220 EXPECT_EQ(simple_message.length(),
221 static_cast<size_t>(SendToSocket(server.get(), simple_message)));
222
223 // Client waits for response.
224 str = ReadSocket(client.get());
225 EXPECT_EQ(simple_message, str);
226
227 // Test asynchronous read. Server waits for message.
228 base::RunLoop run_loop;
229 int read_result = 0;
230 int rv = server->RecvFrom(buffer_.get(), kMaxRead, &recv_from_address_,
231 base::BindOnce(&ReadCompleteCallback, &read_result,
232 run_loop.QuitClosure()));
233 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
234
235 // Client sends to the server.
236 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
237 FROM_HERE,
238 base::BindOnce(&UDPSocketTest::WriteSocketIgnoreResult,
239 base::Unretained(this), client.get(), simple_message));
240 run_loop.Run();
241 EXPECT_EQ(simple_message.length(), static_cast<size_t>(read_result));
242 EXPECT_EQ(simple_message, std::string(buffer_->data(), read_result));
243
244 NetLogSource server_net_log_source = server->NetLog().source();
245 NetLogSource client_net_log_source = client->NetLog().source();
246
247 // Delete sockets so they log their final events.
248 server.reset();
249 client.reset();
250
251 // Check the server's log.
252 auto server_entries =
253 net_log_observer.GetEntriesForSource(server_net_log_source);
254 ASSERT_EQ(6u, server_entries.size());
255 EXPECT_TRUE(
256 LogContainsBeginEvent(server_entries, 0, NetLogEventType::SOCKET_ALIVE));
257 EXPECT_TRUE(LogContainsEvent(server_entries, 1,
258 NetLogEventType::UDP_LOCAL_ADDRESS,
259 NetLogEventPhase::NONE));
260 EXPECT_TRUE(LogContainsEvent(server_entries, 2,
261 NetLogEventType::UDP_BYTES_RECEIVED,
262 NetLogEventPhase::NONE));
263 EXPECT_TRUE(LogContainsEvent(server_entries, 3,
264 NetLogEventType::UDP_BYTES_SENT,
265 NetLogEventPhase::NONE));
266 EXPECT_TRUE(LogContainsEvent(server_entries, 4,
267 NetLogEventType::UDP_BYTES_RECEIVED,
268 NetLogEventPhase::NONE));
269 EXPECT_TRUE(
270 LogContainsEndEvent(server_entries, 5, NetLogEventType::SOCKET_ALIVE));
271
272 // Check the client's log.
273 auto client_entries =
274 net_log_observer.GetEntriesForSource(client_net_log_source);
275 EXPECT_EQ(7u, client_entries.size());
276 EXPECT_TRUE(
277 LogContainsBeginEvent(client_entries, 0, NetLogEventType::SOCKET_ALIVE));
278 EXPECT_TRUE(
279 LogContainsBeginEvent(client_entries, 1, NetLogEventType::UDP_CONNECT));
280 EXPECT_TRUE(
281 LogContainsEndEvent(client_entries, 2, NetLogEventType::UDP_CONNECT));
282 EXPECT_TRUE(LogContainsEvent(client_entries, 3,
283 NetLogEventType::UDP_BYTES_SENT,
284 NetLogEventPhase::NONE));
285 EXPECT_TRUE(LogContainsEvent(client_entries, 4,
286 NetLogEventType::UDP_BYTES_RECEIVED,
287 NetLogEventPhase::NONE));
288 EXPECT_TRUE(LogContainsEvent(client_entries, 5,
289 NetLogEventType::UDP_BYTES_SENT,
290 NetLogEventPhase::NONE));
291 EXPECT_TRUE(
292 LogContainsEndEvent(client_entries, 6, NetLogEventType::SOCKET_ALIVE));
293 }
294
TEST_F(UDPSocketTest,Connect)295 TEST_F(UDPSocketTest, Connect) {
296 // The variable |use_nonblocking_io| has no effect in non-Windows ports.
297 // Run ConnectTest once with sync connect and once with async connect
298 ConnectTest(false, false);
299 ConnectTest(false, true);
300 }
301
302 #if BUILDFLAG(IS_WIN)
TEST_F(UDPSocketTest,ConnectNonBlocking)303 TEST_F(UDPSocketTest, ConnectNonBlocking) {
304 ConnectTest(true, false);
305 ConnectTest(true, true);
306 }
307 #endif
308
TEST_F(UDPSocketTest,PartialRecv)309 TEST_F(UDPSocketTest, PartialRecv) {
310 UDPServerSocket server_socket(nullptr, NetLogSource());
311 ASSERT_THAT(server_socket.Listen(IPEndPoint(IPAddress::IPv4Localhost(), 0)),
312 IsOk());
313 IPEndPoint server_address;
314 ASSERT_THAT(server_socket.GetLocalAddress(&server_address), IsOk());
315
316 UDPClientSocket client_socket(DatagramSocket::DEFAULT_BIND, nullptr,
317 NetLogSource());
318 ASSERT_THAT(client_socket.Connect(server_address), IsOk());
319
320 std::string test_packet("hello world!");
321 ASSERT_EQ(static_cast<int>(test_packet.size()),
322 WriteSocket(&client_socket, test_packet));
323
324 TestCompletionCallback recv_callback;
325
326 // Read just 2 bytes. Read() is expected to return the first 2 bytes from the
327 // packet and discard the rest.
328 const int kPartialReadSize = 2;
329 scoped_refptr<IOBuffer> buffer =
330 base::MakeRefCounted<IOBuffer>(kPartialReadSize);
331 int rv =
332 server_socket.RecvFrom(buffer.get(), kPartialReadSize,
333 &recv_from_address_, recv_callback.callback());
334 rv = recv_callback.GetResult(rv);
335
336 EXPECT_EQ(rv, ERR_MSG_TOO_BIG);
337
338 // Send a different message again.
339 std::string second_packet("Second packet");
340 ASSERT_EQ(static_cast<int>(second_packet.size()),
341 WriteSocket(&client_socket, second_packet));
342
343 // Read whole packet now.
344 std::string received = RecvFromSocket(&server_socket);
345 EXPECT_EQ(second_packet, received);
346 }
347
348 #if BUILDFLAG(IS_APPLE) || BUILDFLAG(IS_ANDROID)
349 // - MacOS: requires root permissions on OSX 10.7+.
350 // - Android: devices attached to testbots don't have default network, so
351 // broadcasting to 255.255.255.255 returns error -109 (Address not reachable).
352 // crbug.com/139144.
353 #define MAYBE_LocalBroadcast DISABLED_LocalBroadcast
354 #else
355 #define MAYBE_LocalBroadcast LocalBroadcast
356 #endif
TEST_F(UDPSocketTest,MAYBE_LocalBroadcast)357 TEST_F(UDPSocketTest, MAYBE_LocalBroadcast) {
358 std::string first_message("first message"), second_message("second message");
359
360 IPEndPoint listen_address;
361 ASSERT_TRUE(CreateUDPAddress("0.0.0.0", 0 /* port */, &listen_address));
362
363 auto server1 =
364 std::make_unique<UDPServerSocket>(NetLog::Get(), NetLogSource());
365 auto server2 =
366 std::make_unique<UDPServerSocket>(NetLog::Get(), NetLogSource());
367 server1->AllowAddressReuse();
368 server1->AllowBroadcast();
369 server2->AllowAddressReuse();
370 server2->AllowBroadcast();
371
372 EXPECT_THAT(server1->Listen(listen_address), IsOk());
373 // Get bound port.
374 EXPECT_THAT(server1->GetLocalAddress(&listen_address), IsOk());
375 EXPECT_THAT(server2->Listen(listen_address), IsOk());
376
377 IPEndPoint broadcast_address;
378 ASSERT_TRUE(CreateUDPAddress("127.255.255.255", listen_address.port(),
379 &broadcast_address));
380 ASSERT_EQ(static_cast<int>(first_message.size()),
381 SendToSocket(server1.get(), first_message, broadcast_address));
382 std::string str = RecvFromSocket(server1.get());
383 ASSERT_EQ(first_message, str);
384 str = RecvFromSocket(server2.get());
385 ASSERT_EQ(first_message, str);
386
387 ASSERT_EQ(static_cast<int>(second_message.size()),
388 SendToSocket(server2.get(), second_message, broadcast_address));
389 str = RecvFromSocket(server1.get());
390 ASSERT_EQ(second_message, str);
391 str = RecvFromSocket(server2.get());
392 ASSERT_EQ(second_message, str);
393 }
394
395 // ConnectRandomBind verifies RANDOM_BIND is handled correctly. It connects
396 // 1000 sockets and then verifies that the allocated port numbers satisfy the
397 // following 2 conditions:
398 // 1. Range from min port value to max is greater than 10000.
399 // 2. There is at least one port in the 5 buckets in the [min, max] range.
400 //
401 // These conditions are not enough to verify that the port numbers are truly
402 // random, but they are enough to protect from most common non-random port
403 // allocation strategies (e.g. counter, pool of available ports, etc.) False
404 // positive result is theoretically possible, but its probability is negligible.
TEST_F(UDPSocketTest,ConnectRandomBind)405 TEST_F(UDPSocketTest, ConnectRandomBind) {
406 const int kIterations = 1000;
407
408 std::vector<int> used_ports;
409 for (int i = 0; i < kIterations; ++i) {
410 UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
411 NetLogSource());
412 EXPECT_THAT(socket.Connect(IPEndPoint(IPAddress::IPv4Localhost(), 53)),
413 IsOk());
414
415 IPEndPoint client_address;
416 EXPECT_THAT(socket.GetLocalAddress(&client_address), IsOk());
417 used_ports.push_back(client_address.port());
418 }
419
420 int min_port = *std::min_element(used_ports.begin(), used_ports.end());
421 int max_port = *std::max_element(used_ports.begin(), used_ports.end());
422 int range = max_port - min_port + 1;
423
424 // Verify that the range of ports used by the random port allocator is wider
425 // than 10k. Assuming that socket implementation limits port range to 16k
426 // ports (default on Fuchsia) probability of false negative is below
427 // 10^-200.
428 static int kMinRange = 10000;
429 EXPECT_GT(range, kMinRange);
430
431 static int kBuckets = 5;
432 std::vector<int> bucket_sizes(kBuckets, 0);
433 for (int port : used_ports) {
434 bucket_sizes[(port - min_port) * kBuckets / range] += 1;
435 }
436
437 // Verify that there is at least one value in each bucket. Probability of
438 // false negative is below (kBuckets * (1 - 1 / kBuckets) ^ kIterations),
439 // which is less than 10^-96.
440 for (int size : bucket_sizes) {
441 EXPECT_GT(size, 0);
442 }
443 }
444
TEST_F(UDPSocketTest,ConnectFail)445 TEST_F(UDPSocketTest, ConnectFail) {
446 UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
447
448 EXPECT_THAT(socket.Open(ADDRESS_FAMILY_IPV4), IsOk());
449
450 // Connect to an IPv6 address should fail since the socket was created for
451 // IPv4.
452 EXPECT_THAT(socket.Connect(net::IPEndPoint(IPAddress::IPv6Localhost(), 53)),
453 Not(IsOk()));
454
455 // Make sure that UDPSocket actually closed the socket.
456 EXPECT_FALSE(socket.is_connected());
457 }
458
459 // Similar to ConnectFail but UDPSocket adopts an opened socket instead of
460 // opening one directly.
TEST_F(UDPSocketTest,AdoptedSocket)461 TEST_F(UDPSocketTest, AdoptedSocket) {
462 auto socketfd =
463 CreatePlatformSocket(ConvertAddressFamily(ADDRESS_FAMILY_IPV4),
464 SOCK_DGRAM, AF_UNIX ? 0 : IPPROTO_UDP);
465 UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
466
467 EXPECT_THAT(socket.AdoptOpenedSocket(ADDRESS_FAMILY_IPV4, socketfd), IsOk());
468
469 // Connect to an IPv6 address should fail since the socket was created for
470 // IPv4.
471 EXPECT_THAT(socket.Connect(net::IPEndPoint(IPAddress::IPv6Localhost(), 53)),
472 Not(IsOk()));
473
474 // Make sure that UDPSocket actually closed the socket.
475 EXPECT_FALSE(socket.is_connected());
476 }
477
478 // Tests that UDPSocket updates the global counter correctly.
TEST_F(UDPSocketTest,LimitAdoptSocket)479 TEST_F(UDPSocketTest, LimitAdoptSocket) {
480 ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
481 {
482 // Creating a platform socket does not increase count.
483 auto socketfd =
484 CreatePlatformSocket(ConvertAddressFamily(ADDRESS_FAMILY_IPV4),
485 SOCK_DGRAM, AF_UNIX ? 0 : IPPROTO_UDP);
486 ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
487
488 // Simply allocating a UDPSocket does not increase count.
489 UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
490 EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
491
492 // Calling AdoptOpenedSocket() allocates the socket and increases the global
493 // counter.
494 EXPECT_THAT(socket.AdoptOpenedSocket(ADDRESS_FAMILY_IPV4, socketfd),
495 IsOk());
496 EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
497
498 // Connect to an IPv6 address should fail since the socket was created for
499 // IPv4.
500 EXPECT_THAT(socket.Connect(net::IPEndPoint(IPAddress::IPv6Localhost(), 53)),
501 Not(IsOk()));
502
503 // That Connect() failed doesn't change the global counter.
504 EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
505 }
506 // Finally, destroying UDPSocket decrements the global counter.
507 EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
508 }
509
510 // In this test, we verify that connect() on a socket will have the effect
511 // of filtering reads on this socket only to data read from the destination
512 // we connected to.
513 //
514 // The purpose of this test is that some documentation indicates that connect
515 // binds the client's sends to send to a particular server endpoint, but does
516 // not bind the client's reads to only be from that endpoint, and that we need
517 // to always use recvfrom() to disambiguate.
TEST_F(UDPSocketTest,VerifyConnectBindsAddr)518 TEST_F(UDPSocketTest, VerifyConnectBindsAddr) {
519 std::string simple_message("hello world!");
520 std::string foreign_message("BAD MESSAGE TO GET!!");
521
522 // Setup the first server to listen.
523 IPEndPoint server1_address(IPAddress::IPv4Localhost(), 0 /* port */);
524 UDPServerSocket server1(nullptr, NetLogSource());
525 ASSERT_THAT(server1.Listen(server1_address), IsOk());
526 // Get the bound port.
527 ASSERT_THAT(server1.GetLocalAddress(&server1_address), IsOk());
528
529 // Setup the second server to listen.
530 IPEndPoint server2_address(IPAddress::IPv4Localhost(), 0 /* port */);
531 UDPServerSocket server2(nullptr, NetLogSource());
532 ASSERT_THAT(server2.Listen(server2_address), IsOk());
533
534 // Setup the client, connected to server 1.
535 UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
536 EXPECT_THAT(client.Connect(server1_address), IsOk());
537
538 // Client sends to server1.
539 EXPECT_EQ(simple_message.length(),
540 static_cast<size_t>(WriteSocket(&client, simple_message)));
541
542 // Server1 waits for message.
543 std::string str = RecvFromSocket(&server1);
544 EXPECT_EQ(simple_message, str);
545
546 // Get the client's address.
547 IPEndPoint client_address;
548 EXPECT_THAT(client.GetLocalAddress(&client_address), IsOk());
549
550 // Server2 sends reply.
551 EXPECT_EQ(foreign_message.length(),
552 static_cast<size_t>(
553 SendToSocket(&server2, foreign_message, client_address)));
554
555 // Server1 sends reply.
556 EXPECT_EQ(simple_message.length(),
557 static_cast<size_t>(
558 SendToSocket(&server1, simple_message, client_address)));
559
560 // Client waits for response.
561 str = ReadSocket(&client);
562 EXPECT_EQ(simple_message, str);
563 }
564
TEST_F(UDPSocketTest,ClientGetLocalPeerAddresses)565 TEST_F(UDPSocketTest, ClientGetLocalPeerAddresses) {
566 struct TestData {
567 std::string remote_address;
568 std::string local_address;
569 bool may_fail;
570 } tests[] = {
571 {"127.0.00.1", "127.0.0.1", false},
572 {"::1", "::1", true},
573 #if !BUILDFLAG(IS_ANDROID) && !BUILDFLAG(IS_IOS)
574 // Addresses below are disabled on Android. See crbug.com/161248
575 // They are also disabled on iOS. See https://crbug.com/523225
576 {"192.168.1.1", "127.0.0.1", false},
577 {"2001:db8:0::42", "::1", true},
578 #endif
579 };
580 for (const auto& test : tests) {
581 SCOPED_TRACE(std::string("Connecting from ") + test.local_address +
582 std::string(" to ") + test.remote_address);
583
584 IPAddress ip_address;
585 EXPECT_TRUE(ip_address.AssignFromIPLiteral(test.remote_address));
586 IPEndPoint remote_address(ip_address, 80);
587 EXPECT_TRUE(ip_address.AssignFromIPLiteral(test.local_address));
588 IPEndPoint local_address(ip_address, 80);
589
590 UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr,
591 NetLogSource());
592 int rv = client.Connect(remote_address);
593 if (test.may_fail && rv == ERR_ADDRESS_UNREACHABLE) {
594 // Connect() may return ERR_ADDRESS_UNREACHABLE for IPv6
595 // addresses if IPv6 is not configured.
596 continue;
597 }
598
599 EXPECT_LE(ERR_IO_PENDING, rv);
600
601 IPEndPoint fetched_local_address;
602 rv = client.GetLocalAddress(&fetched_local_address);
603 EXPECT_THAT(rv, IsOk());
604
605 // TODO(mbelshe): figure out how to verify the IP and port.
606 // The port is dynamically generated by the udp stack.
607 // The IP is the real IP of the client, not necessarily
608 // loopback.
609 // EXPECT_EQ(local_address.address(), fetched_local_address.address());
610
611 IPEndPoint fetched_remote_address;
612 rv = client.GetPeerAddress(&fetched_remote_address);
613 EXPECT_THAT(rv, IsOk());
614
615 EXPECT_EQ(remote_address, fetched_remote_address);
616 }
617 }
618
TEST_F(UDPSocketTest,ServerGetLocalAddress)619 TEST_F(UDPSocketTest, ServerGetLocalAddress) {
620 IPEndPoint bind_address(IPAddress::IPv4Localhost(), 0);
621 UDPServerSocket server(nullptr, NetLogSource());
622 int rv = server.Listen(bind_address);
623 EXPECT_THAT(rv, IsOk());
624
625 IPEndPoint local_address;
626 rv = server.GetLocalAddress(&local_address);
627 EXPECT_EQ(rv, 0);
628
629 // Verify that port was allocated.
630 EXPECT_GT(local_address.port(), 0);
631 EXPECT_EQ(local_address.address(), bind_address.address());
632 }
633
TEST_F(UDPSocketTest,ServerGetPeerAddress)634 TEST_F(UDPSocketTest, ServerGetPeerAddress) {
635 IPEndPoint bind_address(IPAddress::IPv4Localhost(), 0);
636 UDPServerSocket server(nullptr, NetLogSource());
637 int rv = server.Listen(bind_address);
638 EXPECT_THAT(rv, IsOk());
639
640 IPEndPoint peer_address;
641 rv = server.GetPeerAddress(&peer_address);
642 EXPECT_EQ(rv, ERR_SOCKET_NOT_CONNECTED);
643 }
644
TEST_F(UDPSocketTest,ClientSetDoNotFragment)645 TEST_F(UDPSocketTest, ClientSetDoNotFragment) {
646 for (std::string ip : {"127.0.0.1", "::1"}) {
647 UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr,
648 NetLogSource());
649 IPAddress ip_address;
650 EXPECT_TRUE(ip_address.AssignFromIPLiteral(ip));
651 IPEndPoint remote_address(ip_address, 80);
652 int rv = client.Connect(remote_address);
653 // May fail on IPv6 is IPv6 is not configured.
654 if (ip_address.IsIPv6() && rv == ERR_ADDRESS_UNREACHABLE)
655 return;
656 EXPECT_THAT(rv, IsOk());
657
658 rv = client.SetDoNotFragment();
659 #if BUILDFLAG(IS_IOS) || BUILDFLAG(IS_FUCHSIA)
660 // TODO(crbug.com/945590): IP_MTU_DISCOVER is not implemented on Fuchsia.
661 EXPECT_THAT(rv, IsError(ERR_NOT_IMPLEMENTED));
662 #elif BUILDFLAG(IS_MAC)
663 if (base::mac::IsAtLeastOS11()) {
664 EXPECT_THAT(rv, IsOk());
665 } else {
666 EXPECT_THAT(rv, IsError(ERR_NOT_IMPLEMENTED));
667 }
668 #else
669 EXPECT_THAT(rv, IsOk());
670 #endif
671 }
672 }
673
TEST_F(UDPSocketTest,ServerSetDoNotFragment)674 TEST_F(UDPSocketTest, ServerSetDoNotFragment) {
675 for (std::string ip : {"127.0.0.1", "::1"}) {
676 IPEndPoint bind_address;
677 ASSERT_TRUE(CreateUDPAddress(ip, 0, &bind_address));
678 UDPServerSocket server(nullptr, NetLogSource());
679 int rv = server.Listen(bind_address);
680 // May fail on IPv6 is IPv6 is not configure
681 if (bind_address.address().IsIPv6() &&
682 (rv == ERR_ADDRESS_INVALID || rv == ERR_ADDRESS_UNREACHABLE))
683 return;
684 EXPECT_THAT(rv, IsOk());
685
686 rv = server.SetDoNotFragment();
687 #if BUILDFLAG(IS_IOS) || BUILDFLAG(IS_FUCHSIA)
688 // TODO(crbug.com/945590): IP_MTU_DISCOVER is not implemented on Fuchsia.
689 EXPECT_THAT(rv, IsError(ERR_NOT_IMPLEMENTED));
690 #elif BUILDFLAG(IS_MAC)
691 if (base::mac::IsAtLeastOS11()) {
692 EXPECT_THAT(rv, IsOk());
693 } else {
694 EXPECT_THAT(rv, IsError(ERR_NOT_IMPLEMENTED));
695 }
696 #else
697 EXPECT_THAT(rv, IsOk());
698 #endif
699 }
700 }
701
702 // Close the socket while read is pending.
TEST_F(UDPSocketTest,CloseWithPendingRead)703 TEST_F(UDPSocketTest, CloseWithPendingRead) {
704 IPEndPoint bind_address(IPAddress::IPv4Localhost(), 0);
705 UDPServerSocket server(nullptr, NetLogSource());
706 int rv = server.Listen(bind_address);
707 EXPECT_THAT(rv, IsOk());
708
709 TestCompletionCallback callback;
710 IPEndPoint from;
711 rv = server.RecvFrom(buffer_.get(), kMaxRead, &from, callback.callback());
712 EXPECT_EQ(rv, ERR_IO_PENDING);
713
714 server.Close();
715
716 EXPECT_FALSE(callback.have_result());
717 }
718
719 // Some Android devices do not support multicast.
720 // The ones supporting multicast need WifiManager.MulitcastLock to enable it.
721 // http://goo.gl/jjAk9
722 #if !BUILDFLAG(IS_ANDROID)
TEST_F(UDPSocketTest,JoinMulticastGroup)723 TEST_F(UDPSocketTest, JoinMulticastGroup) {
724 const char kGroup[] = "237.132.100.17";
725
726 IPAddress group_ip;
727 EXPECT_TRUE(group_ip.AssignFromIPLiteral(kGroup));
728 // TODO(https://github.com/google/gvisor/issues/3839): don't guard on
729 // OS_FUCHSIA.
730 #if BUILDFLAG(IS_WIN) || BUILDFLAG(IS_FUCHSIA)
731 IPEndPoint bind_address(IPAddress::AllZeros(group_ip.size()), 0 /* port */);
732 #else
733 IPEndPoint bind_address(group_ip, 0 /* port */);
734 #endif // BUILDFLAG(IS_WIN) || BUILDFLAG(IS_FUCHSIA)
735
736 UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
737 EXPECT_THAT(socket.Open(bind_address.GetFamily()), IsOk());
738
739 EXPECT_THAT(socket.Bind(bind_address), IsOk());
740 EXPECT_THAT(socket.JoinGroup(group_ip), IsOk());
741 // Joining group multiple times.
742 EXPECT_NE(OK, socket.JoinGroup(group_ip));
743 EXPECT_THAT(socket.LeaveGroup(group_ip), IsOk());
744 // Leaving group multiple times.
745 EXPECT_NE(OK, socket.LeaveGroup(group_ip));
746
747 socket.Close();
748 }
749
750 // TODO(https://crbug.com/947115): failing on device on iOS 12.2.
751 // TODO(https://crbug.com/1227554): flaky on Mac 11.
752 #if BUILDFLAG(IS_IOS) || BUILDFLAG(IS_MAC)
753 #define MAYBE_SharedMulticastAddress DISABLED_SharedMulticastAddress
754 #else
755 #define MAYBE_SharedMulticastAddress SharedMulticastAddress
756 #endif
TEST_F(UDPSocketTest,MAYBE_SharedMulticastAddress)757 TEST_F(UDPSocketTest, MAYBE_SharedMulticastAddress) {
758 const char kGroup[] = "224.0.0.251";
759
760 IPAddress group_ip;
761 ASSERT_TRUE(group_ip.AssignFromIPLiteral(kGroup));
762 // TODO(https://github.com/google/gvisor/issues/3839): don't guard on
763 // OS_FUCHSIA.
764 #if BUILDFLAG(IS_WIN) || BUILDFLAG(IS_FUCHSIA)
765 IPEndPoint receive_address(IPAddress::AllZeros(group_ip.size()),
766 0 /* port */);
767 #else
768 IPEndPoint receive_address(group_ip, 0 /* port */);
769 #endif // BUILDFLAG(IS_WIN) || BUILDFLAG(IS_FUCHSIA)
770
771 NetworkInterfaceList interfaces;
772 ASSERT_TRUE(GetNetworkList(&interfaces, 0));
773 // The test fails with the Hyper-V switch interface (on the host side).
774 interfaces.erase(std::remove_if(interfaces.begin(), interfaces.end(),
775 [](const auto& iface) {
776 return iface.friendly_name.rfind(
777 "vEthernet", 0) == 0;
778 }),
779 interfaces.end());
780 ASSERT_FALSE(interfaces.empty());
781
782 // Setup first receiving socket.
783 UDPServerSocket socket1(nullptr, NetLogSource());
784 socket1.AllowAddressSharingForMulticast();
785 ASSERT_THAT(socket1.SetMulticastInterface(interfaces[0].interface_index),
786 IsOk());
787 ASSERT_THAT(socket1.Listen(receive_address), IsOk());
788 ASSERT_THAT(socket1.JoinGroup(group_ip), IsOk());
789 // Get the bound port.
790 ASSERT_THAT(socket1.GetLocalAddress(&receive_address), IsOk());
791
792 // Setup second receiving socket.
793 UDPServerSocket socket2(nullptr, NetLogSource());
794 socket2.AllowAddressSharingForMulticast(), IsOk();
795 ASSERT_THAT(socket2.SetMulticastInterface(interfaces[0].interface_index),
796 IsOk());
797 ASSERT_THAT(socket2.Listen(receive_address), IsOk());
798 ASSERT_THAT(socket2.JoinGroup(group_ip), IsOk());
799
800 // Setup client socket.
801 IPEndPoint send_address(group_ip, receive_address.port());
802 UDPClientSocket client_socket(DatagramSocket::DEFAULT_BIND, nullptr,
803 NetLogSource());
804 ASSERT_THAT(client_socket.Connect(send_address), IsOk());
805
806 #if !BUILDFLAG(IS_CHROMEOS_ASH)
807 // Send a message via the multicast group. That message is expected be be
808 // received by both receving sockets.
809 //
810 // Skip on ChromeOS where it's known to sometimes not work.
811 // TODO(crbug.com/898964): If possible, fix and reenable.
812 const char kMessage[] = "hello!";
813 ASSERT_GE(WriteSocket(&client_socket, kMessage), 0);
814 EXPECT_EQ(kMessage, RecvFromSocket(&socket1));
815 EXPECT_EQ(kMessage, RecvFromSocket(&socket2));
816 #endif // !BUILDFLAG(IS_CHROMEOS_ASH)
817 }
818 #endif // !BUILDFLAG(IS_ANDROID)
819
TEST_F(UDPSocketTest,MulticastOptions)820 TEST_F(UDPSocketTest, MulticastOptions) {
821 IPEndPoint bind_address;
822 ASSERT_TRUE(CreateUDPAddress("0.0.0.0", 0 /* port */, &bind_address));
823
824 UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
825 // Before binding.
826 EXPECT_THAT(socket.SetMulticastLoopbackMode(false), IsOk());
827 EXPECT_THAT(socket.SetMulticastLoopbackMode(true), IsOk());
828 EXPECT_THAT(socket.SetMulticastTimeToLive(0), IsOk());
829 EXPECT_THAT(socket.SetMulticastTimeToLive(3), IsOk());
830 EXPECT_NE(OK, socket.SetMulticastTimeToLive(-1));
831 EXPECT_THAT(socket.SetMulticastInterface(0), IsOk());
832
833 EXPECT_THAT(socket.Open(bind_address.GetFamily()), IsOk());
834 EXPECT_THAT(socket.Bind(bind_address), IsOk());
835
836 EXPECT_NE(OK, socket.SetMulticastLoopbackMode(false));
837 EXPECT_NE(OK, socket.SetMulticastTimeToLive(0));
838 EXPECT_NE(OK, socket.SetMulticastInterface(0));
839
840 socket.Close();
841 }
842
843 // Checking that DSCP bits are set correctly is difficult,
844 // but let's check that the code doesn't crash at least.
TEST_F(UDPSocketTest,SetDSCP)845 TEST_F(UDPSocketTest, SetDSCP) {
846 // Setup the server to listen.
847 IPEndPoint bind_address;
848 UDPSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
849 // We need a real IP, but we won't actually send anything to it.
850 ASSERT_TRUE(CreateUDPAddress("8.8.8.8", 9999, &bind_address));
851 int rv = client.Open(bind_address.GetFamily());
852 EXPECT_THAT(rv, IsOk());
853
854 rv = client.Connect(bind_address);
855 if (rv != OK) {
856 // Let's try localhost then.
857 bind_address = IPEndPoint(IPAddress::IPv4Localhost(), 9999);
858 rv = client.Connect(bind_address);
859 }
860 EXPECT_THAT(rv, IsOk());
861
862 client.SetDiffServCodePoint(DSCP_NO_CHANGE);
863 client.SetDiffServCodePoint(DSCP_AF41);
864 client.SetDiffServCodePoint(DSCP_DEFAULT);
865 client.SetDiffServCodePoint(DSCP_CS2);
866 client.SetDiffServCodePoint(DSCP_NO_CHANGE);
867 client.SetDiffServCodePoint(DSCP_DEFAULT);
868 client.Close();
869 }
870
TEST_F(UDPSocketTest,ConnectUsingNetwork)871 TEST_F(UDPSocketTest, ConnectUsingNetwork) {
872 // The specific value of this address doesn't really matter, and no
873 // server needs to be running here. The test only needs to call
874 // ConnectUsingNetwork() and won't send any datagrams.
875 const IPEndPoint fake_server_address(IPAddress::IPv4Localhost(), 8080);
876 const handles::NetworkHandle wrong_network_handle = 65536;
877 #if BUILDFLAG(IS_ANDROID)
878 NetworkChangeNotifierFactoryAndroid ncn_factory;
879 NetworkChangeNotifier::DisableForTest ncn_disable_for_test;
880 std::unique_ptr<NetworkChangeNotifier> ncn(ncn_factory.CreateInstance());
881 if (!NetworkChangeNotifier::AreNetworkHandlesSupported())
882 GTEST_SKIP() << "Network handles are required to test BindToNetwork.";
883
884 {
885 // Connecting using a not existing network should fail but not report
886 // ERR_NOT_IMPLEMENTED when network handles are supported.
887 UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
888 NetLogSource());
889 int rv =
890 socket.ConnectUsingNetwork(wrong_network_handle, fake_server_address);
891 EXPECT_NE(ERR_NOT_IMPLEMENTED, rv);
892 EXPECT_NE(OK, rv);
893 EXPECT_NE(wrong_network_handle, socket.GetBoundNetwork());
894 }
895
896 {
897 // Connecting using an existing network should succeed when
898 // NetworkChangeNotifier returns a valid default network.
899 UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
900 NetLogSource());
901 const handles::NetworkHandle network_handle =
902 NetworkChangeNotifier::GetDefaultNetwork();
903 if (network_handle != handles::kInvalidNetworkHandle) {
904 EXPECT_EQ(
905 OK, socket.ConnectUsingNetwork(network_handle, fake_server_address));
906 EXPECT_EQ(network_handle, socket.GetBoundNetwork());
907 }
908 }
909 #else
910 UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr, NetLogSource());
911 EXPECT_EQ(
912 ERR_NOT_IMPLEMENTED,
913 socket.ConnectUsingNetwork(wrong_network_handle, fake_server_address));
914 #endif // BUILDFLAG(IS_ANDROID)
915 }
916
TEST_F(UDPSocketTest,ConnectUsingNetworkAsync)917 TEST_F(UDPSocketTest, ConnectUsingNetworkAsync) {
918 // The specific value of this address doesn't really matter, and no
919 // server needs to be running here. The test only needs to call
920 // ConnectUsingNetwork() and won't send any datagrams.
921 const IPEndPoint fake_server_address(IPAddress::IPv4Localhost(), 8080);
922 const handles::NetworkHandle wrong_network_handle = 65536;
923 #if BUILDFLAG(IS_ANDROID)
924 NetworkChangeNotifierFactoryAndroid ncn_factory;
925 NetworkChangeNotifier::DisableForTest ncn_disable_for_test;
926 std::unique_ptr<NetworkChangeNotifier> ncn(ncn_factory.CreateInstance());
927 if (!NetworkChangeNotifier::AreNetworkHandlesSupported())
928 GTEST_SKIP() << "Network handles are required to test BindToNetwork.";
929
930 {
931 // Connecting using a not existing network should fail but not report
932 // ERR_NOT_IMPLEMENTED when network handles are supported.
933 UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
934 NetLogSource());
935 TestCompletionCallback callback;
936 int rv = socket.ConnectUsingNetworkAsync(
937 wrong_network_handle, fake_server_address, callback.callback());
938
939 if (rv == ERR_IO_PENDING) {
940 rv = callback.WaitForResult();
941 }
942 EXPECT_NE(ERR_NOT_IMPLEMENTED, rv);
943 EXPECT_NE(OK, rv);
944 }
945
946 {
947 // Connecting using an existing network should succeed when
948 // NetworkChangeNotifier returns a valid default network.
949 UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
950 NetLogSource());
951 TestCompletionCallback callback;
952 const handles::NetworkHandle network_handle =
953 NetworkChangeNotifier::GetDefaultNetwork();
954 if (network_handle != handles::kInvalidNetworkHandle) {
955 int rv = socket.ConnectUsingNetworkAsync(
956 network_handle, fake_server_address, callback.callback());
957 if (rv == ERR_IO_PENDING) {
958 rv = callback.WaitForResult();
959 }
960 EXPECT_EQ(OK, rv);
961 EXPECT_EQ(network_handle, socket.GetBoundNetwork());
962 }
963 }
964 #else
965 UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr, NetLogSource());
966 TestCompletionCallback callback;
967 EXPECT_EQ(ERR_NOT_IMPLEMENTED, socket.ConnectUsingNetworkAsync(
968 wrong_network_handle, fake_server_address,
969 callback.callback()));
970 #endif // BUILDFLAG(IS_ANDROID)
971 }
972
973 } // namespace
974
975 #if BUILDFLAG(IS_WIN)
976
977 namespace {
978
979 const HANDLE kFakeHandle1 = (HANDLE)12;
980 const HANDLE kFakeHandle2 = (HANDLE)13;
981
982 const QOS_FLOWID kFakeFlowId1 = (QOS_FLOWID)27;
983 const QOS_FLOWID kFakeFlowId2 = (QOS_FLOWID)38;
984
985 class TestUDPSocketWin : public UDPSocketWin {
986 public:
TestUDPSocketWin(QwaveApi * qos,DatagramSocket::BindType bind_type,net::NetLog * net_log,const net::NetLogSource & source)987 TestUDPSocketWin(QwaveApi* qos,
988 DatagramSocket::BindType bind_type,
989 net::NetLog* net_log,
990 const net::NetLogSource& source)
991 : UDPSocketWin(bind_type, net_log, source), qos_(qos) {}
992
993 TestUDPSocketWin(const TestUDPSocketWin&) = delete;
994 TestUDPSocketWin& operator=(const TestUDPSocketWin&) = delete;
995
996 // Overriding GetQwaveApi causes the test class to use the injected mock
997 // QwaveApi instance instead of the singleton.
GetQwaveApi() const998 QwaveApi* GetQwaveApi() const override { return qos_; }
999
1000 private:
1001 raw_ptr<QwaveApi> qos_;
1002 };
1003
1004 class MockQwaveApi : public QwaveApi {
1005 public:
1006 MOCK_CONST_METHOD0(qwave_supported, bool());
1007 MOCK_METHOD0(OnFatalError, void());
1008 MOCK_METHOD2(CreateHandle, BOOL(PQOS_VERSION version, PHANDLE handle));
1009 MOCK_METHOD1(CloseHandle, BOOL(HANDLE handle));
1010 MOCK_METHOD6(AddSocketToFlow,
1011 BOOL(HANDLE handle,
1012 SOCKET socket,
1013 PSOCKADDR addr,
1014 QOS_TRAFFIC_TYPE traffic_type,
1015 DWORD flags,
1016 PQOS_FLOWID flow_id));
1017
1018 MOCK_METHOD4(
1019 RemoveSocketFromFlow,
1020 BOOL(HANDLE handle, SOCKET socket, QOS_FLOWID flow_id, DWORD reserved));
1021 MOCK_METHOD7(SetFlow,
1022 BOOL(HANDLE handle,
1023 QOS_FLOWID flow_id,
1024 QOS_SET_FLOW op,
1025 ULONG size,
1026 PVOID data,
1027 DWORD reserved,
1028 LPOVERLAPPED overlapped));
1029 };
1030
OpenedDscpTestClient(QwaveApi * api,IPEndPoint bind_address)1031 std::unique_ptr<UDPSocket> OpenedDscpTestClient(QwaveApi* api,
1032 IPEndPoint bind_address) {
1033 auto client = std::make_unique<TestUDPSocketWin>(
1034 api, DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1035 int rv = client->Open(bind_address.GetFamily());
1036 EXPECT_THAT(rv, IsOk());
1037
1038 return client;
1039 }
1040
ConnectedDscpTestClient(QwaveApi * api)1041 std::unique_ptr<UDPSocket> ConnectedDscpTestClient(QwaveApi* api) {
1042 IPEndPoint bind_address;
1043 // We need a real IP, but we won't actually send anything to it.
1044 EXPECT_TRUE(CreateUDPAddress("8.8.8.8", 9999, &bind_address));
1045 auto client = OpenedDscpTestClient(api, bind_address);
1046 EXPECT_THAT(client->Connect(bind_address), IsOk());
1047 return client;
1048 }
1049
UnconnectedDscpTestClient(QwaveApi * api)1050 std::unique_ptr<UDPSocket> UnconnectedDscpTestClient(QwaveApi* api) {
1051 IPEndPoint bind_address;
1052 EXPECT_TRUE(CreateUDPAddress("0.0.0.0", 9999, &bind_address));
1053 auto client = OpenedDscpTestClient(api, bind_address);
1054 EXPECT_THAT(client->Bind(bind_address), IsOk());
1055 return client;
1056 }
1057
1058 } // namespace
1059
1060 using ::testing::Return;
1061 using ::testing::SetArgPointee;
1062 using ::testing::_;
1063
TEST_F(UDPSocketTest,SetDSCPNoopIfPassedNoChange)1064 TEST_F(UDPSocketTest, SetDSCPNoopIfPassedNoChange) {
1065 MockQwaveApi api;
1066 EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1067
1068 EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _)).Times(0);
1069 std::unique_ptr<UDPSocket> client = ConnectedDscpTestClient(&api);
1070 EXPECT_THAT(client->SetDiffServCodePoint(DSCP_NO_CHANGE), IsOk());
1071 }
1072
TEST_F(UDPSocketTest,SetDSCPFailsIfQOSDoesntLink)1073 TEST_F(UDPSocketTest, SetDSCPFailsIfQOSDoesntLink) {
1074 MockQwaveApi api;
1075 EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(false));
1076 EXPECT_CALL(api, CreateHandle(_, _)).Times(0);
1077
1078 std::unique_ptr<UDPSocket> client = ConnectedDscpTestClient(&api);
1079 EXPECT_EQ(ERR_NOT_IMPLEMENTED, client->SetDiffServCodePoint(DSCP_AF41));
1080 }
1081
TEST_F(UDPSocketTest,SetDSCPFailsIfHandleCantBeCreated)1082 TEST_F(UDPSocketTest, SetDSCPFailsIfHandleCantBeCreated) {
1083 MockQwaveApi api;
1084 EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1085 EXPECT_CALL(api, CreateHandle(_, _)).WillOnce(Return(false));
1086 EXPECT_CALL(api, OnFatalError()).Times(1);
1087
1088 std::unique_ptr<UDPSocket> client = ConnectedDscpTestClient(&api);
1089 EXPECT_EQ(ERR_INVALID_HANDLE, client->SetDiffServCodePoint(DSCP_AF41));
1090
1091 RunUntilIdle();
1092
1093 EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(false));
1094 EXPECT_EQ(ERR_NOT_IMPLEMENTED, client->SetDiffServCodePoint(DSCP_AF41));
1095 }
1096
1097 MATCHER_P(DscpPointee, dscp, "") {
1098 return *(DWORD*)arg == (DWORD)dscp;
1099 }
1100
TEST_F(UDPSocketTest,ConnectedSocketDelayedInitAndUpdate)1101 TEST_F(UDPSocketTest, ConnectedSocketDelayedInitAndUpdate) {
1102 MockQwaveApi api;
1103 std::unique_ptr<UDPSocket> client = ConnectedDscpTestClient(&api);
1104 EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1105 EXPECT_CALL(api, CreateHandle(_, _))
1106 .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1107
1108 EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _))
1109 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1110 EXPECT_CALL(api, SetFlow(_, _, _, _, _, _, _));
1111
1112 // First set on connected sockets will fail since init is async and
1113 // we haven't given the runloop a chance to execute the callback.
1114 EXPECT_EQ(ERR_INVALID_HANDLE, client->SetDiffServCodePoint(DSCP_AF41));
1115 RunUntilIdle();
1116 EXPECT_THAT(client->SetDiffServCodePoint(DSCP_AF41), IsOk());
1117
1118 // New dscp value should reset the flow.
1119 EXPECT_CALL(api, RemoveSocketFromFlow(_, _, kFakeFlowId1, _));
1120 EXPECT_CALL(api, AddSocketToFlow(_, _, _, QOSTrafficTypeBestEffort, _, _))
1121 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId2), Return(true)));
1122 EXPECT_CALL(api, SetFlow(_, _, QOSSetOutgoingDSCPValue, _,
1123 DscpPointee(DSCP_DEFAULT), _, _));
1124 EXPECT_THAT(client->SetDiffServCodePoint(DSCP_DEFAULT), IsOk());
1125
1126 // Called from DscpManager destructor.
1127 EXPECT_CALL(api, RemoveSocketFromFlow(_, _, kFakeFlowId2, _));
1128 EXPECT_CALL(api, CloseHandle(kFakeHandle1));
1129 }
1130
TEST_F(UDPSocketTest,UnonnectedSocketDelayedInitAndUpdate)1131 TEST_F(UDPSocketTest, UnonnectedSocketDelayedInitAndUpdate) {
1132 MockQwaveApi api;
1133 EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1134 EXPECT_CALL(api, CreateHandle(_, _))
1135 .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1136
1137 // CreateHandle won't have completed yet. Set passes.
1138 std::unique_ptr<UDPSocket> client = UnconnectedDscpTestClient(&api);
1139 EXPECT_THAT(client->SetDiffServCodePoint(DSCP_AF41), IsOk());
1140
1141 RunUntilIdle();
1142 EXPECT_THAT(client->SetDiffServCodePoint(DSCP_AF42), IsOk());
1143
1144 // Called from DscpManager destructor.
1145 EXPECT_CALL(api, CloseHandle(kFakeHandle1));
1146 }
1147
1148 // TODO(zstein): Mocking out DscpManager might be simpler here
1149 // (just verify that DscpManager::Set and DscpManager::PrepareForSend are
1150 // called).
TEST_F(UDPSocketTest,SendToCallsQwaveApis)1151 TEST_F(UDPSocketTest, SendToCallsQwaveApis) {
1152 MockQwaveApi api;
1153 std::unique_ptr<UDPSocket> client = UnconnectedDscpTestClient(&api);
1154 EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1155 EXPECT_CALL(api, CreateHandle(_, _))
1156 .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1157 EXPECT_THAT(client->SetDiffServCodePoint(DSCP_AF41), IsOk());
1158 RunUntilIdle();
1159
1160 EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _))
1161 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1162 EXPECT_CALL(api, SetFlow(_, _, _, _, _, _, _));
1163 std::string simple_message("hello world");
1164 IPEndPoint server_address(IPAddress::IPv4Localhost(), 9438);
1165 int rv = SendToSocket(client.get(), simple_message, server_address);
1166 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1167
1168 // TODO(zstein): Move to second test case (Qwave APIs called once per address)
1169 rv = SendToSocket(client.get(), simple_message, server_address);
1170 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1171
1172 // TODO(zstein): Move to third test case (Qwave APIs called for each
1173 // destination address).
1174 EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _)).WillOnce(Return(true));
1175 IPEndPoint server_address2(IPAddress::IPv4Localhost(), 9439);
1176
1177 rv = SendToSocket(client.get(), simple_message, server_address2);
1178 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1179
1180 // Called from DscpManager destructor.
1181 EXPECT_CALL(api, RemoveSocketFromFlow(_, _, _, _));
1182 EXPECT_CALL(api, CloseHandle(kFakeHandle1));
1183 }
1184
TEST_F(UDPSocketTest,SendToCallsApisAfterDeferredInit)1185 TEST_F(UDPSocketTest, SendToCallsApisAfterDeferredInit) {
1186 MockQwaveApi api;
1187 std::unique_ptr<UDPSocket> client = UnconnectedDscpTestClient(&api);
1188 EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1189 EXPECT_CALL(api, CreateHandle(_, _))
1190 .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1191
1192 // SetDiffServCodepoint works even if qos api hasn't finished initing.
1193 EXPECT_THAT(client->SetDiffServCodePoint(DSCP_CS7), IsOk());
1194
1195 std::string simple_message("hello world");
1196 IPEndPoint server_address(IPAddress::IPv4Localhost(), 9438);
1197
1198 // SendTo works, but doesn't yet apply TOS
1199 EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _)).Times(0);
1200 int rv = SendToSocket(client.get(), simple_message, server_address);
1201 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1202
1203 RunUntilIdle();
1204 // Now we're initialized, SendTo triggers qos calls with correct codepoint.
1205 EXPECT_CALL(api, AddSocketToFlow(_, _, _, QOSTrafficTypeControl, _, _))
1206 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1207 EXPECT_CALL(api, SetFlow(_, _, _, _, _, _, _)).WillOnce(Return(true));
1208 rv = SendToSocket(client.get(), simple_message, server_address);
1209 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1210
1211 // Called from DscpManager destructor.
1212 EXPECT_CALL(api, RemoveSocketFromFlow(_, _, kFakeFlowId1, _));
1213 EXPECT_CALL(api, CloseHandle(kFakeHandle1));
1214 }
1215
1216 class DscpManagerTest : public TestWithTaskEnvironment {
1217 protected:
DscpManagerTest()1218 DscpManagerTest() {
1219 EXPECT_CALL(api_, qwave_supported()).WillRepeatedly(Return(true));
1220 EXPECT_CALL(api_, CreateHandle(_, _))
1221 .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1222 dscp_manager_ = std::make_unique<DscpManager>(&api_, INVALID_SOCKET);
1223
1224 CreateUDPAddress("1.2.3.4", 9001, &address1_);
1225 CreateUDPAddress("1234:5678:90ab:cdef:1234:5678:90ab:cdef", 9002,
1226 &address2_);
1227 }
1228
1229 MockQwaveApi api_;
1230 std::unique_ptr<DscpManager> dscp_manager_;
1231
1232 IPEndPoint address1_;
1233 IPEndPoint address2_;
1234 };
1235
TEST_F(DscpManagerTest,PrepareForSendIsNoopIfNoSet)1236 TEST_F(DscpManagerTest, PrepareForSendIsNoopIfNoSet) {
1237 RunUntilIdle();
1238 dscp_manager_->PrepareForSend(address1_);
1239 }
1240
TEST_F(DscpManagerTest,PrepareForSendCallsQwaveApisAfterSet)1241 TEST_F(DscpManagerTest, PrepareForSendCallsQwaveApisAfterSet) {
1242 RunUntilIdle();
1243 dscp_manager_->Set(DSCP_CS2);
1244
1245 // AddSocketToFlow should be called for each address.
1246 // SetFlow should only be called when the flow is first created.
1247 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1248 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1249 EXPECT_CALL(api_, SetFlow(_, kFakeFlowId1, _, _, _, _, _));
1250 dscp_manager_->PrepareForSend(address1_);
1251
1252 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1253 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1254 EXPECT_CALL(api_, SetFlow(_, _, _, _, _, _, _)).Times(0);
1255 dscp_manager_->PrepareForSend(address2_);
1256
1257 // Called from DscpManager destructor.
1258 EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId1, _));
1259 EXPECT_CALL(api_, CloseHandle(kFakeHandle1));
1260 }
1261
TEST_F(DscpManagerTest,PrepareForSendCallsQwaveApisOncePerAddress)1262 TEST_F(DscpManagerTest, PrepareForSendCallsQwaveApisOncePerAddress) {
1263 RunUntilIdle();
1264 dscp_manager_->Set(DSCP_CS2);
1265
1266 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1267 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1268 EXPECT_CALL(api_, SetFlow(_, kFakeFlowId1, _, _, _, _, _));
1269 dscp_manager_->PrepareForSend(address1_);
1270 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _)).Times(0);
1271 EXPECT_CALL(api_, SetFlow(_, _, _, _, _, _, _)).Times(0);
1272 dscp_manager_->PrepareForSend(address1_);
1273
1274 // Called from DscpManager destructor.
1275 EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId1, _));
1276 EXPECT_CALL(api_, CloseHandle(kFakeHandle1));
1277 }
1278
TEST_F(DscpManagerTest,SetDestroysExistingFlow)1279 TEST_F(DscpManagerTest, SetDestroysExistingFlow) {
1280 RunUntilIdle();
1281 dscp_manager_->Set(DSCP_CS2);
1282
1283 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1284 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1285 EXPECT_CALL(api_, SetFlow(_, kFakeFlowId1, _, _, _, _, _));
1286 dscp_manager_->PrepareForSend(address1_);
1287
1288 // Calling Set should destroy the existing flow.
1289 // TODO(zstein): Verify that RemoveSocketFromFlow with no address
1290 // destroys the flow for all destinations.
1291 EXPECT_CALL(api_, RemoveSocketFromFlow(_, NULL, kFakeFlowId1, _));
1292 dscp_manager_->Set(DSCP_CS5);
1293
1294 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1295 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId2), Return(true)));
1296 EXPECT_CALL(api_, SetFlow(_, kFakeFlowId2, _, _, _, _, _));
1297 dscp_manager_->PrepareForSend(address1_);
1298
1299 // Called from DscpManager destructor.
1300 EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId2, _));
1301 EXPECT_CALL(api_, CloseHandle(kFakeHandle1));
1302 }
1303
TEST_F(DscpManagerTest,SocketReAddedOnRecreateHandle)1304 TEST_F(DscpManagerTest, SocketReAddedOnRecreateHandle) {
1305 RunUntilIdle();
1306 dscp_manager_->Set(DSCP_CS2);
1307
1308 // First Set and Send work fine.
1309 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1310 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1311 EXPECT_CALL(api_, SetFlow(_, kFakeFlowId1, _, _, _, _, _))
1312 .WillOnce(Return(true));
1313 EXPECT_THAT(dscp_manager_->PrepareForSend(address1_), IsOk());
1314
1315 // Make Second flow operation fail (requires resetting the codepoint).
1316 EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId1, _))
1317 .WillOnce(Return(true));
1318 dscp_manager_->Set(DSCP_CS7);
1319
1320 auto error = std::make_unique<base::ScopedClearLastError>();
1321 ::SetLastError(ERROR_DEVICE_REINITIALIZATION_NEEDED);
1322 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _)).WillOnce(Return(false));
1323 EXPECT_CALL(api_, SetFlow(_, _, _, _, _, _, _)).Times(0);
1324 EXPECT_CALL(api_, CloseHandle(kFakeHandle1));
1325 EXPECT_CALL(api_, CreateHandle(_, _))
1326 .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle2), Return(true)));
1327 EXPECT_EQ(ERR_INVALID_HANDLE, dscp_manager_->PrepareForSend(address1_));
1328 error = nullptr;
1329 RunUntilIdle();
1330
1331 // Next Send should work fine, without requiring another Set
1332 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, QOSTrafficTypeControl, _, _))
1333 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId2), Return(true)));
1334 EXPECT_CALL(api_, SetFlow(_, kFakeFlowId2, _, _, _, _, _))
1335 .WillOnce(Return(true));
1336 EXPECT_THAT(dscp_manager_->PrepareForSend(address1_), IsOk());
1337
1338 // Called from DscpManager destructor.
1339 EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId2, _));
1340 EXPECT_CALL(api_, CloseHandle(kFakeHandle2));
1341 }
1342
1343 #endif
1344
TEST_F(UDPSocketTest,ReadWithSocketOptimization)1345 TEST_F(UDPSocketTest, ReadWithSocketOptimization) {
1346 std::string simple_message("hello world!");
1347
1348 // Setup the server to listen.
1349 IPEndPoint server_address(IPAddress::IPv4Localhost(), 0 /* port */);
1350 UDPServerSocket server(nullptr, NetLogSource());
1351 server.AllowAddressReuse();
1352 ASSERT_THAT(server.Listen(server_address), IsOk());
1353 // Get bound port.
1354 ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
1355
1356 // Setup the client, enable experimental optimization and connected to the
1357 // server.
1358 UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1359 client.EnableRecvOptimization();
1360 EXPECT_THAT(client.Connect(server_address), IsOk());
1361
1362 // Get the client's address.
1363 IPEndPoint client_address;
1364 EXPECT_THAT(client.GetLocalAddress(&client_address), IsOk());
1365
1366 // Server sends the message to the client.
1367 EXPECT_EQ(simple_message.length(),
1368 static_cast<size_t>(
1369 SendToSocket(&server, simple_message, client_address)));
1370
1371 // Client receives the message.
1372 std::string str = ReadSocket(&client);
1373 EXPECT_EQ(simple_message, str);
1374
1375 server.Close();
1376 client.Close();
1377 }
1378
1379 // Tests that read from a socket correctly returns
1380 // |ERR_MSG_TOO_BIG| when the buffer is too small and
1381 // returns the actual message when it fits the buffer.
1382 // For the optimized path, the buffer size should be at least
1383 // 1 byte greater than the message.
TEST_F(UDPSocketTest,ReadWithSocketOptimizationTruncation)1384 TEST_F(UDPSocketTest, ReadWithSocketOptimizationTruncation) {
1385 std::string too_long_message(kMaxRead + 1, 'A');
1386 std::string right_length_message(kMaxRead - 1, 'B');
1387 std::string exact_length_message(kMaxRead, 'C');
1388
1389 // Setup the server to listen.
1390 IPEndPoint server_address(IPAddress::IPv4Localhost(), 0 /* port */);
1391 UDPServerSocket server(nullptr, NetLogSource());
1392 server.AllowAddressReuse();
1393 ASSERT_THAT(server.Listen(server_address), IsOk());
1394 // Get bound port.
1395 ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
1396
1397 // Setup the client, enable experimental optimization and connected to the
1398 // server.
1399 UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1400 client.EnableRecvOptimization();
1401 EXPECT_THAT(client.Connect(server_address), IsOk());
1402
1403 // Get the client's address.
1404 IPEndPoint client_address;
1405 EXPECT_THAT(client.GetLocalAddress(&client_address), IsOk());
1406
1407 // Send messages to the client.
1408 EXPECT_EQ(too_long_message.length(),
1409 static_cast<size_t>(
1410 SendToSocket(&server, too_long_message, client_address)));
1411 EXPECT_EQ(right_length_message.length(),
1412 static_cast<size_t>(
1413 SendToSocket(&server, right_length_message, client_address)));
1414 EXPECT_EQ(exact_length_message.length(),
1415 static_cast<size_t>(
1416 SendToSocket(&server, exact_length_message, client_address)));
1417
1418 // Client receives the messages.
1419
1420 // 1. The first message is |too_long_message|. Its size exceeds the buffer.
1421 // In that case, the client is expected to get |ERR_MSG_TOO_BIG| when the
1422 // data is read.
1423 TestCompletionCallback callback;
1424 int rv = client.Read(buffer_.get(), kMaxRead, callback.callback());
1425 EXPECT_EQ(ERR_MSG_TOO_BIG, callback.GetResult(rv));
1426
1427 // 2. The second message is |right_length_message|. Its size is
1428 // one byte smaller than the size of the buffer. In that case, the client
1429 // is expected to read the whole message successfully.
1430 rv = client.Read(buffer_.get(), kMaxRead, callback.callback());
1431 rv = callback.GetResult(rv);
1432 EXPECT_EQ(static_cast<int>(right_length_message.length()), rv);
1433 EXPECT_EQ(right_length_message, std::string(buffer_->data(), rv));
1434
1435 // 3. The third message is |exact_length_message|. Its size is equal to
1436 // the read buffer size. In that case, the client expects to get
1437 // |ERR_MSG_TOO_BIG| when the socket is read. Internally, the optimized
1438 // path uses read() system call that requires one extra byte to detect
1439 // truncated messages; therefore, messages that fill the buffer exactly
1440 // are considered truncated.
1441 // The optimization is only enabled on POSIX platforms. On Windows,
1442 // the optimization is turned off; therefore, the client
1443 // should be able to read the whole message without encountering
1444 // |ERR_MSG_TOO_BIG|.
1445 rv = client.Read(buffer_.get(), kMaxRead, callback.callback());
1446 rv = callback.GetResult(rv);
1447 #if BUILDFLAG(IS_POSIX)
1448 EXPECT_EQ(ERR_MSG_TOO_BIG, rv);
1449 #else
1450 EXPECT_EQ(static_cast<int>(exact_length_message.length()), rv);
1451 EXPECT_EQ(exact_length_message, std::string(buffer_->data(), rv));
1452 #endif
1453 server.Close();
1454 client.Close();
1455 }
1456
1457 // On Android, where socket tagging is supported, verify that UDPSocket::Tag
1458 // works as expected.
1459 #if BUILDFLAG(IS_ANDROID)
TEST_F(UDPSocketTest,Tag)1460 TEST_F(UDPSocketTest, Tag) {
1461 if (!CanGetTaggedBytes()) {
1462 DVLOG(0) << "Skipping test - GetTaggedBytes unsupported.";
1463 return;
1464 }
1465
1466 UDPServerSocket server(nullptr, NetLogSource());
1467 ASSERT_THAT(server.Listen(IPEndPoint(IPAddress::IPv4Localhost(), 0)), IsOk());
1468 IPEndPoint server_address;
1469 ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
1470
1471 UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1472 ASSERT_THAT(client.Connect(server_address), IsOk());
1473
1474 // Verify UDP packets are tagged and counted properly.
1475 int32_t tag_val1 = 0x12345678;
1476 uint64_t old_traffic = GetTaggedBytes(tag_val1);
1477 SocketTag tag1(SocketTag::UNSET_UID, tag_val1);
1478 client.ApplySocketTag(tag1);
1479 // Client sends to the server.
1480 std::string simple_message("hello world!");
1481 int rv = WriteSocket(&client, simple_message);
1482 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1483 // Server waits for message.
1484 std::string str = RecvFromSocket(&server);
1485 EXPECT_EQ(simple_message, str);
1486 // Server echoes reply.
1487 rv = SendToSocket(&server, simple_message);
1488 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1489 // Client waits for response.
1490 str = ReadSocket(&client);
1491 EXPECT_EQ(simple_message, str);
1492 EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
1493
1494 // Verify socket can be retagged with a new value and the current process's
1495 // UID.
1496 int32_t tag_val2 = 0x87654321;
1497 old_traffic = GetTaggedBytes(tag_val2);
1498 SocketTag tag2(getuid(), tag_val2);
1499 client.ApplySocketTag(tag2);
1500 // Client sends to the server.
1501 rv = WriteSocket(&client, simple_message);
1502 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1503 // Server waits for message.
1504 str = RecvFromSocket(&server);
1505 EXPECT_EQ(simple_message, str);
1506 // Server echoes reply.
1507 rv = SendToSocket(&server, simple_message);
1508 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1509 // Client waits for response.
1510 str = ReadSocket(&client);
1511 EXPECT_EQ(simple_message, str);
1512 EXPECT_GT(GetTaggedBytes(tag_val2), old_traffic);
1513
1514 // Verify socket can be retagged with a new value and the current process's
1515 // UID.
1516 old_traffic = GetTaggedBytes(tag_val1);
1517 client.ApplySocketTag(tag1);
1518 // Client sends to the server.
1519 rv = WriteSocket(&client, simple_message);
1520 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1521 // Server waits for message.
1522 str = RecvFromSocket(&server);
1523 EXPECT_EQ(simple_message, str);
1524 // Server echoes reply.
1525 rv = SendToSocket(&server, simple_message);
1526 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1527 // Client waits for response.
1528 str = ReadSocket(&client);
1529 EXPECT_EQ(simple_message, str);
1530 EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
1531 }
1532
TEST_F(UDPSocketTest,BindToNetwork)1533 TEST_F(UDPSocketTest, BindToNetwork) {
1534 // The specific value of this address doesn't really matter, and no
1535 // server needs to be running here. The test only needs to call
1536 // Connect() and won't send any datagrams.
1537 const IPEndPoint fake_server_address(IPAddress::IPv4Localhost(), 8080);
1538 NetworkChangeNotifierFactoryAndroid ncn_factory;
1539 NetworkChangeNotifier::DisableForTest ncn_disable_for_test;
1540 std::unique_ptr<NetworkChangeNotifier> ncn(ncn_factory.CreateInstance());
1541 if (!NetworkChangeNotifier::AreNetworkHandlesSupported())
1542 GTEST_SKIP() << "Network handles are required to test BindToNetwork.";
1543
1544 // Binding the socket to a not existing network should fail at connect time.
1545 const handles::NetworkHandle wrong_network_handle = 65536;
1546 UDPClientSocket wrong_socket(DatagramSocket::RANDOM_BIND, nullptr,
1547 NetLogSource(), wrong_network_handle);
1548 // Different Android versions might report different errors. Hence, just check
1549 // what shouldn't happen.
1550 int rv = wrong_socket.Connect(fake_server_address);
1551 EXPECT_NE(OK, rv);
1552 EXPECT_NE(ERR_NOT_IMPLEMENTED, rv);
1553 EXPECT_NE(wrong_network_handle, wrong_socket.GetBoundNetwork());
1554
1555 // Binding the socket to an existing network should succeed.
1556 const handles::NetworkHandle network_handle =
1557 NetworkChangeNotifier::GetDefaultNetwork();
1558 if (network_handle != handles::kInvalidNetworkHandle) {
1559 UDPClientSocket correct_socket(DatagramSocket::RANDOM_BIND, nullptr,
1560 NetLogSource(), network_handle);
1561 EXPECT_EQ(OK, correct_socket.Connect(fake_server_address));
1562 EXPECT_EQ(network_handle, correct_socket.GetBoundNetwork());
1563 }
1564 }
1565
1566 #endif // BUILDFLAG(IS_ANDROID)
1567
1568 // Scoped helper to override the process-wide UDP socket limit.
1569 class OverrideUDPSocketLimit {
1570 public:
OverrideUDPSocketLimit(int new_limit)1571 explicit OverrideUDPSocketLimit(int new_limit) {
1572 base::FieldTrialParams params;
1573 params[features::kLimitOpenUDPSocketsMax.name] =
1574 base::NumberToString(new_limit);
1575
1576 scoped_feature_list_.InitAndEnableFeatureWithParameters(
1577 features::kLimitOpenUDPSockets, params);
1578 }
1579
1580 private:
1581 base::test::ScopedFeatureList scoped_feature_list_;
1582 };
1583
1584 // Tests that UDPClientSocket respects the global UDP socket limits.
TEST_F(UDPSocketTest,LimitClientSocket)1585 TEST_F(UDPSocketTest, LimitClientSocket) {
1586 // Reduce the global UDP limit to 2.
1587 OverrideUDPSocketLimit set_limit(2);
1588
1589 ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
1590
1591 auto socket1 = std::make_unique<UDPClientSocket>(DatagramSocket::DEFAULT_BIND,
1592 nullptr, NetLogSource());
1593 auto socket2 = std::make_unique<UDPClientSocket>(DatagramSocket::DEFAULT_BIND,
1594 nullptr, NetLogSource());
1595
1596 // Simply constructing a UDPClientSocket does not increase the limit (no
1597 // Connect() or Bind() has been called yet).
1598 ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
1599
1600 // The specific value of this address doesn't really matter, and no server
1601 // needs to be running here. The test only needs to call Connect() and won't
1602 // send any datagrams.
1603 IPEndPoint server_address(IPAddress::IPv4Localhost(), 8080);
1604
1605 // Successful Connect() on socket1 increases socket count.
1606 EXPECT_THAT(socket1->Connect(server_address), IsOk());
1607 EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1608
1609 // Successful Connect() on socket2 increases socket count.
1610 EXPECT_THAT(socket2->Connect(server_address), IsOk());
1611 EXPECT_EQ(2, GetGlobalUDPSocketCountForTesting());
1612
1613 // Attempting a third Connect() should fail with ERR_INSUFFICIENT_RESOURCES,
1614 // as the limit is currently 2.
1615 auto socket3 = std::make_unique<UDPClientSocket>(DatagramSocket::DEFAULT_BIND,
1616 nullptr, NetLogSource());
1617 EXPECT_THAT(socket3->Connect(server_address),
1618 IsError(ERR_INSUFFICIENT_RESOURCES));
1619 EXPECT_EQ(2, GetGlobalUDPSocketCountForTesting());
1620
1621 // Check that explicitly closing socket2 free up a count.
1622 socket2->Close();
1623 EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1624
1625 // Since the socket was already closed, deleting it will not affect the count.
1626 socket2.reset();
1627 EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1628
1629 // Now that the count is below limit, try to connect another socket. This time
1630 // it will work.
1631 auto socket4 = std::make_unique<UDPClientSocket>(DatagramSocket::DEFAULT_BIND,
1632 nullptr, NetLogSource());
1633 EXPECT_THAT(socket4->Connect(server_address), IsOk());
1634 EXPECT_EQ(2, GetGlobalUDPSocketCountForTesting());
1635
1636 // Verify that closing the two remaining sockets brings the open count back to
1637 // 0.
1638 socket1.reset();
1639 EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1640 socket4.reset();
1641 EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
1642 }
1643
1644 // Tests that UDPSocketClient updates the global counter
1645 // correctly when Connect() fails.
TEST_F(UDPSocketTest,LimitConnectFail)1646 TEST_F(UDPSocketTest, LimitConnectFail) {
1647 ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
1648
1649 {
1650 // Simply allocating a UDPSocket does not increase count.
1651 UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1652 EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
1653
1654 // Calling Open() allocates the socket and increases the global counter.
1655 EXPECT_THAT(socket.Open(ADDRESS_FAMILY_IPV4), IsOk());
1656 EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1657
1658 // Connect to an IPv6 address should fail since the socket was created for
1659 // IPv4.
1660 EXPECT_THAT(socket.Connect(net::IPEndPoint(IPAddress::IPv6Localhost(), 53)),
1661 Not(IsOk()));
1662
1663 // That Connect() failed doesn't change the global counter.
1664 EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1665 }
1666
1667 // Finally, destroying UDPSocket decrements the global counter.
1668 EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
1669 }
1670
1671 // Tests allocating UDPClientSockets and Connect()ing them in parallel.
1672 //
1673 // This is primarily intended for coverage under TSAN, to check for races
1674 // enforcing the global socket counter.
TEST_F(UDPSocketTest,LimitConnectMultithreaded)1675 TEST_F(UDPSocketTest, LimitConnectMultithreaded) {
1676 ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
1677
1678 // Start up some threads.
1679 std::vector<std::unique_ptr<base::Thread>> threads;
1680 for (size_t i = 0; i < 5; ++i) {
1681 threads.push_back(std::make_unique<base::Thread>("Worker thread"));
1682 ASSERT_TRUE(threads.back()->Start());
1683 }
1684
1685 // Post tasks to each of the threads.
1686 for (const auto& thread : threads) {
1687 thread->task_runner()->PostTask(
1688 FROM_HERE, base::BindOnce([] {
1689 // The specific value of this address doesn't really matter, and no
1690 // server needs to be running here. The test only needs to call
1691 // Connect() and won't send any datagrams.
1692 IPEndPoint server_address(IPAddress::IPv4Localhost(), 8080);
1693
1694 UDPClientSocket socket(DatagramSocket::DEFAULT_BIND, nullptr,
1695 NetLogSource());
1696 EXPECT_THAT(socket.Connect(server_address), IsOk());
1697 }));
1698 }
1699
1700 // Complete all the tasks.
1701 threads.clear();
1702
1703 EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
1704 }
1705
1706 } // namespace net
1707