• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include <arpa/inet.h>
17 
18 #include "common/log_wrapper.h"
19 #include "frame_builder.h"
20 #include "handshake_helper.h"
21 #include "string_utils.h"
22 #include "client/websocket_client.h"
23 
24 namespace OHOS::ArkCompiler::Toolchain {
ValidateServerHandShake(HttpResponse & response)25 static bool ValidateServerHandShake(HttpResponse& response)
26 {
27     static constexpr std::string_view HTTP_SWITCHING_PROTOCOLS_STATUS_CODE = "101";
28     static constexpr std::string_view HTTP_RESPONSE_REQUIRED_UPGRADE = "websocket";
29     static constexpr std::string_view HTTP_RESPONSE_REQUIRED_CONNECTION = "upgrade";
30     // NB! `defaultWebSocketKey` must match "Sec-WebSocket-Key" used in `WebSocketClient`.
31     static constexpr unsigned char defaultWebSocketKey[] = "64b4B+s5JDlgkdg7NekJ+g==";
32 
33     // in accordance to https://www.rfc-editor.org/rfc/rfc6455#section-4.1
34     if (response.status != HTTP_SWITCHING_PROTOCOLS_STATUS_CODE) {
35         return false;
36     }
37     ToLowerCase(response.upgrade);
38     if (response.upgrade != HTTP_RESPONSE_REQUIRED_UPGRADE) {
39         return false;
40     }
41     ToLowerCase(response.connection);
42     if (response.connection != HTTP_RESPONSE_REQUIRED_CONNECTION) {
43         return false;
44     }
45 
46     // The same WebSocket-Key is used for all connections
47     // - must either use a randomly-selected, as required by spec or do this calculation statically.
48     unsigned char expectedSecWebSocketAccept[WebSocketKeyEncoder::ENCODED_KEY_LEN + 1];
49     if (!WebSocketKeyEncoder::EncodeKey(defaultWebSocketKey, expectedSecWebSocketAccept)) {
50         LOGE("ValidateServerHandShake failed to generate expected Sec-WebSocket-Accept token");
51         return false;
52     }
53 
54     Trim(response.secWebSocketAccept);
55     if (response.secWebSocketAccept.size() != WebSocketKeyEncoder::ENCODED_KEY_LEN ||
56         response.secWebSocketAccept.compare(reinterpret_cast<const char *>(expectedSecWebSocketAccept)) != 0) {
57         return false;
58     }
59 
60     // may support two remaining checks
61     return true;
62 }
63 
InitToolchainWebSocketForPort(int port,uint32_t timeoutLimit)64 bool WebSocketClient::InitToolchainWebSocketForPort(int port, uint32_t timeoutLimit)
65 {
66     if (GetConnectionState() != ConnectionState::CLOSED) {
67         LOGE("InitToolchainWebSocketForPort::client has inited.");
68         return true;
69     }
70 
71     int connection = socket(AF_INET, SOCK_STREAM, 0);
72     if (connection < SOCKET_SUCCESS) {
73         LOGE("InitToolchainWebSocketForPort::client socket failed, error = %{public}d , desc = %{public}s",
74             errno, strerror(errno));
75         return false;
76     }
77     SetConnectionSocket(connection);
78 
79     // set send and recv timeout limit
80     if (!SetWebSocketTimeOut(connection, timeoutLimit)) {
81         LOGE("InitToolchainWebSocketForPort::client SetWebSocketTimeOut failed, error = %{public}d , desc = %{public}s",
82             errno, strerror(errno));
83         CloseOnInitFailure();
84         return false;
85     }
86 
87     sockaddr_in clientAddr;
88     if (memset_s(&clientAddr, sizeof(clientAddr), 0, sizeof(clientAddr)) != EOK) {
89         LOGE("InitToolchainWebSocketForPort::client memset_s clientAddr failed, error = %{public}d, desc = %{public}s",
90             errno, strerror(errno));
91         CloseOnInitFailure();
92         return false;
93     }
94     clientAddr.sin_family = AF_INET;
95     clientAddr.sin_port = htons(port);
96     int ret = inet_pton(AF_INET, "127.0.0.1", &clientAddr.sin_addr);
97     if (ret != NET_SUCCESS) {
98         LOGE("InitToolchainWebSocketForPort::client inet_pton failed, error = %{public}d, desc = %{public}s",
99             errno, strerror(errno));
100         CloseOnInitFailure();
101         return false;
102     }
103 
104     ret = connect(connection, reinterpret_cast<struct sockaddr*>(&clientAddr), sizeof(clientAddr));
105     if (ret != SOCKET_SUCCESS) {
106         LOGE("InitToolchainWebSocketForPort::client connect failed, error = %{public}d, desc = %{public}s",
107             errno, strerror(errno));
108         CloseOnInitFailure();
109         return false;
110     }
111     SetConnectionState(ConnectionState::CONNECTING);
112     LOGI("InitToolchainWebSocketForPort::client connect success.");
113     return true;
114 }
115 
InitToolchainWebSocketForSockName(const std::string & sockName,uint32_t timeoutLimit)116 bool WebSocketClient::InitToolchainWebSocketForSockName(const std::string &sockName, uint32_t timeoutLimit)
117 {
118     if (GetConnectionState() != ConnectionState::CLOSED) {
119         LOGE("InitToolchainWebSocketForSockName::client has inited.");
120         return true;
121     }
122 
123     int connection = socket(AF_UNIX, SOCK_STREAM, 0);
124     if (connection < SOCKET_SUCCESS) {
125         LOGE("InitToolchainWebSocketForSockName::client socket failed, error = %{public}d , desc = %{public}s",
126             errno, strerror(errno));
127         return false;
128     }
129     SetConnectionSocket(connection);
130 
131     // set send and recv timeout limit
132     if (!SetWebSocketTimeOut(connection, timeoutLimit)) {
133         LOGE("InitToolchainWebSocketForSockName::client SetWebSocketTimeOut failed, "
134             "error = %{public}d, desc = %{public}s",
135             errno, strerror(errno));
136         CloseOnInitFailure();
137         return false;
138     }
139 
140     struct sockaddr_un serverAddr;
141     if (memset_s(&serverAddr, sizeof(serverAddr), 0, sizeof(serverAddr)) != EOK) {
142         LOGE("InitToolchainWebSocketForSockName::client memset_s clientAddr failed, "
143             "error = %{public}d, desc = %{public}s",
144             errno, strerror(errno));
145         CloseOnInitFailure();
146         return false;
147     }
148     serverAddr.sun_family = AF_UNIX;
149     if (strcpy_s(serverAddr.sun_path + 1, sizeof(serverAddr.sun_path) - 1, sockName.c_str()) != EOK) {
150         LOGE("InitToolchainWebSocketForSockName::client strcpy_s serverAddr.sun_path failed, "
151             "error = %{public}d, desc = %{public}s",
152             errno, strerror(errno));
153         CloseOnInitFailure();
154         return false;
155     }
156     serverAddr.sun_path[0] = '\0';
157 
158     uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1;
159     int ret = connect(connection, reinterpret_cast<struct sockaddr*>(&serverAddr), static_cast<int32_t>(len));
160     if (ret != SOCKET_SUCCESS) {
161         LOGE("InitToolchainWebSocketForSockName::client connect failed, error = %{public}d, desc = %{public}s",
162             errno, strerror(errno));
163         CloseOnInitFailure();
164         return false;
165     }
166     SetConnectionState(ConnectionState::CONNECTING);
167     LOGI("InitToolchainWebSocketForSockName::client connect success.");
168     return true;
169 }
170 
ClientSendWSUpgradeReq()171 bool WebSocketClient::ClientSendWSUpgradeReq()
172 {
173     auto state = GetConnectionState();
174     if (state == ConnectionState::CLOSING || state == ConnectionState::CLOSED) {
175         LOGE("ClientSendWSUpgradeReq::client has not inited.");
176         return false;
177     }
178     if (state == ConnectionState::OPEN) {
179         LOGE("ClientSendWSUpgradeReq::client has connected.");
180         return true;
181     }
182 
183     // length without null-terminator
184     if (!Send(GetConnectionSocket(), CLIENT_WEBSOCKET_UPGRADE_REQ, sizeof(CLIENT_WEBSOCKET_UPGRADE_REQ) - 1, 0)) {
185         LOGE("ClientSendWSUpgradeReq::client send wsupgrade req failed, error = %{public}d, desc = %{public}sn",
186             errno, strerror(errno));
187         CloseOnInitFailure();
188         return false;
189     }
190     LOGI("ClientSendWSUpgradeReq::client send wsupgrade req success.");
191     return true;
192 }
193 
ClientRecvWSUpgradeRsp()194 bool WebSocketClient::ClientRecvWSUpgradeRsp()
195 {
196     auto state = GetConnectionState();
197     if (state == ConnectionState::CLOSING || state == ConnectionState::CLOSED) {
198         LOGE("ClientRecvWSUpgradeRsp::client has not inited.");
199         return false;
200     }
201     if (state == ConnectionState::OPEN) {
202         LOGE("ClientRecvWSUpgradeRsp::client has connected.");
203         return true;
204     }
205 
206     std::string msgBuf(HTTP_HANDSHAKE_MAX_LEN, 0);
207     ssize_t msgLen = 0;
208     while ((msgLen = recv(GetConnectionSocket(), msgBuf.data(), HTTP_HANDSHAKE_MAX_LEN, 0)) < 0 &&
209            (errno == EINTR || errno == EAGAIN)) {
210         LOGW("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp failed, errno = %{public}d", errno);
211     }
212     if (msgLen <= 0) {
213         LOGE("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp failed, error = %{public}d, desc = %{public}sn",
214             errno, strerror(errno));
215         CloseOnInitFailure();
216         return false;
217     }
218     // reduce to received size
219     msgBuf.resize(msgLen);
220 
221     HttpResponse response;
222     if (!HttpResponse::Decode(msgBuf, response) || !ValidateServerHandShake(response)) {
223         LOGE("ClientRecvWSUpgradeRsp::client server handshake response is invalid");
224         CloseOnInitFailure();
225         return false;
226     }
227 
228     SetConnectionState(ConnectionState::OPEN);
229     LOGI("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp success.");
230     return true;
231 }
232 
GetSocketStateString()233 std::string WebSocketClient::GetSocketStateString()
234 {
235     return std::string(SOCKET_STATE_NAMES[EnumToNumber(GetConnectionState())]);
236 }
237 
DecodeMessage(WebSocketFrame & wsFrame) const238 bool WebSocketClient::DecodeMessage(WebSocketFrame& wsFrame) const
239 {
240     uint64_t msgLen = wsFrame.payloadLen;
241     if (msgLen == 0) {
242         // receiving empty data is OK
243         return true;
244     }
245     auto& buffer = wsFrame.payload;
246     buffer.resize(msgLen, 0);
247 
248     if (!RecvUnderLock(buffer)) {
249         LOGE("DecodeMessage: Recv message without mask failed");
250         return false;
251     }
252 
253     return true;
254 }
255 
Close()256 void WebSocketClient::Close()
257 {
258     if (!CloseConnection(CloseStatusCode::SERVER_GO_AWAY)) {
259         LOGW("Failed to close WebSocketClient");
260     }
261 }
262 
CloseOnInitFailure()263 void WebSocketClient::CloseOnInitFailure()
264 {
265     // Must be in `CONNECTING` or `CLOSED` state.
266     auto expected = ConnectionState::CONNECTING;
267     if (!CompareExchangeConnectionState(expected, ConnectionState::CLOSING) &&
268         expected != ConnectionState::CLOSED) {
269         LOGE("CloseOnInitFailure violation: must be either CONNECTING or CLOSED, but got %{public}d",
270              EnumToNumber(expected));
271     }
272     CloseConnectionSocket(ConnectionCloseReason::FAIL);
273 }
274 
ValidateIncomingFrame(const WebSocketFrame & wsFrame) const275 bool WebSocketClient::ValidateIncomingFrame(const WebSocketFrame& wsFrame) const
276 {
277     // "A server MUST NOT mask any frames that it sends to the client."
278     // https://www.rfc-editor.org/rfc/rfc6455#section-5.1
279     return wsFrame.mask == 0;
280 }
281 
CreateFrame(bool isLast,FrameType frameType) const282 std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType) const
283 {
284     ClientFrameBuilder builder(isLast, frameType, MASK_KEY);
285     return builder.Build();
286 }
287 
CreateFrame(bool isLast,FrameType frameType,const std::string & payload) const288 std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType, const std::string& payload) const
289 {
290     ClientFrameBuilder builder(isLast, frameType, MASK_KEY);
291     return builder.SetPayload(payload).Build();
292 }
293 
CreateFrame(bool isLast,FrameType frameType,std::string && payload) const294 std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType, std::string&& payload) const
295 {
296     ClientFrameBuilder builder(isLast, frameType, MASK_KEY);
297     return builder.SetPayload(std::move(payload)).Build();
298 }
299 }  // namespace OHOS::ArkCompiler::Toolchain
300