• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef COMMUNICATIONNETSTACK_TLS_SOCEKT_H
17 #define COMMUNICATIONNETSTACK_TLS_SOCEKT_H
18 
19 #include <any>
20 #include <condition_variable>
21 #include <cstring>
22 #include <functional>
23 #include <map>
24 #include <thread>
25 #include <tuple>
26 #include <unistd.h>
27 #include <vector>
28 
29 #include "extra_options_base.h"
30 #include "net_address.h"
31 #include "proxy_options.h"
32 #include "socket_error.h"
33 #include "socket_remote_info.h"
34 #include "socket_state_base.h"
35 #include "tcp_connect_options.h"
36 #include "tcp_extra_options.h"
37 #include "tcp_send_options.h"
38 #include "tls.h"
39 #include "tls_certificate.h"
40 #include "tls_configuration.h"
41 #include "tls_context.h"
42 #include "tls_key.h"
43 
44 namespace OHOS {
45 namespace NetStack {
46 namespace TlsSocket {
47 
48 using BindCallback = std::function<void(int32_t errorNumber)>;
49 using ConnectCallback = std::function<void(int32_t errorNumber)>;
50 using SendCallback = std::function<void(int32_t errorNumber)>;
51 using CloseCallback = std::function<void(int32_t errorNumber)>;
52 using GetRemoteAddressCallback = std::function<void(int32_t errorNumber, const Socket::NetAddress &address)>;
53 using GetLocalAddressCallback = std::function<void(int32_t errorNumber, const Socket::NetAddress &address)>;
54 using GetStateCallback = std::function<void(int32_t errorNumber, const Socket::SocketStateBase &state)>;
55 using SetExtraOptionsCallback = std::function<void(int32_t errorNumber)>;
56 using GetCertificateCallback = std::function<void(int32_t errorNumber, const X509CertRawData &cert)>;
57 using GetRemoteCertificateCallback = std::function<void(int32_t errorNumber, const X509CertRawData &cert)>;
58 using GetProtocolCallback = std::function<void(int32_t errorNumber, const std::string &protocol)>;
59 using GetCipherSuiteCallback = std::function<void(int32_t errorNumber, const std::vector<std::string> &suite)>;
60 using GetSignatureAlgorithmsCallback =
61     std::function<void(int32_t errorNumber, const std::vector<std::string> &algorithms)>;
62 
63 using OnMessageCallback = std::function<void(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)>;
64 using OnConnectCallback = std::function<void(void)>;
65 using OnCloseCallback = std::function<void(void)>;
66 using OnErrorCallback = std::function<void(int32_t errorNumber, const std::string &errorString)>;
67 
68 using CheckServerIdentity =
69     std::function<void(const std::string &hostName, const std::vector<std::string> &x509Certificates)>;
70 
71 constexpr const char *ALPN_PROTOCOLS_HTTP_1_1 = "http1.1";
72 constexpr const char *ALPN_PROTOCOLS_HTTP_2 = "h2";
73 
74 constexpr size_t MAX_ERR_LEN = 1024;
75 
76 /**
77  * Parameters required during communication
78  */
79 class TLSSecureOptions {
80 public:
81     TLSSecureOptions() = default;
82     ~TLSSecureOptions() = default;
83 
84     TLSSecureOptions(const TLSSecureOptions &tlsSecureOptions);
85     TLSSecureOptions &operator=(const TLSSecureOptions &tlsSecureOptions);
86     /**
87      * Set root CA Chain to verify the server cert
88      * @param caChain root certificate chain used to validate server certificates
89      */
90     void SetCaChain(const std::vector<std::string> &caChain);
91 
92     /**
93      * Set digital certificate for server verification
94      * @param cert digital certificate sent to the server to verify validity
95      */
96     void SetCert(const std::string &cert);
97 
98     /**
99      * Set key to decrypt server data
100      * @param keyChain key used to decrypt server data
101      */
102     void SetKey(const SecureData &key);
103 
104     /**
105      * Set the password to read the private key
106      * @param keyPass read the password of the private key
107      */
108     void SetKeyPass(const SecureData &keyPass);
109 
110     /**
111      * Set the protocol used in communication
112      * @param protocolChain protocol version number used
113      */
114     void SetProtocolChain(const std::vector<std::string> &protocolChain);
115 
116     /**
117      * Whether the peer cipher suite is preferred for communication
118      * @param useRemoteCipherPrefer whether the peer cipher suite is preferred
119      */
120     void SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer);
121 
122     /**
123      * Encryption algorithm used in communication
124      * @param signatureAlgorithms encryption algorithm e.g: rsa
125      */
126     void SetSignatureAlgorithms(const std::string &signatureAlgorithms);
127 
128     /**
129      * Crypto suite used in communication
130      * @param cipherSuite cipher suite e.g:AES256-SHA256
131      */
132     void SetCipherSuite(const std::string &cipherSuite);
133 
134     /**
135      * Set a revoked certificate
136      * @param crlChain certificate Revocation List
137      */
138     void SetCrlChain(const std::vector<std::string> &crlChain);
139 
140     /**
141      * Get root CA Chain to verify the server cert
142      * @return root CA chain
143      */
144     [[nodiscard]] const std::vector<std::string> &GetCaChain() const;
145 
146     /**
147      * Obtain a certificate to send to the server for checking
148      * @return digital certificate obtained
149      */
150     [[nodiscard]] const std::string &GetCert() const;
151 
152     /**
153      * Obtain the private key in the communication process
154      * @return private key during communication
155      */
156     [[nodiscard]] const SecureData &GetKey() const;
157 
158     /**
159      * Get the password to read the private key
160      * @return read the password of the private key
161      */
162     [[nodiscard]] const SecureData &GetKeyPass() const;
163 
164     /**
165      * Get the protocol of the communication process
166      * @return protocol of communication process
167      */
168     [[nodiscard]] const std::vector<std::string> &GetProtocolChain() const;
169 
170     /**
171      * Is the remote cipher suite being used for communication
172      * @return is use Remote Cipher Prefer
173      */
174     [[nodiscard]] bool UseRemoteCipherPrefer() const;
175 
176     /**
177      * Obtain the encryption algorithm used in the communication process
178      * @return encryption algorithm used in communication
179      */
180     [[nodiscard]] const std::string &GetSignatureAlgorithms() const;
181 
182     /**
183      * Obtain the cipher suite used in communication
184      * @return crypto suite used in communication
185      */
186     [[nodiscard]] const std::string &GetCipherSuite() const;
187 
188     /**
189      * Get revoked certificate chain
190      * @return revoked certificate chain
191      */
192     [[nodiscard]] const std::vector<std::string> &GetCrlChain() const;
193 
194     void SetVerifyMode(VerifyMode verifyMode);
195 
196     [[nodiscard]] VerifyMode GetVerifyMode() const;
197 
198 private:
199     std::vector<std::string> caChain_;
200     std::string cert_;
201     SecureData key_;
202     SecureData keyPass_;
203     std::vector<std::string> protocolChain_;
204     bool useRemoteCipherPrefer_ = false;
205     std::string signatureAlgorithms_;
206     std::string cipherSuite_;
207     std::vector<std::string> crlChain_;
208     VerifyMode TLSVerifyMode_ = VerifyMode::ONE_WAY_MODE;
209 };
210 
211 /**
212  * Some options required during tls connection
213  */
214 class TLSConnectOptions {
215 public:
216     friend class TLSSocketExec;
217     /**
218      * Communication parameters required for connection establishment
219      * @param address communication parameters during connection
220      */
221     void SetNetAddress(const Socket::NetAddress &address);
222 
223     /**
224      * Parameters required during communication
225      * @param tlsSecureOptions certificate and other relevant parameters
226      */
227     void SetTlsSecureOptions(TLSSecureOptions &tlsSecureOptions);
228 
229     /**
230      * Set the callback function to check the validity of the server
231      * @param checkServerIdentity callback function passed in by API caller
232      */
233     void SetCheckServerIdentity(const CheckServerIdentity &checkServerIdentity);
234 
235     /**
236      * Set application layer protocol negotiation
237      * @param alpnProtocols application layer protocol negotiation
238      */
239     void SetAlpnProtocols(const std::vector<std::string> &alpnProtocols);
240 
241     /**
242      * Set whether to skip remote validation
243      * @param skipRemoteValidation flag to choose whether to skip validation
244      */
245     void SetSkipRemoteValidation(bool skipRemoteValidation);
246 
247     /**
248      * Obtain the network address of the communication process
249      * @return network address
250      */
251     [[nodiscard]] Socket::NetAddress GetNetAddress() const;
252 
253     /**
254      * Obtain the parameters required in the communication process
255      * @return certificate and other relevant parameters
256      */
257     [[nodiscard]] TLSSecureOptions GetTlsSecureOptions() const;
258 
259     /**
260      * Get the check server ID callback function passed in by the API caller
261      * @return check the server identity callback function
262      */
263     [[nodiscard]] CheckServerIdentity GetCheckServerIdentity() const;
264 
265     /**
266      * Obtain the application layer protocol negotiation in the communication process
267      * @return application layer protocol negotiation
268      */
269     [[nodiscard]] const std::vector<std::string> &GetAlpnProtocols() const;
270 
271     /**
272      * Get the choice of whether to skip remote validaion
273      * @return skipRemoteValidaion result
274      */
275     [[nodiscard]] bool GetSkipRemoteValidation() const;
276 
277     void SetHostName(const std::string &hostName);
278     [[nodiscard]] std::string GetHostName() const;
279 
280     std::shared_ptr<Socket::ProxyOptions> proxyOptions_{nullptr};
281 
282 private:
283     Socket::NetAddress address_;
284     TLSSecureOptions tlsSecureOptions_;
285     CheckServerIdentity checkServerIdentity_;
286     std::vector<std::string> alpnProtocols_;
287     bool skipRemoteValidation_ = false;
288     std::string hostName_;
289 };
290 
291 /**
292  * TLS socket interface class
293  */
294 class TLSSocket : public std::enable_shared_from_this<TLSSocket> {
295 public:
296     TLSSocket(const TLSSocket &) = delete;
297     TLSSocket(TLSSocket &&) = delete;
298 
299     TLSSocket &operator=(const TLSSocket &) = delete;
300     TLSSocket &operator=(TLSSocket &&) = delete;
301 
302     TLSSocket() = default;
303     ~TLSSocket() = default;
304 
TLSSocket(int sockFd)305     explicit TLSSocket(int sockFd): sockFd_(sockFd), isExtSock_(true) {}
306 
307     /**
308      * Create a socket and bind to the address specified by address
309      * @param address ip address
310      * @param callback callback to the caller if bind ok or not
311      */
312     void Bind(Socket::NetAddress &address, const BindCallback &callback);
313 
314     /**
315      * Establish a secure connection based on the created socket
316      * @param tlsConnectOptions some options required during tls connection
317      * @param callback callback to the caller if connect ok or not
318      */
319     void Connect(TLSConnectOptions &tlsConnectOptions, const ConnectCallback &callback);
320 
321     /**
322      * Send data based on the created socket
323      * @param tcpSendOptions  some options required during tcp data transmission
324      * @param callback callback to the caller if send ok or not
325      */
326     void Send(const Socket::TCPSendOptions &tcpSendOptions, const SendCallback &callback);
327 
328     /**
329      * Disconnect by releasing the socket when communicating
330      * @param callback callback to the caller
331      */
332     void Close(const CloseCallback &callback);
333 
334     /**
335      * Get the peer network address
336      * @param callback callback to the caller
337      */
338     void GetRemoteAddress(const GetRemoteAddressCallback &callback);
339 
340     /**
341      * Get the status of the current socket
342      * @param callback callback to the caller
343      */
344     void GetState(const GetStateCallback &callback);
345 
346     /**
347      * Gets or sets the options associated with the current socket
348      * @param tcpExtraOptions options associated with the current socket
349      * @param callback callback to the caller
350      */
351     void SetExtraOptions(const Socket::TCPExtraOptions &tcpExtraOptions, const SetExtraOptionsCallback &callback);
352 
353     /**
354      *  Get a local digital certificate
355      * @param callback callback to the caller
356      */
357     void GetCertificate(const GetCertificateCallback &callback);
358 
359     /**
360      * Get the peer digital certificate
361      * @param needChain need chain
362      * @param callback callback to the caller
363      */
364     void GetRemoteCertificate(const GetRemoteCertificateCallback &callback);
365 
366     /**
367      * Obtain the protocol used in communication
368      * @param callback callback to the caller
369      */
370     void GetProtocol(const GetProtocolCallback &callback);
371 
372     /**
373      * Obtain the cipher suite used in communication
374      * @param callback callback to the caller
375      */
376     void GetCipherSuite(const GetCipherSuiteCallback &callback);
377 
378     /**
379      * Obtain the encryption algorithm used in the communication process
380      * @param callback callback to the caller
381      */
382     void GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback &callback);
383 
384     /**
385      * Register a callback which is called when message is received
386      * @param onMessageCallback callback which is called when message is received
387      */
388     void OnMessage(const OnMessageCallback &onMessageCallback);
389 
390     /**
391      * Register the callback that is called when the connection is established
392      * @param onConnectCallback callback invoked when connection is established
393      */
394     void OnConnect(const OnConnectCallback &onConnectCallback);
395 
396     /**
397      * Register the callback that is called when the connection is disconnected
398      * @param onCloseCallback callback invoked when disconnected
399      */
400     void OnClose(const OnCloseCallback &onCloseCallback);
401 
402     /**
403      * Register the callback that is called when an error occurs
404      * @param onErrorCallback callback invoked when an error occurs
405      */
406     void OnError(const OnErrorCallback &onErrorCallback);
407 
408     /**
409      * Unregister the callback which is called when message is received
410      */
411     void OffMessage();
412 
413     /**
414      * Off Connect
415      */
416     void OffConnect();
417 
418     /**
419      * Off Close
420      */
421     void OffClose();
422 
423     /**
424      * Off Error
425      */
426     void OffError();
427 
428     /**
429      * Get the socket file description of the server
430      */
431     int GetSocketFd();
432 
433     /**
434      * Set the current socket file description address of the server
435      */
436     void SetLocalAddress(const Socket::NetAddress &address);
437 
438     /**
439      * Get the current socket file description address of the server
440      */
441     Socket::NetAddress GetLocalAddress();
442 
443     void ExecTlsGetAddr(
444         const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr, socklen_t *len);
445 
446     bool ExecTlsSetSockBlockFlag(int sock, bool noneBlock);
447 
448     bool IsExtSock() const;
449 
450 private:
451     class TLSSocketInternal final {
452     public:
453         TLSSocketInternal() = default;
454         ~TLSSocketInternal() = default;
455 
456         /**
457          * Establish an encrypted connection on the specified socket
458          * @param sock socket for establishing encrypted connection
459          * @param options some options required during tls connection
460          * @param isExtSock socket fd is originated from external source when constructing tls socket
461          * @return whether the encrypted connection is successfully established
462          */
463         bool TlsConnectToHost(int sock, const TLSConnectOptions &options, bool isExtSock);
464 
465         /**
466          * Set the configuration items for establishing encrypted connections
467          * @param config configuration item when establishing encrypted connection
468          */
469         void SetTlsConfiguration(const TLSConnectOptions &config);
470 
471         /**
472          * Send data through an established encrypted connection
473          * @param data data sent over an established encrypted connection
474          * @return whether the data is successfully sent to the server
475          */
476         bool Send(const std::string &data);
477 
478         /**
479          * Receive the data sent by the server through the established encrypted connection
480          * @param buffer receive the data sent by the server
481          * @param maxBufferSize the size of the data received from the server
482          * @return whether the data sent by the server is successfully received
483          */
484         int Recv(char *buffer, int maxBufferSize);
485 
486         /**
487          * Disconnect encrypted connection
488          * @return whether the encrypted connection was successfully disconnected
489          */
490         bool Close();
491 
492         /**
493          * Set the application layer negotiation protocol in the encrypted communication process
494          * @param alpnProtocols application layer negotiation protocol
495          * @return set whether the application layer negotiation protocol is successful during encrypted communication
496          */
497         bool SetAlpnProtocols(const std::vector<std::string> &alpnProtocols);
498 
499         /**
500          * Storage of server communication related network information
501          * @param remoteInfo communication related network information
502          */
503         void MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo);
504 
505         /**
506          * convert the code to ssl error code
507          * @return the value for ssl error code.
508          */
509         int ConvertSSLError(void);
510 
511         /**
512          * Get configuration options for encrypted communication process
513          * @return configuration options for encrypted communication processes
514          */
515         [[nodiscard]] TLSConfiguration GetTlsConfiguration() const;
516 
517         /**
518          * Obtain the cipher suite during encrypted communication
519          * @return crypto suite used in encrypted communication
520          */
521         [[nodiscard]] std::vector<std::string> GetCipherSuite() const;
522 
523         /**
524          * Obtain the peer certificate used in encrypted communication
525          * @return peer certificate used in encrypted communication
526          */
527         [[nodiscard]] std::string GetRemoteCertificate() const;
528 
529         /**
530          * Obtain the peer certificate used in encrypted communication
531          * @return peer certificate serialization data used in encrypted communication
532          */
533         [[nodiscard]] const X509CertRawData &GetRemoteCertRawData() const;
534 
535         /**
536          * Obtain the certificate used in encrypted communication
537          * @return certificate serialization data used in encrypted communication
538          */
539         [[nodiscard]] const X509CertRawData &GetCertificate() const;
540 
541         /**
542          * Get the encryption algorithm used in encrypted communication
543          * @return encryption algorithm used in encrypted communication
544          */
545         [[nodiscard]] std::vector<std::string> GetSignatureAlgorithms() const;
546 
547         /**
548          * Obtain the communication protocol used in encrypted communication
549          * @return communication protocol used in encrypted communication
550          */
551         [[nodiscard]] std::string GetProtocol() const;
552 
553         /**
554          * Set the information about the shared signature algorithm supported by peers during encrypted communication
555          * @return information about peer supported shared signature algorithms
556          */
557         [[nodiscard]] bool SetSharedSigals();
558 
559         /**
560          * Obtain the ssl used in encrypted communication
561          * @return SSL used in encrypted communication
562          */
563         [[nodiscard]] ssl_st *GetSSL();
564 
565     private:
566         bool SendRetry(ssl_st *ssl, const char *curPos, size_t curSendSize, int sockfd);
567         bool StartTlsConnected(const TLSConnectOptions &options);
568         bool CreatTlsContext();
569         bool StartShakingHands(const TLSConnectOptions &options);
570         bool GetRemoteCertificateFromPeer();
571         bool SetRemoteCertRawData();
572         bool PollSend(int sockfd, ssl_st *ssl, const char *pdata, int sendSize);
573         std::string CheckServerIdentityLegal(const std::string &hostName, const X509 *x509Certificates);
574         std::string CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext,
575                                              const X509 *x509Certificates);
576 
577     private:
578         std::mutex mutexForSsl_;
579         ssl_st *ssl_ = nullptr;
580         X509 *peerX509_ = nullptr;
581         uint16_t port_ = 0;
582         sa_family_t family_ = 0;
583         int32_t socketDescriptor_ = 0;
584 
585         TLSContext tlsContext_;
586         TLSConfiguration configuration_;
587         Socket::NetAddress address_;
588         X509CertRawData remoteRawData_;
589 
590         std::string hostName_;
591         std::string remoteCert_;
592 
593         std::vector<std::string> signatureAlgorithms_;
594         std::unique_ptr<TLSContext> tlsContextPointer_ = nullptr;
595     };
596 
597 private:
598     TLSSocketInternal tlsSocketInternal_;
599 
600     static std::string MakeAddressString(sockaddr *addr);
601 
602     static void GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr,
603                         socklen_t *len);
604 
605     void CallOnMessageCallback(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo);
606     void CallOnConnectCallback();
607     void CallOnCloseCallback();
608     void CallOnErrorCallback(int32_t err, const std::string &errString);
609 
610     void CallBindCallback(int32_t err, BindCallback callback);
611     void CallConnectCallback(int32_t err, ConnectCallback callback);
612     void CallSendCallback(int32_t err, SendCallback callback);
613     void CallCloseCallback(int32_t err, CloseCallback callback);
614     void CallGetRemoteAddressCallback(int32_t err, const Socket::NetAddress &address,
615                                       GetRemoteAddressCallback callback);
616     void CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, GetStateCallback callback);
617     void CallSetExtraOptionsCallback(int32_t err, SetExtraOptionsCallback callback);
618     void CallGetCertificateCallback(int32_t err, const X509CertRawData &cert, GetCertificateCallback callback);
619     void CallGetRemoteCertificateCallback(int32_t err, const X509CertRawData &cert,
620                                           GetRemoteCertificateCallback callback);
621     void CallGetProtocolCallback(int32_t err, const std::string &protocol, GetProtocolCallback callback);
622     void CallGetCipherSuiteCallback(int32_t err, const std::vector<std::string> &suite,
623                                     GetCipherSuiteCallback callback);
624     void CallGetSignatureAlgorithmsCallback(int32_t err, const std::vector<std::string> &algorithms,
625                                             GetSignatureAlgorithmsCallback callback);
626 
627     int ReadMessage();
628     void StartReadMessage();
629 
630     void GetIp4RemoteAddress(const GetRemoteAddressCallback &callback);
631     void GetIp6RemoteAddress(const GetRemoteAddressCallback &callback);
632 
633     [[nodiscard]] bool SetBaseOptions(const Socket::ExtraOptionsBase &option) const;
634     [[nodiscard]] bool SetExtraOptions(const Socket::TCPExtraOptions &option) const;
635 
636     void MakeIpSocket(sa_family_t family);
637 
638     template<class T>
DealCallback(int32_t err,T & callback)639     void DealCallback(int32_t err, T &callback)
640     {
641         if (callback) {
642             callback(err);
643         }
644     }
645 
646 private:
647     static constexpr const size_t MAX_ERROR_LEN = 128;
648     static constexpr const size_t MAX_BUFFER_SIZE = 8192;
649 
650     OnMessageCallback onMessageCallback_;
651     OnConnectCallback onConnectCallback_;
652     OnCloseCallback onCloseCallback_;
653     OnErrorCallback onErrorCallback_;
654 
655     std::mutex mutex_;
656     std::mutex recvMutex_;
657     std::mutex cvMutex_;
658     bool isRunning_ = false;
659     bool isRunOver_ = true;
660     std::condition_variable cvSslFree_;
661     int sockFd_ = -1;
662     bool isExtSock_ = false;
663     Socket::NetAddress localAddress_;
664 };
665 } // namespace TlsSocket
666 } // namespace NetStack
667 } // namespace OHOS
668 
669 #endif // COMMUNICATIONNETSTACK_TLS_SOCEKT_H
670