1 /* 2 * Copyright (c) 2023 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16 #ifndef COMMUNICATIONNETSTACK_TLS_SERVER_SOCEKT_H 17 #define COMMUNICATIONNETSTACK_TLS_SERVER_SOCEKT_H 18 19 #include "event_manager.h" 20 #include "extra_options_base.h" 21 #include "net_address.h" 22 #include "socket_error.h" 23 #include "socket_remote_info.h" 24 #include "socket_state_base.h" 25 #include "tcp_connect_options.h" 26 #include "tcp_extra_options.h" 27 #include "tcp_send_options.h" 28 #include "tls.h" 29 #include "tls_certificate.h" 30 #include "tls_configuration.h" 31 #include "tls_context_server.h" 32 #include "tls_key.h" 33 #include "tls_socket.h" 34 #include <any> 35 #include <condition_variable> 36 #include <cstring> 37 #include <functional> 38 #include <map> 39 #include <poll.h> 40 #include <thread> 41 #include <tuple> 42 #include <unistd.h> 43 #include <vector> 44 45 namespace OHOS { 46 namespace NetStack { 47 namespace TlsSocketServer { 48 constexpr int USER_LIMIT = 10; 49 using OnMessageCallback = 50 std::function<void(const int &socketFd, const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)>; 51 using OnCloseCallback = std::function<void(const int &socketFd)>; 52 using OnConnectCallback = std::function<void(const int &socketFd, std::shared_ptr<EventManager> eventManager)>; 53 using ListenCallback = std::function<void(int32_t errorNumber)>; 54 class TLSServerSendOptions { 55 public: 56 /** 57 * Set the socket ID to be transmitted 58 * @param socketFd Communication descriptor 59 */ 60 void SetSocket(const int &socketFd); 61 62 /** 63 * Set the data to send 64 * @param data Send data 65 */ 66 void SetSendData(const std::string &data); 67 68 /** 69 * Get the socket ID 70 * @return Gets the communication descriptor 71 */ 72 [[nodiscard]] const int &GetSocket() const; 73 74 /** 75 * Gets the data sent 76 * @return Send data 77 */ 78 [[nodiscard]] const std::string &GetSendData() const; 79 80 private: 81 int socketFd_; 82 std::string data_; 83 }; 84 85 class TLSSocketServer { 86 public: 87 TLSSocketServer(const TLSSocketServer &) = delete; 88 TLSSocketServer(TLSSocketServer &&) = delete; 89 90 TLSSocketServer &operator=(const TLSSocketServer &) = delete; 91 TLSSocketServer &operator=(TLSSocketServer &&) = delete; 92 93 TLSSocketServer() = default; 94 ~TLSSocketServer(); 95 96 /** 97 * Create sockets, bind and listen waiting for clients to connect 98 * @param tlsListenOptions Bind the listening connection configuration 99 * @param callback callback to the caller if bind ok or not 100 */ 101 void Listen(const TlsSocket::TLSConnectOptions &tlsListenOptions, const ListenCallback &callback); 102 103 /** 104 * Send data through an established encrypted connection 105 * @param data data sent over an established encrypted connection 106 * @return whether the data is successfully sent to the server 107 */ 108 bool Send(const TLSServerSendOptions &data, const TlsSocket::SendCallback &callback); 109 110 /** 111 * Disconnect by releasing the socket when communicating 112 * @param socketFd The socket ID of the client 113 * @param callback callback to the caller 114 */ 115 void Close(const int socketFd, const TlsSocket::CloseCallback &callback); 116 117 /** 118 * Disconnect by releasing the socket when communicating 119 * @param callback callback to the caller 120 */ 121 void Stop(const TlsSocket::CloseCallback &callback); 122 123 /** 124 * Get the peer network address 125 * @param socketFd The socket ID of the client 126 * @param callback callback to the caller 127 */ 128 void GetRemoteAddress(const int socketFd, const TlsSocket::GetRemoteAddressCallback &callback); 129 130 /** 131 * Get the status of the current socket 132 * @param callback callback to the caller 133 */ 134 void GetState(const TlsSocket::GetStateCallback &callback); 135 136 /** 137 * Gets or sets the options associated with the current socket 138 * @param tcpExtraOptions options associated with the current socket 139 * @param callback callback to the caller 140 */ 141 bool SetExtraOptions(const Socket::TCPExtraOptions &tcpExtraOptions, 142 const TlsSocket::SetExtraOptionsCallback &callback); 143 144 /** 145 * Get a local digital certificate 146 * @param callback callback to the caller 147 */ 148 void GetCertificate(const TlsSocket::GetCertificateCallback &callback); 149 150 /** 151 * Get the peer digital certificate 152 * @param socketFd The socket ID of the client 153 * @param needChain need chain 154 * @param callback callback to the caller 155 */ 156 void GetRemoteCertificate(const int socketFd, const TlsSocket::GetRemoteCertificateCallback &callback); 157 158 /** 159 * Obtain the protocol used in communication 160 * @param callback callback to the caller 161 */ 162 void GetProtocol(const TlsSocket::GetProtocolCallback &callback); 163 164 /** 165 * Obtain the cipher suite used in communication 166 * @param socketFd The socket ID of the client 167 * @param callback callback to the caller 168 */ 169 void GetCipherSuite(const int socketFd, const TlsSocket::GetCipherSuiteCallback &callback); 170 171 /** 172 * Obtain the encryption algorithm used in the communication process 173 * @param socketFd The socket ID of the client 174 * @param callback callback to the caller 175 */ 176 void GetSignatureAlgorithms(const int socketFd, const TlsSocket::GetSignatureAlgorithmsCallback &callback); 177 178 /** 179 * Register the callback that is called when the connection is disconnected 180 * @param onCloseCallback callback invoked when disconnected 181 */ 182 183 /** 184 * Register the callback that is called when the connection is established 185 * @param onConnectCallback callback invoked when connection is established 186 */ 187 void OnConnect(const OnConnectCallback &onConnectCallback); 188 189 /** 190 * Register the callback that is called when an error occurs 191 * @param onErrorCallback callback invoked when an error occurs 192 */ 193 void OnError(const TlsSocket::OnErrorCallback &onErrorCallback); 194 195 /** 196 * Off Connect 197 */ 198 void OffConnect(); 199 200 /** 201 * Off Error 202 */ 203 void OffError(); 204 205 public: 206 class Connection : public std::enable_shared_from_this<Connection> { 207 public: 208 ~Connection(); 209 /** 210 * Establish an encrypted accept on the specified socket 211 * @param sock socket for establishing encrypted connection 212 * @param options some options required during tls accept 213 * @return whether the encrypted accept is successfully established 214 */ 215 bool TlsAcceptToHost(int sock, const TlsSocket::TLSConnectOptions &options); 216 217 /** 218 * Set the configuration items for establishing encrypted connections 219 * @param config configuration item when establishing encrypted connection 220 */ 221 void SetTlsConfiguration(const TlsSocket::TLSConnectOptions &config); 222 223 /** 224 * Set address information 225 */ 226 void SetAddress(const Socket::NetAddress address); 227 228 /** 229 * Send data through an established encrypted connection 230 * @param data data sent over an established encrypted connection 231 * @return whether the data is successfully sent to the server 232 */ 233 bool Send(const std::string &data); 234 235 /** 236 * Receive the data sent by the server through the established encrypted connection 237 * @param buffer receive the data sent by the server 238 * @param maxBufferSize the size of the data received from the server 239 * @return whether the data sent by the server is successfully received 240 */ 241 int Recv(char *buffer, int maxBufferSize); 242 243 /** 244 * Disconnect encrypted connection 245 * @return whether the encrypted connection was successfully disconnected 246 */ 247 bool Close(); 248 249 /** 250 * Set the application layer negotiation protocol in the encrypted communication process 251 * @param alpnProtocols application layer negotiation protocol 252 * @return set whether the application layer negotiation protocol is successful during encrypted communication 253 */ 254 bool SetAlpnProtocols(const std::vector<std::string> &alpnProtocols); 255 256 /** 257 * Storage of server communication related network information 258 * @param remoteInfo communication related network information 259 */ 260 void MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo); 261 262 /** 263 * Get configuration options for encrypted communication process 264 * @return configuration options for encrypted communication processes 265 */ 266 [[nodiscard]] TlsSocket::TLSConfiguration GetTlsConfiguration() const; 267 268 /** 269 * Obtain the cipher suite during encrypted communication 270 * @return crypto suite used in encrypted communication 271 */ 272 [[nodiscard]] std::vector<std::string> GetCipherSuite() const; 273 274 /** 275 * Obtain the peer certificate used in encrypted communication 276 * @return peer certificate used in encrypted communication 277 */ 278 [[nodiscard]] std::string GetRemoteCertificate() const; 279 280 /** 281 * Obtain the peer certificate used in encrypted communication 282 * @return peer certificate serialization data used in encrypted communication 283 */ 284 [[nodiscard]] const TlsSocket::X509CertRawData &GetRemoteCertRawData() const; 285 286 /** 287 * Obtain the certificate used in encrypted communication 288 * @return certificate serialization data used in encrypted communication 289 */ 290 [[nodiscard]] const TlsSocket::X509CertRawData &GetCertificate() const; 291 292 /** 293 * Get the encryption algorithm used in encrypted communication 294 * @return encryption algorithm used in encrypted communication 295 */ 296 [[nodiscard]] std::vector<std::string> GetSignatureAlgorithms() const; 297 298 /** 299 * Obtain the communication protocol used in encrypted communication 300 * @return communication protocol used in encrypted communication 301 */ 302 [[nodiscard]] std::string GetProtocol() const; 303 304 /** 305 * Set the information about the shared signature algorithm supported by peers during encrypted communication 306 * @return information about peer supported shared signature algorithms 307 */ 308 [[nodiscard]] bool SetSharedSigals(); 309 310 /** 311 * Obtain the ssl used in encrypted communication 312 * @return SSL used in encrypted communication 313 */ 314 [[nodiscard]] ssl_st *GetSSL() const; 315 316 /** 317 * Get address information 318 * @return Returns the address information of the remote client 319 */ 320 [[nodiscard]] Socket::NetAddress GetAddress() const; 321 322 /** 323 * Get address information 324 * @return Returns the address information of the remote client 325 */ 326 [[nodiscard]] int GetSocketFd() const; 327 328 /** 329 * Get EventManager information 330 * @return Returns the address information of the remote client 331 */ 332 [[nodiscard]] std::shared_ptr<EventManager> GetEventManager() const; 333 334 void OnMessage(const OnMessageCallback &onMessageCallback); 335 /** 336 * Unregister the callback which is called when message is received 337 */ 338 void OffMessage(); 339 340 void CallOnMessageCallback(int32_t socketFd, const std::string &data, 341 const Socket::SocketRemoteInfo &remoteInfo); 342 343 void SetEventManager(std::shared_ptr<EventManager> eventManager); 344 345 void SetClientID(int32_t clientID); 346 347 [[nodiscard]] int GetClientID(); 348 349 void CallOnCloseCallback(const int32_t socketFd); 350 void OnClose(const OnCloseCallback &onCloseCallback); 351 OnCloseCallback onCloseCallback_; 352 353 /** 354 * Off Close 355 */ 356 void OffClose(); 357 358 /** 359 * Register the callback that is called when an error occurs 360 * @param onErrorCallback callback invoked when an error occurs 361 */ 362 void OnError(const TlsSocket::OnErrorCallback &onErrorCallback); 363 /** 364 * Off Error 365 */ 366 void OffError(); 367 368 void CallOnErrorCallback(int32_t err, const std::string &errString); 369 370 TlsSocket::OnErrorCallback onErrorCallback_; 371 372 private: 373 bool StartTlsAccept(const TlsSocket::TLSConnectOptions &options); 374 bool CreatTlsContext(); 375 bool StartShakingHands(const TlsSocket::TLSConnectOptions &options); 376 bool GetRemoteCertificateFromPeer(); 377 bool SetRemoteCertRawData(); 378 std::string CheckServerIdentityLegal(const std::string &hostName, const X509 *x509Certificates); 379 std::string CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext, 380 const X509 *x509Certificates); 381 382 private: 383 ssl_st *ssl_ = nullptr; 384 X509 *peerX509_ = nullptr; 385 int32_t socketFd_ = 0; 386 387 TlsSocket::TLSContextServer tlsContext_; 388 TlsSocket::TLSConfiguration connectionConfiguration_; 389 Socket::NetAddress address_; 390 TlsSocket::X509CertRawData remoteRawData_; 391 392 std::string hostName_; 393 std::string remoteCert_; 394 std::string keyPass_; 395 396 std::vector<std::string> signatureAlgorithms_; 397 std::unique_ptr<TlsSocket::TLSContextServer> tlsContextServerPointer_ = nullptr; 398 399 std::shared_ptr<EventManager> eventManager_ = nullptr; 400 int32_t clientID_ = 0; 401 OnMessageCallback onMessageCallback_; 402 }; 403 404 private: 405 void SetLocalTlsConfiguration(const TlsSocket::TLSConnectOptions &config); 406 int RecvRemoteInfo(int socketFd, int index); 407 void RemoveConnect(int socketFd); 408 void AddConnect(int socketFd, std::shared_ptr<Connection> connection); 409 void CallListenCallback(int32_t err, ListenCallback callback); 410 void CallOnErrorCallback(int32_t err, const std::string &errString); 411 412 void CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, TlsSocket::GetStateCallback callback); 413 void CallOnConnectCallback(const int32_t socketFd, std::shared_ptr<EventManager> eventManager); 414 void CallSendCallback(int32_t err, TlsSocket::SendCallback callback); 415 bool ExecBind(const Socket::NetAddress &address, const ListenCallback &callback); 416 void ExecAccept(const TlsSocket::TLSConnectOptions &tlsAcceptOptions, const ListenCallback &callback); 417 void MakeIpSocket(sa_family_t family); 418 void GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr, 419 socklen_t *len); 420 static constexpr const size_t MAX_ERROR_LEN = 128; 421 static constexpr const size_t MAX_BUFFER_SIZE = 8192; 422 423 void PollThread(const TlsSocket::TLSConnectOptions &tlsListenOptions); 424 425 private: 426 std::mutex mutex_; 427 std::mutex connectMutex_; 428 int listenSocketFd_ = -1; 429 Socket::NetAddress address_; 430 std::map<int, std::shared_ptr<Connection>> clientIdConnections_; 431 TlsSocket::TLSConfiguration TLSServerConfiguration_; 432 433 OnConnectCallback onConnectCallback_; 434 TlsSocket::OnErrorCallback onErrorCallback_; 435 436 void ProcessTcpAccept(const TlsSocket::TLSConnectOptions &tlsListenOptions, int clientId); 437 void DropFdFromPollList(int &fd_index); 438 void InitPollList(int &listendFd); 439 440 struct pollfd fds_[USER_LIMIT + 1]; 441 442 bool isRunning_; 443 444 public: 445 std::shared_ptr<Connection> GetConnectionByClientID(int clientid); 446 int GetConnectionClientCount(); 447 448 std::shared_ptr<Connection> GetConnectionByClientEventManager(const EventManager *eventManager); 449 void CloseConnectionByEventManager(EventManager *eventManager); 450 void DeleteConnectionByEventManager(EventManager *eventManager); 451 }; 452 } // namespace TlsSocketServer 453 } // namespace NetStack 454 } // namespace OHOS 455 456 #endif // COMMUNICATIONNETSTACK_TLS_SERVER_SOCEKT_H 457