• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GROUPED_PAIRWISE_EXCHANGE_ALLTOALL_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GROUPED_PAIRWISE_EXCHANGE_ALLTOALL_H_
19 
20 #include <vector>
21 #include <string>
22 #include <algorithm>
23 
24 #include "ir/anf.h"
25 #include "include/common/utils/utils.h"
26 #include "include/common/utils/parallel_context.h"
27 #include "pipeline/jit/ps/resource.h"
28 
29 namespace mindspore {
30 namespace opt {
31 void SetGroupedPairwiseExchangeAllToAll(const pipeline::ResourcePtr &resource);
32 size_t GetDeviceNum();
33 size_t GetGlobalRankID();
34 
35 class GroupedPairwiseExchangeAllToAllInfo {
36  public:
GroupedPairwiseExchangeAllToAllInfo()37   GroupedPairwiseExchangeAllToAllInfo() {
38     SetGroupNum();
39     SetRanksPerGroup();
40     SetGroupRank();
41     SetSendGroupRanks();
42     SetSortedInputsIdx();
43     SetTotalSendRankIds();
44     SetTotalRecvRankIds();
45     SetReshapeScaleAxisVec();
46   }
47   ~GroupedPairwiseExchangeAllToAllInfo() = default;
GetGroupNum()48   int64_t GetGroupNum() const { return gpea_num_; }
GetRanksPerGroup()49   int64_t GetRanksPerGroup() const { return ranks_per_group_; }
GetGroupRank()50   int64_t GetGroupRank() const { return group_rank_; }
GetSendRankIds(int64_t step_id)51   std::vector<int64_t> GetSendRankIds(int64_t step_id) { return total_send_rank_ids_[step_id]; }
GetRecvRankIds(int64_t step_id)52   std::vector<int64_t> GetRecvRankIds(int64_t step_id) { return total_recv_rank_ids_[step_id]; }
GetSendGroupRanks()53   std::vector<int64_t> GetSendGroupRanks() { return send_group_ranks_; }
GetSortedInputsIdx()54   std::vector<int64_t> GetSortedInputsIdx() { return sorted_inputs_idx_; }
GetReshapeScaleAxisVec()55   std::vector<uint32_t> GetReshapeScaleAxisVec() { return reshape_scale_axis_vec_; }
56 
DisplayInfo()57   void DisplayInfo() {
58     MS_LOG(DEBUG) << "gpea_num_ " << GetGroupNum();
59     MS_LOG(DEBUG) << "ranks_per_group_ " << GetRanksPerGroup();
60     MS_LOG(DEBUG) << "group_rank_ " << GetGroupRank();
61     MS_LOG(DEBUG) << "send_group_ranks_ " << GetSendGroupRanks();
62     MS_LOG(DEBUG) << "sorted_inputs_idx_ " << GetSortedInputsIdx();
63     for (int64_t step = 0; step < GetGroupNum(); step++) {
64       MS_LOG(DEBUG) << "step " << step << " recv_rank_ids " << GetRecvRankIds(step);
65       MS_LOG(DEBUG) << "step " << step << " send_rank_ids " << GetSendRankIds(step);
66     }
67     MS_LOG(DEBUG) << "reshape_scale_axis_vec_ " << GetReshapeScaleAxisVec();
68   }
69 
70  private:
71   int64_t gpea_num_;
72   int64_t ranks_per_group_;
73   int64_t group_rank_;
74   std::vector<std::vector<int64_t>> total_send_rank_ids_;
75   std::vector<std::vector<int64_t>> total_recv_rank_ids_;
76   std::vector<int64_t> send_group_ranks_;
77   std::vector<int64_t> sorted_inputs_idx_;
78   std::vector<uint32_t> reshape_scale_axis_vec_;
79 
SetGroupNum()80   void SetGroupNum() {
81     // for example, env['GPEA_NUM'] = "1"
82     std::string gpea_num_str = common::GetEnv("GPEA_NUM");
83     gpea_num_ = 1;
84     if (!gpea_num_str.empty()) {
85       const int decimal = 10;
86       gpea_num_ = std::strtol(gpea_num_str.c_str(), nullptr, decimal);
87     }
88   }
89 
SetRanksPerGroup()90   void SetRanksPerGroup() { ranks_per_group_ = SizeToLong(GetDeviceNum()) / gpea_num_; }
91 
SetGroupRank()92   void SetGroupRank() { group_rank_ = SizeToLong(GetGlobalRankID()) / ranks_per_group_; }
93 
SetTotalSendRankIds()94   void SetTotalSendRankIds() {
95     for (int64_t step = 0; step < gpea_num_; step++) {
96       std::vector<int64_t> curr_rank_ids;
97       int64_t send_group_id = (group_rank_ + step) % gpea_num_;
98       for (int64_t i = 0; i < ranks_per_group_; i++) {
99         curr_rank_ids.push_back(send_group_id * ranks_per_group_ + i);
100       }
101       total_send_rank_ids_.push_back(curr_rank_ids);
102     }
103   }
104 
SetTotalRecvRankIds()105   void SetTotalRecvRankIds() {
106     for (int64_t step = 0; step < gpea_num_; step++) {
107       std::vector<int64_t> curr_rank_ids;
108       int64_t recv_group_id = (group_rank_ - step + gpea_num_) % gpea_num_;
109       for (int64_t i = 0; i < ranks_per_group_; i++) {
110         curr_rank_ids.push_back(recv_group_id * ranks_per_group_ + i);
111       }
112       total_recv_rank_ids_.push_back(curr_rank_ids);
113     }
114   }
115 
SetSendGroupRanks()116   void SetSendGroupRanks() {
117     for (int64_t step = 0; step < gpea_num_; step++) {
118       int64_t curr_group_rank = (group_rank_ + step) % gpea_num_;
119       send_group_ranks_.push_back(curr_group_rank);
120     }
121   }
122 
123   template <typename T>
sort_indexes(const std::vector<T> & v)124   std::vector<int64_t> sort_indexes(const std::vector<T> &v) const {
125     std::vector<int64_t> idx(v.size());
126     for (size_t i = 0; i < idx.size(); ++i) {
127       idx[i] = SizeToLong(i);
128     }
129     sort(idx.begin(), idx.end(), [&v](size_t i1, size_t i2) { return v[i1] < v[i2]; });
130     return idx;
131   }
132 
SetSortedInputsIdx()133   void SetSortedInputsIdx() { sorted_inputs_idx_ = sort_indexes(send_group_ranks_); }
134 
SetReshapeScaleAxisVec()135   void SetReshapeScaleAxisVec() {
136     // for example, env['GPEA_RESHAPE_SCALE_AXIS'] = "2,1"
137     std::string reshape_scale_axis_str = common::GetEnv("GPEA_RESHAPE_SCALE_AXIS");
138     if (reshape_scale_axis_str.empty()) {
139       reshape_scale_axis_vec_ = {kIndex2, kIndex1};
140       return;
141     }
142 
143     std::string value_str;
144     std::vector<std::string> result_str_vec;
145     for (size_t i = 0; i < reshape_scale_axis_str.size(); ++i) {
146       if (reshape_scale_axis_str[i] == ',') {
147         result_str_vec.push_back(value_str);
148         value_str.clear();
149       } else {
150         value_str += reshape_scale_axis_str[i];
151       }
152     }
153     if (reshape_scale_axis_str.back() != ',') {
154       result_str_vec.push_back(value_str);
155     }
156 
157     for (size_t i = 0; i < result_str_vec.size(); i++) {
158       reshape_scale_axis_vec_.push_back(atoi(result_str_vec[i].c_str()));
159     }
160     return;
161   }
162 };
163 }  // namespace opt
164 }  // namespace mindspore
165 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GROUPED_PAIRWISE_EXCHANGE_ALLTOALL_H_
166