1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/websockets/websocket_handshake_stream_create_helper.h"
6
7 #include <string>
8 #include <utility>
9 #include <vector>
10
11 #include "base/containers/span.h"
12 #include "base/functional/callback.h"
13 #include "base/memory/scoped_refptr.h"
14 #include "base/notreached.h"
15 #include "base/strings/string_piece.h"
16 #include "base/task/single_thread_task_runner.h"
17 #include "base/time/default_tick_clock.h"
18 #include "base/time/time.h"
19 #include "net/base/auth.h"
20 #include "net/base/completion_once_callback.h"
21 #include "net/base/host_port_pair.h"
22 #include "net/base/ip_address.h"
23 #include "net/base/ip_endpoint.h"
24 #include "net/base/load_flags.h"
25 #include "net/base/net_errors.h"
26 #include "net/base/network_anonymization_key.h"
27 #include "net/base/network_handle.h"
28 #include "net/base/privacy_mode.h"
29 #include "net/base/proxy_server.h"
30 #include "net/base/request_priority.h"
31 #include "net/base/test_completion_callback.h"
32 #include "net/cert/cert_verify_result.h"
33 #include "net/dns/public/host_resolver_results.h"
34 #include "net/dns/public/secure_dns_policy.h"
35 #include "net/http/http_request_info.h"
36 #include "net/http/http_response_headers.h"
37 #include "net/http/http_response_info.h"
38 #include "net/http/transport_security_state.h"
39 #include "net/log/net_log.h"
40 #include "net/log/net_log_with_source.h"
41 #include "net/quic/address_utils.h"
42 #include "net/quic/crypto/proof_verifier_chromium.h"
43 #include "net/quic/mock_crypto_client_stream_factory.h"
44 #include "net/quic/mock_quic_data.h"
45 #include "net/quic/quic_chromium_alarm_factory.h"
46 #include "net/quic/quic_chromium_connection_helper.h"
47 #include "net/quic/quic_chromium_packet_reader.h"
48 #include "net/quic/quic_chromium_packet_writer.h"
49 #include "net/quic/quic_context.h"
50 #include "net/quic/quic_http_utils.h"
51 #include "net/quic/quic_server_info.h"
52 #include "net/quic/quic_session_key.h"
53 #include "net/quic/quic_test_packet_maker.h"
54 #include "net/quic/test_quic_crypto_client_config_handle.h"
55 #include "net/quic/test_task_runner.h"
56 #include "net/socket/client_socket_handle.h"
57 #include "net/socket/client_socket_pool.h"
58 #include "net/socket/connect_job.h"
59 #include "net/socket/socket_tag.h"
60 #include "net/socket/socket_test_util.h"
61 #include "net/socket/websocket_endpoint_lock_manager.h"
62 #include "net/spdy/spdy_session_key.h"
63 #include "net/spdy/spdy_test_util_common.h"
64 #include "net/ssl/ssl_config_service_defaults.h"
65 #include "net/ssl/ssl_info.h"
66 #include "net/test/cert_test_util.h"
67 #include "net/test/gtest_util.h"
68 #include "net/test/test_data_directory.h"
69 #include "net/test/test_with_task_environment.h"
70 #include "net/third_party/quiche/src/quiche/common/platform/api/quiche_flags.h"
71 #include "net/third_party/quiche/src/quiche/quic/core/crypto/quic_crypto_client_config.h"
72 #include "net/third_party/quiche/src/quiche/quic/core/qpack/qpack_decoder.h"
73 #include "net/third_party/quiche/src/quiche/quic/core/quic_connection.h"
74 #include "net/third_party/quiche/src/quiche/quic/core/quic_connection_id.h"
75 #include "net/third_party/quiche/src/quiche/quic/core/quic_error_codes.h"
76 #include "net/third_party/quiche/src/quiche/quic/core/quic_packets.h"
77 #include "net/third_party/quiche/src/quiche/quic/core/quic_time.h"
78 #include "net/third_party/quiche/src/quiche/quic/core/quic_types.h"
79 #include "net/third_party/quiche/src/quiche/quic/core/quic_utils.h"
80 #include "net/third_party/quiche/src/quiche/quic/core/quic_versions.h"
81 #include "net/third_party/quiche/src/quiche/quic/platform/api/quic_socket_address.h"
82 #include "net/third_party/quiche/src/quiche/quic/test_tools/crypto_test_utils.h"
83 #include "net/third_party/quiche/src/quiche/quic/test_tools/mock_clock.h"
84 #include "net/third_party/quiche/src/quiche/quic/test_tools/mock_connection_id_generator.h"
85 #include "net/third_party/quiche/src/quiche/quic/test_tools/mock_random.h"
86 #include "net/third_party/quiche/src/quiche/quic/test_tools/qpack/qpack_test_utils.h"
87 #include "net/third_party/quiche/src/quiche/quic/test_tools/quic_test_utils.h"
88 #include "net/third_party/quiche/src/quiche/spdy/core/http2_header_block.h"
89 #include "net/third_party/quiche/src/quiche/spdy/core/spdy_protocol.h"
90 #include "net/traffic_annotation/network_traffic_annotation.h"
91 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
92 #include "net/websockets/websocket_basic_handshake_stream.h"
93 #include "net/websockets/websocket_event_interface.h"
94 #include "net/websockets/websocket_stream.h"
95 #include "net/websockets/websocket_test_util.h"
96 #include "testing/gmock/include/gmock/gmock.h"
97 #include "testing/gtest/include/gtest/gtest.h"
98 #include "third_party/abseil-cpp/absl/types/optional.h"
99 #include "url/gurl.h"
100 #include "url/origin.h"
101 #include "url/scheme_host_port.h"
102 #include "url/url_constants.h"
103
104 namespace net {
105 class HttpNetworkSession;
106 class URLRequest;
107 class WebSocketHttp2HandshakeStream;
108 class WebSocketHttp3HandshakeStream;
109 class X509Certificate;
110 struct WebSocketHandshakeRequestInfo;
111 struct WebSocketHandshakeResponseInfo;
112 } // namespace net
113
114 using ::net::test::IsError;
115 using ::net::test::IsOk;
116 using ::testing::StrictMock;
117 using ::testing::TestWithParam;
118 using ::testing::Values;
119 using ::testing::_;
120
121 namespace net {
122 namespace {
123
124 enum HandshakeStreamType {
125 BASIC_HANDSHAKE_STREAM,
126 HTTP2_HANDSHAKE_STREAM,
127 HTTP3_HANDSHAKE_STREAM
128 };
129
130 // This class encapsulates the details of creating a mock ClientSocketHandle.
131 class MockClientSocketHandleFactory {
132 public:
MockClientSocketHandleFactory()133 MockClientSocketHandleFactory()
134 : common_connect_job_params_(
135 socket_factory_maker_.factory(),
136 /*host_resolver=*/nullptr,
137 /*http_auth_cache=*/nullptr,
138 /*http_auth_handler_factory=*/nullptr,
139 /*spdy_session_pool=*/nullptr,
140 /*quic_supported_versions=*/nullptr,
141 /*quic_stream_factory=*/nullptr,
142 /*proxy_delegate=*/nullptr,
143 /*http_user_agent_settings=*/nullptr,
144 /*ssl_client_context=*/nullptr,
145 /*socket_performance_watcher_factory=*/nullptr,
146 /*network_quality_estimator=*/nullptr,
147 /*net_log=*/nullptr,
148 /*websocket_endpoint_lock_manager=*/nullptr,
149 /*http_server_properties=*/nullptr,
150 /*alpn_protos=*/nullptr,
151 /*application_settings=*/nullptr,
152 /*ignore_certificate_errors=*/nullptr),
153 pool_(1, 1, &common_connect_job_params_) {}
154
155 MockClientSocketHandleFactory(const MockClientSocketHandleFactory&) = delete;
156 MockClientSocketHandleFactory& operator=(
157 const MockClientSocketHandleFactory&) = delete;
158
159 // The created socket expects |expect_written| to be written to the socket,
160 // and will respond with |return_to_read|. The test will fail if the expected
161 // text is not written, or if all the bytes are not read.
CreateClientSocketHandle(const std::string & expect_written,const std::string & return_to_read)162 std::unique_ptr<ClientSocketHandle> CreateClientSocketHandle(
163 const std::string& expect_written,
164 const std::string& return_to_read) {
165 socket_factory_maker_.SetExpectations(expect_written, return_to_read);
166 auto socket_handle = std::make_unique<ClientSocketHandle>();
167 socket_handle->Init(
168 ClientSocketPool::GroupId(
169 url::SchemeHostPort(url::kHttpScheme, "a", 80),
170 PrivacyMode::PRIVACY_MODE_DISABLED, NetworkAnonymizationKey(),
171 SecureDnsPolicy::kAllow),
172 scoped_refptr<ClientSocketPool::SocketParams>(),
173 absl::nullopt /* proxy_annotation_tag */, MEDIUM, SocketTag(),
174 ClientSocketPool::RespectLimits::ENABLED, CompletionOnceCallback(),
175 ClientSocketPool::ProxyAuthCallback(), &pool_, NetLogWithSource());
176 return socket_handle;
177 }
178
179 private:
180 WebSocketMockClientSocketFactoryMaker socket_factory_maker_;
181 const CommonConnectJobParams common_connect_job_params_;
182 MockTransportClientSocketPool pool_;
183 };
184
185 class TestConnectDelegate : public WebSocketStream::ConnectDelegate {
186 public:
187 ~TestConnectDelegate() override = default;
188
OnCreateRequest(URLRequest * request)189 void OnCreateRequest(URLRequest* request) override {}
OnSuccess(std::unique_ptr<WebSocketStream> stream,std::unique_ptr<WebSocketHandshakeResponseInfo> response)190 void OnSuccess(
191 std::unique_ptr<WebSocketStream> stream,
192 std::unique_ptr<WebSocketHandshakeResponseInfo> response) override {}
OnFailure(const std::string & failure_message,int net_error,absl::optional<int> response_code)193 void OnFailure(const std::string& failure_message,
194 int net_error,
195 absl::optional<int> response_code) override {}
OnStartOpeningHandshake(std::unique_ptr<WebSocketHandshakeRequestInfo> request)196 void OnStartOpeningHandshake(
197 std::unique_ptr<WebSocketHandshakeRequestInfo> request) override {}
OnSSLCertificateError(std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,int net_error,const SSLInfo & ssl_info,bool fatal)198 void OnSSLCertificateError(
199 std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
200 ssl_error_callbacks,
201 int net_error,
202 const SSLInfo& ssl_info,
203 bool fatal) override {}
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> response_headers,const IPEndPoint & host_port_pair,base::OnceCallback<void (const AuthCredentials *)> callback,absl::optional<AuthCredentials> * credentials)204 int OnAuthRequired(const AuthChallengeInfo& auth_info,
205 scoped_refptr<HttpResponseHeaders> response_headers,
206 const IPEndPoint& host_port_pair,
207 base::OnceCallback<void(const AuthCredentials*)> callback,
208 absl::optional<AuthCredentials>* credentials) override {
209 *credentials = absl::nullopt;
210 return OK;
211 }
212 };
213
214 class MockWebSocketStreamRequestAPI : public WebSocketStreamRequestAPI {
215 public:
216 ~MockWebSocketStreamRequestAPI() override = default;
217
218 MOCK_METHOD1(OnBasicHandshakeStreamCreated,
219 void(WebSocketBasicHandshakeStream* handshake_stream));
220 MOCK_METHOD1(OnHttp2HandshakeStreamCreated,
221 void(WebSocketHttp2HandshakeStream* handshake_stream));
222 MOCK_METHOD1(OnHttp3HandshakeStreamCreated,
223 void(WebSocketHttp3HandshakeStream* handshake_stream));
224 MOCK_METHOD3(OnFailure,
225 void(const std::string& message,
226 int net_error,
227 absl::optional<int> response_code));
228 };
229
230 class WebSocketHandshakeStreamCreateHelperTest
231 : public TestWithParam<HandshakeStreamType>,
232 public WithTaskEnvironment {
233 protected:
WebSocketHandshakeStreamCreateHelperTest()234 WebSocketHandshakeStreamCreateHelperTest()
235 : quic_version_(quic::HandshakeProtocol::PROTOCOL_TLS1_3,
236 quic::QuicTransportVersion::QUIC_VERSION_IETF_RFC_V1),
237 mock_quic_data_(quic_version_) {}
CreateAndInitializeStream(const std::vector<std::string> & sub_protocols,const WebSocketExtraHeaders & extra_request_headers,const WebSocketExtraHeaders & extra_response_headers)238 std::unique_ptr<WebSocketStream> CreateAndInitializeStream(
239 const std::vector<std::string>& sub_protocols,
240 const WebSocketExtraHeaders& extra_request_headers,
241 const WebSocketExtraHeaders& extra_response_headers) {
242 const char kPath[] = "/";
243 const char kOrigin[] = "http://origin.example.org";
244 const GURL url("wss://www.example.org/");
245 NetLogWithSource net_log;
246
247 WebSocketHandshakeStreamCreateHelper create_helper(
248 &connect_delegate_, sub_protocols, &stream_request_);
249
250 switch (GetParam()) {
251 case BASIC_HANDSHAKE_STREAM:
252 EXPECT_CALL(stream_request_, OnBasicHandshakeStreamCreated(_)).Times(1);
253 break;
254
255 case HTTP2_HANDSHAKE_STREAM:
256 EXPECT_CALL(stream_request_, OnHttp2HandshakeStreamCreated(_)).Times(1);
257 break;
258
259 case HTTP3_HANDSHAKE_STREAM:
260 EXPECT_CALL(stream_request_, OnHttp3HandshakeStreamCreated(_)).Times(1);
261 break;
262
263 default:
264 NOTREACHED();
265 }
266
267 EXPECT_CALL(stream_request_, OnFailure(_, _, _)).Times(0);
268
269 HttpRequestInfo request_info;
270 request_info.url = url;
271 request_info.method = "GET";
272 request_info.load_flags = LOAD_DISABLE_CACHE;
273 request_info.traffic_annotation =
274 MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS);
275
276 auto headers = WebSocketCommonTestHeaders();
277
278 switch (GetParam()) {
279 case BASIC_HANDSHAKE_STREAM: {
280 std::unique_ptr<ClientSocketHandle> socket_handle =
281 socket_handle_factory_.CreateClientSocketHandle(
282 WebSocketStandardRequest(kPath, "www.example.org",
283 url::Origin::Create(GURL(kOrigin)),
284 /*send_additional_request_headers=*/{},
285 extra_request_headers),
286 WebSocketStandardResponse(
287 WebSocketExtraHeadersToString(extra_response_headers)));
288
289 std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
290 create_helper.CreateBasicStream(std::move(socket_handle), false,
291 &websocket_endpoint_lock_manager_);
292
293 // If in future the implementation type returned by CreateBasicStream()
294 // changes, this static_cast will be wrong. However, in that case the
295 // test will fail and AddressSanitizer should identify the issue.
296 static_cast<WebSocketBasicHandshakeStream*>(handshake.get())
297 ->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
298
299 handshake->RegisterRequest(&request_info);
300 int rv = handshake->InitializeStream(true, DEFAULT_PRIORITY, net_log,
301 CompletionOnceCallback());
302 EXPECT_THAT(rv, IsOk());
303
304 HttpResponseInfo response;
305 TestCompletionCallback request_callback;
306 rv = handshake->SendRequest(headers, &response,
307 request_callback.callback());
308 EXPECT_THAT(rv, IsOk());
309
310 TestCompletionCallback response_callback;
311 rv = handshake->ReadResponseHeaders(response_callback.callback());
312 EXPECT_THAT(rv, IsOk());
313 EXPECT_EQ(101, response.headers->response_code());
314 EXPECT_TRUE(response.headers->HasHeaderValue("Connection", "Upgrade"));
315 EXPECT_TRUE(response.headers->HasHeaderValue("Upgrade", "websocket"));
316 return handshake->Upgrade();
317 }
318 case HTTP2_HANDSHAKE_STREAM: {
319 SpdyTestUtil spdy_util;
320 spdy::Http2HeaderBlock request_header_block = WebSocketHttp2Request(
321 kPath, "www.example.org", kOrigin, extra_request_headers);
322 spdy::SpdySerializedFrame request_headers(
323 spdy_util.ConstructSpdyHeaders(1, std::move(request_header_block),
324 DEFAULT_PRIORITY, false));
325 MockWrite writes[] = {CreateMockWrite(request_headers, 0)};
326
327 spdy::Http2HeaderBlock response_header_block =
328 WebSocketHttp2Response(extra_response_headers);
329 spdy::SpdySerializedFrame response_headers(
330 spdy_util.ConstructSpdyResponseHeaders(
331 1, std::move(response_header_block), false));
332 MockRead reads[] = {CreateMockRead(response_headers, 1),
333 MockRead(ASYNC, 0, 2)};
334
335 SequencedSocketData data(reads, writes);
336
337 SSLSocketDataProvider ssl(ASYNC, OK);
338 ssl.ssl_info.cert =
339 ImportCertFromFile(GetTestCertsDirectory(), "wildcard.pem");
340
341 SpdySessionDependencies session_deps;
342 session_deps.socket_factory->AddSocketDataProvider(&data);
343 session_deps.socket_factory->AddSSLSocketDataProvider(&ssl);
344
345 std::unique_ptr<HttpNetworkSession> http_network_session =
346 SpdySessionDependencies::SpdyCreateSession(&session_deps);
347 const SpdySessionKey key(
348 HostPortPair::FromURL(url), ProxyChain::Direct(),
349 PRIVACY_MODE_DISABLED, SpdySessionKey::IsProxySession::kFalse,
350 SocketTag(), NetworkAnonymizationKey(), SecureDnsPolicy::kAllow);
351 base::WeakPtr<SpdySession> spdy_session =
352 CreateSpdySession(http_network_session.get(), key, net_log);
353 std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
354 create_helper.CreateHttp2Stream(spdy_session, {} /* dns_aliases */);
355
356 handshake->RegisterRequest(&request_info);
357 int rv = handshake->InitializeStream(true, DEFAULT_PRIORITY,
358 NetLogWithSource(),
359 CompletionOnceCallback());
360 EXPECT_THAT(rv, IsOk());
361
362 HttpResponseInfo response;
363 TestCompletionCallback request_callback;
364 rv = handshake->SendRequest(headers, &response,
365 request_callback.callback());
366 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
367 rv = request_callback.WaitForResult();
368 EXPECT_THAT(rv, IsOk());
369
370 TestCompletionCallback response_callback;
371 rv = handshake->ReadResponseHeaders(response_callback.callback());
372 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
373 rv = response_callback.WaitForResult();
374 EXPECT_THAT(rv, IsOk());
375
376 EXPECT_EQ(200, response.headers->response_code());
377 return handshake->Upgrade();
378 }
379 case HTTP3_HANDSHAKE_STREAM: {
380 const quic::QuicStreamId client_data_stream_id(
381 quic::QuicUtils::GetFirstBidirectionalStreamId(
382 quic_version_.transport_version, quic::Perspective::IS_CLIENT));
383 quic::QuicCryptoClientConfig crypto_config(
384 quic::test::crypto_test_utils::ProofVerifierForTesting());
385
386 const quic::QuicConnectionId connection_id(
387 quic::test::TestConnectionId(2));
388 test::QuicTestPacketMaker client_maker(
389 quic_version_, connection_id, &clock_, "mail.example.org",
390 quic::Perspective::IS_CLIENT,
391 /*client_headers_include_h2_stream_dependency_=*/false);
392 test::QuicTestPacketMaker server_maker(
393 quic_version_, connection_id, &clock_, "mail.example.org",
394 quic::Perspective::IS_SERVER,
395 /*client_headers_include_h2_stream_dependency_=*/false);
396 IPEndPoint peer_addr(IPAddress(192, 0, 2, 23), 443);
397 quic::test::MockConnectionIdGenerator connection_id_generator;
398
399 testing::StrictMock<quic::test::MockQuicConnectionVisitor> visitor;
400 ProofVerifyDetailsChromium verify_details;
401 MockCryptoClientStreamFactory crypto_client_stream_factory;
402 TransportSecurityState transport_security_state;
403 SSLConfigServiceDefaults ssl_config_service;
404
405 FLAGS_quic_enable_http3_grease_randomness = false;
406 clock_.AdvanceTime(quic::QuicTime::Delta::FromMilliseconds(20));
407 quic::QuicEnableVersion(quic_version_);
408 quic::test::MockRandom random_generator{0};
409
410 spdy::Http2HeaderBlock request_header_block = WebSocketHttp2Request(
411 kPath, "www.example.org", kOrigin, extra_request_headers);
412
413 int packet_number = 1;
414 mock_quic_data_.AddWrite(
415 SYNCHRONOUS,
416 client_maker.MakeInitialSettingsPacket(packet_number++));
417
418 mock_quic_data_.AddWrite(
419 ASYNC,
420 client_maker.MakeRequestHeadersPacket(
421 packet_number++, client_data_stream_id,
422 /*fin=*/false, ConvertRequestPriorityToQuicPriority(LOWEST),
423 std::move(request_header_block), nullptr));
424
425 spdy::Http2HeaderBlock response_header_block =
426 WebSocketHttp2Response(extra_response_headers);
427
428 mock_quic_data_.AddRead(
429 ASYNC, server_maker.MakeResponseHeadersPacket(
430 /*packet_number=*/1, client_data_stream_id,
431 /*fin=*/false, std::move(response_header_block),
432 /*spdy_headers_frame_length=*/nullptr));
433
434 mock_quic_data_.AddRead(SYNCHRONOUS, ERR_IO_PENDING);
435
436 mock_quic_data_.AddWrite(SYNCHRONOUS,
437 client_maker.MakeAckAndRstPacket(
438 packet_number++, client_data_stream_id,
439 quic::QUIC_STREAM_CANCELLED, 1, 0,
440 /*include_stop_sending_if_v99=*/true));
441 auto socket = std::make_unique<MockUDPClientSocket>(
442 mock_quic_data_.InitializeAndGetSequencedSocketData(),
443 NetLog::Get());
444 socket->Connect(peer_addr);
445
446 scoped_refptr<test::TestTaskRunner> runner =
447 base::MakeRefCounted<test::TestTaskRunner>(&clock_);
448 auto helper = std::make_unique<QuicChromiumConnectionHelper>(
449 &clock_, &random_generator);
450 auto alarm_factory =
451 std::make_unique<QuicChromiumAlarmFactory>(runner.get(), &clock_);
452 // Ownership of 'writer' is passed to 'QuicConnection'.
453 QuicChromiumPacketWriter* writer = new QuicChromiumPacketWriter(
454 socket.get(),
455 base::SingleThreadTaskRunner::GetCurrentDefault().get());
456 quic::QuicConnection* connection = new quic::QuicConnection(
457 connection_id, quic::QuicSocketAddress(),
458 net::ToQuicSocketAddress(peer_addr), helper.get(),
459 alarm_factory.get(), writer, true /* owns_writer */,
460 quic::Perspective::IS_CLIENT,
461 quic::test::SupportedVersions(quic_version_),
462 connection_id_generator);
463 connection->set_visitor(&visitor);
464
465 // Load a certificate that is valid for *.example.org
466 scoped_refptr<X509Certificate> test_cert(
467 ImportCertFromFile(GetTestCertsDirectory(), "wildcard.pem"));
468 EXPECT_TRUE(test_cert.get());
469
470 verify_details.cert_verify_result.verified_cert = test_cert;
471 verify_details.cert_verify_result.is_issued_by_known_root = true;
472 crypto_client_stream_factory.AddProofVerifyDetails(&verify_details);
473
474 base::TimeTicks dns_end = base::TimeTicks::Now();
475 base::TimeTicks dns_start = dns_end - base::Milliseconds(1);
476
477 session_ = std::make_unique<QuicChromiumClientSession>(
478 connection, std::move(socket),
479 /*stream_factory=*/nullptr, &crypto_client_stream_factory, &clock_,
480 &transport_security_state, &ssl_config_service,
481 /*server_info=*/nullptr,
482 QuicSessionKey("mail.example.org", 80, PRIVACY_MODE_DISABLED,
483 SocketTag(), NetworkAnonymizationKey(),
484 SecureDnsPolicy::kAllow,
485 /*require_dns_https_alpn=*/false),
486 /*require_confirmation=*/false,
487 /*migrate_session_early_v2=*/false,
488 /*migrate_session_on_network_change_v2=*/false,
489 /*default_network=*/handles::kInvalidNetworkHandle,
490 quic::QuicTime::Delta::FromMilliseconds(
491 kDefaultRetransmittableOnWireTimeout.InMilliseconds()),
492 /*migrate_idle_session=*/true, /*allow_port_migration=*/false,
493 kDefaultIdleSessionMigrationPeriod,
494 /*multi_port_probing_interval=*/0, kMaxTimeOnNonDefaultNetwork,
495 kMaxMigrationsToNonDefaultNetworkOnWriteError,
496 kMaxMigrationsToNonDefaultNetworkOnPathDegrading,
497 kQuicYieldAfterPacketsRead,
498 quic::QuicTime::Delta::FromMilliseconds(
499 kQuicYieldAfterDurationMilliseconds),
500 /*cert_verify_flags=*/0, quic::test::DefaultQuicConfig(),
501 std::make_unique<TestQuicCryptoClientConfigHandle>(&crypto_config),
502 dns_start, dns_end, base::DefaultTickClock::GetInstance(),
503 base::SingleThreadTaskRunner::GetCurrentDefault().get(),
504 /*socket_performance_watcher=*/nullptr,
505 HostResolverEndpointResult(), NetLog::Get());
506
507 session_->Initialize();
508
509 // Blackhole QPACK decoder stream instead of constructing mock writes.
510 session_->qpack_decoder()->set_qpack_stream_sender_delegate(
511 &noop_qpack_stream_sender_delegate_);
512 TestCompletionCallback callback;
513 EXPECT_THAT(session_->CryptoConnect(callback.callback()), IsOk());
514 EXPECT_TRUE(session_->OneRttKeysAvailable());
515 std::unique_ptr<QuicChromiumClientSession::Handle> session_handle =
516 session_->CreateHandle(
517 url::SchemeHostPort(url::kHttpsScheme, "mail.example.org", 80));
518
519 std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
520 create_helper.CreateHttp3Stream(std::move(session_handle),
521 {} /* dns_aliases */);
522
523 handshake->RegisterRequest(&request_info);
524 int rv = handshake->InitializeStream(true, DEFAULT_PRIORITY, net_log,
525 CompletionOnceCallback());
526 EXPECT_THAT(rv, IsOk());
527
528 HttpResponseInfo response;
529 TestCompletionCallback request_callback;
530 rv = handshake->SendRequest(headers, &response,
531 request_callback.callback());
532 EXPECT_THAT(rv, IsOk());
533
534 session_->StartReading();
535
536 TestCompletionCallback response_callback;
537 rv = handshake->ReadResponseHeaders(response_callback.callback());
538 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
539 rv = response_callback.WaitForResult();
540 EXPECT_THAT(rv, IsOk());
541
542 EXPECT_EQ(200, response.headers->response_code());
543
544 return handshake->Upgrade();
545 }
546 default:
547 NOTREACHED();
548 return nullptr;
549 }
550 }
551
552 private:
553 MockClientSocketHandleFactory socket_handle_factory_;
554 TestConnectDelegate connect_delegate_;
555 StrictMock<MockWebSocketStreamRequestAPI> stream_request_;
556 WebSocketEndpointLockManager websocket_endpoint_lock_manager_;
557
558 // For HTTP3_HANDSHAKE_STREAM
559 quic::ParsedQuicVersion quic_version_;
560 quic::MockClock clock_;
561 std::unique_ptr<QuicChromiumClientSession> session_;
562 test::MockQuicData mock_quic_data_;
563 quic::test::NoopQpackStreamSenderDelegate noop_qpack_stream_sender_delegate_;
564 };
565
566 INSTANTIATE_TEST_SUITE_P(All,
567 WebSocketHandshakeStreamCreateHelperTest,
568 Values(BASIC_HANDSHAKE_STREAM,
569 HTTP2_HANDSHAKE_STREAM,
570 HTTP3_HANDSHAKE_STREAM));
571
572 // Confirm that the basic case works as expected.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,BasicStream)573 TEST_P(WebSocketHandshakeStreamCreateHelperTest, BasicStream) {
574 std::unique_ptr<WebSocketStream> stream =
575 CreateAndInitializeStream({}, {}, {});
576 EXPECT_EQ("", stream->GetExtensions());
577 EXPECT_EQ("", stream->GetSubProtocol());
578 }
579
580 // Verify that the sub-protocols are passed through.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,SubProtocols)581 TEST_P(WebSocketHandshakeStreamCreateHelperTest, SubProtocols) {
582 std::vector<std::string> sub_protocols;
583 sub_protocols.push_back("chat");
584 sub_protocols.push_back("superchat");
585 std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
586 sub_protocols, {{"Sec-WebSocket-Protocol", "chat, superchat"}},
587 {{"Sec-WebSocket-Protocol", "superchat"}});
588 EXPECT_EQ("superchat", stream->GetSubProtocol());
589 }
590
591 // Verify that extension name is available. Bad extension names are tested in
592 // websocket_stream_test.cc.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,Extensions)593 TEST_P(WebSocketHandshakeStreamCreateHelperTest, Extensions) {
594 std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
595 {}, {}, {{"Sec-WebSocket-Extensions", "permessage-deflate"}});
596 EXPECT_EQ("permessage-deflate", stream->GetExtensions());
597 }
598
599 // Verify that extension parameters are available. Bad parameters are tested in
600 // websocket_stream_test.cc.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,ExtensionParameters)601 TEST_P(WebSocketHandshakeStreamCreateHelperTest, ExtensionParameters) {
602 std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
603 {}, {},
604 {{"Sec-WebSocket-Extensions",
605 "permessage-deflate;"
606 " client_max_window_bits=14; server_max_window_bits=14;"
607 " server_no_context_takeover; client_no_context_takeover"}});
608
609 EXPECT_EQ(
610 "permessage-deflate;"
611 " client_max_window_bits=14; server_max_window_bits=14;"
612 " server_no_context_takeover; client_no_context_takeover",
613 stream->GetExtensions());
614 }
615
616 } // namespace
617
618 } // namespace net
619