1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include <sys/socket.h>
17 #include <sys/time.h>
18 #include <arpa/inet.h>
19 #include <sys/un.h>
20 #include <unistd.h>
21
22 #include "common/log_wrapper.h"
23 #include "websocket/frame_builder.h"
24 #include "websocket/handshake_helper.h"
25 #include "websocket/network.h"
26 #include "websocket/string_utils.h"
27 #include "websocket/client/websocket_client.h"
28
29 namespace OHOS::ArkCompiler::Toolchain {
InitToolchainWebSocketForPort(int port,uint32_t timeoutLimit)30 bool WebSocketClient::InitToolchainWebSocketForPort(int port, uint32_t timeoutLimit)
31 {
32 if (socketState_ != SocketState::UNINITED) {
33 LOGE("InitToolchainWebSocketForPort::client has inited.");
34 return true;
35 }
36
37 connectionFd_ = socket(AF_INET, SOCK_STREAM, 0);
38 if (connectionFd_ < SOCKET_SUCCESS) {
39 LOGE("InitToolchainWebSocketForPort::client socket failed, error = %{public}d , desc = %{public}s",
40 errno, strerror(errno));
41 return false;
42 }
43
44 // set send and recv timeout limit
45 if (!SetWebSocketTimeOut(connectionFd_, timeoutLimit)) {
46 LOGE("InitToolchainWebSocketForPort::client SetWebSocketTimeOut failed, error = %{public}d , desc = %{public}s",
47 errno, strerror(errno));
48 CloseConnectionSocketOnFail();
49 return false;
50 }
51
52 sockaddr_in clientAddr;
53 if (memset_s(&clientAddr, sizeof(clientAddr), 0, sizeof(clientAddr)) != EOK) {
54 LOGE("InitToolchainWebSocketForPort::client memset_s clientAddr failed, error = %{public}d, desc = %{public}s",
55 errno, strerror(errno));
56 CloseConnectionSocketOnFail();
57 return false;
58 }
59 clientAddr.sin_family = AF_INET;
60 clientAddr.sin_port = htons(port);
61 int ret = inet_pton(AF_INET, "127.0.0.1", &clientAddr.sin_addr);
62 if (ret != NET_SUCCESS) {
63 LOGE("InitToolchainWebSocketForPort::client inet_pton failed, error = %{public}d, desc = %{public}s",
64 errno, strerror(errno));
65 CloseConnectionSocketOnFail();
66 return false;
67 }
68
69 ret = connect(connectionFd_, reinterpret_cast<struct sockaddr*>(&clientAddr), sizeof(clientAddr));
70 if (ret != SOCKET_SUCCESS) {
71 LOGE("InitToolchainWebSocketForPort::client connect failed, error = %{public}d, desc = %{public}s",
72 errno, strerror(errno));
73 CloseConnectionSocketOnFail();
74 return false;
75 }
76 socketState_ = SocketState::INITED;
77 LOGI("InitToolchainWebSocketForPort::client connect success.");
78 return true;
79 }
80
InitToolchainWebSocketForSockName(const std::string & sockName,uint32_t timeoutLimit)81 bool WebSocketClient::InitToolchainWebSocketForSockName(const std::string &sockName, uint32_t timeoutLimit)
82 {
83 if (socketState_ != SocketState::UNINITED) {
84 LOGE("InitToolchainWebSocketForSockName::client has inited.");
85 return true;
86 }
87
88 connectionFd_ = socket(AF_UNIX, SOCK_STREAM, 0);
89 if (connectionFd_ < SOCKET_SUCCESS) {
90 LOGE("InitToolchainWebSocketForSockName::client socket failed, error = %{public}d , desc = %{public}s",
91 errno, strerror(errno));
92 return false;
93 }
94
95 // set send and recv timeout limit
96 if (!SetWebSocketTimeOut(connectionFd_, timeoutLimit)) {
97 LOGE("InitToolchainWebSocketForSockName::client SetWebSocketTimeOut failed, "
98 "error = %{public}d, desc = %{public}s",
99 errno, strerror(errno));
100 CloseConnectionSocketOnFail();
101 return false;
102 }
103
104 struct sockaddr_un serverAddr;
105 if (memset_s(&serverAddr, sizeof(serverAddr), 0, sizeof(serverAddr)) != EOK) {
106 LOGE("InitToolchainWebSocketForSockName::client memset_s clientAddr failed, "
107 "error = %{public}d, desc = %{public}s",
108 errno, strerror(errno));
109 CloseConnectionSocketOnFail();
110 return false;
111 }
112 serverAddr.sun_family = AF_UNIX;
113 if (strcpy_s(serverAddr.sun_path + 1, sizeof(serverAddr.sun_path) - 1, sockName.c_str()) != EOK) {
114 LOGE("InitToolchainWebSocketForSockName::client strcpy_s serverAddr.sun_path failed, "
115 "error = %{public}d, desc = %{public}s",
116 errno, strerror(errno));
117 CloseConnectionSocketOnFail();
118 return false;
119 }
120 serverAddr.sun_path[0] = '\0';
121
122 uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1;
123 int ret = connect(connectionFd_, reinterpret_cast<struct sockaddr*>(&serverAddr), static_cast<int32_t>(len));
124 if (ret != SOCKET_SUCCESS) {
125 LOGE("InitToolchainWebSocketForSockName::client connect failed, error = %{public}d, desc = %{public}s",
126 errno, strerror(errno));
127 CloseConnectionSocketOnFail();
128 return false;
129 }
130 socketState_ = SocketState::INITED;
131 LOGI("InitToolchainWebSocketForSockName::client connect success.");
132 return true;
133 }
134
ClientSendWSUpgradeReq()135 bool WebSocketClient::ClientSendWSUpgradeReq()
136 {
137 if (socketState_ == SocketState::UNINITED) {
138 LOGE("ClientSendWSUpgradeReq::client has not inited.");
139 return false;
140 }
141 if (socketState_ == SocketState::CONNECTED) {
142 LOGE("ClientSendWSUpgradeReq::client has connected.");
143 return true;
144 }
145
146 // length without null-terminator
147 if (!Send(connectionFd_, CLIENT_WEBSOCKET_UPGRADE_REQ, sizeof(CLIENT_WEBSOCKET_UPGRADE_REQ) - 1, 0)) {
148 LOGE("ClientSendWSUpgradeReq::client send wsupgrade req failed, error = %{public}d, desc = %{public}sn",
149 errno, strerror(errno));
150 CloseConnectionSocketOnFail();
151 return false;
152 }
153 LOGI("ClientSendWSUpgradeReq::client send wsupgrade req success.");
154 return true;
155 }
156
ClientRecvWSUpgradeRsp()157 bool WebSocketClient::ClientRecvWSUpgradeRsp()
158 {
159 if (socketState_ == SocketState::UNINITED) {
160 LOGE("ClientRecvWSUpgradeRsp::client has not inited.");
161 return false;
162 }
163 if (socketState_ == SocketState::CONNECTED) {
164 LOGE("ClientRecvWSUpgradeRsp::client has connected.");
165 return true;
166 }
167
168 std::string msgBuf(HTTP_HANDSHAKE_MAX_LEN, 0);
169 ssize_t msgLen = 0;
170 while ((msgLen = recv(connectionFd_, msgBuf.data(), HTTP_HANDSHAKE_MAX_LEN, 0)) < 0 && errno == EINTR) {
171 LOGW("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp failed, errno == EINTR");
172 }
173 if (msgLen <= 0) {
174 LOGE("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp failed, error = %{public}d, desc = %{public}sn",
175 errno, strerror(errno));
176 CloseConnectionSocketOnFail();
177 return false;
178 }
179 // reduce to received size
180 msgBuf.resize(msgLen);
181
182 HttpResponse response;
183 if (!HttpResponse::Decode(msgBuf, response) || !ValidateServerHandShake(response)) {
184 LOGE("ClientRecvWSUpgradeRsp::client server handshake response is invalid");
185 CloseConnectionSocketOnFail();
186 return false;
187 }
188
189 socketState_ = SocketState::CONNECTED;
190 LOGI("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp success.");
191 return true;
192 }
193
GetSocketStateString()194 std::string WebSocketClient::GetSocketStateString()
195 {
196 return std::string(SOCKET_STATE_NAMES[EnumToNumber(socketState_.load())]);
197 }
198
199 /* static */
ValidateServerHandShake(HttpResponse & response)200 bool WebSocketClient::ValidateServerHandShake(HttpResponse& response)
201 {
202 // in accordance to https://www.rfc-editor.org/rfc/rfc6455#section-4.1
203 if (response.status != HTTP_SWITCHING_PROTOCOLS_STATUS_CODE) {
204 return false;
205 }
206 ToLowerCase(response.upgrade);
207 if (response.upgrade != HTTP_RESPONSE_REQUIRED_UPGRADE) {
208 return false;
209 }
210 ToLowerCase(response.connection);
211 if (response.connection != HTTP_RESPONSE_REQUIRED_CONNECTION) {
212 return false;
213 }
214
215 // The same WebSocket-Key is used for all connections
216 // - must either use a randomly-selected, as required by spec or do this calculation statically.
217 unsigned char expectedSecWebSocketAccept_[WebSocketKeyEncoder::ENCODED_KEY_LEN + 1];
218 if (!WebSocketKeyEncoder::EncodeKey(DEFAULT_WEB_SOCKET_KEY, expectedSecWebSocketAccept_)) {
219 LOGE("ValidateServerHandShake::client failed to generate expected Sec-WebSocket-Accept token");
220 return false;
221 }
222
223 Trim(response.secWebSocketAccept);
224 if (response.secWebSocketAccept.size() != WebSocketKeyEncoder::ENCODED_KEY_LEN ||
225 response.secWebSocketAccept.compare(reinterpret_cast<const char *>(expectedSecWebSocketAccept_)) != 0) {
226 return false;
227 }
228
229 // may support two remaining checks
230 return true;
231 }
232
DecodeMessage(WebSocketFrame & wsFrame) const233 bool WebSocketClient::DecodeMessage(WebSocketFrame& wsFrame) const
234 {
235 uint64_t msgLen = wsFrame.payloadLen;
236 if (msgLen == 0) {
237 // receiving empty data is OK
238 return true;
239 }
240 auto& buffer = wsFrame.payload;
241 buffer.resize(msgLen, 0);
242
243 if (!Recv(connectionFd_, buffer, 0)) {
244 LOGE("DecodeMessage: Recv message without mask failed");
245 return false;
246 }
247
248 return true;
249 }
250
Close()251 void WebSocketClient::Close()
252 {
253 if (socketState_ == SocketState::CONNECTED) {
254 CloseConnection(CloseStatusCode::SERVER_GO_AWAY, SocketState::UNINITED);
255 }
256 }
257
CloseConnectionSocketOnFail()258 void WebSocketClient::CloseConnectionSocketOnFail()
259 {
260 CloseConnectionSocket(ConnectionCloseReason::FAIL, SocketState::UNINITED);
261 }
262
ValidateIncomingFrame(const WebSocketFrame & wsFrame)263 bool WebSocketClient::ValidateIncomingFrame(const WebSocketFrame& wsFrame)
264 {
265 // "A server MUST NOT mask any frames that it sends to the client."
266 // https://www.rfc-editor.org/rfc/rfc6455#section-5.1
267 return wsFrame.mask == 0;
268 }
269
CreateFrame(bool isLast,FrameType frameType) const270 std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType) const
271 {
272 ClientFrameBuilder builder(isLast, frameType, MASK_KEY);
273 return builder.Build();
274 }
275
CreateFrame(bool isLast,FrameType frameType,const std::string & payload) const276 std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType, const std::string& payload) const
277 {
278 ClientFrameBuilder builder(isLast, frameType, MASK_KEY);
279 return builder.SetPayload(payload).Build();
280 }
281
CreateFrame(bool isLast,FrameType frameType,std::string && payload) const282 std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType, std::string&& payload) const
283 {
284 ClientFrameBuilder builder(isLast, frameType, MASK_KEY);
285 return builder.SetPayload(std::move(payload)).Build();
286 }
287 } // namespace OHOS::ArkCompiler::Toolchain
288