• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "tls_socket_server.h"
17 
18 #include <chrono>
19 #include <memory>
20 #include <netinet/tcp.h>
21 #include <numeric>
22 #include <openssl/err.h>
23 #include <openssl/ssl.h>
24 
25 #include <regex>
26 #include <securec.h>
27 #include <sys/ioctl.h>
28 
29 #include "base_context.h"
30 #include "netstack_common_utils.h"
31 #include "netstack_log.h"
32 #include "tls.h"
33 
34 namespace OHOS {
35 namespace NetStack {
36 namespace TlsSocketServer {
37 #if UNITTEST
38 #else
39 namespace {
40 #endif // UNITTEST
41 constexpr size_t MAX_ERR_LENGTH = 1024;
42 
43 constexpr int SSL_RET_CODE = 0;
44 
45 constexpr int BUF_SIZE = 2048;
46 constexpr int POLL_WAIT_TIME = 2000;
47 constexpr int OFFSET = 2;
48 constexpr int SSL_ERROR_RETURN = -1;
49 constexpr int REMOTE_CERT_LEN = 8192;
50 constexpr int COMMON_NAME_BUF_SIZE = 256;
51 constexpr int LISETEN_COUNT = 516;
52 constexpr const char *SPLIT_HOST_NAME = ".";
53 constexpr const char *SPLIT_ALT_NAMES = ",";
54 constexpr const char *DNS = "DNS:";
55 constexpr const char *HOST_NAME = "hostname: ";
56 constexpr const char *IP_ADDRESS = "IP Address:";
57 constexpr const char *SIGN_NID_RSA = "RSA+";
58 constexpr const char *SIGN_NID_RSA_PSS = "RSA-PSS+";
59 constexpr const char *SIGN_NID_DSA = "DSA+";
60 constexpr const char *SIGN_NID_ECDSA = "ECDSA+";
61 constexpr const char *SIGN_NID_ED = "Ed25519+";
62 constexpr const char *SIGN_NID_ED_FOUR_FOUR_EIGHT = "Ed448+";
63 constexpr const char *SIGN_NID_UNDEF_ADD = "UNDEF+";
64 constexpr const char *PROTOCOL_UNKNOW = "UNKNOW_PROTOCOL";
65 constexpr const char *SIGN_NID_UNDEF = "UNDEF";
66 constexpr const char *OPERATOR_PLUS_SIGN = "+";
67 constexpr const char *UNKNOW_REASON = "Unknown reason";
68 constexpr const char *IP = "IP: ";
69 static constexpr const char *TLS_SOCKET_SERVER_READ = "OS_NET_TSAccRD";
70 const std::regex JSON_STRING_PATTERN{R"(/^"(?:[^"\\\u0000-\u001f]|\\(?:["\\/bfnrt]|u[0-9a-fA-F]{4}))*"/)"};
71 const std::regex PATTERN{
72     "((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|"
73     "2[0-4][0-9]|[01]?[0-9][0-9]?)"};
74 int g_userCounter = 0;
75 
IsIP(const std::string & ip)76 bool IsIP(const std::string &ip)
77 {
78     std::regex pattern(PATTERN);
79     std::smatch res;
80     return regex_match(ip, res, pattern);
81 }
82 
SplitHostName(std::string & hostName)83 std::vector<std::string> SplitHostName(std::string &hostName)
84 {
85     transform(hostName.begin(), hostName.end(), hostName.begin(), ::tolower);
86     return CommonUtils::Split(hostName, SPLIT_HOST_NAME);
87 }
88 
SeekIntersection(std::vector<std::string> & vecA,std::vector<std::string> & vecB)89 bool SeekIntersection(std::vector<std::string> &vecA, std::vector<std::string> &vecB)
90 {
91     std::vector<std::string> result;
92     set_intersection(vecA.begin(), vecA.end(), vecB.begin(), vecB.end(), inserter(result, result.begin()));
93     return !result.empty();
94 }
95 
ConvertErrno()96 int ConvertErrno()
97 {
98     return TlsSocket::TlsSocketError::TLS_ERR_SYS_BASE + errno;
99 }
100 
ConvertSSLError(ssl_st * ssl)101 int ConvertSSLError(ssl_st *ssl)
102 {
103     if (!ssl) {
104         return TlsSocket::TLS_ERR_SSL_NULL;
105     }
106     return TlsSocket::TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl, SSL_RET_CODE);
107 }
108 
MakeErrnoString()109 std::string MakeErrnoString()
110 {
111     return strerror(errno);
112 }
113 
MakeSSLErrorString(int error)114 std::string MakeSSLErrorString(int error)
115 {
116     char err[MAX_ERR_LENGTH] = {0};
117     ERR_error_string_n(error - TlsSocket::TlsSocketError::TLS_ERR_SYS_BASE, err, sizeof(err));
118     return err;
119 }
SplitEscapedAltNames(std::string & altNames)120 std::vector<std::string> SplitEscapedAltNames(std::string &altNames)
121 {
122     std::vector<std::string> result;
123     std::string currentToken;
124     size_t offset = 0;
125     while (offset != altNames.length()) {
126         auto nextSep = altNames.find_first_of(", ");
127         auto nextQuote = altNames.find_first_of('\"');
128         if (nextQuote != std::string::npos && (nextSep != std::string::npos || nextQuote < nextSep)) {
129             currentToken += altNames.substr(offset, nextQuote);
130             std::regex jsonStringPattern(JSON_STRING_PATTERN);
131             std::smatch match;
132             std::string altNameSubStr = altNames.substr(nextQuote);
133             bool ret = regex_match(altNameSubStr, match, jsonStringPattern);
134             if (!ret) {
135                 return {""};
136             }
137             currentToken += result[0];
138             offset = nextQuote + result[0].length();
139         } else if (nextSep != std::string::npos) {
140             currentToken += altNames.substr(offset, nextSep);
141             result.push_back(currentToken);
142             currentToken = "";
143             offset = nextSep + OFFSET;
144         } else {
145             currentToken += altNames.substr(offset);
146             offset = altNames.length();
147         }
148     }
149     result.push_back(currentToken);
150     return result;
151 }
152 #if UNITTEST
153 #else
154 } // namespace
155 #endif
156 
SetSocket(const int & socketFd)157 void TLSServerSendOptions::SetSocket(const int &socketFd)
158 {
159     socketFd_ = socketFd;
160 }
161 
SetSendData(const std::string & data)162 void TLSServerSendOptions::SetSendData(const std::string &data)
163 {
164     data_ = data;
165 }
166 
GetSocket() const167 const int &TLSServerSendOptions::GetSocket() const
168 {
169     return socketFd_;
170 }
171 
GetSendData() const172 const std::string &TLSServerSendOptions::GetSendData() const
173 {
174     return data_;
175 }
176 
~TLSSocketServer()177 TLSSocketServer::~TLSSocketServer()
178 {
179     isRunning_ = false;
180     clientIdConnections_.clear();
181 
182     if (listenSocketFd_ != -1) {
183         shutdown(listenSocketFd_, SHUT_RDWR);
184         close(listenSocketFd_);
185         listenSocketFd_ = -1;
186     }
187 }
188 
Listen(const TlsSocket::TLSConnectOptions & tlsListenOptions,const ListenCallback & callback)189 void TLSSocketServer::Listen(const TlsSocket::TLSConnectOptions &tlsListenOptions, const ListenCallback &callback)
190 {
191     if (!CommonUtils::HasInternetPermission()) {
192         CallListenCallback(PERMISSION_DENIED_CODE, callback);
193         return;
194     }
195     NETSTACK_LOGE("Listen 1 %{public}d", listenSocketFd_);
196     if (listenSocketFd_ >= 0) {
197         CallListenCallback(TlsSocket::TLSSOCKET_SUCCESS, callback);
198         return;
199     }
200     NETSTACK_LOGE("Listen 2 %{public}d, %{public}d", listenSocketFd_, g_userCounter);
201     if (ExecBind(tlsListenOptions.GetNetAddress(), callback)) {
202         NETSTACK_LOGE("Listen 3 %{public}d", listenSocketFd_);
203         ExecAccept(tlsListenOptions, callback);
204     } else {
205         shutdown(listenSocketFd_, SHUT_RDWR);
206         close(listenSocketFd_);
207         listenSocketFd_ = -1;
208     }
209     if (isRunning_) {
210         isRunning_ = false;
211         WaitForRcvThdExit();
212     }
213     PollThread(tlsListenOptions);
214 }
215 
ExecBind(const Socket::NetAddress & address,const ListenCallback & callback)216 bool TLSSocketServer::ExecBind(const Socket::NetAddress &address, const ListenCallback &callback)
217 {
218     MakeIpSocket(address.GetSaFamily());
219     if (listenSocketFd_ < 0) {
220         int resErr = ConvertErrno();
221         NETSTACK_LOGE("make tcp socket failed errno is %{public}d %{public}s", errno, MakeErrnoString().c_str());
222         CallOnErrorCallback(resErr, MakeErrnoString());
223         CallListenCallback(resErr, callback);
224         return false;
225     }
226     sockaddr_in addr4 = {0};
227     sockaddr_in6 addr6 = {0};
228     sockaddr *addr = nullptr;
229     socklen_t len;
230     GetAddr(address, &addr4, &addr6, &addr, &len);
231     if (addr == nullptr) {
232         NETSTACK_LOGE("TLSSocket::Bind Address Is Invalid");
233         CallOnErrorCallback(-1, "Address Is Invalid");
234         CallListenCallback(ConvertErrno(), callback);
235         return false;
236     }
237     int reuse = 1; // 1 means enable reuseaddr feature
238     if (setsockopt(listenSocketFd_, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<void *>(&reuse), sizeof(reuse)) < 0) {
239         NETSTACK_LOGE("failed to set tls server listen socket reuseaddr on, sockfd: %{public}d", listenSocketFd_);
240     }
241     if (bind(listenSocketFd_, addr, len) < 0) {
242         if (errno != EADDRINUSE) {
243             NETSTACK_LOGE("bind error is %{public}s %{public}d", strerror(errno), errno);
244             CallOnErrorCallback(-1, "Address binding failed");
245             CallListenCallback(ConvertErrno(), callback);
246             return false;
247         }
248         if (addr->sa_family == AF_INET) {
249             NETSTACK_LOGI("distribute a random port");
250             addr4.sin_port = 0; /* distribute a random port */
251         } else if (addr->sa_family == AF_INET6) {
252             NETSTACK_LOGI("distribute a random port");
253             addr6.sin6_port = 0; /* distribute a random port */
254         }
255         if (bind(listenSocketFd_, addr, len) < 0) {
256             NETSTACK_LOGE("rebind error is %{public}s %{public}d", strerror(errno), errno);
257             CallOnErrorCallback(-1, "Duplicate binding address failed");
258             CallListenCallback(ConvertErrno(), callback);
259             return false;
260         }
261         NETSTACK_LOGI("rebind success");
262     }
263     NETSTACK_LOGI("bind success");
264     address_ = address;
265     return true;
266 }
267 
ExecAccept(const TlsSocket::TLSConnectOptions & tlsAcceptOptions,const ListenCallback & callback)268 void TLSSocketServer::ExecAccept(const TlsSocket::TLSConnectOptions &tlsAcceptOptions, const ListenCallback &callback)
269 {
270     if (listenSocketFd_ < 0) {
271         int resErr = ConvertErrno();
272         NETSTACK_LOGE("accept error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
273         CallOnErrorCallback(resErr, MakeErrnoString());
274         callback(resErr);
275         return;
276     }
277     SetLocalTlsConfiguration(tlsAcceptOptions);
278     int ret = 0;
279     ret = listen(listenSocketFd_, LISETEN_COUNT);
280     if (ret < 0) {
281         int resErr = ConvertErrno();
282         NETSTACK_LOGE("tcp server listen error");
283         CallOnErrorCallback(resErr, MakeErrnoString());
284         callback(resErr);
285         return;
286     }
287     CallListenCallback(TlsSocket::TLSSOCKET_SUCCESS, callback);
288 }
289 
Send(const TLSServerSendOptions & data,const TlsSocket::SendCallback & callback)290 bool TLSSocketServer::Send(const TLSServerSendOptions &data, const TlsSocket::SendCallback &callback)
291 {
292     int socketFd = data.GetSocket();
293     std::string info = data.GetSendData();
294 
295     auto connect_iterator = clientIdConnections_.find(socketFd);
296     if (connect_iterator == clientIdConnections_.end()) {
297         NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
298         CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
299         return false;
300     }
301     auto connect = connect_iterator->second;
302     auto res = connect->Send(info);
303     if (!res) {
304         int resErr = ConvertSSLError(connect->GetSSL());
305         NETSTACK_LOGE("send error is %{public}d %{public}d", resErr, errno);
306         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
307         CallSendCallback(resErr, callback);
308         return false;
309     }
310     CallSendCallback(TlsSocket::TLSSOCKET_SUCCESS, callback);
311     return res;
312 }
313 
CallSendCallback(int32_t err,TlsSocket::SendCallback callback)314 void TLSSocketServer::CallSendCallback(int32_t err, TlsSocket::SendCallback callback)
315 {
316     if (callback) {
317         callback(err);
318     }
319 }
320 
Close(const int socketFd,const TlsSocket::CloseCallback & callback)321 void TLSSocketServer::Close(const int socketFd, const TlsSocket::CloseCallback &callback)
322 {
323     {
324         std::shared_lock<std::shared_mutex> its_lock(connectMutex_);
325         for (auto it = clientIdConnections_.begin(); it != clientIdConnections_.end();) {
326             if (it->first == socketFd) {
327                 auto res = it->second->Close();
328                 if (!res) {
329                     int resErr = ConvertSSLError(it->second->GetSSL());
330                     NETSTACK_LOGE("close error is %{public}d %{public}d", resErr, errno);
331                     CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
332                     callback(resErr);
333                     return;
334                 }
335                 callback(TlsSocket::TLSSOCKET_SUCCESS);
336                 return;
337             } else {
338                 ++it;
339             }
340         }
341     }
342     NETSTACK_LOGE("socket = %{public}d There is no corresponding socketFd", socketFd);
343     CallOnErrorCallback(-1, "The send failed with no corresponding socketFd");
344     callback(TlsSocket::TLS_ERR_SYS_EINVAL);
345 }
346 
Stop(const TlsSocket::CloseCallback & callback)347 void TLSSocketServer::Stop(const TlsSocket::CloseCallback &callback)
348 {
349     if (!CommonUtils::HasInternetPermission()) {
350         callback(PERMISSION_DENIED_CODE);
351     }
352     close(listenSocketFd_);
353     listenSocketFd_ = -1;
354     NETSTACK_LOGE("g_userCounter = %{public}d", g_userCounter);
355     callback(TlsSocket::TLSSOCKET_SUCCESS);
356 }
357 
GetRemoteAddress(const int socketFd,const TlsSocket::GetRemoteAddressCallback & callback)358 void TLSSocketServer::GetRemoteAddress(const int socketFd, const TlsSocket::GetRemoteAddressCallback &callback)
359 {
360     auto connect_iterator = clientIdConnections_.find(socketFd);
361     if (connect_iterator == clientIdConnections_.end()) {
362         NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
363         CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
364         callback(TlsSocket::TLS_ERR_SYS_EINVAL, {});
365         return;
366     }
367     auto connect = connect_iterator->second;
368     auto address = connect->GetAddress();
369     callback(TlsSocket::TLSSOCKET_SUCCESS, address);
370 }
371 
GetLocalAddress(const int socketFd,const TlsSocket::GetLocalAddressCallback & callback)372 void TLSSocketServer::GetLocalAddress(const int socketFd, const TlsSocket::GetLocalAddressCallback &callback)
373 {
374     auto connect_iterator = clientIdConnections_.find(socketFd);
375     if (connect_iterator == clientIdConnections_.end()) {
376         NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
377         CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
378         callback(TlsSocket::TLS_ERR_SYS_EINVAL, {});
379         return;
380     }
381     auto connect = connect_iterator->second;
382     auto localAddress = connect->GetLocalAddress();
383     callback(TlsSocket::TLSSOCKET_SUCCESS, localAddress);
384 }
385 
GetState(const TlsSocket::GetStateCallback & callback)386 void TLSSocketServer::GetState(const TlsSocket::GetStateCallback &callback)
387 {
388     int opt;
389     socklen_t optLen = sizeof(int);
390     int r = getsockopt(listenSocketFd_, SOL_SOCKET, SO_TYPE, &opt, &optLen);
391     if (r < 0) {
392         Socket::SocketStateBase state;
393         state.SetIsClose(true);
394         CallGetStateCallback(ConvertErrno(), state, callback);
395         return;
396     }
397     sockaddr sockAddr = {0};
398     socklen_t len = sizeof(sockaddr);
399     Socket::SocketStateBase state;
400     int ret = getsockname(listenSocketFd_, &sockAddr, &len);
401     state.SetIsBound(ret == 0);
402     ret = getpeername(listenSocketFd_, &sockAddr, &len);
403     if (ret != 0) {
404         NETSTACK_LOGE("getpeername failed");
405     }
406     state.SetIsConnected(GetConnectionClientCount() > 0);
407     CallGetStateCallback(TlsSocket::TLSSOCKET_SUCCESS, state, callback);
408 }
409 
CallGetStateCallback(int32_t err,const Socket::SocketStateBase & state,TlsSocket::GetStateCallback callback)410 void TLSSocketServer::CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state,
411                                            TlsSocket::GetStateCallback callback)
412 {
413     if (callback) {
414         callback(err, state);
415     }
416 }
SetExtraOptions(const Socket::TCPExtraOptions & tcpExtraOptions,const TlsSocket::SetExtraOptionsCallback & callback)417 bool TLSSocketServer::SetExtraOptions(const Socket::TCPExtraOptions &tcpExtraOptions,
418                                       const TlsSocket::SetExtraOptionsCallback &callback)
419 {
420     if (tcpExtraOptions.IsKeepAlive()) {
421         int keepalive = 1;
422         if (setsockopt(listenSocketFd_, SOL_SOCKET, SO_KEEPALIVE, &keepalive, sizeof(keepalive)) < 0) {
423             return false;
424         }
425     }
426 
427     if (tcpExtraOptions.IsOOBInline()) {
428         int oobInline = 1;
429         if (setsockopt(listenSocketFd_, SOL_SOCKET, SO_OOBINLINE, &oobInline, sizeof(oobInline)) < 0) {
430             return false;
431         }
432     }
433 
434     if (tcpExtraOptions.IsTCPNoDelay()) {
435         int tcpNoDelay = 1;
436         if (setsockopt(listenSocketFd_, IPPROTO_TCP, TCP_NODELAY, &tcpNoDelay, sizeof(tcpNoDelay)) < 0) {
437             return false;
438         }
439     }
440 
441     linger soLinger = {0};
442     soLinger.l_onoff = tcpExtraOptions.socketLinger.IsOn();
443     soLinger.l_linger = (int)tcpExtraOptions.socketLinger.GetLinger();
444     if (setsockopt(listenSocketFd_, SOL_SOCKET, SO_LINGER, &soLinger, sizeof(soLinger)) < 0) {
445         return false;
446     }
447 
448     return true;
449 }
450 
SetLocalTlsConfiguration(const TlsSocket::TLSConnectOptions & config)451 void TLSSocketServer::SetLocalTlsConfiguration(const TlsSocket::TLSConnectOptions &config)
452 {
453     TLSServerConfiguration_.SetPrivateKey(config.GetTlsSecureOptions().GetKey(),
454                                           config.GetTlsSecureOptions().GetKeyPass());
455     TLSServerConfiguration_.SetLocalCertificate(config.GetTlsSecureOptions().GetCert());
456     TLSServerConfiguration_.SetCaCertificate(config.GetTlsSecureOptions().GetCaChain());
457 
458     TLSServerConfiguration_.SetVerifyMode(config.GetTlsSecureOptions().GetVerifyMode());
459 
460     const auto protocolVec = config.GetTlsSecureOptions().GetProtocolChain();
461     if (!protocolVec.empty()) {
462         TLSServerConfiguration_.SetProtocol(protocolVec);
463     }
464 }
465 
GetCertificate(const TlsSocket::GetCertificateCallback & callback)466 void TLSSocketServer::GetCertificate(const TlsSocket::GetCertificateCallback &callback)
467 {
468     const auto &cert = TLSServerConfiguration_.GetCertificate();
469     NETSTACK_LOGI("cert der is %{public}d", cert.encodingFormat);
470     if (!cert.data.Length()) {
471         CallOnErrorCallback(-1, "cert not data Length");
472         callback(-1, {});
473         return;
474     }
475     callback(TlsSocket::TLSSOCKET_SUCCESS, cert);
476 }
477 
GetRemoteCertificate(const int socketFd,const TlsSocket::GetRemoteCertificateCallback & callback)478 void TLSSocketServer::GetRemoteCertificate(const int socketFd, const TlsSocket::GetRemoteCertificateCallback &callback)
479 {
480     auto connect_iterator = clientIdConnections_.find(socketFd);
481     if (connect_iterator == clientIdConnections_.end()) {
482         NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
483         CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
484         callback(TlsSocket::TLS_ERR_SYS_EINVAL, {});
485         return;
486     }
487     auto connect = connect_iterator->second;
488     const auto &remoteCert = connect->GetRemoteCertRawData();
489     if (!remoteCert.data.Length()) {
490         int resErr = ConvertSSLError(connect->GetSSL());
491         NETSTACK_LOGE("GetRemoteCertificate error is %{public}d %{public}d", resErr, errno);
492         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
493         callback(resErr, {});
494         return;
495     }
496     callback(TlsSocket::TLSSOCKET_SUCCESS, remoteCert);
497 }
498 
GetProtocol(const TlsSocket::GetProtocolCallback & callback)499 void TLSSocketServer::GetProtocol(const TlsSocket::GetProtocolCallback &callback)
500 {
501     if (TLSServerConfiguration_.GetProtocol() == TlsSocket::TLS_V1_3) {
502         callback(TlsSocket::TLSSOCKET_SUCCESS, TlsSocket::PROTOCOL_TLS_V13);
503         return;
504     }
505     callback(TlsSocket::TLSSOCKET_SUCCESS, TlsSocket::PROTOCOL_TLS_V12);
506 }
507 
GetCipherSuite(const int socketFd,const TlsSocket::GetCipherSuiteCallback & callback)508 void TLSSocketServer::GetCipherSuite(const int socketFd, const TlsSocket::GetCipherSuiteCallback &callback)
509 {
510     auto connect_iterator = clientIdConnections_.find(socketFd);
511     if (connect_iterator == clientIdConnections_.end()) {
512         NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
513         CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
514         callback(TlsSocket::TLS_ERR_SYS_EINVAL, {});
515         return;
516     }
517     auto connect = connect_iterator->second;
518     auto cipherSuite = connect->GetCipherSuite();
519     if (cipherSuite.empty()) {
520         int resErr = ConvertSSLError(connect->GetSSL());
521         NETSTACK_LOGE("GetCipherSuite error is %{public}d %{public}d", resErr, errno);
522         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
523         callback(resErr, cipherSuite);
524         return;
525     }
526     callback(TlsSocket::TLSSOCKET_SUCCESS, cipherSuite);
527 }
528 
GetSignatureAlgorithms(const int socketFd,const TlsSocket::GetSignatureAlgorithmsCallback & callback)529 void TLSSocketServer::GetSignatureAlgorithms(const int socketFd,
530                                              const TlsSocket::GetSignatureAlgorithmsCallback &callback)
531 {
532     auto connect_iterator = clientIdConnections_.find(socketFd);
533     if (connect_iterator == clientIdConnections_.end()) {
534         NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
535         CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
536         callback(TlsSocket::TLS_ERR_SYS_EINVAL, {});
537         return;
538     }
539     auto connect = connect_iterator->second;
540     auto signatureAlgorithms = connect->GetSignatureAlgorithms();
541     if (signatureAlgorithms.empty()) {
542         int resErr = ConvertSSLError(connect->GetSSL());
543         NETSTACK_LOGE("GetSignatureAlgorithms error is %{public}d %{public}d", resErr, errno);
544         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
545         callback(resErr, signatureAlgorithms);
546         return;
547     }
548     callback(TlsSocket::TLSSOCKET_SUCCESS, signatureAlgorithms);
549 }
550 
OnMessage(const OnMessageCallback & onMessageCallback)551 void TLSSocketServer::Connection::OnMessage(const OnMessageCallback &onMessageCallback)
552 {
553     onMessageCallback_ = onMessageCallback;
554     CachedMessageCallback();
555 }
556 
OnClose(const OnCloseCallback & onCloseCallback)557 void TLSSocketServer::Connection::OnClose(const OnCloseCallback &onCloseCallback)
558 {
559     onCloseCallback_ = onCloseCallback;
560 }
561 
OnConnect(const OnConnectCallback & onConnectCallback)562 void TLSSocketServer::OnConnect(const OnConnectCallback &onConnectCallback)
563 {
564     std::lock_guard<std::mutex> lock(mutex_);
565     onConnectCallback_ = onConnectCallback;
566 }
567 
OnError(const TlsSocket::OnErrorCallback & onErrorCallback)568 void TLSSocketServer::OnError(const TlsSocket::OnErrorCallback &onErrorCallback)
569 {
570     std::lock_guard<std::mutex> lock(mutex_);
571     onErrorCallback_ = onErrorCallback;
572 }
573 
OffMessage()574 void TLSSocketServer::Connection::OffMessage()
575 {
576     if (onMessageCallback_) {
577         onMessageCallback_ = nullptr;
578     }
579 }
580 
OffConnect()581 void TLSSocketServer::OffConnect()
582 {
583     std::lock_guard<std::mutex> lock(mutex_);
584     if (onConnectCallback_) {
585         onConnectCallback_ = nullptr;
586     }
587 }
588 
OnError(const TlsSocket::OnErrorCallback & onErrorCallback)589 void TLSSocketServer::Connection::OnError(const TlsSocket::OnErrorCallback &onErrorCallback)
590 {
591     onErrorCallback_ = onErrorCallback;
592 }
593 
OffClose()594 void TLSSocketServer::Connection::OffClose()
595 {
596     if (onCloseCallback_) {
597         onCloseCallback_ = nullptr;
598     }
599 }
600 
OffError()601 void TLSSocketServer::Connection::OffError()
602 {
603     onErrorCallback_ = nullptr;
604 }
605 
CallOnErrorCallback(int32_t err,const std::string & errString)606 void TLSSocketServer::Connection::CallOnErrorCallback(int32_t err, const std::string &errString)
607 {
608     TlsSocket::OnErrorCallback CallBackfunc = nullptr;
609     {
610         if (onErrorCallback_) {
611             CallBackfunc = onErrorCallback_;
612         }
613     }
614 
615     if (CallBackfunc) {
616         CallBackfunc(err, errString);
617     }
618 }
OffError()619 void TLSSocketServer::OffError()
620 {
621     std::lock_guard<std::mutex> lock(mutex_);
622     if (onErrorCallback_) {
623         onErrorCallback_ = nullptr;
624     }
625 }
626 
MakeIpSocket(sa_family_t family)627 void TLSSocketServer::MakeIpSocket(sa_family_t family)
628 {
629     if (family != AF_INET && family != AF_INET6) {
630         return;
631     }
632     int sock = socket(family, SOCK_STREAM, IPPROTO_IP);
633     if (sock < 0) {
634         int resErr = ConvertErrno();
635         NETSTACK_LOGE("Create socket failed (%{public}d:%{public}s)", errno, MakeErrnoString().c_str());
636         CallOnErrorCallback(resErr, MakeErrnoString());
637         return;
638     }
639     listenSocketFd_ = sock;
640 }
641 
CallOnErrorCallback(int32_t err,const std::string & errString)642 void TLSSocketServer::CallOnErrorCallback(int32_t err, const std::string &errString)
643 {
644     TlsSocket::OnErrorCallback CallBackfunc = nullptr;
645     {
646         std::lock_guard<std::mutex> lock(mutex_);
647         if (onErrorCallback_) {
648             CallBackfunc = onErrorCallback_;
649         }
650     }
651 
652     if (CallBackfunc) {
653         CallBackfunc(err, errString);
654     }
655 }
GetAddr(const Socket::NetAddress & address,sockaddr_in * addr4,sockaddr_in6 * addr6,sockaddr ** addr,socklen_t * len)656 void TLSSocketServer::GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6,
657                               sockaddr **addr, socklen_t *len)
658 {
659     if (!addr6 || !addr4 || !len) {
660         return;
661     }
662     sa_family_t family = address.GetSaFamily();
663     if (family == AF_INET) {
664         addr4->sin_family = AF_INET;
665         addr4->sin_port = htons(address.GetPort());
666         addr4->sin_addr.s_addr = inet_addr(address.GetAddress().c_str());
667         *addr = reinterpret_cast<sockaddr *>(addr4);
668         *len = sizeof(sockaddr_in);
669     } else if (family == AF_INET6) {
670         addr6->sin6_family = AF_INET6;
671         addr6->sin6_port = htons(address.GetPort());
672         inet_pton(AF_INET6, address.GetAddress().c_str(), &addr6->sin6_addr);
673         *addr = reinterpret_cast<sockaddr *>(addr6);
674         *len = sizeof(sockaddr_in6);
675     }
676 }
677 
GetListenSocketFd()678 int TLSSocketServer::GetListenSocketFd()
679 {
680     return listenSocketFd_;
681 }
682 
SetLocalAddress(const Socket::NetAddress & address)683 void TLSSocketServer::SetLocalAddress(const Socket::NetAddress &address)
684 {
685     localAddress_ = address;
686 }
687 
GetLocalAddress()688 Socket::NetAddress TLSSocketServer::GetLocalAddress()
689 {
690     return localAddress_;
691 }
692 
GetConnectionByClientID(int clientid)693 std::shared_ptr<TLSSocketServer::Connection> TLSSocketServer::GetConnectionByClientID(int clientid)
694 {
695     std::shared_ptr<Connection> ptrConnection = nullptr;
696 
697     auto it = clientIdConnections_.find(clientid);
698     if (it != clientIdConnections_.end()) {
699         ptrConnection = it->second;
700     }
701 
702     return ptrConnection;
703 }
704 
GetConnectionClientCount()705 int TLSSocketServer::GetConnectionClientCount()
706 {
707     return g_userCounter;
708 }
709 
CallListenCallback(int32_t err,ListenCallback callback)710 void TLSSocketServer::CallListenCallback(int32_t err, ListenCallback callback)
711 {
712     if (callback) {
713         callback(err);
714     }
715 }
716 
SetAddress(const Socket::NetAddress address)717 void TLSSocketServer::Connection::SetAddress(const Socket::NetAddress address)
718 {
719     address_ = address;
720 }
721 
SetLocalAddress(const Socket::NetAddress address)722 void TLSSocketServer::Connection::SetLocalAddress(const Socket::NetAddress address)
723 {
724     localAddress_ = address;
725 }
726 
GetRemoteCertRawData() const727 const TlsSocket::X509CertRawData &TLSSocketServer::Connection::GetRemoteCertRawData() const
728 {
729     return remoteRawData_;
730 }
731 
~Connection()732 TLSSocketServer::Connection::~Connection()
733 {
734     NETSTACK_LOGI("TLSSocketServer ~conn");
735     Close();
736 }
737 
TlsAcceptToHost(int sock,const TlsSocket::TLSConnectOptions & options)738 bool TLSSocketServer::Connection::TlsAcceptToHost(int sock, const TlsSocket::TLSConnectOptions &options)
739 {
740     SetTlsConfiguration(options);
741     std::string cipherSuite = options.GetTlsSecureOptions().GetCipherSuite();
742     if (!cipherSuite.empty()) {
743         connectionConfiguration_.SetCipherSuite(cipherSuite);
744     }
745     std::string signatureAlgorithms = options.GetTlsSecureOptions().GetSignatureAlgorithms();
746     if (!signatureAlgorithms.empty()) {
747         connectionConfiguration_.SetSignatureAlgorithms(signatureAlgorithms);
748     }
749     const auto protocolVec = options.GetTlsSecureOptions().GetProtocolChain();
750     if (!protocolVec.empty()) {
751         connectionConfiguration_.SetProtocol(protocolVec);
752     }
753     connectionConfiguration_.SetVerifyMode(options.GetTlsSecureOptions().GetVerifyMode());
754     socketFd_ = sock;
755     return StartTlsAccept(options);
756 }
757 
SetTlsConfiguration(const TlsSocket::TLSConnectOptions & config)758 void TLSSocketServer::Connection::SetTlsConfiguration(const TlsSocket::TLSConnectOptions &config)
759 {
760     connectionConfiguration_.SetPrivateKey(config.GetTlsSecureOptions().GetKey(),
761                                            config.GetTlsSecureOptions().GetKeyPass());
762     connectionConfiguration_.SetLocalCertificate(config.GetTlsSecureOptions().GetCert());
763     connectionConfiguration_.SetCaCertificate(config.GetTlsSecureOptions().GetCaChain());
764     connectionConfiguration_.SetNetAddress(config.GetNetAddress());
765 }
766 
Send(const std::string & data)767 bool TLSSocketServer::Connection::Send(const std::string &data)
768 {
769     if (!ssl_) {
770         NETSTACK_LOGE("ssl is null");
771         return false;
772     }
773     if (data.empty()) {
774         NETSTACK_LOGI("data is empty");
775         return true;
776     }
777     int len = SSL_write(ssl_, data.c_str(), data.length());
778     if (len < 0) {
779         int resErr = ConvertSSLError(GetSSL());
780         NETSTACK_LOGE("data send failed! error is %{public}d %{public}d", resErr, errno);
781         return false;
782     }
783     NETSTACK_LOGD("data Sent successfully,sent in total %{public}d bytes!", len);
784     return true;
785 }
786 
Recv(char * buffer,int maxBufferSize)787 int TLSSocketServer::Connection::Recv(char *buffer, int maxBufferSize)
788 {
789     if (!ssl_) {
790         NETSTACK_LOGE("ssl is null");
791         return SSL_ERROR_RETURN;
792     }
793     return SSL_read(ssl_, buffer, maxBufferSize);
794 }
795 
Close()796 bool TLSSocketServer::Connection::Close()
797 {
798     if (!ssl_) {
799         NETSTACK_LOGE("ssl is null");
800         return false;
801     }
802     int result = SSL_shutdown(ssl_);
803     if (result < 0) {
804         int resErr = ConvertSSLError(GetSSL());
805         NETSTACK_LOGE("Error in shutdown, error is %{public}d %{public}d", resErr, errno);
806     }
807     SSL_free(ssl_);
808     ssl_ = nullptr;
809     if (socketFd_ != -1) {
810         shutdown(socketFd_, SHUT_RDWR);
811         close(socketFd_);
812         NETSTACK_LOGI("close connection socketFd %{public}d", socketFd_);
813         socketFd_ = -1;
814     }
815     if (!tlsContextServerPointer_) {
816         NETSTACK_LOGE("Tls context pointer is null");
817         return false;
818     }
819     tlsContextServerPointer_->CloseCtx();
820     return true;
821 }
822 
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)823 bool TLSSocketServer::Connection::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
824 {
825     if (!ssl_) {
826         NETSTACK_LOGE("ssl is null");
827         return false;
828     }
829     size_t pos = 0;
830     size_t len = std::accumulate(alpnProtocols.begin(), alpnProtocols.end(), static_cast<size_t>(0),
831                                  [](size_t init, const std::string &alpnProt) { return init + alpnProt.length(); });
832     auto result = std::make_unique<unsigned char[]>(alpnProtocols.size() + len);
833     for (const auto &str : alpnProtocols) {
834         len = str.length();
835         result[pos++] = len;
836         if (!strcpy_s(reinterpret_cast<char *>(&result[pos]), len, str.c_str())) {
837             NETSTACK_LOGE("strcpy_s failed");
838             return false;
839         }
840         pos += len;
841     }
842     result[pos] = '\0';
843 
844     NETSTACK_LOGD("alpnProtocols after splicing %{public}s", result.get());
845     if (SSL_set_alpn_protos(ssl_, result.get(), pos)) {
846         int resErr = ConvertSSLError(GetSSL());
847         NETSTACK_LOGE("Failed to set negotiable protocol list, error is %{public}d %{public}d", resErr, errno);
848         return false;
849     }
850     return true;
851 }
852 
MakeRemoteInfo(Socket::SocketRemoteInfo & remoteInfo)853 void TLSSocketServer::Connection::MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo)
854 {
855     remoteInfo.SetAddress(address_.GetAddress());
856     remoteInfo.SetPort(address_.GetPort());
857     remoteInfo.SetFamily(address_.GetSaFamily());
858 }
859 
GetTlsConfiguration() const860 TlsSocket::TLSConfiguration TLSSocketServer::Connection::GetTlsConfiguration() const
861 {
862     return connectionConfiguration_;
863 }
864 
GetCipherSuite() const865 std::vector<std::string> TLSSocketServer::Connection::GetCipherSuite() const
866 {
867     if (!ssl_) {
868         NETSTACK_LOGE("ssl in null");
869         return {};
870     }
871     STACK_OF(SSL_CIPHER) *sk = SSL_get_ciphers(ssl_);
872     if (!sk) {
873         NETSTACK_LOGE("get ciphers failed");
874         return {};
875     }
876     TlsSocket::CipherSuite cipherSuite;
877     std::vector<std::string> cipherSuiteVec;
878     for (int i = 0; i < sk_SSL_CIPHER_num(sk); i++) {
879         const SSL_CIPHER *c = sk_SSL_CIPHER_value(sk, i);
880         cipherSuite.cipherName_ = SSL_CIPHER_get_name(c);
881         cipherSuiteVec.push_back(cipherSuite.cipherName_);
882     }
883     return cipherSuiteVec;
884 }
885 
GetRemoteCertificate() const886 std::string TLSSocketServer::Connection::GetRemoteCertificate() const
887 {
888     return remoteCert_;
889 }
890 
GetCertificate() const891 const TlsSocket::X509CertRawData &TLSSocketServer::Connection::GetCertificate() const
892 {
893     return connectionConfiguration_.GetCertificate();
894 }
895 
GetSignatureAlgorithms() const896 std::vector<std::string> TLSSocketServer::Connection::GetSignatureAlgorithms() const
897 {
898     return signatureAlgorithms_;
899 }
900 
GetProtocol() const901 std::string TLSSocketServer::Connection::GetProtocol() const
902 {
903     if (!ssl_) {
904         NETSTACK_LOGE("ssl in null");
905         return PROTOCOL_UNKNOW;
906     }
907     if (connectionConfiguration_.GetProtocol() == TlsSocket::TLS_V1_3) {
908         return TlsSocket::PROTOCOL_TLS_V13;
909     }
910     return TlsSocket::PROTOCOL_TLS_V12;
911 }
912 
SetSharedSigals()913 bool TLSSocketServer::Connection::SetSharedSigals()
914 {
915     if (!ssl_) {
916         NETSTACK_LOGE("ssl is null");
917         return false;
918     }
919     int number = SSL_get_shared_sigalgs(ssl_, 0, nullptr, nullptr, nullptr, nullptr, nullptr);
920     if (!number) {
921         NETSTACK_LOGE("SSL_get_shared_sigalgs return value error");
922         return false;
923     }
924     for (int i = 0; i < number; i++) {
925         int hash_nid;
926         int sign_nid;
927         std::string sig_with_md;
928         SSL_get_shared_sigalgs(ssl_, i, &sign_nid, &hash_nid, nullptr, nullptr, nullptr);
929         switch (sign_nid) {
930             case EVP_PKEY_RSA:
931                 sig_with_md = SIGN_NID_RSA;
932                 break;
933             case EVP_PKEY_RSA_PSS:
934                 sig_with_md = SIGN_NID_RSA_PSS;
935                 break;
936             case EVP_PKEY_DSA:
937                 sig_with_md = SIGN_NID_DSA;
938                 break;
939             case EVP_PKEY_EC:
940                 sig_with_md = SIGN_NID_ECDSA;
941                 break;
942             case NID_ED25519:
943                 sig_with_md = SIGN_NID_ED;
944                 break;
945             case NID_ED448:
946                 sig_with_md = SIGN_NID_ED_FOUR_FOUR_EIGHT;
947                 break;
948             default:
949                 const char *sn = OBJ_nid2sn(sign_nid);
950                 sig_with_md = (sn != nullptr) ? (std::string(sn) + OPERATOR_PLUS_SIGN) : SIGN_NID_UNDEF_ADD;
951         }
952         const char *sn_hash = OBJ_nid2sn(hash_nid);
953         sig_with_md += (sn_hash != nullptr) ? std::string(sn_hash) : SIGN_NID_UNDEF;
954         signatureAlgorithms_.push_back(sig_with_md);
955     }
956     return true;
957 }
958 
GetSSL() const959 ssl_st *TLSSocketServer::Connection::GetSSL() const
960 {
961     return ssl_;
962 }
963 
GetAddress() const964 Socket::NetAddress TLSSocketServer::Connection::GetAddress() const
965 {
966     return address_;
967 }
968 
GetLocalAddress() const969 Socket::NetAddress TLSSocketServer::Connection::GetLocalAddress() const
970 {
971     return localAddress_;
972 }
973 
GetSocketFd() const974 int TLSSocketServer::Connection::GetSocketFd() const
975 {
976     return socketFd_;
977 }
978 
GetEventManager() const979 std::shared_ptr<EventManager> TLSSocketServer::Connection::GetEventManager() const
980 {
981     return eventManager_;
982 }
983 
SetEventManager(std::shared_ptr<EventManager> eventManager)984 void TLSSocketServer::Connection::SetEventManager(std::shared_ptr<EventManager> eventManager)
985 {
986     eventManager_ = eventManager;
987 }
988 
SetClientID(int32_t clientID)989 void TLSSocketServer::Connection::SetClientID(int32_t clientID)
990 {
991     clientID_ = clientID;
992 }
993 
GetClientID()994 int TLSSocketServer::Connection::GetClientID()
995 {
996     return clientID_;
997 }
998 
StartTlsAccept(const TlsSocket::TLSConnectOptions & options)999 bool TLSSocketServer::Connection::StartTlsAccept(const TlsSocket::TLSConnectOptions &options)
1000 {
1001     if (!CreatTlsContext()) {
1002         NETSTACK_LOGE("failed to create tls context");
1003         return false;
1004     }
1005     if (!StartShakingHands(options)) {
1006         NETSTACK_LOGE("failed to shaking hands");
1007         return false;
1008     }
1009     return true;
1010 }
1011 
CreatTlsContext()1012 bool TLSSocketServer::Connection::CreatTlsContext()
1013 {
1014     tlsContextServerPointer_ = TlsSocket::TLSContextServer::CreateConfiguration(connectionConfiguration_);
1015     if (!tlsContextServerPointer_) {
1016         NETSTACK_LOGE("failed to create tls context pointer");
1017         return false;
1018     }
1019     if (!(ssl_ = tlsContextServerPointer_->CreateSsl())) {
1020         NETSTACK_LOGE("failed to create ssl session");
1021         return false;
1022     }
1023     SSL_set_fd(ssl_, socketFd_);
1024     SSL_set_accept_state(ssl_);
1025     return true;
1026 }
1027 
StartShakingHands(const TlsSocket::TLSConnectOptions & options)1028 bool TLSSocketServer::Connection::StartShakingHands(const TlsSocket::TLSConnectOptions &options)
1029 {
1030     if (!ssl_) {
1031         NETSTACK_LOGE("ssl is null");
1032         return false;
1033     }
1034     int result = SSL_accept(ssl_);
1035     if (result == -1) {
1036         int errorStatus = ConvertSSLError(ssl_);
1037         NETSTACK_LOGE("SSL connect is error, errno is %{public}d, error info is %{public}s", errorStatus,
1038                       MakeSSLErrorString(errorStatus).c_str());
1039         return false;
1040     }
1041 
1042     std::vector<std::string> SslProtocolVer({SSL_get_version(ssl_)});
1043     connectionConfiguration_.SetProtocol({SslProtocolVer});
1044 
1045     std::string list = SSL_get_cipher_list(ssl_, 0);
1046     NETSTACK_LOGI("SSL_get_cipher_list: %{public}s", list.c_str());
1047     connectionConfiguration_.SetCipherSuite(list);
1048     if (!SetSharedSigals()) {
1049         NETSTACK_LOGE("Failed to set sharedSigalgs");
1050     }
1051 
1052     if (!GetRemoteCertificateFromPeer()) {
1053         NETSTACK_LOGE("Failed to get remote certificate");
1054     }
1055     if (peerX509_ != nullptr) {
1056         NETSTACK_LOGE("peer x509Certificates is null");
1057 
1058         if (!SetRemoteCertRawData()) {
1059             NETSTACK_LOGE("Failed to set remote x509 certificata Serialization data");
1060         }
1061         TlsSocket::CheckServerIdentity checkServerIdentity = options.GetCheckServerIdentity();
1062         if (!checkServerIdentity) {
1063             CheckServerIdentityLegal(hostName_, peerX509_);
1064         } else {
1065             checkServerIdentity(hostName_, {remoteCert_});
1066         }
1067     }
1068     return true;
1069 }
1070 
GetRemoteCertificateFromPeer()1071 bool TLSSocketServer::Connection::GetRemoteCertificateFromPeer()
1072 {
1073     peerX509_ = SSL_get_peer_certificate(ssl_);
1074 
1075     if (SSL_get_verify_result(ssl_) == X509_V_OK) {
1076         NETSTACK_LOGE("SSL_get_verify_result ==X509_V_OK");
1077     }
1078 
1079     if (peerX509_ == nullptr) {
1080         int resErr = ConvertSSLError(GetSSL());
1081         NETSTACK_LOGE("open fail errno, errno is %{public}d, error info is %{public}s", resErr,
1082                       MakeSSLErrorString(resErr).c_str());
1083         return false;
1084     }
1085     BIO *bio = BIO_new(BIO_s_mem());
1086     if (!bio) {
1087         NETSTACK_LOGE("TlsSocket::SetRemoteCertificate bio is null");
1088         return false;
1089     }
1090     X509_print(bio, peerX509_);
1091     char data[REMOTE_CERT_LEN] = {0};
1092     if (!BIO_read(bio, data, REMOTE_CERT_LEN)) {
1093         NETSTACK_LOGE("BIO_read function returns error");
1094         BIO_free(bio);
1095         return false;
1096     }
1097     BIO_free(bio);
1098     remoteCert_ = std::string(data);
1099     return true;
1100 }
1101 
SetRemoteCertRawData()1102 bool TLSSocketServer::Connection::SetRemoteCertRawData()
1103 {
1104     if (peerX509_ == nullptr) {
1105         NETSTACK_LOGE("peerX509 is null");
1106         return false;
1107     }
1108     int32_t length = i2d_X509(peerX509_, nullptr);
1109     if (length <= 0) {
1110         NETSTACK_LOGE("Failed to convert peerX509 to der format");
1111         return false;
1112     }
1113     unsigned char *der = nullptr;
1114     (void)i2d_X509(peerX509_, &der);
1115     TlsSocket::SecureData data(der, length);
1116     remoteRawData_.data = data;
1117     OPENSSL_free(der);
1118     remoteRawData_.encodingFormat = TlsSocket::EncodingFormat::DER;
1119     return true;
1120 }
1121 
StartsWith(const std::string & s,const std::string & prefix)1122 static bool StartsWith(const std::string &s, const std::string &prefix)
1123 {
1124     return s.size() >= prefix.size() && s.compare(0, prefix.size(), prefix) == 0;
1125 }
CheckIpAndDnsName(const std::string & hostName,std::vector<std::string> & dnsNames,std::vector<std::string> & ips,const X509 * x509Certificates,std::tuple<bool,std::string> & result)1126 void CheckIpAndDnsName(const std::string &hostName, std::vector<std::string> &dnsNames, std::vector<std::string> &ips,
1127                        const X509 *x509Certificates, std::tuple<bool, std::string> &result)
1128 {
1129     bool valid = false;
1130     std::string reason = UNKNOW_REASON;
1131     int index = X509_get_ext_by_NID(x509Certificates, NID_commonName, -1);
1132     if (IsIP(hostName)) {
1133         auto it = find(ips.begin(), ips.end(), hostName);
1134         if (it == ips.end()) {
1135             reason = IP + hostName + " is not in the cert's list";
1136         }
1137         result = {valid, reason};
1138         return;
1139     }
1140     std::string tempHostName = "" + hostName;
1141     if (!dnsNames.empty() || index > 0) {
1142         std::vector<std::string> hostParts = SplitHostName(tempHostName);
1143         std::string tmpStr = "";
1144         if (!dnsNames.empty()) {
1145             valid = SeekIntersection(hostParts, dnsNames);
1146             tmpStr = ". is not in the cert's altnames";
1147         } else {
1148             char commonNameBuf[COMMON_NAME_BUF_SIZE] = {0};
1149             X509_NAME *pSubName = nullptr;
1150             int len = X509_NAME_get_text_by_NID(pSubName, NID_commonName, commonNameBuf, COMMON_NAME_BUF_SIZE);
1151             if (len > 0) {
1152                 std::vector<std::string> commonNameVec;
1153                 commonNameVec.emplace_back(commonNameBuf);
1154                 valid = SeekIntersection(hostParts, commonNameVec);
1155                 tmpStr = ". is not cert's CN";
1156             }
1157         }
1158         if (!valid) {
1159             reason = HOST_NAME + tempHostName + tmpStr;
1160         }
1161 
1162         result = {valid, reason};
1163         return;
1164     }
1165     reason = "Cert does not contain a DNS name";
1166     result = {valid, reason};
1167 }
1168 
CheckServerIdentityLegal(const std::string & hostName,const X509 * x509Certificates)1169 std::string TLSSocketServer::Connection::CheckServerIdentityLegal(const std::string &hostName,
1170                                                                   const X509 *x509Certificates)
1171 {
1172     X509_NAME *subjectName = X509_get_subject_name(x509Certificates);
1173     if (!subjectName) {
1174         return "subject name is null";
1175     }
1176     char subNameBuf[BUF_SIZE] = {0};
1177     X509_NAME_oneline(subjectName, subNameBuf, BUF_SIZE);
1178     int index = X509_get_ext_by_NID(x509Certificates, NID_subject_alt_name, -1);
1179     if (index < 0) {
1180         return "X509 get ext nid error";
1181     }
1182     X509_EXTENSION *ext = X509_get_ext(x509Certificates, index);
1183     if (ext == nullptr) {
1184         return "X509 get ext error";
1185     }
1186     ASN1_OBJECT *obj = nullptr;
1187     obj = X509_EXTENSION_get_object(ext);
1188     char subAltNameBuf[BUF_SIZE] = {0};
1189     OBJ_obj2txt(subAltNameBuf, BUF_SIZE, obj, 0);
1190     NETSTACK_LOGD("extions obj : %{public}s\n", subAltNameBuf);
1191 
1192     return CheckServerIdentityLegal(hostName, ext, x509Certificates);
1193 }
1194 
CheckServerIdentityLegal(const std::string & hostName,X509_EXTENSION * ext,const X509 * x509Certificates)1195 std::string TLSSocketServer::Connection::CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext,
1196                                                                   const X509 *x509Certificates)
1197 {
1198     ASN1_OCTET_STRING *extData = X509_EXTENSION_get_data(ext);
1199     if (!extData) {
1200         NETSTACK_LOGE("extData is nullptr");
1201         return "";
1202     }
1203     std::string altNames = reinterpret_cast<char *>(extData->data);
1204     std::string hostname = "" + hostName;
1205     BIO *bio = BIO_new(BIO_s_file());
1206     if (!bio) {
1207         return "bio is null";
1208     }
1209     BIO_set_fp(bio, stdout, BIO_NOCLOSE);
1210     ASN1_STRING_print(bio, extData);
1211     std::vector<std::string> dnsNames = {};
1212     std::vector<std::string> ips = {};
1213     constexpr int DNS_NAME_IDX = 4;
1214     constexpr int IP_NAME_IDX = 11;
1215     if (!altNames.empty()) {
1216         std::vector<std::string> splitAltNames;
1217         if (altNames.find('\"') != std::string::npos) {
1218             splitAltNames = SplitEscapedAltNames(altNames);
1219         } else {
1220             splitAltNames = CommonUtils::Split(altNames, SPLIT_ALT_NAMES);
1221         }
1222         for (auto const &iter : splitAltNames) {
1223             if (StartsWith(iter, DNS)) {
1224                 dnsNames.push_back(iter.substr(DNS_NAME_IDX));
1225             } else if (StartsWith(iter, IP_ADDRESS)) {
1226                 ips.push_back(iter.substr(IP_NAME_IDX));
1227             }
1228         }
1229     }
1230     std::tuple<bool, std::string> result;
1231     CheckIpAndDnsName(hostName, dnsNames, ips, x509Certificates, result);
1232     if (!std::get<0>(result)) {
1233         return "Hostname/IP does not match certificate's altnames: " + std::get<1>(result);
1234     }
1235     return HOST_NAME + hostname + ". is cert's CN";
1236 }
1237 
RemoveConnect(int socketFd)1238 void TLSSocketServer::RemoveConnect(int socketFd)
1239 {
1240     std::unique_lock<std::shared_mutex> its_lock(connectMutex_);
1241     for (auto it = clientIdConnections_.begin(); it != clientIdConnections_.end(); ++it) {
1242         if (it->second == nullptr) {
1243             NETSTACK_LOGE("tlsconnection is nullptr");
1244             continue;
1245         }
1246         if (it->second->GetSocketFd() == socketFd) {
1247             clientIdConnections_.erase(it);
1248             break;
1249         }
1250     }
1251 }
1252 
RecvRemoteInfo(int socketFd,int index)1253 bool TLSSocketServer::RecvRemoteInfo(int socketFd, int index)
1254 {
1255     {
1256         std::shared_lock<std::shared_mutex> its_lock(connectMutex_);
1257         for (auto it = clientIdConnections_.begin(); it != clientIdConnections_.end();) {
1258             if (it->second == nullptr) {
1259                 NETSTACK_LOGE("tlsconnection is nullptr");
1260                 return false;
1261             }
1262             if (it->second->GetSocketFd() == socketFd) {
1263                 char buffer[MAX_BUFFER_SIZE];
1264                 if (memset_s(buffer, MAX_BUFFER_SIZE, 0, MAX_BUFFER_SIZE) != EOK) {
1265                     NETSTACK_LOGE("memcpy_s failed");
1266                     break;
1267                 }
1268                 int len = it->second->Recv(buffer, MAX_BUFFER_SIZE);
1269                 NETSTACK_LOGE("revc message is size is %{public}d", len);
1270                 if (len > 0) {
1271                     Socket::SocketRemoteInfo remoteInfo;
1272                     remoteInfo.SetSize(strlen(buffer));
1273                     it->second->MakeRemoteInfo(remoteInfo);
1274                     it->second->CallOnMessageCallback(socketFd, buffer, remoteInfo);
1275                     return false;
1276                 } else if (len == 0) {
1277                     NETSTACK_LOGE("tls connection is closed by peer, clientId: %{public}d, Fd: %{public}d",
1278                         it->second->GetClientID(), socketFd);
1279                     it->second->CallOnCloseCallback(socketFd);
1280                     break;
1281                 } else {
1282                     int resErr = ConvertSSLError(it->second->GetSSL());
1283                     NETSTACK_LOGE("recv fail, clientId: %{public}d, Fd: %{public}d, "
1284                         "ssl error is %{public}d, error info is %{public}s",
1285                         it->second->GetClientID(), socketFd, resErr, MakeSSLErrorString(resErr).c_str());
1286                     it->second->CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1287                     break;
1288                 }
1289 #if defined(CROSS_PLATFORM)
1290                 if (len == 0 &&  errno == 0) {
1291                     NETSTACK_LOGI("A client left");
1292                 }
1293 #endif
1294             } else {
1295                 ++it;
1296             }
1297         }
1298     }
1299     RemoveConnect(socketFd);
1300     return DropFdFromPollList(index);
1301 }
1302 
CachedMessageCallback()1303 void TLSSocketServer::Connection::CachedMessageCallback()
1304 {
1305     int32_t socketFd = GetSocketFd();
1306     if (socketFd < 0) {
1307         NETSTACK_LOGE("socketFd is invalid to recv message");
1308         return;
1309     }
1310     if (onMessageCallback_) {
1311         while (!dataCache_->IsEmpty()) {
1312             CacheInfo cache = dataCache_->Get();
1313             onMessageCallback_(socketFd, cache.data, cache.remoteInfo);
1314         }
1315     }
1316     NETSTACK_LOGD("Cached message is callbacked for socket %{public}d", socketFd);
1317 }
1318 
CallOnMessageCallback(int32_t socketFd,const std::string & data,const Socket::SocketRemoteInfo & remoteInfo)1319 void TLSSocketServer::Connection::CallOnMessageCallback(int32_t socketFd, const std::string &data,
1320                                                         const Socket::SocketRemoteInfo &remoteInfo)
1321 {
1322     OnMessageCallback CallBackfunc = nullptr;
1323     {
1324         if (onMessageCallback_) {
1325             CallBackfunc = onMessageCallback_;
1326         }
1327     }
1328 
1329     if (CallBackfunc) {
1330         while (!dataCache_->IsEmpty()) {
1331             CacheInfo cache = dataCache_->Get();
1332             CallBackfunc(socketFd, cache.data, cache.remoteInfo);
1333         }
1334         CallBackfunc(socketFd, data, remoteInfo);
1335     } else {
1336         NETSTACK_LOGD("message callback is not registered");
1337         CacheInfo cache = {data, remoteInfo};
1338         dataCache_->Set(cache);
1339     }
1340 }
1341 
AddConnect(int socketFd,std::shared_ptr<Connection> connection)1342 void TLSSocketServer::AddConnect(int socketFd, std::shared_ptr<Connection> connection)
1343 {
1344     std::unique_lock<std::shared_mutex> its_lock(connectMutex_);
1345     clientIdConnections_[connection->GetClientID()] = connection;
1346 }
1347 
CallOnCloseCallback(const int32_t socketFd)1348 void TLSSocketServer::Connection::CallOnCloseCallback(const int32_t socketFd)
1349 {
1350     OnCloseCallback CallBackfunc = nullptr;
1351     {
1352         if (onCloseCallback_) {
1353             CallBackfunc = onCloseCallback_;
1354         }
1355     }
1356 
1357     if (CallBackfunc) {
1358         CallBackfunc(socketFd);
1359     }
1360 }
1361 
CallOnConnectCallback(const int32_t socketFd,std::shared_ptr<EventManager> eventManager)1362 void TLSSocketServer::CallOnConnectCallback(const int32_t socketFd, std::shared_ptr<EventManager> eventManager)
1363 {
1364     OnConnectCallback CallBackfunc = nullptr;
1365     {
1366         std::lock_guard<std::mutex> lock(mutex_);
1367         if (onConnectCallback_) {
1368             CallBackfunc = onConnectCallback_;
1369         }
1370     }
1371 
1372     if (CallBackfunc) {
1373         CallBackfunc(socketFd, eventManager);
1374     } else {
1375         NETSTACK_LOGE("CallOnConnectCallback  fun === null");
1376     }
1377 }
1378 
GetTlsConnectionLocalAddress(int acceptSockFD,Socket::NetAddress & localAddress)1379 bool TLSSocketServer::GetTlsConnectionLocalAddress(int acceptSockFD, Socket::NetAddress &localAddress)
1380 {
1381     struct sockaddr_storage addr{};
1382     socklen_t addrLen = sizeof(addr);
1383     if (getsockname(acceptSockFD, (struct sockaddr *)&addr, &addrLen) < 0) {
1384         if (acceptSockFD > 0) {
1385             close(acceptSockFD);
1386             CallOnErrorCallback(errno, strerror(errno));
1387             return false;
1388         }
1389     }
1390     char ipStr[INET6_ADDRSTRLEN] = {0};
1391     if (addr.ss_family == AF_INET) {
1392         auto *addr_in = (struct sockaddr_in *)&addr;
1393         inet_ntop(AF_INET, &addr_in->sin_addr, ipStr, sizeof(ipStr));
1394         localAddress.SetFamilyBySaFamily(AF_INET);
1395         localAddress.SetRawAddress(ipStr);
1396         localAddress.SetPort(ntohs(addr_in->sin_port));
1397     } else if (addr.ss_family == AF_INET6) {
1398         auto *addr_in6 = (struct sockaddr_in6 *)&addr;
1399         inet_ntop(AF_INET6, &addr_in6->sin6_addr, ipStr, sizeof(ipStr));
1400         localAddress.SetFamilyBySaFamily(AF_INET6);
1401         localAddress.SetRawAddress(ipStr);
1402         localAddress.SetPort(ntohs(addr_in6->sin6_port));
1403     }
1404     return true;
1405 }
1406 
ProcessTcpAccept(const TlsSocket::TLSConnectOptions & tlsListenOptions,int clientID)1407 void TLSSocketServer::ProcessTcpAccept(const TlsSocket::TLSConnectOptions &tlsListenOptions, int clientID)
1408 {
1409     struct sockaddr_in clientAddress;
1410     socklen_t clientAddrLength = sizeof(clientAddress);
1411     int connectFD = accept(listenSocketFd_, (struct sockaddr *)&clientAddress, &clientAddrLength);
1412     if (connectFD < 0) {
1413         int resErr = ConvertErrno();
1414         NETSTACK_LOGE("Server accept new client ERROR");
1415         CallOnErrorCallback(resErr, MakeErrnoString());
1416         return;
1417     }
1418     NETSTACK_LOGI("Server accept new client SUCCESS");
1419     std::shared_ptr<Connection> connection = std::make_shared<Connection>();
1420     Socket::NetAddress netAddress;
1421     Socket::NetAddress localAddress;
1422     char clientIp[INET6_ADDRSTRLEN] = {0};
1423     inet_ntop(address_.GetSaFamily(), &clientAddress.sin_addr, clientIp, INET_ADDRSTRLEN);
1424     int clientPort = ntohs(clientAddress.sin_port);
1425     netAddress.SetRawAddress(clientIp);
1426     netAddress.SetPort(clientPort);
1427     netAddress.SetFamilyBySaFamily(address_.GetSaFamily());
1428     connection->SetAddress(netAddress);
1429     if (!GetTlsConnectionLocalAddress(connectFD, localAddress)) {
1430         NETSTACK_LOGE("GetTlsConnectionLocalAddress");
1431         return;
1432     }
1433     connection->SetLocalAddress(localAddress);
1434     SetTlsConnectionSecureOptions(tlsListenOptions, clientID, connectFD, connection);
1435 }
SetTlsConnectionSecureOptions(const TlsSocket::TLSConnectOptions & tlsListenOptions,int clientID,int connectFD,std::shared_ptr<Connection> & connection)1436 void TLSSocketServer::SetTlsConnectionSecureOptions(const TlsSocket::TLSConnectOptions &tlsListenOptions, int clientID,
1437                                                     int connectFD, std::shared_ptr<Connection> &connection)
1438 {
1439     connection->SetClientID(clientID);
1440     auto res = connection->TlsAcceptToHost(connectFD, tlsListenOptions);
1441     if (!res) {
1442         int resErr = ConvertSSLError(connection->GetSSL());
1443         NETSTACK_LOGE("setTlsConnectionSecureOptions error is %{public}d %{public}d", resErr, errno);
1444         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1445         return;
1446     }
1447     if (g_userCounter >= USER_LIMIT) {
1448         const std::string info = "Too many users!";
1449         connection->Send(info);
1450         connection->Close();
1451         NETSTACK_LOGE("Too many users");
1452         if (connection->GetSocketFd() != -1) {
1453             close(connectFD);
1454         }
1455         CallOnErrorCallback(-1, "Too many users");
1456         return;
1457     }
1458     g_userCounter++;
1459     fds_[g_userCounter].fd = connectFD;
1460 #if defined(CROSS_PLATFORM)
1461     fds_[g_userCounter].events = POLLIN | POLLERR;
1462 #else
1463     fds_[g_userCounter].events = POLLIN | POLLRDHUP | POLLERR;
1464 #endif
1465     fds_[g_userCounter].revents = 0;
1466     AddConnect(connectFD, connection);
1467     auto ptrEventManager = std::make_shared<EventManager>();
1468     ptrEventManager->SetData(this);
1469     connection->SetEventManager(ptrEventManager);
1470     CallOnConnectCallback(clientID, ptrEventManager);
1471     NETSTACK_LOGI("New client come in, fd is %{public}d", connectFD);
1472 }
1473 
InitPollList(const int & listendFd)1474 void TLSSocketServer::InitPollList(const int &listendFd)
1475 {
1476     fds_[0].fd = listendFd;
1477     fds_[0].events = POLLIN | POLLERR;
1478     fds_[0].revents = 0;
1479 }
1480 
DropFdFromPollList(int & fd_index)1481 bool TLSSocketServer::DropFdFromPollList(int &fd_index)
1482 {
1483     if (g_userCounter < 0) {
1484         NETSTACK_LOGE("g_userCounter = %{public}d", g_userCounter);
1485         return true;
1486     }
1487     if (fd_index == 0) {
1488         // index 0 is for listen only
1489         fds_[0].fd = -1;
1490         fds_[0].events = 0;
1491         NETSTACK_LOGI("drop listenFd from poll List, g_userCounter = %{public}d", g_userCounter);
1492     } else {
1493         // remove the fd_index, and insert the last index
1494         fds_[fd_index].fd = fds_[g_userCounter].fd;
1495         fds_[g_userCounter].fd = -1;
1496         fds_[g_userCounter].events = 0;
1497         fd_index--;
1498         g_userCounter--;
1499         NETSTACK_LOGI("drop clientFd from poll List, g_userCounter = %{public}d", g_userCounter);
1500     }
1501     for (int i = 0; i < g_userCounter + 1; ++i) {
1502         if (fds_[i].fd > 0) {
1503             return false;
1504         }
1505     }
1506     return true;
1507 }
NotifyRcvThdExit()1508 void TLSSocketServer::NotifyRcvThdExit()
1509 {
1510     std::unique_lock<std::mutex> lock(sockRcvThdMtx_);
1511     sockRcvExit_ = true;
1512     sockRcvThdCon_.notify_one();
1513     NETSTACK_LOGI("recv thread exit");
1514 }
1515 
WaitForRcvThdExit()1516 void TLSSocketServer::WaitForRcvThdExit()
1517 {
1518     std::unique_lock<std::mutex> lock(sockRcvThdMtx_);
1519     sockRcvThdCon_.wait(lock, [this]() { return sockRcvExit_; });
1520 }
1521 
PollThread(const TlsSocket::TLSConnectOptions & tlsListenOptions)1522 void TLSSocketServer::PollThread(const TlsSocket::TLSConnectOptions &tlsListenOptions)
1523 {
1524     int on = 1;
1525     isRunning_ = true;
1526     ioctl(listenSocketFd_, FIONBIO, (char *)&on);
1527     NETSTACK_LOGI("PollThread  start working %{public}d", isRunning_);
1528     std::thread thread_([this, tlsOption = tlsListenOptions]() {
1529 #if defined(MAC_PLATFORM) || defined(IOS_PLATFORM)
1530         pthread_setname_np(TLS_SOCKET_SERVER_READ);
1531 #else
1532         pthread_setname_np(pthread_self(), TLS_SOCKET_SERVER_READ);
1533 #endif
1534         InitPollList(listenSocketFd_);
1535         bool exitLoop = false;
1536         while (isRunning_ && !exitLoop) {
1537             int ret = poll(fds_, g_userCounter + 1, POLL_WAIT_TIME);
1538             if (ret < 0) {
1539                 int resErr = ConvertErrno();
1540                 NETSTACK_LOGE("Poll ERROR");
1541                 CallOnErrorCallback(resErr, MakeErrnoString());
1542                 break;
1543             }
1544             if (ret == 0) {
1545                 continue;
1546             }
1547             for (int i = 0; i < g_userCounter + 1; ++i) {
1548                 if ((fds_[i].fd == listenSocketFd_) && (static_cast<uint16_t>(fds_[i].revents) & POLLIN)) {
1549                     ProcessTcpAccept(tlsOption, g_userCounter + 1);
1550 #if !defined(CROSS_PLATFORM)
1551                 } else if ((static_cast<uint16_t>(fds_[i].revents) & POLLRDHUP) ||
1552                            (static_cast<uint16_t>(fds_[i].revents) & (POLLERR | POLLNVAL))) {
1553 #else
1554                 } else if ((static_cast<uint16_t>(fds_[i].revents) & POLLERR | POLLNVAL)) {
1555 #endif
1556                     RemoveConnect(fds_[i].fd);
1557                     exitLoop = DropFdFromPollList(i);
1558                 } else if (static_cast<uint16_t>(fds_[i].revents) & POLLIN) {
1559                     exitLoop = RecvRemoteInfo(fds_[i].fd, i);
1560                 }
1561             }
1562         }
1563         isRunning_ = false;
1564         NotifyRcvThdExit();
1565     });
1566     thread_.detach();
1567 }
1568 
GetConnectionByClientEventManager(const std::shared_ptr<EventManager> & eventManager)1569 std::shared_ptr<TLSSocketServer::Connection> TLSSocketServer::GetConnectionByClientEventManager(
1570     const std::shared_ptr<EventManager> &eventManager)
1571 {
1572     std::shared_lock<std::shared_mutex> its_lock(connectMutex_);
1573     auto it = std::find_if(clientIdConnections_.begin(), clientIdConnections_.end(), [eventManager](const auto& pair) {
1574         return pair.second->GetEventManager() == eventManager;
1575     });
1576     if (it == clientIdConnections_.end()) {
1577         return nullptr;
1578     }
1579     return it->second;
1580 }
1581 
CloseConnectionByEventManager(const std::shared_ptr<EventManager> & eventManager)1582 void TLSSocketServer::CloseConnectionByEventManager(const std::shared_ptr<EventManager> &eventManager)
1583 {
1584     std::shared_ptr<Connection> ptrConnection = GetConnectionByClientEventManager(eventManager);
1585 
1586     if (ptrConnection != nullptr) {
1587         ptrConnection->Close();
1588     }
1589 }
1590 
DeleteConnectionByEventManager(const std::shared_ptr<EventManager> & eventManager)1591 void TLSSocketServer::DeleteConnectionByEventManager(const std::shared_ptr<EventManager> &eventManager)
1592 {
1593     std::unique_lock<std::shared_mutex> its_lock(connectMutex_);
1594     for (auto it = clientIdConnections_.begin(); it != clientIdConnections_.end(); ++it) {
1595         if (it->second->GetEventManager() == eventManager) {
1596             it = clientIdConnections_.erase(it);
1597             break;
1598         }
1599     }
1600 }
1601 } // namespace TlsSocketServer
1602 } // namespace NetStack
1603 } // namespace OHOS
1604