• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef COMMUNICATIONNETSTACK_TLS_SOCEKT_H
17 #define COMMUNICATIONNETSTACK_TLS_SOCEKT_H
18 
19 #include <any>
20 #include <condition_variable>
21 #include <cstring>
22 #include <functional>
23 #include <map>
24 #include <shared_mutex>
25 #include <thread>
26 #include <tuple>
27 #include <unistd.h>
28 #include <vector>
29 
30 #include "extra_options_base.h"
31 #include "net_address.h"
32 #include "proxy_options.h"
33 #include "socket_error.h"
34 #include "socket_remote_info.h"
35 #include "socket_state_base.h"
36 #include "tcp_connect_options.h"
37 #include "tcp_extra_options.h"
38 #include "tcp_send_options.h"
39 #include "tls.h"
40 #include "tls_certificate.h"
41 #include "tls_configuration.h"
42 #include "tls_context.h"
43 #include "tls_key.h"
44 
45 namespace OHOS {
46 namespace NetStack {
47 namespace TlsSocket {
48 
49 using BindCallback = std::function<void(int32_t errorNumber)>;
50 using ConnectCallback = std::function<void(int32_t errorNumber)>;
51 using SendCallback = std::function<void(int32_t errorNumber)>;
52 using CloseCallback = std::function<void(int32_t errorNumber)>;
53 using GetRemoteAddressCallback = std::function<void(int32_t errorNumber, const Socket::NetAddress &address)>;
54 using GetLocalAddressCallback = std::function<void(int32_t errorNumber, const Socket::NetAddress &address)>;
55 using GetStateCallback = std::function<void(int32_t errorNumber, const Socket::SocketStateBase &state)>;
56 using SetExtraOptionsCallback = std::function<void(int32_t errorNumber)>;
57 using GetCertificateCallback = std::function<void(int32_t errorNumber, const X509CertRawData &cert)>;
58 using GetRemoteCertificateCallback = std::function<void(int32_t errorNumber, const X509CertRawData &cert)>;
59 using GetProtocolCallback = std::function<void(int32_t errorNumber, const std::string &protocol)>;
60 using GetCipherSuiteCallback = std::function<void(int32_t errorNumber, const std::vector<std::string> &suite)>;
61 using GetSignatureAlgorithmsCallback =
62     std::function<void(int32_t errorNumber, const std::vector<std::string> &algorithms)>;
63 
64 using OnMessageCallback = std::function<void(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)>;
65 using OnConnectCallback = std::function<void(void)>;
66 using OnCloseCallback = std::function<void(void)>;
67 using OnErrorCallback = std::function<void(int32_t errorNumber, const std::string &errorString)>;
68 
69 using CheckServerIdentity =
70     std::function<void(const std::string &hostName, const std::vector<std::string> &x509Certificates)>;
71 
72 constexpr const char *ALPN_PROTOCOLS_HTTP_1_1 = "http1.1";
73 constexpr const char *ALPN_PROTOCOLS_HTTP_2 = "h2";
74 
75 constexpr size_t MAX_ERR_LEN = 1024;
76 
77 /**
78  * Parameters required during communication
79  */
80 class TLSSecureOptions {
81 public:
82     TLSSecureOptions() = default;
83     ~TLSSecureOptions() = default;
84 
85     TLSSecureOptions(const TLSSecureOptions &tlsSecureOptions);
86     TLSSecureOptions &operator=(const TLSSecureOptions &tlsSecureOptions);
87     /**
88      * Set root CA Chain to verify the server cert
89      * @param caChain root certificate chain used to validate server certificates
90      */
91     void SetCaChain(const std::vector<std::string> &caChain);
92 
93     /**
94      * Set digital certificate for server verification
95      * @param cert digital certificate sent to the server to verify validity
96      */
97     void SetCert(const std::string &cert);
98 
99     /**
100      * Set key to decrypt server data
101      * @param keyChain key used to decrypt server data
102      */
103     void SetKey(const SecureData &key);
104 
105     /**
106      * Set the password to read the private key
107      * @param keyPass read the password of the private key
108      */
109     void SetKeyPass(const SecureData &keyPass);
110 
111     /**
112      * Set the protocol used in communication
113      * @param protocolChain protocol version number used
114      */
115     void SetProtocolChain(const std::vector<std::string> &protocolChain);
116 
117     /**
118      * Whether the peer cipher suite is preferred for communication
119      * @param useRemoteCipherPrefer whether the peer cipher suite is preferred
120      */
121     void SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer);
122 
123     /**
124      * Encryption algorithm used in communication
125      * @param signatureAlgorithms encryption algorithm e.g: rsa
126      */
127     void SetSignatureAlgorithms(const std::string &signatureAlgorithms);
128 
129     /**
130      * Crypto suite used in communication
131      * @param cipherSuite cipher suite e.g:AES256-SHA256
132      */
133     void SetCipherSuite(const std::string &cipherSuite);
134 
135     /**
136      * Set a revoked certificate
137      * @param crlChain certificate Revocation List
138      */
139     void SetCrlChain(const std::vector<std::string> &crlChain);
140 
141     /**
142      * Get root CA Chain to verify the server cert
143      * @return root CA chain
144      */
145     [[nodiscard]] const std::vector<std::string> &GetCaChain() const;
146 
147     /**
148      * Obtain a certificate to send to the server for checking
149      * @return digital certificate obtained
150      */
151     [[nodiscard]] const std::string &GetCert() const;
152 
153     /**
154      * Obtain the private key in the communication process
155      * @return private key during communication
156      */
157     [[nodiscard]] const SecureData &GetKey() const;
158 
159     /**
160      * Get the password to read the private key
161      * @return read the password of the private key
162      */
163     [[nodiscard]] const SecureData &GetKeyPass() const;
164 
165     /**
166      * Get the protocol of the communication process
167      * @return protocol of communication process
168      */
169     [[nodiscard]] const std::vector<std::string> &GetProtocolChain() const;
170 
171     /**
172      * Is the remote cipher suite being used for communication
173      * @return is use Remote Cipher Prefer
174      */
175     [[nodiscard]] bool UseRemoteCipherPrefer() const;
176 
177     /**
178      * Obtain the encryption algorithm used in the communication process
179      * @return encryption algorithm used in communication
180      */
181     [[nodiscard]] const std::string &GetSignatureAlgorithms() const;
182 
183     /**
184      * Obtain the cipher suite used in communication
185      * @return crypto suite used in communication
186      */
187     [[nodiscard]] const std::string &GetCipherSuite() const;
188 
189     /**
190      * Get revoked certificate chain
191      * @return revoked certificate chain
192      */
193     [[nodiscard]] const std::vector<std::string> &GetCrlChain() const;
194 
195     void SetVerifyMode(VerifyMode verifyMode);
196 
197     [[nodiscard]] VerifyMode GetVerifyMode() const;
198 
199 private:
200     std::vector<std::string> caChain_;
201     std::string cert_;
202     SecureData key_;
203     SecureData keyPass_;
204     std::vector<std::string> protocolChain_;
205     bool useRemoteCipherPrefer_ = false;
206     std::string signatureAlgorithms_;
207     std::string cipherSuite_;
208     std::vector<std::string> crlChain_;
209     VerifyMode TLSVerifyMode_ = VerifyMode::ONE_WAY_MODE;
210 };
211 
212 /**
213  * Some options required during tls connection
214  */
215 class TLSConnectOptions {
216 public:
217     friend class TLSSocketExec;
218     /**
219      * Communication parameters required for connection establishment
220      * @param address communication parameters during connection
221      */
222     void SetNetAddress(const Socket::NetAddress &address);
223 
224     /**
225      * Parameters required during communication
226      * @param tlsSecureOptions certificate and other relevant parameters
227      */
228     void SetTlsSecureOptions(TLSSecureOptions &tlsSecureOptions);
229 
230     /**
231      * Set the callback function to check the validity of the server
232      * @param checkServerIdentity callback function passed in by API caller
233      */
234     void SetCheckServerIdentity(const CheckServerIdentity &checkServerIdentity);
235 
236     /**
237      * Set application layer protocol negotiation
238      * @param alpnProtocols application layer protocol negotiation
239      */
240     void SetAlpnProtocols(const std::vector<std::string> &alpnProtocols);
241 
242     /**
243      * Set whether to skip remote validation
244      * @param skipRemoteValidation flag to choose whether to skip validation
245      */
246     void SetSkipRemoteValidation(bool skipRemoteValidation);
247 
248     /**
249      * Obtain the network address of the communication process
250      * @return network address
251      */
252     [[nodiscard]] Socket::NetAddress GetNetAddress() const;
253 
254     /**
255      * Obtain the parameters required in the communication process
256      * @return certificate and other relevant parameters
257      */
258     [[nodiscard]] TLSSecureOptions GetTlsSecureOptions() const;
259 
260     /**
261      * Get the check server ID callback function passed in by the API caller
262      * @return check the server identity callback function
263      */
264     [[nodiscard]] CheckServerIdentity GetCheckServerIdentity() const;
265 
266     /**
267      * Obtain the application layer protocol negotiation in the communication process
268      * @return application layer protocol negotiation
269      */
270     [[nodiscard]] const std::vector<std::string> &GetAlpnProtocols() const;
271 
272     /**
273      * Get the choice of whether to skip remote validaion
274      * @return skipRemoteValidaion result
275      */
276     [[nodiscard]] bool GetSkipRemoteValidation() const;
277 
278     void SetHostName(const std::string &hostName);
279     [[nodiscard]] std::string GetHostName() const;
280 
281     std::shared_ptr<Socket::ProxyOptions> proxyOptions_{nullptr};
282 
283 private:
284     Socket::NetAddress address_;
285     TLSSecureOptions tlsSecureOptions_;
286     CheckServerIdentity checkServerIdentity_;
287     std::vector<std::string> alpnProtocols_;
288     bool skipRemoteValidation_ = false;
289     std::string hostName_;
290 };
291 
292 /**
293  * TLS socket interface class
294  */
295 class TLSSocket : public std::enable_shared_from_this<TLSSocket> {
296 public:
297     TLSSocket(const TLSSocket &) = delete;
298     TLSSocket(TLSSocket &&) = delete;
299 
300     TLSSocket &operator=(const TLSSocket &) = delete;
301     TLSSocket &operator=(TLSSocket &&) = delete;
302 
303     TLSSocket() = default;
304     ~TLSSocket() = default;
305 
TLSSocket(int sockFd)306     explicit TLSSocket(int sockFd): sockFd_(sockFd), isExtSock_(true) {}
307 
308     /**
309      * Create a socket and bind to the address specified by address
310      * @param address ip address
311      * @param callback callback to the caller if bind ok or not
312      */
313     void Bind(Socket::NetAddress &address, const BindCallback &callback);
314 
315     /**
316      * Establish a secure connection based on the created socket
317      * @param tlsConnectOptions some options required during tls connection
318      * @param callback callback to the caller if connect ok or not
319      */
320     void Connect(TLSConnectOptions &tlsConnectOptions, const ConnectCallback &callback);
321 
322     /**
323      * Send data based on the created socket
324      * @param tcpSendOptions  some options required during tcp data transmission
325      * @param callback callback to the caller if send ok or not
326      */
327     void Send(const Socket::TCPSendOptions &tcpSendOptions, const SendCallback &callback);
328 
329     /**
330      * Disconnect by releasing the socket when communicating
331      * @param callback callback to the caller
332      */
333     void Close(const CloseCallback &callback);
334 
335     /**
336      * Get the peer network address
337      * @param callback callback to the caller
338      */
339     void GetRemoteAddress(const GetRemoteAddressCallback &callback);
340 
341     /**
342      * Get the status of the current socket
343      * @param callback callback to the caller
344      */
345     void GetState(const GetStateCallback &callback);
346 
347     /**
348      * Gets or sets the options associated with the current socket
349      * @param tcpExtraOptions options associated with the current socket
350      * @param callback callback to the caller
351      */
352     void SetExtraOptions(const Socket::TCPExtraOptions &tcpExtraOptions, const SetExtraOptionsCallback &callback);
353 
354     /**
355      *  Get a local digital certificate
356      * @param callback callback to the caller
357      */
358     void GetCertificate(const GetCertificateCallback &callback);
359 
360     /**
361      * Get the peer digital certificate
362      * @param needChain need chain
363      * @param callback callback to the caller
364      */
365     void GetRemoteCertificate(const GetRemoteCertificateCallback &callback);
366 
367     /**
368      * Obtain the protocol used in communication
369      * @param callback callback to the caller
370      */
371     void GetProtocol(const GetProtocolCallback &callback);
372 
373     /**
374      * Obtain the cipher suite used in communication
375      * @param callback callback to the caller
376      */
377     void GetCipherSuite(const GetCipherSuiteCallback &callback);
378 
379     /**
380      * Obtain the encryption algorithm used in the communication process
381      * @param callback callback to the caller
382      */
383     void GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback &callback);
384 
385     /**
386      * Register a callback which is called when message is received
387      * @param onMessageCallback callback which is called when message is received
388      */
389     void OnMessage(const OnMessageCallback &onMessageCallback);
390 
391     /**
392      * Register the callback that is called when the connection is established
393      * @param onConnectCallback callback invoked when connection is established
394      */
395     void OnConnect(const OnConnectCallback &onConnectCallback);
396 
397     /**
398      * Register the callback that is called when the connection is disconnected
399      * @param onCloseCallback callback invoked when disconnected
400      */
401     void OnClose(const OnCloseCallback &onCloseCallback);
402 
403     /**
404      * Register the callback that is called when an error occurs
405      * @param onErrorCallback callback invoked when an error occurs
406      */
407     void OnError(const OnErrorCallback &onErrorCallback);
408 
409     /**
410      * Unregister the callback which is called when message is received
411      */
412     void OffMessage();
413 
414     /**
415      * Off Connect
416      */
417     void OffConnect();
418 
419     /**
420      * Off Close
421      */
422     void OffClose();
423 
424     /**
425      * Off Error
426      */
427     void OffError();
428 
429     /**
430      * Get the socket file description of the server
431      */
432     int GetSocketFd();
433 
434     /**
435      * Set the current socket file description address of the server
436      */
437     void SetLocalAddress(const Socket::NetAddress &address);
438 
439     /**
440      * Get the current socket file description address of the server
441      */
442     Socket::NetAddress GetLocalAddress();
443 
444     void ExecTlsGetAddr(
445         const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr, socklen_t *len);
446 
447     bool ExecTlsSetSockBlockFlag(int sock, bool noneBlock);
448 
449     bool IsExtSock() const;
450 
451 private:
452     class TLSSocketInternal final {
453     public:
454         TLSSocketInternal() = default;
455         ~TLSSocketInternal() = default;
456 
457         /**
458          * Establish an encrypted connection on the specified socket
459          * @param sock socket for establishing encrypted connection
460          * @param options some options required during tls connection
461          * @param isExtSock socket fd is originated from external source when constructing tls socket
462          * @return whether the encrypted connection is successfully established
463          */
464         bool TlsConnectToHost(int sock, const TLSConnectOptions &options, bool isExtSock);
465 
466         /**
467          * Set the configuration items for establishing encrypted connections
468          * @param config configuration item when establishing encrypted connection
469          */
470         void SetTlsConfiguration(const TLSConnectOptions &config);
471 
472         /**
473          * Send data through an established encrypted connection
474          * @param data data sent over an established encrypted connection
475          * @return whether the data is successfully sent to the server
476          */
477         bool Send(const std::string &data);
478 
479         /**
480          * Receive the data sent by the server through the established encrypted connection
481          * @param buffer receive the data sent by the server
482          * @param maxBufferSize the size of the data received from the server
483          * @return whether the data sent by the server is successfully received
484          */
485         int Recv(char *buffer, int maxBufferSize);
486 
487         /**
488          * Disconnect encrypted connection
489          * @return whether the encrypted connection was successfully disconnected
490          */
491         bool Close();
492 
493         /**
494          * Set the application layer negotiation protocol in the encrypted communication process
495          * @param alpnProtocols application layer negotiation protocol
496          * @return set whether the application layer negotiation protocol is successful during encrypted communication
497          */
498         bool SetAlpnProtocols(const std::vector<std::string> &alpnProtocols);
499 
500         /**
501          * Storage of server communication related network information
502          * @param remoteInfo communication related network information
503          */
504         void MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo);
505 
506         /**
507          * convert the code to ssl error code
508          * @return the value for ssl error code.
509          */
510         int ConvertSSLError(void);
511 
512         /**
513          * Get configuration options for encrypted communication process
514          * @return configuration options for encrypted communication processes
515          */
516         [[nodiscard]] TLSConfiguration GetTlsConfiguration() const;
517 
518         /**
519          * Obtain the cipher suite during encrypted communication
520          * @return crypto suite used in encrypted communication
521          */
522         [[nodiscard]] std::vector<std::string> GetCipherSuite() const;
523 
524         /**
525          * Obtain the peer certificate used in encrypted communication
526          * @return peer certificate used in encrypted communication
527          */
528         [[nodiscard]] std::string GetRemoteCertificate() const;
529 
530         /**
531          * Obtain the peer certificate used in encrypted communication
532          * @return peer certificate serialization data used in encrypted communication
533          */
534         [[nodiscard]] const X509CertRawData &GetRemoteCertRawData() const;
535 
536         /**
537          * Obtain the certificate used in encrypted communication
538          * @return certificate serialization data used in encrypted communication
539          */
540         [[nodiscard]] const X509CertRawData &GetCertificate() const;
541 
542         /**
543          * Get the encryption algorithm used in encrypted communication
544          * @return encryption algorithm used in encrypted communication
545          */
546         [[nodiscard]] std::vector<std::string> GetSignatureAlgorithms() const;
547 
548         /**
549          * Obtain the communication protocol used in encrypted communication
550          * @return communication protocol used in encrypted communication
551          */
552         [[nodiscard]] std::string GetProtocol() const;
553 
554         /**
555          * Set the information about the shared signature algorithm supported by peers during encrypted communication
556          * @return information about peer supported shared signature algorithms
557          */
558         [[nodiscard]] bool SetSharedSigals();
559 
560         /**
561          * Obtain the ssl used in encrypted communication
562          * @return SSL used in encrypted communication
563          */
564         [[nodiscard]] ssl_st *GetSSL();
565 
566     private:
567         bool SendRetry(ssl_st *ssl, const char *curPos, size_t curSendSize, int sockfd);
568         bool StartTlsConnected(const TLSConnectOptions &options);
569         bool CreatTlsContext();
570         bool StartShakingHands(const TLSConnectOptions &options);
571         bool GetRemoteCertificateFromPeer();
572         bool SetRemoteCertRawData();
573         bool PollSend(int sockfd, ssl_st *ssl, const char *pdata, int sendSize);
574         std::string CheckServerIdentityLegal(const std::string &hostName, const X509 *x509Certificates);
575         std::string CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext,
576                                              const X509 *x509Certificates);
577 
578     private:
579         std::mutex mutexForSsl_;
580         mutable std::shared_mutex rw_mutex_;
581         ssl_st *ssl_ = nullptr;
582         X509 *peerX509_ = nullptr;
583         uint16_t port_ = 0;
584         sa_family_t family_ = 0;
585         int32_t socketDescriptor_ = 0;
586 
587         TLSContext tlsContext_;
588         TLSConfiguration configuration_;
589         Socket::NetAddress address_;
590         X509CertRawData remoteRawData_;
591 
592         std::string hostName_;
593         std::string remoteCert_;
594 
595         std::vector<std::string> signatureAlgorithms_;
596         std::unique_ptr<TLSContext> tlsContextPointer_ = nullptr;
597     };
598 
599 private:
600     TLSSocketInternal tlsSocketInternal_;
601 
602     static std::string MakeAddressString(sockaddr *addr);
603 
604     static void GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr,
605                         socklen_t *len);
606 
607     void CallOnMessageCallback(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo);
608     void CallOnConnectCallback();
609     void CallOnCloseCallback();
610     void CallOnErrorCallback(int32_t err, const std::string &errString);
611 
612     void CallBindCallback(int32_t err, BindCallback callback);
613     void CallConnectCallback(int32_t err, ConnectCallback callback);
614     void CallSendCallback(int32_t err, SendCallback callback);
615     void CallCloseCallback(int32_t err, CloseCallback callback);
616     void CallGetRemoteAddressCallback(int32_t err, const Socket::NetAddress &address,
617                                       GetRemoteAddressCallback callback);
618     void CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, GetStateCallback callback);
619     void CallSetExtraOptionsCallback(int32_t err, SetExtraOptionsCallback callback);
620     void CallGetCertificateCallback(int32_t err, const X509CertRawData &cert, GetCertificateCallback callback);
621     void CallGetRemoteCertificateCallback(int32_t err, const X509CertRawData &cert,
622                                           GetRemoteCertificateCallback callback);
623     void CallGetProtocolCallback(int32_t err, const std::string &protocol, GetProtocolCallback callback);
624     void CallGetCipherSuiteCallback(int32_t err, const std::vector<std::string> &suite,
625                                     GetCipherSuiteCallback callback);
626     void CallGetSignatureAlgorithmsCallback(int32_t err, const std::vector<std::string> &algorithms,
627                                             GetSignatureAlgorithmsCallback callback);
628 
629     int ReadMessage();
630     void StartReadMessage();
631 
632     void GetIp4RemoteAddress(const GetRemoteAddressCallback &callback);
633     void GetIp6RemoteAddress(const GetRemoteAddressCallback &callback);
634 
635     [[nodiscard]] bool SetBaseOptions(const Socket::ExtraOptionsBase &option) const;
636     [[nodiscard]] bool SetExtraOptions(const Socket::TCPExtraOptions &option) const;
637 
638     void MakeIpSocket(sa_family_t family);
639 
640     template<class T>
DealCallback(int32_t err,T & callback)641     void DealCallback(int32_t err, T &callback)
642     {
643         if (callback) {
644             callback(err);
645         }
646     }
647 
648 private:
649     static constexpr const size_t MAX_ERROR_LEN = 128;
650     static constexpr const size_t MAX_BUFFER_SIZE = 8192;
651 
652     OnMessageCallback onMessageCallback_;
653     OnConnectCallback onConnectCallback_;
654     OnCloseCallback onCloseCallback_;
655     OnErrorCallback onErrorCallback_;
656 
657     std::mutex mutex_;
658     std::mutex recvMutex_;
659     std::mutex cvMutex_;
660     bool isRunning_ = false;
661     bool isRunOver_ = true;
662     std::condition_variable cvSslFree_;
663     int sockFd_ = -1;
664     bool isExtSock_ = false;
665     Socket::NetAddress localAddress_;
666 };
667 } // namespace TlsSocket
668 } // namespace NetStack
669 } // namespace OHOS
670 
671 #endif // COMMUNICATIONNETSTACK_TLS_SOCEKT_H
672