1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include <fcntl.h>
17 #include "common/log_wrapper.h"
18 #include "websocket/frame_builder.h"
19 #include "websocket/handshake_helper.h"
20 #include "websocket/network.h"
21 #include "websocket/server/websocket_server.h"
22
23 namespace OHOS::ArkCompiler::Toolchain {
DecodeMessage(WebSocketFrame & wsFrame) const24 bool WebSocketServer::DecodeMessage(WebSocketFrame& wsFrame) const
25 {
26 uint64_t msgLen = wsFrame.payloadLen;
27 if (msgLen == 0) {
28 // receiving empty data is OK
29 return true;
30 }
31 auto& buffer = wsFrame.payload;
32 buffer.resize(msgLen, 0);
33
34 if (!Recv(connectionFd_, wsFrame.maskingKey, sizeof(wsFrame.maskingKey), 0)) {
35 LOGE("DecodeMessage: Recv maskingKey failed");
36 return false;
37 }
38
39 if (!Recv(connectionFd_, buffer, 0)) {
40 LOGE("DecodeMessage: Recv message with mask failed");
41 return false;
42 }
43
44 for (uint64_t i = 0; i < msgLen; i++) {
45 auto j = i % WebSocketFrame::MASK_LEN;
46 buffer[i] = static_cast<uint8_t>(buffer[i]) ^ wsFrame.maskingKey[j];
47 }
48
49 return true;
50 }
51
ProtocolUpgrade(const HttpRequest & req)52 bool WebSocketServer::ProtocolUpgrade(const HttpRequest& req)
53 {
54 unsigned char encodedKey[WebSocketKeyEncoder::ENCODED_KEY_LEN + 1];
55 if (!WebSocketKeyEncoder::EncodeKey(req.secWebSocketKey, encodedKey)) {
56 LOGE("ProtocolUpgrade: failed to encode WebSocket-Key");
57 return false;
58 }
59
60 ProtocolUpgradeBuilder requestBuilder(encodedKey);
61 if (!Send(connectionFd_, requestBuilder.GetUpgradeMessage(), requestBuilder.GetLength(), 0)) {
62 LOGE("ProtocolUpgrade: Send failed");
63 return false;
64 }
65 return true;
66 }
67
ResponseInvalidHandShake() const68 bool WebSocketServer::ResponseInvalidHandShake() const
69 {
70 std::string response(BAD_REQUEST_RESPONSE);
71 return Send(connectionFd_, response, 0);
72 }
73
HttpHandShake()74 bool WebSocketServer::HttpHandShake()
75 {
76 std::string msgBuf(HTTP_HANDSHAKE_MAX_LEN, 0);
77 ssize_t msgLen = 0;
78 while ((msgLen = recv(connectionFd_, msgBuf.data(), HTTP_HANDSHAKE_MAX_LEN, 0)) < 0 && errno == EINTR) {
79 LOGW("HttpHandShake recv failed, errno == EINTR");
80 }
81 if (msgLen <= 0) {
82 LOGE("ReadMsg failed, msgLen = %{public}ld, errno = %{public}d", static_cast<long>(msgLen), errno);
83 return false;
84 }
85 // reduce to received size
86 msgBuf.resize(msgLen);
87
88 HttpRequest req;
89 if (!HttpRequest::Decode(msgBuf, req)) {
90 LOGE("HttpHandShake: Upgrade failed");
91 return false;
92 }
93 if (validateCb_ && !validateCb_(req)) {
94 LOGE("HttpHandShake: Validation failed");
95 return false;
96 }
97
98 if (ValidateHandShakeMessage(req)) {
99 return ProtocolUpgrade(req);
100 }
101
102 LOGE("HttpHandShake: HTTP upgrade parameters failure");
103 if (!ResponseInvalidHandShake()) {
104 LOGE("HttpHandShake: failed to send 'bad request' response");
105 }
106 return false;
107 }
108
109 /* static */
ValidateHandShakeMessage(const HttpRequest & req)110 bool WebSocketServer::ValidateHandShakeMessage(const HttpRequest& req)
111 {
112 return req.connection.find("Upgrade") != std::string::npos &&
113 req.upgrade.find("websocket") != std::string::npos &&
114 req.version.compare("HTTP/1.1") == 0;
115 }
116
AcceptNewConnection()117 bool WebSocketServer::AcceptNewConnection()
118 {
119 if (socketState_ == SocketState::UNINITED) {
120 LOGE("AcceptNewConnection failed, websocket not inited");
121 return false;
122 }
123 if (socketState_ == SocketState::CONNECTED) {
124 LOGI("AcceptNewConnection websocket has connected");
125 return true;
126 }
127
128 if ((connectionFd_ = accept(serverFd_, nullptr, nullptr)) < SOCKET_SUCCESS) {
129 LOGI("AcceptNewConnection accept has exited");
130 return false;
131 }
132
133 if (!HttpHandShake()) {
134 LOGE("AcceptNewConnection HttpHandShake failed");
135 CloseConnectionSocket(ConnectionCloseReason::FAIL, SocketState::INITED);
136 return false;
137 }
138 OnNewConnection();
139 return true;
140 }
141
142 #if !defined(OHOS_PLATFORM)
InitTcpWebSocket(int port,uint32_t timeoutLimit)143 bool WebSocketServer::InitTcpWebSocket(int port, uint32_t timeoutLimit)
144 {
145 if (port < 0) {
146 LOGE("InitTcpWebSocket invalid port");
147 return false;
148 }
149 if (socketState_ != SocketState::UNINITED) {
150 LOGI("InitTcpWebSocket websocket has inited");
151 return true;
152 }
153 #if defined(WINDOWS_PLATFORM)
154 WORD sockVersion = MAKEWORD(2, 2); // 2: version 2.2
155 WSADATA wsaData;
156 if (WSAStartup(sockVersion, &wsaData) != 0) {
157 LOGE("InitTcpWebSocket WSA init failed");
158 return false;
159 }
160 #endif
161 serverFd_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
162 if (serverFd_ < SOCKET_SUCCESS) {
163 LOGE("InitTcpWebSocket socket init failed, errno = %{public}d", errno);
164 return false;
165 }
166 // allow specified port can be used at once and not wait TIME_WAIT status ending
167 int sockOptVal = 1;
168 if ((setsockopt(serverFd_, SOL_SOCKET, SO_REUSEADDR,
169 reinterpret_cast<char *>(&sockOptVal), sizeof(sockOptVal))) != SOCKET_SUCCESS) {
170 LOGE("InitTcpWebSocket setsockopt SO_REUSEADDR failed, errno = %{public}d", errno);
171 CloseServerSocket();
172 return false;
173 }
174 // set send and recv timeout
175 if (!SetWebSocketTimeOut(serverFd_, timeoutLimit)) {
176 LOGE("InitTcpWebSocket SetWebSocketTimeOut failed");
177 CloseServerSocket();
178 return false;
179 }
180 sockaddr_in addrSin = {};
181 addrSin.sin_family = AF_INET;
182 addrSin.sin_port = htons(port);
183 addrSin.sin_addr.s_addr = INADDR_ANY;
184 if (bind(serverFd_, reinterpret_cast<struct sockaddr*>(&addrSin), sizeof(addrSin)) != SOCKET_SUCCESS) {
185 LOGE("InitTcpWebSocket bind failed, errno = %{public}d", errno);
186 CloseServerSocket();
187 return false;
188 }
189 if (listen(serverFd_, 1) != SOCKET_SUCCESS) {
190 LOGE("InitTcpWebSocket listen failed, errno = %{public}d", errno);
191 CloseServerSocket();
192 return false;
193 }
194 socketState_ = SocketState::INITED;
195 return true;
196 }
197 #else
InitUnixWebSocket(const std::string & sockName,uint32_t timeoutLimit)198 bool WebSocketServer::InitUnixWebSocket(const std::string& sockName, uint32_t timeoutLimit)
199 {
200 if (socketState_ != SocketState::UNINITED) {
201 LOGI("InitUnixWebSocket websocket has inited");
202 return true;
203 }
204 serverFd_ = socket(AF_UNIX, SOCK_STREAM, 0); // 0: default protocol
205 if (serverFd_ < SOCKET_SUCCESS) {
206 LOGE("InitUnixWebSocket socket init failed, errno = %{public}d", errno);
207 return false;
208 }
209 // set send and recv timeout
210 if (!SetWebSocketTimeOut(serverFd_, timeoutLimit)) {
211 LOGE("InitUnixWebSocket SetWebSocketTimeOut failed");
212 CloseServerSocket();
213 return false;
214 }
215
216 struct sockaddr_un un;
217 if (memset_s(&un, sizeof(un), 0, sizeof(un)) != EOK) {
218 LOGE("InitUnixWebSocket memset_s failed");
219 CloseServerSocket();
220 return false;
221 }
222 un.sun_family = AF_UNIX;
223 if (strcpy_s(un.sun_path + 1, sizeof(un.sun_path) - 1, sockName.c_str()) != EOK) {
224 LOGE("InitUnixWebSocket strcpy_s failed");
225 CloseServerSocket();
226 return false;
227 }
228 un.sun_path[0] = '\0';
229 uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1;
230 if (bind(serverFd_, reinterpret_cast<struct sockaddr*>(&un), static_cast<int32_t>(len)) != SOCKET_SUCCESS) {
231 LOGE("InitUnixWebSocket bind failed, errno = %{public}d", errno);
232 CloseServerSocket();
233 return false;
234 }
235 if (listen(serverFd_, 1) != SOCKET_SUCCESS) { // 1: connection num
236 LOGE("InitUnixWebSocket listen failed, errno = %{public}d", errno);
237 CloseServerSocket();
238 return false;
239 }
240 socketState_ = SocketState::INITED;
241 return true;
242 }
243
InitUnixWebSocket(int socketfd)244 bool WebSocketServer::InitUnixWebSocket(int socketfd)
245 {
246 if (socketState_ != SocketState::UNINITED) {
247 LOGI("InitUnixWebSocket websocket has inited");
248 return true;
249 }
250 if (socketfd < SOCKET_SUCCESS) {
251 LOGE("InitUnixWebSocket socketfd is invalid");
252 socketState_ = SocketState::UNINITED;
253 return false;
254 }
255 connectionFd_ = socketfd;
256 int flag = fcntl(connectionFd_, F_GETFL, 0);
257 if (flag == -1) {
258 LOGE("InitUnixWebSocket get client state is failed");
259 return false;
260 }
261 fcntl(connectionFd_, F_SETFL, static_cast<size_t>(flag) & ~O_NONBLOCK);
262 socketState_ = SocketState::INITED;
263 return true;
264 }
265
ConnectUnixWebSocketBySocketpair()266 bool WebSocketServer::ConnectUnixWebSocketBySocketpair()
267 {
268 if (socketState_ == SocketState::UNINITED) {
269 LOGE("ConnectUnixWebSocket failed, websocket not inited");
270 return false;
271 }
272 if (socketState_ == SocketState::CONNECTED) {
273 LOGI("ConnectUnixWebSocket websocket has connected");
274 return true;
275 }
276
277 if (!HttpHandShake()) {
278 LOGE("ConnectUnixWebSocket HttpHandShake failed");
279 CloseConnectionSocket(ConnectionCloseReason::FAIL, SocketState::UNINITED);
280 return false;
281 }
282 socketState_ = SocketState::CONNECTED;
283 return true;
284 }
285 #endif
286
CloseServerSocket()287 void WebSocketServer::CloseServerSocket()
288 {
289 close(serverFd_);
290 serverFd_ = -1;
291 socketState_ = SocketState::UNINITED;
292 }
293
OnNewConnection()294 void WebSocketServer::OnNewConnection()
295 {
296 LOGI("New client connected");
297 socketState_ = SocketState::CONNECTED;
298 if (openCb_) {
299 openCb_();
300 }
301 }
302
SetValidateConnectionCallback(ValidateConnectionCallback cb)303 void WebSocketServer::SetValidateConnectionCallback(ValidateConnectionCallback cb)
304 {
305 validateCb_ = std::move(cb);
306 }
307
SetOpenConnectionCallback(OpenConnectionCallback cb)308 void WebSocketServer::SetOpenConnectionCallback(OpenConnectionCallback cb)
309 {
310 openCb_ = std::move(cb);
311 }
312
ValidateIncomingFrame(const WebSocketFrame & wsFrame)313 bool WebSocketServer::ValidateIncomingFrame(const WebSocketFrame& wsFrame)
314 {
315 // "The server MUST close the connection upon receiving a frame that is not masked."
316 // https://www.rfc-editor.org/rfc/rfc6455#section-5.1
317 return wsFrame.mask == 1;
318 }
319
CreateFrame(bool isLast,FrameType frameType) const320 std::string WebSocketServer::CreateFrame(bool isLast, FrameType frameType) const
321 {
322 ServerFrameBuilder builder(isLast, frameType);
323 return builder.Build();
324 }
325
CreateFrame(bool isLast,FrameType frameType,const std::string & payload) const326 std::string WebSocketServer::CreateFrame(bool isLast, FrameType frameType, const std::string& payload) const
327 {
328 ServerFrameBuilder builder(isLast, frameType);
329 return builder.SetPayload(payload).Build();
330 }
331
CreateFrame(bool isLast,FrameType frameType,std::string && payload) const332 std::string WebSocketServer::CreateFrame(bool isLast, FrameType frameType, std::string&& payload) const
333 {
334 ServerFrameBuilder builder(isLast, frameType);
335 return builder.SetPayload(std::move(payload)).Build();
336 }
337
Close()338 void WebSocketServer::Close()
339 {
340 if (socketState_ == SocketState::UNINITED) {
341 return;
342 }
343 if (socketState_ == SocketState::CONNECTED) {
344 CloseConnection(CloseStatusCode::SERVER_GO_AWAY, SocketState::INITED);
345 }
346 usleep(10000); // 10000: time for websocket to enter the accept
347 #if defined(OHOS_PLATFORM)
348 shutdown(serverFd_, SHUT_RDWR);
349 #endif
350 CloseServerSocket();
351 }
352 } // namespace OHOS::ArkCompiler::Toolchain
353