• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022-2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "tls_socket.h"
17 
18 #include <chrono>
19 #include <memory>
20 #include <numeric>
21 #include <regex>
22 #include <securec.h>
23 #include <thread>
24 
25 #include <netinet/tcp.h>
26 #include <openssl/err.h>
27 #include <openssl/ssl.h>
28 
29 #include "base_context.h"
30 #include "netstack_common_utils.h"
31 #include "netstack_log.h"
32 #include "tls.h"
33 
34 namespace OHOS {
35 namespace NetStack {
36 namespace TlsSocket {
37 namespace {
38 constexpr int WAIT_MS = 10;
39 constexpr int TIMEOUT_MS = 10000;
40 constexpr int REMOTE_CERT_LEN = 8192;
41 constexpr int COMMON_NAME_BUF_SIZE = 256;
42 constexpr int BUF_SIZE = 2048;
43 constexpr int SSL_RET_CODE = 0;
44 constexpr int SSL_ERROR_RETURN = -1;
45 constexpr int OFFSET = 2;
46 constexpr int QUIT_RESPONSE_CODE_LEN = 3;
47 constexpr const char *SPLIT_ALT_NAMES = ",";
48 constexpr const char *SPLIT_HOST_NAME = ".";
49 constexpr const char *PROTOCOL_UNKNOW = "UNKNOW_PROTOCOL";
50 constexpr const char *UNKNOW_REASON = "Unknown reason";
51 constexpr const char *IP = "IP: ";
52 constexpr const char *HOST_NAME = "hostname: ";
53 constexpr const char *DNS = "DNS:";
54 constexpr const char *IP_ADDRESS = "IP Address:";
55 constexpr const char *SIGN_NID_RSA = "RSA+";
56 constexpr const char *SIGN_NID_RSA_PSS = "RSA-PSS+";
57 constexpr const char *SIGN_NID_DSA = "DSA+";
58 constexpr const char *SIGN_NID_ECDSA = "ECDSA+";
59 constexpr const char *SIGN_NID_ED = "Ed25519+";
60 constexpr const char *SIGN_NID_ED_FOUR_FOUR_EIGHT = "Ed448+";
61 constexpr const char *SIGN_NID_UNDEF_ADD = "UNDEF+";
62 constexpr const char *SIGN_NID_UNDEF = "UNDEF";
63 constexpr const char *OPERATOR_PLUS_SIGN = "+";
64 constexpr const char *QUIT_RESPONSE_CODE = "221";
65 const std::regex JSON_STRING_PATTERN{R"(/^"(?:[^"\\\u0000-\u001f]|\\(?:["\\/bfnrt]|u[0-9a-fA-F]{4}))*"/)"};
66 const std::regex PATTERN{
67     "((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|"
68     "2[0-4][0-9]|[01]?[0-9][0-9]?)"};
69 
ConvertErrno()70 int ConvertErrno()
71 {
72     return TlsSocketError::TLS_ERR_SYS_BASE + errno;
73 }
74 
ConvertSSLError(ssl_st * ssl)75 int ConvertSSLError(ssl_st *ssl)
76 {
77     if (!ssl) {
78         return TLS_ERR_SSL_NULL;
79     }
80     return TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl, SSL_RET_CODE);
81 }
82 
MakeErrnoString()83 std::string MakeErrnoString()
84 {
85     return strerror(errno);
86 }
87 
MakeSSLErrorString(int error)88 std::string MakeSSLErrorString(int error)
89 {
90     char err[MAX_ERR_LEN] = {0};
91     ERR_error_string_n(error - TlsSocketError::TLS_ERR_SYS_BASE, err, sizeof(err));
92     return err;
93 }
94 
SplitEscapedAltNames(std::string & altNames)95 std::vector<std::string> SplitEscapedAltNames(std::string &altNames)
96 {
97     std::vector<std::string> result;
98     std::string currentToken;
99     size_t offset = 0;
100     while (offset != altNames.length()) {
101         auto nextSep = altNames.find_first_of(", ");
102         auto nextQuote = altNames.find_first_of('\"');
103         if (nextQuote != std::string::npos && (nextSep != std::string::npos || nextQuote < nextSep)) {
104             currentToken += altNames.substr(offset, nextQuote);
105             std::regex jsonStringPattern(JSON_STRING_PATTERN);
106             std::smatch match;
107             std::string altNameSubStr = altNames.substr(nextQuote);
108             bool ret = regex_match(altNameSubStr, match, jsonStringPattern);
109             if (!ret) {
110                 return {""};
111             }
112             currentToken += result[0];
113             offset = nextQuote + result[0].length();
114         } else if (nextSep != std::string::npos) {
115             currentToken += altNames.substr(offset, nextSep);
116             result.push_back(currentToken);
117             currentToken = "";
118             offset = nextSep + OFFSET;
119         } else {
120             currentToken += altNames.substr(offset);
121             offset = altNames.length();
122         }
123     }
124     result.push_back(currentToken);
125     return result;
126 }
127 
IsIP(const std::string & ip)128 bool IsIP(const std::string &ip)
129 {
130     std::regex pattern(PATTERN);
131     std::smatch res;
132     return regex_match(ip, res, pattern);
133 }
134 
SplitHostName(std::string & hostName)135 std::vector<std::string> SplitHostName(std::string &hostName)
136 {
137     transform(hostName.begin(), hostName.end(), hostName.begin(), ::tolower);
138     return CommonUtils::Split(hostName, SPLIT_HOST_NAME);
139 }
140 
SeekIntersection(std::vector<std::string> & vecA,std::vector<std::string> & vecB)141 bool SeekIntersection(std::vector<std::string> &vecA, std::vector<std::string> &vecB)
142 {
143     std::vector<std::string> result;
144     set_intersection(vecA.begin(), vecA.end(), vecB.begin(), vecB.end(), inserter(result, result.begin()));
145     return !result.empty();
146 }
147 } // namespace
148 
TLSSecureOptions(const TLSSecureOptions & tlsSecureOptions)149 TLSSecureOptions::TLSSecureOptions(const TLSSecureOptions &tlsSecureOptions)
150 {
151     *this = tlsSecureOptions;
152 }
153 
operator =(const TLSSecureOptions & tlsSecureOptions)154 TLSSecureOptions &TLSSecureOptions::operator=(const TLSSecureOptions &tlsSecureOptions)
155 {
156     key_ = tlsSecureOptions.GetKey();
157     caChain_ = tlsSecureOptions.GetCaChain();
158     cert_ = tlsSecureOptions.GetCert();
159     protocolChain_ = tlsSecureOptions.GetProtocolChain();
160     crlChain_ = tlsSecureOptions.GetCrlChain();
161     keyPass_ = tlsSecureOptions.GetKeyPass();
162     key_ = tlsSecureOptions.GetKey();
163     signatureAlgorithms_ = tlsSecureOptions.GetSignatureAlgorithms();
164     cipherSuite_ = tlsSecureOptions.GetCipherSuite();
165     useRemoteCipherPrefer_ = tlsSecureOptions.UseRemoteCipherPrefer();
166     TLSVerifyMode_ = tlsSecureOptions.GetVerifyMode();
167     return *this;
168 }
169 
SetCaChain(const std::vector<std::string> & caChain)170 void TLSSecureOptions::SetCaChain(const std::vector<std::string> &caChain)
171 {
172     caChain_ = caChain;
173 }
174 
SetCert(const std::string & cert)175 void TLSSecureOptions::SetCert(const std::string &cert)
176 {
177     cert_ = cert;
178 }
179 
SetKey(const SecureData & key)180 void TLSSecureOptions::SetKey(const SecureData &key)
181 {
182     key_ = key;
183 }
184 
SetKeyPass(const SecureData & keyPass)185 void TLSSecureOptions::SetKeyPass(const SecureData &keyPass)
186 {
187     keyPass_ = keyPass;
188 }
189 
SetProtocolChain(const std::vector<std::string> & protocolChain)190 void TLSSecureOptions::SetProtocolChain(const std::vector<std::string> &protocolChain)
191 {
192     protocolChain_ = protocolChain;
193 }
194 
SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer)195 void TLSSecureOptions::SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer)
196 {
197     useRemoteCipherPrefer_ = useRemoteCipherPrefer;
198 }
199 
SetSignatureAlgorithms(const std::string & signatureAlgorithms)200 void TLSSecureOptions::SetSignatureAlgorithms(const std::string &signatureAlgorithms)
201 {
202     signatureAlgorithms_ = signatureAlgorithms;
203 }
204 
SetCipherSuite(const std::string & cipherSuite)205 void TLSSecureOptions::SetCipherSuite(const std::string &cipherSuite)
206 {
207     cipherSuite_ = cipherSuite;
208 }
209 
SetCrlChain(const std::vector<std::string> & crlChain)210 void TLSSecureOptions::SetCrlChain(const std::vector<std::string> &crlChain)
211 {
212     crlChain_ = crlChain;
213 }
214 
GetCaChain() const215 const std::vector<std::string> &TLSSecureOptions::GetCaChain() const
216 {
217     return caChain_;
218 }
219 
GetCert() const220 const std::string &TLSSecureOptions::GetCert() const
221 {
222     return cert_;
223 }
224 
GetKey() const225 const SecureData &TLSSecureOptions::GetKey() const
226 {
227     return key_;
228 }
229 
GetKeyPass() const230 const SecureData &TLSSecureOptions::GetKeyPass() const
231 {
232     return keyPass_;
233 }
234 
GetProtocolChain() const235 const std::vector<std::string> &TLSSecureOptions::GetProtocolChain() const
236 {
237     return protocolChain_;
238 }
239 
UseRemoteCipherPrefer() const240 bool TLSSecureOptions::UseRemoteCipherPrefer() const
241 {
242     return useRemoteCipherPrefer_;
243 }
244 
GetSignatureAlgorithms() const245 const std::string &TLSSecureOptions::GetSignatureAlgorithms() const
246 {
247     return signatureAlgorithms_;
248 }
249 
GetCipherSuite() const250 const std::string &TLSSecureOptions::GetCipherSuite() const
251 {
252     return cipherSuite_;
253 }
254 
GetCrlChain() const255 const std::vector<std::string> &TLSSecureOptions::GetCrlChain() const
256 {
257     return crlChain_;
258 }
259 
SetVerifyMode(VerifyMode verifyMode)260 void TLSSecureOptions::SetVerifyMode(VerifyMode verifyMode)
261 {
262     TLSVerifyMode_ = verifyMode;
263 }
264 
GetVerifyMode() const265 VerifyMode TLSSecureOptions::GetVerifyMode() const
266 {
267     return TLSVerifyMode_;
268 }
269 
SetNetAddress(const Socket::NetAddress & address)270 void TLSConnectOptions::SetNetAddress(const Socket::NetAddress &address)
271 {
272     address_.SetAddress(address.GetAddress());
273     address_.SetPort(address.GetPort());
274     address_.SetFamilyBySaFamily(address.GetSaFamily());
275 }
276 
SetTlsSecureOptions(TLSSecureOptions & tlsSecureOptions)277 void TLSConnectOptions::SetTlsSecureOptions(TLSSecureOptions &tlsSecureOptions)
278 {
279     tlsSecureOptions_ = tlsSecureOptions;
280 }
281 
SetCheckServerIdentity(const CheckServerIdentity & checkServerIdentity)282 void TLSConnectOptions::SetCheckServerIdentity(const CheckServerIdentity &checkServerIdentity)
283 {
284     checkServerIdentity_ = checkServerIdentity;
285 }
286 
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)287 void TLSConnectOptions::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
288 {
289     alpnProtocols_ = alpnProtocols;
290 }
291 
GetNetAddress() const292 Socket::NetAddress TLSConnectOptions::GetNetAddress() const
293 {
294     return address_;
295 }
296 
GetTlsSecureOptions() const297 TLSSecureOptions TLSConnectOptions::GetTlsSecureOptions() const
298 {
299     return tlsSecureOptions_;
300 }
301 
GetCheckServerIdentity() const302 CheckServerIdentity TLSConnectOptions::GetCheckServerIdentity() const
303 {
304     return checkServerIdentity_;
305 }
306 
GetAlpnProtocols() const307 const std::vector<std::string> &TLSConnectOptions::GetAlpnProtocols() const
308 {
309     return alpnProtocols_;
310 }
311 
MakeAddressString(sockaddr * addr)312 std::string TLSSocket::MakeAddressString(sockaddr *addr)
313 {
314     if (!addr) {
315         return {};
316     }
317     if (addr->sa_family == AF_INET) {
318         auto *addr4 = reinterpret_cast<sockaddr_in *>(addr);
319         const char *str = inet_ntoa(addr4->sin_addr);
320         if (str == nullptr || strlen(str) == 0) {
321             return {};
322         }
323         return str;
324     } else if (addr->sa_family == AF_INET6) {
325         auto *addr6 = reinterpret_cast<sockaddr_in6 *>(addr);
326         char str[INET6_ADDRSTRLEN] = {0};
327         if (inet_ntop(AF_INET6, &addr6->sin6_addr, str, INET6_ADDRSTRLEN) == nullptr || strlen(str) == 0) {
328             return {};
329         }
330         return str;
331     }
332     return {};
333 }
334 
GetAddr(const Socket::NetAddress & address,sockaddr_in * addr4,sockaddr_in6 * addr6,sockaddr ** addr,socklen_t * len)335 void TLSSocket::GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr,
336                         socklen_t *len)
337 {
338     if (!addr6 || !addr4 || !len) {
339         return;
340     }
341     sa_family_t family = address.GetSaFamily();
342     if (family == AF_INET) {
343         addr4->sin_family = AF_INET;
344         addr4->sin_port = htons(address.GetPort());
345         addr4->sin_addr.s_addr = inet_addr(address.GetAddress().c_str());
346         *addr = reinterpret_cast<sockaddr *>(addr4);
347         *len = sizeof(sockaddr_in);
348     } else if (family == AF_INET6) {
349         addr6->sin6_family = AF_INET6;
350         addr6->sin6_port = htons(address.GetPort());
351         inet_pton(AF_INET6, address.GetAddress().c_str(), &addr6->sin6_addr);
352         *addr = reinterpret_cast<sockaddr *>(addr6);
353         *len = sizeof(sockaddr_in6);
354     }
355 }
356 
MakeIpSocket(sa_family_t family)357 void TLSSocket::MakeIpSocket(sa_family_t family)
358 {
359     if (family != AF_INET && family != AF_INET6) {
360         return;
361     }
362     int sock = socket(family, SOCK_STREAM, IPPROTO_IP);
363     if (sock < 0) {
364         int resErr = ConvertErrno();
365         NETSTACK_LOGE("Create socket failed (%{public}d:%{public}s)", errno, MakeErrnoString().c_str());
366         CallOnErrorCallback(resErr, MakeErrnoString());
367         return;
368     }
369     sockFd_ = sock;
370 }
371 
StartReadMessage()372 void TLSSocket::StartReadMessage()
373 {
374     std::thread thread([this]() {
375         isRunning_ = true;
376         isRunOver_ = false;
377         while (isRunning_) {
378             char buffer[MAX_BUFFER_SIZE];
379             if (memset_s(buffer, MAX_BUFFER_SIZE, 0, MAX_BUFFER_SIZE) != EOK) {
380                 NETSTACK_LOGE("memcpy_s failed!");
381                 break;
382             }
383             int len = tlsSocketInternal_.Recv(buffer, MAX_BUFFER_SIZE);
384             if (!isRunning_) {
385                 break;
386             }
387             if (len < 0) {
388                 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
389                 NETSTACK_LOGE("SSL_read function read error, errno is %{public}d, errno info is %{public}s", resErr,
390                               MakeSSLErrorString(resErr).c_str());
391                 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
392                 break;
393             }
394 
395             if (len == 0) {
396                 continue;
397             }
398 
399             Socket::SocketRemoteInfo remoteInfo;
400             remoteInfo.SetSize(strlen(buffer));
401             tlsSocketInternal_.MakeRemoteInfo(remoteInfo);
402             CallOnMessageCallback(buffer, remoteInfo);
403             if (strncmp(buffer, QUIT_RESPONSE_CODE, QUIT_RESPONSE_CODE_LEN) == 0) {
404                 break;
405             }
406         }
407         isRunOver_ = true;
408     });
409     thread.detach();
410 }
411 
CallOnMessageCallback(const std::string & data,const Socket::SocketRemoteInfo & remoteInfo)412 void TLSSocket::CallOnMessageCallback(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)
413 {
414     OnMessageCallback func = nullptr;
415     {
416         std::lock_guard<std::mutex> lock(mutex_);
417         if (onMessageCallback_) {
418             func = onMessageCallback_;
419         }
420     }
421 
422     if (func) {
423         func(data, remoteInfo);
424     }
425 }
426 
CallOnConnectCallback()427 void TLSSocket::CallOnConnectCallback()
428 {
429     OnConnectCallback func = nullptr;
430     {
431         std::lock_guard<std::mutex> lock(mutex_);
432         if (onConnectCallback_) {
433             func = onConnectCallback_;
434         }
435     }
436 
437     if (func) {
438         func();
439     }
440 }
441 
CallOnCloseCallback()442 void TLSSocket::CallOnCloseCallback()
443 {
444     OnCloseCallback func = nullptr;
445     {
446         std::lock_guard<std::mutex> lock(mutex_);
447         if (onCloseCallback_) {
448             func = onCloseCallback_;
449         }
450     }
451 
452     if (func) {
453         func();
454     }
455 }
456 
CallOnErrorCallback(int32_t err,const std::string & errString)457 void TLSSocket::CallOnErrorCallback(int32_t err, const std::string &errString)
458 {
459     OnErrorCallback func = nullptr;
460     {
461         std::lock_guard<std::mutex> lock(mutex_);
462         if (onErrorCallback_) {
463             func = onErrorCallback_;
464         }
465     }
466 
467     if (func) {
468         func(err, errString);
469     }
470 }
471 
CallBindCallback(int32_t err,BindCallback callback)472 void TLSSocket::CallBindCallback(int32_t err, BindCallback callback)
473 {
474     DealCallback<BindCallback>(err, callback);
475 }
476 
CallConnectCallback(int32_t err,ConnectCallback callback)477 void TLSSocket::CallConnectCallback(int32_t err, ConnectCallback callback)
478 {
479     DealCallback<ConnectCallback>(err, callback);
480 }
481 
CallSendCallback(int32_t err,SendCallback callback)482 void TLSSocket::CallSendCallback(int32_t err, SendCallback callback)
483 {
484     DealCallback<SendCallback>(err, callback);
485 }
486 
CallCloseCallback(int32_t err,CloseCallback callback)487 void TLSSocket::CallCloseCallback(int32_t err, CloseCallback callback)
488 {
489     DealCallback<CloseCallback>(err, callback);
490 }
491 
CallGetRemoteAddressCallback(int32_t err,const Socket::NetAddress & address,GetRemoteAddressCallback callback)492 void TLSSocket::CallGetRemoteAddressCallback(int32_t err, const Socket::NetAddress &address,
493                                              GetRemoteAddressCallback callback)
494 {
495     GetRemoteAddressCallback func = nullptr;
496     {
497         std::lock_guard<std::mutex> lock(mutex_);
498         if (callback) {
499             func = callback;
500         }
501     }
502 
503     if (func) {
504         func(err, address);
505     }
506 }
507 
CallGetStateCallback(int32_t err,const Socket::SocketStateBase & state,GetStateCallback callback)508 void TLSSocket::CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, GetStateCallback callback)
509 {
510     GetStateCallback func = nullptr;
511     {
512         std::lock_guard<std::mutex> lock(mutex_);
513         if (callback) {
514             func = callback;
515         }
516     }
517 
518     if (func) {
519         func(err, state);
520     }
521 }
522 
CallSetExtraOptionsCallback(int32_t err,SetExtraOptionsCallback callback)523 void TLSSocket::CallSetExtraOptionsCallback(int32_t err, SetExtraOptionsCallback callback)
524 {
525     DealCallback<SetExtraOptionsCallback>(err, callback);
526 }
527 
CallGetCertificateCallback(int32_t err,const X509CertRawData & cert,GetCertificateCallback callback)528 void TLSSocket::CallGetCertificateCallback(int32_t err, const X509CertRawData &cert, GetCertificateCallback callback)
529 {
530     GetCertificateCallback func = nullptr;
531     {
532         std::lock_guard<std::mutex> lock(mutex_);
533         if (callback) {
534             func = callback;
535         }
536     }
537 
538     if (func) {
539         func(err, cert);
540     }
541 }
542 
CallGetRemoteCertificateCallback(int32_t err,const X509CertRawData & cert,GetRemoteCertificateCallback callback)543 void TLSSocket::CallGetRemoteCertificateCallback(int32_t err, const X509CertRawData &cert,
544                                                  GetRemoteCertificateCallback callback)
545 {
546     GetRemoteCertificateCallback func = nullptr;
547     {
548         std::lock_guard<std::mutex> lock(mutex_);
549         if (callback) {
550             func = callback;
551         }
552     }
553 
554     if (func) {
555         func(err, cert);
556     }
557 }
558 
CallGetProtocolCallback(int32_t err,const std::string & protocol,GetProtocolCallback callback)559 void TLSSocket::CallGetProtocolCallback(int32_t err, const std::string &protocol, GetProtocolCallback callback)
560 {
561     GetProtocolCallback func = nullptr;
562     {
563         std::lock_guard<std::mutex> lock(mutex_);
564         if (callback) {
565             func = callback;
566         }
567     }
568 
569     if (func) {
570         func(err, protocol);
571     }
572 }
573 
CallGetCipherSuiteCallback(int32_t err,const std::vector<std::string> & suite,GetCipherSuiteCallback callback)574 void TLSSocket::CallGetCipherSuiteCallback(int32_t err, const std::vector<std::string> &suite,
575                                            GetCipherSuiteCallback callback)
576 {
577     GetCipherSuiteCallback func = nullptr;
578     {
579         std::lock_guard<std::mutex> lock(mutex_);
580         if (callback) {
581             func = callback;
582         }
583     }
584 
585     if (func) {
586         func(err, suite);
587     }
588 }
589 
CallGetSignatureAlgorithmsCallback(int32_t err,const std::vector<std::string> & algorithms,GetSignatureAlgorithmsCallback callback)590 void TLSSocket::CallGetSignatureAlgorithmsCallback(int32_t err, const std::vector<std::string> &algorithms,
591                                                    GetSignatureAlgorithmsCallback callback)
592 {
593     GetSignatureAlgorithmsCallback func = nullptr;
594     {
595         std::lock_guard<std::mutex> lock(mutex_);
596         if (callback) {
597             func = callback;
598         }
599     }
600 
601     if (func) {
602         func(err, algorithms);
603     }
604 }
605 
Bind(const Socket::NetAddress & address,const BindCallback & callback)606 void TLSSocket::Bind(const Socket::NetAddress &address, const BindCallback &callback)
607 {
608     if (!CommonUtils::HasInternetPermission()) {
609         CallBindCallback(PERMISSION_DENIED_CODE, callback);
610         return;
611     }
612     if (sockFd_ >= 0) {
613         CallBindCallback(TLSSOCKET_SUCCESS, callback);
614         return;
615     }
616 
617     MakeIpSocket(address.GetSaFamily());
618     if (sockFd_ < 0) {
619         int resErr = ConvertErrno();
620         NETSTACK_LOGE("make tcp socket failed errno is %{public}d %{public}s", errno, MakeErrnoString().c_str());
621         CallOnErrorCallback(resErr, MakeErrnoString());
622         CallBindCallback(resErr, callback);
623         return;
624     }
625 
626     sockaddr_in addr4 = {0};
627     sockaddr_in6 addr6 = {0};
628     sockaddr *addr = nullptr;
629     socklen_t len;
630     GetAddr(address, &addr4, &addr6, &addr, &len);
631     if (addr == nullptr) {
632         NETSTACK_LOGE("TLSSocket::Bind Address Is Invalid");
633         CallOnErrorCallback(-1, "Address Is Invalid");
634         CallBindCallback(ConvertErrno(), callback);
635         return;
636     }
637     CallBindCallback(TLSSOCKET_SUCCESS, callback);
638 }
639 
Connect(OHOS::NetStack::TlsSocket::TLSConnectOptions & tlsConnectOptions,const OHOS::NetStack::TlsSocket::ConnectCallback & callback)640 void TLSSocket::Connect(OHOS::NetStack::TlsSocket::TLSConnectOptions &tlsConnectOptions,
641                         const OHOS::NetStack::TlsSocket::ConnectCallback &callback)
642 {
643     if (sockFd_ < 0) {
644         int resErr = ConvertErrno();
645         NETSTACK_LOGE("connect error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
646         CallOnErrorCallback(resErr, MakeErrnoString());
647         callback(resErr);
648         return;
649     }
650 
651     auto res = tlsSocketInternal_.TlsConnectToHost(sockFd_, tlsConnectOptions);
652     if (!res) {
653         int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
654         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
655         callback(resErr);
656         return;
657     }
658     StartReadMessage();
659     CallOnConnectCallback();
660     callback(TLSSOCKET_SUCCESS);
661 }
662 
Send(const OHOS::NetStack::Socket::TCPSendOptions & tcpSendOptions,const SendCallback & callback)663 void TLSSocket::Send(const OHOS::NetStack::Socket::TCPSendOptions &tcpSendOptions, const SendCallback &callback)
664 {
665     (void)tcpSendOptions;
666 
667     auto res = tlsSocketInternal_.Send(tcpSendOptions.GetData());
668     if (!res) {
669         int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
670         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
671         CallSendCallback(resErr, callback);
672         return;
673     }
674     CallSendCallback(TLSSOCKET_SUCCESS, callback);
675 }
676 
WaitConditionWithTimeout(const bool * flag,const int32_t timeoutMs)677 bool WaitConditionWithTimeout(const bool *flag, const int32_t timeoutMs)
678 {
679     int maxWaitCnt = timeoutMs / WAIT_MS;
680     int cnt = 0;
681     while (!(*flag)) {
682         if (cnt >= maxWaitCnt) {
683             return false;
684         }
685         std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MS));
686         cnt++;
687     }
688     return true;
689 }
690 
Close(const CloseCallback & callback)691 void TLSSocket::Close(const CloseCallback &callback)
692 {
693     if (!WaitConditionWithTimeout(&isRunning_, TIMEOUT_MS)) {
694         callback(ConvertErrno());
695         NETSTACK_LOGE("The error cause is that the runtime wait time is insufficient");
696         return;
697     }
698     isRunning_ = false;
699 
700     auto res = tlsSocketInternal_.Close();
701     if (!res) {
702         int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
703         NETSTACK_LOGE("close error is %{public}s %{public}d", MakeSSLErrorString(resErr).c_str(), resErr);
704         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
705         callback(resErr);
706         return;
707     }
708     CallOnCloseCallback();
709     callback(TLSSOCKET_SUCCESS);
710 }
711 
GetRemoteAddress(const GetRemoteAddressCallback & callback)712 void TLSSocket::GetRemoteAddress(const GetRemoteAddressCallback &callback)
713 {
714     sockaddr sockAddr = {0};
715     socklen_t len = sizeof(sockaddr);
716     int ret = getsockname(sockFd_, &sockAddr, &len);
717     if (ret < 0) {
718         int resErr = ConvertErrno();
719         NETSTACK_LOGE("getsockname failed errno %{public}d", resErr);
720         CallOnErrorCallback(resErr, MakeErrnoString());
721         CallGetRemoteAddressCallback(resErr, {}, callback);
722         return;
723     }
724 
725     if (sockAddr.sa_family == AF_INET) {
726         GetIp4RemoteAddress(callback);
727     } else if (sockAddr.sa_family == AF_INET6) {
728         GetIp6RemoteAddress(callback);
729     }
730 }
731 
GetIp4RemoteAddress(const GetRemoteAddressCallback & callback)732 void TLSSocket::GetIp4RemoteAddress(const GetRemoteAddressCallback &callback)
733 {
734     sockaddr_in addr4 = {0};
735     socklen_t len4 = sizeof(sockaddr_in);
736 
737     int ret = getpeername(sockFd_, reinterpret_cast<sockaddr *>(&addr4), &len4);
738     if (ret < 0) {
739         int resErr = ConvertErrno();
740         NETSTACK_LOGE("GetIp4RemoteAddress failed errno %{public}d", resErr);
741         CallOnErrorCallback(resErr, MakeErrnoString());
742         CallGetRemoteAddressCallback(resErr, {}, callback);
743         return;
744     }
745 
746     std::string address = MakeAddressString(reinterpret_cast<sockaddr *>(&addr4));
747     if (address.empty()) {
748         NETSTACK_LOGE("GetIp4RemoteAddress failed errno %{public}d", errno);
749         CallOnErrorCallback(-1, "Address is invalid");
750         CallGetRemoteAddressCallback(ConvertErrno(), {}, callback);
751         return;
752     }
753     Socket::NetAddress netAddress;
754     netAddress.SetAddress(address);
755     netAddress.SetFamilyBySaFamily(AF_INET);
756     netAddress.SetPort(ntohs(addr4.sin_port));
757     CallGetRemoteAddressCallback(TLSSOCKET_SUCCESS, netAddress, callback);
758 }
759 
GetIp6RemoteAddress(const GetRemoteAddressCallback & callback)760 void TLSSocket::GetIp6RemoteAddress(const GetRemoteAddressCallback &callback)
761 {
762     sockaddr_in6 addr6 = {0};
763     socklen_t len6 = sizeof(sockaddr_in6);
764 
765     int ret = getpeername(sockFd_, reinterpret_cast<sockaddr *>(&addr6), &len6);
766     if (ret < 0) {
767         int resErr = ConvertErrno();
768         NETSTACK_LOGE("GetIp6RemoteAddress failed errno %{public}d", resErr);
769         CallOnErrorCallback(resErr, MakeErrnoString());
770         CallGetRemoteAddressCallback(resErr, {}, callback);
771         return;
772     }
773 
774     std::string address = MakeAddressString(reinterpret_cast<sockaddr *>(&addr6));
775     if (address.empty()) {
776         NETSTACK_LOGE("GetIp6RemoteAddress failed errno %{public}d", errno);
777         CallOnErrorCallback(-1, "Address is invalid");
778         CallGetRemoteAddressCallback(ConvertErrno(), {}, callback);
779         return;
780     }
781     Socket::NetAddress netAddress;
782     netAddress.SetAddress(address);
783     netAddress.SetFamilyBySaFamily(AF_INET6);
784     netAddress.SetPort(ntohs(addr6.sin6_port));
785     CallGetRemoteAddressCallback(TLSSOCKET_SUCCESS, netAddress, callback);
786 }
787 
GetState(const GetStateCallback & callback)788 void TLSSocket::GetState(const GetStateCallback &callback)
789 {
790     int opt;
791     socklen_t optLen = sizeof(int);
792     int r = getsockopt(sockFd_, SOL_SOCKET, SO_TYPE, &opt, &optLen);
793     if (r < 0) {
794         Socket::SocketStateBase state;
795         state.SetIsClose(true);
796         CallGetStateCallback(ConvertErrno(), state, callback);
797         return;
798     }
799     sockaddr sockAddr = {0};
800     socklen_t len = sizeof(sockaddr);
801     Socket::SocketStateBase state;
802     int ret = getsockname(sockFd_, &sockAddr, &len);
803     state.SetIsBound(ret == 0);
804     ret = getpeername(sockFd_, &sockAddr, &len);
805     state.SetIsConnected(ret == 0);
806     CallGetStateCallback(TLSSOCKET_SUCCESS, state, callback);
807 }
808 
SetBaseOptions(const Socket::ExtraOptionsBase & option) const809 bool TLSSocket::SetBaseOptions(const Socket::ExtraOptionsBase &option) const
810 {
811     if (option.GetReceiveBufferSize() != 0) {
812         int size = (int)option.GetReceiveBufferSize();
813         if (setsockopt(sockFd_, SOL_SOCKET, SO_RCVBUF, reinterpret_cast<void *>(&size), sizeof(size)) < 0) {
814             return false;
815         }
816     }
817 
818     if (option.GetSendBufferSize() != 0) {
819         int size = (int)option.GetSendBufferSize();
820         if (setsockopt(sockFd_, SOL_SOCKET, SO_SNDBUF, reinterpret_cast<void *>(&size), sizeof(size)) < 0) {
821             return false;
822         }
823     }
824 
825     if (option.IsReuseAddress()) {
826         int reuse = 1;
827         if (setsockopt(sockFd_, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<void *>(&reuse), sizeof(reuse)) < 0) {
828             return false;
829         }
830     }
831 
832     if (option.GetSocketTimeout() != 0) {
833         timeval timeout = {(int)option.GetSocketTimeout(), 0};
834         if (setsockopt(sockFd_, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<void *>(&timeout), sizeof(timeout)) < 0) {
835             return false;
836         }
837         if (setsockopt(sockFd_, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast<void *>(&timeout), sizeof(timeout)) < 0) {
838             return false;
839         }
840     }
841 
842     return true;
843 }
844 
SetExtraOptions(const Socket::TCPExtraOptions & option) const845 bool TLSSocket::SetExtraOptions(const Socket::TCPExtraOptions &option) const
846 {
847     if (option.IsKeepAlive()) {
848         int keepalive = 1;
849         if (setsockopt(sockFd_, SOL_SOCKET, SO_KEEPALIVE, &keepalive, sizeof(keepalive)) < 0) {
850             return false;
851         }
852     }
853 
854     if (option.IsOOBInline()) {
855         int oobInline = 1;
856         if (setsockopt(sockFd_, SOL_SOCKET, SO_OOBINLINE, &oobInline, sizeof(oobInline)) < 0) {
857             return false;
858         }
859     }
860 
861     if (option.IsTCPNoDelay()) {
862         int tcpNoDelay = 1;
863         if (setsockopt(sockFd_, IPPROTO_TCP, TCP_NODELAY, &tcpNoDelay, sizeof(tcpNoDelay)) < 0) {
864             return false;
865         }
866     }
867 
868     linger soLinger = {0};
869     soLinger.l_onoff = option.socketLinger.IsOn();
870     soLinger.l_linger = (int)option.socketLinger.GetLinger();
871     if (setsockopt(sockFd_, SOL_SOCKET, SO_LINGER, &soLinger, sizeof(soLinger)) < 0) {
872         return false;
873     }
874 
875     return true;
876 }
877 
SetExtraOptions(const OHOS::NetStack::Socket::TCPExtraOptions & tcpExtraOptions,const SetExtraOptionsCallback & callback)878 void TLSSocket::SetExtraOptions(const OHOS::NetStack::Socket::TCPExtraOptions &tcpExtraOptions,
879                                 const SetExtraOptionsCallback &callback)
880 {
881     if (!SetBaseOptions(tcpExtraOptions)) {
882         NETSTACK_LOGE("SetExtraOptions errno %{public}d", errno);
883         CallOnErrorCallback(errno, MakeErrnoString());
884         CallSetExtraOptionsCallback(ConvertErrno(), callback);
885         return;
886     }
887 
888     if (!SetExtraOptions(tcpExtraOptions)) {
889         NETSTACK_LOGE("SetExtraOptions errno %{public}d", errno);
890         CallOnErrorCallback(errno, MakeErrnoString());
891         CallSetExtraOptionsCallback(ConvertErrno(), callback);
892         return;
893     }
894 
895     CallSetExtraOptionsCallback(TLSSOCKET_SUCCESS, callback);
896 }
897 
GetCertificate(const GetCertificateCallback & callback)898 void TLSSocket::GetCertificate(const GetCertificateCallback &callback)
899 {
900     const auto &cert = tlsSocketInternal_.GetCertificate();
901     NETSTACK_LOGI("cert der is %{public}d", cert.encodingFormat);
902 
903     if (!cert.data.Length()) {
904         int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
905         NETSTACK_LOGE("GetCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
906         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
907         callback(resErr, {});
908         return;
909     }
910     callback(TLSSOCKET_SUCCESS, cert);
911 }
912 
GetRemoteCertificate(const GetRemoteCertificateCallback & callback)913 void TLSSocket::GetRemoteCertificate(const GetRemoteCertificateCallback &callback)
914 {
915     const auto &remoteCert = tlsSocketInternal_.GetRemoteCertRawData();
916     if (!remoteCert.data.Length()) {
917         int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
918         NETSTACK_LOGE("GetRemoteCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
919         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
920         callback(resErr, {});
921         return;
922     }
923     callback(TLSSOCKET_SUCCESS, remoteCert);
924 }
925 
GetProtocol(const GetProtocolCallback & callback)926 void TLSSocket::GetProtocol(const GetProtocolCallback &callback)
927 {
928     const auto &protocol = tlsSocketInternal_.GetProtocol();
929     if (protocol.empty()) {
930         NETSTACK_LOGE("GetProtocol errno %{public}d", errno);
931         int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
932         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
933         callback(resErr, "");
934         return;
935     }
936     callback(TLSSOCKET_SUCCESS, protocol);
937 }
938 
GetCipherSuite(const GetCipherSuiteCallback & callback)939 void TLSSocket::GetCipherSuite(const GetCipherSuiteCallback &callback)
940 {
941     const auto &cipherSuite = tlsSocketInternal_.GetCipherSuite();
942     if (cipherSuite.empty()) {
943         NETSTACK_LOGE("GetCipherSuite errno %{public}d", errno);
944         int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
945         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
946         callback(resErr, cipherSuite);
947         return;
948     }
949     callback(TLSSOCKET_SUCCESS, cipherSuite);
950 }
951 
GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback & callback)952 void TLSSocket::GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback &callback)
953 {
954     const auto &signatureAlgorithms = tlsSocketInternal_.GetSignatureAlgorithms();
955     if (signatureAlgorithms.empty()) {
956         NETSTACK_LOGE("GetSignatureAlgorithms errno %{public}d", errno);
957         int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
958         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
959         callback(resErr, {});
960         return;
961     }
962     callback(TLSSOCKET_SUCCESS, signatureAlgorithms);
963 }
964 
OnMessage(const OnMessageCallback & onMessageCallback)965 void TLSSocket::OnMessage(const OnMessageCallback &onMessageCallback)
966 {
967     std::lock_guard<std::mutex> lock(mutex_);
968     onMessageCallback_ = onMessageCallback;
969 }
970 
OffMessage()971 void TLSSocket::OffMessage()
972 {
973     std::lock_guard<std::mutex> lock(mutex_);
974     if (onMessageCallback_) {
975         onMessageCallback_ = nullptr;
976     }
977 }
978 
OnConnect(const OnConnectCallback & onConnectCallback)979 void TLSSocket::OnConnect(const OnConnectCallback &onConnectCallback)
980 {
981     std::lock_guard<std::mutex> lock(mutex_);
982     onConnectCallback_ = onConnectCallback;
983 }
984 
OffConnect()985 void TLSSocket::OffConnect()
986 {
987     std::lock_guard<std::mutex> lock(mutex_);
988     if (onConnectCallback_) {
989         onConnectCallback_ = nullptr;
990     }
991 }
992 
OnClose(const OnCloseCallback & onCloseCallback)993 void TLSSocket::OnClose(const OnCloseCallback &onCloseCallback)
994 {
995     std::lock_guard<std::mutex> lock(mutex_);
996     onCloseCallback_ = onCloseCallback;
997 }
998 
OffClose()999 void TLSSocket::OffClose()
1000 {
1001     std::lock_guard<std::mutex> lock(mutex_);
1002     if (onCloseCallback_) {
1003         onCloseCallback_ = nullptr;
1004     }
1005 }
1006 
OnError(const OnErrorCallback & onErrorCallback)1007 void TLSSocket::OnError(const OnErrorCallback &onErrorCallback)
1008 {
1009     std::lock_guard<std::mutex> lock(mutex_);
1010     onErrorCallback_ = onErrorCallback;
1011 }
1012 
OffError()1013 void TLSSocket::OffError()
1014 {
1015     std::lock_guard<std::mutex> lock(mutex_);
1016     if (onErrorCallback_) {
1017         onErrorCallback_ = nullptr;
1018     }
1019 }
1020 
ExecSocketConnect(const std::string & hostName,int port,sa_family_t family,int socketDescriptor)1021 bool ExecSocketConnect(const std::string &hostName, int port, sa_family_t family, int socketDescriptor)
1022 {
1023     struct sockaddr_in dest = {0};
1024     dest.sin_family = family;
1025     dest.sin_port = htons(port);
1026     if (!inet_aton(hostName.c_str(), reinterpret_cast<in_addr *>(&dest.sin_addr.s_addr))) {
1027         NETSTACK_LOGE("inet_aton is error, hostName is %s", hostName.c_str());
1028         return false;
1029     }
1030     int connectResult = connect(socketDescriptor, reinterpret_cast<struct sockaddr *>(&dest), sizeof(dest));
1031     if (connectResult == -1) {
1032         NETSTACK_LOGE("socket connect error!The error code is %{public}d, The error message is %{public}s", errno,
1033                       strerror(errno));
1034         return false;
1035     }
1036     return true;
1037 }
1038 
TlsConnectToHost(int sock,const TLSConnectOptions & options)1039 bool TLSSocket::TLSSocketInternal::TlsConnectToHost(int sock, const TLSConnectOptions &options)
1040 {
1041     SetTlsConfiguration(options);
1042     std::string cipherSuite = options.GetTlsSecureOptions().GetCipherSuite();
1043     if (!cipherSuite.empty()) {
1044         configuration_.SetCipherSuite(cipherSuite);
1045     }
1046     std::string signatureAlgorithms = options.GetTlsSecureOptions().GetSignatureAlgorithms();
1047     if (!signatureAlgorithms.empty()) {
1048         configuration_.SetSignatureAlgorithms(signatureAlgorithms);
1049     }
1050     const auto protocolVec = options.GetTlsSecureOptions().GetProtocolChain();
1051     if (!protocolVec.empty()) {
1052         configuration_.SetProtocol(protocolVec);
1053     }
1054 
1055     hostName_ = options.GetNetAddress().GetAddress();
1056     port_ = options.GetNetAddress().GetPort();
1057     family_ = options.GetNetAddress().GetSaFamily();
1058     socketDescriptor_ = sock;
1059     if (!ExecSocketConnect(options.GetNetAddress().GetAddress(), options.GetNetAddress().GetPort(),
1060                            options.GetNetAddress().GetSaFamily(), socketDescriptor_)) {
1061         return false;
1062     }
1063     return StartTlsConnected(options);
1064 }
1065 
SetTlsConfiguration(const TLSConnectOptions & config)1066 void TLSSocket::TLSSocketInternal::SetTlsConfiguration(const TLSConnectOptions &config)
1067 {
1068     configuration_.SetPrivateKey(config.GetTlsSecureOptions().GetKey(), config.GetTlsSecureOptions().GetKeyPass());
1069     configuration_.SetLocalCertificate(config.GetTlsSecureOptions().GetCert());
1070     configuration_.SetCaCertificate(config.GetTlsSecureOptions().GetCaChain());
1071 }
1072 
Send(const std::string & data)1073 bool TLSSocket::TLSSocketInternal::Send(const std::string &data)
1074 {
1075     NETSTACK_LOGD("data to send :%{public}s", data.c_str());
1076     if (data.empty()) {
1077         NETSTACK_LOGE("data is empty");
1078         return false;
1079     }
1080     if (!ssl_) {
1081         NETSTACK_LOGE("ssl is null");
1082         return false;
1083     }
1084     int len = SSL_write(ssl_, data.c_str(), data.length());
1085     if (len < 0) {
1086         int resErr = ConvertSSLError(GetSSL());
1087         NETSTACK_LOGE("data '%{public}s' send failed!The error code is %{public}d, The error message is'%{public}s'",
1088                       data.c_str(), resErr, MakeSSLErrorString(resErr).c_str());
1089         return false;
1090     }
1091     NETSTACK_LOGD("data '%{public}s' Sent successfully,sent in total %{public}d bytes!", data.c_str(), len);
1092     return true;
1093 }
Recv(char * buffer,int maxBufferSize)1094 int TLSSocket::TLSSocketInternal::Recv(char *buffer, int maxBufferSize)
1095 {
1096     if (!ssl_) {
1097         NETSTACK_LOGE("ssl is null");
1098         return SSL_ERROR_RETURN;
1099     }
1100     return SSL_read(ssl_, buffer, maxBufferSize);
1101 }
1102 
Close()1103 bool TLSSocket::TLSSocketInternal::Close()
1104 {
1105     if (!ssl_) {
1106         NETSTACK_LOGE("ssl is null");
1107         return false;
1108     }
1109     int result = SSL_shutdown(ssl_);
1110     if (result < 0) {
1111         int resErr = ConvertSSLError(GetSSL());
1112         NETSTACK_LOGE("Error in shutdown, errno is %{public}d, error info is %{public}s", resErr,
1113                       MakeSSLErrorString(resErr).c_str());
1114         return false;
1115     }
1116     SSL_free(ssl_);
1117     ssl_ = nullptr;
1118     close(socketDescriptor_);
1119     socketDescriptor_ = -1;
1120     if (!tlsContextPointer_) {
1121         NETSTACK_LOGE("Tls context pointer is null");
1122         return false;
1123     }
1124     tlsContextPointer_->CloseCtx();
1125     return true;
1126 }
1127 
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)1128 bool TLSSocket::TLSSocketInternal::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
1129 {
1130     if (!ssl_) {
1131         NETSTACK_LOGE("ssl is null");
1132         return false;
1133     }
1134     size_t pos = 0;
1135     size_t len = std::accumulate(alpnProtocols.begin(), alpnProtocols.end(), static_cast<size_t>(0),
1136                                  [](size_t init, const std::string &alpnProt) { return init + alpnProt.length(); });
1137     auto result = std::make_unique<unsigned char[]>(alpnProtocols.size() + len);
1138     for (const auto &str : alpnProtocols) {
1139         len = str.length();
1140         result[pos++] = len;
1141         if (!strcpy_s(reinterpret_cast<char *>(&result[pos]), len, str.c_str())) {
1142             NETSTACK_LOGE("strcpy_s failed");
1143             return false;
1144         }
1145         pos += len;
1146     }
1147     result[pos] = '\0';
1148 
1149     NETSTACK_LOGD("alpnProtocols after splicing %{public}s", result.get());
1150     if (SSL_set_alpn_protos(ssl_, result.get(), pos)) {
1151         int resErr = ConvertSSLError(GetSSL());
1152         NETSTACK_LOGE("Failed to set negotiable protocol list, errno is %{public}d, error info is %{public}s", resErr,
1153                       MakeSSLErrorString(resErr).c_str());
1154         return false;
1155     }
1156     return true;
1157 }
1158 
MakeRemoteInfo(Socket::SocketRemoteInfo & remoteInfo)1159 void TLSSocket::TLSSocketInternal::MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo)
1160 {
1161     remoteInfo.SetAddress(hostName_);
1162     remoteInfo.SetPort(port_);
1163     remoteInfo.SetFamily(family_);
1164 }
1165 
GetTlsConfiguration() const1166 TLSConfiguration TLSSocket::TLSSocketInternal::GetTlsConfiguration() const
1167 {
1168     return configuration_;
1169 }
1170 
GetCipherSuite() const1171 std::vector<std::string> TLSSocket::TLSSocketInternal::GetCipherSuite() const
1172 {
1173     if (!ssl_) {
1174         NETSTACK_LOGE("ssl in null");
1175         return {};
1176     }
1177     STACK_OF(SSL_CIPHER) *sk = SSL_get_ciphers(ssl_);
1178     if (!sk) {
1179         NETSTACK_LOGE("get ciphers failed");
1180         return {};
1181     }
1182     CipherSuite cipherSuite;
1183     std::vector<std::string> cipherSuiteVec;
1184     for (int i = 0; i < sk_SSL_CIPHER_num(sk); i++) {
1185         const SSL_CIPHER *c = sk_SSL_CIPHER_value(sk, i);
1186         cipherSuite.cipherName_ = SSL_CIPHER_get_name(c);
1187         cipherSuiteVec.push_back(cipherSuite.cipherName_);
1188     }
1189     return cipherSuiteVec;
1190 }
1191 
GetRemoteCertificate() const1192 std::string TLSSocket::TLSSocketInternal::GetRemoteCertificate() const
1193 {
1194     return remoteCert_;
1195 }
1196 
GetCertificate() const1197 const X509CertRawData &TLSSocket::TLSSocketInternal::GetCertificate() const
1198 {
1199     return configuration_.GetCertificate();
1200 }
1201 
GetSignatureAlgorithms() const1202 std::vector<std::string> TLSSocket::TLSSocketInternal::GetSignatureAlgorithms() const
1203 {
1204     return signatureAlgorithms_;
1205 }
1206 
GetProtocol() const1207 std::string TLSSocket::TLSSocketInternal::GetProtocol() const
1208 {
1209     if (!ssl_) {
1210         NETSTACK_LOGE("ssl in null");
1211         return PROTOCOL_UNKNOW;
1212     }
1213     if (configuration_.GetProtocol() == TLS_V1_3) {
1214         return PROTOCOL_TLS_V13;
1215     }
1216     return PROTOCOL_TLS_V12;
1217 }
1218 
SetSharedSigals()1219 bool TLSSocket::TLSSocketInternal::SetSharedSigals()
1220 {
1221     if (!ssl_) {
1222         NETSTACK_LOGE("ssl is null");
1223         return false;
1224     }
1225     int number = SSL_get_shared_sigalgs(ssl_, 0, nullptr, nullptr, nullptr, nullptr, nullptr);
1226     if (!number) {
1227         NETSTACK_LOGE("SSL_get_shared_sigalgs return value error");
1228         return false;
1229     }
1230     for (int i = 0; i < number; i++) {
1231         int hash_nid;
1232         int sign_nid;
1233         std::string sig_with_md;
1234         SSL_get_shared_sigalgs(ssl_, i, &sign_nid, &hash_nid, nullptr, nullptr, nullptr);
1235         switch (sign_nid) {
1236             case EVP_PKEY_RSA:
1237                 sig_with_md = SIGN_NID_RSA;
1238                 break;
1239             case EVP_PKEY_RSA_PSS:
1240                 sig_with_md = SIGN_NID_RSA_PSS;
1241                 break;
1242             case EVP_PKEY_DSA:
1243                 sig_with_md = SIGN_NID_DSA;
1244                 break;
1245             case EVP_PKEY_EC:
1246                 sig_with_md = SIGN_NID_ECDSA;
1247                 break;
1248             case NID_ED25519:
1249                 sig_with_md = SIGN_NID_ED;
1250                 break;
1251             case NID_ED448:
1252                 sig_with_md = SIGN_NID_ED_FOUR_FOUR_EIGHT;
1253                 break;
1254             default:
1255                 const char *sn = OBJ_nid2sn(sign_nid);
1256                 sig_with_md = (sn != nullptr) ? (std::string(sn) + OPERATOR_PLUS_SIGN) : SIGN_NID_UNDEF_ADD;
1257         }
1258         const char *sn_hash = OBJ_nid2sn(hash_nid);
1259         sig_with_md += (sn_hash != nullptr) ? std::string(sn_hash) : SIGN_NID_UNDEF;
1260         signatureAlgorithms_.push_back(sig_with_md);
1261     }
1262     return true;
1263 }
1264 
StartTlsConnected(const TLSConnectOptions & options)1265 bool TLSSocket::TLSSocketInternal::StartTlsConnected(const TLSConnectOptions &options)
1266 {
1267     if (!CreatTlsContext()) {
1268         NETSTACK_LOGE("failed to create tls context");
1269         return false;
1270     }
1271     if (!StartShakingHands(options)) {
1272         NETSTACK_LOGE("failed to shaking hands");
1273         return false;
1274     }
1275     return true;
1276 }
1277 
CreatTlsContext()1278 bool TLSSocket::TLSSocketInternal::CreatTlsContext()
1279 {
1280     tlsContextPointer_ = TLSContext::CreateConfiguration(configuration_);
1281     if (!tlsContextPointer_) {
1282         NETSTACK_LOGE("failed to create tls context pointer");
1283         return false;
1284     }
1285     if (!(ssl_ = tlsContextPointer_->CreateSsl())) {
1286         NETSTACK_LOGE("failed to create ssl session");
1287         return false;
1288     }
1289     SSL_set_fd(ssl_, socketDescriptor_);
1290     SSL_set_connect_state(ssl_);
1291     return true;
1292 }
1293 
StartsWith(const std::string & s,const std::string & prefix)1294 static bool StartsWith(const std::string &s, const std::string &prefix)
1295 {
1296     return s.size() >= prefix.size() && s.compare(0, prefix.size(), prefix) == 0;
1297 }
1298 
CheckIpAndDnsName(const std::string & hostName,std::vector<std::string> dnsNames,std::vector<std::string> ips,const X509 * x509Certificates,std::tuple<bool,std::string> & result)1299 void CheckIpAndDnsName(const std::string &hostName, std::vector<std::string> dnsNames, std::vector<std::string> ips,
1300                        const X509 *x509Certificates, std::tuple<bool, std::string> &result)
1301 {
1302     bool valid = false;
1303     std::string reason = UNKNOW_REASON;
1304     int index = X509_get_ext_by_NID(x509Certificates, NID_commonName, -1);
1305     if (IsIP(hostName)) {
1306         auto it = find(ips.begin(), ips.end(), hostName);
1307         if (it == ips.end()) {
1308             reason = IP + hostName + " is not in the cert's list";
1309         }
1310         result = {valid, reason};
1311         return;
1312     }
1313     std::string tempHostName = "" + hostName;
1314     if (!dnsNames.empty() || index > 0) {
1315         std::vector<std::string> hostParts = SplitHostName(tempHostName);
1316         if (!dnsNames.empty()) {
1317             valid = SeekIntersection(hostParts, dnsNames);
1318             if (!valid) {
1319                 reason = HOST_NAME + tempHostName + ". is not in the cert's altnames";
1320             }
1321         } else {
1322             char commonNameBuf[COMMON_NAME_BUF_SIZE] = {0};
1323             X509_NAME *pSubName = nullptr;
1324             int len = X509_NAME_get_text_by_NID(pSubName, NID_commonName, commonNameBuf, COMMON_NAME_BUF_SIZE);
1325             if (len > 0) {
1326                 std::vector<std::string> commonNameVec;
1327                 commonNameVec.emplace_back(commonNameBuf);
1328                 valid = SeekIntersection(hostParts, commonNameVec);
1329                 if (!valid) {
1330                     reason = HOST_NAME + tempHostName + ". is not cert's CN";
1331                 }
1332             }
1333         }
1334         result = {valid, reason};
1335         return;
1336     }
1337     reason = "Cert does not contain a DNS name";
1338     result = {valid, reason};
1339 }
1340 
CheckServerIdentityLegal(const std::string & hostName,const X509 * x509Certificates)1341 std::string TLSSocket::TLSSocketInternal::CheckServerIdentityLegal(const std::string &hostName,
1342                                                                    const X509 *x509Certificates)
1343 {
1344     X509_NAME *subjectName = X509_get_subject_name(x509Certificates);
1345     if (!subjectName) {
1346         return "subject name is null";
1347     }
1348     char subNameBuf[BUF_SIZE] = {0};
1349     X509_NAME_oneline(subjectName, subNameBuf, BUF_SIZE);
1350 
1351     int index = X509_get_ext_by_NID(x509Certificates, NID_subject_alt_name, -1);
1352     if (index < 0) {
1353         return "X509 get ext nid error";
1354     }
1355     X509_EXTENSION *ext = X509_get_ext(x509Certificates, index);
1356     if (ext == nullptr) {
1357         return "X509 get ext error";
1358     }
1359     ASN1_OBJECT *obj = nullptr;
1360     obj = X509_EXTENSION_get_object(ext);
1361     char subAltNameBuf[BUF_SIZE] = {0};
1362     OBJ_obj2txt(subAltNameBuf, BUF_SIZE, obj, 0);
1363     NETSTACK_LOGD("extions obj : %{public}s\n", subAltNameBuf);
1364 
1365     return CheckServerIdentityLegal(hostName, ext, x509Certificates);
1366 }
1367 
CheckServerIdentityLegal(const std::string & hostName,X509_EXTENSION * ext,const X509 * x509Certificates)1368 std::string TLSSocket::TLSSocketInternal::CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext,
1369                                                                    const X509 *x509Certificates)
1370 {
1371     ASN1_OCTET_STRING *extData = X509_EXTENSION_get_data(ext);
1372     std::string altNames = reinterpret_cast<char *>(extData->data);
1373     std::string hostname = " " + hostName;
1374     BIO *bio = BIO_new(BIO_s_file());
1375     if (!bio) {
1376         return "bio is null";
1377     }
1378     BIO_set_fp(bio, stdout, BIO_NOCLOSE);
1379     ASN1_STRING_print(bio, extData);
1380     std::vector<std::string> dnsNames = {};
1381     std::vector<std::string> ips = {};
1382     constexpr int DNS_NAME_IDX = 4;
1383     constexpr int IP_NAME_IDX = 11;
1384     if (!altNames.empty()) {
1385         std::vector<std::string> splitAltNames;
1386         if (altNames.find('\"') != std::string::npos) {
1387             splitAltNames = SplitEscapedAltNames(altNames);
1388         } else {
1389             splitAltNames = CommonUtils::Split(altNames, SPLIT_ALT_NAMES);
1390         }
1391         for (auto const &iter : splitAltNames) {
1392             if (StartsWith(iter, DNS)) {
1393                 dnsNames.push_back(iter.substr(DNS_NAME_IDX));
1394             } else if (StartsWith(iter, IP_ADDRESS)) {
1395                 ips.push_back(iter.substr(IP_NAME_IDX));
1396             }
1397         }
1398     }
1399     std::tuple<bool, std::string> result;
1400     CheckIpAndDnsName(hostName, dnsNames, ips, x509Certificates, result);
1401     if (!std::get<0>(result)) {
1402         return "Hostname/IP does not match certificate's altnames: " + std::get<1>(result);
1403     }
1404     return HOST_NAME + hostname + ". is cert's CN";
1405 }
1406 
StartShakingHands(const TLSConnectOptions & options)1407 bool TLSSocket::TLSSocketInternal::StartShakingHands(const TLSConnectOptions &options)
1408 {
1409     if (!ssl_) {
1410         NETSTACK_LOGE("ssl is null");
1411         return false;
1412     }
1413     int result = SSL_connect(ssl_);
1414     if (result == -1) {
1415         int errorStatus = ConvertSSLError(ssl_);
1416         NETSTACK_LOGE("SSL connect is error, errno is %{public}d, error info is %{public}s", errorStatus,
1417                       MakeSSLErrorString(errorStatus).c_str());
1418         return false;
1419     }
1420 
1421     std::string list = SSL_get_cipher_list(ssl_, 0);
1422     NETSTACK_LOGI("SSL_get_cipher_list: %{public}s", list.c_str());
1423     configuration_.SetCipherSuite(list);
1424     if (!SetSharedSigals()) {
1425         NETSTACK_LOGE("Failed to set sharedSigalgs");
1426     }
1427     if (!GetRemoteCertificateFromPeer()) {
1428         NETSTACK_LOGE("Failed to get remote certificate");
1429     }
1430     if (!peerX509_) {
1431         NETSTACK_LOGE("peer x509Certificates is null");
1432         return false;
1433     }
1434     if (!SetRemoteCertRawData()) {
1435         NETSTACK_LOGE("Failed to set remote x509 certificata Serialization data");
1436     }
1437     CheckServerIdentity checkServerIdentity = options.GetCheckServerIdentity();
1438     if (!checkServerIdentity) {
1439         CheckServerIdentityLegal(hostName_, peerX509_);
1440     } else {
1441         checkServerIdentity(hostName_, {remoteCert_});
1442     }
1443     NETSTACK_LOGI("SSL Get Version: %{public}s, SSL Get Cipher: %{public}s", SSL_get_version(ssl_),
1444                   SSL_get_cipher(ssl_));
1445     return true;
1446 }
1447 
GetRemoteCertificateFromPeer()1448 bool TLSSocket::TLSSocketInternal::GetRemoteCertificateFromPeer()
1449 {
1450     peerX509_ = SSL_get_peer_certificate(ssl_);
1451     if (peerX509_ == nullptr) {
1452         int resErr = ConvertSSLError(GetSSL());
1453         NETSTACK_LOGE("open fail errno, errno is %{public}d, error info is %{public}s", resErr,
1454                       MakeSSLErrorString(resErr).c_str());
1455         return false;
1456     }
1457     BIO *bio = BIO_new(BIO_s_mem());
1458     if (!bio) {
1459         NETSTACK_LOGE("TlsSocket::SetRemoteCertificate bio is null");
1460         return false;
1461     }
1462     X509_print(bio, peerX509_);
1463     char data[REMOTE_CERT_LEN] = {0};
1464     if (!BIO_read(bio, data, REMOTE_CERT_LEN)) {
1465         NETSTACK_LOGE("BIO_read function returns error");
1466         BIO_free(bio);
1467         return false;
1468     }
1469     BIO_free(bio);
1470     remoteCert_ = std::string(data);
1471     return true;
1472 }
1473 
SetRemoteCertRawData()1474 bool TLSSocket::TLSSocketInternal::SetRemoteCertRawData()
1475 {
1476     if (peerX509_ == nullptr) {
1477         NETSTACK_LOGE("peerX509 is null");
1478         return false;
1479     }
1480     int32_t length = i2d_X509(peerX509_, nullptr);
1481     if (length <= 0) {
1482         NETSTACK_LOGE("Failed to convert peerX509 to der format");
1483         return false;
1484     }
1485     unsigned char *der = nullptr;
1486     (void)i2d_X509(peerX509_, &der);
1487     SecureData data(der, length);
1488     remoteRawData_.data = data;
1489     OPENSSL_free(der);
1490     remoteRawData_.encodingFormat = DER;
1491     return true;
1492 }
1493 
GetRemoteCertRawData() const1494 const X509CertRawData &TLSSocket::TLSSocketInternal::GetRemoteCertRawData() const
1495 {
1496     return remoteRawData_;
1497 }
1498 
GetSSL() const1499 ssl_st *TLSSocket::TLSSocketInternal::GetSSL() const
1500 {
1501     return ssl_;
1502 }
1503 } // namespace TlsSocket
1504 } // namespace NetStack
1505 } // namespace OHOS
1506