1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include <cstring>
17 #include <iostream>
18 #include <securec.h>
19 #include <string>
20
21 #include "netstack_log.h"
22 #include "websocket_client_innerapi.h"
23
24 static constexpr const char *PATH_START = "/";
25 static constexpr const char *NAME_END = ":";
26 static constexpr const char *STATUS_LINE_SEP = " ";
27 static constexpr const size_t STATUS_LINE_ELEM_NUM = 2;
28 static constexpr const char *PREFIX_HTTPS = "https";
29 static constexpr const char *PREFIX_WSS = "wss";
30 static constexpr const int MAX_URI_LENGTH = 1024;
31 static constexpr const int MAX_HDR_LENGTH = 1024;
32 static constexpr const int MAX_HEADER_LENGTH = 8192;
33 static constexpr const size_t MAX_DATA_LENGTH = 4 * 1024 * 1024;
34 static constexpr const int FD_LIMIT_PER_THREAD = 1 + 1 + 1;
35 static constexpr const int CLOSE_RESULT_FROM_SERVER_CODE = 1001;
36 static constexpr const int CLOSE_RESULT_FROM_CLIENT_CODE = 1000;
37 static constexpr const char *LINK_DOWN = "The link is down";
38 static constexpr const char *CLOSE_REASON_FORM_SERVER = "websocket close from server";
39 static constexpr const int FUNCTION_PARAM_TWO = 2;
40 static std::atomic<int> g_clientID(0);
41 namespace OHOS::NetStack::WebSocketClient {
42 static const lws_retry_bo_t RETRY = {
43 .secs_since_valid_ping = 0, /* force PINGs after secs idle */
44 .secs_since_valid_hangup = 10, /* hangup after secs idle */
45 .jitter_percent = 20,
46 };
47
WebSocketClient()48 WebSocketClient::WebSocketClient()
49 {
50 clientContext = new ClientContext();
51 clientContext->SetClientId(++g_clientID);
52 }
53
~WebSocketClient()54 WebSocketClient::~WebSocketClient()
55 {
56 delete clientContext;
57 clientContext = nullptr;
58 }
59
GetClientContext() const60 ClientContext *WebSocketClient::GetClientContext() const
61 {
62 return clientContext;
63 }
64
RunService(WebSocketClient * Client)65 void RunService(WebSocketClient *Client)
66 {
67 if (Client->GetClientContext()->GetContext() == nullptr) {
68 return;
69 }
70 while (!Client->GetClientContext()->IsThreadStop()) {
71 lws_service(Client->GetClientContext()->GetContext(), 0);
72 }
73 }
74
HttpDummy(lws * wsi,lws_callback_reasons reason,void * user,void * in,size_t len)75 int HttpDummy(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len)
76 {
77 int ret = lws_callback_http_dummy(wsi, reason, user, in, len);
78 return ret;
79 }
80
81 struct CallbackDispatcher {
82 lws_callback_reasons reason;
83 int (*callback)(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len);
84 };
85
LwsCallbackClientAppendHandshakeHeader(lws * wsi,lws_callback_reasons reason,void * user,void * in,size_t len)86 int LwsCallbackClientAppendHandshakeHeader(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len)
87 {
88 WebSocketClient *client = static_cast<WebSocketClient *>(user);
89 if (client->GetClientContext() == nullptr) {
90 NETSTACK_LOGE("Callback ClientContext is nullptr");
91 return -1;
92 }
93 NETSTACK_LOGD("ClientId:%{public}d, Lws Callback AppendHandshakeHeader,",
94 client->GetClientContext()->GetClientId());
95 auto payload = reinterpret_cast<unsigned char **>(in);
96 if (payload == nullptr || (*payload) == nullptr || len == 0) {
97 return -1;
98 }
99 auto payloadEnd = (*payload) + len;
100 for (const auto &pair : client->GetClientContext()->header) {
101 std::string name = pair.first + NAME_END;
102 if (lws_add_http_header_by_name(wsi, reinterpret_cast<const unsigned char *>(name.c_str()),
103 reinterpret_cast<const unsigned char *>(pair.second.c_str()),
104 static_cast<int>(strlen(pair.second.c_str())), payload, payloadEnd)) {
105 return -1;
106 }
107 }
108 return HttpDummy(wsi, reason, user, in, len);
109 }
110
LwsCallbackWsPeerInitiatedClose(lws * wsi,lws_callback_reasons reason,void * user,void * in,size_t len)111 int LwsCallbackWsPeerInitiatedClose(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len)
112 {
113 WebSocketClient *client = static_cast<WebSocketClient *>(user);
114 if (client->GetClientContext() == nullptr) {
115 NETSTACK_LOGE("Lws Callback ClientContext is nullptr");
116 return -1;
117 }
118 NETSTACK_LOGD("ClientId:%{public}d,Callback WsPeerInitiatedClose", client->GetClientContext()->GetClientId());
119 if (in == nullptr || len < sizeof(uint16_t)) {
120 NETSTACK_LOGE("Lws Callback WsPeerInitiatedClose");
121 client->GetClientContext()->Close(LWS_CLOSE_STATUS_NORMAL, "");
122 return HttpDummy(wsi, reason, user, in, len);
123 }
124 uint16_t closeStatus = ntohs(*reinterpret_cast<uint16_t *>(in));
125 std::string closeReason;
126 closeReason.append(reinterpret_cast<char *>(in) + sizeof(uint16_t), len - sizeof(uint16_t));
127 client->GetClientContext()->Close(static_cast<lws_close_status>(closeStatus), closeReason);
128 return HttpDummy(wsi, reason, user, in, len);
129 }
130
LwsCallbackClientWritable(lws * wsi,lws_callback_reasons reason,void * user,void * in,size_t len)131 int LwsCallbackClientWritable(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len)
132 {
133 WebSocketClient *client = static_cast<WebSocketClient *>(user);
134 if (client->GetClientContext() == nullptr) {
135 NETSTACK_LOGE("Lws Callback ClientContext is nullptr");
136 return -1;
137 }
138 NETSTACK_LOGD("ClientId:%{public}d,Callback CallbackClientWritable,",
139 client->GetClientContext()->GetClientId());
140 if (client->GetClientContext()->IsClosed()) {
141 NETSTACK_LOGD("ClientId:%{public}d,Callback ClientWritable need to close",
142 client->GetClientContext()->GetClientId());
143 lws_close_reason(
144 wsi, client->GetClientContext()->closeStatus,
145 reinterpret_cast<unsigned char *>(const_cast<char *>(client->GetClientContext()->closeReason.c_str())),
146 strlen(client->GetClientContext()->closeReason.c_str()));
147 // here do not emit error, because we close it
148 return -1;
149 }
150 SendData sendData = client->GetClientContext()->Pop();
151 if (sendData.data == nullptr || sendData.length == 0) {
152 return HttpDummy(wsi, reason, user, in, len);
153 }
154 const char *message = sendData.data;
155 size_t messageLen = strlen(message);
156 auto buffer = std::make_unique<unsigned char[]>(LWS_PRE + messageLen);
157 if (buffer == nullptr) {
158 return -1;
159 }
160 int result = memcpy_s(buffer.get() + LWS_PRE, LWS_PRE + messageLen, message, messageLen);
161 if (result != 0) {
162 return -1;
163 }
164 int bytesSent = lws_write(wsi, buffer.get() + LWS_PRE, messageLen, LWS_WRITE_TEXT);
165 NETSTACK_LOGD("ClientId:%{public}d,Client Writable send data length = %{public}d",
166 client->GetClientContext()->GetClientId(), bytesSent);
167 return HttpDummy(wsi, reason, user, in, len);
168 }
169
LwsCallbackClientConnectionError(lws * wsi,lws_callback_reasons reason,void * user,void * in,size_t len)170 int LwsCallbackClientConnectionError(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len)
171 {
172 WebSocketClient *client = static_cast<WebSocketClient *>(user);
173 NETSTACK_LOGE("ClientId:%{public}d,Callback ClientConnectionError", client->GetClientContext()->GetClientId());
174 std::string buf;
175 char *data = static_cast<char *>(in);
176 buf.assign(data, len);
177 ErrorResult errorResult;
178 errorResult.errorCode = WebSocketErrorCode::WEBSOCKET_CONNECTION_ERROR;
179 errorResult.errorMessage = data;
180 client->onErrorCallback_(client, errorResult);
181 return HttpDummy(wsi, reason, user, in, len);
182 }
183
LwsCallbackClientReceive(lws * wsi,lws_callback_reasons reason,void * user,void * in,size_t len)184 int LwsCallbackClientReceive(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len)
185 {
186 WebSocketClient *client = static_cast<WebSocketClient *>(user);
187 NETSTACK_LOGD("ClientId:%{public}d,Callback ClientReceive", client->GetClientContext()->GetClientId());
188 std::string buf;
189 char *data = static_cast<char *>(in);
190 buf.assign(data, len);
191 client->onMessageCallback_(client, data, len);
192 return HttpDummy(wsi, reason, user, in, len);
193 }
194
Split(const std::string & str,const std::string & sep,size_t size)195 std::vector<std::string> Split(const std::string &str, const std::string &sep, size_t size)
196 {
197 std::string s = str;
198 std::vector<std::string> res;
199 while (!s.empty()) {
200 if (res.size() + 1 == size) {
201 res.emplace_back(s);
202 break;
203 }
204 auto pos = s.find(sep);
205 if (pos == std::string::npos) {
206 res.emplace_back(s);
207 break;
208 }
209 res.emplace_back(s.substr(0, pos));
210 s = s.substr(pos + sep.size());
211 }
212 return res;
213 }
214
LwsCallbackClientFilterPreEstablish(lws * wsi,lws_callback_reasons reason,void * user,void * in,size_t len)215 int LwsCallbackClientFilterPreEstablish(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len)
216 {
217 WebSocketClient *client = static_cast<WebSocketClient *>(user);
218 if (client->GetClientContext() == nullptr) {
219 NETSTACK_LOGE("Callback ClientContext is nullptr");
220 return -1;
221 }
222 client->GetClientContext()->openStatus = lws_http_client_http_response(wsi);
223 NETSTACK_LOGD("ClientId:%{public}d, libwebsockets Callback ClientFilterPreEstablish openStatus = %{public}d",
224 client->GetClientContext()->GetClientId(), client->GetClientContext()->openStatus);
225 char statusLine[MAX_HDR_LENGTH] = {0};
226 if (lws_hdr_copy(wsi, statusLine, MAX_HDR_LENGTH, WSI_TOKEN_HTTP) < 0 || strlen(statusLine) == 0) {
227 return HttpDummy(wsi, reason, user, in, len);
228 }
229 auto vec = Split(statusLine, STATUS_LINE_SEP, STATUS_LINE_ELEM_NUM);
230 if (vec.size() >= FUNCTION_PARAM_TWO) {
231 client->GetClientContext()->openMessage = vec[1];
232 }
233 return HttpDummy(wsi, reason, user, in, len);
234 }
235
LwsCallbackClientEstablished(lws * wsi,lws_callback_reasons reason,void * user,void * in,size_t len)236 int LwsCallbackClientEstablished(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len)
237 {
238 WebSocketClient *client = static_cast<WebSocketClient *>(user);
239 if (client->GetClientContext() == nullptr) {
240 NETSTACK_LOGE("libwebsockets Callback ClientContext is nullptr");
241 return -1;
242 }
243 NETSTACK_LOGI("ClientId:%{public}d,Callback ClientEstablished", client->GetClientContext()->GetClientId());
244 OpenResult openResult;
245 openResult.status = client->GetClientContext()->openStatus;
246 openResult.message = client->GetClientContext()->openMessage.c_str();
247 client->onOpenCallback_(client, openResult);
248
249 return HttpDummy(wsi, reason, user, in, len);
250 }
251
LwsCallbackClientClosed(lws * wsi,lws_callback_reasons reason,void * user,void * in,size_t len)252 int LwsCallbackClientClosed(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len)
253 {
254 WebSocketClient *client = static_cast<WebSocketClient *>(user);
255 if (client->GetClientContext() == nullptr) {
256 NETSTACK_LOGE("Callback ClientContext is nullptr");
257 return -1;
258 }
259 NETSTACK_LOGI("ClientId:%{public}d,Callback ClientClosed", client->GetClientContext()->GetClientId());
260 std::string buf;
261 char *data = static_cast<char *>(in);
262 buf.assign(data, len);
263 CloseResult closeResult;
264 closeResult.code = CLOSE_RESULT_FROM_SERVER_CODE;
265 closeResult.reason = CLOSE_REASON_FORM_SERVER;
266 client->onCloseCallback_(client, closeResult);
267 client->GetClientContext()->SetThreadStop(true);
268 if ((client->GetClientContext()->closeReason).empty()) {
269 client->GetClientContext()->Close(client->GetClientContext()->closeStatus, LINK_DOWN);
270 }
271 return HttpDummy(wsi, reason, user, in, len);
272 }
273
LwsCallbackWsiDestroy(lws * wsi,lws_callback_reasons reason,void * user,void * in,size_t len)274 int LwsCallbackWsiDestroy(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len)
275 {
276 WebSocketClient *client = static_cast<WebSocketClient *>(user);
277 if (client->GetClientContext() == nullptr) {
278 NETSTACK_LOGE("Callback ClientContext is nullptr");
279 return -1;
280 }
281 NETSTACK_LOGI("Lws Callback LwsCallbackWsiDestroy");
282 return HttpDummy(wsi, reason, user, in, len);
283 }
284
LwsCallbackProtocolDestroy(lws * wsi,lws_callback_reasons reason,void * user,void * in,size_t len)285 int LwsCallbackProtocolDestroy(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len)
286 {
287 NETSTACK_LOGI("Lws Callback ProtocolDestroy");
288 return HttpDummy(wsi, reason, user, in, len);
289 }
290
LwsCallback(lws * wsi,lws_callback_reasons reason,void * user,void * in,size_t len)291 int LwsCallback(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len)
292 {
293 constexpr CallbackDispatcher dispatchers[] = {
294 {LWS_CALLBACK_CLIENT_APPEND_HANDSHAKE_HEADER, LwsCallbackClientAppendHandshakeHeader},
295 {LWS_CALLBACK_WS_PEER_INITIATED_CLOSE, LwsCallbackWsPeerInitiatedClose},
296 {LWS_CALLBACK_CLIENT_WRITEABLE, LwsCallbackClientWritable},
297 {LWS_CALLBACK_CLIENT_CONNECTION_ERROR, LwsCallbackClientConnectionError},
298 {LWS_CALLBACK_CLIENT_RECEIVE, LwsCallbackClientReceive},
299 {LWS_CALLBACK_CLIENT_FILTER_PRE_ESTABLISH, LwsCallbackClientFilterPreEstablish},
300 {LWS_CALLBACK_CLIENT_ESTABLISHED, LwsCallbackClientEstablished},
301 {LWS_CALLBACK_CLIENT_CLOSED, LwsCallbackClientClosed},
302 {LWS_CALLBACK_WSI_DESTROY, LwsCallbackWsiDestroy},
303 {LWS_CALLBACK_PROTOCOL_DESTROY, LwsCallbackProtocolDestroy},
304 };
305 auto it = std::find_if(std::begin(dispatchers), std::end(dispatchers),
306 [&reason](const CallbackDispatcher &dispatcher) { return dispatcher.reason == reason; });
307 if (it != std::end(dispatchers)) {
308 return it->callback(wsi, reason, user, in, len);
309 }
310 return HttpDummy(wsi, reason, user, in, len);
311 }
312
313 static struct lws_protocols protocols[] = {{"lws-minimal-client1", LwsCallback, 0, 0, 0, NULL, 0},
314 LWS_PROTOCOL_LIST_TERM};
315
FillContextInfo(lws_context_creation_info & info)316 static void FillContextInfo(lws_context_creation_info &info)
317 {
318 info.options = LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT;
319 info.port = CONTEXT_PORT_NO_LISTEN;
320 info.protocols = protocols;
321 info.fd_limit_per_thread = FD_LIMIT_PER_THREAD;
322 }
323
ParseUrl(const std::string url,char * prefix,char * address,char * path,int * port)324 bool ParseUrl(const std::string url, char *prefix, char *address, char *path, int *port)
325 {
326 char uri[MAX_URI_LENGTH] = {0};
327 if (strcpy_s(uri, MAX_URI_LENGTH, url.c_str()) < 0) {
328 NETSTACK_LOGE("strcpy_s failed");
329 return false;
330 }
331 const char *tempPrefix = nullptr;
332 const char *tempAddress = nullptr;
333 const char *tempPath = nullptr;
334 (void)lws_parse_uri(uri, &tempPrefix, &tempAddress, port, &tempPath);
335 if (strcpy_s(prefix, MAX_URI_LENGTH, tempPrefix) < 0) {
336 NETSTACK_LOGE("strcpy_s failed");
337 return false;
338 }
339 if (strcpy_s(address, MAX_URI_LENGTH, tempAddress) < 0) {
340 NETSTACK_LOGE("strcpy_s failed");
341 return false;
342 }
343 if (strcpy_s(path, MAX_URI_LENGTH, tempPath) < 0) {
344 NETSTACK_LOGE("strcpy_s failed");
345 return false;
346 }
347 return true;
348 }
349
CreatConnectInfo(const std::string url,lws_context * lwsContext,WebSocketClient * client)350 int CreatConnectInfo(const std::string url, lws_context *lwsContext, WebSocketClient *client)
351 {
352 lws_client_connect_info connectInfo = {};
353 char prefix[MAX_URI_LENGTH] = {0};
354 char address[MAX_URI_LENGTH] = {0};
355 char pathWithoutStart[MAX_URI_LENGTH] = {0};
356 int port = 0;
357 if (!ParseUrl(url, prefix, address, pathWithoutStart, &port)) {
358 return WebSocketErrorCode::WEBSOCKET_CONNECTION_PARSEURL_ERROR;
359 }
360 std::string path = PATH_START + std::string(pathWithoutStart);
361
362 connectInfo.context = lwsContext;
363 connectInfo.address = address;
364 connectInfo.port = port;
365 connectInfo.path = path.c_str();
366 connectInfo.host = address;
367 connectInfo.origin = address;
368
369 connectInfo.local_protocol_name = "lws-minimal-client1";
370 connectInfo.retry_and_idle_policy = &RETRY;
371 if (strcmp(prefix, PREFIX_HTTPS) == 0 || strcmp(prefix, PREFIX_WSS) == 0) {
372 connectInfo.ssl_connection =
373 LCCSCF_USE_SSL | LCCSCF_SKIP_SERVER_CERT_HOSTNAME_CHECK | LCCSCF_ALLOW_INSECURE | LCCSCF_ALLOW_SELFSIGNED;
374 }
375 lws *wsi = nullptr;
376 connectInfo.pwsi = &wsi;
377 connectInfo.userdata = client;
378 if (lws_client_connect_via_info(&connectInfo) == nullptr) {
379 NETSTACK_LOGE("Connect lws_context_destroy");
380 return WebSocketErrorCode::WEBSOCKET_CONNECTION_TO_SERVER_FAIL;
381 }
382 return WebSocketErrorCode::WEBSOCKET_NONE_ERR;
383 }
384
Connect(std::string url,struct OpenOptions options)385 int WebSocketClient::Connect(std::string url, struct OpenOptions options)
386 {
387 NETSTACK_LOGI("ClientId:%{public}d, Connect start", this->GetClientContext()->GetClientId());
388 if (!options.headers.empty()) {
389 if (options.headers.size() > MAX_HEADER_LENGTH) {
390 return WebSocketErrorCode::WEBSOCKET_ERROR_NO_HEADR_EXCEEDS;
391 }
392 for (const auto &item : options.headers) {
393 const std::string &key = item.first;
394 const std::string &value = item.second;
395 this->GetClientContext()->header[key] = value;
396 }
397 }
398 lws_context_creation_info info = {};
399 FillContextInfo(info);
400 lws_context *lwsContext = lws_create_context(&info);
401 if (lwsContext == nullptr) {
402 return WebSocketErrorCode::WEBSOCKET_CONNECTION_NO_MEMOERY;
403 }
404 this->GetClientContext()->SetContext(lwsContext);
405 int ret = CreatConnectInfo(url, lwsContext, this);
406 if (ret != WEBSOCKET_NONE_ERR) {
407 NETSTACK_LOGE("websocket CreatConnectInfo error");
408 GetClientContext()->SetContext(nullptr);
409 lws_context_destroy(lwsContext);
410 return ret;
411 }
412 std::thread serviceThread(RunService, this);
413 serviceThread.detach();
414 return WebSocketErrorCode::WEBSOCKET_NONE_ERR;
415 }
416
Send(char * data,size_t length)417 int WebSocketClient::Send(char *data, size_t length)
418 {
419 if (data == nullptr) {
420 return WebSocketErrorCode::WEBSOCKET_SEND_DATA_NULL;
421 }
422 if (length > MAX_DATA_LENGTH) {
423 return WebSocketErrorCode::WEBSOCKET_DATA_LENGTH_EXCEEDS;
424 }
425 if (this->GetClientContext() == nullptr) {
426 return WebSocketErrorCode::WEBSOCKET_ERROR_NO_CLIENTCONTEX;
427 }
428 this->GetClientContext()->Push(data, length, LWS_WRITE_TEXT);
429 return WebSocketErrorCode::WEBSOCKET_NONE_ERR;
430 }
431
Close(CloseOption options)432 int WebSocketClient::Close(CloseOption options)
433 {
434 NETSTACK_LOGI("Close start");
435 if (this->GetClientContext() == nullptr) {
436 return WebSocketErrorCode::WEBSOCKET_ERROR_NO_CLIENTCONTEX;
437 }
438 if (this->GetClientContext()->openStatus == 0)
439 return WebSocketErrorCode::WEBSOCKET_ERROR_HAVE_NO_CONNECT;
440
441 if (options.reason == nullptr || options.code == 0) {
442 options.reason = "";
443 options.code = CLOSE_RESULT_FROM_CLIENT_CODE;
444 }
445 this->GetClientContext()->Close(static_cast<lws_close_status>(options.code), options.reason);
446 return WebSocketErrorCode::WEBSOCKET_NONE_ERR;
447 }
448
Registcallback(OnOpenCallback onOpen,OnMessageCallback onMessage,OnErrorCallback onError,OnCloseCallback onClose)449 int WebSocketClient::Registcallback(OnOpenCallback onOpen, OnMessageCallback onMessage, OnErrorCallback onError,
450 OnCloseCallback onClose)
451 {
452 onMessageCallback_ = onMessage;
453 onCloseCallback_ = onClose;
454 onErrorCallback_ = onError;
455 onOpenCallback_ = onOpen;
456 return WebSocketErrorCode::WEBSOCKET_NONE_ERR;
457 }
458
Destroy()459 int WebSocketClient::Destroy()
460 {
461 NETSTACK_LOGI("Destroy start");
462 if (this->GetClientContext()->GetContext() == nullptr) {
463 return WebSocketErrorCode::WEBSOCKET_ERROR_HAVE_NO_CONNECT_CONTEXT;
464 }
465 this->GetClientContext()->SetContext(nullptr);
466 lws_context_destroy(this->GetClientContext()->GetContext());
467 return WebSocketErrorCode::WEBSOCKET_NONE_ERR;
468 }
469
470 } // namespace OHOS::NetStack::WebSocketClient