• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #if defined(PANDA_TARGET_WINDOWS)
17 #include <ws2tcpip.h>
18 #else
19 #include <arpa/inet.h>
20 #endif
21 
22 #include <fcntl.h>
23 #include "common/log_wrapper.h"
24 #include "frame_builder.h"
25 #include "handshake_helper.h"
26 #include "server/websocket_server.h"
27 #include "string_utils.h"
28 
29 #include <mutex>
30 #include <thread>
31 
32 namespace OHOS::ArkCompiler::Toolchain {
ValidateHandShakeMessage(const HttpRequest & req)33 static bool ValidateHandShakeMessage(const HttpRequest& req)
34 {
35     std::string upgradeHeaderValue = req.upgrade;
36     // Switch to lower case in order to support obsolete versions of WebSocket protocol.
37     ToLowerCase(upgradeHeaderValue);
38     return req.connection.find("Upgrade") != std::string::npos &&
39         upgradeHeaderValue.find("websocket") != std::string::npos &&
40         req.version.compare("HTTP/1.1") == 0;
41 }
42 
~WebSocketServer()43 WebSocketServer::~WebSocketServer() noexcept
44 {
45     if (serverFd_ != -1) {
46         LOGW("WebSocket server is closed while destructing the object");
47         close(serverFd_);
48         // Reset directly in order to prevent static analyzer warnings.
49         serverFd_ = -1;
50     }
51 }
52 
DecodeMessage(WebSocketFrame & wsFrame) const53 bool WebSocketServer::DecodeMessage(WebSocketFrame& wsFrame) const
54 {
55     const uint64_t msgLen = wsFrame.payloadLen;
56     if (msgLen == 0) {
57         // receiving empty data is OK
58         return true;
59     }
60     auto& buffer = wsFrame.payload;
61     buffer.resize(msgLen, 0);
62 
63     if (!RecvUnderLock(wsFrame.maskingKey, sizeof(wsFrame.maskingKey))) {
64         LOGE("DecodeMessage: Recv maskingKey failed");
65         return false;
66     }
67 
68     if (!RecvUnderLock(buffer)) {
69         LOGE("DecodeMessage: Recv message with mask failed");
70         return false;
71     }
72 
73     for (uint64_t i = 0; i < msgLen; i++) {
74         const auto j = i % WebSocketFrame::MASK_LEN;
75         buffer[i] = static_cast<uint8_t>(buffer[i]) ^ wsFrame.maskingKey[j];
76     }
77 
78     return true;
79 }
80 
ProtocolUpgrade(const HttpRequest & req)81 bool WebSocketServer::ProtocolUpgrade(const HttpRequest& req)
82 {
83     unsigned char encodedKey[WebSocketKeyEncoder::ENCODED_KEY_LEN + 1];
84     if (!WebSocketKeyEncoder::EncodeKey(req.secWebSocketKey, encodedKey)) {
85         LOGE("ProtocolUpgrade: failed to encode WebSocket-Key");
86         return false;
87     }
88 
89     ProtocolUpgradeBuilder requestBuilder(encodedKey);
90     if (!SendUnderLock(requestBuilder.GetUpgradeMessage(), requestBuilder.GetLength())) {
91         LOGE("ProtocolUpgrade: Send failed");
92         return false;
93     }
94     return true;
95 }
96 
ResponseInvalidHandShake() const97 bool WebSocketServer::ResponseInvalidHandShake() const
98 {
99     const std::string response(BAD_REQUEST_RESPONSE);
100     return SendUnderLock(response);
101 }
102 
HttpHandShake()103 bool WebSocketServer::HttpHandShake()
104 {
105     std::string msgBuf(HTTP_HANDSHAKE_MAX_LEN, 0);
106     ssize_t msgLen = 0;
107     {
108         std::shared_lock lock(GetConnectionMutex());
109         while ((msgLen = recv(GetConnectionSocket(), msgBuf.data(), HTTP_HANDSHAKE_MAX_LEN, 0)) < 0 &&
110             (errno == EINTR || errno == EAGAIN)) {
111             LOGW("HttpHandShake recv failed, errno = %{public}d", errno);
112         }
113     }
114     if (msgLen <= 0) {
115         LOGE("ReadMsg failed, msgLen = %{public}ld, errno = %{public}d", static_cast<long>(msgLen), errno);
116         return false;
117     }
118     // reduce to received size
119     msgBuf.resize(msgLen);
120 
121     HttpRequest req;
122     if (!HttpRequest::Decode(msgBuf, req)) {
123         LOGE("HttpHandShake: Upgrade failed");
124         return false;
125     }
126     if (validateCb_ && !validateCb_(req)) {
127         LOGE("HttpHandShake: Validation failed");
128         return false;
129     }
130 
131     if (ValidateHandShakeMessage(req)) {
132         return ProtocolUpgrade(req);
133     }
134 
135     LOGE("HttpHandShake: HTTP upgrade parameters failure");
136     if (!ResponseInvalidHandShake()) {
137         LOGE("HttpHandShake: failed to send 'bad request' response");
138     }
139     return false;
140 }
141 
MoveToConnectingState()142 bool WebSocketServer::MoveToConnectingState()
143 {
144     if (!serverUp_.load()) {
145         // Server `Close` happened, must not accept new connections.
146         return false;
147     }
148     auto expected = ConnectionState::CLOSED;
149     if (!CompareExchangeConnectionState(expected, ConnectionState::CONNECTING)) {
150         switch (expected) {
151             case ConnectionState::CLOSING:
152                 LOGW("MoveToConnectingState during closing connection");
153                 break;
154             case ConnectionState::OPEN:
155                 LOGW("MoveToConnectingState during already established connection");
156                 break;
157             case ConnectionState::CONNECTING:
158                 LOGE("MoveToConnectingState during opening connection, which violates WebSocketServer guarantees");
159                 break;
160             default:
161                 break;
162         }
163         return false;
164     }
165     // Must check once again because of `Close` method implementation.
166     if (!serverUp_.load()) {
167         // Server `Close` happened, `serverFd_` was closed, new connection must not be opened.
168         expected = SetConnectionState(ConnectionState::CLOSED);
169         if (expected != ConnectionState::CONNECTING) {
170             LOGE("AcceptNewConnection: Close guarantees violated");
171         }
172         return false;
173     }
174     return true;
175 }
176 
AcceptNewConnection()177 bool WebSocketServer::AcceptNewConnection()
178 {
179     if (!MoveToConnectingState()) {
180         return false;
181     }
182 
183     const int newConnectionFd = accept(serverFd_, nullptr, nullptr);
184     if (newConnectionFd < SOCKET_SUCCESS) {
185         LOGI("AcceptNewConnection accept has exited");
186 
187         auto expected = SetConnectionState(ConnectionState::CLOSED);
188         if (expected != ConnectionState::CONNECTING) {
189             LOGE("AcceptNewConnection: violation due to concurrent close and accept: got %{public}d",
190                  EnumToNumber(expected));
191         }
192         return false;
193     }
194     {
195         std::unique_lock lock(GetConnectionMutex());
196         SetConnectionSocket(newConnectionFd);
197     }
198 
199     if (!HttpHandShake()) {
200         LOGW("AcceptNewConnection HttpHandShake failed");
201 
202         auto expected = SetConnectionState(ConnectionState::CLOSING);
203         if (expected != ConnectionState::CONNECTING) {
204             LOGE("AcceptNewConnection: violation due to concurrent close and accept: got %{public}d",
205                  EnumToNumber(expected));
206         }
207         CloseConnectionSocket(ConnectionCloseReason::FAIL);
208         return false;
209     }
210 
211     OnNewConnection();
212     return true;
213 }
214 
InitTcpWebSocket(int port,uint32_t timeoutLimit)215 bool WebSocketServer::InitTcpWebSocket(int port, uint32_t timeoutLimit)
216 {
217     if (port < 0) {
218         LOGE("InitTcpWebSocket invalid port");
219         return false;
220     }
221     if (serverUp_.load()) {
222         LOGI("InitTcpWebSocket websocket has inited");
223         return true;
224     }
225 #if defined(WINDOWS_PLATFORM)
226     WORD sockVersion = MAKEWORD(2, 2); // 2: version 2.2
227     WSADATA wsaData;
228     if (WSAStartup(sockVersion, &wsaData) != 0) {
229         LOGE("InitTcpWebSocket WSA init failed");
230         return false;
231     }
232 #endif
233     serverFd_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
234     if (serverFd_ < SOCKET_SUCCESS) {
235         LOGE("InitTcpWebSocket socket init failed, errno = %{public}d", errno);
236         return false;
237     }
238     // allow specified port can be used at once and not wait TIME_WAIT status ending
239     int sockOptVal = 1;
240     if ((setsockopt(serverFd_, SOL_SOCKET, SO_REUSEADDR,
241         reinterpret_cast<char *>(&sockOptVal), sizeof(sockOptVal))) != SOCKET_SUCCESS) {
242         LOGE("InitTcpWebSocket setsockopt SO_REUSEADDR failed, errno = %{public}d", errno);
243         CloseServerSocket();
244         return false;
245     }
246     if (!SetWebSocketTimeOut(serverFd_, timeoutLimit)) {
247         LOGE("InitTcpWebSocket SetWebSocketTimeOut failed");
248         CloseServerSocket();
249         return false;
250     }
251     return BindAndListenTcpWebSocket(port);
252 }
253 
BindAndListenTcpWebSocket(int port)254 bool WebSocketServer::BindAndListenTcpWebSocket(int port)
255 {
256     sockaddr_in addrSin = {};
257     addrSin.sin_family = AF_INET;
258     addrSin.sin_port = htons(port);
259     if (inet_pton(AF_INET, "127.0.0.1", &addrSin.sin_addr) != NET_SUCCESS) {
260         LOGE("BindAndListenTcpWebSocket inet_pton failed, error = %{public}d", errno);
261         return false;
262     }
263     if (bind(serverFd_, reinterpret_cast<struct sockaddr*>(&addrSin), sizeof(addrSin)) != SOCKET_SUCCESS ||
264         listen(serverFd_, 1) != SOCKET_SUCCESS) {
265         LOGE("BindAndListenTcpWebSocket bind/listen failed, errno = %{public}d", errno);
266         CloseServerSocket();
267         return false;
268     }
269     serverUp_.store(true);
270     return true;
271 }
272 
273 #if !defined(WINDOWS_PLATFORM)
InitUnixWebSocket(const std::string & sockName,uint32_t timeoutLimit)274 bool WebSocketServer::InitUnixWebSocket(const std::string& sockName, uint32_t timeoutLimit)
275 {
276     if (serverUp_.load()) {
277         LOGI("InitUnixWebSocket websocket has inited");
278         return true;
279     }
280     serverFd_ = socket(AF_UNIX, SOCK_STREAM, 0); // 0: default protocol
281     if (serverFd_ < SOCKET_SUCCESS) {
282         LOGE("InitUnixWebSocket socket init failed, errno = %{public}d", errno);
283         return false;
284     }
285     // set send and recv timeout
286     if (!SetWebSocketTimeOut(serverFd_, timeoutLimit)) {
287         LOGE("InitUnixWebSocket SetWebSocketTimeOut failed");
288         CloseServerSocket();
289         return false;
290     }
291 
292     struct sockaddr_un un;
293     if (memset_s(&un, sizeof(un), 0, sizeof(un)) != EOK) {
294         LOGE("InitUnixWebSocket memset_s failed");
295         CloseServerSocket();
296         return false;
297     }
298     un.sun_family = AF_UNIX;
299     if (strcpy_s(un.sun_path + 1, sizeof(un.sun_path) - 1, sockName.c_str()) != EOK) {
300         LOGE("InitUnixWebSocket strcpy_s failed");
301         CloseServerSocket();
302         return false;
303     }
304     un.sun_path[0] = '\0';
305     uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1;
306     if (bind(serverFd_, reinterpret_cast<struct sockaddr*>(&un), static_cast<int32_t>(len)) != SOCKET_SUCCESS) {
307         LOGE("InitUnixWebSocket bind failed, errno = %{public}d", errno);
308         CloseServerSocket();
309         return false;
310     }
311     if (listen(serverFd_, 1) != SOCKET_SUCCESS) { // 1: connection num
312         LOGE("InitUnixWebSocket listen failed, errno = %{public}d", errno);
313         CloseServerSocket();
314         return false;
315     }
316     serverUp_.store(true);
317     return true;
318 }
319 
InitUnixWebSocket(int socketfd)320 bool WebSocketServer::InitUnixWebSocket(int socketfd)
321 {
322     if (serverUp_.load()) {
323         LOGI("InitUnixWebSocket websocket has inited");
324         return true;
325     }
326     if (socketfd < SOCKET_SUCCESS) {
327         LOGE("InitUnixWebSocket socketfd is invalid");
328         return false;
329     }
330     SetConnectionSocket(socketfd);
331     const int flag = fcntl(socketfd, F_GETFL, 0);
332     if (flag == -1) {
333         LOGE("InitUnixWebSocket get client state is failed, error is %{public}s", strerror(errno));
334         return false;
335     }
336     fcntl(socketfd, F_SETFL, static_cast<size_t>(flag) & ~O_NONBLOCK);
337     serverUp_.store(true);
338     return true;
339 }
340 
ConnectUnixWebSocketBySocketpair()341 bool WebSocketServer::ConnectUnixWebSocketBySocketpair()
342 {
343     if (!MoveToConnectingState()) {
344         return false;
345     }
346 
347     if (!HttpHandShake()) {
348         LOGE("ConnectUnixWebSocket HttpHandShake failed");
349 
350         auto expected = SetConnectionState(ConnectionState::CLOSING);
351         if (expected != ConnectionState::CONNECTING) {
352             LOGE("ConnectUnixWebSocketBySocketpair: violation due to concurrent close and accept: got %{public}d",
353                  EnumToNumber(expected));
354         }
355         CloseConnectionSocket(ConnectionCloseReason::FAIL);
356         return false;
357     }
358 
359     OnNewConnection();
360     return true;
361 }
362 #endif  // WINDOWS_PLATFORM
363 
CloseServerSocket()364 void WebSocketServer::CloseServerSocket()
365 {
366     close(serverFd_);
367     serverFd_ = -1;
368 }
369 
OnNewConnection()370 void WebSocketServer::OnNewConnection()
371 {
372     LOGI("New client connected");
373     if (openCb_) {
374         openCb_();
375     }
376 
377     auto expected = SetConnectionState(ConnectionState::OPEN);
378     if (expected != ConnectionState::CONNECTING) {
379         LOGE("OnNewConnection violation: expected CONNECTING, but got %{public}d",
380              EnumToNumber(expected));
381     }
382 }
383 
SetValidateConnectionCallback(ValidateConnectionCallback cb)384 void WebSocketServer::SetValidateConnectionCallback(ValidateConnectionCallback cb)
385 {
386     validateCb_ = std::move(cb);
387 }
388 
SetOpenConnectionCallback(OpenConnectionCallback cb)389 void WebSocketServer::SetOpenConnectionCallback(OpenConnectionCallback cb)
390 {
391     openCb_ = std::move(cb);
392 }
393 
ValidateIncomingFrame(const WebSocketFrame & wsFrame) const394 bool WebSocketServer::ValidateIncomingFrame(const WebSocketFrame& wsFrame) const
395 {
396     // "The server MUST close the connection upon receiving a frame that is not masked."
397     // https://www.rfc-editor.org/rfc/rfc6455#section-5.1
398     return wsFrame.mask == 1;
399 }
400 
CreateFrame(bool isLast,FrameType frameType) const401 std::string WebSocketServer::CreateFrame(bool isLast, FrameType frameType) const
402 {
403     ServerFrameBuilder builder(isLast, frameType);
404     return builder.Build();
405 }
406 
CreateFrame(bool isLast,FrameType frameType,const std::string & payload) const407 std::string WebSocketServer::CreateFrame(bool isLast, FrameType frameType, const std::string& payload) const
408 {
409     ServerFrameBuilder builder(isLast, frameType);
410     return builder.SetPayload(payload).Build();
411 }
412 
CreateFrame(bool isLast,FrameType frameType,std::string && payload) const413 std::string WebSocketServer::CreateFrame(bool isLast, FrameType frameType, std::string&& payload) const
414 {
415     ServerFrameBuilder builder(isLast, frameType);
416     return builder.SetPayload(std::move(payload)).Build();
417 }
418 
WaitConnectingStateEnds(ConnectionState connection)419 WebSocketServer::ConnectionState WebSocketServer::WaitConnectingStateEnds(ConnectionState connection)
420 {
421     auto shutdownSocketUnderLock = [this]() {
422         auto fd = GetConnectionSocket();
423         if (fd == -1) {
424             return false;
425         }
426         int err = ShutdownSocket(fd);
427         if (err != 0) {
428             LOGW("Failed to shutdown client socket, errno = %{public}d", errno);
429         }
430         return true;
431     };
432 
433     auto connectionSocketWasShutdown = false;
434     while (connection == ConnectionState::CONNECTING) {
435         if (!connectionSocketWasShutdown) {
436             // Try to shutdown the already accepted connection socket,
437             // otherwise thread can infinitely hang on handshake recv.
438             std::shared_lock lock(GetConnectionMutex());
439             connectionSocketWasShutdown = shutdownSocketUnderLock();
440         }
441 
442         std::this_thread::yield();
443         connection = GetConnectionState();
444     }
445     return connection;
446 }
447 
Close()448 void WebSocketServer::Close()
449 {
450     // Firstly stop accepting new connections.
451     if (!serverUp_.exchange(false)) {
452         return;
453     }
454 
455     int err = ShutdownSocket(serverFd_);
456     if (err != 0) {
457         LOGW("Failed to shutdown server socket, errno = %{public}d", errno);
458     }
459 
460     // If there is a concurrent call to `CloseConnection`, we can immediately close `serverFd_`.
461     // This is because new connections are already prevented, hence no reads of `serverFd_` will be done,
462     // and the connection itself will be closed by someone else.
463     auto connection = GetConnectionState();
464     if (connection == ConnectionState::CLOSING || connection == ConnectionState::CLOSED) {
465         CloseServerSocket();
466         return;
467     }
468 
469     connection = WaitConnectingStateEnds(connection);
470 
471     // Can safely close server socket, as there will be no more new connections attempts.
472     CloseServerSocket();
473     // Must check once again after finished `AcceptNewConnection`.
474     if (connection == ConnectionState::CLOSING || connection == ConnectionState::CLOSED) {
475         return;
476     }
477 
478     // If we reached this point, connection can be `OPEN`, `CLOSING` or `CLOSED`. Try to close it anyway.
479     CloseConnection(CloseStatusCode::SERVER_GO_AWAY);
480 }
481 } // namespace OHOS::ArkCompiler::Toolchain
482