• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/quic/quic_socket_data_provider.h"
6 
7 #include <algorithm>
8 #include <map>
9 #include <memory>
10 #include <optional>
11 #include <set>
12 #include <sstream>
13 #include <string>
14 
15 #include "base/functional/callback.h"
16 #include "base/run_loop.h"
17 #include "base/strings/string_number_conversions.h"
18 #include "base/task/sequenced_task_runner.h"
19 #include "net/base/hex_utils.h"
20 #include "net/socket/socket_test_util.h"
21 #include "net/third_party/quiche/src/quiche/quic/core/quic_packets.h"
22 #include "testing/gtest/include/gtest/gtest.h"
23 
24 namespace net::test {
25 
Expectation(std::string name,Type type,int rv,std::unique_ptr<quic::QuicEncryptedPacket> packet)26 QuicSocketDataProvider::Expectation::Expectation(
27     std::string name,
28     Type type,
29     int rv,
30     std::unique_ptr<quic::QuicEncryptedPacket> packet)
31     : name_(std::move(name)),
32       type_(type),
33       rv_(rv),
34       packet_(std::move(packet)) {}
35 
36 QuicSocketDataProvider::Expectation::Expectation(
37     QuicSocketDataProvider::Expectation&&) = default;
38 
39 QuicSocketDataProvider::Expectation::~Expectation() = default;
40 
After(std::string name)41 QuicSocketDataProvider::Expectation& QuicSocketDataProvider::Expectation::After(
42     std::string name) {
43   after_.insert(std::move(name));
44   return *this;
45 }
46 
TypeToString(QuicSocketDataProvider::Expectation::Type type)47 std::string QuicSocketDataProvider::Expectation::TypeToString(
48     QuicSocketDataProvider::Expectation::Type type) {
49   switch (type) {
50     case Expectation::Type::READ:
51       return "READ";
52     case Expectation::Type::WRITE:
53       return "WRITE";
54     case Expectation::Type::PAUSE:
55       return "PAUSE";
56   }
57   NOTREACHED();
58 }
59 
Consume()60 void QuicSocketDataProvider::Expectation::Consume() {
61   CHECK(!consumed_);
62   VLOG(1) << "Consuming " << TypeToString(type_) << " expectation " << name_;
63   consumed_ = true;
64 }
65 
QuicSocketDataProvider(quic::ParsedQuicVersion version)66 QuicSocketDataProvider::QuicSocketDataProvider(quic::ParsedQuicVersion version)
67     : printer_(version) {}
68 
69 QuicSocketDataProvider::~QuicSocketDataProvider() = default;
70 
AddRead(std::string name,std::unique_ptr<quic::QuicEncryptedPacket> packet)71 QuicSocketDataProvider::Expectation& QuicSocketDataProvider::AddRead(
72     std::string name,
73     std::unique_ptr<quic::QuicEncryptedPacket> packet) {
74   expectations_.push_back(Expectation(std::move(name), Expectation::Type::READ,
75                                       OK, std::move(packet)));
76   return expectations_.back();
77 }
78 
AddRead(std::string name,std::unique_ptr<quic::QuicReceivedPacket> packet)79 QuicSocketDataProvider::Expectation& QuicSocketDataProvider::AddRead(
80     std::string name,
81     std::unique_ptr<quic::QuicReceivedPacket> packet) {
82   uint8_t tos_byte = static_cast<uint8_t>(packet->ecn_codepoint());
83   return AddRead(std::move(name),
84                  static_cast<std::unique_ptr<quic::QuicEncryptedPacket>>(
85                      std::move(packet)))
86       .TosByte(tos_byte);
87 }
88 
AddReadError(std::string name,int rv)89 QuicSocketDataProvider::Expectation& QuicSocketDataProvider::AddReadError(
90     std::string name,
91     int rv) {
92   CHECK_NE(rv, OK);
93   CHECK_NE(rv, ERR_IO_PENDING);
94   expectations_.push_back(
95       Expectation(std::move(name), Expectation::Type::READ, rv, nullptr));
96   return expectations_.back();
97 }
98 
AddWrite(std::string name,std::unique_ptr<quic::QuicEncryptedPacket> packet,int rv)99 QuicSocketDataProvider::Expectation& QuicSocketDataProvider::AddWrite(
100     std::string name,
101     std::unique_ptr<quic::QuicEncryptedPacket> packet,
102     int rv) {
103   expectations_.push_back(Expectation(std::move(name), Expectation::Type::WRITE,
104                                       rv, std::move(packet)));
105   return expectations_.back();
106 }
107 
AddWriteError(std::string name,int rv)108 QuicSocketDataProvider::Expectation& QuicSocketDataProvider::AddWriteError(
109     std::string name,
110     int rv) {
111   CHECK_NE(rv, OK);
112   CHECK_NE(rv, ERR_IO_PENDING);
113   expectations_.push_back(
114       Expectation(std::move(name), Expectation::Type::WRITE, rv, nullptr));
115   return expectations_.back();
116 }
117 
AddPause(std::string name)118 QuicSocketDataProvider::PausePoint QuicSocketDataProvider::AddPause(
119     std::string name) {
120   expectations_.push_back(
121       Expectation(std::move(name), Expectation::Type::PAUSE, OK, nullptr));
122   return expectations_.size() - 1;
123 }
124 
AllDataConsumed() const125 bool QuicSocketDataProvider::AllDataConsumed() const {
126   return std::all_of(
127       expectations_.begin(), expectations_.end(),
128       [](const Expectation& expectation) { return expectation.consumed(); });
129 }
130 
RunUntilPause(QuicSocketDataProvider::PausePoint pause_point)131 void QuicSocketDataProvider::RunUntilPause(
132     QuicSocketDataProvider::PausePoint pause_point) {
133   if (!paused_at_.has_value()) {
134     run_until_run_loop_ = std::make_unique<base::RunLoop>();
135     run_until_run_loop_->Run();
136     run_until_run_loop_.reset();
137   }
138   CHECK(paused_at_.has_value() && *paused_at_ == pause_point)
139       << "Did not pause at '" << expectations_[pause_point].name() << "'.";
140 }
141 
Resume()142 void QuicSocketDataProvider::Resume() {
143   CHECK(paused_at_.has_value());
144   VLOG(1) << "Resuming from pause point " << expectations_[*paused_at_].name();
145   expectations_[*paused_at_].Consume();
146   paused_at_ = std::nullopt;
147   ExpectationConsumed();
148 }
149 
RunUntilAllConsumed()150 void QuicSocketDataProvider::RunUntilAllConsumed() {
151   if (!AllDataConsumed()) {
152     run_until_run_loop_ = std::make_unique<base::RunLoop>();
153     run_until_run_loop_->Run();
154     run_until_run_loop_.reset();
155   }
156 
157   // If that run timed out, then there will still be un-consumed data.
158   if (!AllDataConsumed()) {
159     std::vector<size_t> unconsumed;
160     for (size_t i = 0; i < expectations_.size(); i++) {
161       if (!expectations_[i].consumed()) {
162         unconsumed.push_back(i);
163       }
164     }
165     FAIL() << "All expectations were not consumed; remaining: "
166            << ExpectationList(unconsumed);
167   }
168 }
169 
OnRead()170 MockRead QuicSocketDataProvider::OnRead() {
171   CHECK(!read_pending_);
172   read_pending_ = true;
173   std::optional<MockRead> next_read = ConsumeNextRead();
174   if (!next_read.has_value()) {
175     return MockRead(ASYNC, ERR_IO_PENDING);
176   }
177 
178   read_pending_ = false;
179   return *next_read;
180 }
181 
OnWrite(const std::string & data)182 MockWriteResult QuicSocketDataProvider::OnWrite(const std::string& data) {
183   CHECK(!write_pending_.has_value());
184   write_pending_ = data;
185   std::optional<MockWriteResult> next_write = ConsumeNextWrite();
186   if (!next_write.has_value()) {
187     // If Write() was called when no corresponding expectation exists, that's an
188     // error unless execution is currently paused, in which case it's just
189     // pending. This rarely occurs because the only other type of expectation
190     // that might be blocking a WRITE is a READ, and QUIC implementations
191     // typically eagerly consume READs.
192     if (paused_at_.has_value()) {
193       return MockWriteResult(ASYNC, ERR_IO_PENDING);
194     } else {
195       ADD_FAILURE() << "Write call when none is expected:\n"
196                     << printer_.PrintWrite(data);
197       return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
198     }
199   }
200 
201   write_pending_ = std::nullopt;
202   return *next_write;
203 }
204 
AllReadDataConsumed() const205 bool QuicSocketDataProvider::AllReadDataConsumed() const {
206   return AllDataConsumed();
207 }
208 
AllWriteDataConsumed() const209 bool QuicSocketDataProvider::AllWriteDataConsumed() const {
210   return AllDataConsumed();
211 }
212 
CancelPendingRead()213 void QuicSocketDataProvider::CancelPendingRead() {
214   read_pending_ = false;
215 }
216 
Reset()217 void QuicSocketDataProvider::Reset() {
218   // Note that `Reset` is a parent-class method with a confusing name. It is
219   // used to initialize the socket data provider before it is used.
220 
221   // Map names to index, and incidentally check for duplicate names.
222   std::map<std::string, size_t> names;
223   for (size_t i = 0; i < expectations_.size(); i++) {
224     Expectation& expectation = expectations_[i];
225     auto [_, inserted] = names.insert({expectation.name(), i});
226     CHECK(inserted) << "Another expectation named " << expectation.name()
227                     << " exists.";
228   }
229 
230   // Calculate `dependencies_` mapping indices in `expectations_` to indices of
231   // the expectations they depend on.
232   dependencies_.clear();
233   for (size_t i = 0; i < expectations_.size(); i++) {
234     Expectation& expectation = expectations_[i];
235     if (expectation.after().empty()) {
236       // If no other dependencies are given, make the expectation depend on the
237       // previous expectation.
238       if (i > 0) {
239         dependencies_[i].insert(i - 1);
240       }
241     } else {
242       for (auto& after : expectation.after()) {
243         const auto dep = names.find(after);
244         CHECK(dep != names.end()) << "No expectation named " << after;
245         dependencies_[i].insert(dep->second);
246       }
247     }
248   }
249 
250   pending_maybe_consume_expectations_ = false;
251   read_pending_ = false;
252   write_pending_ = std::nullopt;
253   MaybeConsumeExpectations();
254 }
255 
FindReadyExpectations(Expectation::Type type)256 std::optional<size_t> QuicSocketDataProvider::FindReadyExpectations(
257     Expectation::Type type) {
258   std::vector<size_t> matches;
259   for (size_t i = 0; i < expectations_.size(); i++) {
260     const Expectation& expectation = expectations_[i];
261     if (expectation.consumed() || expectation.type() != type) {
262       continue;
263     }
264     bool found_unconsumed = false;
265     for (auto dep : dependencies_[i]) {
266       if (!expectations_[dep].consumed_) {
267         found_unconsumed = true;
268         break;
269       }
270     }
271     if (!found_unconsumed) {
272       matches.push_back(i);
273     }
274   }
275 
276   if (matches.size() > 1) {
277     std::string exp_type = Expectation::TypeToString(type);
278     std::string names = ExpectationList(matches);
279     CHECK(matches.size() <= 1)
280         << "Multiple expectations of type " << exp_type
281         << " are ready: " << names << ". Use .After() to disambiguate.";
282   }
283 
284   return matches.empty() ? std::nullopt : std::make_optional(matches[0]);
285 }
286 
ConsumeNextRead()287 std::optional<MockRead> QuicSocketDataProvider::ConsumeNextRead() {
288   CHECK(read_pending_);
289   std::optional<size_t> ready = FindReadyExpectations(Expectation::Type::READ);
290   if (!ready.has_value()) {
291     return std::nullopt;
292   }
293 
294   // If there's exactly one matching expectation, return it.
295   Expectation& ready_expectation = expectations_[*ready];
296   MockRead read(ready_expectation.mode(), ready_expectation.rv());
297   if (ready_expectation.packet()) {
298     read.data = ready_expectation.packet()->data();
299     read.data_len = ready_expectation.packet()->length();
300   }
301   read.tos = ready_expectation.tos_byte();
302   ready_expectation.Consume();
303   ExpectationConsumed();
304   return read;
305 }
306 
ConsumeNextWrite()307 std::optional<MockWriteResult> QuicSocketDataProvider::ConsumeNextWrite() {
308   CHECK(write_pending_.has_value());
309   std::optional<size_t> ready = FindReadyExpectations(Expectation::Type::WRITE);
310   if (!ready.has_value()) {
311     return std::nullopt;
312   }
313 
314   // If there's exactly one matching expectation, check if it matches the write
315   // and return it.
316   Expectation& ready_expectation = expectations_[*ready];
317   if (ready_expectation.packet()) {
318     if (!VerifyWriteData(ready_expectation)) {
319       return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
320     }
321   }
322   MockWriteResult write(ready_expectation.mode(),
323                         ready_expectation.packet()
324                             ? ready_expectation.packet()->length()
325                             : ready_expectation.rv());
326   ready_expectation.Consume();
327   ExpectationConsumed();
328   return write;
329 }
330 
MaybeConsumeExpectations()331 void QuicSocketDataProvider::MaybeConsumeExpectations() {
332   pending_maybe_consume_expectations_ = false;
333   if (read_pending_) {
334     std::optional<MockRead> next_read = ConsumeNextRead();
335     if (next_read.has_value()) {
336       read_pending_ = false;
337       if (socket()) {
338         socket()->OnReadComplete(*next_read);
339       }
340     }
341   }
342 
343   if (write_pending_.has_value()) {
344     std::optional<MockWriteResult> next_write = ConsumeNextWrite();
345     if (next_write.has_value()) {
346       write_pending_ = std::nullopt;
347       if (socket()) {
348         socket()->OnWriteComplete(next_write->result);
349       }
350     }
351   }
352 
353   if (!paused_at_) {
354     std::optional<size_t> ready =
355         FindReadyExpectations(Expectation::Type::PAUSE);
356     if (ready.has_value()) {
357       VLOG(1) << "Pausing at " << expectations_[*ready].name();
358       paused_at_ = *ready;
359       if (run_until_run_loop_) {
360         run_until_run_loop_->Quit();
361       }
362     }
363   }
364 
365   if (run_until_run_loop_ && AllDataConsumed()) {
366     run_until_run_loop_->Quit();
367   }
368 }
369 
ExpectationConsumed()370 void QuicSocketDataProvider::ExpectationConsumed() {
371   if (pending_maybe_consume_expectations_) {
372     return;
373   }
374   pending_maybe_consume_expectations_ = true;
375 
376   // Call `MaybeConsumeExpectations` in a task. That method may trigger
377   // consumption of other expectations, and that consumption must happen _after_
378   // the current call to `Read` or `Write` has finished.
379   base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
380       FROM_HERE,
381       base::BindOnce(&QuicSocketDataProvider::MaybeConsumeExpectations,
382                      weak_factory_.GetWeakPtr()));
383 }
384 
VerifyWriteData(QuicSocketDataProvider::Expectation & expectation)385 bool QuicSocketDataProvider::VerifyWriteData(
386     QuicSocketDataProvider::Expectation& expectation) {
387   std::string expected_data(expectation.packet()->data(),
388                             expectation.packet()->length());
389   std::string& actual_data = *write_pending_;
390   bool write_matches = actual_data == expected_data;
391   EXPECT_TRUE(write_matches)
392       << "Expectation '" << expectation.name()
393       << "' not met. Actual formatted write data:\n"
394       << printer_.PrintWrite(actual_data) << "But expectation '"
395       << expectation.name() << "' expected formatted write data:\n"
396       << printer_.PrintWrite(expected_data) << "Actual raw write data:\n"
397       << HexDump(actual_data) << "Expected raw write data:\n"
398       << HexDump(expected_data);
399   return write_matches;
400 }
401 
ExpectationList(const std::vector<size_t> & indices)402 std::string QuicSocketDataProvider::ExpectationList(
403     const std::vector<size_t>& indices) {
404   std::ostringstream names;
405   bool first = true;
406   for (auto i : indices) {
407     names << (first ? "" : ", ") << expectations_[i].name();
408     first = false;
409   }
410   return names.str();
411 }
412 
413 }  // namespace net::test
414