1 /**
2 * Copyright 2019-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/tensor_layout/redistribution_operator_infer.h"
18 #include <utility>
19 #include "frontend/parallel/device_manager.h"
20 #include "frontend/parallel/ops_info/ops_utils.h"
21 #include "include/common/utils/parallel_context.h"
22
23 namespace mindspore {
24 namespace parallel {
Init(const TensorLayout & tensor_layout,const Map & out_tensor_map,RankList dev_list,bool is_cost_model,bool is_dynamic_shape)25 Status RedistributionOperatorInfer::Init(const TensorLayout &tensor_layout, const Map &out_tensor_map,
26 RankList dev_list, bool is_cost_model, bool is_dynamic_shape) {
27 in_tensor_map_ = tensor_layout.tensor_map();
28 dev_mat_ = tensor_layout.device_arrangement();
29 if (!is_dynamic_shape &&
30 (in_tensor_map_.GetDimSize() == 0 || out_tensor_map.GetDimSize() != in_tensor_map_.GetDimSize())) {
31 MS_LOG(ERROR) << "Invalid input when initialize RedistributionOperatorInfer!";
32 return Status::FAILED;
33 }
34
35 cur_tensor_layout_ = tensor_layout;
36 out_tensor_map_ = out_tensor_map;
37 dev_list_ = std::move(dev_list);
38
39 operator_list_.clear();
40 operator_vector_.clear();
41 output_info_vector_.clear();
42 if (constructor_.Init(dev_list_, dev_mat_.array(), is_cost_model, is_dynamic_shape) != Status::SUCCESS) {
43 MS_LOG(ERROR) << "Init constructor failed";
44 return Status::FAILED;
45 }
46 if (virtual_rank_ >= 0) {
47 constructor_.SetVirtualRank(virtual_rank_);
48 }
49 constructor_.UpdateTensorShape(cur_tensor_layout_.slice_shape().array());
50
51 size_t key = 0;
52 Shape map = in_tensor_map_.array();
53 for (int64_t item : map) {
54 map_[key++] = item;
55 }
56
57 is_cost_model_ = is_cost_model;
58 return Status::SUCCESS;
59 }
60
MergePartialToFullForReshapeHasMultiDynamicAxis()61 Status RedistributionOperatorInfer::MergePartialToFullForReshapeHasMultiDynamicAxis() {
62 for (size_t i = 0; i < this->in_tensor_map_.array().size(); ++i) {
63 int64_t matrix_index = this->in_tensor_map_.GetDimByIdx(i);
64 if (matrix_index == -1) {
65 continue;
66 }
67 int64_t shard_value = this->dev_mat_.GetDimByReverseIdx(LongToSize(matrix_index));
68 Args args = {
69 SizeToLong(i), // TRANSFER_CONCAT_TENSOR_DIM_INDEX
70 matrix_index, // TRANSFER_CONCAT_DEV_DIM_INDEX
71 shard_value // TRANSFER_CONCAT_SPLIT_COUNT_INDEX
72 };
73 if (InsertOperator(CONCAT_BY_AXIS, args) == Status::FAILED) {
74 return Status::FAILED;
75 }
76 }
77 return Status::SUCCESS;
78 }
79
SegmentFullShapeToPartial()80 Status RedistributionOperatorInfer::SegmentFullShapeToPartial() {
81 // According to out layout tensor map, insert split.
82 for (size_t i = 0; i < this->out_tensor_map_.array().size(); ++i) {
83 int64_t matrix_index = this->out_tensor_map_.GetDimByIdx(i);
84 if (matrix_index == -1) {
85 continue;
86 }
87 constructor_.UpdateTensorShape(cur_tensor_layout_.tensor_shape().array());
88 // Insert Split on each dim.
89 Args args = {dev_mat_.GetDimByReverseIdx(LongToSize(matrix_index)), SizeToLong(i), matrix_index};
90 if (InsertOperator(SPLIT_BY_AXIS, args) == Status::FAILED) {
91 MS_LOG(ERROR) << "Insert SplitByAxis Error!";
92 return Status::FAILED;
93 }
94 }
95 return Status::SUCCESS;
96 }
97
InferRedistributionOperator()98 Status RedistributionOperatorInfer::InferRedistributionOperator() {
99 this->constructor_.UpdateTensorShape(cur_tensor_layout_.slice_shape().array());
100 while (!map_.empty()) {
101 size_t len_global = operator_list_.size();
102
103 while (!map_.empty()) {
104 size_t len_split_by_axis = operator_list_.size();
105 // split_by_axis operation
106 if (InferSplitByAxis() == Status::FAILED) {
107 return Status::FAILED;
108 }
109 // permute_by_axis operation
110 while (!map_.empty()) {
111 size_t len_permute_by_axis = operator_list_.size();
112 if (InferPermuteByAxis() == Status::FAILED) {
113 return Status::FAILED;
114 }
115 if (len_permute_by_axis == operator_list_.size()) {
116 break;
117 }
118 }
119 if (len_split_by_axis == operator_list_.size()) {
120 break;
121 }
122 }
123 // concat_by_axis operation
124 if (InferConcatByAxis() == Status::FAILED) {
125 return Status::FAILED;
126 }
127 // break loop structure with concat_by_axis
128 if (len_global == operator_list_.size() && !map_.empty()) {
129 size_t index = map_.begin()->first;
130 int64_t in_dim = map_[index];
131 map_[index] = NONE;
132 Args args = {SizeToLong(index), in_dim, dev_mat_.GetDimByReverseIdx(LongToSize(in_dim))};
133 if (InsertOperator(CONCAT_BY_AXIS, args) == Status::FAILED) {
134 return Status::FAILED;
135 }
136 }
137 }
138 return Status::SUCCESS;
139 }
140
InferSplitByAxis()141 Status RedistributionOperatorInfer::InferSplitByAxis() {
142 for (auto iter = map_.begin(); iter != map_.end();) {
143 uint64_t index = iter->first;
144 int64_t in_dim = iter->second;
145 int64_t out_dim = out_tensor_map_.GetDimByIdx(index);
146 if (in_dim == out_dim) {
147 iter = map_.erase(iter);
148 continue;
149 }
150 if (in_dim == NONE &&
151 !std::any_of(map_.begin(), map_.end(),
152 [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) {
153 Args args = {dev_mat_.GetDimByReverseIdx(LongToUlong(out_dim)), UlongToLong(index), out_dim};
154 if (InsertOperator(SPLIT_BY_AXIS, args) == Status::FAILED) {
155 MS_LOG(ERROR) << "Insert SplitByAxis Error!";
156 return Status::FAILED;
157 }
158 iter = map_.erase(iter);
159 } else {
160 (void)++iter;
161 }
162 }
163 return Status::SUCCESS;
164 }
165
InferPermuteByAxis()166 Status RedistributionOperatorInfer::InferPermuteByAxis() {
167 for (auto iter = map_.begin(); iter != map_.end();) {
168 uint64_t index = iter->first;
169 int64_t in_dim = iter->second;
170 int64_t out_dim = out_tensor_map_.GetDimByIdx(index);
171 if (in_dim == out_dim) {
172 iter = map_.erase(iter);
173 continue;
174 }
175 if (in_dim == NONE &&
176 std::any_of(map_.begin(), map_.end(),
177 [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) {
178 int64_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim);
179 int64_t dev_num = dev_mat_.GetDimByReverseIdx(LongToSize(out_dim));
180 if (ParallelContext::GetInstance()->enable_all2all() && !ParallelContext::GetInstance()->do_transform()) {
181 int64_t dev_dim = in_tensor_map_.GetDimByIdx(LongToUlong(cat_dim));
182 Args args_alltoall = {dev_mat_.GetDimByReverseIdx(LongToUlong(dev_dim)), UlongToLong(index), cat_dim, dev_dim,
183 dev_num};
184 if (InsertOperator(PERMUTE_BY_AXIS, args_alltoall) == Status::FAILED) {
185 MS_LOG(ERROR) << "Insert PermuteByAxis Error!";
186 return Status::FAILED;
187 }
188 } else {
189 Args args_allconcat = {cat_dim, out_dim, dev_num};
190 Args args_allsplit = {dev_num, UlongToLong(index), out_dim};
191 if (InsertOperator(CONCAT_BY_AXIS, args_allconcat) == Status::FAILED) {
192 MS_LOG(ERROR) << "Insert ConcatByAxis Error!";
193 return Status::FAILED;
194 }
195 if (InsertOperator(SPLIT_BY_AXIS, args_allsplit) == Status::FAILED) {
196 MS_LOG(ERROR) << "Insert SplitByAxis Error!";
197 return Status::FAILED;
198 }
199 }
200 iter = map_.erase(iter);
201 map_[LongToSize(cat_dim)] = NONE;
202 } else {
203 (void)++iter;
204 }
205 }
206 return Status::SUCCESS;
207 }
208
InferConcatByAxis()209 Status RedistributionOperatorInfer::InferConcatByAxis() {
210 for (auto iter = map_.begin(); iter != map_.end();) {
211 uint64_t index = iter->first;
212 int64_t in_dim = iter->second;
213 int64_t out_dim = out_tensor_map_.GetDimByIdx(index);
214 if (in_dim != NONE && out_tensor_map_.GetIndexByValue(in_dim) == NONE) {
215 Args args = {SizeToLong(index), in_dim, dev_mat_.GetDimByReverseIdx(LongToSize(in_dim))};
216 if (InsertOperator(CONCAT_BY_AXIS, args) == Status::FAILED) {
217 MS_LOG(ERROR) << "Insert ConcatByAxis Error!";
218 return Status::FAILED;
219 }
220 if (out_dim == NONE) {
221 iter = map_.erase(iter);
222 } else {
223 iter->second = NONE;
224 (void)++iter;
225 }
226 } else {
227 (void)++iter;
228 }
229 }
230 return Status::SUCCESS;
231 }
232
233 // Transfer communicative operators into primitives and insert them into vector
InsertOperator(const OperatorName & name,const Args & args)234 Status RedistributionOperatorInfer::InsertOperator(const OperatorName &name, const Args &args) {
235 OperatorR op = std::make_pair(name, args);
236 OperatorC op_cost = std::make_pair(op, cur_tensor_layout_.slice_shape().array());
237 operator_list_.push_back(op_cost);
238 if (construct_op_flag_) {
239 if (name == SPLIT_BY_AXIS) {
240 if (TransferSplitByAxis(args) == Status::FAILED) {
241 return Status::FAILED;
242 }
243 } else if (name == PERMUTE_BY_AXIS) {
244 if (TransferPermuteByAxis(args) == Status::FAILED) {
245 return Status::FAILED;
246 }
247 } else {
248 if (TransferConcatByAxis(args) == Status::FAILED) {
249 return Status::FAILED;
250 }
251 }
252 constructor_.UpdateTensorShape(cur_tensor_layout_.slice_shape().array());
253 }
254 return Status::SUCCESS;
255 }
256
TransferSplitByAxis(const Args & args)257 Status RedistributionOperatorInfer::TransferSplitByAxis(const Args &args) {
258 if (args.size() < TRANSFER_SPLIT_ARGS_SIZE) {
259 MS_LOG(ERROR) << "args size should not be less than 3!";
260 return Status::FAILED;
261 }
262 size_t index = LongToSize(args[TRANSFER_PERMUTE_SPLIT_DIM_INDEX]);
263 if (constructor_.StridedSliceOP(args) != Status::SUCCESS) {
264 return Status::FAILED;
265 } else {
266 operator_vector_.push_back(constructor_.GetOperator());
267 output_info_vector_.push_back(std::make_pair(false, 0));
268 }
269 if (cur_tensor_layout_.UpdateTensorMap(index, args[TRANSFER_PERMUTE_CONCAT_DIM_INDEX]) == Status::FAILED) {
270 return Status::FAILED;
271 }
272 return Status::SUCCESS;
273 }
274
TransferPermuteByAxis(const Args & args)275 Status RedistributionOperatorInfer::TransferPermuteByAxis(const Args &args) {
276 if (args.size() < TRANSFER_PERMUTE_ARGS_SIZE) {
277 MS_LOG(ERROR) << "args size should not be less than 5!";
278 return Status::FAILED;
279 }
280 if (constructor_.AlltoAllOP(args) != Status::SUCCESS) {
281 return Status::FAILED;
282 } else {
283 operator_vector_.push_back(constructor_.GetOperator());
284 output_info_vector_.push_back(std::make_pair(false, 0));
285 }
286 size_t index = LongToSize(args[TRANSFER_PERMUTE_SPLIT_DIM_INDEX]);
287 int64_t val = args[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
288 int64_t out_dim = out_tensor_map_.GetDimByIdx(index);
289
290 if (cur_tensor_layout_.UpdateTensorMap(LongToSize(val), NONE) == Status::FAILED) {
291 return Status::FAILED;
292 }
293 if (cur_tensor_layout_.UpdateTensorMap(index, out_dim) == Status::FAILED) {
294 return Status::FAILED;
295 }
296 return Status::SUCCESS;
297 }
298
TransferConcatByAxis(const Args & args)299 Status RedistributionOperatorInfer::TransferConcatByAxis(const Args &args) {
300 if (args.size() < TRANSFER_CONCAT_ARGS_SIZE) {
301 MS_LOG(ERROR) << "args size should not be less than 3!";
302 return Status::FAILED;
303 }
304 int64_t tensor_dim = args[TRANSFER_CONCAT_TENSOR_DIM_INDEX];
305 int64_t dev_dim = args[TRANSFER_CONCAT_DEV_DIM_INDEX];
306 int64_t split_count = args[TRANSFER_CONCAT_SPLIT_COUNT_INDEX];
307 if (constructor_.AllGatherOP(dev_dim) != Status::SUCCESS) {
308 return Status::FAILED;
309 } else {
310 operator_vector_.push_back(constructor_.GetOperator());
311 (void)output_info_vector_.emplace_back(false, 0);
312 }
313 if (tensor_dim != 0) {
314 if (constructor_.SplitOP(split_count) != Status::SUCCESS) {
315 return Status::FAILED;
316 } else {
317 operator_vector_.push_back(constructor_.GetOperator());
318 (void)output_info_vector_.emplace_back(true, split_count);
319 }
320 if (constructor_.ConcatOP(tensor_dim) != Status::SUCCESS) {
321 return Status::FAILED;
322 } else {
323 operator_vector_.push_back(constructor_.GetOperator());
324 (void)output_info_vector_.emplace_back(false, 0);
325 }
326 }
327 if (cur_tensor_layout_.UpdateTensorMap(LongToSize(tensor_dim), NONE) == Status::FAILED) {
328 return Status::FAILED;
329 }
330 return Status::SUCCESS;
331 }
332 } // namespace parallel
333 } // namespace mindspore
334