1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include <arpa/inet.h>
17
18 #include "common/log_wrapper.h"
19 #include "frame_builder.h"
20 #include "handshake_helper.h"
21 #include "string_utils.h"
22 #include "client/websocket_client.h"
23
24 namespace OHOS::ArkCompiler::Toolchain {
ValidateServerHandShake(HttpResponse & response)25 bool WebSocketClient::ValidateServerHandShake(HttpResponse& response)
26 {
27 static constexpr std::string_view HTTP_SWITCHING_PROTOCOLS_STATUS_CODE = "101";
28 static constexpr std::string_view HTTP_RESPONSE_REQUIRED_UPGRADE = "websocket";
29 static constexpr std::string_view HTTP_RESPONSE_REQUIRED_CONNECTION = "upgrade";
30
31 // in accordance to https://www.rfc-editor.org/rfc/rfc6455#section-4.1
32 if (response.status != HTTP_SWITCHING_PROTOCOLS_STATUS_CODE) {
33 return false;
34 }
35 ToLowerCase(response.upgrade);
36 if (response.upgrade != HTTP_RESPONSE_REQUIRED_UPGRADE) {
37 return false;
38 }
39 ToLowerCase(response.connection);
40 if (response.connection != HTTP_RESPONSE_REQUIRED_CONNECTION) {
41 return false;
42 }
43
44 // The same WebSocket-Key is used for all connections
45 // - must either use a randomly-selected, as required by spec or do this calculation statically.
46 unsigned char expectedSecWebSocketAccept[WebSocketKeyEncoder::ENCODED_KEY_LEN + 1];
47 if (!WebSocketKeyEncoder::EncodeKey(secWebSocketKey_, expectedSecWebSocketAccept)) {
48 LOGE("ValidateServerHandShake failed to generate expected Sec-WebSocket-Accept token");
49 return false;
50 }
51
52 Trim(response.secWebSocketAccept);
53 if (response.secWebSocketAccept.size() != WebSocketKeyEncoder::ENCODED_KEY_LEN ||
54 response.secWebSocketAccept.compare(reinterpret_cast<const char *>(expectedSecWebSocketAccept)) != 0) {
55 return false;
56 }
57
58 // may support two remaining checks
59 return true;
60 }
61
InitToolchainWebSocketForPort(int port,uint32_t timeoutLimit)62 bool WebSocketClient::InitToolchainWebSocketForPort(int port, uint32_t timeoutLimit)
63 {
64 if (GetConnectionState() != ConnectionState::CLOSED) {
65 LOGE("InitToolchainWebSocketForPort::client has inited.");
66 return true;
67 }
68
69 int connection = socket(AF_INET, SOCK_STREAM, 0);
70 if (connection < SOCKET_SUCCESS) {
71 LOGE("InitToolchainWebSocketForPort::client socket failed, error = %{public}d , desc = %{public}s",
72 errno, strerror(errno));
73 return false;
74 }
75 SetConnectionSocket(connection);
76
77 // set send and recv timeout limit
78 if (!SetWebSocketTimeOut(connection, timeoutLimit)) {
79 LOGE("InitToolchainWebSocketForPort::client SetWebSocketTimeOut failed, error = %{public}d , desc = %{public}s",
80 errno, strerror(errno));
81 CloseOnInitFailure();
82 return false;
83 }
84
85 sockaddr_in clientAddr;
86 if (memset_s(&clientAddr, sizeof(clientAddr), 0, sizeof(clientAddr)) != EOK) {
87 LOGE("InitToolchainWebSocketForPort::client memset_s clientAddr failed, error = %{public}d, desc = %{public}s",
88 errno, strerror(errno));
89 CloseOnInitFailure();
90 return false;
91 }
92 clientAddr.sin_family = AF_INET;
93 clientAddr.sin_port = htons(port);
94 int ret = inet_pton(AF_INET, "127.0.0.1", &clientAddr.sin_addr);
95 if (ret != NET_SUCCESS) {
96 LOGE("InitToolchainWebSocketForPort::client inet_pton failed, error = %{public}d, desc = %{public}s",
97 errno, strerror(errno));
98 CloseOnInitFailure();
99 return false;
100 }
101
102 ret = connect(connection, reinterpret_cast<struct sockaddr*>(&clientAddr), sizeof(clientAddr));
103 if (ret != SOCKET_SUCCESS) {
104 LOGE("InitToolchainWebSocketForPort::client connect failed, error = %{public}d, desc = %{public}s",
105 errno, strerror(errno));
106 CloseOnInitFailure();
107 return false;
108 }
109 SetConnectionState(ConnectionState::CONNECTING);
110 LOGI("InitToolchainWebSocketForPort::client connect success.");
111 return true;
112 }
113
InitToolchainWebSocketForSockName(const std::string & sockName,uint32_t timeoutLimit)114 bool WebSocketClient::InitToolchainWebSocketForSockName(const std::string &sockName, uint32_t timeoutLimit)
115 {
116 if (GetConnectionState() != ConnectionState::CLOSED) {
117 LOGE("InitToolchainWebSocketForSockName::client has inited.");
118 return true;
119 }
120
121 int connection = socket(AF_UNIX, SOCK_STREAM, 0);
122 if (connection < SOCKET_SUCCESS) {
123 LOGE("InitToolchainWebSocketForSockName::client socket failed, error = %{public}d , desc = %{public}s",
124 errno, strerror(errno));
125 return false;
126 }
127 SetConnectionSocket(connection);
128
129 // set send and recv timeout limit
130 if (!SetWebSocketTimeOut(connection, timeoutLimit)) {
131 LOGE("InitToolchainWebSocketForSockName::client SetWebSocketTimeOut failed, "
132 "error = %{public}d, desc = %{public}s",
133 errno, strerror(errno));
134 CloseOnInitFailure();
135 return false;
136 }
137
138 struct sockaddr_un serverAddr;
139 if (memset_s(&serverAddr, sizeof(serverAddr), 0, sizeof(serverAddr)) != EOK) {
140 LOGE("InitToolchainWebSocketForSockName::client memset_s clientAddr failed, "
141 "error = %{public}d, desc = %{public}s",
142 errno, strerror(errno));
143 CloseOnInitFailure();
144 return false;
145 }
146 serverAddr.sun_family = AF_UNIX;
147 if (strcpy_s(serverAddr.sun_path + 1, sizeof(serverAddr.sun_path) - 1, sockName.c_str()) != EOK) {
148 LOGE("InitToolchainWebSocketForSockName::client strcpy_s serverAddr.sun_path failed, "
149 "error = %{public}d, desc = %{public}s",
150 errno, strerror(errno));
151 CloseOnInitFailure();
152 return false;
153 }
154 serverAddr.sun_path[0] = '\0';
155
156 uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1;
157 int ret = connect(connection, reinterpret_cast<struct sockaddr*>(&serverAddr), static_cast<int32_t>(len));
158 if (ret != SOCKET_SUCCESS) {
159 LOGE("InitToolchainWebSocketForSockName::client connect failed, error = %{public}d, desc = %{public}s",
160 errno, strerror(errno));
161 CloseOnInitFailure();
162 return false;
163 }
164 SetConnectionState(ConnectionState::CONNECTING);
165 LOGI("InitToolchainWebSocketForSockName::client connect success.");
166 return true;
167 }
168
ClientSendWSUpgradeReq()169 bool WebSocketClient::ClientSendWSUpgradeReq()
170 {
171 auto state = GetConnectionState();
172 if (state == ConnectionState::CLOSING || state == ConnectionState::CLOSED) {
173 LOGE("ClientSendWSUpgradeReq::client has not inited.");
174 return false;
175 }
176 if (state == ConnectionState::OPEN) {
177 LOGE("ClientSendWSUpgradeReq::client has connected.");
178 return true;
179 }
180
181 secWebSocketKey_ = WebSocketKeyEncoder::GenerateRandomSecWSKey();
182 std::string upgradeReq = std::string(CLIENT_WS_UPGRADE_REQ_BEFORE_KEY) + secWebSocketKey_ +
183 std::string(CLIENT_WS_UPGRADE_REQ_AFTER_KEY);
184 if (!Send(GetConnectionSocket(), upgradeReq.data(), upgradeReq.size(), 0)) {
185 LOGE("ClientSendWSUpgradeReq::client send wsupgrade req failed, error = %{public}d, desc = %{public}sn",
186 errno, strerror(errno));
187 CloseOnInitFailure();
188 return false;
189 }
190 LOGI("ClientSendWSUpgradeReq::client send wsupgrade req success.");
191 return true;
192 }
193
ClientRecvWSUpgradeRsp()194 bool WebSocketClient::ClientRecvWSUpgradeRsp()
195 {
196 auto state = GetConnectionState();
197 if (state == ConnectionState::CLOSING || state == ConnectionState::CLOSED) {
198 LOGE("ClientRecvWSUpgradeRsp::client has not inited.");
199 return false;
200 }
201 if (state == ConnectionState::OPEN) {
202 LOGE("ClientRecvWSUpgradeRsp::client has connected.");
203 return true;
204 }
205
206 std::string msgBuf(HTTP_HANDSHAKE_MAX_LEN, 0);
207 ssize_t msgLen = 0;
208 while ((msgLen = recv(GetConnectionSocket(), msgBuf.data(), HTTP_HANDSHAKE_MAX_LEN, 0)) < 0 &&
209 (errno == EINTR || errno == EAGAIN)) {
210 LOGW("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp failed, errno = %{public}d", errno);
211 }
212 if (msgLen <= 0) {
213 LOGE("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp failed, error = %{public}d, desc = %{public}sn",
214 errno, strerror(errno));
215 CloseOnInitFailure();
216 return false;
217 }
218 // reduce to received size
219 msgBuf.resize(msgLen);
220
221 HttpResponse response;
222 if (!HttpResponse::Decode(msgBuf, response) || !ValidateServerHandShake(response)) {
223 LOGE("ClientRecvWSUpgradeRsp::client server handshake response is invalid");
224 CloseOnInitFailure();
225 return false;
226 }
227
228 SetConnectionState(ConnectionState::OPEN);
229 LOGI("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp success.");
230 return true;
231 }
232
GetSocketStateString()233 std::string WebSocketClient::GetSocketStateString()
234 {
235 return std::string(SOCKET_STATE_NAMES[EnumToNumber(GetConnectionState())]);
236 }
237
DecodeMessage(WebSocketFrame & wsFrame) const238 bool WebSocketClient::DecodeMessage(WebSocketFrame& wsFrame) const
239 {
240 uint64_t msgLen = wsFrame.payloadLen;
241 if (msgLen == 0) {
242 // receiving empty data is OK
243 return true;
244 }
245 auto& buffer = wsFrame.payload;
246 buffer.resize(msgLen, 0);
247
248 if (!RecvUnderLock(buffer)) {
249 LOGE("DecodeMessage: Recv message without mask failed");
250 return false;
251 }
252
253 return true;
254 }
255
Close()256 void WebSocketClient::Close()
257 {
258 if (!CloseConnection(CloseStatusCode::SERVER_GO_AWAY)) {
259 LOGW("Failed to close WebSocketClient");
260 }
261 }
262
CloseOnInitFailure()263 void WebSocketClient::CloseOnInitFailure()
264 {
265 // Must be in `CONNECTING` or `CLOSED` state.
266 auto expected = ConnectionState::CONNECTING;
267 if (!CompareExchangeConnectionState(expected, ConnectionState::CLOSING) &&
268 expected != ConnectionState::CLOSED) {
269 LOGE("CloseOnInitFailure violation: must be either CONNECTING or CLOSED, but got %{public}d",
270 EnumToNumber(expected));
271 }
272 CloseConnectionSocket(ConnectionCloseReason::FAIL);
273 }
274
ValidateIncomingFrame(const WebSocketFrame & wsFrame) const275 bool WebSocketClient::ValidateIncomingFrame(const WebSocketFrame& wsFrame) const
276 {
277 // "A server MUST NOT mask any frames that it sends to the client."
278 // https://www.rfc-editor.org/rfc/rfc6455#section-5.1
279 return wsFrame.mask == 0;
280 }
281
CreateFrame(bool isLast,FrameType frameType) const282 std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType) const
283 {
284 ClientFrameBuilder builder(isLast, frameType, MASK_KEY);
285 return builder.Build();
286 }
287
CreateFrame(bool isLast,FrameType frameType,const std::string & payload) const288 std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType, const std::string& payload) const
289 {
290 ClientFrameBuilder builder(isLast, frameType, MASK_KEY);
291 return builder.SetPayload(payload).Build();
292 }
293
CreateFrame(bool isLast,FrameType frameType,std::string && payload) const294 std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType, std::string&& payload) const
295 {
296 ClientFrameBuilder builder(isLast, frameType, MASK_KEY);
297 return builder.SetPayload(std::move(payload)).Build();
298 }
299 } // namespace OHOS::ArkCompiler::Toolchain
300