1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/http/http_response_body_drainer.h"
6
7 #include <cstring>
8
9 #include "base/bind.h"
10 #include "base/compiler_specific.h"
11 #include "base/memory/weak_ptr.h"
12 #include "base/message_loop/message_loop.h"
13 #include "net/base/io_buffer.h"
14 #include "net/base/net_errors.h"
15 #include "net/base/test_completion_callback.h"
16 #include "net/http/http_network_session.h"
17 #include "net/http/http_server_properties_impl.h"
18 #include "net/http/http_stream.h"
19 #include "net/proxy/proxy_service.h"
20 #include "net/ssl/ssl_config_service_defaults.h"
21 #include "testing/gtest/include/gtest/gtest.h"
22
23 namespace net {
24
25 namespace {
26
27 const int kMagicChunkSize = 1024;
28 COMPILE_ASSERT(
29 (HttpResponseBodyDrainer::kDrainBodyBufferSize % kMagicChunkSize) == 0,
30 chunk_size_needs_to_divide_evenly_into_buffer_size);
31
32 class CloseResultWaiter {
33 public:
CloseResultWaiter()34 CloseResultWaiter()
35 : result_(false),
36 have_result_(false),
37 waiting_for_result_(false) {}
38
WaitForResult()39 int WaitForResult() {
40 CHECK(!waiting_for_result_);
41 while (!have_result_) {
42 waiting_for_result_ = true;
43 base::MessageLoop::current()->Run();
44 waiting_for_result_ = false;
45 }
46 return result_;
47 }
48
set_result(bool result)49 void set_result(bool result) {
50 result_ = result;
51 have_result_ = true;
52 if (waiting_for_result_)
53 base::MessageLoop::current()->Quit();
54 }
55
56 private:
57 int result_;
58 bool have_result_;
59 bool waiting_for_result_;
60
61 DISALLOW_COPY_AND_ASSIGN(CloseResultWaiter);
62 };
63
64 class MockHttpStream : public HttpStream {
65 public:
MockHttpStream(CloseResultWaiter * result_waiter)66 MockHttpStream(CloseResultWaiter* result_waiter)
67 : result_waiter_(result_waiter),
68 buf_len_(0),
69 closed_(false),
70 stall_reads_forever_(false),
71 num_chunks_(0),
72 is_sync_(false),
73 is_last_chunk_zero_size_(false),
74 is_complete_(false),
75 weak_factory_(this) {}
~MockHttpStream()76 virtual ~MockHttpStream() {}
77
78 // HttpStream implementation.
InitializeStream(const HttpRequestInfo * request_info,RequestPriority priority,const BoundNetLog & net_log,const CompletionCallback & callback)79 virtual int InitializeStream(const HttpRequestInfo* request_info,
80 RequestPriority priority,
81 const BoundNetLog& net_log,
82 const CompletionCallback& callback) OVERRIDE {
83 return ERR_UNEXPECTED;
84 }
SendRequest(const HttpRequestHeaders & request_headers,HttpResponseInfo * response,const CompletionCallback & callback)85 virtual int SendRequest(const HttpRequestHeaders& request_headers,
86 HttpResponseInfo* response,
87 const CompletionCallback& callback) OVERRIDE {
88 return ERR_UNEXPECTED;
89 }
GetUploadProgress() const90 virtual UploadProgress GetUploadProgress() const OVERRIDE {
91 return UploadProgress();
92 }
ReadResponseHeaders(const CompletionCallback & callback)93 virtual int ReadResponseHeaders(const CompletionCallback& callback) OVERRIDE {
94 return ERR_UNEXPECTED;
95 }
GetResponseInfo() const96 virtual const HttpResponseInfo* GetResponseInfo() const OVERRIDE {
97 return NULL;
98 }
99
CanFindEndOfResponse() const100 virtual bool CanFindEndOfResponse() const OVERRIDE { return true; }
IsConnectionReused() const101 virtual bool IsConnectionReused() const OVERRIDE { return false; }
SetConnectionReused()102 virtual void SetConnectionReused() OVERRIDE {}
IsConnectionReusable() const103 virtual bool IsConnectionReusable() const OVERRIDE { return false; }
GetTotalReceivedBytes() const104 virtual int64 GetTotalReceivedBytes() const OVERRIDE { return 0; }
GetSSLInfo(SSLInfo * ssl_info)105 virtual void GetSSLInfo(SSLInfo* ssl_info) OVERRIDE {}
GetSSLCertRequestInfo(SSLCertRequestInfo * cert_request_info)106 virtual void GetSSLCertRequestInfo(
107 SSLCertRequestInfo* cert_request_info) OVERRIDE {}
108
109 // Mocked API
110 virtual int ReadResponseBody(IOBuffer* buf, int buf_len,
111 const CompletionCallback& callback) OVERRIDE;
Close(bool not_reusable)112 virtual void Close(bool not_reusable) OVERRIDE {
113 CHECK(!closed_);
114 closed_ = true;
115 result_waiter_->set_result(not_reusable);
116 }
117
RenewStreamForAuth()118 virtual HttpStream* RenewStreamForAuth() OVERRIDE {
119 return NULL;
120 }
121
IsResponseBodyComplete() const122 virtual bool IsResponseBodyComplete() const OVERRIDE { return is_complete_; }
123
IsSpdyHttpStream() const124 virtual bool IsSpdyHttpStream() const OVERRIDE { return false; }
125
GetLoadTimingInfo(LoadTimingInfo * load_timing_info) const126 virtual bool GetLoadTimingInfo(
127 LoadTimingInfo* load_timing_info) const OVERRIDE { return false; }
128
Drain(HttpNetworkSession *)129 virtual void Drain(HttpNetworkSession*) OVERRIDE {}
130
SetPriority(RequestPriority priority)131 virtual void SetPriority(RequestPriority priority) OVERRIDE {}
132
133 // Methods to tweak/observer mock behavior:
set_stall_reads_forever()134 void set_stall_reads_forever() { stall_reads_forever_ = true; }
135
set_num_chunks(int num_chunks)136 void set_num_chunks(int num_chunks) { num_chunks_ = num_chunks; }
137
set_sync()138 void set_sync() { is_sync_ = true; }
139
set_is_last_chunk_zero_size()140 void set_is_last_chunk_zero_size() { is_last_chunk_zero_size_ = true; }
141
142 private:
143 int ReadResponseBodyImpl(IOBuffer* buf, int buf_len);
144 void CompleteRead();
145
closed() const146 bool closed() const { return closed_; }
147
148 CloseResultWaiter* const result_waiter_;
149 scoped_refptr<IOBuffer> user_buf_;
150 CompletionCallback callback_;
151 int buf_len_;
152 bool closed_;
153 bool stall_reads_forever_;
154 int num_chunks_;
155 bool is_sync_;
156 bool is_last_chunk_zero_size_;
157 bool is_complete_;
158 base::WeakPtrFactory<MockHttpStream> weak_factory_;
159 };
160
ReadResponseBody(IOBuffer * buf,int buf_len,const CompletionCallback & callback)161 int MockHttpStream::ReadResponseBody(IOBuffer* buf,
162 int buf_len,
163 const CompletionCallback& callback) {
164 CHECK(!callback.is_null());
165 CHECK(callback_.is_null());
166 CHECK(buf);
167
168 if (stall_reads_forever_)
169 return ERR_IO_PENDING;
170
171 if (is_complete_)
172 return ERR_UNEXPECTED;
173
174 if (!is_sync_) {
175 user_buf_ = buf;
176 buf_len_ = buf_len;
177 callback_ = callback;
178 base::MessageLoop::current()->PostTask(
179 FROM_HERE,
180 base::Bind(&MockHttpStream::CompleteRead, weak_factory_.GetWeakPtr()));
181 return ERR_IO_PENDING;
182 } else {
183 return ReadResponseBodyImpl(buf, buf_len);
184 }
185 }
186
ReadResponseBodyImpl(IOBuffer * buf,int buf_len)187 int MockHttpStream::ReadResponseBodyImpl(IOBuffer* buf, int buf_len) {
188 if (is_last_chunk_zero_size_ && num_chunks_ == 1) {
189 buf_len = 0;
190 } else {
191 if (buf_len > kMagicChunkSize)
192 buf_len = kMagicChunkSize;
193 std::memset(buf->data(), 1, buf_len);
194 }
195 num_chunks_--;
196 if (!num_chunks_)
197 is_complete_ = true;
198
199 return buf_len;
200 }
201
CompleteRead()202 void MockHttpStream::CompleteRead() {
203 int result = ReadResponseBodyImpl(user_buf_.get(), buf_len_);
204 user_buf_ = NULL;
205 CompletionCallback callback = callback_;
206 callback_.Reset();
207 callback.Run(result);
208 }
209
210 class HttpResponseBodyDrainerTest : public testing::Test {
211 protected:
HttpResponseBodyDrainerTest()212 HttpResponseBodyDrainerTest()
213 : proxy_service_(ProxyService::CreateDirect()),
214 ssl_config_service_(new SSLConfigServiceDefaults),
215 http_server_properties_(new HttpServerPropertiesImpl()),
216 session_(CreateNetworkSession()),
217 mock_stream_(new MockHttpStream(&result_waiter_)),
218 drainer_(new HttpResponseBodyDrainer(mock_stream_)) {}
219
~HttpResponseBodyDrainerTest()220 virtual ~HttpResponseBodyDrainerTest() {}
221
CreateNetworkSession() const222 HttpNetworkSession* CreateNetworkSession() const {
223 HttpNetworkSession::Params params;
224 params.proxy_service = proxy_service_.get();
225 params.ssl_config_service = ssl_config_service_.get();
226 params.http_server_properties = http_server_properties_->GetWeakPtr();
227 return new HttpNetworkSession(params);
228 }
229
230 scoped_ptr<ProxyService> proxy_service_;
231 scoped_refptr<SSLConfigService> ssl_config_service_;
232 scoped_ptr<HttpServerPropertiesImpl> http_server_properties_;
233 const scoped_refptr<HttpNetworkSession> session_;
234 CloseResultWaiter result_waiter_;
235 MockHttpStream* const mock_stream_; // Owned by |drainer_|.
236 HttpResponseBodyDrainer* const drainer_; // Deletes itself.
237 };
238
TEST_F(HttpResponseBodyDrainerTest,DrainBodySyncSingleOK)239 TEST_F(HttpResponseBodyDrainerTest, DrainBodySyncSingleOK) {
240 mock_stream_->set_num_chunks(1);
241 mock_stream_->set_sync();
242 drainer_->Start(session_.get());
243 EXPECT_FALSE(result_waiter_.WaitForResult());
244 }
245
TEST_F(HttpResponseBodyDrainerTest,DrainBodySyncOK)246 TEST_F(HttpResponseBodyDrainerTest, DrainBodySyncOK) {
247 mock_stream_->set_num_chunks(3);
248 mock_stream_->set_sync();
249 drainer_->Start(session_.get());
250 EXPECT_FALSE(result_waiter_.WaitForResult());
251 }
252
TEST_F(HttpResponseBodyDrainerTest,DrainBodyAsyncOK)253 TEST_F(HttpResponseBodyDrainerTest, DrainBodyAsyncOK) {
254 mock_stream_->set_num_chunks(3);
255 drainer_->Start(session_.get());
256 EXPECT_FALSE(result_waiter_.WaitForResult());
257 }
258
259 // Test the case when the final chunk is 0 bytes. This can happen when
260 // the final 0-byte chunk of a chunk-encoded http response is read in a last
261 // call to ReadResponseBody, after all data were returned from HttpStream.
TEST_F(HttpResponseBodyDrainerTest,DrainBodyAsyncEmptyChunk)262 TEST_F(HttpResponseBodyDrainerTest, DrainBodyAsyncEmptyChunk) {
263 mock_stream_->set_num_chunks(4);
264 mock_stream_->set_is_last_chunk_zero_size();
265 drainer_->Start(session_.get());
266 EXPECT_FALSE(result_waiter_.WaitForResult());
267 }
268
TEST_F(HttpResponseBodyDrainerTest,DrainBodySyncEmptyChunk)269 TEST_F(HttpResponseBodyDrainerTest, DrainBodySyncEmptyChunk) {
270 mock_stream_->set_num_chunks(4);
271 mock_stream_->set_sync();
272 mock_stream_->set_is_last_chunk_zero_size();
273 drainer_->Start(session_.get());
274 EXPECT_FALSE(result_waiter_.WaitForResult());
275 }
276
TEST_F(HttpResponseBodyDrainerTest,DrainBodySizeEqualsDrainBuffer)277 TEST_F(HttpResponseBodyDrainerTest, DrainBodySizeEqualsDrainBuffer) {
278 mock_stream_->set_num_chunks(
279 HttpResponseBodyDrainer::kDrainBodyBufferSize / kMagicChunkSize);
280 drainer_->Start(session_.get());
281 EXPECT_FALSE(result_waiter_.WaitForResult());
282 }
283
TEST_F(HttpResponseBodyDrainerTest,DrainBodyTimeOut)284 TEST_F(HttpResponseBodyDrainerTest, DrainBodyTimeOut) {
285 mock_stream_->set_num_chunks(2);
286 mock_stream_->set_stall_reads_forever();
287 drainer_->Start(session_.get());
288 EXPECT_TRUE(result_waiter_.WaitForResult());
289 }
290
TEST_F(HttpResponseBodyDrainerTest,CancelledBySession)291 TEST_F(HttpResponseBodyDrainerTest, CancelledBySession) {
292 mock_stream_->set_num_chunks(2);
293 mock_stream_->set_stall_reads_forever();
294 drainer_->Start(session_.get());
295 // HttpNetworkSession should delete |drainer_|.
296 }
297
TEST_F(HttpResponseBodyDrainerTest,DrainBodyTooLarge)298 TEST_F(HttpResponseBodyDrainerTest, DrainBodyTooLarge) {
299 int too_many_chunks =
300 HttpResponseBodyDrainer::kDrainBodyBufferSize / kMagicChunkSize;
301 too_many_chunks += 1; // Now it's too large.
302
303 mock_stream_->set_num_chunks(too_many_chunks);
304 drainer_->Start(session_.get());
305 EXPECT_TRUE(result_waiter_.WaitForResult());
306 }
307
TEST_F(HttpResponseBodyDrainerTest,StartBodyTooLarge)308 TEST_F(HttpResponseBodyDrainerTest, StartBodyTooLarge) {
309 int too_many_chunks =
310 HttpResponseBodyDrainer::kDrainBodyBufferSize / kMagicChunkSize;
311 too_many_chunks += 1; // Now it's too large.
312
313 mock_stream_->set_num_chunks(0);
314 drainer_->StartWithSize(session_.get(), too_many_chunks * kMagicChunkSize);
315 EXPECT_TRUE(result_waiter_.WaitForResult());
316 }
317
TEST_F(HttpResponseBodyDrainerTest,StartWithNothingToDo)318 TEST_F(HttpResponseBodyDrainerTest, StartWithNothingToDo) {
319 mock_stream_->set_num_chunks(0);
320 drainer_->StartWithSize(session_.get(), 0);
321 EXPECT_FALSE(result_waiter_.WaitForResult());
322 }
323
324 } // namespace
325
326 } // namespace net
327