• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "include/common/utils/anfalgo.h"
17 #include <memory>
18 #include <algorithm>
19 #include <map>
20 #include <numeric>
21 #include <set>
22 #include <complex>
23 #include "mindapi/base/shape_vector.h"
24 #include "ops/ascend_op_name.h"
25 #include "ops/nn_optimizer_op_name.h"
26 #include "ops/lite_op_name.h"
27 #include "ops/structure_ops.h"
28 #include "ops/sequence_ops.h"
29 #include "ops/other_ops.h"
30 #include "ops/nn_ops.h"
31 #include "ops/math_ops.h"
32 #include "ops/array_ops.h"
33 #include "ops/arithmetic_ops.h"
34 #include "ops/framework_ops.h"
35 #include "ops/op_utils.h"
36 #include "ops/op_def.h"
37 #include "ops/auto_generate/gen_ops_primitive.h"
38 #include "ir/anf.h"
39 #include "ir/func_graph.h"
40 #include "include/common/utils/utils.h"
41 #include "utils/shape_utils.h"
42 #include "utils/trace_base.h"
43 #include "utils/anf_utils.h"
44 #include "include/common/utils/parallel_context.h"
45 #include "utils/ms_context.h"
46 #include "pybind_api/ir/primitive_py.h"
47 #include "kernel/kernel_build_info.h"
48 #include "include/backend/anf_runtime_algorithm.h"
49 
50 namespace mindspore {
51 namespace common {
52 using abstract::AbstractSparseTensor;
53 using abstract::AbstractTensor;
54 using abstract::AbstractTuple;
55 
56 namespace {
57 constexpr size_t kNopNodeRealInputIndex = 1;
58 using complex64 = std::complex<float>;
59 using complex128 = std::complex<double>;
60 
61 const PrimitiveSet expand_prims = {prim::kPrimMakeTuple};
62 const std::set<std::string> kNodeTupleOutSet = {kMakeTupleOpName, kGetNextOpName};
63 
GetRealOutputRecursively(const AnfNodePtr & node,size_t output_index,std::vector<KernelWithIndex> * inputs)64 void GetRealOutputRecursively(const AnfNodePtr &node, size_t output_index, std::vector<KernelWithIndex> *inputs) {
65   MS_EXCEPTION_IF_NULL(node);
66   if (node->isa<ValueNode>() || node->isa<Parameter>()) {
67     return inputs->push_back(std::make_pair(node, 0));
68   }
69 
70   // Skip control node
71   if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad) ||
72       AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) {
73     return GetRealOutputRecursively(node->cast<CNodePtr>()->input(kRealInputIndexInDepend), 0, inputs);
74   }
75 
76   // Bypass TupleGetItem
77   if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
78     auto tuple_get_item = node->cast<CNodePtr>();
79     MS_EXCEPTION_IF_NULL(tuple_get_item);
80     auto input = AnfAlgo::GetTupleGetItemRealInput(tuple_get_item);
81     auto index = AnfAlgo::GetTupleGetItemOutIndex(tuple_get_item);
82     // Conceal MakeTuple + TupleGetItem pair.
83     if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimMakeTuple)) {
84       auto make_tuple = input->cast<CNodePtr>();
85       MS_EXCEPTION_IF_NULL(make_tuple);
86       auto real_input = AnfAlgo::GetInputNode(make_tuple, index);
87       return GetRealOutputRecursively(real_input, 0, inputs);
88     }
89 
90     // Skip TupleGetItem.
91     return GetRealOutputRecursively(input, index, inputs);
92   }
93 
94   // Flatten MakeTuple inputs.
95   if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
96     auto make_tuple = node->cast<CNodePtr>();
97     MS_EXCEPTION_IF_NULL(make_tuple);
98     size_t input_num = AnfAlgo::GetInputTensorNum(make_tuple);
99     for (size_t input_index = 0; input_index < input_num; ++input_index) {
100       auto input_node = AnfAlgo::GetInputNode(make_tuple, input_index);
101       GetRealOutputRecursively(input_node, 0, inputs);
102     }
103     return;
104   }
105 
106   return inputs->push_back(std::make_pair(node, output_index));
107 }
108 
IsMultiLayerTuple(const abstract::AbstractBasePtr & abstract)109 bool IsMultiLayerTuple(const abstract::AbstractBasePtr &abstract) {
110   MS_EXCEPTION_IF_NULL(abstract);
111   if (!abstract->isa<abstract::AbstractSequence>()) {
112     return false;
113   }
114   const auto &sequence_abstract = abstract->cast<abstract::AbstractSequencePtr>();
115   MS_EXCEPTION_IF_NULL(sequence_abstract);
116   if (sequence_abstract->dynamic_len()) {
117     return false;
118   }
119   return std::any_of(sequence_abstract->elements().begin(), sequence_abstract->elements().end(),
120                      [](const abstract::AbstractBasePtr &sub_abstract) {
121                        return sub_abstract != nullptr && sub_abstract->isa<abstract::AbstractSequence>();
122                      });
123 }
124 
GetAllOutputWithIndexInner(const AnfNodePtr & node,const std::vector<PrimitivePtr> & return_types)125 std::vector<KernelWithIndex> GetAllOutputWithIndexInner(const AnfNodePtr &node,
126                                                         const std::vector<PrimitivePtr> &return_types) {
127   MS_EXCEPTION_IF_NULL(node);
128   MS_LOG(DEBUG) << "Output node: " << node->fullname_with_scope();
129   std::vector<KernelWithIndex> ret;
130   std::vector<KernelWithIndex> ret_empty;
131   // The MakeTuple/MakeSparse node need expand and recurse.
132   if (IsOneOfPrimitiveCNode(node, expand_prims)) {
133     auto make_tuple = node->cast<CNodePtr>();
134     MS_EXCEPTION_IF_NULL(make_tuple);
135     for (size_t i = 1; i < make_tuple->size(); i++) {
136       auto make_tuple_output = GetAllOutputWithIndexInner(make_tuple->input(i), return_types);
137       (void)std::copy(make_tuple_output.begin(), make_tuple_output.end(), std::back_inserter(ret));
138     }
139     return ret;
140   }
141   // The depend node need get the real node.
142   if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
143     auto depend_node = node->cast<CNodePtr>();
144     MS_EXCEPTION_IF_NULL(depend_node);
145     auto real_output = GetAllOutputWithIndexInner(depend_node->input(kRealInputIndexInDepend), return_types);
146     (void)std::copy(real_output.begin(), real_output.end(), std::back_inserter(ret));
147     return ret;
148   }
149 
150   // Value node need get all the elements.
151   if (node->isa<ValueNode>()) {
152     auto value = node->cast<ValueNodePtr>()->value();
153     MS_EXCEPTION_IF_NULL(value);
154     if (value->isa<ValueSequence>()) {
155       auto value_tuple = value->cast<ValueSequencePtr>();
156       auto value_tuple_size = CountValueNum(value_tuple);
157       for (size_t i = 0; i < value_tuple_size; ++i) {
158         (void)ret.emplace_back(node, i);
159       }
160     } else {
161       (void)ret.emplace_back(node, 0);
162     }
163     MS_LOG(DEBUG) << "Output value node: " << node->fullname_with_scope() << ", value num: " << ret.size();
164     return ret;
165   }
166 
167   // Output num must be exactly equal to the number of outputs of the node.
168   size_t outputs_num = 1;
169   if (AnfUtils::IsRealCNodeKernel(node)) {
170     if (node->abstract() != nullptr &&
171         (common::AnfAlgo::IsDynamicSequence(node) || IsMultiLayerTuple(node->abstract()))) {
172       outputs_num = common::AnfAlgo::GetOutputNumByAbstract(node->abstract());
173     } else {
174       outputs_num = AnfUtils::GetOutputTensorNum(node);
175     }
176     MS_LOG(DEBUG) << "Output num:" << outputs_num << " for node:" << node->DebugString();
177   }
178   // Call node maybe a real cnode and the unreal node cannot get output num exactly, so we should get
179   // output num from abstract again. For example the TupleGetItem/Makeple multi-level nesting:
180   // '''G = op()  ---> Assume that the output of G is a multi-member tuple
181   //    A = MakeTuple(E, F, G)
182   //    B = MakeTuple(H, A)
183   //    C = TupleGetItem(B, 1) ---> Euqal the A
184   //    D = TupleGetItem(C, 2)  ---> VisitKernel will return the {G, 0}, but expect the whole G with all the members
185   //    return D'''
186   if (common::AnfAlgo::IsCallNode(node) || (!AnfUtils::IsRealCNodeKernel(node))) {
187     MS_EXCEPTION_IF_NULL(node->abstract());
188     outputs_num = AnfAlgo::GetOutputNumByAbstract(node->abstract());
189   }
190 
191   // The output may be the tuple of node, so need visit all the outputs of node.
192   // Since output num represents the number of all outputs of node, only one output is obtained per loop.
193   for (size_t i = 0; i < outputs_num; ++i) {
194     // Maybe this scene: tupleGetItem + depend + makeTuple, can be done correctly in VisitKernelWithReturnType.
195     // The output may be updataState/load node for connecting dependencies between subgraphs.
196     auto output_with_index = AnfAlgo::VisitKernelWithReturnType(
197       node, i, false, {prim::kPrimMakeTuple, prim::kPrimUpdateState, prim::kPrimLoad}, nullptr, true);
198     MS_EXCEPTION_IF_NULL(output_with_index.first);
199 
200     // The MakeTuple/MakeSparse node need recurse.
201     if (IsOneOfPrimitiveCNode(output_with_index.first, expand_prims)) {
202       auto output_vector = GetAllOutputWithIndexInner(output_with_index.first, return_types);
203       if (output_vector.size() <= output_with_index.second) {
204         MS_LOG(INTERNAL_EXCEPTION) << "Invalid index:" << output_with_index.second
205                                    << " for outputs of node:" << output_with_index.first->DebugString();
206       }
207       (void)ret.emplace_back(output_vector[output_with_index.second]);
208       continue;
209     }
210 
211     // The InitDataSetQueue node has no output.
212     if (AnfAlgo::CheckPrimitiveType(output_with_index.first, prim::kPrimInitDataSetQueue)) {
213       return ret_empty;
214     }
215 
216     MS_LOG(DEBUG) << "Output node: " << output_with_index.first->fullname_with_scope()
217                   << " with output index: " << output_with_index.second;
218     ret.push_back(output_with_index);
219   }
220   return ret;
221 }
222 
IsNodeDynamicShape(const AnfNodePtr & node)223 bool IsNodeDynamicShape(const AnfNodePtr &node) {
224   MS_EXCEPTION_IF_NULL(node);
225   if (!node->isa<CNode>()) {
226     MS_LOG(DEBUG) << "Node is not a cnode";
227     return false;
228   }
229   auto cnode = node->cast<CNodePtr>();
230   auto in_dynamic = AnfAlgo::IsNodeInputDynamicShape(cnode);
231   auto out_dynamic = AnfAlgo::IsNodeOutputDynamicShape(cnode);
232   if (in_dynamic && !AnfAlgo::HasNodeAttr(kAttrInputIsDynamicShape, cnode)) {
233     AnfAlgo::SetNodeAttrSafely(kAttrInputIsDynamicShape, MakeValue(true), cnode);
234     MS_LOG(DEBUG) << "Set Input Dynamic Shape Attr to Node:" << cnode->fullname_with_scope()
235                   << " debug string:" << cnode->DebugString();
236   }
237   if (out_dynamic && !AnfAlgo::HasNodeAttr(kAttrOutputIsDynamicShape, cnode)) {
238     AnfAlgo::SetNodeAttrSafely(kAttrOutputIsDynamicShape, MakeValue(true), cnode);
239     MS_LOG(DEBUG) << "Set Output Dynamic Shape Attr to Node:" << cnode->fullname_with_scope()
240                   << " debug string:" << cnode->DebugString();
241   }
242   if (IsPrimitiveCNode(node, prim::kPrimPyExecute) && node->abstract()->isa<abstract::AbstractSequence>()) {
243     AnfAlgo::SetNodeAttrSafely(kAttrOutputIsDynamicShape, MakeValue(true), cnode);
244     MS_LOG(DEBUG) << "Set Output Dynamic Shape Attr to Node:" << cnode->fullname_with_scope();
245     return true;
246   }
247   return in_dynamic || out_dynamic;
248 }
249 }  // namespace
250 
GetTupleGetItemRealInput(const CNodePtr & tuple_get_item)251 AnfNodePtr AnfAlgo::GetTupleGetItemRealInput(const CNodePtr &tuple_get_item) {
252   MS_EXCEPTION_IF_NULL(tuple_get_item);
253   if (tuple_get_item->size() != kTupleGetItemInputSize) {
254     MS_LOG(INTERNAL_EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
255   }
256   return tuple_get_item->input(kRealInputNodeIndexInTupleGetItem);
257 }
258 
GetTupleGetItemOutIndex(const CNodePtr & tuple_get_item)259 size_t AnfAlgo::GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
260   MS_EXCEPTION_IF_NULL(tuple_get_item);
261   if (tuple_get_item->size() != kTupleGetItemInputSize) {
262     MS_LOG(INTERNAL_EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
263   }
264   auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem);
265   MS_EXCEPTION_IF_NULL(output_index_value_node);
266   auto value_node = output_index_value_node->cast<ValueNodePtr>();
267   MS_EXCEPTION_IF_NULL(value_node);
268   auto value = value_node->value();
269   MS_EXCEPTION_IF_NULL(value);
270   auto idx = value->isa<Int64Imm>() ? GetValue<int64_t>(value) : GetValue<int>(value);
271   return LongToSize(idx);
272 }
273 
VisitKernel(const AnfNodePtr & anf_node,size_t index)274 KernelWithIndex AnfAlgo::VisitKernel(const AnfNodePtr &anf_node, size_t index) {
275   // this function was moved to AnfUtils.
276   return AnfUtils::VisitKernel(anf_node, index);
277 }
278 
279 namespace {
VisitKernelWithReturnTypeForTupleGetItem(const AnfNodePtr & anf_node,size_t index,bool skip_nop_node,const std::vector<PrimitivePtr> & return_types,abstract::AbstractBasePtr * abstract,bool is_index_valid)280 KernelWithIndex VisitKernelWithReturnTypeForTupleGetItem(const AnfNodePtr &anf_node, size_t index, bool skip_nop_node,
281                                                          const std::vector<PrimitivePtr> &return_types,
282                                                          abstract::AbstractBasePtr *abstract, bool is_index_valid) {
283   MS_EXCEPTION_IF_NULL(anf_node);
284   if (!common::AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
285     MS_LOG(EXCEPTION) << "Invalid tuple get item node:" << anf_node->DebugString();
286   }
287   auto cnode = anf_node->cast<CNodePtr>();
288   MS_EXCEPTION_IF_NULL(cnode);
289   if (cnode->HasAttr(kAttrReplaceRealKernelInBackend)) {
290     MS_LOG(INFO) << "cnode:" << cnode->DebugString() << " has replace flag";
291     return KernelWithIndex(anf_node, index);
292   }
293   abstract::AbstractBasePtr abs = nullptr;
294   auto item_with_index_tmp = common::AnfAlgo::VisitKernelWithReturnType(
295     common::AnfAlgo::GetTupleGetItemRealInput(cnode), common::AnfAlgo::GetTupleGetItemOutIndex(cnode), skip_nop_node,
296     return_types, &abs, true);
297   if (IsOneOfPrimitiveCNode(item_with_index_tmp.first, expand_prims)) {
298     MS_EXCEPTION_IF_NULL(item_with_index_tmp.first);
299     auto make_tuple = item_with_index_tmp.first->cast<CNodePtr>();
300     MS_EXCEPTION_IF_NULL(make_tuple);
301     const std::vector<AnfNodePtr> &make_tuple_inputs = make_tuple->inputs();
302     size_t make_tuple_input_index = item_with_index_tmp.second + 1;
303     if (make_tuple_input_index >= make_tuple_inputs.size()) {
304       MS_LOG(INTERNAL_EXCEPTION) << "Index[" << make_tuple_input_index << "] out of range[" << make_tuple_inputs.size()
305                                  << "].\nPlease check node: " << cnode->DebugString()
306                                  << ".\nLine: " << trace::GetDebugInfoStr(cnode->debug_info())
307                                  << ".\nAnd check node: " << make_tuple->DebugString()
308                                  << ".\nLine: " << trace::GetDebugInfoStr(make_tuple->debug_info()) << ".";
309     }
310     return common::AnfAlgo::VisitKernelWithReturnType(make_tuple_inputs[make_tuple_input_index], index, skip_nop_node,
311                                                       return_types);
312   }
313   if (common::AnfAlgo::IsCallNode(item_with_index_tmp.first) || item_with_index_tmp.first->isa<Parameter>()) {
314     size_t real_index = item_with_index_tmp.second;
315     if (abs == nullptr) {
316       abs = item_with_index_tmp.first->abstract();
317       real_index = 0;
318     }
319     MS_EXCEPTION_IF_NULL(abs);
320     if (abs->isa<abstract::AbstractSequence>()) {
321       auto tuple_abstract = abs->cast<abstract::AbstractSequencePtr>();
322       MS_EXCEPTION_IF_NULL(tuple_abstract);
323       if (tuple_abstract->dynamic_len()) {
324         return item_with_index_tmp;
325       }
326       auto sub_abstracts = tuple_abstract->elements();
327       if (sub_abstracts.size() <= common::AnfAlgo::GetTupleGetItemOutIndex(cnode)) {
328         MS_LOG(INTERNAL_EXCEPTION) << "Invalid index:" << common::AnfAlgo::GetTupleGetItemOutIndex(cnode)
329                                    << " for abstract:" << abs->ToString();
330       }
331       for (size_t i = 0; i < common::AnfAlgo::GetTupleGetItemOutIndex(cnode); ++i) {
332         MS_EXCEPTION_IF_NULL(sub_abstracts[i]);
333         real_index += AnfAlgo::GetOutputNumByAbstract(sub_abstracts[i]);
334       }
335       if (abstract != nullptr) {
336         (*abstract) = sub_abstracts[common::AnfAlgo::GetTupleGetItemOutIndex(cnode)];
337         MS_EXCEPTION_IF_NULL((*abstract));
338       } else {
339         // In recursion of getitem node, the index of the first input of its real node is returned.
340         // When the recursion ends, the outermost index needs to be accumulated.
341         real_index += index;
342       }
343       return {item_with_index_tmp.first, real_index};
344     }
345   }
346   if (is_index_valid) {
347     if (anf_node->abstract() != nullptr && anf_node->abstract()->isa<abstract::AbstractSequence>()) {
348       const auto &seq_abs = anf_node->abstract()->cast<abstract::AbstractSequencePtr>();
349       MS_EXCEPTION_IF_NULL(seq_abs);
350       if (!seq_abs->dynamic_len()) {
351         return {anf_node, index};
352       }
353     }
354   }
355   return item_with_index_tmp;
356 }
357 }  // namespace
358 
VisitKernelWithReturnType(const AnfNodePtr & anf_node,size_t index,bool skip_nop_node,const std::vector<PrimitivePtr> & return_types,abstract::AbstractBasePtr * abstract,bool is_index_valid)359 KernelWithIndex AnfAlgo::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index, bool skip_nop_node,
360                                                    const std::vector<PrimitivePtr> &return_types,
361                                                    abstract::AbstractBasePtr *abstract, bool is_index_valid) {
362   MS_EXCEPTION_IF_NULL(anf_node);
363   if (std::any_of(return_types.begin(), return_types.end(), [&anf_node](const PrimitivePtr &prim_type) -> bool {
364         return CheckPrimitiveType(anf_node, prim_type);
365       })) {
366     return KernelWithIndex(anf_node, index);
367   }
368   if (!anf_node->isa<CNode>()) {
369     return KernelWithIndex(anf_node, index);
370   }
371   auto cnode = anf_node->cast<CNodePtr>();
372   MS_EXCEPTION_IF_NULL(cnode);
373   // TupleGetItem and SparseGetAttr needs to find real input
374   if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
375     return VisitKernelWithReturnTypeForTupleGetItem(anf_node, index, skip_nop_node, return_types, abstract,
376                                                     is_index_valid);
377   }
378   if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimUpdateState)) {
379     return VisitKernelWithReturnType(cnode->input(kUpdateStateStateInput), index, skip_nop_node, return_types);
380   }
381   const PrimitiveSet follow_first_input_prims = {prim::kPrimDepend, prim::kPrimLoad, prim::kPrimDynamicLossScale};
382   if (IsOneOfPrimitiveCNode(cnode, follow_first_input_prims)) {
383     return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), index, skip_nop_node, return_types);
384   }
385   if (IsNopNode(cnode) && skip_nop_node) {
386     return VisitKernelWithReturnType(cnode->input(kNopNodeRealInputIndex), 0, skip_nop_node, return_types);
387   }
388   return KernelWithIndex(anf_node, index);
389 }
390 
FetchRealNodeSkipMonadControl(const KernelWithIndex & node_with_index)391 KernelWithIndex AnfAlgo::FetchRealNodeSkipMonadControl(const KernelWithIndex &node_with_index) {
392   MS_EXCEPTION_IF_NULL(node_with_index.first);
393   const mindspore::HashSet<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = {prim::kPrimDepend,
394                                                                                               prim::kPrimLoad};
395   if (IsOneOfPrimitiveCNode(node_with_index.first, auto_monad_prims)) {
396     return common::AnfAlgo::VisitKernelWithReturnType(node_with_index.first, node_with_index.second, false);
397   } else {
398     return node_with_index;
399   }
400 }
401 
GetAllOutput(const AnfNodePtr & node,const std::vector<PrimitivePtr> & return_types)402 std::vector<AnfNodePtr> AnfAlgo::GetAllOutput(const AnfNodePtr &node, const std::vector<PrimitivePtr> &return_types) {
403   std::vector<AnfNodePtr> ret;
404   const auto &output_pair = GetAllOutputIndexByReturnTypes(node, return_types);
405   std::transform(output_pair.begin(), output_pair.end(), std::back_inserter(ret),
406                  [](const KernelWithIndex &ele) { return ele.first; });
407   return ret;
408 }
409 
GetAllOutputIndexByReturnTypes(const AnfNodePtr & node,const std::vector<PrimitivePtr> & return_types,bool need_make_tuple)410 std::vector<KernelWithIndex> AnfAlgo::GetAllOutputIndexByReturnTypes(const AnfNodePtr &node,
411                                                                      const std::vector<PrimitivePtr> &return_types,
412                                                                      bool need_make_tuple) {
413   std::vector<KernelWithIndex> ret;
414   auto return_prim_type = return_types;
415   // if visited make_tuple should return back
416   return_prim_type.push_back(prim::kPrimMakeTuple);
417   auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, false, return_prim_type);
418   if (need_make_tuple) {
419     ret.push_back(item_with_index);
420   }
421   if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
422     MS_EXCEPTION_IF_NULL(item_with_index.first);
423     auto make_tuple = item_with_index.first->cast<CNodePtr>();
424     MS_EXCEPTION_IF_NULL(make_tuple);
425     for (size_t i = 1; i < make_tuple->size(); i++) {
426       auto input_i_vector = GetAllOutputIndexByReturnTypes(make_tuple->input(i), return_types);
427       (void)std::copy(input_i_vector.begin(), input_i_vector.end(), std::back_inserter(ret));
428     }
429     return ret;
430   }
431   ret.push_back(item_with_index);
432   return ret;
433 }
434 
GetOutputNumByAbstract(const AbstractBasePtr & node_abstract)435 size_t AnfAlgo::GetOutputNumByAbstract(const AbstractBasePtr &node_abstract) {
436   MS_EXCEPTION_IF_NULL(node_abstract);
437   size_t result = 0;
438 
439   if (!node_abstract->isa<abstract::AbstractSequence>() ||
440       node_abstract->cast<abstract::AbstractSequencePtr>()->dynamic_len() ||
441       node_abstract->cast<abstract::AbstractSequencePtr>()->dynamic_len_element_abs() != nullptr) {
442     return 1;
443   }
444 
445   auto tuple_abstract = node_abstract->cast<abstract::AbstractSequencePtr>();
446   MS_EXCEPTION_IF_NULL(tuple_abstract);
447   const auto &sub_abstracts = tuple_abstract->elements();
448   for (const auto &sub_abstract : sub_abstracts) {
449     MS_EXCEPTION_IF_NULL(sub_abstract);
450     result += GetOutputNumByAbstract(sub_abstract);
451   }
452   return result;
453 }
454 
GetAllOutputWithOutMonadAndParameter(const AnfNodePtr & node)455 std::vector<KernelWithIndex> AnfAlgo::GetAllOutputWithOutMonadAndParameter(const AnfNodePtr &node) {
456   MS_EXCEPTION_IF_NULL(node);
457   const auto &graph_outputs = common::AnfAlgo::GetAllOutputWithIndex(node);
458   std::vector<KernelWithIndex> real_output;
459   for (const auto &node_with_index : graph_outputs) {
460     MS_EXCEPTION_IF_NULL(node_with_index.first);
461     if (HasAbstractMonad(node_with_index.first) || node_with_index.first->isa<Parameter>() ||
462         node_with_index.first->isa<ValueNode>()) {
463       continue;
464     }
465     real_output.emplace_back(node_with_index);
466   }
467   return real_output;
468 }
469 
GetAllOutputWithIndex(const AnfNodePtr & node,const std::vector<PrimitivePtr> & return_types)470 std::vector<KernelWithIndex> AnfAlgo::GetAllOutputWithIndex(const AnfNodePtr &node,
471                                                             const std::vector<PrimitivePtr> &return_types) {
472   auto ret = GetAllOutputWithIndexInner(node, return_types);
473   std::map<AnfNodePtr, size_t> value_node_index;
474 
475   // Unify the output of the front and back end to the ValueTuple
476   for (auto &output_with_index : ret) {
477     auto value_node = output_with_index.first;
478     MS_EXCEPTION_IF_NULL(value_node);
479     if (!value_node->isa<ValueNode>()) {
480       continue;
481     }
482     if (value_node_index.find(value_node) == value_node_index.end() ||
483         value_node_index[value_node] < output_with_index.second) {
484       value_node_index[value_node] = output_with_index.second;
485     } else {
486       value_node_index[value_node]++;
487       MS_LOG(DEBUG) << "Set output value node new index, value node: " << value_node->fullname_with_scope()
488                     << ", original index: " << output_with_index.second
489                     << ", new index:" << value_node_index[value_node];
490       output_with_index.second = value_node_index[value_node];
491     }
492   }
493   return ret;
494 }
495 
CheckPrimitiveType(const AnfNodePtr & node,const PrimitivePtr & primitive_type)496 bool AnfAlgo::CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
497   MS_EXCEPTION_IF_NULL(node);
498   if (!node->isa<CNode>()) {
499     return false;
500   }
501   auto cnode = node->cast<CNodePtr>();
502   MS_EXCEPTION_IF_NULL(cnode);
503   return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
504 }
505 
GetCNodeFuncGraphPtr(const AnfNodePtr & node)506 FuncGraphPtr AnfAlgo::GetCNodeFuncGraphPtr(const AnfNodePtr &node) {
507   MS_EXCEPTION_IF_NULL(node);
508   auto cnode = node->cast<CNodePtr>();
509   MS_EXCEPTION_IF_NULL(cnode);
510   auto attr_input = cnode->input(kAnfPrimitiveIndex);
511   MS_EXCEPTION_IF_NULL(attr_input);
512   auto value_node = attr_input->cast<ValueNodePtr>();
513   MS_EXCEPTION_IF_NULL(value_node);
514   auto value = value_node->value();
515   MS_EXCEPTION_IF_NULL(value);
516   return value->cast<FuncGraphPtr>();
517 }
518 
GetCNodeName(const AnfNodePtr & node)519 std::string AnfAlgo::GetCNodeName(const AnfNodePtr &node) {
520   // this function was moved to AnfUtils.
521   return AnfUtils::GetCNodeName(node);
522 }
523 
IsGetNextNode(const AnfNodePtr & node)524 bool AnfAlgo::IsGetNextNode(const AnfNodePtr &node) {
525   auto node_name = AnfUtils::GetCNodeName(node);
526   return node_name == kGetNextOpName || node_name == kDynamicGetNextV2OpName;
527 }
528 
GetNodeDebugString(const AnfNodePtr & node)529 std::string AnfAlgo::GetNodeDebugString(const AnfNodePtr &node) {
530   MS_EXCEPTION_IF_NULL(node);
531   return node->DebugString();
532 }
533 
SetNodeAttr(const std::string & key,const ValuePtr & value,const AnfNodePtr & node)534 void AnfAlgo::SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node) {
535   // this function was moved to AnfUtils.
536   return AnfUtils::SetNodeAttr(key, value, node);
537 }
538 
SetNodeAttrSafely(const std::string & key,const ValuePtr & value,const AnfNodePtr & node)539 void AnfAlgo::SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node) {
540   // Make CNode safe to set attr firstly.
541   auto cnode = node->cast<CNodePtr>();
542   if (cnode == nullptr) {
543     return;
544   }
545   auto prim = common::AnfAlgo::GetCNodePrimitive(cnode);
546   if (prim != nullptr) {
547     auto new_prim = prim->isa<PrimitivePy>() ? prim : prim->Clone();
548     cnode->set_input(0, NewValueNode(new_prim));
549   }
550 
551   // Set attr secondly.
552   common::AnfAlgo::SetNodeAttr(key, value, node);
553 }
554 
CopyNodeAttr(const std::string & key,const AnfNodePtr & from,const AnfNodePtr & to)555 void AnfAlgo::CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to) {
556   CopyNodeAttr(key, key, from, to);
557 }
558 
CopyNodeAttr(const std::string & old_key,const std::string & new_key,const AnfNodePtr & from,const AnfNodePtr & to)559 void AnfAlgo::CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from,
560                            const AnfNodePtr &to) {
561   MS_EXCEPTION_IF_NULL(from);
562   MS_EXCEPTION_IF_NULL(to);
563   if (!from->isa<CNode>() || !to->isa<CNode>()) {
564     MS_LOG(INTERNAL_EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << " ,to_node is "
565                                << to->DebugString() << trace::DumpSourceLines(from);
566   }
567   auto from_primitive = AnfAlgo::GetCNodePrimitive(from);
568   MS_EXCEPTION_IF_NULL(from_primitive);
569   auto to_primitive = AnfAlgo::GetCNodePrimitive(to);
570   MS_EXCEPTION_IF_NULL(to_primitive);
571   to_primitive->set_attr(new_key, from_primitive->GetAttr(old_key));
572 }
573 
CopyNodeAttrs(const AnfNodePtr & from,const AnfNodePtr & to)574 void AnfAlgo::CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to) {
575   MS_EXCEPTION_IF_NULL(from);
576   MS_EXCEPTION_IF_NULL(to);
577   if (!from->isa<CNode>() || !to->isa<CNode>()) {
578     MS_LOG(INTERNAL_EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << ",to_node is "
579                                << from->DebugString() << trace::DumpSourceLines(from);
580   }
581   auto from_primitive = AnfAlgo::GetCNodePrimitive(from);
582   MS_EXCEPTION_IF_NULL(from_primitive);
583   auto to_primitive = AnfAlgo::GetCNodePrimitive(to);
584   MS_EXCEPTION_IF_NULL(to_primitive);
585   auto from_cnode = from->cast<CNodePtr>();
586   auto to_cnode = to->cast<CNodePtr>();
587   if (from_cnode->HasPrimalAttr(kAttrMicro)) {
588     to_cnode->AddPrimalAttr(kAttrMicro, from_cnode->GetPrimalAttr(kAttrMicro));
589   }
590   (void)to_primitive->SetAttrs(from_primitive->attrs());
591 }
592 
EraseNodeAttr(const std::string & key,const AnfNodePtr & node)593 void AnfAlgo::EraseNodeAttr(const std::string &key, const AnfNodePtr &node) {
594   MS_EXCEPTION_IF_NULL(node);
595   if (!node->isa<CNode>()) {
596     MS_LOG(INTERNAL_EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString()
597                                << trace::DumpSourceLines(node);
598   }
599   // single op cnode.
600   auto primitive = AnfAlgo::GetCNodePrimitive(node);
601   if (primitive != nullptr) {
602     primitive->EraseAttr(key);
603     return;
604   }
605   // graph kernel cnode.
606   auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
607   MS_EXCEPTION_IF_NULL(fg);
608   fg->erase_flag(key);
609 }
610 
HasNodeAttr(const std::string & key,const CNodePtr & node)611 bool AnfAlgo::HasNodeAttr(const std::string &key, const CNodePtr &node) {
612   MS_EXCEPTION_IF_NULL(node);
613   // call node's input0 is not a primitive.
614   if (!IsValueNode<FuncGraph>(node->input(0)) && !IsValueNode<Primitive>(node->input(0))) {
615     return false;
616   }
617   // single op cnode.
618   auto primitive = AnfAlgo::GetCNodePrimitive(node);
619   if (primitive != nullptr) {
620     return primitive->HasAttr(key);
621   }
622   // graph kernel cnode.
623   auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
624   MS_EXCEPTION_IF_NULL(fg);
625   return fg->has_attr(key);
626 }
627 
GetInputNum(const CNodePtr & cnode)628 size_t AnfAlgo::GetInputNum(const CNodePtr &cnode) {
629   MS_EXCEPTION_IF_NULL(cnode);
630   size_t input_num = cnode->size();
631   if (input_num == 0) {
632     MS_LOG(INTERNAL_EXCEPTION) << "Cnode inputs size can't be zero." << trace::DumpSourceLines(cnode);
633   }
634   return input_num - 1;
635 }
636 
GetInputTensorNum(const AnfNodePtr & node)637 size_t AnfAlgo::GetInputTensorNum(const AnfNodePtr &node) {
638   // this function was moved to AnfUtils.
639   return AnfUtils::GetInputTensorNum(node);
640 }
641 
IsPrevNodeHasTupleGetItem(const AnfNodePtr & anf_node,size_t input_idx,bool skip_nop_node)642 bool AnfAlgo::IsPrevNodeHasTupleGetItem(const AnfNodePtr &anf_node, size_t input_idx, bool skip_nop_node) {
643   if (!anf_node->isa<CNode>()) {
644     MS_LOG(INTERNAL_EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode."
645                                << trace::DumpSourceLines(anf_node);
646   }
647   auto input_node = AnfAlgo::GetInputNode(anf_node->cast<CNodePtr>(), input_idx);
648   MS_EXCEPTION_IF_NULL(input_node);
649   auto res = VisitKernelWithReturnType(input_node, 0, skip_nop_node, {prim::kPrimTupleGetItem});
650   if (CheckPrimitiveType(res.first, prim::kPrimTupleGetItem)) {
651     return true;
652   }
653   return false;
654 }
655 
GetPrevNodeOutput(const AnfNodePtr & anf_node,size_t input_idx,bool skip_nop_node)656 KernelWithIndex AnfAlgo::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx, bool skip_nop_node) {
657   MS_EXCEPTION_IF_NULL(anf_node);
658   if (!anf_node->isa<CNode>()) {
659     MS_LOG(INTERNAL_EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode."
660                                << trace::DumpSourceLines(anf_node);
661   }
662   auto kernel_info = anf_node->kernel_info();
663   if (kernel_info) {
664     auto runtime_cache = kernel_info->runtime_cache();
665     if (runtime_cache.runtime_cache().is_valid()) {
666       auto output = runtime_cache.runtime_cache().get_prev_node_output(input_idx);
667       if (output.first != nullptr) {
668         return output;
669       }
670     }
671   }
672   KernelWithIndex res;
673   if (CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
674     res = VisitKernelWithReturnType(anf_node, 0, skip_nop_node);
675   } else {
676     auto input_node = AnfAlgo::GetInputNode(anf_node->cast<CNodePtr>(), input_idx);
677     MS_EXCEPTION_IF_NULL(input_node);
678     res = VisitKernelWithReturnType(input_node, 0, skip_nop_node);
679   }
680   if (kernel_info) {
681     auto runtime_cache = kernel_info->runtime_cache();
682     if (runtime_cache.runtime_cache().is_valid()) {
683       runtime_cache.runtime_cache().set_prev_node_output(input_idx, res);
684     }
685   }
686   return res;
687 }
688 
689 // if the prev_node is MakeTuple, get all the input_nodes recursively, else use the ori GetPrevNodeOutput function
GetRealPrevNodesOutput(const AnfNodePtr & anf_node,size_t input_idx,bool skip_nop_node)690 std::vector<KernelWithIndex> AnfAlgo::GetRealPrevNodesOutput(const AnfNodePtr &anf_node, size_t input_idx,
691                                                              bool skip_nop_node) {
692   MS_EXCEPTION_IF_NULL(anf_node);
693   auto cnode = anf_node->cast<CNodePtr>();
694   MS_EXCEPTION_IF_NULL(cnode);
695 
696   std::vector<KernelWithIndex> res;
697   auto input_node = AnfAlgo::GetInputNode(cnode, input_idx);
698   MS_EXCEPTION_IF_NULL(input_node);
699   if (CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) {
700     auto maketuple_input_num = GetInputTensorNum(input_node);
701     for (size_t i = 0; i < maketuple_input_num; ++i) {
702       auto inputs_i = GetRealPrevNodesOutput(input_node, i, skip_nop_node);
703       (void)res.insert(res.end(), inputs_i.begin(), inputs_i.end());
704     }
705   } else {
706     (void)res.emplace_back(GetPrevNodeOutput(cnode, input_idx, skip_nop_node));
707   }
708   return res;
709 }
710 
GetRealPrevNodesOutputInferDataType(const AnfNodePtr & node,size_t input_idx)711 std::vector<TypeId> AnfAlgo::GetRealPrevNodesOutputInferDataType(const AnfNodePtr &node, size_t input_idx) {
712   std::vector<KernelWithIndex> kernels_with_index = AnfAlgo::GetRealPrevNodesOutput(node, input_idx);
713   std::vector<TypeId> res;
714   (void)std::transform(kernels_with_index.begin(), kernels_with_index.end(), std::back_inserter(res),
715                        [](auto kernel_with_index) {
716                          return AnfAlgo::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second);
717                        });
718   return res;
719 }
720 
721 namespace {
GetShape(const abstract::BaseShapePtr & base_shape)722 inline ShapeVector GetShape(const abstract::BaseShapePtr &base_shape) {
723   MS_EXCEPTION_IF_NULL(base_shape);
724   if (base_shape->isa<abstract::Shape>()) {
725     auto shape_ptr = base_shape->cast<abstract::ShapePtr>();
726     MS_EXCEPTION_IF_NULL(shape_ptr);
727     return shape_ptr->shape();
728   }
729   return {};
730 }
731 
GetOutputShape(const abstract::AbstractBasePtr & abstract,size_t output_idx,bool is_real_squence_output)732 ShapeVector GetOutputShape(const abstract::AbstractBasePtr &abstract, size_t output_idx, bool is_real_squence_output) {
733   MS_EXCEPTION_IF_NULL(abstract);
734   if (abstract->isa<abstract::AbstractTensor>() || abstract->isa<abstract::AbstractMapTensor>()) {
735     if (output_idx != 0) {
736       MS_LOG(INTERNAL_EXCEPTION) << "The abstract " << abstract->ToString()
737                                  << "is single output but got index:" << output_idx;
738     }
739     const auto &shape = abstract->GetShape();
740     return GetShape(shape);
741   } else if (abstract->isa<abstract::AbstractScalar>() || abstract->isa<abstract::AbstractMonad>()) {
742     return ShapeVector();
743   } else if (abstract->isa<abstract::AbstractSparseTensor>()) {
744     const auto &shape = abstract->GetShape();
745     MS_EXCEPTION_IF_NULL(shape);
746     const auto &tuple_shape = shape->cast<abstract::TupleShapePtr>();
747     MS_EXCEPTION_IF_NULL(tuple_shape);
748     if (output_idx >= tuple_shape->size()) {
749       MS_LOG(INTERNAL_EXCEPTION) << "Output index " << output_idx << "is larger than output number "
750                                  << tuple_shape->size() << " of tuple shape:" << tuple_shape->ToString()
751                                  << " in abstract:" << abstract;
752     }
753     return GetShape(tuple_shape->shape()[output_idx]);
754   }
755 
756   if (!abstract->isa<abstract::AbstractSequence>()) {
757     MS_LOG(INFO) << "Unknown abstract for get shape:" << abstract->ToString();
758     return {};
759   }
760 
761   const auto &sequence_abstract = abstract->cast<abstract::AbstractSequencePtr>();
762   MS_EXCEPTION_IF_NULL(sequence_abstract);
763   if (sequence_abstract->dynamic_len()) {
764     const auto &element_abstract = sequence_abstract->dynamic_len_element_abs();
765     if (element_abstract == nullptr) {
766       MS_LOG(ERROR) << "Invalid abstract for get shape:" << sequence_abstract->ToString();
767       return ShapeVector();
768     }
769     return GetOutputShape(element_abstract, 0, true);
770   }
771 
772   if (sequence_abstract->size() == 0) {
773     return ShapeVector();
774   }
775 
776   if (!is_real_squence_output) {
777     if (output_idx >= sequence_abstract->size()) {
778       MS_LOG(INTERNAL_EXCEPTION) << "Output index " << output_idx << "is larger than output number "
779                                  << sequence_abstract->size() << " of abstract:" << sequence_abstract->ToString();
780     }
781     MS_EXCEPTION_IF_NULL(sequence_abstract->elements()[output_idx]);
782     return GetOutputShape(sequence_abstract->elements()[output_idx], 0, true);
783   }
784 
785   // For real sequence output, if the inner elements' shape is same, the output is {element_num, *actual_shape},
786   // otherwise is {element_num, inner_max_size}.
787   // For example:
788   //   1) Output abstract: ((3,4,5), (3,4,5)), output shape: (2, 3, 4, 5).
789   //   2) Output abstract: ((3,4,5), (3,4,6)), output shape: (2, 72).
790   ShapeVector elem_shape_vector;
791   size_t change_cnt = 0;
792   ShapeValueDType elem_size = 0;
793   for (const auto &elem_abs : sequence_abstract->elements()) {
794     MS_EXCEPTION_IF_NULL(elem_abs);
795     elem_shape_vector = GetOutputShape(elem_abs, 0, true);
796     auto cur_size = std::accumulate(elem_shape_vector.begin(), elem_shape_vector.end(), 1L, std::multiplies<int64_t>());
797     if (elem_size < cur_size) {
798       elem_size = cur_size;
799       ++change_cnt;
800     }
801   }
802 
803   ShapeVector shape_vector = {SizeToLong(sequence_abstract->size())};
804   if (change_cnt == 1) {
805     (void)shape_vector.insert(shape_vector.end(), elem_shape_vector.begin(), elem_shape_vector.end());
806   } else {
807     shape_vector.push_back(elem_size);
808   }
809   return shape_vector;
810 }
811 }  // namespace
812 
GetOutputInferShape(const AnfNodePtr & node,size_t output_idx,bool is_real_squence_output)813 ShapeVector AnfAlgo::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx, bool is_real_squence_output) {
814   MS_EXCEPTION_IF_NULL(node);
815   return GetOutputShape(node->abstract(), output_idx, is_real_squence_output || AnfAlgo::IsDynamicSequence(node));
816 }
817 
GetPrevNodeOutputInferShape(const AnfNodePtr & node,size_t input_idx)818 ShapeVector AnfAlgo::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) {
819   KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
820   return AnfAlgo::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second);
821 }
822 
GetOutputInferType(const AnfNodePtr & node,size_t output_idx,bool is_real_tuple)823 TypePtr AnfAlgo::GetOutputInferType(const AnfNodePtr &node, size_t output_idx, bool is_real_tuple) {
824   MS_EXCEPTION_IF_NULL(node);
825   MS_EXCEPTION_IF_NULL(node->abstract());
826   const auto &type = node->abstract()->BuildType();
827   MS_EXCEPTION_IF_NULL(type);
828   if (!type->isa<Tuple>() && !type->isa<List>()) {
829     if (output_idx != 0) {
830       MS_LOG(EXCEPTION) << "Invalid index:" << output_idx << " for node:" << node->DebugString()
831                         << " abstract:" << node->abstract()->ToString() << " type:" << type->ToString();
832     }
833     return type;
834   }
835   if (is_real_tuple) {
836     return type;
837   }
838   if (type->isa<Tuple>()) {
839     const auto &tuple_type = type->cast<TuplePtr>();
840     MS_EXCEPTION_IF_NULL(tuple_type);
841     if (tuple_type->dynamic_len()) {
842       if (output_idx != 0) {
843         MS_LOG(EXCEPTION) << "Failed to get type by index:" << output_idx << " type:" << type->ToString();
844       }
845       return tuple_type;
846     }
847     if (output_idx >= tuple_type->size()) {
848       MS_LOG(EXCEPTION) << "Invalid index:" << output_idx << " for node:" << node->DebugString()
849                         << " abstract:" << node->abstract()->ToString() << " type:" << type->ToString();
850     }
851     return tuple_type->elements()[output_idx];
852   }
853   const auto &list_type = type->cast<ListPtr>();
854   MS_EXCEPTION_IF_NULL(list_type);
855   if (list_type->dynamic_len()) {
856     if (output_idx != 0) {
857       MS_LOG(EXCEPTION) << "Failed to get type by index:" << output_idx << " type:" << type->ToString();
858     }
859     return list_type;
860   }
861   if (output_idx >= list_type->size()) {
862     MS_LOG(EXCEPTION) << "Invalid index:" << output_idx << " for node:" << node->DebugString()
863                       << " abstract:" << node->abstract()->ToString() << " type:" << type->ToString();
864   }
865   return list_type->elements()[output_idx];
866 }
867 
GetOutputInferDataType(const TypePtr & type,size_t output_idx)868 TypeId AnfAlgo::GetOutputInferDataType(const TypePtr &type, size_t output_idx) {
869   auto type_ptr = type;
870   MS_EXCEPTION_IF_NULL(type_ptr);
871   if (type_ptr->isa<Tuple>()) {
872     auto tuple_ptr = type_ptr->cast<TuplePtr>();
873     MS_EXCEPTION_IF_NULL(tuple_ptr);
874     if (tuple_ptr->size() == 0) {
875       if (tuple_ptr->dynamic_len() && tuple_ptr->dynamic_element_type() != nullptr) {
876         MS_LOG(INFO) << "Dynamic empty tuple type has an dynamic element type:"
877                      << tuple_ptr->dynamic_element_type()->type_id();
878         return tuple_ptr->dynamic_element_type()->type_id();
879       }
880       return kTypeUnknown;
881     }
882     if (tuple_ptr->dynamic_len()) {
883       MS_EXCEPTION_IF_NULL(tuple_ptr->dynamic_element_type());
884       return GetOutputInferDataType(tuple_ptr->dynamic_element_type(), 0);
885     }
886     MS_EXCEPTION_IF_NULL(tuple_ptr);
887     if (output_idx >= tuple_ptr->size()) {
888       MS_LOG(INTERNAL_EXCEPTION) << "Output index " << output_idx << " must be less than output number "
889                                  << tuple_ptr->size();
890     }
891     type_ptr = (*tuple_ptr)[output_idx];
892     MS_EXCEPTION_IF_NULL(type_ptr);
893   }
894 
895   if (type_ptr->isa<List>()) {
896     auto list_ptr = type_ptr->cast<ListPtr>();
897     MS_EXCEPTION_IF_NULL(list_ptr);
898     if (list_ptr->size() == 0) {
899       if (list_ptr->dynamic_len() && list_ptr->dynamic_element_type() != nullptr) {
900         MS_LOG(INFO) << "Dynamic empty list type has an dynamic element type:"
901                      << list_ptr->dynamic_element_type()->type_id();
902         return list_ptr->dynamic_element_type()->type_id();
903       }
904       return kTypeUnknown;
905     }
906     if (list_ptr->dynamic_len()) {
907       MS_EXCEPTION_IF_NULL(list_ptr->dynamic_element_type());
908       return GetOutputInferDataType(list_ptr->dynamic_element_type(), 0);
909     }
910     MS_EXCEPTION_IF_NULL(list_ptr);
911     if (output_idx >= list_ptr->size()) {
912       MS_LOG(INTERNAL_EXCEPTION) << "Output index " << output_idx << " must be less than output number "
913                                  << list_ptr->size();
914     }
915     type_ptr = (*list_ptr)[output_idx];
916     MS_EXCEPTION_IF_NULL(type_ptr);
917   }
918 
919   if (type_ptr->isa<SparseTensorType>()) {
920     auto tensor_ptr = type_ptr->cast<SparseTensorTypePtr>();
921     MS_EXCEPTION_IF_NULL(tensor_ptr);
922     type_ptr = (*tensor_ptr)[output_idx];
923     MS_EXCEPTION_IF_NULL(type_ptr);
924   }
925 
926   if (type_ptr->isa<TensorType>()) {
927     auto tensor_ptr = type_ptr->cast<TensorTypePtr>();
928     MS_EXCEPTION_IF_NULL(tensor_ptr);
929     TypePtr elem = tensor_ptr->element();
930     MS_EXCEPTION_IF_NULL(elem);
931     return elem->type_id();
932   }
933   if (type_ptr->isa<Tuple>() || type_ptr->isa<List>()) {
934     return GetOutputInferDataType(type_ptr, 0);
935   }
936   return type_ptr->type_id();
937 }
938 
939 namespace {
IsTupleInTupleValueNode(const AnfNodePtr & node)940 bool IsTupleInTupleValueNode(const AnfNodePtr &node) {
941   if (node == nullptr || !node->isa<ValueNode>()) {
942     return false;
943   }
944   const auto &value_node = node->cast<ValueNodePtr>();
945   MS_EXCEPTION_IF_NULL(value_node);
946   const auto &value = value_node->value();
947   if (value == nullptr || !value->isa<ValueSequence>()) {
948     return false;
949   }
950   const auto &value_sequence = value->cast<ValueSequencePtr>();
951   MS_EXCEPTION_IF_NULL(value_sequence);
952   return std::any_of(value_sequence->value().begin(), value_sequence->value().end(),
953                      [](const ValuePtr &sub_value) { return sub_value != nullptr && sub_value->isa<ValueSequence>(); });
954 }
955 }  // namespace
956 
GetOutputInferDataType(const AnfNodePtr & node,size_t output_idx)957 TypeId AnfAlgo::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) {
958   MS_EXCEPTION_IF_NULL(node);
959   if (IsCallNode(node) || IsTupleInTupleValueNode(node)) {
960     if (node->abstract() == nullptr) {
961       MS_LOG(INTERNAL_EXCEPTION) << "Empty abstract of call node:" << node->DebugString();
962     }
963     const auto &abs = common::AnfAlgo::FetchAbstractByIndex(node->abstract(), output_idx);
964     MS_EXCEPTION_IF_NULL(abs);
965     const auto &type = abs->BuildType();
966     MS_EXCEPTION_IF_NULL(type);
967     if (type->isa<TensorType>()) {
968       const auto &tensor_type = type->cast<TensorTypePtr>();
969       MS_EXCEPTION_IF_NULL(tensor_type);
970       const auto &element = tensor_type->element();
971       return element->type_id();
972     } else {
973       return type->type_id();
974     }
975   }
976   return GetOutputInferDataType(node->Type(), output_idx);
977 }
978 
GetPrevNodeOutputInferDataType(const AnfNodePtr & node,size_t input_idx)979 TypeId AnfAlgo::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) {
980   KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
981   return AnfAlgo::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second);
982 }
983 
GetPrevNodeOutputInferType(const AnfNodePtr & node,size_t input_idx)984 TypePtr AnfAlgo::GetPrevNodeOutputInferType(const AnfNodePtr &node, size_t input_idx) {
985   KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
986   return AnfAlgo::GetOutputInferType(kernel_with_index.first, kernel_with_index.second);
987 }
988 
989 // set infer shapes and types of anf node
SetOutputTypeAndDetailShape(const std::vector<TypeId> & types,const std::vector<abstract::BaseShapePtr> & shapes,AnfNode * node)990 void AnfAlgo::SetOutputTypeAndDetailShape(const std::vector<TypeId> &types,
991                                           const std::vector<abstract::BaseShapePtr> &shapes, AnfNode *node) {
992   MS_EXCEPTION_IF_NULL(node);
993   auto node_ptr = node->cast<AnfNodePtr>();
994   MS_EXCEPTION_IF_NULL(node_ptr);
995   std::string node_name = "";
996   if (node_ptr->isa<CNode>()) {
997     node_name = GetCNodeName(node_ptr);
998   }
999   if (types.size() != shapes.size()) {
1000     MS_LOG(INTERNAL_EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size()
1001                                << " for node " << node->fullname_with_scope() << "." << trace::DumpSourceLines(node);
1002   }
1003 
1004   auto tuple_node = kNodeTupleOutSet.find(node_name);
1005   if (shapes.empty() && tuple_node == kNodeTupleOutSet.end()) {
1006     node->set_abstract(std::make_shared<abstract::AbstractNone>());
1007   } else if (shapes.size() == 1 && tuple_node == kNodeTupleOutSet.end()) {
1008     // single output handle
1009     if (shapes[0]->isa<abstract::NoShape>()) {
1010       auto abstract = std::make_shared<abstract::AbstractScalar>(TypeIdToType(types[0]));
1011       node->set_abstract(abstract);
1012     } else {
1013       auto abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[0]), shapes[0]);
1014       node->set_abstract(abstract);
1015     }
1016   } else {
1017     // multiple output handle
1018     std::vector<AbstractBasePtr> abstract_list;
1019     for (size_t i = 0; i < types.size(); ++i) {
1020       if (shapes[0]->isa<abstract::NoShape>()) {
1021         auto abstract = std::make_shared<abstract::AbstractScalar>(TypeIdToType(types[i]));
1022         abstract_list.emplace_back(abstract);
1023       } else {
1024         auto abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[i]), shapes[i]);
1025         abstract_list.emplace_back(abstract);
1026       }
1027     }
1028     auto abstract_tuple = std::make_shared<AbstractTuple>(abstract_list);
1029     node->set_abstract(abstract_tuple);
1030   }
1031 }
1032 
SetSingleOutputTypeAndDetailShape(const std::vector<TypeId> & types,const std::vector<abstract::BaseShapePtr> & shapes,AnfNode * node)1033 void AnfAlgo::SetSingleOutputTypeAndDetailShape(const std::vector<TypeId> &types,
1034                                                 const std::vector<abstract::BaseShapePtr> &shapes, AnfNode *node) {
1035   MS_EXCEPTION_IF_NULL(node);
1036   auto node_ptr = node->cast<AnfNodePtr>();
1037   MS_EXCEPTION_IF_NULL(node_ptr);
1038   if (types.size() != shapes.size()) {
1039     MS_LOG(INTERNAL_EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size()
1040                                << " for node " << node->fullname_with_scope() << "." << trace::DumpSourceLines(node);
1041   }
1042   auto abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[0]), shapes[0]);
1043   node->set_abstract(abstract);
1044 }
1045 
1046 namespace {
DeleteDynamicLen(AnfNode * node)1047 void DeleteDynamicLen(AnfNode *node) {
1048   MS_EXCEPTION_IF_NULL(node);
1049   if (node->abstract() != nullptr && node->abstract()->isa<abstract::AbstractSequence>()) {
1050     const auto &tuple_abs = node->abstract()->cast<abstract::AbstractSequencePtr>();
1051     MS_EXCEPTION_IF_NULL(tuple_abs);
1052     if (tuple_abs->dynamic_len()) {
1053       auto cloned_abstract = tuple_abs->Clone()->cast<abstract::AbstractSequencePtr>();
1054       cloned_abstract->set_dynamic_len(false);
1055       node->set_abstract(cloned_abstract);
1056     }
1057   }
1058 }
1059 }  // namespace
1060 
1061 // set infer shapes and types of anf node
SetOutputInferTypeAndShape(const std::vector<TypeId> & types,const std::vector<ShapeVector> & shapes,AnfNode * node,bool disable_dynamic_len)1062 void AnfAlgo::SetOutputInferTypeAndShape(const std::vector<TypeId> &types, const std::vector<ShapeVector> &shapes,
1063                                          AnfNode *node, bool disable_dynamic_len) {
1064   MS_EXCEPTION_IF_NULL(node);
1065   if (disable_dynamic_len) {
1066     DeleteDynamicLen(node);
1067   }
1068   auto node_ptr = node->cast<AnfNodePtr>();
1069   MS_EXCEPTION_IF_NULL(node_ptr);
1070   std::string node_name = "";
1071   if (node_ptr->isa<CNode>()) {
1072     node_name = GetCNodeName(node_ptr);
1073   }
1074   if (types.size() != shapes.size()) {
1075     MS_LOG(INTERNAL_EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size()
1076                                << "." << trace::DumpSourceLines(node);
1077   }
1078   auto abstract_ptr = node_ptr->abstract();
1079 
1080   auto tuple_node = kNodeTupleOutSet.find(node_name);
1081   if (shapes.empty() && tuple_node == kNodeTupleOutSet.end()) {
1082     node->set_abstract(std::make_shared<abstract::AbstractNone>());
1083   } else if (shapes.size() == 1 && tuple_node == kNodeTupleOutSet.end()) {
1084     // single output handle
1085     if (abstract_ptr != nullptr && abstract_ptr->isa<abstract::AbstractMapTensor>()) {
1086       // For AbstractMapTensor.
1087       abstract_ptr->set_shape(std::make_shared<abstract::Shape>(shapes[0]));
1088       return;
1089     }
1090 
1091     abstract::AbstractTensorPtr abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[0]), shapes[0]);
1092     node->set_abstract(abstract);
1093   } else {
1094     // multiple output handle
1095     std::vector<AbstractBasePtr> abstract_list;
1096     for (size_t i = 0; i < types.size(); ++i) {
1097       abstract::AbstractTensorPtr abstract =
1098         std::make_shared<AbstractTensor>(TypeIdToType(types[i]), std::make_shared<abstract::Shape>(shapes[i]));
1099       abstract_list.emplace_back(abstract);
1100     }
1101     auto abstract_tuple = std::make_shared<AbstractTuple>(abstract_list);
1102     node->set_abstract(abstract_tuple);
1103   }
1104 }
1105 // copy an abstract of a node to another node
CopyAbstract(const AnfNodePtr & from_node,AnfNode * to_node)1106 void AnfAlgo::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node) {
1107   MS_EXCEPTION_IF_NULL(from_node);
1108   MS_EXCEPTION_IF_NULL(to_node);
1109   to_node->set_abstract(from_node->abstract());
1110 }
1111 
IsNodeInGraphKernel(const AnfNodePtr & node)1112 bool AnfAlgo::IsNodeInGraphKernel(const AnfNodePtr &node) {
1113   // this function was moved to AnfUtils.
1114   return AnfUtils::IsNodeInGraphKernel(node);
1115 }
1116 
GetOutputOfGraphkernel(const KernelWithIndex & kernel_with_index)1117 AnfNodePtr AnfAlgo::GetOutputOfGraphkernel(const KernelWithIndex &kernel_with_index) {
1118   auto func_graph = GetCNodeFuncGraph(kernel_with_index.first);
1119   if (func_graph == nullptr) {
1120     return kernel_with_index.first;
1121   }
1122   auto output = func_graph->output();
1123   if (CheckPrimitiveType(output, prim::kPrimMakeTuple)) {
1124     return output->cast<CNodePtr>()->input(kernel_with_index.second + 1);
1125   }
1126   return output;
1127 }
1128 
IsParameterWeight(const ParameterPtr & node)1129 bool AnfAlgo::IsParameterWeight(const ParameterPtr &node) {
1130   MS_EXCEPTION_IF_NULL(node);
1131   return node->has_default();
1132 }
1133 
IsLabelIndexInNode(const AnfNodePtr & node,size_t label_index)1134 bool AnfAlgo::IsLabelIndexInNode(const AnfNodePtr &node, size_t label_index) {
1135   MS_EXCEPTION_IF_NULL(node);
1136   if (!node->isa<CNode>()) {
1137     return false;
1138   }
1139   auto cnode = node->cast<CNodePtr>();
1140   MS_EXCEPTION_IF_NULL(cnode);
1141   if (AnfAlgo::GetCNodeName(cnode) == kLabelGotoOpName &&
1142       (AnfAlgo::GetNodeAttr<uint32_t>(cnode, kAttrLabelIndex) == label_index)) {
1143     return true;
1144   } else if (AnfAlgo::GetCNodeName(cnode) == kLabelSwitchOpName) {
1145     auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cnode, kAttrLabelSwitchList);
1146     if (std::find(label_list.begin(), label_list.end(), label_index) != label_list.end()) {
1147       return true;
1148     }
1149   }
1150   return false;
1151 }
1152 
IsUpdateParameterKernel(const CNodePtr & node)1153 bool AnfAlgo::IsUpdateParameterKernel(const CNodePtr &node) {
1154   MS_EXCEPTION_IF_NULL(node);
1155   auto node_name = GetCNodeName(node);
1156   if (HasNodeAttr(kAttrAsync, node) && GetNodeAttr<bool>(node, kAttrAsync)) {
1157     return false;
1158   }
1159   if (!IsOneOfOperator(node_name) && node_name.find("Assign") == string::npos) {
1160     return false;
1161   }
1162   return true;
1163 }
1164 
IsTupleOutput(const AnfNodePtr & anf)1165 bool AnfAlgo::IsTupleOutput(const AnfNodePtr &anf) {
1166   MS_EXCEPTION_IF_NULL(anf);
1167   TypePtr type = anf->Type();
1168   if (type == nullptr) {
1169     return false;
1170   }
1171 
1172   // For dynamic sequence node, all output should be emplaced in single tensor.
1173   if (anf->abstract() && IsDynamicSequence(anf)) {
1174     return false;
1175   }
1176 
1177   MS_EXCEPTION_IF_NULL(type);
1178   return type->isa<Tuple>() || type->isa<List>() || type->isa<SparseTensorType>();
1179 }
1180 
GetInputNode(const CNodePtr & node,size_t index)1181 AnfNodePtr AnfAlgo::GetInputNode(const CNodePtr &node, size_t index) {
1182   MS_EXCEPTION_IF_NULL(node);
1183   auto get_input_index = index + 1;
1184   if (get_input_index >= node->size()) {
1185     MS_LOG(INTERNAL_EXCEPTION) << "Input index size " << get_input_index << ", but the node input size just "
1186                                << node->size() << ". node: " << node->DebugString() << "."
1187                                << trace::DumpSourceLines(node);
1188   }
1189   // input 0 is primitive node
1190   return node->input(get_input_index);
1191 }
1192 
SetNodeInput(const CNodePtr & node,const AnfNodePtr & input_node,size_t index)1193 void AnfAlgo::SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index) {
1194   MS_EXCEPTION_IF_NULL(node);
1195   MS_EXCEPTION_IF_NULL(input_node);
1196   if (node->func_graph() != nullptr) {
1197     auto manager = node->func_graph()->manager();
1198     if (manager != nullptr) {
1199       manager->SetEdge(node, SizeToInt(index + 1), input_node);
1200       return;
1201     }
1202   }
1203   node->set_input(index + 1, input_node);
1204 }
1205 
GetCNodePrimitiveNode(const CNodePtr & node)1206 AnfNodePtr AnfAlgo::GetCNodePrimitiveNode(const CNodePtr &node) {
1207   MS_EXCEPTION_IF_NULL(node);
1208   return node->input(kAnfPrimitiveIndex);
1209 }
1210 
GetCNodePrimitive(const AnfNodePtr & node)1211 PrimitivePtr AnfAlgo::GetCNodePrimitive(const AnfNodePtr &node) {
1212   MS_EXCEPTION_IF_NULL(node);
1213   auto cnode = node->cast<CNodePtr>();
1214   MS_EXCEPTION_IF_NULL(cnode);
1215   auto attr_input = GetCNodePrimitiveNode(cnode);
1216   MS_EXCEPTION_IF_NULL(attr_input);
1217   auto value_node = attr_input->cast<ValueNodePtr>();
1218   MS_EXCEPTION_IF_NULL(value_node);
1219   auto value = value_node->value();
1220   MS_EXCEPTION_IF_NULL(value);
1221   auto primitive = value->cast<PrimitivePtr>();
1222   return primitive;
1223 }
1224 
IsInplaceNode(const mindspore::AnfNodePtr & kernel,const string & type)1225 bool AnfAlgo::IsInplaceNode(const mindspore::AnfNodePtr &kernel, const string &type) {
1226   MS_EXCEPTION_IF_NULL(kernel);
1227   auto primitive = AnfAlgo::GetCNodePrimitive(kernel);
1228   if (!primitive) {
1229     return false;
1230   }
1231 
1232   auto inplace_attr = primitive->GetAttr(type);
1233   if (inplace_attr == nullptr) {
1234     return false;
1235   }
1236 
1237   return true;
1238 }
1239 
IsCommunicationOp(const AnfNodePtr & node)1240 bool AnfAlgo::IsCommunicationOp(const AnfNodePtr &node) {
1241   static const std::set<std::string> kCommunicationOpNames = {
1242     kAllReduceOpName,       kAllGatherOpName,       kBroadcastOpName, kReduceScatterOpName,     kSendOpName,
1243     kReceiveOpName,         kAlltoAllOpName,        kAllToAllOpName,  kAllToAllvOpName,         kMuxReceiveOpName,
1244     kMuxSendOpName,         kReduceOpName,          kBarrierOpName,   kCollectiveScatterOpName, kCollectiveGatherOpName,
1245     kMatMulAllReduceOpName, kBatchISendIRecvOpName, kAlltoAllVOpName};
1246   MS_EXCEPTION_IF_NULL(node);
1247   if (!node->isa<CNode>()) {
1248     return false;
1249   }
1250   auto kernel_name = AnfAlgo::GetCNodeName(node);
1251   return (kCommunicationOpNames.find(kernel_name) != kCommunicationOpNames.end());
1252 }
1253 
IsDtypeFormatSensitiveOp(const AnfNodePtr & node)1254 bool AnfAlgo::IsDtypeFormatSensitiveOp(const AnfNodePtr &node) {
1255   static const std::set<std::string> kDtypeFormatSensitiveOpNames = {kCastOpName};
1256   MS_EXCEPTION_IF_NULL(node);
1257   if (!node->isa<CNode>()) {
1258     return false;
1259   }
1260   auto kernel_name = AnfAlgo::GetCNodeName(node);
1261   return (kDtypeFormatSensitiveOpNames.find(kernel_name) != kDtypeFormatSensitiveOpNames.end());
1262 }
1263 
IsFusedCommunicationOp(const AnfNodePtr & node)1264 bool AnfAlgo::IsFusedCommunicationOp(const AnfNodePtr &node) {
1265   if (!IsCommunicationOp(node)) {
1266     return false;
1267   }
1268   auto primitive = AnfAlgo::GetCNodePrimitive(node);
1269   MS_EXCEPTION_IF_NULL(primitive);
1270   ValuePtr attr_fusion = primitive->GetAttr(kAttrFusion);
1271   ValuePtr attr_not_delay_fusion = primitive->GetAttr(kAttrNotDelayFusion);
1272   if (attr_fusion == nullptr) {
1273     return false;
1274   }
1275 
1276   auto fusion = GetValue<int64_t>(attr_fusion);
1277   if (fusion == 0) {
1278     return false;
1279   }
1280   if (attr_not_delay_fusion && GetValue<bool>(attr_not_delay_fusion)) {
1281     return false;
1282   }
1283   return true;
1284 }
1285 
IsGetNext(const NotNull<AnfNodePtr> & node)1286 bool AnfAlgo::IsGetNext(const NotNull<AnfNodePtr> &node) {
1287   auto kernel_name = AnfAlgo::GetCNodeName(node);
1288   return kernel_name == kGetNextOpName || kernel_name == kDynamicGetNextV2OpName;
1289 }
1290 
IsGraphKernel(const AnfNodePtr & node)1291 bool AnfAlgo::IsGraphKernel(const AnfNodePtr &node) {
1292   // this function was moved to AnfUtils.
1293   return AnfUtils::IsGraphKernel(node);
1294 }
1295 
IsNeedSkipNopOpAddr(const AnfNodePtr & node)1296 bool AnfAlgo::IsNeedSkipNopOpAddr(const AnfNodePtr &node) {
1297   MS_EXCEPTION_IF_NULL(node);
1298   if (!node->isa<CNode>()) {
1299     return false;
1300   }
1301 
1302   auto primitive = AnfAlgo::GetCNodePrimitive(node);
1303   if (primitive == nullptr) {
1304     return false;
1305   }
1306 
1307   auto skip_nop_op_addr_attr = primitive->GetAttr(kAttrSkipNopOpAddr);
1308   if (skip_nop_op_addr_attr == nullptr) {
1309     return false;
1310   }
1311 
1312   return GetValue<bool>(skip_nop_op_addr_attr);
1313 }
1314 
IsNeedSkipNopOpExecution(const AnfNodePtr & node)1315 bool AnfAlgo::IsNeedSkipNopOpExecution(const AnfNodePtr &node) {
1316   MS_EXCEPTION_IF_NULL(node);
1317   if (!node->isa<CNode>()) {
1318     return false;
1319   }
1320 
1321   auto primitive = AnfAlgo::GetCNodePrimitive(node);
1322   if (primitive == nullptr) {
1323     return false;
1324   }
1325 
1326   auto skip_nop_execution_attr = primitive->GetAttr(kAttrSkipNopOpExecution);
1327   if (skip_nop_execution_attr == nullptr) {
1328     return false;
1329   }
1330 
1331   return GetValue<bool>(skip_nop_execution_attr);
1332 }
1333 
GetValueNodeFuncGraph(const AnfNodePtr & node)1334 FuncGraphPtr AnfAlgo::GetValueNodeFuncGraph(const AnfNodePtr &node) {
1335   MS_EXCEPTION_IF_NULL(node);
1336   auto value_node = node->cast<ValueNodePtr>();
1337   if (value_node == nullptr) {
1338     return nullptr;
1339   }
1340   auto value = value_node->value();
1341   if (value == nullptr) {
1342     return nullptr;
1343   }
1344   auto func_graph = value->cast<FuncGraphPtr>();
1345   return func_graph;
1346 }
1347 
IsSwitchCall(const CNodePtr & call_node)1348 bool AnfAlgo::IsSwitchCall(const CNodePtr &call_node) {
1349   MS_EXCEPTION_IF_NULL(call_node);
1350   if (!CheckPrimitiveType(call_node, prim::kPrimCall)) {
1351     MS_LOG(INTERNAL_EXCEPTION) << "Call node should be a 'call', but is a " << call_node->DebugString() << "."
1352                                << trace::DumpSourceLines(call_node);
1353   }
1354   auto input1 = call_node->input(1);
1355   MS_EXCEPTION_IF_NULL(input1);
1356   if (input1->isa<ValueNode>()) {
1357     return false;
1358   } else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) {
1359     return true;
1360   }
1361   MS_LOG(INTERNAL_EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString() << "."
1362                              << trace::DumpSourceLines(call_node);
1363 }
1364 
IsScalarInput(const CNodePtr & cnode,size_t index)1365 bool AnfAlgo::IsScalarInput(const CNodePtr &cnode, size_t index) {
1366   auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
1367   if (shape.empty()) {
1368     return true;
1369   }
1370   return shape.size() == kShape1dDims && shape[0] == 1;
1371 }
1372 
IsScalarOutput(const CNodePtr & cnode,size_t index)1373 bool AnfAlgo::IsScalarOutput(const CNodePtr &cnode, size_t index) {
1374   auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
1375   if (shape.empty()) {
1376     return true;
1377   }
1378   return shape.size() == kShape1dDims && shape[0] == 1;
1379 }
1380 
1381 namespace {
FindDelayExecPosition(const std::vector<CNodePtr> & nodes,size_t current_index,std::set<size_t> * invalid_position,std::map<size_t,std::vector<CNodePtr>> * insert_nodes)1382 void FindDelayExecPosition(const std::vector<CNodePtr> &nodes, size_t current_index, std::set<size_t> *invalid_position,
1383                            std::map<size_t, std::vector<CNodePtr>> *insert_nodes) {
1384   MS_EXCEPTION_IF_NULL(invalid_position);
1385   MS_EXCEPTION_IF_NULL(insert_nodes);
1386   if (current_index >= nodes.size()) {
1387     return;
1388   }
1389   auto &node = nodes[current_index];
1390   for (size_t j = current_index + 1; j < nodes.size(); ++j) {
1391     auto &child = nodes[j];
1392     auto child_name = AnfAlgo::GetCNodeName(child);
1393     if (child_name == kAssignAddOpName || child_name == kAssignSubOpName || child_name == kAssignOpName ||
1394         IsOneOfOperator(child_name)) {
1395       return;
1396     }
1397 
1398     auto input_size = child->size() - 1;
1399     for (size_t k = 0; k < input_size; ++k) {
1400       auto kernel_index = AnfAlgo::GetPrevNodeOutput(child, k, true);
1401       if (kernel_index.first != node) {
1402         continue;
1403       }
1404       (void)invalid_position->insert(current_index);
1405       auto iter = insert_nodes->find(j);
1406       if (iter != insert_nodes->end()) {
1407         iter->second.emplace_back(node);
1408       } else {
1409         (*insert_nodes)[j] = {node};
1410       }
1411       return;
1412     }
1413   }
1414 }
1415 
DelayExecNode(const std::vector<CNodePtr> & nodes,const std::string & node_name,bool only_seed)1416 std::vector<CNodePtr> DelayExecNode(const std::vector<CNodePtr> &nodes, const std::string &node_name, bool only_seed) {
1417   std::map<size_t, std::vector<CNodePtr>> insert_nodes;
1418   std::set<size_t> invalid_position;
1419   for (size_t i = 0; i < nodes.size(); ++i) {
1420     auto &node = nodes[i];
1421     if (AnfAlgo::GetCNodeName(node) != node_name) {
1422       continue;
1423     }
1424     if (only_seed) {
1425       bool is_seed = true;
1426       auto input_size = node->size() - 1;
1427       for (size_t k = 0; k < input_size; ++k) {
1428         auto input = AnfAlgo::GetPrevNodeOutput(node, k, true).first;
1429         if (input != nullptr && input->isa<CNode>()) {
1430           is_seed = false;
1431           break;
1432         }
1433       }
1434       if (!is_seed) {
1435         continue;
1436       }
1437     }
1438     FindDelayExecPosition(nodes, i, &invalid_position, &insert_nodes);
1439   }
1440   std::vector<CNodePtr> result;
1441   for (size_t i = 0; i < nodes.size(); ++i) {
1442     auto iter = insert_nodes.find(i);
1443     if (iter != insert_nodes.end()) {
1444       (void)result.insert(result.end(), iter->second.rbegin(), iter->second.rend());
1445     }
1446     if (invalid_position.find(i) != invalid_position.end()) {
1447       continue;
1448     }
1449     result.emplace_back(nodes[i]);
1450   }
1451   return result;
1452 }
1453 }  // namespace
1454 
ReorderExecList(NotNull<std::vector<CNodePtr> * > node_list)1455 void AnfAlgo::ReorderExecList(NotNull<std::vector<CNodePtr> *> node_list) {
1456   std::vector<CNodePtr> result;
1457   std::copy(node_list->begin(), node_list->end(), std::back_inserter(result));
1458   result = DelayExecNode(result, kTransDataOpName, true);
1459   result = DelayExecNode(result, kCastOpName, true);
1460   result = DelayExecNode(result, kAdamApplyOneWithDecayOpName, false);
1461   result = DelayExecNode(result, kAdamApplyOneOpName, false);
1462   result = DelayExecNode(result, kQuantDTypeCastOpName, false);
1463   result = DelayExecNode(result, kFSEDecodeOpName, false);
1464   if (parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() > 1) {
1465     result = DelayExecNode(result, kDropoutGenMaskOpName, true);
1466     result = DelayExecNode(result, kStatelessDropOutGenMaskOpName, true);
1467   }
1468   node_list->clear();
1469   std::copy(result.begin(), result.end(), std::back_inserter(*node_list));
1470 }
1471 
ReorderPosteriorExecList(NotNull<std::vector<CNodePtr> * > node_list)1472 void AnfAlgo::ReorderPosteriorExecList(NotNull<std::vector<CNodePtr> *> node_list) {
1473   std::vector<CNodePtr> ordinary_node_list;
1474   std::vector<CNodePtr> posterior_node_list;
1475 
1476   for (const auto &node : *node_list) {
1477     MS_EXCEPTION_IF_NULL(node);
1478     if (IsOneOfPosteriorOperator(AnfAlgo::GetCNodeName(node))) {
1479       posterior_node_list.emplace_back(node);
1480     } else {
1481       ordinary_node_list.emplace_back(node);
1482     }
1483   }
1484   node_list->clear();
1485   std::copy(ordinary_node_list.begin(), ordinary_node_list.end(), std::back_inserter(*node_list));
1486   std::copy(posterior_node_list.begin(), posterior_node_list.end(), std::back_inserter(*node_list));
1487 }
1488 
GetCNodeOutputPrecision(const AnfNodePtr & node)1489 TypeId AnfAlgo::GetCNodeOutputPrecision(const AnfNodePtr &node) {
1490   MS_EXCEPTION_IF_NULL(node);
1491   auto prim = AnfAlgo::GetCNodePrimitive(node);
1492   if (prim == nullptr) {
1493     return kTypeUnknown;
1494   }
1495 
1496   TypeId except_type = kTypeUnknown;
1497   if (prim->GetAttr(kAttrOutputPrecision) != nullptr) {
1498     auto output_type_str = GetValue<std::string>(prim->GetAttr(kAttrOutputPrecision));
1499     if (output_type_str == "float16") {
1500       except_type = kNumberTypeFloat16;
1501     } else if (output_type_str == "float32") {
1502       except_type = kNumberTypeFloat32;
1503     } else {
1504       MS_LOG(INTERNAL_EXCEPTION) << "The fix precision must be float16 or float32, but got " << output_type_str << "."
1505                                  << trace::DumpSourceLines(node);
1506     }
1507   }
1508 
1509   return except_type;
1510 }
1511 
GetPrevNodeOutputPrecision(const AnfNodePtr & node,size_t input_idx)1512 TypeId AnfAlgo::GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx) {
1513   MS_EXCEPTION_IF_NULL(node);
1514   if (!node->isa<CNode>()) {
1515     MS_LOG(INTERNAL_EXCEPTION) << node->DebugString() << ", input node is not CNode." << trace::DumpSourceLines(node);
1516   }
1517   auto cnode = node->cast<CNodePtr>();
1518   MS_EXCEPTION_IF_NULL(cnode);
1519   if (input_idx + 1 >= cnode->size()) {
1520     MS_LOG(INTERNAL_EXCEPTION) << "Input index " << input_idx << " is larger than input number "
1521                                << GetInputTensorNum(cnode) << "." << trace::DumpSourceLines(node);
1522   }
1523   auto input_node = cnode->input(input_idx + 1);
1524   MS_EXCEPTION_IF_NULL(input_node);
1525   auto kernel_with_index = VisitKernel(input_node, 0);
1526   if (!kernel_with_index.first->isa<CNode>()) {
1527     return kTypeUnknown;
1528   }
1529   return GetCNodeOutputPrecision(kernel_with_index.first);
1530 }
1531 
IsCondControlKernel(const CNodePtr & node)1532 bool AnfAlgo::IsCondControlKernel(const CNodePtr &node) {
1533   MS_EXCEPTION_IF_NULL(node);
1534   if (node->inputs().empty()) {
1535     MS_LOG(INTERNAL_EXCEPTION) << "Illegal null input of cnode." << trace::DumpSourceLines(node);
1536   }
1537   auto input = node->input(kAnfPrimitiveIndex);
1538   return IsPrimitive(input, prim::kPrimLabelGoto) || IsPrimitive(input, prim::kPrimLabelSwitch);
1539 }
1540 
GetBooleanAttr(const AnfNodePtr & node,const std::string & attr)1541 bool AnfAlgo::GetBooleanAttr(const AnfNodePtr &node, const std::string &attr) {
1542   MS_EXCEPTION_IF_NULL(node);
1543   if (!node->isa<CNode>()) {
1544     return false;
1545   }
1546   auto cnode = node->cast<CNodePtr>();
1547   MS_EXCEPTION_IF_NULL(cnode);
1548   auto has_attr = AnfAlgo::HasNodeAttr(attr, cnode);
1549   if (!has_attr) {
1550     return false;
1551   }
1552   return AnfAlgo::GetNodeAttr<bool>(node, attr);
1553 }
1554 
GetDumpFlag(const AnfNodePtr & node)1555 std::optional<string> AnfAlgo::GetDumpFlag(const AnfNodePtr &node) {
1556   MS_EXCEPTION_IF_NULL(node);
1557   auto cnode = node->cast<CNodePtr>();
1558   if (cnode == nullptr || !AnfAlgo::HasNodeAttr(kAttrDump, cnode)) {
1559     return {};
1560   }
1561   return std::optional<string>{AnfAlgo::GetNodeAttr<string>(node, kAttrDump)};
1562 }
1563 
IsNodeDynamicRank(const AnfNodePtr & node)1564 bool IsNodeDynamicRank(const AnfNodePtr &node) {
1565   MS_EXCEPTION_IF_NULL(node);
1566   if (!node->isa<CNode>()) {
1567     MS_LOG(DEBUG) << "Node is not a cnode";
1568     return false;
1569   }
1570   auto cnode = node->cast<CNodePtr>();
1571   MS_EXCEPTION_IF_NULL(cnode);
1572   auto in_dyn_rank = AnfAlgo::IsNodeInputDynamicRank(cnode);
1573   auto out_dyn_rank = AnfAlgo::IsNodeOutputDynamicRank(cnode);
1574   if (in_dyn_rank && !AnfAlgo::HasNodeAttr(kAttrInputIsDynamicRank, cnode)) {
1575     AnfAlgo::SetNodeAttrSafely(kAttrInputIsDynamicRank, MakeValue(true), cnode);
1576     MS_LOG(DEBUG) << "Set input dynamic rank attr for node:" << cnode->fullname_with_scope();
1577   }
1578   if (out_dyn_rank && !AnfAlgo::HasNodeAttr(kAttrOutputIsDynamicRank, cnode)) {
1579     AnfAlgo::SetNodeAttrSafely(kAttrOutputIsDynamicRank, MakeValue(true), cnode);
1580     MS_LOG(DEBUG) << "Set output dynamic rank attr for node:" << cnode->fullname_with_scope();
1581   }
1582   return in_dyn_rank || out_dyn_rank;
1583 }
1584 
IsDynamicRankNode(const AnfNodePtr & node)1585 bool AnfAlgo::IsDynamicRankNode(const AnfNodePtr &node) {
1586   MS_EXCEPTION_IF_NULL(node);
1587   if (node->isa<Parameter>()) {
1588     return IsOutputAnchorDynamicRank(node, 0);
1589   }
1590   auto cnode = node->cast<CNodePtr>();
1591   MS_EXCEPTION_IF_NULL(cnode);
1592   if ((!HasNodeAttr(kAttrInputIsDynamicRank, cnode)) && (!HasNodeAttr(kAttrOutputIsDynamicRank, cnode))) {
1593     auto ret = IsNodeDynamicRank(node);
1594     MS_LOG(DEBUG) << "The Node:" << node->fullname_with_scope() << " is dynamic rank: [" << ret << "]";
1595     return ret;
1596   }
1597   return GetBooleanAttr(node, kAttrInputIsDynamicRank) || GetBooleanAttr(node, kAttrOutputIsDynamicRank) ||
1598          GetBooleanAttr(node, kAttrIsDynamicRank);
1599 }
1600 
IsInputAnchorDynamicRank(const AnfNodePtr & node,size_t idx)1601 bool AnfAlgo::IsInputAnchorDynamicRank(const AnfNodePtr &node, size_t idx) {
1602   MS_EXCEPTION_IF_NULL(node);
1603   if (!node->isa<CNode>()) {
1604     MS_LOG(INTERNAL_EXCEPTION) << "Only cnode has inputs, node: " << node->fullname_with_scope();
1605   }
1606   const auto &in_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, idx);
1607   if (mindspore::IsDynamicRank(in_shape)) {
1608     return true;
1609   }
1610   return false;
1611 }
1612 
IsOutputAnchorDynamicRank(const AnfNodePtr & node,size_t idx)1613 bool AnfAlgo::IsOutputAnchorDynamicRank(const AnfNodePtr &node, size_t idx) {
1614   MS_EXCEPTION_IF_NULL(node);
1615   const auto &out_shape = common::AnfAlgo::GetOutputInferShape(node, idx);
1616   if (mindspore::IsDynamicRank(out_shape)) {
1617     return true;
1618   }
1619   return false;
1620 }
1621 
IsNodeInputDynamicRank(const CNodePtr & anf_node_ptr)1622 bool AnfAlgo::IsNodeInputDynamicRank(const CNodePtr &anf_node_ptr) {
1623   MS_EXCEPTION_IF_NULL(anf_node_ptr);
1624   const auto &inputs = anf_node_ptr->inputs();
1625   for (size_t i = 1; i < inputs.size(); ++i) {
1626     const auto &input = inputs[i];
1627     MS_EXCEPTION_IF_NULL(input);
1628     if (IsNodeOutputDynamicRank(input)) {
1629       return true;
1630     }
1631   }
1632   return false;
1633 }
1634 
IsNodeOutputDynamicRank(const AnfNodePtr & node)1635 bool AnfAlgo::IsNodeOutputDynamicRank(const AnfNodePtr &node) {
1636   MS_EXCEPTION_IF_NULL(node);
1637   auto base_shape = node->Shape();
1638   if (base_shape == nullptr) {
1639     MS_LOG(INFO) << "Invalid base shape, node: " << node->fullname_with_scope();
1640     return false;
1641   }
1642   if (base_shape->isa<abstract::DynamicSequenceShape>()) {
1643     auto b_ptr = base_shape->cast<abstract::DynamicSequenceShapePtr>();
1644     if (b_ptr->IsDimUnknown()) {
1645       return true;
1646     }
1647   }
1648   return base_shape->IsDimUnknown();
1649 }
1650 
IsDynamicShape(const AnfNodePtr & node)1651 bool AnfAlgo::IsDynamicShape(const AnfNodePtr &node) {
1652   MS_EXCEPTION_IF_NULL(node);
1653   if (!node->isa<CNode>()) {
1654     MS_LOG(DEBUG) << "Node is not a cnode.";
1655     return false;
1656   }
1657   auto cnode = node->cast<CNodePtr>();
1658   if ((!HasNodeAttr(kAttrInputIsDynamicShape, cnode)) && (!HasNodeAttr(kAttrOutputIsDynamicShape, cnode))) {
1659     auto ret = IsNodeDynamicShape(node);
1660     MS_LOG(DEBUG) << "The Node:" << node->fullname_with_scope() << " is dynamic shape or not:" << ret;
1661     return ret;
1662   }
1663   return GetBooleanAttr(node, kAttrInputIsDynamicShape) || GetBooleanAttr(node, kAttrOutputIsDynamicShape);
1664 }
1665 
IsDynamicValue(const AnfNodePtr & node)1666 bool AnfAlgo::IsDynamicValue(const AnfNodePtr &node) {
1667   MS_EXCEPTION_IF_NULL(node);
1668   if (!node->isa<CNode>()) {
1669     MS_LOG(DEBUG) << "Node is not a cnode.";
1670     return false;
1671   }
1672   if (AnfAlgo::IsGraphKernel(node)) {
1673     MS_LOG(DEBUG) << "Node(" << node->fullname_with_scope() << ") is GraphKernel node, it's not dynamic value type.";
1674     return false;
1675   }
1676 
1677   auto cnode = node->cast<CNodePtr>();
1678   if (cnode->HasAttr(ops::kHasDynamicValue)) {
1679     return true;
1680   }
1681   auto depend_list = abstract::GetValueDependArgIndices(cnode);
1682   if (!depend_list.empty()) {
1683     size_t real_input_num = cnode->size() - 1;  // exclude primitive in input[0]
1684     for (auto i = depend_list.begin(); i != depend_list.end(); i++) {
1685       if (*i >= SizeToInt(real_input_num)) {
1686         continue;
1687       }
1688       if (!cnode->input(*i + 1)->isa<ValueNode>()) {
1689         cnode->AddAttr(mindspore::ops::kHasDynamicValue, MakeValue(true));
1690         MS_LOG(DEBUG) << "The input index[" << *i << "]"
1691                       << " of node: " << cnode->fullname_with_scope() << " is a dynamic value input";
1692         return true;
1693       }
1694     }
1695   }
1696   return false;
1697 }
1698 
GetRealDynamicShape(const std::vector<size_t> & shape,NotNull<std::vector<int64_t> * > dynamic_shape)1699 void AnfAlgo::GetRealDynamicShape(const std::vector<size_t> &shape, NotNull<std::vector<int64_t> *> dynamic_shape) {
1700   for (auto size : shape) {
1701     if (size == SIZE_MAX) {
1702       dynamic_shape->push_back(-1);
1703     } else {
1704       dynamic_shape->push_back(SizeToLong(size));
1705     }
1706   }
1707 }
1708 
GetShapeFromSequenceShape(const abstract::SequenceShapePtr & sequeue_shape_ptr,size_t index)1709 static ShapeVector GetShapeFromSequenceShape(const abstract::SequenceShapePtr &sequeue_shape_ptr, size_t index) {
1710   MS_EXCEPTION_IF_NULL(sequeue_shape_ptr);
1711   auto shape_list = sequeue_shape_ptr->shape();
1712   if (index >= shape_list.size()) {
1713     MS_LOG(INTERNAL_EXCEPTION) << "Output Index:" << index << " >= " << shape_list.size();
1714   }
1715 
1716   auto shape = shape_list[index];
1717   MS_EXCEPTION_IF_NULL(shape);
1718   if (shape->isa<abstract::NoShape>()) {
1719     // For scalar in sequeue case.
1720     return {};
1721   } else if (!shape->isa<abstract::Shape>()) {
1722     MS_LOG(INTERNAL_EXCEPTION) << "Invalid Shape Type(" << shape->ToString() << ") In Shape List";
1723   }
1724 
1725   auto shape_ptr = shape->cast<abstract::ShapePtr>();
1726   return shape_ptr->max_shape();
1727 }
1728 
GetOutputMaxShape(const AnfNodePtr & anf_node,size_t index)1729 ShapeVector AnfAlgo::GetOutputMaxShape(const AnfNodePtr &anf_node, size_t index) {
1730   MS_EXCEPTION_IF_NULL(anf_node);
1731   auto shape = anf_node->Shape();
1732   MS_EXCEPTION_IF_NULL(shape);
1733   if (shape->isa<abstract::Shape>()) {
1734     auto shape_ptr = shape->cast<abstract::ShapePtr>();
1735     return shape_ptr->max_shape();
1736   } else if (shape->isa<abstract::SequenceShape>()) {
1737     auto sequeue_shape_ptr = shape->cast<abstract::SequenceShapePtr>();
1738     return GetShapeFromSequenceShape(sequeue_shape_ptr, index);
1739   } else if (shape->isa<abstract::NoShape>()) {
1740     return {};
1741   } else if (shape->isa<abstract::DynamicSequenceShape>()) {
1742     return {1};
1743   } else {
1744     MS_LOG(INTERNAL_EXCEPTION) << "Invalid shape type." << trace::DumpSourceLines(anf_node);
1745   }
1746 }
1747 
IsNodeOutputDynamicShape(const AnfNodePtr & node)1748 bool AnfAlgo::IsNodeOutputDynamicShape(const AnfNodePtr &node) {
1749   MS_EXCEPTION_IF_NULL(node);
1750   auto base_shape = node->Shape();
1751   if (base_shape == nullptr) {
1752     MS_LOG(INFO) << "Invalid base shape, node: " << node->fullname_with_scope();
1753     return false;
1754   }
1755   if (base_shape->isa<abstract::DynamicSequenceShape>()) {
1756     return true;
1757   }
1758   return base_shape->IsDynamic();
1759 }
1760 
IsNodeInputDynamicShape(const CNodePtr & anf_node_ptr)1761 bool AnfAlgo::IsNodeInputDynamicShape(const CNodePtr &anf_node_ptr) {
1762   MS_EXCEPTION_IF_NULL(anf_node_ptr);
1763   const auto &inputs = anf_node_ptr->inputs();
1764   for (size_t i = 1; i < inputs.size(); ++i) {
1765     const auto &input = inputs[i];
1766     MS_EXCEPTION_IF_NULL(input);
1767     if (IsNodeOutputDynamicShape(input)) {
1768       return true;
1769     }
1770   }
1771   return false;
1772 }
1773 
GetGraphSplitGroup(const AnfNodePtr & node)1774 std::string AnfAlgo::GetGraphSplitGroup(const AnfNodePtr &node) {
1775   return HasNodeAttr(kAttrGraphSplitGroup, node->cast<CNodePtr>())
1776            ? GetNodeAttr<std::string>(node->cast<CNodePtr>(), kAttrGraphSplitGroup)
1777            : "DefaultGroup";
1778 }
1779 
GetAllVisitedCNode(const CNodePtr & node,std::vector<AnfNodePtr> * used_kernels,std::set<AnfNodePtr> * visited)1780 void AnfAlgo::GetAllVisitedCNode(const CNodePtr &node, std::vector<AnfNodePtr> *used_kernels,
1781                                  std::set<AnfNodePtr> *visited) {
1782   MS_EXCEPTION_IF_NULL(node);
1783   MS_EXCEPTION_IF_NULL(used_kernels);
1784   MS_EXCEPTION_IF_NULL(visited);
1785   if (visited->find(node) != visited->end()) {
1786     MS_LOG(INFO) << "Node:" << node->fullname_with_scope() << " has already been visited";
1787     return;
1788   }
1789   (void)visited->insert(node);
1790   auto input_size = node->size() - 1;
1791   for (size_t i = 0; i < input_size; ++i) {
1792     auto input = AnfAlgo::GetInputNode(node, i);
1793     if (!input->isa<CNode>()) {
1794       continue;
1795     }
1796     if (!AnfUtils::IsRealKernel(input) || IsNopNode(input)) {
1797       GetAllVisitedCNode(input->cast<CNodePtr>(), used_kernels, visited);
1798     } else {
1799       used_kernels->push_back(input);
1800     }
1801   }
1802 }
1803 
GetAllFatherRealNode(const AnfNodePtr & anf_node,std::vector<AnfNodePtr> * result,std::set<AnfNodePtr> * visited)1804 void AnfAlgo::GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
1805                                    std::set<AnfNodePtr> *visited) {
1806   MS_EXCEPTION_IF_NULL(anf_node);
1807   MS_EXCEPTION_IF_NULL(result);
1808   MS_EXCEPTION_IF_NULL(visited);
1809   if (visited->find(anf_node) != visited->end()) {
1810     MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " has already been visited";
1811     return;
1812   }
1813   visited->insert(anf_node);
1814   if (AnfUtils::IsRealKernel(anf_node)) {
1815     result->emplace_back(anf_node);
1816     return;
1817   }
1818   if (!anf_node->isa<CNode>()) {
1819     return;
1820   }
1821   auto cnode = anf_node->cast<CNodePtr>();
1822   MS_EXCEPTION_IF_NULL(cnode);
1823   if (cnode->inputs().empty()) {
1824     MS_LOG(INTERNAL_EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString() << "."
1825                                << trace::DumpSourceLines(cnode);
1826   }
1827   auto input0 = cnode->input(0);
1828   if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
1829     for (size_t i = 1; i < cnode->size(); ++i) {
1830       GetAllFatherRealNode(cnode->input(i), result, visited);
1831     }
1832   } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
1833     if (cnode->size() != kTupleGetItemInputSize) {
1834       MS_LOG(INTERNAL_EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
1835     }
1836     GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited);
1837   } else if (IsPrimitive(input0, prim::kPrimDepend)) {
1838     if (cnode->size() != kDependInputSize) {
1839       MS_LOG(INTERNAL_EXCEPTION) << "Depend node must have 2 inputs!" << trace::DumpSourceLines(cnode);
1840     }
1841     GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited);
1842     GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited);
1843   }
1844 }
1845 
IsHostKernel(const CNodePtr & kernel_node)1846 bool AnfAlgo::IsHostKernel(const CNodePtr &kernel_node) {
1847   static const std::map<std::string, std::pair<size_t, size_t>> host_kernel_input_output_num = {
1848     {prim::kPrimDynamicShape->name(), {1, 1}},
1849     {prim::kPrimReshape->name(), {2, 1}},
1850     {prim::kPrimTensorShape->name(), {1, 1}}};
1851 
1852   auto op_name = AnfAlgo::GetCNodeName(kernel_node);
1853   auto iter = host_kernel_input_output_num.find(op_name);
1854   if (iter == host_kernel_input_output_num.end()) {
1855     return false;
1856   }
1857 
1858   auto input_num = GetInputTensorNum(kernel_node);
1859   auto output_num = AnfUtils::GetOutputTensorNum(kernel_node);
1860   auto kernel_input_num = iter->second.first;
1861   auto kernel_output_num = iter->second.second;
1862   if (kernel_input_num != input_num || kernel_output_num != output_num) {
1863     return false;
1864   }
1865   return true;
1866 }
1867 
AddArgList(AbstractBasePtrList * args_spec_list,const AnfNodePtr & real_input,size_t real_input_index)1868 void AnfAlgo::AddArgList(AbstractBasePtrList *args_spec_list, const AnfNodePtr &real_input, size_t real_input_index) {
1869   MS_EXCEPTION_IF_NULL(args_spec_list);
1870   MS_EXCEPTION_IF_NULL(real_input);
1871 
1872   // cppcheck-suppress unreadVariable
1873   auto lock = AnfUtils::GetAbstractLock(real_input.get());
1874   auto real_abs = real_input->abstract();
1875   MS_EXCEPTION_IF_NULL(real_abs);
1876   if (real_abs->isa<abstract::AbstractTuple>() && (!common::AnfAlgo::IsDynamicSequence(real_input))) {
1877     auto abs_tuple = real_abs->Clone()->cast<abstract::AbstractTuplePtr>();
1878     MS_EXCEPTION_IF_NULL(abs_tuple);
1879     MS_EXCEPTION_IF_CHECK_FAIL((real_input_index < abs_tuple->elements().size()), "Index is out of range.");
1880     auto abs_index = abs_tuple->elements()[real_input_index];
1881     (void)args_spec_list->emplace_back(abs_index);
1882   } else {
1883     (void)args_spec_list->emplace_back(real_abs->Clone());
1884   }
1885 }
1886 
GetUpdateStateUsers(const FuncGraphManagerPtr & manager,const AnfNodePtr & node)1887 AnfNodeIndexSet AnfAlgo::GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) {
1888   AnfNodeIndexSet update_states;
1889   for (auto &user : manager->node_users()[node]) {
1890     if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimUpdateState)) {
1891       update_states.insert(user);
1892     }
1893   }
1894   return update_states;
1895 }
1896 
GetRealInputs(const AnfNodePtr & node,std::vector<KernelWithIndex> * inputs)1897 void AnfAlgo::GetRealInputs(const AnfNodePtr &node, std::vector<KernelWithIndex> *inputs) {
1898   size_t input_num = AnfAlgo::GetInputTensorNum(node);
1899   for (size_t input_index = 0; input_index < input_num; ++input_index) {
1900     auto input_node = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), input_index);
1901     GetRealOutputRecursively(input_node, 0, inputs);
1902   }
1903 }
1904 
IsBpropCutOpExecInBackend(const AnfNodePtr & node)1905 bool AnfAlgo::IsBpropCutOpExecInBackend(const AnfNodePtr &node) {
1906   MS_EXCEPTION_IF_NULL(node);
1907   if (!node->isa<CNode>()) {
1908     return false;
1909   }
1910   // Operators in set control_ops_exec_in_backend will be compiled into kernel graph, rather than be cut into single op
1911   // and executed in VM.
1912   static std::set<std::string> bprop_cut_ops_exec_in_backend = {kBpropCutOpName};
1913   return bprop_cut_ops_exec_in_backend.find(AnfAlgo::GetCNodeName(node)) != bprop_cut_ops_exec_in_backend.end();
1914 }
1915 
IsNodeInputContainMonad(const AnfNodePtr & node)1916 bool AnfAlgo::IsNodeInputContainMonad(const AnfNodePtr &node) {
1917   MS_EXCEPTION_IF_NULL(node);
1918   auto input_size = GetInputTensorNum(node);
1919   for (size_t i = 0; i < input_size; ++i) {
1920     auto input_with_index = GetPrevNodeOutput(node, i);
1921     if (HasAbstractMonad(input_with_index.first)) {
1922       return true;
1923     }
1924   }
1925   return false;
1926 }
1927 
HasMonadInput(const AnfNodePtr & node)1928 bool AnfAlgo::HasMonadInput(const AnfNodePtr &node) {
1929   MS_EXCEPTION_IF_NULL(node);
1930   if (!node->isa<CNode>()) {
1931     return false;
1932   }
1933 
1934   auto cnode = node->cast<CNodePtr>();
1935   MS_EXCEPTION_IF_NULL(cnode);
1936   const auto &inputs = cnode->inputs();
1937   for (const auto &input : inputs) {
1938     MS_EXCEPTION_IF_NULL(input);
1939     if (HasAbstractMonad(input)) {
1940       return true;
1941     }
1942   }
1943   return false;
1944 }
1945 
IsNonTaskOp(const CNodePtr & node)1946 bool AnfAlgo::IsNonTaskOp(const CNodePtr &node) {
1947   auto op_name = GetCNodeName(node);
1948   return (op_name == kSplitOpName || op_name == kSplitDOpName || op_name == kSplitVDOpName) &&
1949          AnfAlgo::HasNodeAttr(kAttrNonTask, node);
1950 }
1951 
IsNoneInput(const AnfNodePtr & node,size_t index)1952 bool AnfAlgo::IsNoneInput(const AnfNodePtr &node, size_t index) {
1953   MS_EXCEPTION_IF_NULL(node);
1954   auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, index);
1955   auto prev_node = kernel_with_index.first;
1956   MS_EXCEPTION_IF_NULL(prev_node);
1957   // Only const optional input(None) support now.
1958   if (prev_node->isa<ValueNode>()) {
1959     auto value = prev_node->cast<ValueNodePtr>()->value();
1960     MS_EXCEPTION_IF_NULL(value);
1961     if (value->isa<None>()) {
1962       return true;
1963     }
1964   }
1965 
1966   return false;
1967 }
1968 
IsCallNode(const AnfNodePtr & node)1969 bool AnfAlgo::IsCallNode(const AnfNodePtr &node) {
1970   MS_EXCEPTION_IF_NULL(node);
1971   if (!node->isa<CNode>()) {
1972     return false;
1973   }
1974   auto input0 = node->cast<CNodePtr>()->input(0);
1975   if (IsValueNode<Primitive>(input0)) {
1976     return false;
1977   }
1978   return true;
1979 }
1980 
GetAttrGroups(const AnfNodePtr & node,size_t index)1981 int64_t AnfAlgo::GetAttrGroups(const AnfNodePtr &node, size_t index) {
1982   if (node == nullptr) {
1983     return 1;
1984   }
1985   if (node->isa<CNode>()) {
1986     auto cnode = node->cast<CNodePtr>();
1987     if (HasNodeAttr(kAttrFracZGroupIdx, cnode)) {
1988       auto fz_group_idx = GetNodeAttr<std::vector<int64_t>>(cnode, kAttrFracZGroupIdx);
1989       if (index >= fz_group_idx.size()) {
1990         MS_LOG(INTERNAL_EXCEPTION) << "Index out of range, attr fracz_group_idx of node[" << node->fullname_with_scope()
1991                                    << "] only have " << fz_group_idx.size() << " numbers, but get index " << index;
1992       }
1993       return fz_group_idx[index];
1994     } else if (HasNodeAttr(kAttrFracZGroup, cnode)) {
1995       return GetNodeAttr<int64_t>(cnode, kAttrFracZGroup);
1996     }
1997   }
1998   if (node->isa<Parameter>()) {
1999     auto param = node->cast<ParameterPtr>();
2000     MS_EXCEPTION_IF_NULL(param);
2001     return param->fracz_group();
2002   }
2003   if (node->isa<ValueNode>()) {
2004     auto value_node = node->cast<ValueNodePtr>();
2005     MS_EXCEPTION_IF_NULL(value_node);
2006     return value_node->fracz_group();
2007   }
2008   return 1;
2009 }
2010 
GetTupleIndexes(const AnfNodePtr & node,std::vector<size_t> * const index_stack)2011 AnfNodePtr AnfAlgo::GetTupleIndexes(const AnfNodePtr &node, std::vector<size_t> *const index_stack) {
2012   MS_EXCEPTION_IF_NULL(node);
2013   MS_EXCEPTION_IF_NULL(index_stack);
2014 
2015   if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
2016     auto tuple_getitem = node->cast<CNodePtr>();
2017     MS_EXCEPTION_IF_NULL(tuple_getitem);
2018     // Get cur index
2019     auto output_index_value_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem);
2020     MS_EXCEPTION_IF_NULL(output_index_value_node);
2021     auto value_node = output_index_value_node->cast<ValueNodePtr>();
2022     MS_EXCEPTION_IF_NULL(value_node);
2023     auto output_idx = LongToSize(GetValue<int64_t>(value_node->value()));
2024     index_stack->push_back(output_idx);
2025     auto real_input = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem);
2026     return GetTupleIndexes(real_input, index_stack);
2027   }
2028   if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
2029     // If make_tuple in make_tuple, visit may start with inner tuple_getitem.
2030     if (index_stack->empty()) {
2031       MS_LOG(WARNING) << "Visit make tuple: " << node->DebugString()
2032                       << ", but index are empty, visit should not start with inner tuple_getitem.";
2033       return nullptr;
2034     }
2035     auto make_tuple = node->cast<CNodePtr>();
2036     MS_EXCEPTION_IF_NULL(make_tuple);
2037     auto output_idx = index_stack->back();
2038     index_stack->pop_back();
2039     return GetTupleIndexes(make_tuple->input(1 + output_idx), index_stack);
2040   }
2041   if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
2042     return GetTupleIndexes(node->cast<CNodePtr>()->input(kRealInputIndexInDepend), index_stack);
2043   }
2044   if (IsPrimitiveCNode(node, prim::kPrimLoad)) {
2045     return GetTupleIndexes(node->cast<CNodePtr>()->input(1), index_stack);
2046   }
2047   MS_LOG(DEBUG) << "Get real node:" << node->DebugString();
2048   return node;
2049 }
2050 
IsNopNode(const AnfNodePtr & node)2051 bool AnfAlgo::IsNopNode(const AnfNodePtr &node) {
2052   static mindspore::HashSet<std::string> nop_nodes = {prim::kPrimReshape->name(),
2053                                                       kExpandDimsOpName,
2054                                                       prim::kPrimSqueeze->name(),
2055                                                       prim::kPrimFlatten->name(),
2056                                                       kFlattenGradOpName,
2057                                                       prim::kPrimReformat->name(),
2058                                                       prim::kPrimTupleToList->name(),
2059                                                       prim::kPrimListToTuple->name(),
2060                                                       prim::kPrimTupleToTensor->name(),
2061                                                       prim::kPrimScalarToTensor->name(),
2062                                                       prim::kPrimTensorToTuple->name(),
2063                                                       prim::kPrimTensorToScalar->name(),
2064                                                       "ReshapeExt"};
2065   if (node == nullptr || !node->isa<CNode>()) {
2066     return false;
2067   }
2068   CNodePtr cnode = node->cast<CNodePtr>();
2069   MS_EXCEPTION_IF_NULL(cnode);
2070   if (cnode->inputs().empty()) {
2071     return false;
2072   }
2073   auto input0 = cnode->input(0);
2074   MS_EXCEPTION_IF_NULL(input0);
2075   if (!input0->isa<ValueNode>()) {
2076     return false;
2077   }
2078   bool is_nop_node = false;
2079   if (AnfAlgo::HasNodeAttr(kAttrNopOp, cnode)) {
2080     is_nop_node = AnfAlgo::GetNodeAttr<bool>(cnode, kAttrNopOp);
2081   }
2082   if (nop_nodes.find(AnfAlgo::GetCNodeName(cnode)) == nop_nodes.end() && !is_nop_node) {
2083     return false;
2084   }
2085 
2086   // Check the input type and output type.
2087   if (GetOutputInferDataType(node, 0) != GetPrevNodeOutputInferDataType(node, 0)) {
2088     return false;
2089   }
2090 
2091   return true;
2092 }
2093 
2094 template <typename T>
CheckAbsType(const AnfNodePtr & node)2095 bool AnfAlgo::CheckAbsType(const AnfNodePtr &node) {
2096   MS_EXCEPTION_IF_NULL(node);
2097   MS_EXCEPTION_IF_NULL(node->abstract());
2098   return node->abstract()->cast<T>() != nullptr;
2099 }
2100 
CheckAbsSparseTensor(const AnfNodePtr & node)2101 bool AnfAlgo::CheckAbsSparseTensor(const AnfNodePtr &node) {
2102   return CheckAbsType<abstract::AbstractSparseTensorPtr>(node);
2103 }
2104 
CheckAbsSparseTensor(const abstract::AbstractBasePtr & abs)2105 bool AnfAlgo::CheckAbsSparseTensor(const abstract::AbstractBasePtr &abs) {
2106   return abs->cast<abstract::AbstractSparseTensorPtr>() != nullptr;
2107 }
2108 
GetSparseTypeIdAt(const AnfNodePtr & node,size_t idx)2109 TypeId AnfAlgo::GetSparseTypeIdAt(const AnfNodePtr &node, size_t idx) {
2110   if (CheckAbsType<abstract::AbstractSparseTensorPtr>(node)) {
2111     auto abs_sparse = node->abstract()->cast<abstract::AbstractSparseTensorPtr>();
2112     auto shape_idx = abs_sparse->size() - 1;
2113     // idx points to a tensor element
2114     if (idx < shape_idx) {
2115       return abs_sparse->GetTensorTypeIdAt(idx);
2116     }
2117     return abs_sparse->GetShapeTypeIdAt(idx - shape_idx);
2118   }
2119   MS_LOG(INTERNAL_EXCEPTION) << "Expect AbstractCSRTensor or AbstractCOOTensor, but got "
2120                              << node->abstract()->ToString();
2121 }
2122 
GetTensorValueString(const tensor::BaseTensorPtr & tensor)2123 std::string AnfAlgo::GetTensorValueString(const tensor::BaseTensorPtr &tensor) {
2124   MS_EXCEPTION_IF_NULL(tensor);
2125   auto dtype = tensor->Dtype();
2126   MS_EXCEPTION_IF_NULL(dtype);
2127   size_t data_size = tensor->DataSize();
2128   auto shape = tensor->shape();
2129   std::ostringstream buf;
2130   auto fn = [&buf, data_size, &shape](auto addr) {
2131     // Tensor value.
2132     buf << "v";
2133     for (size_t i = 0; i < data_size; ++i) {
2134       buf << *(addr + i) << ",";
2135     }
2136     // Tensor shape is necessary.
2137     // For example, the value of ones[3x4] and ones[4x3] are the same, but the shape is different.
2138     buf << "s" << tensor::ShapeToString(shape);
2139   };
2140 
2141   if (dtype->type_id() == kNumberTypeBool) {
2142     fn(reinterpret_cast<bool *>(tensor->data_c()));
2143   } else if (dtype->type_id() == kNumberTypeInt) {
2144     fn(reinterpret_cast<int *>(tensor->data_c()));
2145   } else if (dtype->type_id() == kNumberTypeInt8) {
2146     fn(reinterpret_cast<int8_t *>(tensor->data_c()));
2147   } else if (dtype->type_id() == kNumberTypeUInt8) {
2148     fn(reinterpret_cast<uint8_t *>(tensor->data_c()));
2149   } else if (dtype->type_id() == kNumberTypeInt16) {
2150     fn(reinterpret_cast<int16_t *>(tensor->data_c()));
2151   } else if (dtype->type_id() == kNumberTypeUInt16) {
2152     fn(reinterpret_cast<uint16_t *>(tensor->data_c()));
2153   } else if (dtype->type_id() == kNumberTypeInt32) {
2154     fn(reinterpret_cast<int32_t *>(tensor->data_c()));
2155   } else if (dtype->type_id() == kNumberTypeUInt32) {
2156     fn(reinterpret_cast<uint32_t *>(tensor->data_c()));
2157   } else if (dtype->type_id() == kNumberTypeInt64) {
2158     fn(reinterpret_cast<int64_t *>(tensor->data_c()));
2159   } else if (dtype->type_id() == kNumberTypeUInt64) {
2160     fn(reinterpret_cast<uint64_t *>(tensor->data_c()));
2161   } else if (dtype->type_id() == kNumberTypeFloat16) {
2162     fn(reinterpret_cast<float16 *>(tensor->data_c()));
2163   } else if (dtype->type_id() == kNumberTypeFloat64) {
2164     fn(reinterpret_cast<double *>(tensor->data_c()));
2165   } else if (dtype->type_id() == kNumberTypeFloat || dtype->type_id() == kNumberTypeFloat32) {
2166     fn(reinterpret_cast<float *>(tensor->data_c()));
2167   } else if (dtype->type_id() == kNumberTypeBFloat16) {
2168     fn(reinterpret_cast<bfloat16 *>(tensor->data_c()));
2169   } else if (dtype->type_id() == kNumberTypeComplex64) {
2170     fn(reinterpret_cast<complex64 *>(tensor->data_c()));
2171   } else if (dtype->type_id() == kNumberTypeComplex128) {
2172     fn(reinterpret_cast<complex128 *>(tensor->data_c()));
2173   } else {
2174     MS_LOG(INTERNAL_EXCEPTION) << "The dtype of the constant input is " << dtype->ToString();
2175   }
2176   return buf.str();
2177 }
2178 
FrontendGetNodeAbstractByIndex(const AnfNodePtr & node,size_t index)2179 abstract::AbstractBasePtr AnfAlgo::FrontendGetNodeAbstractByIndex(const AnfNodePtr &node, size_t index) {
2180   MS_EXCEPTION_IF_NULL(node);
2181   const auto &abstract = node->abstract();
2182   if (abstract == nullptr) {
2183     return abstract;
2184   }
2185 
2186   // Return output abstract directly for : 1.not sequence type, 2.dynamic sequence type, 3.real tuple/list type.
2187   if (!abstract->isa<abstract::AbstractSequence>() || common::AnfAlgo::IsDynamicSequence(node)) {
2188     MS_EXCEPTION_IF_CHECK_FAIL((index == 0),
2189                                "Cannot get " + std::to_string(index) + " child abstract from " + abstract->ToString());
2190     return abstract;
2191   }
2192 
2193   // Return element abstract by index for tuple type.
2194   const auto &abstract_tuple = abstract->cast<abstract::AbstractSequencePtr>();
2195   MS_EXCEPTION_IF_NULL(abstract_tuple);
2196   const auto &elements = abstract_tuple->elements();
2197   if (elements.size() <= index) {
2198     const auto sub_abstract = FetchAbstractByIndex(node->abstract(), index);
2199     return sub_abstract;
2200   }
2201   return elements[index];
2202 }
2203 
GetJitLevel(const FuncGraphPtr & func_graph)2204 std::string AnfAlgo::GetJitLevel(const FuncGraphPtr &func_graph) {
2205   MS_EXCEPTION_IF_NULL(func_graph);
2206   if (!func_graph->has_attr(kAttrJitLevel)) {
2207     MS_LOG(INFO) << "The func_graph:" << func_graph->ToString() << " has no jit_level attr, return default: None.";
2208     return "";
2209   }
2210   auto jit_level_value = func_graph->get_attr(kAttrJitLevel);
2211   auto jit_level = GetValue<std::string>(jit_level_value);
2212   return jit_level;
2213 }
2214 
IsNodeMutableScalar(const AnfNodePtr & node)2215 bool AnfAlgo::IsNodeMutableScalar(const AnfNodePtr &node) {
2216   MS_EXCEPTION_IF_NULL(node);
2217   if (!node->isa<CNode>()) {
2218     return false;
2219   }
2220   // Check if the node is mutable scalar by all_inputs are scalar or output is scalar.
2221   const auto &is_mutable_scalar_func = [](const AnfNodePtr &cur_node) {
2222     const auto &abstract = cur_node->abstract();
2223     if (abstract == nullptr || (!abstract->isa<abstract::AbstractScalar>())) {
2224       return false;
2225     }
2226     if (abstract->BuildValue()->ContainsValueAny() && abstract->BuildType()->isa<Number>()) {
2227       return true;
2228     }
2229     return false;
2230   };
2231   bool is_output_mutable_scalar = is_mutable_scalar_func(node);
2232   bool is_scalar_to_tensor = IsPrimitiveCNode(node, prim::kPrimScalarToTensor);
2233   if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
2234     const auto &cnode = node->cast<CNodePtr>();
2235     MS_EXCEPTION_IF_NULL(cnode);
2236     if (!is_mutable_scalar_func(cnode->input(kRealInputIndexInDepend))) {
2237       return false;
2238     }
2239   }
2240   return is_output_mutable_scalar || is_scalar_to_tensor;
2241 }
2242 
IsDynamicSequence(const AnfNodePtr & node)2243 bool AnfAlgo::IsDynamicSequence(const AnfNodePtr &node) {
2244   MS_EXCEPTION_IF_NULL(node);
2245   // Check if the node is dynamic sequence by sign in abstract.
2246   const auto &is_dynamic_len_func = [&node]() {
2247     const auto &abstract = node->abstract();
2248     if (abstract == nullptr || (!abstract->isa<abstract::AbstractSequence>())) {
2249       return false;
2250     }
2251 
2252     const auto &sequence_abstract = abstract->cast<abstract::AbstractSequencePtr>();
2253     MS_EXCEPTION_IF_NULL(sequence_abstract);
2254     return sequence_abstract->dynamic_len() || sequence_abstract->dynamic_len_element_abs() != nullptr;
2255   };
2256 
2257   // Check if the node is dynamic sequence by sign in node, in cnode it is an attr in primitive, in parameter, it is
2258   // an sign.
2259   if (node->isa<Parameter>()) {
2260     const auto &parameter = node->cast<ParameterPtr>();
2261     MS_EXCEPTION_IF_NULL(parameter);
2262     if (parameter->dynamic_len()) {
2263       return true;
2264     }
2265     bool is_dynamic = is_dynamic_len_func();
2266     if (is_dynamic) {
2267       parameter->set_dynamic_len(true);
2268     }
2269     return is_dynamic;
2270   } else if (node->isa<CNode>()) {
2271     if (IsCallNode(node)) {
2272       return is_dynamic_len_func();
2273     }
2274     const auto &cnode = node->cast<CNodePtr>();
2275     MS_EXCEPTION_IF_NULL(cnode);
2276     if (cnode->HasAttr(kAttrDynamicLenName)) {
2277       return GetValue<bool>(cnode->GetAttr(kAttrDynamicLenName));
2278     } else {
2279       bool is_dynamic = is_dynamic_len_func();
2280       cnode->AddAttr(kAttrDynamicLenName, MakeValue(is_dynamic));
2281       return is_dynamic;
2282     }
2283   } else if (node->isa<ValueNode>()) {
2284     return is_dynamic_len_func();
2285   }
2286   return false;
2287 }
2288 
IsAnyTypeOutput(const AnfNodePtr & node)2289 bool AnfAlgo::IsAnyTypeOutput(const AnfNodePtr &node) {
2290   MS_EXCEPTION_IF_NULL(node);
2291   if (node->isa<CNode>()) {
2292     if (IsCallNode(node)) {
2293       if (node->abstract() != nullptr && node->abstract()->isa<abstract::AbstractAny>()) {
2294         return true;
2295       }
2296       return false;
2297     }
2298     const auto &cnode = node->cast<CNodePtr>();
2299     MS_EXCEPTION_IF_NULL(cnode);
2300     if (cnode->HasAttr(kAttrAnyOutputName)) {
2301       return GetValue<bool>(cnode->GetAttr(kAttrAnyOutputName));
2302     } else {
2303       bool is_any_output = (node->abstract() != nullptr && node->abstract()->isa<abstract::AbstractAny>());
2304       cnode->AddAttr(kAttrAnyOutputName, MakeValue(is_any_output));
2305       return is_any_output;
2306     }
2307   }
2308   return false;
2309 }
2310 
2311 namespace {
IsIncludeAny(const abstract::AbstractBasePtr & abstract)2312 bool IsIncludeAny(const abstract::AbstractBasePtr &abstract) {
2313   if (abstract == nullptr) {
2314     return false;
2315   }
2316   if (abstract->isa<abstract::AbstractAny>()) {
2317     return true;
2318   }
2319   if (!abstract->isa<abstract::AbstractSequence>()) {
2320     return false;
2321   }
2322   const auto &seq_abstract = abstract->cast<abstract::AbstractSequencePtr>();
2323   MS_EXCEPTION_IF_NULL(seq_abstract);
2324   if (std::any_of(seq_abstract->elements().begin(), seq_abstract->elements().end(),
2325                   [](const auto &abstract) { return IsIncludeAny(abstract); })) {
2326     return true;
2327   }
2328   return false;
2329 }
2330 }  // namespace
2331 
IsAnyTypeInput(const std::vector<AnfNodePtr> & inputs)2332 bool AnfAlgo::IsAnyTypeInput(const std::vector<AnfNodePtr> &inputs) {
2333   for (const auto &input : inputs) {
2334     MS_EXCEPTION_IF_NULL(input);
2335     if (IsIncludeAny(input->abstract())) {
2336       return true;
2337     }
2338   }
2339   return false;
2340 }
2341 
HasTupleInput(const CNodePtr & node)2342 bool AnfAlgo::HasTupleInput(const CNodePtr &node) {
2343   MS_EXCEPTION_IF_NULL(node);
2344   size_t input_num = node->size() - 1;
2345   for (size_t i = 0; i < input_num; ++i) {
2346     auto input_node = common::AnfAlgo::GetInputNode(node, i);
2347     MS_EXCEPTION_IF_NULL(input_node);
2348     if (common::AnfAlgo::IsTupleOutput(input_node)) {
2349       return true;
2350     }
2351   }
2352   return false;
2353 }
2354 
HasDynamicTupleInput(const CNodePtr & node)2355 bool AnfAlgo::HasDynamicTupleInput(const CNodePtr &node) {
2356   MS_EXCEPTION_IF_NULL(node);
2357   size_t input_num = node->size() - 1;
2358   for (size_t i = 0; i < input_num; ++i) {
2359     auto input_node = common::AnfAlgo::GetInputNode(node, i);
2360     MS_EXCEPTION_IF_NULL(input_node);
2361     if (common::AnfAlgo::IsDynamicSequence(input_node)) {
2362       return true;
2363     }
2364   }
2365   return false;
2366 }
2367 
IsReduceOp(const std::string & op_name)2368 bool AnfAlgo::IsReduceOp(const std::string &op_name) {
2369   static const std::set<std::string> reduce_op_type = {prim::kPrimReduceAll->name(),  prim::kPrimReduceAny->name(),
2370                                                        prim::kPrimReduceMean->name(), prim::kPrimReduceMax->name(),
2371                                                        prim::kPrimReduceMin->name(),  prim::kPrimReduceProd->name(),
2372                                                        prim::kPrimReduceSum->name(),  prim::kPrimSquareSumV1->name()};
2373   return reduce_op_type.find(op_name) != reduce_op_type.end();
2374 }
2375 
IsTypeTransformOp(const std::string & op_name)2376 bool AnfAlgo::IsTypeTransformOp(const std::string &op_name) {
2377   static const std::set<std::string> type_trans_op_names = {
2378     prim::kPrimTupleToTensor->name(),  prim::kPrimTensorToTuple->name(), prim::kPrimScalarToTensor->name(),
2379     prim::kPrimTensorToScalar->name(), prim::kPrimRealMakeTuple->name(), prim::kPrimRealTupleGetItem->name()};
2380   return type_trans_op_names.find(op_name) != type_trans_op_names.end();
2381 }
2382 
GetDynamicSequenceShape(const AnfNodePtr & node,size_t output_idx)2383 abstract::BaseShapePtr AnfAlgo::GetDynamicSequenceShape(const AnfNodePtr &node, size_t output_idx) {
2384   MS_EXCEPTION_IF_NULL(node);
2385   abstract::AbstractSequencePtr sequence_abs = nullptr;
2386   if (node->Shape() == nullptr || (!node->Shape()->isa<abstract::DynamicSequenceShape>())) {
2387     MS_LOG(INFO) << "node:" << node->fullname_with_scope() << " index:" << output_idx
2388                  << " abs:" << node->abstract()->ToString();
2389     if (!node->abstract()->isa<abstract::AbstractSequence>()) {
2390       MS_LOG(INTERNAL_EXCEPTION) << "Not sequence abstract in node:" << node->DebugString()
2391                                  << " for dynamic sequence shape.";
2392     }
2393     const auto &top_sequence_abs = node->abstract()->cast<abstract::AbstractSequencePtr>();
2394     MS_EXCEPTION_IF_NULL(top_sequence_abs);
2395     if (output_idx >= top_sequence_abs->elements().size()) {
2396       MS_LOG(INTERNAL_EXCEPTION) << "Invalid index:" << output_idx << " for abs:" << top_sequence_abs->ToString()
2397                                  << "node:" << node->fullname_with_scope();
2398     }
2399     const auto &sub_abs = top_sequence_abs->elements()[output_idx];
2400     MS_EXCEPTION_IF_NULL(sub_abs);
2401     if (!sub_abs->isa<abstract::AbstractSequence>()) {
2402       MS_LOG(INTERNAL_EXCEPTION) << "Not sequence abstract in node:" << node->DebugString()
2403                                  << " for dynamic sequence shape.";
2404     }
2405     sequence_abs = sub_abs->cast<abstract::AbstractSequencePtr>();
2406   } else {
2407     if (node->abstract() == nullptr) {
2408       MS_LOG(INTERNAL_EXCEPTION) << "Empty abstract in node:" << node->DebugString() << " for dynamic sequence shape.";
2409     }
2410     if (!node->abstract()->isa<abstract::AbstractSequence>()) {
2411       MS_LOG(INTERNAL_EXCEPTION) << "Not sequence abstract in node:" << node->DebugString()
2412                                  << " for dynamic sequence shape.";
2413     }
2414     sequence_abs = node->abstract()->cast<abstract::AbstractSequencePtr>();
2415   }
2416   MS_EXCEPTION_IF_NULL(sequence_abs);
2417   if (!sequence_abs->dynamic_len()) {
2418     MS_LOG(INTERNAL_EXCEPTION) << "Not dynamic abstract in node:" << node->DebugString()
2419                                << " for dynamic sequence shape.";
2420   }
2421   const auto &element_abs = sequence_abs->dynamic_len_element_abs();
2422   if (element_abs == nullptr) {
2423     MS_LOG(INFO) << "No element abs for node:" << node->DebugString() << " index:" << output_idx;
2424     ShapeVector empty_shape{0};
2425     return std::make_shared<abstract::Shape>(empty_shape);
2426   }
2427   return element_abs->BuildShape();
2428 }
2429 
FetchAbstractByIndex(const AbstractBasePtr & abstract,size_t index)2430 abstract::AbstractBasePtr AnfAlgo::FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index) {
2431   MS_EXCEPTION_IF_NULL(abstract);
2432   if (!abstract->isa<abstract::AbstractSequence>() || abstract->cast<abstract::AbstractSequencePtr>()->dynamic_len()) {
2433     if (index != 0) {
2434       MS_LOG(INTERNAL_EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
2435     }
2436     return abstract;
2437   }
2438 
2439   auto tuple_abstract = abstract->cast<abstract::AbstractSequencePtr>();
2440   MS_EXCEPTION_IF_NULL(tuple_abstract);
2441   const auto &sub_abstracts = tuple_abstract->elements();
2442   size_t real_index = index;
2443   for (const auto &sub_abstract : sub_abstracts) {
2444     size_t tmp_index = common::AnfAlgo::GetOutputNumByAbstract(sub_abstract);
2445     if (real_index >= tmp_index) {
2446       real_index -= tmp_index;
2447       continue;
2448     }
2449     return FetchAbstractByIndex(sub_abstract, real_index);
2450   }
2451   MS_LOG(INTERNAL_EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
2452 }
2453 
GetInputName(const CNodePtr & origin_op,size_t input_index)2454 std::string AnfAlgo::GetInputName(const CNodePtr &origin_op, size_t input_index) {
2455   auto prim_func_input_name = ops::GetInputNameByIndex(GetCNodeName(origin_op), input_index);
2456   if (prim_func_input_name != "") {
2457     return prim_func_input_name;
2458   }
2459   auto origin_primitive = GetCNodePrimitive(origin_op);
2460   MS_EXCEPTION_IF_NULL(origin_primitive);
2461   auto input_names = origin_primitive->GetAttr(kAttrInputNames);
2462   if (input_names == nullptr) {
2463     MS_LOG(INTERNAL_EXCEPTION) << "input_names are nullptr in cnode " << origin_op->fullname_with_scope()
2464                                << ", debug string:" << origin_op->DebugString()
2465                                << ", attr text:" << origin_primitive->GetAttrsText();
2466   }
2467 
2468   auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
2469   if (input_index >= input_names_vec.size()) {
2470     MS_LOG(INFO) << "Input index is invalid. input index: " << input_index << ", input name size "
2471                  << input_names_vec.size();
2472     return "";
2473   }
2474   return input_names_vec[input_index];
2475 }
2476 
IsNoOuputNode(const AnfNodePtr & node)2477 bool AnfAlgo::IsNoOuputNode(const AnfNodePtr &node) {
2478   const mindspore::HashSet<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> no_output_prims = {
2479     prim::kPrimSend,
2480     prim::kPrimNPUClearFloatStatusV2,
2481     prim::kPrimInitPartitionMap,
2482     prim::kPrimInitEmbeddingHashmap,
2483     prim::kPrimEmbeddingTableImport,
2484     prim::kPrimEmbeddingComputeVarExport,
2485     prim::kPrimEmbeddingComputeVarImport,
2486     prim::kPrimEmbeddingTableExport};
2487   if (IsOneOfPrimitiveCNode(node, no_output_prims)) {
2488     return true;
2489   }
2490   return false;
2491 }
2492 
ValueToScalar(const ValuePtr & value,TypeId type_id)2493 ValuePtr AnfAlgo::ValueToScalar(const ValuePtr &value, TypeId type_id) {
2494   MS_EXCEPTION_IF_NULL(value);
2495   if (!value->isa<KernelTensorValue>()) {
2496     return nullptr;
2497   }
2498   const auto &kernel_tensor_value = value->cast<KernelTensorValuePtr>();
2499   MS_EXCEPTION_IF_NULL(kernel_tensor_value);
2500   MS_EXCEPTION_IF_NULL(kernel_tensor_value->GetDataPtr());
2501   switch (type_id) {
2502     case kNumberTypeBool:
2503       return MakeValue(*reinterpret_cast<const bool *>(kernel_tensor_value->GetDataPtr()));
2504     case kNumberTypeInt16:
2505       return MakeValue(*reinterpret_cast<const int16_t *>(kernel_tensor_value->GetDataPtr()));
2506     case kNumberTypeUInt16:
2507       return MakeValue(*reinterpret_cast<const uint16_t *>(kernel_tensor_value->GetDataPtr()));
2508     case kNumberTypeInt8:
2509       return MakeValue(*reinterpret_cast<const int8_t *>(kernel_tensor_value->GetDataPtr()));
2510     case kNumberTypeUInt8:
2511       return MakeValue(*reinterpret_cast<const uint8_t *>(kernel_tensor_value->GetDataPtr()));
2512     case kNumberTypeInt32:
2513       return MakeValue(*reinterpret_cast<const int32_t *>(kernel_tensor_value->GetDataPtr()));
2514     case kNumberTypeUInt32:
2515       return MakeValue(*reinterpret_cast<const uint32_t *>(kernel_tensor_value->GetDataPtr()));
2516     case kNumberTypeInt64:
2517       return MakeValue(*reinterpret_cast<const int64_t *>(kernel_tensor_value->GetDataPtr()));
2518     case kNumberTypeUInt64:
2519       return MakeValue(*reinterpret_cast<const uint64_t *>(kernel_tensor_value->GetDataPtr()));
2520     case kNumberTypeFloat16:
2521       return MakeValue(*reinterpret_cast<const uint16_t *>(kernel_tensor_value->GetDataPtr()));
2522     case kNumberTypeFloat32:
2523       return MakeValue(*reinterpret_cast<const float *>(kernel_tensor_value->GetDataPtr()));
2524     case kNumberTypeFloat64:
2525       return MakeValue(*reinterpret_cast<const double *>(kernel_tensor_value->GetDataPtr()));
2526     case kNumberTypeBFloat16:
2527       return MakeValue(*reinterpret_cast<const uint16_t *>(kernel_tensor_value->GetDataPtr()));
2528     default:
2529       MS_LOG(DEBUG) << "Not support scalar type:" << type_id;
2530   }
2531   return nullptr;
2532 }
2533 
2534 namespace {
IterateFindTensor(ValuePtrList * value_list,const VectorRef & ref_list)2535 void IterateFindTensor(ValuePtrList *value_list, const VectorRef &ref_list) {
2536   MS_EXCEPTION_IF_NULL(value_list);
2537   for (size_t i = 0; i < ref_list.size(); ++i) {
2538     if (utils::isa<tensor::BaseTensorPtr>(ref_list[i])) {
2539       auto tensor_ptr = utils::cast<std::shared_ptr<tensor::BaseTensor>>(ref_list[i]);
2540       MS_EXCEPTION_IF_NULL(tensor_ptr);
2541       (void)value_list->emplace_back(tensor_ptr);
2542     } else if (utils::isa<VectorRef>(ref_list[i])) {
2543       auto ref_iter = utils::cast<VectorRef>(ref_list[i]);
2544       IterateFindTensor(value_list, ref_iter);
2545     } else if (utils::isa<tensor::CSRTensorPtr>(ref_list[i])) {
2546       auto csr_tensor = utils::cast<tensor::CSRTensorPtr>(ref_list[i]);
2547       MS_EXCEPTION_IF_NULL(csr_tensor);
2548       (void)value_list->emplace_back(csr_tensor);
2549     } else {
2550       MS_LOG(EXCEPTION) << "The ref value " << ref_list[i].ToString() << " is not a vector ref or a tensor!";
2551     }
2552   }
2553 }
2554 
HasAbstractFunction(const AbstractBasePtr & abs)2555 bool HasAbstractFunction(const AbstractBasePtr &abs) {
2556   if (abs->isa<abstract::AbstractSequence>() && !abs->isa<abstract::AbstractSparseTensor>()) {
2557     auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
2558     if (abs_seq->dynamic_len()) {
2559       return HasAbstractFunction(abs_seq->dynamic_len_element_abs());
2560     }
2561     return std::any_of(abs_seq->elements().cbegin(), abs_seq->elements().cend(), HasAbstractFunction);
2562   }
2563   // if abs it not AbstractSequence.
2564   return abs->isa<abstract::AbstractFunction>();
2565 }
2566 
IsCellReuse(const AnfNodePtr & input)2567 bool IsCellReuse(const AnfNodePtr &input) {
2568   if (IsValueNode<FuncGraph>(input)) {
2569     auto fg = GetValueNode<FuncGraphPtr>(input);
2570     MS_EXCEPTION_IF_NULL(fg);
2571     if (fg->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE)) {
2572       return true;
2573     }
2574   }
2575   return false;
2576 }
2577 
AcceptableReturnValue(const CNodePtr & cnode,const AnfNodePtr & input0)2578 bool AcceptableReturnValue(const CNodePtr &cnode, const AnfNodePtr &input0) {
2579   if (IsCellReuse(input0)) {
2580     return true;
2581   }
2582   auto func_graphs = abstract::GetFuncGraphsFromCallNode(cnode);
2583   auto graph_has_function_output = [](const FuncGraphPtr &fg) { return HasAbstractFunction(fg->output()->abstract()); };
2584   if (std::all_of(func_graphs.cbegin(), func_graphs.cend(), std::not_fn(graph_has_function_output))) {
2585     return true;
2586   }
2587   return false;
2588 }
2589 
SupportInlinePartial(const AnfNodePtr & input0)2590 bool SupportInlinePartial(const AnfNodePtr &input0) {
2591   // inline partial
2592   if (IsPrimitiveCNode(input0, prim::kPrimTupleGetItem)) {
2593     auto tuple_get_node = input0->cast<CNodePtr>();
2594     MS_EXCEPTION_IF_NULL(tuple_get_node);
2595     auto get_from_node = tuple_get_node->input(1);
2596     auto idx = common::AnfAlgo::GetTupleGetItemOutIndex(tuple_get_node);
2597     MS_EXCEPTION_IF_NULL(get_from_node);
2598     // tuple get item from a call subgraph output
2599     if (get_from_node->isa<CNode>() && IsValueNode<FuncGraph>(get_from_node->cast<CNodePtr>()->input(0))) {
2600       auto call_graph = GetValueNode<FuncGraphPtr>(get_from_node->cast<CNodePtr>()->input(0));
2601       MS_EXCEPTION_IF_NULL(call_graph);
2602       auto graph_out = call_graph->output();
2603       MS_EXCEPTION_IF_NULL(graph_out);
2604       size_t tuple_input_num = common::AnfAlgo::GetInputTensorNum(graph_out);
2605       // the partial must be the last output
2606       if (graph_out->isa<CNode>() && tuple_input_num == idx + 1) {
2607         int partial_cnt = 0;
2608         for (size_t i = 0; i < tuple_input_num; i++) {
2609           auto input = graph_out->cast<CNodePtr>()->input(i + 1);
2610           if (IsPrimitiveCNode(input, prim::kPrimPartial)) {
2611             partial_cnt++;
2612           }
2613         }
2614         auto partial = graph_out->cast<CNodePtr>()->input(idx + 1);
2615         MS_EXCEPTION_IF_NULL(partial);
2616         // we only support one partial func at the last return value now
2617         if (partial_cnt != 1 || !IsPrimitiveCNode(partial, prim::kPrimPartial)) {
2618           if (partial_cnt != 0) {
2619             MS_LOG(INFO) << "Partial func cnt: " << partial_cnt
2620                          << ", last return value: " << partial->fullname_with_scope();
2621           }
2622           return false;
2623         }
2624         auto partial_inputs = partial->cast<CNodePtr>()->inputs();
2625         // the input of partial can't be FuncGraph/Partial
2626         bool has_illegal_input = std::any_of(
2627           partial_inputs.begin() + kPartialMinInputSize, partial_inputs.end(), [](const AnfNodePtr &partial_input) {
2628             return IsValueNode<FuncGraph>(partial_input) || IsPrimitiveCNode(partial_input, prim::kPrimPartial);
2629           });
2630         return !has_illegal_input;
2631       }
2632     }
2633   }
2634   return false;
2635 }
2636 }  // namespace
2637 
TransformVectorRefToMultiValue(const VectorRef & base_ref)2638 ValuePtrList AnfAlgo::TransformVectorRefToMultiValue(const VectorRef &base_ref) {
2639   ValuePtrList value_list;
2640   if (utils::isa<VectorRef>(base_ref)) {
2641     auto ref_list = utils::cast<VectorRef>(base_ref);
2642     IterateFindTensor(&value_list, ref_list);
2643   } else if (utils::isa<tensor::Tensor>(base_ref)) {
2644     auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(base_ref);
2645     MS_EXCEPTION_IF_NULL(tensor_ptr);
2646     (void)value_list.emplace_back(tensor_ptr);
2647   } else {
2648     MS_LOG(EXCEPTION) << "The ref value " << base_ref.ToString() << " is not a vector ref or a tensor!";
2649   }
2650   return value_list;
2651 }
2652 
HasIncorporateCallNode(const CNodePtr & cnode)2653 bool AnfAlgo::HasIncorporateCallNode(const CNodePtr &cnode) {
2654   if (!IsValueNode<Primitive>(cnode->input(0))) {  // If cnode is a call node.
2655     auto input0 = cnode->input(0);
2656     if (IsPrimitiveCNode(input0, prim::kPrimSwitch) || IsPrimitiveCNode(input0, prim::kPrimSwitchLayer) ||
2657         IsValueNode<FuncGraph>(input0)) {
2658       if (IsCellReuse(input0) && IsEnableRefMode()) {
2659         MS_LOG(INFO) << "Use cell reuse when enable ge mode: " << cnode->DebugString();
2660         return true;
2661       }
2662       if (AcceptableReturnValue(cnode, input0)) {
2663         return false;
2664       }
2665     }
2666     if (SupportInlinePartial(input0)) {
2667       return false;
2668     }
2669     MS_LOG(INFO) << "Call has indirect call: " << cnode->DebugString();
2670     return true;
2671   }
2672   return false;
2673 }
2674 
IsDynamicGraph(const FuncGraphPtr & func_graph)2675 bool AnfAlgo::IsDynamicGraph(const FuncGraphPtr &func_graph) {
2676   MS_EXCEPTION_IF_NULL(func_graph);
2677   std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return(), SuccDeeperSimple);
2678   AnfNodePtr dynamic_node = nullptr;
2679   AnfNodePtr pyexecute_node = nullptr;
2680   for (const auto &node : node_list) {
2681     if (node->abstract() == nullptr) {
2682       MS_LOG(INFO) << "Null abstract of node: " << node->DebugString();
2683       continue;
2684     }
2685     if (node->abstract() != nullptr) {
2686       auto shape = node->abstract()->GetShape();
2687       // Dynamic shape tensor.
2688       if (shape->isa<abstract::TensorShape>() && IsDynamic(shape->GetShapeVector())) {
2689         dynamic_node = node;
2690         break;
2691       }
2692       // Dynamic len sequence.
2693       if (node->abstract()->isa<abstract::AbstractSequence>() &&
2694           node->abstract()->cast<abstract::AbstractSequencePtr>()->dynamic_len()) {
2695         dynamic_node = node;
2696         break;
2697       }
2698       // PyExecute node exist
2699       if (IsPrimitiveCNode(node, prim::kPrimPyExecute)) {
2700         pyexecute_node = node;
2701       }
2702     }
2703   }
2704   if (dynamic_node != nullptr) {
2705     MS_LOG(INFO) << "Func graph:" << func_graph->ToString()
2706                  << " is dynamic shape graph, because find dynamic shape node:" << dynamic_node->DebugString()
2707                  << ", abstract: " << dynamic_node->abstract()->ToString();
2708     return true;
2709   }
2710   if (pyexecute_node != nullptr) {
2711     MS_LOG(INFO) << "Func graph:" << func_graph->ToString() << " has pyexecute node:" << pyexecute_node->DebugString();
2712     return true;
2713   }
2714   return false;
2715 }
2716 }  // namespace common
2717 }  // namespace mindspore
2718