1 /* 2 * Copyright (c) 2022 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16 #ifndef COMMUNICATIONNETSTACK_TLS_SOCEKT_H 17 #define COMMUNICATIONNETSTACK_TLS_SOCEKT_H 18 19 #include <any> 20 #include <condition_variable> 21 #include <cstring> 22 #include <functional> 23 #include <map> 24 #include <thread> 25 #include <tuple> 26 #include <unistd.h> 27 #include <vector> 28 29 #include "extra_options_base.h" 30 #include "net_address.h" 31 #include "socket_error.h" 32 #include "socket_remote_info.h" 33 #include "socket_state_base.h" 34 #include "tcp_connect_options.h" 35 #include "tcp_extra_options.h" 36 #include "tcp_send_options.h" 37 #include "tls.h" 38 #include "tls_certificate.h" 39 #include "tls_configuration.h" 40 #include "tls_context.h" 41 #include "tls_key.h" 42 43 namespace OHOS { 44 namespace NetStack { 45 46 using BindCallback = std::function<void(int32_t errorNumber)>; 47 using ConnectCallback = std::function<void(int32_t errorNumber)>; 48 using SendCallback = std::function<void(int32_t errorNumber)>; 49 using CloseCallback = std::function<void(int32_t errorNumber)>; 50 using GetRemoteAddressCallback = std::function<void(int32_t errorNumber, const NetAddress &address)>; 51 using GetStateCallback = std::function<void(int32_t errorNumber, const SocketStateBase &state)>; 52 using SetExtraOptionsCallback = std::function<void(int32_t errorNumber)>; 53 using GetCertificateCallback = std::function<void(int32_t errorNumber, const X509CertRawData &cert)>; 54 using GetRemoteCertificateCallback = std::function<void(int32_t errorNumber, const X509CertRawData &cert)>; 55 using GetProtocolCallback = std::function<void(int32_t errorNumber, const std::string &protocol)>; 56 using GetCipherSuiteCallback = std::function<void(int32_t errorNumber, const std::vector<std::string> &suite)>; 57 using GetSignatureAlgorithmsCallback = 58 std::function<void(int32_t errorNumber, const std::vector<std::string> &algorithms)>; 59 60 using OnMessageCallback = std::function<void(const std::string &data, const SocketRemoteInfo &remoteInfo)>; 61 using OnConnectCallback = std::function<void(void)>; 62 using OnCloseCallback = std::function<void(void)>; 63 using OnErrorCallback = std::function<void(int32_t errorNumber, const std::string &errorString)>; 64 65 using CheckServerIdentity = 66 std::function<void(const std::string &hostName, const std::vector<std::string> &x509Certificates)>; 67 68 constexpr const char *ALPN_PROTOCOLS_HTTP_1_1 = "http1.1"; 69 constexpr const char *ALPN_PROTOCOLS_HTTP_2 = "h2"; 70 71 constexpr size_t MAX_ERR_LEN = 1024; 72 73 /** 74 * Parameters required during communication 75 */ 76 class TLSSecureOptions { 77 public: 78 TLSSecureOptions() = default; 79 ~TLSSecureOptions() = default; 80 81 TLSSecureOptions(const TLSSecureOptions &tlsSecureOptions); 82 TLSSecureOptions &operator=(const TLSSecureOptions &tlsSecureOptions); 83 /** 84 * Set root CA Chain to verify the server cert 85 * @param caChain root certificate chain used to validate server certificates 86 */ 87 void SetCaChain(const std::vector<std::string> &caChain); 88 89 /** 90 * Set digital certificate for server verification 91 * @param cert digital certificate sent to the server to verify validity 92 */ 93 void SetCert(const std::string &cert); 94 95 /** 96 * Set key to decrypt server data 97 * @param keyChain key used to decrypt server data 98 */ 99 void SetKey(const SecureData &key); 100 101 /** 102 * Set the password to read the private key 103 * @param keyPass read the password of the private key 104 */ 105 void SetKeyPass(const SecureData &keyPass); 106 107 /** 108 * Set the protocol used in communication 109 * @param protocolChain protocol version number used 110 */ 111 void SetProtocolChain(const std::vector<std::string> &protocolChain); 112 113 /** 114 * Whether the peer cipher suite is preferred for communication 115 * @param useRemoteCipherPrefer whether the peer cipher suite is preferred 116 */ 117 void SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer); 118 119 /** 120 * Encryption algorithm used in communication 121 * @param signatureAlgorithms encryption algorithm e.g: rsa 122 */ 123 void SetSignatureAlgorithms(const std::string &signatureAlgorithms); 124 125 /** 126 * Crypto suite used in communication 127 * @param cipherSuite cipher suite e.g:AES256-SHA256 128 */ 129 void SetCipherSuite(const std::string &cipherSuite); 130 131 /** 132 * Set a revoked certificate 133 * @param crlChain certificate Revocation List 134 */ 135 void SetCrlChain(const std::vector<std::string> &crlChain); 136 137 /** 138 * Get root CA Chain to verify the server cert 139 * @return root CA chain 140 */ 141 [[nodiscard]] const std::vector<std::string> &GetCaChain() const; 142 143 /** 144 * Obtain a certificate to send to the server for checking 145 * @return digital certificate obtained 146 */ 147 [[nodiscard]] const std::string &GetCert() const; 148 149 /** 150 * Obtain the private key in the communication process 151 * @return private key during communication 152 */ 153 [[nodiscard]] const SecureData &GetKey() const; 154 155 /** 156 * Get the password to read the private key 157 * @return read the password of the private key 158 */ 159 [[nodiscard]] const SecureData &GetKeyPass() const; 160 161 /** 162 * Get the protocol of the communication process 163 * @return protocol of communication process 164 */ 165 [[nodiscard]] const std::vector<std::string> &GetProtocolChain() const; 166 167 /** 168 * Is the remote cipher suite being used for communication 169 * @return is use Remote Cipher Prefer 170 */ 171 [[nodiscard]] bool UseRemoteCipherPrefer() const; 172 173 /** 174 * Obtain the encryption algorithm used in the communication process 175 * @return encryption algorithm used in communication 176 */ 177 [[nodiscard]] const std::string &GetSignatureAlgorithms() const; 178 179 /** 180 * Obtain the cipher suite used in communication 181 * @return crypto suite used in communication 182 */ 183 [[nodiscard]] const std::string &GetCipherSuite() const; 184 185 /** 186 * Get revoked certificate chain 187 * @return revoked certificate chain 188 */ 189 [[nodiscard]] const std::vector<std::string> &GetCrlChain() const; 190 191 private: 192 std::vector<std::string> caChain_; 193 std::string cert_; 194 SecureData key_; 195 SecureData keyPass_; 196 std::vector<std::string> protocolChain_; 197 bool useRemoteCipherPrefer_ = false; 198 std::string signatureAlgorithms_; 199 std::string cipherSuite_; 200 std::vector<std::string> crlChain_; 201 }; 202 203 /** 204 * Some options required during tls connection 205 */ 206 class TLSConnectOptions { 207 public: 208 /** 209 * Communication parameters required for connection establishment 210 * @param address communication parameters during connection 211 */ 212 void SetNetAddress(const NetAddress &address); 213 214 /** 215 * Parameters required during communication 216 * @param tlsSecureOptions certificate and other relevant parameters 217 */ 218 void SetTlsSecureOptions(TLSSecureOptions &tlsSecureOptions); 219 220 /** 221 * Set the callback function to check the validity of the server 222 * @param checkServerIdentity callback function passed in by API caller 223 */ 224 void SetCheckServerIdentity(const CheckServerIdentity &checkServerIdentity); 225 226 /** 227 * Set application layer protocol negotiation 228 * @param alpnProtocols application layer protocol negotiation 229 */ 230 void SetAlpnProtocols(const std::vector<std::string> &alpnProtocols); 231 232 /** 233 * Obtain the network address of the communication process 234 * @return network address 235 */ 236 [[nodiscard]] NetAddress GetNetAddress() const; 237 238 /** 239 * Obtain the parameters required in the communication process 240 * @return certificate and other relevant parameters 241 */ 242 [[nodiscard]] TLSSecureOptions GetTlsSecureOptions() const; 243 244 /** 245 * Get the check server ID callback function passed in by the API caller 246 * @return check the server identity callback function 247 */ 248 [[nodiscard]] CheckServerIdentity GetCheckServerIdentity() const; 249 250 /** 251 * Obtain the application layer protocol negotiation in the communication process 252 * @return application layer protocol negotiation 253 */ 254 [[nodiscard]] const std::vector<std::string> &GetAlpnProtocols() const; 255 256 private: 257 NetAddress address_; 258 TLSSecureOptions tlsSecureOptions_; 259 CheckServerIdentity checkServerIdentity_; 260 std::vector<std::string> alpnProtocols_; 261 }; 262 263 /** 264 * TLS socket interface class 265 */ 266 class TLSSocket { 267 public: 268 TLSSocket(const TLSSocket &) = delete; 269 TLSSocket(TLSSocket &&) = delete; 270 271 TLSSocket &operator=(const TLSSocket &) = delete; 272 TLSSocket &operator=(TLSSocket &&) = delete; 273 274 TLSSocket() = default; 275 ~TLSSocket() = default; 276 277 /** 278 * Create a socket and bind to the address specified by address 279 * @param address ip address 280 * @param callback callback to the caller if bind ok or not 281 */ 282 void Bind(const NetAddress &address, const BindCallback &callback); 283 284 /** 285 * Establish a secure connection based on the created socket 286 * @param tlsConnectOptions some options required during tls connection 287 * @param callback callback to the caller if connect ok or not 288 */ 289 void Connect(TLSConnectOptions &tlsConnectOptions, const ConnectCallback &callback); 290 291 /** 292 * Send data based on the created socket 293 * @param tcpSendOptions some options required during tcp data transmission 294 * @param callback callback to the caller if send ok or not 295 */ 296 void Send(const TCPSendOptions &tcpSendOptions, const SendCallback &callback); 297 298 /** 299 * Disconnect by releasing the socket when communicating 300 * @param callback callback to the caller 301 */ 302 void Close(const CloseCallback &callback); 303 304 /** 305 * Get the peer network address 306 * @param callback callback to the caller 307 */ 308 void GetRemoteAddress(const GetRemoteAddressCallback &callback); 309 310 /** 311 * Get the status of the current socket 312 * @param callback callback to the caller 313 */ 314 void GetState(const GetStateCallback &callback); 315 316 /** 317 * Gets or sets the options associated with the current socket 318 * @param tcpExtraOptions options associated with the current socket 319 * @param callback callback to the caller 320 */ 321 void SetExtraOptions(const TCPExtraOptions &tcpExtraOptions, const SetExtraOptionsCallback &callback); 322 323 /** 324 * Get a local digital certificate 325 * @param callback callback to the caller 326 */ 327 void GetCertificate(const GetCertificateCallback &callback); 328 329 /** 330 * Get the peer digital certificate 331 * @param needChain need chain 332 * @param callback callback to the caller 333 */ 334 void GetRemoteCertificate(const GetRemoteCertificateCallback &callback); 335 336 /** 337 * Obtain the protocol used in communication 338 * @param callback callback to the caller 339 */ 340 void GetProtocol(const GetProtocolCallback &callback); 341 342 /** 343 * Obtain the cipher suite used in communication 344 * @param callback callback to the caller 345 */ 346 void GetCipherSuite(const GetCipherSuiteCallback &callback); 347 348 /** 349 * Obtain the encryption algorithm used in the communication process 350 * @param callback callback to the caller 351 */ 352 void GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback &callback); 353 354 /** 355 * Register a callback which is called when message is received 356 * @param onMessageCallback callback which is called when message is received 357 */ 358 void OnMessage(const OnMessageCallback &onMessageCallback); 359 360 /** 361 * Register the callback that is called when the connection is established 362 * @param onConnectCallback callback invoked when connection is established 363 */ 364 void OnConnect(const OnConnectCallback &onConnectCallback); 365 366 /** 367 * Register the callback that is called when the connection is disconnected 368 * @param onCloseCallback callback invoked when disconnected 369 */ 370 void OnClose(const OnCloseCallback &onCloseCallback); 371 372 /** 373 * Register the callback that is called when an error occurs 374 * @param onErrorCallback callback invoked when an error occurs 375 */ 376 void OnError(const OnErrorCallback &onErrorCallback); 377 378 /** 379 * Unregister the callback which is called when message is received 380 */ 381 void OffMessage(); 382 383 /** 384 * Off Connect 385 */ 386 void OffConnect(); 387 388 /** 389 * Off Close 390 */ 391 void OffClose(); 392 393 /** 394 * Off Error 395 */ 396 void OffError(); 397 398 private: 399 class TLSSocketInternal final { 400 public: 401 TLSSocketInternal() = default; 402 ~TLSSocketInternal() = default; 403 404 /** 405 * Establish an encrypted connection on the specified socket 406 * @param sock socket for establishing encrypted connection 407 * @param options some options required during tls connection 408 * @return whether the encrypted connection is successfully established 409 */ 410 bool TlsConnectToHost(int sock, const TLSConnectOptions &options); 411 412 /** 413 * Set the configuration items for establishing encrypted connections 414 * @param config configuration item when establishing encrypted connection 415 */ 416 void SetTlsConfiguration(const TLSConnectOptions &config); 417 418 /** 419 * Send data through an established encrypted connection 420 * @param data data sent over an established encrypted connection 421 * @return whether the data is successfully sent to the server 422 */ 423 bool Send(const std::string &data); 424 425 /** 426 * Receive the data sent by the server through the established encrypted connection 427 * @param buffer receive the data sent by the server 428 * @param maxBufferSize the size of the data received from the server 429 * @return whether the data sent by the server is successfully received 430 */ 431 int Recv(char *buffer, int maxBufferSize); 432 433 /** 434 * Disconnect encrypted connection 435 * @return whether the encrypted connection was successfully disconnected 436 */ 437 bool Close(); 438 439 /** 440 * Set the application layer negotiation protocol in the encrypted communication process 441 * @param alpnProtocols application layer negotiation protocol 442 * @return set whether the application layer negotiation protocol is successful during encrypted communication 443 */ 444 bool SetAlpnProtocols(const std::vector<std::string> &alpnProtocols); 445 446 /** 447 * Storage of server communication related network information 448 * @param remoteInfo communication related network information 449 */ 450 void MakeRemoteInfo(SocketRemoteInfo &remoteInfo); 451 452 /** 453 * Get configuration options for encrypted communication process 454 * @return configuration options for encrypted communication processes 455 */ 456 [[nodiscard]] TLSConfiguration GetTlsConfiguration() const; 457 458 /** 459 * Obtain the cipher suite during encrypted communication 460 * @return crypto suite used in encrypted communication 461 */ 462 [[nodiscard]] std::vector<std::string> GetCipherSuite() const; 463 464 /** 465 * Obtain the peer certificate used in encrypted communication 466 * @return peer certificate used in encrypted communication 467 */ 468 [[nodiscard]] std::string GetRemoteCertificate() const; 469 470 /** 471 * Obtain the peer certificate used in encrypted communication 472 * @return peer certificate serialization data used in encrypted communication 473 */ 474 [[nodiscard]] const X509CertRawData &GetRemoteCertRawData() const; 475 476 /** 477 * Obtain the certificate used in encrypted communication 478 * @return certificate serialization data used in encrypted communication 479 */ 480 [[nodiscard]] const X509CertRawData &GetCertificate() const; 481 482 /** 483 * Get the encryption algorithm used in encrypted communication 484 * @return encryption algorithm used in encrypted communication 485 */ 486 [[nodiscard]] std::vector<std::string> GetSignatureAlgorithms() const; 487 488 /** 489 * Obtain the communication protocol used in encrypted communication 490 * @return communication protocol used in encrypted communication 491 */ 492 [[nodiscard]] std::string GetProtocol() const; 493 494 /** 495 * Set the information about the shared signature algorithm supported by peers during encrypted communication 496 * @return information about peer supported shared signature algorithms 497 */ 498 [[nodiscard]] bool SetSharedSigals(); 499 500 /** 501 * Obtain the ssl used in encrypted communication 502 * @return SSL used in encrypted communication 503 */ 504 [[nodiscard]] ssl_st *GetSSL() const; 505 506 private: 507 bool StartTlsConnected(const TLSConnectOptions &options); 508 bool CreatTlsContext(); 509 bool StartShakingHands(const TLSConnectOptions &options); 510 bool GetRemoteCertificateFromPeer(); 511 bool SetRemoteCertRawData(); 512 std::string CheckServerIdentityLegal(const std::string &hostName, const X509 *x509Certificates); 513 514 private: 515 ssl_st *ssl_ = nullptr; 516 X509 *peerX509_ = nullptr; 517 uint16_t port_ = 0; 518 sa_family_t family_ = 0; 519 int32_t socketDescriptor_ = 0; 520 521 TLSContext tlsContext_; 522 TLSConfiguration configuration_; 523 NetAddress address_; 524 X509CertRawData remoteRawData_; 525 526 std::string hostName_; 527 std::string remoteCert_; 528 std::string keyPass_; 529 530 std::vector<std::string> signatureAlgorithms_; 531 std::unique_ptr<TLSContext> tlsContextPointer_ = nullptr; 532 }; 533 534 private: 535 TLSSocketInternal tlsSocketInternal_; 536 537 static std::string MakeAddressString(sockaddr *addr); 538 539 static void GetAddr(const NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr, 540 socklen_t *len); 541 542 void CallOnMessageCallback(const std::string &data, const SocketRemoteInfo &remoteInfo); 543 void CallOnConnectCallback(); 544 void CallOnCloseCallback(); 545 void CallOnErrorCallback(int32_t err, const std::string &errString); 546 547 void CallBindCallback(int32_t err, BindCallback callback); 548 void CallConnectCallback(int32_t err, ConnectCallback callback); 549 void CallSendCallback(int32_t err, SendCallback callback); 550 void CallCloseCallback(int32_t err, CloseCallback callback); 551 void CallGetRemoteAddressCallback(int32_t err, const NetAddress &address, GetRemoteAddressCallback callback); 552 void CallGetStateCallback(int32_t err, const SocketStateBase &state, GetStateCallback callback); 553 void CallSetExtraOptionsCallback(int32_t err, SetExtraOptionsCallback callback); 554 void CallGetCertificateCallback(int32_t err, const X509CertRawData &cert, GetCertificateCallback callback); 555 void CallGetRemoteCertificateCallback(int32_t err, const X509CertRawData &cert, 556 GetRemoteCertificateCallback callback); 557 void CallGetProtocolCallback(int32_t err, const std::string &protocol, GetProtocolCallback callback); 558 void CallGetCipherSuiteCallback(int32_t err, const std::vector<std::string> &suite, 559 GetCipherSuiteCallback callback); 560 void CallGetSignatureAlgorithmsCallback(int32_t err, const std::vector<std::string> &algorithms, 561 GetSignatureAlgorithmsCallback callback); 562 563 void StartReadMessage(); 564 565 void GetIp4RemoteAddress(const GetRemoteAddressCallback &callback); 566 void GetIp6RemoteAddress(const GetRemoteAddressCallback &callback); 567 568 [[nodiscard]] bool SetBaseOptions(const ExtraOptionsBase &option) const; 569 [[nodiscard]] bool SetExtraOptions(const TCPExtraOptions &option) const; 570 571 void MakeIpSocket(sa_family_t family); 572 573 private: 574 static constexpr const size_t MAX_ERROR_LEN = 128; 575 static constexpr const size_t MAX_BUFFER_SIZE = 8192; 576 577 OnMessageCallback onMessageCallback_; 578 OnConnectCallback onConnectCallback_; 579 OnCloseCallback onCloseCallback_; 580 OnErrorCallback onErrorCallback_; 581 582 std::mutex mutex_; 583 bool isRunning_ = false; 584 bool isRunOver_ = true; 585 586 int sockFd_ = -1; 587 }; 588 } // namespace NetStack 589 } // namespace OHOS 590 591 #endif // COMMUNICATIONNETSTACK_TLS_SOCEKT_H 592