• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright (c) 2016-2019 Vinnie Falco (vinnie dot falco at gmail dot com)
3 //
4 // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 //
7 // Official repository: https://github.com/boostorg/beast
8 //
9 
10 //------------------------------------------------------------------------------
11 //
12 // Example: Advanced server, flex (plain + SSL)
13 //
14 //------------------------------------------------------------------------------
15 
16 #include "example/common/server_certificate.hpp"
17 
18 #include <boost/beast/core.hpp>
19 #include <boost/beast/http.hpp>
20 #include <boost/beast/ssl.hpp>
21 #include <boost/beast/websocket.hpp>
22 #include <boost/beast/version.hpp>
23 #include <boost/asio/bind_executor.hpp>
24 #include <boost/asio/dispatch.hpp>
25 #include <boost/asio/signal_set.hpp>
26 #include <boost/asio/steady_timer.hpp>
27 #include <boost/asio/strand.hpp>
28 #include <boost/make_unique.hpp>
29 #include <boost/optional.hpp>
30 #include <algorithm>
31 #include <cstdlib>
32 #include <functional>
33 #include <iostream>
34 #include <memory>
35 #include <string>
36 #include <thread>
37 #include <vector>
38 
39 namespace beast = boost::beast;                 // from <boost/beast.hpp>
40 namespace http = beast::http;                   // from <boost/beast/http.hpp>
41 namespace websocket = beast::websocket;         // from <boost/beast/websocket.hpp>
42 namespace net = boost::asio;                    // from <boost/asio.hpp>
43 namespace ssl = boost::asio::ssl;               // from <boost/asio/ssl.hpp>
44 using tcp = boost::asio::ip::tcp;               // from <boost/asio/ip/tcp.hpp>
45 
46 // Return a reasonable mime type based on the extension of a file.
47 beast::string_view
mime_type(beast::string_view path)48 mime_type(beast::string_view path)
49 {
50     using beast::iequals;
51     auto const ext = [&path]
52     {
53         auto const pos = path.rfind(".");
54         if(pos == beast::string_view::npos)
55             return beast::string_view{};
56         return path.substr(pos);
57     }();
58     if(iequals(ext, ".htm"))  return "text/html";
59     if(iequals(ext, ".html")) return "text/html";
60     if(iequals(ext, ".php"))  return "text/html";
61     if(iequals(ext, ".css"))  return "text/css";
62     if(iequals(ext, ".txt"))  return "text/plain";
63     if(iequals(ext, ".js"))   return "application/javascript";
64     if(iequals(ext, ".json")) return "application/json";
65     if(iequals(ext, ".xml"))  return "application/xml";
66     if(iequals(ext, ".swf"))  return "application/x-shockwave-flash";
67     if(iequals(ext, ".flv"))  return "video/x-flv";
68     if(iequals(ext, ".png"))  return "image/png";
69     if(iequals(ext, ".jpe"))  return "image/jpeg";
70     if(iequals(ext, ".jpeg")) return "image/jpeg";
71     if(iequals(ext, ".jpg"))  return "image/jpeg";
72     if(iequals(ext, ".gif"))  return "image/gif";
73     if(iequals(ext, ".bmp"))  return "image/bmp";
74     if(iequals(ext, ".ico"))  return "image/vnd.microsoft.icon";
75     if(iequals(ext, ".tiff")) return "image/tiff";
76     if(iequals(ext, ".tif"))  return "image/tiff";
77     if(iequals(ext, ".svg"))  return "image/svg+xml";
78     if(iequals(ext, ".svgz")) return "image/svg+xml";
79     return "application/text";
80 }
81 
82 // Append an HTTP rel-path to a local filesystem path.
83 // The returned path is normalized for the platform.
84 std::string
path_cat(beast::string_view base,beast::string_view path)85 path_cat(
86     beast::string_view base,
87     beast::string_view path)
88 {
89     if(base.empty())
90         return std::string(path);
91     std::string result(base);
92 #ifdef BOOST_MSVC
93     char constexpr path_separator = '\\';
94     if(result.back() == path_separator)
95         result.resize(result.size() - 1);
96     result.append(path.data(), path.size());
97     for(auto& c : result)
98         if(c == '/')
99             c = path_separator;
100 #else
101     char constexpr path_separator = '/';
102     if(result.back() == path_separator)
103         result.resize(result.size() - 1);
104     result.append(path.data(), path.size());
105 #endif
106     return result;
107 }
108 
109 // This function produces an HTTP response for the given
110 // request. The type of the response object depends on the
111 // contents of the request, so the interface requires the
112 // caller to pass a generic lambda for receiving the response.
113 template<
114     class Body, class Allocator,
115     class Send>
116 void
handle_request(beast::string_view doc_root,http::request<Body,http::basic_fields<Allocator>> && req,Send && send)117 handle_request(
118     beast::string_view doc_root,
119     http::request<Body, http::basic_fields<Allocator>>&& req,
120     Send&& send)
121 {
122     // Returns a bad request response
123     auto const bad_request =
124     [&req](beast::string_view why)
125     {
126         http::response<http::string_body> res{http::status::bad_request, req.version()};
127         res.set(http::field::server, BOOST_BEAST_VERSION_STRING);
128         res.set(http::field::content_type, "text/html");
129         res.keep_alive(req.keep_alive());
130         res.body() = std::string(why);
131         res.prepare_payload();
132         return res;
133     };
134 
135     // Returns a not found response
136     auto const not_found =
137     [&req](beast::string_view target)
138     {
139         http::response<http::string_body> res{http::status::not_found, req.version()};
140         res.set(http::field::server, BOOST_BEAST_VERSION_STRING);
141         res.set(http::field::content_type, "text/html");
142         res.keep_alive(req.keep_alive());
143         res.body() = "The resource '" + std::string(target) + "' was not found.";
144         res.prepare_payload();
145         return res;
146     };
147 
148     // Returns a server error response
149     auto const server_error =
150     [&req](beast::string_view what)
151     {
152         http::response<http::string_body> res{http::status::internal_server_error, req.version()};
153         res.set(http::field::server, BOOST_BEAST_VERSION_STRING);
154         res.set(http::field::content_type, "text/html");
155         res.keep_alive(req.keep_alive());
156         res.body() = "An error occurred: '" + std::string(what) + "'";
157         res.prepare_payload();
158         return res;
159     };
160 
161     // Make sure we can handle the method
162     if( req.method() != http::verb::get &&
163         req.method() != http::verb::head)
164         return send(bad_request("Unknown HTTP-method"));
165 
166     // Request path must be absolute and not contain "..".
167     if( req.target().empty() ||
168         req.target()[0] != '/' ||
169         req.target().find("..") != beast::string_view::npos)
170         return send(bad_request("Illegal request-target"));
171 
172     // Build the path to the requested file
173     std::string path = path_cat(doc_root, req.target());
174     if(req.target().back() == '/')
175         path.append("index.html");
176 
177     // Attempt to open the file
178     beast::error_code ec;
179     http::file_body::value_type body;
180     body.open(path.c_str(), beast::file_mode::scan, ec);
181 
182     // Handle the case where the file doesn't exist
183     if(ec == beast::errc::no_such_file_or_directory)
184         return send(not_found(req.target()));
185 
186     // Handle an unknown error
187     if(ec)
188         return send(server_error(ec.message()));
189 
190     // Cache the size since we need it after the move
191     auto const size = body.size();
192 
193     // Respond to HEAD request
194     if(req.method() == http::verb::head)
195     {
196         http::response<http::empty_body> res{http::status::ok, req.version()};
197         res.set(http::field::server, BOOST_BEAST_VERSION_STRING);
198         res.set(http::field::content_type, mime_type(path));
199         res.content_length(size);
200         res.keep_alive(req.keep_alive());
201         return send(std::move(res));
202     }
203 
204     // Respond to GET request
205     http::response<http::file_body> res{
206         std::piecewise_construct,
207         std::make_tuple(std::move(body)),
208         std::make_tuple(http::status::ok, req.version())};
209     res.set(http::field::server, BOOST_BEAST_VERSION_STRING);
210     res.set(http::field::content_type, mime_type(path));
211     res.content_length(size);
212     res.keep_alive(req.keep_alive());
213     return send(std::move(res));
214 }
215 
216 //------------------------------------------------------------------------------
217 
218 // Report a failure
219 void
fail(beast::error_code ec,char const * what)220 fail(beast::error_code ec, char const* what)
221 {
222     // ssl::error::stream_truncated, also known as an SSL "short read",
223     // indicates the peer closed the connection without performing the
224     // required closing handshake (for example, Google does this to
225     // improve performance). Generally this can be a security issue,
226     // but if your communication protocol is self-terminated (as
227     // it is with both HTTP and WebSocket) then you may simply
228     // ignore the lack of close_notify.
229     //
230     // https://github.com/boostorg/beast/issues/38
231     //
232     // https://security.stackexchange.com/questions/91435/how-to-handle-a-malicious-ssl-tls-shutdown
233     //
234     // When a short read would cut off the end of an HTTP message,
235     // Beast returns the error beast::http::error::partial_message.
236     // Therefore, if we see a short read here, it has occurred
237     // after the message has been completed, so it is safe to ignore it.
238 
239     if(ec == net::ssl::error::stream_truncated)
240         return;
241 
242     std::cerr << what << ": " << ec.message() << "\n";
243 }
244 
245 //------------------------------------------------------------------------------
246 
247 // Echoes back all received WebSocket messages.
248 // This uses the Curiously Recurring Template Pattern so that
249 // the same code works with both SSL streams and regular sockets.
250 template<class Derived>
251 class websocket_session
252 {
253     // Access the derived class, this is part of
254     // the Curiously Recurring Template Pattern idiom.
255     Derived&
derived()256     derived()
257     {
258         return static_cast<Derived&>(*this);
259     }
260 
261     beast::flat_buffer buffer_;
262 
263     // Start the asynchronous operation
264     template<class Body, class Allocator>
265     void
do_accept(http::request<Body,http::basic_fields<Allocator>> req)266     do_accept(http::request<Body, http::basic_fields<Allocator>> req)
267     {
268         // Set suggested timeout settings for the websocket
269         derived().ws().set_option(
270             websocket::stream_base::timeout::suggested(
271                 beast::role_type::server));
272 
273         // Set a decorator to change the Server of the handshake
274         derived().ws().set_option(
275             websocket::stream_base::decorator(
276             [](websocket::response_type& res)
277             {
278                 res.set(http::field::server,
279                     std::string(BOOST_BEAST_VERSION_STRING) +
280                         " advanced-server-flex");
281             }));
282 
283         // Accept the websocket handshake
284         derived().ws().async_accept(
285             req,
286             beast::bind_front_handler(
287                 &websocket_session::on_accept,
288                 derived().shared_from_this()));
289     }
290 
291     void
on_accept(beast::error_code ec)292     on_accept(beast::error_code ec)
293     {
294         if(ec)
295             return fail(ec, "accept");
296 
297         // Read a message
298         do_read();
299     }
300 
301     void
do_read()302     do_read()
303     {
304         // Read a message into our buffer
305         derived().ws().async_read(
306             buffer_,
307             beast::bind_front_handler(
308                 &websocket_session::on_read,
309                 derived().shared_from_this()));
310     }
311 
312     void
on_read(beast::error_code ec,std::size_t bytes_transferred)313     on_read(
314         beast::error_code ec,
315         std::size_t bytes_transferred)
316     {
317         boost::ignore_unused(bytes_transferred);
318 
319         // This indicates that the websocket_session was closed
320         if(ec == websocket::error::closed)
321             return;
322 
323         if(ec)
324             return fail(ec, "read");
325 
326         // Echo the message
327         derived().ws().text(derived().ws().got_text());
328         derived().ws().async_write(
329             buffer_.data(),
330             beast::bind_front_handler(
331                 &websocket_session::on_write,
332                 derived().shared_from_this()));
333     }
334 
335     void
on_write(beast::error_code ec,std::size_t bytes_transferred)336     on_write(
337         beast::error_code ec,
338         std::size_t bytes_transferred)
339     {
340         boost::ignore_unused(bytes_transferred);
341 
342         if(ec)
343             return fail(ec, "write");
344 
345         // Clear the buffer
346         buffer_.consume(buffer_.size());
347 
348         // Do another read
349         do_read();
350     }
351 
352 public:
353     // Start the asynchronous operation
354     template<class Body, class Allocator>
355     void
run(http::request<Body,http::basic_fields<Allocator>> req)356     run(http::request<Body, http::basic_fields<Allocator>> req)
357     {
358         // Accept the WebSocket upgrade request
359         do_accept(std::move(req));
360     }
361 };
362 
363 //------------------------------------------------------------------------------
364 
365 // Handles a plain WebSocket connection
366 class plain_websocket_session
367     : public websocket_session<plain_websocket_session>
368     , public std::enable_shared_from_this<plain_websocket_session>
369 {
370     websocket::stream<beast::tcp_stream> ws_;
371 
372 public:
373     // Create the session
374     explicit
plain_websocket_session(beast::tcp_stream && stream)375     plain_websocket_session(
376         beast::tcp_stream&& stream)
377         : ws_(std::move(stream))
378     {
379     }
380 
381     // Called by the base class
382     websocket::stream<beast::tcp_stream>&
ws()383     ws()
384     {
385         return ws_;
386     }
387 };
388 
389 //------------------------------------------------------------------------------
390 
391 // Handles an SSL WebSocket connection
392 class ssl_websocket_session
393     : public websocket_session<ssl_websocket_session>
394     , public std::enable_shared_from_this<ssl_websocket_session>
395 {
396     websocket::stream<
397         beast::ssl_stream<beast::tcp_stream>> ws_;
398 
399 public:
400     // Create the ssl_websocket_session
401     explicit
ssl_websocket_session(beast::ssl_stream<beast::tcp_stream> && stream)402     ssl_websocket_session(
403         beast::ssl_stream<beast::tcp_stream>&& stream)
404         : ws_(std::move(stream))
405     {
406     }
407 
408     // Called by the base class
409     websocket::stream<
410         beast::ssl_stream<beast::tcp_stream>>&
ws()411     ws()
412     {
413         return ws_;
414     }
415 };
416 
417 //------------------------------------------------------------------------------
418 
419 template<class Body, class Allocator>
420 void
make_websocket_session(beast::tcp_stream stream,http::request<Body,http::basic_fields<Allocator>> req)421 make_websocket_session(
422     beast::tcp_stream stream,
423     http::request<Body, http::basic_fields<Allocator>> req)
424 {
425     std::make_shared<plain_websocket_session>(
426         std::move(stream))->run(std::move(req));
427 }
428 
429 template<class Body, class Allocator>
430 void
make_websocket_session(beast::ssl_stream<beast::tcp_stream> stream,http::request<Body,http::basic_fields<Allocator>> req)431 make_websocket_session(
432     beast::ssl_stream<beast::tcp_stream> stream,
433     http::request<Body, http::basic_fields<Allocator>> req)
434 {
435     std::make_shared<ssl_websocket_session>(
436         std::move(stream))->run(std::move(req));
437 }
438 
439 //------------------------------------------------------------------------------
440 
441 // Handles an HTTP server connection.
442 // This uses the Curiously Recurring Template Pattern so that
443 // the same code works with both SSL streams and regular sockets.
444 template<class Derived>
445 class http_session
446 {
447     // Access the derived class, this is part of
448     // the Curiously Recurring Template Pattern idiom.
449     Derived&
derived()450     derived()
451     {
452         return static_cast<Derived&>(*this);
453     }
454 
455     // This queue is used for HTTP pipelining.
456     class queue
457     {
458         enum
459         {
460             // Maximum number of responses we will queue
461             limit = 8
462         };
463 
464         // The type-erased, saved work item
465         struct work
466         {
467             virtual ~work() = default;
468             virtual void operator()() = 0;
469         };
470 
471         http_session& self_;
472         std::vector<std::unique_ptr<work>> items_;
473 
474     public:
475         explicit
queue(http_session & self)476         queue(http_session& self)
477             : self_(self)
478         {
479             static_assert(limit > 0, "queue limit must be positive");
480             items_.reserve(limit);
481         }
482 
483         // Returns `true` if we have reached the queue limit
484         bool
is_full() const485         is_full() const
486         {
487             return items_.size() >= limit;
488         }
489 
490         // Called when a message finishes sending
491         // Returns `true` if the caller should initiate a read
492         bool
on_write()493         on_write()
494         {
495             BOOST_ASSERT(! items_.empty());
496             auto const was_full = is_full();
497             items_.erase(items_.begin());
498             if(! items_.empty())
499                 (*items_.front())();
500             return was_full;
501         }
502 
503         // Called by the HTTP handler to send a response.
504         template<bool isRequest, class Body, class Fields>
505         void
operator ()(http::message<isRequest,Body,Fields> && msg)506         operator()(http::message<isRequest, Body, Fields>&& msg)
507         {
508             // This holds a work item
509             struct work_impl : work
510             {
511                 http_session& self_;
512                 http::message<isRequest, Body, Fields> msg_;
513 
514                 work_impl(
515                     http_session& self,
516                     http::message<isRequest, Body, Fields>&& msg)
517                     : self_(self)
518                     , msg_(std::move(msg))
519                 {
520                 }
521 
522                 void
523                 operator()()
524                 {
525                     http::async_write(
526                         self_.derived().stream(),
527                         msg_,
528                         beast::bind_front_handler(
529                             &http_session::on_write,
530                             self_.derived().shared_from_this(),
531                             msg_.need_eof()));
532                 }
533             };
534 
535             // Allocate and store the work
536             items_.push_back(
537                 boost::make_unique<work_impl>(self_, std::move(msg)));
538 
539             // If there was no previous work, start this one
540             if(items_.size() == 1)
541                 (*items_.front())();
542         }
543     };
544 
545     std::shared_ptr<std::string const> doc_root_;
546     queue queue_;
547 
548     // The parser is stored in an optional container so we can
549     // construct it from scratch it at the beginning of each new message.
550     boost::optional<http::request_parser<http::string_body>> parser_;
551 
552 protected:
553     beast::flat_buffer buffer_;
554 
555 public:
556     // Construct the session
http_session(beast::flat_buffer buffer,std::shared_ptr<std::string const> const & doc_root)557     http_session(
558         beast::flat_buffer buffer,
559         std::shared_ptr<std::string const> const& doc_root)
560         : doc_root_(doc_root)
561         , queue_(*this)
562         , buffer_(std::move(buffer))
563     {
564     }
565 
566     void
do_read()567     do_read()
568     {
569         // Construct a new parser for each message
570         parser_.emplace();
571 
572         // Apply a reasonable limit to the allowed size
573         // of the body in bytes to prevent abuse.
574         parser_->body_limit(10000);
575 
576         // Set the timeout.
577         beast::get_lowest_layer(
578             derived().stream()).expires_after(std::chrono::seconds(30));
579 
580         // Read a request using the parser-oriented interface
581         http::async_read(
582             derived().stream(),
583             buffer_,
584             *parser_,
585             beast::bind_front_handler(
586                 &http_session::on_read,
587                 derived().shared_from_this()));
588     }
589 
590     void
on_read(beast::error_code ec,std::size_t bytes_transferred)591     on_read(beast::error_code ec, std::size_t bytes_transferred)
592     {
593         boost::ignore_unused(bytes_transferred);
594 
595         // This means they closed the connection
596         if(ec == http::error::end_of_stream)
597             return derived().do_eof();
598 
599         if(ec)
600             return fail(ec, "read");
601 
602         // See if it is a WebSocket Upgrade
603         if(websocket::is_upgrade(parser_->get()))
604         {
605             // Disable the timeout.
606             // The websocket::stream uses its own timeout settings.
607             beast::get_lowest_layer(derived().stream()).expires_never();
608 
609             // Create a websocket session, transferring ownership
610             // of both the socket and the HTTP request.
611             return make_websocket_session(
612                 derived().release_stream(),
613                 parser_->release());
614         }
615 
616         // Send the response
617         handle_request(*doc_root_, parser_->release(), queue_);
618 
619         // If we aren't at the queue limit, try to pipeline another request
620         if(! queue_.is_full())
621             do_read();
622     }
623 
624     void
on_write(bool close,beast::error_code ec,std::size_t bytes_transferred)625     on_write(bool close, beast::error_code ec, std::size_t bytes_transferred)
626     {
627         boost::ignore_unused(bytes_transferred);
628 
629         if(ec)
630             return fail(ec, "write");
631 
632         if(close)
633         {
634             // This means we should close the connection, usually because
635             // the response indicated the "Connection: close" semantic.
636             return derived().do_eof();
637         }
638 
639         // Inform the queue that a write completed
640         if(queue_.on_write())
641         {
642             // Read another request
643             do_read();
644         }
645     }
646 };
647 
648 //------------------------------------------------------------------------------
649 
650 // Handles a plain HTTP connection
651 class plain_http_session
652     : public http_session<plain_http_session>
653     , public std::enable_shared_from_this<plain_http_session>
654 {
655     beast::tcp_stream stream_;
656 
657 public:
658     // Create the session
plain_http_session(beast::tcp_stream && stream,beast::flat_buffer && buffer,std::shared_ptr<std::string const> const & doc_root)659     plain_http_session(
660         beast::tcp_stream&& stream,
661         beast::flat_buffer&& buffer,
662         std::shared_ptr<std::string const> const& doc_root)
663         : http_session<plain_http_session>(
664             std::move(buffer),
665             doc_root)
666         , stream_(std::move(stream))
667     {
668     }
669 
670     // Start the session
671     void
run()672     run()
673     {
674         this->do_read();
675     }
676 
677     // Called by the base class
678     beast::tcp_stream&
stream()679     stream()
680     {
681         return stream_;
682     }
683 
684     // Called by the base class
685     beast::tcp_stream
release_stream()686     release_stream()
687     {
688         return std::move(stream_);
689     }
690 
691     // Called by the base class
692     void
do_eof()693     do_eof()
694     {
695         // Send a TCP shutdown
696         beast::error_code ec;
697         stream_.socket().shutdown(tcp::socket::shutdown_send, ec);
698 
699         // At this point the connection is closed gracefully
700     }
701 };
702 
703 //------------------------------------------------------------------------------
704 
705 // Handles an SSL HTTP connection
706 class ssl_http_session
707     : public http_session<ssl_http_session>
708     , public std::enable_shared_from_this<ssl_http_session>
709 {
710     beast::ssl_stream<beast::tcp_stream> stream_;
711 
712 public:
713     // Create the http_session
ssl_http_session(beast::tcp_stream && stream,ssl::context & ctx,beast::flat_buffer && buffer,std::shared_ptr<std::string const> const & doc_root)714     ssl_http_session(
715         beast::tcp_stream&& stream,
716         ssl::context& ctx,
717         beast::flat_buffer&& buffer,
718         std::shared_ptr<std::string const> const& doc_root)
719         : http_session<ssl_http_session>(
720             std::move(buffer),
721             doc_root)
722         , stream_(std::move(stream), ctx)
723     {
724     }
725 
726     // Start the session
727     void
run()728     run()
729     {
730         // Set the timeout.
731         beast::get_lowest_layer(stream_).expires_after(std::chrono::seconds(30));
732 
733         // Perform the SSL handshake
734         // Note, this is the buffered version of the handshake.
735         stream_.async_handshake(
736             ssl::stream_base::server,
737             buffer_.data(),
738             beast::bind_front_handler(
739                 &ssl_http_session::on_handshake,
740                 shared_from_this()));
741     }
742 
743     // Called by the base class
744     beast::ssl_stream<beast::tcp_stream>&
stream()745     stream()
746     {
747         return stream_;
748     }
749 
750     // Called by the base class
751     beast::ssl_stream<beast::tcp_stream>
release_stream()752     release_stream()
753     {
754         return std::move(stream_);
755     }
756 
757     // Called by the base class
758     void
do_eof()759     do_eof()
760     {
761         // Set the timeout.
762         beast::get_lowest_layer(stream_).expires_after(std::chrono::seconds(30));
763 
764         // Perform the SSL shutdown
765         stream_.async_shutdown(
766             beast::bind_front_handler(
767                 &ssl_http_session::on_shutdown,
768                 shared_from_this()));
769     }
770 
771 private:
772     void
on_handshake(beast::error_code ec,std::size_t bytes_used)773     on_handshake(
774         beast::error_code ec,
775         std::size_t bytes_used)
776     {
777         if(ec)
778             return fail(ec, "handshake");
779 
780         // Consume the portion of the buffer used by the handshake
781         buffer_.consume(bytes_used);
782 
783         do_read();
784     }
785 
786     void
on_shutdown(beast::error_code ec)787     on_shutdown(beast::error_code ec)
788     {
789         if(ec)
790             return fail(ec, "shutdown");
791 
792         // At this point the connection is closed gracefully
793     }
794 };
795 
796 //------------------------------------------------------------------------------
797 
798 // Detects SSL handshakes
799 class detect_session : public std::enable_shared_from_this<detect_session>
800 {
801     beast::tcp_stream stream_;
802     ssl::context& ctx_;
803     std::shared_ptr<std::string const> doc_root_;
804     beast::flat_buffer buffer_;
805 
806 public:
807     explicit
detect_session(tcp::socket && socket,ssl::context & ctx,std::shared_ptr<std::string const> const & doc_root)808     detect_session(
809         tcp::socket&& socket,
810         ssl::context& ctx,
811         std::shared_ptr<std::string const> const& doc_root)
812         : stream_(std::move(socket))
813         , ctx_(ctx)
814         , doc_root_(doc_root)
815     {
816     }
817 
818     // Launch the detector
819     void
run()820     run()
821     {
822         // We need to be executing within a strand to perform async operations
823         // on the I/O objects in this session. Although not strictly necessary
824         // for single-threaded contexts, this example code is written to be
825         // thread-safe by default.
826         net::dispatch(
827             stream_.get_executor(),
828             beast::bind_front_handler(
829                 &detect_session::on_run,
830                 this->shared_from_this()));
831     }
832 
833     void
on_run()834     on_run()
835     {
836         // Set the timeout.
837         stream_.expires_after(std::chrono::seconds(30));
838 
839         beast::async_detect_ssl(
840             stream_,
841             buffer_,
842             beast::bind_front_handler(
843                 &detect_session::on_detect,
844                 this->shared_from_this()));
845     }
846 
847     void
on_detect(beast::error_code ec,bool result)848     on_detect(beast::error_code ec, bool result)
849     {
850         if(ec)
851             return fail(ec, "detect");
852 
853         if(result)
854         {
855             // Launch SSL session
856             std::make_shared<ssl_http_session>(
857                 std::move(stream_),
858                 ctx_,
859                 std::move(buffer_),
860                 doc_root_)->run();
861             return;
862         }
863 
864         // Launch plain session
865         std::make_shared<plain_http_session>(
866             std::move(stream_),
867             std::move(buffer_),
868             doc_root_)->run();
869     }
870 };
871 
872 // Accepts incoming connections and launches the sessions
873 class listener : public std::enable_shared_from_this<listener>
874 {
875     net::io_context& ioc_;
876     ssl::context& ctx_;
877     tcp::acceptor acceptor_;
878     std::shared_ptr<std::string const> doc_root_;
879 
880 public:
listener(net::io_context & ioc,ssl::context & ctx,tcp::endpoint endpoint,std::shared_ptr<std::string const> const & doc_root)881     listener(
882         net::io_context& ioc,
883         ssl::context& ctx,
884         tcp::endpoint endpoint,
885         std::shared_ptr<std::string const> const& doc_root)
886         : ioc_(ioc)
887         , ctx_(ctx)
888         , acceptor_(net::make_strand(ioc))
889         , doc_root_(doc_root)
890     {
891         beast::error_code ec;
892 
893         // Open the acceptor
894         acceptor_.open(endpoint.protocol(), ec);
895         if(ec)
896         {
897             fail(ec, "open");
898             return;
899         }
900 
901         // Allow address reuse
902         acceptor_.set_option(net::socket_base::reuse_address(true), ec);
903         if(ec)
904         {
905             fail(ec, "set_option");
906             return;
907         }
908 
909         // Bind to the server address
910         acceptor_.bind(endpoint, ec);
911         if(ec)
912         {
913             fail(ec, "bind");
914             return;
915         }
916 
917         // Start listening for connections
918         acceptor_.listen(
919             net::socket_base::max_listen_connections, ec);
920         if(ec)
921         {
922             fail(ec, "listen");
923             return;
924         }
925     }
926 
927     // Start accepting incoming connections
928     void
run()929     run()
930     {
931         do_accept();
932     }
933 
934 private:
935     void
do_accept()936     do_accept()
937     {
938         // The new connection gets its own strand
939         acceptor_.async_accept(
940             net::make_strand(ioc_),
941             beast::bind_front_handler(
942                 &listener::on_accept,
943                 shared_from_this()));
944     }
945 
946     void
on_accept(beast::error_code ec,tcp::socket socket)947     on_accept(beast::error_code ec, tcp::socket socket)
948     {
949         if(ec)
950         {
951             fail(ec, "accept");
952         }
953         else
954         {
955             // Create the detector http_session and run it
956             std::make_shared<detect_session>(
957                 std::move(socket),
958                 ctx_,
959                 doc_root_)->run();
960         }
961 
962         // Accept another connection
963         do_accept();
964     }
965 };
966 
967 //------------------------------------------------------------------------------
968 
main(int argc,char * argv[])969 int main(int argc, char* argv[])
970 {
971     // Check command line arguments.
972     if (argc != 5)
973     {
974         std::cerr <<
975             "Usage: advanced-server-flex <address> <port> <doc_root> <threads>\n" <<
976             "Example:\n" <<
977             "    advanced-server-flex 0.0.0.0 8080 . 1\n";
978         return EXIT_FAILURE;
979     }
980     auto const address = net::ip::make_address(argv[1]);
981     auto const port = static_cast<unsigned short>(std::atoi(argv[2]));
982     auto const doc_root = std::make_shared<std::string>(argv[3]);
983     auto const threads = std::max<int>(1, std::atoi(argv[4]));
984 
985     // The io_context is required for all I/O
986     net::io_context ioc{threads};
987 
988     // The SSL context is required, and holds certificates
989     ssl::context ctx{ssl::context::tlsv12};
990 
991     // This holds the self-signed certificate used by the server
992     load_server_certificate(ctx);
993 
994     // Create and launch a listening port
995     std::make_shared<listener>(
996         ioc,
997         ctx,
998         tcp::endpoint{address, port},
999         doc_root)->run();
1000 
1001     // Capture SIGINT and SIGTERM to perform a clean shutdown
1002     net::signal_set signals(ioc, SIGINT, SIGTERM);
1003     signals.async_wait(
1004         [&](beast::error_code const&, int)
1005         {
1006             // Stop the `io_context`. This will cause `run()`
1007             // to return immediately, eventually destroying the
1008             // `io_context` and all of the sockets in it.
1009             ioc.stop();
1010         });
1011 
1012     // Run the I/O service on the requested number of threads
1013     std::vector<std::thread> v;
1014     v.reserve(threads - 1);
1015     for(auto i = threads - 1; i > 0; --i)
1016         v.emplace_back(
1017         [&ioc]
1018         {
1019             ioc.run();
1020         });
1021     ioc.run();
1022 
1023     // (If we get here, it means we got a SIGINT or SIGTERM)
1024 
1025     // Block until all the threads exit
1026     for(auto& t : v)
1027         t.join();
1028 
1029     return EXIT_SUCCESS;
1030 }
1031