• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include <arpa/inet.h>
17 #include <sys/un.h>
18 
19 #include "gtest/gtest.h"
20 #include "websocket.h"
21 #include "securec.h"
22 
23 using namespace OHOS::ArkCompiler::Toolchain;
24 
25 namespace panda::test {
26 class WebSocketTest : public testing::Test {
27 public:
SetUpTestCase()28     static void SetUpTestCase()
29     {
30         GTEST_LOG_(INFO) << "SetUpTestCase";
31     }
32 
TearDownTestCase()33     static void TearDownTestCase()
34     {
35         GTEST_LOG_(INFO) << "TearDownCase";
36     }
37 
SetUp()38     void SetUp() override
39     {
40     }
41 
TearDown()42     void TearDown() override
43     {
44     }
45 
46     class ClientWebSocket : public WebSocket {
47     public:
48         ClientWebSocket() = default;
49         ~ClientWebSocket() = default;
50 #if defined(OHOS_PLATFORM)
ClientConnectUnixWebSocket(const std::string & sockName,uint32_t timeoutLimit=0)51         bool ClientConnectUnixWebSocket(const std::string &sockName, uint32_t timeoutLimit = 0)
52         {
53             if (socketState_ != SocketState::UNINITED) {
54                 std::cout << "ClientConnectUnixWebSocket::client has inited..." << std::endl;
55                 return true;
56             }
57 
58             client_ = socket(AF_UNIX, SOCK_STREAM, 0);
59             if (client_ < SOCKET_SUCCESS) {
60                 std::cerr << "ClientConnectUnixWebSocket::client socket failed, error = "
61                           << errno << ", desc = " << strerror(errno) << std::endl;
62                 return false;
63             }
64 
65             // set send and recv timeout limit
66             if (!SetWebSocketTimeOut(client_, timeoutLimit)) {
67                 std::cerr << "ClientConnectUnixWebSocket::client SetWebSocketTimeOut failed, error = "
68                           << errno << ", desc = " << strerror(errno) << std::endl;
69                 close(client_);
70                 client_ = -1;
71                 return false;
72             }
73 
74             struct sockaddr_un serverAddr;
75             if (memset_s(&serverAddr, sizeof(serverAddr), 0, sizeof(serverAddr)) != EOK) {
76                 std::cerr << "ClientConnectUnixWebSocket::client memset_s serverAddr failed, error = "
77                           << errno << ", desc = " << strerror(errno) << std::endl;
78                 close(client_);
79                 client_ = -1;
80                 return false;
81             }
82             serverAddr.sun_family = AF_UNIX;
83             if (strcpy_s(serverAddr.sun_path + 1, sizeof(serverAddr.sun_path) - 1, sockName.c_str()) != EOK) {
84                 std::cerr << "ClientConnectUnixWebSocket::client strcpy_s serverAddr.sun_path failed, error = "
85                           << errno << ", desc = " << strerror(errno) << std::endl;
86                 close(client_);
87                 client_ = -1;
88                 return false;
89             }
90             serverAddr.sun_path[0] = '\0';
91 
92             uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1;
93             int ret = connect(client_, reinterpret_cast<struct sockaddr*>(&serverAddr), static_cast<int32_t>(len));
94             if (ret != SOCKET_SUCCESS) {
95                 std::cerr << "ClientConnectUnixWebSocket::client connect failed, error = "
96                           << errno << ", desc = " << strerror(errno) << std::endl;
97                 close(client_);
98                 client_ = -1;
99                 return false;
100             }
101             socketState_ = SocketState::INITED;
102             std::cout << "ClientConnectUnixWebSocket::client connect success..." << std::endl;
103             return true;
104         }
105 #else
ClientConnectTcpWebSocket(uint32_t timeoutLimit=0)106         bool ClientConnectTcpWebSocket(uint32_t timeoutLimit = 0)
107         {
108             if (socketState_ != SocketState::UNINITED) {
109                 std::cout << "ClientConnectTcpWebSocket::client has inited..." << std::endl;
110                 return true;
111             }
112 
113             client_ = socket(AF_INET, SOCK_STREAM, 0);
114             if (client_ < SOCKET_SUCCESS) {
115                 std::cerr << "ClientConnectTcpWebSocket::client socket failed, error = "
116                           << errno << ", desc = " << strerror(errno) << std::endl;
117                 return false;
118             }
119 
120             // set send and recv timeout limit
121             if (!SetWebSocketTimeOut(client_, timeoutLimit)) {
122                 std::cerr << "ClientConnectTcpWebSocket::client SetWebSocketTimeOut failed, error = "
123                           << errno << ", desc = " << strerror(errno) << std::endl;
124                 close(client_);
125                 client_ = -1;
126                 return false;
127             }
128 
129             struct sockaddr_in serverAddr;
130             if (memset_s(&serverAddr, sizeof(serverAddr), 0, sizeof(serverAddr)) != EOK) {
131                 std::cerr << "ClientConnectTcpWebSocket::client memset_s serverAddr failed, error = "
132                           << errno << ", desc = " << strerror(errno) << std::endl;
133                 close(client_);
134                 client_ = -1;
135                 return false;
136             }
137             serverAddr.sin_family = AF_INET;
138             serverAddr.sin_port = htons(9230); // 9230: sockName for tcp
139             if (int ret = inet_pton(AF_INET, "127.0.0.1", &serverAddr.sin_addr) < NET_SUCCESS) {
140                 std::cerr << "ClientConnectTcpWebSocket::client inet_pton failed, ret = "
141                           << ret << ", error = " << errno << ", desc = " << strerror(errno) << std::endl;
142                 close(client_);
143                 client_ = -1;
144                 return false;
145             }
146 
147             int ret = connect(client_, reinterpret_cast<struct sockaddr*>(&serverAddr), sizeof(serverAddr));
148             if (ret != SOCKET_SUCCESS) {
149                 std::cerr << "ClientConnectTcpWebSocket::client connect failed, error = "
150                           << errno << ", desc = " << strerror(errno) << std::endl;
151                 close(client_);
152                 client_ = -1;
153                 return false;
154             }
155             socketState_ = SocketState::INITED;
156             std::cout << "ClientConnectTcpWebSocket::client connect success..." << std::endl;
157             return true;
158         }
159 #endif
160 
ClientSendWSUpgradeReq()161         bool ClientSendWSUpgradeReq()
162         {
163             if (socketState_ == SocketState::UNINITED) {
164                 std::cerr << "ClientSendWSUpgradeReq::client has not inited..." << std::endl;
165                 return false;
166             }
167             if (socketState_ == SocketState::CONNECTED) {
168                 std::cout << "ClientSendWSUpgradeReq::client has connected..." << std::endl;
169                 return true;
170             }
171 
172             int msgLen = strlen(CLIENT_WEBSOCKET_UPGRADE_REQ);
173             int32_t sendLen = send(client_, CLIENT_WEBSOCKET_UPGRADE_REQ, msgLen, 0);
174             if (sendLen != msgLen) {
175                 std::cerr << "ClientSendWSUpgradeReq::client send wsupgrade req failed, error = "
176                           << errno << ", desc = " << strerror(errno) << std::endl;
177                 socketState_ = SocketState::UNINITED;
178 #if defined(OHOS_PLATFORM)
179                 shutdown(client_, SHUT_RDWR);
180 #endif
181                 close(client_);
182                 client_ = -1;
183                 return false;
184             }
185             std::cout << "ClientSendWSUpgradeReq::client send wsupgrade req success..." << std::endl;
186             return true;
187         }
188 
ClientRecvWSUpgradeRsp()189         bool ClientRecvWSUpgradeRsp()
190         {
191             if (socketState_ == SocketState::UNINITED) {
192                 std::cerr << "ClientRecvWSUpgradeRsp::client has not inited..." << std::endl;
193                 return false;
194             }
195             if (socketState_ == SocketState::CONNECTED) {
196                 std::cout << "ClientRecvWSUpgradeRsp::client has connected..." << std::endl;
197                 return true;
198             }
199 
200             char recvBuf[CLIENT_WEBSOCKET_UPGRADE_RSP_LEN + 1] = {0};
201             int32_t bufLen = recv(client_, recvBuf, CLIENT_WEBSOCKET_UPGRADE_RSP_LEN, 0);
202             if (bufLen != CLIENT_WEBSOCKET_UPGRADE_RSP_LEN) {
203                 std::cerr << "ClientRecvWSUpgradeRsp::client recv wsupgrade rsp failed, error = "
204                           << errno << ", desc = " << strerror(errno) << std::endl;
205                 socketState_ = SocketState::UNINITED;
206 #if defined(OHOS_PLATFORM)
207                 shutdown(client_, SHUT_RDWR);
208 #endif
209                 close(client_);
210                 client_ = -1;
211                 return false;
212             }
213             socketState_ = SocketState::CONNECTED;
214             std::cout << "ClientRecvWSUpgradeRsp::client recv wsupgrade rsp success..." << std::endl;
215             return true;
216         }
217 
ClientSendReq(const std::string & message)218         bool ClientSendReq(const std::string &message)
219         {
220             if (socketState_ != SocketState::CONNECTED) {
221                 std::cerr << "ClientSendReq::client has not connected..." << std::endl;
222                 return false;
223             }
224 
225             uint32_t msgLen = message.length();
226             std::unique_ptr<char []> msgBuf = std::make_unique<char []>(msgLen + 15); // 15: the maximum expand length
227             char *sendBuf = msgBuf.get();
228             uint32_t sendMsgLen = 0;
229             sendBuf[0] = 0x81; // 0x81: the text message sent by the server should start with '0x81'.
230             uint32_t mask = 1;
231             // Depending on the length of the messages, client will use shift operation to get the res
232             // and store them in the buffer.
233             if (msgLen <= 125) { // 125: situation 1 when message's length <= 125
234                 sendBuf[1] = msgLen | (mask << 7); // 7: mask need shift left by 7 bits
235                 sendMsgLen = 2; // 2: the length of header frame is 2;
236             } else if (msgLen < 65536) { // 65536: message's length
237                 sendBuf[1] = 126 | (mask << 7); // 126: payloadLen according to the spec; 7: mask shift left by 7 bits
238                 sendBuf[2] = ((msgLen >> 8) & 0xff); // 8: shift right by 8 bits => res * (256^1)
239                 sendBuf[3] = (msgLen & 0xff); // 3: store len's data => res * (256^0)
240                 sendMsgLen = 4; // 4: the length of header frame is 4
241             } else {
242                 sendBuf[1] = 127 | (mask << 7); // 127: payloadLen according to the spec; 7: mask shift left by 7 bits
243                 for (int32_t i = 2; i <= 5; i++) { // 2 ~ 5: unused bits
244                     sendBuf[i] = 0;
245                 }
246                 sendBuf[6] = ((msgLen & 0xff000000) >> 24); // 6: shift 24 bits => res * (256^3)
247                 sendBuf[7] = ((msgLen & 0x00ff0000) >> 16); // 7: shift 16 bits => res * (256^2)
248                 sendBuf[8] = ((msgLen & 0x0000ff00) >> 8);  // 8: shift 8 bits => res * (256^1)
249                 sendBuf[9] = (msgLen & 0x000000ff); // 9: res * (256^0)
250                 sendMsgLen = 10; // 10: the length of header frame is 10
251             }
252 
253             if (memcpy_s(sendBuf + sendMsgLen, SOCKET_MASK_LEN, MASK_KEY, SOCKET_MASK_LEN) != EOK) {
254                 std::cerr << "ClientSendReq::client memcpy_s MASK_KEY failed, error = "
255                           << errno << ", desc = " << strerror(errno) << std::endl;
256                 return false;
257             }
258             sendMsgLen += SOCKET_MASK_LEN;
259 
260             std::string maskMessage;
261             for (uint64_t i = 0; i < msgLen; i++) {
262                 uint64_t j = i % SOCKET_MASK_LEN;
263                 maskMessage.push_back(message[i] ^ MASK_KEY[j]);
264             }
265             if (memcpy_s(sendBuf + sendMsgLen, msgLen, maskMessage.c_str(), msgLen) != EOK) {
266                 std::cerr << "ClientSendReq::client memcpy_s maskMessage failed, error = "
267                           << errno << ", desc = " << strerror(errno) << std::endl;
268                 return false;
269             }
270             msgBuf[sendMsgLen + msgLen] = '\0';
271 
272             if (send(client_, sendBuf, sendMsgLen + msgLen, 0) != static_cast<int>(sendMsgLen + msgLen)) {
273                 std::cerr << "ClientSendReq::client send msg req failed, error = "
274                           << errno << ", desc = " << strerror(errno) << std::endl;
275                 return false;
276             }
277             std::cout << "ClientRecvWSUpgradeRsp::client send msg req success..." << std::endl;
278             return true;
279         }
280 
Close()281         void Close()
282         {
283             if (socketState_ == SocketState::UNINITED) {
284                 return;
285             }
286 #if defined(OHOS_PLATFORM)
287             shutdown(client_, SHUT_RDWR);
288 #endif
289             close(client_);
290             client_ = -1;
291             socketState_ = SocketState::UNINITED;
292         }
293 
294     private:
295         static constexpr char CLIENT_WEBSOCKET_UPGRADE_REQ[] =  "GET / HTTP/1.1\r\n"
296                                                                 "Connection: Upgrade\r\n"
297                                                                 "Pragma: no-cache\r\n"
298                                                                 "Cache-Control: no-cache\r\n"
299                                                                 "Upgrade: websocket\r\n"
300                                                                 "Sec-WebSocket-Version: 13\r\n"
301                                                                 "Accept-Encoding: gzip, deflate, br\r\n"
302                                                                 "Sec-WebSocket-Key: 64b4B+s5JDlgkdg7NekJ+g==\r\n"
303                                                                 "Sec-WebSocket-Extensions: permessage-deflate\r\n";
304         static constexpr int32_t CLIENT_WEBSOCKET_UPGRADE_RSP_LEN = 129;
305         static constexpr char MASK_KEY[SOCKET_MASK_LEN + 1] = "abcd";
306         static constexpr int NET_SUCCESS = 1;
307     };
308 
309 #if defined(OHOS_PLATFORM)
310     static constexpr char UNIX_DOMAIN_PATH[] = "server.sock";
311 #endif
312     static constexpr char HELLO_SERVER[]   = "hello server";
313     static constexpr char HELLO_CLIENT[]   = "hello client";
314     static constexpr char SERVER_OK[]      = "server ok";
315     static constexpr char CLIENT_OK[]      = "client ok";
316     static constexpr char QUIT[]           = "quit";
317     static const std::string LONG_MSG;
318     static const std::string LONG_LONG_MSG;
319 };
320 
321 const std::string WebSocketTest::LONG_MSG       = std::string(1000, 'f');
322 const std::string WebSocketTest::LONG_LONG_MSG  = std::string(0xfffff, 'f');
323 
324 HWTEST_F(WebSocketTest, ConnectWebSocketTest, testing::ext::TestSize.Level0)
325 {
326     WebSocket serverSocket;
327     bool ret = false;
328 #if defined(OHOS_PLATFORM)
329     int appPid = getpid();
330     ret = serverSocket.InitUnixWebSocket(UNIX_DOMAIN_PATH + std::to_string(appPid), 5);
331 #else
332     ret = serverSocket.InitTcpWebSocket(9230, 5);
333 #endif
334     ASSERT_TRUE(ret);
335     pid_t pid = fork();
336     if (pid == 0) {
337         // subprocess, handle client connect and recv/send message
338         // note: EXPECT/ASSERT produce errors in subprocess that can not lead to failure of testcase in mainprocess,
339         //       so testcase still success finally.
340         ClientWebSocket clientSocket;
341         bool retClient = false;
342 #if defined(OHOS_PLATFORM)
343         retClient = clientSocket.ClientConnectUnixWebSocket(UNIX_DOMAIN_PATH + std::to_string(appPid), 5);
344 #else
345         retClient = clientSocket.ClientConnectTcpWebSocket(5);
346 #endif
347         ASSERT_TRUE(retClient);
348         retClient = clientSocket.ClientSendWSUpgradeReq();
349         ASSERT_TRUE(retClient);
350         retClient = clientSocket.ClientRecvWSUpgradeRsp();
351         ASSERT_TRUE(retClient);
352         retClient = clientSocket.ClientSendReq(HELLO_SERVER);
353         EXPECT_TRUE(retClient);
354         std::string recv = clientSocket.Decode();
355         EXPECT_EQ(strcmp(recv.c_str(), HELLO_CLIENT), 0);
356         if (strcmp(recv.c_str(), HELLO_CLIENT) == 0) {
357             retClient = clientSocket.ClientSendReq(CLIENT_OK);
358             EXPECT_TRUE(retClient);
359         }
360         retClient = clientSocket.ClientSendReq(LONG_MSG);
361         EXPECT_TRUE(retClient);
362         recv = clientSocket.Decode();
363         EXPECT_EQ(strcmp(recv.c_str(), SERVER_OK), 0);
364         if (strcmp(recv.c_str(), SERVER_OK) == 0) {
365             retClient = clientSocket.ClientSendReq(CLIENT_OK);
366             EXPECT_TRUE(retClient);
367         }
368         retClient = clientSocket.ClientSendReq(LONG_LONG_MSG);
369         EXPECT_TRUE(retClient);
370         recv = clientSocket.Decode();
371         EXPECT_EQ(strcmp(recv.c_str(), SERVER_OK), 0);
372         if (strcmp(recv.c_str(), SERVER_OK) == 0) {
373             retClient = clientSocket.ClientSendReq(CLIENT_OK);
374             EXPECT_TRUE(retClient);
375         }
376         retClient = clientSocket.ClientSendReq(QUIT);
377         EXPECT_TRUE(retClient);
378         clientSocket.Close();
379         exit(0);
380     } else if (pid > 0) {
381         // mainprocess, handle server connect and recv/send message
382 #if defined(OHOS_PLATFORM)
383         ret = serverSocket.ConnectUnixWebSocket();
384 #else
385         ret = serverSocket.ConnectTcpWebSocket();
386 #endif
387         ASSERT_TRUE(ret);
388         std::string recv = serverSocket.Decode();
389         EXPECT_EQ(strcmp(recv.c_str(), HELLO_SERVER), 0);
390         serverSocket.SendReply(HELLO_CLIENT);
391         recv = serverSocket.Decode();
392         EXPECT_EQ(strcmp(recv.c_str(), CLIENT_OK), 0);
393         recv = serverSocket.Decode();
394         EXPECT_EQ(strcmp(recv.c_str(), LONG_MSG.c_str()), 0);
395         serverSocket.SendReply(SERVER_OK);
396         recv = serverSocket.Decode();
397         EXPECT_EQ(strcmp(recv.c_str(), CLIENT_OK), 0);
398         recv = serverSocket.Decode();
399         EXPECT_EQ(strcmp(recv.c_str(), LONG_LONG_MSG.c_str()), 0);
400         serverSocket.SendReply(SERVER_OK);
401         recv = serverSocket.Decode();
402         EXPECT_EQ(strcmp(recv.c_str(), CLIENT_OK), 0);
403         recv = serverSocket.Decode();
404         EXPECT_EQ(strcmp(recv.c_str(), QUIT), 0);
405         serverSocket.Close();
406         // sleep ensure that linux os core can really release resource
407         sleep(3);
408     } else {
409         std::cerr << "ConnectWebSocketTest::fork failed, error = "
410                   << errno << ", desc = " << strerror(errno) << std::endl;
411     }
412 }
413 
414 HWTEST_F(WebSocketTest, ReConnectWebSocketTest, testing::ext::TestSize.Level0)
415 {
416     WebSocket serverSocket;
417     bool ret = false;
418 #if defined(OHOS_PLATFORM)
419     int appPid = getpid();
420     ret = serverSocket.InitUnixWebSocket(UNIX_DOMAIN_PATH + std::to_string(appPid), 5);
421 #else
422     ret = serverSocket.InitTcpWebSocket(9230, 5);
423 #endif
424     ASSERT_TRUE(ret);
425     for (int i = 0; i < 5; i++) {
426         pid_t pid = fork();
427         if (pid == 0) {
428             // subprocess, handle client connect and recv/send message
429             // note: EXPECT/ASSERT produce errors in subprocess that can not lead to failure of testcase in mainprocess,
430             //       so testcase still success finally.
431             ClientWebSocket clientSocket;
432             bool retClient = false;
433 #if defined(OHOS_PLATFORM)
434             retClient = clientSocket.ClientConnectUnixWebSocket(UNIX_DOMAIN_PATH + std::to_string(appPid), 5);
435 #else
436             retClient = clientSocket.ClientConnectTcpWebSocket(5);
437 #endif
438             ASSERT_TRUE(retClient);
439             retClient = clientSocket.ClientSendWSUpgradeReq();
440             ASSERT_TRUE(retClient);
441             retClient = clientSocket.ClientRecvWSUpgradeRsp();
442             ASSERT_TRUE(retClient);
443             retClient = clientSocket.ClientSendReq(HELLO_SERVER + std::to_string(i));
444             EXPECT_TRUE(retClient);
445             std::string recv = clientSocket.Decode();
446             EXPECT_EQ(strcmp(recv.c_str(), (HELLO_CLIENT + std::to_string(i)).c_str()), 0);
447             if (strcmp(recv.c_str(), (HELLO_CLIENT + std::to_string(i)).c_str()) == 0) {
448                 retClient = clientSocket.ClientSendReq(CLIENT_OK + std::to_string(i));
449                 EXPECT_TRUE(retClient);
450             }
451             clientSocket.Close();
452             exit(0);
453         } else if (pid > 0) {
454             // mainprocess, handle server connect and recv/send message
455 #if defined(OHOS_PLATFORM)
456             ret = serverSocket.ConnectUnixWebSocket();
457 #else
458             ret = serverSocket.ConnectTcpWebSocket();
459 #endif
460             ASSERT_TRUE(ret);
461             std::string recv = serverSocket.Decode();
462             EXPECT_EQ(strcmp(recv.c_str(), (HELLO_SERVER + std::to_string(i)).c_str()), 0);
463             serverSocket.SendReply(HELLO_CLIENT + std::to_string(i));
464             recv = serverSocket.Decode();
465             EXPECT_EQ(strcmp(recv.c_str(), (CLIENT_OK + std::to_string(i)).c_str()), 0);
466             while (serverSocket.IsConnected()) {
467                 serverSocket.Decode();
468             }
469         } else {
470             std::cerr << "ReConnectWebSocketTest::fork failed, error = "
471                       << errno << ", desc = " << strerror(errno) << std::endl;
472         }
473     }
474     serverSocket.Close();
475 }
476 }  // namespace panda::test
477