• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <host/libs/websocket/websocket_server.h>
18 
19 #include <string>
20 #include <unordered_map>
21 
22 #include <android-base/logging.h>
23 #include <libwebsockets.h>
24 
25 #include <host/libs/websocket/websocket_handler.h>
26 
27 namespace cuttlefish {
WebSocketServer(const char * protocol_name,const std::string & certs_dir,const std::string & assets_dir,int server_port)28 WebSocketServer::WebSocketServer(
29     const char* protocol_name,
30     const std::string &certs_dir,
31     const std::string &assets_dir,
32     int server_port) {
33   std::string cert_file = certs_dir + "/server.crt";
34   std::string key_file = certs_dir + "/server.key";
35 
36   retry_ = {
37       .secs_since_valid_ping = 3,
38       .secs_since_valid_hangup = 10,
39   };
40 
41   struct lws_protocols protocols[] = {
42       {protocol_name, ServerCallback, 4096, 0, 0, nullptr, 0},
43       {nullptr, nullptr, 0, 0, 0, nullptr, 0}};
44 
45   mount_ = {
46       .mount_next = nullptr,
47       .mountpoint = "/",
48       .mountpoint_len = 1,
49       .origin = assets_dir.c_str(),
50       .def = "index.html",
51       .protocol = nullptr,
52       .cgienv = nullptr,
53       .extra_mimetypes = nullptr,
54       .interpret = nullptr,
55       .cgi_timeout = 0,
56       .cache_max_age = 0,
57       .auth_mask = 0,
58       .cache_reusable = 0,
59       .cache_revalidate = 0,
60       .cache_intermediaries = 0,
61       .origin_protocol = LWSMPRO_FILE,  // files in a dir
62       .basic_auth_login_file = nullptr,
63   };
64 
65   struct lws_context_creation_info info;
66   headers_ = {NULL, NULL,
67     "content-security-policy:",
68       "default-src 'self'; "
69       "style-src 'self' https://fonts.googleapis.com/; "
70       "font-src  https://fonts.gstatic.com/; "};
71 
72   memset(&info, 0, sizeof info);
73   info.port = server_port;
74   info.mounts = &mount_;
75   info.protocols = protocols;
76   info.vhost_name = "localhost";
77   info.ws_ping_pong_interval = 10;
78   info.headers = &headers_;
79   info.options |= LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT;
80   info.ssl_cert_filepath = cert_file.c_str();
81   info.ssl_private_key_filepath = key_file.c_str();
82   info.retry_and_idle_policy = &retry_;
83 
84   context_ = lws_create_context(&info);
85   if (!context_) {
86     LOG(FATAL) << "Failed to create websocket context";
87   }
88 }
89 
RegisterHandlerFactory(const std::string & path,std::unique_ptr<WebSocketHandlerFactory> handler_factory_p)90 void WebSocketServer::RegisterHandlerFactory(
91     const std::string &path,
92     std::unique_ptr<WebSocketHandlerFactory> handler_factory_p) {
93   handler_factories_[path] = std::move(handler_factory_p);
94 }
95 
Serve()96 void WebSocketServer::Serve() {
97   int n = 0;
98   while (n >= 0) {
99     n = lws_service(context_, 0);
100   }
101   lws_context_destroy(context_);
102 }
103 
104 std::unordered_map<struct lws*, std::shared_ptr<WebSocketHandler>> WebSocketServer::handlers_ = {};
105 std::unordered_map<std::string, std::unique_ptr<WebSocketHandlerFactory>>
106     WebSocketServer::handler_factories_ = {};
107 
GetPath(struct lws * wsi)108 std::string WebSocketServer::GetPath(struct lws* wsi) {
109   auto len = lws_hdr_total_length(wsi, WSI_TOKEN_GET_URI);
110   std::string path(len + 1, '\0');
111   auto ret = lws_hdr_copy(wsi, path.data(), path.size(), WSI_TOKEN_GET_URI);
112   if (ret <= 0) {
113       len = lws_hdr_total_length(wsi, WSI_TOKEN_HTTP_COLON_PATH);
114       path.resize(len + 1, '\0');
115       ret = lws_hdr_copy(wsi, path.data(), path.size(), WSI_TOKEN_HTTP_COLON_PATH);
116   }
117   if (ret < 0) {
118     LOG(FATAL) << "Something went wrong getting the path";
119   }
120   path.resize(len);
121   return path;
122 }
123 
ServerCallback(struct lws * wsi,enum lws_callback_reasons reason,void * user,void * in,size_t len)124 int WebSocketServer::ServerCallback(struct lws* wsi, enum lws_callback_reasons reason,
125                                     void* user, void* in, size_t len) {
126   switch (reason) {
127     case LWS_CALLBACK_ESTABLISHED: {
128       auto path = GetPath(wsi);
129       auto handler = InstantiateHandler(path, wsi);
130       if (!handler) {
131         // This message came on an unexpected uri, close the connection.
132         lws_close_reason(wsi, LWS_CLOSE_STATUS_NOSTATUS, (uint8_t*)"404", 3);
133         return -1;
134       }
135       handlers_[wsi] = handler;
136       handler->OnConnected();
137       break;
138     }
139     case LWS_CALLBACK_CLOSED: {
140       auto handler = handlers_[wsi];
141       if (handler) {
142         handler->OnClosed();
143         handlers_.erase(wsi);
144       }
145       break;
146     }
147     case LWS_CALLBACK_SERVER_WRITEABLE: {
148       auto handler = handlers_[wsi];
149       if (handler) {
150         auto should_close = handler->OnWritable();
151         if (should_close) {
152           lws_close_reason(wsi, LWS_CLOSE_STATUS_NORMAL, nullptr, 0);
153           return 1;
154         }
155       } else {
156         LOG(WARNING) << "Unknown wsi became writable";
157         return -1;
158       }
159       break;
160     }
161     case LWS_CALLBACK_RECEIVE: {
162       auto handler = handlers_[wsi];
163       if (handler) {
164         bool is_final = (lws_remaining_packet_payload(wsi) == 0) &&
165                         lws_is_final_fragment(wsi);
166         handler->OnReceive(reinterpret_cast<const uint8_t*>(in), len,
167                            lws_frame_is_binary(wsi), is_final);
168       } else {
169         LOG(WARNING) << "Unkwnown wsi sent data";
170       }
171       break;
172     }
173     default:
174       return lws_callback_http_dummy(wsi, reason, user, in, len);
175   }
176   return 0;
177 }
178 
InstantiateHandler(const std::string & uri_path,struct lws * wsi)179 std::shared_ptr<WebSocketHandler> WebSocketServer::InstantiateHandler(
180     const std::string& uri_path, struct lws* wsi) {
181   auto it = handler_factories_.find(uri_path);
182   if (it == handler_factories_.end()) {
183     LOG(ERROR) << "Wrong path provided in URI: " << uri_path;
184     return nullptr;
185   } else {
186     LOG(INFO) << "Creating handler for " << uri_path;
187     return it->second->Build(wsi);
188   }
189 }
190 
191 }  // namespace cuttlefish
192