• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/udp_socket.h"
6 
7 #include <algorithm>
8 
9 #include "base/containers/circular_deque.h"
10 #include "base/functional/bind.h"
11 #include "base/location.h"
12 #include "base/memory/raw_ptr.h"
13 #include "base/memory/weak_ptr.h"
14 #include "base/run_loop.h"
15 #include "base/scoped_clear_last_error.h"
16 #include "base/strings/string_number_conversions.h"
17 #include "base/task/single_thread_task_runner.h"
18 #include "base/test/scoped_feature_list.h"
19 #include "base/threading/thread.h"
20 #include "base/time/time.h"
21 #include "build/build_config.h"
22 #include "build/chromeos_buildflags.h"
23 #include "net/base/features.h"
24 #include "net/base/io_buffer.h"
25 #include "net/base/ip_address.h"
26 #include "net/base/ip_endpoint.h"
27 #include "net/base/net_errors.h"
28 #include "net/base/network_interfaces.h"
29 #include "net/base/test_completion_callback.h"
30 #include "net/log/net_log_event_type.h"
31 #include "net/log/net_log_source.h"
32 #include "net/log/test_net_log.h"
33 #include "net/log/test_net_log_util.h"
34 #include "net/socket/socket_test_util.h"
35 #include "net/socket/udp_client_socket.h"
36 #include "net/socket/udp_server_socket.h"
37 #include "net/socket/udp_socket_global_limits.h"
38 #include "net/test/gtest_util.h"
39 #include "net/test/test_with_task_environment.h"
40 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
41 #include "testing/gmock/include/gmock/gmock.h"
42 #include "testing/gtest/include/gtest/gtest.h"
43 #include "testing/platform_test.h"
44 
45 #if !BUILDFLAG(IS_WIN)
46 #include <netinet/in.h>
47 #include <sys/socket.h>
48 #else
49 #include <winsock2.h>
50 #endif
51 
52 #if BUILDFLAG(IS_ANDROID)
53 #include "base/android/build_info.h"
54 #include "net/android/network_change_notifier_factory_android.h"
55 #include "net/base/network_change_notifier.h"
56 #endif
57 
58 #if BUILDFLAG(IS_IOS)
59 #include <TargetConditionals.h>
60 #endif
61 
62 #if BUILDFLAG(IS_MAC)
63 #include "base/mac/mac_util.h"
64 #endif  // BUILDFLAG(IS_MAC)
65 
66 using net::test::IsError;
67 using net::test::IsOk;
68 using testing::DoAll;
69 using testing::Not;
70 
71 namespace net {
72 
73 namespace {
74 
75 // Creates an address from ip address and port and writes it to |*address|.
CreateUDPAddress(const std::string & ip_str,uint16_t port,IPEndPoint * address)76 bool CreateUDPAddress(const std::string& ip_str,
77                       uint16_t port,
78                       IPEndPoint* address) {
79   IPAddress ip_address;
80   if (!ip_address.AssignFromIPLiteral(ip_str))
81     return false;
82 
83   *address = IPEndPoint(ip_address, port);
84   return true;
85 }
86 
87 class UDPSocketTest : public PlatformTest, public WithTaskEnvironment {
88  public:
UDPSocketTest()89   UDPSocketTest() : buffer_(base::MakeRefCounted<IOBufferWithSize>(kMaxRead)) {}
90 
91   // Blocks until data is read from the socket.
RecvFromSocket(UDPServerSocket * socket)92   std::string RecvFromSocket(UDPServerSocket* socket) {
93     TestCompletionCallback callback;
94 
95     int rv = socket->RecvFrom(buffer_.get(), kMaxRead, &recv_from_address_,
96                               callback.callback());
97     rv = callback.GetResult(rv);
98     if (rv < 0)
99       return std::string();
100     return std::string(buffer_->data(), rv);
101   }
102 
103   // Sends UDP packet.
104   // If |address| is specified, then it is used for the destination
105   // to send to. Otherwise, will send to the last socket this server
106   // received from.
SendToSocket(UDPServerSocket * socket,const std::string & msg)107   int SendToSocket(UDPServerSocket* socket, const std::string& msg) {
108     return SendToSocket(socket, msg, recv_from_address_);
109   }
110 
SendToSocket(UDPServerSocket * socket,std::string msg,const IPEndPoint & address)111   int SendToSocket(UDPServerSocket* socket,
112                    std::string msg,
113                    const IPEndPoint& address) {
114     scoped_refptr<StringIOBuffer> io_buffer =
115         base::MakeRefCounted<StringIOBuffer>(msg);
116     TestCompletionCallback callback;
117     int rv = socket->SendTo(io_buffer.get(), io_buffer->size(), address,
118                             callback.callback());
119     return callback.GetResult(rv);
120   }
121 
ReadSocket(UDPClientSocket * socket)122   std::string ReadSocket(UDPClientSocket* socket) {
123     TestCompletionCallback callback;
124 
125     int rv = socket->Read(buffer_.get(), kMaxRead, callback.callback());
126     rv = callback.GetResult(rv);
127     if (rv < 0)
128       return std::string();
129     return std::string(buffer_->data(), rv);
130   }
131 
132   // Writes specified message to the socket.
WriteSocket(UDPClientSocket * socket,const std::string & msg)133   int WriteSocket(UDPClientSocket* socket, const std::string& msg) {
134     scoped_refptr<StringIOBuffer> io_buffer =
135         base::MakeRefCounted<StringIOBuffer>(msg);
136     TestCompletionCallback callback;
137     int rv = socket->Write(io_buffer.get(), io_buffer->size(),
138                            callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
139     return callback.GetResult(rv);
140   }
141 
WriteSocketIgnoreResult(UDPClientSocket * socket,const std::string & msg)142   void WriteSocketIgnoreResult(UDPClientSocket* socket,
143                                const std::string& msg) {
144     WriteSocket(socket, msg);
145   }
146 
147   // And again for a bare socket
SendToSocket(UDPSocket * socket,std::string msg,const IPEndPoint & address)148   int SendToSocket(UDPSocket* socket,
149                    std::string msg,
150                    const IPEndPoint& address) {
151     auto io_buffer = base::MakeRefCounted<StringIOBuffer>(msg);
152     TestCompletionCallback callback;
153     int rv = socket->SendTo(io_buffer.get(), io_buffer->size(), address,
154                             callback.callback());
155     return callback.GetResult(rv);
156   }
157 
158   // Run unit test for a connection test.
159   // |use_nonblocking_io| is used to switch between overlapped and non-blocking
160   // IO on Windows. It has no effect in other ports.
161   void ConnectTest(bool use_nonblocking_io, bool use_async);
162 
163  protected:
164   static const int kMaxRead = 1024;
165   scoped_refptr<IOBufferWithSize> buffer_;
166   IPEndPoint recv_from_address_;
167 };
168 
169 const int UDPSocketTest::kMaxRead;
170 
ReadCompleteCallback(int * result_out,base::OnceClosure callback,int result)171 void ReadCompleteCallback(int* result_out,
172                           base::OnceClosure callback,
173                           int result) {
174   *result_out = result;
175   std::move(callback).Run();
176 }
177 
ConnectTest(bool use_nonblocking_io,bool use_async)178 void UDPSocketTest::ConnectTest(bool use_nonblocking_io, bool use_async) {
179   std::string simple_message("hello world!");
180   RecordingNetLogObserver net_log_observer;
181   // Setup the server to listen.
182   IPEndPoint server_address(IPAddress::IPv4Localhost(), 0 /* port */);
183   auto server =
184       std::make_unique<UDPServerSocket>(NetLog::Get(), NetLogSource());
185   if (use_nonblocking_io)
186     server->UseNonBlockingIO();
187   server->AllowAddressReuse();
188   ASSERT_THAT(server->Listen(server_address), IsOk());
189   // Get bound port.
190   ASSERT_THAT(server->GetLocalAddress(&server_address), IsOk());
191 
192   // Setup the client.
193   auto client = std::make_unique<UDPClientSocket>(
194       DatagramSocket::DEFAULT_BIND, NetLog::Get(), NetLogSource());
195   if (use_nonblocking_io)
196     client->UseNonBlockingIO();
197 
198   if (!use_async) {
199     EXPECT_THAT(client->Connect(server_address), IsOk());
200   } else {
201     TestCompletionCallback callback;
202     int rv = client->ConnectAsync(server_address, callback.callback());
203     if (rv != OK) {
204       ASSERT_EQ(rv, ERR_IO_PENDING);
205       rv = callback.WaitForResult();
206       EXPECT_EQ(rv, OK);
207     } else {
208       EXPECT_EQ(rv, OK);
209     }
210   }
211   // Client sends to the server.
212   EXPECT_EQ(simple_message.length(),
213             static_cast<size_t>(WriteSocket(client.get(), simple_message)));
214 
215   // Server waits for message.
216   std::string str = RecvFromSocket(server.get());
217   EXPECT_EQ(simple_message, str);
218 
219   // Server echoes reply.
220   EXPECT_EQ(simple_message.length(),
221             static_cast<size_t>(SendToSocket(server.get(), simple_message)));
222 
223   // Client waits for response.
224   str = ReadSocket(client.get());
225   EXPECT_EQ(simple_message, str);
226 
227   // Test asynchronous read. Server waits for message.
228   base::RunLoop run_loop;
229   int read_result = 0;
230   int rv = server->RecvFrom(buffer_.get(), kMaxRead, &recv_from_address_,
231                             base::BindOnce(&ReadCompleteCallback, &read_result,
232                                            run_loop.QuitClosure()));
233   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
234 
235   // Client sends to the server.
236   base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
237       FROM_HERE,
238       base::BindOnce(&UDPSocketTest::WriteSocketIgnoreResult,
239                      base::Unretained(this), client.get(), simple_message));
240   run_loop.Run();
241   EXPECT_EQ(simple_message.length(), static_cast<size_t>(read_result));
242   EXPECT_EQ(simple_message, std::string(buffer_->data(), read_result));
243 
244   NetLogSource server_net_log_source = server->NetLog().source();
245   NetLogSource client_net_log_source = client->NetLog().source();
246 
247   // Delete sockets so they log their final events.
248   server.reset();
249   client.reset();
250 
251   // Check the server's log.
252   auto server_entries =
253       net_log_observer.GetEntriesForSource(server_net_log_source);
254   ASSERT_EQ(6u, server_entries.size());
255   EXPECT_TRUE(
256       LogContainsBeginEvent(server_entries, 0, NetLogEventType::SOCKET_ALIVE));
257   EXPECT_TRUE(LogContainsEvent(server_entries, 1,
258                                NetLogEventType::UDP_LOCAL_ADDRESS,
259                                NetLogEventPhase::NONE));
260   EXPECT_TRUE(LogContainsEvent(server_entries, 2,
261                                NetLogEventType::UDP_BYTES_RECEIVED,
262                                NetLogEventPhase::NONE));
263   EXPECT_TRUE(LogContainsEvent(server_entries, 3,
264                                NetLogEventType::UDP_BYTES_SENT,
265                                NetLogEventPhase::NONE));
266   EXPECT_TRUE(LogContainsEvent(server_entries, 4,
267                                NetLogEventType::UDP_BYTES_RECEIVED,
268                                NetLogEventPhase::NONE));
269   EXPECT_TRUE(
270       LogContainsEndEvent(server_entries, 5, NetLogEventType::SOCKET_ALIVE));
271 
272   // Check the client's log.
273   auto client_entries =
274       net_log_observer.GetEntriesForSource(client_net_log_source);
275   EXPECT_EQ(7u, client_entries.size());
276   EXPECT_TRUE(
277       LogContainsBeginEvent(client_entries, 0, NetLogEventType::SOCKET_ALIVE));
278   EXPECT_TRUE(
279       LogContainsBeginEvent(client_entries, 1, NetLogEventType::UDP_CONNECT));
280   EXPECT_TRUE(
281       LogContainsEndEvent(client_entries, 2, NetLogEventType::UDP_CONNECT));
282   EXPECT_TRUE(LogContainsEvent(client_entries, 3,
283                                NetLogEventType::UDP_BYTES_SENT,
284                                NetLogEventPhase::NONE));
285   EXPECT_TRUE(LogContainsEvent(client_entries, 4,
286                                NetLogEventType::UDP_BYTES_RECEIVED,
287                                NetLogEventPhase::NONE));
288   EXPECT_TRUE(LogContainsEvent(client_entries, 5,
289                                NetLogEventType::UDP_BYTES_SENT,
290                                NetLogEventPhase::NONE));
291   EXPECT_TRUE(
292       LogContainsEndEvent(client_entries, 6, NetLogEventType::SOCKET_ALIVE));
293 }
294 
TEST_F(UDPSocketTest,Connect)295 TEST_F(UDPSocketTest, Connect) {
296   // The variable |use_nonblocking_io| has no effect in non-Windows ports.
297   // Run ConnectTest once with sync connect and once with async connect
298   ConnectTest(false, false);
299   ConnectTest(false, true);
300 }
301 
302 #if BUILDFLAG(IS_WIN)
TEST_F(UDPSocketTest,ConnectNonBlocking)303 TEST_F(UDPSocketTest, ConnectNonBlocking) {
304   ConnectTest(true, false);
305   ConnectTest(true, true);
306 }
307 #endif
308 
TEST_F(UDPSocketTest,PartialRecv)309 TEST_F(UDPSocketTest, PartialRecv) {
310   UDPServerSocket server_socket(nullptr, NetLogSource());
311   ASSERT_THAT(server_socket.Listen(IPEndPoint(IPAddress::IPv4Localhost(), 0)),
312               IsOk());
313   IPEndPoint server_address;
314   ASSERT_THAT(server_socket.GetLocalAddress(&server_address), IsOk());
315 
316   UDPClientSocket client_socket(DatagramSocket::DEFAULT_BIND, nullptr,
317                                 NetLogSource());
318   ASSERT_THAT(client_socket.Connect(server_address), IsOk());
319 
320   std::string test_packet("hello world!");
321   ASSERT_EQ(static_cast<int>(test_packet.size()),
322             WriteSocket(&client_socket, test_packet));
323 
324   TestCompletionCallback recv_callback;
325 
326   // Read just 2 bytes. Read() is expected to return the first 2 bytes from the
327   // packet and discard the rest.
328   const int kPartialReadSize = 2;
329   auto buffer = base::MakeRefCounted<IOBufferWithSize>(kPartialReadSize);
330   int rv =
331       server_socket.RecvFrom(buffer.get(), kPartialReadSize,
332                              &recv_from_address_, recv_callback.callback());
333   rv = recv_callback.GetResult(rv);
334 
335   EXPECT_EQ(rv, ERR_MSG_TOO_BIG);
336 
337   // Send a different message again.
338   std::string second_packet("Second packet");
339   ASSERT_EQ(static_cast<int>(second_packet.size()),
340             WriteSocket(&client_socket, second_packet));
341 
342   // Read whole packet now.
343   std::string received = RecvFromSocket(&server_socket);
344   EXPECT_EQ(second_packet, received);
345 }
346 
347 #if BUILDFLAG(IS_APPLE) || BUILDFLAG(IS_ANDROID)
348 // - MacOS: requires root permissions on OSX 10.7+.
349 // - Android: devices attached to testbots don't have default network, so
350 // broadcasting to 255.255.255.255 returns error -109 (Address not reachable).
351 // crbug.com/139144.
352 #define MAYBE_LocalBroadcast DISABLED_LocalBroadcast
353 #else
354 #define MAYBE_LocalBroadcast LocalBroadcast
355 #endif
TEST_F(UDPSocketTest,MAYBE_LocalBroadcast)356 TEST_F(UDPSocketTest, MAYBE_LocalBroadcast) {
357   std::string first_message("first message"), second_message("second message");
358 
359   IPEndPoint listen_address;
360   ASSERT_TRUE(CreateUDPAddress("0.0.0.0", 0 /* port */, &listen_address));
361 
362   auto server1 =
363       std::make_unique<UDPServerSocket>(NetLog::Get(), NetLogSource());
364   auto server2 =
365       std::make_unique<UDPServerSocket>(NetLog::Get(), NetLogSource());
366   server1->AllowAddressReuse();
367   server1->AllowBroadcast();
368   server2->AllowAddressReuse();
369   server2->AllowBroadcast();
370 
371   EXPECT_THAT(server1->Listen(listen_address), IsOk());
372   // Get bound port.
373   EXPECT_THAT(server1->GetLocalAddress(&listen_address), IsOk());
374   EXPECT_THAT(server2->Listen(listen_address), IsOk());
375 
376   IPEndPoint broadcast_address;
377   ASSERT_TRUE(CreateUDPAddress("127.255.255.255", listen_address.port(),
378                                &broadcast_address));
379   ASSERT_EQ(static_cast<int>(first_message.size()),
380             SendToSocket(server1.get(), first_message, broadcast_address));
381   std::string str = RecvFromSocket(server1.get());
382   ASSERT_EQ(first_message, str);
383   str = RecvFromSocket(server2.get());
384   ASSERT_EQ(first_message, str);
385 
386   ASSERT_EQ(static_cast<int>(second_message.size()),
387             SendToSocket(server2.get(), second_message, broadcast_address));
388   str = RecvFromSocket(server1.get());
389   ASSERT_EQ(second_message, str);
390   str = RecvFromSocket(server2.get());
391   ASSERT_EQ(second_message, str);
392 }
393 
394 // ConnectRandomBind verifies RANDOM_BIND is handled correctly. It connects
395 // 1000 sockets and then verifies that the allocated port numbers satisfy the
396 // following 2 conditions:
397 //  1. Range from min port value to max is greater than 10000.
398 //  2. There is at least one port in the 5 buckets in the [min, max] range.
399 //
400 // These conditions are not enough to verify that the port numbers are truly
401 // random, but they are enough to protect from most common non-random port
402 // allocation strategies (e.g. counter, pool of available ports, etc.) False
403 // positive result is theoretically possible, but its probability is negligible.
TEST_F(UDPSocketTest,ConnectRandomBind)404 TEST_F(UDPSocketTest, ConnectRandomBind) {
405   const int kIterations = 1000;
406 
407   std::vector<int> used_ports;
408   for (int i = 0; i < kIterations; ++i) {
409     UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
410                            NetLogSource());
411     EXPECT_THAT(socket.Connect(IPEndPoint(IPAddress::IPv4Localhost(), 53)),
412                 IsOk());
413 
414     IPEndPoint client_address;
415     EXPECT_THAT(socket.GetLocalAddress(&client_address), IsOk());
416     used_ports.push_back(client_address.port());
417   }
418 
419   int min_port = *std::min_element(used_ports.begin(), used_ports.end());
420   int max_port = *std::max_element(used_ports.begin(), used_ports.end());
421   int range = max_port - min_port + 1;
422 
423   // Verify that the range of ports used by the random port allocator is wider
424   // than 10k. Assuming that socket implementation limits port range to 16k
425   // ports (default on Fuchsia) probability of false negative is below
426   // 10^-200.
427   static int kMinRange = 10000;
428   EXPECT_GT(range, kMinRange);
429 
430   static int kBuckets = 5;
431   std::vector<int> bucket_sizes(kBuckets, 0);
432   for (int port : used_ports) {
433     bucket_sizes[(port - min_port) * kBuckets / range] += 1;
434   }
435 
436   // Verify that there is at least one value in each bucket. Probability of
437   // false negative is below (kBuckets * (1 - 1 / kBuckets) ^ kIterations),
438   // which is less than 10^-96.
439   for (int size : bucket_sizes) {
440     EXPECT_GT(size, 0);
441   }
442 }
443 
TEST_F(UDPSocketTest,ConnectFail)444 TEST_F(UDPSocketTest, ConnectFail) {
445   UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
446 
447   EXPECT_THAT(socket.Open(ADDRESS_FAMILY_IPV4), IsOk());
448 
449   // Connect to an IPv6 address should fail since the socket was created for
450   // IPv4.
451   EXPECT_THAT(socket.Connect(net::IPEndPoint(IPAddress::IPv6Localhost(), 53)),
452               Not(IsOk()));
453 
454   // Make sure that UDPSocket actually closed the socket.
455   EXPECT_FALSE(socket.is_connected());
456 }
457 
458 // Similar to ConnectFail but UDPSocket adopts an opened socket instead of
459 // opening one directly.
TEST_F(UDPSocketTest,AdoptedSocket)460 TEST_F(UDPSocketTest, AdoptedSocket) {
461   auto socketfd =
462       CreatePlatformSocket(ConvertAddressFamily(ADDRESS_FAMILY_IPV4),
463                            SOCK_DGRAM, AF_UNIX ? 0 : IPPROTO_UDP);
464   UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
465 
466   EXPECT_THAT(socket.AdoptOpenedSocket(ADDRESS_FAMILY_IPV4, socketfd), IsOk());
467 
468   // Connect to an IPv6 address should fail since the socket was created for
469   // IPv4.
470   EXPECT_THAT(socket.Connect(net::IPEndPoint(IPAddress::IPv6Localhost(), 53)),
471               Not(IsOk()));
472 
473   // Make sure that UDPSocket actually closed the socket.
474   EXPECT_FALSE(socket.is_connected());
475 }
476 
477 // Tests that UDPSocket updates the global counter correctly.
TEST_F(UDPSocketTest,LimitAdoptSocket)478 TEST_F(UDPSocketTest, LimitAdoptSocket) {
479   ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
480   {
481     // Creating a platform socket does not increase count.
482     auto socketfd =
483         CreatePlatformSocket(ConvertAddressFamily(ADDRESS_FAMILY_IPV4),
484                              SOCK_DGRAM, AF_UNIX ? 0 : IPPROTO_UDP);
485     ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
486 
487     // Simply allocating a UDPSocket does not increase count.
488     UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
489     EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
490 
491     // Calling AdoptOpenedSocket() allocates the socket and increases the global
492     // counter.
493     EXPECT_THAT(socket.AdoptOpenedSocket(ADDRESS_FAMILY_IPV4, socketfd),
494                 IsOk());
495     EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
496 
497     // Connect to an IPv6 address should fail since the socket was created for
498     // IPv4.
499     EXPECT_THAT(socket.Connect(net::IPEndPoint(IPAddress::IPv6Localhost(), 53)),
500                 Not(IsOk()));
501 
502     // That Connect() failed doesn't change the global counter.
503     EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
504   }
505   // Finally, destroying UDPSocket decrements the global counter.
506   EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
507 }
508 
509 // In this test, we verify that connect() on a socket will have the effect
510 // of filtering reads on this socket only to data read from the destination
511 // we connected to.
512 //
513 // The purpose of this test is that some documentation indicates that connect
514 // binds the client's sends to send to a particular server endpoint, but does
515 // not bind the client's reads to only be from that endpoint, and that we need
516 // to always use recvfrom() to disambiguate.
TEST_F(UDPSocketTest,VerifyConnectBindsAddr)517 TEST_F(UDPSocketTest, VerifyConnectBindsAddr) {
518   std::string simple_message("hello world!");
519   std::string foreign_message("BAD MESSAGE TO GET!!");
520 
521   // Setup the first server to listen.
522   IPEndPoint server1_address(IPAddress::IPv4Localhost(), 0 /* port */);
523   UDPServerSocket server1(nullptr, NetLogSource());
524   ASSERT_THAT(server1.Listen(server1_address), IsOk());
525   // Get the bound port.
526   ASSERT_THAT(server1.GetLocalAddress(&server1_address), IsOk());
527 
528   // Setup the second server to listen.
529   IPEndPoint server2_address(IPAddress::IPv4Localhost(), 0 /* port */);
530   UDPServerSocket server2(nullptr, NetLogSource());
531   ASSERT_THAT(server2.Listen(server2_address), IsOk());
532 
533   // Setup the client, connected to server 1.
534   UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
535   EXPECT_THAT(client.Connect(server1_address), IsOk());
536 
537   // Client sends to server1.
538   EXPECT_EQ(simple_message.length(),
539             static_cast<size_t>(WriteSocket(&client, simple_message)));
540 
541   // Server1 waits for message.
542   std::string str = RecvFromSocket(&server1);
543   EXPECT_EQ(simple_message, str);
544 
545   // Get the client's address.
546   IPEndPoint client_address;
547   EXPECT_THAT(client.GetLocalAddress(&client_address), IsOk());
548 
549   // Server2 sends reply.
550   EXPECT_EQ(foreign_message.length(),
551             static_cast<size_t>(
552                 SendToSocket(&server2, foreign_message, client_address)));
553 
554   // Server1 sends reply.
555   EXPECT_EQ(simple_message.length(),
556             static_cast<size_t>(
557                 SendToSocket(&server1, simple_message, client_address)));
558 
559   // Client waits for response.
560   str = ReadSocket(&client);
561   EXPECT_EQ(simple_message, str);
562 }
563 
TEST_F(UDPSocketTest,ClientGetLocalPeerAddresses)564 TEST_F(UDPSocketTest, ClientGetLocalPeerAddresses) {
565   struct TestData {
566     std::string remote_address;
567     std::string local_address;
568     bool may_fail;
569   } tests[] = {
570     {"127.0.00.1", "127.0.0.1", false},
571     {"::1", "::1", true},
572 #if !BUILDFLAG(IS_ANDROID) && !BUILDFLAG(IS_IOS)
573     // Addresses below are disabled on Android. See crbug.com/161248
574     // They are also disabled on iOS. See https://crbug.com/523225
575     {"192.168.1.1", "127.0.0.1", false},
576     {"2001:db8:0::42", "::1", true},
577 #endif
578   };
579   for (const auto& test : tests) {
580     SCOPED_TRACE(std::string("Connecting from ") + test.local_address +
581                  std::string(" to ") + test.remote_address);
582 
583     IPAddress ip_address;
584     EXPECT_TRUE(ip_address.AssignFromIPLiteral(test.remote_address));
585     IPEndPoint remote_address(ip_address, 80);
586     EXPECT_TRUE(ip_address.AssignFromIPLiteral(test.local_address));
587     IPEndPoint local_address(ip_address, 80);
588 
589     UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr,
590                            NetLogSource());
591     int rv = client.Connect(remote_address);
592     if (test.may_fail && rv == ERR_ADDRESS_UNREACHABLE) {
593       // Connect() may return ERR_ADDRESS_UNREACHABLE for IPv6
594       // addresses if IPv6 is not configured.
595       continue;
596     }
597 
598     EXPECT_LE(ERR_IO_PENDING, rv);
599 
600     IPEndPoint fetched_local_address;
601     rv = client.GetLocalAddress(&fetched_local_address);
602     EXPECT_THAT(rv, IsOk());
603 
604     // TODO(mbelshe): figure out how to verify the IP and port.
605     //                The port is dynamically generated by the udp stack.
606     //                The IP is the real IP of the client, not necessarily
607     //                loopback.
608     // EXPECT_EQ(local_address.address(), fetched_local_address.address());
609 
610     IPEndPoint fetched_remote_address;
611     rv = client.GetPeerAddress(&fetched_remote_address);
612     EXPECT_THAT(rv, IsOk());
613 
614     EXPECT_EQ(remote_address, fetched_remote_address);
615   }
616 }
617 
TEST_F(UDPSocketTest,ServerGetLocalAddress)618 TEST_F(UDPSocketTest, ServerGetLocalAddress) {
619   IPEndPoint bind_address(IPAddress::IPv4Localhost(), 0);
620   UDPServerSocket server(nullptr, NetLogSource());
621   int rv = server.Listen(bind_address);
622   EXPECT_THAT(rv, IsOk());
623 
624   IPEndPoint local_address;
625   rv = server.GetLocalAddress(&local_address);
626   EXPECT_EQ(rv, 0);
627 
628   // Verify that port was allocated.
629   EXPECT_GT(local_address.port(), 0);
630   EXPECT_EQ(local_address.address(), bind_address.address());
631 }
632 
TEST_F(UDPSocketTest,ServerGetPeerAddress)633 TEST_F(UDPSocketTest, ServerGetPeerAddress) {
634   IPEndPoint bind_address(IPAddress::IPv4Localhost(), 0);
635   UDPServerSocket server(nullptr, NetLogSource());
636   int rv = server.Listen(bind_address);
637   EXPECT_THAT(rv, IsOk());
638 
639   IPEndPoint peer_address;
640   rv = server.GetPeerAddress(&peer_address);
641   EXPECT_EQ(rv, ERR_SOCKET_NOT_CONNECTED);
642 }
643 
TEST_F(UDPSocketTest,ClientSetDoNotFragment)644 TEST_F(UDPSocketTest, ClientSetDoNotFragment) {
645   for (std::string ip : {"127.0.0.1", "::1"}) {
646     UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr,
647                            NetLogSource());
648     IPAddress ip_address;
649     EXPECT_TRUE(ip_address.AssignFromIPLiteral(ip));
650     IPEndPoint remote_address(ip_address, 80);
651     int rv = client.Connect(remote_address);
652     // May fail on IPv6 is IPv6 is not configured.
653     if (ip_address.IsIPv6() && rv == ERR_ADDRESS_UNREACHABLE)
654       return;
655     EXPECT_THAT(rv, IsOk());
656 
657     rv = client.SetDoNotFragment();
658 #if BUILDFLAG(IS_IOS) || BUILDFLAG(IS_FUCHSIA)
659     // TODO(crbug.com/945590): IP_MTU_DISCOVER is not implemented on Fuchsia.
660     EXPECT_THAT(rv, IsError(ERR_NOT_IMPLEMENTED));
661 #elif BUILDFLAG(IS_MAC)
662     if (base::mac::MacOSMajorVersion() >= 11) {
663       EXPECT_THAT(rv, IsOk());
664     } else {
665       EXPECT_THAT(rv, IsError(ERR_NOT_IMPLEMENTED));
666     }
667 #else
668     EXPECT_THAT(rv, IsOk());
669 #endif
670   }
671 }
672 
TEST_F(UDPSocketTest,ServerSetDoNotFragment)673 TEST_F(UDPSocketTest, ServerSetDoNotFragment) {
674   for (std::string ip : {"127.0.0.1", "::1"}) {
675     IPEndPoint bind_address;
676     ASSERT_TRUE(CreateUDPAddress(ip, 0, &bind_address));
677     UDPServerSocket server(nullptr, NetLogSource());
678     int rv = server.Listen(bind_address);
679     // May fail on IPv6 is IPv6 is not configure
680     if (bind_address.address().IsIPv6() &&
681         (rv == ERR_ADDRESS_INVALID || rv == ERR_ADDRESS_UNREACHABLE))
682       return;
683     EXPECT_THAT(rv, IsOk());
684 
685     rv = server.SetDoNotFragment();
686 #if BUILDFLAG(IS_IOS) || BUILDFLAG(IS_FUCHSIA)
687     // TODO(crbug.com/945590): IP_MTU_DISCOVER is not implemented on Fuchsia.
688     EXPECT_THAT(rv, IsError(ERR_NOT_IMPLEMENTED));
689 #elif BUILDFLAG(IS_MAC)
690     if (base::mac::MacOSMajorVersion() >= 11) {
691       EXPECT_THAT(rv, IsOk());
692     } else {
693       EXPECT_THAT(rv, IsError(ERR_NOT_IMPLEMENTED));
694     }
695 #else
696     EXPECT_THAT(rv, IsOk());
697 #endif
698   }
699 }
700 
701 // Close the socket while read is pending.
TEST_F(UDPSocketTest,CloseWithPendingRead)702 TEST_F(UDPSocketTest, CloseWithPendingRead) {
703   IPEndPoint bind_address(IPAddress::IPv4Localhost(), 0);
704   UDPServerSocket server(nullptr, NetLogSource());
705   int rv = server.Listen(bind_address);
706   EXPECT_THAT(rv, IsOk());
707 
708   TestCompletionCallback callback;
709   IPEndPoint from;
710   rv = server.RecvFrom(buffer_.get(), kMaxRead, &from, callback.callback());
711   EXPECT_EQ(rv, ERR_IO_PENDING);
712 
713   server.Close();
714 
715   EXPECT_FALSE(callback.have_result());
716 }
717 
718 // Some Android devices do not support multicast.
719 // The ones supporting multicast need WifiManager.MulitcastLock to enable it.
720 // http://goo.gl/jjAk9
721 #if !BUILDFLAG(IS_ANDROID)
TEST_F(UDPSocketTest,JoinMulticastGroup)722 TEST_F(UDPSocketTest, JoinMulticastGroup) {
723   const char kGroup[] = "237.132.100.17";
724 
725   IPAddress group_ip;
726   EXPECT_TRUE(group_ip.AssignFromIPLiteral(kGroup));
727 // TODO(https://github.com/google/gvisor/issues/3839): don't guard on
728 // OS_FUCHSIA.
729 #if BUILDFLAG(IS_WIN) || BUILDFLAG(IS_FUCHSIA)
730   IPEndPoint bind_address(IPAddress::AllZeros(group_ip.size()), 0 /* port */);
731 #else
732   IPEndPoint bind_address(group_ip, 0 /* port */);
733 #endif  // BUILDFLAG(IS_WIN) || BUILDFLAG(IS_FUCHSIA)
734 
735   UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
736   EXPECT_THAT(socket.Open(bind_address.GetFamily()), IsOk());
737 
738   EXPECT_THAT(socket.Bind(bind_address), IsOk());
739   EXPECT_THAT(socket.JoinGroup(group_ip), IsOk());
740   // Joining group multiple times.
741   EXPECT_NE(OK, socket.JoinGroup(group_ip));
742   EXPECT_THAT(socket.LeaveGroup(group_ip), IsOk());
743   // Leaving group multiple times.
744   EXPECT_NE(OK, socket.LeaveGroup(group_ip));
745 
746   socket.Close();
747 }
748 
749 // TODO(https://crbug.com/947115): failing on device on iOS 12.2.
750 // TODO(https://crbug.com/1227554): flaky on Mac 11.
751 #if BUILDFLAG(IS_IOS) || BUILDFLAG(IS_MAC)
752 #define MAYBE_SharedMulticastAddress DISABLED_SharedMulticastAddress
753 #else
754 #define MAYBE_SharedMulticastAddress SharedMulticastAddress
755 #endif
TEST_F(UDPSocketTest,MAYBE_SharedMulticastAddress)756 TEST_F(UDPSocketTest, MAYBE_SharedMulticastAddress) {
757   const char kGroup[] = "224.0.0.251";
758 
759   IPAddress group_ip;
760   ASSERT_TRUE(group_ip.AssignFromIPLiteral(kGroup));
761 // TODO(https://github.com/google/gvisor/issues/3839): don't guard on
762 // OS_FUCHSIA.
763 #if BUILDFLAG(IS_WIN) || BUILDFLAG(IS_FUCHSIA)
764   IPEndPoint receive_address(IPAddress::AllZeros(group_ip.size()),
765                              0 /* port */);
766 #else
767   IPEndPoint receive_address(group_ip, 0 /* port */);
768 #endif  // BUILDFLAG(IS_WIN) || BUILDFLAG(IS_FUCHSIA)
769 
770   NetworkInterfaceList interfaces;
771   ASSERT_TRUE(GetNetworkList(&interfaces, 0));
772   // The test fails with the Hyper-V switch interface (on the host side).
773   interfaces.erase(std::remove_if(interfaces.begin(), interfaces.end(),
774                                   [](const auto& iface) {
775                                     return iface.friendly_name.rfind(
776                                                "vEthernet", 0) == 0;
777                                   }),
778                    interfaces.end());
779   ASSERT_FALSE(interfaces.empty());
780 
781   // Setup first receiving socket.
782   UDPServerSocket socket1(nullptr, NetLogSource());
783   socket1.AllowAddressSharingForMulticast();
784   ASSERT_THAT(socket1.SetMulticastInterface(interfaces[0].interface_index),
785               IsOk());
786   ASSERT_THAT(socket1.Listen(receive_address), IsOk());
787   ASSERT_THAT(socket1.JoinGroup(group_ip), IsOk());
788   // Get the bound port.
789   ASSERT_THAT(socket1.GetLocalAddress(&receive_address), IsOk());
790 
791   // Setup second receiving socket.
792   UDPServerSocket socket2(nullptr, NetLogSource());
793   socket2.AllowAddressSharingForMulticast(), IsOk();
794   ASSERT_THAT(socket2.SetMulticastInterface(interfaces[0].interface_index),
795               IsOk());
796   ASSERT_THAT(socket2.Listen(receive_address), IsOk());
797   ASSERT_THAT(socket2.JoinGroup(group_ip), IsOk());
798 
799   // Setup client socket.
800   IPEndPoint send_address(group_ip, receive_address.port());
801   UDPClientSocket client_socket(DatagramSocket::DEFAULT_BIND, nullptr,
802                                 NetLogSource());
803   ASSERT_THAT(client_socket.Connect(send_address), IsOk());
804 
805 #if !BUILDFLAG(IS_CHROMEOS_ASH)
806   // Send a message via the multicast group. That message is expected be be
807   // received by both receving sockets.
808   //
809   // Skip on ChromeOS where it's known to sometimes not work.
810   // TODO(crbug.com/898964): If possible, fix and reenable.
811   const char kMessage[] = "hello!";
812   ASSERT_GE(WriteSocket(&client_socket, kMessage), 0);
813   EXPECT_EQ(kMessage, RecvFromSocket(&socket1));
814   EXPECT_EQ(kMessage, RecvFromSocket(&socket2));
815 #endif  // !BUILDFLAG(IS_CHROMEOS_ASH)
816 }
817 #endif  // !BUILDFLAG(IS_ANDROID)
818 
TEST_F(UDPSocketTest,MulticastOptions)819 TEST_F(UDPSocketTest, MulticastOptions) {
820   IPEndPoint bind_address;
821   ASSERT_TRUE(CreateUDPAddress("0.0.0.0", 0 /* port */, &bind_address));
822 
823   UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
824   // Before binding.
825   EXPECT_THAT(socket.SetMulticastLoopbackMode(false), IsOk());
826   EXPECT_THAT(socket.SetMulticastLoopbackMode(true), IsOk());
827   EXPECT_THAT(socket.SetMulticastTimeToLive(0), IsOk());
828   EXPECT_THAT(socket.SetMulticastTimeToLive(3), IsOk());
829   EXPECT_NE(OK, socket.SetMulticastTimeToLive(-1));
830   EXPECT_THAT(socket.SetMulticastInterface(0), IsOk());
831 
832   EXPECT_THAT(socket.Open(bind_address.GetFamily()), IsOk());
833   EXPECT_THAT(socket.Bind(bind_address), IsOk());
834 
835   EXPECT_NE(OK, socket.SetMulticastLoopbackMode(false));
836   EXPECT_NE(OK, socket.SetMulticastTimeToLive(0));
837   EXPECT_NE(OK, socket.SetMulticastInterface(0));
838 
839   socket.Close();
840 }
841 
842 // Checking that DSCP bits are set correctly is difficult,
843 // but let's check that the code doesn't crash at least.
TEST_F(UDPSocketTest,SetDSCP)844 TEST_F(UDPSocketTest, SetDSCP) {
845   // Setup the server to listen.
846   IPEndPoint bind_address;
847   UDPSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
848   // We need a real IP, but we won't actually send anything to it.
849   ASSERT_TRUE(CreateUDPAddress("8.8.8.8", 9999, &bind_address));
850   int rv = client.Open(bind_address.GetFamily());
851   EXPECT_THAT(rv, IsOk());
852 
853   rv = client.Connect(bind_address);
854   if (rv != OK) {
855     // Let's try localhost then.
856     bind_address = IPEndPoint(IPAddress::IPv4Localhost(), 9999);
857     rv = client.Connect(bind_address);
858   }
859   EXPECT_THAT(rv, IsOk());
860 
861   client.SetDiffServCodePoint(DSCP_NO_CHANGE);
862   client.SetDiffServCodePoint(DSCP_AF41);
863   client.SetDiffServCodePoint(DSCP_DEFAULT);
864   client.SetDiffServCodePoint(DSCP_CS2);
865   client.SetDiffServCodePoint(DSCP_NO_CHANGE);
866   client.SetDiffServCodePoint(DSCP_DEFAULT);
867   client.Close();
868 }
869 
TEST_F(UDPSocketTest,ConnectUsingNetwork)870 TEST_F(UDPSocketTest, ConnectUsingNetwork) {
871   // The specific value of this address doesn't really matter, and no
872   // server needs to be running here. The test only needs to call
873   // ConnectUsingNetwork() and won't send any datagrams.
874   const IPEndPoint fake_server_address(IPAddress::IPv4Localhost(), 8080);
875   const handles::NetworkHandle wrong_network_handle = 65536;
876 #if BUILDFLAG(IS_ANDROID)
877   NetworkChangeNotifierFactoryAndroid ncn_factory;
878   NetworkChangeNotifier::DisableForTest ncn_disable_for_test;
879   std::unique_ptr<NetworkChangeNotifier> ncn(ncn_factory.CreateInstance());
880   if (!NetworkChangeNotifier::AreNetworkHandlesSupported())
881     GTEST_SKIP() << "Network handles are required to test BindToNetwork.";
882 
883   {
884     // Connecting using a not existing network should fail but not report
885     // ERR_NOT_IMPLEMENTED when network handles are supported.
886     UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
887                            NetLogSource());
888     int rv =
889         socket.ConnectUsingNetwork(wrong_network_handle, fake_server_address);
890     EXPECT_NE(ERR_NOT_IMPLEMENTED, rv);
891     EXPECT_NE(OK, rv);
892     EXPECT_NE(wrong_network_handle, socket.GetBoundNetwork());
893   }
894 
895   {
896     // Connecting using an existing network should succeed when
897     // NetworkChangeNotifier returns a valid default network.
898     UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
899                            NetLogSource());
900     const handles::NetworkHandle network_handle =
901         NetworkChangeNotifier::GetDefaultNetwork();
902     if (network_handle != handles::kInvalidNetworkHandle) {
903       EXPECT_EQ(
904           OK, socket.ConnectUsingNetwork(network_handle, fake_server_address));
905       EXPECT_EQ(network_handle, socket.GetBoundNetwork());
906     }
907   }
908 #else
909   UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr, NetLogSource());
910   EXPECT_EQ(
911       ERR_NOT_IMPLEMENTED,
912       socket.ConnectUsingNetwork(wrong_network_handle, fake_server_address));
913 #endif  // BUILDFLAG(IS_ANDROID)
914 }
915 
TEST_F(UDPSocketTest,ConnectUsingNetworkAsync)916 TEST_F(UDPSocketTest, ConnectUsingNetworkAsync) {
917   // The specific value of this address doesn't really matter, and no
918   // server needs to be running here. The test only needs to call
919   // ConnectUsingNetwork() and won't send any datagrams.
920   const IPEndPoint fake_server_address(IPAddress::IPv4Localhost(), 8080);
921   const handles::NetworkHandle wrong_network_handle = 65536;
922 #if BUILDFLAG(IS_ANDROID)
923   NetworkChangeNotifierFactoryAndroid ncn_factory;
924   NetworkChangeNotifier::DisableForTest ncn_disable_for_test;
925   std::unique_ptr<NetworkChangeNotifier> ncn(ncn_factory.CreateInstance());
926   if (!NetworkChangeNotifier::AreNetworkHandlesSupported())
927     GTEST_SKIP() << "Network handles are required to test BindToNetwork.";
928 
929   {
930     // Connecting using a not existing network should fail but not report
931     // ERR_NOT_IMPLEMENTED when network handles are supported.
932     UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
933                            NetLogSource());
934     TestCompletionCallback callback;
935     int rv = socket.ConnectUsingNetworkAsync(
936         wrong_network_handle, fake_server_address, callback.callback());
937 
938     if (rv == ERR_IO_PENDING) {
939       rv = callback.WaitForResult();
940     }
941     EXPECT_NE(ERR_NOT_IMPLEMENTED, rv);
942     EXPECT_NE(OK, rv);
943   }
944 
945   {
946     // Connecting using an existing network should succeed when
947     // NetworkChangeNotifier returns a valid default network.
948     UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
949                            NetLogSource());
950     TestCompletionCallback callback;
951     const handles::NetworkHandle network_handle =
952         NetworkChangeNotifier::GetDefaultNetwork();
953     if (network_handle != handles::kInvalidNetworkHandle) {
954       int rv = socket.ConnectUsingNetworkAsync(
955           network_handle, fake_server_address, callback.callback());
956       if (rv == ERR_IO_PENDING) {
957         rv = callback.WaitForResult();
958       }
959       EXPECT_EQ(OK, rv);
960       EXPECT_EQ(network_handle, socket.GetBoundNetwork());
961     }
962   }
963 #else
964   UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr, NetLogSource());
965   TestCompletionCallback callback;
966   EXPECT_EQ(ERR_NOT_IMPLEMENTED, socket.ConnectUsingNetworkAsync(
967                                      wrong_network_handle, fake_server_address,
968                                      callback.callback()));
969 #endif  // BUILDFLAG(IS_ANDROID)
970 }
971 
972 }  // namespace
973 
974 #if BUILDFLAG(IS_WIN)
975 
976 namespace {
977 
978 const HANDLE kFakeHandle1 = (HANDLE)12;
979 const HANDLE kFakeHandle2 = (HANDLE)13;
980 
981 const QOS_FLOWID kFakeFlowId1 = (QOS_FLOWID)27;
982 const QOS_FLOWID kFakeFlowId2 = (QOS_FLOWID)38;
983 
984 class TestUDPSocketWin : public UDPSocketWin {
985  public:
TestUDPSocketWin(QwaveApi * qos,DatagramSocket::BindType bind_type,net::NetLog * net_log,const net::NetLogSource & source)986   TestUDPSocketWin(QwaveApi* qos,
987                    DatagramSocket::BindType bind_type,
988                    net::NetLog* net_log,
989                    const net::NetLogSource& source)
990       : UDPSocketWin(bind_type, net_log, source), qos_(qos) {}
991 
992   TestUDPSocketWin(const TestUDPSocketWin&) = delete;
993   TestUDPSocketWin& operator=(const TestUDPSocketWin&) = delete;
994 
995   // Overriding GetQwaveApi causes the test class to use the injected mock
996   // QwaveApi instance instead of the singleton.
GetQwaveApi() const997   QwaveApi* GetQwaveApi() const override { return qos_; }
998 
999  private:
1000   raw_ptr<QwaveApi> qos_;
1001 };
1002 
1003 class MockQwaveApi : public QwaveApi {
1004  public:
1005   MOCK_CONST_METHOD0(qwave_supported, bool());
1006   MOCK_METHOD0(OnFatalError, void());
1007   MOCK_METHOD2(CreateHandle, BOOL(PQOS_VERSION version, PHANDLE handle));
1008   MOCK_METHOD1(CloseHandle, BOOL(HANDLE handle));
1009   MOCK_METHOD6(AddSocketToFlow,
1010                BOOL(HANDLE handle,
1011                     SOCKET socket,
1012                     PSOCKADDR addr,
1013                     QOS_TRAFFIC_TYPE traffic_type,
1014                     DWORD flags,
1015                     PQOS_FLOWID flow_id));
1016 
1017   MOCK_METHOD4(
1018       RemoveSocketFromFlow,
1019       BOOL(HANDLE handle, SOCKET socket, QOS_FLOWID flow_id, DWORD reserved));
1020   MOCK_METHOD7(SetFlow,
1021                BOOL(HANDLE handle,
1022                     QOS_FLOWID flow_id,
1023                     QOS_SET_FLOW op,
1024                     ULONG size,
1025                     PVOID data,
1026                     DWORD reserved,
1027                     LPOVERLAPPED overlapped));
1028 };
1029 
OpenedDscpTestClient(QwaveApi * api,IPEndPoint bind_address)1030 std::unique_ptr<UDPSocket> OpenedDscpTestClient(QwaveApi* api,
1031                                                 IPEndPoint bind_address) {
1032   auto client = std::make_unique<TestUDPSocketWin>(
1033       api, DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1034   int rv = client->Open(bind_address.GetFamily());
1035   EXPECT_THAT(rv, IsOk());
1036 
1037   return client;
1038 }
1039 
ConnectedDscpTestClient(QwaveApi * api)1040 std::unique_ptr<UDPSocket> ConnectedDscpTestClient(QwaveApi* api) {
1041   IPEndPoint bind_address;
1042   // We need a real IP, but we won't actually send anything to it.
1043   EXPECT_TRUE(CreateUDPAddress("8.8.8.8", 9999, &bind_address));
1044   auto client = OpenedDscpTestClient(api, bind_address);
1045   EXPECT_THAT(client->Connect(bind_address), IsOk());
1046   return client;
1047 }
1048 
UnconnectedDscpTestClient(QwaveApi * api)1049 std::unique_ptr<UDPSocket> UnconnectedDscpTestClient(QwaveApi* api) {
1050   IPEndPoint bind_address;
1051   EXPECT_TRUE(CreateUDPAddress("0.0.0.0", 9999, &bind_address));
1052   auto client = OpenedDscpTestClient(api, bind_address);
1053   EXPECT_THAT(client->Bind(bind_address), IsOk());
1054   return client;
1055 }
1056 
1057 }  // namespace
1058 
1059 using ::testing::Return;
1060 using ::testing::SetArgPointee;
1061 using ::testing::_;
1062 
TEST_F(UDPSocketTest,SetDSCPNoopIfPassedNoChange)1063 TEST_F(UDPSocketTest, SetDSCPNoopIfPassedNoChange) {
1064   MockQwaveApi api;
1065   EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1066 
1067   EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _)).Times(0);
1068   std::unique_ptr<UDPSocket> client = ConnectedDscpTestClient(&api);
1069   EXPECT_THAT(client->SetDiffServCodePoint(DSCP_NO_CHANGE), IsOk());
1070 }
1071 
TEST_F(UDPSocketTest,SetDSCPFailsIfQOSDoesntLink)1072 TEST_F(UDPSocketTest, SetDSCPFailsIfQOSDoesntLink) {
1073   MockQwaveApi api;
1074   EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(false));
1075   EXPECT_CALL(api, CreateHandle(_, _)).Times(0);
1076 
1077   std::unique_ptr<UDPSocket> client = ConnectedDscpTestClient(&api);
1078   EXPECT_EQ(ERR_NOT_IMPLEMENTED, client->SetDiffServCodePoint(DSCP_AF41));
1079 }
1080 
TEST_F(UDPSocketTest,SetDSCPFailsIfHandleCantBeCreated)1081 TEST_F(UDPSocketTest, SetDSCPFailsIfHandleCantBeCreated) {
1082   MockQwaveApi api;
1083   EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1084   EXPECT_CALL(api, CreateHandle(_, _)).WillOnce(Return(false));
1085   EXPECT_CALL(api, OnFatalError()).Times(1);
1086 
1087   std::unique_ptr<UDPSocket> client = ConnectedDscpTestClient(&api);
1088   EXPECT_EQ(ERR_INVALID_HANDLE, client->SetDiffServCodePoint(DSCP_AF41));
1089 
1090   RunUntilIdle();
1091 
1092   EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(false));
1093   EXPECT_EQ(ERR_NOT_IMPLEMENTED, client->SetDiffServCodePoint(DSCP_AF41));
1094 }
1095 
1096 MATCHER_P(DscpPointee, dscp, "") {
1097   return *(DWORD*)arg == (DWORD)dscp;
1098 }
1099 
TEST_F(UDPSocketTest,ConnectedSocketDelayedInitAndUpdate)1100 TEST_F(UDPSocketTest, ConnectedSocketDelayedInitAndUpdate) {
1101   MockQwaveApi api;
1102   std::unique_ptr<UDPSocket> client = ConnectedDscpTestClient(&api);
1103   EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1104   EXPECT_CALL(api, CreateHandle(_, _))
1105       .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1106 
1107   EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _))
1108       .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1109   EXPECT_CALL(api, SetFlow(_, _, _, _, _, _, _));
1110 
1111   // First set on connected sockets will fail since init is async and
1112   // we haven't given the runloop a chance to execute the callback.
1113   EXPECT_EQ(ERR_INVALID_HANDLE, client->SetDiffServCodePoint(DSCP_AF41));
1114   RunUntilIdle();
1115   EXPECT_THAT(client->SetDiffServCodePoint(DSCP_AF41), IsOk());
1116 
1117   // New dscp value should reset the flow.
1118   EXPECT_CALL(api, RemoveSocketFromFlow(_, _, kFakeFlowId1, _));
1119   EXPECT_CALL(api, AddSocketToFlow(_, _, _, QOSTrafficTypeBestEffort, _, _))
1120       .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId2), Return(true)));
1121   EXPECT_CALL(api, SetFlow(_, _, QOSSetOutgoingDSCPValue, _,
1122                            DscpPointee(DSCP_DEFAULT), _, _));
1123   EXPECT_THAT(client->SetDiffServCodePoint(DSCP_DEFAULT), IsOk());
1124 
1125   // Called from DscpManager destructor.
1126   EXPECT_CALL(api, RemoveSocketFromFlow(_, _, kFakeFlowId2, _));
1127   EXPECT_CALL(api, CloseHandle(kFakeHandle1));
1128 }
1129 
TEST_F(UDPSocketTest,UnonnectedSocketDelayedInitAndUpdate)1130 TEST_F(UDPSocketTest, UnonnectedSocketDelayedInitAndUpdate) {
1131   MockQwaveApi api;
1132   EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1133   EXPECT_CALL(api, CreateHandle(_, _))
1134       .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1135 
1136   // CreateHandle won't have completed yet.  Set passes.
1137   std::unique_ptr<UDPSocket> client = UnconnectedDscpTestClient(&api);
1138   EXPECT_THAT(client->SetDiffServCodePoint(DSCP_AF41), IsOk());
1139 
1140   RunUntilIdle();
1141   EXPECT_THAT(client->SetDiffServCodePoint(DSCP_AF42), IsOk());
1142 
1143   // Called from DscpManager destructor.
1144   EXPECT_CALL(api, CloseHandle(kFakeHandle1));
1145 }
1146 
1147 // TODO(zstein): Mocking out DscpManager might be simpler here
1148 // (just verify that DscpManager::Set and DscpManager::PrepareForSend are
1149 // called).
TEST_F(UDPSocketTest,SendToCallsQwaveApis)1150 TEST_F(UDPSocketTest, SendToCallsQwaveApis) {
1151   MockQwaveApi api;
1152   std::unique_ptr<UDPSocket> client = UnconnectedDscpTestClient(&api);
1153   EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1154   EXPECT_CALL(api, CreateHandle(_, _))
1155       .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1156   EXPECT_THAT(client->SetDiffServCodePoint(DSCP_AF41), IsOk());
1157   RunUntilIdle();
1158 
1159   EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _))
1160       .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1161   EXPECT_CALL(api, SetFlow(_, _, _, _, _, _, _));
1162   std::string simple_message("hello world");
1163   IPEndPoint server_address(IPAddress::IPv4Localhost(), 9438);
1164   int rv = SendToSocket(client.get(), simple_message, server_address);
1165   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1166 
1167   // TODO(zstein): Move to second test case (Qwave APIs called once per address)
1168   rv = SendToSocket(client.get(), simple_message, server_address);
1169   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1170 
1171   // TODO(zstein): Move to third test case (Qwave APIs called for each
1172   // destination address).
1173   EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _)).WillOnce(Return(true));
1174   IPEndPoint server_address2(IPAddress::IPv4Localhost(), 9439);
1175 
1176   rv = SendToSocket(client.get(), simple_message, server_address2);
1177   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1178 
1179   // Called from DscpManager destructor.
1180   EXPECT_CALL(api, RemoveSocketFromFlow(_, _, _, _));
1181   EXPECT_CALL(api, CloseHandle(kFakeHandle1));
1182 }
1183 
TEST_F(UDPSocketTest,SendToCallsApisAfterDeferredInit)1184 TEST_F(UDPSocketTest, SendToCallsApisAfterDeferredInit) {
1185   MockQwaveApi api;
1186   std::unique_ptr<UDPSocket> client = UnconnectedDscpTestClient(&api);
1187   EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1188   EXPECT_CALL(api, CreateHandle(_, _))
1189       .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1190 
1191   // SetDiffServCodepoint works even if qos api hasn't finished initing.
1192   EXPECT_THAT(client->SetDiffServCodePoint(DSCP_CS7), IsOk());
1193 
1194   std::string simple_message("hello world");
1195   IPEndPoint server_address(IPAddress::IPv4Localhost(), 9438);
1196 
1197   // SendTo works, but doesn't yet apply TOS
1198   EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _)).Times(0);
1199   int rv = SendToSocket(client.get(), simple_message, server_address);
1200   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1201 
1202   RunUntilIdle();
1203   // Now we're initialized, SendTo triggers qos calls with correct codepoint.
1204   EXPECT_CALL(api, AddSocketToFlow(_, _, _, QOSTrafficTypeControl, _, _))
1205       .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1206   EXPECT_CALL(api, SetFlow(_, _, _, _, _, _, _)).WillOnce(Return(true));
1207   rv = SendToSocket(client.get(), simple_message, server_address);
1208   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1209 
1210   // Called from DscpManager destructor.
1211   EXPECT_CALL(api, RemoveSocketFromFlow(_, _, kFakeFlowId1, _));
1212   EXPECT_CALL(api, CloseHandle(kFakeHandle1));
1213 }
1214 
1215 class DscpManagerTest : public TestWithTaskEnvironment {
1216  protected:
DscpManagerTest()1217   DscpManagerTest() {
1218     EXPECT_CALL(api_, qwave_supported()).WillRepeatedly(Return(true));
1219     EXPECT_CALL(api_, CreateHandle(_, _))
1220         .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1221     dscp_manager_ = std::make_unique<DscpManager>(&api_, INVALID_SOCKET);
1222 
1223     CreateUDPAddress("1.2.3.4", 9001, &address1_);
1224     CreateUDPAddress("1234:5678:90ab:cdef:1234:5678:90ab:cdef", 9002,
1225                      &address2_);
1226   }
1227 
1228   MockQwaveApi api_;
1229   std::unique_ptr<DscpManager> dscp_manager_;
1230 
1231   IPEndPoint address1_;
1232   IPEndPoint address2_;
1233 };
1234 
TEST_F(DscpManagerTest,PrepareForSendIsNoopIfNoSet)1235 TEST_F(DscpManagerTest, PrepareForSendIsNoopIfNoSet) {
1236   RunUntilIdle();
1237   dscp_manager_->PrepareForSend(address1_);
1238 }
1239 
TEST_F(DscpManagerTest,PrepareForSendCallsQwaveApisAfterSet)1240 TEST_F(DscpManagerTest, PrepareForSendCallsQwaveApisAfterSet) {
1241   RunUntilIdle();
1242   dscp_manager_->Set(DSCP_CS2);
1243 
1244   // AddSocketToFlow should be called for each address.
1245   // SetFlow should only be called when the flow is first created.
1246   EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1247       .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1248   EXPECT_CALL(api_, SetFlow(_, kFakeFlowId1, _, _, _, _, _));
1249   dscp_manager_->PrepareForSend(address1_);
1250 
1251   EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1252       .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1253   EXPECT_CALL(api_, SetFlow(_, _, _, _, _, _, _)).Times(0);
1254   dscp_manager_->PrepareForSend(address2_);
1255 
1256   // Called from DscpManager destructor.
1257   EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId1, _));
1258   EXPECT_CALL(api_, CloseHandle(kFakeHandle1));
1259 }
1260 
TEST_F(DscpManagerTest,PrepareForSendCallsQwaveApisOncePerAddress)1261 TEST_F(DscpManagerTest, PrepareForSendCallsQwaveApisOncePerAddress) {
1262   RunUntilIdle();
1263   dscp_manager_->Set(DSCP_CS2);
1264 
1265   EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1266       .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1267   EXPECT_CALL(api_, SetFlow(_, kFakeFlowId1, _, _, _, _, _));
1268   dscp_manager_->PrepareForSend(address1_);
1269   EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _)).Times(0);
1270   EXPECT_CALL(api_, SetFlow(_, _, _, _, _, _, _)).Times(0);
1271   dscp_manager_->PrepareForSend(address1_);
1272 
1273   // Called from DscpManager destructor.
1274   EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId1, _));
1275   EXPECT_CALL(api_, CloseHandle(kFakeHandle1));
1276 }
1277 
TEST_F(DscpManagerTest,SetDestroysExistingFlow)1278 TEST_F(DscpManagerTest, SetDestroysExistingFlow) {
1279   RunUntilIdle();
1280   dscp_manager_->Set(DSCP_CS2);
1281 
1282   EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1283       .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1284   EXPECT_CALL(api_, SetFlow(_, kFakeFlowId1, _, _, _, _, _));
1285   dscp_manager_->PrepareForSend(address1_);
1286 
1287   // Calling Set should destroy the existing flow.
1288   // TODO(zstein): Verify that RemoveSocketFromFlow with no address
1289   // destroys the flow for all destinations.
1290   EXPECT_CALL(api_, RemoveSocketFromFlow(_, NULL, kFakeFlowId1, _));
1291   dscp_manager_->Set(DSCP_CS5);
1292 
1293   EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1294       .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId2), Return(true)));
1295   EXPECT_CALL(api_, SetFlow(_, kFakeFlowId2, _, _, _, _, _));
1296   dscp_manager_->PrepareForSend(address1_);
1297 
1298   // Called from DscpManager destructor.
1299   EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId2, _));
1300   EXPECT_CALL(api_, CloseHandle(kFakeHandle1));
1301 }
1302 
TEST_F(DscpManagerTest,SocketReAddedOnRecreateHandle)1303 TEST_F(DscpManagerTest, SocketReAddedOnRecreateHandle) {
1304   RunUntilIdle();
1305   dscp_manager_->Set(DSCP_CS2);
1306 
1307   // First Set and Send work fine.
1308   EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1309       .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1310   EXPECT_CALL(api_, SetFlow(_, kFakeFlowId1, _, _, _, _, _))
1311       .WillOnce(Return(true));
1312   EXPECT_THAT(dscp_manager_->PrepareForSend(address1_), IsOk());
1313 
1314   // Make Second flow operation fail (requires resetting the codepoint).
1315   EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId1, _))
1316       .WillOnce(Return(true));
1317   dscp_manager_->Set(DSCP_CS7);
1318 
1319   auto error = std::make_unique<base::ScopedClearLastError>();
1320   ::SetLastError(ERROR_DEVICE_REINITIALIZATION_NEEDED);
1321   EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _)).WillOnce(Return(false));
1322   EXPECT_CALL(api_, SetFlow(_, _, _, _, _, _, _)).Times(0);
1323   EXPECT_CALL(api_, CloseHandle(kFakeHandle1));
1324   EXPECT_CALL(api_, CreateHandle(_, _))
1325       .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle2), Return(true)));
1326   EXPECT_EQ(ERR_INVALID_HANDLE, dscp_manager_->PrepareForSend(address1_));
1327   error = nullptr;
1328   RunUntilIdle();
1329 
1330   // Next Send should work fine, without requiring another Set
1331   EXPECT_CALL(api_, AddSocketToFlow(_, _, _, QOSTrafficTypeControl, _, _))
1332       .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId2), Return(true)));
1333   EXPECT_CALL(api_, SetFlow(_, kFakeFlowId2, _, _, _, _, _))
1334       .WillOnce(Return(true));
1335   EXPECT_THAT(dscp_manager_->PrepareForSend(address1_), IsOk());
1336 
1337   // Called from DscpManager destructor.
1338   EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId2, _));
1339   EXPECT_CALL(api_, CloseHandle(kFakeHandle2));
1340 }
1341 
1342 #endif
1343 
TEST_F(UDPSocketTest,ReadWithSocketOptimization)1344 TEST_F(UDPSocketTest, ReadWithSocketOptimization) {
1345   std::string simple_message("hello world!");
1346 
1347   // Setup the server to listen.
1348   IPEndPoint server_address(IPAddress::IPv4Localhost(), 0 /* port */);
1349   UDPServerSocket server(nullptr, NetLogSource());
1350   server.AllowAddressReuse();
1351   ASSERT_THAT(server.Listen(server_address), IsOk());
1352   // Get bound port.
1353   ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
1354 
1355   // Setup the client, enable experimental optimization and connected to the
1356   // server.
1357   UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1358   client.EnableRecvOptimization();
1359   EXPECT_THAT(client.Connect(server_address), IsOk());
1360 
1361   // Get the client's address.
1362   IPEndPoint client_address;
1363   EXPECT_THAT(client.GetLocalAddress(&client_address), IsOk());
1364 
1365   // Server sends the message to the client.
1366   EXPECT_EQ(simple_message.length(),
1367             static_cast<size_t>(
1368                 SendToSocket(&server, simple_message, client_address)));
1369 
1370   // Client receives the message.
1371   std::string str = ReadSocket(&client);
1372   EXPECT_EQ(simple_message, str);
1373 
1374   server.Close();
1375   client.Close();
1376 }
1377 
1378 // Tests that read from a socket correctly returns
1379 // |ERR_MSG_TOO_BIG| when the buffer is too small and
1380 // returns the actual message when it fits the buffer.
1381 // For the optimized path, the buffer size should be at least
1382 // 1 byte greater than the message.
TEST_F(UDPSocketTest,ReadWithSocketOptimizationTruncation)1383 TEST_F(UDPSocketTest, ReadWithSocketOptimizationTruncation) {
1384   std::string too_long_message(kMaxRead + 1, 'A');
1385   std::string right_length_message(kMaxRead - 1, 'B');
1386   std::string exact_length_message(kMaxRead, 'C');
1387 
1388   // Setup the server to listen.
1389   IPEndPoint server_address(IPAddress::IPv4Localhost(), 0 /* port */);
1390   UDPServerSocket server(nullptr, NetLogSource());
1391   server.AllowAddressReuse();
1392   ASSERT_THAT(server.Listen(server_address), IsOk());
1393   // Get bound port.
1394   ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
1395 
1396   // Setup the client, enable experimental optimization and connected to the
1397   // server.
1398   UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1399   client.EnableRecvOptimization();
1400   EXPECT_THAT(client.Connect(server_address), IsOk());
1401 
1402   // Get the client's address.
1403   IPEndPoint client_address;
1404   EXPECT_THAT(client.GetLocalAddress(&client_address), IsOk());
1405 
1406   // Send messages to the client.
1407   EXPECT_EQ(too_long_message.length(),
1408             static_cast<size_t>(
1409                 SendToSocket(&server, too_long_message, client_address)));
1410   EXPECT_EQ(right_length_message.length(),
1411             static_cast<size_t>(
1412                 SendToSocket(&server, right_length_message, client_address)));
1413   EXPECT_EQ(exact_length_message.length(),
1414             static_cast<size_t>(
1415                 SendToSocket(&server, exact_length_message, client_address)));
1416 
1417   // Client receives the messages.
1418 
1419   // 1. The first message is |too_long_message|. Its size exceeds the buffer.
1420   // In that case, the client is expected to get |ERR_MSG_TOO_BIG| when the
1421   // data is read.
1422   TestCompletionCallback callback;
1423   int rv = client.Read(buffer_.get(), kMaxRead, callback.callback());
1424   EXPECT_EQ(ERR_MSG_TOO_BIG, callback.GetResult(rv));
1425 
1426   // 2. The second message is |right_length_message|. Its size is
1427   // one byte smaller than the size of the buffer. In that case, the client
1428   // is expected to read the whole message successfully.
1429   rv = client.Read(buffer_.get(), kMaxRead, callback.callback());
1430   rv = callback.GetResult(rv);
1431   EXPECT_EQ(static_cast<int>(right_length_message.length()), rv);
1432   EXPECT_EQ(right_length_message, std::string(buffer_->data(), rv));
1433 
1434   // 3. The third message is |exact_length_message|. Its size is equal to
1435   // the read buffer size. In that case, the client expects to get
1436   // |ERR_MSG_TOO_BIG| when the socket is read. Internally, the optimized
1437   // path uses read() system call that requires one extra byte to detect
1438   // truncated messages; therefore, messages that fill the buffer exactly
1439   // are considered truncated.
1440   // The optimization is only enabled on POSIX platforms. On Windows,
1441   // the optimization is turned off; therefore, the client
1442   // should be able to read the whole message without encountering
1443   // |ERR_MSG_TOO_BIG|.
1444   rv = client.Read(buffer_.get(), kMaxRead, callback.callback());
1445   rv = callback.GetResult(rv);
1446 #if BUILDFLAG(IS_POSIX)
1447   EXPECT_EQ(ERR_MSG_TOO_BIG, rv);
1448 #else
1449   EXPECT_EQ(static_cast<int>(exact_length_message.length()), rv);
1450   EXPECT_EQ(exact_length_message, std::string(buffer_->data(), rv));
1451 #endif
1452   server.Close();
1453   client.Close();
1454 }
1455 
1456 // On Android, where socket tagging is supported, verify that UDPSocket::Tag
1457 // works as expected.
1458 #if BUILDFLAG(IS_ANDROID)
TEST_F(UDPSocketTest,Tag)1459 TEST_F(UDPSocketTest, Tag) {
1460   if (!CanGetTaggedBytes()) {
1461     DVLOG(0) << "Skipping test - GetTaggedBytes unsupported.";
1462     return;
1463   }
1464 
1465   UDPServerSocket server(nullptr, NetLogSource());
1466   ASSERT_THAT(server.Listen(IPEndPoint(IPAddress::IPv4Localhost(), 0)), IsOk());
1467   IPEndPoint server_address;
1468   ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
1469 
1470   UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1471   ASSERT_THAT(client.Connect(server_address), IsOk());
1472 
1473   // Verify UDP packets are tagged and counted properly.
1474   int32_t tag_val1 = 0x12345678;
1475   uint64_t old_traffic = GetTaggedBytes(tag_val1);
1476   SocketTag tag1(SocketTag::UNSET_UID, tag_val1);
1477   client.ApplySocketTag(tag1);
1478   // Client sends to the server.
1479   std::string simple_message("hello world!");
1480   int rv = WriteSocket(&client, simple_message);
1481   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1482   // Server waits for message.
1483   std::string str = RecvFromSocket(&server);
1484   EXPECT_EQ(simple_message, str);
1485   // Server echoes reply.
1486   rv = SendToSocket(&server, simple_message);
1487   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1488   // Client waits for response.
1489   str = ReadSocket(&client);
1490   EXPECT_EQ(simple_message, str);
1491   EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
1492 
1493   // Verify socket can be retagged with a new value and the current process's
1494   // UID.
1495   int32_t tag_val2 = 0x87654321;
1496   old_traffic = GetTaggedBytes(tag_val2);
1497   SocketTag tag2(getuid(), tag_val2);
1498   client.ApplySocketTag(tag2);
1499   // Client sends to the server.
1500   rv = WriteSocket(&client, simple_message);
1501   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1502   // Server waits for message.
1503   str = RecvFromSocket(&server);
1504   EXPECT_EQ(simple_message, str);
1505   // Server echoes reply.
1506   rv = SendToSocket(&server, simple_message);
1507   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1508   // Client waits for response.
1509   str = ReadSocket(&client);
1510   EXPECT_EQ(simple_message, str);
1511   EXPECT_GT(GetTaggedBytes(tag_val2), old_traffic);
1512 
1513   // Verify socket can be retagged with a new value and the current process's
1514   // UID.
1515   old_traffic = GetTaggedBytes(tag_val1);
1516   client.ApplySocketTag(tag1);
1517   // Client sends to the server.
1518   rv = WriteSocket(&client, simple_message);
1519   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1520   // Server waits for message.
1521   str = RecvFromSocket(&server);
1522   EXPECT_EQ(simple_message, str);
1523   // Server echoes reply.
1524   rv = SendToSocket(&server, simple_message);
1525   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1526   // Client waits for response.
1527   str = ReadSocket(&client);
1528   EXPECT_EQ(simple_message, str);
1529   EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
1530 }
1531 
TEST_F(UDPSocketTest,BindToNetwork)1532 TEST_F(UDPSocketTest, BindToNetwork) {
1533   // The specific value of this address doesn't really matter, and no
1534   // server needs to be running here. The test only needs to call
1535   // Connect() and won't send any datagrams.
1536   const IPEndPoint fake_server_address(IPAddress::IPv4Localhost(), 8080);
1537   NetworkChangeNotifierFactoryAndroid ncn_factory;
1538   NetworkChangeNotifier::DisableForTest ncn_disable_for_test;
1539   std::unique_ptr<NetworkChangeNotifier> ncn(ncn_factory.CreateInstance());
1540   if (!NetworkChangeNotifier::AreNetworkHandlesSupported())
1541     GTEST_SKIP() << "Network handles are required to test BindToNetwork.";
1542 
1543   // Binding the socket to a not existing network should fail at connect time.
1544   const handles::NetworkHandle wrong_network_handle = 65536;
1545   UDPClientSocket wrong_socket(DatagramSocket::RANDOM_BIND, nullptr,
1546                                NetLogSource(), wrong_network_handle);
1547   // Different Android versions might report different errors. Hence, just check
1548   // what shouldn't happen.
1549   int rv = wrong_socket.Connect(fake_server_address);
1550   EXPECT_NE(OK, rv);
1551   EXPECT_NE(ERR_NOT_IMPLEMENTED, rv);
1552   EXPECT_NE(wrong_network_handle, wrong_socket.GetBoundNetwork());
1553 
1554   // Binding the socket to an existing network should succeed.
1555   const handles::NetworkHandle network_handle =
1556       NetworkChangeNotifier::GetDefaultNetwork();
1557   if (network_handle != handles::kInvalidNetworkHandle) {
1558     UDPClientSocket correct_socket(DatagramSocket::RANDOM_BIND, nullptr,
1559                                    NetLogSource(), network_handle);
1560     EXPECT_EQ(OK, correct_socket.Connect(fake_server_address));
1561     EXPECT_EQ(network_handle, correct_socket.GetBoundNetwork());
1562   }
1563 }
1564 
1565 #endif  // BUILDFLAG(IS_ANDROID)
1566 
1567 // Scoped helper to override the process-wide UDP socket limit.
1568 class OverrideUDPSocketLimit {
1569  public:
OverrideUDPSocketLimit(int new_limit)1570   explicit OverrideUDPSocketLimit(int new_limit) {
1571     base::FieldTrialParams params;
1572     params[features::kLimitOpenUDPSocketsMax.name] =
1573         base::NumberToString(new_limit);
1574 
1575     scoped_feature_list_.InitAndEnableFeatureWithParameters(
1576         features::kLimitOpenUDPSockets, params);
1577   }
1578 
1579  private:
1580   base::test::ScopedFeatureList scoped_feature_list_;
1581 };
1582 
1583 // Tests that UDPClientSocket respects the global UDP socket limits.
TEST_F(UDPSocketTest,LimitClientSocket)1584 TEST_F(UDPSocketTest, LimitClientSocket) {
1585   // Reduce the global UDP limit to 2.
1586   OverrideUDPSocketLimit set_limit(2);
1587 
1588   ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
1589 
1590   auto socket1 = std::make_unique<UDPClientSocket>(DatagramSocket::DEFAULT_BIND,
1591                                                    nullptr, NetLogSource());
1592   auto socket2 = std::make_unique<UDPClientSocket>(DatagramSocket::DEFAULT_BIND,
1593                                                    nullptr, NetLogSource());
1594 
1595   // Simply constructing a UDPClientSocket does not increase the limit (no
1596   // Connect() or Bind() has been called yet).
1597   ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
1598 
1599   // The specific value of this address doesn't really matter, and no server
1600   // needs to be running here. The test only needs to call Connect() and won't
1601   // send any datagrams.
1602   IPEndPoint server_address(IPAddress::IPv4Localhost(), 8080);
1603 
1604   // Successful Connect() on socket1 increases socket count.
1605   EXPECT_THAT(socket1->Connect(server_address), IsOk());
1606   EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1607 
1608   // Successful Connect() on socket2 increases socket count.
1609   EXPECT_THAT(socket2->Connect(server_address), IsOk());
1610   EXPECT_EQ(2, GetGlobalUDPSocketCountForTesting());
1611 
1612   // Attempting a third Connect() should fail with ERR_INSUFFICIENT_RESOURCES,
1613   // as the limit is currently 2.
1614   auto socket3 = std::make_unique<UDPClientSocket>(DatagramSocket::DEFAULT_BIND,
1615                                                    nullptr, NetLogSource());
1616   EXPECT_THAT(socket3->Connect(server_address),
1617               IsError(ERR_INSUFFICIENT_RESOURCES));
1618   EXPECT_EQ(2, GetGlobalUDPSocketCountForTesting());
1619 
1620   // Check that explicitly closing socket2 free up a count.
1621   socket2->Close();
1622   EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1623 
1624   // Since the socket was already closed, deleting it will not affect the count.
1625   socket2.reset();
1626   EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1627 
1628   // Now that the count is below limit, try to connect another socket. This time
1629   // it will work.
1630   auto socket4 = std::make_unique<UDPClientSocket>(DatagramSocket::DEFAULT_BIND,
1631                                                    nullptr, NetLogSource());
1632   EXPECT_THAT(socket4->Connect(server_address), IsOk());
1633   EXPECT_EQ(2, GetGlobalUDPSocketCountForTesting());
1634 
1635   // Verify that closing the two remaining sockets brings the open count back to
1636   // 0.
1637   socket1.reset();
1638   EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1639   socket4.reset();
1640   EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
1641 }
1642 
1643 // Tests that UDPSocketClient updates the global counter
1644 // correctly when Connect() fails.
TEST_F(UDPSocketTest,LimitConnectFail)1645 TEST_F(UDPSocketTest, LimitConnectFail) {
1646   ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
1647 
1648   {
1649     // Simply allocating a UDPSocket does not increase count.
1650     UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1651     EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
1652 
1653     // Calling Open() allocates the socket and increases the global counter.
1654     EXPECT_THAT(socket.Open(ADDRESS_FAMILY_IPV4), IsOk());
1655     EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1656 
1657     // Connect to an IPv6 address should fail since the socket was created for
1658     // IPv4.
1659     EXPECT_THAT(socket.Connect(net::IPEndPoint(IPAddress::IPv6Localhost(), 53)),
1660                 Not(IsOk()));
1661 
1662     // That Connect() failed doesn't change the global counter.
1663     EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1664   }
1665 
1666   // Finally, destroying UDPSocket decrements the global counter.
1667   EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
1668 }
1669 
1670 // Tests allocating UDPClientSockets and Connect()ing them in parallel.
1671 //
1672 // This is primarily intended for coverage under TSAN, to check for races
1673 // enforcing the global socket counter.
TEST_F(UDPSocketTest,LimitConnectMultithreaded)1674 TEST_F(UDPSocketTest, LimitConnectMultithreaded) {
1675   ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
1676 
1677   // Start up some threads.
1678   std::vector<std::unique_ptr<base::Thread>> threads;
1679   for (size_t i = 0; i < 5; ++i) {
1680     threads.push_back(std::make_unique<base::Thread>("Worker thread"));
1681     ASSERT_TRUE(threads.back()->Start());
1682   }
1683 
1684   // Post tasks to each of the threads.
1685   for (const auto& thread : threads) {
1686     thread->task_runner()->PostTask(
1687         FROM_HERE, base::BindOnce([] {
1688           // The specific value of this address doesn't really matter, and no
1689           // server needs to be running here. The test only needs to call
1690           // Connect() and won't send any datagrams.
1691           IPEndPoint server_address(IPAddress::IPv4Localhost(), 8080);
1692 
1693           UDPClientSocket socket(DatagramSocket::DEFAULT_BIND, nullptr,
1694                                  NetLogSource());
1695           EXPECT_THAT(socket.Connect(server_address), IsOk());
1696         }));
1697   }
1698 
1699   // Complete all the tasks.
1700   threads.clear();
1701 
1702   EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
1703 }
1704 
1705 }  // namespace net
1706