1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "common/log_wrapper.h"
17 #include "define.h"
18 #include "frame_builder.h"
19 #include "network.h"
20 #include "websocket_base.h"
21
22 #include <mutex>
23
24 namespace OHOS::ArkCompiler::Toolchain {
ToString(CloseStatusCode status)25 static std::string ToString(CloseStatusCode status)
26 {
27 if (status == CloseStatusCode::NO_STATUS_CODE) {
28 return "";
29 }
30 std::string result;
31 PushNumberPerByte(result, EnumToNumber(status));
32 return result;
33 }
34
~WebSocketBase()35 WebSocketBase::~WebSocketBase() noexcept
36 {
37 if (connectionFd_ != -1) {
38 LOGW("WebSocket connection is closed while destructing the object");
39 close(connectionFd_);
40 // Reset directly in order to prevent static analyzer warnings.
41 connectionFd_ = -1;
42 }
43 }
44
45 // if the data is too large, it will be split into multiple frames, the first frame will be marked as 0x0
46 // and the last frame will be marked as 0x1.
47 // we just add the 'isLast' parameter to indicate whether it is the last frame.
SendReply(const std::string & message,FrameType frameType,bool isLast) const48 bool WebSocketBase::SendReply(const std::string& message, FrameType frameType, bool isLast) const
49 {
50 if (connectionState_.load() != ConnectionState::OPEN) {
51 LOGE("SendReply failed, websocket not connected");
52 return false;
53 }
54
55 const auto frame = CreateFrame(isLast, frameType, message);
56 if (!SendUnderLock(frame)) {
57 LOGE("SendReply: send failed");
58 return false;
59 }
60 return true;
61 }
62
63 /**
64 * The wired format of this data transmission section is described in detail through ABNFRFC5234.
65 * When receive the message, we should decode it according the spec. The structure is as follows:
66 * 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
67 * +-+-+-+-+-------+-+-------------+-------------------------------+
68 * |F|R|R|R| opcode|M| Payload len | Extended payload length |
69 * |I|S|S|S| (4) |A| (7) | (16/64) |
70 * |N|V|V|V| |S| | (if payload len==126/127) |
71 * | |1|2|3| |K| | |
72 * +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
73 * | Extended payload length continued, if payload len == 127 |
74 * + - - - - - - - - - - - - - - - +-------------------------------+
75 * | |Masking-key, if MASK set to 1 |
76 * +-------------------------------+-------------------------------+
77 * | Masking-key (continued) | Payload Data |
78 * +-------------------------------- - - - - - - - - - - - - - - - +
79 * : Payload Data continued ... :
80 * + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
81 * | Payload Data continued ... |
82 * +---------------------------------------------------------------+
83 */
84
ReadPayload(WebSocketFrame & wsFrame) const85 bool WebSocketBase::ReadPayload(WebSocketFrame& wsFrame) const
86 {
87 if (wsFrame.payloadLen == WebSocketFrame::TWO_BYTES_LENTH_ENC) {
88 uint8_t recvbuf[WebSocketFrame::TWO_BYTES_LENTH] = {0};
89 if (!RecvUnderLock(recvbuf, WebSocketFrame::TWO_BYTES_LENTH)) {
90 LOGE("ReadPayload: Recv payloadLen == 126 failed");
91 return false;
92 }
93 wsFrame.payloadLen = NetToHostLongLong(recvbuf, WebSocketFrame::TWO_BYTES_LENTH);
94 } else if (wsFrame.payloadLen == WebSocketFrame::EIGHT_BYTES_LENTH_ENC) {
95 uint8_t recvbuf[WebSocketFrame::EIGHT_BYTES_LENTH] = {0};
96 if (!RecvUnderLock(recvbuf, WebSocketFrame::EIGHT_BYTES_LENTH)) {
97 LOGE("ReadPayload: Recv payloadLen == 127 failed");
98 return false;
99 }
100 wsFrame.payloadLen = NetToHostLongLong(recvbuf, WebSocketFrame::EIGHT_BYTES_LENTH);
101 }
102 return DecodeMessage(wsFrame);
103 }
104
HandleDataFrame(WebSocketFrame & wsFrame) const105 bool WebSocketBase::HandleDataFrame(WebSocketFrame& wsFrame) const
106 {
107 if (wsFrame.opcode == EnumToNumber(FrameType::TEXT)) {
108 return ReadPayload(wsFrame);
109 } else {
110 LOGW("Received unsupported data frame, opcode = %{public}d", wsFrame.opcode);
111 }
112 return true;
113 }
114
HandleControlFrame(WebSocketFrame & wsFrame)115 bool WebSocketBase::HandleControlFrame(WebSocketFrame& wsFrame)
116 {
117 if (wsFrame.opcode == EnumToNumber(FrameType::PING)) {
118 // A Pong frame sent in response to a Ping frame must have identical
119 // "Application data" as found in the message body of the Ping frame
120 // being replied to.
121 // https://www.rfc-editor.org/rfc/rfc6455#section-5.5.3
122 if (!ReadPayload(wsFrame)) {
123 LOGE("Failed to read ping frame payload");
124 return false;
125 }
126 SendPongFrame(wsFrame.payload);
127 } else if (wsFrame.opcode == EnumToNumber(FrameType::CLOSE)) {
128 // might read payload to response by echoing the status code
129 CloseConnection(CloseStatusCode::NO_STATUS_CODE);
130 }
131 return true;
132 }
133
Decode()134 std::string WebSocketBase::Decode()
135 {
136 if (auto state = connectionState_.load(); state != ConnectionState::OPEN) {
137 LOGE("Decode failed: websocket not connected, state = %{public}d", EnumToNumber(state));
138 return "";
139 }
140
141 uint8_t recvbuf[WebSocketFrame::HEADER_LEN] = {0};
142 if (!RecvUnderLock(recvbuf, WebSocketFrame::HEADER_LEN)) {
143 LOGE("Decode failed, client websocket disconnect");
144 CloseConnection(CloseStatusCode::UNEXPECTED_ERROR);
145 return std::string(DECODE_DISCONNECT_MSG);
146 }
147 WebSocketFrame wsFrame(recvbuf);
148 if (!ValidateIncomingFrame(wsFrame)) {
149 LOGE("Received websocket frame is invalid - header is %02x%02x", recvbuf[0], recvbuf[1]);
150 CloseConnection(CloseStatusCode::PROTOCOL_ERROR);
151 return std::string(DECODE_DISCONNECT_MSG);
152 }
153
154 if (IsControlFrame(wsFrame.opcode)) {
155 if (HandleControlFrame(wsFrame)) {
156 return wsFrame.payload;
157 }
158 } else if (HandleDataFrame(wsFrame)) {
159 return wsFrame.payload;
160 }
161 // Unexpected data, must close the connection.
162 CloseConnection(CloseStatusCode::PROTOCOL_ERROR);
163 return std::string(DECODE_DISCONNECT_MSG);
164 }
165
IsConnected() const166 bool WebSocketBase::IsConnected() const
167 {
168 return connectionState_.load() == ConnectionState::OPEN;
169 }
170
SetCloseConnectionCallback(CloseConnectionCallback cb)171 void WebSocketBase::SetCloseConnectionCallback(CloseConnectionCallback cb)
172 {
173 closeCb_ = std::move(cb);
174 }
175
SetFailConnectionCallback(FailConnectionCallback cb)176 void WebSocketBase::SetFailConnectionCallback(FailConnectionCallback cb)
177 {
178 failCb_ = std::move(cb);
179 }
180
OnConnectionClose(ConnectionCloseReason status)181 void WebSocketBase::OnConnectionClose(ConnectionCloseReason status)
182 {
183 if (status == ConnectionCloseReason::FAIL) {
184 if (failCb_) {
185 failCb_();
186 }
187 } else if (status == ConnectionCloseReason::CLOSE) {
188 if (closeCb_) {
189 closeCb_();
190 }
191 }
192 }
193
CloseConnectionSocket(ConnectionCloseReason status)194 void WebSocketBase::CloseConnectionSocket(ConnectionCloseReason status)
195 {
196 OnConnectionClose(status);
197
198 {
199 // Shared lock due to other thread possibly hanging on `recv` with acquired shared lock.
200 std::shared_lock lock(connectionMutex_);
201 int err = ShutdownSocket(connectionFd_);
202 if (err != 0) {
203 LOGW("Failed to shutdown client socket, errno = %{public}d", errno);
204 }
205 }
206 {
207 // Unique lock due to close and write into `connectionFd_`.
208 // Note that `close` must be also done in critical section,
209 // otherwise the other thread can continue using the outdated and possibly reassigned file descriptor.
210 std::unique_lock lock(connectionMutex_);
211 close(connectionFd_);
212 // Reset directly in order to prevent static analyzer warnings.
213 connectionFd_ = -1;
214 }
215
216 auto expected = ConnectionState::CLOSING;
217 if (!connectionState_.compare_exchange_strong(expected, ConnectionState::CLOSED)) {
218 LOGE("In connection transition CLOSING->CLOSED got initial state = %{public}d", EnumToNumber(expected));
219 }
220 }
221
SendPongFrame(std::string payload) const222 void WebSocketBase::SendPongFrame(std::string payload) const
223 {
224 const auto frame = CreateFrame(true, FrameType::PONG, std::move(payload));
225 if (!SendUnderLock(frame)) {
226 LOGE("Decode: Send pong frame failed");
227 }
228 }
229
SendCloseFrame(CloseStatusCode status) const230 void WebSocketBase::SendCloseFrame(CloseStatusCode status) const
231 {
232 const auto frame = CreateFrame(true, FrameType::CLOSE, ToString(status));
233 if (!SendUnderLock(frame)) {
234 LOGE("SendCloseFrame: Send close frame failed");
235 }
236 }
237
CloseConnection(CloseStatusCode status)238 bool WebSocketBase::CloseConnection(CloseStatusCode status)
239 {
240 auto expected = ConnectionState::OPEN;
241 if (!connectionState_.compare_exchange_strong(expected, ConnectionState::CLOSING)) {
242 // Concurrent connection close detected, do nothing.
243 return false;
244 }
245
246 LOGI("Close connection, status = %{public}d", static_cast<int>(status));
247 SendCloseFrame(status);
248 // can close connection right after sending back close frame.
249 CloseConnectionSocket(ConnectionCloseReason::CLOSE);
250 return true;
251 }
252
GetConnectionSocket() const253 int WebSocketBase::GetConnectionSocket() const
254 {
255 return connectionFd_;
256 }
257
SetConnectionSocket(int socketFd)258 void WebSocketBase::SetConnectionSocket(int socketFd)
259 {
260 connectionFd_ = socketFd;
261 }
262
GetConnectionMutex()263 std::shared_mutex &WebSocketBase::GetConnectionMutex()
264 {
265 return connectionMutex_;
266 }
267
GetConnectionState() const268 WebSocketBase::ConnectionState WebSocketBase::GetConnectionState() const
269 {
270 return connectionState_.load();
271 }
272
SetConnectionState(ConnectionState newState)273 WebSocketBase::ConnectionState WebSocketBase::SetConnectionState(ConnectionState newState)
274 {
275 return connectionState_.exchange(newState);
276 }
277
CompareExchangeConnectionState(ConnectionState & expected,ConnectionState newState)278 bool WebSocketBase::CompareExchangeConnectionState(ConnectionState& expected, ConnectionState newState)
279 {
280 return connectionState_.compare_exchange_strong(expected, newState);
281 }
282
SendUnderLock(const std::string & message) const283 bool WebSocketBase::SendUnderLock(const std::string& message) const
284 {
285 std::shared_lock lock(connectionMutex_);
286 return Send(connectionFd_, message, 0);
287 }
288
SendUnderLock(const char * buf,size_t totalLen) const289 bool WebSocketBase::SendUnderLock(const char* buf, size_t totalLen) const
290 {
291 std::shared_lock lock(connectionMutex_);
292 return Send(connectionFd_, buf, totalLen, 0);
293 }
294
RecvUnderLock(std::string & message) const295 bool WebSocketBase::RecvUnderLock(std::string& message) const
296 {
297 std::shared_lock lock(connectionMutex_);
298 return Recv(connectionFd_, message, 0);
299 }
300
RecvUnderLock(uint8_t * buf,size_t totalLen) const301 bool WebSocketBase::RecvUnderLock(uint8_t* buf, size_t totalLen) const
302 {
303 std::shared_lock lock(connectionMutex_);
304 return Recv(connectionFd_, buf, totalLen, 0);
305 }
306
307 /* static */
IsDecodeDisconnectMsg(const std::string & message)308 bool WebSocketBase::IsDecodeDisconnectMsg(const std::string& message)
309 {
310 return message == DECODE_DISCONNECT_MSG;
311 }
312
313 #if !defined(OHOS_PLATFORM)
314 /* static */
SetWebSocketTimeOut(int32_t fd,uint32_t timeoutLimit)315 bool WebSocketBase::SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit)
316 {
317 if (timeoutLimit > 0) {
318 struct timeval timeout = {static_cast<time_t>(timeoutLimit), 0};
319 if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO,
320 reinterpret_cast<char *>(&timeout), sizeof(timeout)) != SOCKET_SUCCESS) {
321 LOGE("SetWebSocketTimeOut setsockopt SO_SNDTIMEO failed, errno = %{public}d", errno);
322 return false;
323 }
324 if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO,
325 reinterpret_cast<char *>(&timeout), sizeof(timeout)) != SOCKET_SUCCESS) {
326 LOGE("SetWebSocketTimeOut setsockopt SO_RCVTIMEO failed, errno = %{public}d", errno);
327 return false;
328 }
329 }
330 return true;
331 }
332 #else
333 /* static */
SetWebSocketTimeOut(int32_t fd,uint32_t timeoutLimit)334 bool WebSocketBase::SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit)
335 {
336 if (timeoutLimit > 0) {
337 struct timeval timeout = {static_cast<time_t>(timeoutLimit), 0};
338 if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) != SOCKET_SUCCESS) {
339 LOGE("SetWebSocketTimeOut setsockopt SO_SNDTIMEO failed, errno = %{public}d", errno);
340 return false;
341 }
342 if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) != SOCKET_SUCCESS) {
343 LOGE("SetWebSocketTimeOut setsockopt SO_RCVTIMEO failed, errno = %{public}d", errno);
344 return false;
345 }
346 }
347 return true;
348 }
349 #endif
350
351 #if defined(WINDOWS_PLATFORM)
352 /* static */
ShutdownSocket(int32_t fd)353 int WebSocketBase::ShutdownSocket(int32_t fd)
354 {
355 return shutdown(fd, SD_BOTH);
356 }
357 #else
358 /* static */
ShutdownSocket(int32_t fd)359 int WebSocketBase::ShutdownSocket(int32_t fd)
360 {
361 return shutdown(fd, SHUT_RDWR);
362 }
363 #endif
364 } // namespace OHOS::ArkCompiler::Toolchain
365