• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <algorithm>
6 #include <utility>
7 #include <vector>
8 
9 #include "base/bind.h"
10 #include "base/bind_helpers.h"
11 #include "base/callback_helpers.h"
12 #include "base/compiler_specific.h"
13 #include "base/format_macros.h"
14 #include "base/memory/ref_counted.h"
15 #include "base/memory/scoped_ptr.h"
16 #include "base/memory/weak_ptr.h"
17 #include "base/message_loop/message_loop.h"
18 #include "base/message_loop/message_loop_proxy.h"
19 #include "base/run_loop.h"
20 #include "base/strings/string_split.h"
21 #include "base/strings/string_util.h"
22 #include "base/strings/stringprintf.h"
23 #include "base/time/time.h"
24 #include "net/base/address_list.h"
25 #include "net/base/io_buffer.h"
26 #include "net/base/ip_endpoint.h"
27 #include "net/base/net_errors.h"
28 #include "net/base/net_log.h"
29 #include "net/base/net_util.h"
30 #include "net/base/test_completion_callback.h"
31 #include "net/http/http_response_headers.h"
32 #include "net/http/http_util.h"
33 #include "net/server/http_server.h"
34 #include "net/server/http_server_request_info.h"
35 #include "net/socket/tcp_client_socket.h"
36 #include "net/socket/tcp_server_socket.h"
37 #include "net/url_request/url_fetcher.h"
38 #include "net/url_request/url_fetcher_delegate.h"
39 #include "net/url_request/url_request_context.h"
40 #include "net/url_request/url_request_context_getter.h"
41 #include "net/url_request/url_request_test_util.h"
42 #include "testing/gtest/include/gtest/gtest.h"
43 
44 namespace net {
45 
46 namespace {
47 
48 const int kMaxExpectedResponseLength = 2048;
49 
SetTimedOutAndQuitLoop(const base::WeakPtr<bool> timed_out,const base::Closure & quit_loop_func)50 void SetTimedOutAndQuitLoop(const base::WeakPtr<bool> timed_out,
51                             const base::Closure& quit_loop_func) {
52   if (timed_out) {
53     *timed_out = true;
54     quit_loop_func.Run();
55   }
56 }
57 
RunLoopWithTimeout(base::RunLoop * run_loop)58 bool RunLoopWithTimeout(base::RunLoop* run_loop) {
59   bool timed_out = false;
60   base::WeakPtrFactory<bool> timed_out_weak_factory(&timed_out);
61   base::MessageLoop::current()->PostDelayedTask(
62       FROM_HERE,
63       base::Bind(&SetTimedOutAndQuitLoop,
64                  timed_out_weak_factory.GetWeakPtr(),
65                  run_loop->QuitClosure()),
66       base::TimeDelta::FromSeconds(1));
67   run_loop->Run();
68   return !timed_out;
69 }
70 
71 class TestHttpClient {
72  public:
TestHttpClient()73   TestHttpClient() : connect_result_(OK) {}
74 
ConnectAndWait(const IPEndPoint & address)75   int ConnectAndWait(const IPEndPoint& address) {
76     AddressList addresses(address);
77     NetLog::Source source;
78     socket_.reset(new TCPClientSocket(addresses, NULL, source));
79 
80     base::RunLoop run_loop;
81     connect_result_ = socket_->Connect(base::Bind(&TestHttpClient::OnConnect,
82                                                   base::Unretained(this),
83                                                   run_loop.QuitClosure()));
84     if (connect_result_ != OK && connect_result_ != ERR_IO_PENDING)
85       return connect_result_;
86 
87     if (!RunLoopWithTimeout(&run_loop))
88       return ERR_TIMED_OUT;
89     return connect_result_;
90   }
91 
Send(const std::string & data)92   void Send(const std::string& data) {
93     write_buffer_ =
94         new DrainableIOBuffer(new StringIOBuffer(data), data.length());
95     Write();
96   }
97 
Read(std::string * message,int expected_bytes)98   bool Read(std::string* message, int expected_bytes) {
99     int total_bytes_received = 0;
100     message->clear();
101     while (total_bytes_received < expected_bytes) {
102       net::TestCompletionCallback callback;
103       ReadInternal(callback.callback());
104       int bytes_received = callback.WaitForResult();
105       if (bytes_received <= 0)
106         return false;
107 
108       total_bytes_received += bytes_received;
109       message->append(read_buffer_->data(), bytes_received);
110     }
111     return true;
112   }
113 
ReadResponse(std::string * message)114   bool ReadResponse(std::string* message) {
115     if (!Read(message, 1))
116       return false;
117     while (!IsCompleteResponse(*message)) {
118       std::string chunk;
119       if (!Read(&chunk, 1))
120         return false;
121       message->append(chunk);
122     }
123     return true;
124   }
125 
126  private:
OnConnect(const base::Closure & quit_loop,int result)127   void OnConnect(const base::Closure& quit_loop, int result) {
128     connect_result_ = result;
129     quit_loop.Run();
130   }
131 
Write()132   void Write() {
133     int result = socket_->Write(
134         write_buffer_.get(),
135         write_buffer_->BytesRemaining(),
136         base::Bind(&TestHttpClient::OnWrite, base::Unretained(this)));
137     if (result != ERR_IO_PENDING)
138       OnWrite(result);
139   }
140 
OnWrite(int result)141   void OnWrite(int result) {
142     ASSERT_GT(result, 0);
143     write_buffer_->DidConsume(result);
144     if (write_buffer_->BytesRemaining())
145       Write();
146   }
147 
ReadInternal(const net::CompletionCallback & callback)148   void ReadInternal(const net::CompletionCallback& callback) {
149     read_buffer_ = new IOBufferWithSize(kMaxExpectedResponseLength);
150     int result =
151         socket_->Read(read_buffer_.get(), kMaxExpectedResponseLength, callback);
152     if (result != ERR_IO_PENDING)
153       callback.Run(result);
154   }
155 
IsCompleteResponse(const std::string & response)156   bool IsCompleteResponse(const std::string& response) {
157     // Check end of headers first.
158     int end_of_headers = HttpUtil::LocateEndOfHeaders(response.data(),
159                                                       response.size());
160     if (end_of_headers < 0)
161       return false;
162 
163     // Return true if response has data equal to or more than content length.
164     int64 body_size = static_cast<int64>(response.size()) - end_of_headers;
165     DCHECK_LE(0, body_size);
166     scoped_refptr<HttpResponseHeaders> headers(new HttpResponseHeaders(
167         HttpUtil::AssembleRawHeaders(response.data(), end_of_headers)));
168     return body_size >= headers->GetContentLength();
169   }
170 
171   scoped_refptr<IOBufferWithSize> read_buffer_;
172   scoped_refptr<DrainableIOBuffer> write_buffer_;
173   scoped_ptr<TCPClientSocket> socket_;
174   int connect_result_;
175 };
176 
177 }  // namespace
178 
179 class HttpServerTest : public testing::Test,
180                        public HttpServer::Delegate {
181  public:
HttpServerTest()182   HttpServerTest() : quit_after_request_count_(0) {}
183 
SetUp()184   virtual void SetUp() OVERRIDE {
185     scoped_ptr<ServerSocket> server_socket(
186         new TCPServerSocket(NULL, net::NetLog::Source()));
187     server_socket->ListenWithAddressAndPort("127.0.0.1", 0, 1);
188     server_.reset(new HttpServer(server_socket.Pass(), this));
189     ASSERT_EQ(OK, server_->GetLocalAddress(&server_address_));
190   }
191 
OnConnect(int connection_id)192   virtual void OnConnect(int connection_id) OVERRIDE {}
193 
OnHttpRequest(int connection_id,const HttpServerRequestInfo & info)194   virtual void OnHttpRequest(int connection_id,
195                              const HttpServerRequestInfo& info) OVERRIDE {
196     requests_.push_back(std::make_pair(info, connection_id));
197     if (requests_.size() == quit_after_request_count_)
198       run_loop_quit_func_.Run();
199   }
200 
OnWebSocketRequest(int connection_id,const HttpServerRequestInfo & info)201   virtual void OnWebSocketRequest(int connection_id,
202                                   const HttpServerRequestInfo& info) OVERRIDE {
203     NOTREACHED();
204   }
205 
OnWebSocketMessage(int connection_id,const std::string & data)206   virtual void OnWebSocketMessage(int connection_id,
207                                   const std::string& data) OVERRIDE {
208     NOTREACHED();
209   }
210 
OnClose(int connection_id)211   virtual void OnClose(int connection_id) OVERRIDE {}
212 
RunUntilRequestsReceived(size_t count)213   bool RunUntilRequestsReceived(size_t count) {
214     quit_after_request_count_ = count;
215     if (requests_.size() == count)
216       return true;
217 
218     base::RunLoop run_loop;
219     run_loop_quit_func_ = run_loop.QuitClosure();
220     bool success = RunLoopWithTimeout(&run_loop);
221     run_loop_quit_func_.Reset();
222     return success;
223   }
224 
GetRequest(size_t request_index)225   HttpServerRequestInfo GetRequest(size_t request_index) {
226     return requests_[request_index].first;
227   }
228 
GetConnectionId(size_t request_index)229   int GetConnectionId(size_t request_index) {
230     return requests_[request_index].second;
231   }
232 
HandleAcceptResult(scoped_ptr<StreamSocket> socket)233   void HandleAcceptResult(scoped_ptr<StreamSocket> socket) {
234     server_->accepted_socket_.reset(socket.release());
235     server_->HandleAcceptResult(OK);
236   }
237 
238  protected:
239   scoped_ptr<HttpServer> server_;
240   IPEndPoint server_address_;
241   base::Closure run_loop_quit_func_;
242   std::vector<std::pair<HttpServerRequestInfo, int> > requests_;
243 
244  private:
245   size_t quit_after_request_count_;
246 };
247 
248 namespace {
249 
250 class WebSocketTest : public HttpServerTest {
OnHttpRequest(int connection_id,const HttpServerRequestInfo & info)251   virtual void OnHttpRequest(int connection_id,
252                              const HttpServerRequestInfo& info) OVERRIDE {
253     NOTREACHED();
254   }
255 
OnWebSocketRequest(int connection_id,const HttpServerRequestInfo & info)256   virtual void OnWebSocketRequest(int connection_id,
257                                   const HttpServerRequestInfo& info) OVERRIDE {
258     HttpServerTest::OnHttpRequest(connection_id, info);
259   }
260 
OnWebSocketMessage(int connection_id,const std::string & data)261   virtual void OnWebSocketMessage(int connection_id,
262                                   const std::string& data) OVERRIDE {
263   }
264 };
265 
TEST_F(HttpServerTest,Request)266 TEST_F(HttpServerTest, Request) {
267   TestHttpClient client;
268   ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
269   client.Send("GET /test HTTP/1.1\r\n\r\n");
270   ASSERT_TRUE(RunUntilRequestsReceived(1));
271   ASSERT_EQ("GET", GetRequest(0).method);
272   ASSERT_EQ("/test", GetRequest(0).path);
273   ASSERT_EQ("", GetRequest(0).data);
274   ASSERT_EQ(0u, GetRequest(0).headers.size());
275   ASSERT_TRUE(StartsWithASCII(GetRequest(0).peer.ToString(),
276                               "127.0.0.1",
277                               true));
278 }
279 
TEST_F(HttpServerTest,RequestWithHeaders)280 TEST_F(HttpServerTest, RequestWithHeaders) {
281   TestHttpClient client;
282   ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
283   const char* kHeaders[][3] = {
284       {"Header", ": ", "1"},
285       {"HeaderWithNoWhitespace", ":", "1"},
286       {"HeaderWithWhitespace", "   :  \t   ", "1 1 1 \t  "},
287       {"HeaderWithColon", ": ", "1:1"},
288       {"EmptyHeader", ":", ""},
289       {"EmptyHeaderWithWhitespace", ":  \t  ", ""},
290       {"HeaderWithNonASCII", ":  ", "\xf7"},
291   };
292   std::string headers;
293   for (size_t i = 0; i < arraysize(kHeaders); ++i) {
294     headers +=
295         std::string(kHeaders[i][0]) + kHeaders[i][1] + kHeaders[i][2] + "\r\n";
296   }
297 
298   client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n");
299   ASSERT_TRUE(RunUntilRequestsReceived(1));
300   ASSERT_EQ("", GetRequest(0).data);
301 
302   for (size_t i = 0; i < arraysize(kHeaders); ++i) {
303     std::string field = base::StringToLowerASCII(std::string(kHeaders[i][0]));
304     std::string value = kHeaders[i][2];
305     ASSERT_EQ(1u, GetRequest(0).headers.count(field)) << field;
306     ASSERT_EQ(value, GetRequest(0).headers[field]) << kHeaders[i][0];
307   }
308 }
309 
TEST_F(HttpServerTest,RequestWithDuplicateHeaders)310 TEST_F(HttpServerTest, RequestWithDuplicateHeaders) {
311   TestHttpClient client;
312   ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
313   const char* kHeaders[][3] = {
314       {"FirstHeader", ": ", "1"},
315       {"DuplicateHeader", ": ", "2"},
316       {"MiddleHeader", ": ", "3"},
317       {"DuplicateHeader", ": ", "4"},
318       {"LastHeader", ": ", "5"},
319   };
320   std::string headers;
321   for (size_t i = 0; i < arraysize(kHeaders); ++i) {
322     headers +=
323         std::string(kHeaders[i][0]) + kHeaders[i][1] + kHeaders[i][2] + "\r\n";
324   }
325 
326   client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n");
327   ASSERT_TRUE(RunUntilRequestsReceived(1));
328   ASSERT_EQ("", GetRequest(0).data);
329 
330   for (size_t i = 0; i < arraysize(kHeaders); ++i) {
331     std::string field = base::StringToLowerASCII(std::string(kHeaders[i][0]));
332     std::string value = (field == "duplicateheader") ? "2,4" : kHeaders[i][2];
333     ASSERT_EQ(1u, GetRequest(0).headers.count(field)) << field;
334     ASSERT_EQ(value, GetRequest(0).headers[field]) << kHeaders[i][0];
335   }
336 }
337 
TEST_F(HttpServerTest,HasHeaderValueTest)338 TEST_F(HttpServerTest, HasHeaderValueTest) {
339   TestHttpClient client;
340   ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
341   const char* kHeaders[] = {
342       "Header: Abcd",
343       "HeaderWithNoWhitespace:E",
344       "HeaderWithWhitespace   :  \t   f \t  ",
345       "DuplicateHeader: g",
346       "HeaderWithComma: h, i ,j",
347       "DuplicateHeader: k",
348       "EmptyHeader:",
349       "EmptyHeaderWithWhitespace:  \t  ",
350       "HeaderWithNonASCII:  \xf7",
351   };
352   std::string headers;
353   for (size_t i = 0; i < arraysize(kHeaders); ++i) {
354     headers += std::string(kHeaders[i]) + "\r\n";
355   }
356 
357   client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n");
358   ASSERT_TRUE(RunUntilRequestsReceived(1));
359   ASSERT_EQ("", GetRequest(0).data);
360 
361   ASSERT_TRUE(GetRequest(0).HasHeaderValue("header", "abcd"));
362   ASSERT_FALSE(GetRequest(0).HasHeaderValue("header", "bc"));
363   ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithnowhitespace", "e"));
364   ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithwhitespace", "f"));
365   ASSERT_TRUE(GetRequest(0).HasHeaderValue("duplicateheader", "g"));
366   ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithcomma", "h"));
367   ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithcomma", "i"));
368   ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithcomma", "j"));
369   ASSERT_TRUE(GetRequest(0).HasHeaderValue("duplicateheader", "k"));
370   ASSERT_FALSE(GetRequest(0).HasHeaderValue("emptyheader", "x"));
371   ASSERT_FALSE(GetRequest(0).HasHeaderValue("emptyheaderwithwhitespace", "x"));
372   ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithnonascii", "\xf7"));
373 }
374 
TEST_F(HttpServerTest,RequestWithBody)375 TEST_F(HttpServerTest, RequestWithBody) {
376   TestHttpClient client;
377   ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
378   std::string body = "a" + std::string(1 << 10, 'b') + "c";
379   client.Send(base::StringPrintf(
380       "GET /test HTTP/1.1\r\n"
381       "SomeHeader: 1\r\n"
382       "Content-Length: %" PRIuS "\r\n\r\n%s",
383       body.length(),
384       body.c_str()));
385   ASSERT_TRUE(RunUntilRequestsReceived(1));
386   ASSERT_EQ(2u, GetRequest(0).headers.size());
387   ASSERT_EQ(body.length(), GetRequest(0).data.length());
388   ASSERT_EQ('a', body[0]);
389   ASSERT_EQ('c', *body.rbegin());
390 }
391 
TEST_F(WebSocketTest,RequestWebSocket)392 TEST_F(WebSocketTest, RequestWebSocket) {
393   TestHttpClient client;
394   ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
395   client.Send(
396       "GET /test HTTP/1.1\r\n"
397       "Upgrade: WebSocket\r\n"
398       "Connection: SomethingElse, Upgrade\r\n"
399       "Sec-WebSocket-Version: 8\r\n"
400       "Sec-WebSocket-Key: key\r\n"
401       "\r\n");
402   ASSERT_TRUE(RunUntilRequestsReceived(1));
403 }
404 
TEST_F(HttpServerTest,RequestWithTooLargeBody)405 TEST_F(HttpServerTest, RequestWithTooLargeBody) {
406   class TestURLFetcherDelegate : public URLFetcherDelegate {
407    public:
408     TestURLFetcherDelegate(const base::Closure& quit_loop_func)
409         : quit_loop_func_(quit_loop_func) {}
410     virtual ~TestURLFetcherDelegate() {}
411 
412     virtual void OnURLFetchComplete(const URLFetcher* source) OVERRIDE {
413       EXPECT_EQ(HTTP_INTERNAL_SERVER_ERROR, source->GetResponseCode());
414       quit_loop_func_.Run();
415     }
416 
417    private:
418     base::Closure quit_loop_func_;
419   };
420 
421   base::RunLoop run_loop;
422   TestURLFetcherDelegate delegate(run_loop.QuitClosure());
423 
424   scoped_refptr<URLRequestContextGetter> request_context_getter(
425       new TestURLRequestContextGetter(base::MessageLoopProxy::current()));
426   scoped_ptr<URLFetcher> fetcher(
427       URLFetcher::Create(GURL(base::StringPrintf("http://127.0.0.1:%d/test",
428                                                  server_address_.port())),
429                          URLFetcher::GET,
430                          &delegate));
431   fetcher->SetRequestContext(request_context_getter.get());
432   fetcher->AddExtraRequestHeader(
433       base::StringPrintf("content-length:%d", 1 << 30));
434   fetcher->Start();
435 
436   ASSERT_TRUE(RunLoopWithTimeout(&run_loop));
437   ASSERT_EQ(0u, requests_.size());
438 }
439 
TEST_F(HttpServerTest,Send200)440 TEST_F(HttpServerTest, Send200) {
441   TestHttpClient client;
442   ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
443   client.Send("GET /test HTTP/1.1\r\n\r\n");
444   ASSERT_TRUE(RunUntilRequestsReceived(1));
445   server_->Send200(GetConnectionId(0), "Response!", "text/plain");
446 
447   std::string response;
448   ASSERT_TRUE(client.ReadResponse(&response));
449   ASSERT_TRUE(StartsWithASCII(response, "HTTP/1.1 200 OK", true));
450   ASSERT_TRUE(EndsWith(response, "Response!", true));
451 }
452 
TEST_F(HttpServerTest,SendRaw)453 TEST_F(HttpServerTest, SendRaw) {
454   TestHttpClient client;
455   ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
456   client.Send("GET /test HTTP/1.1\r\n\r\n");
457   ASSERT_TRUE(RunUntilRequestsReceived(1));
458   server_->SendRaw(GetConnectionId(0), "Raw Data ");
459   server_->SendRaw(GetConnectionId(0), "More Data");
460   server_->SendRaw(GetConnectionId(0), "Third Piece of Data");
461 
462   const std::string expected_response("Raw Data More DataThird Piece of Data");
463   std::string response;
464   ASSERT_TRUE(client.Read(&response, expected_response.length()));
465   ASSERT_EQ(expected_response, response);
466 }
467 
468 class MockStreamSocket : public StreamSocket {
469  public:
MockStreamSocket()470   MockStreamSocket()
471       : connected_(true),
472         read_buf_(NULL),
473         read_buf_len_(0) {}
474 
475   // StreamSocket
Connect(const CompletionCallback & callback)476   virtual int Connect(const CompletionCallback& callback) OVERRIDE {
477     return ERR_NOT_IMPLEMENTED;
478   }
Disconnect()479   virtual void Disconnect() OVERRIDE {
480     connected_ = false;
481     if (!read_callback_.is_null()) {
482       read_buf_ = NULL;
483       read_buf_len_ = 0;
484       base::ResetAndReturn(&read_callback_).Run(ERR_CONNECTION_CLOSED);
485     }
486   }
IsConnected() const487   virtual bool IsConnected() const OVERRIDE { return connected_; }
IsConnectedAndIdle() const488   virtual bool IsConnectedAndIdle() const OVERRIDE { return IsConnected(); }
GetPeerAddress(IPEndPoint * address) const489   virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE {
490     return ERR_NOT_IMPLEMENTED;
491   }
GetLocalAddress(IPEndPoint * address) const492   virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE {
493     return ERR_NOT_IMPLEMENTED;
494   }
NetLog() const495   virtual const BoundNetLog& NetLog() const OVERRIDE { return net_log_; }
SetSubresourceSpeculation()496   virtual void SetSubresourceSpeculation() OVERRIDE {}
SetOmniboxSpeculation()497   virtual void SetOmniboxSpeculation() OVERRIDE {}
WasEverUsed() const498   virtual bool WasEverUsed() const OVERRIDE { return true; }
UsingTCPFastOpen() const499   virtual bool UsingTCPFastOpen() const OVERRIDE { return false; }
WasNpnNegotiated() const500   virtual bool WasNpnNegotiated() const OVERRIDE { return false; }
GetNegotiatedProtocol() const501   virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
502     return kProtoUnknown;
503   }
GetSSLInfo(SSLInfo * ssl_info)504   virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { return false; }
505 
506   // Socket
Read(IOBuffer * buf,int buf_len,const CompletionCallback & callback)507   virtual int Read(IOBuffer* buf, int buf_len,
508                    const CompletionCallback& callback) OVERRIDE {
509     if (!connected_) {
510       return ERR_SOCKET_NOT_CONNECTED;
511     }
512     if (pending_read_data_.empty()) {
513       read_buf_ = buf;
514       read_buf_len_ = buf_len;
515       read_callback_ = callback;
516       return ERR_IO_PENDING;
517     }
518     DCHECK_GT(buf_len, 0);
519     int read_len = std::min(static_cast<int>(pending_read_data_.size()),
520                             buf_len);
521     memcpy(buf->data(), pending_read_data_.data(), read_len);
522     pending_read_data_.erase(0, read_len);
523     return read_len;
524   }
Write(IOBuffer * buf,int buf_len,const CompletionCallback & callback)525   virtual int Write(IOBuffer* buf, int buf_len,
526                     const CompletionCallback& callback) OVERRIDE {
527     return ERR_NOT_IMPLEMENTED;
528   }
SetReceiveBufferSize(int32 size)529   virtual int SetReceiveBufferSize(int32 size) OVERRIDE {
530     return ERR_NOT_IMPLEMENTED;
531   }
SetSendBufferSize(int32 size)532   virtual int SetSendBufferSize(int32 size) OVERRIDE {
533     return ERR_NOT_IMPLEMENTED;
534   }
535 
DidRead(const char * data,int data_len)536   void DidRead(const char* data, int data_len) {
537     if (!read_buf_.get()) {
538       pending_read_data_.append(data, data_len);
539       return;
540     }
541     int read_len = std::min(data_len, read_buf_len_);
542     memcpy(read_buf_->data(), data, read_len);
543     pending_read_data_.assign(data + read_len, data_len - read_len);
544     read_buf_ = NULL;
545     read_buf_len_ = 0;
546     base::ResetAndReturn(&read_callback_).Run(read_len);
547   }
548 
549  private:
~MockStreamSocket()550   virtual ~MockStreamSocket() {}
551 
552   bool connected_;
553   scoped_refptr<IOBuffer> read_buf_;
554   int read_buf_len_;
555   CompletionCallback read_callback_;
556   std::string pending_read_data_;
557   BoundNetLog net_log_;
558 
559   DISALLOW_COPY_AND_ASSIGN(MockStreamSocket);
560 };
561 
TEST_F(HttpServerTest,RequestWithBodySplitAcrossPackets)562 TEST_F(HttpServerTest, RequestWithBodySplitAcrossPackets) {
563   MockStreamSocket* socket = new MockStreamSocket();
564   HandleAcceptResult(make_scoped_ptr<StreamSocket>(socket));
565   std::string body("body");
566   std::string request_text = base::StringPrintf(
567       "GET /test HTTP/1.1\r\n"
568       "SomeHeader: 1\r\n"
569       "Content-Length: %" PRIuS "\r\n\r\n%s",
570       body.length(),
571       body.c_str());
572   socket->DidRead(request_text.c_str(), request_text.length() - 2);
573   ASSERT_EQ(0u, requests_.size());
574   socket->DidRead(request_text.c_str() + request_text.length() - 2, 2);
575   ASSERT_EQ(1u, requests_.size());
576   ASSERT_EQ(body, GetRequest(0).data);
577 }
578 
TEST_F(HttpServerTest,MultipleRequestsOnSameConnection)579 TEST_F(HttpServerTest, MultipleRequestsOnSameConnection) {
580   // The idea behind this test is that requests with or without bodies should
581   // not break parsing of the next request.
582   TestHttpClient client;
583   ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
584   std::string body = "body";
585   client.Send(base::StringPrintf(
586       "GET /test HTTP/1.1\r\n"
587       "Content-Length: %" PRIuS "\r\n\r\n%s",
588       body.length(),
589       body.c_str()));
590   ASSERT_TRUE(RunUntilRequestsReceived(1));
591   ASSERT_EQ(body, GetRequest(0).data);
592 
593   int client_connection_id = GetConnectionId(0);
594   server_->Send200(client_connection_id, "Content for /test", "text/plain");
595   std::string response1;
596   ASSERT_TRUE(client.ReadResponse(&response1));
597   ASSERT_TRUE(StartsWithASCII(response1, "HTTP/1.1 200 OK", true));
598   ASSERT_TRUE(EndsWith(response1, "Content for /test", true));
599 
600   client.Send("GET /test2 HTTP/1.1\r\n\r\n");
601   ASSERT_TRUE(RunUntilRequestsReceived(2));
602   ASSERT_EQ("/test2", GetRequest(1).path);
603 
604   ASSERT_EQ(client_connection_id, GetConnectionId(1));
605   server_->Send404(client_connection_id);
606   std::string response2;
607   ASSERT_TRUE(client.ReadResponse(&response2));
608   ASSERT_TRUE(StartsWithASCII(response2, "HTTP/1.1 404 Not Found", true));
609 
610   client.Send("GET /test3 HTTP/1.1\r\n\r\n");
611   ASSERT_TRUE(RunUntilRequestsReceived(3));
612   ASSERT_EQ("/test3", GetRequest(2).path);
613 
614   ASSERT_EQ(client_connection_id, GetConnectionId(2));
615   server_->Send200(client_connection_id, "Content for /test3", "text/plain");
616   std::string response3;
617   ASSERT_TRUE(client.ReadResponse(&response3));
618   ASSERT_TRUE(StartsWithASCII(response3, "HTTP/1.1 200 OK", true));
619   ASSERT_TRUE(EndsWith(response3, "Content for /test3", true));
620 }
621 
622 class CloseOnConnectHttpServerTest : public HttpServerTest {
623  public:
OnConnect(int connection_id)624   virtual void OnConnect(int connection_id) OVERRIDE {
625     connection_ids_.push_back(connection_id);
626     server_->Close(connection_id);
627   }
628 
629  protected:
630   std::vector<int> connection_ids_;
631 };
632 
TEST_F(CloseOnConnectHttpServerTest,ServerImmediatelyClosesConnection)633 TEST_F(CloseOnConnectHttpServerTest, ServerImmediatelyClosesConnection) {
634   TestHttpClient client;
635   ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
636   client.Send("GET / HTTP/1.1\r\n\r\n");
637   ASSERT_FALSE(RunUntilRequestsReceived(1));
638   ASSERT_EQ(1ul, connection_ids_.size());
639   ASSERT_EQ(0ul, requests_.size());
640 }
641 
642 }  // namespace
643 
644 }  // namespace net
645