1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include <algorithm>
6 #include <utility>
7 #include <vector>
8
9 #include "base/bind.h"
10 #include "base/bind_helpers.h"
11 #include "base/callback_helpers.h"
12 #include "base/compiler_specific.h"
13 #include "base/format_macros.h"
14 #include "base/memory/ref_counted.h"
15 #include "base/memory/scoped_ptr.h"
16 #include "base/memory/weak_ptr.h"
17 #include "base/message_loop/message_loop.h"
18 #include "base/message_loop/message_loop_proxy.h"
19 #include "base/run_loop.h"
20 #include "base/strings/string_split.h"
21 #include "base/strings/string_util.h"
22 #include "base/strings/stringprintf.h"
23 #include "base/time/time.h"
24 #include "net/base/address_list.h"
25 #include "net/base/io_buffer.h"
26 #include "net/base/ip_endpoint.h"
27 #include "net/base/net_errors.h"
28 #include "net/base/net_log.h"
29 #include "net/base/net_util.h"
30 #include "net/base/test_completion_callback.h"
31 #include "net/http/http_response_headers.h"
32 #include "net/http/http_util.h"
33 #include "net/server/http_server.h"
34 #include "net/server/http_server_request_info.h"
35 #include "net/socket/tcp_client_socket.h"
36 #include "net/socket/tcp_server_socket.h"
37 #include "net/url_request/url_fetcher.h"
38 #include "net/url_request/url_fetcher_delegate.h"
39 #include "net/url_request/url_request_context.h"
40 #include "net/url_request/url_request_context_getter.h"
41 #include "net/url_request/url_request_test_util.h"
42 #include "testing/gtest/include/gtest/gtest.h"
43
44 namespace net {
45
46 namespace {
47
48 const int kMaxExpectedResponseLength = 2048;
49
SetTimedOutAndQuitLoop(const base::WeakPtr<bool> timed_out,const base::Closure & quit_loop_func)50 void SetTimedOutAndQuitLoop(const base::WeakPtr<bool> timed_out,
51 const base::Closure& quit_loop_func) {
52 if (timed_out) {
53 *timed_out = true;
54 quit_loop_func.Run();
55 }
56 }
57
RunLoopWithTimeout(base::RunLoop * run_loop)58 bool RunLoopWithTimeout(base::RunLoop* run_loop) {
59 bool timed_out = false;
60 base::WeakPtrFactory<bool> timed_out_weak_factory(&timed_out);
61 base::MessageLoop::current()->PostDelayedTask(
62 FROM_HERE,
63 base::Bind(&SetTimedOutAndQuitLoop,
64 timed_out_weak_factory.GetWeakPtr(),
65 run_loop->QuitClosure()),
66 base::TimeDelta::FromSeconds(1));
67 run_loop->Run();
68 return !timed_out;
69 }
70
71 class TestHttpClient {
72 public:
TestHttpClient()73 TestHttpClient() : connect_result_(OK) {}
74
ConnectAndWait(const IPEndPoint & address)75 int ConnectAndWait(const IPEndPoint& address) {
76 AddressList addresses(address);
77 NetLog::Source source;
78 socket_.reset(new TCPClientSocket(addresses, NULL, source));
79
80 base::RunLoop run_loop;
81 connect_result_ = socket_->Connect(base::Bind(&TestHttpClient::OnConnect,
82 base::Unretained(this),
83 run_loop.QuitClosure()));
84 if (connect_result_ != OK && connect_result_ != ERR_IO_PENDING)
85 return connect_result_;
86
87 if (!RunLoopWithTimeout(&run_loop))
88 return ERR_TIMED_OUT;
89 return connect_result_;
90 }
91
Send(const std::string & data)92 void Send(const std::string& data) {
93 write_buffer_ =
94 new DrainableIOBuffer(new StringIOBuffer(data), data.length());
95 Write();
96 }
97
Read(std::string * message,int expected_bytes)98 bool Read(std::string* message, int expected_bytes) {
99 int total_bytes_received = 0;
100 message->clear();
101 while (total_bytes_received < expected_bytes) {
102 net::TestCompletionCallback callback;
103 ReadInternal(callback.callback());
104 int bytes_received = callback.WaitForResult();
105 if (bytes_received <= 0)
106 return false;
107
108 total_bytes_received += bytes_received;
109 message->append(read_buffer_->data(), bytes_received);
110 }
111 return true;
112 }
113
ReadResponse(std::string * message)114 bool ReadResponse(std::string* message) {
115 if (!Read(message, 1))
116 return false;
117 while (!IsCompleteResponse(*message)) {
118 std::string chunk;
119 if (!Read(&chunk, 1))
120 return false;
121 message->append(chunk);
122 }
123 return true;
124 }
125
126 private:
OnConnect(const base::Closure & quit_loop,int result)127 void OnConnect(const base::Closure& quit_loop, int result) {
128 connect_result_ = result;
129 quit_loop.Run();
130 }
131
Write()132 void Write() {
133 int result = socket_->Write(
134 write_buffer_.get(),
135 write_buffer_->BytesRemaining(),
136 base::Bind(&TestHttpClient::OnWrite, base::Unretained(this)));
137 if (result != ERR_IO_PENDING)
138 OnWrite(result);
139 }
140
OnWrite(int result)141 void OnWrite(int result) {
142 ASSERT_GT(result, 0);
143 write_buffer_->DidConsume(result);
144 if (write_buffer_->BytesRemaining())
145 Write();
146 }
147
ReadInternal(const net::CompletionCallback & callback)148 void ReadInternal(const net::CompletionCallback& callback) {
149 read_buffer_ = new IOBufferWithSize(kMaxExpectedResponseLength);
150 int result =
151 socket_->Read(read_buffer_.get(), kMaxExpectedResponseLength, callback);
152 if (result != ERR_IO_PENDING)
153 callback.Run(result);
154 }
155
IsCompleteResponse(const std::string & response)156 bool IsCompleteResponse(const std::string& response) {
157 // Check end of headers first.
158 int end_of_headers = HttpUtil::LocateEndOfHeaders(response.data(),
159 response.size());
160 if (end_of_headers < 0)
161 return false;
162
163 // Return true if response has data equal to or more than content length.
164 int64 body_size = static_cast<int64>(response.size()) - end_of_headers;
165 DCHECK_LE(0, body_size);
166 scoped_refptr<HttpResponseHeaders> headers(new HttpResponseHeaders(
167 HttpUtil::AssembleRawHeaders(response.data(), end_of_headers)));
168 return body_size >= headers->GetContentLength();
169 }
170
171 scoped_refptr<IOBufferWithSize> read_buffer_;
172 scoped_refptr<DrainableIOBuffer> write_buffer_;
173 scoped_ptr<TCPClientSocket> socket_;
174 int connect_result_;
175 };
176
177 } // namespace
178
179 class HttpServerTest : public testing::Test,
180 public HttpServer::Delegate {
181 public:
HttpServerTest()182 HttpServerTest() : quit_after_request_count_(0) {}
183
SetUp()184 virtual void SetUp() OVERRIDE {
185 scoped_ptr<ServerSocket> server_socket(
186 new TCPServerSocket(NULL, net::NetLog::Source()));
187 server_socket->ListenWithAddressAndPort("127.0.0.1", 0, 1);
188 server_.reset(new HttpServer(server_socket.Pass(), this));
189 ASSERT_EQ(OK, server_->GetLocalAddress(&server_address_));
190 }
191
OnConnect(int connection_id)192 virtual void OnConnect(int connection_id) OVERRIDE {}
193
OnHttpRequest(int connection_id,const HttpServerRequestInfo & info)194 virtual void OnHttpRequest(int connection_id,
195 const HttpServerRequestInfo& info) OVERRIDE {
196 requests_.push_back(std::make_pair(info, connection_id));
197 if (requests_.size() == quit_after_request_count_)
198 run_loop_quit_func_.Run();
199 }
200
OnWebSocketRequest(int connection_id,const HttpServerRequestInfo & info)201 virtual void OnWebSocketRequest(int connection_id,
202 const HttpServerRequestInfo& info) OVERRIDE {
203 NOTREACHED();
204 }
205
OnWebSocketMessage(int connection_id,const std::string & data)206 virtual void OnWebSocketMessage(int connection_id,
207 const std::string& data) OVERRIDE {
208 NOTREACHED();
209 }
210
OnClose(int connection_id)211 virtual void OnClose(int connection_id) OVERRIDE {}
212
RunUntilRequestsReceived(size_t count)213 bool RunUntilRequestsReceived(size_t count) {
214 quit_after_request_count_ = count;
215 if (requests_.size() == count)
216 return true;
217
218 base::RunLoop run_loop;
219 run_loop_quit_func_ = run_loop.QuitClosure();
220 bool success = RunLoopWithTimeout(&run_loop);
221 run_loop_quit_func_.Reset();
222 return success;
223 }
224
GetRequest(size_t request_index)225 HttpServerRequestInfo GetRequest(size_t request_index) {
226 return requests_[request_index].first;
227 }
228
GetConnectionId(size_t request_index)229 int GetConnectionId(size_t request_index) {
230 return requests_[request_index].second;
231 }
232
HandleAcceptResult(scoped_ptr<StreamSocket> socket)233 void HandleAcceptResult(scoped_ptr<StreamSocket> socket) {
234 server_->accepted_socket_.reset(socket.release());
235 server_->HandleAcceptResult(OK);
236 }
237
238 protected:
239 scoped_ptr<HttpServer> server_;
240 IPEndPoint server_address_;
241 base::Closure run_loop_quit_func_;
242 std::vector<std::pair<HttpServerRequestInfo, int> > requests_;
243
244 private:
245 size_t quit_after_request_count_;
246 };
247
248 namespace {
249
250 class WebSocketTest : public HttpServerTest {
OnHttpRequest(int connection_id,const HttpServerRequestInfo & info)251 virtual void OnHttpRequest(int connection_id,
252 const HttpServerRequestInfo& info) OVERRIDE {
253 NOTREACHED();
254 }
255
OnWebSocketRequest(int connection_id,const HttpServerRequestInfo & info)256 virtual void OnWebSocketRequest(int connection_id,
257 const HttpServerRequestInfo& info) OVERRIDE {
258 HttpServerTest::OnHttpRequest(connection_id, info);
259 }
260
OnWebSocketMessage(int connection_id,const std::string & data)261 virtual void OnWebSocketMessage(int connection_id,
262 const std::string& data) OVERRIDE {
263 }
264 };
265
TEST_F(HttpServerTest,Request)266 TEST_F(HttpServerTest, Request) {
267 TestHttpClient client;
268 ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
269 client.Send("GET /test HTTP/1.1\r\n\r\n");
270 ASSERT_TRUE(RunUntilRequestsReceived(1));
271 ASSERT_EQ("GET", GetRequest(0).method);
272 ASSERT_EQ("/test", GetRequest(0).path);
273 ASSERT_EQ("", GetRequest(0).data);
274 ASSERT_EQ(0u, GetRequest(0).headers.size());
275 ASSERT_TRUE(StartsWithASCII(GetRequest(0).peer.ToString(),
276 "127.0.0.1",
277 true));
278 }
279
TEST_F(HttpServerTest,RequestWithHeaders)280 TEST_F(HttpServerTest, RequestWithHeaders) {
281 TestHttpClient client;
282 ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
283 const char* kHeaders[][3] = {
284 {"Header", ": ", "1"},
285 {"HeaderWithNoWhitespace", ":", "1"},
286 {"HeaderWithWhitespace", " : \t ", "1 1 1 \t "},
287 {"HeaderWithColon", ": ", "1:1"},
288 {"EmptyHeader", ":", ""},
289 {"EmptyHeaderWithWhitespace", ": \t ", ""},
290 {"HeaderWithNonASCII", ": ", "\xf7"},
291 };
292 std::string headers;
293 for (size_t i = 0; i < arraysize(kHeaders); ++i) {
294 headers +=
295 std::string(kHeaders[i][0]) + kHeaders[i][1] + kHeaders[i][2] + "\r\n";
296 }
297
298 client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n");
299 ASSERT_TRUE(RunUntilRequestsReceived(1));
300 ASSERT_EQ("", GetRequest(0).data);
301
302 for (size_t i = 0; i < arraysize(kHeaders); ++i) {
303 std::string field = base::StringToLowerASCII(std::string(kHeaders[i][0]));
304 std::string value = kHeaders[i][2];
305 ASSERT_EQ(1u, GetRequest(0).headers.count(field)) << field;
306 ASSERT_EQ(value, GetRequest(0).headers[field]) << kHeaders[i][0];
307 }
308 }
309
TEST_F(HttpServerTest,RequestWithDuplicateHeaders)310 TEST_F(HttpServerTest, RequestWithDuplicateHeaders) {
311 TestHttpClient client;
312 ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
313 const char* kHeaders[][3] = {
314 {"FirstHeader", ": ", "1"},
315 {"DuplicateHeader", ": ", "2"},
316 {"MiddleHeader", ": ", "3"},
317 {"DuplicateHeader", ": ", "4"},
318 {"LastHeader", ": ", "5"},
319 };
320 std::string headers;
321 for (size_t i = 0; i < arraysize(kHeaders); ++i) {
322 headers +=
323 std::string(kHeaders[i][0]) + kHeaders[i][1] + kHeaders[i][2] + "\r\n";
324 }
325
326 client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n");
327 ASSERT_TRUE(RunUntilRequestsReceived(1));
328 ASSERT_EQ("", GetRequest(0).data);
329
330 for (size_t i = 0; i < arraysize(kHeaders); ++i) {
331 std::string field = base::StringToLowerASCII(std::string(kHeaders[i][0]));
332 std::string value = (field == "duplicateheader") ? "2,4" : kHeaders[i][2];
333 ASSERT_EQ(1u, GetRequest(0).headers.count(field)) << field;
334 ASSERT_EQ(value, GetRequest(0).headers[field]) << kHeaders[i][0];
335 }
336 }
337
TEST_F(HttpServerTest,HasHeaderValueTest)338 TEST_F(HttpServerTest, HasHeaderValueTest) {
339 TestHttpClient client;
340 ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
341 const char* kHeaders[] = {
342 "Header: Abcd",
343 "HeaderWithNoWhitespace:E",
344 "HeaderWithWhitespace : \t f \t ",
345 "DuplicateHeader: g",
346 "HeaderWithComma: h, i ,j",
347 "DuplicateHeader: k",
348 "EmptyHeader:",
349 "EmptyHeaderWithWhitespace: \t ",
350 "HeaderWithNonASCII: \xf7",
351 };
352 std::string headers;
353 for (size_t i = 0; i < arraysize(kHeaders); ++i) {
354 headers += std::string(kHeaders[i]) + "\r\n";
355 }
356
357 client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n");
358 ASSERT_TRUE(RunUntilRequestsReceived(1));
359 ASSERT_EQ("", GetRequest(0).data);
360
361 ASSERT_TRUE(GetRequest(0).HasHeaderValue("header", "abcd"));
362 ASSERT_FALSE(GetRequest(0).HasHeaderValue("header", "bc"));
363 ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithnowhitespace", "e"));
364 ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithwhitespace", "f"));
365 ASSERT_TRUE(GetRequest(0).HasHeaderValue("duplicateheader", "g"));
366 ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithcomma", "h"));
367 ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithcomma", "i"));
368 ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithcomma", "j"));
369 ASSERT_TRUE(GetRequest(0).HasHeaderValue("duplicateheader", "k"));
370 ASSERT_FALSE(GetRequest(0).HasHeaderValue("emptyheader", "x"));
371 ASSERT_FALSE(GetRequest(0).HasHeaderValue("emptyheaderwithwhitespace", "x"));
372 ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithnonascii", "\xf7"));
373 }
374
TEST_F(HttpServerTest,RequestWithBody)375 TEST_F(HttpServerTest, RequestWithBody) {
376 TestHttpClient client;
377 ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
378 std::string body = "a" + std::string(1 << 10, 'b') + "c";
379 client.Send(base::StringPrintf(
380 "GET /test HTTP/1.1\r\n"
381 "SomeHeader: 1\r\n"
382 "Content-Length: %" PRIuS "\r\n\r\n%s",
383 body.length(),
384 body.c_str()));
385 ASSERT_TRUE(RunUntilRequestsReceived(1));
386 ASSERT_EQ(2u, GetRequest(0).headers.size());
387 ASSERT_EQ(body.length(), GetRequest(0).data.length());
388 ASSERT_EQ('a', body[0]);
389 ASSERT_EQ('c', *body.rbegin());
390 }
391
TEST_F(WebSocketTest,RequestWebSocket)392 TEST_F(WebSocketTest, RequestWebSocket) {
393 TestHttpClient client;
394 ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
395 client.Send(
396 "GET /test HTTP/1.1\r\n"
397 "Upgrade: WebSocket\r\n"
398 "Connection: SomethingElse, Upgrade\r\n"
399 "Sec-WebSocket-Version: 8\r\n"
400 "Sec-WebSocket-Key: key\r\n"
401 "\r\n");
402 ASSERT_TRUE(RunUntilRequestsReceived(1));
403 }
404
TEST_F(HttpServerTest,RequestWithTooLargeBody)405 TEST_F(HttpServerTest, RequestWithTooLargeBody) {
406 class TestURLFetcherDelegate : public URLFetcherDelegate {
407 public:
408 TestURLFetcherDelegate(const base::Closure& quit_loop_func)
409 : quit_loop_func_(quit_loop_func) {}
410 virtual ~TestURLFetcherDelegate() {}
411
412 virtual void OnURLFetchComplete(const URLFetcher* source) OVERRIDE {
413 EXPECT_EQ(HTTP_INTERNAL_SERVER_ERROR, source->GetResponseCode());
414 quit_loop_func_.Run();
415 }
416
417 private:
418 base::Closure quit_loop_func_;
419 };
420
421 base::RunLoop run_loop;
422 TestURLFetcherDelegate delegate(run_loop.QuitClosure());
423
424 scoped_refptr<URLRequestContextGetter> request_context_getter(
425 new TestURLRequestContextGetter(base::MessageLoopProxy::current()));
426 scoped_ptr<URLFetcher> fetcher(
427 URLFetcher::Create(GURL(base::StringPrintf("http://127.0.0.1:%d/test",
428 server_address_.port())),
429 URLFetcher::GET,
430 &delegate));
431 fetcher->SetRequestContext(request_context_getter.get());
432 fetcher->AddExtraRequestHeader(
433 base::StringPrintf("content-length:%d", 1 << 30));
434 fetcher->Start();
435
436 ASSERT_TRUE(RunLoopWithTimeout(&run_loop));
437 ASSERT_EQ(0u, requests_.size());
438 }
439
TEST_F(HttpServerTest,Send200)440 TEST_F(HttpServerTest, Send200) {
441 TestHttpClient client;
442 ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
443 client.Send("GET /test HTTP/1.1\r\n\r\n");
444 ASSERT_TRUE(RunUntilRequestsReceived(1));
445 server_->Send200(GetConnectionId(0), "Response!", "text/plain");
446
447 std::string response;
448 ASSERT_TRUE(client.ReadResponse(&response));
449 ASSERT_TRUE(StartsWithASCII(response, "HTTP/1.1 200 OK", true));
450 ASSERT_TRUE(EndsWith(response, "Response!", true));
451 }
452
TEST_F(HttpServerTest,SendRaw)453 TEST_F(HttpServerTest, SendRaw) {
454 TestHttpClient client;
455 ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
456 client.Send("GET /test HTTP/1.1\r\n\r\n");
457 ASSERT_TRUE(RunUntilRequestsReceived(1));
458 server_->SendRaw(GetConnectionId(0), "Raw Data ");
459 server_->SendRaw(GetConnectionId(0), "More Data");
460 server_->SendRaw(GetConnectionId(0), "Third Piece of Data");
461
462 const std::string expected_response("Raw Data More DataThird Piece of Data");
463 std::string response;
464 ASSERT_TRUE(client.Read(&response, expected_response.length()));
465 ASSERT_EQ(expected_response, response);
466 }
467
468 class MockStreamSocket : public StreamSocket {
469 public:
MockStreamSocket()470 MockStreamSocket()
471 : connected_(true),
472 read_buf_(NULL),
473 read_buf_len_(0) {}
474
475 // StreamSocket
Connect(const CompletionCallback & callback)476 virtual int Connect(const CompletionCallback& callback) OVERRIDE {
477 return ERR_NOT_IMPLEMENTED;
478 }
Disconnect()479 virtual void Disconnect() OVERRIDE {
480 connected_ = false;
481 if (!read_callback_.is_null()) {
482 read_buf_ = NULL;
483 read_buf_len_ = 0;
484 base::ResetAndReturn(&read_callback_).Run(ERR_CONNECTION_CLOSED);
485 }
486 }
IsConnected() const487 virtual bool IsConnected() const OVERRIDE { return connected_; }
IsConnectedAndIdle() const488 virtual bool IsConnectedAndIdle() const OVERRIDE { return IsConnected(); }
GetPeerAddress(IPEndPoint * address) const489 virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE {
490 return ERR_NOT_IMPLEMENTED;
491 }
GetLocalAddress(IPEndPoint * address) const492 virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE {
493 return ERR_NOT_IMPLEMENTED;
494 }
NetLog() const495 virtual const BoundNetLog& NetLog() const OVERRIDE { return net_log_; }
SetSubresourceSpeculation()496 virtual void SetSubresourceSpeculation() OVERRIDE {}
SetOmniboxSpeculation()497 virtual void SetOmniboxSpeculation() OVERRIDE {}
WasEverUsed() const498 virtual bool WasEverUsed() const OVERRIDE { return true; }
UsingTCPFastOpen() const499 virtual bool UsingTCPFastOpen() const OVERRIDE { return false; }
WasNpnNegotiated() const500 virtual bool WasNpnNegotiated() const OVERRIDE { return false; }
GetNegotiatedProtocol() const501 virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
502 return kProtoUnknown;
503 }
GetSSLInfo(SSLInfo * ssl_info)504 virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { return false; }
505
506 // Socket
Read(IOBuffer * buf,int buf_len,const CompletionCallback & callback)507 virtual int Read(IOBuffer* buf, int buf_len,
508 const CompletionCallback& callback) OVERRIDE {
509 if (!connected_) {
510 return ERR_SOCKET_NOT_CONNECTED;
511 }
512 if (pending_read_data_.empty()) {
513 read_buf_ = buf;
514 read_buf_len_ = buf_len;
515 read_callback_ = callback;
516 return ERR_IO_PENDING;
517 }
518 DCHECK_GT(buf_len, 0);
519 int read_len = std::min(static_cast<int>(pending_read_data_.size()),
520 buf_len);
521 memcpy(buf->data(), pending_read_data_.data(), read_len);
522 pending_read_data_.erase(0, read_len);
523 return read_len;
524 }
Write(IOBuffer * buf,int buf_len,const CompletionCallback & callback)525 virtual int Write(IOBuffer* buf, int buf_len,
526 const CompletionCallback& callback) OVERRIDE {
527 return ERR_NOT_IMPLEMENTED;
528 }
SetReceiveBufferSize(int32 size)529 virtual int SetReceiveBufferSize(int32 size) OVERRIDE {
530 return ERR_NOT_IMPLEMENTED;
531 }
SetSendBufferSize(int32 size)532 virtual int SetSendBufferSize(int32 size) OVERRIDE {
533 return ERR_NOT_IMPLEMENTED;
534 }
535
DidRead(const char * data,int data_len)536 void DidRead(const char* data, int data_len) {
537 if (!read_buf_.get()) {
538 pending_read_data_.append(data, data_len);
539 return;
540 }
541 int read_len = std::min(data_len, read_buf_len_);
542 memcpy(read_buf_->data(), data, read_len);
543 pending_read_data_.assign(data + read_len, data_len - read_len);
544 read_buf_ = NULL;
545 read_buf_len_ = 0;
546 base::ResetAndReturn(&read_callback_).Run(read_len);
547 }
548
549 private:
~MockStreamSocket()550 virtual ~MockStreamSocket() {}
551
552 bool connected_;
553 scoped_refptr<IOBuffer> read_buf_;
554 int read_buf_len_;
555 CompletionCallback read_callback_;
556 std::string pending_read_data_;
557 BoundNetLog net_log_;
558
559 DISALLOW_COPY_AND_ASSIGN(MockStreamSocket);
560 };
561
TEST_F(HttpServerTest,RequestWithBodySplitAcrossPackets)562 TEST_F(HttpServerTest, RequestWithBodySplitAcrossPackets) {
563 MockStreamSocket* socket = new MockStreamSocket();
564 HandleAcceptResult(make_scoped_ptr<StreamSocket>(socket));
565 std::string body("body");
566 std::string request_text = base::StringPrintf(
567 "GET /test HTTP/1.1\r\n"
568 "SomeHeader: 1\r\n"
569 "Content-Length: %" PRIuS "\r\n\r\n%s",
570 body.length(),
571 body.c_str());
572 socket->DidRead(request_text.c_str(), request_text.length() - 2);
573 ASSERT_EQ(0u, requests_.size());
574 socket->DidRead(request_text.c_str() + request_text.length() - 2, 2);
575 ASSERT_EQ(1u, requests_.size());
576 ASSERT_EQ(body, GetRequest(0).data);
577 }
578
TEST_F(HttpServerTest,MultipleRequestsOnSameConnection)579 TEST_F(HttpServerTest, MultipleRequestsOnSameConnection) {
580 // The idea behind this test is that requests with or without bodies should
581 // not break parsing of the next request.
582 TestHttpClient client;
583 ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
584 std::string body = "body";
585 client.Send(base::StringPrintf(
586 "GET /test HTTP/1.1\r\n"
587 "Content-Length: %" PRIuS "\r\n\r\n%s",
588 body.length(),
589 body.c_str()));
590 ASSERT_TRUE(RunUntilRequestsReceived(1));
591 ASSERT_EQ(body, GetRequest(0).data);
592
593 int client_connection_id = GetConnectionId(0);
594 server_->Send200(client_connection_id, "Content for /test", "text/plain");
595 std::string response1;
596 ASSERT_TRUE(client.ReadResponse(&response1));
597 ASSERT_TRUE(StartsWithASCII(response1, "HTTP/1.1 200 OK", true));
598 ASSERT_TRUE(EndsWith(response1, "Content for /test", true));
599
600 client.Send("GET /test2 HTTP/1.1\r\n\r\n");
601 ASSERT_TRUE(RunUntilRequestsReceived(2));
602 ASSERT_EQ("/test2", GetRequest(1).path);
603
604 ASSERT_EQ(client_connection_id, GetConnectionId(1));
605 server_->Send404(client_connection_id);
606 std::string response2;
607 ASSERT_TRUE(client.ReadResponse(&response2));
608 ASSERT_TRUE(StartsWithASCII(response2, "HTTP/1.1 404 Not Found", true));
609
610 client.Send("GET /test3 HTTP/1.1\r\n\r\n");
611 ASSERT_TRUE(RunUntilRequestsReceived(3));
612 ASSERT_EQ("/test3", GetRequest(2).path);
613
614 ASSERT_EQ(client_connection_id, GetConnectionId(2));
615 server_->Send200(client_connection_id, "Content for /test3", "text/plain");
616 std::string response3;
617 ASSERT_TRUE(client.ReadResponse(&response3));
618 ASSERT_TRUE(StartsWithASCII(response3, "HTTP/1.1 200 OK", true));
619 ASSERT_TRUE(EndsWith(response3, "Content for /test3", true));
620 }
621
622 class CloseOnConnectHttpServerTest : public HttpServerTest {
623 public:
OnConnect(int connection_id)624 virtual void OnConnect(int connection_id) OVERRIDE {
625 connection_ids_.push_back(connection_id);
626 server_->Close(connection_id);
627 }
628
629 protected:
630 std::vector<int> connection_ids_;
631 };
632
TEST_F(CloseOnConnectHttpServerTest,ServerImmediatelyClosesConnection)633 TEST_F(CloseOnConnectHttpServerTest, ServerImmediatelyClosesConnection) {
634 TestHttpClient client;
635 ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
636 client.Send("GET / HTTP/1.1\r\n\r\n");
637 ASSERT_FALSE(RunUntilRequestsReceived(1));
638 ASSERT_EQ(1ul, connection_ids_.size());
639 ASSERT_EQ(0ul, requests_.size());
640 }
641
642 } // namespace
643
644 } // namespace net
645