• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/quic/quic_socket_data_provider.h"
6 
7 #include <memory>
8 
9 #include "base/strings/string_number_conversions.h"
10 #include "base/task/sequenced_task_runner.h"
11 #include "base/test/bind.h"
12 #include "base/test/gtest_util.h"
13 #include "net/base/io_buffer.h"
14 #include "net/quic/mock_quic_context.h"
15 #include "net/quic/quic_test_packet_maker.h"
16 #include "net/socket/datagram_client_socket.h"
17 #include "net/socket/diff_serv_code_point.h"
18 #include "net/socket/socket_test_util.h"
19 #include "net/test/test_with_task_environment.h"
20 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
21 #include "testing/gtest/include/gtest/gtest-spi.h"
22 #include "testing/gtest/include/gtest/gtest.h"
23 
24 namespace net::test {
25 
26 class QuicSocketDataProviderTest : public TestWithTaskEnvironment {
27  public:
QuicSocketDataProviderTest()28   QuicSocketDataProviderTest()
29       : packet_maker_(std::make_unique<QuicTestPacketMaker>(
30             version_,
31             quic::QuicUtils::CreateRandomConnectionId(
32                 context_.random_generator()),
33             context_.clock(),
34             "hostname",
35             quic::Perspective::IS_CLIENT,
36             /*client_priority_uses_incremental=*/true,
37             /*use_priority_header=*/true)) {}
38 
39   // Create a simple test packet.
TestPacket(uint64_t packet_number)40   std::unique_ptr<quic::QuicReceivedPacket> TestPacket(uint64_t packet_number) {
41     return packet_maker_->Packet(packet_number)
42         .AddMessageFrame(base::NumberToString(packet_number))
43         .Build();
44   }
45 
46  protected:
47   NetLogWithSource net_log_with_source_{
48       NetLogWithSource::Make(NetLogSourceType::NONE)};
49   quic::ParsedQuicVersion version_ = quic::ParsedQuicVersion::RFCv1();
50   MockQuicContext context_;
51   std::unique_ptr<QuicTestPacketMaker> packet_maker_;
52 };
53 
54 // A linear sequence of sync expectations completes.
TEST_F(QuicSocketDataProviderTest,LinearSequenceSync)55 TEST_F(QuicSocketDataProviderTest, LinearSequenceSync) {
56   QuicSocketDataProvider socket_data(version_);
57   MockClientSocketFactory socket_factory;
58 
59   socket_data.AddWrite("p1", TestPacket(1)).Sync();
60   socket_data.AddWrite("p2", TestPacket(2)).Sync();
61   socket_data.AddWrite("p3", TestPacket(3)).Sync();
62 
63   socket_factory.AddSocketDataProvider(&socket_data);
64   base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
65       FROM_HERE, base::BindLambdaForTesting([&]() {
66         std::unique_ptr<DatagramClientSocket> socket =
67             socket_factory.CreateDatagramClientSocket(
68                 DatagramSocket::BindType::DEFAULT_BIND, nullptr,
69                 net_log_with_source_.source());
70         socket->Connect(IPEndPoint());
71 
72         for (uint64_t packet_number = 1; packet_number < 4; packet_number++) {
73           std::unique_ptr<quic::QuicReceivedPacket> packet =
74               TestPacket(packet_number);
75           scoped_refptr<StringIOBuffer> buffer =
76               base::MakeRefCounted<StringIOBuffer>(
77                   std::string(packet->data(), packet->length()));
78           EXPECT_EQ(
79               static_cast<int>(packet->length()),
80               socket->Write(buffer.get(), packet->length(), base::DoNothing(),
81                             TRAFFIC_ANNOTATION_FOR_TESTS));
82         }
83       }));
84 
85   socket_data.RunUntilAllConsumed();
86 }
87 
88 // A linear sequence of async expectations completes.
TEST_F(QuicSocketDataProviderTest,LinearSequenceAsync)89 TEST_F(QuicSocketDataProviderTest, LinearSequenceAsync) {
90   QuicSocketDataProvider socket_data(version_);
91   MockClientSocketFactory socket_factory;
92 
93   socket_data.AddWrite("p1", TestPacket(1));
94   socket_data.AddWrite("p2", TestPacket(2));
95   socket_data.AddWrite("p3", TestPacket(3));
96 
97   socket_factory.AddSocketDataProvider(&socket_data);
98   std::unique_ptr<DatagramClientSocket> socket =
99       socket_factory.CreateDatagramClientSocket(
100           DatagramSocket::BindType::DEFAULT_BIND, nullptr,
101           net_log_with_source_.source());
102   socket->Connect(IPEndPoint());
103 
104   int next_packet = 1;
105   base::RepeatingCallback<void(int)> callback =
106       base::BindLambdaForTesting([&](int result) {
107         EXPECT_GT(result, 0);  // Bytes written or, on the first call, one.
108         if (next_packet <= 3) {
109           std::unique_ptr<quic::QuicReceivedPacket> packet =
110               TestPacket(next_packet++);
111           scoped_refptr<StringIOBuffer> buffer =
112               base::MakeRefCounted<StringIOBuffer>(
113                   std::string(packet->data(), packet->length()));
114           EXPECT_EQ(ERR_IO_PENDING,
115                     socket->Write(buffer.get(), packet->length(), callback,
116                                   TRAFFIC_ANNOTATION_FOR_TESTS));
117         }
118       });
119   callback.Run(1);
120   socket_data.RunUntilAllConsumed();
121 }
122 
123 // The `TosByte` builder method results in a correct TOS byte in the read.
TEST_F(QuicSocketDataProviderTest,ReadTos)124 TEST_F(QuicSocketDataProviderTest, ReadTos) {
125   QuicSocketDataProvider socket_data(version_);
126   MockClientSocketFactory socket_factory;
127   const uint8_t kTestTos = (DSCP_CS1 << 2) + ECN_CE;
128 
129   socket_data.AddRead("p1", TestPacket(1)).Sync().TosByte(kTestTos);
130 
131   socket_factory.AddSocketDataProvider(&socket_data);
132   std::unique_ptr<DatagramClientSocket> socket =
133       socket_factory.CreateDatagramClientSocket(
134           DatagramSocket::BindType::DEFAULT_BIND, nullptr,
135           net_log_with_source_.source());
136   socket->Connect(IPEndPoint());
137 
138   scoped_refptr<GrowableIOBuffer> read_buffer =
139       base::MakeRefCounted<GrowableIOBuffer>();
140   read_buffer->SetCapacity(100);
141   EXPECT_EQ(static_cast<int>(TestPacket(1)->length()),
142             socket->Read(read_buffer.get(), 100, base::DoNothing()));
143   DscpAndEcn dscp_and_ecn = socket->GetLastTos();
144   EXPECT_EQ(dscp_and_ecn.dscp, DSCP_CS1);
145   EXPECT_EQ(dscp_and_ecn.ecn, ECN_CE);
146 
147   socket_data.RunUntilAllConsumed();
148 }
149 
150 // AddReadError creates a read returning an error.
TEST_F(QuicSocketDataProviderTest,AddReadError)151 TEST_F(QuicSocketDataProviderTest, AddReadError) {
152   QuicSocketDataProvider socket_data(version_);
153   MockClientSocketFactory socket_factory;
154 
155   socket_data.AddReadError("p1", ERR_CONNECTION_ABORTED).Sync();
156 
157   socket_factory.AddSocketDataProvider(&socket_data);
158   std::unique_ptr<DatagramClientSocket> socket =
159       socket_factory.CreateDatagramClientSocket(
160           DatagramSocket::BindType::DEFAULT_BIND, nullptr,
161           net_log_with_source_.source());
162   socket->Connect(IPEndPoint());
163 
164   scoped_refptr<GrowableIOBuffer> read_buffer =
165       base::MakeRefCounted<GrowableIOBuffer>();
166   read_buffer->SetCapacity(100);
167   EXPECT_EQ(ERR_CONNECTION_ABORTED,
168             socket->Read(read_buffer.get(), 100, base::DoNothing()));
169 
170   socket_data.RunUntilAllConsumed();
171 }
172 
173 // AddRead with a QuicReceivedPacket correctly sets the ECN.
TEST_F(QuicSocketDataProviderTest,AddReadQuicReceivedPacketGetsEcn)174 TEST_F(QuicSocketDataProviderTest, AddReadQuicReceivedPacketGetsEcn) {
175   QuicSocketDataProvider socket_data(version_);
176   MockClientSocketFactory socket_factory;
177 
178   packet_maker_->set_ecn_codepoint(quic::QuicEcnCodepoint::ECN_ECT0);
179   socket_data.AddRead("p1", TestPacket(1)).Sync();
180 
181   socket_factory.AddSocketDataProvider(&socket_data);
182   std::unique_ptr<DatagramClientSocket> socket =
183       socket_factory.CreateDatagramClientSocket(
184           DatagramSocket::BindType::DEFAULT_BIND, nullptr,
185           net_log_with_source_.source());
186   socket->Connect(IPEndPoint());
187 
188   scoped_refptr<GrowableIOBuffer> read_buffer =
189       base::MakeRefCounted<GrowableIOBuffer>();
190   read_buffer->SetCapacity(100);
191   EXPECT_EQ(static_cast<int>(TestPacket(1)->length()),
192             socket->Read(read_buffer.get(), 100, base::DoNothing()));
193   DscpAndEcn dscp_and_ecn = socket->GetLastTos();
194   EXPECT_EQ(dscp_and_ecn.ecn, ECN_ECT0);
195 
196   socket_data.RunUntilAllConsumed();
197   EXPECT_TRUE(socket_data.AllReadDataConsumed());
198   EXPECT_TRUE(socket_data.AllWriteDataConsumed());
199 }
200 
201 // A write of data different from the expectation generates a failure.
TEST_F(QuicSocketDataProviderTest,MismatchedWrite)202 TEST_F(QuicSocketDataProviderTest, MismatchedWrite) {
203   QuicSocketDataProvider socket_data(version_);
204   MockClientSocketFactory socket_factory;
205 
206   socket_data.AddWrite("p1", TestPacket(1)).Sync();
207 
208   socket_factory.AddSocketDataProvider(&socket_data);
209   std::unique_ptr<DatagramClientSocket> socket =
210       socket_factory.CreateDatagramClientSocket(
211           DatagramSocket::BindType::DEFAULT_BIND, nullptr,
212           net_log_with_source_.source());
213   socket->Connect(IPEndPoint());
214 
215   std::unique_ptr<quic::QuicReceivedPacket> packet = TestPacket(999);
216   scoped_refptr<StringIOBuffer> buffer = base::MakeRefCounted<StringIOBuffer>(
217       std::string(packet->data(), packet->length()));
218   EXPECT_NONFATAL_FAILURE(
219       EXPECT_EQ(ERR_UNEXPECTED,
220                 socket->Write(buffer.get(), packet->length(), base::DoNothing(),
221                               TRAFFIC_ANNOTATION_FOR_TESTS)),
222       "Expectation 'p1' not met.");
223 }
224 
225 // AllDataConsumed is false if there are still pending expectations.
TEST_F(QuicSocketDataProviderTest,NotAllConsumed)226 TEST_F(QuicSocketDataProviderTest, NotAllConsumed) {
227   QuicSocketDataProvider socket_data(version_);
228   MockClientSocketFactory socket_factory;
229 
230   socket_data.AddWrite("p1", TestPacket(1)).Sync();
231   socket_data.AddWrite("p2", TestPacket(2)).Sync();
232 
233   socket_factory.AddSocketDataProvider(&socket_data);
234   std::unique_ptr<DatagramClientSocket> socket =
235       socket_factory.CreateDatagramClientSocket(
236           DatagramSocket::BindType::DEFAULT_BIND, nullptr,
237           net_log_with_source_.source());
238   socket->Connect(IPEndPoint());
239 
240   std::unique_ptr<quic::QuicReceivedPacket> packet = TestPacket(1);
241   scoped_refptr<StringIOBuffer> buffer = base::MakeRefCounted<StringIOBuffer>(
242       std::string(packet->data(), packet->length()));
243   EXPECT_EQ(static_cast<int>(packet->length()),
244             socket->Write(buffer.get(), packet->length(), base::DoNothing(),
245                           TRAFFIC_ANNOTATION_FOR_TESTS));
246 
247   EXPECT_FALSE(socket_data.AllDataConsumed());
248 }
249 
250 // When a Write call occurs with no matching expectation, that is treated as an
251 // error.
TEST_F(QuicSocketDataProviderTest,ReadBlocksWrite)252 TEST_F(QuicSocketDataProviderTest, ReadBlocksWrite) {
253   QuicSocketDataProvider socket_data(version_);
254   MockClientSocketFactory socket_factory;
255 
256   socket_data.AddRead("p1", TestPacket(1)).Sync();
257   socket_data.AddWrite("p2", TestPacket(2)).Sync();
258 
259   socket_factory.AddSocketDataProvider(&socket_data);
260   std::unique_ptr<DatagramClientSocket> socket =
261       socket_factory.CreateDatagramClientSocket(
262           DatagramSocket::BindType::DEFAULT_BIND, nullptr,
263           net_log_with_source_.source());
264   socket->Connect(IPEndPoint());
265 
266   std::unique_ptr<quic::QuicReceivedPacket> packet = TestPacket(1);
267   scoped_refptr<StringIOBuffer> buffer = base::MakeRefCounted<StringIOBuffer>(
268       std::string(packet->data(), packet->length()));
269   EXPECT_NONFATAL_FAILURE(
270       EXPECT_EQ(ERR_UNEXPECTED,
271                 socket->Write(buffer.get(), packet->length(), base::DoNothing(),
272                               TRAFFIC_ANNOTATION_FOR_TESTS)),
273       "Write call when none is expected:");
274 }
275 
276 // When a Read call occurs with no matching expectation, it waits for a matching
277 // expectation to become read.
TEST_F(QuicSocketDataProviderTest,WriteDelaysRead)278 TEST_F(QuicSocketDataProviderTest, WriteDelaysRead) {
279   QuicSocketDataProvider socket_data(version_);
280   MockClientSocketFactory socket_factory;
281 
282   socket_data.AddWrite("p1", TestPacket(1)).Sync();
283   socket_data.AddRead("p2", TestPacket(22222)).Sync();
284 
285   socket_factory.AddSocketDataProvider(&socket_data);
286   std::unique_ptr<DatagramClientSocket> socket =
287       socket_factory.CreateDatagramClientSocket(
288           DatagramSocket::BindType::DEFAULT_BIND, nullptr,
289           net_log_with_source_.source());
290   socket->Connect(IPEndPoint());
291 
292   // Begin a read operation which should not complete yet.
293   bool read_completed = false;
294   base::OnceCallback<void(int)> read_callback =
295       base::BindLambdaForTesting([&](int result) {
296         EXPECT_EQ(result, static_cast<int>(TestPacket(22222)->length()));
297         read_completed = true;
298       });
299   scoped_refptr<GrowableIOBuffer> read_buffer =
300       base::MakeRefCounted<GrowableIOBuffer>();
301   read_buffer->SetCapacity(100);
302   EXPECT_EQ(ERR_IO_PENDING,
303             socket->Read(read_buffer.get(), 100, std::move(read_callback)));
304 
305   EXPECT_FALSE(read_completed);
306 
307   // Perform the write on which the read depends.
308   std::unique_ptr<quic::QuicReceivedPacket> packet = TestPacket(1);
309   scoped_refptr<StringIOBuffer> buffer = base::MakeRefCounted<StringIOBuffer>(
310       std::string(packet->data(), packet->length()));
311   EXPECT_EQ(static_cast<int>(packet->length()),
312             socket->Write(buffer.get(), packet->length(), base::DoNothing(),
313                           TRAFFIC_ANNOTATION_FOR_TESTS));
314 
315   socket_data.RunUntilAllConsumed();
316   EXPECT_TRUE(read_completed);
317 }
318 
319 // When a pause becomes ready, subsequent calls are delayed.
TEST_F(QuicSocketDataProviderTest,PauseDelaysCalls)320 TEST_F(QuicSocketDataProviderTest, PauseDelaysCalls) {
321   QuicSocketDataProvider socket_data(version_);
322   MockClientSocketFactory socket_factory;
323 
324   socket_data.AddWrite("p1", TestPacket(1)).Sync();
325   auto pause = socket_data.AddPause("pause");
326   socket_data.AddRead("p2", TestPacket(2)).After("pause");
327   socket_data.AddWrite("p3", TestPacket(3)).After("pause");
328 
329   socket_factory.AddSocketDataProvider(&socket_data);
330   std::unique_ptr<DatagramClientSocket> socket =
331       socket_factory.CreateDatagramClientSocket(
332           DatagramSocket::BindType::DEFAULT_BIND, nullptr,
333           net_log_with_source_.source());
334   socket->Connect(IPEndPoint());
335 
336   // Perform a write in another task, and wait for the pause.
337   bool write_completed = false;
338   base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
339       FROM_HERE, base::BindLambdaForTesting([&]() {
340         std::unique_ptr<quic::QuicReceivedPacket> packet = TestPacket(1);
341         scoped_refptr<StringIOBuffer> buffer =
342             base::MakeRefCounted<StringIOBuffer>(
343                 std::string(packet->data(), packet->length()));
344         EXPECT_EQ(
345             static_cast<int>(packet->length()),
346             socket->Write(buffer.get(), packet->length(), base::DoNothing(),
347                           TRAFFIC_ANNOTATION_FOR_TESTS));
348         write_completed = true;
349       }));
350 
351   EXPECT_FALSE(write_completed);
352   socket_data.RunUntilPause(pause);
353   EXPECT_TRUE(write_completed);
354 
355   // Begin a read operation which should not complete yet.
356   bool read_completed = false;
357   base::OnceCallback<void(int)> read_callback =
358       base::BindLambdaForTesting([&](int result) {
359         EXPECT_EQ(result, static_cast<int>(TestPacket(2)->length()));
360         read_completed = true;
361       });
362   scoped_refptr<GrowableIOBuffer> read_buffer =
363       base::MakeRefCounted<GrowableIOBuffer>();
364   read_buffer->SetCapacity(100);
365   EXPECT_EQ(ERR_IO_PENDING,
366             socket->Read(read_buffer.get(), 100, std::move(read_callback)));
367 
368   // Begin a write operation which should not complete yet.
369   write_completed = false;
370   base::OnceCallback<void(int)> write_callback =
371       base::BindLambdaForTesting([&](int result) {
372         EXPECT_EQ(result, static_cast<int>(TestPacket(3)->length()));
373         write_completed = true;
374       });
375   std::unique_ptr<quic::QuicReceivedPacket> packet = TestPacket(3);
376   scoped_refptr<StringIOBuffer> buffer = base::MakeRefCounted<StringIOBuffer>(
377       std::string(packet->data(), packet->length()));
378   EXPECT_EQ(ERR_IO_PENDING, socket->Write(buffer.get(), packet->length(),
379                                           std::move(write_callback),
380                                           TRAFFIC_ANNOTATION_FOR_TESTS));
381 
382   EXPECT_FALSE(read_completed);
383   EXPECT_FALSE(write_completed);
384 
385   socket_data.Resume();
386   socket_data.RunUntilAllConsumed();
387   RunUntilIdle();
388 
389   EXPECT_TRUE(read_completed);
390   EXPECT_TRUE(write_completed);
391 }
392 
393 // Using `After`, a `Read` and `Write` can be allowed in either order.
TEST_F(QuicSocketDataProviderTest,ParallelReadAndWrite)394 TEST_F(QuicSocketDataProviderTest, ParallelReadAndWrite) {
395   for (bool read_first : {false, true}) {
396     SCOPED_TRACE(::testing::Message() << "read_first: " << read_first);
397     QuicSocketDataProvider socket_data(version_);
398     MockClientSocketFactory socket_factory;
399 
400     socket_data.AddWrite("p1", TestPacket(1)).Sync();
401     socket_data.AddRead("p2", TestPacket(2)).Sync().After("p1");
402     socket_data.AddWrite("p3", TestPacket(3)).Sync().After("p1");
403 
404     socket_factory.AddSocketDataProvider(&socket_data);
405     std::unique_ptr<DatagramClientSocket> socket =
406         socket_factory.CreateDatagramClientSocket(
407             DatagramSocket::BindType::DEFAULT_BIND, nullptr,
408             net_log_with_source_.source());
409     socket->Connect(IPEndPoint());
410 
411     // Write p1 to get things started.
412     std::unique_ptr<quic::QuicReceivedPacket> packet = TestPacket(1);
413     scoped_refptr<IOBuffer> buffer = base::MakeRefCounted<StringIOBuffer>(
414         std::string(packet->data(), packet->length()));
415     EXPECT_EQ(static_cast<int>(packet->length()),
416               socket->Write(buffer.get(), packet->length(), base::DoNothing(),
417                             TRAFFIC_ANNOTATION_FOR_TESTS));
418 
419     scoped_refptr<GrowableIOBuffer> read_buffer =
420         base::MakeRefCounted<GrowableIOBuffer>();
421     read_buffer->SetCapacity(100);
422     auto do_read = [&]() {
423       EXPECT_EQ(static_cast<int>(TestPacket(2)->length()),
424                 socket->Read(read_buffer.get(), 100, base::DoNothing()));
425     };
426 
427     std::unique_ptr<quic::QuicReceivedPacket> write_packet = TestPacket(3);
428     buffer = base::MakeRefCounted<StringIOBuffer>(
429         std::string(write_packet->data(), write_packet->length()));
430 
431     auto do_write = [&]() {
432       EXPECT_EQ(static_cast<int>(write_packet->length()),
433                 socket->Write(buffer.get(), write_packet->length(),
434                               base::DoNothing(), TRAFFIC_ANNOTATION_FOR_TESTS));
435     };
436 
437     // Read p2 and write p3 in both orders.
438     if (read_first) {
439       do_read();
440       do_write();
441     } else {
442       do_write();
443       do_read();
444     }
445 
446     socket_data.RunUntilAllConsumed();
447   }
448 }
449 
450 // When multiple Read expectations become ready at the same time, fail with a
451 // CHECK error.
TEST_F(QuicSocketDataProviderTest,MultipleReadsReady)452 TEST_F(QuicSocketDataProviderTest, MultipleReadsReady) {
453   QuicSocketDataProvider socket_data(version_);
454   MockClientSocketFactory socket_factory;
455 
456   socket_data.AddWrite("p1", TestPacket(1)).Sync();
457   socket_data.AddRead("p2", TestPacket(2)).After("p1");
458   socket_data.AddRead("p3", TestPacket(3)).After("p1");
459 
460   socket_factory.AddSocketDataProvider(&socket_data);
461   std::unique_ptr<DatagramClientSocket> socket =
462       socket_factory.CreateDatagramClientSocket(
463           DatagramSocket::BindType::DEFAULT_BIND, nullptr,
464           net_log_with_source_.source());
465   socket->Connect(IPEndPoint());
466 
467   std::unique_ptr<quic::QuicReceivedPacket> packet = TestPacket(1);
468   scoped_refptr<StringIOBuffer> buffer = base::MakeRefCounted<StringIOBuffer>(
469       std::string(packet->data(), packet->length()));
470   EXPECT_EQ(static_cast<int>(packet->length()),
471             socket->Write(buffer.get(), packet->length(), base::DoNothing(),
472                           TRAFFIC_ANNOTATION_FOR_TESTS));
473   EXPECT_CHECK_DEATH(
474       socket->Read(buffer.get(), buffer->size(), base::DoNothing()));
475 }
476 
477 }  // namespace net::test
478