• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "websocket.h"
17 
18 #include "define.h"
19 #include "log_wrapper.h"
20 #include "securec.h"
21 
22 namespace OHOS::ArkCompiler::Toolchain {
23 /**
24  *  SendMessage in WebSocket has 3 situations:
25  *    1. message's length <= 125
26  *    2. message's length >= 126 && messages's length < 65536
27  *    3. message's length >= 65536
28  */
29 
SendReply(const std::string & message) const30 void WebSocket::SendReply(const std::string& message) const
31 {
32     if (socketState_ != SocketState::CONNECTED) {
33         LOGE("SendReply failed, websocket not connected");
34         return;
35     }
36     const size_t msgLen = message.length();
37     std::unique_ptr<char []> msgBuf = std::make_unique<char []>(msgLen + 11); // 11: the maximum expandable length
38     char* sendBuf = msgBuf.get();
39     uint32_t sendMsgLen;
40     sendBuf[0] = 0x81; // 0x81: the text message sent by the server should start with '0x81'.
41 
42     // Depending on the length of the messages, server will use shift operation to get the res
43     // and store them in the buffer.
44     if (msgLen <= 125) { // 125: situation 1 when message's length <= 125
45         sendBuf[1] = msgLen;
46         sendMsgLen = 2; // 2: the length of header frame is 2
47     } else if (msgLen < 65536) { // 65536: message's length
48         sendBuf[1] = 126; // 126: payloadLen according to the spec
49         sendBuf[2] = ((msgLen >> 8) & 0xff); // 8: shift right by 8 bits => res * (256^1)
50         sendBuf[3] = (msgLen & 0xff); // 3: store len's data => res * (256^0)
51         sendMsgLen = 4; // 4: the length of header frame is 4
52     } else {
53         sendBuf[1] = 127; // 127: payloadLen according to the spec
54         for (int32_t i = 2; i <= 5; i++) { // 2 ~ 5: unused bits
55             sendBuf[i] = 0;
56         }
57         sendBuf[6] = ((msgLen & 0xff000000) >> 24); // 6: shift 24 bits => res * (256^3)
58         sendBuf[7] = ((msgLen & 0x00ff0000) >> 16); // 7: shift 16 bits => res * (256^2)
59         sendBuf[8] = ((msgLen & 0x0000ff00) >> 8);  // 8: shift 8 bits => res * (256^1)
60         sendBuf[9] = (msgLen & 0x000000ff); // 9: res * (256^0)
61         sendMsgLen = 10; // 10: the length of header frame is 10
62     }
63     if (memcpy_s(sendBuf + sendMsgLen, msgLen, message.c_str(), msgLen) != EOK) {
64         LOGE("SendReply: memcpy_s failed");
65         return;
66     }
67     sendBuf[sendMsgLen + msgLen] = '\0';
68     if (!Send(client_, sendBuf, sendMsgLen + msgLen, 0)) {
69         LOGE("SendReply: send failed");
70         return;
71     }
72 }
73 
HttpProtocolDecode(const std::string & request,HttpProtocol & req)74 bool WebSocket::HttpProtocolDecode(const std::string& request, HttpProtocol& req)
75 {
76     if (request.find("GET") == std::string::npos) {
77         LOGE("Handshake failed: lack of necessary info");
78         return false;
79     }
80     std::vector<std::string> reqStr = ProtocolSplit(request, EOL);
81     for (size_t i = 0; i < reqStr.size(); i++) {
82         if (i == 0) {
83             std::vector<std::string> headers = ProtocolSplit(reqStr.at(i), " ");
84             req.version = headers.at(2); // 2: to get the version param
85         } else if (i < reqStr.size() - 1) {
86             std::vector<std::string> headers = ProtocolSplit(reqStr.at(i), ": ");
87             if (reqStr.at(i).find("Connection") != std::string::npos) {
88                 req.connection = headers.at(1); // 1: to get the connection param
89             } else if (reqStr.at(i).find("Upgrade") != std::string::npos) {
90                 req.upgrade = headers.at(1); // 1: to get the upgrade param
91             } else if (reqStr.at(i).find("Sec-WebSocket-Key") != std::string::npos) {
92                 req.secWebSocketKey = headers.at(1); // 1: to get the secWebSocketKey param
93             }
94         }
95     }
96     return true;
97 }
98 
99 /**
100   *  The wired format of this data transmission section is described in detail through ABNFRFC5234.
101   *  When receive the message, we should decode it according the spec. The structure is as follows:
102   *     0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
103   *    +-+-+-+-+-------+-+-------------+-------------------------------+
104   *    |F|R|R|R| opcode|M| Payload len |    Extended payload length    |
105   *    |I|S|S|S|  (4)  |A|     (7)     |             (16/64)           |
106   *    |N|V|V|V|       |S|             |   (if payload len==126/127)   |
107   *    | |1|2|3|       |K|             |                               |
108   *    +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
109   *    |     Extended payload length continued, if payload len == 127  |
110   *    + - - - - - - - - - - - - - - - +-------------------------------+
111   *    |                               |Masking-key, if MASK set to 1  |
112   *    +-------------------------------+-------------------------------+
113   *    | Masking-key (continued)       |          Payload Data         |
114   *    +-------------------------------- - - - - - - - - - - - - - - - +
115   *    :                     Payload Data continued ...                :
116   *    + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
117   *    |                     Payload Data continued ...                |
118   *    +---------------------------------------------------------------+
119   */
120 
HandleFrame(WebSocketFrame & wsFrame)121 bool WebSocket::HandleFrame(WebSocketFrame& wsFrame)
122 {
123     if (wsFrame.payloadLen == 126) { // 126: the payloadLen read from frame
124         char recvbuf[PAYLOAD_LEN + 1] = {0};
125         if (!Recv(client_, recvbuf, PAYLOAD_LEN, 0)) {
126             LOGE("HandleFrame: Recv payloadLen == 126 failed");
127             return false;
128         }
129 
130         uint16_t msgLen = 0;
131         if (memcpy_s(&msgLen, sizeof(recvbuf), recvbuf, sizeof(recvbuf) - 1) != EOK) {
132             LOGE("HandleFrame: memcpy_s failed");
133             return false;
134         }
135         wsFrame.payloadLen = ntohs(msgLen);
136     } else if (wsFrame.payloadLen > 126) { // 126: the payloadLen read from frame
137         char recvbuf[EXTEND_PAYLOAD_LEN + 1] = {0};
138         if (!Recv(client_, recvbuf, EXTEND_PAYLOAD_LEN, 0)) {
139             LOGE("HandleFrame: Recv payloadLen > 127 failed");
140             return false;
141         }
142         wsFrame.payloadLen = NetToHostLongLong(recvbuf, EXTEND_PAYLOAD_LEN);
143     }
144     return DecodeMessage(wsFrame);
145 }
146 
DecodeMessage(WebSocketFrame & wsFrame)147 bool WebSocket::DecodeMessage(WebSocketFrame& wsFrame)
148 {
149     if (wsFrame.payloadLen == 0 || wsFrame.payloadLen > UINT64_MAX) {
150         LOGE("ReadMsg length error, expected greater than zero and less than UINT64_MAX");
151         return false;
152     }
153     uint64_t msgLen = wsFrame.payloadLen;
154     wsFrame.payload = std::make_unique<char []>(msgLen + 1);
155     if (wsFrame.mask == 1) {
156         char buf[msgLen + 1];
157         if (!Recv(client_, wsFrame.maskingkey, SOCKET_MASK_LEN, 0)) {
158             LOGE("DecodeMessage: Recv maskingkey failed");
159             return false;
160         }
161 
162         if (!Recv(client_, buf, msgLen, 0)) {
163             LOGE("DecodeMessage: Recv message with mask failed");
164             return false;
165         }
166 
167         for (uint64_t i = 0; i < msgLen; i++) {
168             uint64_t j = i % SOCKET_MASK_LEN;
169             wsFrame.payload.get()[i] = buf[i] ^ wsFrame.maskingkey[j];
170         }
171     } else {
172         char buf[msgLen + 1];
173         if (!Recv(client_, buf, msgLen, 0)) {
174             LOGE("DecodeMessage: Recv message without mask failed");
175             return false;
176         }
177 
178         if (memcpy_s(wsFrame.payload.get(), msgLen, buf, msgLen) != EOK) {
179             LOGE("DecodeMessage: memcpy_s failed");
180             return false;
181         }
182     }
183     wsFrame.payload.get()[msgLen] = '\0';
184     return true;
185 }
186 
ProtocolUpgrade(const HttpProtocol & req)187 bool WebSocket::ProtocolUpgrade(const HttpProtocol& req)
188 {
189     std::string rawKey = req.secWebSocketKey + WEB_SOCKET_GUID;
190     unsigned const char* webSocketKey = reinterpret_cast<unsigned const char*>(std::move(rawKey).c_str());
191     unsigned char hash[SHA_DIGEST_LENGTH + 1];
192     SHA1(webSocketKey, strlen(reinterpret_cast<const char*>(webSocketKey)), hash);
193     hash[SHA_DIGEST_LENGTH] = '\0';
194     unsigned char encodedKey[ENCODED_KEY_LEN];
195     EVP_EncodeBlock(encodedKey, reinterpret_cast<const unsigned char*>(hash), SHA_DIGEST_LENGTH);
196     std::string response;
197 
198     std::ostringstream sstream;
199     sstream << "HTTP/1.1 101 Switching Protocols" << EOL;
200     sstream << "Connection: upgrade" << EOL;
201     sstream << "Upgrade: websocket" << EOL;
202     sstream << "Sec-WebSocket-Accept: " << encodedKey << EOL;
203     sstream << EOL;
204     response = sstream.str();
205     if (!Send(client_, response.c_str(), response.length(), 0)) {
206         LOGE("ProtocolUpgrade: Send failed");
207         return false;
208     }
209     return true;
210 }
211 
Decode()212 std::string WebSocket::Decode()
213 {
214     if (socketState_ != SocketState::CONNECTED) {
215         LOGE("Decode failed, websocket not connected!");
216         return "";
217     }
218     char recvbuf[SOCKET_HEADER_LEN + 1];
219     if (!Recv(client_, recvbuf, SOCKET_HEADER_LEN, 0)) {
220         LOGE("Decode failed, client websocket disconnect");
221         socketState_ = SocketState::INITED;
222 #if defined(OHOS_PLATFORM)
223         shutdown(client_, SHUT_RDWR);
224         close(client_);
225         client_ = -1;
226 #else
227         close(client_);
228         client_ = -1;
229 #endif
230         return "";
231     }
232     WebSocketFrame wsFrame;
233     int32_t index = 0;
234     wsFrame.fin = static_cast<uint8_t>(recvbuf[index] >> 7); // 7: shift right by 7 bits to get the fin
235     wsFrame.opcode = static_cast<uint8_t>(recvbuf[index] & 0xf);
236     if (wsFrame.opcode == 0x1) { // 0x1: 0x1 means a text frame
237         index++;
238         wsFrame.mask = static_cast<uint8_t>((recvbuf[index] >> 7) & 0x1); // 7: to get the mask
239         wsFrame.payloadLen = recvbuf[index] & 0x7f;
240         if (HandleFrame(wsFrame)) {
241             return wsFrame.payload.get();
242         }
243         return "";
244     }
245     return "";
246 }
247 
HttpHandShake()248 bool WebSocket::HttpHandShake()
249 {
250     char msgBuf[SOCKET_HANDSHAKE_LEN];
251     int32_t msgLen = recv(client_, msgBuf, SOCKET_HANDSHAKE_LEN, 0);
252     if (msgLen <= 0) {
253         LOGE("ReadMsg failed readRet=%{public}d", msgLen);
254         return false;
255     } else {
256         msgBuf[msgLen - 1] = '\0';
257         HttpProtocol req;
258         if (!HttpProtocolDecode(msgBuf, req)) {
259             LOGE("HttpHandShake: Upgrade failed");
260             return false;
261         } else if (req.connection.find("Upgrade") != std::string::npos &&
262             req.upgrade.find("websocket") != std::string::npos && req.version.compare("HTTP/1.1") == 0) {
263             ProtocolUpgrade(req);
264         }
265     }
266     return true;
267 }
268 
269 #if !defined(OHOS_PLATFORM)
InitTcpWebSocket(uint32_t timeoutLimit)270 bool WebSocket::InitTcpWebSocket(uint32_t timeoutLimit)
271 {
272     if (socketState_ != SocketState::UNINITED) {
273         LOGI("InitTcpWebSocket websocket has inited");
274         return true;
275     }
276 #if defined(WINDOWS_PLATFORM)
277     WORD sockVersion = MAKEWORD(2, 2); // 2: version 2.2
278     WSADATA wsaData;
279     if (WSAStartup(sockVersion, &wsaData) != 0) {
280         LOGE("InitTcpWebSocket WSA init failed");
281         return false;
282     }
283 #endif
284     fd_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
285     if (fd_ < SOCKET_SUCCESS) {
286         LOGE("InitTcpWebSocket socket init failed");
287         return false;
288     }
289     // allow specified port can be used at once and not wait TIME_WAIT status ending
290     int sockOptVal = 1;
291     if ((setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &sockOptVal, sizeof(sockOptVal))) != SOCKET_SUCCESS) {
292         LOGE("InitTcpWebSocket setsockopt SO_REUSEADDR failed");
293         close(fd_);
294         fd_ = -1;
295         return false;
296     }
297 
298     // set send and recv timeout
299     if (!SetWebSocketTimeOut(fd_, timeoutLimit)) {
300         LOGE("InitTcpWebSocket SetWebSocketTimeOut failed");
301         close(fd_);
302         fd_ = -1;
303         return false;
304     }
305 
306     sockaddr_in addr_sin = {0};
307     addr_sin.sin_family = AF_INET;
308     addr_sin.sin_port = htons(9230); // 9230: sockName for tcp
309     addr_sin.sin_addr.s_addr = INADDR_ANY;
310     if (bind(fd_, reinterpret_cast<struct sockaddr*>(&addr_sin), sizeof(addr_sin)) < SOCKET_SUCCESS) {
311         LOGE("InitTcpWebSocket bind failed");
312         close(fd_);
313         fd_ = -1;
314         return false;
315     }
316     if (listen(fd_, 1) < SOCKET_SUCCESS) {
317         LOGE("InitTcpWebSocket listen failed");
318         close(fd_);
319         fd_ = -1;
320         return false;
321     }
322     socketState_ = SocketState::INITED;
323     return true;
324 }
325 
ConnectTcpWebSocket()326 bool WebSocket::ConnectTcpWebSocket()
327 {
328     if (socketState_ == SocketState::UNINITED) {
329         LOGE("ConnectTcpWebSocket failed, websocket not inited");
330         return false;
331     }
332     if (socketState_ == SocketState::CONNECTED) {
333         LOGI("ConnectTcpWebSocket websocket has connected");
334         return true;
335     }
336 
337     if ((client_ = accept(fd_, nullptr, nullptr)) < SOCKET_SUCCESS) {
338         LOGE("ConnectTcpWebSocket accept failed");
339         socketState_ = SocketState::UNINITED;
340         close(fd_);
341         fd_ = -1;
342         return false;
343     }
344 
345     if (!HttpHandShake()) {
346         LOGE("ConnectTcpWebSocket HttpHandShake failed");
347         socketState_ = SocketState::UNINITED;
348         close(client_);
349         client_ = -1;
350         close(fd_);
351         fd_ = -1;
352         return false;
353     }
354     socketState_ = SocketState::CONNECTED;
355     return true;
356 }
357 #else
InitUnixWebSocket(const std::string & sockName,uint32_t timeoutLimit)358 bool WebSocket::InitUnixWebSocket(const std::string& sockName, uint32_t timeoutLimit)
359 {
360     if (socketState_ != SocketState::UNINITED) {
361         LOGI("InitUnixWebSocket websocket has inited");
362         return true;
363     }
364     fd_ = socket(AF_UNIX, SOCK_STREAM, 0); // 0: defautlt protocol
365     if (fd_ < SOCKET_SUCCESS) {
366         LOGE("InitUnixWebSocket socket init failed");
367         return false;
368     }
369     // set send and recv timeout
370     if (!SetWebSocketTimeOut(fd_, timeoutLimit)) {
371         LOGE("InitUnixWebSocket SetWebSocketTimeOut failed");
372         close(fd_);
373         fd_ = -1;
374         return false;
375     }
376 
377     struct sockaddr_un un;
378     if (memset_s(&un, sizeof(un), 0, sizeof(un)) != EOK) {
379         LOGE("InitUnixWebSocket memset_s failed");
380         close(fd_);
381         fd_ = -1;
382         return false;
383     }
384     un.sun_family = AF_UNIX;
385     if (strcpy_s(un.sun_path + 1, sizeof(un.sun_path) - 1, sockName.c_str()) != EOK) {
386         LOGE("InitUnixWebSocket strcpy_s failed");
387         close(fd_);
388         fd_ = -1;
389         return false;
390     }
391     un.sun_path[0] = '\0';
392     uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1;
393     if (bind(fd_, reinterpret_cast<struct sockaddr*>(&un), static_cast<int32_t>(len)) < SOCKET_SUCCESS) {
394         LOGE("InitUnixWebSocket bind failed");
395         close(fd_);
396         fd_ = -1;
397         return false;
398     }
399     if (listen(fd_, 1) < SOCKET_SUCCESS) { // 1: connection num
400         LOGE("InitUnixWebSocket listen failed");
401         close(fd_);
402         fd_ = -1;
403         return false;
404     }
405     socketState_ = SocketState::INITED;
406     return true;
407 }
408 
ConnectUnixWebSocket()409 bool WebSocket::ConnectUnixWebSocket()
410 {
411     if (socketState_ == SocketState::UNINITED) {
412         LOGE("ConnectUnixWebSocket failed, websocket not inited");
413         return false;
414     }
415     if (socketState_ == SocketState::CONNECTED) {
416         LOGI("ConnectUnixWebSocket websocket has connected");
417         return true;
418     }
419 
420     if ((client_ = accept(fd_, nullptr, nullptr)) < SOCKET_SUCCESS) {
421         LOGE("ConnectUnixWebSocket accept failed");
422         socketState_ = SocketState::UNINITED;
423         close(fd_);
424         fd_ = -1;
425         return false;
426     }
427     if (!HttpHandShake()) {
428         LOGE("ConnectUnixWebSocket HttpHandShake failed");
429         socketState_ = SocketState::UNINITED;
430         shutdown(client_, SHUT_RDWR);
431         close(client_);
432         client_ = -1;
433         shutdown(fd_, SHUT_RDWR);
434         close(fd_);
435         fd_ = -1;
436         return false;
437     }
438     socketState_ = SocketState::CONNECTED;
439     return true;
440 }
441 #endif
442 
IsConnected()443 bool WebSocket::IsConnected()
444 {
445     return socketState_ == SocketState::CONNECTED;
446 }
447 
Close()448 void WebSocket::Close()
449 {
450     if (socketState_ == SocketState::UNINITED) {
451         return;
452     }
453     if (socketState_ == SocketState::CONNECTED) {
454 #if defined(OHOS_PLATFORM)
455         shutdown(client_, SHUT_RDWR);
456 #endif
457         close(client_);
458         client_ = -1;
459     }
460     socketState_ = SocketState::UNINITED;
461     usleep(10000); // 10000: time for websocket to enter the accept
462 #if defined(OHOS_PLATFORM)
463     shutdown(fd_, SHUT_RDWR);
464 #endif
465     close(fd_);
466     fd_ = -1;
467 }
468 
NetToHostLongLong(char * buf,uint32_t len)469 uint64_t WebSocket::NetToHostLongLong(char* buf, uint32_t len)
470 {
471     uint64_t result = 0;
472     for (uint32_t i = 0; i < len; i++) {
473         result |= static_cast<unsigned char>(buf[i]);
474         if ((i + 1) < len) {
475             result <<= 8; // 8: result need shift left 8 bits in order to big endian convert to int
476         }
477     }
478     return result;
479 }
480 
Recv(int32_t client,char * buf,size_t totalLen,int32_t flags) const481 bool WebSocket::Recv(int32_t client, char* buf, size_t totalLen, int32_t flags) const
482 {
483     size_t recvLen = 0;
484     while (recvLen < totalLen) {
485         ssize_t len = recv(client, buf + recvLen, totalLen - recvLen, flags);
486         if (len <= 0) {
487             LOGE("Recv payload in while failed, websocket disconnect");
488             return false;
489         }
490         recvLen += static_cast<size_t>(len);
491     }
492     buf[totalLen] = '\0';
493     return true;
494 }
495 
Send(int32_t client,const char * buf,size_t totalLen,int32_t flags) const496 bool WebSocket::Send(int32_t client, const char* buf, size_t totalLen, int32_t flags) const
497 {
498     size_t sendLen = 0;
499     while (sendLen < totalLen) {
500         ssize_t len = send(client, buf + sendLen, totalLen - sendLen, flags);
501         if (len <= 0) {
502             LOGE("Send Message in while failed, websocket disconnect");
503             return false;
504         }
505         sendLen += static_cast<size_t>(len);
506     }
507     return true;
508 }
509 
SetWebSocketTimeOut(int32_t fd,uint32_t timeoutLimit)510 bool WebSocket::SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit)
511 {
512     if (timeoutLimit > 0) {
513         struct timeval timeout = {timeoutLimit, 0};
514         if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) != SOCKET_SUCCESS) {
515             LOGE("SetWebSocketTimeOut setsockopt SO_SNDTIMEO failed");
516             return false;
517         }
518         if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) != SOCKET_SUCCESS) {
519             LOGE("SetWebSocketTimeOut setsockopt SO_RCVTIMEO failed");
520             return false;
521         }
522     }
523     return true;
524 }
525 } // namespace OHOS::ArkCompiler::Toolchain
526