• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "quiche/quic/test_tools/quic_test_utils.h"
6 
7 #include <algorithm>
8 #include <cstddef>
9 #include <cstdint>
10 #include <memory>
11 #include <utility>
12 #include <vector>
13 
14 #include "absl/base/macros.h"
15 #include "absl/strings/string_view.h"
16 #include "openssl/chacha.h"
17 #include "openssl/sha.h"
18 #include "quiche/quic/core/crypto/crypto_framer.h"
19 #include "quiche/quic/core/crypto/crypto_handshake.h"
20 #include "quiche/quic/core/crypto/crypto_utils.h"
21 #include "quiche/quic/core/crypto/null_decrypter.h"
22 #include "quiche/quic/core/crypto/null_encrypter.h"
23 #include "quiche/quic/core/crypto/quic_decrypter.h"
24 #include "quiche/quic/core/crypto/quic_encrypter.h"
25 #include "quiche/quic/core/http/quic_spdy_client_session.h"
26 #include "quiche/quic/core/quic_config.h"
27 #include "quiche/quic/core/quic_data_writer.h"
28 #include "quiche/quic/core/quic_framer.h"
29 #include "quiche/quic/core/quic_packet_creator.h"
30 #include "quiche/quic/core/quic_types.h"
31 #include "quiche/quic/core/quic_utils.h"
32 #include "quiche/quic/core/quic_versions.h"
33 #include "quiche/quic/platform/api/quic_flags.h"
34 #include "quiche/quic/platform/api/quic_logging.h"
35 #include "quiche/quic/test_tools/crypto_test_utils.h"
36 #include "quiche/quic/test_tools/quic_config_peer.h"
37 #include "quiche/quic/test_tools/quic_connection_peer.h"
38 #include "quiche/common/quiche_buffer_allocator.h"
39 #include "quiche/common/quiche_endian.h"
40 #include "quiche/common/simple_buffer_allocator.h"
41 #include "quiche/spdy/core/spdy_frame_builder.h"
42 
43 using testing::_;
44 using testing::Invoke;
45 using testing::Return;
46 
47 namespace quic {
48 namespace test {
49 
TestConnectionId()50 QuicConnectionId TestConnectionId() {
51   // Chosen by fair dice roll.
52   // Guaranteed to be random.
53   return TestConnectionId(42);
54 }
55 
TestConnectionId(uint64_t connection_number)56 QuicConnectionId TestConnectionId(uint64_t connection_number) {
57   const uint64_t connection_id64_net =
58       quiche::QuicheEndian::HostToNet64(connection_number);
59   return QuicConnectionId(reinterpret_cast<const char*>(&connection_id64_net),
60                           sizeof(connection_id64_net));
61 }
62 
TestConnectionIdNineBytesLong(uint64_t connection_number)63 QuicConnectionId TestConnectionIdNineBytesLong(uint64_t connection_number) {
64   const uint64_t connection_number_net =
65       quiche::QuicheEndian::HostToNet64(connection_number);
66   char connection_id_bytes[9] = {};
67   static_assert(
68       sizeof(connection_id_bytes) == 1 + sizeof(connection_number_net),
69       "bad lengths");
70   memcpy(connection_id_bytes + 1, &connection_number_net,
71          sizeof(connection_number_net));
72   return QuicConnectionId(connection_id_bytes, sizeof(connection_id_bytes));
73 }
74 
TestConnectionIdToUInt64(QuicConnectionId connection_id)75 uint64_t TestConnectionIdToUInt64(QuicConnectionId connection_id) {
76   QUICHE_DCHECK_EQ(connection_id.length(), kQuicDefaultConnectionIdLength);
77   uint64_t connection_id64_net = 0;
78   memcpy(&connection_id64_net, connection_id.data(),
79          std::min<size_t>(static_cast<size_t>(connection_id.length()),
80                           sizeof(connection_id64_net)));
81   return quiche::QuicheEndian::NetToHost64(connection_id64_net);
82 }
83 
CreateStatelessResetTokenForTest()84 std::vector<uint8_t> CreateStatelessResetTokenForTest() {
85   static constexpr uint8_t kStatelessResetTokenDataForTest[16] = {
86       0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97,
87       0x98, 0x99, 0x9A, 0x9B, 0x9C, 0x9D, 0x9E, 0x9F};
88   return std::vector<uint8_t>(kStatelessResetTokenDataForTest,
89                               kStatelessResetTokenDataForTest +
90                                   sizeof(kStatelessResetTokenDataForTest));
91 }
92 
TestHostname()93 std::string TestHostname() { return "test.example.com"; }
94 
TestServerId()95 QuicServerId TestServerId() { return QuicServerId(TestHostname(), kTestPort); }
96 
InitAckFrame(const std::vector<QuicAckBlock> & ack_blocks)97 QuicAckFrame InitAckFrame(const std::vector<QuicAckBlock>& ack_blocks) {
98   QUICHE_DCHECK_GT(ack_blocks.size(), 0u);
99 
100   QuicAckFrame ack;
101   QuicPacketNumber end_of_previous_block(1);
102   for (const QuicAckBlock& block : ack_blocks) {
103     QUICHE_DCHECK_GE(block.start, end_of_previous_block);
104     QUICHE_DCHECK_GT(block.limit, block.start);
105     ack.packets.AddRange(block.start, block.limit);
106     end_of_previous_block = block.limit;
107   }
108 
109   ack.largest_acked = ack.packets.Max();
110 
111   return ack;
112 }
113 
InitAckFrame(uint64_t largest_acked)114 QuicAckFrame InitAckFrame(uint64_t largest_acked) {
115   return InitAckFrame(QuicPacketNumber(largest_acked));
116 }
117 
InitAckFrame(QuicPacketNumber largest_acked)118 QuicAckFrame InitAckFrame(QuicPacketNumber largest_acked) {
119   return InitAckFrame({{QuicPacketNumber(1), largest_acked + 1}});
120 }
121 
MakeAckFrameWithAckBlocks(size_t num_ack_blocks,uint64_t least_unacked)122 QuicAckFrame MakeAckFrameWithAckBlocks(size_t num_ack_blocks,
123                                        uint64_t least_unacked) {
124   QuicAckFrame ack;
125   ack.largest_acked = QuicPacketNumber(2 * num_ack_blocks + least_unacked);
126   // Add enough received packets to get num_ack_blocks ack blocks.
127   for (QuicPacketNumber i = QuicPacketNumber(2);
128        i < QuicPacketNumber(2 * num_ack_blocks + 1); i += 2) {
129     ack.packets.Add(i + least_unacked);
130   }
131   return ack;
132 }
133 
MakeAckFrameWithGaps(uint64_t gap_size,size_t max_num_gaps,uint64_t largest_acked)134 QuicAckFrame MakeAckFrameWithGaps(uint64_t gap_size, size_t max_num_gaps,
135                                   uint64_t largest_acked) {
136   QuicAckFrame ack;
137   ack.largest_acked = QuicPacketNumber(largest_acked);
138   ack.packets.Add(QuicPacketNumber(largest_acked));
139   for (size_t i = 0; i < max_num_gaps; ++i) {
140     if (largest_acked <= gap_size) {
141       break;
142     }
143     largest_acked -= gap_size;
144     ack.packets.Add(QuicPacketNumber(largest_acked));
145   }
146   return ack;
147 }
148 
HeaderToEncryptionLevel(const QuicPacketHeader & header)149 EncryptionLevel HeaderToEncryptionLevel(const QuicPacketHeader& header) {
150   if (header.form == IETF_QUIC_SHORT_HEADER_PACKET) {
151     return ENCRYPTION_FORWARD_SECURE;
152   } else if (header.form == IETF_QUIC_LONG_HEADER_PACKET) {
153     if (header.long_packet_type == HANDSHAKE) {
154       return ENCRYPTION_HANDSHAKE;
155     } else if (header.long_packet_type == ZERO_RTT_PROTECTED) {
156       return ENCRYPTION_ZERO_RTT;
157     }
158   }
159   return ENCRYPTION_INITIAL;
160 }
161 
BuildUnsizedDataPacket(QuicFramer * framer,const QuicPacketHeader & header,const QuicFrames & frames)162 std::unique_ptr<QuicPacket> BuildUnsizedDataPacket(
163     QuicFramer* framer, const QuicPacketHeader& header,
164     const QuicFrames& frames) {
165   const size_t max_plaintext_size =
166       framer->GetMaxPlaintextSize(kMaxOutgoingPacketSize);
167   size_t packet_size = GetPacketHeaderSize(framer->transport_version(), header);
168   for (size_t i = 0; i < frames.size(); ++i) {
169     QUICHE_DCHECK_LE(packet_size, max_plaintext_size);
170     bool first_frame = i == 0;
171     bool last_frame = i == frames.size() - 1;
172     const size_t frame_size = framer->GetSerializedFrameLength(
173         frames[i], max_plaintext_size - packet_size, first_frame, last_frame,
174         header.packet_number_length);
175     QUICHE_DCHECK(frame_size);
176     packet_size += frame_size;
177   }
178   return BuildUnsizedDataPacket(framer, header, frames, packet_size);
179 }
180 
BuildUnsizedDataPacket(QuicFramer * framer,const QuicPacketHeader & header,const QuicFrames & frames,size_t packet_size)181 std::unique_ptr<QuicPacket> BuildUnsizedDataPacket(
182     QuicFramer* framer, const QuicPacketHeader& header,
183     const QuicFrames& frames, size_t packet_size) {
184   char* buffer = new char[packet_size];
185   EncryptionLevel level = HeaderToEncryptionLevel(header);
186   size_t length =
187       framer->BuildDataPacket(header, frames, buffer, packet_size, level);
188 
189   if (length == 0) {
190     delete[] buffer;
191     return nullptr;
192   }
193   // Re-construct the data packet with data ownership.
194   return std::make_unique<QuicPacket>(
195       buffer, length, /* owns_buffer */ true,
196       GetIncludedDestinationConnectionIdLength(header),
197       GetIncludedSourceConnectionIdLength(header), header.version_flag,
198       header.nonce != nullptr, header.packet_number_length,
199       header.retry_token_length_length, header.retry_token.length(),
200       header.length_length);
201 }
202 
Sha1Hash(absl::string_view data)203 std::string Sha1Hash(absl::string_view data) {
204   char buffer[SHA_DIGEST_LENGTH];
205   SHA1(reinterpret_cast<const uint8_t*>(data.data()), data.size(),
206        reinterpret_cast<uint8_t*>(buffer));
207   return std::string(buffer, ABSL_ARRAYSIZE(buffer));
208 }
209 
ClearControlFrame(const QuicFrame & frame)210 bool ClearControlFrame(const QuicFrame& frame) {
211   DeleteFrame(&const_cast<QuicFrame&>(frame));
212   return true;
213 }
214 
ClearControlFrameWithTransmissionType(const QuicFrame & frame,TransmissionType)215 bool ClearControlFrameWithTransmissionType(const QuicFrame& frame,
216                                            TransmissionType /*type*/) {
217   return ClearControlFrame(frame);
218 }
219 
RandUint64()220 uint64_t SimpleRandom::RandUint64() {
221   uint64_t result;
222   RandBytes(&result, sizeof(result));
223   return result;
224 }
225 
RandBytes(void * data,size_t len)226 void SimpleRandom::RandBytes(void* data, size_t len) {
227   uint8_t* data_bytes = reinterpret_cast<uint8_t*>(data);
228   while (len > 0) {
229     const size_t buffer_left = sizeof(buffer_) - buffer_offset_;
230     const size_t to_copy = std::min(buffer_left, len);
231     memcpy(data_bytes, buffer_ + buffer_offset_, to_copy);
232     data_bytes += to_copy;
233     buffer_offset_ += to_copy;
234     len -= to_copy;
235 
236     if (buffer_offset_ == sizeof(buffer_)) {
237       FillBuffer();
238     }
239   }
240 }
241 
InsecureRandBytes(void * data,size_t len)242 void SimpleRandom::InsecureRandBytes(void* data, size_t len) {
243   RandBytes(data, len);
244 }
245 
InsecureRandUint64()246 uint64_t SimpleRandom::InsecureRandUint64() { return RandUint64(); }
247 
FillBuffer()248 void SimpleRandom::FillBuffer() {
249   uint8_t nonce[12];
250   memcpy(nonce, buffer_, sizeof(nonce));
251   CRYPTO_chacha_20(buffer_, buffer_, sizeof(buffer_), key_, nonce, 0);
252   buffer_offset_ = 0;
253 }
254 
set_seed(uint64_t seed)255 void SimpleRandom::set_seed(uint64_t seed) {
256   static_assert(sizeof(key_) == SHA256_DIGEST_LENGTH, "Key has to be 256 bits");
257   SHA256(reinterpret_cast<const uint8_t*>(&seed), sizeof(seed), key_);
258 
259   memset(buffer_, 0, sizeof(buffer_));
260   FillBuffer();
261 }
262 
MockFramerVisitor()263 MockFramerVisitor::MockFramerVisitor() {
264   // By default, we want to accept packets.
265   ON_CALL(*this, OnProtocolVersionMismatch(_))
266       .WillByDefault(testing::Return(false));
267 
268   // By default, we want to accept packets.
269   ON_CALL(*this, OnUnauthenticatedHeader(_))
270       .WillByDefault(testing::Return(true));
271 
272   ON_CALL(*this, OnUnauthenticatedPublicHeader(_))
273       .WillByDefault(testing::Return(true));
274 
275   ON_CALL(*this, OnPacketHeader(_)).WillByDefault(testing::Return(true));
276 
277   ON_CALL(*this, OnStreamFrame(_)).WillByDefault(testing::Return(true));
278 
279   ON_CALL(*this, OnCryptoFrame(_)).WillByDefault(testing::Return(true));
280 
281   ON_CALL(*this, OnStopWaitingFrame(_)).WillByDefault(testing::Return(true));
282 
283   ON_CALL(*this, OnPaddingFrame(_)).WillByDefault(testing::Return(true));
284 
285   ON_CALL(*this, OnPingFrame(_)).WillByDefault(testing::Return(true));
286 
287   ON_CALL(*this, OnRstStreamFrame(_)).WillByDefault(testing::Return(true));
288 
289   ON_CALL(*this, OnConnectionCloseFrame(_))
290       .WillByDefault(testing::Return(true));
291 
292   ON_CALL(*this, OnStopSendingFrame(_)).WillByDefault(testing::Return(true));
293 
294   ON_CALL(*this, OnPathChallengeFrame(_)).WillByDefault(testing::Return(true));
295 
296   ON_CALL(*this, OnPathResponseFrame(_)).WillByDefault(testing::Return(true));
297 
298   ON_CALL(*this, OnGoAwayFrame(_)).WillByDefault(testing::Return(true));
299   ON_CALL(*this, OnMaxStreamsFrame(_)).WillByDefault(testing::Return(true));
300   ON_CALL(*this, OnStreamsBlockedFrame(_)).WillByDefault(testing::Return(true));
301 }
302 
~MockFramerVisitor()303 MockFramerVisitor::~MockFramerVisitor() {}
304 
OnProtocolVersionMismatch(ParsedQuicVersion)305 bool NoOpFramerVisitor::OnProtocolVersionMismatch(
306     ParsedQuicVersion /*version*/) {
307   return false;
308 }
309 
OnUnauthenticatedPublicHeader(const QuicPacketHeader &)310 bool NoOpFramerVisitor::OnUnauthenticatedPublicHeader(
311     const QuicPacketHeader& /*header*/) {
312   return true;
313 }
314 
OnUnauthenticatedHeader(const QuicPacketHeader &)315 bool NoOpFramerVisitor::OnUnauthenticatedHeader(
316     const QuicPacketHeader& /*header*/) {
317   return true;
318 }
319 
OnPacketHeader(const QuicPacketHeader &)320 bool NoOpFramerVisitor::OnPacketHeader(const QuicPacketHeader& /*header*/) {
321   return true;
322 }
323 
OnCoalescedPacket(const QuicEncryptedPacket &)324 void NoOpFramerVisitor::OnCoalescedPacket(
325     const QuicEncryptedPacket& /*packet*/) {}
326 
OnUndecryptablePacket(const QuicEncryptedPacket &,EncryptionLevel,bool)327 void NoOpFramerVisitor::OnUndecryptablePacket(
328     const QuicEncryptedPacket& /*packet*/, EncryptionLevel /*decryption_level*/,
329     bool /*has_decryption_key*/) {}
330 
OnStreamFrame(const QuicStreamFrame &)331 bool NoOpFramerVisitor::OnStreamFrame(const QuicStreamFrame& /*frame*/) {
332   return true;
333 }
334 
OnCryptoFrame(const QuicCryptoFrame &)335 bool NoOpFramerVisitor::OnCryptoFrame(const QuicCryptoFrame& /*frame*/) {
336   return true;
337 }
338 
OnAckFrameStart(QuicPacketNumber,QuicTime::Delta)339 bool NoOpFramerVisitor::OnAckFrameStart(QuicPacketNumber /*largest_acked*/,
340                                         QuicTime::Delta /*ack_delay_time*/) {
341   return true;
342 }
343 
OnAckRange(QuicPacketNumber,QuicPacketNumber)344 bool NoOpFramerVisitor::OnAckRange(QuicPacketNumber /*start*/,
345                                    QuicPacketNumber /*end*/) {
346   return true;
347 }
348 
OnAckTimestamp(QuicPacketNumber,QuicTime)349 bool NoOpFramerVisitor::OnAckTimestamp(QuicPacketNumber /*packet_number*/,
350                                        QuicTime /*timestamp*/) {
351   return true;
352 }
353 
OnAckFrameEnd(QuicPacketNumber,const std::optional<QuicEcnCounts> &)354 bool NoOpFramerVisitor::OnAckFrameEnd(
355     QuicPacketNumber /*start*/,
356     const std::optional<QuicEcnCounts>& /*ecn_counts*/) {
357   return true;
358 }
359 
OnStopWaitingFrame(const QuicStopWaitingFrame &)360 bool NoOpFramerVisitor::OnStopWaitingFrame(
361     const QuicStopWaitingFrame& /*frame*/) {
362   return true;
363 }
364 
OnPaddingFrame(const QuicPaddingFrame &)365 bool NoOpFramerVisitor::OnPaddingFrame(const QuicPaddingFrame& /*frame*/) {
366   return true;
367 }
368 
OnPingFrame(const QuicPingFrame &)369 bool NoOpFramerVisitor::OnPingFrame(const QuicPingFrame& /*frame*/) {
370   return true;
371 }
372 
OnRstStreamFrame(const QuicRstStreamFrame &)373 bool NoOpFramerVisitor::OnRstStreamFrame(const QuicRstStreamFrame& /*frame*/) {
374   return true;
375 }
376 
OnConnectionCloseFrame(const QuicConnectionCloseFrame &)377 bool NoOpFramerVisitor::OnConnectionCloseFrame(
378     const QuicConnectionCloseFrame& /*frame*/) {
379   return true;
380 }
381 
OnNewConnectionIdFrame(const QuicNewConnectionIdFrame &)382 bool NoOpFramerVisitor::OnNewConnectionIdFrame(
383     const QuicNewConnectionIdFrame& /*frame*/) {
384   return true;
385 }
386 
OnRetireConnectionIdFrame(const QuicRetireConnectionIdFrame &)387 bool NoOpFramerVisitor::OnRetireConnectionIdFrame(
388     const QuicRetireConnectionIdFrame& /*frame*/) {
389   return true;
390 }
391 
OnNewTokenFrame(const QuicNewTokenFrame &)392 bool NoOpFramerVisitor::OnNewTokenFrame(const QuicNewTokenFrame& /*frame*/) {
393   return true;
394 }
395 
OnStopSendingFrame(const QuicStopSendingFrame &)396 bool NoOpFramerVisitor::OnStopSendingFrame(
397     const QuicStopSendingFrame& /*frame*/) {
398   return true;
399 }
400 
OnPathChallengeFrame(const QuicPathChallengeFrame &)401 bool NoOpFramerVisitor::OnPathChallengeFrame(
402     const QuicPathChallengeFrame& /*frame*/) {
403   return true;
404 }
405 
OnPathResponseFrame(const QuicPathResponseFrame &)406 bool NoOpFramerVisitor::OnPathResponseFrame(
407     const QuicPathResponseFrame& /*frame*/) {
408   return true;
409 }
410 
OnGoAwayFrame(const QuicGoAwayFrame &)411 bool NoOpFramerVisitor::OnGoAwayFrame(const QuicGoAwayFrame& /*frame*/) {
412   return true;
413 }
414 
OnMaxStreamsFrame(const QuicMaxStreamsFrame &)415 bool NoOpFramerVisitor::OnMaxStreamsFrame(
416     const QuicMaxStreamsFrame& /*frame*/) {
417   return true;
418 }
419 
OnStreamsBlockedFrame(const QuicStreamsBlockedFrame &)420 bool NoOpFramerVisitor::OnStreamsBlockedFrame(
421     const QuicStreamsBlockedFrame& /*frame*/) {
422   return true;
423 }
424 
OnWindowUpdateFrame(const QuicWindowUpdateFrame &)425 bool NoOpFramerVisitor::OnWindowUpdateFrame(
426     const QuicWindowUpdateFrame& /*frame*/) {
427   return true;
428 }
429 
OnBlockedFrame(const QuicBlockedFrame &)430 bool NoOpFramerVisitor::OnBlockedFrame(const QuicBlockedFrame& /*frame*/) {
431   return true;
432 }
433 
OnMessageFrame(const QuicMessageFrame &)434 bool NoOpFramerVisitor::OnMessageFrame(const QuicMessageFrame& /*frame*/) {
435   return true;
436 }
437 
OnHandshakeDoneFrame(const QuicHandshakeDoneFrame &)438 bool NoOpFramerVisitor::OnHandshakeDoneFrame(
439     const QuicHandshakeDoneFrame& /*frame*/) {
440   return true;
441 }
442 
OnAckFrequencyFrame(const QuicAckFrequencyFrame &)443 bool NoOpFramerVisitor::OnAckFrequencyFrame(
444     const QuicAckFrequencyFrame& /*frame*/) {
445   return true;
446 }
447 
IsValidStatelessResetToken(const StatelessResetToken &) const448 bool NoOpFramerVisitor::IsValidStatelessResetToken(
449     const StatelessResetToken& /*token*/) const {
450   return false;
451 }
452 
MockQuicConnectionVisitor()453 MockQuicConnectionVisitor::MockQuicConnectionVisitor() {
454   ON_CALL(*this, GetFlowControlSendWindowSize(_))
455       .WillByDefault(Return(std::numeric_limits<QuicByteCount>::max()));
456 }
457 
~MockQuicConnectionVisitor()458 MockQuicConnectionVisitor::~MockQuicConnectionVisitor() {}
459 
MockQuicConnectionHelper()460 MockQuicConnectionHelper::MockQuicConnectionHelper() {}
461 
~MockQuicConnectionHelper()462 MockQuicConnectionHelper::~MockQuicConnectionHelper() {}
463 
GetClock() const464 const MockClock* MockQuicConnectionHelper::GetClock() const { return &clock_; }
465 
GetClock()466 MockClock* MockQuicConnectionHelper::GetClock() { return &clock_; }
467 
GetRandomGenerator()468 QuicRandom* MockQuicConnectionHelper::GetRandomGenerator() {
469   return &random_generator_;
470 }
471 
CreateAlarm(QuicAlarm::Delegate * delegate)472 QuicAlarm* MockAlarmFactory::CreateAlarm(QuicAlarm::Delegate* delegate) {
473   return new MockAlarmFactory::TestAlarm(
474       QuicArenaScopedPtr<QuicAlarm::Delegate>(delegate));
475 }
476 
CreateAlarm(QuicArenaScopedPtr<QuicAlarm::Delegate> delegate,QuicConnectionArena * arena)477 QuicArenaScopedPtr<QuicAlarm> MockAlarmFactory::CreateAlarm(
478     QuicArenaScopedPtr<QuicAlarm::Delegate> delegate,
479     QuicConnectionArena* arena) {
480   if (arena != nullptr) {
481     return arena->New<TestAlarm>(std::move(delegate));
482   } else {
483     return QuicArenaScopedPtr<TestAlarm>(new TestAlarm(std::move(delegate)));
484   }
485 }
486 
487 quiche::QuicheBufferAllocator*
GetStreamSendBufferAllocator()488 MockQuicConnectionHelper::GetStreamSendBufferAllocator() {
489   return &buffer_allocator_;
490 }
491 
AdvanceTime(QuicTime::Delta delta)492 void MockQuicConnectionHelper::AdvanceTime(QuicTime::Delta delta) {
493   clock_.AdvanceTime(delta);
494 }
495 
MockQuicConnection(QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective)496 MockQuicConnection::MockQuicConnection(QuicConnectionHelperInterface* helper,
497                                        QuicAlarmFactory* alarm_factory,
498                                        Perspective perspective)
499     : MockQuicConnection(TestConnectionId(),
500                          QuicSocketAddress(TestPeerIPAddress(), kTestPort),
501                          helper, alarm_factory, perspective,
502                          ParsedVersionOfIndex(CurrentSupportedVersions(), 0)) {}
503 
MockQuicConnection(QuicSocketAddress address,QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective)504 MockQuicConnection::MockQuicConnection(QuicSocketAddress address,
505                                        QuicConnectionHelperInterface* helper,
506                                        QuicAlarmFactory* alarm_factory,
507                                        Perspective perspective)
508     : MockQuicConnection(TestConnectionId(), address, helper, alarm_factory,
509                          perspective,
510                          ParsedVersionOfIndex(CurrentSupportedVersions(), 0)) {}
511 
MockQuicConnection(QuicConnectionId connection_id,QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective)512 MockQuicConnection::MockQuicConnection(QuicConnectionId connection_id,
513                                        QuicConnectionHelperInterface* helper,
514                                        QuicAlarmFactory* alarm_factory,
515                                        Perspective perspective)
516     : MockQuicConnection(connection_id,
517                          QuicSocketAddress(TestPeerIPAddress(), kTestPort),
518                          helper, alarm_factory, perspective,
519                          ParsedVersionOfIndex(CurrentSupportedVersions(), 0)) {}
520 
MockQuicConnection(QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective,const ParsedQuicVersionVector & supported_versions)521 MockQuicConnection::MockQuicConnection(
522     QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory,
523     Perspective perspective, const ParsedQuicVersionVector& supported_versions)
524     : MockQuicConnection(
525           TestConnectionId(), QuicSocketAddress(TestPeerIPAddress(), kTestPort),
526           helper, alarm_factory, perspective, supported_versions) {}
527 
MockQuicConnection(QuicConnectionId connection_id,QuicSocketAddress initial_peer_address,QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective,const ParsedQuicVersionVector & supported_versions)528 MockQuicConnection::MockQuicConnection(
529     QuicConnectionId connection_id, QuicSocketAddress initial_peer_address,
530     QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory,
531     Perspective perspective, const ParsedQuicVersionVector& supported_versions)
532     : QuicConnection(
533           connection_id,
534           /*initial_self_address=*/QuicSocketAddress(QuicIpAddress::Any4(), 5),
535           initial_peer_address, helper, alarm_factory,
536           new testing::NiceMock<MockPacketWriter>(),
537           /* owns_writer= */ true, perspective, supported_versions,
538           connection_id_generator_) {
539   ON_CALL(*this, OnError(_))
540       .WillByDefault(
541           Invoke(this, &PacketSavingConnection::QuicConnection_OnError));
542   ON_CALL(*this, SendCryptoData(_, _, _))
543       .WillByDefault(
544           Invoke(this, &MockQuicConnection::QuicConnection_SendCryptoData));
545 
546   SetSelfAddress(QuicSocketAddress(QuicIpAddress::Any4(), 5));
547 }
548 
~MockQuicConnection()549 MockQuicConnection::~MockQuicConnection() {}
550 
AdvanceTime(QuicTime::Delta delta)551 void MockQuicConnection::AdvanceTime(QuicTime::Delta delta) {
552   static_cast<MockQuicConnectionHelper*>(helper())->AdvanceTime(delta);
553 }
554 
OnProtocolVersionMismatch(ParsedQuicVersion)555 bool MockQuicConnection::OnProtocolVersionMismatch(
556     ParsedQuicVersion /*version*/) {
557   return false;
558 }
559 
PacketSavingConnection(MockQuicConnectionHelper * helper,QuicAlarmFactory * alarm_factory,Perspective perspective)560 PacketSavingConnection::PacketSavingConnection(MockQuicConnectionHelper* helper,
561                                                QuicAlarmFactory* alarm_factory,
562                                                Perspective perspective)
563     : MockQuicConnection(helper, alarm_factory, perspective),
564       mock_helper_(helper) {}
565 
PacketSavingConnection(MockQuicConnectionHelper * helper,QuicAlarmFactory * alarm_factory,Perspective perspective,const ParsedQuicVersionVector & supported_versions)566 PacketSavingConnection::PacketSavingConnection(
567     MockQuicConnectionHelper* helper, QuicAlarmFactory* alarm_factory,
568     Perspective perspective, const ParsedQuicVersionVector& supported_versions)
569     : MockQuicConnection(helper, alarm_factory, perspective,
570                          supported_versions),
571       mock_helper_(helper) {}
572 
~PacketSavingConnection()573 PacketSavingConnection::~PacketSavingConnection() {}
574 
GetSerializedPacketFate(bool,EncryptionLevel)575 SerializedPacketFate PacketSavingConnection::GetSerializedPacketFate(
576     bool /*is_mtu_discovery*/, EncryptionLevel /*encryption_level*/) {
577   return SEND_TO_WRITER;
578 }
579 
SendOrQueuePacket(SerializedPacket packet)580 void PacketSavingConnection::SendOrQueuePacket(SerializedPacket packet) {
581   encrypted_packets_.push_back(std::make_unique<QuicEncryptedPacket>(
582       CopyBuffer(packet), packet.encrypted_length, true));
583   MockClock& clock = *mock_helper_->GetClock();
584   clock.AdvanceTime(QuicTime::Delta::FromMilliseconds(10));
585   // Transfer ownership of the packet to the SentPacketManager and the
586   // ack notifier to the AckNotifierManager.
587   OnPacketSent(packet.encryption_level, packet.transmission_type);
588   QuicConnectionPeer::GetSentPacketManager(this)->OnPacketSent(
589       &packet, clock.ApproximateNow(), NOT_RETRANSMISSION,
590       HAS_RETRANSMITTABLE_DATA, true, ECN_NOT_ECT);
591 }
592 
GetPackets() const593 std::vector<const QuicEncryptedPacket*> PacketSavingConnection::GetPackets()
594     const {
595   std::vector<const QuicEncryptedPacket*> packets;
596   for (size_t i = num_cleared_packets_; i < encrypted_packets_.size(); ++i) {
597     packets.push_back(encrypted_packets_[i].get());
598   }
599   return packets;
600 }
601 
ClearPackets()602 void PacketSavingConnection::ClearPackets() {
603   num_cleared_packets_ = encrypted_packets_.size();
604 }
605 
MockQuicSession(QuicConnection * connection)606 MockQuicSession::MockQuicSession(QuicConnection* connection)
607     : MockQuicSession(connection, true) {}
608 
MockQuicSession(QuicConnection * connection,bool create_mock_crypto_stream)609 MockQuicSession::MockQuicSession(QuicConnection* connection,
610                                  bool create_mock_crypto_stream)
611     : QuicSession(connection, nullptr, DefaultQuicConfig(),
612                   connection->supported_versions(),
613                   /*num_expected_unidirectional_static_streams = */ 0) {
614   if (create_mock_crypto_stream) {
615     crypto_stream_ =
616         std::make_unique<testing::NiceMock<MockQuicCryptoStream>>(this);
617   }
618   ON_CALL(*this, WritevData(_, _, _, _, _, _))
619       .WillByDefault(testing::Return(QuicConsumedData(0, false)));
620 }
621 
~MockQuicSession()622 MockQuicSession::~MockQuicSession() { DeleteConnection(); }
623 
GetMutableCryptoStream()624 QuicCryptoStream* MockQuicSession::GetMutableCryptoStream() {
625   return crypto_stream_.get();
626 }
627 
GetCryptoStream() const628 const QuicCryptoStream* MockQuicSession::GetCryptoStream() const {
629   return crypto_stream_.get();
630 }
631 
SetCryptoStream(QuicCryptoStream * crypto_stream)632 void MockQuicSession::SetCryptoStream(QuicCryptoStream* crypto_stream) {
633   crypto_stream_.reset(crypto_stream);
634 }
635 
ConsumeData(QuicStreamId id,size_t write_length,QuicStreamOffset offset,StreamSendingState state,TransmissionType,std::optional<EncryptionLevel>)636 QuicConsumedData MockQuicSession::ConsumeData(
637     QuicStreamId id, size_t write_length, QuicStreamOffset offset,
638     StreamSendingState state, TransmissionType /*type*/,
639     std::optional<EncryptionLevel> /*level*/) {
640   if (write_length > 0) {
641     auto buf = std::make_unique<char[]>(write_length);
642     QuicStream* stream = GetOrCreateStream(id);
643     QUICHE_DCHECK(stream);
644     QuicDataWriter writer(write_length, buf.get(), quiche::HOST_BYTE_ORDER);
645     stream->WriteStreamData(offset, write_length, &writer);
646   } else {
647     QUICHE_DCHECK(state != NO_FIN);
648   }
649   return QuicConsumedData(write_length, state != NO_FIN);
650 }
651 
MockQuicCryptoStream(QuicSession * session)652 MockQuicCryptoStream::MockQuicCryptoStream(QuicSession* session)
653     : QuicCryptoStream(session), params_(new QuicCryptoNegotiatedParameters) {}
654 
~MockQuicCryptoStream()655 MockQuicCryptoStream::~MockQuicCryptoStream() {}
656 
EarlyDataReason() const657 ssl_early_data_reason_t MockQuicCryptoStream::EarlyDataReason() const {
658   return ssl_early_data_unknown;
659 }
660 
one_rtt_keys_available() const661 bool MockQuicCryptoStream::one_rtt_keys_available() const { return false; }
662 
663 const QuicCryptoNegotiatedParameters&
crypto_negotiated_params() const664 MockQuicCryptoStream::crypto_negotiated_params() const {
665   return *params_;
666 }
667 
crypto_message_parser()668 CryptoMessageParser* MockQuicCryptoStream::crypto_message_parser() {
669   return &crypto_framer_;
670 }
671 
MockQuicSpdySession(QuicConnection * connection)672 MockQuicSpdySession::MockQuicSpdySession(QuicConnection* connection)
673     : MockQuicSpdySession(connection, true) {}
674 
MockQuicSpdySession(QuicConnection * connection,bool create_mock_crypto_stream)675 MockQuicSpdySession::MockQuicSpdySession(QuicConnection* connection,
676                                          bool create_mock_crypto_stream)
677     : QuicSpdySession(connection, nullptr, DefaultQuicConfig(),
678                       connection->supported_versions()) {
679   if (create_mock_crypto_stream) {
680     crypto_stream_ = std::make_unique<MockQuicCryptoStream>(this);
681   }
682 
683   ON_CALL(*this, WritevData(_, _, _, _, _, _))
684       .WillByDefault(testing::Return(QuicConsumedData(0, false)));
685 
686   ON_CALL(*this, SendWindowUpdate(_, _))
687       .WillByDefault([this](QuicStreamId id, QuicStreamOffset byte_offset) {
688         return QuicSpdySession::SendWindowUpdate(id, byte_offset);
689       });
690 
691   ON_CALL(*this, SendBlocked(_, _))
692       .WillByDefault([this](QuicStreamId id, QuicStreamOffset byte_offset) {
693         return QuicSpdySession::SendBlocked(id, byte_offset);
694       });
695 
696   ON_CALL(*this, OnCongestionWindowChange(_)).WillByDefault(testing::Return());
697 }
698 
~MockQuicSpdySession()699 MockQuicSpdySession::~MockQuicSpdySession() { DeleteConnection(); }
700 
GetMutableCryptoStream()701 QuicCryptoStream* MockQuicSpdySession::GetMutableCryptoStream() {
702   return crypto_stream_.get();
703 }
704 
GetCryptoStream() const705 const QuicCryptoStream* MockQuicSpdySession::GetCryptoStream() const {
706   return crypto_stream_.get();
707 }
708 
SetCryptoStream(QuicCryptoStream * crypto_stream)709 void MockQuicSpdySession::SetCryptoStream(QuicCryptoStream* crypto_stream) {
710   crypto_stream_.reset(crypto_stream);
711 }
712 
ConsumeData(QuicStreamId id,size_t write_length,QuicStreamOffset offset,StreamSendingState state,TransmissionType,std::optional<EncryptionLevel>)713 QuicConsumedData MockQuicSpdySession::ConsumeData(
714     QuicStreamId id, size_t write_length, QuicStreamOffset offset,
715     StreamSendingState state, TransmissionType /*type*/,
716     std::optional<EncryptionLevel> /*level*/) {
717   if (write_length > 0) {
718     auto buf = std::make_unique<char[]>(write_length);
719     QuicStream* stream = GetOrCreateStream(id);
720     QUICHE_DCHECK(stream);
721     QuicDataWriter writer(write_length, buf.get(), quiche::HOST_BYTE_ORDER);
722     stream->WriteStreamData(offset, write_length, &writer);
723   } else {
724     QUICHE_DCHECK(state != NO_FIN);
725   }
726   return QuicConsumedData(write_length, state != NO_FIN);
727 }
728 
TestQuicSpdyServerSession(QuicConnection * connection,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions,const QuicCryptoServerConfig * crypto_config,QuicCompressedCertsCache * compressed_certs_cache)729 TestQuicSpdyServerSession::TestQuicSpdyServerSession(
730     QuicConnection* connection, const QuicConfig& config,
731     const ParsedQuicVersionVector& supported_versions,
732     const QuicCryptoServerConfig* crypto_config,
733     QuicCompressedCertsCache* compressed_certs_cache)
734     : QuicServerSessionBase(config, supported_versions, connection, &visitor_,
735                             &helper_, crypto_config, compressed_certs_cache) {
736   ON_CALL(helper_, CanAcceptClientHello(_, _, _, _, _))
737       .WillByDefault(testing::Return(true));
738 }
739 
~TestQuicSpdyServerSession()740 TestQuicSpdyServerSession::~TestQuicSpdyServerSession() { DeleteConnection(); }
741 
742 std::unique_ptr<QuicCryptoServerStreamBase>
CreateQuicCryptoServerStream(const QuicCryptoServerConfig * crypto_config,QuicCompressedCertsCache * compressed_certs_cache)743 TestQuicSpdyServerSession::CreateQuicCryptoServerStream(
744     const QuicCryptoServerConfig* crypto_config,
745     QuicCompressedCertsCache* compressed_certs_cache) {
746   return CreateCryptoServerStream(crypto_config, compressed_certs_cache, this,
747                                   &helper_);
748 }
749 
750 QuicCryptoServerStreamBase*
GetMutableCryptoStream()751 TestQuicSpdyServerSession::GetMutableCryptoStream() {
752   return QuicServerSessionBase::GetMutableCryptoStream();
753 }
754 
GetCryptoStream() const755 const QuicCryptoServerStreamBase* TestQuicSpdyServerSession::GetCryptoStream()
756     const {
757   return QuicServerSessionBase::GetCryptoStream();
758 }
759 
TestQuicSpdyClientSession(QuicConnection * connection,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions,const QuicServerId & server_id,QuicCryptoClientConfig * crypto_config,std::optional<QuicSSLConfig> ssl_config)760 TestQuicSpdyClientSession::TestQuicSpdyClientSession(
761     QuicConnection* connection, const QuicConfig& config,
762     const ParsedQuicVersionVector& supported_versions,
763     const QuicServerId& server_id, QuicCryptoClientConfig* crypto_config,
764     std::optional<QuicSSLConfig> ssl_config)
765     : QuicSpdyClientSessionBase(connection, nullptr, config,
766                                 supported_versions),
767       ssl_config_(std::move(ssl_config)) {
768   // TODO(b/153726130): Consider adding SetServerApplicationStateForResumption
769   // calls in tests and set |has_application_state| to true.
770   crypto_stream_ = std::make_unique<QuicCryptoClientStream>(
771       server_id, this, crypto_test_utils::ProofVerifyContextForTesting(),
772       crypto_config, this, /*has_application_state = */ false);
773   Initialize();
774   ON_CALL(*this, OnConfigNegotiated())
775       .WillByDefault(
776           Invoke(this, &TestQuicSpdyClientSession::RealOnConfigNegotiated));
777 }
778 
~TestQuicSpdyClientSession()779 TestQuicSpdyClientSession::~TestQuicSpdyClientSession() {}
780 
GetMutableCryptoStream()781 QuicCryptoClientStream* TestQuicSpdyClientSession::GetMutableCryptoStream() {
782   return crypto_stream_.get();
783 }
784 
GetCryptoStream() const785 const QuicCryptoClientStream* TestQuicSpdyClientSession::GetCryptoStream()
786     const {
787   return crypto_stream_.get();
788 }
789 
RealOnConfigNegotiated()790 void TestQuicSpdyClientSession::RealOnConfigNegotiated() {
791   QuicSpdyClientSessionBase::OnConfigNegotiated();
792 }
793 
MockPacketWriter()794 MockPacketWriter::MockPacketWriter() {
795   ON_CALL(*this, GetMaxPacketSize(_))
796       .WillByDefault(testing::Return(kMaxOutgoingPacketSize));
797   ON_CALL(*this, IsBatchMode()).WillByDefault(testing::Return(false));
798   ON_CALL(*this, GetNextWriteLocation(_, _))
799       .WillByDefault(testing::Return(QuicPacketBuffer()));
800   ON_CALL(*this, Flush())
801       .WillByDefault(testing::Return(WriteResult(WRITE_STATUS_OK, 0)));
802   ON_CALL(*this, SupportsReleaseTime()).WillByDefault(testing::Return(false));
803 }
804 
~MockPacketWriter()805 MockPacketWriter::~MockPacketWriter() {}
806 
MockSendAlgorithm()807 MockSendAlgorithm::MockSendAlgorithm() {
808   ON_CALL(*this, PacingRate(_))
809       .WillByDefault(testing::Return(QuicBandwidth::Zero()));
810   ON_CALL(*this, BandwidthEstimate())
811       .WillByDefault(testing::Return(QuicBandwidth::Zero()));
812 }
813 
~MockSendAlgorithm()814 MockSendAlgorithm::~MockSendAlgorithm() {}
815 
MockLossAlgorithm()816 MockLossAlgorithm::MockLossAlgorithm() {}
817 
~MockLossAlgorithm()818 MockLossAlgorithm::~MockLossAlgorithm() {}
819 
MockAckListener()820 MockAckListener::MockAckListener() {}
821 
~MockAckListener()822 MockAckListener::~MockAckListener() {}
823 
MockNetworkChangeVisitor()824 MockNetworkChangeVisitor::MockNetworkChangeVisitor() {}
825 
~MockNetworkChangeVisitor()826 MockNetworkChangeVisitor::~MockNetworkChangeVisitor() {}
827 
TestPeerIPAddress()828 QuicIpAddress TestPeerIPAddress() { return QuicIpAddress::Loopback4(); }
829 
QuicVersionMax()830 ParsedQuicVersion QuicVersionMax() { return AllSupportedVersions().front(); }
831 
QuicVersionMin()832 ParsedQuicVersion QuicVersionMin() { return AllSupportedVersions().back(); }
833 
DisableQuicVersionsWithTls()834 void DisableQuicVersionsWithTls() {
835   for (const ParsedQuicVersion& version : AllSupportedVersionsWithTls()) {
836     QuicDisableVersion(version);
837   }
838 }
839 
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data)840 QuicEncryptedPacket* ConstructEncryptedPacket(
841     QuicConnectionId destination_connection_id,
842     QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
843     uint64_t packet_number, const std::string& data) {
844   return ConstructEncryptedPacket(
845       destination_connection_id, source_connection_id, version_flag, reset_flag,
846       packet_number, data, CONNECTION_ID_PRESENT, CONNECTION_ID_ABSENT,
847       PACKET_4BYTE_PACKET_NUMBER);
848 }
849 
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length)850 QuicEncryptedPacket* ConstructEncryptedPacket(
851     QuicConnectionId destination_connection_id,
852     QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
853     uint64_t packet_number, const std::string& data,
854     QuicConnectionIdIncluded destination_connection_id_included,
855     QuicConnectionIdIncluded source_connection_id_included,
856     QuicPacketNumberLength packet_number_length) {
857   return ConstructEncryptedPacket(
858       destination_connection_id, source_connection_id, version_flag, reset_flag,
859       packet_number, data, destination_connection_id_included,
860       source_connection_id_included, packet_number_length, nullptr);
861 }
862 
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length,ParsedQuicVersionVector * versions)863 QuicEncryptedPacket* ConstructEncryptedPacket(
864     QuicConnectionId destination_connection_id,
865     QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
866     uint64_t packet_number, const std::string& data,
867     QuicConnectionIdIncluded destination_connection_id_included,
868     QuicConnectionIdIncluded source_connection_id_included,
869     QuicPacketNumberLength packet_number_length,
870     ParsedQuicVersionVector* versions) {
871   return ConstructEncryptedPacket(
872       destination_connection_id, source_connection_id, version_flag, reset_flag,
873       packet_number, data, false, destination_connection_id_included,
874       source_connection_id_included, packet_number_length, versions,
875       Perspective::IS_CLIENT);
876 }
877 
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,bool full_padding,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length,ParsedQuicVersionVector * versions)878 QuicEncryptedPacket* ConstructEncryptedPacket(
879     QuicConnectionId destination_connection_id,
880     QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
881     uint64_t packet_number, const std::string& data, bool full_padding,
882     QuicConnectionIdIncluded destination_connection_id_included,
883     QuicConnectionIdIncluded source_connection_id_included,
884     QuicPacketNumberLength packet_number_length,
885     ParsedQuicVersionVector* versions) {
886   return ConstructEncryptedPacket(
887       destination_connection_id, source_connection_id, version_flag, reset_flag,
888       packet_number, data, full_padding, destination_connection_id_included,
889       source_connection_id_included, packet_number_length, versions,
890       Perspective::IS_CLIENT);
891 }
892 
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,bool full_padding,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length,ParsedQuicVersionVector * versions,Perspective perspective)893 QuicEncryptedPacket* ConstructEncryptedPacket(
894     QuicConnectionId destination_connection_id,
895     QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
896     uint64_t packet_number, const std::string& data, bool full_padding,
897     QuicConnectionIdIncluded destination_connection_id_included,
898     QuicConnectionIdIncluded source_connection_id_included,
899     QuicPacketNumberLength packet_number_length,
900     ParsedQuicVersionVector* versions, Perspective perspective) {
901   QuicPacketHeader header;
902   header.destination_connection_id = destination_connection_id;
903   header.destination_connection_id_included =
904       destination_connection_id_included;
905   header.source_connection_id = source_connection_id;
906   header.source_connection_id_included = source_connection_id_included;
907   header.version_flag = version_flag;
908   header.reset_flag = reset_flag;
909   header.packet_number_length = packet_number_length;
910   header.packet_number = QuicPacketNumber(packet_number);
911   ParsedQuicVersionVector supported_versions = CurrentSupportedVersions();
912   if (!versions) {
913     versions = &supported_versions;
914   }
915   EXPECT_FALSE(versions->empty());
916   ParsedQuicVersion version = (*versions)[0];
917   if (QuicVersionHasLongHeaderLengths(version.transport_version) &&
918       version_flag) {
919     header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1;
920     header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2;
921   }
922 
923   QuicFrames frames;
924   QuicFramer framer(*versions, QuicTime::Zero(), perspective,
925                     kQuicDefaultConnectionIdLength);
926   framer.SetInitialObfuscators(destination_connection_id);
927   EncryptionLevel level =
928       header.version_flag ? ENCRYPTION_INITIAL : ENCRYPTION_FORWARD_SECURE;
929   if (level != ENCRYPTION_INITIAL) {
930     framer.SetEncrypter(level, std::make_unique<TaggingEncrypter>(level));
931   }
932   if (!QuicVersionUsesCryptoFrames(version.transport_version)) {
933     QuicFrame frame(
934         QuicStreamFrame(QuicUtils::GetCryptoStreamId(version.transport_version),
935                         false, 0, absl::string_view(data)));
936     frames.push_back(frame);
937   } else {
938     QuicFrame frame(new QuicCryptoFrame(level, 0, data));
939     frames.push_back(frame);
940   }
941   if (full_padding) {
942     frames.push_back(QuicFrame(QuicPaddingFrame(-1)));
943   } else {
944     // We need a minimum number of bytes of encrypted payload. This will
945     // guarantee that we have at least that much. (It ignores the overhead of
946     // the stream/crypto framing, so it overpads slightly.)
947     size_t min_plaintext_size = QuicPacketCreator::MinPlaintextPacketSize(
948         version, packet_number_length);
949     if (data.length() < min_plaintext_size) {
950       size_t padding_length = min_plaintext_size - data.length();
951       frames.push_back(QuicFrame(QuicPaddingFrame(padding_length)));
952     }
953   }
954 
955   std::unique_ptr<QuicPacket> packet(
956       BuildUnsizedDataPacket(&framer, header, frames));
957   EXPECT_TRUE(packet != nullptr);
958   char* buffer = new char[kMaxOutgoingPacketSize];
959   size_t encrypted_length =
960       framer.EncryptPayload(level, QuicPacketNumber(packet_number), *packet,
961                             buffer, kMaxOutgoingPacketSize);
962   EXPECT_NE(0u, encrypted_length);
963   DeleteFrames(&frames);
964   return new QuicEncryptedPacket(buffer, encrypted_length, true);
965 }
966 
GetUndecryptableEarlyPacket(const ParsedQuicVersion & version,const QuicConnectionId & server_connection_id)967 std::unique_ptr<QuicEncryptedPacket> GetUndecryptableEarlyPacket(
968     const ParsedQuicVersion& version,
969     const QuicConnectionId& server_connection_id) {
970   QuicPacketHeader header;
971   header.destination_connection_id = server_connection_id;
972   header.destination_connection_id_included = CONNECTION_ID_PRESENT;
973   header.source_connection_id = EmptyQuicConnectionId();
974   header.source_connection_id_included = CONNECTION_ID_PRESENT;
975   if (!version.SupportsClientConnectionIds()) {
976     header.source_connection_id_included = CONNECTION_ID_ABSENT;
977   }
978   header.version_flag = true;
979   header.reset_flag = false;
980   header.packet_number_length = PACKET_4BYTE_PACKET_NUMBER;
981   header.packet_number = QuicPacketNumber(33);
982   header.long_packet_type = ZERO_RTT_PROTECTED;
983   if (version.HasLongHeaderLengths()) {
984     header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1;
985     header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2;
986   }
987 
988   QuicFrames frames;
989   frames.push_back(QuicFrame(QuicPingFrame()));
990   frames.push_back(QuicFrame(QuicPaddingFrame(100)));
991   QuicFramer framer({version}, QuicTime::Zero(), Perspective::IS_CLIENT,
992                     kQuicDefaultConnectionIdLength);
993   framer.SetInitialObfuscators(server_connection_id);
994 
995   framer.SetEncrypter(ENCRYPTION_ZERO_RTT,
996                       std::make_unique<TaggingEncrypter>(ENCRYPTION_ZERO_RTT));
997   std::unique_ptr<QuicPacket> packet(
998       BuildUnsizedDataPacket(&framer, header, frames));
999   EXPECT_TRUE(packet != nullptr);
1000   char* buffer = new char[kMaxOutgoingPacketSize];
1001   size_t encrypted_length =
1002       framer.EncryptPayload(ENCRYPTION_ZERO_RTT, header.packet_number, *packet,
1003                             buffer, kMaxOutgoingPacketSize);
1004   EXPECT_NE(0u, encrypted_length);
1005   DeleteFrames(&frames);
1006   return std::make_unique<QuicEncryptedPacket>(buffer, encrypted_length,
1007                                                /*owns_buffer=*/true);
1008 }
1009 
ConstructReceivedPacket(const QuicEncryptedPacket & encrypted_packet,QuicTime receipt_time)1010 QuicReceivedPacket* ConstructReceivedPacket(
1011     const QuicEncryptedPacket& encrypted_packet, QuicTime receipt_time) {
1012   char* buffer = new char[encrypted_packet.length()];
1013   memcpy(buffer, encrypted_packet.data(), encrypted_packet.length());
1014   return new QuicReceivedPacket(buffer, encrypted_packet.length(), receipt_time,
1015                                 true);
1016 }
1017 
ConstructMisFramedEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length,ParsedQuicVersion version,Perspective perspective)1018 QuicEncryptedPacket* ConstructMisFramedEncryptedPacket(
1019     QuicConnectionId destination_connection_id,
1020     QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
1021     uint64_t packet_number, const std::string& data,
1022     QuicConnectionIdIncluded destination_connection_id_included,
1023     QuicConnectionIdIncluded source_connection_id_included,
1024     QuicPacketNumberLength packet_number_length, ParsedQuicVersion version,
1025     Perspective perspective) {
1026   QuicPacketHeader header;
1027   header.destination_connection_id = destination_connection_id;
1028   header.destination_connection_id_included =
1029       destination_connection_id_included;
1030   header.source_connection_id = source_connection_id;
1031   header.source_connection_id_included = source_connection_id_included;
1032   header.version_flag = version_flag;
1033   header.reset_flag = reset_flag;
1034   header.packet_number_length = packet_number_length;
1035   header.packet_number = QuicPacketNumber(packet_number);
1036   if (QuicVersionHasLongHeaderLengths(version.transport_version) &&
1037       version_flag) {
1038     header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1;
1039     header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2;
1040   }
1041   QuicFrame frame(QuicStreamFrame(1, false, 0, absl::string_view(data)));
1042   QuicFrames frames;
1043   frames.push_back(frame);
1044   QuicFramer framer({version}, QuicTime::Zero(), perspective,
1045                     kQuicDefaultConnectionIdLength);
1046   framer.SetInitialObfuscators(destination_connection_id);
1047   EncryptionLevel level =
1048       version_flag ? ENCRYPTION_INITIAL : ENCRYPTION_FORWARD_SECURE;
1049   if (level != ENCRYPTION_INITIAL) {
1050     framer.SetEncrypter(level, std::make_unique<TaggingEncrypter>(level));
1051   }
1052   // We need a minimum of 7 bytes of encrypted payload. This will guarantee that
1053   // we have at least that much. (It ignores the overhead of the stream/crypto
1054   // framing, so it overpads slightly.)
1055   if (data.length() < 7) {
1056     size_t padding_length = 7 - data.length();
1057     frames.push_back(QuicFrame(QuicPaddingFrame(padding_length)));
1058   }
1059 
1060   std::unique_ptr<QuicPacket> packet(
1061       BuildUnsizedDataPacket(&framer, header, frames));
1062   EXPECT_TRUE(packet != nullptr);
1063 
1064   // Now set the frame type to 0x1F, which is an invalid frame type.
1065   reinterpret_cast<unsigned char*>(
1066       packet->mutable_data())[GetStartOfEncryptedData(
1067       framer.transport_version(),
1068       GetIncludedDestinationConnectionIdLength(header),
1069       GetIncludedSourceConnectionIdLength(header), version_flag,
1070       false /* no diversification nonce */, packet_number_length,
1071       header.retry_token_length_length, 0, header.length_length)] = 0x1F;
1072 
1073   char* buffer = new char[kMaxOutgoingPacketSize];
1074   size_t encrypted_length =
1075       framer.EncryptPayload(level, QuicPacketNumber(packet_number), *packet,
1076                             buffer, kMaxOutgoingPacketSize);
1077   EXPECT_NE(0u, encrypted_length);
1078   return new QuicEncryptedPacket(buffer, encrypted_length, true);
1079 }
1080 
DefaultQuicConfig()1081 QuicConfig DefaultQuicConfig() {
1082   QuicConfig config;
1083   config.SetInitialMaxStreamDataBytesIncomingBidirectionalToSend(
1084       kInitialStreamFlowControlWindowForTest);
1085   config.SetInitialMaxStreamDataBytesOutgoingBidirectionalToSend(
1086       kInitialStreamFlowControlWindowForTest);
1087   config.SetInitialMaxStreamDataBytesUnidirectionalToSend(
1088       kInitialStreamFlowControlWindowForTest);
1089   config.SetInitialStreamFlowControlWindowToSend(
1090       kInitialStreamFlowControlWindowForTest);
1091   config.SetInitialSessionFlowControlWindowToSend(
1092       kInitialSessionFlowControlWindowForTest);
1093   QuicConfigPeer::SetReceivedMaxBidirectionalStreams(
1094       &config, kDefaultMaxStreamsPerConnection);
1095   // Default enable NSTP.
1096   // This is unnecessary for versions > 44
1097   if (!config.HasClientSentConnectionOption(quic::kNSTP,
1098                                             quic::Perspective::IS_CLIENT)) {
1099     quic::QuicTagVector connection_options;
1100     connection_options.push_back(quic::kNSTP);
1101     config.SetConnectionOptionsToSend(connection_options);
1102   }
1103   return config;
1104 }
1105 
SupportedVersions(ParsedQuicVersion version)1106 ParsedQuicVersionVector SupportedVersions(ParsedQuicVersion version) {
1107   ParsedQuicVersionVector versions;
1108   versions.push_back(version);
1109   return versions;
1110 }
1111 
MockQuicConnectionDebugVisitor()1112 MockQuicConnectionDebugVisitor::MockQuicConnectionDebugVisitor() {}
1113 
~MockQuicConnectionDebugVisitor()1114 MockQuicConnectionDebugVisitor::~MockQuicConnectionDebugVisitor() {}
1115 
MockReceivedPacketManager(QuicConnectionStats * stats)1116 MockReceivedPacketManager::MockReceivedPacketManager(QuicConnectionStats* stats)
1117     : QuicReceivedPacketManager(stats) {}
1118 
~MockReceivedPacketManager()1119 MockReceivedPacketManager::~MockReceivedPacketManager() {}
1120 
MockPacketCreatorDelegate()1121 MockPacketCreatorDelegate::MockPacketCreatorDelegate() {}
~MockPacketCreatorDelegate()1122 MockPacketCreatorDelegate::~MockPacketCreatorDelegate() {}
1123 
MockSessionNotifier()1124 MockSessionNotifier::MockSessionNotifier() {}
~MockSessionNotifier()1125 MockSessionNotifier::~MockSessionNotifier() {}
1126 
1127 // static
1128 QuicCryptoClientStream::HandshakerInterface*
GetHandshaker(QuicCryptoClientStream * stream)1129 QuicCryptoClientStreamPeer::GetHandshaker(QuicCryptoClientStream* stream) {
1130   return stream->handshaker_.get();
1131 }
1132 
CreateClientSessionForTest(QuicServerId server_id,QuicTime::Delta connection_start_time,const ParsedQuicVersionVector & supported_versions,MockQuicConnectionHelper * helper,QuicAlarmFactory * alarm_factory,QuicCryptoClientConfig * crypto_client_config,PacketSavingConnection ** client_connection,TestQuicSpdyClientSession ** client_session)1133 void CreateClientSessionForTest(
1134     QuicServerId server_id, QuicTime::Delta connection_start_time,
1135     const ParsedQuicVersionVector& supported_versions,
1136     MockQuicConnectionHelper* helper, QuicAlarmFactory* alarm_factory,
1137     QuicCryptoClientConfig* crypto_client_config,
1138     PacketSavingConnection** client_connection,
1139     TestQuicSpdyClientSession** client_session) {
1140   QUICHE_CHECK(crypto_client_config);
1141   QUICHE_CHECK(client_connection);
1142   QUICHE_CHECK(client_session);
1143   QUICHE_CHECK(!connection_start_time.IsZero())
1144       << "Connections must start at non-zero times, otherwise the "
1145       << "strike-register will be unhappy.";
1146 
1147   QuicConfig config = DefaultQuicConfig();
1148   *client_connection = new PacketSavingConnection(
1149       helper, alarm_factory, Perspective::IS_CLIENT, supported_versions);
1150   *client_session = new TestQuicSpdyClientSession(*client_connection, config,
1151                                                   supported_versions, server_id,
1152                                                   crypto_client_config);
1153   (*client_connection)->AdvanceTime(connection_start_time);
1154 }
1155 
CreateServerSessionForTest(QuicServerId,QuicTime::Delta connection_start_time,ParsedQuicVersionVector supported_versions,MockQuicConnectionHelper * helper,QuicAlarmFactory * alarm_factory,QuicCryptoServerConfig * server_crypto_config,QuicCompressedCertsCache * compressed_certs_cache,PacketSavingConnection ** server_connection,TestQuicSpdyServerSession ** server_session)1156 void CreateServerSessionForTest(
1157     QuicServerId /*server_id*/, QuicTime::Delta connection_start_time,
1158     ParsedQuicVersionVector supported_versions,
1159     MockQuicConnectionHelper* helper, QuicAlarmFactory* alarm_factory,
1160     QuicCryptoServerConfig* server_crypto_config,
1161     QuicCompressedCertsCache* compressed_certs_cache,
1162     PacketSavingConnection** server_connection,
1163     TestQuicSpdyServerSession** server_session) {
1164   QUICHE_CHECK(server_crypto_config);
1165   QUICHE_CHECK(server_connection);
1166   QUICHE_CHECK(server_session);
1167   QUICHE_CHECK(!connection_start_time.IsZero())
1168       << "Connections must start at non-zero times, otherwise the "
1169       << "strike-register will be unhappy.";
1170 
1171   *server_connection =
1172       new PacketSavingConnection(helper, alarm_factory, Perspective::IS_SERVER,
1173                                  ParsedVersionOfIndex(supported_versions, 0));
1174   *server_session = new TestQuicSpdyServerSession(
1175       *server_connection, DefaultQuicConfig(), supported_versions,
1176       server_crypto_config, compressed_certs_cache);
1177   (*server_session)->Initialize();
1178 
1179   // We advance the clock initially because the default time is zero and the
1180   // strike register worries that we've just overflowed a uint32_t time.
1181   (*server_connection)->AdvanceTime(connection_start_time);
1182 }
1183 
GetNthClientInitiatedBidirectionalStreamId(QuicTransportVersion version,int n)1184 QuicStreamId GetNthClientInitiatedBidirectionalStreamId(
1185     QuicTransportVersion version, int n) {
1186   int num = n;
1187   if (!VersionUsesHttp3(version)) {
1188     num++;
1189   }
1190   return QuicUtils::GetFirstBidirectionalStreamId(version,
1191                                                   Perspective::IS_CLIENT) +
1192          QuicUtils::StreamIdDelta(version) * num;
1193 }
1194 
GetNthServerInitiatedBidirectionalStreamId(QuicTransportVersion version,int n)1195 QuicStreamId GetNthServerInitiatedBidirectionalStreamId(
1196     QuicTransportVersion version, int n) {
1197   return QuicUtils::GetFirstBidirectionalStreamId(version,
1198                                                   Perspective::IS_SERVER) +
1199          QuicUtils::StreamIdDelta(version) * n;
1200 }
1201 
GetNthServerInitiatedUnidirectionalStreamId(QuicTransportVersion version,int n)1202 QuicStreamId GetNthServerInitiatedUnidirectionalStreamId(
1203     QuicTransportVersion version, int n) {
1204   return QuicUtils::GetFirstUnidirectionalStreamId(version,
1205                                                    Perspective::IS_SERVER) +
1206          QuicUtils::StreamIdDelta(version) * n;
1207 }
1208 
GetNthClientInitiatedUnidirectionalStreamId(QuicTransportVersion version,int n)1209 QuicStreamId GetNthClientInitiatedUnidirectionalStreamId(
1210     QuicTransportVersion version, int n) {
1211   return QuicUtils::GetFirstUnidirectionalStreamId(version,
1212                                                    Perspective::IS_CLIENT) +
1213          QuicUtils::StreamIdDelta(version) * n;
1214 }
1215 
DetermineStreamType(QuicStreamId id,ParsedQuicVersion version,Perspective perspective,bool is_incoming,StreamType default_type)1216 StreamType DetermineStreamType(QuicStreamId id, ParsedQuicVersion version,
1217                                Perspective perspective, bool is_incoming,
1218                                StreamType default_type) {
1219   return version.HasIetfQuicFrames()
1220              ? QuicUtils::GetStreamType(id, perspective, is_incoming, version)
1221              : default_type;
1222 }
1223 
MemSliceFromString(absl::string_view data)1224 quiche::QuicheMemSlice MemSliceFromString(absl::string_view data) {
1225   if (data.empty()) {
1226     return quiche::QuicheMemSlice();
1227   }
1228 
1229   static quiche::SimpleBufferAllocator* allocator =
1230       new quiche::SimpleBufferAllocator();
1231   return quiche::QuicheMemSlice(quiche::QuicheBuffer::Copy(allocator, data));
1232 }
1233 
EncryptPacket(uint64_t,absl::string_view,absl::string_view plaintext,char * output,size_t * output_length,size_t max_output_length)1234 bool TaggingEncrypter::EncryptPacket(uint64_t /*packet_number*/,
1235                                      absl::string_view /*associated_data*/,
1236                                      absl::string_view plaintext, char* output,
1237                                      size_t* output_length,
1238                                      size_t max_output_length) {
1239   const size_t len = plaintext.size() + kTagSize;
1240   if (max_output_length < len) {
1241     return false;
1242   }
1243   // Memmove is safe for inplace encryption.
1244   memmove(output, plaintext.data(), plaintext.size());
1245   output += plaintext.size();
1246   memset(output, tag_, kTagSize);
1247   *output_length = len;
1248   return true;
1249 }
1250 
DecryptPacket(uint64_t,absl::string_view,absl::string_view ciphertext,char * output,size_t * output_length,size_t)1251 bool TaggingDecrypter::DecryptPacket(uint64_t /*packet_number*/,
1252                                      absl::string_view /*associated_data*/,
1253                                      absl::string_view ciphertext, char* output,
1254                                      size_t* output_length,
1255                                      size_t /*max_output_length*/) {
1256   if (ciphertext.size() < kTagSize) {
1257     return false;
1258   }
1259   if (!CheckTag(ciphertext, GetTag(ciphertext))) {
1260     return false;
1261   }
1262   *output_length = ciphertext.size() - kTagSize;
1263   memcpy(output, ciphertext.data(), *output_length);
1264   return true;
1265 }
1266 
CheckTag(absl::string_view ciphertext,uint8_t tag)1267 bool TaggingDecrypter::CheckTag(absl::string_view ciphertext, uint8_t tag) {
1268   for (size_t i = ciphertext.size() - kTagSize; i < ciphertext.size(); i++) {
1269     if (ciphertext.data()[i] != tag) {
1270       return false;
1271     }
1272   }
1273 
1274   return true;
1275 }
1276 
TestPacketWriter(ParsedQuicVersion version,MockClock * clock,Perspective perspective)1277 TestPacketWriter::TestPacketWriter(ParsedQuicVersion version, MockClock* clock,
1278                                    Perspective perspective)
1279     : version_(version),
1280       framer_(SupportedVersions(version_),
1281               QuicUtils::InvertPerspective(perspective)),
1282       clock_(clock) {
1283   QuicFramerPeer::SetLastSerializedServerConnectionId(framer_.framer(),
1284                                                       TestConnectionId());
1285   framer_.framer()->SetInitialObfuscators(TestConnectionId());
1286 
1287   for (int i = 0; i < 128; ++i) {
1288     PacketBuffer* p = new PacketBuffer();
1289     packet_buffer_pool_.push_back(p);
1290     packet_buffer_pool_index_[p->buffer] = p;
1291     packet_buffer_free_list_.push_back(p);
1292   }
1293 }
1294 
~TestPacketWriter()1295 TestPacketWriter::~TestPacketWriter() {
1296   EXPECT_EQ(packet_buffer_pool_.size(), packet_buffer_free_list_.size())
1297       << packet_buffer_pool_.size() - packet_buffer_free_list_.size()
1298       << " out of " << packet_buffer_pool_.size()
1299       << " packet buffers have been leaked.";
1300   for (auto p : packet_buffer_pool_) {
1301     delete p;
1302   }
1303 }
1304 
WritePacket(const char * buffer,size_t buf_len,const QuicIpAddress & self_address,const QuicSocketAddress & peer_address,PerPacketOptions *,const QuicPacketWriterParams & params)1305 WriteResult TestPacketWriter::WritePacket(
1306     const char* buffer, size_t buf_len, const QuicIpAddress& self_address,
1307     const QuicSocketAddress& peer_address, PerPacketOptions* /*options*/,
1308     const QuicPacketWriterParams& params) {
1309   last_write_source_address_ = self_address;
1310   last_write_peer_address_ = peer_address;
1311   // If the buffer is allocated from the pool, return it back to the pool.
1312   // Note the buffer content doesn't change.
1313   if (packet_buffer_pool_index_.find(const_cast<char*>(buffer)) !=
1314       packet_buffer_pool_index_.end()) {
1315     FreePacketBuffer(buffer);
1316   }
1317 
1318   QuicEncryptedPacket packet(buffer, buf_len);
1319   ++packets_write_attempts_;
1320 
1321   if (packet.length() >= sizeof(final_bytes_of_last_packet_)) {
1322     final_bytes_of_previous_packet_ = final_bytes_of_last_packet_;
1323     memcpy(&final_bytes_of_last_packet_, packet.data() + packet.length() - 4,
1324            sizeof(final_bytes_of_last_packet_));
1325   }
1326   if (framer_.framer()->version().KnowsWhichDecrypterToUse()) {
1327     framer_.framer()->InstallDecrypter(ENCRYPTION_HANDSHAKE,
1328                                        std::make_unique<TaggingDecrypter>());
1329     framer_.framer()->InstallDecrypter(ENCRYPTION_ZERO_RTT,
1330                                        std::make_unique<TaggingDecrypter>());
1331     framer_.framer()->InstallDecrypter(ENCRYPTION_FORWARD_SECURE,
1332                                        std::make_unique<TaggingDecrypter>());
1333   } else if (!framer_.framer()->HasDecrypterOfEncryptionLevel(
1334                  ENCRYPTION_FORWARD_SECURE) &&
1335              !framer_.framer()->HasDecrypterOfEncryptionLevel(
1336                  ENCRYPTION_ZERO_RTT)) {
1337     framer_.framer()->SetAlternativeDecrypter(
1338         ENCRYPTION_FORWARD_SECURE,
1339         std::make_unique<StrictTaggingDecrypter>(ENCRYPTION_FORWARD_SECURE),
1340         false);
1341   }
1342   EXPECT_EQ(next_packet_processable_, framer_.ProcessPacket(packet))
1343       << framer_.framer()->detailed_error() << " perspective "
1344       << framer_.framer()->perspective();
1345   next_packet_processable_ = true;
1346   if (block_on_next_write_) {
1347     write_blocked_ = true;
1348     block_on_next_write_ = false;
1349   }
1350   if (next_packet_too_large_) {
1351     next_packet_too_large_ = false;
1352     return WriteResult(WRITE_STATUS_ERROR, *MessageTooBigErrorCode());
1353   }
1354   if (always_get_packet_too_large_) {
1355     return WriteResult(WRITE_STATUS_ERROR, *MessageTooBigErrorCode());
1356   }
1357   if (IsWriteBlocked()) {
1358     return WriteResult(is_write_blocked_data_buffered_
1359                            ? WRITE_STATUS_BLOCKED_DATA_BUFFERED
1360                            : WRITE_STATUS_BLOCKED,
1361                        0);
1362   }
1363 
1364   if (ShouldWriteFail()) {
1365     return WriteResult(WRITE_STATUS_ERROR, write_error_code_);
1366   }
1367 
1368   last_packet_size_ = packet.length();
1369   total_bytes_written_ += packet.length();
1370   last_packet_header_ = framer_.header();
1371   if (!framer_.connection_close_frames().empty()) {
1372     ++connection_close_packets_;
1373   }
1374   if (!write_pause_time_delta_.IsZero()) {
1375     clock_->AdvanceTime(write_pause_time_delta_);
1376   }
1377   if (is_batch_mode_) {
1378     bytes_buffered_ += last_packet_size_;
1379     return WriteResult(WRITE_STATUS_OK, 0);
1380   }
1381   last_ecn_sent_ = params.ecn_codepoint;
1382   return WriteResult(WRITE_STATUS_OK, last_packet_size_);
1383 }
1384 
GetNextWriteLocation(const QuicIpAddress &,const QuicSocketAddress &)1385 QuicPacketBuffer TestPacketWriter::GetNextWriteLocation(
1386     const QuicIpAddress& /*self_address*/,
1387     const QuicSocketAddress& /*peer_address*/) {
1388   return {AllocPacketBuffer(), [this](const char* p) { FreePacketBuffer(p); }};
1389 }
1390 
Flush()1391 WriteResult TestPacketWriter::Flush() {
1392   flush_attempts_++;
1393   if (block_on_next_flush_) {
1394     block_on_next_flush_ = false;
1395     SetWriteBlocked();
1396     return WriteResult(WRITE_STATUS_BLOCKED, /*errno*/ -1);
1397   }
1398   if (write_should_fail_) {
1399     return WriteResult(WRITE_STATUS_ERROR, /*errno*/ -1);
1400   }
1401   int bytes_flushed = bytes_buffered_;
1402   bytes_buffered_ = 0;
1403   return WriteResult(WRITE_STATUS_OK, bytes_flushed);
1404 }
1405 
AllocPacketBuffer()1406 char* TestPacketWriter::AllocPacketBuffer() {
1407   PacketBuffer* p = packet_buffer_free_list_.front();
1408   EXPECT_FALSE(p->in_use);
1409   p->in_use = true;
1410   packet_buffer_free_list_.pop_front();
1411   return p->buffer;
1412 }
1413 
FreePacketBuffer(const char * buffer)1414 void TestPacketWriter::FreePacketBuffer(const char* buffer) {
1415   auto iter = packet_buffer_pool_index_.find(const_cast<char*>(buffer));
1416   ASSERT_TRUE(iter != packet_buffer_pool_index_.end());
1417   PacketBuffer* p = iter->second;
1418   ASSERT_TRUE(p->in_use);
1419   p->in_use = false;
1420   packet_buffer_free_list_.push_back(p);
1421 }
1422 
WriteServerVersionNegotiationProbeResponse(char * packet_bytes,size_t * packet_length_out,const char * source_connection_id_bytes,uint8_t source_connection_id_length)1423 bool WriteServerVersionNegotiationProbeResponse(
1424     char* packet_bytes, size_t* packet_length_out,
1425     const char* source_connection_id_bytes,
1426     uint8_t source_connection_id_length) {
1427   if (packet_bytes == nullptr) {
1428     QUIC_BUG(quic_bug_10256_1) << "Invalid packet_bytes";
1429     return false;
1430   }
1431   if (packet_length_out == nullptr) {
1432     QUIC_BUG(quic_bug_10256_2) << "Invalid packet_length_out";
1433     return false;
1434   }
1435   QuicConnectionId source_connection_id(source_connection_id_bytes,
1436                                         source_connection_id_length);
1437   std::unique_ptr<QuicEncryptedPacket> encrypted_packet =
1438       QuicFramer::BuildVersionNegotiationPacket(
1439           source_connection_id, EmptyQuicConnectionId(),
1440           /*ietf_quic=*/true, /*use_length_prefix=*/true,
1441           ParsedQuicVersionVector{});
1442   if (!encrypted_packet) {
1443     QUIC_BUG(quic_bug_10256_3) << "Failed to create version negotiation packet";
1444     return false;
1445   }
1446   if (*packet_length_out < encrypted_packet->length()) {
1447     QUIC_BUG(quic_bug_10256_4)
1448         << "Invalid *packet_length_out " << *packet_length_out << " < "
1449         << encrypted_packet->length();
1450     return false;
1451   }
1452   *packet_length_out = encrypted_packet->length();
1453   memcpy(packet_bytes, encrypted_packet->data(), *packet_length_out);
1454   return true;
1455 }
1456 
ParseClientVersionNegotiationProbePacket(const char * packet_bytes,size_t packet_length,char * destination_connection_id_bytes,uint8_t * destination_connection_id_length_out)1457 bool ParseClientVersionNegotiationProbePacket(
1458     const char* packet_bytes, size_t packet_length,
1459     char* destination_connection_id_bytes,
1460     uint8_t* destination_connection_id_length_out) {
1461   if (packet_bytes == nullptr) {
1462     QUIC_BUG(quic_bug_10256_5) << "Invalid packet_bytes";
1463     return false;
1464   }
1465   if (packet_length < kMinPacketSizeForVersionNegotiation ||
1466       packet_length > 65535) {
1467     QUIC_BUG(quic_bug_10256_6) << "Invalid packet_length";
1468     return false;
1469   }
1470   if (destination_connection_id_bytes == nullptr) {
1471     QUIC_BUG(quic_bug_10256_7) << "Invalid destination_connection_id_bytes";
1472     return false;
1473   }
1474   if (destination_connection_id_length_out == nullptr) {
1475     QUIC_BUG(quic_bug_10256_8)
1476         << "Invalid destination_connection_id_length_out";
1477     return false;
1478   }
1479 
1480   QuicEncryptedPacket encrypted_packet(packet_bytes, packet_length);
1481   PacketHeaderFormat format;
1482   QuicLongHeaderType long_packet_type;
1483   bool version_present, has_length_prefix;
1484   QuicVersionLabel version_label;
1485   ParsedQuicVersion parsed_version = ParsedQuicVersion::Unsupported();
1486   QuicConnectionId destination_connection_id, source_connection_id;
1487   std::optional<absl::string_view> retry_token;
1488   std::string detailed_error;
1489   QuicErrorCode error = QuicFramer::ParsePublicHeaderDispatcher(
1490       encrypted_packet,
1491       /*expected_destination_connection_id_length=*/0, &format,
1492       &long_packet_type, &version_present, &has_length_prefix, &version_label,
1493       &parsed_version, &destination_connection_id, &source_connection_id,
1494       &retry_token, &detailed_error);
1495   if (error != QUIC_NO_ERROR) {
1496     QUIC_BUG(quic_bug_10256_9) << "Failed to parse packet: " << detailed_error;
1497     return false;
1498   }
1499   if (!version_present) {
1500     QUIC_BUG(quic_bug_10256_10) << "Packet is not a long header";
1501     return false;
1502   }
1503   if (*destination_connection_id_length_out <
1504       destination_connection_id.length()) {
1505     QUIC_BUG(quic_bug_10256_11)
1506         << "destination_connection_id_length_out too small";
1507     return false;
1508   }
1509   *destination_connection_id_length_out = destination_connection_id.length();
1510   memcpy(destination_connection_id_bytes, destination_connection_id.data(),
1511          *destination_connection_id_length_out);
1512   return true;
1513 }
1514 
1515 }  // namespace test
1516 }  // namespace quic
1517