1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "tls_socket.h"
17
18 #include <chrono>
19 #include <numeric>
20 #include <regex>
21 #include <securec.h>
22 #include <thread>
23 #include <memory>
24
25 #include <netinet/tcp.h>
26 #include <openssl/err.h>
27 #include <openssl/ssl.h>
28
29 #include "base_context.h"
30 #include "netstack_common_utils.h"
31 #include "netstack_log.h"
32 #include "tls.h"
33
34 namespace OHOS {
35 namespace NetStack {
36 namespace {
37 constexpr int WAIT_MS = 10;
38 constexpr int TIMEOUT_MS = 10000;
39 constexpr int REMOTE_CERT_LEN = 8192;
40 constexpr int COMMON_NAME_BUF_SIZE = 256;
41 constexpr int BUF_SIZE = 2048;
42 constexpr int SSL_RET_CODE = 0;
43 constexpr int SSL_ERROR_RETURN = -1;
44 constexpr int OFFSET = 2;
45 constexpr const char *SPLIT_ALT_NAMES = ",";
46 constexpr const char *SPLIT_HOST_NAME = ".";
47 constexpr const char *PROTOCOL_UNKNOW = "UNKNOW_PROTOCOL";
48 constexpr const char *UNKNOW_REASON = "Unknown reason";
49 constexpr const char *IP = "IP: ";
50 constexpr const char *HOST_NAME = "hostname: ";
51 constexpr const char *DNS = "DNS:";
52 constexpr const char *IP_ADDRESS = "IP Address:";
53 constexpr const char *SIGN_NID_RSA = "RSA+";
54 constexpr const char *SIGN_NID_RSA_PSS = "RSA-PSS+";
55 constexpr const char *SIGN_NID_DSA = "DSA+";
56 constexpr const char *SIGN_NID_ECDSA = "ECDSA+";
57 constexpr const char *SIGN_NID_ED = "Ed25519+";
58 constexpr const char *SIGN_NID_ED_FOUR_FOUR_EIGHT = "Ed448+";
59 constexpr const char *SIGN_NID_UNDEF_ADD = "UNDEF+";
60 constexpr const char *SIGN_NID_UNDEF = "UNDEF";
61 constexpr const char *OPERATOR_PLUS_SIGN = "+";
62 const std::regex JSON_STRING_PATTERN{R"(/^"(?:[^"\\\u0000-\u001f]|\\(?:["\\/bfnrt]|u[0-9a-fA-F]{4}))*"/)"};
63 const std::regex PATTERN{
64 "((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|"
65 "2[0-4][0-9]|[01]?[0-9][0-9]?)"};
66
ConvertErrno()67 int ConvertErrno()
68 {
69 return TlsSocketError::TLS_ERR_SYS_BASE + errno;
70 }
71
ConvertSSLError(ssl_st * ssl)72 int ConvertSSLError(ssl_st *ssl)
73 {
74 if (!ssl) {
75 return TLS_ERR_SSL_NULL;
76 }
77 return TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl, SSL_RET_CODE);
78 }
79
MakeErrnoString()80 std::string MakeErrnoString()
81 {
82 return strerror(errno);
83 }
84
MakeSSLErrorString(int error)85 std::string MakeSSLErrorString(int error)
86 {
87 char err[MAX_ERR_LEN] = {0};
88 ERR_error_string_n(error - TlsSocketError::TLS_ERR_SYS_BASE, err, sizeof(err));
89 return err;
90 }
91
SplitEscapedAltNames(std::string & altNames)92 std::vector<std::string> SplitEscapedAltNames(std::string &altNames)
93 {
94 std::vector<std::string> result;
95 std::string currentToken;
96 size_t offset = 0;
97 while (offset != altNames.length()) {
98 auto nextSep = altNames.find_first_of(", ");
99 auto nextQuote = altNames.find_first_of('\"');
100 if (nextQuote != std::string::npos && (nextSep != std::string::npos || nextQuote < nextSep)) {
101 currentToken += altNames.substr(offset, nextQuote);
102 std::regex jsonStringPattern(JSON_STRING_PATTERN);
103 std::smatch match;
104 std::string altNameSubStr = altNames.substr(nextQuote);
105 bool ret = regex_match(altNameSubStr, match, jsonStringPattern);
106 if (!ret) {
107 return {""};
108 }
109 currentToken += result[0];
110 offset = nextQuote + result[0].length();
111 } else if (nextSep != std::string::npos) {
112 currentToken += altNames.substr(offset, nextSep);
113 result.push_back(currentToken);
114 currentToken = "";
115 offset = nextSep + OFFSET;
116 } else {
117 currentToken += altNames.substr(offset);
118 offset = altNames.length();
119 }
120 }
121 result.push_back(currentToken);
122 return result;
123 }
124
IsIP(const std::string & ip)125 bool IsIP(const std::string &ip)
126 {
127 std::regex pattern(PATTERN);
128 std::smatch res;
129 return regex_match(ip, res, pattern);
130 }
131
SplitHostName(std::string & hostName)132 std::vector<std::string> SplitHostName(std::string &hostName)
133 {
134 transform(hostName.begin(), hostName.end(), hostName.begin(), ::tolower);
135 return CommonUtils::Split(hostName, SPLIT_HOST_NAME);
136 }
137
SeekIntersection(std::vector<std::string> & vecA,std::vector<std::string> & vecB)138 bool SeekIntersection(std::vector<std::string> &vecA, std::vector<std::string> &vecB)
139 {
140 std::vector<std::string> result;
141 set_intersection(vecA.begin(), vecA.end(), vecB.begin(), vecB.end(), inserter(result, result.begin()));
142 return !result.empty();
143 }
144 } // namespace
145
TLSSecureOptions(const TLSSecureOptions & tlsSecureOptions)146 TLSSecureOptions::TLSSecureOptions(const TLSSecureOptions &tlsSecureOptions)
147 {
148 *this = tlsSecureOptions;
149 }
150
operator =(const TLSSecureOptions & tlsSecureOptions)151 TLSSecureOptions &TLSSecureOptions::operator=(const TLSSecureOptions &tlsSecureOptions)
152 {
153 key_ = tlsSecureOptions.GetKey();
154 caChain_ = tlsSecureOptions.GetCaChain();
155 cert_ = tlsSecureOptions.GetCert();
156 protocolChain_ = tlsSecureOptions.GetProtocolChain();
157 crlChain_ = tlsSecureOptions.GetCrlChain();
158 keyPass_ = tlsSecureOptions.GetKeyPass();
159 key_ = tlsSecureOptions.GetKey();
160 signatureAlgorithms_ = tlsSecureOptions.GetSignatureAlgorithms();
161 cipherSuite_ = tlsSecureOptions.GetCipherSuite();
162 useRemoteCipherPrefer_ = tlsSecureOptions.UseRemoteCipherPrefer();
163 return *this;
164 }
165
SetCaChain(const std::vector<std::string> & caChain)166 void TLSSecureOptions::SetCaChain(const std::vector<std::string> &caChain)
167 {
168 caChain_ = caChain;
169 }
170
SetCert(const std::string & cert)171 void TLSSecureOptions::SetCert(const std::string &cert)
172 {
173 cert_ = cert;
174 }
175
SetKey(const SecureData & key)176 void TLSSecureOptions::SetKey(const SecureData &key)
177 {
178 key_ = key;
179 }
180
SetKeyPass(const SecureData & keyPass)181 void TLSSecureOptions::SetKeyPass(const SecureData &keyPass)
182 {
183 keyPass_ = keyPass;
184 }
185
SetProtocolChain(const std::vector<std::string> & protocolChain)186 void TLSSecureOptions::SetProtocolChain(const std::vector<std::string> &protocolChain)
187 {
188 protocolChain_ = protocolChain;
189 }
190
SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer)191 void TLSSecureOptions::SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer)
192 {
193 useRemoteCipherPrefer_ = useRemoteCipherPrefer;
194 }
195
SetSignatureAlgorithms(const std::string & signatureAlgorithms)196 void TLSSecureOptions::SetSignatureAlgorithms(const std::string &signatureAlgorithms)
197 {
198 signatureAlgorithms_ = signatureAlgorithms;
199 }
200
SetCipherSuite(const std::string & cipherSuite)201 void TLSSecureOptions::SetCipherSuite(const std::string &cipherSuite)
202 {
203 cipherSuite_ = cipherSuite;
204 }
205
SetCrlChain(const std::vector<std::string> & crlChain)206 void TLSSecureOptions::SetCrlChain(const std::vector<std::string> &crlChain)
207 {
208 crlChain_ = crlChain;
209 }
210
GetCaChain() const211 const std::vector<std::string> &TLSSecureOptions::GetCaChain() const
212 {
213 return caChain_;
214 }
215
GetCert() const216 const std::string &TLSSecureOptions::GetCert() const
217 {
218 return cert_;
219 }
220
GetKey() const221 const SecureData &TLSSecureOptions::GetKey() const
222 {
223 return key_;
224 }
225
GetKeyPass() const226 const SecureData &TLSSecureOptions::GetKeyPass() const
227 {
228 return keyPass_;
229 }
230
GetProtocolChain() const231 const std::vector<std::string> &TLSSecureOptions::GetProtocolChain() const
232 {
233 return protocolChain_;
234 }
235
UseRemoteCipherPrefer() const236 bool TLSSecureOptions::UseRemoteCipherPrefer() const
237 {
238 return useRemoteCipherPrefer_;
239 }
240
GetSignatureAlgorithms() const241 const std::string &TLSSecureOptions::GetSignatureAlgorithms() const
242 {
243 return signatureAlgorithms_;
244 }
245
GetCipherSuite() const246 const std::string &TLSSecureOptions::GetCipherSuite() const
247 {
248 return cipherSuite_;
249 }
250
GetCrlChain() const251 const std::vector<std::string> &TLSSecureOptions::GetCrlChain() const
252 {
253 return crlChain_;
254 }
255
SetNetAddress(const NetAddress & address)256 void TLSConnectOptions::SetNetAddress(const NetAddress &address)
257 {
258 address_.SetAddress(address.GetAddress());
259 address_.SetPort(address.GetPort());
260 address_.SetFamilyBySaFamily(address.GetSaFamily());
261 }
262
SetTlsSecureOptions(TLSSecureOptions & tlsSecureOptions)263 void TLSConnectOptions::SetTlsSecureOptions(TLSSecureOptions &tlsSecureOptions)
264 {
265 tlsSecureOptions_ = tlsSecureOptions;
266 }
267
SetCheckServerIdentity(const CheckServerIdentity & checkServerIdentity)268 void TLSConnectOptions::SetCheckServerIdentity(const CheckServerIdentity &checkServerIdentity)
269 {
270 checkServerIdentity_ = checkServerIdentity;
271 }
272
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)273 void TLSConnectOptions::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
274 {
275 alpnProtocols_ = alpnProtocols;
276 }
277
GetNetAddress() const278 NetAddress TLSConnectOptions::GetNetAddress() const
279 {
280 return address_;
281 }
282
GetTlsSecureOptions() const283 TLSSecureOptions TLSConnectOptions::GetTlsSecureOptions() const
284 {
285 return tlsSecureOptions_;
286 }
287
GetCheckServerIdentity() const288 CheckServerIdentity TLSConnectOptions::GetCheckServerIdentity() const
289 {
290 return checkServerIdentity_;
291 }
292
GetAlpnProtocols() const293 const std::vector<std::string> &TLSConnectOptions::GetAlpnProtocols() const
294 {
295 return alpnProtocols_;
296 }
297
MakeAddressString(sockaddr * addr)298 std::string TLSSocket::MakeAddressString(sockaddr *addr)
299 {
300 if (!addr) {
301 return {};
302 }
303 if (addr->sa_family == AF_INET) {
304 auto *addr4 = reinterpret_cast<sockaddr_in *>(addr);
305 const char *str = inet_ntoa(addr4->sin_addr);
306 if (str == nullptr || strlen(str) == 0) {
307 return {};
308 }
309 return str;
310 } else if (addr->sa_family == AF_INET6) {
311 auto *addr6 = reinterpret_cast<sockaddr_in6 *>(addr);
312 char str[INET6_ADDRSTRLEN] = {0};
313 if (inet_ntop(AF_INET6, &addr6->sin6_addr, str, INET6_ADDRSTRLEN) == nullptr || strlen(str) == 0) {
314 return {};
315 }
316 return str;
317 }
318 return {};
319 }
320
GetAddr(const NetAddress & address,sockaddr_in * addr4,sockaddr_in6 * addr6,sockaddr ** addr,socklen_t * len)321 void TLSSocket::GetAddr(const NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr,
322 socklen_t *len)
323 {
324 if (!addr6 || !addr4 || !len) {
325 return;
326 }
327 sa_family_t family = address.GetSaFamily();
328 if (family == AF_INET) {
329 addr4->sin_family = AF_INET;
330 addr4->sin_port = htons(address.GetPort());
331 addr4->sin_addr.s_addr = inet_addr(address.GetAddress().c_str());
332 *addr = reinterpret_cast<sockaddr *>(addr4);
333 *len = sizeof(sockaddr_in);
334 } else if (family == AF_INET6) {
335 addr6->sin6_family = AF_INET6;
336 addr6->sin6_port = htons(address.GetPort());
337 inet_pton(AF_INET6, address.GetAddress().c_str(), &addr6->sin6_addr);
338 *addr = reinterpret_cast<sockaddr *>(addr6);
339 *len = sizeof(sockaddr_in6);
340 }
341 }
342
MakeIpSocket(sa_family_t family)343 void TLSSocket::MakeIpSocket(sa_family_t family)
344 {
345 if (family != AF_INET && family != AF_INET6) {
346 return;
347 }
348 int sock = socket(family, SOCK_STREAM, IPPROTO_IP);
349 if (sock < 0) {
350 int resErr = ConvertErrno();
351 NETSTACK_LOGE("Create socket failed (%{public}d:%{public}s)", errno, MakeErrnoString().c_str());
352 CallOnErrorCallback(resErr, MakeErrnoString());
353 return;
354 }
355 sockFd_ = sock;
356 }
357
StartReadMessage()358 void TLSSocket::StartReadMessage()
359 {
360 std::thread thread([this]() {
361 isRunning_ = true;
362 isRunOver_ = false;
363 char buffer[MAX_BUFFER_SIZE];
364 bzero(buffer, MAX_BUFFER_SIZE);
365 while (isRunning_) {
366 int len = tlsSocketInternal_.Recv(buffer, MAX_BUFFER_SIZE);
367 if (!isRunning_) {
368 break;
369 }
370 if (len < 0) {
371 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
372 NETSTACK_LOGE("SSL_read function read error, errno is %{public}d, errno info is %{public}s", resErr,
373 MakeSSLErrorString(resErr).c_str());
374 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
375 break;
376 }
377
378 if (len == 0) {
379 continue;
380 }
381
382 SocketRemoteInfo remoteInfo;
383 remoteInfo.SetSize(len);
384 tlsSocketInternal_.MakeRemoteInfo(remoteInfo);
385 CallOnMessageCallback(buffer, remoteInfo);
386 }
387 isRunOver_ = true;
388 });
389 thread.detach();
390 }
391
CallOnMessageCallback(const std::string & data,const OHOS::NetStack::SocketRemoteInfo & remoteInfo)392 void TLSSocket::CallOnMessageCallback(const std::string &data, const OHOS::NetStack::SocketRemoteInfo &remoteInfo)
393 {
394 OnMessageCallback func = nullptr;
395 {
396 std::lock_guard<std::mutex> lock(mutex_);
397 if (onMessageCallback_) {
398 func = onMessageCallback_;
399 }
400 }
401
402 if (func) {
403 func(data, remoteInfo);
404 }
405 }
406
CallOnConnectCallback()407 void TLSSocket::CallOnConnectCallback()
408 {
409 OnConnectCallback func = nullptr;
410 {
411 std::lock_guard<std::mutex> lock(mutex_);
412 if (onConnectCallback_) {
413 func = onConnectCallback_;
414 }
415 }
416
417 if (func) {
418 func();
419 }
420 }
421
CallOnCloseCallback()422 void TLSSocket::CallOnCloseCallback()
423 {
424 OnCloseCallback func = nullptr;
425 {
426 std::lock_guard<std::mutex> lock(mutex_);
427 if (onCloseCallback_) {
428 func = onCloseCallback_;
429 }
430 }
431
432 if (func) {
433 func();
434 }
435 }
436
CallOnErrorCallback(int32_t err,const std::string & errString)437 void TLSSocket::CallOnErrorCallback(int32_t err, const std::string &errString)
438 {
439 OnErrorCallback func = nullptr;
440 {
441 std::lock_guard<std::mutex> lock(mutex_);
442 if (onErrorCallback_) {
443 func = onErrorCallback_;
444 }
445 }
446
447 if (func) {
448 func(err, errString);
449 }
450 }
451
CallBindCallback(int32_t err,BindCallback callback)452 void TLSSocket::CallBindCallback(int32_t err, BindCallback callback)
453 {
454 BindCallback func = nullptr;
455 {
456 std::lock_guard<std::mutex> lock(mutex_);
457 if (callback) {
458 func = callback;
459 }
460 }
461
462 if (func) {
463 func(err);
464 }
465 }
466
CallConnectCallback(int32_t err,ConnectCallback callback)467 void TLSSocket::CallConnectCallback(int32_t err, ConnectCallback callback)
468 {
469 ConnectCallback func = nullptr;
470 {
471 std::lock_guard<std::mutex> lock(mutex_);
472 if (callback) {
473 func = callback;
474 }
475 }
476
477 if (func) {
478 func(err);
479 }
480 }
481
CallSendCallback(int32_t err,SendCallback callback)482 void TLSSocket::CallSendCallback(int32_t err, SendCallback callback)
483 {
484 SendCallback func = nullptr;
485 {
486 std::lock_guard<std::mutex> lock(mutex_);
487 if (callback) {
488 func = callback;
489 }
490 }
491
492 if (func) {
493 func(err);
494 }
495 }
496
CallCloseCallback(int32_t err,CloseCallback callback)497 void TLSSocket::CallCloseCallback(int32_t err, CloseCallback callback)
498 {
499 CloseCallback func = nullptr;
500 {
501 std::lock_guard<std::mutex> lock(mutex_);
502 if (callback) {
503 func = callback;
504 }
505 }
506
507 if (func) {
508 func(err);
509 }
510 }
511
CallGetRemoteAddressCallback(int32_t err,const NetAddress & address,GetRemoteAddressCallback callback)512 void TLSSocket::CallGetRemoteAddressCallback(int32_t err, const NetAddress &address, GetRemoteAddressCallback callback)
513 {
514 GetRemoteAddressCallback func = nullptr;
515 {
516 std::lock_guard<std::mutex> lock(mutex_);
517 if (callback) {
518 func = callback;
519 }
520 }
521
522 if (func) {
523 func(err, address);
524 }
525 }
526
CallGetStateCallback(int32_t err,const SocketStateBase & state,GetStateCallback callback)527 void TLSSocket::CallGetStateCallback(int32_t err, const SocketStateBase &state, GetStateCallback callback)
528 {
529 GetStateCallback func = nullptr;
530 {
531 std::lock_guard<std::mutex> lock(mutex_);
532 if (callback) {
533 func = callback;
534 }
535 }
536
537 if (func) {
538 func(err, state);
539 }
540 }
541
CallSetExtraOptionsCallback(int32_t err,SetExtraOptionsCallback callback)542 void TLSSocket::CallSetExtraOptionsCallback(int32_t err, SetExtraOptionsCallback callback)
543 {
544 SetExtraOptionsCallback func = nullptr;
545 {
546 std::lock_guard<std::mutex> lock(mutex_);
547 if (callback) {
548 func = callback;
549 }
550 }
551
552 if (func) {
553 func(err);
554 }
555 }
556
CallGetCertificateCallback(int32_t err,const X509CertRawData & cert,GetCertificateCallback callback)557 void TLSSocket::CallGetCertificateCallback(int32_t err, const X509CertRawData &cert, GetCertificateCallback callback)
558 {
559 GetCertificateCallback func = nullptr;
560 {
561 std::lock_guard<std::mutex> lock(mutex_);
562 if (callback) {
563 func = callback;
564 }
565 }
566
567 if (func) {
568 func(err, cert);
569 }
570 }
571
CallGetRemoteCertificateCallback(int32_t err,const X509CertRawData & cert,GetRemoteCertificateCallback callback)572 void TLSSocket::CallGetRemoteCertificateCallback(int32_t err, const X509CertRawData &cert,
573 GetRemoteCertificateCallback callback)
574 {
575 GetRemoteCertificateCallback func = nullptr;
576 {
577 std::lock_guard<std::mutex> lock(mutex_);
578 if (callback) {
579 func = callback;
580 }
581 }
582
583 if (func) {
584 func(err, cert);
585 }
586 }
587
CallGetProtocolCallback(int32_t err,const std::string & protocol,GetProtocolCallback callback)588 void TLSSocket::CallGetProtocolCallback(int32_t err, const std::string &protocol, GetProtocolCallback callback)
589 {
590 GetProtocolCallback func = nullptr;
591 {
592 std::lock_guard<std::mutex> lock(mutex_);
593 if (callback) {
594 func = callback;
595 }
596 }
597
598 if (func) {
599 func(err, protocol);
600 }
601 }
602
CallGetCipherSuiteCallback(int32_t err,const std::vector<std::string> & suite,GetCipherSuiteCallback callback)603 void TLSSocket::CallGetCipherSuiteCallback(int32_t err, const std::vector<std::string> &suite,
604 GetCipherSuiteCallback callback)
605 {
606 GetCipherSuiteCallback func = nullptr;
607 {
608 std::lock_guard<std::mutex> lock(mutex_);
609 if (callback) {
610 func = callback;
611 }
612 }
613
614 if (func) {
615 func(err, suite);
616 }
617 }
618
CallGetSignatureAlgorithmsCallback(int32_t err,const std::vector<std::string> & algorithms,GetSignatureAlgorithmsCallback callback)619 void TLSSocket::CallGetSignatureAlgorithmsCallback(int32_t err, const std::vector<std::string> &algorithms,
620 GetSignatureAlgorithmsCallback callback)
621 {
622 GetSignatureAlgorithmsCallback func = nullptr;
623 {
624 std::lock_guard<std::mutex> lock(mutex_);
625 if (callback) {
626 func = callback;
627 }
628 }
629
630 if (func) {
631 func(err, algorithms);
632 }
633 }
634
Bind(const OHOS::NetStack::NetAddress & address,const OHOS::NetStack::BindCallback & callback)635 void TLSSocket::Bind(const OHOS::NetStack::NetAddress &address, const OHOS::NetStack::BindCallback &callback)
636 {
637 if (!CommonUtils::HasInternetPermission()) {
638 CallBindCallback(PERMISSION_DENIED_CODE, callback);
639 return;
640 }
641 if (sockFd_ >= 0) {
642 CallBindCallback(TLSSOCKET_SUCCESS, callback);
643 return;
644 }
645
646 MakeIpSocket(address.GetSaFamily());
647 if (sockFd_ < 0) {
648 int resErr = ConvertErrno();
649 NETSTACK_LOGE("make tcp socket failed errno is %{public}d %{public}s", errno, MakeErrnoString().c_str());
650 CallOnErrorCallback(resErr, MakeErrnoString());
651 CallBindCallback(resErr, callback);
652 return;
653 }
654
655 sockaddr_in addr4 = {0};
656 sockaddr_in6 addr6 = {0};
657 sockaddr *addr = nullptr;
658 socklen_t len;
659 GetAddr(address, &addr4, &addr6, &addr, &len);
660 if (addr == nullptr) {
661 NETSTACK_LOGE("TLSSocket::Bind Address Is Invalid");
662 CallOnErrorCallback(-1, "Address Is Invalid");
663 CallBindCallback(ConvertErrno(), callback);
664 return;
665 }
666 CallBindCallback(TLSSOCKET_SUCCESS, callback);
667 }
668
Connect(OHOS::NetStack::TLSConnectOptions & tlsConnectOptions,const OHOS::NetStack::ConnectCallback & callback)669 void TLSSocket::Connect(OHOS::NetStack::TLSConnectOptions &tlsConnectOptions,
670 const OHOS::NetStack::ConnectCallback &callback)
671 {
672 if (sockFd_ < 0) {
673 int resErr = ConvertErrno();
674 NETSTACK_LOGE("connect error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
675 CallOnErrorCallback(resErr, MakeErrnoString());
676 callback(resErr);
677 return;
678 }
679
680 auto res = tlsSocketInternal_.TlsConnectToHost(sockFd_, tlsConnectOptions);
681 if (!res) {
682 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
683 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
684 callback(resErr);
685 return;
686 }
687 StartReadMessage();
688 CallOnConnectCallback();
689 callback(TLSSOCKET_SUCCESS);
690 }
691
Send(const OHOS::NetStack::TCPSendOptions & tcpSendOptions,const OHOS::NetStack::SendCallback & callback)692 void TLSSocket::Send(const OHOS::NetStack::TCPSendOptions &tcpSendOptions, const OHOS::NetStack::SendCallback &callback)
693 {
694 (void)tcpSendOptions;
695
696 auto res = tlsSocketInternal_.Send(tcpSendOptions.GetData());
697 if (!res) {
698 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
699 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
700 CallSendCallback(resErr, callback);
701 return;
702 }
703 CallSendCallback(TLSSOCKET_SUCCESS, callback);
704 }
705
WaitConditionWithTimeout(const bool * flag,const int32_t timeoutMs)706 bool WaitConditionWithTimeout(const bool *flag, const int32_t timeoutMs)
707 {
708 int maxWaitCnt = timeoutMs / WAIT_MS;
709 int cnt = 0;
710 while (!(*flag)) {
711 if (cnt >= maxWaitCnt) {
712 return false;
713 }
714 std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MS));
715 cnt++;
716 }
717 return true;
718 }
719
Close(const OHOS::NetStack::CloseCallback & callback)720 void TLSSocket::Close(const OHOS::NetStack::CloseCallback &callback)
721 {
722 if (!WaitConditionWithTimeout(&isRunning_, TIMEOUT_MS)) {
723 callback(ConvertErrno());
724 NETSTACK_LOGE("The error cause is that the runtime wait time is insufficient");
725 return;
726 }
727 isRunning_ = false;
728 if (!WaitConditionWithTimeout(&isRunOver_, TIMEOUT_MS)) {
729 callback(ConvertErrno());
730 NETSTACK_LOGE("The error is due to insufficient delay time");
731 return;
732 }
733 auto res = tlsSocketInternal_.Close();
734 if (!res) {
735 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
736 NETSTACK_LOGE("close error is %{public}s %{public}d", MakeSSLErrorString(resErr).c_str(), resErr);
737 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
738 callback(resErr);
739 return;
740 }
741 CallOnCloseCallback();
742 callback(TLSSOCKET_SUCCESS);
743 }
744
GetRemoteAddress(const OHOS::NetStack::GetRemoteAddressCallback & callback)745 void TLSSocket::GetRemoteAddress(const OHOS::NetStack::GetRemoteAddressCallback &callback)
746 {
747 sockaddr sockAddr;
748 socklen_t len = sizeof(sockaddr);
749 int ret = getsockname(sockFd_, &sockAddr, &len);
750 if (ret < 0) {
751 int resErr = ConvertErrno();
752 NETSTACK_LOGE("getsockname failed errno %{public}d", resErr);
753 CallOnErrorCallback(resErr, MakeErrnoString());
754 CallGetRemoteAddressCallback(resErr, {}, callback);
755 return;
756 }
757
758 if (sockAddr.sa_family == AF_INET) {
759 GetIp4RemoteAddress(callback);
760 } else if (sockAddr.sa_family == AF_INET6) {
761 GetIp6RemoteAddress(callback);
762 }
763 }
764
GetIp4RemoteAddress(const OHOS::NetStack::GetRemoteAddressCallback & callback)765 void TLSSocket::GetIp4RemoteAddress(const OHOS::NetStack::GetRemoteAddressCallback &callback)
766 {
767 sockaddr_in addr4 = {0};
768 socklen_t len4 = sizeof(sockaddr_in);
769
770 int ret = getpeername(sockFd_, reinterpret_cast<sockaddr *>(&addr4), &len4);
771 if (ret < 0) {
772 int resErr = ConvertErrno();
773 NETSTACK_LOGE("GetIp4RemoteAddress failed errno %{public}d", resErr);
774 CallOnErrorCallback(resErr, MakeErrnoString());
775 CallGetRemoteAddressCallback(resErr, {}, callback);
776 return;
777 }
778
779 std::string address = MakeAddressString(reinterpret_cast<sockaddr *>(&addr4));
780 if (address.empty()) {
781 NETSTACK_LOGE("GetIp4RemoteAddress failed errno %{public}d", errno);
782 CallOnErrorCallback(-1, "Address is invalid");
783 CallGetRemoteAddressCallback(ConvertErrno(), {}, callback);
784 return;
785 }
786 NetAddress netAddress;
787 netAddress.SetAddress(address);
788 netAddress.SetFamilyBySaFamily(AF_INET);
789 netAddress.SetPort(ntohs(addr4.sin_port));
790 CallGetRemoteAddressCallback(TLSSOCKET_SUCCESS, netAddress, callback);
791 }
792
GetIp6RemoteAddress(const OHOS::NetStack::GetRemoteAddressCallback & callback)793 void TLSSocket::GetIp6RemoteAddress(const OHOS::NetStack::GetRemoteAddressCallback &callback)
794 {
795 sockaddr_in6 addr6 = {0};
796 socklen_t len6 = sizeof(sockaddr_in6);
797
798 int ret = getpeername(sockFd_, reinterpret_cast<sockaddr *>(&addr6), &len6);
799 if (ret < 0) {
800 int resErr = ConvertErrno();
801 NETSTACK_LOGE("GetIp6RemoteAddress failed errno %{public}d", resErr);
802 CallOnErrorCallback(resErr, MakeErrnoString());
803 CallGetRemoteAddressCallback(resErr, {}, callback);
804 return;
805 }
806
807 std::string address = MakeAddressString(reinterpret_cast<sockaddr *>(&addr6));
808 if (address.empty()) {
809 NETSTACK_LOGE("GetIp6RemoteAddress failed errno %{public}d", errno);
810 CallOnErrorCallback(-1, "Address is invalid");
811 CallGetRemoteAddressCallback(ConvertErrno(), {}, callback);
812 return;
813 }
814 NetAddress netAddress;
815 netAddress.SetAddress(address);
816 netAddress.SetFamilyBySaFamily(AF_INET6);
817 netAddress.SetPort(ntohs(addr6.sin6_port));
818 CallGetRemoteAddressCallback(TLSSOCKET_SUCCESS, netAddress, callback);
819 }
820
GetState(const OHOS::NetStack::GetStateCallback & callback)821 void TLSSocket::GetState(const OHOS::NetStack::GetStateCallback &callback)
822 {
823 int opt;
824 socklen_t optLen = sizeof(int);
825 int r = getsockopt(sockFd_, SOL_SOCKET, SO_TYPE, &opt, &optLen);
826 if (r < 0) {
827 SocketStateBase state;
828 state.SetIsClose(true);
829 CallGetStateCallback(ConvertErrno(), state, callback);
830 return;
831 }
832 sockaddr sockAddr;
833 socklen_t len = sizeof(sockaddr);
834 SocketStateBase state;
835 int ret = getsockname(sockFd_, &sockAddr, &len);
836 state.SetIsBound(ret == 0);
837 ret = getpeername(sockFd_, &sockAddr, &len);
838 state.SetIsConnected(ret == 0);
839 CallGetStateCallback(TLSSOCKET_SUCCESS, state, callback);
840 }
841
SetBaseOptions(const ExtraOptionsBase & option) const842 bool TLSSocket::SetBaseOptions(const ExtraOptionsBase &option) const
843 {
844 if (option.GetReceiveBufferSize() != 0) {
845 int size = (int)option.GetReceiveBufferSize();
846 if (setsockopt(sockFd_, SOL_SOCKET, SO_RCVBUF, reinterpret_cast<void *>(&size), sizeof(size)) < 0) {
847 return false;
848 }
849 }
850
851 if (option.GetSendBufferSize() != 0) {
852 int size = (int)option.GetSendBufferSize();
853 if (setsockopt(sockFd_, SOL_SOCKET, SO_SNDBUF, reinterpret_cast<void *>(&size), sizeof(size)) < 0) {
854 return false;
855 }
856 }
857
858 if (option.IsReuseAddress()) {
859 int reuse = 1;
860 if (setsockopt(sockFd_, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<void *>(&reuse), sizeof(reuse)) < 0) {
861 return false;
862 }
863 }
864
865 if (option.GetSocketTimeout() != 0) {
866 timeval timeout = {(int)option.GetSocketTimeout(), 0};
867 if (setsockopt(sockFd_, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<void *>(&timeout), sizeof(timeout)) < 0) {
868 return false;
869 }
870 if (setsockopt(sockFd_, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast<void *>(&timeout), sizeof(timeout)) < 0) {
871 return false;
872 }
873 }
874
875 return true;
876 }
877
SetExtraOptions(const TCPExtraOptions & option) const878 bool TLSSocket::SetExtraOptions(const TCPExtraOptions &option) const
879 {
880 if (option.IsKeepAlive()) {
881 int keepalive = 1;
882 if (setsockopt(sockFd_, SOL_SOCKET, SO_KEEPALIVE, &keepalive, sizeof(keepalive)) < 0) {
883 return false;
884 }
885 }
886
887 if (option.IsOOBInline()) {
888 int oobInline = 1;
889 if (setsockopt(sockFd_, SOL_SOCKET, SO_OOBINLINE, &oobInline, sizeof(oobInline)) < 0) {
890 return false;
891 }
892 }
893
894 if (option.IsTCPNoDelay()) {
895 int tcpNoDelay = 1;
896 if (setsockopt(sockFd_, IPPROTO_TCP, TCP_NODELAY, &tcpNoDelay, sizeof(tcpNoDelay)) < 0) {
897 return false;
898 }
899 }
900
901 linger soLinger = {0};
902 soLinger.l_onoff = option.socketLinger.IsOn();
903 soLinger.l_linger = (int)option.socketLinger.GetLinger();
904 if (setsockopt(sockFd_, SOL_SOCKET, SO_LINGER, &soLinger, sizeof(soLinger)) < 0) {
905 return false;
906 }
907
908 return true;
909 }
910
SetExtraOptions(const OHOS::NetStack::TCPExtraOptions & tcpExtraOptions,const OHOS::NetStack::SetExtraOptionsCallback & callback)911 void TLSSocket::SetExtraOptions(const OHOS::NetStack::TCPExtraOptions &tcpExtraOptions,
912 const OHOS::NetStack::SetExtraOptionsCallback &callback)
913 {
914 if (!SetBaseOptions(tcpExtraOptions)) {
915 NETSTACK_LOGE("SetExtraOptions errno %{public}d", errno);
916 CallOnErrorCallback(errno, MakeErrnoString());
917 CallSetExtraOptionsCallback(ConvertErrno(), callback);
918 return;
919 }
920
921 if (!SetExtraOptions(tcpExtraOptions)) {
922 NETSTACK_LOGE("SetExtraOptions errno %{public}d", errno);
923 CallOnErrorCallback(errno, MakeErrnoString());
924 CallSetExtraOptionsCallback(ConvertErrno(), callback);
925 return;
926 }
927
928 CallSetExtraOptionsCallback(TLSSOCKET_SUCCESS, callback);
929 }
930
GetCertificate(const GetCertificateCallback & callback)931 void TLSSocket::GetCertificate(const GetCertificateCallback &callback)
932 {
933 const auto &cert = tlsSocketInternal_.GetCertificate();
934 NETSTACK_LOGI("cert der is %{public}d", cert.encodingFormat);
935
936 if (!cert.data.Length()) {
937 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
938 NETSTACK_LOGE("GetCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
939 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
940 callback(resErr, {});
941 return;
942 }
943 callback(TLSSOCKET_SUCCESS, cert);
944 }
945
GetRemoteCertificate(const GetRemoteCertificateCallback & callback)946 void TLSSocket::GetRemoteCertificate(const GetRemoteCertificateCallback &callback)
947 {
948 const auto &remoteCert = tlsSocketInternal_.GetRemoteCertRawData();
949 if (!remoteCert.data.Length()) {
950 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
951 NETSTACK_LOGE("GetRemoteCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
952 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
953 callback(resErr, {});
954 return;
955 }
956 callback(TLSSOCKET_SUCCESS, remoteCert);
957 }
958
GetProtocol(const GetProtocolCallback & callback)959 void TLSSocket::GetProtocol(const GetProtocolCallback &callback)
960 {
961 const auto &protocol = tlsSocketInternal_.GetProtocol();
962 if (protocol.empty()) {
963 NETSTACK_LOGE("GetProtocol errno %{public}d", errno);
964 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
965 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
966 callback(resErr, "");
967 return;
968 }
969 callback(TLSSOCKET_SUCCESS, protocol);
970 }
971
GetCipherSuite(const GetCipherSuiteCallback & callback)972 void TLSSocket::GetCipherSuite(const GetCipherSuiteCallback &callback)
973 {
974 const auto &cipherSuite = tlsSocketInternal_.GetCipherSuite();
975 if (cipherSuite.empty()) {
976 NETSTACK_LOGE("GetCipherSuite errno %{public}d", errno);
977 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
978 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
979 callback(resErr, cipherSuite);
980 return;
981 }
982 callback(TLSSOCKET_SUCCESS, cipherSuite);
983 }
984
GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback & callback)985 void TLSSocket::GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback &callback)
986 {
987 const auto &signatureAlgorithms = tlsSocketInternal_.GetSignatureAlgorithms();
988 if (signatureAlgorithms.empty()) {
989 NETSTACK_LOGE("GetSignatureAlgorithms errno %{public}d", errno);
990 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
991 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
992 callback(resErr, {});
993 return;
994 }
995 callback(TLSSOCKET_SUCCESS, signatureAlgorithms);
996 }
997
OnMessage(const OHOS::NetStack::OnMessageCallback & onMessageCallback)998 void TLSSocket::OnMessage(const OHOS::NetStack::OnMessageCallback &onMessageCallback)
999 {
1000 std::lock_guard<std::mutex> lock(mutex_);
1001 onMessageCallback_ = onMessageCallback;
1002 }
1003
OffMessage()1004 void TLSSocket::OffMessage()
1005 {
1006 std::lock_guard<std::mutex> lock(mutex_);
1007 if (onMessageCallback_) {
1008 onMessageCallback_ = nullptr;
1009 }
1010 }
1011
OnConnect(const OnConnectCallback & onConnectCallback)1012 void TLSSocket::OnConnect(const OnConnectCallback &onConnectCallback)
1013 {
1014 std::lock_guard<std::mutex> lock(mutex_);
1015 onConnectCallback_ = onConnectCallback;
1016 }
1017
OffConnect()1018 void TLSSocket::OffConnect()
1019 {
1020 std::lock_guard<std::mutex> lock(mutex_);
1021 if (onConnectCallback_) {
1022 onConnectCallback_ = nullptr;
1023 }
1024 }
1025
OnClose(const OnCloseCallback & onCloseCallback)1026 void TLSSocket::OnClose(const OnCloseCallback &onCloseCallback)
1027 {
1028 std::lock_guard<std::mutex> lock(mutex_);
1029 onCloseCallback_ = onCloseCallback;
1030 }
1031
OffClose()1032 void TLSSocket::OffClose()
1033 {
1034 std::lock_guard<std::mutex> lock(mutex_);
1035 if (onCloseCallback_) {
1036 onCloseCallback_ = nullptr;
1037 }
1038 }
1039
OnError(const OnErrorCallback & onErrorCallback)1040 void TLSSocket::OnError(const OnErrorCallback &onErrorCallback)
1041 {
1042 std::lock_guard<std::mutex> lock(mutex_);
1043 onErrorCallback_ = onErrorCallback;
1044 }
1045
OffError()1046 void TLSSocket::OffError()
1047 {
1048 std::lock_guard<std::mutex> lock(mutex_);
1049 if (onErrorCallback_) {
1050 onErrorCallback_ = nullptr;
1051 }
1052 }
1053
ExecSocketConnect(const std::string & hostName,int port,sa_family_t family,int socketDescriptor)1054 bool ExecSocketConnect(const std::string &hostName, int port, sa_family_t family, int socketDescriptor)
1055 {
1056 struct sockaddr_in dest = {0};
1057 dest.sin_family = family;
1058 dest.sin_port = htons(port);
1059 if (!inet_aton(hostName.c_str(), reinterpret_cast<in_addr *>(&dest.sin_addr.s_addr))) {
1060 NETSTACK_LOGE("inet_aton is error, hostName is %s", hostName.c_str());
1061 return false;
1062 }
1063 int connectResult = connect(socketDescriptor, reinterpret_cast<struct sockaddr *>(&dest), sizeof(dest));
1064 if (connectResult == -1) {
1065 NETSTACK_LOGE("socket connect error!The error code is %{public}d, The error message is %{public}s", errno,
1066 strerror(errno));
1067 return false;
1068 }
1069 return true;
1070 }
1071
TlsConnectToHost(int sock,const TLSConnectOptions & options)1072 bool TLSSocket::TLSSocketInternal::TlsConnectToHost(int sock, const TLSConnectOptions &options)
1073 {
1074 SetTlsConfiguration(options);
1075 std::string cipherSuite = options.GetTlsSecureOptions().GetCipherSuite();
1076 if (!cipherSuite.empty()) {
1077 configuration_.SetCipherSuite(cipherSuite);
1078 }
1079 std::string signatureAlgorithms = options.GetTlsSecureOptions().GetSignatureAlgorithms();
1080 if (!signatureAlgorithms.empty()) {
1081 configuration_.SetSignatureAlgorithms(signatureAlgorithms);
1082 }
1083 const auto protocolVec = options.GetTlsSecureOptions().GetProtocolChain();
1084 if (!protocolVec.empty()) {
1085 configuration_.SetProtocol(protocolVec);
1086 }
1087
1088 hostName_ = options.GetNetAddress().GetAddress();
1089 port_ = options.GetNetAddress().GetPort();
1090 family_ = options.GetNetAddress().GetSaFamily();
1091 socketDescriptor_ = sock;
1092 if (!ExecSocketConnect(options.GetNetAddress().GetAddress(), options.GetNetAddress().GetPort(),
1093 options.GetNetAddress().GetSaFamily(), socketDescriptor_)) {
1094 return false;
1095 }
1096 return StartTlsConnected(options);
1097 }
1098
SetTlsConfiguration(const TLSConnectOptions & config)1099 void TLSSocket::TLSSocketInternal::SetTlsConfiguration(const TLSConnectOptions &config)
1100 {
1101 configuration_.SetPrivateKey(config.GetTlsSecureOptions().GetKey(), config.GetTlsSecureOptions().GetKeyPass());
1102 configuration_.SetLocalCertificate(config.GetTlsSecureOptions().GetCert());
1103 configuration_.SetCaCertificate(config.GetTlsSecureOptions().GetCaChain());
1104 }
1105
Send(const std::string & data)1106 bool TLSSocket::TLSSocketInternal::Send(const std::string &data)
1107 {
1108 NETSTACK_LOGD("data to send :%{public}s", data.c_str());
1109 if (data.empty()) {
1110 NETSTACK_LOGE("data is empty");
1111 return false;
1112 }
1113 if (!ssl_) {
1114 NETSTACK_LOGE("ssl is null");
1115 return false;
1116 }
1117 int len = SSL_write(ssl_, data.c_str(), data.length());
1118 if (len < 0) {
1119 int resErr = ConvertSSLError(GetSSL());
1120 NETSTACK_LOGE("data '%{public}s' send failed!The error code is %{public}d, The error message is'%{public}s'",
1121 data.c_str(), resErr, MakeSSLErrorString(resErr).c_str());
1122 return false;
1123 }
1124 NETSTACK_LOGD("data '%{public}s' Sent successfully,sent in total %{public}d bytes!", data.c_str(), len);
1125 return true;
1126 }
Recv(char * buffer,int maxBufferSize)1127 int TLSSocket::TLSSocketInternal::Recv(char *buffer, int maxBufferSize)
1128 {
1129 if (!ssl_) {
1130 NETSTACK_LOGE("ssl is null");
1131 return SSL_ERROR_RETURN;
1132 }
1133 return SSL_read(ssl_, buffer, maxBufferSize);
1134 }
1135
Close()1136 bool TLSSocket::TLSSocketInternal::Close()
1137 {
1138 if (!ssl_) {
1139 NETSTACK_LOGE("ssl is null");
1140 return false;
1141 }
1142 int result = SSL_shutdown(ssl_);
1143 if (result < 0) {
1144 int resErr = ConvertSSLError(GetSSL());
1145 NETSTACK_LOGE("Error in shutdown, errno is %{public}d, error info is %{public}s", resErr,
1146 MakeSSLErrorString(resErr).c_str());
1147 return false;
1148 }
1149 SSL_free(ssl_);
1150 ssl_ = nullptr;
1151 close(socketDescriptor_);
1152 socketDescriptor_ = -1;
1153 if (!tlsContextPointer_) {
1154 NETSTACK_LOGE("Tls context pointer is null");
1155 return false;
1156 }
1157 tlsContextPointer_->CloseCtx();
1158 return true;
1159 }
1160
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)1161 bool TLSSocket::TLSSocketInternal::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
1162 {
1163 if (!ssl_) {
1164 NETSTACK_LOGE("ssl is null");
1165 return false;
1166 }
1167 size_t pos = 0;
1168 size_t len = std::accumulate(alpnProtocols.begin(), alpnProtocols.end(), static_cast<size_t>(0),
1169 [](size_t init, const std::string &alpnProt) { return init + alpnProt.length(); });
1170 auto result = std::make_unique<unsigned char[]>(alpnProtocols.size() + len);
1171 for (const auto &str : alpnProtocols) {
1172 len = str.length();
1173 result[pos++] = len;
1174 if (!strcpy_s(reinterpret_cast<char *>(&result[pos]), len, str.c_str())) {
1175 NETSTACK_LOGE("strcpy_s failed");
1176 return false;
1177 }
1178 pos += len;
1179 }
1180 result[pos] = '\0';
1181
1182 NETSTACK_LOGD("alpnProtocols after splicing %{public}s", result.get());
1183 if (SSL_set_alpn_protos(ssl_, result.get(), pos)) {
1184 int resErr = ConvertSSLError(GetSSL());
1185 NETSTACK_LOGE("Failed to set negotiable protocol list, errno is %{public}d, error info is %{public}s", resErr,
1186 MakeSSLErrorString(resErr).c_str());
1187 return false;
1188 }
1189 return true;
1190 }
1191
MakeRemoteInfo(SocketRemoteInfo & remoteInfo)1192 void TLSSocket::TLSSocketInternal::MakeRemoteInfo(SocketRemoteInfo &remoteInfo)
1193 {
1194 remoteInfo.SetAddress(hostName_);
1195 remoteInfo.SetPort(port_);
1196 remoteInfo.SetFamily(family_);
1197 }
1198
GetTlsConfiguration() const1199 TLSConfiguration TLSSocket::TLSSocketInternal::GetTlsConfiguration() const
1200 {
1201 return configuration_;
1202 }
1203
GetCipherSuite() const1204 std::vector<std::string> TLSSocket::TLSSocketInternal::GetCipherSuite() const
1205 {
1206 if (!ssl_) {
1207 NETSTACK_LOGE("ssl in null");
1208 return {};
1209 }
1210 STACK_OF(SSL_CIPHER) *sk = SSL_get_ciphers(ssl_);
1211 if (!sk) {
1212 NETSTACK_LOGE("get ciphers failed");
1213 return {};
1214 }
1215 CipherSuite cipherSuite;
1216 std::vector<std::string> cipherSuiteVec;
1217 for (int i = 0; i < sk_SSL_CIPHER_num(sk); i++) {
1218 const SSL_CIPHER *c = sk_SSL_CIPHER_value(sk, i);
1219 cipherSuite.cipherName_ = SSL_CIPHER_get_name(c);
1220 cipherSuiteVec.push_back(cipherSuite.cipherName_);
1221 }
1222 return cipherSuiteVec;
1223 }
1224
GetRemoteCertificate() const1225 std::string TLSSocket::TLSSocketInternal::GetRemoteCertificate() const
1226 {
1227 return remoteCert_;
1228 }
1229
GetCertificate() const1230 const X509CertRawData &TLSSocket::TLSSocketInternal::GetCertificate() const
1231 {
1232 return configuration_.GetCertificate();
1233 }
1234
GetSignatureAlgorithms() const1235 std::vector<std::string> TLSSocket::TLSSocketInternal::GetSignatureAlgorithms() const
1236 {
1237 return signatureAlgorithms_;
1238 }
1239
GetProtocol() const1240 std::string TLSSocket::TLSSocketInternal::GetProtocol() const
1241 {
1242 if (!ssl_) {
1243 NETSTACK_LOGE("ssl in null");
1244 return PROTOCOL_UNKNOW;
1245 }
1246 if (configuration_.GetProtocol() == TLS_V1_3) {
1247 return PROTOCOL_TLS_V13;
1248 }
1249 return PROTOCOL_TLS_V12;
1250 }
1251
SetSharedSigals()1252 bool TLSSocket::TLSSocketInternal::SetSharedSigals()
1253 {
1254 if (!ssl_) {
1255 NETSTACK_LOGE("ssl is null");
1256 return false;
1257 }
1258 int number = SSL_get_shared_sigalgs(ssl_, 0, nullptr, nullptr, nullptr, nullptr, nullptr);
1259 if (!number) {
1260 NETSTACK_LOGE("SSL_get_shared_sigalgs return value error");
1261 return false;
1262 }
1263 for (int i = 0; i < number; i++) {
1264 int hash_nid;
1265 int sign_nid;
1266 std::string sig_with_md;
1267 SSL_get_shared_sigalgs(ssl_, i, &sign_nid, &hash_nid, nullptr, nullptr, nullptr);
1268 switch (sign_nid) {
1269 case EVP_PKEY_RSA:
1270 sig_with_md = SIGN_NID_RSA;
1271 break;
1272 case EVP_PKEY_RSA_PSS:
1273 sig_with_md = SIGN_NID_RSA_PSS;
1274 break;
1275 case EVP_PKEY_DSA:
1276 sig_with_md = SIGN_NID_DSA;
1277 break;
1278 case EVP_PKEY_EC:
1279 sig_with_md = SIGN_NID_ECDSA;
1280 break;
1281 case NID_ED25519:
1282 sig_with_md = SIGN_NID_ED;
1283 break;
1284 case NID_ED448:
1285 sig_with_md = SIGN_NID_ED_FOUR_FOUR_EIGHT;
1286 break;
1287 default:
1288 const char *sn = OBJ_nid2sn(sign_nid);
1289 sig_with_md = (sn != nullptr) ? (std::string(sn) + OPERATOR_PLUS_SIGN) : SIGN_NID_UNDEF_ADD;
1290 }
1291 const char *sn_hash = OBJ_nid2sn(hash_nid);
1292 sig_with_md += (sn_hash != nullptr) ? std::string(sn_hash) : SIGN_NID_UNDEF;
1293 signatureAlgorithms_.push_back(sig_with_md);
1294 }
1295 return true;
1296 }
1297
StartTlsConnected(const TLSConnectOptions & options)1298 bool TLSSocket::TLSSocketInternal::StartTlsConnected(const TLSConnectOptions &options)
1299 {
1300 if (!CreatTlsContext()) {
1301 NETSTACK_LOGE("failed to create tls context");
1302 return false;
1303 }
1304 if (!StartShakingHands(options)) {
1305 NETSTACK_LOGE("failed to shaking hands");
1306 return false;
1307 }
1308 return true;
1309 }
1310
CreatTlsContext()1311 bool TLSSocket::TLSSocketInternal::CreatTlsContext()
1312 {
1313 tlsContextPointer_ = TLSContext::CreateConfiguration(configuration_);
1314 if (!tlsContextPointer_) {
1315 NETSTACK_LOGE("failed to create tls context pointer");
1316 return false;
1317 }
1318 if (!(ssl_ = tlsContextPointer_->CreateSsl())) {
1319 NETSTACK_LOGE("failed to create ssl session");
1320 return false;
1321 }
1322 SSL_set_fd(ssl_, socketDescriptor_);
1323 SSL_set_connect_state(ssl_);
1324 return true;
1325 }
1326
StartsWith(const std::string & s,const std::string & prefix)1327 static bool StartsWith(const std::string &s, const std::string &prefix)
1328 {
1329 return s.size() >= prefix.size() && s.compare(0, prefix.size(), prefix) == 0;
1330 }
1331
CheckIpAndDnsName(const std::string & hostName,std::vector<std::string> dnsNames,std::vector<std::string> ips,const X509 * x509Certificates,std::tuple<bool,std::string> & result)1332 void CheckIpAndDnsName(const std::string &hostName, std::vector<std::string> dnsNames, std::vector<std::string> ips,
1333 const X509 *x509Certificates, std::tuple<bool, std::string> &result)
1334 {
1335 bool valid = false;
1336 std::string reason = UNKNOW_REASON;
1337 int index = X509_get_ext_by_NID(x509Certificates, NID_commonName, -1);
1338 if (IsIP(hostName)) {
1339 auto it = find(ips.begin(), ips.end(), hostName);
1340 if (it == ips.end()) {
1341 reason = IP + hostName + " is not in the cert's list";
1342 }
1343 result = {valid, reason};
1344 return;
1345 }
1346 std::string tempHostName = "" + hostName;
1347 if (!dnsNames.empty() || index > 0) {
1348 std::vector<std::string> hostParts = SplitHostName(tempHostName);
1349 if (!dnsNames.empty()) {
1350 valid = SeekIntersection(hostParts, dnsNames);
1351 if (!valid) {
1352 reason = HOST_NAME + tempHostName + ". is not in the cert's altnames";
1353 }
1354 } else {
1355 char commonNameBuf[COMMON_NAME_BUF_SIZE] = {0};
1356 X509_NAME *pSubName = nullptr;
1357 int len = X509_NAME_get_text_by_NID(pSubName, NID_commonName, commonNameBuf, COMMON_NAME_BUF_SIZE);
1358 if (len > 0) {
1359 std::vector<std::string> commonNameVec;
1360 commonNameVec.emplace_back(commonNameBuf);
1361 valid = SeekIntersection(hostParts, commonNameVec);
1362 if (!valid) {
1363 reason = HOST_NAME + tempHostName + ". is not cert's CN";
1364 }
1365 }
1366 }
1367 result = {valid, reason};
1368 return;
1369 }
1370 reason = "Cert does not contain a DNS name";
1371 result = {valid, reason};
1372 }
1373
CheckServerIdentityLegal(const std::string & hostName,const X509 * x509Certificates)1374 std::string TLSSocket::TLSSocketInternal::CheckServerIdentityLegal(const std::string &hostName,
1375 const X509 *x509Certificates)
1376 {
1377 std::string hostname = hostName;
1378
1379 X509_NAME *subjectName = X509_get_subject_name(x509Certificates);
1380 if (!subjectName) {
1381 return "subject name is null";
1382 }
1383 char subNameBuf[BUF_SIZE] = {0};
1384 X509_NAME_oneline(subjectName, subNameBuf, BUF_SIZE);
1385
1386 int index = X509_get_ext_by_NID(x509Certificates, NID_subject_alt_name, -1);
1387 if (index < 0) {
1388 return "X509 get ext nid error";
1389 }
1390 X509_EXTENSION *ext = X509_get_ext(x509Certificates, index);
1391 if (ext == nullptr) {
1392 return "X509 get ext error";
1393 }
1394 ASN1_OBJECT *obj = nullptr;
1395 obj = X509_EXTENSION_get_object(ext);
1396 char subAltNameBuf[BUF_SIZE] = {0};
1397 OBJ_obj2txt(subAltNameBuf, BUF_SIZE, obj, 0);
1398 NETSTACK_LOGD("extions obj : %{public}s\n", subAltNameBuf);
1399
1400 ASN1_OCTET_STRING *extData = X509_EXTENSION_get_data(ext);
1401 std::string altNames = reinterpret_cast<char *>(extData->data);
1402
1403 BIO *bio = BIO_new(BIO_s_file());
1404 if (!bio) {
1405 return "bio is null";
1406 }
1407 BIO_set_fp(bio, stdout, BIO_NOCLOSE);
1408 ASN1_STRING_print(bio, extData);
1409 std::vector<std::string> dnsNames = {};
1410 std::vector<std::string> ips = {};
1411 constexpr int DNS_NAME_IDX = 4;
1412 constexpr int IP_NAME_IDX = 11;
1413 hostname = "" + hostname;
1414 if (!altNames.empty()) {
1415 std::vector<std::string> splitAltNames;
1416 if (altNames.find('\"') != std::string::npos) {
1417 splitAltNames = SplitEscapedAltNames(altNames);
1418 } else {
1419 splitAltNames = CommonUtils::Split(altNames, SPLIT_ALT_NAMES);
1420 }
1421 for (auto const &iter : splitAltNames) {
1422 if (StartsWith(iter, DNS)) {
1423 dnsNames.push_back(iter.substr(DNS_NAME_IDX));
1424 } else if (StartsWith(iter, IP_ADDRESS)) {
1425 ips.push_back(iter.substr(IP_NAME_IDX));
1426 }
1427 }
1428 }
1429 std::tuple<bool, std::string> result;
1430 CheckIpAndDnsName(hostName, dnsNames, ips, x509Certificates, result);
1431 if (!std::get<0>(result)) {
1432 return "Hostname/IP does not match certificate's altnames: " + std::get<1>(result);
1433 }
1434 return HOST_NAME + hostname + ". is cert's CN";
1435 }
1436
StartShakingHands(const TLSConnectOptions & options)1437 bool TLSSocket::TLSSocketInternal::StartShakingHands(const TLSConnectOptions &options)
1438 {
1439 if (!ssl_) {
1440 NETSTACK_LOGE("ssl is null");
1441 return false;
1442 }
1443 int result = SSL_connect(ssl_);
1444 if (result == -1) {
1445 int errorStatus = ConvertSSLError(ssl_);
1446 NETSTACK_LOGE("SSL connect is error, errno is %{public}d, error info is %{public}s", errorStatus,
1447 MakeSSLErrorString(errorStatus).c_str());
1448 return false;
1449 }
1450
1451 std::string list = SSL_get_cipher_list(ssl_, 0);
1452 NETSTACK_LOGI("SSL_get_cipher_list: %{public}s", list.c_str());
1453 configuration_.SetCipherSuite(list);
1454 if (!SetSharedSigals()) {
1455 NETSTACK_LOGE("Failed to set sharedSigalgs");
1456 }
1457 if (!GetRemoteCertificateFromPeer()) {
1458 NETSTACK_LOGE("Failed to get remote certificate");
1459 }
1460 if (!peerX509_) {
1461 NETSTACK_LOGE("peer x509Certificates is null");
1462 return false;
1463 }
1464 if (!SetRemoteCertRawData()) {
1465 NETSTACK_LOGE("Failed to set remote x509 certificata Serialization data");
1466 }
1467 CheckServerIdentity checkServerIdentity = options.GetCheckServerIdentity();
1468 if (!checkServerIdentity) {
1469 CheckServerIdentityLegal(hostName_, peerX509_);
1470 } else {
1471 checkServerIdentity(hostName_, {remoteCert_});
1472 }
1473 NETSTACK_LOGI("SSL Get Version: %{public}s, SSL Get Cipher: %{public}s", SSL_get_version(ssl_),
1474 SSL_get_cipher(ssl_));
1475 return true;
1476 }
1477
GetRemoteCertificateFromPeer()1478 bool TLSSocket::TLSSocketInternal::GetRemoteCertificateFromPeer()
1479 {
1480 peerX509_ = SSL_get_peer_certificate(ssl_);
1481 if (peerX509_ == nullptr) {
1482 int resErr = ConvertSSLError(GetSSL());
1483 NETSTACK_LOGE("open fail errno, errno is %{public}d, error info is %{public}s", resErr,
1484 MakeSSLErrorString(resErr).c_str());
1485 return false;
1486 }
1487 BIO *bio = BIO_new(BIO_s_mem());
1488 if (!bio) {
1489 NETSTACK_LOGE("TlsSocket::SetRemoteCertificate bio is null");
1490 return false;
1491 }
1492 X509_print(bio, peerX509_);
1493 char data[REMOTE_CERT_LEN] = {0};
1494 if (!BIO_read(bio, data, REMOTE_CERT_LEN)) {
1495 NETSTACK_LOGE("BIO_read function returns error");
1496 BIO_free(bio);
1497 return false;
1498 }
1499 BIO_free(bio);
1500 remoteCert_ = std::string(data);
1501 return true;
1502 }
1503
SetRemoteCertRawData()1504 bool TLSSocket::TLSSocketInternal::SetRemoteCertRawData()
1505 {
1506 if (peerX509_ == nullptr) {
1507 NETSTACK_LOGE("peerX509 is null");
1508 return false;
1509 }
1510 int32_t length = i2d_X509(peerX509_, nullptr);
1511 if (length <= 0) {
1512 NETSTACK_LOGE("Failed to convert peerX509 to der format");
1513 return false;
1514 }
1515 unsigned char *der = nullptr;
1516 (void)i2d_X509(peerX509_, &der);
1517 SecureData data(der, length);
1518 remoteRawData_.data = data;
1519 OPENSSL_free(der);
1520 remoteRawData_.encodingFormat = DER;
1521 return true;
1522 }
1523
GetRemoteCertRawData() const1524 const X509CertRawData &TLSSocket::TLSSocketInternal::GetRemoteCertRawData() const
1525 {
1526 return remoteRawData_;
1527 }
1528
GetSSL() const1529 ssl_st *TLSSocket::TLSSocketInternal::GetSSL() const
1530 {
1531 return ssl_;
1532 }
1533 } // namespace NetStack
1534 } // namespace OHOS
1535