1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "include/backend/distributed/cluster/topology/compute_graph_node.h"
17 #include <utility>
18 #include <random>
19 #include <nlohmann/json.hpp>
20 #include "utils/log_adapter.h"
21 #include "utils/ms_exception.h"
22 #include "include/backend/distributed/cluster/topology/common.h"
23 #include "include/backend/distributed/recovery/recovery_context.h"
24 #include "include/backend/distributed/constants.h"
25 #include "proto/topology.pb.h"
26 #include "include/backend/distributed/ps/ps_context.h"
27 #include "include/backend/distributed/rpc/tcp/constants.h"
28 #include "utils/convert_utils_base.h"
29
30 namespace mindspore {
31 namespace distributed {
32 namespace cluster {
33 namespace topology {
34 constexpr char kStartExchangeMetaPrefix[] = "START_EXCHANGE_META_";
35 constexpr char kExchangeMetaDonePrefix[] = "EXCHANGE_META_DONE_";
36 constexpr char kMetaFlagValue[] = "1";
37 constexpr char kMetaDeleteFlagValue[] = "";
38
~ComputeGraphNode()39 ComputeGraphNode::~ComputeGraphNode() {
40 if (!finalized_) {
41 try {
42 (void)Finalize(true);
43 } catch (std::exception &) {
44 MS_LOG(ERROR) << "Failed to finalize ComputeGraphNode.";
45 }
46 }
47 }
48
Initialize()49 bool ComputeGraphNode::Initialize() {
50 // Init the address of meta server node.
51 RETURN_IF_FALSE_WITH_LOG(FillMetaServerAddress(&meta_server_addr_),
52 "Failed to init the address of meta server node.");
53
54 // Init the TCP client.
55 bool enable_ssl = ps::PSContext::instance()->enable_ssl();
56 tcp_client_ = std::make_unique<rpc::TCPClient>(enable_ssl);
57 MS_EXCEPTION_IF_NULL(tcp_client_);
58 RETURN_IF_FALSE_WITH_LOG(tcp_client_->Initialize(), "Failed to create the TCP client.");
59
60 hb_client_ = std::make_unique<rpc::TCPClient>(enable_ssl);
61 MS_EXCEPTION_IF_NULL(hb_client_);
62 RETURN_IF_FALSE_WITH_LOG(hb_client_->Initialize(), "Failed to create the heartbeat tcp client.");
63
64 // Register itself to meta server node.
65 bool success = false;
66 if (!enable_ssl) {
67 success = ReconnectWithTimeoutWindow(std::bind(&ComputeGraphNode::Register, this),
68 "Failed to register and try to reconnect to the meta server.", topo_timeout_);
69 } else {
70 const auto &server_url = meta_server_addr_.GetUrl();
71 size_t retry = 10;
72 while (!success && retry-- > 0) {
73 success = Register();
74 if (success) {
75 break;
76 }
77
78 if (tcp_client_ != nullptr) {
79 (void)tcp_client_->Disconnect(server_url);
80 tcp_client_->Finalize();
81 tcp_client_.reset();
82 }
83 if (hb_client_ != nullptr) {
84 (void)hb_client_->Disconnect(server_url);
85 hb_client_->Finalize();
86 hb_client_.reset();
87 }
88
89 tcp_client_ = std::make_unique<rpc::TCPClient>(enable_ssl);
90 MS_EXCEPTION_IF_NULL(tcp_client_);
91 RETURN_IF_FALSE_WITH_LOG(tcp_client_->Initialize(), "Failed to create the TCP client.");
92
93 hb_client_ = std::make_unique<rpc::TCPClient>(enable_ssl);
94 MS_EXCEPTION_IF_NULL(hb_client_);
95 RETURN_IF_FALSE_WITH_LOG(hb_client_->Initialize(), "Failed to create the heartbeat tcp client.");
96 }
97 }
98 if (!success) {
99 return false;
100 }
101
102 // Enable the heartbeat to meta server node.
103 enable_hb_ = true;
104 heartbeat_ = std::thread(&ComputeGraphNode::Heartbeat, this);
105 return true;
106 }
107
Initialized()108 bool ComputeGraphNode::Initialized() {
109 // The cgn is initialized only when the cluster is ready, or there will be error message unexpected.
110 return authenticated_ && topo_state_ == TopoState::kInitialized;
111 }
112
Finalize(bool force)113 bool ComputeGraphNode::Finalize(bool force) {
114 // Stop the heartbeat thread.
115 enable_hb_ = false;
116 if (heartbeat_.joinable()) {
117 heartbeat_.join();
118 }
119
120 // Exit the compute graph node from the cluster topology.
121 while (!force) {
122 bool success = ReconnectIfNeeded(std::bind(&ComputeGraphNode::Unregister, this),
123 "Failed to unregister and try to reconnect to the meta server.", kNoRetry);
124 if (!success) {
125 MS_LOG(ERROR) << "Failed to unregister from the meta server node.";
126 if (recovery::IsEnableRecovery()) {
127 continue;
128 } else {
129 break;
130 }
131 } else {
132 MS_LOG(INFO) << "The compute graph node has been unregistered successfully.";
133 break;
134 }
135 }
136
137 // Release the TCP client.
138 bool enable_ssl = ps::PSContext::instance()->enable_ssl();
139 const auto &server_url = meta_server_addr_.GetUrl();
140 if (tcp_client_ != nullptr) {
141 if (!(enable_ssl && !authenticated_)) {
142 (void)tcp_client_->Disconnect(server_url);
143 }
144 tcp_client_->Finalize();
145 tcp_client_.reset();
146 }
147
148 if (hb_client_ != nullptr) {
149 if (!(enable_ssl && !authenticated_)) {
150 (void)hb_client_->Disconnect(server_url);
151 }
152 hb_client_->Finalize();
153 hb_client_.reset();
154 }
155 return true;
156 }
157
Register()158 bool ComputeGraphNode::Register() {
159 MS_EXCEPTION_IF_NULL(hb_client_);
160 MS_EXCEPTION_IF_NULL(tcp_client_);
161 const auto &server_url = meta_server_addr_.GetUrl();
162 MS_LOG(INFO) << "Start connecting heartbeat client.";
163 if (!hb_client_->IsConnected(server_url)) {
164 if (!hb_client_->Connect(server_url, kNoRetry)) {
165 MS_LOG(WARNING) << "Failed to connect to the meta server node url: " << server_url;
166 return false;
167 }
168 }
169
170 MS_LOG(INFO) << "Start connecting business client.";
171 if (!tcp_client_->IsConnected(server_url)) {
172 if (!tcp_client_->Connect(server_url, kNoRetry)) {
173 MS_LOG(WARNING) << "Failed to connect to the meta server node url: " << server_url;
174 return false;
175 }
176 }
177
178 RegistrationMessage reg_msg;
179 reg_msg.set_node_id(node_id_);
180 reg_msg.set_role(role_);
181
182 // Set the local hostname.
183 char host_name[MAX_HOSTNAME_LEN] = {0};
184 if (gethostname(host_name, MAX_HOSTNAME_LEN) != 0) {
185 MS_LOG(ERROR) << "Failed to get local host name.";
186 return false;
187 }
188 reg_msg.set_host_name(std::string(host_name));
189
190 // Set client ip address.
191 client_ip_ = hb_client_->GetClientIPByDstUrl(server_url);
192 reg_msg.set_host_ip(client_ip_);
193
194 std::string content = reg_msg.SerializeAsString();
195 auto message = CreateMessage(server_url, MessageName::kRegistration, content);
196 MS_EXCEPTION_IF_NULL(message);
197
198 const uint32_t timeout = 10;
199 MessageBase *response = hb_client_->ReceiveSync(std::move(message), timeout);
200 if (response == nullptr) {
201 return false;
202 }
203 auto body = response->body;
204 delete response;
205 response = nullptr;
206
207 RegistrationRespMessage reg_resp_msg;
208 (void)reg_resp_msg.ParseFromArray(body.c_str(), SizeToInt(body.length()));
209
210 if (reg_resp_msg.success()) {
211 authenticated_ = true;
212 rank_id_ = reg_resp_msg.rank_id();
213 MS_LOG(INFO) << "The compute graph node: " << node_id_ << " has been registered successfully.";
214 return true;
215 } else {
216 MS_LOG(EXCEPTION) << "Failed to register the compute graph node: " << node_id_
217 << ". Reason: " << reg_resp_msg.error_reason();
218 }
219 }
220
Unregister()221 bool ComputeGraphNode::Unregister() {
222 MS_EXCEPTION_IF_NULL(hb_client_);
223
224 UnregistrationMessage unreg_msg;
225 unreg_msg.set_node_id(node_id_);
226
227 std::string content = unreg_msg.SerializeAsString();
228 auto message = CreateMessage(meta_server_addr_.GetUrl(), MessageName::kUnregistration, content);
229 MS_EXCEPTION_IF_NULL(message);
230
231 const uint32_t timeout = 10;
232 MessageBase *response = hb_client_->ReceiveSync(std::move(message), timeout);
233 if (response == nullptr) {
234 return false;
235 }
236 auto unreg_rt = response->body;
237 delete response;
238 response = nullptr;
239
240 if (std::to_string(static_cast<int>(MessageName::kSuccess)) == unreg_rt) {
241 return true;
242 } else {
243 return false;
244 }
245 }
246
Heartbeat()247 bool ComputeGraphNode::Heartbeat() {
248 std::random_device rd;
249 std::mt19937 gen(rd());
250 int random_time_lower =
251 common::GetEnv("MS_RETRY_INTERVAL_LOWER").empty() ? 3 : std::stoi(common::GetEnv("MS_RETRY_INTERVAL_LOWER"));
252 int random_time_upper =
253 common::GetEnv("MS_RETRY_INTERVAL_UPPER").empty() ? 5 : std::stoi(common::GetEnv("MS_RETRY_INTERVAL_UPPER"));
254 std::uniform_int_distribution<> distrib(random_time_lower, random_time_upper);
255 MS_LOG(INFO) << "Interval of heartbeat lower and upper are " << random_time_lower << " and " << random_time_upper;
256 try {
257 MS_EXCEPTION_IF_NULL(hb_client_);
258
259 MS_LOG(INFO) << "The heartbeat thread is started.";
260 while (enable_hb_) {
261 HeartbeatMessage hb_msg;
262 hb_msg.set_node_id(node_id_);
263
264 const auto &server_url = meta_server_addr_.GetUrl();
265 std::string content = hb_msg.SerializeAsString();
266 auto message = CreateMessage(server_url, MessageName::kHeartbeat, content);
267 MS_EXCEPTION_IF_NULL(message);
268
269 MessageBase *response = hb_client_->ReceiveSync(std::move(message));
270 if (response == nullptr) {
271 MS_LOG(ERROR)
272 << "Failed to send heartbeat message to meta server node and try to reconnect to the meta server.";
273 if (!Reconnect()) {
274 if (!recovery::IsEnableRecovery() && topo_state_ != TopoState::kInitializing) {
275 topo_state_ = TopoState::kFailed;
276 if (abnormal_callback_ != nullptr) {
277 (*abnormal_callback_)();
278 }
279 MS_LOG(EXCEPTION)
280 << "Failed to connect to the meta server. Maybe it has exited. Please check scheduler's log.";
281 } else {
282 MS_LOG(ERROR) << "Failed to connect to the meta server. Maybe it has exited. Please check scheduler's log.";
283 }
284 }
285 } else {
286 auto &body = response->body;
287 HeartbeatRespMessage resp_msg;
288 (void)resp_msg.ParseFromArray(body.c_str(), SizeToInt(body.length()));
289 topo_state_ = static_cast<TopoState>(resp_msg.topo_state());
290 if (topo_state_ == TopoState::kInitialized && disable_heartbeat_) {
291 MS_LOG(WARNING)
292 << "After cluster is initialized, disconnect heartbeat client if MS_DISABLE_HEARTBEAT is set to 1.";
293 (void)hb_client_->Disconnect(meta_server_addr_.GetUrl());
294 break;
295 }
296
297 auto nodes_num = resp_msg.nodes_num();
298 auto abnormal_nodes_num = resp_msg.abnormal_nodes_num();
299 if (abnormal_nodes_num > 0 && !recovery::IsEnableRecovery()) {
300 topo_state_ = TopoState::kFailed;
301 if (abnormal_callback_ != nullptr) {
302 (*abnormal_callback_)();
303 }
304 delete response;
305 MS_LOG(EXCEPTION) << "The state of the cluster is error, total nodes num: " << nodes_num
306 << ", abnormal nodes num: " << abnormal_nodes_num;
307 }
308 delete response;
309 }
310
311 uint32_t interval = distrib(gen);
312 MS_LOG(DEBUG) << "Heart beat interval " << interval;
313 (void)sleep(interval);
314 }
315 } catch (const std::exception &e) {
316 MsException::Instance().SetException();
317 }
318 return true;
319 }
320
ReconnectIfNeeded(const std::function<bool (void)> & func,const std::string & error,size_t retry)321 bool ComputeGraphNode::ReconnectIfNeeded(const std::function<bool(void)> &func, const std::string &error,
322 size_t retry) {
323 bool success = false;
324
325 while (!success && retry > 0) {
326 success = func();
327 if (!success) {
328 // Retry to reconnect to the meta server.
329 MS_LOG(WARNING) << error;
330 (void)sleep(kExecuteInterval);
331 (void)Reconnect();
332 }
333 --retry;
334 }
335 return success;
336 }
337
ReconnectWithTimeoutWindow(const std::function<bool (void)> & func,const std::string & error,size_t time_out)338 bool ComputeGraphNode::ReconnectWithTimeoutWindow(const std::function<bool(void)> &func, const std::string &error,
339 size_t time_out) {
340 size_t time_out_in_milli = time_out * 1000;
341 size_t start_tick = LongToSize(CURRENT_TIMESTAMP_MILLI.count());
342 bool success = false;
343 while (!success && LongToSize(CURRENT_TIMESTAMP_MILLI.count()) - start_tick <= time_out_in_milli) {
344 success = func();
345 if (!success) {
346 // Retry to reconnect to the meta server.
347 MS_LOG(WARNING) << error;
348 (void)sleep(kExecuteInterval);
349 (void)Reconnect();
350 }
351 }
352 return success;
353 }
354
Reconnect()355 bool ComputeGraphNode::Reconnect() {
356 MS_ERROR_IF_NULL_W_RET_VAL(tcp_client_, false);
357 MS_ERROR_IF_NULL_W_RET_VAL(hb_client_, false);
358
359 auto server_url = meta_server_addr_.GetUrl();
360 // Disconnect from meta server node firstly.
361 while (tcp_client_->IsConnected(server_url)) {
362 (void)tcp_client_->Disconnect(server_url);
363 }
364 while (hb_client_->IsConnected(server_url)) {
365 (void)hb_client_->Disconnect(server_url);
366 }
367
368 // Reconnect to the meta server node.
369 if (!tcp_client_->IsConnected(server_url)) {
370 MS_LOG(INFO) << "Start reconnecting business client.";
371 (void)tcp_client_->Connect(server_url, kNoRetry);
372 }
373 if (!tcp_client_->IsConnected(server_url)) {
374 return false;
375 }
376 if (!hb_client_->IsConnected(server_url)) {
377 MS_LOG(INFO) << "Start reconnecting heartbeat client.";
378 (void)hb_client_->Connect(server_url, kNoRetry);
379 }
380 return hb_client_->IsConnected(server_url);
381 }
382
SendMessageToMSN(const std::string msg_name,const std::string & msg_body,bool sync)383 bool ComputeGraphNode::SendMessageToMSN(const std::string msg_name, const std::string &msg_body, bool sync) {
384 MS_EXCEPTION_IF_NULL(tcp_client_);
385
386 auto message = CreateMessage(meta_server_addr_.GetUrl(), msg_name, msg_body);
387 MS_EXCEPTION_IF_NULL(message);
388
389 if (sync) {
390 auto retval = tcp_client_->SendSync(std::move(message));
391 if (retval) {
392 return true;
393 } else {
394 return false;
395 }
396 } else {
397 (void)tcp_client_->SendSync(std::move(message));
398 return true;
399 }
400 }
401
RetrieveMessageFromMSN(const std::string & msg_name,uint32_t timeout)402 std::shared_ptr<std::string> ComputeGraphNode::RetrieveMessageFromMSN(const std::string &msg_name, uint32_t timeout) {
403 return RetrieveMessageFromMSN(msg_name, msg_name);
404 }
405
PutMetadata(const std::string & name,const std::string & value,bool sync)406 bool ComputeGraphNode::PutMetadata(const std::string &name, const std::string &value, bool sync) {
407 MetadataMessage metadata;
408 metadata.set_name(name);
409 metadata.set_value(value);
410 return SendMessageToMSN(std::to_string(static_cast<int>(MessageName::kWriteMetadata)), metadata.SerializeAsString(),
411 sync);
412 }
413
PutMetadata(const std::string & name,const void * value,const size_t & size)414 bool ComputeGraphNode::PutMetadata(const std::string &name, const void *value, const size_t &size) {
415 MetadataMessage metadata;
416 metadata.set_name(name);
417 metadata.set_value(value, size);
418 return SendMessageToMSN(std::to_string(static_cast<int>(MessageName::kWriteMetadata)), metadata.SerializeAsString());
419 }
420
GetMetadata(const std::string & name,uint32_t)421 std::string ComputeGraphNode::GetMetadata(const std::string &name, uint32_t) {
422 MetadataMessage metadata;
423 metadata.set_name(name);
424
425 auto message = CreateMessage(meta_server_addr_.GetUrl(), std::to_string(static_cast<int>(MessageName::kReadMetadata)),
426 metadata.SerializeAsString());
427 MS_EXCEPTION_IF_NULL(message);
428
429 MS_EXCEPTION_IF_NULL(tcp_client_);
430 auto retval = tcp_client_->ReceiveSync(std::move(message));
431 if (retval != rpc::NULL_MSG && (retval->name == std::to_string(static_cast<int>(MessageName::kValidMetadata)))) {
432 (void)metadata.ParseFromArray(retval->body.c_str(), SizeToInt(retval->body.length()));
433 return metadata.value();
434 }
435 return "";
436 }
437
DeleteMetadata(const std::string & name,uint32_t)438 bool ComputeGraphNode::DeleteMetadata(const std::string &name, uint32_t) {
439 MetadataMessage metadata;
440 metadata.set_name(name);
441
442 auto message =
443 CreateMessage(meta_server_addr_.GetUrl(), std::to_string(static_cast<int>(MessageName::kDeleteMetadata)),
444 metadata.SerializeAsString());
445 MS_EXCEPTION_IF_NULL(message);
446
447 MS_EXCEPTION_IF_NULL(tcp_client_);
448 auto retval = tcp_client_->ReceiveSync(std::move(message));
449 if (retval != rpc::NULL_MSG && (retval->name == std::to_string(static_cast<int>(MessageName::kValidMetadata)))) {
450 return true;
451 } else {
452 return false;
453 }
454 }
455
456 // The transaction of the exchange process is as follows:
457 // step 1: RANK[0] - Start the exchange process (set EXCHANGE_META_${name} flag);
458 // step 2: RANK[1-(N-1)] - Start the exchange process (check EXCHANGE_META_${name} flag);
459 // step 3: RANK[0-(N-1)] - Do the exchange (exchange the metadata through meta server node);
460 // step 4: RANK[0-(N-1)] - Finish the exchange process (set EXCHANGE_META_${name}_DONE_RANK_${RANK_ID});
461 // step 5: RANK[0] - Exit the exchange process (check all the EXCHANGE_META_${name}_DONE_RANK_${RANK_ID} flag &
462 // delete all the EXCHANGE_META_${name}_DONE_RANK_${RANK_ID} flag &
463 // delete all the metadata in results &
464 // delete EXCHANGE_META_${name} flag);
465 // step 6: RANK[1-(N-1)] - Exit the exchange process (check EXCHANGE_META_${name} flag deleted);
ExchangeMetadata(const std::string & biz,const size_t & rank_size,const std::vector<std::string> & names_prefix,const std::vector<std::string> & values,std::map<std::string,std::string> * results,uint32_t timeout)466 bool ComputeGraphNode::ExchangeMetadata(const std::string &biz, const size_t &rank_size,
467 const std::vector<std::string> &names_prefix,
468 const std::vector<std::string> &values,
469 std::map<std::string, std::string> *results, uint32_t timeout) {
470 std::unique_lock<std::shared_mutex> lock(exchange_meta_mutex_);
471 MS_ERROR_IF_NULL_W_RET_VAL(results, false);
472 MS_LOG(INFO) << "Start to exchange metadata for the biz: " << biz;
473 if (names_prefix.size() != values.size()) {
474 return false;
475 }
476 if (timeout == 0) {
477 return false;
478 }
479 bool success = false;
480
481 // step 1 set the start flag.
482 std::string meta_name = kStartExchangeMetaPrefix + biz;
483 if (rank_id_ == 0) {
484 EXECUTE_WITH_TIMEOUT(PutMetadata(meta_name, kMetaFlagValue), kExecuteInterval,
485 "Failed to set the metadata exchange flag " + meta_name + ".", success, timeout);
486 }
487 // step 2 check the start flag.
488 EXECUTE_WITH_EXPECTED(GetMetadata(meta_name), kMetaFlagValue, kExecuteInterval,
489 "Failed to check the metadata exchange flag " << meta_name << ".", timeout);
490 // step 3 exchange the metadata.
491 for (size_t i = 0; i < names_prefix.size(); ++i) {
492 auto name = names_prefix[i] + std::to_string(rank_id_);
493 auto value = values[i];
494 EXECUTE_WITH_TIMEOUT(PutMetadata(name, value), kExecuteInterval,
495 "Failed to put metadata name: " + name + ", value: " + value + ".", success, timeout);
496 }
497 for (size_t i = 0; i < rank_size; ++i) {
498 for (size_t j = 0; j < names_prefix.size(); ++j) {
499 auto other_name = names_prefix[j] + std::to_string(i);
500 while (true) {
501 auto other_value = GetMetadata(other_name);
502 if (other_value.length() > 0) {
503 (*results)[other_name] = other_value;
504 break;
505 } else {
506 MS_LOG(WARNING) << "Failed to get metadata " << other_name << " from rank " << i;
507 (void)sleep(kExecuteInterval);
508 }
509 }
510 }
511 }
512 // step 4 set the exchange done flag.
513 auto done = kExchangeMetaDonePrefix + std::to_string(rank_id_);
514 EXECUTE_WITH_TIMEOUT(PutMetadata(done, kMetaFlagValue), kExecuteInterval,
515 "Failed to set the metadata exchange done flag " + done + ".", success, timeout);
516 // step 5 check all node done and then clear the metadata in meta server and remove the start flag finally.
517 if (rank_id_ == 0) {
518 for (size_t i = 0; i < rank_size; ++i) {
519 auto other_done = kExchangeMetaDonePrefix + std::to_string(i);
520 EXECUTE_WITH_EXPECTED(
521 GetMetadata(other_done), kMetaFlagValue, kExecuteInterval,
522 "Failed to check the metadata exchange done flag " << other_done << " for rank " << i << ".", timeout);
523 }
524 for (size_t i = 0; i < rank_size; ++i) {
525 auto other_done = kExchangeMetaDonePrefix + std::to_string(i);
526 EXECUTE_WITH_TIMEOUT(DeleteMetadata(other_done), kExecuteInterval,
527 "Failed to delete the metadata exchange done flag " + other_done + ".", success, timeout);
528 }
529 for (auto iter = results->begin(); iter != results->end(); ++iter) {
530 auto delete_name = iter->first;
531 EXECUTE_WITH_TIMEOUT(DeleteMetadata(delete_name), kExecuteInterval,
532 "Failed to delete the metadata: " + delete_name + ".", success, timeout);
533 }
534 EXECUTE_WITH_TIMEOUT(DeleteMetadata(meta_name), kExecuteInterval,
535 "Failed to delete the metadata flag: " + meta_name + ".", success, timeout);
536 }
537
538 // step 6 check the exchange finish flag.
539 EXECUTE_WITH_EXPECTED(GetMetadata(meta_name), kMetaDeleteFlagValue, kExecuteInterval,
540 "Failed to check the metadata exchange flag " << meta_name << ".", timeout);
541 MS_LOG(INFO) << "The metadata exchange for the biz: " << biz << " has been completed";
542 return true;
543 }
544
GetHostNames(const std::string & role)545 std::vector<std::string> ComputeGraphNode::GetHostNames(const std::string &role) {
546 auto retval = RetrieveMessageFromMSN(std::to_string(static_cast<int>(MessageName::kGetHostNames)), role);
547 if (retval != nullptr) {
548 MS_LOG(INFO) << "Worker gets host names " << *retval;
549 nlohmann::json hostnames;
550 size_t retry_num = 60;
551 do {
552 try {
553 if (retval != nullptr) {
554 hostnames = nlohmann::json::parse(*retval);
555 } else {
556 MS_LOG(ERROR) << "Get hostnames from sched failed, receive empty message.";
557 }
558 break;
559 } catch (const std::exception &e) {
560 MS_LOG(ERROR) << "Worker failed to parse hostname json " << e.what() << ". Retry number: " << retry_num;
561 retval = RetrieveMessageFromMSN(std::to_string(static_cast<int>(MessageName::kGetHostNames)), role);
562 retry_num--;
563 (void)sleep(kExecuteInterval);
564 }
565 } while (retry_num != 0);
566 MS_LOG(DEBUG) << "Successfully get hostnames from scheduler: " << hostnames.dump();
567 return hostnames.at(kHostNames).get<std::vector<std::string>>();
568 } else {
569 return std::vector<std::string>();
570 }
571 }
572
set_abnormal_callback(std::shared_ptr<std::function<void (void)>> abnormal_callback)573 void ComputeGraphNode::set_abnormal_callback(std::shared_ptr<std::function<void(void)>> abnormal_callback) {
574 abnormal_callback_ = abnormal_callback;
575 }
576
client_ip() const577 const std::string &ComputeGraphNode::client_ip() const { return client_ip_; }
578
RetrieveMessageFromMSN(const std::string & msg_name,const std::string & msg_body,uint32_t)579 std::shared_ptr<std::string> ComputeGraphNode::RetrieveMessageFromMSN(const std::string &msg_name,
580 const std::string &msg_body, uint32_t) {
581 MS_EXCEPTION_IF_NULL(tcp_client_);
582
583 auto message = CreateMessage(meta_server_addr_.GetUrl(), msg_name, msg_body);
584 MS_EXCEPTION_IF_NULL(message);
585
586 auto retval = tcp_client_->ReceiveSync(std::move(message));
587 if (retval != rpc::NULL_MSG) {
588 return std::make_shared<std::string>(retval->body);
589 }
590 return nullptr;
591 }
592 } // namespace topology
593 } // namespace cluster
594 } // namespace distributed
595 } // namespace mindspore
596