• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ps/core/abstract_node.h"
18 
19 #include "include/common/debug/common.h"
20 #include "ps/core/communicator/http_communicator.h"
21 #include "ps/core/communicator/tcp_communicator.h"
22 #include "ps/core/node_recovery.h"
23 
24 namespace mindspore {
25 namespace ps {
26 namespace core {
~AbstractNode()27 AbstractNode::~AbstractNode() {
28   try {
29     if (client_to_scheduler_ != nullptr) {
30       client_to_scheduler_->Stop();
31     }
32     if (client_to_scheduler_thread_ != nullptr && client_to_scheduler_thread_->joinable()) {
33       client_to_scheduler_thread_->join();
34     }
35     if (heart_beat_thread_ != nullptr && heart_beat_thread_->joinable()) {
36       heart_beat_thread_->join();
37     }
38     if (server_ != nullptr) {
39       server_->Stop();
40     }
41     if (server_thread_ != nullptr && server_thread_->joinable()) {
42       server_thread_->join();
43     }
44   } catch (const std::exception &e) {
45     MS_LOG(ERROR) << "AbstractNode destructor run failed, error message: " << e.what();
46   } catch (...) {
47     MS_LOG(ERROR) << "AbstractNode destructor run failed, unknown error occurred.";
48   }
49 }
50 
Register(const std::shared_ptr<TcpClient> & client)51 void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
52   MS_EXCEPTION_IF_NULL(client);
53   auto message_meta = std::make_shared<MessageMeta>();
54   MS_EXCEPTION_IF_NULL(message_meta);
55   message_meta->set_cmd(NodeCommand::REGISTER);
56   message_meta->set_rank_id(node_info_.rank_id_);
57 
58   RegisterMessage register_message;
59   register_message.set_node_id(node_info_.node_id_);
60   register_message.set_role(node_info_.node_role_);
61   register_message.set_ip(node_info_.ip_);
62   register_message.set_port(node_info_.port_);
63   register_message.set_is_recover(is_recover.load());
64 
65   MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
66                << " the node id:" << node_info_.node_id_ << " begin to register to the scheduler!";
67 
68   if (!SendMessageAsync(client, message_meta, Protos::PROTOBUF, register_message.SerializeAsString().data(),
69                         register_message.ByteSizeLong())) {
70     MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
71                   << " the node id:" << node_info_.node_id_ << " register timeout!";
72   } else {
73     MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
74                  << " the node id:" << node_info_.node_id_ << " send register success!";
75   }
76 }
77 
SendFailMessageToScheduler(const std::string & node_role,const std::string & event_info)78 void AbstractNode::SendFailMessageToScheduler(const std::string &node_role, const std::string &event_info) {
79   auto message_meta = std::make_shared<MessageMeta>();
80   MS_EXCEPTION_IF_NULL(message_meta);
81   message_meta->set_cmd(NodeCommand::FAILURE_EVENT_INFO);
82 
83   std::string now_time = ps::core::CommUtil::GetNowTime().time_str_mill;
84   FailureEventMessage failure_event_message;
85   failure_event_message.set_node_role(node_role);
86   failure_event_message.set_ip(node_info_.ip_);
87   failure_event_message.set_port(node_info_.port_);
88   failure_event_message.set_time(now_time);
89   failure_event_message.set_event(event_info);
90 
91   MS_LOG(INFO) << "The node role:" << node_role << "The node id:" << node_info_.node_id_
92                << "begin to send failure message to scheduler!";
93 
94   if (!SendMessageAsync(client_to_scheduler_, message_meta, Protos::PROTOBUF,
95                         failure_event_message.SerializeAsString().data(), failure_event_message.ByteSizeLong())) {
96     MS_LOG(ERROR) << "The node role:" << node_role << " the node id:" << node_info_.node_id_
97                   << " send failure message timeout!";
98   } else {
99     MS_LOG(INFO) << "The node role:" << node_role << " the node id:" << node_info_.node_id_ << " send failure message "
100                  << event_info << "success!";
101   }
102 }
103 
ProcessRegisterResp(const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)104 void AbstractNode::ProcessRegisterResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
105   MS_EXCEPTION_IF_NULL(meta);
106   MS_EXCEPTION_IF_NULL(data);
107   RegisterRespMessage register_resp_message;
108   CHECK_RETURN_TYPE(register_resp_message.ParseFromArray(data, SizeToInt(size)));
109   MS_LOG(INFO) << "The node id get from scheduler is:" << register_resp_message.node_id()
110                << ", rank_id is:" << register_resp_message.rank_id();
111 
112   if (register_resp_message.node_id() != node_info_.node_id_) {
113     MS_LOG(ERROR) << "The node id received:" << register_resp_message.node_id()
114                   << " is not match the current node id:" << node_info_.node_id_;
115     return;
116   }
117   node_info_.rank_id_ = register_resp_message.rank_id();
118   if (node_info_.rank_id_ == UINT32_MAX) {
119     MS_LOG(ERROR) << "The rank id received:" << register_resp_message.rank_id();
120     return;
121   }
122 
123   // Receive the Register message, indicating that the scheduler is alive, so update the time point at which the
124   // scheduler is alive
125   UpdateSchedulerTime();
126 
127   MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << " registered scheduler success!";
128 }
129 
Broadcast(const NodeRole & node_role,const std::string & message,int command,const uint32_t & timeout)130 bool AbstractNode::Broadcast(const NodeRole &node_role, const std::string &message, int command,
131                              const uint32_t &timeout) {
132   if (node_role != NodeRole::SERVER) {
133     MS_LOG(EXCEPTION) << "Currently only supports broadcast to server nodes";
134   }
135 
136   uint32_t broadcast_size = 0;
137   (void)std::for_each(nodes_address_.begin(), nodes_address_.end(), [&broadcast_size, &node_role](const auto &addr) {
138     if (addr.first.first == node_role) {
139       ++broadcast_size;
140     }
141   });
142   uint64_t request_id = AddMessageTrack(broadcast_size);
143 
144   for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) {
145     if (it->first.first != node_role) {
146       continue;
147     }
148     auto message_meta = std::make_shared<MessageMeta>();
149     MS_EXCEPTION_IF_NULL(message_meta);
150     message_meta->set_cmd(NodeCommand::SEND_DATA);
151     message_meta->set_request_id(request_id);
152     message_meta->set_rank_id(node_info_.rank_id_);
153     message_meta->set_role(node_info_.node_role_);
154     message_meta->set_user_cmd(command);
155 
156     auto client = GetOrCreateTcpClient((*it).first.second);
157     if (!client->SendMessage(message_meta, Protos::RAW, message.data(), message.size())) {
158       MS_LOG(WARNING) << "Client send message failed.";
159     }
160   }
161   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
162                 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
163   return Wait(request_id, timeout);
164 }
165 
set_ready_for_scale_out()166 void AbstractNode::set_ready_for_scale_out() {
167   MS_LOG(INFO) << "[Scale out]: begin to set ready for scale out.";
168   Register(client_to_scheduler_);
169   std::lock_guard<std::mutex> lock(client_mutex_);
170   connected_nodes_.clear();
171 }
172 
set_ready_for_scale_in()173 void AbstractNode::set_ready_for_scale_in() {
174   MS_LOG(INFO) << "[Scale in]: begin to set ready for scale in.";
175   if (!is_current_node_scale_in_) {
176     Register(client_to_scheduler_);
177     std::lock_guard<std::mutex> lock(client_mutex_);
178     connected_nodes_.clear();
179   }
180 }
181 
set_scale_out_done()182 void AbstractNode::set_scale_out_done() {
183   MS_LOG(INFO) << "[Scale out]: begin to set scale out done.";
184   auto message_meta = std::make_shared<MessageMeta>();
185   MS_EXCEPTION_IF_NULL(message_meta);
186   message_meta->set_cmd(NodeCommand::SCALE_OUT_DONE);
187 
188   ScaleOutDoneMessage scale_out_done_message;
189   scale_out_done_message.set_node_id(node_info_.node_id_);
190 
191   if (!SendMessageSync(client_to_scheduler_, message_meta, Protos::PROTOBUF,
192                        scale_out_done_message.SerializeAsString().data(), scale_out_done_message.ByteSizeLong())) {
193     MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
194                     << " the node id:" << node_info_.node_id_ << " scale_out_done timeout!";
195     return;
196   }
197 
198   MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
199                << " the node id:" << node_info_.node_id_ << "is send scale_out_done to scheduler successful!";
200 }
201 
set_scale_in_done()202 void AbstractNode::set_scale_in_done() {
203   MS_LOG(INFO) << "[Scale in]: begin to set scale in done.";
204   auto message_meta = std::make_shared<MessageMeta>();
205   MS_EXCEPTION_IF_NULL(message_meta);
206   message_meta->set_cmd(NodeCommand::SCALE_IN_DONE);
207 
208   ScaleInDoneMessage scale_in_done_message;
209   scale_in_done_message.set_node_id(node_info_.node_id_);
210 
211   if (!SendMessageSync(client_to_scheduler_, message_meta, Protos::PROTOBUF,
212                        scale_in_done_message.SerializeAsString().data(), scale_in_done_message.ByteSizeLong())) {
213     MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
214                     << " the node id:" << node_info_.node_id_ << " scale_in_done timeout!";
215     return;
216   }
217 
218   MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
219                << " the node id:" << node_info_.node_id_ << "is send scale_in_done to scheduler successful!";
220 }
221 
BroadcastEvent(const uint32_t & event)222 void AbstractNode::BroadcastEvent(const uint32_t &event) {
223   auto message_meta = std::make_shared<MessageMeta>();
224   MS_EXCEPTION_IF_NULL(message_meta);
225   message_meta->set_cmd(NodeCommand::SEND_EVENT);
226 
227   EventRespMessage event_resp_message;
228   event_resp_message.set_event(event);
229 
230   for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) {
231     const uint32_t rank_id = (*it).first.second;
232     const NodeRole role = (*it).first.first;
233     auto client = GetOrCreateTcpClient(rank_id, role);
234     if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, event_resp_message.SerializeAsString().data(),
235                          event_resp_message.ByteSizeLong())) {
236       MS_LOG(WARNING) << "send event to node role:" << CommUtil::NodeRoleToString(role) << ", rank id:" << rank_id
237                       << " timeout!";
238     }
239   }
240   MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
241                << " the node id:" << node_info_.node_id_ << " send event to server/worker!";
242 }
243 
RegisterEventCallback(const core::ClusterEvent & event,const EventCallback & event_cb)244 void AbstractNode::RegisterEventCallback(const core::ClusterEvent &event, const EventCallback &event_cb) {
245   event_to_callback_.try_emplace(event, event_cb);
246 }
247 
RegisterCustomEventCallback(const uint32_t & event,const EventCallback & event_cb)248 void AbstractNode::RegisterCustomEventCallback(const uint32_t &event, const EventCallback &event_cb) {
249   custom_event_to_callback_.try_emplace(event, event_cb);
250 }
251 
Send(const NodeRole & node_role,const uint32_t & rank_id,const void * message,size_t len,int command,VectorPtr * output,const uint32_t & timeout)252 bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, const void *message, size_t len,
253                         int command, VectorPtr *output, const uint32_t &timeout) {
254   MS_EXCEPTION_IF_NULL(message);
255   if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
256     MS_LOG(ERROR) << "The node role or rank_id is illegal, the worker num:" << worker_num_
257                   << ", the server num:" << server_num_ << ", the rank id:" << rank_id;
258     return false;
259   }
260 
261   uint64_t request_id = AddMessageTrack(1);
262   if (output != nullptr) {
263     set_message_callback(request_id, [this, request_id, rank_id, output]() {
264       receive_messages_mutex_.lock();
265       auto res = receive_messages_[request_id];
266       *output = res[rank_id];
267       receive_messages_.erase(request_id);
268       receive_messages_mutex_.unlock();
269     });
270   }
271 
272   auto message_meta = std::make_shared<MessageMeta>();
273   MS_EXCEPTION_IF_NULL(message_meta);
274   message_meta->set_cmd(NodeCommand::SEND_DATA);
275   message_meta->set_request_id(request_id);
276   message_meta->set_rank_id(node_info_.rank_id_);
277   message_meta->set_role(node_info_.node_role_);
278   message_meta->set_user_cmd(command);
279 
280   auto client = GetOrCreateTcpClient(rank_id, node_role);
281   MS_EXCEPTION_IF_NULL(client);
282   if (!client->SendMessage(message_meta, Protos::RAW, message, len)) {
283     MS_LOG(WARNING) << "Client send message failed.";
284   }
285   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
286                 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
287   return Wait(request_id, timeout);
288 }
289 
Send(const NodeRole & node_role,const uint32_t & rank_id,const std::string & msg,int command,VectorPtr * output,const uint32_t & timeout)290 bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, const std::string &msg, int command,
291                         VectorPtr *output, const uint32_t &timeout) {
292   return Send(node_role, rank_id, msg.data(), msg.length(), command, output, timeout);
293 }
294 
Send(const NodeRole & node_role,const std::vector<uint32_t> & rank_ids,const std::vector<std::string> & msgs,int command,std::vector<VectorPtr> * output,const uint32_t & timeout)295 bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
296                         const std::vector<std::string> &msgs, int command, std::vector<VectorPtr> *output,
297                         const uint32_t &timeout) {
298   uint64_t request_id = AddMessageTrack(msgs.size());
299 
300   if (rank_ids.size() != msgs.size()) {
301     MS_LOG(EXCEPTION) << "The number of rank ids and messages are not equal!";
302   }
303 
304   if (output != nullptr) {
305     set_message_callback(request_id, [this, request_id, &rank_ids, output]() {
306       receive_messages_mutex_.lock();
307       auto &res = receive_messages_[request_id];
308       for (auto &rank_id : rank_ids) {
309         auto &response = res[rank_id];
310         output->push_back(response);
311       }
312       receive_messages_.erase(request_id);
313       receive_messages_mutex_.unlock();
314     });
315   }
316   size_t size = rank_ids.size();
317   for (size_t it = 0; it < size; ++it) {
318     if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it), worker_num_, server_num_)) {
319       MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
320                         << ", the server num:" << server_num_ << ", the rank id:" << rank_ids.at(it);
321     }
322 
323     auto message_meta = std::make_shared<MessageMeta>();
324     MS_EXCEPTION_IF_NULL(message_meta);
325     message_meta->set_cmd(NodeCommand::SEND_DATA);
326     message_meta->set_request_id(request_id);
327     message_meta->set_rank_id(node_info_.rank_id_);
328     message_meta->set_role(node_info_.node_role_);
329     message_meta->set_user_cmd(command);
330 
331     auto &msg = msgs.at(it);
332 
333     auto client = GetOrCreateTcpClient(rank_ids.at(it), node_role);
334     MS_EXCEPTION_IF_NULL(client);
335     if (!client->SendMessage(message_meta, Protos::RAW, msg.data(), msg.size())) {
336       MS_LOG(WARNING) << "Client send message failed.";
337     }
338   }
339   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
340                 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
341   return Wait(request_id, timeout);
342 }
343 
SendToScheduler(const void * message,size_t len,NodeCommand node_cmd,VectorPtr * output,const uint32_t & timeout)344 bool AbstractNode::SendToScheduler(const void *message, size_t len, NodeCommand node_cmd, VectorPtr *output,
345                                    const uint32_t &timeout) {
346   MS_EXCEPTION_IF_NULL(message);
347 
348   uint32_t expected_reponse_num = 1;
349   uint64_t request_id = AddMessageTrack(expected_reponse_num);
350   auto message_meta = std::make_shared<MessageMeta>();
351   MS_EXCEPTION_IF_NULL(message_meta);
352   message_meta->set_cmd(node_cmd);
353   message_meta->set_request_id(request_id);
354 
355   MS_EXCEPTION_IF_NULL(client_to_scheduler_);
356   if (!client_to_scheduler_->SendMessage(message_meta, Protos::RAW, message, len)) {
357     MS_LOG(WARNING) << "Failed to send message" << node_cmd << "to scheduler.";
358   }
359 
360   bool ret = Wait(request_id, timeout);
361   if (!ret) {
362     MS_LOG(ERROR) << "Sending message " << node_cmd << " to scheduler timeout.";
363     return ret;
364   }
365 
366   // Assign the response value from scheduler.
367   if (output != nullptr) {
368     if (received_scheduler_messages_.count(request_id) == 0) {
369       MS_LOG(ERROR) << "The response message of command " << node_cmd << ", request_id " << request_id
370                     << " is not received yet.";
371       return false;
372     }
373     *output = received_scheduler_messages_[request_id];
374     (void)received_scheduler_messages_.erase(request_id);
375   }
376   return ret;
377 }
378 
CollectiveSendAsync(const NodeRole & node_role,const uint32_t & rank_id,const void * data,size_t size)379 uint64_t AbstractNode::CollectiveSendAsync(const NodeRole &node_role, const uint32_t &rank_id, const void *data,
380                                            size_t size) {
381   MS_EXCEPTION_IF_NULL(data);
382   if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
383     MS_LOG(ERROR) << "The node role or rank_id is illegal, the worker num:" << worker_num_
384                   << ", the server num:" << server_num_ << ", the rank id:" << rank_id;
385     return 0;
386   }
387 
388   std::shared_ptr<MessageMeta> message_meta = std::make_shared<MessageMeta>();
389   MS_EXCEPTION_IF_NULL(message_meta);
390   message_meta->set_cmd(NodeCommand::COLLECTIVE_SEND_DATA);
391   message_meta->set_rank_id(node_info_.rank_id_);
392   message_meta->set_role(node_info_.node_role_);
393 
394   auto client = GetOrCreateTcpClient(rank_id, node_role);
395   MS_EXCEPTION_IF_NULL(client);
396   return SendCollectiveMeta(client, message_meta, Protos::RAW, data, size);
397 }
398 
CollectiveMetaToString(const CollectiveMessageMeta & meta)399 static std::string CollectiveMetaToString(const CollectiveMessageMeta &meta) {
400   std::ostringstream os;
401   os << "{iteration:" << meta.iteration() << ", data:" << meta.weight_name() << ", send rank:" << meta.send_rank_id()
402      << ", recv rank:" << meta.recv_rank_id() << ", phase:" << meta.phase() << ", chunk index:" << meta.chunk_index()
403      << ", for index:" << meta.for_index() << "}";
404   return os.str();
405 }
406 
FlCollectiveSendAsync(const CollectiveMessageMeta & collective_meta,const void * data,size_t size)407 uint64_t AbstractNode::FlCollectiveSendAsync(const CollectiveMessageMeta &collective_meta, const void *data,
408                                              size_t size) {
409   MS_EXCEPTION_IF_NULL(data);
410   auto recv_rank_id = collective_meta.recv_rank_id();
411   if (!CommUtil::ValidateRankId(SERVER, recv_rank_id, worker_num_, server_num_)) {
412     MS_LOG(ERROR) << "The node role or rank_id is illegal, the worker num:" << worker_num_
413                   << ", the server num:" << server_num_ << ", the rank id:" << recv_rank_id;
414     return 0;
415   }
416   std::shared_ptr<MessageMeta> message_meta = std::make_shared<MessageMeta>();
417   MS_EXCEPTION_IF_NULL(message_meta);
418   message_meta->set_cmd(NodeCommand::COLLECTIVE_SEND_DATA);
419   message_meta->set_rank_id(node_info_.rank_id_);
420   message_meta->set_role(node_info_.node_role_);
421   *(message_meta->mutable_collective_meta()) = collective_meta;
422   message_meta->mutable_collective_meta()->set_enable_flag(true);
423   message_meta->mutable_collective_meta()->set_send_rank_id(node_info_.rank_id_);
424 
425   MS_LOG(DEBUG) << "Send data to rank id:" << recv_rank_id
426                 << ", send meta:" << CollectiveMetaToString(message_meta->collective_meta());
427   auto client = GetOrCreateTcpClient(recv_rank_id, SERVER);
428   MS_EXCEPTION_IF_NULL(client);
429   return SendCollectiveMeta(client, message_meta, Protos::RAW, data, size);
430 }
431 
FlCollectiveWaitInner(const CollectiveMessageMeta & expect_meta,VectorPtr * output,const uint32_t & timeout)432 bool AbstractNode::FlCollectiveWaitInner(const CollectiveMessageMeta &expect_meta, VectorPtr *output,
433                                          const uint32_t &timeout) {
434   if (output == nullptr) {
435     return false;
436   }
437   auto send_rank_id = expect_meta.send_rank_id();
438   if (!CommUtil::ValidateRankId(SERVER, send_rank_id, worker_num_, server_num_)) {
439     MS_LOG(ERROR) << "The node role or rank_id is illegal, the worker num:" << worker_num_
440                   << ", the server num:" << server_num_ << ", the rank id:" << send_rank_id;
441     return false;
442   }
443   auto check_meta = [](const CollectiveMessageMeta &left, const CollectiveMessageMeta &right) {
444     return left.iteration() == right.iteration() && left.weight_name() == right.weight_name() &&
445            left.recv_rank_id() == right.recv_rank_id() && left.send_rank_id() == right.send_rank_id() &&
446            left.phase() == right.phase() && left.chunk_index() == right.chunk_index() &&
447            left.for_index() == right.for_index();
448   };
449   auto iteration_num = expect_meta.iteration();
450   std::unique_lock<std::mutex> lock(fl_receive_mutex_);
451   auto &recv_data_list = fl_received_data_[send_rank_id];
452   for (uint32_t i = 0; i < timeout; i++) {
453     if (recv_data_list.empty()) {
454       fl_receive_cond_.wait_for(lock, std::chrono::seconds(1), [&recv_data_list]() { return !recv_data_list.empty(); });
455       if (recv_data_list.empty()) {               // timeout
456         if (HasIterationFailed(iteration_num)) {  // if result of iteration reported by other server is failed
457           MS_LOG(WARNING) << "Detect iteration " << iteration_num << " has failed";
458           return false;
459         }
460         continue;
461       }
462     }
463     while (!recv_data_list.empty()) {
464       auto first = recv_data_list.begin();
465       auto recv_meta = std::move(first->first);
466       auto recv_data = std::move(first->second);
467       recv_data_list.erase(first);
468       MS_LOG(DEBUG) << "Handle receive data from rank id:" << send_rank_id
469                     << ", recv meta:" << CollectiveMetaToString(recv_meta);
470       if (recv_meta.iteration() != expect_meta.iteration()) {
471         MS_LOG(WARNING) << "Skip recv data, iteration of recv meta " << recv_meta.iteration()
472                         << " != iteration of expected meta " << expect_meta.iteration();
473         continue;
474       }
475       // error data in the same iteration
476       if (!check_meta(recv_meta, expect_meta)) {
477         MS_LOG(WARNING) << "Recv meta not match expected meta, recv mata: " << CollectiveMetaToString(recv_meta)
478                         << ", expected meta: " << CollectiveMetaToString(expect_meta);
479         return false;
480       }
481       *output = recv_data;
482       return true;  // success to recv data
483     }
484   }
485   return false;
486 }
487 
FlCollectiveWait(const CollectiveMessageMeta & expect_meta,size_t expect_size,VectorPtr * output,const uint32_t & timeout)488 bool AbstractNode::FlCollectiveWait(const CollectiveMessageMeta &expect_meta, size_t expect_size, VectorPtr *output,
489                                     const uint32_t &timeout) {
490   if (output == nullptr) {
491     MS_LOG(ERROR) << "FlCollectiveWait failed, parameter output invalid";
492     return false;
493   }
494   auto data_recved = FlCollectiveWaitInner(expect_meta, output, timeout);
495   if (!data_recved) {
496     MS_LOG(ERROR) << "FlCollectiveWait failed, expect meta: " << CollectiveMetaToString(expect_meta);
497     return false;
498   }
499   if (*output == nullptr) {
500     MS_LOG(ERROR) << "FlCollectiveWait failed, recv buffer invalid";
501     return false;
502   }
503   if (expect_size != (*output)->size()) {
504     MS_LOG(ERROR) << "Expected data size " << expect_size << " != recv data size " << (*output)->size()
505                   << CollectiveMetaToString(expect_meta);
506     return false;
507   }
508   return true;
509 }
510 
OnRecvCollectiveData(const MessageMeta & message_meta,const VectorPtr & data)511 void AbstractNode::OnRecvCollectiveData(const MessageMeta &message_meta, const VectorPtr &data) {
512   std::unique_lock<std::mutex> lock(fl_receive_mutex_);
513   auto &recv_meta = message_meta.collective_meta();
514   auto send_rank_id = recv_meta.send_rank_id();
515   MS_LOG(DEBUG) << "Receive data from rank id:" << send_rank_id << ", recv meta:" << CollectiveMetaToString(recv_meta);
516   fl_received_data_[send_rank_id].emplace_back(std::make_pair(recv_meta, data));
517   fl_receive_cond_.notify_all();
518 }
519 
HasIterationFailed(uint32_t iteration_num) const520 bool AbstractNode::HasIterationFailed(uint32_t iteration_num) const {
521   return iteration_num == failed_iteration_num_ && iteration_failed_;
522 }
523 
CollectiveReceiveAsync(const NodeRole & node_role,const uint32_t & rank_id,VectorPtr * output)524 std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const NodeRole &node_role, const uint32_t &rank_id,
525                                                                    VectorPtr *output) {
526   MS_EXCEPTION_IF_NULL(output);
527   if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
528     MS_LOG(ERROR) << "The node role or rank_id is illegal, the worker num:" << worker_num_
529                   << ", the server num:" << server_num_ << ", the rank id:" << rank_id;
530     return std::make_pair(0, 0);
531   }
532 
533   receive_callbacks_mutex_.lock();
534   uint64_t rank_request_id = NextExpectedRankRequestId(rank_id);
535   auto pair_data = std::make_pair(rank_id, rank_request_id);
536   receive_messages_done_[pair_data] = false;
537   if (received_data_.count(pair_data) > 0) {
538     auto res = received_data_[pair_data];
539     MS_EXCEPTION_IF_NULL(res);
540     *output = res;
541     (void)received_data_.erase(pair_data);
542     receive_messages_done_[pair_data] = true;
543     MS_LOG(DEBUG) << "Receive data from rank id:" << rank_id << ", the rank request id is:" << rank_request_id;
544   } else {
545     receive_callbacks_[pair_data] = [=]() mutable {
546       auto res_output = received_data_[std::make_pair(rank_id, rank_request_id)];
547       MS_EXCEPTION_IF_NULL(res_output);
548       if (*output != nullptr) {
549         MS_LOG(WARNING) << "The output is not empty.";
550       }
551       *output = res_output;
552       received_data_.erase(std::make_pair(rank_id, rank_request_id));
553       receive_messages_done_[std::make_pair(rank_id, rank_request_id)] = true;
554       MS_LOG(DEBUG) << "Receive data from rank id:" << rank_id << ", the rank request id is:" << rank_request_id;
555     };
556   }
557   receive_callbacks_mutex_.unlock();
558   return std::make_pair(rank_id, rank_request_id);
559 }
560 
CollectiveWait(const std::pair<uint32_t,uint64_t> & request_id,const uint32_t & timeout)561 bool AbstractNode::CollectiveWait(const std::pair<uint32_t, uint64_t> &request_id, const uint32_t &timeout) {
562   std::unique_lock<std::mutex> lock(receive_callbacks_mutex_);
563   bool res =
564     receive_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { return receive_messages_done_[request_id]; });
565   if (receive_messages_done_.count(request_id) != 0) {
566     (void)receive_messages_done_.erase(request_id);
567   }
568   return res;
569 }
570 
persistent_state() const571 PersistentState AbstractNode::persistent_state() const { return persistent_state_; }
set_persistent_state(PersistentState persistent_state)572 void AbstractNode::set_persistent_state(PersistentState persistent_state) { persistent_state_ = persistent_state; }
573 
worker_num() const574 uint32_t AbstractNode::worker_num() const { return worker_num_; }
575 
server_num() const576 uint32_t AbstractNode::server_num() const { return server_num_; }
577 
set_worker_num(const uint32_t & worker_num)578 void AbstractNode::set_worker_num(const uint32_t &worker_num) { worker_num_ = worker_num; }
579 
set_server_num(const uint32_t & server_num)580 void AbstractNode::set_server_num(const uint32_t &server_num) { server_num_ = server_num; }
581 
scheduler_ip() const582 std::string AbstractNode::scheduler_ip() const { return scheduler_ip_; }
583 
set_scheduler_ip(const std::string & scheduler_ip)584 void AbstractNode::set_scheduler_ip(const std::string &scheduler_ip) { scheduler_ip_ = scheduler_ip; }
585 
scheduler_port() const586 uint16_t AbstractNode::scheduler_port() const { return scheduler_port_; }
587 
set_scheduler_port(const uint16_t & scheduler_port)588 void AbstractNode::set_scheduler_port(const uint16_t &scheduler_port) { scheduler_port_ = scheduler_port; }
589 
cluster_state() const590 ClusterState AbstractNode::cluster_state() const { return current_cluster_state_; }
591 
set_handler(const RequestHandler & handler)592 void AbstractNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; }
593 
Response(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)594 void AbstractNode::Response(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
595                             const void *data, size_t size) {
596   MS_EXCEPTION_IF_NULL(conn);
597   MS_EXCEPTION_IF_NULL(meta);
598   MS_EXCEPTION_IF_NULL(data);
599   MS_EXCEPTION_IF_NULL(server_);
600   meta->set_role(node_info_.node_role_);
601   meta->set_rank_id(node_info_.rank_id_);
602   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
603                 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << meta->request_id();
604   if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
605     MS_LOG(WARNING) << "Server response message failed.";
606   }
607 }
608 
StartHeartbeatTimer(const std::shared_ptr<TcpClient> & client)609 void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) {
610   MS_EXCEPTION_IF_NULL(client);
611   MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_)
612                << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_
613                << " begin send heartbeat to the scheduler!";
614   heart_beat_thread_ = std::make_unique<std::thread>([&]() {
615     uint32_t connect_interval = PSContext::instance()->cluster_config().connect_interval;
616     uint32_t heartbeat_interval = PSContext::instance()->cluster_config().heartbeat_interval * 1000;
617     uint32_t reconnect_interval = 0;
618     if (heartbeat_interval > connect_interval) {
619       MS_LOG(WARNING) << "heartbeat_interval [" << heartbeat_interval << "] is larger than connect_interval ["
620                       << connect_interval << "], reset connect_interval to " << heartbeat_interval;
621     }
622     while (!is_finish_.load()) {
623       if (!Heartbeat(client)) {
624         MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
625                         << ", the node id is:" << node_info_.node_id_ << " Send heartbeat failed!";
626         if (CheckSchedulerTimeout()) {
627           MS_LOG(WARNING) << "Scheduler is Timeout, please recovery.";
628         }
629       } else {
630         UpdateSchedulerTime();
631       }
632 
633       if (!is_already_finished_ && (client->connection_status() == -1)) {
634         if (reconnect_interval > connect_interval) {
635           MS_LOG(WARNING) << "Connection to Scheduler is disconnected, try to reconnect.";
636           reconnect_interval = 0;
637           ConnectToScheduler();
638         } else {
639           reconnect_interval += heartbeat_interval;
640         }
641       }
642 
643       std::this_thread::sleep_for(std::chrono::milliseconds(heartbeat_interval));
644     }
645   });
646   MS_EXCEPTION_IF_NULL(heart_beat_thread_);
647 }
648 
Heartbeat(const std::shared_ptr<TcpClient> & client)649 bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client) {
650   MS_EXCEPTION_IF_NULL(client);
651   if (client->connection_status() != 1) {
652     return false;
653   }
654   auto meta = std::make_shared<MessageMeta>();
655   MS_EXCEPTION_IF_NULL(meta);
656   meta->set_cmd(NodeCommand::HEARTBEAT);
657 
658   HeartbeatMessage heartbeat_message;
659   heartbeat_message.set_node_id(node_info_.node_id_);
660   heartbeat_message.set_persistent_state(PersistentState::NOT_ENABLE_PERSIST);
661 
662   // The worker role does not support disaster recovery currently.
663   if (EnableRecovery() && role() == NodeRole::SERVER) {
664     heartbeat_message.set_persistent_state(persistent_state_);
665   }
666 
667   if (!SendMessageSync(client, meta, Protos::PROTOBUF, heartbeat_message.SerializeAsString().data(),
668                        heartbeat_message.ByteSizeLong(), kCommTimeoutInSeconds)) {
669     MS_LOG(WARNING) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!";
670     return false;
671   }
672   return true;
673 }
674 
UpdateSchedulerTime()675 void AbstractNode::UpdateSchedulerTime() {
676   struct timeval current_time {};
677   (void)gettimeofday(&current_time, nullptr);
678   scheduler_time_ = current_time;
679   MS_LOG(DEBUG) << "Update scheduler time, the current time is: " << current_time.tv_sec;
680 }
681 
CheckSchedulerTimeout() const682 bool AbstractNode::CheckSchedulerTimeout() const {
683   struct timeval current_time {};
684   (void)gettimeofday(&current_time, nullptr);
685   int64_t old_time = scheduler_time_.tv_sec + PSContext::instance()->cluster_config().scheduler_timeout;
686   if (old_time < current_time.tv_sec) {
687     return true;
688   }
689   return false;
690 }
691 
ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)692 void AbstractNode::ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
693   MS_EXCEPTION_IF_NULL(meta);
694   MS_EXCEPTION_IF_NULL(data);
695   HeartbeatRespMessage heartbeat_resp_message;
696   CHECK_RETURN_TYPE(heartbeat_resp_message.ParseFromArray(data, SizeToInt(size)));
697 
698   if (heartbeat_resp_message.cluster_state() != current_cluster_state_ &&
699       current_cluster_state_ != ClusterState::CLUSTER_SCALE_IN &&
700       current_cluster_state_ != ClusterState::CLUSTER_SCALE_OUT) {
701     UpdateClusterState(heartbeat_resp_message.cluster_state());
702   }
703   MS_LOG(DEBUG) << "The current cluster state from heartbeat:"
704                 << CommUtil::ClusterStateToString(current_cluster_state_);
705 
706   std::string timeoutNodeId;
707 
708   all_nodes_info_.clear();
709   for (const auto &it : heartbeat_resp_message.servers_meta()) {
710     NodeInfo info;
711     info.ip_ = it.ip();
712     info.node_id_ = it.node_id();
713     info.port_ = it.port();
714     info.node_role_ = it.role();
715     info.rank_id_ = it.rank_id();
716     info.is_alive = it.is_alive();
717 
718     if (!info.is_alive) {
719       timeoutNodeId += (info.node_id_ + " ");
720     }
721 
722     all_nodes_info_[info.node_id_] = info;
723     MS_LOG(DEBUG) << "The node id:" << info.node_id_ << ", the rank id:" << info.rank_id_
724                   << ", the node role:" << CommUtil::NodeRoleToString(info.node_role_) << " is alive:" << info.is_alive;
725   }
726   bool is_worker = heartbeat_resp_message.is_worker();
727   bool is_ps_mode = PSContext::instance()->server_mode() == ps::kServerModePS;
728   bool not_enable_recover_node_timeout = (is_worker && is_ps_mode);
729 
730   if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
731     if (node_recovery_ == nullptr || not_enable_recover_node_timeout) {
732       MS_LOG(INFO) << "The recovery is disabled. Trigger NODE_TIMEOUT event.";
733       // Avoid other methods blocking endlessly when NODE_TIMEOUT event is triggered.
734       is_ready_ = true;
735       wait_start_cond_.notify_all();
736       is_finish_ = true;
737       wait_finish_cond_.notify_all();
738       OnEventCallback(ClusterEvent::NODE_TIMEOUT);
739     } else {
740       MS_LOG(INFO) << "The nodes:" << timeoutNodeId
741                    << "is support recovery, users can pull up this node to restore the cluster.";
742     }
743   }
744 
745   if (!EnableRecovery()) {
746     return;
747   }
748 
749   PersistentCommand persistent_cmd = heartbeat_resp_message.persistent_cmd();
750   // The worker role does not support disaster recovery for the time being.
751   if (role() == NodeRole::SERVER && persistent_cmd == PersistentCommand::BEGIN_PERSIST &&
752       persistent_state_ != PersistentState::PERSISTING) {
753     OnEventCallback(ClusterEvent::ON_BEGIN_PERSIST);
754   }
755 }
756 
FetchServers(const std::shared_ptr<TcpClient> & client)757 void AbstractNode::FetchServers(const std::shared_ptr<TcpClient> &client) {
758   MS_EXCEPTION_IF_NULL(client);
759   auto meta = std::make_shared<MessageMeta>();
760   MS_EXCEPTION_IF_NULL(meta);
761   meta->set_cmd(NodeCommand::FETCH_METADATA);
762 
763   FetchServersMessage fetch_servers;
764   fetch_servers.set_node_id(node_info_.node_id_);
765   if (!SendMessageSync(client, meta, Protos::PROTOBUF, fetch_servers.SerializeAsString().data(),
766                        fetch_servers.ByteSizeLong())) {
767     MS_LOG(EXCEPTION) << "Fetch servers address timeout!";
768   }
769 }
770 
ProcessFetchServersResp(const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)771 void AbstractNode::ProcessFetchServersResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
772   MS_EXCEPTION_IF_NULL(meta);
773   MS_EXCEPTION_IF_NULL(data);
774   FetchServersRespMessage fetch_servers_resp_message;
775   CHECK_RETURN_TYPE(fetch_servers_resp_message.ParseFromArray(data, SizeToInt(size)));
776 
777   nodes_address_.clear();
778   for (const auto &it : fetch_servers_resp_message.servers_meta()) {
779     nodes_address_[std::make_pair(it.role(), it.rank_id())] = std::make_pair(it.ip(), it.port());
780     MS_LOG(INFO) << "The server ip is:" << it.ip() << ", the port is:" << it.port();
781   }
782 }
783 
ProcessReceiveSchedulerResp(const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)784 void AbstractNode::ProcessReceiveSchedulerResp(const std::shared_ptr<MessageMeta> &meta, const void *data,
785                                                size_t size) {
786   MS_EXCEPTION_IF_NULL(meta);
787   MS_EXCEPTION_IF_NULL(data);
788   std::lock_guard<std::mutex> lock(receive_messages_mutex_);
789 
790   const uint64_t request_id = meta->request_id();
791   VectorPtr received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
792   if (size > 0) {
793     size_t dest_size = size;
794     size_t src_size = size;
795     auto ret = memcpy_s(received_data.get()->data(), dest_size, data, src_size);
796     if (ret != EOK) {
797       MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
798     }
799   }
800   received_scheduler_messages_[request_id] = received_data;
801 }
802 
ProcessSendMetadata(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)803 void AbstractNode::ProcessSendMetadata(const std::shared_ptr<TcpConnection> &conn,
804                                        const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
805                                        size_t size) {
806   MS_EXCEPTION_IF_NULL(conn);
807   MS_EXCEPTION_IF_NULL(meta);
808   MS_EXCEPTION_IF_NULL(data);
809   if (is_current_node_scale_in_) {
810     MS_LOG(WARNING) << "Trigger cluster scale in done event.";
811     node_info_.rank_id_ = UINT32_MAX;
812     OnEventCallback(ClusterEvent::CLUSTER_SCALE_IN_DONE);
813     return;
814   }
815   SendMetadataMessage send_meta_message;
816   send_meta_message.ParseFromArray(data, SizeToInt(size));
817   worker_num_ = send_meta_message.worker_num();
818   server_num_ = send_meta_message.server_num();
819   if (send_meta_message.rank_id() < 0) {
820     MS_LOG(EXCEPTION) << "The rank id is wrong.";
821   }
822   node_info_.rank_id_ = send_meta_message.rank_id();
823   UpdateClusterState(send_meta_message.cluster_state());
824   MS_LOG(INFO) << "The send metadata worker num:" << worker_num_ << ", server num:" << server_num_
825                << ", cluster state is:" << CommUtil::ClusterStateToString(current_cluster_state_)
826                << ", the rank id:" << node_info_.rank_id_;
827 
828   client_mutex_.lock();
829   nodes_address_.clear();
830   for (const auto &it : send_meta_message.servers_meta()) {
831     nodes_address_[std::make_pair(it.role(), it.rank_id())] = std::make_pair(it.ip(), it.port());
832     MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(it.role()) << ", node id:" << it.node_id()
833                  << ", rank id:" << it.rank_id() << ", ip:" << it.ip() << ", port:" << it.port();
834   }
835   client_mutex_.unlock();
836   if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
837     MS_LOG(WARNING) << "Sever response message failed.";
838   }
839   is_ready_ = true;
840   wait_start_cond_.notify_all();
841 
842   if (current_cluster_state_ == ClusterState::CLUSTER_SCALE_OUT) {
843     MS_LOG(WARNING) << "Trigger cluster scale out done event.";
844     OnEventCallback(ClusterEvent::CLUSTER_SCALE_OUT_DONE);
845   }
846 
847   if (current_cluster_state_ == ClusterState::CLUSTER_SCALE_IN) {
848     MS_LOG(WARNING) << "Trigger cluster scale in done event.";
849     OnEventCallback(ClusterEvent::CLUSTER_SCALE_IN_DONE);
850   }
851 
852   if (cancelSafeModeFn_ && current_cluster_state_ == ClusterState::CLUSTER_SCALE_OUT_ROLLBACK) {
853     MS_LOG(WARNING) << "Trigger cluster scale out rollback done event.";
854     OnEventCallback(ClusterEvent::CLUSTER_SCALE_OUT_ROLLBACK_DONE);
855     cancelSafeModeFn_();
856   }
857 
858   std::lock_guard<std::mutex> lock(client_mutex_);
859   connected_nodes_.clear();
860 
861   OnEventCallback(ClusterEvent::ON_SEND_META_DATA);
862 }
863 
ProcessFinish(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)864 void AbstractNode::ProcessFinish(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
865                                  const Protos &, const void *data, size_t size) {
866   MS_EXCEPTION_IF_NULL(conn);
867   MS_EXCEPTION_IF_NULL(meta);
868   MS_EXCEPTION_IF_NULL(data);
869   if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
870     MS_LOG(WARNING) << "Server response message failed.";
871   }
872   is_finish_ = true;
873   wait_finish_cond_.notify_all();
874 }
875 
ProcessScaleOutDone(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)876 void AbstractNode::ProcessScaleOutDone(const std::shared_ptr<TcpConnection> &conn,
877                                        const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
878                                        size_t size) {
879   MS_EXCEPTION_IF_NULL(conn);
880   MS_EXCEPTION_IF_NULL(meta);
881   MS_EXCEPTION_IF_NULL(data);
882   MS_LOG(INFO) << "This node receive a scale out done from scheduler.";
883   if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
884     MS_LOG(WARNING) << "Server response message failed.";
885   }
886   is_ready_ = true;
887   UpdateClusterState(ClusterState::CLUSTER_READY);
888 }
889 
ProcessScaleInDone(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)890 void AbstractNode::ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn,
891                                       const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
892                                       size_t size) {
893   MS_EXCEPTION_IF_NULL(conn);
894   MS_EXCEPTION_IF_NULL(meta);
895   MS_EXCEPTION_IF_NULL(data);
896   if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
897     MS_LOG(WARNING) << "Server response message failed.";
898   }
899   is_ready_ = true;
900   UpdateClusterState(ClusterState::CLUSTER_READY);
901 }
902 
ProcessEvent(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)903 void AbstractNode::ProcessEvent(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
904                                 const Protos &, const void *data, size_t size) {
905   MS_EXCEPTION_IF_NULL(conn);
906   MS_EXCEPTION_IF_NULL(meta);
907   MS_EXCEPTION_IF_NULL(data);
908   EventRespMessage event_resp_message;
909   event_resp_message.ParseFromArray(data, SizeToInt(size));
910   uint32_t event = event_resp_message.event();
911   if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
912     MS_LOG(WARNING) << "Server response message failed.";
913   }
914   MS_LOG(INFO) << "This node receive a event:" << event;
915   if (event == static_cast<uint32_t>(ps::UserDefineEvent::kNodeTimeout)) {
916     OnEventCallback(ClusterEvent::NODE_TIMEOUT);
917   } else {
918     OnCustomEventCallback(event);
919   }
920 }
921 
ProcessScaleOutRollback(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)922 void AbstractNode::ProcessScaleOutRollback(const std::shared_ptr<TcpConnection> &conn,
923                                            const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
924                                            size_t size) {
925   MS_EXCEPTION_IF_NULL(conn);
926   MS_EXCEPTION_IF_NULL(meta);
927   MS_EXCEPTION_IF_NULL(data);
928 
929   if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
930     MS_LOG(WARNING) << "Server response message failed.";
931   }
932 
933   UpdateClusterState(ClusterState::CLUSTER_SCALE_OUT_ROLLBACK);
934 
935   MS_LOG(INFO) << "[Scale out rollback]: begin to set scale out rollback.";
936   Register(client_to_scheduler_);
937   std::lock_guard<std::mutex> lock(client_mutex_);
938   connected_nodes_.clear();
939 
940   MS_LOG(INFO) << "The node begin to start scale out rollback.";
941 }
942 
ProcessScaleOut(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)943 void AbstractNode::ProcessScaleOut(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
944                                    const Protos &, const void *data, size_t size) {
945   MS_EXCEPTION_IF_NULL(conn);
946   MS_EXCEPTION_IF_NULL(meta);
947   MS_EXCEPTION_IF_NULL(data);
948 
949   ScaleOutMessage scale_out_message;
950   scale_out_message.ParseFromArray(data, SizeToInt(size));
951   int32_t worker_num = scale_out_message.worker_num();
952   int32_t server_num = scale_out_message.server_num();
953   MS_LOG(WARNING) << "The scale out worker num:" << worker_num << ", the server num:" << server_num;
954 
955   if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
956     MS_LOG(WARNING) << "Server response message failed.";
957   }
958   OnEventCallback(ClusterEvent::READY_FOR_SCALE_OUT);
959   UpdateClusterState(ClusterState::CLUSTER_SCALE_OUT);
960   is_ready_ = false;
961 }
962 
ProcessScaleIn(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)963 void AbstractNode::ProcessScaleIn(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
964                                   const Protos &, const void *data, size_t size) {
965   MS_EXCEPTION_IF_NULL(conn);
966   MS_EXCEPTION_IF_NULL(meta);
967   MS_EXCEPTION_IF_NULL(data);
968 
969   ScaleInMessage scale_in_message;
970   scale_in_message.ParseFromArray(data, SizeToInt(size));
971   int32_t worker_num = scale_in_message.worker_num();
972   int32_t server_num = scale_in_message.server_num();
973   MS_LOG(WARNING) << "The scale in worker num:" << worker_num << ", the server num:" << server_num;
974 
975   is_current_node_scale_in_ = scale_in_message.is_node_scale_in();
976   if (is_current_node_scale_in_) {
977     MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
978                     << " the node id:" << node_info_.node_id_ << " is a scale in node!";
979   } else {
980     MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
981                     << " the node id:" << node_info_.node_id_ << " is not a scale in node!";
982   }
983 
984   if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
985     MS_LOG(WARNING) << "Server response message failed.";
986   }
987   OnEventCallback(ClusterEvent::READY_FOR_SCALE_IN);
988   UpdateClusterState(ClusterState::CLUSTER_SCALE_IN);
989   is_ready_ = false;
990 }
991 
ProcessSchedulerRecovery(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)992 void AbstractNode::ProcessSchedulerRecovery(const std::shared_ptr<TcpConnection> &conn,
993                                             const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
994                                             size_t size) {
995   MS_EXCEPTION_IF_NULL(conn);
996   MS_EXCEPTION_IF_NULL(meta);
997   MS_EXCEPTION_IF_NULL(data);
998   SendMetadataMessage scheduler_recovery_message;
999   (void)scheduler_recovery_message.ParseFromArray(data, SizeToInt(size));
1000   worker_num_ = scheduler_recovery_message.worker_num();
1001   server_num_ = scheduler_recovery_message.server_num();
1002   uint32_t rank_id = scheduler_recovery_message.rank_id();
1003 
1004   MS_LOG(INFO) << "[Scheduler Recovery]: The scheduler recovery worker num:" << worker_num_
1005                << ", the server num:" << server_num_ << ", the rank id: " << rank_id;
1006 
1007   if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
1008     MS_LOG(WARNING) << "[Scheduler Recovery]: Server response message failed.";
1009   }
1010   MS_LOG(INFO) << "[Scheduler Recovery]: Server response message success!.";
1011 
1012   ConnectToScheduler();
1013   bool connected = client_to_scheduler_->WaitConnected();
1014   if (!connected) {
1015     MS_LOG(WARNING) << "[Scheduler Recovery]: Server node connect to scheduler timedout!";
1016   }
1017 
1018   Register(client_to_scheduler_);
1019   std::lock_guard<std::mutex> lock(client_mutex_);
1020   connected_nodes_.clear();
1021   MS_LOG(INFO) << "[Scheduler Recovery]: This node connect to scheduler successful!";
1022 
1023   if (cancelSafeModeFn_ && (current_cluster_state_ == ClusterState::CLUSTER_SCALE_IN ||
1024                             current_cluster_state_ == ClusterState::CLUSTER_SCALE_OUT)) {
1025     MS_LOG(INFO) << "[Scheduler Recovery]: Cancel Safe mode for " << kClusterState.at(current_cluster_state_);
1026     cancelSafeModeFn_();
1027   }
1028 
1029   UpdateClusterState(ClusterState::CLUSTER_SCHEDULER_RECOVERY);
1030   is_ready_ = false;
1031 }
1032 
Disconnect(const std::shared_ptr<TcpClient> & client,const uint32_t & timeout)1033 bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout) {
1034   MS_EXCEPTION_IF_NULL(client);
1035   auto meta = std::make_shared<MessageMeta>();
1036   MS_EXCEPTION_IF_NULL(meta);
1037   meta->set_cmd(NodeCommand::FINISH);
1038 
1039   std::string finish_message = node_info_.node_id_;
1040 
1041   if (!SendMessageSync(client, meta, Protos::RAW, finish_message.data(), finish_message.length())) {
1042     MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
1043                     << " the node id:" << node_info_.node_id_ << " send Finish Message timeout!";
1044   }
1045   return WaitForDisconnect(timeout);
1046 }
1047 
WaitForDisconnect(const uint32_t & timeout)1048 bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) {
1049   // If the cluster state is NODE_TIMEOUT, this node is already disconnected.
1050   if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
1051     return true;
1052   }
1053   std::unique_lock<std::mutex> lock(wait_finish_mutex_);
1054   auto condition_func = [this] {
1055     if (is_finish_.load()) {
1056       MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success finish!";
1057     }
1058     return is_finish_.load();
1059   };
1060 
1061   bool res;
1062   if (timeout == UINT32_MAX) {
1063     // Caller should use this method to help block the thread.
1064     wait_finish_cond_.wait(lock, condition_func);
1065     res = true;
1066   } else {
1067     res = wait_finish_cond_.wait_for(lock, std::chrono::seconds(timeout), condition_func);
1068   }
1069 
1070   return res;
1071 }
1072 
InitClientToServer()1073 void AbstractNode::InitClientToServer() {
1074   // create tcp client to myself in case of event dispatch failed when Send msg to server 0 failed
1075   client_to_server_ = std::make_shared<TcpClient>(node_info_.ip_, node_info_.port_, node_info_.node_role_);
1076   MS_EXCEPTION_IF_NULL(client_to_server_);
1077   client_to_server_->Init();
1078   MS_LOG(INFO) << "The node start a tcp client to this node!";
1079 }
1080 
InitClientToScheduler()1081 bool AbstractNode::InitClientToScheduler() {
1082   if (config_ == nullptr) {
1083     MS_LOG(WARNING) << "The config is empty.";
1084     return false;
1085   }
1086   client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_ip_, scheduler_port_, NodeRole::SCHEDULER);
1087   MS_EXCEPTION_IF_NULL(client_to_scheduler_);
1088   client_to_scheduler_->SetMessageCallback(
1089     [&](const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data, size_t size) {
1090       try {
1091         MS_EXCEPTION_IF_NULL(meta);
1092         MS_EXCEPTION_IF_NULL(data);
1093         if (handlers_.count(meta->cmd()) == 0) {
1094           MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
1095         }
1096         if (handlers_[meta->cmd()] != nullptr) {
1097           const auto &handler_ptr = handlers_[meta->cmd()];
1098           (this->*handler_ptr)(meta, data, size);
1099         }
1100         NotifyMessageArrival(meta);
1101       } catch (const std::exception &e) {
1102         MsException::Instance().SetException();
1103       }
1104     });
1105   ConnectToScheduler();
1106   StartHeartbeatTimer(client_to_scheduler_);
1107   MS_LOG(INFO) << "Start heartbeat timer!";
1108 
1109   bool wait_res = client_to_scheduler_->WaitConnected();
1110   if (!wait_res) {
1111     is_ready_ = true;
1112   }
1113   return wait_res;
1114 }
ConnectToScheduler()1115 void AbstractNode::ConnectToScheduler() {
1116   client_to_scheduler_->Init();
1117   if (TcpClient::is_started()) {
1118     return;
1119   }
1120 
1121   if (client_to_scheduler_thread_ != nullptr && client_to_scheduler_thread_->joinable()) {
1122     client_to_scheduler_thread_->join();
1123   }
1124   client_to_scheduler_thread_ = std::make_unique<std::thread>([this]() {
1125     MS_LOG(INFO) << "The node start a tcp client!";
1126     client_to_scheduler_->Start();
1127   });
1128 }
1129 
GetOrCreateTcpClient(const uint32_t & rank_id,const NodeRole & role)1130 const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const uint32_t &rank_id, const NodeRole &role) {
1131   std::lock_guard<std::mutex> lock(client_mutex_);
1132   auto key = std::make_pair(role, rank_id);
1133   if (connected_nodes_.find(key) != connected_nodes_.end()) {
1134     return connected_nodes_[key];
1135   } else {
1136     if (nodes_address_.find(key) == nodes_address_.end()) {
1137       MS_LOG(EXCEPTION) << "Worker receive nodes info from scheduler failed. Role: " << role << ", rank: " << rank_id;
1138     }
1139     if (config_ == nullptr) {
1140       MS_LOG(EXCEPTION) << "The config is empty.";
1141     }
1142 
1143     MS_LOG(INFO) << "Create tcp client for role: " << role << ", rank: " << rank_id;
1144     std::string ip = nodes_address_[key].first;
1145     uint16_t port = nodes_address_[key].second;
1146     auto client = std::make_shared<TcpClient>(ip, port, role);
1147     MS_EXCEPTION_IF_NULL(client);
1148     client->SetMessageCallback([&](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
1149                                    size_t size) {
1150       switch (meta->cmd()) {
1151         case NodeCommand::SEND_DATA:
1152           ProcessSendDataResp(meta, protos, data, size);
1153           RunMessageCallback(meta->request_id());
1154           break;
1155         case NodeCommand::COLLECTIVE_SEND_DATA:
1156           MS_LOG(DEBUG) << "The Node id:" << node_info_.node_id_ << " receive a collective_send_data message response!";
1157           break;
1158         case NodeCommand::SEND_EVENT:
1159           MS_LOG(DEBUG) << "The Node id:" << node_info_.node_id_ << " receive a send_event command message response!";
1160           break;
1161         default:
1162           MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
1163       }
1164       NotifyMessageArrival(meta);
1165     });
1166     client->Init();
1167     connected_nodes_[key] = client;
1168     return connected_nodes_[key];
1169   }
1170 }
1171 
SendMessageSync(const std::shared_ptr<TcpClient> & client,const CommMessage & message,const uint32_t & timeout)1172 bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
1173                                    const uint32_t &timeout) {
1174   MS_EXCEPTION_IF_NULL(client);
1175   uint64_t request_id = AddMessageTrack(1);
1176   const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
1177   client->SendMessage(message);
1178   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
1179                 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
1180   return Wait(request_id, timeout);
1181 }
1182 
SendMessageSync(const std::shared_ptr<TcpClient> & client,const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size,const uint32_t & timeout)1183 bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta,
1184                                    const Protos &protos, const void *data, size_t size, const uint32_t &timeout) {
1185   MS_EXCEPTION_IF_NULL(client);
1186   MS_EXCEPTION_IF_NULL(meta);
1187   MS_EXCEPTION_IF_NULL(data);
1188   uint64_t request_id = AddMessageTrack(1);
1189   meta->set_request_id(request_id);
1190   client->SendMessage(meta, protos, data, size);
1191   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
1192                 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
1193   return Wait(request_id, timeout);
1194 }
1195 
SendCollectiveMeta(const std::shared_ptr<TcpClient> & client,const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size)1196 uint64_t AbstractNode::SendCollectiveMeta(const std::shared_ptr<TcpClient> &client,
1197                                           const std::shared_ptr<MessageMeta> &meta, const Protos &protos,
1198                                           const void *data, size_t size) {
1199   MS_EXCEPTION_IF_NULL(client);
1200   MS_EXCEPTION_IF_NULL(meta);
1201   MS_EXCEPTION_IF_NULL(data);
1202   uint64_t request_id = AddMessageTrack(1);
1203   meta->set_request_id(request_id);
1204   client->SendMessage(meta, protos, data, size);
1205   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
1206                 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
1207   return request_id;
1208 }
1209 
ProcessCollectiveSendData(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size)1210 void AbstractNode::ProcessCollectiveSendData(const std::shared_ptr<TcpConnection> &conn,
1211                                              const std::shared_ptr<MessageMeta> &meta, const Protos &protos,
1212                                              const void *data, size_t size) {
1213   MS_EXCEPTION_IF_NULL(conn);
1214   MS_EXCEPTION_IF_NULL(meta);
1215   MS_EXCEPTION_IF_NULL(data);
1216   if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
1217     MS_LOG(WARNING) << "Server response message failed.";
1218   }
1219   RunReceiveCallback(meta, protos, data, size);
1220 }
1221 
ProcessSendData(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)1222 void AbstractNode::ProcessSendData(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
1223                                    const Protos &, const void *data, size_t size) {
1224   MS_EXCEPTION_IF_NULL(conn);
1225   MS_EXCEPTION_IF_NULL(meta);
1226   MS_EXCEPTION_IF_NULL(data);
1227   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
1228                 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << meta->request_id()
1229                 << " the current time is:"
1230                 << std::chrono::time_point_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now())
1231                      .time_since_epoch()
1232                      .count();
1233   request_handler_(conn, meta, data, size);
1234 }
1235 
NotifyMessageArrival(const std::shared_ptr<MessageMeta> & meta)1236 void AbstractNode::NotifyMessageArrival(const std::shared_ptr<MessageMeta> &meta) {
1237   MS_EXCEPTION_IF_NULL(meta);
1238   std::lock_guard<std::mutex> lock(message_tracker_mutex_);
1239   uint64_t request_id = meta->request_id();
1240   if (message_tracker_.count(request_id)) {
1241     message_tracker_[request_id].second++;
1242   } else {
1243     MS_LOG(WARNING) << "The requset id:" << request_id << " is removed.";
1244   }
1245   message_tracker_cond_.notify_all();
1246 }
1247 
RunReceiveCallback(const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)1248 void AbstractNode::RunReceiveCallback(const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
1249                                       size_t size) {
1250   MS_EXCEPTION_IF_NULL(meta);
1251   MS_EXCEPTION_IF_NULL(data);
1252   std::shared_ptr<std::vector<unsigned char>> received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
1253   size_t dest_size = size;
1254   size_t src_size = size;
1255   int ret = memcpy_s(received_data->data(), dest_size, data, src_size);
1256   if (ret != 0) {
1257     MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
1258   }
1259   if (meta->collective_meta().enable_flag()) {
1260     OnRecvCollectiveData(*meta, received_data);
1261     return;
1262   }
1263   receive_callbacks_mutex_.lock();
1264   uint32_t rank_id = meta->rank_id();
1265   // When receiving a collective message, Then generate rank request id,compare with the desired rank request id,
1266   // If they are equal, then call the callback function
1267   uint64_t rank_request_id = NextActualRankRequestId(rank_id);
1268   received_data_[std::make_pair(rank_id, rank_request_id)] = received_data;
1269   MS_LOG(DEBUG) << "Run Receive data callback,the rank id:" << rank_id << ", the rank request id is:" << rank_request_id
1270                 << ", the send request id is:" << meta->request_id() << " the size is:" << size;
1271   auto it = receive_callbacks_.find(std::make_pair(rank_id, rank_request_id));
1272   if (it != receive_callbacks_.end()) {
1273     if (receive_messages_done_.count(std::make_pair(rank_id, rank_request_id)) != 0) {
1274       if (it->second) {
1275         it->second();
1276       }
1277     }
1278     receive_cond_.notify_all();
1279     receive_callbacks_.erase(it);
1280   }
1281   receive_callbacks_mutex_.unlock();
1282 }
1283 
NextExpectedRankRequestId(const uint32_t & rank_id)1284 uint64_t AbstractNode::NextExpectedRankRequestId(const uint32_t &rank_id) {
1285   std::lock_guard<std::mutex> lock(rank_request_ids_mutex);
1286   uint64_t rank_request_id = 1;
1287   if (expected_rank_request_ids_.count(rank_id)) {
1288     rank_request_id = ++expected_rank_request_ids_[rank_id];
1289     expected_rank_request_ids_[rank_id] = rank_request_id;
1290   } else {
1291     expected_rank_request_ids_[rank_id] = rank_request_id;
1292   }
1293   return rank_request_id;
1294 }
1295 
NextActualRankRequestId(const uint32_t & rank_id)1296 uint64_t AbstractNode::NextActualRankRequestId(const uint32_t &rank_id) {
1297   std::lock_guard<std::mutex> lock(rank_request_ids_mutex);
1298   uint64_t rank_request_id = 1;
1299   if (actual_rank_request_ids_.count(rank_id)) {
1300     rank_request_id = ++actual_rank_request_ids_[rank_id];
1301     actual_rank_request_ids_[rank_id] = rank_request_id;
1302   } else {
1303     actual_rank_request_ids_[rank_id] = rank_request_id;
1304   }
1305   return rank_request_id;
1306 }
1307 
InitCommandHandler()1308 void AbstractNode::InitCommandHandler() {
1309   handlers_[NodeCommand::HEARTBEAT] = &AbstractNode::ProcessHeartbeatResp;
1310   handlers_[NodeCommand::REGISTER] = &AbstractNode::ProcessRegisterResp;
1311   handlers_[NodeCommand::FETCH_METADATA] = &AbstractNode::ProcessFetchServersResp;
1312   handlers_[NodeCommand::FINISH] = nullptr;
1313   handlers_[NodeCommand::SCALE_OUT_DONE] = nullptr;
1314   handlers_[NodeCommand::SCALE_IN_DONE] = nullptr;
1315   handlers_[NodeCommand::SEND_EVENT] = nullptr;
1316   RegisterActorRouteTableRspHandler();
1317   RegisterInitCollectCommResphandler();
1318   RegisterRecoveryRespHandler();
1319 }
1320 
RegisterActorRouteTableRspHandler()1321 void AbstractNode::RegisterActorRouteTableRspHandler() {
1322   handlers_[NodeCommand::REGISTER_ACTOR_ROUTE] = &AbstractNode::ProcessReceiveSchedulerResp;
1323   handlers_[NodeCommand::DELETE_ACTOR_ROUTE] = &AbstractNode::ProcessReceiveSchedulerResp;
1324   handlers_[NodeCommand::LOOKUP_ACTOR_ROUTE] = &AbstractNode::ProcessReceiveSchedulerResp;
1325 }
1326 
InitServerHandler()1327 void AbstractNode::InitServerHandler() {
1328   server_handler_[NodeCommand::SEND_METADATA] = &AbstractNode::ProcessSendMetadata;
1329   server_handler_[NodeCommand::FINISH] = &AbstractNode::ProcessFinish;
1330   server_handler_[NodeCommand::SEND_DATA] = &AbstractNode::ProcessSendData;
1331   server_handler_[NodeCommand::COLLECTIVE_SEND_DATA] = &AbstractNode::ProcessCollectiveSendData;
1332   server_handler_[NodeCommand::SCALE_OUT] = &AbstractNode::ProcessScaleOut;
1333   server_handler_[NodeCommand::SCALE_IN] = &AbstractNode::ProcessScaleIn;
1334   server_handler_[NodeCommand::SCALE_OUT_DONE] = &AbstractNode::ProcessScaleOutDone;
1335   server_handler_[NodeCommand::SCALE_IN_DONE] = &AbstractNode::ProcessScaleInDone;
1336   server_handler_[NodeCommand::SEND_EVENT] = &AbstractNode::ProcessEvent;
1337   server_handler_[NodeCommand::SCHEDULER_RECOVERY] = &AbstractNode::ProcessSchedulerRecovery;
1338   server_handler_[NodeCommand::PREPARE_BUILDING_NETWORK] = &AbstractNode::ProcessPrepareBuildingNetwork;
1339   server_handler_[NodeCommand::SCALE_OUT_ROLLBACK] = &AbstractNode::ProcessScaleOutRollback;
1340 }
1341 
InitNodeInfo(const NodeRole & role)1342 void AbstractNode::InitNodeInfo(const NodeRole &role) {
1343   MS_EXCEPTION_IF_NULL(config_);
1344   MS_EXCEPTION_IF_NULL(server_);
1345   if (PSContext::instance()->node_id().empty() && config_->Exists(kNodeId)) {
1346     node_info_.node_id_ = config_->Get(kNodeId, "");
1347   } else {
1348     node_info_.node_id_ = PSContext::instance()->node_id();
1349   }
1350 
1351   if (node_info_.node_id_.empty()) {
1352     node_info_.node_id_ = CommUtil::GenerateUUID();
1353   }
1354   node_info_.node_role_ = role;
1355   node_info_.ip_ = server_->BoundIp();
1356   node_info_.port_ = server_->BoundPort();
1357 
1358   MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
1359                << " is generate uuid is:" << node_info_.node_id_ << ", the ip:" << server_->BoundIp()
1360                << ", the port:" << server_->BoundPort();
1361 }
1362 
InitNodeNum()1363 void AbstractNode::InitNodeNum() {
1364   worker_num_ = PSContext::instance()->cluster_config().initial_worker_num;
1365   server_num_ = PSContext::instance()->cluster_config().initial_server_num;
1366   scheduler_ip_ = PSContext::instance()->cluster_config().scheduler_host;
1367   scheduler_port_ = PSContext::instance()->cluster_config().scheduler_port;
1368   MS_LOG(INFO) << "The worker num:" << worker_num_ << ", the server num:" << server_num_
1369                << ", the scheduler ip:" << scheduler_ip_ << ", the scheduler port:" << scheduler_port_;
1370 }
1371 
Recover()1372 bool AbstractNode::Recover() {
1373   MS_EXCEPTION_IF_NULL(config_);
1374   if (config_->Exists(kKeyRecovery)) {
1375     MS_LOG(INFO) << "The node is support recovery.";
1376     node_recovery_ = std::make_unique<NodeRecovery>(this);
1377     MS_EXCEPTION_IF_NULL(node_recovery_);
1378     if (node_recovery_->Initialize(config_->Get(kKeyRecovery, ""))) {
1379       MS_LOG(INFO) << "Initializing node recovery success.";
1380       return node_recovery_->Recover();
1381     }
1382   }
1383   return false;
1384 }
1385 
OnEventCallback(const ClusterEvent & event)1386 void AbstractNode::OnEventCallback(const ClusterEvent &event) {
1387   if (!event_to_callback_.count(event)) {
1388     MS_LOG(INFO) << "[Event]:The event callback of " << event << " is not set.";
1389   } else {
1390     MS_LOG(INFO) << "[Event]:Trigger the event:" << event;
1391     if (event_to_callback_[event]) {
1392       event_to_callback_[event]();
1393     }
1394   }
1395 }
1396 
OnCustomEventCallback(const uint32_t & event)1397 void AbstractNode::OnCustomEventCallback(const uint32_t &event) {
1398   if (!custom_event_to_callback_.count(event)) {
1399     MS_LOG(WARNING) << "[Custom event]:The event callback of " << event << " is not set.";
1400   } else {
1401     MS_LOG(INFO) << "[Custom event]:Trigger the event:" << event;
1402     if (custom_event_to_callback_[event]) {
1403       custom_event_to_callback_[event]();
1404     }
1405   }
1406 }
1407 
IsWorkerOrServer0(const std::unordered_map<std::string,NodeInfo> & info)1408 bool AbstractNode::IsWorkerOrServer0(const std::unordered_map<std::string, NodeInfo> &info) {
1409   for (const auto &it : info) {
1410     if (it.second.is_alive == true && it.second.node_role_ == NodeRole::WORKER) {
1411       return true;
1412     }
1413 
1414     if (it.second.is_alive == true && it.second.rank_id_ == 0 && it.second.node_role_ == NodeRole::SERVER) {
1415       return true;
1416     }
1417   }
1418   return false;
1419 }
1420 
CreateTcpServer(const std::pair<uint32_t,uint32_t> & port_range)1421 void AbstractNode::CreateTcpServer(const std::pair<uint32_t, uint32_t> &port_range) {
1422   MS_EXCEPTION_IF_NULL(config_);
1423   std::string interface;
1424   std::string server_ip = common::GetEnv("MS_WORKER_IP");
1425   if (server_ip.empty()) {
1426     MS_LOG(INFO) << "'MS_WORKER_IP' env is not set, so get first available network interface.";
1427     CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
1428   }
1429 
1430   server_ = std::make_shared<TcpServer>(server_ip, 0, config_.get(), port_range);
1431   MS_EXCEPTION_IF_NULL(server_);
1432   server_->SetMessageCallback([&](const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
1433                                   const Protos &protos, const void *data, size_t size) {
1434     MS_EXCEPTION_IF_NULL(conn);
1435     MS_EXCEPTION_IF_NULL(meta);
1436     MS_EXCEPTION_IF_NULL(data);
1437     MS_LOG(DEBUG) << "Receive message cmd " << meta->cmd() << ", size is " << size;
1438     const auto &handler_pair = server_handler_.find(meta->cmd());
1439     if (handler_pair == server_handler_.end()) {
1440       MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
1441     }
1442     (this->*(handler_pair->second))(conn, meta, protos, data, size);
1443   });
1444 
1445   server_->Init();
1446   server_thread_ = std::make_unique<std::thread>([this]() {
1447     MS_LOG(INFO) << "The worker node or server node start a tcp server!";
1448     this->server_->Start();
1449   });
1450   MS_EXCEPTION_IF_NULL(server_thread_);
1451 }
1452 
UpdateClusterState(const ClusterState & state)1453 void AbstractNode::UpdateClusterState(const ClusterState &state) {
1454   std::lock_guard<std::mutex> lock(cluster_state_mutex_);
1455   std::string state_str = CommUtil::ClusterStateToString(state);
1456   if (state_str.empty()) {
1457     return;
1458   }
1459 
1460   if (state == current_cluster_state_) {
1461     return;
1462   }
1463   MS_LOG(INFO) << "[state]: Cluster state change from:" << CommUtil::ClusterStateToString(current_cluster_state_)
1464                << " to " << state_str;
1465   current_cluster_state_ = state;
1466 }
1467 
PersistMetaData()1468 void AbstractNode::PersistMetaData() {
1469   if (node_recovery_ == nullptr) {
1470     MS_LOG(WARNING) << "node recovery is null, so don't persist meta data";
1471     return;
1472   }
1473   if (config_->Exists(kKeyRecovery)) {
1474     ClusterConfig &clusterConfig = PSContext::instance()->cluster_config();
1475     clusterConfig.scheduler_host = this->scheduler_ip();
1476     clusterConfig.scheduler_port = this->scheduler_port();
1477     clusterConfig.initial_worker_num = worker_num_;
1478     clusterConfig.initial_server_num = server_num_;
1479 
1480     node_recovery_->Persist(clusterConfig);
1481   }
1482 }
1483 
ProcessPrepareBuildingNetwork(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)1484 void AbstractNode::ProcessPrepareBuildingNetwork(const std::shared_ptr<TcpConnection> &conn,
1485                                                  const std::shared_ptr<MessageMeta> &meta, const Protos &,
1486                                                  const void *data, size_t size) {
1487   MS_EXCEPTION_IF_NULL(conn);
1488   MS_EXCEPTION_IF_NULL(meta);
1489   MS_EXCEPTION_IF_NULL(data);
1490   if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
1491     MS_LOG(ERROR) << "sever response message failed, prepare for building network failed.";
1492   } else {
1493     MS_LOG(INFO) << "prepare for building network success.";
1494   }
1495 }
1496 }  // namespace core
1497 }  // namespace ps
1498 }  // namespace mindspore
1499