• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include <sys/socket.h>
17 #include <sys/time.h>
18 #include <arpa/inet.h>
19 #include <sys/un.h>
20 #include <unistd.h>
21 
22 #include "common/log_wrapper.h"
23 #include "frame_builder.h"
24 #include "handshake_helper.h"
25 #include "network.h"
26 #include "string_utils.h"
27 #include "client/websocket_client.h"
28 
29 namespace OHOS::ArkCompiler::Toolchain {
ValidateServerHandShake(HttpResponse & response)30 static bool ValidateServerHandShake(HttpResponse& response)
31 {
32     static constexpr std::string_view HTTP_SWITCHING_PROTOCOLS_STATUS_CODE = "101";
33     static constexpr std::string_view HTTP_RESPONSE_REQUIRED_UPGRADE = "websocket";
34     static constexpr std::string_view HTTP_RESPONSE_REQUIRED_CONNECTION = "upgrade";
35     // NB! `defaultWebSocketKey` must match "Sec-WebSocket-Key" used in `WebSocketClient`.
36     static constexpr unsigned char defaultWebSocketKey[] = "64b4B+s5JDlgkdg7NekJ+g==";
37 
38     // in accordance to https://www.rfc-editor.org/rfc/rfc6455#section-4.1
39     if (response.status != HTTP_SWITCHING_PROTOCOLS_STATUS_CODE) {
40         return false;
41     }
42     ToLowerCase(response.upgrade);
43     if (response.upgrade != HTTP_RESPONSE_REQUIRED_UPGRADE) {
44         return false;
45     }
46     ToLowerCase(response.connection);
47     if (response.connection != HTTP_RESPONSE_REQUIRED_CONNECTION) {
48         return false;
49     }
50 
51     // The same WebSocket-Key is used for all connections
52     // - must either use a randomly-selected, as required by spec or do this calculation statically.
53     unsigned char expectedSecWebSocketAccept[WebSocketKeyEncoder::ENCODED_KEY_LEN + 1];
54     if (!WebSocketKeyEncoder::EncodeKey(defaultWebSocketKey, expectedSecWebSocketAccept)) {
55         LOGE("ValidateServerHandShake failed to generate expected Sec-WebSocket-Accept token");
56         return false;
57     }
58 
59     Trim(response.secWebSocketAccept);
60     if (response.secWebSocketAccept.size() != WebSocketKeyEncoder::ENCODED_KEY_LEN ||
61         response.secWebSocketAccept.compare(reinterpret_cast<const char *>(expectedSecWebSocketAccept)) != 0) {
62         return false;
63     }
64 
65     // may support two remaining checks
66     return true;
67 }
68 
InitToolchainWebSocketForPort(int port,uint32_t timeoutLimit)69 bool WebSocketClient::InitToolchainWebSocketForPort(int port, uint32_t timeoutLimit)
70 {
71     if (GetConnectionState() != ConnectionState::CLOSED) {
72         LOGE("InitToolchainWebSocketForPort::client has inited.");
73         return true;
74     }
75 
76     int connection = socket(AF_INET, SOCK_STREAM, 0);
77     if (connection < SOCKET_SUCCESS) {
78         LOGE("InitToolchainWebSocketForPort::client socket failed, error = %{public}d , desc = %{public}s",
79             errno, strerror(errno));
80         return false;
81     }
82     SetConnectionSocket(connection);
83 
84     // set send and recv timeout limit
85     if (!SetWebSocketTimeOut(connection, timeoutLimit)) {
86         LOGE("InitToolchainWebSocketForPort::client SetWebSocketTimeOut failed, error = %{public}d , desc = %{public}s",
87             errno, strerror(errno));
88         CloseOnInitFailure();
89         return false;
90     }
91 
92     sockaddr_in clientAddr;
93     if (memset_s(&clientAddr, sizeof(clientAddr), 0, sizeof(clientAddr)) != EOK) {
94         LOGE("InitToolchainWebSocketForPort::client memset_s clientAddr failed, error = %{public}d, desc = %{public}s",
95             errno, strerror(errno));
96         CloseOnInitFailure();
97         return false;
98     }
99     clientAddr.sin_family = AF_INET;
100     clientAddr.sin_port = htons(port);
101     int ret = inet_pton(AF_INET, "127.0.0.1", &clientAddr.sin_addr);
102     if (ret != NET_SUCCESS) {
103         LOGE("InitToolchainWebSocketForPort::client inet_pton failed, error = %{public}d, desc = %{public}s",
104             errno, strerror(errno));
105         CloseOnInitFailure();
106         return false;
107     }
108 
109     ret = connect(connection, reinterpret_cast<struct sockaddr*>(&clientAddr), sizeof(clientAddr));
110     if (ret != SOCKET_SUCCESS) {
111         LOGE("InitToolchainWebSocketForPort::client connect failed, error = %{public}d, desc = %{public}s",
112             errno, strerror(errno));
113         CloseOnInitFailure();
114         return false;
115     }
116     SetConnectionState(ConnectionState::CONNECTING);
117     LOGI("InitToolchainWebSocketForPort::client connect success.");
118     return true;
119 }
120 
InitToolchainWebSocketForSockName(const std::string & sockName,uint32_t timeoutLimit)121 bool WebSocketClient::InitToolchainWebSocketForSockName(const std::string &sockName, uint32_t timeoutLimit)
122 {
123     if (GetConnectionState() != ConnectionState::CLOSED) {
124         LOGE("InitToolchainWebSocketForSockName::client has inited.");
125         return true;
126     }
127 
128     int connection = socket(AF_UNIX, SOCK_STREAM, 0);
129     if (connection < SOCKET_SUCCESS) {
130         LOGE("InitToolchainWebSocketForSockName::client socket failed, error = %{public}d , desc = %{public}s",
131             errno, strerror(errno));
132         return false;
133     }
134     SetConnectionSocket(connection);
135 
136     // set send and recv timeout limit
137     if (!SetWebSocketTimeOut(connection, timeoutLimit)) {
138         LOGE("InitToolchainWebSocketForSockName::client SetWebSocketTimeOut failed, "
139             "error = %{public}d, desc = %{public}s",
140             errno, strerror(errno));
141         CloseOnInitFailure();
142         return false;
143     }
144 
145     struct sockaddr_un serverAddr;
146     if (memset_s(&serverAddr, sizeof(serverAddr), 0, sizeof(serverAddr)) != EOK) {
147         LOGE("InitToolchainWebSocketForSockName::client memset_s clientAddr failed, "
148             "error = %{public}d, desc = %{public}s",
149             errno, strerror(errno));
150         CloseOnInitFailure();
151         return false;
152     }
153     serverAddr.sun_family = AF_UNIX;
154     if (strcpy_s(serverAddr.sun_path + 1, sizeof(serverAddr.sun_path) - 1, sockName.c_str()) != EOK) {
155         LOGE("InitToolchainWebSocketForSockName::client strcpy_s serverAddr.sun_path failed, "
156             "error = %{public}d, desc = %{public}s",
157             errno, strerror(errno));
158         CloseOnInitFailure();
159         return false;
160     }
161     serverAddr.sun_path[0] = '\0';
162 
163     uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1;
164     int ret = connect(connection, reinterpret_cast<struct sockaddr*>(&serverAddr), static_cast<int32_t>(len));
165     if (ret != SOCKET_SUCCESS) {
166         LOGE("InitToolchainWebSocketForSockName::client connect failed, error = %{public}d, desc = %{public}s",
167             errno, strerror(errno));
168         CloseOnInitFailure();
169         return false;
170     }
171     SetConnectionState(ConnectionState::CONNECTING);
172     LOGI("InitToolchainWebSocketForSockName::client connect success.");
173     return true;
174 }
175 
ClientSendWSUpgradeReq()176 bool WebSocketClient::ClientSendWSUpgradeReq()
177 {
178     auto state = GetConnectionState();
179     if (state == ConnectionState::CLOSING || state == ConnectionState::CLOSED) {
180         LOGE("ClientSendWSUpgradeReq::client has not inited.");
181         return false;
182     }
183     if (state == ConnectionState::OPEN) {
184         LOGE("ClientSendWSUpgradeReq::client has connected.");
185         return true;
186     }
187 
188     // length without null-terminator
189     if (!Send(GetConnectionSocket(), CLIENT_WEBSOCKET_UPGRADE_REQ, sizeof(CLIENT_WEBSOCKET_UPGRADE_REQ) - 1, 0)) {
190         LOGE("ClientSendWSUpgradeReq::client send wsupgrade req failed, error = %{public}d, desc = %{public}sn",
191             errno, strerror(errno));
192         CloseOnInitFailure();
193         return false;
194     }
195     LOGI("ClientSendWSUpgradeReq::client send wsupgrade req success.");
196     return true;
197 }
198 
ClientRecvWSUpgradeRsp()199 bool WebSocketClient::ClientRecvWSUpgradeRsp()
200 {
201     auto state = GetConnectionState();
202     if (state == ConnectionState::CLOSING || state == ConnectionState::CLOSED) {
203         LOGE("ClientRecvWSUpgradeRsp::client has not inited.");
204         return false;
205     }
206     if (state == ConnectionState::OPEN) {
207         LOGE("ClientRecvWSUpgradeRsp::client has connected.");
208         return true;
209     }
210 
211     std::string msgBuf(HTTP_HANDSHAKE_MAX_LEN, 0);
212     ssize_t msgLen = 0;
213     while ((msgLen = recv(GetConnectionSocket(), msgBuf.data(), HTTP_HANDSHAKE_MAX_LEN, 0)) < 0 &&
214            (errno == EINTR || errno == EAGAIN)) {
215         LOGW("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp failed, errno = %{public}d", errno);
216     }
217     if (msgLen <= 0) {
218         LOGE("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp failed, error = %{public}d, desc = %{public}sn",
219             errno, strerror(errno));
220         CloseOnInitFailure();
221         return false;
222     }
223     // reduce to received size
224     msgBuf.resize(msgLen);
225 
226     HttpResponse response;
227     if (!HttpResponse::Decode(msgBuf, response) || !ValidateServerHandShake(response)) {
228         LOGE("ClientRecvWSUpgradeRsp::client server handshake response is invalid");
229         CloseOnInitFailure();
230         return false;
231     }
232 
233     SetConnectionState(ConnectionState::OPEN);
234     LOGI("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp success.");
235     return true;
236 }
237 
GetSocketStateString()238 std::string WebSocketClient::GetSocketStateString()
239 {
240     return std::string(SOCKET_STATE_NAMES[EnumToNumber(GetConnectionState())]);
241 }
242 
DecodeMessage(WebSocketFrame & wsFrame) const243 bool WebSocketClient::DecodeMessage(WebSocketFrame& wsFrame) const
244 {
245     uint64_t msgLen = wsFrame.payloadLen;
246     if (msgLen == 0) {
247         // receiving empty data is OK
248         return true;
249     }
250     auto& buffer = wsFrame.payload;
251     buffer.resize(msgLen, 0);
252 
253     if (!RecvUnderLock(buffer)) {
254         LOGE("DecodeMessage: Recv message without mask failed");
255         return false;
256     }
257 
258     return true;
259 }
260 
Close()261 void WebSocketClient::Close()
262 {
263     if (!CloseConnection(CloseStatusCode::SERVER_GO_AWAY)) {
264         LOGW("Failed to close WebSocketClient");
265     }
266 }
267 
CloseOnInitFailure()268 void WebSocketClient::CloseOnInitFailure()
269 {
270     // Must be in `CONNECTING` or `CLOSED` state.
271     auto expected = ConnectionState::CONNECTING;
272     if (!CompareExchangeConnectionState(expected, ConnectionState::CLOSING) &&
273         expected != ConnectionState::CLOSED) {
274         LOGE("CloseOnInitFailure violation: must be either CONNECTING or CLOSED, but got %{public}d",
275              EnumToNumber(expected));
276     }
277     CloseConnectionSocket(ConnectionCloseReason::FAIL);
278 }
279 
ValidateIncomingFrame(const WebSocketFrame & wsFrame) const280 bool WebSocketClient::ValidateIncomingFrame(const WebSocketFrame& wsFrame) const
281 {
282     // "A server MUST NOT mask any frames that it sends to the client."
283     // https://www.rfc-editor.org/rfc/rfc6455#section-5.1
284     return wsFrame.mask == 0;
285 }
286 
CreateFrame(bool isLast,FrameType frameType) const287 std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType) const
288 {
289     ClientFrameBuilder builder(isLast, frameType, MASK_KEY);
290     return builder.Build();
291 }
292 
CreateFrame(bool isLast,FrameType frameType,const std::string & payload) const293 std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType, const std::string& payload) const
294 {
295     ClientFrameBuilder builder(isLast, frameType, MASK_KEY);
296     return builder.SetPayload(payload).Build();
297 }
298 
CreateFrame(bool isLast,FrameType frameType,std::string && payload) const299 std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType, std::string&& payload) const
300 {
301     ClientFrameBuilder builder(isLast, frameType, MASK_KEY);
302     return builder.SetPayload(std::move(payload)).Build();
303 }
304 }  // namespace OHOS::ArkCompiler::Toolchain
305