1 /*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "perfetto/ext/base/http/http_server.h"
18
19 #include <initializer_list>
20 #include <string>
21
22 #include "perfetto/ext/base/string_utils.h"
23 #include "perfetto/ext/base/unix_socket.h"
24 #include "src/base/test/test_task_runner.h"
25 #include "test/gtest_and_gmock.h"
26
27 namespace perfetto {
28 namespace base {
29 namespace {
30
31 using testing::_;
32 using testing::Invoke;
33 using testing::InvokeWithoutArgs;
34 using testing::NiceMock;
35
36 constexpr int kTestPort = 5127; // Chosen with a fair dice roll.
37
38 class MockHttpHandler : public HttpRequestHandler {
39 public:
40 MOCK_METHOD1(OnHttpRequest, void(const HttpRequest&));
41 MOCK_METHOD1(OnHttpConnectionClosed, void(HttpServerConnection*));
42 MOCK_METHOD1(OnWebsocketMessage, void(const WebsocketMessage&));
43 };
44
45 class HttpCli {
46 public:
HttpCli(TestTaskRunner * ttr)47 explicit HttpCli(TestTaskRunner* ttr) : task_runner_(ttr) {
48 sock = UnixSocketRaw::CreateMayFail(SockFamily::kInet, SockType::kStream);
49 sock.SetBlocking(true);
50 sock.Connect("127.0.0.1:" + std::to_string(kTestPort));
51 }
52
SendHttpReq(std::initializer_list<std::string> headers,const std::string & body="")53 void SendHttpReq(std::initializer_list<std::string> headers,
54 const std::string& body = "") {
55 for (auto& header : headers)
56 sock.SendStr(header + "\r\n");
57 if (!body.empty())
58 sock.SendStr("Content-Length: " + std::to_string(body.size()) + "\r\n");
59 sock.SendStr("\r\n");
60 sock.SendStr(body);
61 }
62
Recv(size_t min_bytes)63 std::string Recv(size_t min_bytes) {
64 static int n = 0;
65 auto checkpoint_name = "rx_" + std::to_string(n++);
66 auto checkpoint = task_runner_->CreateCheckpoint(checkpoint_name);
67 std::string rxbuf;
68 sock.SetBlocking(false);
69 task_runner_->AddFileDescriptorWatch(sock.watch_handle(), [&] {
70 char buf[1024]{};
71 auto rsize = PERFETTO_EINTR(sock.Receive(buf, sizeof(buf)));
72 if (rsize < 0)
73 return;
74 rxbuf.append(buf, static_cast<size_t>(rsize));
75 if (rsize == 0 || (min_bytes && rxbuf.length() >= min_bytes))
76 checkpoint();
77 });
78 task_runner_->RunUntilCheckpoint(checkpoint_name);
79 task_runner_->RemoveFileDescriptorWatch(sock.watch_handle());
80 return rxbuf;
81 }
82
RecvAndWaitConnClose()83 std::string RecvAndWaitConnClose() { return Recv(0); }
84
85 TestTaskRunner* task_runner_;
86 UnixSocketRaw sock;
87 };
88
89 class HttpServerTest : public ::testing::Test {
90 public:
HttpServerTest()91 HttpServerTest() : srv_(&task_runner_, &handler_) { srv_.Start(kTestPort); }
92
93 TestTaskRunner task_runner_;
94 MockHttpHandler handler_;
95 HttpServer srv_;
96 };
97
TEST_F(HttpServerTest,GET)98 TEST_F(HttpServerTest, GET) {
99 const int kIterations = 3;
100 EXPECT_CALL(handler_, OnHttpRequest(_))
101 .Times(kIterations)
102 .WillRepeatedly(Invoke([](const HttpRequest& req) {
103 EXPECT_EQ(req.uri.ToStdString(), "/foo/bar");
104 EXPECT_EQ(req.method.ToStdString(), "GET");
105 EXPECT_EQ(req.origin.ToStdString(), "https://example.com");
106 EXPECT_EQ("42",
107 req.GetHeader("X-header").value_or("N/A").ToStdString());
108 EXPECT_EQ("foo",
109 req.GetHeader("X-header2").value_or("N/A").ToStdString());
110 EXPECT_FALSE(req.is_websocket_handshake);
111 req.conn->SendResponseAndClose("200 OK", {}, "<html>");
112 }));
113 EXPECT_CALL(handler_, OnHttpConnectionClosed(_)).Times(kIterations);
114
115 for (int i = 0; i < 3; i++) {
116 HttpCli cli(&task_runner_);
117 cli.SendHttpReq(
118 {
119 "GET /foo/bar HTTP/1.1", //
120 "Origin: https://example.com", //
121 "X-header: 42", //
122 "X-header2: foo", //
123 },
124 "");
125 EXPECT_EQ(cli.RecvAndWaitConnClose(),
126 "HTTP/1.1 200 OK\r\n"
127 "Content-Length: 6\r\n"
128 "Connection: close\r\n"
129 "\r\n<html>");
130 }
131 }
132
TEST_F(HttpServerTest,GET_404)133 TEST_F(HttpServerTest, GET_404) {
134 HttpCli cli(&task_runner_);
135 EXPECT_CALL(handler_, OnHttpRequest(_))
136 .WillOnce(Invoke([&](const HttpRequest& req) {
137 EXPECT_EQ(req.uri.ToStdString(), "/404");
138 EXPECT_EQ(req.method.ToStdString(), "GET");
139 req.conn->SendResponseAndClose("404 Not Found");
140 }));
141 cli.SendHttpReq({"GET /404 HTTP/1.1"}, "");
142 EXPECT_CALL(handler_, OnHttpConnectionClosed(_));
143 EXPECT_EQ(cli.RecvAndWaitConnClose(),
144 "HTTP/1.1 404 Not Found\r\n"
145 "Content-Length: 0\r\n"
146 "Connection: close\r\n"
147 "\r\n");
148 }
149
TEST_F(HttpServerTest,POST)150 TEST_F(HttpServerTest, POST) {
151 HttpCli cli(&task_runner_);
152
153 EXPECT_CALL(handler_, OnHttpRequest(_))
154 .WillOnce(Invoke([&](const HttpRequest& req) {
155 EXPECT_EQ(req.uri.ToStdString(), "/rpc");
156 EXPECT_EQ(req.method.ToStdString(), "POST");
157 EXPECT_EQ(req.origin.ToStdString(), "https://example.com");
158 EXPECT_EQ("foo", req.GetHeader("X-1").value_or("N/A").ToStdString());
159 EXPECT_EQ(req.body.ToStdString(), "the\r\npost\nbody\r\n\r\n");
160 req.conn->SendResponseAndClose("200 OK");
161 }));
162
163 cli.SendHttpReq(
164 {"POST /rpc HTTP/1.1", "Origin: https://example.com", "X-1: foo"},
165 "the\r\npost\nbody\r\n\r\n");
166 EXPECT_CALL(handler_, OnHttpConnectionClosed(_));
167 EXPECT_EQ(cli.RecvAndWaitConnClose(),
168 "HTTP/1.1 200 OK\r\n"
169 "Content-Length: 0\r\n"
170 "Connection: close\r\n"
171 "\r\n");
172 }
173
174 // An unhandled request should cause a HTTP 500.
TEST_F(HttpServerTest,Unhadled_500)175 TEST_F(HttpServerTest, Unhadled_500) {
176 HttpCli cli(&task_runner_);
177 EXPECT_CALL(handler_, OnHttpRequest(_));
178 cli.SendHttpReq({"GET /unhandled HTTP/1.1"});
179 EXPECT_CALL(handler_, OnHttpConnectionClosed(_));
180 EXPECT_EQ(cli.RecvAndWaitConnClose(),
181 "HTTP/1.1 500 Internal Server Error\r\n"
182 "Content-Length: 0\r\n"
183 "Connection: close\r\n"
184 "\r\n");
185 }
186
187 // Send three requests within the same keepalive connection.
TEST_F(HttpServerTest,POST_Keepalive)188 TEST_F(HttpServerTest, POST_Keepalive) {
189 HttpCli cli(&task_runner_);
190 static const int kNumRequests = 3;
191 int req_num = 0;
192 EXPECT_CALL(handler_, OnHttpConnectionClosed(_)).Times(1);
193 EXPECT_CALL(handler_, OnHttpRequest(_))
194 .Times(3)
195 .WillRepeatedly(Invoke([&](const HttpRequest& req) {
196 EXPECT_EQ(req.uri.ToStdString(), "/" + std::to_string(req_num));
197 EXPECT_EQ(req.method.ToStdString(), "POST");
198 EXPECT_EQ(req.body.ToStdString(), "body" + std::to_string(req_num));
199 req.conn->SendResponseHeaders("200 OK");
200 if (++req_num == kNumRequests)
201 req.conn->Close();
202 }));
203
204 for (int i = 0; i < kNumRequests; i++) {
205 auto i_str = std::to_string(i);
206 cli.SendHttpReq({"POST /" + i_str + " HTTP/1.1", "Connection: keep-alive"},
207 "body" + i_str);
208 }
209
210 std::string expected_response;
211 for (int i = 0; i < kNumRequests; i++) {
212 expected_response +=
213 "HTTP/1.1 200 OK\r\n"
214 "Content-Length: 0\r\n"
215 "Connection: keep-alive\r\n"
216 "\r\n";
217 }
218 EXPECT_EQ(cli.RecvAndWaitConnClose(), expected_response);
219 }
220
TEST_F(HttpServerTest,Websocket)221 TEST_F(HttpServerTest, Websocket) {
222 srv_.AddAllowedOrigin("http://foo.com");
223 srv_.AddAllowedOrigin("http://websocket.com");
224 for (int rep = 0; rep < 3; rep++) {
225 HttpCli cli(&task_runner_);
226 EXPECT_CALL(handler_, OnHttpRequest(_))
227 .WillOnce(Invoke([&](const HttpRequest& req) {
228 EXPECT_EQ(req.uri.ToStdString(), "/websocket");
229 EXPECT_EQ(req.method.ToStdString(), "GET");
230 EXPECT_EQ(req.origin.ToStdString(), "http://websocket.com");
231 EXPECT_TRUE(req.is_websocket_handshake);
232 req.conn->UpgradeToWebsocket(req);
233 }));
234
235 cli.SendHttpReq({
236 "GET /websocket HTTP/1.1", //
237 "Origin: http://websocket.com", //
238 "Connection: upgrade", //
239 "Upgrade: websocket", //
240 "Sec-WebSocket-Version: 13", //
241 "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", //
242 });
243 std::string expected_resp =
244 "HTTP/1.1 101 Switching Protocols\r\n"
245 "Upgrade: websocket\r\n"
246 "Connection: Upgrade\r\n"
247 "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n"
248 "Access-Control-Allow-Origin: http://websocket.com\r\n"
249 "Vary: Origin\r\n"
250 "\r\n";
251 EXPECT_EQ(cli.Recv(expected_resp.size()), expected_resp);
252
253 for (int i = 0; i < 3; i++) {
254 EXPECT_CALL(handler_, OnWebsocketMessage(_))
255 .WillOnce(Invoke([i](const WebsocketMessage& msg) {
256 EXPECT_EQ(msg.data.ToStdString(), "test message");
257 StackString<6> resp("PONG%d", i);
258 msg.conn->SendWebsocketMessage(resp.c_str(), resp.len());
259 }));
260
261 // A frame from a real tcpdump capture:
262 // 1... .... = Fin: True
263 // .000 .... = Reserved: 0x0
264 // .... 0001 = Opcode: Text (1)
265 // 1... .... = Mask: True
266 // .000 1100 = Payload length: 12
267 // Masking-Key: e17e8eb9
268 // Masked payload: "test message"
269 cli.sock.SendStr(
270 "\x81\x8c\xe1\x7e\x8e\xb9\x95\x1b\xfd\xcd\xc1\x13\xeb\xca\x92\x1f\xe9"
271 "\xdc");
272 EXPECT_EQ(cli.Recv(2 + 5), "\x82\x05PONG" + std::to_string(i));
273 }
274
275 cli.sock.Shutdown();
276 auto checkpoint_name = "ws_close_" + std::to_string(rep);
277 auto ws_close = task_runner_.CreateCheckpoint(checkpoint_name);
278 EXPECT_CALL(handler_, OnHttpConnectionClosed(_))
279 .WillOnce(InvokeWithoutArgs(ws_close));
280 task_runner_.RunUntilCheckpoint(checkpoint_name);
281 }
282 }
283
TEST_F(HttpServerTest,Websocket_OriginNotAllowed)284 TEST_F(HttpServerTest, Websocket_OriginNotAllowed) {
285 srv_.AddAllowedOrigin("http://websocket.com");
286 srv_.AddAllowedOrigin("http://notallowed.commando");
287 srv_.AddAllowedOrigin("http://iamnotallowed.com");
288 srv_.AddAllowedOrigin("iamnotallowed.com");
289 // The origin must match in full, including scheme. This won't match.
290 srv_.AddAllowedOrigin("notallowed.com");
291
292 HttpCli cli(&task_runner_);
293 auto close_checkpoint = task_runner_.CreateCheckpoint("close");
294 EXPECT_CALL(handler_, OnHttpConnectionClosed(_))
295 .WillOnce(InvokeWithoutArgs(close_checkpoint));
296 EXPECT_CALL(handler_, OnHttpRequest(_))
297 .WillOnce(Invoke([&](const HttpRequest& req) {
298 EXPECT_EQ(req.origin.ToStdString(), "http://notallowed.com");
299 EXPECT_TRUE(req.is_websocket_handshake);
300 req.conn->UpgradeToWebsocket(req);
301 }));
302
303 cli.SendHttpReq({
304 "GET /websocket HTTP/1.1", //
305 "Origin: http://notallowed.com", //
306 "Connection: upgrade", //
307 "Upgrade: websocket", //
308 "Sec-WebSocket-Version: 13", //
309 "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", //
310 });
311 std::string expected_resp =
312 "HTTP/1.1 403 Forbidden\r\n"
313 "Content-Length: 18\r\n"
314 "Connection: close\r\n"
315 "\r\n"
316 "Origin not allowed";
317
318 EXPECT_EQ(cli.Recv(expected_resp.size()), expected_resp);
319 cli.sock.Shutdown();
320 task_runner_.RunUntilCheckpoint("close");
321 }
322
323 } // namespace
324 } // namespace base
325 } // namespace perfetto
326