1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/backend/distributed/collective/collective_manager.h"
18 #include <algorithm>
19 #include <string>
20 #include <numeric>
21 #include <vector>
22 #include <functional>
23 #include <csignal>
24 #include <memory>
25 #include "utils/ms_context.h"
26 #include "include/backend/distributed/recovery/recovery_context.h"
27 #include "distributed/persistent/storage/json_utils.h"
28 #include "runtime/collective/dummy_collective_communication_lib.h"
29
30 namespace mindspore {
31 namespace distributed {
32 namespace collective {
33 using recovery::RecoveryContext;
34
CollectiveManager()35 CollectiveManager::CollectiveManager()
36 : inited_(false),
37 finalized_(true),
38 need_init_(false),
39 need_reinit_(false),
40 host_ctx_(nullptr),
41 device_ctx_(nullptr),
42 host_comm_lib_instance_(nullptr),
43 device_comm_lib_instance_(nullptr),
44 comm_lib_instance_(nullptr),
45 global_rank_id_(0),
46 local_rank_id_(0),
47 global_rank_size_(1),
48 global_group_ranks_({}),
49 device_lib_supported_(true),
50 need_host_collective_(false) {}
51
~CollectiveManager()52 CollectiveManager::~CollectiveManager() {
53 if (!finalized_) {
54 try {
55 (void)Finalize();
56 } catch (std::exception &) {
57 MS_LOG(ERROR) << "Failed to finalize collective manager.";
58 }
59 }
60 finalized_ = true;
61 host_ctx_ = nullptr;
62 device_ctx_ = nullptr;
63 host_comm_lib_instance_ = nullptr;
64 device_comm_lib_instance_ = nullptr;
65 comm_lib_instance_ = nullptr;
66 }
67
instance()68 std::shared_ptr<CollectiveManager> CollectiveManager::instance() {
69 static std::shared_ptr<CollectiveManager> instance = nullptr;
70 if (instance == nullptr) {
71 instance.reset(new (std::nothrow) CollectiveManager());
72 MS_EXCEPTION_IF_NULL(instance);
73 }
74 return instance;
75 }
76
77 namespace {
78 // The wrapper to provide a timeout mechanism for executing functions.
79 // We also need to log the functionality of the function.
ExecuteFuncInThread(const std::function<bool ()> & func,const int64_t timeout,const std::string & func_name,const std::string & functionality)80 bool ExecuteFuncInThread(const std::function<bool()> &func, const int64_t timeout, const std::string &func_name,
81 const std::string &functionality) {
82 bool execute_success = false;
83 bool execute_fail = false;
84 std::mutex exec_ret_mutex;
85 std::condition_variable thread_blocker;
86
87 std::unique_ptr<std::thread> executive_thread = std::make_unique<std::thread>([&] {
88 if (!func()) {
89 MS_LOG(ERROR) << "Failed to execute function: " << func_name << " " << functionality
90 << ". Please check error log above.";
91 std::unique_lock<std::mutex> lock(exec_ret_mutex);
92 execute_fail = true;
93 thread_blocker.notify_one();
94 return;
95 }
96
97 {
98 std::unique_lock<std::mutex> lock(exec_ret_mutex);
99 execute_success = true;
100 thread_blocker.notify_one();
101 }
102 });
103 MS_EXCEPTION_IF_NULL(executive_thread);
104 executive_thread->detach();
105
106 std::unique_lock<std::mutex> locker(exec_ret_mutex);
107 (void)thread_blocker.wait_for(locker, std::chrono::seconds(timeout), [&] { return execute_success || execute_fail; });
108
109 if (!execute_success && !execute_fail) {
110 std::string node_id = common::GetEnv("MS_NODE_ID");
111 #if !defined(_WIN32) && !defined(_WIN64)
112 MS_LOG(ERROR) << "Execute function: " << func_name << " " << functionality << " timeout, this node id: " << node_id
113 << " exit process";
114 (void)kill(getpid(), SIGTERM);
115 #endif
116 }
117 return execute_success;
118 }
119
120 // In a disaster recovery scenario, the comparison between the current unique id and the last generated unique id
121 // ensures that the acquired unique id is newly generated, and the latest unique id will be persisted.
CheckUniqueIDLatest(const std::string & group_name,size_t root_info_size,const void * root_info)122 bool CheckUniqueIDLatest(const std::string &group_name, size_t root_info_size, const void *root_info) {
123 MS_EXCEPTION_IF_NULL(root_info);
124 auto persistent_json = RecoveryContext::GetInstance()->persistent_json();
125 MS_EXCEPTION_IF_NULL(persistent_json);
126
127 std::string new_unique_id(static_cast<const char *>(root_info), root_info_size);
128 std::vector<int> new_unique_id_integer_seq;
129 (void)std::transform(new_unique_id.begin(), new_unique_id.end(), std::back_inserter(new_unique_id_integer_seq),
130 [](char c) { return static_cast<int>(c); });
131
132 const char unique_id_str[] = "_unique_id";
133 std::string unique_id_key = group_name + unique_id_str;
134 if (!persistent_json->Exists(unique_id_key)) {
135 persistent_json->Insert(unique_id_key, new_unique_id_integer_seq);
136 return true;
137 }
138
139 std::vector<int> old_unique_id_integer_seq = persistent_json->Get<std::vector<int>>(unique_id_key);
140 if (new_unique_id_integer_seq == old_unique_id_integer_seq) {
141 return false;
142 }
143
144 persistent_json->Insert(unique_id_key, new_unique_id_integer_seq);
145 return true;
146 }
147 } // namespace
148
Initialize()149 bool CollectiveManager::Initialize() {
150 need_init_ = true;
151 if (inited_ && !need_reinit_) {
152 return true;
153 }
154
155 need_host_collective_ = common::UseHostCollective();
156 std::string device_type = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
157 // need_host_collective_ means using rank_table to initialize collective communication, which is only supported by
158 // Ascend. On other types of devices, exception should be thrown.
159 if (device_type != kAscendDevice && !need_host_collective_) {
160 MS_LOG(EXCEPTION) << kDetailedFailureReason;
161 }
162 if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode && !need_host_collective_) {
163 MS_LOG(EXCEPTION) << "Ranktable startup method doesn't support pynative mode. Please switch to msrun method.";
164 }
165
166 MS_LOG(INFO) << "Start initializing collective communication for backend: " << device_type << "...";
167
168 // Use dummy collective libs in simulation mode.
169 if (!common::GetEnv(kSimulationLevel).empty()) {
170 MS_LOG(WARNING) << "This is simulation mode with level " << common::GetEnv(kSimulationLevel)
171 << ". Process's RANK_ID: " << common::GetEnv("RANK_ID")
172 << ", RANK_SIZE: " << common::GetEnv("RANK_SIZE");
173
174 return InitializeDummyCommLib();
175 }
176
177 // Initialize real collective libs.
178 if (!need_host_collective_) {
179 RETURN_IF_FALSE_WITH_LOG(InitDeviceCommLib(), "Failed to initialize device communication library.");
180 comm_lib_instance_ = device_comm_lib_instance_;
181 } else {
182 // Step 1: Initialize host side collective communication.
183 RETURN_IF_FALSE_WITH_LOG(InitHostCommlib(), "Failed to initialize host communication library.");
184 comm_lib_instance_ = host_comm_lib_instance_;
185
186 // Step 2, 3 and 4 are for device communication library. So if the training job is only launched on CPU, they will
187 // not be necessary.
188 // Step 2: Assign local rank id(device id) for this process.
189 RETURN_IF_FALSE_WITH_LOG(AssignLocalRank(), "Failed to assign local rank id.");
190
191 // Step 3: Initialize device side collective communication.
192 RETURN_IF_FALSE_WITH_LOG(InitDeviceCommLib(), "Failed to initialize device communication library.");
193
194 // Step 4: Create global communication group.
195 MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
196 auto group_name = device_comm_lib_instance_->global_group_name();
197 RETURN_IF_FALSE_WITH_LOG(CreateCommunicationGroup(group_name, global_group_ranks_),
198 "Failed to create group " + group_name);
199 }
200
201 MS_LOG(INFO) << "End initializing collective communication for backend: " << device_type;
202 inited_ = true;
203 finalized_ = false;
204 need_reinit_ = false;
205 return true;
206 }
207
InitializeDummyCommLib()208 bool CollectiveManager::InitializeDummyCommLib() {
209 dummy_comm_lib_instance_ = std::make_shared<device::DummyCollectiveCommunicationLib>();
210 comm_lib_instance_ = dummy_comm_lib_instance_.get();
211 MS_EXCEPTION_IF_NULL(comm_lib_instance_);
212 RETURN_IF_FALSE_WITH_LOG(comm_lib_instance_->Initialize(0, 1, local_rank_id_),
213 "Failed to initialize dummy communication library.");
214 global_rank_id_ = comm_lib_instance_->global_rank_id();
215 global_rank_size_ = comm_lib_instance_->global_rank_size();
216 MS_LOG(WARNING) << "Initializing dummy collective communication with rank size: " << global_rank_size_
217 << ", rank id: " << global_rank_id_ << ". Real rank size: 1.";
218
219 std::string device_type = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
220 // If this is Ascend backend and uses host collective(OpenMPI or Dynamic Cluster/msrun), initialize dummy ascend
221 // collective lib.
222 if (device_type == kAscendDevice) {
223 MS_LOG(WARNING) << "Initialize dummy Ascend collective communication lib.";
224 RETURN_IF_FALSE_WITH_LOG(InitDeviceCommLib(), "Failed to initialize dummy device communication library on Ascend.");
225 }
226 inited_ = true;
227 finalized_ = false;
228 need_reinit_ = false;
229 return true;
230 }
231
FinalizeDummyCommLib()232 bool CollectiveManager::FinalizeDummyCommLib() {
233 std::string device_type = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
234 if (need_host_collective_ && device_type == kAscendDevice) {
235 MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
236 if (!device_comm_lib_instance_->Finalize()) {
237 MS_LOG(WARNING) << "Failed to finalize dummy device communication library.";
238 }
239 }
240 MS_EXCEPTION_IF_NULL(comm_lib_instance_);
241 (void)comm_lib_instance_->Finalize();
242
243 inited_ = false;
244 finalized_ = true;
245 need_init_ = false;
246 return true;
247 }
248
GetLocalGroupRankAndSize(const std::vector<uint32_t> & group_ranks,uint32_t * local_group_rank,uint32_t * local_group_size)249 bool CollectiveManager::GetLocalGroupRankAndSize(const std::vector<uint32_t> &group_ranks, uint32_t *local_group_rank,
250 uint32_t *local_group_size) {
251 MS_EXCEPTION_IF_NULL(local_group_rank);
252 MS_EXCEPTION_IF_NULL(local_group_size);
253 auto it =
254 std::find_if(group_ranks.begin(), group_ranks.end(), [&](uint32_t rank) { return rank > global_rank_size_; });
255 if (it != group_ranks.end()) {
256 MS_LOG(ERROR) << "The rank " << *it << "is out of global rank size.";
257 return false;
258 }
259 if (all_host_hashs_.size() != static_cast<size_t>(global_rank_size_)) {
260 MS_LOG(ERROR) << "The host hash size should be equal to global rank size " << global_rank_size_ << ", but got "
261 << all_host_hashs_.size();
262 return false;
263 }
264 *local_group_size = static_cast<uint32_t>(std::count_if(group_ranks.begin(), group_ranks.end(), [&](uint32_t rank) {
265 return all_host_hashs_[rank] == all_host_hashs_[global_rank_id_];
266 }));
267 auto pos = find(group_ranks.begin(), group_ranks.end(), global_rank_id_);
268 if (pos == group_ranks.end()) {
269 *local_group_rank = UINT32_MAX;
270 return true;
271 }
272 *local_group_rank = static_cast<uint32_t>(std::count_if(group_ranks.begin(), pos, [&](uint32_t rank) {
273 return all_host_hashs_[rank] == all_host_hashs_[global_rank_id_];
274 }));
275 return true;
276 }
277
CreateCommunicationGroup(const std::string & group_name,const std::vector<uint32_t> & group_ranks)278 bool CollectiveManager::CreateCommunicationGroup(const std::string &group_name,
279 const std::vector<uint32_t> &group_ranks) {
280 MS_LOG(WARNING) << "Start to create communication group: " << group_name << " " << group_ranks;
281 if (std::find(group_ranks.begin(), group_ranks.end(), global_rank_id_) == group_ranks.end()) {
282 MS_LOG(WARNING) << "This rank: " << global_rank_id_ << " is not in the group ranks: " << group_ranks
283 << ". This may cause some exception when initializing the group.";
284 }
285 group_map_[group_name] = group_ranks;
286
287 // Create simulation communication group.
288 if (!common::GetEnv(kSimulationLevel).empty()) {
289 return CreateSimulationGroup(group_name, group_ranks);
290 }
291
292 MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
293 if (!need_host_collective_) {
294 RETURN_IF_FALSE_WITH_LOG(device_comm_lib_instance_->CreateDeviceCommunicationGroup(group_name, group_ranks),
295 "Failed to create device communication group " + group_name);
296 return true;
297 }
298 uint32_t local_group_rank = 0;
299 uint32_t local_group_size = 0;
300 RETURN_IF_FALSE_WITH_LOG(GetLocalGroupRankAndSize(group_ranks, &local_group_rank, &local_group_size),
301 "GetLocalGroupRankAndSize failed for group " + group_name);
302 MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
303 // Step 1: Create communication group on host side.
304 RETURN_IF_FALSE_WITH_LOG(
305 host_comm_lib_instance_->CreateCommunicationGroup(group_name, group_ranks, local_group_rank, local_group_size),
306 "Failed to create host communication group" + group_name);
307
308 // Step 2: Create communication group on device side.
309 RETURN_IF_FALSE_WITH_LOG(
310 device_comm_lib_instance_->CreateCommunicationGroup(group_name, group_ranks, local_group_rank, local_group_size),
311 "Failed to create device communication group" + group_name);
312
313 // Step 3: Generate device information of the root node.
314 CommunicationGroupPtr group = device_comm_lib_instance_->GetGroup(group_name);
315 MS_EXCEPTION_IF_NULL(group);
316 size_t root_info_size = 0;
317 void *root_info = group->GenerateRootInfo(&root_info_size);
318 MS_EXCEPTION_IF_NULL(root_info);
319
320 bool ret = false;
321 // Step 4: Broadcast the device root information to all nodes on host side.
322 while (!ret) {
323 RETURN_IF_FALSE_WITH_LOG(host_comm_lib_instance_->BroadcastUniqueID(group_name, root_info_size, root_info),
324 "Broadcast for device root info failed on the host side.");
325 ret = true;
326 // In disaster recovery scenarios, it is necessary to ensure that the unique id obtained from the Scheduler is a
327 // newly generated one.
328 if (RecoveryContext::GetInstance()->enable_recovery()) {
329 ret = CheckUniqueIDLatest(group_name, root_info_size, root_info);
330 if (!ret) {
331 // The time interval for querying latest unique id from scheduler: 3 second.
332 constexpr uint32_t kWaitDuration = 3;
333 std::this_thread::sleep_for(std::chrono::seconds(kWaitDuration));
334 }
335 }
336 MS_LOG(INFO) << "Successfully send/fetch unqiueid for communication group " << group_name;
337 }
338
339 // Step 5: Initialize communication group on the device side.
340 std::function<bool()> init_device_comm_group_func = [&, this]() {
341 MS_EXCEPTION_IF_NULL(device_ctx_);
342 device_ctx_->Initialize();
343 return group->Initialize(root_info);
344 };
345 MS_LOG(WARNING) << "Begin initialize communication group on the device side: " << group_name;
346
347 // Timeout limit in seconds to wait finish initializing device communication group.
348 int64_t comm_init_timout = GetCommunicatorInitTimeout();
349 MS_LOG(INFO) << "Communicator initializing timeout is " << comm_init_timout << " seconds.";
350 // Initialize communication group on the device side in thread with timeout limit.
351 ret = ExecuteFuncInThread(init_device_comm_group_func, comm_init_timout, "init_device_comm_group_func",
352 "to initialize communicator for group " + group_name);
353 if (!ret) {
354 MS_LOG(ERROR) << "Failed to create comm group on device side for " << group_name;
355 }
356 MS_LOG(WARNING) << "End initialize communication group on the device side: " << group_name;
357 return ret;
358 }
359
DestroyCommunicationGroup(const std::string & group_name)360 bool CollectiveManager::DestroyCommunicationGroup(const std::string &group_name) {
361 MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
362 if (!need_host_collective_ || !common::GetEnv(kSimulationLevel).empty()) {
363 RETURN_IF_FALSE_WITH_LOG(device_comm_lib_instance_->DestroyDeviceCommunicationGroup(group_name),
364 "Failed to destroy device communication group " + group_name);
365 return true;
366 }
367 MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
368 RETURN_IF_FALSE_WITH_LOG(host_comm_lib_instance_->DestroyCommunicationGroup(group_name),
369 "Failed to destroy host communication group " + group_name);
370 RETURN_IF_FALSE_WITH_LOG(device_comm_lib_instance_->DestroyCommunicationGroup(group_name),
371 "Failed to destroy device communication group " + group_name);
372 return true;
373 }
374
GetRankId(const std::string & group_name)375 uint32_t CollectiveManager::GetRankId(const std::string &group_name) {
376 BY_PASS_SCHED_RANK_ID;
377 MS_EXCEPTION_IF_NULL(comm_lib_instance_);
378 return comm_lib_instance_->GetRankId(group_name);
379 }
380
GetGroupSize(const std::string & group_name)381 uint32_t CollectiveManager::GetGroupSize(const std::string &group_name) {
382 BY_PASS_SCHED_RANK_SIZE;
383 MS_EXCEPTION_IF_NULL(comm_lib_instance_);
384 return comm_lib_instance_->GetGroupSize(group_name);
385 }
386
GetLocalRankId(const std::string & group_name)387 uint32_t CollectiveManager::GetLocalRankId(const std::string &group_name) {
388 BY_PASS_SCHED_RANK_ID;
389 MS_EXCEPTION_IF_NULL(comm_lib_instance_);
390 return comm_lib_instance_->GetLocalRankId(group_name);
391 }
392
GetLocalGroupSize(const std::string & group_name)393 uint32_t CollectiveManager::GetLocalGroupSize(const std::string &group_name) {
394 BY_PASS_SCHED_RANK_SIZE;
395 MS_EXCEPTION_IF_NULL(comm_lib_instance_);
396 return comm_lib_instance_->GetLocalGroupSize(group_name);
397 }
398
GetWorldRankFromGroupRank(const std::string & group_name,uint32_t local_rank)399 uint32_t CollectiveManager::GetWorldRankFromGroupRank(const std::string &group_name, uint32_t local_rank) {
400 BY_PASS_SCHED_RANK_ID;
401 MS_EXCEPTION_IF_NULL(comm_lib_instance_);
402 return comm_lib_instance_->GetWorldRankFromGroupRank(group_name, local_rank);
403 }
404
GetGroupRankFromWorldRank(uint32_t global_rank,const std::string & group_name)405 uint32_t CollectiveManager::GetGroupRankFromWorldRank(uint32_t global_rank, const std::string &group_name) {
406 BY_PASS_SCHED_RANK_ID;
407 MS_EXCEPTION_IF_NULL(comm_lib_instance_);
408 return comm_lib_instance_->GetGroupRankFromWorldRank(global_rank, group_name);
409 }
410
GetGroupRanks(const std::string & group_name)411 std::vector<uint32_t> CollectiveManager::GetGroupRanks(const std::string &group_name) {
412 const auto &group = comm_lib_instance_->GetGroup(group_name);
413 if (group == nullptr) {
414 MS_LOG(EXCEPTION) << "Group " << group_name << " doesn't include this rank " << global_rank_id_ << " process.";
415 }
416 return group->group_ranks();
417 }
418
Finalize()419 bool CollectiveManager::Finalize() {
420 if (!inited_.load() || finalized_.load()) {
421 return true;
422 }
423
424 if (!common::GetEnv(kSimulationLevel).empty() || dummy_comm_lib_instance_ != nullptr) {
425 return FinalizeDummyCommLib();
426 }
427
428 std::function<bool()> finalize_comm_lib_func = [&, this]() {
429 if (need_host_collective_) {
430 MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
431 MS_LOG(INFO) << "Start finalizing host communication lib.";
432 if (!host_comm_lib_instance_->Finalize()) {
433 MS_LOG(WARNING) << "Failed to finalize device communication library.";
434 }
435 MS_LOG(INFO) << "End finalizing host communication lib.";
436 }
437
438 MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
439
440 MS_LOG(INFO) << "Start finalizing device communication lib.";
441 if (!device_comm_lib_instance_->Finalize()) {
442 MS_LOG(WARNING) << "Failed to finalize device communication library.";
443 }
444 MS_LOG(INFO) << "End finalizing device communication lib.";
445
446 inited_ = false;
447 finalized_ = true;
448 need_init_ = false;
449 return true;
450 };
451
452 MS_LOG(INFO) << "Begin finalize collective manager.";
453
454 // Timeout limit 30 seconds to wait to finish finalizing device communication group.
455 const int64_t kTimeToWait = 30;
456 // Finalize collective manager in thread with timeout limit.
457 bool ret = ExecuteFuncInThread(finalize_comm_lib_func, kTimeToWait, "finalize_comm_lib_func",
458 "to destroy communication groups and finalize communication lib");
459
460 MS_LOG(INFO) << "End finalize collective manager.";
461 return ret;
462 }
463
set_global_rank_id(uint32_t global_rank_id)464 void CollectiveManager::set_global_rank_id(uint32_t global_rank_id) { global_rank_id_ = global_rank_id; }
465
set_global_rank_size(uint32_t global_rank_size)466 void CollectiveManager::set_global_rank_size(uint32_t global_rank_size) { global_rank_size_ = global_rank_size; }
467
global_rank_id() const468 uint32_t CollectiveManager::global_rank_id() const { return global_rank_id_; }
469
local_rank_id() const470 uint32_t CollectiveManager::local_rank_id() const { return local_rank_id_; }
471
InitHostCommlib()472 bool CollectiveManager::InitHostCommlib() {
473 device::DeviceContextKey host_key = {"CPU", 0};
474 host_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(host_key);
475 MS_EXCEPTION_IF_NULL(host_ctx_);
476 MS_EXCEPTION_IF_NULL(host_ctx_->device_res_manager_);
477 RETURN_IF_FALSE_WITH_LOG(host_ctx_->device_res_manager_->LoadCollectiveCommLib(),
478 "Failed to load communication library on the host side.");
479
480 host_comm_lib_instance_ = host_ctx_->device_res_manager_->collective_comm_lib();
481 MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
482
483 // For some communication libraries, global_rank_id_', 'global_rank_size_' should be set by caller, e.g., when using
484 // MindSpore communication. For other communication libraries, global rank id and size is generated by itself, e.g.,
485 // OpenMPI, and parameters 'global_rank_id_', 'global_rank_size_' will not be used.
486 MS_LOG(INFO) << "Start initializing communication library on host side...";
487 RETURN_IF_FALSE_WITH_LOG(host_comm_lib_instance_->Initialize(global_rank_id_, global_rank_size_),
488 "Failed to initialize communication library on host side.");
489
490 if (!global_group_ranks_.empty()) {
491 global_group_ranks_.clear();
492 }
493
494 // Reassign 'global_rank_id_' and 'global_rank_size_'. Generate global communication group ranks.
495 global_rank_id_ = host_comm_lib_instance_->global_rank_id();
496 global_rank_size_ = host_comm_lib_instance_->global_rank_size();
497 for (uint32_t i = 0; i < global_rank_size_; i++) {
498 global_group_ranks_.push_back(i);
499 }
500
501 // Create world group on host side for AllGather operation of host name while assigning local rank.
502 host_global_group_name_ = host_comm_lib_instance_->global_group_name();
503 RETURN_IF_FALSE_WITH_LOG(
504 host_comm_lib_instance_->CreateCommunicationGroup(host_global_group_name_, global_group_ranks_, 0, 0),
505 "Failed to create host communication group " + host_global_group_name_);
506 MS_LOG(INFO) << "Communication library on host side is successfully initialized. Global rank id: " << global_rank_id_
507 << ", global rank size: " << global_rank_size_;
508 return true;
509 }
510
InitDeviceCommLib()511 bool CollectiveManager::InitDeviceCommLib() {
512 std::string device_type = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
513 // If library on device side is not supported, replace it with host library.
514 if (!device_lib_supported_) {
515 device_type = kCPUDevice;
516 }
517 device::DeviceContextKey device_key = {device_type, local_rank_id_};
518 device_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(device_key);
519 MS_EXCEPTION_IF_NULL(device_ctx_);
520 // We can initialize device context now because device id(local_rank_id_) is already assigned.
521 device_ctx_->Initialize();
522
523 MS_EXCEPTION_IF_NULL(device_ctx_->device_res_manager_);
524 RETURN_IF_FALSE_WITH_LOG(device_ctx_->device_res_manager_->LoadCollectiveCommLib(),
525 "Failed to load communication library on the device side.");
526 device_comm_lib_instance_ = device_ctx_->device_res_manager_->collective_comm_lib();
527 MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
528
529 MS_LOG(INFO) << "Start initializing communication library on device side...";
530 RETURN_IF_FALSE_WITH_LOG(device_comm_lib_instance_->Initialize(global_rank_id_, global_rank_size_, local_rank_id_),
531 "Failed to initialize communication library on device side.");
532 MS_LOG(INFO) << "Communication library on device side is successfully initialized.";
533 return true;
534 }
535
AssignLocalRank()536 bool CollectiveManager::AssignLocalRank() {
537 char host_name[MAX_HOSTNAME_LEN] = {0};
538 #ifndef _WIN32
539 if (gethostname(host_name, MAX_HOSTNAME_LEN) != 0) {
540 MS_LOG(ERROR) << "Failed to get host name.";
541 return false;
542 }
543 #endif
544 MS_LOG(INFO) << "Host name for rank " << global_rank_id_ << " is " << host_name;
545
546 // Generate host name hash for every process. The host names of different physical machine should not be the same so
547 // that local rank id won't repeat.
548 size_t host_hash = std::hash<std::string>()(host_name);
549 const uint32_t kGlobalRankSize = global_rank_size_;
550 all_host_hashs_.resize(kGlobalRankSize);
551 if (global_rank_id_ >= global_rank_size_) {
552 MS_LOG(ERROR) << "The global rank id " << global_rank_id_ << " should be less than global rank size "
553 << global_rank_size_;
554 return false;
555 }
556 all_host_hashs_[global_rank_id_] = host_hash;
557 // some case, call init("hccl"), though is one card case and DEVICE_ID is set by user.
558 if (global_rank_size_ <= 1) {
559 local_rank_id_ = MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_DEVICE_ID);
560 return true;
561 }
562 MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
563 RETURN_IF_FALSE_WITH_LOG(host_comm_lib_instance_->AllGatherHostHashName(host_hash, &all_host_hashs_),
564 "AllGather for host names failed.");
565 MS_LOG(INFO) << "Successfully get all nodes' hostname.";
566
567 // Accumulate rank id.
568 // In disaster recovery scenario, this function will enter multiple times when the network is reconfigured, so old
569 // local rank id need to be cleaned.
570 std::vector<uint32_t> world_ranks(global_rank_size_);
571 std::iota(world_ranks.begin(), world_ranks.end(), 0);
572 uint32_t local_group_size = 0;
573 RETURN_IF_FALSE_WITH_LOG(GetLocalGroupRankAndSize(world_ranks, &local_rank_id_, &local_group_size),
574 "GetLocalGroupRankAndSize for world group failed.");
575 host_comm_lib_instance_->SetLocalGroupRank(host_comm_lib_instance_->global_group_name(), local_rank_id_);
576 host_comm_lib_instance_->SetLocalGroupSize(host_comm_lib_instance_->global_group_name(), local_group_size);
577 // No need to reset device_id if library on device side is not supported, e.g., ascend.
578 if (device_lib_supported_) {
579 MsContext::GetInstance()->set_param_inner<uint32_t>(MS_CTX_DEVICE_ID, local_rank_id_);
580 MS_LOG(INFO) << "The local rank id assigned for this process is " << local_rank_id_
581 << ". device_id of ms_context is set.";
582 common::SetEnv("RANK_ID", std::to_string(global_rank_id_).c_str());
583 common::SetEnv("DEVICE_ID", std::to_string(local_rank_id_).c_str());
584 common::SetEnv("RANK_SIZE", std::to_string(global_rank_size_).c_str());
585 }
586
587 return true;
588 }
589
CreateSimulationGroup(const std::string & group_name,const std::vector<uint32_t> & group_ranks)590 bool CollectiveManager::CreateSimulationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) {
591 // Set local rank id to 0 and local group size to 8 in simulation mode. These two values should not affect compiling.
592 uint32_t local_rank = 0;
593 uint32_t local_rank_size = 8;
594 MS_LOG(WARNING) << "Create dummy communication group with group name: " << group_name
595 << ", group ranks: " << group_ranks << ". Real group size: 1.";
596 RETURN_IF_FALSE_WITH_LOG(
597 dummy_comm_lib_instance_->CreateCommunicationGroup(group_name, group_ranks, local_rank, local_rank_size),
598 "Failed to create dummy communication group " + group_name);
599
600 std::string device_type = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
601 // If this is Ascend backend and uses host collective(OpenMPI or Dynamic Cluster/msrun), initialize real HCCL
602 // communicator through dummy Ascend collective communication lib.
603 if (device_type == kAscendDevice) {
604 MS_LOG(WARNING) << "Create Ascend communication group with group name: " << group_name
605 << ", group ranks: " << group_ranks
606 << ". Real HCCL communicator will be initialized with group size 1.";
607 RETURN_IF_FALSE_WITH_LOG(
608 device_comm_lib_instance_->CreateCommunicationGroup(group_name, group_ranks, local_rank, local_rank_size),
609 "Failed to create dummy device communication group " + group_name);
610
611 CommunicationGroupPtr group = device_comm_lib_instance_->GetGroup(group_name);
612 size_t root_info_size = 0;
613 void *root_info = group->GenerateRootInfo(&root_info_size);
614 MS_EXCEPTION_IF_NULL(device_ctx_);
615 device_ctx_->Initialize();
616 auto ret = group->Initialize(root_info);
617 if (!ret) {
618 MS_LOG(ERROR) << "Failed to create comm group on device side for " << group_name;
619 }
620 }
621 return true;
622 }
623
GetCommunicatorInitTimeout()624 int64_t CollectiveManager::GetCommunicatorInitTimeout() {
625 // The default timeout is 600 seconds.
626 int64_t default_comm_init_timeout = 600;
627 std::string device_type = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
628 if (device_type == kAscendDevice) {
629 std::string str_comm_init_timeout = common::GetEnv("HCCL_CONNECT_TIMEOUT");
630 return str_comm_init_timeout.empty() ? default_comm_init_timeout : std::stoi(str_comm_init_timeout);
631 }
632 return default_comm_init_timeout;
633 }
634 } // namespace collective
635 } // namespace distributed
636 } // namespace mindspore
637