• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "tls_socket.h"
17 
18 #include <chrono>
19 #include <numeric>
20 #include <regex>
21 #include <securec.h>
22 #include <thread>
23 #include <memory>
24 
25 #include <netinet/tcp.h>
26 #include <openssl/err.h>
27 #include <openssl/ssl.h>
28 
29 #include "netmanager_base_permission.h"
30 #include "netstack_common_utils.h"
31 #include "netstack_log.h"
32 #include "tls.h"
33 
34 namespace OHOS {
35 namespace NetStack {
36 namespace {
37 constexpr int WAIT_MS = 10;
38 constexpr int TIMEOUT_MS = 10000;
39 constexpr int REMOTE_CERT_LEN = 8192;
40 constexpr int COMMON_NAME_BUF_SIZE = 256;
41 constexpr int BUF_SIZE = 2048;
42 constexpr int SSL_RET_CODE = 0;
43 constexpr int SSL_ERROR_RETURN = -1;
44 constexpr int OFFSET = 2;
45 constexpr const char *SPLIT_ALT_NAMES = ",";
46 constexpr const char *SPLIT_HOST_NAME = ".";
47 constexpr const char *PROTOCOL_UNKNOW = "UNKNOW_PROTOCOL";
48 constexpr const char *UNKNOW_REASON = "Unknown reason";
49 constexpr const char *IP = "IP: ";
50 constexpr const char *HOST_NAME = "hostname: ";
51 constexpr const char *DNS = "DNS:";
52 constexpr const char *IP_ADDRESS = "IP Address:";
53 constexpr const char *SIGN_NID_RSA = "RSA+";
54 constexpr const char *SIGN_NID_RSA_PSS = "RSA-PSS+";
55 constexpr const char *SIGN_NID_DSA = "DSA+";
56 constexpr const char *SIGN_NID_ECDSA = "ECDSA+";
57 constexpr const char *SIGN_NID_ED = "Ed25519+";
58 constexpr const char *SIGN_NID_ED_FOUR_FOUR_EIGHT = "Ed448+";
59 constexpr const char *SIGN_NID_UNDEF_ADD = "UNDEF+";
60 constexpr const char *SIGN_NID_UNDEF = "UNDEF";
61 constexpr const char *OPERATOR_PLUS_SIGN = "+";
62 const std::regex JSON_STRING_PATTERN{R"(/^"(?:[^"\\\u0000-\u001f]|\\(?:["\\/bfnrt]|u[0-9a-fA-F]{4}))*"/)"};
63 const std::regex PATTERN{
64     "((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|"
65     "2[0-4][0-9]|[01]?[0-9][0-9]?)"};
66 
ConvertErrno()67 int ConvertErrno()
68 {
69     return TlsSocketError::TLS_ERR_SYS_BASE + errno;
70 }
71 
ConvertSSLError(ssl_st * ssl)72 int ConvertSSLError(ssl_st *ssl)
73 {
74     if (!ssl) {
75         return TLS_ERR_SSL_NULL;
76     }
77     return TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl, SSL_RET_CODE);
78 }
79 
MakeErrnoString()80 std::string MakeErrnoString()
81 {
82     return strerror(errno);
83 }
84 
MakeSSLErrorString(int error)85 std::string MakeSSLErrorString(int error)
86 {
87     char err[MAX_ERR_LEN] = {0};
88     ERR_error_string_n(error - TlsSocketError::TLS_ERR_SYS_BASE, err, sizeof(err));
89     return err;
90 }
91 
SplitEscapedAltNames(std::string & altNames)92 std::vector<std::string> SplitEscapedAltNames(std::string &altNames)
93 {
94     std::vector<std::string> result;
95     std::string currentToken;
96     size_t offset = 0;
97     while (offset != altNames.length()) {
98         auto nextSep = altNames.find_first_of(", ");
99         auto nextQuote = altNames.find_first_of('\"');
100         if (nextQuote != std::string::npos && (nextSep != std::string::npos || nextQuote < nextSep)) {
101             currentToken += altNames.substr(offset, nextQuote);
102             std::regex jsonStringPattern(JSON_STRING_PATTERN);
103             std::smatch match;
104             std::string altNameSubStr = altNames.substr(nextQuote);
105             bool ret = regex_match(altNameSubStr, match, jsonStringPattern);
106             if (!ret) {
107                 return {""};
108             }
109             currentToken += result[0];
110             offset = nextQuote + result[0].length();
111         } else if (nextSep != std::string::npos) {
112             currentToken += altNames.substr(offset, nextSep);
113             result.push_back(currentToken);
114             currentToken = "";
115             offset = nextSep + OFFSET;
116         } else {
117             currentToken += altNames.substr(offset);
118             offset = altNames.length();
119         }
120     }
121     result.push_back(currentToken);
122     return result;
123 }
124 
IsIP(const std::string & ip)125 bool IsIP(const std::string &ip)
126 {
127     std::regex pattern(PATTERN);
128     std::smatch res;
129     return regex_match(ip, res, pattern);
130 }
131 
SplitHostName(std::string & hostName)132 std::vector<std::string> SplitHostName(std::string &hostName)
133 {
134     transform(hostName.begin(), hostName.end(), hostName.begin(), ::tolower);
135     return CommonUtils::Split(hostName, SPLIT_HOST_NAME);
136 }
137 
SeekIntersection(std::vector<std::string> & vecA,std::vector<std::string> & vecB)138 bool SeekIntersection(std::vector<std::string> &vecA, std::vector<std::string> &vecB)
139 {
140     std::vector<std::string> result;
141     set_intersection(vecA.begin(), vecA.end(), vecB.begin(), vecB.end(), inserter(result, result.begin()));
142     return !result.empty();
143 }
144 } // namespace
145 
TLSSecureOptions(const TLSSecureOptions & tlsSecureOptions)146 TLSSecureOptions::TLSSecureOptions(const TLSSecureOptions &tlsSecureOptions)
147 {
148     *this = tlsSecureOptions;
149 }
150 
operator =(const TLSSecureOptions & tlsSecureOptions)151 TLSSecureOptions &TLSSecureOptions::operator=(const TLSSecureOptions &tlsSecureOptions)
152 {
153     key_ = tlsSecureOptions.GetKey();
154     caChain_ = tlsSecureOptions.GetCaChain();
155     cert_ = tlsSecureOptions.GetCert();
156     protocolChain_ = tlsSecureOptions.GetProtocolChain();
157     crlChain_ = tlsSecureOptions.GetCrlChain();
158     keyPass_ = tlsSecureOptions.GetKeyPass();
159     key_ = tlsSecureOptions.GetKey();
160     signatureAlgorithms_ = tlsSecureOptions.GetSignatureAlgorithms();
161     cipherSuite_ = tlsSecureOptions.GetCipherSuite();
162     useRemoteCipherPrefer_ = tlsSecureOptions.UseRemoteCipherPrefer();
163     return *this;
164 }
165 
SetCaChain(const std::vector<std::string> & caChain)166 void TLSSecureOptions::SetCaChain(const std::vector<std::string> &caChain)
167 {
168     caChain_ = caChain;
169 }
170 
SetCert(const std::string & cert)171 void TLSSecureOptions::SetCert(const std::string &cert)
172 {
173     cert_ = cert;
174 }
175 
SetKey(const SecureData & key)176 void TLSSecureOptions::SetKey(const SecureData &key)
177 {
178     key_ = key;
179 }
180 
SetKeyPass(const SecureData & keyPass)181 void TLSSecureOptions::SetKeyPass(const SecureData &keyPass)
182 {
183     keyPass_ = keyPass;
184 }
185 
SetProtocolChain(const std::vector<std::string> & protocolChain)186 void TLSSecureOptions::SetProtocolChain(const std::vector<std::string> &protocolChain)
187 {
188     protocolChain_ = protocolChain;
189 }
190 
SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer)191 void TLSSecureOptions::SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer)
192 {
193     useRemoteCipherPrefer_ = useRemoteCipherPrefer;
194 }
195 
SetSignatureAlgorithms(const std::string & signatureAlgorithms)196 void TLSSecureOptions::SetSignatureAlgorithms(const std::string &signatureAlgorithms)
197 {
198     signatureAlgorithms_ = signatureAlgorithms;
199 }
200 
SetCipherSuite(const std::string & cipherSuite)201 void TLSSecureOptions::SetCipherSuite(const std::string &cipherSuite)
202 {
203     cipherSuite_ = cipherSuite;
204 }
205 
SetCrlChain(const std::vector<std::string> & crlChain)206 void TLSSecureOptions::SetCrlChain(const std::vector<std::string> &crlChain)
207 {
208     crlChain_ = crlChain;
209 }
210 
GetCaChain() const211 const std::vector<std::string> &TLSSecureOptions::GetCaChain() const
212 {
213     return caChain_;
214 }
215 
GetCert() const216 const std::string &TLSSecureOptions::GetCert() const
217 {
218     return cert_;
219 }
220 
GetKey() const221 const SecureData &TLSSecureOptions::GetKey() const
222 {
223     return key_;
224 }
225 
GetKeyPass() const226 const SecureData &TLSSecureOptions::GetKeyPass() const
227 {
228     return keyPass_;
229 }
230 
GetProtocolChain() const231 const std::vector<std::string> &TLSSecureOptions::GetProtocolChain() const
232 {
233     return protocolChain_;
234 }
235 
UseRemoteCipherPrefer() const236 bool TLSSecureOptions::UseRemoteCipherPrefer() const
237 {
238     return useRemoteCipherPrefer_;
239 }
240 
GetSignatureAlgorithms() const241 const std::string &TLSSecureOptions::GetSignatureAlgorithms() const
242 {
243     return signatureAlgorithms_;
244 }
245 
GetCipherSuite() const246 const std::string &TLSSecureOptions::GetCipherSuite() const
247 {
248     return cipherSuite_;
249 }
250 
GetCrlChain() const251 const std::vector<std::string> &TLSSecureOptions::GetCrlChain() const
252 {
253     return crlChain_;
254 }
255 
SetNetAddress(const NetAddress & address)256 void TLSConnectOptions::SetNetAddress(const NetAddress &address)
257 {
258     address_.SetAddress(address.GetAddress());
259     address_.SetPort(address.GetPort());
260     address_.SetFamilyBySaFamily(address.GetSaFamily());
261 }
262 
SetTlsSecureOptions(TLSSecureOptions & tlsSecureOptions)263 void TLSConnectOptions::SetTlsSecureOptions(TLSSecureOptions &tlsSecureOptions)
264 {
265     tlsSecureOptions_ = tlsSecureOptions;
266 }
267 
SetCheckServerIdentity(const CheckServerIdentity & checkServerIdentity)268 void TLSConnectOptions::SetCheckServerIdentity(const CheckServerIdentity &checkServerIdentity)
269 {
270     checkServerIdentity_ = checkServerIdentity;
271 }
272 
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)273 void TLSConnectOptions::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
274 {
275     alpnProtocols_ = alpnProtocols;
276 }
277 
GetNetAddress() const278 NetAddress TLSConnectOptions::GetNetAddress() const
279 {
280     return address_;
281 }
282 
GetTlsSecureOptions() const283 TLSSecureOptions TLSConnectOptions::GetTlsSecureOptions() const
284 {
285     return tlsSecureOptions_;
286 }
287 
GetCheckServerIdentity() const288 CheckServerIdentity TLSConnectOptions::GetCheckServerIdentity() const
289 {
290     return checkServerIdentity_;
291 }
292 
GetAlpnProtocols() const293 const std::vector<std::string> &TLSConnectOptions::GetAlpnProtocols() const
294 {
295     return alpnProtocols_;
296 }
297 
MakeAddressString(sockaddr * addr)298 std::string TLSSocket::MakeAddressString(sockaddr *addr)
299 {
300     if (!addr) {
301         return {};
302     }
303     if (addr->sa_family == AF_INET) {
304         auto *addr4 = reinterpret_cast<sockaddr_in *>(addr);
305         const char *str = inet_ntoa(addr4->sin_addr);
306         if (str == nullptr || strlen(str) == 0) {
307             return {};
308         }
309         return str;
310     } else if (addr->sa_family == AF_INET6) {
311         auto *addr6 = reinterpret_cast<sockaddr_in6 *>(addr);
312         char str[INET6_ADDRSTRLEN] = {0};
313         if (inet_ntop(AF_INET6, &addr6->sin6_addr, str, INET6_ADDRSTRLEN) == nullptr || strlen(str) == 0) {
314             return {};
315         }
316         return str;
317     }
318     return {};
319 }
320 
GetAddr(const NetAddress & address,sockaddr_in * addr4,sockaddr_in6 * addr6,sockaddr ** addr,socklen_t * len)321 void TLSSocket::GetAddr(const NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr,
322                         socklen_t *len)
323 {
324     if (!addr6 || !addr4 || !len) {
325         return;
326     }
327     sa_family_t family = address.GetSaFamily();
328     if (family == AF_INET) {
329         addr4->sin_family = AF_INET;
330         addr4->sin_port = htons(address.GetPort());
331         addr4->sin_addr.s_addr = inet_addr(address.GetAddress().c_str());
332         *addr = reinterpret_cast<sockaddr *>(addr4);
333         *len = sizeof(sockaddr_in);
334     } else if (family == AF_INET6) {
335         addr6->sin6_family = AF_INET6;
336         addr6->sin6_port = htons(address.GetPort());
337         inet_pton(AF_INET6, address.GetAddress().c_str(), &addr6->sin6_addr);
338         *addr = reinterpret_cast<sockaddr *>(addr6);
339         *len = sizeof(sockaddr_in6);
340     }
341 }
342 
MakeIpSocket(sa_family_t family)343 void TLSSocket::MakeIpSocket(sa_family_t family)
344 {
345     if (family != AF_INET && family != AF_INET6) {
346         return;
347     }
348     int sock = socket(family, SOCK_STREAM, IPPROTO_IP);
349     if (sock < 0) {
350         int resErr = ConvertErrno();
351         NETSTACK_LOGE("Create socket failed (%{public}d:%{public}s)", errno, MakeErrnoString().c_str());
352         CallOnErrorCallback(resErr, MakeErrnoString());
353         return;
354     }
355     sockFd_ = sock;
356 }
357 
StartReadMessage()358 void TLSSocket::StartReadMessage()
359 {
360     std::thread thread([this]() {
361         isRunning_ = true;
362         isRunOver_ = false;
363         char buffer[MAX_BUFFER_SIZE];
364         bzero(buffer, MAX_BUFFER_SIZE);
365         while (isRunning_) {
366             int len = tlsSocketInternal_.Recv(buffer, MAX_BUFFER_SIZE);
367             if (!isRunning_) {
368                 break;
369             }
370             if (len < 0) {
371                 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
372                 NETSTACK_LOGE("SSL_read function read error, errno is %{public}d, errno info is %{public}s", resErr,
373                               MakeSSLErrorString(resErr).c_str());
374                 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
375                 break;
376             }
377 
378             if (len == 0) {
379                 continue;
380             }
381 
382             SocketRemoteInfo remoteInfo;
383             remoteInfo.SetSize(len);
384             tlsSocketInternal_.MakeRemoteInfo(remoteInfo);
385             CallOnMessageCallback(buffer, remoteInfo);
386         }
387         isRunOver_ = true;
388     });
389     thread.detach();
390 }
391 
CallOnMessageCallback(const std::string & data,const OHOS::NetStack::SocketRemoteInfo & remoteInfo)392 void TLSSocket::CallOnMessageCallback(const std::string &data, const OHOS::NetStack::SocketRemoteInfo &remoteInfo)
393 {
394     OnMessageCallback func = nullptr;
395     {
396         std::lock_guard<std::mutex> lock(mutex_);
397         if (onMessageCallback_) {
398             func = onMessageCallback_;
399         }
400     }
401 
402     if (func) {
403         func(data, remoteInfo);
404     }
405 }
406 
CallOnConnectCallback()407 void TLSSocket::CallOnConnectCallback()
408 {
409     OnConnectCallback func = nullptr;
410     {
411         std::lock_guard<std::mutex> lock(mutex_);
412         if (onConnectCallback_) {
413             func = onConnectCallback_;
414         }
415     }
416 
417     if (func) {
418         func();
419     }
420 }
421 
CallOnCloseCallback()422 void TLSSocket::CallOnCloseCallback()
423 {
424     OnCloseCallback func = nullptr;
425     {
426         std::lock_guard<std::mutex> lock(mutex_);
427         if (onCloseCallback_) {
428             func = onCloseCallback_;
429         }
430     }
431 
432     if (func) {
433         func();
434     }
435 }
436 
CallOnErrorCallback(int32_t err,const std::string & errString)437 void TLSSocket::CallOnErrorCallback(int32_t err, const std::string &errString)
438 {
439     OnErrorCallback func = nullptr;
440     {
441         std::lock_guard<std::mutex> lock(mutex_);
442         if (onErrorCallback_) {
443             func = onErrorCallback_;
444         }
445     }
446 
447     if (func) {
448         func(err, errString);
449     }
450 }
451 
CallBindCallback(int32_t err,BindCallback callback)452 void TLSSocket::CallBindCallback(int32_t err, BindCallback callback)
453 {
454     BindCallback func = nullptr;
455     {
456         std::lock_guard<std::mutex> lock(mutex_);
457         if (callback) {
458             func = callback;
459         }
460     }
461 
462     if (func) {
463         func(err);
464     }
465 }
466 
CallConnectCallback(int32_t err,ConnectCallback callback)467 void TLSSocket::CallConnectCallback(int32_t err, ConnectCallback callback)
468 {
469     ConnectCallback func = nullptr;
470     {
471         std::lock_guard<std::mutex> lock(mutex_);
472         if (callback) {
473             func = callback;
474         }
475     }
476 
477     if (func) {
478         func(err);
479     }
480 }
481 
CallSendCallback(int32_t err,SendCallback callback)482 void TLSSocket::CallSendCallback(int32_t err, SendCallback callback)
483 {
484     SendCallback func = nullptr;
485     {
486         std::lock_guard<std::mutex> lock(mutex_);
487         if (callback) {
488             func = callback;
489         }
490     }
491 
492     if (func) {
493         func(err);
494     }
495 }
496 
CallCloseCallback(int32_t err,CloseCallback callback)497 void TLSSocket::CallCloseCallback(int32_t err, CloseCallback callback)
498 {
499     CloseCallback func = nullptr;
500     {
501         std::lock_guard<std::mutex> lock(mutex_);
502         if (callback) {
503             func = callback;
504         }
505     }
506 
507     if (func) {
508         func(err);
509     }
510 }
511 
CallGetRemoteAddressCallback(int32_t err,const NetAddress & address,GetRemoteAddressCallback callback)512 void TLSSocket::CallGetRemoteAddressCallback(int32_t err, const NetAddress &address, GetRemoteAddressCallback callback)
513 {
514     GetRemoteAddressCallback func = nullptr;
515     {
516         std::lock_guard<std::mutex> lock(mutex_);
517         if (callback) {
518             func = callback;
519         }
520     }
521 
522     if (func) {
523         func(err, address);
524     }
525 }
526 
CallGetStateCallback(int32_t err,const SocketStateBase & state,GetStateCallback callback)527 void TLSSocket::CallGetStateCallback(int32_t err, const SocketStateBase &state, GetStateCallback callback)
528 {
529     GetStateCallback func = nullptr;
530     {
531         std::lock_guard<std::mutex> lock(mutex_);
532         if (callback) {
533             func = callback;
534         }
535     }
536 
537     if (func) {
538         func(err, state);
539     }
540 }
541 
CallSetExtraOptionsCallback(int32_t err,SetExtraOptionsCallback callback)542 void TLSSocket::CallSetExtraOptionsCallback(int32_t err, SetExtraOptionsCallback callback)
543 {
544     SetExtraOptionsCallback func = nullptr;
545     {
546         std::lock_guard<std::mutex> lock(mutex_);
547         if (callback) {
548             func = callback;
549         }
550     }
551 
552     if (func) {
553         func(err);
554     }
555 }
556 
CallGetCertificateCallback(int32_t err,const X509CertRawData & cert,GetCertificateCallback callback)557 void TLSSocket::CallGetCertificateCallback(int32_t err, const X509CertRawData &cert, GetCertificateCallback callback)
558 {
559     GetCertificateCallback func = nullptr;
560     {
561         std::lock_guard<std::mutex> lock(mutex_);
562         if (callback) {
563             func = callback;
564         }
565     }
566 
567     if (func) {
568         func(err, cert);
569     }
570 }
571 
CallGetRemoteCertificateCallback(int32_t err,const X509CertRawData & cert,GetRemoteCertificateCallback callback)572 void TLSSocket::CallGetRemoteCertificateCallback(int32_t err, const X509CertRawData &cert,
573                                                  GetRemoteCertificateCallback callback)
574 {
575     GetRemoteCertificateCallback func = nullptr;
576     {
577         std::lock_guard<std::mutex> lock(mutex_);
578         if (callback) {
579             func = callback;
580         }
581     }
582 
583     if (func) {
584         func(err, cert);
585     }
586 }
587 
CallGetProtocolCallback(int32_t err,const std::string & protocol,GetProtocolCallback callback)588 void TLSSocket::CallGetProtocolCallback(int32_t err, const std::string &protocol, GetProtocolCallback callback)
589 {
590     GetProtocolCallback func = nullptr;
591     {
592         std::lock_guard<std::mutex> lock(mutex_);
593         if (callback) {
594             func = callback;
595         }
596     }
597 
598     if (func) {
599         func(err, protocol);
600     }
601 }
602 
CallGetCipherSuiteCallback(int32_t err,const std::vector<std::string> & suite,GetCipherSuiteCallback callback)603 void TLSSocket::CallGetCipherSuiteCallback(int32_t err, const std::vector<std::string> &suite,
604                                            GetCipherSuiteCallback callback)
605 {
606     GetCipherSuiteCallback func = nullptr;
607     {
608         std::lock_guard<std::mutex> lock(mutex_);
609         if (callback) {
610             func = callback;
611         }
612     }
613 
614     if (func) {
615         func(err, suite);
616     }
617 }
618 
CallGetSignatureAlgorithmsCallback(int32_t err,const std::vector<std::string> & algorithms,GetSignatureAlgorithmsCallback callback)619 void TLSSocket::CallGetSignatureAlgorithmsCallback(int32_t err, const std::vector<std::string> &algorithms,
620                                                    GetSignatureAlgorithmsCallback callback)
621 {
622     GetSignatureAlgorithmsCallback func = nullptr;
623     {
624         std::lock_guard<std::mutex> lock(mutex_);
625         if (callback) {
626             func = callback;
627         }
628     }
629 
630     if (func) {
631         func(err, algorithms);
632     }
633 }
634 
Bind(const OHOS::NetStack::NetAddress & address,const OHOS::NetStack::BindCallback & callback)635 void TLSSocket::Bind(const OHOS::NetStack::NetAddress &address, const OHOS::NetStack::BindCallback &callback)
636 {
637     if (!NetManagerStandard::NetManagerPermission::CheckPermission(NetManagerStandard::Permission::INTERNET)) {
638         NETSTACK_LOGE("Bind permission check fail.");
639         CallBindCallback(TLS_ERR_PERMISSION_DENIED, callback);
640         return;
641     }
642 
643     if (sockFd_ >= 0) {
644         CallBindCallback(TLSSOCKET_SUCCESS, callback);
645         return;
646     }
647 
648     MakeIpSocket(address.GetSaFamily());
649     if (sockFd_ < 0) {
650         int resErr = ConvertErrno();
651         NETSTACK_LOGE("make tcp socket failed errno is %{public}d %{public}s", errno, MakeErrnoString().c_str());
652         CallOnErrorCallback(resErr, MakeErrnoString());
653         CallBindCallback(resErr, callback);
654         return;
655     }
656 
657     sockaddr_in addr4 = {0};
658     sockaddr_in6 addr6 = {0};
659     sockaddr *addr = nullptr;
660     socklen_t len;
661     GetAddr(address, &addr4, &addr6, &addr, &len);
662     if (addr == nullptr) {
663         NETSTACK_LOGE("TLSSocket::Bind Address Is Invalid");
664         CallOnErrorCallback(-1, "Address Is Invalid");
665         CallBindCallback(ConvertErrno(), callback);
666         return;
667     }
668     CallBindCallback(TLSSOCKET_SUCCESS, callback);
669 }
670 
Connect(OHOS::NetStack::TLSConnectOptions & tlsConnectOptions,const OHOS::NetStack::ConnectCallback & callback)671 void TLSSocket::Connect(OHOS::NetStack::TLSConnectOptions &tlsConnectOptions,
672                         const OHOS::NetStack::ConnectCallback &callback)
673 {
674     if (sockFd_ < 0) {
675         int resErr = ConvertErrno();
676         NETSTACK_LOGE("connect error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
677         CallOnErrorCallback(resErr, MakeErrnoString());
678         callback(resErr);
679         return;
680     }
681 
682     auto res = tlsSocketInternal_.TlsConnectToHost(sockFd_, tlsConnectOptions);
683     if (!res) {
684         int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
685         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
686         callback(resErr);
687         return;
688     }
689     StartReadMessage();
690     CallOnConnectCallback();
691     callback(TLSSOCKET_SUCCESS);
692 }
693 
Send(const OHOS::NetStack::TCPSendOptions & tcpSendOptions,const OHOS::NetStack::SendCallback & callback)694 void TLSSocket::Send(const OHOS::NetStack::TCPSendOptions &tcpSendOptions, const OHOS::NetStack::SendCallback &callback)
695 {
696     (void)tcpSendOptions;
697 
698     auto res = tlsSocketInternal_.Send(tcpSendOptions.GetData());
699     if (!res) {
700         int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
701         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
702         CallSendCallback(resErr, callback);
703         return;
704     }
705     CallSendCallback(TLSSOCKET_SUCCESS, callback);
706 }
707 
WaitConditionWithTimeout(const bool * flag,const int32_t timeoutMs)708 bool WaitConditionWithTimeout(const bool *flag, const int32_t timeoutMs)
709 {
710     int maxWaitCnt = timeoutMs / WAIT_MS;
711     int cnt = 0;
712     while (!(*flag)) {
713         if (cnt >= maxWaitCnt) {
714             return false;
715         }
716         std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MS));
717         cnt++;
718     }
719     return true;
720 }
721 
Close(const OHOS::NetStack::CloseCallback & callback)722 void TLSSocket::Close(const OHOS::NetStack::CloseCallback &callback)
723 {
724     if (!WaitConditionWithTimeout(&isRunning_, TIMEOUT_MS)) {
725         callback(ConvertErrno());
726         NETSTACK_LOGE("The error cause is that the runtime wait time is insufficient");
727         return;
728     }
729     isRunning_ = false;
730     if (!WaitConditionWithTimeout(&isRunOver_, TIMEOUT_MS)) {
731         callback(ConvertErrno());
732         NETSTACK_LOGE("The error is due to insufficient delay time");
733         return;
734     }
735     auto res = tlsSocketInternal_.Close();
736     if (!res) {
737         int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
738         NETSTACK_LOGE("close error is %{public}s %{public}d", MakeSSLErrorString(resErr).c_str(), resErr);
739         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
740         callback(resErr);
741         return;
742     }
743     CallOnCloseCallback();
744     callback(TLSSOCKET_SUCCESS);
745 }
746 
GetRemoteAddress(const OHOS::NetStack::GetRemoteAddressCallback & callback)747 void TLSSocket::GetRemoteAddress(const OHOS::NetStack::GetRemoteAddressCallback &callback)
748 {
749     sa_family_t family;
750     socklen_t len = sizeof(family);
751     int ret = getsockname(sockFd_, reinterpret_cast<sockaddr *>(&family), &len);
752     if (ret < 0) {
753         int resErr = ConvertErrno();
754         NETSTACK_LOGE("getsockname failed errno %{public}d", resErr);
755         CallOnErrorCallback(resErr, MakeErrnoString());
756         CallGetRemoteAddressCallback(resErr, {}, callback);
757         return;
758     }
759 
760     if (family == AF_INET) {
761         GetIp4RemoteAddress(callback);
762     } else if (family == AF_INET6) {
763         GetIp6RemoteAddress(callback);
764     }
765 }
766 
GetIp4RemoteAddress(const OHOS::NetStack::GetRemoteAddressCallback & callback)767 void TLSSocket::GetIp4RemoteAddress(const OHOS::NetStack::GetRemoteAddressCallback &callback)
768 {
769     sockaddr_in addr4 = {0};
770     socklen_t len4 = sizeof(sockaddr_in);
771 
772     int ret = getpeername(sockFd_, reinterpret_cast<sockaddr *>(&addr4), &len4);
773     if (ret < 0) {
774         int resErr = ConvertErrno();
775         NETSTACK_LOGE("GetIp4RemoteAddress failed errno %{public}d", resErr);
776         CallOnErrorCallback(resErr, MakeErrnoString());
777         CallGetRemoteAddressCallback(resErr, {}, callback);
778         return;
779     }
780 
781     std::string address = MakeAddressString(reinterpret_cast<sockaddr *>(&addr4));
782     if (address.empty()) {
783         NETSTACK_LOGE("GetIp4RemoteAddress failed errno %{public}d", errno);
784         CallOnErrorCallback(-1, "Address is invalid");
785         CallGetRemoteAddressCallback(ConvertErrno(), {}, callback);
786         return;
787     }
788     NetAddress netAddress;
789     netAddress.SetAddress(address);
790     netAddress.SetFamilyBySaFamily(AF_INET);
791     netAddress.SetPort(ntohs(addr4.sin_port));
792     CallGetRemoteAddressCallback(TLSSOCKET_SUCCESS, netAddress, callback);
793 }
794 
GetIp6RemoteAddress(const OHOS::NetStack::GetRemoteAddressCallback & callback)795 void TLSSocket::GetIp6RemoteAddress(const OHOS::NetStack::GetRemoteAddressCallback &callback)
796 {
797     sockaddr_in6 addr6 = {0};
798     socklen_t len6 = sizeof(sockaddr_in6);
799 
800     int ret = getpeername(sockFd_, reinterpret_cast<sockaddr *>(&addr6), &len6);
801     if (ret < 0) {
802         int resErr = ConvertErrno();
803         NETSTACK_LOGE("GetIp6RemoteAddress failed errno %{public}d", resErr);
804         CallOnErrorCallback(resErr, MakeErrnoString());
805         CallGetRemoteAddressCallback(resErr, {}, callback);
806         return;
807     }
808 
809     std::string address = MakeAddressString(reinterpret_cast<sockaddr *>(&addr6));
810     if (address.empty()) {
811         NETSTACK_LOGE("GetIp6RemoteAddress failed errno %{public}d", errno);
812         CallOnErrorCallback(-1, "Address is invalid");
813         CallGetRemoteAddressCallback(ConvertErrno(), {}, callback);
814         return;
815     }
816     NetAddress netAddress;
817     netAddress.SetAddress(address);
818     netAddress.SetFamilyBySaFamily(AF_INET6);
819     netAddress.SetPort(ntohs(addr6.sin6_port));
820     CallGetRemoteAddressCallback(TLSSOCKET_SUCCESS, netAddress, callback);
821 }
822 
GetState(const OHOS::NetStack::GetStateCallback & callback)823 void TLSSocket::GetState(const OHOS::NetStack::GetStateCallback &callback)
824 {
825     int opt;
826     socklen_t optLen = sizeof(int);
827     int r = getsockopt(sockFd_, SOL_SOCKET, SO_TYPE, &opt, &optLen);
828     if (r < 0) {
829         SocketStateBase state;
830         state.SetIsClose(true);
831         CallGetStateCallback(ConvertErrno(), state, callback);
832         return;
833     }
834     sa_family_t family;
835     socklen_t len = sizeof(family);
836     SocketStateBase state;
837     int ret = getsockname(sockFd_, reinterpret_cast<sockaddr *>(&family), &len);
838     state.SetIsBound(ret == 0);
839     ret = getpeername(sockFd_, reinterpret_cast<sockaddr *>(&family), &len);
840     state.SetIsConnected(ret == 0);
841     CallGetStateCallback(TLSSOCKET_SUCCESS, state, callback);
842 }
843 
SetBaseOptions(const ExtraOptionsBase & option) const844 bool TLSSocket::SetBaseOptions(const ExtraOptionsBase &option) const
845 {
846     if (option.GetReceiveBufferSize() != 0) {
847         int size = (int)option.GetReceiveBufferSize();
848         if (setsockopt(sockFd_, SOL_SOCKET, SO_RCVBUF, reinterpret_cast<void *>(&size), sizeof(size)) < 0) {
849             return false;
850         }
851     }
852 
853     if (option.GetSendBufferSize() != 0) {
854         int size = (int)option.GetSendBufferSize();
855         if (setsockopt(sockFd_, SOL_SOCKET, SO_SNDBUF, reinterpret_cast<void *>(&size), sizeof(size)) < 0) {
856             return false;
857         }
858     }
859 
860     if (option.IsReuseAddress()) {
861         int reuse = 1;
862         if (setsockopt(sockFd_, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<void *>(&reuse), sizeof(reuse)) < 0) {
863             return false;
864         }
865     }
866 
867     if (option.GetSocketTimeout() != 0) {
868         timeval timeout = {(int)option.GetSocketTimeout(), 0};
869         if (setsockopt(sockFd_, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<void *>(&timeout), sizeof(timeout)) < 0) {
870             return false;
871         }
872         if (setsockopt(sockFd_, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast<void *>(&timeout), sizeof(timeout)) < 0) {
873             return false;
874         }
875     }
876 
877     return true;
878 }
879 
SetExtraOptions(const TCPExtraOptions & option) const880 bool TLSSocket::SetExtraOptions(const TCPExtraOptions &option) const
881 {
882     if (option.IsKeepAlive()) {
883         int keepalive = 1;
884         if (setsockopt(sockFd_, SOL_SOCKET, SO_KEEPALIVE, &keepalive, sizeof(keepalive)) < 0) {
885             return false;
886         }
887     }
888 
889     if (option.IsOOBInline()) {
890         int oobInline = 1;
891         if (setsockopt(sockFd_, SOL_SOCKET, SO_OOBINLINE, &oobInline, sizeof(oobInline)) < 0) {
892             return false;
893         }
894     }
895 
896     if (option.IsTCPNoDelay()) {
897         int tcpNoDelay = 1;
898         if (setsockopt(sockFd_, IPPROTO_TCP, TCP_NODELAY, &tcpNoDelay, sizeof(tcpNoDelay)) < 0) {
899             return false;
900         }
901     }
902 
903     linger soLinger = {0};
904     soLinger.l_onoff = option.socketLinger.IsOn();
905     soLinger.l_linger = (int)option.socketLinger.GetLinger();
906     if (setsockopt(sockFd_, SOL_SOCKET, SO_LINGER, &soLinger, sizeof(soLinger)) < 0) {
907         return false;
908     }
909 
910     return true;
911 }
912 
SetExtraOptions(const OHOS::NetStack::TCPExtraOptions & tcpExtraOptions,const OHOS::NetStack::SetExtraOptionsCallback & callback)913 void TLSSocket::SetExtraOptions(const OHOS::NetStack::TCPExtraOptions &tcpExtraOptions,
914                                 const OHOS::NetStack::SetExtraOptionsCallback &callback)
915 {
916     if (!SetBaseOptions(tcpExtraOptions)) {
917         NETSTACK_LOGE("SetExtraOptions errno %{public}d", errno);
918         CallOnErrorCallback(errno, MakeErrnoString());
919         CallSetExtraOptionsCallback(ConvertErrno(), callback);
920         return;
921     }
922 
923     if (!SetExtraOptions(tcpExtraOptions)) {
924         NETSTACK_LOGE("SetExtraOptions errno %{public}d", errno);
925         CallOnErrorCallback(errno, MakeErrnoString());
926         CallSetExtraOptionsCallback(ConvertErrno(), callback);
927         return;
928     }
929 
930     CallSetExtraOptionsCallback(TLSSOCKET_SUCCESS, callback);
931 }
932 
GetCertificate(const GetCertificateCallback & callback)933 void TLSSocket::GetCertificate(const GetCertificateCallback &callback)
934 {
935     const auto &cert = tlsSocketInternal_.GetCertificate();
936     NETSTACK_LOGI("cert der is %{public}d", cert.encodingFormat);
937 
938     if (!cert.data.Length()) {
939         int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
940         NETSTACK_LOGE("GetCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
941         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
942         callback(resErr, {});
943         return;
944     }
945     callback(TLSSOCKET_SUCCESS, cert);
946 }
947 
GetRemoteCertificate(const GetRemoteCertificateCallback & callback)948 void TLSSocket::GetRemoteCertificate(const GetRemoteCertificateCallback &callback)
949 {
950     const auto &remoteCert = tlsSocketInternal_.GetRemoteCertRawData();
951     if (!remoteCert.data.Length()) {
952         int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
953         NETSTACK_LOGE("GetRemoteCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
954         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
955         callback(resErr, {});
956         return;
957     }
958     callback(TLSSOCKET_SUCCESS, remoteCert);
959 }
960 
GetProtocol(const GetProtocolCallback & callback)961 void TLSSocket::GetProtocol(const GetProtocolCallback &callback)
962 {
963     const auto &protocol = tlsSocketInternal_.GetProtocol();
964     if (protocol.empty()) {
965         NETSTACK_LOGE("GetProtocol errno %{public}d", errno);
966         int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
967         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
968         callback(resErr, "");
969         return;
970     }
971     callback(TLSSOCKET_SUCCESS, protocol);
972 }
973 
GetCipherSuite(const GetCipherSuiteCallback & callback)974 void TLSSocket::GetCipherSuite(const GetCipherSuiteCallback &callback)
975 {
976     const auto &cipherSuite = tlsSocketInternal_.GetCipherSuite();
977     if (cipherSuite.empty()) {
978         NETSTACK_LOGE("GetCipherSuite errno %{public}d", errno);
979         int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
980         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
981         callback(resErr, cipherSuite);
982         return;
983     }
984     callback(TLSSOCKET_SUCCESS, cipherSuite);
985 }
986 
GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback & callback)987 void TLSSocket::GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback &callback)
988 {
989     const auto &signatureAlgorithms = tlsSocketInternal_.GetSignatureAlgorithms();
990     if (signatureAlgorithms.empty()) {
991         NETSTACK_LOGE("GetSignatureAlgorithms errno %{public}d", errno);
992         int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
993         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
994         callback(resErr, {});
995         return;
996     }
997     callback(TLSSOCKET_SUCCESS, signatureAlgorithms);
998 }
999 
OnMessage(const OHOS::NetStack::OnMessageCallback & onMessageCallback)1000 void TLSSocket::OnMessage(const OHOS::NetStack::OnMessageCallback &onMessageCallback)
1001 {
1002     std::lock_guard<std::mutex> lock(mutex_);
1003     onMessageCallback_ = onMessageCallback;
1004 }
1005 
OffMessage()1006 void TLSSocket::OffMessage()
1007 {
1008     std::lock_guard<std::mutex> lock(mutex_);
1009     if (onMessageCallback_) {
1010         onMessageCallback_ = nullptr;
1011     }
1012 }
1013 
OnConnect(const OnConnectCallback & onConnectCallback)1014 void TLSSocket::OnConnect(const OnConnectCallback &onConnectCallback)
1015 {
1016     std::lock_guard<std::mutex> lock(mutex_);
1017     onConnectCallback_ = onConnectCallback;
1018 }
1019 
OffConnect()1020 void TLSSocket::OffConnect()
1021 {
1022     std::lock_guard<std::mutex> lock(mutex_);
1023     if (onConnectCallback_) {
1024         onConnectCallback_ = nullptr;
1025     }
1026 }
1027 
OnClose(const OnCloseCallback & onCloseCallback)1028 void TLSSocket::OnClose(const OnCloseCallback &onCloseCallback)
1029 {
1030     std::lock_guard<std::mutex> lock(mutex_);
1031     onCloseCallback_ = onCloseCallback;
1032 }
1033 
OffClose()1034 void TLSSocket::OffClose()
1035 {
1036     std::lock_guard<std::mutex> lock(mutex_);
1037     if (onCloseCallback_) {
1038         onCloseCallback_ = nullptr;
1039     }
1040 }
1041 
OnError(const OnErrorCallback & onErrorCallback)1042 void TLSSocket::OnError(const OnErrorCallback &onErrorCallback)
1043 {
1044     std::lock_guard<std::mutex> lock(mutex_);
1045     onErrorCallback_ = onErrorCallback;
1046 }
1047 
OffError()1048 void TLSSocket::OffError()
1049 {
1050     std::lock_guard<std::mutex> lock(mutex_);
1051     if (onErrorCallback_) {
1052         onErrorCallback_ = nullptr;
1053     }
1054 }
1055 
ExecSocketConnect(const std::string & hostName,int port,sa_family_t family,int socketDescriptor)1056 bool ExecSocketConnect(const std::string &hostName, int port, sa_family_t family, int socketDescriptor)
1057 {
1058     struct sockaddr_in dest = {0};
1059     dest.sin_family = family;
1060     dest.sin_port = htons(port);
1061     if (!inet_aton(hostName.c_str(), reinterpret_cast<in_addr *>(&dest.sin_addr.s_addr))) {
1062         NETSTACK_LOGE("inet_aton is error, hostName is %s", hostName.c_str());
1063         return false;
1064     }
1065     int connectResult = connect(socketDescriptor, reinterpret_cast<struct sockaddr *>(&dest), sizeof(dest));
1066     if (connectResult == -1) {
1067         NETSTACK_LOGE("socket connect error!The error code is %{public}d, The error message is %{public}s", errno,
1068                       strerror(errno));
1069         return false;
1070     }
1071     return true;
1072 }
1073 
TlsConnectToHost(int sock,const TLSConnectOptions & options)1074 bool TLSSocket::TLSSocketInternal::TlsConnectToHost(int sock, const TLSConnectOptions &options)
1075 {
1076     SetTlsConfiguration(options);
1077     std::string cipherSuite = options.GetTlsSecureOptions().GetCipherSuite();
1078     if (!cipherSuite.empty()) {
1079         configuration_.SetCipherSuite(cipherSuite);
1080     }
1081     std::string signatureAlgorithms = options.GetTlsSecureOptions().GetSignatureAlgorithms();
1082     if (!signatureAlgorithms.empty()) {
1083         configuration_.SetSignatureAlgorithms(signatureAlgorithms);
1084     }
1085     const auto protocolVec = options.GetTlsSecureOptions().GetProtocolChain();
1086     if (!protocolVec.empty()) {
1087         configuration_.SetProtocol(protocolVec);
1088     }
1089 
1090     hostName_ = options.GetNetAddress().GetAddress();
1091     port_ = options.GetNetAddress().GetPort();
1092     family_ = options.GetNetAddress().GetSaFamily();
1093     socketDescriptor_ = sock;
1094     if (!ExecSocketConnect(options.GetNetAddress().GetAddress(), options.GetNetAddress().GetPort(),
1095                            options.GetNetAddress().GetSaFamily(), socketDescriptor_)) {
1096         return false;
1097     }
1098     return StartTlsConnected(options);
1099 }
1100 
SetTlsConfiguration(const TLSConnectOptions & config)1101 void TLSSocket::TLSSocketInternal::SetTlsConfiguration(const TLSConnectOptions &config)
1102 {
1103     configuration_.SetPrivateKey(config.GetTlsSecureOptions().GetKey(), config.GetTlsSecureOptions().GetKeyPass());
1104     configuration_.SetLocalCertificate(config.GetTlsSecureOptions().GetCert());
1105     configuration_.SetCaCertificate(config.GetTlsSecureOptions().GetCaChain());
1106 }
1107 
Send(const std::string & data)1108 bool TLSSocket::TLSSocketInternal::Send(const std::string &data)
1109 {
1110     NETSTACK_LOGD("data to send :%{public}s", data.c_str());
1111     if (data.empty()) {
1112         NETSTACK_LOGE("data is empty");
1113         return false;
1114     }
1115     if (!ssl_) {
1116         NETSTACK_LOGE("ssl is null");
1117         return false;
1118     }
1119     int len = SSL_write(ssl_, data.c_str(), data.length());
1120     if (len < 0) {
1121         int resErr = ConvertSSLError(GetSSL());
1122         NETSTACK_LOGE("data '%{public}s' send failed!The error code is %{public}d, The error message is'%{public}s'",
1123                       data.c_str(), resErr, MakeSSLErrorString(resErr).c_str());
1124         return false;
1125     }
1126     NETSTACK_LOGD("data '%{public}s' Sent successfully,sent in total %{public}d bytes!", data.c_str(), len);
1127     return true;
1128 }
Recv(char * buffer,int maxBufferSize)1129 int TLSSocket::TLSSocketInternal::Recv(char *buffer, int maxBufferSize)
1130 {
1131     if (!ssl_) {
1132         NETSTACK_LOGE("ssl is null");
1133         return SSL_ERROR_RETURN;
1134     }
1135     return SSL_read(ssl_, buffer, maxBufferSize);
1136 }
1137 
Close()1138 bool TLSSocket::TLSSocketInternal::Close()
1139 {
1140     if (!ssl_) {
1141         NETSTACK_LOGE("ssl is null");
1142         return false;
1143     }
1144     int result = SSL_shutdown(ssl_);
1145     if (result < 0) {
1146         int resErr = ConvertSSLError(GetSSL());
1147         NETSTACK_LOGE("Error in shutdown, errno is %{public}d, error info is %{public}s", resErr,
1148                       MakeSSLErrorString(resErr).c_str());
1149         return false;
1150     }
1151     SSL_free(ssl_);
1152     ssl_ = nullptr;
1153     close(socketDescriptor_);
1154     socketDescriptor_ = -1;
1155     if (!tlsContextPointer_) {
1156         NETSTACK_LOGE("Tls context pointer is null");
1157         return false;
1158     }
1159     tlsContextPointer_->CloseCtx();
1160     return true;
1161 }
1162 
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)1163 bool TLSSocket::TLSSocketInternal::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
1164 {
1165     if (!ssl_) {
1166         NETSTACK_LOGE("ssl is null");
1167         return false;
1168     }
1169     size_t pos = 0;
1170     size_t len = std::accumulate(alpnProtocols.begin(), alpnProtocols.end(), static_cast<size_t>(0),
1171                                  [](size_t init, const std::string &alpnProt) { return init + alpnProt.length(); });
1172     auto result = std::make_unique<unsigned char[]>(alpnProtocols.size() + len);
1173     for (const auto &str : alpnProtocols) {
1174         len = str.length();
1175         result[pos++] = len;
1176         if (!strcpy_s(reinterpret_cast<char *>(&result[pos]), len, str.c_str())) {
1177             NETSTACK_LOGE("strcpy_s failed");
1178             return false;
1179         }
1180         pos += len;
1181     }
1182     result[pos] = '\0';
1183 
1184     NETSTACK_LOGD("alpnProtocols after splicing %{public}s", result.get());
1185     if (SSL_set_alpn_protos(ssl_, result.get(), pos)) {
1186         int resErr = ConvertSSLError(GetSSL());
1187         NETSTACK_LOGE("Failed to set negotiable protocol list, errno is %{public}d, error info is %{public}s", resErr,
1188                       MakeSSLErrorString(resErr).c_str());
1189         return false;
1190     }
1191     return true;
1192 }
1193 
MakeRemoteInfo(SocketRemoteInfo & remoteInfo)1194 void TLSSocket::TLSSocketInternal::MakeRemoteInfo(SocketRemoteInfo &remoteInfo)
1195 {
1196     remoteInfo.SetAddress(hostName_);
1197     remoteInfo.SetPort(port_);
1198     remoteInfo.SetFamily(family_);
1199 }
1200 
GetTlsConfiguration() const1201 TLSConfiguration TLSSocket::TLSSocketInternal::GetTlsConfiguration() const
1202 {
1203     return configuration_;
1204 }
1205 
GetCipherSuite() const1206 std::vector<std::string> TLSSocket::TLSSocketInternal::GetCipherSuite() const
1207 {
1208     if (!ssl_) {
1209         NETSTACK_LOGE("ssl in null");
1210         return {};
1211     }
1212     STACK_OF(SSL_CIPHER) *sk = SSL_get_ciphers(ssl_);
1213     if (!sk) {
1214         NETSTACK_LOGE("get ciphers failed");
1215         return {};
1216     }
1217     CipherSuite cipherSuite;
1218     std::vector<std::string> cipherSuiteVec;
1219     for (int i = 0; i < sk_SSL_CIPHER_num(sk); i++) {
1220         const SSL_CIPHER *c = sk_SSL_CIPHER_value(sk, i);
1221         cipherSuite.cipherName_ = SSL_CIPHER_get_name(c);
1222         cipherSuiteVec.push_back(cipherSuite.cipherName_);
1223     }
1224     return cipherSuiteVec;
1225 }
1226 
GetRemoteCertificate() const1227 std::string TLSSocket::TLSSocketInternal::GetRemoteCertificate() const
1228 {
1229     return remoteCert_;
1230 }
1231 
GetCertificate() const1232 const X509CertRawData &TLSSocket::TLSSocketInternal::GetCertificate() const
1233 {
1234     return configuration_.GetCertificate();
1235 }
1236 
GetSignatureAlgorithms() const1237 std::vector<std::string> TLSSocket::TLSSocketInternal::GetSignatureAlgorithms() const
1238 {
1239     return signatureAlgorithms_;
1240 }
1241 
GetProtocol() const1242 std::string TLSSocket::TLSSocketInternal::GetProtocol() const
1243 {
1244     if (!ssl_) {
1245         NETSTACK_LOGE("ssl in null");
1246         return PROTOCOL_UNKNOW;
1247     }
1248     if (configuration_.GetProtocol() == TLS_V1_3) {
1249         return PROTOCOL_TLS_V13;
1250     }
1251     return PROTOCOL_TLS_V12;
1252 }
1253 
SetSharedSigals()1254 bool TLSSocket::TLSSocketInternal::SetSharedSigals()
1255 {
1256     if (!ssl_) {
1257         NETSTACK_LOGE("ssl is null");
1258         return false;
1259     }
1260     int number = SSL_get_shared_sigalgs(ssl_, 0, nullptr, nullptr, nullptr, nullptr, nullptr);
1261     if (!number) {
1262         NETSTACK_LOGE("SSL_get_shared_sigalgs return value error");
1263         return false;
1264     }
1265     for (int i = 0; i < number; i++) {
1266         int hash_nid;
1267         int sign_nid;
1268         std::string sig_with_md;
1269         SSL_get_shared_sigalgs(ssl_, i, &sign_nid, &hash_nid, nullptr, nullptr, nullptr);
1270         switch (sign_nid) {
1271             case EVP_PKEY_RSA:
1272                 sig_with_md = SIGN_NID_RSA;
1273                 break;
1274             case EVP_PKEY_RSA_PSS:
1275                 sig_with_md = SIGN_NID_RSA_PSS;
1276                 break;
1277             case EVP_PKEY_DSA:
1278                 sig_with_md = SIGN_NID_DSA;
1279                 break;
1280             case EVP_PKEY_EC:
1281                 sig_with_md = SIGN_NID_ECDSA;
1282                 break;
1283             case NID_ED25519:
1284                 sig_with_md = SIGN_NID_ED;
1285                 break;
1286             case NID_ED448:
1287                 sig_with_md = SIGN_NID_ED_FOUR_FOUR_EIGHT;
1288                 break;
1289             default:
1290                 const char *sn = OBJ_nid2sn(sign_nid);
1291                 sig_with_md = (sn != nullptr) ? (std::string(sn) + OPERATOR_PLUS_SIGN) : SIGN_NID_UNDEF_ADD;
1292         }
1293         const char *sn_hash = OBJ_nid2sn(hash_nid);
1294         sig_with_md += (sn_hash != nullptr) ? std::string(sn_hash) : SIGN_NID_UNDEF;
1295         signatureAlgorithms_.push_back(sig_with_md);
1296     }
1297     return true;
1298 }
1299 
StartTlsConnected(const TLSConnectOptions & options)1300 bool TLSSocket::TLSSocketInternal::StartTlsConnected(const TLSConnectOptions &options)
1301 {
1302     if (!CreatTlsContext()) {
1303         NETSTACK_LOGE("failed to create tls context");
1304         return false;
1305     }
1306     if (!StartShakingHands(options)) {
1307         NETSTACK_LOGE("failed to shaking hands");
1308         return false;
1309     }
1310     return true;
1311 }
1312 
CreatTlsContext()1313 bool TLSSocket::TLSSocketInternal::CreatTlsContext()
1314 {
1315     tlsContextPointer_ = TLSContext::CreateConfiguration(configuration_);
1316     if (!tlsContextPointer_) {
1317         NETSTACK_LOGE("failed to create tls context pointer");
1318         return false;
1319     }
1320     if (!(ssl_ = tlsContextPointer_->CreateSsl())) {
1321         NETSTACK_LOGE("failed to create ssl session");
1322         return false;
1323     }
1324     SSL_set_fd(ssl_, socketDescriptor_);
1325     SSL_set_connect_state(ssl_);
1326     return true;
1327 }
1328 
StartsWith(const std::string & s,const std::string & prefix)1329 static bool StartsWith(const std::string &s, const std::string &prefix)
1330 {
1331     return s.size() >= prefix.size() && s.compare(0, prefix.size(), prefix) == 0;
1332 }
1333 
CheckIpAndDnsName(const std::string & hostName,std::vector<std::string> dnsNames,std::vector<std::string> ips,const X509 * x509Certificates,std::tuple<bool,std::string> & result)1334 void CheckIpAndDnsName(const std::string &hostName, std::vector<std::string> dnsNames, std::vector<std::string> ips,
1335                        const X509 *x509Certificates, std::tuple<bool, std::string> &result)
1336 {
1337     bool valid = false;
1338     std::string reason = UNKNOW_REASON;
1339     int index = X509_get_ext_by_NID(x509Certificates, NID_commonName, -1);
1340     if (IsIP(hostName)) {
1341         auto it = find(ips.begin(), ips.end(), hostName);
1342         if (it == ips.end()) {
1343             reason = IP + hostName + " is not in the cert's list";
1344         }
1345         result = {valid, reason};
1346         return;
1347     }
1348     std::string tempHostName = "" + hostName;
1349     if (!dnsNames.empty() || index > 0) {
1350         std::vector<std::string> hostParts = SplitHostName(tempHostName);
1351         if (!dnsNames.empty()) {
1352             valid = SeekIntersection(hostParts, dnsNames);
1353             if (!valid) {
1354                 reason = HOST_NAME + tempHostName + ". is not in the cert's altnames";
1355             }
1356         } else {
1357             char commonNameBuf[COMMON_NAME_BUF_SIZE] = {0};
1358             X509_NAME *pSubName = nullptr;
1359             int len = X509_NAME_get_text_by_NID(pSubName, NID_commonName, commonNameBuf, COMMON_NAME_BUF_SIZE);
1360             if (len > 0) {
1361                 std::vector<std::string> commonNameVec;
1362                 commonNameVec.emplace_back(commonNameBuf);
1363                 valid = SeekIntersection(hostParts, commonNameVec);
1364                 if (!valid) {
1365                     reason = HOST_NAME + tempHostName + ". is not cert's CN";
1366                 }
1367             }
1368         }
1369         result = {valid, reason};
1370         return;
1371     }
1372     reason = "Cert does not contain a DNS name";
1373     result = {valid, reason};
1374 }
1375 
CheckServerIdentityLegal(const std::string & hostName,const X509 * x509Certificates)1376 std::string TLSSocket::TLSSocketInternal::CheckServerIdentityLegal(const std::string &hostName,
1377                                                                    const X509 *x509Certificates)
1378 {
1379     std::string hostname = hostName;
1380 
1381     X509_NAME *subjectName = X509_get_subject_name(x509Certificates);
1382     if (!subjectName) {
1383         return "subject name is null";
1384     }
1385     char subNameBuf[BUF_SIZE] = {0};
1386     X509_NAME_oneline(subjectName, subNameBuf, BUF_SIZE);
1387 
1388     int index = X509_get_ext_by_NID(x509Certificates, NID_subject_alt_name, -1);
1389     if (index < 0) {
1390         return "X509 get ext nid error";
1391     }
1392     X509_EXTENSION *ext = X509_get_ext(x509Certificates, index);
1393     if (ext == nullptr) {
1394         return "X509 get ext error";
1395     }
1396     ASN1_OBJECT *obj = nullptr;
1397     obj = X509_EXTENSION_get_object(ext);
1398     char subAltNameBuf[BUF_SIZE] = {0};
1399     OBJ_obj2txt(subAltNameBuf, BUF_SIZE, obj, 0);
1400     NETSTACK_LOGD("extions obj : %{public}s\n", subAltNameBuf);
1401 
1402     ASN1_OCTET_STRING *extData = X509_EXTENSION_get_data(ext);
1403     std::string altNames = reinterpret_cast<char *>(extData->data);
1404 
1405     BIO *bio = BIO_new(BIO_s_file());
1406     if (!bio) {
1407         return "bio is null";
1408     }
1409     BIO_set_fp(bio, stdout, BIO_NOCLOSE);
1410     ASN1_STRING_print(bio, extData);
1411     std::vector<std::string> dnsNames = {};
1412     std::vector<std::string> ips = {};
1413     constexpr int DNS_NAME_IDX = 4;
1414     constexpr int IP_NAME_IDX = 11;
1415     hostname = "" + hostname;
1416     if (!altNames.empty()) {
1417         std::vector<std::string> splitAltNames;
1418         if (altNames.find('\"') != std::string::npos) {
1419             splitAltNames = SplitEscapedAltNames(altNames);
1420         } else {
1421             splitAltNames = CommonUtils::Split(altNames, SPLIT_ALT_NAMES);
1422         }
1423         for (auto const &iter : splitAltNames) {
1424             if (StartsWith(iter, DNS)) {
1425                 dnsNames.push_back(iter.substr(DNS_NAME_IDX));
1426             } else if (StartsWith(iter, IP_ADDRESS)) {
1427                 ips.push_back(iter.substr(IP_NAME_IDX));
1428             }
1429         }
1430     }
1431     std::tuple<bool, std::string> result;
1432     CheckIpAndDnsName(hostName, dnsNames, ips, x509Certificates, result);
1433     if (!std::get<0>(result)) {
1434         return "Hostname/IP does not match certificate's altnames: " + std::get<1>(result);
1435     }
1436     return HOST_NAME + hostname + ". is cert's CN";
1437 }
1438 
StartShakingHands(const TLSConnectOptions & options)1439 bool TLSSocket::TLSSocketInternal::StartShakingHands(const TLSConnectOptions &options)
1440 {
1441     if (!ssl_) {
1442         NETSTACK_LOGE("ssl is null");
1443         return false;
1444     }
1445     int result = SSL_connect(ssl_);
1446     if (result == -1) {
1447         int errorStatus = ConvertSSLError(ssl_);
1448         NETSTACK_LOGE("SSL connect is error, errno is %{public}d, error info is %{public}s", errorStatus,
1449                       MakeSSLErrorString(errorStatus).c_str());
1450         return false;
1451     }
1452 
1453     std::string list = SSL_get_cipher_list(ssl_, 0);
1454     NETSTACK_LOGI("SSL_get_cipher_list: %{public}s", list.c_str());
1455     configuration_.SetCipherSuite(list);
1456     if (!SetSharedSigals()) {
1457         NETSTACK_LOGE("Failed to set sharedSigalgs");
1458     }
1459     if (!GetRemoteCertificateFromPeer()) {
1460         NETSTACK_LOGE("Failed to get remote certificate");
1461     }
1462     if (!peerX509_) {
1463         NETSTACK_LOGE("peer x509Certificates is null");
1464         return false;
1465     }
1466     if (!SetRemoteCertRawData()) {
1467         NETSTACK_LOGE("Failed to set remote x509 certificata Serialization data");
1468     }
1469     CheckServerIdentity checkServerIdentity = options.GetCheckServerIdentity();
1470     if (!checkServerIdentity) {
1471         CheckServerIdentityLegal(hostName_, peerX509_);
1472     } else {
1473         checkServerIdentity(hostName_, {remoteCert_});
1474     }
1475     NETSTACK_LOGI("SSL Get Version: %{public}s, SSL Get Cipher: %{public}s", SSL_get_version(ssl_),
1476                   SSL_get_cipher(ssl_));
1477     return true;
1478 }
1479 
GetRemoteCertificateFromPeer()1480 bool TLSSocket::TLSSocketInternal::GetRemoteCertificateFromPeer()
1481 {
1482     peerX509_ = SSL_get_peer_certificate(ssl_);
1483     if (peerX509_ == nullptr) {
1484         int resErr = ConvertSSLError(GetSSL());
1485         NETSTACK_LOGE("open fail errno, errno is %{public}d, error info is %{public}s", resErr,
1486                       MakeSSLErrorString(resErr).c_str());
1487         return false;
1488     }
1489     BIO *bio = BIO_new(BIO_s_mem());
1490     if (!bio) {
1491         NETSTACK_LOGE("TlsSocket::SetRemoteCertificate bio is null");
1492         return false;
1493     }
1494     X509_print(bio, peerX509_);
1495     char data[REMOTE_CERT_LEN] = {0};
1496     if (!BIO_read(bio, data, REMOTE_CERT_LEN)) {
1497         NETSTACK_LOGE("BIO_read function returns error");
1498         BIO_free(bio);
1499         return false;
1500     }
1501     BIO_free(bio);
1502     remoteCert_ = std::string(data);
1503     return true;
1504 }
1505 
SetRemoteCertRawData()1506 bool TLSSocket::TLSSocketInternal::SetRemoteCertRawData()
1507 {
1508     if (peerX509_ == nullptr) {
1509         NETSTACK_LOGE("peerX509 is null");
1510         return false;
1511     }
1512     int32_t length = i2d_X509(peerX509_, nullptr);
1513     if (length <= 0) {
1514         NETSTACK_LOGE("Failed to convert peerX509 to der format");
1515         return false;
1516     }
1517     unsigned char *der = nullptr;
1518     (void)i2d_X509(peerX509_, &der);
1519     SecureData data(der, length);
1520     remoteRawData_.data = data;
1521     OPENSSL_free(der);
1522     remoteRawData_.encodingFormat = DER;
1523     return true;
1524 }
1525 
GetRemoteCertRawData() const1526 const X509CertRawData &TLSSocket::TLSSocketInternal::GetRemoteCertRawData() const
1527 {
1528     return remoteRawData_;
1529 }
1530 
GetSSL() const1531 ssl_st *TLSSocket::TLSSocketInternal::GetSSL() const
1532 {
1533     return ssl_;
1534 }
1535 } // namespace NetStack
1536 } // namespace OHOS
1537