• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright (C) 2020 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 //
15 
16 #include "host/frontend/webrtc/lib/ws_connection.h"
17 
18 #include <android-base/logging.h>
19 #include <libwebsockets.h>
20 
21 class WsConnectionContextImpl;
22 
23 class WsConnectionImpl : public WsConnection,
24                          public std::enable_shared_from_this<WsConnectionImpl> {
25  public:
26   struct CreateConnectionSul {
27     lws_sorted_usec_list_t sul = {};
28     std::weak_ptr<WsConnectionImpl> weak_this;
29   };
30 
31   WsConnectionImpl(
32       int port, const std::string& addr, const std::string& path,
33       Security secure,
34       const std::vector<std::pair<std::string, std::string>>& headers,
35       std::weak_ptr<WsConnectionObserver> observer,
36       std::shared_ptr<WsConnectionContextImpl> context);
37 
38   ~WsConnectionImpl() override;
39 
40   void Connect() override;
41   void ConnectInner();
42 
43   bool Send(const uint8_t* data, size_t len, bool binary = false) override;
44 
45   void OnError(const std::string& error);
46   void OnReceive(const uint8_t* data, size_t len, bool is_binary);
47   void OnOpen();
48   void OnClose();
49   void OnWriteable();
50 
51   void AddHttpHeaders(unsigned char** p, unsigned char* end) const;
52 
53  private:
54   struct WsBuffer {
55     WsBuffer() = default;
WsBufferWsConnectionImpl::WsBuffer56     WsBuffer(const uint8_t* data, size_t len, bool binary)
57         : buffer_(LWS_PRE + len), is_binary_(binary) {
58       memcpy(&buffer_[LWS_PRE], data, len);
59     }
60 
dataWsConnectionImpl::WsBuffer61     uint8_t* data() { return &buffer_[LWS_PRE]; }
is_binaryWsConnectionImpl::WsBuffer62     bool is_binary() const { return is_binary_; }
sizeWsConnectionImpl::WsBuffer63     size_t size() const { return buffer_.size() - LWS_PRE; }
64 
65    private:
66     std::vector<uint8_t> buffer_;
67     bool is_binary_;
68   };
69 
70   CreateConnectionSul extended_sul_;
71   struct lws* wsi_;
72   const int port_;
73   const std::string addr_;
74   const std::string path_;
75   const Security security_;
76   const std::vector<std::pair<std::string, std::string>> headers_;
77 
78   std::weak_ptr<WsConnectionObserver> observer_;
79 
80   // each element contains the data to be sent and whether it's binary or not
81   std::deque<WsBuffer> write_queue_;
82   std::mutex write_queue_mutex_;
83   // The connection object should not outlive the context object. This reference
84   // guarantees it.
85   std::shared_ptr<WsConnectionContextImpl> context_;
86 };
87 
88 class WsConnectionContextImpl
89     : public WsConnectionContext,
90       public std::enable_shared_from_this<WsConnectionContextImpl> {
91  public:
92   WsConnectionContextImpl(struct lws_context* lws_ctx);
93   ~WsConnectionContextImpl() override;
94 
95   std::shared_ptr<WsConnection> CreateConnection(
96       int port, const std::string& addr, const std::string& path,
97       WsConnection::Security secure,
98       std::weak_ptr<WsConnectionObserver> observer,
99       const std::vector<std::pair<std::string, std::string>>& headers) override;
100 
101   void RememberConnection(void*, std::weak_ptr<WsConnectionImpl>);
102   void ForgetConnection(void*);
103   std::shared_ptr<WsConnectionImpl> GetConnection(void*);
104 
lws_context()105   struct lws_context* lws_context() {
106     return lws_context_;
107   }
108 
109  private:
110   void Start();
111 
112   std::map<void*, std::weak_ptr<WsConnectionImpl>> weak_by_ptr_;
113   std::mutex map_mutex_;
114   struct lws_context* lws_context_;
115   std::thread message_loop_;
116 };
117 
118 int LwsCallback(struct lws* wsi, enum lws_callback_reasons reason, void* user,
119                 void* in, size_t len);
120 void CreateConnectionCallback(lws_sorted_usec_list_t* sul);
121 
122 namespace {
123 
124 constexpr char kProtocolName[] = "cf-webrtc-device";
125 constexpr int kBufferSize = 65536;
126 
127 const uint32_t backoff_ms[] = {1000, 2000, 3000, 4000, 5000};
128 
129 const lws_retry_bo_t kRetry = {
130     .retry_ms_table = backoff_ms,
131     .retry_ms_table_count = LWS_ARRAY_SIZE(backoff_ms),
132     .conceal_count = LWS_ARRAY_SIZE(backoff_ms),
133 
134     .secs_since_valid_ping = 3,    /* force PINGs after secs idle */
135     .secs_since_valid_hangup = 10, /* hangup after secs idle */
136 
137     .jitter_percent = 20,
138 };
139 
140 const struct lws_protocols kProtocols[2] = {
141     {kProtocolName, LwsCallback, 0, kBufferSize, 0, NULL, 0},
142     {NULL, NULL, 0, 0, 0, NULL, 0}};
143 
144 }  // namespace
145 
Create()146 std::shared_ptr<WsConnectionContext> WsConnectionContext::Create() {
147   struct lws_context_creation_info context_info = {};
148   context_info.port = CONTEXT_PORT_NO_LISTEN;
149   context_info.options = LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT;
150   context_info.protocols = kProtocols;
151   struct lws_context* lws_ctx = lws_create_context(&context_info);
152   if (!lws_ctx) {
153     return nullptr;
154   }
155   return std::shared_ptr<WsConnectionContext>(
156       new WsConnectionContextImpl(lws_ctx));
157 }
158 
WsConnectionContextImpl(struct lws_context * lws_ctx)159 WsConnectionContextImpl::WsConnectionContextImpl(struct lws_context* lws_ctx)
160     : lws_context_(lws_ctx) {
161   Start();
162 }
163 
~WsConnectionContextImpl()164 WsConnectionContextImpl::~WsConnectionContextImpl() {
165   lws_context_destroy(lws_context_);
166   if (message_loop_.joinable()) {
167     message_loop_.join();
168   }
169 }
170 
Start()171 void WsConnectionContextImpl::Start() {
172   message_loop_ = std::thread([this]() {
173     for (;;) {
174       if (lws_service(lws_context_, 0) < 0) {
175         break;
176       }
177     }
178   });
179 }
180 
CreateConnection(int port,const std::string & addr,const std::string & path,WsConnection::Security security,std::weak_ptr<WsConnectionObserver> observer,const std::vector<std::pair<std::string,std::string>> & headers)181 std::shared_ptr<WsConnection> WsConnectionContextImpl::CreateConnection(
182     int port, const std::string& addr, const std::string& path,
183     WsConnection::Security security,
184     std::weak_ptr<WsConnectionObserver> observer,
185     const std::vector<std::pair<std::string, std::string>>& headers) {
186   return std::shared_ptr<WsConnection>(new WsConnectionImpl(
187       port, addr, path, security, headers, observer, shared_from_this()));
188 }
189 
GetConnection(void * raw)190 std::shared_ptr<WsConnectionImpl> WsConnectionContextImpl::GetConnection(
191     void* raw) {
192   std::shared_ptr<WsConnectionImpl> connection;
193   {
194     std::lock_guard<std::mutex> lock(map_mutex_);
195     if (weak_by_ptr_.count(raw) == 0) {
196       return nullptr;
197     }
198     connection = weak_by_ptr_[raw].lock();
199     if (!connection) {
200       weak_by_ptr_.erase(raw);
201     }
202   }
203   return connection;
204 }
205 
RememberConnection(void * raw,std::weak_ptr<WsConnectionImpl> conn)206 void WsConnectionContextImpl::RememberConnection(
207     void* raw, std::weak_ptr<WsConnectionImpl> conn) {
208   std::lock_guard<std::mutex> lock(map_mutex_);
209   weak_by_ptr_.emplace(
210       std::pair<void*, std::weak_ptr<WsConnectionImpl>>(raw, conn));
211 }
212 
ForgetConnection(void * raw)213 void WsConnectionContextImpl::ForgetConnection(void* raw) {
214   std::lock_guard<std::mutex> lock(map_mutex_);
215   weak_by_ptr_.erase(raw);
216 }
217 
WsConnectionImpl(int port,const std::string & addr,const std::string & path,Security security,const std::vector<std::pair<std::string,std::string>> & headers,std::weak_ptr<WsConnectionObserver> observer,std::shared_ptr<WsConnectionContextImpl> context)218 WsConnectionImpl::WsConnectionImpl(
219     int port, const std::string& addr, const std::string& path,
220     Security security,
221     const std::vector<std::pair<std::string, std::string>>& headers,
222     std::weak_ptr<WsConnectionObserver> observer,
223     std::shared_ptr<WsConnectionContextImpl> context)
224     : port_(port),
225       addr_(addr),
226       path_(path),
227       security_(security),
228       headers_(headers),
229       observer_(observer),
230       context_(context) {}
231 
~WsConnectionImpl()232 WsConnectionImpl::~WsConnectionImpl() {
233   context_->ForgetConnection(this);
234   // This will cause the callback to be called which will drop the connection
235   // after seeing the context doesn't remember this object
236   lws_callback_on_writable(wsi_);
237 }
238 
Connect()239 void WsConnectionImpl::Connect() {
240   memset(&extended_sul_.sul, 0, sizeof(extended_sul_.sul));
241   extended_sul_.weak_this = weak_from_this();
242   lws_sul_schedule(context_->lws_context(), 0, &extended_sul_.sul,
243                    CreateConnectionCallback, 1);
244 }
245 
AddHttpHeaders(unsigned char ** p,unsigned char * end) const246 void WsConnectionImpl::AddHttpHeaders(unsigned char** p,
247                                       unsigned char* end) const {
248   for (const auto& header_entry: headers_) {
249     const auto& name = header_entry.first;
250     const auto& value = header_entry.second;
251     auto res = lws_add_http_header_by_name(
252         wsi_, reinterpret_cast<const unsigned char*>(name.c_str()),
253         reinterpret_cast<const unsigned char*>(value.c_str()), value.size(), p,
254         end);
255     if (res != 0) {
256       LOG(ERROR) << "Unable to add header: " << name;
257     }
258   }
259   if (!headers_.empty()) {
260     // Let LWS know we added some headers.
261     lws_client_http_body_pending(wsi_, 1);
262   }
263 }
264 
OnError(const std::string & error)265 void WsConnectionImpl::OnError(const std::string& error) {
266   auto observer = observer_.lock();
267   if (observer) {
268     observer->OnError(error);
269   }
270 }
OnReceive(const uint8_t * data,size_t len,bool is_binary)271 void WsConnectionImpl::OnReceive(const uint8_t* data, size_t len,
272                                  bool is_binary) {
273   auto observer = observer_.lock();
274   if (observer) {
275     observer->OnReceive(data, len, is_binary);
276   }
277 }
OnOpen()278 void WsConnectionImpl::OnOpen() {
279   auto observer = observer_.lock();
280   if (observer) {
281     observer->OnOpen();
282   }
283 }
OnClose()284 void WsConnectionImpl::OnClose() {
285   auto observer = observer_.lock();
286   if (observer) {
287     observer->OnClose();
288   }
289 }
290 
OnWriteable()291 void WsConnectionImpl::OnWriteable() {
292   WsBuffer buffer;
293   {
294     std::lock_guard<std::mutex> lock(write_queue_mutex_);
295     if (write_queue_.size() == 0) {
296       return;
297     }
298     buffer = std::move(write_queue_.front());
299     write_queue_.pop_front();
300   }
301   auto flags = lws_write_ws_flags(
302       buffer.is_binary() ? LWS_WRITE_BINARY : LWS_WRITE_TEXT, true, true);
303   auto res = lws_write(wsi_, buffer.data(), buffer.size(),
304                        (enum lws_write_protocol)flags);
305   if (res != buffer.size()) {
306     LOG(WARNING) << "Unable to send the entire message!";
307   }
308 }
309 
Send(const uint8_t * data,size_t len,bool binary)310 bool WsConnectionImpl::Send(const uint8_t* data, size_t len, bool binary) {
311   if (!wsi_) {
312     LOG(WARNING) << "Send called on an uninitialized connection!!";
313     return false;
314   }
315   WsBuffer buffer(data, len, binary);
316   {
317     std::lock_guard<std::mutex> lock(write_queue_mutex_);
318     write_queue_.emplace_back(std::move(buffer));
319   }
320 
321   lws_callback_on_writable(wsi_);
322   return true;
323 }
324 
LwsCallback(struct lws * wsi,enum lws_callback_reasons reason,void * user,void * in,size_t len)325 int LwsCallback(struct lws* wsi, enum lws_callback_reasons reason, void* user,
326                 void* in, size_t len) {
327   constexpr int DROP = -1;
328   constexpr int OK = 0;
329 
330   // For some values of `reason`, `user` doesn't point to the value provided
331   // when the connection was created. This function object should be used with
332   // care.
333   auto with_connection =
334       [wsi, user](std::function<void(std::shared_ptr<WsConnectionImpl>)> cb) {
335         auto context = reinterpret_cast<WsConnectionContextImpl*>(user);
336         auto connection = context->GetConnection(wsi);
337         if (!connection) {
338           return DROP;
339         }
340         cb(connection);
341         return OK;
342       };
343 
344   switch (reason) {
345     case LWS_CALLBACK_CLIENT_CONNECTION_ERROR:
346       return with_connection(
347           [in](std::shared_ptr<WsConnectionImpl> connection) {
348             connection->OnError(in ? (char*)in : "(null)");
349           });
350 
351     case LWS_CALLBACK_CLIENT_RECEIVE:
352       return with_connection(
353           [in, len, wsi](std::shared_ptr<WsConnectionImpl> connection) {
354             connection->OnReceive((const uint8_t*)in, len,
355                                   lws_frame_is_binary(wsi));
356           });
357 
358     case LWS_CALLBACK_CLIENT_ESTABLISHED:
359       return with_connection([](std::shared_ptr<WsConnectionImpl> connection) {
360         connection->OnOpen();
361       });
362 
363     case LWS_CALLBACK_CLIENT_CLOSED:
364       return with_connection([](std::shared_ptr<WsConnectionImpl> connection) {
365         connection->OnClose();
366       });
367 
368     case LWS_CALLBACK_CLIENT_WRITEABLE:
369       return with_connection([](std::shared_ptr<WsConnectionImpl> connection) {
370         connection->OnWriteable();
371       });
372 
373     case LWS_CALLBACK_CLIENT_APPEND_HANDSHAKE_HEADER:
374       return with_connection(
375           [in, len](std::shared_ptr<WsConnectionImpl> connection) {
376             auto p = reinterpret_cast<unsigned char**>(in);
377             auto end = (*p) + len;
378             connection->AddHttpHeaders(p, end);
379           });
380 
381     case LWS_CALLBACK_CLIENT_HTTP_WRITEABLE:
382       // This callback is only called when we add additional HTTP headers, let
383       // LWS know we're done modifying the HTTP request.
384       lws_client_http_body_pending(wsi, 0);
385       return 0;
386 
387     default:
388       LOG(VERBOSE) << "Unhandled value: " << reason;
389       return lws_callback_http_dummy(wsi, reason, user, in, len);
390   }
391 }
392 
CreateConnectionCallback(lws_sorted_usec_list_t * sul)393 void CreateConnectionCallback(lws_sorted_usec_list_t* sul) {
394   std::shared_ptr<WsConnectionImpl> connection =
395       reinterpret_cast<WsConnectionImpl::CreateConnectionSul*>(sul)
396           ->weak_this.lock();
397   if (!connection) {
398     LOG(WARNING) << "The object was already destroyed by the time of the first "
399                  << "connection attempt. That's unusual.";
400     return;
401   }
402   connection->ConnectInner();
403 }
404 
ConnectInner()405 void WsConnectionImpl::ConnectInner() {
406   struct lws_client_connect_info connect_info;
407 
408   memset(&connect_info, 0, sizeof(connect_info));
409 
410   connect_info.context = context_->lws_context();
411   connect_info.port = port_;
412   connect_info.address = addr_.c_str();
413   connect_info.path = path_.c_str();
414   connect_info.host = connect_info.address;
415   connect_info.origin = connect_info.address;
416   switch (security_) {
417     case Security::kAllowSelfSigned:
418       connect_info.ssl_connection = LCCSCF_ALLOW_SELFSIGNED |
419                                     LCCSCF_SKIP_SERVER_CERT_HOSTNAME_CHECK |
420                                     LCCSCF_USE_SSL;
421       break;
422     case Security::kStrict:
423       connect_info.ssl_connection = LCCSCF_USE_SSL;
424       break;
425     case Security::kInsecure:
426       connect_info.ssl_connection = 0;
427       break;
428   }
429   connect_info.protocol = "webrtc-operator";
430   connect_info.local_protocol_name = kProtocolName;
431   connect_info.pwsi = &wsi_;
432   connect_info.retry_and_idle_policy = &kRetry;
433   // There is no guarantee the connection object still exists when the callback
434   // is called. Put the context instead as the user data which is guaranteed to
435   // still exist and holds a weak ptr to the connection.
436   connect_info.userdata = context_.get();
437 
438   if (lws_client_connect_via_info(&connect_info)) {
439     // wsi_ is not initialized until after the call to
440     // lws_client_connect_via_info(). Luckily, this is guaranteed to run before
441     // the protocol callback is called because it runs in the same loop.
442     context_->RememberConnection(wsi_, weak_from_this());
443   } else {
444     LOG(ERROR) << "Connection failed!";
445   }
446 }
447