1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/tensor_layout/redistribution_operator_infer.h"
18
19 #include <utility>
20
21 #include "frontend/parallel/device_manager.h"
22 #include "frontend/parallel/context.h"
23
24 namespace mindspore {
25 namespace parallel {
Init(const TensorLayout & tensor_layout,const Map & out_tensor_map,RankList dev_list,bool is_cost_model)26 Status RedistributionOperatorInfer::Init(const TensorLayout &tensor_layout, const Map &out_tensor_map,
27 RankList dev_list, bool is_cost_model) {
28 in_tensor_map_ = tensor_layout.tensor_map();
29 dev_mat_ = tensor_layout.device_arrangement();
30
31 if (in_tensor_map_.GetDimSize() == 0 || out_tensor_map.GetDimSize() != in_tensor_map_.GetDimSize()) {
32 MS_LOG(ERROR) << "Invalid input when initialize RedistributionOperatorInfer!";
33 return Status::FAILED;
34 }
35
36 cur_tensor_layout_ = tensor_layout;
37 out_tensor_map_ = out_tensor_map;
38 dev_list_ = std::move(dev_list);
39
40 operator_list_.clear();
41 operator_vector_.clear();
42 output_info_vector_.clear();
43
44 if (constructor_.Init(dev_list_, dev_mat_.array()) != Status::SUCCESS) {
45 MS_LOG(ERROR) << "Init constructor failed";
46 return Status::FAILED;
47 }
48 constructor_.UpdateTensorShape(cur_tensor_layout_.slice_shape().array());
49
50 size_t key = 0;
51 Shape map = in_tensor_map_.array();
52 for (int64_t item : map) {
53 map_[key++] = item;
54 }
55
56 is_cost_model_ = is_cost_model;
57 return Status::SUCCESS;
58 }
59
InferRedistributionOperator()60 Status RedistributionOperatorInfer::InferRedistributionOperator() {
61 while (!map_.empty()) {
62 size_t len_global = operator_list_.size();
63
64 while (!map_.empty()) {
65 size_t len_split_by_axis = operator_list_.size();
66 // split_by_axis operation
67 if (InferSplitByAxis() == Status::FAILED) {
68 return Status::FAILED;
69 }
70 // permute_by_axis operation
71 while (!map_.empty()) {
72 size_t len_permute_by_axis = operator_list_.size();
73 if (InferPermuteByAxis() == Status::FAILED) {
74 return Status::FAILED;
75 }
76 if (len_permute_by_axis == operator_list_.size()) break;
77 }
78 if (len_split_by_axis == operator_list_.size()) break;
79 }
80 // concat_by_axis operation
81 if (InferConcatByAxis() == Status::FAILED) {
82 return Status::FAILED;
83 }
84 // break loop structure with concat_by_axis
85 if (len_global == operator_list_.size() && !map_.empty()) {
86 size_t index = map_.begin()->first;
87 int64_t in_dim = map_[index];
88 map_[index] = NONE;
89 Args args = {SizeToLong(index), in_dim, dev_mat_.GetDimByReverseIdx(LongToSize(in_dim))};
90 if (InsertOperator(CONCAT_BY_AXIS, args) == Status::FAILED) {
91 return Status::FAILED;
92 }
93 }
94 }
95 return Status::SUCCESS;
96 }
97
InferSplitByAxis()98 Status RedistributionOperatorInfer::InferSplitByAxis() {
99 for (auto iter = map_.begin(); iter != map_.end();) {
100 uint64_t index = iter->first;
101 int64_t in_dim = iter->second;
102 int64_t out_dim = out_tensor_map_.GetDimByIdx(index);
103 if (in_dim == out_dim) {
104 (void)map_.erase(iter++);
105 continue;
106 }
107 if (in_dim == NONE &&
108 !std::any_of(map_.begin(), map_.end(),
109 [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) {
110 Args args = {dev_mat_.GetDimByReverseIdx(LongToUlong(out_dim)), UlongToLong(index), out_dim};
111 if (InsertOperator(SPLIT_BY_AXIS, args) == Status::FAILED) {
112 MS_LOG(ERROR) << "Insert SplitByAxis Error!";
113 return Status::FAILED;
114 }
115 (void)map_.erase(iter++);
116 } else {
117 (void)++iter;
118 }
119 }
120 return Status::SUCCESS;
121 }
122
InferPermuteByAxis()123 Status RedistributionOperatorInfer::InferPermuteByAxis() {
124 for (auto iter = map_.begin(); iter != map_.end();) {
125 uint64_t index = iter->first;
126 int64_t in_dim = map_[index];
127 int64_t out_dim = out_tensor_map_.GetDimByIdx(index);
128 if (in_dim == out_dim) {
129 (void)map_.erase(iter++);
130 continue;
131 }
132 if (in_dim == NONE &&
133 std::any_of(map_.begin(), map_.end(),
134 [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) {
135 int64_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim);
136 int64_t dev_num = dev_mat_.GetDimByReverseIdx(LongToSize(out_dim));
137 if (ParallelContext::GetInstance()->enable_all2all()) {
138 int64_t dev_dim = in_tensor_map_.GetDimByIdx(LongToUlong(cat_dim));
139 Args args_alltoall = {dev_mat_.GetDimByReverseIdx(LongToUlong(dev_dim)), UlongToLong(index), cat_dim, dev_dim,
140 dev_num};
141 if (InsertOperator(PERMUTE_BY_AXIS, args_alltoall) == Status::FAILED) {
142 MS_LOG(ERROR) << "Insert PermuteByAxis Error!";
143 return Status::FAILED;
144 }
145 } else {
146 Args args_allconcat = {cat_dim, out_dim, dev_num};
147 Args args_allsplit = {dev_num, UlongToLong(index), out_dim};
148 if (InsertOperator(CONCAT_BY_AXIS, args_allconcat) == Status::FAILED) {
149 MS_LOG(ERROR) << "Insert ConcatByAxis Error!";
150 return Status::FAILED;
151 }
152 if (InsertOperator(SPLIT_BY_AXIS, args_allsplit) == Status::FAILED) {
153 MS_LOG(ERROR) << "Insert SplitByAxis Error!";
154 return Status::FAILED;
155 }
156 }
157 (void)map_.erase(iter++);
158 map_[LongToSize(cat_dim)] = NONE;
159 } else {
160 (void)++iter;
161 }
162 }
163 return Status::SUCCESS;
164 }
165
InferConcatByAxis()166 Status RedistributionOperatorInfer::InferConcatByAxis() {
167 for (auto iter = map_.begin(); iter != map_.end();) {
168 uint64_t index = iter->first;
169 int64_t in_dim = map_[index];
170 int64_t out_dim = out_tensor_map_.GetDimByIdx(index);
171 if (in_dim != NONE && out_tensor_map_.GetIndexByValue(in_dim) == NONE) {
172 Args args = {SizeToLong(index), in_dim, dev_mat_.GetDimByReverseIdx(LongToSize(in_dim))};
173 if (InsertOperator(CONCAT_BY_AXIS, args) == Status::FAILED) {
174 MS_LOG(ERROR) << "Insert ConcatByAxis Error!";
175 return Status::FAILED;
176 }
177 if (out_dim == NONE) {
178 (void)map_.erase(iter++);
179 } else {
180 map_[index] = NONE;
181 (void)++iter;
182 }
183 } else {
184 (void)++iter;
185 }
186 }
187 return Status::SUCCESS;
188 }
189
190 // Transfer communicative operators into primitives and insert them into vector
InsertOperator(const OperatorName & name,const Args & args)191 Status RedistributionOperatorInfer::InsertOperator(const OperatorName &name, const Args &args) {
192 OperatorR op = std::make_pair(name, args);
193 OperatorC op_cost = std::make_pair(op, cur_tensor_layout_.slice_shape().array());
194 operator_list_.push_back(op_cost);
195 if (construct_op_flag_) {
196 if (name == SPLIT_BY_AXIS) {
197 if (TransferSplitByAxis(args) == Status::FAILED) {
198 return Status::FAILED;
199 }
200 } else if (name == PERMUTE_BY_AXIS) {
201 if (TransferPermuteByAxis(args) == Status::FAILED) {
202 return Status::FAILED;
203 }
204 } else {
205 if (TransferConcatByAxis(args) == Status::FAILED) {
206 return Status::FAILED;
207 }
208 }
209 constructor_.UpdateTensorShape(cur_tensor_layout_.slice_shape().array());
210 }
211 return Status::SUCCESS;
212 }
213
TransferSplitByAxis(const Args & args)214 Status RedistributionOperatorInfer::TransferSplitByAxis(const Args &args) {
215 if (args.size() < TRANSFER_SPLIT_ARGS_SIZE) {
216 MS_LOG(ERROR) << "args size should not be less than 3!";
217 return Status::FAILED;
218 }
219 size_t index = LongToSize(args[TRANSFER_PERMUTE_SPLIT_DIM_INDEX]);
220 if (constructor_.StridedSliceOP(args) != Status::SUCCESS) {
221 return Status::FAILED;
222 } else {
223 operator_vector_.push_back(constructor_.GetOperator());
224 output_info_vector_.push_back(std::make_pair(false, 0));
225 }
226 if (cur_tensor_layout_.UpdateTensorMap(index, args[TRANSFER_PERMUTE_CONCAT_DIM_INDEX]) == Status::FAILED) {
227 return Status::FAILED;
228 }
229 return Status::SUCCESS;
230 }
231
TransferPermuteByAxis(const Args & args)232 Status RedistributionOperatorInfer::TransferPermuteByAxis(const Args &args) {
233 if (args.size() < TRANSFER_PERMUTE_ARGS_SIZE) {
234 MS_LOG(ERROR) << "args size should not be less than 5!";
235 return Status::FAILED;
236 }
237 if (constructor_.AlltoAllOP(args) != Status::SUCCESS) {
238 return Status::FAILED;
239 } else {
240 operator_vector_.push_back(constructor_.GetOperator());
241 output_info_vector_.push_back(std::make_pair(false, 0));
242 }
243 size_t index = LongToSize(args[TRANSFER_PERMUTE_SPLIT_DIM_INDEX]);
244 int64_t val = args[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
245 int64_t out_dim = out_tensor_map_.GetDimByIdx(index);
246
247 if (cur_tensor_layout_.UpdateTensorMap(LongToSize(val), NONE) == Status::FAILED) {
248 return Status::FAILED;
249 }
250 if (cur_tensor_layout_.UpdateTensorMap(index, out_dim) == Status::FAILED) {
251 return Status::FAILED;
252 }
253 return Status::SUCCESS;
254 }
255
TransferConcatByAxis(const Args & args)256 Status RedistributionOperatorInfer::TransferConcatByAxis(const Args &args) {
257 if (args.size() < TRANSFER_CONCAT_ARGS_SIZE) {
258 MS_LOG(ERROR) << "args size should not be less than 3!";
259 return Status::FAILED;
260 }
261 int64_t tensor_dim = args[TRANSFER_CONCAT_TENSOR_DIM_INDEX];
262 int64_t dev_dim = args[TRANSFER_CONCAT_DEV_DIM_INDEX];
263 int64_t split_count = args[TRANSFER_CONCAT_SPLIT_COUNT_INDEX];
264 if (constructor_.AllGatherOP(dev_dim) != Status::SUCCESS) {
265 return Status::FAILED;
266 } else {
267 operator_vector_.push_back(constructor_.GetOperator());
268 output_info_vector_.push_back(std::make_pair(false, 0));
269 }
270 if (tensor_dim != 0) {
271 if (constructor_.SplitOP(split_count) != Status::SUCCESS) {
272 return Status::FAILED;
273 } else {
274 operator_vector_.push_back(constructor_.GetOperator());
275 output_info_vector_.push_back(std::make_pair(true, split_count));
276 }
277 if (constructor_.ConcatOP(tensor_dim) != Status::SUCCESS) {
278 return Status::FAILED;
279 } else {
280 operator_vector_.push_back(constructor_.GetOperator());
281 output_info_vector_.push_back(std::make_pair(false, 0));
282 }
283 }
284 if (cur_tensor_layout_.UpdateTensorMap(LongToSize(tensor_dim), NONE) == Status::FAILED) {
285 return Status::FAILED;
286 }
287 return Status::SUCCESS;
288 }
289 } // namespace parallel
290 } // namespace mindspore
291