• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include <sys/socket.h>
17 #include <sys/time.h>
18 #include <arpa/inet.h>
19 #include <sys/un.h>
20 #include <unistd.h>
21 
22 #include "common/log_wrapper.h"
23 #include "websocket/frame_builder.h"
24 #include "websocket/handshake_helper.h"
25 #include "websocket/network.h"
26 #include "websocket/string_utils.h"
27 #include "websocket/client/websocket_client.h"
28 
29 namespace OHOS::ArkCompiler::Toolchain {
InitToolchainWebSocketForPort(int port,uint32_t timeoutLimit)30 bool WebSocketClient::InitToolchainWebSocketForPort(int port, uint32_t timeoutLimit)
31 {
32     if (socketState_ != SocketState::UNINITED) {
33         LOGE("InitToolchainWebSocketForPort::client has inited.");
34         return true;
35     }
36 
37     connectionFd_ = socket(AF_INET, SOCK_STREAM, 0);
38     if (connectionFd_ < SOCKET_SUCCESS) {
39         LOGE("InitToolchainWebSocketForPort::client socket failed, error = %{public}d , desc = %{public}s",
40             errno, strerror(errno));
41         return false;
42     }
43 
44     // set send and recv timeout limit
45     if (!SetWebSocketTimeOut(connectionFd_, timeoutLimit)) {
46         LOGE("InitToolchainWebSocketForPort::client SetWebSocketTimeOut failed, error = %{public}d , desc = %{public}s",
47             errno, strerror(errno));
48         CloseConnectionSocketOnFail();
49         return false;
50     }
51 
52     sockaddr_in clientAddr;
53     if (memset_s(&clientAddr, sizeof(clientAddr), 0, sizeof(clientAddr)) != EOK) {
54         LOGE("InitToolchainWebSocketForPort::client memset_s clientAddr failed, error = %{public}d, desc = %{public}s",
55             errno, strerror(errno));
56         CloseConnectionSocketOnFail();
57         return false;
58     }
59     clientAddr.sin_family = AF_INET;
60     clientAddr.sin_port = htons(port);
61     int ret = inet_pton(AF_INET, "127.0.0.1", &clientAddr.sin_addr);
62     if (ret != NET_SUCCESS) {
63         LOGE("InitToolchainWebSocketForPort::client inet_pton failed, error = %{public}d, desc = %{public}s",
64             errno, strerror(errno));
65         CloseConnectionSocketOnFail();
66         return false;
67     }
68 
69     ret = connect(connectionFd_, reinterpret_cast<struct sockaddr*>(&clientAddr), sizeof(clientAddr));
70     if (ret != SOCKET_SUCCESS) {
71         LOGE("InitToolchainWebSocketForPort::client connect failed, error = %{public}d, desc = %{public}s",
72             errno, strerror(errno));
73         CloseConnectionSocketOnFail();
74         return false;
75     }
76     socketState_ = SocketState::INITED;
77     LOGI("InitToolchainWebSocketForPort::client connect success.");
78     return true;
79 }
80 
InitToolchainWebSocketForSockName(const std::string & sockName,uint32_t timeoutLimit)81 bool WebSocketClient::InitToolchainWebSocketForSockName(const std::string &sockName, uint32_t timeoutLimit)
82 {
83     if (socketState_ != SocketState::UNINITED) {
84         LOGE("InitToolchainWebSocketForSockName::client has inited.");
85         return true;
86     }
87 
88     connectionFd_ = socket(AF_UNIX, SOCK_STREAM, 0);
89     if (connectionFd_ < SOCKET_SUCCESS) {
90         LOGE("InitToolchainWebSocketForSockName::client socket failed, error = %{public}d , desc = %{public}s",
91             errno, strerror(errno));
92         return false;
93     }
94 
95     // set send and recv timeout limit
96     if (!SetWebSocketTimeOut(connectionFd_, timeoutLimit)) {
97         LOGE("InitToolchainWebSocketForSockName::client SetWebSocketTimeOut failed, "
98             "error = %{public}d, desc = %{public}s",
99             errno, strerror(errno));
100         CloseConnectionSocketOnFail();
101         return false;
102     }
103 
104     struct sockaddr_un serverAddr;
105     if (memset_s(&serverAddr, sizeof(serverAddr), 0, sizeof(serverAddr)) != EOK) {
106         LOGE("InitToolchainWebSocketForSockName::client memset_s clientAddr failed, "
107             "error = %{public}d, desc = %{public}s",
108             errno, strerror(errno));
109         CloseConnectionSocketOnFail();
110         return false;
111     }
112     serverAddr.sun_family = AF_UNIX;
113     if (strcpy_s(serverAddr.sun_path + 1, sizeof(serverAddr.sun_path) - 1, sockName.c_str()) != EOK) {
114         LOGE("InitToolchainWebSocketForSockName::client strcpy_s serverAddr.sun_path failed, "
115             "error = %{public}d, desc = %{public}s",
116             errno, strerror(errno));
117         CloseConnectionSocketOnFail();
118         return false;
119     }
120     serverAddr.sun_path[0] = '\0';
121 
122     uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1;
123     int ret = connect(connectionFd_, reinterpret_cast<struct sockaddr*>(&serverAddr), static_cast<int32_t>(len));
124     if (ret != SOCKET_SUCCESS) {
125         LOGE("InitToolchainWebSocketForSockName::client connect failed, error = %{public}d, desc = %{public}s",
126             errno, strerror(errno));
127         CloseConnectionSocketOnFail();
128         return false;
129     }
130     socketState_ = SocketState::INITED;
131     LOGI("InitToolchainWebSocketForSockName::client connect success.");
132     return true;
133 }
134 
ClientSendWSUpgradeReq()135 bool WebSocketClient::ClientSendWSUpgradeReq()
136 {
137     if (socketState_ == SocketState::UNINITED) {
138         LOGE("ClientSendWSUpgradeReq::client has not inited.");
139         return false;
140     }
141     if (socketState_ == SocketState::CONNECTED) {
142         LOGE("ClientSendWSUpgradeReq::client has connected.");
143         return true;
144     }
145 
146     // length without null-terminator
147     if (!Send(connectionFd_, CLIENT_WEBSOCKET_UPGRADE_REQ, sizeof(CLIENT_WEBSOCKET_UPGRADE_REQ) - 1, 0)) {
148         LOGE("ClientSendWSUpgradeReq::client send wsupgrade req failed, error = %{public}d, desc = %{public}sn",
149             errno, strerror(errno));
150         CloseConnectionSocketOnFail();
151         return false;
152     }
153     LOGI("ClientSendWSUpgradeReq::client send wsupgrade req success.");
154     return true;
155 }
156 
ClientRecvWSUpgradeRsp()157 bool WebSocketClient::ClientRecvWSUpgradeRsp()
158 {
159     if (socketState_ == SocketState::UNINITED) {
160         LOGE("ClientRecvWSUpgradeRsp::client has not inited.");
161         return false;
162     }
163     if (socketState_ == SocketState::CONNECTED) {
164         LOGE("ClientRecvWSUpgradeRsp::client has connected.");
165         return true;
166     }
167 
168     std::string msgBuf(HTTP_HANDSHAKE_MAX_LEN, 0);
169     ssize_t msgLen = 0;
170     while ((msgLen = recv(connectionFd_, msgBuf.data(), HTTP_HANDSHAKE_MAX_LEN, 0)) < 0 && errno == EINTR) {
171         LOGW("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp failed, errno == EINTR");
172     }
173     if (msgLen <= 0) {
174         LOGE("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp failed, error = %{public}d, desc = %{public}sn",
175             errno, strerror(errno));
176         CloseConnectionSocketOnFail();
177         return false;
178     }
179     // reduce to received size
180     msgBuf.resize(msgLen);
181 
182     HttpResponse response;
183     if (!HttpResponse::Decode(msgBuf, response) || !ValidateServerHandShake(response)) {
184         LOGE("ClientRecvWSUpgradeRsp::client server handshake response is invalid");
185         CloseConnectionSocketOnFail();
186         return false;
187     }
188 
189     socketState_ = SocketState::CONNECTED;
190     LOGI("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp success.");
191     return true;
192 }
193 
GetSocketStateString()194 std::string WebSocketClient::GetSocketStateString()
195 {
196     return std::string(SOCKET_STATE_NAMES[EnumToNumber(socketState_.load())]);
197 }
198 
199 /* static */
ValidateServerHandShake(HttpResponse & response)200 bool WebSocketClient::ValidateServerHandShake(HttpResponse& response)
201 {
202     // in accordance to https://www.rfc-editor.org/rfc/rfc6455#section-4.1
203     if (response.status != HTTP_SWITCHING_PROTOCOLS_STATUS_CODE) {
204         return false;
205     }
206     ToLowerCase(response.upgrade);
207     if (response.upgrade != HTTP_RESPONSE_REQUIRED_UPGRADE) {
208         return false;
209     }
210     ToLowerCase(response.connection);
211     if (response.connection != HTTP_RESPONSE_REQUIRED_CONNECTION) {
212         return false;
213     }
214 
215     // The same WebSocket-Key is used for all connections
216     // - must either use a randomly-selected, as required by spec or do this calculation statically.
217     unsigned char expectedSecWebSocketAccept_[WebSocketKeyEncoder::ENCODED_KEY_LEN + 1];
218     if (!WebSocketKeyEncoder::EncodeKey(DEFAULT_WEB_SOCKET_KEY, expectedSecWebSocketAccept_)) {
219         LOGE("ValidateServerHandShake::client failed to generate expected Sec-WebSocket-Accept token");
220         return false;
221     }
222 
223     Trim(response.secWebSocketAccept);
224     if (response.secWebSocketAccept.size() != WebSocketKeyEncoder::ENCODED_KEY_LEN ||
225         response.secWebSocketAccept.compare(reinterpret_cast<const char *>(expectedSecWebSocketAccept_)) != 0) {
226         return false;
227     }
228 
229     // may support two remaining checks
230     return true;
231 }
232 
DecodeMessage(WebSocketFrame & wsFrame) const233 bool WebSocketClient::DecodeMessage(WebSocketFrame& wsFrame) const
234 {
235     uint64_t msgLen = wsFrame.payloadLen;
236     if (msgLen == 0) {
237         // receiving empty data is OK
238         return true;
239     }
240     auto& buffer = wsFrame.payload;
241     buffer.resize(msgLen, 0);
242 
243     if (!Recv(connectionFd_, buffer, 0)) {
244         LOGE("DecodeMessage: Recv message without mask failed");
245         return false;
246     }
247 
248     return true;
249 }
250 
Close()251 void WebSocketClient::Close()
252 {
253     if (socketState_ == SocketState::CONNECTED) {
254         CloseConnection(CloseStatusCode::SERVER_GO_AWAY, SocketState::UNINITED);
255     }
256 }
257 
CloseConnectionSocketOnFail()258 void WebSocketClient::CloseConnectionSocketOnFail()
259 {
260     CloseConnectionSocket(ConnectionCloseReason::FAIL, SocketState::UNINITED);
261 }
262 
ValidateIncomingFrame(const WebSocketFrame & wsFrame)263 bool WebSocketClient::ValidateIncomingFrame(const WebSocketFrame& wsFrame)
264 {
265     // "A server MUST NOT mask any frames that it sends to the client."
266     // https://www.rfc-editor.org/rfc/rfc6455#section-5.1
267     return wsFrame.mask == 0;
268 }
269 
CreateFrame(bool isLast,FrameType frameType) const270 std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType) const
271 {
272     ClientFrameBuilder builder(isLast, frameType, MASK_KEY);
273     return builder.Build();
274 }
275 
CreateFrame(bool isLast,FrameType frameType,const std::string & payload) const276 std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType, const std::string& payload) const
277 {
278     ClientFrameBuilder builder(isLast, frameType, MASK_KEY);
279     return builder.SetPayload(payload).Build();
280 }
281 
CreateFrame(bool isLast,FrameType frameType,std::string && payload) const282 std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType, std::string&& payload) const
283 {
284     ClientFrameBuilder builder(isLast, frameType, MASK_KEY);
285     return builder.SetPayload(std::move(payload)).Build();
286 }
287 }  // namespace OHOS::ArkCompiler::Toolchain
288