1 /*
2 * Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "tls_socket.h"
17
18 #include <chrono>
19 #include <memory>
20 #include <numeric>
21 #include <regex>
22 #include <securec.h>
23 #include <set>
24 #include <thread>
25 #include <poll.h>
26
27 #include <netinet/tcp.h>
28 #include <openssl/err.h>
29 #include <openssl/ssl.h>
30
31 #include "base_context.h"
32 #include "netstack_common_utils.h"
33 #include "netstack_log.h"
34 #include "tls.h"
35 #include "socket_exec_common.h"
36
37 namespace OHOS {
38 namespace NetStack {
39 namespace TlsSocket {
40 namespace {
41 constexpr int READ_TIMEOUT_MS = 500;
42 constexpr int REMOTE_CERT_LEN = 8192;
43 constexpr int COMMON_NAME_BUF_SIZE = 256;
44 constexpr int BUF_SIZE = 2048;
45 constexpr int SSL_RET_CODE = 0;
46 constexpr int SSL_ERROR_RETURN = -1;
47 constexpr int SSL_WANT_READ_RETURN = -2;
48 constexpr int OFFSET = 2;
49 constexpr int DEFAULT_BUFFER_SIZE = 8192;
50 constexpr int DEFAULT_POLL_TIMEOUT_MS = 500;
51 constexpr int SEND_RETRY_TIMES = 5;
52 constexpr int SEND_POLL_TIMEOUT_MS = 1000;
53 constexpr int MAX_RECV_BUFFER_SIZE = 1024 * 16;
54 constexpr const char *SPLIT_ALT_NAMES = ",";
55 constexpr const char *SPLIT_HOST_NAME = ".";
56 constexpr const char *PROTOCOL_UNKNOW = "UNKNOW_PROTOCOL";
57 constexpr const char *UNKNOW_REASON = "Unknown reason";
58 constexpr const char *IP = "IP: ";
59 constexpr const char *HOST_NAME = "hostname: ";
60 constexpr const char *DNS = "DNS:";
61 constexpr const char *IP_ADDRESS = "IP Address:";
62 constexpr const char *SIGN_NID_RSA = "RSA+";
63 constexpr const char *SIGN_NID_RSA_PSS = "RSA-PSS+";
64 constexpr const char *SIGN_NID_DSA = "DSA+";
65 constexpr const char *SIGN_NID_ECDSA = "ECDSA+";
66 constexpr const char *SIGN_NID_ED = "Ed25519+";
67 constexpr const char *SIGN_NID_ED_FOUR_FOUR_EIGHT = "Ed448+";
68 constexpr const char *SIGN_NID_UNDEF_ADD = "UNDEF+";
69 constexpr const char *SIGN_NID_UNDEF = "UNDEF";
70 constexpr const char *OPERATOR_PLUS_SIGN = "+";
71 static constexpr const char *TLS_SOCKET_CLIENT_READ = "OS_NET_TSCliRD";
72 const std::regex JSON_STRING_PATTERN{R"(/^"(?:[^"\\\u0000-\u001f]|\\(?:["\\/bfnrt]|u[0-9a-fA-F]{4}))*"/)"};
73 const std::regex PATTERN{
74 "((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|"
75 "2[0-4][0-9]|[01]?[0-9][0-9]?)"};
76
77 class CaCertCache {
78 public:
GetInstance()79 static CaCertCache &GetInstance()
80 {
81 static CaCertCache instance;
82 return instance;
83 }
84
Get(const std::string & key)85 std::set<std::string> Get(const std::string &key)
86 {
87 std::lock_guard l(mutex_);
88 auto it = map_.find(key);
89 if (it != map_.end()) {
90 return it->second;
91 }
92 return {};
93 }
94
Set(const std::string & key,const std::string & val)95 void Set(const std::string &key, const std::string &val)
96 {
97 std::lock_guard l(mutex_);
98 map_[key].insert(val);
99 }
100
101 private:
102 CaCertCache() = default;
103 ~CaCertCache() = default;
104 CaCertCache &operator=(const CaCertCache &) = delete;
105 CaCertCache(const CaCertCache &) = delete;
106
107 std::map<std::string, std::set<std::string>> map_;
108 std::mutex mutex_;
109 };
110
ConvertErrno()111 int ConvertErrno()
112 {
113 return TlsSocketError::TLS_ERR_SYS_BASE + errno;
114 }
115
MakeErrnoString()116 std::string MakeErrnoString()
117 {
118 return strerror(errno);
119 }
120
MakeSSLErrorString(int error)121 std::string MakeSSLErrorString(int error)
122 {
123 char err[MAX_ERR_LEN] = {0};
124 ERR_error_string_n(error - TlsSocketError::TLS_ERR_SYS_BASE, err, sizeof(err));
125 return err;
126 }
127
SplitEscapedAltNames(std::string & altNames)128 std::vector<std::string> SplitEscapedAltNames(std::string &altNames)
129 {
130 std::vector<std::string> result;
131 std::string currentToken;
132 size_t offset = 0;
133 while (offset != altNames.length()) {
134 auto nextSep = altNames.find_first_of(", ");
135 auto nextQuote = altNames.find_first_of('\"');
136 if (nextQuote != std::string::npos && (nextSep != std::string::npos || nextQuote < nextSep)) {
137 currentToken += altNames.substr(offset, nextQuote);
138 std::regex jsonStringPattern(JSON_STRING_PATTERN);
139 std::smatch match;
140 std::string altNameSubStr = altNames.substr(nextQuote);
141 bool ret = regex_match(altNameSubStr, match, jsonStringPattern);
142 if (!ret) {
143 return {""};
144 }
145 currentToken += result[0];
146 offset = nextQuote + result[0].length();
147 } else if (nextSep != std::string::npos) {
148 currentToken += altNames.substr(offset, nextSep);
149 result.push_back(currentToken);
150 currentToken = "";
151 offset = nextSep + OFFSET;
152 } else {
153 currentToken += altNames.substr(offset);
154 offset = altNames.length();
155 }
156 }
157 result.push_back(currentToken);
158 return result;
159 }
160
IsIP(const std::string & ip)161 bool IsIP(const std::string &ip)
162 {
163 std::regex pattern(PATTERN);
164 std::smatch res;
165 return regex_match(ip, res, pattern);
166 }
167
SplitHostName(std::string & hostName)168 std::vector<std::string> SplitHostName(std::string &hostName)
169 {
170 transform(hostName.begin(), hostName.end(), hostName.begin(), ::tolower);
171 return CommonUtils::Split(hostName, SPLIT_HOST_NAME);
172 }
173
SeekIntersection(std::vector<std::string> & vecA,std::vector<std::string> & vecB)174 bool SeekIntersection(std::vector<std::string> &vecA, std::vector<std::string> &vecB)
175 {
176 std::vector<std::string> result;
177 set_intersection(vecA.begin(), vecA.end(), vecB.begin(), vecB.end(), inserter(result, result.begin()));
178 return !result.empty();
179 }
180 } // namespace
181
SetSockBlockFlag(int sock,bool noneBlock)182 static bool SetSockBlockFlag(int sock, bool noneBlock)
183 {
184 int flags = fcntl(sock, F_GETFL, 0);
185 while (flags == -1 && errno == EINTR) {
186 flags = fcntl(sock, F_GETFL, 0);
187 }
188 if (flags == -1) {
189 NETSTACK_LOGE("set block flags failed, socket is %{public}d, errno is %{public}d", sock, errno);
190 return false;
191 }
192
193 auto newFlags = static_cast<size_t>(flags);
194 if (noneBlock) {
195 newFlags |= static_cast<size_t>(O_NONBLOCK);
196 } else {
197 newFlags &= ~static_cast<size_t>(O_NONBLOCK);
198 }
199
200 int ret = fcntl(sock, F_SETFL, newFlags);
201 while (ret == -1 && errno == EINTR) {
202 ret = fcntl(sock, F_SETFL, newFlags);
203 }
204 if (ret == -1) {
205 NETSTACK_LOGE("set block flags failed, socket is %{public}d, errno is %{public}d", sock, errno);
206 return false;
207 }
208 return true;
209 }
210
TLSSecureOptions(const TLSSecureOptions & tlsSecureOptions)211 TLSSecureOptions::TLSSecureOptions(const TLSSecureOptions &tlsSecureOptions)
212 {
213 *this = tlsSecureOptions;
214 }
215
operator =(const TLSSecureOptions & tlsSecureOptions)216 TLSSecureOptions &TLSSecureOptions::operator=(const TLSSecureOptions &tlsSecureOptions)
217 {
218 key_ = tlsSecureOptions.GetKey();
219 caChain_ = tlsSecureOptions.GetCaChain();
220 cert_ = tlsSecureOptions.GetCert();
221 protocolChain_ = tlsSecureOptions.GetProtocolChain();
222 crlChain_ = tlsSecureOptions.GetCrlChain();
223 keyPass_ = tlsSecureOptions.GetKeyPass();
224 key_ = tlsSecureOptions.GetKey();
225 signatureAlgorithms_ = tlsSecureOptions.GetSignatureAlgorithms();
226 cipherSuite_ = tlsSecureOptions.GetCipherSuite();
227 useRemoteCipherPrefer_ = tlsSecureOptions.UseRemoteCipherPrefer();
228 TLSVerifyMode_ = tlsSecureOptions.GetVerifyMode();
229 return *this;
230 }
231
SetCaChain(const std::vector<std::string> & caChain)232 void TLSSecureOptions::SetCaChain(const std::vector<std::string> &caChain)
233 {
234 caChain_ = caChain;
235 }
236
SetCert(const std::string & cert)237 void TLSSecureOptions::SetCert(const std::string &cert)
238 {
239 cert_ = cert;
240 }
241
SetKey(const SecureData & key)242 void TLSSecureOptions::SetKey(const SecureData &key)
243 {
244 key_ = key;
245 }
246
SetKeyPass(const SecureData & keyPass)247 void TLSSecureOptions::SetKeyPass(const SecureData &keyPass)
248 {
249 keyPass_ = keyPass;
250 }
251
SetProtocolChain(const std::vector<std::string> & protocolChain)252 void TLSSecureOptions::SetProtocolChain(const std::vector<std::string> &protocolChain)
253 {
254 protocolChain_ = protocolChain;
255 }
256
SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer)257 void TLSSecureOptions::SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer)
258 {
259 useRemoteCipherPrefer_ = useRemoteCipherPrefer;
260 }
261
SetSignatureAlgorithms(const std::string & signatureAlgorithms)262 void TLSSecureOptions::SetSignatureAlgorithms(const std::string &signatureAlgorithms)
263 {
264 signatureAlgorithms_ = signatureAlgorithms;
265 }
266
SetCipherSuite(const std::string & cipherSuite)267 void TLSSecureOptions::SetCipherSuite(const std::string &cipherSuite)
268 {
269 cipherSuite_ = cipherSuite;
270 }
271
SetCrlChain(const std::vector<std::string> & crlChain)272 void TLSSecureOptions::SetCrlChain(const std::vector<std::string> &crlChain)
273 {
274 crlChain_ = crlChain;
275 }
276
GetCaChain() const277 const std::vector<std::string> &TLSSecureOptions::GetCaChain() const
278 {
279 return caChain_;
280 }
281
GetCert() const282 const std::string &TLSSecureOptions::GetCert() const
283 {
284 return cert_;
285 }
286
GetKey() const287 const SecureData &TLSSecureOptions::GetKey() const
288 {
289 return key_;
290 }
291
GetKeyPass() const292 const SecureData &TLSSecureOptions::GetKeyPass() const
293 {
294 return keyPass_;
295 }
296
GetProtocolChain() const297 const std::vector<std::string> &TLSSecureOptions::GetProtocolChain() const
298 {
299 return protocolChain_;
300 }
301
UseRemoteCipherPrefer() const302 bool TLSSecureOptions::UseRemoteCipherPrefer() const
303 {
304 return useRemoteCipherPrefer_;
305 }
306
GetSignatureAlgorithms() const307 const std::string &TLSSecureOptions::GetSignatureAlgorithms() const
308 {
309 return signatureAlgorithms_;
310 }
311
GetCipherSuite() const312 const std::string &TLSSecureOptions::GetCipherSuite() const
313 {
314 return cipherSuite_;
315 }
316
GetCrlChain() const317 const std::vector<std::string> &TLSSecureOptions::GetCrlChain() const
318 {
319 return crlChain_;
320 }
321
SetVerifyMode(VerifyMode verifyMode)322 void TLSSecureOptions::SetVerifyMode(VerifyMode verifyMode)
323 {
324 TLSVerifyMode_ = verifyMode;
325 }
326
GetVerifyMode() const327 VerifyMode TLSSecureOptions::GetVerifyMode() const
328 {
329 return TLSVerifyMode_;
330 }
331
SetNetAddress(const Socket::NetAddress & address)332 void TLSConnectOptions::SetNetAddress(const Socket::NetAddress &address)
333 {
334 address_.SetFamilyBySaFamily(address.GetSaFamily());
335 address_.SetRawAddress(address.GetAddress());
336 address_.SetPort(address.GetPort());
337 }
338
SetTlsSecureOptions(TLSSecureOptions & tlsSecureOptions)339 void TLSConnectOptions::SetTlsSecureOptions(TLSSecureOptions &tlsSecureOptions)
340 {
341 tlsSecureOptions_ = tlsSecureOptions;
342 }
343
SetCheckServerIdentity(const CheckServerIdentity & checkServerIdentity)344 void TLSConnectOptions::SetCheckServerIdentity(const CheckServerIdentity &checkServerIdentity)
345 {
346 checkServerIdentity_ = checkServerIdentity;
347 }
348
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)349 void TLSConnectOptions::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
350 {
351 alpnProtocols_ = alpnProtocols;
352 }
353
SetSkipRemoteValidation(bool skipRemoteValidation)354 void TLSConnectOptions::SetSkipRemoteValidation(bool skipRemoteValidation)
355 {
356 skipRemoteValidation_ = skipRemoteValidation;
357 }
358
GetNetAddress() const359 Socket::NetAddress TLSConnectOptions::GetNetAddress() const
360 {
361 return address_;
362 }
363
GetTlsSecureOptions() const364 TLSSecureOptions TLSConnectOptions::GetTlsSecureOptions() const
365 {
366 return tlsSecureOptions_;
367 }
368
GetCheckServerIdentity() const369 CheckServerIdentity TLSConnectOptions::GetCheckServerIdentity() const
370 {
371 return checkServerIdentity_;
372 }
373
GetAlpnProtocols() const374 const std::vector<std::string> &TLSConnectOptions::GetAlpnProtocols() const
375 {
376 return alpnProtocols_;
377 }
378
GetSkipRemoteValidation() const379 bool TLSConnectOptions::GetSkipRemoteValidation() const
380 {
381 return skipRemoteValidation_;
382 }
383
SetHostName(const std::string & hostName)384 void TLSConnectOptions::SetHostName(const std::string &hostName)
385 {
386 hostName_ = hostName;
387 }
388
GetHostName() const389 std::string TLSConnectOptions::GetHostName() const
390 {
391 return hostName_;
392 }
393
MakeAddressString(sockaddr * addr)394 std::string TLSSocket::MakeAddressString(sockaddr *addr)
395 {
396 if (!addr) {
397 return {};
398 }
399 if (addr->sa_family == AF_INET) {
400 auto *addr4 = reinterpret_cast<sockaddr_in *>(addr);
401 const char *str = inet_ntoa(addr4->sin_addr);
402 if (str == nullptr || strlen(str) == 0) {
403 return {};
404 }
405 return str;
406 } else if (addr->sa_family == AF_INET6) {
407 auto *addr6 = reinterpret_cast<sockaddr_in6 *>(addr);
408 char str[INET6_ADDRSTRLEN] = {0};
409 if (inet_ntop(AF_INET6, &addr6->sin6_addr, str, INET6_ADDRSTRLEN) == nullptr || strlen(str) == 0) {
410 return {};
411 }
412 return str;
413 }
414 return {};
415 }
416
GetAddr(const Socket::NetAddress & address,sockaddr_in * addr4,sockaddr_in6 * addr6,sockaddr ** addr,socklen_t * len)417 void TLSSocket::GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr,
418 socklen_t *len)
419 {
420 if (!addr6 || !addr4 || !len) {
421 return;
422 }
423 sa_family_t family = address.GetSaFamily();
424 if (family == AF_INET) {
425 addr4->sin_family = AF_INET;
426 addr4->sin_port = htons(address.GetPort());
427 addr4->sin_addr.s_addr = inet_addr(address.GetAddress().c_str());
428 *addr = reinterpret_cast<sockaddr *>(addr4);
429 *len = sizeof(sockaddr_in);
430 } else if (family == AF_INET6) {
431 addr6->sin6_family = AF_INET6;
432 addr6->sin6_port = htons(address.GetPort());
433 inet_pton(AF_INET6, address.GetAddress().c_str(), &addr6->sin6_addr);
434 *addr = reinterpret_cast<sockaddr *>(addr6);
435 *len = sizeof(sockaddr_in6);
436 }
437 }
438
MakeIpSocket(sa_family_t family)439 void TLSSocket::MakeIpSocket(sa_family_t family)
440 {
441 if (family != AF_INET && family != AF_INET6) {
442 return;
443 }
444 int sock = socket(family, SOCK_STREAM, IPPROTO_IP);
445 if (sock < 0) {
446 int resErr = ConvertErrno();
447 NETSTACK_LOGE("Create socket failed (%{public}d:%{public}s)", errno, MakeErrnoString().c_str());
448 CallOnErrorCallback(resErr, MakeErrnoString());
449 return;
450 }
451 sockFd_ = sock;
452 }
453
ReadMessage()454 int TLSSocket::ReadMessage()
455 {
456 char buffer[MAX_RECV_BUFFER_SIZE];
457 if (memset_s(buffer, MAX_RECV_BUFFER_SIZE, 0, MAX_RECV_BUFFER_SIZE) != EOK) {
458 NETSTACK_LOGE("memset_s failed!");
459 return -1;
460 }
461 nfds_t num = 1;
462 pollfd fds[1] = {{.fd = sockFd_, .events = POLLIN}};
463 int ret = poll(fds, num, READ_TIMEOUT_MS);
464 if (ret < 0) {
465 if (errno == EAGAIN || errno == EINTR) {
466 return 0;
467 }
468 int resErr = ConvertErrno();
469 NETSTACK_LOGE("Message poll errno is %{public}d %{public}s", errno, MakeErrnoString().c_str());
470 CallOnErrorCallback(resErr, MakeErrnoString());
471 return ret;
472 } else if (ret == 0) {
473 NETSTACK_LOGD("tls recv poll timeout");
474 return ret;
475 }
476
477 std::lock_guard<std::mutex> lock(recvMutex_);
478 if (!isRunning_) {
479 return -1;
480 }
481 int len = tlsSocketInternal_.Recv(buffer, MAX_RECV_BUFFER_SIZE);
482 if (len < 0) {
483 if (errno == EAGAIN || errno == EINTR || len == SSL_WANT_READ_RETURN) {
484 return 0;
485 }
486 int resErr = tlsSocketInternal_.ConvertSSLError();
487 NETSTACK_LOGE("SSL_read function read error, errno is %{public}d, errno info is %{public}s",
488 resErr, MakeSSLErrorString(resErr).c_str());
489 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
490 return len;
491 } else if (len == 0) {
492 NETSTACK_LOGI("Message recv len 0, session is closed by peer");
493 CallOnCloseCallback();
494 return -1;
495 }
496 Socket::SocketRemoteInfo remoteInfo;
497 remoteInfo.SetSize(len);
498 tlsSocketInternal_.MakeRemoteInfo(remoteInfo);
499 std::string bufContent(buffer, len);
500 CallOnMessageCallback(bufContent, remoteInfo);
501
502 return ret;
503 }
504
StartReadMessage()505 void TLSSocket::StartReadMessage()
506 {
507 std::thread thread([this]() {
508 isRunning_ = true;
509 isRunOver_ = false;
510 #if defined(MAC_PLATFORM) || defined(IOS_PLATFORM)
511 pthread_setname_np(TLS_SOCKET_CLIENT_READ);
512 #else
513 pthread_setname_np(pthread_self(), TLS_SOCKET_CLIENT_READ);
514 #endif
515 while (isRunning_) {
516 int ret = ReadMessage();
517 if (ret < 0) {
518 break;
519 }
520 }
521 isRunOver_ = true;
522 cvSslFree_.notify_one();
523 });
524 thread.detach();
525 }
526
CallOnMessageCallback(const std::string & data,const Socket::SocketRemoteInfo & remoteInfo)527 void TLSSocket::CallOnMessageCallback(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)
528 {
529 OnMessageCallback func = nullptr;
530 {
531 std::lock_guard<std::mutex> lock(mutex_);
532 if (onMessageCallback_) {
533 func = onMessageCallback_;
534 }
535 }
536
537 if (func) {
538 func(data, remoteInfo);
539 }
540 }
541
CallOnConnectCallback()542 void TLSSocket::CallOnConnectCallback()
543 {
544 OnConnectCallback func = nullptr;
545 {
546 std::lock_guard<std::mutex> lock(mutex_);
547 if (onConnectCallback_) {
548 func = onConnectCallback_;
549 }
550 }
551
552 if (func) {
553 func();
554 }
555 }
556
CallOnCloseCallback()557 void TLSSocket::CallOnCloseCallback()
558 {
559 OnCloseCallback func = nullptr;
560 {
561 std::lock_guard<std::mutex> lock(mutex_);
562 if (onCloseCallback_) {
563 func = onCloseCallback_;
564 }
565 }
566
567 if (func) {
568 func();
569 }
570 }
571
CallOnErrorCallback(int32_t err,const std::string & errString)572 void TLSSocket::CallOnErrorCallback(int32_t err, const std::string &errString)
573 {
574 OnErrorCallback func = nullptr;
575 {
576 std::lock_guard<std::mutex> lock(mutex_);
577 if (onErrorCallback_) {
578 func = onErrorCallback_;
579 }
580 }
581
582 if (func) {
583 func(err, errString);
584 }
585 }
586
CallBindCallback(int32_t err,BindCallback callback)587 void TLSSocket::CallBindCallback(int32_t err, BindCallback callback)
588 {
589 DealCallback<BindCallback>(err, callback);
590 }
591
CallConnectCallback(int32_t err,ConnectCallback callback)592 void TLSSocket::CallConnectCallback(int32_t err, ConnectCallback callback)
593 {
594 DealCallback<ConnectCallback>(err, callback);
595 }
596
CallSendCallback(int32_t err,SendCallback callback)597 void TLSSocket::CallSendCallback(int32_t err, SendCallback callback)
598 {
599 DealCallback<SendCallback>(err, callback);
600 }
601
CallCloseCallback(int32_t err,CloseCallback callback)602 void TLSSocket::CallCloseCallback(int32_t err, CloseCallback callback)
603 {
604 DealCallback<CloseCallback>(err, callback);
605 }
606
CallGetRemoteAddressCallback(int32_t err,const Socket::NetAddress & address,GetRemoteAddressCallback callback)607 void TLSSocket::CallGetRemoteAddressCallback(int32_t err, const Socket::NetAddress &address,
608 GetRemoteAddressCallback callback)
609 {
610 GetRemoteAddressCallback func = nullptr;
611 {
612 std::lock_guard<std::mutex> lock(mutex_);
613 if (callback) {
614 func = callback;
615 }
616 }
617
618 if (func) {
619 func(err, address);
620 }
621 }
622
CallGetStateCallback(int32_t err,const Socket::SocketStateBase & state,GetStateCallback callback)623 void TLSSocket::CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, GetStateCallback callback)
624 {
625 GetStateCallback func = nullptr;
626 {
627 std::lock_guard<std::mutex> lock(mutex_);
628 if (callback) {
629 func = callback;
630 }
631 }
632
633 if (func) {
634 func(err, state);
635 }
636 }
637
CallSetExtraOptionsCallback(int32_t err,SetExtraOptionsCallback callback)638 void TLSSocket::CallSetExtraOptionsCallback(int32_t err, SetExtraOptionsCallback callback)
639 {
640 DealCallback<SetExtraOptionsCallback>(err, callback);
641 }
642
CallGetCertificateCallback(int32_t err,const X509CertRawData & cert,GetCertificateCallback callback)643 void TLSSocket::CallGetCertificateCallback(int32_t err, const X509CertRawData &cert, GetCertificateCallback callback)
644 {
645 GetCertificateCallback func = nullptr;
646 {
647 std::lock_guard<std::mutex> lock(mutex_);
648 if (callback) {
649 func = callback;
650 }
651 }
652
653 if (func) {
654 func(err, cert);
655 }
656 }
657
CallGetRemoteCertificateCallback(int32_t err,const X509CertRawData & cert,GetRemoteCertificateCallback callback)658 void TLSSocket::CallGetRemoteCertificateCallback(int32_t err, const X509CertRawData &cert,
659 GetRemoteCertificateCallback callback)
660 {
661 GetRemoteCertificateCallback func = nullptr;
662 {
663 std::lock_guard<std::mutex> lock(mutex_);
664 if (callback) {
665 func = callback;
666 }
667 }
668
669 if (func) {
670 func(err, cert);
671 }
672 }
673
CallGetProtocolCallback(int32_t err,const std::string & protocol,GetProtocolCallback callback)674 void TLSSocket::CallGetProtocolCallback(int32_t err, const std::string &protocol, GetProtocolCallback callback)
675 {
676 GetProtocolCallback func = nullptr;
677 {
678 std::lock_guard<std::mutex> lock(mutex_);
679 if (callback) {
680 func = callback;
681 }
682 }
683
684 if (func) {
685 func(err, protocol);
686 }
687 }
688
CallGetCipherSuiteCallback(int32_t err,const std::vector<std::string> & suite,GetCipherSuiteCallback callback)689 void TLSSocket::CallGetCipherSuiteCallback(int32_t err, const std::vector<std::string> &suite,
690 GetCipherSuiteCallback callback)
691 {
692 GetCipherSuiteCallback func = nullptr;
693 {
694 std::lock_guard<std::mutex> lock(mutex_);
695 if (callback) {
696 func = callback;
697 }
698 }
699
700 if (func) {
701 func(err, suite);
702 }
703 }
704
CallGetSignatureAlgorithmsCallback(int32_t err,const std::vector<std::string> & algorithms,GetSignatureAlgorithmsCallback callback)705 void TLSSocket::CallGetSignatureAlgorithmsCallback(int32_t err, const std::vector<std::string> &algorithms,
706 GetSignatureAlgorithmsCallback callback)
707 {
708 GetSignatureAlgorithmsCallback func = nullptr;
709 {
710 std::lock_guard<std::mutex> lock(mutex_);
711 if (callback) {
712 func = callback;
713 }
714 }
715
716 if (func) {
717 func(err, algorithms);
718 }
719 }
720
Bind(Socket::NetAddress & address,const BindCallback & callback)721 void TLSSocket::Bind(Socket::NetAddress &address, const BindCallback &callback)
722 {
723 static constexpr int32_t PARSE_ERROR_CODE = 401;
724 if (!CommonUtils::HasInternetPermission()) {
725 CallBindCallback(PERMISSION_DENIED_CODE, callback);
726 return;
727 }
728 if (sockFd_ >= 0) {
729 CallBindCallback(TLSSOCKET_SUCCESS, callback);
730 return;
731 }
732
733 MakeIpSocket(address.GetSaFamily());
734 if (sockFd_ < 0) {
735 int resErr = ConvertErrno();
736 NETSTACK_LOGE("make tcp socket failed errno is %{public}d %{public}s", errno, MakeErrnoString().c_str());
737 CallOnErrorCallback(resErr, MakeErrnoString());
738 CallBindCallback(resErr, callback);
739 return;
740 }
741
742 auto temp = address.GetAddress();
743 address.SetRawAddress("");
744 address.SetAddress(temp);
745 if (address.GetAddress().empty()) {
746 CallBindCallback(PARSE_ERROR_CODE, callback);
747 return;
748 }
749
750 sockaddr_in addr4 = {0};
751 sockaddr_in6 addr6 = {0};
752 sockaddr *addr = nullptr;
753 socklen_t len;
754 GetAddr(address, &addr4, &addr6, &addr, &len);
755 if (addr == nullptr) {
756 NETSTACK_LOGE("TLSSocket::Bind Address Is Invalid");
757 CallOnErrorCallback(-1, "Address Is Invalid");
758 CallBindCallback(ConvertErrno(), callback);
759 return;
760 }
761 CallBindCallback(TLSSOCKET_SUCCESS, callback);
762 }
763
Connect(OHOS::NetStack::TlsSocket::TLSConnectOptions & tlsConnectOptions,const OHOS::NetStack::TlsSocket::ConnectCallback & callback)764 void TLSSocket::Connect(OHOS::NetStack::TlsSocket::TLSConnectOptions &tlsConnectOptions,
765 const OHOS::NetStack::TlsSocket::ConnectCallback &callback)
766 {
767 if (sockFd_ < 0) {
768 int resErr = ConvertErrno();
769 NETSTACK_LOGE("connect error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
770 CallOnErrorCallback(resErr, MakeErrnoString());
771 callback(resErr);
772 return;
773 }
774
775 if (isExtSock_ && !SetSockBlockFlag(sockFd_, false)) {
776 int resErr = ConvertErrno();
777 NETSTACK_LOGE("SetSockBlockFlag error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
778 CallOnErrorCallback(resErr, MakeErrnoString());
779 callback(resErr);
780 return;
781 }
782
783 auto res = tlsSocketInternal_.TlsConnectToHost(sockFd_, tlsConnectOptions, isExtSock_);
784 if (!res) {
785 int resErr = tlsSocketInternal_.ConvertSSLError();
786 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
787 callback(resErr);
788 return;
789 }
790 if (!SetSockBlockFlag(sockFd_, true)) {
791 int resErr = ConvertErrno();
792 NETSTACK_LOGE("SetSockBlockFlag error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
793 CallOnErrorCallback(resErr, MakeErrnoString());
794 callback(resErr);
795 return;
796 }
797 StartReadMessage();
798 CallOnConnectCallback();
799 callback(TLSSOCKET_SUCCESS);
800 }
801
Send(const OHOS::NetStack::Socket::TCPSendOptions & tcpSendOptions,const SendCallback & callback)802 void TLSSocket::Send(const OHOS::NetStack::Socket::TCPSendOptions &tcpSendOptions, const SendCallback &callback)
803 {
804 (void)tcpSendOptions;
805
806 auto res = tlsSocketInternal_.Send(tcpSendOptions.GetData());
807 if (!res) {
808 int resErr = tlsSocketInternal_.ConvertSSLError();
809 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
810 CallSendCallback(resErr, callback);
811 return;
812 }
813 CallSendCallback(TLSSOCKET_SUCCESS, callback);
814 }
815
Close(const CloseCallback & callback)816 void TLSSocket::Close(const CloseCallback &callback)
817 {
818 isRunning_ = false;
819 std::unique_lock<std::mutex> cvLock(cvMutex_);
820 cvSslFree_.wait(cvLock, [this]() -> bool { return isRunOver_; });
821
822 std::lock_guard<std::mutex> lock(recvMutex_);
823 auto res = tlsSocketInternal_.Close();
824 if (!res) {
825 int resErr = tlsSocketInternal_.ConvertSSLError();
826 NETSTACK_LOGE("close error is %{public}s %{public}d", MakeSSLErrorString(resErr).c_str(), resErr);
827 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
828 callback(resErr);
829 return;
830 }
831 sockFd_ = -1;
832 CallOnCloseCallback();
833 callback(TLSSOCKET_SUCCESS);
834 }
835
GetRemoteAddress(const GetRemoteAddressCallback & callback)836 void TLSSocket::GetRemoteAddress(const GetRemoteAddressCallback &callback)
837 {
838 sockaddr sockAddr = {0};
839 socklen_t len = sizeof(sockaddr);
840 int ret = getsockname(sockFd_, &sockAddr, &len);
841 if (ret < 0) {
842 int resErr = ConvertErrno();
843 NETSTACK_LOGE("getsockname failed errno %{public}d", resErr);
844 CallOnErrorCallback(resErr, MakeErrnoString());
845 CallGetRemoteAddressCallback(resErr, {}, callback);
846 return;
847 }
848
849 if (sockAddr.sa_family == AF_INET) {
850 GetIp4RemoteAddress(callback);
851 } else if (sockAddr.sa_family == AF_INET6) {
852 GetIp6RemoteAddress(callback);
853 }
854 }
855
GetIp4RemoteAddress(const GetRemoteAddressCallback & callback)856 void TLSSocket::GetIp4RemoteAddress(const GetRemoteAddressCallback &callback)
857 {
858 sockaddr_in addr4 = {0};
859 socklen_t len4 = sizeof(sockaddr_in);
860
861 int ret = getpeername(sockFd_, reinterpret_cast<sockaddr *>(&addr4), &len4);
862 if (ret < 0) {
863 int resErr = ConvertErrno();
864 NETSTACK_LOGE("GetIp4RemoteAddress failed errno %{public}d", resErr);
865 CallOnErrorCallback(resErr, MakeErrnoString());
866 CallGetRemoteAddressCallback(resErr, {}, callback);
867 return;
868 }
869
870 std::string address = MakeAddressString(reinterpret_cast<sockaddr *>(&addr4));
871 if (address.empty()) {
872 NETSTACK_LOGE("GetIp4RemoteAddress failed errno %{public}d", errno);
873 CallOnErrorCallback(-1, "Address is invalid");
874 CallGetRemoteAddressCallback(ConvertErrno(), {}, callback);
875 return;
876 }
877 Socket::NetAddress netAddress;
878 netAddress.SetFamilyBySaFamily(AF_INET);
879 netAddress.SetRawAddress(address);
880 netAddress.SetPort(ntohs(addr4.sin_port));
881 CallGetRemoteAddressCallback(TLSSOCKET_SUCCESS, netAddress, callback);
882 }
883
GetIp6RemoteAddress(const GetRemoteAddressCallback & callback)884 void TLSSocket::GetIp6RemoteAddress(const GetRemoteAddressCallback &callback)
885 {
886 sockaddr_in6 addr6 = {0};
887 socklen_t len6 = sizeof(sockaddr_in6);
888
889 int ret = getpeername(sockFd_, reinterpret_cast<sockaddr *>(&addr6), &len6);
890 if (ret < 0) {
891 int resErr = ConvertErrno();
892 NETSTACK_LOGE("GetIp6RemoteAddress failed errno %{public}d", resErr);
893 CallOnErrorCallback(resErr, MakeErrnoString());
894 CallGetRemoteAddressCallback(resErr, {}, callback);
895 return;
896 }
897
898 std::string address = MakeAddressString(reinterpret_cast<sockaddr *>(&addr6));
899 if (address.empty()) {
900 NETSTACK_LOGE("GetIp6RemoteAddress failed errno %{public}d", errno);
901 CallOnErrorCallback(-1, "Address is invalid");
902 CallGetRemoteAddressCallback(ConvertErrno(), {}, callback);
903 return;
904 }
905 Socket::NetAddress netAddress;
906 netAddress.SetFamilyBySaFamily(AF_INET6);
907 netAddress.SetRawAddress(address);
908 netAddress.SetPort(ntohs(addr6.sin6_port));
909 CallGetRemoteAddressCallback(TLSSOCKET_SUCCESS, netAddress, callback);
910 }
911
GetState(const GetStateCallback & callback)912 void TLSSocket::GetState(const GetStateCallback &callback)
913 {
914 int opt;
915 socklen_t optLen = sizeof(int);
916 int r = getsockopt(sockFd_, SOL_SOCKET, SO_TYPE, &opt, &optLen);
917 if (r < 0) {
918 Socket::SocketStateBase state;
919 state.SetIsClose(true);
920 CallGetStateCallback(ConvertErrno(), state, callback);
921 return;
922 }
923 sockaddr sockAddr = {0};
924 socklen_t len = sizeof(sockaddr);
925 Socket::SocketStateBase state;
926 int ret = getsockname(sockFd_, &sockAddr, &len);
927 state.SetIsBound(ret == 0);
928 ret = getpeername(sockFd_, &sockAddr, &len);
929 state.SetIsConnected(ret == 0);
930 CallGetStateCallback(TLSSOCKET_SUCCESS, state, callback);
931 }
932
SetBaseOptions(const Socket::ExtraOptionsBase & option) const933 bool TLSSocket::SetBaseOptions(const Socket::ExtraOptionsBase &option) const
934 {
935 if (option.GetReceiveBufferSize() != 0) {
936 int size = (int)option.GetReceiveBufferSize();
937 if (setsockopt(sockFd_, SOL_SOCKET, SO_RCVBUF, reinterpret_cast<void *>(&size), sizeof(size)) < 0) {
938 return false;
939 }
940 }
941
942 if (option.GetSendBufferSize() != 0) {
943 int size = (int)option.GetSendBufferSize();
944 if (setsockopt(sockFd_, SOL_SOCKET, SO_SNDBUF, reinterpret_cast<void *>(&size), sizeof(size)) < 0) {
945 return false;
946 }
947 }
948
949 if (option.IsReuseAddress()) {
950 int reuse = 1;
951 if (setsockopt(sockFd_, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<void *>(&reuse), sizeof(reuse)) < 0) {
952 return false;
953 }
954 }
955
956 if (option.GetSocketTimeout() != 0) {
957 timeval timeout = {(int)option.GetSocketTimeout(), 0};
958 if (setsockopt(sockFd_, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<void *>(&timeout), sizeof(timeout)) < 0) {
959 return false;
960 }
961 if (setsockopt(sockFd_, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast<void *>(&timeout), sizeof(timeout)) < 0) {
962 return false;
963 }
964 }
965
966 return true;
967 }
968
SetExtraOptions(const Socket::TCPExtraOptions & option) const969 bool TLSSocket::SetExtraOptions(const Socket::TCPExtraOptions &option) const
970 {
971 if (option.IsKeepAlive()) {
972 int keepalive = 1;
973 if (setsockopt(sockFd_, SOL_SOCKET, SO_KEEPALIVE, &keepalive, sizeof(keepalive)) < 0) {
974 return false;
975 }
976 }
977
978 if (option.IsOOBInline()) {
979 int oobInline = 1;
980 if (setsockopt(sockFd_, SOL_SOCKET, SO_OOBINLINE, &oobInline, sizeof(oobInline)) < 0) {
981 return false;
982 }
983 }
984
985 if (option.IsTCPNoDelay()) {
986 int tcpNoDelay = 1;
987 if (setsockopt(sockFd_, IPPROTO_TCP, TCP_NODELAY, &tcpNoDelay, sizeof(tcpNoDelay)) < 0) {
988 return false;
989 }
990 }
991
992 linger soLinger = {0};
993 soLinger.l_onoff = option.socketLinger.IsOn();
994 soLinger.l_linger = (int)option.socketLinger.GetLinger();
995 if (setsockopt(sockFd_, SOL_SOCKET, SO_LINGER, &soLinger, sizeof(soLinger)) < 0) {
996 return false;
997 }
998
999 return true;
1000 }
1001
SetExtraOptions(const OHOS::NetStack::Socket::TCPExtraOptions & tcpExtraOptions,const SetExtraOptionsCallback & callback)1002 void TLSSocket::SetExtraOptions(const OHOS::NetStack::Socket::TCPExtraOptions &tcpExtraOptions,
1003 const SetExtraOptionsCallback &callback)
1004 {
1005 if (!SetBaseOptions(tcpExtraOptions)) {
1006 NETSTACK_LOGE("SetExtraOptions errno %{public}d", errno);
1007 CallOnErrorCallback(errno, MakeErrnoString());
1008 CallSetExtraOptionsCallback(ConvertErrno(), callback);
1009 return;
1010 }
1011
1012 if (!SetExtraOptions(tcpExtraOptions)) {
1013 NETSTACK_LOGE("SetExtraOptions errno %{public}d", errno);
1014 CallOnErrorCallback(errno, MakeErrnoString());
1015 CallSetExtraOptionsCallback(ConvertErrno(), callback);
1016 return;
1017 }
1018
1019 CallSetExtraOptionsCallback(TLSSOCKET_SUCCESS, callback);
1020 }
1021
GetCertificate(const GetCertificateCallback & callback)1022 void TLSSocket::GetCertificate(const GetCertificateCallback &callback)
1023 {
1024 const auto &cert = tlsSocketInternal_.GetCertificate();
1025 NETSTACK_LOGI("cert der is %{public}d", cert.encodingFormat);
1026
1027 if (!cert.data.Length()) {
1028 int resErr = tlsSocketInternal_.ConvertSSLError();
1029 NETSTACK_LOGE("GetCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
1030 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1031 callback(resErr, {});
1032 return;
1033 }
1034 callback(TLSSOCKET_SUCCESS, cert);
1035 }
1036
GetRemoteCertificate(const GetRemoteCertificateCallback & callback)1037 void TLSSocket::GetRemoteCertificate(const GetRemoteCertificateCallback &callback)
1038 {
1039 const auto &remoteCert = tlsSocketInternal_.GetRemoteCertRawData();
1040 if (!remoteCert.data.Length()) {
1041 int resErr = tlsSocketInternal_.ConvertSSLError();
1042 NETSTACK_LOGE("GetRemoteCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
1043 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1044 callback(resErr, {});
1045 return;
1046 }
1047 callback(TLSSOCKET_SUCCESS, remoteCert);
1048 }
1049
GetProtocol(const GetProtocolCallback & callback)1050 void TLSSocket::GetProtocol(const GetProtocolCallback &callback)
1051 {
1052 const auto &protocol = tlsSocketInternal_.GetProtocol();
1053 if (protocol.empty()) {
1054 NETSTACK_LOGE("GetProtocol errno %{public}d", errno);
1055 int resErr = tlsSocketInternal_.ConvertSSLError();
1056 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1057 callback(resErr, "");
1058 return;
1059 }
1060 callback(TLSSOCKET_SUCCESS, protocol);
1061 }
1062
GetCipherSuite(const GetCipherSuiteCallback & callback)1063 void TLSSocket::GetCipherSuite(const GetCipherSuiteCallback &callback)
1064 {
1065 const auto &cipherSuite = tlsSocketInternal_.GetCipherSuite();
1066 if (cipherSuite.empty()) {
1067 NETSTACK_LOGE("GetCipherSuite errno %{public}d", errno);
1068 int resErr = tlsSocketInternal_.ConvertSSLError();
1069 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1070 callback(resErr, cipherSuite);
1071 return;
1072 }
1073 callback(TLSSOCKET_SUCCESS, cipherSuite);
1074 }
1075
GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback & callback)1076 void TLSSocket::GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback &callback)
1077 {
1078 const auto &signatureAlgorithms = tlsSocketInternal_.GetSignatureAlgorithms();
1079 if (signatureAlgorithms.empty()) {
1080 NETSTACK_LOGE("GetSignatureAlgorithms errno %{public}d", errno);
1081 int resErr = tlsSocketInternal_.ConvertSSLError();
1082 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1083 callback(resErr, {});
1084 return;
1085 }
1086 callback(TLSSOCKET_SUCCESS, signatureAlgorithms);
1087 }
1088
OnMessage(const OnMessageCallback & onMessageCallback)1089 void TLSSocket::OnMessage(const OnMessageCallback &onMessageCallback)
1090 {
1091 std::lock_guard<std::mutex> lock(mutex_);
1092 onMessageCallback_ = onMessageCallback;
1093 }
1094
OffMessage()1095 void TLSSocket::OffMessage()
1096 {
1097 std::lock_guard<std::mutex> lock(mutex_);
1098 if (onMessageCallback_) {
1099 onMessageCallback_ = nullptr;
1100 }
1101 }
1102
OnConnect(const OnConnectCallback & onConnectCallback)1103 void TLSSocket::OnConnect(const OnConnectCallback &onConnectCallback)
1104 {
1105 std::lock_guard<std::mutex> lock(mutex_);
1106 onConnectCallback_ = onConnectCallback;
1107 }
1108
OffConnect()1109 void TLSSocket::OffConnect()
1110 {
1111 std::lock_guard<std::mutex> lock(mutex_);
1112 if (onConnectCallback_) {
1113 onConnectCallback_ = nullptr;
1114 }
1115 }
1116
OnClose(const OnCloseCallback & onCloseCallback)1117 void TLSSocket::OnClose(const OnCloseCallback &onCloseCallback)
1118 {
1119 std::lock_guard<std::mutex> lock(mutex_);
1120 onCloseCallback_ = onCloseCallback;
1121 }
1122
OffClose()1123 void TLSSocket::OffClose()
1124 {
1125 std::lock_guard<std::mutex> lock(mutex_);
1126 if (onCloseCallback_) {
1127 onCloseCallback_ = nullptr;
1128 }
1129 }
1130
OnError(const OnErrorCallback & onErrorCallback)1131 void TLSSocket::OnError(const OnErrorCallback &onErrorCallback)
1132 {
1133 std::lock_guard<std::mutex> lock(mutex_);
1134 onErrorCallback_ = onErrorCallback;
1135 }
1136
OffError()1137 void TLSSocket::OffError()
1138 {
1139 std::lock_guard<std::mutex> lock(mutex_);
1140 if (onErrorCallback_) {
1141 onErrorCallback_ = nullptr;
1142 }
1143 }
1144
GetSocketFd()1145 int TLSSocket::GetSocketFd()
1146 {
1147 return sockFd_;
1148 }
1149
SetLocalAddress(const Socket::NetAddress & address)1150 void TLSSocket::SetLocalAddress(const Socket::NetAddress &address)
1151 {
1152 localAddress_ = address;
1153 }
1154
GetLocalAddress()1155 Socket::NetAddress TLSSocket::GetLocalAddress()
1156 {
1157 return localAddress_;
1158 }
1159
GetCloseState()1160 bool TLSSocket::GetCloseState()
1161 {
1162 return isClosed;
1163 }
1164
SetCloseState(bool flag)1165 void TLSSocket::SetCloseState(bool flag)
1166 {
1167 isClosed = flag;
1168 }
1169
GetCloseLock()1170 std::mutex &TLSSocket::GetCloseLock()
1171 {
1172 return mutexForClose_;
1173 }
1174
ExecSocketConnect(const std::string & host,int port,sa_family_t family,int socketDescriptor)1175 bool ExecSocketConnect(const std::string &host, int port, sa_family_t family, int socketDescriptor)
1176 {
1177 auto hostName = ConvertAddressToIp(host, family);
1178 struct sockaddr_in dest = {0};
1179 dest.sin_family = family;
1180 dest.sin_port = htons(port);
1181
1182 sockaddr_in addr4 = {0};
1183 sockaddr_in6 addr6 = {0};
1184 sockaddr *addr = nullptr;
1185 socklen_t len = 0;
1186 if (family == AF_INET) {
1187 if (inet_pton(AF_INET, hostName.c_str(), &addr4.sin_addr.s_addr) <= 0) {
1188 return false;
1189 }
1190 addr4.sin_family = family;
1191 addr4.sin_port = htons(port);
1192 addr = reinterpret_cast<sockaddr *>(&addr4);
1193 len = sizeof(sockaddr_in);
1194 } else {
1195 if (inet_pton(AF_INET6, hostName.c_str(), &addr6.sin6_addr) <= 0) {
1196 return false;
1197 }
1198 addr6.sin6_family = family;
1199 addr6.sin6_port = htons(port);
1200 addr = reinterpret_cast<sockaddr *>(&addr6);
1201 len = sizeof(sockaddr_in6);
1202 }
1203
1204 int connectResult = connect(socketDescriptor, addr, len);
1205 if (connectResult == -1) {
1206 NETSTACK_LOGE("socket connect error!The error code is %{public}d, The error message is %{public}s", errno,
1207 strerror(errno));
1208 return false;
1209 }
1210 return true;
1211 }
1212
ConvertSSLError(void)1213 int TLSSocket::TLSSocketInternal::ConvertSSLError(void)
1214 {
1215 std::lock_guard<std::mutex> lock(mutexForSsl_);
1216 if (!ssl_) {
1217 return TLS_ERR_SSL_NULL;
1218 }
1219 return TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl_, SSL_RET_CODE);
1220 }
1221
TlsConnectToHost(int sock,const TLSConnectOptions & options,bool isExtSock)1222 bool TLSSocket::TLSSocketInternal::TlsConnectToHost(int sock, const TLSConnectOptions &options, bool isExtSock)
1223 {
1224 SetTlsConfiguration(options);
1225 std::string cipherSuite = options.GetTlsSecureOptions().GetCipherSuite();
1226 if (!cipherSuite.empty()) {
1227 configuration_.SetCipherSuite(cipherSuite);
1228 }
1229 std::string signatureAlgorithms = options.GetTlsSecureOptions().GetSignatureAlgorithms();
1230 if (!signatureAlgorithms.empty()) {
1231 configuration_.SetSignatureAlgorithms(signatureAlgorithms);
1232 }
1233 const auto protocolVec = options.GetTlsSecureOptions().GetProtocolChain();
1234 if (!protocolVec.empty()) {
1235 configuration_.SetProtocol(protocolVec);
1236 }
1237 configuration_.SetSkipFlag(options.GetSkipRemoteValidation());
1238 hostName_ = options.GetNetAddress().GetAddress();
1239 port_ = options.GetNetAddress().GetPort();
1240 family_ = options.GetNetAddress().GetSaFamily();
1241 socketDescriptor_ = sock;
1242 if (!isExtSock && !ExecSocketConnect(options.GetNetAddress().GetAddress(), options.GetNetAddress().GetPort(),
1243 options.GetNetAddress().GetSaFamily(), socketDescriptor_)) {
1244 return false;
1245 }
1246 return StartTlsConnected(options);
1247 }
1248
SetTlsConfiguration(const TLSConnectOptions & config)1249 void TLSSocket::TLSSocketInternal::SetTlsConfiguration(const TLSConnectOptions &config)
1250 {
1251 configuration_.SetPrivateKey(config.GetTlsSecureOptions().GetKey(), config.GetTlsSecureOptions().GetKeyPass());
1252 configuration_.SetLocalCertificate(config.GetTlsSecureOptions().GetCert());
1253 configuration_.SetCaCertificate(config.GetTlsSecureOptions().GetCaChain());
1254 configuration_.SetNetAddress(config.GetNetAddress());
1255 }
1256
SendRetry(ssl_st * ssl,const char * curPos,size_t curSendSize,int sockfd)1257 bool TLSSocket::TLSSocketInternal::SendRetry(ssl_st *ssl, const char *curPos, size_t curSendSize, int sockfd)
1258 {
1259 pollfd fds[1] = {{.fd = sockfd, .events = POLLOUT}};
1260 for (int i = 0; i <= SEND_RETRY_TIMES; i++) {
1261 int ret = poll(fds, 1, SEND_POLL_TIMEOUT_MS);
1262 if (ret < 0) {
1263 if (errno == EAGAIN || errno == EINTR) {
1264 continue;
1265 }
1266 NETSTACK_LOGE("send poll error, fd: %{public}d, errno: %{public}d", sockfd, errno);
1267 return false;
1268 } else if (ret == 0) {
1269 NETSTACK_LOGI("send poll timeout, fd: %{public}d, errno: %{public}d", sockfd, errno);
1270 continue;
1271 }
1272 int len = SSL_write(ssl, curPos, curSendSize);
1273 if (len < 0) {
1274 int err = SSL_get_error(ssl, SSL_RET_CODE);
1275 if (err == SSL_ERROR_WANT_WRITE || errno == EAGAIN) {
1276 NETSTACK_LOGI("write retry times: %{public}d err: %{public}d errno: %{public}d", i, err, errno);
1277 continue;
1278 } else {
1279 NETSTACK_LOGE("write failed err: %{public}d errno: %{public}d", err, errno);
1280 return false;
1281 }
1282 } else if (len == 0) {
1283 NETSTACK_LOGI("send len is 0, should have sent len");
1284 return false;
1285 } else {
1286 return true;
1287 }
1288 }
1289 return false;
1290 }
1291
PollSend(int sockfd,ssl_st * ssl,const char * pdata,int sendSize)1292 bool TLSSocket::TLSSocketInternal::PollSend(int sockfd, ssl_st *ssl, const char *pdata, int sendSize)
1293 {
1294 int bufferSize = DEFAULT_BUFFER_SIZE;
1295 auto curPos = pdata;
1296 nfds_t num = 1;
1297 pollfd fds[1] = {{.fd = sockfd, .events = POLLOUT}};
1298 while (sendSize > 0) {
1299 int ret = poll(fds, num, DEFAULT_POLL_TIMEOUT_MS);
1300 if (ret < 0) {
1301 if (errno == EAGAIN || errno == EINTR) {
1302 continue;
1303 }
1304 NETSTACK_LOGE("send poll error, fd: %{public}d, errno: %{public}d", sockfd, errno);
1305 return false;
1306 } else if (ret == 0) {
1307 NETSTACK_LOGI("send poll timeout, fd: %{public}d, errno: %{public}d", sockfd, errno);
1308 continue;
1309 }
1310 std::lock_guard<std::mutex> lock(mutexForSsl_);
1311 if (!ssl) {
1312 NETSTACK_LOGE("ssl is null");
1313 return false;
1314 }
1315 size_t curSendSize = std::min<size_t>(sendSize, bufferSize);
1316 int len = SSL_write(ssl, curPos, curSendSize);
1317 if (len < 0) {
1318 int err = SSL_get_error(ssl, SSL_RET_CODE);
1319 if (err != SSL_ERROR_WANT_WRITE || errno != EAGAIN) {
1320 NETSTACK_LOGE("write failed, return, err: %{public}d errno: %{public}d", err, errno);
1321 return false;
1322 } else if (!SendRetry(ssl, curPos, curSendSize, sockfd)) {
1323 return false;
1324 }
1325 } else if (len == 0) {
1326 NETSTACK_LOGI("send len is 0, should have sent len is %{public}d", sendSize);
1327 return false;
1328 }
1329 curPos += len;
1330 sendSize -= len;
1331 }
1332 return true;
1333 }
1334
Send(const std::string & data)1335 bool TLSSocket::TLSSocketInternal::Send(const std::string &data)
1336 {
1337 {
1338 std::lock_guard<std::mutex> lock(mutexForSsl_);
1339 if (!ssl_) {
1340 NETSTACK_LOGE("ssl is null");
1341 return false;
1342 }
1343 }
1344
1345 if (data.empty()) {
1346 NETSTACK_LOGE("data is empty");
1347 return true;
1348 }
1349
1350 if (!PollSend(socketDescriptor_, ssl_, data.c_str(), data.size())) {
1351 return false;
1352 }
1353 return true;
1354 }
Recv(char * buffer,int maxBufferSize)1355 int TLSSocket::TLSSocketInternal::Recv(char *buffer, int maxBufferSize)
1356 {
1357 if (!ssl_) {
1358 NETSTACK_LOGE("ssl is null");
1359 return SSL_ERROR_RETURN;
1360 }
1361
1362 int ret = SSL_read(ssl_, buffer, maxBufferSize);
1363 if (ret < 0) {
1364 int err = SSL_get_error(ssl_, SSL_RET_CODE);
1365 switch (err) {
1366 case SSL_ERROR_SSL:
1367 NETSTACK_LOGE("An error occurred in the SSL library");
1368 return SSL_ERROR_RETURN;
1369 case SSL_ERROR_ZERO_RETURN:
1370 NETSTACK_LOGE("peer disconnected...");
1371 return SSL_ERROR_RETURN;
1372 case SSL_ERROR_WANT_READ:
1373 NETSTACK_LOGD("SSL_read function no data available for reading, try again at a later time");
1374 return SSL_WANT_READ_RETURN;
1375 default:
1376 NETSTACK_LOGE("SSL_read function failed, error code is %{public}d", err);
1377 return SSL_ERROR_RETURN;
1378 }
1379 }
1380 return ret;
1381 }
1382
Close()1383 bool TLSSocket::TLSSocketInternal::Close()
1384 {
1385 std::lock_guard<std::mutex> lock(mutexForSsl_);
1386 if (!ssl_) {
1387 NETSTACK_LOGE("ssl is null, fd =%{public}d", socketDescriptor_);
1388 return false;
1389 }
1390 int result = SSL_shutdown(ssl_);
1391 if (result < 0) {
1392 int resErr = TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl_, SSL_RET_CODE);
1393 NETSTACK_LOGE("Error in shutdown, errno is %{public}d, error info is %{public}s", resErr,
1394 MakeSSLErrorString(resErr).c_str());
1395 }
1396 NETSTACK_LOGI("tls socket close, fd =%{public}d", socketDescriptor_);
1397 SSL_free(ssl_);
1398 ssl_ = nullptr;
1399 close(socketDescriptor_);
1400 socketDescriptor_ = -1;
1401 if (!tlsContextPointer_) {
1402 NETSTACK_LOGE("Tls context pointer is null");
1403 return false;
1404 }
1405 tlsContextPointer_->CloseCtx();
1406 return true;
1407 }
1408
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)1409 bool TLSSocket::TLSSocketInternal::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
1410 {
1411 if (!ssl_) {
1412 NETSTACK_LOGE("ssl is null");
1413 return false;
1414 }
1415 size_t pos = 0;
1416 size_t len = std::accumulate(alpnProtocols.begin(), alpnProtocols.end(), static_cast<size_t>(0),
1417 [](size_t init, const std::string &alpnProt) { return init + alpnProt.length(); });
1418 auto result = std::make_unique<unsigned char[]>(alpnProtocols.size() + len);
1419 for (const auto &str : alpnProtocols) {
1420 len = str.length();
1421 result[pos++] = len;
1422 if (!strcpy_s(reinterpret_cast<char *>(&result[pos]), len, str.c_str())) {
1423 NETSTACK_LOGE("strcpy_s failed");
1424 return false;
1425 }
1426 pos += len;
1427 }
1428 result[pos] = '\0';
1429
1430 NETSTACK_LOGD("alpnProtocols after splicing %{public}s", result.get());
1431 if (SSL_set_alpn_protos(ssl_, result.get(), pos)) {
1432 int resErr = ConvertSSLError();
1433 NETSTACK_LOGE("Failed to set negotiable protocol list, errno is %{public}d, error info is %{public}s", resErr,
1434 MakeSSLErrorString(resErr).c_str());
1435 return false;
1436 }
1437 return true;
1438 }
1439
MakeRemoteInfo(Socket::SocketRemoteInfo & remoteInfo)1440 void TLSSocket::TLSSocketInternal::MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo)
1441 {
1442 remoteInfo.SetFamily(family_);
1443 remoteInfo.SetAddress(hostName_);
1444 remoteInfo.SetPort(port_);
1445 }
1446
GetTlsConfiguration() const1447 TLSConfiguration TLSSocket::TLSSocketInternal::GetTlsConfiguration() const
1448 {
1449 return configuration_;
1450 }
1451
GetCipherSuite() const1452 std::vector<std::string> TLSSocket::TLSSocketInternal::GetCipherSuite() const
1453 {
1454 if (!ssl_) {
1455 NETSTACK_LOGE("ssl in null");
1456 return {};
1457 }
1458 STACK_OF(SSL_CIPHER) *sk = SSL_get_ciphers(ssl_);
1459 if (!sk) {
1460 NETSTACK_LOGE("get ciphers failed");
1461 return {};
1462 }
1463 CipherSuite cipherSuite;
1464 std::vector<std::string> cipherSuiteVec;
1465 for (int i = 0; i < sk_SSL_CIPHER_num(sk); i++) {
1466 const SSL_CIPHER *c = sk_SSL_CIPHER_value(sk, i);
1467 cipherSuite.cipherName_ = SSL_CIPHER_get_name(c);
1468 cipherSuiteVec.push_back(cipherSuite.cipherName_);
1469 }
1470 return cipherSuiteVec;
1471 }
1472
GetRemoteCertificate() const1473 std::string TLSSocket::TLSSocketInternal::GetRemoteCertificate() const
1474 {
1475 return remoteCert_;
1476 }
1477
GetCertificate() const1478 const X509CertRawData &TLSSocket::TLSSocketInternal::GetCertificate() const
1479 {
1480 return configuration_.GetCertificate();
1481 }
1482
GetSignatureAlgorithms() const1483 std::vector<std::string> TLSSocket::TLSSocketInternal::GetSignatureAlgorithms() const
1484 {
1485 return signatureAlgorithms_;
1486 }
1487
GetProtocol() const1488 std::string TLSSocket::TLSSocketInternal::GetProtocol() const
1489 {
1490 if (!ssl_) {
1491 NETSTACK_LOGE("ssl in null");
1492 return PROTOCOL_UNKNOW;
1493 }
1494 if (configuration_.GetProtocol() == TLS_V1_3) {
1495 return PROTOCOL_TLS_V13;
1496 }
1497 return PROTOCOL_TLS_V12;
1498 }
1499
SetSharedSigals()1500 bool TLSSocket::TLSSocketInternal::SetSharedSigals()
1501 {
1502 if (!ssl_) {
1503 NETSTACK_LOGE("ssl is null");
1504 return false;
1505 }
1506 int number = SSL_get_shared_sigalgs(ssl_, 0, nullptr, nullptr, nullptr, nullptr, nullptr);
1507 if (!number) {
1508 NETSTACK_LOGE("SSL_get_shared_sigalgs return value error");
1509 return false;
1510 }
1511 for (int i = 0; i < number; i++) {
1512 int hash_nid;
1513 int sign_nid;
1514 std::string sig_with_md;
1515 SSL_get_shared_sigalgs(ssl_, i, &sign_nid, &hash_nid, nullptr, nullptr, nullptr);
1516 switch (sign_nid) {
1517 case EVP_PKEY_RSA:
1518 sig_with_md = SIGN_NID_RSA;
1519 break;
1520 case EVP_PKEY_RSA_PSS:
1521 sig_with_md = SIGN_NID_RSA_PSS;
1522 break;
1523 case EVP_PKEY_DSA:
1524 sig_with_md = SIGN_NID_DSA;
1525 break;
1526 case EVP_PKEY_EC:
1527 sig_with_md = SIGN_NID_ECDSA;
1528 break;
1529 case NID_ED25519:
1530 sig_with_md = SIGN_NID_ED;
1531 break;
1532 case NID_ED448:
1533 sig_with_md = SIGN_NID_ED_FOUR_FOUR_EIGHT;
1534 break;
1535 default:
1536 const char *sn = OBJ_nid2sn(sign_nid);
1537 sig_with_md = (sn != nullptr) ? (std::string(sn) + OPERATOR_PLUS_SIGN) : SIGN_NID_UNDEF_ADD;
1538 }
1539 const char *sn_hash = OBJ_nid2sn(hash_nid);
1540 sig_with_md += (sn_hash != nullptr) ? std::string(sn_hash) : SIGN_NID_UNDEF;
1541 signatureAlgorithms_.push_back(sig_with_md);
1542 }
1543 return true;
1544 }
1545
StartTlsConnected(const TLSConnectOptions & options)1546 bool TLSSocket::TLSSocketInternal::StartTlsConnected(const TLSConnectOptions &options)
1547 {
1548 if (!CreatTlsContext()) {
1549 NETSTACK_LOGE("failed to create tls context");
1550 return false;
1551 }
1552 if (!StartShakingHands(options)) {
1553 NETSTACK_LOGE("failed to shaking hands");
1554 return false;
1555 }
1556 return true;
1557 }
1558
CreatTlsContext()1559 bool TLSSocket::TLSSocketInternal::CreatTlsContext()
1560 {
1561 tlsContextPointer_ = TLSContext::CreateConfiguration(configuration_);
1562 if (!tlsContextPointer_) {
1563 NETSTACK_LOGE("failed to create tls context pointer");
1564 return false;
1565 }
1566
1567 std::lock_guard<std::mutex> lock(mutexForSsl_);
1568 if (!(ssl_ = tlsContextPointer_->CreateSsl())) {
1569 NETSTACK_LOGE("failed to create ssl session");
1570 return false;
1571 }
1572
1573 SSL_set_fd(ssl_, socketDescriptor_);
1574 SSL_set_connect_state(ssl_);
1575 return true;
1576 }
1577
StartsWith(const std::string & s,const std::string & prefix)1578 static bool StartsWith(const std::string &s, const std::string &prefix)
1579 {
1580 return s.size() >= prefix.size() && s.compare(0, prefix.size(), prefix) == 0;
1581 }
1582
CheckIpAndDnsName(const std::string & hostName,std::vector<std::string> dnsNames,std::vector<std::string> ips,const X509 * x509Certificates,std::tuple<bool,std::string> & result)1583 void CheckIpAndDnsName(const std::string &hostName, std::vector<std::string> dnsNames, std::vector<std::string> ips,
1584 const X509 *x509Certificates, std::tuple<bool, std::string> &result)
1585 {
1586 bool valid = false;
1587 std::string reason = UNKNOW_REASON;
1588 int index = X509_get_ext_by_NID(x509Certificates, NID_commonName, -1);
1589 if (IsIP(hostName)) {
1590 auto it = find(ips.begin(), ips.end(), hostName);
1591 if (it == ips.end()) {
1592 reason = IP + hostName + " is not in the cert's list";
1593 }
1594 result = {valid, reason};
1595 return;
1596 }
1597 std::string tempHostName = "" + hostName;
1598 if (!dnsNames.empty() || index > 0) {
1599 std::vector<std::string> hostParts = SplitHostName(tempHostName);
1600 if (!dnsNames.empty()) {
1601 valid = SeekIntersection(hostParts, dnsNames);
1602 if (!valid) {
1603 reason = HOST_NAME + tempHostName + ". is not in the cert's altnames";
1604 }
1605 } else {
1606 char commonNameBuf[COMMON_NAME_BUF_SIZE] = {0};
1607 X509_NAME *pSubName = nullptr;
1608 int len = X509_NAME_get_text_by_NID(pSubName, NID_commonName, commonNameBuf, COMMON_NAME_BUF_SIZE);
1609 if (len > 0) {
1610 std::vector<std::string> commonNameVec;
1611 commonNameVec.emplace_back(commonNameBuf);
1612 valid = SeekIntersection(hostParts, commonNameVec);
1613 if (!valid) {
1614 reason = HOST_NAME + tempHostName + ". is not cert's CN";
1615 }
1616 }
1617 }
1618 result = {valid, reason};
1619 return;
1620 }
1621 reason = "Cert does not contain a DNS name";
1622 result = {valid, reason};
1623 }
1624
CheckServerIdentityLegal(const std::string & hostName,const X509 * x509Certificates)1625 std::string TLSSocket::TLSSocketInternal::CheckServerIdentityLegal(const std::string &hostName,
1626 const X509 *x509Certificates)
1627 {
1628 X509_NAME *subjectName = X509_get_subject_name(x509Certificates);
1629 if (!subjectName) {
1630 return "subject name is null";
1631 }
1632 char subNameBuf[BUF_SIZE] = {0};
1633 X509_NAME_oneline(subjectName, subNameBuf, BUF_SIZE);
1634
1635 int index = X509_get_ext_by_NID(x509Certificates, NID_subject_alt_name, -1);
1636 if (index < 0) {
1637 return "X509 get ext nid error";
1638 }
1639 X509_EXTENSION *ext = X509_get_ext(x509Certificates, index);
1640 if (ext == nullptr) {
1641 return "X509 get ext error";
1642 }
1643 ASN1_OBJECT *obj = nullptr;
1644 obj = X509_EXTENSION_get_object(ext);
1645 char subAltNameBuf[BUF_SIZE] = {0};
1646 OBJ_obj2txt(subAltNameBuf, BUF_SIZE, obj, 0);
1647 NETSTACK_LOGD("extions obj : %{public}s\n", subAltNameBuf);
1648
1649 return CheckServerIdentityLegal(hostName, ext, x509Certificates);
1650 }
1651
CheckServerIdentityLegal(const std::string & hostName,X509_EXTENSION * ext,const X509 * x509Certificates)1652 std::string TLSSocket::TLSSocketInternal::CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext,
1653 const X509 *x509Certificates)
1654 {
1655 ASN1_OCTET_STRING *extData = X509_EXTENSION_get_data(ext);
1656 if (!extData) {
1657 NETSTACK_LOGE("extData is nullptr");
1658 return "";
1659 }
1660
1661 std::string altNames = reinterpret_cast<char *>(extData->data);
1662 std::string hostname = " " + hostName;
1663 BIO *bio = BIO_new(BIO_s_file());
1664 if (!bio) {
1665 return "bio is null";
1666 }
1667 BIO_set_fp(bio, stdout, BIO_NOCLOSE);
1668 ASN1_STRING_print(bio, extData);
1669 std::vector<std::string> dnsNames = {};
1670 std::vector<std::string> ips = {};
1671 constexpr int DNS_NAME_IDX = 4;
1672 constexpr int IP_NAME_IDX = 11;
1673 if (!altNames.empty()) {
1674 std::vector<std::string> splitAltNames;
1675 if (altNames.find('\"') != std::string::npos) {
1676 splitAltNames = SplitEscapedAltNames(altNames);
1677 } else {
1678 splitAltNames = CommonUtils::Split(altNames, SPLIT_ALT_NAMES);
1679 }
1680 for (auto const &iter : splitAltNames) {
1681 if (StartsWith(iter, DNS)) {
1682 dnsNames.push_back(iter.substr(DNS_NAME_IDX));
1683 } else if (StartsWith(iter, IP_ADDRESS)) {
1684 ips.push_back(iter.substr(IP_NAME_IDX));
1685 }
1686 }
1687 }
1688 std::tuple<bool, std::string> result;
1689 CheckIpAndDnsName(hostName, dnsNames, ips, x509Certificates, result);
1690 if (!std::get<0>(result)) {
1691 return "Hostname/IP does not match certificate's altnames: " + std::get<1>(result);
1692 }
1693 return HOST_NAME + hostname + ". is cert's CN";
1694 }
1695
LoadCaCertFromMemory(X509_STORE * store,const std::string & pemCerts)1696 static void LoadCaCertFromMemory(X509_STORE *store, const std::string &pemCerts)
1697 {
1698 if (!store || pemCerts.empty() || pemCerts.size() > static_cast<size_t>(INT_MAX)) {
1699 return;
1700 }
1701
1702 auto cbio = BIO_new_mem_buf(pemCerts.data(), static_cast<int>(pemCerts.size()));
1703 if (!cbio) {
1704 return;
1705 }
1706
1707 auto inf = PEM_X509_INFO_read_bio(cbio, nullptr, nullptr, nullptr);
1708 if (!inf) {
1709 BIO_free(cbio);
1710 return;
1711 }
1712
1713 /* add each entry from PEM file to x509_store */
1714 for (int i = 0; i < static_cast<int>(sk_X509_INFO_num(inf)); ++i) {
1715 auto itmp = sk_X509_INFO_value(inf, i);
1716 if (!itmp) {
1717 continue;
1718 }
1719 if (itmp->x509) {
1720 X509_STORE_add_cert(store, itmp->x509);
1721 }
1722 if (itmp->crl) {
1723 X509_STORE_add_crl(store, itmp->crl);
1724 }
1725 }
1726
1727 sk_X509_INFO_pop_free(inf, X509_INFO_free);
1728 BIO_free(cbio);
1729 }
1730
X509_to_PEM(X509 * cert)1731 static std::string X509_to_PEM(X509 *cert)
1732 {
1733 if (!cert) {
1734 return {};
1735 }
1736 BIO *bio = BIO_new(BIO_s_mem());
1737 if (!bio) {
1738 return {};
1739 }
1740 if (!PEM_write_bio_X509(bio, cert)) {
1741 BIO_free(bio);
1742 return {};
1743 }
1744
1745 char *data = nullptr;
1746 auto pemStringLength = BIO_get_mem_data(bio, &data);
1747 if (!data) {
1748 BIO_free(bio);
1749 return {};
1750 }
1751 std::string certificateInPEM(data, pemStringLength);
1752 BIO_free(bio);
1753 return certificateInPEM;
1754 }
1755
CacheCertificates(const std::string & hostName,SSL * ssl)1756 static void CacheCertificates(const std::string &hostName, SSL *ssl)
1757 {
1758 if (!ssl || hostName.empty()) {
1759 return;
1760 }
1761 auto certificatesStack = SSL_get_peer_cert_chain(ssl);
1762 if (!certificatesStack) {
1763 return;
1764 }
1765 auto numCertificates = sk_X509_num(certificatesStack);
1766 for (auto i = 0; i < numCertificates; ++i) {
1767 auto cert = sk_X509_value(certificatesStack, i);
1768 auto certificateInPEM = X509_to_PEM(cert);
1769 if (!certificateInPEM.empty()) {
1770 CaCertCache::GetInstance().Set(hostName, certificateInPEM);
1771 }
1772 }
1773 }
1774
LoadCachedCaCert(const std::string & hostName,SSL * ssl)1775 static void LoadCachedCaCert(const std::string &hostName, SSL *ssl)
1776 {
1777 if (!ssl) {
1778 return;
1779 }
1780 auto cachedPem = CaCertCache::GetInstance().Get(hostName);
1781 auto sslCtx = SSL_get_SSL_CTX(ssl);
1782 if (!sslCtx) {
1783 return;
1784 }
1785 auto x509Store = SSL_CTX_get_cert_store(sslCtx);
1786 if (!x509Store) {
1787 return;
1788 }
1789 for (const auto &pem : cachedPem) {
1790 LoadCaCertFromMemory(x509Store, pem);
1791 }
1792 }
1793
StartShakingHands(const TLSConnectOptions & options)1794 bool TLSSocket::TLSSocketInternal::StartShakingHands(const TLSConnectOptions &options)
1795 {
1796 {
1797 std::lock_guard<std::mutex> lock(mutexForSsl_);
1798 if (!ssl_) {
1799 NETSTACK_LOGE("ssl is null");
1800 return false;
1801 }
1802
1803 auto hostName = options.GetHostName();
1804 // indicates hostName is not ip address
1805 if (hostName != options.GetNetAddress().GetAddress()) {
1806 LoadCachedCaCert(hostName, ssl_);
1807 }
1808
1809 int result = SSL_connect(ssl_);
1810 if (result == -1) {
1811 char err[MAX_ERR_LEN] = {0};
1812 auto code = ERR_get_error();
1813 ERR_error_string_n(code, err, MAX_ERR_LEN);
1814 int errorStatus = TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl_, SSL_RET_CODE);
1815 NETSTACK_LOGE("SSLConnect fail %{public}d, error: %{public}s errno: %{public}d ERR_get_error %{public}s",
1816 errorStatus, MakeSSLErrorString(errorStatus).c_str(), errno, err);
1817 return false;
1818 }
1819
1820 // indicates hostName is not ip address
1821 if (hostName != options.GetNetAddress().GetAddress()) {
1822 CacheCertificates(hostName, ssl_);
1823 }
1824
1825 std::string list = SSL_get_cipher_list(ssl_, 0);
1826 NETSTACK_LOGI("cipher_list: %{public}s, Version: %{public}s, Cipher: %{public}s", list.c_str(),
1827 SSL_get_version(ssl_), SSL_get_cipher(ssl_));
1828 configuration_.SetCipherSuite(list);
1829 }
1830 if (!SetSharedSigals()) {
1831 NETSTACK_LOGE("Failed to set sharedSigalgs");
1832 }
1833 if (!GetRemoteCertificateFromPeer()) {
1834 NETSTACK_LOGE("Failed to get remote certificate");
1835 }
1836 if (!peerX509_) {
1837 NETSTACK_LOGE("peer x509Certificates is null");
1838 return false;
1839 }
1840 if (!SetRemoteCertRawData()) {
1841 NETSTACK_LOGE("Failed to set remote x509 certificata Serialization data");
1842 }
1843 CheckServerIdentity checkServerIdentity = options.GetCheckServerIdentity();
1844 if (!checkServerIdentity) {
1845 CheckServerIdentityLegal(hostName_, peerX509_);
1846 } else {
1847 checkServerIdentity(hostName_, {remoteCert_});
1848 }
1849 return true;
1850 }
1851
GetRemoteCertificateFromPeer()1852 bool TLSSocket::TLSSocketInternal::GetRemoteCertificateFromPeer()
1853 {
1854 peerX509_ = SSL_get_peer_certificate(ssl_);
1855 if (peerX509_ == nullptr) {
1856 int resErr = ConvertSSLError();
1857 NETSTACK_LOGE("open fail errno, errno is %{public}d, error info is %{public}s", resErr,
1858 MakeSSLErrorString(resErr).c_str());
1859 return false;
1860 }
1861 BIO *bio = BIO_new(BIO_s_mem());
1862 if (!bio) {
1863 NETSTACK_LOGE("TlsSocket::SetRemoteCertificate bio is null");
1864 return false;
1865 }
1866 X509_print(bio, peerX509_);
1867 char data[REMOTE_CERT_LEN] = {0};
1868 if (!BIO_read(bio, data, REMOTE_CERT_LEN)) {
1869 NETSTACK_LOGE("BIO_read function returns error");
1870 BIO_free(bio);
1871 return false;
1872 }
1873 BIO_free(bio);
1874 remoteCert_ = std::string(data);
1875 return true;
1876 }
1877
SetRemoteCertRawData()1878 bool TLSSocket::TLSSocketInternal::SetRemoteCertRawData()
1879 {
1880 if (peerX509_ == nullptr) {
1881 NETSTACK_LOGE("peerX509 is null");
1882 return false;
1883 }
1884 int32_t length = i2d_X509(peerX509_, nullptr);
1885 if (length <= 0) {
1886 NETSTACK_LOGE("Failed to convert peerX509 to der format");
1887 return false;
1888 }
1889 unsigned char *der = nullptr;
1890 (void)i2d_X509(peerX509_, &der);
1891 SecureData data(der, length);
1892 remoteRawData_.data = data;
1893 OPENSSL_free(der);
1894 remoteRawData_.encodingFormat = DER;
1895 return true;
1896 }
1897
GetRemoteCertRawData() const1898 const X509CertRawData &TLSSocket::TLSSocketInternal::GetRemoteCertRawData() const
1899 {
1900 return remoteRawData_;
1901 }
1902
GetSSL()1903 ssl_st *TLSSocket::TLSSocketInternal::GetSSL()
1904 {
1905 std::lock_guard<std::mutex> lock(mutexForSsl_);
1906 return ssl_;
1907 }
1908 } // namespace TlsSocket
1909 } // namespace NetStack
1910 } // namespace OHOS
1911