1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ps/core/abstract_node.h"
18 #include "ps/core/node_recovery.h"
19 #include "ps/core/communicator/tcp_communicator.h"
20 #include "ps/core/communicator/http_communicator.h"
21
22 namespace mindspore {
23 namespace ps {
24 namespace core {
Register(const std::shared_ptr<TcpClient> & client)25 void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
26 MS_EXCEPTION_IF_NULL(client);
27 auto message_meta = std::make_shared<MessageMeta>();
28 MS_EXCEPTION_IF_NULL(message_meta);
29 message_meta->set_cmd(NodeCommand::REGISTER);
30 message_meta->set_rank_id(node_info_.rank_id_);
31
32 RegisterMessage register_message;
33 register_message.set_node_id(node_info_.node_id_);
34 register_message.set_role(node_info_.node_role_);
35 register_message.set_ip(node_info_.ip_);
36 register_message.set_port(node_info_.port_);
37
38 MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
39 << " the node id:" << node_info_.node_id_ << " begin to register to the scheduler!";
40
41 if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, register_message.SerializeAsString().data(),
42 register_message.ByteSizeLong())) {
43 MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
44 << " the node id:" << node_info_.node_id_ << " register timeout!";
45 }
46 }
47
ProcessRegisterResp(const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)48 void AbstractNode::ProcessRegisterResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
49 MS_EXCEPTION_IF_NULL(meta);
50 MS_EXCEPTION_IF_NULL(data);
51 RegisterRespMessage register_resp_message;
52 CHECK_RETURN_TYPE(register_resp_message.ParseFromArray(data, SizeToInt(size)));
53 if (register_resp_message.node_id() != node_info_.node_id_) {
54 MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id()
55 << " is not match the current node id:" << node_info_.node_id_;
56 }
57
58 // Receive the Register message, indicating that the scheduler is alive, so update the time point at which the
59 // scheduler is alive
60 UpdateSchedulerTime();
61
62 MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << " registered scheduler success!";
63 }
64
Broadcast(const NodeRole & node_role,const DataPtr & message,size_t size,int command,const uint32_t & timeout)65 bool AbstractNode::Broadcast(const NodeRole &node_role, const DataPtr &message, size_t size, int command,
66 const uint32_t &timeout) {
67 MS_EXCEPTION_IF_NULL(message);
68 if (node_role != NodeRole::SERVER) {
69 MS_LOG(EXCEPTION) << "Currently only supports broadcast to server nodes";
70 }
71
72 uint64_t request_id = AddMessageTrack(nodes_address_.size());
73
74 for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) {
75 auto message_meta = std::make_shared<MessageMeta>();
76 MS_EXCEPTION_IF_NULL(message_meta);
77 message_meta->set_cmd(NodeCommand::SEND_DATA);
78 message_meta->set_request_id(request_id);
79 message_meta->set_rank_id(node_info_.rank_id_);
80 message_meta->set_role(node_info_.node_role_);
81 message_meta->set_user_cmd(command);
82
83 auto client = GetOrCreateTcpClient((*it).first.second);
84 if (!client->SendMessage(message_meta, Protos::RAW, message.get(), size)) {
85 MS_LOG(WARNING) << "Client send message failed.";
86 }
87 }
88 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
89 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
90 return Wait(request_id, timeout);
91 }
92
set_ready_for_scale_out()93 void AbstractNode::set_ready_for_scale_out() {
94 MS_LOG(INFO) << "[Scale out]: begin to set ready for scale out.";
95 Register(client_to_scheduler_);
96 std::lock_guard<std::mutex> lock(client_mutex_);
97 connected_nodes_.clear();
98 }
99
set_ready_for_scale_in()100 void AbstractNode::set_ready_for_scale_in() {
101 MS_LOG(INFO) << "[Scale in]: begin to set ready for scale in.";
102 if (!is_current_node_scale_in_) {
103 Register(client_to_scheduler_);
104 std::lock_guard<std::mutex> lock(client_mutex_);
105 connected_nodes_.clear();
106 }
107 }
108
set_scale_out_done()109 void AbstractNode::set_scale_out_done() {
110 MS_LOG(INFO) << "[Scale out]: begin to set scale out done.";
111 auto message_meta = std::make_shared<MessageMeta>();
112 MS_EXCEPTION_IF_NULL(message_meta);
113 message_meta->set_cmd(NodeCommand::SCALE_OUT_DONE);
114
115 ScaleOutDoneMessage scale_out_done_message;
116 scale_out_done_message.set_node_id(node_info_.node_id_);
117
118 if (!SendMessageSync(client_to_scheduler_, message_meta, Protos::PROTOBUF,
119 scale_out_done_message.SerializeAsString().data(), scale_out_done_message.ByteSizeLong())) {
120 MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
121 << " the node id:" << node_info_.node_id_ << " scale_out_done timeout!";
122 return;
123 }
124
125 MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
126 << " the node id:" << node_info_.node_id_ << "is send scale_out_done to scheduler successful!";
127 }
128
set_scale_in_done()129 void AbstractNode::set_scale_in_done() {
130 MS_LOG(INFO) << "[Scale in]: begin to set scale in done.";
131 auto message_meta = std::make_shared<MessageMeta>();
132 MS_EXCEPTION_IF_NULL(message_meta);
133 message_meta->set_cmd(NodeCommand::SCALE_IN_DONE);
134
135 ScaleInDoneMessage scale_in_done_message;
136 scale_in_done_message.set_node_id(node_info_.node_id_);
137
138 if (!SendMessageSync(client_to_scheduler_, message_meta, Protos::PROTOBUF,
139 scale_in_done_message.SerializeAsString().data(), scale_in_done_message.ByteSizeLong())) {
140 MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
141 << " the node id:" << node_info_.node_id_ << " scale_in_done timeout!";
142 return;
143 }
144
145 MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
146 << " the node id:" << node_info_.node_id_ << "is send scale_in_done to scheduler successful!";
147 }
148
BroadcastEvent(const uint32_t & event)149 void AbstractNode::BroadcastEvent(const uint32_t &event) {
150 auto message_meta = std::make_shared<MessageMeta>();
151 MS_EXCEPTION_IF_NULL(message_meta);
152 message_meta->set_cmd(NodeCommand::SEND_EVENT);
153
154 EventMessage event_message;
155 event_message.set_event(event);
156 event_message.set_node_id(node_info_.node_id_);
157
158 if (!SendMessageSync(client_to_scheduler_, message_meta, Protos::PROTOBUF, event_message.SerializeAsString().data(),
159 event_message.ByteSizeLong())) {
160 MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
161 << " the node id:" << node_info_.node_id_ << " send event timeout!";
162 return;
163 }
164
165 MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
166 << " the node id:" << node_info_.node_id_ << "is send event to scheduler!";
167 }
168
RegisterEventCallback(const core::ClusterEvent & event,const EventCallback & event_cb)169 void AbstractNode::RegisterEventCallback(const core::ClusterEvent &event, const EventCallback &event_cb) {
170 event_to_callback_.try_emplace(event, event_cb);
171 }
172
RegisterCustomEventCallback(const uint32_t & event,const EventCallback & event_cb)173 void AbstractNode::RegisterCustomEventCallback(const uint32_t &event, const EventCallback &event_cb) {
174 custom_event_to_callback_.try_emplace(event, event_cb);
175 }
176
Send(const NodeRole & node_role,const uint32_t & rank_id,const DataPtr & data,size_t len,int command,const uint32_t & timeout)177 bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len,
178 int command, const uint32_t &timeout) {
179 if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
180 MS_LOG(DEBUG) << "The node is timeout, can not send message.";
181 return false;
182 }
183 MS_EXCEPTION_IF_NULL(data);
184 if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
185 MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
186 << ", the server num:" << server_num_ << ", the rank id:" << rank_id;
187 }
188
189 auto message_meta = std::make_shared<MessageMeta>();
190 MS_EXCEPTION_IF_NULL(message_meta);
191 message_meta->set_cmd(NodeCommand::SEND_DATA);
192 message_meta->set_rank_id(node_info_.rank_id_);
193 message_meta->set_role(node_info_.node_role_);
194 message_meta->set_user_cmd(command);
195
196 auto client = GetOrCreateTcpClient(rank_id);
197 return SendMessageSync(client, message_meta, Protos::RAW, data.get(), len, timeout);
198 }
199
Send(const NodeRole & node_role,const std::vector<uint32_t> & rank_ids,const std::vector<DataPtr> & data,const std::vector<size_t> & lens,int command,const uint32_t & timeout)200 bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
201 const std::vector<DataPtr> &data, const std::vector<size_t> &lens, int command,
202 const uint32_t &timeout) {
203 if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
204 MS_LOG(DEBUG) << "The node is timeout, can not send message.";
205 return false;
206 }
207
208 uint64_t request_id = AddMessageTrack(data.size());
209
210 if (rank_ids.size() != data.size() || rank_ids.size() != lens.size()) {
211 MS_LOG(EXCEPTION) << "The number of rank ids, data and lens are not equal!";
212 }
213 for (size_t it = 0; it < rank_ids.size(); ++it) {
214 if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it), worker_num_, server_num_)) {
215 MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
216 << ", the server num:" << server_num_ << ", the rank id:" << rank_ids.at(it);
217 }
218
219 auto message_meta = std::make_shared<MessageMeta>();
220 MS_EXCEPTION_IF_NULL(message_meta);
221 message_meta->set_cmd(NodeCommand::SEND_DATA);
222 message_meta->set_request_id(request_id);
223 message_meta->set_rank_id(node_info_.rank_id_);
224 message_meta->set_role(node_info_.node_role_);
225 message_meta->set_user_cmd(command);
226
227 auto send = data.at(it);
228 auto len = lens.at(it);
229 auto client = GetOrCreateTcpClient(rank_ids.at(it));
230 MS_EXCEPTION_IF_NULL(client);
231 if (!client->SendMessage(message_meta, Protos::RAW, send.get(), len)) {
232 MS_LOG(WARNING) << "Client send message failed.";
233 }
234 }
235 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
236 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
237 return Wait(request_id, timeout);
238 }
239
Send(const NodeRole & node_role,const uint32_t & rank_id,const DataPtr & message,size_t len,int command,VectorPtr * output,const uint32_t & timeout)240 bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len,
241 int command, VectorPtr *output, const uint32_t &timeout) {
242 if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
243 MS_LOG(DEBUG) << "The node is timeout, can not send message.";
244 return false;
245 }
246 MS_EXCEPTION_IF_NULL(message);
247 MS_EXCEPTION_IF_NULL(output);
248 if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
249 MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
250 << ", the server num:" << server_num_ << ", the rank id:" << rank_id;
251 }
252
253 uint64_t request_id = AddMessageTrack(1);
254 set_message_callback(request_id, [&]() {
255 receive_messages_mutex_.lock();
256 auto res = receive_messages_[request_id];
257 *output = res[rank_id];
258 receive_messages_.erase(request_id);
259 receive_messages_mutex_.unlock();
260 });
261
262 auto message_meta = std::make_shared<MessageMeta>();
263 MS_EXCEPTION_IF_NULL(message_meta);
264 message_meta->set_cmd(NodeCommand::SEND_DATA);
265 message_meta->set_request_id(request_id);
266 message_meta->set_rank_id(node_info_.rank_id_);
267 message_meta->set_role(node_info_.node_role_);
268 message_meta->set_user_cmd(command);
269
270 auto client = GetOrCreateTcpClient(rank_id);
271 MS_EXCEPTION_IF_NULL(client);
272 if (!client->SendMessage(message_meta, Protos::RAW, message.get(), len)) {
273 MS_LOG(WARNING) << "Client send message failed.";
274 }
275 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
276 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
277 return Wait(request_id, timeout);
278 }
279
Send(const NodeRole & node_role,const std::vector<uint32_t> & rank_ids,const std::vector<DataPtr> & data,const std::vector<size_t> & data_lens,int command,std::vector<VectorPtr> * output,const uint32_t & timeout)280 bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
281 const std::vector<DataPtr> &data, const std::vector<size_t> &data_lens, int command,
282 std::vector<VectorPtr> *output, const uint32_t &timeout) {
283 if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
284 MS_LOG(DEBUG) << "The node is timeout, can not send message.";
285 return false;
286 }
287 MS_EXCEPTION_IF_NULL(output);
288 uint64_t request_id = AddMessageTrack(data.size());
289
290 if (rank_ids.size() != data.size()) {
291 MS_LOG(EXCEPTION) << "The number of rank ids, data, comm_message_resp should be equal!";
292 }
293
294 size_t size = rank_ids.size();
295
296 set_message_callback(request_id, [&]() {
297 receive_messages_mutex_.lock();
298 auto res = receive_messages_[request_id];
299 for (size_t it = 0; it < size; ++it) {
300 (*output).push_back(res[rank_ids.at(it)]);
301 }
302 receive_messages_.erase(request_id);
303 receive_messages_mutex_.unlock();
304 });
305
306 for (size_t it = 0; it < size; ++it) {
307 if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it), worker_num_, server_num_)) {
308 MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
309 << ", the server num:" << server_num_ << ", the rank id:" << rank_ids.at(it);
310 }
311
312 auto message_meta = std::make_shared<MessageMeta>();
313 MS_EXCEPTION_IF_NULL(message_meta);
314 message_meta->set_cmd(NodeCommand::SEND_DATA);
315 message_meta->set_request_id(request_id);
316 message_meta->set_rank_id(node_info_.rank_id_);
317 message_meta->set_role(node_info_.node_role_);
318 message_meta->set_user_cmd(command);
319
320 auto send = data.at(it);
321 auto len = data_lens.at(it);
322
323 auto client = GetOrCreateTcpClient(rank_ids.at(it));
324 MS_EXCEPTION_IF_NULL(client);
325 if (!client->SendMessage(message_meta, Protos::RAW, send.get(), len)) {
326 MS_LOG(WARNING) << "Client send message failed.";
327 }
328 }
329 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
330 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
331 return Wait(request_id, timeout);
332 }
333
CollectiveSendAsync(const NodeRole & node_role,const uint32_t & rank_id,const void * data,size_t size)334 uint64_t AbstractNode::CollectiveSendAsync(const NodeRole &node_role, const uint32_t &rank_id, const void *data,
335 size_t size) {
336 MS_EXCEPTION_IF_NULL(data);
337 if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
338 MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
339 << ", the server num:" << server_num_ << ", the rank id:" << rank_id;
340 }
341
342 std::shared_ptr<MessageMeta> message_meta = std::make_shared<MessageMeta>();
343 MS_EXCEPTION_IF_NULL(message_meta);
344 message_meta->set_cmd(NodeCommand::COLLECTIVE_SEND_DATA);
345 message_meta->set_rank_id(node_info_.rank_id_);
346 message_meta->set_role(node_info_.node_role_);
347
348 auto client = GetOrCreateTcpClient(rank_id);
349 MS_EXCEPTION_IF_NULL(client);
350 return SendMessageAsync(client, message_meta, Protos::RAW, data, size);
351 }
352
CollectiveReceiveAsync(const NodeRole & node_role,const uint32_t & rank_id,VectorPtr * output)353 std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const NodeRole &node_role, const uint32_t &rank_id,
354 VectorPtr *output) {
355 MS_EXCEPTION_IF_NULL(output);
356 if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
357 MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
358 << ", the server num:" << server_num_ << ", the rank id:" << rank_id;
359 }
360
361 receive_callbacks_mutex_.lock();
362 uint64_t rank_request_id = NextExpectedRankRequestId(rank_id);
363 auto pair_data = std::make_pair(rank_id, rank_request_id);
364 receive_messages_done_[pair_data] = false;
365 if (received_data_.count(pair_data) > 0) {
366 auto res = received_data_[pair_data];
367 MS_EXCEPTION_IF_NULL(res);
368 *output = res;
369 (void)received_data_.erase(pair_data);
370 receive_messages_done_[pair_data] = true;
371 MS_LOG(DEBUG) << "Receive data from rank id:" << rank_id << ", the rank request id is:" << rank_request_id;
372 } else {
373 receive_callbacks_[pair_data] = [=]() mutable {
374 auto res_output = received_data_[std::make_pair(rank_id, rank_request_id)];
375 MS_EXCEPTION_IF_NULL(res_output);
376 if (*output != nullptr) {
377 MS_LOG(WARNING) << "The output is not empty.";
378 }
379 *output = res_output;
380 received_data_.erase(std::make_pair(rank_id, rank_request_id));
381 receive_messages_done_[std::make_pair(rank_id, rank_request_id)] = true;
382 MS_LOG(DEBUG) << "Receive data from rank id:" << rank_id << ", the rank request id is:" << rank_request_id;
383 };
384 }
385 receive_callbacks_mutex_.unlock();
386 return std::make_pair(rank_id, rank_request_id);
387 }
388
CollectiveWait(const std::pair<uint32_t,uint64_t> & request_id,const uint32_t & timeout)389 bool AbstractNode::CollectiveWait(const std::pair<uint32_t, uint64_t> &request_id, const uint32_t &timeout) {
390 std::unique_lock<std::mutex> lock(receive_callbacks_mutex_);
391 bool res =
392 receive_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { return receive_messages_done_[request_id]; });
393 if (receive_messages_done_.count(request_id) != 0) {
394 (void)receive_messages_done_.erase(request_id);
395 }
396 return res;
397 }
398
InitFollowerScaler()399 bool AbstractNode::InitFollowerScaler() {
400 follower_scaler_ = std::make_unique<FollowerScaler>(this);
401 MS_EXCEPTION_IF_NULL(follower_scaler_);
402 follower_scaler_->RegisterScaleEventCallbacks();
403 return true;
404 }
405
RegisterFollowerScalerBarrierBeforeScaleOut(const std::string & module,const BarrierBeforeScaleOut & barrier)406 void AbstractNode::RegisterFollowerScalerBarrierBeforeScaleOut(const std::string &module,
407 const BarrierBeforeScaleOut &barrier) {
408 MS_EXCEPTION_IF_NULL(follower_scaler_);
409 follower_scaler_->RegisterBarrierBeforeScaleOut(module, barrier);
410 }
411
RegisterFollowerScalerBarrierBeforeScaleIn(const std::string & module,const BarrierBeforeScaleIn & barrier)412 void AbstractNode::RegisterFollowerScalerBarrierBeforeScaleIn(const std::string &module,
413 const BarrierBeforeScaleIn &barrier) {
414 MS_EXCEPTION_IF_NULL(follower_scaler_);
415 follower_scaler_->RegisterBarrierBeforeScaleIn(module, barrier);
416 }
417
RegisterFollowerScalerHandlerAfterScaleOut(const std::string & module,const HandlerAfterScaleOut & handler)418 void AbstractNode::RegisterFollowerScalerHandlerAfterScaleOut(const std::string &module,
419 const HandlerAfterScaleOut &handler) {
420 MS_EXCEPTION_IF_NULL(follower_scaler_);
421 follower_scaler_->RegisterHandlerAfterScaleOut(module, handler);
422 }
423
RegisterFollowerScalerHandlerAfterScaleIn(const std::string & module,const HandlerAfterScaleIn & handler)424 void AbstractNode::RegisterFollowerScalerHandlerAfterScaleIn(const std::string &module,
425 const HandlerAfterScaleIn &handler) {
426 MS_EXCEPTION_IF_NULL(follower_scaler_);
427 follower_scaler_->RegisterHandlerAfterScaleIn(module, handler);
428 }
429
worker_num() const430 int32_t AbstractNode::worker_num() const { return worker_num_; }
431
server_num() const432 int32_t AbstractNode::server_num() const { return server_num_; }
433
set_worker_num(const int32_t & worker_num)434 void AbstractNode::set_worker_num(const int32_t &worker_num) { worker_num_ = worker_num; }
435
set_server_num(const int32_t & server_num)436 void AbstractNode::set_server_num(const int32_t &server_num) { server_num_ = server_num; }
437
scheduler_ip() const438 std::string AbstractNode::scheduler_ip() const { return scheduler_ip_; }
439
set_scheduler_ip(const std::string & scheduler_ip)440 void AbstractNode::set_scheduler_ip(const std::string &scheduler_ip) { scheduler_ip_ = scheduler_ip; }
441
scheduler_port() const442 uint16_t AbstractNode::scheduler_port() const { return scheduler_port_; }
443
set_scheduler_port(const uint16_t & scheduler_port)444 void AbstractNode::set_scheduler_port(const uint16_t &scheduler_port) { scheduler_port_ = scheduler_port; }
445
cluster_state() const446 ClusterState AbstractNode::cluster_state() const { return current_cluster_state_; }
447
set_handler(const RequestHandler & handler)448 void AbstractNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; }
449
Response(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)450 void AbstractNode::Response(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
451 const void *data, size_t size) {
452 MS_EXCEPTION_IF_NULL(conn);
453 MS_EXCEPTION_IF_NULL(meta);
454 MS_EXCEPTION_IF_NULL(data);
455 MS_EXCEPTION_IF_NULL(server_);
456 meta->set_role(node_info_.node_role_);
457 meta->set_rank_id(node_info_.rank_id_);
458 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
459 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << meta->request_id();
460 if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
461 MS_LOG(WARNING) << "Server response message failed.";
462 }
463 }
464
GetOrCreateHttpComm(const std::string & ip,uint16_t port,const std::shared_ptr<TaskExecutor> & task_executor)465 std::shared_ptr<CommunicatorBase> AbstractNode::GetOrCreateHttpComm(
466 const std::string &ip, uint16_t port, const std::shared_ptr<TaskExecutor> &task_executor) {
467 MS_EXCEPTION_IF_NULL(task_executor);
468 std::lock_guard<std::mutex> lock(communicator_mutex_);
469 if (!communicators_.count(kHttpCommunicator)) {
470 MS_LOG(INFO) << "Create Http communicator.";
471 auto http_comm = std::make_shared<HttpCommunicator>(ip, port, task_executor);
472 MS_EXCEPTION_IF_NULL(http_comm);
473 communicators_[kHttpCommunicator] = http_comm;
474 }
475 return communicators_[kHttpCommunicator];
476 }
477
GetOrCreateTcpComm(const std::string & scheduler_ip,std::int16_t scheduler_port,uint32_t worker_num,uint32_t server_num,const std::shared_ptr<TaskExecutor> & task_executor)478 std::shared_ptr<CommunicatorBase> AbstractNode::GetOrCreateTcpComm(const std::string &scheduler_ip,
479 std::int16_t scheduler_port, uint32_t worker_num,
480 uint32_t server_num,
481 const std::shared_ptr<TaskExecutor> &task_executor) {
482 MS_EXCEPTION_IF_NULL(task_executor);
483 std::lock_guard<std::mutex> lock(communicator_mutex_);
484 if (!communicators_.count(kTcpCommunicator)) {
485 MS_LOG(INFO) << "Create Tcp communicator.";
486 auto tcp_comm = std::make_shared<TcpCommunicator>(task_executor, this);
487 PSContext::instance()->cluster_config().scheduler_host = scheduler_ip;
488 PSContext::instance()->cluster_config().scheduler_port = static_cast<uint16_t>(scheduler_port);
489 PSContext::instance()->cluster_config().initial_worker_num = worker_num;
490 PSContext::instance()->cluster_config().initial_server_num = server_num;
491 MS_EXCEPTION_IF_NULL(tcp_comm);
492 PSContext::instance()->cluster_config().scheduler_host = scheduler_ip;
493 PSContext::instance()->cluster_config().scheduler_port = static_cast<uint16_t>(scheduler_port);
494 PSContext::instance()->cluster_config().initial_worker_num = worker_num;
495 PSContext::instance()->cluster_config().initial_server_num = server_num;
496 MS_LOG(INFO) << "Initialize cluster metadata for server. Worker number:" << worker_num
497 << ", Server number:" << server_num << ", Scheduler ip:" << scheduler_ip
498 << ", Scheduler port:" << scheduler_port;
499 communicators_[kTcpCommunicator] = tcp_comm;
500 }
501 return communicators_[kTcpCommunicator];
502 }
503
StartHeartbeatTimer(const std::shared_ptr<TcpClient> & client)504 void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) {
505 MS_EXCEPTION_IF_NULL(client);
506 MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_)
507 << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_
508 << " begin send heartbeat to the scheduler!";
509 heart_beat_thread_ = std::make_unique<std::thread>([&]() {
510 while (!is_finish_.load()) {
511 if (!Heartbeat(client)) {
512 MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
513 << ", the node id is:" << node_info_.node_id_ << " Send heartbeat timeout!";
514 if (CheckSchedulerTimeout()) {
515 MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
516 << ", the node id is:" << node_info_.node_id_ << " exited due to scheduler timeout!";
517 is_finish_ = true;
518 wait_finish_cond_.notify_all();
519 if (!is_already_stopped_) {
520 OnEventCallback(ClusterEvent::SCHEDULER_TIMEOUT);
521 }
522 }
523 } else {
524 UpdateSchedulerTime();
525 }
526
527 std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval));
528 }
529 });
530 MS_EXCEPTION_IF_NULL(heart_beat_thread_);
531 heart_beat_thread_->detach();
532 }
533
Heartbeat(const std::shared_ptr<TcpClient> & client)534 bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client) {
535 MS_EXCEPTION_IF_NULL(client);
536 auto meta = std::make_shared<MessageMeta>();
537 MS_EXCEPTION_IF_NULL(meta);
538 meta->set_cmd(NodeCommand::HEARTBEAT);
539
540 HeartbeatMessage heartbeat_message;
541 heartbeat_message.set_node_id(node_info_.node_id_);
542
543 if (!SendMessageSync(client, meta, Protos::PROTOBUF, heartbeat_message.SerializeAsString().data(),
544 heartbeat_message.ByteSizeLong(), kCommTimeoutInSeconds)) {
545 MS_LOG(WARNING) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!";
546 return false;
547 }
548 return true;
549 }
550
UpdateSchedulerTime()551 void AbstractNode::UpdateSchedulerTime() {
552 struct timeval current_time {};
553 (void)gettimeofday(¤t_time, nullptr);
554 scheduler_time_ = current_time;
555 MS_LOG(DEBUG) << "Update scheduler time, the current time is: " << current_time.tv_sec;
556 }
557
CheckSchedulerTimeout() const558 bool AbstractNode::CheckSchedulerTimeout() const {
559 struct timeval current_time {};
560 (void)gettimeofday(¤t_time, nullptr);
561 int64_t old_time = scheduler_time_.tv_sec + PSContext::instance()->cluster_config().scheduler_timeout;
562 if (old_time < current_time.tv_sec) {
563 return true;
564 }
565 return false;
566 }
567
ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)568 void AbstractNode::ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
569 MS_EXCEPTION_IF_NULL(meta);
570 MS_EXCEPTION_IF_NULL(data);
571 HeartbeatRespMessage heartbeat_resp_message;
572 CHECK_RETURN_TYPE(heartbeat_resp_message.ParseFromArray(data, SizeToInt(size)));
573
574 current_cluster_state_ = heartbeat_resp_message.cluster_state();
575 MS_LOG(DEBUG) << "The current cluster state from heartbeat:"
576 << CommUtil::ClusterStateToString(current_cluster_state_);
577
578 all_nodes_info_.clear();
579 for (const auto &it : heartbeat_resp_message.servers_meta()) {
580 NodeInfo info;
581 info.ip_ = it.ip();
582 info.node_id_ = it.node_id();
583 info.port_ = static_cast<uint16_t>(it.port());
584 info.node_role_ = it.role();
585 info.rank_id_ = it.rank_id();
586 info.is_alive = it.is_alive();
587
588 all_nodes_info_[info.node_id_] = info;
589 MS_LOG(DEBUG) << "The node id:" << info.node_id_ << ", the rank id:" << info.rank_id_
590 << ", the node role:" << CommUtil::NodeRoleToString(info.node_role_) << " is alive:" << info.is_alive;
591 }
592
593 bool is_worker_or_server0 = heartbeat_resp_message.is_worker_or_server0();
594
595 if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
596 if (node_recovery_ == nullptr || is_worker_or_server0) {
597 MS_LOG(INFO) << "The recovery is disable.";
598 is_ready_ = true;
599 wait_start_cond_.notify_all();
600 OnEventCallback(ClusterEvent::NODE_TIMEOUT);
601 } else {
602 MS_LOG(INFO) << "The node is support recovery, users can pull up this node to restore the cluster.";
603 }
604 }
605 }
606
FetchServers(const std::shared_ptr<TcpClient> & client)607 void AbstractNode::FetchServers(const std::shared_ptr<TcpClient> &client) {
608 MS_EXCEPTION_IF_NULL(client);
609 auto meta = std::make_shared<MessageMeta>();
610 MS_EXCEPTION_IF_NULL(meta);
611 meta->set_cmd(NodeCommand::FETCH_METADATA);
612
613 FetchServersMessage fetch_servers;
614 fetch_servers.set_node_id(node_info_.node_id_);
615 if (!SendMessageSync(client, meta, Protos::PROTOBUF, fetch_servers.SerializeAsString().data(),
616 fetch_servers.ByteSizeLong())) {
617 MS_LOG(EXCEPTION) << "Fetch servers address timeout!";
618 }
619 }
620
ProcessFetchServersResp(const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)621 void AbstractNode::ProcessFetchServersResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
622 MS_EXCEPTION_IF_NULL(meta);
623 MS_EXCEPTION_IF_NULL(data);
624 FetchServersRespMessage fetch_servers_resp_message;
625 CHECK_RETURN_TYPE(fetch_servers_resp_message.ParseFromArray(data, SizeToInt(size)));
626
627 nodes_address_.clear();
628 for (const auto &it : fetch_servers_resp_message.servers_meta()) {
629 nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port());
630 MS_LOG(INFO) << "The server ip is:" << it.ip() << ", the port is:" << it.port();
631 }
632 }
633
ProcessSendMetadata(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)634 void AbstractNode::ProcessSendMetadata(const std::shared_ptr<TcpConnection> &conn,
635 const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
636 size_t size) {
637 MS_EXCEPTION_IF_NULL(conn);
638 MS_EXCEPTION_IF_NULL(meta);
639 MS_EXCEPTION_IF_NULL(data);
640 if (is_current_node_scale_in_) {
641 MS_LOG(WARNING) << "Trigger cluster scale in done event.";
642 node_info_.rank_id_ = UINT32_MAX;
643 OnEventCallback(ClusterEvent::CLUSTER_SCALE_IN_DONE);
644 return;
645 }
646 SendMetadataMessage send_meta_message;
647 CHECK_RETURN_TYPE(send_meta_message.ParseFromArray(data, SizeToInt(size)));
648 worker_num_ = send_meta_message.worker_num();
649 server_num_ = send_meta_message.server_num();
650 if (send_meta_message.rank_id() < 0) {
651 MS_LOG(EXCEPTION) << "The rank id is wrong.";
652 }
653 node_info_.rank_id_ = send_meta_message.rank_id();
654 current_cluster_state_ = send_meta_message.cluster_state();
655 MS_LOG(INFO) << "The send metadata worker num:" << worker_num_ << ", server num:" << server_num_
656 << ", cluster state is:" << CommUtil::ClusterStateToString(current_cluster_state_)
657 << ", the rank id:" << node_info_.rank_id_;
658
659 client_mutex_.lock();
660 nodes_address_.clear();
661 for (const auto &it : send_meta_message.servers_meta()) {
662 nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port());
663 MS_LOG(INFO) << "The server ip is:" << it.ip() << ", the port is:" << it.port() << ", the rank id:" << it.rank_id();
664 }
665 client_mutex_.unlock();
666 if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
667 MS_LOG(WARNING) << "Sever response message failed.";
668 }
669 is_ready_ = true;
670 wait_start_cond_.notify_all();
671
672 if (current_cluster_state_ == ClusterState::CLUSTER_SCALE_OUT) {
673 MS_LOG(WARNING) << "Trigger cluster scale out done event.";
674 OnEventCallback(ClusterEvent::CLUSTER_SCALE_OUT_DONE);
675 }
676
677 if (current_cluster_state_ == ClusterState::CLUSTER_SCALE_IN) {
678 MS_LOG(WARNING) << "Trigger cluster scale in done event.";
679 OnEventCallback(ClusterEvent::CLUSTER_SCALE_IN_DONE);
680 }
681
682 std::lock_guard<std::mutex> lock(client_mutex_);
683 connected_nodes_.clear();
684 }
685
ProcessFinish(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)686 void AbstractNode::ProcessFinish(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
687 const Protos &, const void *data, size_t size) {
688 MS_EXCEPTION_IF_NULL(conn);
689 MS_EXCEPTION_IF_NULL(meta);
690 MS_EXCEPTION_IF_NULL(data);
691 if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
692 MS_LOG(WARNING) << "Server response message failed.";
693 }
694 is_finish_ = true;
695 wait_finish_cond_.notify_all();
696 }
697
ProcessScaleOutDone(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)698 void AbstractNode::ProcessScaleOutDone(const std::shared_ptr<TcpConnection> &conn,
699 const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
700 size_t size) {
701 MS_EXCEPTION_IF_NULL(conn);
702 MS_EXCEPTION_IF_NULL(meta);
703 MS_EXCEPTION_IF_NULL(data);
704 if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
705 MS_LOG(WARNING) << "Server response message failed.";
706 }
707 is_ready_ = true;
708 current_cluster_state_ = ClusterState::CLUSTER_READY;
709 }
710
ProcessScaleInDone(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)711 void AbstractNode::ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn,
712 const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
713 size_t size) {
714 MS_EXCEPTION_IF_NULL(conn);
715 MS_EXCEPTION_IF_NULL(meta);
716 MS_EXCEPTION_IF_NULL(data);
717 if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
718 MS_LOG(WARNING) << "Server response message failed.";
719 }
720 is_ready_ = true;
721 current_cluster_state_ = ClusterState::CLUSTER_READY;
722 }
723
ProcessEvent(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)724 void AbstractNode::ProcessEvent(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
725 const Protos &, const void *data, size_t size) {
726 MS_EXCEPTION_IF_NULL(conn);
727 MS_EXCEPTION_IF_NULL(meta);
728 MS_EXCEPTION_IF_NULL(data);
729 EventRespMessage event_resp_message;
730 CHECK_RETURN_TYPE(event_resp_message.ParseFromArray(data, SizeToInt(size)));
731 uint32_t event = event_resp_message.event();
732 if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
733 MS_LOG(WARNING) << "Server response message failed.";
734 }
735 OnCustomEventCallback(event);
736 }
737
ProcessScaleOut(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)738 void AbstractNode::ProcessScaleOut(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
739 const Protos &, const void *data, size_t size) {
740 MS_EXCEPTION_IF_NULL(conn);
741 MS_EXCEPTION_IF_NULL(meta);
742 MS_EXCEPTION_IF_NULL(data);
743
744 ScaleOutMessage scale_out_message;
745 CHECK_RETURN_TYPE(scale_out_message.ParseFromArray(data, SizeToInt(size)));
746 int32_t worker_num = scale_out_message.worker_num();
747 int32_t server_num = scale_out_message.server_num();
748 MS_LOG(WARNING) << "The scale out worker num:" << worker_num << ", the server num:" << server_num;
749
750 if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
751 MS_LOG(WARNING) << "Server response message failed.";
752 }
753 OnEventCallback(ClusterEvent::READY_FOR_SCALE_OUT);
754 current_cluster_state_ = ClusterState::CLUSTER_SCALE_OUT;
755 is_ready_ = false;
756 }
757
ProcessScaleIn(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)758 void AbstractNode::ProcessScaleIn(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
759 const Protos &, const void *data, size_t size) {
760 MS_EXCEPTION_IF_NULL(conn);
761 MS_EXCEPTION_IF_NULL(meta);
762 MS_EXCEPTION_IF_NULL(data);
763
764 ScaleInMessage scale_in_message;
765 CHECK_RETURN_TYPE(scale_in_message.ParseFromArray(data, SizeToInt(size)));
766 int32_t worker_num = scale_in_message.worker_num();
767 int32_t server_num = scale_in_message.server_num();
768 MS_LOG(WARNING) << "The scale in worker num:" << worker_num << ", the server num:" << server_num;
769
770 is_current_node_scale_in_ = scale_in_message.is_node_scale_in();
771 if (is_current_node_scale_in_) {
772 MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
773 << " the node id:" << node_info_.node_id_ << " is a scale in node!";
774 } else {
775 MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
776 << " the node id:" << node_info_.node_id_ << " is not a scale in node!";
777 }
778
779 if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
780 MS_LOG(WARNING) << "Server response message failed.";
781 }
782 OnEventCallback(ClusterEvent::READY_FOR_SCALE_IN);
783 current_cluster_state_ = ClusterState::CLUSTER_SCALE_IN;
784 is_ready_ = false;
785 }
786
Disconnect(const std::shared_ptr<TcpClient> & client,const uint32_t & timeout)787 bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout) {
788 MS_EXCEPTION_IF_NULL(client);
789 auto meta = std::make_shared<MessageMeta>();
790 MS_EXCEPTION_IF_NULL(meta);
791 meta->set_cmd(NodeCommand::FINISH);
792
793 std::string finish_message = node_info_.node_id_;
794
795 if (!SendMessageSync(client, meta, Protos::RAW, finish_message.data(), finish_message.length())) {
796 MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
797 << " the node id:" << node_info_.node_id_ << " send Finish Message timeout!";
798 }
799 return WaitForDisconnect(timeout);
800 }
801
WaitForDisconnect(const uint32_t & timeout)802 bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) {
803 std::unique_lock<std::mutex> lock(wait_finish_mutex_);
804 bool res = wait_finish_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
805 if (is_finish_.load()) {
806 MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success finish!";
807 }
808 return is_finish_.load();
809 });
810 return res;
811 }
812
InitClientToScheduler()813 bool AbstractNode::InitClientToScheduler() {
814 if (config_ == nullptr) {
815 MS_LOG(WARNING) << "The config is empty.";
816 return false;
817 }
818 client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_ip_, scheduler_port_, config_.get());
819 MS_EXCEPTION_IF_NULL(client_to_scheduler_);
820 client_to_scheduler_->SetMessageCallback(
821 [&](const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data, size_t size) {
822 try {
823 MS_EXCEPTION_IF_NULL(meta);
824 MS_EXCEPTION_IF_NULL(data);
825 if (handlers_.count(meta->cmd()) == 0) {
826 MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
827 }
828 if (handlers_[meta->cmd()] != nullptr) {
829 const auto &handler_ptr = handlers_[meta->cmd()];
830 (this->*handler_ptr)(meta, data, size);
831 }
832 NotifyMessageArrival(meta);
833 } catch (const std::exception &e) {
834 MsException::Instance().SetException();
835 }
836 });
837
838 client_to_scheduler_->Init();
839 client_to_scheduler_thread_ = std::make_unique<std::thread>([&]() {
840 MS_LOG(INFO) << "The node start a tcp client!";
841 client_to_scheduler_->Start();
842 });
843 client_to_scheduler_thread_->detach();
844
845 client_to_scheduler_->set_disconnected_callback([&]() {
846 std::this_thread::sleep_for(std::chrono::milliseconds(PSContext::instance()->cluster_config().connect_interval));
847 if (is_ready_.load() == false) {
848 client_to_scheduler_->Init();
849 }
850 });
851 bool wait_res = client_to_scheduler_->WaitConnected();
852 if (!wait_res) {
853 is_ready_ = true;
854 }
855 return wait_res;
856 }
857
GetOrCreateTcpClient(const uint32_t & rank_id)858 const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const uint32_t &rank_id) {
859 std::lock_guard<std::mutex> lock(client_mutex_);
860 if (connected_nodes_.find(rank_id) != connected_nodes_.end()) {
861 return connected_nodes_[rank_id];
862 } else {
863 if (nodes_address_.find(std::make_pair(NodeRole::SERVER, rank_id)) == nodes_address_.end()) {
864 MS_LOG(EXCEPTION) << "Worker receive nodes info from scheduler failed!";
865 }
866 if (config_ == nullptr) {
867 MS_LOG(EXCEPTION) << "The config is empty.";
868 }
869 std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first;
870 uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second;
871 auto client = std::make_shared<TcpClient>(ip, port, config_.get());
872 MS_EXCEPTION_IF_NULL(client);
873 client->SetMessageCallback([&](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
874 size_t size) {
875 switch (meta->cmd()) {
876 case NodeCommand::SEND_DATA:
877 ProcessSendDataResp(meta, protos, data, size);
878 RunMessageCallback(meta->request_id());
879 break;
880 case NodeCommand::COLLECTIVE_SEND_DATA:
881 MS_LOG(DEBUG) << "The Node id:" << node_info_.node_id_ << " receive a collective_send_data message response!";
882 break;
883 default:
884 MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
885 }
886 NotifyMessageArrival(meta);
887 });
888 client->Init();
889 connected_nodes_[rank_id] = client;
890 return connected_nodes_[rank_id];
891 }
892 }
893
SendMessageSync(const std::shared_ptr<TcpClient> & client,const CommMessage & message,const uint32_t & timeout)894 bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
895 const uint32_t &timeout) {
896 MS_EXCEPTION_IF_NULL(client);
897 uint64_t request_id = AddMessageTrack(1);
898 const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
899 client->SendMessage(message);
900 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
901 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
902 return Wait(request_id, timeout);
903 }
904
SendMessageAsync(const std::shared_ptr<TcpClient> & client,const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size)905 uint64_t AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client,
906 const std::shared_ptr<MessageMeta> &meta, const Protos &protos,
907 const void *data, size_t size) {
908 MS_EXCEPTION_IF_NULL(client);
909 MS_EXCEPTION_IF_NULL(meta);
910 MS_EXCEPTION_IF_NULL(data);
911 uint64_t request_id = AddMessageTrack(1);
912 meta->set_request_id(request_id);
913 client->SendMessage(meta, protos, data, size);
914 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
915 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
916 return request_id;
917 }
918
SendMessageSync(const std::shared_ptr<TcpClient> & client,const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size,const uint32_t & timeout)919 bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta,
920 const Protos &protos, const void *data, size_t size, const uint32_t &timeout) {
921 MS_EXCEPTION_IF_NULL(client);
922 MS_EXCEPTION_IF_NULL(meta);
923 MS_EXCEPTION_IF_NULL(data);
924 uint64_t request_id = AddMessageTrack(1);
925 meta->set_request_id(request_id);
926 client->SendMessage(meta, protos, data, size);
927 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
928 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
929 return Wait(request_id, timeout);
930 }
931
ProcessCollectiveSendData(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)932 void AbstractNode::ProcessCollectiveSendData(const std::shared_ptr<TcpConnection> &conn,
933 const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
934 MS_EXCEPTION_IF_NULL(conn);
935 MS_EXCEPTION_IF_NULL(meta);
936 MS_EXCEPTION_IF_NULL(data);
937 if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
938 MS_LOG(WARNING) << "Server response message failed.";
939 }
940 }
941
ProcessSendData(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)942 void AbstractNode::ProcessSendData(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
943 const Protos &, const void *data, size_t size) {
944 MS_EXCEPTION_IF_NULL(conn);
945 MS_EXCEPTION_IF_NULL(meta);
946 MS_EXCEPTION_IF_NULL(data);
947 std::shared_ptr<unsigned char[]> res(new unsigned char[size]);
948 if (size > 0) {
949 size_t dest_size = size;
950 size_t src_size = size;
951 if (memcpy_s(res.get(), dest_size, data, src_size) != EOK) {
952 MS_LOG(EXCEPTION) << "The memcpy_s error";
953 }
954 }
955 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
956 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << meta->request_id()
957 << " the current time is:"
958 << std::chrono::time_point_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now())
959 .time_since_epoch()
960 .count();
961 request_handler_(conn, meta, res, size);
962 }
963
NotifyMessageArrival(const std::shared_ptr<MessageMeta> & meta)964 void AbstractNode::NotifyMessageArrival(const std::shared_ptr<MessageMeta> &meta) {
965 MS_EXCEPTION_IF_NULL(meta);
966 std::lock_guard<std::mutex> lock(message_tracker_mutex_);
967 uint64_t request_id = meta->request_id();
968 if (message_tracker_.count(request_id)) {
969 message_tracker_[request_id].second++;
970 } else {
971 MS_LOG(WARNING) << "The requset id:" << request_id << " is removed.";
972 }
973 message_tracker_cond_.notify_all();
974 }
975
RunReceiveCallback(const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)976 void AbstractNode::RunReceiveCallback(const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
977 size_t size) {
978 MS_EXCEPTION_IF_NULL(meta);
979 MS_EXCEPTION_IF_NULL(data);
980 receive_callbacks_mutex_.lock();
981 uint32_t rank_id = meta->rank_id();
982 // When receiving a collective message, Then generate rank request id,compare with the desired rank request id,
983 // If they are equal, then call the callback function
984 uint64_t rank_request_id = NextActualRankRequestId(rank_id);
985 std::shared_ptr<std::vector<unsigned char>> received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
986 size_t dest_size = size;
987 size_t src_size = size;
988 int ret = memcpy_s(received_data->data(), dest_size, data, src_size);
989 if (ret != 0) {
990 receive_callbacks_mutex_.unlock();
991 MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
992 }
993 received_data_[std::make_pair(rank_id, rank_request_id)] = received_data;
994 MS_LOG(DEBUG) << "Run Receive data callback,the rank id:" << rank_id << ", the rank request id is:" << rank_request_id
995 << ", the send request id is:" << meta->request_id() << " the size is:" << size;
996 auto it = receive_callbacks_.find(std::make_pair(rank_id, rank_request_id));
997 if (it != receive_callbacks_.end()) {
998 if (receive_messages_done_.count(std::make_pair(rank_id, rank_request_id)) != 0) {
999 if (it->second) {
1000 it->second();
1001 }
1002 }
1003 receive_cond_.notify_all();
1004 receive_callbacks_.erase(it);
1005 }
1006 receive_callbacks_mutex_.unlock();
1007 }
1008
NextExpectedRankRequestId(const uint32_t & rank_id)1009 uint64_t AbstractNode::NextExpectedRankRequestId(const uint32_t &rank_id) {
1010 std::lock_guard<std::mutex> lock(rank_request_ids_mutex);
1011 uint64_t rank_request_id = 1;
1012 if (expected_rank_request_ids_.count(rank_id)) {
1013 rank_request_id = ++expected_rank_request_ids_[rank_id];
1014 expected_rank_request_ids_[rank_id] = rank_request_id;
1015 } else {
1016 expected_rank_request_ids_[rank_id] = rank_request_id;
1017 }
1018 return rank_request_id;
1019 }
1020
NextActualRankRequestId(const uint32_t & rank_id)1021 uint64_t AbstractNode::NextActualRankRequestId(const uint32_t &rank_id) {
1022 std::lock_guard<std::mutex> lock(rank_request_ids_mutex);
1023 uint64_t rank_request_id = 1;
1024 if (actual_rank_request_ids_.count(rank_id)) {
1025 rank_request_id = ++actual_rank_request_ids_[rank_id];
1026 actual_rank_request_ids_[rank_id] = rank_request_id;
1027 } else {
1028 actual_rank_request_ids_[rank_id] = rank_request_id;
1029 }
1030 return rank_request_id;
1031 }
1032
InitCommandHandler()1033 void AbstractNode::InitCommandHandler() {
1034 handlers_[NodeCommand::HEARTBEAT] = &AbstractNode::ProcessHeartbeatResp;
1035 handlers_[NodeCommand::REGISTER] = &AbstractNode::ProcessRegisterResp;
1036 handlers_[NodeCommand::FETCH_METADATA] = &AbstractNode::ProcessFetchServersResp;
1037 handlers_[NodeCommand::FINISH] = nullptr;
1038 handlers_[NodeCommand::SCALE_OUT_DONE] = nullptr;
1039 handlers_[NodeCommand::SCALE_IN_DONE] = nullptr;
1040 handlers_[NodeCommand::SEND_EVENT] = nullptr;
1041 }
1042
InitServerHandler()1043 void AbstractNode::InitServerHandler() {
1044 server_handler_[NodeCommand::SEND_METADATA] = &AbstractNode::ProcessSendMetadata;
1045 server_handler_[NodeCommand::FINISH] = &AbstractNode::ProcessFinish;
1046 server_handler_[NodeCommand::SEND_DATA] = nullptr;
1047 server_handler_[NodeCommand::COLLECTIVE_SEND_DATA] = nullptr;
1048 server_handler_[NodeCommand::SCALE_OUT] = &AbstractNode::ProcessScaleOut;
1049 server_handler_[NodeCommand::SCALE_IN] = &AbstractNode::ProcessScaleIn;
1050 server_handler_[NodeCommand::SCALE_OUT_DONE] = &AbstractNode::ProcessScaleOutDone;
1051 server_handler_[NodeCommand::SCALE_IN_DONE] = &AbstractNode::ProcessScaleInDone;
1052 server_handler_[NodeCommand::SEND_EVENT] = &AbstractNode::ProcessEvent;
1053 }
1054
InitNodeInfo(const NodeRole & role)1055 void AbstractNode::InitNodeInfo(const NodeRole &role) {
1056 MS_EXCEPTION_IF_NULL(config_);
1057 MS_EXCEPTION_IF_NULL(server_);
1058 if (PSContext::instance()->node_id().empty() && config_->Exists(kNodeId)) {
1059 node_info_.node_id_ = config_->Get(kNodeId, "");
1060 } else {
1061 node_info_.node_id_ = PSContext::instance()->node_id();
1062 }
1063
1064 if (node_info_.node_id_.empty()) {
1065 node_info_.node_id_ = CommUtil::GenerateUUID();
1066 }
1067 node_info_.node_role_ = role;
1068 node_info_.ip_ = server_->BoundIp();
1069 node_info_.port_ = server_->BoundPort();
1070
1071 MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
1072 << " is generate uuid is:" << node_info_.node_id_ << ", the ip:" << server_->BoundIp()
1073 << ", the port:" << server_->BoundPort();
1074 }
1075
InitNodeNum()1076 void AbstractNode::InitNodeNum() {
1077 worker_num_ = SizeToInt(PSContext::instance()->cluster_config().initial_worker_num);
1078 server_num_ = SizeToInt(PSContext::instance()->cluster_config().initial_server_num);
1079 scheduler_ip_ = PSContext::instance()->cluster_config().scheduler_host;
1080 scheduler_port_ = PSContext::instance()->cluster_config().scheduler_port;
1081 MS_LOG(INFO) << "The worker num:" << worker_num_ << ", the server num:" << server_num_
1082 << ", the scheduler ip:" << scheduler_ip_ << ", the scheduler port:" << scheduler_port_;
1083 }
1084
Recover()1085 bool AbstractNode::Recover() {
1086 MS_EXCEPTION_IF_NULL(config_);
1087 if (config_->Exists(kKeyRecovery)) {
1088 MS_LOG(INFO) << "The node is support recovery.";
1089 node_recovery_ = std::make_unique<NodeRecovery>(this);
1090 MS_EXCEPTION_IF_NULL(node_recovery_);
1091 node_recovery_->Initialize(config_->Get(kKeyRecovery, ""));
1092 return node_recovery_->Recover();
1093 }
1094 return false;
1095 }
1096
OnEventCallback(const ClusterEvent & event)1097 void AbstractNode::OnEventCallback(const ClusterEvent &event) {
1098 if (!event_to_callback_.count(event)) {
1099 MS_LOG(ERROR) << "[Event]:The event callback of " << event << " is not set.";
1100 } else {
1101 MS_LOG(INFO) << "[Event]:Trigger the event:" << event;
1102 if (event_to_callback_[event]) {
1103 event_to_callback_[event]();
1104 }
1105 }
1106 }
1107
OnCustomEventCallback(const uint32_t & event)1108 void AbstractNode::OnCustomEventCallback(const uint32_t &event) {
1109 if (!custom_event_to_callback_.count(event)) {
1110 MS_LOG(WARNING) << "[Custom event]:The event callback of " << event << " is not set.";
1111 } else {
1112 MS_LOG(INFO) << "[Custom event]:Trigger the event:" << event;
1113 if (custom_event_to_callback_[event]) {
1114 custom_event_to_callback_[event]();
1115 }
1116 }
1117 }
1118
IsWorkerOrServer0(const std::unordered_map<std::string,NodeInfo> & info)1119 bool AbstractNode::IsWorkerOrServer0(const std::unordered_map<std::string, NodeInfo> &info) {
1120 for (const auto &it : info) {
1121 if (it.second.is_alive == true && it.second.node_role_ == NodeRole::WORKER) {
1122 return true;
1123 }
1124
1125 if (it.second.is_alive == true && it.second.rank_id_ == 0 && it.second.node_role_ == NodeRole::SERVER) {
1126 return true;
1127 }
1128 }
1129 return false;
1130 }
1131
CreateTcpServer()1132 void AbstractNode::CreateTcpServer() {
1133 MS_EXCEPTION_IF_NULL(config_);
1134 std::string interface;
1135 std::string server_ip;
1136 CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
1137 server_ = std::make_shared<TcpServer>(server_ip, 0, config_.get());
1138 MS_EXCEPTION_IF_NULL(server_);
1139 server_->SetMessageCallback([&](const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
1140 const Protos &protos, const void *data, size_t size) {
1141 MS_EXCEPTION_IF_NULL(meta);
1142 MS_EXCEPTION_IF_NULL(conn);
1143 MS_EXCEPTION_IF_NULL(data);
1144 if (server_handler_.count(meta->cmd()) == 0) {
1145 MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
1146 }
1147
1148 if (meta->cmd() == NodeCommand::COLLECTIVE_SEND_DATA) {
1149 ProcessCollectiveSendData(conn, meta, data, size);
1150 RunReceiveCallback(meta, protos, data, size);
1151 } else if (meta->cmd() == NodeCommand::SEND_DATA) {
1152 ProcessSendData(conn, meta, protos, data, size);
1153 } else {
1154 const auto &handler_ptr = server_handler_[meta->cmd()];
1155 (this->*handler_ptr)(conn, meta, protos, data, size);
1156 }
1157 });
1158 server_->Init();
1159 server_thread_ = std::make_unique<std::thread>([this]() {
1160 MS_LOG(INFO) << "The server node start a tcp server!";
1161 this->server_->Start();
1162 });
1163 MS_EXCEPTION_IF_NULL(server_thread_);
1164 server_thread_->detach();
1165 }
1166 } // namespace core
1167 } // namespace ps
1168 } // namespace mindspore
1169