• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <map>
18 #include <vector>
19 #include <string>
20 #include <memory>
21 #include <utility>
22 #include "frontend/parallel/graph_util/graph_utils.h"
23 #include "frontend/parallel/ops_info/ops_utils.h"
24 #include "frontend/parallel/step_parallel_utils.h"
25 #include "frontend/parallel/parameter_manager.h"
26 #include "frontend/parallel/graph_util/generate_graph.h"
27 #include "frontend/parallel/graph_util/graph_info.h"
28 #include "frontend/parallel/tensor_layout/prime_generator.h"
29 #include "mindspore/core/ir/primitive.h"
30 #include "mindspore/core/ir/func_graph.h"
31 #include "include/common/utils/anfalgo.h"
32 
33 namespace mindspore::parallel {
GetPrimeFactor(int64_t value)34 int64_t GetPrimeFactor(int64_t value) {
35   static const std::vector<int64_t> prime_table = PrimeGenerator::GetInstance()->GetPrimeTable();
36   for (const auto &prime : prime_table) {
37     if (prime > value) {
38       return -1;
39     }
40     if (value % prime == 0) {
41       return prime;
42     }
43   }
44   return -1;
45 }
46 
CreateShape(const AnfNodePtr & pre_cnode,const FuncGraphPtr & func_graph,const std::string & inst_name)47 CNodePtr CreateShape(const AnfNodePtr &pre_cnode, const FuncGraphPtr &func_graph, const std::string &inst_name) {
48   auto prim = std::make_shared<Primitive>(SHAPE_OP);
49   prim->set_instance_name(inst_name);
50   AnfNodePtrList shape_node_inputs(SIZE_TWO);
51   shape_node_inputs[0] = NewValueNode(prim);
52   shape_node_inputs[1] = pre_cnode;
53   auto shape_cnode = func_graph->NewCNode(shape_node_inputs);
54   return shape_cnode;
55 }
56 
IsTargetOp(const CNodePtr & cnode,const std::string & target)57 inline bool IsTargetOp(const CNodePtr &cnode, const std::string &target) { return GetPrimName(cnode) == target; }
58 
IsTupleGetItem(const CNodePtr & cnode)59 bool IsTupleGetItem(const CNodePtr &cnode) { return IsTargetOp(cnode, TUPLE_GETITEM_OP); }
60 
IsReshapeOp(const CNodePtr & cnode)61 bool IsReshapeOp(const CNodePtr &cnode) { return IsTargetOp(cnode, RESHAPE); }
62 
IsShapeOp(const CNodePtr & cnode)63 bool IsShapeOp(const CNodePtr &cnode) { return IsTargetOp(cnode, SHAPE_OP); }
64 
GetTensorRedistributionFromCNode(const CNodePtr & node)65 TensorRedistributionPtr GetTensorRedistributionFromCNode(const CNodePtr &node) {
66   OperatorInfoPtr distribute_operator = GetDistributeOperator(node);
67   if (distribute_operator == nullptr) {
68     MS_LOG(WARNING) << node->fullname_with_scope() << " has no OperatorInfo.";
69     return nullptr;
70   }
71   if (IsReshapeOp(node)) {
72     return distribute_operator->reshape_tensor_redistribution();
73   }
74   return distribute_operator->tensor_redistribution();
75 }
76 
IsDynamicOp(const CNodePtr & node)77 bool IsDynamicOp(const CNodePtr &node) {
78   TensorRedistributionPtr tensor_redistribution = GetTensorRedistributionFromCNode(node);
79   if (tensor_redistribution == nullptr) {
80     return false;
81   }
82   return tensor_redistribution->IsAssembledStaticShape();
83 }
84 
FindForwardGraphByRootNodes(const std::vector<AnfNodePtr> & root_all_nodes)85 std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const std::vector<AnfNodePtr> &root_all_nodes) {
86   // J->CNode->Graph
87   std::set<FuncGraphPtr> graph_set;
88   for (auto &node : root_all_nodes) {
89     MS_EXCEPTION_IF_NULL(node);
90     if (!node->isa<CNode>()) {
91       continue;
92     }
93 
94     auto cnode = node->cast<CNodePtr>();
95     if ((cnode->size() < SIZE_TWO) || !IsValueNode<Primitive>(cnode->input(0))) {
96       continue;
97     }
98     auto expect_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
99     if (expect_prim->name() != J && expect_prim->name() != SHARD) {
100       continue;
101     }
102     if (IsValueNode<FuncGraph>(cnode->input(1))) {
103       auto graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
104       MS_LOG(DEBUG) << "Find the forward graph success";
105       (void)graph_set.insert(graph);
106       auto manager = graph->manager();
107       MS_EXCEPTION_IF_NULL(manager);
108       auto graph_used = manager->func_graphs_used_total(graph);
109       for (auto iter = graph_used.cbegin(); iter != graph_used.cend(); ++iter) {
110         (void)graph_set.insert(*iter);
111       }
112     }
113   }
114   return graph_set;
115 }
116 
GetAccuGrad(const std::vector<AnfNodePtr> & parameters,const std::string & weight_name)117 AnfNodePtr GetAccuGrad(const std::vector<AnfNodePtr> &parameters, const std::string &weight_name) {
118   for (auto &param : parameters) {
119     if (!ParameterIsCloned(param)) {
120       continue;
121     }
122 
123     auto param_ptr = param->cast<ParameterPtr>();
124     MS_EXCEPTION_IF_NULL(param_ptr);
125     auto accu_grads_name = std::string(ACCU_GRADS) + "." + weight_name;
126     if (param_ptr->name() == accu_grads_name) {
127       MS_LOG(INFO) << "Find the accumulation grad node: " << param_ptr->name();
128       return param;
129     }
130   }
131   return nullptr;
132 }
133 
CreateMirrorInput(const FuncGraphPtr & root,const Operator & op,const AnfNodePtr & node,const std::string & instance_name,const std::string & weight_name)134 std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operator &op, const AnfNodePtr &node,
135                                           const std::string &instance_name, const std::string &weight_name) {
136   MS_EXCEPTION_IF_NULL(root);
137   MS_EXCEPTION_IF_NULL(node);
138   MS_EXCEPTION_IF_NULL(root->manager());
139 
140   std::string op_name = op.first;
141   OperatorArgs arg_forward = op.second;
142   AnfNodePtr grad_accu = nullptr;
143 
144   int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
145   int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
146   if (grad_accumulation_step > 1 || split_stage_num > 1) {
147     auto parameters = root->parameters();
148     grad_accu = GetAccuGrad(parameters, weight_name);
149     if (!grad_accu && op_name == MICRO_STEP_ALL_GATHER) {
150       MS_LOG(EXCEPTION) << "You should define `accu_grads` when use " << op_name << " parameter:" << weight_name;
151     }
152   }
153 
154   OperatorParams params = arg_forward.second;
155 
156   std::vector<AnfNodePtr> new_node_input;
157   if (op_name == MIRROR_MINI_STEP_OPERATOR || op_name == MINI_STEP_ALL_GATHER ||
158       op_name == MIRROR_MICRO_STEP_OPERATOR || op_name == MICRO_STEP_ALL_GATHER) {
159     MS_EXCEPTION_IF_NULL(grad_accu);
160     new_node_input = {node, grad_accu};
161     MS_LOG(INFO) << "Insert the grad accumulation node as the mirror op's input";
162   } else {
163     new_node_input = {node};
164   }
165 
166   if (!params.empty()) {
167     for (auto &param : params) {
168       AnfNodePtr val = NewValueNode(param.first.second);
169       MS_EXCEPTION_IF_NULL(val);
170       int64_t position = param.second;
171       (void)new_node_input.insert(new_node_input.cbegin() + position - 1, val);
172     }
173   }
174 
175   new_node_input = ConvertToRealInputs(op_name, instance_name, new_node_input, arg_forward.first);
176   // if the op have 'group' attr, set the rank list name for the op
177   SetCommunicationOpGroupLabel(new_node_input);
178   return new_node_input;
179 }
180 
CreateMakeTuple(const std::vector<AnfNodePtr> & tuple_inputs,const FuncGraphPtr & func_graph,const std::string & instance_name="")181 CNodePtr CreateMakeTuple(const std::vector<AnfNodePtr> &tuple_inputs, const FuncGraphPtr &func_graph,
182                          const std::string &instance_name = "") {
183   MS_EXCEPTION_IF_NULL(func_graph);
184   std::vector<AnfNodePtr> make_tuple_inputs(tuple_inputs.size() + 1);
185   auto prim = std::make_shared<Primitive>(MAKE_TUPLE);
186   if (!instance_name.empty()) {
187     prim->set_instance_name(instance_name);
188   }
189   make_tuple_inputs[0] = NewValueNode(prim);
190   for (size_t i = 0; i < tuple_inputs.size(); ++i) {
191     make_tuple_inputs[i + 1] = tuple_inputs[i];
192   }
193   auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
194   return make_tuple;
195 }
196 
CreateSplit(const std::vector<AnfNodePtr> & inputs,const FuncGraphPtr & func_graph,const std::string & inst_name)197 CNodePtr CreateSplit(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph,
198                      const std::string &inst_name) {
199   MS_EXCEPTION_IF_NULL(func_graph);
200   MS_EXCEPTION_IF_CHECK_FAIL(inputs.size() == SIZE_THREE, "inputs is empty.");
201   auto prim = std::make_shared<Primitive>(SPLIT);
202   if (!inst_name.empty()) {
203     prim->set_instance_name(inst_name);
204   }
205   std::vector<AnfNodePtr> split_inputs(SIZE_FOUR);
206   split_inputs[INDEX_ZERO] = NewValueNode(prim);
207   split_inputs[INDEX_ONE] = inputs[INDEX_ZERO];   // split_input
208   split_inputs[INDEX_TWO] = inputs[INDEX_ONE];    // split_axis
209   split_inputs[INDEX_THREE] = inputs[INDEX_TWO];  // split_size
210   auto split = func_graph->NewCNode(split_inputs);
211   return split;
212 }
213 
CreateCast(const AnfNodePtr & cast_input,const ValueNodePtr & dest_type,const FuncGraphPtr & func_graph)214 CNodePtr CreateCast(const AnfNodePtr &cast_input, const ValueNodePtr &dest_type, const FuncGraphPtr &func_graph) {
215   auto cast_prim = NewValueNode(prim::kPrimScalarCast);
216   auto cast = func_graph->NewCNode({cast_prim, cast_input, dest_type});
217   return cast;
218 }
219 
CreateDiv(const AnfNodePtr & input_node,int64_t divisor,const FuncGraphPtr & func_graph,bool to_long,const std::string & inst_name)220 AnfNodePtr CreateDiv(const AnfNodePtr &input_node, int64_t divisor, const FuncGraphPtr &func_graph, bool to_long,
221                      const std::string &inst_name) {
222   MS_EXCEPTION_IF_NULL(input_node);
223   MS_EXCEPTION_IF_NULL(func_graph);
224   MS_EXCEPTION_IF_ZERO("div_divisor", divisor);
225   if (divisor == 1) {
226     return input_node;
227   }
228   auto prim = std::make_shared<Primitive>(SCALAR_FLOOR_DIV);
229   if (!inst_name.empty()) {
230     prim->set_instance_name(inst_name);
231   }
232   std::vector<AnfNodePtr> inputs(SIZE_THREE);
233   inputs[INDEX_ZERO] = NewValueNode(prim);
234   inputs[INDEX_ONE] = input_node;
235   inputs[INDEX_TWO] = CreatInt64Imm(divisor);
236   auto div = func_graph->NewCNode(inputs);
237   if (to_long) {
238     auto type_id = NewValueNode(MakeValue(static_cast<int64_t>(kInt64->type_id())));
239     return CreateCast(div, type_id, func_graph);
240   }
241   return div;
242 }
243 
CreateMul(const AnfNodePtr & input_node,const int64_t factor,const FuncGraphPtr & func_graph,bool to_long=false,const std::string & inst_name="")244 CNodePtr CreateMul(const AnfNodePtr &input_node, const int64_t factor, const FuncGraphPtr &func_graph,
245                    bool to_long = false, const std::string &inst_name = "") {
246   MS_EXCEPTION_IF_NULL(input_node);
247   MS_EXCEPTION_IF_NULL(func_graph);
248   MS_EXCEPTION_IF_ZERO("mul_factor", factor);
249   auto prim = std::make_shared<Primitive>(SCALAR_MUL);
250   if (!inst_name.empty()) {
251     prim->set_instance_name(inst_name);
252   }
253   std::vector<AnfNodePtr> inputs(SIZE_THREE);
254   inputs[INDEX_ZERO] = NewValueNode(prim);
255   inputs[INDEX_ONE] = input_node;
256   inputs[INDEX_TWO] = CreatInt64Imm(factor);
257   auto mul = func_graph->NewCNode(inputs);
258   if (to_long) {
259     auto type_id = NewValueNode(MakeValue(static_cast<int64_t>(kInt64->type_id())));
260     return CreateCast(mul, type_id, func_graph);
261   }
262   return mul;
263 }
264 
MatchWithPrime(const AssembledDynamicDimsMapping & dyn_dims_mapping,int64_t prime)265 bool MatchWithPrime(const AssembledDynamicDimsMapping &dyn_dims_mapping, int64_t prime) {
266   for (const auto &iter : dyn_dims_mapping) {
267     int64_t prime_base = GetPrimeFactor(iter.first);
268     if (prime_base == prime) {
269       return true;
270     }
271   }
272   return false;
273 }
274 
IsSameRank(const Shape & shape_vec,const Shape & targe_shape_vec)275 inline bool IsSameRank(const Shape &shape_vec, const Shape &targe_shape_vec) {
276   return shape_vec.size() == targe_shape_vec.size();
277 }
278 
HasAssebledDynamicDim(const Shape & shape_vec,const AssembledDynamicDimsMapping & dyn_dims_mapping,const TensorRedistributionPtr & tensor_redistribution,bool is_same_rank)279 bool HasAssebledDynamicDim(const Shape &shape_vec, const AssembledDynamicDimsMapping &dyn_dims_mapping,
280                            const TensorRedistributionPtr &tensor_redistribution, bool is_same_rank) {
281   for (int64_t dim : shape_vec) {
282     auto iter = dyn_dims_mapping.find(dim);
283     if (iter != dyn_dims_mapping.end()) {
284       return true;
285     }
286     int64_t prime_base = dim;
287     while (prime_base > 1) {
288       int64_t prime_of_dim = GetPrimeFactor(prime_base);
289       if (prime_of_dim == -1) {
290         break;
291       }
292       if (MatchWithPrime(dyn_dims_mapping, prime_of_dim)) {
293         return true;
294       }
295       prime_base /= prime_of_dim;
296     }
297   }
298   return false;
299 }
300 
MatchingAccordingToPrime(const Shape & shape_vec,const AssembledDynamicDimsMapping & dyn_dims_mapping,const TensorRedistributionPtr & tensor_redistribution,const FuncGraphPtr & func_graph,std::vector<AnfNodePtr> * shape_input,enum ReshapeMode reshape_mode=ReshapeMode::NO_RESHAPE)301 void MatchingAccordingToPrime(const Shape &shape_vec, const AssembledDynamicDimsMapping &dyn_dims_mapping,
302                               const TensorRedistributionPtr &tensor_redistribution, const FuncGraphPtr &func_graph,
303                               std::vector<AnfNodePtr> *shape_input,
304                               enum ReshapeMode reshape_mode = ReshapeMode::NO_RESHAPE) {
305   MS_LOG(INFO) << "Match with prime, shape_vec=" << shape_vec << ", reshape_mode=" << reshape_mode;
306   MS_EXCEPTION_IF_NULL(shape_input);
307   // If the shape not changed, it means not reshape.
308   // So the dynamic dim can be matched according to index.
309   std::string instance_name = std::string(REDISTRIBUTION_OP) + "_" + "assemble_shape";
310   for (size_t i = 0; i < shape_vec.size(); ++i) {
311     int64_t dim = shape_vec[i];
312     // TODO(liuchongming): dim could has more than one prime, have to get all prime in dim.
313     int64_t dim_prime = GetPrimeFactor(dim);
314     bool found = false;
315     if (dim != -1 && dim_prime != -1) {
316       for (const auto &iter : dyn_dims_mapping) {
317         int64_t dim_value_in_graph = iter.first;
318         AnfNodePtr tuple_getitem = iter.second.second;
319         int64_t dyn_prime = GetPrimeFactor(dim_value_in_graph);
320         if (dyn_prime != dim_prime) {
321           continue;
322         }
323         MS_LOG(INFO) << "i=" << i << ", dim_value_in_graph=" << dim_value_in_graph << ", dim_prime=" << dim_prime
324                      << ", dim=" << dim;
325         if (dim_value_in_graph > dim) {
326           int64_t divisor = dim_value_in_graph / dim;
327           AnfNodePtr div_op = CreateDiv(tuple_getitem, divisor, func_graph, false, instance_name);
328           (void)shape_input->emplace_back(div_op);
329           found = true;
330           break;
331         } else if (dim_value_in_graph < dim) {
332           int64_t divisor = dim / dim_value_in_graph;
333           AnfNodePtr mul_op = CreateMul(tuple_getitem, divisor, func_graph, false, instance_name);
334           (void)shape_input->emplace_back(mul_op);
335           found = true;
336           break;
337         } else {
338           (void)shape_input->emplace_back(tuple_getitem);
339           found = true;
340           break;
341         }
342       }
343     }
344     if (!found) {
345       MS_LOG(INFO) << "Cannot find " << dim << " in shape param.";
346       AnfNodePtr val = CreatInt64Imm(dim);
347       (void)shape_input->emplace_back(val);
348     }
349   }
350 }
351 
MatchingAccordingToIndex(const Shape & shape_vec,const AssembledDynamicDimsMapping & dyn_dims_mapping,const TensorRedistributionPtr & tensor_redistribution,const FuncGraphPtr & func_graph,std::vector<AnfNodePtr> * shape_input,enum ReshapeMode reshape_mode=ReshapeMode::NO_RESHAPE)352 void MatchingAccordingToIndex(const Shape &shape_vec, const AssembledDynamicDimsMapping &dyn_dims_mapping,
353                               const TensorRedistributionPtr &tensor_redistribution, const FuncGraphPtr &func_graph,
354                               std::vector<AnfNodePtr> *shape_input,
355                               enum ReshapeMode reshape_mode = ReshapeMode::NO_RESHAPE) {
356   MS_LOG(INFO) << "Match with index, shape_vec=" << shape_vec;
357   MS_EXCEPTION_IF_NULL(shape_input);
358   TensorLayout to_layout = tensor_redistribution->layout_transfer().to_in();
359   TensorLayout from_layout = tensor_redistribution->layout_transfer().from_in();
360   // If the shape not changed, it means not reshape.
361   // So the dynamic dim can be matched according to index.
362   // {index, {prime_dim, AnfNode}}
363   std::map<size_t, std::pair<int64_t, AnfNodePtr>> mapping_table;
364   for (const auto &iter : dyn_dims_mapping) {
365     mapping_table.insert({iter.second.first, {iter.first, iter.second.second}});
366   }
367   for (size_t i = 0; i < shape_vec.size(); ++i) {
368     int64_t dim = shape_vec[i];
369     if (dim != -1 && mapping_table.find(i) != mapping_table.end()) {
370       std::pair<int64_t, AnfNodePtr> tuple_getitem_input_pair = mapping_table[i];
371       int64_t dim_value_in_graph = tuple_getitem_input_pair.first;
372       int64_t dim_prime = GetPrimeFactor(dim);
373       int64_t tuple_getitem_prime = GetPrimeFactor(tuple_getitem_input_pair.first);
374       if (dim_prime != tuple_getitem_prime) {
375         MS_LOG(EXCEPTION) << "Prime in dim and dynamic input are not matched, " << dim_prime << " for " << dim
376                           << " and " << tuple_getitem_prime << " for " << tuple_getitem_input_pair.first;
377       }
378       // After matching with prime, fetch the real dim value in graph and
379       //  calculate whether it needs mul/div.
380       if (dim_value_in_graph > dim) {
381         int64_t divisor = dim_value_in_graph / dim;
382         AnfNodePtr div_op =
383           CreateDiv(tuple_getitem_input_pair.second, divisor, func_graph, true, "assemble_dynamic_shape_op");
384         (void)shape_input->emplace_back(div_op);
385         continue;
386       }
387       if (dim_value_in_graph < dim) {
388         int64_t divisor = dim / dim_value_in_graph;
389         AnfNodePtr mul_op =
390           CreateMul(tuple_getitem_input_pair.second, divisor, func_graph, true, "assemble_dynamic_shape_op");
391         (void)shape_input->emplace_back(mul_op);
392         continue;
393       }
394       (void)shape_input->emplace_back(tuple_getitem_input_pair.second);
395       continue;
396     }
397     MS_LOG(INFO) << "Cannot find " << dim << " in shape param.";
398     AnfNodePtr val = CreatInt64Imm(dim);
399     (void)shape_input->emplace_back(val);
400   }
401 }
402 
CountDynamicAxis(const AnfNodePtrList & shape_input)403 int64_t CountDynamicAxis(const AnfNodePtrList &shape_input) {
404   int64_t dyn_axis_cnt = 0;
405   for (size_t i = 0; i < shape_input.size(); ++i) {
406     if (shape_input[i]->isa<ValueNode>()) {
407       auto val_node = shape_input[i]->cast<ValueNodePtr>();
408       MS_EXCEPTION_IF_NULL(val_node->value());
409       int64_t index = GetValue<int64_t>(val_node->value());
410       if (index == -1) {
411         dyn_axis_cnt += 1;
412       }
413     } else {
414       dyn_axis_cnt += 1;
415     }
416   }
417   return dyn_axis_cnt;
418 }
419 
WhetherIsValueNode(const AnfNodePtr & node)420 inline bool WhetherIsValueNode(const AnfNodePtr &node) { return node->isa<ValueNode>(); }
421 
ConvertConstParamToDynamic(const TensorRedistributionPtr & tensor_redistribution,const Param & param,const FuncGraphPtr & func_graph,bool is_reshape,enum ReshapeMode reshape_mode=ReshapeMode::NO_RESHAPE)422 AnfNodePtr ConvertConstParamToDynamic(const TensorRedistributionPtr &tensor_redistribution, const Param &param,
423                                       const FuncGraphPtr &func_graph, bool is_reshape,
424                                       enum ReshapeMode reshape_mode = ReshapeMode::NO_RESHAPE) {
425   // Only ConvertReshapeInputs will use this function.
426   MS_EXCEPTION_IF_NULL(tensor_redistribution);
427   AssembledDynamicDimsMapping dyn_dims_mapping = tensor_redistribution->GetDynamicDimsMapping();
428   if (dyn_dims_mapping.empty()) {
429     MS_LOG(ERROR) << "Doesn't have dynamic dims mapping.";
430     return nullptr;
431   }
432   std::vector<int64_t> shape_vec = GetValue<std::vector<int64_t>>(param.first.second);
433   if (shape_vec.empty()) {
434     MS_LOG(ERROR) << "Cannot get shape from param.";
435     return nullptr;
436   }
437 
438   // After refactor, dyn_dims_mapping is generated according to origin_from_shape.
439   // Reshape has 3 scenes:
440   // 1. from_origin_->from_layout.from: when shape is squeezed, 1 in front or in back are removed from from_origin.
441   // 2. to_layout.to->to_origin_: when shape is unified, it could be expanded.
442   // 3. User's reshape: written in user's scripts.
443   Shape origin_from_shape = tensor_redistribution->from_origin_layout().tensor_shape().array();
444   Shape origin_slice_from_shape = tensor_redistribution->from_origin_layout().slice_shape().array();
445   Shape from_shape = tensor_redistribution->from_layout().tensor_shape().array();
446   Shape unified_from_shape = tensor_redistribution->layout_transfer().from_in().tensor_shape().array();
447   Shape unified_slice_from_shape = tensor_redistribution->layout_transfer().from_in().slice_shape().array();
448   MS_LOG(INFO) << "reshape_mode=" << reshape_mode << ", shape_vec: " << shape_vec
449                << ", origin_from_shape: " << origin_from_shape
450                << ", \norigin_slice_from_shape: " << origin_slice_from_shape << ", \nfrom_shape: " << from_shape
451                << ", \nunified_from_shape: " << unified_from_shape
452                << ", \nunified_slice_from_shape:" << unified_slice_from_shape;
453   // The rank should be compared between shape_vec and origin_from_shape, because
454   // the mapping is generated according to origin_from_shape.
455   bool is_same_rank = IsSameRank(shape_vec, origin_from_shape);
456   if (!HasAssebledDynamicDim(shape_vec, dyn_dims_mapping, tensor_redistribution, is_same_rank)) {
457     // If the shape_vec is (-1, dim_1) and dim_1 is not a generated fake value by tensor redistribution,
458     // so it doesn't have to match.
459     AnfNodePtr val = NewValueNode(param.first.second);
460     MS_EXCEPTION_IF_NULL(val);
461     val->set_abstract(param.first.second->ToAbstract());
462     return val;
463   }
464   if (shape_vec.size() == 1) {
465     std::vector<int64_t> const_shape{-1};
466     AnfNodePtr val = NewValueNode(const_shape);
467     val->set_abstract(param.first.second->ToAbstract());
468     return val;
469   }
470   std::vector<AnfNodePtr> shape_input;
471   if (reshape_mode == ReshapeMode::FROM_ORIGIN_SLICE_TO_FROM_LAYOUT_SLICE ||
472       reshape_mode == ReshapeMode::TO_ORIGIN_SLICE_TO_TO_LAYOUT_SLICE ||
473       reshape_mode == ReshapeMode::FROM_ORIGIN_BASE_SLICE_TO_TO_ORIGIN_BASE_SLICE) {
474     MatchingAccordingToPrime(shape_vec, dyn_dims_mapping, tensor_redistribution, func_graph, &shape_input,
475                              reshape_mode);
476   } else {
477     if (is_same_rank) {
478       MatchingAccordingToIndex(shape_vec, dyn_dims_mapping, tensor_redistribution, func_graph, &shape_input,
479                                reshape_mode);
480     } else {
481       MatchingAccordingToPrime(shape_vec, dyn_dims_mapping, tensor_redistribution, func_graph, &shape_input,
482                                reshape_mode);
483     }
484   }
485   if (shape_input.size() != shape_vec.size()) {
486     MS_LOG(ERROR) << "shape size is not equal.";
487     return nullptr;
488   }
489 
490   if (is_reshape) {
491     // If only has one dynamic axis, then set it to -1.
492     size_t dyn_axis_cnt = LongToSize(CountDynamicAxis(shape_input));
493     MS_LOG(INFO) << "For shape_vec=" << shape_vec << ", has " << dyn_axis_cnt << " dynamic axis.";
494     if (dyn_axis_cnt == 1) {
495       constexpr int64_t unknown = -1;
496       for (size_t i = 0; i < shape_input.size(); ++i) {
497         if (shape_input[i]->isa<CNode>()) {
498           shape_input[i] = NewValueNode(MakeValue(unknown));
499           MS_LOG(INFO) << "change index " << i << " to -1.";
500           break;
501         }
502       }
503     }
504   }
505   if (std::all_of(shape_input.begin(), shape_input.end(), &WhetherIsValueNode)) {
506     std::vector<int64_t> const_shape(shape_input.size());
507     for (size_t i = 0; i < shape_input.size(); ++i) {
508       auto val_node = shape_input[i]->cast<ValueNodePtr>();
509       MS_EXCEPTION_IF_NULL(val_node->value());
510       int64_t value = GetValue<int64_t>(val_node->value());
511       const_shape[i] = value;
512     }
513     return NewValueNode(const_shape);
514   }
515   auto make_tuple = CreateMakeTuple(shape_input, func_graph, REDISTRIBUTION_OP);
516   return make_tuple;
517 }
518 
ConvertStridedSliceInputs(const OperatorParams & params,const TensorRedistributionPtr & tensor_redistribution_from_cnode,const FuncGraphPtr & func_graph,std::vector<AnfNodePtr> * new_node_input)519 Status ConvertStridedSliceInputs(const OperatorParams &params,
520                                  const TensorRedistributionPtr &tensor_redistribution_from_cnode,
521                                  const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *new_node_input) {
522   for (auto &param : params) {
523     if (param.first.first == BEGIN_MASK || param.first.first == END_MASK || param.first.first == ELLIPSIS_MASK ||
524         param.first.first == NEW_AXIS_MASK || param.first.first == SHRINK_AXIS_MASK) {
525       int64_t value = GetValue<int64_t>(param.first.second);
526       MS_LOG(INFO) << "STRIDEDSLICE: param=" << param.first.first << ", param.second=" << value;
527       AnfNodePtr val = NewValueNode(value);
528       val->set_abstract(param.first.second->ToAbstract());
529       (void)new_node_input->emplace_back(val);
530       continue;
531     }
532     Shape shape_vec = GetValue<Shape>(param.first.second);
533     MS_LOG(INFO) << "STRIDEDSLICE: param=" << param.first.first << ", " << shape_vec;
534     if (param.first.first == END) {
535       auto dynamic_input = ConvertConstParamToDynamic(tensor_redistribution_from_cnode, param, func_graph, false);
536       MS_ERROR_IF_NULL_W_RET_VAL(dynamic_input, FAILED);
537       new_node_input->emplace_back(dynamic_input);
538       continue;
539     }
540     AnfNodePtr val = NewValueNode(shape_vec);
541     MS_ERROR_IF_NULL_W_RET_VAL(val, FAILED);
542     val->set_abstract(param.first.second->ToAbstract());
543     (void)new_node_input->emplace_back(val);
544   }
545   return SUCCESS;
546 }
547 
WhetherMatchingIsNeededForReshape(const Shape & shape_vec,const TensorRedistributionPtr & tensor_redistribution)548 bool WhetherMatchingIsNeededForReshape(const Shape &shape_vec, const TensorRedistributionPtr &tensor_redistribution) {
549   size_t user_specific_dynamic_dim_cnt = std::count(shape_vec.begin(), shape_vec.end(), -1);
550   TensorLayout to_layout = tensor_redistribution->layout_transfer().to_in();
551   Shape to_shape_in_layout = to_layout.slice_shape().array();
552   MS_LOG(INFO) << "shape_vec=" << shape_vec << ", to_shape_in_layout=" << to_shape_in_layout;
553   if (user_specific_dynamic_dim_cnt == 1 && shape_vec.size() == to_shape_in_layout.size()) {
554     size_t dyn_index = static_cast<size_t>(std::find(shape_vec.begin(), shape_vec.end(), -1) - shape_vec.begin());
555     for (size_t i = 0; i < shape_vec.size(); ++i) {
556       if (i != dyn_index && shape_vec[i] != to_shape_in_layout[i]) {
557         return true;
558       }
559     }
560     MS_LOG(INFO) << "No need to matching for shape: " << shape_vec << ", to_shape_in_layout: " << to_shape_in_layout;
561     return false;
562   }
563   return true;
564 }
565 
HasOnlyOneDynamicAxis(const Shape & shape_vec,const TensorRedistributionPtr & tensor_redistribution_from_cnode)566 inline bool HasOnlyOneDynamicAxis(const Shape &shape_vec,
567                                   const TensorRedistributionPtr &tensor_redistribution_from_cnode) {
568   Shape origin_to_no_assembled = tensor_redistribution_from_cnode->to_origin_no_assembled().tensor_shape().array();
569   Shape origin_to_no_assembled_slice = tensor_redistribution_from_cnode->to_origin_no_assembled().slice_shape().array();
570   bool has_only_one_dynamic_axis = std::count(origin_to_no_assembled.begin(), origin_to_no_assembled.end(), -1) == 1;
571   MS_LOG(INFO) << "shape_vec: " << shape_vec << ", origin_to_no_assembled: " << origin_to_no_assembled
572                << ", origin_to_no_assembled_slice: " << origin_to_no_assembled_slice;
573   return (origin_to_no_assembled.size() == shape_vec.size()) && has_only_one_dynamic_axis;
574 }
575 
ReplaceDynamicAxisToNegOne(const TensorRedistributionPtr & tensor_redistribution_from_cnode,Shape * shape_vec)576 void ReplaceDynamicAxisToNegOne(const TensorRedistributionPtr &tensor_redistribution_from_cnode, Shape *shape_vec) {
577   Shape origin_to_no_assembled = tensor_redistribution_from_cnode->to_origin_no_assembled().tensor_shape().array();
578   for (size_t i = 0; i < origin_to_no_assembled.size(); ++i) {
579     if (origin_to_no_assembled[i] == -1) {
580       (*shape_vec)[i] = -1;
581     }
582   }
583 }
584 
ConvertReshapeInputs(const OperatorParams & params,const TensorRedistributionPtr & tensor_redistribution_from_cnode,const FuncGraphPtr & func_graph,std::vector<AnfNodePtr> * new_node_input)585 Status ConvertReshapeInputs(const OperatorParams &params,
586                             const TensorRedistributionPtr &tensor_redistribution_from_cnode,
587                             const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *new_node_input) {
588   Param shape_param;
589   bool use_origin_shape = false;
590   ReshapeMode reshape_mode = ReshapeMode::NO_RESHAPE;
591   for (auto &param : params) {
592     if (param.first.first == SHAPE) {
593       shape_param = param;
594       continue;
595     }
596     if (param.first.first == USE_ORIGIN_SHAPE) {
597       use_origin_shape = GetValue<bool>(param.first.second);
598       MS_LOG(INFO) << "Has USE_ORIGIN_SHAPE = " << use_origin_shape;
599       continue;
600     }
601     if (param.first.first == REDISTRIBUTION_RESHAPE_MODE) {
602       reshape_mode = static_cast<ReshapeMode>(GetValue<int64_t>(param.first.second));
603       MS_LOG(INFO) << "Has REDISTRIBUTION_RESHAPE_MODE = " << reshape_mode;
604       continue;
605     }
606   }
607   Shape shape_vec = GetValue<Shape>(shape_param.first.second);
608   if (shape_vec.size() == 1) {
609     std::vector<int64_t> const_shape{-1};
610     AnfNodePtr val = NewValueNode(const_shape);
611     (void)new_node_input->emplace_back(val);
612     return SUCCESS;
613   }
614   if (use_origin_shape && tensor_redistribution_from_cnode->original_reshape_shape() != nullptr) {
615     // Only reshape in user's code should be in this branch.
616     // original_reshape_shape could be ValueNode, MakeTuple, Shape.
617     (void)new_node_input->emplace_back(tensor_redistribution_from_cnode->original_reshape_shape());
618     return SUCCESS;
619   }
620   size_t dynamic_axis_cnt = std::count(shape_vec.begin(), shape_vec.end(), -1);
621   if (shape_vec.size() > 1 && dynamic_axis_cnt >= SIZE_TWO) {
622     MS_LOG(WARNING) << "The shape of Reshape op has more than one -1, cannot be supported for now.";
623   }
624   Shape origin_to_no_assembled = tensor_redistribution_from_cnode->to_origin_no_assembled().tensor_shape().array();
625   Shape origin_to_no_assembled_slice = tensor_redistribution_from_cnode->to_origin_no_assembled().slice_shape().array();
626   MS_LOG(INFO) << "shape_vec: " << shape_vec << ", reshape_mode: " << reshape_mode
627                << ", origin_to_no_assembled: " << origin_to_no_assembled
628                << ", origin_to_no_assembled_slice: " << origin_to_no_assembled_slice;
629   // if only has one dynamic axis, then replace it with -1 simply.
630   if (reshape_mode == ReshapeMode::NO_RESHAPE && HasOnlyOneDynamicAxis(shape_vec, tensor_redistribution_from_cnode)) {
631     // After HasOnlyOneDynamicAxis checks, shape_vec must have one dynamic axis and it must be prime axis.
632     Shape new_shape_vec(shape_vec);
633     ReplaceDynamicAxisToNegOne(tensor_redistribution_from_cnode, &new_shape_vec);
634     MS_LOG(INFO) << "Replace shape: " << shape_vec << " to new_shape_vec: " << new_shape_vec;
635     AnfNodePtr val = NewValueNode(new_shape_vec);
636     (void)new_node_input->emplace_back(val);
637     return SUCCESS;
638   }
639   if (!WhetherMatchingIsNeededForReshape(shape_vec, tensor_redistribution_from_cnode)) {
640     MS_LOG(INFO) << "No need to matching for " << shape_vec;
641     AnfNodePtr val = NewValueNode(shape_param.first.second);
642     val->set_abstract(shape_param.first.second->ToAbstract());
643     (void)new_node_input->emplace_back(val);
644     return SUCCESS;
645   }
646   auto dynamic_input =
647     ConvertConstParamToDynamic(tensor_redistribution_from_cnode, shape_param, func_graph, true, reshape_mode);
648   MS_ERROR_IF_NULL_W_RET_VAL(dynamic_input, FAILED);
649   (void)new_node_input->emplace_back(dynamic_input);
650   return SUCCESS;
651 }
652 
ConvertSplitInputs(const OperatorParams & params,const FuncGraphPtr & func_graph,std::vector<AnfNodePtr> * new_node_input)653 Status ConvertSplitInputs(const OperatorParams &params, const FuncGraphPtr &func_graph,
654                           std::vector<AnfNodePtr> *new_node_input) {
655   MS_EXCEPTION_IF_CHECK_FAIL(new_node_input->size() == 1,
656                              "new_node_input must and only contain the input of split for split.");
657   auto split_target = new_node_input[0];
658   std::vector<AnfNodePtr> split_inputs = {split_target};
659   ValuePtr output_index;
660   for (auto &param : params) {
661     if (param.first.first == SPLIT_OUTPUT_INDEX) {
662       output_index = param.first.second;
663       continue;
664     }
665     AnfNodePtr val = NewValueNode(param.first.second);
666     MS_EXCEPTION_IF_NULL(val);
667     val->set_abstract(param.first.second->ToAbstract());
668     (void)split_inputs.emplace_back(val);
669   }
670   constexpr char tag[] = "redistribution_allsplit";
671   auto split_op = CreateSplit(split_inputs, func_graph, tag);
672   auto split_output_index = NewValueNode(output_index);
673   auto tuple_get_item_prim = std::make_shared<Primitive>(TUPLE_GETITEM_OP);
674   auto prim_value_node = NewValueNode(tuple_get_item_prim);
675   tuple_get_item_prim->set_instance_name(tag);
676   new_node_input->resize(SIZE_THREE);
677   (*new_node_input)[INDEX_ZERO] = prim_value_node;
678   (*new_node_input)[INDEX_ONE] = split_op;
679   (*new_node_input)[INDEX_TWO] = split_output_index;
680   return SUCCESS;
681 }
682 
IsToBeInsertedSplitOp(const Operator & op)683 bool IsToBeInsertedSplitOp(const Operator &op) {
684   // if split op has attr SPLIT_INSERT_LATER, then skip it in OptimizeTensorRedistributionOperatorList stage,
685   // and insert it in CreateInputs
686   if (op.first != SPLIT) {
687     return false;
688   }
689   OperatorAttrs op_attrs = op.second.first;
690   auto is_skip_func = [](const Attr &attr) -> bool {
691     return attr.first == SPLIT_INSERT_LATER && GetValue<bool>(attr.second);
692   };
693   return std::any_of(op_attrs.begin(), op_attrs.end(), is_skip_func);
694 }
695 
ConvertParamsToInputs(const Operator & op,const TensorRedistributionPtr & tensor_redistribution_from_cnode,const FuncGraphPtr & func_graph,std::vector<AnfNodePtr> * new_node_input)696 Status ConvertParamsToInputs(const Operator &op, const TensorRedistributionPtr &tensor_redistribution_from_cnode,
697                              const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *new_node_input) {
698   MS_ERROR_IF_NULL_W_RET_VAL(tensor_redistribution_from_cnode, FAILED);
699   MS_EXCEPTION_IF_NULL(func_graph);
700   OperatorArgs arg_forward = op.second;
701   OperatorParams params = arg_forward.second;
702 
703   if (op.first == RESHAPE) {
704     if (ConvertReshapeInputs(params, tensor_redistribution_from_cnode, func_graph, new_node_input) != SUCCESS) {
705       return FAILED;
706     }
707   } else if (op.first == STRIDEDSLICE) {
708     if (ConvertStridedSliceInputs(params, tensor_redistribution_from_cnode, func_graph, new_node_input) != SUCCESS) {
709       return FAILED;
710     }
711   } else if (IsToBeInsertedSplitOp(op)) {
712     if (ConvertSplitInputs(params, func_graph, new_node_input) != SUCCESS) {
713       return FAILED;
714     }
715   } else {
716     MS_LOG(DEBUG) << op.first << " is not supported.";
717     return FAILED;
718   }
719   return SUCCESS;
720 }
721 
CreateInput(const Operator & op,const AnfNodePtr & pre_node,const std::string & instance_name,const CNodePtr & cur_cnode)722 std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &pre_node, const std::string &instance_name,
723                                     const CNodePtr &cur_cnode) {
724   MS_EXCEPTION_IF_NULL(pre_node);
725   OperatorArgs arg_forward = op.second;
726   OperatorParams params = arg_forward.second;
727 
728   std::vector<AnfNodePtr> new_node_input = {pre_node};
729   MS_LOG(INFO) << "CreateInput param.empty=" << params.empty() << ", pre_node=" << pre_node->fullname_with_scope()
730                << ", op=" << op.first;
731   bool is_done = false;
732   if (cur_cnode != nullptr) {
733     TensorRedistributionPtr tensor_redistribution = GetTensorRedistributionFromCNode(cur_cnode);
734     // 1. Only deal with Reshape in user scripts.
735     // 2. Deal with non-user Reshape. If only have StrideSliceD, Concat and Split cannot reach.
736     if (tensor_redistribution != nullptr && tensor_redistribution->IsAssembledStaticShape()) {
737       MS_LOG(DEBUG) << cur_cnode->fullname_with_scope() << " distribute_operator is not nullptr";
738       if (ConvertParamsToInputs(op, tensor_redistribution, cur_cnode->func_graph(), &new_node_input) == SUCCESS) {
739         is_done = true;
740       } else {
741         MS_LOG(DEBUG) << "Convert params to inputs failed.";
742       }
743     } else {
744       MS_LOG(INFO) << "cur_cnode=" << cur_cnode->fullname_with_scope() << " is not dynamic node.";
745     }
746   }
747 
748   if (IsToBeInsertedSplitOp(op) && !is_done && cur_cnode != nullptr) {
749     // it means Split on static shape scene.
750     auto ret = ConvertSplitInputs(params, cur_cnode->func_graph(), &new_node_input);
751     MS_EXCEPTION_IF_CHECK_FAIL(ret == SUCCESS, "Insert split op failed.");
752     is_done = true;
753   }
754 
755   if (!is_done && !params.empty()) {
756     for (const auto &param : params) {
757       AnfNodePtr val = NewValueNode(param.first.second);
758       MS_EXCEPTION_IF_NULL(val);
759       val->set_abstract(param.first.second->ToAbstract());
760       int64_t position = param.second;
761       (void)new_node_input.insert(new_node_input.cbegin() + position - 1, val);
762     }
763   }
764 
765   if (!IsToBeInsertedSplitOp(op)) {
766     new_node_input = ConvertToRealInputs(op.first, instance_name, new_node_input, arg_forward.first);
767   }
768   // if the op have 'group' attr, set the rank list name for the op
769   SetCommunicationOpGroupLabel(new_node_input);
770   return new_node_input;
771 }
772 
ReplaceOpInput(const Operator & replace_op,const std::string & instance_name,const CNodePtr & node)773 std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
774                                        const CNodePtr &node) {
775   MS_EXCEPTION_IF_NULL(node);
776   MS_EXCEPTION_IF_NULL(node->func_graph());
777   OperatorArgs arg_replace_op = replace_op.second;
778   OperatorParams params = arg_replace_op.second;
779   if (node->size() < SIZE_TWO) {
780     // GetNext operator dose not has input
781     if (node->size() == 1) {
782       return ConvertToRealInputs(replace_op.first, instance_name, AnfNodePtrList{}, arg_replace_op.first);
783     }
784     MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2";
785   }
786   std::vector<AnfNodePtr> replace_input = {node->input(1)};
787 
788   if (replace_op.first == EMBEDDING_LOOKUP) {
789     replace_input = {node->input(1), node->input(2)};
790   }
791   if (!params.empty() && replace_op.first != SYNC_BATCH_NORM) {
792     Param param_first = *(params.begin());
793     int64_t first_position = param_first.second;
794     if (first_position == 1) {
795       replace_input.pop_back();
796     }
797   }
798   bool is_done = false;
799   bool to_be_converted = replace_op.first == SPLIT || replace_op.first == STRIDEDSLICE || replace_op.first == RESHAPE;
800   if (!params.empty() && to_be_converted && IsDynamicOp(node)) {
801     TensorRedistributionPtr tensor_redistribution = GetTensorRedistributionFromCNode(node);
802     auto ret = ConvertParamsToInputs(replace_op, tensor_redistribution, node->func_graph(), &replace_input);
803     MS_EXCEPTION_IF_CHECK_FAIL(ret == SUCCESS, "ConvertStridedSliceInputs failed.");
804     is_done = true;
805   } else if (!params.empty() && !IsToBeInsertedSplitOp(replace_op)) {
806     for (auto &param : params) {
807       AnfNodePtr val = NewValueNode(param.first.second);
808       if (val == nullptr) {
809         MS_LOG(EXCEPTION) << "Failure:val is nullptr";
810       }
811       int64_t position = param.second;
812       (void)replace_input.insert(replace_input.cbegin() + position - 1, val);
813     }
814   } else if (replace_op.first == SYNC_BATCH_NORM) {
815     for (size_t i = 2; i < node->size(); ++i) {
816       replace_input.push_back(node->input(i));
817     }
818   }
819 
820   if (!IsToBeInsertedSplitOp(replace_op)) {
821     replace_input = ConvertToRealInputs(replace_op.first, instance_name, replace_input, arg_replace_op.first);
822   } else if (IsToBeInsertedSplitOp(replace_op) && !is_done) {
823     // it means Split on static shape scene.
824     auto ret = ConvertSplitInputs(params, node->func_graph(), &replace_input);
825     MS_EXCEPTION_IF_CHECK_FAIL(ret == SUCCESS, "Insert split op failed.");
826   }
827   SetCommunicationOpGroupLabel(replace_input);
828   return replace_input;
829 }
830 
InsertNode(const Operator & op,const CNodePtr & node,size_t index,const AnfNodePtr & pre_node,const FuncGraphPtr & func_graph,const std::string & instance_name,const std::string & param_name,const FuncGraphPtr & root,const TensorRedistributionPtr & tensor_redistribution)831 CNodePtr InsertNode(const Operator &op, const CNodePtr &node, size_t index, const AnfNodePtr &pre_node,
832                     const FuncGraphPtr &func_graph, const std::string &instance_name, const std::string &param_name,
833                     const FuncGraphPtr &root, const TensorRedistributionPtr &tensor_redistribution) {
834   // insert new node before the node
835   MS_EXCEPTION_IF_NULL(node);
836   MS_EXCEPTION_IF_NULL(func_graph);
837   FuncGraphManagerPtr manager = func_graph->manager();
838   MS_EXCEPTION_IF_NULL(manager);
839   ScopePtr scope = node->scope();
840   MS_EXCEPTION_IF_NULL(scope);
841   std::vector<AnfNodePtr> node_input;
842 
843   if (root && !param_name.empty()) {
844     node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name);
845   } else {
846     node_input = CreateInput(op, pre_node, instance_name, node);
847   }
848 
849   CNodePtr new_node = func_graph->NewCNode(node_input);
850   MS_EXCEPTION_IF_NULL(new_node);
851   if (instance_name.find(SPLIT_SENS) == std::string::npos) {
852     new_node->set_in_forward_flag(true);  // mark forward flag
853   }
854   auto new_node_value = node_input[0]->cast<ValueNodePtr>();
855   MS_EXCEPTION_IF_NULL(new_node_value);
856   auto new_node_prim = new_node_value->value()->cast<PrimitivePtr>();
857   new_node_prim->set_instance_name(instance_name);
858   new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
859   if (instance_name.find(NOT_RECOMPUTE) != std::string::npos) {
860     new_node_prim->set_attr("recompute", MakeValue(false));
861   } else if (instance_name.find(RECOMPUTE) != std::string::npos) {
862     new_node_prim->set_attr("recompute", MakeValue(true));
863   }
864 
865   auto primitive = common::AnfAlgo::GetCNodePrimitive(new_node);
866   MS_EXCEPTION_IF_NULL(primitive);
867   if (node->HasPrimalAttr(SEGMENT)) {
868     primitive->AddAttr(SEGMENT, node->GetPrimalAttr(SEGMENT));
869     new_node->AddPrimalAttr(SEGMENT, node->GetPrimalAttr(SEGMENT));
870   }
871   if (node->HasPrimalAttr(MICRO)) {
872     new_node->AddPrimalAttr(MICRO, node->GetPrimalAttr(MICRO));
873   }
874   new_node->set_scope(scope);
875   node_input[0]->set_scope(scope);
876   if (instance_name.find(REDISTRIBUTION_OP) != std::string::npos) {
877     new_node->AddPrimalAttr(kPrimalAttrForwardCommNodeUniqueId, MakeValue<std::string>(new_node->UniqueId()));
878     if (node->HasPrimalAttr(MICRO)) {
879       new_node->AddPrimalAttr(MICRO, node->GetPrimalAttr(MICRO));
880     }
881   }
882   manager->SetEdge(node, SizeToInt(index), new_node);
883   MS_LOG(INFO) << "Insert " << instance_name << " success";
884   return new_node;
885 }
886 
IsRootNode(const CNodePtr & cnode,const AnfNodePtr & root_node)887 bool IsRootNode(const CNodePtr &cnode, const AnfNodePtr &root_node) {
888   // cnode is TupleGetItem.
889   // if first input of op is shape, and the shape first input is the same with reshape.
890   // sometimes the reshape first input maybe is not same with shape first input.
891   auto first_input_of_tuple_getitem = cnode->input(1)->cast<CNodePtr>();
892   if (!IsTargetOp(first_input_of_tuple_getitem, SHAPE_OP)) {
893     return false;
894   }
895   auto first_input_of_shape = first_input_of_tuple_getitem->input(1);
896   if (first_input_of_shape == root_node) {
897     return True;
898   } else {
899     MS_LOG(WARNING) << "Shape's first input is not same with root node.";
900   }
901   return True;
902 }
903 
FindPreviousNodeAndSkipTupleGetItem(const CNodePtr & current,int32_t depth=0)904 std::pair<CNodePtr, int64_t> FindPreviousNodeAndSkipTupleGetItem(const CNodePtr &current, int32_t depth = 0) {
905   // current is TupleGetItem
906   if (depth == MAX_RECURSIVE_DEPTH) {
907     return {nullptr, -1};
908   }
909   auto prev = current->input(1);
910   auto cnode = prev->cast<CNodePtr>();
911   if (IsTupleGetItem(cnode)) {
912     return FindPreviousNodeAndSkipTupleGetItem(cnode, depth + 1);
913   }
914   int64_t index = GetTupleGetItemIndex(current);
915   return {cnode, index};
916 }
917 
ModifyGraph(const CNodePtr & current_cnode,const CNodePtr & previous_tuple_getitem_cnode,size_t input_index)918 bool ModifyGraph(const CNodePtr &current_cnode, const CNodePtr &previous_tuple_getitem_cnode, size_t input_index) {
919   /**
920    * This function must be called after IsRootNode() called and IsRootNode() return True.
921    *
922    * TupleGetItem(tensor, index)
923    * ->
924    * ScalarMul(scalar)
925    * ->
926    * current_cnode
927    */
928   int64_t index = GetTupleGetItemIndex(previous_tuple_getitem_cnode);
929   auto root_node = previous_tuple_getitem_cnode->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>();
930   if (IsTupleGetItem(root_node)) {
931     // keep search the previous node.
932     auto output = FindPreviousNodeAndSkipTupleGetItem(root_node);
933     root_node = output.first;
934   }
935   // Get tensor layout from root_node.
936   if (!root_node->has_user_data<OperatorInfo>()) {
937     // Default/TupleGetItem-op0 has no operator info.
938     MS_LOG(INFO) << root_node->fullname_with_scope() << " has no operator info.";
939     return True;
940   }
941   OperatorInfoPtr distribute_operator = GetDistributeOperator(root_node);
942   MS_EXCEPTION_IF_NULL(distribute_operator);
943   std::vector<TensorInfo> root_tensor_info = distribute_operator->outputs_tensor_info();
944   if (root_tensor_info.size() != 1) {
945     MS_LOG(ERROR) << "Outputs number cannot be larger than 1.";
946     return False;
947   }
948   TensorInfo tensor_info = root_tensor_info[0];
949   Map tensor_map = tensor_info.tensor_layout().tensor_map();
950   Arrangement dev_arr = tensor_info.tensor_layout().device_arrangement();
951   if (LongToSize(index) >= tensor_map.GetDimSize()) {
952     MS_LOG(ERROR) << "Index cannot be larger than tensor_map size.";
953     return False;
954   }
955   int64_t scalar = dev_arr.GetDimByReverseIdx(tensor_map.GetDimByIdx(index));
956   // Create ValueNode for scalar->Create Mul Cnode->Modify inputs and edges
957   Operator scalar_mul_op = CreateScalarMulOp(scalar);
958   InsertNode(scalar_mul_op,                 // to be inserted op
959              current_cnode,                 // current node
960              input_index,                   // input index of current_node
961              previous_tuple_getitem_cnode,  // insert scalar_mul_op between previous and current
962              current_cnode->func_graph(),   // current func_graph
963              "instance_name", "", nullptr);
964   MS_LOG(DEBUG) << tensor_info.tensor_layout().ToString() << ", " << previous_tuple_getitem_cnode->fullname_with_scope()
965                 << " index: " << index << ", scalar: " << scalar;
966   return True;
967 }
968 
UpdateShapeToRootPath(const CNodePtr & cnode,const AnfNodePtr & root_node,int32_t depth=0)969 Status UpdateShapeToRootPath(const CNodePtr &cnode, const AnfNodePtr &root_node, int32_t depth = 0) {
970   if (depth == MAX_RECURSIVE_DEPTH) {
971     return REACH_MAX_RECURSIVE_DEPTH;
972   }
973   auto value_node = cnode->input(0)->cast<ValueNodePtr>();
974   auto prim = value_node->value()->cast<PrimitivePtr>();
975   for (size_t i = 1; i < cnode->inputs().size(); ++i) {
976     auto input = cnode->input(i)->cast<CNodePtr>();
977     if (input == nullptr) {
978       continue;
979     }
980     if (IsTupleGetItem(input) && IsRootNode(input, root_node)) {
981       // Modify this graph path.
982       if (!ModifyGraph(cnode, input, i)) {
983         MS_LOG(ERROR) << "Failed to modify graph.";
984         return Status::FAILED;
985       }
986       return Status::SUCCESS;
987     }
988     // Keep traceback.
989     Status ret = UpdateShapeToRootPath(input, root_node, depth + 1);
990     if (ret != Status::SUCCESS) {
991       return Status::FAILED;
992     }
993   }
994   return Status::SUCCESS;
995 }
996 
UpdatePartialShape(const CNodePtr & cnode)997 Status UpdatePartialShape(const CNodePtr &cnode) {
998   // Traceback shape_of_reshape input of Reshape Op.
999   MS_EXCEPTION_IF_NULL(cnode);
1000   MS_EXCEPTION_IF_CHECK_FAIL(cnode->inputs().size() == RESHAPE_INPUT_SIZE,
1001                              "Reshape op must have " + std::to_string(RESHAPE_INPUT_SIZE) + " inputs.");
1002   // Step1. Get second input of Reshape op which represent shape_of_reshape.
1003   // Step2. Visit shape_of_reshape and trace back to dynamic axis.
1004   auto input_of_reshape = cnode->input(RESHAPE_INPUT_SIZE - 2);
1005   auto shape_of_reshape = cnode->input(RESHAPE_INPUT_SIZE - 1);
1006   auto shape_cnode = shape_of_reshape->cast<CNodePtr>();  // MakeTuple
1007   if (shape_cnode == nullptr) {
1008     return Status::SUCCESS;
1009   }
1010   for (const auto &input : shape_cnode->inputs()) {
1011     auto cnode_input = input->cast<CNodePtr>();
1012     if (cnode_input == nullptr) {
1013       continue;
1014     }
1015     if (UpdateShapeToRootPath(cnode_input, input_of_reshape) != Status::SUCCESS) {
1016       MS_LOG(ERROR) << "Update " << cnode->fullname_with_scope() << " previous shape failed.";
1017       return Status::FAILED;
1018     }
1019   }
1020   return Status::SUCCESS;
1021 }
1022 
FindPreviousCareNode(const CNodePtr & current,int32_t depth=0)1023 CNodePtr FindPreviousCareNode(const CNodePtr &current, int32_t depth = 0) {
1024   if (depth == MAX_RECURSIVE_DEPTH) {
1025     return nullptr;
1026   }
1027   auto prev = current->input(1);
1028   // If prev is parameter maybe problem here.
1029   auto cnode = prev->cast<CNodePtr>();
1030   if (cnode == nullptr) {
1031     MS_LOG(INFO) << "Input of node is not a cnode: " << prev->fullname_with_scope();
1032     return nullptr;
1033   }
1034   if (!IsParallelCareNode(cnode) && (IsTargetOp(cnode, "Cast") || IsTupleGetItem(cnode))) {
1035     return FindPreviousCareNode(cnode, depth + 1);
1036   }
1037   return cnode;
1038 }
1039 
GetDistributeOperatorFromCNode(const CNodePtr & cnode,TensorInfo * tensor_info)1040 Status GetDistributeOperatorFromCNode(const CNodePtr &cnode, TensorInfo *tensor_info) {
1041   MS_EXCEPTION_IF_NULL(cnode);
1042   CNodePtr target_cnode = cnode;
1043   if (!IsParallelCareNode(cnode)) {
1044     // keep search the previous node.
1045     target_cnode = FindPreviousCareNode(cnode);
1046   }
1047   if (target_cnode == nullptr) {
1048     return Status::FAILED;
1049   }
1050   if (!target_cnode->has_user_data<OperatorInfo>()) {
1051     MS_LOG(EXCEPTION) << "Found " << cnode->fullname_with_scope() << " previous node is "
1052                       << target_cnode->fullname_with_scope() << " and it has no operator info.";
1053   }
1054 
1055   OperatorInfoPtr distribute_operator = GetDistributeOperator(target_cnode);
1056   MS_EXCEPTION_IF_NULL(distribute_operator);
1057   std::vector<TensorInfo> root_tensor_info = distribute_operator->outputs_tensor_info();
1058   if (root_tensor_info.size() != 1) {
1059     if (IsTupleGetItem(cnode)) {
1060       int64_t output_index = GetTupleGetItemIndex(cnode);
1061       MS_EXCEPTION_IF_CHECK_FAIL(
1062         (output_index >= 0 && output_index < SizeToLong(root_tensor_info.size())),
1063         "TupleGetItem index is not matched with its input length, TupleGetItem is " + cnode->fullname_with_scope());
1064       MS_LOG(INFO) << "Replace tensor info use " << target_cnode->fullname_with_scope() << " with index "
1065                    << output_index;
1066       (*tensor_info) = root_tensor_info[output_index];
1067       return Status::SUCCESS;
1068     }
1069     MS_LOG(WARNING) << "Outputs number cannot be larger than 1, but " << target_cnode->fullname_with_scope() << " has "
1070                     << root_tensor_info.size() << " outputs.";
1071   }
1072   (*tensor_info) = root_tensor_info[0];
1073   return Status::SUCCESS;
1074 }
1075 
UpdateTupleGetItemShapeValue(const CNodePtr & tuple_getitem,const TensorInfo & tensor_info,const FuncGraphPtr & func_graph)1076 Status UpdateTupleGetItemShapeValue(const CNodePtr &tuple_getitem, const TensorInfo &tensor_info,
1077                                     const FuncGraphPtr &func_graph) {
1078   MS_LOG(INFO) << "into UpdateTupleGetItemShapeValue";
1079   Map tensor_map = tensor_info.tensor_layout().tensor_map();
1080   Arrangement dev_arr = tensor_info.tensor_layout().device_arrangement();
1081   auto manager = func_graph->manager();
1082   MS_EXCEPTION_IF_NULL(manager);
1083   auto node_users_map = manager->node_users();
1084 
1085   int64_t index = GetTupleGetItemIndex(tuple_getitem);
1086   if (LongToSize(index) >= tensor_map.GetDimSize()) {
1087     MS_LOG(ERROR) << "Index cannot be larger than tensor_map size.";
1088     return Status::FAILED;
1089   }
1090   if (tensor_map.GetDimByIdx(index) < 0) {
1091     MS_LOG(DEBUG) << "Skip index " << index << ", because it's " << tensor_map.GetDimByIdx(index);
1092     return Status::SUCCESS;
1093   }
1094   int64_t scalar = dev_arr.GetDimByReverseIdx(tensor_map.GetDimByIdx(index));
1095   for (const auto &next_node : node_users_map[tuple_getitem]) {
1096     auto tuple_getitem_user = next_node.first->cast<CNodePtr>();
1097     if (tuple_getitem_user == nullptr) {
1098       MS_LOG(DEBUG) << "tuple_getitem_user is nullptr";
1099       continue;
1100     }
1101     MS_LOG(INFO) << tuple_getitem->input(1)->fullname_with_scope() << "->" << tuple_getitem->fullname_with_scope()
1102                  << "->ScalarMul(" << scalar << ")->" << next_node.first->fullname_with_scope() << "["
1103                  << next_node.second << "]" << std::endl;
1104     Operator scalar_mul_op = CreateScalarMulOp(scalar);
1105     (void)InsertNode(scalar_mul_op,                     // to be inserted op
1106                      tuple_getitem_user,                // current node
1107                      next_node.second,                  // tuple_getitem_user[input_index] = scalar_mul_op
1108                      tuple_getitem,                     // insert scalar_mul_op between previous and current
1109                      tuple_getitem_user->func_graph(),  // current func_graph
1110                      "update_partial_shape", "", nullptr);
1111   }
1112   return Status::SUCCESS;
1113 }
1114 
UpdateReshapeShapeValue(const CNodePtr & reshape_cnode,const CNodePtr & shape_cnode,const Shape & shape,const TensorInfo & tensor_info,const FuncGraphPtr & func_graph)1115 Status UpdateReshapeShapeValue(const CNodePtr &reshape_cnode, const CNodePtr &shape_cnode, const Shape &shape,
1116                                const TensorInfo &tensor_info, const FuncGraphPtr &func_graph) {
1117   // Replace shape to MakeTuple(shape[0]*factor0, shape[1]*factor1,...)
1118   MS_LOG(INFO) << "into UpdateReshapeShapeValue: " << shape;
1119   MS_EXCEPTION_IF_NULL(reshape_cnode);
1120   MS_EXCEPTION_IF_NULL(shape_cnode);
1121   MS_EXCEPTION_IF_NULL(func_graph);
1122   FuncGraphManagerPtr manager = func_graph->manager();
1123   MS_EXCEPTION_IF_NULL(manager);
1124   Map tensor_map = tensor_info.tensor_layout().tensor_map();
1125   Arrangement dev_arr = tensor_info.tensor_layout().device_arrangement();
1126   TensorRedistributionPtr tensor_redistribution = GetTensorRedistributionFromCNode(reshape_cnode);
1127 
1128   std::vector<AnfNodePtr> make_tuple_inputs;
1129   std::string instance_name = std::string(REDISTRIBUTION_OP) + "_replace_reshape";
1130   for (size_t i = 0; i < shape.size(); ++i) {
1131     if (shape[i] > 0) {
1132       // Get const value and set to make_tuple_inputs.
1133       auto const_val_node = NewValueNode(MakeValue(shape[i]));
1134       make_tuple_inputs.emplace_back(const_val_node);
1135       MS_LOG(INFO) << "Create ValueNode " << shape[i];
1136       continue;
1137     }
1138     // Get shape from shape node.
1139     auto prim_tuple_get_item = std::make_shared<Primitive>(TUPLE_GETITEM_OP);
1140     AnfNodePtrList inputs{NewValueNode(prim_tuple_get_item), shape_cnode, NewValueNode(MakeValue(SizeToLong(i)))};
1141     auto tuple_get_item_cnode = func_graph->NewCNode(inputs);
1142     tuple_get_item_cnode->set_fullname_with_scope("tuple_getitem_replace_reshape");
1143     prim_tuple_get_item->set_instance_name(instance_name);
1144     make_tuple_inputs.emplace_back(tuple_get_item_cnode);
1145     MS_LOG(INFO) << "Create TupleGetItem for " << i;
1146   }
1147   auto make_tuple = CreateMakeTuple(make_tuple_inputs, func_graph, instance_name);
1148   make_tuple->set_in_forward_flag(true);
1149   std::string fullname = shape_cnode->fullname_with_scope() + "_replace";
1150   make_tuple->set_fullname_with_scope(fullname);
1151   manager->SetEdge(reshape_cnode, INDEX_TWO, make_tuple);
1152   MS_LOG(INFO) << shape_cnode->fullname_with_scope() << "->" << make_tuple->fullname_with_scope() << "->"
1153                << reshape_cnode->fullname_with_scope();
1154   MS_LOG(INFO) << "reshape shape is : " << shape;
1155   MS_LOG(INFO) << "reshape tensor_map is : " << tensor_map.array();
1156   MS_LOG(INFO) << "reshape dev_arr is : " << dev_arr.array();
1157   for (size_t i = 0; i < tensor_map.array().size(); ++i) {
1158     if (tensor_map.GetDimByIdx(i) == -1) {
1159       continue;
1160     }
1161     if (make_tuple_inputs[i]->isa<ValueNode>()) {
1162       continue;
1163     }
1164     int64_t scalar = dev_arr.GetDimByReverseIdx(tensor_map.GetDimByIdx(i));
1165     Operator scalar_mul_op = CreateScalarMulOp(scalar);
1166     (void)InsertNode(scalar_mul_op,             // to be inserted op
1167                      make_tuple,                // current node
1168                      i + 1,                     // make_tuple[input_index] = scalar_mul_op
1169                      make_tuple->input(i + 1),  // insert scalar_mul_op between previous and current
1170                      func_graph,                // current func_graph
1171                      "update_partial_shape", "", nullptr);
1172   }
1173   if (tensor_redistribution != nullptr && tensor_redistribution->original_reshape_shape() != nullptr) {
1174     tensor_redistribution->set_original_reshape_shape(make_tuple);
1175     MS_LOG(INFO) << "Change original_reshape_shape";
1176   }
1177   return Status::SUCCESS;
1178 }
1179 
SkipSupplyForReshape(const CNodePtr & cnode)1180 bool SkipSupplyForReshape(const CNodePtr &cnode) {
1181   if (!IsReshapeOp(cnode)) {
1182     return false;
1183   }
1184   auto prim = GetCNodePrimitive(cnode);
1185   if (prim->HasAttr(SKIP_REDISTRIBUTION)) {
1186     bool skip_redistribution = GetValue<bool>(prim->GetAttr(SKIP_REDISTRIBUTION));
1187     return skip_redistribution;
1188   }
1189   return false;
1190 }
1191 
UpdateShapeNode(const CNodePtr & cnode,const FuncGraphPtr & func_graph)1192 Status UpdateShapeNode(const CNodePtr &cnode, const FuncGraphPtr &func_graph) {
1193   MS_EXCEPTION_IF_NULL(cnode);
1194   // Step1. Get shape input tensor layout. cnode is Shape op.
1195   auto input_of_shape = cnode->input(1);
1196   auto input_cnode = input_of_shape->cast<CNodePtr>();
1197   if (input_cnode == nullptr) {
1198     return Status::SUCCESS;
1199   }
1200   if (SkipSupplyForReshape(input_cnode)) {
1201     MS_LOG(INFO) << "Skip " << cnode->fullname_with_scope() << ", because its input is reshape.";
1202     return Status::SUCCESS;
1203   }
1204   if (IsValueNode<FuncGraph>(input_cnode->input(0))) {
1205     // It means it's a sub-graph call node.
1206     MS_LOG(WARNING) << "If the input of shape is subgraph, and it's outputs sharding strategy "
1207                        "is not all 1, it could be problem.";
1208     return Status::SUCCESS;
1209   }
1210   TensorInfo tensor_info;
1211   if (GetDistributeOperatorFromCNode(input_cnode, &tensor_info) != Status::SUCCESS) {
1212     return Status::SUCCESS;
1213   }
1214   Map tensor_map = tensor_info.tensor_layout().tensor_map();
1215   Arrangement dev_arr = tensor_info.tensor_layout().device_arrangement();
1216 
1217   // Step2. Get shape node users.
1218   auto node_users_map = func_graph->manager()->node_users();
1219   auto shape_node_users = node_users_map[cnode];
1220   for (const auto &node_user : shape_node_users) {
1221     MS_EXCEPTION_IF_NULL(node_user.first);
1222     auto shape_user = node_user.first->cast<CNodePtr>();
1223     if (IsReshapeOp(shape_user)) {
1224       std::vector<Shape> input_shapes = GetNodeShape(input_of_shape);
1225       if (input_shapes.size() != 1) {
1226         MS_LOG(EXCEPTION) << "Shape's input size is illegal.";
1227       }
1228       if (UpdateReshapeShapeValue(shape_user, cnode, input_shapes[0], tensor_info, func_graph) != Status::SUCCESS) {
1229         MS_LOG(EXCEPTION) << "Update reshape shape value failed.";
1230       }
1231       continue;
1232     }
1233     if (shape_user == nullptr || IsTargetOp(shape_user, ZEROS)) {
1234       MS_LOG(ERROR) << "won't supply shape for " << shape_user->fullname_with_scope();
1235       continue;
1236     }
1237     MS_EXCEPTION_IF_CHECK_FAIL(IsTupleGetItem(shape_user),
1238                                "Only support TupleGetItem here, but got " + GetPrimName(shape_user));
1239     if (IsTupleGetItem(shape_user) &&
1240         UpdateTupleGetItemShapeValue(shape_user, tensor_info, func_graph) != Status::SUCCESS) {
1241       MS_LOG(EXCEPTION) << "Update tuple get item shape value failed.";
1242     }
1243   }
1244   return Status::SUCCESS;
1245 }
1246 
UpdateMakeTupleShapeValue(const CNodePtr & make_tuple,const std::map<size_t,int64_t> & factor_mapping,const FuncGraphPtr & func_graph)1247 Status UpdateMakeTupleShapeValue(const CNodePtr &make_tuple, const std::map<size_t, int64_t> &factor_mapping,
1248                                  const FuncGraphPtr &func_graph) {
1249   for (size_t i = 1; i < make_tuple->inputs().size(); ++i) {
1250     if (factor_mapping.find(i - 1) == factor_mapping.end()) {
1251       continue;
1252     }
1253     auto make_tuple_input = make_tuple->input(i);
1254     if (make_tuple_input->isa<ValueNode>()) {
1255       auto val_node = make_tuple_input->cast<ValueNodePtr>();
1256       MS_EXCEPTION_IF_NULL(val_node->value());
1257       auto dim_value = GetValue<int64_t>(val_node->value());
1258       if (dim_value == -1) {
1259         continue;
1260       }
1261     }
1262     Operator scalar_div_op = CreateScalarDivOp(factor_mapping.at(i - 1));
1263     // TODO(liuchongming): If make_tuple_input is mul op, then consider merge the two op.
1264     auto div_cnode = InsertNode(scalar_div_op,     // to be inserted op
1265                                 make_tuple,        // current node
1266                                 i,                 // tuple_getitem_user[i] = scalar_div_op
1267                                 make_tuple_input,  // insert scalar_div_op between previous and current
1268                                 func_graph,        // current func_graph
1269                                 "segment_partial_shape", "", nullptr);
1270     Operator cast_op = CreateScalarCastOp(kInt64);
1271     (void)InsertNode(cast_op,     // to be inserted op
1272                      make_tuple,  // current node
1273                      i,           // tuple_getitem_user[i] = cast_op
1274                      div_cnode,   // div_cnode->scalar_div_op->make_tuple
1275                      func_graph,  // current func_graph
1276                      "segment_partial_shape", "", nullptr);
1277   }
1278   return Status::SUCCESS;
1279 }
1280 
SegmentEntireShapeToPartialForDynamic(const CNodePtr & reshape_node,const FuncGraphPtr & func_graph)1281 Status SegmentEntireShapeToPartialForDynamic(const CNodePtr &reshape_node, const FuncGraphPtr &func_graph) {
1282   MS_EXCEPTION_IF_NULL(reshape_node);
1283   // reshape_node is Reshape node.
1284   // Step1. Get reshape_node's user tensor layout.
1285   // Step2. Shard reshape_node's second input (only for TupleGetItem).
1286   auto tensor_redistribution = GetTensorRedistributionFromCNode(reshape_node);
1287   if (tensor_redistribution == nullptr) {
1288     MS_LOG(WARNING) << "Cannot find layout in " << reshape_node->fullname_with_scope();
1289     return Status::FAILED;
1290   }
1291   if (!tensor_redistribution->is_dynamic_shape()) {
1292     MS_LOG(INFO) << reshape_node->fullname_with_scope() << " is static shape.";
1293     return Status::SUCCESS;
1294   }
1295   TensorLayout out_layout = tensor_redistribution->to_origin_no_assembled();
1296   auto tensor_map = out_layout.tensor_map();
1297   auto dev_mat = out_layout.device_arrangement();
1298   std::map<size_t, int64_t> factor_mapping;
1299   for (size_t i = 0; i < tensor_map.array().size(); ++i) {
1300     if (tensor_map.GetDimByIdx(i) != -1) {
1301       factor_mapping.insert({i, dev_mat.GetDimByReverseIdx(tensor_map.GetDimByIdx(i))});
1302     }
1303   }
1304   auto shape_input = reshape_node->input(INDEX_TWO);
1305   if (!shape_input->isa<CNode>()) {
1306     MS_LOG(DEBUG) << "Reshape's second input is not a CNode.";
1307     return Status::SUCCESS;
1308   }
1309   auto shape_input_cnode = shape_input->cast<CNodePtr>();
1310   if (IsTargetOp(shape_input_cnode, MAKE_TUPLE)) {
1311     UpdateMakeTupleShapeValue(shape_input_cnode, factor_mapping, func_graph);
1312   }
1313   return Status::SUCCESS;
1314 }
1315 
MergeEntireShapeForDynamic(const FuncGraphPtr & root)1316 Status MergeEntireShapeForDynamic(const FuncGraphPtr &root) {
1317   MS_LOG(INFO) << "Into MergeEntireShapeForDynamic";
1318   MS_EXCEPTION_IF_NULL(root);
1319   // Step1. Judge whether is dynamic shape.
1320   // Step2. Find all Shape node, get its factor arr.
1321   // Step3. Mul factor in Step2 to its child nodes(TupleGetItem).
1322   // Step4. Modify next nodes of TupleGetItem.
1323   auto ret_node = root->get_return();
1324   MS_EXCEPTION_IF_NULL(ret_node);
1325   auto all_nodes = DeepScopedGraphSearch(ret_node);
1326   std::reverse(all_nodes.begin(), all_nodes.end());
1327   std::set<FuncGraphPtr> graph_set = FindForwardGraphByRootNodes(all_nodes);
1328 
1329   if (graph_set.empty()) {
1330     MS_LOG(INFO) << "Can not find the forward graph, so mark the ops in root graph";
1331     auto fgs = root->manager()->func_graphs();
1332     for (auto fg = fgs.cbegin(); fg != fgs.cend(); ++fg) {
1333       // Travers all node and find shape.
1334       auto fg_nodes_set = (*fg)->nodes();
1335       for (auto const &node : fg_nodes_set) {
1336         if (!node->isa<CNode>()) {
1337           continue;
1338         }
1339         auto cnode = node->cast<CNodePtr>();
1340         if (IsShapeOp(cnode)) {
1341           UpdateShapeNode(cnode, *fg);
1342           continue;
1343         }
1344       }
1345     }
1346   } else {
1347     MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size();
1348     for (auto func_graph = graph_set.cbegin(); func_graph != graph_set.cend(); ++func_graph) {
1349       auto return_node = (*func_graph)->get_return();
1350       MS_EXCEPTION_IF_NULL(return_node);
1351       std::vector<AnfNodePtr> all_dfs_nodes = DeepLinkedGraphSearch(return_node);
1352       for (const auto &node : all_dfs_nodes) {
1353         if (!node->isa<CNode>()) {
1354           continue;
1355         }
1356         auto cnode = node->cast<CNodePtr>();
1357         if (IsShapeOp(cnode)) {
1358           UpdateShapeNode(cnode, *func_graph);
1359           continue;
1360         }
1361       }
1362     }
1363   }
1364   return Status::SUCCESS;
1365 }
1366 }  // namespace mindspore::parallel
1367