• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/tensor_layout/tensor_transform.h"
18 #include <functional>
19 #include <algorithm>
20 #include <memory>
21 #include <utility>
22 #include <string>
23 #include "include/common/utils/parallel_context.h"
24 #include "frontend/parallel/ops_info/ops_utils.h"
25 #include "frontend/parallel/graph_util/graph_utils.h"
26 
27 namespace mindspore {
28 namespace parallel {
29 const size_t kAllConcatSize = 3;
30 const size_t kIndex0 = 0;
31 const size_t kIndex1 = 1;
32 const size_t kIndex2 = 2;
33 const size_t kSize1 = 1;
34 const size_t kSize2 = 2;
35 const size_t kSize3 = 3;
36 
TensorTransform()37 TensorTransform::TensorTransform() {}
38 
GetInstance()39 std::shared_ptr<TensorTransform> TensorTransform::GetInstance() {
40   static std::shared_ptr<TensorTransform> inst_tensor_transform_ =
41     std::shared_ptr<TensorTransform>(new TensorTransform());
42   inst_tensor_transform_->InitTransforOperator();
43   return inst_tensor_transform_;
44 }
45 
InitTransforOperator()46 void TensorTransform::InitTransforOperator() {
47   if (inited_function_) {
48     return;
49   }
50   transform_operator_[RESHAPE] = [this](auto op_pair) { return ExtractReshapeOp(op_pair); };
51   transform_operator_[ALL_GATHER] = [this](auto op_pair) { return ExtractAllGatherOp(op_pair); };
52   transform_operator_[SPLIT] = [this](auto op_pair) { return ExtractSplitOp(op_pair); };
53   transform_operator_[CONCAT] = [this](auto op_pair) { return ExtractConcatOp(op_pair); };
54   transform_operator_[STRIDEDSLICE] = [this](auto op_pair) { return ExtractStridedSliceOp(op_pair); };
55   infer_shape_operator_[RESHAPE] = [this](Shape ori_shape, std::vector<int64_t> op_pair) {
56     return InferReshapeOp(ori_shape, op_pair);
57   };
58   infer_shape_operator_[ALL_GATHER] = [this](Shape ori_shape, std::vector<int64_t> op_pair) {
59     return InferAllGatherOp(ori_shape, op_pair);
60   };
61   infer_shape_operator_[STRIDEDSLICE] = [this](Shape ori_shape, std::vector<int64_t> op_pair) {
62     return InferStridedSliceOp(ori_shape, op_pair);
63   };
64   inited_function_ = true;
65 }
66 
67 // return {op_name, dst_shape}
ExtractReshapeOp(const Operator & reshape_op_pair) const68 std::pair<std::string, std::vector<int64_t>> TensorTransform::ExtractReshapeOp(const Operator &reshape_op_pair) const {
69   auto op_name = reshape_op_pair.first;
70   auto op_params = reshape_op_pair.second.second;
71   if (op_params.empty()) {
72     MS_LOG(INTERNAL_EXCEPTION) << "The reshape has not contains dst_shape.";
73   }
74   auto shape_value_ptr = op_params.front().first.second;
75   auto dst_shape = GetValue<std::vector<int64_t>>(shape_value_ptr);
76   return std::make_pair(op_name, dst_shape);
77 }
78 
79 // return {op_name, group_ranks + axis}
ExtractAllGatherOp(const Operator & allgather_op_pair) const80 std::pair<std::string, std::vector<int64_t>> TensorTransform::ExtractAllGatherOp(
81   const Operator &allgather_op_pair) const {
82   auto op_name = allgather_op_pair.first;
83   auto op_attrs = allgather_op_pair.second.first;
84   if (op_attrs.size() < kSize2) {
85     MS_LOG(INTERNAL_EXCEPTION) << "The allgather has not contains group attrs.";
86   }
87   auto group_attr = op_attrs[1].second;
88   auto group_ranks = GetValue<std::vector<int64_t>>(group_attr);
89   // default allgather axis is 0
90   group_ranks.push_back(0);
91   return std::make_pair(op_name, group_ranks);
92 }
93 
94 // return {op_name, [axis, output_num]}
ExtractSplitOp(const Operator & split_op_pair) const95 std::pair<std::string, std::vector<int64_t>> TensorTransform::ExtractSplitOp(const Operator &split_op_pair) const {
96   auto op_name = split_op_pair.first;
97   auto op_attrs = split_op_pair.second.first;
98   if (op_attrs.size() < kSize2) {
99     MS_LOG(INTERNAL_EXCEPTION) << "The split has not contains output_num attrs.";
100   }
101   auto axis_attr = op_attrs[0].second;
102   auto axis = GetValue<int64_t>(axis_attr);
103   auto output_num_attr = op_attrs[1].second;
104   auto output_num = GetValue<int64_t>(output_num_attr);
105   std::vector<int64_t> attr_list = {axis, output_num};
106   return std::make_pair(op_name, attr_list);
107 }
108 
109 // return {op_name, [axis]}
ExtractConcatOp(const Operator & concat_op_pair) const110 std::pair<std::string, std::vector<int64_t>> TensorTransform::ExtractConcatOp(const Operator &concat_op_pair) const {
111   auto op_name = concat_op_pair.first;
112   auto op_attrs = concat_op_pair.second.first;
113   if (op_attrs.size() < 1) {
114     MS_LOG(INTERNAL_EXCEPTION) << "The concat has not contains axis attrs.";
115   }
116   auto axis_attr = op_attrs[0].second;
117   auto axis = GetValue<int64_t>(axis_attr);
118   std::vector<int64_t> attr_list = {axis};
119   return std::make_pair(op_name, attr_list);
120 }
121 
122 // return {op_name, begin + end + stride}
ExtractStridedSliceOp(const Operator & slice_op_pair) const123 std::pair<std::string, std::vector<int64_t>> TensorTransform::ExtractStridedSliceOp(
124   const Operator &slice_op_pair) const {
125   auto op_name = slice_op_pair.first;
126   auto op_params = slice_op_pair.second.second;
127   if (op_params.size() < kSize3) {
128     MS_LOG(INTERNAL_EXCEPTION) << "The stridedslice op has not contains begin/end/strides.";
129   }
130   auto begin_value_ptr = op_params[0].first.second;
131   auto begin = GetValue<std::vector<int64_t>>(begin_value_ptr);
132   auto end_value_ptr = op_params[1].first.second;
133   auto end = GetValue<std::vector<int64_t>>(end_value_ptr);
134   auto stride_value_ptr = op_params[2].first.second;
135   auto stride = GetValue<std::vector<int64_t>>(stride_value_ptr);
136   std::vector<int64_t> stride_attr;
137   (void)std::copy(begin.begin(), begin.end(), std::back_inserter(stride_attr));
138   (void)std::copy(end.begin(), end.end(), std::back_inserter(stride_attr));
139   (void)std::copy(stride.begin(), stride.end(), std::back_inserter(stride_attr));
140   return std::make_pair(op_name, stride_attr);
141 }
142 
OptimizeAllConcat(std::vector<std::pair<std::string,std::vector<int64_t>>> * transform_op_list)143 void TensorTransform::OptimizeAllConcat(std::vector<std::pair<std::string, std::vector<int64_t>>> *transform_op_list) {
144   if (transform_op_list->size() < kAllConcatSize) {
145     return;
146   }
147   std::vector<size_t> allconcat_index;
148   for (size_t i = kAllConcatSize - 1; i < transform_op_list->size(); ++i) {
149     if ((*transform_op_list)[i - kIndex2].first != ALL_GATHER || (*transform_op_list)[i - 1].first != SPLIT ||
150         (*transform_op_list)[i].first != CONCAT) {
151       continue;
152     }
153     auto allgather_group_size = SizeToLong((*transform_op_list)[i - kIndex2].second.size() - 1);
154     auto split_axis = ((*transform_op_list)[i - kIndex1].second)[kIndex0];
155     auto split_size = ((*transform_op_list)[i - kIndex1].second)[kIndex1];
156     auto concat_axis = (*transform_op_list)[i].second.front();
157     if (allgather_group_size != split_size || split_axis != 0) {
158       continue;
159     }
160     (*transform_op_list)[i - kIndex2].second.back() = concat_axis;
161     allconcat_index.push_back(i);
162   }
163   for (int j = SizeToInt(allconcat_index.size()) - 1; j >= 0; --j) {
164     auto erase_index = allconcat_index[IntToSize(j)];
165     (void)transform_op_list->erase(transform_op_list->begin() + erase_index);
166     (void)transform_op_list->erase(transform_op_list->begin() + erase_index - 1);
167   }
168 }
169 
TransformOperators(const Shapes & from,const Shapes & to,const RankList & dev_list,int64_t rank_id)170 std::vector<std::pair<std::string, std::vector<int64_t>>> TensorTransform::TransformOperators(const Shapes &from,
171                                                                                               const Shapes &to,
172                                                                                               const RankList &dev_list,
173                                                                                               int64_t rank_id) {
174   TensorLayout from_layout;
175   (void)from_layout.InitFromVector(from[kIndex0], from[kIndex1], from[kIndex2]);
176   TensorLayout to_layout;
177   (void)to_layout.InitFromVector(to[kIndex0], to[kIndex1], to[kIndex2]);
178   ParallelContext::GetInstance()->set_do_transform(true);
179   tensor_redistribution_.SetVirtualRank(rank_id);
180   (void)tensor_redistribution_.Init(from_layout, to_layout, dev_list);
181   RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution_.InferTensorRedistributionOperatorList();
182   if (redistribution_oplist_ptr == nullptr) {
183     MS_LOG(INTERNAL_EXCEPTION) << "Infer tensor redistribution failed.";
184   }
185   if (redistribution_oplist_ptr->first.size() != redistribution_oplist_ptr->second.size()) {
186     MS_LOG(INTERNAL_EXCEPTION) << "The redistribution op list size cannot match redistribution output info list size.";
187   }
188   auto operators_vector = redistribution_oplist_ptr->first;
189   std::vector<std::pair<std::string, std::vector<int64_t>>> transform_op_list;
190   for (const auto &op_pair : operators_vector) {
191     auto op_name = op_pair.first;
192     auto it = transform_operator_.find(op_name);
193     if (it == transform_operator_.end()) {
194       MS_LOG(INTERNAL_EXCEPTION) << "The op:" << op_name << " is not a valid redistrbution op.";
195     }
196     transform_op_list.push_back(it->second(op_pair));
197   }
198   OptimizeAllConcat(&transform_op_list);
199   ParallelContext::GetInstance()->set_do_transform(false);
200   return transform_op_list;
201 }
202 
InferReshapeOp(const Shape & ori_shape,const std::vector<int64_t> & op) const203 Shape TensorTransform::InferReshapeOp(const Shape &ori_shape, const std::vector<int64_t> &op) const {
204   if (std::find(op.begin(), op.end(), -1) != op.end()) {
205     MS_LOG(DEBUG) << "It's dynamic shape. Reshape to " << op;
206     return op;
207   }
208   if (std::find(ori_shape.begin(), ori_shape.end(), -1) != ori_shape.end()) {
209     return op;
210   }
211   if (std::accumulate(ori_shape.begin(), ori_shape.end(), 1, std::multiplies<int64_t>()) !=
212       std::accumulate(op.begin(), op.end(), 1, std::multiplies<int64_t>())) {
213     MS_LOG(EXCEPTION) << "Infer redistribution error, cannot convert shape: " << ori_shape << " to shape:" << op;
214   }
215   MS_LOG(DEBUG) << "It's static shape. Reshape to " << op;
216   return op;
217 }
218 
InferAllGatherOp(const Shape & ori_shape,const std::vector<int64_t> & op) const219 Shape TensorTransform::InferAllGatherOp(const Shape &ori_shape, const std::vector<int64_t> &op) const {
220   auto new_shape = ori_shape;
221   auto axis = op.back();
222   if (new_shape[LongToSize(axis)] != -1) {
223     new_shape[LongToSize(axis)] = new_shape[LongToSize(axis)] * (op.size() - 1);
224   }
225   return new_shape;
226 }
227 
InferStridedSliceOp(const Shape & ori_shape,const std::vector<int64_t> & op) const228 Shape TensorTransform::InferStridedSliceOp(const Shape &ori_shape, const std::vector<int64_t> &op) const {
229   size_t end_index = size_t(op.size() / 3);
230   if (ori_shape.size() != end_index) {
231     MS_LOG(EXCEPTION) << "Infer redistribution error, the shape:" << ori_shape
232                       << " cannot be sliced with dimension size:" << end_index;
233   }
234   auto new_shape = ori_shape;
235   for (size_t i = 0; i < ori_shape.size(); ++i) {
236     new_shape[i] = (op[end_index + i] - op[i]) / op[kSize2 * end_index + i];
237   }
238   return new_shape;
239 }
240 
GetRedistributionOpShape(const Shape & ori_shape,const std::vector<std::pair<std::string,std::vector<int64_t>>> & transform_op_list)241 std::vector<Shape> TensorTransform::GetRedistributionOpShape(
242   const Shape &ori_shape, const std::vector<std::pair<std::string, std::vector<int64_t>>> &transform_op_list) {
243   std::vector<Shape> result_shape;
244   auto cur_shape = ori_shape;
245   for (const auto &op : transform_op_list) {
246     auto op_name = op.first;
247     auto it = infer_shape_operator_.find(op_name);
248     if (it == infer_shape_operator_.end()) {
249       MS_LOG(EXCEPTION) << "The op:" << op_name << " cannot infer shape in redistribution.";
250     }
251     cur_shape = it->second(cur_shape, op.second);
252     result_shape.push_back(cur_shape);
253   }
254   return result_shape;
255 }
256 
ConstructReshapeOp(const std::vector<int64_t> & shape)257 Operator ConstructReshapeOp(const std::vector<int64_t> &shape) {
258   OperatorAttrs attrs;
259   ValuePtr param_value = MakeValue(shape);
260   Attr param = std::make_pair(SHAPE, param_value);
261   OperatorParams params = {std::make_pair(param, 2)};
262   OperatorArgs args = std::make_pair(attrs, params);
263   return std::make_pair(RESHAPE, args);
264 }
265 
OptimizeTensorRedistributionOperatorList(const RedistributionOpListPtr & redistribution_op_list,const Shape & input_shape)266 RedistributionOpListPtr TensorTransform::OptimizeTensorRedistributionOperatorList(
267   const RedistributionOpListPtr &redistribution_op_list, const Shape &input_shape) {
268   MS_LOG(DEBUG) << "Do optimization for tensor redistributions.";
269   // 1 operators_vector to transform_op_list
270   // 2 allgather->split->concat to allconcat
271   MS_EXCEPTION_IF_NULL(redistribution_op_list);
272   if ((redistribution_op_list->first).size() != (redistribution_op_list->second).size()) {
273     return redistribution_op_list;
274   }
275   auto operators_vector = redistribution_op_list->first;
276   std::vector<std::pair<std::string, std::vector<int64_t>>> transform_op_list;
277   for (const auto &op_pair : operators_vector) {
278     auto op_name = op_pair.first;
279     auto it = transform_operator_.find(op_name);
280     if (it == transform_operator_.end() || IsToBeInsertedSplitOp(op_pair)) {
281       MS_LOG(INFO) << "The op:" << op_name << " would not be optimized.";
282       return redistribution_op_list;
283     }
284     transform_op_list.push_back(it->second(op_pair));
285   }
286   OptimizeAllConcat(&transform_op_list);
287   auto shape_list = GetRedistributionOpShape(input_shape, transform_op_list);
288   size_t current_allgather_pos_in_origin_list = 0;
289   std::unordered_map<size_t, std::vector<int64_t>> left_reshape_op_list;
290   std::vector<size_t> allconcat_pos_list;
291   // 3 remove the dim which value is 1 for AllConcat
292   for (size_t i = 0; i < transform_op_list.size(); ++i) {
293     auto trans_op_pair = transform_op_list[i];
294     if (trans_op_pair.first != ALL_GATHER) {
295       current_allgather_pos_in_origin_list++;
296       continue;
297     }
298     auto axis = transform_op_list[i].second.back();
299     if (axis == 0) {
300       current_allgather_pos_in_origin_list += kSize3;
301       continue;
302     }
303     if (i == transform_op_list.size() - 1 || transform_op_list[i + 1].first != RESHAPE) {
304       current_allgather_pos_in_origin_list += kSize3;
305       continue;
306     }
307     auto src_shape = shape_list[i];
308     if (src_shape[LongToSize(axis)] > 0 && transform_op_list[i].second.size() - 1 > 0) {
309       src_shape[LongToSize(axis)] = src_shape[LongToSize(axis)] / (SizeToLong(transform_op_list[i].second.size()) - 1);
310     }
311     auto new_axis = axis;
312     auto new_src_shape = src_shape;
313     for (int32_t j = axis - 1; j >= 0; --j) {
314       if (src_shape[j] != 1) {
315         continue;
316       }
317       new_src_shape.erase(new_src_shape.begin() + j);
318       new_axis -= 1;
319     }
320     MS_LOG(INFO) << "src_shape:" << src_shape << ", new_src_shape:" << new_src_shape << ", axis:" << axis
321                  << ", new_axis:" << new_axis;
322     if (new_axis != 0) {
323       current_allgather_pos_in_origin_list += kSize3;
324       continue;
325     }
326     left_reshape_op_list[current_allgather_pos_in_origin_list] = new_src_shape;
327     allconcat_pos_list.push_back(current_allgather_pos_in_origin_list);
328     current_allgather_pos_in_origin_list += kSize3;
329   }
330   // Insert reshape and adjust allgather-split-concat for redistribution_op_list
331   std::reverse(allconcat_pos_list.begin(), allconcat_pos_list.end());
332   for (auto pos : allconcat_pos_list) {
333     // erase split concat
334     (void)redistribution_op_list->first.erase(redistribution_op_list->first.begin() + pos + kSize2);
335     (void)redistribution_op_list->first.erase(redistribution_op_list->first.begin() + pos + kSize1);
336     (void)redistribution_op_list->second.erase(redistribution_op_list->second.begin() + pos + kSize2);
337     (void)redistribution_op_list->second.erase(redistribution_op_list->second.begin() + pos + kSize1);
338     // insert reshape before allgather
339     Operator left_reshape_op = ConstructReshapeOp(left_reshape_op_list[pos]);
340     (void)redistribution_op_list->first.insert(redistribution_op_list->first.begin() + pos, left_reshape_op);
341     (void)redistribution_op_list->second.insert(redistribution_op_list->second.begin() + pos, {false, 0});
342   }
343   return redistribution_op_list;
344 }
345 }  // namespace parallel
346 }  // namespace mindspore
347