1 // Copyright 2015 The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "webservd/protocol_handler.h"
16
17 #include <linux/tcp.h>
18 #include <microhttpd.h>
19 #include <netinet/in.h>
20 #include <sys/socket.h>
21
22 #include <algorithm>
23 #include <limits>
24 #include <vector>
25
26 #include <base/bind.h>
27 #include <base/guid.h>
28 #include <base/logging.h>
29 #include <base/message_loop/message_loop.h>
30
31 #include "webservd/request.h"
32 #include "webservd/request_handler_interface.h"
33 #include "webservd/server_interface.h"
34
35 namespace webservd {
36
37 // Helper class to provide static callback methods to libmicrohttpd library,
38 // with the ability to access private methods of Server class.
39 class ServerHelper final {
40 public:
ConnectionHandler(void * cls,MHD_Connection * connection,const char * url,const char * method,const char * version,const char * upload_data,size_t * upload_data_size,void ** con_cls)41 static int ConnectionHandler(void *cls,
42 MHD_Connection* connection,
43 const char* url,
44 const char* method,
45 const char* version,
46 const char* upload_data,
47 size_t* upload_data_size,
48 void** con_cls) {
49 auto handler = reinterpret_cast<ProtocolHandler*>(cls);
50 if (nullptr == *con_cls) {
51 std::string request_handler_id = handler->FindRequestHandler(url, method);
52 std::unique_ptr<Request> request{new Request{
53 request_handler_id, url, method, version, connection, handler
54 }};
55 if (!request->BeginRequestData())
56 return MHD_NO;
57
58 // Pass the raw pointer here in order to interface with libmicrohttpd's
59 // old-style C API.
60 *con_cls = request.release();
61 } else {
62 auto request = reinterpret_cast<Request*>(*con_cls);
63 if (*upload_data_size) {
64 if (!request->AddRequestData(upload_data, upload_data_size))
65 return MHD_NO;
66 } else {
67 request->EndRequestData();
68 }
69 }
70 return MHD_YES;
71 }
72
RequestCompleted(void *,MHD_Connection *,void ** con_cls,MHD_RequestTerminationCode toe)73 static void RequestCompleted(void* /* cls */,
74 MHD_Connection* /* connection */,
75 void** con_cls,
76 MHD_RequestTerminationCode toe) {
77 if (toe != MHD_REQUEST_TERMINATED_COMPLETED_OK) {
78 LOG(ERROR) << "Web request terminated abnormally with error code: "
79 << toe;
80 }
81 auto request = reinterpret_cast<Request*>(*con_cls);
82 *con_cls = nullptr;
83 delete request;
84 }
85 };
86
ProtocolHandler(const std::string & name,ServerInterface * server_interface)87 ProtocolHandler::ProtocolHandler(const std::string& name,
88 ServerInterface* server_interface)
89 : id_{base::GenerateGUID()},
90 name_{name},
91 server_interface_{server_interface} {}
92
~ProtocolHandler()93 ProtocolHandler::~ProtocolHandler() {
94 Stop();
95 }
96
AddRequestHandler(const std::string & url,const std::string & method,std::unique_ptr<RequestHandlerInterface> handler)97 std::string ProtocolHandler::AddRequestHandler(
98 const std::string& url,
99 const std::string& method,
100 std::unique_ptr<RequestHandlerInterface> handler) {
101 std::string handler_id = base::GenerateGUID();
102 request_handlers_.emplace(handler_id,
103 HandlerMapEntry{url, method, std::move(handler)});
104 return handler_id;
105 }
106
RemoveRequestHandler(const std::string & handler_id)107 bool ProtocolHandler::RemoveRequestHandler(const std::string& handler_id) {
108 return request_handlers_.erase(handler_id) == 1;
109 }
110
FindRequestHandler(const base::StringPiece & url,const base::StringPiece & method) const111 std::string ProtocolHandler::FindRequestHandler(
112 const base::StringPiece& url,
113 const base::StringPiece& method) const {
114 size_t score = std::numeric_limits<size_t>::max();
115 std::string handler_id;
116 for (const auto& pair : request_handlers_) {
117 std::string handler_url = pair.second.url;
118 bool url_match = (handler_url == url);
119 bool method_match = (pair.second.method == method);
120
121 // Try exact match first. If everything matches, we have our handler.
122 if (url_match && method_match)
123 return pair.first;
124
125 // Calculate the current handler's similarity score. The lower the score
126 // the better the match is...
127 size_t current_score = 0;
128 if (!url_match && !handler_url.empty() && handler_url.back() == '/') {
129 if (url.starts_with(handler_url)) {
130 url_match = true;
131 // Use the difference in URL length as URL match quality proxy.
132 // The longer URL, the more specific (better) match is.
133 // Multiply by 2 to allow for extra score point for matching the method.
134 current_score = (url.size() - handler_url.size()) * 2;
135 }
136 }
137
138 if (!method_match && pair.second.method.empty()) {
139 // If the handler didn't specify the method it handles, this means
140 // it doesn't care. However this isn't the exact match, so bump
141 // the score up one point.
142 method_match = true;
143 ++current_score;
144 }
145
146 if (url_match && method_match && current_score < score) {
147 score = current_score;
148 handler_id = pair.first;
149 }
150 }
151
152 return handler_id;
153 }
154
Start(Config::ProtocolHandler * config)155 bool ProtocolHandler::Start(Config::ProtocolHandler* config) {
156 if (server_) {
157 LOG(ERROR) << "Protocol handler is already running.";
158 return false;
159 }
160
161 // If using TLS, the certificate, private key and fingerprint must be
162 // provided.
163 CHECK_EQ(config->use_tls, !config->private_key.empty());
164 CHECK_EQ(config->use_tls, !config->certificate.empty());
165 CHECK_EQ(config->use_tls, !config->certificate_fingerprint.empty());
166
167 LOG(INFO) << "Starting " << (config->use_tls ? "HTTPS" : "HTTP")
168 << " protocol handler on port: " << config->port;
169
170 port_ = config->port;
171 protocol_ = (config->use_tls ? "https" : "http");
172 certificate_fingerprint_ = config->certificate_fingerprint;
173
174 auto callback_addr =
175 reinterpret_cast<intptr_t>(&ServerHelper::RequestCompleted);
176 uint32_t flags = MHD_NO_FLAG;
177 if (server_interface_->GetConfig().use_debug)
178 flags |= MHD_USE_DEBUG;
179
180 // Enable IPv6 if supported.
181 if (server_interface_->GetConfig().use_ipv6)
182 flags |= MHD_USE_DUAL_STACK;
183 flags |= MHD_USE_TCP_FASTOPEN; // Use TCP Fast Open (see RFC 7413).
184 flags |= MHD_USE_SUSPEND_RESUME; // Allow suspending/resuming connections.
185
186 // MHD uses timeout of 0 to mean there is no timeout.
187 int timeout = server_interface_->GetConfig().default_request_timeout_seconds;
188 if (timeout < 0)
189 timeout = 0;
190
191 std::vector<MHD_OptionItem> options{
192 {MHD_OPTION_CONNECTION_LIMIT, 10, nullptr},
193 {MHD_OPTION_CONNECTION_TIMEOUT, timeout, nullptr},
194 {MHD_OPTION_NOTIFY_COMPLETED, callback_addr, nullptr},
195 };
196
197 if (config->socket_fd != -1) {
198 // Take ownership of the socket.
199 int socket_fd = config->socket_fd;
200 config->socket_fd = -1;
201
202 // Set some more socket options. These options were set in libmicrohttpd.
203 int on = 1;
204 if (setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) < 0) {
205 // Treat this as a non-fatal failure. Just continue after logging.
206 PLOG(WARNING) << "Failed to set SO_REUSEADDR option on listening socket.";
207 }
208 on = (MHD_USE_DUAL_STACK != (flags & MHD_USE_DUAL_STACK));
209 if (setsockopt(socket_fd, IPPROTO_IPV6, IPV6_V6ONLY, &on, sizeof(on)) < 0) {
210 PLOG(WARNING) << "Failed to set IPV6_V6ONLY option on listening socket.";
211 close(socket_fd);
212 return false;
213 }
214
215 // Bind socket to the port.
216 sockaddr_in6 addr = {};
217 addr.sin6_family = AF_INET6;
218 addr.sin6_port = htons(config->port);
219 if (bind(socket_fd, reinterpret_cast<const sockaddr*>(&addr),
220 sizeof(addr)) < 0) {
221 PLOG(ERROR) << "Failed to bind the socket to port " << config->port;
222 close(socket_fd);
223 return false;
224 }
225 if ((flags & MHD_USE_TCP_FASTOPEN) != 0) {
226 // This is the default value from libmicrohttpd.
227 int fastopen_queue_size = 10;
228 if (setsockopt(socket_fd, IPPROTO_TCP, TCP_FASTOPEN,
229 &fastopen_queue_size, sizeof(fastopen_queue_size)) < 0) {
230 // Treat this as a non-fatal failure. Just continue after logging.
231 PLOG(WARNING) << "Failed to set TCP_FASTOPEN option on socket.";
232 }
233 }
234
235 // Start listening on the socket.
236 // 32 connections is the value used by libmicrohttpd.
237 if (listen(socket_fd, 32) < 0) {
238 PLOG(ERROR) << "Failed to listen for connections on the socket.";
239 close(socket_fd);
240 return false;
241 }
242
243 // Finally, pass the socket to libmicrohttpd.
244 options.push_back(
245 MHD_OptionItem{MHD_OPTION_LISTEN_SOCKET, socket_fd, nullptr});
246 }
247
248 // libmicrohttpd expects both the key and certificate to be zero-terminated
249 // strings. Make sure they are terminated properly.
250 brillo::SecureBlob private_key_copy = config->private_key;
251 brillo::Blob certificate_copy = config->certificate;
252 private_key_copy.push_back(0);
253 certificate_copy.push_back(0);
254
255 if (config->use_tls) {
256 flags |= MHD_USE_SSL;
257 options.push_back(
258 MHD_OptionItem{MHD_OPTION_HTTPS_MEM_KEY, 0, private_key_copy.data()});
259 options.push_back(
260 MHD_OptionItem{MHD_OPTION_HTTPS_MEM_CERT, 0, certificate_copy.data()});
261 }
262
263 options.push_back(MHD_OptionItem{MHD_OPTION_END, 0, nullptr});
264
265 server_ = MHD_start_daemon(flags, config->port, nullptr, nullptr,
266 &ServerHelper::ConnectionHandler, this,
267 MHD_OPTION_ARRAY, options.data(), MHD_OPTION_END);
268 if (!server_) {
269 PLOG(ERROR) << "Failed to create protocol handler on port " << config->port;
270 return false;
271 }
272 server_interface_->ProtocolHandlerStarted(this);
273 DoWork();
274 LOG(INFO) << "Protocol handler started";
275 return true;
276 }
277
Stop()278 bool ProtocolHandler::Stop() {
279 if (server_) {
280 LOG(INFO) << "Shutting down the protocol handler...";
281 MHD_stop_daemon(server_);
282 server_ = nullptr;
283 server_interface_->ProtocolHandlerStopped(this);
284 LOG(INFO) << "Protocol handler shutdown complete";
285 }
286 port_ = 0;
287 protocol_.clear();
288 certificate_fingerprint_.clear();
289 return true;
290 }
291
AddRequest(Request * request)292 void ProtocolHandler::AddRequest(Request* request) {
293 requests_.emplace(request->GetID(), request);
294 }
295
RemoveRequest(Request * request)296 void ProtocolHandler::RemoveRequest(Request* request) {
297 requests_.erase(request->GetID());
298 }
299
GetRequest(const std::string & request_id) const300 Request* ProtocolHandler::GetRequest(const std::string& request_id) const {
301 auto p = requests_.find(request_id);
302 return (p != requests_.end()) ? p->second : nullptr;
303 }
304
305 // A file descriptor watcher class that oversees I/O operation notification
306 // on particular socket file descriptor.
307 class ProtocolHandler::Watcher final : public base::MessageLoopForIO::Watcher {
308 public:
Watcher(ProtocolHandler * handler,int fd)309 Watcher(ProtocolHandler* handler, int fd) : fd_{fd}, handler_{handler} {}
310
Watch(bool read,bool write)311 void Watch(bool read, bool write) {
312 if (read == watching_read_ && write == watching_write_ && !triggered_)
313 return;
314
315 controller_.StopWatchingFileDescriptor();
316 watching_read_ = read;
317 watching_write_ = write;
318 triggered_ = false;
319
320 auto mode = base::MessageLoopForIO::WATCH_READ_WRITE;
321 if (watching_read_ && watching_write_)
322 mode = base::MessageLoopForIO::WATCH_READ_WRITE;
323 else if (watching_read_)
324 mode = base::MessageLoopForIO::WATCH_READ;
325 else if (watching_write_)
326 mode = base::MessageLoopForIO::WATCH_WRITE;
327 base::MessageLoopForIO::current()->WatchFileDescriptor(fd_, false, mode,
328 &controller_, this);
329 }
330
331 // Overrides from base::MessageLoopForIO::Watcher.
OnFileCanReadWithoutBlocking(int)332 void OnFileCanReadWithoutBlocking(int /* fd */) override {
333 triggered_ = true;
334 handler_->ScheduleWork();
335 }
336
OnFileCanWriteWithoutBlocking(int)337 void OnFileCanWriteWithoutBlocking(int /* fd */) override {
338 triggered_ = true;
339 handler_->ScheduleWork();
340 }
341
GetFileDescriptor() const342 int GetFileDescriptor() const { return fd_; }
343
344 private:
345 int fd_{-1};
346 ProtocolHandler* handler_{nullptr};
347 bool watching_read_{false};
348 bool watching_write_{false};
349 bool triggered_{false};
350 base::MessageLoopForIO::FileDescriptorWatcher controller_;
351
352 DISALLOW_COPY_AND_ASSIGN(Watcher);
353 };
354
ScheduleWork()355 void ProtocolHandler::ScheduleWork() {
356 if (work_scheduled_)
357 return;
358
359 work_scheduled_ = true;
360 base::MessageLoopForIO::current()->PostTask(
361 FROM_HERE,
362 base::Bind(&ProtocolHandler::DoWork, weak_ptr_factory_.GetWeakPtr()));
363 }
364
DoWork()365 void ProtocolHandler::DoWork() {
366 work_scheduled_ = false;
367 weak_ptr_factory_.InvalidateWeakPtrs();
368
369 // Check if there is any pending work to be done in libmicrohttpd.
370 MHD_run(server_);
371
372 // Get all the file descriptors from libmicrohttpd and watch for I/O
373 // operations on them.
374 fd_set rs;
375 fd_set ws;
376 fd_set es;
377 int max_fd = MHD_INVALID_SOCKET;
378 FD_ZERO(&rs);
379 FD_ZERO(&ws);
380 FD_ZERO(&es);
381 CHECK_EQ(MHD_YES, MHD_get_fdset(server_, &rs, &ws, &es, &max_fd));
382
383 for (auto& watcher : watchers_) {
384 int fd = watcher->GetFileDescriptor();
385 if (FD_ISSET(fd, &rs) || FD_ISSET(fd, &ws)) {
386 watcher->Watch(FD_ISSET(fd, &rs), FD_ISSET(fd, &ws));
387 FD_CLR(fd, &rs);
388 FD_CLR(fd, &ws);
389 } else {
390 watcher.reset();
391 }
392 }
393
394 watchers_.erase(std::remove(watchers_.begin(), watchers_.end(), nullptr),
395 watchers_.end());
396
397 for (int fd = 0; fd <= max_fd; fd++) {
398 // libmicrohttpd is not using exception FDs, so lets put our expectations
399 // upfront.
400 CHECK(!FD_ISSET(fd, &es));
401 if (FD_ISSET(fd, &rs) || FD_ISSET(fd, &ws)) {
402 // libmicrohttpd should never use any of stdin/stdout/stderr descriptors.
403 CHECK_GT(fd, STDERR_FILENO);
404 std::unique_ptr<Watcher> watcher{new Watcher{this, fd}};
405 watcher->Watch(FD_ISSET(fd, &rs), FD_ISSET(fd, &ws));
406 watchers_.push_back(std::move(watcher));
407 }
408 }
409
410 // Schedule a time-out timer, if asked by libmicrohttpd.
411 MHD_UNSIGNED_LONG_LONG mhd_timeout = 0;
412 if (MHD_get_timeout(server_, &mhd_timeout) == MHD_YES) {
413 base::MessageLoopForIO::current()->PostDelayedTask(
414 FROM_HERE,
415 base::Bind(&ProtocolHandler::DoWork, weak_ptr_factory_.GetWeakPtr()),
416 base::TimeDelta::FromMilliseconds(mhd_timeout));
417 }
418 }
419
420 } // namespace webservd
421