• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "websocket.h"
17 
18 #include "define.h"
19 #include "log_wrapper.h"
20 #include "securec.h"
21 
22 namespace OHOS::ArkCompiler::Toolchain {
23 /**
24  *  SendMessage in WebSocket has 3 situations:
25  *    1. message's length <= 125
26  *    2. message's length >= 126 && messages's length < 65536
27  *    3. message's length >= 65536
28  */
29 
SendReply(const std::string & message) const30 void WebSocket::SendReply(const std::string& message) const
31 {
32     if (socketState_ != SocketState::CONNECTED) {
33         LOGE("SendReply failed, websocket not connected");
34         return;
35     }
36     const size_t msgLen = message.length();
37     std::unique_ptr<char []> msgBuf = std::make_unique<char []>(msgLen + 11); // 11: the maximum expandable length
38     char* sendBuf = msgBuf.get();
39     uint32_t sendMsgLen;
40     sendBuf[0] = 0x81; // 0x81: the text message sent by the server should start with '0x81'.
41 
42     // Depending on the length of the messages, server will use shift operation to get the res
43     // and store them in the buffer.
44     if (msgLen <= 125) { // 125: situation 1 when message's length <= 125
45         sendBuf[1] = msgLen;
46         sendMsgLen = 2; // 2: the length of header frame is 2
47     } else if (msgLen < 65536) { // 65536: message's length
48         sendBuf[1] = 126; // 126: payloadLen according to the spec
49         sendBuf[2] = ((msgLen >> 8) & 0xff); // 8: shift right by 8 bits => res * (256^1)
50         sendBuf[3] = (msgLen & 0xff); // 3: store len's data => res * (256^0)
51         sendMsgLen = 4; // 4: the length of header frame is 4
52     } else {
53         sendBuf[1] = 127; // 127: payloadLen according to the spec
54         for (int32_t i = 2; i <= 5; i++) { // 2 ~ 5: unused bits
55             sendBuf[i] = 0;
56         }
57         sendBuf[6] = ((msgLen & 0xff000000) >> 24); // 6: shift 24 bits => res * (256^3)
58         sendBuf[7] = ((msgLen & 0x00ff0000) >> 16); // 7: shift 16 bits => res * (256^2)
59         sendBuf[8] = ((msgLen & 0x0000ff00) >> 8);  // 8: shift 8 bits => res * (256^1)
60         sendBuf[9] = (msgLen & 0x000000ff); // 9: res * (256^0)
61         sendMsgLen = 10; // 10: the length of header frame is 10
62     }
63     if (memcpy_s(sendBuf + sendMsgLen, msgLen, message.c_str(), msgLen) != EOK) {
64         LOGE("SendReply: memcpy_s failed");
65         return;
66     }
67     sendBuf[sendMsgLen + msgLen] = '\0';
68     if (!Send(client_, sendBuf, sendMsgLen + msgLen, 0)) {
69         LOGE("SendReply: send failed");
70         return;
71     }
72 }
73 
HttpProtocolDecode(const std::string & request,HttpProtocol & req)74 bool WebSocket::HttpProtocolDecode(const std::string& request, HttpProtocol& req)
75 {
76     if (request.find("GET") == std::string::npos) {
77         LOGE("Handshake failed: lack of necessary info");
78         return false;
79     }
80     std::vector<std::string> reqStr = ProtocolSplit(request, EOL);
81     for (size_t i = 0; i < reqStr.size(); i++) {
82         if (i == 0) {
83             std::vector<std::string> headers = ProtocolSplit(reqStr.at(i), " ");
84             req.version = headers.at(2); // 2: to get the version param
85         } else if (i < reqStr.size() - 1) {
86             std::vector<std::string> headers = ProtocolSplit(reqStr.at(i), ": ");
87             if (reqStr.at(i).find("Connection") != std::string::npos) {
88                 req.connection = headers.at(1); // 1: to get the connection param
89             } else if (reqStr.at(i).find("Upgrade") != std::string::npos) {
90                 req.upgrade = headers.at(1); // 1: to get the upgrade param
91             } else if (reqStr.at(i).find("Sec-WebSocket-Key") != std::string::npos) {
92                 req.secWebSocketKey = headers.at(1); // 1: to get the secWebSocketKey param
93             }
94         }
95     }
96     return true;
97 }
98 
99 /**
100   *  The wired format of this data transmission section is described in detail through ABNFRFC5234.
101   *  When receive the message, we should decode it according the spec. The structure is as follows:
102   *     0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
103   *    +-+-+-+-+-------+-+-------------+-------------------------------+
104   *    |F|R|R|R| opcode|M| Payload len |    Extended payload length    |
105   *    |I|S|S|S|  (4)  |A|     (7)     |             (16/64)           |
106   *    |N|V|V|V|       |S|             |   (if payload len==126/127)   |
107   *    | |1|2|3|       |K|             |                               |
108   *    +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
109   *    |     Extended payload length continued, if payload len == 127  |
110   *    + - - - - - - - - - - - - - - - +-------------------------------+
111   *    |                               |Masking-key, if MASK set to 1  |
112   *    +-------------------------------+-------------------------------+
113   *    | Masking-key (continued)       |          Payload Data         |
114   *    +-------------------------------- - - - - - - - - - - - - - - - +
115   *    :                     Payload Data continued ...                :
116   *    + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
117   *    |                     Payload Data continued ...                |
118   *    +---------------------------------------------------------------+
119   */
120 
HandleFrame(WebSocketFrame & wsFrame)121 bool WebSocket::HandleFrame(WebSocketFrame& wsFrame)
122 {
123     if (wsFrame.payloadLen == 126) { // 126: the payloadLen read from frame
124         char recvbuf[PAYLOAD_LEN + 1] = {0};
125         if (!Recv(client_, recvbuf, PAYLOAD_LEN, 0)) {
126             LOGE("HandleFrame: Recv payloadLen == 126 failed");
127             return false;
128         }
129 
130         uint16_t msgLen = 0;
131         if (memcpy_s(&msgLen, sizeof(recvbuf), recvbuf, sizeof(recvbuf) - 1) != EOK) {
132             LOGE("HandleFrame: memcpy_s failed");
133             return false;
134         }
135         wsFrame.payloadLen = ntohs(msgLen);
136     } else if (wsFrame.payloadLen > 126) { // 126: the payloadLen read from frame
137         char recvbuf[EXTEND_PAYLOAD_LEN + 1] = {0};
138         if (!Recv(client_, recvbuf, EXTEND_PAYLOAD_LEN, 0)) {
139             LOGE("HandleFrame: Recv payloadLen > 127 failed");
140             return false;
141         }
142         wsFrame.payloadLen = NetToHostLongLong(recvbuf, EXTEND_PAYLOAD_LEN);
143     }
144     return DecodeMessage(wsFrame);
145 }
146 
DecodeMessage(WebSocketFrame & wsFrame)147 bool WebSocket::DecodeMessage(WebSocketFrame& wsFrame)
148 {
149     if (wsFrame.payloadLen == 0 || wsFrame.payloadLen > UINT64_MAX) {
150         LOGE("ReadMsg length error, expected greater than zero and less than UINT64_MAX");
151         return false;
152     }
153     uint64_t msgLen = wsFrame.payloadLen;
154     wsFrame.payload = std::make_unique<char []>(msgLen + 1);
155     if (wsFrame.mask == 1) {
156         char buf[msgLen + 1];
157         if (!Recv(client_, wsFrame.maskingkey, SOCKET_MASK_LEN, 0)) {
158             LOGE("DecodeMessage: Recv maskingkey failed");
159             return false;
160         }
161 
162         if (!Recv(client_, buf, msgLen, 0)) {
163             LOGE("DecodeMessage: Recv message with mask failed");
164             return false;
165         }
166 
167         for (uint64_t i = 0; i < msgLen; i++) {
168             uint64_t j = i % SOCKET_MASK_LEN;
169             wsFrame.payload.get()[i] = buf[i] ^ wsFrame.maskingkey[j];
170         }
171     } else {
172         char buf[msgLen + 1];
173         if (!Recv(client_, buf, msgLen, 0)) {
174             LOGE("DecodeMessage: Recv message without mask failed");
175             return false;
176         }
177 
178         if (memcpy_s(wsFrame.payload.get(), msgLen, buf, msgLen) != EOK) {
179             LOGE("DecodeMessage: memcpy_s failed");
180             return false;
181         }
182     }
183     wsFrame.payload.get()[msgLen] = '\0';
184     return true;
185 }
186 
ProtocolUpgrade(const HttpProtocol & req)187 bool WebSocket::ProtocolUpgrade(const HttpProtocol& req)
188 {
189     std::string rawKey = req.secWebSocketKey + WEB_SOCKET_GUID;
190     unsigned const char* webSocketKey = reinterpret_cast<unsigned const char*>(std::move(rawKey).c_str());
191     unsigned char hash[SHA_DIGEST_LENGTH + 1];
192     SHA1(webSocketKey, strlen(reinterpret_cast<const char*>(webSocketKey)), hash);
193     hash[SHA_DIGEST_LENGTH] = '\0';
194     unsigned char encodedKey[ENCODED_KEY_LEN];
195     EVP_EncodeBlock(encodedKey, reinterpret_cast<const unsigned char*>(hash), SHA_DIGEST_LENGTH);
196     std::string response;
197 
198     std::ostringstream sstream;
199     sstream << "HTTP/1.1 101 Switching Protocols" << EOL;
200     sstream << "Connection: upgrade" << EOL;
201     sstream << "Upgrade: websocket" << EOL;
202     sstream << "Sec-WebSocket-Accept: " << encodedKey << EOL;
203     sstream << EOL;
204     response = sstream.str();
205     if (!Send(client_, response.c_str(), response.length(), 0)) {
206         LOGE("ProtocolUpgrade: Send failed");
207         return false;
208     }
209     return true;
210 }
211 
Decode()212 std::string WebSocket::Decode()
213 {
214     if (socketState_ != SocketState::CONNECTED) {
215         LOGE("Decode failed, websocket not connected!");
216         return "";
217     }
218     char recvbuf[SOCKET_HEADER_LEN + 1];
219     if (!Recv(client_, recvbuf, SOCKET_HEADER_LEN, 0)) {
220         LOGE("Decode failed, client websocket disconnect");
221         socketState_ = SocketState::INITED;
222 #if defined(OHOS_PLATFORM)
223         shutdown(client_, SHUT_RDWR);
224         close(client_);
225         client_ = -1;
226 #else
227         close(client_);
228         client_ = -1;
229 #endif
230         return "";
231     }
232     WebSocketFrame wsFrame;
233     int32_t index = 0;
234     wsFrame.fin = static_cast<uint8_t>(recvbuf[index] >> 7); // 7: shift right by 7 bits to get the fin
235     wsFrame.opcode = static_cast<uint8_t>(recvbuf[index] & 0xf);
236     if (wsFrame.opcode == 0x1) { // 0x1: 0x1 means a text frame
237         index++;
238         wsFrame.mask = static_cast<uint8_t>((recvbuf[index] >> 7) & 0x1); // 7: to get the mask
239         wsFrame.payloadLen = recvbuf[index] & 0x7f;
240         if (HandleFrame(wsFrame)) {
241             return wsFrame.payload.get();
242         }
243         return "";
244     } else if (wsFrame.opcode == 0x9) { // 0x9: 0x9 means a ping frame
245         // send pong frame
246         char pongFrame[SOCKET_HEADER_LEN] = {0};
247         pongFrame[0] = 0x8a; // 0x8a: 0x8a means a pong frame
248         pongFrame[1] = 0x0;
249         if (!Send(client_, pongFrame, SOCKET_HEADER_LEN, 0)) {
250             LOGE("Decode: Send pong frame failed");
251             return "";
252         }
253     }
254     return "";
255 }
256 
HttpHandShake()257 bool WebSocket::HttpHandShake()
258 {
259     char msgBuf[SOCKET_HANDSHAKE_LEN];
260     int32_t msgLen = recv(client_, msgBuf, SOCKET_HANDSHAKE_LEN, 0);
261     if (msgLen <= 0) {
262         LOGE("ReadMsg failed readRet=%{public}d", msgLen);
263         return false;
264     } else {
265         msgBuf[msgLen - 1] = '\0';
266         HttpProtocol req;
267         if (!HttpProtocolDecode(msgBuf, req)) {
268             LOGE("HttpHandShake: Upgrade failed");
269             return false;
270         } else if (req.connection.find("Upgrade") != std::string::npos &&
271             req.upgrade.find("websocket") != std::string::npos && req.version.compare("HTTP/1.1") == 0) {
272             ProtocolUpgrade(req);
273         }
274     }
275     return true;
276 }
277 
278 #if !defined(OHOS_PLATFORM)
InitTcpWebSocket(uint32_t timeoutLimit)279 bool WebSocket::InitTcpWebSocket(uint32_t timeoutLimit)
280 {
281     if (socketState_ != SocketState::UNINITED) {
282         LOGI("InitTcpWebSocket websocket has inited");
283         return true;
284     }
285 #if defined(WINDOWS_PLATFORM)
286     WORD sockVersion = MAKEWORD(2, 2); // 2: version 2.2
287     WSADATA wsaData;
288     if (WSAStartup(sockVersion, &wsaData) != 0) {
289         LOGE("InitTcpWebSocket WSA init failed");
290         return false;
291     }
292 #endif
293     fd_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
294     if (fd_ < SOCKET_SUCCESS) {
295         LOGE("InitTcpWebSocket socket init failed");
296         return false;
297     }
298     // allow specified port can be used at once and not wait TIME_WAIT status ending
299     int sockOptVal = 1;
300     if ((setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &sockOptVal, sizeof(sockOptVal))) != SOCKET_SUCCESS) {
301         LOGE("InitTcpWebSocket setsockopt SO_REUSEADDR failed");
302         close(fd_);
303         fd_ = -1;
304         return false;
305     }
306 
307     // set send and recv timeout
308     if (!SetWebSocketTimeOut(fd_, timeoutLimit)) {
309         LOGE("InitTcpWebSocket SetWebSocketTimeOut failed");
310         close(fd_);
311         fd_ = -1;
312         return false;
313     }
314 
315     sockaddr_in addr_sin = {0};
316     addr_sin.sin_family = AF_INET;
317     addr_sin.sin_port = htons(9230); // 9230: sockName for tcp
318     addr_sin.sin_addr.s_addr = INADDR_ANY;
319     if (bind(fd_, reinterpret_cast<struct sockaddr*>(&addr_sin), sizeof(addr_sin)) < SOCKET_SUCCESS) {
320         LOGE("InitTcpWebSocket bind failed");
321         close(fd_);
322         fd_ = -1;
323         return false;
324     }
325     if (listen(fd_, 1) < SOCKET_SUCCESS) {
326         LOGE("InitTcpWebSocket listen failed");
327         close(fd_);
328         fd_ = -1;
329         return false;
330     }
331     socketState_ = SocketState::INITED;
332     return true;
333 }
334 
ConnectTcpWebSocket()335 bool WebSocket::ConnectTcpWebSocket()
336 {
337     if (socketState_ == SocketState::UNINITED) {
338         LOGE("ConnectTcpWebSocket failed, websocket not inited");
339         return false;
340     }
341     if (socketState_ == SocketState::CONNECTED) {
342         LOGI("ConnectTcpWebSocket websocket has connected");
343         return true;
344     }
345 
346     if ((client_ = accept(fd_, nullptr, nullptr)) < SOCKET_SUCCESS) {
347         LOGE("ConnectTcpWebSocket accept failed");
348         socketState_ = SocketState::UNINITED;
349         close(fd_);
350         fd_ = -1;
351         return false;
352     }
353 
354     if (!HttpHandShake()) {
355         LOGE("ConnectTcpWebSocket HttpHandShake failed");
356         socketState_ = SocketState::UNINITED;
357         close(client_);
358         client_ = -1;
359         close(fd_);
360         fd_ = -1;
361         return false;
362     }
363     socketState_ = SocketState::CONNECTED;
364     return true;
365 }
366 #else
InitUnixWebSocket(const std::string & sockName,uint32_t timeoutLimit)367 bool WebSocket::InitUnixWebSocket(const std::string& sockName, uint32_t timeoutLimit)
368 {
369     if (socketState_ != SocketState::UNINITED) {
370         LOGI("InitUnixWebSocket websocket has inited");
371         return true;
372     }
373     fd_ = socket(AF_UNIX, SOCK_STREAM, 0); // 0: defautlt protocol
374     if (fd_ < SOCKET_SUCCESS) {
375         LOGE("InitUnixWebSocket socket init failed");
376         return false;
377     }
378     // set send and recv timeout
379     if (!SetWebSocketTimeOut(fd_, timeoutLimit)) {
380         LOGE("InitUnixWebSocket SetWebSocketTimeOut failed");
381         close(fd_);
382         fd_ = -1;
383         return false;
384     }
385 
386     struct sockaddr_un un;
387     if (memset_s(&un, sizeof(un), 0, sizeof(un)) != EOK) {
388         LOGE("InitUnixWebSocket memset_s failed");
389         close(fd_);
390         fd_ = -1;
391         return false;
392     }
393     un.sun_family = AF_UNIX;
394     if (strcpy_s(un.sun_path + 1, sizeof(un.sun_path) - 1, sockName.c_str()) != EOK) {
395         LOGE("InitUnixWebSocket strcpy_s failed");
396         close(fd_);
397         fd_ = -1;
398         return false;
399     }
400     un.sun_path[0] = '\0';
401     uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1;
402     if (bind(fd_, reinterpret_cast<struct sockaddr*>(&un), static_cast<int32_t>(len)) < SOCKET_SUCCESS) {
403         LOGE("InitUnixWebSocket bind failed");
404         close(fd_);
405         fd_ = -1;
406         return false;
407     }
408     if (listen(fd_, 1) < SOCKET_SUCCESS) { // 1: connection num
409         LOGE("InitUnixWebSocket listen failed");
410         close(fd_);
411         fd_ = -1;
412         return false;
413     }
414     socketState_ = SocketState::INITED;
415     return true;
416 }
417 
ConnectUnixWebSocket()418 bool WebSocket::ConnectUnixWebSocket()
419 {
420     if (socketState_ == SocketState::UNINITED) {
421         LOGE("ConnectUnixWebSocket failed, websocket not inited");
422         return false;
423     }
424     if (socketState_ == SocketState::CONNECTED) {
425         LOGI("ConnectUnixWebSocket websocket has connected");
426         return true;
427     }
428 
429     if ((client_ = accept(fd_, nullptr, nullptr)) < SOCKET_SUCCESS) {
430         LOGE("ConnectUnixWebSocket accept failed");
431         socketState_ = SocketState::UNINITED;
432         close(fd_);
433         fd_ = -1;
434         return false;
435     }
436     if (!HttpHandShake()) {
437         LOGE("ConnectUnixWebSocket HttpHandShake failed");
438         socketState_ = SocketState::UNINITED;
439         shutdown(client_, SHUT_RDWR);
440         close(client_);
441         client_ = -1;
442         shutdown(fd_, SHUT_RDWR);
443         close(fd_);
444         fd_ = -1;
445         return false;
446     }
447     socketState_ = SocketState::CONNECTED;
448     return true;
449 }
450 #endif
451 
IsConnected()452 bool WebSocket::IsConnected()
453 {
454     return socketState_ == SocketState::CONNECTED;
455 }
456 
Close()457 void WebSocket::Close()
458 {
459     if (socketState_ == SocketState::UNINITED) {
460         return;
461     }
462     if (socketState_ == SocketState::CONNECTED) {
463 #if defined(OHOS_PLATFORM)
464         shutdown(client_, SHUT_RDWR);
465 #endif
466         close(client_);
467         client_ = -1;
468     }
469     socketState_ = SocketState::UNINITED;
470     usleep(10000); // 10000: time for websocket to enter the accept
471 #if defined(OHOS_PLATFORM)
472     shutdown(fd_, SHUT_RDWR);
473 #endif
474     close(fd_);
475     fd_ = -1;
476 }
477 
NetToHostLongLong(char * buf,uint32_t len)478 uint64_t WebSocket::NetToHostLongLong(char* buf, uint32_t len)
479 {
480     uint64_t result = 0;
481     for (uint32_t i = 0; i < len; i++) {
482         result |= static_cast<unsigned char>(buf[i]);
483         if ((i + 1) < len) {
484             result <<= 8; // 8: result need shift left 8 bits in order to big endian convert to int
485         }
486     }
487     return result;
488 }
489 
Recv(int32_t client,char * buf,size_t totalLen,int32_t flags) const490 bool WebSocket::Recv(int32_t client, char* buf, size_t totalLen, int32_t flags) const
491 {
492     size_t recvLen = 0;
493     while (recvLen < totalLen) {
494         ssize_t len = recv(client, buf + recvLen, totalLen - recvLen, flags);
495         if (len <= 0) {
496             LOGE("Recv payload in while failed, websocket disconnect");
497             return false;
498         }
499         recvLen += static_cast<size_t>(len);
500     }
501     buf[totalLen] = '\0';
502     return true;
503 }
504 
Send(int32_t client,const char * buf,size_t totalLen,int32_t flags) const505 bool WebSocket::Send(int32_t client, const char* buf, size_t totalLen, int32_t flags) const
506 {
507     size_t sendLen = 0;
508     while (sendLen < totalLen) {
509         ssize_t len = send(client, buf + sendLen, totalLen - sendLen, flags);
510         if (len <= 0) {
511             LOGE("Send Message in while failed, websocket disconnect");
512             return false;
513         }
514         sendLen += static_cast<size_t>(len);
515     }
516     return true;
517 }
518 
SetWebSocketTimeOut(int32_t fd,uint32_t timeoutLimit)519 bool WebSocket::SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit)
520 {
521     if (timeoutLimit > 0) {
522         struct timeval timeout = {timeoutLimit, 0};
523         if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) != SOCKET_SUCCESS) {
524             LOGE("SetWebSocketTimeOut setsockopt SO_SNDTIMEO failed");
525             return false;
526         }
527         if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) != SOCKET_SUCCESS) {
528             LOGE("SetWebSocketTimeOut setsockopt SO_RCVTIMEO failed");
529             return false;
530         }
531     }
532     return true;
533 }
534 } // namespace OHOS::ArkCompiler::Toolchain
535