1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "tls_socket_server.h"
17
18 #include <chrono>
19 #include <memory>
20 #include <netinet/tcp.h>
21 #include <numeric>
22 #include <openssl/err.h>
23 #include <openssl/ssl.h>
24
25 #include <regex>
26 #include <securec.h>
27 #include <sys/ioctl.h>
28
29 #include "base_context.h"
30 #include "netstack_common_utils.h"
31 #include "netstack_log.h"
32 #include "tls.h"
33
34 namespace OHOS {
35 namespace NetStack {
36 namespace TlsSocketServer {
37 namespace {
38 constexpr int SSL_RET_CODE = 0;
39
40 constexpr int BUF_SIZE = 2048;
41 constexpr int POLL_WAIT_TIME = 2000;
42 constexpr int OFFSET = 2;
43 constexpr int SSL_ERROR_RETURN = -1;
44 constexpr int REMOTE_CERT_LEN = 8192;
45 constexpr int COMMON_NAME_BUF_SIZE = 256;
46 constexpr int LISETEN_COUNT = 516;
47 constexpr const char *SPLIT_HOST_NAME = ".";
48 constexpr const char *SPLIT_ALT_NAMES = ",";
49 constexpr const char *DNS = "DNS:";
50 constexpr const char *HOST_NAME = "hostname: ";
51 constexpr const char *IP_ADDRESS = "IP Address:";
52 constexpr const char *SIGN_NID_RSA = "RSA+";
53 constexpr const char *SIGN_NID_RSA_PSS = "RSA-PSS+";
54 constexpr const char *SIGN_NID_DSA = "DSA+";
55 constexpr const char *SIGN_NID_ECDSA = "ECDSA+";
56 constexpr const char *SIGN_NID_ED = "Ed25519+";
57 constexpr const char *SIGN_NID_ED_FOUR_FOUR_EIGHT = "Ed448+";
58 constexpr const char *SIGN_NID_UNDEF_ADD = "UNDEF+";
59 constexpr const char *PROTOCOL_UNKNOW = "UNKNOW_PROTOCOL";
60 constexpr const char *SIGN_NID_UNDEF = "UNDEF";
61 constexpr const char *OPERATOR_PLUS_SIGN = "+";
62 constexpr const char *UNKNOW_REASON = "Unknown reason";
63 constexpr const char *IP = "IP: ";
64 const std::regex JSON_STRING_PATTERN{R"(/^"(?:[^"\\\u0000-\u001f]|\\(?:["\\/bfnrt]|u[0-9a-fA-F]{4}))*"/)"};
65 const std::regex PATTERN{
66 "((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|"
67 "2[0-4][0-9]|[01]?[0-9][0-9]?)"};
68 int g_userCounter = 0;
69
IsIP(const std::string & ip)70 bool IsIP(const std::string &ip)
71 {
72 std::regex pattern(PATTERN);
73 std::smatch res;
74 return regex_match(ip, res, pattern);
75 }
76
SplitHostName(std::string & hostName)77 std::vector<std::string> SplitHostName(std::string &hostName)
78 {
79 transform(hostName.begin(), hostName.end(), hostName.begin(), ::tolower);
80 return CommonUtils::Split(hostName, SPLIT_HOST_NAME);
81 }
82
SeekIntersection(std::vector<std::string> & vecA,std::vector<std::string> & vecB)83 bool SeekIntersection(std::vector<std::string> &vecA, std::vector<std::string> &vecB)
84 {
85 std::vector<std::string> result;
86 set_intersection(vecA.begin(), vecA.end(), vecB.begin(), vecB.end(), inserter(result, result.begin()));
87 return !result.empty();
88 }
89
ConvertErrno()90 int ConvertErrno()
91 {
92 return TlsSocket::TlsSocketError::TLS_ERR_SYS_BASE + errno;
93 }
94
ConvertSSLError(ssl_st * ssl)95 int ConvertSSLError(ssl_st *ssl)
96 {
97 if (!ssl) {
98 return TlsSocket::TLS_ERR_SSL_NULL;
99 }
100 return TlsSocket::TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl, SSL_RET_CODE);
101 }
102
MakeErrnoString()103 std::string MakeErrnoString()
104 {
105 return strerror(errno);
106 }
107
MakeSSLErrorString(int error)108 std::string MakeSSLErrorString(int error)
109 {
110 char err[TlsSocket::MAX_ERR_LEN] = {0};
111 ERR_error_string_n(error - TlsSocket::TlsSocketError::TLS_ERR_SYS_BASE, err, sizeof(err));
112 return err;
113 }
SplitEscapedAltNames(std::string & altNames)114 std::vector<std::string> SplitEscapedAltNames(std::string &altNames)
115 {
116 std::vector<std::string> result;
117 std::string currentToken;
118 size_t offset = 0;
119 while (offset != altNames.length()) {
120 auto nextSep = altNames.find_first_of(", ");
121 auto nextQuote = altNames.find_first_of('\"');
122 if (nextQuote != std::string::npos && (nextSep != std::string::npos || nextQuote < nextSep)) {
123 currentToken += altNames.substr(offset, nextQuote);
124 std::regex jsonStringPattern(JSON_STRING_PATTERN);
125 std::smatch match;
126 std::string altNameSubStr = altNames.substr(nextQuote);
127 bool ret = regex_match(altNameSubStr, match, jsonStringPattern);
128 if (!ret) {
129 return {""};
130 }
131 currentToken += result[0];
132 offset = nextQuote + result[0].length();
133 } else if (nextSep != std::string::npos) {
134 currentToken += altNames.substr(offset, nextSep);
135 result.push_back(currentToken);
136 currentToken = "";
137 offset = nextSep + OFFSET;
138 } else {
139 currentToken += altNames.substr(offset);
140 offset = altNames.length();
141 }
142 }
143 result.push_back(currentToken);
144 return result;
145 }
146 } // namespace
147
SetSocket(const int & socketFd)148 void TLSServerSendOptions::SetSocket(const int &socketFd)
149 {
150 socketFd_ = socketFd;
151 }
152
SetSendData(const std::string & data)153 void TLSServerSendOptions::SetSendData(const std::string &data)
154 {
155 data_ = data;
156 }
157
GetSocket() const158 const int &TLSServerSendOptions::GetSocket() const
159 {
160 return socketFd_;
161 }
162
GetSendData() const163 const std::string &TLSServerSendOptions::GetSendData() const
164 {
165 return data_;
166 }
167
~TLSSocketServer()168 TLSSocketServer::~TLSSocketServer()
169 {
170 isRunning_ = false;
171 clientIdConnections_.clear();
172
173 if (listenSocketFd_ != -1) {
174 shutdown(listenSocketFd_, SHUT_RDWR);
175 close(listenSocketFd_);
176 listenSocketFd_ = -1;
177 }
178 }
179
Listen(const TlsSocket::TLSConnectOptions & tlsListenOptions,const ListenCallback & callback)180 void TLSSocketServer::Listen(const TlsSocket::TLSConnectOptions &tlsListenOptions, const ListenCallback &callback)
181 {
182 if (!CommonUtils::HasInternetPermission()) {
183 CallListenCallback(PERMISSION_DENIED_CODE, callback);
184 return;
185 }
186 NETSTACK_LOGE("Listen 1 %{public}d", listenSocketFd_);
187 if (listenSocketFd_ >= 0) {
188 CallListenCallback(TlsSocket::TLSSOCKET_SUCCESS, callback);
189 return;
190 }
191 NETSTACK_LOGE("Listen 2 %{public}d", listenSocketFd_);
192 if (ExecBind(tlsListenOptions.GetNetAddress(), callback)) {
193 NETSTACK_LOGE("Listen 3 %{public}d", listenSocketFd_);
194 ExecAccept(tlsListenOptions, callback);
195 } else {
196 shutdown(listenSocketFd_, SHUT_RDWR);
197 close(listenSocketFd_);
198 listenSocketFd_ = -1;
199 }
200
201 PollThread(tlsListenOptions);
202 }
203
ExecBind(const Socket::NetAddress & address,const ListenCallback & callback)204 bool TLSSocketServer::ExecBind(const Socket::NetAddress &address, const ListenCallback &callback)
205 {
206 MakeIpSocket(address.GetSaFamily());
207 if (listenSocketFd_ < 0) {
208 int resErr = ConvertErrno();
209 NETSTACK_LOGE("make tcp socket failed errno is %{public}d %{public}s", errno, MakeErrnoString().c_str());
210 CallOnErrorCallback(resErr, MakeErrnoString());
211 CallListenCallback(resErr, callback);
212 return false;
213 }
214 sockaddr_in addr4 = {0};
215 sockaddr_in6 addr6 = {0};
216 sockaddr *addr = nullptr;
217 socklen_t len;
218 GetAddr(address, &addr4, &addr6, &addr, &len);
219 if (addr == nullptr) {
220 NETSTACK_LOGE("TLSSocket::Bind Address Is Invalid");
221 CallOnErrorCallback(-1, "Address Is Invalid");
222 CallListenCallback(ConvertErrno(), callback);
223 return false;
224 }
225 if (bind(listenSocketFd_, addr, len) < 0) {
226 if (errno != EADDRINUSE) {
227 NETSTACK_LOGE("bind error is %{public}s %{public}d", strerror(errno), errno);
228 CallOnErrorCallback(-1, "Address binding failed");
229 CallListenCallback(ConvertErrno(), callback);
230 return false;
231 }
232 if (addr->sa_family == AF_INET) {
233 NETSTACK_LOGI("distribute a random port");
234 addr4.sin_port = 0; /* distribute a random port */
235 } else if (addr->sa_family == AF_INET6) {
236 NETSTACK_LOGI("distribute a random port");
237 addr6.sin6_port = 0; /* distribute a random port */
238 }
239 if (bind(listenSocketFd_, addr, len) < 0) {
240 NETSTACK_LOGE("rebind error is %{public}s %{public}d", strerror(errno), errno);
241 CallOnErrorCallback(-1, "Duplicate binding address failed");
242 CallListenCallback(ConvertErrno(), callback);
243 return false;
244 }
245 NETSTACK_LOGI("rebind success");
246 }
247 NETSTACK_LOGI("bind success");
248 address_ = address;
249 return true;
250 }
251
ExecAccept(const TlsSocket::TLSConnectOptions & tlsAcceptOptions,const ListenCallback & callback)252 void TLSSocketServer::ExecAccept(const TlsSocket::TLSConnectOptions &tlsAcceptOptions, const ListenCallback &callback)
253 {
254 if (listenSocketFd_ < 0) {
255 int resErr = ConvertErrno();
256 NETSTACK_LOGE("accept error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
257 CallOnErrorCallback(resErr, MakeErrnoString());
258 callback(resErr);
259 return;
260 }
261 SetLocalTlsConfiguration(tlsAcceptOptions);
262 int ret = 0;
263
264 NETSTACK_LOGE(
265 "accept error is listenSocketFd_= %{public}d LISETEN_COUNT =%{public}d .GetVerifyMode() = %{public}d ",
266 listenSocketFd_, LISETEN_COUNT, tlsAcceptOptions.GetTlsSecureOptions().GetVerifyMode());
267 ret = listen(listenSocketFd_, LISETEN_COUNT);
268 if (ret < 0) {
269 int resErr = ConvertErrno();
270 NETSTACK_LOGE("tcp server listen error");
271 CallOnErrorCallback(resErr, MakeErrnoString());
272 callback(resErr);
273 return;
274 }
275 CallListenCallback(TlsSocket::TLSSOCKET_SUCCESS, callback);
276 }
277
Send(const TLSServerSendOptions & data,const TlsSocket::SendCallback & callback)278 bool TLSSocketServer::Send(const TLSServerSendOptions &data, const TlsSocket::SendCallback &callback)
279 {
280 int socketFd = data.GetSocket();
281 std::string info = data.GetSendData();
282
283 auto connect_iterator = clientIdConnections_.find(socketFd);
284 if (connect_iterator == clientIdConnections_.end()) {
285 NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
286 CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
287 return false;
288 }
289 auto connect = connect_iterator->second;
290 auto res = connect->Send(info);
291 if (!res) {
292 int resErr = ConvertSSLError(connect->GetSSL());
293 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
294 CallSendCallback(resErr, callback);
295 return false;
296 }
297 CallSendCallback(TlsSocket::TLSSOCKET_SUCCESS, callback);
298 return res;
299 }
300
CallSendCallback(int32_t err,TlsSocket::SendCallback callback)301 void TLSSocketServer::CallSendCallback(int32_t err, TlsSocket::SendCallback callback)
302 {
303 TlsSocket::SendCallback CallBackfunc = nullptr;
304 {
305 std::lock_guard<std::mutex> lock(mutex_);
306 if (callback) {
307 CallBackfunc = callback;
308 }
309 }
310
311 if (CallBackfunc) {
312 CallBackfunc(err);
313 }
314 }
315
Close(const int socketFd,const TlsSocket::CloseCallback & callback)316 void TLSSocketServer::Close(const int socketFd, const TlsSocket::CloseCallback &callback)
317 {
318 {
319 std::lock_guard<std::mutex> its_lock(connectMutex_);
320 for (auto it = clientIdConnections_.begin(); it != clientIdConnections_.end();) {
321 if (it->first == socketFd) {
322 auto res = it->second->Close();
323 if (!res) {
324 int resErr = ConvertSSLError(it->second->GetSSL());
325 NETSTACK_LOGE("close error is %{public}s %{public}d", MakeSSLErrorString(resErr).c_str(), resErr);
326 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
327 callback(resErr);
328 return;
329 }
330 callback(TlsSocket::TLSSOCKET_SUCCESS);
331 return;
332 } else {
333 ++it;
334 }
335 }
336 }
337 NETSTACK_LOGE("socket = %{public}d There is no corresponding socketFd", socketFd);
338 CallOnErrorCallback(-1, "The send failed with no corresponding socketFd");
339 callback(TlsSocket::TLS_ERR_SYS_EINVAL);
340 }
341
Stop(const TlsSocket::CloseCallback & callback)342 void TLSSocketServer::Stop(const TlsSocket::CloseCallback &callback)
343 {
344 std::lock_guard<std::mutex> its_lock(connectMutex_);
345 for (const auto &c : clientIdConnections_) {
346 c.second->Close();
347 }
348 clientIdConnections_.clear();
349 close(listenSocketFd_);
350 listenSocketFd_ = -1;
351 callback(TlsSocket::TLSSOCKET_SUCCESS);
352 }
353
GetRemoteAddress(const int socketFd,const TlsSocket::GetRemoteAddressCallback & callback)354 void TLSSocketServer::GetRemoteAddress(const int socketFd, const TlsSocket::GetRemoteAddressCallback &callback)
355 {
356 auto connect_iterator = clientIdConnections_.find(socketFd);
357 if (connect_iterator == clientIdConnections_.end()) {
358 NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
359 CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
360 callback(TlsSocket::TLS_ERR_SYS_EINVAL, {});
361 return;
362 }
363 auto connect = connect_iterator->second;
364 auto address = connect->GetAddress();
365 callback(TlsSocket::TLSSOCKET_SUCCESS, address);
366 }
367
GetState(const TlsSocket::GetStateCallback & callback)368 void TLSSocketServer::GetState(const TlsSocket::GetStateCallback &callback)
369 {
370 int opt;
371 socklen_t optLen = sizeof(int);
372 int r = getsockopt(listenSocketFd_, SOL_SOCKET, SO_TYPE, &opt, &optLen);
373 if (r < 0) {
374 Socket::SocketStateBase state;
375 state.SetIsClose(true);
376 CallGetStateCallback(ConvertErrno(), state, callback);
377 return;
378 }
379 sockaddr sockAddr = {0};
380 socklen_t len = sizeof(sockaddr);
381 Socket::SocketStateBase state;
382 int ret = getsockname(listenSocketFd_, &sockAddr, &len);
383 state.SetIsBound(ret == 0);
384 ret = getpeername(listenSocketFd_, &sockAddr, &len);
385 if (ret != 0) {
386 NETSTACK_LOGE("getpeername failed");
387 }
388 state.SetIsConnected(GetConnectionClientCount() > 0);
389 CallGetStateCallback(TlsSocket::TLSSOCKET_SUCCESS, state, callback);
390 }
391
CallGetStateCallback(int32_t err,const Socket::SocketStateBase & state,TlsSocket::GetStateCallback callback)392 void TLSSocketServer::CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state,
393 TlsSocket::GetStateCallback callback)
394 {
395 TlsSocket::GetStateCallback CallBackfunc = nullptr;
396 {
397 std::lock_guard<std::mutex> lock(mutex_);
398 if (callback) {
399 CallBackfunc = callback;
400 }
401 }
402
403 if (CallBackfunc) {
404 CallBackfunc(err, state);
405 }
406 }
SetExtraOptions(const Socket::TCPExtraOptions & tcpExtraOptions,const TlsSocket::SetExtraOptionsCallback & callback)407 bool TLSSocketServer::SetExtraOptions(const Socket::TCPExtraOptions &tcpExtraOptions,
408 const TlsSocket::SetExtraOptionsCallback &callback)
409 {
410 if (tcpExtraOptions.IsKeepAlive()) {
411 int keepalive = 1;
412 if (setsockopt(listenSocketFd_, SOL_SOCKET, SO_KEEPALIVE, &keepalive, sizeof(keepalive)) < 0) {
413 return false;
414 }
415 }
416
417 if (tcpExtraOptions.IsOOBInline()) {
418 int oobInline = 1;
419 if (setsockopt(listenSocketFd_, SOL_SOCKET, SO_OOBINLINE, &oobInline, sizeof(oobInline)) < 0) {
420 return false;
421 }
422 }
423
424 if (tcpExtraOptions.IsTCPNoDelay()) {
425 int tcpNoDelay = 1;
426 if (setsockopt(listenSocketFd_, IPPROTO_TCP, TCP_NODELAY, &tcpNoDelay, sizeof(tcpNoDelay)) < 0) {
427 return false;
428 }
429 }
430
431 linger soLinger = {0};
432 soLinger.l_onoff = tcpExtraOptions.socketLinger.IsOn();
433 soLinger.l_linger = (int)tcpExtraOptions.socketLinger.GetLinger();
434 if (setsockopt(listenSocketFd_, SOL_SOCKET, SO_LINGER, &soLinger, sizeof(soLinger)) < 0) {
435 return false;
436 }
437
438 return true;
439 }
440
SetLocalTlsConfiguration(const TlsSocket::TLSConnectOptions & config)441 void TLSSocketServer::SetLocalTlsConfiguration(const TlsSocket::TLSConnectOptions &config)
442 {
443 TLSServerConfiguration_.SetPrivateKey(config.GetTlsSecureOptions().GetKey(),
444 config.GetTlsSecureOptions().GetKeyPass());
445 TLSServerConfiguration_.SetLocalCertificate(config.GetTlsSecureOptions().GetCert());
446 TLSServerConfiguration_.SetCaCertificate(config.GetTlsSecureOptions().GetCaChain());
447
448 TLSServerConfiguration_.SetVerifyMode(config.GetTlsSecureOptions().GetVerifyMode());
449
450 const auto protocolVec = config.GetTlsSecureOptions().GetProtocolChain();
451 if (!protocolVec.empty()) {
452 TLSServerConfiguration_.SetProtocol(protocolVec);
453 }
454 }
455
GetCertificate(const TlsSocket::GetCertificateCallback & callback)456 void TLSSocketServer::GetCertificate(const TlsSocket::GetCertificateCallback &callback)
457 {
458 const auto &cert = TLSServerConfiguration_.GetCertificate();
459 NETSTACK_LOGI("cert der is %{public}d", cert.encodingFormat);
460 if (!cert.data.Length()) {
461 CallOnErrorCallback(-1, "cert not data Length");
462 callback(-1, {});
463 return;
464 }
465 callback(TlsSocket::TLSSOCKET_SUCCESS, cert);
466 }
467
GetRemoteCertificate(const int socketFd,const TlsSocket::GetRemoteCertificateCallback & callback)468 void TLSSocketServer::GetRemoteCertificate(const int socketFd, const TlsSocket::GetRemoteCertificateCallback &callback)
469 {
470 auto connect_iterator = clientIdConnections_.find(socketFd);
471 if (connect_iterator == clientIdConnections_.end()) {
472 NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
473 CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
474 callback(TlsSocket::TLS_ERR_SYS_EINVAL, {});
475 return;
476 }
477 auto connect = connect_iterator->second;
478 const auto &remoteCert = connect->GetRemoteCertRawData();
479 if (!remoteCert.data.Length()) {
480 int resErr = ConvertSSLError(connect->GetSSL());
481 NETSTACK_LOGE("GetRemoteCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
482 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
483 callback(resErr, {});
484 return;
485 }
486 callback(TlsSocket::TLSSOCKET_SUCCESS, remoteCert);
487 }
488
GetProtocol(const TlsSocket::GetProtocolCallback & callback)489 void TLSSocketServer::GetProtocol(const TlsSocket::GetProtocolCallback &callback)
490 {
491 if (TLSServerConfiguration_.GetProtocol() == TlsSocket::TLS_V1_3) {
492 callback(TlsSocket::TLSSOCKET_SUCCESS, TlsSocket::PROTOCOL_TLS_V13);
493 return;
494 }
495 callback(TlsSocket::TLSSOCKET_SUCCESS, TlsSocket::PROTOCOL_TLS_V12);
496 }
497
GetCipherSuite(const int socketFd,const TlsSocket::GetCipherSuiteCallback & callback)498 void TLSSocketServer::GetCipherSuite(const int socketFd, const TlsSocket::GetCipherSuiteCallback &callback)
499 {
500 auto connect_iterator = clientIdConnections_.find(socketFd);
501 if (connect_iterator == clientIdConnections_.end()) {
502 NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
503 CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
504 callback(TlsSocket::TLS_ERR_SYS_EINVAL, {});
505 return;
506 }
507 auto connect = connect_iterator->second;
508 auto cipherSuite = connect->GetCipherSuite();
509 if (cipherSuite.empty()) {
510 NETSTACK_LOGE("GetCipherSuite errno %{public}d", errno);
511 int resErr = ConvertSSLError(connect->GetSSL());
512 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
513 callback(resErr, cipherSuite);
514 return;
515 }
516 callback(TlsSocket::TLSSOCKET_SUCCESS, cipherSuite);
517 }
518
GetSignatureAlgorithms(const int socketFd,const TlsSocket::GetSignatureAlgorithmsCallback & callback)519 void TLSSocketServer::GetSignatureAlgorithms(const int socketFd,
520 const TlsSocket::GetSignatureAlgorithmsCallback &callback)
521 {
522 auto connect_iterator = clientIdConnections_.find(socketFd);
523 if (connect_iterator == clientIdConnections_.end()) {
524 NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
525 CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
526 callback(TlsSocket::TLS_ERR_SYS_EINVAL, {});
527 return;
528 }
529 auto connect = connect_iterator->second;
530 auto signatureAlgorithms = connect->GetSignatureAlgorithms();
531 if (signatureAlgorithms.empty()) {
532 NETSTACK_LOGE("GetCipherSuite errno %{public}d", errno);
533 int resErr = ConvertSSLError(connect->GetSSL());
534 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
535 callback(resErr, signatureAlgorithms);
536 return;
537 }
538 callback(TlsSocket::TLSSOCKET_SUCCESS, signatureAlgorithms);
539 }
540
OnMessage(const OnMessageCallback & onMessageCallback)541 void TLSSocketServer::Connection::OnMessage(const OnMessageCallback &onMessageCallback)
542 {
543 onMessageCallback_ = onMessageCallback;
544 }
545
OnClose(const OnCloseCallback & onCloseCallback)546 void TLSSocketServer::Connection::OnClose(const OnCloseCallback &onCloseCallback)
547 {
548 onCloseCallback_ = onCloseCallback;
549 }
550
OnConnect(const OnConnectCallback & onConnectCallback)551 void TLSSocketServer::OnConnect(const OnConnectCallback &onConnectCallback)
552 {
553 std::lock_guard<std::mutex> lock(mutex_);
554 onConnectCallback_ = onConnectCallback;
555 }
556
OnError(const TlsSocket::OnErrorCallback & onErrorCallback)557 void TLSSocketServer::OnError(const TlsSocket::OnErrorCallback &onErrorCallback)
558 {
559 std::lock_guard<std::mutex> lock(mutex_);
560 onErrorCallback_ = onErrorCallback;
561 }
562
OffMessage()563 void TLSSocketServer::Connection::OffMessage()
564 {
565 if (onMessageCallback_) {
566 onMessageCallback_ = nullptr;
567 }
568 }
569
OffConnect()570 void TLSSocketServer::OffConnect()
571 {
572 std::lock_guard<std::mutex> lock(mutex_);
573 if (onConnectCallback_) {
574 onConnectCallback_ = nullptr;
575 }
576 }
577
OnError(const TlsSocket::OnErrorCallback & onErrorCallback)578 void TLSSocketServer::Connection::OnError(const TlsSocket::OnErrorCallback &onErrorCallback)
579 {
580 onErrorCallback_ = onErrorCallback;
581 }
582
OffClose()583 void TLSSocketServer::Connection::OffClose()
584 {
585 if (onCloseCallback_) {
586 onCloseCallback_ = nullptr;
587 }
588 }
589
OffError()590 void TLSSocketServer::Connection::OffError()
591 {
592 onErrorCallback_ = nullptr;
593 }
594
CallOnErrorCallback(int32_t err,const std::string & errString)595 void TLSSocketServer::Connection::CallOnErrorCallback(int32_t err, const std::string &errString)
596 {
597 TlsSocket::OnErrorCallback CallBackfunc = nullptr;
598 {
599 if (onErrorCallback_) {
600 CallBackfunc = onErrorCallback_;
601 }
602 }
603
604 if (CallBackfunc) {
605 CallBackfunc(err, errString);
606 }
607 }
OffError()608 void TLSSocketServer::OffError()
609 {
610 std::lock_guard<std::mutex> lock(mutex_);
611 if (onErrorCallback_) {
612 onErrorCallback_ = nullptr;
613 }
614 }
615
MakeIpSocket(sa_family_t family)616 void TLSSocketServer::MakeIpSocket(sa_family_t family)
617 {
618 if (family != AF_INET && family != AF_INET6) {
619 return;
620 }
621 int sock = socket(family, SOCK_STREAM, IPPROTO_IP);
622 if (sock < 0) {
623 int resErr = ConvertErrno();
624 NETSTACK_LOGE("Create socket failed (%{public}d:%{public}s)", errno, MakeErrnoString().c_str());
625 CallOnErrorCallback(resErr, MakeErrnoString());
626 return;
627 }
628 listenSocketFd_ = sock;
629 }
630
CallOnErrorCallback(int32_t err,const std::string & errString)631 void TLSSocketServer::CallOnErrorCallback(int32_t err, const std::string &errString)
632 {
633 TlsSocket::OnErrorCallback CallBackfunc = nullptr;
634 {
635 std::lock_guard<std::mutex> lock(mutex_);
636 if (onErrorCallback_) {
637 CallBackfunc = onErrorCallback_;
638 }
639 }
640
641 if (CallBackfunc) {
642 CallBackfunc(err, errString);
643 }
644 }
GetAddr(const Socket::NetAddress & address,sockaddr_in * addr4,sockaddr_in6 * addr6,sockaddr ** addr,socklen_t * len)645 void TLSSocketServer::GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6,
646 sockaddr **addr, socklen_t *len)
647 {
648 if (!addr6 || !addr4 || !len) {
649 return;
650 }
651 sa_family_t family = address.GetSaFamily();
652 if (family == AF_INET) {
653 addr4->sin_family = AF_INET;
654 addr4->sin_port = htons(address.GetPort());
655 addr4->sin_addr.s_addr = inet_addr(address.GetAddress().c_str());
656 *addr = reinterpret_cast<sockaddr *>(addr4);
657 *len = sizeof(sockaddr_in);
658 } else if (family == AF_INET6) {
659 addr6->sin6_family = AF_INET6;
660 addr6->sin6_port = htons(address.GetPort());
661 inet_pton(AF_INET6, address.GetAddress().c_str(), &addr6->sin6_addr);
662 *addr = reinterpret_cast<sockaddr *>(addr6);
663 *len = sizeof(sockaddr_in6);
664 }
665 }
666
GetConnectionByClientID(int clientid)667 std::shared_ptr<TLSSocketServer::Connection> TLSSocketServer::GetConnectionByClientID(int clientid)
668 {
669 std::shared_ptr<Connection> ptrConnection = nullptr;
670
671 auto it = clientIdConnections_.find(clientid);
672 if (it != clientIdConnections_.end()) {
673 ptrConnection = it->second;
674 }
675
676 return ptrConnection;
677 }
678
GetConnectionClientCount()679 int TLSSocketServer::GetConnectionClientCount()
680 {
681 return g_userCounter;
682 }
683
CallListenCallback(int32_t err,ListenCallback callback)684 void TLSSocketServer::CallListenCallback(int32_t err, ListenCallback callback)
685 {
686 ListenCallback CallBackfunc = nullptr;
687 {
688 std::lock_guard<std::mutex> lock(mutex_);
689 if (callback) {
690 CallBackfunc = callback;
691 }
692 }
693
694 if (CallBackfunc) {
695 CallBackfunc(err);
696 }
697 }
698
SetAddress(const Socket::NetAddress address)699 void TLSSocketServer::Connection::SetAddress(const Socket::NetAddress address)
700 {
701 address_ = address;
702 }
703
GetRemoteCertRawData() const704 const TlsSocket::X509CertRawData &TLSSocketServer::Connection::GetRemoteCertRawData() const
705 {
706 return remoteRawData_;
707 }
708
~Connection()709 TLSSocketServer::Connection::~Connection()
710 {
711 Close();
712 }
713
TlsAcceptToHost(int sock,const TlsSocket::TLSConnectOptions & options)714 bool TLSSocketServer::Connection::TlsAcceptToHost(int sock, const TlsSocket::TLSConnectOptions &options)
715 {
716 SetTlsConfiguration(options);
717 std::string cipherSuite = options.GetTlsSecureOptions().GetCipherSuite();
718 if (!cipherSuite.empty()) {
719 connectionConfiguration_.SetCipherSuite(cipherSuite);
720 }
721 std::string signatureAlgorithms = options.GetTlsSecureOptions().GetSignatureAlgorithms();
722 if (!signatureAlgorithms.empty()) {
723 connectionConfiguration_.SetSignatureAlgorithms(signatureAlgorithms);
724 }
725 const auto protocolVec = options.GetTlsSecureOptions().GetProtocolChain();
726 if (!protocolVec.empty()) {
727 connectionConfiguration_.SetProtocol(protocolVec);
728 }
729 connectionConfiguration_.SetVerifyMode(options.GetTlsSecureOptions().GetVerifyMode());
730 socketFd_ = sock;
731 return StartTlsAccept(options);
732 }
733
SetTlsConfiguration(const TlsSocket::TLSConnectOptions & config)734 void TLSSocketServer::Connection::SetTlsConfiguration(const TlsSocket::TLSConnectOptions &config)
735 {
736 connectionConfiguration_.SetPrivateKey(config.GetTlsSecureOptions().GetKey(),
737 config.GetTlsSecureOptions().GetKeyPass());
738 connectionConfiguration_.SetLocalCertificate(config.GetTlsSecureOptions().GetCert());
739 connectionConfiguration_.SetCaCertificate(config.GetTlsSecureOptions().GetCaChain());
740 }
741
Send(const std::string & data)742 bool TLSSocketServer::Connection::Send(const std::string &data)
743 {
744 NETSTACK_LOGD("data to send :%{public}s", data.c_str());
745 if (data.empty()) {
746 NETSTACK_LOGE("data is empty");
747 return false;
748 }
749 if (!ssl_) {
750 NETSTACK_LOGE("ssl is null");
751 return false;
752 }
753 int len = SSL_write(ssl_, data.c_str(), data.length());
754 if (len < 0) {
755 int resErr = ConvertSSLError(GetSSL());
756 NETSTACK_LOGE("data '%{public}s' send failed!The error code is %{public}d, The error message is'%{public}s'",
757 data.c_str(), resErr, MakeSSLErrorString(resErr).c_str());
758 return false;
759 }
760 NETSTACK_LOGD("data '%{public}s' Sent successfully,sent in total %{public}d bytes!", data.c_str(), len);
761 return true;
762 }
763
Recv(char * buffer,int maxBufferSize)764 int TLSSocketServer::Connection::Recv(char *buffer, int maxBufferSize)
765 {
766 if (!ssl_) {
767 NETSTACK_LOGE("ssl is null");
768 return SSL_ERROR_RETURN;
769 }
770 return SSL_read(ssl_, buffer, maxBufferSize);
771 }
772
Close()773 bool TLSSocketServer::Connection::Close()
774 {
775 if (!ssl_) {
776 NETSTACK_LOGE("ssl is null");
777 return false;
778 }
779 int result = SSL_shutdown(ssl_);
780 if (result < 0) {
781 int resErr = ConvertSSLError(GetSSL());
782 NETSTACK_LOGE("Error in shutdown, errno is %{public}d, error info is %{public}s", resErr,
783 MakeSSLErrorString(resErr).c_str());
784 }
785 SSL_free(ssl_);
786 ssl_ = nullptr;
787 if (socketFd_ != -1) {
788 shutdown(socketFd_, SHUT_RDWR);
789 close(socketFd_);
790 socketFd_ = -1;
791 }
792 if (!tlsContextServerPointer_) {
793 NETSTACK_LOGE("Tls context pointer is null");
794 return false;
795 }
796 tlsContextServerPointer_->CloseCtx();
797 return true;
798 }
799
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)800 bool TLSSocketServer::Connection::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
801 {
802 if (!ssl_) {
803 NETSTACK_LOGE("ssl is null");
804 return false;
805 }
806 size_t pos = 0;
807 size_t len = std::accumulate(alpnProtocols.begin(), alpnProtocols.end(), static_cast<size_t>(0),
808 [](size_t init, const std::string &alpnProt) { return init + alpnProt.length(); });
809 auto result = std::make_unique<unsigned char[]>(alpnProtocols.size() + len);
810 for (const auto &str : alpnProtocols) {
811 len = str.length();
812 result[pos++] = len;
813 if (!strcpy_s(reinterpret_cast<char *>(&result[pos]), len, str.c_str())) {
814 NETSTACK_LOGE("strcpy_s failed");
815 return false;
816 }
817 pos += len;
818 }
819 result[pos] = '\0';
820
821 NETSTACK_LOGD("alpnProtocols after splicing %{public}s", result.get());
822 if (SSL_set_alpn_protos(ssl_, result.get(), pos)) {
823 int resErr = ConvertSSLError(GetSSL());
824 NETSTACK_LOGE("Failed to set negotiable protocol list, errno is %{public}d, error info is %{public}s", resErr,
825 MakeSSLErrorString(resErr).c_str());
826 return false;
827 }
828 return true;
829 }
830
MakeRemoteInfo(Socket::SocketRemoteInfo & remoteInfo)831 void TLSSocketServer::Connection::MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo)
832 {
833 remoteInfo.SetAddress(address_.GetAddress());
834 remoteInfo.SetPort(address_.GetPort());
835 remoteInfo.SetFamily(address_.GetSaFamily());
836 }
837
GetTlsConfiguration() const838 TlsSocket::TLSConfiguration TLSSocketServer::Connection::GetTlsConfiguration() const
839 {
840 return connectionConfiguration_;
841 }
842
GetCipherSuite() const843 std::vector<std::string> TLSSocketServer::Connection::GetCipherSuite() const
844 {
845 if (!ssl_) {
846 NETSTACK_LOGE("ssl in null");
847 return {};
848 }
849 STACK_OF(SSL_CIPHER) *sk = SSL_get_ciphers(ssl_);
850 if (!sk) {
851 NETSTACK_LOGE("get ciphers failed");
852 return {};
853 }
854 TlsSocket::CipherSuite cipherSuite;
855 std::vector<std::string> cipherSuiteVec;
856 for (int i = 0; i < sk_SSL_CIPHER_num(sk); i++) {
857 const SSL_CIPHER *c = sk_SSL_CIPHER_value(sk, i);
858 cipherSuite.cipherName_ = SSL_CIPHER_get_name(c);
859 cipherSuiteVec.push_back(cipherSuite.cipherName_);
860 }
861 return cipherSuiteVec;
862 }
863
GetRemoteCertificate() const864 std::string TLSSocketServer::Connection::GetRemoteCertificate() const
865 {
866 return remoteCert_;
867 }
868
GetCertificate() const869 const TlsSocket::X509CertRawData &TLSSocketServer::Connection::GetCertificate() const
870 {
871 return connectionConfiguration_.GetCertificate();
872 }
873
GetSignatureAlgorithms() const874 std::vector<std::string> TLSSocketServer::Connection::GetSignatureAlgorithms() const
875 {
876 return signatureAlgorithms_;
877 }
878
GetProtocol() const879 std::string TLSSocketServer::Connection::GetProtocol() const
880 {
881 if (!ssl_) {
882 NETSTACK_LOGE("ssl in null");
883 return PROTOCOL_UNKNOW;
884 }
885 if (connectionConfiguration_.GetProtocol() == TlsSocket::TLS_V1_3) {
886 return TlsSocket::PROTOCOL_TLS_V13;
887 }
888 return TlsSocket::PROTOCOL_TLS_V12;
889 }
890
SetSharedSigals()891 bool TLSSocketServer::Connection::SetSharedSigals()
892 {
893 if (!ssl_) {
894 NETSTACK_LOGE("ssl is null");
895 return false;
896 }
897 int number = SSL_get_shared_sigalgs(ssl_, 0, nullptr, nullptr, nullptr, nullptr, nullptr);
898 if (!number) {
899 NETSTACK_LOGE("SSL_get_shared_sigalgs return value error");
900 return false;
901 }
902 for (int i = 0; i < number; i++) {
903 int hash_nid;
904 int sign_nid;
905 std::string sig_with_md;
906 SSL_get_shared_sigalgs(ssl_, i, &sign_nid, &hash_nid, nullptr, nullptr, nullptr);
907 switch (sign_nid) {
908 case EVP_PKEY_RSA:
909 sig_with_md = SIGN_NID_RSA;
910 break;
911 case EVP_PKEY_RSA_PSS:
912 sig_with_md = SIGN_NID_RSA_PSS;
913 break;
914 case EVP_PKEY_DSA:
915 sig_with_md = SIGN_NID_DSA;
916 break;
917 case EVP_PKEY_EC:
918 sig_with_md = SIGN_NID_ECDSA;
919 break;
920 case NID_ED25519:
921 sig_with_md = SIGN_NID_ED;
922 break;
923 case NID_ED448:
924 sig_with_md = SIGN_NID_ED_FOUR_FOUR_EIGHT;
925 break;
926 default:
927 const char *sn = OBJ_nid2sn(sign_nid);
928 sig_with_md = (sn != nullptr) ? (std::string(sn) + OPERATOR_PLUS_SIGN) : SIGN_NID_UNDEF_ADD;
929 }
930 const char *sn_hash = OBJ_nid2sn(hash_nid);
931 sig_with_md += (sn_hash != nullptr) ? std::string(sn_hash) : SIGN_NID_UNDEF;
932 signatureAlgorithms_.push_back(sig_with_md);
933 }
934 return true;
935 }
936
GetSSL() const937 ssl_st *TLSSocketServer::Connection::GetSSL() const
938 {
939 return ssl_;
940 }
941
GetAddress() const942 Socket::NetAddress TLSSocketServer::Connection::GetAddress() const
943 {
944 return address_;
945 }
946
GetSocketFd() const947 int TLSSocketServer::Connection::GetSocketFd() const
948 {
949 return socketFd_;
950 }
951
GetEventManager() const952 std::shared_ptr<EventManager> TLSSocketServer::Connection::GetEventManager() const
953 {
954 return eventManager_;
955 }
956
SetEventManager(std::shared_ptr<EventManager> eventManager)957 void TLSSocketServer::Connection::SetEventManager(std::shared_ptr<EventManager> eventManager)
958 {
959 eventManager_ = eventManager;
960 }
961
SetClientID(int32_t clientID)962 void TLSSocketServer::Connection::SetClientID(int32_t clientID)
963 {
964 clientID_ = clientID;
965 }
966
GetClientID()967 int TLSSocketServer::Connection::GetClientID()
968 {
969 return clientID_;
970 }
971
StartTlsAccept(const TlsSocket::TLSConnectOptions & options)972 bool TLSSocketServer::Connection::StartTlsAccept(const TlsSocket::TLSConnectOptions &options)
973 {
974 if (!CreatTlsContext()) {
975 NETSTACK_LOGE("failed to create tls context");
976 return false;
977 }
978 if (!StartShakingHands(options)) {
979 NETSTACK_LOGE("failed to shaking hands");
980 return false;
981 }
982 return true;
983 }
984
CreatTlsContext()985 bool TLSSocketServer::Connection::CreatTlsContext()
986 {
987 tlsContextServerPointer_ = TlsSocket::TLSContextServer::CreateConfiguration(connectionConfiguration_);
988 if (!tlsContextServerPointer_) {
989 NETSTACK_LOGE("failed to create tls context pointer");
990 return false;
991 }
992 if (!(ssl_ = tlsContextServerPointer_->CreateSsl())) {
993 NETSTACK_LOGE("failed to create ssl session");
994 return false;
995 }
996 SSL_set_fd(ssl_, socketFd_);
997 SSL_set_accept_state(ssl_);
998 return true;
999 }
1000
StartShakingHands(const TlsSocket::TLSConnectOptions & options)1001 bool TLSSocketServer::Connection::StartShakingHands(const TlsSocket::TLSConnectOptions &options)
1002 {
1003 if (!ssl_) {
1004 NETSTACK_LOGE("ssl is null");
1005 return false;
1006 }
1007 int result = SSL_accept(ssl_);
1008 if (result == -1) {
1009 int errorStatus = ConvertSSLError(ssl_);
1010 NETSTACK_LOGE("SSL connect is error, errno is %{public}d, error info is %{public}s", errorStatus,
1011 MakeSSLErrorString(errorStatus).c_str());
1012 return false;
1013 }
1014
1015 std::vector<std::string> SslProtocolVer({SSL_get_version(ssl_)});
1016 connectionConfiguration_.SetProtocol({SslProtocolVer});
1017
1018 std::string list = SSL_get_cipher_list(ssl_, 0);
1019 NETSTACK_LOGI("SSL_get_cipher_list: %{public}s", list.c_str());
1020 connectionConfiguration_.SetCipherSuite(list);
1021 if (!SetSharedSigals()) {
1022 NETSTACK_LOGE("Failed to set sharedSigalgs");
1023 }
1024
1025 if (!GetRemoteCertificateFromPeer()) {
1026 NETSTACK_LOGE("Failed to get remote certificate");
1027 }
1028 if (peerX509_ != nullptr) {
1029 NETSTACK_LOGE("peer x509Certificates is null");
1030
1031 if (!SetRemoteCertRawData()) {
1032 NETSTACK_LOGE("Failed to set remote x509 certificata Serialization data");
1033 }
1034 TlsSocket::CheckServerIdentity checkServerIdentity = options.GetCheckServerIdentity();
1035 if (!checkServerIdentity) {
1036 CheckServerIdentityLegal(hostName_, peerX509_);
1037 } else {
1038 checkServerIdentity(hostName_, {remoteCert_});
1039 }
1040 }
1041 return true;
1042 }
1043
GetRemoteCertificateFromPeer()1044 bool TLSSocketServer::Connection::GetRemoteCertificateFromPeer()
1045 {
1046 peerX509_ = SSL_get_peer_certificate(ssl_);
1047
1048 if (SSL_get_verify_result(ssl_) == X509_V_OK) {
1049 NETSTACK_LOGE("SSL_get_verify_result ==X509_V_OK");
1050 }
1051
1052 if (peerX509_ == nullptr) {
1053 int resErr = ConvertSSLError(GetSSL());
1054 NETSTACK_LOGE("open fail errno, errno is %{public}d, error info is %{public}s", resErr,
1055 MakeSSLErrorString(resErr).c_str());
1056 return false;
1057 }
1058 BIO *bio = BIO_new(BIO_s_mem());
1059 if (!bio) {
1060 NETSTACK_LOGE("TlsSocket::SetRemoteCertificate bio is null");
1061 return false;
1062 }
1063 X509_print(bio, peerX509_);
1064 char data[REMOTE_CERT_LEN] = {0};
1065 if (!BIO_read(bio, data, REMOTE_CERT_LEN)) {
1066 NETSTACK_LOGE("BIO_read function returns error");
1067 BIO_free(bio);
1068 return false;
1069 }
1070 BIO_free(bio);
1071 remoteCert_ = std::string(data);
1072 return true;
1073 }
1074
SetRemoteCertRawData()1075 bool TLSSocketServer::Connection::SetRemoteCertRawData()
1076 {
1077 if (peerX509_ == nullptr) {
1078 NETSTACK_LOGE("peerX509 is null");
1079 return false;
1080 }
1081 int32_t length = i2d_X509(peerX509_, nullptr);
1082 if (length <= 0) {
1083 NETSTACK_LOGE("Failed to convert peerX509 to der format");
1084 return false;
1085 }
1086 unsigned char *der = nullptr;
1087 (void)i2d_X509(peerX509_, &der);
1088 TlsSocket::SecureData data(der, length);
1089 remoteRawData_.data = data;
1090 OPENSSL_free(der);
1091 remoteRawData_.encodingFormat = TlsSocket::EncodingFormat::DER;
1092 return true;
1093 }
1094
StartsWith(const std::string & s,const std::string & prefix)1095 static bool StartsWith(const std::string &s, const std::string &prefix)
1096 {
1097 return s.size() >= prefix.size() && s.compare(0, prefix.size(), prefix) == 0;
1098 }
CheckIpAndDnsName(const std::string & hostName,std::vector<std::string> & dnsNames,std::vector<std::string> & ips,const X509 * x509Certificates,std::tuple<bool,std::string> & result)1099 void CheckIpAndDnsName(const std::string &hostName, std::vector<std::string> &dnsNames, std::vector<std::string> &ips,
1100 const X509 *x509Certificates, std::tuple<bool, std::string> &result)
1101 {
1102 bool valid = false;
1103 std::string reason = UNKNOW_REASON;
1104 int index = X509_get_ext_by_NID(x509Certificates, NID_commonName, -1);
1105 if (IsIP(hostName)) {
1106 auto it = find(ips.begin(), ips.end(), hostName);
1107 if (it == ips.end()) {
1108 reason = IP + hostName + " is not in the cert's list";
1109 }
1110 result = {valid, reason};
1111 return;
1112 }
1113 std::string tempHostName = "" + hostName;
1114 std::string tmpStr = "";
1115 if (!dnsNames.empty() || index > 0) {
1116 std::vector<std::string> hostParts = SplitHostName(tempHostName);
1117 if (!dnsNames.empty()) {
1118 valid = SeekIntersection(hostParts, dnsNames);
1119 tmpStr = ". is not in the cert's altnames";
1120 } else {
1121 char commonNameBuf[COMMON_NAME_BUF_SIZE] = {0};
1122 X509_NAME *pSubName = nullptr;
1123 int len = X509_NAME_get_text_by_NID(pSubName, NID_commonName, commonNameBuf, COMMON_NAME_BUF_SIZE);
1124 if (len > 0) {
1125 std::vector<std::string> commonNameVec;
1126 commonNameVec.emplace_back(commonNameBuf);
1127 valid = SeekIntersection(hostParts, commonNameVec);
1128 tmpStr = ". is not cert's CN";
1129 }
1130 }
1131 if (!valid) {
1132 reason = HOST_NAME + tempHostName + tmpStr;
1133 }
1134
1135 result = {valid, reason};
1136 return;
1137 }
1138 reason = "Cert does not contain a DNS name";
1139 result = {valid, reason};
1140 }
1141
CheckServerIdentityLegal(const std::string & hostName,const X509 * x509Certificates)1142 std::string TLSSocketServer::Connection::CheckServerIdentityLegal(const std::string &hostName,
1143 const X509 *x509Certificates)
1144 {
1145 std::string hostname = hostName;
1146 X509_NAME *subjectName = X509_get_subject_name(x509Certificates);
1147 if (!subjectName) {
1148 return "subject name is null";
1149 }
1150 char subNameBuf[BUF_SIZE] = {0};
1151 X509_NAME_oneline(subjectName, subNameBuf, BUF_SIZE);
1152 int index = X509_get_ext_by_NID(x509Certificates, NID_subject_alt_name, -1);
1153 if (index < 0) {
1154 return "X509 get ext nid error";
1155 }
1156 X509_EXTENSION *ext = X509_get_ext(x509Certificates, index);
1157 if (ext == nullptr) {
1158 return "X509 get ext error";
1159 }
1160 ASN1_OBJECT *obj = nullptr;
1161 obj = X509_EXTENSION_get_object(ext);
1162 char subAltNameBuf[BUF_SIZE] = {0};
1163 OBJ_obj2txt(subAltNameBuf, BUF_SIZE, obj, 0);
1164 NETSTACK_LOGD("extions obj : %{public}s\n", subAltNameBuf);
1165
1166 return CheckServerIdentityLegal(hostName, ext, x509Certificates);
1167 }
1168
CheckServerIdentityLegal(const std::string & hostName,X509_EXTENSION * ext,const X509 * x509Certificates)1169 std::string TLSSocketServer::Connection::CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext,
1170 const X509 *x509Certificates)
1171 {
1172 ASN1_OCTET_STRING *extData = X509_EXTENSION_get_data(ext);
1173 std::string altNames = reinterpret_cast<char *>(extData->data);
1174 std::string hostname = "" + hostName;
1175 BIO *bio = BIO_new(BIO_s_file());
1176 if (!bio) {
1177 return "bio is null";
1178 }
1179 BIO_set_fp(bio, stdout, BIO_NOCLOSE);
1180 ASN1_STRING_print(bio, extData);
1181 std::vector<std::string> dnsNames = {};
1182 std::vector<std::string> ips = {};
1183 constexpr int DNS_NAME_IDX = 4;
1184 constexpr int IP_NAME_IDX = 11;
1185 if (!altNames.empty()) {
1186 std::vector<std::string> splitAltNames;
1187 if (altNames.find('\"') != std::string::npos) {
1188 splitAltNames = SplitEscapedAltNames(altNames);
1189 } else {
1190 splitAltNames = CommonUtils::Split(altNames, SPLIT_ALT_NAMES);
1191 }
1192 for (auto const &iter : splitAltNames) {
1193 if (StartsWith(iter, DNS)) {
1194 dnsNames.push_back(iter.substr(DNS_NAME_IDX));
1195 } else if (StartsWith(iter, IP_ADDRESS)) {
1196 ips.push_back(iter.substr(IP_NAME_IDX));
1197 }
1198 }
1199 }
1200 std::tuple<bool, std::string> result;
1201 CheckIpAndDnsName(hostName, dnsNames, ips, x509Certificates, result);
1202 if (!std::get<0>(result)) {
1203 return "Hostname/IP does not match certificate's altnames: " + std::get<1>(result);
1204 }
1205 return HOST_NAME + hostname + ". is cert's CN";
1206 }
1207
RemoveConnect(int socketFd)1208 void TLSSocketServer::RemoveConnect(int socketFd)
1209 {
1210 std::shared_ptr<Connection> ptrConnection = nullptr;
1211 {
1212 std::lock_guard<std::mutex> its_lock(connectMutex_);
1213
1214 for (auto it = clientIdConnections_.begin(); it != clientIdConnections_.end();) {
1215 if (it->second->GetSocketFd() == socketFd) {
1216 break;
1217 } else {
1218 ++it;
1219 }
1220 }
1221 }
1222 if (ptrConnection != nullptr) {
1223 ptrConnection->CallOnCloseCallback(static_cast<unsigned int>(socketFd));
1224 ptrConnection->Close();
1225 }
1226 }
1227
RecvRemoteInfo(int socketFd,int index)1228 int TLSSocketServer::RecvRemoteInfo(int socketFd, int index)
1229 {
1230 {
1231 std::lock_guard<std::mutex> its_lock(connectMutex_);
1232 for (auto it = clientIdConnections_.begin(); it != clientIdConnections_.end();) {
1233 if (it->second->GetSocketFd() == socketFd) {
1234 char buffer[MAX_BUFFER_SIZE];
1235 if (memset_s(buffer, MAX_BUFFER_SIZE, 0, MAX_BUFFER_SIZE) != EOK) {
1236 NETSTACK_LOGE("memcpy_s failed");
1237 break;
1238 }
1239 int len = it->second->Recv(buffer, MAX_BUFFER_SIZE);
1240 NETSTACK_LOGE("revc message is size is %{public}d buffer is %{public}s ", len, buffer);
1241 if (len > 0) {
1242 Socket::SocketRemoteInfo remoteInfo;
1243 remoteInfo.SetSize(strlen(buffer));
1244 it->second->MakeRemoteInfo(remoteInfo);
1245 it->second->CallOnMessageCallback(socketFd, buffer, remoteInfo);
1246 return len;
1247 }
1248 break;
1249 } else {
1250 ++it;
1251 }
1252 }
1253 }
1254 RemoveConnect(socketFd);
1255 DropFdFromPollList(index);
1256 return -1;
1257 }
1258
CallOnMessageCallback(int32_t socketFd,const std::string & data,const Socket::SocketRemoteInfo & remoteInfo)1259 void TLSSocketServer::Connection::CallOnMessageCallback(int32_t socketFd, const std::string &data,
1260 const Socket::SocketRemoteInfo &remoteInfo)
1261 {
1262 OnMessageCallback CallBackfunc = nullptr;
1263 {
1264 if (onMessageCallback_) {
1265 CallBackfunc = onMessageCallback_;
1266 }
1267 }
1268
1269 if (CallBackfunc) {
1270 CallBackfunc(socketFd, data, remoteInfo);
1271 }
1272 }
1273
AddConnect(int socketFd,std::shared_ptr<Connection> connection)1274 void TLSSocketServer::AddConnect(int socketFd, std::shared_ptr<Connection> connection)
1275 {
1276 std::lock_guard<std::mutex> its_lock(connectMutex_);
1277 clientIdConnections_[connection->GetClientID()] = connection;
1278 }
1279
CallOnCloseCallback(const int32_t socketFd)1280 void TLSSocketServer::Connection::CallOnCloseCallback(const int32_t socketFd)
1281 {
1282 OnCloseCallback CallBackfunc = nullptr;
1283 {
1284 if (onCloseCallback_) {
1285 CallBackfunc = onCloseCallback_;
1286 }
1287 }
1288
1289 if (CallBackfunc) {
1290 CallBackfunc(socketFd);
1291 }
1292 }
1293
CallOnConnectCallback(const int32_t socketFd,std::shared_ptr<EventManager> eventManager)1294 void TLSSocketServer::CallOnConnectCallback(const int32_t socketFd, std::shared_ptr<EventManager> eventManager)
1295 {
1296 OnConnectCallback CallBackfunc = nullptr;
1297 {
1298 std::lock_guard<std::mutex> lock(mutex_);
1299 if (onConnectCallback_) {
1300 CallBackfunc = onConnectCallback_;
1301 }
1302 }
1303
1304 if (CallBackfunc) {
1305 CallBackfunc(socketFd, eventManager);
1306 } else {
1307 NETSTACK_LOGE("CallOnConnectCallback fun === null");
1308 }
1309 }
1310
ProcessTcpAccept(const TlsSocket::TLSConnectOptions & tlsListenOptions,int clientID)1311 void TLSSocketServer::ProcessTcpAccept(const TlsSocket::TLSConnectOptions &tlsListenOptions, int clientID)
1312 {
1313 #if !defined(CROSS_PLATFORM)
1314 struct sockaddr_in clientAddress;
1315 socklen_t clientAddrLength = sizeof(clientAddress);
1316 int connectFD = accept(listenSocketFd_, (struct sockaddr *)&clientAddress, &clientAddrLength);
1317 if (connectFD < 0) {
1318 int resErr = ConvertErrno();
1319 NETSTACK_LOGE("Server accept new client ERROR");
1320 CallOnErrorCallback(resErr, MakeErrnoString());
1321 return;
1322 }
1323 NETSTACK_LOGI("Server accept new client SUCCESS");
1324 std::shared_ptr<Connection> connection = std::make_shared<Connection>();
1325 Socket::NetAddress netAddress;
1326 char clientIp[INET_ADDRSTRLEN];
1327 inet_ntop(address_.GetSaFamily(), &clientAddress.sin_addr, clientIp, INET_ADDRSTRLEN);
1328 int clientPort = ntohs(clientAddress.sin_port);
1329 netAddress.SetAddress(clientIp);
1330 netAddress.SetPort(clientPort);
1331 netAddress.SetFamilyBySaFamily(address_.GetSaFamily());
1332 connection->SetAddress(netAddress);
1333 connection->SetClientID(clientID);
1334 auto res = connection->TlsAcceptToHost(connectFD, tlsListenOptions);
1335 if (!res) {
1336 int resErr = ConvertSSLError(connection->GetSSL());
1337 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1338 return;
1339 }
1340 if (g_userCounter >= USER_LIMIT) {
1341 const std::string info = "Too many users!";
1342 connection->Send(info);
1343 connection->Close();
1344 NETSTACK_LOGE("Too many users");
1345 close(connectFD);
1346 CallOnErrorCallback(-1, "Too many users");
1347 return;
1348 }
1349 g_userCounter++;
1350 fds_[g_userCounter].fd = connectFD;
1351 fds_[g_userCounter].events = POLLIN | POLLRDHUP | POLLERR;
1352 fds_[g_userCounter].revents = 0;
1353 AddConnect(connectFD, connection);
1354 auto ptrEventManager = std::make_shared<EventManager>();
1355
1356 ptrEventManager->SetData(this);
1357 connection->SetEventManager(ptrEventManager);
1358 CallOnConnectCallback(clientID, ptrEventManager);
1359 NETSTACK_LOGI("New client come in, fd is %{public}d", connectFD);
1360 #endif
1361 }
1362
InitPollList(int & listendFd)1363 void TLSSocketServer::InitPollList(int &listendFd)
1364 {
1365 for (int i = 1; i <= USER_LIMIT; ++i) {
1366 fds_[i].fd = -1;
1367 fds_[i].events = 0;
1368 }
1369 fds_[0].fd = listendFd;
1370 fds_[0].events = POLLIN | POLLERR;
1371 fds_[0].revents = 0;
1372 }
1373
DropFdFromPollList(int & fd_index)1374 void TLSSocketServer::DropFdFromPollList(int &fd_index)
1375 {
1376 if (g_userCounter < 0) {
1377 NETSTACK_LOGE("g_userCounter = %{public}d", g_userCounter);
1378 return;
1379 }
1380 fds_[fd_index].fd = fds_[g_userCounter].fd;
1381
1382 fds_[g_userCounter].fd = -1;
1383 fds_[g_userCounter].events = 0;
1384 fd_index--;
1385 g_userCounter--;
1386 NETSTACK_LOGE("CallOnConnectCallback g_userCounter = %{public}d", g_userCounter);
1387 }
1388
PollThread(const TlsSocket::TLSConnectOptions & tlsListenOptions)1389 void TLSSocketServer::PollThread(const TlsSocket::TLSConnectOptions &tlsListenOptions)
1390 {
1391 #if !defined(CROSS_PLATFORM)
1392 int on = 1;
1393 isRunning_ = true;
1394 ioctl(listenSocketFd_, FIONBIO, (char *)&on);
1395 NETSTACK_LOGE("PollThread start working %{public}d", isRunning_);
1396 std::thread thread_([this, &tlsListenOptions]() {
1397 TlsSocket::TLSConnectOptions tlsOption = tlsListenOptions;
1398 InitPollList(listenSocketFd_);
1399 int clientId = 0;
1400 while (isRunning_) {
1401 int ret = poll(fds_, g_userCounter + 1, POLL_WAIT_TIME);
1402 if (ret < 0) {
1403 int resErr = ConvertErrno();
1404 NETSTACK_LOGE("Poll ERROR");
1405 CallOnErrorCallback(resErr, MakeErrnoString());
1406 break;
1407 }
1408 if (ret == 0) {
1409 continue;
1410 }
1411 for (int i = 0; i < g_userCounter + 1; ++i) {
1412 if ((fds_[i].fd == listenSocketFd_) && (static_cast<uint16_t>(fds_[i].revents) & POLLIN)) {
1413 ProcessTcpAccept(tlsOption, ++clientId);
1414 } else if ((static_cast<uint16_t>(fds_[i].revents) & POLLRDHUP) ||
1415 (static_cast<uint16_t>(fds_[i].revents) & POLLERR)) {
1416 RemoveConnect(fds_[i].fd);
1417 DropFdFromPollList(i);
1418 NETSTACK_LOGI("A client left");
1419 } else if (static_cast<uint16_t>(fds_[i].revents) & POLLIN) {
1420 RecvRemoteInfo(fds_[i].fd, i);
1421 }
1422 }
1423 }
1424 });
1425 thread_.detach();
1426 #endif
1427 }
1428
GetConnectionByClientEventManager(const EventManager * eventManager)1429 std::shared_ptr<TLSSocketServer::Connection> TLSSocketServer::GetConnectionByClientEventManager(
1430 const EventManager *eventManager)
1431 {
1432 std::shared_ptr<TLSSocketServer::Connection> ptrConnection = nullptr;
1433 std::lock_guard<std::mutex> its_lock(connectMutex_);
1434 for (const auto &it : clientIdConnections_) {
1435 if (it.second->GetEventManager().get() == eventManager) {
1436 ptrConnection = it.second;
1437 return ptrConnection;
1438 }
1439 }
1440 return ptrConnection;
1441 }
1442
CloseConnectionByEventManager(EventManager * eventManager)1443 void TLSSocketServer::CloseConnectionByEventManager(EventManager *eventManager)
1444 {
1445 std::shared_ptr<Connection> ptrConnection = GetConnectionByClientEventManager(eventManager);
1446
1447 if (ptrConnection != nullptr) {
1448 ptrConnection->Close();
1449 }
1450 }
1451
DeleteConnectionByEventManager(EventManager * eventManager)1452 void TLSSocketServer::DeleteConnectionByEventManager(EventManager *eventManager)
1453 {
1454 std::lock_guard<std::mutex> its_lock(connectMutex_);
1455 for (auto it = clientIdConnections_.begin(); it != clientIdConnections_.end(); ++it) {
1456 if (it->second->GetEventManager().get() == eventManager) {
1457 it = clientIdConnections_.erase(it);
1458 break;
1459 }
1460 }
1461 }
1462 } // namespace TlsSocketServer
1463 } // namespace NetStack
1464 } // namespace OHOS
1465