• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_handshake_stream_create_helper.h"
6 
7 #include <string>
8 #include <utility>
9 #include <vector>
10 
11 #include "base/containers/span.h"
12 #include "base/functional/callback.h"
13 #include "base/memory/scoped_refptr.h"
14 #include "base/notreached.h"
15 #include "base/strings/string_piece.h"
16 #include "base/task/single_thread_task_runner.h"
17 #include "base/time/default_tick_clock.h"
18 #include "base/time/time.h"
19 #include "net/base/auth.h"
20 #include "net/base/completion_once_callback.h"
21 #include "net/base/host_port_pair.h"
22 #include "net/base/ip_address.h"
23 #include "net/base/ip_endpoint.h"
24 #include "net/base/load_flags.h"
25 #include "net/base/net_errors.h"
26 #include "net/base/network_anonymization_key.h"
27 #include "net/base/network_handle.h"
28 #include "net/base/privacy_mode.h"
29 #include "net/base/proxy_server.h"
30 #include "net/base/request_priority.h"
31 #include "net/base/test_completion_callback.h"
32 #include "net/cert/cert_verify_result.h"
33 #include "net/dns/public/host_resolver_results.h"
34 #include "net/dns/public/secure_dns_policy.h"
35 #include "net/http/http_request_info.h"
36 #include "net/http/http_response_headers.h"
37 #include "net/http/http_response_info.h"
38 #include "net/http/transport_security_state.h"
39 #include "net/log/net_log.h"
40 #include "net/log/net_log_with_source.h"
41 #include "net/quic/address_utils.h"
42 #include "net/quic/crypto/proof_verifier_chromium.h"
43 #include "net/quic/mock_crypto_client_stream_factory.h"
44 #include "net/quic/mock_quic_data.h"
45 #include "net/quic/quic_chromium_alarm_factory.h"
46 #include "net/quic/quic_chromium_connection_helper.h"
47 #include "net/quic/quic_chromium_packet_reader.h"
48 #include "net/quic/quic_chromium_packet_writer.h"
49 #include "net/quic/quic_context.h"
50 #include "net/quic/quic_http_utils.h"
51 #include "net/quic/quic_server_info.h"
52 #include "net/quic/quic_session_key.h"
53 #include "net/quic/quic_test_packet_maker.h"
54 #include "net/quic/test_quic_crypto_client_config_handle.h"
55 #include "net/quic/test_task_runner.h"
56 #include "net/socket/client_socket_handle.h"
57 #include "net/socket/client_socket_pool.h"
58 #include "net/socket/connect_job.h"
59 #include "net/socket/socket_tag.h"
60 #include "net/socket/socket_test_util.h"
61 #include "net/socket/websocket_endpoint_lock_manager.h"
62 #include "net/spdy/spdy_session_key.h"
63 #include "net/spdy/spdy_test_util_common.h"
64 #include "net/ssl/ssl_config_service_defaults.h"
65 #include "net/ssl/ssl_info.h"
66 #include "net/test/cert_test_util.h"
67 #include "net/test/gtest_util.h"
68 #include "net/test/test_data_directory.h"
69 #include "net/test/test_with_task_environment.h"
70 #include "net/third_party/quiche/src/quiche/common/platform/api/quiche_flags.h"
71 #include "net/third_party/quiche/src/quiche/quic/core/crypto/quic_crypto_client_config.h"
72 #include "net/third_party/quiche/src/quiche/quic/core/qpack/qpack_decoder.h"
73 #include "net/third_party/quiche/src/quiche/quic/core/quic_connection.h"
74 #include "net/third_party/quiche/src/quiche/quic/core/quic_connection_id.h"
75 #include "net/third_party/quiche/src/quiche/quic/core/quic_error_codes.h"
76 #include "net/third_party/quiche/src/quiche/quic/core/quic_packets.h"
77 #include "net/third_party/quiche/src/quiche/quic/core/quic_time.h"
78 #include "net/third_party/quiche/src/quiche/quic/core/quic_types.h"
79 #include "net/third_party/quiche/src/quiche/quic/core/quic_utils.h"
80 #include "net/third_party/quiche/src/quiche/quic/core/quic_versions.h"
81 #include "net/third_party/quiche/src/quiche/quic/platform/api/quic_socket_address.h"
82 #include "net/third_party/quiche/src/quiche/quic/test_tools/crypto_test_utils.h"
83 #include "net/third_party/quiche/src/quiche/quic/test_tools/mock_clock.h"
84 #include "net/third_party/quiche/src/quiche/quic/test_tools/mock_connection_id_generator.h"
85 #include "net/third_party/quiche/src/quiche/quic/test_tools/mock_random.h"
86 #include "net/third_party/quiche/src/quiche/quic/test_tools/qpack/qpack_test_utils.h"
87 #include "net/third_party/quiche/src/quiche/quic/test_tools/quic_test_utils.h"
88 #include "net/third_party/quiche/src/quiche/spdy/core/http2_header_block.h"
89 #include "net/third_party/quiche/src/quiche/spdy/core/spdy_protocol.h"
90 #include "net/traffic_annotation/network_traffic_annotation.h"
91 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
92 #include "net/websockets/websocket_basic_handshake_stream.h"
93 #include "net/websockets/websocket_event_interface.h"
94 #include "net/websockets/websocket_stream.h"
95 #include "net/websockets/websocket_test_util.h"
96 #include "testing/gmock/include/gmock/gmock.h"
97 #include "testing/gtest/include/gtest/gtest.h"
98 #include "third_party/abseil-cpp/absl/types/optional.h"
99 #include "url/gurl.h"
100 #include "url/origin.h"
101 #include "url/scheme_host_port.h"
102 #include "url/url_constants.h"
103 
104 namespace net {
105 class HttpNetworkSession;
106 class URLRequest;
107 class WebSocketHttp2HandshakeStream;
108 class WebSocketHttp3HandshakeStream;
109 class X509Certificate;
110 struct WebSocketHandshakeRequestInfo;
111 struct WebSocketHandshakeResponseInfo;
112 }  // namespace net
113 
114 using ::net::test::IsError;
115 using ::net::test::IsOk;
116 using ::testing::StrictMock;
117 using ::testing::TestWithParam;
118 using ::testing::Values;
119 using ::testing::_;
120 
121 namespace net {
122 namespace {
123 
124 enum HandshakeStreamType {
125   BASIC_HANDSHAKE_STREAM,
126   HTTP2_HANDSHAKE_STREAM,
127   HTTP3_HANDSHAKE_STREAM
128 };
129 
130 // This class encapsulates the details of creating a mock ClientSocketHandle.
131 class MockClientSocketHandleFactory {
132  public:
MockClientSocketHandleFactory()133   MockClientSocketHandleFactory()
134       : common_connect_job_params_(
135             socket_factory_maker_.factory(),
136             /*host_resolver=*/nullptr,
137             /*http_auth_cache=*/nullptr,
138             /*http_auth_handler_factory=*/nullptr,
139             /*spdy_session_pool=*/nullptr,
140             /*quic_supported_versions=*/nullptr,
141             /*quic_stream_factory=*/nullptr,
142             /*proxy_delegate=*/nullptr,
143             /*http_user_agent_settings=*/nullptr,
144             /*ssl_client_context=*/nullptr,
145             /*socket_performance_watcher_factory=*/nullptr,
146             /*network_quality_estimator=*/nullptr,
147             /*net_log=*/nullptr,
148             /*websocket_endpoint_lock_manager=*/nullptr,
149             /*http_server_properties=*/nullptr,
150             /*alpn_protos=*/nullptr,
151             /*application_settings=*/nullptr,
152             /*ignore_certificate_errors=*/nullptr),
153         pool_(1, 1, &common_connect_job_params_) {}
154 
155   MockClientSocketHandleFactory(const MockClientSocketHandleFactory&) = delete;
156   MockClientSocketHandleFactory& operator=(
157       const MockClientSocketHandleFactory&) = delete;
158 
159   // The created socket expects |expect_written| to be written to the socket,
160   // and will respond with |return_to_read|. The test will fail if the expected
161   // text is not written, or if all the bytes are not read.
CreateClientSocketHandle(const std::string & expect_written,const std::string & return_to_read)162   std::unique_ptr<ClientSocketHandle> CreateClientSocketHandle(
163       const std::string& expect_written,
164       const std::string& return_to_read) {
165     socket_factory_maker_.SetExpectations(expect_written, return_to_read);
166     auto socket_handle = std::make_unique<ClientSocketHandle>();
167     socket_handle->Init(
168         ClientSocketPool::GroupId(
169             url::SchemeHostPort(url::kHttpScheme, "a", 80),
170             PrivacyMode::PRIVACY_MODE_DISABLED, NetworkAnonymizationKey(),
171             SecureDnsPolicy::kAllow),
172         scoped_refptr<ClientSocketPool::SocketParams>(),
173         absl::nullopt /* proxy_annotation_tag */, MEDIUM, SocketTag(),
174         ClientSocketPool::RespectLimits::ENABLED, CompletionOnceCallback(),
175         ClientSocketPool::ProxyAuthCallback(), &pool_, NetLogWithSource());
176     return socket_handle;
177   }
178 
179  private:
180   WebSocketMockClientSocketFactoryMaker socket_factory_maker_;
181   const CommonConnectJobParams common_connect_job_params_;
182   MockTransportClientSocketPool pool_;
183 };
184 
185 class TestConnectDelegate : public WebSocketStream::ConnectDelegate {
186  public:
187   ~TestConnectDelegate() override = default;
188 
OnCreateRequest(URLRequest * request)189   void OnCreateRequest(URLRequest* request) override {}
OnSuccess(std::unique_ptr<WebSocketStream> stream,std::unique_ptr<WebSocketHandshakeResponseInfo> response)190   void OnSuccess(
191       std::unique_ptr<WebSocketStream> stream,
192       std::unique_ptr<WebSocketHandshakeResponseInfo> response) override {}
OnFailure(const std::string & failure_message,int net_error,absl::optional<int> response_code)193   void OnFailure(const std::string& failure_message,
194                  int net_error,
195                  absl::optional<int> response_code) override {}
OnStartOpeningHandshake(std::unique_ptr<WebSocketHandshakeRequestInfo> request)196   void OnStartOpeningHandshake(
197       std::unique_ptr<WebSocketHandshakeRequestInfo> request) override {}
OnSSLCertificateError(std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,int net_error,const SSLInfo & ssl_info,bool fatal)198   void OnSSLCertificateError(
199       std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
200           ssl_error_callbacks,
201       int net_error,
202       const SSLInfo& ssl_info,
203       bool fatal) override {}
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> response_headers,const IPEndPoint & host_port_pair,base::OnceCallback<void (const AuthCredentials *)> callback,absl::optional<AuthCredentials> * credentials)204   int OnAuthRequired(const AuthChallengeInfo& auth_info,
205                      scoped_refptr<HttpResponseHeaders> response_headers,
206                      const IPEndPoint& host_port_pair,
207                      base::OnceCallback<void(const AuthCredentials*)> callback,
208                      absl::optional<AuthCredentials>* credentials) override {
209     *credentials = absl::nullopt;
210     return OK;
211   }
212 };
213 
214 class MockWebSocketStreamRequestAPI : public WebSocketStreamRequestAPI {
215  public:
216   ~MockWebSocketStreamRequestAPI() override = default;
217 
218   MOCK_METHOD1(OnBasicHandshakeStreamCreated,
219                void(WebSocketBasicHandshakeStream* handshake_stream));
220   MOCK_METHOD1(OnHttp2HandshakeStreamCreated,
221                void(WebSocketHttp2HandshakeStream* handshake_stream));
222   MOCK_METHOD1(OnHttp3HandshakeStreamCreated,
223                void(WebSocketHttp3HandshakeStream* handshake_stream));
224   MOCK_METHOD3(OnFailure,
225                void(const std::string& message,
226                     int net_error,
227                     absl::optional<int> response_code));
228 };
229 
230 class WebSocketHandshakeStreamCreateHelperTest
231     : public TestWithParam<HandshakeStreamType>,
232       public WithTaskEnvironment {
233  protected:
WebSocketHandshakeStreamCreateHelperTest()234   WebSocketHandshakeStreamCreateHelperTest()
235       : quic_version_(quic::HandshakeProtocol::PROTOCOL_TLS1_3,
236                       quic::QuicTransportVersion::QUIC_VERSION_IETF_RFC_V1),
237         mock_quic_data_(quic_version_) {}
CreateAndInitializeStream(const std::vector<std::string> & sub_protocols,const WebSocketExtraHeaders & extra_request_headers,const WebSocketExtraHeaders & extra_response_headers)238   std::unique_ptr<WebSocketStream> CreateAndInitializeStream(
239       const std::vector<std::string>& sub_protocols,
240       const WebSocketExtraHeaders& extra_request_headers,
241       const WebSocketExtraHeaders& extra_response_headers) {
242     const char kPath[] = "/";
243     const char kOrigin[] = "http://origin.example.org";
244     const GURL url("wss://www.example.org/");
245     NetLogWithSource net_log;
246 
247     WebSocketHandshakeStreamCreateHelper create_helper(
248         &connect_delegate_, sub_protocols, &stream_request_);
249 
250     switch (GetParam()) {
251       case BASIC_HANDSHAKE_STREAM:
252         EXPECT_CALL(stream_request_, OnBasicHandshakeStreamCreated(_)).Times(1);
253         break;
254 
255       case HTTP2_HANDSHAKE_STREAM:
256         EXPECT_CALL(stream_request_, OnHttp2HandshakeStreamCreated(_)).Times(1);
257         break;
258 
259       case HTTP3_HANDSHAKE_STREAM:
260         EXPECT_CALL(stream_request_, OnHttp3HandshakeStreamCreated(_)).Times(1);
261         break;
262 
263       default:
264         NOTREACHED();
265     }
266 
267     EXPECT_CALL(stream_request_, OnFailure(_, _, _)).Times(0);
268 
269     HttpRequestInfo request_info;
270     request_info.url = url;
271     request_info.method = "GET";
272     request_info.load_flags = LOAD_DISABLE_CACHE;
273     request_info.traffic_annotation =
274         MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS);
275 
276     auto headers = WebSocketCommonTestHeaders();
277 
278     switch (GetParam()) {
279       case BASIC_HANDSHAKE_STREAM: {
280         std::unique_ptr<ClientSocketHandle> socket_handle =
281             socket_handle_factory_.CreateClientSocketHandle(
282                 WebSocketStandardRequest(kPath, "www.example.org",
283                                          url::Origin::Create(GURL(kOrigin)),
284                                          /*send_additional_request_headers=*/{},
285                                          extra_request_headers),
286                 WebSocketStandardResponse(
287                     WebSocketExtraHeadersToString(extra_response_headers)));
288 
289         std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
290             create_helper.CreateBasicStream(std::move(socket_handle), false,
291                                             &websocket_endpoint_lock_manager_);
292 
293         // If in future the implementation type returned by CreateBasicStream()
294         // changes, this static_cast will be wrong. However, in that case the
295         // test will fail and AddressSanitizer should identify the issue.
296         static_cast<WebSocketBasicHandshakeStream*>(handshake.get())
297             ->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
298 
299         handshake->RegisterRequest(&request_info);
300         int rv = handshake->InitializeStream(true, DEFAULT_PRIORITY, net_log,
301                                              CompletionOnceCallback());
302         EXPECT_THAT(rv, IsOk());
303 
304         HttpResponseInfo response;
305         TestCompletionCallback request_callback;
306         rv = handshake->SendRequest(headers, &response,
307                                     request_callback.callback());
308         EXPECT_THAT(rv, IsOk());
309 
310         TestCompletionCallback response_callback;
311         rv = handshake->ReadResponseHeaders(response_callback.callback());
312         EXPECT_THAT(rv, IsOk());
313         EXPECT_EQ(101, response.headers->response_code());
314         EXPECT_TRUE(response.headers->HasHeaderValue("Connection", "Upgrade"));
315         EXPECT_TRUE(response.headers->HasHeaderValue("Upgrade", "websocket"));
316         return handshake->Upgrade();
317       }
318       case HTTP2_HANDSHAKE_STREAM: {
319         SpdyTestUtil spdy_util;
320         spdy::Http2HeaderBlock request_header_block = WebSocketHttp2Request(
321             kPath, "www.example.org", kOrigin, extra_request_headers);
322         spdy::SpdySerializedFrame request_headers(
323             spdy_util.ConstructSpdyHeaders(1, std::move(request_header_block),
324                                            DEFAULT_PRIORITY, false));
325         MockWrite writes[] = {CreateMockWrite(request_headers, 0)};
326 
327         spdy::Http2HeaderBlock response_header_block =
328             WebSocketHttp2Response(extra_response_headers);
329         spdy::SpdySerializedFrame response_headers(
330             spdy_util.ConstructSpdyResponseHeaders(
331                 1, std::move(response_header_block), false));
332         MockRead reads[] = {CreateMockRead(response_headers, 1),
333                             MockRead(ASYNC, 0, 2)};
334 
335         SequencedSocketData data(reads, writes);
336 
337         SSLSocketDataProvider ssl(ASYNC, OK);
338         ssl.ssl_info.cert =
339             ImportCertFromFile(GetTestCertsDirectory(), "wildcard.pem");
340 
341         SpdySessionDependencies session_deps;
342         session_deps.socket_factory->AddSocketDataProvider(&data);
343         session_deps.socket_factory->AddSSLSocketDataProvider(&ssl);
344 
345         std::unique_ptr<HttpNetworkSession> http_network_session =
346             SpdySessionDependencies::SpdyCreateSession(&session_deps);
347         const SpdySessionKey key(
348             HostPortPair::FromURL(url), ProxyChain::Direct(),
349             PRIVACY_MODE_DISABLED, SpdySessionKey::IsProxySession::kFalse,
350             SocketTag(), NetworkAnonymizationKey(), SecureDnsPolicy::kAllow);
351         base::WeakPtr<SpdySession> spdy_session =
352             CreateSpdySession(http_network_session.get(), key, net_log);
353         std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
354             create_helper.CreateHttp2Stream(spdy_session, {} /* dns_aliases */);
355 
356         handshake->RegisterRequest(&request_info);
357         int rv = handshake->InitializeStream(true, DEFAULT_PRIORITY,
358                                              NetLogWithSource(),
359                                              CompletionOnceCallback());
360         EXPECT_THAT(rv, IsOk());
361 
362         HttpResponseInfo response;
363         TestCompletionCallback request_callback;
364         rv = handshake->SendRequest(headers, &response,
365                                     request_callback.callback());
366         EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
367         rv = request_callback.WaitForResult();
368         EXPECT_THAT(rv, IsOk());
369 
370         TestCompletionCallback response_callback;
371         rv = handshake->ReadResponseHeaders(response_callback.callback());
372         EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
373         rv = response_callback.WaitForResult();
374         EXPECT_THAT(rv, IsOk());
375 
376         EXPECT_EQ(200, response.headers->response_code());
377         return handshake->Upgrade();
378       }
379       case HTTP3_HANDSHAKE_STREAM: {
380         const quic::QuicStreamId client_data_stream_id(
381             quic::QuicUtils::GetFirstBidirectionalStreamId(
382                 quic_version_.transport_version, quic::Perspective::IS_CLIENT));
383         quic::QuicCryptoClientConfig crypto_config(
384             quic::test::crypto_test_utils::ProofVerifierForTesting());
385 
386         const quic::QuicConnectionId connection_id(
387             quic::test::TestConnectionId(2));
388         test::QuicTestPacketMaker client_maker(
389             quic_version_, connection_id, &clock_, "mail.example.org",
390             quic::Perspective::IS_CLIENT,
391             /*client_headers_include_h2_stream_dependency_=*/false);
392         test::QuicTestPacketMaker server_maker(
393             quic_version_, connection_id, &clock_, "mail.example.org",
394             quic::Perspective::IS_SERVER,
395             /*client_headers_include_h2_stream_dependency_=*/false);
396         IPEndPoint peer_addr(IPAddress(192, 0, 2, 23), 443);
397         quic::test::MockConnectionIdGenerator connection_id_generator;
398 
399         testing::StrictMock<quic::test::MockQuicConnectionVisitor> visitor;
400         ProofVerifyDetailsChromium verify_details;
401         MockCryptoClientStreamFactory crypto_client_stream_factory;
402         TransportSecurityState transport_security_state;
403         SSLConfigServiceDefaults ssl_config_service;
404 
405         FLAGS_quic_enable_http3_grease_randomness = false;
406         clock_.AdvanceTime(quic::QuicTime::Delta::FromMilliseconds(20));
407         quic::QuicEnableVersion(quic_version_);
408         quic::test::MockRandom random_generator{0};
409 
410         spdy::Http2HeaderBlock request_header_block = WebSocketHttp2Request(
411             kPath, "www.example.org", kOrigin, extra_request_headers);
412 
413         int packet_number = 1;
414         mock_quic_data_.AddWrite(
415             SYNCHRONOUS,
416             client_maker.MakeInitialSettingsPacket(packet_number++));
417 
418         mock_quic_data_.AddWrite(
419             ASYNC,
420             client_maker.MakeRequestHeadersPacket(
421                 packet_number++, client_data_stream_id,
422                 /*fin=*/false, ConvertRequestPriorityToQuicPriority(LOWEST),
423                 std::move(request_header_block), nullptr));
424 
425         spdy::Http2HeaderBlock response_header_block =
426             WebSocketHttp2Response(extra_response_headers);
427 
428         mock_quic_data_.AddRead(
429             ASYNC, server_maker.MakeResponseHeadersPacket(
430                        /*packet_number=*/1, client_data_stream_id,
431                        /*fin=*/false, std::move(response_header_block),
432                        /*spdy_headers_frame_length=*/nullptr));
433 
434         mock_quic_data_.AddRead(SYNCHRONOUS, ERR_IO_PENDING);
435 
436         mock_quic_data_.AddWrite(SYNCHRONOUS,
437                                  client_maker.MakeAckAndRstPacket(
438                                      packet_number++, client_data_stream_id,
439                                      quic::QUIC_STREAM_CANCELLED, 1, 0,
440                                      /*include_stop_sending_if_v99=*/true));
441         auto socket = std::make_unique<MockUDPClientSocket>(
442             mock_quic_data_.InitializeAndGetSequencedSocketData(),
443             NetLog::Get());
444         socket->Connect(peer_addr);
445 
446         scoped_refptr<test::TestTaskRunner> runner =
447             base::MakeRefCounted<test::TestTaskRunner>(&clock_);
448         auto helper = std::make_unique<QuicChromiumConnectionHelper>(
449             &clock_, &random_generator);
450         auto alarm_factory =
451             std::make_unique<QuicChromiumAlarmFactory>(runner.get(), &clock_);
452         // Ownership of 'writer' is passed to 'QuicConnection'.
453         QuicChromiumPacketWriter* writer = new QuicChromiumPacketWriter(
454             socket.get(),
455             base::SingleThreadTaskRunner::GetCurrentDefault().get());
456         quic::QuicConnection* connection = new quic::QuicConnection(
457             connection_id, quic::QuicSocketAddress(),
458             net::ToQuicSocketAddress(peer_addr), helper.get(),
459             alarm_factory.get(), writer, true /* owns_writer */,
460             quic::Perspective::IS_CLIENT,
461             quic::test::SupportedVersions(quic_version_),
462             connection_id_generator);
463         connection->set_visitor(&visitor);
464 
465         // Load a certificate that is valid for *.example.org
466         scoped_refptr<X509Certificate> test_cert(
467             ImportCertFromFile(GetTestCertsDirectory(), "wildcard.pem"));
468         EXPECT_TRUE(test_cert.get());
469 
470         verify_details.cert_verify_result.verified_cert = test_cert;
471         verify_details.cert_verify_result.is_issued_by_known_root = true;
472         crypto_client_stream_factory.AddProofVerifyDetails(&verify_details);
473 
474         base::TimeTicks dns_end = base::TimeTicks::Now();
475         base::TimeTicks dns_start = dns_end - base::Milliseconds(1);
476 
477         session_ = std::make_unique<QuicChromiumClientSession>(
478             connection, std::move(socket),
479             /*stream_factory=*/nullptr, &crypto_client_stream_factory, &clock_,
480             &transport_security_state, &ssl_config_service,
481             /*server_info=*/nullptr,
482             QuicSessionKey("mail.example.org", 80, PRIVACY_MODE_DISABLED,
483                            SocketTag(), NetworkAnonymizationKey(),
484                            SecureDnsPolicy::kAllow,
485                            /*require_dns_https_alpn=*/false),
486             /*require_confirmation=*/false,
487             /*migrate_session_early_v2=*/false,
488             /*migrate_session_on_network_change_v2=*/false,
489             /*default_network=*/handles::kInvalidNetworkHandle,
490             quic::QuicTime::Delta::FromMilliseconds(
491                 kDefaultRetransmittableOnWireTimeout.InMilliseconds()),
492             /*migrate_idle_session=*/true, /*allow_port_migration=*/false,
493             kDefaultIdleSessionMigrationPeriod,
494             /*multi_port_probing_interval=*/0, kMaxTimeOnNonDefaultNetwork,
495             kMaxMigrationsToNonDefaultNetworkOnWriteError,
496             kMaxMigrationsToNonDefaultNetworkOnPathDegrading,
497             kQuicYieldAfterPacketsRead,
498             quic::QuicTime::Delta::FromMilliseconds(
499                 kQuicYieldAfterDurationMilliseconds),
500             /*cert_verify_flags=*/0, quic::test::DefaultQuicConfig(),
501             std::make_unique<TestQuicCryptoClientConfigHandle>(&crypto_config),
502             dns_start, dns_end, base::DefaultTickClock::GetInstance(),
503             base::SingleThreadTaskRunner::GetCurrentDefault().get(),
504             /*socket_performance_watcher=*/nullptr,
505             HostResolverEndpointResult(), NetLog::Get());
506 
507         session_->Initialize();
508 
509         // Blackhole QPACK decoder stream instead of constructing mock writes.
510         session_->qpack_decoder()->set_qpack_stream_sender_delegate(
511             &noop_qpack_stream_sender_delegate_);
512         TestCompletionCallback callback;
513         EXPECT_THAT(session_->CryptoConnect(callback.callback()), IsOk());
514         EXPECT_TRUE(session_->OneRttKeysAvailable());
515         std::unique_ptr<QuicChromiumClientSession::Handle> session_handle =
516             session_->CreateHandle(
517                 url::SchemeHostPort(url::kHttpsScheme, "mail.example.org", 80));
518 
519         std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
520             create_helper.CreateHttp3Stream(std::move(session_handle),
521                                             {} /* dns_aliases */);
522 
523         handshake->RegisterRequest(&request_info);
524         int rv = handshake->InitializeStream(true, DEFAULT_PRIORITY, net_log,
525                                              CompletionOnceCallback());
526         EXPECT_THAT(rv, IsOk());
527 
528         HttpResponseInfo response;
529         TestCompletionCallback request_callback;
530         rv = handshake->SendRequest(headers, &response,
531                                     request_callback.callback());
532         EXPECT_THAT(rv, IsOk());
533 
534         session_->StartReading();
535 
536         TestCompletionCallback response_callback;
537         rv = handshake->ReadResponseHeaders(response_callback.callback());
538         EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
539         rv = response_callback.WaitForResult();
540         EXPECT_THAT(rv, IsOk());
541 
542         EXPECT_EQ(200, response.headers->response_code());
543 
544         return handshake->Upgrade();
545       }
546       default:
547         NOTREACHED();
548         return nullptr;
549     }
550   }
551 
552  private:
553   MockClientSocketHandleFactory socket_handle_factory_;
554   TestConnectDelegate connect_delegate_;
555   StrictMock<MockWebSocketStreamRequestAPI> stream_request_;
556   WebSocketEndpointLockManager websocket_endpoint_lock_manager_;
557 
558   // For HTTP3_HANDSHAKE_STREAM
559   quic::ParsedQuicVersion quic_version_;
560   quic::MockClock clock_;
561   std::unique_ptr<QuicChromiumClientSession> session_;
562   test::MockQuicData mock_quic_data_;
563   quic::test::NoopQpackStreamSenderDelegate noop_qpack_stream_sender_delegate_;
564 };
565 
566 INSTANTIATE_TEST_SUITE_P(All,
567                          WebSocketHandshakeStreamCreateHelperTest,
568                          Values(BASIC_HANDSHAKE_STREAM,
569                                 HTTP2_HANDSHAKE_STREAM,
570                                 HTTP3_HANDSHAKE_STREAM));
571 
572 // Confirm that the basic case works as expected.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,BasicStream)573 TEST_P(WebSocketHandshakeStreamCreateHelperTest, BasicStream) {
574   std::unique_ptr<WebSocketStream> stream =
575       CreateAndInitializeStream({}, {}, {});
576   EXPECT_EQ("", stream->GetExtensions());
577   EXPECT_EQ("", stream->GetSubProtocol());
578 }
579 
580 // Verify that the sub-protocols are passed through.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,SubProtocols)581 TEST_P(WebSocketHandshakeStreamCreateHelperTest, SubProtocols) {
582   std::vector<std::string> sub_protocols;
583   sub_protocols.push_back("chat");
584   sub_protocols.push_back("superchat");
585   std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
586       sub_protocols, {{"Sec-WebSocket-Protocol", "chat, superchat"}},
587       {{"Sec-WebSocket-Protocol", "superchat"}});
588   EXPECT_EQ("superchat", stream->GetSubProtocol());
589 }
590 
591 // Verify that extension name is available. Bad extension names are tested in
592 // websocket_stream_test.cc.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,Extensions)593 TEST_P(WebSocketHandshakeStreamCreateHelperTest, Extensions) {
594   std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
595       {}, {}, {{"Sec-WebSocket-Extensions", "permessage-deflate"}});
596   EXPECT_EQ("permessage-deflate", stream->GetExtensions());
597 }
598 
599 // Verify that extension parameters are available. Bad parameters are tested in
600 // websocket_stream_test.cc.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,ExtensionParameters)601 TEST_P(WebSocketHandshakeStreamCreateHelperTest, ExtensionParameters) {
602   std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
603       {}, {},
604       {{"Sec-WebSocket-Extensions",
605         "permessage-deflate;"
606         " client_max_window_bits=14; server_max_window_bits=14;"
607         " server_no_context_takeover; client_no_context_takeover"}});
608 
609   EXPECT_EQ(
610       "permessage-deflate;"
611       " client_max_window_bits=14; server_max_window_bits=14;"
612       " server_no_context_takeover; client_no_context_takeover",
613       stream->GetExtensions());
614 }
615 
616 }  // namespace
617 
618 }  // namespace net
619