1 /* 2 * Copyright (c) 2023 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16 #include <arpa/inet.h> 17 #include <sys/un.h> 18 19 #include "gtest/gtest.h" 20 #include "websocket.h" 21 #include "securec.h" 22 23 using namespace OHOS::ArkCompiler::Toolchain; 24 25 namespace panda::test { 26 class WebSocketTest : public testing::Test { 27 public: SetUpTestCase()28 static void SetUpTestCase() 29 { 30 GTEST_LOG_(INFO) << "SetUpTestCase"; 31 } 32 TearDownTestCase()33 static void TearDownTestCase() 34 { 35 GTEST_LOG_(INFO) << "TearDownCase"; 36 } 37 SetUp()38 void SetUp() override 39 { 40 } 41 TearDown()42 void TearDown() override 43 { 44 } 45 46 class ClientWebSocket : public WebSocket { 47 public: 48 ClientWebSocket() = default; 49 ~ClientWebSocket() = default; 50 #if defined(OHOS_PLATFORM) ClientConnectUnixWebSocket(const std::string & sockName,uint32_t timeoutLimit=0)51 bool ClientConnectUnixWebSocket(const std::string &sockName, uint32_t timeoutLimit = 0) 52 { 53 if (socketState_ != SocketState::UNINITED) { 54 std::cout << "ClientConnectUnixWebSocket::client has inited..." << std::endl; 55 return true; 56 } 57 58 client_ = socket(AF_UNIX, SOCK_STREAM, 0); 59 if (client_ < SOCKET_SUCCESS) { 60 std::cerr << "ClientConnectUnixWebSocket::client socket failed, error = " 61 << errno << ", desc = " << strerror(errno) << std::endl; 62 return false; 63 } 64 65 // set send and recv timeout limit 66 if (!SetWebSocketTimeOut(client_, timeoutLimit)) { 67 std::cerr << "ClientConnectUnixWebSocket::client SetWebSocketTimeOut failed, error = " 68 << errno << ", desc = " << strerror(errno) << std::endl; 69 close(client_); 70 client_ = -1; 71 return false; 72 } 73 74 struct sockaddr_un serverAddr; 75 if (memset_s(&serverAddr, sizeof(serverAddr), 0, sizeof(serverAddr)) != EOK) { 76 std::cerr << "ClientConnectUnixWebSocket::client memset_s serverAddr failed, error = " 77 << errno << ", desc = " << strerror(errno) << std::endl; 78 close(client_); 79 client_ = -1; 80 return false; 81 } 82 serverAddr.sun_family = AF_UNIX; 83 if (strcpy_s(serverAddr.sun_path + 1, sizeof(serverAddr.sun_path) - 1, sockName.c_str()) != EOK) { 84 std::cerr << "ClientConnectUnixWebSocket::client strcpy_s serverAddr.sun_path failed, error = " 85 << errno << ", desc = " << strerror(errno) << std::endl; 86 close(client_); 87 client_ = -1; 88 return false; 89 } 90 serverAddr.sun_path[0] = '\0'; 91 92 uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1; 93 int ret = connect(client_, reinterpret_cast<struct sockaddr*>(&serverAddr), static_cast<int32_t>(len)); 94 if (ret != SOCKET_SUCCESS) { 95 std::cerr << "ClientConnectUnixWebSocket::client connect failed, error = " 96 << errno << ", desc = " << strerror(errno) << std::endl; 97 close(client_); 98 client_ = -1; 99 return false; 100 } 101 socketState_ = SocketState::INITED; 102 std::cout << "ClientConnectUnixWebSocket::client connect success..." << std::endl; 103 return true; 104 } 105 #else ClientConnectTcpWebSocket(uint32_t timeoutLimit=0)106 bool ClientConnectTcpWebSocket(uint32_t timeoutLimit = 0) 107 { 108 if (socketState_ != SocketState::UNINITED) { 109 std::cout << "ClientConnectTcpWebSocket::client has inited..." << std::endl; 110 return true; 111 } 112 113 client_ = socket(AF_INET, SOCK_STREAM, 0); 114 if (client_ < SOCKET_SUCCESS) { 115 std::cerr << "ClientConnectTcpWebSocket::client socket failed, error = " 116 << errno << ", desc = " << strerror(errno) << std::endl; 117 return false; 118 } 119 120 // set send and recv timeout limit 121 if (!SetWebSocketTimeOut(client_, timeoutLimit)) { 122 std::cerr << "ClientConnectTcpWebSocket::client SetWebSocketTimeOut failed, error = " 123 << errno << ", desc = " << strerror(errno) << std::endl; 124 close(client_); 125 client_ = -1; 126 return false; 127 } 128 129 struct sockaddr_in serverAddr; 130 if (memset_s(&serverAddr, sizeof(serverAddr), 0, sizeof(serverAddr)) != EOK) { 131 std::cerr << "ClientConnectTcpWebSocket::client memset_s serverAddr failed, error = " 132 << errno << ", desc = " << strerror(errno) << std::endl; 133 close(client_); 134 client_ = -1; 135 return false; 136 } 137 serverAddr.sin_family = AF_INET; 138 serverAddr.sin_port = htons(9230); // 9230: sockName for tcp 139 if (int ret = inet_pton(AF_INET, "127.0.0.1", &serverAddr.sin_addr) < NET_SUCCESS) { 140 std::cerr << "ClientConnectTcpWebSocket::client inet_pton failed, ret = " 141 << ret << ", error = " << errno << ", desc = " << strerror(errno) << std::endl; 142 close(client_); 143 client_ = -1; 144 return false; 145 } 146 147 int ret = connect(client_, reinterpret_cast<struct sockaddr*>(&serverAddr), sizeof(serverAddr)); 148 if (ret != SOCKET_SUCCESS) { 149 std::cerr << "ClientConnectTcpWebSocket::client connect failed, error = " 150 << errno << ", desc = " << strerror(errno) << std::endl; 151 close(client_); 152 client_ = -1; 153 return false; 154 } 155 socketState_ = SocketState::INITED; 156 std::cout << "ClientConnectTcpWebSocket::client connect success..." << std::endl; 157 return true; 158 } 159 #endif 160 ClientSendWSUpgradeReq()161 bool ClientSendWSUpgradeReq() 162 { 163 if (socketState_ == SocketState::UNINITED) { 164 std::cerr << "ClientSendWSUpgradeReq::client has not inited..." << std::endl; 165 return false; 166 } 167 if (socketState_ == SocketState::CONNECTED) { 168 std::cout << "ClientSendWSUpgradeReq::client has connected..." << std::endl; 169 return true; 170 } 171 172 int msgLen = strlen(CLIENT_WEBSOCKET_UPGRADE_REQ); 173 int32_t sendLen = send(client_, CLIENT_WEBSOCKET_UPGRADE_REQ, msgLen, 0); 174 if (sendLen != msgLen) { 175 std::cerr << "ClientSendWSUpgradeReq::client send wsupgrade req failed, error = " 176 << errno << ", desc = " << strerror(errno) << std::endl; 177 socketState_ = SocketState::UNINITED; 178 #if defined(OHOS_PLATFORM) 179 shutdown(client_, SHUT_RDWR); 180 #endif 181 close(client_); 182 client_ = -1; 183 return false; 184 } 185 std::cout << "ClientSendWSUpgradeReq::client send wsupgrade req success..." << std::endl; 186 return true; 187 } 188 ClientRecvWSUpgradeRsp()189 bool ClientRecvWSUpgradeRsp() 190 { 191 if (socketState_ == SocketState::UNINITED) { 192 std::cerr << "ClientRecvWSUpgradeRsp::client has not inited..." << std::endl; 193 return false; 194 } 195 if (socketState_ == SocketState::CONNECTED) { 196 std::cout << "ClientRecvWSUpgradeRsp::client has connected..." << std::endl; 197 return true; 198 } 199 200 char recvBuf[CLIENT_WEBSOCKET_UPGRADE_RSP_LEN + 1] = {0}; 201 int32_t bufLen = recv(client_, recvBuf, CLIENT_WEBSOCKET_UPGRADE_RSP_LEN, 0); 202 if (bufLen != CLIENT_WEBSOCKET_UPGRADE_RSP_LEN) { 203 std::cerr << "ClientRecvWSUpgradeRsp::client recv wsupgrade rsp failed, error = " 204 << errno << ", desc = " << strerror(errno) << std::endl; 205 socketState_ = SocketState::UNINITED; 206 #if defined(OHOS_PLATFORM) 207 shutdown(client_, SHUT_RDWR); 208 #endif 209 close(client_); 210 client_ = -1; 211 return false; 212 } 213 socketState_ = SocketState::CONNECTED; 214 std::cout << "ClientRecvWSUpgradeRsp::client recv wsupgrade rsp success..." << std::endl; 215 return true; 216 } 217 ClientSendReq(const std::string & message)218 bool ClientSendReq(const std::string &message) 219 { 220 if (socketState_ != SocketState::CONNECTED) { 221 std::cerr << "ClientSendReq::client has not connected..." << std::endl; 222 return false; 223 } 224 225 uint32_t msgLen = message.length(); 226 std::unique_ptr<char []> msgBuf = std::make_unique<char []>(msgLen + 15); // 15: the maximum expand length 227 char *sendBuf = msgBuf.get(); 228 uint32_t sendMsgLen = 0; 229 sendBuf[0] = 0x81; // 0x81: the text message sent by the server should start with '0x81'. 230 uint32_t mask = 1; 231 // Depending on the length of the messages, client will use shift operation to get the res 232 // and store them in the buffer. 233 if (msgLen <= 125) { // 125: situation 1 when message's length <= 125 234 sendBuf[1] = msgLen | (mask << 7); // 7: mask need shift left by 7 bits 235 sendMsgLen = 2; // 2: the length of header frame is 2; 236 } else if (msgLen < 65536) { // 65536: message's length 237 sendBuf[1] = 126 | (mask << 7); // 126: payloadLen according to the spec; 7: mask shift left by 7 bits 238 sendBuf[2] = ((msgLen >> 8) & 0xff); // 8: shift right by 8 bits => res * (256^1) 239 sendBuf[3] = (msgLen & 0xff); // 3: store len's data => res * (256^0) 240 sendMsgLen = 4; // 4: the length of header frame is 4 241 } else { 242 sendBuf[1] = 127 | (mask << 7); // 127: payloadLen according to the spec; 7: mask shift left by 7 bits 243 for (int32_t i = 2; i <= 5; i++) { // 2 ~ 5: unused bits 244 sendBuf[i] = 0; 245 } 246 sendBuf[6] = ((msgLen & 0xff000000) >> 24); // 6: shift 24 bits => res * (256^3) 247 sendBuf[7] = ((msgLen & 0x00ff0000) >> 16); // 7: shift 16 bits => res * (256^2) 248 sendBuf[8] = ((msgLen & 0x0000ff00) >> 8); // 8: shift 8 bits => res * (256^1) 249 sendBuf[9] = (msgLen & 0x000000ff); // 9: res * (256^0) 250 sendMsgLen = 10; // 10: the length of header frame is 10 251 } 252 253 if (memcpy_s(sendBuf + sendMsgLen, SOCKET_MASK_LEN, MASK_KEY, SOCKET_MASK_LEN) != EOK) { 254 std::cerr << "ClientSendReq::client memcpy_s MASK_KEY failed, error = " 255 << errno << ", desc = " << strerror(errno) << std::endl; 256 return false; 257 } 258 sendMsgLen += SOCKET_MASK_LEN; 259 260 std::string maskMessage; 261 for (uint64_t i = 0; i < msgLen; i++) { 262 uint64_t j = i % SOCKET_MASK_LEN; 263 maskMessage.push_back(message[i] ^ MASK_KEY[j]); 264 } 265 if (memcpy_s(sendBuf + sendMsgLen, msgLen, maskMessage.c_str(), msgLen) != EOK) { 266 std::cerr << "ClientSendReq::client memcpy_s maskMessage failed, error = " 267 << errno << ", desc = " << strerror(errno) << std::endl; 268 return false; 269 } 270 msgBuf[sendMsgLen + msgLen] = '\0'; 271 272 if (send(client_, sendBuf, sendMsgLen + msgLen, 0) != static_cast<int>(sendMsgLen + msgLen)) { 273 std::cerr << "ClientSendReq::client send msg req failed, error = " 274 << errno << ", desc = " << strerror(errno) << std::endl; 275 return false; 276 } 277 std::cout << "ClientRecvWSUpgradeRsp::client send msg req success..." << std::endl; 278 return true; 279 } 280 Close()281 void Close() 282 { 283 if (socketState_ == SocketState::UNINITED) { 284 return; 285 } 286 #if defined(OHOS_PLATFORM) 287 shutdown(client_, SHUT_RDWR); 288 #endif 289 close(client_); 290 client_ = -1; 291 socketState_ = SocketState::UNINITED; 292 } 293 294 private: 295 static constexpr char CLIENT_WEBSOCKET_UPGRADE_REQ[] = "GET / HTTP/1.1\r\n" 296 "Connection: Upgrade\r\n" 297 "Pragma: no-cache\r\n" 298 "Cache-Control: no-cache\r\n" 299 "Upgrade: websocket\r\n" 300 "Sec-WebSocket-Version: 13\r\n" 301 "Accept-Encoding: gzip, deflate, br\r\n" 302 "Sec-WebSocket-Key: 64b4B+s5JDlgkdg7NekJ+g==\r\n" 303 "Sec-WebSocket-Extensions: permessage-deflate\r\n"; 304 static constexpr int32_t CLIENT_WEBSOCKET_UPGRADE_RSP_LEN = 129; 305 static constexpr char MASK_KEY[SOCKET_MASK_LEN + 1] = "abcd"; 306 static constexpr int NET_SUCCESS = 1; 307 }; 308 309 #if defined(OHOS_PLATFORM) 310 static constexpr char UNIX_DOMAIN_PATH[] = "server.sock"; 311 #endif 312 static constexpr char HELLO_SERVER[] = "hello server"; 313 static constexpr char HELLO_CLIENT[] = "hello client"; 314 static constexpr char SERVER_OK[] = "server ok"; 315 static constexpr char CLIENT_OK[] = "client ok"; 316 static constexpr char QUIT[] = "quit"; 317 static const std::string LONG_MSG; 318 static const std::string LONG_LONG_MSG; 319 }; 320 321 const std::string WebSocketTest::LONG_MSG = std::string(1000, 'f'); 322 const std::string WebSocketTest::LONG_LONG_MSG = std::string(0xfffff, 'f'); 323 324 HWTEST_F(WebSocketTest, ConnectWebSocketTest, testing::ext::TestSize.Level0) 325 { 326 WebSocket serverSocket; 327 bool ret = false; 328 #if defined(OHOS_PLATFORM) 329 int appPid = getpid(); 330 ret = serverSocket.InitUnixWebSocket(UNIX_DOMAIN_PATH + std::to_string(appPid), 5); 331 #else 332 ret = serverSocket.InitTcpWebSocket(9230, 5); 333 #endif 334 ASSERT_TRUE(ret); 335 pid_t pid = fork(); 336 if (pid == 0) { 337 // subprocess, handle client connect and recv/send message 338 // note: EXPECT/ASSERT produce errors in subprocess that can not lead to failure of testcase in mainprocess, 339 // so testcase still success finally. 340 ClientWebSocket clientSocket; 341 bool retClient = false; 342 #if defined(OHOS_PLATFORM) 343 retClient = clientSocket.ClientConnectUnixWebSocket(UNIX_DOMAIN_PATH + std::to_string(appPid), 5); 344 #else 345 retClient = clientSocket.ClientConnectTcpWebSocket(5); 346 #endif 347 ASSERT_TRUE(retClient); 348 retClient = clientSocket.ClientSendWSUpgradeReq(); 349 ASSERT_TRUE(retClient); 350 retClient = clientSocket.ClientRecvWSUpgradeRsp(); 351 ASSERT_TRUE(retClient); 352 retClient = clientSocket.ClientSendReq(HELLO_SERVER); 353 EXPECT_TRUE(retClient); 354 std::string recv = clientSocket.Decode(); 355 EXPECT_EQ(strcmp(recv.c_str(), HELLO_CLIENT), 0); 356 if (strcmp(recv.c_str(), HELLO_CLIENT) == 0) { 357 retClient = clientSocket.ClientSendReq(CLIENT_OK); 358 EXPECT_TRUE(retClient); 359 } 360 retClient = clientSocket.ClientSendReq(LONG_MSG); 361 EXPECT_TRUE(retClient); 362 recv = clientSocket.Decode(); 363 EXPECT_EQ(strcmp(recv.c_str(), SERVER_OK), 0); 364 if (strcmp(recv.c_str(), SERVER_OK) == 0) { 365 retClient = clientSocket.ClientSendReq(CLIENT_OK); 366 EXPECT_TRUE(retClient); 367 } 368 retClient = clientSocket.ClientSendReq(LONG_LONG_MSG); 369 EXPECT_TRUE(retClient); 370 recv = clientSocket.Decode(); 371 EXPECT_EQ(strcmp(recv.c_str(), SERVER_OK), 0); 372 if (strcmp(recv.c_str(), SERVER_OK) == 0) { 373 retClient = clientSocket.ClientSendReq(CLIENT_OK); 374 EXPECT_TRUE(retClient); 375 } 376 retClient = clientSocket.ClientSendReq(QUIT); 377 EXPECT_TRUE(retClient); 378 clientSocket.Close(); 379 exit(0); 380 } else if (pid > 0) { 381 // mainprocess, handle server connect and recv/send message 382 #if defined(OHOS_PLATFORM) 383 ret = serverSocket.ConnectUnixWebSocket(); 384 #else 385 ret = serverSocket.ConnectTcpWebSocket(); 386 #endif 387 ASSERT_TRUE(ret); 388 std::string recv = serverSocket.Decode(); 389 EXPECT_EQ(strcmp(recv.c_str(), HELLO_SERVER), 0); 390 serverSocket.SendReply(HELLO_CLIENT); 391 recv = serverSocket.Decode(); 392 EXPECT_EQ(strcmp(recv.c_str(), CLIENT_OK), 0); 393 recv = serverSocket.Decode(); 394 EXPECT_EQ(strcmp(recv.c_str(), LONG_MSG.c_str()), 0); 395 serverSocket.SendReply(SERVER_OK); 396 recv = serverSocket.Decode(); 397 EXPECT_EQ(strcmp(recv.c_str(), CLIENT_OK), 0); 398 recv = serverSocket.Decode(); 399 EXPECT_EQ(strcmp(recv.c_str(), LONG_LONG_MSG.c_str()), 0); 400 serverSocket.SendReply(SERVER_OK); 401 recv = serverSocket.Decode(); 402 EXPECT_EQ(strcmp(recv.c_str(), CLIENT_OK), 0); 403 recv = serverSocket.Decode(); 404 EXPECT_EQ(strcmp(recv.c_str(), QUIT), 0); 405 serverSocket.Close(); 406 // sleep ensure that linux os core can really release resource 407 sleep(3); 408 } else { 409 std::cerr << "ConnectWebSocketTest::fork failed, error = " 410 << errno << ", desc = " << strerror(errno) << std::endl; 411 } 412 } 413 414 HWTEST_F(WebSocketTest, ReConnectWebSocketTest, testing::ext::TestSize.Level0) 415 { 416 WebSocket serverSocket; 417 bool ret = false; 418 #if defined(OHOS_PLATFORM) 419 int appPid = getpid(); 420 ret = serverSocket.InitUnixWebSocket(UNIX_DOMAIN_PATH + std::to_string(appPid), 5); 421 #else 422 ret = serverSocket.InitTcpWebSocket(9230, 5); 423 #endif 424 ASSERT_TRUE(ret); 425 for (int i = 0; i < 5; i++) { 426 pid_t pid = fork(); 427 if (pid == 0) { 428 // subprocess, handle client connect and recv/send message 429 // note: EXPECT/ASSERT produce errors in subprocess that can not lead to failure of testcase in mainprocess, 430 // so testcase still success finally. 431 ClientWebSocket clientSocket; 432 bool retClient = false; 433 #if defined(OHOS_PLATFORM) 434 retClient = clientSocket.ClientConnectUnixWebSocket(UNIX_DOMAIN_PATH + std::to_string(appPid), 5); 435 #else 436 retClient = clientSocket.ClientConnectTcpWebSocket(5); 437 #endif 438 ASSERT_TRUE(retClient); 439 retClient = clientSocket.ClientSendWSUpgradeReq(); 440 ASSERT_TRUE(retClient); 441 retClient = clientSocket.ClientRecvWSUpgradeRsp(); 442 ASSERT_TRUE(retClient); 443 retClient = clientSocket.ClientSendReq(HELLO_SERVER + std::to_string(i)); 444 EXPECT_TRUE(retClient); 445 std::string recv = clientSocket.Decode(); 446 EXPECT_EQ(strcmp(recv.c_str(), (HELLO_CLIENT + std::to_string(i)).c_str()), 0); 447 if (strcmp(recv.c_str(), (HELLO_CLIENT + std::to_string(i)).c_str()) == 0) { 448 retClient = clientSocket.ClientSendReq(CLIENT_OK + std::to_string(i)); 449 EXPECT_TRUE(retClient); 450 } 451 clientSocket.Close(); 452 exit(0); 453 } else if (pid > 0) { 454 // mainprocess, handle server connect and recv/send message 455 #if defined(OHOS_PLATFORM) 456 ret = serverSocket.ConnectUnixWebSocket(); 457 #else 458 ret = serverSocket.ConnectTcpWebSocket(); 459 #endif 460 ASSERT_TRUE(ret); 461 std::string recv = serverSocket.Decode(); 462 EXPECT_EQ(strcmp(recv.c_str(), (HELLO_SERVER + std::to_string(i)).c_str()), 0); 463 serverSocket.SendReply(HELLO_CLIENT + std::to_string(i)); 464 recv = serverSocket.Decode(); 465 EXPECT_EQ(strcmp(recv.c_str(), (CLIENT_OK + std::to_string(i)).c_str()), 0); 466 while (serverSocket.IsConnected()) { 467 serverSocket.Decode(); 468 } 469 } else { 470 std::cerr << "ReConnectWebSocketTest::fork failed, error = " 471 << errno << ", desc = " << strerror(errno) << std::endl; 472 } 473 } 474 serverSocket.Close(); 475 } 476 } // namespace panda::test 477