1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #if defined(PANDA_TARGET_WINDOWS)
17 #include <ws2tcpip.h>
18 #else
19 #include <arpa/inet.h>
20 #endif
21
22 #include <fcntl.h>
23 #include "common/log_wrapper.h"
24 #include "frame_builder.h"
25 #include "handshake_helper.h"
26 #include "server/websocket_server.h"
27 #include "string_utils.h"
28
29 #include <mutex>
30 #include <thread>
31
32 namespace OHOS::ArkCompiler::Toolchain {
ValidateHandShakeMessage(const HttpRequest & req)33 static bool ValidateHandShakeMessage(const HttpRequest& req)
34 {
35 std::string upgradeHeaderValue = req.upgrade;
36 // Switch to lower case in order to support obsolete versions of WebSocket protocol.
37 ToLowerCase(upgradeHeaderValue);
38 return req.connection.find("Upgrade") != std::string::npos &&
39 upgradeHeaderValue.find("websocket") != std::string::npos &&
40 req.version.compare("HTTP/1.1") == 0;
41 }
42
~WebSocketServer()43 WebSocketServer::~WebSocketServer() noexcept
44 {
45 if (serverFd_ != -1) {
46 LOGW("WebSocket server is closed while destructing the object");
47 close(serverFd_);
48 // Reset directly in order to prevent static analyzer warnings.
49 serverFd_ = -1;
50 }
51 }
52
DecodeMessage(WebSocketFrame & wsFrame) const53 bool WebSocketServer::DecodeMessage(WebSocketFrame& wsFrame) const
54 {
55 const uint64_t msgLen = wsFrame.payloadLen;
56 if (msgLen == 0) {
57 // receiving empty data is OK
58 return true;
59 }
60 auto& buffer = wsFrame.payload;
61 buffer.resize(msgLen, 0);
62
63 if (!RecvUnderLock(wsFrame.maskingKey, sizeof(wsFrame.maskingKey))) {
64 LOGE("DecodeMessage: Recv maskingKey failed");
65 return false;
66 }
67
68 if (!RecvUnderLock(buffer)) {
69 LOGE("DecodeMessage: Recv message with mask failed");
70 return false;
71 }
72
73 for (uint64_t i = 0; i < msgLen; i++) {
74 const auto j = i % WebSocketFrame::MASK_LEN;
75 buffer[i] = static_cast<uint8_t>(buffer[i]) ^ wsFrame.maskingKey[j];
76 }
77
78 return true;
79 }
80
ProtocolUpgrade(const HttpRequest & req)81 bool WebSocketServer::ProtocolUpgrade(const HttpRequest& req)
82 {
83 unsigned char encodedKey[WebSocketKeyEncoder::ENCODED_KEY_LEN + 1];
84 if (!WebSocketKeyEncoder::EncodeKey(req.secWebSocketKey, encodedKey)) {
85 LOGE("ProtocolUpgrade: failed to encode WebSocket-Key");
86 return false;
87 }
88
89 ProtocolUpgradeBuilder requestBuilder(encodedKey);
90 if (!SendUnderLock(requestBuilder.GetUpgradeMessage(), requestBuilder.GetLength())) {
91 LOGE("ProtocolUpgrade: Send failed");
92 return false;
93 }
94 return true;
95 }
96
ResponseInvalidHandShake() const97 bool WebSocketServer::ResponseInvalidHandShake() const
98 {
99 const std::string response(BAD_REQUEST_RESPONSE);
100 return SendUnderLock(response);
101 }
102
HttpHandShake()103 bool WebSocketServer::HttpHandShake()
104 {
105 std::string msgBuf(HTTP_HANDSHAKE_MAX_LEN, 0);
106 ssize_t msgLen = 0;
107 {
108 std::shared_lock lock(GetConnectionMutex());
109 while ((msgLen = recv(GetConnectionSocket(), msgBuf.data(), HTTP_HANDSHAKE_MAX_LEN, 0)) < 0 &&
110 (errno == EINTR || errno == EAGAIN)) {
111 LOGW("HttpHandShake recv failed, errno = %{public}d", errno);
112 }
113 }
114 if (msgLen <= 0) {
115 LOGE("ReadMsg failed, msgLen = %{public}ld, errno = %{public}d", static_cast<long>(msgLen), errno);
116 return false;
117 }
118 // reduce to received size
119 msgBuf.resize(msgLen);
120
121 HttpRequest req;
122 if (!HttpRequest::Decode(msgBuf, req)) {
123 LOGE("HttpHandShake: Upgrade failed");
124 return false;
125 }
126 if (validateCb_ && !validateCb_(req)) {
127 LOGE("HttpHandShake: Validation failed");
128 return false;
129 }
130
131 if (ValidateHandShakeMessage(req)) {
132 return ProtocolUpgrade(req);
133 }
134
135 LOGE("HttpHandShake: HTTP upgrade parameters failure");
136 if (!ResponseInvalidHandShake()) {
137 LOGE("HttpHandShake: failed to send 'bad request' response");
138 }
139 return false;
140 }
141
MoveToConnectingState()142 bool WebSocketServer::MoveToConnectingState()
143 {
144 if (!serverUp_.load()) {
145 // Server `Close` happened, must not accept new connections.
146 return false;
147 }
148 auto expected = ConnectionState::CLOSED;
149 if (!CompareExchangeConnectionState(expected, ConnectionState::CONNECTING)) {
150 switch (expected) {
151 case ConnectionState::CLOSING:
152 LOGW("MoveToConnectingState during closing connection");
153 break;
154 case ConnectionState::OPEN:
155 LOGW("MoveToConnectingState during already established connection");
156 break;
157 case ConnectionState::CONNECTING:
158 LOGE("MoveToConnectingState during opening connection, which violates WebSocketServer guarantees");
159 break;
160 default:
161 break;
162 }
163 return false;
164 }
165 // Must check once again because of `Close` method implementation.
166 if (!serverUp_.load()) {
167 // Server `Close` happened, `serverFd_` was closed, new connection must not be opened.
168 expected = SetConnectionState(ConnectionState::CLOSED);
169 if (expected != ConnectionState::CONNECTING) {
170 LOGE("AcceptNewConnection: Close guarantees violated");
171 }
172 return false;
173 }
174 return true;
175 }
176
AcceptNewConnection()177 bool WebSocketServer::AcceptNewConnection()
178 {
179 if (!MoveToConnectingState()) {
180 return false;
181 }
182
183 const int newConnectionFd = accept(serverFd_, nullptr, nullptr);
184 if (newConnectionFd < SOCKET_SUCCESS) {
185 LOGI("AcceptNewConnection accept has exited");
186
187 auto expected = SetConnectionState(ConnectionState::CLOSED);
188 if (expected != ConnectionState::CONNECTING) {
189 LOGE("AcceptNewConnection: violation due to concurrent close and accept: got %{public}d",
190 EnumToNumber(expected));
191 }
192 return false;
193 }
194 {
195 std::unique_lock lock(GetConnectionMutex());
196 SetConnectionSocket(newConnectionFd);
197 }
198
199 if (!HttpHandShake()) {
200 LOGW("AcceptNewConnection HttpHandShake failed");
201
202 auto expected = SetConnectionState(ConnectionState::CLOSING);
203 if (expected != ConnectionState::CONNECTING) {
204 LOGE("AcceptNewConnection: violation due to concurrent close and accept: got %{public}d",
205 EnumToNumber(expected));
206 }
207 CloseConnectionSocket(ConnectionCloseReason::FAIL);
208 return false;
209 }
210
211 OnNewConnection();
212 return true;
213 }
214
InitTcpWebSocket(int port,uint32_t timeoutLimit)215 bool WebSocketServer::InitTcpWebSocket(int port, uint32_t timeoutLimit)
216 {
217 if (port < 0) {
218 LOGE("InitTcpWebSocket invalid port");
219 return false;
220 }
221 if (serverUp_.load()) {
222 LOGI("InitTcpWebSocket websocket has inited");
223 return true;
224 }
225 #if defined(WINDOWS_PLATFORM)
226 WORD sockVersion = MAKEWORD(2, 2); // 2: version 2.2
227 WSADATA wsaData;
228 if (WSAStartup(sockVersion, &wsaData) != 0) {
229 LOGE("InitTcpWebSocket WSA init failed");
230 return false;
231 }
232 #endif
233 serverFd_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
234 if (serverFd_ < SOCKET_SUCCESS) {
235 LOGE("InitTcpWebSocket socket init failed, errno = %{public}d", errno);
236 return false;
237 }
238 // allow specified port can be used at once and not wait TIME_WAIT status ending
239 int sockOptVal = 1;
240 if ((setsockopt(serverFd_, SOL_SOCKET, SO_REUSEADDR,
241 reinterpret_cast<char *>(&sockOptVal), sizeof(sockOptVal))) != SOCKET_SUCCESS) {
242 LOGE("InitTcpWebSocket setsockopt SO_REUSEADDR failed, errno = %{public}d", errno);
243 CloseServerSocket();
244 return false;
245 }
246 if (!SetWebSocketTimeOut(serverFd_, timeoutLimit)) {
247 LOGE("InitTcpWebSocket SetWebSocketTimeOut failed");
248 CloseServerSocket();
249 return false;
250 }
251 return BindAndListenTcpWebSocket(port);
252 }
253
BindAndListenTcpWebSocket(int port)254 bool WebSocketServer::BindAndListenTcpWebSocket(int port)
255 {
256 sockaddr_in addrSin = {};
257 addrSin.sin_family = AF_INET;
258 addrSin.sin_port = htons(port);
259 if (inet_pton(AF_INET, "127.0.0.1", &addrSin.sin_addr) != NET_SUCCESS) {
260 LOGE("BindAndListenTcpWebSocket inet_pton failed, error = %{public}d", errno);
261 return false;
262 }
263 if (bind(serverFd_, reinterpret_cast<struct sockaddr*>(&addrSin), sizeof(addrSin)) != SOCKET_SUCCESS ||
264 listen(serverFd_, 1) != SOCKET_SUCCESS) {
265 LOGE("BindAndListenTcpWebSocket bind/listen failed, errno = %{public}d", errno);
266 CloseServerSocket();
267 return false;
268 }
269 serverUp_.store(true);
270 return true;
271 }
272
273 #if !defined(WINDOWS_PLATFORM)
InitUnixWebSocket(const std::string & sockName,uint32_t timeoutLimit)274 bool WebSocketServer::InitUnixWebSocket(const std::string& sockName, uint32_t timeoutLimit)
275 {
276 if (serverUp_.load()) {
277 LOGI("InitUnixWebSocket websocket has inited");
278 return true;
279 }
280 serverFd_ = socket(AF_UNIX, SOCK_STREAM, 0); // 0: default protocol
281 if (serverFd_ < SOCKET_SUCCESS) {
282 LOGE("InitUnixWebSocket socket init failed, errno = %{public}d", errno);
283 return false;
284 }
285 // set send and recv timeout
286 if (!SetWebSocketTimeOut(serverFd_, timeoutLimit)) {
287 LOGE("InitUnixWebSocket SetWebSocketTimeOut failed");
288 CloseServerSocket();
289 return false;
290 }
291
292 struct sockaddr_un un;
293 if (memset_s(&un, sizeof(un), 0, sizeof(un)) != EOK) {
294 LOGE("InitUnixWebSocket memset_s failed");
295 CloseServerSocket();
296 return false;
297 }
298 un.sun_family = AF_UNIX;
299 if (strcpy_s(un.sun_path + 1, sizeof(un.sun_path) - 1, sockName.c_str()) != EOK) {
300 LOGE("InitUnixWebSocket strcpy_s failed");
301 CloseServerSocket();
302 return false;
303 }
304 un.sun_path[0] = '\0';
305 uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1;
306 if (bind(serverFd_, reinterpret_cast<struct sockaddr*>(&un), static_cast<int32_t>(len)) != SOCKET_SUCCESS) {
307 LOGE("InitUnixWebSocket bind failed, errno = %{public}d", errno);
308 CloseServerSocket();
309 return false;
310 }
311 if (listen(serverFd_, 1) != SOCKET_SUCCESS) { // 1: connection num
312 LOGE("InitUnixWebSocket listen failed, errno = %{public}d", errno);
313 CloseServerSocket();
314 return false;
315 }
316 serverUp_.store(true);
317 return true;
318 }
319
InitUnixWebSocket(int socketfd)320 bool WebSocketServer::InitUnixWebSocket(int socketfd)
321 {
322 if (serverUp_.load()) {
323 LOGI("InitUnixWebSocket websocket has inited");
324 return true;
325 }
326 if (socketfd < SOCKET_SUCCESS) {
327 LOGE("InitUnixWebSocket socketfd is invalid");
328 return false;
329 }
330 SetConnectionSocket(socketfd);
331 const int flag = fcntl(socketfd, F_GETFL, 0);
332 if (flag == -1) {
333 LOGE("InitUnixWebSocket get client state is failed, error is %{public}s", strerror(errno));
334 return false;
335 }
336 fcntl(socketfd, F_SETFL, static_cast<size_t>(flag) & ~O_NONBLOCK);
337 serverUp_.store(true);
338 return true;
339 }
340
ConnectUnixWebSocketBySocketpair()341 bool WebSocketServer::ConnectUnixWebSocketBySocketpair()
342 {
343 if (!MoveToConnectingState()) {
344 return false;
345 }
346
347 if (!HttpHandShake()) {
348 LOGE("ConnectUnixWebSocket HttpHandShake failed");
349
350 auto expected = SetConnectionState(ConnectionState::CLOSING);
351 if (expected != ConnectionState::CONNECTING) {
352 LOGE("ConnectUnixWebSocketBySocketpair: violation due to concurrent close and accept: got %{public}d",
353 EnumToNumber(expected));
354 }
355 CloseConnectionSocket(ConnectionCloseReason::FAIL);
356 return false;
357 }
358
359 OnNewConnection();
360 return true;
361 }
362 #endif // WINDOWS_PLATFORM
363
CloseServerSocket()364 void WebSocketServer::CloseServerSocket()
365 {
366 close(serverFd_);
367 serverFd_ = -1;
368 }
369
OnNewConnection()370 void WebSocketServer::OnNewConnection()
371 {
372 LOGI("New client connected");
373 if (openCb_) {
374 openCb_();
375 }
376
377 auto expected = SetConnectionState(ConnectionState::OPEN);
378 if (expected != ConnectionState::CONNECTING) {
379 LOGE("OnNewConnection violation: expected CONNECTING, but got %{public}d",
380 EnumToNumber(expected));
381 }
382 }
383
SetValidateConnectionCallback(ValidateConnectionCallback cb)384 void WebSocketServer::SetValidateConnectionCallback(ValidateConnectionCallback cb)
385 {
386 validateCb_ = std::move(cb);
387 }
388
SetOpenConnectionCallback(OpenConnectionCallback cb)389 void WebSocketServer::SetOpenConnectionCallback(OpenConnectionCallback cb)
390 {
391 openCb_ = std::move(cb);
392 }
393
ValidateIncomingFrame(const WebSocketFrame & wsFrame) const394 bool WebSocketServer::ValidateIncomingFrame(const WebSocketFrame& wsFrame) const
395 {
396 // "The server MUST close the connection upon receiving a frame that is not masked."
397 // https://www.rfc-editor.org/rfc/rfc6455#section-5.1
398 return wsFrame.mask == 1;
399 }
400
CreateFrame(bool isLast,FrameType frameType) const401 std::string WebSocketServer::CreateFrame(bool isLast, FrameType frameType) const
402 {
403 ServerFrameBuilder builder(isLast, frameType);
404 return builder.Build();
405 }
406
CreateFrame(bool isLast,FrameType frameType,const std::string & payload) const407 std::string WebSocketServer::CreateFrame(bool isLast, FrameType frameType, const std::string& payload) const
408 {
409 ServerFrameBuilder builder(isLast, frameType);
410 return builder.SetPayload(payload).Build();
411 }
412
CreateFrame(bool isLast,FrameType frameType,std::string && payload) const413 std::string WebSocketServer::CreateFrame(bool isLast, FrameType frameType, std::string&& payload) const
414 {
415 ServerFrameBuilder builder(isLast, frameType);
416 return builder.SetPayload(std::move(payload)).Build();
417 }
418
WaitConnectingStateEnds(ConnectionState connection)419 WebSocketServer::ConnectionState WebSocketServer::WaitConnectingStateEnds(ConnectionState connection)
420 {
421 auto shutdownSocketUnderLock = [this]() {
422 auto fd = GetConnectionSocket();
423 if (fd == -1) {
424 return false;
425 }
426 int err = ShutdownSocket(fd);
427 if (err != 0) {
428 LOGW("Failed to shutdown client socket, errno = %{public}d", errno);
429 }
430 return true;
431 };
432
433 auto connectionSocketWasShutdown = false;
434 while (connection == ConnectionState::CONNECTING) {
435 if (!connectionSocketWasShutdown) {
436 // Try to shutdown the already accepted connection socket,
437 // otherwise thread can infinitely hang on handshake recv.
438 std::shared_lock lock(GetConnectionMutex());
439 connectionSocketWasShutdown = shutdownSocketUnderLock();
440 }
441
442 std::this_thread::yield();
443 connection = GetConnectionState();
444 }
445 return connection;
446 }
447
Close()448 void WebSocketServer::Close()
449 {
450 // Firstly stop accepting new connections.
451 if (!serverUp_.exchange(false)) {
452 return;
453 }
454
455 int err = ShutdownSocket(serverFd_);
456 if (err != 0) {
457 LOGW("Failed to shutdown server socket, errno = %{public}d", errno);
458 }
459
460 // If there is a concurrent call to `CloseConnection`, we can immediately close `serverFd_`.
461 // This is because new connections are already prevented, hence no reads of `serverFd_` will be done,
462 // and the connection itself will be closed by someone else.
463 auto connection = GetConnectionState();
464 if (connection == ConnectionState::CLOSING || connection == ConnectionState::CLOSED) {
465 CloseServerSocket();
466 return;
467 }
468
469 connection = WaitConnectingStateEnds(connection);
470
471 // Can safely close server socket, as there will be no more new connections attempts.
472 CloseServerSocket();
473 // Must check once again after finished `AcceptNewConnection`.
474 if (connection == ConnectionState::CLOSING || connection == ConnectionState::CLOSED) {
475 return;
476 }
477
478 // If we reached this point, connection can be `OPEN`, `CLOSING` or `CLOSED`. Try to close it anyway.
479 CloseConnection(CloseStatusCode::SERVER_GO_AWAY);
480 }
481 } // namespace OHOS::ArkCompiler::Toolchain
482