1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ps/core/abstract_node.h"
18
19 #include "include/common/debug/common.h"
20 #include "ps/core/communicator/http_communicator.h"
21 #include "ps/core/communicator/tcp_communicator.h"
22 #include "ps/core/node_recovery.h"
23
24 namespace mindspore {
25 namespace ps {
26 namespace core {
~AbstractNode()27 AbstractNode::~AbstractNode() {
28 try {
29 if (client_to_scheduler_ != nullptr) {
30 client_to_scheduler_->Stop();
31 }
32 if (client_to_scheduler_thread_ != nullptr && client_to_scheduler_thread_->joinable()) {
33 client_to_scheduler_thread_->join();
34 }
35 if (heart_beat_thread_ != nullptr && heart_beat_thread_->joinable()) {
36 heart_beat_thread_->join();
37 }
38 if (server_ != nullptr) {
39 server_->Stop();
40 }
41 if (server_thread_ != nullptr && server_thread_->joinable()) {
42 server_thread_->join();
43 }
44 } catch (const std::exception &e) {
45 MS_LOG(ERROR) << "AbstractNode destructor run failed, error message: " << e.what();
46 } catch (...) {
47 MS_LOG(ERROR) << "AbstractNode destructor run failed, unknown error occurred.";
48 }
49 }
50
Register(const std::shared_ptr<TcpClient> & client)51 void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
52 MS_EXCEPTION_IF_NULL(client);
53 auto message_meta = std::make_shared<MessageMeta>();
54 MS_EXCEPTION_IF_NULL(message_meta);
55 message_meta->set_cmd(NodeCommand::REGISTER);
56 message_meta->set_rank_id(node_info_.rank_id_);
57
58 RegisterMessage register_message;
59 register_message.set_node_id(node_info_.node_id_);
60 register_message.set_role(node_info_.node_role_);
61 register_message.set_ip(node_info_.ip_);
62 register_message.set_port(node_info_.port_);
63 register_message.set_is_recover(is_recover.load());
64
65 MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
66 << " the node id:" << node_info_.node_id_ << " begin to register to the scheduler!";
67
68 if (!SendMessageAsync(client, message_meta, Protos::PROTOBUF, register_message.SerializeAsString().data(),
69 register_message.ByteSizeLong())) {
70 MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
71 << " the node id:" << node_info_.node_id_ << " register timeout!";
72 } else {
73 MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
74 << " the node id:" << node_info_.node_id_ << " send register success!";
75 }
76 }
77
SendFailMessageToScheduler(const std::string & node_role,const std::string & event_info)78 void AbstractNode::SendFailMessageToScheduler(const std::string &node_role, const std::string &event_info) {
79 auto message_meta = std::make_shared<MessageMeta>();
80 MS_EXCEPTION_IF_NULL(message_meta);
81 message_meta->set_cmd(NodeCommand::FAILURE_EVENT_INFO);
82
83 std::string now_time = ps::core::CommUtil::GetNowTime().time_str_mill;
84 FailureEventMessage failure_event_message;
85 failure_event_message.set_node_role(node_role);
86 failure_event_message.set_ip(node_info_.ip_);
87 failure_event_message.set_port(node_info_.port_);
88 failure_event_message.set_time(now_time);
89 failure_event_message.set_event(event_info);
90
91 MS_LOG(INFO) << "The node role:" << node_role << "The node id:" << node_info_.node_id_
92 << "begin to send failure message to scheduler!";
93
94 if (!SendMessageAsync(client_to_scheduler_, message_meta, Protos::PROTOBUF,
95 failure_event_message.SerializeAsString().data(), failure_event_message.ByteSizeLong())) {
96 MS_LOG(ERROR) << "The node role:" << node_role << " the node id:" << node_info_.node_id_
97 << " send failure message timeout!";
98 } else {
99 MS_LOG(INFO) << "The node role:" << node_role << " the node id:" << node_info_.node_id_ << " send failure message "
100 << event_info << "success!";
101 }
102 }
103
ProcessRegisterResp(const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)104 void AbstractNode::ProcessRegisterResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
105 MS_EXCEPTION_IF_NULL(meta);
106 MS_EXCEPTION_IF_NULL(data);
107 RegisterRespMessage register_resp_message;
108 CHECK_RETURN_TYPE(register_resp_message.ParseFromArray(data, SizeToInt(size)));
109 MS_LOG(INFO) << "The node id get from scheduler is:" << register_resp_message.node_id()
110 << ", rank_id is:" << register_resp_message.rank_id();
111
112 if (register_resp_message.node_id() != node_info_.node_id_) {
113 MS_LOG(ERROR) << "The node id received:" << register_resp_message.node_id()
114 << " is not match the current node id:" << node_info_.node_id_;
115 return;
116 }
117 node_info_.rank_id_ = register_resp_message.rank_id();
118 if (node_info_.rank_id_ == UINT32_MAX) {
119 MS_LOG(ERROR) << "The rank id received:" << register_resp_message.rank_id();
120 return;
121 }
122
123 // Receive the Register message, indicating that the scheduler is alive, so update the time point at which the
124 // scheduler is alive
125 UpdateSchedulerTime();
126
127 MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << " registered scheduler success!";
128 }
129
Broadcast(const NodeRole & node_role,const std::string & message,int command,const uint32_t & timeout)130 bool AbstractNode::Broadcast(const NodeRole &node_role, const std::string &message, int command,
131 const uint32_t &timeout) {
132 if (node_role != NodeRole::SERVER) {
133 MS_LOG(EXCEPTION) << "Currently only supports broadcast to server nodes";
134 }
135
136 uint32_t broadcast_size = 0;
137 (void)std::for_each(nodes_address_.begin(), nodes_address_.end(), [&broadcast_size, &node_role](const auto &addr) {
138 if (addr.first.first == node_role) {
139 ++broadcast_size;
140 }
141 });
142 uint64_t request_id = AddMessageTrack(broadcast_size);
143
144 for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) {
145 if (it->first.first != node_role) {
146 continue;
147 }
148 auto message_meta = std::make_shared<MessageMeta>();
149 MS_EXCEPTION_IF_NULL(message_meta);
150 message_meta->set_cmd(NodeCommand::SEND_DATA);
151 message_meta->set_request_id(request_id);
152 message_meta->set_rank_id(node_info_.rank_id_);
153 message_meta->set_role(node_info_.node_role_);
154 message_meta->set_user_cmd(command);
155
156 auto client = GetOrCreateTcpClient((*it).first.second);
157 if (!client->SendMessage(message_meta, Protos::RAW, message.data(), message.size())) {
158 MS_LOG(WARNING) << "Client send message failed.";
159 }
160 }
161 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
162 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
163 return Wait(request_id, timeout);
164 }
165
set_ready_for_scale_out()166 void AbstractNode::set_ready_for_scale_out() {
167 MS_LOG(INFO) << "[Scale out]: begin to set ready for scale out.";
168 Register(client_to_scheduler_);
169 std::lock_guard<std::mutex> lock(client_mutex_);
170 connected_nodes_.clear();
171 }
172
set_ready_for_scale_in()173 void AbstractNode::set_ready_for_scale_in() {
174 MS_LOG(INFO) << "[Scale in]: begin to set ready for scale in.";
175 if (!is_current_node_scale_in_) {
176 Register(client_to_scheduler_);
177 std::lock_guard<std::mutex> lock(client_mutex_);
178 connected_nodes_.clear();
179 }
180 }
181
set_scale_out_done()182 void AbstractNode::set_scale_out_done() {
183 MS_LOG(INFO) << "[Scale out]: begin to set scale out done.";
184 auto message_meta = std::make_shared<MessageMeta>();
185 MS_EXCEPTION_IF_NULL(message_meta);
186 message_meta->set_cmd(NodeCommand::SCALE_OUT_DONE);
187
188 ScaleOutDoneMessage scale_out_done_message;
189 scale_out_done_message.set_node_id(node_info_.node_id_);
190
191 if (!SendMessageSync(client_to_scheduler_, message_meta, Protos::PROTOBUF,
192 scale_out_done_message.SerializeAsString().data(), scale_out_done_message.ByteSizeLong())) {
193 MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
194 << " the node id:" << node_info_.node_id_ << " scale_out_done timeout!";
195 return;
196 }
197
198 MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
199 << " the node id:" << node_info_.node_id_ << "is send scale_out_done to scheduler successful!";
200 }
201
set_scale_in_done()202 void AbstractNode::set_scale_in_done() {
203 MS_LOG(INFO) << "[Scale in]: begin to set scale in done.";
204 auto message_meta = std::make_shared<MessageMeta>();
205 MS_EXCEPTION_IF_NULL(message_meta);
206 message_meta->set_cmd(NodeCommand::SCALE_IN_DONE);
207
208 ScaleInDoneMessage scale_in_done_message;
209 scale_in_done_message.set_node_id(node_info_.node_id_);
210
211 if (!SendMessageSync(client_to_scheduler_, message_meta, Protos::PROTOBUF,
212 scale_in_done_message.SerializeAsString().data(), scale_in_done_message.ByteSizeLong())) {
213 MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
214 << " the node id:" << node_info_.node_id_ << " scale_in_done timeout!";
215 return;
216 }
217
218 MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
219 << " the node id:" << node_info_.node_id_ << "is send scale_in_done to scheduler successful!";
220 }
221
BroadcastEvent(const uint32_t & event)222 void AbstractNode::BroadcastEvent(const uint32_t &event) {
223 auto message_meta = std::make_shared<MessageMeta>();
224 MS_EXCEPTION_IF_NULL(message_meta);
225 message_meta->set_cmd(NodeCommand::SEND_EVENT);
226
227 EventRespMessage event_resp_message;
228 event_resp_message.set_event(event);
229
230 for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) {
231 const uint32_t rank_id = (*it).first.second;
232 const NodeRole role = (*it).first.first;
233 auto client = GetOrCreateTcpClient(rank_id, role);
234 if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, event_resp_message.SerializeAsString().data(),
235 event_resp_message.ByteSizeLong())) {
236 MS_LOG(WARNING) << "send event to node role:" << CommUtil::NodeRoleToString(role) << ", rank id:" << rank_id
237 << " timeout!";
238 }
239 }
240 MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
241 << " the node id:" << node_info_.node_id_ << " send event to server/worker!";
242 }
243
RegisterEventCallback(const core::ClusterEvent & event,const EventCallback & event_cb)244 void AbstractNode::RegisterEventCallback(const core::ClusterEvent &event, const EventCallback &event_cb) {
245 event_to_callback_.try_emplace(event, event_cb);
246 }
247
RegisterCustomEventCallback(const uint32_t & event,const EventCallback & event_cb)248 void AbstractNode::RegisterCustomEventCallback(const uint32_t &event, const EventCallback &event_cb) {
249 custom_event_to_callback_.try_emplace(event, event_cb);
250 }
251
Send(const NodeRole & node_role,const uint32_t & rank_id,const void * message,size_t len,int command,VectorPtr * output,const uint32_t & timeout)252 bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, const void *message, size_t len,
253 int command, VectorPtr *output, const uint32_t &timeout) {
254 MS_EXCEPTION_IF_NULL(message);
255 if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
256 MS_LOG(ERROR) << "The node role or rank_id is illegal, the worker num:" << worker_num_
257 << ", the server num:" << server_num_ << ", the rank id:" << rank_id;
258 return false;
259 }
260
261 uint64_t request_id = AddMessageTrack(1);
262 if (output != nullptr) {
263 set_message_callback(request_id, [this, request_id, rank_id, output]() {
264 receive_messages_mutex_.lock();
265 auto res = receive_messages_[request_id];
266 *output = res[rank_id];
267 receive_messages_.erase(request_id);
268 receive_messages_mutex_.unlock();
269 });
270 }
271
272 auto message_meta = std::make_shared<MessageMeta>();
273 MS_EXCEPTION_IF_NULL(message_meta);
274 message_meta->set_cmd(NodeCommand::SEND_DATA);
275 message_meta->set_request_id(request_id);
276 message_meta->set_rank_id(node_info_.rank_id_);
277 message_meta->set_role(node_info_.node_role_);
278 message_meta->set_user_cmd(command);
279
280 auto client = GetOrCreateTcpClient(rank_id, node_role);
281 MS_EXCEPTION_IF_NULL(client);
282 if (!client->SendMessage(message_meta, Protos::RAW, message, len)) {
283 MS_LOG(WARNING) << "Client send message failed.";
284 }
285 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
286 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
287 return Wait(request_id, timeout);
288 }
289
Send(const NodeRole & node_role,const uint32_t & rank_id,const std::string & msg,int command,VectorPtr * output,const uint32_t & timeout)290 bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, const std::string &msg, int command,
291 VectorPtr *output, const uint32_t &timeout) {
292 return Send(node_role, rank_id, msg.data(), msg.length(), command, output, timeout);
293 }
294
Send(const NodeRole & node_role,const std::vector<uint32_t> & rank_ids,const std::vector<std::string> & msgs,int command,std::vector<VectorPtr> * output,const uint32_t & timeout)295 bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
296 const std::vector<std::string> &msgs, int command, std::vector<VectorPtr> *output,
297 const uint32_t &timeout) {
298 uint64_t request_id = AddMessageTrack(msgs.size());
299
300 if (rank_ids.size() != msgs.size()) {
301 MS_LOG(EXCEPTION) << "The number of rank ids and messages are not equal!";
302 }
303
304 if (output != nullptr) {
305 set_message_callback(request_id, [this, request_id, &rank_ids, output]() {
306 receive_messages_mutex_.lock();
307 auto &res = receive_messages_[request_id];
308 for (auto &rank_id : rank_ids) {
309 auto &response = res[rank_id];
310 output->push_back(response);
311 }
312 receive_messages_.erase(request_id);
313 receive_messages_mutex_.unlock();
314 });
315 }
316 size_t size = rank_ids.size();
317 for (size_t it = 0; it < size; ++it) {
318 if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it), worker_num_, server_num_)) {
319 MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
320 << ", the server num:" << server_num_ << ", the rank id:" << rank_ids.at(it);
321 }
322
323 auto message_meta = std::make_shared<MessageMeta>();
324 MS_EXCEPTION_IF_NULL(message_meta);
325 message_meta->set_cmd(NodeCommand::SEND_DATA);
326 message_meta->set_request_id(request_id);
327 message_meta->set_rank_id(node_info_.rank_id_);
328 message_meta->set_role(node_info_.node_role_);
329 message_meta->set_user_cmd(command);
330
331 auto &msg = msgs.at(it);
332
333 auto client = GetOrCreateTcpClient(rank_ids.at(it), node_role);
334 MS_EXCEPTION_IF_NULL(client);
335 if (!client->SendMessage(message_meta, Protos::RAW, msg.data(), msg.size())) {
336 MS_LOG(WARNING) << "Client send message failed.";
337 }
338 }
339 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
340 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
341 return Wait(request_id, timeout);
342 }
343
SendToScheduler(const void * message,size_t len,NodeCommand node_cmd,VectorPtr * output,const uint32_t & timeout)344 bool AbstractNode::SendToScheduler(const void *message, size_t len, NodeCommand node_cmd, VectorPtr *output,
345 const uint32_t &timeout) {
346 MS_EXCEPTION_IF_NULL(message);
347
348 uint32_t expected_reponse_num = 1;
349 uint64_t request_id = AddMessageTrack(expected_reponse_num);
350 auto message_meta = std::make_shared<MessageMeta>();
351 MS_EXCEPTION_IF_NULL(message_meta);
352 message_meta->set_cmd(node_cmd);
353 message_meta->set_request_id(request_id);
354
355 MS_EXCEPTION_IF_NULL(client_to_scheduler_);
356 if (!client_to_scheduler_->SendMessage(message_meta, Protos::RAW, message, len)) {
357 MS_LOG(WARNING) << "Failed to send message" << node_cmd << "to scheduler.";
358 }
359
360 bool ret = Wait(request_id, timeout);
361 if (!ret) {
362 MS_LOG(ERROR) << "Sending message " << node_cmd << " to scheduler timeout.";
363 return ret;
364 }
365
366 // Assign the response value from scheduler.
367 if (output != nullptr) {
368 if (received_scheduler_messages_.count(request_id) == 0) {
369 MS_LOG(ERROR) << "The response message of command " << node_cmd << ", request_id " << request_id
370 << " is not received yet.";
371 return false;
372 }
373 *output = received_scheduler_messages_[request_id];
374 (void)received_scheduler_messages_.erase(request_id);
375 }
376 return ret;
377 }
378
CollectiveSendAsync(const NodeRole & node_role,const uint32_t & rank_id,const void * data,size_t size)379 uint64_t AbstractNode::CollectiveSendAsync(const NodeRole &node_role, const uint32_t &rank_id, const void *data,
380 size_t size) {
381 MS_EXCEPTION_IF_NULL(data);
382 if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
383 MS_LOG(ERROR) << "The node role or rank_id is illegal, the worker num:" << worker_num_
384 << ", the server num:" << server_num_ << ", the rank id:" << rank_id;
385 return 0;
386 }
387
388 std::shared_ptr<MessageMeta> message_meta = std::make_shared<MessageMeta>();
389 MS_EXCEPTION_IF_NULL(message_meta);
390 message_meta->set_cmd(NodeCommand::COLLECTIVE_SEND_DATA);
391 message_meta->set_rank_id(node_info_.rank_id_);
392 message_meta->set_role(node_info_.node_role_);
393
394 auto client = GetOrCreateTcpClient(rank_id, node_role);
395 MS_EXCEPTION_IF_NULL(client);
396 return SendCollectiveMeta(client, message_meta, Protos::RAW, data, size);
397 }
398
CollectiveMetaToString(const CollectiveMessageMeta & meta)399 static std::string CollectiveMetaToString(const CollectiveMessageMeta &meta) {
400 std::ostringstream os;
401 os << "{iteration:" << meta.iteration() << ", data:" << meta.weight_name() << ", send rank:" << meta.send_rank_id()
402 << ", recv rank:" << meta.recv_rank_id() << ", phase:" << meta.phase() << ", chunk index:" << meta.chunk_index()
403 << ", for index:" << meta.for_index() << "}";
404 return os.str();
405 }
406
FlCollectiveSendAsync(const CollectiveMessageMeta & collective_meta,const void * data,size_t size)407 uint64_t AbstractNode::FlCollectiveSendAsync(const CollectiveMessageMeta &collective_meta, const void *data,
408 size_t size) {
409 MS_EXCEPTION_IF_NULL(data);
410 auto recv_rank_id = collective_meta.recv_rank_id();
411 if (!CommUtil::ValidateRankId(SERVER, recv_rank_id, worker_num_, server_num_)) {
412 MS_LOG(ERROR) << "The node role or rank_id is illegal, the worker num:" << worker_num_
413 << ", the server num:" << server_num_ << ", the rank id:" << recv_rank_id;
414 return 0;
415 }
416 std::shared_ptr<MessageMeta> message_meta = std::make_shared<MessageMeta>();
417 MS_EXCEPTION_IF_NULL(message_meta);
418 message_meta->set_cmd(NodeCommand::COLLECTIVE_SEND_DATA);
419 message_meta->set_rank_id(node_info_.rank_id_);
420 message_meta->set_role(node_info_.node_role_);
421 *(message_meta->mutable_collective_meta()) = collective_meta;
422 message_meta->mutable_collective_meta()->set_enable_flag(true);
423 message_meta->mutable_collective_meta()->set_send_rank_id(node_info_.rank_id_);
424
425 MS_LOG(DEBUG) << "Send data to rank id:" << recv_rank_id
426 << ", send meta:" << CollectiveMetaToString(message_meta->collective_meta());
427 auto client = GetOrCreateTcpClient(recv_rank_id, SERVER);
428 MS_EXCEPTION_IF_NULL(client);
429 return SendCollectiveMeta(client, message_meta, Protos::RAW, data, size);
430 }
431
FlCollectiveWaitInner(const CollectiveMessageMeta & expect_meta,VectorPtr * output,const uint32_t & timeout)432 bool AbstractNode::FlCollectiveWaitInner(const CollectiveMessageMeta &expect_meta, VectorPtr *output,
433 const uint32_t &timeout) {
434 if (output == nullptr) {
435 return false;
436 }
437 auto send_rank_id = expect_meta.send_rank_id();
438 if (!CommUtil::ValidateRankId(SERVER, send_rank_id, worker_num_, server_num_)) {
439 MS_LOG(ERROR) << "The node role or rank_id is illegal, the worker num:" << worker_num_
440 << ", the server num:" << server_num_ << ", the rank id:" << send_rank_id;
441 return false;
442 }
443 auto check_meta = [](const CollectiveMessageMeta &left, const CollectiveMessageMeta &right) {
444 return left.iteration() == right.iteration() && left.weight_name() == right.weight_name() &&
445 left.recv_rank_id() == right.recv_rank_id() && left.send_rank_id() == right.send_rank_id() &&
446 left.phase() == right.phase() && left.chunk_index() == right.chunk_index() &&
447 left.for_index() == right.for_index();
448 };
449 auto iteration_num = expect_meta.iteration();
450 std::unique_lock<std::mutex> lock(fl_receive_mutex_);
451 auto &recv_data_list = fl_received_data_[send_rank_id];
452 for (uint32_t i = 0; i < timeout; i++) {
453 if (recv_data_list.empty()) {
454 fl_receive_cond_.wait_for(lock, std::chrono::seconds(1), [&recv_data_list]() { return !recv_data_list.empty(); });
455 if (recv_data_list.empty()) { // timeout
456 if (HasIterationFailed(iteration_num)) { // if result of iteration reported by other server is failed
457 MS_LOG(WARNING) << "Detect iteration " << iteration_num << " has failed";
458 return false;
459 }
460 continue;
461 }
462 }
463 while (!recv_data_list.empty()) {
464 auto first = recv_data_list.begin();
465 auto recv_meta = std::move(first->first);
466 auto recv_data = std::move(first->second);
467 recv_data_list.erase(first);
468 MS_LOG(DEBUG) << "Handle receive data from rank id:" << send_rank_id
469 << ", recv meta:" << CollectiveMetaToString(recv_meta);
470 if (recv_meta.iteration() != expect_meta.iteration()) {
471 MS_LOG(WARNING) << "Skip recv data, iteration of recv meta " << recv_meta.iteration()
472 << " != iteration of expected meta " << expect_meta.iteration();
473 continue;
474 }
475 // error data in the same iteration
476 if (!check_meta(recv_meta, expect_meta)) {
477 MS_LOG(WARNING) << "Recv meta not match expected meta, recv mata: " << CollectiveMetaToString(recv_meta)
478 << ", expected meta: " << CollectiveMetaToString(expect_meta);
479 return false;
480 }
481 *output = recv_data;
482 return true; // success to recv data
483 }
484 }
485 return false;
486 }
487
FlCollectiveWait(const CollectiveMessageMeta & expect_meta,size_t expect_size,VectorPtr * output,const uint32_t & timeout)488 bool AbstractNode::FlCollectiveWait(const CollectiveMessageMeta &expect_meta, size_t expect_size, VectorPtr *output,
489 const uint32_t &timeout) {
490 if (output == nullptr) {
491 MS_LOG(ERROR) << "FlCollectiveWait failed, parameter output invalid";
492 return false;
493 }
494 auto data_recved = FlCollectiveWaitInner(expect_meta, output, timeout);
495 if (!data_recved) {
496 MS_LOG(ERROR) << "FlCollectiveWait failed, expect meta: " << CollectiveMetaToString(expect_meta);
497 return false;
498 }
499 if (*output == nullptr) {
500 MS_LOG(ERROR) << "FlCollectiveWait failed, recv buffer invalid";
501 return false;
502 }
503 if (expect_size != (*output)->size()) {
504 MS_LOG(ERROR) << "Expected data size " << expect_size << " != recv data size " << (*output)->size()
505 << CollectiveMetaToString(expect_meta);
506 return false;
507 }
508 return true;
509 }
510
OnRecvCollectiveData(const MessageMeta & message_meta,const VectorPtr & data)511 void AbstractNode::OnRecvCollectiveData(const MessageMeta &message_meta, const VectorPtr &data) {
512 std::unique_lock<std::mutex> lock(fl_receive_mutex_);
513 auto &recv_meta = message_meta.collective_meta();
514 auto send_rank_id = recv_meta.send_rank_id();
515 MS_LOG(DEBUG) << "Receive data from rank id:" << send_rank_id << ", recv meta:" << CollectiveMetaToString(recv_meta);
516 fl_received_data_[send_rank_id].emplace_back(std::make_pair(recv_meta, data));
517 fl_receive_cond_.notify_all();
518 }
519
HasIterationFailed(uint32_t iteration_num) const520 bool AbstractNode::HasIterationFailed(uint32_t iteration_num) const {
521 return iteration_num == failed_iteration_num_ && iteration_failed_;
522 }
523
CollectiveReceiveAsync(const NodeRole & node_role,const uint32_t & rank_id,VectorPtr * output)524 std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const NodeRole &node_role, const uint32_t &rank_id,
525 VectorPtr *output) {
526 MS_EXCEPTION_IF_NULL(output);
527 if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
528 MS_LOG(ERROR) << "The node role or rank_id is illegal, the worker num:" << worker_num_
529 << ", the server num:" << server_num_ << ", the rank id:" << rank_id;
530 return std::make_pair(0, 0);
531 }
532
533 receive_callbacks_mutex_.lock();
534 uint64_t rank_request_id = NextExpectedRankRequestId(rank_id);
535 auto pair_data = std::make_pair(rank_id, rank_request_id);
536 receive_messages_done_[pair_data] = false;
537 if (received_data_.count(pair_data) > 0) {
538 auto res = received_data_[pair_data];
539 MS_EXCEPTION_IF_NULL(res);
540 *output = res;
541 (void)received_data_.erase(pair_data);
542 receive_messages_done_[pair_data] = true;
543 MS_LOG(DEBUG) << "Receive data from rank id:" << rank_id << ", the rank request id is:" << rank_request_id;
544 } else {
545 receive_callbacks_[pair_data] = [=]() mutable {
546 auto res_output = received_data_[std::make_pair(rank_id, rank_request_id)];
547 MS_EXCEPTION_IF_NULL(res_output);
548 if (*output != nullptr) {
549 MS_LOG(WARNING) << "The output is not empty.";
550 }
551 *output = res_output;
552 received_data_.erase(std::make_pair(rank_id, rank_request_id));
553 receive_messages_done_[std::make_pair(rank_id, rank_request_id)] = true;
554 MS_LOG(DEBUG) << "Receive data from rank id:" << rank_id << ", the rank request id is:" << rank_request_id;
555 };
556 }
557 receive_callbacks_mutex_.unlock();
558 return std::make_pair(rank_id, rank_request_id);
559 }
560
CollectiveWait(const std::pair<uint32_t,uint64_t> & request_id,const uint32_t & timeout)561 bool AbstractNode::CollectiveWait(const std::pair<uint32_t, uint64_t> &request_id, const uint32_t &timeout) {
562 std::unique_lock<std::mutex> lock(receive_callbacks_mutex_);
563 bool res =
564 receive_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { return receive_messages_done_[request_id]; });
565 if (receive_messages_done_.count(request_id) != 0) {
566 (void)receive_messages_done_.erase(request_id);
567 }
568 return res;
569 }
570
persistent_state() const571 PersistentState AbstractNode::persistent_state() const { return persistent_state_; }
set_persistent_state(PersistentState persistent_state)572 void AbstractNode::set_persistent_state(PersistentState persistent_state) { persistent_state_ = persistent_state; }
573
worker_num() const574 uint32_t AbstractNode::worker_num() const { return worker_num_; }
575
server_num() const576 uint32_t AbstractNode::server_num() const { return server_num_; }
577
set_worker_num(const uint32_t & worker_num)578 void AbstractNode::set_worker_num(const uint32_t &worker_num) { worker_num_ = worker_num; }
579
set_server_num(const uint32_t & server_num)580 void AbstractNode::set_server_num(const uint32_t &server_num) { server_num_ = server_num; }
581
scheduler_ip() const582 std::string AbstractNode::scheduler_ip() const { return scheduler_ip_; }
583
set_scheduler_ip(const std::string & scheduler_ip)584 void AbstractNode::set_scheduler_ip(const std::string &scheduler_ip) { scheduler_ip_ = scheduler_ip; }
585
scheduler_port() const586 uint16_t AbstractNode::scheduler_port() const { return scheduler_port_; }
587
set_scheduler_port(const uint16_t & scheduler_port)588 void AbstractNode::set_scheduler_port(const uint16_t &scheduler_port) { scheduler_port_ = scheduler_port; }
589
cluster_state() const590 ClusterState AbstractNode::cluster_state() const { return current_cluster_state_; }
591
set_handler(const RequestHandler & handler)592 void AbstractNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; }
593
Response(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)594 void AbstractNode::Response(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
595 const void *data, size_t size) {
596 MS_EXCEPTION_IF_NULL(conn);
597 MS_EXCEPTION_IF_NULL(meta);
598 MS_EXCEPTION_IF_NULL(data);
599 MS_EXCEPTION_IF_NULL(server_);
600 meta->set_role(node_info_.node_role_);
601 meta->set_rank_id(node_info_.rank_id_);
602 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
603 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << meta->request_id();
604 if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
605 MS_LOG(WARNING) << "Server response message failed.";
606 }
607 }
608
StartHeartbeatTimer(const std::shared_ptr<TcpClient> & client)609 void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) {
610 MS_EXCEPTION_IF_NULL(client);
611 MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_)
612 << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_
613 << " begin send heartbeat to the scheduler!";
614 heart_beat_thread_ = std::make_unique<std::thread>([&]() {
615 uint32_t connect_interval = PSContext::instance()->cluster_config().connect_interval;
616 uint32_t heartbeat_interval = PSContext::instance()->cluster_config().heartbeat_interval * 1000;
617 uint32_t reconnect_interval = 0;
618 if (heartbeat_interval > connect_interval) {
619 MS_LOG(WARNING) << "heartbeat_interval [" << heartbeat_interval << "] is larger than connect_interval ["
620 << connect_interval << "], reset connect_interval to " << heartbeat_interval;
621 }
622 while (!is_finish_.load()) {
623 if (!Heartbeat(client)) {
624 MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
625 << ", the node id is:" << node_info_.node_id_ << " Send heartbeat failed!";
626 if (CheckSchedulerTimeout()) {
627 MS_LOG(WARNING) << "Scheduler is Timeout, please recovery.";
628 }
629 } else {
630 UpdateSchedulerTime();
631 }
632
633 if (!is_already_finished_ && (client->connection_status() == -1)) {
634 if (reconnect_interval > connect_interval) {
635 MS_LOG(WARNING) << "Connection to Scheduler is disconnected, try to reconnect.";
636 reconnect_interval = 0;
637 ConnectToScheduler();
638 } else {
639 reconnect_interval += heartbeat_interval;
640 }
641 }
642
643 std::this_thread::sleep_for(std::chrono::milliseconds(heartbeat_interval));
644 }
645 });
646 MS_EXCEPTION_IF_NULL(heart_beat_thread_);
647 }
648
Heartbeat(const std::shared_ptr<TcpClient> & client)649 bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client) {
650 MS_EXCEPTION_IF_NULL(client);
651 if (client->connection_status() != 1) {
652 return false;
653 }
654 auto meta = std::make_shared<MessageMeta>();
655 MS_EXCEPTION_IF_NULL(meta);
656 meta->set_cmd(NodeCommand::HEARTBEAT);
657
658 HeartbeatMessage heartbeat_message;
659 heartbeat_message.set_node_id(node_info_.node_id_);
660 heartbeat_message.set_persistent_state(PersistentState::NOT_ENABLE_PERSIST);
661
662 // The worker role does not support disaster recovery currently.
663 if (EnableRecovery() && role() == NodeRole::SERVER) {
664 heartbeat_message.set_persistent_state(persistent_state_);
665 }
666
667 if (!SendMessageSync(client, meta, Protos::PROTOBUF, heartbeat_message.SerializeAsString().data(),
668 heartbeat_message.ByteSizeLong(), kCommTimeoutInSeconds)) {
669 MS_LOG(WARNING) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!";
670 return false;
671 }
672 return true;
673 }
674
UpdateSchedulerTime()675 void AbstractNode::UpdateSchedulerTime() {
676 struct timeval current_time {};
677 (void)gettimeofday(¤t_time, nullptr);
678 scheduler_time_ = current_time;
679 MS_LOG(DEBUG) << "Update scheduler time, the current time is: " << current_time.tv_sec;
680 }
681
CheckSchedulerTimeout() const682 bool AbstractNode::CheckSchedulerTimeout() const {
683 struct timeval current_time {};
684 (void)gettimeofday(¤t_time, nullptr);
685 int64_t old_time = scheduler_time_.tv_sec + PSContext::instance()->cluster_config().scheduler_timeout;
686 if (old_time < current_time.tv_sec) {
687 return true;
688 }
689 return false;
690 }
691
ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)692 void AbstractNode::ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
693 MS_EXCEPTION_IF_NULL(meta);
694 MS_EXCEPTION_IF_NULL(data);
695 HeartbeatRespMessage heartbeat_resp_message;
696 CHECK_RETURN_TYPE(heartbeat_resp_message.ParseFromArray(data, SizeToInt(size)));
697
698 if (heartbeat_resp_message.cluster_state() != current_cluster_state_ &&
699 current_cluster_state_ != ClusterState::CLUSTER_SCALE_IN &&
700 current_cluster_state_ != ClusterState::CLUSTER_SCALE_OUT) {
701 UpdateClusterState(heartbeat_resp_message.cluster_state());
702 }
703 MS_LOG(DEBUG) << "The current cluster state from heartbeat:"
704 << CommUtil::ClusterStateToString(current_cluster_state_);
705
706 std::string timeoutNodeId;
707
708 all_nodes_info_.clear();
709 for (const auto &it : heartbeat_resp_message.servers_meta()) {
710 NodeInfo info;
711 info.ip_ = it.ip();
712 info.node_id_ = it.node_id();
713 info.port_ = it.port();
714 info.node_role_ = it.role();
715 info.rank_id_ = it.rank_id();
716 info.is_alive = it.is_alive();
717
718 if (!info.is_alive) {
719 timeoutNodeId += (info.node_id_ + " ");
720 }
721
722 all_nodes_info_[info.node_id_] = info;
723 MS_LOG(DEBUG) << "The node id:" << info.node_id_ << ", the rank id:" << info.rank_id_
724 << ", the node role:" << CommUtil::NodeRoleToString(info.node_role_) << " is alive:" << info.is_alive;
725 }
726 bool is_worker = heartbeat_resp_message.is_worker();
727 bool is_ps_mode = PSContext::instance()->server_mode() == ps::kServerModePS;
728 bool not_enable_recover_node_timeout = (is_worker && is_ps_mode);
729
730 if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
731 if (node_recovery_ == nullptr || not_enable_recover_node_timeout) {
732 MS_LOG(INFO) << "The recovery is disabled. Trigger NODE_TIMEOUT event.";
733 // Avoid other methods blocking endlessly when NODE_TIMEOUT event is triggered.
734 is_ready_ = true;
735 wait_start_cond_.notify_all();
736 is_finish_ = true;
737 wait_finish_cond_.notify_all();
738 OnEventCallback(ClusterEvent::NODE_TIMEOUT);
739 } else {
740 MS_LOG(INFO) << "The nodes:" << timeoutNodeId
741 << "is support recovery, users can pull up this node to restore the cluster.";
742 }
743 }
744
745 if (!EnableRecovery()) {
746 return;
747 }
748
749 PersistentCommand persistent_cmd = heartbeat_resp_message.persistent_cmd();
750 // The worker role does not support disaster recovery for the time being.
751 if (role() == NodeRole::SERVER && persistent_cmd == PersistentCommand::BEGIN_PERSIST &&
752 persistent_state_ != PersistentState::PERSISTING) {
753 OnEventCallback(ClusterEvent::ON_BEGIN_PERSIST);
754 }
755 }
756
FetchServers(const std::shared_ptr<TcpClient> & client)757 void AbstractNode::FetchServers(const std::shared_ptr<TcpClient> &client) {
758 MS_EXCEPTION_IF_NULL(client);
759 auto meta = std::make_shared<MessageMeta>();
760 MS_EXCEPTION_IF_NULL(meta);
761 meta->set_cmd(NodeCommand::FETCH_METADATA);
762
763 FetchServersMessage fetch_servers;
764 fetch_servers.set_node_id(node_info_.node_id_);
765 if (!SendMessageSync(client, meta, Protos::PROTOBUF, fetch_servers.SerializeAsString().data(),
766 fetch_servers.ByteSizeLong())) {
767 MS_LOG(EXCEPTION) << "Fetch servers address timeout!";
768 }
769 }
770
ProcessFetchServersResp(const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)771 void AbstractNode::ProcessFetchServersResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
772 MS_EXCEPTION_IF_NULL(meta);
773 MS_EXCEPTION_IF_NULL(data);
774 FetchServersRespMessage fetch_servers_resp_message;
775 CHECK_RETURN_TYPE(fetch_servers_resp_message.ParseFromArray(data, SizeToInt(size)));
776
777 nodes_address_.clear();
778 for (const auto &it : fetch_servers_resp_message.servers_meta()) {
779 nodes_address_[std::make_pair(it.role(), it.rank_id())] = std::make_pair(it.ip(), it.port());
780 MS_LOG(INFO) << "The server ip is:" << it.ip() << ", the port is:" << it.port();
781 }
782 }
783
ProcessReceiveSchedulerResp(const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)784 void AbstractNode::ProcessReceiveSchedulerResp(const std::shared_ptr<MessageMeta> &meta, const void *data,
785 size_t size) {
786 MS_EXCEPTION_IF_NULL(meta);
787 MS_EXCEPTION_IF_NULL(data);
788 std::lock_guard<std::mutex> lock(receive_messages_mutex_);
789
790 const uint64_t request_id = meta->request_id();
791 VectorPtr received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
792 if (size > 0) {
793 size_t dest_size = size;
794 size_t src_size = size;
795 auto ret = memcpy_s(received_data.get()->data(), dest_size, data, src_size);
796 if (ret != EOK) {
797 MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
798 }
799 }
800 received_scheduler_messages_[request_id] = received_data;
801 }
802
ProcessSendMetadata(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)803 void AbstractNode::ProcessSendMetadata(const std::shared_ptr<TcpConnection> &conn,
804 const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
805 size_t size) {
806 MS_EXCEPTION_IF_NULL(conn);
807 MS_EXCEPTION_IF_NULL(meta);
808 MS_EXCEPTION_IF_NULL(data);
809 if (is_current_node_scale_in_) {
810 MS_LOG(WARNING) << "Trigger cluster scale in done event.";
811 node_info_.rank_id_ = UINT32_MAX;
812 OnEventCallback(ClusterEvent::CLUSTER_SCALE_IN_DONE);
813 return;
814 }
815 SendMetadataMessage send_meta_message;
816 send_meta_message.ParseFromArray(data, SizeToInt(size));
817 worker_num_ = send_meta_message.worker_num();
818 server_num_ = send_meta_message.server_num();
819 if (send_meta_message.rank_id() < 0) {
820 MS_LOG(EXCEPTION) << "The rank id is wrong.";
821 }
822 node_info_.rank_id_ = send_meta_message.rank_id();
823 UpdateClusterState(send_meta_message.cluster_state());
824 MS_LOG(INFO) << "The send metadata worker num:" << worker_num_ << ", server num:" << server_num_
825 << ", cluster state is:" << CommUtil::ClusterStateToString(current_cluster_state_)
826 << ", the rank id:" << node_info_.rank_id_;
827
828 client_mutex_.lock();
829 nodes_address_.clear();
830 for (const auto &it : send_meta_message.servers_meta()) {
831 nodes_address_[std::make_pair(it.role(), it.rank_id())] = std::make_pair(it.ip(), it.port());
832 MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(it.role()) << ", node id:" << it.node_id()
833 << ", rank id:" << it.rank_id() << ", ip:" << it.ip() << ", port:" << it.port();
834 }
835 client_mutex_.unlock();
836 if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
837 MS_LOG(WARNING) << "Sever response message failed.";
838 }
839 is_ready_ = true;
840 wait_start_cond_.notify_all();
841
842 if (current_cluster_state_ == ClusterState::CLUSTER_SCALE_OUT) {
843 MS_LOG(WARNING) << "Trigger cluster scale out done event.";
844 OnEventCallback(ClusterEvent::CLUSTER_SCALE_OUT_DONE);
845 }
846
847 if (current_cluster_state_ == ClusterState::CLUSTER_SCALE_IN) {
848 MS_LOG(WARNING) << "Trigger cluster scale in done event.";
849 OnEventCallback(ClusterEvent::CLUSTER_SCALE_IN_DONE);
850 }
851
852 if (cancelSafeModeFn_ && current_cluster_state_ == ClusterState::CLUSTER_SCALE_OUT_ROLLBACK) {
853 MS_LOG(WARNING) << "Trigger cluster scale out rollback done event.";
854 OnEventCallback(ClusterEvent::CLUSTER_SCALE_OUT_ROLLBACK_DONE);
855 cancelSafeModeFn_();
856 }
857
858 std::lock_guard<std::mutex> lock(client_mutex_);
859 connected_nodes_.clear();
860
861 OnEventCallback(ClusterEvent::ON_SEND_META_DATA);
862 }
863
ProcessFinish(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)864 void AbstractNode::ProcessFinish(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
865 const Protos &, const void *data, size_t size) {
866 MS_EXCEPTION_IF_NULL(conn);
867 MS_EXCEPTION_IF_NULL(meta);
868 MS_EXCEPTION_IF_NULL(data);
869 if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
870 MS_LOG(WARNING) << "Server response message failed.";
871 }
872 is_finish_ = true;
873 wait_finish_cond_.notify_all();
874 }
875
ProcessScaleOutDone(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)876 void AbstractNode::ProcessScaleOutDone(const std::shared_ptr<TcpConnection> &conn,
877 const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
878 size_t size) {
879 MS_EXCEPTION_IF_NULL(conn);
880 MS_EXCEPTION_IF_NULL(meta);
881 MS_EXCEPTION_IF_NULL(data);
882 MS_LOG(INFO) << "This node receive a scale out done from scheduler.";
883 if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
884 MS_LOG(WARNING) << "Server response message failed.";
885 }
886 is_ready_ = true;
887 UpdateClusterState(ClusterState::CLUSTER_READY);
888 }
889
ProcessScaleInDone(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)890 void AbstractNode::ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn,
891 const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
892 size_t size) {
893 MS_EXCEPTION_IF_NULL(conn);
894 MS_EXCEPTION_IF_NULL(meta);
895 MS_EXCEPTION_IF_NULL(data);
896 if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
897 MS_LOG(WARNING) << "Server response message failed.";
898 }
899 is_ready_ = true;
900 UpdateClusterState(ClusterState::CLUSTER_READY);
901 }
902
ProcessEvent(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)903 void AbstractNode::ProcessEvent(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
904 const Protos &, const void *data, size_t size) {
905 MS_EXCEPTION_IF_NULL(conn);
906 MS_EXCEPTION_IF_NULL(meta);
907 MS_EXCEPTION_IF_NULL(data);
908 EventRespMessage event_resp_message;
909 event_resp_message.ParseFromArray(data, SizeToInt(size));
910 uint32_t event = event_resp_message.event();
911 if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
912 MS_LOG(WARNING) << "Server response message failed.";
913 }
914 MS_LOG(INFO) << "This node receive a event:" << event;
915 if (event == static_cast<uint32_t>(ps::UserDefineEvent::kNodeTimeout)) {
916 OnEventCallback(ClusterEvent::NODE_TIMEOUT);
917 } else {
918 OnCustomEventCallback(event);
919 }
920 }
921
ProcessScaleOutRollback(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)922 void AbstractNode::ProcessScaleOutRollback(const std::shared_ptr<TcpConnection> &conn,
923 const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
924 size_t size) {
925 MS_EXCEPTION_IF_NULL(conn);
926 MS_EXCEPTION_IF_NULL(meta);
927 MS_EXCEPTION_IF_NULL(data);
928
929 if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
930 MS_LOG(WARNING) << "Server response message failed.";
931 }
932
933 UpdateClusterState(ClusterState::CLUSTER_SCALE_OUT_ROLLBACK);
934
935 MS_LOG(INFO) << "[Scale out rollback]: begin to set scale out rollback.";
936 Register(client_to_scheduler_);
937 std::lock_guard<std::mutex> lock(client_mutex_);
938 connected_nodes_.clear();
939
940 MS_LOG(INFO) << "The node begin to start scale out rollback.";
941 }
942
ProcessScaleOut(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)943 void AbstractNode::ProcessScaleOut(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
944 const Protos &, const void *data, size_t size) {
945 MS_EXCEPTION_IF_NULL(conn);
946 MS_EXCEPTION_IF_NULL(meta);
947 MS_EXCEPTION_IF_NULL(data);
948
949 ScaleOutMessage scale_out_message;
950 scale_out_message.ParseFromArray(data, SizeToInt(size));
951 int32_t worker_num = scale_out_message.worker_num();
952 int32_t server_num = scale_out_message.server_num();
953 MS_LOG(WARNING) << "The scale out worker num:" << worker_num << ", the server num:" << server_num;
954
955 if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
956 MS_LOG(WARNING) << "Server response message failed.";
957 }
958 OnEventCallback(ClusterEvent::READY_FOR_SCALE_OUT);
959 UpdateClusterState(ClusterState::CLUSTER_SCALE_OUT);
960 is_ready_ = false;
961 }
962
ProcessScaleIn(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)963 void AbstractNode::ProcessScaleIn(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
964 const Protos &, const void *data, size_t size) {
965 MS_EXCEPTION_IF_NULL(conn);
966 MS_EXCEPTION_IF_NULL(meta);
967 MS_EXCEPTION_IF_NULL(data);
968
969 ScaleInMessage scale_in_message;
970 scale_in_message.ParseFromArray(data, SizeToInt(size));
971 int32_t worker_num = scale_in_message.worker_num();
972 int32_t server_num = scale_in_message.server_num();
973 MS_LOG(WARNING) << "The scale in worker num:" << worker_num << ", the server num:" << server_num;
974
975 is_current_node_scale_in_ = scale_in_message.is_node_scale_in();
976 if (is_current_node_scale_in_) {
977 MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
978 << " the node id:" << node_info_.node_id_ << " is a scale in node!";
979 } else {
980 MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
981 << " the node id:" << node_info_.node_id_ << " is not a scale in node!";
982 }
983
984 if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
985 MS_LOG(WARNING) << "Server response message failed.";
986 }
987 OnEventCallback(ClusterEvent::READY_FOR_SCALE_IN);
988 UpdateClusterState(ClusterState::CLUSTER_SCALE_IN);
989 is_ready_ = false;
990 }
991
ProcessSchedulerRecovery(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)992 void AbstractNode::ProcessSchedulerRecovery(const std::shared_ptr<TcpConnection> &conn,
993 const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
994 size_t size) {
995 MS_EXCEPTION_IF_NULL(conn);
996 MS_EXCEPTION_IF_NULL(meta);
997 MS_EXCEPTION_IF_NULL(data);
998 SendMetadataMessage scheduler_recovery_message;
999 (void)scheduler_recovery_message.ParseFromArray(data, SizeToInt(size));
1000 worker_num_ = scheduler_recovery_message.worker_num();
1001 server_num_ = scheduler_recovery_message.server_num();
1002 uint32_t rank_id = scheduler_recovery_message.rank_id();
1003
1004 MS_LOG(INFO) << "[Scheduler Recovery]: The scheduler recovery worker num:" << worker_num_
1005 << ", the server num:" << server_num_ << ", the rank id: " << rank_id;
1006
1007 if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
1008 MS_LOG(WARNING) << "[Scheduler Recovery]: Server response message failed.";
1009 }
1010 MS_LOG(INFO) << "[Scheduler Recovery]: Server response message success!.";
1011
1012 ConnectToScheduler();
1013 bool connected = client_to_scheduler_->WaitConnected();
1014 if (!connected) {
1015 MS_LOG(WARNING) << "[Scheduler Recovery]: Server node connect to scheduler timedout!";
1016 }
1017
1018 Register(client_to_scheduler_);
1019 std::lock_guard<std::mutex> lock(client_mutex_);
1020 connected_nodes_.clear();
1021 MS_LOG(INFO) << "[Scheduler Recovery]: This node connect to scheduler successful!";
1022
1023 if (cancelSafeModeFn_ && (current_cluster_state_ == ClusterState::CLUSTER_SCALE_IN ||
1024 current_cluster_state_ == ClusterState::CLUSTER_SCALE_OUT)) {
1025 MS_LOG(INFO) << "[Scheduler Recovery]: Cancel Safe mode for " << kClusterState.at(current_cluster_state_);
1026 cancelSafeModeFn_();
1027 }
1028
1029 UpdateClusterState(ClusterState::CLUSTER_SCHEDULER_RECOVERY);
1030 is_ready_ = false;
1031 }
1032
Disconnect(const std::shared_ptr<TcpClient> & client,const uint32_t & timeout)1033 bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout) {
1034 MS_EXCEPTION_IF_NULL(client);
1035 auto meta = std::make_shared<MessageMeta>();
1036 MS_EXCEPTION_IF_NULL(meta);
1037 meta->set_cmd(NodeCommand::FINISH);
1038
1039 std::string finish_message = node_info_.node_id_;
1040
1041 if (!SendMessageSync(client, meta, Protos::RAW, finish_message.data(), finish_message.length())) {
1042 MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
1043 << " the node id:" << node_info_.node_id_ << " send Finish Message timeout!";
1044 }
1045 return WaitForDisconnect(timeout);
1046 }
1047
WaitForDisconnect(const uint32_t & timeout)1048 bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) {
1049 // If the cluster state is NODE_TIMEOUT, this node is already disconnected.
1050 if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
1051 return true;
1052 }
1053 std::unique_lock<std::mutex> lock(wait_finish_mutex_);
1054 auto condition_func = [this] {
1055 if (is_finish_.load()) {
1056 MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success finish!";
1057 }
1058 return is_finish_.load();
1059 };
1060
1061 bool res;
1062 if (timeout == UINT32_MAX) {
1063 // Caller should use this method to help block the thread.
1064 wait_finish_cond_.wait(lock, condition_func);
1065 res = true;
1066 } else {
1067 res = wait_finish_cond_.wait_for(lock, std::chrono::seconds(timeout), condition_func);
1068 }
1069
1070 return res;
1071 }
1072
InitClientToServer()1073 void AbstractNode::InitClientToServer() {
1074 // create tcp client to myself in case of event dispatch failed when Send msg to server 0 failed
1075 client_to_server_ = std::make_shared<TcpClient>(node_info_.ip_, node_info_.port_, node_info_.node_role_);
1076 MS_EXCEPTION_IF_NULL(client_to_server_);
1077 client_to_server_->Init();
1078 MS_LOG(INFO) << "The node start a tcp client to this node!";
1079 }
1080
InitClientToScheduler()1081 bool AbstractNode::InitClientToScheduler() {
1082 if (config_ == nullptr) {
1083 MS_LOG(WARNING) << "The config is empty.";
1084 return false;
1085 }
1086 client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_ip_, scheduler_port_, NodeRole::SCHEDULER);
1087 MS_EXCEPTION_IF_NULL(client_to_scheduler_);
1088 client_to_scheduler_->SetMessageCallback(
1089 [&](const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data, size_t size) {
1090 try {
1091 MS_EXCEPTION_IF_NULL(meta);
1092 MS_EXCEPTION_IF_NULL(data);
1093 if (handlers_.count(meta->cmd()) == 0) {
1094 MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
1095 }
1096 if (handlers_[meta->cmd()] != nullptr) {
1097 const auto &handler_ptr = handlers_[meta->cmd()];
1098 (this->*handler_ptr)(meta, data, size);
1099 }
1100 NotifyMessageArrival(meta);
1101 } catch (const std::exception &e) {
1102 MsException::Instance().SetException();
1103 }
1104 });
1105 ConnectToScheduler();
1106 StartHeartbeatTimer(client_to_scheduler_);
1107 MS_LOG(INFO) << "Start heartbeat timer!";
1108
1109 bool wait_res = client_to_scheduler_->WaitConnected();
1110 if (!wait_res) {
1111 is_ready_ = true;
1112 }
1113 return wait_res;
1114 }
ConnectToScheduler()1115 void AbstractNode::ConnectToScheduler() {
1116 client_to_scheduler_->Init();
1117 if (TcpClient::is_started()) {
1118 return;
1119 }
1120
1121 if (client_to_scheduler_thread_ != nullptr && client_to_scheduler_thread_->joinable()) {
1122 client_to_scheduler_thread_->join();
1123 }
1124 client_to_scheduler_thread_ = std::make_unique<std::thread>([this]() {
1125 MS_LOG(INFO) << "The node start a tcp client!";
1126 client_to_scheduler_->Start();
1127 });
1128 }
1129
GetOrCreateTcpClient(const uint32_t & rank_id,const NodeRole & role)1130 const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const uint32_t &rank_id, const NodeRole &role) {
1131 std::lock_guard<std::mutex> lock(client_mutex_);
1132 auto key = std::make_pair(role, rank_id);
1133 if (connected_nodes_.find(key) != connected_nodes_.end()) {
1134 return connected_nodes_[key];
1135 } else {
1136 if (nodes_address_.find(key) == nodes_address_.end()) {
1137 MS_LOG(EXCEPTION) << "Worker receive nodes info from scheduler failed. Role: " << role << ", rank: " << rank_id;
1138 }
1139 if (config_ == nullptr) {
1140 MS_LOG(EXCEPTION) << "The config is empty.";
1141 }
1142
1143 MS_LOG(INFO) << "Create tcp client for role: " << role << ", rank: " << rank_id;
1144 std::string ip = nodes_address_[key].first;
1145 uint16_t port = nodes_address_[key].second;
1146 auto client = std::make_shared<TcpClient>(ip, port, role);
1147 MS_EXCEPTION_IF_NULL(client);
1148 client->SetMessageCallback([&](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
1149 size_t size) {
1150 switch (meta->cmd()) {
1151 case NodeCommand::SEND_DATA:
1152 ProcessSendDataResp(meta, protos, data, size);
1153 RunMessageCallback(meta->request_id());
1154 break;
1155 case NodeCommand::COLLECTIVE_SEND_DATA:
1156 MS_LOG(DEBUG) << "The Node id:" << node_info_.node_id_ << " receive a collective_send_data message response!";
1157 break;
1158 case NodeCommand::SEND_EVENT:
1159 MS_LOG(DEBUG) << "The Node id:" << node_info_.node_id_ << " receive a send_event command message response!";
1160 break;
1161 default:
1162 MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
1163 }
1164 NotifyMessageArrival(meta);
1165 });
1166 client->Init();
1167 connected_nodes_[key] = client;
1168 return connected_nodes_[key];
1169 }
1170 }
1171
SendMessageSync(const std::shared_ptr<TcpClient> & client,const CommMessage & message,const uint32_t & timeout)1172 bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
1173 const uint32_t &timeout) {
1174 MS_EXCEPTION_IF_NULL(client);
1175 uint64_t request_id = AddMessageTrack(1);
1176 const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
1177 client->SendMessage(message);
1178 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
1179 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
1180 return Wait(request_id, timeout);
1181 }
1182
SendMessageSync(const std::shared_ptr<TcpClient> & client,const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size,const uint32_t & timeout)1183 bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta,
1184 const Protos &protos, const void *data, size_t size, const uint32_t &timeout) {
1185 MS_EXCEPTION_IF_NULL(client);
1186 MS_EXCEPTION_IF_NULL(meta);
1187 MS_EXCEPTION_IF_NULL(data);
1188 uint64_t request_id = AddMessageTrack(1);
1189 meta->set_request_id(request_id);
1190 client->SendMessage(meta, protos, data, size);
1191 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
1192 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
1193 return Wait(request_id, timeout);
1194 }
1195
SendCollectiveMeta(const std::shared_ptr<TcpClient> & client,const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size)1196 uint64_t AbstractNode::SendCollectiveMeta(const std::shared_ptr<TcpClient> &client,
1197 const std::shared_ptr<MessageMeta> &meta, const Protos &protos,
1198 const void *data, size_t size) {
1199 MS_EXCEPTION_IF_NULL(client);
1200 MS_EXCEPTION_IF_NULL(meta);
1201 MS_EXCEPTION_IF_NULL(data);
1202 uint64_t request_id = AddMessageTrack(1);
1203 meta->set_request_id(request_id);
1204 client->SendMessage(meta, protos, data, size);
1205 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
1206 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
1207 return request_id;
1208 }
1209
ProcessCollectiveSendData(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size)1210 void AbstractNode::ProcessCollectiveSendData(const std::shared_ptr<TcpConnection> &conn,
1211 const std::shared_ptr<MessageMeta> &meta, const Protos &protos,
1212 const void *data, size_t size) {
1213 MS_EXCEPTION_IF_NULL(conn);
1214 MS_EXCEPTION_IF_NULL(meta);
1215 MS_EXCEPTION_IF_NULL(data);
1216 if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
1217 MS_LOG(WARNING) << "Server response message failed.";
1218 }
1219 RunReceiveCallback(meta, protos, data, size);
1220 }
1221
ProcessSendData(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)1222 void AbstractNode::ProcessSendData(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
1223 const Protos &, const void *data, size_t size) {
1224 MS_EXCEPTION_IF_NULL(conn);
1225 MS_EXCEPTION_IF_NULL(meta);
1226 MS_EXCEPTION_IF_NULL(data);
1227 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
1228 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << meta->request_id()
1229 << " the current time is:"
1230 << std::chrono::time_point_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now())
1231 .time_since_epoch()
1232 .count();
1233 request_handler_(conn, meta, data, size);
1234 }
1235
NotifyMessageArrival(const std::shared_ptr<MessageMeta> & meta)1236 void AbstractNode::NotifyMessageArrival(const std::shared_ptr<MessageMeta> &meta) {
1237 MS_EXCEPTION_IF_NULL(meta);
1238 std::lock_guard<std::mutex> lock(message_tracker_mutex_);
1239 uint64_t request_id = meta->request_id();
1240 if (message_tracker_.count(request_id)) {
1241 message_tracker_[request_id].second++;
1242 } else {
1243 MS_LOG(WARNING) << "The requset id:" << request_id << " is removed.";
1244 }
1245 message_tracker_cond_.notify_all();
1246 }
1247
RunReceiveCallback(const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)1248 void AbstractNode::RunReceiveCallback(const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
1249 size_t size) {
1250 MS_EXCEPTION_IF_NULL(meta);
1251 MS_EXCEPTION_IF_NULL(data);
1252 std::shared_ptr<std::vector<unsigned char>> received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
1253 size_t dest_size = size;
1254 size_t src_size = size;
1255 int ret = memcpy_s(received_data->data(), dest_size, data, src_size);
1256 if (ret != 0) {
1257 MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
1258 }
1259 if (meta->collective_meta().enable_flag()) {
1260 OnRecvCollectiveData(*meta, received_data);
1261 return;
1262 }
1263 receive_callbacks_mutex_.lock();
1264 uint32_t rank_id = meta->rank_id();
1265 // When receiving a collective message, Then generate rank request id,compare with the desired rank request id,
1266 // If they are equal, then call the callback function
1267 uint64_t rank_request_id = NextActualRankRequestId(rank_id);
1268 received_data_[std::make_pair(rank_id, rank_request_id)] = received_data;
1269 MS_LOG(DEBUG) << "Run Receive data callback,the rank id:" << rank_id << ", the rank request id is:" << rank_request_id
1270 << ", the send request id is:" << meta->request_id() << " the size is:" << size;
1271 auto it = receive_callbacks_.find(std::make_pair(rank_id, rank_request_id));
1272 if (it != receive_callbacks_.end()) {
1273 if (receive_messages_done_.count(std::make_pair(rank_id, rank_request_id)) != 0) {
1274 if (it->second) {
1275 it->second();
1276 }
1277 }
1278 receive_cond_.notify_all();
1279 receive_callbacks_.erase(it);
1280 }
1281 receive_callbacks_mutex_.unlock();
1282 }
1283
NextExpectedRankRequestId(const uint32_t & rank_id)1284 uint64_t AbstractNode::NextExpectedRankRequestId(const uint32_t &rank_id) {
1285 std::lock_guard<std::mutex> lock(rank_request_ids_mutex);
1286 uint64_t rank_request_id = 1;
1287 if (expected_rank_request_ids_.count(rank_id)) {
1288 rank_request_id = ++expected_rank_request_ids_[rank_id];
1289 expected_rank_request_ids_[rank_id] = rank_request_id;
1290 } else {
1291 expected_rank_request_ids_[rank_id] = rank_request_id;
1292 }
1293 return rank_request_id;
1294 }
1295
NextActualRankRequestId(const uint32_t & rank_id)1296 uint64_t AbstractNode::NextActualRankRequestId(const uint32_t &rank_id) {
1297 std::lock_guard<std::mutex> lock(rank_request_ids_mutex);
1298 uint64_t rank_request_id = 1;
1299 if (actual_rank_request_ids_.count(rank_id)) {
1300 rank_request_id = ++actual_rank_request_ids_[rank_id];
1301 actual_rank_request_ids_[rank_id] = rank_request_id;
1302 } else {
1303 actual_rank_request_ids_[rank_id] = rank_request_id;
1304 }
1305 return rank_request_id;
1306 }
1307
InitCommandHandler()1308 void AbstractNode::InitCommandHandler() {
1309 handlers_[NodeCommand::HEARTBEAT] = &AbstractNode::ProcessHeartbeatResp;
1310 handlers_[NodeCommand::REGISTER] = &AbstractNode::ProcessRegisterResp;
1311 handlers_[NodeCommand::FETCH_METADATA] = &AbstractNode::ProcessFetchServersResp;
1312 handlers_[NodeCommand::FINISH] = nullptr;
1313 handlers_[NodeCommand::SCALE_OUT_DONE] = nullptr;
1314 handlers_[NodeCommand::SCALE_IN_DONE] = nullptr;
1315 handlers_[NodeCommand::SEND_EVENT] = nullptr;
1316 RegisterActorRouteTableRspHandler();
1317 RegisterInitCollectCommResphandler();
1318 RegisterRecoveryRespHandler();
1319 }
1320
RegisterActorRouteTableRspHandler()1321 void AbstractNode::RegisterActorRouteTableRspHandler() {
1322 handlers_[NodeCommand::REGISTER_ACTOR_ROUTE] = &AbstractNode::ProcessReceiveSchedulerResp;
1323 handlers_[NodeCommand::DELETE_ACTOR_ROUTE] = &AbstractNode::ProcessReceiveSchedulerResp;
1324 handlers_[NodeCommand::LOOKUP_ACTOR_ROUTE] = &AbstractNode::ProcessReceiveSchedulerResp;
1325 }
1326
InitServerHandler()1327 void AbstractNode::InitServerHandler() {
1328 server_handler_[NodeCommand::SEND_METADATA] = &AbstractNode::ProcessSendMetadata;
1329 server_handler_[NodeCommand::FINISH] = &AbstractNode::ProcessFinish;
1330 server_handler_[NodeCommand::SEND_DATA] = &AbstractNode::ProcessSendData;
1331 server_handler_[NodeCommand::COLLECTIVE_SEND_DATA] = &AbstractNode::ProcessCollectiveSendData;
1332 server_handler_[NodeCommand::SCALE_OUT] = &AbstractNode::ProcessScaleOut;
1333 server_handler_[NodeCommand::SCALE_IN] = &AbstractNode::ProcessScaleIn;
1334 server_handler_[NodeCommand::SCALE_OUT_DONE] = &AbstractNode::ProcessScaleOutDone;
1335 server_handler_[NodeCommand::SCALE_IN_DONE] = &AbstractNode::ProcessScaleInDone;
1336 server_handler_[NodeCommand::SEND_EVENT] = &AbstractNode::ProcessEvent;
1337 server_handler_[NodeCommand::SCHEDULER_RECOVERY] = &AbstractNode::ProcessSchedulerRecovery;
1338 server_handler_[NodeCommand::PREPARE_BUILDING_NETWORK] = &AbstractNode::ProcessPrepareBuildingNetwork;
1339 server_handler_[NodeCommand::SCALE_OUT_ROLLBACK] = &AbstractNode::ProcessScaleOutRollback;
1340 }
1341
InitNodeInfo(const NodeRole & role)1342 void AbstractNode::InitNodeInfo(const NodeRole &role) {
1343 MS_EXCEPTION_IF_NULL(config_);
1344 MS_EXCEPTION_IF_NULL(server_);
1345 if (PSContext::instance()->node_id().empty() && config_->Exists(kNodeId)) {
1346 node_info_.node_id_ = config_->Get(kNodeId, "");
1347 } else {
1348 node_info_.node_id_ = PSContext::instance()->node_id();
1349 }
1350
1351 if (node_info_.node_id_.empty()) {
1352 node_info_.node_id_ = CommUtil::GenerateUUID();
1353 }
1354 node_info_.node_role_ = role;
1355 node_info_.ip_ = server_->BoundIp();
1356 node_info_.port_ = server_->BoundPort();
1357
1358 MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
1359 << " is generate uuid is:" << node_info_.node_id_ << ", the ip:" << server_->BoundIp()
1360 << ", the port:" << server_->BoundPort();
1361 }
1362
InitNodeNum()1363 void AbstractNode::InitNodeNum() {
1364 worker_num_ = PSContext::instance()->cluster_config().initial_worker_num;
1365 server_num_ = PSContext::instance()->cluster_config().initial_server_num;
1366 scheduler_ip_ = PSContext::instance()->cluster_config().scheduler_host;
1367 scheduler_port_ = PSContext::instance()->cluster_config().scheduler_port;
1368 MS_LOG(INFO) << "The worker num:" << worker_num_ << ", the server num:" << server_num_
1369 << ", the scheduler ip:" << scheduler_ip_ << ", the scheduler port:" << scheduler_port_;
1370 }
1371
Recover()1372 bool AbstractNode::Recover() {
1373 MS_EXCEPTION_IF_NULL(config_);
1374 if (config_->Exists(kKeyRecovery)) {
1375 MS_LOG(INFO) << "The node is support recovery.";
1376 node_recovery_ = std::make_unique<NodeRecovery>(this);
1377 MS_EXCEPTION_IF_NULL(node_recovery_);
1378 if (node_recovery_->Initialize(config_->Get(kKeyRecovery, ""))) {
1379 MS_LOG(INFO) << "Initializing node recovery success.";
1380 return node_recovery_->Recover();
1381 }
1382 }
1383 return false;
1384 }
1385
OnEventCallback(const ClusterEvent & event)1386 void AbstractNode::OnEventCallback(const ClusterEvent &event) {
1387 if (!event_to_callback_.count(event)) {
1388 MS_LOG(INFO) << "[Event]:The event callback of " << event << " is not set.";
1389 } else {
1390 MS_LOG(INFO) << "[Event]:Trigger the event:" << event;
1391 if (event_to_callback_[event]) {
1392 event_to_callback_[event]();
1393 }
1394 }
1395 }
1396
OnCustomEventCallback(const uint32_t & event)1397 void AbstractNode::OnCustomEventCallback(const uint32_t &event) {
1398 if (!custom_event_to_callback_.count(event)) {
1399 MS_LOG(WARNING) << "[Custom event]:The event callback of " << event << " is not set.";
1400 } else {
1401 MS_LOG(INFO) << "[Custom event]:Trigger the event:" << event;
1402 if (custom_event_to_callback_[event]) {
1403 custom_event_to_callback_[event]();
1404 }
1405 }
1406 }
1407
IsWorkerOrServer0(const std::unordered_map<std::string,NodeInfo> & info)1408 bool AbstractNode::IsWorkerOrServer0(const std::unordered_map<std::string, NodeInfo> &info) {
1409 for (const auto &it : info) {
1410 if (it.second.is_alive == true && it.second.node_role_ == NodeRole::WORKER) {
1411 return true;
1412 }
1413
1414 if (it.second.is_alive == true && it.second.rank_id_ == 0 && it.second.node_role_ == NodeRole::SERVER) {
1415 return true;
1416 }
1417 }
1418 return false;
1419 }
1420
CreateTcpServer(const std::pair<uint32_t,uint32_t> & port_range)1421 void AbstractNode::CreateTcpServer(const std::pair<uint32_t, uint32_t> &port_range) {
1422 MS_EXCEPTION_IF_NULL(config_);
1423 std::string interface;
1424 std::string server_ip = common::GetEnv("MS_WORKER_IP");
1425 if (server_ip.empty()) {
1426 MS_LOG(INFO) << "'MS_WORKER_IP' env is not set, so get first available network interface.";
1427 CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
1428 }
1429
1430 server_ = std::make_shared<TcpServer>(server_ip, 0, config_.get(), port_range);
1431 MS_EXCEPTION_IF_NULL(server_);
1432 server_->SetMessageCallback([&](const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
1433 const Protos &protos, const void *data, size_t size) {
1434 MS_EXCEPTION_IF_NULL(conn);
1435 MS_EXCEPTION_IF_NULL(meta);
1436 MS_EXCEPTION_IF_NULL(data);
1437 MS_LOG(DEBUG) << "Receive message cmd " << meta->cmd() << ", size is " << size;
1438 const auto &handler_pair = server_handler_.find(meta->cmd());
1439 if (handler_pair == server_handler_.end()) {
1440 MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
1441 }
1442 (this->*(handler_pair->second))(conn, meta, protos, data, size);
1443 });
1444
1445 server_->Init();
1446 server_thread_ = std::make_unique<std::thread>([this]() {
1447 MS_LOG(INFO) << "The worker node or server node start a tcp server!";
1448 this->server_->Start();
1449 });
1450 MS_EXCEPTION_IF_NULL(server_thread_);
1451 }
1452
UpdateClusterState(const ClusterState & state)1453 void AbstractNode::UpdateClusterState(const ClusterState &state) {
1454 std::lock_guard<std::mutex> lock(cluster_state_mutex_);
1455 std::string state_str = CommUtil::ClusterStateToString(state);
1456 if (state_str.empty()) {
1457 return;
1458 }
1459
1460 if (state == current_cluster_state_) {
1461 return;
1462 }
1463 MS_LOG(INFO) << "[state]: Cluster state change from:" << CommUtil::ClusterStateToString(current_cluster_state_)
1464 << " to " << state_str;
1465 current_cluster_state_ = state;
1466 }
1467
PersistMetaData()1468 void AbstractNode::PersistMetaData() {
1469 if (node_recovery_ == nullptr) {
1470 MS_LOG(WARNING) << "node recovery is null, so don't persist meta data";
1471 return;
1472 }
1473 if (config_->Exists(kKeyRecovery)) {
1474 ClusterConfig &clusterConfig = PSContext::instance()->cluster_config();
1475 clusterConfig.scheduler_host = this->scheduler_ip();
1476 clusterConfig.scheduler_port = this->scheduler_port();
1477 clusterConfig.initial_worker_num = worker_num_;
1478 clusterConfig.initial_server_num = server_num_;
1479
1480 node_recovery_->Persist(clusterConfig);
1481 }
1482 }
1483
ProcessPrepareBuildingNetwork(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)1484 void AbstractNode::ProcessPrepareBuildingNetwork(const std::shared_ptr<TcpConnection> &conn,
1485 const std::shared_ptr<MessageMeta> &meta, const Protos &,
1486 const void *data, size_t size) {
1487 MS_EXCEPTION_IF_NULL(conn);
1488 MS_EXCEPTION_IF_NULL(meta);
1489 MS_EXCEPTION_IF_NULL(data);
1490 if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
1491 MS_LOG(ERROR) << "sever response message failed, prepare for building network failed.";
1492 } else {
1493 MS_LOG(INFO) << "prepare for building network success.";
1494 }
1495 }
1496 } // namespace core
1497 } // namespace ps
1498 } // namespace mindspore
1499