1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "websocket.h"
17
18 #include "define.h"
19 #include "log_wrapper.h"
20 #include "securec.h"
21
22 namespace OHOS::ArkCompiler::Toolchain {
23 /**
24 * SendMessage in WebSocket has 3 situations:
25 * 1. message's length <= 125
26 * 2. message's length >= 126 && messages's length < 65536
27 * 3. message's length >= 65536
28 */
29
SendReply(const std::string & message) const30 void WebSocket::SendReply(const std::string& message) const
31 {
32 if (socketState_ != SocketState::CONNECTED) {
33 LOGE("SendReply failed, websocket not connected");
34 return;
35 }
36 const size_t msgLen = message.length();
37 std::unique_ptr<char []> msgBuf = std::make_unique<char []>(msgLen + 11); // 11: the maximum expandable length
38 char* sendBuf = msgBuf.get();
39 uint32_t sendMsgLen;
40 sendBuf[0] = 0x81; // 0x81: the text message sent by the server should start with '0x81'.
41
42 // Depending on the length of the messages, server will use shift operation to get the res
43 // and store them in the buffer.
44 if (msgLen <= 125) { // 125: situation 1 when message's length <= 125
45 sendBuf[1] = msgLen;
46 sendMsgLen = 2; // 2: the length of header frame is 2
47 } else if (msgLen < 65536) { // 65536: message's length
48 sendBuf[1] = 126; // 126: payloadLen according to the spec
49 sendBuf[2] = ((msgLen >> 8) & 0xff); // 8: shift right by 8 bits => res * (256^1)
50 sendBuf[3] = (msgLen & 0xff); // 3: store len's data => res * (256^0)
51 sendMsgLen = 4; // 4: the length of header frame is 4
52 } else {
53 sendBuf[1] = 127; // 127: payloadLen according to the spec
54 for (int32_t i = 2; i <= 5; i++) { // 2 ~ 5: unused bits
55 sendBuf[i] = 0;
56 }
57 sendBuf[6] = ((msgLen & 0xff000000) >> 24); // 6: shift 24 bits => res * (256^3)
58 sendBuf[7] = ((msgLen & 0x00ff0000) >> 16); // 7: shift 16 bits => res * (256^2)
59 sendBuf[8] = ((msgLen & 0x0000ff00) >> 8); // 8: shift 8 bits => res * (256^1)
60 sendBuf[9] = (msgLen & 0x000000ff); // 9: res * (256^0)
61 sendMsgLen = 10; // 10: the length of header frame is 10
62 }
63 if (memcpy_s(sendBuf + sendMsgLen, msgLen, message.c_str(), msgLen) != EOK) {
64 LOGE("SendReply: memcpy_s failed");
65 return;
66 }
67 sendBuf[sendMsgLen + msgLen] = '\0';
68 if (!Send(client_, sendBuf, sendMsgLen + msgLen, 0)) {
69 LOGE("SendReply: send failed");
70 return;
71 }
72 }
73
HttpProtocolDecode(const std::string & request,HttpProtocol & req)74 bool WebSocket::HttpProtocolDecode(const std::string& request, HttpProtocol& req)
75 {
76 if (request.find("GET") == std::string::npos) {
77 LOGE("Handshake failed: lack of necessary info");
78 return false;
79 }
80 std::vector<std::string> reqStr = ProtocolSplit(request, EOL);
81 for (size_t i = 0; i < reqStr.size(); i++) {
82 if (i == 0) {
83 std::vector<std::string> headers = ProtocolSplit(reqStr.at(i), " ");
84 req.version = headers.at(2); // 2: to get the version param
85 } else if (i < reqStr.size() - 1) {
86 std::vector<std::string> headers = ProtocolSplit(reqStr.at(i), ": ");
87 if (reqStr.at(i).find("Connection") != std::string::npos) {
88 req.connection = headers.at(1); // 1: to get the connection param
89 } else if (reqStr.at(i).find("Upgrade") != std::string::npos) {
90 req.upgrade = headers.at(1); // 1: to get the upgrade param
91 } else if (reqStr.at(i).find("Sec-WebSocket-Key") != std::string::npos) {
92 req.secWebSocketKey = headers.at(1); // 1: to get the secWebSocketKey param
93 }
94 }
95 }
96 return true;
97 }
98
99 /**
100 * The wired format of this data transmission section is described in detail through ABNFRFC5234.
101 * When receive the message, we should decode it according the spec. The structure is as follows:
102 * 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
103 * +-+-+-+-+-------+-+-------------+-------------------------------+
104 * |F|R|R|R| opcode|M| Payload len | Extended payload length |
105 * |I|S|S|S| (4) |A| (7) | (16/64) |
106 * |N|V|V|V| |S| | (if payload len==126/127) |
107 * | |1|2|3| |K| | |
108 * +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
109 * | Extended payload length continued, if payload len == 127 |
110 * + - - - - - - - - - - - - - - - +-------------------------------+
111 * | |Masking-key, if MASK set to 1 |
112 * +-------------------------------+-------------------------------+
113 * | Masking-key (continued) | Payload Data |
114 * +-------------------------------- - - - - - - - - - - - - - - - +
115 * : Payload Data continued ... :
116 * + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
117 * | Payload Data continued ... |
118 * +---------------------------------------------------------------+
119 */
120
HandleFrame(WebSocketFrame & wsFrame)121 bool WebSocket::HandleFrame(WebSocketFrame& wsFrame)
122 {
123 if (wsFrame.payloadLen == 126) { // 126: the payloadLen read from frame
124 char recvbuf[PAYLOAD_LEN + 1] = {0};
125 if (!Recv(client_, recvbuf, PAYLOAD_LEN, 0)) {
126 LOGE("HandleFrame: Recv payloadLen == 126 failed");
127 return false;
128 }
129
130 uint16_t msgLen = 0;
131 if (memcpy_s(&msgLen, sizeof(recvbuf), recvbuf, sizeof(recvbuf) - 1) != EOK) {
132 LOGE("HandleFrame: memcpy_s failed");
133 return false;
134 }
135 wsFrame.payloadLen = ntohs(msgLen);
136 } else if (wsFrame.payloadLen > 126) { // 126: the payloadLen read from frame
137 char recvbuf[EXTEND_PAYLOAD_LEN + 1] = {0};
138 if (!Recv(client_, recvbuf, EXTEND_PAYLOAD_LEN, 0)) {
139 LOGE("HandleFrame: Recv payloadLen > 127 failed");
140 return false;
141 }
142 wsFrame.payloadLen = NetToHostLongLong(recvbuf, EXTEND_PAYLOAD_LEN);
143 }
144 return DecodeMessage(wsFrame);
145 }
146
DecodeMessage(WebSocketFrame & wsFrame)147 bool WebSocket::DecodeMessage(WebSocketFrame& wsFrame)
148 {
149 if (wsFrame.payloadLen == 0 || wsFrame.payloadLen > UINT64_MAX) {
150 LOGE("ReadMsg length error, expected greater than zero and less than UINT64_MAX");
151 return false;
152 }
153 uint64_t msgLen = wsFrame.payloadLen;
154 wsFrame.payload = std::make_unique<char []>(msgLen + 1);
155 if (wsFrame.mask == 1) {
156 char buf[msgLen + 1];
157 if (!Recv(client_, wsFrame.maskingkey, SOCKET_MASK_LEN, 0)) {
158 LOGE("DecodeMessage: Recv maskingkey failed");
159 return false;
160 }
161
162 if (!Recv(client_, buf, msgLen, 0)) {
163 LOGE("DecodeMessage: Recv message with mask failed");
164 return false;
165 }
166
167 for (uint64_t i = 0; i < msgLen; i++) {
168 uint64_t j = i % SOCKET_MASK_LEN;
169 wsFrame.payload.get()[i] = buf[i] ^ wsFrame.maskingkey[j];
170 }
171 } else {
172 char buf[msgLen + 1];
173 if (!Recv(client_, buf, msgLen, 0)) {
174 LOGE("DecodeMessage: Recv message without mask failed");
175 return false;
176 }
177
178 if (memcpy_s(wsFrame.payload.get(), msgLen, buf, msgLen) != EOK) {
179 LOGE("DecodeMessage: memcpy_s failed");
180 return false;
181 }
182 }
183 wsFrame.payload.get()[msgLen] = '\0';
184 return true;
185 }
186
ProtocolUpgrade(const HttpProtocol & req)187 bool WebSocket::ProtocolUpgrade(const HttpProtocol& req)
188 {
189 std::string rawKey = req.secWebSocketKey + WEB_SOCKET_GUID;
190 unsigned const char* webSocketKey = reinterpret_cast<unsigned const char*>(std::move(rawKey).c_str());
191 unsigned char hash[SHA_DIGEST_LENGTH + 1];
192 SHA1(webSocketKey, strlen(reinterpret_cast<const char*>(webSocketKey)), hash);
193 hash[SHA_DIGEST_LENGTH] = '\0';
194 unsigned char encodedKey[ENCODED_KEY_LEN];
195 EVP_EncodeBlock(encodedKey, reinterpret_cast<const unsigned char*>(hash), SHA_DIGEST_LENGTH);
196 std::string response;
197
198 std::ostringstream sstream;
199 sstream << "HTTP/1.1 101 Switching Protocols" << EOL;
200 sstream << "Connection: upgrade" << EOL;
201 sstream << "Upgrade: websocket" << EOL;
202 sstream << "Sec-WebSocket-Accept: " << encodedKey << EOL;
203 sstream << EOL;
204 response = sstream.str();
205 if (!Send(client_, response.c_str(), response.length(), 0)) {
206 LOGE("ProtocolUpgrade: Send failed");
207 return false;
208 }
209 return true;
210 }
211
Decode()212 std::string WebSocket::Decode()
213 {
214 if (socketState_ != SocketState::CONNECTED) {
215 LOGE("Decode failed, websocket not connected!");
216 return "";
217 }
218 char recvbuf[SOCKET_HEADER_LEN + 1];
219 if (!Recv(client_, recvbuf, SOCKET_HEADER_LEN, 0)) {
220 LOGE("Decode failed, client websocket disconnect");
221 socketState_ = SocketState::INITED;
222 #if defined(OHOS_PLATFORM)
223 shutdown(client_, SHUT_RDWR);
224 close(client_);
225 client_ = -1;
226 #else
227 close(client_);
228 client_ = -1;
229 #endif
230 return "";
231 }
232 WebSocketFrame wsFrame;
233 int32_t index = 0;
234 wsFrame.fin = static_cast<uint8_t>(recvbuf[index] >> 7); // 7: shift right by 7 bits to get the fin
235 wsFrame.opcode = static_cast<uint8_t>(recvbuf[index] & 0xf);
236 if (wsFrame.opcode == 0x1) { // 0x1: 0x1 means a text frame
237 index++;
238 wsFrame.mask = static_cast<uint8_t>((recvbuf[index] >> 7) & 0x1); // 7: to get the mask
239 wsFrame.payloadLen = recvbuf[index] & 0x7f;
240 if (HandleFrame(wsFrame)) {
241 return wsFrame.payload.get();
242 }
243 return "";
244 }
245 return "";
246 }
247
HttpHandShake()248 bool WebSocket::HttpHandShake()
249 {
250 char msgBuf[SOCKET_HANDSHAKE_LEN];
251 int32_t msgLen = recv(client_, msgBuf, SOCKET_HANDSHAKE_LEN, 0);
252 if (msgLen <= 0) {
253 LOGE("ReadMsg failed readRet=%{public}d", msgLen);
254 return false;
255 } else {
256 msgBuf[msgLen - 1] = '\0';
257 HttpProtocol req;
258 if (!HttpProtocolDecode(msgBuf, req)) {
259 LOGE("HttpHandShake: Upgrade failed");
260 return false;
261 } else if (req.connection.find("Upgrade") != std::string::npos &&
262 req.upgrade.find("websocket") != std::string::npos && req.version.compare("HTTP/1.1") == 0) {
263 ProtocolUpgrade(req);
264 }
265 }
266 return true;
267 }
268
269 #if !defined(OHOS_PLATFORM)
InitTcpWebSocket(uint32_t timeoutLimit)270 bool WebSocket::InitTcpWebSocket(uint32_t timeoutLimit)
271 {
272 if (socketState_ != SocketState::UNINITED) {
273 LOGI("InitTcpWebSocket websocket has inited");
274 return true;
275 }
276 #if defined(WINDOWS_PLATFORM)
277 WORD sockVersion = MAKEWORD(2, 2); // 2: version 2.2
278 WSADATA wsaData;
279 if (WSAStartup(sockVersion, &wsaData) != 0) {
280 LOGE("InitTcpWebSocket WSA init failed");
281 return false;
282 }
283 #endif
284 fd_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
285 if (fd_ < SOCKET_SUCCESS) {
286 LOGE("InitTcpWebSocket socket init failed");
287 return false;
288 }
289 // allow specified port can be used at once and not wait TIME_WAIT status ending
290 int sockOptVal = 1;
291 if ((setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &sockOptVal, sizeof(sockOptVal))) != SOCKET_SUCCESS) {
292 LOGE("InitTcpWebSocket setsockopt SO_REUSEADDR failed");
293 close(fd_);
294 fd_ = -1;
295 return false;
296 }
297
298 // set send and recv timeout
299 if (!SetWebSocketTimeOut(fd_, timeoutLimit)) {
300 LOGE("InitTcpWebSocket SetWebSocketTimeOut failed");
301 close(fd_);
302 fd_ = -1;
303 return false;
304 }
305
306 sockaddr_in addr_sin = {0};
307 addr_sin.sin_family = AF_INET;
308 addr_sin.sin_port = htons(9230); // 9230: sockName for tcp
309 addr_sin.sin_addr.s_addr = INADDR_ANY;
310 if (bind(fd_, reinterpret_cast<struct sockaddr*>(&addr_sin), sizeof(addr_sin)) < SOCKET_SUCCESS) {
311 LOGE("InitTcpWebSocket bind failed");
312 close(fd_);
313 fd_ = -1;
314 return false;
315 }
316 if (listen(fd_, 1) < SOCKET_SUCCESS) {
317 LOGE("InitTcpWebSocket listen failed");
318 close(fd_);
319 fd_ = -1;
320 return false;
321 }
322 socketState_ = SocketState::INITED;
323 return true;
324 }
325
ConnectTcpWebSocket()326 bool WebSocket::ConnectTcpWebSocket()
327 {
328 if (socketState_ == SocketState::UNINITED) {
329 LOGE("ConnectTcpWebSocket failed, websocket not inited");
330 return false;
331 }
332 if (socketState_ == SocketState::CONNECTED) {
333 LOGI("ConnectTcpWebSocket websocket has connected");
334 return true;
335 }
336
337 if ((client_ = accept(fd_, nullptr, nullptr)) < SOCKET_SUCCESS) {
338 LOGE("ConnectTcpWebSocket accept failed");
339 socketState_ = SocketState::UNINITED;
340 close(fd_);
341 fd_ = -1;
342 return false;
343 }
344
345 if (!HttpHandShake()) {
346 LOGE("ConnectTcpWebSocket HttpHandShake failed");
347 socketState_ = SocketState::UNINITED;
348 close(client_);
349 client_ = -1;
350 close(fd_);
351 fd_ = -1;
352 return false;
353 }
354 socketState_ = SocketState::CONNECTED;
355 return true;
356 }
357 #else
InitUnixWebSocket(const std::string & sockName,uint32_t timeoutLimit)358 bool WebSocket::InitUnixWebSocket(const std::string& sockName, uint32_t timeoutLimit)
359 {
360 if (socketState_ != SocketState::UNINITED) {
361 LOGI("InitUnixWebSocket websocket has inited");
362 return true;
363 }
364 fd_ = socket(AF_UNIX, SOCK_STREAM, 0); // 0: defautlt protocol
365 if (fd_ < SOCKET_SUCCESS) {
366 LOGE("InitUnixWebSocket socket init failed");
367 return false;
368 }
369 // set send and recv timeout
370 if (!SetWebSocketTimeOut(fd_, timeoutLimit)) {
371 LOGE("InitUnixWebSocket SetWebSocketTimeOut failed");
372 close(fd_);
373 fd_ = -1;
374 return false;
375 }
376
377 struct sockaddr_un un;
378 if (memset_s(&un, sizeof(un), 0, sizeof(un)) != EOK) {
379 LOGE("InitUnixWebSocket memset_s failed");
380 close(fd_);
381 fd_ = -1;
382 return false;
383 }
384 un.sun_family = AF_UNIX;
385 if (strcpy_s(un.sun_path + 1, sizeof(un.sun_path) - 1, sockName.c_str()) != EOK) {
386 LOGE("InitUnixWebSocket strcpy_s failed");
387 close(fd_);
388 fd_ = -1;
389 return false;
390 }
391 un.sun_path[0] = '\0';
392 uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1;
393 if (bind(fd_, reinterpret_cast<struct sockaddr*>(&un), static_cast<int32_t>(len)) < SOCKET_SUCCESS) {
394 LOGE("InitUnixWebSocket bind failed");
395 close(fd_);
396 fd_ = -1;
397 return false;
398 }
399 if (listen(fd_, 1) < SOCKET_SUCCESS) { // 1: connection num
400 LOGE("InitUnixWebSocket listen failed");
401 close(fd_);
402 fd_ = -1;
403 return false;
404 }
405 socketState_ = SocketState::INITED;
406 return true;
407 }
408
ConnectUnixWebSocket()409 bool WebSocket::ConnectUnixWebSocket()
410 {
411 if (socketState_ == SocketState::UNINITED) {
412 LOGE("ConnectUnixWebSocket failed, websocket not inited");
413 return false;
414 }
415 if (socketState_ == SocketState::CONNECTED) {
416 LOGI("ConnectUnixWebSocket websocket has connected");
417 return true;
418 }
419
420 if ((client_ = accept(fd_, nullptr, nullptr)) < SOCKET_SUCCESS) {
421 LOGE("ConnectUnixWebSocket accept failed");
422 socketState_ = SocketState::UNINITED;
423 close(fd_);
424 fd_ = -1;
425 return false;
426 }
427 if (!HttpHandShake()) {
428 LOGE("ConnectUnixWebSocket HttpHandShake failed");
429 socketState_ = SocketState::UNINITED;
430 shutdown(client_, SHUT_RDWR);
431 close(client_);
432 client_ = -1;
433 shutdown(fd_, SHUT_RDWR);
434 close(fd_);
435 fd_ = -1;
436 return false;
437 }
438 socketState_ = SocketState::CONNECTED;
439 return true;
440 }
441 #endif
442
IsConnected()443 bool WebSocket::IsConnected()
444 {
445 return socketState_ == SocketState::CONNECTED;
446 }
447
Close()448 void WebSocket::Close()
449 {
450 if (socketState_ == SocketState::UNINITED) {
451 return;
452 }
453 if (socketState_ == SocketState::CONNECTED) {
454 #if defined(OHOS_PLATFORM)
455 shutdown(client_, SHUT_RDWR);
456 #endif
457 close(client_);
458 client_ = -1;
459 }
460 socketState_ = SocketState::UNINITED;
461 usleep(10000); // 10000: time for websocket to enter the accept
462 #if defined(OHOS_PLATFORM)
463 shutdown(fd_, SHUT_RDWR);
464 #endif
465 close(fd_);
466 fd_ = -1;
467 }
468
NetToHostLongLong(char * buf,uint32_t len)469 uint64_t WebSocket::NetToHostLongLong(char* buf, uint32_t len)
470 {
471 uint64_t result = 0;
472 for (uint32_t i = 0; i < len; i++) {
473 result |= static_cast<unsigned char>(buf[i]);
474 if ((i + 1) < len) {
475 result <<= 8; // 8: result need shift left 8 bits in order to big endian convert to int
476 }
477 }
478 return result;
479 }
480
Recv(int32_t client,char * buf,size_t totalLen,int32_t flags) const481 bool WebSocket::Recv(int32_t client, char* buf, size_t totalLen, int32_t flags) const
482 {
483 size_t recvLen = 0;
484 while (recvLen < totalLen) {
485 ssize_t len = recv(client, buf + recvLen, totalLen - recvLen, flags);
486 if (len <= 0) {
487 LOGE("Recv payload in while failed, websocket disconnect");
488 return false;
489 }
490 recvLen += static_cast<size_t>(len);
491 }
492 buf[totalLen] = '\0';
493 return true;
494 }
495
Send(int32_t client,const char * buf,size_t totalLen,int32_t flags) const496 bool WebSocket::Send(int32_t client, const char* buf, size_t totalLen, int32_t flags) const
497 {
498 size_t sendLen = 0;
499 while (sendLen < totalLen) {
500 ssize_t len = send(client, buf + sendLen, totalLen - sendLen, flags);
501 if (len <= 0) {
502 LOGE("Send Message in while failed, websocket disconnect");
503 return false;
504 }
505 sendLen += static_cast<size_t>(len);
506 }
507 return true;
508 }
509
SetWebSocketTimeOut(int32_t fd,uint32_t timeoutLimit)510 bool WebSocket::SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit)
511 {
512 if (timeoutLimit > 0) {
513 struct timeval timeout = {timeoutLimit, 0};
514 if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) != SOCKET_SUCCESS) {
515 LOGE("SetWebSocketTimeOut setsockopt SO_SNDTIMEO failed");
516 return false;
517 }
518 if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) != SOCKET_SUCCESS) {
519 LOGE("SetWebSocketTimeOut setsockopt SO_RCVTIMEO failed");
520 return false;
521 }
522 }
523 return true;
524 }
525 } // namespace OHOS::ArkCompiler::Toolchain
526