• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "tls_socket.h"
17 
18 #include <chrono>
19 #include <memory>
20 #include <numeric>
21 #include <poll.h>
22 #include <regex>
23 #include <securec.h>
24 #include <set>
25 #include <thread>
26 
27 #include <netinet/tcp.h>
28 #include <openssl/err.h>
29 #include <openssl/ssl.h>
30 
31 #include "base_context.h"
32 #include "netstack_common_utils.h"
33 #include "netstack_log.h"
34 #include "socket_exec_common.h"
35 #include "tls.h"
36 
37 namespace OHOS {
38 namespace NetStack {
39 namespace TlsSocket {
40 namespace {
41 constexpr int READ_TIMEOUT_MS = 500;
42 constexpr int REMOTE_CERT_LEN = 8192;
43 constexpr int COMMON_NAME_BUF_SIZE = 256;
44 constexpr int BUF_SIZE = 2048;
45 constexpr int SSL_RET_CODE = 0;
46 constexpr int SSL_ERROR_RETURN = -1;
47 constexpr int SSL_WANT_READ_RETURN = -2;
48 constexpr int OFFSET = 2;
49 constexpr int DEFAULT_BUFFER_SIZE = 8192;
50 constexpr int DEFAULT_POLL_TIMEOUT_MS = 500;
51 constexpr int SEND_RETRY_TIMES = 5;
52 constexpr int SEND_POLL_TIMEOUT_MS = 1000;
53 constexpr int MAX_RECV_BUFFER_SIZE = 1024 * 16;
54 constexpr const char *SPLIT_ALT_NAMES = ",";
55 constexpr const char *SPLIT_HOST_NAME = ".";
56 constexpr const char *PROTOCOL_UNKNOW = "UNKNOW_PROTOCOL";
57 constexpr const char *UNKNOW_REASON = "Unknown reason";
58 constexpr const char *IP = "IP: ";
59 constexpr const char *HOST_NAME = "hostname: ";
60 constexpr const char *DNS = "DNS:";
61 constexpr const char *IP_ADDRESS = "IP Address:";
62 constexpr const char *SIGN_NID_RSA = "RSA+";
63 constexpr const char *SIGN_NID_RSA_PSS = "RSA-PSS+";
64 constexpr const char *SIGN_NID_DSA = "DSA+";
65 constexpr const char *SIGN_NID_ECDSA = "ECDSA+";
66 constexpr const char *SIGN_NID_ED = "Ed25519+";
67 constexpr const char *SIGN_NID_ED_FOUR_FOUR_EIGHT = "Ed448+";
68 constexpr const char *SIGN_NID_UNDEF_ADD = "UNDEF+";
69 constexpr const char *SIGN_NID_UNDEF = "UNDEF";
70 constexpr const char *OPERATOR_PLUS_SIGN = "+";
71 static constexpr const char *TLS_SOCKET_CLIENT_READ = "OS_NET_TSCliRD";
72 const std::regex JSON_STRING_PATTERN{R"(/^"(?:[^"\\\u0000-\u001f]|\\(?:["\\/bfnrt]|u[0-9a-fA-F]{4}))*"/)"};
73 const std::regex PATTERN{
74     "((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|"
75     "2[0-4][0-9]|[01]?[0-9][0-9]?)"};
76 
77 class CaCertCache {
78 public:
GetInstance()79     static CaCertCache &GetInstance()
80     {
81         static CaCertCache instance;
82         return instance;
83     }
84 
Get(const std::string & key)85     std::set<std::string> Get(const std::string &key)
86     {
87         std::lock_guard l(mutex_);
88         auto it = map_.find(key);
89         if (it != map_.end()) {
90             return it->second;
91         }
92         return {};
93     }
94 
Set(const std::string & key,const std::string & val)95     void Set(const std::string &key, const std::string &val)
96     {
97         std::lock_guard l(mutex_);
98         map_[key].insert(val);
99     }
100 
101 private:
102     CaCertCache() = default;
103     ~CaCertCache() = default;
104     CaCertCache &operator=(const CaCertCache &) = delete;
105     CaCertCache(const CaCertCache &) = delete;
106 
107     std::map<std::string, std::set<std::string>> map_;
108     std::mutex mutex_;
109 };
110 
ConvertErrno()111 int ConvertErrno()
112 {
113     return TlsSocketError::TLS_ERR_SYS_BASE + errno;
114 }
115 
MakeErrnoString()116 std::string MakeErrnoString()
117 {
118     return strerror(errno);
119 }
120 
MakeSSLErrorString(int error)121 std::string MakeSSLErrorString(int error)
122 {
123     char err[MAX_ERR_LEN] = {0};
124     ERR_error_string_n(error - TlsSocketError::TLS_ERR_SYS_BASE, err, sizeof(err));
125     return err;
126 }
127 
SplitEscapedAltNames(std::string & altNames)128 std::vector<std::string> SplitEscapedAltNames(std::string &altNames)
129 {
130     std::vector<std::string> result;
131     std::string currentToken;
132     size_t offset = 0;
133     while (offset != altNames.length()) {
134         auto nextSep = altNames.find_first_of(", ");
135         auto nextQuote = altNames.find_first_of('\"');
136         if (nextQuote != std::string::npos && (nextSep != std::string::npos || nextQuote < nextSep)) {
137             currentToken += altNames.substr(offset, nextQuote);
138             std::regex jsonStringPattern(JSON_STRING_PATTERN);
139             std::smatch match;
140             std::string altNameSubStr = altNames.substr(nextQuote);
141             bool ret = regex_match(altNameSubStr, match, jsonStringPattern);
142             if (!ret) {
143                 return {""};
144             }
145             currentToken += result[0];
146             offset = nextQuote + result[0].length();
147         } else if (nextSep != std::string::npos) {
148             currentToken += altNames.substr(offset, nextSep);
149             result.push_back(currentToken);
150             currentToken = "";
151             offset = nextSep + OFFSET;
152         } else {
153             currentToken += altNames.substr(offset);
154             offset = altNames.length();
155         }
156     }
157     result.push_back(currentToken);
158     return result;
159 }
160 
IsIP(const std::string & ip)161 bool IsIP(const std::string &ip)
162 {
163     std::regex pattern(PATTERN);
164     std::smatch res;
165     return regex_match(ip, res, pattern);
166 }
167 
SplitHostName(std::string & hostName)168 std::vector<std::string> SplitHostName(std::string &hostName)
169 {
170     transform(hostName.begin(), hostName.end(), hostName.begin(), ::tolower);
171     return CommonUtils::Split(hostName, SPLIT_HOST_NAME);
172 }
173 
SeekIntersection(std::vector<std::string> & vecA,std::vector<std::string> & vecB)174 bool SeekIntersection(std::vector<std::string> &vecA, std::vector<std::string> &vecB)
175 {
176     std::vector<std::string> result;
177     set_intersection(vecA.begin(), vecA.end(), vecB.begin(), vecB.end(), inserter(result, result.begin()));
178     return !result.empty();
179 }
180 } // namespace
181 
SetSockBlockFlag(int sock,bool noneBlock)182 static bool SetSockBlockFlag(int sock, bool noneBlock)
183 {
184     int flags = fcntl(sock, F_GETFL, 0);
185     while (flags == -1 && errno == EINTR) {
186         flags = fcntl(sock, F_GETFL, 0);
187     }
188     if (flags == -1) {
189         NETSTACK_LOGE("set block flags failed, socket is %{public}d, errno is %{public}d", sock, errno);
190         return false;
191     }
192 
193     auto newFlags = static_cast<size_t>(flags);
194     if (noneBlock) {
195         newFlags |= static_cast<size_t>(O_NONBLOCK);
196     } else {
197         newFlags &= ~static_cast<size_t>(O_NONBLOCK);
198     }
199 
200     int ret = fcntl(sock, F_SETFL, newFlags);
201     while (ret == -1 && errno == EINTR) {
202         ret = fcntl(sock, F_SETFL, newFlags);
203     }
204     if (ret == -1) {
205         NETSTACK_LOGE("set block flags failed, socket is %{public}d, errno is %{public}d", sock, errno);
206         return false;
207     }
208     return true;
209 }
210 
TLSSecureOptions(const TLSSecureOptions & tlsSecureOptions)211 TLSSecureOptions::TLSSecureOptions(const TLSSecureOptions &tlsSecureOptions)
212 {
213     *this = tlsSecureOptions;
214 }
215 
operator =(const TLSSecureOptions & tlsSecureOptions)216 TLSSecureOptions &TLSSecureOptions::operator=(const TLSSecureOptions &tlsSecureOptions)
217 {
218     key_ = tlsSecureOptions.GetKey();
219     caChain_ = tlsSecureOptions.GetCaChain();
220     cert_ = tlsSecureOptions.GetCert();
221     protocolChain_ = tlsSecureOptions.GetProtocolChain();
222     crlChain_ = tlsSecureOptions.GetCrlChain();
223     keyPass_ = tlsSecureOptions.GetKeyPass();
224     key_ = tlsSecureOptions.GetKey();
225     signatureAlgorithms_ = tlsSecureOptions.GetSignatureAlgorithms();
226     cipherSuite_ = tlsSecureOptions.GetCipherSuite();
227     useRemoteCipherPrefer_ = tlsSecureOptions.UseRemoteCipherPrefer();
228     TLSVerifyMode_ = tlsSecureOptions.GetVerifyMode();
229     return *this;
230 }
231 
SetCaChain(const std::vector<std::string> & caChain)232 void TLSSecureOptions::SetCaChain(const std::vector<std::string> &caChain)
233 {
234     caChain_ = caChain;
235 }
236 
SetCert(const std::string & cert)237 void TLSSecureOptions::SetCert(const std::string &cert)
238 {
239     cert_ = cert;
240 }
241 
SetKey(const SecureData & key)242 void TLSSecureOptions::SetKey(const SecureData &key)
243 {
244     key_ = key;
245 }
246 
SetKeyPass(const SecureData & keyPass)247 void TLSSecureOptions::SetKeyPass(const SecureData &keyPass)
248 {
249     keyPass_ = keyPass;
250 }
251 
SetProtocolChain(const std::vector<std::string> & protocolChain)252 void TLSSecureOptions::SetProtocolChain(const std::vector<std::string> &protocolChain)
253 {
254     protocolChain_ = protocolChain;
255 }
256 
SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer)257 void TLSSecureOptions::SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer)
258 {
259     useRemoteCipherPrefer_ = useRemoteCipherPrefer;
260 }
261 
SetSignatureAlgorithms(const std::string & signatureAlgorithms)262 void TLSSecureOptions::SetSignatureAlgorithms(const std::string &signatureAlgorithms)
263 {
264     signatureAlgorithms_ = signatureAlgorithms;
265 }
266 
SetCipherSuite(const std::string & cipherSuite)267 void TLSSecureOptions::SetCipherSuite(const std::string &cipherSuite)
268 {
269     cipherSuite_ = cipherSuite;
270 }
271 
SetCrlChain(const std::vector<std::string> & crlChain)272 void TLSSecureOptions::SetCrlChain(const std::vector<std::string> &crlChain)
273 {
274     crlChain_ = crlChain;
275 }
276 
GetCaChain() const277 const std::vector<std::string> &TLSSecureOptions::GetCaChain() const
278 {
279     return caChain_;
280 }
281 
GetCert() const282 const std::string &TLSSecureOptions::GetCert() const
283 {
284     return cert_;
285 }
286 
GetKey() const287 const SecureData &TLSSecureOptions::GetKey() const
288 {
289     return key_;
290 }
291 
GetKeyPass() const292 const SecureData &TLSSecureOptions::GetKeyPass() const
293 {
294     return keyPass_;
295 }
296 
GetProtocolChain() const297 const std::vector<std::string> &TLSSecureOptions::GetProtocolChain() const
298 {
299     return protocolChain_;
300 }
301 
UseRemoteCipherPrefer() const302 bool TLSSecureOptions::UseRemoteCipherPrefer() const
303 {
304     return useRemoteCipherPrefer_;
305 }
306 
GetSignatureAlgorithms() const307 const std::string &TLSSecureOptions::GetSignatureAlgorithms() const
308 {
309     return signatureAlgorithms_;
310 }
311 
GetCipherSuite() const312 const std::string &TLSSecureOptions::GetCipherSuite() const
313 {
314     return cipherSuite_;
315 }
316 
GetCrlChain() const317 const std::vector<std::string> &TLSSecureOptions::GetCrlChain() const
318 {
319     return crlChain_;
320 }
321 
SetVerifyMode(VerifyMode verifyMode)322 void TLSSecureOptions::SetVerifyMode(VerifyMode verifyMode)
323 {
324     TLSVerifyMode_ = verifyMode;
325 }
326 
GetVerifyMode() const327 VerifyMode TLSSecureOptions::GetVerifyMode() const
328 {
329     return TLSVerifyMode_;
330 }
331 
SetNetAddress(const Socket::NetAddress & address)332 void TLSConnectOptions::SetNetAddress(const Socket::NetAddress &address)
333 {
334     address_.SetFamilyBySaFamily(address.GetSaFamily());
335     address_.SetRawAddress(address.GetAddress());
336     address_.SetPort(address.GetPort());
337 }
338 
SetTlsSecureOptions(TLSSecureOptions & tlsSecureOptions)339 void TLSConnectOptions::SetTlsSecureOptions(TLSSecureOptions &tlsSecureOptions)
340 {
341     tlsSecureOptions_ = tlsSecureOptions;
342 }
343 
SetCheckServerIdentity(const CheckServerIdentity & checkServerIdentity)344 void TLSConnectOptions::SetCheckServerIdentity(const CheckServerIdentity &checkServerIdentity)
345 {
346     checkServerIdentity_ = checkServerIdentity;
347 }
348 
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)349 void TLSConnectOptions::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
350 {
351     alpnProtocols_ = alpnProtocols;
352 }
353 
SetSkipRemoteValidation(bool skipRemoteValidation)354 void TLSConnectOptions::SetSkipRemoteValidation(bool skipRemoteValidation)
355 {
356     skipRemoteValidation_ = skipRemoteValidation;
357 }
358 
GetNetAddress() const359 Socket::NetAddress TLSConnectOptions::GetNetAddress() const
360 {
361     return address_;
362 }
363 
GetTlsSecureOptions() const364 TLSSecureOptions TLSConnectOptions::GetTlsSecureOptions() const
365 {
366     return tlsSecureOptions_;
367 }
368 
GetCheckServerIdentity() const369 CheckServerIdentity TLSConnectOptions::GetCheckServerIdentity() const
370 {
371     return checkServerIdentity_;
372 }
373 
GetAlpnProtocols() const374 const std::vector<std::string> &TLSConnectOptions::GetAlpnProtocols() const
375 {
376     return alpnProtocols_;
377 }
378 
GetSkipRemoteValidation() const379 bool TLSConnectOptions::GetSkipRemoteValidation() const
380 {
381     return skipRemoteValidation_;
382 }
383 
SetHostName(const std::string & hostName)384 void TLSConnectOptions::SetHostName(const std::string &hostName)
385 {
386     hostName_ = hostName;
387 }
388 
GetHostName() const389 std::string TLSConnectOptions::GetHostName() const
390 {
391     return hostName_;
392 }
393 
MakeAddressString(sockaddr * addr)394 std::string TLSSocket::MakeAddressString(sockaddr *addr)
395 {
396     if (!addr) {
397         return {};
398     }
399     if (addr->sa_family == AF_INET) {
400         auto *addr4 = reinterpret_cast<sockaddr_in *>(addr);
401         const char *str = inet_ntoa(addr4->sin_addr);
402         if (str == nullptr || strlen(str) == 0) {
403             return {};
404         }
405         return str;
406     } else if (addr->sa_family == AF_INET6) {
407         auto *addr6 = reinterpret_cast<sockaddr_in6 *>(addr);
408         char str[INET6_ADDRSTRLEN] = {0};
409         if (inet_ntop(AF_INET6, &addr6->sin6_addr, str, INET6_ADDRSTRLEN) == nullptr || strlen(str) == 0) {
410             return {};
411         }
412         return str;
413     }
414     return {};
415 }
416 
GetAddr(const Socket::NetAddress & address,sockaddr_in * addr4,sockaddr_in6 * addr6,sockaddr ** addr,socklen_t * len)417 void TLSSocket::GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr,
418                         socklen_t *len)
419 {
420     if (!addr6 || !addr4 || !len) {
421         return;
422     }
423     sa_family_t family = address.GetSaFamily();
424     if (family == AF_INET) {
425         addr4->sin_family = AF_INET;
426         addr4->sin_port = htons(address.GetPort());
427         addr4->sin_addr.s_addr = inet_addr(address.GetAddress().c_str());
428         *addr = reinterpret_cast<sockaddr *>(addr4);
429         *len = sizeof(sockaddr_in);
430     } else if (family == AF_INET6) {
431         addr6->sin6_family = AF_INET6;
432         addr6->sin6_port = htons(address.GetPort());
433         inet_pton(AF_INET6, address.GetAddress().c_str(), &addr6->sin6_addr);
434         *addr = reinterpret_cast<sockaddr *>(addr6);
435         *len = sizeof(sockaddr_in6);
436     }
437 }
438 
ExecTlsSetSockBlockFlag(int sock,bool noneBlock)439 bool TLSSocket::ExecTlsSetSockBlockFlag(int sock, bool noneBlock)
440 {
441     return SetSockBlockFlag(sock, noneBlock);
442 }
443 
ExecTlsGetAddr(const Socket::NetAddress & address,sockaddr_in * addr4,sockaddr_in6 * addr6,sockaddr ** addr,socklen_t * len)444 void TLSSocket::ExecTlsGetAddr(
445     const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr, socklen_t *len)
446 {
447     GetAddr(address, addr4, addr6, addr, len);
448 }
449 
IsExtSock() const450 bool TLSSocket::IsExtSock() const
451 {
452     return isExtSock_;
453 }
454 
MakeIpSocket(sa_family_t family)455 void TLSSocket::MakeIpSocket(sa_family_t family)
456 {
457     if (family != AF_INET && family != AF_INET6) {
458         return;
459     }
460     int sock = socket(family, SOCK_STREAM, IPPROTO_IP);
461     if (sock < 0) {
462         int resErr = ConvertErrno();
463         NETSTACK_LOGE("Create socket failed (%{public}d:%{public}s)", errno, MakeErrnoString().c_str());
464         CallOnErrorCallback(resErr, MakeErrnoString());
465         return;
466     }
467     sockFd_ = sock;
468 }
469 
ReadMessage()470 int TLSSocket::ReadMessage()
471 {
472     char buffer[MAX_RECV_BUFFER_SIZE];
473     if (memset_s(buffer, MAX_RECV_BUFFER_SIZE, 0, MAX_RECV_BUFFER_SIZE) != EOK) {
474         NETSTACK_LOGE("memset_s failed!");
475         return -1;
476     }
477     nfds_t num = 1;
478     pollfd fds[1] = {{.fd = sockFd_, .events = POLLIN}};
479     int ret = poll(fds, num, READ_TIMEOUT_MS);
480     if (ret < 0) {
481         if (errno == EAGAIN || errno == EINTR) {
482             return 0;
483         }
484         int resErr = ConvertErrno();
485         NETSTACK_LOGE("Message poll errno is %{public}d %{public}s", errno, MakeErrnoString().c_str());
486         CallOnErrorCallback(resErr, MakeErrnoString());
487         return ret;
488     } else if (ret == 0) {
489         NETSTACK_LOGD("tls recv poll timeout");
490         return ret;
491     }
492 
493     std::lock_guard<std::mutex> lock(recvMutex_);
494     if (!isRunning_) {
495         return -1;
496     }
497     int len = tlsSocketInternal_.Recv(buffer, MAX_RECV_BUFFER_SIZE);
498     if (len < 0) {
499         if (errno == EAGAIN || errno == EINTR || len == SSL_WANT_READ_RETURN) {
500             return 0;
501         }
502         int resErr = tlsSocketInternal_.ConvertSSLError();
503         NETSTACK_LOGE("SSL_read function read error, errno is %{public}d, errno info is %{public}s", resErr,
504                       MakeSSLErrorString(resErr).c_str());
505         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
506         return len;
507     } else if (len == 0) {
508         NETSTACK_LOGI("Message recv len 0, session is closed by peer");
509         CallOnCloseCallback();
510         return -1;
511     }
512     Socket::SocketRemoteInfo remoteInfo;
513     remoteInfo.SetSize(len);
514     tlsSocketInternal_.MakeRemoteInfo(remoteInfo);
515     std::string bufContent(buffer, len);
516     CallOnMessageCallback(bufContent, remoteInfo);
517 
518     return ret;
519 }
520 
StartReadMessage()521 void TLSSocket::StartReadMessage()
522 {
523     auto wp = std::weak_ptr<TLSSocket>(shared_from_this());
524     std::thread thread([wp]() {
525         auto tlsSocket = wp.lock();
526         if (tlsSocket == nullptr) {
527             return;
528         }
529         tlsSocket->isRunning_ = true;
530         tlsSocket->isRunOver_ = false;
531 #if defined(MAC_PLATFORM) || defined(IOS_PLATFORM)
532         pthread_setname_np(TLS_SOCKET_CLIENT_READ);
533 #else
534         pthread_setname_np(pthread_self(), TLS_SOCKET_CLIENT_READ);
535 #endif
536         while (tlsSocket->isRunning_) {
537             int ret = tlsSocket->ReadMessage();
538             if (ret < 0) {
539                 break;
540             }
541         }
542         tlsSocket->isRunOver_ = true;
543         tlsSocket->cvSslFree_.notify_one();
544     });
545     thread.detach();
546 }
547 
CallOnMessageCallback(const std::string & data,const Socket::SocketRemoteInfo & remoteInfo)548 void TLSSocket::CallOnMessageCallback(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)
549 {
550     OnMessageCallback func = nullptr;
551     {
552         std::lock_guard<std::mutex> lock(mutex_);
553         if (onMessageCallback_) {
554             func = onMessageCallback_;
555         }
556     }
557 
558     if (func) {
559         func(data, remoteInfo);
560     }
561 }
562 
CallOnConnectCallback()563 void TLSSocket::CallOnConnectCallback()
564 {
565     OnConnectCallback func = nullptr;
566     {
567         std::lock_guard<std::mutex> lock(mutex_);
568         if (onConnectCallback_) {
569             func = onConnectCallback_;
570         }
571     }
572 
573     if (func) {
574         func();
575     }
576 }
577 
CallOnCloseCallback()578 void TLSSocket::CallOnCloseCallback()
579 {
580     OnCloseCallback func = nullptr;
581     {
582         std::lock_guard<std::mutex> lock(mutex_);
583         if (onCloseCallback_) {
584             func = onCloseCallback_;
585         }
586     }
587 
588     if (func) {
589         func();
590     }
591 }
592 
CallOnErrorCallback(int32_t err,const std::string & errString)593 void TLSSocket::CallOnErrorCallback(int32_t err, const std::string &errString)
594 {
595     OnErrorCallback func = nullptr;
596     {
597         std::lock_guard<std::mutex> lock(mutex_);
598         if (onErrorCallback_) {
599             func = onErrorCallback_;
600         }
601     }
602 
603     if (func) {
604         func(err, errString);
605     }
606 }
607 
CallBindCallback(int32_t err,BindCallback callback)608 void TLSSocket::CallBindCallback(int32_t err, BindCallback callback)
609 {
610     DealCallback<BindCallback>(err, callback);
611 }
612 
CallConnectCallback(int32_t err,ConnectCallback callback)613 void TLSSocket::CallConnectCallback(int32_t err, ConnectCallback callback)
614 {
615     DealCallback<ConnectCallback>(err, callback);
616 }
617 
CallSendCallback(int32_t err,SendCallback callback)618 void TLSSocket::CallSendCallback(int32_t err, SendCallback callback)
619 {
620     DealCallback<SendCallback>(err, callback);
621 }
622 
CallCloseCallback(int32_t err,CloseCallback callback)623 void TLSSocket::CallCloseCallback(int32_t err, CloseCallback callback)
624 {
625     DealCallback<CloseCallback>(err, callback);
626 }
627 
CallGetRemoteAddressCallback(int32_t err,const Socket::NetAddress & address,GetRemoteAddressCallback callback)628 void TLSSocket::CallGetRemoteAddressCallback(int32_t err, const Socket::NetAddress &address,
629                                              GetRemoteAddressCallback callback)
630 {
631     if (callback) {
632         callback(err, address);
633     }
634 }
635 
CallGetStateCallback(int32_t err,const Socket::SocketStateBase & state,GetStateCallback callback)636 void TLSSocket::CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, GetStateCallback callback)
637 {
638     if (callback) {
639         callback(err, state);
640     }
641 }
642 
CallSetExtraOptionsCallback(int32_t err,SetExtraOptionsCallback callback)643 void TLSSocket::CallSetExtraOptionsCallback(int32_t err, SetExtraOptionsCallback callback)
644 {
645     DealCallback<SetExtraOptionsCallback>(err, callback);
646 }
647 
CallGetCertificateCallback(int32_t err,const X509CertRawData & cert,GetCertificateCallback callback)648 void TLSSocket::CallGetCertificateCallback(int32_t err, const X509CertRawData &cert, GetCertificateCallback callback)
649 {
650     if (callback) {
651         callback(err, cert);
652     }
653 }
654 
CallGetRemoteCertificateCallback(int32_t err,const X509CertRawData & cert,GetRemoteCertificateCallback callback)655 void TLSSocket::CallGetRemoteCertificateCallback(int32_t err, const X509CertRawData &cert,
656                                                  GetRemoteCertificateCallback callback)
657 {
658     if (callback) {
659         callback(err, cert);
660     }
661 }
662 
CallGetProtocolCallback(int32_t err,const std::string & protocol,GetProtocolCallback callback)663 void TLSSocket::CallGetProtocolCallback(int32_t err, const std::string &protocol, GetProtocolCallback callback)
664 {
665     if (callback) {
666         callback(err, protocol);
667     }
668 }
669 
CallGetCipherSuiteCallback(int32_t err,const std::vector<std::string> & suite,GetCipherSuiteCallback callback)670 void TLSSocket::CallGetCipherSuiteCallback(int32_t err, const std::vector<std::string> &suite,
671                                            GetCipherSuiteCallback callback)
672 {
673     if (callback) {
674         callback(err, suite);
675     }
676 }
677 
CallGetSignatureAlgorithmsCallback(int32_t err,const std::vector<std::string> & algorithms,GetSignatureAlgorithmsCallback callback)678 void TLSSocket::CallGetSignatureAlgorithmsCallback(int32_t err, const std::vector<std::string> &algorithms,
679                                                    GetSignatureAlgorithmsCallback callback)
680 {
681     if (callback) {
682         callback(err, algorithms);
683     }
684 }
685 
Bind(Socket::NetAddress & address,const BindCallback & callback)686 void TLSSocket::Bind(Socket::NetAddress &address, const BindCallback &callback)
687 {
688     static constexpr int32_t PARSE_ERROR_CODE = 401;
689     if (!CommonUtils::HasInternetPermission()) {
690         CallBindCallback(PERMISSION_DENIED_CODE, callback);
691         return;
692     }
693     if (sockFd_ >= 0) {
694         CallBindCallback(TLSSOCKET_SUCCESS, callback);
695         return;
696     }
697 
698     MakeIpSocket(address.GetSaFamily());
699     if (sockFd_ < 0) {
700         int resErr = ConvertErrno();
701         NETSTACK_LOGE("make tcp socket failed errno is %{public}d %{public}s", errno, MakeErrnoString().c_str());
702         CallOnErrorCallback(resErr, MakeErrnoString());
703         CallBindCallback(resErr, callback);
704         return;
705     }
706 
707     auto temp = address.GetAddress();
708     address.SetRawAddress("");
709     address.SetAddress(temp);
710     if (address.GetAddress().empty()) {
711         CallBindCallback(PARSE_ERROR_CODE, callback);
712         return;
713     }
714 
715     sockaddr_in addr4 = {0};
716     sockaddr_in6 addr6 = {0};
717     sockaddr *addr = nullptr;
718     socklen_t len;
719     GetAddr(address, &addr4, &addr6, &addr, &len);
720     if (addr == nullptr) {
721         NETSTACK_LOGE("TLSSocket::Bind Address Is Invalid");
722         CallOnErrorCallback(-1, "Address Is Invalid");
723         CallBindCallback(ConvertErrno(), callback);
724         return;
725     }
726     CallBindCallback(TLSSOCKET_SUCCESS, callback);
727 }
728 
Connect(OHOS::NetStack::TlsSocket::TLSConnectOptions & tlsConnectOptions,const OHOS::NetStack::TlsSocket::ConnectCallback & callback)729 void TLSSocket::Connect(OHOS::NetStack::TlsSocket::TLSConnectOptions &tlsConnectOptions,
730                         const OHOS::NetStack::TlsSocket::ConnectCallback &callback)
731 {
732     if (sockFd_ < 0) {
733         int resErr = ConvertErrno();
734         NETSTACK_LOGE("connect error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
735         CallOnErrorCallback(resErr, MakeErrnoString());
736         callback(resErr);
737         return;
738     }
739 
740     if (isExtSock_ && !SetSockBlockFlag(sockFd_, false)) {
741         int resErr = ConvertErrno();
742         NETSTACK_LOGE("SetSockBlockFlag error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
743         CallOnErrorCallback(resErr, MakeErrnoString());
744         callback(resErr);
745         return;
746     }
747 
748     auto res = tlsSocketInternal_.TlsConnectToHost(sockFd_, tlsConnectOptions, isExtSock_);
749     if (!res) {
750         int resErr = tlsSocketInternal_.ConvertSSLError();
751         NETSTACK_LOGE("connect error is %{public}d %{public}d", resErr, errno);
752         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
753         callback(resErr);
754         return;
755     }
756     if (!SetSockBlockFlag(sockFd_, true)) {
757         int resErr = ConvertErrno();
758         NETSTACK_LOGE("SetSockBlockFlag error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
759         CallOnErrorCallback(resErr, MakeErrnoString());
760         callback(resErr);
761         return;
762     }
763     StartReadMessage();
764     CallOnConnectCallback();
765     callback(TLSSOCKET_SUCCESS);
766 }
767 
Send(const OHOS::NetStack::Socket::TCPSendOptions & tcpSendOptions,const SendCallback & callback)768 void TLSSocket::Send(const OHOS::NetStack::Socket::TCPSendOptions &tcpSendOptions, const SendCallback &callback)
769 {
770     (void)tcpSendOptions;
771 
772     auto res = tlsSocketInternal_.Send(tcpSendOptions.GetData());
773     if (!res) {
774         int resErr = tlsSocketInternal_.ConvertSSLError();
775         NETSTACK_LOGE("send error is %{public}d %{public}d", resErr, errno);
776         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
777         CallSendCallback(resErr, callback);
778         return;
779     }
780     CallSendCallback(TLSSOCKET_SUCCESS, callback);
781 }
782 
Close(const CloseCallback & callback)783 void TLSSocket::Close(const CloseCallback &callback)
784 {
785     isRunning_ = false;
786     std::unique_lock<std::mutex> cvLock(cvMutex_);
787     auto wp = std::weak_ptr<TLSSocket>(shared_from_this());
788     cvSslFree_.wait(cvLock, [wp]() -> bool {
789         auto tlsSocket = wp.lock();
790         if (tlsSocket == nullptr) {
791             return true;
792         }
793         return tlsSocket->isRunOver_;
794     });
795 
796     std::lock_guard<std::mutex> lock(recvMutex_);
797     auto res = tlsSocketInternal_.Close();
798     if (!res) {
799         int resErr = tlsSocketInternal_.ConvertSSLError();
800         NETSTACK_LOGE("close error is %{public}s %{public}d", MakeSSLErrorString(resErr).c_str(), resErr);
801         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
802         callback(resErr);
803         return;
804     }
805     sockFd_ = -1;
806     CallOnCloseCallback();
807     callback(TLSSOCKET_SUCCESS);
808 }
809 
GetRemoteAddress(const GetRemoteAddressCallback & callback)810 void TLSSocket::GetRemoteAddress(const GetRemoteAddressCallback &callback)
811 {
812     sockaddr sockAddr = {0};
813     socklen_t len = sizeof(sockaddr);
814     int ret = getsockname(sockFd_, &sockAddr, &len);
815     if (ret < 0) {
816         int resErr = ConvertErrno();
817         NETSTACK_LOGE("getsockname failed errno %{public}d", resErr);
818         CallOnErrorCallback(resErr, MakeErrnoString());
819         CallGetRemoteAddressCallback(resErr, {}, callback);
820         return;
821     }
822 
823     if (sockAddr.sa_family == AF_INET) {
824         GetIp4RemoteAddress(callback);
825     } else if (sockAddr.sa_family == AF_INET6) {
826         GetIp6RemoteAddress(callback);
827     }
828 }
829 
GetIp4RemoteAddress(const GetRemoteAddressCallback & callback)830 void TLSSocket::GetIp4RemoteAddress(const GetRemoteAddressCallback &callback)
831 {
832     sockaddr_in addr4 = {0};
833     socklen_t len4 = sizeof(sockaddr_in);
834 
835     int ret = getpeername(sockFd_, reinterpret_cast<sockaddr *>(&addr4), &len4);
836     if (ret < 0) {
837         int resErr = ConvertErrno();
838         NETSTACK_LOGE("GetIp4RemoteAddress failed errno %{public}d", resErr);
839         CallOnErrorCallback(resErr, MakeErrnoString());
840         CallGetRemoteAddressCallback(resErr, {}, callback);
841         return;
842     }
843 
844     std::string address = MakeAddressString(reinterpret_cast<sockaddr *>(&addr4));
845     if (address.empty()) {
846         NETSTACK_LOGE("GetIp4RemoteAddress failed errno %{public}d", errno);
847         CallOnErrorCallback(-1, "Address is invalid");
848         CallGetRemoteAddressCallback(ConvertErrno(), {}, callback);
849         return;
850     }
851     Socket::NetAddress netAddress;
852     netAddress.SetFamilyBySaFamily(AF_INET);
853     netAddress.SetRawAddress(address);
854     netAddress.SetPort(ntohs(addr4.sin_port));
855     CallGetRemoteAddressCallback(TLSSOCKET_SUCCESS, netAddress, callback);
856 }
857 
GetIp6RemoteAddress(const GetRemoteAddressCallback & callback)858 void TLSSocket::GetIp6RemoteAddress(const GetRemoteAddressCallback &callback)
859 {
860     sockaddr_in6 addr6 = {0};
861     socklen_t len6 = sizeof(sockaddr_in6);
862 
863     int ret = getpeername(sockFd_, reinterpret_cast<sockaddr *>(&addr6), &len6);
864     if (ret < 0) {
865         int resErr = ConvertErrno();
866         NETSTACK_LOGE("GetIp6RemoteAddress failed errno %{public}d", resErr);
867         CallOnErrorCallback(resErr, MakeErrnoString());
868         CallGetRemoteAddressCallback(resErr, {}, callback);
869         return;
870     }
871 
872     std::string address = MakeAddressString(reinterpret_cast<sockaddr *>(&addr6));
873     if (address.empty()) {
874         NETSTACK_LOGE("GetIp6RemoteAddress failed errno %{public}d", errno);
875         CallOnErrorCallback(-1, "Address is invalid");
876         CallGetRemoteAddressCallback(ConvertErrno(), {}, callback);
877         return;
878     }
879     Socket::NetAddress netAddress;
880     netAddress.SetFamilyBySaFamily(AF_INET6);
881     netAddress.SetRawAddress(address);
882     netAddress.SetPort(ntohs(addr6.sin6_port));
883     CallGetRemoteAddressCallback(TLSSOCKET_SUCCESS, netAddress, callback);
884 }
885 
GetState(const GetStateCallback & callback)886 void TLSSocket::GetState(const GetStateCallback &callback)
887 {
888     int opt;
889     socklen_t optLen = sizeof(int);
890     int r = getsockopt(sockFd_, SOL_SOCKET, SO_TYPE, &opt, &optLen);
891     if (r < 0) {
892         Socket::SocketStateBase state;
893         state.SetIsClose(true);
894         CallGetStateCallback(ConvertErrno(), state, callback);
895         return;
896     }
897     sockaddr sockAddr = {0};
898     socklen_t len = sizeof(sockaddr);
899     Socket::SocketStateBase state;
900     int ret = getsockname(sockFd_, &sockAddr, &len);
901     state.SetIsBound(ret == 0);
902     ret = getpeername(sockFd_, &sockAddr, &len);
903     state.SetIsConnected(ret == 0);
904     CallGetStateCallback(TLSSOCKET_SUCCESS, state, callback);
905 }
906 
SetBaseOptions(const Socket::ExtraOptionsBase & option) const907 bool TLSSocket::SetBaseOptions(const Socket::ExtraOptionsBase &option) const
908 {
909     if (option.GetReceiveBufferSize() != 0) {
910         int size = (int)option.GetReceiveBufferSize();
911         if (setsockopt(sockFd_, SOL_SOCKET, SO_RCVBUF, reinterpret_cast<void *>(&size), sizeof(size)) < 0) {
912             return false;
913         }
914     }
915 
916     if (option.GetSendBufferSize() != 0) {
917         int size = (int)option.GetSendBufferSize();
918         if (setsockopt(sockFd_, SOL_SOCKET, SO_SNDBUF, reinterpret_cast<void *>(&size), sizeof(size)) < 0) {
919             return false;
920         }
921     }
922 
923     if (option.IsReuseAddress()) {
924         int reuse = 1;
925         if (setsockopt(sockFd_, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<void *>(&reuse), sizeof(reuse)) < 0) {
926             return false;
927         }
928     }
929 
930     if (option.GetSocketTimeout() != 0) {
931         timeval timeout = {(int)option.GetSocketTimeout(), 0};
932         if (setsockopt(sockFd_, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<void *>(&timeout), sizeof(timeout)) < 0) {
933             return false;
934         }
935         if (setsockopt(sockFd_, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast<void *>(&timeout), sizeof(timeout)) < 0) {
936             return false;
937         }
938     }
939 
940     return true;
941 }
942 
SetExtraOptions(const Socket::TCPExtraOptions & option) const943 bool TLSSocket::SetExtraOptions(const Socket::TCPExtraOptions &option) const
944 {
945     if (option.IsKeepAlive()) {
946         int keepalive = 1;
947         if (setsockopt(sockFd_, SOL_SOCKET, SO_KEEPALIVE, &keepalive, sizeof(keepalive)) < 0) {
948             return false;
949         }
950     }
951 
952     if (option.IsOOBInline()) {
953         int oobInline = 1;
954         if (setsockopt(sockFd_, SOL_SOCKET, SO_OOBINLINE, &oobInline, sizeof(oobInline)) < 0) {
955             return false;
956         }
957     }
958 
959     if (option.IsTCPNoDelay()) {
960         int tcpNoDelay = 1;
961         if (setsockopt(sockFd_, IPPROTO_TCP, TCP_NODELAY, &tcpNoDelay, sizeof(tcpNoDelay)) < 0) {
962             return false;
963         }
964     }
965 
966     linger soLinger = {0};
967     soLinger.l_onoff = option.socketLinger.IsOn();
968     soLinger.l_linger = (int)option.socketLinger.GetLinger();
969     if (setsockopt(sockFd_, SOL_SOCKET, SO_LINGER, &soLinger, sizeof(soLinger)) < 0) {
970         return false;
971     }
972 
973     return true;
974 }
975 
SetExtraOptions(const OHOS::NetStack::Socket::TCPExtraOptions & tcpExtraOptions,const SetExtraOptionsCallback & callback)976 void TLSSocket::SetExtraOptions(const OHOS::NetStack::Socket::TCPExtraOptions &tcpExtraOptions,
977                                 const SetExtraOptionsCallback &callback)
978 {
979     if (!SetBaseOptions(tcpExtraOptions)) {
980         NETSTACK_LOGE("SetExtraOptions errno %{public}d", errno);
981         CallOnErrorCallback(errno, MakeErrnoString());
982         CallSetExtraOptionsCallback(ConvertErrno(), callback);
983         return;
984     }
985 
986     if (!SetExtraOptions(tcpExtraOptions)) {
987         NETSTACK_LOGE("SetExtraOptions errno %{public}d", errno);
988         CallOnErrorCallback(errno, MakeErrnoString());
989         CallSetExtraOptionsCallback(ConvertErrno(), callback);
990         return;
991     }
992 
993     CallSetExtraOptionsCallback(TLSSOCKET_SUCCESS, callback);
994 }
995 
GetCertificate(const GetCertificateCallback & callback)996 void TLSSocket::GetCertificate(const GetCertificateCallback &callback)
997 {
998     const auto &cert = tlsSocketInternal_.GetCertificate();
999     NETSTACK_LOGI("cert der is %{public}d", cert.encodingFormat);
1000 
1001     if (!cert.data.Length()) {
1002         int resErr = tlsSocketInternal_.ConvertSSLError();
1003         NETSTACK_LOGE("GetCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
1004         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1005         callback(resErr, {});
1006         return;
1007     }
1008     callback(TLSSOCKET_SUCCESS, cert);
1009 }
1010 
GetRemoteCertificate(const GetRemoteCertificateCallback & callback)1011 void TLSSocket::GetRemoteCertificate(const GetRemoteCertificateCallback &callback)
1012 {
1013     const auto &remoteCert = tlsSocketInternal_.GetRemoteCertRawData();
1014     if (!remoteCert.data.Length()) {
1015         int resErr = tlsSocketInternal_.ConvertSSLError();
1016         NETSTACK_LOGE("GetRemoteCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
1017         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1018         callback(resErr, {});
1019         return;
1020     }
1021     callback(TLSSOCKET_SUCCESS, remoteCert);
1022 }
1023 
GetProtocol(const GetProtocolCallback & callback)1024 void TLSSocket::GetProtocol(const GetProtocolCallback &callback)
1025 {
1026     const auto &protocol = tlsSocketInternal_.GetProtocol();
1027     if (protocol.empty()) {
1028         NETSTACK_LOGE("GetProtocol errno %{public}d", errno);
1029         int resErr = tlsSocketInternal_.ConvertSSLError();
1030         NETSTACK_LOGE("getProtocol error is %{public}d %{public}d", resErr, errno);
1031         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1032         callback(resErr, "");
1033         return;
1034     }
1035     callback(TLSSOCKET_SUCCESS, protocol);
1036 }
1037 
GetCipherSuite(const GetCipherSuiteCallback & callback)1038 void TLSSocket::GetCipherSuite(const GetCipherSuiteCallback &callback)
1039 {
1040     const auto &cipherSuite = tlsSocketInternal_.GetCipherSuite();
1041     if (cipherSuite.empty()) {
1042         int resErr = tlsSocketInternal_.ConvertSSLError();
1043         NETSTACK_LOGE("getCipherSuite error is %{public}d %{public}d", resErr, errno);
1044         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1045         callback(resErr, cipherSuite);
1046         return;
1047     }
1048     callback(TLSSOCKET_SUCCESS, cipherSuite);
1049 }
1050 
GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback & callback)1051 void TLSSocket::GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback &callback)
1052 {
1053     const auto &signatureAlgorithms = tlsSocketInternal_.GetSignatureAlgorithms();
1054     if (signatureAlgorithms.empty()) {
1055         int resErr = tlsSocketInternal_.ConvertSSLError();
1056         NETSTACK_LOGE("getSignatureAlgorithms error is %{public}d %{public}d", resErr, errno);
1057         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1058         callback(resErr, {});
1059         return;
1060     }
1061     callback(TLSSOCKET_SUCCESS, signatureAlgorithms);
1062 }
1063 
OnMessage(const OnMessageCallback & onMessageCallback)1064 void TLSSocket::OnMessage(const OnMessageCallback &onMessageCallback)
1065 {
1066     std::lock_guard<std::mutex> lock(mutex_);
1067     onMessageCallback_ = onMessageCallback;
1068 }
1069 
OffMessage()1070 void TLSSocket::OffMessage()
1071 {
1072     std::lock_guard<std::mutex> lock(mutex_);
1073     if (onMessageCallback_) {
1074         onMessageCallback_ = nullptr;
1075     }
1076 }
1077 
OnConnect(const OnConnectCallback & onConnectCallback)1078 void TLSSocket::OnConnect(const OnConnectCallback &onConnectCallback)
1079 {
1080     std::lock_guard<std::mutex> lock(mutex_);
1081     onConnectCallback_ = onConnectCallback;
1082 }
1083 
OffConnect()1084 void TLSSocket::OffConnect()
1085 {
1086     std::lock_guard<std::mutex> lock(mutex_);
1087     if (onConnectCallback_) {
1088         onConnectCallback_ = nullptr;
1089     }
1090 }
1091 
OnClose(const OnCloseCallback & onCloseCallback)1092 void TLSSocket::OnClose(const OnCloseCallback &onCloseCallback)
1093 {
1094     std::lock_guard<std::mutex> lock(mutex_);
1095     onCloseCallback_ = onCloseCallback;
1096 }
1097 
OffClose()1098 void TLSSocket::OffClose()
1099 {
1100     std::lock_guard<std::mutex> lock(mutex_);
1101     if (onCloseCallback_) {
1102         onCloseCallback_ = nullptr;
1103     }
1104 }
1105 
OnError(const OnErrorCallback & onErrorCallback)1106 void TLSSocket::OnError(const OnErrorCallback &onErrorCallback)
1107 {
1108     std::lock_guard<std::mutex> lock(mutex_);
1109     onErrorCallback_ = onErrorCallback;
1110 }
1111 
OffError()1112 void TLSSocket::OffError()
1113 {
1114     std::lock_guard<std::mutex> lock(mutex_);
1115     if (onErrorCallback_) {
1116         onErrorCallback_ = nullptr;
1117     }
1118 }
1119 
GetSocketFd()1120 int TLSSocket::GetSocketFd()
1121 {
1122     return sockFd_;
1123 }
1124 
SetLocalAddress(const Socket::NetAddress & address)1125 void TLSSocket::SetLocalAddress(const Socket::NetAddress &address)
1126 {
1127     localAddress_ = address;
1128 }
1129 
GetLocalAddress()1130 Socket::NetAddress TLSSocket::GetLocalAddress()
1131 {
1132     return localAddress_;
1133 }
1134 
ExecSocketConnect(const std::string & host,int port,sa_family_t family,int socketDescriptor)1135 bool ExecSocketConnect(const std::string &host, int port, sa_family_t family, int socketDescriptor)
1136 {
1137     auto hostName = ConvertAddressToIp(host, family);
1138 
1139     sockaddr_in addr4 = {0};
1140     sockaddr_in6 addr6 = {0};
1141     sockaddr *addr = nullptr;
1142     socklen_t len = 0;
1143     if (family == AF_INET) {
1144         if (inet_pton(AF_INET, hostName.c_str(), &addr4.sin_addr.s_addr) <= 0) {
1145             return false;
1146         }
1147         addr4.sin_family = family;
1148         addr4.sin_port = htons(port);
1149         addr = reinterpret_cast<sockaddr *>(&addr4);
1150         len = sizeof(sockaddr_in);
1151     } else {
1152         if (inet_pton(AF_INET6, hostName.c_str(), &addr6.sin6_addr) <= 0) {
1153             return false;
1154         }
1155         addr6.sin6_family = family;
1156         addr6.sin6_port = htons(port);
1157         addr = reinterpret_cast<sockaddr *>(&addr6);
1158         len = sizeof(sockaddr_in6);
1159     }
1160 
1161     int connectResult = connect(socketDescriptor, addr, len);
1162     if (connectResult == -1) {
1163         NETSTACK_LOGE("socket connect error!The error code is %{public}d, The error message is %{public}s", errno,
1164                       strerror(errno));
1165         return false;
1166     }
1167     return true;
1168 }
1169 
ConvertSSLError(void)1170 int TLSSocket::TLSSocketInternal::ConvertSSLError(void)
1171 {
1172     std::lock_guard<std::mutex> lock(mutexForSsl_);
1173     if (!ssl_) {
1174         return TLS_ERR_SSL_NULL;
1175     }
1176     return TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl_, SSL_RET_CODE);
1177 }
1178 
TlsConnectToHost(int sock,const TLSConnectOptions & options,bool isExtSock)1179 bool TLSSocket::TLSSocketInternal::TlsConnectToHost(int sock, const TLSConnectOptions &options, bool isExtSock)
1180 {
1181     SetTlsConfiguration(options);
1182     std::string cipherSuite = options.GetTlsSecureOptions().GetCipherSuite();
1183     if (!cipherSuite.empty()) {
1184         configuration_.SetCipherSuite(cipherSuite);
1185     }
1186     std::string signatureAlgorithms = options.GetTlsSecureOptions().GetSignatureAlgorithms();
1187     if (!signatureAlgorithms.empty()) {
1188         configuration_.SetSignatureAlgorithms(signatureAlgorithms);
1189     }
1190     const auto protocolVec = options.GetTlsSecureOptions().GetProtocolChain();
1191     if (!protocolVec.empty()) {
1192         configuration_.SetProtocol(protocolVec);
1193     }
1194     configuration_.SetSkipFlag(options.GetSkipRemoteValidation());
1195     hostName_ = options.GetNetAddress().GetAddress();
1196     port_ = options.GetNetAddress().GetPort();
1197     family_ = options.GetNetAddress().GetSaFamily();
1198     socketDescriptor_ = sock;
1199     if (options.proxyOptions_ == nullptr && !isExtSock &&
1200         !ExecSocketConnect(options.GetNetAddress().GetAddress(), options.GetNetAddress().GetPort(),
1201         options.GetNetAddress().GetSaFamily(), socketDescriptor_)) {
1202         return false;
1203     }
1204     return StartTlsConnected(options);
1205 }
1206 
SetTlsConfiguration(const TLSConnectOptions & config)1207 void TLSSocket::TLSSocketInternal::SetTlsConfiguration(const TLSConnectOptions &config)
1208 {
1209     configuration_.SetPrivateKey(config.GetTlsSecureOptions().GetKey(), config.GetTlsSecureOptions().GetKeyPass());
1210     configuration_.SetLocalCertificate(config.GetTlsSecureOptions().GetCert());
1211     configuration_.SetCaCertificate(config.GetTlsSecureOptions().GetCaChain());
1212     configuration_.SetNetAddress(config.GetNetAddress());
1213 }
1214 
SendRetry(ssl_st * ssl,const char * curPos,size_t curSendSize,int sockfd)1215 bool TLSSocket::TLSSocketInternal::SendRetry(ssl_st *ssl, const char *curPos, size_t curSendSize, int sockfd)
1216 {
1217     pollfd fds[1] = {{.fd = sockfd, .events = POLLOUT}};
1218     for (int i = 0; i <= SEND_RETRY_TIMES; i++) {
1219         int ret = poll(fds, 1, SEND_POLL_TIMEOUT_MS);
1220         if (ret < 0) {
1221             if (errno == EAGAIN || errno == EINTR) {
1222                 continue;
1223             }
1224             NETSTACK_LOGE("send poll error, fd: %{public}d, errno: %{public}d", sockfd, errno);
1225             return false;
1226         } else if (ret == 0) {
1227             NETSTACK_LOGI("send poll timeout, fd: %{public}d, errno: %{public}d", sockfd, errno);
1228             continue;
1229         }
1230         int len = SSL_write(ssl, curPos, curSendSize);
1231         if (len < 0) {
1232             int err = SSL_get_error(ssl, SSL_RET_CODE);
1233             NETSTACK_LOGE("Error in PollSend, errno is %{public}d %{public}d", err, errno);
1234             if (err == SSL_ERROR_WANT_WRITE || errno == EAGAIN) {
1235                 NETSTACK_LOGI("write retry times: %{public}d err: %{public}d errno: %{public}d", i, err, errno);
1236                 continue;
1237             } else {
1238                 NETSTACK_LOGE("write failed err: %{public}d errno: %{public}d", err, errno);
1239                 return false;
1240             }
1241         } else if (len == 0) {
1242             NETSTACK_LOGI("send len is 0, should have sent len");
1243             return false;
1244         } else {
1245             return true;
1246         }
1247     }
1248     return false;
1249 }
1250 
PollSend(int sockfd,ssl_st * ssl,const char * pdata,int sendSize)1251 bool TLSSocket::TLSSocketInternal::PollSend(int sockfd, ssl_st *ssl, const char *pdata, int sendSize)
1252 {
1253     int bufferSize = DEFAULT_BUFFER_SIZE;
1254     auto curPos = pdata;
1255     nfds_t num = 1;
1256     pollfd fds[1] = {{.fd = sockfd, .events = POLLOUT}};
1257     while (sendSize > 0) {
1258         int ret = poll(fds, num, DEFAULT_POLL_TIMEOUT_MS);
1259         if (ret < 0) {
1260             if (errno == EAGAIN || errno == EINTR) {
1261                 continue;
1262             }
1263             NETSTACK_LOGE("send poll error, fd: %{public}d, errno: %{public}d", sockfd, errno);
1264             return false;
1265         } else if (ret == 0) {
1266             NETSTACK_LOGI("send poll timeout, fd: %{public}d, errno: %{public}d", sockfd, errno);
1267             continue;
1268         }
1269         std::lock_guard<std::mutex> lock(mutexForSsl_);
1270         if (!ssl) {
1271             NETSTACK_LOGE("ssl is null");
1272             return false;
1273         }
1274         size_t curSendSize = std::min<size_t>(sendSize, bufferSize);
1275         int len = SSL_write(ssl, curPos, curSendSize);
1276         if (len < 0) {
1277             int err = SSL_get_error(ssl, SSL_RET_CODE);
1278             NETSTACK_LOGE("Error in PollSend, errno is %{public}d %{public}d", err, errno);
1279             if (err != SSL_ERROR_WANT_WRITE || errno != EAGAIN) {
1280                 NETSTACK_LOGE("write failed, return, err: %{public}d errno: %{public}d", err, errno);
1281                 return false;
1282             } else if (!SendRetry(ssl, curPos, curSendSize, sockfd)) {
1283                 return false;
1284             }
1285         } else if (len == 0) {
1286             NETSTACK_LOGI("send len is 0, should have sent len is %{public}d", sendSize);
1287             return false;
1288         }
1289         curPos += len;
1290         sendSize -= len;
1291     }
1292     return true;
1293 }
1294 
Send(const std::string & data)1295 bool TLSSocket::TLSSocketInternal::Send(const std::string &data)
1296 {
1297     {
1298         std::lock_guard<std::mutex> lock(mutexForSsl_);
1299         if (!ssl_) {
1300             NETSTACK_LOGE("ssl is null");
1301             return false;
1302         }
1303     }
1304 
1305     if (data.empty()) {
1306         NETSTACK_LOGE("data is empty");
1307         return true;
1308     }
1309 
1310     if (!PollSend(socketDescriptor_, ssl_, data.c_str(), data.size())) {
1311         return false;
1312     }
1313     return true;
1314 }
Recv(char * buffer,int maxBufferSize)1315 int TLSSocket::TLSSocketInternal::Recv(char *buffer, int maxBufferSize)
1316 {
1317     std::lock_guard<std::mutex> lock(mutexForSsl_);
1318     if (!ssl_) {
1319         NETSTACK_LOGE("ssl is null");
1320         return SSL_ERROR_RETURN;
1321     }
1322 
1323     int ret = SSL_read(ssl_, buffer, maxBufferSize);
1324     if (ret < 0) {
1325         int err = SSL_get_error(ssl_, SSL_RET_CODE);
1326         switch (err) {
1327             case SSL_ERROR_SSL:
1328                 NETSTACK_LOGE("An error occurred in the SSL library %{public}d %{public}d", err, errno);
1329                 return SSL_ERROR_RETURN;
1330             case SSL_ERROR_ZERO_RETURN:
1331                 NETSTACK_LOGE("peer disconnected...");
1332                 return SSL_ERROR_RETURN;
1333             case SSL_ERROR_WANT_READ:
1334                 NETSTACK_LOGD("SSL_read function no data available for reading, try again at a later time");
1335                 return SSL_WANT_READ_RETURN;
1336             default:
1337                 NETSTACK_LOGE("SSL_read function failed, error code is %{public}d", err);
1338                 return SSL_ERROR_RETURN;
1339         }
1340     }
1341     return ret;
1342 }
1343 
Close()1344 bool TLSSocket::TLSSocketInternal::Close()
1345 {
1346     std::lock_guard<std::mutex> lock(mutexForSsl_);
1347     if (!ssl_) {
1348         NETSTACK_LOGE("ssl is null, fd =%{public}d", socketDescriptor_);
1349         return false;
1350     }
1351     int result = SSL_shutdown(ssl_);
1352     if (result < 0) {
1353         int resErr = TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl_, SSL_RET_CODE);
1354         NETSTACK_LOGE("Error in shutdown, errno is %{public}d, error info is %{public}s", resErr,
1355                       MakeSSLErrorString(resErr).c_str());
1356     }
1357     NETSTACK_LOGI("tls socket close, fd =%{public}d", socketDescriptor_);
1358     SSL_free(ssl_);
1359     ssl_ = nullptr;
1360     close(socketDescriptor_);
1361     socketDescriptor_ = -1;
1362     if (!tlsContextPointer_) {
1363         NETSTACK_LOGE("Tls context pointer is null");
1364         return false;
1365     }
1366     tlsContextPointer_->CloseCtx();
1367     return true;
1368 }
1369 
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)1370 bool TLSSocket::TLSSocketInternal::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
1371 {
1372     if (!ssl_) {
1373         NETSTACK_LOGE("ssl is null");
1374         return false;
1375     }
1376     size_t pos = 0;
1377     size_t len = std::accumulate(alpnProtocols.begin(), alpnProtocols.end(), static_cast<size_t>(0),
1378                                  [](size_t init, const std::string &alpnProt) { return init + alpnProt.length(); });
1379     auto result = std::make_unique<unsigned char[]>(alpnProtocols.size() + len);
1380     for (const auto &str : alpnProtocols) {
1381         len = str.length();
1382         result[pos++] = len;
1383         if (!strcpy_s(reinterpret_cast<char *>(&result[pos]), len, str.c_str())) {
1384             NETSTACK_LOGE("strcpy_s failed");
1385             return false;
1386         }
1387         pos += len;
1388     }
1389     result[pos] = '\0';
1390 
1391     NETSTACK_LOGD("alpnProtocols after splicing %{public}s", result.get());
1392     if (SSL_set_alpn_protos(ssl_, result.get(), pos)) {
1393         int resErr = ConvertSSLError();
1394         NETSTACK_LOGE("Failed to set negotiable protocol list, errno is %{public}d, error info is %{public}s", resErr,
1395                       MakeSSLErrorString(resErr).c_str());
1396         return false;
1397     }
1398     return true;
1399 }
1400 
MakeRemoteInfo(Socket::SocketRemoteInfo & remoteInfo)1401 void TLSSocket::TLSSocketInternal::MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo)
1402 {
1403     remoteInfo.SetFamily(family_);
1404     remoteInfo.SetAddress(hostName_);
1405     remoteInfo.SetPort(port_);
1406 }
1407 
GetTlsConfiguration() const1408 TLSConfiguration TLSSocket::TLSSocketInternal::GetTlsConfiguration() const
1409 {
1410     return configuration_;
1411 }
1412 
GetCipherSuite() const1413 std::vector<std::string> TLSSocket::TLSSocketInternal::GetCipherSuite() const
1414 {
1415     if (!ssl_) {
1416         NETSTACK_LOGE("ssl in null");
1417         return {};
1418     }
1419     STACK_OF(SSL_CIPHER) *sk = SSL_get_ciphers(ssl_);
1420     if (!sk) {
1421         NETSTACK_LOGE("get ciphers failed");
1422         return {};
1423     }
1424     CipherSuite cipherSuite;
1425     std::vector<std::string> cipherSuiteVec;
1426     for (int i = 0; i < sk_SSL_CIPHER_num(sk); i++) {
1427         const SSL_CIPHER *c = sk_SSL_CIPHER_value(sk, i);
1428         cipherSuite.cipherName_ = SSL_CIPHER_get_name(c);
1429         cipherSuiteVec.push_back(cipherSuite.cipherName_);
1430     }
1431     return cipherSuiteVec;
1432 }
1433 
GetRemoteCertificate() const1434 std::string TLSSocket::TLSSocketInternal::GetRemoteCertificate() const
1435 {
1436     return remoteCert_;
1437 }
1438 
GetCertificate() const1439 const X509CertRawData &TLSSocket::TLSSocketInternal::GetCertificate() const
1440 {
1441     return configuration_.GetCertificate();
1442 }
1443 
GetSignatureAlgorithms() const1444 std::vector<std::string> TLSSocket::TLSSocketInternal::GetSignatureAlgorithms() const
1445 {
1446     return signatureAlgorithms_;
1447 }
1448 
GetProtocol() const1449 std::string TLSSocket::TLSSocketInternal::GetProtocol() const
1450 {
1451     if (!ssl_) {
1452         NETSTACK_LOGE("ssl in null");
1453         return PROTOCOL_UNKNOW;
1454     }
1455     if (configuration_.GetProtocol() == TLS_V1_3) {
1456         return PROTOCOL_TLS_V13;
1457     }
1458     return PROTOCOL_TLS_V12;
1459 }
1460 
SetSharedSigals()1461 bool TLSSocket::TLSSocketInternal::SetSharedSigals()
1462 {
1463     if (!ssl_) {
1464         NETSTACK_LOGE("ssl is null");
1465         return false;
1466     }
1467     int number = SSL_get_shared_sigalgs(ssl_, 0, nullptr, nullptr, nullptr, nullptr, nullptr);
1468     if (!number) {
1469         NETSTACK_LOGE("SSL_get_shared_sigalgs return value error");
1470         return false;
1471     }
1472     for (int i = 0; i < number; i++) {
1473         int hash_nid;
1474         int sign_nid;
1475         std::string sig_with_md;
1476         SSL_get_shared_sigalgs(ssl_, i, &sign_nid, &hash_nid, nullptr, nullptr, nullptr);
1477         switch (sign_nid) {
1478             case EVP_PKEY_RSA:
1479                 sig_with_md = SIGN_NID_RSA;
1480                 break;
1481             case EVP_PKEY_RSA_PSS:
1482                 sig_with_md = SIGN_NID_RSA_PSS;
1483                 break;
1484             case EVP_PKEY_DSA:
1485                 sig_with_md = SIGN_NID_DSA;
1486                 break;
1487             case EVP_PKEY_EC:
1488                 sig_with_md = SIGN_NID_ECDSA;
1489                 break;
1490             case NID_ED25519:
1491                 sig_with_md = SIGN_NID_ED;
1492                 break;
1493             case NID_ED448:
1494                 sig_with_md = SIGN_NID_ED_FOUR_FOUR_EIGHT;
1495                 break;
1496             default:
1497                 const char *sn = OBJ_nid2sn(sign_nid);
1498                 sig_with_md = (sn != nullptr) ? (std::string(sn) + OPERATOR_PLUS_SIGN) : SIGN_NID_UNDEF_ADD;
1499         }
1500         const char *sn_hash = OBJ_nid2sn(hash_nid);
1501         sig_with_md += (sn_hash != nullptr) ? std::string(sn_hash) : SIGN_NID_UNDEF;
1502         signatureAlgorithms_.push_back(sig_with_md);
1503     }
1504     return true;
1505 }
1506 
StartTlsConnected(const TLSConnectOptions & options)1507 bool TLSSocket::TLSSocketInternal::StartTlsConnected(const TLSConnectOptions &options)
1508 {
1509     if (!CreatTlsContext()) {
1510         NETSTACK_LOGE("failed to create tls context");
1511         return false;
1512     }
1513     if (!StartShakingHands(options)) {
1514         NETSTACK_LOGE("failed to shaking hands");
1515         return false;
1516     }
1517     return true;
1518 }
1519 
CreatTlsContext()1520 bool TLSSocket::TLSSocketInternal::CreatTlsContext()
1521 {
1522     tlsContextPointer_ = TLSContext::CreateConfiguration(configuration_);
1523     if (!tlsContextPointer_) {
1524         NETSTACK_LOGE("failed to create tls context pointer");
1525         return false;
1526     }
1527 
1528     std::lock_guard<std::mutex> lock(mutexForSsl_);
1529     if (!(ssl_ = tlsContextPointer_->CreateSsl())) {
1530         NETSTACK_LOGE("failed to create ssl session");
1531         return false;
1532     }
1533 
1534     SSL_set_fd(ssl_, socketDescriptor_);
1535     SSL_set_connect_state(ssl_);
1536     return true;
1537 }
1538 
StartsWith(const std::string & s,const std::string & prefix)1539 static bool StartsWith(const std::string &s, const std::string &prefix)
1540 {
1541     return s.size() >= prefix.size() && s.compare(0, prefix.size(), prefix) == 0;
1542 }
1543 
CheckIpAndDnsName(const std::string & hostName,std::vector<std::string> dnsNames,std::vector<std::string> ips,const X509 * x509Certificates,std::tuple<bool,std::string> & result)1544 void CheckIpAndDnsName(const std::string &hostName, std::vector<std::string> dnsNames, std::vector<std::string> ips,
1545                        const X509 *x509Certificates, std::tuple<bool, std::string> &result)
1546 {
1547     bool valid = false;
1548     std::string reason = UNKNOW_REASON;
1549     int index = X509_get_ext_by_NID(x509Certificates, NID_commonName, -1);
1550     if (IsIP(hostName)) {
1551         auto it = find(ips.begin(), ips.end(), hostName);
1552         if (it == ips.end()) {
1553             reason = IP + hostName + " is not in the cert's list";
1554         }
1555         result = {valid, reason};
1556         return;
1557     }
1558     std::string tempHostName = "" + hostName;
1559     if (!dnsNames.empty() || index > 0) {
1560         std::vector<std::string> hostParts = SplitHostName(tempHostName);
1561         if (!dnsNames.empty()) {
1562             valid = SeekIntersection(hostParts, dnsNames);
1563             if (!valid) {
1564                 reason = HOST_NAME + tempHostName + ". is not in the cert's altnames";
1565             }
1566         } else {
1567             char commonNameBuf[COMMON_NAME_BUF_SIZE] = {0};
1568             X509_NAME *pSubName = nullptr;
1569             int len = X509_NAME_get_text_by_NID(pSubName, NID_commonName, commonNameBuf, COMMON_NAME_BUF_SIZE);
1570             if (len > 0) {
1571                 std::vector<std::string> commonNameVec;
1572                 commonNameVec.emplace_back(commonNameBuf);
1573                 valid = SeekIntersection(hostParts, commonNameVec);
1574                 if (!valid) {
1575                     reason = HOST_NAME + tempHostName + ". is not cert's CN";
1576                 }
1577             }
1578         }
1579         result = {valid, reason};
1580         return;
1581     }
1582     reason = "Cert does not contain a DNS name";
1583     result = {valid, reason};
1584 }
1585 
CheckServerIdentityLegal(const std::string & hostName,const X509 * x509Certificates)1586 std::string TLSSocket::TLSSocketInternal::CheckServerIdentityLegal(const std::string &hostName,
1587                                                                    const X509 *x509Certificates)
1588 {
1589     X509_NAME *subjectName = X509_get_subject_name(x509Certificates);
1590     if (!subjectName) {
1591         return "subject name is null";
1592     }
1593     char subNameBuf[BUF_SIZE] = {0};
1594     X509_NAME_oneline(subjectName, subNameBuf, BUF_SIZE);
1595 
1596     int index = X509_get_ext_by_NID(x509Certificates, NID_subject_alt_name, -1);
1597     if (index < 0) {
1598         return "X509 get ext nid error";
1599     }
1600     X509_EXTENSION *ext = X509_get_ext(x509Certificates, index);
1601     if (ext == nullptr) {
1602         return "X509 get ext error";
1603     }
1604     ASN1_OBJECT *obj = nullptr;
1605     obj = X509_EXTENSION_get_object(ext);
1606     char subAltNameBuf[BUF_SIZE] = {0};
1607     OBJ_obj2txt(subAltNameBuf, BUF_SIZE, obj, 0);
1608 
1609     return CheckServerIdentityLegal(hostName, ext, x509Certificates);
1610 }
1611 
CheckServerIdentityLegal(const std::string & hostName,X509_EXTENSION * ext,const X509 * x509Certificates)1612 std::string TLSSocket::TLSSocketInternal::CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext,
1613                                                                    const X509 *x509Certificates)
1614 {
1615     ASN1_OCTET_STRING *extData = X509_EXTENSION_get_data(ext);
1616     if (!extData) {
1617         NETSTACK_LOGE("extData is nullptr");
1618         return "";
1619     }
1620     std::string altNames = reinterpret_cast<char *>(extData->data);
1621     std::string hostname = " " + hostName;
1622     BIO *bio = BIO_new(BIO_s_file());
1623     if (!bio) {
1624         return "bio is null";
1625     }
1626     BIO_set_fp(bio, stdout, BIO_NOCLOSE);
1627     ASN1_STRING_print(bio, extData);
1628     std::vector<std::string> dnsNames = {};
1629     std::vector<std::string> ips = {};
1630     constexpr int DNS_NAME_IDX = 4;
1631     constexpr int IP_NAME_IDX = 11;
1632     if (!altNames.empty()) {
1633         std::vector<std::string> splitAltNames;
1634         if (altNames.find('\"') != std::string::npos) {
1635             splitAltNames = SplitEscapedAltNames(altNames);
1636         } else {
1637             splitAltNames = CommonUtils::Split(altNames, SPLIT_ALT_NAMES);
1638         }
1639         for (auto const &iter : splitAltNames) {
1640             if (StartsWith(iter, DNS)) {
1641                 dnsNames.push_back(iter.substr(DNS_NAME_IDX));
1642             } else if (StartsWith(iter, IP_ADDRESS)) {
1643                 ips.push_back(iter.substr(IP_NAME_IDX));
1644             }
1645         }
1646     }
1647     std::tuple<bool, std::string> result;
1648     CheckIpAndDnsName(hostName, dnsNames, ips, x509Certificates, result);
1649     if (!std::get<0>(result)) {
1650         return "Hostname/IP does not match certificate's altnames: " + std::get<1>(result);
1651     }
1652     return HOST_NAME + hostname + ". is cert's CN";
1653 }
1654 
LoadCaCertFromMemory(X509_STORE * store,const std::string & pemCerts)1655 static void LoadCaCertFromMemory(X509_STORE *store, const std::string &pemCerts)
1656 {
1657     if (!store || pemCerts.empty() || pemCerts.size() > static_cast<size_t>(INT_MAX)) {
1658         return;
1659     }
1660 
1661     auto cbio = BIO_new_mem_buf(pemCerts.data(), static_cast<int>(pemCerts.size()));
1662     if (!cbio) {
1663         return;
1664     }
1665 
1666     auto inf = PEM_X509_INFO_read_bio(cbio, nullptr, nullptr, nullptr);
1667     if (!inf) {
1668         BIO_free(cbio);
1669         return;
1670     }
1671 
1672     /* add each entry from PEM file to x509_store */
1673     for (int i = 0; i < static_cast<int>(sk_X509_INFO_num(inf)); ++i) {
1674         auto itmp = sk_X509_INFO_value(inf, i);
1675         if (!itmp) {
1676             continue;
1677         }
1678         if (itmp->x509) {
1679             X509_STORE_add_cert(store, itmp->x509);
1680         }
1681         if (itmp->crl) {
1682             X509_STORE_add_crl(store, itmp->crl);
1683         }
1684     }
1685 
1686     sk_X509_INFO_pop_free(inf, X509_INFO_free);
1687     BIO_free(cbio);
1688 }
1689 
X509_to_PEM(X509 * cert)1690 static std::string X509_to_PEM(X509 *cert)
1691 {
1692     if (!cert) {
1693         return {};
1694     }
1695     BIO *bio = BIO_new(BIO_s_mem());
1696     if (!bio) {
1697         return {};
1698     }
1699     if (!PEM_write_bio_X509(bio, cert)) {
1700         BIO_free(bio);
1701         return {};
1702     }
1703 
1704     char *data = nullptr;
1705     auto pemStringLength = BIO_get_mem_data(bio, &data);
1706     if (!data) {
1707         BIO_free(bio);
1708         return {};
1709     }
1710     std::string certificateInPEM(data, pemStringLength);
1711     BIO_free(bio);
1712     return certificateInPEM;
1713 }
1714 
CacheCertificates(const std::string & hostName,SSL * ssl)1715 static void CacheCertificates(const std::string &hostName, SSL *ssl)
1716 {
1717     if (!ssl || hostName.empty()) {
1718         return;
1719     }
1720     auto certificatesStack = SSL_get_peer_cert_chain(ssl);
1721     if (!certificatesStack) {
1722         return;
1723     }
1724     auto numCertificates = sk_X509_num(certificatesStack);
1725     for (auto i = 0; i < numCertificates; ++i) {
1726         auto cert = sk_X509_value(certificatesStack, i);
1727         auto certificateInPEM = X509_to_PEM(cert);
1728         if (!certificateInPEM.empty()) {
1729             CaCertCache::GetInstance().Set(hostName, certificateInPEM);
1730         }
1731     }
1732 }
1733 
LoadCachedCaCert(const std::string & hostName,SSL * ssl)1734 static void LoadCachedCaCert(const std::string &hostName, SSL *ssl)
1735 {
1736     if (!ssl) {
1737         return;
1738     }
1739     auto cachedPem = CaCertCache::GetInstance().Get(hostName);
1740     auto sslCtx = SSL_get_SSL_CTX(ssl);
1741     if (!sslCtx) {
1742         return;
1743     }
1744     auto x509Store = SSL_CTX_get_cert_store(sslCtx);
1745     if (!x509Store) {
1746         return;
1747     }
1748     for (const auto &pem : cachedPem) {
1749         LoadCaCertFromMemory(x509Store, pem);
1750     }
1751 }
1752 
StartShakingHands(const TLSConnectOptions & options)1753 bool TLSSocket::TLSSocketInternal::StartShakingHands(const TLSConnectOptions &options)
1754 {
1755     {
1756         std::lock_guard<std::mutex> lock(mutexForSsl_);
1757         if (!ssl_) {
1758             NETSTACK_LOGE("ssl is null");
1759             return false;
1760         }
1761 
1762         auto hostName = options.GetHostName();
1763         // indicates hostName is not ip address
1764         if (hostName != options.GetNetAddress().GetAddress()) {
1765             LoadCachedCaCert(hostName, ssl_);
1766         }
1767 
1768         int result = SSL_connect(ssl_);
1769         if (result == -1) {
1770             char err[MAX_ERR_LEN] = {0};
1771             auto code = ERR_get_error();
1772             ERR_error_string_n(code, err, MAX_ERR_LEN);
1773             int errorStatus = TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl_, SSL_RET_CODE);
1774             NETSTACK_LOGE("SSLConnect fail %{public}d, error: %{public}s errno: %{public}d ERR_get_error %{public}s",
1775                           errorStatus, MakeSSLErrorString(errorStatus).c_str(), errno, err);
1776             return false;
1777         }
1778 
1779         // indicates hostName is not ip address
1780         if (hostName != options.GetNetAddress().GetAddress()) {
1781             CacheCertificates(hostName, ssl_);
1782         }
1783 
1784         std::string list = SSL_get_cipher_list(ssl_, 0);
1785         NETSTACK_LOGI("cipher_list: %{public}s, Version: %{public}s, Cipher: %{public}s", list.c_str(),
1786                       SSL_get_version(ssl_), SSL_get_cipher(ssl_));
1787         configuration_.SetCipherSuite(list);
1788     }
1789     if (!SetSharedSigals()) {
1790         NETSTACK_LOGE("Failed to set sharedSigalgs");
1791     }
1792     if (!GetRemoteCertificateFromPeer()) {
1793         NETSTACK_LOGE("Failed to get remote certificate");
1794     }
1795     if (!peerX509_) {
1796         NETSTACK_LOGE("peer x509Certificates is null");
1797         return false;
1798     }
1799     if (!SetRemoteCertRawData()) {
1800         NETSTACK_LOGE("Failed to set remote x509 certificata Serialization data");
1801     }
1802     CheckServerIdentity checkServerIdentity = options.GetCheckServerIdentity();
1803     if (!checkServerIdentity) {
1804         CheckServerIdentityLegal(hostName_, peerX509_);
1805     } else {
1806         checkServerIdentity(hostName_, {remoteCert_});
1807     }
1808     return true;
1809 }
1810 
GetRemoteCertificateFromPeer()1811 bool TLSSocket::TLSSocketInternal::GetRemoteCertificateFromPeer()
1812 {
1813     peerX509_ = SSL_get_peer_certificate(ssl_);
1814     if (peerX509_ == nullptr) {
1815         int resErr = ConvertSSLError();
1816         NETSTACK_LOGE("open fail errno, errno is %{public}d %{public}d", resErr, errno);
1817         return false;
1818     }
1819     BIO *bio = BIO_new(BIO_s_mem());
1820     if (!bio) {
1821         NETSTACK_LOGE("TlsSocket::SetRemoteCertificate bio is null");
1822         return false;
1823     }
1824     X509_print(bio, peerX509_);
1825     char data[REMOTE_CERT_LEN] = {0};
1826     if (!BIO_read(bio, data, REMOTE_CERT_LEN)) {
1827         NETSTACK_LOGE("BIO_read function returns error");
1828         BIO_free(bio);
1829         return false;
1830     }
1831     BIO_free(bio);
1832     remoteCert_ = std::string(data);
1833     return true;
1834 }
1835 
SetRemoteCertRawData()1836 bool TLSSocket::TLSSocketInternal::SetRemoteCertRawData()
1837 {
1838     if (peerX509_ == nullptr) {
1839         NETSTACK_LOGE("peerX509 is null");
1840         return false;
1841     }
1842     int32_t length = i2d_X509(peerX509_, nullptr);
1843     if (length <= 0) {
1844         NETSTACK_LOGE("Failed to convert peerX509 to der format");
1845         return false;
1846     }
1847     unsigned char *der = nullptr;
1848     (void)i2d_X509(peerX509_, &der);
1849     SecureData data(der, length);
1850     remoteRawData_.data = data;
1851     OPENSSL_free(der);
1852     remoteRawData_.encodingFormat = DER;
1853     return true;
1854 }
1855 
GetRemoteCertRawData() const1856 const X509CertRawData &TLSSocket::TLSSocketInternal::GetRemoteCertRawData() const
1857 {
1858     return remoteRawData_;
1859 }
1860 
GetSSL()1861 ssl_st *TLSSocket::TLSSocketInternal::GetSSL()
1862 {
1863     std::lock_guard<std::mutex> lock(mutexForSsl_);
1864     return ssl_;
1865 }
1866 } // namespace TlsSocket
1867 } // namespace NetStack
1868 } // namespace OHOS
1869