• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "tls_socket.h"
17 
18 #include <chrono>
19 #include <memory>
20 #include <numeric>
21 #include <regex>
22 #include <securec.h>
23 #include <set>
24 #include <thread>
25 #include <poll.h>
26 
27 #include <netinet/tcp.h>
28 #include <openssl/err.h>
29 #include <openssl/ssl.h>
30 
31 #include "base_context.h"
32 #include "netstack_common_utils.h"
33 #include "netstack_log.h"
34 #include "tls.h"
35 #include "socket_exec_common.h"
36 
37 namespace OHOS {
38 namespace NetStack {
39 namespace TlsSocket {
40 namespace {
41 constexpr int READ_TIMEOUT_MS = 500;
42 constexpr int REMOTE_CERT_LEN = 8192;
43 constexpr int COMMON_NAME_BUF_SIZE = 256;
44 constexpr int BUF_SIZE = 2048;
45 constexpr int SSL_RET_CODE = 0;
46 constexpr int SSL_ERROR_RETURN = -1;
47 constexpr int SSL_WANT_READ_RETURN = -2;
48 constexpr int OFFSET = 2;
49 constexpr int DEFAULT_BUFFER_SIZE = 8192;
50 constexpr int DEFAULT_POLL_TIMEOUT_MS = 500;
51 constexpr int SEND_RETRY_TIMES = 5;
52 constexpr int SEND_POLL_TIMEOUT_MS = 1000;
53 constexpr int MAX_RECV_BUFFER_SIZE = 1024 * 16;
54 constexpr const char *SPLIT_ALT_NAMES = ",";
55 constexpr const char *SPLIT_HOST_NAME = ".";
56 constexpr const char *PROTOCOL_UNKNOW = "UNKNOW_PROTOCOL";
57 constexpr const char *UNKNOW_REASON = "Unknown reason";
58 constexpr const char *IP = "IP: ";
59 constexpr const char *HOST_NAME = "hostname: ";
60 constexpr const char *DNS = "DNS:";
61 constexpr const char *IP_ADDRESS = "IP Address:";
62 constexpr const char *SIGN_NID_RSA = "RSA+";
63 constexpr const char *SIGN_NID_RSA_PSS = "RSA-PSS+";
64 constexpr const char *SIGN_NID_DSA = "DSA+";
65 constexpr const char *SIGN_NID_ECDSA = "ECDSA+";
66 constexpr const char *SIGN_NID_ED = "Ed25519+";
67 constexpr const char *SIGN_NID_ED_FOUR_FOUR_EIGHT = "Ed448+";
68 constexpr const char *SIGN_NID_UNDEF_ADD = "UNDEF+";
69 constexpr const char *SIGN_NID_UNDEF = "UNDEF";
70 constexpr const char *OPERATOR_PLUS_SIGN = "+";
71 static constexpr const char *TLS_SOCKET_CLIENT_READ = "OS_NET_TSCliRD";
72 const std::regex JSON_STRING_PATTERN{R"(/^"(?:[^"\\\u0000-\u001f]|\\(?:["\\/bfnrt]|u[0-9a-fA-F]{4}))*"/)"};
73 const std::regex PATTERN{
74     "((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|"
75     "2[0-4][0-9]|[01]?[0-9][0-9]?)"};
76 
77 class CaCertCache {
78 public:
GetInstance()79     static CaCertCache &GetInstance()
80     {
81         static CaCertCache instance;
82         return instance;
83     }
84 
Get(const std::string & key)85     std::set<std::string> Get(const std::string &key)
86     {
87         std::lock_guard l(mutex_);
88         auto it = map_.find(key);
89         if (it != map_.end()) {
90             return it->second;
91         }
92         return {};
93     }
94 
Set(const std::string & key,const std::string & val)95     void Set(const std::string &key, const std::string &val)
96     {
97         std::lock_guard l(mutex_);
98         map_[key].insert(val);
99     }
100 
101 private:
102     CaCertCache() = default;
103     ~CaCertCache() = default;
104     CaCertCache &operator=(const CaCertCache &) = delete;
105     CaCertCache(const CaCertCache &) = delete;
106 
107     std::map<std::string, std::set<std::string>> map_;
108     std::mutex mutex_;
109 };
110 
ConvertErrno()111 int ConvertErrno()
112 {
113     return TlsSocketError::TLS_ERR_SYS_BASE + errno;
114 }
115 
MakeErrnoString()116 std::string MakeErrnoString()
117 {
118     return strerror(errno);
119 }
120 
MakeSSLErrorString(int error)121 std::string MakeSSLErrorString(int error)
122 {
123     char err[MAX_ERR_LEN] = {0};
124     ERR_error_string_n(error - TlsSocketError::TLS_ERR_SYS_BASE, err, sizeof(err));
125     return err;
126 }
127 
SplitEscapedAltNames(std::string & altNames)128 std::vector<std::string> SplitEscapedAltNames(std::string &altNames)
129 {
130     std::vector<std::string> result;
131     std::string currentToken;
132     size_t offset = 0;
133     while (offset != altNames.length()) {
134         auto nextSep = altNames.find_first_of(", ");
135         auto nextQuote = altNames.find_first_of('\"');
136         if (nextQuote != std::string::npos && (nextSep != std::string::npos || nextQuote < nextSep)) {
137             currentToken += altNames.substr(offset, nextQuote);
138             std::regex jsonStringPattern(JSON_STRING_PATTERN);
139             std::smatch match;
140             std::string altNameSubStr = altNames.substr(nextQuote);
141             bool ret = regex_match(altNameSubStr, match, jsonStringPattern);
142             if (!ret) {
143                 return {""};
144             }
145             currentToken += result[0];
146             offset = nextQuote + result[0].length();
147         } else if (nextSep != std::string::npos) {
148             currentToken += altNames.substr(offset, nextSep);
149             result.push_back(currentToken);
150             currentToken = "";
151             offset = nextSep + OFFSET;
152         } else {
153             currentToken += altNames.substr(offset);
154             offset = altNames.length();
155         }
156     }
157     result.push_back(currentToken);
158     return result;
159 }
160 
IsIP(const std::string & ip)161 bool IsIP(const std::string &ip)
162 {
163     std::regex pattern(PATTERN);
164     std::smatch res;
165     return regex_match(ip, res, pattern);
166 }
167 
SplitHostName(std::string & hostName)168 std::vector<std::string> SplitHostName(std::string &hostName)
169 {
170     transform(hostName.begin(), hostName.end(), hostName.begin(), ::tolower);
171     return CommonUtils::Split(hostName, SPLIT_HOST_NAME);
172 }
173 
SeekIntersection(std::vector<std::string> & vecA,std::vector<std::string> & vecB)174 bool SeekIntersection(std::vector<std::string> &vecA, std::vector<std::string> &vecB)
175 {
176     std::vector<std::string> result;
177     set_intersection(vecA.begin(), vecA.end(), vecB.begin(), vecB.end(), inserter(result, result.begin()));
178     return !result.empty();
179 }
180 } // namespace
181 
SetSockBlockFlag(int sock,bool noneBlock)182 static bool SetSockBlockFlag(int sock, bool noneBlock)
183 {
184     int flags = fcntl(sock, F_GETFL, 0);
185     while (flags == -1 && errno == EINTR) {
186         flags = fcntl(sock, F_GETFL, 0);
187     }
188     if (flags == -1) {
189         NETSTACK_LOGE("set block flags failed, socket is %{public}d, errno is %{public}d", sock, errno);
190         return false;
191     }
192 
193     auto newFlags = static_cast<size_t>(flags);
194     if (noneBlock) {
195         newFlags |= static_cast<size_t>(O_NONBLOCK);
196     } else {
197         newFlags &= ~static_cast<size_t>(O_NONBLOCK);
198     }
199 
200     int ret = fcntl(sock, F_SETFL, newFlags);
201     while (ret == -1 && errno == EINTR) {
202         ret = fcntl(sock, F_SETFL, newFlags);
203     }
204     if (ret == -1) {
205         NETSTACK_LOGE("set block flags failed, socket is %{public}d, errno is %{public}d", sock, errno);
206         return false;
207     }
208     return true;
209 }
210 
TLSSecureOptions(const TLSSecureOptions & tlsSecureOptions)211 TLSSecureOptions::TLSSecureOptions(const TLSSecureOptions &tlsSecureOptions)
212 {
213     *this = tlsSecureOptions;
214 }
215 
operator =(const TLSSecureOptions & tlsSecureOptions)216 TLSSecureOptions &TLSSecureOptions::operator=(const TLSSecureOptions &tlsSecureOptions)
217 {
218     key_ = tlsSecureOptions.GetKey();
219     caChain_ = tlsSecureOptions.GetCaChain();
220     cert_ = tlsSecureOptions.GetCert();
221     protocolChain_ = tlsSecureOptions.GetProtocolChain();
222     crlChain_ = tlsSecureOptions.GetCrlChain();
223     keyPass_ = tlsSecureOptions.GetKeyPass();
224     key_ = tlsSecureOptions.GetKey();
225     signatureAlgorithms_ = tlsSecureOptions.GetSignatureAlgorithms();
226     cipherSuite_ = tlsSecureOptions.GetCipherSuite();
227     useRemoteCipherPrefer_ = tlsSecureOptions.UseRemoteCipherPrefer();
228     TLSVerifyMode_ = tlsSecureOptions.GetVerifyMode();
229     return *this;
230 }
231 
SetCaChain(const std::vector<std::string> & caChain)232 void TLSSecureOptions::SetCaChain(const std::vector<std::string> &caChain)
233 {
234     caChain_ = caChain;
235 }
236 
SetCert(const std::string & cert)237 void TLSSecureOptions::SetCert(const std::string &cert)
238 {
239     cert_ = cert;
240 }
241 
SetKey(const SecureData & key)242 void TLSSecureOptions::SetKey(const SecureData &key)
243 {
244     key_ = key;
245 }
246 
SetKeyPass(const SecureData & keyPass)247 void TLSSecureOptions::SetKeyPass(const SecureData &keyPass)
248 {
249     keyPass_ = keyPass;
250 }
251 
SetProtocolChain(const std::vector<std::string> & protocolChain)252 void TLSSecureOptions::SetProtocolChain(const std::vector<std::string> &protocolChain)
253 {
254     protocolChain_ = protocolChain;
255 }
256 
SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer)257 void TLSSecureOptions::SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer)
258 {
259     useRemoteCipherPrefer_ = useRemoteCipherPrefer;
260 }
261 
SetSignatureAlgorithms(const std::string & signatureAlgorithms)262 void TLSSecureOptions::SetSignatureAlgorithms(const std::string &signatureAlgorithms)
263 {
264     signatureAlgorithms_ = signatureAlgorithms;
265 }
266 
SetCipherSuite(const std::string & cipherSuite)267 void TLSSecureOptions::SetCipherSuite(const std::string &cipherSuite)
268 {
269     cipherSuite_ = cipherSuite;
270 }
271 
SetCrlChain(const std::vector<std::string> & crlChain)272 void TLSSecureOptions::SetCrlChain(const std::vector<std::string> &crlChain)
273 {
274     crlChain_ = crlChain;
275 }
276 
GetCaChain() const277 const std::vector<std::string> &TLSSecureOptions::GetCaChain() const
278 {
279     return caChain_;
280 }
281 
GetCert() const282 const std::string &TLSSecureOptions::GetCert() const
283 {
284     return cert_;
285 }
286 
GetKey() const287 const SecureData &TLSSecureOptions::GetKey() const
288 {
289     return key_;
290 }
291 
GetKeyPass() const292 const SecureData &TLSSecureOptions::GetKeyPass() const
293 {
294     return keyPass_;
295 }
296 
GetProtocolChain() const297 const std::vector<std::string> &TLSSecureOptions::GetProtocolChain() const
298 {
299     return protocolChain_;
300 }
301 
UseRemoteCipherPrefer() const302 bool TLSSecureOptions::UseRemoteCipherPrefer() const
303 {
304     return useRemoteCipherPrefer_;
305 }
306 
GetSignatureAlgorithms() const307 const std::string &TLSSecureOptions::GetSignatureAlgorithms() const
308 {
309     return signatureAlgorithms_;
310 }
311 
GetCipherSuite() const312 const std::string &TLSSecureOptions::GetCipherSuite() const
313 {
314     return cipherSuite_;
315 }
316 
GetCrlChain() const317 const std::vector<std::string> &TLSSecureOptions::GetCrlChain() const
318 {
319     return crlChain_;
320 }
321 
SetVerifyMode(VerifyMode verifyMode)322 void TLSSecureOptions::SetVerifyMode(VerifyMode verifyMode)
323 {
324     TLSVerifyMode_ = verifyMode;
325 }
326 
GetVerifyMode() const327 VerifyMode TLSSecureOptions::GetVerifyMode() const
328 {
329     return TLSVerifyMode_;
330 }
331 
SetNetAddress(const Socket::NetAddress & address)332 void TLSConnectOptions::SetNetAddress(const Socket::NetAddress &address)
333 {
334     address_.SetFamilyBySaFamily(address.GetSaFamily());
335     address_.SetRawAddress(address.GetAddress());
336     address_.SetPort(address.GetPort());
337 }
338 
SetTlsSecureOptions(TLSSecureOptions & tlsSecureOptions)339 void TLSConnectOptions::SetTlsSecureOptions(TLSSecureOptions &tlsSecureOptions)
340 {
341     tlsSecureOptions_ = tlsSecureOptions;
342 }
343 
SetCheckServerIdentity(const CheckServerIdentity & checkServerIdentity)344 void TLSConnectOptions::SetCheckServerIdentity(const CheckServerIdentity &checkServerIdentity)
345 {
346     checkServerIdentity_ = checkServerIdentity;
347 }
348 
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)349 void TLSConnectOptions::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
350 {
351     alpnProtocols_ = alpnProtocols;
352 }
353 
SetSkipRemoteValidation(bool skipRemoteValidation)354 void TLSConnectOptions::SetSkipRemoteValidation(bool skipRemoteValidation)
355 {
356     skipRemoteValidation_ = skipRemoteValidation;
357 }
358 
GetNetAddress() const359 Socket::NetAddress TLSConnectOptions::GetNetAddress() const
360 {
361     return address_;
362 }
363 
GetTlsSecureOptions() const364 TLSSecureOptions TLSConnectOptions::GetTlsSecureOptions() const
365 {
366     return tlsSecureOptions_;
367 }
368 
GetCheckServerIdentity() const369 CheckServerIdentity TLSConnectOptions::GetCheckServerIdentity() const
370 {
371     return checkServerIdentity_;
372 }
373 
GetAlpnProtocols() const374 const std::vector<std::string> &TLSConnectOptions::GetAlpnProtocols() const
375 {
376     return alpnProtocols_;
377 }
378 
GetSkipRemoteValidation() const379 bool TLSConnectOptions::GetSkipRemoteValidation() const
380 {
381     return skipRemoteValidation_;
382 }
383 
SetHostName(const std::string & hostName)384 void TLSConnectOptions::SetHostName(const std::string &hostName)
385 {
386     hostName_ = hostName;
387 }
388 
GetHostName() const389 std::string TLSConnectOptions::GetHostName() const
390 {
391     return hostName_;
392 }
393 
MakeAddressString(sockaddr * addr)394 std::string TLSSocket::MakeAddressString(sockaddr *addr)
395 {
396     if (!addr) {
397         return {};
398     }
399     if (addr->sa_family == AF_INET) {
400         auto *addr4 = reinterpret_cast<sockaddr_in *>(addr);
401         const char *str = inet_ntoa(addr4->sin_addr);
402         if (str == nullptr || strlen(str) == 0) {
403             return {};
404         }
405         return str;
406     } else if (addr->sa_family == AF_INET6) {
407         auto *addr6 = reinterpret_cast<sockaddr_in6 *>(addr);
408         char str[INET6_ADDRSTRLEN] = {0};
409         if (inet_ntop(AF_INET6, &addr6->sin6_addr, str, INET6_ADDRSTRLEN) == nullptr || strlen(str) == 0) {
410             return {};
411         }
412         return str;
413     }
414     return {};
415 }
416 
GetAddr(const Socket::NetAddress & address,sockaddr_in * addr4,sockaddr_in6 * addr6,sockaddr ** addr,socklen_t * len)417 void TLSSocket::GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr,
418                         socklen_t *len)
419 {
420     if (!addr6 || !addr4 || !len) {
421         return;
422     }
423     sa_family_t family = address.GetSaFamily();
424     if (family == AF_INET) {
425         addr4->sin_family = AF_INET;
426         addr4->sin_port = htons(address.GetPort());
427         addr4->sin_addr.s_addr = inet_addr(address.GetAddress().c_str());
428         *addr = reinterpret_cast<sockaddr *>(addr4);
429         *len = sizeof(sockaddr_in);
430     } else if (family == AF_INET6) {
431         addr6->sin6_family = AF_INET6;
432         addr6->sin6_port = htons(address.GetPort());
433         inet_pton(AF_INET6, address.GetAddress().c_str(), &addr6->sin6_addr);
434         *addr = reinterpret_cast<sockaddr *>(addr6);
435         *len = sizeof(sockaddr_in6);
436     }
437 }
438 
MakeIpSocket(sa_family_t family)439 void TLSSocket::MakeIpSocket(sa_family_t family)
440 {
441     if (family != AF_INET && family != AF_INET6) {
442         return;
443     }
444     int sock = socket(family, SOCK_STREAM, IPPROTO_IP);
445     if (sock < 0) {
446         int resErr = ConvertErrno();
447         NETSTACK_LOGE("Create socket failed (%{public}d:%{public}s)", errno, MakeErrnoString().c_str());
448         CallOnErrorCallback(resErr, MakeErrnoString());
449         return;
450     }
451     sockFd_ = sock;
452 }
453 
ReadMessage()454 int TLSSocket::ReadMessage()
455 {
456     char buffer[MAX_RECV_BUFFER_SIZE];
457     if (memset_s(buffer, MAX_RECV_BUFFER_SIZE, 0, MAX_RECV_BUFFER_SIZE) != EOK) {
458         NETSTACK_LOGE("memset_s failed!");
459         return -1;
460     }
461     nfds_t num = 1;
462     pollfd fds[1] = {{.fd = sockFd_, .events = POLLIN}};
463     int ret = poll(fds, num, READ_TIMEOUT_MS);
464     if (ret < 0) {
465         if (errno == EAGAIN || errno == EINTR) {
466             return 0;
467         }
468         int resErr = ConvertErrno();
469         NETSTACK_LOGE("Message poll errno is %{public}d %{public}s", errno, MakeErrnoString().c_str());
470         CallOnErrorCallback(resErr, MakeErrnoString());
471         return ret;
472     } else if (ret == 0) {
473         NETSTACK_LOGD("tls recv poll timeout");
474         return ret;
475     }
476 
477     std::lock_guard<std::mutex> lock(recvMutex_);
478     if (!isRunning_) {
479         return -1;
480     }
481     int len = tlsSocketInternal_.Recv(buffer, MAX_RECV_BUFFER_SIZE);
482     if (len < 0) {
483         if (errno == EAGAIN || errno == EINTR || len == SSL_WANT_READ_RETURN) {
484             return 0;
485         }
486         int resErr = tlsSocketInternal_.ConvertSSLError();
487         NETSTACK_LOGE("SSL_read function read error, errno is %{public}d, errno info is %{public}s",
488                       resErr, MakeSSLErrorString(resErr).c_str());
489         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
490         return len;
491     } else if (len == 0) {
492         NETSTACK_LOGI("Message recv len 0, session is closed by peer");
493         CallOnCloseCallback();
494         return -1;
495     }
496     Socket::SocketRemoteInfo remoteInfo;
497     remoteInfo.SetSize(len);
498     tlsSocketInternal_.MakeRemoteInfo(remoteInfo);
499     std::string bufContent(buffer, len);
500     CallOnMessageCallback(bufContent, remoteInfo);
501 
502     return ret;
503 }
504 
StartReadMessage()505 void TLSSocket::StartReadMessage()
506 {
507     std::thread thread([this]() {
508         isRunning_ = true;
509         isRunOver_ = false;
510 #if defined(MAC_PLATFORM) || defined(IOS_PLATFORM)
511         pthread_setname_np(TLS_SOCKET_CLIENT_READ);
512 #else
513         pthread_setname_np(pthread_self(), TLS_SOCKET_CLIENT_READ);
514 #endif
515         while (isRunning_) {
516             int ret = ReadMessage();
517             if (ret < 0) {
518                 break;
519             }
520         }
521         isRunOver_ = true;
522         cvSslFree_.notify_one();
523     });
524     thread.detach();
525 }
526 
CallOnMessageCallback(const std::string & data,const Socket::SocketRemoteInfo & remoteInfo)527 void TLSSocket::CallOnMessageCallback(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)
528 {
529     OnMessageCallback func = nullptr;
530     {
531         std::lock_guard<std::mutex> lock(mutex_);
532         if (onMessageCallback_) {
533             func = onMessageCallback_;
534         }
535     }
536 
537     if (func) {
538         func(data, remoteInfo);
539     }
540 }
541 
CallOnConnectCallback()542 void TLSSocket::CallOnConnectCallback()
543 {
544     OnConnectCallback func = nullptr;
545     {
546         std::lock_guard<std::mutex> lock(mutex_);
547         if (onConnectCallback_) {
548             func = onConnectCallback_;
549         }
550     }
551 
552     if (func) {
553         func();
554     }
555 }
556 
CallOnCloseCallback()557 void TLSSocket::CallOnCloseCallback()
558 {
559     OnCloseCallback func = nullptr;
560     {
561         std::lock_guard<std::mutex> lock(mutex_);
562         if (onCloseCallback_) {
563             func = onCloseCallback_;
564         }
565     }
566 
567     if (func) {
568         func();
569     }
570 }
571 
CallOnErrorCallback(int32_t err,const std::string & errString)572 void TLSSocket::CallOnErrorCallback(int32_t err, const std::string &errString)
573 {
574     OnErrorCallback func = nullptr;
575     {
576         std::lock_guard<std::mutex> lock(mutex_);
577         if (onErrorCallback_) {
578             func = onErrorCallback_;
579         }
580     }
581 
582     if (func) {
583         func(err, errString);
584     }
585 }
586 
CallBindCallback(int32_t err,BindCallback callback)587 void TLSSocket::CallBindCallback(int32_t err, BindCallback callback)
588 {
589     DealCallback<BindCallback>(err, callback);
590 }
591 
CallConnectCallback(int32_t err,ConnectCallback callback)592 void TLSSocket::CallConnectCallback(int32_t err, ConnectCallback callback)
593 {
594     DealCallback<ConnectCallback>(err, callback);
595 }
596 
CallSendCallback(int32_t err,SendCallback callback)597 void TLSSocket::CallSendCallback(int32_t err, SendCallback callback)
598 {
599     DealCallback<SendCallback>(err, callback);
600 }
601 
CallCloseCallback(int32_t err,CloseCallback callback)602 void TLSSocket::CallCloseCallback(int32_t err, CloseCallback callback)
603 {
604     DealCallback<CloseCallback>(err, callback);
605 }
606 
CallGetRemoteAddressCallback(int32_t err,const Socket::NetAddress & address,GetRemoteAddressCallback callback)607 void TLSSocket::CallGetRemoteAddressCallback(int32_t err, const Socket::NetAddress &address,
608                                              GetRemoteAddressCallback callback)
609 {
610     GetRemoteAddressCallback func = nullptr;
611     {
612         std::lock_guard<std::mutex> lock(mutex_);
613         if (callback) {
614             func = callback;
615         }
616     }
617 
618     if (func) {
619         func(err, address);
620     }
621 }
622 
CallGetStateCallback(int32_t err,const Socket::SocketStateBase & state,GetStateCallback callback)623 void TLSSocket::CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, GetStateCallback callback)
624 {
625     GetStateCallback func = nullptr;
626     {
627         std::lock_guard<std::mutex> lock(mutex_);
628         if (callback) {
629             func = callback;
630         }
631     }
632 
633     if (func) {
634         func(err, state);
635     }
636 }
637 
CallSetExtraOptionsCallback(int32_t err,SetExtraOptionsCallback callback)638 void TLSSocket::CallSetExtraOptionsCallback(int32_t err, SetExtraOptionsCallback callback)
639 {
640     DealCallback<SetExtraOptionsCallback>(err, callback);
641 }
642 
CallGetCertificateCallback(int32_t err,const X509CertRawData & cert,GetCertificateCallback callback)643 void TLSSocket::CallGetCertificateCallback(int32_t err, const X509CertRawData &cert, GetCertificateCallback callback)
644 {
645     GetCertificateCallback func = nullptr;
646     {
647         std::lock_guard<std::mutex> lock(mutex_);
648         if (callback) {
649             func = callback;
650         }
651     }
652 
653     if (func) {
654         func(err, cert);
655     }
656 }
657 
CallGetRemoteCertificateCallback(int32_t err,const X509CertRawData & cert,GetRemoteCertificateCallback callback)658 void TLSSocket::CallGetRemoteCertificateCallback(int32_t err, const X509CertRawData &cert,
659                                                  GetRemoteCertificateCallback callback)
660 {
661     GetRemoteCertificateCallback func = nullptr;
662     {
663         std::lock_guard<std::mutex> lock(mutex_);
664         if (callback) {
665             func = callback;
666         }
667     }
668 
669     if (func) {
670         func(err, cert);
671     }
672 }
673 
CallGetProtocolCallback(int32_t err,const std::string & protocol,GetProtocolCallback callback)674 void TLSSocket::CallGetProtocolCallback(int32_t err, const std::string &protocol, GetProtocolCallback callback)
675 {
676     GetProtocolCallback func = nullptr;
677     {
678         std::lock_guard<std::mutex> lock(mutex_);
679         if (callback) {
680             func = callback;
681         }
682     }
683 
684     if (func) {
685         func(err, protocol);
686     }
687 }
688 
CallGetCipherSuiteCallback(int32_t err,const std::vector<std::string> & suite,GetCipherSuiteCallback callback)689 void TLSSocket::CallGetCipherSuiteCallback(int32_t err, const std::vector<std::string> &suite,
690                                            GetCipherSuiteCallback callback)
691 {
692     GetCipherSuiteCallback func = nullptr;
693     {
694         std::lock_guard<std::mutex> lock(mutex_);
695         if (callback) {
696             func = callback;
697         }
698     }
699 
700     if (func) {
701         func(err, suite);
702     }
703 }
704 
CallGetSignatureAlgorithmsCallback(int32_t err,const std::vector<std::string> & algorithms,GetSignatureAlgorithmsCallback callback)705 void TLSSocket::CallGetSignatureAlgorithmsCallback(int32_t err, const std::vector<std::string> &algorithms,
706                                                    GetSignatureAlgorithmsCallback callback)
707 {
708     GetSignatureAlgorithmsCallback func = nullptr;
709     {
710         std::lock_guard<std::mutex> lock(mutex_);
711         if (callback) {
712             func = callback;
713         }
714     }
715 
716     if (func) {
717         func(err, algorithms);
718     }
719 }
720 
Bind(Socket::NetAddress & address,const BindCallback & callback)721 void TLSSocket::Bind(Socket::NetAddress &address, const BindCallback &callback)
722 {
723     static constexpr int32_t PARSE_ERROR_CODE = 401;
724     if (!CommonUtils::HasInternetPermission()) {
725         CallBindCallback(PERMISSION_DENIED_CODE, callback);
726         return;
727     }
728     if (sockFd_ >= 0) {
729         CallBindCallback(TLSSOCKET_SUCCESS, callback);
730         return;
731     }
732 
733     MakeIpSocket(address.GetSaFamily());
734     if (sockFd_ < 0) {
735         int resErr = ConvertErrno();
736         NETSTACK_LOGE("make tcp socket failed errno is %{public}d %{public}s", errno, MakeErrnoString().c_str());
737         CallOnErrorCallback(resErr, MakeErrnoString());
738         CallBindCallback(resErr, callback);
739         return;
740     }
741 
742     auto temp = address.GetAddress();
743     address.SetRawAddress("");
744     address.SetAddress(temp);
745     if (address.GetAddress().empty()) {
746         CallBindCallback(PARSE_ERROR_CODE, callback);
747         return;
748     }
749 
750     sockaddr_in addr4 = {0};
751     sockaddr_in6 addr6 = {0};
752     sockaddr *addr = nullptr;
753     socklen_t len;
754     GetAddr(address, &addr4, &addr6, &addr, &len);
755     if (addr == nullptr) {
756         NETSTACK_LOGE("TLSSocket::Bind Address Is Invalid");
757         CallOnErrorCallback(-1, "Address Is Invalid");
758         CallBindCallback(ConvertErrno(), callback);
759         return;
760     }
761     CallBindCallback(TLSSOCKET_SUCCESS, callback);
762 }
763 
Connect(OHOS::NetStack::TlsSocket::TLSConnectOptions & tlsConnectOptions,const OHOS::NetStack::TlsSocket::ConnectCallback & callback)764 void TLSSocket::Connect(OHOS::NetStack::TlsSocket::TLSConnectOptions &tlsConnectOptions,
765                         const OHOS::NetStack::TlsSocket::ConnectCallback &callback)
766 {
767     if (sockFd_ < 0) {
768         int resErr = ConvertErrno();
769         NETSTACK_LOGE("connect error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
770         CallOnErrorCallback(resErr, MakeErrnoString());
771         callback(resErr);
772         return;
773     }
774 
775     if (isExtSock_ && !SetSockBlockFlag(sockFd_, false)) {
776         int resErr = ConvertErrno();
777         NETSTACK_LOGE("SetSockBlockFlag error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
778         CallOnErrorCallback(resErr, MakeErrnoString());
779         callback(resErr);
780         return;
781     }
782 
783     auto res = tlsSocketInternal_.TlsConnectToHost(sockFd_, tlsConnectOptions, isExtSock_);
784     if (!res) {
785         int resErr = tlsSocketInternal_.ConvertSSLError();
786         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
787         callback(resErr);
788         return;
789     }
790     if (!SetSockBlockFlag(sockFd_, true)) {
791         int resErr = ConvertErrno();
792         NETSTACK_LOGE("SetSockBlockFlag error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
793         CallOnErrorCallback(resErr, MakeErrnoString());
794         callback(resErr);
795         return;
796     }
797     StartReadMessage();
798     CallOnConnectCallback();
799     callback(TLSSOCKET_SUCCESS);
800 }
801 
Send(const OHOS::NetStack::Socket::TCPSendOptions & tcpSendOptions,const SendCallback & callback)802 void TLSSocket::Send(const OHOS::NetStack::Socket::TCPSendOptions &tcpSendOptions, const SendCallback &callback)
803 {
804     (void)tcpSendOptions;
805 
806     auto res = tlsSocketInternal_.Send(tcpSendOptions.GetData());
807     if (!res) {
808         int resErr = tlsSocketInternal_.ConvertSSLError();
809         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
810         CallSendCallback(resErr, callback);
811         return;
812     }
813     CallSendCallback(TLSSOCKET_SUCCESS, callback);
814 }
815 
Close(const CloseCallback & callback)816 void TLSSocket::Close(const CloseCallback &callback)
817 {
818     isRunning_ = false;
819     std::unique_lock<std::mutex> cvLock(cvMutex_);
820     cvSslFree_.wait(cvLock, [this]() -> bool { return isRunOver_; });
821 
822     std::lock_guard<std::mutex> lock(recvMutex_);
823     auto res = tlsSocketInternal_.Close();
824     if (!res) {
825         int resErr = tlsSocketInternal_.ConvertSSLError();
826         NETSTACK_LOGE("close error is %{public}s %{public}d", MakeSSLErrorString(resErr).c_str(), resErr);
827         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
828         callback(resErr);
829         return;
830     }
831     sockFd_ = -1;
832     CallOnCloseCallback();
833     callback(TLSSOCKET_SUCCESS);
834 }
835 
GetRemoteAddress(const GetRemoteAddressCallback & callback)836 void TLSSocket::GetRemoteAddress(const GetRemoteAddressCallback &callback)
837 {
838     sockaddr sockAddr = {0};
839     socklen_t len = sizeof(sockaddr);
840     int ret = getsockname(sockFd_, &sockAddr, &len);
841     if (ret < 0) {
842         int resErr = ConvertErrno();
843         NETSTACK_LOGE("getsockname failed errno %{public}d", resErr);
844         CallOnErrorCallback(resErr, MakeErrnoString());
845         CallGetRemoteAddressCallback(resErr, {}, callback);
846         return;
847     }
848 
849     if (sockAddr.sa_family == AF_INET) {
850         GetIp4RemoteAddress(callback);
851     } else if (sockAddr.sa_family == AF_INET6) {
852         GetIp6RemoteAddress(callback);
853     }
854 }
855 
GetIp4RemoteAddress(const GetRemoteAddressCallback & callback)856 void TLSSocket::GetIp4RemoteAddress(const GetRemoteAddressCallback &callback)
857 {
858     sockaddr_in addr4 = {0};
859     socklen_t len4 = sizeof(sockaddr_in);
860 
861     int ret = getpeername(sockFd_, reinterpret_cast<sockaddr *>(&addr4), &len4);
862     if (ret < 0) {
863         int resErr = ConvertErrno();
864         NETSTACK_LOGE("GetIp4RemoteAddress failed errno %{public}d", resErr);
865         CallOnErrorCallback(resErr, MakeErrnoString());
866         CallGetRemoteAddressCallback(resErr, {}, callback);
867         return;
868     }
869 
870     std::string address = MakeAddressString(reinterpret_cast<sockaddr *>(&addr4));
871     if (address.empty()) {
872         NETSTACK_LOGE("GetIp4RemoteAddress failed errno %{public}d", errno);
873         CallOnErrorCallback(-1, "Address is invalid");
874         CallGetRemoteAddressCallback(ConvertErrno(), {}, callback);
875         return;
876     }
877     Socket::NetAddress netAddress;
878     netAddress.SetFamilyBySaFamily(AF_INET);
879     netAddress.SetRawAddress(address);
880     netAddress.SetPort(ntohs(addr4.sin_port));
881     CallGetRemoteAddressCallback(TLSSOCKET_SUCCESS, netAddress, callback);
882 }
883 
GetIp6RemoteAddress(const GetRemoteAddressCallback & callback)884 void TLSSocket::GetIp6RemoteAddress(const GetRemoteAddressCallback &callback)
885 {
886     sockaddr_in6 addr6 = {0};
887     socklen_t len6 = sizeof(sockaddr_in6);
888 
889     int ret = getpeername(sockFd_, reinterpret_cast<sockaddr *>(&addr6), &len6);
890     if (ret < 0) {
891         int resErr = ConvertErrno();
892         NETSTACK_LOGE("GetIp6RemoteAddress failed errno %{public}d", resErr);
893         CallOnErrorCallback(resErr, MakeErrnoString());
894         CallGetRemoteAddressCallback(resErr, {}, callback);
895         return;
896     }
897 
898     std::string address = MakeAddressString(reinterpret_cast<sockaddr *>(&addr6));
899     if (address.empty()) {
900         NETSTACK_LOGE("GetIp6RemoteAddress failed errno %{public}d", errno);
901         CallOnErrorCallback(-1, "Address is invalid");
902         CallGetRemoteAddressCallback(ConvertErrno(), {}, callback);
903         return;
904     }
905     Socket::NetAddress netAddress;
906     netAddress.SetFamilyBySaFamily(AF_INET6);
907     netAddress.SetRawAddress(address);
908     netAddress.SetPort(ntohs(addr6.sin6_port));
909     CallGetRemoteAddressCallback(TLSSOCKET_SUCCESS, netAddress, callback);
910 }
911 
GetState(const GetStateCallback & callback)912 void TLSSocket::GetState(const GetStateCallback &callback)
913 {
914     int opt;
915     socklen_t optLen = sizeof(int);
916     int r = getsockopt(sockFd_, SOL_SOCKET, SO_TYPE, &opt, &optLen);
917     if (r < 0) {
918         Socket::SocketStateBase state;
919         state.SetIsClose(true);
920         CallGetStateCallback(ConvertErrno(), state, callback);
921         return;
922     }
923     sockaddr sockAddr = {0};
924     socklen_t len = sizeof(sockaddr);
925     Socket::SocketStateBase state;
926     int ret = getsockname(sockFd_, &sockAddr, &len);
927     state.SetIsBound(ret == 0);
928     ret = getpeername(sockFd_, &sockAddr, &len);
929     state.SetIsConnected(ret == 0);
930     CallGetStateCallback(TLSSOCKET_SUCCESS, state, callback);
931 }
932 
SetBaseOptions(const Socket::ExtraOptionsBase & option) const933 bool TLSSocket::SetBaseOptions(const Socket::ExtraOptionsBase &option) const
934 {
935     if (option.GetReceiveBufferSize() != 0) {
936         int size = (int)option.GetReceiveBufferSize();
937         if (setsockopt(sockFd_, SOL_SOCKET, SO_RCVBUF, reinterpret_cast<void *>(&size), sizeof(size)) < 0) {
938             return false;
939         }
940     }
941 
942     if (option.GetSendBufferSize() != 0) {
943         int size = (int)option.GetSendBufferSize();
944         if (setsockopt(sockFd_, SOL_SOCKET, SO_SNDBUF, reinterpret_cast<void *>(&size), sizeof(size)) < 0) {
945             return false;
946         }
947     }
948 
949     if (option.IsReuseAddress()) {
950         int reuse = 1;
951         if (setsockopt(sockFd_, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<void *>(&reuse), sizeof(reuse)) < 0) {
952             return false;
953         }
954     }
955 
956     if (option.GetSocketTimeout() != 0) {
957         timeval timeout = {(int)option.GetSocketTimeout(), 0};
958         if (setsockopt(sockFd_, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<void *>(&timeout), sizeof(timeout)) < 0) {
959             return false;
960         }
961         if (setsockopt(sockFd_, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast<void *>(&timeout), sizeof(timeout)) < 0) {
962             return false;
963         }
964     }
965 
966     return true;
967 }
968 
SetExtraOptions(const Socket::TCPExtraOptions & option) const969 bool TLSSocket::SetExtraOptions(const Socket::TCPExtraOptions &option) const
970 {
971     if (option.IsKeepAlive()) {
972         int keepalive = 1;
973         if (setsockopt(sockFd_, SOL_SOCKET, SO_KEEPALIVE, &keepalive, sizeof(keepalive)) < 0) {
974             return false;
975         }
976     }
977 
978     if (option.IsOOBInline()) {
979         int oobInline = 1;
980         if (setsockopt(sockFd_, SOL_SOCKET, SO_OOBINLINE, &oobInline, sizeof(oobInline)) < 0) {
981             return false;
982         }
983     }
984 
985     if (option.IsTCPNoDelay()) {
986         int tcpNoDelay = 1;
987         if (setsockopt(sockFd_, IPPROTO_TCP, TCP_NODELAY, &tcpNoDelay, sizeof(tcpNoDelay)) < 0) {
988             return false;
989         }
990     }
991 
992     linger soLinger = {0};
993     soLinger.l_onoff = option.socketLinger.IsOn();
994     soLinger.l_linger = (int)option.socketLinger.GetLinger();
995     if (setsockopt(sockFd_, SOL_SOCKET, SO_LINGER, &soLinger, sizeof(soLinger)) < 0) {
996         return false;
997     }
998 
999     return true;
1000 }
1001 
SetExtraOptions(const OHOS::NetStack::Socket::TCPExtraOptions & tcpExtraOptions,const SetExtraOptionsCallback & callback)1002 void TLSSocket::SetExtraOptions(const OHOS::NetStack::Socket::TCPExtraOptions &tcpExtraOptions,
1003                                 const SetExtraOptionsCallback &callback)
1004 {
1005     if (!SetBaseOptions(tcpExtraOptions)) {
1006         NETSTACK_LOGE("SetExtraOptions errno %{public}d", errno);
1007         CallOnErrorCallback(errno, MakeErrnoString());
1008         CallSetExtraOptionsCallback(ConvertErrno(), callback);
1009         return;
1010     }
1011 
1012     if (!SetExtraOptions(tcpExtraOptions)) {
1013         NETSTACK_LOGE("SetExtraOptions errno %{public}d", errno);
1014         CallOnErrorCallback(errno, MakeErrnoString());
1015         CallSetExtraOptionsCallback(ConvertErrno(), callback);
1016         return;
1017     }
1018 
1019     CallSetExtraOptionsCallback(TLSSOCKET_SUCCESS, callback);
1020 }
1021 
GetCertificate(const GetCertificateCallback & callback)1022 void TLSSocket::GetCertificate(const GetCertificateCallback &callback)
1023 {
1024     const auto &cert = tlsSocketInternal_.GetCertificate();
1025     NETSTACK_LOGI("cert der is %{public}d", cert.encodingFormat);
1026 
1027     if (!cert.data.Length()) {
1028         int resErr = tlsSocketInternal_.ConvertSSLError();
1029         NETSTACK_LOGE("GetCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
1030         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1031         callback(resErr, {});
1032         return;
1033     }
1034     callback(TLSSOCKET_SUCCESS, cert);
1035 }
1036 
GetRemoteCertificate(const GetRemoteCertificateCallback & callback)1037 void TLSSocket::GetRemoteCertificate(const GetRemoteCertificateCallback &callback)
1038 {
1039     const auto &remoteCert = tlsSocketInternal_.GetRemoteCertRawData();
1040     if (!remoteCert.data.Length()) {
1041         int resErr = tlsSocketInternal_.ConvertSSLError();
1042         NETSTACK_LOGE("GetRemoteCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
1043         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1044         callback(resErr, {});
1045         return;
1046     }
1047     callback(TLSSOCKET_SUCCESS, remoteCert);
1048 }
1049 
GetProtocol(const GetProtocolCallback & callback)1050 void TLSSocket::GetProtocol(const GetProtocolCallback &callback)
1051 {
1052     const auto &protocol = tlsSocketInternal_.GetProtocol();
1053     if (protocol.empty()) {
1054         NETSTACK_LOGE("GetProtocol errno %{public}d", errno);
1055         int resErr = tlsSocketInternal_.ConvertSSLError();
1056         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1057         callback(resErr, "");
1058         return;
1059     }
1060     callback(TLSSOCKET_SUCCESS, protocol);
1061 }
1062 
GetCipherSuite(const GetCipherSuiteCallback & callback)1063 void TLSSocket::GetCipherSuite(const GetCipherSuiteCallback &callback)
1064 {
1065     const auto &cipherSuite = tlsSocketInternal_.GetCipherSuite();
1066     if (cipherSuite.empty()) {
1067         NETSTACK_LOGE("GetCipherSuite errno %{public}d", errno);
1068         int resErr = tlsSocketInternal_.ConvertSSLError();
1069         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1070         callback(resErr, cipherSuite);
1071         return;
1072     }
1073     callback(TLSSOCKET_SUCCESS, cipherSuite);
1074 }
1075 
GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback & callback)1076 void TLSSocket::GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback &callback)
1077 {
1078     const auto &signatureAlgorithms = tlsSocketInternal_.GetSignatureAlgorithms();
1079     if (signatureAlgorithms.empty()) {
1080         NETSTACK_LOGE("GetSignatureAlgorithms errno %{public}d", errno);
1081         int resErr = tlsSocketInternal_.ConvertSSLError();
1082         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1083         callback(resErr, {});
1084         return;
1085     }
1086     callback(TLSSOCKET_SUCCESS, signatureAlgorithms);
1087 }
1088 
OnMessage(const OnMessageCallback & onMessageCallback)1089 void TLSSocket::OnMessage(const OnMessageCallback &onMessageCallback)
1090 {
1091     std::lock_guard<std::mutex> lock(mutex_);
1092     onMessageCallback_ = onMessageCallback;
1093 }
1094 
OffMessage()1095 void TLSSocket::OffMessage()
1096 {
1097     std::lock_guard<std::mutex> lock(mutex_);
1098     if (onMessageCallback_) {
1099         onMessageCallback_ = nullptr;
1100     }
1101 }
1102 
OnConnect(const OnConnectCallback & onConnectCallback)1103 void TLSSocket::OnConnect(const OnConnectCallback &onConnectCallback)
1104 {
1105     std::lock_guard<std::mutex> lock(mutex_);
1106     onConnectCallback_ = onConnectCallback;
1107 }
1108 
OffConnect()1109 void TLSSocket::OffConnect()
1110 {
1111     std::lock_guard<std::mutex> lock(mutex_);
1112     if (onConnectCallback_) {
1113         onConnectCallback_ = nullptr;
1114     }
1115 }
1116 
OnClose(const OnCloseCallback & onCloseCallback)1117 void TLSSocket::OnClose(const OnCloseCallback &onCloseCallback)
1118 {
1119     std::lock_guard<std::mutex> lock(mutex_);
1120     onCloseCallback_ = onCloseCallback;
1121 }
1122 
OffClose()1123 void TLSSocket::OffClose()
1124 {
1125     std::lock_guard<std::mutex> lock(mutex_);
1126     if (onCloseCallback_) {
1127         onCloseCallback_ = nullptr;
1128     }
1129 }
1130 
OnError(const OnErrorCallback & onErrorCallback)1131 void TLSSocket::OnError(const OnErrorCallback &onErrorCallback)
1132 {
1133     std::lock_guard<std::mutex> lock(mutex_);
1134     onErrorCallback_ = onErrorCallback;
1135 }
1136 
OffError()1137 void TLSSocket::OffError()
1138 {
1139     std::lock_guard<std::mutex> lock(mutex_);
1140     if (onErrorCallback_) {
1141         onErrorCallback_ = nullptr;
1142     }
1143 }
1144 
GetSocketFd()1145 int TLSSocket::GetSocketFd()
1146 {
1147     return sockFd_;
1148 }
1149 
SetLocalAddress(const Socket::NetAddress & address)1150 void TLSSocket::SetLocalAddress(const Socket::NetAddress &address)
1151 {
1152     localAddress_ = address;
1153 }
1154 
GetLocalAddress()1155 Socket::NetAddress TLSSocket::GetLocalAddress()
1156 {
1157     return localAddress_;
1158 }
1159 
GetCloseState()1160 bool TLSSocket::GetCloseState()
1161 {
1162     return isClosed;
1163 }
1164 
SetCloseState(bool flag)1165 void TLSSocket::SetCloseState(bool flag)
1166 {
1167     isClosed = flag;
1168 }
1169 
GetCloseLock()1170 std::mutex &TLSSocket::GetCloseLock()
1171 {
1172     return mutexForClose_;
1173 }
1174 
ExecSocketConnect(const std::string & host,int port,sa_family_t family,int socketDescriptor)1175 bool ExecSocketConnect(const std::string &host, int port, sa_family_t family, int socketDescriptor)
1176 {
1177     auto hostName = ConvertAddressToIp(host, family);
1178     struct sockaddr_in dest = {0};
1179     dest.sin_family = family;
1180     dest.sin_port = htons(port);
1181 
1182     sockaddr_in addr4 = {0};
1183     sockaddr_in6 addr6 = {0};
1184     sockaddr *addr = nullptr;
1185     socklen_t len = 0;
1186     if (family == AF_INET) {
1187         if (inet_pton(AF_INET, hostName.c_str(), &addr4.sin_addr.s_addr) <= 0) {
1188             return false;
1189         }
1190         addr4.sin_family = family;
1191         addr4.sin_port = htons(port);
1192         addr = reinterpret_cast<sockaddr *>(&addr4);
1193         len = sizeof(sockaddr_in);
1194     } else {
1195         if (inet_pton(AF_INET6, hostName.c_str(), &addr6.sin6_addr) <= 0) {
1196             return false;
1197         }
1198         addr6.sin6_family = family;
1199         addr6.sin6_port = htons(port);
1200         addr = reinterpret_cast<sockaddr *>(&addr6);
1201         len = sizeof(sockaddr_in6);
1202     }
1203 
1204     int connectResult = connect(socketDescriptor, addr, len);
1205     if (connectResult == -1) {
1206         NETSTACK_LOGE("socket connect error!The error code is %{public}d, The error message is %{public}s", errno,
1207                       strerror(errno));
1208         return false;
1209     }
1210     return true;
1211 }
1212 
ConvertSSLError(void)1213 int TLSSocket::TLSSocketInternal::ConvertSSLError(void)
1214 {
1215     std::lock_guard<std::mutex> lock(mutexForSsl_);
1216     if (!ssl_) {
1217         return TLS_ERR_SSL_NULL;
1218     }
1219     return TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl_, SSL_RET_CODE);
1220 }
1221 
TlsConnectToHost(int sock,const TLSConnectOptions & options,bool isExtSock)1222 bool TLSSocket::TLSSocketInternal::TlsConnectToHost(int sock, const TLSConnectOptions &options, bool isExtSock)
1223 {
1224     SetTlsConfiguration(options);
1225     std::string cipherSuite = options.GetTlsSecureOptions().GetCipherSuite();
1226     if (!cipherSuite.empty()) {
1227         configuration_.SetCipherSuite(cipherSuite);
1228     }
1229     std::string signatureAlgorithms = options.GetTlsSecureOptions().GetSignatureAlgorithms();
1230     if (!signatureAlgorithms.empty()) {
1231         configuration_.SetSignatureAlgorithms(signatureAlgorithms);
1232     }
1233     const auto protocolVec = options.GetTlsSecureOptions().GetProtocolChain();
1234     if (!protocolVec.empty()) {
1235         configuration_.SetProtocol(protocolVec);
1236     }
1237     configuration_.SetSkipFlag(options.GetSkipRemoteValidation());
1238     hostName_ = options.GetNetAddress().GetAddress();
1239     port_ = options.GetNetAddress().GetPort();
1240     family_ = options.GetNetAddress().GetSaFamily();
1241     socketDescriptor_ = sock;
1242     if (!isExtSock && !ExecSocketConnect(options.GetNetAddress().GetAddress(), options.GetNetAddress().GetPort(),
1243                                          options.GetNetAddress().GetSaFamily(), socketDescriptor_)) {
1244         return false;
1245     }
1246     return StartTlsConnected(options);
1247 }
1248 
SetTlsConfiguration(const TLSConnectOptions & config)1249 void TLSSocket::TLSSocketInternal::SetTlsConfiguration(const TLSConnectOptions &config)
1250 {
1251     configuration_.SetPrivateKey(config.GetTlsSecureOptions().GetKey(), config.GetTlsSecureOptions().GetKeyPass());
1252     configuration_.SetLocalCertificate(config.GetTlsSecureOptions().GetCert());
1253     configuration_.SetCaCertificate(config.GetTlsSecureOptions().GetCaChain());
1254     configuration_.SetNetAddress(config.GetNetAddress());
1255 }
1256 
SendRetry(ssl_st * ssl,const char * curPos,size_t curSendSize,int sockfd)1257 bool TLSSocket::TLSSocketInternal::SendRetry(ssl_st *ssl, const char *curPos, size_t curSendSize, int sockfd)
1258 {
1259     pollfd fds[1] = {{.fd = sockfd, .events = POLLOUT}};
1260     for (int i = 0; i <= SEND_RETRY_TIMES; i++) {
1261         int ret = poll(fds, 1, SEND_POLL_TIMEOUT_MS);
1262         if (ret < 0) {
1263             if (errno == EAGAIN || errno == EINTR) {
1264                 continue;
1265             }
1266             NETSTACK_LOGE("send poll error, fd: %{public}d, errno: %{public}d", sockfd, errno);
1267             return false;
1268         } else if (ret == 0) {
1269             NETSTACK_LOGI("send poll timeout, fd: %{public}d, errno: %{public}d", sockfd, errno);
1270             continue;
1271         }
1272         int len = SSL_write(ssl, curPos, curSendSize);
1273         if (len < 0) {
1274             int err = SSL_get_error(ssl, SSL_RET_CODE);
1275             if (err == SSL_ERROR_WANT_WRITE || errno == EAGAIN) {
1276                 NETSTACK_LOGI("write retry times: %{public}d err: %{public}d errno: %{public}d", i, err, errno);
1277                 continue;
1278             } else {
1279                 NETSTACK_LOGE("write failed err: %{public}d errno: %{public}d", err, errno);
1280                 return false;
1281             }
1282         } else if (len == 0) {
1283             NETSTACK_LOGI("send len is 0, should have sent len");
1284             return false;
1285         } else {
1286             return true;
1287         }
1288     }
1289     return false;
1290 }
1291 
PollSend(int sockfd,ssl_st * ssl,const char * pdata,int sendSize)1292 bool TLSSocket::TLSSocketInternal::PollSend(int sockfd, ssl_st *ssl, const char *pdata, int sendSize)
1293 {
1294     int bufferSize = DEFAULT_BUFFER_SIZE;
1295     auto curPos = pdata;
1296     nfds_t num = 1;
1297     pollfd fds[1] = {{.fd = sockfd, .events = POLLOUT}};
1298     while (sendSize > 0) {
1299         int ret = poll(fds, num, DEFAULT_POLL_TIMEOUT_MS);
1300         if (ret < 0) {
1301             if (errno == EAGAIN || errno == EINTR) {
1302                 continue;
1303             }
1304             NETSTACK_LOGE("send poll error, fd: %{public}d, errno: %{public}d", sockfd, errno);
1305             return false;
1306         } else if (ret == 0) {
1307             NETSTACK_LOGI("send poll timeout, fd: %{public}d, errno: %{public}d", sockfd, errno);
1308             continue;
1309         }
1310         std::lock_guard<std::mutex> lock(mutexForSsl_);
1311         if (!ssl) {
1312             NETSTACK_LOGE("ssl is null");
1313             return false;
1314         }
1315         size_t curSendSize = std::min<size_t>(sendSize, bufferSize);
1316         int len = SSL_write(ssl, curPos, curSendSize);
1317         if (len < 0) {
1318             int err = SSL_get_error(ssl, SSL_RET_CODE);
1319             if (err != SSL_ERROR_WANT_WRITE || errno != EAGAIN) {
1320                 NETSTACK_LOGE("write failed, return, err: %{public}d errno: %{public}d", err, errno);
1321                 return false;
1322             } else if (!SendRetry(ssl, curPos, curSendSize, sockfd)) {
1323                 return false;
1324             }
1325         } else if (len == 0) {
1326             NETSTACK_LOGI("send len is 0, should have sent len is %{public}d", sendSize);
1327             return false;
1328         }
1329         curPos += len;
1330         sendSize -= len;
1331     }
1332     return true;
1333 }
1334 
Send(const std::string & data)1335 bool TLSSocket::TLSSocketInternal::Send(const std::string &data)
1336 {
1337     {
1338         std::lock_guard<std::mutex> lock(mutexForSsl_);
1339         if (!ssl_) {
1340             NETSTACK_LOGE("ssl is null");
1341             return false;
1342         }
1343     }
1344 
1345     if (data.empty()) {
1346         NETSTACK_LOGE("data is empty");
1347         return true;
1348     }
1349 
1350     if (!PollSend(socketDescriptor_, ssl_, data.c_str(), data.size())) {
1351         return false;
1352     }
1353     return true;
1354 }
Recv(char * buffer,int maxBufferSize)1355 int TLSSocket::TLSSocketInternal::Recv(char *buffer, int maxBufferSize)
1356 {
1357     if (!ssl_) {
1358         NETSTACK_LOGE("ssl is null");
1359         return SSL_ERROR_RETURN;
1360     }
1361 
1362     int ret = SSL_read(ssl_, buffer, maxBufferSize);
1363     if (ret < 0) {
1364         int err = SSL_get_error(ssl_, SSL_RET_CODE);
1365         switch (err) {
1366             case SSL_ERROR_SSL:
1367                 NETSTACK_LOGE("An error occurred in the SSL library");
1368                 return SSL_ERROR_RETURN;
1369             case SSL_ERROR_ZERO_RETURN:
1370                 NETSTACK_LOGE("peer disconnected...");
1371                 return SSL_ERROR_RETURN;
1372             case SSL_ERROR_WANT_READ:
1373                 NETSTACK_LOGD("SSL_read function no data available for reading, try again at a later time");
1374                 return SSL_WANT_READ_RETURN;
1375             default:
1376                 NETSTACK_LOGE("SSL_read function failed, error code is %{public}d", err);
1377                 return SSL_ERROR_RETURN;
1378         }
1379     }
1380     return ret;
1381 }
1382 
Close()1383 bool TLSSocket::TLSSocketInternal::Close()
1384 {
1385     std::lock_guard<std::mutex> lock(mutexForSsl_);
1386     if (!ssl_) {
1387         NETSTACK_LOGE("ssl is null, fd =%{public}d", socketDescriptor_);
1388         return false;
1389     }
1390     int result = SSL_shutdown(ssl_);
1391     if (result < 0) {
1392         int resErr = TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl_, SSL_RET_CODE);
1393         NETSTACK_LOGE("Error in shutdown, errno is %{public}d, error info is %{public}s", resErr,
1394                       MakeSSLErrorString(resErr).c_str());
1395     }
1396     NETSTACK_LOGI("tls socket close, fd =%{public}d", socketDescriptor_);
1397     SSL_free(ssl_);
1398     ssl_ = nullptr;
1399     close(socketDescriptor_);
1400     socketDescriptor_ = -1;
1401     if (!tlsContextPointer_) {
1402         NETSTACK_LOGE("Tls context pointer is null");
1403         return false;
1404     }
1405     tlsContextPointer_->CloseCtx();
1406     return true;
1407 }
1408 
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)1409 bool TLSSocket::TLSSocketInternal::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
1410 {
1411     if (!ssl_) {
1412         NETSTACK_LOGE("ssl is null");
1413         return false;
1414     }
1415     size_t pos = 0;
1416     size_t len = std::accumulate(alpnProtocols.begin(), alpnProtocols.end(), static_cast<size_t>(0),
1417                                  [](size_t init, const std::string &alpnProt) { return init + alpnProt.length(); });
1418     auto result = std::make_unique<unsigned char[]>(alpnProtocols.size() + len);
1419     for (const auto &str : alpnProtocols) {
1420         len = str.length();
1421         result[pos++] = len;
1422         if (!strcpy_s(reinterpret_cast<char *>(&result[pos]), len, str.c_str())) {
1423             NETSTACK_LOGE("strcpy_s failed");
1424             return false;
1425         }
1426         pos += len;
1427     }
1428     result[pos] = '\0';
1429 
1430     NETSTACK_LOGD("alpnProtocols after splicing %{public}s", result.get());
1431     if (SSL_set_alpn_protos(ssl_, result.get(), pos)) {
1432         int resErr = ConvertSSLError();
1433         NETSTACK_LOGE("Failed to set negotiable protocol list, errno is %{public}d, error info is %{public}s", resErr,
1434                       MakeSSLErrorString(resErr).c_str());
1435         return false;
1436     }
1437     return true;
1438 }
1439 
MakeRemoteInfo(Socket::SocketRemoteInfo & remoteInfo)1440 void TLSSocket::TLSSocketInternal::MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo)
1441 {
1442     remoteInfo.SetFamily(family_);
1443     remoteInfo.SetAddress(hostName_);
1444     remoteInfo.SetPort(port_);
1445 }
1446 
GetTlsConfiguration() const1447 TLSConfiguration TLSSocket::TLSSocketInternal::GetTlsConfiguration() const
1448 {
1449     return configuration_;
1450 }
1451 
GetCipherSuite() const1452 std::vector<std::string> TLSSocket::TLSSocketInternal::GetCipherSuite() const
1453 {
1454     if (!ssl_) {
1455         NETSTACK_LOGE("ssl in null");
1456         return {};
1457     }
1458     STACK_OF(SSL_CIPHER) *sk = SSL_get_ciphers(ssl_);
1459     if (!sk) {
1460         NETSTACK_LOGE("get ciphers failed");
1461         return {};
1462     }
1463     CipherSuite cipherSuite;
1464     std::vector<std::string> cipherSuiteVec;
1465     for (int i = 0; i < sk_SSL_CIPHER_num(sk); i++) {
1466         const SSL_CIPHER *c = sk_SSL_CIPHER_value(sk, i);
1467         cipherSuite.cipherName_ = SSL_CIPHER_get_name(c);
1468         cipherSuiteVec.push_back(cipherSuite.cipherName_);
1469     }
1470     return cipherSuiteVec;
1471 }
1472 
GetRemoteCertificate() const1473 std::string TLSSocket::TLSSocketInternal::GetRemoteCertificate() const
1474 {
1475     return remoteCert_;
1476 }
1477 
GetCertificate() const1478 const X509CertRawData &TLSSocket::TLSSocketInternal::GetCertificate() const
1479 {
1480     return configuration_.GetCertificate();
1481 }
1482 
GetSignatureAlgorithms() const1483 std::vector<std::string> TLSSocket::TLSSocketInternal::GetSignatureAlgorithms() const
1484 {
1485     return signatureAlgorithms_;
1486 }
1487 
GetProtocol() const1488 std::string TLSSocket::TLSSocketInternal::GetProtocol() const
1489 {
1490     if (!ssl_) {
1491         NETSTACK_LOGE("ssl in null");
1492         return PROTOCOL_UNKNOW;
1493     }
1494     if (configuration_.GetProtocol() == TLS_V1_3) {
1495         return PROTOCOL_TLS_V13;
1496     }
1497     return PROTOCOL_TLS_V12;
1498 }
1499 
SetSharedSigals()1500 bool TLSSocket::TLSSocketInternal::SetSharedSigals()
1501 {
1502     if (!ssl_) {
1503         NETSTACK_LOGE("ssl is null");
1504         return false;
1505     }
1506     int number = SSL_get_shared_sigalgs(ssl_, 0, nullptr, nullptr, nullptr, nullptr, nullptr);
1507     if (!number) {
1508         NETSTACK_LOGE("SSL_get_shared_sigalgs return value error");
1509         return false;
1510     }
1511     for (int i = 0; i < number; i++) {
1512         int hash_nid;
1513         int sign_nid;
1514         std::string sig_with_md;
1515         SSL_get_shared_sigalgs(ssl_, i, &sign_nid, &hash_nid, nullptr, nullptr, nullptr);
1516         switch (sign_nid) {
1517             case EVP_PKEY_RSA:
1518                 sig_with_md = SIGN_NID_RSA;
1519                 break;
1520             case EVP_PKEY_RSA_PSS:
1521                 sig_with_md = SIGN_NID_RSA_PSS;
1522                 break;
1523             case EVP_PKEY_DSA:
1524                 sig_with_md = SIGN_NID_DSA;
1525                 break;
1526             case EVP_PKEY_EC:
1527                 sig_with_md = SIGN_NID_ECDSA;
1528                 break;
1529             case NID_ED25519:
1530                 sig_with_md = SIGN_NID_ED;
1531                 break;
1532             case NID_ED448:
1533                 sig_with_md = SIGN_NID_ED_FOUR_FOUR_EIGHT;
1534                 break;
1535             default:
1536                 const char *sn = OBJ_nid2sn(sign_nid);
1537                 sig_with_md = (sn != nullptr) ? (std::string(sn) + OPERATOR_PLUS_SIGN) : SIGN_NID_UNDEF_ADD;
1538         }
1539         const char *sn_hash = OBJ_nid2sn(hash_nid);
1540         sig_with_md += (sn_hash != nullptr) ? std::string(sn_hash) : SIGN_NID_UNDEF;
1541         signatureAlgorithms_.push_back(sig_with_md);
1542     }
1543     return true;
1544 }
1545 
StartTlsConnected(const TLSConnectOptions & options)1546 bool TLSSocket::TLSSocketInternal::StartTlsConnected(const TLSConnectOptions &options)
1547 {
1548     if (!CreatTlsContext()) {
1549         NETSTACK_LOGE("failed to create tls context");
1550         return false;
1551     }
1552     if (!StartShakingHands(options)) {
1553         NETSTACK_LOGE("failed to shaking hands");
1554         return false;
1555     }
1556     return true;
1557 }
1558 
CreatTlsContext()1559 bool TLSSocket::TLSSocketInternal::CreatTlsContext()
1560 {
1561     tlsContextPointer_ = TLSContext::CreateConfiguration(configuration_);
1562     if (!tlsContextPointer_) {
1563         NETSTACK_LOGE("failed to create tls context pointer");
1564         return false;
1565     }
1566 
1567     std::lock_guard<std::mutex> lock(mutexForSsl_);
1568     if (!(ssl_ = tlsContextPointer_->CreateSsl())) {
1569         NETSTACK_LOGE("failed to create ssl session");
1570         return false;
1571     }
1572 
1573     SSL_set_fd(ssl_, socketDescriptor_);
1574     SSL_set_connect_state(ssl_);
1575     return true;
1576 }
1577 
StartsWith(const std::string & s,const std::string & prefix)1578 static bool StartsWith(const std::string &s, const std::string &prefix)
1579 {
1580     return s.size() >= prefix.size() && s.compare(0, prefix.size(), prefix) == 0;
1581 }
1582 
CheckIpAndDnsName(const std::string & hostName,std::vector<std::string> dnsNames,std::vector<std::string> ips,const X509 * x509Certificates,std::tuple<bool,std::string> & result)1583 void CheckIpAndDnsName(const std::string &hostName, std::vector<std::string> dnsNames, std::vector<std::string> ips,
1584                        const X509 *x509Certificates, std::tuple<bool, std::string> &result)
1585 {
1586     bool valid = false;
1587     std::string reason = UNKNOW_REASON;
1588     int index = X509_get_ext_by_NID(x509Certificates, NID_commonName, -1);
1589     if (IsIP(hostName)) {
1590         auto it = find(ips.begin(), ips.end(), hostName);
1591         if (it == ips.end()) {
1592             reason = IP + hostName + " is not in the cert's list";
1593         }
1594         result = {valid, reason};
1595         return;
1596     }
1597     std::string tempHostName = "" + hostName;
1598     if (!dnsNames.empty() || index > 0) {
1599         std::vector<std::string> hostParts = SplitHostName(tempHostName);
1600         if (!dnsNames.empty()) {
1601             valid = SeekIntersection(hostParts, dnsNames);
1602             if (!valid) {
1603                 reason = HOST_NAME + tempHostName + ". is not in the cert's altnames";
1604             }
1605         } else {
1606             char commonNameBuf[COMMON_NAME_BUF_SIZE] = {0};
1607             X509_NAME *pSubName = nullptr;
1608             int len = X509_NAME_get_text_by_NID(pSubName, NID_commonName, commonNameBuf, COMMON_NAME_BUF_SIZE);
1609             if (len > 0) {
1610                 std::vector<std::string> commonNameVec;
1611                 commonNameVec.emplace_back(commonNameBuf);
1612                 valid = SeekIntersection(hostParts, commonNameVec);
1613                 if (!valid) {
1614                     reason = HOST_NAME + tempHostName + ". is not cert's CN";
1615                 }
1616             }
1617         }
1618         result = {valid, reason};
1619         return;
1620     }
1621     reason = "Cert does not contain a DNS name";
1622     result = {valid, reason};
1623 }
1624 
CheckServerIdentityLegal(const std::string & hostName,const X509 * x509Certificates)1625 std::string TLSSocket::TLSSocketInternal::CheckServerIdentityLegal(const std::string &hostName,
1626                                                                    const X509 *x509Certificates)
1627 {
1628     X509_NAME *subjectName = X509_get_subject_name(x509Certificates);
1629     if (!subjectName) {
1630         return "subject name is null";
1631     }
1632     char subNameBuf[BUF_SIZE] = {0};
1633     X509_NAME_oneline(subjectName, subNameBuf, BUF_SIZE);
1634 
1635     int index = X509_get_ext_by_NID(x509Certificates, NID_subject_alt_name, -1);
1636     if (index < 0) {
1637         return "X509 get ext nid error";
1638     }
1639     X509_EXTENSION *ext = X509_get_ext(x509Certificates, index);
1640     if (ext == nullptr) {
1641         return "X509 get ext error";
1642     }
1643     ASN1_OBJECT *obj = nullptr;
1644     obj = X509_EXTENSION_get_object(ext);
1645     char subAltNameBuf[BUF_SIZE] = {0};
1646     OBJ_obj2txt(subAltNameBuf, BUF_SIZE, obj, 0);
1647     NETSTACK_LOGD("extions obj : %{public}s\n", subAltNameBuf);
1648 
1649     return CheckServerIdentityLegal(hostName, ext, x509Certificates);
1650 }
1651 
CheckServerIdentityLegal(const std::string & hostName,X509_EXTENSION * ext,const X509 * x509Certificates)1652 std::string TLSSocket::TLSSocketInternal::CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext,
1653                                                                    const X509 *x509Certificates)
1654 {
1655     ASN1_OCTET_STRING *extData = X509_EXTENSION_get_data(ext);
1656     if (!extData) {
1657         NETSTACK_LOGE("extData is nullptr");
1658         return "";
1659     }
1660 
1661     std::string altNames = reinterpret_cast<char *>(extData->data);
1662     std::string hostname = " " + hostName;
1663     BIO *bio = BIO_new(BIO_s_file());
1664     if (!bio) {
1665         return "bio is null";
1666     }
1667     BIO_set_fp(bio, stdout, BIO_NOCLOSE);
1668     ASN1_STRING_print(bio, extData);
1669     std::vector<std::string> dnsNames = {};
1670     std::vector<std::string> ips = {};
1671     constexpr int DNS_NAME_IDX = 4;
1672     constexpr int IP_NAME_IDX = 11;
1673     if (!altNames.empty()) {
1674         std::vector<std::string> splitAltNames;
1675         if (altNames.find('\"') != std::string::npos) {
1676             splitAltNames = SplitEscapedAltNames(altNames);
1677         } else {
1678             splitAltNames = CommonUtils::Split(altNames, SPLIT_ALT_NAMES);
1679         }
1680         for (auto const &iter : splitAltNames) {
1681             if (StartsWith(iter, DNS)) {
1682                 dnsNames.push_back(iter.substr(DNS_NAME_IDX));
1683             } else if (StartsWith(iter, IP_ADDRESS)) {
1684                 ips.push_back(iter.substr(IP_NAME_IDX));
1685             }
1686         }
1687     }
1688     std::tuple<bool, std::string> result;
1689     CheckIpAndDnsName(hostName, dnsNames, ips, x509Certificates, result);
1690     if (!std::get<0>(result)) {
1691         return "Hostname/IP does not match certificate's altnames: " + std::get<1>(result);
1692     }
1693     return HOST_NAME + hostname + ". is cert's CN";
1694 }
1695 
LoadCaCertFromMemory(X509_STORE * store,const std::string & pemCerts)1696 static void LoadCaCertFromMemory(X509_STORE *store, const std::string &pemCerts)
1697 {
1698     if (!store || pemCerts.empty() || pemCerts.size() > static_cast<size_t>(INT_MAX)) {
1699         return;
1700     }
1701 
1702     auto cbio = BIO_new_mem_buf(pemCerts.data(), static_cast<int>(pemCerts.size()));
1703     if (!cbio) {
1704         return;
1705     }
1706 
1707     auto inf = PEM_X509_INFO_read_bio(cbio, nullptr, nullptr, nullptr);
1708     if (!inf) {
1709         BIO_free(cbio);
1710         return;
1711     }
1712 
1713     /* add each entry from PEM file to x509_store */
1714     for (int i = 0; i < static_cast<int>(sk_X509_INFO_num(inf)); ++i) {
1715         auto itmp = sk_X509_INFO_value(inf, i);
1716         if (!itmp) {
1717             continue;
1718         }
1719         if (itmp->x509) {
1720             X509_STORE_add_cert(store, itmp->x509);
1721         }
1722         if (itmp->crl) {
1723             X509_STORE_add_crl(store, itmp->crl);
1724         }
1725     }
1726 
1727     sk_X509_INFO_pop_free(inf, X509_INFO_free);
1728     BIO_free(cbio);
1729 }
1730 
X509_to_PEM(X509 * cert)1731 static std::string X509_to_PEM(X509 *cert)
1732 {
1733     if (!cert) {
1734         return {};
1735     }
1736     BIO *bio = BIO_new(BIO_s_mem());
1737     if (!bio) {
1738         return {};
1739     }
1740     if (!PEM_write_bio_X509(bio, cert)) {
1741         BIO_free(bio);
1742         return {};
1743     }
1744 
1745     char *data = nullptr;
1746     auto pemStringLength = BIO_get_mem_data(bio, &data);
1747     if (!data) {
1748         BIO_free(bio);
1749         return {};
1750     }
1751     std::string certificateInPEM(data, pemStringLength);
1752     BIO_free(bio);
1753     return certificateInPEM;
1754 }
1755 
CacheCertificates(const std::string & hostName,SSL * ssl)1756 static void CacheCertificates(const std::string &hostName, SSL *ssl)
1757 {
1758     if (!ssl || hostName.empty()) {
1759         return;
1760     }
1761     auto certificatesStack = SSL_get_peer_cert_chain(ssl);
1762     if (!certificatesStack) {
1763         return;
1764     }
1765     auto numCertificates = sk_X509_num(certificatesStack);
1766     for (auto i = 0; i < numCertificates; ++i) {
1767         auto cert = sk_X509_value(certificatesStack, i);
1768         auto certificateInPEM = X509_to_PEM(cert);
1769         if (!certificateInPEM.empty()) {
1770             CaCertCache::GetInstance().Set(hostName, certificateInPEM);
1771         }
1772     }
1773 }
1774 
LoadCachedCaCert(const std::string & hostName,SSL * ssl)1775 static void LoadCachedCaCert(const std::string &hostName, SSL *ssl)
1776 {
1777     if (!ssl) {
1778         return;
1779     }
1780     auto cachedPem = CaCertCache::GetInstance().Get(hostName);
1781     auto sslCtx = SSL_get_SSL_CTX(ssl);
1782     if (!sslCtx) {
1783         return;
1784     }
1785     auto x509Store = SSL_CTX_get_cert_store(sslCtx);
1786     if (!x509Store) {
1787         return;
1788     }
1789     for (const auto &pem : cachedPem) {
1790         LoadCaCertFromMemory(x509Store, pem);
1791     }
1792 }
1793 
StartShakingHands(const TLSConnectOptions & options)1794 bool TLSSocket::TLSSocketInternal::StartShakingHands(const TLSConnectOptions &options)
1795 {
1796     {
1797         std::lock_guard<std::mutex> lock(mutexForSsl_);
1798         if (!ssl_) {
1799             NETSTACK_LOGE("ssl is null");
1800             return false;
1801         }
1802 
1803         auto hostName = options.GetHostName();
1804         // indicates hostName is not ip address
1805         if (hostName != options.GetNetAddress().GetAddress()) {
1806             LoadCachedCaCert(hostName, ssl_);
1807         }
1808 
1809         int result = SSL_connect(ssl_);
1810         if (result == -1) {
1811             char err[MAX_ERR_LEN] = {0};
1812             auto code = ERR_get_error();
1813             ERR_error_string_n(code, err, MAX_ERR_LEN);
1814             int errorStatus = TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl_, SSL_RET_CODE);
1815             NETSTACK_LOGE("SSLConnect fail %{public}d, error: %{public}s errno: %{public}d ERR_get_error %{public}s",
1816                           errorStatus, MakeSSLErrorString(errorStatus).c_str(), errno, err);
1817             return false;
1818         }
1819 
1820         // indicates hostName is not ip address
1821         if (hostName != options.GetNetAddress().GetAddress()) {
1822             CacheCertificates(hostName, ssl_);
1823         }
1824 
1825         std::string list = SSL_get_cipher_list(ssl_, 0);
1826         NETSTACK_LOGI("cipher_list: %{public}s, Version: %{public}s, Cipher: %{public}s", list.c_str(),
1827                       SSL_get_version(ssl_), SSL_get_cipher(ssl_));
1828         configuration_.SetCipherSuite(list);
1829     }
1830     if (!SetSharedSigals()) {
1831         NETSTACK_LOGE("Failed to set sharedSigalgs");
1832     }
1833     if (!GetRemoteCertificateFromPeer()) {
1834         NETSTACK_LOGE("Failed to get remote certificate");
1835     }
1836     if (!peerX509_) {
1837         NETSTACK_LOGE("peer x509Certificates is null");
1838         return false;
1839     }
1840     if (!SetRemoteCertRawData()) {
1841         NETSTACK_LOGE("Failed to set remote x509 certificata Serialization data");
1842     }
1843     CheckServerIdentity checkServerIdentity = options.GetCheckServerIdentity();
1844     if (!checkServerIdentity) {
1845         CheckServerIdentityLegal(hostName_, peerX509_);
1846     } else {
1847         checkServerIdentity(hostName_, {remoteCert_});
1848     }
1849     return true;
1850 }
1851 
GetRemoteCertificateFromPeer()1852 bool TLSSocket::TLSSocketInternal::GetRemoteCertificateFromPeer()
1853 {
1854     peerX509_ = SSL_get_peer_certificate(ssl_);
1855     if (peerX509_ == nullptr) {
1856         int resErr = ConvertSSLError();
1857         NETSTACK_LOGE("open fail errno, errno is %{public}d, error info is %{public}s", resErr,
1858                       MakeSSLErrorString(resErr).c_str());
1859         return false;
1860     }
1861     BIO *bio = BIO_new(BIO_s_mem());
1862     if (!bio) {
1863         NETSTACK_LOGE("TlsSocket::SetRemoteCertificate bio is null");
1864         return false;
1865     }
1866     X509_print(bio, peerX509_);
1867     char data[REMOTE_CERT_LEN] = {0};
1868     if (!BIO_read(bio, data, REMOTE_CERT_LEN)) {
1869         NETSTACK_LOGE("BIO_read function returns error");
1870         BIO_free(bio);
1871         return false;
1872     }
1873     BIO_free(bio);
1874     remoteCert_ = std::string(data);
1875     return true;
1876 }
1877 
SetRemoteCertRawData()1878 bool TLSSocket::TLSSocketInternal::SetRemoteCertRawData()
1879 {
1880     if (peerX509_ == nullptr) {
1881         NETSTACK_LOGE("peerX509 is null");
1882         return false;
1883     }
1884     int32_t length = i2d_X509(peerX509_, nullptr);
1885     if (length <= 0) {
1886         NETSTACK_LOGE("Failed to convert peerX509 to der format");
1887         return false;
1888     }
1889     unsigned char *der = nullptr;
1890     (void)i2d_X509(peerX509_, &der);
1891     SecureData data(der, length);
1892     remoteRawData_.data = data;
1893     OPENSSL_free(der);
1894     remoteRawData_.encodingFormat = DER;
1895     return true;
1896 }
1897 
GetRemoteCertRawData() const1898 const X509CertRawData &TLSSocket::TLSSocketInternal::GetRemoteCertRawData() const
1899 {
1900     return remoteRawData_;
1901 }
1902 
GetSSL()1903 ssl_st *TLSSocket::TLSSocketInternal::GetSSL()
1904 {
1905     std::lock_guard<std::mutex> lock(mutexForSsl_);
1906     return ssl_;
1907 }
1908 } // namespace TlsSocket
1909 } // namespace NetStack
1910 } // namespace OHOS
1911