• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ps/core/abstract_node.h"
18 #include "ps/core/node_recovery.h"
19 #include "ps/core/communicator/tcp_communicator.h"
20 #include "ps/core/communicator/http_communicator.h"
21 
22 namespace mindspore {
23 namespace ps {
24 namespace core {
Register(const std::shared_ptr<TcpClient> & client)25 void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
26   MS_EXCEPTION_IF_NULL(client);
27   auto message_meta = std::make_shared<MessageMeta>();
28   MS_EXCEPTION_IF_NULL(message_meta);
29   message_meta->set_cmd(NodeCommand::REGISTER);
30   message_meta->set_rank_id(node_info_.rank_id_);
31 
32   RegisterMessage register_message;
33   register_message.set_node_id(node_info_.node_id_);
34   register_message.set_role(node_info_.node_role_);
35   register_message.set_ip(node_info_.ip_);
36   register_message.set_port(node_info_.port_);
37 
38   MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
39                << " the node id:" << node_info_.node_id_ << " begin to register to the scheduler!";
40 
41   if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, register_message.SerializeAsString().data(),
42                        register_message.ByteSizeLong())) {
43     MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
44                       << " the node id:" << node_info_.node_id_ << " register timeout!";
45   }
46 }
47 
ProcessRegisterResp(const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)48 void AbstractNode::ProcessRegisterResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
49   MS_EXCEPTION_IF_NULL(meta);
50   MS_EXCEPTION_IF_NULL(data);
51   RegisterRespMessage register_resp_message;
52   CHECK_RETURN_TYPE(register_resp_message.ParseFromArray(data, SizeToInt(size)));
53   if (register_resp_message.node_id() != node_info_.node_id_) {
54     MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id()
55                       << " is not match the current node id:" << node_info_.node_id_;
56   }
57 
58   // Receive the Register message, indicating that the scheduler is alive, so update the time point at which the
59   // scheduler is alive
60   UpdateSchedulerTime();
61 
62   MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << " registered scheduler success!";
63 }
64 
Broadcast(const NodeRole & node_role,const DataPtr & message,size_t size,int command,const uint32_t & timeout)65 bool AbstractNode::Broadcast(const NodeRole &node_role, const DataPtr &message, size_t size, int command,
66                              const uint32_t &timeout) {
67   MS_EXCEPTION_IF_NULL(message);
68   if (node_role != NodeRole::SERVER) {
69     MS_LOG(EXCEPTION) << "Currently only supports broadcast to server nodes";
70   }
71 
72   uint64_t request_id = AddMessageTrack(nodes_address_.size());
73 
74   for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) {
75     auto message_meta = std::make_shared<MessageMeta>();
76     MS_EXCEPTION_IF_NULL(message_meta);
77     message_meta->set_cmd(NodeCommand::SEND_DATA);
78     message_meta->set_request_id(request_id);
79     message_meta->set_rank_id(node_info_.rank_id_);
80     message_meta->set_role(node_info_.node_role_);
81     message_meta->set_user_cmd(command);
82 
83     auto client = GetOrCreateTcpClient((*it).first.second);
84     if (!client->SendMessage(message_meta, Protos::RAW, message.get(), size)) {
85       MS_LOG(WARNING) << "Client send message failed.";
86     }
87   }
88   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
89                 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
90   return Wait(request_id, timeout);
91 }
92 
set_ready_for_scale_out()93 void AbstractNode::set_ready_for_scale_out() {
94   MS_LOG(INFO) << "[Scale out]: begin to set ready for scale out.";
95   Register(client_to_scheduler_);
96   std::lock_guard<std::mutex> lock(client_mutex_);
97   connected_nodes_.clear();
98 }
99 
set_ready_for_scale_in()100 void AbstractNode::set_ready_for_scale_in() {
101   MS_LOG(INFO) << "[Scale in]: begin to set ready for scale in.";
102   if (!is_current_node_scale_in_) {
103     Register(client_to_scheduler_);
104     std::lock_guard<std::mutex> lock(client_mutex_);
105     connected_nodes_.clear();
106   }
107 }
108 
set_scale_out_done()109 void AbstractNode::set_scale_out_done() {
110   MS_LOG(INFO) << "[Scale out]: begin to set scale out done.";
111   auto message_meta = std::make_shared<MessageMeta>();
112   MS_EXCEPTION_IF_NULL(message_meta);
113   message_meta->set_cmd(NodeCommand::SCALE_OUT_DONE);
114 
115   ScaleOutDoneMessage scale_out_done_message;
116   scale_out_done_message.set_node_id(node_info_.node_id_);
117 
118   if (!SendMessageSync(client_to_scheduler_, message_meta, Protos::PROTOBUF,
119                        scale_out_done_message.SerializeAsString().data(), scale_out_done_message.ByteSizeLong())) {
120     MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
121                     << " the node id:" << node_info_.node_id_ << " scale_out_done timeout!";
122     return;
123   }
124 
125   MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
126                << " the node id:" << node_info_.node_id_ << "is send scale_out_done to scheduler successful!";
127 }
128 
set_scale_in_done()129 void AbstractNode::set_scale_in_done() {
130   MS_LOG(INFO) << "[Scale in]: begin to set scale in done.";
131   auto message_meta = std::make_shared<MessageMeta>();
132   MS_EXCEPTION_IF_NULL(message_meta);
133   message_meta->set_cmd(NodeCommand::SCALE_IN_DONE);
134 
135   ScaleInDoneMessage scale_in_done_message;
136   scale_in_done_message.set_node_id(node_info_.node_id_);
137 
138   if (!SendMessageSync(client_to_scheduler_, message_meta, Protos::PROTOBUF,
139                        scale_in_done_message.SerializeAsString().data(), scale_in_done_message.ByteSizeLong())) {
140     MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
141                     << " the node id:" << node_info_.node_id_ << " scale_in_done timeout!";
142     return;
143   }
144 
145   MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
146                << " the node id:" << node_info_.node_id_ << "is send scale_in_done to scheduler successful!";
147 }
148 
BroadcastEvent(const uint32_t & event)149 void AbstractNode::BroadcastEvent(const uint32_t &event) {
150   auto message_meta = std::make_shared<MessageMeta>();
151   MS_EXCEPTION_IF_NULL(message_meta);
152   message_meta->set_cmd(NodeCommand::SEND_EVENT);
153 
154   EventMessage event_message;
155   event_message.set_event(event);
156   event_message.set_node_id(node_info_.node_id_);
157 
158   if (!SendMessageSync(client_to_scheduler_, message_meta, Protos::PROTOBUF, event_message.SerializeAsString().data(),
159                        event_message.ByteSizeLong())) {
160     MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
161                   << " the node id:" << node_info_.node_id_ << " send event timeout!";
162     return;
163   }
164 
165   MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
166                << " the node id:" << node_info_.node_id_ << "is send event to scheduler!";
167 }
168 
RegisterEventCallback(const core::ClusterEvent & event,const EventCallback & event_cb)169 void AbstractNode::RegisterEventCallback(const core::ClusterEvent &event, const EventCallback &event_cb) {
170   event_to_callback_.try_emplace(event, event_cb);
171 }
172 
RegisterCustomEventCallback(const uint32_t & event,const EventCallback & event_cb)173 void AbstractNode::RegisterCustomEventCallback(const uint32_t &event, const EventCallback &event_cb) {
174   custom_event_to_callback_.try_emplace(event, event_cb);
175 }
176 
Send(const NodeRole & node_role,const uint32_t & rank_id,const DataPtr & data,size_t len,int command,const uint32_t & timeout)177 bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len,
178                         int command, const uint32_t &timeout) {
179   if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
180     MS_LOG(DEBUG) << "The node is timeout, can not send message.";
181     return false;
182   }
183   MS_EXCEPTION_IF_NULL(data);
184   if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
185     MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
186                       << ", the server num:" << server_num_ << ", the rank id:" << rank_id;
187   }
188 
189   auto message_meta = std::make_shared<MessageMeta>();
190   MS_EXCEPTION_IF_NULL(message_meta);
191   message_meta->set_cmd(NodeCommand::SEND_DATA);
192   message_meta->set_rank_id(node_info_.rank_id_);
193   message_meta->set_role(node_info_.node_role_);
194   message_meta->set_user_cmd(command);
195 
196   auto client = GetOrCreateTcpClient(rank_id);
197   return SendMessageSync(client, message_meta, Protos::RAW, data.get(), len, timeout);
198 }
199 
Send(const NodeRole & node_role,const std::vector<uint32_t> & rank_ids,const std::vector<DataPtr> & data,const std::vector<size_t> & lens,int command,const uint32_t & timeout)200 bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
201                         const std::vector<DataPtr> &data, const std::vector<size_t> &lens, int command,
202                         const uint32_t &timeout) {
203   if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
204     MS_LOG(DEBUG) << "The node is timeout, can not send message.";
205     return false;
206   }
207 
208   uint64_t request_id = AddMessageTrack(data.size());
209 
210   if (rank_ids.size() != data.size() || rank_ids.size() != lens.size()) {
211     MS_LOG(EXCEPTION) << "The number of rank ids, data and lens are not equal!";
212   }
213   for (size_t it = 0; it < rank_ids.size(); ++it) {
214     if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it), worker_num_, server_num_)) {
215       MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
216                         << ", the server num:" << server_num_ << ", the rank id:" << rank_ids.at(it);
217     }
218 
219     auto message_meta = std::make_shared<MessageMeta>();
220     MS_EXCEPTION_IF_NULL(message_meta);
221     message_meta->set_cmd(NodeCommand::SEND_DATA);
222     message_meta->set_request_id(request_id);
223     message_meta->set_rank_id(node_info_.rank_id_);
224     message_meta->set_role(node_info_.node_role_);
225     message_meta->set_user_cmd(command);
226 
227     auto send = data.at(it);
228     auto len = lens.at(it);
229     auto client = GetOrCreateTcpClient(rank_ids.at(it));
230     MS_EXCEPTION_IF_NULL(client);
231     if (!client->SendMessage(message_meta, Protos::RAW, send.get(), len)) {
232       MS_LOG(WARNING) << "Client send message failed.";
233     }
234   }
235   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
236                 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
237   return Wait(request_id, timeout);
238 }
239 
Send(const NodeRole & node_role,const uint32_t & rank_id,const DataPtr & message,size_t len,int command,VectorPtr * output,const uint32_t & timeout)240 bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len,
241                         int command, VectorPtr *output, const uint32_t &timeout) {
242   if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
243     MS_LOG(DEBUG) << "The node is timeout, can not send message.";
244     return false;
245   }
246   MS_EXCEPTION_IF_NULL(message);
247   MS_EXCEPTION_IF_NULL(output);
248   if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
249     MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
250                       << ", the server num:" << server_num_ << ", the rank id:" << rank_id;
251   }
252 
253   uint64_t request_id = AddMessageTrack(1);
254   set_message_callback(request_id, [&]() {
255     receive_messages_mutex_.lock();
256     auto res = receive_messages_[request_id];
257     *output = res[rank_id];
258     receive_messages_.erase(request_id);
259     receive_messages_mutex_.unlock();
260   });
261 
262   auto message_meta = std::make_shared<MessageMeta>();
263   MS_EXCEPTION_IF_NULL(message_meta);
264   message_meta->set_cmd(NodeCommand::SEND_DATA);
265   message_meta->set_request_id(request_id);
266   message_meta->set_rank_id(node_info_.rank_id_);
267   message_meta->set_role(node_info_.node_role_);
268   message_meta->set_user_cmd(command);
269 
270   auto client = GetOrCreateTcpClient(rank_id);
271   MS_EXCEPTION_IF_NULL(client);
272   if (!client->SendMessage(message_meta, Protos::RAW, message.get(), len)) {
273     MS_LOG(WARNING) << "Client send message failed.";
274   }
275   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
276                 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
277   return Wait(request_id, timeout);
278 }
279 
Send(const NodeRole & node_role,const std::vector<uint32_t> & rank_ids,const std::vector<DataPtr> & data,const std::vector<size_t> & data_lens,int command,std::vector<VectorPtr> * output,const uint32_t & timeout)280 bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
281                         const std::vector<DataPtr> &data, const std::vector<size_t> &data_lens, int command,
282                         std::vector<VectorPtr> *output, const uint32_t &timeout) {
283   if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
284     MS_LOG(DEBUG) << "The node is timeout, can not send message.";
285     return false;
286   }
287   MS_EXCEPTION_IF_NULL(output);
288   uint64_t request_id = AddMessageTrack(data.size());
289 
290   if (rank_ids.size() != data.size()) {
291     MS_LOG(EXCEPTION) << "The number of rank ids, data, comm_message_resp should be equal!";
292   }
293 
294   size_t size = rank_ids.size();
295 
296   set_message_callback(request_id, [&]() {
297     receive_messages_mutex_.lock();
298     auto res = receive_messages_[request_id];
299     for (size_t it = 0; it < size; ++it) {
300       (*output).push_back(res[rank_ids.at(it)]);
301     }
302     receive_messages_.erase(request_id);
303     receive_messages_mutex_.unlock();
304   });
305 
306   for (size_t it = 0; it < size; ++it) {
307     if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it), worker_num_, server_num_)) {
308       MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
309                         << ", the server num:" << server_num_ << ", the rank id:" << rank_ids.at(it);
310     }
311 
312     auto message_meta = std::make_shared<MessageMeta>();
313     MS_EXCEPTION_IF_NULL(message_meta);
314     message_meta->set_cmd(NodeCommand::SEND_DATA);
315     message_meta->set_request_id(request_id);
316     message_meta->set_rank_id(node_info_.rank_id_);
317     message_meta->set_role(node_info_.node_role_);
318     message_meta->set_user_cmd(command);
319 
320     auto send = data.at(it);
321     auto len = data_lens.at(it);
322 
323     auto client = GetOrCreateTcpClient(rank_ids.at(it));
324     MS_EXCEPTION_IF_NULL(client);
325     if (!client->SendMessage(message_meta, Protos::RAW, send.get(), len)) {
326       MS_LOG(WARNING) << "Client send message failed.";
327     }
328   }
329   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
330                 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
331   return Wait(request_id, timeout);
332 }
333 
CollectiveSendAsync(const NodeRole & node_role,const uint32_t & rank_id,const void * data,size_t size)334 uint64_t AbstractNode::CollectiveSendAsync(const NodeRole &node_role, const uint32_t &rank_id, const void *data,
335                                            size_t size) {
336   MS_EXCEPTION_IF_NULL(data);
337   if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
338     MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
339                       << ", the server num:" << server_num_ << ", the rank id:" << rank_id;
340   }
341 
342   std::shared_ptr<MessageMeta> message_meta = std::make_shared<MessageMeta>();
343   MS_EXCEPTION_IF_NULL(message_meta);
344   message_meta->set_cmd(NodeCommand::COLLECTIVE_SEND_DATA);
345   message_meta->set_rank_id(node_info_.rank_id_);
346   message_meta->set_role(node_info_.node_role_);
347 
348   auto client = GetOrCreateTcpClient(rank_id);
349   MS_EXCEPTION_IF_NULL(client);
350   return SendMessageAsync(client, message_meta, Protos::RAW, data, size);
351 }
352 
CollectiveReceiveAsync(const NodeRole & node_role,const uint32_t & rank_id,VectorPtr * output)353 std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const NodeRole &node_role, const uint32_t &rank_id,
354                                                                    VectorPtr *output) {
355   MS_EXCEPTION_IF_NULL(output);
356   if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
357     MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
358                       << ", the server num:" << server_num_ << ", the rank id:" << rank_id;
359   }
360 
361   receive_callbacks_mutex_.lock();
362   uint64_t rank_request_id = NextExpectedRankRequestId(rank_id);
363   auto pair_data = std::make_pair(rank_id, rank_request_id);
364   receive_messages_done_[pair_data] = false;
365   if (received_data_.count(pair_data) > 0) {
366     auto res = received_data_[pair_data];
367     MS_EXCEPTION_IF_NULL(res);
368     *output = res;
369     (void)received_data_.erase(pair_data);
370     receive_messages_done_[pair_data] = true;
371     MS_LOG(DEBUG) << "Receive data from rank id:" << rank_id << ", the rank request id is:" << rank_request_id;
372   } else {
373     receive_callbacks_[pair_data] = [=]() mutable {
374       auto res_output = received_data_[std::make_pair(rank_id, rank_request_id)];
375       MS_EXCEPTION_IF_NULL(res_output);
376       if (*output != nullptr) {
377         MS_LOG(WARNING) << "The output is not empty.";
378       }
379       *output = res_output;
380       received_data_.erase(std::make_pair(rank_id, rank_request_id));
381       receive_messages_done_[std::make_pair(rank_id, rank_request_id)] = true;
382       MS_LOG(DEBUG) << "Receive data from rank id:" << rank_id << ", the rank request id is:" << rank_request_id;
383     };
384   }
385   receive_callbacks_mutex_.unlock();
386   return std::make_pair(rank_id, rank_request_id);
387 }
388 
CollectiveWait(const std::pair<uint32_t,uint64_t> & request_id,const uint32_t & timeout)389 bool AbstractNode::CollectiveWait(const std::pair<uint32_t, uint64_t> &request_id, const uint32_t &timeout) {
390   std::unique_lock<std::mutex> lock(receive_callbacks_mutex_);
391   bool res =
392     receive_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { return receive_messages_done_[request_id]; });
393   if (receive_messages_done_.count(request_id) != 0) {
394     (void)receive_messages_done_.erase(request_id);
395   }
396   return res;
397 }
398 
InitFollowerScaler()399 bool AbstractNode::InitFollowerScaler() {
400   follower_scaler_ = std::make_unique<FollowerScaler>(this);
401   MS_EXCEPTION_IF_NULL(follower_scaler_);
402   follower_scaler_->RegisterScaleEventCallbacks();
403   return true;
404 }
405 
RegisterFollowerScalerBarrierBeforeScaleOut(const std::string & module,const BarrierBeforeScaleOut & barrier)406 void AbstractNode::RegisterFollowerScalerBarrierBeforeScaleOut(const std::string &module,
407                                                                const BarrierBeforeScaleOut &barrier) {
408   MS_EXCEPTION_IF_NULL(follower_scaler_);
409   follower_scaler_->RegisterBarrierBeforeScaleOut(module, barrier);
410 }
411 
RegisterFollowerScalerBarrierBeforeScaleIn(const std::string & module,const BarrierBeforeScaleIn & barrier)412 void AbstractNode::RegisterFollowerScalerBarrierBeforeScaleIn(const std::string &module,
413                                                               const BarrierBeforeScaleIn &barrier) {
414   MS_EXCEPTION_IF_NULL(follower_scaler_);
415   follower_scaler_->RegisterBarrierBeforeScaleIn(module, barrier);
416 }
417 
RegisterFollowerScalerHandlerAfterScaleOut(const std::string & module,const HandlerAfterScaleOut & handler)418 void AbstractNode::RegisterFollowerScalerHandlerAfterScaleOut(const std::string &module,
419                                                               const HandlerAfterScaleOut &handler) {
420   MS_EXCEPTION_IF_NULL(follower_scaler_);
421   follower_scaler_->RegisterHandlerAfterScaleOut(module, handler);
422 }
423 
RegisterFollowerScalerHandlerAfterScaleIn(const std::string & module,const HandlerAfterScaleIn & handler)424 void AbstractNode::RegisterFollowerScalerHandlerAfterScaleIn(const std::string &module,
425                                                              const HandlerAfterScaleIn &handler) {
426   MS_EXCEPTION_IF_NULL(follower_scaler_);
427   follower_scaler_->RegisterHandlerAfterScaleIn(module, handler);
428 }
429 
worker_num() const430 int32_t AbstractNode::worker_num() const { return worker_num_; }
431 
server_num() const432 int32_t AbstractNode::server_num() const { return server_num_; }
433 
set_worker_num(const int32_t & worker_num)434 void AbstractNode::set_worker_num(const int32_t &worker_num) { worker_num_ = worker_num; }
435 
set_server_num(const int32_t & server_num)436 void AbstractNode::set_server_num(const int32_t &server_num) { server_num_ = server_num; }
437 
scheduler_ip() const438 std::string AbstractNode::scheduler_ip() const { return scheduler_ip_; }
439 
set_scheduler_ip(const std::string & scheduler_ip)440 void AbstractNode::set_scheduler_ip(const std::string &scheduler_ip) { scheduler_ip_ = scheduler_ip; }
441 
scheduler_port() const442 uint16_t AbstractNode::scheduler_port() const { return scheduler_port_; }
443 
set_scheduler_port(const uint16_t & scheduler_port)444 void AbstractNode::set_scheduler_port(const uint16_t &scheduler_port) { scheduler_port_ = scheduler_port; }
445 
cluster_state() const446 ClusterState AbstractNode::cluster_state() const { return current_cluster_state_; }
447 
set_handler(const RequestHandler & handler)448 void AbstractNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; }
449 
Response(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)450 void AbstractNode::Response(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
451                             const void *data, size_t size) {
452   MS_EXCEPTION_IF_NULL(conn);
453   MS_EXCEPTION_IF_NULL(meta);
454   MS_EXCEPTION_IF_NULL(data);
455   MS_EXCEPTION_IF_NULL(server_);
456   meta->set_role(node_info_.node_role_);
457   meta->set_rank_id(node_info_.rank_id_);
458   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
459                 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << meta->request_id();
460   if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
461     MS_LOG(WARNING) << "Server response message failed.";
462   }
463 }
464 
GetOrCreateHttpComm(const std::string & ip,uint16_t port,const std::shared_ptr<TaskExecutor> & task_executor)465 std::shared_ptr<CommunicatorBase> AbstractNode::GetOrCreateHttpComm(
466   const std::string &ip, uint16_t port, const std::shared_ptr<TaskExecutor> &task_executor) {
467   MS_EXCEPTION_IF_NULL(task_executor);
468   std::lock_guard<std::mutex> lock(communicator_mutex_);
469   if (!communicators_.count(kHttpCommunicator)) {
470     MS_LOG(INFO) << "Create Http communicator.";
471     auto http_comm = std::make_shared<HttpCommunicator>(ip, port, task_executor);
472     MS_EXCEPTION_IF_NULL(http_comm);
473     communicators_[kHttpCommunicator] = http_comm;
474   }
475   return communicators_[kHttpCommunicator];
476 }
477 
GetOrCreateTcpComm(const std::string & scheduler_ip,std::int16_t scheduler_port,uint32_t worker_num,uint32_t server_num,const std::shared_ptr<TaskExecutor> & task_executor)478 std::shared_ptr<CommunicatorBase> AbstractNode::GetOrCreateTcpComm(const std::string &scheduler_ip,
479                                                                    std::int16_t scheduler_port, uint32_t worker_num,
480                                                                    uint32_t server_num,
481                                                                    const std::shared_ptr<TaskExecutor> &task_executor) {
482   MS_EXCEPTION_IF_NULL(task_executor);
483   std::lock_guard<std::mutex> lock(communicator_mutex_);
484   if (!communicators_.count(kTcpCommunicator)) {
485     MS_LOG(INFO) << "Create Tcp communicator.";
486     auto tcp_comm = std::make_shared<TcpCommunicator>(task_executor, this);
487     PSContext::instance()->cluster_config().scheduler_host = scheduler_ip;
488     PSContext::instance()->cluster_config().scheduler_port = static_cast<uint16_t>(scheduler_port);
489     PSContext::instance()->cluster_config().initial_worker_num = worker_num;
490     PSContext::instance()->cluster_config().initial_server_num = server_num;
491     MS_EXCEPTION_IF_NULL(tcp_comm);
492     PSContext::instance()->cluster_config().scheduler_host = scheduler_ip;
493     PSContext::instance()->cluster_config().scheduler_port = static_cast<uint16_t>(scheduler_port);
494     PSContext::instance()->cluster_config().initial_worker_num = worker_num;
495     PSContext::instance()->cluster_config().initial_server_num = server_num;
496     MS_LOG(INFO) << "Initialize cluster metadata for server. Worker number:" << worker_num
497                  << ", Server number:" << server_num << ", Scheduler ip:" << scheduler_ip
498                  << ", Scheduler port:" << scheduler_port;
499     communicators_[kTcpCommunicator] = tcp_comm;
500   }
501   return communicators_[kTcpCommunicator];
502 }
503 
StartHeartbeatTimer(const std::shared_ptr<TcpClient> & client)504 void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) {
505   MS_EXCEPTION_IF_NULL(client);
506   MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_)
507                << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_
508                << " begin send heartbeat to the scheduler!";
509   heart_beat_thread_ = std::make_unique<std::thread>([&]() {
510     while (!is_finish_.load()) {
511       if (!Heartbeat(client)) {
512         MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
513                         << ", the node id is:" << node_info_.node_id_ << " Send heartbeat timeout!";
514         if (CheckSchedulerTimeout()) {
515           MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
516                           << ", the node id is:" << node_info_.node_id_ << " exited due to scheduler timeout!";
517           is_finish_ = true;
518           wait_finish_cond_.notify_all();
519           if (!is_already_stopped_) {
520             OnEventCallback(ClusterEvent::SCHEDULER_TIMEOUT);
521           }
522         }
523       } else {
524         UpdateSchedulerTime();
525       }
526 
527       std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval));
528     }
529   });
530   MS_EXCEPTION_IF_NULL(heart_beat_thread_);
531   heart_beat_thread_->detach();
532 }
533 
Heartbeat(const std::shared_ptr<TcpClient> & client)534 bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client) {
535   MS_EXCEPTION_IF_NULL(client);
536   auto meta = std::make_shared<MessageMeta>();
537   MS_EXCEPTION_IF_NULL(meta);
538   meta->set_cmd(NodeCommand::HEARTBEAT);
539 
540   HeartbeatMessage heartbeat_message;
541   heartbeat_message.set_node_id(node_info_.node_id_);
542 
543   if (!SendMessageSync(client, meta, Protos::PROTOBUF, heartbeat_message.SerializeAsString().data(),
544                        heartbeat_message.ByteSizeLong(), kCommTimeoutInSeconds)) {
545     MS_LOG(WARNING) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!";
546     return false;
547   }
548   return true;
549 }
550 
UpdateSchedulerTime()551 void AbstractNode::UpdateSchedulerTime() {
552   struct timeval current_time {};
553   (void)gettimeofday(&current_time, nullptr);
554   scheduler_time_ = current_time;
555   MS_LOG(DEBUG) << "Update scheduler time, the current time is: " << current_time.tv_sec;
556 }
557 
CheckSchedulerTimeout() const558 bool AbstractNode::CheckSchedulerTimeout() const {
559   struct timeval current_time {};
560   (void)gettimeofday(&current_time, nullptr);
561   int64_t old_time = scheduler_time_.tv_sec + PSContext::instance()->cluster_config().scheduler_timeout;
562   if (old_time < current_time.tv_sec) {
563     return true;
564   }
565   return false;
566 }
567 
ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)568 void AbstractNode::ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
569   MS_EXCEPTION_IF_NULL(meta);
570   MS_EXCEPTION_IF_NULL(data);
571   HeartbeatRespMessage heartbeat_resp_message;
572   CHECK_RETURN_TYPE(heartbeat_resp_message.ParseFromArray(data, SizeToInt(size)));
573 
574   current_cluster_state_ = heartbeat_resp_message.cluster_state();
575   MS_LOG(DEBUG) << "The current cluster state from heartbeat:"
576                 << CommUtil::ClusterStateToString(current_cluster_state_);
577 
578   all_nodes_info_.clear();
579   for (const auto &it : heartbeat_resp_message.servers_meta()) {
580     NodeInfo info;
581     info.ip_ = it.ip();
582     info.node_id_ = it.node_id();
583     info.port_ = static_cast<uint16_t>(it.port());
584     info.node_role_ = it.role();
585     info.rank_id_ = it.rank_id();
586     info.is_alive = it.is_alive();
587 
588     all_nodes_info_[info.node_id_] = info;
589     MS_LOG(DEBUG) << "The node id:" << info.node_id_ << ", the rank id:" << info.rank_id_
590                   << ", the node role:" << CommUtil::NodeRoleToString(info.node_role_) << " is alive:" << info.is_alive;
591   }
592 
593   bool is_worker_or_server0 = heartbeat_resp_message.is_worker_or_server0();
594 
595   if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
596     if (node_recovery_ == nullptr || is_worker_or_server0) {
597       MS_LOG(INFO) << "The recovery is disable.";
598       is_ready_ = true;
599       wait_start_cond_.notify_all();
600       OnEventCallback(ClusterEvent::NODE_TIMEOUT);
601     } else {
602       MS_LOG(INFO) << "The node is support recovery, users can pull up this node to restore the cluster.";
603     }
604   }
605 }
606 
FetchServers(const std::shared_ptr<TcpClient> & client)607 void AbstractNode::FetchServers(const std::shared_ptr<TcpClient> &client) {
608   MS_EXCEPTION_IF_NULL(client);
609   auto meta = std::make_shared<MessageMeta>();
610   MS_EXCEPTION_IF_NULL(meta);
611   meta->set_cmd(NodeCommand::FETCH_METADATA);
612 
613   FetchServersMessage fetch_servers;
614   fetch_servers.set_node_id(node_info_.node_id_);
615   if (!SendMessageSync(client, meta, Protos::PROTOBUF, fetch_servers.SerializeAsString().data(),
616                        fetch_servers.ByteSizeLong())) {
617     MS_LOG(EXCEPTION) << "Fetch servers address timeout!";
618   }
619 }
620 
ProcessFetchServersResp(const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)621 void AbstractNode::ProcessFetchServersResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
622   MS_EXCEPTION_IF_NULL(meta);
623   MS_EXCEPTION_IF_NULL(data);
624   FetchServersRespMessage fetch_servers_resp_message;
625   CHECK_RETURN_TYPE(fetch_servers_resp_message.ParseFromArray(data, SizeToInt(size)));
626 
627   nodes_address_.clear();
628   for (const auto &it : fetch_servers_resp_message.servers_meta()) {
629     nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port());
630     MS_LOG(INFO) << "The server ip is:" << it.ip() << ", the port is:" << it.port();
631   }
632 }
633 
ProcessSendMetadata(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)634 void AbstractNode::ProcessSendMetadata(const std::shared_ptr<TcpConnection> &conn,
635                                        const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
636                                        size_t size) {
637   MS_EXCEPTION_IF_NULL(conn);
638   MS_EXCEPTION_IF_NULL(meta);
639   MS_EXCEPTION_IF_NULL(data);
640   if (is_current_node_scale_in_) {
641     MS_LOG(WARNING) << "Trigger cluster scale in done event.";
642     node_info_.rank_id_ = UINT32_MAX;
643     OnEventCallback(ClusterEvent::CLUSTER_SCALE_IN_DONE);
644     return;
645   }
646   SendMetadataMessage send_meta_message;
647   CHECK_RETURN_TYPE(send_meta_message.ParseFromArray(data, SizeToInt(size)));
648   worker_num_ = send_meta_message.worker_num();
649   server_num_ = send_meta_message.server_num();
650   if (send_meta_message.rank_id() < 0) {
651     MS_LOG(EXCEPTION) << "The rank id is wrong.";
652   }
653   node_info_.rank_id_ = send_meta_message.rank_id();
654   current_cluster_state_ = send_meta_message.cluster_state();
655   MS_LOG(INFO) << "The send metadata worker num:" << worker_num_ << ", server num:" << server_num_
656                << ", cluster state is:" << CommUtil::ClusterStateToString(current_cluster_state_)
657                << ", the rank id:" << node_info_.rank_id_;
658 
659   client_mutex_.lock();
660   nodes_address_.clear();
661   for (const auto &it : send_meta_message.servers_meta()) {
662     nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port());
663     MS_LOG(INFO) << "The server ip is:" << it.ip() << ", the port is:" << it.port() << ", the rank id:" << it.rank_id();
664   }
665   client_mutex_.unlock();
666   if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
667     MS_LOG(WARNING) << "Sever response message failed.";
668   }
669   is_ready_ = true;
670   wait_start_cond_.notify_all();
671 
672   if (current_cluster_state_ == ClusterState::CLUSTER_SCALE_OUT) {
673     MS_LOG(WARNING) << "Trigger cluster scale out done event.";
674     OnEventCallback(ClusterEvent::CLUSTER_SCALE_OUT_DONE);
675   }
676 
677   if (current_cluster_state_ == ClusterState::CLUSTER_SCALE_IN) {
678     MS_LOG(WARNING) << "Trigger cluster scale in done event.";
679     OnEventCallback(ClusterEvent::CLUSTER_SCALE_IN_DONE);
680   }
681 
682   std::lock_guard<std::mutex> lock(client_mutex_);
683   connected_nodes_.clear();
684 }
685 
ProcessFinish(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)686 void AbstractNode::ProcessFinish(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
687                                  const Protos &, const void *data, size_t size) {
688   MS_EXCEPTION_IF_NULL(conn);
689   MS_EXCEPTION_IF_NULL(meta);
690   MS_EXCEPTION_IF_NULL(data);
691   if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
692     MS_LOG(WARNING) << "Server response message failed.";
693   }
694   is_finish_ = true;
695   wait_finish_cond_.notify_all();
696 }
697 
ProcessScaleOutDone(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)698 void AbstractNode::ProcessScaleOutDone(const std::shared_ptr<TcpConnection> &conn,
699                                        const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
700                                        size_t size) {
701   MS_EXCEPTION_IF_NULL(conn);
702   MS_EXCEPTION_IF_NULL(meta);
703   MS_EXCEPTION_IF_NULL(data);
704   if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
705     MS_LOG(WARNING) << "Server response message failed.";
706   }
707   is_ready_ = true;
708   current_cluster_state_ = ClusterState::CLUSTER_READY;
709 }
710 
ProcessScaleInDone(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)711 void AbstractNode::ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn,
712                                       const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
713                                       size_t size) {
714   MS_EXCEPTION_IF_NULL(conn);
715   MS_EXCEPTION_IF_NULL(meta);
716   MS_EXCEPTION_IF_NULL(data);
717   if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
718     MS_LOG(WARNING) << "Server response message failed.";
719   }
720   is_ready_ = true;
721   current_cluster_state_ = ClusterState::CLUSTER_READY;
722 }
723 
ProcessEvent(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)724 void AbstractNode::ProcessEvent(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
725                                 const Protos &, const void *data, size_t size) {
726   MS_EXCEPTION_IF_NULL(conn);
727   MS_EXCEPTION_IF_NULL(meta);
728   MS_EXCEPTION_IF_NULL(data);
729   EventRespMessage event_resp_message;
730   CHECK_RETURN_TYPE(event_resp_message.ParseFromArray(data, SizeToInt(size)));
731   uint32_t event = event_resp_message.event();
732   if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
733     MS_LOG(WARNING) << "Server response message failed.";
734   }
735   OnCustomEventCallback(event);
736 }
737 
ProcessScaleOut(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)738 void AbstractNode::ProcessScaleOut(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
739                                    const Protos &, const void *data, size_t size) {
740   MS_EXCEPTION_IF_NULL(conn);
741   MS_EXCEPTION_IF_NULL(meta);
742   MS_EXCEPTION_IF_NULL(data);
743 
744   ScaleOutMessage scale_out_message;
745   CHECK_RETURN_TYPE(scale_out_message.ParseFromArray(data, SizeToInt(size)));
746   int32_t worker_num = scale_out_message.worker_num();
747   int32_t server_num = scale_out_message.server_num();
748   MS_LOG(WARNING) << "The scale out worker num:" << worker_num << ", the server num:" << server_num;
749 
750   if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
751     MS_LOG(WARNING) << "Server response message failed.";
752   }
753   OnEventCallback(ClusterEvent::READY_FOR_SCALE_OUT);
754   current_cluster_state_ = ClusterState::CLUSTER_SCALE_OUT;
755   is_ready_ = false;
756 }
757 
ProcessScaleIn(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)758 void AbstractNode::ProcessScaleIn(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
759                                   const Protos &, const void *data, size_t size) {
760   MS_EXCEPTION_IF_NULL(conn);
761   MS_EXCEPTION_IF_NULL(meta);
762   MS_EXCEPTION_IF_NULL(data);
763 
764   ScaleInMessage scale_in_message;
765   CHECK_RETURN_TYPE(scale_in_message.ParseFromArray(data, SizeToInt(size)));
766   int32_t worker_num = scale_in_message.worker_num();
767   int32_t server_num = scale_in_message.server_num();
768   MS_LOG(WARNING) << "The scale in worker num:" << worker_num << ", the server num:" << server_num;
769 
770   is_current_node_scale_in_ = scale_in_message.is_node_scale_in();
771   if (is_current_node_scale_in_) {
772     MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
773                     << " the node id:" << node_info_.node_id_ << " is a scale in node!";
774   } else {
775     MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
776                     << " the node id:" << node_info_.node_id_ << " is not a scale in node!";
777   }
778 
779   if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
780     MS_LOG(WARNING) << "Server response message failed.";
781   }
782   OnEventCallback(ClusterEvent::READY_FOR_SCALE_IN);
783   current_cluster_state_ = ClusterState::CLUSTER_SCALE_IN;
784   is_ready_ = false;
785 }
786 
Disconnect(const std::shared_ptr<TcpClient> & client,const uint32_t & timeout)787 bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout) {
788   MS_EXCEPTION_IF_NULL(client);
789   auto meta = std::make_shared<MessageMeta>();
790   MS_EXCEPTION_IF_NULL(meta);
791   meta->set_cmd(NodeCommand::FINISH);
792 
793   std::string finish_message = node_info_.node_id_;
794 
795   if (!SendMessageSync(client, meta, Protos::RAW, finish_message.data(), finish_message.length())) {
796     MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
797                     << " the node id:" << node_info_.node_id_ << " send Finish Message timeout!";
798   }
799   return WaitForDisconnect(timeout);
800 }
801 
WaitForDisconnect(const uint32_t & timeout)802 bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) {
803   std::unique_lock<std::mutex> lock(wait_finish_mutex_);
804   bool res = wait_finish_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
805     if (is_finish_.load()) {
806       MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success finish!";
807     }
808     return is_finish_.load();
809   });
810   return res;
811 }
812 
InitClientToScheduler()813 bool AbstractNode::InitClientToScheduler() {
814   if (config_ == nullptr) {
815     MS_LOG(WARNING) << "The config is empty.";
816     return false;
817   }
818   client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_ip_, scheduler_port_, config_.get());
819   MS_EXCEPTION_IF_NULL(client_to_scheduler_);
820   client_to_scheduler_->SetMessageCallback(
821     [&](const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data, size_t size) {
822       try {
823         MS_EXCEPTION_IF_NULL(meta);
824         MS_EXCEPTION_IF_NULL(data);
825         if (handlers_.count(meta->cmd()) == 0) {
826           MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
827         }
828         if (handlers_[meta->cmd()] != nullptr) {
829           const auto &handler_ptr = handlers_[meta->cmd()];
830           (this->*handler_ptr)(meta, data, size);
831         }
832         NotifyMessageArrival(meta);
833       } catch (const std::exception &e) {
834         MsException::Instance().SetException();
835       }
836     });
837 
838   client_to_scheduler_->Init();
839   client_to_scheduler_thread_ = std::make_unique<std::thread>([&]() {
840     MS_LOG(INFO) << "The node start a tcp client!";
841     client_to_scheduler_->Start();
842   });
843   client_to_scheduler_thread_->detach();
844 
845   client_to_scheduler_->set_disconnected_callback([&]() {
846     std::this_thread::sleep_for(std::chrono::milliseconds(PSContext::instance()->cluster_config().connect_interval));
847     if (is_ready_.load() == false) {
848       client_to_scheduler_->Init();
849     }
850   });
851   bool wait_res = client_to_scheduler_->WaitConnected();
852   if (!wait_res) {
853     is_ready_ = true;
854   }
855   return wait_res;
856 }
857 
GetOrCreateTcpClient(const uint32_t & rank_id)858 const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const uint32_t &rank_id) {
859   std::lock_guard<std::mutex> lock(client_mutex_);
860   if (connected_nodes_.find(rank_id) != connected_nodes_.end()) {
861     return connected_nodes_[rank_id];
862   } else {
863     if (nodes_address_.find(std::make_pair(NodeRole::SERVER, rank_id)) == nodes_address_.end()) {
864       MS_LOG(EXCEPTION) << "Worker receive nodes info from scheduler failed!";
865     }
866     if (config_ == nullptr) {
867       MS_LOG(EXCEPTION) << "The config is empty.";
868     }
869     std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first;
870     uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second;
871     auto client = std::make_shared<TcpClient>(ip, port, config_.get());
872     MS_EXCEPTION_IF_NULL(client);
873     client->SetMessageCallback([&](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
874                                    size_t size) {
875       switch (meta->cmd()) {
876         case NodeCommand::SEND_DATA:
877           ProcessSendDataResp(meta, protos, data, size);
878           RunMessageCallback(meta->request_id());
879           break;
880         case NodeCommand::COLLECTIVE_SEND_DATA:
881           MS_LOG(DEBUG) << "The Node id:" << node_info_.node_id_ << " receive a collective_send_data message response!";
882           break;
883         default:
884           MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
885       }
886       NotifyMessageArrival(meta);
887     });
888     client->Init();
889     connected_nodes_[rank_id] = client;
890     return connected_nodes_[rank_id];
891   }
892 }
893 
SendMessageSync(const std::shared_ptr<TcpClient> & client,const CommMessage & message,const uint32_t & timeout)894 bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
895                                    const uint32_t &timeout) {
896   MS_EXCEPTION_IF_NULL(client);
897   uint64_t request_id = AddMessageTrack(1);
898   const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
899   client->SendMessage(message);
900   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
901                 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
902   return Wait(request_id, timeout);
903 }
904 
SendMessageAsync(const std::shared_ptr<TcpClient> & client,const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size)905 uint64_t AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client,
906                                         const std::shared_ptr<MessageMeta> &meta, const Protos &protos,
907                                         const void *data, size_t size) {
908   MS_EXCEPTION_IF_NULL(client);
909   MS_EXCEPTION_IF_NULL(meta);
910   MS_EXCEPTION_IF_NULL(data);
911   uint64_t request_id = AddMessageTrack(1);
912   meta->set_request_id(request_id);
913   client->SendMessage(meta, protos, data, size);
914   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
915                 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
916   return request_id;
917 }
918 
SendMessageSync(const std::shared_ptr<TcpClient> & client,const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size,const uint32_t & timeout)919 bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta,
920                                    const Protos &protos, const void *data, size_t size, const uint32_t &timeout) {
921   MS_EXCEPTION_IF_NULL(client);
922   MS_EXCEPTION_IF_NULL(meta);
923   MS_EXCEPTION_IF_NULL(data);
924   uint64_t request_id = AddMessageTrack(1);
925   meta->set_request_id(request_id);
926   client->SendMessage(meta, protos, data, size);
927   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
928                 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
929   return Wait(request_id, timeout);
930 }
931 
ProcessCollectiveSendData(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)932 void AbstractNode::ProcessCollectiveSendData(const std::shared_ptr<TcpConnection> &conn,
933                                              const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
934   MS_EXCEPTION_IF_NULL(conn);
935   MS_EXCEPTION_IF_NULL(meta);
936   MS_EXCEPTION_IF_NULL(data);
937   if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
938     MS_LOG(WARNING) << "Server response message failed.";
939   }
940 }
941 
ProcessSendData(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)942 void AbstractNode::ProcessSendData(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
943                                    const Protos &, const void *data, size_t size) {
944   MS_EXCEPTION_IF_NULL(conn);
945   MS_EXCEPTION_IF_NULL(meta);
946   MS_EXCEPTION_IF_NULL(data);
947   std::shared_ptr<unsigned char[]> res(new unsigned char[size]);
948   if (size > 0) {
949     size_t dest_size = size;
950     size_t src_size = size;
951     if (memcpy_s(res.get(), dest_size, data, src_size) != EOK) {
952       MS_LOG(EXCEPTION) << "The memcpy_s error";
953     }
954   }
955   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
956                 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << meta->request_id()
957                 << " the current time is:"
958                 << std::chrono::time_point_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now())
959                      .time_since_epoch()
960                      .count();
961   request_handler_(conn, meta, res, size);
962 }
963 
NotifyMessageArrival(const std::shared_ptr<MessageMeta> & meta)964 void AbstractNode::NotifyMessageArrival(const std::shared_ptr<MessageMeta> &meta) {
965   MS_EXCEPTION_IF_NULL(meta);
966   std::lock_guard<std::mutex> lock(message_tracker_mutex_);
967   uint64_t request_id = meta->request_id();
968   if (message_tracker_.count(request_id)) {
969     message_tracker_[request_id].second++;
970   } else {
971     MS_LOG(WARNING) << "The requset id:" << request_id << " is removed.";
972   }
973   message_tracker_cond_.notify_all();
974 }
975 
RunReceiveCallback(const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)976 void AbstractNode::RunReceiveCallback(const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
977                                       size_t size) {
978   MS_EXCEPTION_IF_NULL(meta);
979   MS_EXCEPTION_IF_NULL(data);
980   receive_callbacks_mutex_.lock();
981   uint32_t rank_id = meta->rank_id();
982   // When receiving a collective message, Then generate rank request id,compare with the desired rank request id,
983   // If they are equal, then call the callback function
984   uint64_t rank_request_id = NextActualRankRequestId(rank_id);
985   std::shared_ptr<std::vector<unsigned char>> received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
986   size_t dest_size = size;
987   size_t src_size = size;
988   int ret = memcpy_s(received_data->data(), dest_size, data, src_size);
989   if (ret != 0) {
990     receive_callbacks_mutex_.unlock();
991     MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
992   }
993   received_data_[std::make_pair(rank_id, rank_request_id)] = received_data;
994   MS_LOG(DEBUG) << "Run Receive data callback,the rank id:" << rank_id << ", the rank request id is:" << rank_request_id
995                 << ", the send request id is:" << meta->request_id() << " the size is:" << size;
996   auto it = receive_callbacks_.find(std::make_pair(rank_id, rank_request_id));
997   if (it != receive_callbacks_.end()) {
998     if (receive_messages_done_.count(std::make_pair(rank_id, rank_request_id)) != 0) {
999       if (it->second) {
1000         it->second();
1001       }
1002     }
1003     receive_cond_.notify_all();
1004     receive_callbacks_.erase(it);
1005   }
1006   receive_callbacks_mutex_.unlock();
1007 }
1008 
NextExpectedRankRequestId(const uint32_t & rank_id)1009 uint64_t AbstractNode::NextExpectedRankRequestId(const uint32_t &rank_id) {
1010   std::lock_guard<std::mutex> lock(rank_request_ids_mutex);
1011   uint64_t rank_request_id = 1;
1012   if (expected_rank_request_ids_.count(rank_id)) {
1013     rank_request_id = ++expected_rank_request_ids_[rank_id];
1014     expected_rank_request_ids_[rank_id] = rank_request_id;
1015   } else {
1016     expected_rank_request_ids_[rank_id] = rank_request_id;
1017   }
1018   return rank_request_id;
1019 }
1020 
NextActualRankRequestId(const uint32_t & rank_id)1021 uint64_t AbstractNode::NextActualRankRequestId(const uint32_t &rank_id) {
1022   std::lock_guard<std::mutex> lock(rank_request_ids_mutex);
1023   uint64_t rank_request_id = 1;
1024   if (actual_rank_request_ids_.count(rank_id)) {
1025     rank_request_id = ++actual_rank_request_ids_[rank_id];
1026     actual_rank_request_ids_[rank_id] = rank_request_id;
1027   } else {
1028     actual_rank_request_ids_[rank_id] = rank_request_id;
1029   }
1030   return rank_request_id;
1031 }
1032 
InitCommandHandler()1033 void AbstractNode::InitCommandHandler() {
1034   handlers_[NodeCommand::HEARTBEAT] = &AbstractNode::ProcessHeartbeatResp;
1035   handlers_[NodeCommand::REGISTER] = &AbstractNode::ProcessRegisterResp;
1036   handlers_[NodeCommand::FETCH_METADATA] = &AbstractNode::ProcessFetchServersResp;
1037   handlers_[NodeCommand::FINISH] = nullptr;
1038   handlers_[NodeCommand::SCALE_OUT_DONE] = nullptr;
1039   handlers_[NodeCommand::SCALE_IN_DONE] = nullptr;
1040   handlers_[NodeCommand::SEND_EVENT] = nullptr;
1041 }
1042 
InitServerHandler()1043 void AbstractNode::InitServerHandler() {
1044   server_handler_[NodeCommand::SEND_METADATA] = &AbstractNode::ProcessSendMetadata;
1045   server_handler_[NodeCommand::FINISH] = &AbstractNode::ProcessFinish;
1046   server_handler_[NodeCommand::SEND_DATA] = nullptr;
1047   server_handler_[NodeCommand::COLLECTIVE_SEND_DATA] = nullptr;
1048   server_handler_[NodeCommand::SCALE_OUT] = &AbstractNode::ProcessScaleOut;
1049   server_handler_[NodeCommand::SCALE_IN] = &AbstractNode::ProcessScaleIn;
1050   server_handler_[NodeCommand::SCALE_OUT_DONE] = &AbstractNode::ProcessScaleOutDone;
1051   server_handler_[NodeCommand::SCALE_IN_DONE] = &AbstractNode::ProcessScaleInDone;
1052   server_handler_[NodeCommand::SEND_EVENT] = &AbstractNode::ProcessEvent;
1053 }
1054 
InitNodeInfo(const NodeRole & role)1055 void AbstractNode::InitNodeInfo(const NodeRole &role) {
1056   MS_EXCEPTION_IF_NULL(config_);
1057   MS_EXCEPTION_IF_NULL(server_);
1058   if (PSContext::instance()->node_id().empty() && config_->Exists(kNodeId)) {
1059     node_info_.node_id_ = config_->Get(kNodeId, "");
1060   } else {
1061     node_info_.node_id_ = PSContext::instance()->node_id();
1062   }
1063 
1064   if (node_info_.node_id_.empty()) {
1065     node_info_.node_id_ = CommUtil::GenerateUUID();
1066   }
1067   node_info_.node_role_ = role;
1068   node_info_.ip_ = server_->BoundIp();
1069   node_info_.port_ = server_->BoundPort();
1070 
1071   MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
1072                << " is generate uuid is:" << node_info_.node_id_ << ", the ip:" << server_->BoundIp()
1073                << ", the port:" << server_->BoundPort();
1074 }
1075 
InitNodeNum()1076 void AbstractNode::InitNodeNum() {
1077   worker_num_ = SizeToInt(PSContext::instance()->cluster_config().initial_worker_num);
1078   server_num_ = SizeToInt(PSContext::instance()->cluster_config().initial_server_num);
1079   scheduler_ip_ = PSContext::instance()->cluster_config().scheduler_host;
1080   scheduler_port_ = PSContext::instance()->cluster_config().scheduler_port;
1081   MS_LOG(INFO) << "The worker num:" << worker_num_ << ", the server num:" << server_num_
1082                << ", the scheduler ip:" << scheduler_ip_ << ", the scheduler port:" << scheduler_port_;
1083 }
1084 
Recover()1085 bool AbstractNode::Recover() {
1086   MS_EXCEPTION_IF_NULL(config_);
1087   if (config_->Exists(kKeyRecovery)) {
1088     MS_LOG(INFO) << "The node is support recovery.";
1089     node_recovery_ = std::make_unique<NodeRecovery>(this);
1090     MS_EXCEPTION_IF_NULL(node_recovery_);
1091     node_recovery_->Initialize(config_->Get(kKeyRecovery, ""));
1092     return node_recovery_->Recover();
1093   }
1094   return false;
1095 }
1096 
OnEventCallback(const ClusterEvent & event)1097 void AbstractNode::OnEventCallback(const ClusterEvent &event) {
1098   if (!event_to_callback_.count(event)) {
1099     MS_LOG(ERROR) << "[Event]:The event callback of " << event << " is not set.";
1100   } else {
1101     MS_LOG(INFO) << "[Event]:Trigger the event:" << event;
1102     if (event_to_callback_[event]) {
1103       event_to_callback_[event]();
1104     }
1105   }
1106 }
1107 
OnCustomEventCallback(const uint32_t & event)1108 void AbstractNode::OnCustomEventCallback(const uint32_t &event) {
1109   if (!custom_event_to_callback_.count(event)) {
1110     MS_LOG(WARNING) << "[Custom event]:The event callback of " << event << " is not set.";
1111   } else {
1112     MS_LOG(INFO) << "[Custom event]:Trigger the event:" << event;
1113     if (custom_event_to_callback_[event]) {
1114       custom_event_to_callback_[event]();
1115     }
1116   }
1117 }
1118 
IsWorkerOrServer0(const std::unordered_map<std::string,NodeInfo> & info)1119 bool AbstractNode::IsWorkerOrServer0(const std::unordered_map<std::string, NodeInfo> &info) {
1120   for (const auto &it : info) {
1121     if (it.second.is_alive == true && it.second.node_role_ == NodeRole::WORKER) {
1122       return true;
1123     }
1124 
1125     if (it.second.is_alive == true && it.second.rank_id_ == 0 && it.second.node_role_ == NodeRole::SERVER) {
1126       return true;
1127     }
1128   }
1129   return false;
1130 }
1131 
CreateTcpServer()1132 void AbstractNode::CreateTcpServer() {
1133   MS_EXCEPTION_IF_NULL(config_);
1134   std::string interface;
1135   std::string server_ip;
1136   CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
1137   server_ = std::make_shared<TcpServer>(server_ip, 0, config_.get());
1138   MS_EXCEPTION_IF_NULL(server_);
1139   server_->SetMessageCallback([&](const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
1140                                   const Protos &protos, const void *data, size_t size) {
1141     MS_EXCEPTION_IF_NULL(meta);
1142     MS_EXCEPTION_IF_NULL(conn);
1143     MS_EXCEPTION_IF_NULL(data);
1144     if (server_handler_.count(meta->cmd()) == 0) {
1145       MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
1146     }
1147 
1148     if (meta->cmd() == NodeCommand::COLLECTIVE_SEND_DATA) {
1149       ProcessCollectiveSendData(conn, meta, data, size);
1150       RunReceiveCallback(meta, protos, data, size);
1151     } else if (meta->cmd() == NodeCommand::SEND_DATA) {
1152       ProcessSendData(conn, meta, protos, data, size);
1153     } else {
1154       const auto &handler_ptr = server_handler_[meta->cmd()];
1155       (this->*handler_ptr)(conn, meta, protos, data, size);
1156     }
1157   });
1158   server_->Init();
1159   server_thread_ = std::make_unique<std::thread>([this]() {
1160     MS_LOG(INFO) << "The server node start a tcp server!";
1161     this->server_->Start();
1162   });
1163   MS_EXCEPTION_IF_NULL(server_thread_);
1164   server_thread_->detach();
1165 }
1166 }  // namespace core
1167 }  // namespace ps
1168 }  // namespace mindspore
1169