1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <functional>
18 #include <algorithm>
19 #include <string>
20 #include <vector>
21 #include <utility>
22 #include <unordered_map>
23 #include "utils/ms_exception.h"
24 #include "proto/topology.pb.h"
25 #include "include/backend/distributed/ps/ps_context.h"
26 #include "include/backend/distributed/rpc/tcp/constants.h"
27 #include "include/backend/distributed/recovery/recovery_context.h"
28 #include "distributed/recovery/file_configuration.h"
29 #include "distributed/cluster/topology/meta_server_node.h"
30 #include "utils/convert_utils_base.h"
31
32 namespace mindspore {
33 namespace distributed {
34 namespace cluster {
35 namespace topology {
36 // The keys for the persisted metadata of compute node states.
37 constexpr char kComputeNodeStates[] = "compute_node_states";
38 constexpr char kNodeId[] = "node_id";
39 constexpr char kRecoveryFileName[] = "recovery.dat";
40 constexpr char kHostName[] = "host_name";
41 constexpr char kRole[] = "role";
42 constexpr char kRankId[] = "rank_id";
43
~MetaServerNode()44 MetaServerNode::~MetaServerNode() {
45 try {
46 (void)Finalize(true);
47 } catch (std::exception &) {
48 MS_LOG(ERROR) << "Failed to finalize MetaServerNode.";
49 }
50 }
51
Initialize()52 bool MetaServerNode::Initialize() {
53 // Init metadata for the cluster.
54 SetMetaData();
55
56 // Init the address of meta server node.
57 RETURN_IF_FALSE_WITH_LOG(FillMetaServerAddress(&meta_server_addr_),
58 "Failed to init the address of meta server node.");
59
60 // Init the TCP server.
61 RETURN_IF_FALSE_WITH_LOG(InitTCPServer(), "Failed to create the TCP server.");
62
63 // The meta server node is restarted and the metadata of cluster needs to be recovered.
64 if (recovery::IsEnableRecovery()) {
65 RETURN_IF_FALSE_WITH_LOG(Recovery(), "Failed to recover from configuration.");
66 }
67
68 start_time_ = Now();
69
70 // Init the thread for monitoring the state of the cluster topo.
71 topo_monitor_ = std::thread(&MetaServerNode::UpdateTopoState, this);
72 return true;
73 }
74
Initialized()75 bool MetaServerNode::Initialized() {
76 return topo_state_ == TopoState::kInitialized || topo_state_ == TopoState::kFinished;
77 }
78
Finalize(bool force)79 bool MetaServerNode::Finalize(bool force) {
80 if (finalized_) {
81 return true;
82 }
83 if (topo_state_ != TopoState::kFinished && !force &&
84 (recovery::IsEnableRecovery() || (abnormal_node_num_ == 0 && !recovery::IsEnableRecovery()))) {
85 MS_LOG(WARNING) << "The meta server node can not be finalized because there are still " << nodes_.size()
86 << " alive nodes.";
87 return false;
88 } else {
89 if (abnormal_node_num_ > 0) {
90 MS_LOG(ERROR) << "There are " << abnormal_node_num_ << " abnormal compute graph nodes.";
91 }
92
93 // Release the TCP server.
94 if (tcp_server_ != nullptr) {
95 tcp_server_->Finalize();
96 tcp_server_.reset();
97 }
98
99 // Stop the topo monitor thread.
100 enable_monitor_ = false;
101 if (topo_monitor_.joinable()) {
102 topo_monitor_.join();
103 }
104 if (force) {
105 MS_LOG(INFO) << "The meta server node is forced to finalized.";
106 }
107 finalized_ = true;
108 MsException::Instance().CheckException();
109 return true;
110 }
111 }
112
SetMetaData()113 void MetaServerNode::SetMetaData() {
114 // The validation check happened in cluster_context.cc, so we don't validating in this method.
115 if (!common::GetEnv(kEnvWorkerNum).empty()) {
116 role_expect_num_[kEnvRoleOfWorker] = IntToUint(std::stoi(common::GetEnv(kEnvWorkerNum)));
117 }
118 if (!common::GetEnv(kEnvServerNum).empty()) {
119 role_expect_num_[kEnvRoleOfServer] = IntToUint(std::stoi(common::GetEnv(kEnvServerNum)));
120 role_expect_num_[kEnvRoleOfPServer] = IntToUint(std::stoi(common::GetEnv(kEnvServerNum)));
121 }
122 }
123
InitTCPServer()124 bool MetaServerNode::InitTCPServer() {
125 bool enable_ssl = ps::PSContext::instance()->enable_ssl();
126 tcp_server_ = std::make_unique<rpc::TCPServer>(enable_ssl);
127 MS_EXCEPTION_IF_NULL(tcp_server_);
128 RETURN_IF_FALSE_WITH_LOG(tcp_server_->Initialize(meta_server_addr_.GetUrl()), "Failed to init the tcp server.");
129 tcp_server_->SetMessageHandler(std::bind(&MetaServerNode::HandleMessage, this, std::placeholders::_1));
130
131 // Configure the message processors for the TCP server.
132 system_msg_handlers_[MessageName::kRegistration] =
133 std::bind(&MetaServerNode::ProcessRegister, this, std::placeholders::_1);
134 system_msg_handlers_[MessageName::kUnregistration] =
135 std::bind(&MetaServerNode::ProcessUnregister, this, std::placeholders::_1);
136 system_msg_handlers_[MessageName::kHeartbeat] =
137 std::bind(&MetaServerNode::ProcessHeartbeat, this, std::placeholders::_1);
138 system_msg_handlers_[MessageName::kWriteMetadata] =
139 std::bind(&MetaServerNode::ProcessWriteMetadata, this, std::placeholders::_1);
140 system_msg_handlers_[MessageName::kReadMetadata] =
141 std::bind(&MetaServerNode::ProcessReadMetadata, this, std::placeholders::_1);
142 system_msg_handlers_[MessageName::kDeleteMetadata] =
143 std::bind(&MetaServerNode::ProcessDeleteMetadata, this, std::placeholders::_1);
144 system_msg_handlers_[MessageName::kGetHostNames] =
145 std::bind(&MetaServerNode::ProcessGetHostNames, this, std::placeholders::_1);
146 return true;
147 }
148
HandleMessage(MessageBase * const message)149 MessageBase *const MetaServerNode::HandleMessage(MessageBase *const message) {
150 MS_ERROR_IF_NULL_W_RET_VAL(message, rpc::NULL_MSG);
151 const auto &name = message->Name();
152
153 // Handle system messages.
154 if (std::all_of(name.begin(), name.end(), ::isdigit)) {
155 const auto &message_name = static_cast<MessageName>(std::stoi(message->Name()));
156 const auto &handler = system_msg_handlers_.find(message_name);
157 if (handler == system_msg_handlers_.end()) {
158 MS_LOG(ERROR) << "Unknown system message name: " << message->Name();
159 delete message;
160 return rpc::NULL_MSG;
161 }
162 auto ret_msg = system_msg_handlers_[message_name](message);
163 delete message;
164 return ret_msg;
165 } else {
166 // Handle user defined messages.
167 const auto &handler = message_handlers_.find(name);
168 if (handler == message_handlers_.end()) {
169 MS_LOG(ERROR) << "Unknown message name: " << name;
170 delete message;
171 return rpc::NULL_MSG;
172 }
173 const auto &result = (*message_handlers_[name])(message->Body());
174 if (result.length() > 0) {
175 auto rt_msg = CreateMessage(meta_server_addr_.GetUrl(), name, result);
176 delete message;
177 MS_EXCEPTION_IF_NULL(rt_msg);
178 return rt_msg.release();
179 } else {
180 delete message;
181 return rpc::NULL_MSG;
182 }
183 }
184 }
185
ProcessRegister(MessageBase * const message)186 MessageBase *const MetaServerNode::ProcessRegister(MessageBase *const message) {
187 MS_ERROR_IF_NULL_W_RET_VAL(message, rpc::NULL_MSG);
188 RegistrationMessage registration;
189 const std::string &body = message->Body();
190 (void)registration.ParseFromArray(body.c_str(), SizeToInt(body.length()));
191
192 // Add the compute graph node into registered nodes.
193 const auto &node_id = registration.node_id();
194 const auto &host_name = registration.host_name();
195 const auto &host_ip = registration.host_ip();
196 const auto &role = registration.role();
197 std::unique_lock<std::shared_mutex> lock(nodes_mutex_);
198 if (nodes_.find(node_id) == nodes_.end()) {
199 uint32_t rank_id;
200 if (common::IsStrNumeric(node_id)) {
201 // This means node id is not randomly generated. So directly convert to int.
202 rank_id = static_cast<uint32_t>(std::atoi(node_id.c_str()));
203 } else {
204 rank_id = AllocateRankId(role);
205 }
206
207 // Check validation of this registered node.
208 std::string reject_reason = "";
209 if (!CheckRankIdValidation(node_id, role, rank_id, host_ip, &reject_reason)) {
210 RegistrationRespMessage reg_resp_msg;
211 reg_resp_msg.set_success(false);
212 reg_resp_msg.set_error_reason(reject_reason);
213 auto response =
214 CreateMessage(meta_server_addr_.GetUrl(), MessageName::kInvalidNode, reg_resp_msg.SerializeAsString());
215 return response.release();
216 }
217
218 std::shared_ptr<NodeInfo> node_info = std::make_shared<NodeInfo>(node_id);
219 MS_ERROR_IF_NULL_W_RET_VAL(node_info, rpc::NULL_MSG);
220 node_info->host_name = host_name;
221 node_info->host_ip = host_ip;
222 node_info->role = role;
223 node_info->rank_id = rank_id;
224 node_info->state = NodeState::kRegistered;
225 (void)time(&(node_info->last_update));
226 nodes_[node_id] = node_info;
227 MS_LOG(WARNING) << "The new node: " << node_id << "(role: " << role << ")"
228 << ", rank id: " << rank_id << ", hostname: " << node_info->host_name << ", ip: " << host_ip
229 << " is registered successfully. Currently registered node number: " << nodes_.size()
230 << ", expected node number: " << total_node_num_;
231 (void)TransitionToInitialized();
232
233 RegistrationRespMessage reg_resp_msg;
234 reg_resp_msg.set_success(true);
235 reg_resp_msg.set_rank_id(rank_id);
236 reg_resp_msg.set_node_num(SizeToUint(total_node_num_));
237 std::string content = reg_resp_msg.SerializeAsString();
238
239 auto message = CreateMessage(meta_server_addr_.GetUrl(), MessageName::kSuccess, content);
240 MS_EXCEPTION_IF_NULL(message);
241 return message.release();
242 } else {
243 if (!recovery::IsEnableRecovery()) {
244 MS_LOG(WARNING) << "Node " << node_id << " registered repeatedly. It's host ip is " << host_ip
245 << ". Reject this node.";
246 RegistrationRespMessage reg_resp_msg;
247 reg_resp_msg.set_success(false);
248 reg_resp_msg.set_error_reason(
249 "Repeated registration node: " + node_id +
250 " to the scheduler. Please check if there's another scheduler process with port:" +
251 std::to_string(meta_server_addr_.port) +
252 " still running, or this is an extra node for distributed job. You can run command: 'netstat -anp|grep " +
253 std::to_string(meta_server_addr_.port) +
254 "' to check residual scheduler process. If another residual scheduler's still running, please kill it or "
255 "change '--master_port' to a unoccupied port number of 'msrun' command and "
256 "retry.");
257 auto response =
258 CreateMessage(meta_server_addr_.GetUrl(), MessageName::kInvalidNode, reg_resp_msg.SerializeAsString());
259 return response.release();
260 }
261 auto node_info = nodes_[node_id];
262 MS_EXCEPTION_IF_NULL(node_info);
263 node_info->host_ip = host_ip;
264 MS_LOG(WARNING) << "The node: " << node_id << " have been recovered. IP address: " << host_ip
265 << ", rank id: " << node_info->rank_id;
266 (void)metadata_.insert(std::make_pair(node_info->role + node_info->node_id, std::to_string(node_info->rank_id)));
267
268 RegistrationRespMessage reg_resp_msg;
269 reg_resp_msg.set_success(true);
270 reg_resp_msg.set_rank_id(node_info->rank_id);
271 std::string content = reg_resp_msg.SerializeAsString();
272
273 auto response = CreateMessage(meta_server_addr_.GetUrl(), MessageName::kSuccess, content);
274 MS_EXCEPTION_IF_NULL(response);
275 return response.release();
276 }
277 }
278
ProcessUnregister(MessageBase * const message)279 MessageBase *const MetaServerNode::ProcessUnregister(MessageBase *const message) {
280 MS_ERROR_IF_NULL_W_RET_VAL(message, rpc::NULL_MSG);
281 UnregistrationMessage unregistration;
282 const std::string &body = message->Body();
283 (void)unregistration.ParseFromArray(body.c_str(), SizeToInt(body.length()));
284
285 const auto &node_id = unregistration.node_id();
286
287 if (topo_state_ != TopoState::kInitialized) {
288 MS_LOG(ERROR) << "Unable to process unreg message from node " << node_id << " because the state of the topology is "
289 << topo_state_;
290 auto response = CreateMessage(meta_server_addr_.GetUrl(), MessageName::kUninitTopo,
291 std::to_string(static_cast<int>(MessageName::kUninitTopo)));
292 MS_EXCEPTION_IF_NULL(response);
293 return response.release();
294 }
295
296 std::unique_lock<std::shared_mutex> lock(nodes_mutex_);
297 if (nodes_.find(node_id) == nodes_.end()) {
298 MS_LOG(ERROR) << "Received unregistration message from invalid compute graph node: " << node_id;
299 auto response = CreateMessage(meta_server_addr_.GetUrl(), MessageName::kInvalidNode,
300 std::to_string(static_cast<int>(MessageName::kInvalidNode)));
301 MS_EXCEPTION_IF_NULL(response);
302 return response.release();
303 }
304 (void)nodes_.erase(node_id);
305 MS_LOG(WARNING) << "Node " << node_id << " has unregistered.";
306 if (nodes_.size() == 0) {
307 topo_state_ = TopoState::kFinished;
308 }
309 auto response = CreateMessage(meta_server_addr_.GetUrl(), MessageName::kSuccess,
310 std::to_string(static_cast<int>(MessageName::kSuccess)));
311 MS_EXCEPTION_IF_NULL(response);
312 return response.release();
313 }
314
ProcessHeartbeat(MessageBase * const message)315 MessageBase *const MetaServerNode::ProcessHeartbeat(MessageBase *const message) {
316 MS_ERROR_IF_NULL_W_RET_VAL(message, rpc::NULL_MSG);
317 HeartbeatMessage heartbeat;
318 const std::string &body = message->Body();
319 (void)heartbeat.ParseFromArray(body.c_str(), SizeToInt(body.length()));
320
321 // Update the state(timestamp) of this node.
322 const auto &node_id = heartbeat.node_id();
323 std::shared_lock<std::shared_mutex> lock(nodes_mutex_);
324 if (nodes_.find(node_id) != nodes_.end()) {
325 auto &node = nodes_[node_id];
326 MS_ERROR_IF_NULL_W_RET_VAL(node, rpc::NULL_MSG);
327 (void)time(&(node->last_update));
328 node->state = NodeState::kRegistered;
329
330 HeartbeatRespMessage resp_msg;
331 resp_msg.set_success(static_cast<bool>(MessageName::kSuccess));
332 resp_msg.set_topo_state(static_cast<uint32_t>(topo_state_));
333 resp_msg.set_nodes_num(SizeToUint(total_node_num_));
334 resp_msg.set_abnormal_nodes_num(SizeToUint(abnormal_node_num_));
335 auto content = resp_msg.SerializeAsString();
336 auto response = CreateMessage(meta_server_addr_.GetUrl(), MessageName::kSuccess, content);
337 MS_EXCEPTION_IF_NULL(response);
338 return response.release();
339 } else {
340 MS_LOG(ERROR) << "Invalid node: " << node_id << ".";
341 return rpc::NULL_MSG;
342 }
343 }
344
ProcessWriteMetadata(MessageBase * const message)345 MessageBase *const MetaServerNode::ProcessWriteMetadata(MessageBase *const message) {
346 MS_ERROR_IF_NULL_W_RET_VAL(message, rpc::NULL_MSG);
347 const std::string &body = message->Body();
348 MetadataMessage meta_msg;
349 (void)meta_msg.ParseFromArray(body.c_str(), SizeToInt(body.length()));
350 if (meta_msg.name().length() == 0) {
351 MS_LOG(ERROR) << "Empty metadata name.";
352 return rpc::NULL_MSG;
353 }
354 std::shared_lock<std::shared_mutex> lock(meta_mutex_);
355 metadata_[meta_msg.name()] = meta_msg.value();
356 return rpc::NULL_MSG;
357 }
358
ProcessReadMetadata(MessageBase * const message)359 MessageBase *const MetaServerNode::ProcessReadMetadata(MessageBase *const message) {
360 MS_ERROR_IF_NULL_W_RET_VAL(message, rpc::NULL_MSG);
361 const std::string &body = message->Body();
362 MetadataMessage meta_msg;
363 (void)meta_msg.ParseFromArray(body.c_str(), SizeToInt(body.length()));
364
365 std::shared_lock<std::shared_mutex> lock(meta_mutex_);
366 MessageName result;
367 std::unique_ptr<MessageBase> response;
368
369 if (metadata_.find(meta_msg.name()) == metadata_.end()) {
370 result = MessageName::kInvalidMetadata;
371 } else {
372 result = MessageName::kValidMetadata;
373 std::string meta_value = metadata_[meta_msg.name()];
374 meta_msg.set_value(meta_value);
375 }
376 response = CreateMessage(meta_server_addr_.GetUrl(), result, meta_msg.SerializeAsString());
377 MS_EXCEPTION_IF_NULL(response);
378 return response.release();
379 }
380
ProcessDeleteMetadata(MessageBase * const message)381 MessageBase *const MetaServerNode::ProcessDeleteMetadata(MessageBase *const message) {
382 MS_ERROR_IF_NULL_W_RET_VAL(message, rpc::NULL_MSG);
383 const std::string &body = message->Body();
384 MetadataMessage meta_msg;
385 (void)meta_msg.ParseFromArray(body.c_str(), SizeToInt(body.length()));
386
387 std::shared_lock<std::shared_mutex> lock(meta_mutex_);
388 MessageName result;
389 std::unique_ptr<MessageBase> response;
390
391 if (metadata_.find(meta_msg.name()) == metadata_.end()) {
392 result = MessageName::kInvalidMetadata;
393 } else {
394 result = MessageName::kValidMetadata;
395 (void)metadata_.erase(meta_msg.name());
396 }
397 response = CreateMessage(meta_server_addr_.GetUrl(), result, meta_msg.SerializeAsString());
398 MS_EXCEPTION_IF_NULL(response);
399 return response.release();
400 }
401
ProcessGetHostNames(MessageBase * const message)402 MessageBase *const MetaServerNode::ProcessGetHostNames(MessageBase *const message) {
403 MS_ERROR_IF_NULL_W_RET_VAL(message, rpc::NULL_MSG);
404 // Convert result to the message.
405 nlohmann::json hostnames = nlohmann::json::array();
406 nlohmann::json retval = nlohmann::json::object();
407 MessageName result;
408
409 if (nodes_.size() != total_node_num_) {
410 result = MessageName::kInvalidMetadata;
411 } else {
412 result = MessageName::kValidMetadata;
413
414 auto node_role = message->body;
415
416 // Collect all the hostnames from nodes info.
417 std::vector<std::string> tmp_hostnames(nodes_.size(), "");
418 std::shared_lock<std::shared_mutex> lock(nodes_mutex_);
419
420 // The hostnames must are sorted strictly by the rank id.
421 for (auto iter = nodes_.begin(); iter != nodes_.end(); ++iter) {
422 auto node_info = iter->second;
423 MS_EXCEPTION_IF_NULL(node_info);
424 if (node_info->role != node_role) {
425 continue;
426 }
427 if (node_info->rank_id >= 0 && node_info->rank_id < tmp_hostnames.size()) {
428 tmp_hostnames[node_info->rank_id] = node_info->host_name;
429 } else {
430 MS_LOG(ERROR) << "Invalid rank id: " << node_info->rank_id << " for node: " << node_info->node_id;
431 continue;
432 }
433 }
434
435 // The hostname of the node whose role name not match is empty, and should be skipped.
436 for (size_t i = 0; i < tmp_hostnames.size(); ++i) {
437 if (tmp_hostnames[i] != "") {
438 hostnames.push_back(tmp_hostnames[i]);
439 }
440 }
441 }
442
443 retval[kHostNames] = hostnames;
444 try {
445 MS_LOG(DEBUG) << "Host names are " << retval.dump();
446 } catch (const std::exception &e) {
447 MS_LOG(ERROR) << "Failed to dump host names json " << e.what();
448 }
449 auto response = CreateMessage(meta_server_addr_.GetUrl(), result, retval.dump());
450 MS_EXCEPTION_IF_NULL(response);
451 return response.release();
452 }
453
UpdateTopoState()454 void MetaServerNode::UpdateTopoState() {
455 try {
456 while (enable_monitor_) {
457 nodes_mutex_.lock();
458
459 // Update the state of topology.
460 if (topo_state_ == TopoState::kInitializing) {
461 if (TransitionToInitialized()) {
462 continue;
463 }
464 MS_LOG(INFO) << "The cluster topology is in the process of constructing, current alive node num: ("
465 << nodes_.size() << "/" << total_node_num_ << ")";
466 } else if (topo_state_ == TopoState::kInitialized) {
467 if (nodes_.size() == 0) {
468 topo_state_ = TopoState::kFinished;
469 }
470 }
471
472 if (!disable_heartbeat_) {
473 // Update the state of compute graph nodes if heartbeat is enabled.
474 size_t abnormal_node_num = 0;
475 std::vector<std::string> time_out_node_ids = {};
476 for (auto iter = nodes_.begin(); iter != nodes_.end(); ++iter) {
477 auto node_id = iter->first;
478 auto node_info = iter->second;
479 MS_EXCEPTION_IF_NULL(node_info);
480 time_t now = time(&now);
481 auto elapsed = difftime(now, node_info->last_update);
482 if (elapsed > node_timeout_) {
483 node_info->state = NodeState::kTimeout;
484 ++abnormal_node_num;
485 time_out_node_ids.push_back(node_id);
486 MS_LOG(ERROR) << "The node: " << node_id
487 << " is timed out. It may exit with exception, please check this node's log.";
488 }
489 }
490 abnormal_node_num_ = abnormal_node_num;
491 if (abnormal_node_num_ > 0 && !recovery::IsEnableRecovery()) {
492 MS_LOG(EXCEPTION) << "The total number of timed out node is " << abnormal_node_num_
493 << ". Timed out node list is: " << time_out_node_ids << ", worker " << time_out_node_ids[0]
494 << " is the first one timed out, please check its log.";
495 }
496 }
497
498 nodes_mutex_.unlock();
499 static const size_t interval = 3;
500 (void)sleep(interval);
501 }
502 } catch (const std::exception &e) {
503 nodes_mutex_.unlock();
504 MsException::Instance().SetException();
505 }
506 }
507
TransitionToInitialized()508 bool MetaServerNode::TransitionToInitialized() {
509 if (nodes_.size() == total_node_num_) {
510 // After all nodes are successfully registered, reassign rank ids so they could be continuous.
511 ReassignNodeRank();
512
513 // Assign port range for each node after cluster is initialized.
514 AssignPortRange();
515
516 // Persist the cluster metadata into storage through configuration.
517 if (recovery::IsEnableRecovery() && configuration_ != nullptr && configuration_->Empty()) {
518 if (!Persist()) {
519 MS_LOG(EXCEPTION) << "Failed to persist the metadata of the cluster.";
520 }
521 }
522 topo_state_ = TopoState::kInitialized;
523 MS_LOG(INFO) << "The cluster topology has been constructed successfully.";
524 return true;
525 }
526 return false;
527 }
528
AssignPortRange()529 void MetaServerNode::AssignPortRange() {
530 MS_LOG(DEBUG) << "Start assigning port range for nodes...";
531 std::unordered_map<std::string, uint32_t> each_host_node_num;
532 std::unordered_map<std::string, uint32_t> node_index_map;
533 // Assign computing graph nodes' port range according to their hosts.
534 for (const auto &n : nodes_) {
535 std::string node_id = n.first;
536 const auto &node_info = n.second;
537
538 uint32_t &host_node_num = each_host_node_num[node_info->host_name];
539 node_index_map[node_info->node_id] = host_node_num;
540 host_node_num++;
541 }
542
543 NodePortRanges node_ranges;
544 for (const auto &n : nodes_) {
545 std::string node_id = n.first;
546 const auto &node_info = n.second;
547 uint32_t node_index = node_index_map[node_id];
548 uint32_t each_node_range = kNodePortRangeNum / each_host_node_num[node_info->host_name];
549 uint32_t min_port = kStartPort + each_node_range * node_index;
550 uint32_t max_port = min_port + each_node_range - 1;
551 PortRange range;
552 range.set_min_port(min_port);
553 range.set_max_port(max_port);
554 (void)node_ranges.mutable_data()->insert({node_id, range});
555 MS_LOG(INFO) << "The port range for node " << node_id << ", rank id: " << node_info->rank_id
556 << ", min port: " << min_port << ", max port: " << max_port;
557 }
558 (void)metadata_.insert({kNodePortRange, node_ranges.SerializeAsString()});
559 }
560
Recovery()561 bool MetaServerNode::Recovery() {
562 std::shared_lock<std::shared_mutex> lock(nodes_mutex_);
563 std::string recovery_path = recovery::RecoveryPath();
564 RETURN_IF_FALSE_WITH_LOG(CheckFilePath(recovery_path), "Invalid recovery path: " << recovery_path);
565 configuration_ = std::make_unique<recovery::FileConfiguration>(recovery_path + "/" + kRecoveryFileName);
566 MS_EXCEPTION_IF_NULL(configuration_);
567
568 RETURN_IF_FALSE_WITH_LOG(configuration_->Initialize(),
569 "Failed to initialize the recovery file configuration from file path: " << recovery_path);
570
571 if (configuration_->Empty()) {
572 MS_LOG(INFO) << "The meta server node is started for the first time.";
573 return true;
574
575 // The meta server node is restarted and the metadata of cluster needs to be recovered.
576 } else {
577 MS_LOG(INFO) << "Begin to recover the meta server node.";
578 std::string states_key = kComputeNodeStates;
579 RETURN_IF_FALSE_WITH_LOG(configuration_->Exists(states_key),
580 "Can not find the key " + states_key + " in configuration.");
581
582 // Check the validation of the previous metadata.
583 const auto &states = configuration_->Get(states_key, "");
584 nlohmann::json node_states = nlohmann::json::parse(states);
585 RETURN_IF_FALSE_WITH_LOG(node_states.size() == total_node_num_,
586 "Invalid number of node in configuration: " + std::to_string(node_states.size()) +
587 ", expected total number of node: " + std::to_string(total_node_num_));
588
589 // Restore the nodes state.
590 for (auto iter = node_states.begin(); iter != node_states.end(); ++iter) {
591 const auto &node_id = iter.key();
592 std::shared_ptr<NodeInfo> node_info = std::make_shared<NodeInfo>(node_id);
593 MS_EXCEPTION_IF_NULL(node_info);
594 (void)time(&(node_info->last_update));
595 node_info->host_name = iter.value().at(kHostName);
596 node_info->role = iter.value().at(kRole);
597 node_info->rank_id = iter.value().at(kRankId);
598 node_info->state = NodeState::kRegistered;
599 nodes_[node_id] = node_info;
600 }
601
602 if (nodes_.size() == total_node_num_) {
603 topo_state_ = TopoState::kInitialized;
604 }
605 MS_LOG(INFO) << "The meta server node has been recovered successfully.";
606 }
607 return true;
608 }
609
Persist()610 bool MetaServerNode::Persist() {
611 if (total_node_num_ != nodes_.size()) {
612 MS_LOG(ERROR) << "Invalid number of alive node: " << nodes_.size()
613 << ", the expected total number of node is: " << total_node_num_;
614 return false;
615 }
616
617 // The thread safety of nodes_ visiting has been guarded by the caller.
618 nlohmann::json node_states;
619 for (auto iter = nodes_.begin(); iter != nodes_.end(); ++iter) {
620 const auto &node_id = iter->first;
621 nlohmann::json node_state;
622 node_state[kNodeId] = node_id;
623
624 MS_EXCEPTION_IF_NULL(iter->second);
625 node_state[kHostName] = iter->second->host_name;
626 node_state[kRole] = iter->second->role;
627 node_state[kRankId] = iter->second->rank_id;
628 node_states[node_id] = node_state;
629 }
630
631 MS_EXCEPTION_IF_NULL(configuration_);
632 configuration_->Put(kComputeNodeStates, node_states.dump());
633 RETURN_IF_FALSE_WITH_LOG(configuration_->Flush(), "Failed to flush configuration.");
634 return true;
635 }
636
AllocateRankId(const std::string & role)637 uint32_t MetaServerNode::AllocateRankId(const std::string &role) {
638 std::shared_lock<std::shared_mutex> lock(rank_mutex_);
639 if (role_expect_num_.find(role) == role_expect_num_.end()) {
640 MS_LOG(WARNING) << "Role: " << role << " is invalid.";
641 return UINT32_MAX;
642 }
643 if (next_rank_ids_.count(role) == 0) {
644 next_rank_ids_[role] = 0;
645 } else {
646 // If this role's rank id has exceeded, do not increase next_rank_ids_ and return an exceeded rank id. The caller
647 // will check rank id's validation and reject this request.
648 if (next_rank_ids_[role] == role_expect_num_[role] - 1) {
649 return next_rank_ids_[role] + 1;
650 }
651 next_rank_ids_[role] += 1;
652 }
653 return next_rank_ids_[role];
654 }
655
CheckRankIdValidation(const std::string & node_id,const std::string & role,uint32_t rank_id,const std::string & host_ip,std::string * reject_reason)656 bool MetaServerNode::CheckRankIdValidation(const std::string &node_id, const std::string &role, uint32_t rank_id,
657 const std::string &host_ip, std::string *reject_reason) {
658 if (role_expect_num_.find(role) == role_expect_num_.end()) {
659 MS_LOG(WARNING) << "Registered node role: " << role << " is invalid.";
660 return false;
661 }
662 // Whether rank id has already exists.
663 bool rank_id_exist = std::any_of(nodes_.begin(), nodes_.end(), [&role, &rank_id](const auto &n) {
664 return n.second->role == role && n.second->rank_id == rank_id;
665 });
666 // Whether rank id exceeds upper bound.
667 bool is_extra_node = (rank_id >= role_expect_num_[role]);
668 if (rank_id_exist) {
669 *reject_reason = "Rank id:" + std::to_string(rank_id) + " for role:" + role + " exists.";
670 }
671 if (is_extra_node) {
672 *reject_reason = "This node is extra or rank id exceeds. Total node number for role " + role + " is " +
673 std::to_string(role_expect_num_[role]) + " but got rank id " + std::to_string(rank_id);
674 }
675 if (rank_id_exist || is_extra_node) {
676 MS_LOG(WARNING) << "Rejecting registration request for node " << node_id << " from host " << host_ip
677 << ". Rejection reason: " << *reject_reason;
678 return false;
679 }
680 return true;
681 }
682
ReassignNodeRank()683 void MetaServerNode::ReassignNodeRank() {
684 if (std::all_of(nodes_.begin(), nodes_.end(), [](const auto &node) { return common::IsStrNumeric(node.first); })) {
685 MS_LOG(WARNING) << "Rank ids are already set by numeric node ids. No need to reassign them.";
686 for (const auto &n : nodes_) {
687 const std::shared_ptr<NodeInfo> &node_info = n.second;
688 const std::string &role = node_info->role;
689 (void)metadata_.insert(std::make_pair(role + node_info->node_id, std::to_string(node_info->rank_id)));
690 }
691 return;
692 }
693
694 MS_LOG(INFO) << "Start sorting and reassiging rank ids for nodes according to node ips and node ids.";
695 std::map<std::string, std::map<NodeKey, uint32_t>> node_ranks;
696 for (auto &n : nodes_) {
697 std::shared_ptr<NodeInfo> &node_info = n.second;
698 NodeKey node_key = {node_info->host_ip, node_info->node_id};
699 (void)node_ranks[node_info->role].insert(std::make_pair(node_key, 0));
700 }
701
702 for (auto &n : node_ranks) {
703 std::map<NodeKey, uint32_t> &node_key_ranks = n.second;
704 uint32_t accum_rank_id = 0;
705 for (auto &node_rank : node_key_ranks) {
706 node_rank.second = accum_rank_id++;
707 }
708 }
709
710 for (auto &n : nodes_) {
711 std::shared_ptr<NodeInfo> &node_info = n.second;
712 const std::string &role = node_info->role;
713 NodeKey node_key = {node_info->host_ip, node_info->node_id};
714 uint32_t new_rank = node_ranks[role][node_key];
715
716 MS_LOG(WARNING) << "Assign rank id of node id: " << node_info->node_id << ", role: " << role
717 << ", with host ip: " << node_info->host_ip << ", old rank id: " << node_info->rank_id
718 << ", new rank id: " << new_rank;
719
720 node_info->rank_id = new_rank;
721 (void)metadata_.insert(std::make_pair(role + node_info->node_id, std::to_string(node_info->rank_id)));
722 }
723 }
724
TopologyState() const725 TopoState MetaServerNode::TopologyState() const { return topo_state_; }
726
GetAliveNodeNum()727 size_t MetaServerNode::GetAliveNodeNum() {
728 std::shared_lock<std::shared_mutex> lock(nodes_mutex_);
729 size_t count = 0;
730 for (auto iter = nodes_.begin(); iter != nodes_.end(); ++iter) {
731 auto node_info = iter->second;
732 MS_EXCEPTION_IF_NULL(node_info);
733
734 // Only the node which has been authenticated is alive.
735 if (node_info->state == NodeState::kRegistered) {
736 ++count;
737 }
738 }
739 return count;
740 }
741
RegisterMessageHandler(const std::string & name,const std::shared_ptr<std::function<std::string (const std::string &)>> & handler)742 bool MetaServerNode::RegisterMessageHandler(
743 const std::string &name, const std::shared_ptr<std::function<std::string(const std::string &)>> &handler) {
744 if (message_handlers_.find(name) != message_handlers_.end()) {
745 MS_LOG(ERROR) << "The message name: " << name << " have already been registered";
746 return false;
747 }
748 message_handlers_[name] = handler;
749 return true;
750 }
751 } // namespace topology
752 } // namespace cluster
753 } // namespace distributed
754 } // namespace mindspore
755