1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "include/common/utils/anfalgo.h"
17 #include <memory>
18 #include <algorithm>
19 #include <map>
20 #include <numeric>
21 #include <set>
22 #include <complex>
23 #include "mindapi/base/shape_vector.h"
24 #include "ops/ascend_op_name.h"
25 #include "ops/nn_optimizer_op_name.h"
26 #include "ops/lite_op_name.h"
27 #include "ops/structure_ops.h"
28 #include "ops/sequence_ops.h"
29 #include "ops/other_ops.h"
30 #include "ops/nn_ops.h"
31 #include "ops/math_ops.h"
32 #include "ops/array_ops.h"
33 #include "ops/arithmetic_ops.h"
34 #include "ops/framework_ops.h"
35 #include "ops/op_utils.h"
36 #include "ops/op_def.h"
37 #include "ops/auto_generate/gen_ops_primitive.h"
38 #include "ir/anf.h"
39 #include "ir/func_graph.h"
40 #include "include/common/utils/utils.h"
41 #include "utils/shape_utils.h"
42 #include "utils/trace_base.h"
43 #include "utils/anf_utils.h"
44 #include "include/common/utils/parallel_context.h"
45 #include "utils/ms_context.h"
46 #include "pybind_api/ir/primitive_py.h"
47 #include "kernel/kernel_build_info.h"
48 #include "include/backend/anf_runtime_algorithm.h"
49
50 namespace mindspore {
51 namespace common {
52 using abstract::AbstractSparseTensor;
53 using abstract::AbstractTensor;
54 using abstract::AbstractTuple;
55
56 namespace {
57 constexpr size_t kNopNodeRealInputIndex = 1;
58 using complex64 = std::complex<float>;
59 using complex128 = std::complex<double>;
60
61 const PrimitiveSet expand_prims = {prim::kPrimMakeTuple};
62 const std::set<std::string> kNodeTupleOutSet = {kMakeTupleOpName, kGetNextOpName};
63
GetRealOutputRecursively(const AnfNodePtr & node,size_t output_index,std::vector<KernelWithIndex> * inputs)64 void GetRealOutputRecursively(const AnfNodePtr &node, size_t output_index, std::vector<KernelWithIndex> *inputs) {
65 MS_EXCEPTION_IF_NULL(node);
66 if (node->isa<ValueNode>() || node->isa<Parameter>()) {
67 return inputs->push_back(std::make_pair(node, 0));
68 }
69
70 // Skip control node
71 if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad) ||
72 AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) {
73 return GetRealOutputRecursively(node->cast<CNodePtr>()->input(kRealInputIndexInDepend), 0, inputs);
74 }
75
76 // Bypass TupleGetItem
77 if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
78 auto tuple_get_item = node->cast<CNodePtr>();
79 MS_EXCEPTION_IF_NULL(tuple_get_item);
80 auto input = AnfAlgo::GetTupleGetItemRealInput(tuple_get_item);
81 auto index = AnfAlgo::GetTupleGetItemOutIndex(tuple_get_item);
82 // Conceal MakeTuple + TupleGetItem pair.
83 if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimMakeTuple)) {
84 auto make_tuple = input->cast<CNodePtr>();
85 MS_EXCEPTION_IF_NULL(make_tuple);
86 auto real_input = AnfAlgo::GetInputNode(make_tuple, index);
87 return GetRealOutputRecursively(real_input, 0, inputs);
88 }
89
90 // Skip TupleGetItem.
91 return GetRealOutputRecursively(input, index, inputs);
92 }
93
94 // Flatten MakeTuple inputs.
95 if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
96 auto make_tuple = node->cast<CNodePtr>();
97 MS_EXCEPTION_IF_NULL(make_tuple);
98 size_t input_num = AnfAlgo::GetInputTensorNum(make_tuple);
99 for (size_t input_index = 0; input_index < input_num; ++input_index) {
100 auto input_node = AnfAlgo::GetInputNode(make_tuple, input_index);
101 GetRealOutputRecursively(input_node, 0, inputs);
102 }
103 return;
104 }
105
106 return inputs->push_back(std::make_pair(node, output_index));
107 }
108
IsMultiLayerTuple(const abstract::AbstractBasePtr & abstract)109 bool IsMultiLayerTuple(const abstract::AbstractBasePtr &abstract) {
110 MS_EXCEPTION_IF_NULL(abstract);
111 if (!abstract->isa<abstract::AbstractSequence>()) {
112 return false;
113 }
114 const auto &sequence_abstract = abstract->cast<abstract::AbstractSequencePtr>();
115 MS_EXCEPTION_IF_NULL(sequence_abstract);
116 if (sequence_abstract->dynamic_len()) {
117 return false;
118 }
119 return std::any_of(sequence_abstract->elements().begin(), sequence_abstract->elements().end(),
120 [](const abstract::AbstractBasePtr &sub_abstract) {
121 return sub_abstract != nullptr && sub_abstract->isa<abstract::AbstractSequence>();
122 });
123 }
124
GetAllOutputWithIndexInner(const AnfNodePtr & node,const std::vector<PrimitivePtr> & return_types)125 std::vector<KernelWithIndex> GetAllOutputWithIndexInner(const AnfNodePtr &node,
126 const std::vector<PrimitivePtr> &return_types) {
127 MS_EXCEPTION_IF_NULL(node);
128 MS_LOG(DEBUG) << "Output node: " << node->fullname_with_scope();
129 std::vector<KernelWithIndex> ret;
130 std::vector<KernelWithIndex> ret_empty;
131 // The MakeTuple/MakeSparse node need expand and recurse.
132 if (IsOneOfPrimitiveCNode(node, expand_prims)) {
133 auto make_tuple = node->cast<CNodePtr>();
134 MS_EXCEPTION_IF_NULL(make_tuple);
135 for (size_t i = 1; i < make_tuple->size(); i++) {
136 auto make_tuple_output = GetAllOutputWithIndexInner(make_tuple->input(i), return_types);
137 (void)std::copy(make_tuple_output.begin(), make_tuple_output.end(), std::back_inserter(ret));
138 }
139 return ret;
140 }
141 // The depend node need get the real node.
142 if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
143 auto depend_node = node->cast<CNodePtr>();
144 MS_EXCEPTION_IF_NULL(depend_node);
145 auto real_output = GetAllOutputWithIndexInner(depend_node->input(kRealInputIndexInDepend), return_types);
146 (void)std::copy(real_output.begin(), real_output.end(), std::back_inserter(ret));
147 return ret;
148 }
149
150 // Value node need get all the elements.
151 if (node->isa<ValueNode>()) {
152 auto value = node->cast<ValueNodePtr>()->value();
153 MS_EXCEPTION_IF_NULL(value);
154 if (value->isa<ValueSequence>()) {
155 auto value_tuple = value->cast<ValueSequencePtr>();
156 auto value_tuple_size = CountValueNum(value_tuple);
157 for (size_t i = 0; i < value_tuple_size; ++i) {
158 (void)ret.emplace_back(node, i);
159 }
160 } else {
161 (void)ret.emplace_back(node, 0);
162 }
163 MS_LOG(DEBUG) << "Output value node: " << node->fullname_with_scope() << ", value num: " << ret.size();
164 return ret;
165 }
166
167 // Output num must be exactly equal to the number of outputs of the node.
168 size_t outputs_num = 1;
169 if (AnfUtils::IsRealCNodeKernel(node)) {
170 if (node->abstract() != nullptr &&
171 (common::AnfAlgo::IsDynamicSequence(node) || IsMultiLayerTuple(node->abstract()))) {
172 outputs_num = common::AnfAlgo::GetOutputNumByAbstract(node->abstract());
173 } else {
174 outputs_num = AnfUtils::GetOutputTensorNum(node);
175 }
176 MS_LOG(DEBUG) << "Output num:" << outputs_num << " for node:" << node->DebugString();
177 }
178 // Call node maybe a real cnode and the unreal node cannot get output num exactly, so we should get
179 // output num from abstract again. For example the TupleGetItem/Makeple multi-level nesting:
180 // '''G = op() ---> Assume that the output of G is a multi-member tuple
181 // A = MakeTuple(E, F, G)
182 // B = MakeTuple(H, A)
183 // C = TupleGetItem(B, 1) ---> Euqal the A
184 // D = TupleGetItem(C, 2) ---> VisitKernel will return the {G, 0}, but expect the whole G with all the members
185 // return D'''
186 if (common::AnfAlgo::IsCallNode(node) || (!AnfUtils::IsRealCNodeKernel(node))) {
187 MS_EXCEPTION_IF_NULL(node->abstract());
188 outputs_num = AnfAlgo::GetOutputNumByAbstract(node->abstract());
189 }
190
191 // The output may be the tuple of node, so need visit all the outputs of node.
192 // Since output num represents the number of all outputs of node, only one output is obtained per loop.
193 for (size_t i = 0; i < outputs_num; ++i) {
194 // Maybe this scene: tupleGetItem + depend + makeTuple, can be done correctly in VisitKernelWithReturnType.
195 // The output may be updataState/load node for connecting dependencies between subgraphs.
196 auto output_with_index = AnfAlgo::VisitKernelWithReturnType(
197 node, i, false, {prim::kPrimMakeTuple, prim::kPrimUpdateState, prim::kPrimLoad}, nullptr, true);
198 MS_EXCEPTION_IF_NULL(output_with_index.first);
199
200 // The MakeTuple/MakeSparse node need recurse.
201 if (IsOneOfPrimitiveCNode(output_with_index.first, expand_prims)) {
202 auto output_vector = GetAllOutputWithIndexInner(output_with_index.first, return_types);
203 if (output_vector.size() <= output_with_index.second) {
204 MS_LOG(INTERNAL_EXCEPTION) << "Invalid index:" << output_with_index.second
205 << " for outputs of node:" << output_with_index.first->DebugString();
206 }
207 (void)ret.emplace_back(output_vector[output_with_index.second]);
208 continue;
209 }
210
211 // The InitDataSetQueue node has no output.
212 if (AnfAlgo::CheckPrimitiveType(output_with_index.first, prim::kPrimInitDataSetQueue)) {
213 return ret_empty;
214 }
215
216 MS_LOG(DEBUG) << "Output node: " << output_with_index.first->fullname_with_scope()
217 << " with output index: " << output_with_index.second;
218 ret.push_back(output_with_index);
219 }
220 return ret;
221 }
222
IsNodeDynamicShape(const AnfNodePtr & node)223 bool IsNodeDynamicShape(const AnfNodePtr &node) {
224 MS_EXCEPTION_IF_NULL(node);
225 if (!node->isa<CNode>()) {
226 MS_LOG(DEBUG) << "Node is not a cnode";
227 return false;
228 }
229 auto cnode = node->cast<CNodePtr>();
230 auto in_dynamic = AnfAlgo::IsNodeInputDynamicShape(cnode);
231 auto out_dynamic = AnfAlgo::IsNodeOutputDynamicShape(cnode);
232 if (in_dynamic && !AnfAlgo::HasNodeAttr(kAttrInputIsDynamicShape, cnode)) {
233 AnfAlgo::SetNodeAttrSafely(kAttrInputIsDynamicShape, MakeValue(true), cnode);
234 MS_LOG(DEBUG) << "Set Input Dynamic Shape Attr to Node:" << cnode->fullname_with_scope()
235 << " debug string:" << cnode->DebugString();
236 }
237 if (out_dynamic && !AnfAlgo::HasNodeAttr(kAttrOutputIsDynamicShape, cnode)) {
238 AnfAlgo::SetNodeAttrSafely(kAttrOutputIsDynamicShape, MakeValue(true), cnode);
239 MS_LOG(DEBUG) << "Set Output Dynamic Shape Attr to Node:" << cnode->fullname_with_scope()
240 << " debug string:" << cnode->DebugString();
241 }
242 if (IsPrimitiveCNode(node, prim::kPrimPyExecute) && node->abstract()->isa<abstract::AbstractSequence>()) {
243 AnfAlgo::SetNodeAttrSafely(kAttrOutputIsDynamicShape, MakeValue(true), cnode);
244 MS_LOG(DEBUG) << "Set Output Dynamic Shape Attr to Node:" << cnode->fullname_with_scope();
245 return true;
246 }
247 return in_dynamic || out_dynamic;
248 }
249 } // namespace
250
GetTupleGetItemRealInput(const CNodePtr & tuple_get_item)251 AnfNodePtr AnfAlgo::GetTupleGetItemRealInput(const CNodePtr &tuple_get_item) {
252 MS_EXCEPTION_IF_NULL(tuple_get_item);
253 if (tuple_get_item->size() != kTupleGetItemInputSize) {
254 MS_LOG(INTERNAL_EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
255 }
256 return tuple_get_item->input(kRealInputNodeIndexInTupleGetItem);
257 }
258
GetTupleGetItemOutIndex(const CNodePtr & tuple_get_item)259 size_t AnfAlgo::GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
260 MS_EXCEPTION_IF_NULL(tuple_get_item);
261 if (tuple_get_item->size() != kTupleGetItemInputSize) {
262 MS_LOG(INTERNAL_EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
263 }
264 auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem);
265 MS_EXCEPTION_IF_NULL(output_index_value_node);
266 auto value_node = output_index_value_node->cast<ValueNodePtr>();
267 MS_EXCEPTION_IF_NULL(value_node);
268 auto value = value_node->value();
269 MS_EXCEPTION_IF_NULL(value);
270 auto idx = value->isa<Int64Imm>() ? GetValue<int64_t>(value) : GetValue<int>(value);
271 return LongToSize(idx);
272 }
273
VisitKernel(const AnfNodePtr & anf_node,size_t index)274 KernelWithIndex AnfAlgo::VisitKernel(const AnfNodePtr &anf_node, size_t index) {
275 // this function was moved to AnfUtils.
276 return AnfUtils::VisitKernel(anf_node, index);
277 }
278
279 namespace {
VisitKernelWithReturnTypeForTupleGetItem(const AnfNodePtr & anf_node,size_t index,bool skip_nop_node,const std::vector<PrimitivePtr> & return_types,abstract::AbstractBasePtr * abstract,bool is_index_valid)280 KernelWithIndex VisitKernelWithReturnTypeForTupleGetItem(const AnfNodePtr &anf_node, size_t index, bool skip_nop_node,
281 const std::vector<PrimitivePtr> &return_types,
282 abstract::AbstractBasePtr *abstract, bool is_index_valid) {
283 MS_EXCEPTION_IF_NULL(anf_node);
284 if (!common::AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
285 MS_LOG(EXCEPTION) << "Invalid tuple get item node:" << anf_node->DebugString();
286 }
287 auto cnode = anf_node->cast<CNodePtr>();
288 MS_EXCEPTION_IF_NULL(cnode);
289 if (cnode->HasAttr(kAttrReplaceRealKernelInBackend)) {
290 MS_LOG(INFO) << "cnode:" << cnode->DebugString() << " has replace flag";
291 return KernelWithIndex(anf_node, index);
292 }
293 abstract::AbstractBasePtr abs = nullptr;
294 auto item_with_index_tmp = common::AnfAlgo::VisitKernelWithReturnType(
295 common::AnfAlgo::GetTupleGetItemRealInput(cnode), common::AnfAlgo::GetTupleGetItemOutIndex(cnode), skip_nop_node,
296 return_types, &abs, true);
297 if (IsOneOfPrimitiveCNode(item_with_index_tmp.first, expand_prims)) {
298 MS_EXCEPTION_IF_NULL(item_with_index_tmp.first);
299 auto make_tuple = item_with_index_tmp.first->cast<CNodePtr>();
300 MS_EXCEPTION_IF_NULL(make_tuple);
301 const std::vector<AnfNodePtr> &make_tuple_inputs = make_tuple->inputs();
302 size_t make_tuple_input_index = item_with_index_tmp.second + 1;
303 if (make_tuple_input_index >= make_tuple_inputs.size()) {
304 MS_LOG(INTERNAL_EXCEPTION) << "Index[" << make_tuple_input_index << "] out of range[" << make_tuple_inputs.size()
305 << "].\nPlease check node: " << cnode->DebugString()
306 << ".\nLine: " << trace::GetDebugInfoStr(cnode->debug_info())
307 << ".\nAnd check node: " << make_tuple->DebugString()
308 << ".\nLine: " << trace::GetDebugInfoStr(make_tuple->debug_info()) << ".";
309 }
310 return common::AnfAlgo::VisitKernelWithReturnType(make_tuple_inputs[make_tuple_input_index], index, skip_nop_node,
311 return_types);
312 }
313 if (common::AnfAlgo::IsCallNode(item_with_index_tmp.first) || item_with_index_tmp.first->isa<Parameter>()) {
314 size_t real_index = item_with_index_tmp.second;
315 if (abs == nullptr) {
316 abs = item_with_index_tmp.first->abstract();
317 real_index = 0;
318 }
319 MS_EXCEPTION_IF_NULL(abs);
320 if (abs->isa<abstract::AbstractSequence>()) {
321 auto tuple_abstract = abs->cast<abstract::AbstractSequencePtr>();
322 MS_EXCEPTION_IF_NULL(tuple_abstract);
323 if (tuple_abstract->dynamic_len()) {
324 return item_with_index_tmp;
325 }
326 auto sub_abstracts = tuple_abstract->elements();
327 if (sub_abstracts.size() <= common::AnfAlgo::GetTupleGetItemOutIndex(cnode)) {
328 MS_LOG(INTERNAL_EXCEPTION) << "Invalid index:" << common::AnfAlgo::GetTupleGetItemOutIndex(cnode)
329 << " for abstract:" << abs->ToString();
330 }
331 for (size_t i = 0; i < common::AnfAlgo::GetTupleGetItemOutIndex(cnode); ++i) {
332 MS_EXCEPTION_IF_NULL(sub_abstracts[i]);
333 real_index += AnfAlgo::GetOutputNumByAbstract(sub_abstracts[i]);
334 }
335 if (abstract != nullptr) {
336 (*abstract) = sub_abstracts[common::AnfAlgo::GetTupleGetItemOutIndex(cnode)];
337 MS_EXCEPTION_IF_NULL((*abstract));
338 } else {
339 // In recursion of getitem node, the index of the first input of its real node is returned.
340 // When the recursion ends, the outermost index needs to be accumulated.
341 real_index += index;
342 }
343 return {item_with_index_tmp.first, real_index};
344 }
345 }
346 if (is_index_valid) {
347 if (anf_node->abstract() != nullptr && anf_node->abstract()->isa<abstract::AbstractSequence>()) {
348 const auto &seq_abs = anf_node->abstract()->cast<abstract::AbstractSequencePtr>();
349 MS_EXCEPTION_IF_NULL(seq_abs);
350 if (!seq_abs->dynamic_len()) {
351 return {anf_node, index};
352 }
353 }
354 }
355 return item_with_index_tmp;
356 }
357 } // namespace
358
VisitKernelWithReturnType(const AnfNodePtr & anf_node,size_t index,bool skip_nop_node,const std::vector<PrimitivePtr> & return_types,abstract::AbstractBasePtr * abstract,bool is_index_valid)359 KernelWithIndex AnfAlgo::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index, bool skip_nop_node,
360 const std::vector<PrimitivePtr> &return_types,
361 abstract::AbstractBasePtr *abstract, bool is_index_valid) {
362 MS_EXCEPTION_IF_NULL(anf_node);
363 if (std::any_of(return_types.begin(), return_types.end(), [&anf_node](const PrimitivePtr &prim_type) -> bool {
364 return CheckPrimitiveType(anf_node, prim_type);
365 })) {
366 return KernelWithIndex(anf_node, index);
367 }
368 if (!anf_node->isa<CNode>()) {
369 return KernelWithIndex(anf_node, index);
370 }
371 auto cnode = anf_node->cast<CNodePtr>();
372 MS_EXCEPTION_IF_NULL(cnode);
373 // TupleGetItem and SparseGetAttr needs to find real input
374 if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
375 return VisitKernelWithReturnTypeForTupleGetItem(anf_node, index, skip_nop_node, return_types, abstract,
376 is_index_valid);
377 }
378 if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimUpdateState)) {
379 return VisitKernelWithReturnType(cnode->input(kUpdateStateStateInput), index, skip_nop_node, return_types);
380 }
381 const PrimitiveSet follow_first_input_prims = {prim::kPrimDepend, prim::kPrimLoad, prim::kPrimDynamicLossScale};
382 if (IsOneOfPrimitiveCNode(cnode, follow_first_input_prims)) {
383 return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), index, skip_nop_node, return_types);
384 }
385 if (IsNopNode(cnode) && skip_nop_node) {
386 return VisitKernelWithReturnType(cnode->input(kNopNodeRealInputIndex), 0, skip_nop_node, return_types);
387 }
388 return KernelWithIndex(anf_node, index);
389 }
390
FetchRealNodeSkipMonadControl(const KernelWithIndex & node_with_index)391 KernelWithIndex AnfAlgo::FetchRealNodeSkipMonadControl(const KernelWithIndex &node_with_index) {
392 MS_EXCEPTION_IF_NULL(node_with_index.first);
393 const mindspore::HashSet<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = {prim::kPrimDepend,
394 prim::kPrimLoad};
395 if (IsOneOfPrimitiveCNode(node_with_index.first, auto_monad_prims)) {
396 return common::AnfAlgo::VisitKernelWithReturnType(node_with_index.first, node_with_index.second, false);
397 } else {
398 return node_with_index;
399 }
400 }
401
GetAllOutput(const AnfNodePtr & node,const std::vector<PrimitivePtr> & return_types)402 std::vector<AnfNodePtr> AnfAlgo::GetAllOutput(const AnfNodePtr &node, const std::vector<PrimitivePtr> &return_types) {
403 std::vector<AnfNodePtr> ret;
404 const auto &output_pair = GetAllOutputIndexByReturnTypes(node, return_types);
405 std::transform(output_pair.begin(), output_pair.end(), std::back_inserter(ret),
406 [](const KernelWithIndex &ele) { return ele.first; });
407 return ret;
408 }
409
GetAllOutputIndexByReturnTypes(const AnfNodePtr & node,const std::vector<PrimitivePtr> & return_types,bool need_make_tuple)410 std::vector<KernelWithIndex> AnfAlgo::GetAllOutputIndexByReturnTypes(const AnfNodePtr &node,
411 const std::vector<PrimitivePtr> &return_types,
412 bool need_make_tuple) {
413 std::vector<KernelWithIndex> ret;
414 auto return_prim_type = return_types;
415 // if visited make_tuple should return back
416 return_prim_type.push_back(prim::kPrimMakeTuple);
417 auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, false, return_prim_type);
418 if (need_make_tuple) {
419 ret.push_back(item_with_index);
420 }
421 if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
422 MS_EXCEPTION_IF_NULL(item_with_index.first);
423 auto make_tuple = item_with_index.first->cast<CNodePtr>();
424 MS_EXCEPTION_IF_NULL(make_tuple);
425 for (size_t i = 1; i < make_tuple->size(); i++) {
426 auto input_i_vector = GetAllOutputIndexByReturnTypes(make_tuple->input(i), return_types);
427 (void)std::copy(input_i_vector.begin(), input_i_vector.end(), std::back_inserter(ret));
428 }
429 return ret;
430 }
431 ret.push_back(item_with_index);
432 return ret;
433 }
434
GetOutputNumByAbstract(const AbstractBasePtr & node_abstract)435 size_t AnfAlgo::GetOutputNumByAbstract(const AbstractBasePtr &node_abstract) {
436 MS_EXCEPTION_IF_NULL(node_abstract);
437 size_t result = 0;
438
439 if (!node_abstract->isa<abstract::AbstractSequence>() ||
440 node_abstract->cast<abstract::AbstractSequencePtr>()->dynamic_len() ||
441 node_abstract->cast<abstract::AbstractSequencePtr>()->dynamic_len_element_abs() != nullptr) {
442 return 1;
443 }
444
445 auto tuple_abstract = node_abstract->cast<abstract::AbstractSequencePtr>();
446 MS_EXCEPTION_IF_NULL(tuple_abstract);
447 const auto &sub_abstracts = tuple_abstract->elements();
448 for (const auto &sub_abstract : sub_abstracts) {
449 MS_EXCEPTION_IF_NULL(sub_abstract);
450 result += GetOutputNumByAbstract(sub_abstract);
451 }
452 return result;
453 }
454
GetAllOutputWithOutMonadAndParameter(const AnfNodePtr & node)455 std::vector<KernelWithIndex> AnfAlgo::GetAllOutputWithOutMonadAndParameter(const AnfNodePtr &node) {
456 MS_EXCEPTION_IF_NULL(node);
457 const auto &graph_outputs = common::AnfAlgo::GetAllOutputWithIndex(node);
458 std::vector<KernelWithIndex> real_output;
459 for (const auto &node_with_index : graph_outputs) {
460 MS_EXCEPTION_IF_NULL(node_with_index.first);
461 if (HasAbstractMonad(node_with_index.first) || node_with_index.first->isa<Parameter>() ||
462 node_with_index.first->isa<ValueNode>()) {
463 continue;
464 }
465 real_output.emplace_back(node_with_index);
466 }
467 return real_output;
468 }
469
GetAllOutputWithIndex(const AnfNodePtr & node,const std::vector<PrimitivePtr> & return_types)470 std::vector<KernelWithIndex> AnfAlgo::GetAllOutputWithIndex(const AnfNodePtr &node,
471 const std::vector<PrimitivePtr> &return_types) {
472 auto ret = GetAllOutputWithIndexInner(node, return_types);
473 std::map<AnfNodePtr, size_t> value_node_index;
474
475 // Unify the output of the front and back end to the ValueTuple
476 for (auto &output_with_index : ret) {
477 auto value_node = output_with_index.first;
478 MS_EXCEPTION_IF_NULL(value_node);
479 if (!value_node->isa<ValueNode>()) {
480 continue;
481 }
482 if (value_node_index.find(value_node) == value_node_index.end() ||
483 value_node_index[value_node] < output_with_index.second) {
484 value_node_index[value_node] = output_with_index.second;
485 } else {
486 value_node_index[value_node]++;
487 MS_LOG(DEBUG) << "Set output value node new index, value node: " << value_node->fullname_with_scope()
488 << ", original index: " << output_with_index.second
489 << ", new index:" << value_node_index[value_node];
490 output_with_index.second = value_node_index[value_node];
491 }
492 }
493 return ret;
494 }
495
CheckPrimitiveType(const AnfNodePtr & node,const PrimitivePtr & primitive_type)496 bool AnfAlgo::CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
497 MS_EXCEPTION_IF_NULL(node);
498 if (!node->isa<CNode>()) {
499 return false;
500 }
501 auto cnode = node->cast<CNodePtr>();
502 MS_EXCEPTION_IF_NULL(cnode);
503 return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
504 }
505
GetCNodeFuncGraphPtr(const AnfNodePtr & node)506 FuncGraphPtr AnfAlgo::GetCNodeFuncGraphPtr(const AnfNodePtr &node) {
507 MS_EXCEPTION_IF_NULL(node);
508 auto cnode = node->cast<CNodePtr>();
509 MS_EXCEPTION_IF_NULL(cnode);
510 auto attr_input = cnode->input(kAnfPrimitiveIndex);
511 MS_EXCEPTION_IF_NULL(attr_input);
512 auto value_node = attr_input->cast<ValueNodePtr>();
513 MS_EXCEPTION_IF_NULL(value_node);
514 auto value = value_node->value();
515 MS_EXCEPTION_IF_NULL(value);
516 return value->cast<FuncGraphPtr>();
517 }
518
GetCNodeName(const AnfNodePtr & node)519 std::string AnfAlgo::GetCNodeName(const AnfNodePtr &node) {
520 // this function was moved to AnfUtils.
521 return AnfUtils::GetCNodeName(node);
522 }
523
IsGetNextNode(const AnfNodePtr & node)524 bool AnfAlgo::IsGetNextNode(const AnfNodePtr &node) {
525 auto node_name = AnfUtils::GetCNodeName(node);
526 return node_name == kGetNextOpName || node_name == kDynamicGetNextV2OpName;
527 }
528
GetNodeDebugString(const AnfNodePtr & node)529 std::string AnfAlgo::GetNodeDebugString(const AnfNodePtr &node) {
530 MS_EXCEPTION_IF_NULL(node);
531 return node->DebugString();
532 }
533
SetNodeAttr(const std::string & key,const ValuePtr & value,const AnfNodePtr & node)534 void AnfAlgo::SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node) {
535 // this function was moved to AnfUtils.
536 return AnfUtils::SetNodeAttr(key, value, node);
537 }
538
SetNodeAttrSafely(const std::string & key,const ValuePtr & value,const AnfNodePtr & node)539 void AnfAlgo::SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node) {
540 // Make CNode safe to set attr firstly.
541 auto cnode = node->cast<CNodePtr>();
542 if (cnode == nullptr) {
543 return;
544 }
545 auto prim = common::AnfAlgo::GetCNodePrimitive(cnode);
546 if (prim != nullptr) {
547 auto new_prim = prim->isa<PrimitivePy>() ? prim : prim->Clone();
548 cnode->set_input(0, NewValueNode(new_prim));
549 }
550
551 // Set attr secondly.
552 common::AnfAlgo::SetNodeAttr(key, value, node);
553 }
554
CopyNodeAttr(const std::string & key,const AnfNodePtr & from,const AnfNodePtr & to)555 void AnfAlgo::CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to) {
556 CopyNodeAttr(key, key, from, to);
557 }
558
CopyNodeAttr(const std::string & old_key,const std::string & new_key,const AnfNodePtr & from,const AnfNodePtr & to)559 void AnfAlgo::CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from,
560 const AnfNodePtr &to) {
561 MS_EXCEPTION_IF_NULL(from);
562 MS_EXCEPTION_IF_NULL(to);
563 if (!from->isa<CNode>() || !to->isa<CNode>()) {
564 MS_LOG(INTERNAL_EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << " ,to_node is "
565 << to->DebugString() << trace::DumpSourceLines(from);
566 }
567 auto from_primitive = AnfAlgo::GetCNodePrimitive(from);
568 MS_EXCEPTION_IF_NULL(from_primitive);
569 auto to_primitive = AnfAlgo::GetCNodePrimitive(to);
570 MS_EXCEPTION_IF_NULL(to_primitive);
571 to_primitive->set_attr(new_key, from_primitive->GetAttr(old_key));
572 }
573
CopyNodeAttrs(const AnfNodePtr & from,const AnfNodePtr & to)574 void AnfAlgo::CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to) {
575 MS_EXCEPTION_IF_NULL(from);
576 MS_EXCEPTION_IF_NULL(to);
577 if (!from->isa<CNode>() || !to->isa<CNode>()) {
578 MS_LOG(INTERNAL_EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << ",to_node is "
579 << from->DebugString() << trace::DumpSourceLines(from);
580 }
581 auto from_primitive = AnfAlgo::GetCNodePrimitive(from);
582 MS_EXCEPTION_IF_NULL(from_primitive);
583 auto to_primitive = AnfAlgo::GetCNodePrimitive(to);
584 MS_EXCEPTION_IF_NULL(to_primitive);
585 auto from_cnode = from->cast<CNodePtr>();
586 auto to_cnode = to->cast<CNodePtr>();
587 if (from_cnode->HasPrimalAttr(kAttrMicro)) {
588 to_cnode->AddPrimalAttr(kAttrMicro, from_cnode->GetPrimalAttr(kAttrMicro));
589 }
590 (void)to_primitive->SetAttrs(from_primitive->attrs());
591 }
592
EraseNodeAttr(const std::string & key,const AnfNodePtr & node)593 void AnfAlgo::EraseNodeAttr(const std::string &key, const AnfNodePtr &node) {
594 MS_EXCEPTION_IF_NULL(node);
595 if (!node->isa<CNode>()) {
596 MS_LOG(INTERNAL_EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString()
597 << trace::DumpSourceLines(node);
598 }
599 // single op cnode.
600 auto primitive = AnfAlgo::GetCNodePrimitive(node);
601 if (primitive != nullptr) {
602 primitive->EraseAttr(key);
603 return;
604 }
605 // graph kernel cnode.
606 auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
607 MS_EXCEPTION_IF_NULL(fg);
608 fg->erase_flag(key);
609 }
610
HasNodeAttr(const std::string & key,const CNodePtr & node)611 bool AnfAlgo::HasNodeAttr(const std::string &key, const CNodePtr &node) {
612 MS_EXCEPTION_IF_NULL(node);
613 // call node's input0 is not a primitive.
614 if (!IsValueNode<FuncGraph>(node->input(0)) && !IsValueNode<Primitive>(node->input(0))) {
615 return false;
616 }
617 // single op cnode.
618 auto primitive = AnfAlgo::GetCNodePrimitive(node);
619 if (primitive != nullptr) {
620 return primitive->HasAttr(key);
621 }
622 // graph kernel cnode.
623 auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
624 MS_EXCEPTION_IF_NULL(fg);
625 return fg->has_attr(key);
626 }
627
GetInputNum(const CNodePtr & cnode)628 size_t AnfAlgo::GetInputNum(const CNodePtr &cnode) {
629 MS_EXCEPTION_IF_NULL(cnode);
630 size_t input_num = cnode->size();
631 if (input_num == 0) {
632 MS_LOG(INTERNAL_EXCEPTION) << "Cnode inputs size can't be zero." << trace::DumpSourceLines(cnode);
633 }
634 return input_num - 1;
635 }
636
GetInputTensorNum(const AnfNodePtr & node)637 size_t AnfAlgo::GetInputTensorNum(const AnfNodePtr &node) {
638 // this function was moved to AnfUtils.
639 return AnfUtils::GetInputTensorNum(node);
640 }
641
IsPrevNodeHasTupleGetItem(const AnfNodePtr & anf_node,size_t input_idx,bool skip_nop_node)642 bool AnfAlgo::IsPrevNodeHasTupleGetItem(const AnfNodePtr &anf_node, size_t input_idx, bool skip_nop_node) {
643 if (!anf_node->isa<CNode>()) {
644 MS_LOG(INTERNAL_EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode."
645 << trace::DumpSourceLines(anf_node);
646 }
647 auto input_node = AnfAlgo::GetInputNode(anf_node->cast<CNodePtr>(), input_idx);
648 MS_EXCEPTION_IF_NULL(input_node);
649 auto res = VisitKernelWithReturnType(input_node, 0, skip_nop_node, {prim::kPrimTupleGetItem});
650 if (CheckPrimitiveType(res.first, prim::kPrimTupleGetItem)) {
651 return true;
652 }
653 return false;
654 }
655
GetPrevNodeOutput(const AnfNodePtr & anf_node,size_t input_idx,bool skip_nop_node)656 KernelWithIndex AnfAlgo::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx, bool skip_nop_node) {
657 MS_EXCEPTION_IF_NULL(anf_node);
658 if (!anf_node->isa<CNode>()) {
659 MS_LOG(INTERNAL_EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode."
660 << trace::DumpSourceLines(anf_node);
661 }
662 auto kernel_info = anf_node->kernel_info();
663 if (kernel_info) {
664 auto runtime_cache = kernel_info->runtime_cache();
665 if (runtime_cache.runtime_cache().is_valid()) {
666 auto output = runtime_cache.runtime_cache().get_prev_node_output(input_idx);
667 if (output.first != nullptr) {
668 return output;
669 }
670 }
671 }
672 KernelWithIndex res;
673 if (CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
674 res = VisitKernelWithReturnType(anf_node, 0, skip_nop_node);
675 } else {
676 auto input_node = AnfAlgo::GetInputNode(anf_node->cast<CNodePtr>(), input_idx);
677 MS_EXCEPTION_IF_NULL(input_node);
678 res = VisitKernelWithReturnType(input_node, 0, skip_nop_node);
679 }
680 if (kernel_info) {
681 auto runtime_cache = kernel_info->runtime_cache();
682 if (runtime_cache.runtime_cache().is_valid()) {
683 runtime_cache.runtime_cache().set_prev_node_output(input_idx, res);
684 }
685 }
686 return res;
687 }
688
689 // if the prev_node is MakeTuple, get all the input_nodes recursively, else use the ori GetPrevNodeOutput function
GetRealPrevNodesOutput(const AnfNodePtr & anf_node,size_t input_idx,bool skip_nop_node)690 std::vector<KernelWithIndex> AnfAlgo::GetRealPrevNodesOutput(const AnfNodePtr &anf_node, size_t input_idx,
691 bool skip_nop_node) {
692 MS_EXCEPTION_IF_NULL(anf_node);
693 auto cnode = anf_node->cast<CNodePtr>();
694 MS_EXCEPTION_IF_NULL(cnode);
695
696 std::vector<KernelWithIndex> res;
697 auto input_node = AnfAlgo::GetInputNode(cnode, input_idx);
698 MS_EXCEPTION_IF_NULL(input_node);
699 if (CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) {
700 auto maketuple_input_num = GetInputTensorNum(input_node);
701 for (size_t i = 0; i < maketuple_input_num; ++i) {
702 auto inputs_i = GetRealPrevNodesOutput(input_node, i, skip_nop_node);
703 (void)res.insert(res.end(), inputs_i.begin(), inputs_i.end());
704 }
705 } else {
706 (void)res.emplace_back(GetPrevNodeOutput(cnode, input_idx, skip_nop_node));
707 }
708 return res;
709 }
710
GetRealPrevNodesOutputInferDataType(const AnfNodePtr & node,size_t input_idx)711 std::vector<TypeId> AnfAlgo::GetRealPrevNodesOutputInferDataType(const AnfNodePtr &node, size_t input_idx) {
712 std::vector<KernelWithIndex> kernels_with_index = AnfAlgo::GetRealPrevNodesOutput(node, input_idx);
713 std::vector<TypeId> res;
714 (void)std::transform(kernels_with_index.begin(), kernels_with_index.end(), std::back_inserter(res),
715 [](auto kernel_with_index) {
716 return AnfAlgo::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second);
717 });
718 return res;
719 }
720
721 namespace {
GetShape(const abstract::BaseShapePtr & base_shape)722 inline ShapeVector GetShape(const abstract::BaseShapePtr &base_shape) {
723 MS_EXCEPTION_IF_NULL(base_shape);
724 if (base_shape->isa<abstract::Shape>()) {
725 auto shape_ptr = base_shape->cast<abstract::ShapePtr>();
726 MS_EXCEPTION_IF_NULL(shape_ptr);
727 return shape_ptr->shape();
728 }
729 return {};
730 }
731
GetOutputShape(const abstract::AbstractBasePtr & abstract,size_t output_idx,bool is_real_squence_output)732 ShapeVector GetOutputShape(const abstract::AbstractBasePtr &abstract, size_t output_idx, bool is_real_squence_output) {
733 MS_EXCEPTION_IF_NULL(abstract);
734 if (abstract->isa<abstract::AbstractTensor>() || abstract->isa<abstract::AbstractMapTensor>()) {
735 if (output_idx != 0) {
736 MS_LOG(INTERNAL_EXCEPTION) << "The abstract " << abstract->ToString()
737 << "is single output but got index:" << output_idx;
738 }
739 const auto &shape = abstract->GetShape();
740 return GetShape(shape);
741 } else if (abstract->isa<abstract::AbstractScalar>() || abstract->isa<abstract::AbstractMonad>()) {
742 return ShapeVector();
743 } else if (abstract->isa<abstract::AbstractSparseTensor>()) {
744 const auto &shape = abstract->GetShape();
745 MS_EXCEPTION_IF_NULL(shape);
746 const auto &tuple_shape = shape->cast<abstract::TupleShapePtr>();
747 MS_EXCEPTION_IF_NULL(tuple_shape);
748 if (output_idx >= tuple_shape->size()) {
749 MS_LOG(INTERNAL_EXCEPTION) << "Output index " << output_idx << "is larger than output number "
750 << tuple_shape->size() << " of tuple shape:" << tuple_shape->ToString()
751 << " in abstract:" << abstract;
752 }
753 return GetShape(tuple_shape->shape()[output_idx]);
754 }
755
756 if (!abstract->isa<abstract::AbstractSequence>()) {
757 MS_LOG(INFO) << "Unknown abstract for get shape:" << abstract->ToString();
758 return {};
759 }
760
761 const auto &sequence_abstract = abstract->cast<abstract::AbstractSequencePtr>();
762 MS_EXCEPTION_IF_NULL(sequence_abstract);
763 if (sequence_abstract->dynamic_len()) {
764 const auto &element_abstract = sequence_abstract->dynamic_len_element_abs();
765 if (element_abstract == nullptr) {
766 MS_LOG(ERROR) << "Invalid abstract for get shape:" << sequence_abstract->ToString();
767 return ShapeVector();
768 }
769 return GetOutputShape(element_abstract, 0, true);
770 }
771
772 if (sequence_abstract->size() == 0) {
773 return ShapeVector();
774 }
775
776 if (!is_real_squence_output) {
777 if (output_idx >= sequence_abstract->size()) {
778 MS_LOG(INTERNAL_EXCEPTION) << "Output index " << output_idx << "is larger than output number "
779 << sequence_abstract->size() << " of abstract:" << sequence_abstract->ToString();
780 }
781 MS_EXCEPTION_IF_NULL(sequence_abstract->elements()[output_idx]);
782 return GetOutputShape(sequence_abstract->elements()[output_idx], 0, true);
783 }
784
785 // For real sequence output, if the inner elements' shape is same, the output is {element_num, *actual_shape},
786 // otherwise is {element_num, inner_max_size}.
787 // For example:
788 // 1) Output abstract: ((3,4,5), (3,4,5)), output shape: (2, 3, 4, 5).
789 // 2) Output abstract: ((3,4,5), (3,4,6)), output shape: (2, 72).
790 ShapeVector elem_shape_vector;
791 size_t change_cnt = 0;
792 ShapeValueDType elem_size = 0;
793 for (const auto &elem_abs : sequence_abstract->elements()) {
794 MS_EXCEPTION_IF_NULL(elem_abs);
795 elem_shape_vector = GetOutputShape(elem_abs, 0, true);
796 auto cur_size = std::accumulate(elem_shape_vector.begin(), elem_shape_vector.end(), 1L, std::multiplies<int64_t>());
797 if (elem_size < cur_size) {
798 elem_size = cur_size;
799 ++change_cnt;
800 }
801 }
802
803 ShapeVector shape_vector = {SizeToLong(sequence_abstract->size())};
804 if (change_cnt == 1) {
805 (void)shape_vector.insert(shape_vector.end(), elem_shape_vector.begin(), elem_shape_vector.end());
806 } else {
807 shape_vector.push_back(elem_size);
808 }
809 return shape_vector;
810 }
811 } // namespace
812
GetOutputInferShape(const AnfNodePtr & node,size_t output_idx,bool is_real_squence_output)813 ShapeVector AnfAlgo::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx, bool is_real_squence_output) {
814 MS_EXCEPTION_IF_NULL(node);
815 return GetOutputShape(node->abstract(), output_idx, is_real_squence_output || AnfAlgo::IsDynamicSequence(node));
816 }
817
GetPrevNodeOutputInferShape(const AnfNodePtr & node,size_t input_idx)818 ShapeVector AnfAlgo::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) {
819 KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
820 return AnfAlgo::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second);
821 }
822
GetOutputInferType(const AnfNodePtr & node,size_t output_idx,bool is_real_tuple)823 TypePtr AnfAlgo::GetOutputInferType(const AnfNodePtr &node, size_t output_idx, bool is_real_tuple) {
824 MS_EXCEPTION_IF_NULL(node);
825 MS_EXCEPTION_IF_NULL(node->abstract());
826 const auto &type = node->abstract()->BuildType();
827 MS_EXCEPTION_IF_NULL(type);
828 if (!type->isa<Tuple>() && !type->isa<List>()) {
829 if (output_idx != 0) {
830 MS_LOG(EXCEPTION) << "Invalid index:" << output_idx << " for node:" << node->DebugString()
831 << " abstract:" << node->abstract()->ToString() << " type:" << type->ToString();
832 }
833 return type;
834 }
835 if (is_real_tuple) {
836 return type;
837 }
838 if (type->isa<Tuple>()) {
839 const auto &tuple_type = type->cast<TuplePtr>();
840 MS_EXCEPTION_IF_NULL(tuple_type);
841 if (tuple_type->dynamic_len()) {
842 if (output_idx != 0) {
843 MS_LOG(EXCEPTION) << "Failed to get type by index:" << output_idx << " type:" << type->ToString();
844 }
845 return tuple_type;
846 }
847 if (output_idx >= tuple_type->size()) {
848 MS_LOG(EXCEPTION) << "Invalid index:" << output_idx << " for node:" << node->DebugString()
849 << " abstract:" << node->abstract()->ToString() << " type:" << type->ToString();
850 }
851 return tuple_type->elements()[output_idx];
852 }
853 const auto &list_type = type->cast<ListPtr>();
854 MS_EXCEPTION_IF_NULL(list_type);
855 if (list_type->dynamic_len()) {
856 if (output_idx != 0) {
857 MS_LOG(EXCEPTION) << "Failed to get type by index:" << output_idx << " type:" << type->ToString();
858 }
859 return list_type;
860 }
861 if (output_idx >= list_type->size()) {
862 MS_LOG(EXCEPTION) << "Invalid index:" << output_idx << " for node:" << node->DebugString()
863 << " abstract:" << node->abstract()->ToString() << " type:" << type->ToString();
864 }
865 return list_type->elements()[output_idx];
866 }
867
GetOutputInferDataType(const TypePtr & type,size_t output_idx)868 TypeId AnfAlgo::GetOutputInferDataType(const TypePtr &type, size_t output_idx) {
869 auto type_ptr = type;
870 MS_EXCEPTION_IF_NULL(type_ptr);
871 if (type_ptr->isa<Tuple>()) {
872 auto tuple_ptr = type_ptr->cast<TuplePtr>();
873 MS_EXCEPTION_IF_NULL(tuple_ptr);
874 if (tuple_ptr->size() == 0) {
875 if (tuple_ptr->dynamic_len() && tuple_ptr->dynamic_element_type() != nullptr) {
876 MS_LOG(INFO) << "Dynamic empty tuple type has an dynamic element type:"
877 << tuple_ptr->dynamic_element_type()->type_id();
878 return tuple_ptr->dynamic_element_type()->type_id();
879 }
880 return kTypeUnknown;
881 }
882 if (tuple_ptr->dynamic_len()) {
883 MS_EXCEPTION_IF_NULL(tuple_ptr->dynamic_element_type());
884 return GetOutputInferDataType(tuple_ptr->dynamic_element_type(), 0);
885 }
886 MS_EXCEPTION_IF_NULL(tuple_ptr);
887 if (output_idx >= tuple_ptr->size()) {
888 MS_LOG(INTERNAL_EXCEPTION) << "Output index " << output_idx << " must be less than output number "
889 << tuple_ptr->size();
890 }
891 type_ptr = (*tuple_ptr)[output_idx];
892 MS_EXCEPTION_IF_NULL(type_ptr);
893 }
894
895 if (type_ptr->isa<List>()) {
896 auto list_ptr = type_ptr->cast<ListPtr>();
897 MS_EXCEPTION_IF_NULL(list_ptr);
898 if (list_ptr->size() == 0) {
899 if (list_ptr->dynamic_len() && list_ptr->dynamic_element_type() != nullptr) {
900 MS_LOG(INFO) << "Dynamic empty list type has an dynamic element type:"
901 << list_ptr->dynamic_element_type()->type_id();
902 return list_ptr->dynamic_element_type()->type_id();
903 }
904 return kTypeUnknown;
905 }
906 if (list_ptr->dynamic_len()) {
907 MS_EXCEPTION_IF_NULL(list_ptr->dynamic_element_type());
908 return GetOutputInferDataType(list_ptr->dynamic_element_type(), 0);
909 }
910 MS_EXCEPTION_IF_NULL(list_ptr);
911 if (output_idx >= list_ptr->size()) {
912 MS_LOG(INTERNAL_EXCEPTION) << "Output index " << output_idx << " must be less than output number "
913 << list_ptr->size();
914 }
915 type_ptr = (*list_ptr)[output_idx];
916 MS_EXCEPTION_IF_NULL(type_ptr);
917 }
918
919 if (type_ptr->isa<SparseTensorType>()) {
920 auto tensor_ptr = type_ptr->cast<SparseTensorTypePtr>();
921 MS_EXCEPTION_IF_NULL(tensor_ptr);
922 type_ptr = (*tensor_ptr)[output_idx];
923 MS_EXCEPTION_IF_NULL(type_ptr);
924 }
925
926 if (type_ptr->isa<TensorType>()) {
927 auto tensor_ptr = type_ptr->cast<TensorTypePtr>();
928 MS_EXCEPTION_IF_NULL(tensor_ptr);
929 TypePtr elem = tensor_ptr->element();
930 MS_EXCEPTION_IF_NULL(elem);
931 return elem->type_id();
932 }
933 if (type_ptr->isa<Tuple>() || type_ptr->isa<List>()) {
934 return GetOutputInferDataType(type_ptr, 0);
935 }
936 return type_ptr->type_id();
937 }
938
939 namespace {
IsTupleInTupleValueNode(const AnfNodePtr & node)940 bool IsTupleInTupleValueNode(const AnfNodePtr &node) {
941 if (node == nullptr || !node->isa<ValueNode>()) {
942 return false;
943 }
944 const auto &value_node = node->cast<ValueNodePtr>();
945 MS_EXCEPTION_IF_NULL(value_node);
946 const auto &value = value_node->value();
947 if (value == nullptr || !value->isa<ValueSequence>()) {
948 return false;
949 }
950 const auto &value_sequence = value->cast<ValueSequencePtr>();
951 MS_EXCEPTION_IF_NULL(value_sequence);
952 return std::any_of(value_sequence->value().begin(), value_sequence->value().end(),
953 [](const ValuePtr &sub_value) { return sub_value != nullptr && sub_value->isa<ValueSequence>(); });
954 }
955 } // namespace
956
GetOutputInferDataType(const AnfNodePtr & node,size_t output_idx)957 TypeId AnfAlgo::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) {
958 MS_EXCEPTION_IF_NULL(node);
959 if (IsCallNode(node) || IsTupleInTupleValueNode(node)) {
960 if (node->abstract() == nullptr) {
961 MS_LOG(INTERNAL_EXCEPTION) << "Empty abstract of call node:" << node->DebugString();
962 }
963 const auto &abs = common::AnfAlgo::FetchAbstractByIndex(node->abstract(), output_idx);
964 MS_EXCEPTION_IF_NULL(abs);
965 const auto &type = abs->BuildType();
966 MS_EXCEPTION_IF_NULL(type);
967 if (type->isa<TensorType>()) {
968 const auto &tensor_type = type->cast<TensorTypePtr>();
969 MS_EXCEPTION_IF_NULL(tensor_type);
970 const auto &element = tensor_type->element();
971 return element->type_id();
972 } else {
973 return type->type_id();
974 }
975 }
976 return GetOutputInferDataType(node->Type(), output_idx);
977 }
978
GetPrevNodeOutputInferDataType(const AnfNodePtr & node,size_t input_idx)979 TypeId AnfAlgo::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) {
980 KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
981 return AnfAlgo::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second);
982 }
983
GetPrevNodeOutputInferType(const AnfNodePtr & node,size_t input_idx)984 TypePtr AnfAlgo::GetPrevNodeOutputInferType(const AnfNodePtr &node, size_t input_idx) {
985 KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
986 return AnfAlgo::GetOutputInferType(kernel_with_index.first, kernel_with_index.second);
987 }
988
989 // set infer shapes and types of anf node
SetOutputTypeAndDetailShape(const std::vector<TypeId> & types,const std::vector<abstract::BaseShapePtr> & shapes,AnfNode * node)990 void AnfAlgo::SetOutputTypeAndDetailShape(const std::vector<TypeId> &types,
991 const std::vector<abstract::BaseShapePtr> &shapes, AnfNode *node) {
992 MS_EXCEPTION_IF_NULL(node);
993 auto node_ptr = node->cast<AnfNodePtr>();
994 MS_EXCEPTION_IF_NULL(node_ptr);
995 std::string node_name = "";
996 if (node_ptr->isa<CNode>()) {
997 node_name = GetCNodeName(node_ptr);
998 }
999 if (types.size() != shapes.size()) {
1000 MS_LOG(INTERNAL_EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size()
1001 << " for node " << node->fullname_with_scope() << "." << trace::DumpSourceLines(node);
1002 }
1003
1004 auto tuple_node = kNodeTupleOutSet.find(node_name);
1005 if (shapes.empty() && tuple_node == kNodeTupleOutSet.end()) {
1006 node->set_abstract(std::make_shared<abstract::AbstractNone>());
1007 } else if (shapes.size() == 1 && tuple_node == kNodeTupleOutSet.end()) {
1008 // single output handle
1009 if (shapes[0]->isa<abstract::NoShape>()) {
1010 auto abstract = std::make_shared<abstract::AbstractScalar>(TypeIdToType(types[0]));
1011 node->set_abstract(abstract);
1012 } else {
1013 auto abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[0]), shapes[0]);
1014 node->set_abstract(abstract);
1015 }
1016 } else {
1017 // multiple output handle
1018 std::vector<AbstractBasePtr> abstract_list;
1019 for (size_t i = 0; i < types.size(); ++i) {
1020 if (shapes[0]->isa<abstract::NoShape>()) {
1021 auto abstract = std::make_shared<abstract::AbstractScalar>(TypeIdToType(types[i]));
1022 abstract_list.emplace_back(abstract);
1023 } else {
1024 auto abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[i]), shapes[i]);
1025 abstract_list.emplace_back(abstract);
1026 }
1027 }
1028 auto abstract_tuple = std::make_shared<AbstractTuple>(abstract_list);
1029 node->set_abstract(abstract_tuple);
1030 }
1031 }
1032
SetSingleOutputTypeAndDetailShape(const std::vector<TypeId> & types,const std::vector<abstract::BaseShapePtr> & shapes,AnfNode * node)1033 void AnfAlgo::SetSingleOutputTypeAndDetailShape(const std::vector<TypeId> &types,
1034 const std::vector<abstract::BaseShapePtr> &shapes, AnfNode *node) {
1035 MS_EXCEPTION_IF_NULL(node);
1036 auto node_ptr = node->cast<AnfNodePtr>();
1037 MS_EXCEPTION_IF_NULL(node_ptr);
1038 if (types.size() != shapes.size()) {
1039 MS_LOG(INTERNAL_EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size()
1040 << " for node " << node->fullname_with_scope() << "." << trace::DumpSourceLines(node);
1041 }
1042 auto abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[0]), shapes[0]);
1043 node->set_abstract(abstract);
1044 }
1045
1046 namespace {
DeleteDynamicLen(AnfNode * node)1047 void DeleteDynamicLen(AnfNode *node) {
1048 MS_EXCEPTION_IF_NULL(node);
1049 if (node->abstract() != nullptr && node->abstract()->isa<abstract::AbstractSequence>()) {
1050 const auto &tuple_abs = node->abstract()->cast<abstract::AbstractSequencePtr>();
1051 MS_EXCEPTION_IF_NULL(tuple_abs);
1052 if (tuple_abs->dynamic_len()) {
1053 auto cloned_abstract = tuple_abs->Clone()->cast<abstract::AbstractSequencePtr>();
1054 cloned_abstract->set_dynamic_len(false);
1055 node->set_abstract(cloned_abstract);
1056 }
1057 }
1058 }
1059 } // namespace
1060
1061 // set infer shapes and types of anf node
SetOutputInferTypeAndShape(const std::vector<TypeId> & types,const std::vector<ShapeVector> & shapes,AnfNode * node,bool disable_dynamic_len)1062 void AnfAlgo::SetOutputInferTypeAndShape(const std::vector<TypeId> &types, const std::vector<ShapeVector> &shapes,
1063 AnfNode *node, bool disable_dynamic_len) {
1064 MS_EXCEPTION_IF_NULL(node);
1065 if (disable_dynamic_len) {
1066 DeleteDynamicLen(node);
1067 }
1068 auto node_ptr = node->cast<AnfNodePtr>();
1069 MS_EXCEPTION_IF_NULL(node_ptr);
1070 std::string node_name = "";
1071 if (node_ptr->isa<CNode>()) {
1072 node_name = GetCNodeName(node_ptr);
1073 }
1074 if (types.size() != shapes.size()) {
1075 MS_LOG(INTERNAL_EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size()
1076 << "." << trace::DumpSourceLines(node);
1077 }
1078 auto abstract_ptr = node_ptr->abstract();
1079
1080 auto tuple_node = kNodeTupleOutSet.find(node_name);
1081 if (shapes.empty() && tuple_node == kNodeTupleOutSet.end()) {
1082 node->set_abstract(std::make_shared<abstract::AbstractNone>());
1083 } else if (shapes.size() == 1 && tuple_node == kNodeTupleOutSet.end()) {
1084 // single output handle
1085 if (abstract_ptr != nullptr && abstract_ptr->isa<abstract::AbstractMapTensor>()) {
1086 // For AbstractMapTensor.
1087 abstract_ptr->set_shape(std::make_shared<abstract::Shape>(shapes[0]));
1088 return;
1089 }
1090
1091 abstract::AbstractTensorPtr abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[0]), shapes[0]);
1092 node->set_abstract(abstract);
1093 } else {
1094 // multiple output handle
1095 std::vector<AbstractBasePtr> abstract_list;
1096 for (size_t i = 0; i < types.size(); ++i) {
1097 abstract::AbstractTensorPtr abstract =
1098 std::make_shared<AbstractTensor>(TypeIdToType(types[i]), std::make_shared<abstract::Shape>(shapes[i]));
1099 abstract_list.emplace_back(abstract);
1100 }
1101 auto abstract_tuple = std::make_shared<AbstractTuple>(abstract_list);
1102 node->set_abstract(abstract_tuple);
1103 }
1104 }
1105 // copy an abstract of a node to another node
CopyAbstract(const AnfNodePtr & from_node,AnfNode * to_node)1106 void AnfAlgo::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node) {
1107 MS_EXCEPTION_IF_NULL(from_node);
1108 MS_EXCEPTION_IF_NULL(to_node);
1109 to_node->set_abstract(from_node->abstract());
1110 }
1111
IsNodeInGraphKernel(const AnfNodePtr & node)1112 bool AnfAlgo::IsNodeInGraphKernel(const AnfNodePtr &node) {
1113 // this function was moved to AnfUtils.
1114 return AnfUtils::IsNodeInGraphKernel(node);
1115 }
1116
GetOutputOfGraphkernel(const KernelWithIndex & kernel_with_index)1117 AnfNodePtr AnfAlgo::GetOutputOfGraphkernel(const KernelWithIndex &kernel_with_index) {
1118 auto func_graph = GetCNodeFuncGraph(kernel_with_index.first);
1119 if (func_graph == nullptr) {
1120 return kernel_with_index.first;
1121 }
1122 auto output = func_graph->output();
1123 if (CheckPrimitiveType(output, prim::kPrimMakeTuple)) {
1124 return output->cast<CNodePtr>()->input(kernel_with_index.second + 1);
1125 }
1126 return output;
1127 }
1128
IsParameterWeight(const ParameterPtr & node)1129 bool AnfAlgo::IsParameterWeight(const ParameterPtr &node) {
1130 MS_EXCEPTION_IF_NULL(node);
1131 return node->has_default();
1132 }
1133
IsLabelIndexInNode(const AnfNodePtr & node,size_t label_index)1134 bool AnfAlgo::IsLabelIndexInNode(const AnfNodePtr &node, size_t label_index) {
1135 MS_EXCEPTION_IF_NULL(node);
1136 if (!node->isa<CNode>()) {
1137 return false;
1138 }
1139 auto cnode = node->cast<CNodePtr>();
1140 MS_EXCEPTION_IF_NULL(cnode);
1141 if (AnfAlgo::GetCNodeName(cnode) == kLabelGotoOpName &&
1142 (AnfAlgo::GetNodeAttr<uint32_t>(cnode, kAttrLabelIndex) == label_index)) {
1143 return true;
1144 } else if (AnfAlgo::GetCNodeName(cnode) == kLabelSwitchOpName) {
1145 auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cnode, kAttrLabelSwitchList);
1146 if (std::find(label_list.begin(), label_list.end(), label_index) != label_list.end()) {
1147 return true;
1148 }
1149 }
1150 return false;
1151 }
1152
IsUpdateParameterKernel(const CNodePtr & node)1153 bool AnfAlgo::IsUpdateParameterKernel(const CNodePtr &node) {
1154 MS_EXCEPTION_IF_NULL(node);
1155 auto node_name = GetCNodeName(node);
1156 if (HasNodeAttr(kAttrAsync, node) && GetNodeAttr<bool>(node, kAttrAsync)) {
1157 return false;
1158 }
1159 if (!IsOneOfOperator(node_name) && node_name.find("Assign") == string::npos) {
1160 return false;
1161 }
1162 return true;
1163 }
1164
IsTupleOutput(const AnfNodePtr & anf)1165 bool AnfAlgo::IsTupleOutput(const AnfNodePtr &anf) {
1166 MS_EXCEPTION_IF_NULL(anf);
1167 TypePtr type = anf->Type();
1168 if (type == nullptr) {
1169 return false;
1170 }
1171
1172 // For dynamic sequence node, all output should be emplaced in single tensor.
1173 if (anf->abstract() && IsDynamicSequence(anf)) {
1174 return false;
1175 }
1176
1177 MS_EXCEPTION_IF_NULL(type);
1178 return type->isa<Tuple>() || type->isa<List>() || type->isa<SparseTensorType>();
1179 }
1180
GetInputNode(const CNodePtr & node,size_t index)1181 AnfNodePtr AnfAlgo::GetInputNode(const CNodePtr &node, size_t index) {
1182 MS_EXCEPTION_IF_NULL(node);
1183 auto get_input_index = index + 1;
1184 if (get_input_index >= node->size()) {
1185 MS_LOG(INTERNAL_EXCEPTION) << "Input index size " << get_input_index << ", but the node input size just "
1186 << node->size() << ". node: " << node->DebugString() << "."
1187 << trace::DumpSourceLines(node);
1188 }
1189 // input 0 is primitive node
1190 return node->input(get_input_index);
1191 }
1192
SetNodeInput(const CNodePtr & node,const AnfNodePtr & input_node,size_t index)1193 void AnfAlgo::SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index) {
1194 MS_EXCEPTION_IF_NULL(node);
1195 MS_EXCEPTION_IF_NULL(input_node);
1196 if (node->func_graph() != nullptr) {
1197 auto manager = node->func_graph()->manager();
1198 if (manager != nullptr) {
1199 manager->SetEdge(node, SizeToInt(index + 1), input_node);
1200 return;
1201 }
1202 }
1203 node->set_input(index + 1, input_node);
1204 }
1205
GetCNodePrimitiveNode(const CNodePtr & node)1206 AnfNodePtr AnfAlgo::GetCNodePrimitiveNode(const CNodePtr &node) {
1207 MS_EXCEPTION_IF_NULL(node);
1208 return node->input(kAnfPrimitiveIndex);
1209 }
1210
GetCNodePrimitive(const AnfNodePtr & node)1211 PrimitivePtr AnfAlgo::GetCNodePrimitive(const AnfNodePtr &node) {
1212 MS_EXCEPTION_IF_NULL(node);
1213 auto cnode = node->cast<CNodePtr>();
1214 MS_EXCEPTION_IF_NULL(cnode);
1215 auto attr_input = GetCNodePrimitiveNode(cnode);
1216 MS_EXCEPTION_IF_NULL(attr_input);
1217 auto value_node = attr_input->cast<ValueNodePtr>();
1218 MS_EXCEPTION_IF_NULL(value_node);
1219 auto value = value_node->value();
1220 MS_EXCEPTION_IF_NULL(value);
1221 auto primitive = value->cast<PrimitivePtr>();
1222 return primitive;
1223 }
1224
IsInplaceNode(const mindspore::AnfNodePtr & kernel,const string & type)1225 bool AnfAlgo::IsInplaceNode(const mindspore::AnfNodePtr &kernel, const string &type) {
1226 MS_EXCEPTION_IF_NULL(kernel);
1227 auto primitive = AnfAlgo::GetCNodePrimitive(kernel);
1228 if (!primitive) {
1229 return false;
1230 }
1231
1232 auto inplace_attr = primitive->GetAttr(type);
1233 if (inplace_attr == nullptr) {
1234 return false;
1235 }
1236
1237 return true;
1238 }
1239
IsCommunicationOp(const AnfNodePtr & node)1240 bool AnfAlgo::IsCommunicationOp(const AnfNodePtr &node) {
1241 static const std::set<std::string> kCommunicationOpNames = {
1242 kAllReduceOpName, kAllGatherOpName, kBroadcastOpName, kReduceScatterOpName, kSendOpName,
1243 kReceiveOpName, kAlltoAllOpName, kAllToAllOpName, kAllToAllvOpName, kMuxReceiveOpName,
1244 kMuxSendOpName, kReduceOpName, kBarrierOpName, kCollectiveScatterOpName, kCollectiveGatherOpName,
1245 kMatMulAllReduceOpName, kBatchISendIRecvOpName, kAlltoAllVOpName};
1246 MS_EXCEPTION_IF_NULL(node);
1247 if (!node->isa<CNode>()) {
1248 return false;
1249 }
1250 auto kernel_name = AnfAlgo::GetCNodeName(node);
1251 return (kCommunicationOpNames.find(kernel_name) != kCommunicationOpNames.end());
1252 }
1253
IsDtypeFormatSensitiveOp(const AnfNodePtr & node)1254 bool AnfAlgo::IsDtypeFormatSensitiveOp(const AnfNodePtr &node) {
1255 static const std::set<std::string> kDtypeFormatSensitiveOpNames = {kCastOpName};
1256 MS_EXCEPTION_IF_NULL(node);
1257 if (!node->isa<CNode>()) {
1258 return false;
1259 }
1260 auto kernel_name = AnfAlgo::GetCNodeName(node);
1261 return (kDtypeFormatSensitiveOpNames.find(kernel_name) != kDtypeFormatSensitiveOpNames.end());
1262 }
1263
IsFusedCommunicationOp(const AnfNodePtr & node)1264 bool AnfAlgo::IsFusedCommunicationOp(const AnfNodePtr &node) {
1265 if (!IsCommunicationOp(node)) {
1266 return false;
1267 }
1268 auto primitive = AnfAlgo::GetCNodePrimitive(node);
1269 MS_EXCEPTION_IF_NULL(primitive);
1270 ValuePtr attr_fusion = primitive->GetAttr(kAttrFusion);
1271 ValuePtr attr_not_delay_fusion = primitive->GetAttr(kAttrNotDelayFusion);
1272 if (attr_fusion == nullptr) {
1273 return false;
1274 }
1275
1276 auto fusion = GetValue<int64_t>(attr_fusion);
1277 if (fusion == 0) {
1278 return false;
1279 }
1280 if (attr_not_delay_fusion && GetValue<bool>(attr_not_delay_fusion)) {
1281 return false;
1282 }
1283 return true;
1284 }
1285
IsGetNext(const NotNull<AnfNodePtr> & node)1286 bool AnfAlgo::IsGetNext(const NotNull<AnfNodePtr> &node) {
1287 auto kernel_name = AnfAlgo::GetCNodeName(node);
1288 return kernel_name == kGetNextOpName || kernel_name == kDynamicGetNextV2OpName;
1289 }
1290
IsGraphKernel(const AnfNodePtr & node)1291 bool AnfAlgo::IsGraphKernel(const AnfNodePtr &node) {
1292 // this function was moved to AnfUtils.
1293 return AnfUtils::IsGraphKernel(node);
1294 }
1295
IsNeedSkipNopOpAddr(const AnfNodePtr & node)1296 bool AnfAlgo::IsNeedSkipNopOpAddr(const AnfNodePtr &node) {
1297 MS_EXCEPTION_IF_NULL(node);
1298 if (!node->isa<CNode>()) {
1299 return false;
1300 }
1301
1302 auto primitive = AnfAlgo::GetCNodePrimitive(node);
1303 if (primitive == nullptr) {
1304 return false;
1305 }
1306
1307 auto skip_nop_op_addr_attr = primitive->GetAttr(kAttrSkipNopOpAddr);
1308 if (skip_nop_op_addr_attr == nullptr) {
1309 return false;
1310 }
1311
1312 return GetValue<bool>(skip_nop_op_addr_attr);
1313 }
1314
IsNeedSkipNopOpExecution(const AnfNodePtr & node)1315 bool AnfAlgo::IsNeedSkipNopOpExecution(const AnfNodePtr &node) {
1316 MS_EXCEPTION_IF_NULL(node);
1317 if (!node->isa<CNode>()) {
1318 return false;
1319 }
1320
1321 auto primitive = AnfAlgo::GetCNodePrimitive(node);
1322 if (primitive == nullptr) {
1323 return false;
1324 }
1325
1326 auto skip_nop_execution_attr = primitive->GetAttr(kAttrSkipNopOpExecution);
1327 if (skip_nop_execution_attr == nullptr) {
1328 return false;
1329 }
1330
1331 return GetValue<bool>(skip_nop_execution_attr);
1332 }
1333
GetValueNodeFuncGraph(const AnfNodePtr & node)1334 FuncGraphPtr AnfAlgo::GetValueNodeFuncGraph(const AnfNodePtr &node) {
1335 MS_EXCEPTION_IF_NULL(node);
1336 auto value_node = node->cast<ValueNodePtr>();
1337 if (value_node == nullptr) {
1338 return nullptr;
1339 }
1340 auto value = value_node->value();
1341 if (value == nullptr) {
1342 return nullptr;
1343 }
1344 auto func_graph = value->cast<FuncGraphPtr>();
1345 return func_graph;
1346 }
1347
IsSwitchCall(const CNodePtr & call_node)1348 bool AnfAlgo::IsSwitchCall(const CNodePtr &call_node) {
1349 MS_EXCEPTION_IF_NULL(call_node);
1350 if (!CheckPrimitiveType(call_node, prim::kPrimCall)) {
1351 MS_LOG(INTERNAL_EXCEPTION) << "Call node should be a 'call', but is a " << call_node->DebugString() << "."
1352 << trace::DumpSourceLines(call_node);
1353 }
1354 auto input1 = call_node->input(1);
1355 MS_EXCEPTION_IF_NULL(input1);
1356 if (input1->isa<ValueNode>()) {
1357 return false;
1358 } else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) {
1359 return true;
1360 }
1361 MS_LOG(INTERNAL_EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString() << "."
1362 << trace::DumpSourceLines(call_node);
1363 }
1364
IsScalarInput(const CNodePtr & cnode,size_t index)1365 bool AnfAlgo::IsScalarInput(const CNodePtr &cnode, size_t index) {
1366 auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
1367 if (shape.empty()) {
1368 return true;
1369 }
1370 return shape.size() == kShape1dDims && shape[0] == 1;
1371 }
1372
IsScalarOutput(const CNodePtr & cnode,size_t index)1373 bool AnfAlgo::IsScalarOutput(const CNodePtr &cnode, size_t index) {
1374 auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
1375 if (shape.empty()) {
1376 return true;
1377 }
1378 return shape.size() == kShape1dDims && shape[0] == 1;
1379 }
1380
1381 namespace {
FindDelayExecPosition(const std::vector<CNodePtr> & nodes,size_t current_index,std::set<size_t> * invalid_position,std::map<size_t,std::vector<CNodePtr>> * insert_nodes)1382 void FindDelayExecPosition(const std::vector<CNodePtr> &nodes, size_t current_index, std::set<size_t> *invalid_position,
1383 std::map<size_t, std::vector<CNodePtr>> *insert_nodes) {
1384 MS_EXCEPTION_IF_NULL(invalid_position);
1385 MS_EXCEPTION_IF_NULL(insert_nodes);
1386 if (current_index >= nodes.size()) {
1387 return;
1388 }
1389 auto &node = nodes[current_index];
1390 for (size_t j = current_index + 1; j < nodes.size(); ++j) {
1391 auto &child = nodes[j];
1392 auto child_name = AnfAlgo::GetCNodeName(child);
1393 if (child_name == kAssignAddOpName || child_name == kAssignSubOpName || child_name == kAssignOpName ||
1394 IsOneOfOperator(child_name)) {
1395 return;
1396 }
1397
1398 auto input_size = child->size() - 1;
1399 for (size_t k = 0; k < input_size; ++k) {
1400 auto kernel_index = AnfAlgo::GetPrevNodeOutput(child, k, true);
1401 if (kernel_index.first != node) {
1402 continue;
1403 }
1404 (void)invalid_position->insert(current_index);
1405 auto iter = insert_nodes->find(j);
1406 if (iter != insert_nodes->end()) {
1407 iter->second.emplace_back(node);
1408 } else {
1409 (*insert_nodes)[j] = {node};
1410 }
1411 return;
1412 }
1413 }
1414 }
1415
DelayExecNode(const std::vector<CNodePtr> & nodes,const std::string & node_name,bool only_seed)1416 std::vector<CNodePtr> DelayExecNode(const std::vector<CNodePtr> &nodes, const std::string &node_name, bool only_seed) {
1417 std::map<size_t, std::vector<CNodePtr>> insert_nodes;
1418 std::set<size_t> invalid_position;
1419 for (size_t i = 0; i < nodes.size(); ++i) {
1420 auto &node = nodes[i];
1421 if (AnfAlgo::GetCNodeName(node) != node_name) {
1422 continue;
1423 }
1424 if (only_seed) {
1425 bool is_seed = true;
1426 auto input_size = node->size() - 1;
1427 for (size_t k = 0; k < input_size; ++k) {
1428 auto input = AnfAlgo::GetPrevNodeOutput(node, k, true).first;
1429 if (input != nullptr && input->isa<CNode>()) {
1430 is_seed = false;
1431 break;
1432 }
1433 }
1434 if (!is_seed) {
1435 continue;
1436 }
1437 }
1438 FindDelayExecPosition(nodes, i, &invalid_position, &insert_nodes);
1439 }
1440 std::vector<CNodePtr> result;
1441 for (size_t i = 0; i < nodes.size(); ++i) {
1442 auto iter = insert_nodes.find(i);
1443 if (iter != insert_nodes.end()) {
1444 (void)result.insert(result.end(), iter->second.rbegin(), iter->second.rend());
1445 }
1446 if (invalid_position.find(i) != invalid_position.end()) {
1447 continue;
1448 }
1449 result.emplace_back(nodes[i]);
1450 }
1451 return result;
1452 }
1453 } // namespace
1454
ReorderExecList(NotNull<std::vector<CNodePtr> * > node_list)1455 void AnfAlgo::ReorderExecList(NotNull<std::vector<CNodePtr> *> node_list) {
1456 std::vector<CNodePtr> result;
1457 std::copy(node_list->begin(), node_list->end(), std::back_inserter(result));
1458 result = DelayExecNode(result, kTransDataOpName, true);
1459 result = DelayExecNode(result, kCastOpName, true);
1460 result = DelayExecNode(result, kAdamApplyOneWithDecayOpName, false);
1461 result = DelayExecNode(result, kAdamApplyOneOpName, false);
1462 result = DelayExecNode(result, kQuantDTypeCastOpName, false);
1463 result = DelayExecNode(result, kFSEDecodeOpName, false);
1464 if (parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() > 1) {
1465 result = DelayExecNode(result, kDropoutGenMaskOpName, true);
1466 result = DelayExecNode(result, kStatelessDropOutGenMaskOpName, true);
1467 }
1468 node_list->clear();
1469 std::copy(result.begin(), result.end(), std::back_inserter(*node_list));
1470 }
1471
ReorderPosteriorExecList(NotNull<std::vector<CNodePtr> * > node_list)1472 void AnfAlgo::ReorderPosteriorExecList(NotNull<std::vector<CNodePtr> *> node_list) {
1473 std::vector<CNodePtr> ordinary_node_list;
1474 std::vector<CNodePtr> posterior_node_list;
1475
1476 for (const auto &node : *node_list) {
1477 MS_EXCEPTION_IF_NULL(node);
1478 if (IsOneOfPosteriorOperator(AnfAlgo::GetCNodeName(node))) {
1479 posterior_node_list.emplace_back(node);
1480 } else {
1481 ordinary_node_list.emplace_back(node);
1482 }
1483 }
1484 node_list->clear();
1485 std::copy(ordinary_node_list.begin(), ordinary_node_list.end(), std::back_inserter(*node_list));
1486 std::copy(posterior_node_list.begin(), posterior_node_list.end(), std::back_inserter(*node_list));
1487 }
1488
GetCNodeOutputPrecision(const AnfNodePtr & node)1489 TypeId AnfAlgo::GetCNodeOutputPrecision(const AnfNodePtr &node) {
1490 MS_EXCEPTION_IF_NULL(node);
1491 auto prim = AnfAlgo::GetCNodePrimitive(node);
1492 if (prim == nullptr) {
1493 return kTypeUnknown;
1494 }
1495
1496 TypeId except_type = kTypeUnknown;
1497 if (prim->GetAttr(kAttrOutputPrecision) != nullptr) {
1498 auto output_type_str = GetValue<std::string>(prim->GetAttr(kAttrOutputPrecision));
1499 if (output_type_str == "float16") {
1500 except_type = kNumberTypeFloat16;
1501 } else if (output_type_str == "float32") {
1502 except_type = kNumberTypeFloat32;
1503 } else {
1504 MS_LOG(INTERNAL_EXCEPTION) << "The fix precision must be float16 or float32, but got " << output_type_str << "."
1505 << trace::DumpSourceLines(node);
1506 }
1507 }
1508
1509 return except_type;
1510 }
1511
GetPrevNodeOutputPrecision(const AnfNodePtr & node,size_t input_idx)1512 TypeId AnfAlgo::GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx) {
1513 MS_EXCEPTION_IF_NULL(node);
1514 if (!node->isa<CNode>()) {
1515 MS_LOG(INTERNAL_EXCEPTION) << node->DebugString() << ", input node is not CNode." << trace::DumpSourceLines(node);
1516 }
1517 auto cnode = node->cast<CNodePtr>();
1518 MS_EXCEPTION_IF_NULL(cnode);
1519 if (input_idx + 1 >= cnode->size()) {
1520 MS_LOG(INTERNAL_EXCEPTION) << "Input index " << input_idx << " is larger than input number "
1521 << GetInputTensorNum(cnode) << "." << trace::DumpSourceLines(node);
1522 }
1523 auto input_node = cnode->input(input_idx + 1);
1524 MS_EXCEPTION_IF_NULL(input_node);
1525 auto kernel_with_index = VisitKernel(input_node, 0);
1526 if (!kernel_with_index.first->isa<CNode>()) {
1527 return kTypeUnknown;
1528 }
1529 return GetCNodeOutputPrecision(kernel_with_index.first);
1530 }
1531
IsCondControlKernel(const CNodePtr & node)1532 bool AnfAlgo::IsCondControlKernel(const CNodePtr &node) {
1533 MS_EXCEPTION_IF_NULL(node);
1534 if (node->inputs().empty()) {
1535 MS_LOG(INTERNAL_EXCEPTION) << "Illegal null input of cnode." << trace::DumpSourceLines(node);
1536 }
1537 auto input = node->input(kAnfPrimitiveIndex);
1538 return IsPrimitive(input, prim::kPrimLabelGoto) || IsPrimitive(input, prim::kPrimLabelSwitch);
1539 }
1540
GetBooleanAttr(const AnfNodePtr & node,const std::string & attr)1541 bool AnfAlgo::GetBooleanAttr(const AnfNodePtr &node, const std::string &attr) {
1542 MS_EXCEPTION_IF_NULL(node);
1543 if (!node->isa<CNode>()) {
1544 return false;
1545 }
1546 auto cnode = node->cast<CNodePtr>();
1547 MS_EXCEPTION_IF_NULL(cnode);
1548 auto has_attr = AnfAlgo::HasNodeAttr(attr, cnode);
1549 if (!has_attr) {
1550 return false;
1551 }
1552 return AnfAlgo::GetNodeAttr<bool>(node, attr);
1553 }
1554
GetDumpFlag(const AnfNodePtr & node)1555 std::optional<string> AnfAlgo::GetDumpFlag(const AnfNodePtr &node) {
1556 MS_EXCEPTION_IF_NULL(node);
1557 auto cnode = node->cast<CNodePtr>();
1558 if (cnode == nullptr || !AnfAlgo::HasNodeAttr(kAttrDump, cnode)) {
1559 return {};
1560 }
1561 return std::optional<string>{AnfAlgo::GetNodeAttr<string>(node, kAttrDump)};
1562 }
1563
IsNodeDynamicRank(const AnfNodePtr & node)1564 bool IsNodeDynamicRank(const AnfNodePtr &node) {
1565 MS_EXCEPTION_IF_NULL(node);
1566 if (!node->isa<CNode>()) {
1567 MS_LOG(DEBUG) << "Node is not a cnode";
1568 return false;
1569 }
1570 auto cnode = node->cast<CNodePtr>();
1571 MS_EXCEPTION_IF_NULL(cnode);
1572 auto in_dyn_rank = AnfAlgo::IsNodeInputDynamicRank(cnode);
1573 auto out_dyn_rank = AnfAlgo::IsNodeOutputDynamicRank(cnode);
1574 if (in_dyn_rank && !AnfAlgo::HasNodeAttr(kAttrInputIsDynamicRank, cnode)) {
1575 AnfAlgo::SetNodeAttrSafely(kAttrInputIsDynamicRank, MakeValue(true), cnode);
1576 MS_LOG(DEBUG) << "Set input dynamic rank attr for node:" << cnode->fullname_with_scope();
1577 }
1578 if (out_dyn_rank && !AnfAlgo::HasNodeAttr(kAttrOutputIsDynamicRank, cnode)) {
1579 AnfAlgo::SetNodeAttrSafely(kAttrOutputIsDynamicRank, MakeValue(true), cnode);
1580 MS_LOG(DEBUG) << "Set output dynamic rank attr for node:" << cnode->fullname_with_scope();
1581 }
1582 return in_dyn_rank || out_dyn_rank;
1583 }
1584
IsDynamicRankNode(const AnfNodePtr & node)1585 bool AnfAlgo::IsDynamicRankNode(const AnfNodePtr &node) {
1586 MS_EXCEPTION_IF_NULL(node);
1587 if (node->isa<Parameter>()) {
1588 return IsOutputAnchorDynamicRank(node, 0);
1589 }
1590 auto cnode = node->cast<CNodePtr>();
1591 MS_EXCEPTION_IF_NULL(cnode);
1592 if ((!HasNodeAttr(kAttrInputIsDynamicRank, cnode)) && (!HasNodeAttr(kAttrOutputIsDynamicRank, cnode))) {
1593 auto ret = IsNodeDynamicRank(node);
1594 MS_LOG(DEBUG) << "The Node:" << node->fullname_with_scope() << " is dynamic rank: [" << ret << "]";
1595 return ret;
1596 }
1597 return GetBooleanAttr(node, kAttrInputIsDynamicRank) || GetBooleanAttr(node, kAttrOutputIsDynamicRank) ||
1598 GetBooleanAttr(node, kAttrIsDynamicRank);
1599 }
1600
IsInputAnchorDynamicRank(const AnfNodePtr & node,size_t idx)1601 bool AnfAlgo::IsInputAnchorDynamicRank(const AnfNodePtr &node, size_t idx) {
1602 MS_EXCEPTION_IF_NULL(node);
1603 if (!node->isa<CNode>()) {
1604 MS_LOG(INTERNAL_EXCEPTION) << "Only cnode has inputs, node: " << node->fullname_with_scope();
1605 }
1606 const auto &in_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, idx);
1607 if (mindspore::IsDynamicRank(in_shape)) {
1608 return true;
1609 }
1610 return false;
1611 }
1612
IsOutputAnchorDynamicRank(const AnfNodePtr & node,size_t idx)1613 bool AnfAlgo::IsOutputAnchorDynamicRank(const AnfNodePtr &node, size_t idx) {
1614 MS_EXCEPTION_IF_NULL(node);
1615 const auto &out_shape = common::AnfAlgo::GetOutputInferShape(node, idx);
1616 if (mindspore::IsDynamicRank(out_shape)) {
1617 return true;
1618 }
1619 return false;
1620 }
1621
IsNodeInputDynamicRank(const CNodePtr & anf_node_ptr)1622 bool AnfAlgo::IsNodeInputDynamicRank(const CNodePtr &anf_node_ptr) {
1623 MS_EXCEPTION_IF_NULL(anf_node_ptr);
1624 const auto &inputs = anf_node_ptr->inputs();
1625 for (size_t i = 1; i < inputs.size(); ++i) {
1626 const auto &input = inputs[i];
1627 MS_EXCEPTION_IF_NULL(input);
1628 if (IsNodeOutputDynamicRank(input)) {
1629 return true;
1630 }
1631 }
1632 return false;
1633 }
1634
IsNodeOutputDynamicRank(const AnfNodePtr & node)1635 bool AnfAlgo::IsNodeOutputDynamicRank(const AnfNodePtr &node) {
1636 MS_EXCEPTION_IF_NULL(node);
1637 auto base_shape = node->Shape();
1638 if (base_shape == nullptr) {
1639 MS_LOG(INFO) << "Invalid base shape, node: " << node->fullname_with_scope();
1640 return false;
1641 }
1642 if (base_shape->isa<abstract::DynamicSequenceShape>()) {
1643 auto b_ptr = base_shape->cast<abstract::DynamicSequenceShapePtr>();
1644 if (b_ptr->IsDimUnknown()) {
1645 return true;
1646 }
1647 }
1648 return base_shape->IsDimUnknown();
1649 }
1650
IsDynamicShape(const AnfNodePtr & node)1651 bool AnfAlgo::IsDynamicShape(const AnfNodePtr &node) {
1652 MS_EXCEPTION_IF_NULL(node);
1653 if (!node->isa<CNode>()) {
1654 MS_LOG(DEBUG) << "Node is not a cnode.";
1655 return false;
1656 }
1657 auto cnode = node->cast<CNodePtr>();
1658 if ((!HasNodeAttr(kAttrInputIsDynamicShape, cnode)) && (!HasNodeAttr(kAttrOutputIsDynamicShape, cnode))) {
1659 auto ret = IsNodeDynamicShape(node);
1660 MS_LOG(DEBUG) << "The Node:" << node->fullname_with_scope() << " is dynamic shape or not:" << ret;
1661 return ret;
1662 }
1663 return GetBooleanAttr(node, kAttrInputIsDynamicShape) || GetBooleanAttr(node, kAttrOutputIsDynamicShape);
1664 }
1665
IsDynamicValue(const AnfNodePtr & node)1666 bool AnfAlgo::IsDynamicValue(const AnfNodePtr &node) {
1667 MS_EXCEPTION_IF_NULL(node);
1668 if (!node->isa<CNode>()) {
1669 MS_LOG(DEBUG) << "Node is not a cnode.";
1670 return false;
1671 }
1672 if (AnfAlgo::IsGraphKernel(node)) {
1673 MS_LOG(DEBUG) << "Node(" << node->fullname_with_scope() << ") is GraphKernel node, it's not dynamic value type.";
1674 return false;
1675 }
1676
1677 auto cnode = node->cast<CNodePtr>();
1678 if (cnode->HasAttr(ops::kHasDynamicValue)) {
1679 return true;
1680 }
1681 auto depend_list = abstract::GetValueDependArgIndices(cnode);
1682 if (!depend_list.empty()) {
1683 size_t real_input_num = cnode->size() - 1; // exclude primitive in input[0]
1684 for (auto i = depend_list.begin(); i != depend_list.end(); i++) {
1685 if (*i >= SizeToInt(real_input_num)) {
1686 continue;
1687 }
1688 if (!cnode->input(*i + 1)->isa<ValueNode>()) {
1689 cnode->AddAttr(mindspore::ops::kHasDynamicValue, MakeValue(true));
1690 MS_LOG(DEBUG) << "The input index[" << *i << "]"
1691 << " of node: " << cnode->fullname_with_scope() << " is a dynamic value input";
1692 return true;
1693 }
1694 }
1695 }
1696 return false;
1697 }
1698
GetRealDynamicShape(const std::vector<size_t> & shape,NotNull<std::vector<int64_t> * > dynamic_shape)1699 void AnfAlgo::GetRealDynamicShape(const std::vector<size_t> &shape, NotNull<std::vector<int64_t> *> dynamic_shape) {
1700 for (auto size : shape) {
1701 if (size == SIZE_MAX) {
1702 dynamic_shape->push_back(-1);
1703 } else {
1704 dynamic_shape->push_back(SizeToLong(size));
1705 }
1706 }
1707 }
1708
GetShapeFromSequenceShape(const abstract::SequenceShapePtr & sequeue_shape_ptr,size_t index)1709 static ShapeVector GetShapeFromSequenceShape(const abstract::SequenceShapePtr &sequeue_shape_ptr, size_t index) {
1710 MS_EXCEPTION_IF_NULL(sequeue_shape_ptr);
1711 auto shape_list = sequeue_shape_ptr->shape();
1712 if (index >= shape_list.size()) {
1713 MS_LOG(INTERNAL_EXCEPTION) << "Output Index:" << index << " >= " << shape_list.size();
1714 }
1715
1716 auto shape = shape_list[index];
1717 MS_EXCEPTION_IF_NULL(shape);
1718 if (shape->isa<abstract::NoShape>()) {
1719 // For scalar in sequeue case.
1720 return {};
1721 } else if (!shape->isa<abstract::Shape>()) {
1722 MS_LOG(INTERNAL_EXCEPTION) << "Invalid Shape Type(" << shape->ToString() << ") In Shape List";
1723 }
1724
1725 auto shape_ptr = shape->cast<abstract::ShapePtr>();
1726 return shape_ptr->max_shape();
1727 }
1728
GetOutputMaxShape(const AnfNodePtr & anf_node,size_t index)1729 ShapeVector AnfAlgo::GetOutputMaxShape(const AnfNodePtr &anf_node, size_t index) {
1730 MS_EXCEPTION_IF_NULL(anf_node);
1731 auto shape = anf_node->Shape();
1732 MS_EXCEPTION_IF_NULL(shape);
1733 if (shape->isa<abstract::Shape>()) {
1734 auto shape_ptr = shape->cast<abstract::ShapePtr>();
1735 return shape_ptr->max_shape();
1736 } else if (shape->isa<abstract::SequenceShape>()) {
1737 auto sequeue_shape_ptr = shape->cast<abstract::SequenceShapePtr>();
1738 return GetShapeFromSequenceShape(sequeue_shape_ptr, index);
1739 } else if (shape->isa<abstract::NoShape>()) {
1740 return {};
1741 } else if (shape->isa<abstract::DynamicSequenceShape>()) {
1742 return {1};
1743 } else {
1744 MS_LOG(INTERNAL_EXCEPTION) << "Invalid shape type." << trace::DumpSourceLines(anf_node);
1745 }
1746 }
1747
IsNodeOutputDynamicShape(const AnfNodePtr & node)1748 bool AnfAlgo::IsNodeOutputDynamicShape(const AnfNodePtr &node) {
1749 MS_EXCEPTION_IF_NULL(node);
1750 auto base_shape = node->Shape();
1751 if (base_shape == nullptr) {
1752 MS_LOG(INFO) << "Invalid base shape, node: " << node->fullname_with_scope();
1753 return false;
1754 }
1755 if (base_shape->isa<abstract::DynamicSequenceShape>()) {
1756 return true;
1757 }
1758 return base_shape->IsDynamic();
1759 }
1760
IsNodeInputDynamicShape(const CNodePtr & anf_node_ptr)1761 bool AnfAlgo::IsNodeInputDynamicShape(const CNodePtr &anf_node_ptr) {
1762 MS_EXCEPTION_IF_NULL(anf_node_ptr);
1763 const auto &inputs = anf_node_ptr->inputs();
1764 for (size_t i = 1; i < inputs.size(); ++i) {
1765 const auto &input = inputs[i];
1766 MS_EXCEPTION_IF_NULL(input);
1767 if (IsNodeOutputDynamicShape(input)) {
1768 return true;
1769 }
1770 }
1771 return false;
1772 }
1773
GetGraphSplitGroup(const AnfNodePtr & node)1774 std::string AnfAlgo::GetGraphSplitGroup(const AnfNodePtr &node) {
1775 return HasNodeAttr(kAttrGraphSplitGroup, node->cast<CNodePtr>())
1776 ? GetNodeAttr<std::string>(node->cast<CNodePtr>(), kAttrGraphSplitGroup)
1777 : "DefaultGroup";
1778 }
1779
GetAllVisitedCNode(const CNodePtr & node,std::vector<AnfNodePtr> * used_kernels,std::set<AnfNodePtr> * visited)1780 void AnfAlgo::GetAllVisitedCNode(const CNodePtr &node, std::vector<AnfNodePtr> *used_kernels,
1781 std::set<AnfNodePtr> *visited) {
1782 MS_EXCEPTION_IF_NULL(node);
1783 MS_EXCEPTION_IF_NULL(used_kernels);
1784 MS_EXCEPTION_IF_NULL(visited);
1785 if (visited->find(node) != visited->end()) {
1786 MS_LOG(INFO) << "Node:" << node->fullname_with_scope() << " has already been visited";
1787 return;
1788 }
1789 (void)visited->insert(node);
1790 auto input_size = node->size() - 1;
1791 for (size_t i = 0; i < input_size; ++i) {
1792 auto input = AnfAlgo::GetInputNode(node, i);
1793 if (!input->isa<CNode>()) {
1794 continue;
1795 }
1796 if (!AnfUtils::IsRealKernel(input) || IsNopNode(input)) {
1797 GetAllVisitedCNode(input->cast<CNodePtr>(), used_kernels, visited);
1798 } else {
1799 used_kernels->push_back(input);
1800 }
1801 }
1802 }
1803
GetAllFatherRealNode(const AnfNodePtr & anf_node,std::vector<AnfNodePtr> * result,std::set<AnfNodePtr> * visited)1804 void AnfAlgo::GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
1805 std::set<AnfNodePtr> *visited) {
1806 MS_EXCEPTION_IF_NULL(anf_node);
1807 MS_EXCEPTION_IF_NULL(result);
1808 MS_EXCEPTION_IF_NULL(visited);
1809 if (visited->find(anf_node) != visited->end()) {
1810 MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " has already been visited";
1811 return;
1812 }
1813 visited->insert(anf_node);
1814 if (AnfUtils::IsRealKernel(anf_node)) {
1815 result->emplace_back(anf_node);
1816 return;
1817 }
1818 if (!anf_node->isa<CNode>()) {
1819 return;
1820 }
1821 auto cnode = anf_node->cast<CNodePtr>();
1822 MS_EXCEPTION_IF_NULL(cnode);
1823 if (cnode->inputs().empty()) {
1824 MS_LOG(INTERNAL_EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString() << "."
1825 << trace::DumpSourceLines(cnode);
1826 }
1827 auto input0 = cnode->input(0);
1828 if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
1829 for (size_t i = 1; i < cnode->size(); ++i) {
1830 GetAllFatherRealNode(cnode->input(i), result, visited);
1831 }
1832 } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
1833 if (cnode->size() != kTupleGetItemInputSize) {
1834 MS_LOG(INTERNAL_EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
1835 }
1836 GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited);
1837 } else if (IsPrimitive(input0, prim::kPrimDepend)) {
1838 if (cnode->size() != kDependInputSize) {
1839 MS_LOG(INTERNAL_EXCEPTION) << "Depend node must have 2 inputs!" << trace::DumpSourceLines(cnode);
1840 }
1841 GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited);
1842 GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited);
1843 }
1844 }
1845
IsHostKernel(const CNodePtr & kernel_node)1846 bool AnfAlgo::IsHostKernel(const CNodePtr &kernel_node) {
1847 static const std::map<std::string, std::pair<size_t, size_t>> host_kernel_input_output_num = {
1848 {prim::kPrimDynamicShape->name(), {1, 1}},
1849 {prim::kPrimReshape->name(), {2, 1}},
1850 {prim::kPrimTensorShape->name(), {1, 1}}};
1851
1852 auto op_name = AnfAlgo::GetCNodeName(kernel_node);
1853 auto iter = host_kernel_input_output_num.find(op_name);
1854 if (iter == host_kernel_input_output_num.end()) {
1855 return false;
1856 }
1857
1858 auto input_num = GetInputTensorNum(kernel_node);
1859 auto output_num = AnfUtils::GetOutputTensorNum(kernel_node);
1860 auto kernel_input_num = iter->second.first;
1861 auto kernel_output_num = iter->second.second;
1862 if (kernel_input_num != input_num || kernel_output_num != output_num) {
1863 return false;
1864 }
1865 return true;
1866 }
1867
AddArgList(AbstractBasePtrList * args_spec_list,const AnfNodePtr & real_input,size_t real_input_index)1868 void AnfAlgo::AddArgList(AbstractBasePtrList *args_spec_list, const AnfNodePtr &real_input, size_t real_input_index) {
1869 MS_EXCEPTION_IF_NULL(args_spec_list);
1870 MS_EXCEPTION_IF_NULL(real_input);
1871
1872 // cppcheck-suppress unreadVariable
1873 auto lock = AnfUtils::GetAbstractLock(real_input.get());
1874 auto real_abs = real_input->abstract();
1875 MS_EXCEPTION_IF_NULL(real_abs);
1876 if (real_abs->isa<abstract::AbstractTuple>() && (!common::AnfAlgo::IsDynamicSequence(real_input))) {
1877 auto abs_tuple = real_abs->Clone()->cast<abstract::AbstractTuplePtr>();
1878 MS_EXCEPTION_IF_NULL(abs_tuple);
1879 MS_EXCEPTION_IF_CHECK_FAIL((real_input_index < abs_tuple->elements().size()), "Index is out of range.");
1880 auto abs_index = abs_tuple->elements()[real_input_index];
1881 (void)args_spec_list->emplace_back(abs_index);
1882 } else {
1883 (void)args_spec_list->emplace_back(real_abs->Clone());
1884 }
1885 }
1886
GetUpdateStateUsers(const FuncGraphManagerPtr & manager,const AnfNodePtr & node)1887 AnfNodeIndexSet AnfAlgo::GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) {
1888 AnfNodeIndexSet update_states;
1889 for (auto &user : manager->node_users()[node]) {
1890 if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimUpdateState)) {
1891 update_states.insert(user);
1892 }
1893 }
1894 return update_states;
1895 }
1896
GetRealInputs(const AnfNodePtr & node,std::vector<KernelWithIndex> * inputs)1897 void AnfAlgo::GetRealInputs(const AnfNodePtr &node, std::vector<KernelWithIndex> *inputs) {
1898 size_t input_num = AnfAlgo::GetInputTensorNum(node);
1899 for (size_t input_index = 0; input_index < input_num; ++input_index) {
1900 auto input_node = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), input_index);
1901 GetRealOutputRecursively(input_node, 0, inputs);
1902 }
1903 }
1904
IsBpropCutOpExecInBackend(const AnfNodePtr & node)1905 bool AnfAlgo::IsBpropCutOpExecInBackend(const AnfNodePtr &node) {
1906 MS_EXCEPTION_IF_NULL(node);
1907 if (!node->isa<CNode>()) {
1908 return false;
1909 }
1910 // Operators in set control_ops_exec_in_backend will be compiled into kernel graph, rather than be cut into single op
1911 // and executed in VM.
1912 static std::set<std::string> bprop_cut_ops_exec_in_backend = {kBpropCutOpName};
1913 return bprop_cut_ops_exec_in_backend.find(AnfAlgo::GetCNodeName(node)) != bprop_cut_ops_exec_in_backend.end();
1914 }
1915
IsNodeInputContainMonad(const AnfNodePtr & node)1916 bool AnfAlgo::IsNodeInputContainMonad(const AnfNodePtr &node) {
1917 MS_EXCEPTION_IF_NULL(node);
1918 auto input_size = GetInputTensorNum(node);
1919 for (size_t i = 0; i < input_size; ++i) {
1920 auto input_with_index = GetPrevNodeOutput(node, i);
1921 if (HasAbstractMonad(input_with_index.first)) {
1922 return true;
1923 }
1924 }
1925 return false;
1926 }
1927
HasMonadInput(const AnfNodePtr & node)1928 bool AnfAlgo::HasMonadInput(const AnfNodePtr &node) {
1929 MS_EXCEPTION_IF_NULL(node);
1930 if (!node->isa<CNode>()) {
1931 return false;
1932 }
1933
1934 auto cnode = node->cast<CNodePtr>();
1935 MS_EXCEPTION_IF_NULL(cnode);
1936 const auto &inputs = cnode->inputs();
1937 for (const auto &input : inputs) {
1938 MS_EXCEPTION_IF_NULL(input);
1939 if (HasAbstractMonad(input)) {
1940 return true;
1941 }
1942 }
1943 return false;
1944 }
1945
IsNonTaskOp(const CNodePtr & node)1946 bool AnfAlgo::IsNonTaskOp(const CNodePtr &node) {
1947 auto op_name = GetCNodeName(node);
1948 return (op_name == kSplitOpName || op_name == kSplitDOpName || op_name == kSplitVDOpName) &&
1949 AnfAlgo::HasNodeAttr(kAttrNonTask, node);
1950 }
1951
IsNoneInput(const AnfNodePtr & node,size_t index)1952 bool AnfAlgo::IsNoneInput(const AnfNodePtr &node, size_t index) {
1953 MS_EXCEPTION_IF_NULL(node);
1954 auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, index);
1955 auto prev_node = kernel_with_index.first;
1956 MS_EXCEPTION_IF_NULL(prev_node);
1957 // Only const optional input(None) support now.
1958 if (prev_node->isa<ValueNode>()) {
1959 auto value = prev_node->cast<ValueNodePtr>()->value();
1960 MS_EXCEPTION_IF_NULL(value);
1961 if (value->isa<None>()) {
1962 return true;
1963 }
1964 }
1965
1966 return false;
1967 }
1968
IsCallNode(const AnfNodePtr & node)1969 bool AnfAlgo::IsCallNode(const AnfNodePtr &node) {
1970 MS_EXCEPTION_IF_NULL(node);
1971 if (!node->isa<CNode>()) {
1972 return false;
1973 }
1974 auto input0 = node->cast<CNodePtr>()->input(0);
1975 if (IsValueNode<Primitive>(input0)) {
1976 return false;
1977 }
1978 return true;
1979 }
1980
GetAttrGroups(const AnfNodePtr & node,size_t index)1981 int64_t AnfAlgo::GetAttrGroups(const AnfNodePtr &node, size_t index) {
1982 if (node == nullptr) {
1983 return 1;
1984 }
1985 if (node->isa<CNode>()) {
1986 auto cnode = node->cast<CNodePtr>();
1987 if (HasNodeAttr(kAttrFracZGroupIdx, cnode)) {
1988 auto fz_group_idx = GetNodeAttr<std::vector<int64_t>>(cnode, kAttrFracZGroupIdx);
1989 if (index >= fz_group_idx.size()) {
1990 MS_LOG(INTERNAL_EXCEPTION) << "Index out of range, attr fracz_group_idx of node[" << node->fullname_with_scope()
1991 << "] only have " << fz_group_idx.size() << " numbers, but get index " << index;
1992 }
1993 return fz_group_idx[index];
1994 } else if (HasNodeAttr(kAttrFracZGroup, cnode)) {
1995 return GetNodeAttr<int64_t>(cnode, kAttrFracZGroup);
1996 }
1997 }
1998 if (node->isa<Parameter>()) {
1999 auto param = node->cast<ParameterPtr>();
2000 MS_EXCEPTION_IF_NULL(param);
2001 return param->fracz_group();
2002 }
2003 if (node->isa<ValueNode>()) {
2004 auto value_node = node->cast<ValueNodePtr>();
2005 MS_EXCEPTION_IF_NULL(value_node);
2006 return value_node->fracz_group();
2007 }
2008 return 1;
2009 }
2010
GetTupleIndexes(const AnfNodePtr & node,std::vector<size_t> * const index_stack)2011 AnfNodePtr AnfAlgo::GetTupleIndexes(const AnfNodePtr &node, std::vector<size_t> *const index_stack) {
2012 MS_EXCEPTION_IF_NULL(node);
2013 MS_EXCEPTION_IF_NULL(index_stack);
2014
2015 if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
2016 auto tuple_getitem = node->cast<CNodePtr>();
2017 MS_EXCEPTION_IF_NULL(tuple_getitem);
2018 // Get cur index
2019 auto output_index_value_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem);
2020 MS_EXCEPTION_IF_NULL(output_index_value_node);
2021 auto value_node = output_index_value_node->cast<ValueNodePtr>();
2022 MS_EXCEPTION_IF_NULL(value_node);
2023 auto output_idx = LongToSize(GetValue<int64_t>(value_node->value()));
2024 index_stack->push_back(output_idx);
2025 auto real_input = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem);
2026 return GetTupleIndexes(real_input, index_stack);
2027 }
2028 if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
2029 // If make_tuple in make_tuple, visit may start with inner tuple_getitem.
2030 if (index_stack->empty()) {
2031 MS_LOG(WARNING) << "Visit make tuple: " << node->DebugString()
2032 << ", but index are empty, visit should not start with inner tuple_getitem.";
2033 return nullptr;
2034 }
2035 auto make_tuple = node->cast<CNodePtr>();
2036 MS_EXCEPTION_IF_NULL(make_tuple);
2037 auto output_idx = index_stack->back();
2038 index_stack->pop_back();
2039 return GetTupleIndexes(make_tuple->input(1 + output_idx), index_stack);
2040 }
2041 if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
2042 return GetTupleIndexes(node->cast<CNodePtr>()->input(kRealInputIndexInDepend), index_stack);
2043 }
2044 if (IsPrimitiveCNode(node, prim::kPrimLoad)) {
2045 return GetTupleIndexes(node->cast<CNodePtr>()->input(1), index_stack);
2046 }
2047 MS_LOG(DEBUG) << "Get real node:" << node->DebugString();
2048 return node;
2049 }
2050
IsNopNode(const AnfNodePtr & node)2051 bool AnfAlgo::IsNopNode(const AnfNodePtr &node) {
2052 static mindspore::HashSet<std::string> nop_nodes = {prim::kPrimReshape->name(),
2053 kExpandDimsOpName,
2054 prim::kPrimSqueeze->name(),
2055 prim::kPrimFlatten->name(),
2056 kFlattenGradOpName,
2057 prim::kPrimReformat->name(),
2058 prim::kPrimTupleToList->name(),
2059 prim::kPrimListToTuple->name(),
2060 prim::kPrimTupleToTensor->name(),
2061 prim::kPrimScalarToTensor->name(),
2062 prim::kPrimTensorToTuple->name(),
2063 prim::kPrimTensorToScalar->name(),
2064 "ReshapeExt"};
2065 if (node == nullptr || !node->isa<CNode>()) {
2066 return false;
2067 }
2068 CNodePtr cnode = node->cast<CNodePtr>();
2069 MS_EXCEPTION_IF_NULL(cnode);
2070 if (cnode->inputs().empty()) {
2071 return false;
2072 }
2073 auto input0 = cnode->input(0);
2074 MS_EXCEPTION_IF_NULL(input0);
2075 if (!input0->isa<ValueNode>()) {
2076 return false;
2077 }
2078 bool is_nop_node = false;
2079 if (AnfAlgo::HasNodeAttr(kAttrNopOp, cnode)) {
2080 is_nop_node = AnfAlgo::GetNodeAttr<bool>(cnode, kAttrNopOp);
2081 }
2082 if (nop_nodes.find(AnfAlgo::GetCNodeName(cnode)) == nop_nodes.end() && !is_nop_node) {
2083 return false;
2084 }
2085
2086 // Check the input type and output type.
2087 if (GetOutputInferDataType(node, 0) != GetPrevNodeOutputInferDataType(node, 0)) {
2088 return false;
2089 }
2090
2091 return true;
2092 }
2093
2094 template <typename T>
CheckAbsType(const AnfNodePtr & node)2095 bool AnfAlgo::CheckAbsType(const AnfNodePtr &node) {
2096 MS_EXCEPTION_IF_NULL(node);
2097 MS_EXCEPTION_IF_NULL(node->abstract());
2098 return node->abstract()->cast<T>() != nullptr;
2099 }
2100
CheckAbsSparseTensor(const AnfNodePtr & node)2101 bool AnfAlgo::CheckAbsSparseTensor(const AnfNodePtr &node) {
2102 return CheckAbsType<abstract::AbstractSparseTensorPtr>(node);
2103 }
2104
CheckAbsSparseTensor(const abstract::AbstractBasePtr & abs)2105 bool AnfAlgo::CheckAbsSparseTensor(const abstract::AbstractBasePtr &abs) {
2106 return abs->cast<abstract::AbstractSparseTensorPtr>() != nullptr;
2107 }
2108
GetSparseTypeIdAt(const AnfNodePtr & node,size_t idx)2109 TypeId AnfAlgo::GetSparseTypeIdAt(const AnfNodePtr &node, size_t idx) {
2110 if (CheckAbsType<abstract::AbstractSparseTensorPtr>(node)) {
2111 auto abs_sparse = node->abstract()->cast<abstract::AbstractSparseTensorPtr>();
2112 auto shape_idx = abs_sparse->size() - 1;
2113 // idx points to a tensor element
2114 if (idx < shape_idx) {
2115 return abs_sparse->GetTensorTypeIdAt(idx);
2116 }
2117 return abs_sparse->GetShapeTypeIdAt(idx - shape_idx);
2118 }
2119 MS_LOG(INTERNAL_EXCEPTION) << "Expect AbstractCSRTensor or AbstractCOOTensor, but got "
2120 << node->abstract()->ToString();
2121 }
2122
GetTensorValueString(const tensor::BaseTensorPtr & tensor)2123 std::string AnfAlgo::GetTensorValueString(const tensor::BaseTensorPtr &tensor) {
2124 MS_EXCEPTION_IF_NULL(tensor);
2125 auto dtype = tensor->Dtype();
2126 MS_EXCEPTION_IF_NULL(dtype);
2127 size_t data_size = tensor->DataSize();
2128 auto shape = tensor->shape();
2129 std::ostringstream buf;
2130 auto fn = [&buf, data_size, &shape](auto addr) {
2131 // Tensor value.
2132 buf << "v";
2133 for (size_t i = 0; i < data_size; ++i) {
2134 buf << *(addr + i) << ",";
2135 }
2136 // Tensor shape is necessary.
2137 // For example, the value of ones[3x4] and ones[4x3] are the same, but the shape is different.
2138 buf << "s" << tensor::ShapeToString(shape);
2139 };
2140
2141 if (dtype->type_id() == kNumberTypeBool) {
2142 fn(reinterpret_cast<bool *>(tensor->data_c()));
2143 } else if (dtype->type_id() == kNumberTypeInt) {
2144 fn(reinterpret_cast<int *>(tensor->data_c()));
2145 } else if (dtype->type_id() == kNumberTypeInt8) {
2146 fn(reinterpret_cast<int8_t *>(tensor->data_c()));
2147 } else if (dtype->type_id() == kNumberTypeUInt8) {
2148 fn(reinterpret_cast<uint8_t *>(tensor->data_c()));
2149 } else if (dtype->type_id() == kNumberTypeInt16) {
2150 fn(reinterpret_cast<int16_t *>(tensor->data_c()));
2151 } else if (dtype->type_id() == kNumberTypeUInt16) {
2152 fn(reinterpret_cast<uint16_t *>(tensor->data_c()));
2153 } else if (dtype->type_id() == kNumberTypeInt32) {
2154 fn(reinterpret_cast<int32_t *>(tensor->data_c()));
2155 } else if (dtype->type_id() == kNumberTypeUInt32) {
2156 fn(reinterpret_cast<uint32_t *>(tensor->data_c()));
2157 } else if (dtype->type_id() == kNumberTypeInt64) {
2158 fn(reinterpret_cast<int64_t *>(tensor->data_c()));
2159 } else if (dtype->type_id() == kNumberTypeUInt64) {
2160 fn(reinterpret_cast<uint64_t *>(tensor->data_c()));
2161 } else if (dtype->type_id() == kNumberTypeFloat16) {
2162 fn(reinterpret_cast<float16 *>(tensor->data_c()));
2163 } else if (dtype->type_id() == kNumberTypeFloat64) {
2164 fn(reinterpret_cast<double *>(tensor->data_c()));
2165 } else if (dtype->type_id() == kNumberTypeFloat || dtype->type_id() == kNumberTypeFloat32) {
2166 fn(reinterpret_cast<float *>(tensor->data_c()));
2167 } else if (dtype->type_id() == kNumberTypeBFloat16) {
2168 fn(reinterpret_cast<bfloat16 *>(tensor->data_c()));
2169 } else if (dtype->type_id() == kNumberTypeComplex64) {
2170 fn(reinterpret_cast<complex64 *>(tensor->data_c()));
2171 } else if (dtype->type_id() == kNumberTypeComplex128) {
2172 fn(reinterpret_cast<complex128 *>(tensor->data_c()));
2173 } else {
2174 MS_LOG(INTERNAL_EXCEPTION) << "The dtype of the constant input is " << dtype->ToString();
2175 }
2176 return buf.str();
2177 }
2178
FrontendGetNodeAbstractByIndex(const AnfNodePtr & node,size_t index)2179 abstract::AbstractBasePtr AnfAlgo::FrontendGetNodeAbstractByIndex(const AnfNodePtr &node, size_t index) {
2180 MS_EXCEPTION_IF_NULL(node);
2181 const auto &abstract = node->abstract();
2182 if (abstract == nullptr) {
2183 return abstract;
2184 }
2185
2186 // Return output abstract directly for : 1.not sequence type, 2.dynamic sequence type, 3.real tuple/list type.
2187 if (!abstract->isa<abstract::AbstractSequence>() || common::AnfAlgo::IsDynamicSequence(node)) {
2188 MS_EXCEPTION_IF_CHECK_FAIL((index == 0),
2189 "Cannot get " + std::to_string(index) + " child abstract from " + abstract->ToString());
2190 return abstract;
2191 }
2192
2193 // Return element abstract by index for tuple type.
2194 const auto &abstract_tuple = abstract->cast<abstract::AbstractSequencePtr>();
2195 MS_EXCEPTION_IF_NULL(abstract_tuple);
2196 const auto &elements = abstract_tuple->elements();
2197 if (elements.size() <= index) {
2198 const auto sub_abstract = FetchAbstractByIndex(node->abstract(), index);
2199 return sub_abstract;
2200 }
2201 return elements[index];
2202 }
2203
GetJitLevel(const FuncGraphPtr & func_graph)2204 std::string AnfAlgo::GetJitLevel(const FuncGraphPtr &func_graph) {
2205 MS_EXCEPTION_IF_NULL(func_graph);
2206 if (!func_graph->has_attr(kAttrJitLevel)) {
2207 MS_LOG(INFO) << "The func_graph:" << func_graph->ToString() << " has no jit_level attr, return default: None.";
2208 return "";
2209 }
2210 auto jit_level_value = func_graph->get_attr(kAttrJitLevel);
2211 auto jit_level = GetValue<std::string>(jit_level_value);
2212 return jit_level;
2213 }
2214
IsNodeMutableScalar(const AnfNodePtr & node)2215 bool AnfAlgo::IsNodeMutableScalar(const AnfNodePtr &node) {
2216 MS_EXCEPTION_IF_NULL(node);
2217 if (!node->isa<CNode>()) {
2218 return false;
2219 }
2220 // Check if the node is mutable scalar by all_inputs are scalar or output is scalar.
2221 const auto &is_mutable_scalar_func = [](const AnfNodePtr &cur_node) {
2222 const auto &abstract = cur_node->abstract();
2223 if (abstract == nullptr || (!abstract->isa<abstract::AbstractScalar>())) {
2224 return false;
2225 }
2226 if (abstract->BuildValue()->ContainsValueAny() && abstract->BuildType()->isa<Number>()) {
2227 return true;
2228 }
2229 return false;
2230 };
2231 bool is_output_mutable_scalar = is_mutable_scalar_func(node);
2232 bool is_scalar_to_tensor = IsPrimitiveCNode(node, prim::kPrimScalarToTensor);
2233 if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
2234 const auto &cnode = node->cast<CNodePtr>();
2235 MS_EXCEPTION_IF_NULL(cnode);
2236 if (!is_mutable_scalar_func(cnode->input(kRealInputIndexInDepend))) {
2237 return false;
2238 }
2239 }
2240 return is_output_mutable_scalar || is_scalar_to_tensor;
2241 }
2242
IsDynamicSequence(const AnfNodePtr & node)2243 bool AnfAlgo::IsDynamicSequence(const AnfNodePtr &node) {
2244 MS_EXCEPTION_IF_NULL(node);
2245 // Check if the node is dynamic sequence by sign in abstract.
2246 const auto &is_dynamic_len_func = [&node]() {
2247 const auto &abstract = node->abstract();
2248 if (abstract == nullptr || (!abstract->isa<abstract::AbstractSequence>())) {
2249 return false;
2250 }
2251
2252 const auto &sequence_abstract = abstract->cast<abstract::AbstractSequencePtr>();
2253 MS_EXCEPTION_IF_NULL(sequence_abstract);
2254 return sequence_abstract->dynamic_len() || sequence_abstract->dynamic_len_element_abs() != nullptr;
2255 };
2256
2257 // Check if the node is dynamic sequence by sign in node, in cnode it is an attr in primitive, in parameter, it is
2258 // an sign.
2259 if (node->isa<Parameter>()) {
2260 const auto ¶meter = node->cast<ParameterPtr>();
2261 MS_EXCEPTION_IF_NULL(parameter);
2262 if (parameter->dynamic_len()) {
2263 return true;
2264 }
2265 bool is_dynamic = is_dynamic_len_func();
2266 if (is_dynamic) {
2267 parameter->set_dynamic_len(true);
2268 }
2269 return is_dynamic;
2270 } else if (node->isa<CNode>()) {
2271 if (IsCallNode(node)) {
2272 return is_dynamic_len_func();
2273 }
2274 const auto &cnode = node->cast<CNodePtr>();
2275 MS_EXCEPTION_IF_NULL(cnode);
2276 if (cnode->HasAttr(kAttrDynamicLenName)) {
2277 return GetValue<bool>(cnode->GetAttr(kAttrDynamicLenName));
2278 } else {
2279 bool is_dynamic = is_dynamic_len_func();
2280 cnode->AddAttr(kAttrDynamicLenName, MakeValue(is_dynamic));
2281 return is_dynamic;
2282 }
2283 } else if (node->isa<ValueNode>()) {
2284 return is_dynamic_len_func();
2285 }
2286 return false;
2287 }
2288
IsAnyTypeOutput(const AnfNodePtr & node)2289 bool AnfAlgo::IsAnyTypeOutput(const AnfNodePtr &node) {
2290 MS_EXCEPTION_IF_NULL(node);
2291 if (node->isa<CNode>()) {
2292 if (IsCallNode(node)) {
2293 if (node->abstract() != nullptr && node->abstract()->isa<abstract::AbstractAny>()) {
2294 return true;
2295 }
2296 return false;
2297 }
2298 const auto &cnode = node->cast<CNodePtr>();
2299 MS_EXCEPTION_IF_NULL(cnode);
2300 if (cnode->HasAttr(kAttrAnyOutputName)) {
2301 return GetValue<bool>(cnode->GetAttr(kAttrAnyOutputName));
2302 } else {
2303 bool is_any_output = (node->abstract() != nullptr && node->abstract()->isa<abstract::AbstractAny>());
2304 cnode->AddAttr(kAttrAnyOutputName, MakeValue(is_any_output));
2305 return is_any_output;
2306 }
2307 }
2308 return false;
2309 }
2310
2311 namespace {
IsIncludeAny(const abstract::AbstractBasePtr & abstract)2312 bool IsIncludeAny(const abstract::AbstractBasePtr &abstract) {
2313 if (abstract == nullptr) {
2314 return false;
2315 }
2316 if (abstract->isa<abstract::AbstractAny>()) {
2317 return true;
2318 }
2319 if (!abstract->isa<abstract::AbstractSequence>()) {
2320 return false;
2321 }
2322 const auto &seq_abstract = abstract->cast<abstract::AbstractSequencePtr>();
2323 MS_EXCEPTION_IF_NULL(seq_abstract);
2324 if (std::any_of(seq_abstract->elements().begin(), seq_abstract->elements().end(),
2325 [](const auto &abstract) { return IsIncludeAny(abstract); })) {
2326 return true;
2327 }
2328 return false;
2329 }
2330 } // namespace
2331
IsAnyTypeInput(const std::vector<AnfNodePtr> & inputs)2332 bool AnfAlgo::IsAnyTypeInput(const std::vector<AnfNodePtr> &inputs) {
2333 for (const auto &input : inputs) {
2334 MS_EXCEPTION_IF_NULL(input);
2335 if (IsIncludeAny(input->abstract())) {
2336 return true;
2337 }
2338 }
2339 return false;
2340 }
2341
HasTupleInput(const CNodePtr & node)2342 bool AnfAlgo::HasTupleInput(const CNodePtr &node) {
2343 MS_EXCEPTION_IF_NULL(node);
2344 size_t input_num = node->size() - 1;
2345 for (size_t i = 0; i < input_num; ++i) {
2346 auto input_node = common::AnfAlgo::GetInputNode(node, i);
2347 MS_EXCEPTION_IF_NULL(input_node);
2348 if (common::AnfAlgo::IsTupleOutput(input_node)) {
2349 return true;
2350 }
2351 }
2352 return false;
2353 }
2354
HasDynamicTupleInput(const CNodePtr & node)2355 bool AnfAlgo::HasDynamicTupleInput(const CNodePtr &node) {
2356 MS_EXCEPTION_IF_NULL(node);
2357 size_t input_num = node->size() - 1;
2358 for (size_t i = 0; i < input_num; ++i) {
2359 auto input_node = common::AnfAlgo::GetInputNode(node, i);
2360 MS_EXCEPTION_IF_NULL(input_node);
2361 if (common::AnfAlgo::IsDynamicSequence(input_node)) {
2362 return true;
2363 }
2364 }
2365 return false;
2366 }
2367
IsReduceOp(const std::string & op_name)2368 bool AnfAlgo::IsReduceOp(const std::string &op_name) {
2369 static const std::set<std::string> reduce_op_type = {prim::kPrimReduceAll->name(), prim::kPrimReduceAny->name(),
2370 prim::kPrimReduceMean->name(), prim::kPrimReduceMax->name(),
2371 prim::kPrimReduceMin->name(), prim::kPrimReduceProd->name(),
2372 prim::kPrimReduceSum->name(), prim::kPrimSquareSumV1->name()};
2373 return reduce_op_type.find(op_name) != reduce_op_type.end();
2374 }
2375
IsTypeTransformOp(const std::string & op_name)2376 bool AnfAlgo::IsTypeTransformOp(const std::string &op_name) {
2377 static const std::set<std::string> type_trans_op_names = {
2378 prim::kPrimTupleToTensor->name(), prim::kPrimTensorToTuple->name(), prim::kPrimScalarToTensor->name(),
2379 prim::kPrimTensorToScalar->name(), prim::kPrimRealMakeTuple->name(), prim::kPrimRealTupleGetItem->name()};
2380 return type_trans_op_names.find(op_name) != type_trans_op_names.end();
2381 }
2382
GetDynamicSequenceShape(const AnfNodePtr & node,size_t output_idx)2383 abstract::BaseShapePtr AnfAlgo::GetDynamicSequenceShape(const AnfNodePtr &node, size_t output_idx) {
2384 MS_EXCEPTION_IF_NULL(node);
2385 abstract::AbstractSequencePtr sequence_abs = nullptr;
2386 if (node->Shape() == nullptr || (!node->Shape()->isa<abstract::DynamicSequenceShape>())) {
2387 MS_LOG(INFO) << "node:" << node->fullname_with_scope() << " index:" << output_idx
2388 << " abs:" << node->abstract()->ToString();
2389 if (!node->abstract()->isa<abstract::AbstractSequence>()) {
2390 MS_LOG(INTERNAL_EXCEPTION) << "Not sequence abstract in node:" << node->DebugString()
2391 << " for dynamic sequence shape.";
2392 }
2393 const auto &top_sequence_abs = node->abstract()->cast<abstract::AbstractSequencePtr>();
2394 MS_EXCEPTION_IF_NULL(top_sequence_abs);
2395 if (output_idx >= top_sequence_abs->elements().size()) {
2396 MS_LOG(INTERNAL_EXCEPTION) << "Invalid index:" << output_idx << " for abs:" << top_sequence_abs->ToString()
2397 << "node:" << node->fullname_with_scope();
2398 }
2399 const auto &sub_abs = top_sequence_abs->elements()[output_idx];
2400 MS_EXCEPTION_IF_NULL(sub_abs);
2401 if (!sub_abs->isa<abstract::AbstractSequence>()) {
2402 MS_LOG(INTERNAL_EXCEPTION) << "Not sequence abstract in node:" << node->DebugString()
2403 << " for dynamic sequence shape.";
2404 }
2405 sequence_abs = sub_abs->cast<abstract::AbstractSequencePtr>();
2406 } else {
2407 if (node->abstract() == nullptr) {
2408 MS_LOG(INTERNAL_EXCEPTION) << "Empty abstract in node:" << node->DebugString() << " for dynamic sequence shape.";
2409 }
2410 if (!node->abstract()->isa<abstract::AbstractSequence>()) {
2411 MS_LOG(INTERNAL_EXCEPTION) << "Not sequence abstract in node:" << node->DebugString()
2412 << " for dynamic sequence shape.";
2413 }
2414 sequence_abs = node->abstract()->cast<abstract::AbstractSequencePtr>();
2415 }
2416 MS_EXCEPTION_IF_NULL(sequence_abs);
2417 if (!sequence_abs->dynamic_len()) {
2418 MS_LOG(INTERNAL_EXCEPTION) << "Not dynamic abstract in node:" << node->DebugString()
2419 << " for dynamic sequence shape.";
2420 }
2421 const auto &element_abs = sequence_abs->dynamic_len_element_abs();
2422 if (element_abs == nullptr) {
2423 MS_LOG(INFO) << "No element abs for node:" << node->DebugString() << " index:" << output_idx;
2424 ShapeVector empty_shape{0};
2425 return std::make_shared<abstract::Shape>(empty_shape);
2426 }
2427 return element_abs->BuildShape();
2428 }
2429
FetchAbstractByIndex(const AbstractBasePtr & abstract,size_t index)2430 abstract::AbstractBasePtr AnfAlgo::FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index) {
2431 MS_EXCEPTION_IF_NULL(abstract);
2432 if (!abstract->isa<abstract::AbstractSequence>() || abstract->cast<abstract::AbstractSequencePtr>()->dynamic_len()) {
2433 if (index != 0) {
2434 MS_LOG(INTERNAL_EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
2435 }
2436 return abstract;
2437 }
2438
2439 auto tuple_abstract = abstract->cast<abstract::AbstractSequencePtr>();
2440 MS_EXCEPTION_IF_NULL(tuple_abstract);
2441 const auto &sub_abstracts = tuple_abstract->elements();
2442 size_t real_index = index;
2443 for (const auto &sub_abstract : sub_abstracts) {
2444 size_t tmp_index = common::AnfAlgo::GetOutputNumByAbstract(sub_abstract);
2445 if (real_index >= tmp_index) {
2446 real_index -= tmp_index;
2447 continue;
2448 }
2449 return FetchAbstractByIndex(sub_abstract, real_index);
2450 }
2451 MS_LOG(INTERNAL_EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
2452 }
2453
GetInputName(const CNodePtr & origin_op,size_t input_index)2454 std::string AnfAlgo::GetInputName(const CNodePtr &origin_op, size_t input_index) {
2455 auto prim_func_input_name = ops::GetInputNameByIndex(GetCNodeName(origin_op), input_index);
2456 if (prim_func_input_name != "") {
2457 return prim_func_input_name;
2458 }
2459 auto origin_primitive = GetCNodePrimitive(origin_op);
2460 MS_EXCEPTION_IF_NULL(origin_primitive);
2461 auto input_names = origin_primitive->GetAttr(kAttrInputNames);
2462 if (input_names == nullptr) {
2463 MS_LOG(INTERNAL_EXCEPTION) << "input_names are nullptr in cnode " << origin_op->fullname_with_scope()
2464 << ", debug string:" << origin_op->DebugString()
2465 << ", attr text:" << origin_primitive->GetAttrsText();
2466 }
2467
2468 auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
2469 if (input_index >= input_names_vec.size()) {
2470 MS_LOG(INFO) << "Input index is invalid. input index: " << input_index << ", input name size "
2471 << input_names_vec.size();
2472 return "";
2473 }
2474 return input_names_vec[input_index];
2475 }
2476
IsNoOuputNode(const AnfNodePtr & node)2477 bool AnfAlgo::IsNoOuputNode(const AnfNodePtr &node) {
2478 const mindspore::HashSet<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> no_output_prims = {
2479 prim::kPrimSend,
2480 prim::kPrimNPUClearFloatStatusV2,
2481 prim::kPrimInitPartitionMap,
2482 prim::kPrimInitEmbeddingHashmap,
2483 prim::kPrimEmbeddingTableImport,
2484 prim::kPrimEmbeddingComputeVarExport,
2485 prim::kPrimEmbeddingComputeVarImport,
2486 prim::kPrimEmbeddingTableExport};
2487 if (IsOneOfPrimitiveCNode(node, no_output_prims)) {
2488 return true;
2489 }
2490 return false;
2491 }
2492
ValueToScalar(const ValuePtr & value,TypeId type_id)2493 ValuePtr AnfAlgo::ValueToScalar(const ValuePtr &value, TypeId type_id) {
2494 MS_EXCEPTION_IF_NULL(value);
2495 if (!value->isa<KernelTensorValue>()) {
2496 return nullptr;
2497 }
2498 const auto &kernel_tensor_value = value->cast<KernelTensorValuePtr>();
2499 MS_EXCEPTION_IF_NULL(kernel_tensor_value);
2500 MS_EXCEPTION_IF_NULL(kernel_tensor_value->GetDataPtr());
2501 switch (type_id) {
2502 case kNumberTypeBool:
2503 return MakeValue(*reinterpret_cast<const bool *>(kernel_tensor_value->GetDataPtr()));
2504 case kNumberTypeInt16:
2505 return MakeValue(*reinterpret_cast<const int16_t *>(kernel_tensor_value->GetDataPtr()));
2506 case kNumberTypeUInt16:
2507 return MakeValue(*reinterpret_cast<const uint16_t *>(kernel_tensor_value->GetDataPtr()));
2508 case kNumberTypeInt8:
2509 return MakeValue(*reinterpret_cast<const int8_t *>(kernel_tensor_value->GetDataPtr()));
2510 case kNumberTypeUInt8:
2511 return MakeValue(*reinterpret_cast<const uint8_t *>(kernel_tensor_value->GetDataPtr()));
2512 case kNumberTypeInt32:
2513 return MakeValue(*reinterpret_cast<const int32_t *>(kernel_tensor_value->GetDataPtr()));
2514 case kNumberTypeUInt32:
2515 return MakeValue(*reinterpret_cast<const uint32_t *>(kernel_tensor_value->GetDataPtr()));
2516 case kNumberTypeInt64:
2517 return MakeValue(*reinterpret_cast<const int64_t *>(kernel_tensor_value->GetDataPtr()));
2518 case kNumberTypeUInt64:
2519 return MakeValue(*reinterpret_cast<const uint64_t *>(kernel_tensor_value->GetDataPtr()));
2520 case kNumberTypeFloat16:
2521 return MakeValue(*reinterpret_cast<const uint16_t *>(kernel_tensor_value->GetDataPtr()));
2522 case kNumberTypeFloat32:
2523 return MakeValue(*reinterpret_cast<const float *>(kernel_tensor_value->GetDataPtr()));
2524 case kNumberTypeFloat64:
2525 return MakeValue(*reinterpret_cast<const double *>(kernel_tensor_value->GetDataPtr()));
2526 case kNumberTypeBFloat16:
2527 return MakeValue(*reinterpret_cast<const uint16_t *>(kernel_tensor_value->GetDataPtr()));
2528 default:
2529 MS_LOG(DEBUG) << "Not support scalar type:" << type_id;
2530 }
2531 return nullptr;
2532 }
2533
2534 namespace {
IterateFindTensor(ValuePtrList * value_list,const VectorRef & ref_list)2535 void IterateFindTensor(ValuePtrList *value_list, const VectorRef &ref_list) {
2536 MS_EXCEPTION_IF_NULL(value_list);
2537 for (size_t i = 0; i < ref_list.size(); ++i) {
2538 if (utils::isa<tensor::BaseTensorPtr>(ref_list[i])) {
2539 auto tensor_ptr = utils::cast<std::shared_ptr<tensor::BaseTensor>>(ref_list[i]);
2540 MS_EXCEPTION_IF_NULL(tensor_ptr);
2541 (void)value_list->emplace_back(tensor_ptr);
2542 } else if (utils::isa<VectorRef>(ref_list[i])) {
2543 auto ref_iter = utils::cast<VectorRef>(ref_list[i]);
2544 IterateFindTensor(value_list, ref_iter);
2545 } else if (utils::isa<tensor::CSRTensorPtr>(ref_list[i])) {
2546 auto csr_tensor = utils::cast<tensor::CSRTensorPtr>(ref_list[i]);
2547 MS_EXCEPTION_IF_NULL(csr_tensor);
2548 (void)value_list->emplace_back(csr_tensor);
2549 } else {
2550 MS_LOG(EXCEPTION) << "The ref value " << ref_list[i].ToString() << " is not a vector ref or a tensor!";
2551 }
2552 }
2553 }
2554
HasAbstractFunction(const AbstractBasePtr & abs)2555 bool HasAbstractFunction(const AbstractBasePtr &abs) {
2556 if (abs->isa<abstract::AbstractSequence>() && !abs->isa<abstract::AbstractSparseTensor>()) {
2557 auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
2558 if (abs_seq->dynamic_len()) {
2559 return HasAbstractFunction(abs_seq->dynamic_len_element_abs());
2560 }
2561 return std::any_of(abs_seq->elements().cbegin(), abs_seq->elements().cend(), HasAbstractFunction);
2562 }
2563 // if abs it not AbstractSequence.
2564 return abs->isa<abstract::AbstractFunction>();
2565 }
2566
IsCellReuse(const AnfNodePtr & input)2567 bool IsCellReuse(const AnfNodePtr &input) {
2568 if (IsValueNode<FuncGraph>(input)) {
2569 auto fg = GetValueNode<FuncGraphPtr>(input);
2570 MS_EXCEPTION_IF_NULL(fg);
2571 if (fg->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE)) {
2572 return true;
2573 }
2574 }
2575 return false;
2576 }
2577
AcceptableReturnValue(const CNodePtr & cnode,const AnfNodePtr & input0)2578 bool AcceptableReturnValue(const CNodePtr &cnode, const AnfNodePtr &input0) {
2579 if (IsCellReuse(input0)) {
2580 return true;
2581 }
2582 auto func_graphs = abstract::GetFuncGraphsFromCallNode(cnode);
2583 auto graph_has_function_output = [](const FuncGraphPtr &fg) { return HasAbstractFunction(fg->output()->abstract()); };
2584 if (std::all_of(func_graphs.cbegin(), func_graphs.cend(), std::not_fn(graph_has_function_output))) {
2585 return true;
2586 }
2587 return false;
2588 }
2589
SupportInlinePartial(const AnfNodePtr & input0)2590 bool SupportInlinePartial(const AnfNodePtr &input0) {
2591 // inline partial
2592 if (IsPrimitiveCNode(input0, prim::kPrimTupleGetItem)) {
2593 auto tuple_get_node = input0->cast<CNodePtr>();
2594 MS_EXCEPTION_IF_NULL(tuple_get_node);
2595 auto get_from_node = tuple_get_node->input(1);
2596 auto idx = common::AnfAlgo::GetTupleGetItemOutIndex(tuple_get_node);
2597 MS_EXCEPTION_IF_NULL(get_from_node);
2598 // tuple get item from a call subgraph output
2599 if (get_from_node->isa<CNode>() && IsValueNode<FuncGraph>(get_from_node->cast<CNodePtr>()->input(0))) {
2600 auto call_graph = GetValueNode<FuncGraphPtr>(get_from_node->cast<CNodePtr>()->input(0));
2601 MS_EXCEPTION_IF_NULL(call_graph);
2602 auto graph_out = call_graph->output();
2603 MS_EXCEPTION_IF_NULL(graph_out);
2604 size_t tuple_input_num = common::AnfAlgo::GetInputTensorNum(graph_out);
2605 // the partial must be the last output
2606 if (graph_out->isa<CNode>() && tuple_input_num == idx + 1) {
2607 int partial_cnt = 0;
2608 for (size_t i = 0; i < tuple_input_num; i++) {
2609 auto input = graph_out->cast<CNodePtr>()->input(i + 1);
2610 if (IsPrimitiveCNode(input, prim::kPrimPartial)) {
2611 partial_cnt++;
2612 }
2613 }
2614 auto partial = graph_out->cast<CNodePtr>()->input(idx + 1);
2615 MS_EXCEPTION_IF_NULL(partial);
2616 // we only support one partial func at the last return value now
2617 if (partial_cnt != 1 || !IsPrimitiveCNode(partial, prim::kPrimPartial)) {
2618 if (partial_cnt != 0) {
2619 MS_LOG(INFO) << "Partial func cnt: " << partial_cnt
2620 << ", last return value: " << partial->fullname_with_scope();
2621 }
2622 return false;
2623 }
2624 auto partial_inputs = partial->cast<CNodePtr>()->inputs();
2625 // the input of partial can't be FuncGraph/Partial
2626 bool has_illegal_input = std::any_of(
2627 partial_inputs.begin() + kPartialMinInputSize, partial_inputs.end(), [](const AnfNodePtr &partial_input) {
2628 return IsValueNode<FuncGraph>(partial_input) || IsPrimitiveCNode(partial_input, prim::kPrimPartial);
2629 });
2630 return !has_illegal_input;
2631 }
2632 }
2633 }
2634 return false;
2635 }
2636 } // namespace
2637
TransformVectorRefToMultiValue(const VectorRef & base_ref)2638 ValuePtrList AnfAlgo::TransformVectorRefToMultiValue(const VectorRef &base_ref) {
2639 ValuePtrList value_list;
2640 if (utils::isa<VectorRef>(base_ref)) {
2641 auto ref_list = utils::cast<VectorRef>(base_ref);
2642 IterateFindTensor(&value_list, ref_list);
2643 } else if (utils::isa<tensor::Tensor>(base_ref)) {
2644 auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(base_ref);
2645 MS_EXCEPTION_IF_NULL(tensor_ptr);
2646 (void)value_list.emplace_back(tensor_ptr);
2647 } else {
2648 MS_LOG(EXCEPTION) << "The ref value " << base_ref.ToString() << " is not a vector ref or a tensor!";
2649 }
2650 return value_list;
2651 }
2652
HasIncorporateCallNode(const CNodePtr & cnode)2653 bool AnfAlgo::HasIncorporateCallNode(const CNodePtr &cnode) {
2654 if (!IsValueNode<Primitive>(cnode->input(0))) { // If cnode is a call node.
2655 auto input0 = cnode->input(0);
2656 if (IsPrimitiveCNode(input0, prim::kPrimSwitch) || IsPrimitiveCNode(input0, prim::kPrimSwitchLayer) ||
2657 IsValueNode<FuncGraph>(input0)) {
2658 if (IsCellReuse(input0) && IsEnableRefMode()) {
2659 MS_LOG(INFO) << "Use cell reuse when enable ge mode: " << cnode->DebugString();
2660 return true;
2661 }
2662 if (AcceptableReturnValue(cnode, input0)) {
2663 return false;
2664 }
2665 }
2666 if (SupportInlinePartial(input0)) {
2667 return false;
2668 }
2669 MS_LOG(INFO) << "Call has indirect call: " << cnode->DebugString();
2670 return true;
2671 }
2672 return false;
2673 }
2674
IsDynamicGraph(const FuncGraphPtr & func_graph)2675 bool AnfAlgo::IsDynamicGraph(const FuncGraphPtr &func_graph) {
2676 MS_EXCEPTION_IF_NULL(func_graph);
2677 std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return(), SuccDeeperSimple);
2678 AnfNodePtr dynamic_node = nullptr;
2679 AnfNodePtr pyexecute_node = nullptr;
2680 for (const auto &node : node_list) {
2681 if (node->abstract() == nullptr) {
2682 MS_LOG(INFO) << "Null abstract of node: " << node->DebugString();
2683 continue;
2684 }
2685 if (node->abstract() != nullptr) {
2686 auto shape = node->abstract()->GetShape();
2687 // Dynamic shape tensor.
2688 if (shape->isa<abstract::TensorShape>() && IsDynamic(shape->GetShapeVector())) {
2689 dynamic_node = node;
2690 break;
2691 }
2692 // Dynamic len sequence.
2693 if (node->abstract()->isa<abstract::AbstractSequence>() &&
2694 node->abstract()->cast<abstract::AbstractSequencePtr>()->dynamic_len()) {
2695 dynamic_node = node;
2696 break;
2697 }
2698 // PyExecute node exist
2699 if (IsPrimitiveCNode(node, prim::kPrimPyExecute)) {
2700 pyexecute_node = node;
2701 }
2702 }
2703 }
2704 if (dynamic_node != nullptr) {
2705 MS_LOG(INFO) << "Func graph:" << func_graph->ToString()
2706 << " is dynamic shape graph, because find dynamic shape node:" << dynamic_node->DebugString()
2707 << ", abstract: " << dynamic_node->abstract()->ToString();
2708 return true;
2709 }
2710 if (pyexecute_node != nullptr) {
2711 MS_LOG(INFO) << "Func graph:" << func_graph->ToString() << " has pyexecute node:" << pyexecute_node->DebugString();
2712 return true;
2713 }
2714 return false;
2715 }
2716 } // namespace common
2717 } // namespace mindspore
2718