1 /*
2 * Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "tls_socket.h"
17
18 #include <chrono>
19 #include <memory>
20 #include <numeric>
21 #include <poll.h>
22 #include <regex>
23 #include <securec.h>
24 #include <set>
25 #include <thread>
26
27 #include <netinet/tcp.h>
28 #include <openssl/err.h>
29 #include <openssl/ssl.h>
30
31 #include "base_context.h"
32 #include "netstack_common_utils.h"
33 #include "netstack_log.h"
34 #include "socket_exec_common.h"
35 #include "tls.h"
36
37 namespace OHOS {
38 namespace NetStack {
39 namespace TlsSocket {
40 namespace {
41 constexpr int READ_TIMEOUT_MS = 500;
42 constexpr int REMOTE_CERT_LEN = 8192;
43 constexpr int COMMON_NAME_BUF_SIZE = 256;
44 constexpr int BUF_SIZE = 2048;
45 constexpr int SSL_RET_CODE = 0;
46 constexpr int SSL_ERROR_RETURN = -1;
47 constexpr int SSL_WANT_READ_RETURN = -2;
48 constexpr int OFFSET = 2;
49 constexpr int DEFAULT_BUFFER_SIZE = 8192;
50 constexpr int DEFAULT_POLL_TIMEOUT_MS = 500;
51 constexpr int SEND_RETRY_TIMES = 5;
52 constexpr int SEND_POLL_TIMEOUT_MS = 1000;
53 constexpr int MAX_RECV_BUFFER_SIZE = 1024 * 16;
54 constexpr const char *SPLIT_ALT_NAMES = ",";
55 constexpr const char *SPLIT_HOST_NAME = ".";
56 constexpr const char *PROTOCOL_UNKNOW = "UNKNOW_PROTOCOL";
57 constexpr const char *UNKNOW_REASON = "Unknown reason";
58 constexpr const char *IP = "IP: ";
59 constexpr const char *HOST_NAME = "hostname: ";
60 constexpr const char *DNS = "DNS:";
61 constexpr const char *IP_ADDRESS = "IP Address:";
62 constexpr const char *SIGN_NID_RSA = "RSA+";
63 constexpr const char *SIGN_NID_RSA_PSS = "RSA-PSS+";
64 constexpr const char *SIGN_NID_DSA = "DSA+";
65 constexpr const char *SIGN_NID_ECDSA = "ECDSA+";
66 constexpr const char *SIGN_NID_ED = "Ed25519+";
67 constexpr const char *SIGN_NID_ED_FOUR_FOUR_EIGHT = "Ed448+";
68 constexpr const char *SIGN_NID_UNDEF_ADD = "UNDEF+";
69 constexpr const char *SIGN_NID_UNDEF = "UNDEF";
70 constexpr const char *OPERATOR_PLUS_SIGN = "+";
71 static constexpr const char *TLS_SOCKET_CLIENT_READ = "OS_NET_TSCliRD";
72 const std::regex JSON_STRING_PATTERN{R"(/^"(?:[^"\\\u0000-\u001f]|\\(?:["\\/bfnrt]|u[0-9a-fA-F]{4}))*"/)"};
73 const std::regex PATTERN{
74 "((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|"
75 "2[0-4][0-9]|[01]?[0-9][0-9]?)"};
76
77 class CaCertCache {
78 public:
GetInstance()79 static CaCertCache &GetInstance()
80 {
81 static CaCertCache instance;
82 return instance;
83 }
84
Get(const std::string & key)85 std::set<std::string> Get(const std::string &key)
86 {
87 std::lock_guard l(mutex_);
88 auto it = map_.find(key);
89 if (it != map_.end()) {
90 return it->second;
91 }
92 return {};
93 }
94
Set(const std::string & key,const std::string & val)95 void Set(const std::string &key, const std::string &val)
96 {
97 std::lock_guard l(mutex_);
98 map_[key].insert(val);
99 }
100
101 private:
102 CaCertCache() = default;
103 ~CaCertCache() = default;
104 CaCertCache &operator=(const CaCertCache &) = delete;
105 CaCertCache(const CaCertCache &) = delete;
106
107 std::map<std::string, std::set<std::string>> map_;
108 std::mutex mutex_;
109 };
110
ConvertErrno()111 int ConvertErrno()
112 {
113 return TlsSocketError::TLS_ERR_SYS_BASE + errno;
114 }
115
MakeErrnoString()116 std::string MakeErrnoString()
117 {
118 return strerror(errno);
119 }
120
MakeSSLErrorString(int error)121 std::string MakeSSLErrorString(int error)
122 {
123 char err[MAX_ERR_LEN] = {0};
124 ERR_error_string_n(error - TlsSocketError::TLS_ERR_SYS_BASE, err, sizeof(err));
125 return err;
126 }
127
SplitEscapedAltNames(std::string & altNames)128 std::vector<std::string> SplitEscapedAltNames(std::string &altNames)
129 {
130 std::vector<std::string> result;
131 std::string currentToken;
132 size_t offset = 0;
133 while (offset != altNames.length()) {
134 auto nextSep = altNames.find_first_of(", ");
135 auto nextQuote = altNames.find_first_of('\"');
136 if (nextQuote != std::string::npos && (nextSep != std::string::npos || nextQuote < nextSep)) {
137 currentToken += altNames.substr(offset, nextQuote);
138 std::regex jsonStringPattern(JSON_STRING_PATTERN);
139 std::smatch match;
140 std::string altNameSubStr = altNames.substr(nextQuote);
141 bool ret = regex_match(altNameSubStr, match, jsonStringPattern);
142 if (!ret) {
143 return {""};
144 }
145 currentToken += result[0];
146 offset = nextQuote + result[0].length();
147 } else if (nextSep != std::string::npos) {
148 currentToken += altNames.substr(offset, nextSep);
149 result.push_back(currentToken);
150 currentToken = "";
151 offset = nextSep + OFFSET;
152 } else {
153 currentToken += altNames.substr(offset);
154 offset = altNames.length();
155 }
156 }
157 result.push_back(currentToken);
158 return result;
159 }
160
IsIP(const std::string & ip)161 bool IsIP(const std::string &ip)
162 {
163 std::regex pattern(PATTERN);
164 std::smatch res;
165 return regex_match(ip, res, pattern);
166 }
167
SplitHostName(std::string & hostName)168 std::vector<std::string> SplitHostName(std::string &hostName)
169 {
170 transform(hostName.begin(), hostName.end(), hostName.begin(), ::tolower);
171 return CommonUtils::Split(hostName, SPLIT_HOST_NAME);
172 }
173
SeekIntersection(std::vector<std::string> & vecA,std::vector<std::string> & vecB)174 bool SeekIntersection(std::vector<std::string> &vecA, std::vector<std::string> &vecB)
175 {
176 std::vector<std::string> result;
177 set_intersection(vecA.begin(), vecA.end(), vecB.begin(), vecB.end(), inserter(result, result.begin()));
178 return !result.empty();
179 }
180 } // namespace
181
SetSockBlockFlag(int sock,bool noneBlock)182 static bool SetSockBlockFlag(int sock, bool noneBlock)
183 {
184 int flags = fcntl(sock, F_GETFL, 0);
185 while (flags == -1 && errno == EINTR) {
186 flags = fcntl(sock, F_GETFL, 0);
187 }
188 if (flags == -1) {
189 NETSTACK_LOGE("set block flags failed, socket is %{public}d, errno is %{public}d", sock, errno);
190 return false;
191 }
192
193 auto newFlags = static_cast<size_t>(flags);
194 if (noneBlock) {
195 newFlags |= static_cast<size_t>(O_NONBLOCK);
196 } else {
197 newFlags &= ~static_cast<size_t>(O_NONBLOCK);
198 }
199
200 int ret = fcntl(sock, F_SETFL, newFlags);
201 while (ret == -1 && errno == EINTR) {
202 ret = fcntl(sock, F_SETFL, newFlags);
203 }
204 if (ret == -1) {
205 NETSTACK_LOGE("set block flags failed, socket is %{public}d, errno is %{public}d", sock, errno);
206 return false;
207 }
208 return true;
209 }
210
TLSSecureOptions(const TLSSecureOptions & tlsSecureOptions)211 TLSSecureOptions::TLSSecureOptions(const TLSSecureOptions &tlsSecureOptions)
212 {
213 *this = tlsSecureOptions;
214 }
215
operator =(const TLSSecureOptions & tlsSecureOptions)216 TLSSecureOptions &TLSSecureOptions::operator=(const TLSSecureOptions &tlsSecureOptions)
217 {
218 key_ = tlsSecureOptions.GetKey();
219 caChain_ = tlsSecureOptions.GetCaChain();
220 cert_ = tlsSecureOptions.GetCert();
221 protocolChain_ = tlsSecureOptions.GetProtocolChain();
222 crlChain_ = tlsSecureOptions.GetCrlChain();
223 keyPass_ = tlsSecureOptions.GetKeyPass();
224 key_ = tlsSecureOptions.GetKey();
225 signatureAlgorithms_ = tlsSecureOptions.GetSignatureAlgorithms();
226 cipherSuite_ = tlsSecureOptions.GetCipherSuite();
227 useRemoteCipherPrefer_ = tlsSecureOptions.UseRemoteCipherPrefer();
228 TLSVerifyMode_ = tlsSecureOptions.GetVerifyMode();
229 return *this;
230 }
231
SetCaChain(const std::vector<std::string> & caChain)232 void TLSSecureOptions::SetCaChain(const std::vector<std::string> &caChain)
233 {
234 caChain_ = caChain;
235 }
236
SetCert(const std::string & cert)237 void TLSSecureOptions::SetCert(const std::string &cert)
238 {
239 cert_ = cert;
240 }
241
SetKey(const SecureData & key)242 void TLSSecureOptions::SetKey(const SecureData &key)
243 {
244 key_ = key;
245 }
246
SetKeyPass(const SecureData & keyPass)247 void TLSSecureOptions::SetKeyPass(const SecureData &keyPass)
248 {
249 keyPass_ = keyPass;
250 }
251
SetProtocolChain(const std::vector<std::string> & protocolChain)252 void TLSSecureOptions::SetProtocolChain(const std::vector<std::string> &protocolChain)
253 {
254 protocolChain_ = protocolChain;
255 }
256
SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer)257 void TLSSecureOptions::SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer)
258 {
259 useRemoteCipherPrefer_ = useRemoteCipherPrefer;
260 }
261
SetSignatureAlgorithms(const std::string & signatureAlgorithms)262 void TLSSecureOptions::SetSignatureAlgorithms(const std::string &signatureAlgorithms)
263 {
264 signatureAlgorithms_ = signatureAlgorithms;
265 }
266
SetCipherSuite(const std::string & cipherSuite)267 void TLSSecureOptions::SetCipherSuite(const std::string &cipherSuite)
268 {
269 cipherSuite_ = cipherSuite;
270 }
271
SetCrlChain(const std::vector<std::string> & crlChain)272 void TLSSecureOptions::SetCrlChain(const std::vector<std::string> &crlChain)
273 {
274 crlChain_ = crlChain;
275 }
276
GetCaChain() const277 const std::vector<std::string> &TLSSecureOptions::GetCaChain() const
278 {
279 return caChain_;
280 }
281
GetCert() const282 const std::string &TLSSecureOptions::GetCert() const
283 {
284 return cert_;
285 }
286
GetKey() const287 const SecureData &TLSSecureOptions::GetKey() const
288 {
289 return key_;
290 }
291
GetKeyPass() const292 const SecureData &TLSSecureOptions::GetKeyPass() const
293 {
294 return keyPass_;
295 }
296
GetProtocolChain() const297 const std::vector<std::string> &TLSSecureOptions::GetProtocolChain() const
298 {
299 return protocolChain_;
300 }
301
UseRemoteCipherPrefer() const302 bool TLSSecureOptions::UseRemoteCipherPrefer() const
303 {
304 return useRemoteCipherPrefer_;
305 }
306
GetSignatureAlgorithms() const307 const std::string &TLSSecureOptions::GetSignatureAlgorithms() const
308 {
309 return signatureAlgorithms_;
310 }
311
GetCipherSuite() const312 const std::string &TLSSecureOptions::GetCipherSuite() const
313 {
314 return cipherSuite_;
315 }
316
GetCrlChain() const317 const std::vector<std::string> &TLSSecureOptions::GetCrlChain() const
318 {
319 return crlChain_;
320 }
321
SetVerifyMode(VerifyMode verifyMode)322 void TLSSecureOptions::SetVerifyMode(VerifyMode verifyMode)
323 {
324 TLSVerifyMode_ = verifyMode;
325 }
326
GetVerifyMode() const327 VerifyMode TLSSecureOptions::GetVerifyMode() const
328 {
329 return TLSVerifyMode_;
330 }
331
SetNetAddress(const Socket::NetAddress & address)332 void TLSConnectOptions::SetNetAddress(const Socket::NetAddress &address)
333 {
334 address_.SetFamilyBySaFamily(address.GetSaFamily());
335 address_.SetRawAddress(address.GetAddress());
336 address_.SetPort(address.GetPort());
337 }
338
SetTlsSecureOptions(TLSSecureOptions & tlsSecureOptions)339 void TLSConnectOptions::SetTlsSecureOptions(TLSSecureOptions &tlsSecureOptions)
340 {
341 tlsSecureOptions_ = tlsSecureOptions;
342 }
343
SetCheckServerIdentity(const CheckServerIdentity & checkServerIdentity)344 void TLSConnectOptions::SetCheckServerIdentity(const CheckServerIdentity &checkServerIdentity)
345 {
346 checkServerIdentity_ = checkServerIdentity;
347 }
348
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)349 void TLSConnectOptions::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
350 {
351 alpnProtocols_ = alpnProtocols;
352 }
353
SetSkipRemoteValidation(bool skipRemoteValidation)354 void TLSConnectOptions::SetSkipRemoteValidation(bool skipRemoteValidation)
355 {
356 skipRemoteValidation_ = skipRemoteValidation;
357 }
358
GetNetAddress() const359 Socket::NetAddress TLSConnectOptions::GetNetAddress() const
360 {
361 return address_;
362 }
363
GetTlsSecureOptions() const364 TLSSecureOptions TLSConnectOptions::GetTlsSecureOptions() const
365 {
366 return tlsSecureOptions_;
367 }
368
GetCheckServerIdentity() const369 CheckServerIdentity TLSConnectOptions::GetCheckServerIdentity() const
370 {
371 return checkServerIdentity_;
372 }
373
GetAlpnProtocols() const374 const std::vector<std::string> &TLSConnectOptions::GetAlpnProtocols() const
375 {
376 return alpnProtocols_;
377 }
378
GetSkipRemoteValidation() const379 bool TLSConnectOptions::GetSkipRemoteValidation() const
380 {
381 return skipRemoteValidation_;
382 }
383
SetHostName(const std::string & hostName)384 void TLSConnectOptions::SetHostName(const std::string &hostName)
385 {
386 hostName_ = hostName;
387 }
388
GetHostName() const389 std::string TLSConnectOptions::GetHostName() const
390 {
391 return hostName_;
392 }
393
MakeAddressString(sockaddr * addr)394 std::string TLSSocket::MakeAddressString(sockaddr *addr)
395 {
396 if (!addr) {
397 return {};
398 }
399 if (addr->sa_family == AF_INET) {
400 auto *addr4 = reinterpret_cast<sockaddr_in *>(addr);
401 const char *str = inet_ntoa(addr4->sin_addr);
402 if (str == nullptr || strlen(str) == 0) {
403 return {};
404 }
405 return str;
406 } else if (addr->sa_family == AF_INET6) {
407 auto *addr6 = reinterpret_cast<sockaddr_in6 *>(addr);
408 char str[INET6_ADDRSTRLEN] = {0};
409 if (inet_ntop(AF_INET6, &addr6->sin6_addr, str, INET6_ADDRSTRLEN) == nullptr || strlen(str) == 0) {
410 return {};
411 }
412 return str;
413 }
414 return {};
415 }
416
GetAddr(const Socket::NetAddress & address,sockaddr_in * addr4,sockaddr_in6 * addr6,sockaddr ** addr,socklen_t * len)417 void TLSSocket::GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr,
418 socklen_t *len)
419 {
420 if (!addr6 || !addr4 || !len) {
421 return;
422 }
423 sa_family_t family = address.GetSaFamily();
424 if (family == AF_INET) {
425 addr4->sin_family = AF_INET;
426 addr4->sin_port = htons(address.GetPort());
427 addr4->sin_addr.s_addr = inet_addr(address.GetAddress().c_str());
428 *addr = reinterpret_cast<sockaddr *>(addr4);
429 *len = sizeof(sockaddr_in);
430 } else if (family == AF_INET6) {
431 addr6->sin6_family = AF_INET6;
432 addr6->sin6_port = htons(address.GetPort());
433 inet_pton(AF_INET6, address.GetAddress().c_str(), &addr6->sin6_addr);
434 *addr = reinterpret_cast<sockaddr *>(addr6);
435 *len = sizeof(sockaddr_in6);
436 }
437 }
438
ExecTlsSetSockBlockFlag(int sock,bool noneBlock)439 bool TLSSocket::ExecTlsSetSockBlockFlag(int sock, bool noneBlock)
440 {
441 return SetSockBlockFlag(sock, noneBlock);
442 }
443
ExecTlsGetAddr(const Socket::NetAddress & address,sockaddr_in * addr4,sockaddr_in6 * addr6,sockaddr ** addr,socklen_t * len)444 void TLSSocket::ExecTlsGetAddr(
445 const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr, socklen_t *len)
446 {
447 GetAddr(address, addr4, addr6, addr, len);
448 }
449
IsExtSock() const450 bool TLSSocket::IsExtSock() const
451 {
452 return isExtSock_;
453 }
454
MakeIpSocket(sa_family_t family)455 void TLSSocket::MakeIpSocket(sa_family_t family)
456 {
457 if (family != AF_INET && family != AF_INET6) {
458 return;
459 }
460 int sock = socket(family, SOCK_STREAM, IPPROTO_IP);
461 if (sock < 0) {
462 int resErr = ConvertErrno();
463 NETSTACK_LOGE("Create socket failed (%{public}d:%{public}s)", errno, MakeErrnoString().c_str());
464 CallOnErrorCallback(resErr, MakeErrnoString());
465 return;
466 }
467 sockFd_ = sock;
468 }
469
ReadMessage()470 int TLSSocket::ReadMessage()
471 {
472 char buffer[MAX_RECV_BUFFER_SIZE];
473 if (memset_s(buffer, MAX_RECV_BUFFER_SIZE, 0, MAX_RECV_BUFFER_SIZE) != EOK) {
474 NETSTACK_LOGE("memset_s failed!");
475 return -1;
476 }
477 nfds_t num = 1;
478 pollfd fds[1] = {{.fd = sockFd_, .events = POLLIN}};
479 int ret = poll(fds, num, READ_TIMEOUT_MS);
480 if (ret < 0) {
481 if (errno == EAGAIN || errno == EINTR) {
482 return 0;
483 }
484 int resErr = ConvertErrno();
485 NETSTACK_LOGE("Message poll errno is %{public}d %{public}s", errno, MakeErrnoString().c_str());
486 CallOnErrorCallback(resErr, MakeErrnoString());
487 return ret;
488 } else if (ret == 0) {
489 NETSTACK_LOGD("tls recv poll timeout");
490 return ret;
491 }
492
493 std::lock_guard<std::mutex> lock(recvMutex_);
494 if (!isRunning_) {
495 return -1;
496 }
497 int len = tlsSocketInternal_.Recv(buffer, MAX_RECV_BUFFER_SIZE);
498 if (len < 0) {
499 if (errno == EAGAIN || errno == EINTR || len == SSL_WANT_READ_RETURN) {
500 return 0;
501 }
502 int resErr = tlsSocketInternal_.ConvertSSLError();
503 NETSTACK_LOGE("SSL_read function read error, errno is %{public}d, errno info is %{public}s", resErr,
504 MakeSSLErrorString(resErr).c_str());
505 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
506 return len;
507 } else if (len == 0) {
508 NETSTACK_LOGI("Message recv len 0, session is closed by peer");
509 CallOnCloseCallback();
510 return -1;
511 }
512 Socket::SocketRemoteInfo remoteInfo;
513 remoteInfo.SetSize(len);
514 tlsSocketInternal_.MakeRemoteInfo(remoteInfo);
515 std::string bufContent(buffer, len);
516 CallOnMessageCallback(bufContent, remoteInfo);
517
518 return ret;
519 }
520
StartReadMessage()521 void TLSSocket::StartReadMessage()
522 {
523 auto wp = std::weak_ptr<TLSSocket>(shared_from_this());
524 std::thread thread([wp]() {
525 auto tlsSocket = wp.lock();
526 if (tlsSocket == nullptr) {
527 return;
528 }
529 tlsSocket->isRunning_ = true;
530 tlsSocket->isRunOver_ = false;
531 #if defined(MAC_PLATFORM) || defined(IOS_PLATFORM)
532 pthread_setname_np(TLS_SOCKET_CLIENT_READ);
533 #else
534 pthread_setname_np(pthread_self(), TLS_SOCKET_CLIENT_READ);
535 #endif
536 while (tlsSocket->isRunning_) {
537 int ret = tlsSocket->ReadMessage();
538 if (ret < 0) {
539 break;
540 }
541 }
542 tlsSocket->isRunOver_ = true;
543 tlsSocket->cvSslFree_.notify_one();
544 });
545 thread.detach();
546 }
547
CallOnMessageCallback(const std::string & data,const Socket::SocketRemoteInfo & remoteInfo)548 void TLSSocket::CallOnMessageCallback(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)
549 {
550 OnMessageCallback func = nullptr;
551 {
552 std::lock_guard<std::mutex> lock(mutex_);
553 if (onMessageCallback_) {
554 func = onMessageCallback_;
555 }
556 }
557
558 if (func) {
559 func(data, remoteInfo);
560 }
561 }
562
CallOnConnectCallback()563 void TLSSocket::CallOnConnectCallback()
564 {
565 OnConnectCallback func = nullptr;
566 {
567 std::lock_guard<std::mutex> lock(mutex_);
568 if (onConnectCallback_) {
569 func = onConnectCallback_;
570 }
571 }
572
573 if (func) {
574 func();
575 }
576 }
577
CallOnCloseCallback()578 void TLSSocket::CallOnCloseCallback()
579 {
580 OnCloseCallback func = nullptr;
581 {
582 std::lock_guard<std::mutex> lock(mutex_);
583 if (onCloseCallback_) {
584 func = onCloseCallback_;
585 }
586 }
587
588 if (func) {
589 func();
590 }
591 }
592
CallOnErrorCallback(int32_t err,const std::string & errString)593 void TLSSocket::CallOnErrorCallback(int32_t err, const std::string &errString)
594 {
595 OnErrorCallback func = nullptr;
596 {
597 std::lock_guard<std::mutex> lock(mutex_);
598 if (onErrorCallback_) {
599 func = onErrorCallback_;
600 }
601 }
602
603 if (func) {
604 func(err, errString);
605 }
606 }
607
CallBindCallback(int32_t err,BindCallback callback)608 void TLSSocket::CallBindCallback(int32_t err, BindCallback callback)
609 {
610 DealCallback<BindCallback>(err, callback);
611 }
612
CallConnectCallback(int32_t err,ConnectCallback callback)613 void TLSSocket::CallConnectCallback(int32_t err, ConnectCallback callback)
614 {
615 DealCallback<ConnectCallback>(err, callback);
616 }
617
CallSendCallback(int32_t err,SendCallback callback)618 void TLSSocket::CallSendCallback(int32_t err, SendCallback callback)
619 {
620 DealCallback<SendCallback>(err, callback);
621 }
622
CallCloseCallback(int32_t err,CloseCallback callback)623 void TLSSocket::CallCloseCallback(int32_t err, CloseCallback callback)
624 {
625 DealCallback<CloseCallback>(err, callback);
626 }
627
CallGetRemoteAddressCallback(int32_t err,const Socket::NetAddress & address,GetRemoteAddressCallback callback)628 void TLSSocket::CallGetRemoteAddressCallback(int32_t err, const Socket::NetAddress &address,
629 GetRemoteAddressCallback callback)
630 {
631 if (callback) {
632 callback(err, address);
633 }
634 }
635
CallGetStateCallback(int32_t err,const Socket::SocketStateBase & state,GetStateCallback callback)636 void TLSSocket::CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, GetStateCallback callback)
637 {
638 if (callback) {
639 callback(err, state);
640 }
641 }
642
CallSetExtraOptionsCallback(int32_t err,SetExtraOptionsCallback callback)643 void TLSSocket::CallSetExtraOptionsCallback(int32_t err, SetExtraOptionsCallback callback)
644 {
645 DealCallback<SetExtraOptionsCallback>(err, callback);
646 }
647
CallGetCertificateCallback(int32_t err,const X509CertRawData & cert,GetCertificateCallback callback)648 void TLSSocket::CallGetCertificateCallback(int32_t err, const X509CertRawData &cert, GetCertificateCallback callback)
649 {
650 if (callback) {
651 callback(err, cert);
652 }
653 }
654
CallGetRemoteCertificateCallback(int32_t err,const X509CertRawData & cert,GetRemoteCertificateCallback callback)655 void TLSSocket::CallGetRemoteCertificateCallback(int32_t err, const X509CertRawData &cert,
656 GetRemoteCertificateCallback callback)
657 {
658 if (callback) {
659 callback(err, cert);
660 }
661 }
662
CallGetProtocolCallback(int32_t err,const std::string & protocol,GetProtocolCallback callback)663 void TLSSocket::CallGetProtocolCallback(int32_t err, const std::string &protocol, GetProtocolCallback callback)
664 {
665 if (callback) {
666 callback(err, protocol);
667 }
668 }
669
CallGetCipherSuiteCallback(int32_t err,const std::vector<std::string> & suite,GetCipherSuiteCallback callback)670 void TLSSocket::CallGetCipherSuiteCallback(int32_t err, const std::vector<std::string> &suite,
671 GetCipherSuiteCallback callback)
672 {
673 if (callback) {
674 callback(err, suite);
675 }
676 }
677
CallGetSignatureAlgorithmsCallback(int32_t err,const std::vector<std::string> & algorithms,GetSignatureAlgorithmsCallback callback)678 void TLSSocket::CallGetSignatureAlgorithmsCallback(int32_t err, const std::vector<std::string> &algorithms,
679 GetSignatureAlgorithmsCallback callback)
680 {
681 if (callback) {
682 callback(err, algorithms);
683 }
684 }
685
Bind(Socket::NetAddress & address,const BindCallback & callback)686 void TLSSocket::Bind(Socket::NetAddress &address, const BindCallback &callback)
687 {
688 static constexpr int32_t PARSE_ERROR_CODE = 401;
689 if (!CommonUtils::HasInternetPermission()) {
690 CallBindCallback(PERMISSION_DENIED_CODE, callback);
691 return;
692 }
693 if (sockFd_ >= 0) {
694 CallBindCallback(TLSSOCKET_SUCCESS, callback);
695 return;
696 }
697
698 MakeIpSocket(address.GetSaFamily());
699 if (sockFd_ < 0) {
700 int resErr = ConvertErrno();
701 NETSTACK_LOGE("make tcp socket failed errno is %{public}d %{public}s", errno, MakeErrnoString().c_str());
702 CallOnErrorCallback(resErr, MakeErrnoString());
703 CallBindCallback(resErr, callback);
704 return;
705 }
706
707 auto temp = address.GetAddress();
708 address.SetRawAddress("");
709 address.SetAddress(temp);
710 if (address.GetAddress().empty()) {
711 CallBindCallback(PARSE_ERROR_CODE, callback);
712 return;
713 }
714
715 sockaddr_in addr4 = {0};
716 sockaddr_in6 addr6 = {0};
717 sockaddr *addr = nullptr;
718 socklen_t len;
719 GetAddr(address, &addr4, &addr6, &addr, &len);
720 if (addr == nullptr) {
721 NETSTACK_LOGE("TLSSocket::Bind Address Is Invalid");
722 CallOnErrorCallback(-1, "Address Is Invalid");
723 CallBindCallback(ConvertErrno(), callback);
724 return;
725 }
726 CallBindCallback(TLSSOCKET_SUCCESS, callback);
727 }
728
Connect(OHOS::NetStack::TlsSocket::TLSConnectOptions & tlsConnectOptions,const OHOS::NetStack::TlsSocket::ConnectCallback & callback)729 void TLSSocket::Connect(OHOS::NetStack::TlsSocket::TLSConnectOptions &tlsConnectOptions,
730 const OHOS::NetStack::TlsSocket::ConnectCallback &callback)
731 {
732 if (sockFd_ < 0) {
733 int resErr = ConvertErrno();
734 NETSTACK_LOGE("connect error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
735 CallOnErrorCallback(resErr, MakeErrnoString());
736 callback(resErr);
737 return;
738 }
739
740 if (isExtSock_ && !SetSockBlockFlag(sockFd_, false)) {
741 int resErr = ConvertErrno();
742 NETSTACK_LOGE("SetSockBlockFlag error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
743 CallOnErrorCallback(resErr, MakeErrnoString());
744 callback(resErr);
745 return;
746 }
747
748 auto res = tlsSocketInternal_.TlsConnectToHost(sockFd_, tlsConnectOptions, isExtSock_);
749 if (!res) {
750 int resErr = tlsSocketInternal_.ConvertSSLError();
751 NETSTACK_LOGE("connect error is %{public}d %{public}d", resErr, errno);
752 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
753 callback(resErr);
754 return;
755 }
756 if (!SetSockBlockFlag(sockFd_, true)) {
757 int resErr = ConvertErrno();
758 NETSTACK_LOGE("SetSockBlockFlag error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
759 CallOnErrorCallback(resErr, MakeErrnoString());
760 callback(resErr);
761 return;
762 }
763 StartReadMessage();
764 CallOnConnectCallback();
765 callback(TLSSOCKET_SUCCESS);
766 }
767
Send(const OHOS::NetStack::Socket::TCPSendOptions & tcpSendOptions,const SendCallback & callback)768 void TLSSocket::Send(const OHOS::NetStack::Socket::TCPSendOptions &tcpSendOptions, const SendCallback &callback)
769 {
770 (void)tcpSendOptions;
771
772 auto res = tlsSocketInternal_.Send(tcpSendOptions.GetData());
773 if (!res) {
774 int resErr = tlsSocketInternal_.ConvertSSLError();
775 NETSTACK_LOGE("send error is %{public}d %{public}d", resErr, errno);
776 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
777 CallSendCallback(resErr, callback);
778 return;
779 }
780 CallSendCallback(TLSSOCKET_SUCCESS, callback);
781 }
782
Close(const CloseCallback & callback)783 void TLSSocket::Close(const CloseCallback &callback)
784 {
785 isRunning_ = false;
786 std::unique_lock<std::mutex> cvLock(cvMutex_);
787 auto wp = std::weak_ptr<TLSSocket>(shared_from_this());
788 cvSslFree_.wait(cvLock, [wp]() -> bool {
789 auto tlsSocket = wp.lock();
790 if (tlsSocket == nullptr) {
791 return true;
792 }
793 return tlsSocket->isRunOver_;
794 });
795
796 std::lock_guard<std::mutex> lock(recvMutex_);
797 auto res = tlsSocketInternal_.Close();
798 if (!res) {
799 int resErr = tlsSocketInternal_.ConvertSSLError();
800 NETSTACK_LOGE("close error is %{public}s %{public}d", MakeSSLErrorString(resErr).c_str(), resErr);
801 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
802 callback(resErr);
803 return;
804 }
805 sockFd_ = -1;
806 CallOnCloseCallback();
807 callback(TLSSOCKET_SUCCESS);
808 }
809
GetRemoteAddress(const GetRemoteAddressCallback & callback)810 void TLSSocket::GetRemoteAddress(const GetRemoteAddressCallback &callback)
811 {
812 sockaddr sockAddr = {0};
813 socklen_t len = sizeof(sockaddr);
814 int ret = getsockname(sockFd_, &sockAddr, &len);
815 if (ret < 0) {
816 int resErr = ConvertErrno();
817 NETSTACK_LOGE("getsockname failed errno %{public}d", resErr);
818 CallOnErrorCallback(resErr, MakeErrnoString());
819 CallGetRemoteAddressCallback(resErr, {}, callback);
820 return;
821 }
822
823 if (sockAddr.sa_family == AF_INET) {
824 GetIp4RemoteAddress(callback);
825 } else if (sockAddr.sa_family == AF_INET6) {
826 GetIp6RemoteAddress(callback);
827 }
828 }
829
GetIp4RemoteAddress(const GetRemoteAddressCallback & callback)830 void TLSSocket::GetIp4RemoteAddress(const GetRemoteAddressCallback &callback)
831 {
832 sockaddr_in addr4 = {0};
833 socklen_t len4 = sizeof(sockaddr_in);
834
835 int ret = getpeername(sockFd_, reinterpret_cast<sockaddr *>(&addr4), &len4);
836 if (ret < 0) {
837 int resErr = ConvertErrno();
838 NETSTACK_LOGE("GetIp4RemoteAddress failed errno %{public}d", resErr);
839 CallOnErrorCallback(resErr, MakeErrnoString());
840 CallGetRemoteAddressCallback(resErr, {}, callback);
841 return;
842 }
843
844 std::string address = MakeAddressString(reinterpret_cast<sockaddr *>(&addr4));
845 if (address.empty()) {
846 NETSTACK_LOGE("GetIp4RemoteAddress failed errno %{public}d", errno);
847 CallOnErrorCallback(-1, "Address is invalid");
848 CallGetRemoteAddressCallback(ConvertErrno(), {}, callback);
849 return;
850 }
851 Socket::NetAddress netAddress;
852 netAddress.SetFamilyBySaFamily(AF_INET);
853 netAddress.SetRawAddress(address);
854 netAddress.SetPort(ntohs(addr4.sin_port));
855 CallGetRemoteAddressCallback(TLSSOCKET_SUCCESS, netAddress, callback);
856 }
857
GetIp6RemoteAddress(const GetRemoteAddressCallback & callback)858 void TLSSocket::GetIp6RemoteAddress(const GetRemoteAddressCallback &callback)
859 {
860 sockaddr_in6 addr6 = {0};
861 socklen_t len6 = sizeof(sockaddr_in6);
862
863 int ret = getpeername(sockFd_, reinterpret_cast<sockaddr *>(&addr6), &len6);
864 if (ret < 0) {
865 int resErr = ConvertErrno();
866 NETSTACK_LOGE("GetIp6RemoteAddress failed errno %{public}d", resErr);
867 CallOnErrorCallback(resErr, MakeErrnoString());
868 CallGetRemoteAddressCallback(resErr, {}, callback);
869 return;
870 }
871
872 std::string address = MakeAddressString(reinterpret_cast<sockaddr *>(&addr6));
873 if (address.empty()) {
874 NETSTACK_LOGE("GetIp6RemoteAddress failed errno %{public}d", errno);
875 CallOnErrorCallback(-1, "Address is invalid");
876 CallGetRemoteAddressCallback(ConvertErrno(), {}, callback);
877 return;
878 }
879 Socket::NetAddress netAddress;
880 netAddress.SetFamilyBySaFamily(AF_INET6);
881 netAddress.SetRawAddress(address);
882 netAddress.SetPort(ntohs(addr6.sin6_port));
883 CallGetRemoteAddressCallback(TLSSOCKET_SUCCESS, netAddress, callback);
884 }
885
GetState(const GetStateCallback & callback)886 void TLSSocket::GetState(const GetStateCallback &callback)
887 {
888 int opt;
889 socklen_t optLen = sizeof(int);
890 int r = getsockopt(sockFd_, SOL_SOCKET, SO_TYPE, &opt, &optLen);
891 if (r < 0) {
892 Socket::SocketStateBase state;
893 state.SetIsClose(true);
894 CallGetStateCallback(ConvertErrno(), state, callback);
895 return;
896 }
897 sockaddr sockAddr = {0};
898 socklen_t len = sizeof(sockaddr);
899 Socket::SocketStateBase state;
900 int ret = getsockname(sockFd_, &sockAddr, &len);
901 state.SetIsBound(ret == 0);
902 ret = getpeername(sockFd_, &sockAddr, &len);
903 state.SetIsConnected(ret == 0);
904 CallGetStateCallback(TLSSOCKET_SUCCESS, state, callback);
905 }
906
SetBaseOptions(const Socket::ExtraOptionsBase & option) const907 bool TLSSocket::SetBaseOptions(const Socket::ExtraOptionsBase &option) const
908 {
909 if (option.GetReceiveBufferSize() != 0) {
910 int size = (int)option.GetReceiveBufferSize();
911 if (setsockopt(sockFd_, SOL_SOCKET, SO_RCVBUF, reinterpret_cast<void *>(&size), sizeof(size)) < 0) {
912 return false;
913 }
914 }
915
916 if (option.GetSendBufferSize() != 0) {
917 int size = (int)option.GetSendBufferSize();
918 if (setsockopt(sockFd_, SOL_SOCKET, SO_SNDBUF, reinterpret_cast<void *>(&size), sizeof(size)) < 0) {
919 return false;
920 }
921 }
922
923 if (option.IsReuseAddress()) {
924 int reuse = 1;
925 if (setsockopt(sockFd_, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<void *>(&reuse), sizeof(reuse)) < 0) {
926 return false;
927 }
928 }
929
930 if (option.GetSocketTimeout() != 0) {
931 timeval timeout = {(int)option.GetSocketTimeout(), 0};
932 if (setsockopt(sockFd_, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<void *>(&timeout), sizeof(timeout)) < 0) {
933 return false;
934 }
935 if (setsockopt(sockFd_, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast<void *>(&timeout), sizeof(timeout)) < 0) {
936 return false;
937 }
938 }
939
940 return true;
941 }
942
SetExtraOptions(const Socket::TCPExtraOptions & option) const943 bool TLSSocket::SetExtraOptions(const Socket::TCPExtraOptions &option) const
944 {
945 if (option.IsKeepAlive()) {
946 int keepalive = 1;
947 if (setsockopt(sockFd_, SOL_SOCKET, SO_KEEPALIVE, &keepalive, sizeof(keepalive)) < 0) {
948 return false;
949 }
950 }
951
952 if (option.IsOOBInline()) {
953 int oobInline = 1;
954 if (setsockopt(sockFd_, SOL_SOCKET, SO_OOBINLINE, &oobInline, sizeof(oobInline)) < 0) {
955 return false;
956 }
957 }
958
959 if (option.IsTCPNoDelay()) {
960 int tcpNoDelay = 1;
961 if (setsockopt(sockFd_, IPPROTO_TCP, TCP_NODELAY, &tcpNoDelay, sizeof(tcpNoDelay)) < 0) {
962 return false;
963 }
964 }
965
966 linger soLinger = {0};
967 soLinger.l_onoff = option.socketLinger.IsOn();
968 soLinger.l_linger = (int)option.socketLinger.GetLinger();
969 if (setsockopt(sockFd_, SOL_SOCKET, SO_LINGER, &soLinger, sizeof(soLinger)) < 0) {
970 return false;
971 }
972
973 return true;
974 }
975
SetExtraOptions(const OHOS::NetStack::Socket::TCPExtraOptions & tcpExtraOptions,const SetExtraOptionsCallback & callback)976 void TLSSocket::SetExtraOptions(const OHOS::NetStack::Socket::TCPExtraOptions &tcpExtraOptions,
977 const SetExtraOptionsCallback &callback)
978 {
979 if (!SetBaseOptions(tcpExtraOptions)) {
980 NETSTACK_LOGE("SetExtraOptions errno %{public}d", errno);
981 CallOnErrorCallback(errno, MakeErrnoString());
982 CallSetExtraOptionsCallback(ConvertErrno(), callback);
983 return;
984 }
985
986 if (!SetExtraOptions(tcpExtraOptions)) {
987 NETSTACK_LOGE("SetExtraOptions errno %{public}d", errno);
988 CallOnErrorCallback(errno, MakeErrnoString());
989 CallSetExtraOptionsCallback(ConvertErrno(), callback);
990 return;
991 }
992
993 CallSetExtraOptionsCallback(TLSSOCKET_SUCCESS, callback);
994 }
995
GetCertificate(const GetCertificateCallback & callback)996 void TLSSocket::GetCertificate(const GetCertificateCallback &callback)
997 {
998 const auto &cert = tlsSocketInternal_.GetCertificate();
999 NETSTACK_LOGI("cert der is %{public}d", cert.encodingFormat);
1000
1001 if (!cert.data.Length()) {
1002 int resErr = tlsSocketInternal_.ConvertSSLError();
1003 NETSTACK_LOGE("GetCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
1004 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1005 callback(resErr, {});
1006 return;
1007 }
1008 callback(TLSSOCKET_SUCCESS, cert);
1009 }
1010
GetRemoteCertificate(const GetRemoteCertificateCallback & callback)1011 void TLSSocket::GetRemoteCertificate(const GetRemoteCertificateCallback &callback)
1012 {
1013 const auto &remoteCert = tlsSocketInternal_.GetRemoteCertRawData();
1014 if (!remoteCert.data.Length()) {
1015 int resErr = tlsSocketInternal_.ConvertSSLError();
1016 NETSTACK_LOGE("GetRemoteCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
1017 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1018 callback(resErr, {});
1019 return;
1020 }
1021 callback(TLSSOCKET_SUCCESS, remoteCert);
1022 }
1023
GetProtocol(const GetProtocolCallback & callback)1024 void TLSSocket::GetProtocol(const GetProtocolCallback &callback)
1025 {
1026 const auto &protocol = tlsSocketInternal_.GetProtocol();
1027 if (protocol.empty()) {
1028 NETSTACK_LOGE("GetProtocol errno %{public}d", errno);
1029 int resErr = tlsSocketInternal_.ConvertSSLError();
1030 NETSTACK_LOGE("getProtocol error is %{public}d %{public}d", resErr, errno);
1031 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1032 callback(resErr, "");
1033 return;
1034 }
1035 callback(TLSSOCKET_SUCCESS, protocol);
1036 }
1037
GetCipherSuite(const GetCipherSuiteCallback & callback)1038 void TLSSocket::GetCipherSuite(const GetCipherSuiteCallback &callback)
1039 {
1040 const auto &cipherSuite = tlsSocketInternal_.GetCipherSuite();
1041 if (cipherSuite.empty()) {
1042 int resErr = tlsSocketInternal_.ConvertSSLError();
1043 NETSTACK_LOGE("getCipherSuite error is %{public}d %{public}d", resErr, errno);
1044 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1045 callback(resErr, cipherSuite);
1046 return;
1047 }
1048 callback(TLSSOCKET_SUCCESS, cipherSuite);
1049 }
1050
GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback & callback)1051 void TLSSocket::GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback &callback)
1052 {
1053 const auto &signatureAlgorithms = tlsSocketInternal_.GetSignatureAlgorithms();
1054 if (signatureAlgorithms.empty()) {
1055 int resErr = tlsSocketInternal_.ConvertSSLError();
1056 NETSTACK_LOGE("getSignatureAlgorithms error is %{public}d %{public}d", resErr, errno);
1057 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1058 callback(resErr, {});
1059 return;
1060 }
1061 callback(TLSSOCKET_SUCCESS, signatureAlgorithms);
1062 }
1063
OnMessage(const OnMessageCallback & onMessageCallback)1064 void TLSSocket::OnMessage(const OnMessageCallback &onMessageCallback)
1065 {
1066 std::lock_guard<std::mutex> lock(mutex_);
1067 onMessageCallback_ = onMessageCallback;
1068 }
1069
OffMessage()1070 void TLSSocket::OffMessage()
1071 {
1072 std::lock_guard<std::mutex> lock(mutex_);
1073 if (onMessageCallback_) {
1074 onMessageCallback_ = nullptr;
1075 }
1076 }
1077
OnConnect(const OnConnectCallback & onConnectCallback)1078 void TLSSocket::OnConnect(const OnConnectCallback &onConnectCallback)
1079 {
1080 std::lock_guard<std::mutex> lock(mutex_);
1081 onConnectCallback_ = onConnectCallback;
1082 }
1083
OffConnect()1084 void TLSSocket::OffConnect()
1085 {
1086 std::lock_guard<std::mutex> lock(mutex_);
1087 if (onConnectCallback_) {
1088 onConnectCallback_ = nullptr;
1089 }
1090 }
1091
OnClose(const OnCloseCallback & onCloseCallback)1092 void TLSSocket::OnClose(const OnCloseCallback &onCloseCallback)
1093 {
1094 std::lock_guard<std::mutex> lock(mutex_);
1095 onCloseCallback_ = onCloseCallback;
1096 }
1097
OffClose()1098 void TLSSocket::OffClose()
1099 {
1100 std::lock_guard<std::mutex> lock(mutex_);
1101 if (onCloseCallback_) {
1102 onCloseCallback_ = nullptr;
1103 }
1104 }
1105
OnError(const OnErrorCallback & onErrorCallback)1106 void TLSSocket::OnError(const OnErrorCallback &onErrorCallback)
1107 {
1108 std::lock_guard<std::mutex> lock(mutex_);
1109 onErrorCallback_ = onErrorCallback;
1110 }
1111
OffError()1112 void TLSSocket::OffError()
1113 {
1114 std::lock_guard<std::mutex> lock(mutex_);
1115 if (onErrorCallback_) {
1116 onErrorCallback_ = nullptr;
1117 }
1118 }
1119
GetSocketFd()1120 int TLSSocket::GetSocketFd()
1121 {
1122 return sockFd_;
1123 }
1124
SetLocalAddress(const Socket::NetAddress & address)1125 void TLSSocket::SetLocalAddress(const Socket::NetAddress &address)
1126 {
1127 localAddress_ = address;
1128 }
1129
GetLocalAddress()1130 Socket::NetAddress TLSSocket::GetLocalAddress()
1131 {
1132 return localAddress_;
1133 }
1134
ExecSocketConnect(const std::string & host,int port,sa_family_t family,int socketDescriptor)1135 bool ExecSocketConnect(const std::string &host, int port, sa_family_t family, int socketDescriptor)
1136 {
1137 auto hostName = ConvertAddressToIp(host, family);
1138
1139 sockaddr_in addr4 = {0};
1140 sockaddr_in6 addr6 = {0};
1141 sockaddr *addr = nullptr;
1142 socklen_t len = 0;
1143 if (family == AF_INET) {
1144 if (inet_pton(AF_INET, hostName.c_str(), &addr4.sin_addr.s_addr) <= 0) {
1145 return false;
1146 }
1147 addr4.sin_family = family;
1148 addr4.sin_port = htons(port);
1149 addr = reinterpret_cast<sockaddr *>(&addr4);
1150 len = sizeof(sockaddr_in);
1151 } else {
1152 if (inet_pton(AF_INET6, hostName.c_str(), &addr6.sin6_addr) <= 0) {
1153 return false;
1154 }
1155 addr6.sin6_family = family;
1156 addr6.sin6_port = htons(port);
1157 addr = reinterpret_cast<sockaddr *>(&addr6);
1158 len = sizeof(sockaddr_in6);
1159 }
1160
1161 int connectResult = connect(socketDescriptor, addr, len);
1162 if (connectResult == -1) {
1163 NETSTACK_LOGE("socket connect error!The error code is %{public}d, The error message is %{public}s", errno,
1164 strerror(errno));
1165 return false;
1166 }
1167 return true;
1168 }
1169
ConvertSSLError(void)1170 int TLSSocket::TLSSocketInternal::ConvertSSLError(void)
1171 {
1172 std::lock_guard<std::mutex> lock(mutexForSsl_);
1173 if (!ssl_) {
1174 return TLS_ERR_SSL_NULL;
1175 }
1176 return TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl_, SSL_RET_CODE);
1177 }
1178
TlsConnectToHost(int sock,const TLSConnectOptions & options,bool isExtSock)1179 bool TLSSocket::TLSSocketInternal::TlsConnectToHost(int sock, const TLSConnectOptions &options, bool isExtSock)
1180 {
1181 SetTlsConfiguration(options);
1182 std::string cipherSuite = options.GetTlsSecureOptions().GetCipherSuite();
1183 if (!cipherSuite.empty()) {
1184 configuration_.SetCipherSuite(cipherSuite);
1185 }
1186 std::string signatureAlgorithms = options.GetTlsSecureOptions().GetSignatureAlgorithms();
1187 if (!signatureAlgorithms.empty()) {
1188 configuration_.SetSignatureAlgorithms(signatureAlgorithms);
1189 }
1190 const auto protocolVec = options.GetTlsSecureOptions().GetProtocolChain();
1191 if (!protocolVec.empty()) {
1192 configuration_.SetProtocol(protocolVec);
1193 }
1194 configuration_.SetSkipFlag(options.GetSkipRemoteValidation());
1195 hostName_ = options.GetNetAddress().GetAddress();
1196 port_ = options.GetNetAddress().GetPort();
1197 family_ = options.GetNetAddress().GetSaFamily();
1198 socketDescriptor_ = sock;
1199 if (options.proxyOptions_ == nullptr && !isExtSock &&
1200 !ExecSocketConnect(options.GetNetAddress().GetAddress(), options.GetNetAddress().GetPort(),
1201 options.GetNetAddress().GetSaFamily(), socketDescriptor_)) {
1202 return false;
1203 }
1204 return StartTlsConnected(options);
1205 }
1206
SetTlsConfiguration(const TLSConnectOptions & config)1207 void TLSSocket::TLSSocketInternal::SetTlsConfiguration(const TLSConnectOptions &config)
1208 {
1209 configuration_.SetPrivateKey(config.GetTlsSecureOptions().GetKey(), config.GetTlsSecureOptions().GetKeyPass());
1210 configuration_.SetLocalCertificate(config.GetTlsSecureOptions().GetCert());
1211 configuration_.SetCaCertificate(config.GetTlsSecureOptions().GetCaChain());
1212 configuration_.SetNetAddress(config.GetNetAddress());
1213 }
1214
SendRetry(ssl_st * ssl,const char * curPos,size_t curSendSize,int sockfd)1215 bool TLSSocket::TLSSocketInternal::SendRetry(ssl_st *ssl, const char *curPos, size_t curSendSize, int sockfd)
1216 {
1217 pollfd fds[1] = {{.fd = sockfd, .events = POLLOUT}};
1218 for (int i = 0; i <= SEND_RETRY_TIMES; i++) {
1219 int ret = poll(fds, 1, SEND_POLL_TIMEOUT_MS);
1220 if (ret < 0) {
1221 if (errno == EAGAIN || errno == EINTR) {
1222 continue;
1223 }
1224 NETSTACK_LOGE("send poll error, fd: %{public}d, errno: %{public}d", sockfd, errno);
1225 return false;
1226 } else if (ret == 0) {
1227 NETSTACK_LOGI("send poll timeout, fd: %{public}d, errno: %{public}d", sockfd, errno);
1228 continue;
1229 }
1230 int len = SSL_write(ssl, curPos, curSendSize);
1231 if (len < 0) {
1232 int err = SSL_get_error(ssl, SSL_RET_CODE);
1233 NETSTACK_LOGE("Error in PollSend, errno is %{public}d %{public}d", err, errno);
1234 if (err == SSL_ERROR_WANT_WRITE || errno == EAGAIN) {
1235 NETSTACK_LOGI("write retry times: %{public}d err: %{public}d errno: %{public}d", i, err, errno);
1236 continue;
1237 } else {
1238 NETSTACK_LOGE("write failed err: %{public}d errno: %{public}d", err, errno);
1239 return false;
1240 }
1241 } else if (len == 0) {
1242 NETSTACK_LOGI("send len is 0, should have sent len");
1243 return false;
1244 } else {
1245 return true;
1246 }
1247 }
1248 return false;
1249 }
1250
PollSend(int sockfd,ssl_st * ssl,const char * pdata,int sendSize)1251 bool TLSSocket::TLSSocketInternal::PollSend(int sockfd, ssl_st *ssl, const char *pdata, int sendSize)
1252 {
1253 int bufferSize = DEFAULT_BUFFER_SIZE;
1254 auto curPos = pdata;
1255 nfds_t num = 1;
1256 pollfd fds[1] = {{.fd = sockfd, .events = POLLOUT}};
1257 while (sendSize > 0) {
1258 int ret = poll(fds, num, DEFAULT_POLL_TIMEOUT_MS);
1259 if (ret < 0) {
1260 if (errno == EAGAIN || errno == EINTR) {
1261 continue;
1262 }
1263 NETSTACK_LOGE("send poll error, fd: %{public}d, errno: %{public}d", sockfd, errno);
1264 return false;
1265 } else if (ret == 0) {
1266 NETSTACK_LOGI("send poll timeout, fd: %{public}d, errno: %{public}d", sockfd, errno);
1267 continue;
1268 }
1269 std::lock_guard<std::mutex> lock(mutexForSsl_);
1270 if (!ssl) {
1271 NETSTACK_LOGE("ssl is null");
1272 return false;
1273 }
1274 size_t curSendSize = std::min<size_t>(sendSize, bufferSize);
1275 int len = SSL_write(ssl, curPos, curSendSize);
1276 if (len < 0) {
1277 int err = SSL_get_error(ssl, SSL_RET_CODE);
1278 NETSTACK_LOGE("Error in PollSend, errno is %{public}d %{public}d", err, errno);
1279 if (err != SSL_ERROR_WANT_WRITE || errno != EAGAIN) {
1280 NETSTACK_LOGE("write failed, return, err: %{public}d errno: %{public}d", err, errno);
1281 return false;
1282 } else if (!SendRetry(ssl, curPos, curSendSize, sockfd)) {
1283 return false;
1284 }
1285 } else if (len == 0) {
1286 NETSTACK_LOGI("send len is 0, should have sent len is %{public}d", sendSize);
1287 return false;
1288 }
1289 curPos += len;
1290 sendSize -= len;
1291 }
1292 return true;
1293 }
1294
Send(const std::string & data)1295 bool TLSSocket::TLSSocketInternal::Send(const std::string &data)
1296 {
1297 {
1298 std::lock_guard<std::mutex> lock(mutexForSsl_);
1299 if (!ssl_) {
1300 NETSTACK_LOGE("ssl is null");
1301 return false;
1302 }
1303 }
1304
1305 if (data.empty()) {
1306 NETSTACK_LOGE("data is empty");
1307 return true;
1308 }
1309
1310 if (!PollSend(socketDescriptor_, ssl_, data.c_str(), data.size())) {
1311 return false;
1312 }
1313 return true;
1314 }
Recv(char * buffer,int maxBufferSize)1315 int TLSSocket::TLSSocketInternal::Recv(char *buffer, int maxBufferSize)
1316 {
1317 std::lock_guard<std::mutex> lock(mutexForSsl_);
1318 if (!ssl_) {
1319 NETSTACK_LOGE("ssl is null");
1320 return SSL_ERROR_RETURN;
1321 }
1322
1323 int ret = SSL_read(ssl_, buffer, maxBufferSize);
1324 if (ret < 0) {
1325 int err = SSL_get_error(ssl_, SSL_RET_CODE);
1326 switch (err) {
1327 case SSL_ERROR_SSL:
1328 NETSTACK_LOGE("An error occurred in the SSL library %{public}d %{public}d", err, errno);
1329 return SSL_ERROR_RETURN;
1330 case SSL_ERROR_ZERO_RETURN:
1331 NETSTACK_LOGE("peer disconnected...");
1332 return SSL_ERROR_RETURN;
1333 case SSL_ERROR_WANT_READ:
1334 NETSTACK_LOGD("SSL_read function no data available for reading, try again at a later time");
1335 return SSL_WANT_READ_RETURN;
1336 default:
1337 NETSTACK_LOGE("SSL_read function failed, error code is %{public}d", err);
1338 return SSL_ERROR_RETURN;
1339 }
1340 }
1341 return ret;
1342 }
1343
Close()1344 bool TLSSocket::TLSSocketInternal::Close()
1345 {
1346 std::lock_guard<std::mutex> lock(mutexForSsl_);
1347 if (!ssl_) {
1348 NETSTACK_LOGE("ssl is null, fd =%{public}d", socketDescriptor_);
1349 return false;
1350 }
1351 int result = SSL_shutdown(ssl_);
1352 if (result < 0) {
1353 int resErr = TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl_, SSL_RET_CODE);
1354 NETSTACK_LOGE("Error in shutdown, errno is %{public}d, error info is %{public}s", resErr,
1355 MakeSSLErrorString(resErr).c_str());
1356 }
1357 NETSTACK_LOGI("tls socket close, fd =%{public}d", socketDescriptor_);
1358 SSL_free(ssl_);
1359 ssl_ = nullptr;
1360 close(socketDescriptor_);
1361 socketDescriptor_ = -1;
1362 if (!tlsContextPointer_) {
1363 NETSTACK_LOGE("Tls context pointer is null");
1364 return false;
1365 }
1366 tlsContextPointer_->CloseCtx();
1367 return true;
1368 }
1369
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)1370 bool TLSSocket::TLSSocketInternal::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
1371 {
1372 if (!ssl_) {
1373 NETSTACK_LOGE("ssl is null");
1374 return false;
1375 }
1376 size_t pos = 0;
1377 size_t len = std::accumulate(alpnProtocols.begin(), alpnProtocols.end(), static_cast<size_t>(0),
1378 [](size_t init, const std::string &alpnProt) { return init + alpnProt.length(); });
1379 auto result = std::make_unique<unsigned char[]>(alpnProtocols.size() + len);
1380 for (const auto &str : alpnProtocols) {
1381 len = str.length();
1382 result[pos++] = len;
1383 if (!strcpy_s(reinterpret_cast<char *>(&result[pos]), len, str.c_str())) {
1384 NETSTACK_LOGE("strcpy_s failed");
1385 return false;
1386 }
1387 pos += len;
1388 }
1389 result[pos] = '\0';
1390
1391 NETSTACK_LOGD("alpnProtocols after splicing %{public}s", result.get());
1392 if (SSL_set_alpn_protos(ssl_, result.get(), pos)) {
1393 int resErr = ConvertSSLError();
1394 NETSTACK_LOGE("Failed to set negotiable protocol list, errno is %{public}d, error info is %{public}s", resErr,
1395 MakeSSLErrorString(resErr).c_str());
1396 return false;
1397 }
1398 return true;
1399 }
1400
MakeRemoteInfo(Socket::SocketRemoteInfo & remoteInfo)1401 void TLSSocket::TLSSocketInternal::MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo)
1402 {
1403 remoteInfo.SetFamily(family_);
1404 remoteInfo.SetAddress(hostName_);
1405 remoteInfo.SetPort(port_);
1406 }
1407
GetTlsConfiguration() const1408 TLSConfiguration TLSSocket::TLSSocketInternal::GetTlsConfiguration() const
1409 {
1410 return configuration_;
1411 }
1412
GetCipherSuite() const1413 std::vector<std::string> TLSSocket::TLSSocketInternal::GetCipherSuite() const
1414 {
1415 if (!ssl_) {
1416 NETSTACK_LOGE("ssl in null");
1417 return {};
1418 }
1419 STACK_OF(SSL_CIPHER) *sk = SSL_get_ciphers(ssl_);
1420 if (!sk) {
1421 NETSTACK_LOGE("get ciphers failed");
1422 return {};
1423 }
1424 CipherSuite cipherSuite;
1425 std::vector<std::string> cipherSuiteVec;
1426 for (int i = 0; i < sk_SSL_CIPHER_num(sk); i++) {
1427 const SSL_CIPHER *c = sk_SSL_CIPHER_value(sk, i);
1428 cipherSuite.cipherName_ = SSL_CIPHER_get_name(c);
1429 cipherSuiteVec.push_back(cipherSuite.cipherName_);
1430 }
1431 return cipherSuiteVec;
1432 }
1433
GetRemoteCertificate() const1434 std::string TLSSocket::TLSSocketInternal::GetRemoteCertificate() const
1435 {
1436 return remoteCert_;
1437 }
1438
GetCertificate() const1439 const X509CertRawData &TLSSocket::TLSSocketInternal::GetCertificate() const
1440 {
1441 return configuration_.GetCertificate();
1442 }
1443
GetSignatureAlgorithms() const1444 std::vector<std::string> TLSSocket::TLSSocketInternal::GetSignatureAlgorithms() const
1445 {
1446 return signatureAlgorithms_;
1447 }
1448
GetProtocol() const1449 std::string TLSSocket::TLSSocketInternal::GetProtocol() const
1450 {
1451 if (!ssl_) {
1452 NETSTACK_LOGE("ssl in null");
1453 return PROTOCOL_UNKNOW;
1454 }
1455 if (configuration_.GetProtocol() == TLS_V1_3) {
1456 return PROTOCOL_TLS_V13;
1457 }
1458 return PROTOCOL_TLS_V12;
1459 }
1460
SetSharedSigals()1461 bool TLSSocket::TLSSocketInternal::SetSharedSigals()
1462 {
1463 if (!ssl_) {
1464 NETSTACK_LOGE("ssl is null");
1465 return false;
1466 }
1467 int number = SSL_get_shared_sigalgs(ssl_, 0, nullptr, nullptr, nullptr, nullptr, nullptr);
1468 if (!number) {
1469 NETSTACK_LOGE("SSL_get_shared_sigalgs return value error");
1470 return false;
1471 }
1472 for (int i = 0; i < number; i++) {
1473 int hash_nid;
1474 int sign_nid;
1475 std::string sig_with_md;
1476 SSL_get_shared_sigalgs(ssl_, i, &sign_nid, &hash_nid, nullptr, nullptr, nullptr);
1477 switch (sign_nid) {
1478 case EVP_PKEY_RSA:
1479 sig_with_md = SIGN_NID_RSA;
1480 break;
1481 case EVP_PKEY_RSA_PSS:
1482 sig_with_md = SIGN_NID_RSA_PSS;
1483 break;
1484 case EVP_PKEY_DSA:
1485 sig_with_md = SIGN_NID_DSA;
1486 break;
1487 case EVP_PKEY_EC:
1488 sig_with_md = SIGN_NID_ECDSA;
1489 break;
1490 case NID_ED25519:
1491 sig_with_md = SIGN_NID_ED;
1492 break;
1493 case NID_ED448:
1494 sig_with_md = SIGN_NID_ED_FOUR_FOUR_EIGHT;
1495 break;
1496 default:
1497 const char *sn = OBJ_nid2sn(sign_nid);
1498 sig_with_md = (sn != nullptr) ? (std::string(sn) + OPERATOR_PLUS_SIGN) : SIGN_NID_UNDEF_ADD;
1499 }
1500 const char *sn_hash = OBJ_nid2sn(hash_nid);
1501 sig_with_md += (sn_hash != nullptr) ? std::string(sn_hash) : SIGN_NID_UNDEF;
1502 signatureAlgorithms_.push_back(sig_with_md);
1503 }
1504 return true;
1505 }
1506
StartTlsConnected(const TLSConnectOptions & options)1507 bool TLSSocket::TLSSocketInternal::StartTlsConnected(const TLSConnectOptions &options)
1508 {
1509 if (!CreatTlsContext()) {
1510 NETSTACK_LOGE("failed to create tls context");
1511 return false;
1512 }
1513 if (!StartShakingHands(options)) {
1514 NETSTACK_LOGE("failed to shaking hands");
1515 return false;
1516 }
1517 return true;
1518 }
1519
CreatTlsContext()1520 bool TLSSocket::TLSSocketInternal::CreatTlsContext()
1521 {
1522 tlsContextPointer_ = TLSContext::CreateConfiguration(configuration_);
1523 if (!tlsContextPointer_) {
1524 NETSTACK_LOGE("failed to create tls context pointer");
1525 return false;
1526 }
1527
1528 std::lock_guard<std::mutex> lock(mutexForSsl_);
1529 if (!(ssl_ = tlsContextPointer_->CreateSsl())) {
1530 NETSTACK_LOGE("failed to create ssl session");
1531 return false;
1532 }
1533
1534 SSL_set_fd(ssl_, socketDescriptor_);
1535 SSL_set_connect_state(ssl_);
1536 return true;
1537 }
1538
StartsWith(const std::string & s,const std::string & prefix)1539 static bool StartsWith(const std::string &s, const std::string &prefix)
1540 {
1541 return s.size() >= prefix.size() && s.compare(0, prefix.size(), prefix) == 0;
1542 }
1543
CheckIpAndDnsName(const std::string & hostName,std::vector<std::string> dnsNames,std::vector<std::string> ips,const X509 * x509Certificates,std::tuple<bool,std::string> & result)1544 void CheckIpAndDnsName(const std::string &hostName, std::vector<std::string> dnsNames, std::vector<std::string> ips,
1545 const X509 *x509Certificates, std::tuple<bool, std::string> &result)
1546 {
1547 bool valid = false;
1548 std::string reason = UNKNOW_REASON;
1549 int index = X509_get_ext_by_NID(x509Certificates, NID_commonName, -1);
1550 if (IsIP(hostName)) {
1551 auto it = find(ips.begin(), ips.end(), hostName);
1552 if (it == ips.end()) {
1553 reason = IP + hostName + " is not in the cert's list";
1554 }
1555 result = {valid, reason};
1556 return;
1557 }
1558 std::string tempHostName = "" + hostName;
1559 if (!dnsNames.empty() || index > 0) {
1560 std::vector<std::string> hostParts = SplitHostName(tempHostName);
1561 if (!dnsNames.empty()) {
1562 valid = SeekIntersection(hostParts, dnsNames);
1563 if (!valid) {
1564 reason = HOST_NAME + tempHostName + ". is not in the cert's altnames";
1565 }
1566 } else {
1567 char commonNameBuf[COMMON_NAME_BUF_SIZE] = {0};
1568 X509_NAME *pSubName = nullptr;
1569 int len = X509_NAME_get_text_by_NID(pSubName, NID_commonName, commonNameBuf, COMMON_NAME_BUF_SIZE);
1570 if (len > 0) {
1571 std::vector<std::string> commonNameVec;
1572 commonNameVec.emplace_back(commonNameBuf);
1573 valid = SeekIntersection(hostParts, commonNameVec);
1574 if (!valid) {
1575 reason = HOST_NAME + tempHostName + ". is not cert's CN";
1576 }
1577 }
1578 }
1579 result = {valid, reason};
1580 return;
1581 }
1582 reason = "Cert does not contain a DNS name";
1583 result = {valid, reason};
1584 }
1585
CheckServerIdentityLegal(const std::string & hostName,const X509 * x509Certificates)1586 std::string TLSSocket::TLSSocketInternal::CheckServerIdentityLegal(const std::string &hostName,
1587 const X509 *x509Certificates)
1588 {
1589 X509_NAME *subjectName = X509_get_subject_name(x509Certificates);
1590 if (!subjectName) {
1591 return "subject name is null";
1592 }
1593 char subNameBuf[BUF_SIZE] = {0};
1594 X509_NAME_oneline(subjectName, subNameBuf, BUF_SIZE);
1595
1596 int index = X509_get_ext_by_NID(x509Certificates, NID_subject_alt_name, -1);
1597 if (index < 0) {
1598 return "X509 get ext nid error";
1599 }
1600 X509_EXTENSION *ext = X509_get_ext(x509Certificates, index);
1601 if (ext == nullptr) {
1602 return "X509 get ext error";
1603 }
1604 ASN1_OBJECT *obj = nullptr;
1605 obj = X509_EXTENSION_get_object(ext);
1606 char subAltNameBuf[BUF_SIZE] = {0};
1607 OBJ_obj2txt(subAltNameBuf, BUF_SIZE, obj, 0);
1608
1609 return CheckServerIdentityLegal(hostName, ext, x509Certificates);
1610 }
1611
CheckServerIdentityLegal(const std::string & hostName,X509_EXTENSION * ext,const X509 * x509Certificates)1612 std::string TLSSocket::TLSSocketInternal::CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext,
1613 const X509 *x509Certificates)
1614 {
1615 ASN1_OCTET_STRING *extData = X509_EXTENSION_get_data(ext);
1616 if (!extData) {
1617 NETSTACK_LOGE("extData is nullptr");
1618 return "";
1619 }
1620 std::string altNames = reinterpret_cast<char *>(extData->data);
1621 std::string hostname = " " + hostName;
1622 BIO *bio = BIO_new(BIO_s_file());
1623 if (!bio) {
1624 return "bio is null";
1625 }
1626 BIO_set_fp(bio, stdout, BIO_NOCLOSE);
1627 ASN1_STRING_print(bio, extData);
1628 std::vector<std::string> dnsNames = {};
1629 std::vector<std::string> ips = {};
1630 constexpr int DNS_NAME_IDX = 4;
1631 constexpr int IP_NAME_IDX = 11;
1632 if (!altNames.empty()) {
1633 std::vector<std::string> splitAltNames;
1634 if (altNames.find('\"') != std::string::npos) {
1635 splitAltNames = SplitEscapedAltNames(altNames);
1636 } else {
1637 splitAltNames = CommonUtils::Split(altNames, SPLIT_ALT_NAMES);
1638 }
1639 for (auto const &iter : splitAltNames) {
1640 if (StartsWith(iter, DNS)) {
1641 dnsNames.push_back(iter.substr(DNS_NAME_IDX));
1642 } else if (StartsWith(iter, IP_ADDRESS)) {
1643 ips.push_back(iter.substr(IP_NAME_IDX));
1644 }
1645 }
1646 }
1647 std::tuple<bool, std::string> result;
1648 CheckIpAndDnsName(hostName, dnsNames, ips, x509Certificates, result);
1649 if (!std::get<0>(result)) {
1650 return "Hostname/IP does not match certificate's altnames: " + std::get<1>(result);
1651 }
1652 return HOST_NAME + hostname + ". is cert's CN";
1653 }
1654
LoadCaCertFromMemory(X509_STORE * store,const std::string & pemCerts)1655 static void LoadCaCertFromMemory(X509_STORE *store, const std::string &pemCerts)
1656 {
1657 if (!store || pemCerts.empty() || pemCerts.size() > static_cast<size_t>(INT_MAX)) {
1658 return;
1659 }
1660
1661 auto cbio = BIO_new_mem_buf(pemCerts.data(), static_cast<int>(pemCerts.size()));
1662 if (!cbio) {
1663 return;
1664 }
1665
1666 auto inf = PEM_X509_INFO_read_bio(cbio, nullptr, nullptr, nullptr);
1667 if (!inf) {
1668 BIO_free(cbio);
1669 return;
1670 }
1671
1672 /* add each entry from PEM file to x509_store */
1673 for (int i = 0; i < static_cast<int>(sk_X509_INFO_num(inf)); ++i) {
1674 auto itmp = sk_X509_INFO_value(inf, i);
1675 if (!itmp) {
1676 continue;
1677 }
1678 if (itmp->x509) {
1679 X509_STORE_add_cert(store, itmp->x509);
1680 }
1681 if (itmp->crl) {
1682 X509_STORE_add_crl(store, itmp->crl);
1683 }
1684 }
1685
1686 sk_X509_INFO_pop_free(inf, X509_INFO_free);
1687 BIO_free(cbio);
1688 }
1689
X509_to_PEM(X509 * cert)1690 static std::string X509_to_PEM(X509 *cert)
1691 {
1692 if (!cert) {
1693 return {};
1694 }
1695 BIO *bio = BIO_new(BIO_s_mem());
1696 if (!bio) {
1697 return {};
1698 }
1699 if (!PEM_write_bio_X509(bio, cert)) {
1700 BIO_free(bio);
1701 return {};
1702 }
1703
1704 char *data = nullptr;
1705 auto pemStringLength = BIO_get_mem_data(bio, &data);
1706 if (!data) {
1707 BIO_free(bio);
1708 return {};
1709 }
1710 std::string certificateInPEM(data, pemStringLength);
1711 BIO_free(bio);
1712 return certificateInPEM;
1713 }
1714
CacheCertificates(const std::string & hostName,SSL * ssl)1715 static void CacheCertificates(const std::string &hostName, SSL *ssl)
1716 {
1717 if (!ssl || hostName.empty()) {
1718 return;
1719 }
1720 auto certificatesStack = SSL_get_peer_cert_chain(ssl);
1721 if (!certificatesStack) {
1722 return;
1723 }
1724 auto numCertificates = sk_X509_num(certificatesStack);
1725 for (auto i = 0; i < numCertificates; ++i) {
1726 auto cert = sk_X509_value(certificatesStack, i);
1727 auto certificateInPEM = X509_to_PEM(cert);
1728 if (!certificateInPEM.empty()) {
1729 CaCertCache::GetInstance().Set(hostName, certificateInPEM);
1730 }
1731 }
1732 }
1733
LoadCachedCaCert(const std::string & hostName,SSL * ssl)1734 static void LoadCachedCaCert(const std::string &hostName, SSL *ssl)
1735 {
1736 if (!ssl) {
1737 return;
1738 }
1739 auto cachedPem = CaCertCache::GetInstance().Get(hostName);
1740 auto sslCtx = SSL_get_SSL_CTX(ssl);
1741 if (!sslCtx) {
1742 return;
1743 }
1744 auto x509Store = SSL_CTX_get_cert_store(sslCtx);
1745 if (!x509Store) {
1746 return;
1747 }
1748 for (const auto &pem : cachedPem) {
1749 LoadCaCertFromMemory(x509Store, pem);
1750 }
1751 }
1752
StartShakingHands(const TLSConnectOptions & options)1753 bool TLSSocket::TLSSocketInternal::StartShakingHands(const TLSConnectOptions &options)
1754 {
1755 {
1756 std::lock_guard<std::mutex> lock(mutexForSsl_);
1757 if (!ssl_) {
1758 NETSTACK_LOGE("ssl is null");
1759 return false;
1760 }
1761
1762 auto hostName = options.GetHostName();
1763 // indicates hostName is not ip address
1764 if (hostName != options.GetNetAddress().GetAddress()) {
1765 LoadCachedCaCert(hostName, ssl_);
1766 }
1767
1768 int result = SSL_connect(ssl_);
1769 if (result == -1) {
1770 char err[MAX_ERR_LEN] = {0};
1771 auto code = ERR_get_error();
1772 ERR_error_string_n(code, err, MAX_ERR_LEN);
1773 int errorStatus = TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl_, SSL_RET_CODE);
1774 NETSTACK_LOGE("SSLConnect fail %{public}d, error: %{public}s errno: %{public}d ERR_get_error %{public}s",
1775 errorStatus, MakeSSLErrorString(errorStatus).c_str(), errno, err);
1776 return false;
1777 }
1778
1779 // indicates hostName is not ip address
1780 if (hostName != options.GetNetAddress().GetAddress()) {
1781 CacheCertificates(hostName, ssl_);
1782 }
1783
1784 std::string list = SSL_get_cipher_list(ssl_, 0);
1785 NETSTACK_LOGI("cipher_list: %{public}s, Version: %{public}s, Cipher: %{public}s", list.c_str(),
1786 SSL_get_version(ssl_), SSL_get_cipher(ssl_));
1787 configuration_.SetCipherSuite(list);
1788 }
1789 if (!SetSharedSigals()) {
1790 NETSTACK_LOGE("Failed to set sharedSigalgs");
1791 }
1792 if (!GetRemoteCertificateFromPeer()) {
1793 NETSTACK_LOGE("Failed to get remote certificate");
1794 }
1795 if (!peerX509_) {
1796 NETSTACK_LOGE("peer x509Certificates is null");
1797 return false;
1798 }
1799 if (!SetRemoteCertRawData()) {
1800 NETSTACK_LOGE("Failed to set remote x509 certificata Serialization data");
1801 }
1802 CheckServerIdentity checkServerIdentity = options.GetCheckServerIdentity();
1803 if (!checkServerIdentity) {
1804 CheckServerIdentityLegal(hostName_, peerX509_);
1805 } else {
1806 checkServerIdentity(hostName_, {remoteCert_});
1807 }
1808 return true;
1809 }
1810
GetRemoteCertificateFromPeer()1811 bool TLSSocket::TLSSocketInternal::GetRemoteCertificateFromPeer()
1812 {
1813 peerX509_ = SSL_get_peer_certificate(ssl_);
1814 if (peerX509_ == nullptr) {
1815 int resErr = ConvertSSLError();
1816 NETSTACK_LOGE("open fail errno, errno is %{public}d %{public}d", resErr, errno);
1817 return false;
1818 }
1819 BIO *bio = BIO_new(BIO_s_mem());
1820 if (!bio) {
1821 NETSTACK_LOGE("TlsSocket::SetRemoteCertificate bio is null");
1822 return false;
1823 }
1824 X509_print(bio, peerX509_);
1825 char data[REMOTE_CERT_LEN] = {0};
1826 if (!BIO_read(bio, data, REMOTE_CERT_LEN)) {
1827 NETSTACK_LOGE("BIO_read function returns error");
1828 BIO_free(bio);
1829 return false;
1830 }
1831 BIO_free(bio);
1832 remoteCert_ = std::string(data);
1833 return true;
1834 }
1835
SetRemoteCertRawData()1836 bool TLSSocket::TLSSocketInternal::SetRemoteCertRawData()
1837 {
1838 if (peerX509_ == nullptr) {
1839 NETSTACK_LOGE("peerX509 is null");
1840 return false;
1841 }
1842 int32_t length = i2d_X509(peerX509_, nullptr);
1843 if (length <= 0) {
1844 NETSTACK_LOGE("Failed to convert peerX509 to der format");
1845 return false;
1846 }
1847 unsigned char *der = nullptr;
1848 (void)i2d_X509(peerX509_, &der);
1849 SecureData data(der, length);
1850 remoteRawData_.data = data;
1851 OPENSSL_free(der);
1852 remoteRawData_.encodingFormat = DER;
1853 return true;
1854 }
1855
GetRemoteCertRawData() const1856 const X509CertRawData &TLSSocket::TLSSocketInternal::GetRemoteCertRawData() const
1857 {
1858 return remoteRawData_;
1859 }
1860
GetSSL()1861 ssl_st *TLSSocket::TLSSocketInternal::GetSSL()
1862 {
1863 std::lock_guard<std::mutex> lock(mutexForSsl_);
1864 return ssl_;
1865 }
1866 } // namespace TlsSocket
1867 } // namespace NetStack
1868 } // namespace OHOS
1869