1 // Copyright 2024 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/quic/quic_socket_data_provider.h"
6
7 #include <algorithm>
8 #include <map>
9 #include <memory>
10 #include <optional>
11 #include <set>
12 #include <sstream>
13 #include <string>
14
15 #include "base/functional/callback.h"
16 #include "base/run_loop.h"
17 #include "base/strings/string_number_conversions.h"
18 #include "base/task/sequenced_task_runner.h"
19 #include "net/base/hex_utils.h"
20 #include "net/socket/socket_test_util.h"
21 #include "net/third_party/quiche/src/quiche/quic/core/quic_packets.h"
22 #include "testing/gtest/include/gtest/gtest.h"
23
24 namespace net::test {
25
Expectation(std::string name,Type type,int rv,std::unique_ptr<quic::QuicEncryptedPacket> packet)26 QuicSocketDataProvider::Expectation::Expectation(
27 std::string name,
28 Type type,
29 int rv,
30 std::unique_ptr<quic::QuicEncryptedPacket> packet)
31 : name_(std::move(name)),
32 type_(type),
33 rv_(rv),
34 packet_(std::move(packet)) {}
35
36 QuicSocketDataProvider::Expectation::Expectation(
37 QuicSocketDataProvider::Expectation&&) = default;
38
39 QuicSocketDataProvider::Expectation::~Expectation() = default;
40
After(std::string name)41 QuicSocketDataProvider::Expectation& QuicSocketDataProvider::Expectation::After(
42 std::string name) {
43 after_.insert(std::move(name));
44 return *this;
45 }
46
TypeToString(QuicSocketDataProvider::Expectation::Type type)47 std::string QuicSocketDataProvider::Expectation::TypeToString(
48 QuicSocketDataProvider::Expectation::Type type) {
49 switch (type) {
50 case Expectation::Type::READ:
51 return "READ";
52 case Expectation::Type::WRITE:
53 return "WRITE";
54 case Expectation::Type::PAUSE:
55 return "PAUSE";
56 }
57 NOTREACHED();
58 }
59
Consume()60 void QuicSocketDataProvider::Expectation::Consume() {
61 CHECK(!consumed_);
62 VLOG(1) << "Consuming " << TypeToString(type_) << " expectation " << name_;
63 consumed_ = true;
64 }
65
QuicSocketDataProvider(quic::ParsedQuicVersion version)66 QuicSocketDataProvider::QuicSocketDataProvider(quic::ParsedQuicVersion version)
67 : printer_(version) {}
68
69 QuicSocketDataProvider::~QuicSocketDataProvider() = default;
70
AddRead(std::string name,std::unique_ptr<quic::QuicEncryptedPacket> packet)71 QuicSocketDataProvider::Expectation& QuicSocketDataProvider::AddRead(
72 std::string name,
73 std::unique_ptr<quic::QuicEncryptedPacket> packet) {
74 expectations_.push_back(Expectation(std::move(name), Expectation::Type::READ,
75 OK, std::move(packet)));
76 return expectations_.back();
77 }
78
AddRead(std::string name,std::unique_ptr<quic::QuicReceivedPacket> packet)79 QuicSocketDataProvider::Expectation& QuicSocketDataProvider::AddRead(
80 std::string name,
81 std::unique_ptr<quic::QuicReceivedPacket> packet) {
82 uint8_t tos_byte = static_cast<uint8_t>(packet->ecn_codepoint());
83 return AddRead(std::move(name),
84 static_cast<std::unique_ptr<quic::QuicEncryptedPacket>>(
85 std::move(packet)))
86 .TosByte(tos_byte);
87 }
88
AddReadError(std::string name,int rv)89 QuicSocketDataProvider::Expectation& QuicSocketDataProvider::AddReadError(
90 std::string name,
91 int rv) {
92 CHECK_NE(rv, OK);
93 CHECK_NE(rv, ERR_IO_PENDING);
94 expectations_.push_back(
95 Expectation(std::move(name), Expectation::Type::READ, rv, nullptr));
96 return expectations_.back();
97 }
98
AddWrite(std::string name,std::unique_ptr<quic::QuicEncryptedPacket> packet,int rv)99 QuicSocketDataProvider::Expectation& QuicSocketDataProvider::AddWrite(
100 std::string name,
101 std::unique_ptr<quic::QuicEncryptedPacket> packet,
102 int rv) {
103 expectations_.push_back(Expectation(std::move(name), Expectation::Type::WRITE,
104 rv, std::move(packet)));
105 return expectations_.back();
106 }
107
AddWriteError(std::string name,int rv)108 QuicSocketDataProvider::Expectation& QuicSocketDataProvider::AddWriteError(
109 std::string name,
110 int rv) {
111 CHECK_NE(rv, OK);
112 CHECK_NE(rv, ERR_IO_PENDING);
113 expectations_.push_back(
114 Expectation(std::move(name), Expectation::Type::WRITE, rv, nullptr));
115 return expectations_.back();
116 }
117
AddPause(std::string name)118 QuicSocketDataProvider::PausePoint QuicSocketDataProvider::AddPause(
119 std::string name) {
120 expectations_.push_back(
121 Expectation(std::move(name), Expectation::Type::PAUSE, OK, nullptr));
122 return expectations_.size() - 1;
123 }
124
AllDataConsumed() const125 bool QuicSocketDataProvider::AllDataConsumed() const {
126 return std::all_of(
127 expectations_.begin(), expectations_.end(),
128 [](const Expectation& expectation) { return expectation.consumed(); });
129 }
130
RunUntilPause(QuicSocketDataProvider::PausePoint pause_point)131 void QuicSocketDataProvider::RunUntilPause(
132 QuicSocketDataProvider::PausePoint pause_point) {
133 if (!paused_at_.has_value()) {
134 run_until_run_loop_ = std::make_unique<base::RunLoop>();
135 run_until_run_loop_->Run();
136 run_until_run_loop_.reset();
137 }
138 CHECK(paused_at_.has_value() && *paused_at_ == pause_point)
139 << "Did not pause at '" << expectations_[pause_point].name() << "'.";
140 }
141
Resume()142 void QuicSocketDataProvider::Resume() {
143 CHECK(paused_at_.has_value());
144 VLOG(1) << "Resuming from pause point " << expectations_[*paused_at_].name();
145 expectations_[*paused_at_].Consume();
146 paused_at_ = std::nullopt;
147 ExpectationConsumed();
148 }
149
RunUntilAllConsumed()150 void QuicSocketDataProvider::RunUntilAllConsumed() {
151 if (!AllDataConsumed()) {
152 run_until_run_loop_ = std::make_unique<base::RunLoop>();
153 run_until_run_loop_->Run();
154 run_until_run_loop_.reset();
155 }
156
157 // If that run timed out, then there will still be un-consumed data.
158 if (!AllDataConsumed()) {
159 std::vector<size_t> unconsumed;
160 for (size_t i = 0; i < expectations_.size(); i++) {
161 if (!expectations_[i].consumed()) {
162 unconsumed.push_back(i);
163 }
164 }
165 FAIL() << "All expectations were not consumed; remaining: "
166 << ExpectationList(unconsumed);
167 }
168 }
169
OnRead()170 MockRead QuicSocketDataProvider::OnRead() {
171 CHECK(!read_pending_);
172 read_pending_ = true;
173 std::optional<MockRead> next_read = ConsumeNextRead();
174 if (!next_read.has_value()) {
175 return MockRead(ASYNC, ERR_IO_PENDING);
176 }
177
178 read_pending_ = false;
179 return *next_read;
180 }
181
OnWrite(const std::string & data)182 MockWriteResult QuicSocketDataProvider::OnWrite(const std::string& data) {
183 CHECK(!write_pending_.has_value());
184 write_pending_ = data;
185 std::optional<MockWriteResult> next_write = ConsumeNextWrite();
186 if (!next_write.has_value()) {
187 // If Write() was called when no corresponding expectation exists, that's an
188 // error unless execution is currently paused, in which case it's just
189 // pending. This rarely occurs because the only other type of expectation
190 // that might be blocking a WRITE is a READ, and QUIC implementations
191 // typically eagerly consume READs.
192 if (paused_at_.has_value()) {
193 return MockWriteResult(ASYNC, ERR_IO_PENDING);
194 } else {
195 ADD_FAILURE() << "Write call when none is expected:\n"
196 << printer_.PrintWrite(data);
197 return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
198 }
199 }
200
201 write_pending_ = std::nullopt;
202 return *next_write;
203 }
204
AllReadDataConsumed() const205 bool QuicSocketDataProvider::AllReadDataConsumed() const {
206 return AllDataConsumed();
207 }
208
AllWriteDataConsumed() const209 bool QuicSocketDataProvider::AllWriteDataConsumed() const {
210 return AllDataConsumed();
211 }
212
CancelPendingRead()213 void QuicSocketDataProvider::CancelPendingRead() {
214 read_pending_ = false;
215 }
216
Reset()217 void QuicSocketDataProvider::Reset() {
218 // Note that `Reset` is a parent-class method with a confusing name. It is
219 // used to initialize the socket data provider before it is used.
220
221 // Map names to index, and incidentally check for duplicate names.
222 std::map<std::string, size_t> names;
223 for (size_t i = 0; i < expectations_.size(); i++) {
224 Expectation& expectation = expectations_[i];
225 auto [_, inserted] = names.insert({expectation.name(), i});
226 CHECK(inserted) << "Another expectation named " << expectation.name()
227 << " exists.";
228 }
229
230 // Calculate `dependencies_` mapping indices in `expectations_` to indices of
231 // the expectations they depend on.
232 dependencies_.clear();
233 for (size_t i = 0; i < expectations_.size(); i++) {
234 Expectation& expectation = expectations_[i];
235 if (expectation.after().empty()) {
236 // If no other dependencies are given, make the expectation depend on the
237 // previous expectation.
238 if (i > 0) {
239 dependencies_[i].insert(i - 1);
240 }
241 } else {
242 for (auto& after : expectation.after()) {
243 const auto dep = names.find(after);
244 CHECK(dep != names.end()) << "No expectation named " << after;
245 dependencies_[i].insert(dep->second);
246 }
247 }
248 }
249
250 pending_maybe_consume_expectations_ = false;
251 read_pending_ = false;
252 write_pending_ = std::nullopt;
253 MaybeConsumeExpectations();
254 }
255
FindReadyExpectations(Expectation::Type type)256 std::optional<size_t> QuicSocketDataProvider::FindReadyExpectations(
257 Expectation::Type type) {
258 std::vector<size_t> matches;
259 for (size_t i = 0; i < expectations_.size(); i++) {
260 const Expectation& expectation = expectations_[i];
261 if (expectation.consumed() || expectation.type() != type) {
262 continue;
263 }
264 bool found_unconsumed = false;
265 for (auto dep : dependencies_[i]) {
266 if (!expectations_[dep].consumed_) {
267 found_unconsumed = true;
268 break;
269 }
270 }
271 if (!found_unconsumed) {
272 matches.push_back(i);
273 }
274 }
275
276 if (matches.size() > 1) {
277 std::string exp_type = Expectation::TypeToString(type);
278 std::string names = ExpectationList(matches);
279 CHECK(matches.size() <= 1)
280 << "Multiple expectations of type " << exp_type
281 << " are ready: " << names << ". Use .After() to disambiguate.";
282 }
283
284 return matches.empty() ? std::nullopt : std::make_optional(matches[0]);
285 }
286
ConsumeNextRead()287 std::optional<MockRead> QuicSocketDataProvider::ConsumeNextRead() {
288 CHECK(read_pending_);
289 std::optional<size_t> ready = FindReadyExpectations(Expectation::Type::READ);
290 if (!ready.has_value()) {
291 return std::nullopt;
292 }
293
294 // If there's exactly one matching expectation, return it.
295 Expectation& ready_expectation = expectations_[*ready];
296 MockRead read(ready_expectation.mode(), ready_expectation.rv());
297 if (ready_expectation.packet()) {
298 read.data = ready_expectation.packet()->data();
299 read.data_len = ready_expectation.packet()->length();
300 }
301 read.tos = ready_expectation.tos_byte();
302 ready_expectation.Consume();
303 ExpectationConsumed();
304 return read;
305 }
306
ConsumeNextWrite()307 std::optional<MockWriteResult> QuicSocketDataProvider::ConsumeNextWrite() {
308 CHECK(write_pending_.has_value());
309 std::optional<size_t> ready = FindReadyExpectations(Expectation::Type::WRITE);
310 if (!ready.has_value()) {
311 return std::nullopt;
312 }
313
314 // If there's exactly one matching expectation, check if it matches the write
315 // and return it.
316 Expectation& ready_expectation = expectations_[*ready];
317 if (ready_expectation.packet()) {
318 if (!VerifyWriteData(ready_expectation)) {
319 return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
320 }
321 }
322 MockWriteResult write(ready_expectation.mode(),
323 ready_expectation.packet()
324 ? ready_expectation.packet()->length()
325 : ready_expectation.rv());
326 ready_expectation.Consume();
327 ExpectationConsumed();
328 return write;
329 }
330
MaybeConsumeExpectations()331 void QuicSocketDataProvider::MaybeConsumeExpectations() {
332 pending_maybe_consume_expectations_ = false;
333 if (read_pending_) {
334 std::optional<MockRead> next_read = ConsumeNextRead();
335 if (next_read.has_value()) {
336 read_pending_ = false;
337 if (socket()) {
338 socket()->OnReadComplete(*next_read);
339 }
340 }
341 }
342
343 if (write_pending_.has_value()) {
344 std::optional<MockWriteResult> next_write = ConsumeNextWrite();
345 if (next_write.has_value()) {
346 write_pending_ = std::nullopt;
347 if (socket()) {
348 socket()->OnWriteComplete(next_write->result);
349 }
350 }
351 }
352
353 if (!paused_at_) {
354 std::optional<size_t> ready =
355 FindReadyExpectations(Expectation::Type::PAUSE);
356 if (ready.has_value()) {
357 VLOG(1) << "Pausing at " << expectations_[*ready].name();
358 paused_at_ = *ready;
359 if (run_until_run_loop_) {
360 run_until_run_loop_->Quit();
361 }
362 }
363 }
364
365 if (run_until_run_loop_ && AllDataConsumed()) {
366 run_until_run_loop_->Quit();
367 }
368 }
369
ExpectationConsumed()370 void QuicSocketDataProvider::ExpectationConsumed() {
371 if (pending_maybe_consume_expectations_) {
372 return;
373 }
374 pending_maybe_consume_expectations_ = true;
375
376 // Call `MaybeConsumeExpectations` in a task. That method may trigger
377 // consumption of other expectations, and that consumption must happen _after_
378 // the current call to `Read` or `Write` has finished.
379 base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
380 FROM_HERE,
381 base::BindOnce(&QuicSocketDataProvider::MaybeConsumeExpectations,
382 weak_factory_.GetWeakPtr()));
383 }
384
VerifyWriteData(QuicSocketDataProvider::Expectation & expectation)385 bool QuicSocketDataProvider::VerifyWriteData(
386 QuicSocketDataProvider::Expectation& expectation) {
387 std::string expected_data(expectation.packet()->data(),
388 expectation.packet()->length());
389 std::string& actual_data = *write_pending_;
390 bool write_matches = actual_data == expected_data;
391 EXPECT_TRUE(write_matches)
392 << "Expectation '" << expectation.name()
393 << "' not met. Actual formatted write data:\n"
394 << printer_.PrintWrite(actual_data) << "But expectation '"
395 << expectation.name() << "' expected formatted write data:\n"
396 << printer_.PrintWrite(expected_data) << "Actual raw write data:\n"
397 << HexDump(actual_data) << "Expected raw write data:\n"
398 << HexDump(expected_data);
399 return write_matches;
400 }
401
ExpectationList(const std::vector<size_t> & indices)402 std::string QuicSocketDataProvider::ExpectationList(
403 const std::vector<size_t>& indices) {
404 std::ostringstream names;
405 bool first = true;
406 for (auto i : indices) {
407 names << (first ? "" : ", ") << expectations_[i].name();
408 first = false;
409 }
410 return names.str();
411 }
412
413 } // namespace net::test
414