• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/backend/distributed/collective/collective_manager.h"
18 #include <algorithm>
19 #include <string>
20 #include <numeric>
21 #include <vector>
22 #include <functional>
23 #include <csignal>
24 #include <memory>
25 #include "utils/ms_context.h"
26 #include "include/backend/distributed/recovery/recovery_context.h"
27 #include "distributed/persistent/storage/json_utils.h"
28 #include "runtime/collective/dummy_collective_communication_lib.h"
29 
30 namespace mindspore {
31 namespace distributed {
32 namespace collective {
33 using recovery::RecoveryContext;
34 
CollectiveManager()35 CollectiveManager::CollectiveManager()
36     : inited_(false),
37       finalized_(true),
38       need_init_(false),
39       need_reinit_(false),
40       host_ctx_(nullptr),
41       device_ctx_(nullptr),
42       host_comm_lib_instance_(nullptr),
43       device_comm_lib_instance_(nullptr),
44       comm_lib_instance_(nullptr),
45       global_rank_id_(0),
46       local_rank_id_(0),
47       global_rank_size_(1),
48       global_group_ranks_({}),
49       device_lib_supported_(true),
50       need_host_collective_(false) {}
51 
~CollectiveManager()52 CollectiveManager::~CollectiveManager() {
53   if (!finalized_) {
54     try {
55       (void)Finalize();
56     } catch (std::exception &) {
57       MS_LOG(ERROR) << "Failed to finalize collective manager.";
58     }
59   }
60   finalized_ = true;
61   host_ctx_ = nullptr;
62   device_ctx_ = nullptr;
63   host_comm_lib_instance_ = nullptr;
64   device_comm_lib_instance_ = nullptr;
65   comm_lib_instance_ = nullptr;
66 }
67 
instance()68 std::shared_ptr<CollectiveManager> CollectiveManager::instance() {
69   static std::shared_ptr<CollectiveManager> instance = nullptr;
70   if (instance == nullptr) {
71     instance.reset(new (std::nothrow) CollectiveManager());
72     MS_EXCEPTION_IF_NULL(instance);
73   }
74   return instance;
75 }
76 
77 namespace {
78 // The wrapper to provide a timeout mechanism for executing functions.
79 // We also need to log the functionality of the function.
ExecuteFuncInThread(const std::function<bool ()> & func,const int64_t timeout,const std::string & func_name,const std::string & functionality)80 bool ExecuteFuncInThread(const std::function<bool()> &func, const int64_t timeout, const std::string &func_name,
81                          const std::string &functionality) {
82   bool execute_success = false;
83   bool execute_fail = false;
84   std::mutex exec_ret_mutex;
85   std::condition_variable thread_blocker;
86 
87   std::unique_ptr<std::thread> executive_thread = std::make_unique<std::thread>([&] {
88     if (!func()) {
89       MS_LOG(ERROR) << "Failed to execute function: " << func_name << " " << functionality
90                     << ". Please check error log above.";
91       std::unique_lock<std::mutex> lock(exec_ret_mutex);
92       execute_fail = true;
93       thread_blocker.notify_one();
94       return;
95     }
96 
97     {
98       std::unique_lock<std::mutex> lock(exec_ret_mutex);
99       execute_success = true;
100       thread_blocker.notify_one();
101     }
102   });
103   MS_EXCEPTION_IF_NULL(executive_thread);
104   executive_thread->detach();
105 
106   std::unique_lock<std::mutex> locker(exec_ret_mutex);
107   (void)thread_blocker.wait_for(locker, std::chrono::seconds(timeout), [&] { return execute_success || execute_fail; });
108 
109   if (!execute_success && !execute_fail) {
110     std::string node_id = common::GetEnv("MS_NODE_ID");
111 #if !defined(_WIN32) && !defined(_WIN64)
112     MS_LOG(ERROR) << "Execute function: " << func_name << " " << functionality << " timeout, this node id: " << node_id
113                   << " exit process";
114     (void)kill(getpid(), SIGTERM);
115 #endif
116   }
117   return execute_success;
118 }
119 
120 // In a disaster recovery scenario, the comparison between the current unique id and the last generated unique id
121 // ensures that the acquired unique id is newly generated, and the latest unique id will be persisted.
CheckUniqueIDLatest(const std::string & group_name,size_t root_info_size,const void * root_info)122 bool CheckUniqueIDLatest(const std::string &group_name, size_t root_info_size, const void *root_info) {
123   MS_EXCEPTION_IF_NULL(root_info);
124   auto persistent_json = RecoveryContext::GetInstance()->persistent_json();
125   MS_EXCEPTION_IF_NULL(persistent_json);
126 
127   std::string new_unique_id(static_cast<const char *>(root_info), root_info_size);
128   std::vector<int> new_unique_id_integer_seq;
129   (void)std::transform(new_unique_id.begin(), new_unique_id.end(), std::back_inserter(new_unique_id_integer_seq),
130                        [](char c) { return static_cast<int>(c); });
131 
132   const char unique_id_str[] = "_unique_id";
133   std::string unique_id_key = group_name + unique_id_str;
134   if (!persistent_json->Exists(unique_id_key)) {
135     persistent_json->Insert(unique_id_key, new_unique_id_integer_seq);
136     return true;
137   }
138 
139   std::vector<int> old_unique_id_integer_seq = persistent_json->Get<std::vector<int>>(unique_id_key);
140   if (new_unique_id_integer_seq == old_unique_id_integer_seq) {
141     return false;
142   }
143 
144   persistent_json->Insert(unique_id_key, new_unique_id_integer_seq);
145   return true;
146 }
147 }  // namespace
148 
Initialize()149 bool CollectiveManager::Initialize() {
150   need_init_ = true;
151   if (inited_ && !need_reinit_) {
152     return true;
153   }
154 
155   need_host_collective_ = common::UseHostCollective();
156   std::string device_type = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
157   // need_host_collective_ means using rank_table to initialize collective communication, which is only supported by
158   // Ascend. On other types of devices, exception should be thrown.
159   if (device_type != kAscendDevice && !need_host_collective_) {
160     MS_LOG(EXCEPTION) << kDetailedFailureReason;
161   }
162   if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode && !need_host_collective_) {
163     MS_LOG(EXCEPTION) << "Ranktable startup method doesn't support pynative mode. Please switch to msrun method.";
164   }
165 
166   MS_LOG(INFO) << "Start initializing collective communication for backend: " << device_type << "...";
167 
168   // Use dummy collective libs in simulation mode.
169   if (!common::GetEnv(kSimulationLevel).empty()) {
170     MS_LOG(WARNING) << "This is simulation mode with level " << common::GetEnv(kSimulationLevel)
171                     << ". Process's RANK_ID: " << common::GetEnv("RANK_ID")
172                     << ", RANK_SIZE: " << common::GetEnv("RANK_SIZE");
173 
174     return InitializeDummyCommLib();
175   }
176 
177   // Initialize real collective libs.
178   if (!need_host_collective_) {
179     RETURN_IF_FALSE_WITH_LOG(InitDeviceCommLib(), "Failed to initialize device communication library.");
180     comm_lib_instance_ = device_comm_lib_instance_;
181   } else {
182     // Step 1: Initialize host side collective communication.
183     RETURN_IF_FALSE_WITH_LOG(InitHostCommlib(), "Failed to initialize host communication library.");
184     comm_lib_instance_ = host_comm_lib_instance_;
185 
186     // Step 2, 3 and 4 are for device communication library. So if the training job is only launched on CPU, they will
187     // not be necessary.
188     // Step 2: Assign local rank id(device id) for this process.
189     RETURN_IF_FALSE_WITH_LOG(AssignLocalRank(), "Failed to assign local rank id.");
190 
191     // Step 3: Initialize device side collective communication.
192     RETURN_IF_FALSE_WITH_LOG(InitDeviceCommLib(), "Failed to initialize device communication library.");
193 
194     // Step 4: Create global communication group.
195     MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
196     auto group_name = device_comm_lib_instance_->global_group_name();
197     RETURN_IF_FALSE_WITH_LOG(CreateCommunicationGroup(group_name, global_group_ranks_),
198                              "Failed to create group " + group_name);
199   }
200 
201   MS_LOG(INFO) << "End initializing collective communication for backend: " << device_type;
202   inited_ = true;
203   finalized_ = false;
204   need_reinit_ = false;
205   return true;
206 }
207 
InitializeDummyCommLib()208 bool CollectiveManager::InitializeDummyCommLib() {
209   dummy_comm_lib_instance_ = std::make_shared<device::DummyCollectiveCommunicationLib>();
210   comm_lib_instance_ = dummy_comm_lib_instance_.get();
211   MS_EXCEPTION_IF_NULL(comm_lib_instance_);
212   RETURN_IF_FALSE_WITH_LOG(comm_lib_instance_->Initialize(0, 1, local_rank_id_),
213                            "Failed to initialize dummy communication library.");
214   global_rank_id_ = comm_lib_instance_->global_rank_id();
215   global_rank_size_ = comm_lib_instance_->global_rank_size();
216   MS_LOG(WARNING) << "Initializing dummy collective communication with rank size: " << global_rank_size_
217                   << ", rank id: " << global_rank_id_ << ". Real rank size: 1.";
218 
219   std::string device_type = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
220   // If this is Ascend backend and uses host collective(OpenMPI or Dynamic Cluster/msrun), initialize dummy ascend
221   // collective lib.
222   if (device_type == kAscendDevice) {
223     MS_LOG(WARNING) << "Initialize dummy Ascend collective communication lib.";
224     RETURN_IF_FALSE_WITH_LOG(InitDeviceCommLib(), "Failed to initialize dummy device communication library on Ascend.");
225   }
226   inited_ = true;
227   finalized_ = false;
228   need_reinit_ = false;
229   return true;
230 }
231 
FinalizeDummyCommLib()232 bool CollectiveManager::FinalizeDummyCommLib() {
233   std::string device_type = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
234   if (need_host_collective_ && device_type == kAscendDevice) {
235     MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
236     if (!device_comm_lib_instance_->Finalize()) {
237       MS_LOG(WARNING) << "Failed to finalize dummy device communication library.";
238     }
239   }
240   MS_EXCEPTION_IF_NULL(comm_lib_instance_);
241   (void)comm_lib_instance_->Finalize();
242 
243   inited_ = false;
244   finalized_ = true;
245   need_init_ = false;
246   return true;
247 }
248 
GetLocalGroupRankAndSize(const std::vector<uint32_t> & group_ranks,uint32_t * local_group_rank,uint32_t * local_group_size)249 bool CollectiveManager::GetLocalGroupRankAndSize(const std::vector<uint32_t> &group_ranks, uint32_t *local_group_rank,
250                                                  uint32_t *local_group_size) {
251   MS_EXCEPTION_IF_NULL(local_group_rank);
252   MS_EXCEPTION_IF_NULL(local_group_size);
253   auto it =
254     std::find_if(group_ranks.begin(), group_ranks.end(), [&](uint32_t rank) { return rank > global_rank_size_; });
255   if (it != group_ranks.end()) {
256     MS_LOG(ERROR) << "The rank " << *it << "is out of global rank size.";
257     return false;
258   }
259   if (all_host_hashs_.size() != static_cast<size_t>(global_rank_size_)) {
260     MS_LOG(ERROR) << "The host hash size should be equal to global rank size " << global_rank_size_ << ", but got "
261                   << all_host_hashs_.size();
262     return false;
263   }
264   *local_group_size = static_cast<uint32_t>(std::count_if(group_ranks.begin(), group_ranks.end(), [&](uint32_t rank) {
265     return all_host_hashs_[rank] == all_host_hashs_[global_rank_id_];
266   }));
267   auto pos = find(group_ranks.begin(), group_ranks.end(), global_rank_id_);
268   if (pos == group_ranks.end()) {
269     *local_group_rank = UINT32_MAX;
270     return true;
271   }
272   *local_group_rank = static_cast<uint32_t>(std::count_if(group_ranks.begin(), pos, [&](uint32_t rank) {
273     return all_host_hashs_[rank] == all_host_hashs_[global_rank_id_];
274   }));
275   return true;
276 }
277 
CreateCommunicationGroup(const std::string & group_name,const std::vector<uint32_t> & group_ranks)278 bool CollectiveManager::CreateCommunicationGroup(const std::string &group_name,
279                                                  const std::vector<uint32_t> &group_ranks) {
280   MS_LOG(WARNING) << "Start to create communication group: " << group_name << " " << group_ranks;
281   if (std::find(group_ranks.begin(), group_ranks.end(), global_rank_id_) == group_ranks.end()) {
282     MS_LOG(WARNING) << "This rank: " << global_rank_id_ << " is not in the group ranks: " << group_ranks
283                     << ". This may cause some exception when initializing the group.";
284   }
285   group_map_[group_name] = group_ranks;
286 
287   // Create simulation communication group.
288   if (!common::GetEnv(kSimulationLevel).empty()) {
289     return CreateSimulationGroup(group_name, group_ranks);
290   }
291 
292   MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
293   if (!need_host_collective_) {
294     RETURN_IF_FALSE_WITH_LOG(device_comm_lib_instance_->CreateDeviceCommunicationGroup(group_name, group_ranks),
295                              "Failed to create device communication group " + group_name);
296     return true;
297   }
298   uint32_t local_group_rank = 0;
299   uint32_t local_group_size = 0;
300   RETURN_IF_FALSE_WITH_LOG(GetLocalGroupRankAndSize(group_ranks, &local_group_rank, &local_group_size),
301                            "GetLocalGroupRankAndSize failed for group " + group_name);
302   MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
303   // Step 1: Create communication group on host side.
304   RETURN_IF_FALSE_WITH_LOG(
305     host_comm_lib_instance_->CreateCommunicationGroup(group_name, group_ranks, local_group_rank, local_group_size),
306     "Failed to create host communication group" + group_name);
307 
308   // Step 2: Create communication group on device side.
309   RETURN_IF_FALSE_WITH_LOG(
310     device_comm_lib_instance_->CreateCommunicationGroup(group_name, group_ranks, local_group_rank, local_group_size),
311     "Failed to create device communication group" + group_name);
312 
313   // Step 3: Generate device information of the root node.
314   CommunicationGroupPtr group = device_comm_lib_instance_->GetGroup(group_name);
315   MS_EXCEPTION_IF_NULL(group);
316   size_t root_info_size = 0;
317   void *root_info = group->GenerateRootInfo(&root_info_size);
318   MS_EXCEPTION_IF_NULL(root_info);
319 
320   bool ret = false;
321   // Step 4: Broadcast the device root information to all nodes on host side.
322   while (!ret) {
323     RETURN_IF_FALSE_WITH_LOG(host_comm_lib_instance_->BroadcastUniqueID(group_name, root_info_size, root_info),
324                              "Broadcast for device root info failed on the host side.");
325     ret = true;
326     // In disaster recovery scenarios, it is necessary to ensure that the unique id obtained from the Scheduler is a
327     // newly generated one.
328     if (RecoveryContext::GetInstance()->enable_recovery()) {
329       ret = CheckUniqueIDLatest(group_name, root_info_size, root_info);
330       if (!ret) {
331         // The time interval for querying latest unique id from scheduler: 3 second.
332         constexpr uint32_t kWaitDuration = 3;
333         std::this_thread::sleep_for(std::chrono::seconds(kWaitDuration));
334       }
335     }
336     MS_LOG(INFO) << "Successfully send/fetch unqiueid for communication group " << group_name;
337   }
338 
339   // Step 5: Initialize communication group on the device side.
340   std::function<bool()> init_device_comm_group_func = [&, this]() {
341     MS_EXCEPTION_IF_NULL(device_ctx_);
342     device_ctx_->Initialize();
343     return group->Initialize(root_info);
344   };
345   MS_LOG(WARNING) << "Begin initialize communication group on the device side: " << group_name;
346 
347   // Timeout limit in seconds to wait finish initializing device communication group.
348   int64_t comm_init_timout = GetCommunicatorInitTimeout();
349   MS_LOG(INFO) << "Communicator initializing timeout is " << comm_init_timout << " seconds.";
350   // Initialize communication group on the device side in thread with timeout limit.
351   ret = ExecuteFuncInThread(init_device_comm_group_func, comm_init_timout, "init_device_comm_group_func",
352                             "to initialize communicator for group " + group_name);
353   if (!ret) {
354     MS_LOG(ERROR) << "Failed to create comm group on device side for " << group_name;
355   }
356   MS_LOG(WARNING) << "End initialize communication group on the device side: " << group_name;
357   return ret;
358 }
359 
DestroyCommunicationGroup(const std::string & group_name)360 bool CollectiveManager::DestroyCommunicationGroup(const std::string &group_name) {
361   MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
362   if (!need_host_collective_ || !common::GetEnv(kSimulationLevel).empty()) {
363     RETURN_IF_FALSE_WITH_LOG(device_comm_lib_instance_->DestroyDeviceCommunicationGroup(group_name),
364                              "Failed to destroy device communication group " + group_name);
365     return true;
366   }
367   MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
368   RETURN_IF_FALSE_WITH_LOG(host_comm_lib_instance_->DestroyCommunicationGroup(group_name),
369                            "Failed to destroy host communication group " + group_name);
370   RETURN_IF_FALSE_WITH_LOG(device_comm_lib_instance_->DestroyCommunicationGroup(group_name),
371                            "Failed to destroy device communication group " + group_name);
372   return true;
373 }
374 
GetRankId(const std::string & group_name)375 uint32_t CollectiveManager::GetRankId(const std::string &group_name) {
376   BY_PASS_SCHED_RANK_ID;
377   MS_EXCEPTION_IF_NULL(comm_lib_instance_);
378   return comm_lib_instance_->GetRankId(group_name);
379 }
380 
GetGroupSize(const std::string & group_name)381 uint32_t CollectiveManager::GetGroupSize(const std::string &group_name) {
382   BY_PASS_SCHED_RANK_SIZE;
383   MS_EXCEPTION_IF_NULL(comm_lib_instance_);
384   return comm_lib_instance_->GetGroupSize(group_name);
385 }
386 
GetLocalRankId(const std::string & group_name)387 uint32_t CollectiveManager::GetLocalRankId(const std::string &group_name) {
388   BY_PASS_SCHED_RANK_ID;
389   MS_EXCEPTION_IF_NULL(comm_lib_instance_);
390   return comm_lib_instance_->GetLocalRankId(group_name);
391 }
392 
GetLocalGroupSize(const std::string & group_name)393 uint32_t CollectiveManager::GetLocalGroupSize(const std::string &group_name) {
394   BY_PASS_SCHED_RANK_SIZE;
395   MS_EXCEPTION_IF_NULL(comm_lib_instance_);
396   return comm_lib_instance_->GetLocalGroupSize(group_name);
397 }
398 
GetWorldRankFromGroupRank(const std::string & group_name,uint32_t local_rank)399 uint32_t CollectiveManager::GetWorldRankFromGroupRank(const std::string &group_name, uint32_t local_rank) {
400   BY_PASS_SCHED_RANK_ID;
401   MS_EXCEPTION_IF_NULL(comm_lib_instance_);
402   return comm_lib_instance_->GetWorldRankFromGroupRank(group_name, local_rank);
403 }
404 
GetGroupRankFromWorldRank(uint32_t global_rank,const std::string & group_name)405 uint32_t CollectiveManager::GetGroupRankFromWorldRank(uint32_t global_rank, const std::string &group_name) {
406   BY_PASS_SCHED_RANK_ID;
407   MS_EXCEPTION_IF_NULL(comm_lib_instance_);
408   return comm_lib_instance_->GetGroupRankFromWorldRank(global_rank, group_name);
409 }
410 
GetGroupRanks(const std::string & group_name)411 std::vector<uint32_t> CollectiveManager::GetGroupRanks(const std::string &group_name) {
412   const auto &group = comm_lib_instance_->GetGroup(group_name);
413   if (group == nullptr) {
414     MS_LOG(EXCEPTION) << "Group " << group_name << " doesn't include this rank " << global_rank_id_ << " process.";
415   }
416   return group->group_ranks();
417 }
418 
Finalize()419 bool CollectiveManager::Finalize() {
420   if (!inited_.load() || finalized_.load()) {
421     return true;
422   }
423 
424   if (!common::GetEnv(kSimulationLevel).empty() || dummy_comm_lib_instance_ != nullptr) {
425     return FinalizeDummyCommLib();
426   }
427 
428   std::function<bool()> finalize_comm_lib_func = [&, this]() {
429     if (need_host_collective_) {
430       MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
431       MS_LOG(INFO) << "Start finalizing host communication lib.";
432       if (!host_comm_lib_instance_->Finalize()) {
433         MS_LOG(WARNING) << "Failed to finalize device communication library.";
434       }
435       MS_LOG(INFO) << "End finalizing host communication lib.";
436     }
437 
438     MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
439 
440     MS_LOG(INFO) << "Start finalizing device communication lib.";
441     if (!device_comm_lib_instance_->Finalize()) {
442       MS_LOG(WARNING) << "Failed to finalize device communication library.";
443     }
444     MS_LOG(INFO) << "End finalizing device communication lib.";
445 
446     inited_ = false;
447     finalized_ = true;
448     need_init_ = false;
449     return true;
450   };
451 
452   MS_LOG(INFO) << "Begin finalize collective manager.";
453 
454   // Timeout limit 30 seconds to wait to finish finalizing device communication group.
455   const int64_t kTimeToWait = 30;
456   // Finalize collective manager in thread with timeout limit.
457   bool ret = ExecuteFuncInThread(finalize_comm_lib_func, kTimeToWait, "finalize_comm_lib_func",
458                                  "to destroy communication groups and finalize communication lib");
459 
460   MS_LOG(INFO) << "End finalize collective manager.";
461   return ret;
462 }
463 
set_global_rank_id(uint32_t global_rank_id)464 void CollectiveManager::set_global_rank_id(uint32_t global_rank_id) { global_rank_id_ = global_rank_id; }
465 
set_global_rank_size(uint32_t global_rank_size)466 void CollectiveManager::set_global_rank_size(uint32_t global_rank_size) { global_rank_size_ = global_rank_size; }
467 
global_rank_id() const468 uint32_t CollectiveManager::global_rank_id() const { return global_rank_id_; }
469 
local_rank_id() const470 uint32_t CollectiveManager::local_rank_id() const { return local_rank_id_; }
471 
InitHostCommlib()472 bool CollectiveManager::InitHostCommlib() {
473   device::DeviceContextKey host_key = {"CPU", 0};
474   host_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(host_key);
475   MS_EXCEPTION_IF_NULL(host_ctx_);
476   MS_EXCEPTION_IF_NULL(host_ctx_->device_res_manager_);
477   RETURN_IF_FALSE_WITH_LOG(host_ctx_->device_res_manager_->LoadCollectiveCommLib(),
478                            "Failed to load communication library on the host side.");
479 
480   host_comm_lib_instance_ = host_ctx_->device_res_manager_->collective_comm_lib();
481   MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
482 
483   // For some communication libraries, global_rank_id_', 'global_rank_size_' should be set by caller, e.g., when using
484   // MindSpore communication. For other communication libraries, global rank id and size is generated by itself, e.g.,
485   // OpenMPI, and parameters 'global_rank_id_', 'global_rank_size_' will not be used.
486   MS_LOG(INFO) << "Start initializing communication library on host side...";
487   RETURN_IF_FALSE_WITH_LOG(host_comm_lib_instance_->Initialize(global_rank_id_, global_rank_size_),
488                            "Failed to initialize communication library on host side.");
489 
490   if (!global_group_ranks_.empty()) {
491     global_group_ranks_.clear();
492   }
493 
494   // Reassign 'global_rank_id_' and 'global_rank_size_'. Generate global communication group ranks.
495   global_rank_id_ = host_comm_lib_instance_->global_rank_id();
496   global_rank_size_ = host_comm_lib_instance_->global_rank_size();
497   for (uint32_t i = 0; i < global_rank_size_; i++) {
498     global_group_ranks_.push_back(i);
499   }
500 
501   // Create world group on host side for AllGather operation of host name while assigning local rank.
502   host_global_group_name_ = host_comm_lib_instance_->global_group_name();
503   RETURN_IF_FALSE_WITH_LOG(
504     host_comm_lib_instance_->CreateCommunicationGroup(host_global_group_name_, global_group_ranks_, 0, 0),
505     "Failed to create host communication group " + host_global_group_name_);
506   MS_LOG(INFO) << "Communication library on host side is successfully initialized. Global rank id: " << global_rank_id_
507                << ", global rank size: " << global_rank_size_;
508   return true;
509 }
510 
InitDeviceCommLib()511 bool CollectiveManager::InitDeviceCommLib() {
512   std::string device_type = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
513   // If library on device side is not supported, replace it with host library.
514   if (!device_lib_supported_) {
515     device_type = kCPUDevice;
516   }
517   device::DeviceContextKey device_key = {device_type, local_rank_id_};
518   device_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(device_key);
519   MS_EXCEPTION_IF_NULL(device_ctx_);
520   // We can initialize device context now because device id(local_rank_id_) is already assigned.
521   device_ctx_->Initialize();
522 
523   MS_EXCEPTION_IF_NULL(device_ctx_->device_res_manager_);
524   RETURN_IF_FALSE_WITH_LOG(device_ctx_->device_res_manager_->LoadCollectiveCommLib(),
525                            "Failed to load communication library on the device side.");
526   device_comm_lib_instance_ = device_ctx_->device_res_manager_->collective_comm_lib();
527   MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
528 
529   MS_LOG(INFO) << "Start initializing communication library on device side...";
530   RETURN_IF_FALSE_WITH_LOG(device_comm_lib_instance_->Initialize(global_rank_id_, global_rank_size_, local_rank_id_),
531                            "Failed to initialize communication library on device side.");
532   MS_LOG(INFO) << "Communication library on device side is successfully initialized.";
533   return true;
534 }
535 
AssignLocalRank()536 bool CollectiveManager::AssignLocalRank() {
537   char host_name[MAX_HOSTNAME_LEN] = {0};
538 #ifndef _WIN32
539   if (gethostname(host_name, MAX_HOSTNAME_LEN) != 0) {
540     MS_LOG(ERROR) << "Failed to get host name.";
541     return false;
542   }
543 #endif
544   MS_LOG(INFO) << "Host name for rank " << global_rank_id_ << " is " << host_name;
545 
546   // Generate host name hash for every process. The host names of different physical machine should not be the same so
547   // that local rank id won't repeat.
548   size_t host_hash = std::hash<std::string>()(host_name);
549   const uint32_t kGlobalRankSize = global_rank_size_;
550   all_host_hashs_.resize(kGlobalRankSize);
551   if (global_rank_id_ >= global_rank_size_) {
552     MS_LOG(ERROR) << "The global rank id " << global_rank_id_ << " should be less than global rank size "
553                   << global_rank_size_;
554     return false;
555   }
556   all_host_hashs_[global_rank_id_] = host_hash;
557   // some case, call init("hccl"), though is one card case and DEVICE_ID is set by user.
558   if (global_rank_size_ <= 1) {
559     local_rank_id_ = MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_DEVICE_ID);
560     return true;
561   }
562   MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
563   RETURN_IF_FALSE_WITH_LOG(host_comm_lib_instance_->AllGatherHostHashName(host_hash, &all_host_hashs_),
564                            "AllGather for host names failed.");
565   MS_LOG(INFO) << "Successfully get all nodes' hostname.";
566 
567   // Accumulate rank id.
568   // In disaster recovery scenario, this function will enter multiple times when the network is reconfigured, so old
569   // local rank id need to be cleaned.
570   std::vector<uint32_t> world_ranks(global_rank_size_);
571   std::iota(world_ranks.begin(), world_ranks.end(), 0);
572   uint32_t local_group_size = 0;
573   RETURN_IF_FALSE_WITH_LOG(GetLocalGroupRankAndSize(world_ranks, &local_rank_id_, &local_group_size),
574                            "GetLocalGroupRankAndSize for world group failed.");
575   host_comm_lib_instance_->SetLocalGroupRank(host_comm_lib_instance_->global_group_name(), local_rank_id_);
576   host_comm_lib_instance_->SetLocalGroupSize(host_comm_lib_instance_->global_group_name(), local_group_size);
577   // No need to reset device_id if library on device side is not supported, e.g., ascend.
578   if (device_lib_supported_) {
579     MsContext::GetInstance()->set_param_inner<uint32_t>(MS_CTX_DEVICE_ID, local_rank_id_);
580     MS_LOG(INFO) << "The local rank id assigned for this process is " << local_rank_id_
581                  << ". device_id of ms_context is set.";
582     common::SetEnv("RANK_ID", std::to_string(global_rank_id_).c_str());
583     common::SetEnv("DEVICE_ID", std::to_string(local_rank_id_).c_str());
584     common::SetEnv("RANK_SIZE", std::to_string(global_rank_size_).c_str());
585   }
586 
587   return true;
588 }
589 
CreateSimulationGroup(const std::string & group_name,const std::vector<uint32_t> & group_ranks)590 bool CollectiveManager::CreateSimulationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) {
591   // Set local rank id to 0 and local group size to 8 in simulation mode. These two values should not affect compiling.
592   uint32_t local_rank = 0;
593   uint32_t local_rank_size = 8;
594   MS_LOG(WARNING) << "Create dummy communication group with group name: " << group_name
595                   << ", group ranks: " << group_ranks << ". Real group size: 1.";
596   RETURN_IF_FALSE_WITH_LOG(
597     dummy_comm_lib_instance_->CreateCommunicationGroup(group_name, group_ranks, local_rank, local_rank_size),
598     "Failed to create dummy communication group " + group_name);
599 
600   std::string device_type = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
601   // If this is Ascend backend and uses host collective(OpenMPI or Dynamic Cluster/msrun), initialize real HCCL
602   // communicator through dummy Ascend collective communication lib.
603   if (device_type == kAscendDevice) {
604     MS_LOG(WARNING) << "Create Ascend communication group with group name: " << group_name
605                     << ", group ranks: " << group_ranks
606                     << ". Real HCCL communicator will be initialized with group size 1.";
607     RETURN_IF_FALSE_WITH_LOG(
608       device_comm_lib_instance_->CreateCommunicationGroup(group_name, group_ranks, local_rank, local_rank_size),
609       "Failed to create dummy device communication group " + group_name);
610 
611     CommunicationGroupPtr group = device_comm_lib_instance_->GetGroup(group_name);
612     size_t root_info_size = 0;
613     void *root_info = group->GenerateRootInfo(&root_info_size);
614     MS_EXCEPTION_IF_NULL(device_ctx_);
615     device_ctx_->Initialize();
616     auto ret = group->Initialize(root_info);
617     if (!ret) {
618       MS_LOG(ERROR) << "Failed to create comm group on device side for " << group_name;
619     }
620   }
621   return true;
622 }
623 
GetCommunicatorInitTimeout()624 int64_t CollectiveManager::GetCommunicatorInitTimeout() {
625   // The default timeout is 600 seconds.
626   int64_t default_comm_init_timeout = 600;
627   std::string device_type = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
628   if (device_type == kAscendDevice) {
629     std::string str_comm_init_timeout = common::GetEnv("HCCL_CONNECT_TIMEOUT");
630     return str_comm_init_timeout.empty() ? default_comm_init_timeout : std::stoi(str_comm_init_timeout);
631   }
632   return default_comm_init_timeout;
633 }
634 }  // namespace collective
635 }  // namespace distributed
636 }  // namespace mindspore
637