• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "tls_socket_server.h"
17 
18 #include <chrono>
19 #include <memory>
20 #include <netinet/tcp.h>
21 #include <numeric>
22 #include <openssl/err.h>
23 #include <openssl/ssl.h>
24 
25 #include <regex>
26 #include <securec.h>
27 #include <sys/ioctl.h>
28 
29 #include "base_context.h"
30 #include "netstack_common_utils.h"
31 #include "netstack_log.h"
32 #include "tls.h"
33 
34 namespace OHOS {
35 namespace NetStack {
36 namespace TlsSocketServer {
37 namespace {
38 constexpr int SSL_RET_CODE = 0;
39 
40 constexpr int BUF_SIZE = 2048;
41 constexpr int POLL_WAIT_TIME = 2000;
42 constexpr int OFFSET = 2;
43 constexpr int SSL_ERROR_RETURN = -1;
44 constexpr int REMOTE_CERT_LEN = 8192;
45 constexpr int COMMON_NAME_BUF_SIZE = 256;
46 constexpr int LISETEN_COUNT = 516;
47 constexpr const char *SPLIT_HOST_NAME = ".";
48 constexpr const char *SPLIT_ALT_NAMES = ",";
49 constexpr const char *DNS = "DNS:";
50 constexpr const char *HOST_NAME = "hostname: ";
51 constexpr const char *IP_ADDRESS = "IP Address:";
52 constexpr const char *SIGN_NID_RSA = "RSA+";
53 constexpr const char *SIGN_NID_RSA_PSS = "RSA-PSS+";
54 constexpr const char *SIGN_NID_DSA = "DSA+";
55 constexpr const char *SIGN_NID_ECDSA = "ECDSA+";
56 constexpr const char *SIGN_NID_ED = "Ed25519+";
57 constexpr const char *SIGN_NID_ED_FOUR_FOUR_EIGHT = "Ed448+";
58 constexpr const char *SIGN_NID_UNDEF_ADD = "UNDEF+";
59 constexpr const char *PROTOCOL_UNKNOW = "UNKNOW_PROTOCOL";
60 constexpr const char *SIGN_NID_UNDEF = "UNDEF";
61 constexpr const char *OPERATOR_PLUS_SIGN = "+";
62 constexpr const char *UNKNOW_REASON = "Unknown reason";
63 constexpr const char *IP = "IP: ";
64 const std::regex JSON_STRING_PATTERN{R"(/^"(?:[^"\\\u0000-\u001f]|\\(?:["\\/bfnrt]|u[0-9a-fA-F]{4}))*"/)"};
65 const std::regex PATTERN{
66     "((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|"
67     "2[0-4][0-9]|[01]?[0-9][0-9]?)"};
68 int g_userCounter = 0;
69 
IsIP(const std::string & ip)70 bool IsIP(const std::string &ip)
71 {
72     std::regex pattern(PATTERN);
73     std::smatch res;
74     return regex_match(ip, res, pattern);
75 }
76 
SplitHostName(std::string & hostName)77 std::vector<std::string> SplitHostName(std::string &hostName)
78 {
79     transform(hostName.begin(), hostName.end(), hostName.begin(), ::tolower);
80     return CommonUtils::Split(hostName, SPLIT_HOST_NAME);
81 }
82 
SeekIntersection(std::vector<std::string> & vecA,std::vector<std::string> & vecB)83 bool SeekIntersection(std::vector<std::string> &vecA, std::vector<std::string> &vecB)
84 {
85     std::vector<std::string> result;
86     set_intersection(vecA.begin(), vecA.end(), vecB.begin(), vecB.end(), inserter(result, result.begin()));
87     return !result.empty();
88 }
89 
ConvertErrno()90 int ConvertErrno()
91 {
92     return TlsSocket::TlsSocketError::TLS_ERR_SYS_BASE + errno;
93 }
94 
ConvertSSLError(ssl_st * ssl)95 int ConvertSSLError(ssl_st *ssl)
96 {
97     if (!ssl) {
98         return TlsSocket::TLS_ERR_SSL_NULL;
99     }
100     return TlsSocket::TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl, SSL_RET_CODE);
101 }
102 
MakeErrnoString()103 std::string MakeErrnoString()
104 {
105     return strerror(errno);
106 }
107 
MakeSSLErrorString(int error)108 std::string MakeSSLErrorString(int error)
109 {
110     char err[TlsSocket::MAX_ERR_LEN] = {0};
111     ERR_error_string_n(error - TlsSocket::TlsSocketError::TLS_ERR_SYS_BASE, err, sizeof(err));
112     return err;
113 }
SplitEscapedAltNames(std::string & altNames)114 std::vector<std::string> SplitEscapedAltNames(std::string &altNames)
115 {
116     std::vector<std::string> result;
117     std::string currentToken;
118     size_t offset = 0;
119     while (offset != altNames.length()) {
120         auto nextSep = altNames.find_first_of(", ");
121         auto nextQuote = altNames.find_first_of('\"');
122         if (nextQuote != std::string::npos && (nextSep != std::string::npos || nextQuote < nextSep)) {
123             currentToken += altNames.substr(offset, nextQuote);
124             std::regex jsonStringPattern(JSON_STRING_PATTERN);
125             std::smatch match;
126             std::string altNameSubStr = altNames.substr(nextQuote);
127             bool ret = regex_match(altNameSubStr, match, jsonStringPattern);
128             if (!ret) {
129                 return {""};
130             }
131             currentToken += result[0];
132             offset = nextQuote + result[0].length();
133         } else if (nextSep != std::string::npos) {
134             currentToken += altNames.substr(offset, nextSep);
135             result.push_back(currentToken);
136             currentToken = "";
137             offset = nextSep + OFFSET;
138         } else {
139             currentToken += altNames.substr(offset);
140             offset = altNames.length();
141         }
142     }
143     result.push_back(currentToken);
144     return result;
145 }
146 } // namespace
147 
SetSocket(const int & socketFd)148 void TLSServerSendOptions::SetSocket(const int &socketFd)
149 {
150     socketFd_ = socketFd;
151 }
152 
SetSendData(const std::string & data)153 void TLSServerSendOptions::SetSendData(const std::string &data)
154 {
155     data_ = data;
156 }
157 
GetSocket() const158 const int &TLSServerSendOptions::GetSocket() const
159 {
160     return socketFd_;
161 }
162 
GetSendData() const163 const std::string &TLSServerSendOptions::GetSendData() const
164 {
165     return data_;
166 }
167 
~TLSSocketServer()168 TLSSocketServer::~TLSSocketServer()
169 {
170     isRunning_ = false;
171     clientIdConnections_.clear();
172 
173     if (listenSocketFd_ != -1) {
174         shutdown(listenSocketFd_, SHUT_RDWR);
175         close(listenSocketFd_);
176         listenSocketFd_ = -1;
177     }
178 }
179 
Listen(const TlsSocket::TLSConnectOptions & tlsListenOptions,const ListenCallback & callback)180 void TLSSocketServer::Listen(const TlsSocket::TLSConnectOptions &tlsListenOptions, const ListenCallback &callback)
181 {
182     if (!CommonUtils::HasInternetPermission()) {
183         CallListenCallback(PERMISSION_DENIED_CODE, callback);
184         return;
185     }
186     NETSTACK_LOGE("Listen 1 %{public}d", listenSocketFd_);
187     if (listenSocketFd_ >= 0) {
188         CallListenCallback(TlsSocket::TLSSOCKET_SUCCESS, callback);
189         return;
190     }
191     NETSTACK_LOGE("Listen 2 %{public}d", listenSocketFd_);
192     if (ExecBind(tlsListenOptions.GetNetAddress(), callback)) {
193         NETSTACK_LOGE("Listen 3 %{public}d", listenSocketFd_);
194         ExecAccept(tlsListenOptions, callback);
195     } else {
196         shutdown(listenSocketFd_, SHUT_RDWR);
197         close(listenSocketFd_);
198         listenSocketFd_ = -1;
199     }
200 
201     PollThread(tlsListenOptions);
202 }
203 
ExecBind(const Socket::NetAddress & address,const ListenCallback & callback)204 bool TLSSocketServer::ExecBind(const Socket::NetAddress &address, const ListenCallback &callback)
205 {
206     MakeIpSocket(address.GetSaFamily());
207     if (listenSocketFd_ < 0) {
208         int resErr = ConvertErrno();
209         NETSTACK_LOGE("make tcp socket failed errno is %{public}d %{public}s", errno, MakeErrnoString().c_str());
210         CallOnErrorCallback(resErr, MakeErrnoString());
211         CallListenCallback(resErr, callback);
212         return false;
213     }
214     sockaddr_in addr4 = {0};
215     sockaddr_in6 addr6 = {0};
216     sockaddr *addr = nullptr;
217     socklen_t len;
218     GetAddr(address, &addr4, &addr6, &addr, &len);
219     if (addr == nullptr) {
220         NETSTACK_LOGE("TLSSocket::Bind Address Is Invalid");
221         CallOnErrorCallback(-1, "Address Is Invalid");
222         CallListenCallback(ConvertErrno(), callback);
223         return false;
224     }
225     if (bind(listenSocketFd_, addr, len) < 0) {
226         if (errno != EADDRINUSE) {
227             NETSTACK_LOGE("bind error is %{public}s %{public}d", strerror(errno), errno);
228             CallOnErrorCallback(-1, "Address binding failed");
229             CallListenCallback(ConvertErrno(), callback);
230             return false;
231         }
232         if (addr->sa_family == AF_INET) {
233             NETSTACK_LOGI("distribute a random port");
234             addr4.sin_port = 0; /* distribute a random port */
235         } else if (addr->sa_family == AF_INET6) {
236             NETSTACK_LOGI("distribute a random port");
237             addr6.sin6_port = 0; /* distribute a random port */
238         }
239         if (bind(listenSocketFd_, addr, len) < 0) {
240             NETSTACK_LOGE("rebind error is %{public}s %{public}d", strerror(errno), errno);
241             CallOnErrorCallback(-1, "Duplicate binding address failed");
242             CallListenCallback(ConvertErrno(), callback);
243             return false;
244         }
245         NETSTACK_LOGI("rebind success");
246     }
247     NETSTACK_LOGI("bind success");
248     address_ = address;
249     return true;
250 }
251 
ExecAccept(const TlsSocket::TLSConnectOptions & tlsAcceptOptions,const ListenCallback & callback)252 void TLSSocketServer::ExecAccept(const TlsSocket::TLSConnectOptions &tlsAcceptOptions, const ListenCallback &callback)
253 {
254     if (listenSocketFd_ < 0) {
255         int resErr = ConvertErrno();
256         NETSTACK_LOGE("accept error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
257         CallOnErrorCallback(resErr, MakeErrnoString());
258         callback(resErr);
259         return;
260     }
261     SetLocalTlsConfiguration(tlsAcceptOptions);
262     int ret = 0;
263 
264     NETSTACK_LOGE(
265         "accept error is listenSocketFd_=  %{public}d LISETEN_COUNT =%{public}d .GetVerifyMode()  = %{public}d ",
266         listenSocketFd_, LISETEN_COUNT, tlsAcceptOptions.GetTlsSecureOptions().GetVerifyMode());
267     ret = listen(listenSocketFd_, LISETEN_COUNT);
268     if (ret < 0) {
269         int resErr = ConvertErrno();
270         NETSTACK_LOGE("tcp server listen error");
271         CallOnErrorCallback(resErr, MakeErrnoString());
272         callback(resErr);
273         return;
274     }
275     CallListenCallback(TlsSocket::TLSSOCKET_SUCCESS, callback);
276 }
277 
Send(const TLSServerSendOptions & data,const TlsSocket::SendCallback & callback)278 bool TLSSocketServer::Send(const TLSServerSendOptions &data, const TlsSocket::SendCallback &callback)
279 {
280     int socketFd = data.GetSocket();
281     std::string info = data.GetSendData();
282 
283     auto connect_iterator = clientIdConnections_.find(socketFd);
284     if (connect_iterator == clientIdConnections_.end()) {
285         NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
286         CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
287         return false;
288     }
289     auto connect = connect_iterator->second;
290     auto res = connect->Send(info);
291     if (!res) {
292         int resErr = ConvertSSLError(connect->GetSSL());
293         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
294         CallSendCallback(resErr, callback);
295         return false;
296     }
297     CallSendCallback(TlsSocket::TLSSOCKET_SUCCESS, callback);
298     return res;
299 }
300 
CallSendCallback(int32_t err,TlsSocket::SendCallback callback)301 void TLSSocketServer::CallSendCallback(int32_t err, TlsSocket::SendCallback callback)
302 {
303     TlsSocket::SendCallback CallBackfunc = nullptr;
304     {
305         std::lock_guard<std::mutex> lock(mutex_);
306         if (callback) {
307             CallBackfunc = callback;
308         }
309     }
310 
311     if (CallBackfunc) {
312         CallBackfunc(err);
313     }
314 }
315 
Close(const int socketFd,const TlsSocket::CloseCallback & callback)316 void TLSSocketServer::Close(const int socketFd, const TlsSocket::CloseCallback &callback)
317 {
318     {
319         std::lock_guard<std::mutex> its_lock(connectMutex_);
320         for (auto it = clientIdConnections_.begin(); it != clientIdConnections_.end();) {
321             if (it->first == socketFd) {
322                 auto res = it->second->Close();
323                 if (!res) {
324                     int resErr = ConvertSSLError(it->second->GetSSL());
325                     NETSTACK_LOGE("close error is %{public}s %{public}d", MakeSSLErrorString(resErr).c_str(), resErr);
326                     CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
327                     callback(resErr);
328                     return;
329                 }
330                 callback(TlsSocket::TLSSOCKET_SUCCESS);
331                 return;
332             } else {
333                 ++it;
334             }
335         }
336     }
337     NETSTACK_LOGE("socket = %{public}d There is no corresponding socketFd", socketFd);
338     CallOnErrorCallback(-1, "The send failed with no corresponding socketFd");
339     callback(TlsSocket::TLS_ERR_SYS_EINVAL);
340 }
341 
Stop(const TlsSocket::CloseCallback & callback)342 void TLSSocketServer::Stop(const TlsSocket::CloseCallback &callback)
343 {
344     std::lock_guard<std::mutex> its_lock(connectMutex_);
345     for (const auto &c : clientIdConnections_) {
346         c.second->Close();
347     }
348     clientIdConnections_.clear();
349     close(listenSocketFd_);
350     listenSocketFd_ = -1;
351     callback(TlsSocket::TLSSOCKET_SUCCESS);
352 }
353 
GetRemoteAddress(const int socketFd,const TlsSocket::GetRemoteAddressCallback & callback)354 void TLSSocketServer::GetRemoteAddress(const int socketFd, const TlsSocket::GetRemoteAddressCallback &callback)
355 {
356     auto connect_iterator = clientIdConnections_.find(socketFd);
357     if (connect_iterator == clientIdConnections_.end()) {
358         NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
359         CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
360         callback(TlsSocket::TLS_ERR_SYS_EINVAL, {});
361         return;
362     }
363     auto connect = connect_iterator->second;
364     auto address = connect->GetAddress();
365     callback(TlsSocket::TLSSOCKET_SUCCESS, address);
366 }
367 
GetState(const TlsSocket::GetStateCallback & callback)368 void TLSSocketServer::GetState(const TlsSocket::GetStateCallback &callback)
369 {
370     int opt;
371     socklen_t optLen = sizeof(int);
372     int r = getsockopt(listenSocketFd_, SOL_SOCKET, SO_TYPE, &opt, &optLen);
373     if (r < 0) {
374         Socket::SocketStateBase state;
375         state.SetIsClose(true);
376         CallGetStateCallback(ConvertErrno(), state, callback);
377         return;
378     }
379     sockaddr sockAddr = {0};
380     socklen_t len = sizeof(sockaddr);
381     Socket::SocketStateBase state;
382     int ret = getsockname(listenSocketFd_, &sockAddr, &len);
383     state.SetIsBound(ret == 0);
384     ret = getpeername(listenSocketFd_, &sockAddr, &len);
385     if (ret != 0) {
386         NETSTACK_LOGE("getpeername failed");
387     }
388     state.SetIsConnected(GetConnectionClientCount() > 0);
389     CallGetStateCallback(TlsSocket::TLSSOCKET_SUCCESS, state, callback);
390 }
391 
CallGetStateCallback(int32_t err,const Socket::SocketStateBase & state,TlsSocket::GetStateCallback callback)392 void TLSSocketServer::CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state,
393                                            TlsSocket::GetStateCallback callback)
394 {
395     TlsSocket::GetStateCallback CallBackfunc = nullptr;
396     {
397         std::lock_guard<std::mutex> lock(mutex_);
398         if (callback) {
399             CallBackfunc = callback;
400         }
401     }
402 
403     if (CallBackfunc) {
404         CallBackfunc(err, state);
405     }
406 }
SetExtraOptions(const Socket::TCPExtraOptions & tcpExtraOptions,const TlsSocket::SetExtraOptionsCallback & callback)407 bool TLSSocketServer::SetExtraOptions(const Socket::TCPExtraOptions &tcpExtraOptions,
408                                       const TlsSocket::SetExtraOptionsCallback &callback)
409 {
410     if (tcpExtraOptions.IsKeepAlive()) {
411         int keepalive = 1;
412         if (setsockopt(listenSocketFd_, SOL_SOCKET, SO_KEEPALIVE, &keepalive, sizeof(keepalive)) < 0) {
413             return false;
414         }
415     }
416 
417     if (tcpExtraOptions.IsOOBInline()) {
418         int oobInline = 1;
419         if (setsockopt(listenSocketFd_, SOL_SOCKET, SO_OOBINLINE, &oobInline, sizeof(oobInline)) < 0) {
420             return false;
421         }
422     }
423 
424     if (tcpExtraOptions.IsTCPNoDelay()) {
425         int tcpNoDelay = 1;
426         if (setsockopt(listenSocketFd_, IPPROTO_TCP, TCP_NODELAY, &tcpNoDelay, sizeof(tcpNoDelay)) < 0) {
427             return false;
428         }
429     }
430 
431     linger soLinger = {0};
432     soLinger.l_onoff = tcpExtraOptions.socketLinger.IsOn();
433     soLinger.l_linger = (int)tcpExtraOptions.socketLinger.GetLinger();
434     if (setsockopt(listenSocketFd_, SOL_SOCKET, SO_LINGER, &soLinger, sizeof(soLinger)) < 0) {
435         return false;
436     }
437 
438     return true;
439 }
440 
SetLocalTlsConfiguration(const TlsSocket::TLSConnectOptions & config)441 void TLSSocketServer::SetLocalTlsConfiguration(const TlsSocket::TLSConnectOptions &config)
442 {
443     TLSServerConfiguration_.SetPrivateKey(config.GetTlsSecureOptions().GetKey(),
444                                           config.GetTlsSecureOptions().GetKeyPass());
445     TLSServerConfiguration_.SetLocalCertificate(config.GetTlsSecureOptions().GetCert());
446     TLSServerConfiguration_.SetCaCertificate(config.GetTlsSecureOptions().GetCaChain());
447 
448     TLSServerConfiguration_.SetVerifyMode(config.GetTlsSecureOptions().GetVerifyMode());
449 
450     const auto protocolVec = config.GetTlsSecureOptions().GetProtocolChain();
451     if (!protocolVec.empty()) {
452         TLSServerConfiguration_.SetProtocol(protocolVec);
453     }
454 }
455 
GetCertificate(const TlsSocket::GetCertificateCallback & callback)456 void TLSSocketServer::GetCertificate(const TlsSocket::GetCertificateCallback &callback)
457 {
458     const auto &cert = TLSServerConfiguration_.GetCertificate();
459     NETSTACK_LOGI("cert der is %{public}d", cert.encodingFormat);
460     if (!cert.data.Length()) {
461         CallOnErrorCallback(-1, "cert not data Length");
462         callback(-1, {});
463         return;
464     }
465     callback(TlsSocket::TLSSOCKET_SUCCESS, cert);
466 }
467 
GetRemoteCertificate(const int socketFd,const TlsSocket::GetRemoteCertificateCallback & callback)468 void TLSSocketServer::GetRemoteCertificate(const int socketFd, const TlsSocket::GetRemoteCertificateCallback &callback)
469 {
470     auto connect_iterator = clientIdConnections_.find(socketFd);
471     if (connect_iterator == clientIdConnections_.end()) {
472         NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
473         CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
474         callback(TlsSocket::TLS_ERR_SYS_EINVAL, {});
475         return;
476     }
477     auto connect = connect_iterator->second;
478     const auto &remoteCert = connect->GetRemoteCertRawData();
479     if (!remoteCert.data.Length()) {
480         int resErr = ConvertSSLError(connect->GetSSL());
481         NETSTACK_LOGE("GetRemoteCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
482         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
483         callback(resErr, {});
484         return;
485     }
486     callback(TlsSocket::TLSSOCKET_SUCCESS, remoteCert);
487 }
488 
GetProtocol(const TlsSocket::GetProtocolCallback & callback)489 void TLSSocketServer::GetProtocol(const TlsSocket::GetProtocolCallback &callback)
490 {
491     if (TLSServerConfiguration_.GetProtocol() == TlsSocket::TLS_V1_3) {
492         callback(TlsSocket::TLSSOCKET_SUCCESS, TlsSocket::PROTOCOL_TLS_V13);
493         return;
494     }
495     callback(TlsSocket::TLSSOCKET_SUCCESS, TlsSocket::PROTOCOL_TLS_V12);
496 }
497 
GetCipherSuite(const int socketFd,const TlsSocket::GetCipherSuiteCallback & callback)498 void TLSSocketServer::GetCipherSuite(const int socketFd, const TlsSocket::GetCipherSuiteCallback &callback)
499 {
500     auto connect_iterator = clientIdConnections_.find(socketFd);
501     if (connect_iterator == clientIdConnections_.end()) {
502         NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
503         CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
504         callback(TlsSocket::TLS_ERR_SYS_EINVAL, {});
505         return;
506     }
507     auto connect = connect_iterator->second;
508     auto cipherSuite = connect->GetCipherSuite();
509     if (cipherSuite.empty()) {
510         NETSTACK_LOGE("GetCipherSuite errno %{public}d", errno);
511         int resErr = ConvertSSLError(connect->GetSSL());
512         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
513         callback(resErr, cipherSuite);
514         return;
515     }
516     callback(TlsSocket::TLSSOCKET_SUCCESS, cipherSuite);
517 }
518 
GetSignatureAlgorithms(const int socketFd,const TlsSocket::GetSignatureAlgorithmsCallback & callback)519 void TLSSocketServer::GetSignatureAlgorithms(const int socketFd,
520                                              const TlsSocket::GetSignatureAlgorithmsCallback &callback)
521 {
522     auto connect_iterator = clientIdConnections_.find(socketFd);
523     if (connect_iterator == clientIdConnections_.end()) {
524         NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
525         CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
526         callback(TlsSocket::TLS_ERR_SYS_EINVAL, {});
527         return;
528     }
529     auto connect = connect_iterator->second;
530     auto signatureAlgorithms = connect->GetSignatureAlgorithms();
531     if (signatureAlgorithms.empty()) {
532         NETSTACK_LOGE("GetCipherSuite errno %{public}d", errno);
533         int resErr = ConvertSSLError(connect->GetSSL());
534         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
535         callback(resErr, signatureAlgorithms);
536         return;
537     }
538     callback(TlsSocket::TLSSOCKET_SUCCESS, signatureAlgorithms);
539 }
540 
OnMessage(const OnMessageCallback & onMessageCallback)541 void TLSSocketServer::Connection::OnMessage(const OnMessageCallback &onMessageCallback)
542 {
543     onMessageCallback_ = onMessageCallback;
544 }
545 
OnClose(const OnCloseCallback & onCloseCallback)546 void TLSSocketServer::Connection::OnClose(const OnCloseCallback &onCloseCallback)
547 {
548     onCloseCallback_ = onCloseCallback;
549 }
550 
OnConnect(const OnConnectCallback & onConnectCallback)551 void TLSSocketServer::OnConnect(const OnConnectCallback &onConnectCallback)
552 {
553     std::lock_guard<std::mutex> lock(mutex_);
554     onConnectCallback_ = onConnectCallback;
555 }
556 
OnError(const TlsSocket::OnErrorCallback & onErrorCallback)557 void TLSSocketServer::OnError(const TlsSocket::OnErrorCallback &onErrorCallback)
558 {
559     std::lock_guard<std::mutex> lock(mutex_);
560     onErrorCallback_ = onErrorCallback;
561 }
562 
OffMessage()563 void TLSSocketServer::Connection::OffMessage()
564 {
565     if (onMessageCallback_) {
566         onMessageCallback_ = nullptr;
567     }
568 }
569 
OffConnect()570 void TLSSocketServer::OffConnect()
571 {
572     std::lock_guard<std::mutex> lock(mutex_);
573     if (onConnectCallback_) {
574         onConnectCallback_ = nullptr;
575     }
576 }
577 
OnError(const TlsSocket::OnErrorCallback & onErrorCallback)578 void TLSSocketServer::Connection::OnError(const TlsSocket::OnErrorCallback &onErrorCallback)
579 {
580     onErrorCallback_ = onErrorCallback;
581 }
582 
OffClose()583 void TLSSocketServer::Connection::OffClose()
584 {
585     if (onCloseCallback_) {
586         onCloseCallback_ = nullptr;
587     }
588 }
589 
OffError()590 void TLSSocketServer::Connection::OffError()
591 {
592     onErrorCallback_ = nullptr;
593 }
594 
CallOnErrorCallback(int32_t err,const std::string & errString)595 void TLSSocketServer::Connection::CallOnErrorCallback(int32_t err, const std::string &errString)
596 {
597     TlsSocket::OnErrorCallback CallBackfunc = nullptr;
598     {
599         if (onErrorCallback_) {
600             CallBackfunc = onErrorCallback_;
601         }
602     }
603 
604     if (CallBackfunc) {
605         CallBackfunc(err, errString);
606     }
607 }
OffError()608 void TLSSocketServer::OffError()
609 {
610     std::lock_guard<std::mutex> lock(mutex_);
611     if (onErrorCallback_) {
612         onErrorCallback_ = nullptr;
613     }
614 }
615 
MakeIpSocket(sa_family_t family)616 void TLSSocketServer::MakeIpSocket(sa_family_t family)
617 {
618     if (family != AF_INET && family != AF_INET6) {
619         return;
620     }
621     int sock = socket(family, SOCK_STREAM, IPPROTO_IP);
622     if (sock < 0) {
623         int resErr = ConvertErrno();
624         NETSTACK_LOGE("Create socket failed (%{public}d:%{public}s)", errno, MakeErrnoString().c_str());
625         CallOnErrorCallback(resErr, MakeErrnoString());
626         return;
627     }
628     listenSocketFd_ = sock;
629 }
630 
CallOnErrorCallback(int32_t err,const std::string & errString)631 void TLSSocketServer::CallOnErrorCallback(int32_t err, const std::string &errString)
632 {
633     TlsSocket::OnErrorCallback CallBackfunc = nullptr;
634     {
635         std::lock_guard<std::mutex> lock(mutex_);
636         if (onErrorCallback_) {
637             CallBackfunc = onErrorCallback_;
638         }
639     }
640 
641     if (CallBackfunc) {
642         CallBackfunc(err, errString);
643     }
644 }
GetAddr(const Socket::NetAddress & address,sockaddr_in * addr4,sockaddr_in6 * addr6,sockaddr ** addr,socklen_t * len)645 void TLSSocketServer::GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6,
646                               sockaddr **addr, socklen_t *len)
647 {
648     if (!addr6 || !addr4 || !len) {
649         return;
650     }
651     sa_family_t family = address.GetSaFamily();
652     if (family == AF_INET) {
653         addr4->sin_family = AF_INET;
654         addr4->sin_port = htons(address.GetPort());
655         addr4->sin_addr.s_addr = inet_addr(address.GetAddress().c_str());
656         *addr = reinterpret_cast<sockaddr *>(addr4);
657         *len = sizeof(sockaddr_in);
658     } else if (family == AF_INET6) {
659         addr6->sin6_family = AF_INET6;
660         addr6->sin6_port = htons(address.GetPort());
661         inet_pton(AF_INET6, address.GetAddress().c_str(), &addr6->sin6_addr);
662         *addr = reinterpret_cast<sockaddr *>(addr6);
663         *len = sizeof(sockaddr_in6);
664     }
665 }
666 
GetConnectionByClientID(int clientid)667 std::shared_ptr<TLSSocketServer::Connection> TLSSocketServer::GetConnectionByClientID(int clientid)
668 {
669     std::shared_ptr<Connection> ptrConnection = nullptr;
670 
671     auto it = clientIdConnections_.find(clientid);
672     if (it != clientIdConnections_.end()) {
673         ptrConnection = it->second;
674     }
675 
676     return ptrConnection;
677 }
678 
GetConnectionClientCount()679 int TLSSocketServer::GetConnectionClientCount()
680 {
681     return g_userCounter;
682 }
683 
CallListenCallback(int32_t err,ListenCallback callback)684 void TLSSocketServer::CallListenCallback(int32_t err, ListenCallback callback)
685 {
686     ListenCallback CallBackfunc = nullptr;
687     {
688         std::lock_guard<std::mutex> lock(mutex_);
689         if (callback) {
690             CallBackfunc = callback;
691         }
692     }
693 
694     if (CallBackfunc) {
695         CallBackfunc(err);
696     }
697 }
698 
SetAddress(const Socket::NetAddress address)699 void TLSSocketServer::Connection::SetAddress(const Socket::NetAddress address)
700 {
701     address_ = address;
702 }
703 
GetRemoteCertRawData() const704 const TlsSocket::X509CertRawData &TLSSocketServer::Connection::GetRemoteCertRawData() const
705 {
706     return remoteRawData_;
707 }
708 
~Connection()709 TLSSocketServer::Connection::~Connection()
710 {
711     Close();
712 }
713 
TlsAcceptToHost(int sock,const TlsSocket::TLSConnectOptions & options)714 bool TLSSocketServer::Connection::TlsAcceptToHost(int sock, const TlsSocket::TLSConnectOptions &options)
715 {
716     SetTlsConfiguration(options);
717     std::string cipherSuite = options.GetTlsSecureOptions().GetCipherSuite();
718     if (!cipherSuite.empty()) {
719         connectionConfiguration_.SetCipherSuite(cipherSuite);
720     }
721     std::string signatureAlgorithms = options.GetTlsSecureOptions().GetSignatureAlgorithms();
722     if (!signatureAlgorithms.empty()) {
723         connectionConfiguration_.SetSignatureAlgorithms(signatureAlgorithms);
724     }
725     const auto protocolVec = options.GetTlsSecureOptions().GetProtocolChain();
726     if (!protocolVec.empty()) {
727         connectionConfiguration_.SetProtocol(protocolVec);
728     }
729     connectionConfiguration_.SetVerifyMode(options.GetTlsSecureOptions().GetVerifyMode());
730     socketFd_ = sock;
731     return StartTlsAccept(options);
732 }
733 
SetTlsConfiguration(const TlsSocket::TLSConnectOptions & config)734 void TLSSocketServer::Connection::SetTlsConfiguration(const TlsSocket::TLSConnectOptions &config)
735 {
736     connectionConfiguration_.SetPrivateKey(config.GetTlsSecureOptions().GetKey(),
737                                            config.GetTlsSecureOptions().GetKeyPass());
738     connectionConfiguration_.SetLocalCertificate(config.GetTlsSecureOptions().GetCert());
739     connectionConfiguration_.SetCaCertificate(config.GetTlsSecureOptions().GetCaChain());
740 }
741 
Send(const std::string & data)742 bool TLSSocketServer::Connection::Send(const std::string &data)
743 {
744     NETSTACK_LOGD("data to send :%{public}s", data.c_str());
745     if (data.empty()) {
746         NETSTACK_LOGE("data is empty");
747         return false;
748     }
749     if (!ssl_) {
750         NETSTACK_LOGE("ssl is null");
751         return false;
752     }
753     int len = SSL_write(ssl_, data.c_str(), data.length());
754     if (len < 0) {
755         int resErr = ConvertSSLError(GetSSL());
756         NETSTACK_LOGE("data '%{public}s' send failed!The error code is %{public}d, The error message is'%{public}s'",
757                       data.c_str(), resErr, MakeSSLErrorString(resErr).c_str());
758         return false;
759     }
760     NETSTACK_LOGD("data '%{public}s' Sent successfully,sent in total %{public}d bytes!", data.c_str(), len);
761     return true;
762 }
763 
Recv(char * buffer,int maxBufferSize)764 int TLSSocketServer::Connection::Recv(char *buffer, int maxBufferSize)
765 {
766     if (!ssl_) {
767         NETSTACK_LOGE("ssl is null");
768         return SSL_ERROR_RETURN;
769     }
770     return SSL_read(ssl_, buffer, maxBufferSize);
771 }
772 
Close()773 bool TLSSocketServer::Connection::Close()
774 {
775     if (!ssl_) {
776         NETSTACK_LOGE("ssl is null");
777         return false;
778     }
779     int result = SSL_shutdown(ssl_);
780     if (result < 0) {
781         int resErr = ConvertSSLError(GetSSL());
782         NETSTACK_LOGE("Error in shutdown, errno is %{public}d, error info is %{public}s", resErr,
783                       MakeSSLErrorString(resErr).c_str());
784     }
785     SSL_free(ssl_);
786     ssl_ = nullptr;
787     if (socketFd_ != -1) {
788         shutdown(socketFd_, SHUT_RDWR);
789         close(socketFd_);
790         socketFd_ = -1;
791     }
792     if (!tlsContextServerPointer_) {
793         NETSTACK_LOGE("Tls context pointer is null");
794         return false;
795     }
796     tlsContextServerPointer_->CloseCtx();
797     return true;
798 }
799 
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)800 bool TLSSocketServer::Connection::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
801 {
802     if (!ssl_) {
803         NETSTACK_LOGE("ssl is null");
804         return false;
805     }
806     size_t pos = 0;
807     size_t len = std::accumulate(alpnProtocols.begin(), alpnProtocols.end(), static_cast<size_t>(0),
808                                  [](size_t init, const std::string &alpnProt) { return init + alpnProt.length(); });
809     auto result = std::make_unique<unsigned char[]>(alpnProtocols.size() + len);
810     for (const auto &str : alpnProtocols) {
811         len = str.length();
812         result[pos++] = len;
813         if (!strcpy_s(reinterpret_cast<char *>(&result[pos]), len, str.c_str())) {
814             NETSTACK_LOGE("strcpy_s failed");
815             return false;
816         }
817         pos += len;
818     }
819     result[pos] = '\0';
820 
821     NETSTACK_LOGD("alpnProtocols after splicing %{public}s", result.get());
822     if (SSL_set_alpn_protos(ssl_, result.get(), pos)) {
823         int resErr = ConvertSSLError(GetSSL());
824         NETSTACK_LOGE("Failed to set negotiable protocol list, errno is %{public}d, error info is %{public}s", resErr,
825                       MakeSSLErrorString(resErr).c_str());
826         return false;
827     }
828     return true;
829 }
830 
MakeRemoteInfo(Socket::SocketRemoteInfo & remoteInfo)831 void TLSSocketServer::Connection::MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo)
832 {
833     remoteInfo.SetAddress(address_.GetAddress());
834     remoteInfo.SetPort(address_.GetPort());
835     remoteInfo.SetFamily(address_.GetSaFamily());
836 }
837 
GetTlsConfiguration() const838 TlsSocket::TLSConfiguration TLSSocketServer::Connection::GetTlsConfiguration() const
839 {
840     return connectionConfiguration_;
841 }
842 
GetCipherSuite() const843 std::vector<std::string> TLSSocketServer::Connection::GetCipherSuite() const
844 {
845     if (!ssl_) {
846         NETSTACK_LOGE("ssl in null");
847         return {};
848     }
849     STACK_OF(SSL_CIPHER) *sk = SSL_get_ciphers(ssl_);
850     if (!sk) {
851         NETSTACK_LOGE("get ciphers failed");
852         return {};
853     }
854     TlsSocket::CipherSuite cipherSuite;
855     std::vector<std::string> cipherSuiteVec;
856     for (int i = 0; i < sk_SSL_CIPHER_num(sk); i++) {
857         const SSL_CIPHER *c = sk_SSL_CIPHER_value(sk, i);
858         cipherSuite.cipherName_ = SSL_CIPHER_get_name(c);
859         cipherSuiteVec.push_back(cipherSuite.cipherName_);
860     }
861     return cipherSuiteVec;
862 }
863 
GetRemoteCertificate() const864 std::string TLSSocketServer::Connection::GetRemoteCertificate() const
865 {
866     return remoteCert_;
867 }
868 
GetCertificate() const869 const TlsSocket::X509CertRawData &TLSSocketServer::Connection::GetCertificate() const
870 {
871     return connectionConfiguration_.GetCertificate();
872 }
873 
GetSignatureAlgorithms() const874 std::vector<std::string> TLSSocketServer::Connection::GetSignatureAlgorithms() const
875 {
876     return signatureAlgorithms_;
877 }
878 
GetProtocol() const879 std::string TLSSocketServer::Connection::GetProtocol() const
880 {
881     if (!ssl_) {
882         NETSTACK_LOGE("ssl in null");
883         return PROTOCOL_UNKNOW;
884     }
885     if (connectionConfiguration_.GetProtocol() == TlsSocket::TLS_V1_3) {
886         return TlsSocket::PROTOCOL_TLS_V13;
887     }
888     return TlsSocket::PROTOCOL_TLS_V12;
889 }
890 
SetSharedSigals()891 bool TLSSocketServer::Connection::SetSharedSigals()
892 {
893     if (!ssl_) {
894         NETSTACK_LOGE("ssl is null");
895         return false;
896     }
897     int number = SSL_get_shared_sigalgs(ssl_, 0, nullptr, nullptr, nullptr, nullptr, nullptr);
898     if (!number) {
899         NETSTACK_LOGE("SSL_get_shared_sigalgs return value error");
900         return false;
901     }
902     for (int i = 0; i < number; i++) {
903         int hash_nid;
904         int sign_nid;
905         std::string sig_with_md;
906         SSL_get_shared_sigalgs(ssl_, i, &sign_nid, &hash_nid, nullptr, nullptr, nullptr);
907         switch (sign_nid) {
908             case EVP_PKEY_RSA:
909                 sig_with_md = SIGN_NID_RSA;
910                 break;
911             case EVP_PKEY_RSA_PSS:
912                 sig_with_md = SIGN_NID_RSA_PSS;
913                 break;
914             case EVP_PKEY_DSA:
915                 sig_with_md = SIGN_NID_DSA;
916                 break;
917             case EVP_PKEY_EC:
918                 sig_with_md = SIGN_NID_ECDSA;
919                 break;
920             case NID_ED25519:
921                 sig_with_md = SIGN_NID_ED;
922                 break;
923             case NID_ED448:
924                 sig_with_md = SIGN_NID_ED_FOUR_FOUR_EIGHT;
925                 break;
926             default:
927                 const char *sn = OBJ_nid2sn(sign_nid);
928                 sig_with_md = (sn != nullptr) ? (std::string(sn) + OPERATOR_PLUS_SIGN) : SIGN_NID_UNDEF_ADD;
929         }
930         const char *sn_hash = OBJ_nid2sn(hash_nid);
931         sig_with_md += (sn_hash != nullptr) ? std::string(sn_hash) : SIGN_NID_UNDEF;
932         signatureAlgorithms_.push_back(sig_with_md);
933     }
934     return true;
935 }
936 
GetSSL() const937 ssl_st *TLSSocketServer::Connection::GetSSL() const
938 {
939     return ssl_;
940 }
941 
GetAddress() const942 Socket::NetAddress TLSSocketServer::Connection::GetAddress() const
943 {
944     return address_;
945 }
946 
GetSocketFd() const947 int TLSSocketServer::Connection::GetSocketFd() const
948 {
949     return socketFd_;
950 }
951 
GetEventManager() const952 std::shared_ptr<EventManager> TLSSocketServer::Connection::GetEventManager() const
953 {
954     return eventManager_;
955 }
956 
SetEventManager(std::shared_ptr<EventManager> eventManager)957 void TLSSocketServer::Connection::SetEventManager(std::shared_ptr<EventManager> eventManager)
958 {
959     eventManager_ = eventManager;
960 }
961 
SetClientID(int32_t clientID)962 void TLSSocketServer::Connection::SetClientID(int32_t clientID)
963 {
964     clientID_ = clientID;
965 }
966 
GetClientID()967 int TLSSocketServer::Connection::GetClientID()
968 {
969     return clientID_;
970 }
971 
StartTlsAccept(const TlsSocket::TLSConnectOptions & options)972 bool TLSSocketServer::Connection::StartTlsAccept(const TlsSocket::TLSConnectOptions &options)
973 {
974     if (!CreatTlsContext()) {
975         NETSTACK_LOGE("failed to create tls context");
976         return false;
977     }
978     if (!StartShakingHands(options)) {
979         NETSTACK_LOGE("failed to shaking hands");
980         return false;
981     }
982     return true;
983 }
984 
CreatTlsContext()985 bool TLSSocketServer::Connection::CreatTlsContext()
986 {
987     tlsContextServerPointer_ = TlsSocket::TLSContextServer::CreateConfiguration(connectionConfiguration_);
988     if (!tlsContextServerPointer_) {
989         NETSTACK_LOGE("failed to create tls context pointer");
990         return false;
991     }
992     if (!(ssl_ = tlsContextServerPointer_->CreateSsl())) {
993         NETSTACK_LOGE("failed to create ssl session");
994         return false;
995     }
996     SSL_set_fd(ssl_, socketFd_);
997     SSL_set_accept_state(ssl_);
998     return true;
999 }
1000 
StartShakingHands(const TlsSocket::TLSConnectOptions & options)1001 bool TLSSocketServer::Connection::StartShakingHands(const TlsSocket::TLSConnectOptions &options)
1002 {
1003     if (!ssl_) {
1004         NETSTACK_LOGE("ssl is null");
1005         return false;
1006     }
1007     int result = SSL_accept(ssl_);
1008     if (result == -1) {
1009         int errorStatus = ConvertSSLError(ssl_);
1010         NETSTACK_LOGE("SSL connect is error, errno is %{public}d, error info is %{public}s", errorStatus,
1011                       MakeSSLErrorString(errorStatus).c_str());
1012         return false;
1013     }
1014 
1015     std::vector<std::string> SslProtocolVer({SSL_get_version(ssl_)});
1016     connectionConfiguration_.SetProtocol({SslProtocolVer});
1017 
1018     std::string list = SSL_get_cipher_list(ssl_, 0);
1019     NETSTACK_LOGI("SSL_get_cipher_list: %{public}s", list.c_str());
1020     connectionConfiguration_.SetCipherSuite(list);
1021     if (!SetSharedSigals()) {
1022         NETSTACK_LOGE("Failed to set sharedSigalgs");
1023     }
1024 
1025     if (!GetRemoteCertificateFromPeer()) {
1026         NETSTACK_LOGE("Failed to get remote certificate");
1027     }
1028     if (peerX509_ != nullptr) {
1029         NETSTACK_LOGE("peer x509Certificates is null");
1030 
1031         if (!SetRemoteCertRawData()) {
1032             NETSTACK_LOGE("Failed to set remote x509 certificata Serialization data");
1033         }
1034         TlsSocket::CheckServerIdentity checkServerIdentity = options.GetCheckServerIdentity();
1035         if (!checkServerIdentity) {
1036             CheckServerIdentityLegal(hostName_, peerX509_);
1037         } else {
1038             checkServerIdentity(hostName_, {remoteCert_});
1039         }
1040     }
1041     return true;
1042 }
1043 
GetRemoteCertificateFromPeer()1044 bool TLSSocketServer::Connection::GetRemoteCertificateFromPeer()
1045 {
1046     peerX509_ = SSL_get_peer_certificate(ssl_);
1047 
1048     if (SSL_get_verify_result(ssl_) == X509_V_OK) {
1049         NETSTACK_LOGE("SSL_get_verify_result ==X509_V_OK");
1050     }
1051 
1052     if (peerX509_ == nullptr) {
1053         int resErr = ConvertSSLError(GetSSL());
1054         NETSTACK_LOGE("open fail errno, errno is %{public}d, error info is %{public}s", resErr,
1055                       MakeSSLErrorString(resErr).c_str());
1056         return false;
1057     }
1058     BIO *bio = BIO_new(BIO_s_mem());
1059     if (!bio) {
1060         NETSTACK_LOGE("TlsSocket::SetRemoteCertificate bio is null");
1061         return false;
1062     }
1063     X509_print(bio, peerX509_);
1064     char data[REMOTE_CERT_LEN] = {0};
1065     if (!BIO_read(bio, data, REMOTE_CERT_LEN)) {
1066         NETSTACK_LOGE("BIO_read function returns error");
1067         BIO_free(bio);
1068         return false;
1069     }
1070     BIO_free(bio);
1071     remoteCert_ = std::string(data);
1072     return true;
1073 }
1074 
SetRemoteCertRawData()1075 bool TLSSocketServer::Connection::SetRemoteCertRawData()
1076 {
1077     if (peerX509_ == nullptr) {
1078         NETSTACK_LOGE("peerX509 is null");
1079         return false;
1080     }
1081     int32_t length = i2d_X509(peerX509_, nullptr);
1082     if (length <= 0) {
1083         NETSTACK_LOGE("Failed to convert peerX509 to der format");
1084         return false;
1085     }
1086     unsigned char *der = nullptr;
1087     (void)i2d_X509(peerX509_, &der);
1088     TlsSocket::SecureData data(der, length);
1089     remoteRawData_.data = data;
1090     OPENSSL_free(der);
1091     remoteRawData_.encodingFormat = TlsSocket::EncodingFormat::DER;
1092     return true;
1093 }
1094 
StartsWith(const std::string & s,const std::string & prefix)1095 static bool StartsWith(const std::string &s, const std::string &prefix)
1096 {
1097     return s.size() >= prefix.size() && s.compare(0, prefix.size(), prefix) == 0;
1098 }
CheckIpAndDnsName(const std::string & hostName,std::vector<std::string> & dnsNames,std::vector<std::string> & ips,const X509 * x509Certificates,std::tuple<bool,std::string> & result)1099 void CheckIpAndDnsName(const std::string &hostName, std::vector<std::string> &dnsNames, std::vector<std::string> &ips,
1100                        const X509 *x509Certificates, std::tuple<bool, std::string> &result)
1101 {
1102     bool valid = false;
1103     std::string reason = UNKNOW_REASON;
1104     int index = X509_get_ext_by_NID(x509Certificates, NID_commonName, -1);
1105     if (IsIP(hostName)) {
1106         auto it = find(ips.begin(), ips.end(), hostName);
1107         if (it == ips.end()) {
1108             reason = IP + hostName + " is not in the cert's list";
1109         }
1110         result = {valid, reason};
1111         return;
1112     }
1113     std::string tempHostName = "" + hostName;
1114     std::string tmpStr = "";
1115     if (!dnsNames.empty() || index > 0) {
1116         std::vector<std::string> hostParts = SplitHostName(tempHostName);
1117         if (!dnsNames.empty()) {
1118             valid = SeekIntersection(hostParts, dnsNames);
1119             tmpStr = ". is not in the cert's altnames";
1120         } else {
1121             char commonNameBuf[COMMON_NAME_BUF_SIZE] = {0};
1122             X509_NAME *pSubName = nullptr;
1123             int len = X509_NAME_get_text_by_NID(pSubName, NID_commonName, commonNameBuf, COMMON_NAME_BUF_SIZE);
1124             if (len > 0) {
1125                 std::vector<std::string> commonNameVec;
1126                 commonNameVec.emplace_back(commonNameBuf);
1127                 valid = SeekIntersection(hostParts, commonNameVec);
1128                 tmpStr = ". is not cert's CN";
1129             }
1130         }
1131         if (!valid) {
1132             reason = HOST_NAME + tempHostName + tmpStr;
1133         }
1134 
1135         result = {valid, reason};
1136         return;
1137     }
1138     reason = "Cert does not contain a DNS name";
1139     result = {valid, reason};
1140 }
1141 
CheckServerIdentityLegal(const std::string & hostName,const X509 * x509Certificates)1142 std::string TLSSocketServer::Connection::CheckServerIdentityLegal(const std::string &hostName,
1143                                                                   const X509 *x509Certificates)
1144 {
1145     std::string hostname = hostName;
1146     X509_NAME *subjectName = X509_get_subject_name(x509Certificates);
1147     if (!subjectName) {
1148         return "subject name is null";
1149     }
1150     char subNameBuf[BUF_SIZE] = {0};
1151     X509_NAME_oneline(subjectName, subNameBuf, BUF_SIZE);
1152     int index = X509_get_ext_by_NID(x509Certificates, NID_subject_alt_name, -1);
1153     if (index < 0) {
1154         return "X509 get ext nid error";
1155     }
1156     X509_EXTENSION *ext = X509_get_ext(x509Certificates, index);
1157     if (ext == nullptr) {
1158         return "X509 get ext error";
1159     }
1160     ASN1_OBJECT *obj = nullptr;
1161     obj = X509_EXTENSION_get_object(ext);
1162     char subAltNameBuf[BUF_SIZE] = {0};
1163     OBJ_obj2txt(subAltNameBuf, BUF_SIZE, obj, 0);
1164     NETSTACK_LOGD("extions obj : %{public}s\n", subAltNameBuf);
1165 
1166     return CheckServerIdentityLegal(hostName, ext, x509Certificates);
1167 }
1168 
CheckServerIdentityLegal(const std::string & hostName,X509_EXTENSION * ext,const X509 * x509Certificates)1169 std::string TLSSocketServer::Connection::CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext,
1170                                                                   const X509 *x509Certificates)
1171 {
1172     ASN1_OCTET_STRING *extData = X509_EXTENSION_get_data(ext);
1173     std::string altNames = reinterpret_cast<char *>(extData->data);
1174     std::string hostname = "" + hostName;
1175     BIO *bio = BIO_new(BIO_s_file());
1176     if (!bio) {
1177         return "bio is null";
1178     }
1179     BIO_set_fp(bio, stdout, BIO_NOCLOSE);
1180     ASN1_STRING_print(bio, extData);
1181     std::vector<std::string> dnsNames = {};
1182     std::vector<std::string> ips = {};
1183     constexpr int DNS_NAME_IDX = 4;
1184     constexpr int IP_NAME_IDX = 11;
1185     if (!altNames.empty()) {
1186         std::vector<std::string> splitAltNames;
1187         if (altNames.find('\"') != std::string::npos) {
1188             splitAltNames = SplitEscapedAltNames(altNames);
1189         } else {
1190             splitAltNames = CommonUtils::Split(altNames, SPLIT_ALT_NAMES);
1191         }
1192         for (auto const &iter : splitAltNames) {
1193             if (StartsWith(iter, DNS)) {
1194                 dnsNames.push_back(iter.substr(DNS_NAME_IDX));
1195             } else if (StartsWith(iter, IP_ADDRESS)) {
1196                 ips.push_back(iter.substr(IP_NAME_IDX));
1197             }
1198         }
1199     }
1200     std::tuple<bool, std::string> result;
1201     CheckIpAndDnsName(hostName, dnsNames, ips, x509Certificates, result);
1202     if (!std::get<0>(result)) {
1203         return "Hostname/IP does not match certificate's altnames: " + std::get<1>(result);
1204     }
1205     return HOST_NAME + hostname + ". is cert's CN";
1206 }
1207 
RemoveConnect(int socketFd)1208 void TLSSocketServer::RemoveConnect(int socketFd)
1209 {
1210     std::shared_ptr<Connection> ptrConnection = nullptr;
1211     {
1212         std::lock_guard<std::mutex> its_lock(connectMutex_);
1213 
1214         for (auto it = clientIdConnections_.begin(); it != clientIdConnections_.end();) {
1215             if (it->second->GetSocketFd() == socketFd) {
1216                 break;
1217             } else {
1218                 ++it;
1219             }
1220         }
1221     }
1222     if (ptrConnection != nullptr) {
1223         ptrConnection->CallOnCloseCallback(static_cast<unsigned int>(socketFd));
1224         ptrConnection->Close();
1225     }
1226 }
1227 
RecvRemoteInfo(int socketFd,int index)1228 int TLSSocketServer::RecvRemoteInfo(int socketFd, int index)
1229 {
1230     {
1231         std::lock_guard<std::mutex> its_lock(connectMutex_);
1232         for (auto it = clientIdConnections_.begin(); it != clientIdConnections_.end();) {
1233             if (it->second->GetSocketFd() == socketFd) {
1234                 char buffer[MAX_BUFFER_SIZE];
1235                 if (memset_s(buffer, MAX_BUFFER_SIZE, 0, MAX_BUFFER_SIZE) != EOK) {
1236                     NETSTACK_LOGE("memcpy_s failed");
1237                     break;
1238                 }
1239                 int len = it->second->Recv(buffer, MAX_BUFFER_SIZE);
1240                 NETSTACK_LOGE("revc message is size is  %{public}d  buffer is   %{public}s ", len, buffer);
1241                 if (len > 0) {
1242                     Socket::SocketRemoteInfo remoteInfo;
1243                     remoteInfo.SetSize(strlen(buffer));
1244                     it->second->MakeRemoteInfo(remoteInfo);
1245                     it->second->CallOnMessageCallback(socketFd, buffer, remoteInfo);
1246                     return len;
1247                 }
1248                 break;
1249             } else {
1250                 ++it;
1251             }
1252         }
1253     }
1254     RemoveConnect(socketFd);
1255     DropFdFromPollList(index);
1256     return -1;
1257 }
1258 
CallOnMessageCallback(int32_t socketFd,const std::string & data,const Socket::SocketRemoteInfo & remoteInfo)1259 void TLSSocketServer::Connection::CallOnMessageCallback(int32_t socketFd, const std::string &data,
1260                                                         const Socket::SocketRemoteInfo &remoteInfo)
1261 {
1262     OnMessageCallback CallBackfunc = nullptr;
1263     {
1264         if (onMessageCallback_) {
1265             CallBackfunc = onMessageCallback_;
1266         }
1267     }
1268 
1269     if (CallBackfunc) {
1270         CallBackfunc(socketFd, data, remoteInfo);
1271     }
1272 }
1273 
AddConnect(int socketFd,std::shared_ptr<Connection> connection)1274 void TLSSocketServer::AddConnect(int socketFd, std::shared_ptr<Connection> connection)
1275 {
1276     std::lock_guard<std::mutex> its_lock(connectMutex_);
1277     clientIdConnections_[connection->GetClientID()] = connection;
1278 }
1279 
CallOnCloseCallback(const int32_t socketFd)1280 void TLSSocketServer::Connection::CallOnCloseCallback(const int32_t socketFd)
1281 {
1282     OnCloseCallback CallBackfunc = nullptr;
1283     {
1284         if (onCloseCallback_) {
1285             CallBackfunc = onCloseCallback_;
1286         }
1287     }
1288 
1289     if (CallBackfunc) {
1290         CallBackfunc(socketFd);
1291     }
1292 }
1293 
CallOnConnectCallback(const int32_t socketFd,std::shared_ptr<EventManager> eventManager)1294 void TLSSocketServer::CallOnConnectCallback(const int32_t socketFd, std::shared_ptr<EventManager> eventManager)
1295 {
1296     OnConnectCallback CallBackfunc = nullptr;
1297     {
1298         std::lock_guard<std::mutex> lock(mutex_);
1299         if (onConnectCallback_) {
1300             CallBackfunc = onConnectCallback_;
1301         }
1302     }
1303 
1304     if (CallBackfunc) {
1305         CallBackfunc(socketFd, eventManager);
1306     } else {
1307         NETSTACK_LOGE("CallOnConnectCallback  fun === null");
1308     }
1309 }
1310 
ProcessTcpAccept(const TlsSocket::TLSConnectOptions & tlsListenOptions,int clientID)1311 void TLSSocketServer::ProcessTcpAccept(const TlsSocket::TLSConnectOptions &tlsListenOptions, int clientID)
1312 {
1313 #if !defined(CROSS_PLATFORM)
1314     struct sockaddr_in clientAddress;
1315     socklen_t clientAddrLength = sizeof(clientAddress);
1316     int connectFD = accept(listenSocketFd_, (struct sockaddr *)&clientAddress, &clientAddrLength);
1317     if (connectFD < 0) {
1318         int resErr = ConvertErrno();
1319         NETSTACK_LOGE("Server accept new client ERROR");
1320         CallOnErrorCallback(resErr, MakeErrnoString());
1321         return;
1322     }
1323     NETSTACK_LOGI("Server accept new client SUCCESS");
1324     std::shared_ptr<Connection> connection = std::make_shared<Connection>();
1325     Socket::NetAddress netAddress;
1326     char clientIp[INET_ADDRSTRLEN];
1327     inet_ntop(address_.GetSaFamily(), &clientAddress.sin_addr, clientIp, INET_ADDRSTRLEN);
1328     int clientPort = ntohs(clientAddress.sin_port);
1329     netAddress.SetAddress(clientIp);
1330     netAddress.SetPort(clientPort);
1331     netAddress.SetFamilyBySaFamily(address_.GetSaFamily());
1332     connection->SetAddress(netAddress);
1333     connection->SetClientID(clientID);
1334     auto res = connection->TlsAcceptToHost(connectFD, tlsListenOptions);
1335     if (!res) {
1336         int resErr = ConvertSSLError(connection->GetSSL());
1337         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1338         return;
1339     }
1340     if (g_userCounter >= USER_LIMIT) {
1341         const std::string info = "Too many users!";
1342         connection->Send(info);
1343         connection->Close();
1344         NETSTACK_LOGE("Too many users");
1345         close(connectFD);
1346         CallOnErrorCallback(-1, "Too many users");
1347         return;
1348     }
1349     g_userCounter++;
1350     fds_[g_userCounter].fd = connectFD;
1351     fds_[g_userCounter].events = POLLIN | POLLRDHUP | POLLERR;
1352     fds_[g_userCounter].revents = 0;
1353     AddConnect(connectFD, connection);
1354     auto ptrEventManager = std::make_shared<EventManager>();
1355 
1356     ptrEventManager->SetData(this);
1357     connection->SetEventManager(ptrEventManager);
1358     CallOnConnectCallback(clientID, ptrEventManager);
1359     NETSTACK_LOGI("New client come in, fd is %{public}d", connectFD);
1360 #endif
1361 }
1362 
InitPollList(int & listendFd)1363 void TLSSocketServer::InitPollList(int &listendFd)
1364 {
1365     for (int i = 1; i <= USER_LIMIT; ++i) {
1366         fds_[i].fd = -1;
1367         fds_[i].events = 0;
1368     }
1369     fds_[0].fd = listendFd;
1370     fds_[0].events = POLLIN | POLLERR;
1371     fds_[0].revents = 0;
1372 }
1373 
DropFdFromPollList(int & fd_index)1374 void TLSSocketServer::DropFdFromPollList(int &fd_index)
1375 {
1376     if (g_userCounter < 0) {
1377         NETSTACK_LOGE("g_userCounter  = %{public}d", g_userCounter);
1378         return;
1379     }
1380     fds_[fd_index].fd = fds_[g_userCounter].fd;
1381 
1382     fds_[g_userCounter].fd = -1;
1383     fds_[g_userCounter].events = 0;
1384     fd_index--;
1385     g_userCounter--;
1386     NETSTACK_LOGE("CallOnConnectCallback  g_userCounter  = %{public}d", g_userCounter);
1387 }
1388 
PollThread(const TlsSocket::TLSConnectOptions & tlsListenOptions)1389 void TLSSocketServer::PollThread(const TlsSocket::TLSConnectOptions &tlsListenOptions)
1390 {
1391 #if !defined(CROSS_PLATFORM)
1392     int on = 1;
1393     isRunning_ = true;
1394     ioctl(listenSocketFd_, FIONBIO, (char *)&on);
1395     NETSTACK_LOGE("PollThread  start working %{public}d", isRunning_);
1396     std::thread thread_([this, &tlsListenOptions]() {
1397         TlsSocket::TLSConnectOptions tlsOption = tlsListenOptions;
1398         InitPollList(listenSocketFd_);
1399         int clientId = 0;
1400         while (isRunning_) {
1401             int ret = poll(fds_, g_userCounter + 1, POLL_WAIT_TIME);
1402             if (ret < 0) {
1403                 int resErr = ConvertErrno();
1404                 NETSTACK_LOGE("Poll ERROR");
1405                 CallOnErrorCallback(resErr, MakeErrnoString());
1406                 break;
1407             }
1408             if (ret == 0) {
1409                 continue;
1410             }
1411             for (int i = 0; i < g_userCounter + 1; ++i) {
1412                 if ((fds_[i].fd == listenSocketFd_) && (static_cast<uint16_t>(fds_[i].revents) & POLLIN)) {
1413                     ProcessTcpAccept(tlsOption, ++clientId);
1414                 } else if ((static_cast<uint16_t>(fds_[i].revents) & POLLRDHUP) ||
1415                            (static_cast<uint16_t>(fds_[i].revents) & POLLERR)) {
1416                     RemoveConnect(fds_[i].fd);
1417                     DropFdFromPollList(i);
1418                     NETSTACK_LOGI("A client left");
1419                 } else if (static_cast<uint16_t>(fds_[i].revents) & POLLIN) {
1420                     RecvRemoteInfo(fds_[i].fd, i);
1421                 }
1422             }
1423         }
1424     });
1425     thread_.detach();
1426 #endif
1427 }
1428 
GetConnectionByClientEventManager(const EventManager * eventManager)1429 std::shared_ptr<TLSSocketServer::Connection> TLSSocketServer::GetConnectionByClientEventManager(
1430     const EventManager *eventManager)
1431 {
1432     std::shared_ptr<TLSSocketServer::Connection> ptrConnection = nullptr;
1433     std::lock_guard<std::mutex> its_lock(connectMutex_);
1434     for (const auto &it : clientIdConnections_) {
1435         if (it.second->GetEventManager().get() == eventManager) {
1436             ptrConnection = it.second;
1437             return ptrConnection;
1438         }
1439     }
1440     return ptrConnection;
1441 }
1442 
CloseConnectionByEventManager(EventManager * eventManager)1443 void TLSSocketServer::CloseConnectionByEventManager(EventManager *eventManager)
1444 {
1445     std::shared_ptr<Connection> ptrConnection = GetConnectionByClientEventManager(eventManager);
1446 
1447     if (ptrConnection != nullptr) {
1448         ptrConnection->Close();
1449     }
1450 }
1451 
DeleteConnectionByEventManager(EventManager * eventManager)1452 void TLSSocketServer::DeleteConnectionByEventManager(EventManager *eventManager)
1453 {
1454     std::lock_guard<std::mutex> its_lock(connectMutex_);
1455     for (auto it = clientIdConnections_.begin(); it != clientIdConnections_.end(); ++it) {
1456         if (it->second->GetEventManager().get() == eventManager) {
1457             it = clientIdConnections_.erase(it);
1458             break;
1459         }
1460     }
1461 }
1462 } // namespace TlsSocketServer
1463 } // namespace NetStack
1464 } // namespace OHOS
1465