• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef COMMUNICATIONNETSTACK_TLS_SERVER_SOCEKT_H
17 #define COMMUNICATIONNETSTACK_TLS_SERVER_SOCEKT_H
18 
19 #include "event_manager.h"
20 #include "extra_options_base.h"
21 #include "net_address.h"
22 #include "socket_error.h"
23 #include "socket_remote_info.h"
24 #include "socket_state_base.h"
25 #include "tcp_connect_options.h"
26 #include "tcp_extra_options.h"
27 #include "tcp_send_options.h"
28 #include "tls.h"
29 #include "tls_certificate.h"
30 #include "tls_configuration.h"
31 #include "tls_context_server.h"
32 #include "tls_key.h"
33 #include "tls_socket.h"
34 #include <any>
35 #include <condition_variable>
36 #include <cstring>
37 #include <functional>
38 #include <map>
39 #include <poll.h>
40 #include <thread>
41 #include <tuple>
42 #include <unistd.h>
43 #include <vector>
44 
45 namespace OHOS {
46 namespace NetStack {
47 namespace TlsSocketServer {
48 constexpr int USER_LIMIT = 10;
49 using OnMessageCallback =
50     std::function<void(const int &socketFd, const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)>;
51 using OnCloseCallback = std::function<void(const int &socketFd)>;
52 using OnConnectCallback = std::function<void(const int &socketFd, std::shared_ptr<EventManager> eventManager)>;
53 using ListenCallback = std::function<void(int32_t errorNumber)>;
54 class TLSServerSendOptions {
55 public:
56     /**
57      * Set the socket ID to be transmitted
58      * @param socketFd Communication descriptor
59      */
60     void SetSocket(const int &socketFd);
61 
62     /**
63      * Set the data to send
64      * @param data Send data
65      */
66     void SetSendData(const std::string &data);
67 
68     /**
69      * Get the socket ID
70      * @return Gets the communication descriptor
71      */
72     [[nodiscard]] const int &GetSocket() const;
73 
74     /**
75      * Gets the data sent
76      * @return Send data
77      */
78     [[nodiscard]] const std::string &GetSendData() const;
79 
80 private:
81     int socketFd_;
82     std::string data_;
83 };
84 
85 class TLSSocketServer {
86 public:
87     TLSSocketServer(const TLSSocketServer &) = delete;
88     TLSSocketServer(TLSSocketServer &&) = delete;
89 
90     TLSSocketServer &operator=(const TLSSocketServer &) = delete;
91     TLSSocketServer &operator=(TLSSocketServer &&) = delete;
92 
93     TLSSocketServer() = default;
94     ~TLSSocketServer();
95 
96     /**
97      * Create sockets, bind and listen waiting for clients to connect
98      * @param tlsListenOptions Bind the listening connection configuration
99      * @param callback callback to the caller if bind ok or not
100      */
101     void Listen(const TlsSocket::TLSConnectOptions &tlsListenOptions, const ListenCallback &callback);
102 
103     /**
104      * Send data through an established encrypted connection
105      * @param data data sent over an established encrypted connection
106      * @return whether the data is successfully sent to the server
107      */
108     bool Send(const TLSServerSendOptions &data, const TlsSocket::SendCallback &callback);
109 
110     /**
111      * Disconnect by releasing the socket when communicating
112      * @param socketFd The socket ID of the client
113      * @param callback callback to the caller
114      */
115     void Close(const int socketFd, const TlsSocket::CloseCallback &callback);
116 
117     /**
118      * Disconnect by releasing the socket when communicating
119      * @param callback callback to the caller
120      */
121     void Stop(const TlsSocket::CloseCallback &callback);
122 
123     /**
124      * Get the peer network address
125      * @param socketFd The socket ID of the client
126      * @param callback callback to the caller
127      */
128     void GetRemoteAddress(const int socketFd, const TlsSocket::GetRemoteAddressCallback &callback);
129 
130     /**
131      * Get the status of the current socket
132      * @param callback callback to the caller
133      */
134     void GetState(const TlsSocket::GetStateCallback &callback);
135 
136     /**
137      * Gets or sets the options associated with the current socket
138      * @param tcpExtraOptions options associated with the current socket
139      * @param callback callback to the caller
140      */
141     bool SetExtraOptions(const Socket::TCPExtraOptions &tcpExtraOptions,
142                          const TlsSocket::SetExtraOptionsCallback &callback);
143 
144     /**
145      *  Get a local digital certificate
146      * @param callback callback to the caller
147      */
148     void GetCertificate(const TlsSocket::GetCertificateCallback &callback);
149 
150     /**
151      * Get the peer digital certificate
152      * @param socketFd The socket ID of the client
153      * @param needChain need chain
154      * @param callback callback to the caller
155      */
156     void GetRemoteCertificate(const int socketFd, const TlsSocket::GetRemoteCertificateCallback &callback);
157 
158     /**
159      * Obtain the protocol used in communication
160      * @param callback callback to the caller
161      */
162     void GetProtocol(const TlsSocket::GetProtocolCallback &callback);
163 
164     /**
165      * Obtain the cipher suite used in communication
166      * @param socketFd The socket ID of the client
167      * @param callback callback to the caller
168      */
169     void GetCipherSuite(const int socketFd, const TlsSocket::GetCipherSuiteCallback &callback);
170 
171     /**
172      * Obtain the encryption algorithm used in the communication process
173      * @param socketFd The socket ID of the client
174      * @param callback callback to the caller
175      */
176     void GetSignatureAlgorithms(const int socketFd, const TlsSocket::GetSignatureAlgorithmsCallback &callback);
177 
178     /**
179      * Register the callback that is called when the connection is disconnected
180      * @param onCloseCallback callback invoked when disconnected
181      */
182 
183     /**
184      * Register the callback that is called when the connection is established
185      * @param onConnectCallback callback invoked when connection is established
186      */
187     void OnConnect(const OnConnectCallback &onConnectCallback);
188 
189     /**
190      * Register the callback that is called when an error occurs
191      * @param onErrorCallback callback invoked when an error occurs
192      */
193     void OnError(const TlsSocket::OnErrorCallback &onErrorCallback);
194 
195     /**
196      * Off Connect
197      */
198     void OffConnect();
199 
200     /**
201      * Off Error
202      */
203     void OffError();
204 
205 public:
206     class Connection : public std::enable_shared_from_this<Connection> {
207     public:
208         ~Connection();
209         /**
210          * Establish an encrypted accept on the specified socket
211          * @param sock socket for establishing encrypted connection
212          * @param options some options required during tls accept
213          * @return whether the encrypted accept is successfully established
214          */
215         bool TlsAcceptToHost(int sock, const TlsSocket::TLSConnectOptions &options);
216 
217         /**
218          * Set the configuration items for establishing encrypted connections
219          * @param config configuration item when establishing encrypted connection
220          */
221         void SetTlsConfiguration(const TlsSocket::TLSConnectOptions &config);
222 
223         /**
224          * Set address information
225          */
226         void SetAddress(const Socket::NetAddress address);
227 
228         /**
229          * Send data through an established encrypted connection
230          * @param data data sent over an established encrypted connection
231          * @return whether the data is successfully sent to the server
232          */
233         bool Send(const std::string &data);
234 
235         /**
236          * Receive the data sent by the server through the established encrypted connection
237          * @param buffer receive the data sent by the server
238          * @param maxBufferSize the size of the data received from the server
239          * @return whether the data sent by the server is successfully received
240          */
241         int Recv(char *buffer, int maxBufferSize);
242 
243         /**
244          * Disconnect encrypted connection
245          * @return whether the encrypted connection was successfully disconnected
246          */
247         bool Close();
248 
249         /**
250          * Set the application layer negotiation protocol in the encrypted communication process
251          * @param alpnProtocols application layer negotiation protocol
252          * @return set whether the application layer negotiation protocol is successful during encrypted communication
253          */
254         bool SetAlpnProtocols(const std::vector<std::string> &alpnProtocols);
255 
256         /**
257          * Storage of server communication related network information
258          * @param remoteInfo communication related network information
259          */
260         void MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo);
261 
262         /**
263          * Get configuration options for encrypted communication process
264          * @return configuration options for encrypted communication processes
265          */
266         [[nodiscard]] TlsSocket::TLSConfiguration GetTlsConfiguration() const;
267 
268         /**
269          * Obtain the cipher suite during encrypted communication
270          * @return crypto suite used in encrypted communication
271          */
272         [[nodiscard]] std::vector<std::string> GetCipherSuite() const;
273 
274         /**
275          * Obtain the peer certificate used in encrypted communication
276          * @return peer certificate used in encrypted communication
277          */
278         [[nodiscard]] std::string GetRemoteCertificate() const;
279 
280         /**
281          * Obtain the peer certificate used in encrypted communication
282          * @return peer certificate serialization data used in encrypted communication
283          */
284         [[nodiscard]] const TlsSocket::X509CertRawData &GetRemoteCertRawData() const;
285 
286         /**
287          * Obtain the certificate used in encrypted communication
288          * @return certificate serialization data used in encrypted communication
289          */
290         [[nodiscard]] const TlsSocket::X509CertRawData &GetCertificate() const;
291 
292         /**
293          * Get the encryption algorithm used in encrypted communication
294          * @return encryption algorithm used in encrypted communication
295          */
296         [[nodiscard]] std::vector<std::string> GetSignatureAlgorithms() const;
297 
298         /**
299          * Obtain the communication protocol used in encrypted communication
300          * @return communication protocol used in encrypted communication
301          */
302         [[nodiscard]] std::string GetProtocol() const;
303 
304         /**
305          * Set the information about the shared signature algorithm supported by peers during encrypted communication
306          * @return information about peer supported shared signature algorithms
307          */
308         [[nodiscard]] bool SetSharedSigals();
309 
310         /**
311          * Obtain the ssl used in encrypted communication
312          * @return SSL used in encrypted communication
313          */
314         [[nodiscard]] ssl_st *GetSSL() const;
315 
316         /**
317          * Get address information
318          * @return Returns the address information of the remote client
319          */
320         [[nodiscard]] Socket::NetAddress GetAddress() const;
321 
322         /**
323          * Get address information
324          * @return Returns the address information of the remote client
325          */
326         [[nodiscard]] int GetSocketFd() const;
327 
328         /**
329          * Get EventManager information
330          * @return Returns the address information of the remote client
331          */
332         [[nodiscard]] std::shared_ptr<EventManager> GetEventManager() const;
333 
334         void OnMessage(const OnMessageCallback &onMessageCallback);
335         /**
336          * Unregister the callback which is called when message is received
337          */
338         void OffMessage();
339 
340         void CallOnMessageCallback(int32_t socketFd, const std::string &data,
341                                    const Socket::SocketRemoteInfo &remoteInfo);
342 
343         void SetEventManager(std::shared_ptr<EventManager> eventManager);
344 
345         void SetClientID(int32_t clientID);
346 
347         [[nodiscard]] int GetClientID();
348 
349         void CallOnCloseCallback(const int32_t socketFd);
350         void OnClose(const OnCloseCallback &onCloseCallback);
351         OnCloseCallback onCloseCallback_;
352 
353         /**
354          * Off Close
355          */
356         void OffClose();
357 
358         /**
359          * Register the callback that is called when an error occurs
360          * @param onErrorCallback callback invoked when an error occurs
361          */
362         void OnError(const TlsSocket::OnErrorCallback &onErrorCallback);
363         /**
364          * Off Error
365          */
366         void OffError();
367 
368         void CallOnErrorCallback(int32_t err, const std::string &errString);
369 
370         TlsSocket::OnErrorCallback onErrorCallback_;
371 
372     private:
373         bool StartTlsAccept(const TlsSocket::TLSConnectOptions &options);
374         bool CreatTlsContext();
375         bool StartShakingHands(const TlsSocket::TLSConnectOptions &options);
376         bool GetRemoteCertificateFromPeer();
377         bool SetRemoteCertRawData();
378         std::string CheckServerIdentityLegal(const std::string &hostName, const X509 *x509Certificates);
379         std::string CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext,
380                                              const X509 *x509Certificates);
381 
382     private:
383         ssl_st *ssl_ = nullptr;
384         X509 *peerX509_ = nullptr;
385         int32_t socketFd_ = 0;
386 
387         TlsSocket::TLSContextServer tlsContext_;
388         TlsSocket::TLSConfiguration connectionConfiguration_;
389         Socket::NetAddress address_;
390         TlsSocket::X509CertRawData remoteRawData_;
391 
392         std::string hostName_;
393         std::string remoteCert_;
394         std::string keyPass_;
395 
396         std::vector<std::string> signatureAlgorithms_;
397         std::unique_ptr<TlsSocket::TLSContextServer> tlsContextServerPointer_ = nullptr;
398 
399         std::shared_ptr<EventManager> eventManager_ = nullptr;
400         int32_t clientID_ = 0;
401         OnMessageCallback onMessageCallback_;
402     };
403 
404 private:
405     void SetLocalTlsConfiguration(const TlsSocket::TLSConnectOptions &config);
406     int RecvRemoteInfo(int socketFd, int index);
407     void RemoveConnect(int socketFd);
408     void AddConnect(int socketFd, std::shared_ptr<Connection> connection);
409     void CallListenCallback(int32_t err, ListenCallback callback);
410     void CallOnErrorCallback(int32_t err, const std::string &errString);
411 
412     void CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, TlsSocket::GetStateCallback callback);
413     void CallOnConnectCallback(const int32_t socketFd, std::shared_ptr<EventManager> eventManager);
414     void CallSendCallback(int32_t err, TlsSocket::SendCallback callback);
415     bool ExecBind(const Socket::NetAddress &address, const ListenCallback &callback);
416     void ExecAccept(const TlsSocket::TLSConnectOptions &tlsAcceptOptions, const ListenCallback &callback);
417     void MakeIpSocket(sa_family_t family);
418     void GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr,
419                  socklen_t *len);
420     static constexpr const size_t MAX_ERROR_LEN = 128;
421     static constexpr const size_t MAX_BUFFER_SIZE = 8192;
422 
423     void PollThread(const TlsSocket::TLSConnectOptions &tlsListenOptions);
424 
425 private:
426     std::mutex mutex_;
427     std::mutex connectMutex_;
428     int listenSocketFd_ = -1;
429     Socket::NetAddress address_;
430     std::map<int, std::shared_ptr<Connection>> clientIdConnections_;
431     TlsSocket::TLSConfiguration TLSServerConfiguration_;
432 
433     OnConnectCallback onConnectCallback_;
434     TlsSocket::OnErrorCallback onErrorCallback_;
435 
436     void ProcessTcpAccept(const TlsSocket::TLSConnectOptions &tlsListenOptions, int clientId);
437     void DropFdFromPollList(int &fd_index);
438     void InitPollList(int &listendFd);
439 
440     struct pollfd fds_[USER_LIMIT + 1];
441 
442     bool isRunning_;
443 
444 public:
445     std::shared_ptr<Connection> GetConnectionByClientID(int clientid);
446     int GetConnectionClientCount();
447 
448     std::shared_ptr<Connection> GetConnectionByClientEventManager(const EventManager *eventManager);
449     void CloseConnectionByEventManager(EventManager *eventManager);
450     void DeleteConnectionByEventManager(EventManager *eventManager);
451 };
452 } // namespace TlsSocketServer
453 } // namespace NetStack
454 } // namespace OHOS
455 
456 #endif // COMMUNICATIONNETSTACK_TLS_SERVER_SOCEKT_H
457