1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ps/core/node_manager.h"
18
19 namespace mindspore {
20 namespace ps {
21 namespace core {
InitNode()22 void NodeManager::InitNode() {
23 initial_total_node_num_ = PSContext::instance()->cluster_config().initial_server_num +
24 PSContext::instance()->cluster_config().initial_worker_num;
25 meta_data_ = std::make_unique<ClusterMetadata>(PSContext::instance()->cluster_config().initial_worker_num,
26 PSContext::instance()->cluster_config().initial_server_num);
27 MS_EXCEPTION_IF_NULL(meta_data_);
28 total_node_num_ = UintToInt(initial_total_node_num_);
29 }
30
NextRankId(const RegisterMessage & register_message,const std::shared_ptr<MessageMeta> & meta)31 uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message, const std::shared_ptr<MessageMeta> &meta) {
32 MS_EXCEPTION_IF_NULL(meta);
33 MS_EXCEPTION_IF_NULL(meta_data_);
34 std::lock_guard<std::mutex> lock(assign_rank_id_mutex_);
35 uint32_t rank_id = UINT_MAX;
36
37 const std::string &node_id = register_message.node_id();
38 if (registered_nodes_info_.find(node_id) != registered_nodes_info_.end()) {
39 const std::string &new_ip = register_message.ip();
40 uint32_t new_port = register_message.port();
41 rank_id = registered_nodes_info_[node_id].rank_id_;
42 registered_nodes_info_[node_id].is_alive = true;
43 registered_nodes_info_[node_id].ip_ = new_ip;
44 registered_nodes_info_[node_id].port_ = static_cast<uint16_t>(new_port);
45 MS_LOG(INFO) << "The node id: " << node_id << " is already assigned!";
46 return rank_id;
47 }
48
49 if (register_message.role() == NodeRole::SERVER) {
50 const std::string &ip = register_message.ip();
51 uint32_t port = register_message.port();
52
53 auto rank_it = std::find_if(registered_nodes_info_.begin(), registered_nodes_info_.end(), [&rank_id](auto item) {
54 bool res = item.second.is_alive == false && item.second.node_role_ == NodeRole::SERVER;
55 if (res) {
56 MS_LOG(INFO) << "The server node id:" << item.first << " rank id:" << item.second.rank_id_ << " is not alive.";
57 rank_id = item.second.rank_id_;
58 }
59 return res;
60 });
61 if (rank_it == registered_nodes_info_.end()) {
62 if (meta->rank_id() != UINT32_MAX && UintToInt(meta->rank_id()) <= next_server_rank_id_) {
63 rank_id = meta->rank_id();
64 MS_LOG(INFO) << "Use the old rank id:" << rank_id;
65 } else {
66 rank_id = IntToUint(++next_server_rank_id_);
67 }
68 } else {
69 registered_nodes_info_.erase((*rank_it).first);
70 }
71
72 if (rank_id >= meta_data_->server_num) {
73 MS_LOG(WARNING) << "The rank id is greater than the number of servers:" << meta_data_->server_num;
74 rank_id = UINT_MAX;
75 --next_server_rank_id_;
76 }
77 NodeInfo node_info;
78 node_info.node_role_ = NodeRole::SERVER;
79 node_info.node_id_ = node_id;
80 node_info.rank_id_ = rank_id;
81 node_info.ip_ = ip;
82 node_info.port_ = static_cast<uint16_t>(port);
83 node_info.is_alive = true;
84 registered_nodes_info_[node_id] = node_info;
85 MS_LOG(INFO) << "The server node id:" << node_id << ",node ip: " << node_info.ip_ << ",node port:" << port
86 << " assign rank id:" << rank_id;
87 } else if (register_message.role() == NodeRole::WORKER) {
88 const std::string &ip = register_message.ip();
89 uint32_t port = register_message.port();
90
91 auto worker_rank_it =
92 std::find_if(registered_nodes_info_.begin(), registered_nodes_info_.end(), [&rank_id](auto item) {
93 bool res = item.second.is_alive == false && item.second.node_role_ == NodeRole::WORKER;
94 if (res) {
95 MS_LOG(INFO) << "The worker node id:" << item.first << " rank id:" << rank_id << " is not alive.";
96 rank_id = item.second.rank_id_;
97 }
98 return res;
99 });
100 if (worker_rank_it == registered_nodes_info_.end()) {
101 if (meta->rank_id() != UINT32_MAX && UintToInt(meta->rank_id()) <= next_worker_rank_id_) {
102 rank_id = meta->rank_id();
103 MS_LOG(INFO) << "Use the old rank id:" << rank_id;
104 } else {
105 rank_id = IntToUint(++next_worker_rank_id_);
106 }
107 } else {
108 registered_nodes_info_.erase((*worker_rank_it).first);
109 }
110
111 if (rank_id >= meta_data_->worker_num) {
112 MS_LOG(WARNING) << "The rank id is greater than the number of workers:" << meta_data_->worker_num;
113 rank_id = UINT_MAX;
114 --next_worker_rank_id_;
115 }
116 NodeInfo node_info;
117 node_info.node_role_ = NodeRole::WORKER;
118 node_info.node_id_ = node_id;
119 node_info.rank_id_ = rank_id;
120 node_info.ip_ = ip;
121 node_info.port_ = static_cast<uint16_t>(port);
122 node_info.is_alive = true;
123 registered_nodes_info_[node_id] = node_info;
124 MS_LOG(INFO) << "The worker node id:" << node_id << " assign rank id:" << rank_id;
125 }
126 return rank_id;
127 }
128
UpdateHeartbeat(const std::string & node_id)129 void NodeManager::UpdateHeartbeat(const std::string &node_id) {
130 std::lock_guard<std::mutex> lock(heartbeat_mutex_);
131 struct timeval current_time {};
132 (void)gettimeofday(¤t_time, nullptr);
133 heartbeats_[node_id] = current_time;
134 }
135
FetchServersMeta()136 std::vector<ServersMeta> NodeManager::FetchServersMeta() {
137 std::vector<ServersMeta> servers_meta_list;
138 for (auto it = registered_nodes_info_.begin(); it != registered_nodes_info_.end(); ++it) {
139 if (it->second.node_role_ == NodeRole::SERVER) {
140 ServersMeta servers_meta;
141 servers_meta.set_rank_id(it->second.rank_id_);
142 servers_meta.set_ip(it->second.ip_);
143 servers_meta.set_port(it->second.port_);
144 servers_meta_list.push_back(servers_meta);
145 }
146 }
147 return servers_meta_list;
148 }
149
FetchAllNodesMeta()150 std::vector<ServersMeta> NodeManager::FetchAllNodesMeta() {
151 std::vector<ServersMeta> servers_meta_list;
152 for (auto it = registered_nodes_info_.begin(); it != registered_nodes_info_.end(); ++it) {
153 ServersMeta servers_meta;
154 servers_meta.set_rank_id(it->second.rank_id_);
155 servers_meta.set_ip(it->second.ip_);
156 servers_meta.set_port(it->second.port_);
157 servers_meta.set_is_alive(it->second.is_alive);
158 servers_meta.set_role(it->second.node_role_);
159 servers_meta.set_node_id(it->second.node_id_);
160 servers_meta_list.push_back(servers_meta);
161 }
162 return servers_meta_list;
163 }
164
UpdateCluster()165 void NodeManager::UpdateCluster() {
166 // 1. update cluster timeout state
167 struct timeval current_time {};
168 (void)gettimeofday(¤t_time, nullptr);
169 timeout_nodes_info_.clear();
170 for (auto it = heartbeats_.begin(); it != heartbeats_.end(); ++it) {
171 if (it->second.tv_sec + PSContext::instance()->cluster_config().heartbeat_timeout < current_time.tv_sec) {
172 if (registered_nodes_info_.count(it->first)) {
173 MS_LOG(WARNING) << "The node id:" << it->first << " is timeout!";
174 timeout_nodes_info_[it->first] = registered_nodes_info_[it->first];
175 registered_nodes_info_[it->first].is_alive = false;
176 }
177 }
178 }
179
180 if (!timeout_nodes_info_.empty()) {
181 UpdateClusterState(ClusterState::NODE_TIMEOUT);
182 for (auto iter = timeout_nodes_info_.begin(); iter != timeout_nodes_info_.end(); ++iter) {
183 (void)heartbeats_.erase(iter->first);
184 finish_nodes_id_.insert(iter->first);
185 }
186 }
187
188 // 2. update cluster finish state
189 if (SizeToInt(finish_nodes_id_.size()) == total_node_num_ ||
190 SizeToInt(finish_nodes_id_.size()) == current_node_num_) {
191 UpdateClusterState(ClusterState::CLUSTER_EXIT);
192 }
193 }
194
CheckClusterTimeout()195 void NodeManager::CheckClusterTimeout() {
196 if (total_node_num_ != SizeToInt(registered_nodes_info_.size())) {
197 MS_LOG(WARNING) << "The cluster is not ready after "
198 << PSContext::instance()->cluster_config().cluster_available_timeout
199 << " seconds,so finish the cluster, and change total node number from " << total_node_num_ << " to "
200 << registered_nodes_info_.size();
201 current_node_num_ = SizeToInt(registered_nodes_info_.size());
202 UpdateClusterState(ClusterState::NODE_TIMEOUT);
203 }
204 }
205
AddFinishNode(const std::string & finish_message)206 void NodeManager::AddFinishNode(const std::string &finish_message) { finish_nodes_id_.insert(finish_message); }
207
AddScaleOutDoneNode(const std::string & node_id)208 void NodeManager::AddScaleOutDoneNode(const std::string &node_id) { scale_out_done_nodes_id_.insert(node_id); }
209
AddScaleInDoneNode(const std::string & node_id)210 void NodeManager::AddScaleInDoneNode(const std::string &node_id) { scale_in_done_nodes_id_.insert(node_id); }
211
IsAllNodesRegistered() const212 bool NodeManager::IsAllNodesRegistered() const {
213 int32_t num = std::count_if(registered_nodes_info_.begin(), registered_nodes_info_.end(),
214 [](auto item) { return item.second.is_alive == true; });
215 return num == total_node_num_;
216 }
217
IsAllNodesFinished() const218 bool NodeManager::IsAllNodesFinished() const { return SizeToInt(finish_nodes_id_.size()) == total_node_num_; }
219
IsAllNodesScaleOutDone() const220 bool NodeManager::IsAllNodesScaleOutDone() const {
221 return SizeToInt(scale_out_done_nodes_id_.size()) == total_node_num_;
222 }
223
IsAllNodesScaleInDone() const224 bool NodeManager::IsAllNodesScaleInDone() const { return SizeToInt(scale_in_done_nodes_id_.size()) == total_node_num_; }
225
nodes_info() const226 const std::unordered_map<std::string, NodeInfo> &NodeManager::nodes_info() const { return nodes_info_; }
227
registered_nodes_info() const228 const std::unordered_map<std::string, NodeInfo> &NodeManager::registered_nodes_info() const {
229 return registered_nodes_info_;
230 }
231
UpdateNodesInfo()232 void NodeManager::UpdateNodesInfo() {
233 MS_LOG(INFO) << "Update nodes info.";
234 nodes_info_.clear();
235 nodes_info_ = registered_nodes_info_;
236 }
237
UpdateNodeState(const NodeState & state)238 void NodeManager::UpdateNodeState(const NodeState &state) {
239 std::lock_guard<std::mutex> lk(node_mutex_);
240 node_state_ = state;
241 }
242
UpdateClusterState(const ClusterState & state)243 void NodeManager::UpdateClusterState(const ClusterState &state) {
244 std::lock_guard<std::mutex> lk(cluster_mutex_);
245 MS_LOG(INFO) << "[state]: Scheduler change state from:" << CommUtil::ClusterStateToString(cluster_state_) << " to "
246 << CommUtil::ClusterStateToString(state);
247 cluster_state_ = state;
248 }
249
GetNodeState()250 NodeState NodeManager::GetNodeState() {
251 std::lock_guard<std::mutex> lk(node_mutex_);
252 return node_state_;
253 }
254
GetClusterState()255 ClusterState NodeManager::GetClusterState() {
256 std::lock_guard<std::mutex> lk(cluster_mutex_);
257 return cluster_state_;
258 }
259
ResetMetadata(const std::vector<std::string> & scale_in_nodes)260 void NodeManager::ResetMetadata(const std::vector<std::string> &scale_in_nodes) {
261 MS_LOG(WARNING) << "Reset metadata.";
262 std::vector<uint32_t> server_rank_ids;
263 if (GetClusterState() == ClusterState::CLUSTER_SCALE_IN) {
264 for (const auto &item : scale_in_nodes) {
265 if (registered_nodes_info_.count(item)) {
266 server_rank_ids.push_back(registered_nodes_info_[item].rank_id_);
267 }
268 }
269 auto min_rank_id = std::min_element(server_rank_ids.begin(), server_rank_ids.end());
270 next_server_rank_id_ = UintToInt(*min_rank_id - 1);
271 MS_LOG(INFO) << "The next server rank id:" << next_server_rank_id_;
272 }
273 registered_nodes_info_.clear();
274 heartbeats_.clear();
275 }
276
IsWorkerOrServer0()277 bool NodeManager::IsWorkerOrServer0() {
278 bool res = std::any_of(registered_nodes_info_.begin(), registered_nodes_info_.end(), [](auto item) {
279 if (item.second.node_role_ == NodeRole::WORKER && item.second.is_alive == false) {
280 return true;
281 }
282
283 if (item.second.node_role_ == NodeRole::SERVER && item.second.is_alive == false && item.second.rank_id_ == 0) {
284 return true;
285 }
286
287 return false;
288 });
289
290 return res;
291 }
292
IsNodeRegistered(const std::string & node_id)293 bool NodeManager::IsNodeRegistered(const std::string &node_id) {
294 if (registered_nodes_info_.find(node_id) != registered_nodes_info_.end()) {
295 return true;
296 }
297 return false;
298 }
299
set_total_node_num(const int32_t & node_num)300 void NodeManager::set_total_node_num(const int32_t &node_num) { total_node_num_ = node_num; }
301
total_node_num() const302 const int32_t &NodeManager::total_node_num() const { return total_node_num_; }
303
set_worker_num(const int32_t & worker_num)304 void NodeManager::set_worker_num(const int32_t &worker_num) { meta_data_->worker_num = IntToUint(worker_num); }
305
set_server_num(const int32_t & server_num)306 void NodeManager::set_server_num(const int32_t &server_num) { meta_data_->server_num = IntToUint(server_num); }
307
worker_num() const308 int32_t NodeManager::worker_num() const { return UintToInt(meta_data_->worker_num); }
309
server_num() const310 int32_t NodeManager::server_num() const { return UintToInt(meta_data_->server_num); }
311 } // namespace core
312 } // namespace ps
313 } // namespace mindspore
314