• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022-2025 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "common/log_wrapper.h"
17 #include "define.h"
18 #include "platform/file.h"
19 #include "frame_builder.h"
20 #include "network.h"
21 #include "websocket_base.h"
22 
23 #include <mutex>
24 
25 namespace OHOS::ArkCompiler::Toolchain {
26 std::mutex g_sendReplymutex;
ToString(CloseStatusCode status)27 static std::string ToString(CloseStatusCode status)
28 {
29     if (status == CloseStatusCode::NO_STATUS_CODE) {
30         return "";
31     }
32     std::string result;
33     PushNumberPerByte(result, EnumToNumber(status));
34     return result;
35 }
36 
~WebSocketBase()37 WebSocketBase::~WebSocketBase() noexcept
38 {
39     if (connectionFd_ != -1) {
40         LOGW("WebSocket connection is closed while destructing the object");
41         FdsanClose(reinterpret_cast<fd_t>(connectionFd_));
42         // Reset directly in order to prevent static analyzer warnings.
43         connectionFd_ = -1;
44     }
45 }
46 
47 // if the data is too large, it will be split into multiple frames, the first frame will be marked as 0x0
48 // and the last frame will be marked as 0x1.
49 // we just add the 'isLast' parameter to indicate whether it is the last frame.
SendReply(const std::string & message,FrameType frameType,bool isLast) const50 bool WebSocketBase::SendReply(const std::string& message, FrameType frameType, bool isLast) const
51 {
52     std::lock_guard<std::mutex> lock(g_sendReplymutex);
53     if (connectionState_.load() != ConnectionState::OPEN) {
54         LOGE("SendReply failed, websocket not connected");
55         return false;
56     }
57 
58     const auto frame = CreateFrame(isLast, frameType, message);
59     if (!SendUnderLock(frame)) {
60         LOGE("SendReply: send failed");
61         return false;
62     }
63     return true;
64 }
65 
66 /**
67   *  The wired format of this data transmission section is described in detail through ABNFRFC5234.
68   *  When receive the message, we should decode it according the spec. The structure is as follows:
69   *     0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
70   *    +-+-+-+-+-------+-+-------------+-------------------------------+
71   *    |F|R|R|R| opcode|M| Payload len |    Extended payload length    |
72   *    |I|S|S|S|  (4)  |A|     (7)     |             (16/64)           |
73   *    |N|V|V|V|       |S|             |   (if payload len==126/127)   |
74   *    | |1|2|3|       |K|             |                               |
75   *    +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
76   *    |     Extended payload length continued, if payload len == 127  |
77   *    + - - - - - - - - - - - - - - - +-------------------------------+
78   *    |                               |Masking-key, if MASK set to 1  |
79   *    +-------------------------------+-------------------------------+
80   *    | Masking-key (continued)       |          Payload Data         |
81   *    +-------------------------------- - - - - - - - - - - - - - - - +
82   *    :                     Payload Data continued ...                :
83   *    + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
84   *    |                     Payload Data continued ...                |
85   *    +---------------------------------------------------------------+
86   */
87 
ReadPayload(WebSocketFrame & wsFrame) const88 bool WebSocketBase::ReadPayload(WebSocketFrame& wsFrame) const
89 {
90     if (wsFrame.payloadLen == WebSocketFrame::TWO_BYTES_LENTH_ENC) {
91         uint8_t recvbuf[WebSocketFrame::TWO_BYTES_LENTH] = {0};
92         if (!RecvUnderLock(recvbuf, WebSocketFrame::TWO_BYTES_LENTH)) {
93             LOGE("ReadPayload: Recv payloadLen == 126 failed");
94             return false;
95         }
96         wsFrame.payloadLen = NetToHostLongLong(recvbuf, WebSocketFrame::TWO_BYTES_LENTH);
97     } else if (wsFrame.payloadLen == WebSocketFrame::EIGHT_BYTES_LENTH_ENC) {
98         uint8_t recvbuf[WebSocketFrame::EIGHT_BYTES_LENTH] = {0};
99         if (!RecvUnderLock(recvbuf, WebSocketFrame::EIGHT_BYTES_LENTH)) {
100             LOGE("ReadPayload: Recv payloadLen == 127 failed");
101             return false;
102         }
103         wsFrame.payloadLen = NetToHostLongLong(recvbuf, WebSocketFrame::EIGHT_BYTES_LENTH);
104     }
105     return DecodeMessage(wsFrame);
106 }
107 
HandleDataFrame(WebSocketFrame & wsFrame) const108 bool WebSocketBase::HandleDataFrame(WebSocketFrame& wsFrame) const
109 {
110     if (wsFrame.opcode == EnumToNumber(FrameType::TEXT)) {
111         return ReadPayload(wsFrame);
112     } else {
113         LOGW("Received unsupported data frame, opcode = %{public}d", wsFrame.opcode);
114     }
115     return true;
116 }
117 
HandleControlFrame(WebSocketFrame & wsFrame)118 bool WebSocketBase::HandleControlFrame(WebSocketFrame& wsFrame)
119 {
120     if (wsFrame.opcode == EnumToNumber(FrameType::PING)) {
121         // A Pong frame sent in response to a Ping frame must have identical
122         // "Application data" as found in the message body of the Ping frame
123         // being replied to.
124         // https://www.rfc-editor.org/rfc/rfc6455#section-5.5.3
125         if (!ReadPayload(wsFrame)) {
126             LOGE("Failed to read ping frame payload");
127             return false;
128         }
129         SendPongFrame(wsFrame.payload);
130     } else if (wsFrame.opcode == EnumToNumber(FrameType::CLOSE)) {
131         // might read payload to response by echoing the status code
132         CloseConnection(CloseStatusCode::NO_STATUS_CODE);
133     }
134     return true;
135 }
136 
Decode()137 std::string WebSocketBase::Decode()
138 {
139     if (auto state = connectionState_.load(); state != ConnectionState::OPEN) {
140         LOGE("Decode failed: websocket not connected, state = %{public}d", EnumToNumber(state));
141         return "";
142     }
143 
144     uint8_t recvbuf[WebSocketFrame::HEADER_LEN] = {0};
145     if (!RecvUnderLock(recvbuf, WebSocketFrame::HEADER_LEN)) {
146         LOGE("Decode failed, client websocket disconnect");
147         CloseConnection(CloseStatusCode::UNEXPECTED_ERROR);
148         return std::string(DECODE_DISCONNECT_MSG);
149     }
150     WebSocketFrame wsFrame(recvbuf);
151     if (!ValidateIncomingFrame(wsFrame)) {
152         LOGE("Received websocket frame is invalid - header is %02x%02x", recvbuf[0], recvbuf[1]);
153         CloseConnection(CloseStatusCode::PROTOCOL_ERROR);
154         return std::string(DECODE_DISCONNECT_MSG);
155     }
156 
157     if (IsControlFrame(wsFrame.opcode)) {
158         if (HandleControlFrame(wsFrame)) {
159             return wsFrame.payload;
160         }
161     } else if (HandleDataFrame(wsFrame)) {
162         return wsFrame.payload;
163     }
164     // Unexpected data, must close the connection.
165     CloseConnection(CloseStatusCode::PROTOCOL_ERROR);
166     return std::string(DECODE_DISCONNECT_MSG);
167 }
168 
IsConnected() const169 bool WebSocketBase::IsConnected() const
170 {
171     return connectionState_.load() == ConnectionState::OPEN;
172 }
173 
SetCloseConnectionCallback(CloseConnectionCallback cb)174 void WebSocketBase::SetCloseConnectionCallback(CloseConnectionCallback cb)
175 {
176     closeCb_ = std::move(cb);
177 }
178 
SetFailConnectionCallback(FailConnectionCallback cb)179 void WebSocketBase::SetFailConnectionCallback(FailConnectionCallback cb)
180 {
181     failCb_ = std::move(cb);
182 }
183 
OnConnectionClose(ConnectionCloseReason status)184 void WebSocketBase::OnConnectionClose(ConnectionCloseReason status)
185 {
186     if (status == ConnectionCloseReason::FAIL) {
187         if (failCb_) {
188             failCb_();
189         }
190     } else if (status == ConnectionCloseReason::CLOSE) {
191         if (closeCb_) {
192             closeCb_();
193         }
194     }
195 }
196 
CloseConnectionSocket(ConnectionCloseReason status)197 void WebSocketBase::CloseConnectionSocket(ConnectionCloseReason status)
198 {
199     OnConnectionClose(status);
200 
201     {
202         // Shared lock due to other thread possibly hanging on `recv` with acquired shared lock.
203         std::shared_lock lock(connectionMutex_);
204         int err = ShutdownSocket(connectionFd_);
205         if (err != 0) {
206             LOGW("Failed to shutdown client socket, errno = %{public}d", errno);
207         }
208     }
209     {
210         // Unique lock due to close and write into `connectionFd_`.
211         // Note that `close` must be also done in critical section,
212         // otherwise the other thread can continue using the outdated and possibly reassigned file descriptor.
213         std::unique_lock lock(connectionMutex_);
214         FdsanClose(reinterpret_cast<fd_t>(connectionFd_));
215         // Reset directly in order to prevent static analyzer warnings.
216         connectionFd_ = -1;
217     }
218 
219     auto expected = ConnectionState::CLOSING;
220     if (!connectionState_.compare_exchange_strong(expected, ConnectionState::CLOSED)) {
221         LOGE("In connection transition CLOSING->CLOSED got initial state = %{public}d", EnumToNumber(expected));
222     }
223 }
224 
SendPongFrame(std::string payload) const225 void WebSocketBase::SendPongFrame(std::string payload) const
226 {
227     const auto frame = CreateFrame(true, FrameType::PONG, std::move(payload));
228     if (!SendUnderLock(frame)) {
229         LOGE("Decode: Send pong frame failed");
230     }
231 }
232 
SendCloseFrame(CloseStatusCode status) const233 void WebSocketBase::SendCloseFrame(CloseStatusCode status) const
234 {
235     const auto frame = CreateFrame(true, FrameType::CLOSE, ToString(status));
236     if (!SendUnderLock(frame)) {
237         LOGE("SendCloseFrame: Send close frame failed");
238     }
239 }
240 
CloseConnection(CloseStatusCode status)241 bool WebSocketBase::CloseConnection(CloseStatusCode status)
242 {
243     auto expected = ConnectionState::OPEN;
244     if (!connectionState_.compare_exchange_strong(expected, ConnectionState::CLOSING)) {
245         // Concurrent connection close detected, do nothing.
246         return false;
247     }
248 
249     LOGI("Close connection, status = %{public}d", static_cast<int>(status));
250     SendCloseFrame(status);
251     // can close connection right after sending back close frame.
252     CloseConnectionSocket(ConnectionCloseReason::CLOSE);
253     return true;
254 }
255 
GetConnectionSocket() const256 int WebSocketBase::GetConnectionSocket() const
257 {
258     return connectionFd_;
259 }
260 
SetConnectionSocket(int socketFd)261 void WebSocketBase::SetConnectionSocket(int socketFd)
262 {
263     FdsanExchangeOwnerTag(reinterpret_cast<fd_t>(socketFd));
264     connectionFd_ = socketFd;
265 }
266 
GetConnectionMutex()267 std::shared_mutex &WebSocketBase::GetConnectionMutex()
268 {
269     return connectionMutex_;
270 }
271 
GetConnectionState() const272 WebSocketBase::ConnectionState WebSocketBase::GetConnectionState() const
273 {
274     return connectionState_.load();
275 }
276 
SetConnectionState(ConnectionState newState)277 WebSocketBase::ConnectionState WebSocketBase::SetConnectionState(ConnectionState newState)
278 {
279     return connectionState_.exchange(newState);
280 }
281 
CompareExchangeConnectionState(ConnectionState & expected,ConnectionState newState)282 bool WebSocketBase::CompareExchangeConnectionState(ConnectionState& expected, ConnectionState newState)
283 {
284     return connectionState_.compare_exchange_strong(expected, newState);
285 }
286 
SendUnderLock(const std::string & message) const287 bool WebSocketBase::SendUnderLock(const std::string& message) const
288 {
289     std::shared_lock lock(connectionMutex_);
290     return Send(connectionFd_, message, 0);
291 }
292 
SendUnderLock(const char * buf,size_t totalLen) const293 bool WebSocketBase::SendUnderLock(const char* buf, size_t totalLen) const
294 {
295     std::shared_lock lock(connectionMutex_);
296     return Send(connectionFd_, buf, totalLen, 0);
297 }
298 
RecvUnderLock(std::string & message) const299 bool WebSocketBase::RecvUnderLock(std::string& message) const
300 {
301     std::shared_lock lock(connectionMutex_);
302     return Recv(connectionFd_, message, 0);
303 }
304 
RecvUnderLock(uint8_t * buf,size_t totalLen) const305 bool WebSocketBase::RecvUnderLock(uint8_t* buf, size_t totalLen) const
306 {
307     std::shared_lock lock(connectionMutex_);
308     return Recv(connectionFd_, buf, totalLen, 0);
309 }
310 
311 /* static */
IsDecodeDisconnectMsg(const std::string & message)312 bool WebSocketBase::IsDecodeDisconnectMsg(const std::string& message)
313 {
314     return message == DECODE_DISCONNECT_MSG;
315 }
316 
317 #if !defined(OHOS_PLATFORM)
318 /* static */
SetWebSocketTimeOut(int32_t fd,uint32_t timeoutLimit)319 bool WebSocketBase::SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit)
320 {
321     if (timeoutLimit > 0) {
322         struct timeval timeout = {static_cast<time_t>(timeoutLimit), 0};
323         if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO,
324             reinterpret_cast<char *>(&timeout), sizeof(timeout)) != SOCKET_SUCCESS) {
325             LOGE("SetWebSocketTimeOut setsockopt SO_SNDTIMEO failed, errno = %{public}d", errno);
326             return false;
327         }
328         if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO,
329             reinterpret_cast<char *>(&timeout), sizeof(timeout)) != SOCKET_SUCCESS) {
330             LOGE("SetWebSocketTimeOut setsockopt SO_RCVTIMEO failed, errno = %{public}d", errno);
331             return false;
332         }
333     }
334     return true;
335 }
336 #else
337 /* static */
SetWebSocketTimeOut(int32_t fd,uint32_t timeoutLimit)338 bool WebSocketBase::SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit)
339 {
340     if (timeoutLimit > 0) {
341         struct timeval timeout = {static_cast<time_t>(timeoutLimit), 0};
342         if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) != SOCKET_SUCCESS) {
343             LOGE("SetWebSocketTimeOut setsockopt SO_SNDTIMEO failed, errno = %{public}d", errno);
344             return false;
345         }
346         if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) != SOCKET_SUCCESS) {
347             LOGE("SetWebSocketTimeOut setsockopt SO_RCVTIMEO failed, errno = %{public}d", errno);
348             return false;
349         }
350     }
351     return true;
352 }
353 #endif
354 
355 #if defined(WINDOWS_PLATFORM)
356 /* static */
ShutdownSocket(int32_t fd)357 int WebSocketBase::ShutdownSocket(int32_t fd)
358 {
359     return shutdown(fd, SD_BOTH);
360 }
361 #else
362 /* static */
ShutdownSocket(int32_t fd)363 int WebSocketBase::ShutdownSocket(int32_t fd)
364 {
365     return shutdown(fd, SHUT_RDWR);
366 }
367 #endif
368 } // namespace OHOS::ArkCompiler::Toolchain
369