1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/udp_socket.h"
6
7 #include <algorithm>
8
9 #include "base/containers/circular_deque.h"
10 #include "base/functional/bind.h"
11 #include "base/location.h"
12 #include "base/memory/raw_ptr.h"
13 #include "base/memory/weak_ptr.h"
14 #include "base/run_loop.h"
15 #include "base/scoped_clear_last_error.h"
16 #include "base/strings/string_number_conversions.h"
17 #include "base/task/single_thread_task_runner.h"
18 #include "base/test/scoped_feature_list.h"
19 #include "base/threading/thread.h"
20 #include "base/time/time.h"
21 #include "build/build_config.h"
22 #include "build/chromeos_buildflags.h"
23 #include "net/base/features.h"
24 #include "net/base/io_buffer.h"
25 #include "net/base/ip_address.h"
26 #include "net/base/ip_endpoint.h"
27 #include "net/base/net_errors.h"
28 #include "net/base/network_interfaces.h"
29 #include "net/base/test_completion_callback.h"
30 #include "net/log/net_log_event_type.h"
31 #include "net/log/net_log_source.h"
32 #include "net/log/test_net_log.h"
33 #include "net/log/test_net_log_util.h"
34 #include "net/socket/socket_test_util.h"
35 #include "net/socket/udp_client_socket.h"
36 #include "net/socket/udp_server_socket.h"
37 #include "net/socket/udp_socket_global_limits.h"
38 #include "net/test/gtest_util.h"
39 #include "net/test/test_with_task_environment.h"
40 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
41 #include "testing/gmock/include/gmock/gmock.h"
42 #include "testing/gtest/include/gtest/gtest.h"
43 #include "testing/platform_test.h"
44
45 #if !BUILDFLAG(IS_WIN)
46 #include <netinet/in.h>
47 #include <sys/socket.h>
48 #else
49 #include <winsock2.h>
50 #endif
51
52 #if BUILDFLAG(IS_ANDROID)
53 #include "base/android/build_info.h"
54 #include "net/android/network_change_notifier_factory_android.h"
55 #include "net/base/network_change_notifier.h"
56 #endif
57
58 #if BUILDFLAG(IS_IOS)
59 #include <TargetConditionals.h>
60 #endif
61
62 #if BUILDFLAG(IS_MAC)
63 #include "base/mac/mac_util.h"
64 #endif // BUILDFLAG(IS_MAC)
65
66 using net::test::IsError;
67 using net::test::IsOk;
68 using testing::DoAll;
69 using testing::Not;
70
71 namespace net {
72
73 namespace {
74
75 // Creates an address from ip address and port and writes it to |*address|.
CreateUDPAddress(const std::string & ip_str,uint16_t port,IPEndPoint * address)76 bool CreateUDPAddress(const std::string& ip_str,
77 uint16_t port,
78 IPEndPoint* address) {
79 IPAddress ip_address;
80 if (!ip_address.AssignFromIPLiteral(ip_str))
81 return false;
82
83 *address = IPEndPoint(ip_address, port);
84 return true;
85 }
86
87 class UDPSocketTest : public PlatformTest, public WithTaskEnvironment {
88 public:
UDPSocketTest()89 UDPSocketTest() : buffer_(base::MakeRefCounted<IOBufferWithSize>(kMaxRead)) {}
90
91 // Blocks until data is read from the socket.
RecvFromSocket(UDPServerSocket * socket)92 std::string RecvFromSocket(UDPServerSocket* socket) {
93 TestCompletionCallback callback;
94
95 int rv = socket->RecvFrom(buffer_.get(), kMaxRead, &recv_from_address_,
96 callback.callback());
97 rv = callback.GetResult(rv);
98 if (rv < 0)
99 return std::string();
100 return std::string(buffer_->data(), rv);
101 }
102
103 // Sends UDP packet.
104 // If |address| is specified, then it is used for the destination
105 // to send to. Otherwise, will send to the last socket this server
106 // received from.
SendToSocket(UDPServerSocket * socket,const std::string & msg)107 int SendToSocket(UDPServerSocket* socket, const std::string& msg) {
108 return SendToSocket(socket, msg, recv_from_address_);
109 }
110
SendToSocket(UDPServerSocket * socket,std::string msg,const IPEndPoint & address)111 int SendToSocket(UDPServerSocket* socket,
112 std::string msg,
113 const IPEndPoint& address) {
114 scoped_refptr<StringIOBuffer> io_buffer =
115 base::MakeRefCounted<StringIOBuffer>(msg);
116 TestCompletionCallback callback;
117 int rv = socket->SendTo(io_buffer.get(), io_buffer->size(), address,
118 callback.callback());
119 return callback.GetResult(rv);
120 }
121
ReadSocket(UDPClientSocket * socket)122 std::string ReadSocket(UDPClientSocket* socket) {
123 TestCompletionCallback callback;
124
125 int rv = socket->Read(buffer_.get(), kMaxRead, callback.callback());
126 rv = callback.GetResult(rv);
127 if (rv < 0)
128 return std::string();
129 return std::string(buffer_->data(), rv);
130 }
131
132 // Writes specified message to the socket.
WriteSocket(UDPClientSocket * socket,const std::string & msg)133 int WriteSocket(UDPClientSocket* socket, const std::string& msg) {
134 scoped_refptr<StringIOBuffer> io_buffer =
135 base::MakeRefCounted<StringIOBuffer>(msg);
136 TestCompletionCallback callback;
137 int rv = socket->Write(io_buffer.get(), io_buffer->size(),
138 callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
139 return callback.GetResult(rv);
140 }
141
WriteSocketIgnoreResult(UDPClientSocket * socket,const std::string & msg)142 void WriteSocketIgnoreResult(UDPClientSocket* socket,
143 const std::string& msg) {
144 WriteSocket(socket, msg);
145 }
146
147 // And again for a bare socket
SendToSocket(UDPSocket * socket,std::string msg,const IPEndPoint & address)148 int SendToSocket(UDPSocket* socket,
149 std::string msg,
150 const IPEndPoint& address) {
151 auto io_buffer = base::MakeRefCounted<StringIOBuffer>(msg);
152 TestCompletionCallback callback;
153 int rv = socket->SendTo(io_buffer.get(), io_buffer->size(), address,
154 callback.callback());
155 return callback.GetResult(rv);
156 }
157
158 // Run unit test for a connection test.
159 // |use_nonblocking_io| is used to switch between overlapped and non-blocking
160 // IO on Windows. It has no effect in other ports.
161 void ConnectTest(bool use_nonblocking_io, bool use_async);
162
163 protected:
164 static const int kMaxRead = 1024;
165 scoped_refptr<IOBufferWithSize> buffer_;
166 IPEndPoint recv_from_address_;
167 };
168
169 const int UDPSocketTest::kMaxRead;
170
ReadCompleteCallback(int * result_out,base::OnceClosure callback,int result)171 void ReadCompleteCallback(int* result_out,
172 base::OnceClosure callback,
173 int result) {
174 *result_out = result;
175 std::move(callback).Run();
176 }
177
ConnectTest(bool use_nonblocking_io,bool use_async)178 void UDPSocketTest::ConnectTest(bool use_nonblocking_io, bool use_async) {
179 std::string simple_message("hello world!");
180 RecordingNetLogObserver net_log_observer;
181 // Setup the server to listen.
182 IPEndPoint server_address(IPAddress::IPv4Localhost(), 0 /* port */);
183 auto server =
184 std::make_unique<UDPServerSocket>(NetLog::Get(), NetLogSource());
185 if (use_nonblocking_io)
186 server->UseNonBlockingIO();
187 server->AllowAddressReuse();
188 ASSERT_THAT(server->Listen(server_address), IsOk());
189 // Get bound port.
190 ASSERT_THAT(server->GetLocalAddress(&server_address), IsOk());
191
192 // Setup the client.
193 auto client = std::make_unique<UDPClientSocket>(
194 DatagramSocket::DEFAULT_BIND, NetLog::Get(), NetLogSource());
195 if (use_nonblocking_io)
196 client->UseNonBlockingIO();
197
198 if (!use_async) {
199 EXPECT_THAT(client->Connect(server_address), IsOk());
200 } else {
201 TestCompletionCallback callback;
202 int rv = client->ConnectAsync(server_address, callback.callback());
203 if (rv != OK) {
204 ASSERT_EQ(rv, ERR_IO_PENDING);
205 rv = callback.WaitForResult();
206 EXPECT_EQ(rv, OK);
207 } else {
208 EXPECT_EQ(rv, OK);
209 }
210 }
211 // Client sends to the server.
212 EXPECT_EQ(simple_message.length(),
213 static_cast<size_t>(WriteSocket(client.get(), simple_message)));
214
215 // Server waits for message.
216 std::string str = RecvFromSocket(server.get());
217 EXPECT_EQ(simple_message, str);
218
219 // Server echoes reply.
220 EXPECT_EQ(simple_message.length(),
221 static_cast<size_t>(SendToSocket(server.get(), simple_message)));
222
223 // Client waits for response.
224 str = ReadSocket(client.get());
225 EXPECT_EQ(simple_message, str);
226
227 // Test asynchronous read. Server waits for message.
228 base::RunLoop run_loop;
229 int read_result = 0;
230 int rv = server->RecvFrom(buffer_.get(), kMaxRead, &recv_from_address_,
231 base::BindOnce(&ReadCompleteCallback, &read_result,
232 run_loop.QuitClosure()));
233 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
234
235 // Client sends to the server.
236 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
237 FROM_HERE,
238 base::BindOnce(&UDPSocketTest::WriteSocketIgnoreResult,
239 base::Unretained(this), client.get(), simple_message));
240 run_loop.Run();
241 EXPECT_EQ(simple_message.length(), static_cast<size_t>(read_result));
242 EXPECT_EQ(simple_message, std::string(buffer_->data(), read_result));
243
244 NetLogSource server_net_log_source = server->NetLog().source();
245 NetLogSource client_net_log_source = client->NetLog().source();
246
247 // Delete sockets so they log their final events.
248 server.reset();
249 client.reset();
250
251 // Check the server's log.
252 auto server_entries =
253 net_log_observer.GetEntriesForSource(server_net_log_source);
254 ASSERT_EQ(6u, server_entries.size());
255 EXPECT_TRUE(
256 LogContainsBeginEvent(server_entries, 0, NetLogEventType::SOCKET_ALIVE));
257 EXPECT_TRUE(LogContainsEvent(server_entries, 1,
258 NetLogEventType::UDP_LOCAL_ADDRESS,
259 NetLogEventPhase::NONE));
260 EXPECT_TRUE(LogContainsEvent(server_entries, 2,
261 NetLogEventType::UDP_BYTES_RECEIVED,
262 NetLogEventPhase::NONE));
263 EXPECT_TRUE(LogContainsEvent(server_entries, 3,
264 NetLogEventType::UDP_BYTES_SENT,
265 NetLogEventPhase::NONE));
266 EXPECT_TRUE(LogContainsEvent(server_entries, 4,
267 NetLogEventType::UDP_BYTES_RECEIVED,
268 NetLogEventPhase::NONE));
269 EXPECT_TRUE(
270 LogContainsEndEvent(server_entries, 5, NetLogEventType::SOCKET_ALIVE));
271
272 // Check the client's log.
273 auto client_entries =
274 net_log_observer.GetEntriesForSource(client_net_log_source);
275 EXPECT_EQ(7u, client_entries.size());
276 EXPECT_TRUE(
277 LogContainsBeginEvent(client_entries, 0, NetLogEventType::SOCKET_ALIVE));
278 EXPECT_TRUE(
279 LogContainsBeginEvent(client_entries, 1, NetLogEventType::UDP_CONNECT));
280 EXPECT_TRUE(
281 LogContainsEndEvent(client_entries, 2, NetLogEventType::UDP_CONNECT));
282 EXPECT_TRUE(LogContainsEvent(client_entries, 3,
283 NetLogEventType::UDP_BYTES_SENT,
284 NetLogEventPhase::NONE));
285 EXPECT_TRUE(LogContainsEvent(client_entries, 4,
286 NetLogEventType::UDP_BYTES_RECEIVED,
287 NetLogEventPhase::NONE));
288 EXPECT_TRUE(LogContainsEvent(client_entries, 5,
289 NetLogEventType::UDP_BYTES_SENT,
290 NetLogEventPhase::NONE));
291 EXPECT_TRUE(
292 LogContainsEndEvent(client_entries, 6, NetLogEventType::SOCKET_ALIVE));
293 }
294
TEST_F(UDPSocketTest,Connect)295 TEST_F(UDPSocketTest, Connect) {
296 // The variable |use_nonblocking_io| has no effect in non-Windows ports.
297 // Run ConnectTest once with sync connect and once with async connect
298 ConnectTest(false, false);
299 ConnectTest(false, true);
300 }
301
302 #if BUILDFLAG(IS_WIN)
TEST_F(UDPSocketTest,ConnectNonBlocking)303 TEST_F(UDPSocketTest, ConnectNonBlocking) {
304 ConnectTest(true, false);
305 ConnectTest(true, true);
306 }
307 #endif
308
TEST_F(UDPSocketTest,PartialRecv)309 TEST_F(UDPSocketTest, PartialRecv) {
310 UDPServerSocket server_socket(nullptr, NetLogSource());
311 ASSERT_THAT(server_socket.Listen(IPEndPoint(IPAddress::IPv4Localhost(), 0)),
312 IsOk());
313 IPEndPoint server_address;
314 ASSERT_THAT(server_socket.GetLocalAddress(&server_address), IsOk());
315
316 UDPClientSocket client_socket(DatagramSocket::DEFAULT_BIND, nullptr,
317 NetLogSource());
318 ASSERT_THAT(client_socket.Connect(server_address), IsOk());
319
320 std::string test_packet("hello world!");
321 ASSERT_EQ(static_cast<int>(test_packet.size()),
322 WriteSocket(&client_socket, test_packet));
323
324 TestCompletionCallback recv_callback;
325
326 // Read just 2 bytes. Read() is expected to return the first 2 bytes from the
327 // packet and discard the rest.
328 const int kPartialReadSize = 2;
329 auto buffer = base::MakeRefCounted<IOBufferWithSize>(kPartialReadSize);
330 int rv =
331 server_socket.RecvFrom(buffer.get(), kPartialReadSize,
332 &recv_from_address_, recv_callback.callback());
333 rv = recv_callback.GetResult(rv);
334
335 EXPECT_EQ(rv, ERR_MSG_TOO_BIG);
336
337 // Send a different message again.
338 std::string second_packet("Second packet");
339 ASSERT_EQ(static_cast<int>(second_packet.size()),
340 WriteSocket(&client_socket, second_packet));
341
342 // Read whole packet now.
343 std::string received = RecvFromSocket(&server_socket);
344 EXPECT_EQ(second_packet, received);
345 }
346
347 #if BUILDFLAG(IS_APPLE) || BUILDFLAG(IS_ANDROID)
348 // - MacOS: requires root permissions on OSX 10.7+.
349 // - Android: devices attached to testbots don't have default network, so
350 // broadcasting to 255.255.255.255 returns error -109 (Address not reachable).
351 // crbug.com/139144.
352 #define MAYBE_LocalBroadcast DISABLED_LocalBroadcast
353 #else
354 #define MAYBE_LocalBroadcast LocalBroadcast
355 #endif
TEST_F(UDPSocketTest,MAYBE_LocalBroadcast)356 TEST_F(UDPSocketTest, MAYBE_LocalBroadcast) {
357 std::string first_message("first message"), second_message("second message");
358
359 IPEndPoint listen_address;
360 ASSERT_TRUE(CreateUDPAddress("0.0.0.0", 0 /* port */, &listen_address));
361
362 auto server1 =
363 std::make_unique<UDPServerSocket>(NetLog::Get(), NetLogSource());
364 auto server2 =
365 std::make_unique<UDPServerSocket>(NetLog::Get(), NetLogSource());
366 server1->AllowAddressReuse();
367 server1->AllowBroadcast();
368 server2->AllowAddressReuse();
369 server2->AllowBroadcast();
370
371 EXPECT_THAT(server1->Listen(listen_address), IsOk());
372 // Get bound port.
373 EXPECT_THAT(server1->GetLocalAddress(&listen_address), IsOk());
374 EXPECT_THAT(server2->Listen(listen_address), IsOk());
375
376 IPEndPoint broadcast_address;
377 ASSERT_TRUE(CreateUDPAddress("127.255.255.255", listen_address.port(),
378 &broadcast_address));
379 ASSERT_EQ(static_cast<int>(first_message.size()),
380 SendToSocket(server1.get(), first_message, broadcast_address));
381 std::string str = RecvFromSocket(server1.get());
382 ASSERT_EQ(first_message, str);
383 str = RecvFromSocket(server2.get());
384 ASSERT_EQ(first_message, str);
385
386 ASSERT_EQ(static_cast<int>(second_message.size()),
387 SendToSocket(server2.get(), second_message, broadcast_address));
388 str = RecvFromSocket(server1.get());
389 ASSERT_EQ(second_message, str);
390 str = RecvFromSocket(server2.get());
391 ASSERT_EQ(second_message, str);
392 }
393
394 // ConnectRandomBind verifies RANDOM_BIND is handled correctly. It connects
395 // 1000 sockets and then verifies that the allocated port numbers satisfy the
396 // following 2 conditions:
397 // 1. Range from min port value to max is greater than 10000.
398 // 2. There is at least one port in the 5 buckets in the [min, max] range.
399 //
400 // These conditions are not enough to verify that the port numbers are truly
401 // random, but they are enough to protect from most common non-random port
402 // allocation strategies (e.g. counter, pool of available ports, etc.) False
403 // positive result is theoretically possible, but its probability is negligible.
TEST_F(UDPSocketTest,ConnectRandomBind)404 TEST_F(UDPSocketTest, ConnectRandomBind) {
405 const int kIterations = 1000;
406
407 std::vector<int> used_ports;
408 for (int i = 0; i < kIterations; ++i) {
409 UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
410 NetLogSource());
411 EXPECT_THAT(socket.Connect(IPEndPoint(IPAddress::IPv4Localhost(), 53)),
412 IsOk());
413
414 IPEndPoint client_address;
415 EXPECT_THAT(socket.GetLocalAddress(&client_address), IsOk());
416 used_ports.push_back(client_address.port());
417 }
418
419 int min_port = *std::min_element(used_ports.begin(), used_ports.end());
420 int max_port = *std::max_element(used_ports.begin(), used_ports.end());
421 int range = max_port - min_port + 1;
422
423 // Verify that the range of ports used by the random port allocator is wider
424 // than 10k. Assuming that socket implementation limits port range to 16k
425 // ports (default on Fuchsia) probability of false negative is below
426 // 10^-200.
427 static int kMinRange = 10000;
428 EXPECT_GT(range, kMinRange);
429
430 static int kBuckets = 5;
431 std::vector<int> bucket_sizes(kBuckets, 0);
432 for (int port : used_ports) {
433 bucket_sizes[(port - min_port) * kBuckets / range] += 1;
434 }
435
436 // Verify that there is at least one value in each bucket. Probability of
437 // false negative is below (kBuckets * (1 - 1 / kBuckets) ^ kIterations),
438 // which is less than 10^-96.
439 for (int size : bucket_sizes) {
440 EXPECT_GT(size, 0);
441 }
442 }
443
TEST_F(UDPSocketTest,ConnectFail)444 TEST_F(UDPSocketTest, ConnectFail) {
445 UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
446
447 EXPECT_THAT(socket.Open(ADDRESS_FAMILY_IPV4), IsOk());
448
449 // Connect to an IPv6 address should fail since the socket was created for
450 // IPv4.
451 EXPECT_THAT(socket.Connect(net::IPEndPoint(IPAddress::IPv6Localhost(), 53)),
452 Not(IsOk()));
453
454 // Make sure that UDPSocket actually closed the socket.
455 EXPECT_FALSE(socket.is_connected());
456 }
457
458 // Similar to ConnectFail but UDPSocket adopts an opened socket instead of
459 // opening one directly.
TEST_F(UDPSocketTest,AdoptedSocket)460 TEST_F(UDPSocketTest, AdoptedSocket) {
461 auto socketfd =
462 CreatePlatformSocket(ConvertAddressFamily(ADDRESS_FAMILY_IPV4),
463 SOCK_DGRAM, AF_UNIX ? 0 : IPPROTO_UDP);
464 UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
465
466 EXPECT_THAT(socket.AdoptOpenedSocket(ADDRESS_FAMILY_IPV4, socketfd), IsOk());
467
468 // Connect to an IPv6 address should fail since the socket was created for
469 // IPv4.
470 EXPECT_THAT(socket.Connect(net::IPEndPoint(IPAddress::IPv6Localhost(), 53)),
471 Not(IsOk()));
472
473 // Make sure that UDPSocket actually closed the socket.
474 EXPECT_FALSE(socket.is_connected());
475 }
476
477 // Tests that UDPSocket updates the global counter correctly.
TEST_F(UDPSocketTest,LimitAdoptSocket)478 TEST_F(UDPSocketTest, LimitAdoptSocket) {
479 ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
480 {
481 // Creating a platform socket does not increase count.
482 auto socketfd =
483 CreatePlatformSocket(ConvertAddressFamily(ADDRESS_FAMILY_IPV4),
484 SOCK_DGRAM, AF_UNIX ? 0 : IPPROTO_UDP);
485 ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
486
487 // Simply allocating a UDPSocket does not increase count.
488 UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
489 EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
490
491 // Calling AdoptOpenedSocket() allocates the socket and increases the global
492 // counter.
493 EXPECT_THAT(socket.AdoptOpenedSocket(ADDRESS_FAMILY_IPV4, socketfd),
494 IsOk());
495 EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
496
497 // Connect to an IPv6 address should fail since the socket was created for
498 // IPv4.
499 EXPECT_THAT(socket.Connect(net::IPEndPoint(IPAddress::IPv6Localhost(), 53)),
500 Not(IsOk()));
501
502 // That Connect() failed doesn't change the global counter.
503 EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
504 }
505 // Finally, destroying UDPSocket decrements the global counter.
506 EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
507 }
508
509 // In this test, we verify that connect() on a socket will have the effect
510 // of filtering reads on this socket only to data read from the destination
511 // we connected to.
512 //
513 // The purpose of this test is that some documentation indicates that connect
514 // binds the client's sends to send to a particular server endpoint, but does
515 // not bind the client's reads to only be from that endpoint, and that we need
516 // to always use recvfrom() to disambiguate.
TEST_F(UDPSocketTest,VerifyConnectBindsAddr)517 TEST_F(UDPSocketTest, VerifyConnectBindsAddr) {
518 std::string simple_message("hello world!");
519 std::string foreign_message("BAD MESSAGE TO GET!!");
520
521 // Setup the first server to listen.
522 IPEndPoint server1_address(IPAddress::IPv4Localhost(), 0 /* port */);
523 UDPServerSocket server1(nullptr, NetLogSource());
524 ASSERT_THAT(server1.Listen(server1_address), IsOk());
525 // Get the bound port.
526 ASSERT_THAT(server1.GetLocalAddress(&server1_address), IsOk());
527
528 // Setup the second server to listen.
529 IPEndPoint server2_address(IPAddress::IPv4Localhost(), 0 /* port */);
530 UDPServerSocket server2(nullptr, NetLogSource());
531 ASSERT_THAT(server2.Listen(server2_address), IsOk());
532
533 // Setup the client, connected to server 1.
534 UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
535 EXPECT_THAT(client.Connect(server1_address), IsOk());
536
537 // Client sends to server1.
538 EXPECT_EQ(simple_message.length(),
539 static_cast<size_t>(WriteSocket(&client, simple_message)));
540
541 // Server1 waits for message.
542 std::string str = RecvFromSocket(&server1);
543 EXPECT_EQ(simple_message, str);
544
545 // Get the client's address.
546 IPEndPoint client_address;
547 EXPECT_THAT(client.GetLocalAddress(&client_address), IsOk());
548
549 // Server2 sends reply.
550 EXPECT_EQ(foreign_message.length(),
551 static_cast<size_t>(
552 SendToSocket(&server2, foreign_message, client_address)));
553
554 // Server1 sends reply.
555 EXPECT_EQ(simple_message.length(),
556 static_cast<size_t>(
557 SendToSocket(&server1, simple_message, client_address)));
558
559 // Client waits for response.
560 str = ReadSocket(&client);
561 EXPECT_EQ(simple_message, str);
562 }
563
TEST_F(UDPSocketTest,ClientGetLocalPeerAddresses)564 TEST_F(UDPSocketTest, ClientGetLocalPeerAddresses) {
565 struct TestData {
566 std::string remote_address;
567 std::string local_address;
568 bool may_fail;
569 } tests[] = {
570 {"127.0.00.1", "127.0.0.1", false},
571 {"::1", "::1", true},
572 #if !BUILDFLAG(IS_ANDROID) && !BUILDFLAG(IS_IOS)
573 // Addresses below are disabled on Android. See crbug.com/161248
574 // They are also disabled on iOS. See https://crbug.com/523225
575 {"192.168.1.1", "127.0.0.1", false},
576 {"2001:db8:0::42", "::1", true},
577 #endif
578 };
579 for (const auto& test : tests) {
580 SCOPED_TRACE(std::string("Connecting from ") + test.local_address +
581 std::string(" to ") + test.remote_address);
582
583 IPAddress ip_address;
584 EXPECT_TRUE(ip_address.AssignFromIPLiteral(test.remote_address));
585 IPEndPoint remote_address(ip_address, 80);
586 EXPECT_TRUE(ip_address.AssignFromIPLiteral(test.local_address));
587 IPEndPoint local_address(ip_address, 80);
588
589 UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr,
590 NetLogSource());
591 int rv = client.Connect(remote_address);
592 if (test.may_fail && rv == ERR_ADDRESS_UNREACHABLE) {
593 // Connect() may return ERR_ADDRESS_UNREACHABLE for IPv6
594 // addresses if IPv6 is not configured.
595 continue;
596 }
597
598 EXPECT_LE(ERR_IO_PENDING, rv);
599
600 IPEndPoint fetched_local_address;
601 rv = client.GetLocalAddress(&fetched_local_address);
602 EXPECT_THAT(rv, IsOk());
603
604 // TODO(mbelshe): figure out how to verify the IP and port.
605 // The port is dynamically generated by the udp stack.
606 // The IP is the real IP of the client, not necessarily
607 // loopback.
608 // EXPECT_EQ(local_address.address(), fetched_local_address.address());
609
610 IPEndPoint fetched_remote_address;
611 rv = client.GetPeerAddress(&fetched_remote_address);
612 EXPECT_THAT(rv, IsOk());
613
614 EXPECT_EQ(remote_address, fetched_remote_address);
615 }
616 }
617
TEST_F(UDPSocketTest,ServerGetLocalAddress)618 TEST_F(UDPSocketTest, ServerGetLocalAddress) {
619 IPEndPoint bind_address(IPAddress::IPv4Localhost(), 0);
620 UDPServerSocket server(nullptr, NetLogSource());
621 int rv = server.Listen(bind_address);
622 EXPECT_THAT(rv, IsOk());
623
624 IPEndPoint local_address;
625 rv = server.GetLocalAddress(&local_address);
626 EXPECT_EQ(rv, 0);
627
628 // Verify that port was allocated.
629 EXPECT_GT(local_address.port(), 0);
630 EXPECT_EQ(local_address.address(), bind_address.address());
631 }
632
TEST_F(UDPSocketTest,ServerGetPeerAddress)633 TEST_F(UDPSocketTest, ServerGetPeerAddress) {
634 IPEndPoint bind_address(IPAddress::IPv4Localhost(), 0);
635 UDPServerSocket server(nullptr, NetLogSource());
636 int rv = server.Listen(bind_address);
637 EXPECT_THAT(rv, IsOk());
638
639 IPEndPoint peer_address;
640 rv = server.GetPeerAddress(&peer_address);
641 EXPECT_EQ(rv, ERR_SOCKET_NOT_CONNECTED);
642 }
643
TEST_F(UDPSocketTest,ClientSetDoNotFragment)644 TEST_F(UDPSocketTest, ClientSetDoNotFragment) {
645 for (std::string ip : {"127.0.0.1", "::1"}) {
646 UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr,
647 NetLogSource());
648 IPAddress ip_address;
649 EXPECT_TRUE(ip_address.AssignFromIPLiteral(ip));
650 IPEndPoint remote_address(ip_address, 80);
651 int rv = client.Connect(remote_address);
652 // May fail on IPv6 is IPv6 is not configured.
653 if (ip_address.IsIPv6() && rv == ERR_ADDRESS_UNREACHABLE)
654 return;
655 EXPECT_THAT(rv, IsOk());
656
657 rv = client.SetDoNotFragment();
658 #if BUILDFLAG(IS_IOS) || BUILDFLAG(IS_FUCHSIA)
659 // TODO(crbug.com/945590): IP_MTU_DISCOVER is not implemented on Fuchsia.
660 EXPECT_THAT(rv, IsError(ERR_NOT_IMPLEMENTED));
661 #elif BUILDFLAG(IS_MAC)
662 if (base::mac::MacOSMajorVersion() >= 11) {
663 EXPECT_THAT(rv, IsOk());
664 } else {
665 EXPECT_THAT(rv, IsError(ERR_NOT_IMPLEMENTED));
666 }
667 #else
668 EXPECT_THAT(rv, IsOk());
669 #endif
670 }
671 }
672
TEST_F(UDPSocketTest,ServerSetDoNotFragment)673 TEST_F(UDPSocketTest, ServerSetDoNotFragment) {
674 for (std::string ip : {"127.0.0.1", "::1"}) {
675 IPEndPoint bind_address;
676 ASSERT_TRUE(CreateUDPAddress(ip, 0, &bind_address));
677 UDPServerSocket server(nullptr, NetLogSource());
678 int rv = server.Listen(bind_address);
679 // May fail on IPv6 is IPv6 is not configure
680 if (bind_address.address().IsIPv6() &&
681 (rv == ERR_ADDRESS_INVALID || rv == ERR_ADDRESS_UNREACHABLE))
682 return;
683 EXPECT_THAT(rv, IsOk());
684
685 rv = server.SetDoNotFragment();
686 #if BUILDFLAG(IS_IOS) || BUILDFLAG(IS_FUCHSIA)
687 // TODO(crbug.com/945590): IP_MTU_DISCOVER is not implemented on Fuchsia.
688 EXPECT_THAT(rv, IsError(ERR_NOT_IMPLEMENTED));
689 #elif BUILDFLAG(IS_MAC)
690 if (base::mac::MacOSMajorVersion() >= 11) {
691 EXPECT_THAT(rv, IsOk());
692 } else {
693 EXPECT_THAT(rv, IsError(ERR_NOT_IMPLEMENTED));
694 }
695 #else
696 EXPECT_THAT(rv, IsOk());
697 #endif
698 }
699 }
700
701 // Close the socket while read is pending.
TEST_F(UDPSocketTest,CloseWithPendingRead)702 TEST_F(UDPSocketTest, CloseWithPendingRead) {
703 IPEndPoint bind_address(IPAddress::IPv4Localhost(), 0);
704 UDPServerSocket server(nullptr, NetLogSource());
705 int rv = server.Listen(bind_address);
706 EXPECT_THAT(rv, IsOk());
707
708 TestCompletionCallback callback;
709 IPEndPoint from;
710 rv = server.RecvFrom(buffer_.get(), kMaxRead, &from, callback.callback());
711 EXPECT_EQ(rv, ERR_IO_PENDING);
712
713 server.Close();
714
715 EXPECT_FALSE(callback.have_result());
716 }
717
718 // Some Android devices do not support multicast.
719 // The ones supporting multicast need WifiManager.MulitcastLock to enable it.
720 // http://goo.gl/jjAk9
721 #if !BUILDFLAG(IS_ANDROID)
TEST_F(UDPSocketTest,JoinMulticastGroup)722 TEST_F(UDPSocketTest, JoinMulticastGroup) {
723 const char kGroup[] = "237.132.100.17";
724
725 IPAddress group_ip;
726 EXPECT_TRUE(group_ip.AssignFromIPLiteral(kGroup));
727 // TODO(https://github.com/google/gvisor/issues/3839): don't guard on
728 // OS_FUCHSIA.
729 #if BUILDFLAG(IS_WIN) || BUILDFLAG(IS_FUCHSIA)
730 IPEndPoint bind_address(IPAddress::AllZeros(group_ip.size()), 0 /* port */);
731 #else
732 IPEndPoint bind_address(group_ip, 0 /* port */);
733 #endif // BUILDFLAG(IS_WIN) || BUILDFLAG(IS_FUCHSIA)
734
735 UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
736 EXPECT_THAT(socket.Open(bind_address.GetFamily()), IsOk());
737
738 EXPECT_THAT(socket.Bind(bind_address), IsOk());
739 EXPECT_THAT(socket.JoinGroup(group_ip), IsOk());
740 // Joining group multiple times.
741 EXPECT_NE(OK, socket.JoinGroup(group_ip));
742 EXPECT_THAT(socket.LeaveGroup(group_ip), IsOk());
743 // Leaving group multiple times.
744 EXPECT_NE(OK, socket.LeaveGroup(group_ip));
745
746 socket.Close();
747 }
748
749 // TODO(https://crbug.com/947115): failing on device on iOS 12.2.
750 // TODO(https://crbug.com/1227554): flaky on Mac 11.
751 #if BUILDFLAG(IS_IOS) || BUILDFLAG(IS_MAC)
752 #define MAYBE_SharedMulticastAddress DISABLED_SharedMulticastAddress
753 #else
754 #define MAYBE_SharedMulticastAddress SharedMulticastAddress
755 #endif
TEST_F(UDPSocketTest,MAYBE_SharedMulticastAddress)756 TEST_F(UDPSocketTest, MAYBE_SharedMulticastAddress) {
757 const char kGroup[] = "224.0.0.251";
758
759 IPAddress group_ip;
760 ASSERT_TRUE(group_ip.AssignFromIPLiteral(kGroup));
761 // TODO(https://github.com/google/gvisor/issues/3839): don't guard on
762 // OS_FUCHSIA.
763 #if BUILDFLAG(IS_WIN) || BUILDFLAG(IS_FUCHSIA)
764 IPEndPoint receive_address(IPAddress::AllZeros(group_ip.size()),
765 0 /* port */);
766 #else
767 IPEndPoint receive_address(group_ip, 0 /* port */);
768 #endif // BUILDFLAG(IS_WIN) || BUILDFLAG(IS_FUCHSIA)
769
770 NetworkInterfaceList interfaces;
771 ASSERT_TRUE(GetNetworkList(&interfaces, 0));
772 // The test fails with the Hyper-V switch interface (on the host side).
773 interfaces.erase(std::remove_if(interfaces.begin(), interfaces.end(),
774 [](const auto& iface) {
775 return iface.friendly_name.rfind(
776 "vEthernet", 0) == 0;
777 }),
778 interfaces.end());
779 ASSERT_FALSE(interfaces.empty());
780
781 // Setup first receiving socket.
782 UDPServerSocket socket1(nullptr, NetLogSource());
783 socket1.AllowAddressSharingForMulticast();
784 ASSERT_THAT(socket1.SetMulticastInterface(interfaces[0].interface_index),
785 IsOk());
786 ASSERT_THAT(socket1.Listen(receive_address), IsOk());
787 ASSERT_THAT(socket1.JoinGroup(group_ip), IsOk());
788 // Get the bound port.
789 ASSERT_THAT(socket1.GetLocalAddress(&receive_address), IsOk());
790
791 // Setup second receiving socket.
792 UDPServerSocket socket2(nullptr, NetLogSource());
793 socket2.AllowAddressSharingForMulticast(), IsOk();
794 ASSERT_THAT(socket2.SetMulticastInterface(interfaces[0].interface_index),
795 IsOk());
796 ASSERT_THAT(socket2.Listen(receive_address), IsOk());
797 ASSERT_THAT(socket2.JoinGroup(group_ip), IsOk());
798
799 // Setup client socket.
800 IPEndPoint send_address(group_ip, receive_address.port());
801 UDPClientSocket client_socket(DatagramSocket::DEFAULT_BIND, nullptr,
802 NetLogSource());
803 ASSERT_THAT(client_socket.Connect(send_address), IsOk());
804
805 #if !BUILDFLAG(IS_CHROMEOS_ASH)
806 // Send a message via the multicast group. That message is expected be be
807 // received by both receving sockets.
808 //
809 // Skip on ChromeOS where it's known to sometimes not work.
810 // TODO(crbug.com/898964): If possible, fix and reenable.
811 const char kMessage[] = "hello!";
812 ASSERT_GE(WriteSocket(&client_socket, kMessage), 0);
813 EXPECT_EQ(kMessage, RecvFromSocket(&socket1));
814 EXPECT_EQ(kMessage, RecvFromSocket(&socket2));
815 #endif // !BUILDFLAG(IS_CHROMEOS_ASH)
816 }
817 #endif // !BUILDFLAG(IS_ANDROID)
818
TEST_F(UDPSocketTest,MulticastOptions)819 TEST_F(UDPSocketTest, MulticastOptions) {
820 IPEndPoint bind_address;
821 ASSERT_TRUE(CreateUDPAddress("0.0.0.0", 0 /* port */, &bind_address));
822
823 UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
824 // Before binding.
825 EXPECT_THAT(socket.SetMulticastLoopbackMode(false), IsOk());
826 EXPECT_THAT(socket.SetMulticastLoopbackMode(true), IsOk());
827 EXPECT_THAT(socket.SetMulticastTimeToLive(0), IsOk());
828 EXPECT_THAT(socket.SetMulticastTimeToLive(3), IsOk());
829 EXPECT_NE(OK, socket.SetMulticastTimeToLive(-1));
830 EXPECT_THAT(socket.SetMulticastInterface(0), IsOk());
831
832 EXPECT_THAT(socket.Open(bind_address.GetFamily()), IsOk());
833 EXPECT_THAT(socket.Bind(bind_address), IsOk());
834
835 EXPECT_NE(OK, socket.SetMulticastLoopbackMode(false));
836 EXPECT_NE(OK, socket.SetMulticastTimeToLive(0));
837 EXPECT_NE(OK, socket.SetMulticastInterface(0));
838
839 socket.Close();
840 }
841
842 // Checking that DSCP bits are set correctly is difficult,
843 // but let's check that the code doesn't crash at least.
TEST_F(UDPSocketTest,SetDSCP)844 TEST_F(UDPSocketTest, SetDSCP) {
845 // Setup the server to listen.
846 IPEndPoint bind_address;
847 UDPSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
848 // We need a real IP, but we won't actually send anything to it.
849 ASSERT_TRUE(CreateUDPAddress("8.8.8.8", 9999, &bind_address));
850 int rv = client.Open(bind_address.GetFamily());
851 EXPECT_THAT(rv, IsOk());
852
853 rv = client.Connect(bind_address);
854 if (rv != OK) {
855 // Let's try localhost then.
856 bind_address = IPEndPoint(IPAddress::IPv4Localhost(), 9999);
857 rv = client.Connect(bind_address);
858 }
859 EXPECT_THAT(rv, IsOk());
860
861 client.SetDiffServCodePoint(DSCP_NO_CHANGE);
862 client.SetDiffServCodePoint(DSCP_AF41);
863 client.SetDiffServCodePoint(DSCP_DEFAULT);
864 client.SetDiffServCodePoint(DSCP_CS2);
865 client.SetDiffServCodePoint(DSCP_NO_CHANGE);
866 client.SetDiffServCodePoint(DSCP_DEFAULT);
867 client.Close();
868 }
869
TEST_F(UDPSocketTest,ConnectUsingNetwork)870 TEST_F(UDPSocketTest, ConnectUsingNetwork) {
871 // The specific value of this address doesn't really matter, and no
872 // server needs to be running here. The test only needs to call
873 // ConnectUsingNetwork() and won't send any datagrams.
874 const IPEndPoint fake_server_address(IPAddress::IPv4Localhost(), 8080);
875 const handles::NetworkHandle wrong_network_handle = 65536;
876 #if BUILDFLAG(IS_ANDROID)
877 NetworkChangeNotifierFactoryAndroid ncn_factory;
878 NetworkChangeNotifier::DisableForTest ncn_disable_for_test;
879 std::unique_ptr<NetworkChangeNotifier> ncn(ncn_factory.CreateInstance());
880 if (!NetworkChangeNotifier::AreNetworkHandlesSupported())
881 GTEST_SKIP() << "Network handles are required to test BindToNetwork.";
882
883 {
884 // Connecting using a not existing network should fail but not report
885 // ERR_NOT_IMPLEMENTED when network handles are supported.
886 UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
887 NetLogSource());
888 int rv =
889 socket.ConnectUsingNetwork(wrong_network_handle, fake_server_address);
890 EXPECT_NE(ERR_NOT_IMPLEMENTED, rv);
891 EXPECT_NE(OK, rv);
892 EXPECT_NE(wrong_network_handle, socket.GetBoundNetwork());
893 }
894
895 {
896 // Connecting using an existing network should succeed when
897 // NetworkChangeNotifier returns a valid default network.
898 UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
899 NetLogSource());
900 const handles::NetworkHandle network_handle =
901 NetworkChangeNotifier::GetDefaultNetwork();
902 if (network_handle != handles::kInvalidNetworkHandle) {
903 EXPECT_EQ(
904 OK, socket.ConnectUsingNetwork(network_handle, fake_server_address));
905 EXPECT_EQ(network_handle, socket.GetBoundNetwork());
906 }
907 }
908 #else
909 UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr, NetLogSource());
910 EXPECT_EQ(
911 ERR_NOT_IMPLEMENTED,
912 socket.ConnectUsingNetwork(wrong_network_handle, fake_server_address));
913 #endif // BUILDFLAG(IS_ANDROID)
914 }
915
TEST_F(UDPSocketTest,ConnectUsingNetworkAsync)916 TEST_F(UDPSocketTest, ConnectUsingNetworkAsync) {
917 // The specific value of this address doesn't really matter, and no
918 // server needs to be running here. The test only needs to call
919 // ConnectUsingNetwork() and won't send any datagrams.
920 const IPEndPoint fake_server_address(IPAddress::IPv4Localhost(), 8080);
921 const handles::NetworkHandle wrong_network_handle = 65536;
922 #if BUILDFLAG(IS_ANDROID)
923 NetworkChangeNotifierFactoryAndroid ncn_factory;
924 NetworkChangeNotifier::DisableForTest ncn_disable_for_test;
925 std::unique_ptr<NetworkChangeNotifier> ncn(ncn_factory.CreateInstance());
926 if (!NetworkChangeNotifier::AreNetworkHandlesSupported())
927 GTEST_SKIP() << "Network handles are required to test BindToNetwork.";
928
929 {
930 // Connecting using a not existing network should fail but not report
931 // ERR_NOT_IMPLEMENTED when network handles are supported.
932 UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
933 NetLogSource());
934 TestCompletionCallback callback;
935 int rv = socket.ConnectUsingNetworkAsync(
936 wrong_network_handle, fake_server_address, callback.callback());
937
938 if (rv == ERR_IO_PENDING) {
939 rv = callback.WaitForResult();
940 }
941 EXPECT_NE(ERR_NOT_IMPLEMENTED, rv);
942 EXPECT_NE(OK, rv);
943 }
944
945 {
946 // Connecting using an existing network should succeed when
947 // NetworkChangeNotifier returns a valid default network.
948 UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
949 NetLogSource());
950 TestCompletionCallback callback;
951 const handles::NetworkHandle network_handle =
952 NetworkChangeNotifier::GetDefaultNetwork();
953 if (network_handle != handles::kInvalidNetworkHandle) {
954 int rv = socket.ConnectUsingNetworkAsync(
955 network_handle, fake_server_address, callback.callback());
956 if (rv == ERR_IO_PENDING) {
957 rv = callback.WaitForResult();
958 }
959 EXPECT_EQ(OK, rv);
960 EXPECT_EQ(network_handle, socket.GetBoundNetwork());
961 }
962 }
963 #else
964 UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr, NetLogSource());
965 TestCompletionCallback callback;
966 EXPECT_EQ(ERR_NOT_IMPLEMENTED, socket.ConnectUsingNetworkAsync(
967 wrong_network_handle, fake_server_address,
968 callback.callback()));
969 #endif // BUILDFLAG(IS_ANDROID)
970 }
971
972 } // namespace
973
974 #if BUILDFLAG(IS_WIN)
975
976 namespace {
977
978 const HANDLE kFakeHandle1 = (HANDLE)12;
979 const HANDLE kFakeHandle2 = (HANDLE)13;
980
981 const QOS_FLOWID kFakeFlowId1 = (QOS_FLOWID)27;
982 const QOS_FLOWID kFakeFlowId2 = (QOS_FLOWID)38;
983
984 class TestUDPSocketWin : public UDPSocketWin {
985 public:
TestUDPSocketWin(QwaveApi * qos,DatagramSocket::BindType bind_type,net::NetLog * net_log,const net::NetLogSource & source)986 TestUDPSocketWin(QwaveApi* qos,
987 DatagramSocket::BindType bind_type,
988 net::NetLog* net_log,
989 const net::NetLogSource& source)
990 : UDPSocketWin(bind_type, net_log, source), qos_(qos) {}
991
992 TestUDPSocketWin(const TestUDPSocketWin&) = delete;
993 TestUDPSocketWin& operator=(const TestUDPSocketWin&) = delete;
994
995 // Overriding GetQwaveApi causes the test class to use the injected mock
996 // QwaveApi instance instead of the singleton.
GetQwaveApi() const997 QwaveApi* GetQwaveApi() const override { return qos_; }
998
999 private:
1000 raw_ptr<QwaveApi> qos_;
1001 };
1002
1003 class MockQwaveApi : public QwaveApi {
1004 public:
1005 MOCK_CONST_METHOD0(qwave_supported, bool());
1006 MOCK_METHOD0(OnFatalError, void());
1007 MOCK_METHOD2(CreateHandle, BOOL(PQOS_VERSION version, PHANDLE handle));
1008 MOCK_METHOD1(CloseHandle, BOOL(HANDLE handle));
1009 MOCK_METHOD6(AddSocketToFlow,
1010 BOOL(HANDLE handle,
1011 SOCKET socket,
1012 PSOCKADDR addr,
1013 QOS_TRAFFIC_TYPE traffic_type,
1014 DWORD flags,
1015 PQOS_FLOWID flow_id));
1016
1017 MOCK_METHOD4(
1018 RemoveSocketFromFlow,
1019 BOOL(HANDLE handle, SOCKET socket, QOS_FLOWID flow_id, DWORD reserved));
1020 MOCK_METHOD7(SetFlow,
1021 BOOL(HANDLE handle,
1022 QOS_FLOWID flow_id,
1023 QOS_SET_FLOW op,
1024 ULONG size,
1025 PVOID data,
1026 DWORD reserved,
1027 LPOVERLAPPED overlapped));
1028 };
1029
OpenedDscpTestClient(QwaveApi * api,IPEndPoint bind_address)1030 std::unique_ptr<UDPSocket> OpenedDscpTestClient(QwaveApi* api,
1031 IPEndPoint bind_address) {
1032 auto client = std::make_unique<TestUDPSocketWin>(
1033 api, DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1034 int rv = client->Open(bind_address.GetFamily());
1035 EXPECT_THAT(rv, IsOk());
1036
1037 return client;
1038 }
1039
ConnectedDscpTestClient(QwaveApi * api)1040 std::unique_ptr<UDPSocket> ConnectedDscpTestClient(QwaveApi* api) {
1041 IPEndPoint bind_address;
1042 // We need a real IP, but we won't actually send anything to it.
1043 EXPECT_TRUE(CreateUDPAddress("8.8.8.8", 9999, &bind_address));
1044 auto client = OpenedDscpTestClient(api, bind_address);
1045 EXPECT_THAT(client->Connect(bind_address), IsOk());
1046 return client;
1047 }
1048
UnconnectedDscpTestClient(QwaveApi * api)1049 std::unique_ptr<UDPSocket> UnconnectedDscpTestClient(QwaveApi* api) {
1050 IPEndPoint bind_address;
1051 EXPECT_TRUE(CreateUDPAddress("0.0.0.0", 9999, &bind_address));
1052 auto client = OpenedDscpTestClient(api, bind_address);
1053 EXPECT_THAT(client->Bind(bind_address), IsOk());
1054 return client;
1055 }
1056
1057 } // namespace
1058
1059 using ::testing::Return;
1060 using ::testing::SetArgPointee;
1061 using ::testing::_;
1062
TEST_F(UDPSocketTest,SetDSCPNoopIfPassedNoChange)1063 TEST_F(UDPSocketTest, SetDSCPNoopIfPassedNoChange) {
1064 MockQwaveApi api;
1065 EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1066
1067 EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _)).Times(0);
1068 std::unique_ptr<UDPSocket> client = ConnectedDscpTestClient(&api);
1069 EXPECT_THAT(client->SetDiffServCodePoint(DSCP_NO_CHANGE), IsOk());
1070 }
1071
TEST_F(UDPSocketTest,SetDSCPFailsIfQOSDoesntLink)1072 TEST_F(UDPSocketTest, SetDSCPFailsIfQOSDoesntLink) {
1073 MockQwaveApi api;
1074 EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(false));
1075 EXPECT_CALL(api, CreateHandle(_, _)).Times(0);
1076
1077 std::unique_ptr<UDPSocket> client = ConnectedDscpTestClient(&api);
1078 EXPECT_EQ(ERR_NOT_IMPLEMENTED, client->SetDiffServCodePoint(DSCP_AF41));
1079 }
1080
TEST_F(UDPSocketTest,SetDSCPFailsIfHandleCantBeCreated)1081 TEST_F(UDPSocketTest, SetDSCPFailsIfHandleCantBeCreated) {
1082 MockQwaveApi api;
1083 EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1084 EXPECT_CALL(api, CreateHandle(_, _)).WillOnce(Return(false));
1085 EXPECT_CALL(api, OnFatalError()).Times(1);
1086
1087 std::unique_ptr<UDPSocket> client = ConnectedDscpTestClient(&api);
1088 EXPECT_EQ(ERR_INVALID_HANDLE, client->SetDiffServCodePoint(DSCP_AF41));
1089
1090 RunUntilIdle();
1091
1092 EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(false));
1093 EXPECT_EQ(ERR_NOT_IMPLEMENTED, client->SetDiffServCodePoint(DSCP_AF41));
1094 }
1095
1096 MATCHER_P(DscpPointee, dscp, "") {
1097 return *(DWORD*)arg == (DWORD)dscp;
1098 }
1099
TEST_F(UDPSocketTest,ConnectedSocketDelayedInitAndUpdate)1100 TEST_F(UDPSocketTest, ConnectedSocketDelayedInitAndUpdate) {
1101 MockQwaveApi api;
1102 std::unique_ptr<UDPSocket> client = ConnectedDscpTestClient(&api);
1103 EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1104 EXPECT_CALL(api, CreateHandle(_, _))
1105 .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1106
1107 EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _))
1108 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1109 EXPECT_CALL(api, SetFlow(_, _, _, _, _, _, _));
1110
1111 // First set on connected sockets will fail since init is async and
1112 // we haven't given the runloop a chance to execute the callback.
1113 EXPECT_EQ(ERR_INVALID_HANDLE, client->SetDiffServCodePoint(DSCP_AF41));
1114 RunUntilIdle();
1115 EXPECT_THAT(client->SetDiffServCodePoint(DSCP_AF41), IsOk());
1116
1117 // New dscp value should reset the flow.
1118 EXPECT_CALL(api, RemoveSocketFromFlow(_, _, kFakeFlowId1, _));
1119 EXPECT_CALL(api, AddSocketToFlow(_, _, _, QOSTrafficTypeBestEffort, _, _))
1120 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId2), Return(true)));
1121 EXPECT_CALL(api, SetFlow(_, _, QOSSetOutgoingDSCPValue, _,
1122 DscpPointee(DSCP_DEFAULT), _, _));
1123 EXPECT_THAT(client->SetDiffServCodePoint(DSCP_DEFAULT), IsOk());
1124
1125 // Called from DscpManager destructor.
1126 EXPECT_CALL(api, RemoveSocketFromFlow(_, _, kFakeFlowId2, _));
1127 EXPECT_CALL(api, CloseHandle(kFakeHandle1));
1128 }
1129
TEST_F(UDPSocketTest,UnonnectedSocketDelayedInitAndUpdate)1130 TEST_F(UDPSocketTest, UnonnectedSocketDelayedInitAndUpdate) {
1131 MockQwaveApi api;
1132 EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1133 EXPECT_CALL(api, CreateHandle(_, _))
1134 .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1135
1136 // CreateHandle won't have completed yet. Set passes.
1137 std::unique_ptr<UDPSocket> client = UnconnectedDscpTestClient(&api);
1138 EXPECT_THAT(client->SetDiffServCodePoint(DSCP_AF41), IsOk());
1139
1140 RunUntilIdle();
1141 EXPECT_THAT(client->SetDiffServCodePoint(DSCP_AF42), IsOk());
1142
1143 // Called from DscpManager destructor.
1144 EXPECT_CALL(api, CloseHandle(kFakeHandle1));
1145 }
1146
1147 // TODO(zstein): Mocking out DscpManager might be simpler here
1148 // (just verify that DscpManager::Set and DscpManager::PrepareForSend are
1149 // called).
TEST_F(UDPSocketTest,SendToCallsQwaveApis)1150 TEST_F(UDPSocketTest, SendToCallsQwaveApis) {
1151 MockQwaveApi api;
1152 std::unique_ptr<UDPSocket> client = UnconnectedDscpTestClient(&api);
1153 EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1154 EXPECT_CALL(api, CreateHandle(_, _))
1155 .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1156 EXPECT_THAT(client->SetDiffServCodePoint(DSCP_AF41), IsOk());
1157 RunUntilIdle();
1158
1159 EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _))
1160 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1161 EXPECT_CALL(api, SetFlow(_, _, _, _, _, _, _));
1162 std::string simple_message("hello world");
1163 IPEndPoint server_address(IPAddress::IPv4Localhost(), 9438);
1164 int rv = SendToSocket(client.get(), simple_message, server_address);
1165 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1166
1167 // TODO(zstein): Move to second test case (Qwave APIs called once per address)
1168 rv = SendToSocket(client.get(), simple_message, server_address);
1169 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1170
1171 // TODO(zstein): Move to third test case (Qwave APIs called for each
1172 // destination address).
1173 EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _)).WillOnce(Return(true));
1174 IPEndPoint server_address2(IPAddress::IPv4Localhost(), 9439);
1175
1176 rv = SendToSocket(client.get(), simple_message, server_address2);
1177 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1178
1179 // Called from DscpManager destructor.
1180 EXPECT_CALL(api, RemoveSocketFromFlow(_, _, _, _));
1181 EXPECT_CALL(api, CloseHandle(kFakeHandle1));
1182 }
1183
TEST_F(UDPSocketTest,SendToCallsApisAfterDeferredInit)1184 TEST_F(UDPSocketTest, SendToCallsApisAfterDeferredInit) {
1185 MockQwaveApi api;
1186 std::unique_ptr<UDPSocket> client = UnconnectedDscpTestClient(&api);
1187 EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1188 EXPECT_CALL(api, CreateHandle(_, _))
1189 .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1190
1191 // SetDiffServCodepoint works even if qos api hasn't finished initing.
1192 EXPECT_THAT(client->SetDiffServCodePoint(DSCP_CS7), IsOk());
1193
1194 std::string simple_message("hello world");
1195 IPEndPoint server_address(IPAddress::IPv4Localhost(), 9438);
1196
1197 // SendTo works, but doesn't yet apply TOS
1198 EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _)).Times(0);
1199 int rv = SendToSocket(client.get(), simple_message, server_address);
1200 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1201
1202 RunUntilIdle();
1203 // Now we're initialized, SendTo triggers qos calls with correct codepoint.
1204 EXPECT_CALL(api, AddSocketToFlow(_, _, _, QOSTrafficTypeControl, _, _))
1205 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1206 EXPECT_CALL(api, SetFlow(_, _, _, _, _, _, _)).WillOnce(Return(true));
1207 rv = SendToSocket(client.get(), simple_message, server_address);
1208 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1209
1210 // Called from DscpManager destructor.
1211 EXPECT_CALL(api, RemoveSocketFromFlow(_, _, kFakeFlowId1, _));
1212 EXPECT_CALL(api, CloseHandle(kFakeHandle1));
1213 }
1214
1215 class DscpManagerTest : public TestWithTaskEnvironment {
1216 protected:
DscpManagerTest()1217 DscpManagerTest() {
1218 EXPECT_CALL(api_, qwave_supported()).WillRepeatedly(Return(true));
1219 EXPECT_CALL(api_, CreateHandle(_, _))
1220 .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1221 dscp_manager_ = std::make_unique<DscpManager>(&api_, INVALID_SOCKET);
1222
1223 CreateUDPAddress("1.2.3.4", 9001, &address1_);
1224 CreateUDPAddress("1234:5678:90ab:cdef:1234:5678:90ab:cdef", 9002,
1225 &address2_);
1226 }
1227
1228 MockQwaveApi api_;
1229 std::unique_ptr<DscpManager> dscp_manager_;
1230
1231 IPEndPoint address1_;
1232 IPEndPoint address2_;
1233 };
1234
TEST_F(DscpManagerTest,PrepareForSendIsNoopIfNoSet)1235 TEST_F(DscpManagerTest, PrepareForSendIsNoopIfNoSet) {
1236 RunUntilIdle();
1237 dscp_manager_->PrepareForSend(address1_);
1238 }
1239
TEST_F(DscpManagerTest,PrepareForSendCallsQwaveApisAfterSet)1240 TEST_F(DscpManagerTest, PrepareForSendCallsQwaveApisAfterSet) {
1241 RunUntilIdle();
1242 dscp_manager_->Set(DSCP_CS2);
1243
1244 // AddSocketToFlow should be called for each address.
1245 // SetFlow should only be called when the flow is first created.
1246 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1247 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1248 EXPECT_CALL(api_, SetFlow(_, kFakeFlowId1, _, _, _, _, _));
1249 dscp_manager_->PrepareForSend(address1_);
1250
1251 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1252 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1253 EXPECT_CALL(api_, SetFlow(_, _, _, _, _, _, _)).Times(0);
1254 dscp_manager_->PrepareForSend(address2_);
1255
1256 // Called from DscpManager destructor.
1257 EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId1, _));
1258 EXPECT_CALL(api_, CloseHandle(kFakeHandle1));
1259 }
1260
TEST_F(DscpManagerTest,PrepareForSendCallsQwaveApisOncePerAddress)1261 TEST_F(DscpManagerTest, PrepareForSendCallsQwaveApisOncePerAddress) {
1262 RunUntilIdle();
1263 dscp_manager_->Set(DSCP_CS2);
1264
1265 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1266 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1267 EXPECT_CALL(api_, SetFlow(_, kFakeFlowId1, _, _, _, _, _));
1268 dscp_manager_->PrepareForSend(address1_);
1269 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _)).Times(0);
1270 EXPECT_CALL(api_, SetFlow(_, _, _, _, _, _, _)).Times(0);
1271 dscp_manager_->PrepareForSend(address1_);
1272
1273 // Called from DscpManager destructor.
1274 EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId1, _));
1275 EXPECT_CALL(api_, CloseHandle(kFakeHandle1));
1276 }
1277
TEST_F(DscpManagerTest,SetDestroysExistingFlow)1278 TEST_F(DscpManagerTest, SetDestroysExistingFlow) {
1279 RunUntilIdle();
1280 dscp_manager_->Set(DSCP_CS2);
1281
1282 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1283 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1284 EXPECT_CALL(api_, SetFlow(_, kFakeFlowId1, _, _, _, _, _));
1285 dscp_manager_->PrepareForSend(address1_);
1286
1287 // Calling Set should destroy the existing flow.
1288 // TODO(zstein): Verify that RemoveSocketFromFlow with no address
1289 // destroys the flow for all destinations.
1290 EXPECT_CALL(api_, RemoveSocketFromFlow(_, NULL, kFakeFlowId1, _));
1291 dscp_manager_->Set(DSCP_CS5);
1292
1293 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1294 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId2), Return(true)));
1295 EXPECT_CALL(api_, SetFlow(_, kFakeFlowId2, _, _, _, _, _));
1296 dscp_manager_->PrepareForSend(address1_);
1297
1298 // Called from DscpManager destructor.
1299 EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId2, _));
1300 EXPECT_CALL(api_, CloseHandle(kFakeHandle1));
1301 }
1302
TEST_F(DscpManagerTest,SocketReAddedOnRecreateHandle)1303 TEST_F(DscpManagerTest, SocketReAddedOnRecreateHandle) {
1304 RunUntilIdle();
1305 dscp_manager_->Set(DSCP_CS2);
1306
1307 // First Set and Send work fine.
1308 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1309 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1310 EXPECT_CALL(api_, SetFlow(_, kFakeFlowId1, _, _, _, _, _))
1311 .WillOnce(Return(true));
1312 EXPECT_THAT(dscp_manager_->PrepareForSend(address1_), IsOk());
1313
1314 // Make Second flow operation fail (requires resetting the codepoint).
1315 EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId1, _))
1316 .WillOnce(Return(true));
1317 dscp_manager_->Set(DSCP_CS7);
1318
1319 auto error = std::make_unique<base::ScopedClearLastError>();
1320 ::SetLastError(ERROR_DEVICE_REINITIALIZATION_NEEDED);
1321 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _)).WillOnce(Return(false));
1322 EXPECT_CALL(api_, SetFlow(_, _, _, _, _, _, _)).Times(0);
1323 EXPECT_CALL(api_, CloseHandle(kFakeHandle1));
1324 EXPECT_CALL(api_, CreateHandle(_, _))
1325 .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle2), Return(true)));
1326 EXPECT_EQ(ERR_INVALID_HANDLE, dscp_manager_->PrepareForSend(address1_));
1327 error = nullptr;
1328 RunUntilIdle();
1329
1330 // Next Send should work fine, without requiring another Set
1331 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, QOSTrafficTypeControl, _, _))
1332 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId2), Return(true)));
1333 EXPECT_CALL(api_, SetFlow(_, kFakeFlowId2, _, _, _, _, _))
1334 .WillOnce(Return(true));
1335 EXPECT_THAT(dscp_manager_->PrepareForSend(address1_), IsOk());
1336
1337 // Called from DscpManager destructor.
1338 EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId2, _));
1339 EXPECT_CALL(api_, CloseHandle(kFakeHandle2));
1340 }
1341
1342 #endif
1343
TEST_F(UDPSocketTest,ReadWithSocketOptimization)1344 TEST_F(UDPSocketTest, ReadWithSocketOptimization) {
1345 std::string simple_message("hello world!");
1346
1347 // Setup the server to listen.
1348 IPEndPoint server_address(IPAddress::IPv4Localhost(), 0 /* port */);
1349 UDPServerSocket server(nullptr, NetLogSource());
1350 server.AllowAddressReuse();
1351 ASSERT_THAT(server.Listen(server_address), IsOk());
1352 // Get bound port.
1353 ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
1354
1355 // Setup the client, enable experimental optimization and connected to the
1356 // server.
1357 UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1358 client.EnableRecvOptimization();
1359 EXPECT_THAT(client.Connect(server_address), IsOk());
1360
1361 // Get the client's address.
1362 IPEndPoint client_address;
1363 EXPECT_THAT(client.GetLocalAddress(&client_address), IsOk());
1364
1365 // Server sends the message to the client.
1366 EXPECT_EQ(simple_message.length(),
1367 static_cast<size_t>(
1368 SendToSocket(&server, simple_message, client_address)));
1369
1370 // Client receives the message.
1371 std::string str = ReadSocket(&client);
1372 EXPECT_EQ(simple_message, str);
1373
1374 server.Close();
1375 client.Close();
1376 }
1377
1378 // Tests that read from a socket correctly returns
1379 // |ERR_MSG_TOO_BIG| when the buffer is too small and
1380 // returns the actual message when it fits the buffer.
1381 // For the optimized path, the buffer size should be at least
1382 // 1 byte greater than the message.
TEST_F(UDPSocketTest,ReadWithSocketOptimizationTruncation)1383 TEST_F(UDPSocketTest, ReadWithSocketOptimizationTruncation) {
1384 std::string too_long_message(kMaxRead + 1, 'A');
1385 std::string right_length_message(kMaxRead - 1, 'B');
1386 std::string exact_length_message(kMaxRead, 'C');
1387
1388 // Setup the server to listen.
1389 IPEndPoint server_address(IPAddress::IPv4Localhost(), 0 /* port */);
1390 UDPServerSocket server(nullptr, NetLogSource());
1391 server.AllowAddressReuse();
1392 ASSERT_THAT(server.Listen(server_address), IsOk());
1393 // Get bound port.
1394 ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
1395
1396 // Setup the client, enable experimental optimization and connected to the
1397 // server.
1398 UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1399 client.EnableRecvOptimization();
1400 EXPECT_THAT(client.Connect(server_address), IsOk());
1401
1402 // Get the client's address.
1403 IPEndPoint client_address;
1404 EXPECT_THAT(client.GetLocalAddress(&client_address), IsOk());
1405
1406 // Send messages to the client.
1407 EXPECT_EQ(too_long_message.length(),
1408 static_cast<size_t>(
1409 SendToSocket(&server, too_long_message, client_address)));
1410 EXPECT_EQ(right_length_message.length(),
1411 static_cast<size_t>(
1412 SendToSocket(&server, right_length_message, client_address)));
1413 EXPECT_EQ(exact_length_message.length(),
1414 static_cast<size_t>(
1415 SendToSocket(&server, exact_length_message, client_address)));
1416
1417 // Client receives the messages.
1418
1419 // 1. The first message is |too_long_message|. Its size exceeds the buffer.
1420 // In that case, the client is expected to get |ERR_MSG_TOO_BIG| when the
1421 // data is read.
1422 TestCompletionCallback callback;
1423 int rv = client.Read(buffer_.get(), kMaxRead, callback.callback());
1424 EXPECT_EQ(ERR_MSG_TOO_BIG, callback.GetResult(rv));
1425
1426 // 2. The second message is |right_length_message|. Its size is
1427 // one byte smaller than the size of the buffer. In that case, the client
1428 // is expected to read the whole message successfully.
1429 rv = client.Read(buffer_.get(), kMaxRead, callback.callback());
1430 rv = callback.GetResult(rv);
1431 EXPECT_EQ(static_cast<int>(right_length_message.length()), rv);
1432 EXPECT_EQ(right_length_message, std::string(buffer_->data(), rv));
1433
1434 // 3. The third message is |exact_length_message|. Its size is equal to
1435 // the read buffer size. In that case, the client expects to get
1436 // |ERR_MSG_TOO_BIG| when the socket is read. Internally, the optimized
1437 // path uses read() system call that requires one extra byte to detect
1438 // truncated messages; therefore, messages that fill the buffer exactly
1439 // are considered truncated.
1440 // The optimization is only enabled on POSIX platforms. On Windows,
1441 // the optimization is turned off; therefore, the client
1442 // should be able to read the whole message without encountering
1443 // |ERR_MSG_TOO_BIG|.
1444 rv = client.Read(buffer_.get(), kMaxRead, callback.callback());
1445 rv = callback.GetResult(rv);
1446 #if BUILDFLAG(IS_POSIX)
1447 EXPECT_EQ(ERR_MSG_TOO_BIG, rv);
1448 #else
1449 EXPECT_EQ(static_cast<int>(exact_length_message.length()), rv);
1450 EXPECT_EQ(exact_length_message, std::string(buffer_->data(), rv));
1451 #endif
1452 server.Close();
1453 client.Close();
1454 }
1455
1456 // On Android, where socket tagging is supported, verify that UDPSocket::Tag
1457 // works as expected.
1458 #if BUILDFLAG(IS_ANDROID)
TEST_F(UDPSocketTest,Tag)1459 TEST_F(UDPSocketTest, Tag) {
1460 if (!CanGetTaggedBytes()) {
1461 DVLOG(0) << "Skipping test - GetTaggedBytes unsupported.";
1462 return;
1463 }
1464
1465 UDPServerSocket server(nullptr, NetLogSource());
1466 ASSERT_THAT(server.Listen(IPEndPoint(IPAddress::IPv4Localhost(), 0)), IsOk());
1467 IPEndPoint server_address;
1468 ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
1469
1470 UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1471 ASSERT_THAT(client.Connect(server_address), IsOk());
1472
1473 // Verify UDP packets are tagged and counted properly.
1474 int32_t tag_val1 = 0x12345678;
1475 uint64_t old_traffic = GetTaggedBytes(tag_val1);
1476 SocketTag tag1(SocketTag::UNSET_UID, tag_val1);
1477 client.ApplySocketTag(tag1);
1478 // Client sends to the server.
1479 std::string simple_message("hello world!");
1480 int rv = WriteSocket(&client, simple_message);
1481 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1482 // Server waits for message.
1483 std::string str = RecvFromSocket(&server);
1484 EXPECT_EQ(simple_message, str);
1485 // Server echoes reply.
1486 rv = SendToSocket(&server, simple_message);
1487 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1488 // Client waits for response.
1489 str = ReadSocket(&client);
1490 EXPECT_EQ(simple_message, str);
1491 EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
1492
1493 // Verify socket can be retagged with a new value and the current process's
1494 // UID.
1495 int32_t tag_val2 = 0x87654321;
1496 old_traffic = GetTaggedBytes(tag_val2);
1497 SocketTag tag2(getuid(), tag_val2);
1498 client.ApplySocketTag(tag2);
1499 // Client sends to the server.
1500 rv = WriteSocket(&client, simple_message);
1501 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1502 // Server waits for message.
1503 str = RecvFromSocket(&server);
1504 EXPECT_EQ(simple_message, str);
1505 // Server echoes reply.
1506 rv = SendToSocket(&server, simple_message);
1507 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1508 // Client waits for response.
1509 str = ReadSocket(&client);
1510 EXPECT_EQ(simple_message, str);
1511 EXPECT_GT(GetTaggedBytes(tag_val2), old_traffic);
1512
1513 // Verify socket can be retagged with a new value and the current process's
1514 // UID.
1515 old_traffic = GetTaggedBytes(tag_val1);
1516 client.ApplySocketTag(tag1);
1517 // Client sends to the server.
1518 rv = WriteSocket(&client, simple_message);
1519 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1520 // Server waits for message.
1521 str = RecvFromSocket(&server);
1522 EXPECT_EQ(simple_message, str);
1523 // Server echoes reply.
1524 rv = SendToSocket(&server, simple_message);
1525 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1526 // Client waits for response.
1527 str = ReadSocket(&client);
1528 EXPECT_EQ(simple_message, str);
1529 EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
1530 }
1531
TEST_F(UDPSocketTest,BindToNetwork)1532 TEST_F(UDPSocketTest, BindToNetwork) {
1533 // The specific value of this address doesn't really matter, and no
1534 // server needs to be running here. The test only needs to call
1535 // Connect() and won't send any datagrams.
1536 const IPEndPoint fake_server_address(IPAddress::IPv4Localhost(), 8080);
1537 NetworkChangeNotifierFactoryAndroid ncn_factory;
1538 NetworkChangeNotifier::DisableForTest ncn_disable_for_test;
1539 std::unique_ptr<NetworkChangeNotifier> ncn(ncn_factory.CreateInstance());
1540 if (!NetworkChangeNotifier::AreNetworkHandlesSupported())
1541 GTEST_SKIP() << "Network handles are required to test BindToNetwork.";
1542
1543 // Binding the socket to a not existing network should fail at connect time.
1544 const handles::NetworkHandle wrong_network_handle = 65536;
1545 UDPClientSocket wrong_socket(DatagramSocket::RANDOM_BIND, nullptr,
1546 NetLogSource(), wrong_network_handle);
1547 // Different Android versions might report different errors. Hence, just check
1548 // what shouldn't happen.
1549 int rv = wrong_socket.Connect(fake_server_address);
1550 EXPECT_NE(OK, rv);
1551 EXPECT_NE(ERR_NOT_IMPLEMENTED, rv);
1552 EXPECT_NE(wrong_network_handle, wrong_socket.GetBoundNetwork());
1553
1554 // Binding the socket to an existing network should succeed.
1555 const handles::NetworkHandle network_handle =
1556 NetworkChangeNotifier::GetDefaultNetwork();
1557 if (network_handle != handles::kInvalidNetworkHandle) {
1558 UDPClientSocket correct_socket(DatagramSocket::RANDOM_BIND, nullptr,
1559 NetLogSource(), network_handle);
1560 EXPECT_EQ(OK, correct_socket.Connect(fake_server_address));
1561 EXPECT_EQ(network_handle, correct_socket.GetBoundNetwork());
1562 }
1563 }
1564
1565 #endif // BUILDFLAG(IS_ANDROID)
1566
1567 // Scoped helper to override the process-wide UDP socket limit.
1568 class OverrideUDPSocketLimit {
1569 public:
OverrideUDPSocketLimit(int new_limit)1570 explicit OverrideUDPSocketLimit(int new_limit) {
1571 base::FieldTrialParams params;
1572 params[features::kLimitOpenUDPSocketsMax.name] =
1573 base::NumberToString(new_limit);
1574
1575 scoped_feature_list_.InitAndEnableFeatureWithParameters(
1576 features::kLimitOpenUDPSockets, params);
1577 }
1578
1579 private:
1580 base::test::ScopedFeatureList scoped_feature_list_;
1581 };
1582
1583 // Tests that UDPClientSocket respects the global UDP socket limits.
TEST_F(UDPSocketTest,LimitClientSocket)1584 TEST_F(UDPSocketTest, LimitClientSocket) {
1585 // Reduce the global UDP limit to 2.
1586 OverrideUDPSocketLimit set_limit(2);
1587
1588 ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
1589
1590 auto socket1 = std::make_unique<UDPClientSocket>(DatagramSocket::DEFAULT_BIND,
1591 nullptr, NetLogSource());
1592 auto socket2 = std::make_unique<UDPClientSocket>(DatagramSocket::DEFAULT_BIND,
1593 nullptr, NetLogSource());
1594
1595 // Simply constructing a UDPClientSocket does not increase the limit (no
1596 // Connect() or Bind() has been called yet).
1597 ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
1598
1599 // The specific value of this address doesn't really matter, and no server
1600 // needs to be running here. The test only needs to call Connect() and won't
1601 // send any datagrams.
1602 IPEndPoint server_address(IPAddress::IPv4Localhost(), 8080);
1603
1604 // Successful Connect() on socket1 increases socket count.
1605 EXPECT_THAT(socket1->Connect(server_address), IsOk());
1606 EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1607
1608 // Successful Connect() on socket2 increases socket count.
1609 EXPECT_THAT(socket2->Connect(server_address), IsOk());
1610 EXPECT_EQ(2, GetGlobalUDPSocketCountForTesting());
1611
1612 // Attempting a third Connect() should fail with ERR_INSUFFICIENT_RESOURCES,
1613 // as the limit is currently 2.
1614 auto socket3 = std::make_unique<UDPClientSocket>(DatagramSocket::DEFAULT_BIND,
1615 nullptr, NetLogSource());
1616 EXPECT_THAT(socket3->Connect(server_address),
1617 IsError(ERR_INSUFFICIENT_RESOURCES));
1618 EXPECT_EQ(2, GetGlobalUDPSocketCountForTesting());
1619
1620 // Check that explicitly closing socket2 free up a count.
1621 socket2->Close();
1622 EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1623
1624 // Since the socket was already closed, deleting it will not affect the count.
1625 socket2.reset();
1626 EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1627
1628 // Now that the count is below limit, try to connect another socket. This time
1629 // it will work.
1630 auto socket4 = std::make_unique<UDPClientSocket>(DatagramSocket::DEFAULT_BIND,
1631 nullptr, NetLogSource());
1632 EXPECT_THAT(socket4->Connect(server_address), IsOk());
1633 EXPECT_EQ(2, GetGlobalUDPSocketCountForTesting());
1634
1635 // Verify that closing the two remaining sockets brings the open count back to
1636 // 0.
1637 socket1.reset();
1638 EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1639 socket4.reset();
1640 EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
1641 }
1642
1643 // Tests that UDPSocketClient updates the global counter
1644 // correctly when Connect() fails.
TEST_F(UDPSocketTest,LimitConnectFail)1645 TEST_F(UDPSocketTest, LimitConnectFail) {
1646 ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
1647
1648 {
1649 // Simply allocating a UDPSocket does not increase count.
1650 UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1651 EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
1652
1653 // Calling Open() allocates the socket and increases the global counter.
1654 EXPECT_THAT(socket.Open(ADDRESS_FAMILY_IPV4), IsOk());
1655 EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1656
1657 // Connect to an IPv6 address should fail since the socket was created for
1658 // IPv4.
1659 EXPECT_THAT(socket.Connect(net::IPEndPoint(IPAddress::IPv6Localhost(), 53)),
1660 Not(IsOk()));
1661
1662 // That Connect() failed doesn't change the global counter.
1663 EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1664 }
1665
1666 // Finally, destroying UDPSocket decrements the global counter.
1667 EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
1668 }
1669
1670 // Tests allocating UDPClientSockets and Connect()ing them in parallel.
1671 //
1672 // This is primarily intended for coverage under TSAN, to check for races
1673 // enforcing the global socket counter.
TEST_F(UDPSocketTest,LimitConnectMultithreaded)1674 TEST_F(UDPSocketTest, LimitConnectMultithreaded) {
1675 ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
1676
1677 // Start up some threads.
1678 std::vector<std::unique_ptr<base::Thread>> threads;
1679 for (size_t i = 0; i < 5; ++i) {
1680 threads.push_back(std::make_unique<base::Thread>("Worker thread"));
1681 ASSERT_TRUE(threads.back()->Start());
1682 }
1683
1684 // Post tasks to each of the threads.
1685 for (const auto& thread : threads) {
1686 thread->task_runner()->PostTask(
1687 FROM_HERE, base::BindOnce([] {
1688 // The specific value of this address doesn't really matter, and no
1689 // server needs to be running here. The test only needs to call
1690 // Connect() and won't send any datagrams.
1691 IPEndPoint server_address(IPAddress::IPv4Localhost(), 8080);
1692
1693 UDPClientSocket socket(DatagramSocket::DEFAULT_BIND, nullptr,
1694 NetLogSource());
1695 EXPECT_THAT(socket.Connect(server_address), IsOk());
1696 }));
1697 }
1698
1699 // Complete all the tasks.
1700 threads.clear();
1701
1702 EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
1703 }
1704
1705 } // namespace net
1706