• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include <fcntl.h>
17 #include "common/log_wrapper.h"
18 #include "websocket/frame_builder.h"
19 #include "websocket/handshake_helper.h"
20 #include "websocket/network.h"
21 #include "websocket/server/websocket_server.h"
22 
23 namespace OHOS::ArkCompiler::Toolchain {
DecodeMessage(WebSocketFrame & wsFrame) const24 bool WebSocketServer::DecodeMessage(WebSocketFrame& wsFrame) const
25 {
26     uint64_t msgLen = wsFrame.payloadLen;
27     if (msgLen == 0) {
28         // receiving empty data is OK
29         return true;
30     }
31     auto& buffer = wsFrame.payload;
32     buffer.resize(msgLen, 0);
33 
34     if (!Recv(connectionFd_, wsFrame.maskingKey, sizeof(wsFrame.maskingKey), 0)) {
35         LOGE("DecodeMessage: Recv maskingKey failed");
36         return false;
37     }
38 
39     if (!Recv(connectionFd_, buffer, 0)) {
40         LOGE("DecodeMessage: Recv message with mask failed");
41         return false;
42     }
43 
44     for (uint64_t i = 0; i < msgLen; i++) {
45         auto j = i % WebSocketFrame::MASK_LEN;
46         buffer[i] = static_cast<uint8_t>(buffer[i]) ^ wsFrame.maskingKey[j];
47     }
48 
49     return true;
50 }
51 
ProtocolUpgrade(const HttpRequest & req)52 bool WebSocketServer::ProtocolUpgrade(const HttpRequest& req)
53 {
54     unsigned char encodedKey[WebSocketKeyEncoder::ENCODED_KEY_LEN + 1];
55     if (!WebSocketKeyEncoder::EncodeKey(req.secWebSocketKey, encodedKey)) {
56         LOGE("ProtocolUpgrade: failed to encode WebSocket-Key");
57         return false;
58     }
59 
60     ProtocolUpgradeBuilder requestBuilder(encodedKey);
61     if (!Send(connectionFd_, requestBuilder.GetUpgradeMessage(), requestBuilder.GetLength(), 0)) {
62         LOGE("ProtocolUpgrade: Send failed");
63         return false;
64     }
65     return true;
66 }
67 
ResponseInvalidHandShake() const68 bool WebSocketServer::ResponseInvalidHandShake() const
69 {
70     std::string response(BAD_REQUEST_RESPONSE);
71     return Send(connectionFd_, response, 0);
72 }
73 
HttpHandShake()74 bool WebSocketServer::HttpHandShake()
75 {
76     std::string msgBuf(HTTP_HANDSHAKE_MAX_LEN, 0);
77     ssize_t msgLen = 0;
78     while ((msgLen = recv(connectionFd_, msgBuf.data(), HTTP_HANDSHAKE_MAX_LEN, 0)) < 0 && errno == EINTR) {
79         LOGW("HttpHandShake recv failed, errno == EINTR");
80     }
81     if (msgLen <= 0) {
82         LOGE("ReadMsg failed, msgLen = %{public}ld, errno = %{public}d", static_cast<long>(msgLen), errno);
83         return false;
84     }
85     // reduce to received size
86     msgBuf.resize(msgLen);
87 
88     HttpRequest req;
89     if (!HttpRequest::Decode(msgBuf, req)) {
90         LOGE("HttpHandShake: Upgrade failed");
91         return false;
92     }
93     if (validateCb_ && !validateCb_(req)) {
94         LOGE("HttpHandShake: Validation failed");
95         return false;
96     }
97 
98     if (ValidateHandShakeMessage(req)) {
99         return ProtocolUpgrade(req);
100     }
101 
102     LOGE("HttpHandShake: HTTP upgrade parameters failure");
103     if (!ResponseInvalidHandShake()) {
104         LOGE("HttpHandShake: failed to send 'bad request' response");
105     }
106     return false;
107 }
108 
109 /* static */
ValidateHandShakeMessage(const HttpRequest & req)110 bool WebSocketServer::ValidateHandShakeMessage(const HttpRequest& req)
111 {
112     return req.connection.find("Upgrade") != std::string::npos &&
113         req.upgrade.find("websocket") != std::string::npos &&
114         req.version.compare("HTTP/1.1") == 0;
115 }
116 
AcceptNewConnection()117 bool WebSocketServer::AcceptNewConnection()
118 {
119     if (socketState_ == SocketState::UNINITED) {
120         LOGE("AcceptNewConnection failed, websocket not inited");
121         return false;
122     }
123     if (socketState_ == SocketState::CONNECTED) {
124         LOGI("AcceptNewConnection websocket has connected");
125         return true;
126     }
127 
128     if ((connectionFd_ = accept(serverFd_, nullptr, nullptr)) < SOCKET_SUCCESS) {
129         LOGI("AcceptNewConnection accept has exited");
130         return false;
131     }
132 
133     if (!HttpHandShake()) {
134         LOGE("AcceptNewConnection HttpHandShake failed");
135         CloseConnectionSocket(ConnectionCloseReason::FAIL, SocketState::INITED);
136         return false;
137     }
138     OnNewConnection();
139     return true;
140 }
141 
142 #if !defined(OHOS_PLATFORM)
InitTcpWebSocket(int port,uint32_t timeoutLimit)143 bool WebSocketServer::InitTcpWebSocket(int port, uint32_t timeoutLimit)
144 {
145     if (port < 0) {
146         LOGE("InitTcpWebSocket invalid port");
147         return false;
148     }
149     if (socketState_ != SocketState::UNINITED) {
150         LOGI("InitTcpWebSocket websocket has inited");
151         return true;
152     }
153 #if defined(WINDOWS_PLATFORM)
154     WORD sockVersion = MAKEWORD(2, 2); // 2: version 2.2
155     WSADATA wsaData;
156     if (WSAStartup(sockVersion, &wsaData) != 0) {
157         LOGE("InitTcpWebSocket WSA init failed");
158         return false;
159     }
160 #endif
161     serverFd_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
162     if (serverFd_ < SOCKET_SUCCESS) {
163         LOGE("InitTcpWebSocket socket init failed, errno = %{public}d", errno);
164         return false;
165     }
166     // allow specified port can be used at once and not wait TIME_WAIT status ending
167     int sockOptVal = 1;
168     if ((setsockopt(serverFd_, SOL_SOCKET, SO_REUSEADDR,
169         reinterpret_cast<char *>(&sockOptVal), sizeof(sockOptVal))) != SOCKET_SUCCESS) {
170         LOGE("InitTcpWebSocket setsockopt SO_REUSEADDR failed, errno = %{public}d", errno);
171         CloseServerSocket();
172         return false;
173     }
174     // set send and recv timeout
175     if (!SetWebSocketTimeOut(serverFd_, timeoutLimit)) {
176         LOGE("InitTcpWebSocket SetWebSocketTimeOut failed");
177         CloseServerSocket();
178         return false;
179     }
180     sockaddr_in addrSin = {};
181     addrSin.sin_family = AF_INET;
182     addrSin.sin_port = htons(port);
183     addrSin.sin_addr.s_addr = INADDR_ANY;
184     if (bind(serverFd_, reinterpret_cast<struct sockaddr*>(&addrSin), sizeof(addrSin)) != SOCKET_SUCCESS) {
185         LOGE("InitTcpWebSocket bind failed, errno = %{public}d", errno);
186         CloseServerSocket();
187         return false;
188     }
189     if (listen(serverFd_, 1) != SOCKET_SUCCESS) {
190         LOGE("InitTcpWebSocket listen failed, errno = %{public}d", errno);
191         CloseServerSocket();
192         return false;
193     }
194     socketState_ = SocketState::INITED;
195     return true;
196 }
197 #else
InitUnixWebSocket(const std::string & sockName,uint32_t timeoutLimit)198 bool WebSocketServer::InitUnixWebSocket(const std::string& sockName, uint32_t timeoutLimit)
199 {
200     if (socketState_ != SocketState::UNINITED) {
201         LOGI("InitUnixWebSocket websocket has inited");
202         return true;
203     }
204     serverFd_ = socket(AF_UNIX, SOCK_STREAM, 0); // 0: default protocol
205     if (serverFd_ < SOCKET_SUCCESS) {
206         LOGE("InitUnixWebSocket socket init failed, errno = %{public}d", errno);
207         return false;
208     }
209     // set send and recv timeout
210     if (!SetWebSocketTimeOut(serverFd_, timeoutLimit)) {
211         LOGE("InitUnixWebSocket SetWebSocketTimeOut failed");
212         CloseServerSocket();
213         return false;
214     }
215 
216     struct sockaddr_un un;
217     if (memset_s(&un, sizeof(un), 0, sizeof(un)) != EOK) {
218         LOGE("InitUnixWebSocket memset_s failed");
219         CloseServerSocket();
220         return false;
221     }
222     un.sun_family = AF_UNIX;
223     if (strcpy_s(un.sun_path + 1, sizeof(un.sun_path) - 1, sockName.c_str()) != EOK) {
224         LOGE("InitUnixWebSocket strcpy_s failed");
225         CloseServerSocket();
226         return false;
227     }
228     un.sun_path[0] = '\0';
229     uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1;
230     if (bind(serverFd_, reinterpret_cast<struct sockaddr*>(&un), static_cast<int32_t>(len)) != SOCKET_SUCCESS) {
231         LOGE("InitUnixWebSocket bind failed, errno = %{public}d", errno);
232         CloseServerSocket();
233         return false;
234     }
235     if (listen(serverFd_, 1) != SOCKET_SUCCESS) { // 1: connection num
236         LOGE("InitUnixWebSocket listen failed, errno = %{public}d", errno);
237         CloseServerSocket();
238         return false;
239     }
240     socketState_ = SocketState::INITED;
241     return true;
242 }
243 
InitUnixWebSocket(int socketfd)244 bool WebSocketServer::InitUnixWebSocket(int socketfd)
245 {
246     if (socketState_ != SocketState::UNINITED) {
247         LOGI("InitUnixWebSocket websocket has inited");
248         return true;
249     }
250     if (socketfd < SOCKET_SUCCESS) {
251         LOGE("InitUnixWebSocket socketfd is invalid");
252         socketState_ = SocketState::UNINITED;
253         return false;
254     }
255     connectionFd_ = socketfd;
256     int flag = fcntl(connectionFd_, F_GETFL, 0);
257     if (flag == -1) {
258         LOGE("InitUnixWebSocket get client state is failed");
259         return false;
260     }
261     fcntl(connectionFd_, F_SETFL, static_cast<size_t>(flag) & ~O_NONBLOCK);
262     socketState_ = SocketState::INITED;
263     return true;
264 }
265 
ConnectUnixWebSocketBySocketpair()266 bool WebSocketServer::ConnectUnixWebSocketBySocketpair()
267 {
268     if (socketState_ == SocketState::UNINITED) {
269         LOGE("ConnectUnixWebSocket failed, websocket not inited");
270         return false;
271     }
272     if (socketState_ == SocketState::CONNECTED) {
273         LOGI("ConnectUnixWebSocket websocket has connected");
274         return true;
275     }
276 
277     if (!HttpHandShake()) {
278         LOGE("ConnectUnixWebSocket HttpHandShake failed");
279         CloseConnectionSocket(ConnectionCloseReason::FAIL, SocketState::UNINITED);
280         return false;
281     }
282     socketState_ = SocketState::CONNECTED;
283     return true;
284 }
285 #endif
286 
CloseServerSocket()287 void WebSocketServer::CloseServerSocket()
288 {
289     close(serverFd_);
290     serverFd_ = -1;
291     socketState_ = SocketState::UNINITED;
292 }
293 
OnNewConnection()294 void WebSocketServer::OnNewConnection()
295 {
296     LOGI("New client connected");
297     socketState_ = SocketState::CONNECTED;
298     if (openCb_) {
299         openCb_();
300     }
301 }
302 
SetValidateConnectionCallback(ValidateConnectionCallback cb)303 void WebSocketServer::SetValidateConnectionCallback(ValidateConnectionCallback cb)
304 {
305     validateCb_ = std::move(cb);
306 }
307 
SetOpenConnectionCallback(OpenConnectionCallback cb)308 void WebSocketServer::SetOpenConnectionCallback(OpenConnectionCallback cb)
309 {
310     openCb_ = std::move(cb);
311 }
312 
ValidateIncomingFrame(const WebSocketFrame & wsFrame)313 bool WebSocketServer::ValidateIncomingFrame(const WebSocketFrame& wsFrame)
314 {
315     // "The server MUST close the connection upon receiving a frame that is not masked."
316     // https://www.rfc-editor.org/rfc/rfc6455#section-5.1
317     return wsFrame.mask == 1;
318 }
319 
CreateFrame(bool isLast,FrameType frameType) const320 std::string WebSocketServer::CreateFrame(bool isLast, FrameType frameType) const
321 {
322     ServerFrameBuilder builder(isLast, frameType);
323     return builder.Build();
324 }
325 
CreateFrame(bool isLast,FrameType frameType,const std::string & payload) const326 std::string WebSocketServer::CreateFrame(bool isLast, FrameType frameType, const std::string& payload) const
327 {
328     ServerFrameBuilder builder(isLast, frameType);
329     return builder.SetPayload(payload).Build();
330 }
331 
CreateFrame(bool isLast,FrameType frameType,std::string && payload) const332 std::string WebSocketServer::CreateFrame(bool isLast, FrameType frameType, std::string&& payload) const
333 {
334     ServerFrameBuilder builder(isLast, frameType);
335     return builder.SetPayload(std::move(payload)).Build();
336 }
337 
Close()338 void WebSocketServer::Close()
339 {
340     if (socketState_ == SocketState::UNINITED) {
341         return;
342     }
343     if (socketState_ == SocketState::CONNECTED) {
344         CloseConnection(CloseStatusCode::SERVER_GO_AWAY, SocketState::INITED);
345     }
346     usleep(10000); // 10000: time for websocket to enter the accept
347 #if defined(OHOS_PLATFORM)
348     shutdown(serverFd_, SHUT_RDWR);
349 #endif
350     CloseServerSocket();
351 }
352 } // namespace OHOS::ArkCompiler::Toolchain
353