• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "quiche/quic/test_tools/quic_test_utils.h"
6 
7 #include <algorithm>
8 #include <cstdint>
9 #include <memory>
10 #include <utility>
11 
12 #include "absl/base/macros.h"
13 #include "absl/strings/string_view.h"
14 #include "openssl/chacha.h"
15 #include "openssl/sha.h"
16 #include "quiche/quic/core/crypto/crypto_framer.h"
17 #include "quiche/quic/core/crypto/crypto_handshake.h"
18 #include "quiche/quic/core/crypto/crypto_utils.h"
19 #include "quiche/quic/core/crypto/null_decrypter.h"
20 #include "quiche/quic/core/crypto/null_encrypter.h"
21 #include "quiche/quic/core/crypto/quic_decrypter.h"
22 #include "quiche/quic/core/crypto/quic_encrypter.h"
23 #include "quiche/quic/core/http/quic_spdy_client_session.h"
24 #include "quiche/quic/core/quic_config.h"
25 #include "quiche/quic/core/quic_data_writer.h"
26 #include "quiche/quic/core/quic_framer.h"
27 #include "quiche/quic/core/quic_packet_creator.h"
28 #include "quiche/quic/core/quic_types.h"
29 #include "quiche/quic/core/quic_utils.h"
30 #include "quiche/quic/core/quic_versions.h"
31 #include "quiche/quic/platform/api/quic_flags.h"
32 #include "quiche/quic/platform/api/quic_logging.h"
33 #include "quiche/quic/test_tools/crypto_test_utils.h"
34 #include "quiche/quic/test_tools/quic_config_peer.h"
35 #include "quiche/quic/test_tools/quic_connection_peer.h"
36 #include "quiche/common/quiche_buffer_allocator.h"
37 #include "quiche/common/quiche_endian.h"
38 #include "quiche/common/simple_buffer_allocator.h"
39 #include "quiche/spdy/core/spdy_frame_builder.h"
40 
41 using testing::_;
42 using testing::Invoke;
43 
44 namespace quic {
45 namespace test {
46 
TestConnectionId()47 QuicConnectionId TestConnectionId() {
48   // Chosen by fair dice roll.
49   // Guaranteed to be random.
50   return TestConnectionId(42);
51 }
52 
TestConnectionId(uint64_t connection_number)53 QuicConnectionId TestConnectionId(uint64_t connection_number) {
54   const uint64_t connection_id64_net =
55       quiche::QuicheEndian::HostToNet64(connection_number);
56   return QuicConnectionId(reinterpret_cast<const char*>(&connection_id64_net),
57                           sizeof(connection_id64_net));
58 }
59 
TestConnectionIdNineBytesLong(uint64_t connection_number)60 QuicConnectionId TestConnectionIdNineBytesLong(uint64_t connection_number) {
61   const uint64_t connection_number_net =
62       quiche::QuicheEndian::HostToNet64(connection_number);
63   char connection_id_bytes[9] = {};
64   static_assert(
65       sizeof(connection_id_bytes) == 1 + sizeof(connection_number_net),
66       "bad lengths");
67   memcpy(connection_id_bytes + 1, &connection_number_net,
68          sizeof(connection_number_net));
69   return QuicConnectionId(connection_id_bytes, sizeof(connection_id_bytes));
70 }
71 
TestConnectionIdToUInt64(QuicConnectionId connection_id)72 uint64_t TestConnectionIdToUInt64(QuicConnectionId connection_id) {
73   QUICHE_DCHECK_EQ(connection_id.length(), kQuicDefaultConnectionIdLength);
74   uint64_t connection_id64_net = 0;
75   memcpy(&connection_id64_net, connection_id.data(),
76          std::min<size_t>(static_cast<size_t>(connection_id.length()),
77                           sizeof(connection_id64_net)));
78   return quiche::QuicheEndian::NetToHost64(connection_id64_net);
79 }
80 
CreateStatelessResetTokenForTest()81 std::vector<uint8_t> CreateStatelessResetTokenForTest() {
82   static constexpr uint8_t kStatelessResetTokenDataForTest[16] = {
83       0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97,
84       0x98, 0x99, 0x9A, 0x9B, 0x9C, 0x9D, 0x9E, 0x9F};
85   return std::vector<uint8_t>(kStatelessResetTokenDataForTest,
86                               kStatelessResetTokenDataForTest +
87                                   sizeof(kStatelessResetTokenDataForTest));
88 }
89 
TestHostname()90 std::string TestHostname() { return "test.example.com"; }
91 
TestServerId()92 QuicServerId TestServerId() { return QuicServerId(TestHostname(), kTestPort); }
93 
InitAckFrame(const std::vector<QuicAckBlock> & ack_blocks)94 QuicAckFrame InitAckFrame(const std::vector<QuicAckBlock>& ack_blocks) {
95   QUICHE_DCHECK_GT(ack_blocks.size(), 0u);
96 
97   QuicAckFrame ack;
98   QuicPacketNumber end_of_previous_block(1);
99   for (const QuicAckBlock& block : ack_blocks) {
100     QUICHE_DCHECK_GE(block.start, end_of_previous_block);
101     QUICHE_DCHECK_GT(block.limit, block.start);
102     ack.packets.AddRange(block.start, block.limit);
103     end_of_previous_block = block.limit;
104   }
105 
106   ack.largest_acked = ack.packets.Max();
107 
108   return ack;
109 }
110 
InitAckFrame(uint64_t largest_acked)111 QuicAckFrame InitAckFrame(uint64_t largest_acked) {
112   return InitAckFrame(QuicPacketNumber(largest_acked));
113 }
114 
InitAckFrame(QuicPacketNumber largest_acked)115 QuicAckFrame InitAckFrame(QuicPacketNumber largest_acked) {
116   return InitAckFrame({{QuicPacketNumber(1), largest_acked + 1}});
117 }
118 
MakeAckFrameWithAckBlocks(size_t num_ack_blocks,uint64_t least_unacked)119 QuicAckFrame MakeAckFrameWithAckBlocks(size_t num_ack_blocks,
120                                        uint64_t least_unacked) {
121   QuicAckFrame ack;
122   ack.largest_acked = QuicPacketNumber(2 * num_ack_blocks + least_unacked);
123   // Add enough received packets to get num_ack_blocks ack blocks.
124   for (QuicPacketNumber i = QuicPacketNumber(2);
125        i < QuicPacketNumber(2 * num_ack_blocks + 1); i += 2) {
126     ack.packets.Add(i + least_unacked);
127   }
128   return ack;
129 }
130 
MakeAckFrameWithGaps(uint64_t gap_size,size_t max_num_gaps,uint64_t largest_acked)131 QuicAckFrame MakeAckFrameWithGaps(uint64_t gap_size, size_t max_num_gaps,
132                                   uint64_t largest_acked) {
133   QuicAckFrame ack;
134   ack.largest_acked = QuicPacketNumber(largest_acked);
135   ack.packets.Add(QuicPacketNumber(largest_acked));
136   for (size_t i = 0; i < max_num_gaps; ++i) {
137     if (largest_acked <= gap_size) {
138       break;
139     }
140     largest_acked -= gap_size;
141     ack.packets.Add(QuicPacketNumber(largest_acked));
142   }
143   return ack;
144 }
145 
HeaderToEncryptionLevel(const QuicPacketHeader & header)146 EncryptionLevel HeaderToEncryptionLevel(const QuicPacketHeader& header) {
147   if (header.form == IETF_QUIC_SHORT_HEADER_PACKET) {
148     return ENCRYPTION_FORWARD_SECURE;
149   } else if (header.form == IETF_QUIC_LONG_HEADER_PACKET) {
150     if (header.long_packet_type == HANDSHAKE) {
151       return ENCRYPTION_HANDSHAKE;
152     } else if (header.long_packet_type == ZERO_RTT_PROTECTED) {
153       return ENCRYPTION_ZERO_RTT;
154     }
155   }
156   return ENCRYPTION_INITIAL;
157 }
158 
BuildUnsizedDataPacket(QuicFramer * framer,const QuicPacketHeader & header,const QuicFrames & frames)159 std::unique_ptr<QuicPacket> BuildUnsizedDataPacket(
160     QuicFramer* framer, const QuicPacketHeader& header,
161     const QuicFrames& frames) {
162   const size_t max_plaintext_size =
163       framer->GetMaxPlaintextSize(kMaxOutgoingPacketSize);
164   size_t packet_size = GetPacketHeaderSize(framer->transport_version(), header);
165   for (size_t i = 0; i < frames.size(); ++i) {
166     QUICHE_DCHECK_LE(packet_size, max_plaintext_size);
167     bool first_frame = i == 0;
168     bool last_frame = i == frames.size() - 1;
169     const size_t frame_size = framer->GetSerializedFrameLength(
170         frames[i], max_plaintext_size - packet_size, first_frame, last_frame,
171         header.packet_number_length);
172     QUICHE_DCHECK(frame_size);
173     packet_size += frame_size;
174   }
175   return BuildUnsizedDataPacket(framer, header, frames, packet_size);
176 }
177 
BuildUnsizedDataPacket(QuicFramer * framer,const QuicPacketHeader & header,const QuicFrames & frames,size_t packet_size)178 std::unique_ptr<QuicPacket> BuildUnsizedDataPacket(
179     QuicFramer* framer, const QuicPacketHeader& header,
180     const QuicFrames& frames, size_t packet_size) {
181   char* buffer = new char[packet_size];
182   EncryptionLevel level = HeaderToEncryptionLevel(header);
183   size_t length =
184       framer->BuildDataPacket(header, frames, buffer, packet_size, level);
185 
186   if (length == 0) {
187     delete[] buffer;
188     return nullptr;
189   }
190   // Re-construct the data packet with data ownership.
191   return std::make_unique<QuicPacket>(
192       buffer, length, /* owns_buffer */ true,
193       GetIncludedDestinationConnectionIdLength(header),
194       GetIncludedSourceConnectionIdLength(header), header.version_flag,
195       header.nonce != nullptr, header.packet_number_length,
196       header.retry_token_length_length, header.retry_token.length(),
197       header.length_length);
198 }
199 
Sha1Hash(absl::string_view data)200 std::string Sha1Hash(absl::string_view data) {
201   char buffer[SHA_DIGEST_LENGTH];
202   SHA1(reinterpret_cast<const uint8_t*>(data.data()), data.size(),
203        reinterpret_cast<uint8_t*>(buffer));
204   return std::string(buffer, ABSL_ARRAYSIZE(buffer));
205 }
206 
ClearControlFrame(const QuicFrame & frame)207 bool ClearControlFrame(const QuicFrame& frame) {
208   DeleteFrame(&const_cast<QuicFrame&>(frame));
209   return true;
210 }
211 
ClearControlFrameWithTransmissionType(const QuicFrame & frame,TransmissionType)212 bool ClearControlFrameWithTransmissionType(const QuicFrame& frame,
213                                            TransmissionType /*type*/) {
214   return ClearControlFrame(frame);
215 }
216 
RandUint64()217 uint64_t SimpleRandom::RandUint64() {
218   uint64_t result;
219   RandBytes(&result, sizeof(result));
220   return result;
221 }
222 
RandBytes(void * data,size_t len)223 void SimpleRandom::RandBytes(void* data, size_t len) {
224   uint8_t* data_bytes = reinterpret_cast<uint8_t*>(data);
225   while (len > 0) {
226     const size_t buffer_left = sizeof(buffer_) - buffer_offset_;
227     const size_t to_copy = std::min(buffer_left, len);
228     memcpy(data_bytes, buffer_ + buffer_offset_, to_copy);
229     data_bytes += to_copy;
230     buffer_offset_ += to_copy;
231     len -= to_copy;
232 
233     if (buffer_offset_ == sizeof(buffer_)) {
234       FillBuffer();
235     }
236   }
237 }
238 
InsecureRandBytes(void * data,size_t len)239 void SimpleRandom::InsecureRandBytes(void* data, size_t len) {
240   RandBytes(data, len);
241 }
242 
InsecureRandUint64()243 uint64_t SimpleRandom::InsecureRandUint64() { return RandUint64(); }
244 
FillBuffer()245 void SimpleRandom::FillBuffer() {
246   uint8_t nonce[12];
247   memcpy(nonce, buffer_, sizeof(nonce));
248   CRYPTO_chacha_20(buffer_, buffer_, sizeof(buffer_), key_, nonce, 0);
249   buffer_offset_ = 0;
250 }
251 
set_seed(uint64_t seed)252 void SimpleRandom::set_seed(uint64_t seed) {
253   static_assert(sizeof(key_) == SHA256_DIGEST_LENGTH, "Key has to be 256 bits");
254   SHA256(reinterpret_cast<const uint8_t*>(&seed), sizeof(seed), key_);
255 
256   memset(buffer_, 0, sizeof(buffer_));
257   FillBuffer();
258 }
259 
MockFramerVisitor()260 MockFramerVisitor::MockFramerVisitor() {
261   // By default, we want to accept packets.
262   ON_CALL(*this, OnProtocolVersionMismatch(_))
263       .WillByDefault(testing::Return(false));
264 
265   // By default, we want to accept packets.
266   ON_CALL(*this, OnUnauthenticatedHeader(_))
267       .WillByDefault(testing::Return(true));
268 
269   ON_CALL(*this, OnUnauthenticatedPublicHeader(_))
270       .WillByDefault(testing::Return(true));
271 
272   ON_CALL(*this, OnPacketHeader(_)).WillByDefault(testing::Return(true));
273 
274   ON_CALL(*this, OnStreamFrame(_)).WillByDefault(testing::Return(true));
275 
276   ON_CALL(*this, OnCryptoFrame(_)).WillByDefault(testing::Return(true));
277 
278   ON_CALL(*this, OnStopWaitingFrame(_)).WillByDefault(testing::Return(true));
279 
280   ON_CALL(*this, OnPaddingFrame(_)).WillByDefault(testing::Return(true));
281 
282   ON_CALL(*this, OnPingFrame(_)).WillByDefault(testing::Return(true));
283 
284   ON_CALL(*this, OnRstStreamFrame(_)).WillByDefault(testing::Return(true));
285 
286   ON_CALL(*this, OnConnectionCloseFrame(_))
287       .WillByDefault(testing::Return(true));
288 
289   ON_CALL(*this, OnStopSendingFrame(_)).WillByDefault(testing::Return(true));
290 
291   ON_CALL(*this, OnPathChallengeFrame(_)).WillByDefault(testing::Return(true));
292 
293   ON_CALL(*this, OnPathResponseFrame(_)).WillByDefault(testing::Return(true));
294 
295   ON_CALL(*this, OnGoAwayFrame(_)).WillByDefault(testing::Return(true));
296   ON_CALL(*this, OnMaxStreamsFrame(_)).WillByDefault(testing::Return(true));
297   ON_CALL(*this, OnStreamsBlockedFrame(_)).WillByDefault(testing::Return(true));
298 }
299 
~MockFramerVisitor()300 MockFramerVisitor::~MockFramerVisitor() {}
301 
OnProtocolVersionMismatch(ParsedQuicVersion)302 bool NoOpFramerVisitor::OnProtocolVersionMismatch(
303     ParsedQuicVersion /*version*/) {
304   return false;
305 }
306 
OnUnauthenticatedPublicHeader(const QuicPacketHeader &)307 bool NoOpFramerVisitor::OnUnauthenticatedPublicHeader(
308     const QuicPacketHeader& /*header*/) {
309   return true;
310 }
311 
OnUnauthenticatedHeader(const QuicPacketHeader &)312 bool NoOpFramerVisitor::OnUnauthenticatedHeader(
313     const QuicPacketHeader& /*header*/) {
314   return true;
315 }
316 
OnPacketHeader(const QuicPacketHeader &)317 bool NoOpFramerVisitor::OnPacketHeader(const QuicPacketHeader& /*header*/) {
318   return true;
319 }
320 
OnCoalescedPacket(const QuicEncryptedPacket &)321 void NoOpFramerVisitor::OnCoalescedPacket(
322     const QuicEncryptedPacket& /*packet*/) {}
323 
OnUndecryptablePacket(const QuicEncryptedPacket &,EncryptionLevel,bool)324 void NoOpFramerVisitor::OnUndecryptablePacket(
325     const QuicEncryptedPacket& /*packet*/, EncryptionLevel /*decryption_level*/,
326     bool /*has_decryption_key*/) {}
327 
OnStreamFrame(const QuicStreamFrame &)328 bool NoOpFramerVisitor::OnStreamFrame(const QuicStreamFrame& /*frame*/) {
329   return true;
330 }
331 
OnCryptoFrame(const QuicCryptoFrame &)332 bool NoOpFramerVisitor::OnCryptoFrame(const QuicCryptoFrame& /*frame*/) {
333   return true;
334 }
335 
OnAckFrameStart(QuicPacketNumber,QuicTime::Delta)336 bool NoOpFramerVisitor::OnAckFrameStart(QuicPacketNumber /*largest_acked*/,
337                                         QuicTime::Delta /*ack_delay_time*/) {
338   return true;
339 }
340 
OnAckRange(QuicPacketNumber,QuicPacketNumber)341 bool NoOpFramerVisitor::OnAckRange(QuicPacketNumber /*start*/,
342                                    QuicPacketNumber /*end*/) {
343   return true;
344 }
345 
OnAckTimestamp(QuicPacketNumber,QuicTime)346 bool NoOpFramerVisitor::OnAckTimestamp(QuicPacketNumber /*packet_number*/,
347                                        QuicTime /*timestamp*/) {
348   return true;
349 }
350 
OnAckFrameEnd(QuicPacketNumber,const absl::optional<QuicEcnCounts> &)351 bool NoOpFramerVisitor::OnAckFrameEnd(
352     QuicPacketNumber /*start*/,
353     const absl::optional<QuicEcnCounts>& /*ecn_counts*/) {
354   return true;
355 }
356 
OnStopWaitingFrame(const QuicStopWaitingFrame &)357 bool NoOpFramerVisitor::OnStopWaitingFrame(
358     const QuicStopWaitingFrame& /*frame*/) {
359   return true;
360 }
361 
OnPaddingFrame(const QuicPaddingFrame &)362 bool NoOpFramerVisitor::OnPaddingFrame(const QuicPaddingFrame& /*frame*/) {
363   return true;
364 }
365 
OnPingFrame(const QuicPingFrame &)366 bool NoOpFramerVisitor::OnPingFrame(const QuicPingFrame& /*frame*/) {
367   return true;
368 }
369 
OnRstStreamFrame(const QuicRstStreamFrame &)370 bool NoOpFramerVisitor::OnRstStreamFrame(const QuicRstStreamFrame& /*frame*/) {
371   return true;
372 }
373 
OnConnectionCloseFrame(const QuicConnectionCloseFrame &)374 bool NoOpFramerVisitor::OnConnectionCloseFrame(
375     const QuicConnectionCloseFrame& /*frame*/) {
376   return true;
377 }
378 
OnNewConnectionIdFrame(const QuicNewConnectionIdFrame &)379 bool NoOpFramerVisitor::OnNewConnectionIdFrame(
380     const QuicNewConnectionIdFrame& /*frame*/) {
381   return true;
382 }
383 
OnRetireConnectionIdFrame(const QuicRetireConnectionIdFrame &)384 bool NoOpFramerVisitor::OnRetireConnectionIdFrame(
385     const QuicRetireConnectionIdFrame& /*frame*/) {
386   return true;
387 }
388 
OnNewTokenFrame(const QuicNewTokenFrame &)389 bool NoOpFramerVisitor::OnNewTokenFrame(const QuicNewTokenFrame& /*frame*/) {
390   return true;
391 }
392 
OnStopSendingFrame(const QuicStopSendingFrame &)393 bool NoOpFramerVisitor::OnStopSendingFrame(
394     const QuicStopSendingFrame& /*frame*/) {
395   return true;
396 }
397 
OnPathChallengeFrame(const QuicPathChallengeFrame &)398 bool NoOpFramerVisitor::OnPathChallengeFrame(
399     const QuicPathChallengeFrame& /*frame*/) {
400   return true;
401 }
402 
OnPathResponseFrame(const QuicPathResponseFrame &)403 bool NoOpFramerVisitor::OnPathResponseFrame(
404     const QuicPathResponseFrame& /*frame*/) {
405   return true;
406 }
407 
OnGoAwayFrame(const QuicGoAwayFrame &)408 bool NoOpFramerVisitor::OnGoAwayFrame(const QuicGoAwayFrame& /*frame*/) {
409   return true;
410 }
411 
OnMaxStreamsFrame(const QuicMaxStreamsFrame &)412 bool NoOpFramerVisitor::OnMaxStreamsFrame(
413     const QuicMaxStreamsFrame& /*frame*/) {
414   return true;
415 }
416 
OnStreamsBlockedFrame(const QuicStreamsBlockedFrame &)417 bool NoOpFramerVisitor::OnStreamsBlockedFrame(
418     const QuicStreamsBlockedFrame& /*frame*/) {
419   return true;
420 }
421 
OnWindowUpdateFrame(const QuicWindowUpdateFrame &)422 bool NoOpFramerVisitor::OnWindowUpdateFrame(
423     const QuicWindowUpdateFrame& /*frame*/) {
424   return true;
425 }
426 
OnBlockedFrame(const QuicBlockedFrame &)427 bool NoOpFramerVisitor::OnBlockedFrame(const QuicBlockedFrame& /*frame*/) {
428   return true;
429 }
430 
OnMessageFrame(const QuicMessageFrame &)431 bool NoOpFramerVisitor::OnMessageFrame(const QuicMessageFrame& /*frame*/) {
432   return true;
433 }
434 
OnHandshakeDoneFrame(const QuicHandshakeDoneFrame &)435 bool NoOpFramerVisitor::OnHandshakeDoneFrame(
436     const QuicHandshakeDoneFrame& /*frame*/) {
437   return true;
438 }
439 
OnAckFrequencyFrame(const QuicAckFrequencyFrame &)440 bool NoOpFramerVisitor::OnAckFrequencyFrame(
441     const QuicAckFrequencyFrame& /*frame*/) {
442   return true;
443 }
444 
IsValidStatelessResetToken(const StatelessResetToken &) const445 bool NoOpFramerVisitor::IsValidStatelessResetToken(
446     const StatelessResetToken& /*token*/) const {
447   return false;
448 }
449 
MockQuicConnectionVisitor()450 MockQuicConnectionVisitor::MockQuicConnectionVisitor() {}
451 
~MockQuicConnectionVisitor()452 MockQuicConnectionVisitor::~MockQuicConnectionVisitor() {}
453 
MockQuicConnectionHelper()454 MockQuicConnectionHelper::MockQuicConnectionHelper() {}
455 
~MockQuicConnectionHelper()456 MockQuicConnectionHelper::~MockQuicConnectionHelper() {}
457 
GetClock() const458 const QuicClock* MockQuicConnectionHelper::GetClock() const { return &clock_; }
459 
GetClock()460 QuicClock* MockQuicConnectionHelper::GetClock() { return &clock_; }
461 
GetRandomGenerator()462 QuicRandom* MockQuicConnectionHelper::GetRandomGenerator() {
463   return &random_generator_;
464 }
465 
CreateAlarm(QuicAlarm::Delegate * delegate)466 QuicAlarm* MockAlarmFactory::CreateAlarm(QuicAlarm::Delegate* delegate) {
467   return new MockAlarmFactory::TestAlarm(
468       QuicArenaScopedPtr<QuicAlarm::Delegate>(delegate));
469 }
470 
CreateAlarm(QuicArenaScopedPtr<QuicAlarm::Delegate> delegate,QuicConnectionArena * arena)471 QuicArenaScopedPtr<QuicAlarm> MockAlarmFactory::CreateAlarm(
472     QuicArenaScopedPtr<QuicAlarm::Delegate> delegate,
473     QuicConnectionArena* arena) {
474   if (arena != nullptr) {
475     return arena->New<TestAlarm>(std::move(delegate));
476   } else {
477     return QuicArenaScopedPtr<TestAlarm>(new TestAlarm(std::move(delegate)));
478   }
479 }
480 
481 quiche::QuicheBufferAllocator*
GetStreamSendBufferAllocator()482 MockQuicConnectionHelper::GetStreamSendBufferAllocator() {
483   return &buffer_allocator_;
484 }
485 
AdvanceTime(QuicTime::Delta delta)486 void MockQuicConnectionHelper::AdvanceTime(QuicTime::Delta delta) {
487   clock_.AdvanceTime(delta);
488 }
489 
MockQuicConnection(QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective)490 MockQuicConnection::MockQuicConnection(QuicConnectionHelperInterface* helper,
491                                        QuicAlarmFactory* alarm_factory,
492                                        Perspective perspective)
493     : MockQuicConnection(TestConnectionId(),
494                          QuicSocketAddress(TestPeerIPAddress(), kTestPort),
495                          helper, alarm_factory, perspective,
496                          ParsedVersionOfIndex(CurrentSupportedVersions(), 0)) {}
497 
MockQuicConnection(QuicSocketAddress address,QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective)498 MockQuicConnection::MockQuicConnection(QuicSocketAddress address,
499                                        QuicConnectionHelperInterface* helper,
500                                        QuicAlarmFactory* alarm_factory,
501                                        Perspective perspective)
502     : MockQuicConnection(TestConnectionId(), address, helper, alarm_factory,
503                          perspective,
504                          ParsedVersionOfIndex(CurrentSupportedVersions(), 0)) {}
505 
MockQuicConnection(QuicConnectionId connection_id,QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective)506 MockQuicConnection::MockQuicConnection(QuicConnectionId connection_id,
507                                        QuicConnectionHelperInterface* helper,
508                                        QuicAlarmFactory* alarm_factory,
509                                        Perspective perspective)
510     : MockQuicConnection(connection_id,
511                          QuicSocketAddress(TestPeerIPAddress(), kTestPort),
512                          helper, alarm_factory, perspective,
513                          ParsedVersionOfIndex(CurrentSupportedVersions(), 0)) {}
514 
MockQuicConnection(QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective,const ParsedQuicVersionVector & supported_versions)515 MockQuicConnection::MockQuicConnection(
516     QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory,
517     Perspective perspective, const ParsedQuicVersionVector& supported_versions)
518     : MockQuicConnection(
519           TestConnectionId(), QuicSocketAddress(TestPeerIPAddress(), kTestPort),
520           helper, alarm_factory, perspective, supported_versions) {}
521 
MockQuicConnection(QuicConnectionId connection_id,QuicSocketAddress initial_peer_address,QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective,const ParsedQuicVersionVector & supported_versions)522 MockQuicConnection::MockQuicConnection(
523     QuicConnectionId connection_id, QuicSocketAddress initial_peer_address,
524     QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory,
525     Perspective perspective, const ParsedQuicVersionVector& supported_versions)
526     : QuicConnection(
527           connection_id,
528           /*initial_self_address=*/QuicSocketAddress(QuicIpAddress::Any4(), 5),
529           initial_peer_address, helper, alarm_factory,
530           new testing::NiceMock<MockPacketWriter>(),
531           /* owns_writer= */ true, perspective, supported_versions,
532           connection_id_generator_) {
533   ON_CALL(*this, OnError(_))
534       .WillByDefault(
535           Invoke(this, &PacketSavingConnection::QuicConnection_OnError));
536   ON_CALL(*this, SendCryptoData(_, _, _))
537       .WillByDefault(
538           Invoke(this, &MockQuicConnection::QuicConnection_SendCryptoData));
539 
540   SetSelfAddress(QuicSocketAddress(QuicIpAddress::Any4(), 5));
541 }
542 
~MockQuicConnection()543 MockQuicConnection::~MockQuicConnection() {}
544 
AdvanceTime(QuicTime::Delta delta)545 void MockQuicConnection::AdvanceTime(QuicTime::Delta delta) {
546   static_cast<MockQuicConnectionHelper*>(helper())->AdvanceTime(delta);
547 }
548 
OnProtocolVersionMismatch(ParsedQuicVersion)549 bool MockQuicConnection::OnProtocolVersionMismatch(
550     ParsedQuicVersion /*version*/) {
551   return false;
552 }
553 
PacketSavingConnection(QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective)554 PacketSavingConnection::PacketSavingConnection(
555     QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory,
556     Perspective perspective)
557     : MockQuicConnection(helper, alarm_factory, perspective) {}
558 
PacketSavingConnection(QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective,const ParsedQuicVersionVector & supported_versions)559 PacketSavingConnection::PacketSavingConnection(
560     QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory,
561     Perspective perspective, const ParsedQuicVersionVector& supported_versions)
562     : MockQuicConnection(helper, alarm_factory, perspective,
563                          supported_versions) {}
564 
~PacketSavingConnection()565 PacketSavingConnection::~PacketSavingConnection() {}
566 
GetSerializedPacketFate(bool,EncryptionLevel)567 SerializedPacketFate PacketSavingConnection::GetSerializedPacketFate(
568     bool /*is_mtu_discovery*/, EncryptionLevel /*encryption_level*/) {
569   return SEND_TO_WRITER;
570 }
571 
SendOrQueuePacket(SerializedPacket packet)572 void PacketSavingConnection::SendOrQueuePacket(SerializedPacket packet) {
573   encrypted_packets_.push_back(std::make_unique<QuicEncryptedPacket>(
574       CopyBuffer(packet), packet.encrypted_length, true));
575   clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10));
576   // Transfer ownership of the packet to the SentPacketManager and the
577   // ack notifier to the AckNotifierManager.
578   OnPacketSent(packet.encryption_level, packet.transmission_type);
579   QuicConnectionPeer::GetSentPacketManager(this)->OnPacketSent(
580       &packet, clock_.ApproximateNow(), NOT_RETRANSMISSION,
581       HAS_RETRANSMITTABLE_DATA, true, ECN_NOT_ECT);
582 }
583 
MockQuicSession(QuicConnection * connection)584 MockQuicSession::MockQuicSession(QuicConnection* connection)
585     : MockQuicSession(connection, true) {}
586 
MockQuicSession(QuicConnection * connection,bool create_mock_crypto_stream)587 MockQuicSession::MockQuicSession(QuicConnection* connection,
588                                  bool create_mock_crypto_stream)
589     : QuicSession(connection, nullptr, DefaultQuicConfig(),
590                   connection->supported_versions(),
591                   /*num_expected_unidirectional_static_streams = */ 0) {
592   if (create_mock_crypto_stream) {
593     crypto_stream_ = std::make_unique<MockQuicCryptoStream>(this);
594   }
595   ON_CALL(*this, WritevData(_, _, _, _, _, _))
596       .WillByDefault(testing::Return(QuicConsumedData(0, false)));
597 }
598 
~MockQuicSession()599 MockQuicSession::~MockQuicSession() { DeleteConnection(); }
600 
GetMutableCryptoStream()601 QuicCryptoStream* MockQuicSession::GetMutableCryptoStream() {
602   return crypto_stream_.get();
603 }
604 
GetCryptoStream() const605 const QuicCryptoStream* MockQuicSession::GetCryptoStream() const {
606   return crypto_stream_.get();
607 }
608 
SetCryptoStream(QuicCryptoStream * crypto_stream)609 void MockQuicSession::SetCryptoStream(QuicCryptoStream* crypto_stream) {
610   crypto_stream_.reset(crypto_stream);
611 }
612 
ConsumeData(QuicStreamId id,size_t write_length,QuicStreamOffset offset,StreamSendingState state,TransmissionType,absl::optional<EncryptionLevel>)613 QuicConsumedData MockQuicSession::ConsumeData(
614     QuicStreamId id, size_t write_length, QuicStreamOffset offset,
615     StreamSendingState state, TransmissionType /*type*/,
616     absl::optional<EncryptionLevel> /*level*/) {
617   if (write_length > 0) {
618     auto buf = std::make_unique<char[]>(write_length);
619     QuicStream* stream = GetOrCreateStream(id);
620     QUICHE_DCHECK(stream);
621     QuicDataWriter writer(write_length, buf.get(), quiche::HOST_BYTE_ORDER);
622     stream->WriteStreamData(offset, write_length, &writer);
623   } else {
624     QUICHE_DCHECK(state != NO_FIN);
625   }
626   return QuicConsumedData(write_length, state != NO_FIN);
627 }
628 
MockQuicCryptoStream(QuicSession * session)629 MockQuicCryptoStream::MockQuicCryptoStream(QuicSession* session)
630     : QuicCryptoStream(session), params_(new QuicCryptoNegotiatedParameters) {}
631 
~MockQuicCryptoStream()632 MockQuicCryptoStream::~MockQuicCryptoStream() {}
633 
EarlyDataReason() const634 ssl_early_data_reason_t MockQuicCryptoStream::EarlyDataReason() const {
635   return ssl_early_data_unknown;
636 }
637 
encryption_established() const638 bool MockQuicCryptoStream::encryption_established() const { return false; }
639 
one_rtt_keys_available() const640 bool MockQuicCryptoStream::one_rtt_keys_available() const { return false; }
641 
642 const QuicCryptoNegotiatedParameters&
crypto_negotiated_params() const643 MockQuicCryptoStream::crypto_negotiated_params() const {
644   return *params_;
645 }
646 
crypto_message_parser()647 CryptoMessageParser* MockQuicCryptoStream::crypto_message_parser() {
648   return &crypto_framer_;
649 }
650 
MockQuicSpdySession(QuicConnection * connection)651 MockQuicSpdySession::MockQuicSpdySession(QuicConnection* connection)
652     : MockQuicSpdySession(connection, true) {}
653 
MockQuicSpdySession(QuicConnection * connection,bool create_mock_crypto_stream)654 MockQuicSpdySession::MockQuicSpdySession(QuicConnection* connection,
655                                          bool create_mock_crypto_stream)
656     : QuicSpdySession(connection, nullptr, DefaultQuicConfig(),
657                       connection->supported_versions()) {
658   if (create_mock_crypto_stream) {
659     crypto_stream_ = std::make_unique<MockQuicCryptoStream>(this);
660   }
661 
662   ON_CALL(*this, WritevData(_, _, _, _, _, _))
663       .WillByDefault(testing::Return(QuicConsumedData(0, false)));
664 
665   ON_CALL(*this, SendWindowUpdate(_, _))
666       .WillByDefault([this](QuicStreamId id, QuicStreamOffset byte_offset) {
667         return QuicSpdySession::SendWindowUpdate(id, byte_offset);
668       });
669 
670   ON_CALL(*this, SendBlocked(_, _))
671       .WillByDefault([this](QuicStreamId id, QuicStreamOffset byte_offset) {
672         return QuicSpdySession::SendBlocked(id, byte_offset);
673       });
674 
675   ON_CALL(*this, OnCongestionWindowChange(_)).WillByDefault(testing::Return());
676 }
677 
~MockQuicSpdySession()678 MockQuicSpdySession::~MockQuicSpdySession() { DeleteConnection(); }
679 
GetMutableCryptoStream()680 QuicCryptoStream* MockQuicSpdySession::GetMutableCryptoStream() {
681   return crypto_stream_.get();
682 }
683 
GetCryptoStream() const684 const QuicCryptoStream* MockQuicSpdySession::GetCryptoStream() const {
685   return crypto_stream_.get();
686 }
687 
SetCryptoStream(QuicCryptoStream * crypto_stream)688 void MockQuicSpdySession::SetCryptoStream(QuicCryptoStream* crypto_stream) {
689   crypto_stream_.reset(crypto_stream);
690 }
691 
ConsumeData(QuicStreamId id,size_t write_length,QuicStreamOffset offset,StreamSendingState state,TransmissionType,absl::optional<EncryptionLevel>)692 QuicConsumedData MockQuicSpdySession::ConsumeData(
693     QuicStreamId id, size_t write_length, QuicStreamOffset offset,
694     StreamSendingState state, TransmissionType /*type*/,
695     absl::optional<EncryptionLevel> /*level*/) {
696   if (write_length > 0) {
697     auto buf = std::make_unique<char[]>(write_length);
698     QuicStream* stream = GetOrCreateStream(id);
699     QUICHE_DCHECK(stream);
700     QuicDataWriter writer(write_length, buf.get(), quiche::HOST_BYTE_ORDER);
701     stream->WriteStreamData(offset, write_length, &writer);
702   } else {
703     QUICHE_DCHECK(state != NO_FIN);
704   }
705   return QuicConsumedData(write_length, state != NO_FIN);
706 }
707 
TestQuicSpdyServerSession(QuicConnection * connection,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions,const QuicCryptoServerConfig * crypto_config,QuicCompressedCertsCache * compressed_certs_cache)708 TestQuicSpdyServerSession::TestQuicSpdyServerSession(
709     QuicConnection* connection, const QuicConfig& config,
710     const ParsedQuicVersionVector& supported_versions,
711     const QuicCryptoServerConfig* crypto_config,
712     QuicCompressedCertsCache* compressed_certs_cache)
713     : QuicServerSessionBase(config, supported_versions, connection, &visitor_,
714                             &helper_, crypto_config, compressed_certs_cache) {
715   ON_CALL(helper_, CanAcceptClientHello(_, _, _, _, _))
716       .WillByDefault(testing::Return(true));
717 }
718 
~TestQuicSpdyServerSession()719 TestQuicSpdyServerSession::~TestQuicSpdyServerSession() { DeleteConnection(); }
720 
721 std::unique_ptr<QuicCryptoServerStreamBase>
CreateQuicCryptoServerStream(const QuicCryptoServerConfig * crypto_config,QuicCompressedCertsCache * compressed_certs_cache)722 TestQuicSpdyServerSession::CreateQuicCryptoServerStream(
723     const QuicCryptoServerConfig* crypto_config,
724     QuicCompressedCertsCache* compressed_certs_cache) {
725   return CreateCryptoServerStream(crypto_config, compressed_certs_cache, this,
726                                   &helper_);
727 }
728 
729 QuicCryptoServerStreamBase*
GetMutableCryptoStream()730 TestQuicSpdyServerSession::GetMutableCryptoStream() {
731   return QuicServerSessionBase::GetMutableCryptoStream();
732 }
733 
GetCryptoStream() const734 const QuicCryptoServerStreamBase* TestQuicSpdyServerSession::GetCryptoStream()
735     const {
736   return QuicServerSessionBase::GetCryptoStream();
737 }
738 
TestQuicSpdyClientSession(QuicConnection * connection,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions,const QuicServerId & server_id,QuicCryptoClientConfig * crypto_config,absl::optional<QuicSSLConfig> ssl_config)739 TestQuicSpdyClientSession::TestQuicSpdyClientSession(
740     QuicConnection* connection, const QuicConfig& config,
741     const ParsedQuicVersionVector& supported_versions,
742     const QuicServerId& server_id, QuicCryptoClientConfig* crypto_config,
743     absl::optional<QuicSSLConfig> ssl_config)
744     : QuicSpdyClientSessionBase(connection, nullptr, &push_promise_index_,
745                                 config, supported_versions),
746       ssl_config_(std::move(ssl_config)) {
747   // TODO(b/153726130): Consider adding SetServerApplicationStateForResumption
748   // calls in tests and set |has_application_state| to true.
749   crypto_stream_ = std::make_unique<QuicCryptoClientStream>(
750       server_id, this, crypto_test_utils::ProofVerifyContextForTesting(),
751       crypto_config, this, /*has_application_state = */ false);
752   Initialize();
753   ON_CALL(*this, OnConfigNegotiated())
754       .WillByDefault(
755           Invoke(this, &TestQuicSpdyClientSession::RealOnConfigNegotiated));
756 }
757 
~TestQuicSpdyClientSession()758 TestQuicSpdyClientSession::~TestQuicSpdyClientSession() {}
759 
IsAuthorized(const std::string &)760 bool TestQuicSpdyClientSession::IsAuthorized(const std::string& /*authority*/) {
761   return true;
762 }
763 
GetMutableCryptoStream()764 QuicCryptoClientStream* TestQuicSpdyClientSession::GetMutableCryptoStream() {
765   return crypto_stream_.get();
766 }
767 
GetCryptoStream() const768 const QuicCryptoClientStream* TestQuicSpdyClientSession::GetCryptoStream()
769     const {
770   return crypto_stream_.get();
771 }
772 
RealOnConfigNegotiated()773 void TestQuicSpdyClientSession::RealOnConfigNegotiated() {
774   QuicSpdyClientSessionBase::OnConfigNegotiated();
775 }
776 
TestPushPromiseDelegate(bool match)777 TestPushPromiseDelegate::TestPushPromiseDelegate(bool match)
778     : match_(match), rendezvous_fired_(false), rendezvous_stream_(nullptr) {}
779 
CheckVary(const spdy::Http2HeaderBlock &,const spdy::Http2HeaderBlock &,const spdy::Http2HeaderBlock &)780 bool TestPushPromiseDelegate::CheckVary(
781     const spdy::Http2HeaderBlock& /*client_request*/,
782     const spdy::Http2HeaderBlock& /*promise_request*/,
783     const spdy::Http2HeaderBlock& /*promise_response*/) {
784   QUIC_DVLOG(1) << "match " << match_;
785   return match_;
786 }
787 
OnRendezvousResult(QuicSpdyStream * stream)788 void TestPushPromiseDelegate::OnRendezvousResult(QuicSpdyStream* stream) {
789   rendezvous_fired_ = true;
790   rendezvous_stream_ = stream;
791 }
792 
MockPacketWriter()793 MockPacketWriter::MockPacketWriter() {
794   ON_CALL(*this, GetMaxPacketSize(_))
795       .WillByDefault(testing::Return(kMaxOutgoingPacketSize));
796   ON_CALL(*this, IsBatchMode()).WillByDefault(testing::Return(false));
797   ON_CALL(*this, GetNextWriteLocation(_, _))
798       .WillByDefault(testing::Return(QuicPacketBuffer()));
799   ON_CALL(*this, Flush())
800       .WillByDefault(testing::Return(WriteResult(WRITE_STATUS_OK, 0)));
801   ON_CALL(*this, SupportsReleaseTime()).WillByDefault(testing::Return(false));
802 }
803 
~MockPacketWriter()804 MockPacketWriter::~MockPacketWriter() {}
805 
MockSendAlgorithm()806 MockSendAlgorithm::MockSendAlgorithm() {
807   ON_CALL(*this, PacingRate(_))
808       .WillByDefault(testing::Return(QuicBandwidth::Zero()));
809   ON_CALL(*this, BandwidthEstimate())
810       .WillByDefault(testing::Return(QuicBandwidth::Zero()));
811 }
812 
~MockSendAlgorithm()813 MockSendAlgorithm::~MockSendAlgorithm() {}
814 
MockLossAlgorithm()815 MockLossAlgorithm::MockLossAlgorithm() {}
816 
~MockLossAlgorithm()817 MockLossAlgorithm::~MockLossAlgorithm() {}
818 
MockAckListener()819 MockAckListener::MockAckListener() {}
820 
~MockAckListener()821 MockAckListener::~MockAckListener() {}
822 
MockNetworkChangeVisitor()823 MockNetworkChangeVisitor::MockNetworkChangeVisitor() {}
824 
~MockNetworkChangeVisitor()825 MockNetworkChangeVisitor::~MockNetworkChangeVisitor() {}
826 
TestPeerIPAddress()827 QuicIpAddress TestPeerIPAddress() { return QuicIpAddress::Loopback4(); }
828 
QuicVersionMax()829 ParsedQuicVersion QuicVersionMax() { return AllSupportedVersions().front(); }
830 
QuicVersionMin()831 ParsedQuicVersion QuicVersionMin() { return AllSupportedVersions().back(); }
832 
DisableQuicVersionsWithTls()833 void DisableQuicVersionsWithTls() {
834   for (const ParsedQuicVersion& version : AllSupportedVersionsWithTls()) {
835     QuicDisableVersion(version);
836   }
837 }
838 
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data)839 QuicEncryptedPacket* ConstructEncryptedPacket(
840     QuicConnectionId destination_connection_id,
841     QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
842     uint64_t packet_number, const std::string& data) {
843   return ConstructEncryptedPacket(
844       destination_connection_id, source_connection_id, version_flag, reset_flag,
845       packet_number, data, CONNECTION_ID_PRESENT, CONNECTION_ID_ABSENT,
846       PACKET_4BYTE_PACKET_NUMBER);
847 }
848 
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length)849 QuicEncryptedPacket* ConstructEncryptedPacket(
850     QuicConnectionId destination_connection_id,
851     QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
852     uint64_t packet_number, const std::string& data,
853     QuicConnectionIdIncluded destination_connection_id_included,
854     QuicConnectionIdIncluded source_connection_id_included,
855     QuicPacketNumberLength packet_number_length) {
856   return ConstructEncryptedPacket(
857       destination_connection_id, source_connection_id, version_flag, reset_flag,
858       packet_number, data, destination_connection_id_included,
859       source_connection_id_included, packet_number_length, nullptr);
860 }
861 
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length,ParsedQuicVersionVector * versions)862 QuicEncryptedPacket* ConstructEncryptedPacket(
863     QuicConnectionId destination_connection_id,
864     QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
865     uint64_t packet_number, const std::string& data,
866     QuicConnectionIdIncluded destination_connection_id_included,
867     QuicConnectionIdIncluded source_connection_id_included,
868     QuicPacketNumberLength packet_number_length,
869     ParsedQuicVersionVector* versions) {
870   return ConstructEncryptedPacket(
871       destination_connection_id, source_connection_id, version_flag, reset_flag,
872       packet_number, data, false, destination_connection_id_included,
873       source_connection_id_included, packet_number_length, versions,
874       Perspective::IS_CLIENT);
875 }
876 
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,bool full_padding,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length,ParsedQuicVersionVector * versions)877 QuicEncryptedPacket* ConstructEncryptedPacket(
878     QuicConnectionId destination_connection_id,
879     QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
880     uint64_t packet_number, const std::string& data, bool full_padding,
881     QuicConnectionIdIncluded destination_connection_id_included,
882     QuicConnectionIdIncluded source_connection_id_included,
883     QuicPacketNumberLength packet_number_length,
884     ParsedQuicVersionVector* versions) {
885   return ConstructEncryptedPacket(
886       destination_connection_id, source_connection_id, version_flag, reset_flag,
887       packet_number, data, full_padding, destination_connection_id_included,
888       source_connection_id_included, packet_number_length, versions,
889       Perspective::IS_CLIENT);
890 }
891 
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,bool full_padding,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length,ParsedQuicVersionVector * versions,Perspective perspective)892 QuicEncryptedPacket* ConstructEncryptedPacket(
893     QuicConnectionId destination_connection_id,
894     QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
895     uint64_t packet_number, const std::string& data, bool full_padding,
896     QuicConnectionIdIncluded destination_connection_id_included,
897     QuicConnectionIdIncluded source_connection_id_included,
898     QuicPacketNumberLength packet_number_length,
899     ParsedQuicVersionVector* versions, Perspective perspective) {
900   QuicPacketHeader header;
901   header.destination_connection_id = destination_connection_id;
902   header.destination_connection_id_included =
903       destination_connection_id_included;
904   header.source_connection_id = source_connection_id;
905   header.source_connection_id_included = source_connection_id_included;
906   header.version_flag = version_flag;
907   header.reset_flag = reset_flag;
908   header.packet_number_length = packet_number_length;
909   header.packet_number = QuicPacketNumber(packet_number);
910   ParsedQuicVersionVector supported_versions = CurrentSupportedVersions();
911   if (!versions) {
912     versions = &supported_versions;
913   }
914   EXPECT_FALSE(versions->empty());
915   ParsedQuicVersion version = (*versions)[0];
916   if (QuicVersionHasLongHeaderLengths(version.transport_version) &&
917       version_flag) {
918     header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1;
919     header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2;
920   }
921 
922   QuicFrames frames;
923   QuicFramer framer(*versions, QuicTime::Zero(), perspective,
924                     kQuicDefaultConnectionIdLength);
925   framer.SetInitialObfuscators(destination_connection_id);
926   EncryptionLevel level =
927       header.version_flag ? ENCRYPTION_INITIAL : ENCRYPTION_FORWARD_SECURE;
928   if (level != ENCRYPTION_INITIAL) {
929     framer.SetEncrypter(level, std::make_unique<TaggingEncrypter>(level));
930   }
931   if (!QuicVersionUsesCryptoFrames(version.transport_version)) {
932     QuicFrame frame(
933         QuicStreamFrame(QuicUtils::GetCryptoStreamId(version.transport_version),
934                         false, 0, absl::string_view(data)));
935     frames.push_back(frame);
936   } else {
937     QuicFrame frame(new QuicCryptoFrame(level, 0, data));
938     frames.push_back(frame);
939   }
940   if (full_padding) {
941     frames.push_back(QuicFrame(QuicPaddingFrame(-1)));
942   } else {
943     // We need a minimum number of bytes of encrypted payload. This will
944     // guarantee that we have at least that much. (It ignores the overhead of
945     // the stream/crypto framing, so it overpads slightly.)
946     size_t min_plaintext_size = QuicPacketCreator::MinPlaintextPacketSize(
947         version, packet_number_length);
948     if (data.length() < min_plaintext_size) {
949       size_t padding_length = min_plaintext_size - data.length();
950       frames.push_back(QuicFrame(QuicPaddingFrame(padding_length)));
951     }
952   }
953 
954   std::unique_ptr<QuicPacket> packet(
955       BuildUnsizedDataPacket(&framer, header, frames));
956   EXPECT_TRUE(packet != nullptr);
957   char* buffer = new char[kMaxOutgoingPacketSize];
958   size_t encrypted_length =
959       framer.EncryptPayload(level, QuicPacketNumber(packet_number), *packet,
960                             buffer, kMaxOutgoingPacketSize);
961   EXPECT_NE(0u, encrypted_length);
962   DeleteFrames(&frames);
963   return new QuicEncryptedPacket(buffer, encrypted_length, true);
964 }
965 
GetUndecryptableEarlyPacket(const ParsedQuicVersion & version,const QuicConnectionId & server_connection_id)966 std::unique_ptr<QuicEncryptedPacket> GetUndecryptableEarlyPacket(
967     const ParsedQuicVersion& version,
968     const QuicConnectionId& server_connection_id) {
969   QuicPacketHeader header;
970   header.destination_connection_id = server_connection_id;
971   header.destination_connection_id_included = CONNECTION_ID_PRESENT;
972   header.source_connection_id = EmptyQuicConnectionId();
973   header.source_connection_id_included = CONNECTION_ID_PRESENT;
974   if (!version.SupportsClientConnectionIds()) {
975     header.source_connection_id_included = CONNECTION_ID_ABSENT;
976   }
977   header.version_flag = true;
978   header.reset_flag = false;
979   header.packet_number_length = PACKET_4BYTE_PACKET_NUMBER;
980   header.packet_number = QuicPacketNumber(33);
981   header.long_packet_type = ZERO_RTT_PROTECTED;
982   if (version.HasLongHeaderLengths()) {
983     header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1;
984     header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2;
985   }
986 
987   QuicFrames frames;
988   frames.push_back(QuicFrame(QuicPingFrame()));
989   frames.push_back(QuicFrame(QuicPaddingFrame(100)));
990   QuicFramer framer({version}, QuicTime::Zero(), Perspective::IS_CLIENT,
991                     kQuicDefaultConnectionIdLength);
992   framer.SetInitialObfuscators(server_connection_id);
993 
994   framer.SetEncrypter(ENCRYPTION_ZERO_RTT,
995                       std::make_unique<TaggingEncrypter>(ENCRYPTION_ZERO_RTT));
996   std::unique_ptr<QuicPacket> packet(
997       BuildUnsizedDataPacket(&framer, header, frames));
998   EXPECT_TRUE(packet != nullptr);
999   char* buffer = new char[kMaxOutgoingPacketSize];
1000   size_t encrypted_length =
1001       framer.EncryptPayload(ENCRYPTION_ZERO_RTT, header.packet_number, *packet,
1002                             buffer, kMaxOutgoingPacketSize);
1003   EXPECT_NE(0u, encrypted_length);
1004   DeleteFrames(&frames);
1005   return std::make_unique<QuicEncryptedPacket>(buffer, encrypted_length,
1006                                                /*owns_buffer=*/true);
1007 }
1008 
ConstructReceivedPacket(const QuicEncryptedPacket & encrypted_packet,QuicTime receipt_time)1009 QuicReceivedPacket* ConstructReceivedPacket(
1010     const QuicEncryptedPacket& encrypted_packet, QuicTime receipt_time) {
1011   char* buffer = new char[encrypted_packet.length()];
1012   memcpy(buffer, encrypted_packet.data(), encrypted_packet.length());
1013   return new QuicReceivedPacket(buffer, encrypted_packet.length(), receipt_time,
1014                                 true);
1015 }
1016 
ConstructMisFramedEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length,ParsedQuicVersion version,Perspective perspective)1017 QuicEncryptedPacket* ConstructMisFramedEncryptedPacket(
1018     QuicConnectionId destination_connection_id,
1019     QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
1020     uint64_t packet_number, const std::string& data,
1021     QuicConnectionIdIncluded destination_connection_id_included,
1022     QuicConnectionIdIncluded source_connection_id_included,
1023     QuicPacketNumberLength packet_number_length, ParsedQuicVersion version,
1024     Perspective perspective) {
1025   QuicPacketHeader header;
1026   header.destination_connection_id = destination_connection_id;
1027   header.destination_connection_id_included =
1028       destination_connection_id_included;
1029   header.source_connection_id = source_connection_id;
1030   header.source_connection_id_included = source_connection_id_included;
1031   header.version_flag = version_flag;
1032   header.reset_flag = reset_flag;
1033   header.packet_number_length = packet_number_length;
1034   header.packet_number = QuicPacketNumber(packet_number);
1035   if (QuicVersionHasLongHeaderLengths(version.transport_version) &&
1036       version_flag) {
1037     header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1;
1038     header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2;
1039   }
1040   QuicFrame frame(QuicStreamFrame(1, false, 0, absl::string_view(data)));
1041   QuicFrames frames;
1042   frames.push_back(frame);
1043   QuicFramer framer({version}, QuicTime::Zero(), perspective,
1044                     kQuicDefaultConnectionIdLength);
1045   framer.SetInitialObfuscators(destination_connection_id);
1046   EncryptionLevel level =
1047       version_flag ? ENCRYPTION_INITIAL : ENCRYPTION_FORWARD_SECURE;
1048   if (level != ENCRYPTION_INITIAL) {
1049     framer.SetEncrypter(level, std::make_unique<TaggingEncrypter>(level));
1050   }
1051   // We need a minimum of 7 bytes of encrypted payload. This will guarantee that
1052   // we have at least that much. (It ignores the overhead of the stream/crypto
1053   // framing, so it overpads slightly.)
1054   if (data.length() < 7) {
1055     size_t padding_length = 7 - data.length();
1056     frames.push_back(QuicFrame(QuicPaddingFrame(padding_length)));
1057   }
1058 
1059   std::unique_ptr<QuicPacket> packet(
1060       BuildUnsizedDataPacket(&framer, header, frames));
1061   EXPECT_TRUE(packet != nullptr);
1062 
1063   // Now set the frame type to 0x1F, which is an invalid frame type.
1064   reinterpret_cast<unsigned char*>(
1065       packet->mutable_data())[GetStartOfEncryptedData(
1066       framer.transport_version(),
1067       GetIncludedDestinationConnectionIdLength(header),
1068       GetIncludedSourceConnectionIdLength(header), version_flag,
1069       false /* no diversification nonce */, packet_number_length,
1070       header.retry_token_length_length, 0, header.length_length)] = 0x1F;
1071 
1072   char* buffer = new char[kMaxOutgoingPacketSize];
1073   size_t encrypted_length =
1074       framer.EncryptPayload(level, QuicPacketNumber(packet_number), *packet,
1075                             buffer, kMaxOutgoingPacketSize);
1076   EXPECT_NE(0u, encrypted_length);
1077   return new QuicEncryptedPacket(buffer, encrypted_length, true);
1078 }
1079 
DefaultQuicConfig()1080 QuicConfig DefaultQuicConfig() {
1081   QuicConfig config;
1082   config.SetInitialMaxStreamDataBytesIncomingBidirectionalToSend(
1083       kInitialStreamFlowControlWindowForTest);
1084   config.SetInitialMaxStreamDataBytesOutgoingBidirectionalToSend(
1085       kInitialStreamFlowControlWindowForTest);
1086   config.SetInitialMaxStreamDataBytesUnidirectionalToSend(
1087       kInitialStreamFlowControlWindowForTest);
1088   config.SetInitialStreamFlowControlWindowToSend(
1089       kInitialStreamFlowControlWindowForTest);
1090   config.SetInitialSessionFlowControlWindowToSend(
1091       kInitialSessionFlowControlWindowForTest);
1092   QuicConfigPeer::SetReceivedMaxBidirectionalStreams(
1093       &config, kDefaultMaxStreamsPerConnection);
1094   // Default enable NSTP.
1095   // This is unnecessary for versions > 44
1096   if (!config.HasClientSentConnectionOption(quic::kNSTP,
1097                                             quic::Perspective::IS_CLIENT)) {
1098     quic::QuicTagVector connection_options;
1099     connection_options.push_back(quic::kNSTP);
1100     config.SetConnectionOptionsToSend(connection_options);
1101   }
1102   return config;
1103 }
1104 
SupportedVersions(ParsedQuicVersion version)1105 ParsedQuicVersionVector SupportedVersions(ParsedQuicVersion version) {
1106   ParsedQuicVersionVector versions;
1107   versions.push_back(version);
1108   return versions;
1109 }
1110 
MockQuicConnectionDebugVisitor()1111 MockQuicConnectionDebugVisitor::MockQuicConnectionDebugVisitor() {}
1112 
~MockQuicConnectionDebugVisitor()1113 MockQuicConnectionDebugVisitor::~MockQuicConnectionDebugVisitor() {}
1114 
MockReceivedPacketManager(QuicConnectionStats * stats)1115 MockReceivedPacketManager::MockReceivedPacketManager(QuicConnectionStats* stats)
1116     : QuicReceivedPacketManager(stats) {}
1117 
~MockReceivedPacketManager()1118 MockReceivedPacketManager::~MockReceivedPacketManager() {}
1119 
MockPacketCreatorDelegate()1120 MockPacketCreatorDelegate::MockPacketCreatorDelegate() {}
~MockPacketCreatorDelegate()1121 MockPacketCreatorDelegate::~MockPacketCreatorDelegate() {}
1122 
MockSessionNotifier()1123 MockSessionNotifier::MockSessionNotifier() {}
~MockSessionNotifier()1124 MockSessionNotifier::~MockSessionNotifier() {}
1125 
1126 // static
1127 QuicCryptoClientStream::HandshakerInterface*
GetHandshaker(QuicCryptoClientStream * stream)1128 QuicCryptoClientStreamPeer::GetHandshaker(QuicCryptoClientStream* stream) {
1129   return stream->handshaker_.get();
1130 }
1131 
CreateClientSessionForTest(QuicServerId server_id,QuicTime::Delta connection_start_time,const ParsedQuicVersionVector & supported_versions,QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,QuicCryptoClientConfig * crypto_client_config,PacketSavingConnection ** client_connection,TestQuicSpdyClientSession ** client_session)1132 void CreateClientSessionForTest(
1133     QuicServerId server_id, QuicTime::Delta connection_start_time,
1134     const ParsedQuicVersionVector& supported_versions,
1135     QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory,
1136     QuicCryptoClientConfig* crypto_client_config,
1137     PacketSavingConnection** client_connection,
1138     TestQuicSpdyClientSession** client_session) {
1139   QUICHE_CHECK(crypto_client_config);
1140   QUICHE_CHECK(client_connection);
1141   QUICHE_CHECK(client_session);
1142   QUICHE_CHECK(!connection_start_time.IsZero())
1143       << "Connections must start at non-zero times, otherwise the "
1144       << "strike-register will be unhappy.";
1145 
1146   QuicConfig config = DefaultQuicConfig();
1147   *client_connection = new PacketSavingConnection(
1148       helper, alarm_factory, Perspective::IS_CLIENT, supported_versions);
1149   *client_session = new TestQuicSpdyClientSession(*client_connection, config,
1150                                                   supported_versions, server_id,
1151                                                   crypto_client_config);
1152   (*client_connection)->AdvanceTime(connection_start_time);
1153 }
1154 
CreateServerSessionForTest(QuicServerId,QuicTime::Delta connection_start_time,ParsedQuicVersionVector supported_versions,QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,QuicCryptoServerConfig * server_crypto_config,QuicCompressedCertsCache * compressed_certs_cache,PacketSavingConnection ** server_connection,TestQuicSpdyServerSession ** server_session)1155 void CreateServerSessionForTest(
1156     QuicServerId /*server_id*/, QuicTime::Delta connection_start_time,
1157     ParsedQuicVersionVector supported_versions,
1158     QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory,
1159     QuicCryptoServerConfig* server_crypto_config,
1160     QuicCompressedCertsCache* compressed_certs_cache,
1161     PacketSavingConnection** server_connection,
1162     TestQuicSpdyServerSession** server_session) {
1163   QUICHE_CHECK(server_crypto_config);
1164   QUICHE_CHECK(server_connection);
1165   QUICHE_CHECK(server_session);
1166   QUICHE_CHECK(!connection_start_time.IsZero())
1167       << "Connections must start at non-zero times, otherwise the "
1168       << "strike-register will be unhappy.";
1169 
1170   *server_connection =
1171       new PacketSavingConnection(helper, alarm_factory, Perspective::IS_SERVER,
1172                                  ParsedVersionOfIndex(supported_versions, 0));
1173   *server_session = new TestQuicSpdyServerSession(
1174       *server_connection, DefaultQuicConfig(), supported_versions,
1175       server_crypto_config, compressed_certs_cache);
1176   (*server_session)->Initialize();
1177 
1178   // We advance the clock initially because the default time is zero and the
1179   // strike register worries that we've just overflowed a uint32_t time.
1180   (*server_connection)->AdvanceTime(connection_start_time);
1181 }
1182 
GetNthClientInitiatedBidirectionalStreamId(QuicTransportVersion version,int n)1183 QuicStreamId GetNthClientInitiatedBidirectionalStreamId(
1184     QuicTransportVersion version, int n) {
1185   int num = n;
1186   if (!VersionUsesHttp3(version)) {
1187     num++;
1188   }
1189   return QuicUtils::GetFirstBidirectionalStreamId(version,
1190                                                   Perspective::IS_CLIENT) +
1191          QuicUtils::StreamIdDelta(version) * num;
1192 }
1193 
GetNthServerInitiatedBidirectionalStreamId(QuicTransportVersion version,int n)1194 QuicStreamId GetNthServerInitiatedBidirectionalStreamId(
1195     QuicTransportVersion version, int n) {
1196   return QuicUtils::GetFirstBidirectionalStreamId(version,
1197                                                   Perspective::IS_SERVER) +
1198          QuicUtils::StreamIdDelta(version) * n;
1199 }
1200 
GetNthServerInitiatedUnidirectionalStreamId(QuicTransportVersion version,int n)1201 QuicStreamId GetNthServerInitiatedUnidirectionalStreamId(
1202     QuicTransportVersion version, int n) {
1203   return QuicUtils::GetFirstUnidirectionalStreamId(version,
1204                                                    Perspective::IS_SERVER) +
1205          QuicUtils::StreamIdDelta(version) * n;
1206 }
1207 
GetNthClientInitiatedUnidirectionalStreamId(QuicTransportVersion version,int n)1208 QuicStreamId GetNthClientInitiatedUnidirectionalStreamId(
1209     QuicTransportVersion version, int n) {
1210   return QuicUtils::GetFirstUnidirectionalStreamId(version,
1211                                                    Perspective::IS_CLIENT) +
1212          QuicUtils::StreamIdDelta(version) * n;
1213 }
1214 
DetermineStreamType(QuicStreamId id,ParsedQuicVersion version,Perspective perspective,bool is_incoming,StreamType default_type)1215 StreamType DetermineStreamType(QuicStreamId id, ParsedQuicVersion version,
1216                                Perspective perspective, bool is_incoming,
1217                                StreamType default_type) {
1218   return version.HasIetfQuicFrames()
1219              ? QuicUtils::GetStreamType(id, perspective, is_incoming, version)
1220              : default_type;
1221 }
1222 
MemSliceFromString(absl::string_view data)1223 quiche::QuicheMemSlice MemSliceFromString(absl::string_view data) {
1224   if (data.empty()) {
1225     return quiche::QuicheMemSlice();
1226   }
1227 
1228   static quiche::SimpleBufferAllocator* allocator =
1229       new quiche::SimpleBufferAllocator();
1230   return quiche::QuicheMemSlice(quiche::QuicheBuffer::Copy(allocator, data));
1231 }
1232 
EncryptPacket(uint64_t,absl::string_view,absl::string_view plaintext,char * output,size_t * output_length,size_t max_output_length)1233 bool TaggingEncrypter::EncryptPacket(uint64_t /*packet_number*/,
1234                                      absl::string_view /*associated_data*/,
1235                                      absl::string_view plaintext, char* output,
1236                                      size_t* output_length,
1237                                      size_t max_output_length) {
1238   const size_t len = plaintext.size() + kTagSize;
1239   if (max_output_length < len) {
1240     return false;
1241   }
1242   // Memmove is safe for inplace encryption.
1243   memmove(output, plaintext.data(), plaintext.size());
1244   output += plaintext.size();
1245   memset(output, tag_, kTagSize);
1246   *output_length = len;
1247   return true;
1248 }
1249 
DecryptPacket(uint64_t,absl::string_view,absl::string_view ciphertext,char * output,size_t * output_length,size_t)1250 bool TaggingDecrypter::DecryptPacket(uint64_t /*packet_number*/,
1251                                      absl::string_view /*associated_data*/,
1252                                      absl::string_view ciphertext, char* output,
1253                                      size_t* output_length,
1254                                      size_t /*max_output_length*/) {
1255   if (ciphertext.size() < kTagSize) {
1256     return false;
1257   }
1258   if (!CheckTag(ciphertext, GetTag(ciphertext))) {
1259     return false;
1260   }
1261   *output_length = ciphertext.size() - kTagSize;
1262   memcpy(output, ciphertext.data(), *output_length);
1263   return true;
1264 }
1265 
CheckTag(absl::string_view ciphertext,uint8_t tag)1266 bool TaggingDecrypter::CheckTag(absl::string_view ciphertext, uint8_t tag) {
1267   for (size_t i = ciphertext.size() - kTagSize; i < ciphertext.size(); i++) {
1268     if (ciphertext.data()[i] != tag) {
1269       return false;
1270     }
1271   }
1272 
1273   return true;
1274 }
1275 
TestPacketWriter(ParsedQuicVersion version,MockClock * clock,Perspective perspective)1276 TestPacketWriter::TestPacketWriter(ParsedQuicVersion version, MockClock* clock,
1277                                    Perspective perspective)
1278     : version_(version),
1279       framer_(SupportedVersions(version_),
1280               QuicUtils::InvertPerspective(perspective)),
1281       clock_(clock) {
1282   QuicFramerPeer::SetLastSerializedServerConnectionId(framer_.framer(),
1283                                                       TestConnectionId());
1284   framer_.framer()->SetInitialObfuscators(TestConnectionId());
1285 
1286   for (int i = 0; i < 128; ++i) {
1287     PacketBuffer* p = new PacketBuffer();
1288     packet_buffer_pool_.push_back(p);
1289     packet_buffer_pool_index_[p->buffer] = p;
1290     packet_buffer_free_list_.push_back(p);
1291   }
1292 }
1293 
~TestPacketWriter()1294 TestPacketWriter::~TestPacketWriter() {
1295   EXPECT_EQ(packet_buffer_pool_.size(), packet_buffer_free_list_.size())
1296       << packet_buffer_pool_.size() - packet_buffer_free_list_.size()
1297       << " out of " << packet_buffer_pool_.size()
1298       << " packet buffers have been leaked.";
1299   for (auto p : packet_buffer_pool_) {
1300     delete p;
1301   }
1302 }
1303 
WritePacket(const char * buffer,size_t buf_len,const QuicIpAddress & self_address,const QuicSocketAddress & peer_address,PerPacketOptions * options)1304 WriteResult TestPacketWriter::WritePacket(const char* buffer, size_t buf_len,
1305                                           const QuicIpAddress& self_address,
1306                                           const QuicSocketAddress& peer_address,
1307                                           PerPacketOptions* options) {
1308   last_write_source_address_ = self_address;
1309   last_write_peer_address_ = peer_address;
1310   // If the buffer is allocated from the pool, return it back to the pool.
1311   // Note the buffer content doesn't change.
1312   if (packet_buffer_pool_index_.find(const_cast<char*>(buffer)) !=
1313       packet_buffer_pool_index_.end()) {
1314     FreePacketBuffer(buffer);
1315   }
1316 
1317   QuicEncryptedPacket packet(buffer, buf_len);
1318   ++packets_write_attempts_;
1319 
1320   if (packet.length() >= sizeof(final_bytes_of_last_packet_)) {
1321     final_bytes_of_previous_packet_ = final_bytes_of_last_packet_;
1322     memcpy(&final_bytes_of_last_packet_, packet.data() + packet.length() - 4,
1323            sizeof(final_bytes_of_last_packet_));
1324   }
1325   if (framer_.framer()->version().KnowsWhichDecrypterToUse()) {
1326     framer_.framer()->InstallDecrypter(ENCRYPTION_HANDSHAKE,
1327                                        std::make_unique<TaggingDecrypter>());
1328     framer_.framer()->InstallDecrypter(ENCRYPTION_ZERO_RTT,
1329                                        std::make_unique<TaggingDecrypter>());
1330     framer_.framer()->InstallDecrypter(ENCRYPTION_FORWARD_SECURE,
1331                                        std::make_unique<TaggingDecrypter>());
1332   } else if (!framer_.framer()->HasDecrypterOfEncryptionLevel(
1333                  ENCRYPTION_FORWARD_SECURE) &&
1334              !framer_.framer()->HasDecrypterOfEncryptionLevel(
1335                  ENCRYPTION_ZERO_RTT)) {
1336     framer_.framer()->SetAlternativeDecrypter(
1337         ENCRYPTION_FORWARD_SECURE,
1338         std::make_unique<StrictTaggingDecrypter>(ENCRYPTION_FORWARD_SECURE),
1339         false);
1340   }
1341   EXPECT_EQ(next_packet_processable_, framer_.ProcessPacket(packet))
1342       << framer_.framer()->detailed_error() << " perspective "
1343       << framer_.framer()->perspective();
1344   next_packet_processable_ = true;
1345   if (block_on_next_write_) {
1346     write_blocked_ = true;
1347     block_on_next_write_ = false;
1348   }
1349   if (next_packet_too_large_) {
1350     next_packet_too_large_ = false;
1351     return WriteResult(WRITE_STATUS_ERROR, *MessageTooBigErrorCode());
1352   }
1353   if (always_get_packet_too_large_) {
1354     return WriteResult(WRITE_STATUS_ERROR, *MessageTooBigErrorCode());
1355   }
1356   if (IsWriteBlocked()) {
1357     return WriteResult(is_write_blocked_data_buffered_
1358                            ? WRITE_STATUS_BLOCKED_DATA_BUFFERED
1359                            : WRITE_STATUS_BLOCKED,
1360                        0);
1361   }
1362 
1363   if (ShouldWriteFail()) {
1364     return WriteResult(WRITE_STATUS_ERROR, write_error_code_);
1365   }
1366 
1367   last_packet_size_ = packet.length();
1368   total_bytes_written_ += packet.length();
1369   last_packet_header_ = framer_.header();
1370   if (!framer_.connection_close_frames().empty()) {
1371     ++connection_close_packets_;
1372   }
1373   if (!write_pause_time_delta_.IsZero()) {
1374     clock_->AdvanceTime(write_pause_time_delta_);
1375   }
1376   if (is_batch_mode_) {
1377     bytes_buffered_ += last_packet_size_;
1378     return WriteResult(WRITE_STATUS_OK, 0);
1379   }
1380   last_ecn_sent_ = (options == nullptr) ? ECN_NOT_ECT : options->ecn_codepoint;
1381   return WriteResult(WRITE_STATUS_OK, last_packet_size_);
1382 }
1383 
GetNextWriteLocation(const QuicIpAddress &,const QuicSocketAddress &)1384 QuicPacketBuffer TestPacketWriter::GetNextWriteLocation(
1385     const QuicIpAddress& /*self_address*/,
1386     const QuicSocketAddress& /*peer_address*/) {
1387   return {AllocPacketBuffer(), [this](const char* p) { FreePacketBuffer(p); }};
1388 }
1389 
Flush()1390 WriteResult TestPacketWriter::Flush() {
1391   flush_attempts_++;
1392   if (block_on_next_flush_) {
1393     block_on_next_flush_ = false;
1394     SetWriteBlocked();
1395     return WriteResult(WRITE_STATUS_BLOCKED, /*errno*/ -1);
1396   }
1397   if (write_should_fail_) {
1398     return WriteResult(WRITE_STATUS_ERROR, /*errno*/ -1);
1399   }
1400   int bytes_flushed = bytes_buffered_;
1401   bytes_buffered_ = 0;
1402   return WriteResult(WRITE_STATUS_OK, bytes_flushed);
1403 }
1404 
AllocPacketBuffer()1405 char* TestPacketWriter::AllocPacketBuffer() {
1406   PacketBuffer* p = packet_buffer_free_list_.front();
1407   EXPECT_FALSE(p->in_use);
1408   p->in_use = true;
1409   packet_buffer_free_list_.pop_front();
1410   return p->buffer;
1411 }
1412 
FreePacketBuffer(const char * buffer)1413 void TestPacketWriter::FreePacketBuffer(const char* buffer) {
1414   auto iter = packet_buffer_pool_index_.find(const_cast<char*>(buffer));
1415   ASSERT_TRUE(iter != packet_buffer_pool_index_.end());
1416   PacketBuffer* p = iter->second;
1417   ASSERT_TRUE(p->in_use);
1418   p->in_use = false;
1419   packet_buffer_free_list_.push_back(p);
1420 }
1421 
WriteServerVersionNegotiationProbeResponse(char * packet_bytes,size_t * packet_length_out,const char * source_connection_id_bytes,uint8_t source_connection_id_length)1422 bool WriteServerVersionNegotiationProbeResponse(
1423     char* packet_bytes, size_t* packet_length_out,
1424     const char* source_connection_id_bytes,
1425     uint8_t source_connection_id_length) {
1426   if (packet_bytes == nullptr) {
1427     QUIC_BUG(quic_bug_10256_1) << "Invalid packet_bytes";
1428     return false;
1429   }
1430   if (packet_length_out == nullptr) {
1431     QUIC_BUG(quic_bug_10256_2) << "Invalid packet_length_out";
1432     return false;
1433   }
1434   QuicConnectionId source_connection_id(source_connection_id_bytes,
1435                                         source_connection_id_length);
1436   std::unique_ptr<QuicEncryptedPacket> encrypted_packet =
1437       QuicFramer::BuildVersionNegotiationPacket(
1438           source_connection_id, EmptyQuicConnectionId(),
1439           /*ietf_quic=*/true, /*use_length_prefix=*/true,
1440           ParsedQuicVersionVector{});
1441   if (!encrypted_packet) {
1442     QUIC_BUG(quic_bug_10256_3) << "Failed to create version negotiation packet";
1443     return false;
1444   }
1445   if (*packet_length_out < encrypted_packet->length()) {
1446     QUIC_BUG(quic_bug_10256_4)
1447         << "Invalid *packet_length_out " << *packet_length_out << " < "
1448         << encrypted_packet->length();
1449     return false;
1450   }
1451   *packet_length_out = encrypted_packet->length();
1452   memcpy(packet_bytes, encrypted_packet->data(), *packet_length_out);
1453   return true;
1454 }
1455 
ParseClientVersionNegotiationProbePacket(const char * packet_bytes,size_t packet_length,char * destination_connection_id_bytes,uint8_t * destination_connection_id_length_out)1456 bool ParseClientVersionNegotiationProbePacket(
1457     const char* packet_bytes, size_t packet_length,
1458     char* destination_connection_id_bytes,
1459     uint8_t* destination_connection_id_length_out) {
1460   if (packet_bytes == nullptr) {
1461     QUIC_BUG(quic_bug_10256_5) << "Invalid packet_bytes";
1462     return false;
1463   }
1464   if (packet_length < kMinPacketSizeForVersionNegotiation ||
1465       packet_length > 65535) {
1466     QUIC_BUG(quic_bug_10256_6) << "Invalid packet_length";
1467     return false;
1468   }
1469   if (destination_connection_id_bytes == nullptr) {
1470     QUIC_BUG(quic_bug_10256_7) << "Invalid destination_connection_id_bytes";
1471     return false;
1472   }
1473   if (destination_connection_id_length_out == nullptr) {
1474     QUIC_BUG(quic_bug_10256_8)
1475         << "Invalid destination_connection_id_length_out";
1476     return false;
1477   }
1478 
1479   QuicEncryptedPacket encrypted_packet(packet_bytes, packet_length);
1480   PacketHeaderFormat format;
1481   QuicLongHeaderType long_packet_type;
1482   bool version_present, has_length_prefix;
1483   QuicVersionLabel version_label;
1484   ParsedQuicVersion parsed_version = ParsedQuicVersion::Unsupported();
1485   QuicConnectionId destination_connection_id, source_connection_id;
1486   absl::optional<absl::string_view> retry_token;
1487   std::string detailed_error;
1488   QuicErrorCode error = QuicFramer::ParsePublicHeaderDispatcher(
1489       encrypted_packet,
1490       /*expected_destination_connection_id_length=*/0, &format,
1491       &long_packet_type, &version_present, &has_length_prefix, &version_label,
1492       &parsed_version, &destination_connection_id, &source_connection_id,
1493       &retry_token, &detailed_error);
1494   if (error != QUIC_NO_ERROR) {
1495     QUIC_BUG(quic_bug_10256_9) << "Failed to parse packet: " << detailed_error;
1496     return false;
1497   }
1498   if (!version_present) {
1499     QUIC_BUG(quic_bug_10256_10) << "Packet is not a long header";
1500     return false;
1501   }
1502   if (*destination_connection_id_length_out <
1503       destination_connection_id.length()) {
1504     QUIC_BUG(quic_bug_10256_11)
1505         << "destination_connection_id_length_out too small";
1506     return false;
1507   }
1508   *destination_connection_id_length_out = destination_connection_id.length();
1509   memcpy(destination_connection_id_bytes, destination_connection_id.data(),
1510          *destination_connection_id_length_out);
1511   return true;
1512 }
1513 
1514 }  // namespace test
1515 }  // namespace quic
1516