• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/collective/communication_group.h"
18 
19 namespace mindspore {
20 namespace device {
CommunicationGroup(const std::string & name,const std::vector<uint32_t> & group_ranks,uint32_t global_rank,uint32_t local_group_rank,uint32_t local_group_size)21 CommunicationGroup::CommunicationGroup(const std::string &name, const std::vector<uint32_t> &group_ranks,
22                                        uint32_t global_rank, uint32_t local_group_rank, uint32_t local_group_size)
23     : initialized_(false),
24       global_rank_(global_rank),
25       local_group_rank_(local_group_rank),
26       local_group_size_(local_group_size),
27       size_(group_ranks.size()),
28       name_(name),
29       group_ranks_(group_ranks) {
30   uint32_t group_rank = 0;
31   // The input group_ranks contains the global ranks of the processes in this group.
32   (void)std::for_each(group_ranks.begin(), group_ranks.end(), [&](const uint32_t &global_rank) {
33     global_to_group_ranks_[global_rank] = group_rank;
34     group_to_global_ranks_[group_rank] = global_rank;
35     group_rank++;
36   });
37 }
38 
GetGroupRank(uint32_t global_rank)39 uint32_t CommunicationGroup::GetGroupRank(uint32_t global_rank) {
40   CHECK_RET((global_to_group_ranks_.count(global_rank) != 0), true,
41             "Group " + name_ + " doesn't contain the global rank " + std::to_string(global_rank));
42   return global_to_group_ranks_[global_rank];
43 }
44 
GetLocalGroupRank()45 uint32_t CommunicationGroup::GetLocalGroupRank() {
46   CHECK_RET((local_group_rank_ != UINT32_MAX), true,
47             "Group " + name_ + " doesn't contain the global rank " + std::to_string(global_rank_));
48   return local_group_rank_;
49 }
50 
GetGlobalRank(uint32_t group_rank)51 uint32_t CommunicationGroup::GetGlobalRank(uint32_t group_rank) {
52   CHECK_RET((group_to_global_ranks_.count(group_rank) != 0), true,
53             "Group " + name_ + " doesn't contain the group rank " + std::to_string(group_rank));
54   return group_to_global_ranks_[group_rank];
55 }
56 
group_size() const57 uint32_t CommunicationGroup::group_size() const { return size_; }
58 
local_group_size() const59 uint32_t CommunicationGroup::local_group_size() const { return local_group_size_; }
60 
group_ranks() const61 const std::vector<uint32_t> &CommunicationGroup::group_ranks() const { return group_ranks_; }
62 
global_to_group_ranks() const63 const std::map<uint32_t, uint32_t> &CommunicationGroup::global_to_group_ranks() const { return global_to_group_ranks_; }
64 
group_to_global_ranks() const65 const std::map<uint32_t, uint32_t> &CommunicationGroup::group_to_global_ranks() const { return group_to_global_ranks_; }
66 }  // namespace device
67 }  // namespace mindspore
68