1 /*
2 * Copyright (c) 2022-2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "tls_socket.h"
17
18 #include <chrono>
19 #include <memory>
20 #include <numeric>
21 #include <regex>
22 #include <securec.h>
23 #include <thread>
24
25 #include <netinet/tcp.h>
26 #include <openssl/err.h>
27 #include <openssl/ssl.h>
28
29 #include "base_context.h"
30 #include "netstack_common_utils.h"
31 #include "netstack_log.h"
32 #include "tls.h"
33
34 namespace OHOS {
35 namespace NetStack {
36 namespace TlsSocket {
37 namespace {
38 constexpr int WAIT_MS = 10;
39 constexpr int TIMEOUT_MS = 10000;
40 constexpr int REMOTE_CERT_LEN = 8192;
41 constexpr int COMMON_NAME_BUF_SIZE = 256;
42 constexpr int BUF_SIZE = 2048;
43 constexpr int SSL_RET_CODE = 0;
44 constexpr int SSL_ERROR_RETURN = -1;
45 constexpr int OFFSET = 2;
46 constexpr int QUIT_RESPONSE_CODE_LEN = 3;
47 constexpr const char *SPLIT_ALT_NAMES = ",";
48 constexpr const char *SPLIT_HOST_NAME = ".";
49 constexpr const char *PROTOCOL_UNKNOW = "UNKNOW_PROTOCOL";
50 constexpr const char *UNKNOW_REASON = "Unknown reason";
51 constexpr const char *IP = "IP: ";
52 constexpr const char *HOST_NAME = "hostname: ";
53 constexpr const char *DNS = "DNS:";
54 constexpr const char *IP_ADDRESS = "IP Address:";
55 constexpr const char *SIGN_NID_RSA = "RSA+";
56 constexpr const char *SIGN_NID_RSA_PSS = "RSA-PSS+";
57 constexpr const char *SIGN_NID_DSA = "DSA+";
58 constexpr const char *SIGN_NID_ECDSA = "ECDSA+";
59 constexpr const char *SIGN_NID_ED = "Ed25519+";
60 constexpr const char *SIGN_NID_ED_FOUR_FOUR_EIGHT = "Ed448+";
61 constexpr const char *SIGN_NID_UNDEF_ADD = "UNDEF+";
62 constexpr const char *SIGN_NID_UNDEF = "UNDEF";
63 constexpr const char *OPERATOR_PLUS_SIGN = "+";
64 constexpr const char *QUIT_RESPONSE_CODE = "221";
65 const std::regex JSON_STRING_PATTERN{R"(/^"(?:[^"\\\u0000-\u001f]|\\(?:["\\/bfnrt]|u[0-9a-fA-F]{4}))*"/)"};
66 const std::regex PATTERN{
67 "((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|"
68 "2[0-4][0-9]|[01]?[0-9][0-9]?)"};
69
ConvertErrno()70 int ConvertErrno()
71 {
72 return TlsSocketError::TLS_ERR_SYS_BASE + errno;
73 }
74
ConvertSSLError(ssl_st * ssl)75 int ConvertSSLError(ssl_st *ssl)
76 {
77 if (!ssl) {
78 return TLS_ERR_SSL_NULL;
79 }
80 return TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl, SSL_RET_CODE);
81 }
82
MakeErrnoString()83 std::string MakeErrnoString()
84 {
85 return strerror(errno);
86 }
87
MakeSSLErrorString(int error)88 std::string MakeSSLErrorString(int error)
89 {
90 char err[MAX_ERR_LEN] = {0};
91 ERR_error_string_n(error - TlsSocketError::TLS_ERR_SYS_BASE, err, sizeof(err));
92 return err;
93 }
94
SplitEscapedAltNames(std::string & altNames)95 std::vector<std::string> SplitEscapedAltNames(std::string &altNames)
96 {
97 std::vector<std::string> result;
98 std::string currentToken;
99 size_t offset = 0;
100 while (offset != altNames.length()) {
101 auto nextSep = altNames.find_first_of(", ");
102 auto nextQuote = altNames.find_first_of('\"');
103 if (nextQuote != std::string::npos && (nextSep != std::string::npos || nextQuote < nextSep)) {
104 currentToken += altNames.substr(offset, nextQuote);
105 std::regex jsonStringPattern(JSON_STRING_PATTERN);
106 std::smatch match;
107 std::string altNameSubStr = altNames.substr(nextQuote);
108 bool ret = regex_match(altNameSubStr, match, jsonStringPattern);
109 if (!ret) {
110 return {""};
111 }
112 currentToken += result[0];
113 offset = nextQuote + result[0].length();
114 } else if (nextSep != std::string::npos) {
115 currentToken += altNames.substr(offset, nextSep);
116 result.push_back(currentToken);
117 currentToken = "";
118 offset = nextSep + OFFSET;
119 } else {
120 currentToken += altNames.substr(offset);
121 offset = altNames.length();
122 }
123 }
124 result.push_back(currentToken);
125 return result;
126 }
127
IsIP(const std::string & ip)128 bool IsIP(const std::string &ip)
129 {
130 std::regex pattern(PATTERN);
131 std::smatch res;
132 return regex_match(ip, res, pattern);
133 }
134
SplitHostName(std::string & hostName)135 std::vector<std::string> SplitHostName(std::string &hostName)
136 {
137 transform(hostName.begin(), hostName.end(), hostName.begin(), ::tolower);
138 return CommonUtils::Split(hostName, SPLIT_HOST_NAME);
139 }
140
SeekIntersection(std::vector<std::string> & vecA,std::vector<std::string> & vecB)141 bool SeekIntersection(std::vector<std::string> &vecA, std::vector<std::string> &vecB)
142 {
143 std::vector<std::string> result;
144 set_intersection(vecA.begin(), vecA.end(), vecB.begin(), vecB.end(), inserter(result, result.begin()));
145 return !result.empty();
146 }
147 } // namespace
148
TLSSecureOptions(const TLSSecureOptions & tlsSecureOptions)149 TLSSecureOptions::TLSSecureOptions(const TLSSecureOptions &tlsSecureOptions)
150 {
151 *this = tlsSecureOptions;
152 }
153
operator =(const TLSSecureOptions & tlsSecureOptions)154 TLSSecureOptions &TLSSecureOptions::operator=(const TLSSecureOptions &tlsSecureOptions)
155 {
156 key_ = tlsSecureOptions.GetKey();
157 caChain_ = tlsSecureOptions.GetCaChain();
158 cert_ = tlsSecureOptions.GetCert();
159 protocolChain_ = tlsSecureOptions.GetProtocolChain();
160 crlChain_ = tlsSecureOptions.GetCrlChain();
161 keyPass_ = tlsSecureOptions.GetKeyPass();
162 key_ = tlsSecureOptions.GetKey();
163 signatureAlgorithms_ = tlsSecureOptions.GetSignatureAlgorithms();
164 cipherSuite_ = tlsSecureOptions.GetCipherSuite();
165 useRemoteCipherPrefer_ = tlsSecureOptions.UseRemoteCipherPrefer();
166 TLSVerifyMode_ = tlsSecureOptions.GetVerifyMode();
167 return *this;
168 }
169
SetCaChain(const std::vector<std::string> & caChain)170 void TLSSecureOptions::SetCaChain(const std::vector<std::string> &caChain)
171 {
172 caChain_ = caChain;
173 }
174
SetCert(const std::string & cert)175 void TLSSecureOptions::SetCert(const std::string &cert)
176 {
177 cert_ = cert;
178 }
179
SetKey(const SecureData & key)180 void TLSSecureOptions::SetKey(const SecureData &key)
181 {
182 key_ = key;
183 }
184
SetKeyPass(const SecureData & keyPass)185 void TLSSecureOptions::SetKeyPass(const SecureData &keyPass)
186 {
187 keyPass_ = keyPass;
188 }
189
SetProtocolChain(const std::vector<std::string> & protocolChain)190 void TLSSecureOptions::SetProtocolChain(const std::vector<std::string> &protocolChain)
191 {
192 protocolChain_ = protocolChain;
193 }
194
SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer)195 void TLSSecureOptions::SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer)
196 {
197 useRemoteCipherPrefer_ = useRemoteCipherPrefer;
198 }
199
SetSignatureAlgorithms(const std::string & signatureAlgorithms)200 void TLSSecureOptions::SetSignatureAlgorithms(const std::string &signatureAlgorithms)
201 {
202 signatureAlgorithms_ = signatureAlgorithms;
203 }
204
SetCipherSuite(const std::string & cipherSuite)205 void TLSSecureOptions::SetCipherSuite(const std::string &cipherSuite)
206 {
207 cipherSuite_ = cipherSuite;
208 }
209
SetCrlChain(const std::vector<std::string> & crlChain)210 void TLSSecureOptions::SetCrlChain(const std::vector<std::string> &crlChain)
211 {
212 crlChain_ = crlChain;
213 }
214
GetCaChain() const215 const std::vector<std::string> &TLSSecureOptions::GetCaChain() const
216 {
217 return caChain_;
218 }
219
GetCert() const220 const std::string &TLSSecureOptions::GetCert() const
221 {
222 return cert_;
223 }
224
GetKey() const225 const SecureData &TLSSecureOptions::GetKey() const
226 {
227 return key_;
228 }
229
GetKeyPass() const230 const SecureData &TLSSecureOptions::GetKeyPass() const
231 {
232 return keyPass_;
233 }
234
GetProtocolChain() const235 const std::vector<std::string> &TLSSecureOptions::GetProtocolChain() const
236 {
237 return protocolChain_;
238 }
239
UseRemoteCipherPrefer() const240 bool TLSSecureOptions::UseRemoteCipherPrefer() const
241 {
242 return useRemoteCipherPrefer_;
243 }
244
GetSignatureAlgorithms() const245 const std::string &TLSSecureOptions::GetSignatureAlgorithms() const
246 {
247 return signatureAlgorithms_;
248 }
249
GetCipherSuite() const250 const std::string &TLSSecureOptions::GetCipherSuite() const
251 {
252 return cipherSuite_;
253 }
254
GetCrlChain() const255 const std::vector<std::string> &TLSSecureOptions::GetCrlChain() const
256 {
257 return crlChain_;
258 }
259
SetVerifyMode(VerifyMode verifyMode)260 void TLSSecureOptions::SetVerifyMode(VerifyMode verifyMode)
261 {
262 TLSVerifyMode_ = verifyMode;
263 }
264
GetVerifyMode() const265 VerifyMode TLSSecureOptions::GetVerifyMode() const
266 {
267 return TLSVerifyMode_;
268 }
269
SetNetAddress(const Socket::NetAddress & address)270 void TLSConnectOptions::SetNetAddress(const Socket::NetAddress &address)
271 {
272 address_.SetAddress(address.GetAddress());
273 address_.SetPort(address.GetPort());
274 address_.SetFamilyBySaFamily(address.GetSaFamily());
275 }
276
SetTlsSecureOptions(TLSSecureOptions & tlsSecureOptions)277 void TLSConnectOptions::SetTlsSecureOptions(TLSSecureOptions &tlsSecureOptions)
278 {
279 tlsSecureOptions_ = tlsSecureOptions;
280 }
281
SetCheckServerIdentity(const CheckServerIdentity & checkServerIdentity)282 void TLSConnectOptions::SetCheckServerIdentity(const CheckServerIdentity &checkServerIdentity)
283 {
284 checkServerIdentity_ = checkServerIdentity;
285 }
286
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)287 void TLSConnectOptions::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
288 {
289 alpnProtocols_ = alpnProtocols;
290 }
291
GetNetAddress() const292 Socket::NetAddress TLSConnectOptions::GetNetAddress() const
293 {
294 return address_;
295 }
296
GetTlsSecureOptions() const297 TLSSecureOptions TLSConnectOptions::GetTlsSecureOptions() const
298 {
299 return tlsSecureOptions_;
300 }
301
GetCheckServerIdentity() const302 CheckServerIdentity TLSConnectOptions::GetCheckServerIdentity() const
303 {
304 return checkServerIdentity_;
305 }
306
GetAlpnProtocols() const307 const std::vector<std::string> &TLSConnectOptions::GetAlpnProtocols() const
308 {
309 return alpnProtocols_;
310 }
311
MakeAddressString(sockaddr * addr)312 std::string TLSSocket::MakeAddressString(sockaddr *addr)
313 {
314 if (!addr) {
315 return {};
316 }
317 if (addr->sa_family == AF_INET) {
318 auto *addr4 = reinterpret_cast<sockaddr_in *>(addr);
319 const char *str = inet_ntoa(addr4->sin_addr);
320 if (str == nullptr || strlen(str) == 0) {
321 return {};
322 }
323 return str;
324 } else if (addr->sa_family == AF_INET6) {
325 auto *addr6 = reinterpret_cast<sockaddr_in6 *>(addr);
326 char str[INET6_ADDRSTRLEN] = {0};
327 if (inet_ntop(AF_INET6, &addr6->sin6_addr, str, INET6_ADDRSTRLEN) == nullptr || strlen(str) == 0) {
328 return {};
329 }
330 return str;
331 }
332 return {};
333 }
334
GetAddr(const Socket::NetAddress & address,sockaddr_in * addr4,sockaddr_in6 * addr6,sockaddr ** addr,socklen_t * len)335 void TLSSocket::GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr,
336 socklen_t *len)
337 {
338 if (!addr6 || !addr4 || !len) {
339 return;
340 }
341 sa_family_t family = address.GetSaFamily();
342 if (family == AF_INET) {
343 addr4->sin_family = AF_INET;
344 addr4->sin_port = htons(address.GetPort());
345 addr4->sin_addr.s_addr = inet_addr(address.GetAddress().c_str());
346 *addr = reinterpret_cast<sockaddr *>(addr4);
347 *len = sizeof(sockaddr_in);
348 } else if (family == AF_INET6) {
349 addr6->sin6_family = AF_INET6;
350 addr6->sin6_port = htons(address.GetPort());
351 inet_pton(AF_INET6, address.GetAddress().c_str(), &addr6->sin6_addr);
352 *addr = reinterpret_cast<sockaddr *>(addr6);
353 *len = sizeof(sockaddr_in6);
354 }
355 }
356
MakeIpSocket(sa_family_t family)357 void TLSSocket::MakeIpSocket(sa_family_t family)
358 {
359 if (family != AF_INET && family != AF_INET6) {
360 return;
361 }
362 int sock = socket(family, SOCK_STREAM, IPPROTO_IP);
363 if (sock < 0) {
364 int resErr = ConvertErrno();
365 NETSTACK_LOGE("Create socket failed (%{public}d:%{public}s)", errno, MakeErrnoString().c_str());
366 CallOnErrorCallback(resErr, MakeErrnoString());
367 return;
368 }
369 sockFd_ = sock;
370 }
371
StartReadMessage()372 void TLSSocket::StartReadMessage()
373 {
374 std::thread thread([this]() {
375 isRunning_ = true;
376 isRunOver_ = false;
377 while (isRunning_) {
378 char buffer[MAX_BUFFER_SIZE];
379 if (memset_s(buffer, MAX_BUFFER_SIZE, 0, MAX_BUFFER_SIZE) != EOK) {
380 NETSTACK_LOGE("memcpy_s failed!");
381 break;
382 }
383 int len = tlsSocketInternal_.Recv(buffer, MAX_BUFFER_SIZE);
384 if (!isRunning_) {
385 break;
386 }
387 if (len < 0) {
388 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
389 NETSTACK_LOGE("SSL_read function read error, errno is %{public}d, errno info is %{public}s", resErr,
390 MakeSSLErrorString(resErr).c_str());
391 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
392 break;
393 }
394
395 if (len == 0) {
396 continue;
397 }
398
399 Socket::SocketRemoteInfo remoteInfo;
400 remoteInfo.SetSize(strlen(buffer));
401 tlsSocketInternal_.MakeRemoteInfo(remoteInfo);
402 CallOnMessageCallback(buffer, remoteInfo);
403 if (strncmp(buffer, QUIT_RESPONSE_CODE, QUIT_RESPONSE_CODE_LEN) == 0) {
404 break;
405 }
406 }
407 isRunOver_ = true;
408 });
409 thread.detach();
410 }
411
CallOnMessageCallback(const std::string & data,const Socket::SocketRemoteInfo & remoteInfo)412 void TLSSocket::CallOnMessageCallback(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)
413 {
414 OnMessageCallback func = nullptr;
415 {
416 std::lock_guard<std::mutex> lock(mutex_);
417 if (onMessageCallback_) {
418 func = onMessageCallback_;
419 }
420 }
421
422 if (func) {
423 func(data, remoteInfo);
424 }
425 }
426
CallOnConnectCallback()427 void TLSSocket::CallOnConnectCallback()
428 {
429 OnConnectCallback func = nullptr;
430 {
431 std::lock_guard<std::mutex> lock(mutex_);
432 if (onConnectCallback_) {
433 func = onConnectCallback_;
434 }
435 }
436
437 if (func) {
438 func();
439 }
440 }
441
CallOnCloseCallback()442 void TLSSocket::CallOnCloseCallback()
443 {
444 OnCloseCallback func = nullptr;
445 {
446 std::lock_guard<std::mutex> lock(mutex_);
447 if (onCloseCallback_) {
448 func = onCloseCallback_;
449 }
450 }
451
452 if (func) {
453 func();
454 }
455 }
456
CallOnErrorCallback(int32_t err,const std::string & errString)457 void TLSSocket::CallOnErrorCallback(int32_t err, const std::string &errString)
458 {
459 OnErrorCallback func = nullptr;
460 {
461 std::lock_guard<std::mutex> lock(mutex_);
462 if (onErrorCallback_) {
463 func = onErrorCallback_;
464 }
465 }
466
467 if (func) {
468 func(err, errString);
469 }
470 }
471
CallBindCallback(int32_t err,BindCallback callback)472 void TLSSocket::CallBindCallback(int32_t err, BindCallback callback)
473 {
474 DealCallback<BindCallback>(err, callback);
475 }
476
CallConnectCallback(int32_t err,ConnectCallback callback)477 void TLSSocket::CallConnectCallback(int32_t err, ConnectCallback callback)
478 {
479 DealCallback<ConnectCallback>(err, callback);
480 }
481
CallSendCallback(int32_t err,SendCallback callback)482 void TLSSocket::CallSendCallback(int32_t err, SendCallback callback)
483 {
484 DealCallback<SendCallback>(err, callback);
485 }
486
CallCloseCallback(int32_t err,CloseCallback callback)487 void TLSSocket::CallCloseCallback(int32_t err, CloseCallback callback)
488 {
489 DealCallback<CloseCallback>(err, callback);
490 }
491
CallGetRemoteAddressCallback(int32_t err,const Socket::NetAddress & address,GetRemoteAddressCallback callback)492 void TLSSocket::CallGetRemoteAddressCallback(int32_t err, const Socket::NetAddress &address,
493 GetRemoteAddressCallback callback)
494 {
495 GetRemoteAddressCallback func = nullptr;
496 {
497 std::lock_guard<std::mutex> lock(mutex_);
498 if (callback) {
499 func = callback;
500 }
501 }
502
503 if (func) {
504 func(err, address);
505 }
506 }
507
CallGetStateCallback(int32_t err,const Socket::SocketStateBase & state,GetStateCallback callback)508 void TLSSocket::CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, GetStateCallback callback)
509 {
510 GetStateCallback func = nullptr;
511 {
512 std::lock_guard<std::mutex> lock(mutex_);
513 if (callback) {
514 func = callback;
515 }
516 }
517
518 if (func) {
519 func(err, state);
520 }
521 }
522
CallSetExtraOptionsCallback(int32_t err,SetExtraOptionsCallback callback)523 void TLSSocket::CallSetExtraOptionsCallback(int32_t err, SetExtraOptionsCallback callback)
524 {
525 DealCallback<SetExtraOptionsCallback>(err, callback);
526 }
527
CallGetCertificateCallback(int32_t err,const X509CertRawData & cert,GetCertificateCallback callback)528 void TLSSocket::CallGetCertificateCallback(int32_t err, const X509CertRawData &cert, GetCertificateCallback callback)
529 {
530 GetCertificateCallback func = nullptr;
531 {
532 std::lock_guard<std::mutex> lock(mutex_);
533 if (callback) {
534 func = callback;
535 }
536 }
537
538 if (func) {
539 func(err, cert);
540 }
541 }
542
CallGetRemoteCertificateCallback(int32_t err,const X509CertRawData & cert,GetRemoteCertificateCallback callback)543 void TLSSocket::CallGetRemoteCertificateCallback(int32_t err, const X509CertRawData &cert,
544 GetRemoteCertificateCallback callback)
545 {
546 GetRemoteCertificateCallback func = nullptr;
547 {
548 std::lock_guard<std::mutex> lock(mutex_);
549 if (callback) {
550 func = callback;
551 }
552 }
553
554 if (func) {
555 func(err, cert);
556 }
557 }
558
CallGetProtocolCallback(int32_t err,const std::string & protocol,GetProtocolCallback callback)559 void TLSSocket::CallGetProtocolCallback(int32_t err, const std::string &protocol, GetProtocolCallback callback)
560 {
561 GetProtocolCallback func = nullptr;
562 {
563 std::lock_guard<std::mutex> lock(mutex_);
564 if (callback) {
565 func = callback;
566 }
567 }
568
569 if (func) {
570 func(err, protocol);
571 }
572 }
573
CallGetCipherSuiteCallback(int32_t err,const std::vector<std::string> & suite,GetCipherSuiteCallback callback)574 void TLSSocket::CallGetCipherSuiteCallback(int32_t err, const std::vector<std::string> &suite,
575 GetCipherSuiteCallback callback)
576 {
577 GetCipherSuiteCallback func = nullptr;
578 {
579 std::lock_guard<std::mutex> lock(mutex_);
580 if (callback) {
581 func = callback;
582 }
583 }
584
585 if (func) {
586 func(err, suite);
587 }
588 }
589
CallGetSignatureAlgorithmsCallback(int32_t err,const std::vector<std::string> & algorithms,GetSignatureAlgorithmsCallback callback)590 void TLSSocket::CallGetSignatureAlgorithmsCallback(int32_t err, const std::vector<std::string> &algorithms,
591 GetSignatureAlgorithmsCallback callback)
592 {
593 GetSignatureAlgorithmsCallback func = nullptr;
594 {
595 std::lock_guard<std::mutex> lock(mutex_);
596 if (callback) {
597 func = callback;
598 }
599 }
600
601 if (func) {
602 func(err, algorithms);
603 }
604 }
605
Bind(const Socket::NetAddress & address,const BindCallback & callback)606 void TLSSocket::Bind(const Socket::NetAddress &address, const BindCallback &callback)
607 {
608 if (!CommonUtils::HasInternetPermission()) {
609 CallBindCallback(PERMISSION_DENIED_CODE, callback);
610 return;
611 }
612 if (sockFd_ >= 0) {
613 CallBindCallback(TLSSOCKET_SUCCESS, callback);
614 return;
615 }
616
617 MakeIpSocket(address.GetSaFamily());
618 if (sockFd_ < 0) {
619 int resErr = ConvertErrno();
620 NETSTACK_LOGE("make tcp socket failed errno is %{public}d %{public}s", errno, MakeErrnoString().c_str());
621 CallOnErrorCallback(resErr, MakeErrnoString());
622 CallBindCallback(resErr, callback);
623 return;
624 }
625
626 sockaddr_in addr4 = {0};
627 sockaddr_in6 addr6 = {0};
628 sockaddr *addr = nullptr;
629 socklen_t len;
630 GetAddr(address, &addr4, &addr6, &addr, &len);
631 if (addr == nullptr) {
632 NETSTACK_LOGE("TLSSocket::Bind Address Is Invalid");
633 CallOnErrorCallback(-1, "Address Is Invalid");
634 CallBindCallback(ConvertErrno(), callback);
635 return;
636 }
637 CallBindCallback(TLSSOCKET_SUCCESS, callback);
638 }
639
Connect(OHOS::NetStack::TlsSocket::TLSConnectOptions & tlsConnectOptions,const OHOS::NetStack::TlsSocket::ConnectCallback & callback)640 void TLSSocket::Connect(OHOS::NetStack::TlsSocket::TLSConnectOptions &tlsConnectOptions,
641 const OHOS::NetStack::TlsSocket::ConnectCallback &callback)
642 {
643 if (sockFd_ < 0) {
644 int resErr = ConvertErrno();
645 NETSTACK_LOGE("connect error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
646 CallOnErrorCallback(resErr, MakeErrnoString());
647 callback(resErr);
648 return;
649 }
650
651 auto res = tlsSocketInternal_.TlsConnectToHost(sockFd_, tlsConnectOptions);
652 if (!res) {
653 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
654 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
655 callback(resErr);
656 return;
657 }
658 StartReadMessage();
659 CallOnConnectCallback();
660 callback(TLSSOCKET_SUCCESS);
661 }
662
Send(const OHOS::NetStack::Socket::TCPSendOptions & tcpSendOptions,const SendCallback & callback)663 void TLSSocket::Send(const OHOS::NetStack::Socket::TCPSendOptions &tcpSendOptions, const SendCallback &callback)
664 {
665 (void)tcpSendOptions;
666
667 auto res = tlsSocketInternal_.Send(tcpSendOptions.GetData());
668 if (!res) {
669 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
670 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
671 CallSendCallback(resErr, callback);
672 return;
673 }
674 CallSendCallback(TLSSOCKET_SUCCESS, callback);
675 }
676
WaitConditionWithTimeout(const bool * flag,const int32_t timeoutMs)677 bool WaitConditionWithTimeout(const bool *flag, const int32_t timeoutMs)
678 {
679 int maxWaitCnt = timeoutMs / WAIT_MS;
680 int cnt = 0;
681 while (!(*flag)) {
682 if (cnt >= maxWaitCnt) {
683 return false;
684 }
685 std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MS));
686 cnt++;
687 }
688 return true;
689 }
690
Close(const CloseCallback & callback)691 void TLSSocket::Close(const CloseCallback &callback)
692 {
693 if (!WaitConditionWithTimeout(&isRunning_, TIMEOUT_MS)) {
694 callback(ConvertErrno());
695 NETSTACK_LOGE("The error cause is that the runtime wait time is insufficient");
696 return;
697 }
698 isRunning_ = false;
699
700 auto res = tlsSocketInternal_.Close();
701 if (!res) {
702 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
703 NETSTACK_LOGE("close error is %{public}s %{public}d", MakeSSLErrorString(resErr).c_str(), resErr);
704 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
705 callback(resErr);
706 return;
707 }
708 CallOnCloseCallback();
709 callback(TLSSOCKET_SUCCESS);
710 }
711
GetRemoteAddress(const GetRemoteAddressCallback & callback)712 void TLSSocket::GetRemoteAddress(const GetRemoteAddressCallback &callback)
713 {
714 sockaddr sockAddr = {0};
715 socklen_t len = sizeof(sockaddr);
716 int ret = getsockname(sockFd_, &sockAddr, &len);
717 if (ret < 0) {
718 int resErr = ConvertErrno();
719 NETSTACK_LOGE("getsockname failed errno %{public}d", resErr);
720 CallOnErrorCallback(resErr, MakeErrnoString());
721 CallGetRemoteAddressCallback(resErr, {}, callback);
722 return;
723 }
724
725 if (sockAddr.sa_family == AF_INET) {
726 GetIp4RemoteAddress(callback);
727 } else if (sockAddr.sa_family == AF_INET6) {
728 GetIp6RemoteAddress(callback);
729 }
730 }
731
GetIp4RemoteAddress(const GetRemoteAddressCallback & callback)732 void TLSSocket::GetIp4RemoteAddress(const GetRemoteAddressCallback &callback)
733 {
734 sockaddr_in addr4 = {0};
735 socklen_t len4 = sizeof(sockaddr_in);
736
737 int ret = getpeername(sockFd_, reinterpret_cast<sockaddr *>(&addr4), &len4);
738 if (ret < 0) {
739 int resErr = ConvertErrno();
740 NETSTACK_LOGE("GetIp4RemoteAddress failed errno %{public}d", resErr);
741 CallOnErrorCallback(resErr, MakeErrnoString());
742 CallGetRemoteAddressCallback(resErr, {}, callback);
743 return;
744 }
745
746 std::string address = MakeAddressString(reinterpret_cast<sockaddr *>(&addr4));
747 if (address.empty()) {
748 NETSTACK_LOGE("GetIp4RemoteAddress failed errno %{public}d", errno);
749 CallOnErrorCallback(-1, "Address is invalid");
750 CallGetRemoteAddressCallback(ConvertErrno(), {}, callback);
751 return;
752 }
753 Socket::NetAddress netAddress;
754 netAddress.SetAddress(address);
755 netAddress.SetFamilyBySaFamily(AF_INET);
756 netAddress.SetPort(ntohs(addr4.sin_port));
757 CallGetRemoteAddressCallback(TLSSOCKET_SUCCESS, netAddress, callback);
758 }
759
GetIp6RemoteAddress(const GetRemoteAddressCallback & callback)760 void TLSSocket::GetIp6RemoteAddress(const GetRemoteAddressCallback &callback)
761 {
762 sockaddr_in6 addr6 = {0};
763 socklen_t len6 = sizeof(sockaddr_in6);
764
765 int ret = getpeername(sockFd_, reinterpret_cast<sockaddr *>(&addr6), &len6);
766 if (ret < 0) {
767 int resErr = ConvertErrno();
768 NETSTACK_LOGE("GetIp6RemoteAddress failed errno %{public}d", resErr);
769 CallOnErrorCallback(resErr, MakeErrnoString());
770 CallGetRemoteAddressCallback(resErr, {}, callback);
771 return;
772 }
773
774 std::string address = MakeAddressString(reinterpret_cast<sockaddr *>(&addr6));
775 if (address.empty()) {
776 NETSTACK_LOGE("GetIp6RemoteAddress failed errno %{public}d", errno);
777 CallOnErrorCallback(-1, "Address is invalid");
778 CallGetRemoteAddressCallback(ConvertErrno(), {}, callback);
779 return;
780 }
781 Socket::NetAddress netAddress;
782 netAddress.SetAddress(address);
783 netAddress.SetFamilyBySaFamily(AF_INET6);
784 netAddress.SetPort(ntohs(addr6.sin6_port));
785 CallGetRemoteAddressCallback(TLSSOCKET_SUCCESS, netAddress, callback);
786 }
787
GetState(const GetStateCallback & callback)788 void TLSSocket::GetState(const GetStateCallback &callback)
789 {
790 int opt;
791 socklen_t optLen = sizeof(int);
792 int r = getsockopt(sockFd_, SOL_SOCKET, SO_TYPE, &opt, &optLen);
793 if (r < 0) {
794 Socket::SocketStateBase state;
795 state.SetIsClose(true);
796 CallGetStateCallback(ConvertErrno(), state, callback);
797 return;
798 }
799 sockaddr sockAddr = {0};
800 socklen_t len = sizeof(sockaddr);
801 Socket::SocketStateBase state;
802 int ret = getsockname(sockFd_, &sockAddr, &len);
803 state.SetIsBound(ret == 0);
804 ret = getpeername(sockFd_, &sockAddr, &len);
805 state.SetIsConnected(ret == 0);
806 CallGetStateCallback(TLSSOCKET_SUCCESS, state, callback);
807 }
808
SetBaseOptions(const Socket::ExtraOptionsBase & option) const809 bool TLSSocket::SetBaseOptions(const Socket::ExtraOptionsBase &option) const
810 {
811 if (option.GetReceiveBufferSize() != 0) {
812 int size = (int)option.GetReceiveBufferSize();
813 if (setsockopt(sockFd_, SOL_SOCKET, SO_RCVBUF, reinterpret_cast<void *>(&size), sizeof(size)) < 0) {
814 return false;
815 }
816 }
817
818 if (option.GetSendBufferSize() != 0) {
819 int size = (int)option.GetSendBufferSize();
820 if (setsockopt(sockFd_, SOL_SOCKET, SO_SNDBUF, reinterpret_cast<void *>(&size), sizeof(size)) < 0) {
821 return false;
822 }
823 }
824
825 if (option.IsReuseAddress()) {
826 int reuse = 1;
827 if (setsockopt(sockFd_, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<void *>(&reuse), sizeof(reuse)) < 0) {
828 return false;
829 }
830 }
831
832 if (option.GetSocketTimeout() != 0) {
833 timeval timeout = {(int)option.GetSocketTimeout(), 0};
834 if (setsockopt(sockFd_, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<void *>(&timeout), sizeof(timeout)) < 0) {
835 return false;
836 }
837 if (setsockopt(sockFd_, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast<void *>(&timeout), sizeof(timeout)) < 0) {
838 return false;
839 }
840 }
841
842 return true;
843 }
844
SetExtraOptions(const Socket::TCPExtraOptions & option) const845 bool TLSSocket::SetExtraOptions(const Socket::TCPExtraOptions &option) const
846 {
847 if (option.IsKeepAlive()) {
848 int keepalive = 1;
849 if (setsockopt(sockFd_, SOL_SOCKET, SO_KEEPALIVE, &keepalive, sizeof(keepalive)) < 0) {
850 return false;
851 }
852 }
853
854 if (option.IsOOBInline()) {
855 int oobInline = 1;
856 if (setsockopt(sockFd_, SOL_SOCKET, SO_OOBINLINE, &oobInline, sizeof(oobInline)) < 0) {
857 return false;
858 }
859 }
860
861 if (option.IsTCPNoDelay()) {
862 int tcpNoDelay = 1;
863 if (setsockopt(sockFd_, IPPROTO_TCP, TCP_NODELAY, &tcpNoDelay, sizeof(tcpNoDelay)) < 0) {
864 return false;
865 }
866 }
867
868 linger soLinger = {0};
869 soLinger.l_onoff = option.socketLinger.IsOn();
870 soLinger.l_linger = (int)option.socketLinger.GetLinger();
871 if (setsockopt(sockFd_, SOL_SOCKET, SO_LINGER, &soLinger, sizeof(soLinger)) < 0) {
872 return false;
873 }
874
875 return true;
876 }
877
SetExtraOptions(const OHOS::NetStack::Socket::TCPExtraOptions & tcpExtraOptions,const SetExtraOptionsCallback & callback)878 void TLSSocket::SetExtraOptions(const OHOS::NetStack::Socket::TCPExtraOptions &tcpExtraOptions,
879 const SetExtraOptionsCallback &callback)
880 {
881 if (!SetBaseOptions(tcpExtraOptions)) {
882 NETSTACK_LOGE("SetExtraOptions errno %{public}d", errno);
883 CallOnErrorCallback(errno, MakeErrnoString());
884 CallSetExtraOptionsCallback(ConvertErrno(), callback);
885 return;
886 }
887
888 if (!SetExtraOptions(tcpExtraOptions)) {
889 NETSTACK_LOGE("SetExtraOptions errno %{public}d", errno);
890 CallOnErrorCallback(errno, MakeErrnoString());
891 CallSetExtraOptionsCallback(ConvertErrno(), callback);
892 return;
893 }
894
895 CallSetExtraOptionsCallback(TLSSOCKET_SUCCESS, callback);
896 }
897
GetCertificate(const GetCertificateCallback & callback)898 void TLSSocket::GetCertificate(const GetCertificateCallback &callback)
899 {
900 const auto &cert = tlsSocketInternal_.GetCertificate();
901 NETSTACK_LOGI("cert der is %{public}d", cert.encodingFormat);
902
903 if (!cert.data.Length()) {
904 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
905 NETSTACK_LOGE("GetCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
906 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
907 callback(resErr, {});
908 return;
909 }
910 callback(TLSSOCKET_SUCCESS, cert);
911 }
912
GetRemoteCertificate(const GetRemoteCertificateCallback & callback)913 void TLSSocket::GetRemoteCertificate(const GetRemoteCertificateCallback &callback)
914 {
915 const auto &remoteCert = tlsSocketInternal_.GetRemoteCertRawData();
916 if (!remoteCert.data.Length()) {
917 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
918 NETSTACK_LOGE("GetRemoteCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
919 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
920 callback(resErr, {});
921 return;
922 }
923 callback(TLSSOCKET_SUCCESS, remoteCert);
924 }
925
GetProtocol(const GetProtocolCallback & callback)926 void TLSSocket::GetProtocol(const GetProtocolCallback &callback)
927 {
928 const auto &protocol = tlsSocketInternal_.GetProtocol();
929 if (protocol.empty()) {
930 NETSTACK_LOGE("GetProtocol errno %{public}d", errno);
931 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
932 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
933 callback(resErr, "");
934 return;
935 }
936 callback(TLSSOCKET_SUCCESS, protocol);
937 }
938
GetCipherSuite(const GetCipherSuiteCallback & callback)939 void TLSSocket::GetCipherSuite(const GetCipherSuiteCallback &callback)
940 {
941 const auto &cipherSuite = tlsSocketInternal_.GetCipherSuite();
942 if (cipherSuite.empty()) {
943 NETSTACK_LOGE("GetCipherSuite errno %{public}d", errno);
944 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
945 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
946 callback(resErr, cipherSuite);
947 return;
948 }
949 callback(TLSSOCKET_SUCCESS, cipherSuite);
950 }
951
GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback & callback)952 void TLSSocket::GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback &callback)
953 {
954 const auto &signatureAlgorithms = tlsSocketInternal_.GetSignatureAlgorithms();
955 if (signatureAlgorithms.empty()) {
956 NETSTACK_LOGE("GetSignatureAlgorithms errno %{public}d", errno);
957 int resErr = ConvertSSLError(tlsSocketInternal_.GetSSL());
958 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
959 callback(resErr, {});
960 return;
961 }
962 callback(TLSSOCKET_SUCCESS, signatureAlgorithms);
963 }
964
OnMessage(const OnMessageCallback & onMessageCallback)965 void TLSSocket::OnMessage(const OnMessageCallback &onMessageCallback)
966 {
967 std::lock_guard<std::mutex> lock(mutex_);
968 onMessageCallback_ = onMessageCallback;
969 }
970
OffMessage()971 void TLSSocket::OffMessage()
972 {
973 std::lock_guard<std::mutex> lock(mutex_);
974 if (onMessageCallback_) {
975 onMessageCallback_ = nullptr;
976 }
977 }
978
OnConnect(const OnConnectCallback & onConnectCallback)979 void TLSSocket::OnConnect(const OnConnectCallback &onConnectCallback)
980 {
981 std::lock_guard<std::mutex> lock(mutex_);
982 onConnectCallback_ = onConnectCallback;
983 }
984
OffConnect()985 void TLSSocket::OffConnect()
986 {
987 std::lock_guard<std::mutex> lock(mutex_);
988 if (onConnectCallback_) {
989 onConnectCallback_ = nullptr;
990 }
991 }
992
OnClose(const OnCloseCallback & onCloseCallback)993 void TLSSocket::OnClose(const OnCloseCallback &onCloseCallback)
994 {
995 std::lock_guard<std::mutex> lock(mutex_);
996 onCloseCallback_ = onCloseCallback;
997 }
998
OffClose()999 void TLSSocket::OffClose()
1000 {
1001 std::lock_guard<std::mutex> lock(mutex_);
1002 if (onCloseCallback_) {
1003 onCloseCallback_ = nullptr;
1004 }
1005 }
1006
OnError(const OnErrorCallback & onErrorCallback)1007 void TLSSocket::OnError(const OnErrorCallback &onErrorCallback)
1008 {
1009 std::lock_guard<std::mutex> lock(mutex_);
1010 onErrorCallback_ = onErrorCallback;
1011 }
1012
OffError()1013 void TLSSocket::OffError()
1014 {
1015 std::lock_guard<std::mutex> lock(mutex_);
1016 if (onErrorCallback_) {
1017 onErrorCallback_ = nullptr;
1018 }
1019 }
1020
ExecSocketConnect(const std::string & hostName,int port,sa_family_t family,int socketDescriptor)1021 bool ExecSocketConnect(const std::string &hostName, int port, sa_family_t family, int socketDescriptor)
1022 {
1023 struct sockaddr_in dest = {0};
1024 dest.sin_family = family;
1025 dest.sin_port = htons(port);
1026 if (!inet_aton(hostName.c_str(), reinterpret_cast<in_addr *>(&dest.sin_addr.s_addr))) {
1027 NETSTACK_LOGE("inet_aton is error, hostName is %s", hostName.c_str());
1028 return false;
1029 }
1030 int connectResult = connect(socketDescriptor, reinterpret_cast<struct sockaddr *>(&dest), sizeof(dest));
1031 if (connectResult == -1) {
1032 NETSTACK_LOGE("socket connect error!The error code is %{public}d, The error message is %{public}s", errno,
1033 strerror(errno));
1034 return false;
1035 }
1036 return true;
1037 }
1038
TlsConnectToHost(int sock,const TLSConnectOptions & options)1039 bool TLSSocket::TLSSocketInternal::TlsConnectToHost(int sock, const TLSConnectOptions &options)
1040 {
1041 SetTlsConfiguration(options);
1042 std::string cipherSuite = options.GetTlsSecureOptions().GetCipherSuite();
1043 if (!cipherSuite.empty()) {
1044 configuration_.SetCipherSuite(cipherSuite);
1045 }
1046 std::string signatureAlgorithms = options.GetTlsSecureOptions().GetSignatureAlgorithms();
1047 if (!signatureAlgorithms.empty()) {
1048 configuration_.SetSignatureAlgorithms(signatureAlgorithms);
1049 }
1050 const auto protocolVec = options.GetTlsSecureOptions().GetProtocolChain();
1051 if (!protocolVec.empty()) {
1052 configuration_.SetProtocol(protocolVec);
1053 }
1054
1055 hostName_ = options.GetNetAddress().GetAddress();
1056 port_ = options.GetNetAddress().GetPort();
1057 family_ = options.GetNetAddress().GetSaFamily();
1058 socketDescriptor_ = sock;
1059 if (!ExecSocketConnect(options.GetNetAddress().GetAddress(), options.GetNetAddress().GetPort(),
1060 options.GetNetAddress().GetSaFamily(), socketDescriptor_)) {
1061 return false;
1062 }
1063 return StartTlsConnected(options);
1064 }
1065
SetTlsConfiguration(const TLSConnectOptions & config)1066 void TLSSocket::TLSSocketInternal::SetTlsConfiguration(const TLSConnectOptions &config)
1067 {
1068 configuration_.SetPrivateKey(config.GetTlsSecureOptions().GetKey(), config.GetTlsSecureOptions().GetKeyPass());
1069 configuration_.SetLocalCertificate(config.GetTlsSecureOptions().GetCert());
1070 configuration_.SetCaCertificate(config.GetTlsSecureOptions().GetCaChain());
1071 }
1072
Send(const std::string & data)1073 bool TLSSocket::TLSSocketInternal::Send(const std::string &data)
1074 {
1075 NETSTACK_LOGD("data to send :%{public}s", data.c_str());
1076 if (data.empty()) {
1077 NETSTACK_LOGE("data is empty");
1078 return false;
1079 }
1080 if (!ssl_) {
1081 NETSTACK_LOGE("ssl is null");
1082 return false;
1083 }
1084 int len = SSL_write(ssl_, data.c_str(), data.length());
1085 if (len < 0) {
1086 int resErr = ConvertSSLError(GetSSL());
1087 NETSTACK_LOGE("data '%{public}s' send failed!The error code is %{public}d, The error message is'%{public}s'",
1088 data.c_str(), resErr, MakeSSLErrorString(resErr).c_str());
1089 return false;
1090 }
1091 NETSTACK_LOGD("data '%{public}s' Sent successfully,sent in total %{public}d bytes!", data.c_str(), len);
1092 return true;
1093 }
Recv(char * buffer,int maxBufferSize)1094 int TLSSocket::TLSSocketInternal::Recv(char *buffer, int maxBufferSize)
1095 {
1096 if (!ssl_) {
1097 NETSTACK_LOGE("ssl is null");
1098 return SSL_ERROR_RETURN;
1099 }
1100 return SSL_read(ssl_, buffer, maxBufferSize);
1101 }
1102
Close()1103 bool TLSSocket::TLSSocketInternal::Close()
1104 {
1105 if (!ssl_) {
1106 NETSTACK_LOGE("ssl is null");
1107 return false;
1108 }
1109 int result = SSL_shutdown(ssl_);
1110 if (result < 0) {
1111 int resErr = ConvertSSLError(GetSSL());
1112 NETSTACK_LOGE("Error in shutdown, errno is %{public}d, error info is %{public}s", resErr,
1113 MakeSSLErrorString(resErr).c_str());
1114 return false;
1115 }
1116 SSL_free(ssl_);
1117 ssl_ = nullptr;
1118 close(socketDescriptor_);
1119 socketDescriptor_ = -1;
1120 if (!tlsContextPointer_) {
1121 NETSTACK_LOGE("Tls context pointer is null");
1122 return false;
1123 }
1124 tlsContextPointer_->CloseCtx();
1125 return true;
1126 }
1127
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)1128 bool TLSSocket::TLSSocketInternal::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
1129 {
1130 if (!ssl_) {
1131 NETSTACK_LOGE("ssl is null");
1132 return false;
1133 }
1134 size_t pos = 0;
1135 size_t len = std::accumulate(alpnProtocols.begin(), alpnProtocols.end(), static_cast<size_t>(0),
1136 [](size_t init, const std::string &alpnProt) { return init + alpnProt.length(); });
1137 auto result = std::make_unique<unsigned char[]>(alpnProtocols.size() + len);
1138 for (const auto &str : alpnProtocols) {
1139 len = str.length();
1140 result[pos++] = len;
1141 if (!strcpy_s(reinterpret_cast<char *>(&result[pos]), len, str.c_str())) {
1142 NETSTACK_LOGE("strcpy_s failed");
1143 return false;
1144 }
1145 pos += len;
1146 }
1147 result[pos] = '\0';
1148
1149 NETSTACK_LOGD("alpnProtocols after splicing %{public}s", result.get());
1150 if (SSL_set_alpn_protos(ssl_, result.get(), pos)) {
1151 int resErr = ConvertSSLError(GetSSL());
1152 NETSTACK_LOGE("Failed to set negotiable protocol list, errno is %{public}d, error info is %{public}s", resErr,
1153 MakeSSLErrorString(resErr).c_str());
1154 return false;
1155 }
1156 return true;
1157 }
1158
MakeRemoteInfo(Socket::SocketRemoteInfo & remoteInfo)1159 void TLSSocket::TLSSocketInternal::MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo)
1160 {
1161 remoteInfo.SetAddress(hostName_);
1162 remoteInfo.SetPort(port_);
1163 remoteInfo.SetFamily(family_);
1164 }
1165
GetTlsConfiguration() const1166 TLSConfiguration TLSSocket::TLSSocketInternal::GetTlsConfiguration() const
1167 {
1168 return configuration_;
1169 }
1170
GetCipherSuite() const1171 std::vector<std::string> TLSSocket::TLSSocketInternal::GetCipherSuite() const
1172 {
1173 if (!ssl_) {
1174 NETSTACK_LOGE("ssl in null");
1175 return {};
1176 }
1177 STACK_OF(SSL_CIPHER) *sk = SSL_get_ciphers(ssl_);
1178 if (!sk) {
1179 NETSTACK_LOGE("get ciphers failed");
1180 return {};
1181 }
1182 CipherSuite cipherSuite;
1183 std::vector<std::string> cipherSuiteVec;
1184 for (int i = 0; i < sk_SSL_CIPHER_num(sk); i++) {
1185 const SSL_CIPHER *c = sk_SSL_CIPHER_value(sk, i);
1186 cipherSuite.cipherName_ = SSL_CIPHER_get_name(c);
1187 cipherSuiteVec.push_back(cipherSuite.cipherName_);
1188 }
1189 return cipherSuiteVec;
1190 }
1191
GetRemoteCertificate() const1192 std::string TLSSocket::TLSSocketInternal::GetRemoteCertificate() const
1193 {
1194 return remoteCert_;
1195 }
1196
GetCertificate() const1197 const X509CertRawData &TLSSocket::TLSSocketInternal::GetCertificate() const
1198 {
1199 return configuration_.GetCertificate();
1200 }
1201
GetSignatureAlgorithms() const1202 std::vector<std::string> TLSSocket::TLSSocketInternal::GetSignatureAlgorithms() const
1203 {
1204 return signatureAlgorithms_;
1205 }
1206
GetProtocol() const1207 std::string TLSSocket::TLSSocketInternal::GetProtocol() const
1208 {
1209 if (!ssl_) {
1210 NETSTACK_LOGE("ssl in null");
1211 return PROTOCOL_UNKNOW;
1212 }
1213 if (configuration_.GetProtocol() == TLS_V1_3) {
1214 return PROTOCOL_TLS_V13;
1215 }
1216 return PROTOCOL_TLS_V12;
1217 }
1218
SetSharedSigals()1219 bool TLSSocket::TLSSocketInternal::SetSharedSigals()
1220 {
1221 if (!ssl_) {
1222 NETSTACK_LOGE("ssl is null");
1223 return false;
1224 }
1225 int number = SSL_get_shared_sigalgs(ssl_, 0, nullptr, nullptr, nullptr, nullptr, nullptr);
1226 if (!number) {
1227 NETSTACK_LOGE("SSL_get_shared_sigalgs return value error");
1228 return false;
1229 }
1230 for (int i = 0; i < number; i++) {
1231 int hash_nid;
1232 int sign_nid;
1233 std::string sig_with_md;
1234 SSL_get_shared_sigalgs(ssl_, i, &sign_nid, &hash_nid, nullptr, nullptr, nullptr);
1235 switch (sign_nid) {
1236 case EVP_PKEY_RSA:
1237 sig_with_md = SIGN_NID_RSA;
1238 break;
1239 case EVP_PKEY_RSA_PSS:
1240 sig_with_md = SIGN_NID_RSA_PSS;
1241 break;
1242 case EVP_PKEY_DSA:
1243 sig_with_md = SIGN_NID_DSA;
1244 break;
1245 case EVP_PKEY_EC:
1246 sig_with_md = SIGN_NID_ECDSA;
1247 break;
1248 case NID_ED25519:
1249 sig_with_md = SIGN_NID_ED;
1250 break;
1251 case NID_ED448:
1252 sig_with_md = SIGN_NID_ED_FOUR_FOUR_EIGHT;
1253 break;
1254 default:
1255 const char *sn = OBJ_nid2sn(sign_nid);
1256 sig_with_md = (sn != nullptr) ? (std::string(sn) + OPERATOR_PLUS_SIGN) : SIGN_NID_UNDEF_ADD;
1257 }
1258 const char *sn_hash = OBJ_nid2sn(hash_nid);
1259 sig_with_md += (sn_hash != nullptr) ? std::string(sn_hash) : SIGN_NID_UNDEF;
1260 signatureAlgorithms_.push_back(sig_with_md);
1261 }
1262 return true;
1263 }
1264
StartTlsConnected(const TLSConnectOptions & options)1265 bool TLSSocket::TLSSocketInternal::StartTlsConnected(const TLSConnectOptions &options)
1266 {
1267 if (!CreatTlsContext()) {
1268 NETSTACK_LOGE("failed to create tls context");
1269 return false;
1270 }
1271 if (!StartShakingHands(options)) {
1272 NETSTACK_LOGE("failed to shaking hands");
1273 return false;
1274 }
1275 return true;
1276 }
1277
CreatTlsContext()1278 bool TLSSocket::TLSSocketInternal::CreatTlsContext()
1279 {
1280 tlsContextPointer_ = TLSContext::CreateConfiguration(configuration_);
1281 if (!tlsContextPointer_) {
1282 NETSTACK_LOGE("failed to create tls context pointer");
1283 return false;
1284 }
1285 if (!(ssl_ = tlsContextPointer_->CreateSsl())) {
1286 NETSTACK_LOGE("failed to create ssl session");
1287 return false;
1288 }
1289 SSL_set_fd(ssl_, socketDescriptor_);
1290 SSL_set_connect_state(ssl_);
1291 return true;
1292 }
1293
StartsWith(const std::string & s,const std::string & prefix)1294 static bool StartsWith(const std::string &s, const std::string &prefix)
1295 {
1296 return s.size() >= prefix.size() && s.compare(0, prefix.size(), prefix) == 0;
1297 }
1298
CheckIpAndDnsName(const std::string & hostName,std::vector<std::string> dnsNames,std::vector<std::string> ips,const X509 * x509Certificates,std::tuple<bool,std::string> & result)1299 void CheckIpAndDnsName(const std::string &hostName, std::vector<std::string> dnsNames, std::vector<std::string> ips,
1300 const X509 *x509Certificates, std::tuple<bool, std::string> &result)
1301 {
1302 bool valid = false;
1303 std::string reason = UNKNOW_REASON;
1304 int index = X509_get_ext_by_NID(x509Certificates, NID_commonName, -1);
1305 if (IsIP(hostName)) {
1306 auto it = find(ips.begin(), ips.end(), hostName);
1307 if (it == ips.end()) {
1308 reason = IP + hostName + " is not in the cert's list";
1309 }
1310 result = {valid, reason};
1311 return;
1312 }
1313 std::string tempHostName = "" + hostName;
1314 if (!dnsNames.empty() || index > 0) {
1315 std::vector<std::string> hostParts = SplitHostName(tempHostName);
1316 if (!dnsNames.empty()) {
1317 valid = SeekIntersection(hostParts, dnsNames);
1318 if (!valid) {
1319 reason = HOST_NAME + tempHostName + ". is not in the cert's altnames";
1320 }
1321 } else {
1322 char commonNameBuf[COMMON_NAME_BUF_SIZE] = {0};
1323 X509_NAME *pSubName = nullptr;
1324 int len = X509_NAME_get_text_by_NID(pSubName, NID_commonName, commonNameBuf, COMMON_NAME_BUF_SIZE);
1325 if (len > 0) {
1326 std::vector<std::string> commonNameVec;
1327 commonNameVec.emplace_back(commonNameBuf);
1328 valid = SeekIntersection(hostParts, commonNameVec);
1329 if (!valid) {
1330 reason = HOST_NAME + tempHostName + ". is not cert's CN";
1331 }
1332 }
1333 }
1334 result = {valid, reason};
1335 return;
1336 }
1337 reason = "Cert does not contain a DNS name";
1338 result = {valid, reason};
1339 }
1340
CheckServerIdentityLegal(const std::string & hostName,const X509 * x509Certificates)1341 std::string TLSSocket::TLSSocketInternal::CheckServerIdentityLegal(const std::string &hostName,
1342 const X509 *x509Certificates)
1343 {
1344 X509_NAME *subjectName = X509_get_subject_name(x509Certificates);
1345 if (!subjectName) {
1346 return "subject name is null";
1347 }
1348 char subNameBuf[BUF_SIZE] = {0};
1349 X509_NAME_oneline(subjectName, subNameBuf, BUF_SIZE);
1350
1351 int index = X509_get_ext_by_NID(x509Certificates, NID_subject_alt_name, -1);
1352 if (index < 0) {
1353 return "X509 get ext nid error";
1354 }
1355 X509_EXTENSION *ext = X509_get_ext(x509Certificates, index);
1356 if (ext == nullptr) {
1357 return "X509 get ext error";
1358 }
1359 ASN1_OBJECT *obj = nullptr;
1360 obj = X509_EXTENSION_get_object(ext);
1361 char subAltNameBuf[BUF_SIZE] = {0};
1362 OBJ_obj2txt(subAltNameBuf, BUF_SIZE, obj, 0);
1363 NETSTACK_LOGD("extions obj : %{public}s\n", subAltNameBuf);
1364
1365 return CheckServerIdentityLegal(hostName, ext, x509Certificates);
1366 }
1367
CheckServerIdentityLegal(const std::string & hostName,X509_EXTENSION * ext,const X509 * x509Certificates)1368 std::string TLSSocket::TLSSocketInternal::CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext,
1369 const X509 *x509Certificates)
1370 {
1371 ASN1_OCTET_STRING *extData = X509_EXTENSION_get_data(ext);
1372 std::string altNames = reinterpret_cast<char *>(extData->data);
1373 std::string hostname = " " + hostName;
1374 BIO *bio = BIO_new(BIO_s_file());
1375 if (!bio) {
1376 return "bio is null";
1377 }
1378 BIO_set_fp(bio, stdout, BIO_NOCLOSE);
1379 ASN1_STRING_print(bio, extData);
1380 std::vector<std::string> dnsNames = {};
1381 std::vector<std::string> ips = {};
1382 constexpr int DNS_NAME_IDX = 4;
1383 constexpr int IP_NAME_IDX = 11;
1384 if (!altNames.empty()) {
1385 std::vector<std::string> splitAltNames;
1386 if (altNames.find('\"') != std::string::npos) {
1387 splitAltNames = SplitEscapedAltNames(altNames);
1388 } else {
1389 splitAltNames = CommonUtils::Split(altNames, SPLIT_ALT_NAMES);
1390 }
1391 for (auto const &iter : splitAltNames) {
1392 if (StartsWith(iter, DNS)) {
1393 dnsNames.push_back(iter.substr(DNS_NAME_IDX));
1394 } else if (StartsWith(iter, IP_ADDRESS)) {
1395 ips.push_back(iter.substr(IP_NAME_IDX));
1396 }
1397 }
1398 }
1399 std::tuple<bool, std::string> result;
1400 CheckIpAndDnsName(hostName, dnsNames, ips, x509Certificates, result);
1401 if (!std::get<0>(result)) {
1402 return "Hostname/IP does not match certificate's altnames: " + std::get<1>(result);
1403 }
1404 return HOST_NAME + hostname + ". is cert's CN";
1405 }
1406
StartShakingHands(const TLSConnectOptions & options)1407 bool TLSSocket::TLSSocketInternal::StartShakingHands(const TLSConnectOptions &options)
1408 {
1409 if (!ssl_) {
1410 NETSTACK_LOGE("ssl is null");
1411 return false;
1412 }
1413 int result = SSL_connect(ssl_);
1414 if (result == -1) {
1415 int errorStatus = ConvertSSLError(ssl_);
1416 NETSTACK_LOGE("SSL connect is error, errno is %{public}d, error info is %{public}s", errorStatus,
1417 MakeSSLErrorString(errorStatus).c_str());
1418 return false;
1419 }
1420
1421 std::string list = SSL_get_cipher_list(ssl_, 0);
1422 NETSTACK_LOGI("SSL_get_cipher_list: %{public}s", list.c_str());
1423 configuration_.SetCipherSuite(list);
1424 if (!SetSharedSigals()) {
1425 NETSTACK_LOGE("Failed to set sharedSigalgs");
1426 }
1427 if (!GetRemoteCertificateFromPeer()) {
1428 NETSTACK_LOGE("Failed to get remote certificate");
1429 }
1430 if (!peerX509_) {
1431 NETSTACK_LOGE("peer x509Certificates is null");
1432 return false;
1433 }
1434 if (!SetRemoteCertRawData()) {
1435 NETSTACK_LOGE("Failed to set remote x509 certificata Serialization data");
1436 }
1437 CheckServerIdentity checkServerIdentity = options.GetCheckServerIdentity();
1438 if (!checkServerIdentity) {
1439 CheckServerIdentityLegal(hostName_, peerX509_);
1440 } else {
1441 checkServerIdentity(hostName_, {remoteCert_});
1442 }
1443 NETSTACK_LOGI("SSL Get Version: %{public}s, SSL Get Cipher: %{public}s", SSL_get_version(ssl_),
1444 SSL_get_cipher(ssl_));
1445 return true;
1446 }
1447
GetRemoteCertificateFromPeer()1448 bool TLSSocket::TLSSocketInternal::GetRemoteCertificateFromPeer()
1449 {
1450 peerX509_ = SSL_get_peer_certificate(ssl_);
1451 if (peerX509_ == nullptr) {
1452 int resErr = ConvertSSLError(GetSSL());
1453 NETSTACK_LOGE("open fail errno, errno is %{public}d, error info is %{public}s", resErr,
1454 MakeSSLErrorString(resErr).c_str());
1455 return false;
1456 }
1457 BIO *bio = BIO_new(BIO_s_mem());
1458 if (!bio) {
1459 NETSTACK_LOGE("TlsSocket::SetRemoteCertificate bio is null");
1460 return false;
1461 }
1462 X509_print(bio, peerX509_);
1463 char data[REMOTE_CERT_LEN] = {0};
1464 if (!BIO_read(bio, data, REMOTE_CERT_LEN)) {
1465 NETSTACK_LOGE("BIO_read function returns error");
1466 BIO_free(bio);
1467 return false;
1468 }
1469 BIO_free(bio);
1470 remoteCert_ = std::string(data);
1471 return true;
1472 }
1473
SetRemoteCertRawData()1474 bool TLSSocket::TLSSocketInternal::SetRemoteCertRawData()
1475 {
1476 if (peerX509_ == nullptr) {
1477 NETSTACK_LOGE("peerX509 is null");
1478 return false;
1479 }
1480 int32_t length = i2d_X509(peerX509_, nullptr);
1481 if (length <= 0) {
1482 NETSTACK_LOGE("Failed to convert peerX509 to der format");
1483 return false;
1484 }
1485 unsigned char *der = nullptr;
1486 (void)i2d_X509(peerX509_, &der);
1487 SecureData data(der, length);
1488 remoteRawData_.data = data;
1489 OPENSSL_free(der);
1490 remoteRawData_.encodingFormat = DER;
1491 return true;
1492 }
1493
GetRemoteCertRawData() const1494 const X509CertRawData &TLSSocket::TLSSocketInternal::GetRemoteCertRawData() const
1495 {
1496 return remoteRawData_;
1497 }
1498
GetSSL() const1499 ssl_st *TLSSocket::TLSSocketInternal::GetSSL() const
1500 {
1501 return ssl_;
1502 }
1503 } // namespace TlsSocket
1504 } // namespace NetStack
1505 } // namespace OHOS
1506