1 //
2 // Copyright (C) 2020 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 //
15
16 #include "host/frontend/webrtc/lib/ws_connection.h"
17
18 #include <android-base/logging.h>
19 #include <libwebsockets.h>
20
21 class WsConnectionContextImpl;
22
23 class WsConnectionImpl : public WsConnection,
24 public std::enable_shared_from_this<WsConnectionImpl> {
25 public:
26 struct CreateConnectionSul {
27 lws_sorted_usec_list_t sul = {};
28 std::weak_ptr<WsConnectionImpl> weak_this;
29 };
30
31 WsConnectionImpl(
32 int port, const std::string& addr, const std::string& path,
33 Security secure,
34 const std::vector<std::pair<std::string, std::string>>& headers,
35 std::weak_ptr<WsConnectionObserver> observer,
36 std::shared_ptr<WsConnectionContextImpl> context);
37
38 ~WsConnectionImpl() override;
39
40 void Connect() override;
41 void ConnectInner();
42
43 bool Send(const uint8_t* data, size_t len, bool binary = false) override;
44
45 void OnError(const std::string& error);
46 void OnReceive(const uint8_t* data, size_t len, bool is_binary);
47 void OnOpen();
48 void OnClose();
49 void OnWriteable();
50
51 void AddHttpHeaders(unsigned char** p, unsigned char* end) const;
52
53 private:
54 struct WsBuffer {
55 WsBuffer() = default;
WsBufferWsConnectionImpl::WsBuffer56 WsBuffer(const uint8_t* data, size_t len, bool binary)
57 : buffer_(LWS_PRE + len), is_binary_(binary) {
58 memcpy(&buffer_[LWS_PRE], data, len);
59 }
60
dataWsConnectionImpl::WsBuffer61 uint8_t* data() { return &buffer_[LWS_PRE]; }
is_binaryWsConnectionImpl::WsBuffer62 bool is_binary() const { return is_binary_; }
sizeWsConnectionImpl::WsBuffer63 size_t size() const { return buffer_.size() - LWS_PRE; }
64
65 private:
66 std::vector<uint8_t> buffer_;
67 bool is_binary_;
68 };
69
70 CreateConnectionSul extended_sul_;
71 struct lws* wsi_;
72 const int port_;
73 const std::string addr_;
74 const std::string path_;
75 const Security security_;
76 const std::vector<std::pair<std::string, std::string>> headers_;
77
78 std::weak_ptr<WsConnectionObserver> observer_;
79
80 // each element contains the data to be sent and whether it's binary or not
81 std::deque<WsBuffer> write_queue_;
82 std::mutex write_queue_mutex_;
83 // The connection object should not outlive the context object. This reference
84 // guarantees it.
85 std::shared_ptr<WsConnectionContextImpl> context_;
86 };
87
88 class WsConnectionContextImpl
89 : public WsConnectionContext,
90 public std::enable_shared_from_this<WsConnectionContextImpl> {
91 public:
92 WsConnectionContextImpl(struct lws_context* lws_ctx);
93 ~WsConnectionContextImpl() override;
94
95 std::shared_ptr<WsConnection> CreateConnection(
96 int port, const std::string& addr, const std::string& path,
97 WsConnection::Security secure,
98 std::weak_ptr<WsConnectionObserver> observer,
99 const std::vector<std::pair<std::string, std::string>>& headers) override;
100
101 void RememberConnection(void*, std::weak_ptr<WsConnectionImpl>);
102 void ForgetConnection(void*);
103 std::shared_ptr<WsConnectionImpl> GetConnection(void*);
104
lws_context()105 struct lws_context* lws_context() {
106 return lws_context_;
107 }
108
109 private:
110 void Start();
111
112 std::map<void*, std::weak_ptr<WsConnectionImpl>> weak_by_ptr_;
113 std::mutex map_mutex_;
114 struct lws_context* lws_context_;
115 std::thread message_loop_;
116 };
117
118 int LwsCallback(struct lws* wsi, enum lws_callback_reasons reason, void* user,
119 void* in, size_t len);
120 void CreateConnectionCallback(lws_sorted_usec_list_t* sul);
121
122 namespace {
123
124 constexpr char kProtocolName[] = "cf-webrtc-device";
125 constexpr int kBufferSize = 65536;
126
127 const uint32_t backoff_ms[] = {1000, 2000, 3000, 4000, 5000};
128
129 const lws_retry_bo_t kRetry = {
130 .retry_ms_table = backoff_ms,
131 .retry_ms_table_count = LWS_ARRAY_SIZE(backoff_ms),
132 .conceal_count = LWS_ARRAY_SIZE(backoff_ms),
133
134 .secs_since_valid_ping = 3, /* force PINGs after secs idle */
135 .secs_since_valid_hangup = 10, /* hangup after secs idle */
136
137 .jitter_percent = 20,
138 };
139
140 const struct lws_protocols kProtocols[2] = {
141 {kProtocolName, LwsCallback, 0, kBufferSize, 0, NULL, 0},
142 {NULL, NULL, 0, 0, 0, NULL, 0}};
143
144 } // namespace
145
Create()146 std::shared_ptr<WsConnectionContext> WsConnectionContext::Create() {
147 struct lws_context_creation_info context_info = {};
148 context_info.port = CONTEXT_PORT_NO_LISTEN;
149 context_info.options = LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT;
150 context_info.protocols = kProtocols;
151 struct lws_context* lws_ctx = lws_create_context(&context_info);
152 if (!lws_ctx) {
153 return nullptr;
154 }
155 return std::shared_ptr<WsConnectionContext>(
156 new WsConnectionContextImpl(lws_ctx));
157 }
158
WsConnectionContextImpl(struct lws_context * lws_ctx)159 WsConnectionContextImpl::WsConnectionContextImpl(struct lws_context* lws_ctx)
160 : lws_context_(lws_ctx) {
161 Start();
162 }
163
~WsConnectionContextImpl()164 WsConnectionContextImpl::~WsConnectionContextImpl() {
165 lws_context_destroy(lws_context_);
166 if (message_loop_.joinable()) {
167 message_loop_.join();
168 }
169 }
170
Start()171 void WsConnectionContextImpl::Start() {
172 message_loop_ = std::thread([this]() {
173 for (;;) {
174 if (lws_service(lws_context_, 0) < 0) {
175 break;
176 }
177 }
178 });
179 }
180
CreateConnection(int port,const std::string & addr,const std::string & path,WsConnection::Security security,std::weak_ptr<WsConnectionObserver> observer,const std::vector<std::pair<std::string,std::string>> & headers)181 std::shared_ptr<WsConnection> WsConnectionContextImpl::CreateConnection(
182 int port, const std::string& addr, const std::string& path,
183 WsConnection::Security security,
184 std::weak_ptr<WsConnectionObserver> observer,
185 const std::vector<std::pair<std::string, std::string>>& headers) {
186 return std::shared_ptr<WsConnection>(new WsConnectionImpl(
187 port, addr, path, security, headers, observer, shared_from_this()));
188 }
189
GetConnection(void * raw)190 std::shared_ptr<WsConnectionImpl> WsConnectionContextImpl::GetConnection(
191 void* raw) {
192 std::shared_ptr<WsConnectionImpl> connection;
193 {
194 std::lock_guard<std::mutex> lock(map_mutex_);
195 if (weak_by_ptr_.count(raw) == 0) {
196 return nullptr;
197 }
198 connection = weak_by_ptr_[raw].lock();
199 if (!connection) {
200 weak_by_ptr_.erase(raw);
201 }
202 }
203 return connection;
204 }
205
RememberConnection(void * raw,std::weak_ptr<WsConnectionImpl> conn)206 void WsConnectionContextImpl::RememberConnection(
207 void* raw, std::weak_ptr<WsConnectionImpl> conn) {
208 std::lock_guard<std::mutex> lock(map_mutex_);
209 weak_by_ptr_.emplace(
210 std::pair<void*, std::weak_ptr<WsConnectionImpl>>(raw, conn));
211 }
212
ForgetConnection(void * raw)213 void WsConnectionContextImpl::ForgetConnection(void* raw) {
214 std::lock_guard<std::mutex> lock(map_mutex_);
215 weak_by_ptr_.erase(raw);
216 }
217
WsConnectionImpl(int port,const std::string & addr,const std::string & path,Security security,const std::vector<std::pair<std::string,std::string>> & headers,std::weak_ptr<WsConnectionObserver> observer,std::shared_ptr<WsConnectionContextImpl> context)218 WsConnectionImpl::WsConnectionImpl(
219 int port, const std::string& addr, const std::string& path,
220 Security security,
221 const std::vector<std::pair<std::string, std::string>>& headers,
222 std::weak_ptr<WsConnectionObserver> observer,
223 std::shared_ptr<WsConnectionContextImpl> context)
224 : port_(port),
225 addr_(addr),
226 path_(path),
227 security_(security),
228 headers_(headers),
229 observer_(observer),
230 context_(context) {}
231
~WsConnectionImpl()232 WsConnectionImpl::~WsConnectionImpl() {
233 context_->ForgetConnection(this);
234 // This will cause the callback to be called which will drop the connection
235 // after seeing the context doesn't remember this object
236 lws_callback_on_writable(wsi_);
237 }
238
Connect()239 void WsConnectionImpl::Connect() {
240 memset(&extended_sul_.sul, 0, sizeof(extended_sul_.sul));
241 extended_sul_.weak_this = weak_from_this();
242 lws_sul_schedule(context_->lws_context(), 0, &extended_sul_.sul,
243 CreateConnectionCallback, 1);
244 }
245
AddHttpHeaders(unsigned char ** p,unsigned char * end) const246 void WsConnectionImpl::AddHttpHeaders(unsigned char** p,
247 unsigned char* end) const {
248 for (const auto& header_entry: headers_) {
249 const auto& name = header_entry.first;
250 const auto& value = header_entry.second;
251 auto res = lws_add_http_header_by_name(
252 wsi_, reinterpret_cast<const unsigned char*>(name.c_str()),
253 reinterpret_cast<const unsigned char*>(value.c_str()), value.size(), p,
254 end);
255 if (res != 0) {
256 LOG(ERROR) << "Unable to add header: " << name;
257 }
258 }
259 if (!headers_.empty()) {
260 // Let LWS know we added some headers.
261 lws_client_http_body_pending(wsi_, 1);
262 }
263 }
264
OnError(const std::string & error)265 void WsConnectionImpl::OnError(const std::string& error) {
266 auto observer = observer_.lock();
267 if (observer) {
268 observer->OnError(error);
269 }
270 }
OnReceive(const uint8_t * data,size_t len,bool is_binary)271 void WsConnectionImpl::OnReceive(const uint8_t* data, size_t len,
272 bool is_binary) {
273 auto observer = observer_.lock();
274 if (observer) {
275 observer->OnReceive(data, len, is_binary);
276 }
277 }
OnOpen()278 void WsConnectionImpl::OnOpen() {
279 auto observer = observer_.lock();
280 if (observer) {
281 observer->OnOpen();
282 }
283 }
OnClose()284 void WsConnectionImpl::OnClose() {
285 auto observer = observer_.lock();
286 if (observer) {
287 observer->OnClose();
288 }
289 }
290
OnWriteable()291 void WsConnectionImpl::OnWriteable() {
292 WsBuffer buffer;
293 {
294 std::lock_guard<std::mutex> lock(write_queue_mutex_);
295 if (write_queue_.size() == 0) {
296 return;
297 }
298 buffer = std::move(write_queue_.front());
299 write_queue_.pop_front();
300 }
301 auto flags = lws_write_ws_flags(
302 buffer.is_binary() ? LWS_WRITE_BINARY : LWS_WRITE_TEXT, true, true);
303 auto res = lws_write(wsi_, buffer.data(), buffer.size(),
304 (enum lws_write_protocol)flags);
305 if (res != buffer.size()) {
306 LOG(WARNING) << "Unable to send the entire message!";
307 }
308 }
309
Send(const uint8_t * data,size_t len,bool binary)310 bool WsConnectionImpl::Send(const uint8_t* data, size_t len, bool binary) {
311 if (!wsi_) {
312 LOG(WARNING) << "Send called on an uninitialized connection!!";
313 return false;
314 }
315 WsBuffer buffer(data, len, binary);
316 {
317 std::lock_guard<std::mutex> lock(write_queue_mutex_);
318 write_queue_.emplace_back(std::move(buffer));
319 }
320
321 lws_callback_on_writable(wsi_);
322 return true;
323 }
324
LwsCallback(struct lws * wsi,enum lws_callback_reasons reason,void * user,void * in,size_t len)325 int LwsCallback(struct lws* wsi, enum lws_callback_reasons reason, void* user,
326 void* in, size_t len) {
327 constexpr int DROP = -1;
328 constexpr int OK = 0;
329
330 // For some values of `reason`, `user` doesn't point to the value provided
331 // when the connection was created. This function object should be used with
332 // care.
333 auto with_connection =
334 [wsi, user](std::function<void(std::shared_ptr<WsConnectionImpl>)> cb) {
335 auto context = reinterpret_cast<WsConnectionContextImpl*>(user);
336 auto connection = context->GetConnection(wsi);
337 if (!connection) {
338 return DROP;
339 }
340 cb(connection);
341 return OK;
342 };
343
344 switch (reason) {
345 case LWS_CALLBACK_CLIENT_CONNECTION_ERROR:
346 return with_connection(
347 [in](std::shared_ptr<WsConnectionImpl> connection) {
348 connection->OnError(in ? (char*)in : "(null)");
349 });
350
351 case LWS_CALLBACK_CLIENT_RECEIVE:
352 return with_connection(
353 [in, len, wsi](std::shared_ptr<WsConnectionImpl> connection) {
354 connection->OnReceive((const uint8_t*)in, len,
355 lws_frame_is_binary(wsi));
356 });
357
358 case LWS_CALLBACK_CLIENT_ESTABLISHED:
359 return with_connection([](std::shared_ptr<WsConnectionImpl> connection) {
360 connection->OnOpen();
361 });
362
363 case LWS_CALLBACK_CLIENT_CLOSED:
364 return with_connection([](std::shared_ptr<WsConnectionImpl> connection) {
365 connection->OnClose();
366 });
367
368 case LWS_CALLBACK_CLIENT_WRITEABLE:
369 return with_connection([](std::shared_ptr<WsConnectionImpl> connection) {
370 connection->OnWriteable();
371 });
372
373 case LWS_CALLBACK_CLIENT_APPEND_HANDSHAKE_HEADER:
374 return with_connection(
375 [in, len](std::shared_ptr<WsConnectionImpl> connection) {
376 auto p = reinterpret_cast<unsigned char**>(in);
377 auto end = (*p) + len;
378 connection->AddHttpHeaders(p, end);
379 });
380
381 case LWS_CALLBACK_CLIENT_HTTP_WRITEABLE:
382 // This callback is only called when we add additional HTTP headers, let
383 // LWS know we're done modifying the HTTP request.
384 lws_client_http_body_pending(wsi, 0);
385 return 0;
386
387 default:
388 LOG(VERBOSE) << "Unhandled value: " << reason;
389 return lws_callback_http_dummy(wsi, reason, user, in, len);
390 }
391 }
392
CreateConnectionCallback(lws_sorted_usec_list_t * sul)393 void CreateConnectionCallback(lws_sorted_usec_list_t* sul) {
394 std::shared_ptr<WsConnectionImpl> connection =
395 reinterpret_cast<WsConnectionImpl::CreateConnectionSul*>(sul)
396 ->weak_this.lock();
397 if (!connection) {
398 LOG(WARNING) << "The object was already destroyed by the time of the first "
399 << "connection attempt. That's unusual.";
400 return;
401 }
402 connection->ConnectInner();
403 }
404
ConnectInner()405 void WsConnectionImpl::ConnectInner() {
406 struct lws_client_connect_info connect_info;
407
408 memset(&connect_info, 0, sizeof(connect_info));
409
410 connect_info.context = context_->lws_context();
411 connect_info.port = port_;
412 connect_info.address = addr_.c_str();
413 connect_info.path = path_.c_str();
414 connect_info.host = connect_info.address;
415 connect_info.origin = connect_info.address;
416 switch (security_) {
417 case Security::kAllowSelfSigned:
418 connect_info.ssl_connection = LCCSCF_ALLOW_SELFSIGNED |
419 LCCSCF_SKIP_SERVER_CERT_HOSTNAME_CHECK |
420 LCCSCF_USE_SSL;
421 break;
422 case Security::kStrict:
423 connect_info.ssl_connection = LCCSCF_USE_SSL;
424 break;
425 case Security::kInsecure:
426 connect_info.ssl_connection = 0;
427 break;
428 }
429 connect_info.protocol = "webrtc-operator";
430 connect_info.local_protocol_name = kProtocolName;
431 connect_info.pwsi = &wsi_;
432 connect_info.retry_and_idle_policy = &kRetry;
433 // There is no guarantee the connection object still exists when the callback
434 // is called. Put the context instead as the user data which is guaranteed to
435 // still exist and holds a weak ptr to the connection.
436 connect_info.userdata = context_.get();
437
438 if (lws_client_connect_via_info(&connect_info)) {
439 // wsi_ is not initialized until after the call to
440 // lws_client_connect_via_info(). Luckily, this is guaranteed to run before
441 // the protocol callback is called because it runs in the same loop.
442 context_->RememberConnection(wsi_, weak_from_this());
443 } else {
444 LOG(ERROR) << "Connection failed!";
445 }
446 }
447