1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include <arpa/inet.h>
17
18 #include "common/log_wrapper.h"
19 #include "frame_builder.h"
20 #include "handshake_helper.h"
21 #include "string_utils.h"
22 #include "client/websocket_client.h"
23
24 namespace OHOS::ArkCompiler::Toolchain {
ValidateServerHandShake(HttpResponse & response)25 static bool ValidateServerHandShake(HttpResponse& response)
26 {
27 static constexpr std::string_view HTTP_SWITCHING_PROTOCOLS_STATUS_CODE = "101";
28 static constexpr std::string_view HTTP_RESPONSE_REQUIRED_UPGRADE = "websocket";
29 static constexpr std::string_view HTTP_RESPONSE_REQUIRED_CONNECTION = "upgrade";
30 // NB! `defaultWebSocketKey` must match "Sec-WebSocket-Key" used in `WebSocketClient`.
31 static constexpr unsigned char defaultWebSocketKey[] = "64b4B+s5JDlgkdg7NekJ+g==";
32
33 // in accordance to https://www.rfc-editor.org/rfc/rfc6455#section-4.1
34 if (response.status != HTTP_SWITCHING_PROTOCOLS_STATUS_CODE) {
35 return false;
36 }
37 ToLowerCase(response.upgrade);
38 if (response.upgrade != HTTP_RESPONSE_REQUIRED_UPGRADE) {
39 return false;
40 }
41 ToLowerCase(response.connection);
42 if (response.connection != HTTP_RESPONSE_REQUIRED_CONNECTION) {
43 return false;
44 }
45
46 // The same WebSocket-Key is used for all connections
47 // - must either use a randomly-selected, as required by spec or do this calculation statically.
48 unsigned char expectedSecWebSocketAccept[WebSocketKeyEncoder::ENCODED_KEY_LEN + 1];
49 if (!WebSocketKeyEncoder::EncodeKey(defaultWebSocketKey, expectedSecWebSocketAccept)) {
50 LOGE("ValidateServerHandShake failed to generate expected Sec-WebSocket-Accept token");
51 return false;
52 }
53
54 Trim(response.secWebSocketAccept);
55 if (response.secWebSocketAccept.size() != WebSocketKeyEncoder::ENCODED_KEY_LEN ||
56 response.secWebSocketAccept.compare(reinterpret_cast<const char *>(expectedSecWebSocketAccept)) != 0) {
57 return false;
58 }
59
60 // may support two remaining checks
61 return true;
62 }
63
InitToolchainWebSocketForPort(int port,uint32_t timeoutLimit)64 bool WebSocketClient::InitToolchainWebSocketForPort(int port, uint32_t timeoutLimit)
65 {
66 if (GetConnectionState() != ConnectionState::CLOSED) {
67 LOGE("InitToolchainWebSocketForPort::client has inited.");
68 return true;
69 }
70
71 int connection = socket(AF_INET, SOCK_STREAM, 0);
72 if (connection < SOCKET_SUCCESS) {
73 LOGE("InitToolchainWebSocketForPort::client socket failed, error = %{public}d , desc = %{public}s",
74 errno, strerror(errno));
75 return false;
76 }
77 SetConnectionSocket(connection);
78
79 // set send and recv timeout limit
80 if (!SetWebSocketTimeOut(connection, timeoutLimit)) {
81 LOGE("InitToolchainWebSocketForPort::client SetWebSocketTimeOut failed, error = %{public}d , desc = %{public}s",
82 errno, strerror(errno));
83 CloseOnInitFailure();
84 return false;
85 }
86
87 sockaddr_in clientAddr;
88 if (memset_s(&clientAddr, sizeof(clientAddr), 0, sizeof(clientAddr)) != EOK) {
89 LOGE("InitToolchainWebSocketForPort::client memset_s clientAddr failed, error = %{public}d, desc = %{public}s",
90 errno, strerror(errno));
91 CloseOnInitFailure();
92 return false;
93 }
94 clientAddr.sin_family = AF_INET;
95 clientAddr.sin_port = htons(port);
96 int ret = inet_pton(AF_INET, "127.0.0.1", &clientAddr.sin_addr);
97 if (ret != NET_SUCCESS) {
98 LOGE("InitToolchainWebSocketForPort::client inet_pton failed, error = %{public}d, desc = %{public}s",
99 errno, strerror(errno));
100 CloseOnInitFailure();
101 return false;
102 }
103
104 ret = connect(connection, reinterpret_cast<struct sockaddr*>(&clientAddr), sizeof(clientAddr));
105 if (ret != SOCKET_SUCCESS) {
106 LOGE("InitToolchainWebSocketForPort::client connect failed, error = %{public}d, desc = %{public}s",
107 errno, strerror(errno));
108 CloseOnInitFailure();
109 return false;
110 }
111 SetConnectionState(ConnectionState::CONNECTING);
112 LOGI("InitToolchainWebSocketForPort::client connect success.");
113 return true;
114 }
115
InitToolchainWebSocketForSockName(const std::string & sockName,uint32_t timeoutLimit)116 bool WebSocketClient::InitToolchainWebSocketForSockName(const std::string &sockName, uint32_t timeoutLimit)
117 {
118 if (GetConnectionState() != ConnectionState::CLOSED) {
119 LOGE("InitToolchainWebSocketForSockName::client has inited.");
120 return true;
121 }
122
123 int connection = socket(AF_UNIX, SOCK_STREAM, 0);
124 if (connection < SOCKET_SUCCESS) {
125 LOGE("InitToolchainWebSocketForSockName::client socket failed, error = %{public}d , desc = %{public}s",
126 errno, strerror(errno));
127 return false;
128 }
129 SetConnectionSocket(connection);
130
131 // set send and recv timeout limit
132 if (!SetWebSocketTimeOut(connection, timeoutLimit)) {
133 LOGE("InitToolchainWebSocketForSockName::client SetWebSocketTimeOut failed, "
134 "error = %{public}d, desc = %{public}s",
135 errno, strerror(errno));
136 CloseOnInitFailure();
137 return false;
138 }
139
140 struct sockaddr_un serverAddr;
141 if (memset_s(&serverAddr, sizeof(serverAddr), 0, sizeof(serverAddr)) != EOK) {
142 LOGE("InitToolchainWebSocketForSockName::client memset_s clientAddr failed, "
143 "error = %{public}d, desc = %{public}s",
144 errno, strerror(errno));
145 CloseOnInitFailure();
146 return false;
147 }
148 serverAddr.sun_family = AF_UNIX;
149 if (strcpy_s(serverAddr.sun_path + 1, sizeof(serverAddr.sun_path) - 1, sockName.c_str()) != EOK) {
150 LOGE("InitToolchainWebSocketForSockName::client strcpy_s serverAddr.sun_path failed, "
151 "error = %{public}d, desc = %{public}s",
152 errno, strerror(errno));
153 CloseOnInitFailure();
154 return false;
155 }
156 serverAddr.sun_path[0] = '\0';
157
158 uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1;
159 int ret = connect(connection, reinterpret_cast<struct sockaddr*>(&serverAddr), static_cast<int32_t>(len));
160 if (ret != SOCKET_SUCCESS) {
161 LOGE("InitToolchainWebSocketForSockName::client connect failed, error = %{public}d, desc = %{public}s",
162 errno, strerror(errno));
163 CloseOnInitFailure();
164 return false;
165 }
166 SetConnectionState(ConnectionState::CONNECTING);
167 LOGI("InitToolchainWebSocketForSockName::client connect success.");
168 return true;
169 }
170
ClientSendWSUpgradeReq()171 bool WebSocketClient::ClientSendWSUpgradeReq()
172 {
173 auto state = GetConnectionState();
174 if (state == ConnectionState::CLOSING || state == ConnectionState::CLOSED) {
175 LOGE("ClientSendWSUpgradeReq::client has not inited.");
176 return false;
177 }
178 if (state == ConnectionState::OPEN) {
179 LOGE("ClientSendWSUpgradeReq::client has connected.");
180 return true;
181 }
182
183 // length without null-terminator
184 if (!Send(GetConnectionSocket(), CLIENT_WEBSOCKET_UPGRADE_REQ, sizeof(CLIENT_WEBSOCKET_UPGRADE_REQ) - 1, 0)) {
185 LOGE("ClientSendWSUpgradeReq::client send wsupgrade req failed, error = %{public}d, desc = %{public}sn",
186 errno, strerror(errno));
187 CloseOnInitFailure();
188 return false;
189 }
190 LOGI("ClientSendWSUpgradeReq::client send wsupgrade req success.");
191 return true;
192 }
193
ClientRecvWSUpgradeRsp()194 bool WebSocketClient::ClientRecvWSUpgradeRsp()
195 {
196 auto state = GetConnectionState();
197 if (state == ConnectionState::CLOSING || state == ConnectionState::CLOSED) {
198 LOGE("ClientRecvWSUpgradeRsp::client has not inited.");
199 return false;
200 }
201 if (state == ConnectionState::OPEN) {
202 LOGE("ClientRecvWSUpgradeRsp::client has connected.");
203 return true;
204 }
205
206 std::string msgBuf(HTTP_HANDSHAKE_MAX_LEN, 0);
207 ssize_t msgLen = 0;
208 while ((msgLen = recv(GetConnectionSocket(), msgBuf.data(), HTTP_HANDSHAKE_MAX_LEN, 0)) < 0 &&
209 (errno == EINTR || errno == EAGAIN)) {
210 LOGW("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp failed, errno = %{public}d", errno);
211 }
212 if (msgLen <= 0) {
213 LOGE("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp failed, error = %{public}d, desc = %{public}sn",
214 errno, strerror(errno));
215 CloseOnInitFailure();
216 return false;
217 }
218 // reduce to received size
219 msgBuf.resize(msgLen);
220
221 HttpResponse response;
222 if (!HttpResponse::Decode(msgBuf, response) || !ValidateServerHandShake(response)) {
223 LOGE("ClientRecvWSUpgradeRsp::client server handshake response is invalid");
224 CloseOnInitFailure();
225 return false;
226 }
227
228 SetConnectionState(ConnectionState::OPEN);
229 LOGI("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp success.");
230 return true;
231 }
232
GetSocketStateString()233 std::string WebSocketClient::GetSocketStateString()
234 {
235 return std::string(SOCKET_STATE_NAMES[EnumToNumber(GetConnectionState())]);
236 }
237
DecodeMessage(WebSocketFrame & wsFrame) const238 bool WebSocketClient::DecodeMessage(WebSocketFrame& wsFrame) const
239 {
240 uint64_t msgLen = wsFrame.payloadLen;
241 if (msgLen == 0) {
242 // receiving empty data is OK
243 return true;
244 }
245 auto& buffer = wsFrame.payload;
246 buffer.resize(msgLen, 0);
247
248 if (!RecvUnderLock(buffer)) {
249 LOGE("DecodeMessage: Recv message without mask failed");
250 return false;
251 }
252
253 return true;
254 }
255
Close()256 void WebSocketClient::Close()
257 {
258 if (!CloseConnection(CloseStatusCode::SERVER_GO_AWAY)) {
259 LOGW("Failed to close WebSocketClient");
260 }
261 }
262
CloseOnInitFailure()263 void WebSocketClient::CloseOnInitFailure()
264 {
265 // Must be in `CONNECTING` or `CLOSED` state.
266 auto expected = ConnectionState::CONNECTING;
267 if (!CompareExchangeConnectionState(expected, ConnectionState::CLOSING) &&
268 expected != ConnectionState::CLOSED) {
269 LOGE("CloseOnInitFailure violation: must be either CONNECTING or CLOSED, but got %{public}d",
270 EnumToNumber(expected));
271 }
272 CloseConnectionSocket(ConnectionCloseReason::FAIL);
273 }
274
ValidateIncomingFrame(const WebSocketFrame & wsFrame) const275 bool WebSocketClient::ValidateIncomingFrame(const WebSocketFrame& wsFrame) const
276 {
277 // "A server MUST NOT mask any frames that it sends to the client."
278 // https://www.rfc-editor.org/rfc/rfc6455#section-5.1
279 return wsFrame.mask == 0;
280 }
281
CreateFrame(bool isLast,FrameType frameType) const282 std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType) const
283 {
284 ClientFrameBuilder builder(isLast, frameType, MASK_KEY);
285 return builder.Build();
286 }
287
CreateFrame(bool isLast,FrameType frameType,const std::string & payload) const288 std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType, const std::string& payload) const
289 {
290 ClientFrameBuilder builder(isLast, frameType, MASK_KEY);
291 return builder.SetPayload(payload).Build();
292 }
293
CreateFrame(bool isLast,FrameType frameType,std::string && payload) const294 std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType, std::string&& payload) const
295 {
296 ClientFrameBuilder builder(isLast, frameType, MASK_KEY);
297 return builder.SetPayload(std::move(payload)).Build();
298 }
299 } // namespace OHOS::ArkCompiler::Toolchain
300