• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 
2 /**
3  * Copyright 2021 Huawei Technologies Co., Ltd
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 #include "ps/core/communicator/ssl_wrapper.h"
19 
20 #include <sys/time.h>
21 #include <openssl/pem.h>
22 #include <openssl/sha.h>
23 
24 #include <cstdio>
25 #include <cstring>
26 #include <cstdlib>
27 #include <vector>
28 #include <iomanip>
29 #include <sstream>
30 
31 namespace mindspore {
32 namespace ps {
33 namespace core {
SSLWrapper()34 SSLWrapper::SSLWrapper()
35     : ssl_ctx_(nullptr),
36       rootFirstCA_(nullptr),
37       rootSecondCA_(nullptr),
38       check_time_thread_(nullptr),
39       running_(false),
40       is_ready_(false) {}
41 
~SSLWrapper()42 SSLWrapper::~SSLWrapper() { CleanSSL(); }
43 
InitSSL()44 void SSLWrapper::InitSSL() {
45   std::unique_lock<std::mutex> lock(mutex_);
46   if (init_) {
47     return;
48   }
49   init_ = true;
50   CommUtil::InitOpensslLib();
51   ssl_ctx_ = SSL_CTX_new(SSLv23_server_method());
52   if (!ssl_ctx_) {
53     MS_LOG(EXCEPTION) << "SSL_CTX_new failed";
54   }
55   X509_STORE *store = SSL_CTX_get_cert_store(ssl_ctx_);
56   MS_EXCEPTION_IF_NULL(store);
57   if (X509_STORE_set_default_paths(store) != 1) {
58     MS_LOG(EXCEPTION) << "X509_STORE_set_default_paths failed";
59   }
60   std::unique_ptr<Configuration> config_ =
61     std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
62   MS_EXCEPTION_IF_NULL(config_);
63   if (!config_->Initialize()) {
64     MS_LOG(EXCEPTION) << "The config file is empty.";
65   }
66 
67   // 1.Parse the server's certificate and the ciphertext of key.
68   std::string path = CommUtil::ParseConfig(*(config_), kServerCertPath);
69   if (!CommUtil::IsFileExists(path)) {
70     MS_LOG(EXCEPTION) << "The key:" << kServerCertPath << "'s value is not exist.";
71   }
72   std::string server_cert = path;
73   MS_LOG(INFO) << "The server cert path:" << server_cert;
74 
75   // 2. Parse the server password.
76   char *server_password = PSContext::instance()->server_password();
77   if (strlen(server_password) == 0) {
78     MS_LOG(EXCEPTION) << "The client password's value is empty.";
79   }
80 
81   EVP_PKEY *pkey = nullptr;
82   X509 *cert = nullptr;
83   STACK_OF(X509) *ca_stack = nullptr;
84   BIO *bio = BIO_new_file(server_cert.c_str(), "rb");
85   if (bio == nullptr) {
86     PSContext::instance()->ClearServerPassword();
87     MS_LOG(EXCEPTION) << "Read server cert file failed.";
88   }
89   PKCS12 *p12 = d2i_PKCS12_bio(bio, nullptr);
90   BIO_free_all(bio);
91   if (p12 == nullptr) {
92     PSContext::instance()->ClearServerPassword();
93     MS_LOG(EXCEPTION) << "Create PKCS12 cert failed, please check whether the certificate is correct.";
94   }
95   if (PKCS12_parse(p12, server_password, &pkey, &cert, &ca_stack) == 0) {
96     if (ERR_GET_REASON(ERR_peek_last_error()) == PKCS12_R_MAC_VERIFY_FAILURE) {
97       PSContext::instance()->ClearServerPassword();
98       MS_LOG(EXCEPTION) << "The server password is invalid!";
99     }
100     PSContext::instance()->ClearServerPassword();
101     MS_LOG(EXCEPTION) << "PKCS12_parse failed, the reason is " << ERR_reason_error_string(ERR_peek_last_error());
102   }
103   PSContext::instance()->ClearServerPassword();
104 
105   PKCS12_free(p12);
106   MS_EXCEPTION_IF_NULL(cert);
107   MS_EXCEPTION_IF_NULL(pkey);
108   if (ca_stack != nullptr) {
109     MS_LOG(EXCEPTION) << "The cert is invalid: ca_stack should be empty.";
110   }
111 
112   std::string ca_path = CommUtil::ParseConfig(*config_, kCaCertPath);
113   if (!CommUtil::IsFileExists(ca_path)) {
114     MS_LOG(EXCEPTION) << "The key:" << kCaCertPath << "'s value is not exist.";
115   }
116   BIO *ca_bio = BIO_new_file(ca_path.c_str(), "r");
117   if (ca_bio == nullptr) {
118     MS_LOG(EXCEPTION) << "Read CA cert file failed.";
119   }
120   X509 *caCert = PEM_read_bio_X509(ca_bio, nullptr, nullptr, nullptr);
121   X509_CRL *crl = nullptr;
122   std::string crl_path = CommUtil::ParseConfig(*(config_), kCrlPath);
123   if (crl_path.empty()) {
124     MS_LOG(INFO) << "The crl path is empty.";
125   } else if (!CommUtil::checkCRLTime(crl_path)) {
126     MS_LOG(EXCEPTION) << "check crl time failed";
127   } else if (!CommUtil::VerifyCRL(caCert, crl_path, &crl)) {
128     MS_LOG(EXCEPTION) << "Verify crl failed.";
129   }
130 
131   CommUtil::verifyCertPipeline(caCert, cert);
132 
133   SSL_CTX_set_verify(ssl_ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, 0);
134   if (!SSL_CTX_load_verify_locations(ssl_ctx_, ca_path.c_str(), nullptr)) {
135     MS_LOG(EXCEPTION) << "SSL load ca location failed!";
136   }
137 
138   InitSSLCtx(*config_, cert, pkey, crl);
139   StartCheckCertTime(*config_, cert, ca_path);
140 
141   EVP_PKEY_free(pkey);
142   X509_free(caCert);
143   X509_free(cert);
144   BIO_vfree(ca_bio);
145   if (crl != nullptr) {
146     X509_CRL_free(crl);
147   }
148 }
149 
InitSSLCtx(const Configuration & config,const X509 * cert,const EVP_PKEY * pkey,X509_CRL * crl)150 void SSLWrapper::InitSSLCtx(const Configuration &config, const X509 *cert, const EVP_PKEY *pkey, X509_CRL *crl) {
151   std::string default_cipher_list = CommUtil::ParseConfig(config, kCipherList);
152   std::vector<std::string> ciphers = CommUtil::Split(default_cipher_list, kColon);
153   if (!CommUtil::VerifyCipherList(ciphers)) {
154     MS_LOG(EXCEPTION) << "The cipher is wrong.";
155   }
156   if (!SSL_CTX_set_cipher_list(ssl_ctx_, default_cipher_list.c_str())) {
157     MS_LOG(EXCEPTION) << "SSL use set cipher list failed!";
158   }
159   if (!SSL_CTX_use_certificate(ssl_ctx_, const_cast<X509 *>(cert))) {
160     MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!";
161   }
162 
163   if (!SSL_CTX_use_PrivateKey(ssl_ctx_, const_cast<EVP_PKEY *>(pkey))) {
164     MS_LOG(EXCEPTION) << "SSL use private key file failed!";
165   }
166 
167   if (!SSL_CTX_check_private_key(ssl_ctx_)) {
168     MS_LOG(EXCEPTION) << "SSL check private key file failed!";
169   }
170   if (!SSL_CTX_set_options(ssl_ctx_, SSL_OP_SINGLE_DH_USE | SSL_OP_SINGLE_ECDH_USE | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 |
171                                        SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1)) {
172     MS_LOG(EXCEPTION) << "SSL_CTX_set_options failed.";
173   }
174   if (!SSL_CTX_set_mode(ssl_ctx_, SSL_MODE_AUTO_RETRY)) {
175     MS_LOG(EXCEPTION) << "SSL set mode auto retry failed!";
176   }
177 
178   if (crl != nullptr) {
179     // Load CRL into the `X509_STORE`
180     X509_STORE *x509_store = SSL_CTX_get_cert_store(ssl_ctx_);
181     if (X509_STORE_add_crl(x509_store, crl) != 1) {
182       MS_LOG(EXCEPTION) << "ssl server X509_STORE add crl failed!";
183     }
184 
185     // Enable CRL checking
186     X509_VERIFY_PARAM *param = SSL_CTX_get0_param(ssl_ctx_);
187     if (param == nullptr) {
188       MS_LOG(EXCEPTION) << "ssl server X509_VERIFY_PARAM is nullptr!";
189     }
190     if (X509_VERIFY_PARAM_set_flags(param, X509_V_FLAG_CRL_CHECK) != 1) {
191       MS_LOG(EXCEPTION) << "ssl server X509_VERIFY_PARAM set flag X509_V_FLAG_CRL_CHECK failed!";
192     }
193   }
194 
195   SSL_CTX_set_security_level(ssl_ctx_, kSecurityLevel);
196 }
197 
CleanSSL()198 void SSLWrapper::CleanSSL() {
199   if (ssl_ctx_ != nullptr) {
200     SSL_CTX_free(ssl_ctx_);
201   }
202   ERR_free_strings();
203   EVP_cleanup();
204   ERR_remove_thread_state(nullptr);
205   CRYPTO_cleanup_all_ex_data();
206   StopCheckCertTime();
207 }
208 
ConvertAsn1Time(const ASN1_TIME * const time) const209 time_t SSLWrapper::ConvertAsn1Time(const ASN1_TIME *const time) const {
210   MS_EXCEPTION_IF_NULL(time);
211   struct tm t;
212   const char *str = (const char *)time->data;
213   MS_EXCEPTION_IF_NULL(str);
214   size_t i = 0;
215 
216   if (memset_s(&t, sizeof(t), 0, sizeof(t)) != EOK) {
217     MS_LOG(EXCEPTION) << "Memset Failed!";
218   }
219 
220   if (time->type == V_ASN1_UTCTIME) {
221     t.tm_year = (str[i++] - '0') * kBase;
222     t.tm_year += (str[i++] - '0');
223     if (t.tm_year < kSeventyYear) {
224       t.tm_year += kHundredYear;
225     }
226   } else if (time->type == V_ASN1_GENERALIZEDTIME) {
227     t.tm_year = (str[i++] - '0') * kThousandYear;
228     t.tm_year += (str[i++] - '0') * kHundredYear;
229     t.tm_year += (str[i++] - '0') * kBase;
230     t.tm_year += (str[i++] - '0');
231     t.tm_year -= kBaseYear;
232   }
233   t.tm_mon = (str[i++] - '0') * kBase;
234   // -1 since January is 0 not 1.
235   t.tm_mon += (str[i++] - '0') - kJanuary;
236   t.tm_mday = (str[i++] - '0') * kBase;
237   t.tm_mday += (str[i++] - '0');
238   t.tm_hour = (str[i++] - '0') * kBase;
239   t.tm_hour += (str[i++] - '0');
240   t.tm_min = (str[i++] - '0') * kBase;
241   t.tm_min += (str[i++] - '0');
242   t.tm_sec = (str[i++] - '0') * kBase;
243   t.tm_sec += (str[i++] - '0');
244 
245   return mktime(&t);
246 }
247 
StartCheckCertTime(const Configuration & config,const X509 * cert,const std::string & ca_path)248 void SSLWrapper::StartCheckCertTime(const Configuration &config, const X509 *cert, const std::string &ca_path) {
249   MS_EXCEPTION_IF_NULL(cert);
250   MS_LOG(INFO) << "The server start check cert.";
251   int64_t interval = kCertCheckIntervalInHour;
252 
253   int64_t warning_time = kCertExpireWarningTimeInDay;
254   if (config.Exists(kCertExpireWarningTime)) {
255     int64_t res_time = config.GetInt(kCertExpireWarningTime, 0);
256     if (res_time < kMinWarningTime || res_time > kMaxWarningTime) {
257       MS_LOG(EXCEPTION) << "The Certificate expiration warning time should be [7, 180]";
258     }
259     warning_time = res_time;
260   }
261   MS_LOG(INFO) << "The interval time is:" << interval << ", the warning time is:" << warning_time;
262   BIO *ca_bio = BIO_new_file(ca_path.c_str(), "r");
263   MS_EXCEPTION_IF_NULL(ca_bio);
264   X509 *ca_cert = PEM_read_bio_X509(ca_bio, nullptr, nullptr, nullptr);
265   BIO_free_all(ca_bio);
266   MS_EXCEPTION_IF_NULL(ca_cert);
267 
268   running_ = true;
269   check_time_thread_ = std::make_unique<std::thread>([&, cert, ca_cert, interval, warning_time]() {
270     while (running_) {
271       if (!CommUtil::VerifyCertTime(cert, warning_time)) {
272         MS_LOG(WARNING) << "Verify server cert time failed.";
273       }
274 
275       if (!CommUtil::VerifyCertTime(ca_cert, warning_time)) {
276         MS_LOG(WARNING) << "Verify ca cert time failed.";
277       }
278       std::unique_lock<std::mutex> lock(mutex_);
279       bool res = cond_.wait_for(lock, std::chrono::hours(interval), [&] {
280         bool result = is_ready_.load();
281         return result;
282       });
283       MS_LOG(INFO) << "Wait for res:" << res;
284     }
285   });
286   MS_EXCEPTION_IF_NULL(check_time_thread_);
287 }
288 
StopCheckCertTime()289 void SSLWrapper::StopCheckCertTime() {
290   running_ = false;
291   is_ready_ = true;
292   cond_.notify_all();
293   if (check_time_thread_ != nullptr) {
294     check_time_thread_->join();
295   }
296 }
297 
GetSSLCtx(bool)298 SSL_CTX *SSLWrapper::GetSSLCtx(bool) { return ssl_ctx_; }
299 }  // namespace core
300 }  // namespace ps
301 }  // namespace mindspore
302