1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GROUPED_PAIRWISE_EXCHANGE_ALLTOALL_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GROUPED_PAIRWISE_EXCHANGE_ALLTOALL_H_ 19 20 #include <vector> 21 #include <string> 22 #include <algorithm> 23 24 #include "ir/anf.h" 25 #include "include/common/utils/utils.h" 26 #include "include/common/utils/parallel_context.h" 27 #include "pipeline/jit/ps/resource.h" 28 29 namespace mindspore { 30 namespace opt { 31 void SetGroupedPairwiseExchangeAllToAll(const pipeline::ResourcePtr &resource); 32 size_t GetDeviceNum(); 33 size_t GetGlobalRankID(); 34 35 class GroupedPairwiseExchangeAllToAllInfo { 36 public: GroupedPairwiseExchangeAllToAllInfo()37 GroupedPairwiseExchangeAllToAllInfo() { 38 SetGroupNum(); 39 SetRanksPerGroup(); 40 SetGroupRank(); 41 SetSendGroupRanks(); 42 SetSortedInputsIdx(); 43 SetTotalSendRankIds(); 44 SetTotalRecvRankIds(); 45 SetReshapeScaleAxisVec(); 46 } 47 ~GroupedPairwiseExchangeAllToAllInfo() = default; GetGroupNum()48 int64_t GetGroupNum() const { return gpea_num_; } GetRanksPerGroup()49 int64_t GetRanksPerGroup() const { return ranks_per_group_; } GetGroupRank()50 int64_t GetGroupRank() const { return group_rank_; } GetSendRankIds(int64_t step_id)51 std::vector<int64_t> GetSendRankIds(int64_t step_id) { return total_send_rank_ids_[step_id]; } GetRecvRankIds(int64_t step_id)52 std::vector<int64_t> GetRecvRankIds(int64_t step_id) { return total_recv_rank_ids_[step_id]; } GetSendGroupRanks()53 std::vector<int64_t> GetSendGroupRanks() { return send_group_ranks_; } GetSortedInputsIdx()54 std::vector<int64_t> GetSortedInputsIdx() { return sorted_inputs_idx_; } GetReshapeScaleAxisVec()55 std::vector<uint32_t> GetReshapeScaleAxisVec() { return reshape_scale_axis_vec_; } 56 DisplayInfo()57 void DisplayInfo() { 58 MS_LOG(DEBUG) << "gpea_num_ " << GetGroupNum(); 59 MS_LOG(DEBUG) << "ranks_per_group_ " << GetRanksPerGroup(); 60 MS_LOG(DEBUG) << "group_rank_ " << GetGroupRank(); 61 MS_LOG(DEBUG) << "send_group_ranks_ " << GetSendGroupRanks(); 62 MS_LOG(DEBUG) << "sorted_inputs_idx_ " << GetSortedInputsIdx(); 63 for (int64_t step = 0; step < GetGroupNum(); step++) { 64 MS_LOG(DEBUG) << "step " << step << " recv_rank_ids " << GetRecvRankIds(step); 65 MS_LOG(DEBUG) << "step " << step << " send_rank_ids " << GetSendRankIds(step); 66 } 67 MS_LOG(DEBUG) << "reshape_scale_axis_vec_ " << GetReshapeScaleAxisVec(); 68 } 69 70 private: 71 int64_t gpea_num_; 72 int64_t ranks_per_group_; 73 int64_t group_rank_; 74 std::vector<std::vector<int64_t>> total_send_rank_ids_; 75 std::vector<std::vector<int64_t>> total_recv_rank_ids_; 76 std::vector<int64_t> send_group_ranks_; 77 std::vector<int64_t> sorted_inputs_idx_; 78 std::vector<uint32_t> reshape_scale_axis_vec_; 79 SetGroupNum()80 void SetGroupNum() { 81 // for example, env['GPEA_NUM'] = "1" 82 std::string gpea_num_str = common::GetEnv("GPEA_NUM"); 83 gpea_num_ = 1; 84 if (!gpea_num_str.empty()) { 85 const int decimal = 10; 86 gpea_num_ = std::strtol(gpea_num_str.c_str(), nullptr, decimal); 87 } 88 } 89 SetRanksPerGroup()90 void SetRanksPerGroup() { ranks_per_group_ = SizeToLong(GetDeviceNum()) / gpea_num_; } 91 SetGroupRank()92 void SetGroupRank() { group_rank_ = SizeToLong(GetGlobalRankID()) / ranks_per_group_; } 93 SetTotalSendRankIds()94 void SetTotalSendRankIds() { 95 for (int64_t step = 0; step < gpea_num_; step++) { 96 std::vector<int64_t> curr_rank_ids; 97 int64_t send_group_id = (group_rank_ + step) % gpea_num_; 98 for (int64_t i = 0; i < ranks_per_group_; i++) { 99 curr_rank_ids.push_back(send_group_id * ranks_per_group_ + i); 100 } 101 total_send_rank_ids_.push_back(curr_rank_ids); 102 } 103 } 104 SetTotalRecvRankIds()105 void SetTotalRecvRankIds() { 106 for (int64_t step = 0; step < gpea_num_; step++) { 107 std::vector<int64_t> curr_rank_ids; 108 int64_t recv_group_id = (group_rank_ - step + gpea_num_) % gpea_num_; 109 for (int64_t i = 0; i < ranks_per_group_; i++) { 110 curr_rank_ids.push_back(recv_group_id * ranks_per_group_ + i); 111 } 112 total_recv_rank_ids_.push_back(curr_rank_ids); 113 } 114 } 115 SetSendGroupRanks()116 void SetSendGroupRanks() { 117 for (int64_t step = 0; step < gpea_num_; step++) { 118 int64_t curr_group_rank = (group_rank_ + step) % gpea_num_; 119 send_group_ranks_.push_back(curr_group_rank); 120 } 121 } 122 123 template <typename T> sort_indexes(const std::vector<T> & v)124 std::vector<int64_t> sort_indexes(const std::vector<T> &v) const { 125 std::vector<int64_t> idx(v.size()); 126 for (size_t i = 0; i < idx.size(); ++i) { 127 idx[i] = SizeToLong(i); 128 } 129 sort(idx.begin(), idx.end(), [&v](size_t i1, size_t i2) { return v[i1] < v[i2]; }); 130 return idx; 131 } 132 SetSortedInputsIdx()133 void SetSortedInputsIdx() { sorted_inputs_idx_ = sort_indexes(send_group_ranks_); } 134 SetReshapeScaleAxisVec()135 void SetReshapeScaleAxisVec() { 136 // for example, env['GPEA_RESHAPE_SCALE_AXIS'] = "2,1" 137 std::string reshape_scale_axis_str = common::GetEnv("GPEA_RESHAPE_SCALE_AXIS"); 138 if (reshape_scale_axis_str.empty()) { 139 reshape_scale_axis_vec_ = {kIndex2, kIndex1}; 140 return; 141 } 142 143 std::string value_str; 144 std::vector<std::string> result_str_vec; 145 for (size_t i = 0; i < reshape_scale_axis_str.size(); ++i) { 146 if (reshape_scale_axis_str[i] == ',') { 147 result_str_vec.push_back(value_str); 148 value_str.clear(); 149 } else { 150 value_str += reshape_scale_axis_str[i]; 151 } 152 } 153 if (reshape_scale_axis_str.back() != ',') { 154 result_str_vec.push_back(value_str); 155 } 156 157 for (size_t i = 0; i < result_str_vec.size(); i++) { 158 reshape_scale_axis_vec_.push_back(atoi(result_str_vec[i].c_str())); 159 } 160 return; 161 } 162 }; 163 } // namespace opt 164 } // namespace mindspore 165 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GROUPED_PAIRWISE_EXCHANGE_ALLTOALL_H_ 166