1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ps/core/comm_util.h"
18
19 #include <arpa/inet.h>
20 #include <cstdio>
21 #include <cstdlib>
22 #include <cstring>
23 #include <functional>
24 #include <algorithm>
25 #include <regex>
26
27 namespace mindspore {
28 namespace ps {
29 namespace core {
30 std::random_device CommUtil::rd;
31 std::mt19937_64 CommUtil::gen(rd());
32 std::uniform_int_distribution<> CommUtil::dis = std::uniform_int_distribution<>{0, 15};
33 std::uniform_int_distribution<> CommUtil::dis2 = std::uniform_int_distribution<>{8, 11};
34
CheckIpWithRegex(const std::string & ip)35 bool CommUtil::CheckIpWithRegex(const std::string &ip) {
36 std::regex pattern(
37 "(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])"
38 "[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
39 "[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
40 "[.](25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])");
41 std::smatch res;
42 if (regex_match(ip, res, pattern)) {
43 return true;
44 }
45 return false;
46 }
47
CheckIp(const std::string & ip)48 bool CommUtil::CheckIp(const std::string &ip) {
49 if (!CheckIpWithRegex(ip)) {
50 return false;
51 }
52 uint32_t uAddr = inet_addr(ip.c_str());
53 if (INADDR_NONE == uAddr) {
54 return false;
55 }
56 return true;
57 }
58
CheckPort(const uint16_t & port)59 bool CommUtil::CheckPort(const uint16_t &port) {
60 if (port > 65535) {
61 MS_LOG(ERROR) << "The range of port should be 1 to 65535.";
62 return false;
63 }
64 return true;
65 }
66
GetAvailableInterfaceAndIP(std::string * interface,std::string * ip)67 void CommUtil::GetAvailableInterfaceAndIP(std::string *interface, std::string *ip) {
68 MS_EXCEPTION_IF_NULL(interface);
69 MS_EXCEPTION_IF_NULL(ip);
70 struct ifaddrs *if_address = nullptr;
71 struct ifaddrs *ifa = nullptr;
72
73 interface->clear();
74 ip->clear();
75 if (getifaddrs(&if_address) == -1) {
76 MS_LOG(WARNING) << "Get ifaddrs failed.";
77 }
78 for (ifa = if_address; ifa != nullptr; ifa = ifa->ifa_next) {
79 if (ifa->ifa_addr == nullptr) {
80 continue;
81 }
82
83 if (ifa->ifa_addr->sa_family == AF_INET && (ifa->ifa_flags & IFF_LOOPBACK) == 0) {
84 char address_buffer[INET_ADDRSTRLEN] = {0};
85 void *sin_addr_ptr = &(reinterpret_cast<struct sockaddr_in *>(ifa->ifa_addr))->sin_addr;
86 MS_EXCEPTION_IF_NULL(sin_addr_ptr);
87 const char *net_ptr = inet_ntop(AF_INET, sin_addr_ptr, address_buffer, INET_ADDRSTRLEN);
88 MS_EXCEPTION_IF_NULL(net_ptr);
89
90 *ip = address_buffer;
91 *interface = ifa->ifa_name;
92 break;
93 }
94 }
95 MS_EXCEPTION_IF_NULL(if_address);
96 freeifaddrs(if_address);
97 }
98
GenerateUUID()99 std::string CommUtil::GenerateUUID() {
100 std::stringstream ss;
101 int i;
102 ss << std::hex;
103 for (i = 0; i < kGroup1RandomLength; i++) {
104 ss << dis(gen);
105 }
106 ss << "-";
107 for (i = 0; i < kGroup2RandomLength; i++) {
108 ss << dis(gen);
109 }
110 ss << "-4";
111 for (i = 0; i < kGroup3RandomLength - 1; i++) {
112 ss << dis(gen);
113 }
114 ss << "-";
115 ss << dis2(gen);
116 for (i = 0; i < kGroup4RandomLength - 1; i++) {
117 ss << dis(gen);
118 }
119 ss << "-";
120 for (i = 0; i < kGroup5RandomLength; i++) {
121 ss << dis(gen);
122 }
123 return ss.str();
124 }
125
NodeRoleToString(const NodeRole & role)126 std::string CommUtil::NodeRoleToString(const NodeRole &role) {
127 switch (role) {
128 case NodeRole::SCHEDULER:
129 return "SCHEDULER";
130 case NodeRole::SERVER:
131 return "SERVER";
132 case NodeRole::WORKER:
133 return "WORKER";
134 default:
135 MS_LOG(EXCEPTION) << "The node role:" << role << " is illegal!";
136 }
137 }
ValidateRankId(const enum NodeRole & node_role,const uint32_t & rank_id,const int32_t & total_worker_num,const int32_t & total_server_num)138 bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id, const int32_t &total_worker_num,
139 const int32_t &total_server_num) {
140 if (node_role == NodeRole::SERVER && (rank_id > IntToUint(total_server_num) - 1)) {
141 return false;
142 } else if (node_role == NodeRole::WORKER && (rank_id > IntToUint(total_worker_num) - 1)) {
143 return false;
144 }
145 return true;
146 }
147
Retry(const std::function<bool ()> & func,size_t max_attempts,size_t interval_milliseconds)148 bool CommUtil::Retry(const std::function<bool()> &func, size_t max_attempts, size_t interval_milliseconds) {
149 for (size_t attempt = 0; attempt < max_attempts; ++attempt) {
150 if (func()) {
151 return true;
152 }
153 std::this_thread::sleep_for(std::chrono::milliseconds(interval_milliseconds));
154 }
155 return false;
156 }
157
LogCallback(int severity,const char * msg)158 void CommUtil::LogCallback(int severity, const char *msg) {
159 MS_EXCEPTION_IF_NULL(msg);
160 switch (severity) {
161 case EVENT_LOG_MSG:
162 MS_LOG(INFO) << kLibeventLogPrefix << msg;
163 break;
164 case EVENT_LOG_WARN:
165 MS_LOG(WARNING) << kLibeventLogPrefix << msg;
166 break;
167 case EVENT_LOG_ERR:
168 MS_LOG(ERROR) << kLibeventLogPrefix << msg;
169 break;
170 default:
171 break;
172 }
173 }
174
IsFileExists(const std::string & file)175 bool CommUtil::IsFileExists(const std::string &file) {
176 std::ifstream f(file.c_str());
177 if (!f.good()) {
178 return false;
179 } else {
180 f.close();
181 return true;
182 }
183 }
184
ClusterStateToString(const ClusterState & state)185 std::string CommUtil::ClusterStateToString(const ClusterState &state) {
186 MS_LOG(INFO) << "The cluster state:" << state;
187 if (state < SizeToInt(kClusterState.size())) {
188 return kClusterState.at(state);
189 } else {
190 return "";
191 }
192 }
193
ParseConfig(const Configuration & config,const std::string & data)194 std::string CommUtil::ParseConfig(const Configuration &config, const std::string &data) {
195 if (!config.IsInitialized()) {
196 MS_LOG(INFO) << "The config is not initialized.";
197 return "";
198 }
199
200 if (!const_cast<Configuration &>(config).Exists(data)) {
201 MS_LOG(INFO) << "The data:" << data << " is not exist.";
202 return "";
203 }
204
205 std::string path = config.GetString(data, "");
206 return path;
207 }
208
VerifyCertTime(const X509 * cert,int64_t time)209 bool CommUtil::VerifyCertTime(const X509 *cert, int64_t time) {
210 MS_EXCEPTION_IF_NULL(cert);
211 ASN1_TIME *start = X509_getm_notBefore(cert);
212 ASN1_TIME *end = X509_getm_notAfter(cert);
213 MS_EXCEPTION_IF_NULL(start);
214 MS_EXCEPTION_IF_NULL(end);
215 int day = 0;
216 int sec = 0;
217 if (!ASN1_TIME_diff(&day, &sec, start, NULL)) {
218 MS_LOG(WARNING) << "ASN1 time diff failed.";
219 return false;
220 }
221
222 if (day < 0 || sec < 0) {
223 MS_LOG(WARNING) << "Cert start time is later than now time.";
224 return false;
225 }
226 day = 0;
227 sec = 0;
228
229 if (!ASN1_TIME_diff(&day, &sec, NULL, end)) {
230 MS_LOG(WARNING) << "ASN1 time diff failed.";
231 return false;
232 }
233
234 int64_t interval = kCertExpireWarningTimeInDay;
235 if (time > 0) {
236 interval = time;
237 }
238
239 if (day < LongToInt(interval) && day >= 0) {
240 MS_LOG(WARNING) << "The certificate will expire in " << day << " days and " << sec << " seconds.";
241 } else if (day < 0 || sec < 0) {
242 MS_LOG(WARNING) << "The certificate has expired.";
243 return false;
244 }
245 return true;
246 }
247
VerifyCRL(const X509 * cert,const std::string & crl_path)248 bool CommUtil::VerifyCRL(const X509 *cert, const std::string &crl_path) {
249 MS_ERROR_IF_NULL_W_RET_VAL(cert, false);
250 BIO *bio = BIO_new_file(crl_path.c_str(), "r");
251 MS_ERROR_IF_NULL_W_RET_VAL(bio, false);
252 X509_CRL *root_crl = PEM_read_bio_X509_CRL(bio, nullptr, nullptr, nullptr);
253 MS_ERROR_IF_NULL_W_RET_VAL(root_crl, false);
254 EVP_PKEY *evp_pkey = X509_get_pubkey(const_cast<X509 *>(cert));
255 MS_ERROR_IF_NULL_W_RET_VAL(evp_pkey, false);
256
257 int ret = X509_CRL_verify(root_crl, evp_pkey);
258 BIO_free_all(bio);
259 if (ret == 1) {
260 MS_LOG(WARNING) << "Equip cert in root crl, verify failed";
261 return false;
262 }
263 MS_LOG(INFO) << "VerifyCRL success.";
264 return true;
265 }
266
VerifyCommonName(const X509 * cert,const std::string & ca_path)267 bool CommUtil::VerifyCommonName(const X509 *cert, const std::string &ca_path) {
268 MS_ERROR_IF_NULL_W_RET_VAL(cert, false);
269 X509 *cert_temp = const_cast<X509 *>(cert);
270 char subject_cn[256] = "";
271 char issuer_cn[256] = "";
272 X509_NAME *subject_name = X509_get_subject_name(cert_temp);
273 X509_NAME *issuer_name = X509_get_issuer_name(cert_temp);
274 MS_ERROR_IF_NULL_W_RET_VAL(subject_name, false);
275 MS_ERROR_IF_NULL_W_RET_VAL(issuer_name, false);
276 if (!X509_NAME_get_text_by_NID(subject_name, NID_commonName, subject_cn, sizeof(subject_cn))) {
277 MS_LOG(WARNING) << "Get text by nid failed.";
278 return false;
279 }
280 if (!X509_NAME_get_text_by_NID(issuer_name, NID_commonName, issuer_cn, sizeof(issuer_cn))) {
281 MS_LOG(WARNING) << "Get text by nid failed.";
282 return false;
283 }
284 MS_LOG(INFO) << "the subject:" << subject_cn << ", the issuer:" << issuer_cn;
285
286 BIO *ca_bio = BIO_new_file(ca_path.c_str(), "r");
287 MS_EXCEPTION_IF_NULL(ca_bio);
288 X509 *ca_cert = PEM_read_bio_X509(ca_bio, nullptr, nullptr, nullptr);
289 MS_EXCEPTION_IF_NULL(ca_cert);
290 char ca_subject_cn[256] = "";
291 char ca_issuer_cn[256] = "";
292 X509_NAME *ca_subject_name = X509_get_subject_name(ca_cert);
293 X509_NAME *ca_issuer_name = X509_get_issuer_name(ca_cert);
294 MS_ERROR_IF_NULL_W_RET_VAL(ca_subject_name, false);
295 MS_ERROR_IF_NULL_W_RET_VAL(ca_issuer_name, false);
296 if (!X509_NAME_get_text_by_NID(ca_subject_name, NID_commonName, ca_subject_cn, sizeof(subject_cn))) {
297 MS_LOG(WARNING) << "Get text by nid failed.";
298 return false;
299 }
300 if (!X509_NAME_get_text_by_NID(ca_issuer_name, NID_commonName, ca_issuer_cn, sizeof(issuer_cn))) {
301 MS_LOG(WARNING) << "Get text by nid failed.";
302 return false;
303 }
304 MS_LOG(INFO) << "the subject:" << ca_subject_cn << ", the issuer:" << ca_issuer_cn;
305 BIO_free_all(ca_bio);
306 if (strcmp(issuer_cn, ca_subject_cn) != 0) {
307 return false;
308 }
309 return true;
310 }
311
Split(const std::string & s,char delim)312 std::vector<std::string> CommUtil::Split(const std::string &s, char delim) {
313 std::vector<std::string> res;
314 std::stringstream ss(s);
315 std::string item;
316
317 while (getline(ss, item, delim)) {
318 res.push_back(item);
319 }
320 return res;
321 }
322
VerifyCipherList(const std::vector<std::string> & list)323 bool CommUtil::VerifyCipherList(const std::vector<std::string> &list) {
324 for (auto &item : list) {
325 if (!kCiphers.count(item)) {
326 MS_LOG(WARNING) << "The ciphter:" << item << " is not supported.";
327 return false;
328 }
329 }
330 return true;
331 }
332
InitOpenSSLEnv()333 void CommUtil::InitOpenSSLEnv() {
334 if (!SSL_library_init()) {
335 MS_LOG(EXCEPTION) << "SSL_library_init failed.";
336 }
337 if (!ERR_load_crypto_strings()) {
338 MS_LOG(EXCEPTION) << "ERR_load_crypto_strings failed.";
339 }
340 if (!SSL_load_error_strings()) {
341 MS_LOG(EXCEPTION) << "SSL_load_error_strings failed.";
342 }
343 if (!OpenSSL_add_all_algorithms()) {
344 MS_LOG(EXCEPTION) << "OpenSSL_add_all_algorithms failed.";
345 }
346 }
347 } // namespace core
348 } // namespace ps
349 } // namespace mindspore
350