1 /*
2 * Copyright (c) 2022-2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "tls_socket.h"
17
18 #include <chrono>
19 #include <memory>
20 #include <numeric>
21 #include <regex>
22 #include <securec.h>
23 #include <thread>
24
25 #include <netinet/tcp.h>
26 #include <openssl/err.h>
27 #include <openssl/ssl.h>
28
29 #include "base_context.h"
30 #include "netstack_common_utils.h"
31 #include "netstack_log.h"
32 #include "tls.h"
33
34 namespace OHOS {
35 namespace NetStack {
36 namespace TlsSocket {
37 namespace {
38 constexpr int WAIT_MS = 10;
39 constexpr int TIMEOUT_MS = 10000;
40 constexpr int REMOTE_CERT_LEN = 8192;
41 constexpr int COMMON_NAME_BUF_SIZE = 256;
42 constexpr int BUF_SIZE = 2048;
43 constexpr int SSL_RET_CODE = 0;
44 constexpr int SSL_ERROR_RETURN = -1;
45 constexpr int OFFSET = 2;
46 constexpr int QUIT_RESPONSE_CODE_LEN = 3;
47 constexpr const char *SPLIT_ALT_NAMES = ",";
48 constexpr const char *SPLIT_HOST_NAME = ".";
49 constexpr const char *PROTOCOL_UNKNOW = "UNKNOW_PROTOCOL";
50 constexpr const char *UNKNOW_REASON = "Unknown reason";
51 constexpr const char *IP = "IP: ";
52 constexpr const char *HOST_NAME = "hostname: ";
53 constexpr const char *DNS = "DNS:";
54 constexpr const char *IP_ADDRESS = "IP Address:";
55 constexpr const char *SIGN_NID_RSA = "RSA+";
56 constexpr const char *SIGN_NID_RSA_PSS = "RSA-PSS+";
57 constexpr const char *SIGN_NID_DSA = "DSA+";
58 constexpr const char *SIGN_NID_ECDSA = "ECDSA+";
59 constexpr const char *SIGN_NID_ED = "Ed25519+";
60 constexpr const char *SIGN_NID_ED_FOUR_FOUR_EIGHT = "Ed448+";
61 constexpr const char *SIGN_NID_UNDEF_ADD = "UNDEF+";
62 constexpr const char *SIGN_NID_UNDEF = "UNDEF";
63 constexpr const char *OPERATOR_PLUS_SIGN = "+";
64 constexpr const char *QUIT_RESPONSE_CODE = "221";
65 const std::regex JSON_STRING_PATTERN{R"(/^"(?:[^"\\\u0000-\u001f]|\\(?:["\\/bfnrt]|u[0-9a-fA-F]{4}))*"/)"};
66 const std::regex PATTERN{
67 "((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|"
68 "2[0-4][0-9]|[01]?[0-9][0-9]?)"};
69
ConvertErrno()70 int ConvertErrno()
71 {
72 return TlsSocketError::TLS_ERR_SYS_BASE + errno;
73 }
74
ConvertSSLError(ssl_st * ssl)75 int ConvertSSLError(ssl_st *ssl)
76 {
77 if (!ssl) {
78 return TLS_ERR_SSL_NULL;
79 }
80 return TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl, SSL_RET_CODE);
81 }
82
MakeErrnoString()83 std::string MakeErrnoString()
84 {
85 return strerror(errno);
86 }
87
MakeSSLErrorString(int error)88 std::string MakeSSLErrorString(int error)
89 {
90 char err[MAX_ERR_LEN] = {0};
91 ERR_error_string_n(error - TlsSocketError::TLS_ERR_SYS_BASE, err, sizeof(err));
92 return err;
93 }
94
SplitEscapedAltNames(std::string & altNames)95 std::vector<std::string> SplitEscapedAltNames(std::string &altNames)
96 {
97 std::vector<std::string> result;
98 std::string currentToken;
99 size_t offset = 0;
100 while (offset != altNames.length()) {
101 auto nextSep = altNames.find_first_of(", ");
102 auto nextQuote = altNames.find_first_of('\"');
103 if (nextQuote != std::string::npos && (nextSep != std::string::npos || nextQuote < nextSep)) {
104 currentToken += altNames.substr(offset, nextQuote);
105 std::regex jsonStringPattern(JSON_STRING_PATTERN);
106 std::smatch match;
107 std::string altNameSubStr = altNames.substr(nextQuote);
108 bool ret = regex_match(altNameSubStr, match, jsonStringPattern);
109 if (!ret) {
110 return {""};
111 }
112 currentToken += result[0];
113 offset = nextQuote + result[0].length();
114 } else if (nextSep != std::string::npos) {
115 currentToken += altNames.substr(offset, nextSep);
116 result.push_back(currentToken);
117 currentToken = "";
118 offset = nextSep + OFFSET;
119 } else {
120 currentToken += altNames.substr(offset);
121 offset = altNames.length();
122 }
123 }
124 result.push_back(currentToken);
125 return result;
126 }
127
IsIP(const std::string & ip)128 bool IsIP(const std::string &ip)
129 {
130 std::regex pattern(PATTERN);
131 std::smatch res;
132 return regex_match(ip, res, pattern);
133 }
134
SplitHostName(std::string & hostName)135 std::vector<std::string> SplitHostName(std::string &hostName)
136 {
137 transform(hostName.begin(), hostName.end(), hostName.begin(), ::tolower);
138 return CommonUtils::Split(hostName, SPLIT_HOST_NAME);
139 }
140
SeekIntersection(std::vector<std::string> & vecA,std::vector<std::string> & vecB)141 bool SeekIntersection(std::vector<std::string> &vecA, std::vector<std::string> &vecB)
142 {
143 std::vector<std::string> result;
144 set_intersection(vecA.begin(), vecA.end(), vecB.begin(), vecB.end(), inserter(result, result.begin()));
145 return !result.empty();
146 }
147 } // namespace
148
TLSSecureOptions(const TLSSecureOptions & tlsSecureOptions)149 TLSSecureOptions::TLSSecureOptions(const TLSSecureOptions &tlsSecureOptions)
150 {
151 *this = tlsSecureOptions;
152 }
153
operator =(const TLSSecureOptions & tlsSecureOptions)154 TLSSecureOptions &TLSSecureOptions::operator=(const TLSSecureOptions &tlsSecureOptions)
155 {
156 key_ = tlsSecureOptions.GetKey();
157 caChain_ = tlsSecureOptions.GetCaChain();
158 cert_ = tlsSecureOptions.GetCert();
159 protocolChain_ = tlsSecureOptions.GetProtocolChain();
160 crlChain_ = tlsSecureOptions.GetCrlChain();
161 keyPass_ = tlsSecureOptions.GetKeyPass();
162 key_ = tlsSecureOptions.GetKey();
163 signatureAlgorithms_ = tlsSecureOptions.GetSignatureAlgorithms();
164 cipherSuite_ = tlsSecureOptions.GetCipherSuite();
165 useRemoteCipherPrefer_ = tlsSecureOptions.UseRemoteCipherPrefer();
166 TLSVerifyMode_ = tlsSecureOptions.GetVerifyMode();
167 return *this;
168 }
169
SetCaChain(const std::vector<std::string> & caChain)170 void TLSSecureOptions::SetCaChain(const std::vector<std::string> &caChain)
171 {
172 caChain_ = caChain;
173 }
174
SetCert(const std::string & cert)175 void TLSSecureOptions::SetCert(const std::string &cert)
176 {
177 cert_ = cert;
178 }
179
SetKey(const SecureData & key)180 void TLSSecureOptions::SetKey(const SecureData &key)
181 {
182 key_ = key;
183 }
184
SetKeyPass(const SecureData & keyPass)185 void TLSSecureOptions::SetKeyPass(const SecureData &keyPass)
186 {
187 keyPass_ = keyPass;
188 }
189
SetProtocolChain(const std::vector<std::string> & protocolChain)190 void TLSSecureOptions::SetProtocolChain(const std::vector<std::string> &protocolChain)
191 {
192 protocolChain_ = protocolChain;
193 }
194
SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer)195 void TLSSecureOptions::SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer)
196 {
197 useRemoteCipherPrefer_ = useRemoteCipherPrefer;
198 }
199
SetSignatureAlgorithms(const std::string & signatureAlgorithms)200 void TLSSecureOptions::SetSignatureAlgorithms(const std::string &signatureAlgorithms)
201 {
202 signatureAlgorithms_ = signatureAlgorithms;
203 }
204
SetCipherSuite(const std::string & cipherSuite)205 void TLSSecureOptions::SetCipherSuite(const std::string &cipherSuite)
206 {
207 cipherSuite_ = cipherSuite;
208 }
209
SetCrlChain(const std::vector<std::string> & crlChain)210 void TLSSecureOptions::SetCrlChain(const std::vector<std::string> &crlChain)
211 {
212 crlChain_ = crlChain;
213 }
214
GetCaChain() const215 const std::vector<std::string> &TLSSecureOptions::GetCaChain() const
216 {
217 return caChain_;
218 }
219
GetCert() const220 const std::string &TLSSecureOptions::GetCert() const
221 {
222 return cert_;
223 }
224
GetKey() const225 const SecureData &TLSSecureOptions::GetKey() const
226 {
227 return key_;
228 }
229
GetKeyPass() const230 const SecureData &TLSSecureOptions::GetKeyPass() const
231 {
232 return keyPass_;
233 }
234
GetProtocolChain() const235 const std::vector<std::string> &TLSSecureOptions::GetProtocolChain() const
236 {
237 return protocolChain_;
238 }
239
UseRemoteCipherPrefer() const240 bool TLSSecureOptions::UseRemoteCipherPrefer() const
241 {
242 return useRemoteCipherPrefer_;
243 }
244
GetSignatureAlgorithms() const245 const std::string &TLSSecureOptions::GetSignatureAlgorithms() const
246 {
247 return signatureAlgorithms_;
248 }
249
GetCipherSuite() const250 const std::string &TLSSecureOptions::GetCipherSuite() const
251 {
252 return cipherSuite_;
253 }
254
GetCrlChain() const255 const std::vector<std::string> &TLSSecureOptions::GetCrlChain() const
256 {
257 return crlChain_;
258 }
259
SetVerifyMode(VerifyMode verifyMode)260 void TLSSecureOptions::SetVerifyMode(VerifyMode verifyMode)
261 {
262 TLSVerifyMode_ = verifyMode;
263 }
264
GetVerifyMode() const265 VerifyMode TLSSecureOptions::GetVerifyMode() const
266 {
267 return TLSVerifyMode_;
268 }
269
SetNetAddress(const Socket::NetAddress & address)270 void TLSConnectOptions::SetNetAddress(const Socket::NetAddress &address)
271 {
272 address_.SetAddress(address.GetAddress());
273 address_.SetPort(address.GetPort());
274 address_.SetFamilyBySaFamily(address.GetSaFamily());
275 }
276
SetTlsSecureOptions(TLSSecureOptions & tlsSecureOptions)277 void TLSConnectOptions::SetTlsSecureOptions(TLSSecureOptions &tlsSecureOptions)
278 {
279 tlsSecureOptions_ = tlsSecureOptions;
280 }
281
SetCheckServerIdentity(const CheckServerIdentity & checkServerIdentity)282 void TLSConnectOptions::SetCheckServerIdentity(const CheckServerIdentity &checkServerIdentity)
283 {
284 checkServerIdentity_ = checkServerIdentity;
285 }
286
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)287 void TLSConnectOptions::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
288 {
289 alpnProtocols_ = alpnProtocols;
290 }
291
GetNetAddress() const292 Socket::NetAddress TLSConnectOptions::GetNetAddress() const
293 {
294 return address_;
295 }
296
GetTlsSecureOptions() const297 TLSSecureOptions TLSConnectOptions::GetTlsSecureOptions() const
298 {
299 return tlsSecureOptions_;
300 }
301
GetCheckServerIdentity() const302 CheckServerIdentity TLSConnectOptions::GetCheckServerIdentity() const
303 {
304 return checkServerIdentity_;
305 }
306
GetAlpnProtocols() const307 const std::vector<std::string> &TLSConnectOptions::GetAlpnProtocols() const
308 {
309 return alpnProtocols_;
310 }
311
MakeAddressString(sockaddr * addr)312 std::string TLSSocket::MakeAddressString(sockaddr *addr)
313 {
314 if (!addr) {
315 return {};
316 }
317 if (addr->sa_family == AF_INET) {
318 auto *addr4 = reinterpret_cast<sockaddr_in *>(addr);
319 const char *str = inet_ntoa(addr4->sin_addr);
320 if (str == nullptr || strlen(str) == 0) {
321 return {};
322 }
323 return str;
324 } else if (addr->sa_family == AF_INET6) {
325 auto *addr6 = reinterpret_cast<sockaddr_in6 *>(addr);
326 char str[INET6_ADDRSTRLEN] = {0};
327 if (inet_ntop(AF_INET6, &addr6->sin6_addr, str, INET6_ADDRSTRLEN) == nullptr || strlen(str) == 0) {
328 return {};
329 }
330 return str;
331 }
332 return {};
333 }
334
GetAddr(const Socket::NetAddress & address,sockaddr_in * addr4,sockaddr_in6 * addr6,sockaddr ** addr,socklen_t * len)335 void TLSSocket::GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr,
336 socklen_t *len)
337 {
338 if (!addr6 || !addr4 || !len) {
339 return;
340 }
341 sa_family_t family = address.GetSaFamily();
342 if (family == AF_INET) {
343 addr4->sin_family = AF_INET;
344 addr4->sin_port = htons(address.GetPort());
345 addr4->sin_addr.s_addr = inet_addr(address.GetAddress().c_str());
346 *addr = reinterpret_cast<sockaddr *>(addr4);
347 *len = sizeof(sockaddr_in);
348 } else if (family == AF_INET6) {
349 addr6->sin6_family = AF_INET6;
350 addr6->sin6_port = htons(address.GetPort());
351 inet_pton(AF_INET6, address.GetAddress().c_str(), &addr6->sin6_addr);
352 *addr = reinterpret_cast<sockaddr *>(addr6);
353 *len = sizeof(sockaddr_in6);
354 }
355 }
356
MakeIpSocket(sa_family_t family)357 void TLSSocket::MakeIpSocket(sa_family_t family)
358 {
359 if (family != AF_INET && family != AF_INET6) {
360 return;
361 }
362 int sock = socket(family, SOCK_STREAM, IPPROTO_IP);
363 if (sock < 0) {
364 int resErr = ConvertErrno();
365 NETSTACK_LOGE("Create socket failed (%{public}d:%{public}s)", errno, MakeErrnoString().c_str());
366 CallOnErrorCallback(resErr, MakeErrnoString());
367 return;
368 }
369 sockFd_ = sock;
370 }
371
StartReadMessage()372 void TLSSocket::StartReadMessage()
373 {
374 std::thread thread([this]() {
375 isRunning_ = true;
376 isRunOver_ = false;
377 while (isRunning_) {
378 char buffer[MAX_BUFFER_SIZE];
379 if (memset_s(buffer, MAX_BUFFER_SIZE, 0, MAX_BUFFER_SIZE) != EOK) {
380 NETSTACK_LOGE("memcpy_s failed!");
381 break;
382 }
383 int len = tlsSocketInternal_.Recv(buffer, MAX_BUFFER_SIZE);
384 if (!isRunning_) {
385 break;
386 }
387 if (len < 0) {
388 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
389 NETSTACK_LOGE("SSL_read function read error, errno is %{public}d, errno info is %{public}s", resErr,
390 MakeSSLErrorString(resErr).c_str());
391 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
392 break;
393 }
394
395 if (len == 0) {
396 continue;
397 }
398
399 Socket::SocketRemoteInfo remoteInfo;
400 remoteInfo.SetSize(len);
401 tlsSocketInternal_.MakeRemoteInfo(remoteInfo);
402 std::string bufContent(buffer, len);
403 CallOnMessageCallback(bufContent, remoteInfo);
404 if (strncmp(buffer, QUIT_RESPONSE_CODE, QUIT_RESPONSE_CODE_LEN) == 0) {
405 break;
406 }
407 }
408 isRunOver_ = true;
409 });
410 thread.detach();
411 }
412
CallOnMessageCallback(const std::string & data,const Socket::SocketRemoteInfo & remoteInfo)413 void TLSSocket::CallOnMessageCallback(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)
414 {
415 OnMessageCallback func = nullptr;
416 {
417 std::lock_guard<std::mutex> lock(mutex_);
418 if (onMessageCallback_) {
419 func = onMessageCallback_;
420 }
421 }
422
423 if (func) {
424 func(data, remoteInfo);
425 }
426 }
427
CallOnConnectCallback()428 void TLSSocket::CallOnConnectCallback()
429 {
430 OnConnectCallback func = nullptr;
431 {
432 std::lock_guard<std::mutex> lock(mutex_);
433 if (onConnectCallback_) {
434 func = onConnectCallback_;
435 }
436 }
437
438 if (func) {
439 func();
440 }
441 }
442
CallOnCloseCallback()443 void TLSSocket::CallOnCloseCallback()
444 {
445 OnCloseCallback func = nullptr;
446 {
447 std::lock_guard<std::mutex> lock(mutex_);
448 if (onCloseCallback_) {
449 func = onCloseCallback_;
450 }
451 }
452
453 if (func) {
454 func();
455 }
456 }
457
CallOnErrorCallback(int32_t err,const std::string & errString)458 void TLSSocket::CallOnErrorCallback(int32_t err, const std::string &errString)
459 {
460 OnErrorCallback func = nullptr;
461 {
462 std::lock_guard<std::mutex> lock(mutex_);
463 if (onErrorCallback_) {
464 func = onErrorCallback_;
465 }
466 }
467
468 if (func) {
469 func(err, errString);
470 }
471 }
472
CallBindCallback(int32_t err,BindCallback callback)473 void TLSSocket::CallBindCallback(int32_t err, BindCallback callback)
474 {
475 DealCallback<BindCallback>(err, callback);
476 }
477
CallConnectCallback(int32_t err,ConnectCallback callback)478 void TLSSocket::CallConnectCallback(int32_t err, ConnectCallback callback)
479 {
480 DealCallback<ConnectCallback>(err, callback);
481 }
482
CallSendCallback(int32_t err,SendCallback callback)483 void TLSSocket::CallSendCallback(int32_t err, SendCallback callback)
484 {
485 DealCallback<SendCallback>(err, callback);
486 }
487
CallCloseCallback(int32_t err,CloseCallback callback)488 void TLSSocket::CallCloseCallback(int32_t err, CloseCallback callback)
489 {
490 DealCallback<CloseCallback>(err, callback);
491 }
492
CallGetRemoteAddressCallback(int32_t err,const Socket::NetAddress & address,GetRemoteAddressCallback callback)493 void TLSSocket::CallGetRemoteAddressCallback(int32_t err, const Socket::NetAddress &address,
494 GetRemoteAddressCallback callback)
495 {
496 GetRemoteAddressCallback func = nullptr;
497 {
498 std::lock_guard<std::mutex> lock(mutex_);
499 if (callback) {
500 func = callback;
501 }
502 }
503
504 if (func) {
505 func(err, address);
506 }
507 }
508
CallGetStateCallback(int32_t err,const Socket::SocketStateBase & state,GetStateCallback callback)509 void TLSSocket::CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, GetStateCallback callback)
510 {
511 GetStateCallback func = nullptr;
512 {
513 std::lock_guard<std::mutex> lock(mutex_);
514 if (callback) {
515 func = callback;
516 }
517 }
518
519 if (func) {
520 func(err, state);
521 }
522 }
523
CallSetExtraOptionsCallback(int32_t err,SetExtraOptionsCallback callback)524 void TLSSocket::CallSetExtraOptionsCallback(int32_t err, SetExtraOptionsCallback callback)
525 {
526 DealCallback<SetExtraOptionsCallback>(err, callback);
527 }
528
CallGetCertificateCallback(int32_t err,const X509CertRawData & cert,GetCertificateCallback callback)529 void TLSSocket::CallGetCertificateCallback(int32_t err, const X509CertRawData &cert, GetCertificateCallback callback)
530 {
531 GetCertificateCallback func = nullptr;
532 {
533 std::lock_guard<std::mutex> lock(mutex_);
534 if (callback) {
535 func = callback;
536 }
537 }
538
539 if (func) {
540 func(err, cert);
541 }
542 }
543
CallGetRemoteCertificateCallback(int32_t err,const X509CertRawData & cert,GetRemoteCertificateCallback callback)544 void TLSSocket::CallGetRemoteCertificateCallback(int32_t err, const X509CertRawData &cert,
545 GetRemoteCertificateCallback callback)
546 {
547 GetRemoteCertificateCallback func = nullptr;
548 {
549 std::lock_guard<std::mutex> lock(mutex_);
550 if (callback) {
551 func = callback;
552 }
553 }
554
555 if (func) {
556 func(err, cert);
557 }
558 }
559
CallGetProtocolCallback(int32_t err,const std::string & protocol,GetProtocolCallback callback)560 void TLSSocket::CallGetProtocolCallback(int32_t err, const std::string &protocol, GetProtocolCallback callback)
561 {
562 GetProtocolCallback func = nullptr;
563 {
564 std::lock_guard<std::mutex> lock(mutex_);
565 if (callback) {
566 func = callback;
567 }
568 }
569
570 if (func) {
571 func(err, protocol);
572 }
573 }
574
CallGetCipherSuiteCallback(int32_t err,const std::vector<std::string> & suite,GetCipherSuiteCallback callback)575 void TLSSocket::CallGetCipherSuiteCallback(int32_t err, const std::vector<std::string> &suite,
576 GetCipherSuiteCallback callback)
577 {
578 GetCipherSuiteCallback func = nullptr;
579 {
580 std::lock_guard<std::mutex> lock(mutex_);
581 if (callback) {
582 func = callback;
583 }
584 }
585
586 if (func) {
587 func(err, suite);
588 }
589 }
590
CallGetSignatureAlgorithmsCallback(int32_t err,const std::vector<std::string> & algorithms,GetSignatureAlgorithmsCallback callback)591 void TLSSocket::CallGetSignatureAlgorithmsCallback(int32_t err, const std::vector<std::string> &algorithms,
592 GetSignatureAlgorithmsCallback callback)
593 {
594 GetSignatureAlgorithmsCallback func = nullptr;
595 {
596 std::lock_guard<std::mutex> lock(mutex_);
597 if (callback) {
598 func = callback;
599 }
600 }
601
602 if (func) {
603 func(err, algorithms);
604 }
605 }
606
Bind(const Socket::NetAddress & address,const BindCallback & callback)607 void TLSSocket::Bind(const Socket::NetAddress &address, const BindCallback &callback)
608 {
609 if (!CommonUtils::HasInternetPermission()) {
610 CallBindCallback(PERMISSION_DENIED_CODE, callback);
611 return;
612 }
613 if (sockFd_ >= 0) {
614 CallBindCallback(TLSSOCKET_SUCCESS, callback);
615 return;
616 }
617
618 MakeIpSocket(address.GetSaFamily());
619 if (sockFd_ < 0) {
620 int resErr = ConvertErrno();
621 NETSTACK_LOGE("make tcp socket failed errno is %{public}d %{public}s", errno, MakeErrnoString().c_str());
622 CallOnErrorCallback(resErr, MakeErrnoString());
623 CallBindCallback(resErr, callback);
624 return;
625 }
626
627 sockaddr_in addr4 = {0};
628 sockaddr_in6 addr6 = {0};
629 sockaddr *addr = nullptr;
630 socklen_t len;
631 GetAddr(address, &addr4, &addr6, &addr, &len);
632 if (addr == nullptr) {
633 NETSTACK_LOGE("TLSSocket::Bind Address Is Invalid");
634 CallOnErrorCallback(-1, "Address Is Invalid");
635 CallBindCallback(ConvertErrno(), callback);
636 return;
637 }
638 CallBindCallback(TLSSOCKET_SUCCESS, callback);
639 }
640
Connect(OHOS::NetStack::TlsSocket::TLSConnectOptions & tlsConnectOptions,const OHOS::NetStack::TlsSocket::ConnectCallback & callback)641 void TLSSocket::Connect(OHOS::NetStack::TlsSocket::TLSConnectOptions &tlsConnectOptions,
642 const OHOS::NetStack::TlsSocket::ConnectCallback &callback)
643 {
644 if (sockFd_ < 0) {
645 int resErr = ConvertErrno();
646 NETSTACK_LOGE("connect error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
647 CallOnErrorCallback(resErr, MakeErrnoString());
648 callback(resErr);
649 return;
650 }
651
652 auto res = tlsSocketInternal_.TlsConnectToHost(sockFd_, tlsConnectOptions);
653 if (!res) {
654 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
655 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
656 callback(resErr);
657 return;
658 }
659 StartReadMessage();
660 CallOnConnectCallback();
661 callback(TLSSOCKET_SUCCESS);
662 }
663
Send(const OHOS::NetStack::Socket::TCPSendOptions & tcpSendOptions,const SendCallback & callback)664 void TLSSocket::Send(const OHOS::NetStack::Socket::TCPSendOptions &tcpSendOptions, const SendCallback &callback)
665 {
666 (void)tcpSendOptions;
667
668 auto res = tlsSocketInternal_.Send(tcpSendOptions.GetData());
669 if (!res) {
670 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
671 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
672 CallSendCallback(resErr, callback);
673 return;
674 }
675 CallSendCallback(TLSSOCKET_SUCCESS, callback);
676 }
677
WaitConditionWithTimeout(const bool * flag,const int32_t timeoutMs)678 bool WaitConditionWithTimeout(const bool *flag, const int32_t timeoutMs)
679 {
680 int maxWaitCnt = timeoutMs / WAIT_MS;
681 int cnt = 0;
682 while (!(*flag)) {
683 if (cnt >= maxWaitCnt) {
684 return false;
685 }
686 std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MS));
687 cnt++;
688 }
689 return true;
690 }
691
Close(const CloseCallback & callback)692 void TLSSocket::Close(const CloseCallback &callback)
693 {
694 if (!WaitConditionWithTimeout(&isRunning_, TIMEOUT_MS)) {
695 callback(ConvertErrno());
696 NETSTACK_LOGE("The error cause is that the runtime wait time is insufficient");
697 return;
698 }
699 isRunning_ = false;
700
701 auto res = tlsSocketInternal_.Close();
702 if (!res) {
703 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
704 NETSTACK_LOGE("close error is %{public}s %{public}d", MakeSSLErrorString(resErr).c_str(), resErr);
705 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
706 callback(resErr);
707 return;
708 }
709 CallOnCloseCallback();
710 callback(TLSSOCKET_SUCCESS);
711 }
712
GetRemoteAddress(const GetRemoteAddressCallback & callback)713 void TLSSocket::GetRemoteAddress(const GetRemoteAddressCallback &callback)
714 {
715 sockaddr sockAddr = {0};
716 socklen_t len = sizeof(sockaddr);
717 int ret = getsockname(sockFd_, &sockAddr, &len);
718 if (ret < 0) {
719 int resErr = ConvertErrno();
720 NETSTACK_LOGE("getsockname failed errno %{public}d", resErr);
721 CallOnErrorCallback(resErr, MakeErrnoString());
722 CallGetRemoteAddressCallback(resErr, {}, callback);
723 return;
724 }
725
726 if (sockAddr.sa_family == AF_INET) {
727 GetIp4RemoteAddress(callback);
728 } else if (sockAddr.sa_family == AF_INET6) {
729 GetIp6RemoteAddress(callback);
730 }
731 }
732
GetIp4RemoteAddress(const GetRemoteAddressCallback & callback)733 void TLSSocket::GetIp4RemoteAddress(const GetRemoteAddressCallback &callback)
734 {
735 sockaddr_in addr4 = {0};
736 socklen_t len4 = sizeof(sockaddr_in);
737
738 int ret = getpeername(sockFd_, reinterpret_cast<sockaddr *>(&addr4), &len4);
739 if (ret < 0) {
740 int resErr = ConvertErrno();
741 NETSTACK_LOGE("GetIp4RemoteAddress failed errno %{public}d", resErr);
742 CallOnErrorCallback(resErr, MakeErrnoString());
743 CallGetRemoteAddressCallback(resErr, {}, callback);
744 return;
745 }
746
747 std::string address = MakeAddressString(reinterpret_cast<sockaddr *>(&addr4));
748 if (address.empty()) {
749 NETSTACK_LOGE("GetIp4RemoteAddress failed errno %{public}d", errno);
750 CallOnErrorCallback(-1, "Address is invalid");
751 CallGetRemoteAddressCallback(ConvertErrno(), {}, callback);
752 return;
753 }
754 Socket::NetAddress netAddress;
755 netAddress.SetAddress(address);
756 netAddress.SetFamilyBySaFamily(AF_INET);
757 netAddress.SetPort(ntohs(addr4.sin_port));
758 CallGetRemoteAddressCallback(TLSSOCKET_SUCCESS, netAddress, callback);
759 }
760
GetIp6RemoteAddress(const GetRemoteAddressCallback & callback)761 void TLSSocket::GetIp6RemoteAddress(const GetRemoteAddressCallback &callback)
762 {
763 sockaddr_in6 addr6 = {0};
764 socklen_t len6 = sizeof(sockaddr_in6);
765
766 int ret = getpeername(sockFd_, reinterpret_cast<sockaddr *>(&addr6), &len6);
767 if (ret < 0) {
768 int resErr = ConvertErrno();
769 NETSTACK_LOGE("GetIp6RemoteAddress failed errno %{public}d", resErr);
770 CallOnErrorCallback(resErr, MakeErrnoString());
771 CallGetRemoteAddressCallback(resErr, {}, callback);
772 return;
773 }
774
775 std::string address = MakeAddressString(reinterpret_cast<sockaddr *>(&addr6));
776 if (address.empty()) {
777 NETSTACK_LOGE("GetIp6RemoteAddress failed errno %{public}d", errno);
778 CallOnErrorCallback(-1, "Address is invalid");
779 CallGetRemoteAddressCallback(ConvertErrno(), {}, callback);
780 return;
781 }
782 Socket::NetAddress netAddress;
783 netAddress.SetAddress(address);
784 netAddress.SetFamilyBySaFamily(AF_INET6);
785 netAddress.SetPort(ntohs(addr6.sin6_port));
786 CallGetRemoteAddressCallback(TLSSOCKET_SUCCESS, netAddress, callback);
787 }
788
GetState(const GetStateCallback & callback)789 void TLSSocket::GetState(const GetStateCallback &callback)
790 {
791 int opt;
792 socklen_t optLen = sizeof(int);
793 int r = getsockopt(sockFd_, SOL_SOCKET, SO_TYPE, &opt, &optLen);
794 if (r < 0) {
795 Socket::SocketStateBase state;
796 state.SetIsClose(true);
797 CallGetStateCallback(ConvertErrno(), state, callback);
798 return;
799 }
800 sockaddr sockAddr = {0};
801 socklen_t len = sizeof(sockaddr);
802 Socket::SocketStateBase state;
803 int ret = getsockname(sockFd_, &sockAddr, &len);
804 state.SetIsBound(ret == 0);
805 ret = getpeername(sockFd_, &sockAddr, &len);
806 state.SetIsConnected(ret == 0);
807 CallGetStateCallback(TLSSOCKET_SUCCESS, state, callback);
808 }
809
SetBaseOptions(const Socket::ExtraOptionsBase & option) const810 bool TLSSocket::SetBaseOptions(const Socket::ExtraOptionsBase &option) const
811 {
812 if (option.GetReceiveBufferSize() != 0) {
813 int size = (int)option.GetReceiveBufferSize();
814 if (setsockopt(sockFd_, SOL_SOCKET, SO_RCVBUF, reinterpret_cast<void *>(&size), sizeof(size)) < 0) {
815 return false;
816 }
817 }
818
819 if (option.GetSendBufferSize() != 0) {
820 int size = (int)option.GetSendBufferSize();
821 if (setsockopt(sockFd_, SOL_SOCKET, SO_SNDBUF, reinterpret_cast<void *>(&size), sizeof(size)) < 0) {
822 return false;
823 }
824 }
825
826 if (option.IsReuseAddress()) {
827 int reuse = 1;
828 if (setsockopt(sockFd_, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<void *>(&reuse), sizeof(reuse)) < 0) {
829 return false;
830 }
831 }
832
833 if (option.GetSocketTimeout() != 0) {
834 timeval timeout = {(int)option.GetSocketTimeout(), 0};
835 if (setsockopt(sockFd_, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<void *>(&timeout), sizeof(timeout)) < 0) {
836 return false;
837 }
838 if (setsockopt(sockFd_, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast<void *>(&timeout), sizeof(timeout)) < 0) {
839 return false;
840 }
841 }
842
843 return true;
844 }
845
SetExtraOptions(const Socket::TCPExtraOptions & option) const846 bool TLSSocket::SetExtraOptions(const Socket::TCPExtraOptions &option) const
847 {
848 if (option.IsKeepAlive()) {
849 int keepalive = 1;
850 if (setsockopt(sockFd_, SOL_SOCKET, SO_KEEPALIVE, &keepalive, sizeof(keepalive)) < 0) {
851 return false;
852 }
853 }
854
855 if (option.IsOOBInline()) {
856 int oobInline = 1;
857 if (setsockopt(sockFd_, SOL_SOCKET, SO_OOBINLINE, &oobInline, sizeof(oobInline)) < 0) {
858 return false;
859 }
860 }
861
862 if (option.IsTCPNoDelay()) {
863 int tcpNoDelay = 1;
864 if (setsockopt(sockFd_, IPPROTO_TCP, TCP_NODELAY, &tcpNoDelay, sizeof(tcpNoDelay)) < 0) {
865 return false;
866 }
867 }
868
869 linger soLinger = {0};
870 soLinger.l_onoff = option.socketLinger.IsOn();
871 soLinger.l_linger = (int)option.socketLinger.GetLinger();
872 if (setsockopt(sockFd_, SOL_SOCKET, SO_LINGER, &soLinger, sizeof(soLinger)) < 0) {
873 return false;
874 }
875
876 return true;
877 }
878
SetExtraOptions(const OHOS::NetStack::Socket::TCPExtraOptions & tcpExtraOptions,const SetExtraOptionsCallback & callback)879 void TLSSocket::SetExtraOptions(const OHOS::NetStack::Socket::TCPExtraOptions &tcpExtraOptions,
880 const SetExtraOptionsCallback &callback)
881 {
882 if (!SetBaseOptions(tcpExtraOptions)) {
883 NETSTACK_LOGE("SetExtraOptions errno %{public}d", errno);
884 CallOnErrorCallback(errno, MakeErrnoString());
885 CallSetExtraOptionsCallback(ConvertErrno(), callback);
886 return;
887 }
888
889 if (!SetExtraOptions(tcpExtraOptions)) {
890 NETSTACK_LOGE("SetExtraOptions errno %{public}d", errno);
891 CallOnErrorCallback(errno, MakeErrnoString());
892 CallSetExtraOptionsCallback(ConvertErrno(), callback);
893 return;
894 }
895
896 CallSetExtraOptionsCallback(TLSSOCKET_SUCCESS, callback);
897 }
898
GetCertificate(const GetCertificateCallback & callback)899 void TLSSocket::GetCertificate(const GetCertificateCallback &callback)
900 {
901 const auto &cert = tlsSocketInternal_.GetCertificate();
902 NETSTACK_LOGI("cert der is %{public}d", cert.encodingFormat);
903
904 if (!cert.data.Length()) {
905 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
906 NETSTACK_LOGE("GetCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
907 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
908 callback(resErr, {});
909 return;
910 }
911 callback(TLSSOCKET_SUCCESS, cert);
912 }
913
GetRemoteCertificate(const GetRemoteCertificateCallback & callback)914 void TLSSocket::GetRemoteCertificate(const GetRemoteCertificateCallback &callback)
915 {
916 const auto &remoteCert = tlsSocketInternal_.GetRemoteCertRawData();
917 if (!remoteCert.data.Length()) {
918 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
919 NETSTACK_LOGE("GetRemoteCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
920 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
921 callback(resErr, {});
922 return;
923 }
924 callback(TLSSOCKET_SUCCESS, remoteCert);
925 }
926
GetProtocol(const GetProtocolCallback & callback)927 void TLSSocket::GetProtocol(const GetProtocolCallback &callback)
928 {
929 const auto &protocol = tlsSocketInternal_.GetProtocol();
930 if (protocol.empty()) {
931 NETSTACK_LOGE("GetProtocol errno %{public}d", errno);
932 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
933 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
934 callback(resErr, "");
935 return;
936 }
937 callback(TLSSOCKET_SUCCESS, protocol);
938 }
939
GetCipherSuite(const GetCipherSuiteCallback & callback)940 void TLSSocket::GetCipherSuite(const GetCipherSuiteCallback &callback)
941 {
942 const auto &cipherSuite = tlsSocketInternal_.GetCipherSuite();
943 if (cipherSuite.empty()) {
944 NETSTACK_LOGE("GetCipherSuite errno %{public}d", errno);
945 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
946 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
947 callback(resErr, cipherSuite);
948 return;
949 }
950 callback(TLSSOCKET_SUCCESS, cipherSuite);
951 }
952
GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback & callback)953 void TLSSocket::GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback &callback)
954 {
955 const auto &signatureAlgorithms = tlsSocketInternal_.GetSignatureAlgorithms();
956 if (signatureAlgorithms.empty()) {
957 NETSTACK_LOGE("GetSignatureAlgorithms errno %{public}d", errno);
958 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
959 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
960 callback(resErr, {});
961 return;
962 }
963 callback(TLSSOCKET_SUCCESS, signatureAlgorithms);
964 }
965
OnMessage(const OnMessageCallback & onMessageCallback)966 void TLSSocket::OnMessage(const OnMessageCallback &onMessageCallback)
967 {
968 std::lock_guard<std::mutex> lock(mutex_);
969 onMessageCallback_ = onMessageCallback;
970 }
971
OffMessage()972 void TLSSocket::OffMessage()
973 {
974 std::lock_guard<std::mutex> lock(mutex_);
975 if (onMessageCallback_) {
976 onMessageCallback_ = nullptr;
977 }
978 }
979
OnConnect(const OnConnectCallback & onConnectCallback)980 void TLSSocket::OnConnect(const OnConnectCallback &onConnectCallback)
981 {
982 std::lock_guard<std::mutex> lock(mutex_);
983 onConnectCallback_ = onConnectCallback;
984 }
985
OffConnect()986 void TLSSocket::OffConnect()
987 {
988 std::lock_guard<std::mutex> lock(mutex_);
989 if (onConnectCallback_) {
990 onConnectCallback_ = nullptr;
991 }
992 }
993
OnClose(const OnCloseCallback & onCloseCallback)994 void TLSSocket::OnClose(const OnCloseCallback &onCloseCallback)
995 {
996 std::lock_guard<std::mutex> lock(mutex_);
997 onCloseCallback_ = onCloseCallback;
998 }
999
OffClose()1000 void TLSSocket::OffClose()
1001 {
1002 std::lock_guard<std::mutex> lock(mutex_);
1003 if (onCloseCallback_) {
1004 onCloseCallback_ = nullptr;
1005 }
1006 }
1007
OnError(const OnErrorCallback & onErrorCallback)1008 void TLSSocket::OnError(const OnErrorCallback &onErrorCallback)
1009 {
1010 std::lock_guard<std::mutex> lock(mutex_);
1011 onErrorCallback_ = onErrorCallback;
1012 }
1013
OffError()1014 void TLSSocket::OffError()
1015 {
1016 std::lock_guard<std::mutex> lock(mutex_);
1017 if (onErrorCallback_) {
1018 onErrorCallback_ = nullptr;
1019 }
1020 }
1021
ExecSocketConnect(const std::string & hostName,int port,sa_family_t family,int socketDescriptor)1022 bool ExecSocketConnect(const std::string &hostName, int port, sa_family_t family, int socketDescriptor)
1023 {
1024 struct sockaddr_in dest = {0};
1025 dest.sin_family = family;
1026 dest.sin_port = htons(port);
1027 if (!inet_aton(hostName.c_str(), reinterpret_cast<in_addr *>(&dest.sin_addr.s_addr))) {
1028 NETSTACK_LOGE("inet_aton is error, hostName is %s", hostName.c_str());
1029 return false;
1030 }
1031 int connectResult = connect(socketDescriptor, reinterpret_cast<struct sockaddr *>(&dest), sizeof(dest));
1032 if (connectResult == -1) {
1033 NETSTACK_LOGE("socket connect error!The error code is %{public}d, The error message is %{public}s", errno,
1034 strerror(errno));
1035 return false;
1036 }
1037 return true;
1038 }
1039
TlsConnectToHost(int sock,const TLSConnectOptions & options)1040 bool TLSSocket::TLSSocketInternal::TlsConnectToHost(int sock, const TLSConnectOptions &options)
1041 {
1042 SetTlsConfiguration(options);
1043 std::string cipherSuite = options.GetTlsSecureOptions().GetCipherSuite();
1044 if (!cipherSuite.empty()) {
1045 configuration_.SetCipherSuite(cipherSuite);
1046 }
1047 std::string signatureAlgorithms = options.GetTlsSecureOptions().GetSignatureAlgorithms();
1048 if (!signatureAlgorithms.empty()) {
1049 configuration_.SetSignatureAlgorithms(signatureAlgorithms);
1050 }
1051 const auto protocolVec = options.GetTlsSecureOptions().GetProtocolChain();
1052 if (!protocolVec.empty()) {
1053 configuration_.SetProtocol(protocolVec);
1054 }
1055
1056 hostName_ = options.GetNetAddress().GetAddress();
1057 port_ = options.GetNetAddress().GetPort();
1058 family_ = options.GetNetAddress().GetSaFamily();
1059 socketDescriptor_ = sock;
1060 if (!ExecSocketConnect(options.GetNetAddress().GetAddress(), options.GetNetAddress().GetPort(),
1061 options.GetNetAddress().GetSaFamily(), socketDescriptor_)) {
1062 return false;
1063 }
1064 return StartTlsConnected(options);
1065 }
1066
SetTlsConfiguration(const TLSConnectOptions & config)1067 void TLSSocket::TLSSocketInternal::SetTlsConfiguration(const TLSConnectOptions &config)
1068 {
1069 configuration_.SetPrivateKey(config.GetTlsSecureOptions().GetKey(), config.GetTlsSecureOptions().GetKeyPass());
1070 configuration_.SetLocalCertificate(config.GetTlsSecureOptions().GetCert());
1071 configuration_.SetCaCertificate(config.GetTlsSecureOptions().GetCaChain());
1072 }
1073
Send(const std::string & data)1074 bool TLSSocket::TLSSocketInternal::Send(const std::string &data)
1075 {
1076 NETSTACK_LOGD("data to send :%{public}s", data.c_str());
1077 if (data.empty()) {
1078 NETSTACK_LOGE("data is empty");
1079 return false;
1080 }
1081 if (!ssl_) {
1082 NETSTACK_LOGE("ssl is null");
1083 return false;
1084 }
1085 int len = SSL_write(ssl_, data.c_str(), data.length());
1086 if (len < 0) {
1087 int resErr = ConvertSSLError(GetSSL());
1088 NETSTACK_LOGE("data '%{public}s' send failed!The error code is %{public}d, The error message is'%{public}s'",
1089 data.c_str(), resErr, MakeSSLErrorString(resErr).c_str());
1090 return false;
1091 }
1092 NETSTACK_LOGD("data '%{public}s' Sent successfully,sent in total %{public}d bytes!", data.c_str(), len);
1093 return true;
1094 }
Recv(char * buffer,int maxBufferSize)1095 int TLSSocket::TLSSocketInternal::Recv(char *buffer, int maxBufferSize)
1096 {
1097 if (!ssl_) {
1098 NETSTACK_LOGE("ssl is null");
1099 return SSL_ERROR_RETURN;
1100 }
1101 return SSL_read(ssl_, buffer, maxBufferSize);
1102 }
1103
Close()1104 bool TLSSocket::TLSSocketInternal::Close()
1105 {
1106 if (!ssl_) {
1107 NETSTACK_LOGE("ssl is null");
1108 return false;
1109 }
1110 int result = SSL_shutdown(ssl_);
1111 if (result < 0) {
1112 int resErr = ConvertSSLError(GetSSL());
1113 NETSTACK_LOGE("Error in shutdown, errno is %{public}d, error info is %{public}s", resErr,
1114 MakeSSLErrorString(resErr).c_str());
1115 return false;
1116 }
1117 SSL_free(ssl_);
1118 ssl_ = nullptr;
1119 close(socketDescriptor_);
1120 socketDescriptor_ = -1;
1121 if (!tlsContextPointer_) {
1122 NETSTACK_LOGE("Tls context pointer is null");
1123 return false;
1124 }
1125 tlsContextPointer_->CloseCtx();
1126 return true;
1127 }
1128
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)1129 bool TLSSocket::TLSSocketInternal::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
1130 {
1131 if (!ssl_) {
1132 NETSTACK_LOGE("ssl is null");
1133 return false;
1134 }
1135 size_t pos = 0;
1136 size_t len = std::accumulate(alpnProtocols.begin(), alpnProtocols.end(), static_cast<size_t>(0),
1137 [](size_t init, const std::string &alpnProt) { return init + alpnProt.length(); });
1138 auto result = std::make_unique<unsigned char[]>(alpnProtocols.size() + len);
1139 for (const auto &str : alpnProtocols) {
1140 len = str.length();
1141 result[pos++] = len;
1142 if (!strcpy_s(reinterpret_cast<char *>(&result[pos]), len, str.c_str())) {
1143 NETSTACK_LOGE("strcpy_s failed");
1144 return false;
1145 }
1146 pos += len;
1147 }
1148 result[pos] = '\0';
1149
1150 NETSTACK_LOGD("alpnProtocols after splicing %{public}s", result.get());
1151 if (SSL_set_alpn_protos(ssl_, result.get(), pos)) {
1152 int resErr = ConvertSSLError(GetSSL());
1153 NETSTACK_LOGE("Failed to set negotiable protocol list, errno is %{public}d, error info is %{public}s", resErr,
1154 MakeSSLErrorString(resErr).c_str());
1155 return false;
1156 }
1157 return true;
1158 }
1159
MakeRemoteInfo(Socket::SocketRemoteInfo & remoteInfo)1160 void TLSSocket::TLSSocketInternal::MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo)
1161 {
1162 remoteInfo.SetAddress(hostName_);
1163 remoteInfo.SetPort(port_);
1164 remoteInfo.SetFamily(family_);
1165 }
1166
GetTlsConfiguration() const1167 TLSConfiguration TLSSocket::TLSSocketInternal::GetTlsConfiguration() const
1168 {
1169 return configuration_;
1170 }
1171
GetCipherSuite() const1172 std::vector<std::string> TLSSocket::TLSSocketInternal::GetCipherSuite() const
1173 {
1174 if (!ssl_) {
1175 NETSTACK_LOGE("ssl in null");
1176 return {};
1177 }
1178 STACK_OF(SSL_CIPHER) *sk = SSL_get_ciphers(ssl_);
1179 if (!sk) {
1180 NETSTACK_LOGE("get ciphers failed");
1181 return {};
1182 }
1183 CipherSuite cipherSuite;
1184 std::vector<std::string> cipherSuiteVec;
1185 for (int i = 0; i < sk_SSL_CIPHER_num(sk); i++) {
1186 const SSL_CIPHER *c = sk_SSL_CIPHER_value(sk, i);
1187 cipherSuite.cipherName_ = SSL_CIPHER_get_name(c);
1188 cipherSuiteVec.push_back(cipherSuite.cipherName_);
1189 }
1190 return cipherSuiteVec;
1191 }
1192
GetRemoteCertificate() const1193 std::string TLSSocket::TLSSocketInternal::GetRemoteCertificate() const
1194 {
1195 return remoteCert_;
1196 }
1197
GetCertificate() const1198 const X509CertRawData &TLSSocket::TLSSocketInternal::GetCertificate() const
1199 {
1200 return configuration_.GetCertificate();
1201 }
1202
GetSignatureAlgorithms() const1203 std::vector<std::string> TLSSocket::TLSSocketInternal::GetSignatureAlgorithms() const
1204 {
1205 return signatureAlgorithms_;
1206 }
1207
GetProtocol() const1208 std::string TLSSocket::TLSSocketInternal::GetProtocol() const
1209 {
1210 if (!ssl_) {
1211 NETSTACK_LOGE("ssl in null");
1212 return PROTOCOL_UNKNOW;
1213 }
1214 if (configuration_.GetProtocol() == TLS_V1_3) {
1215 return PROTOCOL_TLS_V13;
1216 }
1217 return PROTOCOL_TLS_V12;
1218 }
1219
SetSharedSigals()1220 bool TLSSocket::TLSSocketInternal::SetSharedSigals()
1221 {
1222 if (!ssl_) {
1223 NETSTACK_LOGE("ssl is null");
1224 return false;
1225 }
1226 int number = SSL_get_shared_sigalgs(ssl_, 0, nullptr, nullptr, nullptr, nullptr, nullptr);
1227 if (!number) {
1228 NETSTACK_LOGE("SSL_get_shared_sigalgs return value error");
1229 return false;
1230 }
1231 for (int i = 0; i < number; i++) {
1232 int hash_nid;
1233 int sign_nid;
1234 std::string sig_with_md;
1235 SSL_get_shared_sigalgs(ssl_, i, &sign_nid, &hash_nid, nullptr, nullptr, nullptr);
1236 switch (sign_nid) {
1237 case EVP_PKEY_RSA:
1238 sig_with_md = SIGN_NID_RSA;
1239 break;
1240 case EVP_PKEY_RSA_PSS:
1241 sig_with_md = SIGN_NID_RSA_PSS;
1242 break;
1243 case EVP_PKEY_DSA:
1244 sig_with_md = SIGN_NID_DSA;
1245 break;
1246 case EVP_PKEY_EC:
1247 sig_with_md = SIGN_NID_ECDSA;
1248 break;
1249 case NID_ED25519:
1250 sig_with_md = SIGN_NID_ED;
1251 break;
1252 case NID_ED448:
1253 sig_with_md = SIGN_NID_ED_FOUR_FOUR_EIGHT;
1254 break;
1255 default:
1256 const char *sn = OBJ_nid2sn(sign_nid);
1257 sig_with_md = (sn != nullptr) ? (std::string(sn) + OPERATOR_PLUS_SIGN) : SIGN_NID_UNDEF_ADD;
1258 }
1259 const char *sn_hash = OBJ_nid2sn(hash_nid);
1260 sig_with_md += (sn_hash != nullptr) ? std::string(sn_hash) : SIGN_NID_UNDEF;
1261 signatureAlgorithms_.push_back(sig_with_md);
1262 }
1263 return true;
1264 }
1265
StartTlsConnected(const TLSConnectOptions & options)1266 bool TLSSocket::TLSSocketInternal::StartTlsConnected(const TLSConnectOptions &options)
1267 {
1268 if (!CreatTlsContext()) {
1269 NETSTACK_LOGE("failed to create tls context");
1270 return false;
1271 }
1272 if (!StartShakingHands(options)) {
1273 NETSTACK_LOGE("failed to shaking hands");
1274 return false;
1275 }
1276 return true;
1277 }
1278
CreatTlsContext()1279 bool TLSSocket::TLSSocketInternal::CreatTlsContext()
1280 {
1281 tlsContextPointer_ = TLSContext::CreateConfiguration(configuration_);
1282 if (!tlsContextPointer_) {
1283 NETSTACK_LOGE("failed to create tls context pointer");
1284 return false;
1285 }
1286 if (!(ssl_ = tlsContextPointer_->CreateSsl())) {
1287 NETSTACK_LOGE("failed to create ssl session");
1288 return false;
1289 }
1290 SSL_set_fd(ssl_, socketDescriptor_);
1291 SSL_set_connect_state(ssl_);
1292 return true;
1293 }
1294
StartsWith(const std::string & s,const std::string & prefix)1295 static bool StartsWith(const std::string &s, const std::string &prefix)
1296 {
1297 return s.size() >= prefix.size() && s.compare(0, prefix.size(), prefix) == 0;
1298 }
1299
CheckIpAndDnsName(const std::string & hostName,std::vector<std::string> dnsNames,std::vector<std::string> ips,const X509 * x509Certificates,std::tuple<bool,std::string> & result)1300 void CheckIpAndDnsName(const std::string &hostName, std::vector<std::string> dnsNames, std::vector<std::string> ips,
1301 const X509 *x509Certificates, std::tuple<bool, std::string> &result)
1302 {
1303 bool valid = false;
1304 std::string reason = UNKNOW_REASON;
1305 int index = X509_get_ext_by_NID(x509Certificates, NID_commonName, -1);
1306 if (IsIP(hostName)) {
1307 auto it = find(ips.begin(), ips.end(), hostName);
1308 if (it == ips.end()) {
1309 reason = IP + hostName + " is not in the cert's list";
1310 }
1311 result = {valid, reason};
1312 return;
1313 }
1314 std::string tempHostName = "" + hostName;
1315 if (!dnsNames.empty() || index > 0) {
1316 std::vector<std::string> hostParts = SplitHostName(tempHostName);
1317 if (!dnsNames.empty()) {
1318 valid = SeekIntersection(hostParts, dnsNames);
1319 if (!valid) {
1320 reason = HOST_NAME + tempHostName + ". is not in the cert's altnames";
1321 }
1322 } else {
1323 char commonNameBuf[COMMON_NAME_BUF_SIZE] = {0};
1324 X509_NAME *pSubName = nullptr;
1325 int len = X509_NAME_get_text_by_NID(pSubName, NID_commonName, commonNameBuf, COMMON_NAME_BUF_SIZE);
1326 if (len > 0) {
1327 std::vector<std::string> commonNameVec;
1328 commonNameVec.emplace_back(commonNameBuf);
1329 valid = SeekIntersection(hostParts, commonNameVec);
1330 if (!valid) {
1331 reason = HOST_NAME + tempHostName + ". is not cert's CN";
1332 }
1333 }
1334 }
1335 result = {valid, reason};
1336 return;
1337 }
1338 reason = "Cert does not contain a DNS name";
1339 result = {valid, reason};
1340 }
1341
CheckServerIdentityLegal(const std::string & hostName,const X509 * x509Certificates)1342 std::string TLSSocket::TLSSocketInternal::CheckServerIdentityLegal(const std::string &hostName,
1343 const X509 *x509Certificates)
1344 {
1345 X509_NAME *subjectName = X509_get_subject_name(x509Certificates);
1346 if (!subjectName) {
1347 return "subject name is null";
1348 }
1349 char subNameBuf[BUF_SIZE] = {0};
1350 X509_NAME_oneline(subjectName, subNameBuf, BUF_SIZE);
1351
1352 int index = X509_get_ext_by_NID(x509Certificates, NID_subject_alt_name, -1);
1353 if (index < 0) {
1354 return "X509 get ext nid error";
1355 }
1356 X509_EXTENSION *ext = X509_get_ext(x509Certificates, index);
1357 if (ext == nullptr) {
1358 return "X509 get ext error";
1359 }
1360 ASN1_OBJECT *obj = nullptr;
1361 obj = X509_EXTENSION_get_object(ext);
1362 char subAltNameBuf[BUF_SIZE] = {0};
1363 OBJ_obj2txt(subAltNameBuf, BUF_SIZE, obj, 0);
1364 NETSTACK_LOGD("extions obj : %{public}s\n", subAltNameBuf);
1365
1366 return CheckServerIdentityLegal(hostName, ext, x509Certificates);
1367 }
1368
CheckServerIdentityLegal(const std::string & hostName,X509_EXTENSION * ext,const X509 * x509Certificates)1369 std::string TLSSocket::TLSSocketInternal::CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext,
1370 const X509 *x509Certificates)
1371 {
1372 ASN1_OCTET_STRING *extData = X509_EXTENSION_get_data(ext);
1373 std::string altNames = reinterpret_cast<char *>(extData->data);
1374 std::string hostname = " " + hostName;
1375 BIO *bio = BIO_new(BIO_s_file());
1376 if (!bio) {
1377 return "bio is null";
1378 }
1379 BIO_set_fp(bio, stdout, BIO_NOCLOSE);
1380 ASN1_STRING_print(bio, extData);
1381 std::vector<std::string> dnsNames = {};
1382 std::vector<std::string> ips = {};
1383 constexpr int DNS_NAME_IDX = 4;
1384 constexpr int IP_NAME_IDX = 11;
1385 if (!altNames.empty()) {
1386 std::vector<std::string> splitAltNames;
1387 if (altNames.find('\"') != std::string::npos) {
1388 splitAltNames = SplitEscapedAltNames(altNames);
1389 } else {
1390 splitAltNames = CommonUtils::Split(altNames, SPLIT_ALT_NAMES);
1391 }
1392 for (auto const &iter : splitAltNames) {
1393 if (StartsWith(iter, DNS)) {
1394 dnsNames.push_back(iter.substr(DNS_NAME_IDX));
1395 } else if (StartsWith(iter, IP_ADDRESS)) {
1396 ips.push_back(iter.substr(IP_NAME_IDX));
1397 }
1398 }
1399 }
1400 std::tuple<bool, std::string> result;
1401 CheckIpAndDnsName(hostName, dnsNames, ips, x509Certificates, result);
1402 if (!std::get<0>(result)) {
1403 return "Hostname/IP does not match certificate's altnames: " + std::get<1>(result);
1404 }
1405 return HOST_NAME + hostname + ". is cert's CN";
1406 }
1407
StartShakingHands(const TLSConnectOptions & options)1408 bool TLSSocket::TLSSocketInternal::StartShakingHands(const TLSConnectOptions &options)
1409 {
1410 if (!ssl_) {
1411 NETSTACK_LOGE("ssl is null");
1412 return false;
1413 }
1414 int result = SSL_connect(ssl_);
1415 if (result == -1) {
1416 int errorStatus = ConvertSSLError(ssl_);
1417 NETSTACK_LOGE("SSL connect is error, errno is %{public}d, error info is %{public}s", errorStatus,
1418 MakeSSLErrorString(errorStatus).c_str());
1419 return false;
1420 }
1421
1422 std::string list = SSL_get_cipher_list(ssl_, 0);
1423 NETSTACK_LOGI("SSL_get_cipher_list: %{public}s", list.c_str());
1424 configuration_.SetCipherSuite(list);
1425 if (!SetSharedSigals()) {
1426 NETSTACK_LOGE("Failed to set sharedSigalgs");
1427 }
1428 if (!GetRemoteCertificateFromPeer()) {
1429 NETSTACK_LOGE("Failed to get remote certificate");
1430 }
1431 if (!peerX509_) {
1432 NETSTACK_LOGE("peer x509Certificates is null");
1433 return false;
1434 }
1435 if (!SetRemoteCertRawData()) {
1436 NETSTACK_LOGE("Failed to set remote x509 certificata Serialization data");
1437 }
1438 CheckServerIdentity checkServerIdentity = options.GetCheckServerIdentity();
1439 if (!checkServerIdentity) {
1440 CheckServerIdentityLegal(hostName_, peerX509_);
1441 } else {
1442 checkServerIdentity(hostName_, {remoteCert_});
1443 }
1444 NETSTACK_LOGI("SSL Get Version: %{public}s, SSL Get Cipher: %{public}s", SSL_get_version(ssl_),
1445 SSL_get_cipher(ssl_));
1446 return true;
1447 }
1448
GetRemoteCertificateFromPeer()1449 bool TLSSocket::TLSSocketInternal::GetRemoteCertificateFromPeer()
1450 {
1451 peerX509_ = SSL_get_peer_certificate(ssl_);
1452 if (peerX509_ == nullptr) {
1453 int resErr = ConvertSSLError(GetSSL());
1454 NETSTACK_LOGE("open fail errno, errno is %{public}d, error info is %{public}s", resErr,
1455 MakeSSLErrorString(resErr).c_str());
1456 return false;
1457 }
1458 BIO *bio = BIO_new(BIO_s_mem());
1459 if (!bio) {
1460 NETSTACK_LOGE("TlsSocket::SetRemoteCertificate bio is null");
1461 return false;
1462 }
1463 X509_print(bio, peerX509_);
1464 char data[REMOTE_CERT_LEN] = {0};
1465 if (!BIO_read(bio, data, REMOTE_CERT_LEN)) {
1466 NETSTACK_LOGE("BIO_read function returns error");
1467 BIO_free(bio);
1468 return false;
1469 }
1470 BIO_free(bio);
1471 remoteCert_ = std::string(data);
1472 return true;
1473 }
1474
SetRemoteCertRawData()1475 bool TLSSocket::TLSSocketInternal::SetRemoteCertRawData()
1476 {
1477 if (peerX509_ == nullptr) {
1478 NETSTACK_LOGE("peerX509 is null");
1479 return false;
1480 }
1481 int32_t length = i2d_X509(peerX509_, nullptr);
1482 if (length <= 0) {
1483 NETSTACK_LOGE("Failed to convert peerX509 to der format");
1484 return false;
1485 }
1486 unsigned char *der = nullptr;
1487 (void)i2d_X509(peerX509_, &der);
1488 SecureData data(der, length);
1489 remoteRawData_.data = data;
1490 OPENSSL_free(der);
1491 remoteRawData_.encodingFormat = DER;
1492 return true;
1493 }
1494
GetRemoteCertRawData() const1495 const X509CertRawData &TLSSocket::TLSSocketInternal::GetRemoteCertRawData() const
1496 {
1497 return remoteRawData_;
1498 }
1499
GetSSL() const1500 ssl_st *TLSSocket::TLSSocketInternal::GetSSL() const
1501 {
1502 return ssl_;
1503 }
1504 } // namespace TlsSocket
1505 } // namespace NetStack
1506 } // namespace OHOS
1507