1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include <sys/socket.h>
17 #include <sys/time.h>
18 #include <arpa/inet.h>
19 #include <sys/un.h>
20 #include <unistd.h>
21
22 #include "common/log_wrapper.h"
23 #include "frame_builder.h"
24 #include "handshake_helper.h"
25 #include "network.h"
26 #include "string_utils.h"
27 #include "client/websocket_client.h"
28
29 namespace OHOS::ArkCompiler::Toolchain {
ValidateServerHandShake(HttpResponse & response)30 static bool ValidateServerHandShake(HttpResponse& response)
31 {
32 static constexpr std::string_view HTTP_SWITCHING_PROTOCOLS_STATUS_CODE = "101";
33 static constexpr std::string_view HTTP_RESPONSE_REQUIRED_UPGRADE = "websocket";
34 static constexpr std::string_view HTTP_RESPONSE_REQUIRED_CONNECTION = "upgrade";
35 // NB! `defaultWebSocketKey` must match "Sec-WebSocket-Key" used in `WebSocketClient`.
36 static constexpr unsigned char defaultWebSocketKey[] = "64b4B+s5JDlgkdg7NekJ+g==";
37
38 // in accordance to https://www.rfc-editor.org/rfc/rfc6455#section-4.1
39 if (response.status != HTTP_SWITCHING_PROTOCOLS_STATUS_CODE) {
40 return false;
41 }
42 ToLowerCase(response.upgrade);
43 if (response.upgrade != HTTP_RESPONSE_REQUIRED_UPGRADE) {
44 return false;
45 }
46 ToLowerCase(response.connection);
47 if (response.connection != HTTP_RESPONSE_REQUIRED_CONNECTION) {
48 return false;
49 }
50
51 // The same WebSocket-Key is used for all connections
52 // - must either use a randomly-selected, as required by spec or do this calculation statically.
53 unsigned char expectedSecWebSocketAccept[WebSocketKeyEncoder::ENCODED_KEY_LEN + 1];
54 if (!WebSocketKeyEncoder::EncodeKey(defaultWebSocketKey, expectedSecWebSocketAccept)) {
55 LOGE("ValidateServerHandShake failed to generate expected Sec-WebSocket-Accept token");
56 return false;
57 }
58
59 Trim(response.secWebSocketAccept);
60 if (response.secWebSocketAccept.size() != WebSocketKeyEncoder::ENCODED_KEY_LEN ||
61 response.secWebSocketAccept.compare(reinterpret_cast<const char *>(expectedSecWebSocketAccept)) != 0) {
62 return false;
63 }
64
65 // may support two remaining checks
66 return true;
67 }
68
InitToolchainWebSocketForPort(int port,uint32_t timeoutLimit)69 bool WebSocketClient::InitToolchainWebSocketForPort(int port, uint32_t timeoutLimit)
70 {
71 if (GetConnectionState() != ConnectionState::CLOSED) {
72 LOGE("InitToolchainWebSocketForPort::client has inited.");
73 return true;
74 }
75
76 int connection = socket(AF_INET, SOCK_STREAM, 0);
77 if (connection < SOCKET_SUCCESS) {
78 LOGE("InitToolchainWebSocketForPort::client socket failed, error = %{public}d , desc = %{public}s",
79 errno, strerror(errno));
80 return false;
81 }
82 SetConnectionSocket(connection);
83
84 // set send and recv timeout limit
85 if (!SetWebSocketTimeOut(connection, timeoutLimit)) {
86 LOGE("InitToolchainWebSocketForPort::client SetWebSocketTimeOut failed, error = %{public}d , desc = %{public}s",
87 errno, strerror(errno));
88 CloseOnInitFailure();
89 return false;
90 }
91
92 sockaddr_in clientAddr;
93 if (memset_s(&clientAddr, sizeof(clientAddr), 0, sizeof(clientAddr)) != EOK) {
94 LOGE("InitToolchainWebSocketForPort::client memset_s clientAddr failed, error = %{public}d, desc = %{public}s",
95 errno, strerror(errno));
96 CloseOnInitFailure();
97 return false;
98 }
99 clientAddr.sin_family = AF_INET;
100 clientAddr.sin_port = htons(port);
101 int ret = inet_pton(AF_INET, "127.0.0.1", &clientAddr.sin_addr);
102 if (ret != NET_SUCCESS) {
103 LOGE("InitToolchainWebSocketForPort::client inet_pton failed, error = %{public}d, desc = %{public}s",
104 errno, strerror(errno));
105 CloseOnInitFailure();
106 return false;
107 }
108
109 ret = connect(connection, reinterpret_cast<struct sockaddr*>(&clientAddr), sizeof(clientAddr));
110 if (ret != SOCKET_SUCCESS) {
111 LOGE("InitToolchainWebSocketForPort::client connect failed, error = %{public}d, desc = %{public}s",
112 errno, strerror(errno));
113 CloseOnInitFailure();
114 return false;
115 }
116 SetConnectionState(ConnectionState::CONNECTING);
117 LOGI("InitToolchainWebSocketForPort::client connect success.");
118 return true;
119 }
120
InitToolchainWebSocketForSockName(const std::string & sockName,uint32_t timeoutLimit)121 bool WebSocketClient::InitToolchainWebSocketForSockName(const std::string &sockName, uint32_t timeoutLimit)
122 {
123 if (GetConnectionState() != ConnectionState::CLOSED) {
124 LOGE("InitToolchainWebSocketForSockName::client has inited.");
125 return true;
126 }
127
128 int connection = socket(AF_UNIX, SOCK_STREAM, 0);
129 if (connection < SOCKET_SUCCESS) {
130 LOGE("InitToolchainWebSocketForSockName::client socket failed, error = %{public}d , desc = %{public}s",
131 errno, strerror(errno));
132 return false;
133 }
134 SetConnectionSocket(connection);
135
136 // set send and recv timeout limit
137 if (!SetWebSocketTimeOut(connection, timeoutLimit)) {
138 LOGE("InitToolchainWebSocketForSockName::client SetWebSocketTimeOut failed, "
139 "error = %{public}d, desc = %{public}s",
140 errno, strerror(errno));
141 CloseOnInitFailure();
142 return false;
143 }
144
145 struct sockaddr_un serverAddr;
146 if (memset_s(&serverAddr, sizeof(serverAddr), 0, sizeof(serverAddr)) != EOK) {
147 LOGE("InitToolchainWebSocketForSockName::client memset_s clientAddr failed, "
148 "error = %{public}d, desc = %{public}s",
149 errno, strerror(errno));
150 CloseOnInitFailure();
151 return false;
152 }
153 serverAddr.sun_family = AF_UNIX;
154 if (strcpy_s(serverAddr.sun_path + 1, sizeof(serverAddr.sun_path) - 1, sockName.c_str()) != EOK) {
155 LOGE("InitToolchainWebSocketForSockName::client strcpy_s serverAddr.sun_path failed, "
156 "error = %{public}d, desc = %{public}s",
157 errno, strerror(errno));
158 CloseOnInitFailure();
159 return false;
160 }
161 serverAddr.sun_path[0] = '\0';
162
163 uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1;
164 int ret = connect(connection, reinterpret_cast<struct sockaddr*>(&serverAddr), static_cast<int32_t>(len));
165 if (ret != SOCKET_SUCCESS) {
166 LOGE("InitToolchainWebSocketForSockName::client connect failed, error = %{public}d, desc = %{public}s",
167 errno, strerror(errno));
168 CloseOnInitFailure();
169 return false;
170 }
171 SetConnectionState(ConnectionState::CONNECTING);
172 LOGI("InitToolchainWebSocketForSockName::client connect success.");
173 return true;
174 }
175
ClientSendWSUpgradeReq()176 bool WebSocketClient::ClientSendWSUpgradeReq()
177 {
178 auto state = GetConnectionState();
179 if (state == ConnectionState::CLOSING || state == ConnectionState::CLOSED) {
180 LOGE("ClientSendWSUpgradeReq::client has not inited.");
181 return false;
182 }
183 if (state == ConnectionState::OPEN) {
184 LOGE("ClientSendWSUpgradeReq::client has connected.");
185 return true;
186 }
187
188 // length without null-terminator
189 if (!Send(GetConnectionSocket(), CLIENT_WEBSOCKET_UPGRADE_REQ, sizeof(CLIENT_WEBSOCKET_UPGRADE_REQ) - 1, 0)) {
190 LOGE("ClientSendWSUpgradeReq::client send wsupgrade req failed, error = %{public}d, desc = %{public}sn",
191 errno, strerror(errno));
192 CloseOnInitFailure();
193 return false;
194 }
195 LOGI("ClientSendWSUpgradeReq::client send wsupgrade req success.");
196 return true;
197 }
198
ClientRecvWSUpgradeRsp()199 bool WebSocketClient::ClientRecvWSUpgradeRsp()
200 {
201 auto state = GetConnectionState();
202 if (state == ConnectionState::CLOSING || state == ConnectionState::CLOSED) {
203 LOGE("ClientRecvWSUpgradeRsp::client has not inited.");
204 return false;
205 }
206 if (state == ConnectionState::OPEN) {
207 LOGE("ClientRecvWSUpgradeRsp::client has connected.");
208 return true;
209 }
210
211 std::string msgBuf(HTTP_HANDSHAKE_MAX_LEN, 0);
212 ssize_t msgLen = 0;
213 while ((msgLen = recv(GetConnectionSocket(), msgBuf.data(), HTTP_HANDSHAKE_MAX_LEN, 0)) < 0 &&
214 (errno == EINTR || errno == EAGAIN)) {
215 LOGW("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp failed, errno = %{public}d", errno);
216 }
217 if (msgLen <= 0) {
218 LOGE("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp failed, error = %{public}d, desc = %{public}sn",
219 errno, strerror(errno));
220 CloseOnInitFailure();
221 return false;
222 }
223 // reduce to received size
224 msgBuf.resize(msgLen);
225
226 HttpResponse response;
227 if (!HttpResponse::Decode(msgBuf, response) || !ValidateServerHandShake(response)) {
228 LOGE("ClientRecvWSUpgradeRsp::client server handshake response is invalid");
229 CloseOnInitFailure();
230 return false;
231 }
232
233 SetConnectionState(ConnectionState::OPEN);
234 LOGI("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp success.");
235 return true;
236 }
237
GetSocketStateString()238 std::string WebSocketClient::GetSocketStateString()
239 {
240 return std::string(SOCKET_STATE_NAMES[EnumToNumber(GetConnectionState())]);
241 }
242
DecodeMessage(WebSocketFrame & wsFrame) const243 bool WebSocketClient::DecodeMessage(WebSocketFrame& wsFrame) const
244 {
245 uint64_t msgLen = wsFrame.payloadLen;
246 if (msgLen == 0) {
247 // receiving empty data is OK
248 return true;
249 }
250 auto& buffer = wsFrame.payload;
251 buffer.resize(msgLen, 0);
252
253 if (!RecvUnderLock(buffer)) {
254 LOGE("DecodeMessage: Recv message without mask failed");
255 return false;
256 }
257
258 return true;
259 }
260
Close()261 void WebSocketClient::Close()
262 {
263 if (!CloseConnection(CloseStatusCode::SERVER_GO_AWAY)) {
264 LOGW("Failed to close WebSocketClient");
265 }
266 }
267
CloseOnInitFailure()268 void WebSocketClient::CloseOnInitFailure()
269 {
270 // Must be in `CONNECTING` or `CLOSED` state.
271 auto expected = ConnectionState::CONNECTING;
272 if (!CompareExchangeConnectionState(expected, ConnectionState::CLOSING) &&
273 expected != ConnectionState::CLOSED) {
274 LOGE("CloseOnInitFailure violation: must be either CONNECTING or CLOSED, but got %{public}d",
275 EnumToNumber(expected));
276 }
277 CloseConnectionSocket(ConnectionCloseReason::FAIL);
278 }
279
ValidateIncomingFrame(const WebSocketFrame & wsFrame) const280 bool WebSocketClient::ValidateIncomingFrame(const WebSocketFrame& wsFrame) const
281 {
282 // "A server MUST NOT mask any frames that it sends to the client."
283 // https://www.rfc-editor.org/rfc/rfc6455#section-5.1
284 return wsFrame.mask == 0;
285 }
286
CreateFrame(bool isLast,FrameType frameType) const287 std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType) const
288 {
289 ClientFrameBuilder builder(isLast, frameType, MASK_KEY);
290 return builder.Build();
291 }
292
CreateFrame(bool isLast,FrameType frameType,const std::string & payload) const293 std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType, const std::string& payload) const
294 {
295 ClientFrameBuilder builder(isLast, frameType, MASK_KEY);
296 return builder.SetPayload(payload).Build();
297 }
298
CreateFrame(bool isLast,FrameType frameType,std::string && payload) const299 std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType, std::string&& payload) const
300 {
301 ClientFrameBuilder builder(isLast, frameType, MASK_KEY);
302 return builder.SetPayload(std::move(payload)).Build();
303 }
304 } // namespace OHOS::ArkCompiler::Toolchain
305