• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_basic_stream_adapters.h"
6 
7 #include <stdint.h>
8 
9 #include <string>
10 #include <string_view>
11 #include <utility>
12 #include <vector>
13 
14 #include "base/check.h"
15 #include "base/containers/span.h"
16 #include "base/functional/bind.h"
17 #include "base/functional/callback.h"
18 #include "base/memory/raw_ptr.h"
19 #include "base/memory/scoped_refptr.h"
20 #include "base/memory/weak_ptr.h"
21 #include "base/run_loop.h"
22 #include "base/strings/strcat.h"
23 #include "base/task/single_thread_task_runner.h"
24 #include "base/time/default_tick_clock.h"
25 #include "base/time/time.h"
26 #include "net/base/host_port_pair.h"
27 #include "net/base/io_buffer.h"
28 #include "net/base/ip_address.h"
29 #include "net/base/ip_endpoint.h"
30 #include "net/base/net_errors.h"
31 #include "net/base/network_anonymization_key.h"
32 #include "net/base/network_handle.h"
33 #include "net/base/privacy_mode.h"
34 #include "net/base/proxy_chain.h"
35 #include "net/base/request_priority.h"
36 #include "net/base/session_usage.h"
37 #include "net/base/test_completion_callback.h"
38 #include "net/cert/cert_verify_result.h"
39 #include "net/dns/public/host_resolver_results.h"
40 #include "net/dns/public/secure_dns_policy.h"
41 #include "net/http/http_network_session.h"
42 #include "net/http/transport_security_state.h"
43 #include "net/log/net_log.h"
44 #include "net/log/net_log_with_source.h"
45 #include "net/quic/address_utils.h"
46 #include "net/quic/crypto/proof_verifier_chromium.h"
47 #include "net/quic/mock_crypto_client_stream_factory.h"
48 #include "net/quic/mock_quic_data.h"
49 #include "net/quic/quic_chromium_alarm_factory.h"
50 #include "net/quic/quic_chromium_client_session.h"
51 #include "net/quic/quic_chromium_client_session_peer.h"
52 #include "net/quic/quic_chromium_connection_helper.h"
53 #include "net/quic/quic_chromium_packet_reader.h"
54 #include "net/quic/quic_chromium_packet_writer.h"
55 #include "net/quic/quic_context.h"
56 #include "net/quic/quic_http_utils.h"
57 #include "net/quic/quic_server_info.h"
58 #include "net/quic/quic_session_alias_key.h"
59 #include "net/quic/quic_session_key.h"
60 #include "net/quic/quic_test_packet_maker.h"
61 #include "net/quic/test_quic_crypto_client_config_handle.h"
62 #include "net/quic/test_task_runner.h"
63 #include "net/socket/client_socket_handle.h"
64 #include "net/socket/client_socket_pool.h"
65 #include "net/socket/next_proto.h"
66 #include "net/socket/socket_tag.h"
67 #include "net/socket/socket_test_util.h"
68 #include "net/socket/stream_socket.h"
69 #include "net/spdy/spdy_session_key.h"
70 #include "net/spdy/spdy_test_util_common.h"
71 #include "net/ssl/ssl_config.h"
72 #include "net/ssl/ssl_config_service_defaults.h"
73 #include "net/ssl/ssl_info.h"
74 #include "net/test/cert_test_util.h"
75 #include "net/test/gtest_util.h"
76 #include "net/test/test_data_directory.h"
77 #include "net/test/test_with_task_environment.h"
78 #include "net/third_party/quiche/src/quiche/common/http/http_header_block.h"
79 #include "net/third_party/quiche/src/quiche/common/platform/api/quiche_flags.h"
80 #include "net/third_party/quiche/src/quiche/common/quiche_buffer_allocator.h"
81 #include "net/third_party/quiche/src/quiche/common/simple_buffer_allocator.h"
82 #include "net/third_party/quiche/src/quiche/http2/core/spdy_protocol.h"
83 #include "net/third_party/quiche/src/quiche/quic/core/crypto/quic_crypto_client_config.h"
84 #include "net/third_party/quiche/src/quiche/quic/core/http/http_encoder.h"
85 #include "net/third_party/quiche/src/quiche/quic/core/qpack/qpack_decoder.h"
86 #include "net/third_party/quiche/src/quiche/quic/core/quic_connection.h"
87 #include "net/third_party/quiche/src/quiche/quic/core/quic_connection_id.h"
88 #include "net/third_party/quiche/src/quiche/quic/core/quic_error_codes.h"
89 #include "net/third_party/quiche/src/quiche/quic/core/quic_packets.h"
90 #include "net/third_party/quiche/src/quiche/quic/core/quic_time.h"
91 #include "net/third_party/quiche/src/quiche/quic/core/quic_types.h"
92 #include "net/third_party/quiche/src/quiche/quic/core/quic_utils.h"
93 #include "net/third_party/quiche/src/quiche/quic/core/quic_versions.h"
94 #include "net/third_party/quiche/src/quiche/quic/platform/api/quic_socket_address.h"
95 #include "net/third_party/quiche/src/quiche/quic/test_tools/crypto_test_utils.h"
96 #include "net/third_party/quiche/src/quiche/quic/test_tools/mock_clock.h"
97 #include "net/third_party/quiche/src/quiche/quic/test_tools/mock_connection_id_generator.h"
98 #include "net/third_party/quiche/src/quiche/quic/test_tools/mock_random.h"
99 #include "net/third_party/quiche/src/quiche/quic/test_tools/qpack/qpack_test_utils.h"
100 #include "net/third_party/quiche/src/quiche/quic/test_tools/quic_test_utils.h"
101 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
102 #include "net/websockets/websocket_test_util.h"
103 #include "testing/gmock/include/gmock/gmock.h"
104 #include "testing/gtest/include/gtest/gtest.h"
105 #include "url/gurl.h"
106 #include "url/scheme_host_port.h"
107 #include "url/url_constants.h"
108 
109 namespace net {
110 class QuicChromiumClientStream;
111 class SpdySession;
112 class WebSocketEndpointLockManager;
113 class X509Certificate;
114 }  // namespace net
115 
116 using testing::_;
117 using testing::AnyNumber;
118 using testing::Invoke;
119 using testing::Return;
120 using testing::StrictMock;
121 using testing::Test;
122 
123 namespace net::test {
124 
125 class WebSocketClientSocketHandleAdapterTest : public TestWithTaskEnvironment {
126  protected:
WebSocketClientSocketHandleAdapterTest()127   WebSocketClientSocketHandleAdapterTest()
128       : network_session_(
129             SpdySessionDependencies::SpdyCreateSession(&session_deps_)),
130         websocket_endpoint_lock_manager_(
131             network_session_->websocket_endpoint_lock_manager()) {}
132 
133   ~WebSocketClientSocketHandleAdapterTest() override = default;
134 
InitClientSocketHandle(ClientSocketHandle * connection)135   bool InitClientSocketHandle(ClientSocketHandle* connection) {
136     scoped_refptr<ClientSocketPool::SocketParams> socks_params =
137         base::MakeRefCounted<ClientSocketPool::SocketParams>(
138             /*allowed_bad_certs=*/std::vector<SSLConfig::CertAndStatus>());
139     TestCompletionCallback callback;
140     int rv = connection->Init(
141         ClientSocketPool::GroupId(
142             url::SchemeHostPort(url::kHttpsScheme, "www.example.org", 443),
143             PrivacyMode::PRIVACY_MODE_DISABLED, NetworkAnonymizationKey(),
144             SecureDnsPolicy::kAllow, /*disable_cert_network_fetches=*/false),
145         socks_params, /*proxy_annotation_tag=*/TRAFFIC_ANNOTATION_FOR_TESTS,
146         MEDIUM, SocketTag(), ClientSocketPool::RespectLimits::ENABLED,
147         callback.callback(), ClientSocketPool::ProxyAuthCallback(),
148         network_session_->GetSocketPool(HttpNetworkSession::NORMAL_SOCKET_POOL,
149                                         ProxyChain::Direct()),
150         NetLogWithSource());
151     rv = callback.GetResult(rv);
152     return rv == OK;
153   }
154 
155   SpdySessionDependencies session_deps_;
156   std::unique_ptr<HttpNetworkSession> network_session_;
157   raw_ptr<WebSocketEndpointLockManager> websocket_endpoint_lock_manager_;
158 };
159 
TEST_F(WebSocketClientSocketHandleAdapterTest,Uninitialized)160 TEST_F(WebSocketClientSocketHandleAdapterTest, Uninitialized) {
161   auto connection = std::make_unique<ClientSocketHandle>();
162   WebSocketClientSocketHandleAdapter adapter(std::move(connection));
163   EXPECT_FALSE(adapter.is_initialized());
164 }
165 
TEST_F(WebSocketClientSocketHandleAdapterTest,IsInitialized)166 TEST_F(WebSocketClientSocketHandleAdapterTest, IsInitialized) {
167   StaticSocketDataProvider data;
168   session_deps_.socket_factory->AddSocketDataProvider(&data);
169   SSLSocketDataProvider ssl_socket_data(ASYNC, OK);
170   session_deps_.socket_factory->AddSSLSocketDataProvider(&ssl_socket_data);
171 
172   auto connection = std::make_unique<ClientSocketHandle>();
173   ClientSocketHandle* const connection_ptr = connection.get();
174 
175   WebSocketClientSocketHandleAdapter adapter(std::move(connection));
176   EXPECT_FALSE(adapter.is_initialized());
177 
178   EXPECT_TRUE(InitClientSocketHandle(connection_ptr));
179 
180   EXPECT_TRUE(adapter.is_initialized());
181 }
182 
TEST_F(WebSocketClientSocketHandleAdapterTest,Disconnect)183 TEST_F(WebSocketClientSocketHandleAdapterTest, Disconnect) {
184   StaticSocketDataProvider data;
185   session_deps_.socket_factory->AddSocketDataProvider(&data);
186   SSLSocketDataProvider ssl_socket_data(ASYNC, OK);
187   session_deps_.socket_factory->AddSSLSocketDataProvider(&ssl_socket_data);
188 
189   auto connection = std::make_unique<ClientSocketHandle>();
190   EXPECT_TRUE(InitClientSocketHandle(connection.get()));
191 
192   StreamSocket* const socket = connection->socket();
193 
194   WebSocketClientSocketHandleAdapter adapter(std::move(connection));
195   EXPECT_TRUE(adapter.is_initialized());
196 
197   EXPECT_TRUE(socket->IsConnected());
198   adapter.Disconnect();
199   EXPECT_FALSE(socket->IsConnected());
200 }
201 
TEST_F(WebSocketClientSocketHandleAdapterTest,Read)202 TEST_F(WebSocketClientSocketHandleAdapterTest, Read) {
203   MockRead reads[] = {MockRead(SYNCHRONOUS, "foo"), MockRead("bar")};
204   StaticSocketDataProvider data(reads, base::span<MockWrite>());
205   session_deps_.socket_factory->AddSocketDataProvider(&data);
206   SSLSocketDataProvider ssl_socket_data(ASYNC, OK);
207   session_deps_.socket_factory->AddSSLSocketDataProvider(&ssl_socket_data);
208 
209   auto connection = std::make_unique<ClientSocketHandle>();
210   EXPECT_TRUE(InitClientSocketHandle(connection.get()));
211 
212   WebSocketClientSocketHandleAdapter adapter(std::move(connection));
213   EXPECT_TRUE(adapter.is_initialized());
214 
215   // Buffer larger than each MockRead.
216   constexpr int kReadBufSize = 1024;
217   auto read_buf = base::MakeRefCounted<IOBufferWithSize>(kReadBufSize);
218   int rv = adapter.Read(read_buf.get(), kReadBufSize, CompletionOnceCallback());
219   ASSERT_EQ(3, rv);
220   EXPECT_EQ("foo", std::string_view(read_buf->data(), rv));
221 
222   TestCompletionCallback callback;
223   rv = adapter.Read(read_buf.get(), kReadBufSize, callback.callback());
224   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
225   rv = callback.WaitForResult();
226   ASSERT_EQ(3, rv);
227   EXPECT_EQ("bar", std::string_view(read_buf->data(), rv));
228 
229   EXPECT_TRUE(data.AllReadDataConsumed());
230   EXPECT_TRUE(data.AllWriteDataConsumed());
231 }
232 
TEST_F(WebSocketClientSocketHandleAdapterTest,ReadIntoSmallBuffer)233 TEST_F(WebSocketClientSocketHandleAdapterTest, ReadIntoSmallBuffer) {
234   MockRead reads[] = {MockRead(SYNCHRONOUS, "foo"), MockRead("bar")};
235   StaticSocketDataProvider data(reads, base::span<MockWrite>());
236   session_deps_.socket_factory->AddSocketDataProvider(&data);
237   SSLSocketDataProvider ssl_socket_data(ASYNC, OK);
238   session_deps_.socket_factory->AddSSLSocketDataProvider(&ssl_socket_data);
239 
240   auto connection = std::make_unique<ClientSocketHandle>();
241   EXPECT_TRUE(InitClientSocketHandle(connection.get()));
242 
243   WebSocketClientSocketHandleAdapter adapter(std::move(connection));
244   EXPECT_TRUE(adapter.is_initialized());
245 
246   // Buffer smaller than each MockRead.
247   constexpr int kReadBufSize = 2;
248   auto read_buf = base::MakeRefCounted<IOBufferWithSize>(kReadBufSize);
249   int rv = adapter.Read(read_buf.get(), kReadBufSize, CompletionOnceCallback());
250   ASSERT_EQ(2, rv);
251   EXPECT_EQ("fo", std::string_view(read_buf->data(), rv));
252 
253   rv = adapter.Read(read_buf.get(), kReadBufSize, CompletionOnceCallback());
254   ASSERT_EQ(1, rv);
255   EXPECT_EQ("o", std::string_view(read_buf->data(), rv));
256 
257   TestCompletionCallback callback1;
258   rv = adapter.Read(read_buf.get(), kReadBufSize, callback1.callback());
259   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
260   rv = callback1.WaitForResult();
261   ASSERT_EQ(2, rv);
262   EXPECT_EQ("ba", std::string_view(read_buf->data(), rv));
263 
264   rv = adapter.Read(read_buf.get(), kReadBufSize, CompletionOnceCallback());
265   ASSERT_EQ(1, rv);
266   EXPECT_EQ("r", std::string_view(read_buf->data(), rv));
267 
268   EXPECT_TRUE(data.AllReadDataConsumed());
269   EXPECT_TRUE(data.AllWriteDataConsumed());
270 }
271 
TEST_F(WebSocketClientSocketHandleAdapterTest,Write)272 TEST_F(WebSocketClientSocketHandleAdapterTest, Write) {
273   MockWrite writes[] = {MockWrite(SYNCHRONOUS, "foo"), MockWrite("bar")};
274   StaticSocketDataProvider data(base::span<MockRead>(), writes);
275   session_deps_.socket_factory->AddSocketDataProvider(&data);
276   SSLSocketDataProvider ssl_socket_data(ASYNC, OK);
277   session_deps_.socket_factory->AddSSLSocketDataProvider(&ssl_socket_data);
278 
279   auto connection = std::make_unique<ClientSocketHandle>();
280   EXPECT_TRUE(InitClientSocketHandle(connection.get()));
281 
282   WebSocketClientSocketHandleAdapter adapter(std::move(connection));
283   EXPECT_TRUE(adapter.is_initialized());
284 
285   auto write_buf1 = base::MakeRefCounted<StringIOBuffer>("foo");
286   int rv =
287       adapter.Write(write_buf1.get(), write_buf1->size(),
288                     CompletionOnceCallback(), TRAFFIC_ANNOTATION_FOR_TESTS);
289   ASSERT_EQ(3, rv);
290 
291   auto write_buf2 = base::MakeRefCounted<StringIOBuffer>("bar");
292   TestCompletionCallback callback;
293   rv = adapter.Write(write_buf2.get(), write_buf2->size(), callback.callback(),
294                      TRAFFIC_ANNOTATION_FOR_TESTS);
295   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
296   rv = callback.WaitForResult();
297   ASSERT_EQ(3, rv);
298 
299   EXPECT_TRUE(data.AllReadDataConsumed());
300   EXPECT_TRUE(data.AllWriteDataConsumed());
301 }
302 
303 // Test that if both Read() and Write() returns asynchronously,
304 // the two callbacks are handled correctly.
TEST_F(WebSocketClientSocketHandleAdapterTest,AsyncReadAndWrite)305 TEST_F(WebSocketClientSocketHandleAdapterTest, AsyncReadAndWrite) {
306   MockRead reads[] = {MockRead("foobar")};
307   MockWrite writes[] = {MockWrite("baz")};
308   StaticSocketDataProvider data(reads, writes);
309   session_deps_.socket_factory->AddSocketDataProvider(&data);
310   SSLSocketDataProvider ssl_socket_data(ASYNC, OK);
311   session_deps_.socket_factory->AddSSLSocketDataProvider(&ssl_socket_data);
312 
313   auto connection = std::make_unique<ClientSocketHandle>();
314   EXPECT_TRUE(InitClientSocketHandle(connection.get()));
315 
316   WebSocketClientSocketHandleAdapter adapter(std::move(connection));
317   EXPECT_TRUE(adapter.is_initialized());
318 
319   constexpr int kReadBufSize = 1024;
320   auto read_buf = base::MakeRefCounted<IOBufferWithSize>(kReadBufSize);
321   TestCompletionCallback read_callback;
322   int rv = adapter.Read(read_buf.get(), kReadBufSize, read_callback.callback());
323   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
324 
325   auto write_buf = base::MakeRefCounted<StringIOBuffer>("baz");
326   TestCompletionCallback write_callback;
327   rv = adapter.Write(write_buf.get(), write_buf->size(),
328                      write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
329   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
330 
331   rv = read_callback.WaitForResult();
332   ASSERT_EQ(6, rv);
333   EXPECT_EQ("foobar", std::string_view(read_buf->data(), rv));
334 
335   rv = write_callback.WaitForResult();
336   ASSERT_EQ(3, rv);
337 
338   EXPECT_TRUE(data.AllReadDataConsumed());
339   EXPECT_TRUE(data.AllWriteDataConsumed());
340 }
341 
342 class MockDelegate : public WebSocketSpdyStreamAdapter::Delegate {
343  public:
344   ~MockDelegate() override = default;
345   MOCK_METHOD(void, OnHeadersSent, (), (override));
346   MOCK_METHOD(void,
347               OnHeadersReceived,
348               (const quiche::HttpHeaderBlock&),
349               (override));
350   MOCK_METHOD(void, OnClose, (int), (override));
351 };
352 
353 class WebSocketSpdyStreamAdapterTest : public TestWithTaskEnvironment {
354  protected:
WebSocketSpdyStreamAdapterTest()355   WebSocketSpdyStreamAdapterTest()
356       : url_("wss://www.example.org/"),
357         key_(HostPortPair::FromURL(url_),
358              PRIVACY_MODE_DISABLED,
359              ProxyChain::Direct(),
360              SessionUsage::kDestination,
361              SocketTag(),
362              NetworkAnonymizationKey(),
363              SecureDnsPolicy::kAllow,
364              /*disable_cert_verification_network_fetches=*/false),
365         session_(SpdySessionDependencies::SpdyCreateSession(&session_deps_)),
366         ssl_(SYNCHRONOUS, OK) {}
367 
368   ~WebSocketSpdyStreamAdapterTest() override = default;
369 
RequestHeaders()370   static quiche::HttpHeaderBlock RequestHeaders() {
371     return WebSocketHttp2Request("/", "www.example.org:443",
372                                  "http://www.example.org", {});
373   }
374 
ResponseHeaders()375   static quiche::HttpHeaderBlock ResponseHeaders() {
376     return WebSocketHttp2Response({});
377   }
378 
AddSocketData(SocketDataProvider * data)379   void AddSocketData(SocketDataProvider* data) {
380     session_deps_.socket_factory->AddSocketDataProvider(data);
381   }
382 
AddSSLSocketData()383   void AddSSLSocketData() {
384     ssl_.ssl_info.cert =
385         ImportCertFromFile(GetTestCertsDirectory(), "wildcard.pem");
386     ASSERT_TRUE(ssl_.ssl_info.cert);
387     session_deps_.socket_factory->AddSSLSocketDataProvider(&ssl_);
388   }
389 
CreateSpdySession()390   base::WeakPtr<SpdySession> CreateSpdySession() {
391     return ::net::CreateSpdySession(session_.get(), key_, NetLogWithSource());
392   }
393 
CreateSpdyStream(base::WeakPtr<SpdySession> session)394   base::WeakPtr<SpdyStream> CreateSpdyStream(
395       base::WeakPtr<SpdySession> session) {
396     return CreateStreamSynchronously(SPDY_BIDIRECTIONAL_STREAM, session, url_,
397                                      LOWEST, NetLogWithSource());
398   }
399 
400   SpdyTestUtil spdy_util_;
401   StrictMock<MockDelegate> mock_delegate_;
402 
403  private:
404   const GURL url_;
405   const SpdySessionKey key_;
406   SpdySessionDependencies session_deps_;
407   std::unique_ptr<HttpNetworkSession> session_;
408   SSLSocketDataProvider ssl_;
409 };
410 
TEST_F(WebSocketSpdyStreamAdapterTest,Disconnect)411 TEST_F(WebSocketSpdyStreamAdapterTest, Disconnect) {
412   MockRead reads[] = {MockRead(ASYNC, ERR_IO_PENDING, 0),
413                       MockRead(ASYNC, 0, 1)};
414   SequencedSocketData data(reads, base::span<MockWrite>());
415   AddSocketData(&data);
416   AddSSLSocketData();
417 
418   base::WeakPtr<SpdySession> session = CreateSpdySession();
419   base::WeakPtr<SpdyStream> stream = CreateSpdyStream(session);
420   WebSocketSpdyStreamAdapter adapter(stream, &mock_delegate_,
421                                      NetLogWithSource());
422   EXPECT_TRUE(adapter.is_initialized());
423 
424   base::RunLoop().RunUntilIdle();
425 
426   EXPECT_TRUE(stream);
427   adapter.Disconnect();
428   EXPECT_FALSE(stream);
429 
430   // Read EOF.
431   EXPECT_TRUE(session);
432   data.Resume();
433   base::RunLoop().RunUntilIdle();
434   EXPECT_FALSE(session);
435 
436   EXPECT_TRUE(data.AllReadDataConsumed());
437   EXPECT_TRUE(data.AllWriteDataConsumed());
438 }
439 
TEST_F(WebSocketSpdyStreamAdapterTest,SendRequestHeadersThenDisconnect)440 TEST_F(WebSocketSpdyStreamAdapterTest, SendRequestHeadersThenDisconnect) {
441   MockRead reads[] = {MockRead(ASYNC, ERR_IO_PENDING, 0),
442                       MockRead(ASYNC, 0, 3)};
443   spdy::SpdySerializedFrame headers(spdy_util_.ConstructSpdyHeaders(
444       1, RequestHeaders(), DEFAULT_PRIORITY, false));
445   spdy::SpdySerializedFrame rst(
446       spdy_util_.ConstructSpdyRstStream(1, spdy::ERROR_CODE_CANCEL));
447   MockWrite writes[] = {CreateMockWrite(headers, 1), CreateMockWrite(rst, 2)};
448   SequencedSocketData data(reads, writes);
449   AddSocketData(&data);
450   AddSSLSocketData();
451 
452   base::WeakPtr<SpdySession> session = CreateSpdySession();
453   base::WeakPtr<SpdyStream> stream = CreateSpdyStream(session);
454   WebSocketSpdyStreamAdapter adapter(stream, &mock_delegate_,
455                                      NetLogWithSource());
456   EXPECT_TRUE(adapter.is_initialized());
457 
458   int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND);
459   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
460 
461   // First read is a pause and it has lower sequence number than first write.
462   // Therefore writing headers does not complete while |data| is paused.
463   base::RunLoop().RunUntilIdle();
464 
465   // Reset the stream before writing completes.
466   // OnHeadersSent() will never be called.
467   EXPECT_TRUE(stream);
468   adapter.Disconnect();
469   EXPECT_FALSE(stream);
470 
471   // Resume |data|, finish writing headers, and read EOF.
472   EXPECT_TRUE(session);
473   data.Resume();
474   base::RunLoop().RunUntilIdle();
475   EXPECT_FALSE(session);
476 
477   EXPECT_TRUE(data.AllReadDataConsumed());
478   EXPECT_TRUE(data.AllWriteDataConsumed());
479 }
480 
TEST_F(WebSocketSpdyStreamAdapterTest,OnHeadersSentThenDisconnect)481 TEST_F(WebSocketSpdyStreamAdapterTest, OnHeadersSentThenDisconnect) {
482   MockRead reads[] = {MockRead(ASYNC, 0, 2)};
483   spdy::SpdySerializedFrame headers(spdy_util_.ConstructSpdyHeaders(
484       1, RequestHeaders(), DEFAULT_PRIORITY, false));
485   spdy::SpdySerializedFrame rst(
486       spdy_util_.ConstructSpdyRstStream(1, spdy::ERROR_CODE_CANCEL));
487   MockWrite writes[] = {CreateMockWrite(headers, 0), CreateMockWrite(rst, 1)};
488   SequencedSocketData data(reads, writes);
489   AddSocketData(&data);
490   AddSSLSocketData();
491 
492   EXPECT_CALL(mock_delegate_, OnHeadersSent());
493 
494   base::WeakPtr<SpdySession> session = CreateSpdySession();
495   base::WeakPtr<SpdyStream> stream = CreateSpdyStream(session);
496   WebSocketSpdyStreamAdapter adapter(stream, &mock_delegate_,
497                                      NetLogWithSource());
498   EXPECT_TRUE(adapter.is_initialized());
499 
500   int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND);
501   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
502 
503   // Finish asynchronous write of headers.  This calls OnHeadersSent().
504   base::RunLoop().RunUntilIdle();
505 
506   EXPECT_TRUE(stream);
507   adapter.Disconnect();
508   EXPECT_FALSE(stream);
509 
510   // Read EOF.
511   EXPECT_TRUE(session);
512   base::RunLoop().RunUntilIdle();
513   EXPECT_FALSE(session);
514 
515   EXPECT_TRUE(data.AllReadDataConsumed());
516   EXPECT_TRUE(data.AllWriteDataConsumed());
517 }
518 
TEST_F(WebSocketSpdyStreamAdapterTest,OnHeadersReceivedThenDisconnect)519 TEST_F(WebSocketSpdyStreamAdapterTest, OnHeadersReceivedThenDisconnect) {
520   spdy::SpdySerializedFrame response_headers(
521       spdy_util_.ConstructSpdyResponseHeaders(1, ResponseHeaders(), false));
522   MockRead reads[] = {CreateMockRead(response_headers, 1),
523                       MockRead(ASYNC, 0, 3)};
524   spdy::SpdySerializedFrame request_headers(spdy_util_.ConstructSpdyHeaders(
525       1, RequestHeaders(), DEFAULT_PRIORITY, false));
526   spdy::SpdySerializedFrame rst(
527       spdy_util_.ConstructSpdyRstStream(1, spdy::ERROR_CODE_CANCEL));
528   MockWrite writes[] = {CreateMockWrite(request_headers, 0),
529                         CreateMockWrite(rst, 2)};
530   SequencedSocketData data(reads, writes);
531   AddSocketData(&data);
532   AddSSLSocketData();
533 
534   EXPECT_CALL(mock_delegate_, OnHeadersSent());
535   EXPECT_CALL(mock_delegate_, OnHeadersReceived(_));
536 
537   base::WeakPtr<SpdySession> session = CreateSpdySession();
538   base::WeakPtr<SpdyStream> stream = CreateSpdyStream(session);
539   WebSocketSpdyStreamAdapter adapter(stream, &mock_delegate_,
540                                      NetLogWithSource());
541   EXPECT_TRUE(adapter.is_initialized());
542 
543   int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND);
544   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
545 
546   base::RunLoop().RunUntilIdle();
547 
548   EXPECT_TRUE(stream);
549   adapter.Disconnect();
550   EXPECT_FALSE(stream);
551 
552   // Read EOF.
553   EXPECT_TRUE(session);
554   base::RunLoop().RunUntilIdle();
555   EXPECT_FALSE(session);
556 
557   EXPECT_TRUE(data.AllReadDataConsumed());
558   EXPECT_TRUE(data.AllWriteDataConsumed());
559 }
560 
TEST_F(WebSocketSpdyStreamAdapterTest,ServerClosesConnection)561 TEST_F(WebSocketSpdyStreamAdapterTest, ServerClosesConnection) {
562   MockRead reads[] = {MockRead(ASYNC, 0, 0)};
563   SequencedSocketData data(reads, base::span<MockWrite>());
564   AddSocketData(&data);
565   AddSSLSocketData();
566 
567   EXPECT_CALL(mock_delegate_, OnClose(ERR_CONNECTION_CLOSED));
568 
569   base::WeakPtr<SpdySession> session = CreateSpdySession();
570   base::WeakPtr<SpdyStream> stream = CreateSpdyStream(session);
571   WebSocketSpdyStreamAdapter adapter(stream, &mock_delegate_,
572                                      NetLogWithSource());
573   EXPECT_TRUE(adapter.is_initialized());
574 
575   EXPECT_TRUE(session);
576   EXPECT_TRUE(stream);
577   base::RunLoop().RunUntilIdle();
578   EXPECT_FALSE(session);
579   EXPECT_FALSE(stream);
580 
581   EXPECT_TRUE(data.AllReadDataConsumed());
582   EXPECT_TRUE(data.AllWriteDataConsumed());
583 }
584 
TEST_F(WebSocketSpdyStreamAdapterTest,SendRequestHeadersThenServerClosesConnection)585 TEST_F(WebSocketSpdyStreamAdapterTest,
586        SendRequestHeadersThenServerClosesConnection) {
587   MockRead reads[] = {MockRead(ASYNC, 0, 1)};
588   spdy::SpdySerializedFrame headers(spdy_util_.ConstructSpdyHeaders(
589       1, RequestHeaders(), DEFAULT_PRIORITY, false));
590   MockWrite writes[] = {CreateMockWrite(headers, 0)};
591   SequencedSocketData data(reads, writes);
592   AddSocketData(&data);
593   AddSSLSocketData();
594 
595   EXPECT_CALL(mock_delegate_, OnHeadersSent());
596   EXPECT_CALL(mock_delegate_, OnClose(ERR_CONNECTION_CLOSED));
597 
598   base::WeakPtr<SpdySession> session = CreateSpdySession();
599   base::WeakPtr<SpdyStream> stream = CreateSpdyStream(session);
600   WebSocketSpdyStreamAdapter adapter(stream, &mock_delegate_,
601                                      NetLogWithSource());
602   EXPECT_TRUE(adapter.is_initialized());
603 
604   int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND);
605   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
606 
607   EXPECT_TRUE(session);
608   EXPECT_TRUE(stream);
609   base::RunLoop().RunUntilIdle();
610   EXPECT_FALSE(session);
611   EXPECT_FALSE(stream);
612 
613   EXPECT_TRUE(data.AllReadDataConsumed());
614   EXPECT_TRUE(data.AllWriteDataConsumed());
615 }
616 
TEST_F(WebSocketSpdyStreamAdapterTest,OnHeadersReceivedThenServerClosesConnection)617 TEST_F(WebSocketSpdyStreamAdapterTest,
618        OnHeadersReceivedThenServerClosesConnection) {
619   spdy::SpdySerializedFrame response_headers(
620       spdy_util_.ConstructSpdyResponseHeaders(1, ResponseHeaders(), false));
621   MockRead reads[] = {CreateMockRead(response_headers, 1),
622                       MockRead(ASYNC, 0, 2)};
623   spdy::SpdySerializedFrame request_headers(spdy_util_.ConstructSpdyHeaders(
624       1, RequestHeaders(), DEFAULT_PRIORITY, false));
625   MockWrite writes[] = {CreateMockWrite(request_headers, 0)};
626   SequencedSocketData data(reads, writes);
627   AddSocketData(&data);
628   AddSSLSocketData();
629 
630   EXPECT_CALL(mock_delegate_, OnHeadersSent());
631   EXPECT_CALL(mock_delegate_, OnHeadersReceived(_));
632   EXPECT_CALL(mock_delegate_, OnClose(ERR_CONNECTION_CLOSED));
633 
634   base::WeakPtr<SpdySession> session = CreateSpdySession();
635   base::WeakPtr<SpdyStream> stream = CreateSpdyStream(session);
636   WebSocketSpdyStreamAdapter adapter(stream, &mock_delegate_,
637                                      NetLogWithSource());
638   EXPECT_TRUE(adapter.is_initialized());
639 
640   int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND);
641   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
642 
643   EXPECT_TRUE(session);
644   EXPECT_TRUE(stream);
645   base::RunLoop().RunUntilIdle();
646   EXPECT_FALSE(session);
647   EXPECT_FALSE(stream);
648 
649   EXPECT_TRUE(data.AllReadDataConsumed());
650   EXPECT_TRUE(data.AllWriteDataConsumed());
651 }
652 
653 // Previously we failed to detect a half-close by the server that indicated the
654 // stream should be closed. This test ensures a half-close is correctly
655 // detected. See https://crbug.com/1151393.
TEST_F(WebSocketSpdyStreamAdapterTest,OnHeadersReceivedThenStreamEnd)656 TEST_F(WebSocketSpdyStreamAdapterTest, OnHeadersReceivedThenStreamEnd) {
657   spdy::SpdySerializedFrame response_headers(
658       spdy_util_.ConstructSpdyResponseHeaders(1, ResponseHeaders(), false));
659   spdy::SpdySerializedFrame stream_end(
660       spdy_util_.ConstructSpdyDataFrame(1, "", true));
661   MockRead reads[] = {CreateMockRead(response_headers, 1),
662                       CreateMockRead(stream_end, 2),
663                       MockRead(ASYNC, ERR_IO_PENDING, 3),  // pause here
664                       MockRead(ASYNC, 0, 4)};
665   spdy::SpdySerializedFrame request_headers(spdy_util_.ConstructSpdyHeaders(
666       1, RequestHeaders(), DEFAULT_PRIORITY, /* fin = */ false));
667   MockWrite writes[] = {CreateMockWrite(request_headers, 0)};
668   SequencedSocketData data(reads, writes);
669   AddSocketData(&data);
670   AddSSLSocketData();
671 
672   EXPECT_CALL(mock_delegate_, OnHeadersSent());
673   EXPECT_CALL(mock_delegate_, OnHeadersReceived(_));
674   EXPECT_CALL(mock_delegate_, OnClose(ERR_CONNECTION_CLOSED));
675 
676   // Must create buffer before `adapter`, since `adapter` doesn't hold onto a
677   // reference to it.
678   constexpr int kReadBufSize = 1024;
679   auto read_buf = base::MakeRefCounted<IOBufferWithSize>(kReadBufSize);
680 
681   base::WeakPtr<SpdySession> session = CreateSpdySession();
682   base::WeakPtr<SpdyStream> stream = CreateSpdyStream(session);
683   WebSocketSpdyStreamAdapter adapter(stream, &mock_delegate_,
684                                      NetLogWithSource());
685   EXPECT_TRUE(adapter.is_initialized());
686 
687   int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND);
688   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
689 
690   TestCompletionCallback read_callback;
691   rv = adapter.Read(read_buf.get(), kReadBufSize, read_callback.callback());
692   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
693 
694   EXPECT_TRUE(session);
695   EXPECT_TRUE(stream);
696   rv = read_callback.WaitForResult();
697   EXPECT_EQ(ERR_CONNECTION_CLOSED, rv);
698   EXPECT_TRUE(session);
699   EXPECT_FALSE(stream);
700 
701   // Close the session.
702   data.Resume();
703 
704   base::RunLoop().RunUntilIdle();
705 
706   EXPECT_TRUE(data.AllReadDataConsumed());
707   EXPECT_TRUE(data.AllWriteDataConsumed());
708 }
709 
TEST_F(WebSocketSpdyStreamAdapterTest,DetachDelegate)710 TEST_F(WebSocketSpdyStreamAdapterTest, DetachDelegate) {
711   spdy::SpdySerializedFrame response_headers(
712       spdy_util_.ConstructSpdyResponseHeaders(1, ResponseHeaders(), false));
713   MockRead reads[] = {CreateMockRead(response_headers, 1),
714                       MockRead(ASYNC, 0, 2)};
715   spdy::SpdySerializedFrame request_headers(spdy_util_.ConstructSpdyHeaders(
716       1, RequestHeaders(), DEFAULT_PRIORITY, false));
717   MockWrite writes[] = {CreateMockWrite(request_headers, 0)};
718   SequencedSocketData data(reads, writes);
719   AddSocketData(&data);
720   AddSSLSocketData();
721 
722   base::WeakPtr<SpdySession> session = CreateSpdySession();
723   base::WeakPtr<SpdyStream> stream = CreateSpdyStream(session);
724   WebSocketSpdyStreamAdapter adapter(stream, &mock_delegate_,
725                                      NetLogWithSource());
726   EXPECT_TRUE(adapter.is_initialized());
727 
728   // No Delegate methods shall be called after this.
729   adapter.DetachDelegate();
730 
731   int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND);
732   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
733 
734   EXPECT_TRUE(session);
735   EXPECT_TRUE(stream);
736   base::RunLoop().RunUntilIdle();
737   EXPECT_FALSE(session);
738   EXPECT_FALSE(stream);
739 
740   EXPECT_TRUE(data.AllReadDataConsumed());
741   EXPECT_TRUE(data.AllWriteDataConsumed());
742 }
743 
TEST_F(WebSocketSpdyStreamAdapterTest,Read)744 TEST_F(WebSocketSpdyStreamAdapterTest, Read) {
745   spdy::SpdySerializedFrame response_headers(
746       spdy_util_.ConstructSpdyResponseHeaders(1, ResponseHeaders(), false));
747   // First read is the same size as the buffer, next is smaller, last is larger.
748   spdy::SpdySerializedFrame data_frame1(
749       spdy_util_.ConstructSpdyDataFrame(1, "foo", false));
750   spdy::SpdySerializedFrame data_frame2(
751       spdy_util_.ConstructSpdyDataFrame(1, "ba", false));
752   spdy::SpdySerializedFrame data_frame3(
753       spdy_util_.ConstructSpdyDataFrame(1, "rbaz", true));
754   MockRead reads[] = {CreateMockRead(response_headers, 1),
755                       CreateMockRead(data_frame1, 2),
756                       CreateMockRead(data_frame2, 3),
757                       CreateMockRead(data_frame3, 4), MockRead(ASYNC, 0, 5)};
758   spdy::SpdySerializedFrame request_headers(spdy_util_.ConstructSpdyHeaders(
759       1, RequestHeaders(), DEFAULT_PRIORITY, false));
760   MockWrite writes[] = {CreateMockWrite(request_headers, 0)};
761   SequencedSocketData data(reads, writes);
762   AddSocketData(&data);
763   AddSSLSocketData();
764 
765   EXPECT_CALL(mock_delegate_, OnHeadersSent());
766   EXPECT_CALL(mock_delegate_, OnHeadersReceived(_));
767 
768   base::WeakPtr<SpdySession> session = CreateSpdySession();
769   base::WeakPtr<SpdyStream> stream = CreateSpdyStream(session);
770   WebSocketSpdyStreamAdapter adapter(stream, &mock_delegate_,
771                                      NetLogWithSource());
772   EXPECT_TRUE(adapter.is_initialized());
773 
774   int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND);
775   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
776 
777   constexpr int kReadBufSize = 3;
778   auto read_buf = base::MakeRefCounted<IOBufferWithSize>(kReadBufSize);
779   TestCompletionCallback callback;
780   rv = adapter.Read(read_buf.get(), kReadBufSize, callback.callback());
781   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
782   rv = callback.WaitForResult();
783   ASSERT_EQ(3, rv);
784   EXPECT_EQ("foo", std::string_view(read_buf->data(), rv));
785 
786   // Read EOF to destroy the connection and the stream.
787   // This calls SpdySession::Delegate::OnClose().
788   EXPECT_TRUE(session);
789   EXPECT_TRUE(stream);
790   base::RunLoop().RunUntilIdle();
791   EXPECT_FALSE(session);
792   EXPECT_FALSE(stream);
793 
794   // Two socket reads are concatenated by WebSocketSpdyStreamAdapter.
795   rv = adapter.Read(read_buf.get(), kReadBufSize, CompletionOnceCallback());
796   ASSERT_EQ(3, rv);
797   EXPECT_EQ("bar", std::string_view(read_buf->data(), rv));
798 
799   rv = adapter.Read(read_buf.get(), kReadBufSize, CompletionOnceCallback());
800   ASSERT_EQ(3, rv);
801   EXPECT_EQ("baz", std::string_view(read_buf->data(), rv));
802 
803   // Even though connection and stream are already closed,
804   // WebSocketSpdyStreamAdapter::Delegate::OnClose() is only called after all
805   // buffered data are read.
806   EXPECT_CALL(mock_delegate_, OnClose(ERR_CONNECTION_CLOSED));
807 
808   base::RunLoop().RunUntilIdle();
809 
810   EXPECT_TRUE(data.AllReadDataConsumed());
811   EXPECT_TRUE(data.AllWriteDataConsumed());
812 }
813 
TEST_F(WebSocketSpdyStreamAdapterTest,CallDelegateOnCloseShouldNotCrash)814 TEST_F(WebSocketSpdyStreamAdapterTest, CallDelegateOnCloseShouldNotCrash) {
815   spdy::SpdySerializedFrame response_headers(
816       spdy_util_.ConstructSpdyResponseHeaders(1, ResponseHeaders(), false));
817   spdy::SpdySerializedFrame data_frame1(
818       spdy_util_.ConstructSpdyDataFrame(1, "foo", false));
819   spdy::SpdySerializedFrame data_frame2(
820       spdy_util_.ConstructSpdyDataFrame(1, "bar", false));
821   spdy::SpdySerializedFrame rst(
822       spdy_util_.ConstructSpdyRstStream(1, spdy::ERROR_CODE_CANCEL));
823   MockRead reads[] = {CreateMockRead(response_headers, 1),
824                       CreateMockRead(data_frame1, 2),
825                       CreateMockRead(data_frame2, 3), CreateMockRead(rst, 4),
826                       MockRead(ASYNC, 0, 5)};
827   spdy::SpdySerializedFrame request_headers(spdy_util_.ConstructSpdyHeaders(
828       1, RequestHeaders(), DEFAULT_PRIORITY, false));
829   MockWrite writes[] = {CreateMockWrite(request_headers, 0)};
830   SequencedSocketData data(reads, writes);
831   AddSocketData(&data);
832   AddSSLSocketData();
833 
834   EXPECT_CALL(mock_delegate_, OnHeadersSent());
835   EXPECT_CALL(mock_delegate_, OnHeadersReceived(_));
836 
837   base::WeakPtr<SpdySession> session = CreateSpdySession();
838   base::WeakPtr<SpdyStream> stream = CreateSpdyStream(session);
839   WebSocketSpdyStreamAdapter adapter(stream, &mock_delegate_,
840                                      NetLogWithSource());
841   EXPECT_TRUE(adapter.is_initialized());
842 
843   int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND);
844   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
845 
846   // Buffer larger than each MockRead.
847   constexpr int kReadBufSize = 1024;
848   auto read_buf = base::MakeRefCounted<IOBufferWithSize>(kReadBufSize);
849   TestCompletionCallback callback;
850   rv = adapter.Read(read_buf.get(), kReadBufSize, callback.callback());
851   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
852   rv = callback.WaitForResult();
853   ASSERT_EQ(3, rv);
854   EXPECT_EQ("foo", std::string_view(read_buf->data(), rv));
855 
856   // Read RST_STREAM to destroy the stream.
857   // This calls SpdySession::Delegate::OnClose().
858   EXPECT_TRUE(session);
859   EXPECT_TRUE(stream);
860   base::RunLoop().RunUntilIdle();
861   EXPECT_FALSE(session);
862   EXPECT_FALSE(stream);
863 
864   // Read remaining buffered data.  This will PostTask CallDelegateOnClose().
865   rv = adapter.Read(read_buf.get(), kReadBufSize, CompletionOnceCallback());
866   ASSERT_EQ(3, rv);
867   EXPECT_EQ("bar", std::string_view(read_buf->data(), rv));
868 
869   adapter.DetachDelegate();
870 
871   // Run CallDelegateOnClose(), which should not crash
872   // even if |delegate_| is null.
873   base::RunLoop().RunUntilIdle();
874 
875   EXPECT_TRUE(data.AllReadDataConsumed());
876   EXPECT_TRUE(data.AllWriteDataConsumed());
877 }
878 
TEST_F(WebSocketSpdyStreamAdapterTest,Write)879 TEST_F(WebSocketSpdyStreamAdapterTest, Write) {
880   spdy::SpdySerializedFrame response_headers(
881       spdy_util_.ConstructSpdyResponseHeaders(1, ResponseHeaders(), false));
882   MockRead reads[] = {CreateMockRead(response_headers, 1),
883                       MockRead(ASYNC, 0, 3)};
884   spdy::SpdySerializedFrame request_headers(spdy_util_.ConstructSpdyHeaders(
885       1, RequestHeaders(), DEFAULT_PRIORITY, false));
886   spdy::SpdySerializedFrame data_frame(
887       spdy_util_.ConstructSpdyDataFrame(1, "foo", false));
888   MockWrite writes[] = {CreateMockWrite(request_headers, 0),
889                         CreateMockWrite(data_frame, 2)};
890   SequencedSocketData data(reads, writes);
891   AddSocketData(&data);
892   AddSSLSocketData();
893 
894   base::WeakPtr<SpdySession> session = CreateSpdySession();
895   base::WeakPtr<SpdyStream> stream = CreateSpdyStream(session);
896   WebSocketSpdyStreamAdapter adapter(stream, nullptr, NetLogWithSource());
897   EXPECT_TRUE(adapter.is_initialized());
898 
899   int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND);
900   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
901 
902   base::RunLoop().RunUntilIdle();
903 
904   auto write_buf = base::MakeRefCounted<StringIOBuffer>("foo");
905   TestCompletionCallback callback;
906   rv = adapter.Write(write_buf.get(), write_buf->size(), callback.callback(),
907                      TRAFFIC_ANNOTATION_FOR_TESTS);
908   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
909   rv = callback.WaitForResult();
910   ASSERT_EQ(3, rv);
911 
912   // Read EOF.
913   base::RunLoop().RunUntilIdle();
914 
915   EXPECT_TRUE(data.AllReadDataConsumed());
916   EXPECT_TRUE(data.AllWriteDataConsumed());
917 }
918 
919 // Test that if both Read() and Write() returns asynchronously,
920 // the two callbacks are handled correctly.
TEST_F(WebSocketSpdyStreamAdapterTest,AsyncReadAndWrite)921 TEST_F(WebSocketSpdyStreamAdapterTest, AsyncReadAndWrite) {
922   spdy::SpdySerializedFrame response_headers(
923       spdy_util_.ConstructSpdyResponseHeaders(1, ResponseHeaders(), false));
924   spdy::SpdySerializedFrame read_data_frame(
925       spdy_util_.ConstructSpdyDataFrame(1, "foobar", true));
926   MockRead reads[] = {CreateMockRead(response_headers, 1),
927                       CreateMockRead(read_data_frame, 3),
928                       MockRead(ASYNC, 0, 4)};
929   spdy::SpdySerializedFrame request_headers(spdy_util_.ConstructSpdyHeaders(
930       1, RequestHeaders(), DEFAULT_PRIORITY, false));
931   spdy::SpdySerializedFrame write_data_frame(
932       spdy_util_.ConstructSpdyDataFrame(1, "baz", false));
933   MockWrite writes[] = {CreateMockWrite(request_headers, 0),
934                         CreateMockWrite(write_data_frame, 2)};
935   SequencedSocketData data(reads, writes);
936   AddSocketData(&data);
937   AddSSLSocketData();
938 
939   base::WeakPtr<SpdySession> session = CreateSpdySession();
940   base::WeakPtr<SpdyStream> stream = CreateSpdyStream(session);
941   WebSocketSpdyStreamAdapter adapter(stream, nullptr, NetLogWithSource());
942   EXPECT_TRUE(adapter.is_initialized());
943 
944   int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND);
945   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
946 
947   base::RunLoop().RunUntilIdle();
948 
949   constexpr int kReadBufSize = 1024;
950   auto read_buf = base::MakeRefCounted<IOBufferWithSize>(kReadBufSize);
951   TestCompletionCallback read_callback;
952   rv = adapter.Read(read_buf.get(), kReadBufSize, read_callback.callback());
953   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
954 
955   auto write_buf = base::MakeRefCounted<StringIOBuffer>("baz");
956   TestCompletionCallback write_callback;
957   rv = adapter.Write(write_buf.get(), write_buf->size(),
958                      write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
959   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
960 
961   rv = read_callback.WaitForResult();
962   ASSERT_EQ(6, rv);
963   EXPECT_EQ("foobar", std::string_view(read_buf->data(), rv));
964 
965   rv = write_callback.WaitForResult();
966   ASSERT_EQ(3, rv);
967 
968   // Read EOF.
969   base::RunLoop().RunUntilIdle();
970 
971   EXPECT_TRUE(data.AllReadDataConsumed());
972   EXPECT_TRUE(data.AllWriteDataConsumed());
973 }
974 
975 // A helper class that will delete |adapter| when the callback is invoked.
976 class KillerCallback : public TestCompletionCallbackBase {
977  public:
KillerCallback(std::unique_ptr<WebSocketSpdyStreamAdapter> adapter)978   explicit KillerCallback(std::unique_ptr<WebSocketSpdyStreamAdapter> adapter)
979       : adapter_(std::move(adapter)) {}
980 
981   ~KillerCallback() override = default;
982 
callback()983   CompletionOnceCallback callback() {
984     return base::BindOnce(&KillerCallback::OnComplete, base::Unretained(this));
985   }
986 
987  private:
OnComplete(int result)988   void OnComplete(int result) {
989     adapter_.reset();
990     SetResult(result);
991   }
992 
993   std::unique_ptr<WebSocketSpdyStreamAdapter> adapter_;
994 };
995 
TEST_F(WebSocketSpdyStreamAdapterTest,ReadCallbackDestroysAdapter)996 TEST_F(WebSocketSpdyStreamAdapterTest, ReadCallbackDestroysAdapter) {
997   spdy::SpdySerializedFrame response_headers(
998       spdy_util_.ConstructSpdyResponseHeaders(1, ResponseHeaders(), false));
999   MockRead reads[] = {CreateMockRead(response_headers, 1),
1000                       MockRead(ASYNC, ERR_IO_PENDING, 2),
1001                       MockRead(ASYNC, 0, 3)};
1002   spdy::SpdySerializedFrame request_headers(spdy_util_.ConstructSpdyHeaders(
1003       1, RequestHeaders(), DEFAULT_PRIORITY, false));
1004   MockWrite writes[] = {CreateMockWrite(request_headers, 0)};
1005   SequencedSocketData data(reads, writes);
1006   AddSocketData(&data);
1007   AddSSLSocketData();
1008 
1009   EXPECT_CALL(mock_delegate_, OnHeadersSent());
1010   EXPECT_CALL(mock_delegate_, OnHeadersReceived(_));
1011 
1012   base::WeakPtr<SpdySession> session = CreateSpdySession();
1013   base::WeakPtr<SpdyStream> stream = CreateSpdyStream(session);
1014   auto adapter = std::make_unique<WebSocketSpdyStreamAdapter>(
1015       stream, &mock_delegate_, NetLogWithSource());
1016   EXPECT_TRUE(adapter->is_initialized());
1017 
1018   int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND);
1019   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
1020 
1021   // Send headers.
1022   base::RunLoop().RunUntilIdle();
1023 
1024   WebSocketSpdyStreamAdapter* adapter_raw = adapter.get();
1025   KillerCallback callback(std::move(adapter));
1026 
1027   constexpr int kReadBufSize = 1024;
1028   auto read_buf = base::MakeRefCounted<IOBufferWithSize>(kReadBufSize);
1029   rv = adapter_raw->Read(read_buf.get(), kReadBufSize, callback.callback());
1030   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
1031 
1032   // Read EOF while read is pending.  WebSocketSpdyStreamAdapter::OnClose()
1033   // should not crash if read callback destroys |adapter|.
1034   data.Resume();
1035   rv = callback.WaitForResult();
1036   EXPECT_THAT(rv, IsError(ERR_CONNECTION_CLOSED));
1037 
1038   base::RunLoop().RunUntilIdle();
1039   EXPECT_FALSE(session);
1040   EXPECT_FALSE(stream);
1041 
1042   EXPECT_TRUE(data.AllReadDataConsumed());
1043   EXPECT_TRUE(data.AllWriteDataConsumed());
1044 }
1045 
TEST_F(WebSocketSpdyStreamAdapterTest,WriteCallbackDestroysAdapter)1046 TEST_F(WebSocketSpdyStreamAdapterTest, WriteCallbackDestroysAdapter) {
1047   spdy::SpdySerializedFrame response_headers(
1048       spdy_util_.ConstructSpdyResponseHeaders(1, ResponseHeaders(), false));
1049   MockRead reads[] = {CreateMockRead(response_headers, 1),
1050                       MockRead(ASYNC, ERR_IO_PENDING, 2),
1051                       MockRead(ASYNC, 0, 3)};
1052   spdy::SpdySerializedFrame request_headers(spdy_util_.ConstructSpdyHeaders(
1053       1, RequestHeaders(), DEFAULT_PRIORITY, false));
1054   MockWrite writes[] = {CreateMockWrite(request_headers, 0)};
1055   SequencedSocketData data(reads, writes);
1056   AddSocketData(&data);
1057   AddSSLSocketData();
1058 
1059   EXPECT_CALL(mock_delegate_, OnHeadersSent());
1060   EXPECT_CALL(mock_delegate_, OnHeadersReceived(_));
1061 
1062   base::WeakPtr<SpdySession> session = CreateSpdySession();
1063   base::WeakPtr<SpdyStream> stream = CreateSpdyStream(session);
1064   auto adapter = std::make_unique<WebSocketSpdyStreamAdapter>(
1065       stream, &mock_delegate_, NetLogWithSource());
1066   EXPECT_TRUE(adapter->is_initialized());
1067 
1068   int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND);
1069   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
1070 
1071   // Send headers.
1072   base::RunLoop().RunUntilIdle();
1073 
1074   WebSocketSpdyStreamAdapter* adapter_raw = adapter.get();
1075   KillerCallback callback(std::move(adapter));
1076 
1077   auto write_buf = base::MakeRefCounted<StringIOBuffer>("foo");
1078   rv = adapter_raw->Write(write_buf.get(), write_buf->size(),
1079                           callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
1080   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
1081 
1082   // Read EOF while write is pending.  WebSocketSpdyStreamAdapter::OnClose()
1083   // should not crash if write callback destroys |adapter|.
1084   data.Resume();
1085   rv = callback.WaitForResult();
1086   EXPECT_THAT(rv, IsError(ERR_CONNECTION_CLOSED));
1087 
1088   base::RunLoop().RunUntilIdle();
1089   EXPECT_FALSE(session);
1090   EXPECT_FALSE(stream);
1091 
1092   EXPECT_TRUE(data.AllReadDataConsumed());
1093   EXPECT_TRUE(data.AllWriteDataConsumed());
1094 }
1095 
TEST_F(WebSocketSpdyStreamAdapterTest,OnCloseOkShouldBeTranslatedToConnectionClose)1096 TEST_F(WebSocketSpdyStreamAdapterTest,
1097        OnCloseOkShouldBeTranslatedToConnectionClose) {
1098   spdy::SpdySerializedFrame response_headers(
1099       spdy_util_.ConstructSpdyResponseHeaders(1, ResponseHeaders(), false));
1100   spdy::SpdySerializedFrame close(
1101       spdy_util_.ConstructSpdyRstStream(1, spdy::ERROR_CODE_NO_ERROR));
1102   MockRead reads[] = {CreateMockRead(response_headers, 1),
1103                       CreateMockRead(close, 2), MockRead(ASYNC, 0, 3)};
1104   spdy::SpdySerializedFrame request_headers(spdy_util_.ConstructSpdyHeaders(
1105       1, RequestHeaders(), DEFAULT_PRIORITY, false));
1106   MockWrite writes[] = {CreateMockWrite(request_headers, 0)};
1107   SequencedSocketData data(reads, writes);
1108   AddSocketData(&data);
1109   AddSSLSocketData();
1110 
1111   EXPECT_CALL(mock_delegate_, OnHeadersSent());
1112   EXPECT_CALL(mock_delegate_, OnHeadersReceived(_));
1113 
1114   // Must create buffer before `adapter`, since `adapter` doesn't hold onto a
1115   // reference to it.
1116   constexpr int kReadBufSize = 1024;
1117   auto read_buf = base::MakeRefCounted<IOBufferWithSize>(kReadBufSize);
1118 
1119   base::WeakPtr<SpdySession> session = CreateSpdySession();
1120   base::WeakPtr<SpdyStream> stream = CreateSpdyStream(session);
1121   WebSocketSpdyStreamAdapter adapter(stream, &mock_delegate_,
1122                                      NetLogWithSource());
1123   EXPECT_TRUE(adapter.is_initialized());
1124 
1125   EXPECT_CALL(mock_delegate_, OnClose(ERR_CONNECTION_CLOSED));
1126 
1127   int rv = stream->SendRequestHeaders(RequestHeaders(), MORE_DATA_TO_SEND);
1128   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
1129 
1130   TestCompletionCallback callback;
1131   rv = adapter.Read(read_buf.get(), kReadBufSize, callback.callback());
1132   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
1133   rv = callback.WaitForResult();
1134   ASSERT_EQ(ERR_CONNECTION_CLOSED, rv);
1135 }
1136 
1137 class MockQuicDelegate : public WebSocketQuicStreamAdapter::Delegate {
1138  public:
1139   ~MockQuicDelegate() override = default;
1140   MOCK_METHOD(void, OnHeadersSent, (), (override));
1141   MOCK_METHOD(void,
1142               OnHeadersReceived,
1143               (const quiche::HttpHeaderBlock&),
1144               (override));
1145   MOCK_METHOD(void, OnClose, (int), (override));
1146 };
1147 
1148 class WebSocketQuicStreamAdapterTest
1149     : public TestWithTaskEnvironment,
1150       public ::testing::WithParamInterface<quic::ParsedQuicVersion> {
1151  protected:
RequestHeaders()1152   static quiche::HttpHeaderBlock RequestHeaders() {
1153     return WebSocketHttp2Request("/", "www.example.org:443",
1154                                  "http://www.example.org", {});
1155   }
WebSocketQuicStreamAdapterTest()1156   WebSocketQuicStreamAdapterTest()
1157       : version_(GetParam()),
1158         mock_quic_data_(version_),
1159         client_data_stream_id1_(quic::QuicUtils::GetFirstBidirectionalStreamId(
1160             version_.transport_version,
1161             quic::Perspective::IS_CLIENT)),
1162         crypto_config_(
1163             quic::test::crypto_test_utils::ProofVerifierForTesting()),
1164         connection_id_(quic::test::TestConnectionId(2)),
1165         client_maker_(version_,
1166                       connection_id_,
1167                       &clock_,
1168                       "mail.example.org",
1169                       quic::Perspective::IS_CLIENT),
1170         server_maker_(version_,
1171                       connection_id_,
1172                       &clock_,
1173                       "mail.example.org",
1174                       quic::Perspective::IS_SERVER),
1175         peer_addr_(IPAddress(192, 0, 2, 23), 443),
1176         destination_endpoint_(url::kHttpsScheme, "mail.example.org", 80) {}
1177 
1178   ~WebSocketQuicStreamAdapterTest() override = default;
1179 
SetUp()1180   void SetUp() override {
1181     FLAGS_quic_enable_http3_grease_randomness = false;
1182     clock_.AdvanceTime(quic::QuicTime::Delta::FromMilliseconds(20));
1183     quic::QuicEnableVersion(version_);
1184   }
1185 
TearDown()1186   void TearDown() override {
1187     EXPECT_TRUE(mock_quic_data_.AllReadDataConsumed());
1188     EXPECT_TRUE(mock_quic_data_.AllWriteDataConsumed());
1189   }
1190 
GetQuicSessionHandle()1191   net::QuicChromiumClientSession::Handle* GetQuicSessionHandle() {
1192     return session_handle_.get();
1193   }
1194 
1195   // Helper functions for constructing packets sent by the client
1196 
ConstructSettingsPacket(uint64_t packet_number)1197   std::unique_ptr<quic::QuicReceivedPacket> ConstructSettingsPacket(
1198       uint64_t packet_number) {
1199     return client_maker_.MakeInitialSettingsPacket(packet_number);
1200   }
1201 
ConstructServerDataPacket(uint64_t packet_number,std::string_view data)1202   std::unique_ptr<quic::QuicReceivedPacket> ConstructServerDataPacket(
1203       uint64_t packet_number,
1204       std::string_view data) {
1205     quiche::QuicheBuffer buffer = quic::HttpEncoder::SerializeDataFrameHeader(
1206         data.size(), quiche::SimpleBufferAllocator::Get());
1207     return server_maker_.Packet(packet_number)
1208         .AddStreamFrame(
1209             client_data_stream_id1_, /*fin=*/false,
1210             base::StrCat(
1211                 {std::string_view(buffer.data(), buffer.size()), data}))
1212         .Build();
1213   }
1214 
ConstructRstPacket(uint64_t packet_number,quic::QuicRstStreamErrorCode error_code)1215   std::unique_ptr<quic::QuicReceivedPacket> ConstructRstPacket(
1216       uint64_t packet_number,
1217       quic::QuicRstStreamErrorCode error_code) {
1218     return client_maker_.Packet(packet_number)
1219         .AddStopSendingFrame(client_data_stream_id1_, error_code)
1220         .AddRstStreamFrame(client_data_stream_id1_, error_code)
1221         .Build();
1222   }
1223 
ConstructClientAckPacket(uint64_t packet_number,uint64_t largest_received,uint64_t smallest_received)1224   std::unique_ptr<quic::QuicEncryptedPacket> ConstructClientAckPacket(
1225       uint64_t packet_number,
1226       uint64_t largest_received,
1227       uint64_t smallest_received) {
1228     return client_maker_.Packet(packet_number)
1229         .AddAckFrame(1, largest_received, smallest_received)
1230         .Build();
1231   }
1232 
ConstructAckAndRstPacket(uint64_t packet_number,quic::QuicRstStreamErrorCode error_code,uint64_t largest_received,uint64_t smallest_received)1233   std::unique_ptr<quic::QuicReceivedPacket> ConstructAckAndRstPacket(
1234       uint64_t packet_number,
1235       quic::QuicRstStreamErrorCode error_code,
1236       uint64_t largest_received,
1237       uint64_t smallest_received) {
1238     return client_maker_.Packet(packet_number)
1239         .AddAckFrame(/*first_received=*/1, largest_received, smallest_received)
1240         .AddStopSendingFrame(client_data_stream_id1_, error_code)
1241         .AddRstStreamFrame(client_data_stream_id1_, error_code)
1242         .Build();
1243   }
1244 
Initialize()1245   void Initialize() {
1246     auto socket = std::make_unique<MockUDPClientSocket>(
1247         mock_quic_data_.InitializeAndGetSequencedSocketData(), NetLog::Get());
1248     socket->Connect(peer_addr_);
1249 
1250     runner_ = base::MakeRefCounted<TestTaskRunner>(&clock_);
1251     helper_ = std::make_unique<QuicChromiumConnectionHelper>(
1252         &clock_, &random_generator_);
1253     alarm_factory_ =
1254         std::make_unique<QuicChromiumAlarmFactory>(runner_.get(), &clock_);
1255     // Ownership of 'writer' is passed to 'QuicConnection'.
1256     QuicChromiumPacketWriter* writer = new QuicChromiumPacketWriter(
1257         socket.get(), base::SingleThreadTaskRunner::GetCurrentDefault().get());
1258     quic::QuicConnection* connection = new quic::QuicConnection(
1259         connection_id_, quic::QuicSocketAddress(),
1260         net::ToQuicSocketAddress(peer_addr_), helper_.get(),
1261         alarm_factory_.get(), writer, true /* owns_writer */,
1262         quic::Perspective::IS_CLIENT, quic::test::SupportedVersions(version_),
1263         connection_id_generator_);
1264     connection->set_visitor(&visitor_);
1265 
1266     // Load a certificate that is valid for *.example.org
1267     scoped_refptr<X509Certificate> test_cert(
1268         ImportCertFromFile(GetTestCertsDirectory(), "wildcard.pem"));
1269     EXPECT_TRUE(test_cert.get());
1270 
1271     verify_details_.cert_verify_result.verified_cert = test_cert;
1272     verify_details_.cert_verify_result.is_issued_by_known_root = true;
1273     crypto_client_stream_factory_.AddProofVerifyDetails(&verify_details_);
1274 
1275     base::TimeTicks dns_end = base::TimeTicks::Now();
1276     base::TimeTicks dns_start = dns_end - base::Milliseconds(1);
1277 
1278     session_ = std::make_unique<QuicChromiumClientSession>(
1279         connection, std::move(socket),
1280         /*stream_factory=*/nullptr, &crypto_client_stream_factory_, &clock_,
1281         &transport_security_state_, &ssl_config_service_,
1282         /*server_info=*/nullptr,
1283         QuicSessionAliasKey(
1284             url::SchemeHostPort(),
1285             QuicSessionKey("mail.example.org", 80, PRIVACY_MODE_DISABLED,
1286                            ProxyChain::Direct(), SessionUsage::kDestination,
1287                            SocketTag(), NetworkAnonymizationKey(),
1288                            SecureDnsPolicy::kAllow,
1289                            /*require_dns_https_alpn=*/false)),
1290         /*require_confirmation=*/false,
1291         /*migrate_session_early_v2=*/false,
1292         /*migrate_session_on_network_change_v2=*/false,
1293         /*default_network=*/handles::kInvalidNetworkHandle,
1294         quic::QuicTime::Delta::FromMilliseconds(
1295             kDefaultRetransmittableOnWireTimeout.InMilliseconds()),
1296         /*migrate_idle_session=*/true, /*allow_port_migration=*/false,
1297         kDefaultIdleSessionMigrationPeriod, /*multi_port_probing_interval=*/0,
1298         kMaxTimeOnNonDefaultNetwork,
1299         kMaxMigrationsToNonDefaultNetworkOnWriteError,
1300         kMaxMigrationsToNonDefaultNetworkOnPathDegrading,
1301         kQuicYieldAfterPacketsRead,
1302         quic::QuicTime::Delta::FromMilliseconds(
1303             kQuicYieldAfterDurationMilliseconds),
1304         /*cert_verify_flags=*/0, quic::test::DefaultQuicConfig(),
1305         std::make_unique<TestQuicCryptoClientConfigHandle>(&crypto_config_),
1306         "CONNECTION_UNKNOWN", dns_start, dns_end,
1307         base::DefaultTickClock::GetInstance(),
1308         base::SingleThreadTaskRunner::GetCurrentDefault().get(),
1309         /*socket_performance_watcher=*/nullptr, ConnectionEndpointMetadata(),
1310         /*report_ecn=*/true, /*enable_origin_frame=*/true,
1311         /*allow_server_preferred_address=*/true,
1312         MultiplexedSessionCreationInitiator::kUnknown,
1313         NetLogWithSource::Make(NetLogSourceType::NONE));
1314 
1315     session_->Initialize();
1316 
1317     // Blackhole QPACK decoder stream instead of constructing mock writes.
1318     session_->qpack_decoder()->set_qpack_stream_sender_delegate(
1319         &noop_qpack_stream_sender_delegate_);
1320     TestCompletionCallback callback;
1321     EXPECT_THAT(session_->CryptoConnect(callback.callback()), IsOk());
1322     EXPECT_TRUE(session_->OneRttKeysAvailable());
1323     session_handle_ = session_->CreateHandle(
1324         url::SchemeHostPort(url::kHttpsScheme, "mail.example.org", 80));
1325   }
1326 
1327   const quic::ParsedQuicVersion version_;
1328   MockQuicData mock_quic_data_;
1329   StrictMock<MockQuicDelegate> mock_delegate_;
1330   const quic::QuicStreamId client_data_stream_id1_;
1331 
1332  private:
1333   quic::QuicCryptoClientConfig crypto_config_;
1334   const quic::QuicConnectionId connection_id_;
1335 
1336  protected:
1337   QuicTestPacketMaker client_maker_;
1338   QuicTestPacketMaker server_maker_;
1339   std::unique_ptr<QuicChromiumClientSession> session_;
1340 
1341  private:
1342   quic::MockClock clock_;
1343   std::unique_ptr<QuicChromiumClientSession::Handle> session_handle_;
1344   scoped_refptr<TestTaskRunner> runner_;
1345   ProofVerifyDetailsChromium verify_details_;
1346   MockCryptoClientStreamFactory crypto_client_stream_factory_;
1347   SSLConfigServiceDefaults ssl_config_service_;
1348   quic::test::MockConnectionIdGenerator connection_id_generator_;
1349   std::unique_ptr<QuicChromiumConnectionHelper> helper_;
1350   std::unique_ptr<QuicChromiumAlarmFactory> alarm_factory_;
1351   testing::StrictMock<quic::test::MockQuicConnectionVisitor> visitor_;
1352   TransportSecurityState transport_security_state_;
1353   IPAddress ip_;
1354   IPEndPoint peer_addr_;
1355   quic::test::MockRandom random_generator_{0};
1356   url::SchemeHostPort destination_endpoint_;
1357   quic::test::NoopQpackStreamSenderDelegate noop_qpack_stream_sender_delegate_;
1358 };
1359 
1360 // Like net::TestCompletionCallback, but for a callback that takes an unbound
1361 // parameter of type WebSocketQuicStreamAdapter.
1362 struct WebSocketQuicStreamAdapterIsPendingHelper {
operator ()net::test::WebSocketQuicStreamAdapterIsPendingHelper1363   bool operator()(
1364       const std::unique_ptr<WebSocketQuicStreamAdapter>& adapter) const {
1365     return !adapter;
1366   }
1367 };
1368 
1369 using TestWebSocketQuicStreamAdapterCompletionCallbackBase =
1370     net::internal::TestCompletionCallbackTemplate<
1371         std::unique_ptr<WebSocketQuicStreamAdapter>,
1372         WebSocketQuicStreamAdapterIsPendingHelper>;
1373 
1374 class TestWebSocketQuicStreamAdapterCompletionCallback
1375     : public TestWebSocketQuicStreamAdapterCompletionCallbackBase {
1376  public:
1377   base::OnceCallback<void(std::unique_ptr<WebSocketQuicStreamAdapter>)>
1378   callback();
1379 };
1380 
1381 base::OnceCallback<void(std::unique_ptr<WebSocketQuicStreamAdapter>)>
callback()1382 TestWebSocketQuicStreamAdapterCompletionCallback::callback() {
1383   return base::BindOnce(
1384       &TestWebSocketQuicStreamAdapterCompletionCallback::SetResult,
1385       base::Unretained(this));
1386 }
1387 
1388 INSTANTIATE_TEST_SUITE_P(QuicVersion,
1389                          WebSocketQuicStreamAdapterTest,
1390                          ::testing::ValuesIn(AllSupportedQuicVersions()),
1391                          ::testing::PrintToStringParamName());
1392 
TEST_P(WebSocketQuicStreamAdapterTest,Disconnect)1393 TEST_P(WebSocketQuicStreamAdapterTest, Disconnect) {
1394   int packet_number = 1;
1395   mock_quic_data_.AddWrite(SYNCHRONOUS,
1396                            ConstructSettingsPacket(packet_number++));
1397 
1398   mock_quic_data_.AddWrite(
1399       SYNCHRONOUS,
1400       ConstructRstPacket(packet_number++, quic::QUIC_STREAM_CANCELLED));
1401 
1402   Initialize();
1403 
1404   net::QuicChromiumClientSession::Handle* session_handle =
1405       GetQuicSessionHandle();
1406   ASSERT_TRUE(session_handle);
1407 
1408   TestWebSocketQuicStreamAdapterCompletionCallback callback;
1409   std::unique_ptr<WebSocketQuicStreamAdapter> adapter =
1410       session_handle->CreateWebSocketQuicStreamAdapter(
1411           &mock_delegate_, callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
1412   ASSERT_TRUE(adapter);
1413   EXPECT_TRUE(adapter->is_initialized());
1414   adapter->Disconnect();
1415   // TODO(momoka): Add tests to test both destruction orders.
1416 }
1417 
TEST_P(WebSocketQuicStreamAdapterTest,AsyncAdapterCreation)1418 TEST_P(WebSocketQuicStreamAdapterTest, AsyncAdapterCreation) {
1419   constexpr size_t kMaxOpenStreams = 50;
1420 
1421   int packet_number = 1;
1422   mock_quic_data_.AddWrite(SYNCHRONOUS,
1423                            ConstructSettingsPacket(packet_number++));
1424 
1425   mock_quic_data_.AddWrite(
1426       SYNCHRONOUS, client_maker_.Packet(packet_number++)
1427                        .AddStreamsBlockedFrame(/*control_frame_id=*/1,
1428                                                /*stream_count=*/kMaxOpenStreams,
1429                                                /* unidirectional = */ false)
1430                        .Build());
1431 
1432   mock_quic_data_.AddRead(
1433       ASYNC, server_maker_.Packet(1)
1434                  .AddMaxStreamsFrame(/*control_frame_id=*/1,
1435                                      /*stream_count=*/kMaxOpenStreams + 2,
1436                                      /* unidirectional = */ false)
1437                  .Build());
1438 
1439   mock_quic_data_.AddRead(ASYNC, ERR_IO_PENDING);
1440   mock_quic_data_.AddRead(ASYNC, ERR_CONNECTION_CLOSED);
1441 
1442   Initialize();
1443 
1444   std::vector<QuicChromiumClientStream*> streams;
1445 
1446   for (size_t i = 0; i < kMaxOpenStreams; i++) {
1447     QuicChromiumClientStream* stream =
1448         QuicChromiumClientSessionPeer::CreateOutgoingStream(session_.get());
1449     ASSERT_TRUE(stream);
1450     streams.push_back(stream);
1451     EXPECT_EQ(i + 1, session_->GetNumActiveStreams());
1452   }
1453 
1454   net::QuicChromiumClientSession::Handle* session_handle =
1455       GetQuicSessionHandle();
1456   ASSERT_TRUE(session_handle);
1457 
1458   // Creating an adapter should fail because of the stream limit.
1459   TestWebSocketQuicStreamAdapterCompletionCallback callback;
1460   std::unique_ptr<WebSocketQuicStreamAdapter> adapter =
1461       session_handle->CreateWebSocketQuicStreamAdapter(
1462           &mock_delegate_, callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
1463   ASSERT_EQ(adapter, nullptr);
1464   EXPECT_FALSE(callback.have_result());
1465   EXPECT_EQ(kMaxOpenStreams, session_->GetNumActiveStreams());
1466 
1467   // Read MAX_STREAMS frame that makes it possible to open WebSocket stream.
1468   session_->StartReading();
1469   callback.WaitForResult();
1470   EXPECT_EQ(kMaxOpenStreams + 1, session_->GetNumActiveStreams());
1471 
1472   // Close connection.
1473   mock_quic_data_.Resume();
1474   base::RunLoop().RunUntilIdle();
1475 }
1476 
TEST_P(WebSocketQuicStreamAdapterTest,SendRequestHeadersThenDisconnect)1477 TEST_P(WebSocketQuicStreamAdapterTest, SendRequestHeadersThenDisconnect) {
1478   int packet_number = 1;
1479   mock_quic_data_.AddWrite(SYNCHRONOUS,
1480                            ConstructSettingsPacket(packet_number++));
1481   SpdyTestUtil spdy_util;
1482   quiche::HttpHeaderBlock request_header_block = WebSocketHttp2Request(
1483       "/", "www.example.org:443", "http://www.example.org", {});
1484   mock_quic_data_.AddWrite(
1485       SYNCHRONOUS,
1486       client_maker_.MakeRequestHeadersPacket(
1487           packet_number++, client_data_stream_id1_,
1488           /*fin=*/false, ConvertRequestPriorityToQuicPriority(LOWEST),
1489           std::move(request_header_block), nullptr));
1490 
1491   mock_quic_data_.AddWrite(
1492       SYNCHRONOUS,
1493       ConstructRstPacket(packet_number++, quic::QUIC_STREAM_CANCELLED));
1494 
1495   Initialize();
1496 
1497   net::QuicChromiumClientSession::Handle* session_handle =
1498       GetQuicSessionHandle();
1499   ASSERT_TRUE(session_handle);
1500   TestWebSocketQuicStreamAdapterCompletionCallback callback;
1501   std::unique_ptr<WebSocketQuicStreamAdapter> adapter =
1502       session_handle->CreateWebSocketQuicStreamAdapter(
1503           &mock_delegate_, callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
1504   ASSERT_TRUE(adapter);
1505   EXPECT_TRUE(adapter->is_initialized());
1506 
1507   adapter->WriteHeaders(RequestHeaders(), false);
1508 
1509   adapter->Disconnect();
1510 }
1511 
TEST_P(WebSocketQuicStreamAdapterTest,OnHeadersReceivedThenDisconnect)1512 TEST_P(WebSocketQuicStreamAdapterTest, OnHeadersReceivedThenDisconnect) {
1513   int packet_number = 1;
1514   mock_quic_data_.AddWrite(SYNCHRONOUS,
1515                            ConstructSettingsPacket(packet_number++));
1516 
1517   SpdyTestUtil spdy_util;
1518   quiche::HttpHeaderBlock request_header_block = WebSocketHttp2Request(
1519       "/", "www.example.org:443", "http://www.example.org", {});
1520   mock_quic_data_.AddWrite(
1521       SYNCHRONOUS,
1522       client_maker_.MakeRequestHeadersPacket(
1523           packet_number++, client_data_stream_id1_,
1524           /*fin=*/false, ConvertRequestPriorityToQuicPriority(LOWEST),
1525           std::move(request_header_block), nullptr));
1526 
1527   quiche::HttpHeaderBlock response_header_block = WebSocketHttp2Response({});
1528   mock_quic_data_.AddRead(
1529       ASYNC, server_maker_.MakeResponseHeadersPacket(
1530                  /*packet_number=*/1, client_data_stream_id1_, /*fin=*/false,
1531                  std::move(response_header_block),
1532                  /*spdy_headers_frame_length=*/nullptr));
1533   mock_quic_data_.AddRead(SYNCHRONOUS, ERR_IO_PENDING);
1534   mock_quic_data_.AddWrite(
1535       SYNCHRONOUS, ConstructAckAndRstPacket(packet_number++,
1536                                             quic::QUIC_STREAM_CANCELLED, 1, 0));
1537   base::RunLoop run_loop;
1538   auto quit_closure = run_loop.QuitClosure();
1539   EXPECT_CALL(mock_delegate_, OnHeadersReceived(_)).WillOnce(Invoke([&]() {
1540     std::move(quit_closure).Run();
1541   }));
1542 
1543   Initialize();
1544 
1545   net::QuicChromiumClientSession::Handle* session_handle =
1546       GetQuicSessionHandle();
1547   ASSERT_TRUE(session_handle);
1548 
1549   TestWebSocketQuicStreamAdapterCompletionCallback callback;
1550   std::unique_ptr<WebSocketQuicStreamAdapter> adapter =
1551       session_handle->CreateWebSocketQuicStreamAdapter(
1552           &mock_delegate_, callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
1553   ASSERT_TRUE(adapter);
1554   EXPECT_TRUE(adapter->is_initialized());
1555 
1556   adapter->WriteHeaders(RequestHeaders(), false);
1557 
1558   session_->StartReading();
1559   run_loop.Run();
1560 
1561   adapter->Disconnect();
1562 }
1563 
TEST_P(WebSocketQuicStreamAdapterTest,Read)1564 TEST_P(WebSocketQuicStreamAdapterTest, Read) {
1565   int packet_number = 1;
1566   mock_quic_data_.AddWrite(SYNCHRONOUS,
1567                            ConstructSettingsPacket(packet_number++));
1568 
1569   SpdyTestUtil spdy_util;
1570   quiche::HttpHeaderBlock request_header_block = WebSocketHttp2Request(
1571       "/", "www.example.org:443", "http://www.example.org", {});
1572   mock_quic_data_.AddWrite(
1573       SYNCHRONOUS,
1574       client_maker_.MakeRequestHeadersPacket(
1575           packet_number++, client_data_stream_id1_,
1576           /*fin=*/false, ConvertRequestPriorityToQuicPriority(LOWEST),
1577           std::move(request_header_block), nullptr));
1578 
1579   quiche::HttpHeaderBlock response_header_block = WebSocketHttp2Response({});
1580   mock_quic_data_.AddRead(
1581       ASYNC, server_maker_.MakeResponseHeadersPacket(
1582                  /*packet_number=*/1, client_data_stream_id1_, /*fin=*/false,
1583                  std::move(response_header_block),
1584                  /*spdy_headers_frame_length=*/nullptr));
1585   mock_quic_data_.AddRead(ASYNC, ERR_IO_PENDING);
1586 
1587   mock_quic_data_.AddRead(ASYNC, ConstructServerDataPacket(2, "foo"));
1588   mock_quic_data_.AddRead(SYNCHRONOUS,
1589                           ConstructServerDataPacket(3, "hogehoge"));
1590   mock_quic_data_.AddRead(SYNCHRONOUS, ERR_IO_PENDING);
1591 
1592   mock_quic_data_.AddWrite(ASYNC,
1593                            ConstructClientAckPacket(packet_number++, 2, 0));
1594   mock_quic_data_.AddWrite(
1595       SYNCHRONOUS, ConstructAckAndRstPacket(packet_number++,
1596                                             quic::QUIC_STREAM_CANCELLED, 3, 0));
1597 
1598   base::RunLoop run_loop;
1599   EXPECT_CALL(mock_delegate_, OnHeadersReceived(_)).WillOnce(Invoke([&]() {
1600     run_loop.Quit();
1601   }));
1602 
1603   Initialize();
1604 
1605   net::QuicChromiumClientSession::Handle* session_handle =
1606       GetQuicSessionHandle();
1607   ASSERT_TRUE(session_handle);
1608 
1609   TestWebSocketQuicStreamAdapterCompletionCallback callback;
1610   std::unique_ptr<WebSocketQuicStreamAdapter> adapter =
1611       session_handle->CreateWebSocketQuicStreamAdapter(
1612           &mock_delegate_, callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
1613   ASSERT_TRUE(adapter);
1614   EXPECT_TRUE(adapter->is_initialized());
1615 
1616   adapter->WriteHeaders(RequestHeaders(), false);
1617 
1618   session_->StartReading();
1619   run_loop.Run();
1620 
1621   // Buffer larger than each MockRead.
1622   constexpr int kReadBufSize = 1024;
1623   auto read_buf = base::MakeRefCounted<IOBufferWithSize>(kReadBufSize);
1624   TestCompletionCallback read_callback;
1625 
1626   int rv =
1627       adapter->Read(read_buf.get(), kReadBufSize, read_callback.callback());
1628 
1629   ASSERT_EQ(ERR_IO_PENDING, rv);
1630 
1631   mock_quic_data_.Resume();
1632   base::RunLoop().RunUntilIdle();
1633 
1634   rv = read_callback.WaitForResult();
1635   ASSERT_EQ(3, rv);
1636   EXPECT_EQ("foo", std::string_view(read_buf->data(), rv));
1637 
1638   rv = adapter->Read(read_buf.get(), kReadBufSize, CompletionOnceCallback());
1639   ASSERT_EQ(8, rv);
1640   EXPECT_EQ("hogehoge", std::string_view(read_buf->data(), rv));
1641 
1642   adapter->Disconnect();
1643 
1644   EXPECT_TRUE(mock_quic_data_.AllReadDataConsumed());
1645   EXPECT_TRUE(mock_quic_data_.AllWriteDataConsumed());
1646 }
1647 
TEST_P(WebSocketQuicStreamAdapterTest,ReadIntoSmallBuffer)1648 TEST_P(WebSocketQuicStreamAdapterTest, ReadIntoSmallBuffer) {
1649   int packet_number = 1;
1650   mock_quic_data_.AddWrite(SYNCHRONOUS,
1651                            ConstructSettingsPacket(packet_number++));
1652 
1653   SpdyTestUtil spdy_util;
1654   quiche::HttpHeaderBlock request_header_block = WebSocketHttp2Request(
1655       "/", "www.example.org:443", "http://www.example.org", {});
1656   mock_quic_data_.AddWrite(
1657       SYNCHRONOUS,
1658       client_maker_.MakeRequestHeadersPacket(
1659           packet_number++, client_data_stream_id1_,
1660           /*fin=*/false, ConvertRequestPriorityToQuicPriority(LOWEST),
1661           std::move(request_header_block), nullptr));
1662 
1663   quiche::HttpHeaderBlock response_header_block = WebSocketHttp2Response({});
1664   mock_quic_data_.AddRead(
1665       ASYNC, server_maker_.MakeResponseHeadersPacket(
1666                  /*packet_number=*/1, client_data_stream_id1_, /*fin=*/false,
1667                  std::move(response_header_block),
1668                  /*spdy_headers_frame_length=*/nullptr));
1669   mock_quic_data_.AddRead(ASYNC, ERR_IO_PENDING);
1670   // First read is the same size as the buffer, next is smaller, last is larger.
1671   mock_quic_data_.AddRead(ASYNC, ConstructServerDataPacket(2, "abc"));
1672   mock_quic_data_.AddRead(SYNCHRONOUS, ConstructServerDataPacket(3, "12"));
1673   mock_quic_data_.AddRead(SYNCHRONOUS, ConstructServerDataPacket(4, "ABCD"));
1674   mock_quic_data_.AddRead(SYNCHRONOUS, ERR_IO_PENDING);
1675 
1676   mock_quic_data_.AddWrite(ASYNC,
1677                            ConstructClientAckPacket(packet_number++, 2, 0));
1678   mock_quic_data_.AddWrite(
1679       SYNCHRONOUS, ConstructAckAndRstPacket(packet_number++,
1680                                             quic::QUIC_STREAM_CANCELLED, 4, 0));
1681 
1682   base::RunLoop run_loop;
1683   EXPECT_CALL(mock_delegate_, OnHeadersReceived(_)).WillOnce(Invoke([&]() {
1684     run_loop.Quit();
1685   }));
1686 
1687   Initialize();
1688 
1689   net::QuicChromiumClientSession::Handle* session_handle =
1690       GetQuicSessionHandle();
1691   ASSERT_TRUE(session_handle);
1692   TestWebSocketQuicStreamAdapterCompletionCallback callback;
1693   std::unique_ptr<WebSocketQuicStreamAdapter> adapter =
1694       session_handle->CreateWebSocketQuicStreamAdapter(
1695           &mock_delegate_, callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
1696   ASSERT_TRUE(adapter);
1697   EXPECT_TRUE(adapter->is_initialized());
1698 
1699   adapter->WriteHeaders(RequestHeaders(), false);
1700 
1701   session_->StartReading();
1702   run_loop.Run();
1703 
1704   constexpr int kReadBufSize = 3;
1705   auto read_buf = base::MakeRefCounted<IOBufferWithSize>(kReadBufSize);
1706   TestCompletionCallback read_callback;
1707 
1708   int rv =
1709       adapter->Read(read_buf.get(), kReadBufSize, read_callback.callback());
1710 
1711   ASSERT_EQ(ERR_IO_PENDING, rv);
1712 
1713   mock_quic_data_.Resume();
1714   base::RunLoop().RunUntilIdle();
1715 
1716   rv = read_callback.WaitForResult();
1717   ASSERT_EQ(3, rv);
1718   EXPECT_EQ("abc", std::string_view(read_buf->data(), rv));
1719 
1720   rv = adapter->Read(read_buf.get(), kReadBufSize, CompletionOnceCallback());
1721   ASSERT_EQ(3, rv);
1722   EXPECT_EQ("12A", std::string_view(read_buf->data(), rv));
1723 
1724   rv = adapter->Read(read_buf.get(), kReadBufSize, CompletionOnceCallback());
1725   ASSERT_EQ(3, rv);
1726   EXPECT_EQ("BCD", std::string_view(read_buf->data(), rv));
1727 
1728   adapter->Disconnect();
1729 
1730   EXPECT_TRUE(mock_quic_data_.AllReadDataConsumed());
1731   EXPECT_TRUE(mock_quic_data_.AllWriteDataConsumed());
1732 }
1733 
1734 }  // namespace net::test
1735