1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/collective/collective_communication_lib.h"
18
19 namespace mindspore {
20 namespace device {
Finalize()21 bool CollectiveCommunicationLib::Finalize() {
22 if (!initialized_ || finalized_.load()) {
23 return true;
24 }
25
26 for (const auto &group : groups_) {
27 CHECK_IF_NULL(group.second);
28 if (!group.second->Finalize()) {
29 return false;
30 }
31 }
32 groups_.clear();
33 initialized_ = false;
34 finalized_ = true;
35 return true;
36 }
37
DestroyCommunicationGroup(const std::string & group_name)38 bool CollectiveCommunicationLib::DestroyCommunicationGroup(const std::string &group_name) {
39 if (groups_.count(group_name) == 0) {
40 return false;
41 }
42 auto group = groups_[group_name];
43 CHECK_IF_NULL(group);
44 if (!group->Finalize()) {
45 return false;
46 }
47 (void)groups_.erase(group_name);
48 return true;
49 }
50
GetRankId(const std::string & group_name)51 uint32_t CollectiveCommunicationLib::GetRankId(const std::string &group_name) {
52 CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " does not exist.");
53 auto group = groups_[group_name];
54 CHECK_IF_NULL(group);
55 return group->GetGroupRank(global_rank_id_);
56 }
57
GetGroupSize(const std::string & group_name)58 uint32_t CollectiveCommunicationLib::GetGroupSize(const std::string &group_name) {
59 CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " does not exist.");
60 auto group = groups_[group_name];
61 CHECK_IF_NULL(group);
62 return group->group_size();
63 }
64
GetLocalRankId(const std::string & group_name)65 uint32_t CollectiveCommunicationLib::GetLocalRankId(const std::string &group_name) {
66 CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " does not exist.");
67 auto group = groups_[group_name];
68 CHECK_IF_NULL(group);
69 return group->GetLocalGroupRank();
70 }
71
GetLocalGroupSize(const std::string & group_name)72 uint32_t CollectiveCommunicationLib::GetLocalGroupSize(const std::string &group_name) {
73 CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " does not exist.");
74 auto group = groups_[group_name];
75 CHECK_IF_NULL(group);
76 return group->local_group_size();
77 }
78
GetWorldRankFromGroupRank(const std::string & group_name,uint32_t local_rank)79 uint32_t CollectiveCommunicationLib::GetWorldRankFromGroupRank(const std::string &group_name, uint32_t local_rank) {
80 CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " does not exist.");
81 auto group = groups_[group_name];
82 CHECK_IF_NULL(group);
83 return group->GetGlobalRank(local_rank);
84 }
85
GetGroupRankFromWorldRank(uint32_t global_rank,const std::string & group_name)86 uint32_t CollectiveCommunicationLib::GetGroupRankFromWorldRank(uint32_t global_rank, const std::string &group_name) {
87 CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " does not exist.");
88 auto group = groups_[group_name];
89 CHECK_IF_NULL(group);
90 return group->GetGroupRank(global_rank);
91 }
92
GetGroup(const std::string & group_name)93 CommunicationGroupPtr CollectiveCommunicationLib::GetGroup(const std::string &group_name) {
94 if (groups_.count(group_name) == 0) {
95 return nullptr;
96 }
97 return groups_[group_name];
98 }
99
SetLocalGroupRank(const std::string & group_name,uint32_t local_rank_id)100 void CollectiveCommunicationLib::SetLocalGroupRank(const std::string &group_name, uint32_t local_rank_id) {
101 CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " does not exist.");
102 auto group = groups_[group_name];
103 CHECK_IF_NULL(group);
104 group->set_local_rank(local_rank_id);
105 }
106
SetLocalGroupSize(const std::string & group_name,uint32_t local_group_size)107 void CollectiveCommunicationLib::SetLocalGroupSize(const std::string &group_name, uint32_t local_group_size) {
108 CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " does not exist.");
109 auto group = groups_[group_name];
110 CHECK_IF_NULL(group);
111 group->set_local_size(local_group_size);
112 }
113
global_group_name() const114 const std::string &CollectiveCommunicationLib::global_group_name() const { return global_group_name_; }
115
global_rank_id() const116 uint32_t CollectiveCommunicationLib::global_rank_id() const { return global_rank_id_; }
117
local_rank_id() const118 uint32_t CollectiveCommunicationLib::local_rank_id() const { return local_rank_id_; }
119
global_rank_size() const120 uint32_t CollectiveCommunicationLib::global_rank_size() const { return global_rank_size_; }
121 } // namespace device
122 } // namespace mindspore
123