1 /* 2 * Copyright (c) 2022-2024 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16 #ifndef COMMUNICATIONNETSTACK_TLS_SOCEKT_H 17 #define COMMUNICATIONNETSTACK_TLS_SOCEKT_H 18 19 #include <any> 20 #include <condition_variable> 21 #include <cstring> 22 #include <functional> 23 #include <map> 24 #include <shared_mutex> 25 #include <thread> 26 #include <tuple> 27 #include <unistd.h> 28 #include <vector> 29 30 #include "extra_options_base.h" 31 #include "net_address.h" 32 #include "proxy_options.h" 33 #include "socket_error.h" 34 #include "socket_remote_info.h" 35 #include "socket_state_base.h" 36 #include "tcp_connect_options.h" 37 #include "tcp_extra_options.h" 38 #include "tcp_send_options.h" 39 #include "tls.h" 40 #include "tls_certificate.h" 41 #include "tls_configuration.h" 42 #include "tls_context.h" 43 #include "tls_key.h" 44 45 namespace OHOS { 46 namespace NetStack { 47 namespace TlsSocket { 48 49 using BindCallback = std::function<void(int32_t errorNumber)>; 50 using ConnectCallback = std::function<void(int32_t errorNumber)>; 51 using SendCallback = std::function<void(int32_t errorNumber)>; 52 using CloseCallback = std::function<void(int32_t errorNumber)>; 53 using GetRemoteAddressCallback = std::function<void(int32_t errorNumber, const Socket::NetAddress &address)>; 54 using GetLocalAddressCallback = std::function<void(int32_t errorNumber, const Socket::NetAddress &address)>; 55 using GetStateCallback = std::function<void(int32_t errorNumber, const Socket::SocketStateBase &state)>; 56 using SetExtraOptionsCallback = std::function<void(int32_t errorNumber)>; 57 using GetCertificateCallback = std::function<void(int32_t errorNumber, const X509CertRawData &cert)>; 58 using GetRemoteCertificateCallback = std::function<void(int32_t errorNumber, const X509CertRawData &cert)>; 59 using GetProtocolCallback = std::function<void(int32_t errorNumber, const std::string &protocol)>; 60 using GetCipherSuiteCallback = std::function<void(int32_t errorNumber, const std::vector<std::string> &suite)>; 61 using GetSignatureAlgorithmsCallback = 62 std::function<void(int32_t errorNumber, const std::vector<std::string> &algorithms)>; 63 64 using OnMessageCallback = std::function<void(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)>; 65 using OnConnectCallback = std::function<void(void)>; 66 using OnCloseCallback = std::function<void(void)>; 67 using OnErrorCallback = std::function<void(int32_t errorNumber, const std::string &errorString)>; 68 69 using CheckServerIdentity = 70 std::function<void(const std::string &hostName, const std::vector<std::string> &x509Certificates)>; 71 72 constexpr const char *ALPN_PROTOCOLS_HTTP_1_1 = "http1.1"; 73 constexpr const char *ALPN_PROTOCOLS_HTTP_2 = "h2"; 74 75 constexpr size_t MAX_ERR_LEN = 1024; 76 77 /** 78 * Parameters required during communication 79 */ 80 class TLSSecureOptions { 81 public: 82 TLSSecureOptions() = default; 83 ~TLSSecureOptions() = default; 84 85 TLSSecureOptions(const TLSSecureOptions &tlsSecureOptions); 86 TLSSecureOptions &operator=(const TLSSecureOptions &tlsSecureOptions); 87 /** 88 * Set root CA Chain to verify the server cert 89 * @param caChain root certificate chain used to validate server certificates 90 */ 91 void SetCaChain(const std::vector<std::string> &caChain); 92 93 /** 94 * Set digital certificate for server verification 95 * @param cert digital certificate sent to the server to verify validity 96 */ 97 void SetCert(const std::string &cert); 98 99 /** 100 * Set key to decrypt server data 101 * @param keyChain key used to decrypt server data 102 */ 103 void SetKey(const SecureData &key); 104 105 /** 106 * Set the password to read the private key 107 * @param keyPass read the password of the private key 108 */ 109 void SetKeyPass(const SecureData &keyPass); 110 111 /** 112 * Set the protocol used in communication 113 * @param protocolChain protocol version number used 114 */ 115 void SetProtocolChain(const std::vector<std::string> &protocolChain); 116 117 /** 118 * Whether the peer cipher suite is preferred for communication 119 * @param useRemoteCipherPrefer whether the peer cipher suite is preferred 120 */ 121 void SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer); 122 123 /** 124 * Encryption algorithm used in communication 125 * @param signatureAlgorithms encryption algorithm e.g: rsa 126 */ 127 void SetSignatureAlgorithms(const std::string &signatureAlgorithms); 128 129 /** 130 * Crypto suite used in communication 131 * @param cipherSuite cipher suite e.g:AES256-SHA256 132 */ 133 void SetCipherSuite(const std::string &cipherSuite); 134 135 /** 136 * Set a revoked certificate 137 * @param crlChain certificate Revocation List 138 */ 139 void SetCrlChain(const std::vector<std::string> &crlChain); 140 141 /** 142 * Get root CA Chain to verify the server cert 143 * @return root CA chain 144 */ 145 [[nodiscard]] const std::vector<std::string> &GetCaChain() const; 146 147 /** 148 * Obtain a certificate to send to the server for checking 149 * @return digital certificate obtained 150 */ 151 [[nodiscard]] const std::string &GetCert() const; 152 153 /** 154 * Obtain the private key in the communication process 155 * @return private key during communication 156 */ 157 [[nodiscard]] const SecureData &GetKey() const; 158 159 /** 160 * Get the password to read the private key 161 * @return read the password of the private key 162 */ 163 [[nodiscard]] const SecureData &GetKeyPass() const; 164 165 /** 166 * Get the protocol of the communication process 167 * @return protocol of communication process 168 */ 169 [[nodiscard]] const std::vector<std::string> &GetProtocolChain() const; 170 171 /** 172 * Is the remote cipher suite being used for communication 173 * @return is use Remote Cipher Prefer 174 */ 175 [[nodiscard]] bool UseRemoteCipherPrefer() const; 176 177 /** 178 * Obtain the encryption algorithm used in the communication process 179 * @return encryption algorithm used in communication 180 */ 181 [[nodiscard]] const std::string &GetSignatureAlgorithms() const; 182 183 /** 184 * Obtain the cipher suite used in communication 185 * @return crypto suite used in communication 186 */ 187 [[nodiscard]] const std::string &GetCipherSuite() const; 188 189 /** 190 * Get revoked certificate chain 191 * @return revoked certificate chain 192 */ 193 [[nodiscard]] const std::vector<std::string> &GetCrlChain() const; 194 195 void SetVerifyMode(VerifyMode verifyMode); 196 197 [[nodiscard]] VerifyMode GetVerifyMode() const; 198 199 private: 200 std::vector<std::string> caChain_; 201 std::string cert_; 202 SecureData key_; 203 SecureData keyPass_; 204 std::vector<std::string> protocolChain_; 205 bool useRemoteCipherPrefer_ = false; 206 std::string signatureAlgorithms_; 207 std::string cipherSuite_; 208 std::vector<std::string> crlChain_; 209 VerifyMode TLSVerifyMode_ = VerifyMode::ONE_WAY_MODE; 210 }; 211 212 /** 213 * Some options required during tls connection 214 */ 215 class TLSConnectOptions { 216 public: 217 friend class TLSSocketExec; 218 /** 219 * Communication parameters required for connection establishment 220 * @param address communication parameters during connection 221 */ 222 void SetNetAddress(const Socket::NetAddress &address); 223 224 /** 225 * Parameters required during communication 226 * @param tlsSecureOptions certificate and other relevant parameters 227 */ 228 void SetTlsSecureOptions(TLSSecureOptions &tlsSecureOptions); 229 230 /** 231 * Set the callback function to check the validity of the server 232 * @param checkServerIdentity callback function passed in by API caller 233 */ 234 void SetCheckServerIdentity(const CheckServerIdentity &checkServerIdentity); 235 236 /** 237 * Set application layer protocol negotiation 238 * @param alpnProtocols application layer protocol negotiation 239 */ 240 void SetAlpnProtocols(const std::vector<std::string> &alpnProtocols); 241 242 /** 243 * Set whether to skip remote validation 244 * @param skipRemoteValidation flag to choose whether to skip validation 245 */ 246 void SetSkipRemoteValidation(bool skipRemoteValidation); 247 248 /** 249 * Obtain the network address of the communication process 250 * @return network address 251 */ 252 [[nodiscard]] Socket::NetAddress GetNetAddress() const; 253 254 /** 255 * Obtain the parameters required in the communication process 256 * @return certificate and other relevant parameters 257 */ 258 [[nodiscard]] TLSSecureOptions GetTlsSecureOptions() const; 259 260 /** 261 * Get the check server ID callback function passed in by the API caller 262 * @return check the server identity callback function 263 */ 264 [[nodiscard]] CheckServerIdentity GetCheckServerIdentity() const; 265 266 /** 267 * Obtain the application layer protocol negotiation in the communication process 268 * @return application layer protocol negotiation 269 */ 270 [[nodiscard]] const std::vector<std::string> &GetAlpnProtocols() const; 271 272 /** 273 * Get the choice of whether to skip remote validaion 274 * @return skipRemoteValidaion result 275 */ 276 [[nodiscard]] bool GetSkipRemoteValidation() const; 277 278 void SetHostName(const std::string &hostName); 279 [[nodiscard]] std::string GetHostName() const; 280 281 std::shared_ptr<Socket::ProxyOptions> proxyOptions_{nullptr}; 282 283 private: 284 Socket::NetAddress address_; 285 TLSSecureOptions tlsSecureOptions_; 286 CheckServerIdentity checkServerIdentity_; 287 std::vector<std::string> alpnProtocols_; 288 bool skipRemoteValidation_ = false; 289 std::string hostName_; 290 }; 291 292 /** 293 * TLS socket interface class 294 */ 295 class TLSSocket : public std::enable_shared_from_this<TLSSocket> { 296 public: 297 TLSSocket(const TLSSocket &) = delete; 298 TLSSocket(TLSSocket &&) = delete; 299 300 TLSSocket &operator=(const TLSSocket &) = delete; 301 TLSSocket &operator=(TLSSocket &&) = delete; 302 303 TLSSocket() = default; 304 ~TLSSocket() = default; 305 TLSSocket(int sockFd)306 explicit TLSSocket(int sockFd): sockFd_(sockFd), isExtSock_(true) {} 307 308 /** 309 * Create a socket and bind to the address specified by address 310 * @param address ip address 311 * @param callback callback to the caller if bind ok or not 312 */ 313 void Bind(Socket::NetAddress &address, const BindCallback &callback); 314 315 /** 316 * Establish a secure connection based on the created socket 317 * @param tlsConnectOptions some options required during tls connection 318 * @param callback callback to the caller if connect ok or not 319 */ 320 void Connect(TLSConnectOptions &tlsConnectOptions, const ConnectCallback &callback); 321 322 /** 323 * Send data based on the created socket 324 * @param tcpSendOptions some options required during tcp data transmission 325 * @param callback callback to the caller if send ok or not 326 */ 327 void Send(const Socket::TCPSendOptions &tcpSendOptions, const SendCallback &callback); 328 329 /** 330 * Disconnect by releasing the socket when communicating 331 * @param callback callback to the caller 332 */ 333 void Close(const CloseCallback &callback); 334 335 /** 336 * Get the peer network address 337 * @param callback callback to the caller 338 */ 339 void GetRemoteAddress(const GetRemoteAddressCallback &callback); 340 341 /** 342 * Get the status of the current socket 343 * @param callback callback to the caller 344 */ 345 void GetState(const GetStateCallback &callback); 346 347 /** 348 * Gets or sets the options associated with the current socket 349 * @param tcpExtraOptions options associated with the current socket 350 * @param callback callback to the caller 351 */ 352 void SetExtraOptions(const Socket::TCPExtraOptions &tcpExtraOptions, const SetExtraOptionsCallback &callback); 353 354 /** 355 * Get a local digital certificate 356 * @param callback callback to the caller 357 */ 358 void GetCertificate(const GetCertificateCallback &callback); 359 360 /** 361 * Get the peer digital certificate 362 * @param needChain need chain 363 * @param callback callback to the caller 364 */ 365 void GetRemoteCertificate(const GetRemoteCertificateCallback &callback); 366 367 /** 368 * Obtain the protocol used in communication 369 * @param callback callback to the caller 370 */ 371 void GetProtocol(const GetProtocolCallback &callback); 372 373 /** 374 * Obtain the cipher suite used in communication 375 * @param callback callback to the caller 376 */ 377 void GetCipherSuite(const GetCipherSuiteCallback &callback); 378 379 /** 380 * Obtain the encryption algorithm used in the communication process 381 * @param callback callback to the caller 382 */ 383 void GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback &callback); 384 385 /** 386 * Register a callback which is called when message is received 387 * @param onMessageCallback callback which is called when message is received 388 */ 389 void OnMessage(const OnMessageCallback &onMessageCallback); 390 391 /** 392 * Register the callback that is called when the connection is established 393 * @param onConnectCallback callback invoked when connection is established 394 */ 395 void OnConnect(const OnConnectCallback &onConnectCallback); 396 397 /** 398 * Register the callback that is called when the connection is disconnected 399 * @param onCloseCallback callback invoked when disconnected 400 */ 401 void OnClose(const OnCloseCallback &onCloseCallback); 402 403 /** 404 * Register the callback that is called when an error occurs 405 * @param onErrorCallback callback invoked when an error occurs 406 */ 407 void OnError(const OnErrorCallback &onErrorCallback); 408 409 /** 410 * Unregister the callback which is called when message is received 411 */ 412 void OffMessage(); 413 414 /** 415 * Off Connect 416 */ 417 void OffConnect(); 418 419 /** 420 * Off Close 421 */ 422 void OffClose(); 423 424 /** 425 * Off Error 426 */ 427 void OffError(); 428 429 /** 430 * Get the socket file description of the server 431 */ 432 int GetSocketFd(); 433 434 /** 435 * Set the current socket file description address of the server 436 */ 437 void SetLocalAddress(const Socket::NetAddress &address); 438 439 /** 440 * Get the current socket file description address of the server 441 */ 442 Socket::NetAddress GetLocalAddress(); 443 444 void ExecTlsGetAddr( 445 const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr, socklen_t *len); 446 447 bool ExecTlsSetSockBlockFlag(int sock, bool noneBlock); 448 449 bool IsExtSock() const; 450 451 private: 452 class TLSSocketInternal final { 453 public: 454 TLSSocketInternal() = default; 455 ~TLSSocketInternal() = default; 456 457 /** 458 * Establish an encrypted connection on the specified socket 459 * @param sock socket for establishing encrypted connection 460 * @param options some options required during tls connection 461 * @param isExtSock socket fd is originated from external source when constructing tls socket 462 * @return whether the encrypted connection is successfully established 463 */ 464 bool TlsConnectToHost(int sock, const TLSConnectOptions &options, bool isExtSock); 465 466 /** 467 * Set the configuration items for establishing encrypted connections 468 * @param config configuration item when establishing encrypted connection 469 */ 470 void SetTlsConfiguration(const TLSConnectOptions &config); 471 472 /** 473 * Send data through an established encrypted connection 474 * @param data data sent over an established encrypted connection 475 * @return whether the data is successfully sent to the server 476 */ 477 bool Send(const std::string &data); 478 479 /** 480 * Receive the data sent by the server through the established encrypted connection 481 * @param buffer receive the data sent by the server 482 * @param maxBufferSize the size of the data received from the server 483 * @return whether the data sent by the server is successfully received 484 */ 485 int Recv(char *buffer, int maxBufferSize); 486 487 /** 488 * Disconnect encrypted connection 489 * @return whether the encrypted connection was successfully disconnected 490 */ 491 bool Close(); 492 493 /** 494 * Set the application layer negotiation protocol in the encrypted communication process 495 * @param alpnProtocols application layer negotiation protocol 496 * @return set whether the application layer negotiation protocol is successful during encrypted communication 497 */ 498 bool SetAlpnProtocols(const std::vector<std::string> &alpnProtocols); 499 500 /** 501 * Storage of server communication related network information 502 * @param remoteInfo communication related network information 503 */ 504 void MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo); 505 506 /** 507 * convert the code to ssl error code 508 * @return the value for ssl error code. 509 */ 510 int ConvertSSLError(void); 511 512 /** 513 * Get configuration options for encrypted communication process 514 * @return configuration options for encrypted communication processes 515 */ 516 [[nodiscard]] TLSConfiguration GetTlsConfiguration() const; 517 518 /** 519 * Obtain the cipher suite during encrypted communication 520 * @return crypto suite used in encrypted communication 521 */ 522 [[nodiscard]] std::vector<std::string> GetCipherSuite() const; 523 524 /** 525 * Obtain the peer certificate used in encrypted communication 526 * @return peer certificate used in encrypted communication 527 */ 528 [[nodiscard]] std::string GetRemoteCertificate() const; 529 530 /** 531 * Obtain the peer certificate used in encrypted communication 532 * @return peer certificate serialization data used in encrypted communication 533 */ 534 [[nodiscard]] const X509CertRawData &GetRemoteCertRawData() const; 535 536 /** 537 * Obtain the certificate used in encrypted communication 538 * @return certificate serialization data used in encrypted communication 539 */ 540 [[nodiscard]] const X509CertRawData &GetCertificate() const; 541 542 /** 543 * Get the encryption algorithm used in encrypted communication 544 * @return encryption algorithm used in encrypted communication 545 */ 546 [[nodiscard]] std::vector<std::string> GetSignatureAlgorithms() const; 547 548 /** 549 * Obtain the communication protocol used in encrypted communication 550 * @return communication protocol used in encrypted communication 551 */ 552 [[nodiscard]] std::string GetProtocol() const; 553 554 /** 555 * Set the information about the shared signature algorithm supported by peers during encrypted communication 556 * @return information about peer supported shared signature algorithms 557 */ 558 [[nodiscard]] bool SetSharedSigals(); 559 560 /** 561 * Obtain the ssl used in encrypted communication 562 * @return SSL used in encrypted communication 563 */ 564 [[nodiscard]] ssl_st *GetSSL(); 565 566 private: 567 bool SendRetry(ssl_st *ssl, const char *curPos, size_t curSendSize, int sockfd); 568 bool StartTlsConnected(const TLSConnectOptions &options); 569 bool CreatTlsContext(); 570 bool StartShakingHands(const TLSConnectOptions &options); 571 bool GetRemoteCertificateFromPeer(); 572 bool SetRemoteCertRawData(); 573 bool PollSend(int sockfd, ssl_st *ssl, const char *pdata, int sendSize); 574 std::string CheckServerIdentityLegal(const std::string &hostName, const X509 *x509Certificates); 575 std::string CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext, 576 const X509 *x509Certificates); 577 578 private: 579 std::mutex mutexForSsl_; 580 mutable std::shared_mutex rw_mutex_; 581 ssl_st *ssl_ = nullptr; 582 X509 *peerX509_ = nullptr; 583 uint16_t port_ = 0; 584 sa_family_t family_ = 0; 585 int32_t socketDescriptor_ = 0; 586 587 TLSContext tlsContext_; 588 TLSConfiguration configuration_; 589 Socket::NetAddress address_; 590 X509CertRawData remoteRawData_; 591 592 std::string hostName_; 593 std::string remoteCert_; 594 595 std::vector<std::string> signatureAlgorithms_; 596 std::unique_ptr<TLSContext> tlsContextPointer_ = nullptr; 597 }; 598 599 private: 600 TLSSocketInternal tlsSocketInternal_; 601 602 static std::string MakeAddressString(sockaddr *addr); 603 604 static void GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr, 605 socklen_t *len); 606 607 void CallOnMessageCallback(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo); 608 void CallOnConnectCallback(); 609 void CallOnCloseCallback(); 610 void CallOnErrorCallback(int32_t err, const std::string &errString); 611 612 void CallBindCallback(int32_t err, BindCallback callback); 613 void CallConnectCallback(int32_t err, ConnectCallback callback); 614 void CallSendCallback(int32_t err, SendCallback callback); 615 void CallCloseCallback(int32_t err, CloseCallback callback); 616 void CallGetRemoteAddressCallback(int32_t err, const Socket::NetAddress &address, 617 GetRemoteAddressCallback callback); 618 void CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, GetStateCallback callback); 619 void CallSetExtraOptionsCallback(int32_t err, SetExtraOptionsCallback callback); 620 void CallGetCertificateCallback(int32_t err, const X509CertRawData &cert, GetCertificateCallback callback); 621 void CallGetRemoteCertificateCallback(int32_t err, const X509CertRawData &cert, 622 GetRemoteCertificateCallback callback); 623 void CallGetProtocolCallback(int32_t err, const std::string &protocol, GetProtocolCallback callback); 624 void CallGetCipherSuiteCallback(int32_t err, const std::vector<std::string> &suite, 625 GetCipherSuiteCallback callback); 626 void CallGetSignatureAlgorithmsCallback(int32_t err, const std::vector<std::string> &algorithms, 627 GetSignatureAlgorithmsCallback callback); 628 629 int ReadMessage(); 630 void StartReadMessage(); 631 632 void GetIp4RemoteAddress(const GetRemoteAddressCallback &callback); 633 void GetIp6RemoteAddress(const GetRemoteAddressCallback &callback); 634 635 [[nodiscard]] bool SetBaseOptions(const Socket::ExtraOptionsBase &option) const; 636 [[nodiscard]] bool SetExtraOptions(const Socket::TCPExtraOptions &option) const; 637 638 void MakeIpSocket(sa_family_t family); 639 640 template<class T> DealCallback(int32_t err,T & callback)641 void DealCallback(int32_t err, T &callback) 642 { 643 if (callback) { 644 callback(err); 645 } 646 } 647 648 private: 649 static constexpr const size_t MAX_ERROR_LEN = 128; 650 static constexpr const size_t MAX_BUFFER_SIZE = 8192; 651 652 OnMessageCallback onMessageCallback_; 653 OnConnectCallback onConnectCallback_; 654 OnCloseCallback onCloseCallback_; 655 OnErrorCallback onErrorCallback_; 656 657 std::mutex mutex_; 658 std::mutex recvMutex_; 659 std::mutex cvMutex_; 660 bool isRunning_ = false; 661 bool isRunOver_ = true; 662 std::condition_variable cvSslFree_; 663 int sockFd_ = -1; 664 bool isExtSock_ = false; 665 Socket::NetAddress localAddress_; 666 }; 667 } // namespace TlsSocket 668 } // namespace NetStack 669 } // namespace OHOS 670 671 #endif // COMMUNICATIONNETSTACK_TLS_SOCEKT_H 672