1 /**
2 * Copyright 2019-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/tensor_layout/construct_operator.h"
18
19 #include <functional>
20 #include <numeric>
21 #include <algorithm>
22 #include <memory>
23 #include "frontend/parallel/ops_info/ops_utils.h"
24 #include "include/common/utils/parallel_context.h"
25
26 namespace mindspore {
27 namespace parallel {
Init(const RankList & dev_list,const Shape & dev_matrix_shape,bool is_cost_model,bool is_dynamic_shape)28 Status ConstructOperator::Init(const RankList &dev_list, const Shape &dev_matrix_shape, bool is_cost_model,
29 bool is_dynamic_shape) {
30 dev_size_ = dev_matrix_shape.size();
31 dev_matrix_shape_ = dev_matrix_shape;
32 dev_list_ = dev_list;
33 is_cost_model_ = is_cost_model;
34 is_dynamic_shape_ = is_dynamic_shape;
35 return Status::SUCCESS;
36 }
37
38 // skip redistribution for reshape operator
SkipRedisReshapeOP(const Shape & shape) const39 OperatorVector ConstructOperator::SkipRedisReshapeOP(const Shape &shape) const {
40 OperatorAttrs attrs;
41 ValuePtr param_value = MakeValue(shape);
42 Attr param = std::make_pair(SHAPE, param_value);
43 OperatorParams params = {std::make_pair(param, 2)};
44 OperatorArgs args = std::make_pair(attrs, params);
45 Operator op = std::make_pair(RESHAPE, args);
46 OperatorVector opvector;
47 opvector.push_back(op);
48 return opvector;
49 }
50
ReshapeOP(const Shape & shape,bool use_origin_shape,enum ReshapeMode reshape_mode)51 Status ConstructOperator::ReshapeOP(const Shape &shape, bool use_origin_shape, enum ReshapeMode reshape_mode) {
52 int64_t prod = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
53 int64_t prod_expect = std::accumulate(tensor_shape_.begin(), tensor_shape_.end(), 1, std::multiplies<int64_t>());
54 if (!IsDynamicShape(shape) && !IsDynamicShape(tensor_shape_) && prod != prod_expect) {
55 ValuePtr ptr = MakeValue(shape);
56 MS_EXCEPTION_IF_NULL(ptr);
57 MS_LOG(ERROR) << "Invalid tensor shape " << ptr->ToString()
58 << " when construct Reshape operator! Expect production is " << prod_expect << " which shape is "
59 << tensor_shape_;
60 return Status::INVALID_ARGUMENT;
61 }
62 OperatorAttrs attrs;
63 ValuePtr param_value = MakeValue(shape);
64 Attr param = std::make_pair(SHAPE, param_value);
65 OperatorParams params = {std::make_pair(param, 2)};
66 if (use_origin_shape) {
67 // Only user's reshape could be in this branch.
68 ValuePtr use_origin_shape_flag = MakeValue(use_origin_shape);
69 Attr use_origin_shape_param = std::make_pair(USE_ORIGIN_SHAPE, use_origin_shape_flag);
70 params.emplace_back(std::make_pair(use_origin_shape_param, -1));
71 }
72 if (reshape_mode != ReshapeMode::NO_RESHAPE) {
73 ValuePtr reshape_mode_flag = MakeValue(static_cast<int64_t>(reshape_mode));
74 Attr reshape_mode_param = std::make_pair(REDISTRIBUTION_RESHAPE_MODE, reshape_mode_flag);
75 params.emplace_back(std::make_pair(reshape_mode_param, -1));
76 }
77 OperatorArgs args = std::make_pair(attrs, params);
78 op_ = std::make_pair(RESHAPE, args);
79 return Status::SUCCESS;
80 }
81
CreateStridedSliceOp(int64_t value,const Shape & begin,const Shape & end,const Shape & strides)82 Operator CreateStridedSliceOp(int64_t value, const Shape &begin, const Shape &end, const Shape &strides) {
83 ValuePtr param_begin_value = MakeValue(begin);
84 Param param_begin = std::make_pair(std::make_pair(BEGIN, param_begin_value), STRIDED_SLICE_BEGIN_INDEX + 1);
85 ValuePtr param_end_value = MakeValue(end);
86 Param param_end = std::make_pair(std::make_pair(END, param_end_value), STRIDED_SLICE_END_INDEX + 1);
87
88 ValuePtr param_strides_value = MakeValue(strides);
89 Param param_strides = std::make_pair(std::make_pair(STRIDES, param_strides_value), STRIDED_SLICE_STRIDES_INDEX + 1);
90
91 ValuePtr begin_mask = MakeValue(value);
92 Param param_begin_mask = std::make_pair(std::make_pair(BEGIN_MASK, begin_mask), STRIDED_SLICE_BEGIN_MASK_INDEX + 1);
93 ValuePtr end_mask = MakeValue(value);
94 Param param_end_mask = std::make_pair(std::make_pair(END_MASK, end_mask), STRIDED_SLICE_END_MASK_INDEX + 1);
95 ValuePtr ellipsis_mask = MakeValue(value);
96 Param param_ellipsis_mask =
97 std::make_pair(std::make_pair(ELLIPSIS_MASK, ellipsis_mask), STRIDED_SLICE_ELLIPSIS_MASK_INDEX + 1);
98 ValuePtr new_axis_mask = MakeValue(value);
99 Param param_new_axis_mask =
100 std::make_pair(std::make_pair(NEW_AXIS_MASK, new_axis_mask), STRIDED_SLICE_NEW_AXIS_MASK_INDEX + 1);
101 ValuePtr shrink_axis_mask = MakeValue(value);
102 Param param_shrink_axis_mask =
103 std::make_pair(std::make_pair(SHRINK_AXIS_MASK, shrink_axis_mask), STRIDED_SLICE_SHRINK_AXIS_MASK_INDEX + 1);
104
105 OperatorParams params = {param_begin, param_end, param_strides, param_begin_mask,
106 param_end_mask, param_ellipsis_mask, param_new_axis_mask, param_shrink_axis_mask};
107 OperatorAttrs attrs;
108 OperatorArgs op_args = std::make_pair(attrs, params);
109
110 return std::make_pair(STRIDED_SLICE, op_args);
111 }
112
CreateSplitOp(int64_t split_size_or_sections,int64_t axis,int64_t index)113 Operator CreateSplitOp(int64_t split_size_or_sections, int64_t axis, int64_t index) {
114 ValuePtr split_size_value_ptr = MakeValue(std::make_shared<Int64Imm>(split_size_or_sections));
115 Param split_size_param = std::make_pair(std::make_pair(SPLIT_SIZE, split_size_value_ptr), 1);
116
117 ValuePtr axis_value_ptr = MakeValue(std::make_shared<Int64Imm>(axis));
118 Param axis_param = std::make_pair(std::make_pair(SPLIT_DIM, axis_value_ptr), 2);
119
120 ValuePtr slice_index_value_ptr = MakeValue(std::make_shared<Int64Imm>(index));
121 Param slice_index_param = std::make_pair(std::make_pair(SPLIT_OUTPUT_INDEX, slice_index_value_ptr), 3);
122
123 ValuePtr skip_value = MakeValue(true);
124 Attr skip_attr = std::make_pair(SPLIT_INSERT_LATER, skip_value);
125
126 OperatorAttrs attrs = {skip_attr};
127 OperatorParams params = {axis_param, split_size_param, slice_index_param};
128 OperatorArgs op_args = std::make_pair(attrs, params);
129
130 return std::make_pair(SPLIT, op_args);
131 }
132
ReplaceStridedSliceOpToSplitOp(const Args & args)133 Status ConstructOperator::ReplaceStridedSliceOpToSplitOp(const Args &args) {
134 // Python api defines: split(tensor, split_size_or_sections, axis=0)
135 // In MindSpore ir, it looks like (tensor, axis, split_count)
136 // So we get axis to split, output number and index of current index.
137 // axis can be fetched here.
138 int64_t split_dim = args[TRANSFER_PERMUTE_SPLIT_DIM_INDEX];
139 // split_size_or_sections can be fetched here.
140 int64_t split_count = args[TRANSFER_PERMUTE_SPLIT_COUNT_INDEX];
141 int64_t dev_dim = args[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
142
143 if (split_dim >= SizeToLong(this->tensor_shape_.size()) ||
144 (this->tensor_shape_[split_dim] != -1 && this->tensor_shape_[split_dim] % split_count != 0)) {
145 MS_LOG(ERROR) << "Tensor with shape " << this->tensor_shape_ << " can not be split into " << split_count
146 << " slices in the dimension " << split_dim << " when construct StridedSlice operator";
147 return Status::INVALID_ARGUMENT;
148 }
149
150 std::vector<Group> group_list;
151 if (CreateGroupByDim(dev_size_ - LongToSize(dev_dim) - 1, &group_list) != SUCCESS) {
152 MS_LOG(ERROR) << "stride slice op: create group failed";
153 return FAILED;
154 } else if (group_list.empty()) { // this group only has one device, don't need do StridedSlice
155 MS_LOG(INFO) << "no need stride slice op";
156 return SUCCESS;
157 }
158
159 Group group = group_list[0];
160 size_t rank;
161 if (group.GetIndex(&rank) == Status::FAILED) {
162 MS_LOG(ERROR) << "Get rank from group failed.";
163 return Status::FAILED;
164 }
165 op_ = CreateSplitOp(split_count, split_dim, SizeToLong(rank));
166 return Status::SUCCESS;
167 }
168
StridedSliceOP(const Args & args)169 Status ConstructOperator::StridedSliceOP(const Args &args) {
170 if (this->is_dynamic_shape_) {
171 // When it's dynamic shape scene, use Split instead of StridedSlice.
172 if (ReplaceStridedSliceOpToSplitOp(args) != Status::SUCCESS) {
173 MS_LOG(ERROR) << "Replace StridedSlice to Split failed.";
174 return Status::FAILED;
175 }
176 return Status::SUCCESS;
177 }
178 if (args.size() < STRIDED_SLICE_ARGS_SIZE) {
179 MS_LOG(ERROR) << "args size should not be less than 3!";
180 return Status::FAILED;
181 }
182 int64_t split_count = args[TRANSFER_PERMUTE_SPLIT_COUNT_INDEX];
183 if (split_count <= 0) {
184 MS_LOG(ERROR) << "split_count should not be less than 0!";
185 return Status::FAILED;
186 }
187 int64_t split_dim = args[TRANSFER_PERMUTE_SPLIT_DIM_INDEX];
188 int64_t dev_dim = args[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
189 std::vector<Group> group_list;
190
191 if (CreateGroupByDim(dev_size_ - LongToSize(dev_dim) - 1, &group_list) != SUCCESS) {
192 MS_LOG(ERROR) << "stride slice op: create group failed";
193 return FAILED;
194 } else if (group_list.empty()) { // this group only has one device, don't need do StridedSlice
195 MS_LOG(INFO) << "no need stride slice op";
196 return SUCCESS;
197 }
198
199 Group group = group_list[0];
200 size_t rank;
201 if (virtual_rank_ >= 0) {
202 if (group.GetIndexByRank(virtual_rank_, &rank) == Status::FAILED) {
203 return Status::FAILED;
204 }
205 } else {
206 if (group.GetIndex(&rank) == Status::FAILED) {
207 return Status::FAILED;
208 }
209 }
210 size_t size = tensor_shape_.size();
211 Shape begin(size);
212 Shape end(size);
213 Shape strides(size, 1);
214 size_t index = 0;
215 for (auto num : tensor_shape_) {
216 if (index != LongToSize(split_dim)) {
217 begin[index] = 0;
218 end[index] = num;
219 } else {
220 if (num % split_count != 0) {
221 MS_LOG(ERROR) << "Tensor with shape " << this->tensor_shape_ << " can not be split into " << split_count
222 << " slices in the dimension " << split_dim << " when construct StridedSlice operator";
223 return Status::INVALID_ARGUMENT;
224 }
225 int64_t count = num / split_count;
226 begin[index] = SizeToLong(rank) * count;
227 end[index] = (SizeToLong(rank) + 1) * count;
228 }
229 index++;
230 }
231 op_ = CreateStridedSliceOp(DEFAULT, begin, end, strides);
232
233 return Status::SUCCESS;
234 }
235
AllGatherOP(int64_t dev_dim)236 Status ConstructOperator::AllGatherOP(int64_t dev_dim) {
237 if ((LongToSize(dev_dim) >= dev_size_) || (dev_dim < 0)) {
238 MS_LOG(ERROR) << "Invalid device dimension " << dev_dim << " when construct AllGather operator!";
239 return Status::INVALID_ARGUMENT;
240 }
241
242 std::vector<Group> group_list;
243 if (CreateGroupByDim(dev_size_ - LongToSize(dev_dim) - 1, &group_list) != SUCCESS) {
244 MS_LOG(ERROR) << "AllGather op: create group failed";
245 return FAILED;
246 } else if (group_list.empty()) { // this group only has one device, don't need do allgather
247 MS_LOG(INFO) << "no need all gather op";
248 return SUCCESS;
249 }
250
251 std::string group_name = group_list[0].name();
252 ValuePtr attr_value = MakeValue(group_name);
253 Attr attr = std::make_pair(GROUP, attr_value);
254 auto group_devices = group_list[0].GetDevicesList();
255 std::vector<int64_t> group_ranks;
256 (void)std::transform(group_devices.begin(), group_devices.end(), std::back_inserter(group_ranks),
257 [](const Device &dev) { return dev.rank(); });
258 ValuePtr attr_ranks_value = MakeValue(group_ranks);
259 Attr attr_ranks = std::make_pair(GROUP_RANKS, attr_ranks_value);
260 OperatorAttrs attrs = {attr, attr_ranks};
261 OperatorParams params;
262 OperatorArgs args = std::make_pair(attrs, params);
263 op_ = std::make_pair(ALL_GATHER, args);
264 return Status::SUCCESS;
265 }
266
ConcatOP(int64_t concat_dim)267 Status ConstructOperator::ConcatOP(int64_t concat_dim) {
268 if (LongToSize(concat_dim) >= tensor_shape_.size() || concat_dim < 0) {
269 MS_LOG(ERROR) << "Invalid tensor dimension " << concat_dim << " when construct Concat operator!";
270 return Status::INVALID_ARGUMENT;
271 }
272 ValuePtr attr_value = MakeValue(concat_dim);
273 Attr attr = std::make_pair(AXIS, attr_value);
274 OperatorAttrs attrs = {attr};
275 OperatorParams params;
276 OperatorArgs args = std::make_pair(attrs, params);
277 op_ = std::make_pair(CONCAT, args);
278 return Status::SUCCESS;
279 }
280
SplitOP(int64_t split_count)281 Status ConstructOperator::SplitOP(int64_t split_count) {
282 // tensor_shape_ can not be validated here
283 if (split_count <= 0) {
284 MS_LOG(ERROR) << "Invalid split count when construct Split operator!";
285 return Status::FAILED;
286 }
287 OperatorAttrs attrs;
288 ValuePtr attr_value_axis = MakeValue(DEFAULT);
289 Attr attr_axis = std::make_pair(AXIS, attr_value_axis);
290 ValuePtr attr_value_split = MakeValue(split_count);
291 Attr attr_split = std::make_pair(OUTPUT_NUM, attr_value_split);
292 attrs = {attr_axis, attr_split};
293 OperatorParams params;
294 OperatorArgs args = std::make_pair(attrs, params);
295 op_ = std::make_pair(SPLIT, args);
296 return Status::SUCCESS;
297 }
298
AlltoAllOP(const Args & args)299 Status ConstructOperator::AlltoAllOP(const Args &args) {
300 if (args.size() < TRANSFER_PERMUTE_ARGS_SIZE) {
301 MS_LOG(ERROR) << "args size should not be less than 5!";
302 return Status::FAILED;
303 }
304 int64_t split_count = args[TRANSFER_PERMUTE_SPLIT_COUNT_INDEX];
305 int64_t split_dim = args[TRANSFER_PERMUTE_SPLIT_DIM_INDEX];
306 int64_t concat_dim = args[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
307 int64_t dev_dim = args[TRANSFER_PERMUTE_DEV_DIM_INDEX];
308 if (split_count <= 0) {
309 MS_LOG(ERROR) << "Invalid split count when construct AlltoAll operator!";
310 return Status::FAILED;
311 }
312 if (tensor_shape_[LongToSize(split_dim)] % split_count != 0) {
313 MS_LOG(ERROR) << "Tensor can not be split into " << split_count << " slices in the dimension " << split_dim
314 << "when construct AlltoAll operator!";
315 return Status::INVALID_ARGUMENT;
316 }
317 if (LongToSize(concat_dim) >= tensor_shape_.size() || concat_dim < 0) {
318 MS_LOG(ERROR) << "Invalid split count " << split_count << " when construct AlltoAll operator!";
319 return Status::INVALID_ARGUMENT;
320 }
321 if ((LongToSize(dev_dim) >= dev_size_) || (dev_dim < 0)) {
322 MS_LOG(ERROR) << "Invalid device dimension " << dev_dim << " when construct AlltoAll operator!";
323 return Status::INVALID_ARGUMENT;
324 }
325
326 std::vector<Group> group_list;
327 if (CreateGroupByDim(dev_size_ - LongToSize(dev_dim) - 1, &group_list) != SUCCESS) {
328 MS_LOG(ERROR) << "AlltoAll op: create group failed";
329 return FAILED;
330 } else if (group_list.empty()) { // this group only has one device, don't need do alltoall
331 MS_LOG(INFO) << "no need all to all op";
332 return SUCCESS;
333 }
334
335 std::string group_name = group_list[0].name();
336 ValuePtr attr_value_group = MakeValue(group_name);
337 Attr attr_group = std::make_pair(GROUP, attr_value_group);
338 ValuePtr attr_value_split_count = MakeValue(split_count);
339 Attr attr_split_count = std::make_pair(SPLIT_COUNT, attr_value_split_count);
340 ValuePtr attr_value_split_dim = MakeValue(split_dim);
341 Attr attr_split_dim = std::make_pair(SPLIT_DIM, attr_value_split_dim);
342 ValuePtr attr_value_concat_dim = MakeValue(concat_dim);
343 Attr attr_concat_dim = std::make_pair(CONCAT_DIM, attr_value_concat_dim);
344 OperatorAttrs attrs = {attr_split_count, attr_split_dim, attr_concat_dim, attr_group};
345 OperatorParams params;
346 OperatorArgs op_args = std::make_pair(attrs, params);
347 op_ = std::make_pair(ALL_TO_ALL, op_args);
348 return Status::SUCCESS;
349 }
350
CreateGroupByDim(size_t axis,std::vector<Group> * group)351 Status ConstructOperator::CreateGroupByDim(size_t axis, std::vector<Group> *group) {
352 MS_EXCEPTION_IF_NULL(group);
353 auto rank = ParallelContext::GetInstance()->global_rank();
354 if (check_group()) {
355 CheckGlobalDeviceManager();
356 } else {
357 rank = virtual_rank_;
358 }
359 DeviceMatrix dev_matrix(rank, dev_list_, dev_matrix_shape_);
360 RankList group_devices;
361 if (dev_matrix.GetDevicesAlongDim(SizeToUlong(axis), &group_devices) != SUCCESS) {
362 return FAILED;
363 }
364 // this group only has one device, don't need create the group
365 if (group_devices.size() == 1) {
366 MS_LOG(INFO) << "the group is empty";
367 return SUCCESS;
368 }
369 if (is_cost_model_ || !check_group()) {
370 Group g;
371 std::vector<Device> dev_list;
372 (void)std::transform(group_devices.begin(), group_devices.end(), std::back_inserter(dev_list),
373 [](auto &rank_id) { return Device(rank_id); });
374 (void)g.Init(HCCL_WORLD_GROUP, dev_list);
375 group->push_back(g);
376 if (!check_group()) {
377 return SUCCESS;
378 }
379 return g_device_manager->CheckDeviceList(group_devices);
380 }
381 Group g;
382 if (g_device_manager->CreateGroup(group_devices, &g) != SUCCESS) {
383 MS_LOG(ERROR) << "Create communication group in redistribution failed, the rank_list is: " << group_devices;
384 return FAILED;
385 }
386 group->push_back(g);
387 return SUCCESS;
388 }
389 } // namespace parallel
390 } // namespace mindspore
391