• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef ARKCOMPILER_TOOLCHAIN_WEBSOCKET_WEBSOCKET_BASE_H
17 #define ARKCOMPILER_TOOLCHAIN_WEBSOCKET_WEBSOCKET_BASE_H
18 
19 #include "web_socket_frame.h"
20 
21 #include <atomic>
22 #include <functional>
23 #include <shared_mutex>
24 #include <type_traits>
25 
26 namespace OHOS::ArkCompiler::Toolchain {
27 enum CloseStatusCode : uint16_t {
28     NO_STATUS_CODE = 0,
29     NORMAL = 1000,
30     SERVER_GO_AWAY = 1001,
31     PROTOCOL_ERROR = 1002,
32     UNACCEPTABLE_DATA = 1003,
33     INCONSISTENT_DATA = 1007,
34     POLICY_VIOLATION = 1008,
35     MESSAGE_TOO_BIG = 1009,
36     UNEXPECTED_ERROR = 1011,
37 };
38 
39 class WebSocketBase {
40 public:
41     using CloseConnectionCallback = std::function<void()>;
42     using FailConnectionCallback = std::function<void()>;
43 
44 public:
45     static bool IsDecodeDisconnectMsg(const std::string& message);
46 
47     WebSocketBase() = default;
48     virtual ~WebSocketBase() noexcept;
49 
50     /**
51      * @brief Receive and decode a message.
52      * Must not be called concurrently on the same connection.
53      * Safe to call concurrently with `SendReply` and `Close`.
54      * Control frames are handled according to specification with an empty string as returned value,
55      * otherwise the method returns the decoded received message.
56      * Note that this method closes the connection after receiving invalid data.
57      * This event can be checked with `IsDecodeDisconnectMsg`.
58      */
59     std::string Decode();
60 
61     /**
62      * @brief Send message on current connection.
63      * Safe to call concurrently with: `SendReply`, `Decode`, `Close`.
64      * Note that the connection is not closed on transmission failures.
65      * @param message text payload.
66      * @param frameType frame type, must be either TEXT, BINARY or CONTINUATION.
67      * @param isLast flag indicating whether the message is the final.
68      * @returns true on success, false otherwise.
69      */
70     bool SendReply(const std::string& message, FrameType frameType = FrameType::TEXT, bool isLast = true) const;
71 
72     /**
73      * @brief Check if connection is in `OPEN` state.
74      */
75     bool IsConnected() const;
76 
77     /**
78      * @brief Set callback for calling after normal connection close.
79      * Non thread safe.
80      */
81     void SetCloseConnectionCallback(CloseConnectionCallback cb);
82 
83     /**
84      * @brief Set callback for calling after closing connection on any failure.
85      * Non thread safe.
86      */
87     void SetFailConnectionCallback(FailConnectionCallback cb);
88 
89     /**
90      * @brief Send `CLOSE` frame and close the connection socket.
91      * Does nothing if connection was not in `OPEN` state.
92      * @param status close status code specified in sent frame.
93      * @returns true if connection was closed, false otherwise.
94      */
95     bool CloseConnection(CloseStatusCode status);
96 
97 protected:
98     enum class ConnectionState : uint8_t {
99         CONNECTING,
100         OPEN,
101         CLOSING,
102         CLOSED,
103     };
104 
105     enum class ConnectionCloseReason: uint8_t {
106         FAIL,
107         CLOSE,
108     };
109 
110 protected:
111     /**
112      * @brief Set `send` and `recv` timeout limits.
113      * @param fd socket to set timeout on.
114      * @param timeoutLimit timeout in seconds. If zero, function is no-op.
115      * @returns true on success, false otherwise.
116      */
117     static bool SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit);
118 
119     /**
120      * @brief Shutdown socket for sends and receives.
121      * Note that the implementation of this function is platform-specific,
122      * so there is no unified way to retrieve error code returned from system call.
123      * @param fd socket file descriptor.
124      * @returns zero on success, `-1` otherwise.
125      */
126     static int ShutdownSocket(int32_t fd);
127 
128     /**
129      * @brief Close the connection socket.
130      * Must be transition from `CLOSING` to `CLOSED` connection state.
131      * @param status close reason, depends which callback to execute.
132      */
133     void CloseConnectionSocket(ConnectionCloseReason status);
134 
135     /**
136      * @brief Execute user-provided callbacks before closing the connection socket.
137      */
138     void OnConnectionClose(ConnectionCloseReason status);
139 
140     int GetConnectionSocket() const;
141     void SetConnectionSocket(int socketFd);
142     std::shared_mutex &GetConnectionMutex();
143 
144     ConnectionState GetConnectionState() const;
145     ConnectionState SetConnectionState(ConnectionState newState);
146     bool CompareExchangeConnectionState(ConnectionState& expected, ConnectionState newState);
147 
148     bool HandleDataFrame(WebSocketFrame& wsFrame) const;
149     bool HandleControlFrame(WebSocketFrame& wsFrame);
150     bool ReadPayload(WebSocketFrame& wsFrame) const;
151     void SendPongFrame(std::string payload) const;
152     void SendCloseFrame(CloseStatusCode status) const;
153 
154     bool SendUnderLock(const std::string& message) const;
155     bool SendUnderLock(const char* buf, size_t totalLen) const;
156     bool RecvUnderLock(std::string& message) const;
157     bool RecvUnderLock(uint8_t* buf, size_t totalLen) const;
158 
159     virtual bool ValidateIncomingFrame(const WebSocketFrame& wsFrame) const = 0;
160     virtual std::string CreateFrame(bool isLast, FrameType frameType) const = 0;
161     virtual std::string CreateFrame(bool isLast, FrameType frameType, const std::string& payload) const = 0;
162     virtual std::string CreateFrame(bool isLast, FrameType frameType, std::string&& payload) const = 0;
163     virtual bool DecodeMessage(WebSocketFrame& wsFrame) const = 0;
164 
165 protected:
166     static constexpr size_t HTTP_HANDSHAKE_MAX_LEN = 1024;
167     static constexpr int SOCKET_SUCCESS = 0;
168 
169 private:
170     std::atomic<ConnectionState> connectionState_ {ConnectionState::CLOSED};
171 
172     mutable std::shared_mutex connectionMutex_;
173     int connectionFd_ {-1};
174 
175     // Callbacks used during different stages of connection lifecycle.
176     CloseConnectionCallback closeCb_;
177     FailConnectionCallback failCb_;
178 
179     static constexpr std::string_view DECODE_DISCONNECT_MSG = "disconnect";
180 };
181 } // namespace OHOS::ArkCompiler::Toolchain
182 
183 #endif // ARKCOMPILER_TOOLCHAIN_WEBSOCKET_WEBSOCKET_BASE_H
184