1 /* 2 * Copyright (c) 2022-2023 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16 #ifndef COMMUNICATIONNETSTACK_TLS_SOCEKT_H 17 #define COMMUNICATIONNETSTACK_TLS_SOCEKT_H 18 19 #include <any> 20 #include <condition_variable> 21 #include <cstring> 22 #include <functional> 23 #include <map> 24 #include <thread> 25 #include <tuple> 26 #include <unistd.h> 27 #include <vector> 28 29 #include "extra_options_base.h" 30 #include "net_address.h" 31 #include "socket_error.h" 32 #include "socket_remote_info.h" 33 #include "socket_state_base.h" 34 #include "tcp_connect_options.h" 35 #include "tcp_extra_options.h" 36 #include "tcp_send_options.h" 37 #include "tls.h" 38 #include "tls_certificate.h" 39 #include "tls_configuration.h" 40 #include "tls_context.h" 41 #include "tls_key.h" 42 43 namespace OHOS { 44 namespace NetStack { 45 namespace TlsSocket { 46 47 using BindCallback = std::function<void(int32_t errorNumber)>; 48 using ConnectCallback = std::function<void(int32_t errorNumber)>; 49 using SendCallback = std::function<void(int32_t errorNumber)>; 50 using CloseCallback = std::function<void(int32_t errorNumber)>; 51 using GetRemoteAddressCallback = std::function<void(int32_t errorNumber, const Socket::NetAddress &address)>; 52 using GetStateCallback = std::function<void(int32_t errorNumber, const Socket::SocketStateBase &state)>; 53 using SetExtraOptionsCallback = std::function<void(int32_t errorNumber)>; 54 using GetCertificateCallback = std::function<void(int32_t errorNumber, const X509CertRawData &cert)>; 55 using GetRemoteCertificateCallback = std::function<void(int32_t errorNumber, const X509CertRawData &cert)>; 56 using GetProtocolCallback = std::function<void(int32_t errorNumber, const std::string &protocol)>; 57 using GetCipherSuiteCallback = std::function<void(int32_t errorNumber, const std::vector<std::string> &suite)>; 58 using GetSignatureAlgorithmsCallback = 59 std::function<void(int32_t errorNumber, const std::vector<std::string> &algorithms)>; 60 61 using OnMessageCallback = std::function<void(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)>; 62 using OnConnectCallback = std::function<void(void)>; 63 using OnCloseCallback = std::function<void(void)>; 64 using OnErrorCallback = std::function<void(int32_t errorNumber, const std::string &errorString)>; 65 66 using CheckServerIdentity = 67 std::function<void(const std::string &hostName, const std::vector<std::string> &x509Certificates)>; 68 69 constexpr const char *ALPN_PROTOCOLS_HTTP_1_1 = "http1.1"; 70 constexpr const char *ALPN_PROTOCOLS_HTTP_2 = "h2"; 71 72 constexpr size_t MAX_ERR_LEN = 1024; 73 74 /** 75 * Parameters required during communication 76 */ 77 class TLSSecureOptions { 78 public: 79 TLSSecureOptions() = default; 80 ~TLSSecureOptions() = default; 81 82 TLSSecureOptions(const TLSSecureOptions &tlsSecureOptions); 83 TLSSecureOptions &operator=(const TLSSecureOptions &tlsSecureOptions); 84 /** 85 * Set root CA Chain to verify the server cert 86 * @param caChain root certificate chain used to validate server certificates 87 */ 88 void SetCaChain(const std::vector<std::string> &caChain); 89 90 /** 91 * Set digital certificate for server verification 92 * @param cert digital certificate sent to the server to verify validity 93 */ 94 void SetCert(const std::string &cert); 95 96 /** 97 * Set key to decrypt server data 98 * @param keyChain key used to decrypt server data 99 */ 100 void SetKey(const SecureData &key); 101 102 /** 103 * Set the password to read the private key 104 * @param keyPass read the password of the private key 105 */ 106 void SetKeyPass(const SecureData &keyPass); 107 108 /** 109 * Set the protocol used in communication 110 * @param protocolChain protocol version number used 111 */ 112 void SetProtocolChain(const std::vector<std::string> &protocolChain); 113 114 /** 115 * Whether the peer cipher suite is preferred for communication 116 * @param useRemoteCipherPrefer whether the peer cipher suite is preferred 117 */ 118 void SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer); 119 120 /** 121 * Encryption algorithm used in communication 122 * @param signatureAlgorithms encryption algorithm e.g: rsa 123 */ 124 void SetSignatureAlgorithms(const std::string &signatureAlgorithms); 125 126 /** 127 * Crypto suite used in communication 128 * @param cipherSuite cipher suite e.g:AES256-SHA256 129 */ 130 void SetCipherSuite(const std::string &cipherSuite); 131 132 /** 133 * Set a revoked certificate 134 * @param crlChain certificate Revocation List 135 */ 136 void SetCrlChain(const std::vector<std::string> &crlChain); 137 138 /** 139 * Get root CA Chain to verify the server cert 140 * @return root CA chain 141 */ 142 [[nodiscard]] const std::vector<std::string> &GetCaChain() const; 143 144 /** 145 * Obtain a certificate to send to the server for checking 146 * @return digital certificate obtained 147 */ 148 [[nodiscard]] const std::string &GetCert() const; 149 150 /** 151 * Obtain the private key in the communication process 152 * @return private key during communication 153 */ 154 [[nodiscard]] const SecureData &GetKey() const; 155 156 /** 157 * Get the password to read the private key 158 * @return read the password of the private key 159 */ 160 [[nodiscard]] const SecureData &GetKeyPass() const; 161 162 /** 163 * Get the protocol of the communication process 164 * @return protocol of communication process 165 */ 166 [[nodiscard]] const std::vector<std::string> &GetProtocolChain() const; 167 168 /** 169 * Is the remote cipher suite being used for communication 170 * @return is use Remote Cipher Prefer 171 */ 172 [[nodiscard]] bool UseRemoteCipherPrefer() const; 173 174 /** 175 * Obtain the encryption algorithm used in the communication process 176 * @return encryption algorithm used in communication 177 */ 178 [[nodiscard]] const std::string &GetSignatureAlgorithms() const; 179 180 /** 181 * Obtain the cipher suite used in communication 182 * @return crypto suite used in communication 183 */ 184 [[nodiscard]] const std::string &GetCipherSuite() const; 185 186 /** 187 * Get revoked certificate chain 188 * @return revoked certificate chain 189 */ 190 [[nodiscard]] const std::vector<std::string> &GetCrlChain() const; 191 192 void SetVerifyMode(VerifyMode verifyMode); 193 194 [[nodiscard]] VerifyMode GetVerifyMode() const; 195 196 private: 197 std::vector<std::string> caChain_; 198 std::string cert_; 199 SecureData key_; 200 SecureData keyPass_; 201 std::vector<std::string> protocolChain_; 202 bool useRemoteCipherPrefer_ = false; 203 std::string signatureAlgorithms_; 204 std::string cipherSuite_; 205 std::vector<std::string> crlChain_; 206 VerifyMode TLSVerifyMode_ = VerifyMode::ONE_WAY_MODE; 207 }; 208 209 /** 210 * Some options required during tls connection 211 */ 212 class TLSConnectOptions { 213 public: 214 /** 215 * Communication parameters required for connection establishment 216 * @param address communication parameters during connection 217 */ 218 void SetNetAddress(const Socket::NetAddress &address); 219 220 /** 221 * Parameters required during communication 222 * @param tlsSecureOptions certificate and other relevant parameters 223 */ 224 void SetTlsSecureOptions(TLSSecureOptions &tlsSecureOptions); 225 226 /** 227 * Set the callback function to check the validity of the server 228 * @param checkServerIdentity callback function passed in by API caller 229 */ 230 void SetCheckServerIdentity(const CheckServerIdentity &checkServerIdentity); 231 232 /** 233 * Set application layer protocol negotiation 234 * @param alpnProtocols application layer protocol negotiation 235 */ 236 void SetAlpnProtocols(const std::vector<std::string> &alpnProtocols); 237 238 /** 239 * Obtain the network address of the communication process 240 * @return network address 241 */ 242 [[nodiscard]] Socket::NetAddress GetNetAddress() const; 243 244 /** 245 * Obtain the parameters required in the communication process 246 * @return certificate and other relevant parameters 247 */ 248 [[nodiscard]] TLSSecureOptions GetTlsSecureOptions() const; 249 250 /** 251 * Get the check server ID callback function passed in by the API caller 252 * @return check the server identity callback function 253 */ 254 [[nodiscard]] CheckServerIdentity GetCheckServerIdentity() const; 255 256 /** 257 * Obtain the application layer protocol negotiation in the communication process 258 * @return application layer protocol negotiation 259 */ 260 [[nodiscard]] const std::vector<std::string> &GetAlpnProtocols() const; 261 262 private: 263 Socket::NetAddress address_; 264 TLSSecureOptions tlsSecureOptions_; 265 CheckServerIdentity checkServerIdentity_; 266 std::vector<std::string> alpnProtocols_; 267 }; 268 269 /** 270 * TLS socket interface class 271 */ 272 class TLSSocket { 273 public: 274 TLSSocket(const TLSSocket &) = delete; 275 TLSSocket(TLSSocket &&) = delete; 276 277 TLSSocket &operator=(const TLSSocket &) = delete; 278 TLSSocket &operator=(TLSSocket &&) = delete; 279 280 TLSSocket() = default; 281 ~TLSSocket() = default; 282 283 /** 284 * Create a socket and bind to the address specified by address 285 * @param address ip address 286 * @param callback callback to the caller if bind ok or not 287 */ 288 void Bind(const Socket::NetAddress &address, const BindCallback &callback); 289 290 /** 291 * Establish a secure connection based on the created socket 292 * @param tlsConnectOptions some options required during tls connection 293 * @param callback callback to the caller if connect ok or not 294 */ 295 void Connect(TLSConnectOptions &tlsConnectOptions, const ConnectCallback &callback); 296 297 /** 298 * Send data based on the created socket 299 * @param tcpSendOptions some options required during tcp data transmission 300 * @param callback callback to the caller if send ok or not 301 */ 302 void Send(const Socket::TCPSendOptions &tcpSendOptions, const SendCallback &callback); 303 304 /** 305 * Disconnect by releasing the socket when communicating 306 * @param callback callback to the caller 307 */ 308 void Close(const CloseCallback &callback); 309 310 /** 311 * Get the peer network address 312 * @param callback callback to the caller 313 */ 314 void GetRemoteAddress(const GetRemoteAddressCallback &callback); 315 316 /** 317 * Get the status of the current socket 318 * @param callback callback to the caller 319 */ 320 void GetState(const GetStateCallback &callback); 321 322 /** 323 * Gets or sets the options associated with the current socket 324 * @param tcpExtraOptions options associated with the current socket 325 * @param callback callback to the caller 326 */ 327 void SetExtraOptions(const Socket::TCPExtraOptions &tcpExtraOptions, const SetExtraOptionsCallback &callback); 328 329 /** 330 * Get a local digital certificate 331 * @param callback callback to the caller 332 */ 333 void GetCertificate(const GetCertificateCallback &callback); 334 335 /** 336 * Get the peer digital certificate 337 * @param needChain need chain 338 * @param callback callback to the caller 339 */ 340 void GetRemoteCertificate(const GetRemoteCertificateCallback &callback); 341 342 /** 343 * Obtain the protocol used in communication 344 * @param callback callback to the caller 345 */ 346 void GetProtocol(const GetProtocolCallback &callback); 347 348 /** 349 * Obtain the cipher suite used in communication 350 * @param callback callback to the caller 351 */ 352 void GetCipherSuite(const GetCipherSuiteCallback &callback); 353 354 /** 355 * Obtain the encryption algorithm used in the communication process 356 * @param callback callback to the caller 357 */ 358 void GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback &callback); 359 360 /** 361 * Register a callback which is called when message is received 362 * @param onMessageCallback callback which is called when message is received 363 */ 364 void OnMessage(const OnMessageCallback &onMessageCallback); 365 366 /** 367 * Register the callback that is called when the connection is established 368 * @param onConnectCallback callback invoked when connection is established 369 */ 370 void OnConnect(const OnConnectCallback &onConnectCallback); 371 372 /** 373 * Register the callback that is called when the connection is disconnected 374 * @param onCloseCallback callback invoked when disconnected 375 */ 376 void OnClose(const OnCloseCallback &onCloseCallback); 377 378 /** 379 * Register the callback that is called when an error occurs 380 * @param onErrorCallback callback invoked when an error occurs 381 */ 382 void OnError(const OnErrorCallback &onErrorCallback); 383 384 /** 385 * Unregister the callback which is called when message is received 386 */ 387 void OffMessage(); 388 389 /** 390 * Off Connect 391 */ 392 void OffConnect(); 393 394 /** 395 * Off Close 396 */ 397 void OffClose(); 398 399 /** 400 * Off Error 401 */ 402 void OffError(); 403 404 private: 405 class TLSSocketInternal final { 406 public: 407 TLSSocketInternal() = default; 408 ~TLSSocketInternal() = default; 409 410 /** 411 * Establish an encrypted connection on the specified socket 412 * @param sock socket for establishing encrypted connection 413 * @param options some options required during tls connection 414 * @return whether the encrypted connection is successfully established 415 */ 416 bool TlsConnectToHost(int sock, const TLSConnectOptions &options); 417 418 /** 419 * Set the configuration items for establishing encrypted connections 420 * @param config configuration item when establishing encrypted connection 421 */ 422 void SetTlsConfiguration(const TLSConnectOptions &config); 423 424 /** 425 * Send data through an established encrypted connection 426 * @param data data sent over an established encrypted connection 427 * @return whether the data is successfully sent to the server 428 */ 429 bool Send(const std::string &data); 430 431 /** 432 * Receive the data sent by the server through the established encrypted connection 433 * @param buffer receive the data sent by the server 434 * @param maxBufferSize the size of the data received from the server 435 * @return whether the data sent by the server is successfully received 436 */ 437 int Recv(char *buffer, int maxBufferSize); 438 439 /** 440 * Disconnect encrypted connection 441 * @return whether the encrypted connection was successfully disconnected 442 */ 443 bool Close(); 444 445 /** 446 * Set the application layer negotiation protocol in the encrypted communication process 447 * @param alpnProtocols application layer negotiation protocol 448 * @return set whether the application layer negotiation protocol is successful during encrypted communication 449 */ 450 bool SetAlpnProtocols(const std::vector<std::string> &alpnProtocols); 451 452 /** 453 * Storage of server communication related network information 454 * @param remoteInfo communication related network information 455 */ 456 void MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo); 457 458 /** 459 * Get configuration options for encrypted communication process 460 * @return configuration options for encrypted communication processes 461 */ 462 [[nodiscard]] TLSConfiguration GetTlsConfiguration() const; 463 464 /** 465 * Obtain the cipher suite during encrypted communication 466 * @return crypto suite used in encrypted communication 467 */ 468 [[nodiscard]] std::vector<std::string> GetCipherSuite() const; 469 470 /** 471 * Obtain the peer certificate used in encrypted communication 472 * @return peer certificate used in encrypted communication 473 */ 474 [[nodiscard]] std::string GetRemoteCertificate() const; 475 476 /** 477 * Obtain the peer certificate used in encrypted communication 478 * @return peer certificate serialization data used in encrypted communication 479 */ 480 [[nodiscard]] const X509CertRawData &GetRemoteCertRawData() const; 481 482 /** 483 * Obtain the certificate used in encrypted communication 484 * @return certificate serialization data used in encrypted communication 485 */ 486 [[nodiscard]] const X509CertRawData &GetCertificate() const; 487 488 /** 489 * Get the encryption algorithm used in encrypted communication 490 * @return encryption algorithm used in encrypted communication 491 */ 492 [[nodiscard]] std::vector<std::string> GetSignatureAlgorithms() const; 493 494 /** 495 * Obtain the communication protocol used in encrypted communication 496 * @return communication protocol used in encrypted communication 497 */ 498 [[nodiscard]] std::string GetProtocol() const; 499 500 /** 501 * Set the information about the shared signature algorithm supported by peers during encrypted communication 502 * @return information about peer supported shared signature algorithms 503 */ 504 [[nodiscard]] bool SetSharedSigals(); 505 506 /** 507 * Obtain the ssl used in encrypted communication 508 * @return SSL used in encrypted communication 509 */ 510 [[nodiscard]] ssl_st *GetSSL() const; 511 512 private: 513 bool StartTlsConnected(const TLSConnectOptions &options); 514 bool CreatTlsContext(); 515 bool StartShakingHands(const TLSConnectOptions &options); 516 bool GetRemoteCertificateFromPeer(); 517 bool SetRemoteCertRawData(); 518 std::string CheckServerIdentityLegal(const std::string &hostName, const X509 *x509Certificates); 519 std::string CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext, 520 const X509 *x509Certificates); 521 522 private: 523 ssl_st *ssl_ = nullptr; 524 X509 *peerX509_ = nullptr; 525 uint16_t port_ = 0; 526 sa_family_t family_ = 0; 527 int32_t socketDescriptor_ = 0; 528 529 TLSContext tlsContext_; 530 TLSConfiguration configuration_; 531 Socket::NetAddress address_; 532 X509CertRawData remoteRawData_; 533 534 std::string hostName_; 535 std::string remoteCert_; 536 537 std::vector<std::string> signatureAlgorithms_; 538 std::unique_ptr<TLSContext> tlsContextPointer_ = nullptr; 539 }; 540 541 private: 542 TLSSocketInternal tlsSocketInternal_; 543 544 static std::string MakeAddressString(sockaddr *addr); 545 546 static void GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr, 547 socklen_t *len); 548 549 void CallOnMessageCallback(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo); 550 void CallOnConnectCallback(); 551 void CallOnCloseCallback(); 552 void CallOnErrorCallback(int32_t err, const std::string &errString); 553 554 void CallBindCallback(int32_t err, BindCallback callback); 555 void CallConnectCallback(int32_t err, ConnectCallback callback); 556 void CallSendCallback(int32_t err, SendCallback callback); 557 void CallCloseCallback(int32_t err, CloseCallback callback); 558 void CallGetRemoteAddressCallback(int32_t err, const Socket::NetAddress &address, 559 GetRemoteAddressCallback callback); 560 void CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, GetStateCallback callback); 561 void CallSetExtraOptionsCallback(int32_t err, SetExtraOptionsCallback callback); 562 void CallGetCertificateCallback(int32_t err, const X509CertRawData &cert, GetCertificateCallback callback); 563 void CallGetRemoteCertificateCallback(int32_t err, const X509CertRawData &cert, 564 GetRemoteCertificateCallback callback); 565 void CallGetProtocolCallback(int32_t err, const std::string &protocol, GetProtocolCallback callback); 566 void CallGetCipherSuiteCallback(int32_t err, const std::vector<std::string> &suite, 567 GetCipherSuiteCallback callback); 568 void CallGetSignatureAlgorithmsCallback(int32_t err, const std::vector<std::string> &algorithms, 569 GetSignatureAlgorithmsCallback callback); 570 571 void StartReadMessage(); 572 573 void GetIp4RemoteAddress(const GetRemoteAddressCallback &callback); 574 void GetIp6RemoteAddress(const GetRemoteAddressCallback &callback); 575 576 [[nodiscard]] bool SetBaseOptions(const Socket::ExtraOptionsBase &option) const; 577 [[nodiscard]] bool SetExtraOptions(const Socket::TCPExtraOptions &option) const; 578 579 void MakeIpSocket(sa_family_t family); 580 581 template<class T> DealCallback(int32_t err,T & callback)582 void DealCallback(int32_t err, T &callback) 583 { 584 T func = nullptr; 585 { 586 std::lock_guard<std::mutex> lock(mutex_); 587 if (callback) { 588 func = callback; 589 } 590 } 591 592 if (func) { 593 func(err); 594 } 595 } 596 597 private: 598 static constexpr const size_t MAX_ERROR_LEN = 128; 599 static constexpr const size_t MAX_BUFFER_SIZE = 8192; 600 601 OnMessageCallback onMessageCallback_; 602 OnConnectCallback onConnectCallback_; 603 OnCloseCallback onCloseCallback_; 604 OnErrorCallback onErrorCallback_; 605 606 std::mutex mutex_; 607 bool isRunning_ = false; 608 bool isRunOver_ = true; 609 610 int sockFd_ = -1; 611 }; 612 } // namespace TlsSocket 613 } // namespace NetStack 614 } // namespace OHOS 615 616 #endif // COMMUNICATIONNETSTACK_TLS_SOCEKT_H 617