• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022-2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef COMMUNICATIONNETSTACK_TLS_SOCEKT_H
17 #define COMMUNICATIONNETSTACK_TLS_SOCEKT_H
18 
19 #include <any>
20 #include <condition_variable>
21 #include <cstring>
22 #include <functional>
23 #include <map>
24 #include <thread>
25 #include <tuple>
26 #include <unistd.h>
27 #include <vector>
28 
29 #include "extra_options_base.h"
30 #include "net_address.h"
31 #include "socket_error.h"
32 #include "socket_remote_info.h"
33 #include "socket_state_base.h"
34 #include "tcp_connect_options.h"
35 #include "tcp_extra_options.h"
36 #include "tcp_send_options.h"
37 #include "tls.h"
38 #include "tls_certificate.h"
39 #include "tls_configuration.h"
40 #include "tls_context.h"
41 #include "tls_key.h"
42 
43 namespace OHOS {
44 namespace NetStack {
45 namespace TlsSocket {
46 
47 using BindCallback = std::function<void(int32_t errorNumber)>;
48 using ConnectCallback = std::function<void(int32_t errorNumber)>;
49 using SendCallback = std::function<void(int32_t errorNumber)>;
50 using CloseCallback = std::function<void(int32_t errorNumber)>;
51 using GetRemoteAddressCallback = std::function<void(int32_t errorNumber, const Socket::NetAddress &address)>;
52 using GetStateCallback = std::function<void(int32_t errorNumber, const Socket::SocketStateBase &state)>;
53 using SetExtraOptionsCallback = std::function<void(int32_t errorNumber)>;
54 using GetCertificateCallback = std::function<void(int32_t errorNumber, const X509CertRawData &cert)>;
55 using GetRemoteCertificateCallback = std::function<void(int32_t errorNumber, const X509CertRawData &cert)>;
56 using GetProtocolCallback = std::function<void(int32_t errorNumber, const std::string &protocol)>;
57 using GetCipherSuiteCallback = std::function<void(int32_t errorNumber, const std::vector<std::string> &suite)>;
58 using GetSignatureAlgorithmsCallback =
59     std::function<void(int32_t errorNumber, const std::vector<std::string> &algorithms)>;
60 
61 using OnMessageCallback = std::function<void(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)>;
62 using OnConnectCallback = std::function<void(void)>;
63 using OnCloseCallback = std::function<void(void)>;
64 using OnErrorCallback = std::function<void(int32_t errorNumber, const std::string &errorString)>;
65 
66 using CheckServerIdentity =
67     std::function<void(const std::string &hostName, const std::vector<std::string> &x509Certificates)>;
68 
69 constexpr const char *ALPN_PROTOCOLS_HTTP_1_1 = "http1.1";
70 constexpr const char *ALPN_PROTOCOLS_HTTP_2 = "h2";
71 
72 constexpr size_t MAX_ERR_LEN = 1024;
73 
74 /**
75  * Parameters required during communication
76  */
77 class TLSSecureOptions {
78 public:
79     TLSSecureOptions() = default;
80     ~TLSSecureOptions() = default;
81 
82     TLSSecureOptions(const TLSSecureOptions &tlsSecureOptions);
83     TLSSecureOptions &operator=(const TLSSecureOptions &tlsSecureOptions);
84     /**
85      * Set root CA Chain to verify the server cert
86      * @param caChain root certificate chain used to validate server certificates
87      */
88     void SetCaChain(const std::vector<std::string> &caChain);
89 
90     /**
91      * Set digital certificate for server verification
92      * @param cert digital certificate sent to the server to verify validity
93      */
94     void SetCert(const std::string &cert);
95 
96     /**
97      * Set key to decrypt server data
98      * @param keyChain key used to decrypt server data
99      */
100     void SetKey(const SecureData &key);
101 
102     /**
103      * Set the password to read the private key
104      * @param keyPass read the password of the private key
105      */
106     void SetKeyPass(const SecureData &keyPass);
107 
108     /**
109      * Set the protocol used in communication
110      * @param protocolChain protocol version number used
111      */
112     void SetProtocolChain(const std::vector<std::string> &protocolChain);
113 
114     /**
115      * Whether the peer cipher suite is preferred for communication
116      * @param useRemoteCipherPrefer whether the peer cipher suite is preferred
117      */
118     void SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer);
119 
120     /**
121      * Encryption algorithm used in communication
122      * @param signatureAlgorithms encryption algorithm e.g: rsa
123      */
124     void SetSignatureAlgorithms(const std::string &signatureAlgorithms);
125 
126     /**
127      * Crypto suite used in communication
128      * @param cipherSuite cipher suite e.g:AES256-SHA256
129      */
130     void SetCipherSuite(const std::string &cipherSuite);
131 
132     /**
133      * Set a revoked certificate
134      * @param crlChain certificate Revocation List
135      */
136     void SetCrlChain(const std::vector<std::string> &crlChain);
137 
138     /**
139      * Get root CA Chain to verify the server cert
140      * @return root CA chain
141      */
142     [[nodiscard]] const std::vector<std::string> &GetCaChain() const;
143 
144     /**
145      * Obtain a certificate to send to the server for checking
146      * @return digital certificate obtained
147      */
148     [[nodiscard]] const std::string &GetCert() const;
149 
150     /**
151      * Obtain the private key in the communication process
152      * @return private key during communication
153      */
154     [[nodiscard]] const SecureData &GetKey() const;
155 
156     /**
157      * Get the password to read the private key
158      * @return read the password of the private key
159      */
160     [[nodiscard]] const SecureData &GetKeyPass() const;
161 
162     /**
163      * Get the protocol of the communication process
164      * @return protocol of communication process
165      */
166     [[nodiscard]] const std::vector<std::string> &GetProtocolChain() const;
167 
168     /**
169      * Is the remote cipher suite being used for communication
170      * @return is use Remote Cipher Prefer
171      */
172     [[nodiscard]] bool UseRemoteCipherPrefer() const;
173 
174     /**
175      * Obtain the encryption algorithm used in the communication process
176      * @return encryption algorithm used in communication
177      */
178     [[nodiscard]] const std::string &GetSignatureAlgorithms() const;
179 
180     /**
181      * Obtain the cipher suite used in communication
182      * @return crypto suite used in communication
183      */
184     [[nodiscard]] const std::string &GetCipherSuite() const;
185 
186     /**
187      * Get revoked certificate chain
188      * @return revoked certificate chain
189      */
190     [[nodiscard]] const std::vector<std::string> &GetCrlChain() const;
191 
192     void SetVerifyMode(VerifyMode verifyMode);
193 
194     [[nodiscard]] VerifyMode GetVerifyMode() const;
195 
196 private:
197     std::vector<std::string> caChain_;
198     std::string cert_;
199     SecureData key_;
200     SecureData keyPass_;
201     std::vector<std::string> protocolChain_;
202     bool useRemoteCipherPrefer_ = false;
203     std::string signatureAlgorithms_;
204     std::string cipherSuite_;
205     std::vector<std::string> crlChain_;
206     VerifyMode TLSVerifyMode_ = VerifyMode::ONE_WAY_MODE;
207 };
208 
209 /**
210  * Some options required during tls connection
211  */
212 class TLSConnectOptions {
213 public:
214     /**
215      * Communication parameters required for connection establishment
216      * @param address communication parameters during connection
217      */
218     void SetNetAddress(const Socket::NetAddress &address);
219 
220     /**
221      * Parameters required during communication
222      * @param tlsSecureOptions certificate and other relevant parameters
223      */
224     void SetTlsSecureOptions(TLSSecureOptions &tlsSecureOptions);
225 
226     /**
227      * Set the callback function to check the validity of the server
228      * @param checkServerIdentity callback function passed in by API caller
229      */
230     void SetCheckServerIdentity(const CheckServerIdentity &checkServerIdentity);
231 
232     /**
233      * Set application layer protocol negotiation
234      * @param alpnProtocols application layer protocol negotiation
235      */
236     void SetAlpnProtocols(const std::vector<std::string> &alpnProtocols);
237 
238     /**
239      * Obtain the network address of the communication process
240      * @return network address
241      */
242     [[nodiscard]] Socket::NetAddress GetNetAddress() const;
243 
244     /**
245      * Obtain the parameters required in the communication process
246      * @return certificate and other relevant parameters
247      */
248     [[nodiscard]] TLSSecureOptions GetTlsSecureOptions() const;
249 
250     /**
251      * Get the check server ID callback function passed in by the API caller
252      * @return check the server identity callback function
253      */
254     [[nodiscard]] CheckServerIdentity GetCheckServerIdentity() const;
255 
256     /**
257      * Obtain the application layer protocol negotiation in the communication process
258      * @return application layer protocol negotiation
259      */
260     [[nodiscard]] const std::vector<std::string> &GetAlpnProtocols() const;
261 
262 private:
263     Socket::NetAddress address_;
264     TLSSecureOptions tlsSecureOptions_;
265     CheckServerIdentity checkServerIdentity_;
266     std::vector<std::string> alpnProtocols_;
267 };
268 
269 /**
270  * TLS socket interface class
271  */
272 class TLSSocket {
273 public:
274     TLSSocket(const TLSSocket &) = delete;
275     TLSSocket(TLSSocket &&) = delete;
276 
277     TLSSocket &operator=(const TLSSocket &) = delete;
278     TLSSocket &operator=(TLSSocket &&) = delete;
279 
280     TLSSocket() = default;
281     ~TLSSocket() = default;
282 
283     /**
284      * Create a socket and bind to the address specified by address
285      * @param address ip address
286      * @param callback callback to the caller if bind ok or not
287      */
288     void Bind(const Socket::NetAddress &address, const BindCallback &callback);
289 
290     /**
291      * Establish a secure connection based on the created socket
292      * @param tlsConnectOptions some options required during tls connection
293      * @param callback callback to the caller if connect ok or not
294      */
295     void Connect(TLSConnectOptions &tlsConnectOptions, const ConnectCallback &callback);
296 
297     /**
298      * Send data based on the created socket
299      * @param tcpSendOptions  some options required during tcp data transmission
300      * @param callback callback to the caller if send ok or not
301      */
302     void Send(const Socket::TCPSendOptions &tcpSendOptions, const SendCallback &callback);
303 
304     /**
305      * Disconnect by releasing the socket when communicating
306      * @param callback callback to the caller
307      */
308     void Close(const CloseCallback &callback);
309 
310     /**
311      * Get the peer network address
312      * @param callback callback to the caller
313      */
314     void GetRemoteAddress(const GetRemoteAddressCallback &callback);
315 
316     /**
317      * Get the status of the current socket
318      * @param callback callback to the caller
319      */
320     void GetState(const GetStateCallback &callback);
321 
322     /**
323      * Gets or sets the options associated with the current socket
324      * @param tcpExtraOptions options associated with the current socket
325      * @param callback callback to the caller
326      */
327     void SetExtraOptions(const Socket::TCPExtraOptions &tcpExtraOptions, const SetExtraOptionsCallback &callback);
328 
329     /**
330      *  Get a local digital certificate
331      * @param callback callback to the caller
332      */
333     void GetCertificate(const GetCertificateCallback &callback);
334 
335     /**
336      * Get the peer digital certificate
337      * @param needChain need chain
338      * @param callback callback to the caller
339      */
340     void GetRemoteCertificate(const GetRemoteCertificateCallback &callback);
341 
342     /**
343      * Obtain the protocol used in communication
344      * @param callback callback to the caller
345      */
346     void GetProtocol(const GetProtocolCallback &callback);
347 
348     /**
349      * Obtain the cipher suite used in communication
350      * @param callback callback to the caller
351      */
352     void GetCipherSuite(const GetCipherSuiteCallback &callback);
353 
354     /**
355      * Obtain the encryption algorithm used in the communication process
356      * @param callback callback to the caller
357      */
358     void GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback &callback);
359 
360     /**
361      * Register a callback which is called when message is received
362      * @param onMessageCallback callback which is called when message is received
363      */
364     void OnMessage(const OnMessageCallback &onMessageCallback);
365 
366     /**
367      * Register the callback that is called when the connection is established
368      * @param onConnectCallback callback invoked when connection is established
369      */
370     void OnConnect(const OnConnectCallback &onConnectCallback);
371 
372     /**
373      * Register the callback that is called when the connection is disconnected
374      * @param onCloseCallback callback invoked when disconnected
375      */
376     void OnClose(const OnCloseCallback &onCloseCallback);
377 
378     /**
379      * Register the callback that is called when an error occurs
380      * @param onErrorCallback callback invoked when an error occurs
381      */
382     void OnError(const OnErrorCallback &onErrorCallback);
383 
384     /**
385      * Unregister the callback which is called when message is received
386      */
387     void OffMessage();
388 
389     /**
390      * Off Connect
391      */
392     void OffConnect();
393 
394     /**
395      * Off Close
396      */
397     void OffClose();
398 
399     /**
400      * Off Error
401      */
402     void OffError();
403 
404 private:
405     class TLSSocketInternal final {
406     public:
407         TLSSocketInternal() = default;
408         ~TLSSocketInternal() = default;
409 
410         /**
411          * Establish an encrypted connection on the specified socket
412          * @param sock socket for establishing encrypted connection
413          * @param options some options required during tls connection
414          * @return whether the encrypted connection is successfully established
415          */
416         bool TlsConnectToHost(int sock, const TLSConnectOptions &options);
417 
418         /**
419          * Set the configuration items for establishing encrypted connections
420          * @param config configuration item when establishing encrypted connection
421          */
422         void SetTlsConfiguration(const TLSConnectOptions &config);
423 
424         /**
425          * Send data through an established encrypted connection
426          * @param data data sent over an established encrypted connection
427          * @return whether the data is successfully sent to the server
428          */
429         bool Send(const std::string &data);
430 
431         /**
432          * Receive the data sent by the server through the established encrypted connection
433          * @param buffer receive the data sent by the server
434          * @param maxBufferSize the size of the data received from the server
435          * @return whether the data sent by the server is successfully received
436          */
437         int Recv(char *buffer, int maxBufferSize);
438 
439         /**
440          * Disconnect encrypted connection
441          * @return whether the encrypted connection was successfully disconnected
442          */
443         bool Close();
444 
445         /**
446          * Set the application layer negotiation protocol in the encrypted communication process
447          * @param alpnProtocols application layer negotiation protocol
448          * @return set whether the application layer negotiation protocol is successful during encrypted communication
449          */
450         bool SetAlpnProtocols(const std::vector<std::string> &alpnProtocols);
451 
452         /**
453          * Storage of server communication related network information
454          * @param remoteInfo communication related network information
455          */
456         void MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo);
457 
458         /**
459          * Get configuration options for encrypted communication process
460          * @return configuration options for encrypted communication processes
461          */
462         [[nodiscard]] TLSConfiguration GetTlsConfiguration() const;
463 
464         /**
465          * Obtain the cipher suite during encrypted communication
466          * @return crypto suite used in encrypted communication
467          */
468         [[nodiscard]] std::vector<std::string> GetCipherSuite() const;
469 
470         /**
471          * Obtain the peer certificate used in encrypted communication
472          * @return peer certificate used in encrypted communication
473          */
474         [[nodiscard]] std::string GetRemoteCertificate() const;
475 
476         /**
477          * Obtain the peer certificate used in encrypted communication
478          * @return peer certificate serialization data used in encrypted communication
479          */
480         [[nodiscard]] const X509CertRawData &GetRemoteCertRawData() const;
481 
482         /**
483          * Obtain the certificate used in encrypted communication
484          * @return certificate serialization data used in encrypted communication
485          */
486         [[nodiscard]] const X509CertRawData &GetCertificate() const;
487 
488         /**
489          * Get the encryption algorithm used in encrypted communication
490          * @return encryption algorithm used in encrypted communication
491          */
492         [[nodiscard]] std::vector<std::string> GetSignatureAlgorithms() const;
493 
494         /**
495          * Obtain the communication protocol used in encrypted communication
496          * @return communication protocol used in encrypted communication
497          */
498         [[nodiscard]] std::string GetProtocol() const;
499 
500         /**
501          * Set the information about the shared signature algorithm supported by peers during encrypted communication
502          * @return information about peer supported shared signature algorithms
503          */
504         [[nodiscard]] bool SetSharedSigals();
505 
506         /**
507          * Obtain the ssl used in encrypted communication
508          * @return SSL used in encrypted communication
509          */
510         [[nodiscard]] ssl_st *GetSSL() const;
511 
512     private:
513         bool StartTlsConnected(const TLSConnectOptions &options);
514         bool CreatTlsContext();
515         bool StartShakingHands(const TLSConnectOptions &options);
516         bool GetRemoteCertificateFromPeer();
517         bool SetRemoteCertRawData();
518         std::string CheckServerIdentityLegal(const std::string &hostName, const X509 *x509Certificates);
519         std::string CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext,
520                                              const X509 *x509Certificates);
521 
522     private:
523         ssl_st *ssl_ = nullptr;
524         X509 *peerX509_ = nullptr;
525         uint16_t port_ = 0;
526         sa_family_t family_ = 0;
527         int32_t socketDescriptor_ = 0;
528 
529         TLSContext tlsContext_;
530         TLSConfiguration configuration_;
531         Socket::NetAddress address_;
532         X509CertRawData remoteRawData_;
533 
534         std::string hostName_;
535         std::string remoteCert_;
536 
537         std::vector<std::string> signatureAlgorithms_;
538         std::unique_ptr<TLSContext> tlsContextPointer_ = nullptr;
539     };
540 
541 private:
542     TLSSocketInternal tlsSocketInternal_;
543 
544     static std::string MakeAddressString(sockaddr *addr);
545 
546     static void GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr,
547                         socklen_t *len);
548 
549     void CallOnMessageCallback(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo);
550     void CallOnConnectCallback();
551     void CallOnCloseCallback();
552     void CallOnErrorCallback(int32_t err, const std::string &errString);
553 
554     void CallBindCallback(int32_t err, BindCallback callback);
555     void CallConnectCallback(int32_t err, ConnectCallback callback);
556     void CallSendCallback(int32_t err, SendCallback callback);
557     void CallCloseCallback(int32_t err, CloseCallback callback);
558     void CallGetRemoteAddressCallback(int32_t err, const Socket::NetAddress &address,
559                                       GetRemoteAddressCallback callback);
560     void CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, GetStateCallback callback);
561     void CallSetExtraOptionsCallback(int32_t err, SetExtraOptionsCallback callback);
562     void CallGetCertificateCallback(int32_t err, const X509CertRawData &cert, GetCertificateCallback callback);
563     void CallGetRemoteCertificateCallback(int32_t err, const X509CertRawData &cert,
564                                           GetRemoteCertificateCallback callback);
565     void CallGetProtocolCallback(int32_t err, const std::string &protocol, GetProtocolCallback callback);
566     void CallGetCipherSuiteCallback(int32_t err, const std::vector<std::string> &suite,
567                                     GetCipherSuiteCallback callback);
568     void CallGetSignatureAlgorithmsCallback(int32_t err, const std::vector<std::string> &algorithms,
569                                             GetSignatureAlgorithmsCallback callback);
570 
571     void StartReadMessage();
572 
573     void GetIp4RemoteAddress(const GetRemoteAddressCallback &callback);
574     void GetIp6RemoteAddress(const GetRemoteAddressCallback &callback);
575 
576     [[nodiscard]] bool SetBaseOptions(const Socket::ExtraOptionsBase &option) const;
577     [[nodiscard]] bool SetExtraOptions(const Socket::TCPExtraOptions &option) const;
578 
579     void MakeIpSocket(sa_family_t family);
580 
581     template<class T>
DealCallback(int32_t err,T & callback)582     void DealCallback(int32_t err, T &callback)
583     {
584         T func = nullptr;
585         {
586             std::lock_guard<std::mutex> lock(mutex_);
587             if (callback) {
588                 func = callback;
589             }
590         }
591 
592         if (func) {
593             func(err);
594         }
595     }
596 
597 private:
598     static constexpr const size_t MAX_ERROR_LEN = 128;
599     static constexpr const size_t MAX_BUFFER_SIZE = 8192;
600 
601     OnMessageCallback onMessageCallback_;
602     OnConnectCallback onConnectCallback_;
603     OnCloseCallback onCloseCallback_;
604     OnErrorCallback onErrorCallback_;
605 
606     std::mutex mutex_;
607     bool isRunning_ = false;
608     bool isRunOver_ = true;
609 
610     int sockFd_ = -1;
611 };
612 } // namespace TlsSocket
613 } // namespace NetStack
614 } // namespace OHOS
615 
616 #endif // COMMUNICATIONNETSTACK_TLS_SOCEKT_H
617