• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/tensor_layout/redistribution_operator_infer.h"
18 
19 #include <utility>
20 
21 #include "frontend/parallel/device_manager.h"
22 #include "frontend/parallel/context.h"
23 
24 namespace mindspore {
25 namespace parallel {
Init(const TensorLayout & tensor_layout,const Map & out_tensor_map,RankList dev_list,bool is_cost_model)26 Status RedistributionOperatorInfer::Init(const TensorLayout &tensor_layout, const Map &out_tensor_map,
27                                          RankList dev_list, bool is_cost_model) {
28   in_tensor_map_ = tensor_layout.tensor_map();
29   dev_mat_ = tensor_layout.device_arrangement();
30 
31   if (in_tensor_map_.GetDimSize() == 0 || out_tensor_map.GetDimSize() != in_tensor_map_.GetDimSize()) {
32     MS_LOG(ERROR) << "Invalid input when initialize RedistributionOperatorInfer!";
33     return Status::FAILED;
34   }
35 
36   cur_tensor_layout_ = tensor_layout;
37   out_tensor_map_ = out_tensor_map;
38   dev_list_ = std::move(dev_list);
39 
40   operator_list_.clear();
41   operator_vector_.clear();
42   output_info_vector_.clear();
43 
44   if (constructor_.Init(dev_list_, dev_mat_.array()) != Status::SUCCESS) {
45     MS_LOG(ERROR) << "Init constructor failed";
46     return Status::FAILED;
47   }
48   constructor_.UpdateTensorShape(cur_tensor_layout_.slice_shape().array());
49 
50   size_t key = 0;
51   Shape map = in_tensor_map_.array();
52   for (int64_t item : map) {
53     map_[key++] = item;
54   }
55 
56   is_cost_model_ = is_cost_model;
57   return Status::SUCCESS;
58 }
59 
InferRedistributionOperator()60 Status RedistributionOperatorInfer::InferRedistributionOperator() {
61   while (!map_.empty()) {
62     size_t len_global = operator_list_.size();
63 
64     while (!map_.empty()) {
65       size_t len_split_by_axis = operator_list_.size();
66       // split_by_axis operation
67       if (InferSplitByAxis() == Status::FAILED) {
68         return Status::FAILED;
69       }
70       // permute_by_axis operation
71       while (!map_.empty()) {
72         size_t len_permute_by_axis = operator_list_.size();
73         if (InferPermuteByAxis() == Status::FAILED) {
74           return Status::FAILED;
75         }
76         if (len_permute_by_axis == operator_list_.size()) break;
77       }
78       if (len_split_by_axis == operator_list_.size()) break;
79     }
80     // concat_by_axis operation
81     if (InferConcatByAxis() == Status::FAILED) {
82       return Status::FAILED;
83     }
84     // break loop structure with concat_by_axis
85     if (len_global == operator_list_.size() && !map_.empty()) {
86       size_t index = map_.begin()->first;
87       int64_t in_dim = map_[index];
88       map_[index] = NONE;
89       Args args = {SizeToLong(index), in_dim, dev_mat_.GetDimByReverseIdx(LongToSize(in_dim))};
90       if (InsertOperator(CONCAT_BY_AXIS, args) == Status::FAILED) {
91         return Status::FAILED;
92       }
93     }
94   }
95   return Status::SUCCESS;
96 }
97 
InferSplitByAxis()98 Status RedistributionOperatorInfer::InferSplitByAxis() {
99   for (auto iter = map_.begin(); iter != map_.end();) {
100     uint64_t index = iter->first;
101     int64_t in_dim = iter->second;
102     int64_t out_dim = out_tensor_map_.GetDimByIdx(index);
103     if (in_dim == out_dim) {
104       (void)map_.erase(iter++);
105       continue;
106     }
107     if (in_dim == NONE &&
108         !std::any_of(map_.begin(), map_.end(),
109                      [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) {
110       Args args = {dev_mat_.GetDimByReverseIdx(LongToUlong(out_dim)), UlongToLong(index), out_dim};
111       if (InsertOperator(SPLIT_BY_AXIS, args) == Status::FAILED) {
112         MS_LOG(ERROR) << "Insert SplitByAxis Error!";
113         return Status::FAILED;
114       }
115       (void)map_.erase(iter++);
116     } else {
117       (void)++iter;
118     }
119   }
120   return Status::SUCCESS;
121 }
122 
InferPermuteByAxis()123 Status RedistributionOperatorInfer::InferPermuteByAxis() {
124   for (auto iter = map_.begin(); iter != map_.end();) {
125     uint64_t index = iter->first;
126     int64_t in_dim = map_[index];
127     int64_t out_dim = out_tensor_map_.GetDimByIdx(index);
128     if (in_dim == out_dim) {
129       (void)map_.erase(iter++);
130       continue;
131     }
132     if (in_dim == NONE &&
133         std::any_of(map_.begin(), map_.end(),
134                     [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) {
135       int64_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim);
136       int64_t dev_num = dev_mat_.GetDimByReverseIdx(LongToSize(out_dim));
137       if (ParallelContext::GetInstance()->enable_all2all()) {
138         int64_t dev_dim = in_tensor_map_.GetDimByIdx(LongToUlong(cat_dim));
139         Args args_alltoall = {dev_mat_.GetDimByReverseIdx(LongToUlong(dev_dim)), UlongToLong(index), cat_dim, dev_dim,
140                               dev_num};
141         if (InsertOperator(PERMUTE_BY_AXIS, args_alltoall) == Status::FAILED) {
142           MS_LOG(ERROR) << "Insert PermuteByAxis Error!";
143           return Status::FAILED;
144         }
145       } else {
146         Args args_allconcat = {cat_dim, out_dim, dev_num};
147         Args args_allsplit = {dev_num, UlongToLong(index), out_dim};
148         if (InsertOperator(CONCAT_BY_AXIS, args_allconcat) == Status::FAILED) {
149           MS_LOG(ERROR) << "Insert ConcatByAxis Error!";
150           return Status::FAILED;
151         }
152         if (InsertOperator(SPLIT_BY_AXIS, args_allsplit) == Status::FAILED) {
153           MS_LOG(ERROR) << "Insert SplitByAxis Error!";
154           return Status::FAILED;
155         }
156       }
157       (void)map_.erase(iter++);
158       map_[LongToSize(cat_dim)] = NONE;
159     } else {
160       (void)++iter;
161     }
162   }
163   return Status::SUCCESS;
164 }
165 
InferConcatByAxis()166 Status RedistributionOperatorInfer::InferConcatByAxis() {
167   for (auto iter = map_.begin(); iter != map_.end();) {
168     uint64_t index = iter->first;
169     int64_t in_dim = map_[index];
170     int64_t out_dim = out_tensor_map_.GetDimByIdx(index);
171     if (in_dim != NONE && out_tensor_map_.GetIndexByValue(in_dim) == NONE) {
172       Args args = {SizeToLong(index), in_dim, dev_mat_.GetDimByReverseIdx(LongToSize(in_dim))};
173       if (InsertOperator(CONCAT_BY_AXIS, args) == Status::FAILED) {
174         MS_LOG(ERROR) << "Insert ConcatByAxis Error!";
175         return Status::FAILED;
176       }
177       if (out_dim == NONE) {
178         (void)map_.erase(iter++);
179       } else {
180         map_[index] = NONE;
181         (void)++iter;
182       }
183     } else {
184       (void)++iter;
185     }
186   }
187   return Status::SUCCESS;
188 }
189 
190 // Transfer communicative operators into primitives and insert them into vector
InsertOperator(const OperatorName & name,const Args & args)191 Status RedistributionOperatorInfer::InsertOperator(const OperatorName &name, const Args &args) {
192   OperatorR op = std::make_pair(name, args);
193   OperatorC op_cost = std::make_pair(op, cur_tensor_layout_.slice_shape().array());
194   operator_list_.push_back(op_cost);
195   if (construct_op_flag_) {
196     if (name == SPLIT_BY_AXIS) {
197       if (TransferSplitByAxis(args) == Status::FAILED) {
198         return Status::FAILED;
199       }
200     } else if (name == PERMUTE_BY_AXIS) {
201       if (TransferPermuteByAxis(args) == Status::FAILED) {
202         return Status::FAILED;
203       }
204     } else {
205       if (TransferConcatByAxis(args) == Status::FAILED) {
206         return Status::FAILED;
207       }
208     }
209     constructor_.UpdateTensorShape(cur_tensor_layout_.slice_shape().array());
210   }
211   return Status::SUCCESS;
212 }
213 
TransferSplitByAxis(const Args & args)214 Status RedistributionOperatorInfer::TransferSplitByAxis(const Args &args) {
215   if (args.size() < TRANSFER_SPLIT_ARGS_SIZE) {
216     MS_LOG(ERROR) << "args size should not be less than 3!";
217     return Status::FAILED;
218   }
219   size_t index = LongToSize(args[TRANSFER_PERMUTE_SPLIT_DIM_INDEX]);
220   if (constructor_.StridedSliceOP(args) != Status::SUCCESS) {
221     return Status::FAILED;
222   } else {
223     operator_vector_.push_back(constructor_.GetOperator());
224     output_info_vector_.push_back(std::make_pair(false, 0));
225   }
226   if (cur_tensor_layout_.UpdateTensorMap(index, args[TRANSFER_PERMUTE_CONCAT_DIM_INDEX]) == Status::FAILED) {
227     return Status::FAILED;
228   }
229   return Status::SUCCESS;
230 }
231 
TransferPermuteByAxis(const Args & args)232 Status RedistributionOperatorInfer::TransferPermuteByAxis(const Args &args) {
233   if (args.size() < TRANSFER_PERMUTE_ARGS_SIZE) {
234     MS_LOG(ERROR) << "args size should not be less than 5!";
235     return Status::FAILED;
236   }
237   if (constructor_.AlltoAllOP(args) != Status::SUCCESS) {
238     return Status::FAILED;
239   } else {
240     operator_vector_.push_back(constructor_.GetOperator());
241     output_info_vector_.push_back(std::make_pair(false, 0));
242   }
243   size_t index = LongToSize(args[TRANSFER_PERMUTE_SPLIT_DIM_INDEX]);
244   int64_t val = args[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
245   int64_t out_dim = out_tensor_map_.GetDimByIdx(index);
246 
247   if (cur_tensor_layout_.UpdateTensorMap(LongToSize(val), NONE) == Status::FAILED) {
248     return Status::FAILED;
249   }
250   if (cur_tensor_layout_.UpdateTensorMap(index, out_dim) == Status::FAILED) {
251     return Status::FAILED;
252   }
253   return Status::SUCCESS;
254 }
255 
TransferConcatByAxis(const Args & args)256 Status RedistributionOperatorInfer::TransferConcatByAxis(const Args &args) {
257   if (args.size() < TRANSFER_CONCAT_ARGS_SIZE) {
258     MS_LOG(ERROR) << "args size should not be less than 3!";
259     return Status::FAILED;
260   }
261   int64_t tensor_dim = args[TRANSFER_CONCAT_TENSOR_DIM_INDEX];
262   int64_t dev_dim = args[TRANSFER_CONCAT_DEV_DIM_INDEX];
263   int64_t split_count = args[TRANSFER_CONCAT_SPLIT_COUNT_INDEX];
264   if (constructor_.AllGatherOP(dev_dim) != Status::SUCCESS) {
265     return Status::FAILED;
266   } else {
267     operator_vector_.push_back(constructor_.GetOperator());
268     output_info_vector_.push_back(std::make_pair(false, 0));
269   }
270   if (tensor_dim != 0) {
271     if (constructor_.SplitOP(split_count) != Status::SUCCESS) {
272       return Status::FAILED;
273     } else {
274       operator_vector_.push_back(constructor_.GetOperator());
275       output_info_vector_.push_back(std::make_pair(true, split_count));
276     }
277     if (constructor_.ConcatOP(tensor_dim) != Status::SUCCESS) {
278       return Status::FAILED;
279     } else {
280       operator_vector_.push_back(constructor_.GetOperator());
281       output_info_vector_.push_back(std::make_pair(false, 0));
282     }
283   }
284   if (cur_tensor_layout_.UpdateTensorMap(LongToSize(tensor_dim), NONE) == Status::FAILED) {
285     return Status::FAILED;
286   }
287   return Status::SUCCESS;
288 }
289 }  // namespace parallel
290 }  // namespace mindspore
291