• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/tensor_layout/construct_operator.h"
18 
19 #include <functional>
20 #include <numeric>
21 #include <algorithm>
22 #include <memory>
23 #include "frontend/parallel/ops_info/ops_utils.h"
24 #include "include/common/utils/parallel_context.h"
25 
26 namespace mindspore {
27 namespace parallel {
Init(const RankList & dev_list,const Shape & dev_matrix_shape,bool is_cost_model,bool is_dynamic_shape)28 Status ConstructOperator::Init(const RankList &dev_list, const Shape &dev_matrix_shape, bool is_cost_model,
29                                bool is_dynamic_shape) {
30   dev_size_ = dev_matrix_shape.size();
31   dev_matrix_shape_ = dev_matrix_shape;
32   dev_list_ = dev_list;
33   is_cost_model_ = is_cost_model;
34   is_dynamic_shape_ = is_dynamic_shape;
35   return Status::SUCCESS;
36 }
37 
38 // skip redistribution for reshape operator
SkipRedisReshapeOP(const Shape & shape) const39 OperatorVector ConstructOperator::SkipRedisReshapeOP(const Shape &shape) const {
40   OperatorAttrs attrs;
41   ValuePtr param_value = MakeValue(shape);
42   Attr param = std::make_pair(SHAPE, param_value);
43   OperatorParams params = {std::make_pair(param, 2)};
44   OperatorArgs args = std::make_pair(attrs, params);
45   Operator op = std::make_pair(RESHAPE, args);
46   OperatorVector opvector;
47   opvector.push_back(op);
48   return opvector;
49 }
50 
ReshapeOP(const Shape & shape,bool use_origin_shape,enum ReshapeMode reshape_mode)51 Status ConstructOperator::ReshapeOP(const Shape &shape, bool use_origin_shape, enum ReshapeMode reshape_mode) {
52   int64_t prod = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
53   int64_t prod_expect = std::accumulate(tensor_shape_.begin(), tensor_shape_.end(), 1, std::multiplies<int64_t>());
54   if (!IsDynamicShape(shape) && !IsDynamicShape(tensor_shape_) && prod != prod_expect) {
55     ValuePtr ptr = MakeValue(shape);
56     MS_EXCEPTION_IF_NULL(ptr);
57     MS_LOG(ERROR) << "Invalid tensor shape " << ptr->ToString()
58                   << " when construct Reshape operator! Expect production is " << prod_expect << " which shape is "
59                   << tensor_shape_;
60     return Status::INVALID_ARGUMENT;
61   }
62   OperatorAttrs attrs;
63   ValuePtr param_value = MakeValue(shape);
64   Attr param = std::make_pair(SHAPE, param_value);
65   OperatorParams params = {std::make_pair(param, 2)};
66   if (use_origin_shape) {
67     // Only user's reshape could be in this branch.
68     ValuePtr use_origin_shape_flag = MakeValue(use_origin_shape);
69     Attr use_origin_shape_param = std::make_pair(USE_ORIGIN_SHAPE, use_origin_shape_flag);
70     params.emplace_back(std::make_pair(use_origin_shape_param, -1));
71   }
72   if (reshape_mode != ReshapeMode::NO_RESHAPE) {
73     ValuePtr reshape_mode_flag = MakeValue(static_cast<int64_t>(reshape_mode));
74     Attr reshape_mode_param = std::make_pair(REDISTRIBUTION_RESHAPE_MODE, reshape_mode_flag);
75     params.emplace_back(std::make_pair(reshape_mode_param, -1));
76   }
77   OperatorArgs args = std::make_pair(attrs, params);
78   op_ = std::make_pair(RESHAPE, args);
79   return Status::SUCCESS;
80 }
81 
CreateStridedSliceOp(int64_t value,const Shape & begin,const Shape & end,const Shape & strides)82 Operator CreateStridedSliceOp(int64_t value, const Shape &begin, const Shape &end, const Shape &strides) {
83   ValuePtr param_begin_value = MakeValue(begin);
84   Param param_begin = std::make_pair(std::make_pair(BEGIN, param_begin_value), STRIDED_SLICE_BEGIN_INDEX + 1);
85   ValuePtr param_end_value = MakeValue(end);
86   Param param_end = std::make_pair(std::make_pair(END, param_end_value), STRIDED_SLICE_END_INDEX + 1);
87 
88   ValuePtr param_strides_value = MakeValue(strides);
89   Param param_strides = std::make_pair(std::make_pair(STRIDES, param_strides_value), STRIDED_SLICE_STRIDES_INDEX + 1);
90 
91   ValuePtr begin_mask = MakeValue(value);
92   Param param_begin_mask = std::make_pair(std::make_pair(BEGIN_MASK, begin_mask), STRIDED_SLICE_BEGIN_MASK_INDEX + 1);
93   ValuePtr end_mask = MakeValue(value);
94   Param param_end_mask = std::make_pair(std::make_pair(END_MASK, end_mask), STRIDED_SLICE_END_MASK_INDEX + 1);
95   ValuePtr ellipsis_mask = MakeValue(value);
96   Param param_ellipsis_mask =
97     std::make_pair(std::make_pair(ELLIPSIS_MASK, ellipsis_mask), STRIDED_SLICE_ELLIPSIS_MASK_INDEX + 1);
98   ValuePtr new_axis_mask = MakeValue(value);
99   Param param_new_axis_mask =
100     std::make_pair(std::make_pair(NEW_AXIS_MASK, new_axis_mask), STRIDED_SLICE_NEW_AXIS_MASK_INDEX + 1);
101   ValuePtr shrink_axis_mask = MakeValue(value);
102   Param param_shrink_axis_mask =
103     std::make_pair(std::make_pair(SHRINK_AXIS_MASK, shrink_axis_mask), STRIDED_SLICE_SHRINK_AXIS_MASK_INDEX + 1);
104 
105   OperatorParams params = {param_begin,    param_end,           param_strides,       param_begin_mask,
106                            param_end_mask, param_ellipsis_mask, param_new_axis_mask, param_shrink_axis_mask};
107   OperatorAttrs attrs;
108   OperatorArgs op_args = std::make_pair(attrs, params);
109 
110   return std::make_pair(STRIDED_SLICE, op_args);
111 }
112 
CreateSplitOp(int64_t split_size_or_sections,int64_t axis,int64_t index)113 Operator CreateSplitOp(int64_t split_size_or_sections, int64_t axis, int64_t index) {
114   ValuePtr split_size_value_ptr = MakeValue(std::make_shared<Int64Imm>(split_size_or_sections));
115   Param split_size_param = std::make_pair(std::make_pair(SPLIT_SIZE, split_size_value_ptr), 1);
116 
117   ValuePtr axis_value_ptr = MakeValue(std::make_shared<Int64Imm>(axis));
118   Param axis_param = std::make_pair(std::make_pair(SPLIT_DIM, axis_value_ptr), 2);
119 
120   ValuePtr slice_index_value_ptr = MakeValue(std::make_shared<Int64Imm>(index));
121   Param slice_index_param = std::make_pair(std::make_pair(SPLIT_OUTPUT_INDEX, slice_index_value_ptr), 3);
122 
123   ValuePtr skip_value = MakeValue(true);
124   Attr skip_attr = std::make_pair(SPLIT_INSERT_LATER, skip_value);
125 
126   OperatorAttrs attrs = {skip_attr};
127   OperatorParams params = {axis_param, split_size_param, slice_index_param};
128   OperatorArgs op_args = std::make_pair(attrs, params);
129 
130   return std::make_pair(SPLIT, op_args);
131 }
132 
ReplaceStridedSliceOpToSplitOp(const Args & args)133 Status ConstructOperator::ReplaceStridedSliceOpToSplitOp(const Args &args) {
134   // Python api defines: split(tensor, split_size_or_sections, axis=0)
135   // In MindSpore ir, it looks like (tensor, axis, split_count)
136   // So we get axis to split, output number and index of current index.
137   // axis can be fetched here.
138   int64_t split_dim = args[TRANSFER_PERMUTE_SPLIT_DIM_INDEX];
139   // split_size_or_sections can be fetched here.
140   int64_t split_count = args[TRANSFER_PERMUTE_SPLIT_COUNT_INDEX];
141   int64_t dev_dim = args[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
142 
143   if (split_dim >= SizeToLong(this->tensor_shape_.size()) ||
144       (this->tensor_shape_[split_dim] != -1 && this->tensor_shape_[split_dim] % split_count != 0)) {
145     MS_LOG(ERROR) << "Tensor with shape " << this->tensor_shape_ << " can not be split into " << split_count
146                   << " slices in the dimension " << split_dim << " when construct StridedSlice operator";
147     return Status::INVALID_ARGUMENT;
148   }
149 
150   std::vector<Group> group_list;
151   if (CreateGroupByDim(dev_size_ - LongToSize(dev_dim) - 1, &group_list) != SUCCESS) {
152     MS_LOG(ERROR) << "stride slice op: create group failed";
153     return FAILED;
154   } else if (group_list.empty()) {  // this group only has one device, don't need do StridedSlice
155     MS_LOG(INFO) << "no need stride slice op";
156     return SUCCESS;
157   }
158 
159   Group group = group_list[0];
160   size_t rank;
161   if (group.GetIndex(&rank) == Status::FAILED) {
162     MS_LOG(ERROR) << "Get rank from group failed.";
163     return Status::FAILED;
164   }
165   op_ = CreateSplitOp(split_count, split_dim, SizeToLong(rank));
166   return Status::SUCCESS;
167 }
168 
StridedSliceOP(const Args & args)169 Status ConstructOperator::StridedSliceOP(const Args &args) {
170   if (this->is_dynamic_shape_) {
171     // When it's dynamic shape scene, use Split instead of StridedSlice.
172     if (ReplaceStridedSliceOpToSplitOp(args) != Status::SUCCESS) {
173       MS_LOG(ERROR) << "Replace StridedSlice to Split failed.";
174       return Status::FAILED;
175     }
176     return Status::SUCCESS;
177   }
178   if (args.size() < STRIDED_SLICE_ARGS_SIZE) {
179     MS_LOG(ERROR) << "args size should not be less than 3!";
180     return Status::FAILED;
181   }
182   int64_t split_count = args[TRANSFER_PERMUTE_SPLIT_COUNT_INDEX];
183   if (split_count <= 0) {
184     MS_LOG(ERROR) << "split_count should not be less than 0!";
185     return Status::FAILED;
186   }
187   int64_t split_dim = args[TRANSFER_PERMUTE_SPLIT_DIM_INDEX];
188   int64_t dev_dim = args[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
189   std::vector<Group> group_list;
190 
191   if (CreateGroupByDim(dev_size_ - LongToSize(dev_dim) - 1, &group_list) != SUCCESS) {
192     MS_LOG(ERROR) << "stride slice op: create group failed";
193     return FAILED;
194   } else if (group_list.empty()) {  // this group only has one device, don't need do StridedSlice
195     MS_LOG(INFO) << "no need stride slice op";
196     return SUCCESS;
197   }
198 
199   Group group = group_list[0];
200   size_t rank;
201   if (virtual_rank_ >= 0) {
202     if (group.GetIndexByRank(virtual_rank_, &rank) == Status::FAILED) {
203       return Status::FAILED;
204     }
205   } else {
206     if (group.GetIndex(&rank) == Status::FAILED) {
207       return Status::FAILED;
208     }
209   }
210   size_t size = tensor_shape_.size();
211   Shape begin(size);
212   Shape end(size);
213   Shape strides(size, 1);
214   size_t index = 0;
215   for (auto num : tensor_shape_) {
216     if (index != LongToSize(split_dim)) {
217       begin[index] = 0;
218       end[index] = num;
219     } else {
220       if (num % split_count != 0) {
221         MS_LOG(ERROR) << "Tensor with shape " << this->tensor_shape_ << " can not be split into " << split_count
222                       << " slices in the dimension " << split_dim << " when construct StridedSlice operator";
223         return Status::INVALID_ARGUMENT;
224       }
225       int64_t count = num / split_count;
226       begin[index] = SizeToLong(rank) * count;
227       end[index] = (SizeToLong(rank) + 1) * count;
228     }
229     index++;
230   }
231   op_ = CreateStridedSliceOp(DEFAULT, begin, end, strides);
232 
233   return Status::SUCCESS;
234 }
235 
AllGatherOP(int64_t dev_dim)236 Status ConstructOperator::AllGatherOP(int64_t dev_dim) {
237   if ((LongToSize(dev_dim) >= dev_size_) || (dev_dim < 0)) {
238     MS_LOG(ERROR) << "Invalid device dimension " << dev_dim << " when construct AllGather operator!";
239     return Status::INVALID_ARGUMENT;
240   }
241 
242   std::vector<Group> group_list;
243   if (CreateGroupByDim(dev_size_ - LongToSize(dev_dim) - 1, &group_list) != SUCCESS) {
244     MS_LOG(ERROR) << "AllGather op: create group failed";
245     return FAILED;
246   } else if (group_list.empty()) {  // this group only has one device, don't need do allgather
247     MS_LOG(INFO) << "no need all gather op";
248     return SUCCESS;
249   }
250 
251   std::string group_name = group_list[0].name();
252   ValuePtr attr_value = MakeValue(group_name);
253   Attr attr = std::make_pair(GROUP, attr_value);
254   auto group_devices = group_list[0].GetDevicesList();
255   std::vector<int64_t> group_ranks;
256   (void)std::transform(group_devices.begin(), group_devices.end(), std::back_inserter(group_ranks),
257                        [](const Device &dev) { return dev.rank(); });
258   ValuePtr attr_ranks_value = MakeValue(group_ranks);
259   Attr attr_ranks = std::make_pair(GROUP_RANKS, attr_ranks_value);
260   OperatorAttrs attrs = {attr, attr_ranks};
261   OperatorParams params;
262   OperatorArgs args = std::make_pair(attrs, params);
263   op_ = std::make_pair(ALL_GATHER, args);
264   return Status::SUCCESS;
265 }
266 
ConcatOP(int64_t concat_dim)267 Status ConstructOperator::ConcatOP(int64_t concat_dim) {
268   if (LongToSize(concat_dim) >= tensor_shape_.size() || concat_dim < 0) {
269     MS_LOG(ERROR) << "Invalid tensor dimension " << concat_dim << " when construct Concat operator!";
270     return Status::INVALID_ARGUMENT;
271   }
272   ValuePtr attr_value = MakeValue(concat_dim);
273   Attr attr = std::make_pair(AXIS, attr_value);
274   OperatorAttrs attrs = {attr};
275   OperatorParams params;
276   OperatorArgs args = std::make_pair(attrs, params);
277   op_ = std::make_pair(CONCAT, args);
278   return Status::SUCCESS;
279 }
280 
SplitOP(int64_t split_count)281 Status ConstructOperator::SplitOP(int64_t split_count) {
282   // tensor_shape_ can not be validated here
283   if (split_count <= 0) {
284     MS_LOG(ERROR) << "Invalid split count when construct Split operator!";
285     return Status::FAILED;
286   }
287   OperatorAttrs attrs;
288   ValuePtr attr_value_axis = MakeValue(DEFAULT);
289   Attr attr_axis = std::make_pair(AXIS, attr_value_axis);
290   ValuePtr attr_value_split = MakeValue(split_count);
291   Attr attr_split = std::make_pair(OUTPUT_NUM, attr_value_split);
292   attrs = {attr_axis, attr_split};
293   OperatorParams params;
294   OperatorArgs args = std::make_pair(attrs, params);
295   op_ = std::make_pair(SPLIT, args);
296   return Status::SUCCESS;
297 }
298 
AlltoAllOP(const Args & args)299 Status ConstructOperator::AlltoAllOP(const Args &args) {
300   if (args.size() < TRANSFER_PERMUTE_ARGS_SIZE) {
301     MS_LOG(ERROR) << "args size should not be less than 5!";
302     return Status::FAILED;
303   }
304   int64_t split_count = args[TRANSFER_PERMUTE_SPLIT_COUNT_INDEX];
305   int64_t split_dim = args[TRANSFER_PERMUTE_SPLIT_DIM_INDEX];
306   int64_t concat_dim = args[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
307   int64_t dev_dim = args[TRANSFER_PERMUTE_DEV_DIM_INDEX];
308   if (split_count <= 0) {
309     MS_LOG(ERROR) << "Invalid split count when construct AlltoAll operator!";
310     return Status::FAILED;
311   }
312   if (tensor_shape_[LongToSize(split_dim)] % split_count != 0) {
313     MS_LOG(ERROR) << "Tensor can not be split into " << split_count << " slices in the dimension " << split_dim
314                   << "when construct AlltoAll operator!";
315     return Status::INVALID_ARGUMENT;
316   }
317   if (LongToSize(concat_dim) >= tensor_shape_.size() || concat_dim < 0) {
318     MS_LOG(ERROR) << "Invalid split count " << split_count << " when construct AlltoAll operator!";
319     return Status::INVALID_ARGUMENT;
320   }
321   if ((LongToSize(dev_dim) >= dev_size_) || (dev_dim < 0)) {
322     MS_LOG(ERROR) << "Invalid device dimension " << dev_dim << " when construct AlltoAll operator!";
323     return Status::INVALID_ARGUMENT;
324   }
325 
326   std::vector<Group> group_list;
327   if (CreateGroupByDim(dev_size_ - LongToSize(dev_dim) - 1, &group_list) != SUCCESS) {
328     MS_LOG(ERROR) << "AlltoAll op: create group failed";
329     return FAILED;
330   } else if (group_list.empty()) {  // this group only has one device, don't need do alltoall
331     MS_LOG(INFO) << "no need all to all op";
332     return SUCCESS;
333   }
334 
335   std::string group_name = group_list[0].name();
336   ValuePtr attr_value_group = MakeValue(group_name);
337   Attr attr_group = std::make_pair(GROUP, attr_value_group);
338   ValuePtr attr_value_split_count = MakeValue(split_count);
339   Attr attr_split_count = std::make_pair(SPLIT_COUNT, attr_value_split_count);
340   ValuePtr attr_value_split_dim = MakeValue(split_dim);
341   Attr attr_split_dim = std::make_pair(SPLIT_DIM, attr_value_split_dim);
342   ValuePtr attr_value_concat_dim = MakeValue(concat_dim);
343   Attr attr_concat_dim = std::make_pair(CONCAT_DIM, attr_value_concat_dim);
344   OperatorAttrs attrs = {attr_split_count, attr_split_dim, attr_concat_dim, attr_group};
345   OperatorParams params;
346   OperatorArgs op_args = std::make_pair(attrs, params);
347   op_ = std::make_pair(ALL_TO_ALL, op_args);
348   return Status::SUCCESS;
349 }
350 
CreateGroupByDim(size_t axis,std::vector<Group> * group)351 Status ConstructOperator::CreateGroupByDim(size_t axis, std::vector<Group> *group) {
352   MS_EXCEPTION_IF_NULL(group);
353   auto rank = ParallelContext::GetInstance()->global_rank();
354   if (check_group()) {
355     CheckGlobalDeviceManager();
356   } else {
357     rank = virtual_rank_;
358   }
359   DeviceMatrix dev_matrix(rank, dev_list_, dev_matrix_shape_);
360   RankList group_devices;
361   if (dev_matrix.GetDevicesAlongDim(SizeToUlong(axis), &group_devices) != SUCCESS) {
362     return FAILED;
363   }
364   // this group only has one device, don't need create the group
365   if (group_devices.size() == 1) {
366     MS_LOG(INFO) << "the group is empty";
367     return SUCCESS;
368   }
369   if (is_cost_model_ || !check_group()) {
370     Group g;
371     std::vector<Device> dev_list;
372     (void)std::transform(group_devices.begin(), group_devices.end(), std::back_inserter(dev_list),
373                          [](auto &rank_id) { return Device(rank_id); });
374     (void)g.Init(HCCL_WORLD_GROUP, dev_list);
375     group->push_back(g);
376     if (!check_group()) {
377       return SUCCESS;
378     }
379     return g_device_manager->CheckDeviceList(group_devices);
380   }
381   Group g;
382   if (g_device_manager->CreateGroup(group_devices, &g) != SUCCESS) {
383     MS_LOG(ERROR) << "Create communication group in redistribution failed, the rank_list is: " << group_devices;
384     return FAILED;
385   }
386   group->push_back(g);
387   return SUCCESS;
388 }
389 }  // namespace parallel
390 }  // namespace mindspore
391