1 // Copyright 2024 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/quic/quic_socket_data_provider.h"
6
7 #include <memory>
8
9 #include "base/strings/string_number_conversions.h"
10 #include "base/task/sequenced_task_runner.h"
11 #include "base/test/bind.h"
12 #include "base/test/gtest_util.h"
13 #include "net/base/io_buffer.h"
14 #include "net/quic/mock_quic_context.h"
15 #include "net/quic/quic_test_packet_maker.h"
16 #include "net/socket/datagram_client_socket.h"
17 #include "net/socket/diff_serv_code_point.h"
18 #include "net/socket/socket_test_util.h"
19 #include "net/test/test_with_task_environment.h"
20 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
21 #include "testing/gtest/include/gtest/gtest-spi.h"
22 #include "testing/gtest/include/gtest/gtest.h"
23
24 namespace net::test {
25
26 class QuicSocketDataProviderTest : public TestWithTaskEnvironment {
27 public:
QuicSocketDataProviderTest()28 QuicSocketDataProviderTest()
29 : packet_maker_(std::make_unique<QuicTestPacketMaker>(
30 version_,
31 quic::QuicUtils::CreateRandomConnectionId(
32 context_.random_generator()),
33 context_.clock(),
34 "hostname",
35 quic::Perspective::IS_CLIENT,
36 /*client_priority_uses_incremental=*/true,
37 /*use_priority_header=*/true)) {}
38
39 // Create a simple test packet.
TestPacket(uint64_t packet_number)40 std::unique_ptr<quic::QuicReceivedPacket> TestPacket(uint64_t packet_number) {
41 return packet_maker_->Packet(packet_number)
42 .AddMessageFrame(base::NumberToString(packet_number))
43 .Build();
44 }
45
46 protected:
47 NetLogWithSource net_log_with_source_{
48 NetLogWithSource::Make(NetLogSourceType::NONE)};
49 quic::ParsedQuicVersion version_ = quic::ParsedQuicVersion::RFCv1();
50 MockQuicContext context_;
51 std::unique_ptr<QuicTestPacketMaker> packet_maker_;
52 };
53
54 // A linear sequence of sync expectations completes.
TEST_F(QuicSocketDataProviderTest,LinearSequenceSync)55 TEST_F(QuicSocketDataProviderTest, LinearSequenceSync) {
56 QuicSocketDataProvider socket_data(version_);
57 MockClientSocketFactory socket_factory;
58
59 socket_data.AddWrite("p1", TestPacket(1)).Sync();
60 socket_data.AddWrite("p2", TestPacket(2)).Sync();
61 socket_data.AddWrite("p3", TestPacket(3)).Sync();
62
63 socket_factory.AddSocketDataProvider(&socket_data);
64 base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
65 FROM_HERE, base::BindLambdaForTesting([&]() {
66 std::unique_ptr<DatagramClientSocket> socket =
67 socket_factory.CreateDatagramClientSocket(
68 DatagramSocket::BindType::DEFAULT_BIND, nullptr,
69 net_log_with_source_.source());
70 socket->Connect(IPEndPoint());
71
72 for (uint64_t packet_number = 1; packet_number < 4; packet_number++) {
73 std::unique_ptr<quic::QuicReceivedPacket> packet =
74 TestPacket(packet_number);
75 scoped_refptr<StringIOBuffer> buffer =
76 base::MakeRefCounted<StringIOBuffer>(
77 std::string(packet->data(), packet->length()));
78 EXPECT_EQ(
79 static_cast<int>(packet->length()),
80 socket->Write(buffer.get(), packet->length(), base::DoNothing(),
81 TRAFFIC_ANNOTATION_FOR_TESTS));
82 }
83 }));
84
85 socket_data.RunUntilAllConsumed();
86 }
87
88 // A linear sequence of async expectations completes.
TEST_F(QuicSocketDataProviderTest,LinearSequenceAsync)89 TEST_F(QuicSocketDataProviderTest, LinearSequenceAsync) {
90 QuicSocketDataProvider socket_data(version_);
91 MockClientSocketFactory socket_factory;
92
93 socket_data.AddWrite("p1", TestPacket(1));
94 socket_data.AddWrite("p2", TestPacket(2));
95 socket_data.AddWrite("p3", TestPacket(3));
96
97 socket_factory.AddSocketDataProvider(&socket_data);
98 std::unique_ptr<DatagramClientSocket> socket =
99 socket_factory.CreateDatagramClientSocket(
100 DatagramSocket::BindType::DEFAULT_BIND, nullptr,
101 net_log_with_source_.source());
102 socket->Connect(IPEndPoint());
103
104 int next_packet = 1;
105 base::RepeatingCallback<void(int)> callback =
106 base::BindLambdaForTesting([&](int result) {
107 EXPECT_GT(result, 0); // Bytes written or, on the first call, one.
108 if (next_packet <= 3) {
109 std::unique_ptr<quic::QuicReceivedPacket> packet =
110 TestPacket(next_packet++);
111 scoped_refptr<StringIOBuffer> buffer =
112 base::MakeRefCounted<StringIOBuffer>(
113 std::string(packet->data(), packet->length()));
114 EXPECT_EQ(ERR_IO_PENDING,
115 socket->Write(buffer.get(), packet->length(), callback,
116 TRAFFIC_ANNOTATION_FOR_TESTS));
117 }
118 });
119 callback.Run(1);
120 socket_data.RunUntilAllConsumed();
121 }
122
123 // The `TosByte` builder method results in a correct TOS byte in the read.
TEST_F(QuicSocketDataProviderTest,ReadTos)124 TEST_F(QuicSocketDataProviderTest, ReadTos) {
125 QuicSocketDataProvider socket_data(version_);
126 MockClientSocketFactory socket_factory;
127 const uint8_t kTestTos = (DSCP_CS1 << 2) + ECN_CE;
128
129 socket_data.AddRead("p1", TestPacket(1)).Sync().TosByte(kTestTos);
130
131 socket_factory.AddSocketDataProvider(&socket_data);
132 std::unique_ptr<DatagramClientSocket> socket =
133 socket_factory.CreateDatagramClientSocket(
134 DatagramSocket::BindType::DEFAULT_BIND, nullptr,
135 net_log_with_source_.source());
136 socket->Connect(IPEndPoint());
137
138 scoped_refptr<GrowableIOBuffer> read_buffer =
139 base::MakeRefCounted<GrowableIOBuffer>();
140 read_buffer->SetCapacity(100);
141 EXPECT_EQ(static_cast<int>(TestPacket(1)->length()),
142 socket->Read(read_buffer.get(), 100, base::DoNothing()));
143 DscpAndEcn dscp_and_ecn = socket->GetLastTos();
144 EXPECT_EQ(dscp_and_ecn.dscp, DSCP_CS1);
145 EXPECT_EQ(dscp_and_ecn.ecn, ECN_CE);
146
147 socket_data.RunUntilAllConsumed();
148 }
149
150 // AddReadError creates a read returning an error.
TEST_F(QuicSocketDataProviderTest,AddReadError)151 TEST_F(QuicSocketDataProviderTest, AddReadError) {
152 QuicSocketDataProvider socket_data(version_);
153 MockClientSocketFactory socket_factory;
154
155 socket_data.AddReadError("p1", ERR_CONNECTION_ABORTED).Sync();
156
157 socket_factory.AddSocketDataProvider(&socket_data);
158 std::unique_ptr<DatagramClientSocket> socket =
159 socket_factory.CreateDatagramClientSocket(
160 DatagramSocket::BindType::DEFAULT_BIND, nullptr,
161 net_log_with_source_.source());
162 socket->Connect(IPEndPoint());
163
164 scoped_refptr<GrowableIOBuffer> read_buffer =
165 base::MakeRefCounted<GrowableIOBuffer>();
166 read_buffer->SetCapacity(100);
167 EXPECT_EQ(ERR_CONNECTION_ABORTED,
168 socket->Read(read_buffer.get(), 100, base::DoNothing()));
169
170 socket_data.RunUntilAllConsumed();
171 }
172
173 // AddRead with a QuicReceivedPacket correctly sets the ECN.
TEST_F(QuicSocketDataProviderTest,AddReadQuicReceivedPacketGetsEcn)174 TEST_F(QuicSocketDataProviderTest, AddReadQuicReceivedPacketGetsEcn) {
175 QuicSocketDataProvider socket_data(version_);
176 MockClientSocketFactory socket_factory;
177
178 packet_maker_->set_ecn_codepoint(quic::QuicEcnCodepoint::ECN_ECT0);
179 socket_data.AddRead("p1", TestPacket(1)).Sync();
180
181 socket_factory.AddSocketDataProvider(&socket_data);
182 std::unique_ptr<DatagramClientSocket> socket =
183 socket_factory.CreateDatagramClientSocket(
184 DatagramSocket::BindType::DEFAULT_BIND, nullptr,
185 net_log_with_source_.source());
186 socket->Connect(IPEndPoint());
187
188 scoped_refptr<GrowableIOBuffer> read_buffer =
189 base::MakeRefCounted<GrowableIOBuffer>();
190 read_buffer->SetCapacity(100);
191 EXPECT_EQ(static_cast<int>(TestPacket(1)->length()),
192 socket->Read(read_buffer.get(), 100, base::DoNothing()));
193 DscpAndEcn dscp_and_ecn = socket->GetLastTos();
194 EXPECT_EQ(dscp_and_ecn.ecn, ECN_ECT0);
195
196 socket_data.RunUntilAllConsumed();
197 EXPECT_TRUE(socket_data.AllReadDataConsumed());
198 EXPECT_TRUE(socket_data.AllWriteDataConsumed());
199 }
200
201 // A write of data different from the expectation generates a failure.
TEST_F(QuicSocketDataProviderTest,MismatchedWrite)202 TEST_F(QuicSocketDataProviderTest, MismatchedWrite) {
203 QuicSocketDataProvider socket_data(version_);
204 MockClientSocketFactory socket_factory;
205
206 socket_data.AddWrite("p1", TestPacket(1)).Sync();
207
208 socket_factory.AddSocketDataProvider(&socket_data);
209 std::unique_ptr<DatagramClientSocket> socket =
210 socket_factory.CreateDatagramClientSocket(
211 DatagramSocket::BindType::DEFAULT_BIND, nullptr,
212 net_log_with_source_.source());
213 socket->Connect(IPEndPoint());
214
215 std::unique_ptr<quic::QuicReceivedPacket> packet = TestPacket(999);
216 scoped_refptr<StringIOBuffer> buffer = base::MakeRefCounted<StringIOBuffer>(
217 std::string(packet->data(), packet->length()));
218 EXPECT_NONFATAL_FAILURE(
219 EXPECT_EQ(ERR_UNEXPECTED,
220 socket->Write(buffer.get(), packet->length(), base::DoNothing(),
221 TRAFFIC_ANNOTATION_FOR_TESTS)),
222 "Expectation 'p1' not met.");
223 }
224
225 // AllDataConsumed is false if there are still pending expectations.
TEST_F(QuicSocketDataProviderTest,NotAllConsumed)226 TEST_F(QuicSocketDataProviderTest, NotAllConsumed) {
227 QuicSocketDataProvider socket_data(version_);
228 MockClientSocketFactory socket_factory;
229
230 socket_data.AddWrite("p1", TestPacket(1)).Sync();
231 socket_data.AddWrite("p2", TestPacket(2)).Sync();
232
233 socket_factory.AddSocketDataProvider(&socket_data);
234 std::unique_ptr<DatagramClientSocket> socket =
235 socket_factory.CreateDatagramClientSocket(
236 DatagramSocket::BindType::DEFAULT_BIND, nullptr,
237 net_log_with_source_.source());
238 socket->Connect(IPEndPoint());
239
240 std::unique_ptr<quic::QuicReceivedPacket> packet = TestPacket(1);
241 scoped_refptr<StringIOBuffer> buffer = base::MakeRefCounted<StringIOBuffer>(
242 std::string(packet->data(), packet->length()));
243 EXPECT_EQ(static_cast<int>(packet->length()),
244 socket->Write(buffer.get(), packet->length(), base::DoNothing(),
245 TRAFFIC_ANNOTATION_FOR_TESTS));
246
247 EXPECT_FALSE(socket_data.AllDataConsumed());
248 }
249
250 // When a Write call occurs with no matching expectation, that is treated as an
251 // error.
TEST_F(QuicSocketDataProviderTest,ReadBlocksWrite)252 TEST_F(QuicSocketDataProviderTest, ReadBlocksWrite) {
253 QuicSocketDataProvider socket_data(version_);
254 MockClientSocketFactory socket_factory;
255
256 socket_data.AddRead("p1", TestPacket(1)).Sync();
257 socket_data.AddWrite("p2", TestPacket(2)).Sync();
258
259 socket_factory.AddSocketDataProvider(&socket_data);
260 std::unique_ptr<DatagramClientSocket> socket =
261 socket_factory.CreateDatagramClientSocket(
262 DatagramSocket::BindType::DEFAULT_BIND, nullptr,
263 net_log_with_source_.source());
264 socket->Connect(IPEndPoint());
265
266 std::unique_ptr<quic::QuicReceivedPacket> packet = TestPacket(1);
267 scoped_refptr<StringIOBuffer> buffer = base::MakeRefCounted<StringIOBuffer>(
268 std::string(packet->data(), packet->length()));
269 EXPECT_NONFATAL_FAILURE(
270 EXPECT_EQ(ERR_UNEXPECTED,
271 socket->Write(buffer.get(), packet->length(), base::DoNothing(),
272 TRAFFIC_ANNOTATION_FOR_TESTS)),
273 "Write call when none is expected:");
274 }
275
276 // When a Read call occurs with no matching expectation, it waits for a matching
277 // expectation to become read.
TEST_F(QuicSocketDataProviderTest,WriteDelaysRead)278 TEST_F(QuicSocketDataProviderTest, WriteDelaysRead) {
279 QuicSocketDataProvider socket_data(version_);
280 MockClientSocketFactory socket_factory;
281
282 socket_data.AddWrite("p1", TestPacket(1)).Sync();
283 socket_data.AddRead("p2", TestPacket(22222)).Sync();
284
285 socket_factory.AddSocketDataProvider(&socket_data);
286 std::unique_ptr<DatagramClientSocket> socket =
287 socket_factory.CreateDatagramClientSocket(
288 DatagramSocket::BindType::DEFAULT_BIND, nullptr,
289 net_log_with_source_.source());
290 socket->Connect(IPEndPoint());
291
292 // Begin a read operation which should not complete yet.
293 bool read_completed = false;
294 base::OnceCallback<void(int)> read_callback =
295 base::BindLambdaForTesting([&](int result) {
296 EXPECT_EQ(result, static_cast<int>(TestPacket(22222)->length()));
297 read_completed = true;
298 });
299 scoped_refptr<GrowableIOBuffer> read_buffer =
300 base::MakeRefCounted<GrowableIOBuffer>();
301 read_buffer->SetCapacity(100);
302 EXPECT_EQ(ERR_IO_PENDING,
303 socket->Read(read_buffer.get(), 100, std::move(read_callback)));
304
305 EXPECT_FALSE(read_completed);
306
307 // Perform the write on which the read depends.
308 std::unique_ptr<quic::QuicReceivedPacket> packet = TestPacket(1);
309 scoped_refptr<StringIOBuffer> buffer = base::MakeRefCounted<StringIOBuffer>(
310 std::string(packet->data(), packet->length()));
311 EXPECT_EQ(static_cast<int>(packet->length()),
312 socket->Write(buffer.get(), packet->length(), base::DoNothing(),
313 TRAFFIC_ANNOTATION_FOR_TESTS));
314
315 socket_data.RunUntilAllConsumed();
316 EXPECT_TRUE(read_completed);
317 }
318
319 // When a pause becomes ready, subsequent calls are delayed.
TEST_F(QuicSocketDataProviderTest,PauseDelaysCalls)320 TEST_F(QuicSocketDataProviderTest, PauseDelaysCalls) {
321 QuicSocketDataProvider socket_data(version_);
322 MockClientSocketFactory socket_factory;
323
324 socket_data.AddWrite("p1", TestPacket(1)).Sync();
325 auto pause = socket_data.AddPause("pause");
326 socket_data.AddRead("p2", TestPacket(2)).After("pause");
327 socket_data.AddWrite("p3", TestPacket(3)).After("pause");
328
329 socket_factory.AddSocketDataProvider(&socket_data);
330 std::unique_ptr<DatagramClientSocket> socket =
331 socket_factory.CreateDatagramClientSocket(
332 DatagramSocket::BindType::DEFAULT_BIND, nullptr,
333 net_log_with_source_.source());
334 socket->Connect(IPEndPoint());
335
336 // Perform a write in another task, and wait for the pause.
337 bool write_completed = false;
338 base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
339 FROM_HERE, base::BindLambdaForTesting([&]() {
340 std::unique_ptr<quic::QuicReceivedPacket> packet = TestPacket(1);
341 scoped_refptr<StringIOBuffer> buffer =
342 base::MakeRefCounted<StringIOBuffer>(
343 std::string(packet->data(), packet->length()));
344 EXPECT_EQ(
345 static_cast<int>(packet->length()),
346 socket->Write(buffer.get(), packet->length(), base::DoNothing(),
347 TRAFFIC_ANNOTATION_FOR_TESTS));
348 write_completed = true;
349 }));
350
351 EXPECT_FALSE(write_completed);
352 socket_data.RunUntilPause(pause);
353 EXPECT_TRUE(write_completed);
354
355 // Begin a read operation which should not complete yet.
356 bool read_completed = false;
357 base::OnceCallback<void(int)> read_callback =
358 base::BindLambdaForTesting([&](int result) {
359 EXPECT_EQ(result, static_cast<int>(TestPacket(2)->length()));
360 read_completed = true;
361 });
362 scoped_refptr<GrowableIOBuffer> read_buffer =
363 base::MakeRefCounted<GrowableIOBuffer>();
364 read_buffer->SetCapacity(100);
365 EXPECT_EQ(ERR_IO_PENDING,
366 socket->Read(read_buffer.get(), 100, std::move(read_callback)));
367
368 // Begin a write operation which should not complete yet.
369 write_completed = false;
370 base::OnceCallback<void(int)> write_callback =
371 base::BindLambdaForTesting([&](int result) {
372 EXPECT_EQ(result, static_cast<int>(TestPacket(3)->length()));
373 write_completed = true;
374 });
375 std::unique_ptr<quic::QuicReceivedPacket> packet = TestPacket(3);
376 scoped_refptr<StringIOBuffer> buffer = base::MakeRefCounted<StringIOBuffer>(
377 std::string(packet->data(), packet->length()));
378 EXPECT_EQ(ERR_IO_PENDING, socket->Write(buffer.get(), packet->length(),
379 std::move(write_callback),
380 TRAFFIC_ANNOTATION_FOR_TESTS));
381
382 EXPECT_FALSE(read_completed);
383 EXPECT_FALSE(write_completed);
384
385 socket_data.Resume();
386 socket_data.RunUntilAllConsumed();
387 RunUntilIdle();
388
389 EXPECT_TRUE(read_completed);
390 EXPECT_TRUE(write_completed);
391 }
392
393 // Using `After`, a `Read` and `Write` can be allowed in either order.
TEST_F(QuicSocketDataProviderTest,ParallelReadAndWrite)394 TEST_F(QuicSocketDataProviderTest, ParallelReadAndWrite) {
395 for (bool read_first : {false, true}) {
396 SCOPED_TRACE(::testing::Message() << "read_first: " << read_first);
397 QuicSocketDataProvider socket_data(version_);
398 MockClientSocketFactory socket_factory;
399
400 socket_data.AddWrite("p1", TestPacket(1)).Sync();
401 socket_data.AddRead("p2", TestPacket(2)).Sync().After("p1");
402 socket_data.AddWrite("p3", TestPacket(3)).Sync().After("p1");
403
404 socket_factory.AddSocketDataProvider(&socket_data);
405 std::unique_ptr<DatagramClientSocket> socket =
406 socket_factory.CreateDatagramClientSocket(
407 DatagramSocket::BindType::DEFAULT_BIND, nullptr,
408 net_log_with_source_.source());
409 socket->Connect(IPEndPoint());
410
411 // Write p1 to get things started.
412 std::unique_ptr<quic::QuicReceivedPacket> packet = TestPacket(1);
413 scoped_refptr<IOBuffer> buffer = base::MakeRefCounted<StringIOBuffer>(
414 std::string(packet->data(), packet->length()));
415 EXPECT_EQ(static_cast<int>(packet->length()),
416 socket->Write(buffer.get(), packet->length(), base::DoNothing(),
417 TRAFFIC_ANNOTATION_FOR_TESTS));
418
419 scoped_refptr<GrowableIOBuffer> read_buffer =
420 base::MakeRefCounted<GrowableIOBuffer>();
421 read_buffer->SetCapacity(100);
422 auto do_read = [&]() {
423 EXPECT_EQ(static_cast<int>(TestPacket(2)->length()),
424 socket->Read(read_buffer.get(), 100, base::DoNothing()));
425 };
426
427 std::unique_ptr<quic::QuicReceivedPacket> write_packet = TestPacket(3);
428 buffer = base::MakeRefCounted<StringIOBuffer>(
429 std::string(write_packet->data(), write_packet->length()));
430
431 auto do_write = [&]() {
432 EXPECT_EQ(static_cast<int>(write_packet->length()),
433 socket->Write(buffer.get(), write_packet->length(),
434 base::DoNothing(), TRAFFIC_ANNOTATION_FOR_TESTS));
435 };
436
437 // Read p2 and write p3 in both orders.
438 if (read_first) {
439 do_read();
440 do_write();
441 } else {
442 do_write();
443 do_read();
444 }
445
446 socket_data.RunUntilAllConsumed();
447 }
448 }
449
450 // When multiple Read expectations become ready at the same time, fail with a
451 // CHECK error.
TEST_F(QuicSocketDataProviderTest,MultipleReadsReady)452 TEST_F(QuicSocketDataProviderTest, MultipleReadsReady) {
453 QuicSocketDataProvider socket_data(version_);
454 MockClientSocketFactory socket_factory;
455
456 socket_data.AddWrite("p1", TestPacket(1)).Sync();
457 socket_data.AddRead("p2", TestPacket(2)).After("p1");
458 socket_data.AddRead("p3", TestPacket(3)).After("p1");
459
460 socket_factory.AddSocketDataProvider(&socket_data);
461 std::unique_ptr<DatagramClientSocket> socket =
462 socket_factory.CreateDatagramClientSocket(
463 DatagramSocket::BindType::DEFAULT_BIND, nullptr,
464 net_log_with_source_.source());
465 socket->Connect(IPEndPoint());
466
467 std::unique_ptr<quic::QuicReceivedPacket> packet = TestPacket(1);
468 scoped_refptr<StringIOBuffer> buffer = base::MakeRefCounted<StringIOBuffer>(
469 std::string(packet->data(), packet->length()));
470 EXPECT_EQ(static_cast<int>(packet->length()),
471 socket->Write(buffer.get(), packet->length(), base::DoNothing(),
472 TRAFFIC_ANNOTATION_FOR_TESTS));
473 EXPECT_CHECK_DEATH(
474 socket->Read(buffer.get(), buffer->size(), base::DoNothing()));
475 }
476
477 } // namespace net::test
478