• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "perfetto/ext/base/http/http_server.h"
18 
19 #include <initializer_list>
20 #include <string>
21 
22 #include "perfetto/ext/base/string_utils.h"
23 #include "perfetto/ext/base/unix_socket.h"
24 #include "src/base/test/test_task_runner.h"
25 #include "test/gtest_and_gmock.h"
26 
27 namespace perfetto {
28 namespace base {
29 namespace {
30 
31 using testing::_;
32 using testing::Invoke;
33 using testing::InvokeWithoutArgs;
34 using testing::NiceMock;
35 
36 constexpr int kTestPort = 5127;  // Chosen with a fair dice roll.
37 
38 class MockHttpHandler : public HttpRequestHandler {
39  public:
40   MOCK_METHOD1(OnHttpRequest, void(const HttpRequest&));
41   MOCK_METHOD1(OnHttpConnectionClosed, void(HttpServerConnection*));
42   MOCK_METHOD1(OnWebsocketMessage, void(const WebsocketMessage&));
43 };
44 
45 class HttpCli {
46  public:
HttpCli(TestTaskRunner * ttr)47   explicit HttpCli(TestTaskRunner* ttr) : task_runner_(ttr) {
48     sock = UnixSocketRaw::CreateMayFail(SockFamily::kInet, SockType::kStream);
49     sock.SetBlocking(true);
50     sock.Connect("127.0.0.1:" + std::to_string(kTestPort));
51   }
52 
SendHttpReq(std::initializer_list<std::string> headers,const std::string & body="")53   void SendHttpReq(std::initializer_list<std::string> headers,
54                    const std::string& body = "") {
55     for (auto& header : headers)
56       sock.SendStr(header + "\r\n");
57     if (!body.empty())
58       sock.SendStr("Content-Length: " + std::to_string(body.size()) + "\r\n");
59     sock.SendStr("\r\n");
60     sock.SendStr(body);
61   }
62 
Recv(size_t min_bytes)63   std::string Recv(size_t min_bytes) {
64     static int n = 0;
65     auto checkpoint_name = "rx_" + std::to_string(n++);
66     auto checkpoint = task_runner_->CreateCheckpoint(checkpoint_name);
67     std::string rxbuf;
68     sock.SetBlocking(false);
69     task_runner_->AddFileDescriptorWatch(sock.watch_handle(), [&] {
70       char buf[1024]{};
71       auto rsize = PERFETTO_EINTR(sock.Receive(buf, sizeof(buf)));
72       if (rsize < 0)
73         return;
74       rxbuf.append(buf, static_cast<size_t>(rsize));
75       if (rsize == 0 || (min_bytes && rxbuf.length() >= min_bytes))
76         checkpoint();
77     });
78     task_runner_->RunUntilCheckpoint(checkpoint_name);
79     task_runner_->RemoveFileDescriptorWatch(sock.watch_handle());
80     return rxbuf;
81   }
82 
RecvAndWaitConnClose()83   std::string RecvAndWaitConnClose() { return Recv(0); }
84 
85   TestTaskRunner* task_runner_;
86   UnixSocketRaw sock;
87 };
88 
89 class HttpServerTest : public ::testing::Test {
90  public:
HttpServerTest()91   HttpServerTest() : srv_(&task_runner_, &handler_) { srv_.Start(kTestPort); }
92 
93   TestTaskRunner task_runner_;
94   MockHttpHandler handler_;
95   HttpServer srv_;
96 };
97 
TEST_F(HttpServerTest,GET)98 TEST_F(HttpServerTest, GET) {
99   const int kIterations = 3;
100   EXPECT_CALL(handler_, OnHttpRequest(_))
101       .Times(kIterations)
102       .WillRepeatedly(Invoke([](const HttpRequest& req) {
103         EXPECT_EQ(req.uri.ToStdString(), "/foo/bar");
104         EXPECT_EQ(req.method.ToStdString(), "GET");
105         EXPECT_EQ(req.origin.ToStdString(), "https://example.com");
106         EXPECT_EQ("42",
107                   req.GetHeader("X-header").value_or("N/A").ToStdString());
108         EXPECT_EQ("foo",
109                   req.GetHeader("X-header2").value_or("N/A").ToStdString());
110         EXPECT_FALSE(req.is_websocket_handshake);
111         req.conn->SendResponseAndClose("200 OK", {}, "<html>");
112       }));
113   EXPECT_CALL(handler_, OnHttpConnectionClosed(_)).Times(kIterations);
114 
115   for (int i = 0; i < 3; i++) {
116     HttpCli cli(&task_runner_);
117     cli.SendHttpReq(
118         {
119             "GET /foo/bar HTTP/1.1",        //
120             "Origin: https://example.com",  //
121             "X-header: 42",                 //
122             "X-header2: foo",               //
123         },
124         "");
125     EXPECT_EQ(cli.RecvAndWaitConnClose(),
126               "HTTP/1.1 200 OK\r\n"
127               "Content-Length: 6\r\n"
128               "Connection: close\r\n"
129               "\r\n<html>");
130   }
131 }
132 
TEST_F(HttpServerTest,GET_404)133 TEST_F(HttpServerTest, GET_404) {
134   HttpCli cli(&task_runner_);
135   EXPECT_CALL(handler_, OnHttpRequest(_))
136       .WillOnce(Invoke([&](const HttpRequest& req) {
137         EXPECT_EQ(req.uri.ToStdString(), "/404");
138         EXPECT_EQ(req.method.ToStdString(), "GET");
139         req.conn->SendResponseAndClose("404 Not Found");
140       }));
141   cli.SendHttpReq({"GET /404 HTTP/1.1"}, "");
142   EXPECT_CALL(handler_, OnHttpConnectionClosed(_));
143   EXPECT_EQ(cli.RecvAndWaitConnClose(),
144             "HTTP/1.1 404 Not Found\r\n"
145             "Content-Length: 0\r\n"
146             "Connection: close\r\n"
147             "\r\n");
148 }
149 
TEST_F(HttpServerTest,POST)150 TEST_F(HttpServerTest, POST) {
151   HttpCli cli(&task_runner_);
152 
153   EXPECT_CALL(handler_, OnHttpRequest(_))
154       .WillOnce(Invoke([&](const HttpRequest& req) {
155         EXPECT_EQ(req.uri.ToStdString(), "/rpc");
156         EXPECT_EQ(req.method.ToStdString(), "POST");
157         EXPECT_EQ(req.origin.ToStdString(), "https://example.com");
158         EXPECT_EQ("foo", req.GetHeader("X-1").value_or("N/A").ToStdString());
159         EXPECT_EQ(req.body.ToStdString(), "the\r\npost\nbody\r\n\r\n");
160         req.conn->SendResponseAndClose("200 OK");
161       }));
162 
163   cli.SendHttpReq(
164       {"POST /rpc HTTP/1.1", "Origin: https://example.com", "X-1: foo"},
165       "the\r\npost\nbody\r\n\r\n");
166   EXPECT_CALL(handler_, OnHttpConnectionClosed(_));
167   EXPECT_EQ(cli.RecvAndWaitConnClose(),
168             "HTTP/1.1 200 OK\r\n"
169             "Content-Length: 0\r\n"
170             "Connection: close\r\n"
171             "\r\n");
172 }
173 
174 // An unhandled request should cause a HTTP 500.
TEST_F(HttpServerTest,Unhadled_500)175 TEST_F(HttpServerTest, Unhadled_500) {
176   HttpCli cli(&task_runner_);
177   EXPECT_CALL(handler_, OnHttpRequest(_));
178   cli.SendHttpReq({"GET /unhandled HTTP/1.1"});
179   EXPECT_CALL(handler_, OnHttpConnectionClosed(_));
180   EXPECT_EQ(cli.RecvAndWaitConnClose(),
181             "HTTP/1.1 500 Internal Server Error\r\n"
182             "Content-Length: 0\r\n"
183             "Connection: close\r\n"
184             "\r\n");
185 }
186 
187 // Send three requests within the same keepalive connection.
TEST_F(HttpServerTest,POST_Keepalive)188 TEST_F(HttpServerTest, POST_Keepalive) {
189   HttpCli cli(&task_runner_);
190   static const int kNumRequests = 3;
191   int req_num = 0;
192   EXPECT_CALL(handler_, OnHttpConnectionClosed(_)).Times(1);
193   EXPECT_CALL(handler_, OnHttpRequest(_))
194       .Times(3)
195       .WillRepeatedly(Invoke([&](const HttpRequest& req) {
196         EXPECT_EQ(req.uri.ToStdString(), "/" + std::to_string(req_num));
197         EXPECT_EQ(req.method.ToStdString(), "POST");
198         EXPECT_EQ(req.body.ToStdString(), "body" + std::to_string(req_num));
199         req.conn->SendResponseHeaders("200 OK");
200         if (++req_num == kNumRequests)
201           req.conn->Close();
202       }));
203 
204   for (int i = 0; i < kNumRequests; i++) {
205     auto i_str = std::to_string(i);
206     cli.SendHttpReq({"POST /" + i_str + " HTTP/1.1", "Connection: keep-alive"},
207                     "body" + i_str);
208   }
209 
210   std::string expected_response;
211   for (int i = 0; i < kNumRequests; i++) {
212     expected_response +=
213         "HTTP/1.1 200 OK\r\n"
214         "Content-Length: 0\r\n"
215         "Connection: keep-alive\r\n"
216         "\r\n";
217   }
218   EXPECT_EQ(cli.RecvAndWaitConnClose(), expected_response);
219 }
220 
TEST_F(HttpServerTest,Websocket)221 TEST_F(HttpServerTest, Websocket) {
222   srv_.AddAllowedOrigin("http://foo.com");
223   srv_.AddAllowedOrigin("http://websocket.com");
224   for (int rep = 0; rep < 3; rep++) {
225     HttpCli cli(&task_runner_);
226     EXPECT_CALL(handler_, OnHttpRequest(_))
227         .WillOnce(Invoke([&](const HttpRequest& req) {
228           EXPECT_EQ(req.uri.ToStdString(), "/websocket");
229           EXPECT_EQ(req.method.ToStdString(), "GET");
230           EXPECT_EQ(req.origin.ToStdString(), "http://websocket.com");
231           EXPECT_TRUE(req.is_websocket_handshake);
232           req.conn->UpgradeToWebsocket(req);
233         }));
234 
235     cli.SendHttpReq({
236         "GET /websocket HTTP/1.1",                      //
237         "Origin: http://websocket.com",                 //
238         "Connection: upgrade",                          //
239         "Upgrade: websocket",                           //
240         "Sec-WebSocket-Version: 13",                    //
241         "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==",  //
242     });
243     std::string expected_resp =
244         "HTTP/1.1 101 Switching Protocols\r\n"
245         "Upgrade: websocket\r\n"
246         "Connection: Upgrade\r\n"
247         "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n"
248         "Access-Control-Allow-Origin: http://websocket.com\r\n"
249         "Vary: Origin\r\n"
250         "\r\n";
251     EXPECT_EQ(cli.Recv(expected_resp.size()), expected_resp);
252 
253     for (int i = 0; i < 3; i++) {
254       EXPECT_CALL(handler_, OnWebsocketMessage(_))
255           .WillOnce(Invoke([i](const WebsocketMessage& msg) {
256             EXPECT_EQ(msg.data.ToStdString(), "test message");
257             StackString<6> resp("PONG%d", i);
258             msg.conn->SendWebsocketMessage(resp.c_str(), resp.len());
259           }));
260 
261       // A frame from a real tcpdump capture:
262       //   1... .... = Fin: True
263       //   .000 .... = Reserved: 0x0
264       //   .... 0001 = Opcode: Text (1)
265       //   1... .... = Mask: True
266       //   .000 1100 = Payload length: 12
267       //   Masking-Key: e17e8eb9
268       //   Masked payload: "test message"
269       cli.sock.SendStr(
270           "\x81\x8c\xe1\x7e\x8e\xb9\x95\x1b\xfd\xcd\xc1\x13\xeb\xca\x92\x1f\xe9"
271           "\xdc");
272       EXPECT_EQ(cli.Recv(2 + 5), "\x82\x05PONG" + std::to_string(i));
273     }
274 
275     cli.sock.Shutdown();
276     auto checkpoint_name = "ws_close_" + std::to_string(rep);
277     auto ws_close = task_runner_.CreateCheckpoint(checkpoint_name);
278     EXPECT_CALL(handler_, OnHttpConnectionClosed(_))
279         .WillOnce(InvokeWithoutArgs(ws_close));
280     task_runner_.RunUntilCheckpoint(checkpoint_name);
281   }
282 }
283 
TEST_F(HttpServerTest,Websocket_OriginNotAllowed)284 TEST_F(HttpServerTest, Websocket_OriginNotAllowed) {
285   srv_.AddAllowedOrigin("http://websocket.com");
286   srv_.AddAllowedOrigin("http://notallowed.commando");
287   srv_.AddAllowedOrigin("http://iamnotallowed.com");
288   srv_.AddAllowedOrigin("iamnotallowed.com");
289   // The origin must match in full, including scheme. This won't match.
290   srv_.AddAllowedOrigin("notallowed.com");
291 
292   HttpCli cli(&task_runner_);
293   auto close_checkpoint = task_runner_.CreateCheckpoint("close");
294   EXPECT_CALL(handler_, OnHttpConnectionClosed(_))
295       .WillOnce(InvokeWithoutArgs(close_checkpoint));
296   EXPECT_CALL(handler_, OnHttpRequest(_))
297       .WillOnce(Invoke([&](const HttpRequest& req) {
298         EXPECT_EQ(req.origin.ToStdString(), "http://notallowed.com");
299         EXPECT_TRUE(req.is_websocket_handshake);
300         req.conn->UpgradeToWebsocket(req);
301       }));
302 
303   cli.SendHttpReq({
304       "GET /websocket HTTP/1.1",                      //
305       "Origin: http://notallowed.com",                //
306       "Connection: upgrade",                          //
307       "Upgrade: websocket",                           //
308       "Sec-WebSocket-Version: 13",                    //
309       "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==",  //
310   });
311   std::string expected_resp =
312       "HTTP/1.1 403 Forbidden\r\n"
313       "Content-Length: 18\r\n"
314       "Connection: close\r\n"
315       "\r\n"
316       "Origin not allowed";
317 
318   EXPECT_EQ(cli.Recv(expected_resp.size()), expected_resp);
319   cli.sock.Shutdown();
320   task_runner_.RunUntilCheckpoint("close");
321 }
322 
323 }  // namespace
324 }  // namespace base
325 }  // namespace perfetto
326