1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/collective/communication_group.h"
18
19 namespace mindspore {
20 namespace device {
CommunicationGroup(const std::string & name,const std::vector<uint32_t> & group_ranks,uint32_t global_rank,uint32_t local_group_rank,uint32_t local_group_size)21 CommunicationGroup::CommunicationGroup(const std::string &name, const std::vector<uint32_t> &group_ranks,
22 uint32_t global_rank, uint32_t local_group_rank, uint32_t local_group_size)
23 : initialized_(false),
24 global_rank_(global_rank),
25 local_group_rank_(local_group_rank),
26 local_group_size_(local_group_size),
27 size_(group_ranks.size()),
28 name_(name),
29 group_ranks_(group_ranks) {
30 uint32_t group_rank = 0;
31 // The input group_ranks contains the global ranks of the processes in this group.
32 (void)std::for_each(group_ranks.begin(), group_ranks.end(), [&](const uint32_t &global_rank) {
33 global_to_group_ranks_[global_rank] = group_rank;
34 group_to_global_ranks_[group_rank] = global_rank;
35 group_rank++;
36 });
37 }
38
GetGroupRank(uint32_t global_rank)39 uint32_t CommunicationGroup::GetGroupRank(uint32_t global_rank) {
40 CHECK_RET((global_to_group_ranks_.count(global_rank) != 0), true,
41 "Group " + name_ + " doesn't contain the global rank " + std::to_string(global_rank));
42 return global_to_group_ranks_[global_rank];
43 }
44
GetLocalGroupRank()45 uint32_t CommunicationGroup::GetLocalGroupRank() {
46 CHECK_RET((local_group_rank_ != UINT32_MAX), true,
47 "Group " + name_ + " doesn't contain the global rank " + std::to_string(global_rank_));
48 return local_group_rank_;
49 }
50
GetGlobalRank(uint32_t group_rank)51 uint32_t CommunicationGroup::GetGlobalRank(uint32_t group_rank) {
52 CHECK_RET((group_to_global_ranks_.count(group_rank) != 0), true,
53 "Group " + name_ + " doesn't contain the group rank " + std::to_string(group_rank));
54 return group_to_global_ranks_[group_rank];
55 }
56
group_size() const57 uint32_t CommunicationGroup::group_size() const { return size_; }
58
local_group_size() const59 uint32_t CommunicationGroup::local_group_size() const { return local_group_size_; }
60
group_ranks() const61 const std::vector<uint32_t> &CommunicationGroup::group_ranks() const { return group_ranks_; }
62
global_to_group_ranks() const63 const std::map<uint32_t, uint32_t> &CommunicationGroup::global_to_group_ranks() const { return global_to_group_ranks_; }
64
group_to_global_ranks() const65 const std::map<uint32_t, uint32_t> &CommunicationGroup::group_to_global_ranks() const { return group_to_global_ranks_; }
66 } // namespace device
67 } // namespace mindspore
68