• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "common/log_wrapper.h"
17 #include "define.h"
18 #include "frame_builder.h"
19 #include "network.h"
20 #include "websocket_base.h"
21 
22 #include <mutex>
23 
24 namespace OHOS::ArkCompiler::Toolchain {
ToString(CloseStatusCode status)25 static std::string ToString(CloseStatusCode status)
26 {
27     if (status == CloseStatusCode::NO_STATUS_CODE) {
28         return "";
29     }
30     std::string result;
31     PushNumberPerByte(result, EnumToNumber(status));
32     return result;
33 }
34 
~WebSocketBase()35 WebSocketBase::~WebSocketBase() noexcept
36 {
37     if (connectionFd_ != -1) {
38         LOGW("WebSocket connection is closed while destructing the object");
39         close(connectionFd_);
40         // Reset directly in order to prevent static analyzer warnings.
41         connectionFd_ = -1;
42     }
43 }
44 
45 // if the data is too large, it will be split into multiple frames, the first frame will be marked as 0x0
46 // and the last frame will be marked as 0x1.
47 // we just add the 'isLast' parameter to indicate whether it is the last frame.
SendReply(const std::string & message,FrameType frameType,bool isLast) const48 bool WebSocketBase::SendReply(const std::string& message, FrameType frameType, bool isLast) const
49 {
50     if (connectionState_.load() != ConnectionState::OPEN) {
51         LOGE("SendReply failed, websocket not connected");
52         return false;
53     }
54 
55     const auto frame = CreateFrame(isLast, frameType, message);
56     if (!SendUnderLock(frame)) {
57         LOGE("SendReply: send failed");
58         return false;
59     }
60     return true;
61 }
62 
63 /**
64   *  The wired format of this data transmission section is described in detail through ABNFRFC5234.
65   *  When receive the message, we should decode it according the spec. The structure is as follows:
66   *     0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
67   *    +-+-+-+-+-------+-+-------------+-------------------------------+
68   *    |F|R|R|R| opcode|M| Payload len |    Extended payload length    |
69   *    |I|S|S|S|  (4)  |A|     (7)     |             (16/64)           |
70   *    |N|V|V|V|       |S|             |   (if payload len==126/127)   |
71   *    | |1|2|3|       |K|             |                               |
72   *    +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
73   *    |     Extended payload length continued, if payload len == 127  |
74   *    + - - - - - - - - - - - - - - - +-------------------------------+
75   *    |                               |Masking-key, if MASK set to 1  |
76   *    +-------------------------------+-------------------------------+
77   *    | Masking-key (continued)       |          Payload Data         |
78   *    +-------------------------------- - - - - - - - - - - - - - - - +
79   *    :                     Payload Data continued ...                :
80   *    + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
81   *    |                     Payload Data continued ...                |
82   *    +---------------------------------------------------------------+
83   */
84 
ReadPayload(WebSocketFrame & wsFrame) const85 bool WebSocketBase::ReadPayload(WebSocketFrame& wsFrame) const
86 {
87     if (wsFrame.payloadLen == WebSocketFrame::TWO_BYTES_LENTH_ENC) {
88         uint8_t recvbuf[WebSocketFrame::TWO_BYTES_LENTH] = {0};
89         if (!RecvUnderLock(recvbuf, WebSocketFrame::TWO_BYTES_LENTH)) {
90             LOGE("ReadPayload: Recv payloadLen == 126 failed");
91             return false;
92         }
93         wsFrame.payloadLen = NetToHostLongLong(recvbuf, WebSocketFrame::TWO_BYTES_LENTH);
94     } else if (wsFrame.payloadLen == WebSocketFrame::EIGHT_BYTES_LENTH_ENC) {
95         uint8_t recvbuf[WebSocketFrame::EIGHT_BYTES_LENTH] = {0};
96         if (!RecvUnderLock(recvbuf, WebSocketFrame::EIGHT_BYTES_LENTH)) {
97             LOGE("ReadPayload: Recv payloadLen == 127 failed");
98             return false;
99         }
100         wsFrame.payloadLen = NetToHostLongLong(recvbuf, WebSocketFrame::EIGHT_BYTES_LENTH);
101     }
102     return DecodeMessage(wsFrame);
103 }
104 
HandleDataFrame(WebSocketFrame & wsFrame) const105 bool WebSocketBase::HandleDataFrame(WebSocketFrame& wsFrame) const
106 {
107     if (wsFrame.opcode == EnumToNumber(FrameType::TEXT)) {
108         return ReadPayload(wsFrame);
109     } else {
110         LOGW("Received unsupported data frame, opcode = %{public}d", wsFrame.opcode);
111     }
112     return true;
113 }
114 
HandleControlFrame(WebSocketFrame & wsFrame)115 bool WebSocketBase::HandleControlFrame(WebSocketFrame& wsFrame)
116 {
117     if (wsFrame.opcode == EnumToNumber(FrameType::PING)) {
118         // A Pong frame sent in response to a Ping frame must have identical
119         // "Application data" as found in the message body of the Ping frame
120         // being replied to.
121         // https://www.rfc-editor.org/rfc/rfc6455#section-5.5.3
122         if (!ReadPayload(wsFrame)) {
123             LOGE("Failed to read ping frame payload");
124             return false;
125         }
126         SendPongFrame(wsFrame.payload);
127     } else if (wsFrame.opcode == EnumToNumber(FrameType::CLOSE)) {
128         // might read payload to response by echoing the status code
129         CloseConnection(CloseStatusCode::NO_STATUS_CODE);
130     }
131     return true;
132 }
133 
Decode()134 std::string WebSocketBase::Decode()
135 {
136     if (auto state = connectionState_.load(); state != ConnectionState::OPEN) {
137         LOGE("Decode failed: websocket not connected, state = %{public}d", EnumToNumber(state));
138         return "";
139     }
140 
141     uint8_t recvbuf[WebSocketFrame::HEADER_LEN] = {0};
142     if (!RecvUnderLock(recvbuf, WebSocketFrame::HEADER_LEN)) {
143         LOGE("Decode failed, client websocket disconnect");
144         CloseConnection(CloseStatusCode::UNEXPECTED_ERROR);
145         return std::string(DECODE_DISCONNECT_MSG);
146     }
147     WebSocketFrame wsFrame(recvbuf);
148     if (!ValidateIncomingFrame(wsFrame)) {
149         LOGE("Received websocket frame is invalid - header is %02x%02x", recvbuf[0], recvbuf[1]);
150         CloseConnection(CloseStatusCode::PROTOCOL_ERROR);
151         return std::string(DECODE_DISCONNECT_MSG);
152     }
153 
154     if (IsControlFrame(wsFrame.opcode)) {
155         if (HandleControlFrame(wsFrame)) {
156             return wsFrame.payload;
157         }
158     } else if (HandleDataFrame(wsFrame)) {
159         return wsFrame.payload;
160     }
161     // Unexpected data, must close the connection.
162     CloseConnection(CloseStatusCode::PROTOCOL_ERROR);
163     return std::string(DECODE_DISCONNECT_MSG);
164 }
165 
IsConnected() const166 bool WebSocketBase::IsConnected() const
167 {
168     return connectionState_.load() == ConnectionState::OPEN;
169 }
170 
SetCloseConnectionCallback(CloseConnectionCallback cb)171 void WebSocketBase::SetCloseConnectionCallback(CloseConnectionCallback cb)
172 {
173     closeCb_ = std::move(cb);
174 }
175 
SetFailConnectionCallback(FailConnectionCallback cb)176 void WebSocketBase::SetFailConnectionCallback(FailConnectionCallback cb)
177 {
178     failCb_ = std::move(cb);
179 }
180 
OnConnectionClose(ConnectionCloseReason status)181 void WebSocketBase::OnConnectionClose(ConnectionCloseReason status)
182 {
183     if (status == ConnectionCloseReason::FAIL) {
184         if (failCb_) {
185             failCb_();
186         }
187     } else if (status == ConnectionCloseReason::CLOSE) {
188         if (closeCb_) {
189             closeCb_();
190         }
191     }
192 }
193 
CloseConnectionSocket(ConnectionCloseReason status)194 void WebSocketBase::CloseConnectionSocket(ConnectionCloseReason status)
195 {
196     OnConnectionClose(status);
197 
198     {
199         // Shared lock due to other thread possibly hanging on `recv` with acquired shared lock.
200         std::shared_lock lock(connectionMutex_);
201         int err = ShutdownSocket(connectionFd_);
202         if (err != 0) {
203             LOGW("Failed to shutdown client socket, errno = %{public}d", errno);
204         }
205     }
206     {
207         // Unique lock due to close and write into `connectionFd_`.
208         // Note that `close` must be also done in critical section,
209         // otherwise the other thread can continue using the outdated and possibly reassigned file descriptor.
210         std::unique_lock lock(connectionMutex_);
211         close(connectionFd_);
212         // Reset directly in order to prevent static analyzer warnings.
213         connectionFd_ = -1;
214     }
215 
216     auto expected = ConnectionState::CLOSING;
217     if (!connectionState_.compare_exchange_strong(expected, ConnectionState::CLOSED)) {
218         LOGE("In connection transition CLOSING->CLOSED got initial state = %{public}d", EnumToNumber(expected));
219     }
220 }
221 
SendPongFrame(std::string payload) const222 void WebSocketBase::SendPongFrame(std::string payload) const
223 {
224     const auto frame = CreateFrame(true, FrameType::PONG, std::move(payload));
225     if (!SendUnderLock(frame)) {
226         LOGE("Decode: Send pong frame failed");
227     }
228 }
229 
SendCloseFrame(CloseStatusCode status) const230 void WebSocketBase::SendCloseFrame(CloseStatusCode status) const
231 {
232     const auto frame = CreateFrame(true, FrameType::CLOSE, ToString(status));
233     if (!SendUnderLock(frame)) {
234         LOGE("SendCloseFrame: Send close frame failed");
235     }
236 }
237 
CloseConnection(CloseStatusCode status)238 bool WebSocketBase::CloseConnection(CloseStatusCode status)
239 {
240     auto expected = ConnectionState::OPEN;
241     if (!connectionState_.compare_exchange_strong(expected, ConnectionState::CLOSING)) {
242         // Concurrent connection close detected, do nothing.
243         return false;
244     }
245 
246     LOGI("Close connection, status = %{public}d", static_cast<int>(status));
247     SendCloseFrame(status);
248     // can close connection right after sending back close frame.
249     CloseConnectionSocket(ConnectionCloseReason::CLOSE);
250     return true;
251 }
252 
GetConnectionSocket() const253 int WebSocketBase::GetConnectionSocket() const
254 {
255     return connectionFd_;
256 }
257 
SetConnectionSocket(int socketFd)258 void WebSocketBase::SetConnectionSocket(int socketFd)
259 {
260     connectionFd_ = socketFd;
261 }
262 
GetConnectionMutex()263 std::shared_mutex &WebSocketBase::GetConnectionMutex()
264 {
265     return connectionMutex_;
266 }
267 
GetConnectionState() const268 WebSocketBase::ConnectionState WebSocketBase::GetConnectionState() const
269 {
270     return connectionState_.load();
271 }
272 
SetConnectionState(ConnectionState newState)273 WebSocketBase::ConnectionState WebSocketBase::SetConnectionState(ConnectionState newState)
274 {
275     return connectionState_.exchange(newState);
276 }
277 
CompareExchangeConnectionState(ConnectionState & expected,ConnectionState newState)278 bool WebSocketBase::CompareExchangeConnectionState(ConnectionState& expected, ConnectionState newState)
279 {
280     return connectionState_.compare_exchange_strong(expected, newState);
281 }
282 
SendUnderLock(const std::string & message) const283 bool WebSocketBase::SendUnderLock(const std::string& message) const
284 {
285     std::shared_lock lock(connectionMutex_);
286     return Send(connectionFd_, message, 0);
287 }
288 
SendUnderLock(const char * buf,size_t totalLen) const289 bool WebSocketBase::SendUnderLock(const char* buf, size_t totalLen) const
290 {
291     std::shared_lock lock(connectionMutex_);
292     return Send(connectionFd_, buf, totalLen, 0);
293 }
294 
RecvUnderLock(std::string & message) const295 bool WebSocketBase::RecvUnderLock(std::string& message) const
296 {
297     std::shared_lock lock(connectionMutex_);
298     return Recv(connectionFd_, message, 0);
299 }
300 
RecvUnderLock(uint8_t * buf,size_t totalLen) const301 bool WebSocketBase::RecvUnderLock(uint8_t* buf, size_t totalLen) const
302 {
303     std::shared_lock lock(connectionMutex_);
304     return Recv(connectionFd_, buf, totalLen, 0);
305 }
306 
307 /* static */
IsDecodeDisconnectMsg(const std::string & message)308 bool WebSocketBase::IsDecodeDisconnectMsg(const std::string& message)
309 {
310     return message == DECODE_DISCONNECT_MSG;
311 }
312 
313 #if !defined(OHOS_PLATFORM)
314 /* static */
SetWebSocketTimeOut(int32_t fd,uint32_t timeoutLimit)315 bool WebSocketBase::SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit)
316 {
317     if (timeoutLimit > 0) {
318         struct timeval timeout = {static_cast<time_t>(timeoutLimit), 0};
319         if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO,
320             reinterpret_cast<char *>(&timeout), sizeof(timeout)) != SOCKET_SUCCESS) {
321             LOGE("SetWebSocketTimeOut setsockopt SO_SNDTIMEO failed, errno = %{public}d", errno);
322             return false;
323         }
324         if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO,
325             reinterpret_cast<char *>(&timeout), sizeof(timeout)) != SOCKET_SUCCESS) {
326             LOGE("SetWebSocketTimeOut setsockopt SO_RCVTIMEO failed, errno = %{public}d", errno);
327             return false;
328         }
329     }
330     return true;
331 }
332 #else
333 /* static */
SetWebSocketTimeOut(int32_t fd,uint32_t timeoutLimit)334 bool WebSocketBase::SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit)
335 {
336     if (timeoutLimit > 0) {
337         struct timeval timeout = {static_cast<time_t>(timeoutLimit), 0};
338         if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) != SOCKET_SUCCESS) {
339             LOGE("SetWebSocketTimeOut setsockopt SO_SNDTIMEO failed, errno = %{public}d", errno);
340             return false;
341         }
342         if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) != SOCKET_SUCCESS) {
343             LOGE("SetWebSocketTimeOut setsockopt SO_RCVTIMEO failed, errno = %{public}d", errno);
344             return false;
345         }
346     }
347     return true;
348 }
349 #endif
350 
351 #if defined(WINDOWS_PLATFORM)
352 /* static */
ShutdownSocket(int32_t fd)353 int WebSocketBase::ShutdownSocket(int32_t fd)
354 {
355     return shutdown(fd, SD_BOTH);
356 }
357 #else
358 /* static */
ShutdownSocket(int32_t fd)359 int WebSocketBase::ShutdownSocket(int32_t fd)
360 {
361     return shutdown(fd, SHUT_RDWR);
362 }
363 #endif
364 } // namespace OHOS::ArkCompiler::Toolchain
365