• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ps/core/node_manager.h"
18 
19 namespace mindspore {
20 namespace ps {
21 namespace core {
InitNode()22 void NodeManager::InitNode() {
23   initial_total_node_num_ = PSContext::instance()->cluster_config().initial_server_num +
24                             PSContext::instance()->cluster_config().initial_worker_num;
25   meta_data_ = std::make_unique<ClusterMetadata>(PSContext::instance()->cluster_config().initial_worker_num,
26                                                  PSContext::instance()->cluster_config().initial_server_num);
27   MS_EXCEPTION_IF_NULL(meta_data_);
28   total_node_num_ = UintToInt(initial_total_node_num_);
29 }
30 
NextRankId(const RegisterMessage & register_message,const std::shared_ptr<MessageMeta> & meta)31 uint32_t NodeManager::NextRankId(const RegisterMessage &register_message, const std::shared_ptr<MessageMeta> &meta) {
32   MS_EXCEPTION_IF_NULL(meta);
33   MS_EXCEPTION_IF_NULL(meta_data_);
34   std::lock_guard<std::mutex> lock(assign_rank_id_mutex_);
35   uint32_t rank_id = UINT_MAX;
36 
37   const std::string &node_id = register_message.node_id();
38   if (registered_nodes_info_.find(node_id) != registered_nodes_info_.end()) {
39     const std::string &new_ip = register_message.ip();
40     uint32_t new_port = register_message.port();
41     rank_id = registered_nodes_info_[node_id].rank_id_;
42     registered_nodes_info_[node_id].is_alive = true;
43     registered_nodes_info_[node_id].ip_ = new_ip;
44     registered_nodes_info_[node_id].port_ = static_cast<uint16_t>(new_port);
45     MS_LOG(INFO) << "The node id: " << node_id << " is already assigned!";
46     return rank_id;
47   }
48 
49   if (register_message.role() == NodeRole::SERVER) {
50     const std::string &ip = register_message.ip();
51     uint32_t port = register_message.port();
52 
53     auto rank_it = std::find_if(registered_nodes_info_.begin(), registered_nodes_info_.end(), [&rank_id](auto item) {
54       bool res = item.second.is_alive == false && item.second.node_role_ == NodeRole::SERVER;
55       if (res) {
56         MS_LOG(INFO) << "The server node id:" << item.first << " rank id:" << item.second.rank_id_ << " is not alive.";
57         rank_id = item.second.rank_id_;
58       }
59       return res;
60     });
61     if (rank_it == registered_nodes_info_.end()) {
62       if (meta->rank_id() != UINT32_MAX && UintToInt(meta->rank_id()) <= next_server_rank_id_) {
63         rank_id = meta->rank_id();
64         MS_LOG(INFO) << "Use the old rank id:" << rank_id;
65       } else {
66         rank_id = IntToUint(++next_server_rank_id_);
67       }
68     } else {
69       registered_nodes_info_.erase((*rank_it).first);
70     }
71 
72     if (rank_id >= meta_data_->server_num) {
73       MS_LOG(WARNING) << "The rank id is greater than the number of servers:" << meta_data_->server_num;
74       rank_id = UINT_MAX;
75       --next_server_rank_id_;
76     }
77     NodeInfo node_info;
78     node_info.node_role_ = NodeRole::SERVER;
79     node_info.node_id_ = node_id;
80     node_info.rank_id_ = rank_id;
81     node_info.ip_ = ip;
82     node_info.port_ = static_cast<uint16_t>(port);
83     node_info.is_alive = true;
84     registered_nodes_info_[node_id] = node_info;
85     MS_LOG(INFO) << "The server node id:" << node_id << ",node ip: " << node_info.ip_ << ",node port:" << port
86                  << " assign rank id:" << rank_id;
87   } else if (register_message.role() == NodeRole::WORKER) {
88     const std::string &ip = register_message.ip();
89     uint32_t port = register_message.port();
90 
91     auto worker_rank_it =
92       std::find_if(registered_nodes_info_.begin(), registered_nodes_info_.end(), [&rank_id](auto item) {
93         bool res = item.second.is_alive == false && item.second.node_role_ == NodeRole::WORKER;
94         if (res) {
95           MS_LOG(INFO) << "The worker node id:" << item.first << " rank id:" << rank_id << " is not alive.";
96           rank_id = item.second.rank_id_;
97         }
98         return res;
99       });
100     if (worker_rank_it == registered_nodes_info_.end()) {
101       if (meta->rank_id() != UINT32_MAX && UintToInt(meta->rank_id()) <= next_worker_rank_id_) {
102         rank_id = meta->rank_id();
103         MS_LOG(INFO) << "Use the old rank id:" << rank_id;
104       } else {
105         rank_id = IntToUint(++next_worker_rank_id_);
106       }
107     } else {
108       registered_nodes_info_.erase((*worker_rank_it).first);
109     }
110 
111     if (rank_id >= meta_data_->worker_num) {
112       MS_LOG(WARNING) << "The rank id is greater than the number of workers:" << meta_data_->worker_num;
113       rank_id = UINT_MAX;
114       --next_worker_rank_id_;
115     }
116     NodeInfo node_info;
117     node_info.node_role_ = NodeRole::WORKER;
118     node_info.node_id_ = node_id;
119     node_info.rank_id_ = rank_id;
120     node_info.ip_ = ip;
121     node_info.port_ = static_cast<uint16_t>(port);
122     node_info.is_alive = true;
123     registered_nodes_info_[node_id] = node_info;
124     MS_LOG(INFO) << "The worker node id:" << node_id << " assign rank id:" << rank_id;
125   }
126   return rank_id;
127 }
128 
UpdateHeartbeat(const std::string & node_id)129 void NodeManager::UpdateHeartbeat(const std::string &node_id) {
130   std::lock_guard<std::mutex> lock(heartbeat_mutex_);
131   struct timeval current_time {};
132   (void)gettimeofday(&current_time, nullptr);
133   heartbeats_[node_id] = current_time;
134 }
135 
FetchServersMeta()136 std::vector<ServersMeta> NodeManager::FetchServersMeta() {
137   std::vector<ServersMeta> servers_meta_list;
138   for (auto it = registered_nodes_info_.begin(); it != registered_nodes_info_.end(); ++it) {
139     if (it->second.node_role_ == NodeRole::SERVER) {
140       ServersMeta servers_meta;
141       servers_meta.set_rank_id(it->second.rank_id_);
142       servers_meta.set_ip(it->second.ip_);
143       servers_meta.set_port(it->second.port_);
144       servers_meta_list.push_back(servers_meta);
145     }
146   }
147   return servers_meta_list;
148 }
149 
FetchAllNodesMeta()150 std::vector<ServersMeta> NodeManager::FetchAllNodesMeta() {
151   std::vector<ServersMeta> servers_meta_list;
152   for (auto it = registered_nodes_info_.begin(); it != registered_nodes_info_.end(); ++it) {
153     ServersMeta servers_meta;
154     servers_meta.set_rank_id(it->second.rank_id_);
155     servers_meta.set_ip(it->second.ip_);
156     servers_meta.set_port(it->second.port_);
157     servers_meta.set_is_alive(it->second.is_alive);
158     servers_meta.set_role(it->second.node_role_);
159     servers_meta.set_node_id(it->second.node_id_);
160     servers_meta_list.push_back(servers_meta);
161   }
162   return servers_meta_list;
163 }
164 
UpdateCluster()165 void NodeManager::UpdateCluster() {
166   // 1. update cluster timeout state
167   struct timeval current_time {};
168   (void)gettimeofday(&current_time, nullptr);
169   timeout_nodes_info_.clear();
170   for (auto it = heartbeats_.begin(); it != heartbeats_.end(); ++it) {
171     if (it->second.tv_sec + PSContext::instance()->cluster_config().heartbeat_timeout < current_time.tv_sec) {
172       if (registered_nodes_info_.count(it->first)) {
173         MS_LOG(WARNING) << "The node id:" << it->first << " is timeout!";
174         timeout_nodes_info_[it->first] = registered_nodes_info_[it->first];
175         registered_nodes_info_[it->first].is_alive = false;
176       }
177     }
178   }
179 
180   if (!timeout_nodes_info_.empty()) {
181     UpdateClusterState(ClusterState::NODE_TIMEOUT);
182     for (auto iter = timeout_nodes_info_.begin(); iter != timeout_nodes_info_.end(); ++iter) {
183       (void)heartbeats_.erase(iter->first);
184       finish_nodes_id_.insert(iter->first);
185     }
186   }
187 
188   // 2. update cluster finish state
189   if (SizeToInt(finish_nodes_id_.size()) == total_node_num_ ||
190       SizeToInt(finish_nodes_id_.size()) == current_node_num_) {
191     UpdateClusterState(ClusterState::CLUSTER_EXIT);
192   }
193 }
194 
CheckClusterTimeout()195 void NodeManager::CheckClusterTimeout() {
196   if (total_node_num_ != SizeToInt(registered_nodes_info_.size())) {
197     MS_LOG(WARNING) << "The cluster is not ready after "
198                     << PSContext::instance()->cluster_config().cluster_available_timeout
199                     << " seconds,so finish the cluster, and change total node number from " << total_node_num_ << " to "
200                     << registered_nodes_info_.size();
201     current_node_num_ = SizeToInt(registered_nodes_info_.size());
202     UpdateClusterState(ClusterState::NODE_TIMEOUT);
203   }
204 }
205 
AddFinishNode(const std::string & finish_message)206 void NodeManager::AddFinishNode(const std::string &finish_message) { finish_nodes_id_.insert(finish_message); }
207 
AddScaleOutDoneNode(const std::string & node_id)208 void NodeManager::AddScaleOutDoneNode(const std::string &node_id) { scale_out_done_nodes_id_.insert(node_id); }
209 
AddScaleInDoneNode(const std::string & node_id)210 void NodeManager::AddScaleInDoneNode(const std::string &node_id) { scale_in_done_nodes_id_.insert(node_id); }
211 
IsAllNodesRegistered() const212 bool NodeManager::IsAllNodesRegistered() const {
213   int32_t num = std::count_if(registered_nodes_info_.begin(), registered_nodes_info_.end(),
214                               [](auto item) { return item.second.is_alive == true; });
215   return num == total_node_num_;
216 }
217 
IsAllNodesFinished() const218 bool NodeManager::IsAllNodesFinished() const { return SizeToInt(finish_nodes_id_.size()) == total_node_num_; }
219 
IsAllNodesScaleOutDone() const220 bool NodeManager::IsAllNodesScaleOutDone() const {
221   return SizeToInt(scale_out_done_nodes_id_.size()) == total_node_num_;
222 }
223 
IsAllNodesScaleInDone() const224 bool NodeManager::IsAllNodesScaleInDone() const { return SizeToInt(scale_in_done_nodes_id_.size()) == total_node_num_; }
225 
nodes_info() const226 const std::unordered_map<std::string, NodeInfo> &NodeManager::nodes_info() const { return nodes_info_; }
227 
registered_nodes_info() const228 const std::unordered_map<std::string, NodeInfo> &NodeManager::registered_nodes_info() const {
229   return registered_nodes_info_;
230 }
231 
UpdateNodesInfo()232 void NodeManager::UpdateNodesInfo() {
233   MS_LOG(INFO) << "Update nodes info.";
234   nodes_info_.clear();
235   nodes_info_ = registered_nodes_info_;
236 }
237 
UpdateNodeState(const NodeState & state)238 void NodeManager::UpdateNodeState(const NodeState &state) {
239   std::lock_guard<std::mutex> lk(node_mutex_);
240   node_state_ = state;
241 }
242 
UpdateClusterState(const ClusterState & state)243 void NodeManager::UpdateClusterState(const ClusterState &state) {
244   std::lock_guard<std::mutex> lk(cluster_mutex_);
245   MS_LOG(INFO) << "[state]: Scheduler change state from:" << CommUtil::ClusterStateToString(cluster_state_) << " to "
246                << CommUtil::ClusterStateToString(state);
247   cluster_state_ = state;
248 }
249 
GetNodeState()250 NodeState NodeManager::GetNodeState() {
251   std::lock_guard<std::mutex> lk(node_mutex_);
252   return node_state_;
253 }
254 
GetClusterState()255 ClusterState NodeManager::GetClusterState() {
256   std::lock_guard<std::mutex> lk(cluster_mutex_);
257   return cluster_state_;
258 }
259 
ResetMetadata(const std::vector<std::string> & scale_in_nodes)260 void NodeManager::ResetMetadata(const std::vector<std::string> &scale_in_nodes) {
261   MS_LOG(WARNING) << "Reset metadata.";
262   std::vector<uint32_t> server_rank_ids;
263   if (GetClusterState() == ClusterState::CLUSTER_SCALE_IN) {
264     for (const auto &item : scale_in_nodes) {
265       if (registered_nodes_info_.count(item)) {
266         server_rank_ids.push_back(registered_nodes_info_[item].rank_id_);
267       }
268     }
269     auto min_rank_id = std::min_element(server_rank_ids.begin(), server_rank_ids.end());
270     next_server_rank_id_ = UintToInt(*min_rank_id - 1);
271     MS_LOG(INFO) << "The next server rank id:" << next_server_rank_id_;
272   }
273   registered_nodes_info_.clear();
274   heartbeats_.clear();
275 }
276 
IsWorkerOrServer0()277 bool NodeManager::IsWorkerOrServer0() {
278   bool res = std::any_of(registered_nodes_info_.begin(), registered_nodes_info_.end(), [](auto item) {
279     if (item.second.node_role_ == NodeRole::WORKER && item.second.is_alive == false) {
280       return true;
281     }
282 
283     if (item.second.node_role_ == NodeRole::SERVER && item.second.is_alive == false && item.second.rank_id_ == 0) {
284       return true;
285     }
286 
287     return false;
288   });
289 
290   return res;
291 }
292 
IsNodeRegistered(const std::string & node_id)293 bool NodeManager::IsNodeRegistered(const std::string &node_id) {
294   if (registered_nodes_info_.find(node_id) != registered_nodes_info_.end()) {
295     return true;
296   }
297   return false;
298 }
299 
set_total_node_num(const int32_t & node_num)300 void NodeManager::set_total_node_num(const int32_t &node_num) { total_node_num_ = node_num; }
301 
total_node_num() const302 const int32_t &NodeManager::total_node_num() const { return total_node_num_; }
303 
set_worker_num(const int32_t & worker_num)304 void NodeManager::set_worker_num(const int32_t &worker_num) { meta_data_->worker_num = IntToUint(worker_num); }
305 
set_server_num(const int32_t & server_num)306 void NodeManager::set_server_num(const int32_t &server_num) { meta_data_->server_num = IntToUint(server_num); }
307 
worker_num() const308 int32_t NodeManager::worker_num() const { return UintToInt(meta_data_->worker_num); }
309 
server_num() const310 int32_t NodeManager::server_num() const { return UintToInt(meta_data_->server_num); }
311 }  // namespace core
312 }  // namespace ps
313 }  // namespace mindspore
314