1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #if defined(PANDA_TARGET_WINDOWS)
17 #include <ws2tcpip.h>
18 #else
19 #include <arpa/inet.h>
20 #endif
21
22 #include <fcntl.h>
23 #include "common/log_wrapper.h"
24 #include "platform/file.h"
25 #include "frame_builder.h"
26 #include "handshake_helper.h"
27 #include "server/websocket_server.h"
28 #include "string_utils.h"
29
30 #include <mutex>
31 #include <thread>
32
33 namespace OHOS::ArkCompiler::Toolchain {
ValidateHandShakeMessage(const HttpRequest & req)34 static bool ValidateHandShakeMessage(const HttpRequest& req)
35 {
36 std::string upgradeHeaderValue = req.upgrade;
37 // Switch to lower case in order to support obsolete versions of WebSocket protocol.
38 ToLowerCase(upgradeHeaderValue);
39 return req.connection.find("Upgrade") != std::string::npos &&
40 upgradeHeaderValue.find("websocket") != std::string::npos &&
41 req.version.compare("HTTP/1.1") == 0;
42 }
43
~WebSocketServer()44 WebSocketServer::~WebSocketServer() noexcept
45 {
46 if (serverFd_ != -1) {
47 LOGW("WebSocket server is closed while destructing the object");
48 FdsanClose(reinterpret_cast<fd_t>(serverFd_));
49 // Reset directly in order to prevent static analyzer warnings.
50 serverFd_ = -1;
51 }
52 }
53
DecodeMessage(WebSocketFrame & wsFrame) const54 bool WebSocketServer::DecodeMessage(WebSocketFrame& wsFrame) const
55 {
56 const uint64_t msgLen = wsFrame.payloadLen;
57 if (msgLen == 0) {
58 // receiving empty data is OK
59 return true;
60 }
61 auto& buffer = wsFrame.payload;
62 buffer.resize(msgLen, 0);
63
64 if (!RecvUnderLock(wsFrame.maskingKey, sizeof(wsFrame.maskingKey))) {
65 LOGE("DecodeMessage: Recv maskingKey failed");
66 return false;
67 }
68
69 if (!RecvUnderLock(buffer)) {
70 LOGE("DecodeMessage: Recv message with mask failed");
71 return false;
72 }
73
74 for (uint64_t i = 0; i < msgLen; i++) {
75 const auto j = i % WebSocketFrame::MASK_LEN;
76 buffer[i] = static_cast<uint8_t>(buffer[i]) ^ wsFrame.maskingKey[j];
77 }
78
79 return true;
80 }
81
ProtocolUpgrade(const HttpRequest & req)82 bool WebSocketServer::ProtocolUpgrade(const HttpRequest& req)
83 {
84 unsigned char encodedKey[WebSocketKeyEncoder::ENCODED_KEY_LEN + 1];
85 if (!WebSocketKeyEncoder::EncodeKey(req.secWebSocketKey, encodedKey)) {
86 LOGE("ProtocolUpgrade: failed to encode WebSocket-Key");
87 return false;
88 }
89
90 ProtocolUpgradeBuilder requestBuilder(encodedKey);
91 if (!SendUnderLock(requestBuilder.GetUpgradeMessage(), requestBuilder.GetLength())) {
92 LOGE("ProtocolUpgrade: Send failed");
93 return false;
94 }
95 return true;
96 }
97
ResponseInvalidHandShake() const98 bool WebSocketServer::ResponseInvalidHandShake() const
99 {
100 const std::string response(BAD_REQUEST_RESPONSE);
101 return SendUnderLock(response);
102 }
103
HttpHandShake()104 bool WebSocketServer::HttpHandShake()
105 {
106 std::string msgBuf(HTTP_HANDSHAKE_MAX_LEN, 0);
107 ssize_t msgLen = 0;
108 {
109 std::shared_lock lock(GetConnectionMutex());
110 while ((msgLen = recv(GetConnectionSocket(), msgBuf.data(), HTTP_HANDSHAKE_MAX_LEN, 0)) < 0 &&
111 (errno == EINTR || errno == EAGAIN)) {
112 LOGW("HttpHandShake recv failed, errno = %{public}d", errno);
113 }
114 }
115 if (msgLen <= 0) {
116 LOGE("ReadMsg failed, msgLen = %{public}ld, errno = %{public}d", static_cast<long>(msgLen), errno);
117 return false;
118 }
119 // reduce to received size
120 msgBuf.resize(msgLen);
121
122 HttpRequest req;
123 if (!HttpRequest::Decode(msgBuf, req)) {
124 LOGE("HttpHandShake: Upgrade failed");
125 return false;
126 }
127 if (validateCb_ && !validateCb_(req)) {
128 LOGE("HttpHandShake: Validation failed");
129 return false;
130 }
131
132 if (ValidateHandShakeMessage(req)) {
133 return ProtocolUpgrade(req);
134 }
135
136 LOGE("HttpHandShake: HTTP upgrade parameters failure");
137 if (!ResponseInvalidHandShake()) {
138 LOGE("HttpHandShake: failed to send 'bad request' response");
139 }
140 return false;
141 }
142
MoveToConnectingState()143 bool WebSocketServer::MoveToConnectingState()
144 {
145 if (!serverUp_.load()) {
146 // Server `Close` happened, must not accept new connections.
147 return false;
148 }
149 auto expected = ConnectionState::CLOSED;
150 if (!CompareExchangeConnectionState(expected, ConnectionState::CONNECTING)) {
151 switch (expected) {
152 case ConnectionState::CLOSING:
153 LOGW("MoveToConnectingState during closing connection");
154 break;
155 case ConnectionState::OPEN:
156 LOGW("MoveToConnectingState during already established connection");
157 break;
158 case ConnectionState::CONNECTING:
159 LOGE("MoveToConnectingState during opening connection, which violates WebSocketServer guarantees");
160 break;
161 default:
162 break;
163 }
164 return false;
165 }
166 // Must check once again because of `Close` method implementation.
167 if (!serverUp_.load()) {
168 // Server `Close` happened, `serverFd_` was closed, new connection must not be opened.
169 expected = SetConnectionState(ConnectionState::CLOSED);
170 if (expected != ConnectionState::CONNECTING) {
171 LOGE("AcceptNewConnection: Close guarantees violated");
172 }
173 return false;
174 }
175 return true;
176 }
177
AcceptNewConnection()178 bool WebSocketServer::AcceptNewConnection()
179 {
180 if (!MoveToConnectingState()) {
181 return false;
182 }
183
184 const int newConnectionFd = accept(serverFd_, nullptr, nullptr);
185 if (newConnectionFd < SOCKET_SUCCESS) {
186 LOGI("AcceptNewConnection accept has exited");
187
188 auto expected = SetConnectionState(ConnectionState::CLOSED);
189 if (expected != ConnectionState::CONNECTING) {
190 LOGE("AcceptNewConnection: violation due to concurrent close and accept: got %{public}d",
191 EnumToNumber(expected));
192 }
193 return false;
194 }
195 {
196 std::unique_lock lock(GetConnectionMutex());
197 SetConnectionSocket(newConnectionFd);
198 }
199
200 if (!HttpHandShake()) {
201 LOGW("AcceptNewConnection HttpHandShake failed");
202
203 auto expected = SetConnectionState(ConnectionState::CLOSING);
204 if (expected != ConnectionState::CONNECTING) {
205 LOGE("AcceptNewConnection: violation due to concurrent close and accept: got %{public}d",
206 EnumToNumber(expected));
207 }
208 CloseConnectionSocket(ConnectionCloseReason::FAIL);
209 return false;
210 }
211
212 OnNewConnection();
213 return true;
214 }
215
InitTcpWebSocket(int port,uint32_t timeoutLimit)216 bool WebSocketServer::InitTcpWebSocket(int port, uint32_t timeoutLimit)
217 {
218 if (port < 0) {
219 LOGE("InitTcpWebSocket invalid port");
220 return false;
221 }
222 if (serverUp_.load()) {
223 LOGI("InitTcpWebSocket websocket has inited");
224 return true;
225 }
226 #if defined(WINDOWS_PLATFORM)
227 WORD sockVersion = MAKEWORD(2, 2); // 2: version 2.2
228 WSADATA wsaData;
229 if (WSAStartup(sockVersion, &wsaData) != 0) {
230 LOGE("InitTcpWebSocket WSA init failed");
231 return false;
232 }
233 #endif
234 serverFd_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
235 if (serverFd_ < SOCKET_SUCCESS) {
236 LOGE("InitTcpWebSocket socket init failed, errno = %{public}d", errno);
237 return false;
238 }
239 FdsanExchangeOwnerTag(reinterpret_cast<fd_t>(serverFd_));
240 // allow specified port can be used at once and not wait TIME_WAIT status ending
241 int sockOptVal = 1;
242 if ((setsockopt(serverFd_, SOL_SOCKET, SO_REUSEADDR,
243 reinterpret_cast<char *>(&sockOptVal), sizeof(sockOptVal))) != SOCKET_SUCCESS) {
244 LOGE("InitTcpWebSocket setsockopt SO_REUSEADDR failed, errno = %{public}d", errno);
245 CloseServerSocket();
246 return false;
247 }
248 if (!SetWebSocketTimeOut(serverFd_, timeoutLimit)) {
249 LOGE("InitTcpWebSocket SetWebSocketTimeOut failed");
250 CloseServerSocket();
251 return false;
252 }
253 return BindAndListenTcpWebSocket(port);
254 }
255
BindAndListenTcpWebSocket(int port)256 bool WebSocketServer::BindAndListenTcpWebSocket(int port)
257 {
258 sockaddr_in addrSin = {};
259 addrSin.sin_family = AF_INET;
260 addrSin.sin_port = htons(port);
261 if (inet_pton(AF_INET, "127.0.0.1", &addrSin.sin_addr) != NET_SUCCESS) {
262 LOGE("BindAndListenTcpWebSocket inet_pton failed, error = %{public}d", errno);
263 return false;
264 }
265 if (bind(serverFd_, reinterpret_cast<struct sockaddr*>(&addrSin), sizeof(addrSin)) != SOCKET_SUCCESS ||
266 listen(serverFd_, 1) != SOCKET_SUCCESS) {
267 LOGE("BindAndListenTcpWebSocket bind/listen failed, errno = %{public}d", errno);
268 CloseServerSocket();
269 return false;
270 }
271 serverUp_.store(true);
272 return true;
273 }
274
275 #if !defined(WINDOWS_PLATFORM)
InitUnixWebSocket(const std::string & sockName,uint32_t timeoutLimit)276 bool WebSocketServer::InitUnixWebSocket(const std::string& sockName, uint32_t timeoutLimit)
277 {
278 if (serverUp_.load()) {
279 LOGI("InitUnixWebSocket websocket has inited");
280 return true;
281 }
282 serverFd_ = socket(AF_UNIX, SOCK_STREAM, 0); // 0: default protocol
283 if (serverFd_ < SOCKET_SUCCESS) {
284 LOGE("InitUnixWebSocket socket init failed, errno = %{public}d", errno);
285 return false;
286 }
287 FdsanExchangeOwnerTag(reinterpret_cast<fd_t>(serverFd_));
288 // set send and recv timeout
289 if (!SetWebSocketTimeOut(serverFd_, timeoutLimit)) {
290 LOGE("InitUnixWebSocket SetWebSocketTimeOut failed");
291 CloseServerSocket();
292 return false;
293 }
294
295 struct sockaddr_un un;
296 if (memset_s(&un, sizeof(un), 0, sizeof(un)) != EOK) {
297 LOGE("InitUnixWebSocket memset_s failed");
298 CloseServerSocket();
299 return false;
300 }
301 un.sun_family = AF_UNIX;
302 if (strcpy_s(un.sun_path + 1, sizeof(un.sun_path) - 1, sockName.c_str()) != EOK) {
303 LOGE("InitUnixWebSocket strcpy_s failed");
304 CloseServerSocket();
305 return false;
306 }
307 un.sun_path[0] = '\0';
308 uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1;
309 if (bind(serverFd_, reinterpret_cast<struct sockaddr*>(&un), static_cast<int32_t>(len)) != SOCKET_SUCCESS) {
310 LOGE("InitUnixWebSocket bind failed, errno = %{public}d", errno);
311 CloseServerSocket();
312 return false;
313 }
314 if (listen(serverFd_, 1) != SOCKET_SUCCESS) { // 1: connection num
315 LOGE("InitUnixWebSocket listen failed, errno = %{public}d", errno);
316 CloseServerSocket();
317 return false;
318 }
319 serverUp_.store(true);
320 return true;
321 }
322
InitUnixWebSocket(int socketfd)323 bool WebSocketServer::InitUnixWebSocket(int socketfd)
324 {
325 if (serverUp_.load()) {
326 LOGI("InitUnixWebSocket websocket has inited");
327 return true;
328 }
329 if (socketfd < SOCKET_SUCCESS) {
330 LOGE("InitUnixWebSocket socketfd is invalid");
331 return false;
332 }
333 SetConnectionSocket(socketfd);
334 const int flag = fcntl(socketfd, F_GETFL, 0);
335 if (flag == -1) {
336 LOGE("InitUnixWebSocket get client state is failed, error is %{public}s", strerror(errno));
337 return false;
338 }
339 fcntl(socketfd, F_SETFL, static_cast<size_t>(flag) & ~O_NONBLOCK);
340 serverUp_.store(true);
341 return true;
342 }
343
ConnectUnixWebSocketBySocketpair()344 bool WebSocketServer::ConnectUnixWebSocketBySocketpair()
345 {
346 if (!MoveToConnectingState()) {
347 return false;
348 }
349
350 if (!HttpHandShake()) {
351 LOGE("ConnectUnixWebSocket HttpHandShake failed");
352
353 auto expected = SetConnectionState(ConnectionState::CLOSING);
354 if (expected != ConnectionState::CONNECTING) {
355 LOGE("ConnectUnixWebSocketBySocketpair: violation due to concurrent close and accept: got %{public}d",
356 EnumToNumber(expected));
357 }
358 CloseConnectionSocket(ConnectionCloseReason::FAIL);
359 return false;
360 }
361
362 OnNewConnection();
363 return true;
364 }
365 #endif // WINDOWS_PLATFORM
366
CloseServerSocket()367 void WebSocketServer::CloseServerSocket()
368 {
369 FdsanClose(reinterpret_cast<fd_t>(serverFd_));
370 serverFd_ = -1;
371 }
372
OnNewConnection()373 void WebSocketServer::OnNewConnection()
374 {
375 LOGI("New client connected");
376 if (openCb_) {
377 openCb_();
378 }
379
380 auto expected = SetConnectionState(ConnectionState::OPEN);
381 if (expected != ConnectionState::CONNECTING) {
382 LOGE("OnNewConnection violation: expected CONNECTING, but got %{public}d",
383 EnumToNumber(expected));
384 }
385 }
386
SetValidateConnectionCallback(ValidateConnectionCallback cb)387 void WebSocketServer::SetValidateConnectionCallback(ValidateConnectionCallback cb)
388 {
389 validateCb_ = std::move(cb);
390 }
391
SetOpenConnectionCallback(OpenConnectionCallback cb)392 void WebSocketServer::SetOpenConnectionCallback(OpenConnectionCallback cb)
393 {
394 openCb_ = std::move(cb);
395 }
396
ValidateIncomingFrame(const WebSocketFrame & wsFrame) const397 bool WebSocketServer::ValidateIncomingFrame(const WebSocketFrame& wsFrame) const
398 {
399 // "The server MUST close the connection upon receiving a frame that is not masked."
400 // https://www.rfc-editor.org/rfc/rfc6455#section-5.1
401 return wsFrame.mask == 1;
402 }
403
CreateFrame(bool isLast,FrameType frameType) const404 std::string WebSocketServer::CreateFrame(bool isLast, FrameType frameType) const
405 {
406 ServerFrameBuilder builder(isLast, frameType);
407 return builder.Build();
408 }
409
CreateFrame(bool isLast,FrameType frameType,const std::string & payload) const410 std::string WebSocketServer::CreateFrame(bool isLast, FrameType frameType, const std::string& payload) const
411 {
412 ServerFrameBuilder builder(isLast, frameType);
413 return builder.SetPayload(payload).Build();
414 }
415
CreateFrame(bool isLast,FrameType frameType,std::string && payload) const416 std::string WebSocketServer::CreateFrame(bool isLast, FrameType frameType, std::string&& payload) const
417 {
418 ServerFrameBuilder builder(isLast, frameType);
419 return builder.SetPayload(std::move(payload)).Build();
420 }
421
WaitConnectingStateEnds(ConnectionState connection)422 WebSocketServer::ConnectionState WebSocketServer::WaitConnectingStateEnds(ConnectionState connection)
423 {
424 auto shutdownSocketUnderLock = [this]() {
425 auto fd = GetConnectionSocket();
426 if (fd == -1) {
427 return false;
428 }
429 int err = ShutdownSocket(fd);
430 if (err != 0) {
431 LOGW("Failed to shutdown client socket, errno = %{public}d", errno);
432 }
433 return true;
434 };
435
436 auto connectionSocketWasShutdown = false;
437 while (connection == ConnectionState::CONNECTING) {
438 if (!connectionSocketWasShutdown) {
439 // Try to shutdown the already accepted connection socket,
440 // otherwise thread can infinitely hang on handshake recv.
441 std::shared_lock lock(GetConnectionMutex());
442 connectionSocketWasShutdown = shutdownSocketUnderLock();
443 }
444
445 std::this_thread::yield();
446 connection = GetConnectionState();
447 }
448 return connection;
449 }
450
Close()451 void WebSocketServer::Close()
452 {
453 // Firstly stop accepting new connections.
454 if (!serverUp_.exchange(false)) {
455 return;
456 }
457
458 int err = ShutdownSocket(serverFd_);
459 if (err != 0) {
460 LOGW("Failed to shutdown server socket, errno = %{public}d", errno);
461 }
462
463 // If there is a concurrent call to `CloseConnection`, we can immediately close `serverFd_`.
464 // This is because new connections are already prevented, hence no reads of `serverFd_` will be done,
465 // and the connection itself will be closed by someone else.
466 auto connection = GetConnectionState();
467 if (connection == ConnectionState::CLOSING || connection == ConnectionState::CLOSED) {
468 CloseServerSocket();
469 return;
470 }
471
472 connection = WaitConnectingStateEnds(connection);
473
474 // Can safely close server socket, as there will be no more new connections attempts.
475 CloseServerSocket();
476 // Must check once again after finished `AcceptNewConnection`.
477 if (connection == ConnectionState::CLOSING || connection == ConnectionState::CLOSED) {
478 return;
479 }
480
481 // If we reached this point, connection can be `OPEN`, `CLOSING` or `CLOSED`. Try to close it anyway.
482 CloseConnection(CloseStatusCode::SERVER_GO_AWAY);
483 }
484 } // namespace OHOS::ArkCompiler::Toolchain
485