1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "quiche/quic/test_tools/quic_test_utils.h"
6
7 #include <algorithm>
8 #include <cstdint>
9 #include <memory>
10 #include <utility>
11
12 #include "absl/base/macros.h"
13 #include "absl/strings/string_view.h"
14 #include "openssl/chacha.h"
15 #include "openssl/sha.h"
16 #include "quiche/quic/core/crypto/crypto_framer.h"
17 #include "quiche/quic/core/crypto/crypto_handshake.h"
18 #include "quiche/quic/core/crypto/crypto_utils.h"
19 #include "quiche/quic/core/crypto/null_decrypter.h"
20 #include "quiche/quic/core/crypto/null_encrypter.h"
21 #include "quiche/quic/core/crypto/quic_decrypter.h"
22 #include "quiche/quic/core/crypto/quic_encrypter.h"
23 #include "quiche/quic/core/http/quic_spdy_client_session.h"
24 #include "quiche/quic/core/quic_config.h"
25 #include "quiche/quic/core/quic_data_writer.h"
26 #include "quiche/quic/core/quic_framer.h"
27 #include "quiche/quic/core/quic_packet_creator.h"
28 #include "quiche/quic/core/quic_types.h"
29 #include "quiche/quic/core/quic_utils.h"
30 #include "quiche/quic/core/quic_versions.h"
31 #include "quiche/quic/platform/api/quic_flags.h"
32 #include "quiche/quic/platform/api/quic_logging.h"
33 #include "quiche/quic/test_tools/crypto_test_utils.h"
34 #include "quiche/quic/test_tools/quic_config_peer.h"
35 #include "quiche/quic/test_tools/quic_connection_peer.h"
36 #include "quiche/common/quiche_buffer_allocator.h"
37 #include "quiche/common/quiche_endian.h"
38 #include "quiche/common/simple_buffer_allocator.h"
39 #include "quiche/spdy/core/spdy_frame_builder.h"
40
41 using testing::_;
42 using testing::Invoke;
43
44 namespace quic {
45 namespace test {
46
TestConnectionId()47 QuicConnectionId TestConnectionId() {
48 // Chosen by fair dice roll.
49 // Guaranteed to be random.
50 return TestConnectionId(42);
51 }
52
TestConnectionId(uint64_t connection_number)53 QuicConnectionId TestConnectionId(uint64_t connection_number) {
54 const uint64_t connection_id64_net =
55 quiche::QuicheEndian::HostToNet64(connection_number);
56 return QuicConnectionId(reinterpret_cast<const char*>(&connection_id64_net),
57 sizeof(connection_id64_net));
58 }
59
TestConnectionIdNineBytesLong(uint64_t connection_number)60 QuicConnectionId TestConnectionIdNineBytesLong(uint64_t connection_number) {
61 const uint64_t connection_number_net =
62 quiche::QuicheEndian::HostToNet64(connection_number);
63 char connection_id_bytes[9] = {};
64 static_assert(
65 sizeof(connection_id_bytes) == 1 + sizeof(connection_number_net),
66 "bad lengths");
67 memcpy(connection_id_bytes + 1, &connection_number_net,
68 sizeof(connection_number_net));
69 return QuicConnectionId(connection_id_bytes, sizeof(connection_id_bytes));
70 }
71
TestConnectionIdToUInt64(QuicConnectionId connection_id)72 uint64_t TestConnectionIdToUInt64(QuicConnectionId connection_id) {
73 QUICHE_DCHECK_EQ(connection_id.length(), kQuicDefaultConnectionIdLength);
74 uint64_t connection_id64_net = 0;
75 memcpy(&connection_id64_net, connection_id.data(),
76 std::min<size_t>(static_cast<size_t>(connection_id.length()),
77 sizeof(connection_id64_net)));
78 return quiche::QuicheEndian::NetToHost64(connection_id64_net);
79 }
80
CreateStatelessResetTokenForTest()81 std::vector<uint8_t> CreateStatelessResetTokenForTest() {
82 static constexpr uint8_t kStatelessResetTokenDataForTest[16] = {
83 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97,
84 0x98, 0x99, 0x9A, 0x9B, 0x9C, 0x9D, 0x9E, 0x9F};
85 return std::vector<uint8_t>(kStatelessResetTokenDataForTest,
86 kStatelessResetTokenDataForTest +
87 sizeof(kStatelessResetTokenDataForTest));
88 }
89
TestHostname()90 std::string TestHostname() { return "test.example.com"; }
91
TestServerId()92 QuicServerId TestServerId() { return QuicServerId(TestHostname(), kTestPort); }
93
InitAckFrame(const std::vector<QuicAckBlock> & ack_blocks)94 QuicAckFrame InitAckFrame(const std::vector<QuicAckBlock>& ack_blocks) {
95 QUICHE_DCHECK_GT(ack_blocks.size(), 0u);
96
97 QuicAckFrame ack;
98 QuicPacketNumber end_of_previous_block(1);
99 for (const QuicAckBlock& block : ack_blocks) {
100 QUICHE_DCHECK_GE(block.start, end_of_previous_block);
101 QUICHE_DCHECK_GT(block.limit, block.start);
102 ack.packets.AddRange(block.start, block.limit);
103 end_of_previous_block = block.limit;
104 }
105
106 ack.largest_acked = ack.packets.Max();
107
108 return ack;
109 }
110
InitAckFrame(uint64_t largest_acked)111 QuicAckFrame InitAckFrame(uint64_t largest_acked) {
112 return InitAckFrame(QuicPacketNumber(largest_acked));
113 }
114
InitAckFrame(QuicPacketNumber largest_acked)115 QuicAckFrame InitAckFrame(QuicPacketNumber largest_acked) {
116 return InitAckFrame({{QuicPacketNumber(1), largest_acked + 1}});
117 }
118
MakeAckFrameWithAckBlocks(size_t num_ack_blocks,uint64_t least_unacked)119 QuicAckFrame MakeAckFrameWithAckBlocks(size_t num_ack_blocks,
120 uint64_t least_unacked) {
121 QuicAckFrame ack;
122 ack.largest_acked = QuicPacketNumber(2 * num_ack_blocks + least_unacked);
123 // Add enough received packets to get num_ack_blocks ack blocks.
124 for (QuicPacketNumber i = QuicPacketNumber(2);
125 i < QuicPacketNumber(2 * num_ack_blocks + 1); i += 2) {
126 ack.packets.Add(i + least_unacked);
127 }
128 return ack;
129 }
130
MakeAckFrameWithGaps(uint64_t gap_size,size_t max_num_gaps,uint64_t largest_acked)131 QuicAckFrame MakeAckFrameWithGaps(uint64_t gap_size, size_t max_num_gaps,
132 uint64_t largest_acked) {
133 QuicAckFrame ack;
134 ack.largest_acked = QuicPacketNumber(largest_acked);
135 ack.packets.Add(QuicPacketNumber(largest_acked));
136 for (size_t i = 0; i < max_num_gaps; ++i) {
137 if (largest_acked <= gap_size) {
138 break;
139 }
140 largest_acked -= gap_size;
141 ack.packets.Add(QuicPacketNumber(largest_acked));
142 }
143 return ack;
144 }
145
HeaderToEncryptionLevel(const QuicPacketHeader & header)146 EncryptionLevel HeaderToEncryptionLevel(const QuicPacketHeader& header) {
147 if (header.form == IETF_QUIC_SHORT_HEADER_PACKET) {
148 return ENCRYPTION_FORWARD_SECURE;
149 } else if (header.form == IETF_QUIC_LONG_HEADER_PACKET) {
150 if (header.long_packet_type == HANDSHAKE) {
151 return ENCRYPTION_HANDSHAKE;
152 } else if (header.long_packet_type == ZERO_RTT_PROTECTED) {
153 return ENCRYPTION_ZERO_RTT;
154 }
155 }
156 return ENCRYPTION_INITIAL;
157 }
158
BuildUnsizedDataPacket(QuicFramer * framer,const QuicPacketHeader & header,const QuicFrames & frames)159 std::unique_ptr<QuicPacket> BuildUnsizedDataPacket(
160 QuicFramer* framer, const QuicPacketHeader& header,
161 const QuicFrames& frames) {
162 const size_t max_plaintext_size =
163 framer->GetMaxPlaintextSize(kMaxOutgoingPacketSize);
164 size_t packet_size = GetPacketHeaderSize(framer->transport_version(), header);
165 for (size_t i = 0; i < frames.size(); ++i) {
166 QUICHE_DCHECK_LE(packet_size, max_plaintext_size);
167 bool first_frame = i == 0;
168 bool last_frame = i == frames.size() - 1;
169 const size_t frame_size = framer->GetSerializedFrameLength(
170 frames[i], max_plaintext_size - packet_size, first_frame, last_frame,
171 header.packet_number_length);
172 QUICHE_DCHECK(frame_size);
173 packet_size += frame_size;
174 }
175 return BuildUnsizedDataPacket(framer, header, frames, packet_size);
176 }
177
BuildUnsizedDataPacket(QuicFramer * framer,const QuicPacketHeader & header,const QuicFrames & frames,size_t packet_size)178 std::unique_ptr<QuicPacket> BuildUnsizedDataPacket(
179 QuicFramer* framer, const QuicPacketHeader& header,
180 const QuicFrames& frames, size_t packet_size) {
181 char* buffer = new char[packet_size];
182 EncryptionLevel level = HeaderToEncryptionLevel(header);
183 size_t length =
184 framer->BuildDataPacket(header, frames, buffer, packet_size, level);
185
186 if (length == 0) {
187 delete[] buffer;
188 return nullptr;
189 }
190 // Re-construct the data packet with data ownership.
191 return std::make_unique<QuicPacket>(
192 buffer, length, /* owns_buffer */ true,
193 GetIncludedDestinationConnectionIdLength(header),
194 GetIncludedSourceConnectionIdLength(header), header.version_flag,
195 header.nonce != nullptr, header.packet_number_length,
196 header.retry_token_length_length, header.retry_token.length(),
197 header.length_length);
198 }
199
Sha1Hash(absl::string_view data)200 std::string Sha1Hash(absl::string_view data) {
201 char buffer[SHA_DIGEST_LENGTH];
202 SHA1(reinterpret_cast<const uint8_t*>(data.data()), data.size(),
203 reinterpret_cast<uint8_t*>(buffer));
204 return std::string(buffer, ABSL_ARRAYSIZE(buffer));
205 }
206
ClearControlFrame(const QuicFrame & frame)207 bool ClearControlFrame(const QuicFrame& frame) {
208 DeleteFrame(&const_cast<QuicFrame&>(frame));
209 return true;
210 }
211
ClearControlFrameWithTransmissionType(const QuicFrame & frame,TransmissionType)212 bool ClearControlFrameWithTransmissionType(const QuicFrame& frame,
213 TransmissionType /*type*/) {
214 return ClearControlFrame(frame);
215 }
216
RandUint64()217 uint64_t SimpleRandom::RandUint64() {
218 uint64_t result;
219 RandBytes(&result, sizeof(result));
220 return result;
221 }
222
RandBytes(void * data,size_t len)223 void SimpleRandom::RandBytes(void* data, size_t len) {
224 uint8_t* data_bytes = reinterpret_cast<uint8_t*>(data);
225 while (len > 0) {
226 const size_t buffer_left = sizeof(buffer_) - buffer_offset_;
227 const size_t to_copy = std::min(buffer_left, len);
228 memcpy(data_bytes, buffer_ + buffer_offset_, to_copy);
229 data_bytes += to_copy;
230 buffer_offset_ += to_copy;
231 len -= to_copy;
232
233 if (buffer_offset_ == sizeof(buffer_)) {
234 FillBuffer();
235 }
236 }
237 }
238
InsecureRandBytes(void * data,size_t len)239 void SimpleRandom::InsecureRandBytes(void* data, size_t len) {
240 RandBytes(data, len);
241 }
242
InsecureRandUint64()243 uint64_t SimpleRandom::InsecureRandUint64() { return RandUint64(); }
244
FillBuffer()245 void SimpleRandom::FillBuffer() {
246 uint8_t nonce[12];
247 memcpy(nonce, buffer_, sizeof(nonce));
248 CRYPTO_chacha_20(buffer_, buffer_, sizeof(buffer_), key_, nonce, 0);
249 buffer_offset_ = 0;
250 }
251
set_seed(uint64_t seed)252 void SimpleRandom::set_seed(uint64_t seed) {
253 static_assert(sizeof(key_) == SHA256_DIGEST_LENGTH, "Key has to be 256 bits");
254 SHA256(reinterpret_cast<const uint8_t*>(&seed), sizeof(seed), key_);
255
256 memset(buffer_, 0, sizeof(buffer_));
257 FillBuffer();
258 }
259
MockFramerVisitor()260 MockFramerVisitor::MockFramerVisitor() {
261 // By default, we want to accept packets.
262 ON_CALL(*this, OnProtocolVersionMismatch(_))
263 .WillByDefault(testing::Return(false));
264
265 // By default, we want to accept packets.
266 ON_CALL(*this, OnUnauthenticatedHeader(_))
267 .WillByDefault(testing::Return(true));
268
269 ON_CALL(*this, OnUnauthenticatedPublicHeader(_))
270 .WillByDefault(testing::Return(true));
271
272 ON_CALL(*this, OnPacketHeader(_)).WillByDefault(testing::Return(true));
273
274 ON_CALL(*this, OnStreamFrame(_)).WillByDefault(testing::Return(true));
275
276 ON_CALL(*this, OnCryptoFrame(_)).WillByDefault(testing::Return(true));
277
278 ON_CALL(*this, OnStopWaitingFrame(_)).WillByDefault(testing::Return(true));
279
280 ON_CALL(*this, OnPaddingFrame(_)).WillByDefault(testing::Return(true));
281
282 ON_CALL(*this, OnPingFrame(_)).WillByDefault(testing::Return(true));
283
284 ON_CALL(*this, OnRstStreamFrame(_)).WillByDefault(testing::Return(true));
285
286 ON_CALL(*this, OnConnectionCloseFrame(_))
287 .WillByDefault(testing::Return(true));
288
289 ON_CALL(*this, OnStopSendingFrame(_)).WillByDefault(testing::Return(true));
290
291 ON_CALL(*this, OnPathChallengeFrame(_)).WillByDefault(testing::Return(true));
292
293 ON_CALL(*this, OnPathResponseFrame(_)).WillByDefault(testing::Return(true));
294
295 ON_CALL(*this, OnGoAwayFrame(_)).WillByDefault(testing::Return(true));
296 ON_CALL(*this, OnMaxStreamsFrame(_)).WillByDefault(testing::Return(true));
297 ON_CALL(*this, OnStreamsBlockedFrame(_)).WillByDefault(testing::Return(true));
298 }
299
~MockFramerVisitor()300 MockFramerVisitor::~MockFramerVisitor() {}
301
OnProtocolVersionMismatch(ParsedQuicVersion)302 bool NoOpFramerVisitor::OnProtocolVersionMismatch(
303 ParsedQuicVersion /*version*/) {
304 return false;
305 }
306
OnUnauthenticatedPublicHeader(const QuicPacketHeader &)307 bool NoOpFramerVisitor::OnUnauthenticatedPublicHeader(
308 const QuicPacketHeader& /*header*/) {
309 return true;
310 }
311
OnUnauthenticatedHeader(const QuicPacketHeader &)312 bool NoOpFramerVisitor::OnUnauthenticatedHeader(
313 const QuicPacketHeader& /*header*/) {
314 return true;
315 }
316
OnPacketHeader(const QuicPacketHeader &)317 bool NoOpFramerVisitor::OnPacketHeader(const QuicPacketHeader& /*header*/) {
318 return true;
319 }
320
OnCoalescedPacket(const QuicEncryptedPacket &)321 void NoOpFramerVisitor::OnCoalescedPacket(
322 const QuicEncryptedPacket& /*packet*/) {}
323
OnUndecryptablePacket(const QuicEncryptedPacket &,EncryptionLevel,bool)324 void NoOpFramerVisitor::OnUndecryptablePacket(
325 const QuicEncryptedPacket& /*packet*/, EncryptionLevel /*decryption_level*/,
326 bool /*has_decryption_key*/) {}
327
OnStreamFrame(const QuicStreamFrame &)328 bool NoOpFramerVisitor::OnStreamFrame(const QuicStreamFrame& /*frame*/) {
329 return true;
330 }
331
OnCryptoFrame(const QuicCryptoFrame &)332 bool NoOpFramerVisitor::OnCryptoFrame(const QuicCryptoFrame& /*frame*/) {
333 return true;
334 }
335
OnAckFrameStart(QuicPacketNumber,QuicTime::Delta)336 bool NoOpFramerVisitor::OnAckFrameStart(QuicPacketNumber /*largest_acked*/,
337 QuicTime::Delta /*ack_delay_time*/) {
338 return true;
339 }
340
OnAckRange(QuicPacketNumber,QuicPacketNumber)341 bool NoOpFramerVisitor::OnAckRange(QuicPacketNumber /*start*/,
342 QuicPacketNumber /*end*/) {
343 return true;
344 }
345
OnAckTimestamp(QuicPacketNumber,QuicTime)346 bool NoOpFramerVisitor::OnAckTimestamp(QuicPacketNumber /*packet_number*/,
347 QuicTime /*timestamp*/) {
348 return true;
349 }
350
OnAckFrameEnd(QuicPacketNumber,const absl::optional<QuicEcnCounts> &)351 bool NoOpFramerVisitor::OnAckFrameEnd(
352 QuicPacketNumber /*start*/,
353 const absl::optional<QuicEcnCounts>& /*ecn_counts*/) {
354 return true;
355 }
356
OnStopWaitingFrame(const QuicStopWaitingFrame &)357 bool NoOpFramerVisitor::OnStopWaitingFrame(
358 const QuicStopWaitingFrame& /*frame*/) {
359 return true;
360 }
361
OnPaddingFrame(const QuicPaddingFrame &)362 bool NoOpFramerVisitor::OnPaddingFrame(const QuicPaddingFrame& /*frame*/) {
363 return true;
364 }
365
OnPingFrame(const QuicPingFrame &)366 bool NoOpFramerVisitor::OnPingFrame(const QuicPingFrame& /*frame*/) {
367 return true;
368 }
369
OnRstStreamFrame(const QuicRstStreamFrame &)370 bool NoOpFramerVisitor::OnRstStreamFrame(const QuicRstStreamFrame& /*frame*/) {
371 return true;
372 }
373
OnConnectionCloseFrame(const QuicConnectionCloseFrame &)374 bool NoOpFramerVisitor::OnConnectionCloseFrame(
375 const QuicConnectionCloseFrame& /*frame*/) {
376 return true;
377 }
378
OnNewConnectionIdFrame(const QuicNewConnectionIdFrame &)379 bool NoOpFramerVisitor::OnNewConnectionIdFrame(
380 const QuicNewConnectionIdFrame& /*frame*/) {
381 return true;
382 }
383
OnRetireConnectionIdFrame(const QuicRetireConnectionIdFrame &)384 bool NoOpFramerVisitor::OnRetireConnectionIdFrame(
385 const QuicRetireConnectionIdFrame& /*frame*/) {
386 return true;
387 }
388
OnNewTokenFrame(const QuicNewTokenFrame &)389 bool NoOpFramerVisitor::OnNewTokenFrame(const QuicNewTokenFrame& /*frame*/) {
390 return true;
391 }
392
OnStopSendingFrame(const QuicStopSendingFrame &)393 bool NoOpFramerVisitor::OnStopSendingFrame(
394 const QuicStopSendingFrame& /*frame*/) {
395 return true;
396 }
397
OnPathChallengeFrame(const QuicPathChallengeFrame &)398 bool NoOpFramerVisitor::OnPathChallengeFrame(
399 const QuicPathChallengeFrame& /*frame*/) {
400 return true;
401 }
402
OnPathResponseFrame(const QuicPathResponseFrame &)403 bool NoOpFramerVisitor::OnPathResponseFrame(
404 const QuicPathResponseFrame& /*frame*/) {
405 return true;
406 }
407
OnGoAwayFrame(const QuicGoAwayFrame &)408 bool NoOpFramerVisitor::OnGoAwayFrame(const QuicGoAwayFrame& /*frame*/) {
409 return true;
410 }
411
OnMaxStreamsFrame(const QuicMaxStreamsFrame &)412 bool NoOpFramerVisitor::OnMaxStreamsFrame(
413 const QuicMaxStreamsFrame& /*frame*/) {
414 return true;
415 }
416
OnStreamsBlockedFrame(const QuicStreamsBlockedFrame &)417 bool NoOpFramerVisitor::OnStreamsBlockedFrame(
418 const QuicStreamsBlockedFrame& /*frame*/) {
419 return true;
420 }
421
OnWindowUpdateFrame(const QuicWindowUpdateFrame &)422 bool NoOpFramerVisitor::OnWindowUpdateFrame(
423 const QuicWindowUpdateFrame& /*frame*/) {
424 return true;
425 }
426
OnBlockedFrame(const QuicBlockedFrame &)427 bool NoOpFramerVisitor::OnBlockedFrame(const QuicBlockedFrame& /*frame*/) {
428 return true;
429 }
430
OnMessageFrame(const QuicMessageFrame &)431 bool NoOpFramerVisitor::OnMessageFrame(const QuicMessageFrame& /*frame*/) {
432 return true;
433 }
434
OnHandshakeDoneFrame(const QuicHandshakeDoneFrame &)435 bool NoOpFramerVisitor::OnHandshakeDoneFrame(
436 const QuicHandshakeDoneFrame& /*frame*/) {
437 return true;
438 }
439
OnAckFrequencyFrame(const QuicAckFrequencyFrame &)440 bool NoOpFramerVisitor::OnAckFrequencyFrame(
441 const QuicAckFrequencyFrame& /*frame*/) {
442 return true;
443 }
444
IsValidStatelessResetToken(const StatelessResetToken &) const445 bool NoOpFramerVisitor::IsValidStatelessResetToken(
446 const StatelessResetToken& /*token*/) const {
447 return false;
448 }
449
MockQuicConnectionVisitor()450 MockQuicConnectionVisitor::MockQuicConnectionVisitor() {}
451
~MockQuicConnectionVisitor()452 MockQuicConnectionVisitor::~MockQuicConnectionVisitor() {}
453
MockQuicConnectionHelper()454 MockQuicConnectionHelper::MockQuicConnectionHelper() {}
455
~MockQuicConnectionHelper()456 MockQuicConnectionHelper::~MockQuicConnectionHelper() {}
457
GetClock() const458 const QuicClock* MockQuicConnectionHelper::GetClock() const { return &clock_; }
459
GetClock()460 QuicClock* MockQuicConnectionHelper::GetClock() { return &clock_; }
461
GetRandomGenerator()462 QuicRandom* MockQuicConnectionHelper::GetRandomGenerator() {
463 return &random_generator_;
464 }
465
CreateAlarm(QuicAlarm::Delegate * delegate)466 QuicAlarm* MockAlarmFactory::CreateAlarm(QuicAlarm::Delegate* delegate) {
467 return new MockAlarmFactory::TestAlarm(
468 QuicArenaScopedPtr<QuicAlarm::Delegate>(delegate));
469 }
470
CreateAlarm(QuicArenaScopedPtr<QuicAlarm::Delegate> delegate,QuicConnectionArena * arena)471 QuicArenaScopedPtr<QuicAlarm> MockAlarmFactory::CreateAlarm(
472 QuicArenaScopedPtr<QuicAlarm::Delegate> delegate,
473 QuicConnectionArena* arena) {
474 if (arena != nullptr) {
475 return arena->New<TestAlarm>(std::move(delegate));
476 } else {
477 return QuicArenaScopedPtr<TestAlarm>(new TestAlarm(std::move(delegate)));
478 }
479 }
480
481 quiche::QuicheBufferAllocator*
GetStreamSendBufferAllocator()482 MockQuicConnectionHelper::GetStreamSendBufferAllocator() {
483 return &buffer_allocator_;
484 }
485
AdvanceTime(QuicTime::Delta delta)486 void MockQuicConnectionHelper::AdvanceTime(QuicTime::Delta delta) {
487 clock_.AdvanceTime(delta);
488 }
489
MockQuicConnection(QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective)490 MockQuicConnection::MockQuicConnection(QuicConnectionHelperInterface* helper,
491 QuicAlarmFactory* alarm_factory,
492 Perspective perspective)
493 : MockQuicConnection(TestConnectionId(),
494 QuicSocketAddress(TestPeerIPAddress(), kTestPort),
495 helper, alarm_factory, perspective,
496 ParsedVersionOfIndex(CurrentSupportedVersions(), 0)) {}
497
MockQuicConnection(QuicSocketAddress address,QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective)498 MockQuicConnection::MockQuicConnection(QuicSocketAddress address,
499 QuicConnectionHelperInterface* helper,
500 QuicAlarmFactory* alarm_factory,
501 Perspective perspective)
502 : MockQuicConnection(TestConnectionId(), address, helper, alarm_factory,
503 perspective,
504 ParsedVersionOfIndex(CurrentSupportedVersions(), 0)) {}
505
MockQuicConnection(QuicConnectionId connection_id,QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective)506 MockQuicConnection::MockQuicConnection(QuicConnectionId connection_id,
507 QuicConnectionHelperInterface* helper,
508 QuicAlarmFactory* alarm_factory,
509 Perspective perspective)
510 : MockQuicConnection(connection_id,
511 QuicSocketAddress(TestPeerIPAddress(), kTestPort),
512 helper, alarm_factory, perspective,
513 ParsedVersionOfIndex(CurrentSupportedVersions(), 0)) {}
514
MockQuicConnection(QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective,const ParsedQuicVersionVector & supported_versions)515 MockQuicConnection::MockQuicConnection(
516 QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory,
517 Perspective perspective, const ParsedQuicVersionVector& supported_versions)
518 : MockQuicConnection(
519 TestConnectionId(), QuicSocketAddress(TestPeerIPAddress(), kTestPort),
520 helper, alarm_factory, perspective, supported_versions) {}
521
MockQuicConnection(QuicConnectionId connection_id,QuicSocketAddress initial_peer_address,QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective,const ParsedQuicVersionVector & supported_versions)522 MockQuicConnection::MockQuicConnection(
523 QuicConnectionId connection_id, QuicSocketAddress initial_peer_address,
524 QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory,
525 Perspective perspective, const ParsedQuicVersionVector& supported_versions)
526 : QuicConnection(
527 connection_id,
528 /*initial_self_address=*/QuicSocketAddress(QuicIpAddress::Any4(), 5),
529 initial_peer_address, helper, alarm_factory,
530 new testing::NiceMock<MockPacketWriter>(),
531 /* owns_writer= */ true, perspective, supported_versions,
532 connection_id_generator_) {
533 ON_CALL(*this, OnError(_))
534 .WillByDefault(
535 Invoke(this, &PacketSavingConnection::QuicConnection_OnError));
536 ON_CALL(*this, SendCryptoData(_, _, _))
537 .WillByDefault(
538 Invoke(this, &MockQuicConnection::QuicConnection_SendCryptoData));
539
540 SetSelfAddress(QuicSocketAddress(QuicIpAddress::Any4(), 5));
541 }
542
~MockQuicConnection()543 MockQuicConnection::~MockQuicConnection() {}
544
AdvanceTime(QuicTime::Delta delta)545 void MockQuicConnection::AdvanceTime(QuicTime::Delta delta) {
546 static_cast<MockQuicConnectionHelper*>(helper())->AdvanceTime(delta);
547 }
548
OnProtocolVersionMismatch(ParsedQuicVersion)549 bool MockQuicConnection::OnProtocolVersionMismatch(
550 ParsedQuicVersion /*version*/) {
551 return false;
552 }
553
PacketSavingConnection(QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective)554 PacketSavingConnection::PacketSavingConnection(
555 QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory,
556 Perspective perspective)
557 : MockQuicConnection(helper, alarm_factory, perspective) {}
558
PacketSavingConnection(QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective,const ParsedQuicVersionVector & supported_versions)559 PacketSavingConnection::PacketSavingConnection(
560 QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory,
561 Perspective perspective, const ParsedQuicVersionVector& supported_versions)
562 : MockQuicConnection(helper, alarm_factory, perspective,
563 supported_versions) {}
564
~PacketSavingConnection()565 PacketSavingConnection::~PacketSavingConnection() {}
566
GetSerializedPacketFate(bool,EncryptionLevel)567 SerializedPacketFate PacketSavingConnection::GetSerializedPacketFate(
568 bool /*is_mtu_discovery*/, EncryptionLevel /*encryption_level*/) {
569 return SEND_TO_WRITER;
570 }
571
SendOrQueuePacket(SerializedPacket packet)572 void PacketSavingConnection::SendOrQueuePacket(SerializedPacket packet) {
573 encrypted_packets_.push_back(std::make_unique<QuicEncryptedPacket>(
574 CopyBuffer(packet), packet.encrypted_length, true));
575 clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10));
576 // Transfer ownership of the packet to the SentPacketManager and the
577 // ack notifier to the AckNotifierManager.
578 OnPacketSent(packet.encryption_level, packet.transmission_type);
579 QuicConnectionPeer::GetSentPacketManager(this)->OnPacketSent(
580 &packet, clock_.ApproximateNow(), NOT_RETRANSMISSION,
581 HAS_RETRANSMITTABLE_DATA, true, ECN_NOT_ECT);
582 }
583
MockQuicSession(QuicConnection * connection)584 MockQuicSession::MockQuicSession(QuicConnection* connection)
585 : MockQuicSession(connection, true) {}
586
MockQuicSession(QuicConnection * connection,bool create_mock_crypto_stream)587 MockQuicSession::MockQuicSession(QuicConnection* connection,
588 bool create_mock_crypto_stream)
589 : QuicSession(connection, nullptr, DefaultQuicConfig(),
590 connection->supported_versions(),
591 /*num_expected_unidirectional_static_streams = */ 0) {
592 if (create_mock_crypto_stream) {
593 crypto_stream_ = std::make_unique<MockQuicCryptoStream>(this);
594 }
595 ON_CALL(*this, WritevData(_, _, _, _, _, _))
596 .WillByDefault(testing::Return(QuicConsumedData(0, false)));
597 }
598
~MockQuicSession()599 MockQuicSession::~MockQuicSession() { DeleteConnection(); }
600
GetMutableCryptoStream()601 QuicCryptoStream* MockQuicSession::GetMutableCryptoStream() {
602 return crypto_stream_.get();
603 }
604
GetCryptoStream() const605 const QuicCryptoStream* MockQuicSession::GetCryptoStream() const {
606 return crypto_stream_.get();
607 }
608
SetCryptoStream(QuicCryptoStream * crypto_stream)609 void MockQuicSession::SetCryptoStream(QuicCryptoStream* crypto_stream) {
610 crypto_stream_.reset(crypto_stream);
611 }
612
ConsumeData(QuicStreamId id,size_t write_length,QuicStreamOffset offset,StreamSendingState state,TransmissionType,absl::optional<EncryptionLevel>)613 QuicConsumedData MockQuicSession::ConsumeData(
614 QuicStreamId id, size_t write_length, QuicStreamOffset offset,
615 StreamSendingState state, TransmissionType /*type*/,
616 absl::optional<EncryptionLevel> /*level*/) {
617 if (write_length > 0) {
618 auto buf = std::make_unique<char[]>(write_length);
619 QuicStream* stream = GetOrCreateStream(id);
620 QUICHE_DCHECK(stream);
621 QuicDataWriter writer(write_length, buf.get(), quiche::HOST_BYTE_ORDER);
622 stream->WriteStreamData(offset, write_length, &writer);
623 } else {
624 QUICHE_DCHECK(state != NO_FIN);
625 }
626 return QuicConsumedData(write_length, state != NO_FIN);
627 }
628
MockQuicCryptoStream(QuicSession * session)629 MockQuicCryptoStream::MockQuicCryptoStream(QuicSession* session)
630 : QuicCryptoStream(session), params_(new QuicCryptoNegotiatedParameters) {}
631
~MockQuicCryptoStream()632 MockQuicCryptoStream::~MockQuicCryptoStream() {}
633
EarlyDataReason() const634 ssl_early_data_reason_t MockQuicCryptoStream::EarlyDataReason() const {
635 return ssl_early_data_unknown;
636 }
637
encryption_established() const638 bool MockQuicCryptoStream::encryption_established() const { return false; }
639
one_rtt_keys_available() const640 bool MockQuicCryptoStream::one_rtt_keys_available() const { return false; }
641
642 const QuicCryptoNegotiatedParameters&
crypto_negotiated_params() const643 MockQuicCryptoStream::crypto_negotiated_params() const {
644 return *params_;
645 }
646
crypto_message_parser()647 CryptoMessageParser* MockQuicCryptoStream::crypto_message_parser() {
648 return &crypto_framer_;
649 }
650
MockQuicSpdySession(QuicConnection * connection)651 MockQuicSpdySession::MockQuicSpdySession(QuicConnection* connection)
652 : MockQuicSpdySession(connection, true) {}
653
MockQuicSpdySession(QuicConnection * connection,bool create_mock_crypto_stream)654 MockQuicSpdySession::MockQuicSpdySession(QuicConnection* connection,
655 bool create_mock_crypto_stream)
656 : QuicSpdySession(connection, nullptr, DefaultQuicConfig(),
657 connection->supported_versions()) {
658 if (create_mock_crypto_stream) {
659 crypto_stream_ = std::make_unique<MockQuicCryptoStream>(this);
660 }
661
662 ON_CALL(*this, WritevData(_, _, _, _, _, _))
663 .WillByDefault(testing::Return(QuicConsumedData(0, false)));
664
665 ON_CALL(*this, SendWindowUpdate(_, _))
666 .WillByDefault([this](QuicStreamId id, QuicStreamOffset byte_offset) {
667 return QuicSpdySession::SendWindowUpdate(id, byte_offset);
668 });
669
670 ON_CALL(*this, SendBlocked(_, _))
671 .WillByDefault([this](QuicStreamId id, QuicStreamOffset byte_offset) {
672 return QuicSpdySession::SendBlocked(id, byte_offset);
673 });
674
675 ON_CALL(*this, OnCongestionWindowChange(_)).WillByDefault(testing::Return());
676 }
677
~MockQuicSpdySession()678 MockQuicSpdySession::~MockQuicSpdySession() { DeleteConnection(); }
679
GetMutableCryptoStream()680 QuicCryptoStream* MockQuicSpdySession::GetMutableCryptoStream() {
681 return crypto_stream_.get();
682 }
683
GetCryptoStream() const684 const QuicCryptoStream* MockQuicSpdySession::GetCryptoStream() const {
685 return crypto_stream_.get();
686 }
687
SetCryptoStream(QuicCryptoStream * crypto_stream)688 void MockQuicSpdySession::SetCryptoStream(QuicCryptoStream* crypto_stream) {
689 crypto_stream_.reset(crypto_stream);
690 }
691
ConsumeData(QuicStreamId id,size_t write_length,QuicStreamOffset offset,StreamSendingState state,TransmissionType,absl::optional<EncryptionLevel>)692 QuicConsumedData MockQuicSpdySession::ConsumeData(
693 QuicStreamId id, size_t write_length, QuicStreamOffset offset,
694 StreamSendingState state, TransmissionType /*type*/,
695 absl::optional<EncryptionLevel> /*level*/) {
696 if (write_length > 0) {
697 auto buf = std::make_unique<char[]>(write_length);
698 QuicStream* stream = GetOrCreateStream(id);
699 QUICHE_DCHECK(stream);
700 QuicDataWriter writer(write_length, buf.get(), quiche::HOST_BYTE_ORDER);
701 stream->WriteStreamData(offset, write_length, &writer);
702 } else {
703 QUICHE_DCHECK(state != NO_FIN);
704 }
705 return QuicConsumedData(write_length, state != NO_FIN);
706 }
707
TestQuicSpdyServerSession(QuicConnection * connection,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions,const QuicCryptoServerConfig * crypto_config,QuicCompressedCertsCache * compressed_certs_cache)708 TestQuicSpdyServerSession::TestQuicSpdyServerSession(
709 QuicConnection* connection, const QuicConfig& config,
710 const ParsedQuicVersionVector& supported_versions,
711 const QuicCryptoServerConfig* crypto_config,
712 QuicCompressedCertsCache* compressed_certs_cache)
713 : QuicServerSessionBase(config, supported_versions, connection, &visitor_,
714 &helper_, crypto_config, compressed_certs_cache) {
715 ON_CALL(helper_, CanAcceptClientHello(_, _, _, _, _))
716 .WillByDefault(testing::Return(true));
717 }
718
~TestQuicSpdyServerSession()719 TestQuicSpdyServerSession::~TestQuicSpdyServerSession() { DeleteConnection(); }
720
721 std::unique_ptr<QuicCryptoServerStreamBase>
CreateQuicCryptoServerStream(const QuicCryptoServerConfig * crypto_config,QuicCompressedCertsCache * compressed_certs_cache)722 TestQuicSpdyServerSession::CreateQuicCryptoServerStream(
723 const QuicCryptoServerConfig* crypto_config,
724 QuicCompressedCertsCache* compressed_certs_cache) {
725 return CreateCryptoServerStream(crypto_config, compressed_certs_cache, this,
726 &helper_);
727 }
728
729 QuicCryptoServerStreamBase*
GetMutableCryptoStream()730 TestQuicSpdyServerSession::GetMutableCryptoStream() {
731 return QuicServerSessionBase::GetMutableCryptoStream();
732 }
733
GetCryptoStream() const734 const QuicCryptoServerStreamBase* TestQuicSpdyServerSession::GetCryptoStream()
735 const {
736 return QuicServerSessionBase::GetCryptoStream();
737 }
738
TestQuicSpdyClientSession(QuicConnection * connection,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions,const QuicServerId & server_id,QuicCryptoClientConfig * crypto_config,absl::optional<QuicSSLConfig> ssl_config)739 TestQuicSpdyClientSession::TestQuicSpdyClientSession(
740 QuicConnection* connection, const QuicConfig& config,
741 const ParsedQuicVersionVector& supported_versions,
742 const QuicServerId& server_id, QuicCryptoClientConfig* crypto_config,
743 absl::optional<QuicSSLConfig> ssl_config)
744 : QuicSpdyClientSessionBase(connection, nullptr, &push_promise_index_,
745 config, supported_versions),
746 ssl_config_(std::move(ssl_config)) {
747 // TODO(b/153726130): Consider adding SetServerApplicationStateForResumption
748 // calls in tests and set |has_application_state| to true.
749 crypto_stream_ = std::make_unique<QuicCryptoClientStream>(
750 server_id, this, crypto_test_utils::ProofVerifyContextForTesting(),
751 crypto_config, this, /*has_application_state = */ false);
752 Initialize();
753 ON_CALL(*this, OnConfigNegotiated())
754 .WillByDefault(
755 Invoke(this, &TestQuicSpdyClientSession::RealOnConfigNegotiated));
756 }
757
~TestQuicSpdyClientSession()758 TestQuicSpdyClientSession::~TestQuicSpdyClientSession() {}
759
IsAuthorized(const std::string &)760 bool TestQuicSpdyClientSession::IsAuthorized(const std::string& /*authority*/) {
761 return true;
762 }
763
GetMutableCryptoStream()764 QuicCryptoClientStream* TestQuicSpdyClientSession::GetMutableCryptoStream() {
765 return crypto_stream_.get();
766 }
767
GetCryptoStream() const768 const QuicCryptoClientStream* TestQuicSpdyClientSession::GetCryptoStream()
769 const {
770 return crypto_stream_.get();
771 }
772
RealOnConfigNegotiated()773 void TestQuicSpdyClientSession::RealOnConfigNegotiated() {
774 QuicSpdyClientSessionBase::OnConfigNegotiated();
775 }
776
TestPushPromiseDelegate(bool match)777 TestPushPromiseDelegate::TestPushPromiseDelegate(bool match)
778 : match_(match), rendezvous_fired_(false), rendezvous_stream_(nullptr) {}
779
CheckVary(const spdy::Http2HeaderBlock &,const spdy::Http2HeaderBlock &,const spdy::Http2HeaderBlock &)780 bool TestPushPromiseDelegate::CheckVary(
781 const spdy::Http2HeaderBlock& /*client_request*/,
782 const spdy::Http2HeaderBlock& /*promise_request*/,
783 const spdy::Http2HeaderBlock& /*promise_response*/) {
784 QUIC_DVLOG(1) << "match " << match_;
785 return match_;
786 }
787
OnRendezvousResult(QuicSpdyStream * stream)788 void TestPushPromiseDelegate::OnRendezvousResult(QuicSpdyStream* stream) {
789 rendezvous_fired_ = true;
790 rendezvous_stream_ = stream;
791 }
792
MockPacketWriter()793 MockPacketWriter::MockPacketWriter() {
794 ON_CALL(*this, GetMaxPacketSize(_))
795 .WillByDefault(testing::Return(kMaxOutgoingPacketSize));
796 ON_CALL(*this, IsBatchMode()).WillByDefault(testing::Return(false));
797 ON_CALL(*this, GetNextWriteLocation(_, _))
798 .WillByDefault(testing::Return(QuicPacketBuffer()));
799 ON_CALL(*this, Flush())
800 .WillByDefault(testing::Return(WriteResult(WRITE_STATUS_OK, 0)));
801 ON_CALL(*this, SupportsReleaseTime()).WillByDefault(testing::Return(false));
802 }
803
~MockPacketWriter()804 MockPacketWriter::~MockPacketWriter() {}
805
MockSendAlgorithm()806 MockSendAlgorithm::MockSendAlgorithm() {
807 ON_CALL(*this, PacingRate(_))
808 .WillByDefault(testing::Return(QuicBandwidth::Zero()));
809 ON_CALL(*this, BandwidthEstimate())
810 .WillByDefault(testing::Return(QuicBandwidth::Zero()));
811 }
812
~MockSendAlgorithm()813 MockSendAlgorithm::~MockSendAlgorithm() {}
814
MockLossAlgorithm()815 MockLossAlgorithm::MockLossAlgorithm() {}
816
~MockLossAlgorithm()817 MockLossAlgorithm::~MockLossAlgorithm() {}
818
MockAckListener()819 MockAckListener::MockAckListener() {}
820
~MockAckListener()821 MockAckListener::~MockAckListener() {}
822
MockNetworkChangeVisitor()823 MockNetworkChangeVisitor::MockNetworkChangeVisitor() {}
824
~MockNetworkChangeVisitor()825 MockNetworkChangeVisitor::~MockNetworkChangeVisitor() {}
826
TestPeerIPAddress()827 QuicIpAddress TestPeerIPAddress() { return QuicIpAddress::Loopback4(); }
828
QuicVersionMax()829 ParsedQuicVersion QuicVersionMax() { return AllSupportedVersions().front(); }
830
QuicVersionMin()831 ParsedQuicVersion QuicVersionMin() { return AllSupportedVersions().back(); }
832
DisableQuicVersionsWithTls()833 void DisableQuicVersionsWithTls() {
834 for (const ParsedQuicVersion& version : AllSupportedVersionsWithTls()) {
835 QuicDisableVersion(version);
836 }
837 }
838
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data)839 QuicEncryptedPacket* ConstructEncryptedPacket(
840 QuicConnectionId destination_connection_id,
841 QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
842 uint64_t packet_number, const std::string& data) {
843 return ConstructEncryptedPacket(
844 destination_connection_id, source_connection_id, version_flag, reset_flag,
845 packet_number, data, CONNECTION_ID_PRESENT, CONNECTION_ID_ABSENT,
846 PACKET_4BYTE_PACKET_NUMBER);
847 }
848
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length)849 QuicEncryptedPacket* ConstructEncryptedPacket(
850 QuicConnectionId destination_connection_id,
851 QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
852 uint64_t packet_number, const std::string& data,
853 QuicConnectionIdIncluded destination_connection_id_included,
854 QuicConnectionIdIncluded source_connection_id_included,
855 QuicPacketNumberLength packet_number_length) {
856 return ConstructEncryptedPacket(
857 destination_connection_id, source_connection_id, version_flag, reset_flag,
858 packet_number, data, destination_connection_id_included,
859 source_connection_id_included, packet_number_length, nullptr);
860 }
861
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length,ParsedQuicVersionVector * versions)862 QuicEncryptedPacket* ConstructEncryptedPacket(
863 QuicConnectionId destination_connection_id,
864 QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
865 uint64_t packet_number, const std::string& data,
866 QuicConnectionIdIncluded destination_connection_id_included,
867 QuicConnectionIdIncluded source_connection_id_included,
868 QuicPacketNumberLength packet_number_length,
869 ParsedQuicVersionVector* versions) {
870 return ConstructEncryptedPacket(
871 destination_connection_id, source_connection_id, version_flag, reset_flag,
872 packet_number, data, false, destination_connection_id_included,
873 source_connection_id_included, packet_number_length, versions,
874 Perspective::IS_CLIENT);
875 }
876
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,bool full_padding,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length,ParsedQuicVersionVector * versions)877 QuicEncryptedPacket* ConstructEncryptedPacket(
878 QuicConnectionId destination_connection_id,
879 QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
880 uint64_t packet_number, const std::string& data, bool full_padding,
881 QuicConnectionIdIncluded destination_connection_id_included,
882 QuicConnectionIdIncluded source_connection_id_included,
883 QuicPacketNumberLength packet_number_length,
884 ParsedQuicVersionVector* versions) {
885 return ConstructEncryptedPacket(
886 destination_connection_id, source_connection_id, version_flag, reset_flag,
887 packet_number, data, full_padding, destination_connection_id_included,
888 source_connection_id_included, packet_number_length, versions,
889 Perspective::IS_CLIENT);
890 }
891
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,bool full_padding,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length,ParsedQuicVersionVector * versions,Perspective perspective)892 QuicEncryptedPacket* ConstructEncryptedPacket(
893 QuicConnectionId destination_connection_id,
894 QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
895 uint64_t packet_number, const std::string& data, bool full_padding,
896 QuicConnectionIdIncluded destination_connection_id_included,
897 QuicConnectionIdIncluded source_connection_id_included,
898 QuicPacketNumberLength packet_number_length,
899 ParsedQuicVersionVector* versions, Perspective perspective) {
900 QuicPacketHeader header;
901 header.destination_connection_id = destination_connection_id;
902 header.destination_connection_id_included =
903 destination_connection_id_included;
904 header.source_connection_id = source_connection_id;
905 header.source_connection_id_included = source_connection_id_included;
906 header.version_flag = version_flag;
907 header.reset_flag = reset_flag;
908 header.packet_number_length = packet_number_length;
909 header.packet_number = QuicPacketNumber(packet_number);
910 ParsedQuicVersionVector supported_versions = CurrentSupportedVersions();
911 if (!versions) {
912 versions = &supported_versions;
913 }
914 EXPECT_FALSE(versions->empty());
915 ParsedQuicVersion version = (*versions)[0];
916 if (QuicVersionHasLongHeaderLengths(version.transport_version) &&
917 version_flag) {
918 header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1;
919 header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2;
920 }
921
922 QuicFrames frames;
923 QuicFramer framer(*versions, QuicTime::Zero(), perspective,
924 kQuicDefaultConnectionIdLength);
925 framer.SetInitialObfuscators(destination_connection_id);
926 EncryptionLevel level =
927 header.version_flag ? ENCRYPTION_INITIAL : ENCRYPTION_FORWARD_SECURE;
928 if (level != ENCRYPTION_INITIAL) {
929 framer.SetEncrypter(level, std::make_unique<TaggingEncrypter>(level));
930 }
931 if (!QuicVersionUsesCryptoFrames(version.transport_version)) {
932 QuicFrame frame(
933 QuicStreamFrame(QuicUtils::GetCryptoStreamId(version.transport_version),
934 false, 0, absl::string_view(data)));
935 frames.push_back(frame);
936 } else {
937 QuicFrame frame(new QuicCryptoFrame(level, 0, data));
938 frames.push_back(frame);
939 }
940 if (full_padding) {
941 frames.push_back(QuicFrame(QuicPaddingFrame(-1)));
942 } else {
943 // We need a minimum number of bytes of encrypted payload. This will
944 // guarantee that we have at least that much. (It ignores the overhead of
945 // the stream/crypto framing, so it overpads slightly.)
946 size_t min_plaintext_size = QuicPacketCreator::MinPlaintextPacketSize(
947 version, packet_number_length);
948 if (data.length() < min_plaintext_size) {
949 size_t padding_length = min_plaintext_size - data.length();
950 frames.push_back(QuicFrame(QuicPaddingFrame(padding_length)));
951 }
952 }
953
954 std::unique_ptr<QuicPacket> packet(
955 BuildUnsizedDataPacket(&framer, header, frames));
956 EXPECT_TRUE(packet != nullptr);
957 char* buffer = new char[kMaxOutgoingPacketSize];
958 size_t encrypted_length =
959 framer.EncryptPayload(level, QuicPacketNumber(packet_number), *packet,
960 buffer, kMaxOutgoingPacketSize);
961 EXPECT_NE(0u, encrypted_length);
962 DeleteFrames(&frames);
963 return new QuicEncryptedPacket(buffer, encrypted_length, true);
964 }
965
GetUndecryptableEarlyPacket(const ParsedQuicVersion & version,const QuicConnectionId & server_connection_id)966 std::unique_ptr<QuicEncryptedPacket> GetUndecryptableEarlyPacket(
967 const ParsedQuicVersion& version,
968 const QuicConnectionId& server_connection_id) {
969 QuicPacketHeader header;
970 header.destination_connection_id = server_connection_id;
971 header.destination_connection_id_included = CONNECTION_ID_PRESENT;
972 header.source_connection_id = EmptyQuicConnectionId();
973 header.source_connection_id_included = CONNECTION_ID_PRESENT;
974 if (!version.SupportsClientConnectionIds()) {
975 header.source_connection_id_included = CONNECTION_ID_ABSENT;
976 }
977 header.version_flag = true;
978 header.reset_flag = false;
979 header.packet_number_length = PACKET_4BYTE_PACKET_NUMBER;
980 header.packet_number = QuicPacketNumber(33);
981 header.long_packet_type = ZERO_RTT_PROTECTED;
982 if (version.HasLongHeaderLengths()) {
983 header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1;
984 header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2;
985 }
986
987 QuicFrames frames;
988 frames.push_back(QuicFrame(QuicPingFrame()));
989 frames.push_back(QuicFrame(QuicPaddingFrame(100)));
990 QuicFramer framer({version}, QuicTime::Zero(), Perspective::IS_CLIENT,
991 kQuicDefaultConnectionIdLength);
992 framer.SetInitialObfuscators(server_connection_id);
993
994 framer.SetEncrypter(ENCRYPTION_ZERO_RTT,
995 std::make_unique<TaggingEncrypter>(ENCRYPTION_ZERO_RTT));
996 std::unique_ptr<QuicPacket> packet(
997 BuildUnsizedDataPacket(&framer, header, frames));
998 EXPECT_TRUE(packet != nullptr);
999 char* buffer = new char[kMaxOutgoingPacketSize];
1000 size_t encrypted_length =
1001 framer.EncryptPayload(ENCRYPTION_ZERO_RTT, header.packet_number, *packet,
1002 buffer, kMaxOutgoingPacketSize);
1003 EXPECT_NE(0u, encrypted_length);
1004 DeleteFrames(&frames);
1005 return std::make_unique<QuicEncryptedPacket>(buffer, encrypted_length,
1006 /*owns_buffer=*/true);
1007 }
1008
ConstructReceivedPacket(const QuicEncryptedPacket & encrypted_packet,QuicTime receipt_time)1009 QuicReceivedPacket* ConstructReceivedPacket(
1010 const QuicEncryptedPacket& encrypted_packet, QuicTime receipt_time) {
1011 char* buffer = new char[encrypted_packet.length()];
1012 memcpy(buffer, encrypted_packet.data(), encrypted_packet.length());
1013 return new QuicReceivedPacket(buffer, encrypted_packet.length(), receipt_time,
1014 true);
1015 }
1016
ConstructMisFramedEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length,ParsedQuicVersion version,Perspective perspective)1017 QuicEncryptedPacket* ConstructMisFramedEncryptedPacket(
1018 QuicConnectionId destination_connection_id,
1019 QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
1020 uint64_t packet_number, const std::string& data,
1021 QuicConnectionIdIncluded destination_connection_id_included,
1022 QuicConnectionIdIncluded source_connection_id_included,
1023 QuicPacketNumberLength packet_number_length, ParsedQuicVersion version,
1024 Perspective perspective) {
1025 QuicPacketHeader header;
1026 header.destination_connection_id = destination_connection_id;
1027 header.destination_connection_id_included =
1028 destination_connection_id_included;
1029 header.source_connection_id = source_connection_id;
1030 header.source_connection_id_included = source_connection_id_included;
1031 header.version_flag = version_flag;
1032 header.reset_flag = reset_flag;
1033 header.packet_number_length = packet_number_length;
1034 header.packet_number = QuicPacketNumber(packet_number);
1035 if (QuicVersionHasLongHeaderLengths(version.transport_version) &&
1036 version_flag) {
1037 header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1;
1038 header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2;
1039 }
1040 QuicFrame frame(QuicStreamFrame(1, false, 0, absl::string_view(data)));
1041 QuicFrames frames;
1042 frames.push_back(frame);
1043 QuicFramer framer({version}, QuicTime::Zero(), perspective,
1044 kQuicDefaultConnectionIdLength);
1045 framer.SetInitialObfuscators(destination_connection_id);
1046 EncryptionLevel level =
1047 version_flag ? ENCRYPTION_INITIAL : ENCRYPTION_FORWARD_SECURE;
1048 if (level != ENCRYPTION_INITIAL) {
1049 framer.SetEncrypter(level, std::make_unique<TaggingEncrypter>(level));
1050 }
1051 // We need a minimum of 7 bytes of encrypted payload. This will guarantee that
1052 // we have at least that much. (It ignores the overhead of the stream/crypto
1053 // framing, so it overpads slightly.)
1054 if (data.length() < 7) {
1055 size_t padding_length = 7 - data.length();
1056 frames.push_back(QuicFrame(QuicPaddingFrame(padding_length)));
1057 }
1058
1059 std::unique_ptr<QuicPacket> packet(
1060 BuildUnsizedDataPacket(&framer, header, frames));
1061 EXPECT_TRUE(packet != nullptr);
1062
1063 // Now set the frame type to 0x1F, which is an invalid frame type.
1064 reinterpret_cast<unsigned char*>(
1065 packet->mutable_data())[GetStartOfEncryptedData(
1066 framer.transport_version(),
1067 GetIncludedDestinationConnectionIdLength(header),
1068 GetIncludedSourceConnectionIdLength(header), version_flag,
1069 false /* no diversification nonce */, packet_number_length,
1070 header.retry_token_length_length, 0, header.length_length)] = 0x1F;
1071
1072 char* buffer = new char[kMaxOutgoingPacketSize];
1073 size_t encrypted_length =
1074 framer.EncryptPayload(level, QuicPacketNumber(packet_number), *packet,
1075 buffer, kMaxOutgoingPacketSize);
1076 EXPECT_NE(0u, encrypted_length);
1077 return new QuicEncryptedPacket(buffer, encrypted_length, true);
1078 }
1079
DefaultQuicConfig()1080 QuicConfig DefaultQuicConfig() {
1081 QuicConfig config;
1082 config.SetInitialMaxStreamDataBytesIncomingBidirectionalToSend(
1083 kInitialStreamFlowControlWindowForTest);
1084 config.SetInitialMaxStreamDataBytesOutgoingBidirectionalToSend(
1085 kInitialStreamFlowControlWindowForTest);
1086 config.SetInitialMaxStreamDataBytesUnidirectionalToSend(
1087 kInitialStreamFlowControlWindowForTest);
1088 config.SetInitialStreamFlowControlWindowToSend(
1089 kInitialStreamFlowControlWindowForTest);
1090 config.SetInitialSessionFlowControlWindowToSend(
1091 kInitialSessionFlowControlWindowForTest);
1092 QuicConfigPeer::SetReceivedMaxBidirectionalStreams(
1093 &config, kDefaultMaxStreamsPerConnection);
1094 // Default enable NSTP.
1095 // This is unnecessary for versions > 44
1096 if (!config.HasClientSentConnectionOption(quic::kNSTP,
1097 quic::Perspective::IS_CLIENT)) {
1098 quic::QuicTagVector connection_options;
1099 connection_options.push_back(quic::kNSTP);
1100 config.SetConnectionOptionsToSend(connection_options);
1101 }
1102 return config;
1103 }
1104
SupportedVersions(ParsedQuicVersion version)1105 ParsedQuicVersionVector SupportedVersions(ParsedQuicVersion version) {
1106 ParsedQuicVersionVector versions;
1107 versions.push_back(version);
1108 return versions;
1109 }
1110
MockQuicConnectionDebugVisitor()1111 MockQuicConnectionDebugVisitor::MockQuicConnectionDebugVisitor() {}
1112
~MockQuicConnectionDebugVisitor()1113 MockQuicConnectionDebugVisitor::~MockQuicConnectionDebugVisitor() {}
1114
MockReceivedPacketManager(QuicConnectionStats * stats)1115 MockReceivedPacketManager::MockReceivedPacketManager(QuicConnectionStats* stats)
1116 : QuicReceivedPacketManager(stats) {}
1117
~MockReceivedPacketManager()1118 MockReceivedPacketManager::~MockReceivedPacketManager() {}
1119
MockPacketCreatorDelegate()1120 MockPacketCreatorDelegate::MockPacketCreatorDelegate() {}
~MockPacketCreatorDelegate()1121 MockPacketCreatorDelegate::~MockPacketCreatorDelegate() {}
1122
MockSessionNotifier()1123 MockSessionNotifier::MockSessionNotifier() {}
~MockSessionNotifier()1124 MockSessionNotifier::~MockSessionNotifier() {}
1125
1126 // static
1127 QuicCryptoClientStream::HandshakerInterface*
GetHandshaker(QuicCryptoClientStream * stream)1128 QuicCryptoClientStreamPeer::GetHandshaker(QuicCryptoClientStream* stream) {
1129 return stream->handshaker_.get();
1130 }
1131
CreateClientSessionForTest(QuicServerId server_id,QuicTime::Delta connection_start_time,const ParsedQuicVersionVector & supported_versions,QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,QuicCryptoClientConfig * crypto_client_config,PacketSavingConnection ** client_connection,TestQuicSpdyClientSession ** client_session)1132 void CreateClientSessionForTest(
1133 QuicServerId server_id, QuicTime::Delta connection_start_time,
1134 const ParsedQuicVersionVector& supported_versions,
1135 QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory,
1136 QuicCryptoClientConfig* crypto_client_config,
1137 PacketSavingConnection** client_connection,
1138 TestQuicSpdyClientSession** client_session) {
1139 QUICHE_CHECK(crypto_client_config);
1140 QUICHE_CHECK(client_connection);
1141 QUICHE_CHECK(client_session);
1142 QUICHE_CHECK(!connection_start_time.IsZero())
1143 << "Connections must start at non-zero times, otherwise the "
1144 << "strike-register will be unhappy.";
1145
1146 QuicConfig config = DefaultQuicConfig();
1147 *client_connection = new PacketSavingConnection(
1148 helper, alarm_factory, Perspective::IS_CLIENT, supported_versions);
1149 *client_session = new TestQuicSpdyClientSession(*client_connection, config,
1150 supported_versions, server_id,
1151 crypto_client_config);
1152 (*client_connection)->AdvanceTime(connection_start_time);
1153 }
1154
CreateServerSessionForTest(QuicServerId,QuicTime::Delta connection_start_time,ParsedQuicVersionVector supported_versions,QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,QuicCryptoServerConfig * server_crypto_config,QuicCompressedCertsCache * compressed_certs_cache,PacketSavingConnection ** server_connection,TestQuicSpdyServerSession ** server_session)1155 void CreateServerSessionForTest(
1156 QuicServerId /*server_id*/, QuicTime::Delta connection_start_time,
1157 ParsedQuicVersionVector supported_versions,
1158 QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory,
1159 QuicCryptoServerConfig* server_crypto_config,
1160 QuicCompressedCertsCache* compressed_certs_cache,
1161 PacketSavingConnection** server_connection,
1162 TestQuicSpdyServerSession** server_session) {
1163 QUICHE_CHECK(server_crypto_config);
1164 QUICHE_CHECK(server_connection);
1165 QUICHE_CHECK(server_session);
1166 QUICHE_CHECK(!connection_start_time.IsZero())
1167 << "Connections must start at non-zero times, otherwise the "
1168 << "strike-register will be unhappy.";
1169
1170 *server_connection =
1171 new PacketSavingConnection(helper, alarm_factory, Perspective::IS_SERVER,
1172 ParsedVersionOfIndex(supported_versions, 0));
1173 *server_session = new TestQuicSpdyServerSession(
1174 *server_connection, DefaultQuicConfig(), supported_versions,
1175 server_crypto_config, compressed_certs_cache);
1176 (*server_session)->Initialize();
1177
1178 // We advance the clock initially because the default time is zero and the
1179 // strike register worries that we've just overflowed a uint32_t time.
1180 (*server_connection)->AdvanceTime(connection_start_time);
1181 }
1182
GetNthClientInitiatedBidirectionalStreamId(QuicTransportVersion version,int n)1183 QuicStreamId GetNthClientInitiatedBidirectionalStreamId(
1184 QuicTransportVersion version, int n) {
1185 int num = n;
1186 if (!VersionUsesHttp3(version)) {
1187 num++;
1188 }
1189 return QuicUtils::GetFirstBidirectionalStreamId(version,
1190 Perspective::IS_CLIENT) +
1191 QuicUtils::StreamIdDelta(version) * num;
1192 }
1193
GetNthServerInitiatedBidirectionalStreamId(QuicTransportVersion version,int n)1194 QuicStreamId GetNthServerInitiatedBidirectionalStreamId(
1195 QuicTransportVersion version, int n) {
1196 return QuicUtils::GetFirstBidirectionalStreamId(version,
1197 Perspective::IS_SERVER) +
1198 QuicUtils::StreamIdDelta(version) * n;
1199 }
1200
GetNthServerInitiatedUnidirectionalStreamId(QuicTransportVersion version,int n)1201 QuicStreamId GetNthServerInitiatedUnidirectionalStreamId(
1202 QuicTransportVersion version, int n) {
1203 return QuicUtils::GetFirstUnidirectionalStreamId(version,
1204 Perspective::IS_SERVER) +
1205 QuicUtils::StreamIdDelta(version) * n;
1206 }
1207
GetNthClientInitiatedUnidirectionalStreamId(QuicTransportVersion version,int n)1208 QuicStreamId GetNthClientInitiatedUnidirectionalStreamId(
1209 QuicTransportVersion version, int n) {
1210 return QuicUtils::GetFirstUnidirectionalStreamId(version,
1211 Perspective::IS_CLIENT) +
1212 QuicUtils::StreamIdDelta(version) * n;
1213 }
1214
DetermineStreamType(QuicStreamId id,ParsedQuicVersion version,Perspective perspective,bool is_incoming,StreamType default_type)1215 StreamType DetermineStreamType(QuicStreamId id, ParsedQuicVersion version,
1216 Perspective perspective, bool is_incoming,
1217 StreamType default_type) {
1218 return version.HasIetfQuicFrames()
1219 ? QuicUtils::GetStreamType(id, perspective, is_incoming, version)
1220 : default_type;
1221 }
1222
MemSliceFromString(absl::string_view data)1223 quiche::QuicheMemSlice MemSliceFromString(absl::string_view data) {
1224 if (data.empty()) {
1225 return quiche::QuicheMemSlice();
1226 }
1227
1228 static quiche::SimpleBufferAllocator* allocator =
1229 new quiche::SimpleBufferAllocator();
1230 return quiche::QuicheMemSlice(quiche::QuicheBuffer::Copy(allocator, data));
1231 }
1232
EncryptPacket(uint64_t,absl::string_view,absl::string_view plaintext,char * output,size_t * output_length,size_t max_output_length)1233 bool TaggingEncrypter::EncryptPacket(uint64_t /*packet_number*/,
1234 absl::string_view /*associated_data*/,
1235 absl::string_view plaintext, char* output,
1236 size_t* output_length,
1237 size_t max_output_length) {
1238 const size_t len = plaintext.size() + kTagSize;
1239 if (max_output_length < len) {
1240 return false;
1241 }
1242 // Memmove is safe for inplace encryption.
1243 memmove(output, plaintext.data(), plaintext.size());
1244 output += plaintext.size();
1245 memset(output, tag_, kTagSize);
1246 *output_length = len;
1247 return true;
1248 }
1249
DecryptPacket(uint64_t,absl::string_view,absl::string_view ciphertext,char * output,size_t * output_length,size_t)1250 bool TaggingDecrypter::DecryptPacket(uint64_t /*packet_number*/,
1251 absl::string_view /*associated_data*/,
1252 absl::string_view ciphertext, char* output,
1253 size_t* output_length,
1254 size_t /*max_output_length*/) {
1255 if (ciphertext.size() < kTagSize) {
1256 return false;
1257 }
1258 if (!CheckTag(ciphertext, GetTag(ciphertext))) {
1259 return false;
1260 }
1261 *output_length = ciphertext.size() - kTagSize;
1262 memcpy(output, ciphertext.data(), *output_length);
1263 return true;
1264 }
1265
CheckTag(absl::string_view ciphertext,uint8_t tag)1266 bool TaggingDecrypter::CheckTag(absl::string_view ciphertext, uint8_t tag) {
1267 for (size_t i = ciphertext.size() - kTagSize; i < ciphertext.size(); i++) {
1268 if (ciphertext.data()[i] != tag) {
1269 return false;
1270 }
1271 }
1272
1273 return true;
1274 }
1275
TestPacketWriter(ParsedQuicVersion version,MockClock * clock,Perspective perspective)1276 TestPacketWriter::TestPacketWriter(ParsedQuicVersion version, MockClock* clock,
1277 Perspective perspective)
1278 : version_(version),
1279 framer_(SupportedVersions(version_),
1280 QuicUtils::InvertPerspective(perspective)),
1281 clock_(clock) {
1282 QuicFramerPeer::SetLastSerializedServerConnectionId(framer_.framer(),
1283 TestConnectionId());
1284 framer_.framer()->SetInitialObfuscators(TestConnectionId());
1285
1286 for (int i = 0; i < 128; ++i) {
1287 PacketBuffer* p = new PacketBuffer();
1288 packet_buffer_pool_.push_back(p);
1289 packet_buffer_pool_index_[p->buffer] = p;
1290 packet_buffer_free_list_.push_back(p);
1291 }
1292 }
1293
~TestPacketWriter()1294 TestPacketWriter::~TestPacketWriter() {
1295 EXPECT_EQ(packet_buffer_pool_.size(), packet_buffer_free_list_.size())
1296 << packet_buffer_pool_.size() - packet_buffer_free_list_.size()
1297 << " out of " << packet_buffer_pool_.size()
1298 << " packet buffers have been leaked.";
1299 for (auto p : packet_buffer_pool_) {
1300 delete p;
1301 }
1302 }
1303
WritePacket(const char * buffer,size_t buf_len,const QuicIpAddress & self_address,const QuicSocketAddress & peer_address,PerPacketOptions * options)1304 WriteResult TestPacketWriter::WritePacket(const char* buffer, size_t buf_len,
1305 const QuicIpAddress& self_address,
1306 const QuicSocketAddress& peer_address,
1307 PerPacketOptions* options) {
1308 last_write_source_address_ = self_address;
1309 last_write_peer_address_ = peer_address;
1310 // If the buffer is allocated from the pool, return it back to the pool.
1311 // Note the buffer content doesn't change.
1312 if (packet_buffer_pool_index_.find(const_cast<char*>(buffer)) !=
1313 packet_buffer_pool_index_.end()) {
1314 FreePacketBuffer(buffer);
1315 }
1316
1317 QuicEncryptedPacket packet(buffer, buf_len);
1318 ++packets_write_attempts_;
1319
1320 if (packet.length() >= sizeof(final_bytes_of_last_packet_)) {
1321 final_bytes_of_previous_packet_ = final_bytes_of_last_packet_;
1322 memcpy(&final_bytes_of_last_packet_, packet.data() + packet.length() - 4,
1323 sizeof(final_bytes_of_last_packet_));
1324 }
1325 if (framer_.framer()->version().KnowsWhichDecrypterToUse()) {
1326 framer_.framer()->InstallDecrypter(ENCRYPTION_HANDSHAKE,
1327 std::make_unique<TaggingDecrypter>());
1328 framer_.framer()->InstallDecrypter(ENCRYPTION_ZERO_RTT,
1329 std::make_unique<TaggingDecrypter>());
1330 framer_.framer()->InstallDecrypter(ENCRYPTION_FORWARD_SECURE,
1331 std::make_unique<TaggingDecrypter>());
1332 } else if (!framer_.framer()->HasDecrypterOfEncryptionLevel(
1333 ENCRYPTION_FORWARD_SECURE) &&
1334 !framer_.framer()->HasDecrypterOfEncryptionLevel(
1335 ENCRYPTION_ZERO_RTT)) {
1336 framer_.framer()->SetAlternativeDecrypter(
1337 ENCRYPTION_FORWARD_SECURE,
1338 std::make_unique<StrictTaggingDecrypter>(ENCRYPTION_FORWARD_SECURE),
1339 false);
1340 }
1341 EXPECT_EQ(next_packet_processable_, framer_.ProcessPacket(packet))
1342 << framer_.framer()->detailed_error() << " perspective "
1343 << framer_.framer()->perspective();
1344 next_packet_processable_ = true;
1345 if (block_on_next_write_) {
1346 write_blocked_ = true;
1347 block_on_next_write_ = false;
1348 }
1349 if (next_packet_too_large_) {
1350 next_packet_too_large_ = false;
1351 return WriteResult(WRITE_STATUS_ERROR, *MessageTooBigErrorCode());
1352 }
1353 if (always_get_packet_too_large_) {
1354 return WriteResult(WRITE_STATUS_ERROR, *MessageTooBigErrorCode());
1355 }
1356 if (IsWriteBlocked()) {
1357 return WriteResult(is_write_blocked_data_buffered_
1358 ? WRITE_STATUS_BLOCKED_DATA_BUFFERED
1359 : WRITE_STATUS_BLOCKED,
1360 0);
1361 }
1362
1363 if (ShouldWriteFail()) {
1364 return WriteResult(WRITE_STATUS_ERROR, write_error_code_);
1365 }
1366
1367 last_packet_size_ = packet.length();
1368 total_bytes_written_ += packet.length();
1369 last_packet_header_ = framer_.header();
1370 if (!framer_.connection_close_frames().empty()) {
1371 ++connection_close_packets_;
1372 }
1373 if (!write_pause_time_delta_.IsZero()) {
1374 clock_->AdvanceTime(write_pause_time_delta_);
1375 }
1376 if (is_batch_mode_) {
1377 bytes_buffered_ += last_packet_size_;
1378 return WriteResult(WRITE_STATUS_OK, 0);
1379 }
1380 last_ecn_sent_ = (options == nullptr) ? ECN_NOT_ECT : options->ecn_codepoint;
1381 return WriteResult(WRITE_STATUS_OK, last_packet_size_);
1382 }
1383
GetNextWriteLocation(const QuicIpAddress &,const QuicSocketAddress &)1384 QuicPacketBuffer TestPacketWriter::GetNextWriteLocation(
1385 const QuicIpAddress& /*self_address*/,
1386 const QuicSocketAddress& /*peer_address*/) {
1387 return {AllocPacketBuffer(), [this](const char* p) { FreePacketBuffer(p); }};
1388 }
1389
Flush()1390 WriteResult TestPacketWriter::Flush() {
1391 flush_attempts_++;
1392 if (block_on_next_flush_) {
1393 block_on_next_flush_ = false;
1394 SetWriteBlocked();
1395 return WriteResult(WRITE_STATUS_BLOCKED, /*errno*/ -1);
1396 }
1397 if (write_should_fail_) {
1398 return WriteResult(WRITE_STATUS_ERROR, /*errno*/ -1);
1399 }
1400 int bytes_flushed = bytes_buffered_;
1401 bytes_buffered_ = 0;
1402 return WriteResult(WRITE_STATUS_OK, bytes_flushed);
1403 }
1404
AllocPacketBuffer()1405 char* TestPacketWriter::AllocPacketBuffer() {
1406 PacketBuffer* p = packet_buffer_free_list_.front();
1407 EXPECT_FALSE(p->in_use);
1408 p->in_use = true;
1409 packet_buffer_free_list_.pop_front();
1410 return p->buffer;
1411 }
1412
FreePacketBuffer(const char * buffer)1413 void TestPacketWriter::FreePacketBuffer(const char* buffer) {
1414 auto iter = packet_buffer_pool_index_.find(const_cast<char*>(buffer));
1415 ASSERT_TRUE(iter != packet_buffer_pool_index_.end());
1416 PacketBuffer* p = iter->second;
1417 ASSERT_TRUE(p->in_use);
1418 p->in_use = false;
1419 packet_buffer_free_list_.push_back(p);
1420 }
1421
WriteServerVersionNegotiationProbeResponse(char * packet_bytes,size_t * packet_length_out,const char * source_connection_id_bytes,uint8_t source_connection_id_length)1422 bool WriteServerVersionNegotiationProbeResponse(
1423 char* packet_bytes, size_t* packet_length_out,
1424 const char* source_connection_id_bytes,
1425 uint8_t source_connection_id_length) {
1426 if (packet_bytes == nullptr) {
1427 QUIC_BUG(quic_bug_10256_1) << "Invalid packet_bytes";
1428 return false;
1429 }
1430 if (packet_length_out == nullptr) {
1431 QUIC_BUG(quic_bug_10256_2) << "Invalid packet_length_out";
1432 return false;
1433 }
1434 QuicConnectionId source_connection_id(source_connection_id_bytes,
1435 source_connection_id_length);
1436 std::unique_ptr<QuicEncryptedPacket> encrypted_packet =
1437 QuicFramer::BuildVersionNegotiationPacket(
1438 source_connection_id, EmptyQuicConnectionId(),
1439 /*ietf_quic=*/true, /*use_length_prefix=*/true,
1440 ParsedQuicVersionVector{});
1441 if (!encrypted_packet) {
1442 QUIC_BUG(quic_bug_10256_3) << "Failed to create version negotiation packet";
1443 return false;
1444 }
1445 if (*packet_length_out < encrypted_packet->length()) {
1446 QUIC_BUG(quic_bug_10256_4)
1447 << "Invalid *packet_length_out " << *packet_length_out << " < "
1448 << encrypted_packet->length();
1449 return false;
1450 }
1451 *packet_length_out = encrypted_packet->length();
1452 memcpy(packet_bytes, encrypted_packet->data(), *packet_length_out);
1453 return true;
1454 }
1455
ParseClientVersionNegotiationProbePacket(const char * packet_bytes,size_t packet_length,char * destination_connection_id_bytes,uint8_t * destination_connection_id_length_out)1456 bool ParseClientVersionNegotiationProbePacket(
1457 const char* packet_bytes, size_t packet_length,
1458 char* destination_connection_id_bytes,
1459 uint8_t* destination_connection_id_length_out) {
1460 if (packet_bytes == nullptr) {
1461 QUIC_BUG(quic_bug_10256_5) << "Invalid packet_bytes";
1462 return false;
1463 }
1464 if (packet_length < kMinPacketSizeForVersionNegotiation ||
1465 packet_length > 65535) {
1466 QUIC_BUG(quic_bug_10256_6) << "Invalid packet_length";
1467 return false;
1468 }
1469 if (destination_connection_id_bytes == nullptr) {
1470 QUIC_BUG(quic_bug_10256_7) << "Invalid destination_connection_id_bytes";
1471 return false;
1472 }
1473 if (destination_connection_id_length_out == nullptr) {
1474 QUIC_BUG(quic_bug_10256_8)
1475 << "Invalid destination_connection_id_length_out";
1476 return false;
1477 }
1478
1479 QuicEncryptedPacket encrypted_packet(packet_bytes, packet_length);
1480 PacketHeaderFormat format;
1481 QuicLongHeaderType long_packet_type;
1482 bool version_present, has_length_prefix;
1483 QuicVersionLabel version_label;
1484 ParsedQuicVersion parsed_version = ParsedQuicVersion::Unsupported();
1485 QuicConnectionId destination_connection_id, source_connection_id;
1486 absl::optional<absl::string_view> retry_token;
1487 std::string detailed_error;
1488 QuicErrorCode error = QuicFramer::ParsePublicHeaderDispatcher(
1489 encrypted_packet,
1490 /*expected_destination_connection_id_length=*/0, &format,
1491 &long_packet_type, &version_present, &has_length_prefix, &version_label,
1492 &parsed_version, &destination_connection_id, &source_connection_id,
1493 &retry_token, &detailed_error);
1494 if (error != QUIC_NO_ERROR) {
1495 QUIC_BUG(quic_bug_10256_9) << "Failed to parse packet: " << detailed_error;
1496 return false;
1497 }
1498 if (!version_present) {
1499 QUIC_BUG(quic_bug_10256_10) << "Packet is not a long header";
1500 return false;
1501 }
1502 if (*destination_connection_id_length_out <
1503 destination_connection_id.length()) {
1504 QUIC_BUG(quic_bug_10256_11)
1505 << "destination_connection_id_length_out too small";
1506 return false;
1507 }
1508 *destination_connection_id_length_out = destination_connection_id.length();
1509 memcpy(destination_connection_id_bytes, destination_connection_id.data(),
1510 *destination_connection_id_length_out);
1511 return true;
1512 }
1513
1514 } // namespace test
1515 } // namespace quic
1516