• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ps/core/comm_util.h"
18 
19 #include <arpa/inet.h>
20 #include <cstdio>
21 #include <cstdlib>
22 #include <cstring>
23 #include <functional>
24 #include <algorithm>
25 #include <regex>
26 
27 namespace mindspore {
28 namespace ps {
29 namespace core {
30 std::random_device CommUtil::rd;
31 std::mt19937_64 CommUtil::gen(rd());
32 std::uniform_int_distribution<> CommUtil::dis = std::uniform_int_distribution<>{0, 15};
33 std::uniform_int_distribution<> CommUtil::dis2 = std::uniform_int_distribution<>{8, 11};
34 
CheckIpWithRegex(const std::string & ip)35 bool CommUtil::CheckIpWithRegex(const std::string &ip) {
36   std::regex pattern(
37     "(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])"
38     "[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
39     "[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
40     "[.](25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])");
41   std::smatch res;
42   if (regex_match(ip, res, pattern)) {
43     return true;
44   }
45   return false;
46 }
47 
CheckIp(const std::string & ip)48 bool CommUtil::CheckIp(const std::string &ip) {
49   if (!CheckIpWithRegex(ip)) {
50     return false;
51   }
52   uint32_t uAddr = inet_addr(ip.c_str());
53   if (INADDR_NONE == uAddr) {
54     return false;
55   }
56   return true;
57 }
58 
CheckPort(const uint16_t & port)59 bool CommUtil::CheckPort(const uint16_t &port) {
60   if (port > 65535) {
61     MS_LOG(ERROR) << "The range of port should be 1 to 65535.";
62     return false;
63   }
64   return true;
65 }
66 
GetAvailableInterfaceAndIP(std::string * interface,std::string * ip)67 void CommUtil::GetAvailableInterfaceAndIP(std::string *interface, std::string *ip) {
68   MS_EXCEPTION_IF_NULL(interface);
69   MS_EXCEPTION_IF_NULL(ip);
70   struct ifaddrs *if_address = nullptr;
71   struct ifaddrs *ifa = nullptr;
72 
73   interface->clear();
74   ip->clear();
75   if (getifaddrs(&if_address) == -1) {
76     MS_LOG(WARNING) << "Get ifaddrs failed.";
77   }
78   for (ifa = if_address; ifa != nullptr; ifa = ifa->ifa_next) {
79     if (ifa->ifa_addr == nullptr) {
80       continue;
81     }
82 
83     if (ifa->ifa_addr->sa_family == AF_INET && (ifa->ifa_flags & IFF_LOOPBACK) == 0) {
84       char address_buffer[INET_ADDRSTRLEN] = {0};
85       void *sin_addr_ptr = &(reinterpret_cast<struct sockaddr_in *>(ifa->ifa_addr))->sin_addr;
86       MS_EXCEPTION_IF_NULL(sin_addr_ptr);
87       const char *net_ptr = inet_ntop(AF_INET, sin_addr_ptr, address_buffer, INET_ADDRSTRLEN);
88       MS_EXCEPTION_IF_NULL(net_ptr);
89 
90       *ip = address_buffer;
91       *interface = ifa->ifa_name;
92       break;
93     }
94   }
95   MS_EXCEPTION_IF_NULL(if_address);
96   freeifaddrs(if_address);
97 }
98 
GenerateUUID()99 std::string CommUtil::GenerateUUID() {
100   std::stringstream ss;
101   int i;
102   ss << std::hex;
103   for (i = 0; i < kGroup1RandomLength; i++) {
104     ss << dis(gen);
105   }
106   ss << "-";
107   for (i = 0; i < kGroup2RandomLength; i++) {
108     ss << dis(gen);
109   }
110   ss << "-4";
111   for (i = 0; i < kGroup3RandomLength - 1; i++) {
112     ss << dis(gen);
113   }
114   ss << "-";
115   ss << dis2(gen);
116   for (i = 0; i < kGroup4RandomLength - 1; i++) {
117     ss << dis(gen);
118   }
119   ss << "-";
120   for (i = 0; i < kGroup5RandomLength; i++) {
121     ss << dis(gen);
122   }
123   return ss.str();
124 }
125 
NodeRoleToString(const NodeRole & role)126 std::string CommUtil::NodeRoleToString(const NodeRole &role) {
127   switch (role) {
128     case NodeRole::SCHEDULER:
129       return "SCHEDULER";
130     case NodeRole::SERVER:
131       return "SERVER";
132     case NodeRole::WORKER:
133       return "WORKER";
134     default:
135       MS_LOG(EXCEPTION) << "The node role:" << role << " is illegal!";
136   }
137 }
ValidateRankId(const enum NodeRole & node_role,const uint32_t & rank_id,const int32_t & total_worker_num,const int32_t & total_server_num)138 bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id, const int32_t &total_worker_num,
139                               const int32_t &total_server_num) {
140   if (node_role == NodeRole::SERVER && (rank_id > IntToUint(total_server_num) - 1)) {
141     return false;
142   } else if (node_role == NodeRole::WORKER && (rank_id > IntToUint(total_worker_num) - 1)) {
143     return false;
144   }
145   return true;
146 }
147 
Retry(const std::function<bool ()> & func,size_t max_attempts,size_t interval_milliseconds)148 bool CommUtil::Retry(const std::function<bool()> &func, size_t max_attempts, size_t interval_milliseconds) {
149   for (size_t attempt = 0; attempt < max_attempts; ++attempt) {
150     if (func()) {
151       return true;
152     }
153     std::this_thread::sleep_for(std::chrono::milliseconds(interval_milliseconds));
154   }
155   return false;
156 }
157 
LogCallback(int severity,const char * msg)158 void CommUtil::LogCallback(int severity, const char *msg) {
159   MS_EXCEPTION_IF_NULL(msg);
160   switch (severity) {
161     case EVENT_LOG_MSG:
162       MS_LOG(INFO) << kLibeventLogPrefix << msg;
163       break;
164     case EVENT_LOG_WARN:
165       MS_LOG(WARNING) << kLibeventLogPrefix << msg;
166       break;
167     case EVENT_LOG_ERR:
168       MS_LOG(ERROR) << kLibeventLogPrefix << msg;
169       break;
170     default:
171       break;
172   }
173 }
174 
IsFileExists(const std::string & file)175 bool CommUtil::IsFileExists(const std::string &file) {
176   std::ifstream f(file.c_str());
177   if (!f.good()) {
178     return false;
179   } else {
180     f.close();
181     return true;
182   }
183 }
184 
ClusterStateToString(const ClusterState & state)185 std::string CommUtil::ClusterStateToString(const ClusterState &state) {
186   MS_LOG(INFO) << "The cluster state:" << state;
187   if (state < SizeToInt(kClusterState.size())) {
188     return kClusterState.at(state);
189   } else {
190     return "";
191   }
192 }
193 
ParseConfig(const Configuration & config,const std::string & data)194 std::string CommUtil::ParseConfig(const Configuration &config, const std::string &data) {
195   if (!config.IsInitialized()) {
196     MS_LOG(INFO) << "The config is not initialized.";
197     return "";
198   }
199 
200   if (!const_cast<Configuration &>(config).Exists(data)) {
201     MS_LOG(INFO) << "The data:" << data << " is not exist.";
202     return "";
203   }
204 
205   std::string path = config.GetString(data, "");
206   return path;
207 }
208 
VerifyCertTime(const X509 * cert,int64_t time)209 bool CommUtil::VerifyCertTime(const X509 *cert, int64_t time) {
210   MS_EXCEPTION_IF_NULL(cert);
211   ASN1_TIME *start = X509_getm_notBefore(cert);
212   ASN1_TIME *end = X509_getm_notAfter(cert);
213   MS_EXCEPTION_IF_NULL(start);
214   MS_EXCEPTION_IF_NULL(end);
215   int day = 0;
216   int sec = 0;
217   if (!ASN1_TIME_diff(&day, &sec, start, NULL)) {
218     MS_LOG(WARNING) << "ASN1 time diff failed.";
219     return false;
220   }
221 
222   if (day < 0 || sec < 0) {
223     MS_LOG(WARNING) << "Cert start time is later than now time.";
224     return false;
225   }
226   day = 0;
227   sec = 0;
228 
229   if (!ASN1_TIME_diff(&day, &sec, NULL, end)) {
230     MS_LOG(WARNING) << "ASN1 time diff failed.";
231     return false;
232   }
233 
234   int64_t interval = kCertExpireWarningTimeInDay;
235   if (time > 0) {
236     interval = time;
237   }
238 
239   if (day < LongToInt(interval) && day >= 0) {
240     MS_LOG(WARNING) << "The certificate will expire in " << day << " days and " << sec << " seconds.";
241   } else if (day < 0 || sec < 0) {
242     MS_LOG(WARNING) << "The certificate has expired.";
243     return false;
244   }
245   return true;
246 }
247 
VerifyCRL(const X509 * cert,const std::string & crl_path)248 bool CommUtil::VerifyCRL(const X509 *cert, const std::string &crl_path) {
249   MS_ERROR_IF_NULL_W_RET_VAL(cert, false);
250   BIO *bio = BIO_new_file(crl_path.c_str(), "r");
251   MS_ERROR_IF_NULL_W_RET_VAL(bio, false);
252   X509_CRL *root_crl = PEM_read_bio_X509_CRL(bio, nullptr, nullptr, nullptr);
253   MS_ERROR_IF_NULL_W_RET_VAL(root_crl, false);
254   EVP_PKEY *evp_pkey = X509_get_pubkey(const_cast<X509 *>(cert));
255   MS_ERROR_IF_NULL_W_RET_VAL(evp_pkey, false);
256 
257   int ret = X509_CRL_verify(root_crl, evp_pkey);
258   BIO_free_all(bio);
259   if (ret == 1) {
260     MS_LOG(WARNING) << "Equip cert in root crl, verify failed";
261     return false;
262   }
263   MS_LOG(INFO) << "VerifyCRL success.";
264   return true;
265 }
266 
VerifyCommonName(const X509 * cert,const std::string & ca_path)267 bool CommUtil::VerifyCommonName(const X509 *cert, const std::string &ca_path) {
268   MS_ERROR_IF_NULL_W_RET_VAL(cert, false);
269   X509 *cert_temp = const_cast<X509 *>(cert);
270   char subject_cn[256] = "";
271   char issuer_cn[256] = "";
272   X509_NAME *subject_name = X509_get_subject_name(cert_temp);
273   X509_NAME *issuer_name = X509_get_issuer_name(cert_temp);
274   MS_ERROR_IF_NULL_W_RET_VAL(subject_name, false);
275   MS_ERROR_IF_NULL_W_RET_VAL(issuer_name, false);
276   if (!X509_NAME_get_text_by_NID(subject_name, NID_commonName, subject_cn, sizeof(subject_cn))) {
277     MS_LOG(WARNING) << "Get text by nid failed.";
278     return false;
279   }
280   if (!X509_NAME_get_text_by_NID(issuer_name, NID_commonName, issuer_cn, sizeof(issuer_cn))) {
281     MS_LOG(WARNING) << "Get text by nid failed.";
282     return false;
283   }
284   MS_LOG(INFO) << "the subject:" << subject_cn << ", the issuer:" << issuer_cn;
285 
286   BIO *ca_bio = BIO_new_file(ca_path.c_str(), "r");
287   MS_EXCEPTION_IF_NULL(ca_bio);
288   X509 *ca_cert = PEM_read_bio_X509(ca_bio, nullptr, nullptr, nullptr);
289   MS_EXCEPTION_IF_NULL(ca_cert);
290   char ca_subject_cn[256] = "";
291   char ca_issuer_cn[256] = "";
292   X509_NAME *ca_subject_name = X509_get_subject_name(ca_cert);
293   X509_NAME *ca_issuer_name = X509_get_issuer_name(ca_cert);
294   MS_ERROR_IF_NULL_W_RET_VAL(ca_subject_name, false);
295   MS_ERROR_IF_NULL_W_RET_VAL(ca_issuer_name, false);
296   if (!X509_NAME_get_text_by_NID(ca_subject_name, NID_commonName, ca_subject_cn, sizeof(subject_cn))) {
297     MS_LOG(WARNING) << "Get text by nid failed.";
298     return false;
299   }
300   if (!X509_NAME_get_text_by_NID(ca_issuer_name, NID_commonName, ca_issuer_cn, sizeof(issuer_cn))) {
301     MS_LOG(WARNING) << "Get text by nid failed.";
302     return false;
303   }
304   MS_LOG(INFO) << "the subject:" << ca_subject_cn << ", the issuer:" << ca_issuer_cn;
305   BIO_free_all(ca_bio);
306   if (strcmp(issuer_cn, ca_subject_cn) != 0) {
307     return false;
308   }
309   return true;
310 }
311 
Split(const std::string & s,char delim)312 std::vector<std::string> CommUtil::Split(const std::string &s, char delim) {
313   std::vector<std::string> res;
314   std::stringstream ss(s);
315   std::string item;
316 
317   while (getline(ss, item, delim)) {
318     res.push_back(item);
319   }
320   return res;
321 }
322 
VerifyCipherList(const std::vector<std::string> & list)323 bool CommUtil::VerifyCipherList(const std::vector<std::string> &list) {
324   for (auto &item : list) {
325     if (!kCiphers.count(item)) {
326       MS_LOG(WARNING) << "The ciphter:" << item << " is not supported.";
327       return false;
328     }
329   }
330   return true;
331 }
332 
InitOpenSSLEnv()333 void CommUtil::InitOpenSSLEnv() {
334   if (!SSL_library_init()) {
335     MS_LOG(EXCEPTION) << "SSL_library_init failed.";
336   }
337   if (!ERR_load_crypto_strings()) {
338     MS_LOG(EXCEPTION) << "ERR_load_crypto_strings failed.";
339   }
340   if (!SSL_load_error_strings()) {
341     MS_LOG(EXCEPTION) << "SSL_load_error_strings failed.";
342   }
343   if (!OpenSSL_add_all_algorithms()) {
344     MS_LOG(EXCEPTION) << "OpenSSL_add_all_algorithms failed.";
345   }
346 }
347 }  // namespace core
348 }  // namespace ps
349 }  // namespace mindspore
350