1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "websocket.h"
17
18 #include "define.h"
19 #include "log_wrapper.h"
20 #include "securec.h"
21
22 namespace OHOS::ArkCompiler::Toolchain {
23 /**
24 * SendMessage in WebSocket has 3 situations:
25 * 1. message's length <= 125
26 * 2. message's length >= 126 && messages's length < 65536
27 * 3. message's length >= 65536
28 */
29
SendReply(const std::string & message) const30 void WebSocket::SendReply(const std::string& message) const
31 {
32 if (socketState_ != SocketState::CONNECTED) {
33 LOGE("SendReply failed, websocket not connected");
34 return;
35 }
36 const size_t msgLen = message.length();
37 std::unique_ptr<char []> msgBuf = std::make_unique<char []>(msgLen + 11); // 11: the maximum expandable length
38 char* sendBuf = msgBuf.get();
39 uint32_t sendMsgLen;
40 sendBuf[0] = 0x81; // 0x81: the text message sent by the server should start with '0x81'.
41
42 // Depending on the length of the messages, server will use shift operation to get the res
43 // and store them in the buffer.
44 if (msgLen <= 125) { // 125: situation 1 when message's length <= 125
45 sendBuf[1] = msgLen;
46 sendMsgLen = 2; // 2: the length of header frame is 2
47 } else if (msgLen < 65536) { // 65536: message's length
48 sendBuf[1] = 126; // 126: payloadLen according to the spec
49 sendBuf[2] = ((msgLen >> 8) & 0xff); // 8: shift right by 8 bits => res * (256^1)
50 sendBuf[3] = (msgLen & 0xff); // 3: store len's data => res * (256^0)
51 sendMsgLen = 4; // 4: the length of header frame is 4
52 } else {
53 sendBuf[1] = 127; // 127: payloadLen according to the spec
54 for (int32_t i = 2; i <= 5; i++) { // 2 ~ 5: unused bits
55 sendBuf[i] = 0;
56 }
57 sendBuf[6] = ((msgLen & 0xff000000) >> 24); // 6: shift 24 bits => res * (256^3)
58 sendBuf[7] = ((msgLen & 0x00ff0000) >> 16); // 7: shift 16 bits => res * (256^2)
59 sendBuf[8] = ((msgLen & 0x0000ff00) >> 8); // 8: shift 8 bits => res * (256^1)
60 sendBuf[9] = (msgLen & 0x000000ff); // 9: res * (256^0)
61 sendMsgLen = 10; // 10: the length of header frame is 10
62 }
63 if (memcpy_s(sendBuf + sendMsgLen, msgLen, message.c_str(), msgLen) != EOK) {
64 LOGE("SendReply: memcpy_s failed");
65 return;
66 }
67 sendBuf[sendMsgLen + msgLen] = '\0';
68 if (!Send(client_, sendBuf, sendMsgLen + msgLen, 0)) {
69 LOGE("SendReply: send failed");
70 return;
71 }
72 }
73
HttpProtocolDecode(const std::string & request,HttpProtocol & req)74 bool WebSocket::HttpProtocolDecode(const std::string& request, HttpProtocol& req)
75 {
76 if (request.find("GET") == std::string::npos) {
77 LOGE("Handshake failed: lack of necessary info");
78 return false;
79 }
80 std::vector<std::string> reqStr = ProtocolSplit(request, EOL);
81 for (size_t i = 0; i < reqStr.size(); i++) {
82 if (i == 0) {
83 std::vector<std::string> headers = ProtocolSplit(reqStr.at(i), " ");
84 req.version = headers.at(2); // 2: to get the version param
85 } else if (i < reqStr.size() - 1) {
86 std::vector<std::string> headers = ProtocolSplit(reqStr.at(i), ": ");
87 if (reqStr.at(i).find("Connection") != std::string::npos) {
88 req.connection = headers.at(1); // 1: to get the connection param
89 } else if (reqStr.at(i).find("Upgrade") != std::string::npos) {
90 req.upgrade = headers.at(1); // 1: to get the upgrade param
91 } else if (reqStr.at(i).find("Sec-WebSocket-Key") != std::string::npos) {
92 req.secWebSocketKey = headers.at(1); // 1: to get the secWebSocketKey param
93 }
94 }
95 }
96 return true;
97 }
98
99 /**
100 * The wired format of this data transmission section is described in detail through ABNFRFC5234.
101 * When receive the message, we should decode it according the spec. The structure is as follows:
102 * 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
103 * +-+-+-+-+-------+-+-------------+-------------------------------+
104 * |F|R|R|R| opcode|M| Payload len | Extended payload length |
105 * |I|S|S|S| (4) |A| (7) | (16/64) |
106 * |N|V|V|V| |S| | (if payload len==126/127) |
107 * | |1|2|3| |K| | |
108 * +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
109 * | Extended payload length continued, if payload len == 127 |
110 * + - - - - - - - - - - - - - - - +-------------------------------+
111 * | |Masking-key, if MASK set to 1 |
112 * +-------------------------------+-------------------------------+
113 * | Masking-key (continued) | Payload Data |
114 * +-------------------------------- - - - - - - - - - - - - - - - +
115 * : Payload Data continued ... :
116 * + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
117 * | Payload Data continued ... |
118 * +---------------------------------------------------------------+
119 */
120
HandleFrame(WebSocketFrame & wsFrame)121 bool WebSocket::HandleFrame(WebSocketFrame& wsFrame)
122 {
123 if (wsFrame.payloadLen == 126) { // 126: the payloadLen read from frame
124 char recvbuf[PAYLOAD_LEN + 1] = {0};
125 if (!Recv(client_, recvbuf, PAYLOAD_LEN, 0)) {
126 LOGE("HandleFrame: Recv payloadLen == 126 failed");
127 return false;
128 }
129
130 uint16_t msgLen = 0;
131 if (memcpy_s(&msgLen, sizeof(recvbuf), recvbuf, sizeof(recvbuf) - 1) != EOK) {
132 LOGE("HandleFrame: memcpy_s failed");
133 return false;
134 }
135 wsFrame.payloadLen = ntohs(msgLen);
136 } else if (wsFrame.payloadLen > 126) { // 126: the payloadLen read from frame
137 char recvbuf[EXTEND_PAYLOAD_LEN + 1] = {0};
138 if (!Recv(client_, recvbuf, EXTEND_PAYLOAD_LEN, 0)) {
139 LOGE("HandleFrame: Recv payloadLen > 127 failed");
140 return false;
141 }
142 wsFrame.payloadLen = NetToHostLongLong(recvbuf, EXTEND_PAYLOAD_LEN);
143 }
144 return DecodeMessage(wsFrame);
145 }
146
DecodeMessage(WebSocketFrame & wsFrame)147 bool WebSocket::DecodeMessage(WebSocketFrame& wsFrame)
148 {
149 if (wsFrame.payloadLen == 0 || wsFrame.payloadLen > UINT64_MAX) {
150 LOGE("ReadMsg length error, expected greater than zero and less than UINT64_MAX");
151 return false;
152 }
153 uint64_t msgLen = wsFrame.payloadLen;
154 wsFrame.payload = std::make_unique<char []>(msgLen + 1);
155 if (wsFrame.mask == 1) {
156 char buf[msgLen + 1];
157 if (!Recv(client_, wsFrame.maskingkey, SOCKET_MASK_LEN, 0)) {
158 LOGE("DecodeMessage: Recv maskingkey failed");
159 return false;
160 }
161
162 if (!Recv(client_, buf, msgLen, 0)) {
163 LOGE("DecodeMessage: Recv message with mask failed");
164 return false;
165 }
166
167 for (uint64_t i = 0; i < msgLen; i++) {
168 uint64_t j = i % SOCKET_MASK_LEN;
169 wsFrame.payload.get()[i] = buf[i] ^ wsFrame.maskingkey[j];
170 }
171 } else {
172 char buf[msgLen + 1];
173 if (!Recv(client_, buf, msgLen, 0)) {
174 LOGE("DecodeMessage: Recv message without mask failed");
175 return false;
176 }
177
178 if (memcpy_s(wsFrame.payload.get(), msgLen, buf, msgLen) != EOK) {
179 LOGE("DecodeMessage: memcpy_s failed");
180 return false;
181 }
182 }
183 wsFrame.payload.get()[msgLen] = '\0';
184 return true;
185 }
186
ProtocolUpgrade(const HttpProtocol & req)187 bool WebSocket::ProtocolUpgrade(const HttpProtocol& req)
188 {
189 std::string rawKey = req.secWebSocketKey + WEB_SOCKET_GUID;
190 unsigned const char* webSocketKey = reinterpret_cast<unsigned const char*>(std::move(rawKey).c_str());
191 unsigned char hash[SHA_DIGEST_LENGTH + 1];
192 SHA1(webSocketKey, strlen(reinterpret_cast<const char*>(webSocketKey)), hash);
193 hash[SHA_DIGEST_LENGTH] = '\0';
194 unsigned char encodedKey[ENCODED_KEY_LEN];
195 EVP_EncodeBlock(encodedKey, reinterpret_cast<const unsigned char*>(hash), SHA_DIGEST_LENGTH);
196 std::string response;
197
198 std::ostringstream sstream;
199 sstream << "HTTP/1.1 101 Switching Protocols" << EOL;
200 sstream << "Connection: upgrade" << EOL;
201 sstream << "Upgrade: websocket" << EOL;
202 sstream << "Sec-WebSocket-Accept: " << encodedKey << EOL;
203 sstream << EOL;
204 response = sstream.str();
205 if (!Send(client_, response.c_str(), response.length(), 0)) {
206 LOGE("ProtocolUpgrade: Send failed");
207 return false;
208 }
209 return true;
210 }
211
Decode()212 std::string WebSocket::Decode()
213 {
214 if (socketState_ != SocketState::CONNECTED) {
215 LOGE("Decode failed, websocket not connected!");
216 return "";
217 }
218 char recvbuf[SOCKET_HEADER_LEN + 1];
219 if (!Recv(client_, recvbuf, SOCKET_HEADER_LEN, 0)) {
220 LOGE("Decode failed, client websocket disconnect");
221 socketState_ = SocketState::INITED;
222 #if defined(OHOS_PLATFORM)
223 shutdown(client_, SHUT_RDWR);
224 close(client_);
225 client_ = -1;
226 #else
227 close(client_);
228 client_ = -1;
229 #endif
230 return "";
231 }
232 WebSocketFrame wsFrame;
233 int32_t index = 0;
234 wsFrame.fin = static_cast<uint8_t>(recvbuf[index] >> 7); // 7: shift right by 7 bits to get the fin
235 wsFrame.opcode = static_cast<uint8_t>(recvbuf[index] & 0xf);
236 if (wsFrame.opcode == 0x1) { // 0x1: 0x1 means a text frame
237 index++;
238 wsFrame.mask = static_cast<uint8_t>((recvbuf[index] >> 7) & 0x1); // 7: to get the mask
239 wsFrame.payloadLen = recvbuf[index] & 0x7f;
240 if (HandleFrame(wsFrame)) {
241 return wsFrame.payload.get();
242 }
243 return "";
244 } else if (wsFrame.opcode == 0x9) { // 0x9: 0x9 means a ping frame
245 // send pong frame
246 char pongFrame[SOCKET_HEADER_LEN] = {0};
247 pongFrame[0] = 0x8a; // 0x8a: 0x8a means a pong frame
248 pongFrame[1] = 0x0;
249 if (!Send(client_, pongFrame, SOCKET_HEADER_LEN, 0)) {
250 LOGE("Decode: Send pong frame failed");
251 return "";
252 }
253 }
254 return "";
255 }
256
HttpHandShake()257 bool WebSocket::HttpHandShake()
258 {
259 char msgBuf[SOCKET_HANDSHAKE_LEN];
260 int32_t msgLen = recv(client_, msgBuf, SOCKET_HANDSHAKE_LEN, 0);
261 if (msgLen <= 0) {
262 LOGE("ReadMsg failed readRet=%{public}d", msgLen);
263 return false;
264 } else {
265 msgBuf[msgLen - 1] = '\0';
266 HttpProtocol req;
267 if (!HttpProtocolDecode(msgBuf, req)) {
268 LOGE("HttpHandShake: Upgrade failed");
269 return false;
270 } else if (req.connection.find("Upgrade") != std::string::npos &&
271 req.upgrade.find("websocket") != std::string::npos && req.version.compare("HTTP/1.1") == 0) {
272 ProtocolUpgrade(req);
273 }
274 }
275 return true;
276 }
277
278 #if !defined(OHOS_PLATFORM)
InitTcpWebSocket(uint32_t timeoutLimit)279 bool WebSocket::InitTcpWebSocket(uint32_t timeoutLimit)
280 {
281 if (socketState_ != SocketState::UNINITED) {
282 LOGI("InitTcpWebSocket websocket has inited");
283 return true;
284 }
285 #if defined(WINDOWS_PLATFORM)
286 WORD sockVersion = MAKEWORD(2, 2); // 2: version 2.2
287 WSADATA wsaData;
288 if (WSAStartup(sockVersion, &wsaData) != 0) {
289 LOGE("InitTcpWebSocket WSA init failed");
290 return false;
291 }
292 #endif
293 fd_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
294 if (fd_ < SOCKET_SUCCESS) {
295 LOGE("InitTcpWebSocket socket init failed");
296 return false;
297 }
298 // allow specified port can be used at once and not wait TIME_WAIT status ending
299 int sockOptVal = 1;
300 if ((setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &sockOptVal, sizeof(sockOptVal))) != SOCKET_SUCCESS) {
301 LOGE("InitTcpWebSocket setsockopt SO_REUSEADDR failed");
302 close(fd_);
303 fd_ = -1;
304 return false;
305 }
306
307 // set send and recv timeout
308 if (!SetWebSocketTimeOut(fd_, timeoutLimit)) {
309 LOGE("InitTcpWebSocket SetWebSocketTimeOut failed");
310 close(fd_);
311 fd_ = -1;
312 return false;
313 }
314
315 sockaddr_in addr_sin = {0};
316 addr_sin.sin_family = AF_INET;
317 addr_sin.sin_port = htons(9230); // 9230: sockName for tcp
318 addr_sin.sin_addr.s_addr = INADDR_ANY;
319 if (bind(fd_, reinterpret_cast<struct sockaddr*>(&addr_sin), sizeof(addr_sin)) < SOCKET_SUCCESS) {
320 LOGE("InitTcpWebSocket bind failed");
321 close(fd_);
322 fd_ = -1;
323 return false;
324 }
325 if (listen(fd_, 1) < SOCKET_SUCCESS) {
326 LOGE("InitTcpWebSocket listen failed");
327 close(fd_);
328 fd_ = -1;
329 return false;
330 }
331 socketState_ = SocketState::INITED;
332 return true;
333 }
334
ConnectTcpWebSocket()335 bool WebSocket::ConnectTcpWebSocket()
336 {
337 if (socketState_ == SocketState::UNINITED) {
338 LOGE("ConnectTcpWebSocket failed, websocket not inited");
339 return false;
340 }
341 if (socketState_ == SocketState::CONNECTED) {
342 LOGI("ConnectTcpWebSocket websocket has connected");
343 return true;
344 }
345
346 if ((client_ = accept(fd_, nullptr, nullptr)) < SOCKET_SUCCESS) {
347 LOGE("ConnectTcpWebSocket accept failed");
348 socketState_ = SocketState::UNINITED;
349 close(fd_);
350 fd_ = -1;
351 return false;
352 }
353
354 if (!HttpHandShake()) {
355 LOGE("ConnectTcpWebSocket HttpHandShake failed");
356 socketState_ = SocketState::UNINITED;
357 close(client_);
358 client_ = -1;
359 close(fd_);
360 fd_ = -1;
361 return false;
362 }
363 socketState_ = SocketState::CONNECTED;
364 return true;
365 }
366 #else
InitUnixWebSocket(const std::string & sockName,uint32_t timeoutLimit)367 bool WebSocket::InitUnixWebSocket(const std::string& sockName, uint32_t timeoutLimit)
368 {
369 if (socketState_ != SocketState::UNINITED) {
370 LOGI("InitUnixWebSocket websocket has inited");
371 return true;
372 }
373 fd_ = socket(AF_UNIX, SOCK_STREAM, 0); // 0: defautlt protocol
374 if (fd_ < SOCKET_SUCCESS) {
375 LOGE("InitUnixWebSocket socket init failed");
376 return false;
377 }
378 // set send and recv timeout
379 if (!SetWebSocketTimeOut(fd_, timeoutLimit)) {
380 LOGE("InitUnixWebSocket SetWebSocketTimeOut failed");
381 close(fd_);
382 fd_ = -1;
383 return false;
384 }
385
386 struct sockaddr_un un;
387 if (memset_s(&un, sizeof(un), 0, sizeof(un)) != EOK) {
388 LOGE("InitUnixWebSocket memset_s failed");
389 close(fd_);
390 fd_ = -1;
391 return false;
392 }
393 un.sun_family = AF_UNIX;
394 if (strcpy_s(un.sun_path + 1, sizeof(un.sun_path) - 1, sockName.c_str()) != EOK) {
395 LOGE("InitUnixWebSocket strcpy_s failed");
396 close(fd_);
397 fd_ = -1;
398 return false;
399 }
400 un.sun_path[0] = '\0';
401 uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1;
402 if (bind(fd_, reinterpret_cast<struct sockaddr*>(&un), static_cast<int32_t>(len)) < SOCKET_SUCCESS) {
403 LOGE("InitUnixWebSocket bind failed");
404 close(fd_);
405 fd_ = -1;
406 return false;
407 }
408 if (listen(fd_, 1) < SOCKET_SUCCESS) { // 1: connection num
409 LOGE("InitUnixWebSocket listen failed");
410 close(fd_);
411 fd_ = -1;
412 return false;
413 }
414 socketState_ = SocketState::INITED;
415 return true;
416 }
417
ConnectUnixWebSocket()418 bool WebSocket::ConnectUnixWebSocket()
419 {
420 if (socketState_ == SocketState::UNINITED) {
421 LOGE("ConnectUnixWebSocket failed, websocket not inited");
422 return false;
423 }
424 if (socketState_ == SocketState::CONNECTED) {
425 LOGI("ConnectUnixWebSocket websocket has connected");
426 return true;
427 }
428
429 if ((client_ = accept(fd_, nullptr, nullptr)) < SOCKET_SUCCESS) {
430 LOGE("ConnectUnixWebSocket accept failed");
431 socketState_ = SocketState::UNINITED;
432 close(fd_);
433 fd_ = -1;
434 return false;
435 }
436 if (!HttpHandShake()) {
437 LOGE("ConnectUnixWebSocket HttpHandShake failed");
438 socketState_ = SocketState::UNINITED;
439 shutdown(client_, SHUT_RDWR);
440 close(client_);
441 client_ = -1;
442 shutdown(fd_, SHUT_RDWR);
443 close(fd_);
444 fd_ = -1;
445 return false;
446 }
447 socketState_ = SocketState::CONNECTED;
448 return true;
449 }
450 #endif
451
IsConnected()452 bool WebSocket::IsConnected()
453 {
454 return socketState_ == SocketState::CONNECTED;
455 }
456
Close()457 void WebSocket::Close()
458 {
459 if (socketState_ == SocketState::UNINITED) {
460 return;
461 }
462 if (socketState_ == SocketState::CONNECTED) {
463 #if defined(OHOS_PLATFORM)
464 shutdown(client_, SHUT_RDWR);
465 #endif
466 close(client_);
467 client_ = -1;
468 }
469 socketState_ = SocketState::UNINITED;
470 usleep(10000); // 10000: time for websocket to enter the accept
471 #if defined(OHOS_PLATFORM)
472 shutdown(fd_, SHUT_RDWR);
473 #endif
474 close(fd_);
475 fd_ = -1;
476 }
477
NetToHostLongLong(char * buf,uint32_t len)478 uint64_t WebSocket::NetToHostLongLong(char* buf, uint32_t len)
479 {
480 uint64_t result = 0;
481 for (uint32_t i = 0; i < len; i++) {
482 result |= static_cast<unsigned char>(buf[i]);
483 if ((i + 1) < len) {
484 result <<= 8; // 8: result need shift left 8 bits in order to big endian convert to int
485 }
486 }
487 return result;
488 }
489
Recv(int32_t client,char * buf,size_t totalLen,int32_t flags) const490 bool WebSocket::Recv(int32_t client, char* buf, size_t totalLen, int32_t flags) const
491 {
492 size_t recvLen = 0;
493 while (recvLen < totalLen) {
494 ssize_t len = recv(client, buf + recvLen, totalLen - recvLen, flags);
495 if (len <= 0) {
496 LOGE("Recv payload in while failed, websocket disconnect");
497 return false;
498 }
499 recvLen += static_cast<size_t>(len);
500 }
501 buf[totalLen] = '\0';
502 return true;
503 }
504
Send(int32_t client,const char * buf,size_t totalLen,int32_t flags) const505 bool WebSocket::Send(int32_t client, const char* buf, size_t totalLen, int32_t flags) const
506 {
507 size_t sendLen = 0;
508 while (sendLen < totalLen) {
509 ssize_t len = send(client, buf + sendLen, totalLen - sendLen, flags);
510 if (len <= 0) {
511 LOGE("Send Message in while failed, websocket disconnect");
512 return false;
513 }
514 sendLen += static_cast<size_t>(len);
515 }
516 return true;
517 }
518
SetWebSocketTimeOut(int32_t fd,uint32_t timeoutLimit)519 bool WebSocket::SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit)
520 {
521 if (timeoutLimit > 0) {
522 struct timeval timeout = {timeoutLimit, 0};
523 if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) != SOCKET_SUCCESS) {
524 LOGE("SetWebSocketTimeOut setsockopt SO_SNDTIMEO failed");
525 return false;
526 }
527 if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) != SOCKET_SUCCESS) {
528 LOGE("SetWebSocketTimeOut setsockopt SO_RCVTIMEO failed");
529 return false;
530 }
531 }
532 return true;
533 }
534 } // namespace OHOS::ArkCompiler::Toolchain
535