1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/tensor_layout/construct_operator.h"
18
19 #include <functional>
20 #include <numeric>
21 #include <algorithm>
22
23 namespace mindspore {
24 namespace parallel {
Init(const RankList & dev_list,const Shape & dev_matrix_shape)25 Status ConstructOperator::Init(const RankList &dev_list, const Shape &dev_matrix_shape) {
26 dev_size_ = dev_matrix_shape.size();
27 dev_matrix_shape_ = dev_matrix_shape;
28 dev_list_ = dev_list;
29 return Status::SUCCESS;
30 }
31
32 // skip redistribution for reshape operator
SkipRedisReshapeOP(const Shape & shape)33 OperatorVector ConstructOperator::SkipRedisReshapeOP(const Shape &shape) {
34 OperatorAttrs attrs;
35 ValuePtr param_value = MakeValue(shape);
36 Attr param = std::make_pair(SHAPE, param_value);
37 OperatorParams params = {std::make_pair(param, 2)};
38 OperatorArgs args = std::make_pair(attrs, params);
39 Operator op = std::make_pair(RESHAPE, args);
40 OperatorVector opvector;
41 opvector.push_back(op);
42 return opvector;
43 }
44
ReshapeOP(const Shape & shape)45 Status ConstructOperator::ReshapeOP(const Shape &shape) {
46 int64_t prod = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
47 int64_t prod_expect = std::accumulate(tensor_shape_.begin(), tensor_shape_.end(), 1, std::multiplies<int64_t>());
48 if (prod != prod_expect) {
49 ValuePtr ptr = MakeValue(shape);
50 MS_EXCEPTION_IF_NULL(ptr);
51 MS_LOG(ERROR) << "Invalid tensor shape " << ptr->ToString() << "when construct Reshape operator!";
52 return Status::INVALID_ARGUMENT;
53 }
54 OperatorAttrs attrs;
55 ValuePtr param_value = MakeValue(shape);
56 Attr param = std::make_pair(SHAPE, param_value);
57 OperatorParams params = {std::make_pair(param, 2)};
58 OperatorArgs args = std::make_pair(attrs, params);
59 op_ = std::make_pair(RESHAPE, args);
60 return Status::SUCCESS;
61 }
62
CreateStridedSliceOp(int64_t value,const Shape & begin,const Shape & end,const Shape & strides)63 Operator CreateStridedSliceOp(int64_t value, const Shape &begin, const Shape &end, const Shape &strides) {
64 ValuePtr attr_value = MakeValue(value);
65 Attr attr_begin_mask = std::make_pair(BEGIN_MASK, attr_value);
66 Attr attr_end_mask = std::make_pair(END_MASK, attr_value);
67 Attr attr_ellipsis_mask = std::make_pair(ELLIPSIS_MASK, attr_value);
68 Attr attr_new_axis_mask = std::make_pair(NEW_AXIS_MASK, attr_value);
69 Attr attr_shrink_axis_mask = std::make_pair(SHRINK_AXIS_MASK, attr_value);
70 OperatorAttrs attrs = {attr_begin_mask, attr_end_mask, attr_ellipsis_mask, attr_new_axis_mask, attr_shrink_axis_mask};
71
72 ValuePtr param_begin_value = MakeValue(begin);
73 Param param_begin = std::make_pair(std::make_pair(BEGIN, param_begin_value), STRIDED_SLICE_BEGIN_INDEX + 1);
74 ValuePtr param_end_value = MakeValue(end);
75 Param param_end = std::make_pair(std::make_pair(END, param_end_value), STRIDED_SLICE_END_INDEX + 1);
76
77 ValuePtr param_strides_value = MakeValue(strides);
78 Param param_strides = std::make_pair(std::make_pair(STRIDES, param_strides_value), STRIDED_SLICE_STRIDES_INDEX + 1);
79 OperatorParams params = {param_begin, param_end, param_strides};
80 OperatorArgs op_args = std::make_pair(attrs, params);
81
82 return std::make_pair(STRIDED_SLICE, op_args);
83 }
84
StridedSliceOP(const Args & args)85 Status ConstructOperator::StridedSliceOP(const Args &args) {
86 if (args.size() < STRIDED_SLICE_ARGS_SIZE) {
87 MS_LOG(ERROR) << "args size should not be less than 3!";
88 return Status::FAILED;
89 }
90 int64_t split_count = args[TRANSFER_PERMUTE_SPLIT_COUNT_INDEX];
91 if (split_count <= 0) {
92 MS_LOG(ERROR) << "split_count should not be less than 0!";
93 return Status::FAILED;
94 }
95 int64_t split_dim = args[TRANSFER_PERMUTE_SPLIT_DIM_INDEX];
96 int64_t dev_dim = args[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
97 std::vector<Group> group_list;
98
99 if (CreateGroupByDim(dev_size_ - LongToSize(dev_dim) - 1, &group_list) != SUCCESS) {
100 MS_LOG(ERROR) << "stride slice op: create group failed";
101 return FAILED;
102 } else if (group_list.empty()) { // this group only has one device, don't need do StridedSlice
103 MS_LOG(INFO) << "no need stride slice op";
104 return SUCCESS;
105 }
106
107 Group group = group_list[0];
108 size_t rank;
109 if (group.GetIndex(&rank) == Status::FAILED) {
110 return Status::FAILED;
111 }
112 size_t size = tensor_shape_.size();
113 Shape begin(size);
114 Shape end(size);
115 Shape strides(size, 1);
116 size_t index = 0;
117 for (auto num : tensor_shape_) {
118 if (index != LongToSize(split_dim)) {
119 begin[index] = 0;
120 end[index] = num;
121 } else {
122 if (num % split_count != 0) {
123 MS_LOG(ERROR) << "Tensor can not be split into " << split_count << " slices in the dimension " << split_dim
124 << "! when construct StridedSlice operator";
125 return Status::INVALID_ARGUMENT;
126 }
127 int64_t count = num / split_count;
128 begin[index] = SizeToLong(rank) * count;
129 end[index] = (SizeToLong(rank) + 1) * count;
130 }
131 index++;
132 }
133
134 op_ = CreateStridedSliceOp(DEFAULT, begin, end, strides);
135
136 return Status::SUCCESS;
137 }
138
AllGatherOP(int64_t dev_dim)139 Status ConstructOperator::AllGatherOP(int64_t dev_dim) {
140 if ((LongToSize(dev_dim) >= dev_size_) || (dev_dim < 0)) {
141 MS_LOG(ERROR) << "Invalid device dimension " << dev_dim << " when construct AllGather operator!";
142 return Status::INVALID_ARGUMENT;
143 }
144
145 std::vector<Group> group_list;
146 if (CreateGroupByDim(dev_size_ - LongToSize(dev_dim) - 1, &group_list) != SUCCESS) {
147 MS_LOG(ERROR) << "AllGather op: create group failed";
148 return FAILED;
149 } else if (group_list.empty()) { // this group only has one device, don't need do allgather
150 MS_LOG(INFO) << "no need all gather op";
151 return SUCCESS;
152 }
153
154 std::string group_name = group_list[0].name();
155 ValuePtr attr_value = MakeValue(group_name);
156 Attr attr = std::make_pair(GROUP, attr_value);
157 OperatorAttrs attrs = {attr};
158 OperatorParams params;
159 OperatorArgs args = std::make_pair(attrs, params);
160 op_ = std::make_pair(ALL_GATHER, args);
161 return Status::SUCCESS;
162 }
163
ConcatOP(int64_t concat_dim)164 Status ConstructOperator::ConcatOP(int64_t concat_dim) {
165 if (LongToSize(concat_dim) >= tensor_shape_.size() || concat_dim < 0) {
166 MS_LOG(ERROR) << "Invalid tensor dimension " << concat_dim << " when construct Concat operator!";
167 return Status::INVALID_ARGUMENT;
168 }
169 ValuePtr attr_value = MakeValue(concat_dim);
170 Attr attr = std::make_pair(AXIS, attr_value);
171 OperatorAttrs attrs = {attr};
172 OperatorParams params;
173 OperatorArgs args = std::make_pair(attrs, params);
174 op_ = std::make_pair(CONCAT, args);
175 return Status::SUCCESS;
176 }
177
SplitOP(int64_t split_count)178 Status ConstructOperator::SplitOP(int64_t split_count) {
179 // tensor_shape_ can not be validated here
180 if (split_count <= 0) {
181 MS_LOG(ERROR) << "Invalid split count when construct Split operator!";
182 return Status::FAILED;
183 }
184 OperatorAttrs attrs;
185 ValuePtr attr_value_axis = MakeValue(DEFAULT);
186 Attr attr_axis = std::make_pair(AXIS, attr_value_axis);
187 ValuePtr attr_value_split = MakeValue(split_count);
188 Attr attr_split = std::make_pair(OUTPUT_NUM, attr_value_split);
189 attrs = {attr_axis, attr_split};
190 OperatorParams params;
191 OperatorArgs args = std::make_pair(attrs, params);
192 op_ = std::make_pair(SPLIT, args);
193 return Status::SUCCESS;
194 }
195
AlltoAllOP(const Args & args)196 Status ConstructOperator::AlltoAllOP(const Args &args) {
197 if (args.size() < TRANSFER_PERMUTE_ARGS_SIZE) {
198 MS_LOG(ERROR) << "args size should not be less than 5!";
199 return Status::FAILED;
200 }
201 int64_t split_count = args[TRANSFER_PERMUTE_SPLIT_COUNT_INDEX];
202 int64_t split_dim = args[TRANSFER_PERMUTE_SPLIT_DIM_INDEX];
203 int64_t concat_dim = args[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
204 int64_t dev_dim = args[TRANSFER_PERMUTE_DEV_DIM_INDEX];
205 if (split_count <= 0) {
206 MS_LOG(ERROR) << "Invalid split count when construct AlltoAll operator!";
207 return Status::FAILED;
208 }
209 if (tensor_shape_[LongToSize(split_dim)] % split_count != 0) {
210 MS_LOG(ERROR) << "Tensor can not be split into " << split_count << " slices in the dimension " << split_dim
211 << "when construct AlltoAll operator!";
212 return Status::INVALID_ARGUMENT;
213 }
214 if (LongToSize(concat_dim) >= tensor_shape_.size() || concat_dim < 0) {
215 MS_LOG(ERROR) << "Invalid split count " << split_count << " when construct AlltoAll operator!";
216 return Status::INVALID_ARGUMENT;
217 }
218 if ((LongToSize(dev_dim) >= dev_size_) || (dev_dim < 0)) {
219 MS_LOG(ERROR) << "Invalid device dimension " << dev_dim << " when construct AlltoAll operator!";
220 return Status::INVALID_ARGUMENT;
221 }
222
223 std::vector<Group> group_list;
224 if (CreateGroupByDim(dev_size_ - LongToSize(dev_dim) - 1, &group_list) != SUCCESS) {
225 MS_LOG(ERROR) << "AlltoAll op: create group failed";
226 return FAILED;
227 } else if (group_list.empty()) { // this group only has one device, don't need do alltoall
228 MS_LOG(INFO) << "no need all to all op";
229 return SUCCESS;
230 }
231
232 std::string group_name = group_list[0].name();
233 ValuePtr attr_value_group = MakeValue(group_name);
234 Attr attr_group = std::make_pair(GROUP, attr_value_group);
235 ValuePtr attr_value_split_count = MakeValue(split_count);
236 Attr attr_split_count = std::make_pair(SPLIT_COUNT, attr_value_split_count);
237 ValuePtr attr_value_split_dim = MakeValue(split_dim);
238 Attr attr_split_dim = std::make_pair(SPLIT_DIM, attr_value_split_dim);
239 ValuePtr attr_value_concat_dim = MakeValue(concat_dim);
240 Attr attr_concat_dim = std::make_pair(CONCAT_DIM, attr_value_concat_dim);
241 OperatorAttrs attrs = {attr_split_count, attr_split_dim, attr_concat_dim, attr_group};
242 OperatorParams params;
243 OperatorArgs op_args = std::make_pair(attrs, params);
244 op_ = std::make_pair(ALL_TO_ALL, op_args);
245 return Status::SUCCESS;
246 }
247
CreateGroupByDim(size_t axis,std::vector<Group> * group)248 Status ConstructOperator::CreateGroupByDim(size_t axis, std::vector<Group> *group) {
249 MS_EXCEPTION_IF_NULL(group);
250 CheckGlobalDeviceManager();
251 MS_EXCEPTION_IF_NULL(g_device_manager);
252 int64_t rank = g_device_manager->global_rank();
253 DeviceMatrix dev_matrix(rank, dev_list_, dev_matrix_shape_);
254 RankList group_devices;
255 if (dev_matrix.GetDevicesAlongDim(SizeToUlong(axis), &group_devices) != SUCCESS) {
256 return FAILED;
257 }
258 // this group only has one device, don't need create the group
259 if (group_devices.size() == 1) {
260 MS_LOG(INFO) << "the group is empty";
261 return SUCCESS;
262 }
263
264 Group g = g_device_manager->CreateGroup(group_devices);
265 group->push_back(g);
266 return SUCCESS;
267 }
268 } // namespace parallel
269 } // namespace mindspore
270