1 /*
2 * Copyright (c) 2022-2025 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "common/log_wrapper.h"
17 #include "define.h"
18 #include "platform/file.h"
19 #include "frame_builder.h"
20 #include "network.h"
21 #include "websocket_base.h"
22
23 #include <mutex>
24
25 namespace OHOS::ArkCompiler::Toolchain {
26 std::mutex g_sendReplymutex;
ToString(CloseStatusCode status)27 static std::string ToString(CloseStatusCode status)
28 {
29 if (status == CloseStatusCode::NO_STATUS_CODE) {
30 return "";
31 }
32 std::string result;
33 PushNumberPerByte(result, EnumToNumber(status));
34 return result;
35 }
36
~WebSocketBase()37 WebSocketBase::~WebSocketBase() noexcept
38 {
39 if (connectionFd_ != -1) {
40 LOGW("WebSocket connection is closed while destructing the object");
41 FdsanClose(reinterpret_cast<fd_t>(connectionFd_));
42 // Reset directly in order to prevent static analyzer warnings.
43 connectionFd_ = -1;
44 }
45 }
46
47 // if the data is too large, it will be split into multiple frames, the first frame will be marked as 0x0
48 // and the last frame will be marked as 0x1.
49 // we just add the 'isLast' parameter to indicate whether it is the last frame.
SendReply(const std::string & message,FrameType frameType,bool isLast) const50 bool WebSocketBase::SendReply(const std::string& message, FrameType frameType, bool isLast) const
51 {
52 std::lock_guard<std::mutex> lock(g_sendReplymutex);
53 if (connectionState_.load() != ConnectionState::OPEN) {
54 LOGE("SendReply failed, websocket not connected");
55 return false;
56 }
57
58 const auto frame = CreateFrame(isLast, frameType, message);
59 if (!SendUnderLock(frame)) {
60 LOGE("SendReply: send failed");
61 return false;
62 }
63 return true;
64 }
65
66 /**
67 * The wired format of this data transmission section is described in detail through ABNFRFC5234.
68 * When receive the message, we should decode it according the spec. The structure is as follows:
69 * 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
70 * +-+-+-+-+-------+-+-------------+-------------------------------+
71 * |F|R|R|R| opcode|M| Payload len | Extended payload length |
72 * |I|S|S|S| (4) |A| (7) | (16/64) |
73 * |N|V|V|V| |S| | (if payload len==126/127) |
74 * | |1|2|3| |K| | |
75 * +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
76 * | Extended payload length continued, if payload len == 127 |
77 * + - - - - - - - - - - - - - - - +-------------------------------+
78 * | |Masking-key, if MASK set to 1 |
79 * +-------------------------------+-------------------------------+
80 * | Masking-key (continued) | Payload Data |
81 * +-------------------------------- - - - - - - - - - - - - - - - +
82 * : Payload Data continued ... :
83 * + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
84 * | Payload Data continued ... |
85 * +---------------------------------------------------------------+
86 */
87
ReadPayload(WebSocketFrame & wsFrame) const88 bool WebSocketBase::ReadPayload(WebSocketFrame& wsFrame) const
89 {
90 if (wsFrame.payloadLen == WebSocketFrame::TWO_BYTES_LENTH_ENC) {
91 uint8_t recvbuf[WebSocketFrame::TWO_BYTES_LENTH] = {0};
92 if (!RecvUnderLock(recvbuf, WebSocketFrame::TWO_BYTES_LENTH)) {
93 LOGE("ReadPayload: Recv payloadLen == 126 failed");
94 return false;
95 }
96 wsFrame.payloadLen = NetToHostLongLong(recvbuf, WebSocketFrame::TWO_BYTES_LENTH);
97 } else if (wsFrame.payloadLen == WebSocketFrame::EIGHT_BYTES_LENTH_ENC) {
98 uint8_t recvbuf[WebSocketFrame::EIGHT_BYTES_LENTH] = {0};
99 if (!RecvUnderLock(recvbuf, WebSocketFrame::EIGHT_BYTES_LENTH)) {
100 LOGE("ReadPayload: Recv payloadLen == 127 failed");
101 return false;
102 }
103 wsFrame.payloadLen = NetToHostLongLong(recvbuf, WebSocketFrame::EIGHT_BYTES_LENTH);
104 }
105 return DecodeMessage(wsFrame);
106 }
107
HandleDataFrame(WebSocketFrame & wsFrame) const108 bool WebSocketBase::HandleDataFrame(WebSocketFrame& wsFrame) const
109 {
110 if (wsFrame.opcode == EnumToNumber(FrameType::TEXT)) {
111 return ReadPayload(wsFrame);
112 } else {
113 LOGW("Received unsupported data frame, opcode = %{public}d", wsFrame.opcode);
114 }
115 return true;
116 }
117
HandleControlFrame(WebSocketFrame & wsFrame)118 bool WebSocketBase::HandleControlFrame(WebSocketFrame& wsFrame)
119 {
120 if (wsFrame.opcode == EnumToNumber(FrameType::PING)) {
121 // A Pong frame sent in response to a Ping frame must have identical
122 // "Application data" as found in the message body of the Ping frame
123 // being replied to.
124 // https://www.rfc-editor.org/rfc/rfc6455#section-5.5.3
125 if (!ReadPayload(wsFrame)) {
126 LOGE("Failed to read ping frame payload");
127 return false;
128 }
129 SendPongFrame(wsFrame.payload);
130 } else if (wsFrame.opcode == EnumToNumber(FrameType::CLOSE)) {
131 // might read payload to response by echoing the status code
132 CloseConnection(CloseStatusCode::NO_STATUS_CODE);
133 }
134 return true;
135 }
136
Decode()137 std::string WebSocketBase::Decode()
138 {
139 if (auto state = connectionState_.load(); state != ConnectionState::OPEN) {
140 LOGE("Decode failed: websocket not connected, state = %{public}d", EnumToNumber(state));
141 return "";
142 }
143
144 uint8_t recvbuf[WebSocketFrame::HEADER_LEN] = {0};
145 if (!RecvUnderLock(recvbuf, WebSocketFrame::HEADER_LEN)) {
146 LOGE("Decode failed, client websocket disconnect");
147 CloseConnection(CloseStatusCode::UNEXPECTED_ERROR);
148 return std::string(DECODE_DISCONNECT_MSG);
149 }
150 WebSocketFrame wsFrame(recvbuf);
151 if (!ValidateIncomingFrame(wsFrame)) {
152 LOGE("Received websocket frame is invalid - header is %02x%02x", recvbuf[0], recvbuf[1]);
153 CloseConnection(CloseStatusCode::PROTOCOL_ERROR);
154 return std::string(DECODE_DISCONNECT_MSG);
155 }
156
157 if (IsControlFrame(wsFrame.opcode)) {
158 if (HandleControlFrame(wsFrame)) {
159 return wsFrame.payload;
160 }
161 } else if (HandleDataFrame(wsFrame)) {
162 return wsFrame.payload;
163 }
164 // Unexpected data, must close the connection.
165 CloseConnection(CloseStatusCode::PROTOCOL_ERROR);
166 return std::string(DECODE_DISCONNECT_MSG);
167 }
168
IsConnected() const169 bool WebSocketBase::IsConnected() const
170 {
171 return connectionState_.load() == ConnectionState::OPEN;
172 }
173
SetCloseConnectionCallback(CloseConnectionCallback cb)174 void WebSocketBase::SetCloseConnectionCallback(CloseConnectionCallback cb)
175 {
176 closeCb_ = std::move(cb);
177 }
178
SetFailConnectionCallback(FailConnectionCallback cb)179 void WebSocketBase::SetFailConnectionCallback(FailConnectionCallback cb)
180 {
181 failCb_ = std::move(cb);
182 }
183
OnConnectionClose(ConnectionCloseReason status)184 void WebSocketBase::OnConnectionClose(ConnectionCloseReason status)
185 {
186 if (status == ConnectionCloseReason::FAIL) {
187 if (failCb_) {
188 failCb_();
189 }
190 } else if (status == ConnectionCloseReason::CLOSE) {
191 if (closeCb_) {
192 closeCb_();
193 }
194 }
195 }
196
CloseConnectionSocket(ConnectionCloseReason status)197 void WebSocketBase::CloseConnectionSocket(ConnectionCloseReason status)
198 {
199 OnConnectionClose(status);
200
201 {
202 // Shared lock due to other thread possibly hanging on `recv` with acquired shared lock.
203 std::shared_lock lock(connectionMutex_);
204 int err = ShutdownSocket(connectionFd_);
205 if (err != 0) {
206 LOGW("Failed to shutdown client socket, errno = %{public}d", errno);
207 }
208 }
209 {
210 // Unique lock due to close and write into `connectionFd_`.
211 // Note that `close` must be also done in critical section,
212 // otherwise the other thread can continue using the outdated and possibly reassigned file descriptor.
213 std::unique_lock lock(connectionMutex_);
214 FdsanClose(reinterpret_cast<fd_t>(connectionFd_));
215 // Reset directly in order to prevent static analyzer warnings.
216 connectionFd_ = -1;
217 }
218
219 auto expected = ConnectionState::CLOSING;
220 if (!connectionState_.compare_exchange_strong(expected, ConnectionState::CLOSED)) {
221 LOGE("In connection transition CLOSING->CLOSED got initial state = %{public}d", EnumToNumber(expected));
222 }
223 }
224
SendPongFrame(std::string payload) const225 void WebSocketBase::SendPongFrame(std::string payload) const
226 {
227 const auto frame = CreateFrame(true, FrameType::PONG, std::move(payload));
228 if (!SendUnderLock(frame)) {
229 LOGE("Decode: Send pong frame failed");
230 }
231 }
232
SendCloseFrame(CloseStatusCode status) const233 void WebSocketBase::SendCloseFrame(CloseStatusCode status) const
234 {
235 const auto frame = CreateFrame(true, FrameType::CLOSE, ToString(status));
236 if (!SendUnderLock(frame)) {
237 LOGE("SendCloseFrame: Send close frame failed");
238 }
239 }
240
CloseConnection(CloseStatusCode status)241 bool WebSocketBase::CloseConnection(CloseStatusCode status)
242 {
243 auto expected = ConnectionState::OPEN;
244 if (!connectionState_.compare_exchange_strong(expected, ConnectionState::CLOSING)) {
245 // Concurrent connection close detected, do nothing.
246 return false;
247 }
248
249 LOGI("Close connection, status = %{public}d", static_cast<int>(status));
250 SendCloseFrame(status);
251 // can close connection right after sending back close frame.
252 CloseConnectionSocket(ConnectionCloseReason::CLOSE);
253 return true;
254 }
255
GetConnectionSocket() const256 int WebSocketBase::GetConnectionSocket() const
257 {
258 return connectionFd_;
259 }
260
SetConnectionSocket(int socketFd)261 void WebSocketBase::SetConnectionSocket(int socketFd)
262 {
263 FdsanExchangeOwnerTag(reinterpret_cast<fd_t>(socketFd));
264 connectionFd_ = socketFd;
265 }
266
GetConnectionMutex()267 std::shared_mutex &WebSocketBase::GetConnectionMutex()
268 {
269 return connectionMutex_;
270 }
271
GetConnectionState() const272 WebSocketBase::ConnectionState WebSocketBase::GetConnectionState() const
273 {
274 return connectionState_.load();
275 }
276
SetConnectionState(ConnectionState newState)277 WebSocketBase::ConnectionState WebSocketBase::SetConnectionState(ConnectionState newState)
278 {
279 return connectionState_.exchange(newState);
280 }
281
CompareExchangeConnectionState(ConnectionState & expected,ConnectionState newState)282 bool WebSocketBase::CompareExchangeConnectionState(ConnectionState& expected, ConnectionState newState)
283 {
284 return connectionState_.compare_exchange_strong(expected, newState);
285 }
286
SendUnderLock(const std::string & message) const287 bool WebSocketBase::SendUnderLock(const std::string& message) const
288 {
289 std::shared_lock lock(connectionMutex_);
290 return Send(connectionFd_, message, 0);
291 }
292
SendUnderLock(const char * buf,size_t totalLen) const293 bool WebSocketBase::SendUnderLock(const char* buf, size_t totalLen) const
294 {
295 std::shared_lock lock(connectionMutex_);
296 return Send(connectionFd_, buf, totalLen, 0);
297 }
298
RecvUnderLock(std::string & message) const299 bool WebSocketBase::RecvUnderLock(std::string& message) const
300 {
301 std::shared_lock lock(connectionMutex_);
302 return Recv(connectionFd_, message, 0);
303 }
304
RecvUnderLock(uint8_t * buf,size_t totalLen) const305 bool WebSocketBase::RecvUnderLock(uint8_t* buf, size_t totalLen) const
306 {
307 std::shared_lock lock(connectionMutex_);
308 return Recv(connectionFd_, buf, totalLen, 0);
309 }
310
311 /* static */
IsDecodeDisconnectMsg(const std::string & message)312 bool WebSocketBase::IsDecodeDisconnectMsg(const std::string& message)
313 {
314 return message == DECODE_DISCONNECT_MSG;
315 }
316
317 #if !defined(OHOS_PLATFORM)
318 /* static */
SetWebSocketTimeOut(int32_t fd,uint32_t timeoutLimit)319 bool WebSocketBase::SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit)
320 {
321 if (timeoutLimit > 0) {
322 struct timeval timeout = {static_cast<time_t>(timeoutLimit), 0};
323 if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO,
324 reinterpret_cast<char *>(&timeout), sizeof(timeout)) != SOCKET_SUCCESS) {
325 LOGE("SetWebSocketTimeOut setsockopt SO_SNDTIMEO failed, errno = %{public}d", errno);
326 return false;
327 }
328 if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO,
329 reinterpret_cast<char *>(&timeout), sizeof(timeout)) != SOCKET_SUCCESS) {
330 LOGE("SetWebSocketTimeOut setsockopt SO_RCVTIMEO failed, errno = %{public}d", errno);
331 return false;
332 }
333 }
334 return true;
335 }
336 #else
337 /* static */
SetWebSocketTimeOut(int32_t fd,uint32_t timeoutLimit)338 bool WebSocketBase::SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit)
339 {
340 if (timeoutLimit > 0) {
341 struct timeval timeout = {static_cast<time_t>(timeoutLimit), 0};
342 if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) != SOCKET_SUCCESS) {
343 LOGE("SetWebSocketTimeOut setsockopt SO_SNDTIMEO failed, errno = %{public}d", errno);
344 return false;
345 }
346 if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) != SOCKET_SUCCESS) {
347 LOGE("SetWebSocketTimeOut setsockopt SO_RCVTIMEO failed, errno = %{public}d", errno);
348 return false;
349 }
350 }
351 return true;
352 }
353 #endif
354
355 #if defined(WINDOWS_PLATFORM)
356 /* static */
ShutdownSocket(int32_t fd)357 int WebSocketBase::ShutdownSocket(int32_t fd)
358 {
359 return shutdown(fd, SD_BOTH);
360 }
361 #else
362 /* static */
ShutdownSocket(int32_t fd)363 int WebSocketBase::ShutdownSocket(int32_t fd)
364 {
365 return shutdown(fd, SHUT_RDWR);
366 }
367 #endif
368 } // namespace OHOS::ArkCompiler::Toolchain
369