• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ps/core/communicator/ssl_client.h"
18 
19 #include <sys/time.h>
20 #include <openssl/pem.h>
21 #include <openssl/sha.h>
22 
23 #include <cstdio>
24 #include <cstring>
25 #include <cstdlib>
26 #include <vector>
27 #include <iomanip>
28 #include <sstream>
29 
30 namespace mindspore {
31 namespace ps {
32 namespace core {
SSLClient()33 SSLClient::SSLClient() : ssl_ctx_(nullptr), check_time_thread_(nullptr), running_(false), is_ready_(false) {
34   InitSSL();
35 }
36 
~SSLClient()37 SSLClient::~SSLClient() { CleanSSL(); }
38 
InitSSL()39 void SSLClient::InitSSL() {
40   CommUtil::InitOpenSSLEnv();
41   ssl_ctx_ = SSL_CTX_new(SSLv23_client_method());
42   if (!ssl_ctx_) {
43     MS_LOG(EXCEPTION) << "SSL_CTX_new failed";
44   }
45   std::unique_ptr<Configuration> config_ =
46     std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
47   MS_EXCEPTION_IF_NULL(config_);
48   if (!config_->Initialize()) {
49     MS_LOG(EXCEPTION) << "The config file is empty.";
50   }
51 
52   // 1.Parse the client's certificate and the ciphertext of key.
53   std::string client_cert = kCertificateChain;
54   std::string path = CommUtil::ParseConfig(*config_, kClientCertPath);
55   if (!CommUtil::IsFileExists(path)) {
56     MS_LOG(EXCEPTION) << "The key:" << kClientCertPath << "'s value is not exist.";
57   }
58   client_cert = path;
59 
60   // 2. Parse the client password.
61   std::string client_password = PSContext::instance()->client_password();
62   if (client_password.empty()) {
63     MS_LOG(EXCEPTION) << "The client password's value is empty.";
64   }
65   EVP_PKEY *pkey = nullptr;
66   X509 *cert = nullptr;
67   STACK_OF(X509) *ca_stack = nullptr;
68   BIO *bio = BIO_new_file(client_cert.c_str(), "rb");
69   MS_EXCEPTION_IF_NULL(bio);
70   PKCS12 *p12 = d2i_PKCS12_bio(bio, nullptr);
71   MS_EXCEPTION_IF_NULL(p12);
72   BIO_free_all(bio);
73   if (!PKCS12_parse(p12, client_password.c_str(), &pkey, &cert, &ca_stack)) {
74     MS_LOG(EXCEPTION) << "PKCS12_parse failed.";
75   }
76 
77   PKCS12_free(p12);
78   if (cert == nullptr) {
79     MS_LOG(EXCEPTION) << "the cert is nullptr";
80   }
81   if (pkey == nullptr) {
82     MS_LOG(EXCEPTION) << "the key is nullptr";
83   }
84 
85   // 3. load ca cert.
86   std::string client_ca = kCAcrt;
87   std::string ca_path = CommUtil::ParseConfig(*config_, kCaCertPath);
88   if (!CommUtil::IsFileExists(ca_path)) {
89     MS_LOG(WARNING) << "The key:" << kCaCertPath << "'s value is not exist.";
90   }
91   client_ca = ca_path;
92 
93   std::string crl_path = CommUtil::ParseConfig(*(config_), kCrlPath);
94   if (crl_path.empty()) {
95     MS_LOG(INFO) << "The crl path is empty.";
96   } else if (!CommUtil::VerifyCRL(cert, crl_path)) {
97     MS_LOG(EXCEPTION) << "Verify crl failed.";
98   }
99 
100   if (!CommUtil::VerifyCommonName(cert, client_ca)) {
101     MS_LOG(EXCEPTION) << "Verify common name failed.";
102   }
103 
104   SSL_CTX_set_verify(ssl_ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, 0);
105 
106   if (!SSL_CTX_load_verify_locations(ssl_ctx_, client_ca.c_str(), nullptr)) {
107     MS_LOG(EXCEPTION) << "SSL load ca location failed!";
108   }
109 
110   std::string default_cipher_list = CommUtil::ParseConfig(*config_, kCipherList);
111   std::vector<std::string> ciphers = CommUtil::Split(default_cipher_list, kColon);
112   if (!CommUtil::VerifyCipherList(ciphers)) {
113     MS_LOG(EXCEPTION) << "The cipher is wrong.";
114   }
115   if (!SSL_CTX_set_cipher_list(ssl_ctx_, default_cipher_list.c_str())) {
116     MS_LOG(EXCEPTION) << "SSL use set cipher list failed!";
117   }
118 
119   // 4. load client cert
120   if (!SSL_CTX_use_certificate(ssl_ctx_, cert)) {
121     MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!";
122   }
123 
124   if (!SSL_CTX_use_PrivateKey(ssl_ctx_, pkey)) {
125     MS_LOG(EXCEPTION) << "SSL use private key file failed!";
126   }
127 
128   if (!SSL_CTX_check_private_key(ssl_ctx_)) {
129     MS_LOG(EXCEPTION) << "SSL check private key file failed!";
130   }
131 
132   if (!SSL_CTX_set_options(ssl_ctx_, SSL_OP_SINGLE_DH_USE | SSL_OP_SINGLE_ECDH_USE | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 |
133                                        SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1)) {
134     MS_LOG(EXCEPTION) << "SSL_CTX_set_options failed.";
135   }
136 
137   if (!SSL_CTX_set_mode(ssl_ctx_, SSL_MODE_AUTO_RETRY)) {
138     MS_LOG(EXCEPTION) << "SSL set mode auto retry failed!";
139   }
140 
141   StartCheckCertTime(*config_, cert);
142 }
143 
CleanSSL()144 void SSLClient::CleanSSL() {
145   if (ssl_ctx_ != nullptr) {
146     SSL_CTX_free(ssl_ctx_);
147   }
148   ERR_free_strings();
149   EVP_cleanup();
150   ERR_remove_thread_state(nullptr);
151   CRYPTO_cleanup_all_ex_data();
152   StopCheckCertTime();
153 }
154 
StartCheckCertTime(const Configuration & config,const X509 * cert)155 void SSLClient::StartCheckCertTime(const Configuration &config, const X509 *cert) {
156   MS_EXCEPTION_IF_NULL(cert);
157   MS_LOG(INFO) << "The client start check cert.";
158   int64_t interval = kCertCheckIntervalInHour;
159 
160   int64_t warning_time = kCertExpireWarningTimeInDay;
161   if (config.Exists(kCertExpireWarningTime)) {
162     int64_t res_time = config.GetInt(kCertExpireWarningTime, 0);
163     if (res_time < kMinWarningTime || res_time > kMaxWarningTime) {
164       MS_LOG(EXCEPTION) << "The Certificate expiration warning time should be [7, 180]";
165     }
166     warning_time = res_time;
167   }
168   MS_LOG(INFO) << "The interval time is:" << interval << ", the warning time is:" << warning_time;
169   running_ = true;
170   check_time_thread_ = std::make_unique<std::thread>([&, cert, interval, warning_time]() {
171     while (running_) {
172       if (!CommUtil::VerifyCertTime(cert, warning_time)) {
173         MS_LOG(WARNING) << "Verify cert time failed.";
174       }
175       std::unique_lock<std::mutex> lock(mutex_);
176       bool res = cond_.wait_for(lock, std::chrono::hours(interval), [&] {
177         bool result = is_ready_.load();
178         return result;
179       });
180       MS_LOG(INFO) << "Wait for res:" << res;
181     }
182   });
183   MS_EXCEPTION_IF_NULL(check_time_thread_);
184 }
185 
StopCheckCertTime()186 void SSLClient::StopCheckCertTime() {
187   running_ = false;
188   is_ready_ = true;
189   cond_.notify_all();
190   if (check_time_thread_ != nullptr) {
191     check_time_thread_->join();
192   }
193 }
194 
GetSSLCtx() const195 SSL_CTX *SSLClient::GetSSLCtx() const { return ssl_ctx_; }
196 }  // namespace core
197 }  // namespace ps
198 }  // namespace mindspore
199