• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "tls_socket_server.h"
17 
18 #include <chrono>
19 #include <memory>
20 #include <netinet/tcp.h>
21 #include <numeric>
22 #include <openssl/err.h>
23 #include <openssl/ssl.h>
24 
25 #include <regex>
26 #include <securec.h>
27 #include <sys/ioctl.h>
28 
29 #include "base_context.h"
30 #include "netstack_common_utils.h"
31 #include "netstack_log.h"
32 #include "tls.h"
33 
34 namespace OHOS {
35 namespace NetStack {
36 namespace TlsSocketServer {
37 #if UNITTEST
38 #else
39 namespace {
40 #endif // UNITTEST
41 constexpr size_t MAX_ERR_LENGTH = 1024;
42 
43 constexpr int SSL_RET_CODE = 0;
44 
45 constexpr int BUF_SIZE = 2048;
46 constexpr int POLL_WAIT_TIME = 2000;
47 constexpr int OFFSET = 2;
48 constexpr int SSL_ERROR_RETURN = -1;
49 constexpr int REMOTE_CERT_LEN = 8192;
50 constexpr int COMMON_NAME_BUF_SIZE = 256;
51 constexpr int LISETEN_COUNT = 516;
52 constexpr const char *SPLIT_HOST_NAME = ".";
53 constexpr const char *SPLIT_ALT_NAMES = ",";
54 constexpr const char *DNS = "DNS:";
55 constexpr const char *HOST_NAME = "hostname: ";
56 constexpr const char *IP_ADDRESS = "IP Address:";
57 constexpr const char *SIGN_NID_RSA = "RSA+";
58 constexpr const char *SIGN_NID_RSA_PSS = "RSA-PSS+";
59 constexpr const char *SIGN_NID_DSA = "DSA+";
60 constexpr const char *SIGN_NID_ECDSA = "ECDSA+";
61 constexpr const char *SIGN_NID_ED = "Ed25519+";
62 constexpr const char *SIGN_NID_ED_FOUR_FOUR_EIGHT = "Ed448+";
63 constexpr const char *SIGN_NID_UNDEF_ADD = "UNDEF+";
64 constexpr const char *PROTOCOL_UNKNOW = "UNKNOW_PROTOCOL";
65 constexpr const char *SIGN_NID_UNDEF = "UNDEF";
66 constexpr const char *OPERATOR_PLUS_SIGN = "+";
67 constexpr const char *UNKNOW_REASON = "Unknown reason";
68 constexpr const char *IP = "IP: ";
69 static constexpr const char *TLS_SOCKET_SERVER_READ = "OS_NET_TSAccRD";
70 const std::regex JSON_STRING_PATTERN{R"(/^"(?:[^"\\\u0000-\u001f]|\\(?:["\\/bfnrt]|u[0-9a-fA-F]{4}))*"/)"};
71 const std::regex PATTERN{
72     "((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|"
73     "2[0-4][0-9]|[01]?[0-9][0-9]?)"};
74 int g_userCounter = 0;
75 
IsIP(const std::string & ip)76 bool IsIP(const std::string &ip)
77 {
78     std::regex pattern(PATTERN);
79     std::smatch res;
80     return regex_match(ip, res, pattern);
81 }
82 
SplitHostName(std::string & hostName)83 std::vector<std::string> SplitHostName(std::string &hostName)
84 {
85     transform(hostName.begin(), hostName.end(), hostName.begin(), ::tolower);
86     return CommonUtils::Split(hostName, SPLIT_HOST_NAME);
87 }
88 
SeekIntersection(std::vector<std::string> & vecA,std::vector<std::string> & vecB)89 bool SeekIntersection(std::vector<std::string> &vecA, std::vector<std::string> &vecB)
90 {
91     std::vector<std::string> result;
92     set_intersection(vecA.begin(), vecA.end(), vecB.begin(), vecB.end(), inserter(result, result.begin()));
93     return !result.empty();
94 }
95 
ConvertErrno()96 int ConvertErrno()
97 {
98     return TlsSocket::TlsSocketError::TLS_ERR_SYS_BASE + errno;
99 }
100 
ConvertSSLError(ssl_st * ssl)101 int ConvertSSLError(ssl_st *ssl)
102 {
103     if (!ssl) {
104         return TlsSocket::TLS_ERR_SSL_NULL;
105     }
106     return TlsSocket::TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl, SSL_RET_CODE);
107 }
108 
MakeErrnoString()109 std::string MakeErrnoString()
110 {
111     return strerror(errno);
112 }
113 
MakeSSLErrorString(int error)114 std::string MakeSSLErrorString(int error)
115 {
116     char err[MAX_ERR_LENGTH] = {0};
117     ERR_error_string_n(error - TlsSocket::TlsSocketError::TLS_ERR_SYS_BASE, err, sizeof(err));
118     return err;
119 }
SplitEscapedAltNames(std::string & altNames)120 std::vector<std::string> SplitEscapedAltNames(std::string &altNames)
121 {
122     std::vector<std::string> result;
123     std::string currentToken;
124     size_t offset = 0;
125     while (offset != altNames.length()) {
126         auto nextSep = altNames.find_first_of(", ");
127         auto nextQuote = altNames.find_first_of('\"');
128         if (nextQuote != std::string::npos && (nextSep != std::string::npos || nextQuote < nextSep)) {
129             currentToken += altNames.substr(offset, nextQuote);
130             std::regex jsonStringPattern(JSON_STRING_PATTERN);
131             std::smatch match;
132             std::string altNameSubStr = altNames.substr(nextQuote);
133             bool ret = regex_match(altNameSubStr, match, jsonStringPattern);
134             if (!ret) {
135                 return {""};
136             }
137             currentToken += result[0];
138             offset = nextQuote + result[0].length();
139         } else if (nextSep != std::string::npos) {
140             currentToken += altNames.substr(offset, nextSep);
141             result.push_back(currentToken);
142             currentToken = "";
143             offset = nextSep + OFFSET;
144         } else {
145             currentToken += altNames.substr(offset);
146             offset = altNames.length();
147         }
148     }
149     result.push_back(currentToken);
150     return result;
151 }
152 #if UNITTEST
153 #else
154 } // namespace
155 #endif
156 
SetSocket(const int & socketFd)157 void TLSServerSendOptions::SetSocket(const int &socketFd)
158 {
159     socketFd_ = socketFd;
160 }
161 
SetSendData(const std::string & data)162 void TLSServerSendOptions::SetSendData(const std::string &data)
163 {
164     data_ = data;
165 }
166 
GetSocket() const167 const int &TLSServerSendOptions::GetSocket() const
168 {
169     return socketFd_;
170 }
171 
GetSendData() const172 const std::string &TLSServerSendOptions::GetSendData() const
173 {
174     return data_;
175 }
176 
~TLSSocketServer()177 TLSSocketServer::~TLSSocketServer()
178 {
179     isRunning_ = false;
180     clientIdConnections_.clear();
181 
182     if (listenSocketFd_ != -1) {
183         shutdown(listenSocketFd_, SHUT_RDWR);
184         close(listenSocketFd_);
185         listenSocketFd_ = -1;
186     }
187 }
188 
Listen(const TlsSocket::TLSConnectOptions & tlsListenOptions,const ListenCallback & callback)189 void TLSSocketServer::Listen(const TlsSocket::TLSConnectOptions &tlsListenOptions, const ListenCallback &callback)
190 {
191     if (!CommonUtils::HasInternetPermission()) {
192         CallListenCallback(PERMISSION_DENIED_CODE, callback);
193         return;
194     }
195     NETSTACK_LOGE("Listen 1 %{public}d", listenSocketFd_);
196     if (listenSocketFd_ >= 0) {
197         CallListenCallback(TlsSocket::TLSSOCKET_SUCCESS, callback);
198         return;
199     }
200     NETSTACK_LOGE("Listen 2 %{public}d", listenSocketFd_);
201     if (ExecBind(tlsListenOptions.GetNetAddress(), callback)) {
202         NETSTACK_LOGE("Listen 3 %{public}d", listenSocketFd_);
203         ExecAccept(tlsListenOptions, callback);
204     } else {
205         shutdown(listenSocketFd_, SHUT_RDWR);
206         close(listenSocketFd_);
207         listenSocketFd_ = -1;
208     }
209 
210     PollThread(tlsListenOptions);
211 }
212 
ExecBind(const Socket::NetAddress & address,const ListenCallback & callback)213 bool TLSSocketServer::ExecBind(const Socket::NetAddress &address, const ListenCallback &callback)
214 {
215     MakeIpSocket(address.GetSaFamily());
216     if (listenSocketFd_ < 0) {
217         int resErr = ConvertErrno();
218         NETSTACK_LOGE("make tcp socket failed errno is %{public}d %{public}s", errno, MakeErrnoString().c_str());
219         CallOnErrorCallback(resErr, MakeErrnoString());
220         CallListenCallback(resErr, callback);
221         return false;
222     }
223     sockaddr_in addr4 = {0};
224     sockaddr_in6 addr6 = {0};
225     sockaddr *addr = nullptr;
226     socklen_t len;
227     GetAddr(address, &addr4, &addr6, &addr, &len);
228     if (addr == nullptr) {
229         NETSTACK_LOGE("TLSSocket::Bind Address Is Invalid");
230         CallOnErrorCallback(-1, "Address Is Invalid");
231         CallListenCallback(ConvertErrno(), callback);
232         return false;
233     }
234     if (bind(listenSocketFd_, addr, len) < 0) {
235         if (errno != EADDRINUSE) {
236             NETSTACK_LOGE("bind error is %{public}s %{public}d", strerror(errno), errno);
237             CallOnErrorCallback(-1, "Address binding failed");
238             CallListenCallback(ConvertErrno(), callback);
239             return false;
240         }
241         if (addr->sa_family == AF_INET) {
242             NETSTACK_LOGI("distribute a random port");
243             addr4.sin_port = 0; /* distribute a random port */
244         } else if (addr->sa_family == AF_INET6) {
245             NETSTACK_LOGI("distribute a random port");
246             addr6.sin6_port = 0; /* distribute a random port */
247         }
248         if (bind(listenSocketFd_, addr, len) < 0) {
249             NETSTACK_LOGE("rebind error is %{public}s %{public}d", strerror(errno), errno);
250             CallOnErrorCallback(-1, "Duplicate binding address failed");
251             CallListenCallback(ConvertErrno(), callback);
252             return false;
253         }
254         NETSTACK_LOGI("rebind success");
255     }
256     NETSTACK_LOGI("bind success");
257     address_ = address;
258     return true;
259 }
260 
ExecAccept(const TlsSocket::TLSConnectOptions & tlsAcceptOptions,const ListenCallback & callback)261 void TLSSocketServer::ExecAccept(const TlsSocket::TLSConnectOptions &tlsAcceptOptions, const ListenCallback &callback)
262 {
263     if (listenSocketFd_ < 0) {
264         int resErr = ConvertErrno();
265         NETSTACK_LOGE("accept error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
266         CallOnErrorCallback(resErr, MakeErrnoString());
267         callback(resErr);
268         return;
269     }
270     SetLocalTlsConfiguration(tlsAcceptOptions);
271     int ret = 0;
272 
273     NETSTACK_LOGE(
274         "accept error is listenSocketFd_=  %{public}d LISETEN_COUNT =%{public}d .GetVerifyMode()  = %{public}d ",
275         listenSocketFd_, LISETEN_COUNT, tlsAcceptOptions.GetTlsSecureOptions().GetVerifyMode());
276     ret = listen(listenSocketFd_, LISETEN_COUNT);
277     if (ret < 0) {
278         int resErr = ConvertErrno();
279         NETSTACK_LOGE("tcp server listen error");
280         CallOnErrorCallback(resErr, MakeErrnoString());
281         callback(resErr);
282         return;
283     }
284     CallListenCallback(TlsSocket::TLSSOCKET_SUCCESS, callback);
285 }
286 
Send(const TLSServerSendOptions & data,const TlsSocket::SendCallback & callback)287 bool TLSSocketServer::Send(const TLSServerSendOptions &data, const TlsSocket::SendCallback &callback)
288 {
289     int socketFd = data.GetSocket();
290     std::string info = data.GetSendData();
291 
292     auto connect_iterator = clientIdConnections_.find(socketFd);
293     if (connect_iterator == clientIdConnections_.end()) {
294         NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
295         CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
296         return false;
297     }
298     auto connect = connect_iterator->second;
299     auto res = connect->Send(info);
300     if (!res) {
301         int resErr = ConvertSSLError(connect->GetSSL());
302         NETSTACK_LOGE("send error is %{public}d %{public}d", resErr, errno);
303         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
304         CallSendCallback(resErr, callback);
305         return false;
306     }
307     CallSendCallback(TlsSocket::TLSSOCKET_SUCCESS, callback);
308     return res;
309 }
310 
CallSendCallback(int32_t err,TlsSocket::SendCallback callback)311 void TLSSocketServer::CallSendCallback(int32_t err, TlsSocket::SendCallback callback)
312 {
313     if (callback) {
314         callback(err);
315     }
316 }
317 
Close(const int socketFd,const TlsSocket::CloseCallback & callback)318 void TLSSocketServer::Close(const int socketFd, const TlsSocket::CloseCallback &callback)
319 {
320     {
321         std::lock_guard<std::mutex> its_lock(connectMutex_);
322         for (auto it = clientIdConnections_.begin(); it != clientIdConnections_.end();) {
323             if (it->first == socketFd) {
324                 auto res = it->second->Close();
325                 if (!res) {
326                     int resErr = ConvertSSLError(it->second->GetSSL());
327                     NETSTACK_LOGE("close error is %{public}d %{public}d", resErr, errno);
328                     CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
329                     callback(resErr);
330                     return;
331                 }
332                 callback(TlsSocket::TLSSOCKET_SUCCESS);
333                 return;
334             } else {
335                 ++it;
336             }
337         }
338     }
339     NETSTACK_LOGE("socket = %{public}d There is no corresponding socketFd", socketFd);
340     CallOnErrorCallback(-1, "The send failed with no corresponding socketFd");
341     callback(TlsSocket::TLS_ERR_SYS_EINVAL);
342 }
343 
Stop(const TlsSocket::CloseCallback & callback)344 void TLSSocketServer::Stop(const TlsSocket::CloseCallback &callback)
345 {
346     std::lock_guard<std::mutex> its_lock(connectMutex_);
347     for (const auto &c : clientIdConnections_) {
348         c.second->Close();
349     }
350     clientIdConnections_.clear();
351     close(listenSocketFd_);
352     listenSocketFd_ = -1;
353     callback(TlsSocket::TLSSOCKET_SUCCESS);
354 }
355 
GetRemoteAddress(const int socketFd,const TlsSocket::GetRemoteAddressCallback & callback)356 void TLSSocketServer::GetRemoteAddress(const int socketFd, const TlsSocket::GetRemoteAddressCallback &callback)
357 {
358     auto connect_iterator = clientIdConnections_.find(socketFd);
359     if (connect_iterator == clientIdConnections_.end()) {
360         NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
361         CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
362         callback(TlsSocket::TLS_ERR_SYS_EINVAL, {});
363         return;
364     }
365     auto connect = connect_iterator->second;
366     auto address = connect->GetAddress();
367     callback(TlsSocket::TLSSOCKET_SUCCESS, address);
368 }
369 
GetLocalAddress(const int socketFd,const TlsSocket::GetLocalAddressCallback & callback)370 void TLSSocketServer::GetLocalAddress(const int socketFd, const TlsSocket::GetLocalAddressCallback &callback)
371 {
372     auto connect_iterator = clientIdConnections_.find(socketFd);
373     if (connect_iterator == clientIdConnections_.end()) {
374         NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
375         CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
376         callback(TlsSocket::TLS_ERR_SYS_EINVAL, {});
377         return;
378     }
379     auto connect = connect_iterator->second;
380     auto localAddress = connect->GetLocalAddress();
381     callback(TlsSocket::TLSSOCKET_SUCCESS, localAddress);
382 }
383 
GetState(const TlsSocket::GetStateCallback & callback)384 void TLSSocketServer::GetState(const TlsSocket::GetStateCallback &callback)
385 {
386     int opt;
387     socklen_t optLen = sizeof(int);
388     int r = getsockopt(listenSocketFd_, SOL_SOCKET, SO_TYPE, &opt, &optLen);
389     if (r < 0) {
390         Socket::SocketStateBase state;
391         state.SetIsClose(true);
392         CallGetStateCallback(ConvertErrno(), state, callback);
393         return;
394     }
395     sockaddr sockAddr = {0};
396     socklen_t len = sizeof(sockaddr);
397     Socket::SocketStateBase state;
398     int ret = getsockname(listenSocketFd_, &sockAddr, &len);
399     state.SetIsBound(ret == 0);
400     ret = getpeername(listenSocketFd_, &sockAddr, &len);
401     if (ret != 0) {
402         NETSTACK_LOGE("getpeername failed");
403     }
404     state.SetIsConnected(GetConnectionClientCount() > 0);
405     CallGetStateCallback(TlsSocket::TLSSOCKET_SUCCESS, state, callback);
406 }
407 
CallGetStateCallback(int32_t err,const Socket::SocketStateBase & state,TlsSocket::GetStateCallback callback)408 void TLSSocketServer::CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state,
409                                            TlsSocket::GetStateCallback callback)
410 {
411     if (callback) {
412         callback(err, state);
413     }
414 }
SetExtraOptions(const Socket::TCPExtraOptions & tcpExtraOptions,const TlsSocket::SetExtraOptionsCallback & callback)415 bool TLSSocketServer::SetExtraOptions(const Socket::TCPExtraOptions &tcpExtraOptions,
416                                       const TlsSocket::SetExtraOptionsCallback &callback)
417 {
418     if (tcpExtraOptions.IsKeepAlive()) {
419         int keepalive = 1;
420         if (setsockopt(listenSocketFd_, SOL_SOCKET, SO_KEEPALIVE, &keepalive, sizeof(keepalive)) < 0) {
421             return false;
422         }
423     }
424 
425     if (tcpExtraOptions.IsOOBInline()) {
426         int oobInline = 1;
427         if (setsockopt(listenSocketFd_, SOL_SOCKET, SO_OOBINLINE, &oobInline, sizeof(oobInline)) < 0) {
428             return false;
429         }
430     }
431 
432     if (tcpExtraOptions.IsTCPNoDelay()) {
433         int tcpNoDelay = 1;
434         if (setsockopt(listenSocketFd_, IPPROTO_TCP, TCP_NODELAY, &tcpNoDelay, sizeof(tcpNoDelay)) < 0) {
435             return false;
436         }
437     }
438 
439     linger soLinger = {0};
440     soLinger.l_onoff = tcpExtraOptions.socketLinger.IsOn();
441     soLinger.l_linger = (int)tcpExtraOptions.socketLinger.GetLinger();
442     if (setsockopt(listenSocketFd_, SOL_SOCKET, SO_LINGER, &soLinger, sizeof(soLinger)) < 0) {
443         return false;
444     }
445 
446     return true;
447 }
448 
SetLocalTlsConfiguration(const TlsSocket::TLSConnectOptions & config)449 void TLSSocketServer::SetLocalTlsConfiguration(const TlsSocket::TLSConnectOptions &config)
450 {
451     TLSServerConfiguration_.SetPrivateKey(config.GetTlsSecureOptions().GetKey(),
452                                           config.GetTlsSecureOptions().GetKeyPass());
453     TLSServerConfiguration_.SetLocalCertificate(config.GetTlsSecureOptions().GetCert());
454     TLSServerConfiguration_.SetCaCertificate(config.GetTlsSecureOptions().GetCaChain());
455 
456     TLSServerConfiguration_.SetVerifyMode(config.GetTlsSecureOptions().GetVerifyMode());
457 
458     const auto protocolVec = config.GetTlsSecureOptions().GetProtocolChain();
459     if (!protocolVec.empty()) {
460         TLSServerConfiguration_.SetProtocol(protocolVec);
461     }
462 }
463 
GetCertificate(const TlsSocket::GetCertificateCallback & callback)464 void TLSSocketServer::GetCertificate(const TlsSocket::GetCertificateCallback &callback)
465 {
466     const auto &cert = TLSServerConfiguration_.GetCertificate();
467     NETSTACK_LOGI("cert der is %{public}d", cert.encodingFormat);
468     if (!cert.data.Length()) {
469         CallOnErrorCallback(-1, "cert not data Length");
470         callback(-1, {});
471         return;
472     }
473     callback(TlsSocket::TLSSOCKET_SUCCESS, cert);
474 }
475 
GetRemoteCertificate(const int socketFd,const TlsSocket::GetRemoteCertificateCallback & callback)476 void TLSSocketServer::GetRemoteCertificate(const int socketFd, const TlsSocket::GetRemoteCertificateCallback &callback)
477 {
478     auto connect_iterator = clientIdConnections_.find(socketFd);
479     if (connect_iterator == clientIdConnections_.end()) {
480         NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
481         CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
482         callback(TlsSocket::TLS_ERR_SYS_EINVAL, {});
483         return;
484     }
485     auto connect = connect_iterator->second;
486     const auto &remoteCert = connect->GetRemoteCertRawData();
487     if (!remoteCert.data.Length()) {
488         int resErr = ConvertSSLError(connect->GetSSL());
489         NETSTACK_LOGE("GetRemoteCertificate error is %{public}d %{public}d", resErr, errno);
490         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
491         callback(resErr, {});
492         return;
493     }
494     callback(TlsSocket::TLSSOCKET_SUCCESS, remoteCert);
495 }
496 
GetProtocol(const TlsSocket::GetProtocolCallback & callback)497 void TLSSocketServer::GetProtocol(const TlsSocket::GetProtocolCallback &callback)
498 {
499     if (TLSServerConfiguration_.GetProtocol() == TlsSocket::TLS_V1_3) {
500         callback(TlsSocket::TLSSOCKET_SUCCESS, TlsSocket::PROTOCOL_TLS_V13);
501         return;
502     }
503     callback(TlsSocket::TLSSOCKET_SUCCESS, TlsSocket::PROTOCOL_TLS_V12);
504 }
505 
GetCipherSuite(const int socketFd,const TlsSocket::GetCipherSuiteCallback & callback)506 void TLSSocketServer::GetCipherSuite(const int socketFd, const TlsSocket::GetCipherSuiteCallback &callback)
507 {
508     auto connect_iterator = clientIdConnections_.find(socketFd);
509     if (connect_iterator == clientIdConnections_.end()) {
510         NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
511         CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
512         callback(TlsSocket::TLS_ERR_SYS_EINVAL, {});
513         return;
514     }
515     auto connect = connect_iterator->second;
516     auto cipherSuite = connect->GetCipherSuite();
517     if (cipherSuite.empty()) {
518         int resErr = ConvertSSLError(connect->GetSSL());
519         NETSTACK_LOGE("GetCipherSuite error is %{public}d %{public}d", resErr, errno);
520         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
521         callback(resErr, cipherSuite);
522         return;
523     }
524     callback(TlsSocket::TLSSOCKET_SUCCESS, cipherSuite);
525 }
526 
GetSignatureAlgorithms(const int socketFd,const TlsSocket::GetSignatureAlgorithmsCallback & callback)527 void TLSSocketServer::GetSignatureAlgorithms(const int socketFd,
528                                              const TlsSocket::GetSignatureAlgorithmsCallback &callback)
529 {
530     auto connect_iterator = clientIdConnections_.find(socketFd);
531     if (connect_iterator == clientIdConnections_.end()) {
532         NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
533         CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
534         callback(TlsSocket::TLS_ERR_SYS_EINVAL, {});
535         return;
536     }
537     auto connect = connect_iterator->second;
538     auto signatureAlgorithms = connect->GetSignatureAlgorithms();
539     if (signatureAlgorithms.empty()) {
540         int resErr = ConvertSSLError(connect->GetSSL());
541         NETSTACK_LOGE("GetSignatureAlgorithms error is %{public}d %{public}d", resErr, errno);
542         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
543         callback(resErr, signatureAlgorithms);
544         return;
545     }
546     callback(TlsSocket::TLSSOCKET_SUCCESS, signatureAlgorithms);
547 }
548 
OnMessage(const OnMessageCallback & onMessageCallback)549 void TLSSocketServer::Connection::OnMessage(const OnMessageCallback &onMessageCallback)
550 {
551     onMessageCallback_ = onMessageCallback;
552     CachedMessageCallback();
553 }
554 
OnClose(const OnCloseCallback & onCloseCallback)555 void TLSSocketServer::Connection::OnClose(const OnCloseCallback &onCloseCallback)
556 {
557     onCloseCallback_ = onCloseCallback;
558 }
559 
OnConnect(const OnConnectCallback & onConnectCallback)560 void TLSSocketServer::OnConnect(const OnConnectCallback &onConnectCallback)
561 {
562     std::lock_guard<std::mutex> lock(mutex_);
563     onConnectCallback_ = onConnectCallback;
564 }
565 
OnError(const TlsSocket::OnErrorCallback & onErrorCallback)566 void TLSSocketServer::OnError(const TlsSocket::OnErrorCallback &onErrorCallback)
567 {
568     std::lock_guard<std::mutex> lock(mutex_);
569     onErrorCallback_ = onErrorCallback;
570 }
571 
OffMessage()572 void TLSSocketServer::Connection::OffMessage()
573 {
574     if (onMessageCallback_) {
575         onMessageCallback_ = nullptr;
576     }
577 }
578 
OffConnect()579 void TLSSocketServer::OffConnect()
580 {
581     std::lock_guard<std::mutex> lock(mutex_);
582     if (onConnectCallback_) {
583         onConnectCallback_ = nullptr;
584     }
585 }
586 
OnError(const TlsSocket::OnErrorCallback & onErrorCallback)587 void TLSSocketServer::Connection::OnError(const TlsSocket::OnErrorCallback &onErrorCallback)
588 {
589     onErrorCallback_ = onErrorCallback;
590 }
591 
OffClose()592 void TLSSocketServer::Connection::OffClose()
593 {
594     if (onCloseCallback_) {
595         onCloseCallback_ = nullptr;
596     }
597 }
598 
OffError()599 void TLSSocketServer::Connection::OffError()
600 {
601     onErrorCallback_ = nullptr;
602 }
603 
CallOnErrorCallback(int32_t err,const std::string & errString)604 void TLSSocketServer::Connection::CallOnErrorCallback(int32_t err, const std::string &errString)
605 {
606     TlsSocket::OnErrorCallback CallBackfunc = nullptr;
607     {
608         if (onErrorCallback_) {
609             CallBackfunc = onErrorCallback_;
610         }
611     }
612 
613     if (CallBackfunc) {
614         CallBackfunc(err, errString);
615     }
616 }
OffError()617 void TLSSocketServer::OffError()
618 {
619     std::lock_guard<std::mutex> lock(mutex_);
620     if (onErrorCallback_) {
621         onErrorCallback_ = nullptr;
622     }
623 }
624 
MakeIpSocket(sa_family_t family)625 void TLSSocketServer::MakeIpSocket(sa_family_t family)
626 {
627     if (family != AF_INET && family != AF_INET6) {
628         return;
629     }
630     int sock = socket(family, SOCK_STREAM, IPPROTO_IP);
631     if (sock < 0) {
632         int resErr = ConvertErrno();
633         NETSTACK_LOGE("Create socket failed (%{public}d:%{public}s)", errno, MakeErrnoString().c_str());
634         CallOnErrorCallback(resErr, MakeErrnoString());
635         return;
636     }
637     listenSocketFd_ = sock;
638 }
639 
CallOnErrorCallback(int32_t err,const std::string & errString)640 void TLSSocketServer::CallOnErrorCallback(int32_t err, const std::string &errString)
641 {
642     TlsSocket::OnErrorCallback CallBackfunc = nullptr;
643     {
644         std::lock_guard<std::mutex> lock(mutex_);
645         if (onErrorCallback_) {
646             CallBackfunc = onErrorCallback_;
647         }
648     }
649 
650     if (CallBackfunc) {
651         CallBackfunc(err, errString);
652     }
653 }
GetAddr(const Socket::NetAddress & address,sockaddr_in * addr4,sockaddr_in6 * addr6,sockaddr ** addr,socklen_t * len)654 void TLSSocketServer::GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6,
655                               sockaddr **addr, socklen_t *len)
656 {
657     if (!addr6 || !addr4 || !len) {
658         return;
659     }
660     sa_family_t family = address.GetSaFamily();
661     if (family == AF_INET) {
662         addr4->sin_family = AF_INET;
663         addr4->sin_port = htons(address.GetPort());
664         addr4->sin_addr.s_addr = inet_addr(address.GetAddress().c_str());
665         *addr = reinterpret_cast<sockaddr *>(addr4);
666         *len = sizeof(sockaddr_in);
667     } else if (family == AF_INET6) {
668         addr6->sin6_family = AF_INET6;
669         addr6->sin6_port = htons(address.GetPort());
670         inet_pton(AF_INET6, address.GetAddress().c_str(), &addr6->sin6_addr);
671         *addr = reinterpret_cast<sockaddr *>(addr6);
672         *len = sizeof(sockaddr_in6);
673     }
674 }
675 
GetListenSocketFd()676 int TLSSocketServer::GetListenSocketFd()
677 {
678     return listenSocketFd_;
679 }
680 
SetLocalAddress(const Socket::NetAddress & address)681 void TLSSocketServer::SetLocalAddress(const Socket::NetAddress &address)
682 {
683     localAddress_ = address;
684 }
685 
GetLocalAddress()686 Socket::NetAddress TLSSocketServer::GetLocalAddress()
687 {
688     return localAddress_;
689 }
690 
GetConnectionByClientID(int clientid)691 std::shared_ptr<TLSSocketServer::Connection> TLSSocketServer::GetConnectionByClientID(int clientid)
692 {
693     std::shared_ptr<Connection> ptrConnection = nullptr;
694 
695     auto it = clientIdConnections_.find(clientid);
696     if (it != clientIdConnections_.end()) {
697         ptrConnection = it->second;
698     }
699 
700     return ptrConnection;
701 }
702 
GetConnectionClientCount()703 int TLSSocketServer::GetConnectionClientCount()
704 {
705     return g_userCounter;
706 }
707 
CallListenCallback(int32_t err,ListenCallback callback)708 void TLSSocketServer::CallListenCallback(int32_t err, ListenCallback callback)
709 {
710     if (callback) {
711         callback(err);
712     }
713 }
714 
SetAddress(const Socket::NetAddress address)715 void TLSSocketServer::Connection::SetAddress(const Socket::NetAddress address)
716 {
717     address_ = address;
718 }
719 
SetLocalAddress(const Socket::NetAddress address)720 void TLSSocketServer::Connection::SetLocalAddress(const Socket::NetAddress address)
721 {
722     localAddress_ = address;
723 }
724 
GetRemoteCertRawData() const725 const TlsSocket::X509CertRawData &TLSSocketServer::Connection::GetRemoteCertRawData() const
726 {
727     return remoteRawData_;
728 }
729 
~Connection()730 TLSSocketServer::Connection::~Connection()
731 {
732     Close();
733 }
734 
TlsAcceptToHost(int sock,const TlsSocket::TLSConnectOptions & options)735 bool TLSSocketServer::Connection::TlsAcceptToHost(int sock, const TlsSocket::TLSConnectOptions &options)
736 {
737     SetTlsConfiguration(options);
738     std::string cipherSuite = options.GetTlsSecureOptions().GetCipherSuite();
739     if (!cipherSuite.empty()) {
740         connectionConfiguration_.SetCipherSuite(cipherSuite);
741     }
742     std::string signatureAlgorithms = options.GetTlsSecureOptions().GetSignatureAlgorithms();
743     if (!signatureAlgorithms.empty()) {
744         connectionConfiguration_.SetSignatureAlgorithms(signatureAlgorithms);
745     }
746     const auto protocolVec = options.GetTlsSecureOptions().GetProtocolChain();
747     if (!protocolVec.empty()) {
748         connectionConfiguration_.SetProtocol(protocolVec);
749     }
750     connectionConfiguration_.SetVerifyMode(options.GetTlsSecureOptions().GetVerifyMode());
751     socketFd_ = sock;
752     return StartTlsAccept(options);
753 }
754 
SetTlsConfiguration(const TlsSocket::TLSConnectOptions & config)755 void TLSSocketServer::Connection::SetTlsConfiguration(const TlsSocket::TLSConnectOptions &config)
756 {
757     connectionConfiguration_.SetPrivateKey(config.GetTlsSecureOptions().GetKey(),
758                                            config.GetTlsSecureOptions().GetKeyPass());
759     connectionConfiguration_.SetLocalCertificate(config.GetTlsSecureOptions().GetCert());
760     connectionConfiguration_.SetCaCertificate(config.GetTlsSecureOptions().GetCaChain());
761     connectionConfiguration_.SetNetAddress(config.GetNetAddress());
762 }
763 
Send(const std::string & data)764 bool TLSSocketServer::Connection::Send(const std::string &data)
765 {
766     if (!ssl_) {
767         NETSTACK_LOGE("ssl is null");
768         return false;
769     }
770     if (data.empty()) {
771         NETSTACK_LOGI("data is empty");
772         return true;
773     }
774     int len = SSL_write(ssl_, data.c_str(), data.length());
775     if (len < 0) {
776         int resErr = ConvertSSLError(GetSSL());
777         NETSTACK_LOGE("data send failed! error is %{public}d %{public}d", resErr, errno);
778         return false;
779     }
780     NETSTACK_LOGD("data Sent successfully,sent in total %{public}d bytes!", len);
781     return true;
782 }
783 
Recv(char * buffer,int maxBufferSize)784 int TLSSocketServer::Connection::Recv(char *buffer, int maxBufferSize)
785 {
786     if (!ssl_) {
787         NETSTACK_LOGE("ssl is null");
788         return SSL_ERROR_RETURN;
789     }
790     return SSL_read(ssl_, buffer, maxBufferSize);
791 }
792 
Close()793 bool TLSSocketServer::Connection::Close()
794 {
795     if (!ssl_) {
796         NETSTACK_LOGE("ssl is null");
797         return false;
798     }
799     int result = SSL_shutdown(ssl_);
800     if (result < 0) {
801         int resErr = ConvertSSLError(GetSSL());
802         NETSTACK_LOGE("Error in shutdown, error is %{public}d %{public}d", resErr, errno);
803     }
804     SSL_free(ssl_);
805     ssl_ = nullptr;
806     if (socketFd_ != -1) {
807         shutdown(socketFd_, SHUT_RDWR);
808         close(socketFd_);
809         socketFd_ = -1;
810     }
811     if (!tlsContextServerPointer_) {
812         NETSTACK_LOGE("Tls context pointer is null");
813         return false;
814     }
815     tlsContextServerPointer_->CloseCtx();
816     return true;
817 }
818 
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)819 bool TLSSocketServer::Connection::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
820 {
821     if (!ssl_) {
822         NETSTACK_LOGE("ssl is null");
823         return false;
824     }
825     size_t pos = 0;
826     size_t len = std::accumulate(alpnProtocols.begin(), alpnProtocols.end(), static_cast<size_t>(0),
827                                  [](size_t init, const std::string &alpnProt) { return init + alpnProt.length(); });
828     auto result = std::make_unique<unsigned char[]>(alpnProtocols.size() + len);
829     for (const auto &str : alpnProtocols) {
830         len = str.length();
831         result[pos++] = len;
832         if (!strcpy_s(reinterpret_cast<char *>(&result[pos]), len, str.c_str())) {
833             NETSTACK_LOGE("strcpy_s failed");
834             return false;
835         }
836         pos += len;
837     }
838     result[pos] = '\0';
839 
840     NETSTACK_LOGD("alpnProtocols after splicing %{public}s", result.get());
841     if (SSL_set_alpn_protos(ssl_, result.get(), pos)) {
842         int resErr = ConvertSSLError(GetSSL());
843         NETSTACK_LOGE("Failed to set negotiable protocol list, error is %{public}d %{public}d", resErr, errno);
844         return false;
845     }
846     return true;
847 }
848 
MakeRemoteInfo(Socket::SocketRemoteInfo & remoteInfo)849 void TLSSocketServer::Connection::MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo)
850 {
851     remoteInfo.SetAddress(address_.GetAddress());
852     remoteInfo.SetPort(address_.GetPort());
853     remoteInfo.SetFamily(address_.GetSaFamily());
854 }
855 
GetTlsConfiguration() const856 TlsSocket::TLSConfiguration TLSSocketServer::Connection::GetTlsConfiguration() const
857 {
858     return connectionConfiguration_;
859 }
860 
GetCipherSuite() const861 std::vector<std::string> TLSSocketServer::Connection::GetCipherSuite() const
862 {
863     if (!ssl_) {
864         NETSTACK_LOGE("ssl in null");
865         return {};
866     }
867     STACK_OF(SSL_CIPHER) *sk = SSL_get_ciphers(ssl_);
868     if (!sk) {
869         NETSTACK_LOGE("get ciphers failed");
870         return {};
871     }
872     TlsSocket::CipherSuite cipherSuite;
873     std::vector<std::string> cipherSuiteVec;
874     for (int i = 0; i < sk_SSL_CIPHER_num(sk); i++) {
875         const SSL_CIPHER *c = sk_SSL_CIPHER_value(sk, i);
876         cipherSuite.cipherName_ = SSL_CIPHER_get_name(c);
877         cipherSuiteVec.push_back(cipherSuite.cipherName_);
878     }
879     return cipherSuiteVec;
880 }
881 
GetRemoteCertificate() const882 std::string TLSSocketServer::Connection::GetRemoteCertificate() const
883 {
884     return remoteCert_;
885 }
886 
GetCertificate() const887 const TlsSocket::X509CertRawData &TLSSocketServer::Connection::GetCertificate() const
888 {
889     return connectionConfiguration_.GetCertificate();
890 }
891 
GetSignatureAlgorithms() const892 std::vector<std::string> TLSSocketServer::Connection::GetSignatureAlgorithms() const
893 {
894     return signatureAlgorithms_;
895 }
896 
GetProtocol() const897 std::string TLSSocketServer::Connection::GetProtocol() const
898 {
899     if (!ssl_) {
900         NETSTACK_LOGE("ssl in null");
901         return PROTOCOL_UNKNOW;
902     }
903     if (connectionConfiguration_.GetProtocol() == TlsSocket::TLS_V1_3) {
904         return TlsSocket::PROTOCOL_TLS_V13;
905     }
906     return TlsSocket::PROTOCOL_TLS_V12;
907 }
908 
SetSharedSigals()909 bool TLSSocketServer::Connection::SetSharedSigals()
910 {
911     if (!ssl_) {
912         NETSTACK_LOGE("ssl is null");
913         return false;
914     }
915     int number = SSL_get_shared_sigalgs(ssl_, 0, nullptr, nullptr, nullptr, nullptr, nullptr);
916     if (!number) {
917         NETSTACK_LOGE("SSL_get_shared_sigalgs return value error");
918         return false;
919     }
920     for (int i = 0; i < number; i++) {
921         int hash_nid;
922         int sign_nid;
923         std::string sig_with_md;
924         SSL_get_shared_sigalgs(ssl_, i, &sign_nid, &hash_nid, nullptr, nullptr, nullptr);
925         switch (sign_nid) {
926             case EVP_PKEY_RSA:
927                 sig_with_md = SIGN_NID_RSA;
928                 break;
929             case EVP_PKEY_RSA_PSS:
930                 sig_with_md = SIGN_NID_RSA_PSS;
931                 break;
932             case EVP_PKEY_DSA:
933                 sig_with_md = SIGN_NID_DSA;
934                 break;
935             case EVP_PKEY_EC:
936                 sig_with_md = SIGN_NID_ECDSA;
937                 break;
938             case NID_ED25519:
939                 sig_with_md = SIGN_NID_ED;
940                 break;
941             case NID_ED448:
942                 sig_with_md = SIGN_NID_ED_FOUR_FOUR_EIGHT;
943                 break;
944             default:
945                 const char *sn = OBJ_nid2sn(sign_nid);
946                 sig_with_md = (sn != nullptr) ? (std::string(sn) + OPERATOR_PLUS_SIGN) : SIGN_NID_UNDEF_ADD;
947         }
948         const char *sn_hash = OBJ_nid2sn(hash_nid);
949         sig_with_md += (sn_hash != nullptr) ? std::string(sn_hash) : SIGN_NID_UNDEF;
950         signatureAlgorithms_.push_back(sig_with_md);
951     }
952     return true;
953 }
954 
GetSSL() const955 ssl_st *TLSSocketServer::Connection::GetSSL() const
956 {
957     return ssl_;
958 }
959 
GetAddress() const960 Socket::NetAddress TLSSocketServer::Connection::GetAddress() const
961 {
962     return address_;
963 }
964 
GetLocalAddress() const965 Socket::NetAddress TLSSocketServer::Connection::GetLocalAddress() const
966 {
967     return localAddress_;
968 }
969 
GetSocketFd() const970 int TLSSocketServer::Connection::GetSocketFd() const
971 {
972     return socketFd_;
973 }
974 
GetEventManager() const975 std::shared_ptr<EventManager> TLSSocketServer::Connection::GetEventManager() const
976 {
977     return eventManager_;
978 }
979 
SetEventManager(std::shared_ptr<EventManager> eventManager)980 void TLSSocketServer::Connection::SetEventManager(std::shared_ptr<EventManager> eventManager)
981 {
982     eventManager_ = eventManager;
983 }
984 
SetClientID(int32_t clientID)985 void TLSSocketServer::Connection::SetClientID(int32_t clientID)
986 {
987     clientID_ = clientID;
988 }
989 
GetClientID()990 int TLSSocketServer::Connection::GetClientID()
991 {
992     return clientID_;
993 }
994 
StartTlsAccept(const TlsSocket::TLSConnectOptions & options)995 bool TLSSocketServer::Connection::StartTlsAccept(const TlsSocket::TLSConnectOptions &options)
996 {
997     if (!CreatTlsContext()) {
998         NETSTACK_LOGE("failed to create tls context");
999         return false;
1000     }
1001     if (!StartShakingHands(options)) {
1002         NETSTACK_LOGE("failed to shaking hands");
1003         return false;
1004     }
1005     return true;
1006 }
1007 
CreatTlsContext()1008 bool TLSSocketServer::Connection::CreatTlsContext()
1009 {
1010     tlsContextServerPointer_ = TlsSocket::TLSContextServer::CreateConfiguration(connectionConfiguration_);
1011     if (!tlsContextServerPointer_) {
1012         NETSTACK_LOGE("failed to create tls context pointer");
1013         return false;
1014     }
1015     if (!(ssl_ = tlsContextServerPointer_->CreateSsl())) {
1016         NETSTACK_LOGE("failed to create ssl session");
1017         return false;
1018     }
1019     SSL_set_fd(ssl_, socketFd_);
1020     SSL_set_accept_state(ssl_);
1021     return true;
1022 }
1023 
StartShakingHands(const TlsSocket::TLSConnectOptions & options)1024 bool TLSSocketServer::Connection::StartShakingHands(const TlsSocket::TLSConnectOptions &options)
1025 {
1026     if (!ssl_) {
1027         NETSTACK_LOGE("ssl is null");
1028         return false;
1029     }
1030     int result = SSL_accept(ssl_);
1031     if (result == -1) {
1032         int errorStatus = ConvertSSLError(ssl_);
1033         NETSTACK_LOGE("SSL connect is error, errno is %{public}d, error info is %{public}s", errorStatus,
1034                       MakeSSLErrorString(errorStatus).c_str());
1035         return false;
1036     }
1037 
1038     std::vector<std::string> SslProtocolVer({SSL_get_version(ssl_)});
1039     connectionConfiguration_.SetProtocol({SslProtocolVer});
1040 
1041     std::string list = SSL_get_cipher_list(ssl_, 0);
1042     NETSTACK_LOGI("SSL_get_cipher_list: %{public}s", list.c_str());
1043     connectionConfiguration_.SetCipherSuite(list);
1044     if (!SetSharedSigals()) {
1045         NETSTACK_LOGE("Failed to set sharedSigalgs");
1046     }
1047 
1048     if (!GetRemoteCertificateFromPeer()) {
1049         NETSTACK_LOGE("Failed to get remote certificate");
1050     }
1051     if (peerX509_ != nullptr) {
1052         NETSTACK_LOGE("peer x509Certificates is null");
1053 
1054         if (!SetRemoteCertRawData()) {
1055             NETSTACK_LOGE("Failed to set remote x509 certificata Serialization data");
1056         }
1057         TlsSocket::CheckServerIdentity checkServerIdentity = options.GetCheckServerIdentity();
1058         if (!checkServerIdentity) {
1059             CheckServerIdentityLegal(hostName_, peerX509_);
1060         } else {
1061             checkServerIdentity(hostName_, {remoteCert_});
1062         }
1063     }
1064     return true;
1065 }
1066 
GetRemoteCertificateFromPeer()1067 bool TLSSocketServer::Connection::GetRemoteCertificateFromPeer()
1068 {
1069     peerX509_ = SSL_get_peer_certificate(ssl_);
1070 
1071     if (SSL_get_verify_result(ssl_) == X509_V_OK) {
1072         NETSTACK_LOGE("SSL_get_verify_result ==X509_V_OK");
1073     }
1074 
1075     if (peerX509_ == nullptr) {
1076         int resErr = ConvertSSLError(GetSSL());
1077         NETSTACK_LOGE("open fail errno, errno is %{public}d, error info is %{public}s", resErr,
1078                       MakeSSLErrorString(resErr).c_str());
1079         return false;
1080     }
1081     BIO *bio = BIO_new(BIO_s_mem());
1082     if (!bio) {
1083         NETSTACK_LOGE("TlsSocket::SetRemoteCertificate bio is null");
1084         return false;
1085     }
1086     X509_print(bio, peerX509_);
1087     char data[REMOTE_CERT_LEN] = {0};
1088     if (!BIO_read(bio, data, REMOTE_CERT_LEN)) {
1089         NETSTACK_LOGE("BIO_read function returns error");
1090         BIO_free(bio);
1091         return false;
1092     }
1093     BIO_free(bio);
1094     remoteCert_ = std::string(data);
1095     return true;
1096 }
1097 
SetRemoteCertRawData()1098 bool TLSSocketServer::Connection::SetRemoteCertRawData()
1099 {
1100     if (peerX509_ == nullptr) {
1101         NETSTACK_LOGE("peerX509 is null");
1102         return false;
1103     }
1104     int32_t length = i2d_X509(peerX509_, nullptr);
1105     if (length <= 0) {
1106         NETSTACK_LOGE("Failed to convert peerX509 to der format");
1107         return false;
1108     }
1109     unsigned char *der = nullptr;
1110     (void)i2d_X509(peerX509_, &der);
1111     TlsSocket::SecureData data(der, length);
1112     remoteRawData_.data = data;
1113     OPENSSL_free(der);
1114     remoteRawData_.encodingFormat = TlsSocket::EncodingFormat::DER;
1115     return true;
1116 }
1117 
StartsWith(const std::string & s,const std::string & prefix)1118 static bool StartsWith(const std::string &s, const std::string &prefix)
1119 {
1120     return s.size() >= prefix.size() && s.compare(0, prefix.size(), prefix) == 0;
1121 }
CheckIpAndDnsName(const std::string & hostName,std::vector<std::string> & dnsNames,std::vector<std::string> & ips,const X509 * x509Certificates,std::tuple<bool,std::string> & result)1122 void CheckIpAndDnsName(const std::string &hostName, std::vector<std::string> &dnsNames, std::vector<std::string> &ips,
1123                        const X509 *x509Certificates, std::tuple<bool, std::string> &result)
1124 {
1125     bool valid = false;
1126     std::string reason = UNKNOW_REASON;
1127     int index = X509_get_ext_by_NID(x509Certificates, NID_commonName, -1);
1128     if (IsIP(hostName)) {
1129         auto it = find(ips.begin(), ips.end(), hostName);
1130         if (it == ips.end()) {
1131             reason = IP + hostName + " is not in the cert's list";
1132         }
1133         result = {valid, reason};
1134         return;
1135     }
1136     std::string tempHostName = "" + hostName;
1137     if (!dnsNames.empty() || index > 0) {
1138         std::vector<std::string> hostParts = SplitHostName(tempHostName);
1139         std::string tmpStr = "";
1140         if (!dnsNames.empty()) {
1141             valid = SeekIntersection(hostParts, dnsNames);
1142             tmpStr = ". is not in the cert's altnames";
1143         } else {
1144             char commonNameBuf[COMMON_NAME_BUF_SIZE] = {0};
1145             X509_NAME *pSubName = nullptr;
1146             int len = X509_NAME_get_text_by_NID(pSubName, NID_commonName, commonNameBuf, COMMON_NAME_BUF_SIZE);
1147             if (len > 0) {
1148                 std::vector<std::string> commonNameVec;
1149                 commonNameVec.emplace_back(commonNameBuf);
1150                 valid = SeekIntersection(hostParts, commonNameVec);
1151                 tmpStr = ". is not cert's CN";
1152             }
1153         }
1154         if (!valid) {
1155             reason = HOST_NAME + tempHostName + tmpStr;
1156         }
1157 
1158         result = {valid, reason};
1159         return;
1160     }
1161     reason = "Cert does not contain a DNS name";
1162     result = {valid, reason};
1163 }
1164 
CheckServerIdentityLegal(const std::string & hostName,const X509 * x509Certificates)1165 std::string TLSSocketServer::Connection::CheckServerIdentityLegal(const std::string &hostName,
1166                                                                   const X509 *x509Certificates)
1167 {
1168     X509_NAME *subjectName = X509_get_subject_name(x509Certificates);
1169     if (!subjectName) {
1170         return "subject name is null";
1171     }
1172     char subNameBuf[BUF_SIZE] = {0};
1173     X509_NAME_oneline(subjectName, subNameBuf, BUF_SIZE);
1174     int index = X509_get_ext_by_NID(x509Certificates, NID_subject_alt_name, -1);
1175     if (index < 0) {
1176         return "X509 get ext nid error";
1177     }
1178     X509_EXTENSION *ext = X509_get_ext(x509Certificates, index);
1179     if (ext == nullptr) {
1180         return "X509 get ext error";
1181     }
1182     ASN1_OBJECT *obj = nullptr;
1183     obj = X509_EXTENSION_get_object(ext);
1184     char subAltNameBuf[BUF_SIZE] = {0};
1185     OBJ_obj2txt(subAltNameBuf, BUF_SIZE, obj, 0);
1186     NETSTACK_LOGD("extions obj : %{public}s\n", subAltNameBuf);
1187 
1188     return CheckServerIdentityLegal(hostName, ext, x509Certificates);
1189 }
1190 
CheckServerIdentityLegal(const std::string & hostName,X509_EXTENSION * ext,const X509 * x509Certificates)1191 std::string TLSSocketServer::Connection::CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext,
1192                                                                   const X509 *x509Certificates)
1193 {
1194     ASN1_OCTET_STRING *extData = X509_EXTENSION_get_data(ext);
1195     if (!extData) {
1196         NETSTACK_LOGE("extData is nullptr");
1197         return "";
1198     }
1199     std::string altNames = reinterpret_cast<char *>(extData->data);
1200     std::string hostname = "" + hostName;
1201     BIO *bio = BIO_new(BIO_s_file());
1202     if (!bio) {
1203         return "bio is null";
1204     }
1205     BIO_set_fp(bio, stdout, BIO_NOCLOSE);
1206     ASN1_STRING_print(bio, extData);
1207     std::vector<std::string> dnsNames = {};
1208     std::vector<std::string> ips = {};
1209     constexpr int DNS_NAME_IDX = 4;
1210     constexpr int IP_NAME_IDX = 11;
1211     if (!altNames.empty()) {
1212         std::vector<std::string> splitAltNames;
1213         if (altNames.find('\"') != std::string::npos) {
1214             splitAltNames = SplitEscapedAltNames(altNames);
1215         } else {
1216             splitAltNames = CommonUtils::Split(altNames, SPLIT_ALT_NAMES);
1217         }
1218         for (auto const &iter : splitAltNames) {
1219             if (StartsWith(iter, DNS)) {
1220                 dnsNames.push_back(iter.substr(DNS_NAME_IDX));
1221             } else if (StartsWith(iter, IP_ADDRESS)) {
1222                 ips.push_back(iter.substr(IP_NAME_IDX));
1223             }
1224         }
1225     }
1226     std::tuple<bool, std::string> result;
1227     CheckIpAndDnsName(hostName, dnsNames, ips, x509Certificates, result);
1228     if (!std::get<0>(result)) {
1229         return "Hostname/IP does not match certificate's altnames: " + std::get<1>(result);
1230     }
1231     return HOST_NAME + hostname + ". is cert's CN";
1232 }
1233 
RemoveConnect(int socketFd)1234 void TLSSocketServer::RemoveConnect(int socketFd)
1235 {
1236     std::lock_guard<std::mutex> its_lock(connectMutex_);
1237 
1238     for (auto it = clientIdConnections_.begin(); it != clientIdConnections_.end();) {
1239         if (it->second->GetSocketFd() == socketFd) {
1240             break;
1241         } else {
1242             ++it;
1243         }
1244     }
1245 }
1246 
RecvRemoteInfo(int socketFd,int index)1247 int TLSSocketServer::RecvRemoteInfo(int socketFd, int index)
1248 {
1249     {
1250         std::lock_guard<std::mutex> its_lock(connectMutex_);
1251         for (auto it = clientIdConnections_.begin(); it != clientIdConnections_.end();) {
1252             if (it->second->GetSocketFd() == socketFd) {
1253                 char buffer[MAX_BUFFER_SIZE];
1254                 if (memset_s(buffer, MAX_BUFFER_SIZE, 0, MAX_BUFFER_SIZE) != EOK) {
1255                     NETSTACK_LOGE("memcpy_s failed");
1256                     break;
1257                 }
1258                 int len = it->second->Recv(buffer, MAX_BUFFER_SIZE);
1259                 NETSTACK_LOGE("revc message is size is %{public}d", len);
1260                 if (len > 0) {
1261                     Socket::SocketRemoteInfo remoteInfo;
1262                     remoteInfo.SetSize(strlen(buffer));
1263                     it->second->MakeRemoteInfo(remoteInfo);
1264                     it->second->CallOnMessageCallback(socketFd, buffer, remoteInfo);
1265                     return len;
1266                 }
1267 #if defined(CROSS_PLATFORM)
1268                 if (len == 0 &&  errno == 0) {
1269                     NETSTACK_LOGI("A client left");
1270                 }
1271 #endif
1272                 break;
1273             } else {
1274                 ++it;
1275             }
1276         }
1277     }
1278     RemoveConnect(socketFd);
1279     DropFdFromPollList(index);
1280     return -1;
1281 }
1282 
CachedMessageCallback()1283 void TLSSocketServer::Connection::CachedMessageCallback()
1284 {
1285     int32_t socketFd = GetSocketFd();
1286     if (socketFd < 0) {
1287         NETSTACK_LOGE("socketFd is invalid to recv message");
1288         return;
1289     }
1290     if (onMessageCallback_) {
1291         while (!dataCache_->IsEmpty()) {
1292             CacheInfo cache = dataCache_->Get();
1293             onMessageCallback_(socketFd, cache.data, cache.remoteInfo);
1294         }
1295     }
1296     NETSTACK_LOGD("Cached message is callbacked for socket %{public}d", socketFd);
1297 }
1298 
CallOnMessageCallback(int32_t socketFd,const std::string & data,const Socket::SocketRemoteInfo & remoteInfo)1299 void TLSSocketServer::Connection::CallOnMessageCallback(int32_t socketFd, const std::string &data,
1300                                                         const Socket::SocketRemoteInfo &remoteInfo)
1301 {
1302     OnMessageCallback CallBackfunc = nullptr;
1303     {
1304         if (onMessageCallback_) {
1305             CallBackfunc = onMessageCallback_;
1306         }
1307     }
1308 
1309     if (CallBackfunc) {
1310         while (!dataCache_->IsEmpty()) {
1311             CacheInfo cache = dataCache_->Get();
1312             CallBackfunc(socketFd, cache.data, cache.remoteInfo);
1313         }
1314         CallBackfunc(socketFd, data, remoteInfo);
1315     } else {
1316         NETSTACK_LOGD("message callback is not registered");
1317         CacheInfo cache = {data, remoteInfo};
1318         dataCache_->Set(cache);
1319     }
1320 }
1321 
AddConnect(int socketFd,std::shared_ptr<Connection> connection)1322 void TLSSocketServer::AddConnect(int socketFd, std::shared_ptr<Connection> connection)
1323 {
1324     std::lock_guard<std::mutex> its_lock(connectMutex_);
1325     clientIdConnections_[connection->GetClientID()] = connection;
1326 }
1327 
CallOnCloseCallback(const int32_t socketFd)1328 void TLSSocketServer::Connection::CallOnCloseCallback(const int32_t socketFd)
1329 {
1330     OnCloseCallback CallBackfunc = nullptr;
1331     {
1332         if (onCloseCallback_) {
1333             CallBackfunc = onCloseCallback_;
1334         }
1335     }
1336 
1337     if (CallBackfunc) {
1338         CallBackfunc(socketFd);
1339     }
1340 }
1341 
CallOnConnectCallback(const int32_t socketFd,std::shared_ptr<EventManager> eventManager)1342 void TLSSocketServer::CallOnConnectCallback(const int32_t socketFd, std::shared_ptr<EventManager> eventManager)
1343 {
1344     OnConnectCallback CallBackfunc = nullptr;
1345     {
1346         std::lock_guard<std::mutex> lock(mutex_);
1347         if (onConnectCallback_) {
1348             CallBackfunc = onConnectCallback_;
1349         }
1350     }
1351 
1352     if (CallBackfunc) {
1353         CallBackfunc(socketFd, eventManager);
1354     } else {
1355         NETSTACK_LOGE("CallOnConnectCallback  fun === null");
1356     }
1357 }
1358 
GetTlsConnectionLocalAddress(int acceptSockFD,Socket::NetAddress & localAddress)1359 bool TLSSocketServer::GetTlsConnectionLocalAddress(int acceptSockFD, Socket::NetAddress &localAddress)
1360 {
1361     struct sockaddr_storage addr{};
1362     socklen_t addrLen = sizeof(addr);
1363     if (getsockname(acceptSockFD, (struct sockaddr *)&addr, &addrLen) < 0) {
1364         if (acceptSockFD > 0) {
1365             close(acceptSockFD);
1366             CallOnErrorCallback(errno, strerror(errno));
1367             return false;
1368         }
1369     }
1370     char ipStr[INET6_ADDRSTRLEN] = {0};
1371     if (addr.ss_family == AF_INET) {
1372         auto *addr_in = (struct sockaddr_in *)&addr;
1373         inet_ntop(AF_INET, &addr_in->sin_addr, ipStr, sizeof(ipStr));
1374         localAddress.SetFamilyBySaFamily(AF_INET);
1375         localAddress.SetRawAddress(ipStr);
1376         localAddress.SetPort(ntohs(addr_in->sin_port));
1377     } else if (addr.ss_family == AF_INET6) {
1378         auto *addr_in6 = (struct sockaddr_in6 *)&addr;
1379         inet_ntop(AF_INET6, &addr_in6->sin6_addr, ipStr, sizeof(ipStr));
1380         localAddress.SetFamilyBySaFamily(AF_INET6);
1381         localAddress.SetRawAddress(ipStr);
1382         localAddress.SetPort(ntohs(addr_in6->sin6_port));
1383     }
1384     return true;
1385 }
1386 
ProcessTcpAccept(const TlsSocket::TLSConnectOptions & tlsListenOptions,int clientID)1387 void TLSSocketServer::ProcessTcpAccept(const TlsSocket::TLSConnectOptions &tlsListenOptions, int clientID)
1388 {
1389     struct sockaddr_in clientAddress;
1390     socklen_t clientAddrLength = sizeof(clientAddress);
1391     int connectFD = accept(listenSocketFd_, (struct sockaddr *)&clientAddress, &clientAddrLength);
1392     if (connectFD < 0) {
1393         int resErr = ConvertErrno();
1394         NETSTACK_LOGE("Server accept new client ERROR");
1395         CallOnErrorCallback(resErr, MakeErrnoString());
1396         return;
1397     }
1398     NETSTACK_LOGI("Server accept new client SUCCESS");
1399     std::shared_ptr<Connection> connection = std::make_shared<Connection>();
1400     Socket::NetAddress netAddress;
1401     Socket::NetAddress localAddress;
1402     char clientIp[INET6_ADDRSTRLEN] = {0};
1403     inet_ntop(address_.GetSaFamily(), &clientAddress.sin_addr, clientIp, INET_ADDRSTRLEN);
1404     int clientPort = ntohs(clientAddress.sin_port);
1405     netAddress.SetRawAddress(clientIp);
1406     netAddress.SetPort(clientPort);
1407     netAddress.SetFamilyBySaFamily(address_.GetSaFamily());
1408     connection->SetAddress(netAddress);
1409     if (GetTlsConnectionLocalAddress(connectFD, localAddress)) {
1410         connection->SetLocalAddress(localAddress);
1411     }
1412     SetTlsConnectionSecureOptions(tlsListenOptions, clientID, connectFD, connection);
1413 }
SetTlsConnectionSecureOptions(const TlsSocket::TLSConnectOptions & tlsListenOptions,int clientID,int connectFD,std::shared_ptr<Connection> & connection)1414 void TLSSocketServer::SetTlsConnectionSecureOptions(const TlsSocket::TLSConnectOptions &tlsListenOptions, int clientID,
1415                                                     int connectFD, std::shared_ptr<Connection> &connection)
1416 {
1417     connection->SetClientID(clientID);
1418     auto res = connection->TlsAcceptToHost(connectFD, tlsListenOptions);
1419     if (!res) {
1420         int resErr = ConvertSSLError(connection->GetSSL());
1421         NETSTACK_LOGE("setTlsConnectionSecureOptions error is %{public}d %{public}d", resErr, errno);
1422         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1423         return;
1424     }
1425     if (g_userCounter >= USER_LIMIT) {
1426         const std::string info = "Too many users!";
1427         connection->Send(info);
1428         connection->Close();
1429         NETSTACK_LOGE("Too many users");
1430         if (connection->GetSocketFd() != -1) {
1431             close(connectFD);
1432         }
1433         CallOnErrorCallback(-1, "Too many users");
1434         return;
1435     }
1436     g_userCounter++;
1437     fds_[g_userCounter].fd = connectFD;
1438 #if defined(CROSS_PLATFORM)
1439     fds_[g_userCounter].events = POLLIN | POLLERR;
1440 #else
1441     fds_[g_userCounter].events = POLLIN | POLLRDHUP | POLLERR;
1442 #endif
1443     fds_[g_userCounter].revents = 0;
1444     AddConnect(connectFD, connection);
1445     auto ptrEventManager = std::make_shared<EventManager>();
1446     ptrEventManager->SetData(this);
1447     connection->SetEventManager(ptrEventManager);
1448     CallOnConnectCallback(clientID, ptrEventManager);
1449     NETSTACK_LOGI("New client come in, fd is %{public}d", connectFD);
1450 }
1451 
InitPollList(int & listendFd)1452 void TLSSocketServer::InitPollList(int &listendFd)
1453 {
1454     for (int i = 1; i <= USER_LIMIT; ++i) {
1455         fds_[i].fd = -1;
1456         fds_[i].events = 0;
1457     }
1458     fds_[0].fd = listendFd;
1459     fds_[0].events = POLLIN | POLLERR;
1460     fds_[0].revents = 0;
1461 }
1462 
DropFdFromPollList(int & fd_index)1463 void TLSSocketServer::DropFdFromPollList(int &fd_index)
1464 {
1465     if (g_userCounter < 0) {
1466         NETSTACK_LOGE("g_userCounter  = %{public}d", g_userCounter);
1467         return;
1468     }
1469     fds_[fd_index].fd = fds_[g_userCounter].fd;
1470 
1471     fds_[g_userCounter].fd = -1;
1472     fds_[g_userCounter].events = 0;
1473     fd_index--;
1474     g_userCounter--;
1475     NETSTACK_LOGE("CallOnConnectCallback  g_userCounter  = %{public}d", g_userCounter);
1476 }
1477 
PollThread(const TlsSocket::TLSConnectOptions & tlsListenOptions)1478 void TLSSocketServer::PollThread(const TlsSocket::TLSConnectOptions &tlsListenOptions)
1479 {
1480     int on = 1;
1481     isRunning_ = true;
1482     ioctl(listenSocketFd_, FIONBIO, (char *)&on);
1483     NETSTACK_LOGE("PollThread  start working %{public}d", isRunning_);
1484     std::thread thread_([this, tlsOption = tlsListenOptions]() {
1485 #if defined(MAC_PLATFORM) || defined(IOS_PLATFORM)
1486         pthread_setname_np(TLS_SOCKET_SERVER_READ);
1487 #else
1488         pthread_setname_np(pthread_self(), TLS_SOCKET_SERVER_READ);
1489 #endif
1490         InitPollList(listenSocketFd_);
1491         int clientId = 0;
1492         while (isRunning_) {
1493             int ret = poll(fds_, g_userCounter + 1, POLL_WAIT_TIME);
1494             if (ret < 0) {
1495                 int resErr = ConvertErrno();
1496                 NETSTACK_LOGE("Poll ERROR");
1497                 CallOnErrorCallback(resErr, MakeErrnoString());
1498                 break;
1499             }
1500             if (ret == 0) {
1501                 continue;
1502             }
1503             for (int i = 0; i < g_userCounter + 1; ++i) {
1504                 if ((fds_[i].fd == listenSocketFd_) && (static_cast<uint16_t>(fds_[i].revents) & POLLIN)) {
1505                     ProcessTcpAccept(tlsOption, ++clientId);
1506 #if !defined(CROSS_PLATFORM)
1507                 } else if ((static_cast<uint16_t>(fds_[i].revents) & POLLRDHUP) ||
1508                            (static_cast<uint16_t>(fds_[i].revents) & POLLERR)) {
1509 #else
1510                 } else if ((static_cast<uint16_t>(fds_[i].revents) & POLLERR)) {
1511 #endif
1512                     RemoveConnect(fds_[i].fd);
1513                     DropFdFromPollList(i);
1514                     NETSTACK_LOGI("A client left");
1515                 } else if (static_cast<uint16_t>(fds_[i].revents) & POLLIN) {
1516                     RecvRemoteInfo(fds_[i].fd, i);
1517                 }
1518             }
1519         }
1520     });
1521     thread_.detach();
1522 }
1523 
GetConnectionByClientEventManager(const std::shared_ptr<EventManager> & eventManager)1524 std::shared_ptr<TLSSocketServer::Connection> TLSSocketServer::GetConnectionByClientEventManager(
1525     const std::shared_ptr<EventManager> &eventManager)
1526 {
1527     std::lock_guard<std::mutex> its_lock(connectMutex_);
1528     auto it = std::find_if(clientIdConnections_.begin(), clientIdConnections_.end(), [eventManager](const auto& pair) {
1529         return pair.second->GetEventManager() == eventManager;
1530     });
1531     if (it == clientIdConnections_.end()) {
1532         return nullptr;
1533     }
1534     return it->second;
1535 }
1536 
CloseConnectionByEventManager(const std::shared_ptr<EventManager> & eventManager)1537 void TLSSocketServer::CloseConnectionByEventManager(const std::shared_ptr<EventManager> &eventManager)
1538 {
1539     std::shared_ptr<Connection> ptrConnection = GetConnectionByClientEventManager(eventManager);
1540 
1541     if (ptrConnection != nullptr) {
1542         ptrConnection->Close();
1543     }
1544 }
1545 
DeleteConnectionByEventManager(const std::shared_ptr<EventManager> & eventManager)1546 void TLSSocketServer::DeleteConnectionByEventManager(const std::shared_ptr<EventManager> &eventManager)
1547 {
1548     std::lock_guard<std::mutex> its_lock(connectMutex_);
1549     for (auto it = clientIdConnections_.begin(); it != clientIdConnections_.end(); ++it) {
1550         if (it->second->GetEventManager() == eventManager) {
1551             it = clientIdConnections_.erase(it);
1552             break;
1553         }
1554     }
1555 }
1556 } // namespace TlsSocketServer
1557 } // namespace NetStack
1558 } // namespace OHOS
1559