• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/server/http_server.h"
6 
7 #include <stdint.h>
8 
9 #include <algorithm>
10 #include <memory>
11 #include <unordered_map>
12 #include <utility>
13 #include <vector>
14 
15 #include "base/auto_reset.h"
16 #include "base/check_op.h"
17 #include "base/compiler_specific.h"
18 #include "base/format_macros.h"
19 #include "base/functional/bind.h"
20 #include "base/functional/callback_helpers.h"
21 #include "base/location.h"
22 #include "base/memory/ptr_util.h"
23 #include "base/memory/ref_counted.h"
24 #include "base/memory/weak_ptr.h"
25 #include "base/notreached.h"
26 #include "base/run_loop.h"
27 #include "base/strings/string_split.h"
28 #include "base/strings/string_util.h"
29 #include "base/strings/stringprintf.h"
30 #include "base/task/single_thread_task_runner.h"
31 #include "base/test/test_future.h"
32 #include "base/time/time.h"
33 #include "net/base/address_list.h"
34 #include "net/base/io_buffer.h"
35 #include "net/base/ip_endpoint.h"
36 #include "net/base/net_errors.h"
37 #include "net/base/test_completion_callback.h"
38 #include "net/http/http_response_headers.h"
39 #include "net/http/http_util.h"
40 #include "net/log/net_log_source.h"
41 #include "net/log/net_log_with_source.h"
42 #include "net/server/http_server_request_info.h"
43 #include "net/socket/tcp_client_socket.h"
44 #include "net/socket/tcp_server_socket.h"
45 #include "net/test/gtest_util.h"
46 #include "net/test/test_with_task_environment.h"
47 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
48 #include "net/websockets/websocket_frame.h"
49 #include "testing/gmock/include/gmock/gmock.h"
50 #include "testing/gtest/include/gtest/gtest.h"
51 
52 using net::test::IsOk;
53 
54 namespace net {
55 
56 namespace {
57 
58 const int kMaxExpectedResponseLength = 2048;
59 
60 class TestHttpClient {
61  public:
62   TestHttpClient() = default;
63 
ConnectAndWait(const IPEndPoint & address)64   int ConnectAndWait(const IPEndPoint& address) {
65     AddressList addresses(address);
66     NetLogSource source;
67     socket_ = std::make_unique<TCPClientSocket>(addresses, nullptr, nullptr,
68                                                 nullptr, source);
69 
70     TestCompletionCallback callback;
71     int rv = socket_->Connect(callback.callback());
72     return callback.GetResult(rv);
73   }
74 
Send(const std::string & data)75   void Send(const std::string& data) {
76     write_buffer_ = base::MakeRefCounted<DrainableIOBuffer>(
77         base::MakeRefCounted<StringIOBuffer>(data), data.length());
78     Write();
79   }
80 
Read(std::string * message,int expected_bytes)81   bool Read(std::string* message, int expected_bytes) {
82     int total_bytes_received = 0;
83     message->clear();
84     while (total_bytes_received < expected_bytes) {
85       TestCompletionCallback callback;
86       ReadInternal(&callback);
87       int bytes_received = callback.WaitForResult();
88       if (bytes_received <= 0) {
89         return false;
90       }
91 
92       total_bytes_received += bytes_received;
93       message->append(read_buffer_->data(), bytes_received);
94     }
95     return true;
96   }
97 
ReadResponse(std::string * message)98   bool ReadResponse(std::string* message) {
99     if (!Read(message, 1)) {
100       return false;
101     }
102     while (!IsCompleteResponse(*message)) {
103       std::string chunk;
104       if (!Read(&chunk, 1)) {
105         return false;
106       }
107       message->append(chunk);
108     }
109     return true;
110   }
111 
ExpectUsedThenDisconnectedWithNoData()112   void ExpectUsedThenDisconnectedWithNoData() {
113     // Check that the socket was opened...
114     ASSERT_TRUE(socket_->WasEverUsed());
115 
116     // ...then closed when the server disconnected. Verify that the socket was
117     // closed by checking that a Read() fails.
118     std::string response;
119     ASSERT_FALSE(Read(&response, 1u));
120     ASSERT_TRUE(response.empty());
121   }
122 
socket()123   TCPClientSocket& socket() { return *socket_; }
124 
125  private:
Write()126   void Write() {
127     int result = socket_->Write(
128         write_buffer_.get(), write_buffer_->BytesRemaining(),
129         base::BindOnce(&TestHttpClient::OnWrite, base::Unretained(this)),
130         TRAFFIC_ANNOTATION_FOR_TESTS);
131     if (result != ERR_IO_PENDING) {
132       OnWrite(result);
133     }
134   }
135 
OnWrite(int result)136   void OnWrite(int result) {
137     ASSERT_GT(result, 0);
138     write_buffer_->DidConsume(result);
139     if (write_buffer_->BytesRemaining()) {
140       Write();
141     }
142   }
143 
ReadInternal(TestCompletionCallback * callback)144   void ReadInternal(TestCompletionCallback* callback) {
145     read_buffer_ =
146         base::MakeRefCounted<IOBufferWithSize>(kMaxExpectedResponseLength);
147     int result = socket_->Read(read_buffer_.get(), kMaxExpectedResponseLength,
148                                callback->callback());
149     if (result != ERR_IO_PENDING) {
150       callback->callback().Run(result);
151     }
152   }
153 
IsCompleteResponse(const std::string & response)154   bool IsCompleteResponse(const std::string& response) {
155     // Check end of headers first.
156     size_t end_of_headers =
157         HttpUtil::LocateEndOfHeaders(response.data(), response.size());
158     if (end_of_headers == std::string::npos) {
159       return false;
160     }
161 
162     // Return true if response has data equal to or more than content length.
163     int64_t body_size = static_cast<int64_t>(response.size()) - end_of_headers;
164     DCHECK_LE(0, body_size);
165     auto headers =
166         base::MakeRefCounted<HttpResponseHeaders>(HttpUtil::AssembleRawHeaders(
167             base::StringPiece(response.data(), end_of_headers)));
168     return body_size >= headers->GetContentLength();
169   }
170 
171   scoped_refptr<IOBufferWithSize> read_buffer_;
172   scoped_refptr<DrainableIOBuffer> write_buffer_;
173   std::unique_ptr<TCPClientSocket> socket_;
174 };
175 
176 struct ReceivedRequest {
177   HttpServerRequestInfo info;
178   int connection_id;
179 };
180 
181 }  // namespace
182 
183 class HttpServerTest : public TestWithTaskEnvironment,
184                        public HttpServer::Delegate {
185  public:
186   HttpServerTest() = default;
187 
SetUp()188   void SetUp() override {
189     auto server_socket =
190         std::make_unique<TCPServerSocket>(nullptr, NetLogSource());
191     server_socket->ListenWithAddressAndPort("127.0.0.1", 0, 1);
192     server_ = std::make_unique<HttpServer>(std::move(server_socket), this);
193     ASSERT_THAT(server_->GetLocalAddress(&server_address_), IsOk());
194   }
195 
TearDown()196   void TearDown() override {
197     // Run the event loop some to make sure that the memory handed over to
198     // DeleteSoon gets fully freed.
199     base::RunLoop().RunUntilIdle();
200   }
201 
OnConnect(int connection_id)202   void OnConnect(int connection_id) override {
203     DCHECK(connection_map_.find(connection_id) == connection_map_.end());
204     connection_map_[connection_id] = true;
205     // This is set in CreateConnection(), which must be invoked once for every
206     // expected connection.
207     quit_on_create_loop_->Quit();
208   }
209 
OnHttpRequest(int connection_id,const HttpServerRequestInfo & info)210   void OnHttpRequest(int connection_id,
211                      const HttpServerRequestInfo& info) override {
212     received_request_.SetValue({.info = info, .connection_id = connection_id});
213   }
214 
OnWebSocketRequest(int connection_id,const HttpServerRequestInfo & info)215   void OnWebSocketRequest(int connection_id,
216                           const HttpServerRequestInfo& info) override {
217     NOTREACHED();
218   }
219 
OnWebSocketMessage(int connection_id,std::string data)220   void OnWebSocketMessage(int connection_id, std::string data) override {
221     NOTREACHED();
222   }
223 
OnClose(int connection_id)224   void OnClose(int connection_id) override {
225     DCHECK(connection_map_.find(connection_id) != connection_map_.end());
226     connection_map_[connection_id] = false;
227     if (connection_id == quit_on_close_connection_) {
228       std::move(run_loop_quit_func_).Run();
229     }
230   }
231 
WaitForRequest()232   ReceivedRequest WaitForRequest() { return received_request_.Take(); }
233 
HasRequest() const234   bool HasRequest() const { return received_request_.IsReady(); }
235 
236   // Connections should only be created using this method, which waits until
237   // both the server and the client have received the connected socket.
CreateConnection(TestHttpClient * client)238   void CreateConnection(TestHttpClient* client) {
239     ASSERT_FALSE(quit_on_create_loop_);
240     quit_on_create_loop_ = std::make_unique<base::RunLoop>();
241     EXPECT_THAT(client->ConnectAndWait(server_address_), IsOk());
242     quit_on_create_loop_->Run();
243     quit_on_create_loop_.reset();
244   }
245 
RunUntilConnectionIdClosed(int connection_id)246   void RunUntilConnectionIdClosed(int connection_id) {
247     quit_on_close_connection_ = connection_id;
248     auto iter = connection_map_.find(connection_id);
249     if (iter != connection_map_.end() && !iter->second) {
250       // Already disconnected.
251       return;
252     }
253 
254     base::RunLoop run_loop;
255     base::AutoReset<base::OnceClosure> run_loop_quit_func(
256         &run_loop_quit_func_, run_loop.QuitClosure());
257     run_loop.Run();
258 
259     iter = connection_map_.find(connection_id);
260     ASSERT_TRUE(iter != connection_map_.end());
261     ASSERT_FALSE(iter->second);
262   }
263 
HandleAcceptResult(std::unique_ptr<StreamSocket> socket)264   void HandleAcceptResult(std::unique_ptr<StreamSocket> socket) {
265     ASSERT_FALSE(quit_on_create_loop_);
266     quit_on_create_loop_ = std::make_unique<base::RunLoop>();
267     server_->accepted_socket_ = std::move(socket);
268     server_->HandleAcceptResult(OK);
269     quit_on_create_loop_->Run();
270     quit_on_create_loop_.reset();
271   }
272 
connection_map()273   std::unordered_map<int, bool>& connection_map() { return connection_map_; }
274 
275  protected:
276   std::unique_ptr<HttpServer> server_;
277   IPEndPoint server_address_;
278   base::OnceClosure run_loop_quit_func_;
279   std::unordered_map<int /* connection_id */, bool /* connected */>
280       connection_map_;
281 
282  private:
283   base::test::TestFuture<ReceivedRequest> received_request_;
284   std::unique_ptr<base::RunLoop> quit_on_create_loop_;
285   int quit_on_close_connection_ = -1;
286 };
287 
288 namespace {
289 
290 class WebSocketTest : public HttpServerTest {
OnHttpRequest(int connection_id,const HttpServerRequestInfo & info)291   void OnHttpRequest(int connection_id,
292                      const HttpServerRequestInfo& info) override {
293     NOTREACHED();
294   }
295 
OnWebSocketRequest(int connection_id,const HttpServerRequestInfo & info)296   void OnWebSocketRequest(int connection_id,
297                           const HttpServerRequestInfo& info) override {
298     HttpServerTest::OnHttpRequest(connection_id, info);
299   }
300 
OnWebSocketMessage(int connection_id,std::string data)301   void OnWebSocketMessage(int connection_id, std::string data) override {}
302 };
303 
304 class WebSocketAcceptingTest : public WebSocketTest {
305  public:
OnWebSocketRequest(int connection_id,const HttpServerRequestInfo & info)306   void OnWebSocketRequest(int connection_id,
307                           const HttpServerRequestInfo& info) override {
308     HttpServerTest::OnHttpRequest(connection_id, info);
309     server_->AcceptWebSocket(connection_id, info, TRAFFIC_ANNOTATION_FOR_TESTS);
310   }
311 
OnWebSocketMessage(int connection_id,std::string data)312   void OnWebSocketMessage(int connection_id, std::string data) override {
313     last_message_.SetValue(data);
314   }
315 
GetMessage()316   std::string GetMessage() { return last_message_.Take(); }
317 
318  private:
319   base::test::TestFuture<std::string> last_message_;
320 };
321 
EncodeFrame(std::string message,WebSocketFrameHeader::OpCodeEnum op_code,bool mask,bool finish)322 std::string EncodeFrame(std::string message,
323                         WebSocketFrameHeader::OpCodeEnum op_code,
324                         bool mask,
325                         bool finish) {
326   WebSocketFrameHeader header(op_code);
327   header.final = finish;
328   header.masked = mask;
329   header.payload_length = message.size();
330   const int header_size = GetWebSocketFrameHeaderSize(header);
331   std::string frame_header;
332   frame_header.resize(header_size);
333   if (mask) {
334     WebSocketMaskingKey masking_key = GenerateWebSocketMaskingKey();
335     WriteWebSocketFrameHeader(header, &masking_key, &frame_header[0],
336                               header_size);
337     MaskWebSocketFramePayload(masking_key, 0, &message[0], message.size());
338   } else {
339     WriteWebSocketFrameHeader(header, nullptr, &frame_header[0], header_size);
340   }
341   return frame_header + message;
342 }
343 
TEST_F(HttpServerTest,Request)344 TEST_F(HttpServerTest, Request) {
345   TestHttpClient client;
346   CreateConnection(&client);
347   client.Send("GET /test HTTP/1.1\r\n\r\n");
348   ReceivedRequest request = WaitForRequest();
349   ASSERT_EQ("GET", request.info.method);
350   ASSERT_EQ("/test", request.info.path);
351   ASSERT_EQ("", request.info.data);
352   ASSERT_EQ(0u, request.info.headers.size());
353   ASSERT_TRUE(request.info.peer.ToString().starts_with("127.0.0.1"));
354 }
355 
TEST_F(HttpServerTest,RequestBrokenTermination)356 TEST_F(HttpServerTest, RequestBrokenTermination) {
357   TestHttpClient client;
358   CreateConnection(&client);
359   client.Send("GET /test HTTP/1.1\r\n\r)");
360   RunUntilConnectionIdClosed(1);
361   EXPECT_FALSE(HasRequest());
362   client.ExpectUsedThenDisconnectedWithNoData();
363 }
364 
TEST_F(HttpServerTest,RequestWithHeaders)365 TEST_F(HttpServerTest, RequestWithHeaders) {
366   TestHttpClient client;
367   CreateConnection(&client);
368   const char* const kHeaders[][3] = {
369       {"Header", ": ", "1"},
370       {"HeaderWithNoWhitespace", ":", "1"},
371       {"HeaderWithWhitespace", "   :  \t   ", "1 1 1 \t  "},
372       {"HeaderWithColon", ": ", "1:1"},
373       {"EmptyHeader", ":", ""},
374       {"EmptyHeaderWithWhitespace", ":  \t  ", ""},
375       {"HeaderWithNonASCII", ":  ", "\xf7"},
376   };
377   std::string headers;
378   for (const auto& header : kHeaders) {
379     headers += std::string(header[0]) + header[1] + header[2] + "\r\n";
380   }
381 
382   client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n");
383   auto request = WaitForRequest();
384   ASSERT_EQ("", request.info.data);
385 
386   for (const auto& header : kHeaders) {
387     std::string field = base::ToLowerASCII(std::string(header[0]));
388     std::string value = header[2];
389     ASSERT_EQ(1u, request.info.headers.count(field)) << field;
390     ASSERT_EQ(value, request.info.headers[field]) << header[0];
391   }
392 }
393 
TEST_F(HttpServerTest,RequestWithDuplicateHeaders)394 TEST_F(HttpServerTest, RequestWithDuplicateHeaders) {
395   TestHttpClient client;
396   CreateConnection(&client);
397   const char* const kHeaders[][3] = {
398       // clang-format off
399       {"FirstHeader", ": ", "1"},
400       {"DuplicateHeader", ": ", "2"},
401       {"MiddleHeader", ": ", "3"},
402       {"DuplicateHeader", ": ", "4"},
403       {"LastHeader", ": ", "5"},
404       // clang-format on
405   };
406   std::string headers;
407   for (const auto& header : kHeaders) {
408     headers += std::string(header[0]) + header[1] + header[2] + "\r\n";
409   }
410 
411   client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n");
412   auto request = WaitForRequest();
413   ASSERT_EQ("", request.info.data);
414 
415   for (const auto& header : kHeaders) {
416     std::string field = base::ToLowerASCII(std::string(header[0]));
417     std::string value = (field == "duplicateheader") ? "2,4" : header[2];
418     ASSERT_EQ(1u, request.info.headers.count(field)) << field;
419     ASSERT_EQ(value, request.info.headers[field]) << header[0];
420   }
421 }
422 
TEST_F(HttpServerTest,HasHeaderValueTest)423 TEST_F(HttpServerTest, HasHeaderValueTest) {
424   TestHttpClient client;
425   CreateConnection(&client);
426   const char* const kHeaders[] = {
427       "Header: Abcd",
428       "HeaderWithNoWhitespace:E",
429       "HeaderWithWhitespace   :  \t   f \t  ",
430       "DuplicateHeader: g",
431       "HeaderWithComma: h, i ,j",
432       "DuplicateHeader: k",
433       "EmptyHeader:",
434       "EmptyHeaderWithWhitespace:  \t  ",
435       "HeaderWithNonASCII:  \xf7",
436   };
437   std::string headers;
438   for (const char* header : kHeaders) {
439     headers += std::string(header) + "\r\n";
440   }
441 
442   client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n");
443   auto request = WaitForRequest();
444   ASSERT_EQ("", request.info.data);
445 
446   ASSERT_TRUE(request.info.HasHeaderValue("header", "abcd"));
447   ASSERT_FALSE(request.info.HasHeaderValue("header", "bc"));
448   ASSERT_TRUE(request.info.HasHeaderValue("headerwithnowhitespace", "e"));
449   ASSERT_TRUE(request.info.HasHeaderValue("headerwithwhitespace", "f"));
450   ASSERT_TRUE(request.info.HasHeaderValue("duplicateheader", "g"));
451   ASSERT_TRUE(request.info.HasHeaderValue("headerwithcomma", "h"));
452   ASSERT_TRUE(request.info.HasHeaderValue("headerwithcomma", "i"));
453   ASSERT_TRUE(request.info.HasHeaderValue("headerwithcomma", "j"));
454   ASSERT_TRUE(request.info.HasHeaderValue("duplicateheader", "k"));
455   ASSERT_FALSE(request.info.HasHeaderValue("emptyheader", "x"));
456   ASSERT_FALSE(request.info.HasHeaderValue("emptyheaderwithwhitespace", "x"));
457   ASSERT_TRUE(request.info.HasHeaderValue("headerwithnonascii", "\xf7"));
458 }
459 
TEST_F(HttpServerTest,RequestWithBody)460 TEST_F(HttpServerTest, RequestWithBody) {
461   TestHttpClient client;
462   CreateConnection(&client);
463   std::string body = "a" + std::string(1 << 10, 'b') + "c";
464   client.Send(
465       base::StringPrintf("GET /test HTTP/1.1\r\n"
466                          "SomeHeader: 1\r\n"
467                          "Content-Length: %" PRIuS "\r\n\r\n%s",
468                          body.length(), body.c_str()));
469   auto request = WaitForRequest();
470   ASSERT_EQ(2u, request.info.headers.size());
471   ASSERT_EQ(body.length(), request.info.data.length());
472   ASSERT_EQ('a', body[0]);
473   ASSERT_EQ('c', *body.rbegin());
474 }
475 
476 // Tests that |HttpServer::HandleReadResult| will ignore Upgrade header if value
477 // is not WebSocket.
TEST_F(HttpServerTest,UpgradeIgnored)478 TEST_F(HttpServerTest, UpgradeIgnored) {
479   TestHttpClient client;
480   CreateConnection(&client);
481   client.Send(
482       "GET /test HTTP/1.1\r\n"
483       "Upgrade: h2c\r\n"
484       "Connection: SomethingElse, Upgrade\r\n"
485       "\r\n");
486   WaitForRequest();
487 }
488 
TEST_F(WebSocketTest,RequestWebSocket)489 TEST_F(WebSocketTest, RequestWebSocket) {
490   TestHttpClient client;
491   CreateConnection(&client);
492   client.Send(
493       "GET /test HTTP/1.1\r\n"
494       "Upgrade: WebSocket\r\n"
495       "Connection: SomethingElse, Upgrade\r\n"
496       "Sec-WebSocket-Version: 8\r\n"
497       "Sec-WebSocket-Key: key\r\n"
498       "\r\n");
499   WaitForRequest();
500 }
501 
TEST_F(WebSocketTest,RequestWebSocketTrailingJunk)502 TEST_F(WebSocketTest, RequestWebSocketTrailingJunk) {
503   TestHttpClient client;
504   CreateConnection(&client);
505   client.Send(
506       "GET /test HTTP/1.1\r\n"
507       "Upgrade: WebSocket\r\n"
508       "Connection: SomethingElse, Upgrade\r\n"
509       "Sec-WebSocket-Version: 8\r\n"
510       "Sec-WebSocket-Key: key\r\n"
511       "\r\nHello? Anyone");
512   RunUntilConnectionIdClosed(1);
513   client.ExpectUsedThenDisconnectedWithNoData();
514 }
515 
TEST_F(WebSocketAcceptingTest,SendPingFrameWithNoMessage)516 TEST_F(WebSocketAcceptingTest, SendPingFrameWithNoMessage) {
517   TestHttpClient client;
518   CreateConnection(&client);
519   std::string response;
520   client.Send(
521       "GET /test HTTP/1.1\r\n"
522       "Upgrade: WebSocket\r\n"
523       "Connection: SomethingElse, Upgrade\r\n"
524       "Sec-WebSocket-Version: 8\r\n"
525       "Sec-WebSocket-Key: key\r\n\r\n");
526   WaitForRequest();
527   ASSERT_TRUE(client.ReadResponse(&response));
528   const std::string message = "";
529   const std::string ping_frame =
530       EncodeFrame(message, WebSocketFrameHeader::OpCodeEnum::kOpCodePing,
531                   /* mask= */ true, /* finish= */ true);
532   const std::string pong_frame =
533       EncodeFrame(message, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
534                   /* mask= */ false, /* finish= */ true);
535   client.Send(ping_frame);
536   ASSERT_TRUE(client.Read(&response, pong_frame.length()));
537   EXPECT_EQ(response, pong_frame);
538 }
539 
TEST_F(WebSocketAcceptingTest,SendPingFrameWithMessage)540 TEST_F(WebSocketAcceptingTest, SendPingFrameWithMessage) {
541   TestHttpClient client;
542   CreateConnection(&client);
543   std::string response;
544   client.Send(
545       "GET /test HTTP/1.1\r\n"
546       "Upgrade: WebSocket\r\n"
547       "Connection: SomethingElse, Upgrade\r\n"
548       "Sec-WebSocket-Version: 8\r\n"
549       "Sec-WebSocket-Key: key\r\n\r\n");
550   WaitForRequest();
551   ASSERT_TRUE(client.ReadResponse(&response));
552   const std::string message = "hello";
553   const std::string ping_frame =
554       EncodeFrame(message, WebSocketFrameHeader::OpCodeEnum::kOpCodePing,
555                   /* mask= */ true, /* finish= */ true);
556   const std::string pong_frame =
557       EncodeFrame(message, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
558                   /* mask= */ false, /* finish= */ true);
559   client.Send(ping_frame);
560   ASSERT_TRUE(client.Read(&response, pong_frame.length()));
561   EXPECT_EQ(response, pong_frame);
562 }
563 
TEST_F(WebSocketAcceptingTest,SendPongFrame)564 TEST_F(WebSocketAcceptingTest, SendPongFrame) {
565   TestHttpClient client;
566   CreateConnection(&client);
567   std::string response;
568   client.Send(
569       "GET /test HTTP/1.1\r\n"
570       "Upgrade: WebSocket\r\n"
571       "Connection: SomethingElse, Upgrade\r\n"
572       "Sec-WebSocket-Version: 8\r\n"
573       "Sec-WebSocket-Key: key\r\n\r\n");
574   WaitForRequest();
575   ASSERT_TRUE(client.ReadResponse(&response));
576   const std::string ping_frame = EncodeFrame(
577       /* message= */ "", WebSocketFrameHeader::OpCodeEnum::kOpCodePing,
578       /* mask= */ true, /* finish= */ true);
579   const std::string pong_frame_send = EncodeFrame(
580       /* message= */ "", WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
581       /* mask= */ true, /* finish= */ true);
582   const std::string pong_frame_receive = EncodeFrame(
583       /* message= */ "", WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
584       /* mask= */ false, /* finish= */ true);
585   client.Send(pong_frame_send);
586   client.Send(ping_frame);
587   ASSERT_TRUE(client.Read(&response, pong_frame_receive.length()));
588   EXPECT_EQ(response, pong_frame_receive);
589 }
590 
TEST_F(WebSocketAcceptingTest,SendLongTextFrame)591 TEST_F(WebSocketAcceptingTest, SendLongTextFrame) {
592   TestHttpClient client;
593   CreateConnection(&client);
594   std::string response;
595   client.Send(
596       "GET /test HTTP/1.1\r\n"
597       "Upgrade: WebSocket\r\n"
598       "Connection: SomethingElse, Upgrade\r\n"
599       "Sec-WebSocket-Version: 8\r\n"
600       "Sec-WebSocket-Key: key\r\n\r\n");
601   WaitForRequest();
602   ASSERT_TRUE(client.ReadResponse(&response));
603   constexpr int kFrameSize = 100000;
604   const std::string text_frame(kFrameSize, 'a');
605   const std::string continuation_frame(kFrameSize, 'b');
606   const std::string text_encoded_frame =
607       EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText,
608                   /* mask= */ true,
609                   /* finish= */ false);
610   const std::string continuation_encoded_frame = EncodeFrame(
611       continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation,
612       /* mask= */ true, /* finish= */ true);
613   client.Send(text_encoded_frame);
614   client.Send(continuation_encoded_frame);
615   std::string received_message = GetMessage();
616   EXPECT_EQ(received_message.size(),
617             text_frame.size() + continuation_frame.size());
618   EXPECT_EQ(received_message, text_frame + continuation_frame);
619 }
620 
TEST_F(WebSocketAcceptingTest,SendTwoTextFrame)621 TEST_F(WebSocketAcceptingTest, SendTwoTextFrame) {
622   TestHttpClient client;
623   CreateConnection(&client);
624   std::string response;
625   client.Send(
626       "GET /test HTTP/1.1\r\n"
627       "Upgrade: WebSocket\r\n"
628       "Connection: SomethingElse, Upgrade\r\n"
629       "Sec-WebSocket-Version: 8\r\n"
630       "Sec-WebSocket-Key: key\r\n\r\n");
631   WaitForRequest();
632   ASSERT_TRUE(client.ReadResponse(&response));
633   const std::string text_frame_first = "foo";
634   const std::string continuation_frame_first = "bar";
635   const std::string text_encoded_frame_first = EncodeFrame(
636       text_frame_first, WebSocketFrameHeader::OpCodeEnum::kOpCodeText,
637       /* mask= */ true,
638       /* finish= */ false);
639   const std::string continuation_encoded_frame_first =
640       EncodeFrame(continuation_frame_first,
641                   WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation,
642                   /* mask= */ true, /* finish= */ true);
643 
644   const std::string text_frame_second = "FOO";
645   const std::string continuation_frame_second = "BAR";
646   const std::string text_encoded_frame_second = EncodeFrame(
647       text_frame_second, WebSocketFrameHeader::OpCodeEnum::kOpCodeText,
648       /* mask= */ true,
649       /* finish= */ false);
650   const std::string continuation_encoded_frame_second =
651       EncodeFrame(continuation_frame_second,
652                   WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation,
653                   /* mask= */ true, /* finish= */ true);
654 
655   // text_encoded_frame_first -> text_encoded_frame_second
656   client.Send(text_encoded_frame_first);
657   client.Send(continuation_encoded_frame_first);
658   std::string received_message = GetMessage();
659   EXPECT_EQ(received_message, "foobar");
660   client.Send(text_encoded_frame_second);
661   client.Send(continuation_encoded_frame_second);
662   received_message = GetMessage();
663   EXPECT_EQ(received_message, "FOOBAR");
664 }
665 
TEST_F(WebSocketAcceptingTest,SendPingPongFrame)666 TEST_F(WebSocketAcceptingTest, SendPingPongFrame) {
667   TestHttpClient client;
668   CreateConnection(&client);
669   std::string response;
670   client.Send(
671       "GET /test HTTP/1.1\r\n"
672       "Upgrade: WebSocket\r\n"
673       "Connection: SomethingElse, Upgrade\r\n"
674       "Sec-WebSocket-Version: 8\r\n"
675       "Sec-WebSocket-Key: key\r\n\r\n");
676   WaitForRequest();
677   ASSERT_TRUE(client.ReadResponse(&response));
678 
679   const std::string ping_message_first = "";
680   const std::string ping_frame_first = EncodeFrame(
681       ping_message_first, WebSocketFrameHeader::OpCodeEnum::kOpCodePing,
682       /* mask= */ true, /* finish= */ true);
683   const std::string pong_frame_receive_first = EncodeFrame(
684       ping_message_first, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
685       /* mask= */ false, /* finish= */ true);
686   const std::string pong_frame_send = EncodeFrame(
687       /* message= */ "", WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
688       /* mask= */ true, /* finish= */ true);
689   const std::string ping_message_second = "hello";
690   const std::string ping_frame_second = EncodeFrame(
691       ping_message_second, WebSocketFrameHeader::OpCodeEnum::kOpCodePing,
692       /* mask= */ true, /* finish= */ true);
693   const std::string pong_frame_receive_second = EncodeFrame(
694       ping_message_second, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
695       /* mask= */ false, /* finish= */ true);
696 
697   // ping_frame_first -> pong_frame_send -> ping_frame_second
698   client.Send(ping_frame_first);
699   ASSERT_TRUE(client.Read(&response, pong_frame_receive_first.length()));
700   EXPECT_EQ(response, pong_frame_receive_first);
701   client.Send(pong_frame_send);
702   client.Send(ping_frame_second);
703   ASSERT_TRUE(client.Read(&response, pong_frame_receive_second.length()));
704   EXPECT_EQ(response, pong_frame_receive_second);
705 }
706 
TEST_F(WebSocketAcceptingTest,SendTextAndPingFrame)707 TEST_F(WebSocketAcceptingTest, SendTextAndPingFrame) {
708   TestHttpClient client;
709   CreateConnection(&client);
710   std::string response;
711   client.Send(
712       "GET /test HTTP/1.1\r\n"
713       "Upgrade: WebSocket\r\n"
714       "Connection: SomethingElse, Upgrade\r\n"
715       "Sec-WebSocket-Version: 8\r\n"
716       "Sec-WebSocket-Key: key\r\n\r\n");
717   WaitForRequest();
718   ASSERT_TRUE(client.ReadResponse(&response));
719 
720   const std::string text_frame = "foo";
721   const std::string continuation_frame = "bar";
722   const std::string text_encoded_frame =
723       EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText,
724                   /* mask= */ true,
725                   /* finish= */ false);
726   const std::string continuation_encoded_frame = EncodeFrame(
727       continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation,
728       /* mask= */ true, /* finish= */ true);
729   const std::string ping_message = "ping";
730   const std::string ping_frame =
731       EncodeFrame(ping_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePing,
732                   /* mask= */ true, /* finish= */ true);
733   const std::string pong_frame =
734       EncodeFrame(ping_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
735                   /* mask= */ false, /* finish= */ true);
736 
737   // text_encoded_frame -> ping_frame -> continuation_encoded_frame
738   client.Send(text_encoded_frame);
739   client.Send(ping_frame);
740   client.Send(continuation_encoded_frame);
741   ASSERT_TRUE(client.Read(&response, pong_frame.length()));
742   EXPECT_EQ(response, pong_frame);
743   std::string received_message = GetMessage();
744   EXPECT_EQ(received_message, "foobar");
745 }
746 
TEST_F(WebSocketAcceptingTest,SendTextAndPingFrameWithMessage)747 TEST_F(WebSocketAcceptingTest, SendTextAndPingFrameWithMessage) {
748   TestHttpClient client;
749   CreateConnection(&client);
750   std::string response;
751   client.Send(
752       "GET /test HTTP/1.1\r\n"
753       "Upgrade: WebSocket\r\n"
754       "Connection: SomethingElse, Upgrade\r\n"
755       "Sec-WebSocket-Version: 8\r\n"
756       "Sec-WebSocket-Key: key\r\n\r\n");
757   WaitForRequest();
758   ASSERT_TRUE(client.ReadResponse(&response));
759 
760   const std::string text_frame = "foo";
761   const std::string continuation_frame = "bar";
762   const std::string text_encoded_frame =
763       EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText,
764                   /* mask= */ true,
765                   /* finish= */ false);
766   const std::string continuation_encoded_frame = EncodeFrame(
767       continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation,
768       /* mask= */ true, /* finish= */ true);
769   const std::string ping_message = "hello";
770   const std::string ping_frame =
771       EncodeFrame(ping_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePing,
772                   /* mask= */ true, /* finish= */ true);
773   const std::string pong_frame =
774       EncodeFrame(ping_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
775                   /* mask= */ false, /* finish= */ true);
776 
777   // text_encoded_frame -> ping_frame -> continuation_frame
778   client.Send(text_encoded_frame);
779   client.Send(ping_frame);
780   client.Send(continuation_encoded_frame);
781   ASSERT_TRUE(client.Read(&response, pong_frame.length()));
782   EXPECT_EQ(response, pong_frame);
783   std::string received_message = GetMessage();
784   EXPECT_EQ(received_message, "foobar");
785 }
786 
TEST_F(WebSocketAcceptingTest,SendTextAndPongFrame)787 TEST_F(WebSocketAcceptingTest, SendTextAndPongFrame) {
788   TestHttpClient client;
789   CreateConnection(&client);
790   std::string response;
791   client.Send(
792       "GET /test HTTP/1.1\r\n"
793       "Upgrade: WebSocket\r\n"
794       "Connection: SomethingElse, Upgrade\r\n"
795       "Sec-WebSocket-Version: 8\r\n"
796       "Sec-WebSocket-Key: key\r\n\r\n");
797   WaitForRequest();
798   ASSERT_TRUE(client.ReadResponse(&response));
799 
800   const std::string text_frame = "foo";
801   const std::string continuation_frame = "bar";
802   const std::string text_encoded_frame =
803       EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText,
804                   /* mask= */ true,
805                   /* finish= */ false);
806   const std::string continuation_encoded_frame = EncodeFrame(
807       continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation,
808       /* mask= */ true, /* finish= */ true);
809   const std::string pong_message = "pong";
810   const std::string pong_frame =
811       EncodeFrame(pong_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
812                   /* mask= */ true, /* finish= */ true);
813 
814   // text_encoded_frame -> pong_frame -> continuation_encoded_frame
815   client.Send(text_encoded_frame);
816   client.Send(pong_frame);
817   client.Send(continuation_encoded_frame);
818   std::string received_message = GetMessage();
819   EXPECT_EQ(received_message, "foobar");
820 }
821 
TEST_F(WebSocketAcceptingTest,SendTextPingPongFrame)822 TEST_F(WebSocketAcceptingTest, SendTextPingPongFrame) {
823   TestHttpClient client;
824   CreateConnection(&client);
825   std::string response;
826   client.Send(
827       "GET /test HTTP/1.1\r\n"
828       "Upgrade: WebSocket\r\n"
829       "Connection: SomethingElse, Upgrade\r\n"
830       "Sec-WebSocket-Version: 8\r\n"
831       "Sec-WebSocket-Key: key\r\n\r\n");
832   WaitForRequest();
833   ASSERT_TRUE(client.ReadResponse(&response));
834 
835   const std::string text_frame = "foo";
836   const std::string continuation_frame = "bar";
837   const std::string text_encoded_frame =
838       EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText,
839                   /* mask= */ true,
840                   /* finish= */ false);
841   const std::string continuation_encoded_frame = EncodeFrame(
842       continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation,
843       /* mask= */ true, /* finish= */ true);
844 
845   const std::string ping_message_first = "hello";
846   const std::string ping_frame_first = EncodeFrame(
847       ping_message_first, WebSocketFrameHeader::OpCodeEnum::kOpCodePing,
848       /* mask= */ true, /* finish= */ true);
849   const std::string pong_frame_first = EncodeFrame(
850       ping_message_first, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
851       /* mask= */ false, /* finish= */ true);
852 
853   const std::string ping_message_second = "HELLO";
854   const std::string ping_frame_second = EncodeFrame(
855       ping_message_second, WebSocketFrameHeader::OpCodeEnum::kOpCodePing,
856       /* mask= */ true, /* finish= */ true);
857   const std::string pong_frame_second = EncodeFrame(
858       ping_message_second, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
859       /* mask= */ false, /* finish= */ true);
860 
861   // text_encoded_frame -> ping_frame_first -> ping_frame_second ->
862   // continuation_encoded_frame
863   client.Send(text_encoded_frame);
864   client.Send(ping_frame_first);
865   ASSERT_TRUE(client.Read(&response, pong_frame_first.length()));
866   EXPECT_EQ(response, pong_frame_first);
867   client.Send(ping_frame_second);
868   ASSERT_TRUE(client.Read(&response, pong_frame_second.length()));
869   EXPECT_EQ(response, pong_frame_second);
870   client.Send(continuation_encoded_frame);
871   std::string received_message = GetMessage();
872   EXPECT_EQ(received_message, "foobar");
873 }
874 
TEST_F(HttpServerTest,RequestWithTooLargeBody)875 TEST_F(HttpServerTest, RequestWithTooLargeBody) {
876   TestHttpClient client;
877   CreateConnection(&client);
878   client.Send(
879       "GET /test HTTP/1.1\r\n"
880       "Content-Length: 1073741824\r\n\r\n");
881   std::string response;
882   ASSERT_TRUE(client.ReadResponse(&response));
883   EXPECT_EQ(
884       "HTTP/1.1 500 Internal Server Error\r\n"
885       "Content-Length:42\r\n"
886       "Content-Type:text/html\r\n\r\n"
887       "request content-length too big or unknown.",
888       response);
889 }
890 
TEST_F(HttpServerTest,Send200)891 TEST_F(HttpServerTest, Send200) {
892   TestHttpClient client;
893   CreateConnection(&client);
894   client.Send("GET /test HTTP/1.1\r\n\r\n");
895   auto request = WaitForRequest();
896   server_->Send200(request.connection_id, "Response!", "text/plain",
897                    TRAFFIC_ANNOTATION_FOR_TESTS);
898 
899   std::string response;
900   ASSERT_TRUE(client.ReadResponse(&response));
901   ASSERT_TRUE(response.starts_with("HTTP/1.1 200 OK"));
902   ASSERT_TRUE(response.ends_with("Response!"));
903 }
904 
TEST_F(HttpServerTest,SendRaw)905 TEST_F(HttpServerTest, SendRaw) {
906   TestHttpClient client;
907   CreateConnection(&client);
908   client.Send("GET /test HTTP/1.1\r\n\r\n");
909   auto request = WaitForRequest();
910   server_->SendRaw(request.connection_id, "Raw Data ",
911                    TRAFFIC_ANNOTATION_FOR_TESTS);
912   server_->SendRaw(request.connection_id, "More Data",
913                    TRAFFIC_ANNOTATION_FOR_TESTS);
914   server_->SendRaw(request.connection_id, "Third Piece of Data",
915                    TRAFFIC_ANNOTATION_FOR_TESTS);
916 
917   const std::string expected_response("Raw Data More DataThird Piece of Data");
918   std::string response;
919   ASSERT_TRUE(client.Read(&response, expected_response.length()));
920   ASSERT_EQ(expected_response, response);
921 }
922 
TEST_F(HttpServerTest,WrongProtocolRequest)923 TEST_F(HttpServerTest, WrongProtocolRequest) {
924   const char* const kBadProtocolRequests[] = {
925       "GET /test HTTP/1.0\r\n\r\n",
926       "GET /test foo\r\n\r\n",
927       "GET /test \r\n\r\n",
928   };
929 
930   for (const char* bad_request : kBadProtocolRequests) {
931     TestHttpClient client;
932     CreateConnection(&client);
933 
934     client.Send(bad_request);
935     client.ExpectUsedThenDisconnectedWithNoData();
936 
937     // Assert that the delegate was updated properly.
938     ASSERT_EQ(1u, connection_map().size());
939     ASSERT_FALSE(connection_map().begin()->second);
940     EXPECT_FALSE(HasRequest());
941 
942     // Reset the state of the connection map.
943     connection_map().clear();
944   }
945 }
946 
947 class MockStreamSocket : public StreamSocket {
948  public:
949   MockStreamSocket() = default;
950 
951   MockStreamSocket(const MockStreamSocket&) = delete;
952   MockStreamSocket& operator=(const MockStreamSocket&) = delete;
953 
954   ~MockStreamSocket() override = default;
955 
956   // StreamSocket
Connect(CompletionOnceCallback callback)957   int Connect(CompletionOnceCallback callback) override {
958     return ERR_NOT_IMPLEMENTED;
959   }
Disconnect()960   void Disconnect() override {
961     connected_ = false;
962     if (!read_callback_.is_null()) {
963       read_buf_ = nullptr;
964       read_buf_len_ = 0;
965       std::move(read_callback_).Run(ERR_CONNECTION_CLOSED);
966     }
967   }
IsConnected() const968   bool IsConnected() const override { return connected_; }
IsConnectedAndIdle() const969   bool IsConnectedAndIdle() const override { return IsConnected(); }
GetPeerAddress(IPEndPoint * address) const970   int GetPeerAddress(IPEndPoint* address) const override {
971     return ERR_NOT_IMPLEMENTED;
972   }
GetLocalAddress(IPEndPoint * address) const973   int GetLocalAddress(IPEndPoint* address) const override {
974     return ERR_NOT_IMPLEMENTED;
975   }
NetLog() const976   const NetLogWithSource& NetLog() const override { return net_log_; }
WasEverUsed() const977   bool WasEverUsed() const override { return true; }
GetNegotiatedProtocol() const978   NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
GetSSLInfo(SSLInfo * ssl_info)979   bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
GetTotalReceivedBytes() const980   int64_t GetTotalReceivedBytes() const override {
981     NOTIMPLEMENTED();
982     return 0;
983   }
ApplySocketTag(const SocketTag & tag)984   void ApplySocketTag(const SocketTag& tag) override {}
985 
986   // Socket
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)987   int Read(IOBuffer* buf,
988            int buf_len,
989            CompletionOnceCallback callback) override {
990     if (!connected_) {
991       return ERR_SOCKET_NOT_CONNECTED;
992     }
993     if (pending_read_data_.empty()) {
994       read_buf_ = buf;
995       read_buf_len_ = buf_len;
996       read_callback_ = std::move(callback);
997       return ERR_IO_PENDING;
998     }
999     DCHECK_GT(buf_len, 0);
1000     int read_len =
1001         std::min(static_cast<int>(pending_read_data_.size()), buf_len);
1002     memcpy(buf->data(), pending_read_data_.data(), read_len);
1003     pending_read_data_.erase(0, read_len);
1004     return read_len;
1005   }
1006 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)1007   int Write(IOBuffer* buf,
1008             int buf_len,
1009             CompletionOnceCallback callback,
1010             const NetworkTrafficAnnotationTag& traffic_annotation) override {
1011     return ERR_NOT_IMPLEMENTED;
1012   }
SetReceiveBufferSize(int32_t size)1013   int SetReceiveBufferSize(int32_t size) override {
1014     return ERR_NOT_IMPLEMENTED;
1015   }
SetSendBufferSize(int32_t size)1016   int SetSendBufferSize(int32_t size) override { return ERR_NOT_IMPLEMENTED; }
1017 
DidRead(const char * data,int data_len)1018   void DidRead(const char* data, int data_len) {
1019     if (!read_buf_.get()) {
1020       pending_read_data_.append(data, data_len);
1021       return;
1022     }
1023     int read_len = std::min(data_len, read_buf_len_);
1024     memcpy(read_buf_->data(), data, read_len);
1025     pending_read_data_.assign(data + read_len, data_len - read_len);
1026     read_buf_ = nullptr;
1027     read_buf_len_ = 0;
1028     std::move(read_callback_).Run(read_len);
1029   }
1030 
1031  private:
1032   bool connected_ = true;
1033   scoped_refptr<IOBuffer> read_buf_;
1034   int read_buf_len_ = 0;
1035   CompletionOnceCallback read_callback_;
1036   std::string pending_read_data_;
1037   NetLogWithSource net_log_;
1038 };
1039 
TEST_F(HttpServerTest,RequestWithBodySplitAcrossPackets)1040 TEST_F(HttpServerTest, RequestWithBodySplitAcrossPackets) {
1041   auto socket = std::make_unique<MockStreamSocket>();
1042   auto* socket_ptr = socket.get();
1043   HandleAcceptResult(std::move(socket));
1044   std::string body("body");
1045   std::string request_text = base::StringPrintf(
1046       "GET /test HTTP/1.1\r\n"
1047       "SomeHeader: 1\r\n"
1048       "Content-Length: %" PRIuS "\r\n\r\n%s",
1049       body.length(), body.c_str());
1050   socket_ptr->DidRead(request_text.c_str(), request_text.length() - 2);
1051   ASSERT_FALSE(HasRequest());
1052   socket_ptr->DidRead(request_text.c_str() + request_text.length() - 2, 2);
1053   ASSERT_TRUE(HasRequest());
1054   ASSERT_EQ(body, WaitForRequest().info.data);
1055 }
1056 
TEST_F(HttpServerTest,MultipleRequestsOnSameConnection)1057 TEST_F(HttpServerTest, MultipleRequestsOnSameConnection) {
1058   // The idea behind this test is that requests with or without bodies should
1059   // not break parsing of the next request.
1060   TestHttpClient client;
1061   CreateConnection(&client);
1062   std::string body = "body";
1063   client.Send(
1064       base::StringPrintf("GET /test HTTP/1.1\r\n"
1065                          "Content-Length: %" PRIuS "\r\n\r\n%s",
1066                          body.length(), body.c_str()));
1067   auto first_request = WaitForRequest();
1068   ASSERT_EQ(body, first_request.info.data);
1069 
1070   int client_connection_id = first_request.connection_id;
1071   server_->Send200(client_connection_id, "Content for /test", "text/plain",
1072                    TRAFFIC_ANNOTATION_FOR_TESTS);
1073   std::string response1;
1074   ASSERT_TRUE(client.ReadResponse(&response1));
1075   ASSERT_TRUE(response1.starts_with("HTTP/1.1 200 OK"));
1076   ASSERT_TRUE(response1.ends_with("Content for /test"));
1077 
1078   client.Send("GET /test2 HTTP/1.1\r\n\r\n");
1079   auto second_request = WaitForRequest();
1080   ASSERT_EQ("/test2", second_request.info.path);
1081 
1082   ASSERT_EQ(client_connection_id, second_request.connection_id);
1083   server_->Send404(client_connection_id, TRAFFIC_ANNOTATION_FOR_TESTS);
1084   std::string response2;
1085   ASSERT_TRUE(client.ReadResponse(&response2));
1086   ASSERT_TRUE(response2.starts_with("HTTP/1.1 404 Not Found"));
1087 
1088   client.Send("GET /test3 HTTP/1.1\r\n\r\n");
1089   auto third_request = WaitForRequest();
1090   ASSERT_EQ("/test3", third_request.info.path);
1091 
1092   ASSERT_EQ(client_connection_id, third_request.connection_id);
1093   server_->Send200(client_connection_id, "Content for /test3", "text/plain",
1094                    TRAFFIC_ANNOTATION_FOR_TESTS);
1095   std::string response3;
1096   ASSERT_TRUE(client.ReadResponse(&response3));
1097   ASSERT_TRUE(response3.starts_with("HTTP/1.1 200 OK"));
1098   ASSERT_TRUE(response3.ends_with("Content for /test3"));
1099 }
1100 
1101 class CloseOnConnectHttpServerTest : public HttpServerTest {
1102  public:
OnConnect(int connection_id)1103   void OnConnect(int connection_id) override {
1104     HttpServerTest::OnConnect(connection_id);
1105     connection_ids_.push_back(connection_id);
1106     server_->Close(connection_id);
1107   }
1108 
1109  protected:
1110   std::vector<int> connection_ids_;
1111 };
1112 
TEST_F(CloseOnConnectHttpServerTest,ServerImmediatelyClosesConnection)1113 TEST_F(CloseOnConnectHttpServerTest, ServerImmediatelyClosesConnection) {
1114   TestHttpClient client;
1115   CreateConnection(&client);
1116   client.Send("GET / HTTP/1.1\r\n\r\n");
1117 
1118   // The server should close the socket without responding.
1119   client.ExpectUsedThenDisconnectedWithNoData();
1120 
1121   // Run any tasks the TestServer posted.
1122   base::RunLoop().RunUntilIdle();
1123 
1124   EXPECT_EQ(1ul, connection_ids_.size());
1125   // OnHttpRequest() should never have been called, since the connection was
1126   // closed without reading from it.
1127   EXPECT_FALSE(HasRequest());
1128 }
1129 
1130 }  // namespace
1131 
1132 }  // namespace net
1133