1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/websockets/websocket_handshake_stream_create_helper.h"
6
7 #include <optional>
8 #include <string>
9 #include <utility>
10 #include <vector>
11
12 #include "base/containers/span.h"
13 #include "base/functional/callback.h"
14 #include "base/memory/scoped_refptr.h"
15 #include "base/notreached.h"
16 #include "base/task/single_thread_task_runner.h"
17 #include "base/time/default_tick_clock.h"
18 #include "base/time/time.h"
19 #include "net/base/auth.h"
20 #include "net/base/completion_once_callback.h"
21 #include "net/base/connection_endpoint_metadata.h"
22 #include "net/base/host_port_pair.h"
23 #include "net/base/ip_address.h"
24 #include "net/base/ip_endpoint.h"
25 #include "net/base/load_flags.h"
26 #include "net/base/net_errors.h"
27 #include "net/base/network_anonymization_key.h"
28 #include "net/base/network_handle.h"
29 #include "net/base/privacy_mode.h"
30 #include "net/base/proxy_chain.h"
31 #include "net/base/proxy_server.h"
32 #include "net/base/request_priority.h"
33 #include "net/base/session_usage.h"
34 #include "net/base/test_completion_callback.h"
35 #include "net/cert/cert_verify_result.h"
36 #include "net/dns/public/host_resolver_results.h"
37 #include "net/dns/public/secure_dns_policy.h"
38 #include "net/http/http_request_info.h"
39 #include "net/http/http_response_headers.h"
40 #include "net/http/http_response_info.h"
41 #include "net/http/transport_security_state.h"
42 #include "net/log/net_log.h"
43 #include "net/log/net_log_with_source.h"
44 #include "net/quic/address_utils.h"
45 #include "net/quic/crypto/proof_verifier_chromium.h"
46 #include "net/quic/mock_crypto_client_stream_factory.h"
47 #include "net/quic/mock_quic_data.h"
48 #include "net/quic/quic_chromium_alarm_factory.h"
49 #include "net/quic/quic_chromium_connection_helper.h"
50 #include "net/quic/quic_chromium_packet_reader.h"
51 #include "net/quic/quic_chromium_packet_writer.h"
52 #include "net/quic/quic_context.h"
53 #include "net/quic/quic_http_utils.h"
54 #include "net/quic/quic_server_info.h"
55 #include "net/quic/quic_session_alias_key.h"
56 #include "net/quic/quic_session_key.h"
57 #include "net/quic/quic_test_packet_maker.h"
58 #include "net/quic/test_quic_crypto_client_config_handle.h"
59 #include "net/quic/test_task_runner.h"
60 #include "net/socket/client_socket_handle.h"
61 #include "net/socket/client_socket_pool.h"
62 #include "net/socket/connect_job.h"
63 #include "net/socket/socket_tag.h"
64 #include "net/socket/socket_test_util.h"
65 #include "net/socket/websocket_endpoint_lock_manager.h"
66 #include "net/spdy/multiplexed_session_creation_initiator.h"
67 #include "net/spdy/spdy_session_key.h"
68 #include "net/spdy/spdy_test_util_common.h"
69 #include "net/ssl/ssl_config_service_defaults.h"
70 #include "net/ssl/ssl_info.h"
71 #include "net/test/cert_test_util.h"
72 #include "net/test/gtest_util.h"
73 #include "net/test/test_data_directory.h"
74 #include "net/test/test_with_task_environment.h"
75 #include "net/third_party/quiche/src/quiche/common/http/http_header_block.h"
76 #include "net/third_party/quiche/src/quiche/common/platform/api/quiche_flags.h"
77 #include "net/third_party/quiche/src/quiche/http2/core/spdy_protocol.h"
78 #include "net/third_party/quiche/src/quiche/quic/core/crypto/quic_crypto_client_config.h"
79 #include "net/third_party/quiche/src/quiche/quic/core/qpack/qpack_decoder.h"
80 #include "net/third_party/quiche/src/quiche/quic/core/quic_connection.h"
81 #include "net/third_party/quiche/src/quiche/quic/core/quic_connection_id.h"
82 #include "net/third_party/quiche/src/quiche/quic/core/quic_error_codes.h"
83 #include "net/third_party/quiche/src/quiche/quic/core/quic_packets.h"
84 #include "net/third_party/quiche/src/quiche/quic/core/quic_time.h"
85 #include "net/third_party/quiche/src/quiche/quic/core/quic_types.h"
86 #include "net/third_party/quiche/src/quiche/quic/core/quic_utils.h"
87 #include "net/third_party/quiche/src/quiche/quic/core/quic_versions.h"
88 #include "net/third_party/quiche/src/quiche/quic/platform/api/quic_socket_address.h"
89 #include "net/third_party/quiche/src/quiche/quic/test_tools/crypto_test_utils.h"
90 #include "net/third_party/quiche/src/quiche/quic/test_tools/mock_clock.h"
91 #include "net/third_party/quiche/src/quiche/quic/test_tools/mock_connection_id_generator.h"
92 #include "net/third_party/quiche/src/quiche/quic/test_tools/mock_random.h"
93 #include "net/third_party/quiche/src/quiche/quic/test_tools/qpack/qpack_test_utils.h"
94 #include "net/third_party/quiche/src/quiche/quic/test_tools/quic_test_utils.h"
95 #include "net/traffic_annotation/network_traffic_annotation.h"
96 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
97 #include "net/websockets/websocket_basic_handshake_stream.h"
98 #include "net/websockets/websocket_event_interface.h"
99 #include "net/websockets/websocket_stream.h"
100 #include "net/websockets/websocket_test_util.h"
101 #include "testing/gmock/include/gmock/gmock.h"
102 #include "testing/gtest/include/gtest/gtest.h"
103 #include "url/gurl.h"
104 #include "url/origin.h"
105 #include "url/scheme_host_port.h"
106 #include "url/url_constants.h"
107
108 namespace net {
109 class HttpNetworkSession;
110 class URLRequest;
111 class WebSocketHttp2HandshakeStream;
112 class WebSocketHttp3HandshakeStream;
113 class X509Certificate;
114 struct WebSocketHandshakeRequestInfo;
115 struct WebSocketHandshakeResponseInfo;
116 } // namespace net
117
118 using ::net::test::IsError;
119 using ::net::test::IsOk;
120 using ::testing::_;
121 using ::testing::StrictMock;
122 using ::testing::TestWithParam;
123 using ::testing::Values;
124
125 namespace net {
126 namespace {
127
128 enum HandshakeStreamType {
129 BASIC_HANDSHAKE_STREAM,
130 HTTP2_HANDSHAKE_STREAM,
131 HTTP3_HANDSHAKE_STREAM
132 };
133
134 // This class encapsulates the details of creating a mock ClientSocketHandle.
135 class MockClientSocketHandleFactory {
136 public:
MockClientSocketHandleFactory()137 MockClientSocketHandleFactory()
138 : common_connect_job_params_(
139 socket_factory_maker_.factory(),
140 /*host_resolver=*/nullptr,
141 /*http_auth_cache=*/nullptr,
142 /*http_auth_handler_factory=*/nullptr,
143 /*spdy_session_pool=*/nullptr,
144 /*quic_supported_versions=*/nullptr,
145 /*quic_session_pool=*/nullptr,
146 /*proxy_delegate=*/nullptr,
147 /*http_user_agent_settings=*/nullptr,
148 /*ssl_client_context=*/nullptr,
149 /*socket_performance_watcher_factory=*/nullptr,
150 /*network_quality_estimator=*/nullptr,
151 /*net_log=*/nullptr,
152 /*websocket_endpoint_lock_manager=*/nullptr,
153 /*http_server_properties=*/nullptr,
154 /*alpn_protos=*/nullptr,
155 /*application_settings=*/nullptr,
156 /*ignore_certificate_errors=*/nullptr,
157 /*early_data_enabled=*/nullptr),
158 pool_(1, 1, &common_connect_job_params_) {}
159
160 MockClientSocketHandleFactory(const MockClientSocketHandleFactory&) = delete;
161 MockClientSocketHandleFactory& operator=(
162 const MockClientSocketHandleFactory&) = delete;
163
164 // The created socket expects |expect_written| to be written to the socket,
165 // and will respond with |return_to_read|. The test will fail if the expected
166 // text is not written, or if all the bytes are not read.
CreateClientSocketHandle(const std::string & expect_written,const std::string & return_to_read)167 std::unique_ptr<ClientSocketHandle> CreateClientSocketHandle(
168 const std::string& expect_written,
169 const std::string& return_to_read) {
170 socket_factory_maker_.SetExpectations(expect_written, return_to_read);
171 auto socket_handle = std::make_unique<ClientSocketHandle>();
172 socket_handle->Init(
173 ClientSocketPool::GroupId(
174 url::SchemeHostPort(url::kHttpScheme, "a", 80),
175 PrivacyMode::PRIVACY_MODE_DISABLED, NetworkAnonymizationKey(),
176 SecureDnsPolicy::kAllow, /*disable_cert_network_fetches=*/false),
177 scoped_refptr<ClientSocketPool::SocketParams>(),
178 std::nullopt /* proxy_annotation_tag */, MEDIUM, SocketTag(),
179 ClientSocketPool::RespectLimits::ENABLED, CompletionOnceCallback(),
180 ClientSocketPool::ProxyAuthCallback(), &pool_, NetLogWithSource());
181 return socket_handle;
182 }
183
184 private:
185 WebSocketMockClientSocketFactoryMaker socket_factory_maker_;
186 const CommonConnectJobParams common_connect_job_params_;
187 MockTransportClientSocketPool pool_;
188 };
189
190 class TestConnectDelegate : public WebSocketStream::ConnectDelegate {
191 public:
192 ~TestConnectDelegate() override = default;
193
OnCreateRequest(URLRequest * request)194 void OnCreateRequest(URLRequest* request) override {}
OnURLRequestConnected(URLRequest * request,const TransportInfo & info)195 void OnURLRequestConnected(URLRequest* request,
196 const TransportInfo& info) override {}
OnSuccess(std::unique_ptr<WebSocketStream> stream,std::unique_ptr<WebSocketHandshakeResponseInfo> response)197 void OnSuccess(
198 std::unique_ptr<WebSocketStream> stream,
199 std::unique_ptr<WebSocketHandshakeResponseInfo> response) override {}
OnFailure(const std::string & failure_message,int net_error,std::optional<int> response_code)200 void OnFailure(const std::string& failure_message,
201 int net_error,
202 std::optional<int> response_code) override {}
OnStartOpeningHandshake(std::unique_ptr<WebSocketHandshakeRequestInfo> request)203 void OnStartOpeningHandshake(
204 std::unique_ptr<WebSocketHandshakeRequestInfo> request) override {}
OnSSLCertificateError(std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,int net_error,const SSLInfo & ssl_info,bool fatal)205 void OnSSLCertificateError(
206 std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
207 ssl_error_callbacks,
208 int net_error,
209 const SSLInfo& ssl_info,
210 bool fatal) override {}
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> response_headers,const IPEndPoint & host_port_pair,base::OnceCallback<void (const AuthCredentials *)> callback,std::optional<AuthCredentials> * credentials)211 int OnAuthRequired(const AuthChallengeInfo& auth_info,
212 scoped_refptr<HttpResponseHeaders> response_headers,
213 const IPEndPoint& host_port_pair,
214 base::OnceCallback<void(const AuthCredentials*)> callback,
215 std::optional<AuthCredentials>* credentials) override {
216 *credentials = std::nullopt;
217 return OK;
218 }
219 };
220
221 class MockWebSocketStreamRequestAPI : public WebSocketStreamRequestAPI {
222 public:
223 ~MockWebSocketStreamRequestAPI() override = default;
224
225 MOCK_METHOD1(OnBasicHandshakeStreamCreated,
226 void(WebSocketBasicHandshakeStream* handshake_stream));
227 MOCK_METHOD1(OnHttp2HandshakeStreamCreated,
228 void(WebSocketHttp2HandshakeStream* handshake_stream));
229 MOCK_METHOD1(OnHttp3HandshakeStreamCreated,
230 void(WebSocketHttp3HandshakeStream* handshake_stream));
231 MOCK_METHOD3(OnFailure,
232 void(const std::string& message,
233 int net_error,
234 std::optional<int> response_code));
235 };
236
237 class WebSocketHandshakeStreamCreateHelperTest
238 : public TestWithParam<HandshakeStreamType>,
239 public WithTaskEnvironment {
240 protected:
WebSocketHandshakeStreamCreateHelperTest()241 WebSocketHandshakeStreamCreateHelperTest()
242 : quic_version_(quic::HandshakeProtocol::PROTOCOL_TLS1_3,
243 quic::QuicTransportVersion::QUIC_VERSION_IETF_RFC_V1),
244 mock_quic_data_(quic_version_) {}
CreateAndInitializeStream(const std::vector<std::string> & sub_protocols,const WebSocketExtraHeaders & extra_request_headers,const WebSocketExtraHeaders & extra_response_headers)245 std::unique_ptr<WebSocketStream> CreateAndInitializeStream(
246 const std::vector<std::string>& sub_protocols,
247 const WebSocketExtraHeaders& extra_request_headers,
248 const WebSocketExtraHeaders& extra_response_headers) {
249 constexpr char kPath[] = "/";
250 constexpr char kOrigin[] = "http://origin.example.org";
251 const GURL url("wss://www.example.org/");
252 NetLogWithSource net_log;
253
254 WebSocketHandshakeStreamCreateHelper create_helper(
255 &connect_delegate_, sub_protocols, &stream_request_);
256
257 switch (GetParam()) {
258 case BASIC_HANDSHAKE_STREAM:
259 EXPECT_CALL(stream_request_, OnBasicHandshakeStreamCreated(_)).Times(1);
260 break;
261
262 case HTTP2_HANDSHAKE_STREAM:
263 EXPECT_CALL(stream_request_, OnHttp2HandshakeStreamCreated(_)).Times(1);
264 break;
265
266 case HTTP3_HANDSHAKE_STREAM:
267 EXPECT_CALL(stream_request_, OnHttp3HandshakeStreamCreated(_)).Times(1);
268 break;
269
270 default:
271 NOTREACHED();
272 }
273
274 EXPECT_CALL(stream_request_, OnFailure(_, _, _)).Times(0);
275
276 HttpRequestInfo request_info;
277 request_info.url = url;
278 request_info.method = "GET";
279 request_info.load_flags = LOAD_DISABLE_CACHE;
280 request_info.traffic_annotation =
281 MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS);
282
283 auto headers = WebSocketCommonTestHeaders();
284
285 switch (GetParam()) {
286 case BASIC_HANDSHAKE_STREAM: {
287 std::unique_ptr<ClientSocketHandle> socket_handle =
288 socket_handle_factory_.CreateClientSocketHandle(
289 WebSocketStandardRequest(kPath, "www.example.org",
290 url::Origin::Create(GURL(kOrigin)),
291 /*send_additional_request_headers=*/{},
292 extra_request_headers),
293 WebSocketStandardResponse(
294 WebSocketExtraHeadersToString(extra_response_headers)));
295
296 std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
297 create_helper.CreateBasicStream(std::move(socket_handle), false,
298 &websocket_endpoint_lock_manager_);
299
300 // If in future the implementation type returned by CreateBasicStream()
301 // changes, this static_cast will be wrong. However, in that case the
302 // test will fail and AddressSanitizer should identify the issue.
303 static_cast<WebSocketBasicHandshakeStream*>(handshake.get())
304 ->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
305
306 handshake->RegisterRequest(&request_info);
307 int rv = handshake->InitializeStream(true, DEFAULT_PRIORITY, net_log,
308 CompletionOnceCallback());
309 EXPECT_THAT(rv, IsOk());
310
311 HttpResponseInfo response;
312 TestCompletionCallback request_callback;
313 rv = handshake->SendRequest(headers, &response,
314 request_callback.callback());
315 EXPECT_THAT(rv, IsOk());
316
317 TestCompletionCallback response_callback;
318 rv = handshake->ReadResponseHeaders(response_callback.callback());
319 EXPECT_THAT(rv, IsOk());
320 EXPECT_EQ(101, response.headers->response_code());
321 EXPECT_TRUE(response.headers->HasHeaderValue("Connection", "Upgrade"));
322 EXPECT_TRUE(response.headers->HasHeaderValue("Upgrade", "websocket"));
323 return handshake->Upgrade();
324 }
325 case HTTP2_HANDSHAKE_STREAM: {
326 SpdyTestUtil spdy_util;
327 quiche::HttpHeaderBlock request_header_block = WebSocketHttp2Request(
328 kPath, "www.example.org", kOrigin, extra_request_headers);
329 spdy::SpdySerializedFrame request_headers(
330 spdy_util.ConstructSpdyHeaders(1, std::move(request_header_block),
331 DEFAULT_PRIORITY, false));
332 MockWrite writes[] = {CreateMockWrite(request_headers, 0)};
333
334 quiche::HttpHeaderBlock response_header_block =
335 WebSocketHttp2Response(extra_response_headers);
336 spdy::SpdySerializedFrame response_headers(
337 spdy_util.ConstructSpdyResponseHeaders(
338 1, std::move(response_header_block), false));
339 MockRead reads[] = {CreateMockRead(response_headers, 1),
340 MockRead(ASYNC, 0, 2)};
341
342 SequencedSocketData data(reads, writes);
343
344 SSLSocketDataProvider ssl(ASYNC, OK);
345 ssl.ssl_info.cert =
346 ImportCertFromFile(GetTestCertsDirectory(), "wildcard.pem");
347
348 SpdySessionDependencies session_deps;
349 session_deps.socket_factory->AddSocketDataProvider(&data);
350 session_deps.socket_factory->AddSSLSocketDataProvider(&ssl);
351
352 std::unique_ptr<HttpNetworkSession> http_network_session =
353 SpdySessionDependencies::SpdyCreateSession(&session_deps);
354 const SpdySessionKey key(
355 HostPortPair::FromURL(url), PRIVACY_MODE_DISABLED,
356 ProxyChain::Direct(), SessionUsage::kDestination, SocketTag(),
357 NetworkAnonymizationKey(), SecureDnsPolicy::kAllow,
358 /*disable_cert_verification_network_fetches=*/false);
359 base::WeakPtr<SpdySession> spdy_session =
360 CreateSpdySession(http_network_session.get(), key, net_log);
361 std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
362 create_helper.CreateHttp2Stream(spdy_session, {} /* dns_aliases */);
363
364 handshake->RegisterRequest(&request_info);
365 int rv = handshake->InitializeStream(true, DEFAULT_PRIORITY,
366 NetLogWithSource(),
367 CompletionOnceCallback());
368 EXPECT_THAT(rv, IsOk());
369
370 HttpResponseInfo response;
371 TestCompletionCallback request_callback;
372 rv = handshake->SendRequest(headers, &response,
373 request_callback.callback());
374 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
375 rv = request_callback.WaitForResult();
376 EXPECT_THAT(rv, IsOk());
377
378 TestCompletionCallback response_callback;
379 rv = handshake->ReadResponseHeaders(response_callback.callback());
380 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
381 rv = response_callback.WaitForResult();
382 EXPECT_THAT(rv, IsOk());
383
384 EXPECT_EQ(200, response.headers->response_code());
385 return handshake->Upgrade();
386 }
387 case HTTP3_HANDSHAKE_STREAM: {
388 const quic::QuicStreamId client_data_stream_id(
389 quic::QuicUtils::GetFirstBidirectionalStreamId(
390 quic_version_.transport_version, quic::Perspective::IS_CLIENT));
391 quic::QuicCryptoClientConfig crypto_config(
392 quic::test::crypto_test_utils::ProofVerifierForTesting());
393
394 const quic::QuicConnectionId connection_id(
395 quic::test::TestConnectionId(2));
396 test::QuicTestPacketMaker client_maker(
397 quic_version_, connection_id, &clock_, "mail.example.org",
398 quic::Perspective::IS_CLIENT,
399 /*client_headers_include_h2_stream_dependency_=*/false);
400 test::QuicTestPacketMaker server_maker(
401 quic_version_, connection_id, &clock_, "mail.example.org",
402 quic::Perspective::IS_SERVER,
403 /*client_headers_include_h2_stream_dependency_=*/false);
404 IPEndPoint peer_addr(IPAddress(192, 0, 2, 23), 443);
405 quic::test::MockConnectionIdGenerator connection_id_generator;
406
407 testing::StrictMock<quic::test::MockQuicConnectionVisitor> visitor;
408 ProofVerifyDetailsChromium verify_details;
409 MockCryptoClientStreamFactory crypto_client_stream_factory;
410 TransportSecurityState transport_security_state;
411 SSLConfigServiceDefaults ssl_config_service;
412
413 FLAGS_quic_enable_http3_grease_randomness = false;
414 clock_.AdvanceTime(quic::QuicTime::Delta::FromMilliseconds(20));
415 quic::QuicEnableVersion(quic_version_);
416 quic::test::MockRandom random_generator{0};
417
418 quiche::HttpHeaderBlock request_header_block = WebSocketHttp2Request(
419 kPath, "www.example.org", kOrigin, extra_request_headers);
420
421 int packet_number = 1;
422 mock_quic_data_.AddWrite(
423 SYNCHRONOUS,
424 client_maker.MakeInitialSettingsPacket(packet_number++));
425
426 mock_quic_data_.AddWrite(
427 ASYNC,
428 client_maker.MakeRequestHeadersPacket(
429 packet_number++, client_data_stream_id,
430 /*fin=*/false, ConvertRequestPriorityToQuicPriority(LOWEST),
431 std::move(request_header_block), nullptr));
432
433 quiche::HttpHeaderBlock response_header_block =
434 WebSocketHttp2Response(extra_response_headers);
435
436 mock_quic_data_.AddRead(
437 ASYNC, server_maker.MakeResponseHeadersPacket(
438 /*packet_number=*/1, client_data_stream_id,
439 /*fin=*/false, std::move(response_header_block),
440 /*spdy_headers_frame_length=*/nullptr));
441
442 mock_quic_data_.AddRead(SYNCHRONOUS, ERR_IO_PENDING);
443
444 mock_quic_data_.AddWrite(
445 SYNCHRONOUS,
446 client_maker.Packet(packet_number++)
447 .AddAckFrame(/*first_received=*/1, /*largest_received=*/1,
448 /*smallest_received=*/0)
449 .AddStopSendingFrame(client_data_stream_id,
450 quic::QUIC_STREAM_CANCELLED)
451 .AddRstStreamFrame(client_data_stream_id,
452 quic::QUIC_STREAM_CANCELLED)
453 .Build());
454 auto socket = std::make_unique<MockUDPClientSocket>(
455 mock_quic_data_.InitializeAndGetSequencedSocketData(),
456 NetLog::Get());
457 socket->Connect(peer_addr);
458
459 scoped_refptr<test::TestTaskRunner> runner =
460 base::MakeRefCounted<test::TestTaskRunner>(&clock_);
461 auto helper = std::make_unique<QuicChromiumConnectionHelper>(
462 &clock_, &random_generator);
463 auto alarm_factory =
464 std::make_unique<QuicChromiumAlarmFactory>(runner.get(), &clock_);
465 // Ownership of 'writer' is passed to 'QuicConnection'.
466 QuicChromiumPacketWriter* writer = new QuicChromiumPacketWriter(
467 socket.get(),
468 base::SingleThreadTaskRunner::GetCurrentDefault().get());
469 quic::QuicConnection* connection = new quic::QuicConnection(
470 connection_id, quic::QuicSocketAddress(),
471 net::ToQuicSocketAddress(peer_addr), helper.get(),
472 alarm_factory.get(), writer, true /* owns_writer */,
473 quic::Perspective::IS_CLIENT,
474 quic::test::SupportedVersions(quic_version_),
475 connection_id_generator);
476 connection->set_visitor(&visitor);
477
478 // Load a certificate that is valid for *.example.org
479 scoped_refptr<X509Certificate> test_cert(
480 ImportCertFromFile(GetTestCertsDirectory(), "wildcard.pem"));
481 EXPECT_TRUE(test_cert.get());
482
483 verify_details.cert_verify_result.verified_cert = test_cert;
484 verify_details.cert_verify_result.is_issued_by_known_root = true;
485 crypto_client_stream_factory.AddProofVerifyDetails(&verify_details);
486
487 base::TimeTicks dns_end = base::TimeTicks::Now();
488 base::TimeTicks dns_start = dns_end - base::Milliseconds(1);
489
490 session_ = std::make_unique<QuicChromiumClientSession>(
491 connection, std::move(socket),
492 /*stream_factory=*/nullptr, &crypto_client_stream_factory, &clock_,
493 &transport_security_state, &ssl_config_service,
494 /*server_info=*/nullptr,
495 QuicSessionAliasKey(
496 url::SchemeHostPort(),
497 QuicSessionKey("mail.example.org", 80, PRIVACY_MODE_DISABLED,
498 ProxyChain::Direct(), SessionUsage::kDestination,
499 SocketTag(), NetworkAnonymizationKey(),
500 SecureDnsPolicy::kAllow,
501 /*require_dns_https_alpn=*/false)),
502 /*require_confirmation=*/false,
503 /*migrate_session_early_v2=*/false,
504 /*migrate_session_on_network_change_v2=*/false,
505 /*default_network=*/handles::kInvalidNetworkHandle,
506 quic::QuicTime::Delta::FromMilliseconds(
507 kDefaultRetransmittableOnWireTimeout.InMilliseconds()),
508 /*migrate_idle_session=*/true, /*allow_port_migration=*/false,
509 kDefaultIdleSessionMigrationPeriod,
510 /*multi_port_probing_interval=*/0, kMaxTimeOnNonDefaultNetwork,
511 kMaxMigrationsToNonDefaultNetworkOnWriteError,
512 kMaxMigrationsToNonDefaultNetworkOnPathDegrading,
513 kQuicYieldAfterPacketsRead,
514 quic::QuicTime::Delta::FromMilliseconds(
515 kQuicYieldAfterDurationMilliseconds),
516 /*cert_verify_flags=*/0, quic::test::DefaultQuicConfig(),
517 std::make_unique<TestQuicCryptoClientConfigHandle>(&crypto_config),
518 "CONNECTION_UNKNOWN", dns_start, dns_end,
519 base::DefaultTickClock::GetInstance(),
520 base::SingleThreadTaskRunner::GetCurrentDefault().get(),
521 /*socket_performance_watcher=*/nullptr,
522 ConnectionEndpointMetadata(), /*report_ecn=*/true,
523 /*enable_origin_frame=*/true,
524 /*allow_server_preferred_address=*/true,
525 MultiplexedSessionCreationInitiator::kUnknown,
526 NetLogWithSource::Make(NetLogSourceType::NONE));
527
528 session_->Initialize();
529
530 // Blackhole QPACK decoder stream instead of constructing mock writes.
531 session_->qpack_decoder()->set_qpack_stream_sender_delegate(
532 &noop_qpack_stream_sender_delegate_);
533 TestCompletionCallback callback;
534 EXPECT_THAT(session_->CryptoConnect(callback.callback()), IsOk());
535 EXPECT_TRUE(session_->OneRttKeysAvailable());
536 std::unique_ptr<QuicChromiumClientSession::Handle> session_handle =
537 session_->CreateHandle(
538 url::SchemeHostPort(url::kHttpsScheme, "mail.example.org", 80));
539
540 std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
541 create_helper.CreateHttp3Stream(std::move(session_handle),
542 {} /* dns_aliases */);
543
544 handshake->RegisterRequest(&request_info);
545 int rv = handshake->InitializeStream(true, DEFAULT_PRIORITY, net_log,
546 CompletionOnceCallback());
547 EXPECT_THAT(rv, IsOk());
548
549 HttpResponseInfo response;
550 TestCompletionCallback request_callback;
551 rv = handshake->SendRequest(headers, &response,
552 request_callback.callback());
553 EXPECT_THAT(rv, IsOk());
554
555 session_->StartReading();
556
557 TestCompletionCallback response_callback;
558 rv = handshake->ReadResponseHeaders(response_callback.callback());
559 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
560 rv = response_callback.WaitForResult();
561 EXPECT_THAT(rv, IsOk());
562
563 EXPECT_EQ(200, response.headers->response_code());
564
565 return handshake->Upgrade();
566 }
567 default:
568 NOTREACHED();
569 }
570 }
571
572 private:
573 MockClientSocketHandleFactory socket_handle_factory_;
574 TestConnectDelegate connect_delegate_;
575 StrictMock<MockWebSocketStreamRequestAPI> stream_request_;
576 WebSocketEndpointLockManager websocket_endpoint_lock_manager_;
577
578 // For HTTP3_HANDSHAKE_STREAM
579 quic::ParsedQuicVersion quic_version_;
580 quic::MockClock clock_;
581 std::unique_ptr<QuicChromiumClientSession> session_;
582 test::MockQuicData mock_quic_data_;
583 quic::test::NoopQpackStreamSenderDelegate noop_qpack_stream_sender_delegate_;
584 };
585
586 INSTANTIATE_TEST_SUITE_P(All,
587 WebSocketHandshakeStreamCreateHelperTest,
588 Values(BASIC_HANDSHAKE_STREAM,
589 HTTP2_HANDSHAKE_STREAM,
590 HTTP3_HANDSHAKE_STREAM));
591
592 // Confirm that the basic case works as expected.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,BasicStream)593 TEST_P(WebSocketHandshakeStreamCreateHelperTest, BasicStream) {
594 std::unique_ptr<WebSocketStream> stream =
595 CreateAndInitializeStream({}, {}, {});
596 EXPECT_EQ("", stream->GetExtensions());
597 EXPECT_EQ("", stream->GetSubProtocol());
598 }
599
600 // Verify that the sub-protocols are passed through.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,SubProtocols)601 TEST_P(WebSocketHandshakeStreamCreateHelperTest, SubProtocols) {
602 std::vector<std::string> sub_protocols;
603 sub_protocols.push_back("chat");
604 sub_protocols.push_back("superchat");
605 std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
606 sub_protocols, {{"Sec-WebSocket-Protocol", "chat, superchat"}},
607 {{"Sec-WebSocket-Protocol", "superchat"}});
608 EXPECT_EQ("superchat", stream->GetSubProtocol());
609 }
610
611 // Verify that extension name is available. Bad extension names are tested in
612 // websocket_stream_test.cc.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,Extensions)613 TEST_P(WebSocketHandshakeStreamCreateHelperTest, Extensions) {
614 std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
615 {}, {}, {{"Sec-WebSocket-Extensions", "permessage-deflate"}});
616 EXPECT_EQ("permessage-deflate", stream->GetExtensions());
617 }
618
619 // Verify that extension parameters are available. Bad parameters are tested in
620 // websocket_stream_test.cc.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,ExtensionParameters)621 TEST_P(WebSocketHandshakeStreamCreateHelperTest, ExtensionParameters) {
622 std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
623 {}, {},
624 {{"Sec-WebSocket-Extensions",
625 "permessage-deflate;"
626 " client_max_window_bits=14; server_max_window_bits=14;"
627 " server_no_context_takeover; client_no_context_takeover"}});
628
629 EXPECT_EQ(
630 "permessage-deflate;"
631 " client_max_window_bits=14; server_max_window_bits=14;"
632 " server_no_context_takeover; client_no_context_takeover",
633 stream->GetExtensions());
634 }
635
636 } // namespace
637
638 } // namespace net
639