• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #if defined(PANDA_TARGET_WINDOWS)
17 #include <ws2tcpip.h>
18 #else
19 #include <arpa/inet.h>
20 #endif
21 
22 #include <fcntl.h>
23 #include "common/log_wrapper.h"
24 #include "platform/file.h"
25 #include "frame_builder.h"
26 #include "handshake_helper.h"
27 #include "server/websocket_server.h"
28 #include "string_utils.h"
29 
30 #include <mutex>
31 #include <thread>
32 
33 namespace OHOS::ArkCompiler::Toolchain {
ValidateHandShakeMessage(const HttpRequest & req)34 static bool ValidateHandShakeMessage(const HttpRequest& req)
35 {
36     std::string upgradeHeaderValue = req.upgrade;
37     // Switch to lower case in order to support obsolete versions of WebSocket protocol.
38     ToLowerCase(upgradeHeaderValue);
39     return req.connection.find("Upgrade") != std::string::npos &&
40         upgradeHeaderValue.find("websocket") != std::string::npos &&
41         req.version.compare("HTTP/1.1") == 0;
42 }
43 
~WebSocketServer()44 WebSocketServer::~WebSocketServer() noexcept
45 {
46     if (serverFd_ != -1) {
47         LOGW("WebSocket server is closed while destructing the object");
48         FdsanClose(reinterpret_cast<fd_t>(serverFd_));
49         // Reset directly in order to prevent static analyzer warnings.
50         serverFd_ = -1;
51     }
52 }
53 
DecodeMessage(WebSocketFrame & wsFrame) const54 bool WebSocketServer::DecodeMessage(WebSocketFrame& wsFrame) const
55 {
56     const uint64_t msgLen = wsFrame.payloadLen;
57     if (msgLen == 0) {
58         // receiving empty data is OK
59         return true;
60     }
61     auto& buffer = wsFrame.payload;
62     buffer.resize(msgLen, 0);
63 
64     if (!RecvUnderLock(wsFrame.maskingKey, sizeof(wsFrame.maskingKey))) {
65         LOGE("DecodeMessage: Recv maskingKey failed");
66         return false;
67     }
68 
69     if (!RecvUnderLock(buffer)) {
70         LOGE("DecodeMessage: Recv message with mask failed");
71         return false;
72     }
73 
74     for (uint64_t i = 0; i < msgLen; i++) {
75         const auto j = i % WebSocketFrame::MASK_LEN;
76         buffer[i] = static_cast<uint8_t>(buffer[i]) ^ wsFrame.maskingKey[j];
77     }
78 
79     return true;
80 }
81 
ProtocolUpgrade(const HttpRequest & req)82 bool WebSocketServer::ProtocolUpgrade(const HttpRequest& req)
83 {
84     unsigned char encodedKey[WebSocketKeyEncoder::ENCODED_KEY_LEN + 1];
85     if (!WebSocketKeyEncoder::EncodeKey(req.secWebSocketKey, encodedKey)) {
86         LOGE("ProtocolUpgrade: failed to encode WebSocket-Key");
87         return false;
88     }
89 
90     ProtocolUpgradeBuilder requestBuilder(encodedKey);
91     if (!SendUnderLock(requestBuilder.GetUpgradeMessage(), requestBuilder.GetLength())) {
92         LOGE("ProtocolUpgrade: Send failed");
93         return false;
94     }
95     return true;
96 }
97 
ResponseInvalidHandShake() const98 bool WebSocketServer::ResponseInvalidHandShake() const
99 {
100     const std::string response(BAD_REQUEST_RESPONSE);
101     return SendUnderLock(response);
102 }
103 
HttpHandShake()104 bool WebSocketServer::HttpHandShake()
105 {
106     std::string msgBuf(HTTP_HANDSHAKE_MAX_LEN, 0);
107     ssize_t msgLen = 0;
108     {
109         std::shared_lock lock(GetConnectionMutex());
110         while ((msgLen = recv(GetConnectionSocket(), msgBuf.data(), HTTP_HANDSHAKE_MAX_LEN, 0)) < 0 &&
111             (errno == EINTR || errno == EAGAIN)) {
112             LOGW("HttpHandShake recv failed, errno = %{public}d", errno);
113         }
114     }
115     if (msgLen <= 0) {
116         LOGE("ReadMsg failed, msgLen = %{public}ld, errno = %{public}d", static_cast<long>(msgLen), errno);
117         return false;
118     }
119     // reduce to received size
120     msgBuf.resize(msgLen);
121 
122     HttpRequest req;
123     if (!HttpRequest::Decode(msgBuf, req)) {
124         LOGE("HttpHandShake: Upgrade failed");
125         return false;
126     }
127     if (validateCb_ && !validateCb_(req)) {
128         LOGE("HttpHandShake: Validation failed");
129         return false;
130     }
131 
132     if (ValidateHandShakeMessage(req)) {
133         return ProtocolUpgrade(req);
134     }
135 
136     LOGE("HttpHandShake: HTTP upgrade parameters failure");
137     if (!ResponseInvalidHandShake()) {
138         LOGE("HttpHandShake: failed to send 'bad request' response");
139     }
140     return false;
141 }
142 
MoveToConnectingState()143 bool WebSocketServer::MoveToConnectingState()
144 {
145     if (!serverUp_.load()) {
146         // Server `Close` happened, must not accept new connections.
147         return false;
148     }
149     auto expected = ConnectionState::CLOSED;
150     if (!CompareExchangeConnectionState(expected, ConnectionState::CONNECTING)) {
151         switch (expected) {
152             case ConnectionState::CLOSING:
153                 LOGW("MoveToConnectingState during closing connection");
154                 break;
155             case ConnectionState::OPEN:
156                 LOGW("MoveToConnectingState during already established connection");
157                 break;
158             case ConnectionState::CONNECTING:
159                 LOGE("MoveToConnectingState during opening connection, which violates WebSocketServer guarantees");
160                 break;
161             default:
162                 break;
163         }
164         return false;
165     }
166     // Must check once again because of `Close` method implementation.
167     if (!serverUp_.load()) {
168         // Server `Close` happened, `serverFd_` was closed, new connection must not be opened.
169         expected = SetConnectionState(ConnectionState::CLOSED);
170         if (expected != ConnectionState::CONNECTING) {
171             LOGE("AcceptNewConnection: Close guarantees violated");
172         }
173         return false;
174     }
175     return true;
176 }
177 
AcceptNewConnection()178 bool WebSocketServer::AcceptNewConnection()
179 {
180     if (!MoveToConnectingState()) {
181         return false;
182     }
183 
184     const int newConnectionFd = accept(serverFd_, nullptr, nullptr);
185     if (newConnectionFd < SOCKET_SUCCESS) {
186         LOGI("AcceptNewConnection accept has exited");
187 
188         auto expected = SetConnectionState(ConnectionState::CLOSED);
189         if (expected != ConnectionState::CONNECTING) {
190             LOGE("AcceptNewConnection: violation due to concurrent close and accept: got %{public}d",
191                  EnumToNumber(expected));
192         }
193         return false;
194     }
195     {
196         std::unique_lock lock(GetConnectionMutex());
197         SetConnectionSocket(newConnectionFd);
198     }
199 
200     if (!HttpHandShake()) {
201         LOGW("AcceptNewConnection HttpHandShake failed");
202 
203         auto expected = SetConnectionState(ConnectionState::CLOSING);
204         if (expected != ConnectionState::CONNECTING) {
205             LOGE("AcceptNewConnection: violation due to concurrent close and accept: got %{public}d",
206                  EnumToNumber(expected));
207         }
208         CloseConnectionSocket(ConnectionCloseReason::FAIL);
209         return false;
210     }
211 
212     OnNewConnection();
213     return true;
214 }
215 
InitTcpWebSocket(int port,uint32_t timeoutLimit)216 bool WebSocketServer::InitTcpWebSocket(int port, uint32_t timeoutLimit)
217 {
218     if (port < 0) {
219         LOGE("InitTcpWebSocket invalid port");
220         return false;
221     }
222     if (serverUp_.load()) {
223         LOGI("InitTcpWebSocket websocket has inited");
224         return true;
225     }
226 #if defined(WINDOWS_PLATFORM)
227     WORD sockVersion = MAKEWORD(2, 2); // 2: version 2.2
228     WSADATA wsaData;
229     if (WSAStartup(sockVersion, &wsaData) != 0) {
230         LOGE("InitTcpWebSocket WSA init failed");
231         return false;
232     }
233 #endif
234     serverFd_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
235     if (serverFd_ < SOCKET_SUCCESS) {
236         LOGE("InitTcpWebSocket socket init failed, errno = %{public}d", errno);
237         return false;
238     }
239     FdsanExchangeOwnerTag(reinterpret_cast<fd_t>(serverFd_));
240     // allow specified port can be used at once and not wait TIME_WAIT status ending
241     int sockOptVal = 1;
242     if ((setsockopt(serverFd_, SOL_SOCKET, SO_REUSEADDR,
243         reinterpret_cast<char *>(&sockOptVal), sizeof(sockOptVal))) != SOCKET_SUCCESS) {
244         LOGE("InitTcpWebSocket setsockopt SO_REUSEADDR failed, errno = %{public}d", errno);
245         CloseServerSocket();
246         return false;
247     }
248     if (!SetWebSocketTimeOut(serverFd_, timeoutLimit)) {
249         LOGE("InitTcpWebSocket SetWebSocketTimeOut failed");
250         CloseServerSocket();
251         return false;
252     }
253     return BindAndListenTcpWebSocket(port);
254 }
255 
BindAndListenTcpWebSocket(int port)256 bool WebSocketServer::BindAndListenTcpWebSocket(int port)
257 {
258     sockaddr_in addrSin = {};
259     addrSin.sin_family = AF_INET;
260     addrSin.sin_port = htons(port);
261     if (inet_pton(AF_INET, "127.0.0.1", &addrSin.sin_addr) != NET_SUCCESS) {
262         LOGE("BindAndListenTcpWebSocket inet_pton failed, error = %{public}d", errno);
263         return false;
264     }
265     if (bind(serverFd_, reinterpret_cast<struct sockaddr*>(&addrSin), sizeof(addrSin)) != SOCKET_SUCCESS ||
266         listen(serverFd_, 1) != SOCKET_SUCCESS) {
267         LOGE("BindAndListenTcpWebSocket bind/listen failed, errno = %{public}d", errno);
268         CloseServerSocket();
269         return false;
270     }
271     serverUp_.store(true);
272     return true;
273 }
274 
275 #if !defined(WINDOWS_PLATFORM)
InitUnixWebSocket(const std::string & sockName,uint32_t timeoutLimit)276 bool WebSocketServer::InitUnixWebSocket(const std::string& sockName, uint32_t timeoutLimit)
277 {
278     if (serverUp_.load()) {
279         LOGI("InitUnixWebSocket websocket has inited");
280         return true;
281     }
282     serverFd_ = socket(AF_UNIX, SOCK_STREAM, 0); // 0: default protocol
283     if (serverFd_ < SOCKET_SUCCESS) {
284         LOGE("InitUnixWebSocket socket init failed, errno = %{public}d", errno);
285         return false;
286     }
287     FdsanExchangeOwnerTag(reinterpret_cast<fd_t>(serverFd_));
288     // set send and recv timeout
289     if (!SetWebSocketTimeOut(serverFd_, timeoutLimit)) {
290         LOGE("InitUnixWebSocket SetWebSocketTimeOut failed");
291         CloseServerSocket();
292         return false;
293     }
294 
295     struct sockaddr_un un;
296     if (memset_s(&un, sizeof(un), 0, sizeof(un)) != EOK) {
297         LOGE("InitUnixWebSocket memset_s failed");
298         CloseServerSocket();
299         return false;
300     }
301     un.sun_family = AF_UNIX;
302     if (strcpy_s(un.sun_path + 1, sizeof(un.sun_path) - 1, sockName.c_str()) != EOK) {
303         LOGE("InitUnixWebSocket strcpy_s failed");
304         CloseServerSocket();
305         return false;
306     }
307     un.sun_path[0] = '\0';
308     uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1;
309     if (bind(serverFd_, reinterpret_cast<struct sockaddr*>(&un), static_cast<int32_t>(len)) != SOCKET_SUCCESS) {
310         LOGE("InitUnixWebSocket bind failed, errno = %{public}d", errno);
311         CloseServerSocket();
312         return false;
313     }
314     if (listen(serverFd_, 1) != SOCKET_SUCCESS) { // 1: connection num
315         LOGE("InitUnixWebSocket listen failed, errno = %{public}d", errno);
316         CloseServerSocket();
317         return false;
318     }
319     serverUp_.store(true);
320     return true;
321 }
322 
InitUnixWebSocket(int socketfd)323 bool WebSocketServer::InitUnixWebSocket(int socketfd)
324 {
325     if (serverUp_.load()) {
326         LOGI("InitUnixWebSocket websocket has inited");
327         return true;
328     }
329     if (socketfd < SOCKET_SUCCESS) {
330         LOGE("InitUnixWebSocket socketfd is invalid");
331         return false;
332     }
333     SetConnectionSocket(socketfd);
334     const int flag = fcntl(socketfd, F_GETFL, 0);
335     if (flag == -1) {
336         LOGE("InitUnixWebSocket get client state is failed, error is %{public}s", strerror(errno));
337         return false;
338     }
339     fcntl(socketfd, F_SETFL, static_cast<size_t>(flag) & ~O_NONBLOCK);
340     serverUp_.store(true);
341     return true;
342 }
343 
ConnectUnixWebSocketBySocketpair()344 bool WebSocketServer::ConnectUnixWebSocketBySocketpair()
345 {
346     if (!MoveToConnectingState()) {
347         return false;
348     }
349 
350     if (!HttpHandShake()) {
351         LOGE("ConnectUnixWebSocket HttpHandShake failed");
352 
353         auto expected = SetConnectionState(ConnectionState::CLOSING);
354         if (expected != ConnectionState::CONNECTING) {
355             LOGE("ConnectUnixWebSocketBySocketpair: violation due to concurrent close and accept: got %{public}d",
356                  EnumToNumber(expected));
357         }
358         CloseConnectionSocket(ConnectionCloseReason::FAIL);
359         return false;
360     }
361 
362     OnNewConnection();
363     return true;
364 }
365 #endif  // WINDOWS_PLATFORM
366 
CloseServerSocket()367 void WebSocketServer::CloseServerSocket()
368 {
369     FdsanClose(reinterpret_cast<fd_t>(serverFd_));
370     serverFd_ = -1;
371 }
372 
OnNewConnection()373 void WebSocketServer::OnNewConnection()
374 {
375     LOGI("New client connected");
376     if (openCb_) {
377         openCb_();
378     }
379 
380     auto expected = SetConnectionState(ConnectionState::OPEN);
381     if (expected != ConnectionState::CONNECTING) {
382         LOGE("OnNewConnection violation: expected CONNECTING, but got %{public}d",
383              EnumToNumber(expected));
384     }
385 }
386 
SetValidateConnectionCallback(ValidateConnectionCallback cb)387 void WebSocketServer::SetValidateConnectionCallback(ValidateConnectionCallback cb)
388 {
389     validateCb_ = std::move(cb);
390 }
391 
SetOpenConnectionCallback(OpenConnectionCallback cb)392 void WebSocketServer::SetOpenConnectionCallback(OpenConnectionCallback cb)
393 {
394     openCb_ = std::move(cb);
395 }
396 
ValidateIncomingFrame(const WebSocketFrame & wsFrame) const397 bool WebSocketServer::ValidateIncomingFrame(const WebSocketFrame& wsFrame) const
398 {
399     // "The server MUST close the connection upon receiving a frame that is not masked."
400     // https://www.rfc-editor.org/rfc/rfc6455#section-5.1
401     return wsFrame.mask == 1;
402 }
403 
CreateFrame(bool isLast,FrameType frameType) const404 std::string WebSocketServer::CreateFrame(bool isLast, FrameType frameType) const
405 {
406     ServerFrameBuilder builder(isLast, frameType);
407     return builder.Build();
408 }
409 
CreateFrame(bool isLast,FrameType frameType,const std::string & payload) const410 std::string WebSocketServer::CreateFrame(bool isLast, FrameType frameType, const std::string& payload) const
411 {
412     ServerFrameBuilder builder(isLast, frameType);
413     return builder.SetPayload(payload).Build();
414 }
415 
CreateFrame(bool isLast,FrameType frameType,std::string && payload) const416 std::string WebSocketServer::CreateFrame(bool isLast, FrameType frameType, std::string&& payload) const
417 {
418     ServerFrameBuilder builder(isLast, frameType);
419     return builder.SetPayload(std::move(payload)).Build();
420 }
421 
WaitConnectingStateEnds(ConnectionState connection)422 WebSocketServer::ConnectionState WebSocketServer::WaitConnectingStateEnds(ConnectionState connection)
423 {
424     auto shutdownSocketUnderLock = [this]() {
425         auto fd = GetConnectionSocket();
426         if (fd == -1) {
427             return false;
428         }
429         int err = ShutdownSocket(fd);
430         if (err != 0) {
431             LOGW("Failed to shutdown client socket, errno = %{public}d", errno);
432         }
433         return true;
434     };
435 
436     auto connectionSocketWasShutdown = false;
437     while (connection == ConnectionState::CONNECTING) {
438         if (!connectionSocketWasShutdown) {
439             // Try to shutdown the already accepted connection socket,
440             // otherwise thread can infinitely hang on handshake recv.
441             std::shared_lock lock(GetConnectionMutex());
442             connectionSocketWasShutdown = shutdownSocketUnderLock();
443         }
444 
445         std::this_thread::yield();
446         connection = GetConnectionState();
447     }
448     return connection;
449 }
450 
Close()451 void WebSocketServer::Close()
452 {
453     // Firstly stop accepting new connections.
454     if (!serverUp_.exchange(false)) {
455         return;
456     }
457 
458     int err = ShutdownSocket(serverFd_);
459     if (err != 0) {
460         LOGW("Failed to shutdown server socket, errno = %{public}d", errno);
461     }
462 
463     // If there is a concurrent call to `CloseConnection`, we can immediately close `serverFd_`.
464     // This is because new connections are already prevented, hence no reads of `serverFd_` will be done,
465     // and the connection itself will be closed by someone else.
466     auto connection = GetConnectionState();
467     if (connection == ConnectionState::CLOSING || connection == ConnectionState::CLOSED) {
468         CloseServerSocket();
469         return;
470     }
471 
472     connection = WaitConnectingStateEnds(connection);
473 
474     // Can safely close server socket, as there will be no more new connections attempts.
475     CloseServerSocket();
476     // Must check once again after finished `AcceptNewConnection`.
477     if (connection == ConnectionState::CLOSING || connection == ConnectionState::CLOSED) {
478         return;
479     }
480 
481     // If we reached this point, connection can be `OPEN`, `CLOSING` or `CLOSED`. Try to close it anyway.
482     CloseConnection(CloseStatusCode::SERVER_GO_AWAY);
483 }
484 } // namespace OHOS::ArkCompiler::Toolchain
485