1 /* 2 * Copyright (c) 2022-2024 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16 #ifndef COMMUNICATIONNETSTACK_TLS_SOCEKT_H 17 #define COMMUNICATIONNETSTACK_TLS_SOCEKT_H 18 19 #include <any> 20 #include <condition_variable> 21 #include <cstring> 22 #include <functional> 23 #include <map> 24 #include <thread> 25 #include <tuple> 26 #include <unistd.h> 27 #include <vector> 28 29 #include "extra_options_base.h" 30 #include "net_address.h" 31 #include "proxy_options.h" 32 #include "socket_error.h" 33 #include "socket_remote_info.h" 34 #include "socket_state_base.h" 35 #include "tcp_connect_options.h" 36 #include "tcp_extra_options.h" 37 #include "tcp_send_options.h" 38 #include "tls.h" 39 #include "tls_certificate.h" 40 #include "tls_configuration.h" 41 #include "tls_context.h" 42 #include "tls_key.h" 43 44 namespace OHOS { 45 namespace NetStack { 46 namespace TlsSocket { 47 48 using BindCallback = std::function<void(int32_t errorNumber)>; 49 using ConnectCallback = std::function<void(int32_t errorNumber)>; 50 using SendCallback = std::function<void(int32_t errorNumber)>; 51 using CloseCallback = std::function<void(int32_t errorNumber)>; 52 using GetRemoteAddressCallback = std::function<void(int32_t errorNumber, const Socket::NetAddress &address)>; 53 using GetLocalAddressCallback = std::function<void(int32_t errorNumber, const Socket::NetAddress &address)>; 54 using GetStateCallback = std::function<void(int32_t errorNumber, const Socket::SocketStateBase &state)>; 55 using SetExtraOptionsCallback = std::function<void(int32_t errorNumber)>; 56 using GetCertificateCallback = std::function<void(int32_t errorNumber, const X509CertRawData &cert)>; 57 using GetRemoteCertificateCallback = std::function<void(int32_t errorNumber, const X509CertRawData &cert)>; 58 using GetProtocolCallback = std::function<void(int32_t errorNumber, const std::string &protocol)>; 59 using GetCipherSuiteCallback = std::function<void(int32_t errorNumber, const std::vector<std::string> &suite)>; 60 using GetSignatureAlgorithmsCallback = 61 std::function<void(int32_t errorNumber, const std::vector<std::string> &algorithms)>; 62 63 using OnMessageCallback = std::function<void(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)>; 64 using OnConnectCallback = std::function<void(void)>; 65 using OnCloseCallback = std::function<void(void)>; 66 using OnErrorCallback = std::function<void(int32_t errorNumber, const std::string &errorString)>; 67 68 using CheckServerIdentity = 69 std::function<void(const std::string &hostName, const std::vector<std::string> &x509Certificates)>; 70 71 constexpr const char *ALPN_PROTOCOLS_HTTP_1_1 = "http1.1"; 72 constexpr const char *ALPN_PROTOCOLS_HTTP_2 = "h2"; 73 74 constexpr size_t MAX_ERR_LEN = 1024; 75 76 /** 77 * Parameters required during communication 78 */ 79 class TLSSecureOptions { 80 public: 81 TLSSecureOptions() = default; 82 ~TLSSecureOptions() = default; 83 84 TLSSecureOptions(const TLSSecureOptions &tlsSecureOptions); 85 TLSSecureOptions &operator=(const TLSSecureOptions &tlsSecureOptions); 86 /** 87 * Set root CA Chain to verify the server cert 88 * @param caChain root certificate chain used to validate server certificates 89 */ 90 void SetCaChain(const std::vector<std::string> &caChain); 91 92 /** 93 * Set digital certificate for server verification 94 * @param cert digital certificate sent to the server to verify validity 95 */ 96 void SetCert(const std::string &cert); 97 98 /** 99 * Set key to decrypt server data 100 * @param keyChain key used to decrypt server data 101 */ 102 void SetKey(const SecureData &key); 103 104 /** 105 * Set the password to read the private key 106 * @param keyPass read the password of the private key 107 */ 108 void SetKeyPass(const SecureData &keyPass); 109 110 /** 111 * Set the protocol used in communication 112 * @param protocolChain protocol version number used 113 */ 114 void SetProtocolChain(const std::vector<std::string> &protocolChain); 115 116 /** 117 * Whether the peer cipher suite is preferred for communication 118 * @param useRemoteCipherPrefer whether the peer cipher suite is preferred 119 */ 120 void SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer); 121 122 /** 123 * Encryption algorithm used in communication 124 * @param signatureAlgorithms encryption algorithm e.g: rsa 125 */ 126 void SetSignatureAlgorithms(const std::string &signatureAlgorithms); 127 128 /** 129 * Crypto suite used in communication 130 * @param cipherSuite cipher suite e.g:AES256-SHA256 131 */ 132 void SetCipherSuite(const std::string &cipherSuite); 133 134 /** 135 * Set a revoked certificate 136 * @param crlChain certificate Revocation List 137 */ 138 void SetCrlChain(const std::vector<std::string> &crlChain); 139 140 /** 141 * Get root CA Chain to verify the server cert 142 * @return root CA chain 143 */ 144 [[nodiscard]] const std::vector<std::string> &GetCaChain() const; 145 146 /** 147 * Obtain a certificate to send to the server for checking 148 * @return digital certificate obtained 149 */ 150 [[nodiscard]] const std::string &GetCert() const; 151 152 /** 153 * Obtain the private key in the communication process 154 * @return private key during communication 155 */ 156 [[nodiscard]] const SecureData &GetKey() const; 157 158 /** 159 * Get the password to read the private key 160 * @return read the password of the private key 161 */ 162 [[nodiscard]] const SecureData &GetKeyPass() const; 163 164 /** 165 * Get the protocol of the communication process 166 * @return protocol of communication process 167 */ 168 [[nodiscard]] const std::vector<std::string> &GetProtocolChain() const; 169 170 /** 171 * Is the remote cipher suite being used for communication 172 * @return is use Remote Cipher Prefer 173 */ 174 [[nodiscard]] bool UseRemoteCipherPrefer() const; 175 176 /** 177 * Obtain the encryption algorithm used in the communication process 178 * @return encryption algorithm used in communication 179 */ 180 [[nodiscard]] const std::string &GetSignatureAlgorithms() const; 181 182 /** 183 * Obtain the cipher suite used in communication 184 * @return crypto suite used in communication 185 */ 186 [[nodiscard]] const std::string &GetCipherSuite() const; 187 188 /** 189 * Get revoked certificate chain 190 * @return revoked certificate chain 191 */ 192 [[nodiscard]] const std::vector<std::string> &GetCrlChain() const; 193 194 void SetVerifyMode(VerifyMode verifyMode); 195 196 [[nodiscard]] VerifyMode GetVerifyMode() const; 197 198 private: 199 std::vector<std::string> caChain_; 200 std::string cert_; 201 SecureData key_; 202 SecureData keyPass_; 203 std::vector<std::string> protocolChain_; 204 bool useRemoteCipherPrefer_ = false; 205 std::string signatureAlgorithms_; 206 std::string cipherSuite_; 207 std::vector<std::string> crlChain_; 208 VerifyMode TLSVerifyMode_ = VerifyMode::ONE_WAY_MODE; 209 }; 210 211 /** 212 * Some options required during tls connection 213 */ 214 class TLSConnectOptions { 215 public: 216 friend class TLSSocketExec; 217 /** 218 * Communication parameters required for connection establishment 219 * @param address communication parameters during connection 220 */ 221 void SetNetAddress(const Socket::NetAddress &address); 222 223 /** 224 * Parameters required during communication 225 * @param tlsSecureOptions certificate and other relevant parameters 226 */ 227 void SetTlsSecureOptions(TLSSecureOptions &tlsSecureOptions); 228 229 /** 230 * Set the callback function to check the validity of the server 231 * @param checkServerIdentity callback function passed in by API caller 232 */ 233 void SetCheckServerIdentity(const CheckServerIdentity &checkServerIdentity); 234 235 /** 236 * Set application layer protocol negotiation 237 * @param alpnProtocols application layer protocol negotiation 238 */ 239 void SetAlpnProtocols(const std::vector<std::string> &alpnProtocols); 240 241 /** 242 * Set whether to skip remote validation 243 * @param skipRemoteValidation flag to choose whether to skip validation 244 */ 245 void SetSkipRemoteValidation(bool skipRemoteValidation); 246 247 /** 248 * Obtain the network address of the communication process 249 * @return network address 250 */ 251 [[nodiscard]] Socket::NetAddress GetNetAddress() const; 252 253 /** 254 * Obtain the parameters required in the communication process 255 * @return certificate and other relevant parameters 256 */ 257 [[nodiscard]] TLSSecureOptions GetTlsSecureOptions() const; 258 259 /** 260 * Get the check server ID callback function passed in by the API caller 261 * @return check the server identity callback function 262 */ 263 [[nodiscard]] CheckServerIdentity GetCheckServerIdentity() const; 264 265 /** 266 * Obtain the application layer protocol negotiation in the communication process 267 * @return application layer protocol negotiation 268 */ 269 [[nodiscard]] const std::vector<std::string> &GetAlpnProtocols() const; 270 271 /** 272 * Get the choice of whether to skip remote validaion 273 * @return skipRemoteValidaion result 274 */ 275 [[nodiscard]] bool GetSkipRemoteValidation() const; 276 277 void SetHostName(const std::string &hostName); 278 [[nodiscard]] std::string GetHostName() const; 279 280 std::shared_ptr<Socket::ProxyOptions> proxyOptions_{nullptr}; 281 282 private: 283 Socket::NetAddress address_; 284 TLSSecureOptions tlsSecureOptions_; 285 CheckServerIdentity checkServerIdentity_; 286 std::vector<std::string> alpnProtocols_; 287 bool skipRemoteValidation_ = false; 288 std::string hostName_; 289 }; 290 291 /** 292 * TLS socket interface class 293 */ 294 class TLSSocket : public std::enable_shared_from_this<TLSSocket> { 295 public: 296 TLSSocket(const TLSSocket &) = delete; 297 TLSSocket(TLSSocket &&) = delete; 298 299 TLSSocket &operator=(const TLSSocket &) = delete; 300 TLSSocket &operator=(TLSSocket &&) = delete; 301 302 TLSSocket() = default; 303 ~TLSSocket() = default; 304 TLSSocket(int sockFd)305 explicit TLSSocket(int sockFd): sockFd_(sockFd), isExtSock_(true) {} 306 307 /** 308 * Create a socket and bind to the address specified by address 309 * @param address ip address 310 * @param callback callback to the caller if bind ok or not 311 */ 312 void Bind(Socket::NetAddress &address, const BindCallback &callback); 313 314 /** 315 * Establish a secure connection based on the created socket 316 * @param tlsConnectOptions some options required during tls connection 317 * @param callback callback to the caller if connect ok or not 318 */ 319 void Connect(TLSConnectOptions &tlsConnectOptions, const ConnectCallback &callback); 320 321 /** 322 * Send data based on the created socket 323 * @param tcpSendOptions some options required during tcp data transmission 324 * @param callback callback to the caller if send ok or not 325 */ 326 void Send(const Socket::TCPSendOptions &tcpSendOptions, const SendCallback &callback); 327 328 /** 329 * Disconnect by releasing the socket when communicating 330 * @param callback callback to the caller 331 */ 332 void Close(const CloseCallback &callback); 333 334 /** 335 * Get the peer network address 336 * @param callback callback to the caller 337 */ 338 void GetRemoteAddress(const GetRemoteAddressCallback &callback); 339 340 /** 341 * Get the status of the current socket 342 * @param callback callback to the caller 343 */ 344 void GetState(const GetStateCallback &callback); 345 346 /** 347 * Gets or sets the options associated with the current socket 348 * @param tcpExtraOptions options associated with the current socket 349 * @param callback callback to the caller 350 */ 351 void SetExtraOptions(const Socket::TCPExtraOptions &tcpExtraOptions, const SetExtraOptionsCallback &callback); 352 353 /** 354 * Get a local digital certificate 355 * @param callback callback to the caller 356 */ 357 void GetCertificate(const GetCertificateCallback &callback); 358 359 /** 360 * Get the peer digital certificate 361 * @param needChain need chain 362 * @param callback callback to the caller 363 */ 364 void GetRemoteCertificate(const GetRemoteCertificateCallback &callback); 365 366 /** 367 * Obtain the protocol used in communication 368 * @param callback callback to the caller 369 */ 370 void GetProtocol(const GetProtocolCallback &callback); 371 372 /** 373 * Obtain the cipher suite used in communication 374 * @param callback callback to the caller 375 */ 376 void GetCipherSuite(const GetCipherSuiteCallback &callback); 377 378 /** 379 * Obtain the encryption algorithm used in the communication process 380 * @param callback callback to the caller 381 */ 382 void GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback &callback); 383 384 /** 385 * Register a callback which is called when message is received 386 * @param onMessageCallback callback which is called when message is received 387 */ 388 void OnMessage(const OnMessageCallback &onMessageCallback); 389 390 /** 391 * Register the callback that is called when the connection is established 392 * @param onConnectCallback callback invoked when connection is established 393 */ 394 void OnConnect(const OnConnectCallback &onConnectCallback); 395 396 /** 397 * Register the callback that is called when the connection is disconnected 398 * @param onCloseCallback callback invoked when disconnected 399 */ 400 void OnClose(const OnCloseCallback &onCloseCallback); 401 402 /** 403 * Register the callback that is called when an error occurs 404 * @param onErrorCallback callback invoked when an error occurs 405 */ 406 void OnError(const OnErrorCallback &onErrorCallback); 407 408 /** 409 * Unregister the callback which is called when message is received 410 */ 411 void OffMessage(); 412 413 /** 414 * Off Connect 415 */ 416 void OffConnect(); 417 418 /** 419 * Off Close 420 */ 421 void OffClose(); 422 423 /** 424 * Off Error 425 */ 426 void OffError(); 427 428 /** 429 * Get the socket file description of the server 430 */ 431 int GetSocketFd(); 432 433 /** 434 * Set the current socket file description address of the server 435 */ 436 void SetLocalAddress(const Socket::NetAddress &address); 437 438 /** 439 * Get the current socket file description address of the server 440 */ 441 Socket::NetAddress GetLocalAddress(); 442 443 void ExecTlsGetAddr( 444 const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr, socklen_t *len); 445 446 bool ExecTlsSetSockBlockFlag(int sock, bool noneBlock); 447 448 bool IsExtSock() const; 449 450 private: 451 class TLSSocketInternal final { 452 public: 453 TLSSocketInternal() = default; 454 ~TLSSocketInternal() = default; 455 456 /** 457 * Establish an encrypted connection on the specified socket 458 * @param sock socket for establishing encrypted connection 459 * @param options some options required during tls connection 460 * @param isExtSock socket fd is originated from external source when constructing tls socket 461 * @return whether the encrypted connection is successfully established 462 */ 463 bool TlsConnectToHost(int sock, const TLSConnectOptions &options, bool isExtSock); 464 465 /** 466 * Set the configuration items for establishing encrypted connections 467 * @param config configuration item when establishing encrypted connection 468 */ 469 void SetTlsConfiguration(const TLSConnectOptions &config); 470 471 /** 472 * Send data through an established encrypted connection 473 * @param data data sent over an established encrypted connection 474 * @return whether the data is successfully sent to the server 475 */ 476 bool Send(const std::string &data); 477 478 /** 479 * Receive the data sent by the server through the established encrypted connection 480 * @param buffer receive the data sent by the server 481 * @param maxBufferSize the size of the data received from the server 482 * @return whether the data sent by the server is successfully received 483 */ 484 int Recv(char *buffer, int maxBufferSize); 485 486 /** 487 * Disconnect encrypted connection 488 * @return whether the encrypted connection was successfully disconnected 489 */ 490 bool Close(); 491 492 /** 493 * Set the application layer negotiation protocol in the encrypted communication process 494 * @param alpnProtocols application layer negotiation protocol 495 * @return set whether the application layer negotiation protocol is successful during encrypted communication 496 */ 497 bool SetAlpnProtocols(const std::vector<std::string> &alpnProtocols); 498 499 /** 500 * Storage of server communication related network information 501 * @param remoteInfo communication related network information 502 */ 503 void MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo); 504 505 /** 506 * convert the code to ssl error code 507 * @return the value for ssl error code. 508 */ 509 int ConvertSSLError(void); 510 511 /** 512 * Get configuration options for encrypted communication process 513 * @return configuration options for encrypted communication processes 514 */ 515 [[nodiscard]] TLSConfiguration GetTlsConfiguration() const; 516 517 /** 518 * Obtain the cipher suite during encrypted communication 519 * @return crypto suite used in encrypted communication 520 */ 521 [[nodiscard]] std::vector<std::string> GetCipherSuite() const; 522 523 /** 524 * Obtain the peer certificate used in encrypted communication 525 * @return peer certificate used in encrypted communication 526 */ 527 [[nodiscard]] std::string GetRemoteCertificate() const; 528 529 /** 530 * Obtain the peer certificate used in encrypted communication 531 * @return peer certificate serialization data used in encrypted communication 532 */ 533 [[nodiscard]] const X509CertRawData &GetRemoteCertRawData() const; 534 535 /** 536 * Obtain the certificate used in encrypted communication 537 * @return certificate serialization data used in encrypted communication 538 */ 539 [[nodiscard]] const X509CertRawData &GetCertificate() const; 540 541 /** 542 * Get the encryption algorithm used in encrypted communication 543 * @return encryption algorithm used in encrypted communication 544 */ 545 [[nodiscard]] std::vector<std::string> GetSignatureAlgorithms() const; 546 547 /** 548 * Obtain the communication protocol used in encrypted communication 549 * @return communication protocol used in encrypted communication 550 */ 551 [[nodiscard]] std::string GetProtocol() const; 552 553 /** 554 * Set the information about the shared signature algorithm supported by peers during encrypted communication 555 * @return information about peer supported shared signature algorithms 556 */ 557 [[nodiscard]] bool SetSharedSigals(); 558 559 /** 560 * Obtain the ssl used in encrypted communication 561 * @return SSL used in encrypted communication 562 */ 563 [[nodiscard]] ssl_st *GetSSL(); 564 565 private: 566 bool SendRetry(ssl_st *ssl, const char *curPos, size_t curSendSize, int sockfd); 567 bool StartTlsConnected(const TLSConnectOptions &options); 568 bool CreatTlsContext(); 569 bool StartShakingHands(const TLSConnectOptions &options); 570 bool GetRemoteCertificateFromPeer(); 571 bool SetRemoteCertRawData(); 572 bool PollSend(int sockfd, ssl_st *ssl, const char *pdata, int sendSize); 573 std::string CheckServerIdentityLegal(const std::string &hostName, const X509 *x509Certificates); 574 std::string CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext, 575 const X509 *x509Certificates); 576 577 private: 578 std::mutex mutexForSsl_; 579 ssl_st *ssl_ = nullptr; 580 X509 *peerX509_ = nullptr; 581 uint16_t port_ = 0; 582 sa_family_t family_ = 0; 583 int32_t socketDescriptor_ = 0; 584 585 TLSContext tlsContext_; 586 TLSConfiguration configuration_; 587 Socket::NetAddress address_; 588 X509CertRawData remoteRawData_; 589 590 std::string hostName_; 591 std::string remoteCert_; 592 593 std::vector<std::string> signatureAlgorithms_; 594 std::unique_ptr<TLSContext> tlsContextPointer_ = nullptr; 595 }; 596 597 private: 598 TLSSocketInternal tlsSocketInternal_; 599 600 static std::string MakeAddressString(sockaddr *addr); 601 602 static void GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr, 603 socklen_t *len); 604 605 void CallOnMessageCallback(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo); 606 void CallOnConnectCallback(); 607 void CallOnCloseCallback(); 608 void CallOnErrorCallback(int32_t err, const std::string &errString); 609 610 void CallBindCallback(int32_t err, BindCallback callback); 611 void CallConnectCallback(int32_t err, ConnectCallback callback); 612 void CallSendCallback(int32_t err, SendCallback callback); 613 void CallCloseCallback(int32_t err, CloseCallback callback); 614 void CallGetRemoteAddressCallback(int32_t err, const Socket::NetAddress &address, 615 GetRemoteAddressCallback callback); 616 void CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, GetStateCallback callback); 617 void CallSetExtraOptionsCallback(int32_t err, SetExtraOptionsCallback callback); 618 void CallGetCertificateCallback(int32_t err, const X509CertRawData &cert, GetCertificateCallback callback); 619 void CallGetRemoteCertificateCallback(int32_t err, const X509CertRawData &cert, 620 GetRemoteCertificateCallback callback); 621 void CallGetProtocolCallback(int32_t err, const std::string &protocol, GetProtocolCallback callback); 622 void CallGetCipherSuiteCallback(int32_t err, const std::vector<std::string> &suite, 623 GetCipherSuiteCallback callback); 624 void CallGetSignatureAlgorithmsCallback(int32_t err, const std::vector<std::string> &algorithms, 625 GetSignatureAlgorithmsCallback callback); 626 627 int ReadMessage(); 628 void StartReadMessage(); 629 630 void GetIp4RemoteAddress(const GetRemoteAddressCallback &callback); 631 void GetIp6RemoteAddress(const GetRemoteAddressCallback &callback); 632 633 [[nodiscard]] bool SetBaseOptions(const Socket::ExtraOptionsBase &option) const; 634 [[nodiscard]] bool SetExtraOptions(const Socket::TCPExtraOptions &option) const; 635 636 void MakeIpSocket(sa_family_t family); 637 638 template<class T> DealCallback(int32_t err,T & callback)639 void DealCallback(int32_t err, T &callback) 640 { 641 if (callback) { 642 callback(err); 643 } 644 } 645 646 private: 647 static constexpr const size_t MAX_ERROR_LEN = 128; 648 static constexpr const size_t MAX_BUFFER_SIZE = 8192; 649 650 OnMessageCallback onMessageCallback_; 651 OnConnectCallback onConnectCallback_; 652 OnCloseCallback onCloseCallback_; 653 OnErrorCallback onErrorCallback_; 654 655 std::mutex mutex_; 656 std::mutex recvMutex_; 657 std::mutex cvMutex_; 658 bool isRunning_ = false; 659 bool isRunOver_ = true; 660 std::condition_variable cvSslFree_; 661 int sockFd_ = -1; 662 bool isExtSock_ = false; 663 Socket::NetAddress localAddress_; 664 }; 665 } // namespace TlsSocket 666 } // namespace NetStack 667 } // namespace OHOS 668 669 #endif // COMMUNICATIONNETSTACK_TLS_SOCEKT_H 670