• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
18 #include <functional>
19 #include <numeric>
20 #include <memory>
21 #include <set>
22 #include <utility>
23 #include <algorithm>
24 #include <string>
25 #include "frontend/parallel/status.h"
26 #include "frontend/parallel/ops_info/ops_utils.h"
27 #include "frontend/parallel/graph_util/graph_utils.h"
28 #include "frontend/parallel/tensor_layout/shape_util.h"
29 #include "frontend/parallel/step_parallel_utils.h"
30 #include "frontend/parallel/tensor_layout/prime_generator.h"
31 #include "frontend/parallel/tensor_layout/layout_utils.h"
32 
33 namespace mindspore {
34 namespace parallel {
MakeFromToLayout(const TensorLayout & from,const TensorLayout & to)35 Status TensorRedistribution::MakeFromToLayout(const TensorLayout &from, const TensorLayout &to) {
36   auto from_layout = from.LayoutForRedistribution();
37   auto to_layout = to.LayoutForRedistribution();
38   if (virtual_rank_ >= 0) {
39     from_origin_ = from_layout;
40     to_origin_ = to_layout;
41     virtual_rank_list_ = {virtual_rank_};
42     return SUCCESS;
43   }
44   if (from.GetVirtualRank().size() == to.GetVirtualRank().size()) {
45     from_origin_ = from_layout;
46     to_origin_ = to_layout;
47     virtual_rank_list_ = from.GetVirtualRank();
48     return SUCCESS;
49   }
50   if (from.GetVirtualRank().size() == 1) {
51     auto device_matrix = from_layout.device_arrangement_origin().array();
52     device_matrix.push_back(to.GetVirtualRank().size());
53     virtual_rank_list_ = to.GetVirtualRank();
54     to_origin_ = to_layout;
55     if (!from_layout.tensor_map_before().empty()) {
56       auto new_tensor_map = from_layout.tensor_map_before();
57       std::for_each(new_tensor_map.begin(), new_tensor_map.end(), [](auto &inner_vec) {
58         std::for_each(inner_vec.begin(), inner_vec.end(), [](auto &val) {
59           if (val >= 0) {
60             val++;
61           }
62         });
63       });
64       return from_origin_.InitFromExtendVector(device_matrix, new_tensor_map, from_layout.tensor_shape_before().array(),
65                                                false, false);
66     }
67     auto new_map = from_layout.origin_tensor_map().array();
68     std::transform(new_map.begin(), new_map.end(), new_map.begin(),
69                    [](const auto &val) { return val >= 0 ? val + 1 : val; });
70     return from_origin_.InitFromVector(device_matrix, new_map, from_layout.tensor_shape().array());
71   }
72   if (to.GetVirtualRank().size() == 1) {
73     auto device_matrix = to_layout.device_arrangement_origin().array();
74     device_matrix.push_back(from.GetVirtualRank().size());
75     virtual_rank_list_ = from.GetVirtualRank();
76     from_origin_ = from_layout;
77     if (!to_layout.tensor_map_before().empty()) {
78       auto new_tensor_map = to_layout.tensor_map_before();
79       std::for_each(new_tensor_map.begin(), new_tensor_map.end(), [](auto &inner_vec) {
80         std::for_each(inner_vec.begin(), inner_vec.end(), [](auto &val) {
81           if (val >= 0) {
82             val++;
83           }
84         });
85       });
86       return to_origin_.InitFromExtendVector(device_matrix, new_tensor_map, to_layout.tensor_shape_before().array(),
87                                              false, false);
88     }
89     auto new_map = to_layout.origin_tensor_map().array();
90     std::transform(new_map.begin(), new_map.end(), new_map.begin(),
91                    [](const auto &val) { return val >= 0 ? val + 1 : val; });
92     return to_origin_.InitFromVector(device_matrix, new_map, to_layout.tensor_shape().array());
93   }
94   MS_LOG(ERROR) << "The from layout sharding micro interleaved num:" << from.GetVirtualRank().size()
95                 << " dose not match the to layout sharding micro interleaved num:" << to.GetVirtualRank().size();
96   return FAILED;
97 }
98 
Init(const TensorLayout & from,const TensorLayout & to,const RankList & dev_list)99 Status TensorRedistribution::Init(const TensorLayout &from, const TensorLayout &to, const RankList &dev_list) {
100   if (MakeFromToLayout(from, to) != SUCCESS) {
101     MS_LOG(ERROR) << "Make from_layout and to_layout failed.";
102     return FAILED;
103   }
104   this->is_dynamic_shape_ = CheckDynamicShape(from, to);
105   if (this->is_dynamic_shape_) {
106     // Dynamic info of func_graph should be considered.
107     MS_LOG(INFO) << "LayoutTransfer inited with dynamic shape.";
108     this->from_origin_no_assembled_ = this->from_origin_;
109     this->to_origin_no_assembled_ = this->to_origin_;
110     Status ret = this->AssembleStaticTensorShape(this->from_origin_no_assembled_, this->to_origin_no_assembled_,
111                                                  &this->from_origin_, &this->to_origin_);
112     if (ret != Status::SUCCESS) {
113       return ret;
114     }
115     this->is_assembled_static_shape_ = true;
116   }
117   const Shape from_origin_shape = from_origin_.tensor_shape().array();
118   const Shape to_origin_shape = to_origin_.tensor_shape().array();
119   bool is_from_dyn = std::find(from_origin_shape.begin(), from_origin_shape.end(), -1) != from_origin_shape.end();
120   bool is_to_dyn = std::find(to_origin_shape.begin(), to_origin_shape.end(), -1) != to_origin_shape.end();
121   if (!is_from_dyn && !is_to_dyn && from_origin_.tensor_shape().size() != to_origin_.tensor_shape().size()) {
122     MS_LOG(ERROR) << "from shape size must be equal to to shape size! from shape size is "
123                   << from_origin_.tensor_shape().size() << ", to shape size is " << to_origin_.tensor_shape().size();
124     MS_LOG(ERROR) << "reshape from_origin_ " << from_origin_.ToString();
125     MS_LOG(ERROR) << "reshape to_origin_ " << to_origin_.ToString();
126     return Status::FAILED;
127   }
128 
129   if (virtual_rank_list_.size() == 1) {
130     dev_list_ = dev_list;
131   } else {
132     for (const auto &rank : dev_list) {
133       for (size_t i = 0; i < virtual_rank_list_.size(); ++i) {
134         dev_list_.push_back(int64_t(rank * virtual_rank_list_.size() + i));
135       }
136     }
137   }
138   from_ = from_origin_.SqueezeShape();
139   to_ = to_origin_.SqueezeShape();
140 
141   this->is_inited_ = true;
142   return Status::SUCCESS;
143 }
144 
CalculateFromTensorShape(Shape * from_shape,const Array & from_factors,const Shape & to_shape,const Array & to_factors)145 Status TensorRedistribution::CalculateFromTensorShape(Shape *from_shape, const Array &from_factors,
146                                                       const Shape &to_shape, const Array &to_factors) {
147   if (from_shape->size() != from_factors.GetDimSize() || to_shape.size() != to_factors.GetDimSize()) {
148     MS_LOG(ERROR) << "Shape size is not equal to factor size.";
149     return Status::FAILED;
150   }
151   int64_t to_layout_added_factor = GetLeastFactorWithoutConstDims(to_shape, to_factors);
152   int64_t to_layout_const_size = GetTensorSize(to_shape);
153   int64_t from_layout_const_size = GetTensorSize(*from_shape);
154   if (to_layout_const_size > from_layout_const_size && to_layout_const_size % from_layout_const_size == 0) {
155     to_layout_added_factor *= (to_layout_const_size / from_layout_const_size);
156   }
157   MS_LOG(INFO) << "from_shape=" << (*from_shape) << ", from_factors=" << from_factors.array()
158                << ", to_shape=" << to_shape << ", to_factors=" << to_factors.array()
159                << ", to_layout_added_factor=" << to_layout_added_factor;
160   if (from_layout_const_size > to_layout_const_size && from_layout_const_size % to_layout_const_size == 0) {
161     int64_t merged_const_factor = from_layout_const_size / to_layout_const_size;
162     // Existed dim in from_layout already satisfy to_layout_added_factor.
163     if (to_layout_added_factor > merged_const_factor && to_layout_added_factor % merged_const_factor == 0) {
164       to_layout_added_factor /= merged_const_factor;
165     }
166     if (to_layout_added_factor == 1) {
167       to_layout_added_factor = -1;
168     }
169   }
170   bool strict_mode = UseStrictMode(*from_shape, to_shape);
171   std::vector<int64_t> known_dims;
172   (void)std::copy_if(from_shape->begin(), from_shape->end(), std::back_inserter(known_dims),
173                      [](int64_t dim) -> bool { return dim != -1; });
174   constexpr size_t INVALID_TENSOR_RANK = 9999;
175   size_t last_dyn_dim = INVALID_TENSOR_RANK;
176   auto last_dyn_dim_iter = std::find(from_shape->rbegin(), from_shape->rend(), -1);
177   if (last_dyn_dim_iter != from_shape->rend()) {
178     last_dyn_dim = from_shape->size() - (last_dyn_dim_iter - from_shape->rbegin()) - 1;
179   }
180   for (size_t i = 0; i < from_shape->size(); ++i) {
181     if (from_shape->at(i) != -1) {
182       continue;
183     }
184     int64_t prime_num = PrimeGenerator::GetInstance()->GetCoprimeNum(known_dims);
185     if (prime_num == -1) {
186       return Status::FAILED;
187     }
188     (*from_shape)[i] = prime_num * from_factors.GetDimByIdx(i);
189     if (strict_mode && from_shape->at(i) < to_factors.GetDimByIdx(i) &&
190         from_factors.GetDimByIdx(i) < to_factors.GetDimByIdx(i)) {
191       int64_t common_factor = std::gcd(from_factors.GetDimByIdx(i), to_factors.GetDimByIdx(i));
192       int64_t left_factor = to_factors.GetDimByIdx(i) / common_factor;
193       (*from_shape)[i] *= left_factor;
194       if (to_layout_added_factor >= left_factor && to_layout_added_factor % left_factor == 0) {
195         to_layout_added_factor /= left_factor;
196       }
197       if (to_layout_added_factor < left_factor) {
198         to_layout_added_factor = -1;
199       }
200     }
201     if (strict_mode && from_shape->at(i) >= to_factors.GetDimByIdx(i) &&
202         from_shape->at(i) % to_factors.GetDimByIdx(i) != 0) {
203       (*from_shape)[i] *= to_factors.GetDimByIdx(i);
204       if (to_layout_added_factor >= to_factors.GetDimByIdx(i) &&
205           to_layout_added_factor % to_factors.GetDimByIdx(i) == 0) {
206         to_layout_added_factor /= to_factors.GetDimByIdx(i);
207       }
208     }
209     if (i == last_dyn_dim && to_layout_added_factor > 0) {
210       if (from_shape->at(i) % to_layout_added_factor != 0) {
211         (*from_shape)[i] *= to_layout_added_factor;
212       }
213       to_layout_added_factor = -1;
214     }
215     known_dims.emplace_back(from_shape->at(i));
216     MS_LOG(DEBUG) << "Replace  " << i << " with value " << from_shape->at(i) << " prime " << prime_num;
217     if (!RecordDimsChange(i, from_shape->at(i), &this->from_dims_replace_memo_)) {
218       MS_LOG(ERROR) << "Index " << i << " conflicts.";
219       return Status::FAILED;
220     }
221   }
222   return Status::SUCCESS;
223 }
224 
EnumerateArray(int64_t base_n,size_t length=100)225 static std::vector<int64_t> EnumerateArray(int64_t base_n, size_t length = 100) {
226   static std::map<int64_t, std::vector<int64_t>> enum_numbers;
227   if (enum_numbers.find(base_n) != enum_numbers.end()) {
228     return enum_numbers.at(base_n);
229   }
230   std::vector<int64_t> array(length);
231   for (size_t i = 1; i < length + 1; ++i) {
232     array[i - 1] = base_n * SizeToLong(i);
233   }
234   return array;
235 }
236 
CalculateToTensorShapeUsingEnumeration(const Shape & from_tsr_shape,Shape * to_tsr_shape,const Array & factors)237 Status TensorRedistribution::CalculateToTensorShapeUsingEnumeration(const Shape &from_tsr_shape, Shape *to_tsr_shape,
238                                                                     const Array &factors) {
239   int64_t src_element_size = GetTensorSize(from_tsr_shape);
240   int64_t dst_element_size = GetTensorSize(*to_tsr_shape);
241   if (src_element_size % dst_element_size != 0) {
242     MS_LOG(ERROR) << "Calculate to tensor shape failed. Tensor shape size is not matched.";
243     return Status::FAILED;
244   }
245   const int64_t dyn_dim_val = -1;
246   int64_t dyn_axis_cnt = std::count(to_tsr_shape->begin(), to_tsr_shape->end(), dyn_dim_val);
247   int64_t left_size = src_element_size / dst_element_size;
248 
249   if (dyn_axis_cnt == 0) {
250     if (left_size != 1) {
251       MS_LOG(ERROR) << "Calculate to tensor shape failed. Tensor shape size is not matched.";
252       return Status::FAILED;
253     }
254     return Status::SUCCESS;
255   }
256 
257   if (dyn_axis_cnt == 1) {
258     /**
259      * Case1:
260      * from: c1, -1(32), c3, c4; to: c1/2, -1(32)*c3, c4
261      */
262     auto iter = std::find(to_tsr_shape->begin(), to_tsr_shape->end(), dyn_dim_val);
263     size_t index = static_cast<size_t>(iter - to_tsr_shape->begin());
264     if (left_size % factors.GetDimByIdx(index) != 0) {
265       MS_LOG(ERROR) << "Generate static shape failed, the shape cannot be divided by factor. dim=" << left_size
266                     << ", factor=" << factors.GetDimByIdx(index);
267       return Status::FAILED;
268     }
269     (*iter) = left_size;
270     if (!RecordDimsChange(index, left_size, &this->to_dims_replace_memo_)) {
271       MS_LOG(ERROR) << "Index " << iter - to_tsr_shape->begin() << " conflicts.";
272       return Status::FAILED;
273     }
274     return Status::SUCCESS;
275   } else {
276     /**
277      * Case2:
278      * from: -1(16), c1, c2; to: -1(2), c1*c2/2, 2*-1(8)
279      * Solution:
280      * -1(16), c1*c2/2, 2
281      *      A,       B, c1*c2/2, 2
282      *      A, c1*c2/2, 2* B
283      *
284      * A*B=3*16 && A%2=0 && B%8=0
285      */
286     std::vector<std::vector<int64_t>> enum_numbers;
287     for (size_t i = 0; i < to_tsr_shape->size(); ++i) {
288       if (to_tsr_shape->at(i) == -1) {
289         std::vector<int64_t> array = EnumerateArray(factors.GetDimByIdx(i));
290         enum_numbers.emplace_back(array);
291       }
292     }
293     std::vector<int64_t> candidates(enum_numbers.size());
294     if (!SolveCombination(from_tsr_shape, 0, enum_numbers, 0, left_size, &candidates)) {
295       MS_LOG(ERROR) << "Not supported for now.";
296       return Status::FAILED;
297     }
298     size_t cnt = 0;
299     for (size_t i = 0; i < to_tsr_shape->size(); ++i) {
300       if (to_tsr_shape->at(i) == -1) {
301         (*to_tsr_shape)[i] = candidates[cnt++];
302         if (!RecordDimsChange(i, to_tsr_shape->at(i), &this->to_dims_replace_memo_)) {
303           MS_LOG(ERROR) << "Index " << i << " conflicts.";
304           return Status::FAILED;
305         }
306       }
307     }
308     return Status::SUCCESS;
309   }
310 }
311 
CalculateToTensorShapeForOneDynamicAxis(const Shape & from_shape,const Shape & origin_to_shape,Shape * to_shape)312 void CalculateToTensorShapeForOneDynamicAxis(const Shape &from_shape, const Shape &origin_to_shape, Shape *to_shape) {
313   Shape from_shape_divisor(from_shape);
314   size_t dynamic_axis = 0;
315   for (size_t i = 0; i < origin_to_shape.size(); ++i) {
316     int64_t dim_val = origin_to_shape[i];
317     (*to_shape)[i] = dim_val;
318     if (dim_val == -1) {
319       dynamic_axis = i;
320       continue;
321     }
322     for (int64_t &from_dim_val : from_shape_divisor) {
323       if (dim_val == 1) {
324         break;
325       }
326       int64_t f = std::gcd(dim_val, from_dim_val);
327       from_dim_val /= f;
328       dim_val /= f;
329     }
330   }
331   (*to_shape)[dynamic_axis] = GetTensorSize(from_shape_divisor);
332   MS_LOG(INFO) << "to_shape=" << (*to_shape) << ", from_shape_divisor=" << from_shape_divisor;
333 }
334 
CalculateToTensorShape(const Shape & from_shape,const Shape & origin_to_shape,const Array & to_in_factors,Shape * to_shape)335 Status TensorRedistribution::CalculateToTensorShape(const Shape &from_shape, const Shape &origin_to_shape,
336                                                     const Array &to_in_factors, Shape *to_shape) {
337   MS_LOG(INFO) << "from_shape=" << from_shape << ", origin_to_shape=" << origin_to_shape
338                << ", to_in_factors=" << to_in_factors.array();
339   // Use forward and backward matching first, if failed, turn to enumeration.
340   if (std::count(origin_to_shape.begin(), origin_to_shape.end(), -1) == 1) {
341     CalculateToTensorShapeForOneDynamicAxis(from_shape, origin_to_shape, to_shape);
342     return Status::SUCCESS;
343   }
344   bool flag_forward_match = ForwardMatching(from_shape, origin_to_shape, to_shape, to_in_factors);
345   if (!flag_forward_match && !BackwardMatching(origin_to_shape, to_shape, to_in_factors)) {
346     MS_LOG(DEBUG) << "Backward matching failed.";
347     if (CalculateToTensorShapeUsingEnumeration(from_shape, to_shape, to_in_factors) != Status::SUCCESS) {
348       MS_LOG(ERROR) << "Calculate to tensor shape failed trying to use enumeration method.";
349       return Status::FAILED;
350     }
351   }
352   return Status::SUCCESS;
353 }
354 
AssembleStaticTensorShape(const TensorLayout & from_in,const TensorLayout & to_in,TensorLayout * new_from_layout,TensorLayout * new_to_layout)355 Status TensorRedistribution::AssembleStaticTensorShape(const TensorLayout &from_in, const TensorLayout &to_in,
356                                                        TensorLayout *new_from_layout, TensorLayout *new_to_layout) {
357   Shape new_from_shape(from_in.tensor_shape().array());
358   Shape original_to_shape = to_in.tensor_shape().array();
359   Array from_in_factors;
360   if (GetFactors(from_in, &from_in_factors) != Status::SUCCESS) {
361     MS_LOG(ERROR) << "Get from_in factors failed.";
362     return Status::FAILED;
363   }
364   Array to_in_factors;
365   if (GetFactors(to_in, &to_in_factors) != Status::SUCCESS) {
366     MS_LOG(ERROR) << "Get to_in factors failed.";
367     return Status::FAILED;
368   }
369   if (CalculateFromTensorShape(&new_from_shape, from_in_factors, original_to_shape, to_in_factors) != Status::SUCCESS) {
370     MS_LOG(ERROR) << "Failed to generate static shape for from_tensor layout: " << from_in.ToString();
371     return Status::FAILED;
372   }
373   Shape new_to_shape(to_in_factors.GetDimSize(), 1);
374   if (CalculateToTensorShape(new_from_shape, original_to_shape, to_in_factors, &new_to_shape)) {
375     MS_LOG(ERROR) << "Failed to generate static shape for to_tensor layout: " << to_in.ToString() << std::endl
376                   << "from_in layout: " << from_in.ToString() << std::endl
377                   << "Already generate from_in shape: " << new_from_shape;
378     return Status::FAILED;
379   }
380   size_t size = std::min(new_from_shape.size(), new_to_shape.size());
381   if (GetTensorSize(new_from_shape) != GetTensorSize(new_to_shape)) {
382     int64_t acc_scalar = 1;
383     for (size_t i = 0; i < size; ++i) {
384       if (new_from_shape.at(i) > new_to_shape.at(i) && new_from_shape.at(i) % new_to_shape.at(i) == 0) {
385         int64_t scalar = new_from_shape.at(i) / new_to_shape.at(i);
386         new_to_shape[i] = new_to_shape[i] * scalar;
387         acc_scalar *= scalar;
388       }
389     }
390     const Shape &f_in_tensor_shape = from_in.tensor_shape().array();
391     auto last_dyn_dim_iter = std::find(f_in_tensor_shape.rbegin(), f_in_tensor_shape.rend(), -1);
392     if (last_dyn_dim_iter != f_in_tensor_shape.rend()) {
393       size_t last_dyn_dim =
394         f_in_tensor_shape.size() - static_cast<size_t>(last_dyn_dim_iter - f_in_tensor_shape.rbegin()) - 1;
395       new_from_shape[static_cast<size_t>(last_dyn_dim)] *= acc_scalar;
396     }
397   }
398 
399   // Unify shape from begin to end.
400   UnifyFromAndToShape(&new_from_shape, &new_to_shape, from_in, to_in, &this->from_dims_replace_memo_);
401 
402   MS_LOG(INFO) << "new_from_shape=" << new_from_shape << ", new_to_shape=" << new_to_shape;
403   if (new_from_layout->InitFromVector(from_in.device_arrangement().array(), from_in.tensor_map().array(),
404                                       new_from_shape) != Status::SUCCESS) {
405     MS_LOG(ERROR) << "Failed to init new from_tensor layout.";
406     return Status::FAILED;
407   }
408   MS_LOG(DEBUG) << "Init new_from_tensor layout, origin:" << from_in.ToString()
409                 << ", new:" << new_from_layout->ToString();
410 
411   if (new_to_layout->InitFromVector(to_in.device_arrangement().array(), to_in.tensor_map().array(), new_to_shape) !=
412       Status::SUCCESS) {
413     MS_LOG(ERROR) << "Failed to init new to_tensor layout.";
414     return Status::FAILED;
415   }
416   MS_LOG(DEBUG) << "Init new_to_layout layout, origin:" << to_in.ToString() << ", new:" << new_to_layout->ToString();
417 
418   return Status::SUCCESS;
419 }
420 
IsVirtualDatasetNextInput(const CNodePtr & cnode,const CNodePtr & dst_cnode,size_t depth=0)421 bool IsVirtualDatasetNextInput(const CNodePtr &cnode, const CNodePtr &dst_cnode, size_t depth = 0) {
422   if (depth >= MAX_RECURSIVE_DEPTH) {
423     return false;
424   }
425   for (size_t j = 1; j < cnode->inputs().size(); ++j) {
426     auto cur_cnode = cnode->input(j)->cast<CNodePtr>();
427     if (cur_cnode == nullptr) {
428       continue;
429     }
430     if (cur_cnode->UniqueId() == dst_cnode->UniqueId()) {
431       return true;
432     }
433     if (IsVirtualDatasetNextInput(cur_cnode, dst_cnode, depth + 1)) {
434       return true;
435     }
436   }
437   return false;
438 }
439 
UpdateShapeNodeInput(const CNodePtr & current_cnode,const CNodePtr & dst_cnode,size_t redistribution_index)440 CNodePtr UpdateShapeNodeInput(const CNodePtr &current_cnode, const CNodePtr &dst_cnode, size_t redistribution_index) {
441   for (size_t i = redistribution_index; i < current_cnode->inputs().size(); ++i) {
442     auto prev_cnode = current_cnode->input(i)->cast<CNodePtr>();
443     if (prev_cnode == nullptr) {
444       continue;
445     }
446     bool found = IsVirtualDatasetNextInput(prev_cnode, dst_cnode);
447     if (found) {
448       MS_LOG(INFO) << "change input to " << current_cnode->input(1)->fullname_with_scope();
449       return prev_cnode;
450     }
451   }
452   return nullptr;
453 }
454 
GetDimMapping(const AssembledDynamicDimsMapping & mapping,int64_t index)455 std::pair<int64_t, AnfNodePtr> GetDimMapping(const AssembledDynamicDimsMapping &mapping, int64_t index) {
456   for (const auto &iter : mapping) {
457     if (SizeToLong(iter.second.first) == index) {
458       return std::make_pair(iter.first, iter.second.second);
459     }
460   }
461   MS_LOG(EXCEPTION) << "Cannot find index " << index << " in AssembledDynamicDimsMapping.";
462 }
463 
UnifyAssembledMappingWithSqueezedFromShape()464 void TensorRedistribution::UnifyAssembledMappingWithSqueezedFromShape() {
465   AssembledDynamicDimsMapping new_mapping;
466   for (const auto &iter : this->dynamic_dim_mapping_) {
467     auto origin_tuple_get_item = iter.second.second;
468     auto origin_tuple_get_item_cnode = origin_tuple_get_item->cast<CNodePtr>();
469     MS_EXCEPTION_IF_NULL(origin_tuple_get_item_cnode);
470     auto func_graph = origin_tuple_get_item->func_graph();
471     MS_EXCEPTION_IF_NULL(func_graph);
472     auto prim_tuple_get_item = std::make_shared<Primitive>(TUPLE_GETITEM_OP);
473     int64_t index = SizeToLong(iter.second.first) + 1;
474     AnfNodePtrList inputs{NewValueNode(prim_tuple_get_item), origin_tuple_get_item_cnode->input(1),
475                           NewValueNode(MakeValue(index))};
476     auto tuple_get_item_cnode = func_graph->NewCNode(inputs);
477     tuple_get_item_cnode->set_fullname_with_scope(iter.second.second->fullname_with_scope());
478     prim_tuple_get_item->set_instance_name("tuple_getitem_for_value_" + std::to_string(iter.first));
479     if (iter.second.second->isa<CNode>()) {
480       auto raw_cnode = iter.second.second->cast<CNodePtr>();
481       if (IsValueNode<Primitive>(raw_cnode->input(0))) {
482         auto prim_node = raw_cnode->input(0)->cast<ValueNodePtr>();
483         auto prim = GetValueNode<PrimitivePtr>(prim_node);
484         prim_tuple_get_item->set_instance_name(prim->instance_name());
485       }
486     }
487     new_mapping.insert({iter.first, {iter.second.first, tuple_get_item_cnode}});
488     MS_LOG(WARNING) << "Adjust TupleGetItem for dim=" << iter.second.first << " to " << iter.second.first + 1
489                     << " to replace value=" << iter.first;
490   }
491   this->dynamic_dim_mapping_ = new_mapping;
492 }
493 
UnifyAssembledMappingWithSameSize(const std::set<int64_t> & index_mapping)494 void TensorRedistribution::UnifyAssembledMappingWithSameSize(const std::set<int64_t> &index_mapping) {
495   Shape from_shape = this->assembled_static_origin_from_.tensor_shape().array();
496   Shape origin_slice_shape = this->assembled_static_origin_from_.slice_shape().array();
497   AssembledDynamicDimsMapping new_mapping;
498   for (int64_t i = SizeToLong(from_shape.size()) - 1; i >= 0; --i) {
499     if (index_mapping.find(i) == index_mapping.end()) {
500       continue;
501     }
502     auto dyn_dim = GetDimMapping(this->dynamic_dim_mapping_, i);
503     int64_t real_dim_value = origin_slice_shape[i];
504     new_mapping.insert({real_dim_value, {i, dyn_dim.second}});
505     MS_LOG(INFO) << "insert at " << i << " with " << real_dim_value;
506   }
507   this->dynamic_dim_mapping_ = new_mapping;
508 }
509 
UnifyAssembledMappingWithDiffSize(const std::set<int64_t> & index_mapping)510 void TensorRedistribution::UnifyAssembledMappingWithDiffSize(const std::set<int64_t> &index_mapping) {
511   auto func_graph = this->next_cnode_->func_graph();
512   MS_EXCEPTION_IF_NULL(func_graph);
513 
514   Shape from_shape = this->assembled_static_origin_from_.tensor_shape().array();
515   Shape origin_slice_shape = this->assembled_static_origin_from_.slice_shape().array();
516   Shape unified_from_shape = this->layout_transfer_.from_in().tensor_shape().array();
517   Shape unified_slice_shape = this->layout_transfer_.from_in().slice_shape().array();
518 
519   AssembledDynamicDimsMapping new_mapping;
520   // Assume length of unified_from_shape must be greater than from_shape.
521   int64_t unified_offset = SizeToLong(unified_from_shape.size()) - 1;
522   for (int64_t i = SizeToLong(from_shape.size()) - 1; i >= 0 && unified_offset >= 0; --i) {
523     int64_t real_dim_value = origin_slice_shape[i];
524     // It means it's a const dim.
525     if (index_mapping.find(i) == index_mapping.end()) {
526       MS_EXCEPTION_IF_CHECK_FAIL(real_dim_value >= unified_slice_shape[unified_offset] &&
527                                    real_dim_value % unified_slice_shape[unified_offset] == 0,
528                                  "Tensor layout tensor shape is illegal.");
529       int64_t left_size = real_dim_value / unified_slice_shape[unified_offset];
530       --unified_offset;
531       if (left_size == 1) {
532         continue;
533       }
534       while (left_size != 1 && unified_offset >= 0) {
535         MS_EXCEPTION_IF_CHECK_FAIL(left_size % unified_slice_shape[unified_offset] == 0,
536                                    "Tensor layout tensor shape is illegal, left_size is " + std::to_string(left_size) +
537                                      ", factor is " + std::to_string(unified_slice_shape[unified_offset]));
538         left_size = left_size / unified_slice_shape[unified_offset];
539         --unified_offset;
540       }
541       continue;
542     }
543     auto dyn_dim = GetDimMapping(this->dynamic_dim_mapping_, i);
544     // It means it's a dynamic dim.
545     if (from_shape[i] == unified_from_shape[unified_offset]) {
546       new_mapping.insert({real_dim_value, {unified_offset, dyn_dim.second}});
547       MS_LOG(INFO) << "insert at " << unified_offset << " with " << real_dim_value;
548       --unified_offset;
549     } else if (from_shape[i] > unified_slice_shape[unified_offset] &&
550                from_shape[i] % unified_slice_shape[unified_offset] == 0) {
551       // left_size must be greater than 1.
552       int64_t left_size = real_dim_value / unified_slice_shape[unified_offset];
553       MS_EXCEPTION_IF_CHECK_FAIL(left_size >= 1, "left_size must be greater than or equal to 1.");
554       int64_t divisor = real_dim_value / unified_slice_shape[unified_offset];
555       if (GetPrimeFactor(unified_slice_shape[unified_offset]) != -1) {
556         AnfNodePtr new_dim_node = CreateDiv(dyn_dim.second, divisor, func_graph, true, "assemble_dynamic_shape_op");
557         new_mapping.insert({unified_slice_shape[unified_offset], {unified_offset, new_dim_node}});
558         MS_LOG(INFO) << "insert at " << unified_offset << " with " << unified_slice_shape[unified_offset];
559       } else {
560         new_mapping.insert({unified_slice_shape[unified_offset], {unified_offset, dyn_dim.second}});
561         MS_LOG(INFO) << "insert at " << unified_offset << " with " << unified_slice_shape[unified_offset];
562       }
563       --unified_offset;
564       while (left_size != 1 && unified_offset >= 0) {
565         left_size = left_size / unified_slice_shape[unified_offset];
566         // If it's prime then add it to mapping.
567         if (GetPrimeFactor(unified_slice_shape[unified_offset]) != -1) {
568           new_mapping.insert({unified_slice_shape[unified_offset], {unified_offset, dyn_dim.second}});
569           MS_LOG(INFO) << "insert at " << unified_offset << " with " << unified_slice_shape[unified_offset];
570         } else {
571           MS_LOG(INFO) << "skip at " << unified_offset << " for " << unified_slice_shape[unified_offset]
572                        << ", because it's not a prime.";
573         }
574         --unified_offset;
575       }
576       if (left_size != 1 && unified_offset < 0) {
577         MS_LOG(EXCEPTION) << "Tensor shape cannot be unified.";
578       }
579     } else {
580       MS_LOG(EXCEPTION) << "Tensor shape cannot be unified.";
581     }
582   }
583   this->dynamic_dim_mapping_ = new_mapping;
584 }
585 
UnifyAssembledMapping()586 void TensorRedistribution::UnifyAssembledMapping() {
587   // 12,10,2,2 -> 2,6,10,2,2, 12 and 10 are all dynamic.
588   //  4, 6,2,2 -> 2,2, 6,2,2, 4 is static and 6 is dynamic.
589   // After refactor, from_origin_ and layer_transfer_.from_in are both in static shape.
590   // 1. If origin_from_shape.size > before_unified_from_shape, it means the shape is squeezed.
591   //   Squeezed could be in head and also be in tail.
592   // 2. If before_unified_from_shape < unified_from_shape, it means the shape is expanded.
593   Shape origin_from_shape = this->from_origin_.tensor_shape().array();
594   Shape origin_from_slice_shape = this->from_origin_.slice_shape().array();
595   Shape before_unified_from_shape = this->assembled_static_origin_from_.tensor_shape().array();
596   Shape before_unified_from_slice_shape = this->assembled_static_origin_from_.slice_shape().array();
597   Shape unified_from_shape = this->layout_transfer_.from_in().tensor_shape().array();
598   Shape unified_from_slice_shape = this->layout_transfer_.from_in().slice_shape().array();
599 
600   std::set<int64_t> index_mapping;
601   for (const auto &iter : this->dynamic_dim_mapping_) {
602     index_mapping.insert(iter.second.first);
603   }
604   MS_LOG(INFO) << "\norigin_from_shape=" << origin_from_shape << ", origin_from_slice_shape=" << origin_from_slice_shape
605                << ", \nbefore_unified_from_shape=" << before_unified_from_shape
606                << ", before_unified_from_slice_shape=" << before_unified_from_slice_shape
607                << ", \nunified_from_shape=" << unified_from_shape
608                << ", unified_from_slice_shape=" << unified_from_slice_shape;
609   if (before_unified_from_shape.size() == origin_from_shape.size() - 1 &&
610       (origin_from_shape.front() == 1 || origin_from_shape.back() == 1)) {
611     // It means unified_from_shape and before_unified_from_shape are squeezed,
612     // origin_from_shape has no squeezed info.
613     MS_LOG(WARNING) << "before_unified_from_shape == origin_from_shape - 1.";
614     this->UnifyAssembledMappingWithSqueezedFromShape();
615     return;
616   }
617   if (unified_from_shape.size() == origin_from_shape.size()) {
618     MS_LOG(WARNING) << "unified_from_shape == origin_from_shape.";
619     this->UnifyAssembledMappingWithSameSize(index_mapping);
620     return;
621   }
622   if (unified_from_shape.size() > before_unified_from_shape.size()) {
623     // In this branch, it means the unified_from_shape is expanded,
624     // or it's reshaped to another shape.
625     MS_LOG(WARNING) << "unified_from_shape > before_unified_from_shape.";
626     if (before_unified_from_shape.size() == origin_from_shape.size() - 1 &&
627         (origin_from_shape.front() == 1 || origin_from_shape.back() == 1)) {
628       // It means shape has been squeezed, so add one to index in mapping.
629       this->UnifyAssembledMappingWithSqueezedFromShape();
630     }
631     this->UnifyAssembledMappingWithDiffSize(index_mapping);
632     return;
633   }
634   MS_LOG(EXCEPTION) << "unified_from_shape.size() must be greater than before_unified_from_shape.size().";
635 }
636 
CreateAssembledDynamicMapping(const CNodePtr & cur_cnode,const AnfNodePtr & pre_cnode,const FuncGraphPtr & func_graph,int64_t redistribution_index)637 void TensorRedistribution::CreateAssembledDynamicMapping(const CNodePtr &cur_cnode, const AnfNodePtr &pre_cnode,
638                                                          const FuncGraphPtr &func_graph, int64_t redistribution_index) {
639   MS_EXCEPTION_IF_NULL(func_graph);
640   if (!this->IsAssembledStaticShape()) {
641     return;
642   }
643   MS_LOG(INFO) << "Start to create assembled dynamic shape mapping for " << pre_cnode->fullname_with_scope() << "->"
644                << cur_cnode->fullname_with_scope();
645   this->dynamic_dim_mapping_.clear();
646 
647   AnfNodePtr shape_root = pre_cnode;
648   if (pre_cnode->isa<CNode>() && IsPrimitiveCNode(pre_cnode, std::make_shared<Primitive>(VIRTUAL_DATA_SET))) {
649     // Find VirtualDataset successor.
650     auto shape_input = UpdateShapeNodeInput(cur_cnode, pre_cnode->cast<CNodePtr>(), redistribution_index);
651     if (shape_input == nullptr) {
652       MS_LOG(WARNING) << "Cannot find real input of shape node.";
653     } else {
654       shape_root = shape_input;
655     }
656   }
657   const std::set<std::string> multi_output_op = {ARGMAXWITHVALUE, LAYER_NORM};
658   if (pre_cnode->isa<CNode>() && IsSomePrimitiveList(pre_cnode->cast<CNodePtr>(), multi_output_op)) {
659     shape_root = cur_cnode->input(redistribution_index);
660     MS_LOG(INFO) << "Change shape_root to " << shape_root->fullname_with_scope();
661   }
662 
663   ReplacementMemo from_layout_memo = this->from_dims_replace_memo_;
664   Shape assembled_origin_slice_shape = this->from_origin_.slice_shape().array();
665   MS_LOG(INFO) << "Start to create assembled dynamic shape mapping: " << pre_cnode->fullname_with_scope() << "->"
666                << cur_cnode->fullname_with_scope() << ", shape_root=" << shape_root->fullname_with_scope()
667                << ", assembled_origin_slice_shape=" << assembled_origin_slice_shape;
668   // 1. New shape and set pre_cnode to its inputs.
669   std::string instance_name = std::string(REDISTRIBUTION_OP) + "_" + pre_cnode->fullname_with_scope();
670   auto shape_cnode = CreateShape(shape_root, func_graph, instance_name + "_get_shape");
671   // 2. Create TupleGetItem node to get dim value and insert to mapping.
672   for (const auto &iter : from_layout_memo) {
673     int64_t dim = SizeToLong(iter.first);
674     int64_t replacement = iter.second;
675     MS_EXCEPTION_IF_CHECK_FAIL(replacement % assembled_origin_slice_shape[LongToSize(dim)] == 0,
676                                "Slice shape is not matched.");
677     MS_EXCEPTION_IF_CHECK_FAIL(LongToSize(dim) < assembled_origin_slice_shape.size(), "Slice shape is not matched.");
678     replacement = assembled_origin_slice_shape[dim];
679     auto prim_tuple_get_item = std::make_shared<Primitive>(TUPLE_GETITEM_OP);
680     AnfNodePtrList inputs{NewValueNode(prim_tuple_get_item), shape_cnode, NewValueNode(MakeValue(dim))};
681     auto tuple_get_item_cnode = func_graph->NewCNode(inputs);
682     tuple_get_item_cnode->set_fullname_with_scope(std::string(REDISTRIBUTION_OP) + "_getitem");
683     prim_tuple_get_item->set_instance_name(instance_name + "_getitem");
684     this->dynamic_dim_mapping_.insert({replacement, {iter.first, tuple_get_item_cnode}});
685     MS_LOG(INFO) << "Create TupleGetItem for dim=" << dim << " to replace value=" << replacement;
686   }
687 }
688 
AppendOperatorVecStr(const OperatorVector & vec,std::string * res)689 void AppendOperatorVecStr(const OperatorVector &vec, std::string *res) {
690   for (size_t i = 0; i < vec.size(); ++i) {
691     res->append(vec.at(i).first);
692     if (i != vec.size() - 1) {
693       res->append(", ");
694     }
695   }
696 }
697 
InferTensorRedistributionOperatorListUnExpand(bool is_cost_model)698 RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorListUnExpand(bool is_cost_model) {
699   MS_LOG(INFO) << "Start to infer tensor redistribution with unexpanded.";
700   TensorLayout from_origin = this->from_origin_;
701   TensorLayout to_origin = this->to_origin_;
702   TensorLayout from_repeat = from_origin.TransferRepeatLayout();
703   TensorLayout to_repeat = to_origin.TransferRepeatLayout();
704   MS_LOG(DEBUG) << "reshape from_origin_ " << from_origin.ToString();
705   MS_LOG(DEBUG) << "reshape to_origin_ " << to_origin.ToString();
706   MS_LOG(DEBUG) << "reshape from_repeat " << from_repeat.ToString();
707   MS_LOG(DEBUG) << "reshape to_repeat " << to_repeat.ToString();
708 
709   OperatorVector operator_vector;
710   OutPutInfoVector output_info_vector;
711   if (InferRedistribution(from_origin, from_repeat, &operator_vector, &output_info_vector, is_cost_model) ==
712       Status::FAILED) {
713     return nullptr;
714   }
715   std::string operator_vec_str;
716   AppendOperatorVecStr(operator_vector, &operator_vec_str);
717   MS_LOG(INFO) << "After InferRedistribution, operator_vector size: " << operator_vector.size()
718                << ", operator_vector: " << operator_vec_str;
719   if (from_repeat.slice_shape().array() != to_repeat.slice_shape().array()) {
720     reshape_flag_ = true;
721     ConstructOperator constructor;
722     constructor.UpdateTensorShape(from_repeat.slice_shape().array());
723     Arrangement shape = to_repeat.slice_shape();
724     MS_LOG(INFO) << "from_repeat.slice_shape is not same with to_repeat.slice_shape: "
725                  << "from_repeat.slice_shape=" << from_repeat.slice_shape().array()
726                  << ", to_repeat.slice_shape=" << to_repeat.slice_shape().array() << ", reshape to "
727                  << shape.ToString();
728     if (constructor.ReshapeOP(shape.array()) == Status::FAILED) {
729       return nullptr;
730     } else {
731       operator_vector.push_back(constructor.GetOperator());
732       output_info_vector.emplace_back(std::make_pair(false, 0));
733     }
734   }
735   if (InferRedistribution(to_repeat, to_origin, &operator_vector, &output_info_vector, is_cost_model) ==
736       Status::FAILED) {
737     return nullptr;
738   }
739   operator_vec_str.clear();
740   AppendOperatorVecStr(operator_vector, &operator_vec_str);
741   MS_LOG(INFO) << "After InferRedistribution, operator_vector size: " << operator_vector.size()
742                << ", operator_vector: " << operator_vec_str;
743   return std::make_shared<std::pair<OperatorVector, OutPutInfoVector>>(
744     std::make_pair(operator_vector, output_info_vector));
745 }
746 
GetRedistributionOperators(const RedistributionOperatorInfer & operator_infer,OperatorVector * operator_vector,OutPutInfoVector * output_info_vector,OperatorList * operator_list)747 void GetRedistributionOperators(const RedistributionOperatorInfer &operator_infer, OperatorVector *operator_vector,
748                                 OutPutInfoVector *output_info_vector, OperatorList *operator_list) {
749   for (const auto &op : operator_infer.operator_vector()) {
750     (void)operator_vector->emplace_back(op);
751   }
752   for (auto info : operator_infer.output_info_vector()) {
753     (void)output_info_vector->emplace_back(info);
754   }
755   for (const auto &opc : operator_infer.operator_list()) {
756     (void)operator_list->emplace_back(opc);
757   }
758 }
759 
InferTensorRedistributionOperatorListForMultiDynamicReshape(bool is_cost_model)760 RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorListForMultiDynamicReshape(
761   bool is_cost_model) {
762   MS_LOG(INFO) << "Start to infer tensor redistribution for multi dynamic axis reshape.";
763   if (this->pre_cnode_ != nullptr && this->next_cnode_ != nullptr) {
764     MS_LOG(DEBUG) << this->PrintRedistribution();
765   }
766   OperatorVector operator_vector;
767   OutPutInfoVector output_info_vector;
768   RedistributionOperatorInfer allgather_infer(this->construct_op_flag_);
769   if (allgather_infer.Init(this->from_origin_no_assembled_, this->to_origin_no_assembled_.tensor_map(), this->dev_list_,
770                            is_cost_model, this->is_dynamic_shape_) == Status::FAILED) {
771     MS_LOG(EXCEPTION) << "Init operatorInfer failed.";
772   }
773   // 1. Do AllGather on dynamic axis, skip const axis?
774   if (allgather_infer.MergePartialToFullForReshapeHasMultiDynamicAxis() != Status::SUCCESS) {
775     MS_LOG(EXCEPTION) << "Insert AllGather for Reshape which has multi dynamic axis failed.";
776   }
777   GetRedistributionOperators(allgather_infer, &operator_vector, &output_info_vector, &this->operator_list_);
778   // 2. Do Reshape. Const axis value should be divided later?
779   ConstructOperator constructor;
780   // Actually, no need to create virtual shape, store the original inputs and replace it later in replace op.
781   Shape full_shape = this->to_origin_no_assembled_.tensor_shape().array();
782   MS_LOG(INFO) << "before ReshapeOP, full_shape:" << full_shape;
783   if (constructor.ReshapeOP(full_shape, true) == Status::FAILED) {
784     MS_LOG(EXCEPTION) << "Cannot construct Reshape op for shape " << full_shape;
785   }
786   (void)operator_vector.emplace_back(constructor.GetOperator());
787   (void)output_info_vector.emplace_back(std::make_pair(false, 0));
788   // 3. Do Split, skip const axis?
789   RedistributionOperatorInfer allsplit_infer(this->construct_op_flag_);
790   if (allsplit_infer.Init(this->to_origin_no_assembled_, this->to_origin_no_assembled_.tensor_map(), this->dev_list_,
791                           is_cost_model, this->is_dynamic_shape_) == Status::FAILED) {
792     MS_LOG(ERROR) << "Init operatorInfer failed";
793     return nullptr;
794   }
795   if (allsplit_infer.SegmentFullShapeToPartial() != Status::SUCCESS) {
796     MS_LOG(EXCEPTION) << "Insert AllSplit for Reshape which has multi dynamic axis failed.";
797   }
798   GetRedistributionOperators(allsplit_infer, &operator_vector, &output_info_vector, &this->operator_list_);
799   std::string operator_vec_str;
800   AppendOperatorVecStr(operator_vector, &operator_vec_str);
801   MS_LOG(INFO) << "After InferAllSplit, operator_vector size: " << operator_vector.size()
802                << ", operator_vector: " << operator_vec_str;
803   return std::make_shared<std::pair<OperatorVector, OutPutInfoVector>>(
804     std::make_pair(operator_vector, output_info_vector));
805 }
806 
InferTensorRedistributionOperatorList(bool is_cost_model)807 RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList(bool is_cost_model) {
808   MS_LOG(INFO) << "Start to infer tensor redistribution.";
809   if (this->pre_cnode_ != nullptr && this->next_cnode_ != nullptr) {
810     MS_LOG(DEBUG) << this->PrintRedistribution();
811   }
812   // Step 1: Match device arrangement between from_ and to_
813   // RedistributionLayoutTransfer layout_transfer;
814   // Step 0: Do dynamic shape to static shape conversion.
815   // TensorRedistribution::Init() only save from and to tensor layout, and squeezed from and to layout.
816   // We can change from_ and to_ in RedistributionLayoutTransfer object directly.
817   // RedistributionLayoutTransfer::Init() will check whether is dynamic shape,
818   // if the static shape cannot be created, reuse early process.
819   Status status = this->layout_transfer_.Init(from_, to_);
820   if (status != Status::SUCCESS) {
821     return nullptr;
822   }
823   TensorLayout from_layout;
824   TensorLayout to_layout;
825   if (this->is_dynamic_shape_ && !this->is_assembled_static_shape_) {
826     from_layout = this->layout_transfer_.from_in();
827     to_layout = this->layout_transfer_.to_in();
828   } else {
829     // init a new layout_transfer
830     // The function of assembled_static_origin_from_ is used to record layout before unify.
831     // When device matrix or tensor shape is needed to unified, it could insert 1 in front of tensor shape
832     // or split a dim into multi dim.
833     this->assembled_static_origin_from_ = this->layout_transfer_.from_in();
834     std::shared_ptr<ReshapeLayoutTransfer> ptr = this->layout_transfer_.UnifyDeviceArrangementAndTensorShape();
835     if (ptr == nullptr) {
836       MS_LOG(ERROR) << "Infer tensor layout return nullptr!";
837       return nullptr;
838     }
839     this->layout_transfer_.Init(ptr->from_in(), ptr->to_in());
840     if (!ptr->ExpandAble()) {
841       expand_able_ = false;
842       return InferTensorRedistributionOperatorListUnExpand(is_cost_model);
843     }
844     from_layout = ptr->from_in();
845     to_layout = ptr->to_in();
846   }
847   MS_LOG(DEBUG) << "reshape from_layout " << from_layout.ToString();
848   MS_LOG(DEBUG) << "reshape to_layout " << to_layout.ToString();
849   MS_LOG(DEBUG) << "reshape from_origin_ " << from_origin_.ToString();
850   MS_LOG(DEBUG) << "reshape to_origin_ " << to_origin_.ToString();
851   MS_LOG(DEBUG) << "reshape from_ " << from_.ToString();
852   MS_LOG(DEBUG) << "reshape to_ " << to_.ToString();
853 
854   // Step 2: Infer redistribution and insert operators
855   OperatorVector operator_vector;
856   OutPutInfoVector output_info_vector;
857   if (InferRedistribution(from_layout, to_layout, &operator_vector, &output_info_vector, is_cost_model) !=
858       Status::SUCCESS) {
859     return nullptr;
860   }
861   //  Step 3: Infer reshape and insert operators
862   if (InferReshape(from_layout, to_layout, &operator_vector, &output_info_vector) != Status::SUCCESS) {
863     MS_LOG(ERROR) << "Construct Reshape operator failed!";
864     return nullptr;
865   }
866   std::string operator_vec_str;
867   AppendOperatorVecStr(operator_vector, &operator_vec_str);
868   MS_LOG(INFO) << "After InferRedistribution, operator_vector size: " << operator_vector.size()
869                << ", operator_vector: " << operator_vec_str;
870   return std::make_shared<std::pair<OperatorVector, OutPutInfoVector>>(
871     std::make_pair(operator_vector, output_info_vector));
872 }
873 
InferTensorRedistributionOperatorVirtualGraphs()874 std::vector<RedistributionOpListPtr> TensorRedistribution::InferTensorRedistributionOperatorVirtualGraphs() {
875   std::vector<RedistributionOpListPtr> redis_list_vector;
876   for (const auto &virtual_rank : virtual_rank_list_) {
877     this->SetVirtualRank(virtual_rank);
878     auto redis_list = this->InferTensorRedistributionOperatorList();
879     if (!redis_list) {
880       MS_LOG(INTERNAL_EXCEPTION) << "Infer tensor redistribution failed. from_layout:" << from_origin_.ToString()
881                                  << ", to_layout:" << to_origin_.ToString();
882     }
883     redis_list_vector.push_back(redis_list);
884   }
885   return redis_list_vector;
886 }
887 
IsSameShape(const Shape & src,const Shape & tgt)888 bool IsSameShape(const Shape &src, const Shape &tgt) {
889   if (src.size() != tgt.size()) {
890     return false;
891   }
892   for (size_t i = 0; i < src.size(); ++i) {
893     if (src[i] == -1 || tgt[i] == -1) {
894       continue;
895     }
896     if (src[i] != tgt[i]) {
897       return false;
898     }
899   }
900   return true;
901 }
902 
AlignToLayoutShape(const Shape & to_origin_shape,const Shape & to_layout_shape)903 Shape AlignToLayoutShape(const Shape &to_origin_shape, const Shape &to_layout_shape) {
904   Shape target_shape(to_origin_shape);
905   auto cnt = std::count(target_shape.begin(), target_shape.end(), -1);
906   if (cnt < SizeToInt(SIZE_TWO) || to_layout_shape[0] != 1 || to_layout_shape.size() - 1 != target_shape.size()) {
907     return target_shape;
908   }
909   for (size_t i = 0; i < target_shape.size(); ++i) {
910     if (target_shape[i] != -1) {
911       continue;
912     }
913     target_shape[i] = to_layout_shape[i + 1];
914   }
915   return target_shape;
916 }
917 
OperatorListIsEmpty(ConstructOperator * constructor,OperatorVector * const operator_vector,OutPutInfoVector * const output_info_vector)918 Status TensorRedistribution::OperatorListIsEmpty(ConstructOperator *constructor, OperatorVector *const operator_vector,
919                                                  OutPutInfoVector *const output_info_vector) {
920   if (from_origin_.base_slice_shape().array() != to_origin_.base_slice_shape().array() || keep_reshape_) {
921     reshape_flag_ = true;
922     constructor->UpdateTensorShape(from_origin_.base_slice_shape().array());
923     Arrangement shape = to_origin_.base_slice_shape();
924     MS_LOG(INFO) << "from_origin_.base_slice_shape is not same with to_origin_.base_slice_shape: "
925                  << "from_origin_.base_slice_shape=" << from_origin_.base_slice_shape().array()
926                  << ", to_origin_.base_slice_shape=" << to_origin_.base_slice_shape().array() << ", reshape to "
927                  << shape.ToString();
928     auto reshape_mode = ReshapeMode::FROM_ORIGIN_BASE_SLICE_TO_TO_ORIGIN_BASE_SLICE;
929     reshape_mode = this->is_dynamic_shape_ ? reshape_mode : ReshapeMode::NO_RESHAPE;
930     if (constructor->ReshapeOP(shape.array(), false, reshape_mode) == Status::FAILED) {
931       return Status::FAILED;
932     } else {
933       (void)operator_vector->insert(operator_vector->cbegin(), constructor->GetOperator());
934       (void)output_info_vector->insert(output_info_vector->cbegin(), std::make_pair(false, 0));
935     }
936   }
937   return Status::SUCCESS;
938 }
939 
InferReshape(const TensorLayout & from_layout,const TensorLayout & to_layout,OperatorVector * const operator_vector,OutPutInfoVector * const output_info_vector)940 Status TensorRedistribution::InferReshape(const TensorLayout &from_layout, const TensorLayout &to_layout,
941                                           OperatorVector *const operator_vector,
942                                           OutPutInfoVector *const output_info_vector) {
943   MS_EXCEPTION_IF_NULL(operator_vector);
944   MS_EXCEPTION_IF_NULL(output_info_vector);
945   ConstructOperator constructor;
946   if (operator_list_.empty()) {
947     return OperatorListIsEmpty(&constructor, operator_vector, output_info_vector);
948   }
949   // 1. 需要知道哪个轴是动态的,哪个轴是常量,只比较常量轴,但是是否能保证from_origin_和from_layout的rank是一样的?
950   // from_origin_是静态,那from_layout也一定是静态,如果from_origin_是动态,那from_layout也一定是动态
951   // 先支持from_origin_和from_layout的rank一样的场景
952   if (!IsSameShape(from_origin_.slice_shape().array(), from_layout.slice_shape().array())) {
953     reshape_flag_ = true;
954     constructor.UpdateTensorShape(from_origin_.slice_shape().array());
955     Arrangement shape = from_layout.slice_shape();
956     MS_LOG(INFO) << "from_origin.slice_shape is not same with from_layout.slice_shape: "
957                  << "from_origin_.slice_shape=" << from_origin_.slice_shape().array()
958                  << ", from_layout.slice_shape=" << from_layout.slice_shape().array() << ", reshape to "
959                  << shape.ToString();
960     auto reshape_mode = ReshapeMode::FROM_ORIGIN_SLICE_TO_FROM_LAYOUT_SLICE;
961     reshape_mode = this->is_dynamic_shape_ ? reshape_mode : ReshapeMode::NO_RESHAPE;
962     if (constructor.ReshapeOP(shape.array(), false, reshape_mode) == Status::FAILED) {
963       return Status::FAILED;
964     } else {
965       // Before all-gather.
966       (void)operator_vector->insert(operator_vector->cbegin(), constructor.GetOperator());
967       (void)output_info_vector->insert(output_info_vector->cbegin(), std::make_pair(false, 0));
968     }
969   }
970 
971   if (from_origin_.base_slice_shape().array() != from_origin_.slice_shape().array()) {
972     reshape_flag_ = true;
973     constructor.UpdateTensorShape(from_origin_.base_slice_shape().array());
974     Arrangement shape = from_origin_.slice_shape();
975     MS_LOG(INFO) << "from_origin_.base_slice_shape is not same with from_origin_.slice_shape: "
976                  << "from_origin_.base_slice_shape=" << from_origin_.base_slice_shape().array()
977                  << ", from_origin_.slice_shape=" << from_origin_.slice_shape().array() << ", reshape to "
978                  << shape.ToString();
979     if (constructor.ReshapeOP(shape.array()) == Status::FAILED) {
980       return Status::FAILED;
981     } else {
982       // Before all-gather.
983       (void)operator_vector->insert(operator_vector->cbegin(), constructor.GetOperator());
984       (void)output_info_vector->insert(output_info_vector->cbegin(), std::make_pair(false, 0));
985     }
986   }
987 
988   if (!IsSameShape(to_origin_.slice_shape().array(), to_layout.slice_shape().array())) {
989     reshape_flag_ = true;
990     constructor.UpdateTensorShape(to_layout.slice_shape().array());
991     // If to_origin_ is all -1, it can not be reshape.
992     Shape target_shape = to_origin_.slice_shape().array();
993     size_t cnt = std::count(target_shape.begin(), target_shape.end(), -1);
994     if (this->IsAssembledStaticShape() && cnt >= SIZE_TWO) {
995       target_shape = AlignToLayoutShape(to_origin_.slice_shape().array(), to_layout.slice_shape().array());
996       MS_LOG(INFO) << "update reshape target shape.";
997     }
998     MS_LOG(INFO) << "to_origin_.slice_shape is not same with to_layout.slice_shape: "
999                  << "to_origin_.slice_shape=" << to_origin_.slice_shape().array()
1000                  << ", to_layout.slice_shape=" << to_layout.slice_shape().array() << ", reshape to " << target_shape;
1001     auto reshape_mode = ReshapeMode::TO_ORIGIN_SLICE_TO_TO_LAYOUT_SLICE;
1002     reshape_mode = this->is_dynamic_shape_ ? reshape_mode : ReshapeMode::NO_RESHAPE;
1003     if (constructor.ReshapeOP(target_shape, false, reshape_mode) == Status::FAILED) {
1004       return Status::FAILED;
1005     } else {
1006       // After all-gather.
1007       (void)operator_vector->insert(operator_vector->cend(), constructor.GetOperator());
1008       (void)output_info_vector->insert(output_info_vector->cend(), std::make_pair(false, 0));
1009     }
1010   }
1011 
1012   if (to_origin_.slice_shape().array() != to_origin_.base_slice_shape().array()) {
1013     reshape_flag_ = true;
1014     constructor.UpdateTensorShape(to_origin_.slice_shape().array());
1015     Arrangement shape = to_origin_.base_slice_shape();
1016     MS_LOG(INFO) << "to_origin_.slice_shape is not same with to_origin_.base_slice_shape: "
1017                  << "to_origin_.slice_shape=" << to_origin_.slice_shape().array()
1018                  << ", to_origin_.base_slice_shape=" << to_origin_.base_slice_shape().array() << ", reshape to "
1019                  << shape.ToString();
1020     if (constructor.ReshapeOP(shape.array()) == Status::FAILED) {
1021       return Status::FAILED;
1022     } else {
1023       // After all-gather.
1024       (void)operator_vector->insert(operator_vector->cend(), constructor.GetOperator());
1025       (void)output_info_vector->insert(output_info_vector->cend(), std::make_pair(false, 0));
1026     }
1027   }
1028   return Status::SUCCESS;
1029 }
1030 
InferRedistribution(const TensorLayout & from_layout,const TensorLayout & to_layout,OperatorVector * const operator_vector,OutPutInfoVector * const output_info_vector,bool is_cost_model)1031 Status TensorRedistribution::InferRedistribution(const TensorLayout &from_layout, const TensorLayout &to_layout,
1032                                                  OperatorVector *const operator_vector,
1033                                                  OutPutInfoVector *const output_info_vector, bool is_cost_model) {
1034   MS_EXCEPTION_IF_NULL(operator_vector);
1035   MS_EXCEPTION_IF_NULL(output_info_vector);
1036   MS_LOG(DEBUG) << "Start to infer redistribution.";
1037   RedistributionOperatorInfer operator_infer(construct_op_flag_);
1038   if (virtual_rank_ >= 0) {
1039     operator_infer.SetVirtualRank(virtual_rank_);
1040   }
1041   if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_, is_cost_model, this->is_dynamic_shape_) ==
1042       Status::FAILED) {
1043     MS_LOG(ERROR) << "Init operatorInfer failed";
1044     return Status::FAILED;
1045   }
1046   if (operator_infer.InferRedistributionOperator() != Status::SUCCESS) {
1047     MS_LOG(ERROR) << "Infer redistribution failed";
1048     return Status::FAILED;
1049   } else {
1050     for (auto op : operator_infer.operator_vector()) {
1051       (void)operator_vector->insert(operator_vector->cend(), op);
1052     }
1053     for (auto info : operator_infer.output_info_vector()) {
1054       (void)output_info_vector->insert(output_info_vector->cend(), info);
1055     }
1056     for (auto opc : operator_infer.operator_list()) {
1057       (void)operator_list_.insert(operator_list_.cend(), opc);
1058     }
1059   }
1060   return Status::SUCCESS;
1061 }
1062 
RollbackToDynamicShape()1063 Status TensorRedistribution::RollbackToDynamicShape() {
1064   if (!this->IsAssembledStaticShape()) {
1065     return Status::FAILED;
1066   }
1067   for (auto &iter : this->from_dims_replace_memo_) {
1068     MS_LOG(DEBUG) << "from index=" << iter.first << ", value=" << iter.second << std::endl;
1069   }
1070   for (auto &iter : this->to_dims_replace_memo_) {
1071     MS_LOG(DEBUG) << "to index=" << iter.first << ", value=" << iter.second << std::endl;
1072   }
1073   MS_LOG(DEBUG) << "RollbackToDynamicShape: from_in_=" << this->from_origin_.ToString() << std::endl
1074                 << "to_in_=" << this->to_origin_.ToString() << std::endl;
1075   return Status::SUCCESS;
1076 }
1077 
ComputeCost()1078 Status TensorRedistribution::ComputeCost() {
1079   RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList(true);
1080   if (redistribution_oplist_ptr == nullptr) {
1081     MS_LOG(ERROR) << "Failure: InferTensorRedistribution failed";
1082     return Status::FAILED;
1083   }
1084   // Compute redistribution communication cost and computation cost
1085   for (auto &op_cost : operator_list_) {
1086     OperatorR op = op_cost.first;
1087     Shape slice_shape = op_cost.second;
1088     double prod =
1089       std::accumulate(slice_shape.begin(), slice_shape.end(), static_cast<double>(1.0), std::multiplies<double>());
1090     std::string str = op.first;
1091     if (str == PERMUTE_BY_AXIS && ComputePermuteCost(prod, op.second) != Status::SUCCESS) {
1092       return Status::FAILED;
1093     } else if (str == CONCAT_BY_AXIS && ComputeConcatCost(prod, op.second) != Status::SUCCESS) {
1094       return Status::FAILED;
1095     } else {
1096       // There is only computation cost in SplitByAxis.
1097       // computation cost = before_slice_shape
1098       computation_cost_ += prod;
1099       // This addition may be erroneous
1100       memory_cost_ += prod;
1101     }
1102   }
1103   if (reshape_flag()) {
1104     Shape prev_shape;
1105     if (expand_able_) {
1106       prev_shape = from_.slice_shape().array();
1107     } else {
1108       prev_shape = from_.tensor_shape().array();
1109     }
1110     double prev_prod =
1111       std::accumulate(prev_shape.begin(), prev_shape.end(), static_cast<double>(1.0), std::multiplies<double>());
1112     computation_cost_ += COST_FACTOR * prev_prod;
1113     memory_cost_ += COST_FACTOR * prev_prod;
1114   }
1115   return Status::SUCCESS;
1116 }
1117 
ComputePermuteCost(double input_size,const Shape & attrs)1118 Status TensorRedistribution::ComputePermuteCost(double input_size, const Shape &attrs) {
1119   // Since AlltoAll is a virtual operator, the expanded operators are used here to compute cost.
1120   // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
1121   if (attrs.size() < TRANSFER_PERMUTE_ARGS_SIZE) {
1122     MS_LOG(ERROR) << "attrs size should not be less than 5!";
1123     return Status::FAILED;
1124   }
1125   forward_comm_cost_ += input_size * ALLTOALL_SCALE_FACTOR;
1126   backward_comm_cost_ += input_size * ALLTOALL_SCALE_FACTOR;
1127   comm_cost_ += COST_FACTOR * input_size * ALLTOALL_SCALE_FACTOR;
1128   int64_t concat_dim = attrs[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
1129   if (concat_dim == 0) {
1130     // memory cost = all_gather
1131     computation_cost_ += input_size;
1132     memory_cost_ += input_size;
1133   } else {
1134     // memory cost = all_gather + split + concat
1135     int64_t dev_num = attrs[TRANSFER_PERMUTE_DEV_NUM_INDEX];
1136     computation_cost_ += (input_size + input_size * dev_num + input_size * dev_num);
1137     memory_cost_ += (input_size * dev_num + input_size * dev_num + input_size);
1138   }
1139   return Status::SUCCESS;
1140 }
1141 
ComputeConcatCost(double input_size,const Shape & attrs)1142 Status TensorRedistribution::ComputeConcatCost(double input_size, const Shape &attrs) {
1143   // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
1144   // computation cost = before_slice_shape
1145   if (attrs.size() < TRANSFER_CONCAT_ARGS_SIZE) {
1146     MS_LOG(ERROR) << "op.second size should not be less than 3!";
1147     return Status::FAILED;
1148   }
1149   double dev_num = attrs[TRANSFER_CONCAT_SPLIT_COUNT_INDEX];
1150   // here, communication cost = all_gather + reduce_scatter
1151   forward_comm_cost_ += input_size * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
1152   backward_comm_cost_ += input_size * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
1153   comm_cost_ += input_size * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
1154   int64_t concat_dim = attrs[TRANSFER_CONCAT_TENSOR_DIM_INDEX];
1155   if (concat_dim == 0) {
1156     // computation cost = all_gather
1157     computation_cost_ += input_size;
1158     memory_cost_ += input_size * dev_num;
1159   } else {
1160     // computation cost = all_gather + split + concat
1161     computation_cost_ += (input_size + input_size * dev_num + input_size * dev_num);
1162     memory_cost_ += (input_size * dev_num + input_size * dev_num + input_size);
1163   }
1164   return Status::SUCCESS;
1165 }
1166 }  // namespace parallel
1167 }  // namespace mindspore
1168