• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include <arpa/inet.h>
17 
18 #include "common/log_wrapper.h"
19 #include "frame_builder.h"
20 #include "handshake_helper.h"
21 #include "string_utils.h"
22 #include "client/websocket_client.h"
23 
24 namespace OHOS::ArkCompiler::Toolchain {
ValidateServerHandShake(HttpResponse & response)25 bool WebSocketClient::ValidateServerHandShake(HttpResponse& response)
26 {
27     static constexpr std::string_view HTTP_SWITCHING_PROTOCOLS_STATUS_CODE = "101";
28     static constexpr std::string_view HTTP_RESPONSE_REQUIRED_UPGRADE = "websocket";
29     static constexpr std::string_view HTTP_RESPONSE_REQUIRED_CONNECTION = "upgrade";
30 
31     // in accordance to https://www.rfc-editor.org/rfc/rfc6455#section-4.1
32     if (response.status != HTTP_SWITCHING_PROTOCOLS_STATUS_CODE) {
33         return false;
34     }
35     ToLowerCase(response.upgrade);
36     if (response.upgrade != HTTP_RESPONSE_REQUIRED_UPGRADE) {
37         return false;
38     }
39     ToLowerCase(response.connection);
40     if (response.connection != HTTP_RESPONSE_REQUIRED_CONNECTION) {
41         return false;
42     }
43 
44     // The same WebSocket-Key is used for all connections
45     // - must either use a randomly-selected, as required by spec or do this calculation statically.
46     unsigned char expectedSecWebSocketAccept[WebSocketKeyEncoder::ENCODED_KEY_LEN + 1];
47     if (!WebSocketKeyEncoder::EncodeKey(secWebSocketKey_, expectedSecWebSocketAccept)) {
48         LOGE("ValidateServerHandShake failed to generate expected Sec-WebSocket-Accept token");
49         return false;
50     }
51 
52     Trim(response.secWebSocketAccept);
53     if (response.secWebSocketAccept.size() != WebSocketKeyEncoder::ENCODED_KEY_LEN ||
54         response.secWebSocketAccept.compare(reinterpret_cast<const char *>(expectedSecWebSocketAccept)) != 0) {
55         return false;
56     }
57 
58     // may support two remaining checks
59     return true;
60 }
61 
InitToolchainWebSocketForPort(int port,uint32_t timeoutLimit)62 bool WebSocketClient::InitToolchainWebSocketForPort(int port, uint32_t timeoutLimit)
63 {
64     if (GetConnectionState() != ConnectionState::CLOSED) {
65         LOGE("InitToolchainWebSocketForPort::client has inited.");
66         return true;
67     }
68 
69     int connection = socket(AF_INET, SOCK_STREAM, 0);
70     if (connection < SOCKET_SUCCESS) {
71         LOGE("InitToolchainWebSocketForPort::client socket failed, error = %{public}d , desc = %{public}s",
72             errno, strerror(errno));
73         return false;
74     }
75     SetConnectionSocket(connection);
76 
77     // set send and recv timeout limit
78     if (!SetWebSocketTimeOut(connection, timeoutLimit)) {
79         LOGE("InitToolchainWebSocketForPort::client SetWebSocketTimeOut failed, error = %{public}d , desc = %{public}s",
80             errno, strerror(errno));
81         CloseOnInitFailure();
82         return false;
83     }
84 
85     sockaddr_in clientAddr;
86     if (memset_s(&clientAddr, sizeof(clientAddr), 0, sizeof(clientAddr)) != EOK) {
87         LOGE("InitToolchainWebSocketForPort::client memset_s clientAddr failed, error = %{public}d, desc = %{public}s",
88             errno, strerror(errno));
89         CloseOnInitFailure();
90         return false;
91     }
92     clientAddr.sin_family = AF_INET;
93     clientAddr.sin_port = htons(port);
94     int ret = inet_pton(AF_INET, "127.0.0.1", &clientAddr.sin_addr);
95     if (ret != NET_SUCCESS) {
96         LOGE("InitToolchainWebSocketForPort::client inet_pton failed, error = %{public}d, desc = %{public}s",
97             errno, strerror(errno));
98         CloseOnInitFailure();
99         return false;
100     }
101 
102     ret = connect(connection, reinterpret_cast<struct sockaddr*>(&clientAddr), sizeof(clientAddr));
103     if (ret != SOCKET_SUCCESS) {
104         LOGE("InitToolchainWebSocketForPort::client connect failed, error = %{public}d, desc = %{public}s",
105             errno, strerror(errno));
106         CloseOnInitFailure();
107         return false;
108     }
109     SetConnectionState(ConnectionState::CONNECTING);
110     LOGI("InitToolchainWebSocketForPort::client connect success.");
111     return true;
112 }
113 
InitToolchainWebSocketForSockName(const std::string & sockName,uint32_t timeoutLimit)114 bool WebSocketClient::InitToolchainWebSocketForSockName(const std::string &sockName, uint32_t timeoutLimit)
115 {
116     if (GetConnectionState() != ConnectionState::CLOSED) {
117         LOGE("InitToolchainWebSocketForSockName::client has inited.");
118         return true;
119     }
120 
121     int connection = socket(AF_UNIX, SOCK_STREAM, 0);
122     if (connection < SOCKET_SUCCESS) {
123         LOGE("InitToolchainWebSocketForSockName::client socket failed, error = %{public}d , desc = %{public}s",
124             errno, strerror(errno));
125         return false;
126     }
127     SetConnectionSocket(connection);
128 
129     // set send and recv timeout limit
130     if (!SetWebSocketTimeOut(connection, timeoutLimit)) {
131         LOGE("InitToolchainWebSocketForSockName::client SetWebSocketTimeOut failed, "
132             "error = %{public}d, desc = %{public}s",
133             errno, strerror(errno));
134         CloseOnInitFailure();
135         return false;
136     }
137 
138     struct sockaddr_un serverAddr;
139     if (memset_s(&serverAddr, sizeof(serverAddr), 0, sizeof(serverAddr)) != EOK) {
140         LOGE("InitToolchainWebSocketForSockName::client memset_s clientAddr failed, "
141             "error = %{public}d, desc = %{public}s",
142             errno, strerror(errno));
143         CloseOnInitFailure();
144         return false;
145     }
146     serverAddr.sun_family = AF_UNIX;
147     if (strcpy_s(serverAddr.sun_path + 1, sizeof(serverAddr.sun_path) - 1, sockName.c_str()) != EOK) {
148         LOGE("InitToolchainWebSocketForSockName::client strcpy_s serverAddr.sun_path failed, "
149             "error = %{public}d, desc = %{public}s",
150             errno, strerror(errno));
151         CloseOnInitFailure();
152         return false;
153     }
154     serverAddr.sun_path[0] = '\0';
155 
156     uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1;
157     int ret = connect(connection, reinterpret_cast<struct sockaddr*>(&serverAddr), static_cast<int32_t>(len));
158     if (ret != SOCKET_SUCCESS) {
159         LOGE("InitToolchainWebSocketForSockName::client connect failed, error = %{public}d, desc = %{public}s",
160             errno, strerror(errno));
161         CloseOnInitFailure();
162         return false;
163     }
164     SetConnectionState(ConnectionState::CONNECTING);
165     LOGI("InitToolchainWebSocketForSockName::client connect success.");
166     return true;
167 }
168 
ClientSendWSUpgradeReq()169 bool WebSocketClient::ClientSendWSUpgradeReq()
170 {
171     auto state = GetConnectionState();
172     if (state == ConnectionState::CLOSING || state == ConnectionState::CLOSED) {
173         LOGE("ClientSendWSUpgradeReq::client has not inited.");
174         return false;
175     }
176     if (state == ConnectionState::OPEN) {
177         LOGE("ClientSendWSUpgradeReq::client has connected.");
178         return true;
179     }
180 
181     secWebSocketKey_ = WebSocketKeyEncoder::GenerateRandomSecWSKey();
182     std::string upgradeReq = std::string(CLIENT_WS_UPGRADE_REQ_BEFORE_KEY) + secWebSocketKey_ +
183                              std::string(CLIENT_WS_UPGRADE_REQ_AFTER_KEY);
184     if (!Send(GetConnectionSocket(), upgradeReq.data(), upgradeReq.size(), 0)) {
185         LOGE("ClientSendWSUpgradeReq::client send wsupgrade req failed, error = %{public}d, desc = %{public}sn",
186             errno, strerror(errno));
187         CloseOnInitFailure();
188         return false;
189     }
190     LOGI("ClientSendWSUpgradeReq::client send wsupgrade req success.");
191     return true;
192 }
193 
ClientRecvWSUpgradeRsp()194 bool WebSocketClient::ClientRecvWSUpgradeRsp()
195 {
196     auto state = GetConnectionState();
197     if (state == ConnectionState::CLOSING || state == ConnectionState::CLOSED) {
198         LOGE("ClientRecvWSUpgradeRsp::client has not inited.");
199         return false;
200     }
201     if (state == ConnectionState::OPEN) {
202         LOGE("ClientRecvWSUpgradeRsp::client has connected.");
203         return true;
204     }
205 
206     std::string msgBuf(HTTP_HANDSHAKE_MAX_LEN, 0);
207     ssize_t msgLen = 0;
208     while ((msgLen = recv(GetConnectionSocket(), msgBuf.data(), HTTP_HANDSHAKE_MAX_LEN, 0)) < 0 &&
209            (errno == EINTR || errno == EAGAIN)) {
210         LOGW("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp failed, errno = %{public}d", errno);
211     }
212     if (msgLen <= 0) {
213         LOGE("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp failed, error = %{public}d, desc = %{public}sn",
214             errno, strerror(errno));
215         CloseOnInitFailure();
216         return false;
217     }
218     // reduce to received size
219     msgBuf.resize(msgLen);
220 
221     HttpResponse response;
222     if (!HttpResponse::Decode(msgBuf, response) || !ValidateServerHandShake(response)) {
223         LOGE("ClientRecvWSUpgradeRsp::client server handshake response is invalid");
224         CloseOnInitFailure();
225         return false;
226     }
227 
228     SetConnectionState(ConnectionState::OPEN);
229     LOGI("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp success.");
230     return true;
231 }
232 
GetSocketStateString()233 std::string WebSocketClient::GetSocketStateString()
234 {
235     return std::string(SOCKET_STATE_NAMES[EnumToNumber(GetConnectionState())]);
236 }
237 
DecodeMessage(WebSocketFrame & wsFrame) const238 bool WebSocketClient::DecodeMessage(WebSocketFrame& wsFrame) const
239 {
240     uint64_t msgLen = wsFrame.payloadLen;
241     if (msgLen == 0) {
242         // receiving empty data is OK
243         return true;
244     }
245     auto& buffer = wsFrame.payload;
246     buffer.resize(msgLen, 0);
247 
248     if (!RecvUnderLock(buffer)) {
249         LOGE("DecodeMessage: Recv message without mask failed");
250         return false;
251     }
252 
253     return true;
254 }
255 
Close()256 void WebSocketClient::Close()
257 {
258     if (!CloseConnection(CloseStatusCode::SERVER_GO_AWAY)) {
259         LOGW("Failed to close WebSocketClient");
260     }
261 }
262 
CloseOnInitFailure()263 void WebSocketClient::CloseOnInitFailure()
264 {
265     // Must be in `CONNECTING` or `CLOSED` state.
266     auto expected = ConnectionState::CONNECTING;
267     if (!CompareExchangeConnectionState(expected, ConnectionState::CLOSING) &&
268         expected != ConnectionState::CLOSED) {
269         LOGE("CloseOnInitFailure violation: must be either CONNECTING or CLOSED, but got %{public}d",
270              EnumToNumber(expected));
271     }
272     CloseConnectionSocket(ConnectionCloseReason::FAIL);
273 }
274 
ValidateIncomingFrame(const WebSocketFrame & wsFrame) const275 bool WebSocketClient::ValidateIncomingFrame(const WebSocketFrame& wsFrame) const
276 {
277     // "A server MUST NOT mask any frames that it sends to the client."
278     // https://www.rfc-editor.org/rfc/rfc6455#section-5.1
279     return wsFrame.mask == 0;
280 }
281 
CreateFrame(bool isLast,FrameType frameType) const282 std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType) const
283 {
284     ClientFrameBuilder builder(isLast, frameType, MASK_KEY);
285     return builder.Build();
286 }
287 
CreateFrame(bool isLast,FrameType frameType,const std::string & payload) const288 std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType, const std::string& payload) const
289 {
290     ClientFrameBuilder builder(isLast, frameType, MASK_KEY);
291     return builder.SetPayload(payload).Build();
292 }
293 
CreateFrame(bool isLast,FrameType frameType,std::string && payload) const294 std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType, std::string&& payload) const
295 {
296     ClientFrameBuilder builder(isLast, frameType, MASK_KEY);
297     return builder.SetPayload(std::move(payload)).Build();
298 }
299 }  // namespace OHOS::ArkCompiler::Toolchain
300