• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/tensor_layout/redistribution_operator_infer.h"
18 #include <utility>
19 #include "frontend/parallel/device_manager.h"
20 #include "frontend/parallel/ops_info/ops_utils.h"
21 #include "include/common/utils/parallel_context.h"
22 
23 namespace mindspore {
24 namespace parallel {
Init(const TensorLayout & tensor_layout,const Map & out_tensor_map,RankList dev_list,bool is_cost_model,bool is_dynamic_shape)25 Status RedistributionOperatorInfer::Init(const TensorLayout &tensor_layout, const Map &out_tensor_map,
26                                          RankList dev_list, bool is_cost_model, bool is_dynamic_shape) {
27   in_tensor_map_ = tensor_layout.tensor_map();
28   dev_mat_ = tensor_layout.device_arrangement();
29   if (!is_dynamic_shape &&
30       (in_tensor_map_.GetDimSize() == 0 || out_tensor_map.GetDimSize() != in_tensor_map_.GetDimSize())) {
31     MS_LOG(ERROR) << "Invalid input when initialize RedistributionOperatorInfer!";
32     return Status::FAILED;
33   }
34 
35   cur_tensor_layout_ = tensor_layout;
36   out_tensor_map_ = out_tensor_map;
37   dev_list_ = std::move(dev_list);
38 
39   operator_list_.clear();
40   operator_vector_.clear();
41   output_info_vector_.clear();
42   if (constructor_.Init(dev_list_, dev_mat_.array(), is_cost_model, is_dynamic_shape) != Status::SUCCESS) {
43     MS_LOG(ERROR) << "Init constructor failed";
44     return Status::FAILED;
45   }
46   if (virtual_rank_ >= 0) {
47     constructor_.SetVirtualRank(virtual_rank_);
48   }
49   constructor_.UpdateTensorShape(cur_tensor_layout_.slice_shape().array());
50 
51   size_t key = 0;
52   Shape map = in_tensor_map_.array();
53   for (int64_t item : map) {
54     map_[key++] = item;
55   }
56 
57   is_cost_model_ = is_cost_model;
58   return Status::SUCCESS;
59 }
60 
MergePartialToFullForReshapeHasMultiDynamicAxis()61 Status RedistributionOperatorInfer::MergePartialToFullForReshapeHasMultiDynamicAxis() {
62   for (size_t i = 0; i < this->in_tensor_map_.array().size(); ++i) {
63     int64_t matrix_index = this->in_tensor_map_.GetDimByIdx(i);
64     if (matrix_index == -1) {
65       continue;
66     }
67     int64_t shard_value = this->dev_mat_.GetDimByReverseIdx(LongToSize(matrix_index));
68     Args args = {
69       SizeToLong(i),  // TRANSFER_CONCAT_TENSOR_DIM_INDEX
70       matrix_index,   // TRANSFER_CONCAT_DEV_DIM_INDEX
71       shard_value     // TRANSFER_CONCAT_SPLIT_COUNT_INDEX
72     };
73     if (InsertOperator(CONCAT_BY_AXIS, args) == Status::FAILED) {
74       return Status::FAILED;
75     }
76   }
77   return Status::SUCCESS;
78 }
79 
SegmentFullShapeToPartial()80 Status RedistributionOperatorInfer::SegmentFullShapeToPartial() {
81   // According to out layout tensor map, insert split.
82   for (size_t i = 0; i < this->out_tensor_map_.array().size(); ++i) {
83     int64_t matrix_index = this->out_tensor_map_.GetDimByIdx(i);
84     if (matrix_index == -1) {
85       continue;
86     }
87     constructor_.UpdateTensorShape(cur_tensor_layout_.tensor_shape().array());
88     // Insert Split on each dim.
89     Args args = {dev_mat_.GetDimByReverseIdx(LongToSize(matrix_index)), SizeToLong(i), matrix_index};
90     if (InsertOperator(SPLIT_BY_AXIS, args) == Status::FAILED) {
91       MS_LOG(ERROR) << "Insert SplitByAxis Error!";
92       return Status::FAILED;
93     }
94   }
95   return Status::SUCCESS;
96 }
97 
InferRedistributionOperator()98 Status RedistributionOperatorInfer::InferRedistributionOperator() {
99   this->constructor_.UpdateTensorShape(cur_tensor_layout_.slice_shape().array());
100   while (!map_.empty()) {
101     size_t len_global = operator_list_.size();
102 
103     while (!map_.empty()) {
104       size_t len_split_by_axis = operator_list_.size();
105       // split_by_axis operation
106       if (InferSplitByAxis() == Status::FAILED) {
107         return Status::FAILED;
108       }
109       // permute_by_axis operation
110       while (!map_.empty()) {
111         size_t len_permute_by_axis = operator_list_.size();
112         if (InferPermuteByAxis() == Status::FAILED) {
113           return Status::FAILED;
114         }
115         if (len_permute_by_axis == operator_list_.size()) {
116           break;
117         }
118       }
119       if (len_split_by_axis == operator_list_.size()) {
120         break;
121       }
122     }
123     // concat_by_axis operation
124     if (InferConcatByAxis() == Status::FAILED) {
125       return Status::FAILED;
126     }
127     // break loop structure with concat_by_axis
128     if (len_global == operator_list_.size() && !map_.empty()) {
129       size_t index = map_.begin()->first;
130       int64_t in_dim = map_[index];
131       map_[index] = NONE;
132       Args args = {SizeToLong(index), in_dim, dev_mat_.GetDimByReverseIdx(LongToSize(in_dim))};
133       if (InsertOperator(CONCAT_BY_AXIS, args) == Status::FAILED) {
134         return Status::FAILED;
135       }
136     }
137   }
138   return Status::SUCCESS;
139 }
140 
InferSplitByAxis()141 Status RedistributionOperatorInfer::InferSplitByAxis() {
142   for (auto iter = map_.begin(); iter != map_.end();) {
143     uint64_t index = iter->first;
144     int64_t in_dim = iter->second;
145     int64_t out_dim = out_tensor_map_.GetDimByIdx(index);
146     if (in_dim == out_dim) {
147       iter = map_.erase(iter);
148       continue;
149     }
150     if (in_dim == NONE &&
151         !std::any_of(map_.begin(), map_.end(),
152                      [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) {
153       Args args = {dev_mat_.GetDimByReverseIdx(LongToUlong(out_dim)), UlongToLong(index), out_dim};
154       if (InsertOperator(SPLIT_BY_AXIS, args) == Status::FAILED) {
155         MS_LOG(ERROR) << "Insert SplitByAxis Error!";
156         return Status::FAILED;
157       }
158       iter = map_.erase(iter);
159     } else {
160       (void)++iter;
161     }
162   }
163   return Status::SUCCESS;
164 }
165 
InferPermuteByAxis()166 Status RedistributionOperatorInfer::InferPermuteByAxis() {
167   for (auto iter = map_.begin(); iter != map_.end();) {
168     uint64_t index = iter->first;
169     int64_t in_dim = iter->second;
170     int64_t out_dim = out_tensor_map_.GetDimByIdx(index);
171     if (in_dim == out_dim) {
172       iter = map_.erase(iter);
173       continue;
174     }
175     if (in_dim == NONE &&
176         std::any_of(map_.begin(), map_.end(),
177                     [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) {
178       int64_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim);
179       int64_t dev_num = dev_mat_.GetDimByReverseIdx(LongToSize(out_dim));
180       if (ParallelContext::GetInstance()->enable_all2all() && !ParallelContext::GetInstance()->do_transform()) {
181         int64_t dev_dim = in_tensor_map_.GetDimByIdx(LongToUlong(cat_dim));
182         Args args_alltoall = {dev_mat_.GetDimByReverseIdx(LongToUlong(dev_dim)), UlongToLong(index), cat_dim, dev_dim,
183                               dev_num};
184         if (InsertOperator(PERMUTE_BY_AXIS, args_alltoall) == Status::FAILED) {
185           MS_LOG(ERROR) << "Insert PermuteByAxis Error!";
186           return Status::FAILED;
187         }
188       } else {
189         Args args_allconcat = {cat_dim, out_dim, dev_num};
190         Args args_allsplit = {dev_num, UlongToLong(index), out_dim};
191         if (InsertOperator(CONCAT_BY_AXIS, args_allconcat) == Status::FAILED) {
192           MS_LOG(ERROR) << "Insert ConcatByAxis Error!";
193           return Status::FAILED;
194         }
195         if (InsertOperator(SPLIT_BY_AXIS, args_allsplit) == Status::FAILED) {
196           MS_LOG(ERROR) << "Insert SplitByAxis Error!";
197           return Status::FAILED;
198         }
199       }
200       iter = map_.erase(iter);
201       map_[LongToSize(cat_dim)] = NONE;
202     } else {
203       (void)++iter;
204     }
205   }
206   return Status::SUCCESS;
207 }
208 
InferConcatByAxis()209 Status RedistributionOperatorInfer::InferConcatByAxis() {
210   for (auto iter = map_.begin(); iter != map_.end();) {
211     uint64_t index = iter->first;
212     int64_t in_dim = iter->second;
213     int64_t out_dim = out_tensor_map_.GetDimByIdx(index);
214     if (in_dim != NONE && out_tensor_map_.GetIndexByValue(in_dim) == NONE) {
215       Args args = {SizeToLong(index), in_dim, dev_mat_.GetDimByReverseIdx(LongToSize(in_dim))};
216       if (InsertOperator(CONCAT_BY_AXIS, args) == Status::FAILED) {
217         MS_LOG(ERROR) << "Insert ConcatByAxis Error!";
218         return Status::FAILED;
219       }
220       if (out_dim == NONE) {
221         iter = map_.erase(iter);
222       } else {
223         iter->second = NONE;
224         (void)++iter;
225       }
226     } else {
227       (void)++iter;
228     }
229   }
230   return Status::SUCCESS;
231 }
232 
233 // Transfer communicative operators into primitives and insert them into vector
InsertOperator(const OperatorName & name,const Args & args)234 Status RedistributionOperatorInfer::InsertOperator(const OperatorName &name, const Args &args) {
235   OperatorR op = std::make_pair(name, args);
236   OperatorC op_cost = std::make_pair(op, cur_tensor_layout_.slice_shape().array());
237   operator_list_.push_back(op_cost);
238   if (construct_op_flag_) {
239     if (name == SPLIT_BY_AXIS) {
240       if (TransferSplitByAxis(args) == Status::FAILED) {
241         return Status::FAILED;
242       }
243     } else if (name == PERMUTE_BY_AXIS) {
244       if (TransferPermuteByAxis(args) == Status::FAILED) {
245         return Status::FAILED;
246       }
247     } else {
248       if (TransferConcatByAxis(args) == Status::FAILED) {
249         return Status::FAILED;
250       }
251     }
252     constructor_.UpdateTensorShape(cur_tensor_layout_.slice_shape().array());
253   }
254   return Status::SUCCESS;
255 }
256 
TransferSplitByAxis(const Args & args)257 Status RedistributionOperatorInfer::TransferSplitByAxis(const Args &args) {
258   if (args.size() < TRANSFER_SPLIT_ARGS_SIZE) {
259     MS_LOG(ERROR) << "args size should not be less than 3!";
260     return Status::FAILED;
261   }
262   size_t index = LongToSize(args[TRANSFER_PERMUTE_SPLIT_DIM_INDEX]);
263   if (constructor_.StridedSliceOP(args) != Status::SUCCESS) {
264     return Status::FAILED;
265   } else {
266     operator_vector_.push_back(constructor_.GetOperator());
267     output_info_vector_.push_back(std::make_pair(false, 0));
268   }
269   if (cur_tensor_layout_.UpdateTensorMap(index, args[TRANSFER_PERMUTE_CONCAT_DIM_INDEX]) == Status::FAILED) {
270     return Status::FAILED;
271   }
272   return Status::SUCCESS;
273 }
274 
TransferPermuteByAxis(const Args & args)275 Status RedistributionOperatorInfer::TransferPermuteByAxis(const Args &args) {
276   if (args.size() < TRANSFER_PERMUTE_ARGS_SIZE) {
277     MS_LOG(ERROR) << "args size should not be less than 5!";
278     return Status::FAILED;
279   }
280   if (constructor_.AlltoAllOP(args) != Status::SUCCESS) {
281     return Status::FAILED;
282   } else {
283     operator_vector_.push_back(constructor_.GetOperator());
284     output_info_vector_.push_back(std::make_pair(false, 0));
285   }
286   size_t index = LongToSize(args[TRANSFER_PERMUTE_SPLIT_DIM_INDEX]);
287   int64_t val = args[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
288   int64_t out_dim = out_tensor_map_.GetDimByIdx(index);
289 
290   if (cur_tensor_layout_.UpdateTensorMap(LongToSize(val), NONE) == Status::FAILED) {
291     return Status::FAILED;
292   }
293   if (cur_tensor_layout_.UpdateTensorMap(index, out_dim) == Status::FAILED) {
294     return Status::FAILED;
295   }
296   return Status::SUCCESS;
297 }
298 
TransferConcatByAxis(const Args & args)299 Status RedistributionOperatorInfer::TransferConcatByAxis(const Args &args) {
300   if (args.size() < TRANSFER_CONCAT_ARGS_SIZE) {
301     MS_LOG(ERROR) << "args size should not be less than 3!";
302     return Status::FAILED;
303   }
304   int64_t tensor_dim = args[TRANSFER_CONCAT_TENSOR_DIM_INDEX];
305   int64_t dev_dim = args[TRANSFER_CONCAT_DEV_DIM_INDEX];
306   int64_t split_count = args[TRANSFER_CONCAT_SPLIT_COUNT_INDEX];
307   if (constructor_.AllGatherOP(dev_dim) != Status::SUCCESS) {
308     return Status::FAILED;
309   } else {
310     operator_vector_.push_back(constructor_.GetOperator());
311     (void)output_info_vector_.emplace_back(false, 0);
312   }
313   if (tensor_dim != 0) {
314     if (constructor_.SplitOP(split_count) != Status::SUCCESS) {
315       return Status::FAILED;
316     } else {
317       operator_vector_.push_back(constructor_.GetOperator());
318       (void)output_info_vector_.emplace_back(true, split_count);
319     }
320     if (constructor_.ConcatOP(tensor_dim) != Status::SUCCESS) {
321       return Status::FAILED;
322     } else {
323       operator_vector_.push_back(constructor_.GetOperator());
324       (void)output_info_vector_.emplace_back(false, 0);
325     }
326   }
327   if (cur_tensor_layout_.UpdateTensorMap(LongToSize(tensor_dim), NONE) == Status::FAILED) {
328     return Status::FAILED;
329   }
330   return Status::SUCCESS;
331 }
332 }  // namespace parallel
333 }  // namespace mindspore
334