• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/collective/collective_communication_lib.h"
18 
19 namespace mindspore {
20 namespace device {
Finalize()21 bool CollectiveCommunicationLib::Finalize() {
22   if (!initialized_ || finalized_.load()) {
23     return true;
24   }
25 
26   for (const auto &group : groups_) {
27     CHECK_IF_NULL(group.second);
28     if (!group.second->Finalize()) {
29       return false;
30     }
31   }
32   groups_.clear();
33   initialized_ = false;
34   finalized_ = true;
35   return true;
36 }
37 
DestroyCommunicationGroup(const std::string & group_name)38 bool CollectiveCommunicationLib::DestroyCommunicationGroup(const std::string &group_name) {
39   if (groups_.count(group_name) == 0) {
40     return false;
41   }
42   auto group = groups_[group_name];
43   CHECK_IF_NULL(group);
44   if (!group->Finalize()) {
45     return false;
46   }
47   (void)groups_.erase(group_name);
48   return true;
49 }
50 
GetRankId(const std::string & group_name)51 uint32_t CollectiveCommunicationLib::GetRankId(const std::string &group_name) {
52   CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " does not exist.");
53   auto group = groups_[group_name];
54   CHECK_IF_NULL(group);
55   return group->GetGroupRank(global_rank_id_);
56 }
57 
GetGroupSize(const std::string & group_name)58 uint32_t CollectiveCommunicationLib::GetGroupSize(const std::string &group_name) {
59   CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " does not exist.");
60   auto group = groups_[group_name];
61   CHECK_IF_NULL(group);
62   return group->group_size();
63 }
64 
GetLocalRankId(const std::string & group_name)65 uint32_t CollectiveCommunicationLib::GetLocalRankId(const std::string &group_name) {
66   CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " does not exist.");
67   auto group = groups_[group_name];
68   CHECK_IF_NULL(group);
69   return group->GetLocalGroupRank();
70 }
71 
GetLocalGroupSize(const std::string & group_name)72 uint32_t CollectiveCommunicationLib::GetLocalGroupSize(const std::string &group_name) {
73   CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " does not exist.");
74   auto group = groups_[group_name];
75   CHECK_IF_NULL(group);
76   return group->local_group_size();
77 }
78 
GetWorldRankFromGroupRank(const std::string & group_name,uint32_t local_rank)79 uint32_t CollectiveCommunicationLib::GetWorldRankFromGroupRank(const std::string &group_name, uint32_t local_rank) {
80   CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " does not exist.");
81   auto group = groups_[group_name];
82   CHECK_IF_NULL(group);
83   return group->GetGlobalRank(local_rank);
84 }
85 
GetGroupRankFromWorldRank(uint32_t global_rank,const std::string & group_name)86 uint32_t CollectiveCommunicationLib::GetGroupRankFromWorldRank(uint32_t global_rank, const std::string &group_name) {
87   CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " does not exist.");
88   auto group = groups_[group_name];
89   CHECK_IF_NULL(group);
90   return group->GetGroupRank(global_rank);
91 }
92 
GetGroup(const std::string & group_name)93 CommunicationGroupPtr CollectiveCommunicationLib::GetGroup(const std::string &group_name) {
94   if (groups_.count(group_name) == 0) {
95     return nullptr;
96   }
97   return groups_[group_name];
98 }
99 
SetLocalGroupRank(const std::string & group_name,uint32_t local_rank_id)100 void CollectiveCommunicationLib::SetLocalGroupRank(const std::string &group_name, uint32_t local_rank_id) {
101   CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " does not exist.");
102   auto group = groups_[group_name];
103   CHECK_IF_NULL(group);
104   group->set_local_rank(local_rank_id);
105 }
106 
SetLocalGroupSize(const std::string & group_name,uint32_t local_group_size)107 void CollectiveCommunicationLib::SetLocalGroupSize(const std::string &group_name, uint32_t local_group_size) {
108   CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " does not exist.");
109   auto group = groups_[group_name];
110   CHECK_IF_NULL(group);
111   group->set_local_size(local_group_size);
112 }
113 
global_group_name() const114 const std::string &CollectiveCommunicationLib::global_group_name() const { return global_group_name_; }
115 
global_rank_id() const116 uint32_t CollectiveCommunicationLib::global_rank_id() const { return global_rank_id_; }
117 
local_rank_id() const118 uint32_t CollectiveCommunicationLib::local_rank_id() const { return local_rank_id_; }
119 
global_rank_size() const120 uint32_t CollectiveCommunicationLib::global_rank_size() const { return global_rank_size_; }
121 }  // namespace device
122 }  // namespace mindspore
123