• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/tensor_layout/construct_operator.h"
18 
19 #include <functional>
20 #include <numeric>
21 #include <algorithm>
22 
23 namespace mindspore {
24 namespace parallel {
Init(const RankList & dev_list,const Shape & dev_matrix_shape)25 Status ConstructOperator::Init(const RankList &dev_list, const Shape &dev_matrix_shape) {
26   dev_size_ = dev_matrix_shape.size();
27   dev_matrix_shape_ = dev_matrix_shape;
28   dev_list_ = dev_list;
29   return Status::SUCCESS;
30 }
31 
32 // skip redistribution for reshape operator
SkipRedisReshapeOP(const Shape & shape)33 OperatorVector ConstructOperator::SkipRedisReshapeOP(const Shape &shape) {
34   OperatorAttrs attrs;
35   ValuePtr param_value = MakeValue(shape);
36   Attr param = std::make_pair(SHAPE, param_value);
37   OperatorParams params = {std::make_pair(param, 2)};
38   OperatorArgs args = std::make_pair(attrs, params);
39   Operator op = std::make_pair(RESHAPE, args);
40   OperatorVector opvector;
41   opvector.push_back(op);
42   return opvector;
43 }
44 
ReshapeOP(const Shape & shape)45 Status ConstructOperator::ReshapeOP(const Shape &shape) {
46   int64_t prod = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
47   int64_t prod_expect = std::accumulate(tensor_shape_.begin(), tensor_shape_.end(), 1, std::multiplies<int64_t>());
48   if (prod != prod_expect) {
49     ValuePtr ptr = MakeValue(shape);
50     MS_EXCEPTION_IF_NULL(ptr);
51     MS_LOG(ERROR) << "Invalid tensor shape " << ptr->ToString() << "when construct Reshape operator!";
52     return Status::INVALID_ARGUMENT;
53   }
54   OperatorAttrs attrs;
55   ValuePtr param_value = MakeValue(shape);
56   Attr param = std::make_pair(SHAPE, param_value);
57   OperatorParams params = {std::make_pair(param, 2)};
58   OperatorArgs args = std::make_pair(attrs, params);
59   op_ = std::make_pair(RESHAPE, args);
60   return Status::SUCCESS;
61 }
62 
CreateStridedSliceOp(int64_t value,const Shape & begin,const Shape & end,const Shape & strides)63 Operator CreateStridedSliceOp(int64_t value, const Shape &begin, const Shape &end, const Shape &strides) {
64   ValuePtr attr_value = MakeValue(value);
65   Attr attr_begin_mask = std::make_pair(BEGIN_MASK, attr_value);
66   Attr attr_end_mask = std::make_pair(END_MASK, attr_value);
67   Attr attr_ellipsis_mask = std::make_pair(ELLIPSIS_MASK, attr_value);
68   Attr attr_new_axis_mask = std::make_pair(NEW_AXIS_MASK, attr_value);
69   Attr attr_shrink_axis_mask = std::make_pair(SHRINK_AXIS_MASK, attr_value);
70   OperatorAttrs attrs = {attr_begin_mask, attr_end_mask, attr_ellipsis_mask, attr_new_axis_mask, attr_shrink_axis_mask};
71 
72   ValuePtr param_begin_value = MakeValue(begin);
73   Param param_begin = std::make_pair(std::make_pair(BEGIN, param_begin_value), STRIDED_SLICE_BEGIN_INDEX + 1);
74   ValuePtr param_end_value = MakeValue(end);
75   Param param_end = std::make_pair(std::make_pair(END, param_end_value), STRIDED_SLICE_END_INDEX + 1);
76 
77   ValuePtr param_strides_value = MakeValue(strides);
78   Param param_strides = std::make_pair(std::make_pair(STRIDES, param_strides_value), STRIDED_SLICE_STRIDES_INDEX + 1);
79   OperatorParams params = {param_begin, param_end, param_strides};
80   OperatorArgs op_args = std::make_pair(attrs, params);
81 
82   return std::make_pair(STRIDED_SLICE, op_args);
83 }
84 
StridedSliceOP(const Args & args)85 Status ConstructOperator::StridedSliceOP(const Args &args) {
86   if (args.size() < STRIDED_SLICE_ARGS_SIZE) {
87     MS_LOG(ERROR) << "args size should not be less than 3!";
88     return Status::FAILED;
89   }
90   int64_t split_count = args[TRANSFER_PERMUTE_SPLIT_COUNT_INDEX];
91   if (split_count <= 0) {
92     MS_LOG(ERROR) << "split_count should not be less than 0!";
93     return Status::FAILED;
94   }
95   int64_t split_dim = args[TRANSFER_PERMUTE_SPLIT_DIM_INDEX];
96   int64_t dev_dim = args[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
97   std::vector<Group> group_list;
98 
99   if (CreateGroupByDim(dev_size_ - LongToSize(dev_dim) - 1, &group_list) != SUCCESS) {
100     MS_LOG(ERROR) << "stride slice op: create group failed";
101     return FAILED;
102   } else if (group_list.empty()) {  // this group only has one device, don't need do StridedSlice
103     MS_LOG(INFO) << "no need stride slice op";
104     return SUCCESS;
105   }
106 
107   Group group = group_list[0];
108   size_t rank;
109   if (group.GetIndex(&rank) == Status::FAILED) {
110     return Status::FAILED;
111   }
112   size_t size = tensor_shape_.size();
113   Shape begin(size);
114   Shape end(size);
115   Shape strides(size, 1);
116   size_t index = 0;
117   for (auto num : tensor_shape_) {
118     if (index != LongToSize(split_dim)) {
119       begin[index] = 0;
120       end[index] = num;
121     } else {
122       if (num % split_count != 0) {
123         MS_LOG(ERROR) << "Tensor can not be split into " << split_count << " slices in the dimension " << split_dim
124                       << "! when construct StridedSlice operator";
125         return Status::INVALID_ARGUMENT;
126       }
127       int64_t count = num / split_count;
128       begin[index] = SizeToLong(rank) * count;
129       end[index] = (SizeToLong(rank) + 1) * count;
130     }
131     index++;
132   }
133 
134   op_ = CreateStridedSliceOp(DEFAULT, begin, end, strides);
135 
136   return Status::SUCCESS;
137 }
138 
AllGatherOP(int64_t dev_dim)139 Status ConstructOperator::AllGatherOP(int64_t dev_dim) {
140   if ((LongToSize(dev_dim) >= dev_size_) || (dev_dim < 0)) {
141     MS_LOG(ERROR) << "Invalid device dimension " << dev_dim << " when construct AllGather operator!";
142     return Status::INVALID_ARGUMENT;
143   }
144 
145   std::vector<Group> group_list;
146   if (CreateGroupByDim(dev_size_ - LongToSize(dev_dim) - 1, &group_list) != SUCCESS) {
147     MS_LOG(ERROR) << "AllGather op: create group failed";
148     return FAILED;
149   } else if (group_list.empty()) {  // this group only has one device, don't need do allgather
150     MS_LOG(INFO) << "no need all gather op";
151     return SUCCESS;
152   }
153 
154   std::string group_name = group_list[0].name();
155   ValuePtr attr_value = MakeValue(group_name);
156   Attr attr = std::make_pair(GROUP, attr_value);
157   OperatorAttrs attrs = {attr};
158   OperatorParams params;
159   OperatorArgs args = std::make_pair(attrs, params);
160   op_ = std::make_pair(ALL_GATHER, args);
161   return Status::SUCCESS;
162 }
163 
ConcatOP(int64_t concat_dim)164 Status ConstructOperator::ConcatOP(int64_t concat_dim) {
165   if (LongToSize(concat_dim) >= tensor_shape_.size() || concat_dim < 0) {
166     MS_LOG(ERROR) << "Invalid tensor dimension " << concat_dim << " when construct Concat operator!";
167     return Status::INVALID_ARGUMENT;
168   }
169   ValuePtr attr_value = MakeValue(concat_dim);
170   Attr attr = std::make_pair(AXIS, attr_value);
171   OperatorAttrs attrs = {attr};
172   OperatorParams params;
173   OperatorArgs args = std::make_pair(attrs, params);
174   op_ = std::make_pair(CONCAT, args);
175   return Status::SUCCESS;
176 }
177 
SplitOP(int64_t split_count)178 Status ConstructOperator::SplitOP(int64_t split_count) {
179   // tensor_shape_ can not be validated here
180   if (split_count <= 0) {
181     MS_LOG(ERROR) << "Invalid split count when construct Split operator!";
182     return Status::FAILED;
183   }
184   OperatorAttrs attrs;
185   ValuePtr attr_value_axis = MakeValue(DEFAULT);
186   Attr attr_axis = std::make_pair(AXIS, attr_value_axis);
187   ValuePtr attr_value_split = MakeValue(split_count);
188   Attr attr_split = std::make_pair(OUTPUT_NUM, attr_value_split);
189   attrs = {attr_axis, attr_split};
190   OperatorParams params;
191   OperatorArgs args = std::make_pair(attrs, params);
192   op_ = std::make_pair(SPLIT, args);
193   return Status::SUCCESS;
194 }
195 
AlltoAllOP(const Args & args)196 Status ConstructOperator::AlltoAllOP(const Args &args) {
197   if (args.size() < TRANSFER_PERMUTE_ARGS_SIZE) {
198     MS_LOG(ERROR) << "args size should not be less than 5!";
199     return Status::FAILED;
200   }
201   int64_t split_count = args[TRANSFER_PERMUTE_SPLIT_COUNT_INDEX];
202   int64_t split_dim = args[TRANSFER_PERMUTE_SPLIT_DIM_INDEX];
203   int64_t concat_dim = args[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
204   int64_t dev_dim = args[TRANSFER_PERMUTE_DEV_DIM_INDEX];
205   if (split_count <= 0) {
206     MS_LOG(ERROR) << "Invalid split count when construct AlltoAll operator!";
207     return Status::FAILED;
208   }
209   if (tensor_shape_[LongToSize(split_dim)] % split_count != 0) {
210     MS_LOG(ERROR) << "Tensor can not be split into " << split_count << " slices in the dimension " << split_dim
211                   << "when construct AlltoAll operator!";
212     return Status::INVALID_ARGUMENT;
213   }
214   if (LongToSize(concat_dim) >= tensor_shape_.size() || concat_dim < 0) {
215     MS_LOG(ERROR) << "Invalid split count " << split_count << " when construct AlltoAll operator!";
216     return Status::INVALID_ARGUMENT;
217   }
218   if ((LongToSize(dev_dim) >= dev_size_) || (dev_dim < 0)) {
219     MS_LOG(ERROR) << "Invalid device dimension " << dev_dim << " when construct AlltoAll operator!";
220     return Status::INVALID_ARGUMENT;
221   }
222 
223   std::vector<Group> group_list;
224   if (CreateGroupByDim(dev_size_ - LongToSize(dev_dim) - 1, &group_list) != SUCCESS) {
225     MS_LOG(ERROR) << "AlltoAll op: create group failed";
226     return FAILED;
227   } else if (group_list.empty()) {  // this group only has one device, don't need do alltoall
228     MS_LOG(INFO) << "no need all to all op";
229     return SUCCESS;
230   }
231 
232   std::string group_name = group_list[0].name();
233   ValuePtr attr_value_group = MakeValue(group_name);
234   Attr attr_group = std::make_pair(GROUP, attr_value_group);
235   ValuePtr attr_value_split_count = MakeValue(split_count);
236   Attr attr_split_count = std::make_pair(SPLIT_COUNT, attr_value_split_count);
237   ValuePtr attr_value_split_dim = MakeValue(split_dim);
238   Attr attr_split_dim = std::make_pair(SPLIT_DIM, attr_value_split_dim);
239   ValuePtr attr_value_concat_dim = MakeValue(concat_dim);
240   Attr attr_concat_dim = std::make_pair(CONCAT_DIM, attr_value_concat_dim);
241   OperatorAttrs attrs = {attr_split_count, attr_split_dim, attr_concat_dim, attr_group};
242   OperatorParams params;
243   OperatorArgs op_args = std::make_pair(attrs, params);
244   op_ = std::make_pair(ALL_TO_ALL, op_args);
245   return Status::SUCCESS;
246 }
247 
CreateGroupByDim(size_t axis,std::vector<Group> * group)248 Status ConstructOperator::CreateGroupByDim(size_t axis, std::vector<Group> *group) {
249   MS_EXCEPTION_IF_NULL(group);
250   CheckGlobalDeviceManager();
251   MS_EXCEPTION_IF_NULL(g_device_manager);
252   int64_t rank = g_device_manager->global_rank();
253   DeviceMatrix dev_matrix(rank, dev_list_, dev_matrix_shape_);
254   RankList group_devices;
255   if (dev_matrix.GetDevicesAlongDim(SizeToUlong(axis), &group_devices) != SUCCESS) {
256     return FAILED;
257   }
258   // this group only has one device, don't need create the group
259   if (group_devices.size() == 1) {
260     MS_LOG(INFO) << "the group is empty";
261     return SUCCESS;
262   }
263 
264   Group g = g_device_manager->CreateGroup(group_devices);
265   group->push_back(g);
266   return SUCCESS;
267 }
268 }  // namespace parallel
269 }  // namespace mindspore
270