1 /**
2 * Copyright 2022-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/tensor_layout/tensor_transform.h"
18 #include <functional>
19 #include <algorithm>
20 #include <memory>
21 #include <utility>
22 #include <string>
23 #include "include/common/utils/parallel_context.h"
24 #include "frontend/parallel/ops_info/ops_utils.h"
25 #include "frontend/parallel/graph_util/graph_utils.h"
26
27 namespace mindspore {
28 namespace parallel {
29 const size_t kAllConcatSize = 3;
30 const size_t kIndex0 = 0;
31 const size_t kIndex1 = 1;
32 const size_t kIndex2 = 2;
33 const size_t kSize1 = 1;
34 const size_t kSize2 = 2;
35 const size_t kSize3 = 3;
36
TensorTransform()37 TensorTransform::TensorTransform() {}
38
GetInstance()39 std::shared_ptr<TensorTransform> TensorTransform::GetInstance() {
40 static std::shared_ptr<TensorTransform> inst_tensor_transform_ =
41 std::shared_ptr<TensorTransform>(new TensorTransform());
42 inst_tensor_transform_->InitTransforOperator();
43 return inst_tensor_transform_;
44 }
45
InitTransforOperator()46 void TensorTransform::InitTransforOperator() {
47 if (inited_function_) {
48 return;
49 }
50 transform_operator_[RESHAPE] = [this](auto op_pair) { return ExtractReshapeOp(op_pair); };
51 transform_operator_[ALL_GATHER] = [this](auto op_pair) { return ExtractAllGatherOp(op_pair); };
52 transform_operator_[SPLIT] = [this](auto op_pair) { return ExtractSplitOp(op_pair); };
53 transform_operator_[CONCAT] = [this](auto op_pair) { return ExtractConcatOp(op_pair); };
54 transform_operator_[STRIDEDSLICE] = [this](auto op_pair) { return ExtractStridedSliceOp(op_pair); };
55 infer_shape_operator_[RESHAPE] = [this](Shape ori_shape, std::vector<int64_t> op_pair) {
56 return InferReshapeOp(ori_shape, op_pair);
57 };
58 infer_shape_operator_[ALL_GATHER] = [this](Shape ori_shape, std::vector<int64_t> op_pair) {
59 return InferAllGatherOp(ori_shape, op_pair);
60 };
61 infer_shape_operator_[STRIDEDSLICE] = [this](Shape ori_shape, std::vector<int64_t> op_pair) {
62 return InferStridedSliceOp(ori_shape, op_pair);
63 };
64 inited_function_ = true;
65 }
66
67 // return {op_name, dst_shape}
ExtractReshapeOp(const Operator & reshape_op_pair) const68 std::pair<std::string, std::vector<int64_t>> TensorTransform::ExtractReshapeOp(const Operator &reshape_op_pair) const {
69 auto op_name = reshape_op_pair.first;
70 auto op_params = reshape_op_pair.second.second;
71 if (op_params.empty()) {
72 MS_LOG(INTERNAL_EXCEPTION) << "The reshape has not contains dst_shape.";
73 }
74 auto shape_value_ptr = op_params.front().first.second;
75 auto dst_shape = GetValue<std::vector<int64_t>>(shape_value_ptr);
76 return std::make_pair(op_name, dst_shape);
77 }
78
79 // return {op_name, group_ranks + axis}
ExtractAllGatherOp(const Operator & allgather_op_pair) const80 std::pair<std::string, std::vector<int64_t>> TensorTransform::ExtractAllGatherOp(
81 const Operator &allgather_op_pair) const {
82 auto op_name = allgather_op_pair.first;
83 auto op_attrs = allgather_op_pair.second.first;
84 if (op_attrs.size() < kSize2) {
85 MS_LOG(INTERNAL_EXCEPTION) << "The allgather has not contains group attrs.";
86 }
87 auto group_attr = op_attrs[1].second;
88 auto group_ranks = GetValue<std::vector<int64_t>>(group_attr);
89 // default allgather axis is 0
90 group_ranks.push_back(0);
91 return std::make_pair(op_name, group_ranks);
92 }
93
94 // return {op_name, [axis, output_num]}
ExtractSplitOp(const Operator & split_op_pair) const95 std::pair<std::string, std::vector<int64_t>> TensorTransform::ExtractSplitOp(const Operator &split_op_pair) const {
96 auto op_name = split_op_pair.first;
97 auto op_attrs = split_op_pair.second.first;
98 if (op_attrs.size() < kSize2) {
99 MS_LOG(INTERNAL_EXCEPTION) << "The split has not contains output_num attrs.";
100 }
101 auto axis_attr = op_attrs[0].second;
102 auto axis = GetValue<int64_t>(axis_attr);
103 auto output_num_attr = op_attrs[1].second;
104 auto output_num = GetValue<int64_t>(output_num_attr);
105 std::vector<int64_t> attr_list = {axis, output_num};
106 return std::make_pair(op_name, attr_list);
107 }
108
109 // return {op_name, [axis]}
ExtractConcatOp(const Operator & concat_op_pair) const110 std::pair<std::string, std::vector<int64_t>> TensorTransform::ExtractConcatOp(const Operator &concat_op_pair) const {
111 auto op_name = concat_op_pair.first;
112 auto op_attrs = concat_op_pair.second.first;
113 if (op_attrs.size() < 1) {
114 MS_LOG(INTERNAL_EXCEPTION) << "The concat has not contains axis attrs.";
115 }
116 auto axis_attr = op_attrs[0].second;
117 auto axis = GetValue<int64_t>(axis_attr);
118 std::vector<int64_t> attr_list = {axis};
119 return std::make_pair(op_name, attr_list);
120 }
121
122 // return {op_name, begin + end + stride}
ExtractStridedSliceOp(const Operator & slice_op_pair) const123 std::pair<std::string, std::vector<int64_t>> TensorTransform::ExtractStridedSliceOp(
124 const Operator &slice_op_pair) const {
125 auto op_name = slice_op_pair.first;
126 auto op_params = slice_op_pair.second.second;
127 if (op_params.size() < kSize3) {
128 MS_LOG(INTERNAL_EXCEPTION) << "The stridedslice op has not contains begin/end/strides.";
129 }
130 auto begin_value_ptr = op_params[0].first.second;
131 auto begin = GetValue<std::vector<int64_t>>(begin_value_ptr);
132 auto end_value_ptr = op_params[1].first.second;
133 auto end = GetValue<std::vector<int64_t>>(end_value_ptr);
134 auto stride_value_ptr = op_params[2].first.second;
135 auto stride = GetValue<std::vector<int64_t>>(stride_value_ptr);
136 std::vector<int64_t> stride_attr;
137 (void)std::copy(begin.begin(), begin.end(), std::back_inserter(stride_attr));
138 (void)std::copy(end.begin(), end.end(), std::back_inserter(stride_attr));
139 (void)std::copy(stride.begin(), stride.end(), std::back_inserter(stride_attr));
140 return std::make_pair(op_name, stride_attr);
141 }
142
OptimizeAllConcat(std::vector<std::pair<std::string,std::vector<int64_t>>> * transform_op_list)143 void TensorTransform::OptimizeAllConcat(std::vector<std::pair<std::string, std::vector<int64_t>>> *transform_op_list) {
144 if (transform_op_list->size() < kAllConcatSize) {
145 return;
146 }
147 std::vector<size_t> allconcat_index;
148 for (size_t i = kAllConcatSize - 1; i < transform_op_list->size(); ++i) {
149 if ((*transform_op_list)[i - kIndex2].first != ALL_GATHER || (*transform_op_list)[i - 1].first != SPLIT ||
150 (*transform_op_list)[i].first != CONCAT) {
151 continue;
152 }
153 auto allgather_group_size = SizeToLong((*transform_op_list)[i - kIndex2].second.size() - 1);
154 auto split_axis = ((*transform_op_list)[i - kIndex1].second)[kIndex0];
155 auto split_size = ((*transform_op_list)[i - kIndex1].second)[kIndex1];
156 auto concat_axis = (*transform_op_list)[i].second.front();
157 if (allgather_group_size != split_size || split_axis != 0) {
158 continue;
159 }
160 (*transform_op_list)[i - kIndex2].second.back() = concat_axis;
161 allconcat_index.push_back(i);
162 }
163 for (int j = SizeToInt(allconcat_index.size()) - 1; j >= 0; --j) {
164 auto erase_index = allconcat_index[IntToSize(j)];
165 (void)transform_op_list->erase(transform_op_list->begin() + erase_index);
166 (void)transform_op_list->erase(transform_op_list->begin() + erase_index - 1);
167 }
168 }
169
TransformOperators(const Shapes & from,const Shapes & to,const RankList & dev_list,int64_t rank_id)170 std::vector<std::pair<std::string, std::vector<int64_t>>> TensorTransform::TransformOperators(const Shapes &from,
171 const Shapes &to,
172 const RankList &dev_list,
173 int64_t rank_id) {
174 TensorLayout from_layout;
175 (void)from_layout.InitFromVector(from[kIndex0], from[kIndex1], from[kIndex2]);
176 TensorLayout to_layout;
177 (void)to_layout.InitFromVector(to[kIndex0], to[kIndex1], to[kIndex2]);
178 ParallelContext::GetInstance()->set_do_transform(true);
179 tensor_redistribution_.SetVirtualRank(rank_id);
180 (void)tensor_redistribution_.Init(from_layout, to_layout, dev_list);
181 RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution_.InferTensorRedistributionOperatorList();
182 if (redistribution_oplist_ptr == nullptr) {
183 MS_LOG(INTERNAL_EXCEPTION) << "Infer tensor redistribution failed.";
184 }
185 if (redistribution_oplist_ptr->first.size() != redistribution_oplist_ptr->second.size()) {
186 MS_LOG(INTERNAL_EXCEPTION) << "The redistribution op list size cannot match redistribution output info list size.";
187 }
188 auto operators_vector = redistribution_oplist_ptr->first;
189 std::vector<std::pair<std::string, std::vector<int64_t>>> transform_op_list;
190 for (const auto &op_pair : operators_vector) {
191 auto op_name = op_pair.first;
192 auto it = transform_operator_.find(op_name);
193 if (it == transform_operator_.end()) {
194 MS_LOG(INTERNAL_EXCEPTION) << "The op:" << op_name << " is not a valid redistrbution op.";
195 }
196 transform_op_list.push_back(it->second(op_pair));
197 }
198 OptimizeAllConcat(&transform_op_list);
199 ParallelContext::GetInstance()->set_do_transform(false);
200 return transform_op_list;
201 }
202
InferReshapeOp(const Shape & ori_shape,const std::vector<int64_t> & op) const203 Shape TensorTransform::InferReshapeOp(const Shape &ori_shape, const std::vector<int64_t> &op) const {
204 if (std::find(op.begin(), op.end(), -1) != op.end()) {
205 MS_LOG(DEBUG) << "It's dynamic shape. Reshape to " << op;
206 return op;
207 }
208 if (std::find(ori_shape.begin(), ori_shape.end(), -1) != ori_shape.end()) {
209 return op;
210 }
211 if (std::accumulate(ori_shape.begin(), ori_shape.end(), 1, std::multiplies<int64_t>()) !=
212 std::accumulate(op.begin(), op.end(), 1, std::multiplies<int64_t>())) {
213 MS_LOG(EXCEPTION) << "Infer redistribution error, cannot convert shape: " << ori_shape << " to shape:" << op;
214 }
215 MS_LOG(DEBUG) << "It's static shape. Reshape to " << op;
216 return op;
217 }
218
InferAllGatherOp(const Shape & ori_shape,const std::vector<int64_t> & op) const219 Shape TensorTransform::InferAllGatherOp(const Shape &ori_shape, const std::vector<int64_t> &op) const {
220 auto new_shape = ori_shape;
221 auto axis = op.back();
222 if (new_shape[LongToSize(axis)] != -1) {
223 new_shape[LongToSize(axis)] = new_shape[LongToSize(axis)] * (op.size() - 1);
224 }
225 return new_shape;
226 }
227
InferStridedSliceOp(const Shape & ori_shape,const std::vector<int64_t> & op) const228 Shape TensorTransform::InferStridedSliceOp(const Shape &ori_shape, const std::vector<int64_t> &op) const {
229 size_t end_index = size_t(op.size() / 3);
230 if (ori_shape.size() != end_index) {
231 MS_LOG(EXCEPTION) << "Infer redistribution error, the shape:" << ori_shape
232 << " cannot be sliced with dimension size:" << end_index;
233 }
234 auto new_shape = ori_shape;
235 for (size_t i = 0; i < ori_shape.size(); ++i) {
236 new_shape[i] = (op[end_index + i] - op[i]) / op[kSize2 * end_index + i];
237 }
238 return new_shape;
239 }
240
GetRedistributionOpShape(const Shape & ori_shape,const std::vector<std::pair<std::string,std::vector<int64_t>>> & transform_op_list)241 std::vector<Shape> TensorTransform::GetRedistributionOpShape(
242 const Shape &ori_shape, const std::vector<std::pair<std::string, std::vector<int64_t>>> &transform_op_list) {
243 std::vector<Shape> result_shape;
244 auto cur_shape = ori_shape;
245 for (const auto &op : transform_op_list) {
246 auto op_name = op.first;
247 auto it = infer_shape_operator_.find(op_name);
248 if (it == infer_shape_operator_.end()) {
249 MS_LOG(EXCEPTION) << "The op:" << op_name << " cannot infer shape in redistribution.";
250 }
251 cur_shape = it->second(cur_shape, op.second);
252 result_shape.push_back(cur_shape);
253 }
254 return result_shape;
255 }
256
ConstructReshapeOp(const std::vector<int64_t> & shape)257 Operator ConstructReshapeOp(const std::vector<int64_t> &shape) {
258 OperatorAttrs attrs;
259 ValuePtr param_value = MakeValue(shape);
260 Attr param = std::make_pair(SHAPE, param_value);
261 OperatorParams params = {std::make_pair(param, 2)};
262 OperatorArgs args = std::make_pair(attrs, params);
263 return std::make_pair(RESHAPE, args);
264 }
265
OptimizeTensorRedistributionOperatorList(const RedistributionOpListPtr & redistribution_op_list,const Shape & input_shape)266 RedistributionOpListPtr TensorTransform::OptimizeTensorRedistributionOperatorList(
267 const RedistributionOpListPtr &redistribution_op_list, const Shape &input_shape) {
268 MS_LOG(DEBUG) << "Do optimization for tensor redistributions.";
269 // 1 operators_vector to transform_op_list
270 // 2 allgather->split->concat to allconcat
271 MS_EXCEPTION_IF_NULL(redistribution_op_list);
272 if ((redistribution_op_list->first).size() != (redistribution_op_list->second).size()) {
273 return redistribution_op_list;
274 }
275 auto operators_vector = redistribution_op_list->first;
276 std::vector<std::pair<std::string, std::vector<int64_t>>> transform_op_list;
277 for (const auto &op_pair : operators_vector) {
278 auto op_name = op_pair.first;
279 auto it = transform_operator_.find(op_name);
280 if (it == transform_operator_.end() || IsToBeInsertedSplitOp(op_pair)) {
281 MS_LOG(INFO) << "The op:" << op_name << " would not be optimized.";
282 return redistribution_op_list;
283 }
284 transform_op_list.push_back(it->second(op_pair));
285 }
286 OptimizeAllConcat(&transform_op_list);
287 auto shape_list = GetRedistributionOpShape(input_shape, transform_op_list);
288 size_t current_allgather_pos_in_origin_list = 0;
289 std::unordered_map<size_t, std::vector<int64_t>> left_reshape_op_list;
290 std::vector<size_t> allconcat_pos_list;
291 // 3 remove the dim which value is 1 for AllConcat
292 for (size_t i = 0; i < transform_op_list.size(); ++i) {
293 auto trans_op_pair = transform_op_list[i];
294 if (trans_op_pair.first != ALL_GATHER) {
295 current_allgather_pos_in_origin_list++;
296 continue;
297 }
298 auto axis = transform_op_list[i].second.back();
299 if (axis == 0) {
300 current_allgather_pos_in_origin_list += kSize3;
301 continue;
302 }
303 if (i == transform_op_list.size() - 1 || transform_op_list[i + 1].first != RESHAPE) {
304 current_allgather_pos_in_origin_list += kSize3;
305 continue;
306 }
307 auto src_shape = shape_list[i];
308 if (src_shape[LongToSize(axis)] > 0 && transform_op_list[i].second.size() - 1 > 0) {
309 src_shape[LongToSize(axis)] = src_shape[LongToSize(axis)] / (SizeToLong(transform_op_list[i].second.size()) - 1);
310 }
311 auto new_axis = axis;
312 auto new_src_shape = src_shape;
313 for (int32_t j = axis - 1; j >= 0; --j) {
314 if (src_shape[j] != 1) {
315 continue;
316 }
317 new_src_shape.erase(new_src_shape.begin() + j);
318 new_axis -= 1;
319 }
320 MS_LOG(INFO) << "src_shape:" << src_shape << ", new_src_shape:" << new_src_shape << ", axis:" << axis
321 << ", new_axis:" << new_axis;
322 if (new_axis != 0) {
323 current_allgather_pos_in_origin_list += kSize3;
324 continue;
325 }
326 left_reshape_op_list[current_allgather_pos_in_origin_list] = new_src_shape;
327 allconcat_pos_list.push_back(current_allgather_pos_in_origin_list);
328 current_allgather_pos_in_origin_list += kSize3;
329 }
330 // Insert reshape and adjust allgather-split-concat for redistribution_op_list
331 std::reverse(allconcat_pos_list.begin(), allconcat_pos_list.end());
332 for (auto pos : allconcat_pos_list) {
333 // erase split concat
334 (void)redistribution_op_list->first.erase(redistribution_op_list->first.begin() + pos + kSize2);
335 (void)redistribution_op_list->first.erase(redistribution_op_list->first.begin() + pos + kSize1);
336 (void)redistribution_op_list->second.erase(redistribution_op_list->second.begin() + pos + kSize2);
337 (void)redistribution_op_list->second.erase(redistribution_op_list->second.begin() + pos + kSize1);
338 // insert reshape before allgather
339 Operator left_reshape_op = ConstructReshapeOp(left_reshape_op_list[pos]);
340 (void)redistribution_op_list->first.insert(redistribution_op_list->first.begin() + pos, left_reshape_op);
341 (void)redistribution_op_list->second.insert(redistribution_op_list->second.begin() + pos, {false, 0});
342 }
343 return redistribution_op_list;
344 }
345 } // namespace parallel
346 } // namespace mindspore
347