• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <functional>
18 #include <algorithm>
19 #include <string>
20 #include <vector>
21 #include <utility>
22 #include <unordered_map>
23 #include "utils/ms_exception.h"
24 #include "proto/topology.pb.h"
25 #include "include/backend/distributed/ps/ps_context.h"
26 #include "include/backend/distributed/rpc/tcp/constants.h"
27 #include "include/backend/distributed/recovery/recovery_context.h"
28 #include "distributed/recovery/file_configuration.h"
29 #include "distributed/cluster/topology/meta_server_node.h"
30 #include "utils/convert_utils_base.h"
31 
32 namespace mindspore {
33 namespace distributed {
34 namespace cluster {
35 namespace topology {
36 // The keys for the persisted metadata of compute node states.
37 constexpr char kComputeNodeStates[] = "compute_node_states";
38 constexpr char kNodeId[] = "node_id";
39 constexpr char kRecoveryFileName[] = "recovery.dat";
40 constexpr char kHostName[] = "host_name";
41 constexpr char kRole[] = "role";
42 constexpr char kRankId[] = "rank_id";
43 
~MetaServerNode()44 MetaServerNode::~MetaServerNode() {
45   try {
46     (void)Finalize(true);
47   } catch (std::exception &) {
48     MS_LOG(ERROR) << "Failed to finalize MetaServerNode.";
49   }
50 }
51 
Initialize()52 bool MetaServerNode::Initialize() {
53   // Init metadata for the cluster.
54   SetMetaData();
55 
56   // Init the address of meta server node.
57   RETURN_IF_FALSE_WITH_LOG(FillMetaServerAddress(&meta_server_addr_),
58                            "Failed to init the address of meta server node.");
59 
60   // Init the TCP server.
61   RETURN_IF_FALSE_WITH_LOG(InitTCPServer(), "Failed to create the TCP server.");
62 
63   // The meta server node is restarted and the metadata of cluster needs to be recovered.
64   if (recovery::IsEnableRecovery()) {
65     RETURN_IF_FALSE_WITH_LOG(Recovery(), "Failed to recover from configuration.");
66   }
67 
68   start_time_ = Now();
69 
70   // Init the thread for monitoring the state of the cluster topo.
71   topo_monitor_ = std::thread(&MetaServerNode::UpdateTopoState, this);
72   return true;
73 }
74 
Initialized()75 bool MetaServerNode::Initialized() {
76   return topo_state_ == TopoState::kInitialized || topo_state_ == TopoState::kFinished;
77 }
78 
Finalize(bool force)79 bool MetaServerNode::Finalize(bool force) {
80   if (finalized_) {
81     return true;
82   }
83   if (topo_state_ != TopoState::kFinished && !force &&
84       (recovery::IsEnableRecovery() || (abnormal_node_num_ == 0 && !recovery::IsEnableRecovery()))) {
85     MS_LOG(WARNING) << "The meta server node can not be finalized because there are still " << nodes_.size()
86                     << " alive nodes.";
87     return false;
88   } else {
89     if (abnormal_node_num_ > 0) {
90       MS_LOG(ERROR) << "There are " << abnormal_node_num_ << " abnormal compute graph nodes.";
91     }
92 
93     // Release the TCP server.
94     if (tcp_server_ != nullptr) {
95       tcp_server_->Finalize();
96       tcp_server_.reset();
97     }
98 
99     // Stop the topo monitor thread.
100     enable_monitor_ = false;
101     if (topo_monitor_.joinable()) {
102       topo_monitor_.join();
103     }
104     if (force) {
105       MS_LOG(INFO) << "The meta server node is forced to finalized.";
106     }
107     finalized_ = true;
108     MsException::Instance().CheckException();
109     return true;
110   }
111 }
112 
SetMetaData()113 void MetaServerNode::SetMetaData() {
114   // The validation check happened in cluster_context.cc, so we don't validating in this method.
115   if (!common::GetEnv(kEnvWorkerNum).empty()) {
116     role_expect_num_[kEnvRoleOfWorker] = IntToUint(std::stoi(common::GetEnv(kEnvWorkerNum)));
117   }
118   if (!common::GetEnv(kEnvServerNum).empty()) {
119     role_expect_num_[kEnvRoleOfServer] = IntToUint(std::stoi(common::GetEnv(kEnvServerNum)));
120     role_expect_num_[kEnvRoleOfPServer] = IntToUint(std::stoi(common::GetEnv(kEnvServerNum)));
121   }
122 }
123 
InitTCPServer()124 bool MetaServerNode::InitTCPServer() {
125   bool enable_ssl = ps::PSContext::instance()->enable_ssl();
126   tcp_server_ = std::make_unique<rpc::TCPServer>(enable_ssl);
127   MS_EXCEPTION_IF_NULL(tcp_server_);
128   RETURN_IF_FALSE_WITH_LOG(tcp_server_->Initialize(meta_server_addr_.GetUrl()), "Failed to init the tcp server.");
129   tcp_server_->SetMessageHandler(std::bind(&MetaServerNode::HandleMessage, this, std::placeholders::_1));
130 
131   // Configure the message processors for the TCP server.
132   system_msg_handlers_[MessageName::kRegistration] =
133     std::bind(&MetaServerNode::ProcessRegister, this, std::placeholders::_1);
134   system_msg_handlers_[MessageName::kUnregistration] =
135     std::bind(&MetaServerNode::ProcessUnregister, this, std::placeholders::_1);
136   system_msg_handlers_[MessageName::kHeartbeat] =
137     std::bind(&MetaServerNode::ProcessHeartbeat, this, std::placeholders::_1);
138   system_msg_handlers_[MessageName::kWriteMetadata] =
139     std::bind(&MetaServerNode::ProcessWriteMetadata, this, std::placeholders::_1);
140   system_msg_handlers_[MessageName::kReadMetadata] =
141     std::bind(&MetaServerNode::ProcessReadMetadata, this, std::placeholders::_1);
142   system_msg_handlers_[MessageName::kDeleteMetadata] =
143     std::bind(&MetaServerNode::ProcessDeleteMetadata, this, std::placeholders::_1);
144   system_msg_handlers_[MessageName::kGetHostNames] =
145     std::bind(&MetaServerNode::ProcessGetHostNames, this, std::placeholders::_1);
146   return true;
147 }
148 
HandleMessage(MessageBase * const message)149 MessageBase *const MetaServerNode::HandleMessage(MessageBase *const message) {
150   MS_ERROR_IF_NULL_W_RET_VAL(message, rpc::NULL_MSG);
151   const auto &name = message->Name();
152 
153   // Handle system messages.
154   if (std::all_of(name.begin(), name.end(), ::isdigit)) {
155     const auto &message_name = static_cast<MessageName>(std::stoi(message->Name()));
156     const auto &handler = system_msg_handlers_.find(message_name);
157     if (handler == system_msg_handlers_.end()) {
158       MS_LOG(ERROR) << "Unknown system message name: " << message->Name();
159       delete message;
160       return rpc::NULL_MSG;
161     }
162     auto ret_msg = system_msg_handlers_[message_name](message);
163     delete message;
164     return ret_msg;
165   } else {
166     // Handle user defined messages.
167     const auto &handler = message_handlers_.find(name);
168     if (handler == message_handlers_.end()) {
169       MS_LOG(ERROR) << "Unknown message name: " << name;
170       delete message;
171       return rpc::NULL_MSG;
172     }
173     const auto &result = (*message_handlers_[name])(message->Body());
174     if (result.length() > 0) {
175       auto rt_msg = CreateMessage(meta_server_addr_.GetUrl(), name, result);
176       delete message;
177       MS_EXCEPTION_IF_NULL(rt_msg);
178       return rt_msg.release();
179     } else {
180       delete message;
181       return rpc::NULL_MSG;
182     }
183   }
184 }
185 
ProcessRegister(MessageBase * const message)186 MessageBase *const MetaServerNode::ProcessRegister(MessageBase *const message) {
187   MS_ERROR_IF_NULL_W_RET_VAL(message, rpc::NULL_MSG);
188   RegistrationMessage registration;
189   const std::string &body = message->Body();
190   (void)registration.ParseFromArray(body.c_str(), SizeToInt(body.length()));
191 
192   // Add the compute graph node into registered nodes.
193   const auto &node_id = registration.node_id();
194   const auto &host_name = registration.host_name();
195   const auto &host_ip = registration.host_ip();
196   const auto &role = registration.role();
197   std::unique_lock<std::shared_mutex> lock(nodes_mutex_);
198   if (nodes_.find(node_id) == nodes_.end()) {
199     uint32_t rank_id;
200     if (common::IsStrNumeric(node_id)) {
201       // This means node id is not randomly generated. So directly convert to int.
202       rank_id = static_cast<uint32_t>(std::atoi(node_id.c_str()));
203     } else {
204       rank_id = AllocateRankId(role);
205     }
206 
207     // Check validation of this registered node.
208     std::string reject_reason = "";
209     if (!CheckRankIdValidation(node_id, role, rank_id, host_ip, &reject_reason)) {
210       RegistrationRespMessage reg_resp_msg;
211       reg_resp_msg.set_success(false);
212       reg_resp_msg.set_error_reason(reject_reason);
213       auto response =
214         CreateMessage(meta_server_addr_.GetUrl(), MessageName::kInvalidNode, reg_resp_msg.SerializeAsString());
215       return response.release();
216     }
217 
218     std::shared_ptr<NodeInfo> node_info = std::make_shared<NodeInfo>(node_id);
219     MS_ERROR_IF_NULL_W_RET_VAL(node_info, rpc::NULL_MSG);
220     node_info->host_name = host_name;
221     node_info->host_ip = host_ip;
222     node_info->role = role;
223     node_info->rank_id = rank_id;
224     node_info->state = NodeState::kRegistered;
225     (void)time(&(node_info->last_update));
226     nodes_[node_id] = node_info;
227     MS_LOG(WARNING) << "The new node: " << node_id << "(role: " << role << ")"
228                     << ", rank id: " << rank_id << ", hostname: " << node_info->host_name << ", ip: " << host_ip
229                     << " is registered successfully. Currently registered node number: " << nodes_.size()
230                     << ", expected node number: " << total_node_num_;
231     (void)TransitionToInitialized();
232 
233     RegistrationRespMessage reg_resp_msg;
234     reg_resp_msg.set_success(true);
235     reg_resp_msg.set_rank_id(rank_id);
236     reg_resp_msg.set_node_num(SizeToUint(total_node_num_));
237     std::string content = reg_resp_msg.SerializeAsString();
238 
239     auto message = CreateMessage(meta_server_addr_.GetUrl(), MessageName::kSuccess, content);
240     MS_EXCEPTION_IF_NULL(message);
241     return message.release();
242   } else {
243     if (!recovery::IsEnableRecovery()) {
244       MS_LOG(WARNING) << "Node " << node_id << " registered repeatedly. It's host ip is " << host_ip
245                       << ". Reject this node.";
246       RegistrationRespMessage reg_resp_msg;
247       reg_resp_msg.set_success(false);
248       reg_resp_msg.set_error_reason(
249         "Repeated registration node: " + node_id +
250         " to the scheduler. Please check if there's another scheduler process with port:" +
251         std::to_string(meta_server_addr_.port) +
252         " still running, or this is an extra node for distributed job. You can run command: 'netstat -anp|grep " +
253         std::to_string(meta_server_addr_.port) +
254         "' to check residual scheduler process. If another residual scheduler's still running, please kill it or "
255         "change '--master_port' to a unoccupied port number of 'msrun' command and "
256         "retry.");
257       auto response =
258         CreateMessage(meta_server_addr_.GetUrl(), MessageName::kInvalidNode, reg_resp_msg.SerializeAsString());
259       return response.release();
260     }
261     auto node_info = nodes_[node_id];
262     MS_EXCEPTION_IF_NULL(node_info);
263     node_info->host_ip = host_ip;
264     MS_LOG(WARNING) << "The node: " << node_id << " have been recovered. IP address: " << host_ip
265                     << ", rank id: " << node_info->rank_id;
266     (void)metadata_.insert(std::make_pair(node_info->role + node_info->node_id, std::to_string(node_info->rank_id)));
267 
268     RegistrationRespMessage reg_resp_msg;
269     reg_resp_msg.set_success(true);
270     reg_resp_msg.set_rank_id(node_info->rank_id);
271     std::string content = reg_resp_msg.SerializeAsString();
272 
273     auto response = CreateMessage(meta_server_addr_.GetUrl(), MessageName::kSuccess, content);
274     MS_EXCEPTION_IF_NULL(response);
275     return response.release();
276   }
277 }
278 
ProcessUnregister(MessageBase * const message)279 MessageBase *const MetaServerNode::ProcessUnregister(MessageBase *const message) {
280   MS_ERROR_IF_NULL_W_RET_VAL(message, rpc::NULL_MSG);
281   UnregistrationMessage unregistration;
282   const std::string &body = message->Body();
283   (void)unregistration.ParseFromArray(body.c_str(), SizeToInt(body.length()));
284 
285   const auto &node_id = unregistration.node_id();
286 
287   if (topo_state_ != TopoState::kInitialized) {
288     MS_LOG(ERROR) << "Unable to process unreg message from node " << node_id << " because the state of the topology is "
289                   << topo_state_;
290     auto response = CreateMessage(meta_server_addr_.GetUrl(), MessageName::kUninitTopo,
291                                   std::to_string(static_cast<int>(MessageName::kUninitTopo)));
292     MS_EXCEPTION_IF_NULL(response);
293     return response.release();
294   }
295 
296   std::unique_lock<std::shared_mutex> lock(nodes_mutex_);
297   if (nodes_.find(node_id) == nodes_.end()) {
298     MS_LOG(ERROR) << "Received unregistration message from invalid compute graph node: " << node_id;
299     auto response = CreateMessage(meta_server_addr_.GetUrl(), MessageName::kInvalidNode,
300                                   std::to_string(static_cast<int>(MessageName::kInvalidNode)));
301     MS_EXCEPTION_IF_NULL(response);
302     return response.release();
303   }
304   (void)nodes_.erase(node_id);
305   MS_LOG(WARNING) << "Node " << node_id << " has unregistered.";
306   if (nodes_.size() == 0) {
307     topo_state_ = TopoState::kFinished;
308   }
309   auto response = CreateMessage(meta_server_addr_.GetUrl(), MessageName::kSuccess,
310                                 std::to_string(static_cast<int>(MessageName::kSuccess)));
311   MS_EXCEPTION_IF_NULL(response);
312   return response.release();
313 }
314 
ProcessHeartbeat(MessageBase * const message)315 MessageBase *const MetaServerNode::ProcessHeartbeat(MessageBase *const message) {
316   MS_ERROR_IF_NULL_W_RET_VAL(message, rpc::NULL_MSG);
317   HeartbeatMessage heartbeat;
318   const std::string &body = message->Body();
319   (void)heartbeat.ParseFromArray(body.c_str(), SizeToInt(body.length()));
320 
321   // Update the state(timestamp) of this node.
322   const auto &node_id = heartbeat.node_id();
323   std::shared_lock<std::shared_mutex> lock(nodes_mutex_);
324   if (nodes_.find(node_id) != nodes_.end()) {
325     auto &node = nodes_[node_id];
326     MS_ERROR_IF_NULL_W_RET_VAL(node, rpc::NULL_MSG);
327     (void)time(&(node->last_update));
328     node->state = NodeState::kRegistered;
329 
330     HeartbeatRespMessage resp_msg;
331     resp_msg.set_success(static_cast<bool>(MessageName::kSuccess));
332     resp_msg.set_topo_state(static_cast<uint32_t>(topo_state_));
333     resp_msg.set_nodes_num(SizeToUint(total_node_num_));
334     resp_msg.set_abnormal_nodes_num(SizeToUint(abnormal_node_num_));
335     auto content = resp_msg.SerializeAsString();
336     auto response = CreateMessage(meta_server_addr_.GetUrl(), MessageName::kSuccess, content);
337     MS_EXCEPTION_IF_NULL(response);
338     return response.release();
339   } else {
340     MS_LOG(ERROR) << "Invalid node: " << node_id << ".";
341     return rpc::NULL_MSG;
342   }
343 }
344 
ProcessWriteMetadata(MessageBase * const message)345 MessageBase *const MetaServerNode::ProcessWriteMetadata(MessageBase *const message) {
346   MS_ERROR_IF_NULL_W_RET_VAL(message, rpc::NULL_MSG);
347   const std::string &body = message->Body();
348   MetadataMessage meta_msg;
349   (void)meta_msg.ParseFromArray(body.c_str(), SizeToInt(body.length()));
350   if (meta_msg.name().length() == 0) {
351     MS_LOG(ERROR) << "Empty metadata name.";
352     return rpc::NULL_MSG;
353   }
354   std::shared_lock<std::shared_mutex> lock(meta_mutex_);
355   metadata_[meta_msg.name()] = meta_msg.value();
356   return rpc::NULL_MSG;
357 }
358 
ProcessReadMetadata(MessageBase * const message)359 MessageBase *const MetaServerNode::ProcessReadMetadata(MessageBase *const message) {
360   MS_ERROR_IF_NULL_W_RET_VAL(message, rpc::NULL_MSG);
361   const std::string &body = message->Body();
362   MetadataMessage meta_msg;
363   (void)meta_msg.ParseFromArray(body.c_str(), SizeToInt(body.length()));
364 
365   std::shared_lock<std::shared_mutex> lock(meta_mutex_);
366   MessageName result;
367   std::unique_ptr<MessageBase> response;
368 
369   if (metadata_.find(meta_msg.name()) == metadata_.end()) {
370     result = MessageName::kInvalidMetadata;
371   } else {
372     result = MessageName::kValidMetadata;
373     std::string meta_value = metadata_[meta_msg.name()];
374     meta_msg.set_value(meta_value);
375   }
376   response = CreateMessage(meta_server_addr_.GetUrl(), result, meta_msg.SerializeAsString());
377   MS_EXCEPTION_IF_NULL(response);
378   return response.release();
379 }
380 
ProcessDeleteMetadata(MessageBase * const message)381 MessageBase *const MetaServerNode::ProcessDeleteMetadata(MessageBase *const message) {
382   MS_ERROR_IF_NULL_W_RET_VAL(message, rpc::NULL_MSG);
383   const std::string &body = message->Body();
384   MetadataMessage meta_msg;
385   (void)meta_msg.ParseFromArray(body.c_str(), SizeToInt(body.length()));
386 
387   std::shared_lock<std::shared_mutex> lock(meta_mutex_);
388   MessageName result;
389   std::unique_ptr<MessageBase> response;
390 
391   if (metadata_.find(meta_msg.name()) == metadata_.end()) {
392     result = MessageName::kInvalidMetadata;
393   } else {
394     result = MessageName::kValidMetadata;
395     (void)metadata_.erase(meta_msg.name());
396   }
397   response = CreateMessage(meta_server_addr_.GetUrl(), result, meta_msg.SerializeAsString());
398   MS_EXCEPTION_IF_NULL(response);
399   return response.release();
400 }
401 
ProcessGetHostNames(MessageBase * const message)402 MessageBase *const MetaServerNode::ProcessGetHostNames(MessageBase *const message) {
403   MS_ERROR_IF_NULL_W_RET_VAL(message, rpc::NULL_MSG);
404   // Convert result to the message.
405   nlohmann::json hostnames = nlohmann::json::array();
406   nlohmann::json retval = nlohmann::json::object();
407   MessageName result;
408 
409   if (nodes_.size() != total_node_num_) {
410     result = MessageName::kInvalidMetadata;
411   } else {
412     result = MessageName::kValidMetadata;
413 
414     auto node_role = message->body;
415 
416     // Collect all the hostnames from nodes info.
417     std::vector<std::string> tmp_hostnames(nodes_.size(), "");
418     std::shared_lock<std::shared_mutex> lock(nodes_mutex_);
419 
420     // The hostnames must are sorted strictly by the rank id.
421     for (auto iter = nodes_.begin(); iter != nodes_.end(); ++iter) {
422       auto node_info = iter->second;
423       MS_EXCEPTION_IF_NULL(node_info);
424       if (node_info->role != node_role) {
425         continue;
426       }
427       if (node_info->rank_id >= 0 && node_info->rank_id < tmp_hostnames.size()) {
428         tmp_hostnames[node_info->rank_id] = node_info->host_name;
429       } else {
430         MS_LOG(ERROR) << "Invalid rank id: " << node_info->rank_id << " for node: " << node_info->node_id;
431         continue;
432       }
433     }
434 
435     // The hostname of the node whose role name not match is empty, and should be skipped.
436     for (size_t i = 0; i < tmp_hostnames.size(); ++i) {
437       if (tmp_hostnames[i] != "") {
438         hostnames.push_back(tmp_hostnames[i]);
439       }
440     }
441   }
442 
443   retval[kHostNames] = hostnames;
444   try {
445     MS_LOG(DEBUG) << "Host names are " << retval.dump();
446   } catch (const std::exception &e) {
447     MS_LOG(ERROR) << "Failed to dump host names json " << e.what();
448   }
449   auto response = CreateMessage(meta_server_addr_.GetUrl(), result, retval.dump());
450   MS_EXCEPTION_IF_NULL(response);
451   return response.release();
452 }
453 
UpdateTopoState()454 void MetaServerNode::UpdateTopoState() {
455   try {
456     while (enable_monitor_) {
457       nodes_mutex_.lock();
458 
459       // Update the state of topology.
460       if (topo_state_ == TopoState::kInitializing) {
461         if (TransitionToInitialized()) {
462           continue;
463         }
464         MS_LOG(INFO) << "The cluster topology is in the process of constructing, current alive node num: ("
465                      << nodes_.size() << "/" << total_node_num_ << ")";
466       } else if (topo_state_ == TopoState::kInitialized) {
467         if (nodes_.size() == 0) {
468           topo_state_ = TopoState::kFinished;
469         }
470       }
471 
472       if (!disable_heartbeat_) {
473         // Update the state of compute graph nodes if heartbeat is enabled.
474         size_t abnormal_node_num = 0;
475         std::vector<std::string> time_out_node_ids = {};
476         for (auto iter = nodes_.begin(); iter != nodes_.end(); ++iter) {
477           auto node_id = iter->first;
478           auto node_info = iter->second;
479           MS_EXCEPTION_IF_NULL(node_info);
480           time_t now = time(&now);
481           auto elapsed = difftime(now, node_info->last_update);
482           if (elapsed > node_timeout_) {
483             node_info->state = NodeState::kTimeout;
484             ++abnormal_node_num;
485             time_out_node_ids.push_back(node_id);
486             MS_LOG(ERROR) << "The node: " << node_id
487                           << " is timed out. It may exit with exception, please check this node's log.";
488           }
489         }
490         abnormal_node_num_ = abnormal_node_num;
491         if (abnormal_node_num_ > 0 && !recovery::IsEnableRecovery()) {
492           MS_LOG(EXCEPTION) << "The total number of timed out node is " << abnormal_node_num_
493                             << ". Timed out node list is: " << time_out_node_ids << ", worker " << time_out_node_ids[0]
494                             << " is the first one timed out, please check its log.";
495         }
496       }
497 
498       nodes_mutex_.unlock();
499       static const size_t interval = 3;
500       (void)sleep(interval);
501     }
502   } catch (const std::exception &e) {
503     nodes_mutex_.unlock();
504     MsException::Instance().SetException();
505   }
506 }
507 
TransitionToInitialized()508 bool MetaServerNode::TransitionToInitialized() {
509   if (nodes_.size() == total_node_num_) {
510     // After all nodes are successfully registered, reassign rank ids so they could be continuous.
511     ReassignNodeRank();
512 
513     // Assign port range for each node after cluster is initialized.
514     AssignPortRange();
515 
516     // Persist the cluster metadata into storage through configuration.
517     if (recovery::IsEnableRecovery() && configuration_ != nullptr && configuration_->Empty()) {
518       if (!Persist()) {
519         MS_LOG(EXCEPTION) << "Failed to persist the metadata of the cluster.";
520       }
521     }
522     topo_state_ = TopoState::kInitialized;
523     MS_LOG(INFO) << "The cluster topology has been constructed successfully.";
524     return true;
525   }
526   return false;
527 }
528 
AssignPortRange()529 void MetaServerNode::AssignPortRange() {
530   MS_LOG(DEBUG) << "Start assigning port range for nodes...";
531   std::unordered_map<std::string, uint32_t> each_host_node_num;
532   std::unordered_map<std::string, uint32_t> node_index_map;
533   // Assign computing graph nodes' port range according to their hosts.
534   for (const auto &n : nodes_) {
535     std::string node_id = n.first;
536     const auto &node_info = n.second;
537 
538     uint32_t &host_node_num = each_host_node_num[node_info->host_name];
539     node_index_map[node_info->node_id] = host_node_num;
540     host_node_num++;
541   }
542 
543   NodePortRanges node_ranges;
544   for (const auto &n : nodes_) {
545     std::string node_id = n.first;
546     const auto &node_info = n.second;
547     uint32_t node_index = node_index_map[node_id];
548     uint32_t each_node_range = kNodePortRangeNum / each_host_node_num[node_info->host_name];
549     uint32_t min_port = kStartPort + each_node_range * node_index;
550     uint32_t max_port = min_port + each_node_range - 1;
551     PortRange range;
552     range.set_min_port(min_port);
553     range.set_max_port(max_port);
554     (void)node_ranges.mutable_data()->insert({node_id, range});
555     MS_LOG(INFO) << "The port range for node " << node_id << ", rank id: " << node_info->rank_id
556                  << ", min port: " << min_port << ", max port: " << max_port;
557   }
558   (void)metadata_.insert({kNodePortRange, node_ranges.SerializeAsString()});
559 }
560 
Recovery()561 bool MetaServerNode::Recovery() {
562   std::shared_lock<std::shared_mutex> lock(nodes_mutex_);
563   std::string recovery_path = recovery::RecoveryPath();
564   RETURN_IF_FALSE_WITH_LOG(CheckFilePath(recovery_path), "Invalid recovery path: " << recovery_path);
565   configuration_ = std::make_unique<recovery::FileConfiguration>(recovery_path + "/" + kRecoveryFileName);
566   MS_EXCEPTION_IF_NULL(configuration_);
567 
568   RETURN_IF_FALSE_WITH_LOG(configuration_->Initialize(),
569                            "Failed to initialize the recovery file configuration from file path: " << recovery_path);
570 
571   if (configuration_->Empty()) {
572     MS_LOG(INFO) << "The meta server node is started for the first time.";
573     return true;
574 
575     // The meta server node is restarted and the metadata of cluster needs to be recovered.
576   } else {
577     MS_LOG(INFO) << "Begin to recover the meta server node.";
578     std::string states_key = kComputeNodeStates;
579     RETURN_IF_FALSE_WITH_LOG(configuration_->Exists(states_key),
580                              "Can not find the key " + states_key + " in configuration.");
581 
582     // Check the validation of the previous metadata.
583     const auto &states = configuration_->Get(states_key, "");
584     nlohmann::json node_states = nlohmann::json::parse(states);
585     RETURN_IF_FALSE_WITH_LOG(node_states.size() == total_node_num_,
586                              "Invalid number of node in configuration: " + std::to_string(node_states.size()) +
587                                ", expected total number of node: " + std::to_string(total_node_num_));
588 
589     // Restore the nodes state.
590     for (auto iter = node_states.begin(); iter != node_states.end(); ++iter) {
591       const auto &node_id = iter.key();
592       std::shared_ptr<NodeInfo> node_info = std::make_shared<NodeInfo>(node_id);
593       MS_EXCEPTION_IF_NULL(node_info);
594       (void)time(&(node_info->last_update));
595       node_info->host_name = iter.value().at(kHostName);
596       node_info->role = iter.value().at(kRole);
597       node_info->rank_id = iter.value().at(kRankId);
598       node_info->state = NodeState::kRegistered;
599       nodes_[node_id] = node_info;
600     }
601 
602     if (nodes_.size() == total_node_num_) {
603       topo_state_ = TopoState::kInitialized;
604     }
605     MS_LOG(INFO) << "The meta server node has been recovered successfully.";
606   }
607   return true;
608 }
609 
Persist()610 bool MetaServerNode::Persist() {
611   if (total_node_num_ != nodes_.size()) {
612     MS_LOG(ERROR) << "Invalid number of alive node: " << nodes_.size()
613                   << ", the expected total number of node is: " << total_node_num_;
614     return false;
615   }
616 
617   // The thread safety of nodes_ visiting has been guarded by the caller.
618   nlohmann::json node_states;
619   for (auto iter = nodes_.begin(); iter != nodes_.end(); ++iter) {
620     const auto &node_id = iter->first;
621     nlohmann::json node_state;
622     node_state[kNodeId] = node_id;
623 
624     MS_EXCEPTION_IF_NULL(iter->second);
625     node_state[kHostName] = iter->second->host_name;
626     node_state[kRole] = iter->second->role;
627     node_state[kRankId] = iter->second->rank_id;
628     node_states[node_id] = node_state;
629   }
630 
631   MS_EXCEPTION_IF_NULL(configuration_);
632   configuration_->Put(kComputeNodeStates, node_states.dump());
633   RETURN_IF_FALSE_WITH_LOG(configuration_->Flush(), "Failed to flush configuration.");
634   return true;
635 }
636 
AllocateRankId(const std::string & role)637 uint32_t MetaServerNode::AllocateRankId(const std::string &role) {
638   std::shared_lock<std::shared_mutex> lock(rank_mutex_);
639   if (role_expect_num_.find(role) == role_expect_num_.end()) {
640     MS_LOG(WARNING) << "Role: " << role << " is invalid.";
641     return UINT32_MAX;
642   }
643   if (next_rank_ids_.count(role) == 0) {
644     next_rank_ids_[role] = 0;
645   } else {
646     // If this role's rank id has exceeded, do not increase next_rank_ids_ and return an exceeded rank id. The caller
647     // will check rank id's validation and reject this request.
648     if (next_rank_ids_[role] == role_expect_num_[role] - 1) {
649       return next_rank_ids_[role] + 1;
650     }
651     next_rank_ids_[role] += 1;
652   }
653   return next_rank_ids_[role];
654 }
655 
CheckRankIdValidation(const std::string & node_id,const std::string & role,uint32_t rank_id,const std::string & host_ip,std::string * reject_reason)656 bool MetaServerNode::CheckRankIdValidation(const std::string &node_id, const std::string &role, uint32_t rank_id,
657                                            const std::string &host_ip, std::string *reject_reason) {
658   if (role_expect_num_.find(role) == role_expect_num_.end()) {
659     MS_LOG(WARNING) << "Registered node role: " << role << " is invalid.";
660     return false;
661   }
662   // Whether rank id has already exists.
663   bool rank_id_exist = std::any_of(nodes_.begin(), nodes_.end(), [&role, &rank_id](const auto &n) {
664     return n.second->role == role && n.second->rank_id == rank_id;
665   });
666   // Whether rank id exceeds upper bound.
667   bool is_extra_node = (rank_id >= role_expect_num_[role]);
668   if (rank_id_exist) {
669     *reject_reason = "Rank id:" + std::to_string(rank_id) + " for role:" + role + " exists.";
670   }
671   if (is_extra_node) {
672     *reject_reason = "This node is extra or rank id exceeds. Total node number for role " + role + " is " +
673                      std::to_string(role_expect_num_[role]) + " but got rank id " + std::to_string(rank_id);
674   }
675   if (rank_id_exist || is_extra_node) {
676     MS_LOG(WARNING) << "Rejecting registration request for node " << node_id << " from host " << host_ip
677                     << ". Rejection reason: " << *reject_reason;
678     return false;
679   }
680   return true;
681 }
682 
ReassignNodeRank()683 void MetaServerNode::ReassignNodeRank() {
684   if (std::all_of(nodes_.begin(), nodes_.end(), [](const auto &node) { return common::IsStrNumeric(node.first); })) {
685     MS_LOG(WARNING) << "Rank ids are already set by numeric node ids. No need to reassign them.";
686     for (const auto &n : nodes_) {
687       const std::shared_ptr<NodeInfo> &node_info = n.second;
688       const std::string &role = node_info->role;
689       (void)metadata_.insert(std::make_pair(role + node_info->node_id, std::to_string(node_info->rank_id)));
690     }
691     return;
692   }
693 
694   MS_LOG(INFO) << "Start sorting and reassiging rank ids for nodes according to node ips and node ids.";
695   std::map<std::string, std::map<NodeKey, uint32_t>> node_ranks;
696   for (auto &n : nodes_) {
697     std::shared_ptr<NodeInfo> &node_info = n.second;
698     NodeKey node_key = {node_info->host_ip, node_info->node_id};
699     (void)node_ranks[node_info->role].insert(std::make_pair(node_key, 0));
700   }
701 
702   for (auto &n : node_ranks) {
703     std::map<NodeKey, uint32_t> &node_key_ranks = n.second;
704     uint32_t accum_rank_id = 0;
705     for (auto &node_rank : node_key_ranks) {
706       node_rank.second = accum_rank_id++;
707     }
708   }
709 
710   for (auto &n : nodes_) {
711     std::shared_ptr<NodeInfo> &node_info = n.second;
712     const std::string &role = node_info->role;
713     NodeKey node_key = {node_info->host_ip, node_info->node_id};
714     uint32_t new_rank = node_ranks[role][node_key];
715 
716     MS_LOG(WARNING) << "Assign rank id of node id: " << node_info->node_id << ", role: " << role
717                     << ", with host ip: " << node_info->host_ip << ", old rank id: " << node_info->rank_id
718                     << ", new rank id: " << new_rank;
719 
720     node_info->rank_id = new_rank;
721     (void)metadata_.insert(std::make_pair(role + node_info->node_id, std::to_string(node_info->rank_id)));
722   }
723 }
724 
TopologyState() const725 TopoState MetaServerNode::TopologyState() const { return topo_state_; }
726 
GetAliveNodeNum()727 size_t MetaServerNode::GetAliveNodeNum() {
728   std::shared_lock<std::shared_mutex> lock(nodes_mutex_);
729   size_t count = 0;
730   for (auto iter = nodes_.begin(); iter != nodes_.end(); ++iter) {
731     auto node_info = iter->second;
732     MS_EXCEPTION_IF_NULL(node_info);
733 
734     // Only the node which has been authenticated is alive.
735     if (node_info->state == NodeState::kRegistered) {
736       ++count;
737     }
738   }
739   return count;
740 }
741 
RegisterMessageHandler(const std::string & name,const std::shared_ptr<std::function<std::string (const std::string &)>> & handler)742 bool MetaServerNode::RegisterMessageHandler(
743   const std::string &name, const std::shared_ptr<std::function<std::string(const std::string &)>> &handler) {
744   if (message_handlers_.find(name) != message_handlers_.end()) {
745     MS_LOG(ERROR) << "The message name: " << name << " have already been registered";
746     return false;
747   }
748   message_handlers_[name] = handler;
749   return true;
750 }
751 }  // namespace topology
752 }  // namespace cluster
753 }  // namespace distributed
754 }  // namespace mindspore
755