1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <host/libs/websocket/websocket_server.h>
18
19 #include <string>
20 #include <unordered_map>
21
22 #include <android-base/logging.h>
23 #include <libwebsockets.h>
24
25 #include <host/libs/websocket/websocket_handler.h>
26
27 namespace cuttlefish {
WebSocketServer(const char * protocol_name,const std::string & certs_dir,const std::string & assets_dir,int server_port)28 WebSocketServer::WebSocketServer(
29 const char* protocol_name,
30 const std::string &certs_dir,
31 const std::string &assets_dir,
32 int server_port) {
33 std::string cert_file = certs_dir + "/server.crt";
34 std::string key_file = certs_dir + "/server.key";
35
36 retry_ = {
37 .secs_since_valid_ping = 3,
38 .secs_since_valid_hangup = 10,
39 };
40
41 struct lws_protocols protocols[] = {
42 {protocol_name, ServerCallback, 4096, 0, 0, nullptr, 0},
43 {nullptr, nullptr, 0, 0, 0, nullptr, 0}};
44
45 mount_ = {
46 .mount_next = nullptr,
47 .mountpoint = "/",
48 .mountpoint_len = 1,
49 .origin = assets_dir.c_str(),
50 .def = "index.html",
51 .protocol = nullptr,
52 .cgienv = nullptr,
53 .extra_mimetypes = nullptr,
54 .interpret = nullptr,
55 .cgi_timeout = 0,
56 .cache_max_age = 0,
57 .auth_mask = 0,
58 .cache_reusable = 0,
59 .cache_revalidate = 0,
60 .cache_intermediaries = 0,
61 .origin_protocol = LWSMPRO_FILE, // files in a dir
62 .basic_auth_login_file = nullptr,
63 };
64
65 struct lws_context_creation_info info;
66 headers_ = {NULL, NULL,
67 "content-security-policy:",
68 "default-src 'self'; "
69 "style-src 'self' https://fonts.googleapis.com/; "
70 "font-src https://fonts.gstatic.com/; "};
71
72 memset(&info, 0, sizeof info);
73 info.port = server_port;
74 info.mounts = &mount_;
75 info.protocols = protocols;
76 info.vhost_name = "localhost";
77 info.ws_ping_pong_interval = 10;
78 info.headers = &headers_;
79 info.options |= LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT;
80 info.ssl_cert_filepath = cert_file.c_str();
81 info.ssl_private_key_filepath = key_file.c_str();
82 info.retry_and_idle_policy = &retry_;
83
84 context_ = lws_create_context(&info);
85 if (!context_) {
86 LOG(FATAL) << "Failed to create websocket context";
87 }
88 }
89
RegisterHandlerFactory(const std::string & path,std::unique_ptr<WebSocketHandlerFactory> handler_factory_p)90 void WebSocketServer::RegisterHandlerFactory(
91 const std::string &path,
92 std::unique_ptr<WebSocketHandlerFactory> handler_factory_p) {
93 handler_factories_[path] = std::move(handler_factory_p);
94 }
95
Serve()96 void WebSocketServer::Serve() {
97 int n = 0;
98 while (n >= 0) {
99 n = lws_service(context_, 0);
100 }
101 lws_context_destroy(context_);
102 }
103
104 std::unordered_map<struct lws*, std::shared_ptr<WebSocketHandler>> WebSocketServer::handlers_ = {};
105 std::unordered_map<std::string, std::unique_ptr<WebSocketHandlerFactory>>
106 WebSocketServer::handler_factories_ = {};
107
GetPath(struct lws * wsi)108 std::string WebSocketServer::GetPath(struct lws* wsi) {
109 auto len = lws_hdr_total_length(wsi, WSI_TOKEN_GET_URI);
110 std::string path(len + 1, '\0');
111 auto ret = lws_hdr_copy(wsi, path.data(), path.size(), WSI_TOKEN_GET_URI);
112 if (ret <= 0) {
113 len = lws_hdr_total_length(wsi, WSI_TOKEN_HTTP_COLON_PATH);
114 path.resize(len + 1, '\0');
115 ret = lws_hdr_copy(wsi, path.data(), path.size(), WSI_TOKEN_HTTP_COLON_PATH);
116 }
117 if (ret < 0) {
118 LOG(FATAL) << "Something went wrong getting the path";
119 }
120 path.resize(len);
121 return path;
122 }
123
ServerCallback(struct lws * wsi,enum lws_callback_reasons reason,void * user,void * in,size_t len)124 int WebSocketServer::ServerCallback(struct lws* wsi, enum lws_callback_reasons reason,
125 void* user, void* in, size_t len) {
126 switch (reason) {
127 case LWS_CALLBACK_ESTABLISHED: {
128 auto path = GetPath(wsi);
129 auto handler = InstantiateHandler(path, wsi);
130 if (!handler) {
131 // This message came on an unexpected uri, close the connection.
132 lws_close_reason(wsi, LWS_CLOSE_STATUS_NOSTATUS, (uint8_t*)"404", 3);
133 return -1;
134 }
135 handlers_[wsi] = handler;
136 handler->OnConnected();
137 break;
138 }
139 case LWS_CALLBACK_CLOSED: {
140 auto handler = handlers_[wsi];
141 if (handler) {
142 handler->OnClosed();
143 handlers_.erase(wsi);
144 }
145 break;
146 }
147 case LWS_CALLBACK_SERVER_WRITEABLE: {
148 auto handler = handlers_[wsi];
149 if (handler) {
150 auto should_close = handler->OnWritable();
151 if (should_close) {
152 lws_close_reason(wsi, LWS_CLOSE_STATUS_NORMAL, nullptr, 0);
153 return 1;
154 }
155 } else {
156 LOG(WARNING) << "Unknown wsi became writable";
157 return -1;
158 }
159 break;
160 }
161 case LWS_CALLBACK_RECEIVE: {
162 auto handler = handlers_[wsi];
163 if (handler) {
164 bool is_final = (lws_remaining_packet_payload(wsi) == 0) &&
165 lws_is_final_fragment(wsi);
166 handler->OnReceive(reinterpret_cast<const uint8_t*>(in), len,
167 lws_frame_is_binary(wsi), is_final);
168 } else {
169 LOG(WARNING) << "Unkwnown wsi sent data";
170 }
171 break;
172 }
173 default:
174 return lws_callback_http_dummy(wsi, reason, user, in, len);
175 }
176 return 0;
177 }
178
InstantiateHandler(const std::string & uri_path,struct lws * wsi)179 std::shared_ptr<WebSocketHandler> WebSocketServer::InstantiateHandler(
180 const std::string& uri_path, struct lws* wsi) {
181 auto it = handler_factories_.find(uri_path);
182 if (it == handler_factories_.end()) {
183 LOG(ERROR) << "Wrong path provided in URI: " << uri_path;
184 return nullptr;
185 } else {
186 LOG(INFO) << "Creating handler for " << uri_path;
187 return it->second->Build(wsi);
188 }
189 }
190
191 } // namespace cuttlefish
192