1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "quiche/quic/test_tools/quic_test_utils.h"
6
7 #include <algorithm>
8 #include <cstddef>
9 #include <cstdint>
10 #include <memory>
11 #include <utility>
12 #include <vector>
13
14 #include "absl/base/macros.h"
15 #include "absl/strings/string_view.h"
16 #include "openssl/chacha.h"
17 #include "openssl/sha.h"
18 #include "quiche/quic/core/crypto/crypto_framer.h"
19 #include "quiche/quic/core/crypto/crypto_handshake.h"
20 #include "quiche/quic/core/crypto/crypto_utils.h"
21 #include "quiche/quic/core/crypto/null_decrypter.h"
22 #include "quiche/quic/core/crypto/null_encrypter.h"
23 #include "quiche/quic/core/crypto/quic_decrypter.h"
24 #include "quiche/quic/core/crypto/quic_encrypter.h"
25 #include "quiche/quic/core/http/quic_spdy_client_session.h"
26 #include "quiche/quic/core/quic_config.h"
27 #include "quiche/quic/core/quic_data_writer.h"
28 #include "quiche/quic/core/quic_framer.h"
29 #include "quiche/quic/core/quic_packet_creator.h"
30 #include "quiche/quic/core/quic_types.h"
31 #include "quiche/quic/core/quic_utils.h"
32 #include "quiche/quic/core/quic_versions.h"
33 #include "quiche/quic/platform/api/quic_flags.h"
34 #include "quiche/quic/platform/api/quic_logging.h"
35 #include "quiche/quic/test_tools/crypto_test_utils.h"
36 #include "quiche/quic/test_tools/quic_config_peer.h"
37 #include "quiche/quic/test_tools/quic_connection_peer.h"
38 #include "quiche/common/quiche_buffer_allocator.h"
39 #include "quiche/common/quiche_endian.h"
40 #include "quiche/common/simple_buffer_allocator.h"
41 #include "quiche/spdy/core/spdy_frame_builder.h"
42
43 using testing::_;
44 using testing::Invoke;
45 using testing::Return;
46
47 namespace quic {
48 namespace test {
49
TestConnectionId()50 QuicConnectionId TestConnectionId() {
51 // Chosen by fair dice roll.
52 // Guaranteed to be random.
53 return TestConnectionId(42);
54 }
55
TestConnectionId(uint64_t connection_number)56 QuicConnectionId TestConnectionId(uint64_t connection_number) {
57 const uint64_t connection_id64_net =
58 quiche::QuicheEndian::HostToNet64(connection_number);
59 return QuicConnectionId(reinterpret_cast<const char*>(&connection_id64_net),
60 sizeof(connection_id64_net));
61 }
62
TestConnectionIdNineBytesLong(uint64_t connection_number)63 QuicConnectionId TestConnectionIdNineBytesLong(uint64_t connection_number) {
64 const uint64_t connection_number_net =
65 quiche::QuicheEndian::HostToNet64(connection_number);
66 char connection_id_bytes[9] = {};
67 static_assert(
68 sizeof(connection_id_bytes) == 1 + sizeof(connection_number_net),
69 "bad lengths");
70 memcpy(connection_id_bytes + 1, &connection_number_net,
71 sizeof(connection_number_net));
72 return QuicConnectionId(connection_id_bytes, sizeof(connection_id_bytes));
73 }
74
TestConnectionIdToUInt64(QuicConnectionId connection_id)75 uint64_t TestConnectionIdToUInt64(QuicConnectionId connection_id) {
76 QUICHE_DCHECK_EQ(connection_id.length(), kQuicDefaultConnectionIdLength);
77 uint64_t connection_id64_net = 0;
78 memcpy(&connection_id64_net, connection_id.data(),
79 std::min<size_t>(static_cast<size_t>(connection_id.length()),
80 sizeof(connection_id64_net)));
81 return quiche::QuicheEndian::NetToHost64(connection_id64_net);
82 }
83
CreateStatelessResetTokenForTest()84 std::vector<uint8_t> CreateStatelessResetTokenForTest() {
85 static constexpr uint8_t kStatelessResetTokenDataForTest[16] = {
86 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97,
87 0x98, 0x99, 0x9A, 0x9B, 0x9C, 0x9D, 0x9E, 0x9F};
88 return std::vector<uint8_t>(kStatelessResetTokenDataForTest,
89 kStatelessResetTokenDataForTest +
90 sizeof(kStatelessResetTokenDataForTest));
91 }
92
TestHostname()93 std::string TestHostname() { return "test.example.com"; }
94
TestServerId()95 QuicServerId TestServerId() { return QuicServerId(TestHostname(), kTestPort); }
96
InitAckFrame(const std::vector<QuicAckBlock> & ack_blocks)97 QuicAckFrame InitAckFrame(const std::vector<QuicAckBlock>& ack_blocks) {
98 QUICHE_DCHECK_GT(ack_blocks.size(), 0u);
99
100 QuicAckFrame ack;
101 QuicPacketNumber end_of_previous_block(1);
102 for (const QuicAckBlock& block : ack_blocks) {
103 QUICHE_DCHECK_GE(block.start, end_of_previous_block);
104 QUICHE_DCHECK_GT(block.limit, block.start);
105 ack.packets.AddRange(block.start, block.limit);
106 end_of_previous_block = block.limit;
107 }
108
109 ack.largest_acked = ack.packets.Max();
110
111 return ack;
112 }
113
InitAckFrame(uint64_t largest_acked)114 QuicAckFrame InitAckFrame(uint64_t largest_acked) {
115 return InitAckFrame(QuicPacketNumber(largest_acked));
116 }
117
InitAckFrame(QuicPacketNumber largest_acked)118 QuicAckFrame InitAckFrame(QuicPacketNumber largest_acked) {
119 return InitAckFrame({{QuicPacketNumber(1), largest_acked + 1}});
120 }
121
MakeAckFrameWithAckBlocks(size_t num_ack_blocks,uint64_t least_unacked)122 QuicAckFrame MakeAckFrameWithAckBlocks(size_t num_ack_blocks,
123 uint64_t least_unacked) {
124 QuicAckFrame ack;
125 ack.largest_acked = QuicPacketNumber(2 * num_ack_blocks + least_unacked);
126 // Add enough received packets to get num_ack_blocks ack blocks.
127 for (QuicPacketNumber i = QuicPacketNumber(2);
128 i < QuicPacketNumber(2 * num_ack_blocks + 1); i += 2) {
129 ack.packets.Add(i + least_unacked);
130 }
131 return ack;
132 }
133
MakeAckFrameWithGaps(uint64_t gap_size,size_t max_num_gaps,uint64_t largest_acked)134 QuicAckFrame MakeAckFrameWithGaps(uint64_t gap_size, size_t max_num_gaps,
135 uint64_t largest_acked) {
136 QuicAckFrame ack;
137 ack.largest_acked = QuicPacketNumber(largest_acked);
138 ack.packets.Add(QuicPacketNumber(largest_acked));
139 for (size_t i = 0; i < max_num_gaps; ++i) {
140 if (largest_acked <= gap_size) {
141 break;
142 }
143 largest_acked -= gap_size;
144 ack.packets.Add(QuicPacketNumber(largest_acked));
145 }
146 return ack;
147 }
148
HeaderToEncryptionLevel(const QuicPacketHeader & header)149 EncryptionLevel HeaderToEncryptionLevel(const QuicPacketHeader& header) {
150 if (header.form == IETF_QUIC_SHORT_HEADER_PACKET) {
151 return ENCRYPTION_FORWARD_SECURE;
152 } else if (header.form == IETF_QUIC_LONG_HEADER_PACKET) {
153 if (header.long_packet_type == HANDSHAKE) {
154 return ENCRYPTION_HANDSHAKE;
155 } else if (header.long_packet_type == ZERO_RTT_PROTECTED) {
156 return ENCRYPTION_ZERO_RTT;
157 }
158 }
159 return ENCRYPTION_INITIAL;
160 }
161
BuildUnsizedDataPacket(QuicFramer * framer,const QuicPacketHeader & header,const QuicFrames & frames)162 std::unique_ptr<QuicPacket> BuildUnsizedDataPacket(
163 QuicFramer* framer, const QuicPacketHeader& header,
164 const QuicFrames& frames) {
165 const size_t max_plaintext_size =
166 framer->GetMaxPlaintextSize(kMaxOutgoingPacketSize);
167 size_t packet_size = GetPacketHeaderSize(framer->transport_version(), header);
168 for (size_t i = 0; i < frames.size(); ++i) {
169 QUICHE_DCHECK_LE(packet_size, max_plaintext_size);
170 bool first_frame = i == 0;
171 bool last_frame = i == frames.size() - 1;
172 const size_t frame_size = framer->GetSerializedFrameLength(
173 frames[i], max_plaintext_size - packet_size, first_frame, last_frame,
174 header.packet_number_length);
175 QUICHE_DCHECK(frame_size);
176 packet_size += frame_size;
177 }
178 return BuildUnsizedDataPacket(framer, header, frames, packet_size);
179 }
180
BuildUnsizedDataPacket(QuicFramer * framer,const QuicPacketHeader & header,const QuicFrames & frames,size_t packet_size)181 std::unique_ptr<QuicPacket> BuildUnsizedDataPacket(
182 QuicFramer* framer, const QuicPacketHeader& header,
183 const QuicFrames& frames, size_t packet_size) {
184 char* buffer = new char[packet_size];
185 EncryptionLevel level = HeaderToEncryptionLevel(header);
186 size_t length =
187 framer->BuildDataPacket(header, frames, buffer, packet_size, level);
188
189 if (length == 0) {
190 delete[] buffer;
191 return nullptr;
192 }
193 // Re-construct the data packet with data ownership.
194 return std::make_unique<QuicPacket>(
195 buffer, length, /* owns_buffer */ true,
196 GetIncludedDestinationConnectionIdLength(header),
197 GetIncludedSourceConnectionIdLength(header), header.version_flag,
198 header.nonce != nullptr, header.packet_number_length,
199 header.retry_token_length_length, header.retry_token.length(),
200 header.length_length);
201 }
202
Sha1Hash(absl::string_view data)203 std::string Sha1Hash(absl::string_view data) {
204 char buffer[SHA_DIGEST_LENGTH];
205 SHA1(reinterpret_cast<const uint8_t*>(data.data()), data.size(),
206 reinterpret_cast<uint8_t*>(buffer));
207 return std::string(buffer, ABSL_ARRAYSIZE(buffer));
208 }
209
ClearControlFrame(const QuicFrame & frame)210 bool ClearControlFrame(const QuicFrame& frame) {
211 DeleteFrame(&const_cast<QuicFrame&>(frame));
212 return true;
213 }
214
ClearControlFrameWithTransmissionType(const QuicFrame & frame,TransmissionType)215 bool ClearControlFrameWithTransmissionType(const QuicFrame& frame,
216 TransmissionType /*type*/) {
217 return ClearControlFrame(frame);
218 }
219
RandUint64()220 uint64_t SimpleRandom::RandUint64() {
221 uint64_t result;
222 RandBytes(&result, sizeof(result));
223 return result;
224 }
225
RandBytes(void * data,size_t len)226 void SimpleRandom::RandBytes(void* data, size_t len) {
227 uint8_t* data_bytes = reinterpret_cast<uint8_t*>(data);
228 while (len > 0) {
229 const size_t buffer_left = sizeof(buffer_) - buffer_offset_;
230 const size_t to_copy = std::min(buffer_left, len);
231 memcpy(data_bytes, buffer_ + buffer_offset_, to_copy);
232 data_bytes += to_copy;
233 buffer_offset_ += to_copy;
234 len -= to_copy;
235
236 if (buffer_offset_ == sizeof(buffer_)) {
237 FillBuffer();
238 }
239 }
240 }
241
InsecureRandBytes(void * data,size_t len)242 void SimpleRandom::InsecureRandBytes(void* data, size_t len) {
243 RandBytes(data, len);
244 }
245
InsecureRandUint64()246 uint64_t SimpleRandom::InsecureRandUint64() { return RandUint64(); }
247
FillBuffer()248 void SimpleRandom::FillBuffer() {
249 uint8_t nonce[12];
250 memcpy(nonce, buffer_, sizeof(nonce));
251 CRYPTO_chacha_20(buffer_, buffer_, sizeof(buffer_), key_, nonce, 0);
252 buffer_offset_ = 0;
253 }
254
set_seed(uint64_t seed)255 void SimpleRandom::set_seed(uint64_t seed) {
256 static_assert(sizeof(key_) == SHA256_DIGEST_LENGTH, "Key has to be 256 bits");
257 SHA256(reinterpret_cast<const uint8_t*>(&seed), sizeof(seed), key_);
258
259 memset(buffer_, 0, sizeof(buffer_));
260 FillBuffer();
261 }
262
MockFramerVisitor()263 MockFramerVisitor::MockFramerVisitor() {
264 // By default, we want to accept packets.
265 ON_CALL(*this, OnProtocolVersionMismatch(_))
266 .WillByDefault(testing::Return(false));
267
268 // By default, we want to accept packets.
269 ON_CALL(*this, OnUnauthenticatedHeader(_))
270 .WillByDefault(testing::Return(true));
271
272 ON_CALL(*this, OnUnauthenticatedPublicHeader(_))
273 .WillByDefault(testing::Return(true));
274
275 ON_CALL(*this, OnPacketHeader(_)).WillByDefault(testing::Return(true));
276
277 ON_CALL(*this, OnStreamFrame(_)).WillByDefault(testing::Return(true));
278
279 ON_CALL(*this, OnCryptoFrame(_)).WillByDefault(testing::Return(true));
280
281 ON_CALL(*this, OnStopWaitingFrame(_)).WillByDefault(testing::Return(true));
282
283 ON_CALL(*this, OnPaddingFrame(_)).WillByDefault(testing::Return(true));
284
285 ON_CALL(*this, OnPingFrame(_)).WillByDefault(testing::Return(true));
286
287 ON_CALL(*this, OnRstStreamFrame(_)).WillByDefault(testing::Return(true));
288
289 ON_CALL(*this, OnConnectionCloseFrame(_))
290 .WillByDefault(testing::Return(true));
291
292 ON_CALL(*this, OnStopSendingFrame(_)).WillByDefault(testing::Return(true));
293
294 ON_CALL(*this, OnPathChallengeFrame(_)).WillByDefault(testing::Return(true));
295
296 ON_CALL(*this, OnPathResponseFrame(_)).WillByDefault(testing::Return(true));
297
298 ON_CALL(*this, OnGoAwayFrame(_)).WillByDefault(testing::Return(true));
299 ON_CALL(*this, OnMaxStreamsFrame(_)).WillByDefault(testing::Return(true));
300 ON_CALL(*this, OnStreamsBlockedFrame(_)).WillByDefault(testing::Return(true));
301 }
302
~MockFramerVisitor()303 MockFramerVisitor::~MockFramerVisitor() {}
304
OnProtocolVersionMismatch(ParsedQuicVersion)305 bool NoOpFramerVisitor::OnProtocolVersionMismatch(
306 ParsedQuicVersion /*version*/) {
307 return false;
308 }
309
OnUnauthenticatedPublicHeader(const QuicPacketHeader &)310 bool NoOpFramerVisitor::OnUnauthenticatedPublicHeader(
311 const QuicPacketHeader& /*header*/) {
312 return true;
313 }
314
OnUnauthenticatedHeader(const QuicPacketHeader &)315 bool NoOpFramerVisitor::OnUnauthenticatedHeader(
316 const QuicPacketHeader& /*header*/) {
317 return true;
318 }
319
OnPacketHeader(const QuicPacketHeader &)320 bool NoOpFramerVisitor::OnPacketHeader(const QuicPacketHeader& /*header*/) {
321 return true;
322 }
323
OnCoalescedPacket(const QuicEncryptedPacket &)324 void NoOpFramerVisitor::OnCoalescedPacket(
325 const QuicEncryptedPacket& /*packet*/) {}
326
OnUndecryptablePacket(const QuicEncryptedPacket &,EncryptionLevel,bool)327 void NoOpFramerVisitor::OnUndecryptablePacket(
328 const QuicEncryptedPacket& /*packet*/, EncryptionLevel /*decryption_level*/,
329 bool /*has_decryption_key*/) {}
330
OnStreamFrame(const QuicStreamFrame &)331 bool NoOpFramerVisitor::OnStreamFrame(const QuicStreamFrame& /*frame*/) {
332 return true;
333 }
334
OnCryptoFrame(const QuicCryptoFrame &)335 bool NoOpFramerVisitor::OnCryptoFrame(const QuicCryptoFrame& /*frame*/) {
336 return true;
337 }
338
OnAckFrameStart(QuicPacketNumber,QuicTime::Delta)339 bool NoOpFramerVisitor::OnAckFrameStart(QuicPacketNumber /*largest_acked*/,
340 QuicTime::Delta /*ack_delay_time*/) {
341 return true;
342 }
343
OnAckRange(QuicPacketNumber,QuicPacketNumber)344 bool NoOpFramerVisitor::OnAckRange(QuicPacketNumber /*start*/,
345 QuicPacketNumber /*end*/) {
346 return true;
347 }
348
OnAckTimestamp(QuicPacketNumber,QuicTime)349 bool NoOpFramerVisitor::OnAckTimestamp(QuicPacketNumber /*packet_number*/,
350 QuicTime /*timestamp*/) {
351 return true;
352 }
353
OnAckFrameEnd(QuicPacketNumber,const std::optional<QuicEcnCounts> &)354 bool NoOpFramerVisitor::OnAckFrameEnd(
355 QuicPacketNumber /*start*/,
356 const std::optional<QuicEcnCounts>& /*ecn_counts*/) {
357 return true;
358 }
359
OnStopWaitingFrame(const QuicStopWaitingFrame &)360 bool NoOpFramerVisitor::OnStopWaitingFrame(
361 const QuicStopWaitingFrame& /*frame*/) {
362 return true;
363 }
364
OnPaddingFrame(const QuicPaddingFrame &)365 bool NoOpFramerVisitor::OnPaddingFrame(const QuicPaddingFrame& /*frame*/) {
366 return true;
367 }
368
OnPingFrame(const QuicPingFrame &)369 bool NoOpFramerVisitor::OnPingFrame(const QuicPingFrame& /*frame*/) {
370 return true;
371 }
372
OnRstStreamFrame(const QuicRstStreamFrame &)373 bool NoOpFramerVisitor::OnRstStreamFrame(const QuicRstStreamFrame& /*frame*/) {
374 return true;
375 }
376
OnConnectionCloseFrame(const QuicConnectionCloseFrame &)377 bool NoOpFramerVisitor::OnConnectionCloseFrame(
378 const QuicConnectionCloseFrame& /*frame*/) {
379 return true;
380 }
381
OnNewConnectionIdFrame(const QuicNewConnectionIdFrame &)382 bool NoOpFramerVisitor::OnNewConnectionIdFrame(
383 const QuicNewConnectionIdFrame& /*frame*/) {
384 return true;
385 }
386
OnRetireConnectionIdFrame(const QuicRetireConnectionIdFrame &)387 bool NoOpFramerVisitor::OnRetireConnectionIdFrame(
388 const QuicRetireConnectionIdFrame& /*frame*/) {
389 return true;
390 }
391
OnNewTokenFrame(const QuicNewTokenFrame &)392 bool NoOpFramerVisitor::OnNewTokenFrame(const QuicNewTokenFrame& /*frame*/) {
393 return true;
394 }
395
OnStopSendingFrame(const QuicStopSendingFrame &)396 bool NoOpFramerVisitor::OnStopSendingFrame(
397 const QuicStopSendingFrame& /*frame*/) {
398 return true;
399 }
400
OnPathChallengeFrame(const QuicPathChallengeFrame &)401 bool NoOpFramerVisitor::OnPathChallengeFrame(
402 const QuicPathChallengeFrame& /*frame*/) {
403 return true;
404 }
405
OnPathResponseFrame(const QuicPathResponseFrame &)406 bool NoOpFramerVisitor::OnPathResponseFrame(
407 const QuicPathResponseFrame& /*frame*/) {
408 return true;
409 }
410
OnGoAwayFrame(const QuicGoAwayFrame &)411 bool NoOpFramerVisitor::OnGoAwayFrame(const QuicGoAwayFrame& /*frame*/) {
412 return true;
413 }
414
OnMaxStreamsFrame(const QuicMaxStreamsFrame &)415 bool NoOpFramerVisitor::OnMaxStreamsFrame(
416 const QuicMaxStreamsFrame& /*frame*/) {
417 return true;
418 }
419
OnStreamsBlockedFrame(const QuicStreamsBlockedFrame &)420 bool NoOpFramerVisitor::OnStreamsBlockedFrame(
421 const QuicStreamsBlockedFrame& /*frame*/) {
422 return true;
423 }
424
OnWindowUpdateFrame(const QuicWindowUpdateFrame &)425 bool NoOpFramerVisitor::OnWindowUpdateFrame(
426 const QuicWindowUpdateFrame& /*frame*/) {
427 return true;
428 }
429
OnBlockedFrame(const QuicBlockedFrame &)430 bool NoOpFramerVisitor::OnBlockedFrame(const QuicBlockedFrame& /*frame*/) {
431 return true;
432 }
433
OnMessageFrame(const QuicMessageFrame &)434 bool NoOpFramerVisitor::OnMessageFrame(const QuicMessageFrame& /*frame*/) {
435 return true;
436 }
437
OnHandshakeDoneFrame(const QuicHandshakeDoneFrame &)438 bool NoOpFramerVisitor::OnHandshakeDoneFrame(
439 const QuicHandshakeDoneFrame& /*frame*/) {
440 return true;
441 }
442
OnAckFrequencyFrame(const QuicAckFrequencyFrame &)443 bool NoOpFramerVisitor::OnAckFrequencyFrame(
444 const QuicAckFrequencyFrame& /*frame*/) {
445 return true;
446 }
447
IsValidStatelessResetToken(const StatelessResetToken &) const448 bool NoOpFramerVisitor::IsValidStatelessResetToken(
449 const StatelessResetToken& /*token*/) const {
450 return false;
451 }
452
MockQuicConnectionVisitor()453 MockQuicConnectionVisitor::MockQuicConnectionVisitor() {
454 ON_CALL(*this, GetFlowControlSendWindowSize(_))
455 .WillByDefault(Return(std::numeric_limits<QuicByteCount>::max()));
456 }
457
~MockQuicConnectionVisitor()458 MockQuicConnectionVisitor::~MockQuicConnectionVisitor() {}
459
MockQuicConnectionHelper()460 MockQuicConnectionHelper::MockQuicConnectionHelper() {}
461
~MockQuicConnectionHelper()462 MockQuicConnectionHelper::~MockQuicConnectionHelper() {}
463
GetClock() const464 const MockClock* MockQuicConnectionHelper::GetClock() const { return &clock_; }
465
GetClock()466 MockClock* MockQuicConnectionHelper::GetClock() { return &clock_; }
467
GetRandomGenerator()468 QuicRandom* MockQuicConnectionHelper::GetRandomGenerator() {
469 return &random_generator_;
470 }
471
CreateAlarm(QuicAlarm::Delegate * delegate)472 QuicAlarm* MockAlarmFactory::CreateAlarm(QuicAlarm::Delegate* delegate) {
473 return new MockAlarmFactory::TestAlarm(
474 QuicArenaScopedPtr<QuicAlarm::Delegate>(delegate));
475 }
476
CreateAlarm(QuicArenaScopedPtr<QuicAlarm::Delegate> delegate,QuicConnectionArena * arena)477 QuicArenaScopedPtr<QuicAlarm> MockAlarmFactory::CreateAlarm(
478 QuicArenaScopedPtr<QuicAlarm::Delegate> delegate,
479 QuicConnectionArena* arena) {
480 if (arena != nullptr) {
481 return arena->New<TestAlarm>(std::move(delegate));
482 } else {
483 return QuicArenaScopedPtr<TestAlarm>(new TestAlarm(std::move(delegate)));
484 }
485 }
486
487 quiche::QuicheBufferAllocator*
GetStreamSendBufferAllocator()488 MockQuicConnectionHelper::GetStreamSendBufferAllocator() {
489 return &buffer_allocator_;
490 }
491
AdvanceTime(QuicTime::Delta delta)492 void MockQuicConnectionHelper::AdvanceTime(QuicTime::Delta delta) {
493 clock_.AdvanceTime(delta);
494 }
495
MockQuicConnection(QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective)496 MockQuicConnection::MockQuicConnection(QuicConnectionHelperInterface* helper,
497 QuicAlarmFactory* alarm_factory,
498 Perspective perspective)
499 : MockQuicConnection(TestConnectionId(),
500 QuicSocketAddress(TestPeerIPAddress(), kTestPort),
501 helper, alarm_factory, perspective,
502 ParsedVersionOfIndex(CurrentSupportedVersions(), 0)) {}
503
MockQuicConnection(QuicSocketAddress address,QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective)504 MockQuicConnection::MockQuicConnection(QuicSocketAddress address,
505 QuicConnectionHelperInterface* helper,
506 QuicAlarmFactory* alarm_factory,
507 Perspective perspective)
508 : MockQuicConnection(TestConnectionId(), address, helper, alarm_factory,
509 perspective,
510 ParsedVersionOfIndex(CurrentSupportedVersions(), 0)) {}
511
MockQuicConnection(QuicConnectionId connection_id,QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective)512 MockQuicConnection::MockQuicConnection(QuicConnectionId connection_id,
513 QuicConnectionHelperInterface* helper,
514 QuicAlarmFactory* alarm_factory,
515 Perspective perspective)
516 : MockQuicConnection(connection_id,
517 QuicSocketAddress(TestPeerIPAddress(), kTestPort),
518 helper, alarm_factory, perspective,
519 ParsedVersionOfIndex(CurrentSupportedVersions(), 0)) {}
520
MockQuicConnection(QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective,const ParsedQuicVersionVector & supported_versions)521 MockQuicConnection::MockQuicConnection(
522 QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory,
523 Perspective perspective, const ParsedQuicVersionVector& supported_versions)
524 : MockQuicConnection(
525 TestConnectionId(), QuicSocketAddress(TestPeerIPAddress(), kTestPort),
526 helper, alarm_factory, perspective, supported_versions) {}
527
MockQuicConnection(QuicConnectionId connection_id,QuicSocketAddress initial_peer_address,QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective,const ParsedQuicVersionVector & supported_versions)528 MockQuicConnection::MockQuicConnection(
529 QuicConnectionId connection_id, QuicSocketAddress initial_peer_address,
530 QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory,
531 Perspective perspective, const ParsedQuicVersionVector& supported_versions)
532 : QuicConnection(
533 connection_id,
534 /*initial_self_address=*/QuicSocketAddress(QuicIpAddress::Any4(), 5),
535 initial_peer_address, helper, alarm_factory,
536 new testing::NiceMock<MockPacketWriter>(),
537 /* owns_writer= */ true, perspective, supported_versions,
538 connection_id_generator_) {
539 ON_CALL(*this, OnError(_))
540 .WillByDefault(
541 Invoke(this, &PacketSavingConnection::QuicConnection_OnError));
542 ON_CALL(*this, SendCryptoData(_, _, _))
543 .WillByDefault(
544 Invoke(this, &MockQuicConnection::QuicConnection_SendCryptoData));
545
546 SetSelfAddress(QuicSocketAddress(QuicIpAddress::Any4(), 5));
547 }
548
~MockQuicConnection()549 MockQuicConnection::~MockQuicConnection() {}
550
AdvanceTime(QuicTime::Delta delta)551 void MockQuicConnection::AdvanceTime(QuicTime::Delta delta) {
552 static_cast<MockQuicConnectionHelper*>(helper())->AdvanceTime(delta);
553 }
554
OnProtocolVersionMismatch(ParsedQuicVersion)555 bool MockQuicConnection::OnProtocolVersionMismatch(
556 ParsedQuicVersion /*version*/) {
557 return false;
558 }
559
PacketSavingConnection(MockQuicConnectionHelper * helper,QuicAlarmFactory * alarm_factory,Perspective perspective)560 PacketSavingConnection::PacketSavingConnection(MockQuicConnectionHelper* helper,
561 QuicAlarmFactory* alarm_factory,
562 Perspective perspective)
563 : MockQuicConnection(helper, alarm_factory, perspective),
564 mock_helper_(helper) {}
565
PacketSavingConnection(MockQuicConnectionHelper * helper,QuicAlarmFactory * alarm_factory,Perspective perspective,const ParsedQuicVersionVector & supported_versions)566 PacketSavingConnection::PacketSavingConnection(
567 MockQuicConnectionHelper* helper, QuicAlarmFactory* alarm_factory,
568 Perspective perspective, const ParsedQuicVersionVector& supported_versions)
569 : MockQuicConnection(helper, alarm_factory, perspective,
570 supported_versions),
571 mock_helper_(helper) {}
572
~PacketSavingConnection()573 PacketSavingConnection::~PacketSavingConnection() {}
574
GetSerializedPacketFate(bool,EncryptionLevel)575 SerializedPacketFate PacketSavingConnection::GetSerializedPacketFate(
576 bool /*is_mtu_discovery*/, EncryptionLevel /*encryption_level*/) {
577 return SEND_TO_WRITER;
578 }
579
SendOrQueuePacket(SerializedPacket packet)580 void PacketSavingConnection::SendOrQueuePacket(SerializedPacket packet) {
581 encrypted_packets_.push_back(std::make_unique<QuicEncryptedPacket>(
582 CopyBuffer(packet), packet.encrypted_length, true));
583 MockClock& clock = *mock_helper_->GetClock();
584 clock.AdvanceTime(QuicTime::Delta::FromMilliseconds(10));
585 // Transfer ownership of the packet to the SentPacketManager and the
586 // ack notifier to the AckNotifierManager.
587 OnPacketSent(packet.encryption_level, packet.transmission_type);
588 QuicConnectionPeer::GetSentPacketManager(this)->OnPacketSent(
589 &packet, clock.ApproximateNow(), NOT_RETRANSMISSION,
590 HAS_RETRANSMITTABLE_DATA, true, ECN_NOT_ECT);
591 }
592
GetPackets() const593 std::vector<const QuicEncryptedPacket*> PacketSavingConnection::GetPackets()
594 const {
595 std::vector<const QuicEncryptedPacket*> packets;
596 for (size_t i = num_cleared_packets_; i < encrypted_packets_.size(); ++i) {
597 packets.push_back(encrypted_packets_[i].get());
598 }
599 return packets;
600 }
601
ClearPackets()602 void PacketSavingConnection::ClearPackets() {
603 num_cleared_packets_ = encrypted_packets_.size();
604 }
605
MockQuicSession(QuicConnection * connection)606 MockQuicSession::MockQuicSession(QuicConnection* connection)
607 : MockQuicSession(connection, true) {}
608
MockQuicSession(QuicConnection * connection,bool create_mock_crypto_stream)609 MockQuicSession::MockQuicSession(QuicConnection* connection,
610 bool create_mock_crypto_stream)
611 : QuicSession(connection, nullptr, DefaultQuicConfig(),
612 connection->supported_versions(),
613 /*num_expected_unidirectional_static_streams = */ 0) {
614 if (create_mock_crypto_stream) {
615 crypto_stream_ =
616 std::make_unique<testing::NiceMock<MockQuicCryptoStream>>(this);
617 }
618 ON_CALL(*this, WritevData(_, _, _, _, _, _))
619 .WillByDefault(testing::Return(QuicConsumedData(0, false)));
620 }
621
~MockQuicSession()622 MockQuicSession::~MockQuicSession() { DeleteConnection(); }
623
GetMutableCryptoStream()624 QuicCryptoStream* MockQuicSession::GetMutableCryptoStream() {
625 return crypto_stream_.get();
626 }
627
GetCryptoStream() const628 const QuicCryptoStream* MockQuicSession::GetCryptoStream() const {
629 return crypto_stream_.get();
630 }
631
SetCryptoStream(QuicCryptoStream * crypto_stream)632 void MockQuicSession::SetCryptoStream(QuicCryptoStream* crypto_stream) {
633 crypto_stream_.reset(crypto_stream);
634 }
635
ConsumeData(QuicStreamId id,size_t write_length,QuicStreamOffset offset,StreamSendingState state,TransmissionType,std::optional<EncryptionLevel>)636 QuicConsumedData MockQuicSession::ConsumeData(
637 QuicStreamId id, size_t write_length, QuicStreamOffset offset,
638 StreamSendingState state, TransmissionType /*type*/,
639 std::optional<EncryptionLevel> /*level*/) {
640 if (write_length > 0) {
641 auto buf = std::make_unique<char[]>(write_length);
642 QuicStream* stream = GetOrCreateStream(id);
643 QUICHE_DCHECK(stream);
644 QuicDataWriter writer(write_length, buf.get(), quiche::HOST_BYTE_ORDER);
645 stream->WriteStreamData(offset, write_length, &writer);
646 } else {
647 QUICHE_DCHECK(state != NO_FIN);
648 }
649 return QuicConsumedData(write_length, state != NO_FIN);
650 }
651
MockQuicCryptoStream(QuicSession * session)652 MockQuicCryptoStream::MockQuicCryptoStream(QuicSession* session)
653 : QuicCryptoStream(session), params_(new QuicCryptoNegotiatedParameters) {}
654
~MockQuicCryptoStream()655 MockQuicCryptoStream::~MockQuicCryptoStream() {}
656
EarlyDataReason() const657 ssl_early_data_reason_t MockQuicCryptoStream::EarlyDataReason() const {
658 return ssl_early_data_unknown;
659 }
660
one_rtt_keys_available() const661 bool MockQuicCryptoStream::one_rtt_keys_available() const { return false; }
662
663 const QuicCryptoNegotiatedParameters&
crypto_negotiated_params() const664 MockQuicCryptoStream::crypto_negotiated_params() const {
665 return *params_;
666 }
667
crypto_message_parser()668 CryptoMessageParser* MockQuicCryptoStream::crypto_message_parser() {
669 return &crypto_framer_;
670 }
671
MockQuicSpdySession(QuicConnection * connection)672 MockQuicSpdySession::MockQuicSpdySession(QuicConnection* connection)
673 : MockQuicSpdySession(connection, true) {}
674
MockQuicSpdySession(QuicConnection * connection,bool create_mock_crypto_stream)675 MockQuicSpdySession::MockQuicSpdySession(QuicConnection* connection,
676 bool create_mock_crypto_stream)
677 : QuicSpdySession(connection, nullptr, DefaultQuicConfig(),
678 connection->supported_versions()) {
679 if (create_mock_crypto_stream) {
680 crypto_stream_ = std::make_unique<MockQuicCryptoStream>(this);
681 }
682
683 ON_CALL(*this, WritevData(_, _, _, _, _, _))
684 .WillByDefault(testing::Return(QuicConsumedData(0, false)));
685
686 ON_CALL(*this, SendWindowUpdate(_, _))
687 .WillByDefault([this](QuicStreamId id, QuicStreamOffset byte_offset) {
688 return QuicSpdySession::SendWindowUpdate(id, byte_offset);
689 });
690
691 ON_CALL(*this, SendBlocked(_, _))
692 .WillByDefault([this](QuicStreamId id, QuicStreamOffset byte_offset) {
693 return QuicSpdySession::SendBlocked(id, byte_offset);
694 });
695
696 ON_CALL(*this, OnCongestionWindowChange(_)).WillByDefault(testing::Return());
697 }
698
~MockQuicSpdySession()699 MockQuicSpdySession::~MockQuicSpdySession() { DeleteConnection(); }
700
GetMutableCryptoStream()701 QuicCryptoStream* MockQuicSpdySession::GetMutableCryptoStream() {
702 return crypto_stream_.get();
703 }
704
GetCryptoStream() const705 const QuicCryptoStream* MockQuicSpdySession::GetCryptoStream() const {
706 return crypto_stream_.get();
707 }
708
SetCryptoStream(QuicCryptoStream * crypto_stream)709 void MockQuicSpdySession::SetCryptoStream(QuicCryptoStream* crypto_stream) {
710 crypto_stream_.reset(crypto_stream);
711 }
712
ConsumeData(QuicStreamId id,size_t write_length,QuicStreamOffset offset,StreamSendingState state,TransmissionType,std::optional<EncryptionLevel>)713 QuicConsumedData MockQuicSpdySession::ConsumeData(
714 QuicStreamId id, size_t write_length, QuicStreamOffset offset,
715 StreamSendingState state, TransmissionType /*type*/,
716 std::optional<EncryptionLevel> /*level*/) {
717 if (write_length > 0) {
718 auto buf = std::make_unique<char[]>(write_length);
719 QuicStream* stream = GetOrCreateStream(id);
720 QUICHE_DCHECK(stream);
721 QuicDataWriter writer(write_length, buf.get(), quiche::HOST_BYTE_ORDER);
722 stream->WriteStreamData(offset, write_length, &writer);
723 } else {
724 QUICHE_DCHECK(state != NO_FIN);
725 }
726 return QuicConsumedData(write_length, state != NO_FIN);
727 }
728
TestQuicSpdyServerSession(QuicConnection * connection,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions,const QuicCryptoServerConfig * crypto_config,QuicCompressedCertsCache * compressed_certs_cache)729 TestQuicSpdyServerSession::TestQuicSpdyServerSession(
730 QuicConnection* connection, const QuicConfig& config,
731 const ParsedQuicVersionVector& supported_versions,
732 const QuicCryptoServerConfig* crypto_config,
733 QuicCompressedCertsCache* compressed_certs_cache)
734 : QuicServerSessionBase(config, supported_versions, connection, &visitor_,
735 &helper_, crypto_config, compressed_certs_cache) {
736 ON_CALL(helper_, CanAcceptClientHello(_, _, _, _, _))
737 .WillByDefault(testing::Return(true));
738 }
739
~TestQuicSpdyServerSession()740 TestQuicSpdyServerSession::~TestQuicSpdyServerSession() { DeleteConnection(); }
741
742 std::unique_ptr<QuicCryptoServerStreamBase>
CreateQuicCryptoServerStream(const QuicCryptoServerConfig * crypto_config,QuicCompressedCertsCache * compressed_certs_cache)743 TestQuicSpdyServerSession::CreateQuicCryptoServerStream(
744 const QuicCryptoServerConfig* crypto_config,
745 QuicCompressedCertsCache* compressed_certs_cache) {
746 return CreateCryptoServerStream(crypto_config, compressed_certs_cache, this,
747 &helper_);
748 }
749
750 QuicCryptoServerStreamBase*
GetMutableCryptoStream()751 TestQuicSpdyServerSession::GetMutableCryptoStream() {
752 return QuicServerSessionBase::GetMutableCryptoStream();
753 }
754
GetCryptoStream() const755 const QuicCryptoServerStreamBase* TestQuicSpdyServerSession::GetCryptoStream()
756 const {
757 return QuicServerSessionBase::GetCryptoStream();
758 }
759
TestQuicSpdyClientSession(QuicConnection * connection,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions,const QuicServerId & server_id,QuicCryptoClientConfig * crypto_config,std::optional<QuicSSLConfig> ssl_config)760 TestQuicSpdyClientSession::TestQuicSpdyClientSession(
761 QuicConnection* connection, const QuicConfig& config,
762 const ParsedQuicVersionVector& supported_versions,
763 const QuicServerId& server_id, QuicCryptoClientConfig* crypto_config,
764 std::optional<QuicSSLConfig> ssl_config)
765 : QuicSpdyClientSessionBase(connection, nullptr, config,
766 supported_versions),
767 ssl_config_(std::move(ssl_config)) {
768 // TODO(b/153726130): Consider adding SetServerApplicationStateForResumption
769 // calls in tests and set |has_application_state| to true.
770 crypto_stream_ = std::make_unique<QuicCryptoClientStream>(
771 server_id, this, crypto_test_utils::ProofVerifyContextForTesting(),
772 crypto_config, this, /*has_application_state = */ false);
773 Initialize();
774 ON_CALL(*this, OnConfigNegotiated())
775 .WillByDefault(
776 Invoke(this, &TestQuicSpdyClientSession::RealOnConfigNegotiated));
777 }
778
~TestQuicSpdyClientSession()779 TestQuicSpdyClientSession::~TestQuicSpdyClientSession() {}
780
GetMutableCryptoStream()781 QuicCryptoClientStream* TestQuicSpdyClientSession::GetMutableCryptoStream() {
782 return crypto_stream_.get();
783 }
784
GetCryptoStream() const785 const QuicCryptoClientStream* TestQuicSpdyClientSession::GetCryptoStream()
786 const {
787 return crypto_stream_.get();
788 }
789
RealOnConfigNegotiated()790 void TestQuicSpdyClientSession::RealOnConfigNegotiated() {
791 QuicSpdyClientSessionBase::OnConfigNegotiated();
792 }
793
MockPacketWriter()794 MockPacketWriter::MockPacketWriter() {
795 ON_CALL(*this, GetMaxPacketSize(_))
796 .WillByDefault(testing::Return(kMaxOutgoingPacketSize));
797 ON_CALL(*this, IsBatchMode()).WillByDefault(testing::Return(false));
798 ON_CALL(*this, GetNextWriteLocation(_, _))
799 .WillByDefault(testing::Return(QuicPacketBuffer()));
800 ON_CALL(*this, Flush())
801 .WillByDefault(testing::Return(WriteResult(WRITE_STATUS_OK, 0)));
802 ON_CALL(*this, SupportsReleaseTime()).WillByDefault(testing::Return(false));
803 }
804
~MockPacketWriter()805 MockPacketWriter::~MockPacketWriter() {}
806
MockSendAlgorithm()807 MockSendAlgorithm::MockSendAlgorithm() {
808 ON_CALL(*this, PacingRate(_))
809 .WillByDefault(testing::Return(QuicBandwidth::Zero()));
810 ON_CALL(*this, BandwidthEstimate())
811 .WillByDefault(testing::Return(QuicBandwidth::Zero()));
812 }
813
~MockSendAlgorithm()814 MockSendAlgorithm::~MockSendAlgorithm() {}
815
MockLossAlgorithm()816 MockLossAlgorithm::MockLossAlgorithm() {}
817
~MockLossAlgorithm()818 MockLossAlgorithm::~MockLossAlgorithm() {}
819
MockAckListener()820 MockAckListener::MockAckListener() {}
821
~MockAckListener()822 MockAckListener::~MockAckListener() {}
823
MockNetworkChangeVisitor()824 MockNetworkChangeVisitor::MockNetworkChangeVisitor() {}
825
~MockNetworkChangeVisitor()826 MockNetworkChangeVisitor::~MockNetworkChangeVisitor() {}
827
TestPeerIPAddress()828 QuicIpAddress TestPeerIPAddress() { return QuicIpAddress::Loopback4(); }
829
QuicVersionMax()830 ParsedQuicVersion QuicVersionMax() { return AllSupportedVersions().front(); }
831
QuicVersionMin()832 ParsedQuicVersion QuicVersionMin() { return AllSupportedVersions().back(); }
833
DisableQuicVersionsWithTls()834 void DisableQuicVersionsWithTls() {
835 for (const ParsedQuicVersion& version : AllSupportedVersionsWithTls()) {
836 QuicDisableVersion(version);
837 }
838 }
839
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data)840 QuicEncryptedPacket* ConstructEncryptedPacket(
841 QuicConnectionId destination_connection_id,
842 QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
843 uint64_t packet_number, const std::string& data) {
844 return ConstructEncryptedPacket(
845 destination_connection_id, source_connection_id, version_flag, reset_flag,
846 packet_number, data, CONNECTION_ID_PRESENT, CONNECTION_ID_ABSENT,
847 PACKET_4BYTE_PACKET_NUMBER);
848 }
849
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length)850 QuicEncryptedPacket* ConstructEncryptedPacket(
851 QuicConnectionId destination_connection_id,
852 QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
853 uint64_t packet_number, const std::string& data,
854 QuicConnectionIdIncluded destination_connection_id_included,
855 QuicConnectionIdIncluded source_connection_id_included,
856 QuicPacketNumberLength packet_number_length) {
857 return ConstructEncryptedPacket(
858 destination_connection_id, source_connection_id, version_flag, reset_flag,
859 packet_number, data, destination_connection_id_included,
860 source_connection_id_included, packet_number_length, nullptr);
861 }
862
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length,ParsedQuicVersionVector * versions)863 QuicEncryptedPacket* ConstructEncryptedPacket(
864 QuicConnectionId destination_connection_id,
865 QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
866 uint64_t packet_number, const std::string& data,
867 QuicConnectionIdIncluded destination_connection_id_included,
868 QuicConnectionIdIncluded source_connection_id_included,
869 QuicPacketNumberLength packet_number_length,
870 ParsedQuicVersionVector* versions) {
871 return ConstructEncryptedPacket(
872 destination_connection_id, source_connection_id, version_flag, reset_flag,
873 packet_number, data, false, destination_connection_id_included,
874 source_connection_id_included, packet_number_length, versions,
875 Perspective::IS_CLIENT);
876 }
877
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,bool full_padding,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length,ParsedQuicVersionVector * versions)878 QuicEncryptedPacket* ConstructEncryptedPacket(
879 QuicConnectionId destination_connection_id,
880 QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
881 uint64_t packet_number, const std::string& data, bool full_padding,
882 QuicConnectionIdIncluded destination_connection_id_included,
883 QuicConnectionIdIncluded source_connection_id_included,
884 QuicPacketNumberLength packet_number_length,
885 ParsedQuicVersionVector* versions) {
886 return ConstructEncryptedPacket(
887 destination_connection_id, source_connection_id, version_flag, reset_flag,
888 packet_number, data, full_padding, destination_connection_id_included,
889 source_connection_id_included, packet_number_length, versions,
890 Perspective::IS_CLIENT);
891 }
892
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,bool full_padding,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length,ParsedQuicVersionVector * versions,Perspective perspective)893 QuicEncryptedPacket* ConstructEncryptedPacket(
894 QuicConnectionId destination_connection_id,
895 QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
896 uint64_t packet_number, const std::string& data, bool full_padding,
897 QuicConnectionIdIncluded destination_connection_id_included,
898 QuicConnectionIdIncluded source_connection_id_included,
899 QuicPacketNumberLength packet_number_length,
900 ParsedQuicVersionVector* versions, Perspective perspective) {
901 QuicPacketHeader header;
902 header.destination_connection_id = destination_connection_id;
903 header.destination_connection_id_included =
904 destination_connection_id_included;
905 header.source_connection_id = source_connection_id;
906 header.source_connection_id_included = source_connection_id_included;
907 header.version_flag = version_flag;
908 header.reset_flag = reset_flag;
909 header.packet_number_length = packet_number_length;
910 header.packet_number = QuicPacketNumber(packet_number);
911 ParsedQuicVersionVector supported_versions = CurrentSupportedVersions();
912 if (!versions) {
913 versions = &supported_versions;
914 }
915 EXPECT_FALSE(versions->empty());
916 ParsedQuicVersion version = (*versions)[0];
917 if (QuicVersionHasLongHeaderLengths(version.transport_version) &&
918 version_flag) {
919 header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1;
920 header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2;
921 }
922
923 QuicFrames frames;
924 QuicFramer framer(*versions, QuicTime::Zero(), perspective,
925 kQuicDefaultConnectionIdLength);
926 framer.SetInitialObfuscators(destination_connection_id);
927 EncryptionLevel level =
928 header.version_flag ? ENCRYPTION_INITIAL : ENCRYPTION_FORWARD_SECURE;
929 if (level != ENCRYPTION_INITIAL) {
930 framer.SetEncrypter(level, std::make_unique<TaggingEncrypter>(level));
931 }
932 if (!QuicVersionUsesCryptoFrames(version.transport_version)) {
933 QuicFrame frame(
934 QuicStreamFrame(QuicUtils::GetCryptoStreamId(version.transport_version),
935 false, 0, absl::string_view(data)));
936 frames.push_back(frame);
937 } else {
938 QuicFrame frame(new QuicCryptoFrame(level, 0, data));
939 frames.push_back(frame);
940 }
941 if (full_padding) {
942 frames.push_back(QuicFrame(QuicPaddingFrame(-1)));
943 } else {
944 // We need a minimum number of bytes of encrypted payload. This will
945 // guarantee that we have at least that much. (It ignores the overhead of
946 // the stream/crypto framing, so it overpads slightly.)
947 size_t min_plaintext_size = QuicPacketCreator::MinPlaintextPacketSize(
948 version, packet_number_length);
949 if (data.length() < min_plaintext_size) {
950 size_t padding_length = min_plaintext_size - data.length();
951 frames.push_back(QuicFrame(QuicPaddingFrame(padding_length)));
952 }
953 }
954
955 std::unique_ptr<QuicPacket> packet(
956 BuildUnsizedDataPacket(&framer, header, frames));
957 EXPECT_TRUE(packet != nullptr);
958 char* buffer = new char[kMaxOutgoingPacketSize];
959 size_t encrypted_length =
960 framer.EncryptPayload(level, QuicPacketNumber(packet_number), *packet,
961 buffer, kMaxOutgoingPacketSize);
962 EXPECT_NE(0u, encrypted_length);
963 DeleteFrames(&frames);
964 return new QuicEncryptedPacket(buffer, encrypted_length, true);
965 }
966
GetUndecryptableEarlyPacket(const ParsedQuicVersion & version,const QuicConnectionId & server_connection_id)967 std::unique_ptr<QuicEncryptedPacket> GetUndecryptableEarlyPacket(
968 const ParsedQuicVersion& version,
969 const QuicConnectionId& server_connection_id) {
970 QuicPacketHeader header;
971 header.destination_connection_id = server_connection_id;
972 header.destination_connection_id_included = CONNECTION_ID_PRESENT;
973 header.source_connection_id = EmptyQuicConnectionId();
974 header.source_connection_id_included = CONNECTION_ID_PRESENT;
975 if (!version.SupportsClientConnectionIds()) {
976 header.source_connection_id_included = CONNECTION_ID_ABSENT;
977 }
978 header.version_flag = true;
979 header.reset_flag = false;
980 header.packet_number_length = PACKET_4BYTE_PACKET_NUMBER;
981 header.packet_number = QuicPacketNumber(33);
982 header.long_packet_type = ZERO_RTT_PROTECTED;
983 if (version.HasLongHeaderLengths()) {
984 header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1;
985 header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2;
986 }
987
988 QuicFrames frames;
989 frames.push_back(QuicFrame(QuicPingFrame()));
990 frames.push_back(QuicFrame(QuicPaddingFrame(100)));
991 QuicFramer framer({version}, QuicTime::Zero(), Perspective::IS_CLIENT,
992 kQuicDefaultConnectionIdLength);
993 framer.SetInitialObfuscators(server_connection_id);
994
995 framer.SetEncrypter(ENCRYPTION_ZERO_RTT,
996 std::make_unique<TaggingEncrypter>(ENCRYPTION_ZERO_RTT));
997 std::unique_ptr<QuicPacket> packet(
998 BuildUnsizedDataPacket(&framer, header, frames));
999 EXPECT_TRUE(packet != nullptr);
1000 char* buffer = new char[kMaxOutgoingPacketSize];
1001 size_t encrypted_length =
1002 framer.EncryptPayload(ENCRYPTION_ZERO_RTT, header.packet_number, *packet,
1003 buffer, kMaxOutgoingPacketSize);
1004 EXPECT_NE(0u, encrypted_length);
1005 DeleteFrames(&frames);
1006 return std::make_unique<QuicEncryptedPacket>(buffer, encrypted_length,
1007 /*owns_buffer=*/true);
1008 }
1009
ConstructReceivedPacket(const QuicEncryptedPacket & encrypted_packet,QuicTime receipt_time)1010 QuicReceivedPacket* ConstructReceivedPacket(
1011 const QuicEncryptedPacket& encrypted_packet, QuicTime receipt_time) {
1012 char* buffer = new char[encrypted_packet.length()];
1013 memcpy(buffer, encrypted_packet.data(), encrypted_packet.length());
1014 return new QuicReceivedPacket(buffer, encrypted_packet.length(), receipt_time,
1015 true);
1016 }
1017
ConstructMisFramedEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length,ParsedQuicVersion version,Perspective perspective)1018 QuicEncryptedPacket* ConstructMisFramedEncryptedPacket(
1019 QuicConnectionId destination_connection_id,
1020 QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
1021 uint64_t packet_number, const std::string& data,
1022 QuicConnectionIdIncluded destination_connection_id_included,
1023 QuicConnectionIdIncluded source_connection_id_included,
1024 QuicPacketNumberLength packet_number_length, ParsedQuicVersion version,
1025 Perspective perspective) {
1026 QuicPacketHeader header;
1027 header.destination_connection_id = destination_connection_id;
1028 header.destination_connection_id_included =
1029 destination_connection_id_included;
1030 header.source_connection_id = source_connection_id;
1031 header.source_connection_id_included = source_connection_id_included;
1032 header.version_flag = version_flag;
1033 header.reset_flag = reset_flag;
1034 header.packet_number_length = packet_number_length;
1035 header.packet_number = QuicPacketNumber(packet_number);
1036 if (QuicVersionHasLongHeaderLengths(version.transport_version) &&
1037 version_flag) {
1038 header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1;
1039 header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2;
1040 }
1041 QuicFrame frame(QuicStreamFrame(1, false, 0, absl::string_view(data)));
1042 QuicFrames frames;
1043 frames.push_back(frame);
1044 QuicFramer framer({version}, QuicTime::Zero(), perspective,
1045 kQuicDefaultConnectionIdLength);
1046 framer.SetInitialObfuscators(destination_connection_id);
1047 EncryptionLevel level =
1048 version_flag ? ENCRYPTION_INITIAL : ENCRYPTION_FORWARD_SECURE;
1049 if (level != ENCRYPTION_INITIAL) {
1050 framer.SetEncrypter(level, std::make_unique<TaggingEncrypter>(level));
1051 }
1052 // We need a minimum of 7 bytes of encrypted payload. This will guarantee that
1053 // we have at least that much. (It ignores the overhead of the stream/crypto
1054 // framing, so it overpads slightly.)
1055 if (data.length() < 7) {
1056 size_t padding_length = 7 - data.length();
1057 frames.push_back(QuicFrame(QuicPaddingFrame(padding_length)));
1058 }
1059
1060 std::unique_ptr<QuicPacket> packet(
1061 BuildUnsizedDataPacket(&framer, header, frames));
1062 EXPECT_TRUE(packet != nullptr);
1063
1064 // Now set the frame type to 0x1F, which is an invalid frame type.
1065 reinterpret_cast<unsigned char*>(
1066 packet->mutable_data())[GetStartOfEncryptedData(
1067 framer.transport_version(),
1068 GetIncludedDestinationConnectionIdLength(header),
1069 GetIncludedSourceConnectionIdLength(header), version_flag,
1070 false /* no diversification nonce */, packet_number_length,
1071 header.retry_token_length_length, 0, header.length_length)] = 0x1F;
1072
1073 char* buffer = new char[kMaxOutgoingPacketSize];
1074 size_t encrypted_length =
1075 framer.EncryptPayload(level, QuicPacketNumber(packet_number), *packet,
1076 buffer, kMaxOutgoingPacketSize);
1077 EXPECT_NE(0u, encrypted_length);
1078 return new QuicEncryptedPacket(buffer, encrypted_length, true);
1079 }
1080
DefaultQuicConfig()1081 QuicConfig DefaultQuicConfig() {
1082 QuicConfig config;
1083 config.SetInitialMaxStreamDataBytesIncomingBidirectionalToSend(
1084 kInitialStreamFlowControlWindowForTest);
1085 config.SetInitialMaxStreamDataBytesOutgoingBidirectionalToSend(
1086 kInitialStreamFlowControlWindowForTest);
1087 config.SetInitialMaxStreamDataBytesUnidirectionalToSend(
1088 kInitialStreamFlowControlWindowForTest);
1089 config.SetInitialStreamFlowControlWindowToSend(
1090 kInitialStreamFlowControlWindowForTest);
1091 config.SetInitialSessionFlowControlWindowToSend(
1092 kInitialSessionFlowControlWindowForTest);
1093 QuicConfigPeer::SetReceivedMaxBidirectionalStreams(
1094 &config, kDefaultMaxStreamsPerConnection);
1095 // Default enable NSTP.
1096 // This is unnecessary for versions > 44
1097 if (!config.HasClientSentConnectionOption(quic::kNSTP,
1098 quic::Perspective::IS_CLIENT)) {
1099 quic::QuicTagVector connection_options;
1100 connection_options.push_back(quic::kNSTP);
1101 config.SetConnectionOptionsToSend(connection_options);
1102 }
1103 return config;
1104 }
1105
SupportedVersions(ParsedQuicVersion version)1106 ParsedQuicVersionVector SupportedVersions(ParsedQuicVersion version) {
1107 ParsedQuicVersionVector versions;
1108 versions.push_back(version);
1109 return versions;
1110 }
1111
MockQuicConnectionDebugVisitor()1112 MockQuicConnectionDebugVisitor::MockQuicConnectionDebugVisitor() {}
1113
~MockQuicConnectionDebugVisitor()1114 MockQuicConnectionDebugVisitor::~MockQuicConnectionDebugVisitor() {}
1115
MockReceivedPacketManager(QuicConnectionStats * stats)1116 MockReceivedPacketManager::MockReceivedPacketManager(QuicConnectionStats* stats)
1117 : QuicReceivedPacketManager(stats) {}
1118
~MockReceivedPacketManager()1119 MockReceivedPacketManager::~MockReceivedPacketManager() {}
1120
MockPacketCreatorDelegate()1121 MockPacketCreatorDelegate::MockPacketCreatorDelegate() {}
~MockPacketCreatorDelegate()1122 MockPacketCreatorDelegate::~MockPacketCreatorDelegate() {}
1123
MockSessionNotifier()1124 MockSessionNotifier::MockSessionNotifier() {}
~MockSessionNotifier()1125 MockSessionNotifier::~MockSessionNotifier() {}
1126
1127 // static
1128 QuicCryptoClientStream::HandshakerInterface*
GetHandshaker(QuicCryptoClientStream * stream)1129 QuicCryptoClientStreamPeer::GetHandshaker(QuicCryptoClientStream* stream) {
1130 return stream->handshaker_.get();
1131 }
1132
CreateClientSessionForTest(QuicServerId server_id,QuicTime::Delta connection_start_time,const ParsedQuicVersionVector & supported_versions,MockQuicConnectionHelper * helper,QuicAlarmFactory * alarm_factory,QuicCryptoClientConfig * crypto_client_config,PacketSavingConnection ** client_connection,TestQuicSpdyClientSession ** client_session)1133 void CreateClientSessionForTest(
1134 QuicServerId server_id, QuicTime::Delta connection_start_time,
1135 const ParsedQuicVersionVector& supported_versions,
1136 MockQuicConnectionHelper* helper, QuicAlarmFactory* alarm_factory,
1137 QuicCryptoClientConfig* crypto_client_config,
1138 PacketSavingConnection** client_connection,
1139 TestQuicSpdyClientSession** client_session) {
1140 QUICHE_CHECK(crypto_client_config);
1141 QUICHE_CHECK(client_connection);
1142 QUICHE_CHECK(client_session);
1143 QUICHE_CHECK(!connection_start_time.IsZero())
1144 << "Connections must start at non-zero times, otherwise the "
1145 << "strike-register will be unhappy.";
1146
1147 QuicConfig config = DefaultQuicConfig();
1148 *client_connection = new PacketSavingConnection(
1149 helper, alarm_factory, Perspective::IS_CLIENT, supported_versions);
1150 *client_session = new TestQuicSpdyClientSession(*client_connection, config,
1151 supported_versions, server_id,
1152 crypto_client_config);
1153 (*client_connection)->AdvanceTime(connection_start_time);
1154 }
1155
CreateServerSessionForTest(QuicServerId,QuicTime::Delta connection_start_time,ParsedQuicVersionVector supported_versions,MockQuicConnectionHelper * helper,QuicAlarmFactory * alarm_factory,QuicCryptoServerConfig * server_crypto_config,QuicCompressedCertsCache * compressed_certs_cache,PacketSavingConnection ** server_connection,TestQuicSpdyServerSession ** server_session)1156 void CreateServerSessionForTest(
1157 QuicServerId /*server_id*/, QuicTime::Delta connection_start_time,
1158 ParsedQuicVersionVector supported_versions,
1159 MockQuicConnectionHelper* helper, QuicAlarmFactory* alarm_factory,
1160 QuicCryptoServerConfig* server_crypto_config,
1161 QuicCompressedCertsCache* compressed_certs_cache,
1162 PacketSavingConnection** server_connection,
1163 TestQuicSpdyServerSession** server_session) {
1164 QUICHE_CHECK(server_crypto_config);
1165 QUICHE_CHECK(server_connection);
1166 QUICHE_CHECK(server_session);
1167 QUICHE_CHECK(!connection_start_time.IsZero())
1168 << "Connections must start at non-zero times, otherwise the "
1169 << "strike-register will be unhappy.";
1170
1171 *server_connection =
1172 new PacketSavingConnection(helper, alarm_factory, Perspective::IS_SERVER,
1173 ParsedVersionOfIndex(supported_versions, 0));
1174 *server_session = new TestQuicSpdyServerSession(
1175 *server_connection, DefaultQuicConfig(), supported_versions,
1176 server_crypto_config, compressed_certs_cache);
1177 (*server_session)->Initialize();
1178
1179 // We advance the clock initially because the default time is zero and the
1180 // strike register worries that we've just overflowed a uint32_t time.
1181 (*server_connection)->AdvanceTime(connection_start_time);
1182 }
1183
GetNthClientInitiatedBidirectionalStreamId(QuicTransportVersion version,int n)1184 QuicStreamId GetNthClientInitiatedBidirectionalStreamId(
1185 QuicTransportVersion version, int n) {
1186 int num = n;
1187 if (!VersionUsesHttp3(version)) {
1188 num++;
1189 }
1190 return QuicUtils::GetFirstBidirectionalStreamId(version,
1191 Perspective::IS_CLIENT) +
1192 QuicUtils::StreamIdDelta(version) * num;
1193 }
1194
GetNthServerInitiatedBidirectionalStreamId(QuicTransportVersion version,int n)1195 QuicStreamId GetNthServerInitiatedBidirectionalStreamId(
1196 QuicTransportVersion version, int n) {
1197 return QuicUtils::GetFirstBidirectionalStreamId(version,
1198 Perspective::IS_SERVER) +
1199 QuicUtils::StreamIdDelta(version) * n;
1200 }
1201
GetNthServerInitiatedUnidirectionalStreamId(QuicTransportVersion version,int n)1202 QuicStreamId GetNthServerInitiatedUnidirectionalStreamId(
1203 QuicTransportVersion version, int n) {
1204 return QuicUtils::GetFirstUnidirectionalStreamId(version,
1205 Perspective::IS_SERVER) +
1206 QuicUtils::StreamIdDelta(version) * n;
1207 }
1208
GetNthClientInitiatedUnidirectionalStreamId(QuicTransportVersion version,int n)1209 QuicStreamId GetNthClientInitiatedUnidirectionalStreamId(
1210 QuicTransportVersion version, int n) {
1211 return QuicUtils::GetFirstUnidirectionalStreamId(version,
1212 Perspective::IS_CLIENT) +
1213 QuicUtils::StreamIdDelta(version) * n;
1214 }
1215
DetermineStreamType(QuicStreamId id,ParsedQuicVersion version,Perspective perspective,bool is_incoming,StreamType default_type)1216 StreamType DetermineStreamType(QuicStreamId id, ParsedQuicVersion version,
1217 Perspective perspective, bool is_incoming,
1218 StreamType default_type) {
1219 return version.HasIetfQuicFrames()
1220 ? QuicUtils::GetStreamType(id, perspective, is_incoming, version)
1221 : default_type;
1222 }
1223
MemSliceFromString(absl::string_view data)1224 quiche::QuicheMemSlice MemSliceFromString(absl::string_view data) {
1225 if (data.empty()) {
1226 return quiche::QuicheMemSlice();
1227 }
1228
1229 static quiche::SimpleBufferAllocator* allocator =
1230 new quiche::SimpleBufferAllocator();
1231 return quiche::QuicheMemSlice(quiche::QuicheBuffer::Copy(allocator, data));
1232 }
1233
EncryptPacket(uint64_t,absl::string_view,absl::string_view plaintext,char * output,size_t * output_length,size_t max_output_length)1234 bool TaggingEncrypter::EncryptPacket(uint64_t /*packet_number*/,
1235 absl::string_view /*associated_data*/,
1236 absl::string_view plaintext, char* output,
1237 size_t* output_length,
1238 size_t max_output_length) {
1239 const size_t len = plaintext.size() + kTagSize;
1240 if (max_output_length < len) {
1241 return false;
1242 }
1243 // Memmove is safe for inplace encryption.
1244 memmove(output, plaintext.data(), plaintext.size());
1245 output += plaintext.size();
1246 memset(output, tag_, kTagSize);
1247 *output_length = len;
1248 return true;
1249 }
1250
DecryptPacket(uint64_t,absl::string_view,absl::string_view ciphertext,char * output,size_t * output_length,size_t)1251 bool TaggingDecrypter::DecryptPacket(uint64_t /*packet_number*/,
1252 absl::string_view /*associated_data*/,
1253 absl::string_view ciphertext, char* output,
1254 size_t* output_length,
1255 size_t /*max_output_length*/) {
1256 if (ciphertext.size() < kTagSize) {
1257 return false;
1258 }
1259 if (!CheckTag(ciphertext, GetTag(ciphertext))) {
1260 return false;
1261 }
1262 *output_length = ciphertext.size() - kTagSize;
1263 memcpy(output, ciphertext.data(), *output_length);
1264 return true;
1265 }
1266
CheckTag(absl::string_view ciphertext,uint8_t tag)1267 bool TaggingDecrypter::CheckTag(absl::string_view ciphertext, uint8_t tag) {
1268 for (size_t i = ciphertext.size() - kTagSize; i < ciphertext.size(); i++) {
1269 if (ciphertext.data()[i] != tag) {
1270 return false;
1271 }
1272 }
1273
1274 return true;
1275 }
1276
TestPacketWriter(ParsedQuicVersion version,MockClock * clock,Perspective perspective)1277 TestPacketWriter::TestPacketWriter(ParsedQuicVersion version, MockClock* clock,
1278 Perspective perspective)
1279 : version_(version),
1280 framer_(SupportedVersions(version_),
1281 QuicUtils::InvertPerspective(perspective)),
1282 clock_(clock) {
1283 QuicFramerPeer::SetLastSerializedServerConnectionId(framer_.framer(),
1284 TestConnectionId());
1285 framer_.framer()->SetInitialObfuscators(TestConnectionId());
1286
1287 for (int i = 0; i < 128; ++i) {
1288 PacketBuffer* p = new PacketBuffer();
1289 packet_buffer_pool_.push_back(p);
1290 packet_buffer_pool_index_[p->buffer] = p;
1291 packet_buffer_free_list_.push_back(p);
1292 }
1293 }
1294
~TestPacketWriter()1295 TestPacketWriter::~TestPacketWriter() {
1296 EXPECT_EQ(packet_buffer_pool_.size(), packet_buffer_free_list_.size())
1297 << packet_buffer_pool_.size() - packet_buffer_free_list_.size()
1298 << " out of " << packet_buffer_pool_.size()
1299 << " packet buffers have been leaked.";
1300 for (auto p : packet_buffer_pool_) {
1301 delete p;
1302 }
1303 }
1304
WritePacket(const char * buffer,size_t buf_len,const QuicIpAddress & self_address,const QuicSocketAddress & peer_address,PerPacketOptions *,const QuicPacketWriterParams & params)1305 WriteResult TestPacketWriter::WritePacket(
1306 const char* buffer, size_t buf_len, const QuicIpAddress& self_address,
1307 const QuicSocketAddress& peer_address, PerPacketOptions* /*options*/,
1308 const QuicPacketWriterParams& params) {
1309 last_write_source_address_ = self_address;
1310 last_write_peer_address_ = peer_address;
1311 // If the buffer is allocated from the pool, return it back to the pool.
1312 // Note the buffer content doesn't change.
1313 if (packet_buffer_pool_index_.find(const_cast<char*>(buffer)) !=
1314 packet_buffer_pool_index_.end()) {
1315 FreePacketBuffer(buffer);
1316 }
1317
1318 QuicEncryptedPacket packet(buffer, buf_len);
1319 ++packets_write_attempts_;
1320
1321 if (packet.length() >= sizeof(final_bytes_of_last_packet_)) {
1322 final_bytes_of_previous_packet_ = final_bytes_of_last_packet_;
1323 memcpy(&final_bytes_of_last_packet_, packet.data() + packet.length() - 4,
1324 sizeof(final_bytes_of_last_packet_));
1325 }
1326 if (framer_.framer()->version().KnowsWhichDecrypterToUse()) {
1327 framer_.framer()->InstallDecrypter(ENCRYPTION_HANDSHAKE,
1328 std::make_unique<TaggingDecrypter>());
1329 framer_.framer()->InstallDecrypter(ENCRYPTION_ZERO_RTT,
1330 std::make_unique<TaggingDecrypter>());
1331 framer_.framer()->InstallDecrypter(ENCRYPTION_FORWARD_SECURE,
1332 std::make_unique<TaggingDecrypter>());
1333 } else if (!framer_.framer()->HasDecrypterOfEncryptionLevel(
1334 ENCRYPTION_FORWARD_SECURE) &&
1335 !framer_.framer()->HasDecrypterOfEncryptionLevel(
1336 ENCRYPTION_ZERO_RTT)) {
1337 framer_.framer()->SetAlternativeDecrypter(
1338 ENCRYPTION_FORWARD_SECURE,
1339 std::make_unique<StrictTaggingDecrypter>(ENCRYPTION_FORWARD_SECURE),
1340 false);
1341 }
1342 EXPECT_EQ(next_packet_processable_, framer_.ProcessPacket(packet))
1343 << framer_.framer()->detailed_error() << " perspective "
1344 << framer_.framer()->perspective();
1345 next_packet_processable_ = true;
1346 if (block_on_next_write_) {
1347 write_blocked_ = true;
1348 block_on_next_write_ = false;
1349 }
1350 if (next_packet_too_large_) {
1351 next_packet_too_large_ = false;
1352 return WriteResult(WRITE_STATUS_ERROR, *MessageTooBigErrorCode());
1353 }
1354 if (always_get_packet_too_large_) {
1355 return WriteResult(WRITE_STATUS_ERROR, *MessageTooBigErrorCode());
1356 }
1357 if (IsWriteBlocked()) {
1358 return WriteResult(is_write_blocked_data_buffered_
1359 ? WRITE_STATUS_BLOCKED_DATA_BUFFERED
1360 : WRITE_STATUS_BLOCKED,
1361 0);
1362 }
1363
1364 if (ShouldWriteFail()) {
1365 return WriteResult(WRITE_STATUS_ERROR, write_error_code_);
1366 }
1367
1368 last_packet_size_ = packet.length();
1369 total_bytes_written_ += packet.length();
1370 last_packet_header_ = framer_.header();
1371 if (!framer_.connection_close_frames().empty()) {
1372 ++connection_close_packets_;
1373 }
1374 if (!write_pause_time_delta_.IsZero()) {
1375 clock_->AdvanceTime(write_pause_time_delta_);
1376 }
1377 if (is_batch_mode_) {
1378 bytes_buffered_ += last_packet_size_;
1379 return WriteResult(WRITE_STATUS_OK, 0);
1380 }
1381 last_ecn_sent_ = params.ecn_codepoint;
1382 return WriteResult(WRITE_STATUS_OK, last_packet_size_);
1383 }
1384
GetNextWriteLocation(const QuicIpAddress &,const QuicSocketAddress &)1385 QuicPacketBuffer TestPacketWriter::GetNextWriteLocation(
1386 const QuicIpAddress& /*self_address*/,
1387 const QuicSocketAddress& /*peer_address*/) {
1388 return {AllocPacketBuffer(), [this](const char* p) { FreePacketBuffer(p); }};
1389 }
1390
Flush()1391 WriteResult TestPacketWriter::Flush() {
1392 flush_attempts_++;
1393 if (block_on_next_flush_) {
1394 block_on_next_flush_ = false;
1395 SetWriteBlocked();
1396 return WriteResult(WRITE_STATUS_BLOCKED, /*errno*/ -1);
1397 }
1398 if (write_should_fail_) {
1399 return WriteResult(WRITE_STATUS_ERROR, /*errno*/ -1);
1400 }
1401 int bytes_flushed = bytes_buffered_;
1402 bytes_buffered_ = 0;
1403 return WriteResult(WRITE_STATUS_OK, bytes_flushed);
1404 }
1405
AllocPacketBuffer()1406 char* TestPacketWriter::AllocPacketBuffer() {
1407 PacketBuffer* p = packet_buffer_free_list_.front();
1408 EXPECT_FALSE(p->in_use);
1409 p->in_use = true;
1410 packet_buffer_free_list_.pop_front();
1411 return p->buffer;
1412 }
1413
FreePacketBuffer(const char * buffer)1414 void TestPacketWriter::FreePacketBuffer(const char* buffer) {
1415 auto iter = packet_buffer_pool_index_.find(const_cast<char*>(buffer));
1416 ASSERT_TRUE(iter != packet_buffer_pool_index_.end());
1417 PacketBuffer* p = iter->second;
1418 ASSERT_TRUE(p->in_use);
1419 p->in_use = false;
1420 packet_buffer_free_list_.push_back(p);
1421 }
1422
WriteServerVersionNegotiationProbeResponse(char * packet_bytes,size_t * packet_length_out,const char * source_connection_id_bytes,uint8_t source_connection_id_length)1423 bool WriteServerVersionNegotiationProbeResponse(
1424 char* packet_bytes, size_t* packet_length_out,
1425 const char* source_connection_id_bytes,
1426 uint8_t source_connection_id_length) {
1427 if (packet_bytes == nullptr) {
1428 QUIC_BUG(quic_bug_10256_1) << "Invalid packet_bytes";
1429 return false;
1430 }
1431 if (packet_length_out == nullptr) {
1432 QUIC_BUG(quic_bug_10256_2) << "Invalid packet_length_out";
1433 return false;
1434 }
1435 QuicConnectionId source_connection_id(source_connection_id_bytes,
1436 source_connection_id_length);
1437 std::unique_ptr<QuicEncryptedPacket> encrypted_packet =
1438 QuicFramer::BuildVersionNegotiationPacket(
1439 source_connection_id, EmptyQuicConnectionId(),
1440 /*ietf_quic=*/true, /*use_length_prefix=*/true,
1441 ParsedQuicVersionVector{});
1442 if (!encrypted_packet) {
1443 QUIC_BUG(quic_bug_10256_3) << "Failed to create version negotiation packet";
1444 return false;
1445 }
1446 if (*packet_length_out < encrypted_packet->length()) {
1447 QUIC_BUG(quic_bug_10256_4)
1448 << "Invalid *packet_length_out " << *packet_length_out << " < "
1449 << encrypted_packet->length();
1450 return false;
1451 }
1452 *packet_length_out = encrypted_packet->length();
1453 memcpy(packet_bytes, encrypted_packet->data(), *packet_length_out);
1454 return true;
1455 }
1456
ParseClientVersionNegotiationProbePacket(const char * packet_bytes,size_t packet_length,char * destination_connection_id_bytes,uint8_t * destination_connection_id_length_out)1457 bool ParseClientVersionNegotiationProbePacket(
1458 const char* packet_bytes, size_t packet_length,
1459 char* destination_connection_id_bytes,
1460 uint8_t* destination_connection_id_length_out) {
1461 if (packet_bytes == nullptr) {
1462 QUIC_BUG(quic_bug_10256_5) << "Invalid packet_bytes";
1463 return false;
1464 }
1465 if (packet_length < kMinPacketSizeForVersionNegotiation ||
1466 packet_length > 65535) {
1467 QUIC_BUG(quic_bug_10256_6) << "Invalid packet_length";
1468 return false;
1469 }
1470 if (destination_connection_id_bytes == nullptr) {
1471 QUIC_BUG(quic_bug_10256_7) << "Invalid destination_connection_id_bytes";
1472 return false;
1473 }
1474 if (destination_connection_id_length_out == nullptr) {
1475 QUIC_BUG(quic_bug_10256_8)
1476 << "Invalid destination_connection_id_length_out";
1477 return false;
1478 }
1479
1480 QuicEncryptedPacket encrypted_packet(packet_bytes, packet_length);
1481 PacketHeaderFormat format;
1482 QuicLongHeaderType long_packet_type;
1483 bool version_present, has_length_prefix;
1484 QuicVersionLabel version_label;
1485 ParsedQuicVersion parsed_version = ParsedQuicVersion::Unsupported();
1486 QuicConnectionId destination_connection_id, source_connection_id;
1487 std::optional<absl::string_view> retry_token;
1488 std::string detailed_error;
1489 QuicErrorCode error = QuicFramer::ParsePublicHeaderDispatcher(
1490 encrypted_packet,
1491 /*expected_destination_connection_id_length=*/0, &format,
1492 &long_packet_type, &version_present, &has_length_prefix, &version_label,
1493 &parsed_version, &destination_connection_id, &source_connection_id,
1494 &retry_token, &detailed_error);
1495 if (error != QUIC_NO_ERROR) {
1496 QUIC_BUG(quic_bug_10256_9) << "Failed to parse packet: " << detailed_error;
1497 return false;
1498 }
1499 if (!version_present) {
1500 QUIC_BUG(quic_bug_10256_10) << "Packet is not a long header";
1501 return false;
1502 }
1503 if (*destination_connection_id_length_out <
1504 destination_connection_id.length()) {
1505 QUIC_BUG(quic_bug_10256_11)
1506 << "destination_connection_id_length_out too small";
1507 return false;
1508 }
1509 *destination_connection_id_length_out = destination_connection_id.length();
1510 memcpy(destination_connection_id_bytes, destination_connection_id.data(),
1511 *destination_connection_id_length_out);
1512 return true;
1513 }
1514
1515 } // namespace test
1516 } // namespace quic
1517