• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_handshake_stream_create_helper.h"
6 
7 #include <optional>
8 #include <string>
9 #include <utility>
10 #include <vector>
11 
12 #include "base/containers/span.h"
13 #include "base/functional/callback.h"
14 #include "base/memory/scoped_refptr.h"
15 #include "base/notreached.h"
16 #include "base/task/single_thread_task_runner.h"
17 #include "base/time/default_tick_clock.h"
18 #include "base/time/time.h"
19 #include "net/base/auth.h"
20 #include "net/base/completion_once_callback.h"
21 #include "net/base/connection_endpoint_metadata.h"
22 #include "net/base/host_port_pair.h"
23 #include "net/base/ip_address.h"
24 #include "net/base/ip_endpoint.h"
25 #include "net/base/load_flags.h"
26 #include "net/base/net_errors.h"
27 #include "net/base/network_anonymization_key.h"
28 #include "net/base/network_handle.h"
29 #include "net/base/privacy_mode.h"
30 #include "net/base/proxy_chain.h"
31 #include "net/base/proxy_server.h"
32 #include "net/base/request_priority.h"
33 #include "net/base/session_usage.h"
34 #include "net/base/test_completion_callback.h"
35 #include "net/cert/cert_verify_result.h"
36 #include "net/dns/public/host_resolver_results.h"
37 #include "net/dns/public/secure_dns_policy.h"
38 #include "net/http/http_request_info.h"
39 #include "net/http/http_response_headers.h"
40 #include "net/http/http_response_info.h"
41 #include "net/http/transport_security_state.h"
42 #include "net/log/net_log.h"
43 #include "net/log/net_log_with_source.h"
44 #include "net/quic/address_utils.h"
45 #include "net/quic/crypto/proof_verifier_chromium.h"
46 #include "net/quic/mock_crypto_client_stream_factory.h"
47 #include "net/quic/mock_quic_data.h"
48 #include "net/quic/quic_chromium_alarm_factory.h"
49 #include "net/quic/quic_chromium_connection_helper.h"
50 #include "net/quic/quic_chromium_packet_reader.h"
51 #include "net/quic/quic_chromium_packet_writer.h"
52 #include "net/quic/quic_context.h"
53 #include "net/quic/quic_http_utils.h"
54 #include "net/quic/quic_server_info.h"
55 #include "net/quic/quic_session_alias_key.h"
56 #include "net/quic/quic_session_key.h"
57 #include "net/quic/quic_test_packet_maker.h"
58 #include "net/quic/test_quic_crypto_client_config_handle.h"
59 #include "net/quic/test_task_runner.h"
60 #include "net/socket/client_socket_handle.h"
61 #include "net/socket/client_socket_pool.h"
62 #include "net/socket/connect_job.h"
63 #include "net/socket/socket_tag.h"
64 #include "net/socket/socket_test_util.h"
65 #include "net/socket/websocket_endpoint_lock_manager.h"
66 #include "net/spdy/multiplexed_session_creation_initiator.h"
67 #include "net/spdy/spdy_session_key.h"
68 #include "net/spdy/spdy_test_util_common.h"
69 #include "net/ssl/ssl_config_service_defaults.h"
70 #include "net/ssl/ssl_info.h"
71 #include "net/test/cert_test_util.h"
72 #include "net/test/gtest_util.h"
73 #include "net/test/test_data_directory.h"
74 #include "net/test/test_with_task_environment.h"
75 #include "net/third_party/quiche/src/quiche/common/http/http_header_block.h"
76 #include "net/third_party/quiche/src/quiche/common/platform/api/quiche_flags.h"
77 #include "net/third_party/quiche/src/quiche/http2/core/spdy_protocol.h"
78 #include "net/third_party/quiche/src/quiche/quic/core/crypto/quic_crypto_client_config.h"
79 #include "net/third_party/quiche/src/quiche/quic/core/qpack/qpack_decoder.h"
80 #include "net/third_party/quiche/src/quiche/quic/core/quic_connection.h"
81 #include "net/third_party/quiche/src/quiche/quic/core/quic_connection_id.h"
82 #include "net/third_party/quiche/src/quiche/quic/core/quic_error_codes.h"
83 #include "net/third_party/quiche/src/quiche/quic/core/quic_packets.h"
84 #include "net/third_party/quiche/src/quiche/quic/core/quic_time.h"
85 #include "net/third_party/quiche/src/quiche/quic/core/quic_types.h"
86 #include "net/third_party/quiche/src/quiche/quic/core/quic_utils.h"
87 #include "net/third_party/quiche/src/quiche/quic/core/quic_versions.h"
88 #include "net/third_party/quiche/src/quiche/quic/platform/api/quic_socket_address.h"
89 #include "net/third_party/quiche/src/quiche/quic/test_tools/crypto_test_utils.h"
90 #include "net/third_party/quiche/src/quiche/quic/test_tools/mock_clock.h"
91 #include "net/third_party/quiche/src/quiche/quic/test_tools/mock_connection_id_generator.h"
92 #include "net/third_party/quiche/src/quiche/quic/test_tools/mock_random.h"
93 #include "net/third_party/quiche/src/quiche/quic/test_tools/qpack/qpack_test_utils.h"
94 #include "net/third_party/quiche/src/quiche/quic/test_tools/quic_test_utils.h"
95 #include "net/traffic_annotation/network_traffic_annotation.h"
96 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
97 #include "net/websockets/websocket_basic_handshake_stream.h"
98 #include "net/websockets/websocket_event_interface.h"
99 #include "net/websockets/websocket_stream.h"
100 #include "net/websockets/websocket_test_util.h"
101 #include "testing/gmock/include/gmock/gmock.h"
102 #include "testing/gtest/include/gtest/gtest.h"
103 #include "url/gurl.h"
104 #include "url/origin.h"
105 #include "url/scheme_host_port.h"
106 #include "url/url_constants.h"
107 
108 namespace net {
109 class HttpNetworkSession;
110 class URLRequest;
111 class WebSocketHttp2HandshakeStream;
112 class WebSocketHttp3HandshakeStream;
113 class X509Certificate;
114 struct WebSocketHandshakeRequestInfo;
115 struct WebSocketHandshakeResponseInfo;
116 }  // namespace net
117 
118 using ::net::test::IsError;
119 using ::net::test::IsOk;
120 using ::testing::_;
121 using ::testing::StrictMock;
122 using ::testing::TestWithParam;
123 using ::testing::Values;
124 
125 namespace net {
126 namespace {
127 
128 enum HandshakeStreamType {
129   BASIC_HANDSHAKE_STREAM,
130   HTTP2_HANDSHAKE_STREAM,
131   HTTP3_HANDSHAKE_STREAM
132 };
133 
134 // This class encapsulates the details of creating a mock ClientSocketHandle.
135 class MockClientSocketHandleFactory {
136  public:
MockClientSocketHandleFactory()137   MockClientSocketHandleFactory()
138       : common_connect_job_params_(
139             socket_factory_maker_.factory(),
140             /*host_resolver=*/nullptr,
141             /*http_auth_cache=*/nullptr,
142             /*http_auth_handler_factory=*/nullptr,
143             /*spdy_session_pool=*/nullptr,
144             /*quic_supported_versions=*/nullptr,
145             /*quic_session_pool=*/nullptr,
146             /*proxy_delegate=*/nullptr,
147             /*http_user_agent_settings=*/nullptr,
148             /*ssl_client_context=*/nullptr,
149             /*socket_performance_watcher_factory=*/nullptr,
150             /*network_quality_estimator=*/nullptr,
151             /*net_log=*/nullptr,
152             /*websocket_endpoint_lock_manager=*/nullptr,
153             /*http_server_properties=*/nullptr,
154             /*alpn_protos=*/nullptr,
155             /*application_settings=*/nullptr,
156             /*ignore_certificate_errors=*/nullptr,
157             /*early_data_enabled=*/nullptr),
158         pool_(1, 1, &common_connect_job_params_) {}
159 
160   MockClientSocketHandleFactory(const MockClientSocketHandleFactory&) = delete;
161   MockClientSocketHandleFactory& operator=(
162       const MockClientSocketHandleFactory&) = delete;
163 
164   // The created socket expects |expect_written| to be written to the socket,
165   // and will respond with |return_to_read|. The test will fail if the expected
166   // text is not written, or if all the bytes are not read.
CreateClientSocketHandle(const std::string & expect_written,const std::string & return_to_read)167   std::unique_ptr<ClientSocketHandle> CreateClientSocketHandle(
168       const std::string& expect_written,
169       const std::string& return_to_read) {
170     socket_factory_maker_.SetExpectations(expect_written, return_to_read);
171     auto socket_handle = std::make_unique<ClientSocketHandle>();
172     socket_handle->Init(
173         ClientSocketPool::GroupId(
174             url::SchemeHostPort(url::kHttpScheme, "a", 80),
175             PrivacyMode::PRIVACY_MODE_DISABLED, NetworkAnonymizationKey(),
176             SecureDnsPolicy::kAllow, /*disable_cert_network_fetches=*/false),
177         scoped_refptr<ClientSocketPool::SocketParams>(),
178         std::nullopt /* proxy_annotation_tag */, MEDIUM, SocketTag(),
179         ClientSocketPool::RespectLimits::ENABLED, CompletionOnceCallback(),
180         ClientSocketPool::ProxyAuthCallback(), &pool_, NetLogWithSource());
181     return socket_handle;
182   }
183 
184  private:
185   WebSocketMockClientSocketFactoryMaker socket_factory_maker_;
186   const CommonConnectJobParams common_connect_job_params_;
187   MockTransportClientSocketPool pool_;
188 };
189 
190 class TestConnectDelegate : public WebSocketStream::ConnectDelegate {
191  public:
192   ~TestConnectDelegate() override = default;
193 
OnCreateRequest(URLRequest * request)194   void OnCreateRequest(URLRequest* request) override {}
OnURLRequestConnected(URLRequest * request,const TransportInfo & info)195   void OnURLRequestConnected(URLRequest* request,
196                              const TransportInfo& info) override {}
OnSuccess(std::unique_ptr<WebSocketStream> stream,std::unique_ptr<WebSocketHandshakeResponseInfo> response)197   void OnSuccess(
198       std::unique_ptr<WebSocketStream> stream,
199       std::unique_ptr<WebSocketHandshakeResponseInfo> response) override {}
OnFailure(const std::string & failure_message,int net_error,std::optional<int> response_code)200   void OnFailure(const std::string& failure_message,
201                  int net_error,
202                  std::optional<int> response_code) override {}
OnStartOpeningHandshake(std::unique_ptr<WebSocketHandshakeRequestInfo> request)203   void OnStartOpeningHandshake(
204       std::unique_ptr<WebSocketHandshakeRequestInfo> request) override {}
OnSSLCertificateError(std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,int net_error,const SSLInfo & ssl_info,bool fatal)205   void OnSSLCertificateError(
206       std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
207           ssl_error_callbacks,
208       int net_error,
209       const SSLInfo& ssl_info,
210       bool fatal) override {}
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> response_headers,const IPEndPoint & host_port_pair,base::OnceCallback<void (const AuthCredentials *)> callback,std::optional<AuthCredentials> * credentials)211   int OnAuthRequired(const AuthChallengeInfo& auth_info,
212                      scoped_refptr<HttpResponseHeaders> response_headers,
213                      const IPEndPoint& host_port_pair,
214                      base::OnceCallback<void(const AuthCredentials*)> callback,
215                      std::optional<AuthCredentials>* credentials) override {
216     *credentials = std::nullopt;
217     return OK;
218   }
219 };
220 
221 class MockWebSocketStreamRequestAPI : public WebSocketStreamRequestAPI {
222  public:
223   ~MockWebSocketStreamRequestAPI() override = default;
224 
225   MOCK_METHOD1(OnBasicHandshakeStreamCreated,
226                void(WebSocketBasicHandshakeStream* handshake_stream));
227   MOCK_METHOD1(OnHttp2HandshakeStreamCreated,
228                void(WebSocketHttp2HandshakeStream* handshake_stream));
229   MOCK_METHOD1(OnHttp3HandshakeStreamCreated,
230                void(WebSocketHttp3HandshakeStream* handshake_stream));
231   MOCK_METHOD3(OnFailure,
232                void(const std::string& message,
233                     int net_error,
234                     std::optional<int> response_code));
235 };
236 
237 class WebSocketHandshakeStreamCreateHelperTest
238     : public TestWithParam<HandshakeStreamType>,
239       public WithTaskEnvironment {
240  protected:
WebSocketHandshakeStreamCreateHelperTest()241   WebSocketHandshakeStreamCreateHelperTest()
242       : quic_version_(quic::HandshakeProtocol::PROTOCOL_TLS1_3,
243                       quic::QuicTransportVersion::QUIC_VERSION_IETF_RFC_V1),
244         mock_quic_data_(quic_version_) {}
CreateAndInitializeStream(const std::vector<std::string> & sub_protocols,const WebSocketExtraHeaders & extra_request_headers,const WebSocketExtraHeaders & extra_response_headers)245   std::unique_ptr<WebSocketStream> CreateAndInitializeStream(
246       const std::vector<std::string>& sub_protocols,
247       const WebSocketExtraHeaders& extra_request_headers,
248       const WebSocketExtraHeaders& extra_response_headers) {
249     constexpr char kPath[] = "/";
250     constexpr char kOrigin[] = "http://origin.example.org";
251     const GURL url("wss://www.example.org/");
252     NetLogWithSource net_log;
253 
254     WebSocketHandshakeStreamCreateHelper create_helper(
255         &connect_delegate_, sub_protocols, &stream_request_);
256 
257     switch (GetParam()) {
258       case BASIC_HANDSHAKE_STREAM:
259         EXPECT_CALL(stream_request_, OnBasicHandshakeStreamCreated(_)).Times(1);
260         break;
261 
262       case HTTP2_HANDSHAKE_STREAM:
263         EXPECT_CALL(stream_request_, OnHttp2HandshakeStreamCreated(_)).Times(1);
264         break;
265 
266       case HTTP3_HANDSHAKE_STREAM:
267         EXPECT_CALL(stream_request_, OnHttp3HandshakeStreamCreated(_)).Times(1);
268         break;
269 
270       default:
271         NOTREACHED();
272     }
273 
274     EXPECT_CALL(stream_request_, OnFailure(_, _, _)).Times(0);
275 
276     HttpRequestInfo request_info;
277     request_info.url = url;
278     request_info.method = "GET";
279     request_info.load_flags = LOAD_DISABLE_CACHE;
280     request_info.traffic_annotation =
281         MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS);
282 
283     auto headers = WebSocketCommonTestHeaders();
284 
285     switch (GetParam()) {
286       case BASIC_HANDSHAKE_STREAM: {
287         std::unique_ptr<ClientSocketHandle> socket_handle =
288             socket_handle_factory_.CreateClientSocketHandle(
289                 WebSocketStandardRequest(kPath, "www.example.org",
290                                          url::Origin::Create(GURL(kOrigin)),
291                                          /*send_additional_request_headers=*/{},
292                                          extra_request_headers),
293                 WebSocketStandardResponse(
294                     WebSocketExtraHeadersToString(extra_response_headers)));
295 
296         std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
297             create_helper.CreateBasicStream(std::move(socket_handle), false,
298                                             &websocket_endpoint_lock_manager_);
299 
300         // If in future the implementation type returned by CreateBasicStream()
301         // changes, this static_cast will be wrong. However, in that case the
302         // test will fail and AddressSanitizer should identify the issue.
303         static_cast<WebSocketBasicHandshakeStream*>(handshake.get())
304             ->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
305 
306         handshake->RegisterRequest(&request_info);
307         int rv = handshake->InitializeStream(true, DEFAULT_PRIORITY, net_log,
308                                              CompletionOnceCallback());
309         EXPECT_THAT(rv, IsOk());
310 
311         HttpResponseInfo response;
312         TestCompletionCallback request_callback;
313         rv = handshake->SendRequest(headers, &response,
314                                     request_callback.callback());
315         EXPECT_THAT(rv, IsOk());
316 
317         TestCompletionCallback response_callback;
318         rv = handshake->ReadResponseHeaders(response_callback.callback());
319         EXPECT_THAT(rv, IsOk());
320         EXPECT_EQ(101, response.headers->response_code());
321         EXPECT_TRUE(response.headers->HasHeaderValue("Connection", "Upgrade"));
322         EXPECT_TRUE(response.headers->HasHeaderValue("Upgrade", "websocket"));
323         return handshake->Upgrade();
324       }
325       case HTTP2_HANDSHAKE_STREAM: {
326         SpdyTestUtil spdy_util;
327         quiche::HttpHeaderBlock request_header_block = WebSocketHttp2Request(
328             kPath, "www.example.org", kOrigin, extra_request_headers);
329         spdy::SpdySerializedFrame request_headers(
330             spdy_util.ConstructSpdyHeaders(1, std::move(request_header_block),
331                                            DEFAULT_PRIORITY, false));
332         MockWrite writes[] = {CreateMockWrite(request_headers, 0)};
333 
334         quiche::HttpHeaderBlock response_header_block =
335             WebSocketHttp2Response(extra_response_headers);
336         spdy::SpdySerializedFrame response_headers(
337             spdy_util.ConstructSpdyResponseHeaders(
338                 1, std::move(response_header_block), false));
339         MockRead reads[] = {CreateMockRead(response_headers, 1),
340                             MockRead(ASYNC, 0, 2)};
341 
342         SequencedSocketData data(reads, writes);
343 
344         SSLSocketDataProvider ssl(ASYNC, OK);
345         ssl.ssl_info.cert =
346             ImportCertFromFile(GetTestCertsDirectory(), "wildcard.pem");
347 
348         SpdySessionDependencies session_deps;
349         session_deps.socket_factory->AddSocketDataProvider(&data);
350         session_deps.socket_factory->AddSSLSocketDataProvider(&ssl);
351 
352         std::unique_ptr<HttpNetworkSession> http_network_session =
353             SpdySessionDependencies::SpdyCreateSession(&session_deps);
354         const SpdySessionKey key(
355             HostPortPair::FromURL(url), PRIVACY_MODE_DISABLED,
356             ProxyChain::Direct(), SessionUsage::kDestination, SocketTag(),
357             NetworkAnonymizationKey(), SecureDnsPolicy::kAllow,
358             /*disable_cert_verification_network_fetches=*/false);
359         base::WeakPtr<SpdySession> spdy_session =
360             CreateSpdySession(http_network_session.get(), key, net_log);
361         std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
362             create_helper.CreateHttp2Stream(spdy_session, {} /* dns_aliases */);
363 
364         handshake->RegisterRequest(&request_info);
365         int rv = handshake->InitializeStream(true, DEFAULT_PRIORITY,
366                                              NetLogWithSource(),
367                                              CompletionOnceCallback());
368         EXPECT_THAT(rv, IsOk());
369 
370         HttpResponseInfo response;
371         TestCompletionCallback request_callback;
372         rv = handshake->SendRequest(headers, &response,
373                                     request_callback.callback());
374         EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
375         rv = request_callback.WaitForResult();
376         EXPECT_THAT(rv, IsOk());
377 
378         TestCompletionCallback response_callback;
379         rv = handshake->ReadResponseHeaders(response_callback.callback());
380         EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
381         rv = response_callback.WaitForResult();
382         EXPECT_THAT(rv, IsOk());
383 
384         EXPECT_EQ(200, response.headers->response_code());
385         return handshake->Upgrade();
386       }
387       case HTTP3_HANDSHAKE_STREAM: {
388         const quic::QuicStreamId client_data_stream_id(
389             quic::QuicUtils::GetFirstBidirectionalStreamId(
390                 quic_version_.transport_version, quic::Perspective::IS_CLIENT));
391         quic::QuicCryptoClientConfig crypto_config(
392             quic::test::crypto_test_utils::ProofVerifierForTesting());
393 
394         const quic::QuicConnectionId connection_id(
395             quic::test::TestConnectionId(2));
396         test::QuicTestPacketMaker client_maker(
397             quic_version_, connection_id, &clock_, "mail.example.org",
398             quic::Perspective::IS_CLIENT,
399             /*client_headers_include_h2_stream_dependency_=*/false);
400         test::QuicTestPacketMaker server_maker(
401             quic_version_, connection_id, &clock_, "mail.example.org",
402             quic::Perspective::IS_SERVER,
403             /*client_headers_include_h2_stream_dependency_=*/false);
404         IPEndPoint peer_addr(IPAddress(192, 0, 2, 23), 443);
405         quic::test::MockConnectionIdGenerator connection_id_generator;
406 
407         testing::StrictMock<quic::test::MockQuicConnectionVisitor> visitor;
408         ProofVerifyDetailsChromium verify_details;
409         MockCryptoClientStreamFactory crypto_client_stream_factory;
410         TransportSecurityState transport_security_state;
411         SSLConfigServiceDefaults ssl_config_service;
412 
413         FLAGS_quic_enable_http3_grease_randomness = false;
414         clock_.AdvanceTime(quic::QuicTime::Delta::FromMilliseconds(20));
415         quic::QuicEnableVersion(quic_version_);
416         quic::test::MockRandom random_generator{0};
417 
418         quiche::HttpHeaderBlock request_header_block = WebSocketHttp2Request(
419             kPath, "www.example.org", kOrigin, extra_request_headers);
420 
421         int packet_number = 1;
422         mock_quic_data_.AddWrite(
423             SYNCHRONOUS,
424             client_maker.MakeInitialSettingsPacket(packet_number++));
425 
426         mock_quic_data_.AddWrite(
427             ASYNC,
428             client_maker.MakeRequestHeadersPacket(
429                 packet_number++, client_data_stream_id,
430                 /*fin=*/false, ConvertRequestPriorityToQuicPriority(LOWEST),
431                 std::move(request_header_block), nullptr));
432 
433         quiche::HttpHeaderBlock response_header_block =
434             WebSocketHttp2Response(extra_response_headers);
435 
436         mock_quic_data_.AddRead(
437             ASYNC, server_maker.MakeResponseHeadersPacket(
438                        /*packet_number=*/1, client_data_stream_id,
439                        /*fin=*/false, std::move(response_header_block),
440                        /*spdy_headers_frame_length=*/nullptr));
441 
442         mock_quic_data_.AddRead(SYNCHRONOUS, ERR_IO_PENDING);
443 
444         mock_quic_data_.AddWrite(
445             SYNCHRONOUS,
446             client_maker.Packet(packet_number++)
447                 .AddAckFrame(/*first_received=*/1, /*largest_received=*/1,
448                              /*smallest_received=*/0)
449                 .AddStopSendingFrame(client_data_stream_id,
450                                      quic::QUIC_STREAM_CANCELLED)
451                 .AddRstStreamFrame(client_data_stream_id,
452                                    quic::QUIC_STREAM_CANCELLED)
453                 .Build());
454         auto socket = std::make_unique<MockUDPClientSocket>(
455             mock_quic_data_.InitializeAndGetSequencedSocketData(),
456             NetLog::Get());
457         socket->Connect(peer_addr);
458 
459         scoped_refptr<test::TestTaskRunner> runner =
460             base::MakeRefCounted<test::TestTaskRunner>(&clock_);
461         auto helper = std::make_unique<QuicChromiumConnectionHelper>(
462             &clock_, &random_generator);
463         auto alarm_factory =
464             std::make_unique<QuicChromiumAlarmFactory>(runner.get(), &clock_);
465         // Ownership of 'writer' is passed to 'QuicConnection'.
466         QuicChromiumPacketWriter* writer = new QuicChromiumPacketWriter(
467             socket.get(),
468             base::SingleThreadTaskRunner::GetCurrentDefault().get());
469         quic::QuicConnection* connection = new quic::QuicConnection(
470             connection_id, quic::QuicSocketAddress(),
471             net::ToQuicSocketAddress(peer_addr), helper.get(),
472             alarm_factory.get(), writer, true /* owns_writer */,
473             quic::Perspective::IS_CLIENT,
474             quic::test::SupportedVersions(quic_version_),
475             connection_id_generator);
476         connection->set_visitor(&visitor);
477 
478         // Load a certificate that is valid for *.example.org
479         scoped_refptr<X509Certificate> test_cert(
480             ImportCertFromFile(GetTestCertsDirectory(), "wildcard.pem"));
481         EXPECT_TRUE(test_cert.get());
482 
483         verify_details.cert_verify_result.verified_cert = test_cert;
484         verify_details.cert_verify_result.is_issued_by_known_root = true;
485         crypto_client_stream_factory.AddProofVerifyDetails(&verify_details);
486 
487         base::TimeTicks dns_end = base::TimeTicks::Now();
488         base::TimeTicks dns_start = dns_end - base::Milliseconds(1);
489 
490         session_ = std::make_unique<QuicChromiumClientSession>(
491             connection, std::move(socket),
492             /*stream_factory=*/nullptr, &crypto_client_stream_factory, &clock_,
493             &transport_security_state, &ssl_config_service,
494             /*server_info=*/nullptr,
495             QuicSessionAliasKey(
496                 url::SchemeHostPort(),
497                 QuicSessionKey("mail.example.org", 80, PRIVACY_MODE_DISABLED,
498                                ProxyChain::Direct(), SessionUsage::kDestination,
499                                SocketTag(), NetworkAnonymizationKey(),
500                                SecureDnsPolicy::kAllow,
501                                /*require_dns_https_alpn=*/false)),
502             /*require_confirmation=*/false,
503             /*migrate_session_early_v2=*/false,
504             /*migrate_session_on_network_change_v2=*/false,
505             /*default_network=*/handles::kInvalidNetworkHandle,
506             quic::QuicTime::Delta::FromMilliseconds(
507                 kDefaultRetransmittableOnWireTimeout.InMilliseconds()),
508             /*migrate_idle_session=*/true, /*allow_port_migration=*/false,
509             kDefaultIdleSessionMigrationPeriod,
510             /*multi_port_probing_interval=*/0, kMaxTimeOnNonDefaultNetwork,
511             kMaxMigrationsToNonDefaultNetworkOnWriteError,
512             kMaxMigrationsToNonDefaultNetworkOnPathDegrading,
513             kQuicYieldAfterPacketsRead,
514             quic::QuicTime::Delta::FromMilliseconds(
515                 kQuicYieldAfterDurationMilliseconds),
516             /*cert_verify_flags=*/0, quic::test::DefaultQuicConfig(),
517             std::make_unique<TestQuicCryptoClientConfigHandle>(&crypto_config),
518             "CONNECTION_UNKNOWN", dns_start, dns_end,
519             base::DefaultTickClock::GetInstance(),
520             base::SingleThreadTaskRunner::GetCurrentDefault().get(),
521             /*socket_performance_watcher=*/nullptr,
522             ConnectionEndpointMetadata(), /*report_ecn=*/true,
523             /*enable_origin_frame=*/true,
524             /*allow_server_preferred_address=*/true,
525             MultiplexedSessionCreationInitiator::kUnknown,
526             NetLogWithSource::Make(NetLogSourceType::NONE));
527 
528         session_->Initialize();
529 
530         // Blackhole QPACK decoder stream instead of constructing mock writes.
531         session_->qpack_decoder()->set_qpack_stream_sender_delegate(
532             &noop_qpack_stream_sender_delegate_);
533         TestCompletionCallback callback;
534         EXPECT_THAT(session_->CryptoConnect(callback.callback()), IsOk());
535         EXPECT_TRUE(session_->OneRttKeysAvailable());
536         std::unique_ptr<QuicChromiumClientSession::Handle> session_handle =
537             session_->CreateHandle(
538                 url::SchemeHostPort(url::kHttpsScheme, "mail.example.org", 80));
539 
540         std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
541             create_helper.CreateHttp3Stream(std::move(session_handle),
542                                             {} /* dns_aliases */);
543 
544         handshake->RegisterRequest(&request_info);
545         int rv = handshake->InitializeStream(true, DEFAULT_PRIORITY, net_log,
546                                              CompletionOnceCallback());
547         EXPECT_THAT(rv, IsOk());
548 
549         HttpResponseInfo response;
550         TestCompletionCallback request_callback;
551         rv = handshake->SendRequest(headers, &response,
552                                     request_callback.callback());
553         EXPECT_THAT(rv, IsOk());
554 
555         session_->StartReading();
556 
557         TestCompletionCallback response_callback;
558         rv = handshake->ReadResponseHeaders(response_callback.callback());
559         EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
560         rv = response_callback.WaitForResult();
561         EXPECT_THAT(rv, IsOk());
562 
563         EXPECT_EQ(200, response.headers->response_code());
564 
565         return handshake->Upgrade();
566       }
567       default:
568         NOTREACHED();
569     }
570   }
571 
572  private:
573   MockClientSocketHandleFactory socket_handle_factory_;
574   TestConnectDelegate connect_delegate_;
575   StrictMock<MockWebSocketStreamRequestAPI> stream_request_;
576   WebSocketEndpointLockManager websocket_endpoint_lock_manager_;
577 
578   // For HTTP3_HANDSHAKE_STREAM
579   quic::ParsedQuicVersion quic_version_;
580   quic::MockClock clock_;
581   std::unique_ptr<QuicChromiumClientSession> session_;
582   test::MockQuicData mock_quic_data_;
583   quic::test::NoopQpackStreamSenderDelegate noop_qpack_stream_sender_delegate_;
584 };
585 
586 INSTANTIATE_TEST_SUITE_P(All,
587                          WebSocketHandshakeStreamCreateHelperTest,
588                          Values(BASIC_HANDSHAKE_STREAM,
589                                 HTTP2_HANDSHAKE_STREAM,
590                                 HTTP3_HANDSHAKE_STREAM));
591 
592 // Confirm that the basic case works as expected.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,BasicStream)593 TEST_P(WebSocketHandshakeStreamCreateHelperTest, BasicStream) {
594   std::unique_ptr<WebSocketStream> stream =
595       CreateAndInitializeStream({}, {}, {});
596   EXPECT_EQ("", stream->GetExtensions());
597   EXPECT_EQ("", stream->GetSubProtocol());
598 }
599 
600 // Verify that the sub-protocols are passed through.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,SubProtocols)601 TEST_P(WebSocketHandshakeStreamCreateHelperTest, SubProtocols) {
602   std::vector<std::string> sub_protocols;
603   sub_protocols.push_back("chat");
604   sub_protocols.push_back("superchat");
605   std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
606       sub_protocols, {{"Sec-WebSocket-Protocol", "chat, superchat"}},
607       {{"Sec-WebSocket-Protocol", "superchat"}});
608   EXPECT_EQ("superchat", stream->GetSubProtocol());
609 }
610 
611 // Verify that extension name is available. Bad extension names are tested in
612 // websocket_stream_test.cc.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,Extensions)613 TEST_P(WebSocketHandshakeStreamCreateHelperTest, Extensions) {
614   std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
615       {}, {}, {{"Sec-WebSocket-Extensions", "permessage-deflate"}});
616   EXPECT_EQ("permessage-deflate", stream->GetExtensions());
617 }
618 
619 // Verify that extension parameters are available. Bad parameters are tested in
620 // websocket_stream_test.cc.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,ExtensionParameters)621 TEST_P(WebSocketHandshakeStreamCreateHelperTest, ExtensionParameters) {
622   std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
623       {}, {},
624       {{"Sec-WebSocket-Extensions",
625         "permessage-deflate;"
626         " client_max_window_bits=14; server_max_window_bits=14;"
627         " server_no_context_takeover; client_no_context_takeover"}});
628 
629   EXPECT_EQ(
630       "permessage-deflate;"
631       " client_max_window_bits=14; server_max_window_bits=14;"
632       " server_no_context_takeover; client_no_context_takeover",
633       stream->GetExtensions());
634 }
635 
636 }  // namespace
637 
638 }  // namespace net
639