1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/http/http_response_body_drainer.h"
6
7 #include <stdint.h>
8
9 #include <cstring>
10 #include <set>
11 #include <utility>
12
13 #include "base/compiler_specific.h"
14 #include "base/functional/bind.h"
15 #include "base/location.h"
16 #include "base/memory/raw_ptr.h"
17 #include "base/memory/weak_ptr.h"
18 #include "base/no_destructor.h"
19 #include "base/run_loop.h"
20 #include "base/strings/string_piece.h"
21 #include "base/task/single_thread_task_runner.h"
22 #include "net/base/completion_once_callback.h"
23 #include "net/base/io_buffer.h"
24 #include "net/base/net_errors.h"
25 #include "net/base/test_completion_callback.h"
26 #include "net/cert/ct_policy_enforcer.h"
27 #include "net/cert/mock_cert_verifier.h"
28 #include "net/http/http_network_session.h"
29 #include "net/http/http_server_properties.h"
30 #include "net/http/http_stream.h"
31 #include "net/http/transport_security_state.h"
32 #include "net/proxy_resolution/configured_proxy_resolution_service.h"
33 #include "net/quic/quic_context.h"
34 #include "net/socket/socket_test_util.h"
35 #include "net/ssl/ssl_config_service_defaults.h"
36 #include "net/test/test_with_task_environment.h"
37 #include "testing/gtest/include/gtest/gtest.h"
38
39 namespace net {
40
41 namespace {
42
43 const int kMagicChunkSize = 1024;
44 static_assert((HttpResponseBodyDrainer::kDrainBodyBufferSize %
45 kMagicChunkSize) == 0,
46 "chunk size needs to divide evenly into buffer size");
47
48 class CloseResultWaiter {
49 public:
50 CloseResultWaiter() = default;
51
52 CloseResultWaiter(const CloseResultWaiter&) = delete;
53 CloseResultWaiter& operator=(const CloseResultWaiter&) = delete;
54
WaitForResult()55 int WaitForResult() {
56 CHECK(!waiting_for_result_);
57 while (!have_result_) {
58 waiting_for_result_ = true;
59 base::RunLoop().Run();
60 waiting_for_result_ = false;
61 }
62 return result_;
63 }
64
set_result(bool result)65 void set_result(bool result) {
66 result_ = result;
67 have_result_ = true;
68 if (waiting_for_result_)
69 base::RunLoop::QuitCurrentWhenIdleDeprecated();
70 }
71
72 private:
73 int result_ = false;
74 bool have_result_ = false;
75 bool waiting_for_result_ = false;
76 };
77
78 class MockHttpStream : public HttpStream {
79 public:
MockHttpStream(CloseResultWaiter * result_waiter)80 explicit MockHttpStream(CloseResultWaiter* result_waiter)
81 : result_waiter_(result_waiter) {}
82
83 MockHttpStream(const MockHttpStream&) = delete;
84 MockHttpStream& operator=(const MockHttpStream&) = delete;
85
86 ~MockHttpStream() override = default;
87
88 // HttpStream implementation.
RegisterRequest(const HttpRequestInfo * request_info)89 void RegisterRequest(const HttpRequestInfo* request_info) override {}
InitializeStream(bool can_send_early,RequestPriority priority,const NetLogWithSource & net_log,CompletionOnceCallback callback)90 int InitializeStream(bool can_send_early,
91 RequestPriority priority,
92 const NetLogWithSource& net_log,
93 CompletionOnceCallback callback) override {
94 return ERR_UNEXPECTED;
95 }
SendRequest(const HttpRequestHeaders & request_headers,HttpResponseInfo * response,CompletionOnceCallback callback)96 int SendRequest(const HttpRequestHeaders& request_headers,
97 HttpResponseInfo* response,
98 CompletionOnceCallback callback) override {
99 return ERR_UNEXPECTED;
100 }
ReadResponseHeaders(CompletionOnceCallback callback)101 int ReadResponseHeaders(CompletionOnceCallback callback) override {
102 return ERR_UNEXPECTED;
103 }
104
IsConnectionReused() const105 bool IsConnectionReused() const override { return false; }
SetConnectionReused()106 void SetConnectionReused() override {}
CanReuseConnection() const107 bool CanReuseConnection() const override { return can_reuse_connection_; }
GetTotalReceivedBytes() const108 int64_t GetTotalReceivedBytes() const override { return 0; }
GetTotalSentBytes() const109 int64_t GetTotalSentBytes() const override { return 0; }
GetAlternativeService(AlternativeService * alternative_service) const110 bool GetAlternativeService(
111 AlternativeService* alternative_service) const override {
112 return false;
113 }
GetSSLInfo(SSLInfo * ssl_info)114 void GetSSLInfo(SSLInfo* ssl_info) override {}
GetSSLCertRequestInfo(SSLCertRequestInfo * cert_request_info)115 void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info) override {}
GetRemoteEndpoint(IPEndPoint * endpoint)116 int GetRemoteEndpoint(IPEndPoint* endpoint) override {
117 return ERR_UNEXPECTED;
118 }
119
120 // Mocked API
121 int ReadResponseBody(IOBuffer* buf,
122 int buf_len,
123 CompletionOnceCallback callback) override;
Close(bool not_reusable)124 void Close(bool not_reusable) override {
125 CHECK(!closed_);
126 closed_ = true;
127 result_waiter_->set_result(not_reusable);
128 }
129
RenewStreamForAuth()130 std::unique_ptr<HttpStream> RenewStreamForAuth() override { return nullptr; }
131
IsResponseBodyComplete() const132 bool IsResponseBodyComplete() const override { return is_complete_; }
133
GetLoadTimingInfo(LoadTimingInfo * load_timing_info) const134 bool GetLoadTimingInfo(LoadTimingInfo* load_timing_info) const override {
135 return false;
136 }
137
Drain(HttpNetworkSession *)138 void Drain(HttpNetworkSession*) override {}
139
PopulateNetErrorDetails(NetErrorDetails * details)140 void PopulateNetErrorDetails(NetErrorDetails* details) override { return; }
141
SetPriority(RequestPriority priority)142 void SetPriority(RequestPriority priority) override {}
143
GetDnsAliases() const144 const std::set<std::string>& GetDnsAliases() const override {
145 static const base::NoDestructor<std::set<std::string>> nullset_result;
146 return *nullset_result;
147 }
148
GetAcceptChViaAlps() const149 base::StringPiece GetAcceptChViaAlps() const override { return {}; }
150
151 // Methods to tweak/observer mock behavior:
set_stall_reads_forever()152 void set_stall_reads_forever() { stall_reads_forever_ = true; }
153
set_num_chunks(int num_chunks)154 void set_num_chunks(int num_chunks) { num_chunks_ = num_chunks; }
155
set_sync()156 void set_sync() { is_sync_ = true; }
157
set_is_last_chunk_zero_size()158 void set_is_last_chunk_zero_size() { is_last_chunk_zero_size_ = true; }
159
160 // Sets result value of CanReuseConnection. Defaults to true.
set_can_reuse_connection(bool can_reuse_connection)161 void set_can_reuse_connection(bool can_reuse_connection) {
162 can_reuse_connection_ = can_reuse_connection;
163 }
164
SetRequestHeadersCallback(RequestHeadersCallback callback)165 void SetRequestHeadersCallback(RequestHeadersCallback callback) override {}
166
167 private:
168 int ReadResponseBodyImpl(IOBuffer* buf, int buf_len);
169 void CompleteRead();
170
closed() const171 bool closed() const { return closed_; }
172
173 const raw_ptr<CloseResultWaiter> result_waiter_;
174 scoped_refptr<IOBuffer> user_buf_;
175 CompletionOnceCallback callback_;
176 int buf_len_ = 0;
177 bool closed_ = false;
178 bool stall_reads_forever_ = false;
179 int num_chunks_ = 0;
180 bool is_sync_ = false;
181 bool is_last_chunk_zero_size_ = false;
182 bool is_complete_ = false;
183 bool can_reuse_connection_ = true;
184
185 base::WeakPtrFactory<MockHttpStream> weak_factory_{this};
186 };
187
ReadResponseBody(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)188 int MockHttpStream::ReadResponseBody(IOBuffer* buf,
189 int buf_len,
190 CompletionOnceCallback callback) {
191 CHECK(!callback.is_null());
192 CHECK(callback_.is_null());
193 CHECK(buf);
194
195 if (stall_reads_forever_)
196 return ERR_IO_PENDING;
197
198 if (is_complete_)
199 return ERR_UNEXPECTED;
200
201 if (!is_sync_) {
202 user_buf_ = buf;
203 buf_len_ = buf_len;
204 callback_ = std::move(callback);
205 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
206 FROM_HERE, base::BindOnce(&MockHttpStream::CompleteRead,
207 weak_factory_.GetWeakPtr()));
208 return ERR_IO_PENDING;
209 } else {
210 return ReadResponseBodyImpl(buf, buf_len);
211 }
212 }
213
ReadResponseBodyImpl(IOBuffer * buf,int buf_len)214 int MockHttpStream::ReadResponseBodyImpl(IOBuffer* buf, int buf_len) {
215 if (is_last_chunk_zero_size_ && num_chunks_ == 1) {
216 buf_len = 0;
217 } else {
218 if (buf_len > kMagicChunkSize)
219 buf_len = kMagicChunkSize;
220 std::memset(buf->data(), 1, buf_len);
221 }
222 num_chunks_--;
223 if (!num_chunks_)
224 is_complete_ = true;
225
226 return buf_len;
227 }
228
CompleteRead()229 void MockHttpStream::CompleteRead() {
230 int result = ReadResponseBodyImpl(user_buf_.get(), buf_len_);
231 user_buf_ = nullptr;
232 std::move(callback_).Run(result);
233 }
234
235 class HttpResponseBodyDrainerTest : public TestWithTaskEnvironment {
236 protected:
HttpResponseBodyDrainerTest()237 HttpResponseBodyDrainerTest()
238 : proxy_resolution_service_(
239 ConfiguredProxyResolutionService::CreateDirect()),
240 ssl_config_service_(std::make_unique<SSLConfigServiceDefaults>()),
241 http_server_properties_(std::make_unique<HttpServerProperties>()),
242 session_(CreateNetworkSession()),
243 mock_stream_(new MockHttpStream(&result_waiter_)) {
244 drainer_ = std::make_unique<HttpResponseBodyDrainer>(mock_stream_);
245 }
246
247 ~HttpResponseBodyDrainerTest() override = default;
248
CreateNetworkSession()249 std::unique_ptr<HttpNetworkSession> CreateNetworkSession() {
250 HttpNetworkSessionContext context;
251 context.client_socket_factory = &socket_factory_;
252 context.proxy_resolution_service = proxy_resolution_service_.get();
253 context.ssl_config_service = ssl_config_service_.get();
254 context.http_server_properties = http_server_properties_.get();
255 context.cert_verifier = &cert_verifier_;
256 context.transport_security_state = &transport_security_state_;
257 context.ct_policy_enforcer = &ct_policy_enforcer_;
258 context.quic_context = &quic_context_;
259 return std::make_unique<HttpNetworkSession>(HttpNetworkSessionParams(),
260 context);
261 }
262
263 std::unique_ptr<ProxyResolutionService> proxy_resolution_service_;
264 std::unique_ptr<SSLConfigService> ssl_config_service_;
265 std::unique_ptr<HttpServerProperties> http_server_properties_;
266 MockCertVerifier cert_verifier_;
267 TransportSecurityState transport_security_state_;
268 DefaultCTPolicyEnforcer ct_policy_enforcer_;
269 QuicContext quic_context_;
270 MockClientSocketFactory socket_factory_;
271 const std::unique_ptr<HttpNetworkSession> session_;
272 CloseResultWaiter result_waiter_;
273 const raw_ptr<MockHttpStream> mock_stream_; // Owned by |drainer_|.
274 std::unique_ptr<HttpResponseBodyDrainer> drainer_;
275 };
276
TEST_F(HttpResponseBodyDrainerTest,DrainBodySyncSingleOK)277 TEST_F(HttpResponseBodyDrainerTest, DrainBodySyncSingleOK) {
278 mock_stream_->set_num_chunks(1);
279 mock_stream_->set_sync();
280 session_->StartResponseDrainer(std::move(drainer_));
281 EXPECT_FALSE(result_waiter_.WaitForResult());
282 }
283
TEST_F(HttpResponseBodyDrainerTest,DrainBodySyncOK)284 TEST_F(HttpResponseBodyDrainerTest, DrainBodySyncOK) {
285 mock_stream_->set_num_chunks(3);
286 mock_stream_->set_sync();
287 session_->StartResponseDrainer(std::move(drainer_));
288 EXPECT_FALSE(result_waiter_.WaitForResult());
289 }
290
TEST_F(HttpResponseBodyDrainerTest,DrainBodyAsyncOK)291 TEST_F(HttpResponseBodyDrainerTest, DrainBodyAsyncOK) {
292 mock_stream_->set_num_chunks(3);
293 session_->StartResponseDrainer(std::move(drainer_));
294 EXPECT_FALSE(result_waiter_.WaitForResult());
295 }
296
297 // Test the case when the final chunk is 0 bytes. This can happen when
298 // the final 0-byte chunk of a chunk-encoded http response is read in a last
299 // call to ReadResponseBody, after all data were returned from HttpStream.
TEST_F(HttpResponseBodyDrainerTest,DrainBodyAsyncEmptyChunk)300 TEST_F(HttpResponseBodyDrainerTest, DrainBodyAsyncEmptyChunk) {
301 mock_stream_->set_num_chunks(4);
302 mock_stream_->set_is_last_chunk_zero_size();
303 session_->StartResponseDrainer(std::move(drainer_));
304 EXPECT_FALSE(result_waiter_.WaitForResult());
305 }
306
TEST_F(HttpResponseBodyDrainerTest,DrainBodySyncEmptyChunk)307 TEST_F(HttpResponseBodyDrainerTest, DrainBodySyncEmptyChunk) {
308 mock_stream_->set_num_chunks(4);
309 mock_stream_->set_sync();
310 mock_stream_->set_is_last_chunk_zero_size();
311 session_->StartResponseDrainer(std::move(drainer_));
312 EXPECT_FALSE(result_waiter_.WaitForResult());
313 }
314
TEST_F(HttpResponseBodyDrainerTest,DrainBodySizeEqualsDrainBuffer)315 TEST_F(HttpResponseBodyDrainerTest, DrainBodySizeEqualsDrainBuffer) {
316 mock_stream_->set_num_chunks(
317 HttpResponseBodyDrainer::kDrainBodyBufferSize / kMagicChunkSize);
318 session_->StartResponseDrainer(std::move(drainer_));
319 EXPECT_FALSE(result_waiter_.WaitForResult());
320 }
321
TEST_F(HttpResponseBodyDrainerTest,DrainBodyTimeOut)322 TEST_F(HttpResponseBodyDrainerTest, DrainBodyTimeOut) {
323 mock_stream_->set_num_chunks(2);
324 mock_stream_->set_stall_reads_forever();
325 session_->StartResponseDrainer(std::move(drainer_));
326 EXPECT_TRUE(result_waiter_.WaitForResult());
327 }
328
TEST_F(HttpResponseBodyDrainerTest,CancelledBySession)329 TEST_F(HttpResponseBodyDrainerTest, CancelledBySession) {
330 mock_stream_->set_num_chunks(2);
331 mock_stream_->set_stall_reads_forever();
332 session_->StartResponseDrainer(std::move(drainer_));
333 // HttpNetworkSession should delete |drainer_|.
334 }
335
TEST_F(HttpResponseBodyDrainerTest,DrainBodyTooLarge)336 TEST_F(HttpResponseBodyDrainerTest, DrainBodyTooLarge) {
337 int too_many_chunks =
338 HttpResponseBodyDrainer::kDrainBodyBufferSize / kMagicChunkSize;
339 too_many_chunks += 1; // Now it's too large.
340
341 mock_stream_->set_num_chunks(too_many_chunks);
342 session_->StartResponseDrainer(std::move(drainer_));
343 EXPECT_TRUE(result_waiter_.WaitForResult());
344 }
345
TEST_F(HttpResponseBodyDrainerTest,DrainBodyCantReuse)346 TEST_F(HttpResponseBodyDrainerTest, DrainBodyCantReuse) {
347 mock_stream_->set_num_chunks(1);
348 mock_stream_->set_can_reuse_connection(false);
349 session_->StartResponseDrainer(std::move(drainer_));
350 EXPECT_TRUE(result_waiter_.WaitForResult());
351 }
352
353 } // namespace
354
355 } // namespace net
356