1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ps/core/communicator/ssl_wrapper.h"
18
19 #include <sys/time.h>
20 #include <openssl/pem.h>
21 #include <openssl/sha.h>
22
23 #include <cstdio>
24 #include <cstring>
25 #include <cstdlib>
26 #include <vector>
27 #include <iomanip>
28 #include <sstream>
29
30 namespace mindspore {
31 namespace ps {
32 namespace core {
SSLWrapper()33 SSLWrapper::SSLWrapper()
34 : ssl_ctx_(nullptr),
35 rootFirstCA_(nullptr),
36 rootSecondCA_(nullptr),
37 check_time_thread_(nullptr),
38 running_(false),
39 is_ready_(false) {
40 InitSSL();
41 }
42
~SSLWrapper()43 SSLWrapper::~SSLWrapper() { CleanSSL(); }
44
InitSSL()45 void SSLWrapper::InitSSL() {
46 CommUtil::InitOpenSSLEnv();
47 ssl_ctx_ = SSL_CTX_new(SSLv23_server_method());
48 if (!ssl_ctx_) {
49 MS_LOG(EXCEPTION) << "SSL_CTX_new failed";
50 }
51 X509_STORE *store = SSL_CTX_get_cert_store(ssl_ctx_);
52 MS_EXCEPTION_IF_NULL(store);
53 if (X509_STORE_set_default_paths(store) != 1) {
54 MS_LOG(EXCEPTION) << "X509_STORE_set_default_paths failed";
55 }
56 std::unique_ptr<Configuration> config_ =
57 std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
58 MS_EXCEPTION_IF_NULL(config_);
59 if (!config_->Initialize()) {
60 MS_LOG(EXCEPTION) << "The config file is empty.";
61 }
62
63 // 1.Parse the server's certificate and the ciphertext of key.
64 std::string server_cert = kCertificateChain;
65 std::string path = CommUtil::ParseConfig(*(config_), kServerCertPath);
66 if (!CommUtil::IsFileExists(path)) {
67 MS_LOG(EXCEPTION) << "The key:" << kServerCertPath << "'s value is not exist.";
68 }
69 server_cert = path;
70
71 MS_LOG(INFO) << "The server cert path:" << server_cert;
72
73 // 2. Parse the server password.
74 std::string server_password = PSContext::instance()->server_password();
75 if (server_password.empty()) {
76 MS_LOG(EXCEPTION) << "The client password's value is empty.";
77 }
78
79 EVP_PKEY *pkey = nullptr;
80 X509 *cert = nullptr;
81 STACK_OF(X509) *ca_stack = nullptr;
82 BIO *bio = BIO_new_file(server_cert.c_str(), "rb");
83 MS_EXCEPTION_IF_NULL(bio);
84 PKCS12 *p12 = d2i_PKCS12_bio(bio, nullptr);
85 MS_EXCEPTION_IF_NULL(p12);
86 BIO_free_all(bio);
87 if (!PKCS12_parse(p12, server_password.c_str(), &pkey, &cert, &ca_stack)) {
88 MS_LOG(EXCEPTION) << "PKCS12_parse failed.";
89 }
90 PKCS12_free(p12);
91 std::string default_cipher_list = CommUtil::ParseConfig(*config_, kCipherList);
92 std::vector<std::string> ciphers = CommUtil::Split(default_cipher_list, kColon);
93 if (!CommUtil::VerifyCipherList(ciphers)) {
94 MS_LOG(EXCEPTION) << "The cipher is wrong.";
95 }
96 if (!SSL_CTX_set_cipher_list(ssl_ctx_, default_cipher_list.c_str())) {
97 MS_LOG(EXCEPTION) << "SSL use set cipher list failed!";
98 }
99
100 std::string crl_path = CommUtil::ParseConfig(*(config_), kCrlPath);
101 if (crl_path.empty()) {
102 MS_LOG(INFO) << "The crl path is empty.";
103 } else if (!CommUtil::VerifyCRL(cert, crl_path)) {
104 MS_LOG(EXCEPTION) << "Verify crl failed.";
105 }
106
107 std::string client_ca = kCAcrt;
108 std::string ca_path = CommUtil::ParseConfig(*config_, kCaCertPath);
109 if (!CommUtil::IsFileExists(ca_path)) {
110 MS_LOG(WARNING) << "The key:" << kCaCertPath << "'s value is not exist.";
111 }
112 client_ca = ca_path;
113
114 if (!CommUtil::VerifyCommonName(cert, client_ca)) {
115 MS_LOG(EXCEPTION) << "Verify common name failed.";
116 }
117
118 SSL_CTX_set_verify(ssl_ctx_, SSL_VERIFY_PEER, 0);
119 if (!SSL_CTX_load_verify_locations(ssl_ctx_, client_ca.c_str(), nullptr)) {
120 MS_LOG(EXCEPTION) << "SSL load ca location failed!";
121 }
122
123 if (!SSL_CTX_use_certificate(ssl_ctx_, cert)) {
124 MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!";
125 }
126
127 if (!SSL_CTX_use_PrivateKey(ssl_ctx_, pkey)) {
128 MS_LOG(EXCEPTION) << "SSL use private key file failed!";
129 }
130
131 if (!SSL_CTX_check_private_key(ssl_ctx_)) {
132 MS_LOG(EXCEPTION) << "SSL check private key file failed!";
133 }
134 if (!SSL_CTX_set_options(ssl_ctx_, SSL_OP_SINGLE_DH_USE | SSL_OP_SINGLE_ECDH_USE | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 |
135 SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1)) {
136 MS_LOG(EXCEPTION) << "SSL_CTX_set_options failed.";
137 }
138
139 if (!SSL_CTX_set_mode(ssl_ctx_, SSL_MODE_AUTO_RETRY)) {
140 MS_LOG(EXCEPTION) << "SSL set mode auto retry failed!";
141 }
142
143 StartCheckCertTime(*config_, cert, client_ca);
144 }
145
CleanSSL()146 void SSLWrapper::CleanSSL() {
147 if (ssl_ctx_ != nullptr) {
148 SSL_CTX_free(ssl_ctx_);
149 }
150 ERR_free_strings();
151 EVP_cleanup();
152 ERR_remove_thread_state(nullptr);
153 CRYPTO_cleanup_all_ex_data();
154 StopCheckCertTime();
155 }
156
ConvertAsn1Time(const ASN1_TIME * const time) const157 time_t SSLWrapper::ConvertAsn1Time(const ASN1_TIME *const time) const {
158 MS_EXCEPTION_IF_NULL(time);
159 struct tm t;
160 const char *str = (const char *)time->data;
161 MS_EXCEPTION_IF_NULL(str);
162 size_t i = 0;
163
164 if (memset_s(&t, sizeof(t), 0, sizeof(t)) != EOK) {
165 MS_LOG(EXCEPTION) << "Memset Failed!";
166 }
167
168 if (time->type == V_ASN1_UTCTIME) {
169 t.tm_year = (str[i++] - '0') * kBase;
170 t.tm_year += (str[i++] - '0');
171 if (t.tm_year < kSeventyYear) {
172 t.tm_year += kHundredYear;
173 }
174 } else if (time->type == V_ASN1_GENERALIZEDTIME) {
175 t.tm_year = (str[i++] - '0') * kThousandYear;
176 t.tm_year += (str[i++] - '0') * kHundredYear;
177 t.tm_year += (str[i++] - '0') * kBase;
178 t.tm_year += (str[i++] - '0');
179 t.tm_year -= kBaseYear;
180 }
181 t.tm_mon = (str[i++] - '0') * kBase;
182 // -1 since January is 0 not 1.
183 t.tm_mon += (str[i++] - '0') - kJanuary;
184 t.tm_mday = (str[i++] - '0') * kBase;
185 t.tm_mday += (str[i++] - '0');
186 t.tm_hour = (str[i++] - '0') * kBase;
187 t.tm_hour += (str[i++] - '0');
188 t.tm_min = (str[i++] - '0') * kBase;
189 t.tm_min += (str[i++] - '0');
190 t.tm_sec = (str[i++] - '0') * kBase;
191 t.tm_sec += (str[i++] - '0');
192
193 return mktime(&t);
194 }
195
StartCheckCertTime(const Configuration & config,const X509 * cert,const std::string & ca_path)196 void SSLWrapper::StartCheckCertTime(const Configuration &config, const X509 *cert, const std::string &ca_path) {
197 MS_EXCEPTION_IF_NULL(cert);
198 MS_LOG(INFO) << "The server start check cert.";
199 int64_t interval = kCertCheckIntervalInHour;
200
201 int64_t warning_time = kCertExpireWarningTimeInDay;
202 if (config.Exists(kCertExpireWarningTime)) {
203 int64_t res_time = config.GetInt(kCertExpireWarningTime, 0);
204 if (res_time < kMinWarningTime || res_time > kMaxWarningTime) {
205 MS_LOG(EXCEPTION) << "The Certificate expiration warning time should be [7, 180]";
206 }
207 warning_time = res_time;
208 }
209 MS_LOG(INFO) << "The interval time is:" << interval << ", the warning time is:" << warning_time;
210 BIO *ca_bio = BIO_new_file(ca_path.c_str(), "r");
211 MS_EXCEPTION_IF_NULL(ca_bio);
212 X509 *ca_cert = PEM_read_bio_X509(ca_bio, nullptr, nullptr, nullptr);
213 BIO_free_all(ca_bio);
214 MS_EXCEPTION_IF_NULL(ca_cert);
215
216 running_ = true;
217 check_time_thread_ = std::make_unique<std::thread>([&, cert, ca_cert, interval, warning_time]() {
218 while (running_) {
219 if (!CommUtil::VerifyCertTime(cert, warning_time)) {
220 MS_LOG(WARNING) << "Verify server cert time failed.";
221 }
222
223 if (!CommUtil::VerifyCertTime(ca_cert, warning_time)) {
224 MS_LOG(WARNING) << "Verify ca cert time failed.";
225 }
226 std::unique_lock<std::mutex> lock(mutex_);
227 bool res = cond_.wait_for(lock, std::chrono::hours(interval), [&] {
228 bool result = is_ready_.load();
229 return result;
230 });
231 MS_LOG(INFO) << "Wait for res:" << res;
232 }
233 });
234 MS_EXCEPTION_IF_NULL(check_time_thread_);
235 }
236
StopCheckCertTime()237 void SSLWrapper::StopCheckCertTime() {
238 running_ = false;
239 is_ready_ = true;
240 cond_.notify_all();
241 if (check_time_thread_ != nullptr) {
242 check_time_thread_->join();
243 }
244 }
245
GetSSLCtx(bool)246 SSL_CTX *SSLWrapper::GetSSLCtx(bool) { return ssl_ctx_; }
247 } // namespace core
248 } // namespace ps
249 } // namespace mindspore
250