1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/http/http_response_body_drainer.h"
6
7 #include <stdint.h>
8
9 #include <cstring>
10 #include <set>
11 #include <string_view>
12 #include <utility>
13
14 #include "base/compiler_specific.h"
15 #include "base/functional/bind.h"
16 #include "base/location.h"
17 #include "base/memory/raw_ptr.h"
18 #include "base/memory/weak_ptr.h"
19 #include "base/no_destructor.h"
20 #include "base/run_loop.h"
21 #include "base/task/single_thread_task_runner.h"
22 #include "net/base/completion_once_callback.h"
23 #include "net/base/io_buffer.h"
24 #include "net/base/net_errors.h"
25 #include "net/base/test_completion_callback.h"
26 #include "net/cert/mock_cert_verifier.h"
27 #include "net/http/http_network_session.h"
28 #include "net/http/http_server_properties.h"
29 #include "net/http/http_stream.h"
30 #include "net/http/transport_security_state.h"
31 #include "net/proxy_resolution/configured_proxy_resolution_service.h"
32 #include "net/quic/quic_context.h"
33 #include "net/socket/socket_test_util.h"
34 #include "net/ssl/ssl_config_service_defaults.h"
35 #include "net/test/test_with_task_environment.h"
36 #include "net/url_request/static_http_user_agent_settings.h"
37 #include "testing/gtest/include/gtest/gtest.h"
38
39 namespace net {
40
41 namespace {
42
43 const int kMagicChunkSize = 1024;
44 static_assert((HttpResponseBodyDrainer::kDrainBodyBufferSize %
45 kMagicChunkSize) == 0,
46 "chunk size needs to divide evenly into buffer size");
47
48 class CloseResultWaiter {
49 public:
50 CloseResultWaiter() = default;
51
52 CloseResultWaiter(const CloseResultWaiter&) = delete;
53 CloseResultWaiter& operator=(const CloseResultWaiter&) = delete;
54
WaitForResult()55 int WaitForResult() {
56 CHECK(!waiting_for_result_);
57 while (!have_result_) {
58 waiting_for_result_ = true;
59 loop_.Run();
60 waiting_for_result_ = false;
61 }
62 return result_;
63 }
64
set_result(bool result)65 void set_result(bool result) {
66 result_ = result;
67 have_result_ = true;
68 if (waiting_for_result_) {
69 loop_.Quit();
70 }
71 }
72
73 private:
74 int result_ = false;
75 bool have_result_ = false;
76 bool waiting_for_result_ = false;
77 base::RunLoop loop_;
78 };
79
80 class MockHttpStream : public HttpStream {
81 public:
MockHttpStream(CloseResultWaiter * result_waiter)82 explicit MockHttpStream(CloseResultWaiter* result_waiter)
83 : result_waiter_(result_waiter) {}
84
85 MockHttpStream(const MockHttpStream&) = delete;
86 MockHttpStream& operator=(const MockHttpStream&) = delete;
87
88 ~MockHttpStream() override = default;
89
90 // HttpStream implementation.
RegisterRequest(const HttpRequestInfo * request_info)91 void RegisterRequest(const HttpRequestInfo* request_info) override {}
InitializeStream(bool can_send_early,RequestPriority priority,const NetLogWithSource & net_log,CompletionOnceCallback callback)92 int InitializeStream(bool can_send_early,
93 RequestPriority priority,
94 const NetLogWithSource& net_log,
95 CompletionOnceCallback callback) override {
96 return ERR_UNEXPECTED;
97 }
SendRequest(const HttpRequestHeaders & request_headers,HttpResponseInfo * response,CompletionOnceCallback callback)98 int SendRequest(const HttpRequestHeaders& request_headers,
99 HttpResponseInfo* response,
100 CompletionOnceCallback callback) override {
101 return ERR_UNEXPECTED;
102 }
ReadResponseHeaders(CompletionOnceCallback callback)103 int ReadResponseHeaders(CompletionOnceCallback callback) override {
104 return ERR_UNEXPECTED;
105 }
106
IsConnectionReused() const107 bool IsConnectionReused() const override { return false; }
SetConnectionReused()108 void SetConnectionReused() override {}
CanReuseConnection() const109 bool CanReuseConnection() const override { return can_reuse_connection_; }
GetTotalReceivedBytes() const110 int64_t GetTotalReceivedBytes() const override { return 0; }
GetTotalSentBytes() const111 int64_t GetTotalSentBytes() const override { return 0; }
GetAlternativeService(AlternativeService * alternative_service) const112 bool GetAlternativeService(
113 AlternativeService* alternative_service) const override {
114 return false;
115 }
GetSSLInfo(SSLInfo * ssl_info)116 void GetSSLInfo(SSLInfo* ssl_info) override {}
GetRemoteEndpoint(IPEndPoint * endpoint)117 int GetRemoteEndpoint(IPEndPoint* endpoint) override {
118 return ERR_UNEXPECTED;
119 }
120
121 // Mocked API
122 int ReadResponseBody(IOBuffer* buf,
123 int buf_len,
124 CompletionOnceCallback callback) override;
Close(bool not_reusable)125 void Close(bool not_reusable) override {
126 CHECK(!closed_);
127 closed_ = true;
128 result_waiter_->set_result(not_reusable);
129 }
130
RenewStreamForAuth()131 std::unique_ptr<HttpStream> RenewStreamForAuth() override { return nullptr; }
132
IsResponseBodyComplete() const133 bool IsResponseBodyComplete() const override { return is_complete_; }
134
GetLoadTimingInfo(LoadTimingInfo * load_timing_info) const135 bool GetLoadTimingInfo(LoadTimingInfo* load_timing_info) const override {
136 return false;
137 }
138
Drain(HttpNetworkSession *)139 void Drain(HttpNetworkSession*) override {}
140
PopulateNetErrorDetails(NetErrorDetails * details)141 void PopulateNetErrorDetails(NetErrorDetails* details) override { return; }
142
SetPriority(RequestPriority priority)143 void SetPriority(RequestPriority priority) override {}
144
GetDnsAliases() const145 const std::set<std::string>& GetDnsAliases() const override {
146 static const base::NoDestructor<std::set<std::string>> nullset_result;
147 return *nullset_result;
148 }
149
GetAcceptChViaAlps() const150 std::string_view GetAcceptChViaAlps() const override { return {}; }
151
152 // Methods to tweak/observer mock behavior:
set_stall_reads_forever()153 void set_stall_reads_forever() { stall_reads_forever_ = true; }
154
set_num_chunks(int num_chunks)155 void set_num_chunks(int num_chunks) { num_chunks_ = num_chunks; }
156
set_sync()157 void set_sync() { is_sync_ = true; }
158
set_is_last_chunk_zero_size()159 void set_is_last_chunk_zero_size() { is_last_chunk_zero_size_ = true; }
160
161 // Sets result value of CanReuseConnection. Defaults to true.
set_can_reuse_connection(bool can_reuse_connection)162 void set_can_reuse_connection(bool can_reuse_connection) {
163 can_reuse_connection_ = can_reuse_connection;
164 }
165
SetRequestHeadersCallback(RequestHeadersCallback callback)166 void SetRequestHeadersCallback(RequestHeadersCallback callback) override {}
167
168 private:
169 int ReadResponseBodyImpl(IOBuffer* buf, int buf_len);
170 void CompleteRead();
171
closed() const172 bool closed() const { return closed_; }
173
174 const raw_ptr<CloseResultWaiter> result_waiter_;
175 scoped_refptr<IOBuffer> user_buf_;
176 CompletionOnceCallback callback_;
177 int buf_len_ = 0;
178 bool closed_ = false;
179 bool stall_reads_forever_ = false;
180 int num_chunks_ = 0;
181 bool is_sync_ = false;
182 bool is_last_chunk_zero_size_ = false;
183 bool is_complete_ = false;
184 bool can_reuse_connection_ = true;
185
186 base::WeakPtrFactory<MockHttpStream> weak_factory_{this};
187 };
188
ReadResponseBody(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)189 int MockHttpStream::ReadResponseBody(IOBuffer* buf,
190 int buf_len,
191 CompletionOnceCallback callback) {
192 CHECK(!callback.is_null());
193 CHECK(callback_.is_null());
194 CHECK(buf);
195
196 if (stall_reads_forever_)
197 return ERR_IO_PENDING;
198
199 if (is_complete_)
200 return ERR_UNEXPECTED;
201
202 if (!is_sync_) {
203 user_buf_ = buf;
204 buf_len_ = buf_len;
205 callback_ = std::move(callback);
206 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
207 FROM_HERE, base::BindOnce(&MockHttpStream::CompleteRead,
208 weak_factory_.GetWeakPtr()));
209 return ERR_IO_PENDING;
210 } else {
211 return ReadResponseBodyImpl(buf, buf_len);
212 }
213 }
214
ReadResponseBodyImpl(IOBuffer * buf,int buf_len)215 int MockHttpStream::ReadResponseBodyImpl(IOBuffer* buf, int buf_len) {
216 if (is_last_chunk_zero_size_ && num_chunks_ == 1) {
217 buf_len = 0;
218 } else {
219 if (buf_len > kMagicChunkSize)
220 buf_len = kMagicChunkSize;
221 std::memset(buf->data(), 1, buf_len);
222 }
223 num_chunks_--;
224 if (!num_chunks_)
225 is_complete_ = true;
226
227 return buf_len;
228 }
229
CompleteRead()230 void MockHttpStream::CompleteRead() {
231 int result = ReadResponseBodyImpl(user_buf_.get(), buf_len_);
232 user_buf_ = nullptr;
233 std::move(callback_).Run(result);
234 }
235
236 class HttpResponseBodyDrainerTest : public TestWithTaskEnvironment {
237 protected:
HttpResponseBodyDrainerTest()238 HttpResponseBodyDrainerTest()
239 : proxy_resolution_service_(
240 ConfiguredProxyResolutionService::CreateDirect()),
241 ssl_config_service_(std::make_unique<SSLConfigServiceDefaults>()),
242 http_server_properties_(std::make_unique<HttpServerProperties>()),
243 session_(CreateNetworkSession()),
244 mock_stream_(new MockHttpStream(&result_waiter_)) {
245 drainer_ = std::make_unique<HttpResponseBodyDrainer>(mock_stream_);
246 }
247
248 ~HttpResponseBodyDrainerTest() override = default;
249
CreateNetworkSession()250 std::unique_ptr<HttpNetworkSession> CreateNetworkSession() {
251 HttpNetworkSessionContext context;
252 context.client_socket_factory = &socket_factory_;
253 context.proxy_resolution_service = proxy_resolution_service_.get();
254 context.ssl_config_service = ssl_config_service_.get();
255 context.http_user_agent_settings = &http_user_agent_settings_;
256 context.http_server_properties = http_server_properties_.get();
257 context.cert_verifier = &cert_verifier_;
258 context.transport_security_state = &transport_security_state_;
259 context.quic_context = &quic_context_;
260 return std::make_unique<HttpNetworkSession>(HttpNetworkSessionParams(),
261 context);
262 }
263
264 std::unique_ptr<ProxyResolutionService> proxy_resolution_service_;
265 std::unique_ptr<SSLConfigService> ssl_config_service_;
266 StaticHttpUserAgentSettings http_user_agent_settings_ = {"*", "test-ua"};
267 std::unique_ptr<HttpServerProperties> http_server_properties_;
268 MockCertVerifier cert_verifier_;
269 TransportSecurityState transport_security_state_;
270 QuicContext quic_context_;
271 MockClientSocketFactory socket_factory_;
272 const std::unique_ptr<HttpNetworkSession> session_;
273 CloseResultWaiter result_waiter_;
274 const raw_ptr<MockHttpStream, AcrossTasksDanglingUntriaged>
275 mock_stream_; // Owned by |drainer_|.
276 std::unique_ptr<HttpResponseBodyDrainer> drainer_;
277 };
278
TEST_F(HttpResponseBodyDrainerTest,DrainBodySyncSingleOK)279 TEST_F(HttpResponseBodyDrainerTest, DrainBodySyncSingleOK) {
280 mock_stream_->set_num_chunks(1);
281 mock_stream_->set_sync();
282 session_->StartResponseDrainer(std::move(drainer_));
283 EXPECT_FALSE(result_waiter_.WaitForResult());
284 }
285
TEST_F(HttpResponseBodyDrainerTest,DrainBodySyncOK)286 TEST_F(HttpResponseBodyDrainerTest, DrainBodySyncOK) {
287 mock_stream_->set_num_chunks(3);
288 mock_stream_->set_sync();
289 session_->StartResponseDrainer(std::move(drainer_));
290 EXPECT_FALSE(result_waiter_.WaitForResult());
291 }
292
TEST_F(HttpResponseBodyDrainerTest,DrainBodyAsyncOK)293 TEST_F(HttpResponseBodyDrainerTest, DrainBodyAsyncOK) {
294 mock_stream_->set_num_chunks(3);
295 session_->StartResponseDrainer(std::move(drainer_));
296 EXPECT_FALSE(result_waiter_.WaitForResult());
297 }
298
299 // Test the case when the final chunk is 0 bytes. This can happen when
300 // the final 0-byte chunk of a chunk-encoded http response is read in a last
301 // call to ReadResponseBody, after all data were returned from HttpStream.
TEST_F(HttpResponseBodyDrainerTest,DrainBodyAsyncEmptyChunk)302 TEST_F(HttpResponseBodyDrainerTest, DrainBodyAsyncEmptyChunk) {
303 mock_stream_->set_num_chunks(4);
304 mock_stream_->set_is_last_chunk_zero_size();
305 session_->StartResponseDrainer(std::move(drainer_));
306 EXPECT_FALSE(result_waiter_.WaitForResult());
307 }
308
TEST_F(HttpResponseBodyDrainerTest,DrainBodySyncEmptyChunk)309 TEST_F(HttpResponseBodyDrainerTest, DrainBodySyncEmptyChunk) {
310 mock_stream_->set_num_chunks(4);
311 mock_stream_->set_sync();
312 mock_stream_->set_is_last_chunk_zero_size();
313 session_->StartResponseDrainer(std::move(drainer_));
314 EXPECT_FALSE(result_waiter_.WaitForResult());
315 }
316
TEST_F(HttpResponseBodyDrainerTest,DrainBodySizeEqualsDrainBuffer)317 TEST_F(HttpResponseBodyDrainerTest, DrainBodySizeEqualsDrainBuffer) {
318 mock_stream_->set_num_chunks(
319 HttpResponseBodyDrainer::kDrainBodyBufferSize / kMagicChunkSize);
320 session_->StartResponseDrainer(std::move(drainer_));
321 EXPECT_FALSE(result_waiter_.WaitForResult());
322 }
323
TEST_F(HttpResponseBodyDrainerTest,DrainBodyTimeOut)324 TEST_F(HttpResponseBodyDrainerTest, DrainBodyTimeOut) {
325 mock_stream_->set_num_chunks(2);
326 mock_stream_->set_stall_reads_forever();
327 session_->StartResponseDrainer(std::move(drainer_));
328 EXPECT_TRUE(result_waiter_.WaitForResult());
329 }
330
TEST_F(HttpResponseBodyDrainerTest,CancelledBySession)331 TEST_F(HttpResponseBodyDrainerTest, CancelledBySession) {
332 mock_stream_->set_num_chunks(2);
333 mock_stream_->set_stall_reads_forever();
334 session_->StartResponseDrainer(std::move(drainer_));
335 // HttpNetworkSession should delete |drainer_|.
336 }
337
TEST_F(HttpResponseBodyDrainerTest,DrainBodyTooLarge)338 TEST_F(HttpResponseBodyDrainerTest, DrainBodyTooLarge) {
339 int too_many_chunks =
340 HttpResponseBodyDrainer::kDrainBodyBufferSize / kMagicChunkSize;
341 too_many_chunks += 1; // Now it's too large.
342
343 mock_stream_->set_num_chunks(too_many_chunks);
344 session_->StartResponseDrainer(std::move(drainer_));
345 EXPECT_TRUE(result_waiter_.WaitForResult());
346 }
347
TEST_F(HttpResponseBodyDrainerTest,DrainBodyCantReuse)348 TEST_F(HttpResponseBodyDrainerTest, DrainBodyCantReuse) {
349 mock_stream_->set_num_chunks(1);
350 mock_stream_->set_can_reuse_connection(false);
351 session_->StartResponseDrainer(std::move(drainer_));
352 EXPECT_TRUE(result_waiter_.WaitForResult());
353 }
354
355 } // namespace
356
357 } // namespace net
358