• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2024 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "pipeline/jit/ps/static_analysis/prim.h"
20 
21 #include <algorithm>
22 #include <limits>
23 #include <map>
24 #include <mutex>
25 #include <string>
26 #include <utility>
27 
28 #include "abstract/abstract_value.h"
29 #include "abstract/ops/primitive_infer_map.h"
30 #include "abstract/param_validator.h"
31 #include "abstract/utils.h"
32 #include "frontend/operator/cc_implementations.h"
33 #include "frontend/operator/composite/do_signature.h"
34 #include "frontend/operator/ops.h"
35 #include "frontend/operator/ops_front_infer_function.h"
36 #include "frontend/operator/prim_to_function.h"
37 #include "frontend/operator/composite/unpack_call.h"
38 #include "include/common/fallback.h"
39 #include "include/common/utils/convert_utils.h"
40 #include "include/common/utils/convert_utils_py.h"
41 #include "include/common/utils/primfunc_utils.h"
42 #include "ir/anf.h"
43 #include "ir/cell.h"
44 #include "ops/arithmetic_ops.h"
45 #include "ops/comparison_ops.h"
46 #include "ops/framework_ops.h"
47 #include "ops/other_ops.h"
48 #include "ops/sequence_ops.h"
49 #include "ops/structure_ops.h"
50 #include "ops/array_op_name.h"
51 #include "ops/op_utils.h"
52 #include "pipeline/jit/ps/debug/trace.h"
53 #include "pipeline/jit/ps/fallback.h"
54 #include "pipeline/jit/ps/parse/data_converter.h"
55 #include "pipeline/jit/ps/parse/parse_base.h"
56 #include "pipeline/jit/ps/parse/resolve.h"
57 #include "pipeline/jit/ps/pipeline.h"
58 #include "pipeline/jit/ps/resource.h"
59 #include "pipeline/jit/ps/static_analysis/evaluator.h"
60 #include "pipeline/jit/ps/static_analysis/builtin_prim.h"
61 #include "pipeline/jit/ps/static_analysis/static_analysis.h"
62 #include "utils/check_convert_utils.h"
63 #include "utils/hash_set.h"
64 #include "utils/log_adapter.h"
65 #include "utils/ms_context.h"
66 #include "utils/ms_utils.h"
67 #include "utils/parallel_node_check.h"
68 #include "utils/shape_utils.h"
69 #include "utils/symbolic.h"
70 #include "utils/compile_config.h"
71 
72 namespace mindspore {
73 using ClassTypePtr = std::shared_ptr<parse::ClassType>;
74 namespace abstract {
75 using mindspore::parse::PyObjectWrapper;
76 
77 mindspore::HashSet<std::string> prims_to_skip_undetermined_infer{kMakeTupleOpName,  kMakeListOpName,   kSwitchOpName,
78                                                                  kEnvironSetOpName, kEnvironGetOpName, kLoadOpName,
79                                                                  kUpdateStateOpName};
80 
81 // The Python primitives who visit tuple/list elements, but not consume all elements.
82 // Including:
83 // - Consume no element. For instance, MakeTuple.
84 // - Consume partial elements, not all. For instance, TupleGetItem.
85 // Map{"primitive name", {vector<int>:"index to transparent pass, -1 means all elements"}}
86 mindspore::HashMap<std::string, std::vector<int>> prims_transparent_pass_sequence{
87   {kReturnOpName, std::vector({0})},       {kDependOpName, std::vector({0})},     {kidentityOpName, std::vector({0})},
88   {kMakeTupleOpName, std::vector({-1})},   {kMakeListOpName, std::vector({-1})},  {kListAppendOpName, std::vector({0})},
89   {kTupleGetItemOpName, std::vector({0})}, {kListGetItemOpName, std::vector({0})}};
90 
OpDtypeToInt(ops::OP_DTYPE dtype)91 inline int64_t OpDtypeToInt(ops::OP_DTYPE dtype) { return static_cast<int64_t>(dtype); }
92 
GetNodeAfterTypeConversion(const AnfNodePtr & node,const ops::OpInputArg & op_arg,const FuncGraphPtr & fg)93 AnfNodePtr GetNodeAfterTypeConversion(const AnfNodePtr &node, const ops::OpInputArg &op_arg, const FuncGraphPtr &fg) {
94   MS_EXCEPTION_IF_NULL(fg);
95   // If src_cast_dtype is empty, do no need to do type conversion.
96   if (op_arg.cast_dtype_.empty()) {
97     return node;
98   }
99   const auto convert_func =
100     prim::GetPythonOps(parse::PYTHON_MOD_PRIMITIVE_OP_TYPE_CAST, parse::PYTHON_MOD_PRIMITIVE_ARG_DTYPE_CAST_MODULE);
101   auto convert_fg = dyn_cast<FuncGraph>(convert_func);
102   MS_EXCEPTION_IF_NULL(convert_fg);
103   convert_fg->set_manager(fg->manager());
104   return fg->NewCNodeInOrder({NewValueNode(convert_fg), node, NewValueNode(OpDtypeToInt(op_arg.arg_dtype_))});
105 }
106 
GetNodeAfterArgHandler(const AnfNodePtr & node,const std::string & op_name,const ops::OpInputArg & op_arg,const AbstractBasePtr & abs,const FuncGraphPtr & fg)107 AnfNodePtr GetNodeAfterArgHandler(const AnfNodePtr &node, const std::string &op_name, const ops::OpInputArg &op_arg,
108                                   const AbstractBasePtr &abs, const FuncGraphPtr &fg) {
109   if (op_arg.arg_handler_.empty()) {
110     return node;
111   }
112   if (op_arg.is_optional_ && abs->isa<AbstractNone>()) {
113     return node;
114   }
115   const auto arg_handler_func = prim::GetPythonOps(op_arg.arg_handler_, parse::PYTHON_MOD_PRIMITIVE_ARG_HANDLER_MODULE);
116   MS_LOG(DEBUG) << "The arg handler function for '" << op_arg.arg_name_ << "' of Primitive[" << op_name << "] is "
117                 << arg_handler_func->ToString() << ".";
118   if (arg_handler_func->isa<Primitive>()) {
119     auto arg_handler_fg = dyn_cast<Primitive>(arg_handler_func);
120     MS_EXCEPTION_IF_NULL(arg_handler_fg);
121     return fg->NewCNodeInOrder(
122       {NewValueNode(arg_handler_fg), NewValueNode(op_name), NewValueNode(op_arg.arg_name_), node});
123   }
124   auto arg_handler_fg = dyn_cast<FuncGraph>(arg_handler_func);
125   MS_EXCEPTION_IF_NULL(arg_handler_fg);
126   arg_handler_fg->set_manager(fg->manager());
127   return fg->NewCNodeInOrder(
128     {NewValueNode(arg_handler_fg), NewValueNode(op_name), NewValueNode(op_arg.arg_name_), node});
129 }
130 
GenerateNewNodeBySignatures(const ValuePtr & func,const AbstractBasePtrList & args_abs_list,const AnalysisEnginePtr & engine,const AnfNodeConfigPtr & out_conf)131 CNodePtr DoSignatureEvaluator::GenerateNewNodeBySignatures(const ValuePtr &func,
132                                                            const AbstractBasePtrList &args_abs_list,
133                                                            const AnalysisEnginePtr &engine,
134                                                            const AnfNodeConfigPtr &out_conf) {
135   if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
136     MS_LOG(INTERNAL_EXCEPTION) << "Node of out_conf should be CNode";
137   }
138   auto out_cnode = dyn_cast<CNode>(out_conf->node());
139   MS_EXCEPTION_IF_NULL(out_cnode);
140   auto fg = out_cnode->func_graph();
141   MS_EXCEPTION_IF_NULL(fg);
142   if (out_cnode->size() == 0 || (out_cnode->size() - 1) != args_abs_list.size()) {
143     MS_LOG(EXCEPTION) << "Op: " << func->ToString() << " args size should equal to inputs size minus 1, but args size "
144                       << args_abs_list.size() << ", inputs size " << out_cnode->size();
145   }
146 
147   // Handle primitive signatures.
148   AnfNodePtrList args_inputs;
149   (void)std::transform(out_cnode->weak_inputs().cbegin() + 1, out_cnode->weak_inputs().cend(),
150                        std::back_inserter(args_inputs), [](const AnfNodeWeakPtr &weak_node) {
151                          const auto &node = weak_node.lock();
152                          MS_EXCEPTION_IF_NULL(node);
153                          return node;
154                        });
155   auto op_inputs = prim::GetNewInputsBySignatures(fg, prim_->ToString(), func, args_abs_list, args_inputs);
156   AnfNodePtrList new_inputs{NewValueNode(func)};
157   (void)std::copy(op_inputs.begin(), op_inputs.end(), std::back_inserter(new_inputs));
158   return fg->NewCNodeInOrder(new_inputs);
159 }
160 
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)161 EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
162                                         const AnfNodeConfigPtr &out_conf) {
163   MS_EXCEPTION_IF_NULL(engine);
164   MS_EXCEPTION_IF_NULL(out_conf);
165   auto do_signature = prim_->cast_ptr<prim::DoSignaturePrimitive>();
166   MS_EXCEPTION_IF_NULL(do_signature);
167   auto &func = do_signature->function();
168   MS_EXCEPTION_IF_NULL(func);
169 
170   AbstractBasePtrList args_abs_list;
171   (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_abs_list),
172                        [](const ConfigPtr &config) -> AbstractBasePtr {
173                          MS_EXCEPTION_IF_NULL(config);
174                          const auto &eval_result = config->ObtainEvalResult();
175                          MS_EXCEPTION_IF_NULL(eval_result);
176                          return eval_result->abstract();
177                        });
178   if (func->isa<Primitive>()) {
179     auto do_signature_func = func->cast<PrimitivePtr>();
180     if (do_signature_func->name() == kIsInstanceOpName) {
181       // Handle for DDE.
182       for (size_t i = 0; i < args_abs_list.size(); ++i) {
183         MS_EXCEPTION_IF_NULL(args_abs_list[i]);
184         if (args_abs_list[i]->isa<abstract::AbstractSequence>()) {
185           MS_LOG(DEBUG) << "Primitive \'IsInstance\' is consuming tuple/list arguments[" << i
186                         << "]: " << args_abs_list[i]->ToString();
187           SetSequenceElementsUseFlagsRecursively(args_abs_list[i], true);
188         }
189       }
190     }
191     // Do undetermined infer firstly.
192     if (prims_to_skip_undetermined_infer.find(do_signature_func->name()) == prims_to_skip_undetermined_infer.end()) {
193       auto res_abstract = EvalUndeterminedArgs(args_abs_list);
194       if (res_abstract != nullptr) {
195         MS_LOG(DEBUG) << "DoSignatureEvaluator eval Undetermined for " << do_signature_func->name()
196                       << ", res_abstract: " << res_abstract->ToString();
197         return res_abstract;
198       }
199     }
200   }
201 
202   CNodePtr new_cnode = nullptr;
203   ScopePtr scope = out_conf->node()->scope();
204   ScopeGuard scope_guard(scope);
205   if (bound_node() != nullptr) {
206     TraceGuard trace_guard(std::make_shared<TraceDoSignature>(bound_node()->debug_info()));
207     new_cnode = GenerateNewNodeBySignatures(func, args_abs_list, engine, out_conf);
208   } else {
209     new_cnode = GenerateNewNodeBySignatures(func, args_abs_list, engine, out_conf);
210   }
211   // Update new CNode info.
212   auto out_cnode = dyn_cast<CNode>(out_conf->node());
213   MS_EXCEPTION_IF_NULL(out_cnode);
214   new_cnode->CloneCNodeInfo(out_cnode);
215 
216   // Do forward with old config and new config.
217   AnfNodeConfigPtr new_conf = engine->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
218   return engine->ForwardConfig(out_conf, new_conf);
219 }
220 
GetUnpackGraphSpecArgsList(const AbstractBasePtrList & args_abs_list,bool need_unpack)221 static AbstractBasePtrList GetUnpackGraphSpecArgsList(const AbstractBasePtrList &args_abs_list, bool need_unpack) {
222   if (!need_unpack) {
223     // arg[0] is the func graph to unpack, ignore it
224     AbstractBasePtrList specialize_args_before_unpack(args_abs_list.begin() + 1, args_abs_list.end());
225     return specialize_args_before_unpack;
226   }
227 
228   AbstractBasePtrList graph_specialize_args;
229   // arg[0] is the func graph to unpack, ignore it
230   for (size_t index = 1; index < args_abs_list.size(); index++) {
231     MS_EXCEPTION_IF_NULL(args_abs_list[index]);
232     if (args_abs_list[index]->isa<AbstractTuple>()) {
233       const auto &arg_tuple = args_abs_list[index]->cast_ptr<AbstractTuple>();
234       (void)std::transform(arg_tuple->elements().cbegin(), arg_tuple->elements().cend(),
235                            std::back_inserter(graph_specialize_args), [](AbstractBasePtr abs) { return abs; });
236     } else if (args_abs_list[index]->isa<AbstractDictionary>()) {
237       auto arg_dict = args_abs_list[index]->cast_ptr<AbstractDictionary>();
238       MS_EXCEPTION_IF_NULL(arg_dict);
239       const auto &dict_elems = arg_dict->elements();
240       (void)std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(graph_specialize_args),
241                            [](const AbstractElementPair &item) {
242                              MS_EXCEPTION_IF_NULL(item.first);
243                              // Dict_elems's first element represents parameter names, which should be string type.
244                              return std::make_shared<AbstractKeywordArg>(
245                                GetValue<std::string>(item.first->BuildValue()), item.second);
246                            });
247     } else {
248       MS_LOG(INTERNAL_EXCEPTION) << "UnpackGraph require args should be tuple or dict, but got "
249                                  << args_abs_list[index]->ToString();
250     }
251   }
252   return graph_specialize_args;
253 }
254 
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)255 EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
256                                         const AnfNodeConfigPtr &out_conf) {
257   MS_EXCEPTION_IF_NULL(engine);
258   MS_EXCEPTION_IF_NULL(out_conf);
259   MS_EXCEPTION_IF_NULL(out_conf->node());
260   if (!out_conf->node()->isa<CNode>()) {
261     MS_LOG(INTERNAL_EXCEPTION) << "Node of out_conf should be CNode";
262   }
263   MS_EXCEPTION_IF_NULL(prim_);
264   auto unpack_graph = prim_->cast_ptr<prim::UnpackGraphPrimitive>();
265   MS_EXCEPTION_IF_NULL(unpack_graph);
266   auto out_cnode = out_conf->node()->cast_ptr<CNode>();
267   MS_EXCEPTION_IF_NULL(out_cnode);
268   if (out_cnode->empty() || (out_cnode->size() - 1) != args_conf_list.size()) {
269     MS_LOG(EXCEPTION) << "UnpackGraphPrimitive"
270                       << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
271                       << ", inputs size " << out_cnode->size();
272   }
273   AbstractBasePtrList args_abs_list;
274   (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_abs_list),
275                        [](const ConfigPtr &ref) -> AbstractBasePtr {
276                          MS_EXCEPTION_IF_NULL(ref);
277                          const auto &eval_result = ref->ObtainEvalResult();
278                          MS_EXCEPTION_IF_NULL(eval_result);
279                          return eval_result->abstract();
280                        });
281   // Get the forward graph
282   if (args_abs_list.empty()) {
283     MS_LOG(INTERNAL_EXCEPTION) << "args_abs_list can't be empty.";
284   }
285   MS_EXCEPTION_IF_NULL(args_abs_list[0]);
286   auto fn = args_abs_list[0]->cast_ptr<AbstractFunction>();
287   if (fn == nullptr) {
288     MS_LOG(INTERNAL_EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but "
289                                << args_abs_list[0]->ToString();
290   }
291   AbstractBasePtrList graph_specialize_args_without_sens;
292   FuncGraphAbstractClosure *real_fn = nullptr;
293   // If it's Partial closure, fetch the func graph from it.
294   const auto &partial_fn_abs = fn->cast_ptr<PartialAbstractClosure>();
295   if (partial_fn_abs != nullptr) {
296     const auto &partial_fn = partial_fn_abs->fn();
297     MS_EXCEPTION_IF_NULL(partial_fn);
298     real_fn = partial_fn->cast_ptr<FuncGraphAbstractClosure>();
299   } else {
300     real_fn = fn->cast_ptr<FuncGraphAbstractClosure>();
301   }
302   MS_EXCEPTION_IF_NULL(real_fn);
303   FuncGraphPtr forward_graph = real_fn->func_graph();
304   MS_EXCEPTION_IF_NULL(forward_graph);
305   AbstractBasePtrList graph_specialize_args =
306     GetUnpackGraphSpecArgsList(args_abs_list, unpack_graph->need_unpack_args());
307   if (unpack_graph->with_sens_in_args() && graph_specialize_args.empty()) {
308     MS_EXCEPTION(ValueError) << "Grad with sens, but the sens is not provided.";
309   }
310   // If it's Partial closure, copy the arg list in advance.
311   if (partial_fn_abs != nullptr) {
312     (void)std::copy(partial_fn_abs->args().begin(), partial_fn_abs->args().end(),
313                     std::back_inserter(graph_specialize_args_without_sens));
314   }
315   (void)std::transform(graph_specialize_args.begin(),
316                        graph_specialize_args.end() - (unpack_graph->with_sens_in_args() ? 1 : 0),
317                        std::back_inserter(graph_specialize_args_without_sens), [](AbstractBasePtr abs) { return abs; });
318   MS_LOG(DEBUG) << "forward_graph: " << forward_graph->ToString()
319                 << ", graph_specialize_args_without_sens size: " << graph_specialize_args_without_sens.size();
320   auto new_forward_graph = forward_graph->GenerateFuncGraph(graph_specialize_args_without_sens);
321   MS_EXCEPTION_IF_NULL(engine->func_graph_manager());
322   engine->func_graph_manager()->AddFuncGraph(new_forward_graph);
323   ScopePtr scope = kDefaultScope;
324   if (out_conf != nullptr) {
325     scope = out_conf->node()->scope();
326   }
327   ScopeGuard scope_guard(scope);
328   AnfNodePtr new_node = NewValueNode(new_forward_graph);
329   AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
330   return engine->ForwardConfig(out_conf, fn_conf);
331 }
332 
MixedPrecisionCastHelper(const AnfNodePtr & source_node,const AbstractBasePtr & node_type,const AnfNodePtr & target_type,const FuncGraphPtr & func_graph)333 AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const AbstractBasePtr &node_type,
334                                     const AnfNodePtr &target_type, const FuncGraphPtr &func_graph) {
335   MS_EXCEPTION_IF_NULL(node_type);
336   MS_EXCEPTION_IF_NULL(func_graph);
337   AnfNodePtr target_node = source_node;
338   if (node_type->isa<AbstractTensor>()) {
339     auto x = node_type->cast_ptr<AbstractTensor>();
340     MS_EXCEPTION_IF_NULL(x->element());
341     MS_EXCEPTION_IF_NULL(x->element()->BuildType());
342     if (x->element()->BuildType()->isa<Float>() || x->element()->BuildType()->isa<BFloat>()) {
343       auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
344       MS_EXCEPTION_IF_NULL(cast);
345       target_node = func_graph->NewCNodeAfter(source_node, {NewValueNode(cast), source_node, target_type});
346     }
347   } else if (node_type->isa<AbstractSequence>()) {
348     auto x = node_type->cast_ptr<AbstractSequence>();
349     auto &items = x->elements();
350     std::vector<AnfNodePtr> nodes;
351     (void)nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
352     int64_t idx = 0;
353     for (const auto &item : items) {
354       AnfNodePtr sequence_node = nullptr;
355       if (node_type->isa<AbstractList>()) {
356         sequence_node = func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), source_node, NewValueNode(idx)});
357       } else {
358         sequence_node = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), source_node, NewValueNode(idx)});
359       }
360       AnfNodePtr node = MixedPrecisionCastHelper(sequence_node, item, target_type, func_graph);
361       (void)nodes.emplace_back(node);
362       ++idx;
363     }
364     target_node = func_graph->NewCNode(nodes);
365   } else if (node_type->isa<AbstractDictionary>()) {
366     auto x = node_type->cast_ptr<AbstractDictionary>();
367     auto &items = x->elements();
368     std::vector<AnfNodePtr> dict_key_nodes;
369     std::vector<AnfNodePtr> dict_value_nodes;
370     (void)dict_key_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
371     (void)dict_value_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
372     for (const auto &item : items) {
373       MS_EXCEPTION_IF_NULL(item.first);
374       auto key_value = item.first->BuildValue();
375       MS_EXCEPTION_IF_NULL(key_value);
376       AnfNodePtr dict_key_node = NewValueNode(key_value);
377       AnfNodePtr dict_value_node =
378         func_graph->NewCNode({NewValueNode(prim::kPrimDictGetItem), source_node, NewValueNode(key_value)});
379       AnfNodePtr key_node = MixedPrecisionCastHelper(dict_key_node, item.first, target_type, func_graph);
380       AnfNodePtr value_node = MixedPrecisionCastHelper(dict_value_node, item.second, target_type, func_graph);
381       (void)dict_key_nodes.emplace_back(key_node);
382       (void)dict_value_nodes.emplace_back(value_node);
383     }
384     target_node =
385       func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(std::move(dict_key_nodes)),
386                             func_graph->NewCNode(std::move(dict_value_nodes))});
387   } else if (node_type->isa<AbstractKeywordArg>()) {
388     auto x = node_type->cast_ptr<AbstractKeywordArg>();
389     std::string kwarg_key = x->get_key();
390     AnfNodePtr kwarg_value_node =
391       func_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kwarg_key), source_node});
392     AnfNodePtr node = MixedPrecisionCastHelper(kwarg_value_node, x->get_arg(), target_type, func_graph);
393     target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(kwarg_key), node});
394   }
395   return target_node;
396 }
397 
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)398 EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
399                                                const AnfNodeConfigPtr &out_conf) {
400   MS_EXCEPTION_IF_NULL(engine);
401   AbstractBasePtrList args_abs_list;
402   MS_EXCEPTION_IF_NULL(out_conf);
403   if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
404     MS_LOG(INTERNAL_EXCEPTION) << "Node of out_conf should be CNode";
405   }
406   auto out_cnode = out_conf->node()->cast<CNodePtr>();
407   MS_EXCEPTION_IF_NULL(out_cnode);
408   if (out_cnode->empty() || (out_cnode->size() - 1) != args_conf_list.size()) {
409     MS_LOG(EXCEPTION) << "MixedPrecisionCast"
410                       << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
411                       << ", inputs size " << out_cnode->size();
412   }
413   (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_abs_list),
414                        [](const ConfigPtr &ref) -> AbstractBasePtr {
415                          MS_EXCEPTION_IF_NULL(ref);
416                          const auto &eval_result = ref->ObtainEvalResult();
417                          MS_EXCEPTION_IF_NULL(eval_result);
418                          return eval_result->abstract();
419                        });
420 
421   ScopeGuard scope_guard(out_conf->node()->scope());
422   TraceGuard trace_guard(std::make_shared<TraceMixedPrecision>(out_conf->node()->debug_info()));
423 
424   FuncGraphPtr func_graph = out_cnode->func_graph();
425   constexpr size_t source_node_index = 2;
426   if (out_cnode->size() <= source_node_index) {
427     MS_LOG(EXCEPTION) << "Input size: " << out_cnode->size() << " should bigger than 2.";
428   }
429 
430   AnfNodePtr new_node =
431     MixedPrecisionCastHelper(out_cnode->input(source_node_index), args_abs_list[1], out_cnode->input(1), func_graph);
432   AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
433 
434   if (new_node->isa<CNode>()) {
435     auto new_cnode = new_node->cast_ptr<CNode>();
436     new_cnode->CloneCNodeInfo(out_cnode);
437   }
438   return engine->ForwardConfig(out_conf, fn_conf);
439 }
440 
441 namespace {
CheckTensorCondValid(const AbstractBasePtr & cond)442 void CheckTensorCondValid(const AbstractBasePtr &cond) {
443   // Tensor condition must be one element or dynamic shape.
444   auto base_shape = cond->BuildShape();
445   MS_EXCEPTION_IF_NULL(base_shape);
446   ShapeVector cond_shape = base_shape->cast<ShapePtr>()->shape();
447   if (cond_shape.empty()) {
448     return;
449   }
450   constexpr auto num_one = 1;
451   for (size_t i = 0; i < cond_shape.size(); i++) {
452     if (cond_shape[i] != num_one && cond_shape[i] != Shape::kShapeDimAny && cond_shape[i] != Shape::kShapeRankAny) {
453       MS_LOG(ERROR) << "The condition value of control flow can be a tensor with one element, "
454                     << "but got tensor with shape " << base_shape->ToString();
455       MS_EXCEPTION(ValueError) << "The truth value of an array with more than one element is ambiguous.";
456     }
457   }
458 }
459 }  // namespace
460 
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)461 EvalResultPtr SwitchEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
462                                    const AnfNodeConfigPtr &out_conf) {
463   MS_EXCEPTION_IF_NULL(engine);
464   AbstractBasePtrList args_abs_list;
465   MS_EXCEPTION_IF_NULL(out_conf);
466   MS_EXCEPTION_IF_NULL(out_conf->node());
467   if (!out_conf->node()->isa<CNode>()) {
468     MS_LOG(INTERNAL_EXCEPTION) << "Node of out_conf should be CNode";
469   }
470   auto out_cnode = out_conf->node()->cast<CNodePtr>();
471   MS_EXCEPTION_IF_NULL(out_cnode);
472   if (out_cnode->empty() || (out_cnode->size() - 1) != args_conf_list.size()) {
473     MS_LOG(EXCEPTION) << "For 'Switch',"
474                       << " the args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
475                       << ", inputs size " << out_cnode->size();
476   }
477 
478   // Inputs: condition, true branch, false branch
479   constexpr auto switch_input_size = 3;
480   if (args_conf_list.size() != switch_input_size) {
481     MS_LOG(EXCEPTION) << "Switch evaluator requires 3 parameters, while the input size is " << args_abs_list.size()
482                       << ".";
483   }
484 
485   auto eval_func = [](const ConfigPtr &conf) -> AbstractBasePtr {
486     MS_EXCEPTION_IF_NULL(conf);
487     const auto &eval_result = conf->ObtainEvalResult();
488     MS_EXCEPTION_IF_NULL(eval_result);
489     auto abs = eval_result->abstract();
490     MS_EXCEPTION_IF_NULL(abs);
491     return abs;
492   };
493 
494   auto cond_abstract = eval_func(args_conf_list[0]);
495   ValuePtr cond_value = cond_abstract->GetValueTrack();
496   MS_EXCEPTION_IF_NULL(cond_value);
497   // If the value of condition is ValueAny or the abstract of condition is AbstractTensor,
498   // keeps both true and false branch.
499   if (cond_value->isa<ValueAny>() || cond_abstract->isa<AbstractTensor>()) {
500     if (cond_abstract->isa<AbstractTensor>()) {
501       CheckTensorCondValid(cond_abstract);
502     }
503     auto true_branch = eval_func(args_conf_list[1]);
504     // Need record two func_graph
505     constexpr auto false_branch_index = 2;
506     auto false_branch = eval_func(args_conf_list[false_branch_index]);
507     SetVariableFlag(true_branch);
508     SetVariableFlag(false_branch);
509     auto res_abs = true_branch->Join(false_branch);
510     auto eval_result = std::make_shared<EvalResult>(res_abs, std::make_shared<AttrValueMap>());
511     return eval_result;
512   }
513 
514   if (cond_value->isa<Scalar>()) {
515     AbstractBasePtr res_abs = nullptr;
516     if (cond_value->cast<ScalarPtr>()->IsOne()) {
517       const auto &true_branch = eval_func(args_conf_list[1]);
518       res_abs = true_branch;
519     } else {
520       constexpr auto false_branch_index = 2;
521       auto false_branch = eval_func(args_conf_list[false_branch_index]);
522       res_abs = false_branch;
523     }
524     auto eval_result = std::make_shared<EvalResult>(res_abs, std::make_shared<AttrValueMap>());
525     return eval_result;
526   }
527   MS_LOG(EXCEPTION) << "Not support this condition value: " << cond_abstract->GetValueTrack()->ToString();
528 }
529 
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)530 EvalResultPtr SwitchLayerEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
531                                         const AnfNodeConfigPtr &out_conf) {
532   MS_EXCEPTION_IF_NULL(engine);
533   AbstractBasePtrList args_abs_list;
534   MS_EXCEPTION_IF_NULL(out_conf);
535   MS_EXCEPTION_IF_NULL(out_conf->node());
536   if (!out_conf->node()->isa<CNode>()) {
537     MS_LOG(INTERNAL_EXCEPTION) << "Node of out_conf should be CNode";
538   }
539   auto out_cnode = out_conf->node()->cast<CNodePtr>();
540   MS_EXCEPTION_IF_NULL(out_cnode);
541   if (out_cnode->empty() || (out_cnode->size() - 1) != args_conf_list.size()) {
542     MS_LOG(EXCEPTION) << "For 'SwitchLayer',"
543                       << " the args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
544                       << ", inputs size " << out_cnode->size();
545   }
546 
547   // Inputs: condition, true branch, false branch
548   constexpr auto switch_input_size = 3;
549   if (args_conf_list.size() != switch_input_size) {
550     MS_LOG(EXCEPTION) << "SwitchLayer evaluator requires 3 parameters, while the input size is " << args_abs_list.size()
551                       << ".";
552   }
553   auto eval_func = [](const ConfigPtr &conf) -> AbstractBasePtr {
554     MS_EXCEPTION_IF_NULL(conf);
555     const auto &eval_result = conf->ObtainEvalResult();
556     MS_EXCEPTION_IF_NULL(eval_result);
557     auto abs = eval_result->abstract();
558     MS_EXCEPTION_IF_NULL(abs);
559     return abs;
560   };
561   auto cond_abstract = eval_func(args_conf_list[0]);
562   ValuePtr cond_value = cond_abstract->GetValueTrack();
563   MS_EXCEPTION_IF_NULL(cond_value);
564   MS_LOG(EXCEPTION) << "Not support this condition value: " << cond_value->ToString();
565 }
566 
567 namespace {
BuildPyObject(const ValuePtr & value_ptr)568 py::object BuildPyObject(const ValuePtr &value_ptr) {
569   if (value_ptr == nullptr) {
570     return py::none();
571   } else {
572     return ValueToPyData(value_ptr);
573   }
574 }
575 
AbstractTupleValueToPython(const AbstractTuple * tuple_abs)576 py::object AbstractTupleValueToPython(const AbstractTuple *tuple_abs) {
577   MS_EXCEPTION_IF_NULL(tuple_abs);
578   if (tuple_abs->dynamic_len()) {
579     return py::none();
580   }
581   const auto &elements = tuple_abs->elements();
582   size_t len = elements.size();
583   py::tuple value_tuple(len);
584   for (size_t i = 0; i < len; ++i) {
585     value_tuple[i] = ConvertAbstractToPython(elements[i], true)[ATTR_VALUE];
586   }
587   return value_tuple;
588 }
589 
AbstractTupleToPython(const AbstractBasePtr & abs_base,bool only_convert_value)590 py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base, bool only_convert_value) {
591   auto arg_tuple = dyn_cast_ptr<AbstractTuple>(abs_base);
592   MS_EXCEPTION_IF_NULL(arg_tuple);
593   auto dic = py::dict();
594   if (only_convert_value) {
595     dic[ATTR_VALUE] = AbstractTupleValueToPython(arg_tuple);
596     return dic;
597   }
598   if (arg_tuple->dynamic_len()) {
599     dic[ATTR_VALUE] = py::none();
600     dic[ATTR_SHAPE] = ShapeVector{abstract::Shape::kShapeDimAny};
601     dic[ATTR_DTYPE] = arg_tuple->BuildType();
602     return dic;
603   }
604   size_t len = arg_tuple->size();
605   py::tuple shape_tuple(len);
606   py::tuple dtype_tuple(len);
607   py::tuple value_tuple(len);
608   std::vector<py::dict> res;
609 
610   for (size_t i = 0; i < len; i++) {
611     py::dict out = ConvertAbstractToPython(arg_tuple->elements()[i]);
612     res.push_back(out);
613     shape_tuple[i] = out[ATTR_SHAPE];
614     dtype_tuple[i] = out[ATTR_DTYPE];
615     value_tuple[i] = out[ATTR_VALUE];
616   }
617   dic[ATTR_SHAPE] = shape_tuple;
618   dic[ATTR_DTYPE] = dtype_tuple;
619   dic[ATTR_VALUE] = value_tuple;
620 
621   return dic;
622 }
623 
AbstractDictionaryToPython(const AbstractBasePtr & abs_base)624 py::dict AbstractDictionaryToPython(const AbstractBasePtr &abs_base) {
625   auto arg_dict = dyn_cast_ptr<AbstractDictionary>(abs_base);
626   MS_EXCEPTION_IF_NULL(arg_dict);
627 
628   size_t len = arg_dict->size();
629   const auto &arg_dict_elements = arg_dict->elements();
630   py::list shape_list(len);
631   py::list dtype_list(len);
632   py::dict value_dict = py::dict();
633 
634   for (size_t i = 0; i < len; ++i) {
635     auto cur_attr = arg_dict_elements[i];
636     auto cur_key = cur_attr.first;
637     auto cur_value = cur_attr.second;
638 
639     py::dict cur_value_out = ConvertAbstractToPython(cur_value);
640     shape_list[i] = cur_value_out[ATTR_SHAPE];
641     dtype_list[i] = cur_value_out[ATTR_DTYPE];
642     MS_EXCEPTION_IF_NULL(cur_key);
643     value_dict[ValueToPyData(cur_key->BuildValue())] = cur_value_out[ATTR_VALUE];
644   }
645 
646   py::dict dic = py::dict();
647   dic[ATTR_SHAPE] = shape_list;
648   dic[ATTR_DTYPE] = dtype_list;
649   MS_EXCEPTION_IF_NULL(arg_dict->BuildValue());
650   dic[ATTR_VALUE] = value_dict;
651   return dic;
652 }
653 
AbstractKWArgsToPython(const AbstractBasePtr & abs_base)654 py::object AbstractKWArgsToPython(const AbstractBasePtr &abs_base) {
655   MS_EXCEPTION_IF_NULL(abs_base);
656   auto abs_keyword_arg = abs_base->cast_ptr<abstract::AbstractKeywordArg>();
657   MS_EXCEPTION_IF_NULL(abs_keyword_arg);
658   auto args_abs = abs_keyword_arg->get_arg();
659   auto args_obj = BuildPyObject(args_abs->BuildValue());
660   // if the args is none but the type is not none means the input is a variable.
661   if (!args_abs->isa<AbstractNone>() && py::isinstance<py::none>(args_obj)) {
662     return py::none();
663   }
664   return BuildPyObject(abs_base->BuildValue());
665 }
666 
AbstractListValueToPython(const AbstractList * list_abs)667 py::object AbstractListValueToPython(const AbstractList *list_abs) {
668   MS_EXCEPTION_IF_NULL(list_abs);
669   if (list_abs->dynamic_len()) {
670     return py::none();
671   }
672   const auto &elements = list_abs->elements();
673   size_t len = elements.size();
674   py::list value_list(len);
675   for (size_t i = 0; i < len; ++i) {
676     value_list[i] = ConvertAbstractToPython(elements[i], true)[ATTR_VALUE];
677   }
678   return value_list;
679 }
680 
AbstractListToPython(const AbstractBasePtr & abs_base,bool only_convert_value)681 py::dict AbstractListToPython(const AbstractBasePtr &abs_base, bool only_convert_value) {
682   auto arg_list = dyn_cast_ptr<AbstractList>(abs_base);
683   MS_EXCEPTION_IF_NULL(arg_list);
684   auto dic = py::dict();
685   if (only_convert_value) {
686     dic[ATTR_VALUE] = AbstractListValueToPython(arg_list);
687     return dic;
688   }
689   if (arg_list->dynamic_len()) {
690     auto elem_out = ConvertAbstractToPython(arg_list->dynamic_len_element_abs());
691     dic[ATTR_VALUE] = py::none();
692     dic[ATTR_SHAPE] = elem_out[ATTR_SHAPE];
693     dic[ATTR_DTYPE] = elem_out[ATTR_DTYPE];
694     return dic;
695   }
696   size_t len = arg_list->size();
697   py::list shape_list(len);
698   py::list dtype_list(len);
699   py::list value_list(len);
700   std::vector<py::dict> res;
701 
702   for (size_t i = 0; i < len; i++) {
703     py::dict out = ConvertAbstractToPython(arg_list->elements()[i]);
704     res.push_back(out);
705     shape_list[i] = out[ATTR_SHAPE];
706     dtype_list[i] = out[ATTR_DTYPE];
707     value_list[i] = out[ATTR_VALUE];
708   }
709 
710   dic[ATTR_SHAPE] = shape_list;
711   dic[ATTR_DTYPE] = dtype_list;
712   dic[ATTR_VALUE] = value_list;
713   return dic;
714 }
715 
ConvertAbstractTensorToPython(const AbstractBasePtr & abs_base,bool only_convert_value,py::dict * dic)716 void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, bool only_convert_value, py::dict *dic) {
717   auto arg_tensor = dyn_cast_ptr<AbstractTensor>(abs_base);
718   MS_EXCEPTION_IF_NULL(dic);
719   MS_EXCEPTION_IF_NULL(arg_tensor);
720   if (only_convert_value) {
721     (*dic)[ATTR_VALUE] = BuildPyObject(arg_tensor->BuildValue());
722     return;
723   }
724   MS_EXCEPTION_IF_NULL(arg_tensor->shape());
725   (*dic)[ATTR_SHAPE] = arg_tensor->shape()->shape();
726 
727   (*dic)[ATTR_DTYPE] = arg_tensor->BuildType();
728   (*dic)[ATTR_VALUE] = BuildPyObject(arg_tensor->BuildValue());
729 }
730 namespace {
GetPyObjForPrimitiveAbstract(const PrimitiveAbstractClosurePtr & prim_abs)731 py::object GetPyObjForPrimitiveAbstract(const PrimitiveAbstractClosurePtr &prim_abs) {
732   MS_EXCEPTION_IF_NULL(prim_abs);
733   auto prim = prim_abs->BuildValue();
734   if (prim == nullptr) {
735     return py::none();
736   }
737   if (prim->isa<prim::DoSignaturePrimitive>()) {
738     auto do_sig_prim = prim->cast_ptr<prim::DoSignaturePrimitive>();
739     auto value = do_sig_prim->function();
740     MS_EXCEPTION_IF_NULL(value);
741     if (!value->isa<PrimitivePy>()) {
742       return py::none();
743     }
744     auto prim_py = value->cast_ptr<PrimitivePy>();
745     return prim_py->GetPyObj();
746   }
747   if (prim->isa<PrimitivePy>()) {
748     auto prim_py = prim->cast_ptr<PrimitivePy>();
749     return prim_py->GetPyObj();
750   }
751   return py::none();
752 }
753 }  // namespace
754 
ConvertAbstractFunctionToPython(const AbstractBasePtr & abs_base,py::dict * dic)755 void ConvertAbstractFunctionToPython(const AbstractBasePtr &abs_base, py::dict *dic) {
756   MS_EXCEPTION_IF_NULL(dic);
757   MS_EXCEPTION_IF_NULL(abs_base);
758   (*dic)[ATTR_SHAPE] = py::none();
759   (*dic)[ATTR_DTYPE] = abs_base->BuildType();
760   (*dic)[ATTR_VALUE] = py::none();
761   if (abs_base->isa<PartialAbstractClosure>()) {
762     auto partial_abs = abs_base->cast<PartialAbstractClosurePtr>();
763     AbstractBasePtrList args = partial_abs->args();
764     if (!args.empty()) {
765       auto value = args[0]->BuildValue();
766       MS_EXCEPTION_IF_NULL(value);
767       auto value_obj = value->cast_ptr<parse::ClassType>();
768       if (value_obj != nullptr) {
769         (*dic)[ATTR_DTYPE] = std::make_shared<TypeType>();
770         (*dic)[ATTR_VALUE] = value_obj->obj();
771       }
772     }
773   }
774   if (abs_base->isa<PrimitiveAbstractClosure>()) {
775     (*dic)[ATTR_VALUE] = GetPyObjForPrimitiveAbstract(abs_base->cast<PrimitiveAbstractClosurePtr>());
776   }
777 }
778 
CheckType(const TypePtr & expected_type,const TypePtr & x)779 bool CheckType(const TypePtr &expected_type, const TypePtr &x) {
780   // As x and predicate both are mindspore type statically, here we only to judge whether
781   // x is predicate or is a subclass of predicate.
782   return IsIdentidityOrSubclass(x, expected_type);
783 }
784 
785 // Join all types in args_type_list;
TypeJoin(const TypePtrList & args_type_list)786 TypePtr TypeJoin(const TypePtrList &args_type_list) {
787   if (args_type_list.empty()) {
788     MS_LOG(INTERNAL_EXCEPTION) << "args_type_list is empty";
789   }
790 
791   TypePtr type_tmp = args_type_list[0];
792   for (std::size_t i = 1; i < args_type_list.size(); i++) {
793     type_tmp = abstract::TypeJoin(type_tmp, args_type_list[i]);
794   }
795   return type_tmp;
796 }
797 
CheckTypeList(const TypePtr & predicate,const TypePtrList & args_type_list)798 TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list) {
799   MS_EXCEPTION_IF_NULL(predicate);
800   for (const auto &arg_type : args_type_list) {
801     MS_EXCEPTION_IF_NULL(arg_type);
802     if (!CheckType(predicate, arg_type)) {
803       MS_LOG(INTERNAL_EXCEPTION) << "The expected is " << predicate->ToString() << ", not " << arg_type->ToString();
804     }
805   }
806   return TypeJoin(args_type_list);
807 }
808 }  // namespace
809 
UnknownAbstract(const AbstractBasePtr & abs_base)810 void UnknownAbstract(const AbstractBasePtr &abs_base) {
811   auto value = abs_base->BuildValue();
812   MS_EXCEPTION_IF_NULL(value);
813   if ((*value == *kValueAny)) {
814     auto value_desc = abs_base->value_desc();
815     MS_EXCEPTION(TypeError) << "Unsupported parameter " << (value_desc.empty() ? "type" : value_desc)
816                             << " for python primitive." << abs_base->ToString();
817   }
818   MS_EXCEPTION(TypeError) << "Unsupported parameter type for python primitive, the parameter value is "
819                           << value->ToString();
820 }
821 
ConvertAbstractToPython(const AbstractBasePtr & abs_base,bool only_convert_value)822 py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base, bool only_convert_value) {
823   MS_EXCEPTION_IF_NULL(abs_base);
824   auto dic = py::dict();
825   if (abs_base->isa<AbstractTensor>()) {
826     ConvertAbstractTensorToPython(abs_base, only_convert_value, &dic);
827   } else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>()) {
828     ShapeVector shape;
829     dic[ATTR_SHAPE] = shape;
830     dic[ATTR_DTYPE] = abs_base->BuildType();
831     dic[ATTR_VALUE] = BuildPyObject(abs_base->BuildValue());
832   } else if (abs_base->isa<AbstractTuple>()) {
833     return AbstractTupleToPython(abs_base, only_convert_value);
834   } else if (abs_base->isa<AbstractList>()) {
835     return AbstractListToPython(abs_base, only_convert_value);
836   } else if (abs_base->isa<AbstractDictionary>()) {
837     return AbstractDictionaryToPython(abs_base);
838   } else if (abs_base->isa<AbstractSlice>()) {
839     auto arg_slice = dyn_cast_ptr<AbstractSlice>(abs_base);
840     ShapeVector shape;
841     dic[ATTR_SHAPE] = shape;
842     dic[ATTR_DTYPE] = arg_slice->BuildType();
843     dic[ATTR_VALUE] = BuildPyObject(arg_slice->BuildValue());
844   } else if (abs_base->isa<AbstractRowTensor>()) {
845     auto arg = dyn_cast_ptr<AbstractRowTensor>(abs_base);
846     MS_EXCEPTION_IF_NULL(arg->shape());
847     dic[ATTR_SHAPE] = arg->shape()->shape();
848     dic[ATTR_DTYPE] = arg->BuildType();
849     dic[ATTR_VALUE] = BuildPyObject(arg->BuildValue());
850   } else if (abs_base->isa<AbstractCOOTensor>()) {
851     auto arg = dyn_cast_ptr<AbstractCOOTensor>(abs_base);
852     MS_EXCEPTION_IF_NULL(arg->shape());
853     AbstractBasePtrList sparse_shape = arg->shape()->elements();
854     ShapeVector sparse_shape_vector;
855     (void)std::transform(sparse_shape.begin(), sparse_shape.end(), std::back_inserter(sparse_shape_vector),
856                          [](const AbstractBasePtr &e) -> int64_t {
857                            MS_EXCEPTION_IF_NULL(e);
858                            MS_EXCEPTION_IF_NULL(e->cast_ptr<AbstractScalar>());
859                            ValuePtr value = e->cast_ptr<AbstractScalar>()->BuildValue();
860                            return GetValue<int64_t>(value);
861                          });
862     dic[ATTR_SHAPE] = sparse_shape_vector;
863     dic[ATTR_DTYPE] = arg->BuildType();
864     dic[ATTR_VALUE] = BuildPyObject(arg->BuildValue());
865   } else if (abs_base->isa<AbstractCSRTensor>()) {
866     auto arg = dyn_cast_ptr<AbstractCSRTensor>(abs_base);
867     MS_EXCEPTION_IF_NULL(arg->shape());
868     AbstractBasePtrList sparse_shape = arg->shape()->elements();
869     ShapeVector sparse_shape_vector;
870     (void)std::transform(sparse_shape.begin(), sparse_shape.end(), std::back_inserter(sparse_shape_vector),
871                          [](const AbstractBasePtr &e) -> int64_t {
872                            MS_EXCEPTION_IF_NULL(e);
873                            MS_EXCEPTION_IF_NULL(e->cast_ptr<AbstractScalar>());
874                            ValuePtr value = e->cast_ptr<AbstractScalar>()->BuildValue();
875                            return GetValue<int64_t>(value);
876                          });
877     dic[ATTR_SHAPE] = sparse_shape_vector;
878     dic[ATTR_DTYPE] = arg->BuildType();
879     dic[ATTR_VALUE] = BuildPyObject(arg->BuildValue());
880   } else if (abs_base->isa<AbstractEllipsis>()) {
881     dic[ATTR_SHAPE] = py::none();
882     dic[ATTR_DTYPE] = py::ellipsis();
883     dic[ATTR_VALUE] = py::ellipsis();
884   } else if (abs_base->isa<AbstractNone>()) {
885     dic[ATTR_SHAPE] = py::none();
886     dic[ATTR_DTYPE] = py::none();
887     dic[ATTR_VALUE] = py::none();
888   } else if (abs_base->isa<AbstractFunction>()) {
889     ConvertAbstractFunctionToPython(abs_base, &dic);
890   } else if (abs_base->isa<AbstractClass>()) {
891     auto arg_class = dyn_cast_ptr<AbstractClass>(abs_base);
892     ShapeVector shape;
893     dic[ATTR_SHAPE] = shape;
894     dic[ATTR_DTYPE] = arg_class->BuildType();
895     dic[ATTR_VALUE] = BuildPyObject(arg_class->BuildValue());
896   } else if (abs_base->isa<AbstractUndetermined>()) {
897     auto arg = dyn_cast_ptr<AbstractUndetermined>(abs_base);
898     dic[ATTR_SHAPE] = py::none();
899     dic[ATTR_DTYPE] = arg->BuildType();
900     dic[ATTR_VALUE] = py::none();
901   } else if (abs_base->isa<AbstractMonad>()) {
902     dic[ATTR_SHAPE] = py::none();
903     dic[ATTR_DTYPE] = abs_base->BuildType();
904     dic[ATTR_VALUE] = py::none();
905   } else if (abs_base->isa<AbstractKeywordArg>()) {
906     dic[ATTR_SHAPE] = py::none();
907     dic[ATTR_DTYPE] = abs_base->BuildType();
908     dic[ATTR_VALUE] = AbstractKWArgsToPython(abs_base);
909   } else {
910     UnknownAbstract(abs_base);
911   }
912   return dic;
913 }
914 
915 namespace {
CheckCustomPrimOutputInferResult(const PrimitivePtr & prim,const AbstractBasePtr & res_spec)916 void CheckCustomPrimOutputInferResult(const PrimitivePtr &prim, const AbstractBasePtr &res_spec) {
917   MS_EXCEPTION_IF_NULL(prim);
918   MS_EXCEPTION_IF_NULL(res_spec);
919   const string kOutputNum = "output_num";
920   if (prim->IsCustomPrim()) {
921     // Raise error if output_num is not match the infer result.
922     auto output_num_value = prim->GetAttr(kOutputNum);
923     if (output_num_value == nullptr) {
924       MS_LOG(DEBUG) << "The output num may no need to check";
925       return;
926     }
927     int64_t output_num = GetValue<int64_t>(output_num_value);
928     if (res_spec->isa<AbstractTensor>() && output_num != 1) {
929       MS_LOG(EXCEPTION) << "Custom operator primitive[" << prim->ToString()
930                         << "]'s attribute[output_num]: " << output_num << ", not matches the infer result "
931                         << res_spec->ToString();
932     } else if (res_spec->isa<AbstractTuple>() &&
933                (res_spec->cast_ptr<AbstractTuple>()->size() != LongToSize(output_num))) {
934       MS_LOG(EXCEPTION) << "Custom operator primitive[" << prim->ToString()
935                         << "]'s attribute[output_num]: " << output_num << ", not matches the infer result "
936                         << res_spec->ToString();
937     }
938   }
939 }
940 
IsMonadType(const py::object & type_obj)941 static bool IsMonadType(const py::object &type_obj) {
942   if (py::isinstance<Type>(type_obj)) {
943     auto type = type_obj.cast<Type *>();
944     return type->isa<MonadType>();
945   }
946   return false;
947 }
948 
ToMonadAbstract(const py::object & type_obj)949 AbstractBasePtr ToMonadAbstract(const py::object &type_obj) {
950   if (py::isinstance<Type>(type_obj)) {
951     auto type = type_obj.cast<Type *>();
952     if (!type->isa<MonadType>()) {
953       MS_LOG(INTERNAL_EXCEPTION) << "Not a monad type object: " << py::str(type_obj);
954     }
955     return abstract::MakeMonadAbstract(type->cast<MonadTypePtr>());
956   }
957   MS_LOG(INTERNAL_EXCEPTION) << "Not a type object: " << py::str(type_obj);
958 }
959 
GetPyAbsItemOfTupleOut(const py::object & output,const size_t index)960 py::object GetPyAbsItemOfTupleOut(const py::object &output, const size_t index) {
961   auto out_dict = output.cast<py::dict>();
962   auto type_obj = out_dict[ATTR_DTYPE];
963   auto shape_obj = out_dict[ATTR_SHAPE];
964   auto out_item = py::dict();
965   auto shape_tuple = shape_obj.cast<py::tuple>();
966   auto typeid_tuple = type_obj.cast<py::tuple>();
967   out_item[ATTR_DTYPE] = typeid_tuple[index];
968   out_item[ATTR_SHAPE] = shape_tuple[index];
969   out_item[ATTR_VALUE] = py::none();
970   return out_item;
971 }
972 
MakePyInferRes2AbstractTensor(const py::object & shape_obj,const py::object & type_obj)973 AbstractBasePtr MakePyInferRes2AbstractTensor(const py::object &shape_obj, const py::object &type_obj) {
974   auto res_vec = shape_obj.cast<ShapeVector>();
975   auto res_dtype = type_obj.cast<TypePtr>();
976 
977   auto res_shape = std::make_shared<abstract::Shape>(res_vec);
978   AbstractBasePtr tensor = MakeAbstractTensor(res_shape, res_dtype);
979   return tensor;
980 }
981 
MakePyInferRes2Abstract(const py::object & output)982 AbstractBasePtr MakePyInferRes2Abstract(const py::object &output) {
983   auto out_dict = output.cast<py::dict>();
984   auto type_obj = out_dict[ATTR_DTYPE];
985   auto shape_obj = out_dict[ATTR_SHAPE];
986   if ((py::isinstance<py::list>(shape_obj) || py::isinstance<py::tuple>(shape_obj)) && py::isinstance<Type>(type_obj)) {
987     auto res_vec = shape_obj.cast<ShapeVector>();
988     auto res_dtype = type_obj.cast<TypePtr>();
989     MS_EXCEPTION_IF_NULL(res_dtype);
990     // if the size of shape list is empty, return an scalar abstract
991     if (res_vec.empty() && (!res_dtype->isa<TensorType>())) {
992       abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kValueAny, res_dtype);
993       return abs_scalar;
994     }
995     return MakePyInferRes2AbstractTensor(shape_obj, type_obj);
996   } else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) {
997     auto typeid_tuple = type_obj.cast<py::tuple>();
998     AbstractBasePtrList ptr_list;
999     for (size_t it = 0; it < typeid_tuple.size(); ++it) {
1000       auto output_it = GetPyAbsItemOfTupleOut(output, it);
1001       auto tensor_it = MakePyInferRes2Abstract(output_it);
1002       ptr_list.push_back(tensor_it);
1003     }
1004     auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
1005     return tuple;
1006   } else if (py::isinstance<py::list>(shape_obj) && py::isinstance<py::list>(type_obj)) {
1007     auto typeid_list = type_obj.cast<py::list>();
1008     AbstractBasePtrList ptr_list;
1009     for (size_t it = 0; it < typeid_list.size(); ++it) {
1010       auto output_it = GetPyAbsItemOfTupleOut(output, it);
1011       auto tensor_it = MakePyInferRes2Abstract(output_it);
1012       ptr_list.push_back(tensor_it);
1013     }
1014     auto list = std::make_shared<abstract::AbstractList>(ptr_list);
1015     return list;
1016   } else if (shape_obj.is_none() && type_obj.is_none()) {
1017     // AbstractNone indicates there is no output for this CNode node.
1018     auto abstract_none = std::make_shared<abstract::AbstractNone>();
1019     return abstract_none;
1020   } else if (IsMonadType(type_obj)) {
1021     // Return monad abstract if it is monad type.
1022     return ToMonadAbstract(type_obj);
1023   } else {
1024     MS_LOG(INTERNAL_EXCEPTION) << "Python evaluator return invalid shape or type. " << py::str(type_obj);
1025   }
1026 }
1027 }  // namespace
PreparePyInputs(const AbstractBasePtrList & args)1028 py::tuple PreparePyInputs(const AbstractBasePtrList &args) {
1029   // The monad parameter is defined at the end of the parameter and needs to be ignored
1030   std::size_t args_size = args.size() - GetAbstractMonadNum(args);
1031   py::tuple py_args(args_size);
1032   for (size_t i = 0; i < args_size; i++) {
1033     py_args[i] = ConvertAbstractToPython(args[i]);
1034   }
1035   return py_args;
1036 }
1037 
PyInferRes2Abstract(const PrimitivePyPtr & prim_py,const py::dict & output)1038 AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) {
1039   // Convert to AbstractValue based on type and shape
1040   if (output[ATTR_VALUE].is_none()) {
1041     return MakePyInferRes2Abstract(output);
1042   }
1043 
1044   // Convert pyobject to Value, then to AbstractValue
1045   auto out_dtype = output[ATTR_DTYPE];
1046   TypePtr dtype = py::isinstance<Type>(out_dtype) ? out_dtype.cast<TypePtr>() : nullptr;
1047   ValuePtr converted_ret = nullptr;
1048   bool converted = parse::ConvertData(output[ATTR_VALUE], &converted_ret, false, dtype);
1049   if (!converted) {
1050     MS_LOG(INTERNAL_EXCEPTION) << "Convert data failed";
1051   }
1052   auto res_spec = FromValue(converted_ret);
1053   MS_EXCEPTION_IF_NULL(res_spec);
1054   if (res_spec->isa<AbstractTensor>()) {
1055     // Replace to tensor constant node in specialize
1056     auto res_tensor = res_spec->cast<AbstractTensorPtr>();
1057     res_tensor->set_value(converted_ret);
1058   }
1059   CheckCustomPrimOutputInferResult(prim_py, res_spec);
1060   return res_spec;
1061 }
1062 
RunPyInferValue(const AnalysisEnginePtr &,const AbstractBasePtr & abs_base,const AbstractBasePtrList & args)1063 EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &, const AbstractBasePtr &abs_base,
1064                                                      const AbstractBasePtrList &args) {
1065   auto prim_py = dyn_cast<PrimitivePy>(prim_);
1066   if (prim_py == nullptr) {
1067     MS_LOG(INTERNAL_EXCEPTION) << "The primitive with type 'kPrimTypePyCheck' should be a python primitive.";
1068   }
1069   // Call checking method 'infer_value' for python primitive
1070   MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
1071   auto py_args = PreparePyInputs(args);
1072   py::tuple py_vals(py_args.size());
1073   MS_EXCEPTION_IF_NULL(prim_);
1074   auto added_attrs = prim_->evaluate_added_attrs();
1075   for (size_t i = 0; i < py_args.size(); ++i) {
1076     py_vals[i] = py_args[i][ATTR_VALUE];
1077   }
1078   py::object py_ret = prim_py->RunInferValue(py_vals);
1079   if (py::isinstance<py::none>(py_ret)) {
1080     return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
1081   }
1082   // Convert pyobject to Value, then to AbstractValue
1083   ValuePtr converted_ret = nullptr;
1084   MS_EXCEPTION_IF_NULL(abs_base);
1085   TypePtr dtype = abs_base->BuildType();
1086   bool converted = parse::ConvertData(py_ret, &converted_ret, false, dtype);
1087   if (!converted) {
1088     MS_LOG(INTERNAL_EXCEPTION) << "Convert data failed";
1089   }
1090   auto res_spec = FromValue(converted_ret);
1091   MS_EXCEPTION_IF_NULL(res_spec);
1092   if (res_spec->isa<AbstractTensor>()) {
1093     // Replace to tensor constant node in specialize
1094     auto res_tensor = res_spec->cast_ptr<AbstractTensor>();
1095     res_tensor->set_value(converted_ret);
1096   }
1097   return std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
1098 }
1099 
1100 // Apply EvalResult from cached result for a given primitive.
ApplyCacheEvalResult(const PrimitivePtr & prim,const EvalResultPtr & result)1101 static inline EvalResultPtr ApplyCacheEvalResult(const PrimitivePtr &prim, const EvalResultPtr &result) {
1102   auto &attrs = result->attribute();
1103   if (attrs != nullptr) {
1104     prim->set_evaluate_added_attrs(*attrs);
1105   }
1106   return std::make_shared<EvalResult>(result->abstract()->Clone(), attrs);
1107 }
1108 
EvalPyCheckPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args)1109 EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
1110   // Try to get infer result from evaluator cache.
1111   auto eval_result = evaluator_cache_mgr_->GetValue(args);
1112   if (eval_result != nullptr) {
1113     // Evaluator cache hit.
1114     return std::make_shared<EvalResult>(eval_result->abstract()->Clone(), eval_result->attribute());
1115   }
1116   // In pynative mode (engine == nullptr), it is difficult to set added_attrs to
1117   // python object by C++ code, so we disable global eval cache in pynative mode.
1118   const bool enable_global_cache = (engine != nullptr);
1119   if (enable_global_cache) {
1120     // Try to get infer result from global primitive evaluate cache.
1121     eval_result = eval_cache_->Get(prim_, args);
1122     if (eval_result != nullptr) {
1123       // Global primitive evaluate cache hit.
1124       evaluator_cache_mgr_->SetValue(args, eval_result);
1125       return ApplyCacheEvalResult(prim_, eval_result);
1126     }
1127   }
1128   // PrimitivePy is expected for EvalPyCheckPrim.
1129   auto prim_py = dyn_cast<PrimitivePy>(prim_);
1130   if (prim_py == nullptr) {
1131     MS_LOG(INTERNAL_EXCEPTION) << "The primitive with type 'kPrimTypePyCheck' should be a python primitive.";
1132   }
1133   // We should copy attributes before running check and infer,
1134   // since they may be changed during check and infer.
1135   auto input_attrs = prim_py->attrs();
1136   prim_py->BeginRecordAddAttr();
1137   auto py_args = PreparePyInputs(args);
1138   // Call checking method '__check__' for subclass of 'PrimitiveWithCheck'.
1139   prim_py->RunCheck(py_args);
1140   auto abs = eval_impl_.InferShapeAndType(nullptr, prim_py, args);
1141   MS_EXCEPTION_IF_NULL(abs);
1142   prim_py->EndRecordAddAttr();
1143   auto &added_attrs = prim_py->evaluate_added_attrs();
1144   eval_result = std::make_shared<EvalResult>(abs, std::make_shared<AttrValueMap>(added_attrs));
1145   if (py::hasattr(prim_py->GetPyObj(), PY_PRIM_METHOD_INFER_VALUE)) {
1146     // Call 'infer_value()' method if it is existed, for constant propagation.
1147     eval_result = RunPyInferValue(engine, eval_result->abstract(), args);
1148   }
1149   // Save infer result to caches (evaluator cache and global cache).
1150   if (enable_global_cache) {
1151     eval_cache_->Put(prim_py, std::move(input_attrs), args, eval_result);
1152   }
1153   evaluator_cache_mgr_->SetValue(args, eval_result);
1154   return eval_result;
1155 }
1156 
1157 namespace {
CheckSequenceArgumentForCppPrimitive(const PrimitivePtr & prim,const AbstractBasePtrList & args)1158 void CheckSequenceArgumentForCppPrimitive(const PrimitivePtr &prim, const AbstractBasePtrList &args) {
1159   // To check tuple/list operations with a white list of Python primitive.
1160   MS_EXCEPTION_IF_NULL(prim);
1161   auto iter = prims_transparent_pass_sequence.find(prim->name());
1162   if (iter == prims_transparent_pass_sequence.end()) {
1163     // The primitive use all elements of each argument.
1164     for (size_t i = 0; i < args.size(); ++i) {
1165       MS_EXCEPTION_IF_NULL(args[i]);
1166       if (args[i]->isa<abstract::AbstractSequence>()) {
1167         MS_LOG(DEBUG) << "Primitive \'" << prim->name() << "\' is consuming tuple/list arguments[" << i
1168                       << "]: " << args[i]->ToString();
1169         SetSequenceElementsUseFlagsRecursively(args[i], true);
1170       }
1171     }
1172     return;
1173   }
1174 
1175   // It's transparent pass primitive or using partial elements primitive.
1176   auto index_list = iter->second;
1177   if (index_list.empty()) {
1178     MS_LOG(INTERNAL_EXCEPTION) << "The primitive list should not be empty for " << prim->name();
1179   }
1180   // Ignore all arguments, no need checking if AbstractSequence.
1181   if (index_list[0] == -1) {
1182     return;
1183   }
1184   // Check the specific arguments index.
1185   for (size_t i = 0; i < args.size(); ++i) {
1186     MS_EXCEPTION_IF_NULL(args[i]);
1187     if (!args[i]->isa<abstract::AbstractSequence>()) {
1188       continue;
1189     }
1190     if (std::find(index_list.begin(), index_list.end(), i) == index_list.end()) {
1191       // For current tuple/list argument, it's not a primitive of total transparent pass or partial element use.
1192       MS_LOG(DEBUG) << "Primitive \'" << prim->name() << "\' is consuming specific tuple/list arguments[" << i
1193                     << "]: " << args[i]->ToString();
1194       SetSequenceElementsUseFlagsRecursively(args[i], true);
1195     }
1196   }
1197 }
1198 
CheckSequenceArgumentForPythonPrimitive(const PrimitivePtr & prim,const AbstractBasePtrList & args)1199 void CheckSequenceArgumentForPythonPrimitive(const PrimitivePtr &prim, const AbstractBasePtrList &args) {
1200   MS_EXCEPTION_IF_NULL(prim);
1201   // Consider all primitive implemented python infer() real use the tuple/list arguments.
1202   for (size_t i = 0; i < args.size(); ++i) {
1203     MS_EXCEPTION_IF_NULL(args[i]);
1204     if (args[i]->isa<abstract::AbstractSequence>()) {
1205       MS_EXCEPTION_IF_NULL(args[i]);
1206       MS_LOG(DEBUG) << "Primitive \'" << prim->name() << "\' is consuming tuple/list arguments[" << i
1207                     << "]: " << args[i]->ToString();
1208       SetSequenceElementsUseFlagsRecursively(args[i], true);
1209     }
1210   }
1211 }
1212 
ValidateArgOptional(const AbstractBasePtr & abs_arg,const ops::OpInputArg & input_arg)1213 bool ValidateArgOptional(const AbstractBasePtr &abs_arg, const ops::OpInputArg &input_arg) {
1214   if (!input_arg.is_optional_) {
1215     return false;
1216   }
1217 
1218   auto abs_type = abs_arg->BuildType();
1219   MS_EXCEPTION_IF_NULL(abs_type);
1220   return abs_type->isa<TypeNone>();
1221 }
1222 }  // namespace
1223 
PrimitiveFunctionEvaluator(const PrimitivePtr & prim_func)1224 PrimitiveFunctionEvaluator::PrimitiveFunctionEvaluator(const PrimitivePtr &prim_func)
1225     : TrivialPrimEvaluator("PrimitiveFunctionEvaluator"), prim_func_(prim_func) {
1226   frontend_func_impl_ = mindspore::ops::GetOpFrontendFuncImplPtr(prim_func->name());
1227   op_def_ = mindspore::ops::GetOpDef(prim_func->name());
1228 }
1229 
HasAbstractUndetermined(const AbstractBasePtr & abs)1230 bool HasAbstractUndetermined(const AbstractBasePtr &abs) {
1231   if (abs->isa<AbstractSequence>()) {
1232     auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
1233     return std::any_of(abs_seq->elements().cbegin(), abs_seq->elements().cend(), HasAbstractUndetermined);
1234   }
1235   return abs->IsSameTypeId(AbstractUndetermined::kTypeId);
1236 }
1237 
CheckArgsSizeAndType(const AbstractBasePtrList & abs_args)1238 void PrimitiveFunctionEvaluator::CheckArgsSizeAndType(const AbstractBasePtrList &abs_args) {
1239   auto op_args = op_def_->args_;
1240   // Ignore monad.
1241   AbstractBasePtrList real_abs_args;
1242   (void)std::copy_if(abs_args.cbegin(), abs_args.cend(), std::back_inserter(real_abs_args),
1243                      [](const AbstractBasePtr &abs) {
1244                        MS_EXCEPTION_IF_NULL(abs);
1245                        return !abs->isa<abstract::AbstractMonad>();
1246                      });
1247   // Check inputs number.
1248   if (op_args.size() != real_abs_args.size()) {
1249     MS_EXCEPTION(TypeError) << "For Operator[" << op_def_->name_ << "], the inputs number should be " << op_args.size()
1250                             << " but got " << real_abs_args.size() << ".";
1251   }
1252 
1253   // Check inputs type.
1254   for (size_t i = 0; i < op_args.size(); i++) {
1255     if (HasAbstractUndetermined(real_abs_args[i])) {
1256       continue;
1257     }
1258     if (!ValidateArgOptional(real_abs_args[i], op_args[i]) &&
1259         !ops::ValidateArgsType(real_abs_args[i], op_args[i].arg_dtype_)) {
1260       std::vector<std::string> op_type_list;
1261       for (const auto &op_abs : real_abs_args) {
1262         (void)op_type_list.emplace_back(op_abs->BuildType()->ToString());
1263       }
1264       MS_INTERNAL_EXCEPTION(TypeError)
1265         << "For Operator[" << op_def_->name_ << "], " << op_args[i].arg_name_ << "'s type '"
1266         << real_abs_args[i]->BuildType()->ToString() << "' does not match expected type '"
1267         << ops::EnumToString(op_args[i].arg_dtype_)
1268         << "'.\nThe reason may be: lack of definition of type cast, or incorrect type when creating the node.";
1269     }
1270   }
1271 }
1272 
CheckAndInfer(const AbstractBasePtrList & args)1273 AbstractBasePtr PrimitiveFunctionEvaluator::CheckAndInfer(const AbstractBasePtrList &args) {
1274   if (op_def_ != nullptr) {
1275     (void)op_def_->func_impl_.CheckValidation(prim_func_, args);
1276     if (frontend_func_impl_ != nullptr) {
1277       auto infer_result = frontend_func_impl_->InferAbstract(prim_func_, args);
1278       if (infer_result != nullptr) {
1279         return infer_result;
1280       }
1281     }
1282 
1283     auto type = op_def_->func_impl_.InferType(prim_func_, args);
1284     auto shape = op_def_->func_impl_.InferShape(prim_func_, args);
1285     return MakeAbstract(shape, type);
1286   }
1287   MS_LOG(INTERNAL_EXCEPTION) << "Find infer function failed, primitive: " << prim_func_->ToString();
1288 }
1289 
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args)1290 EvalResultPtr PrimitiveFunctionEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
1291   MS_EXCEPTION_IF_NULL(prim_func_);
1292   CheckArgsSizeAndType(args);
1293   // To check tuple/list operations with a white list of Python primitive.
1294   CheckSequenceArgumentForCppPrimitive(prim_func_, args);
1295 
1296   bool need_infer_value = std::all_of(args.begin(), args.end(), [](const AbstractBasePtr &abs) -> bool {
1297     MS_EXCEPTION_IF_NULL(abs);
1298     auto value = abs->BuildValue();
1299     return (value != nullptr && !value->isa<Monad>() && !value->isa<FuncGraph>());
1300   });
1301 
1302   AbstractBasePtr abs_base = nullptr;
1303   prim_func_->BeginRecordAddAttr();
1304   if (need_infer_value && frontend_func_impl_ != nullptr) {
1305     auto value = frontend_func_impl_->InferValue(prim_func_, args);
1306     if (value != nullptr && !value->ContainsValueAny()) {
1307       abs_base = value->ToAbstract();
1308       prim_func_->EndRecordAddAttr();
1309       auto added_attrs = prim_func_->evaluate_added_attrs();
1310       return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
1311     }
1312   }
1313   abs_base = CheckAndInfer(args);
1314   MS_EXCEPTION_IF_NULL(abs_base);
1315   prim_func_->EndRecordAddAttr();
1316   const auto &added_attrs = prim_func_->evaluate_added_attrs();
1317   return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
1318 }
1319 
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args)1320 EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
1321   // To check tuple/list operations with a white list of Python primitive.
1322   CheckSequenceArgumentForCppPrimitive(prim_, args);
1323   MS_EXCEPTION_IF_NULL(prim_);
1324   if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) {
1325     auto res_abstract = EvalUndeterminedArgs(args);
1326     if (res_abstract != nullptr) {
1327       MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined";
1328       return res_abstract;
1329     }
1330   }
1331   if (prim_->prim_type() == PrimType::kPrimTypePyCheck) {
1332     return EvalPyCheckPrim(engine, args);
1333   }
1334   bool need_infer_value = std::all_of(args.begin(), args.end(), [](const AbstractBasePtr &abs) -> bool {
1335     MS_EXCEPTION_IF_NULL(abs);
1336     auto value = abs->BuildValue();
1337     return (value != nullptr && !value->ContainsValueAny() && !value->isa<None>() && !value->isa<Monad>() &&
1338             !value->isa<FuncGraph>());
1339   });
1340 
1341   AbstractBasePtr abs_base = nullptr;
1342   ValuePtr value = nullptr;
1343   prim_->BeginRecordAddAttr();
1344   if (need_infer_value && eval_impl_.IsImplInferValue()) {
1345     value = eval_impl_.InferValue(prim_, args);
1346     if (value != nullptr) {
1347       abs_base = value->ToAbstract();
1348       prim_->EndRecordAddAttr();
1349       auto added_attrs = prim_->evaluate_added_attrs();
1350       return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
1351     }
1352   }
1353   abs_base = eval_impl_.InferShapeAndType(nullptr, prim_, args);
1354   MS_EXCEPTION_IF_NULL(abs_base);
1355   prim_->EndRecordAddAttr();
1356   const auto &added_attrs = prim_->evaluate_added_attrs();
1357   return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
1358 }
1359 
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args)1360 EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
1361   // Consider all primitive implemented python infer() real use the tuple/list arguments.
1362   CheckSequenceArgumentForPythonPrimitive(prim_py_, args);
1363 
1364   // Ensure input arguments are evaluated.
1365   auto res_abstract = EvalUndeterminedArgs(args);
1366   if (res_abstract != nullptr) {
1367     MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined";
1368     return res_abstract;
1369   }
1370   MS_EXCEPTION_IF_NULL(prim_py_);
1371   auto forbid_reuse = prim_py_->HasAttr(GRAPH_FLAG_FORBID_REUSE_RESULT);
1372   if (!forbid_reuse) {
1373     // Try to get infer result from evaluator cache.
1374     EvalResultPtr eval_result = evaluator_cache_mgr_->GetValue(args);
1375     if (eval_result != nullptr) {
1376       MS_EXCEPTION_IF_NULL(eval_result->abstract());
1377       return std::make_shared<EvalResult>(eval_result->abstract()->Clone(), eval_result->attribute());
1378     }
1379   }
1380   // In pynative mode (engine == nullptr), it is difficult to set added_attrs to
1381   // python object by C++ code, so we disable global eval cache in pynative mode.
1382   const bool enable_global_cache = (engine != nullptr && !forbid_reuse);
1383   if (enable_global_cache) {
1384     // Try to get infer result from global primitive eval cache.
1385     EvalResultPtr eval_result = eval_cache_->Get(prim_py_, args);
1386     if (eval_result != nullptr) {
1387       // Global cache hit.
1388       evaluator_cache_mgr_->SetValue(args, eval_result);
1389       return ApplyCacheEvalResult(prim_py_, eval_result);
1390     }
1391   }
1392   // Cache miss, run infer. We should copy attributes before
1393   // running infer, since they may be changed during infer.
1394   auto input_attrs = prim_py_->attrs();
1395   auto py_args = PreparePyInputs(args);
1396   prim_py_->BeginRecordAddAttr();
1397   py::dict output = prim_py_->RunInfer(py_args);
1398   prim_py_->EndRecordAddAttr();
1399   const auto &added_attrs = prim_py_->evaluate_added_attrs();
1400   MS_LOG(DEBUG) << "Output type is " << py::str(output);
1401   auto res_abs = PyInferRes2Abstract(prim_py_, output);
1402   MS_LOG(DEBUG) << "Python InferTensor result abstract: " << res_abs->ToString();
1403   EvalResultPtr eval_result = std::make_shared<EvalResult>(res_abs, std::make_shared<AttrValueMap>(added_attrs));
1404   // Save result to global primitive eval cache.
1405   if (enable_global_cache) {
1406     eval_cache_->Put(prim_py_, std::move(input_attrs), args, eval_result);
1407   }
1408   evaluator_cache_mgr_->SetValue(args, eval_result);
1409   return eval_result;
1410 }
1411 
EvalPrim(const AnalysisEnginePtr &,const AbstractBasePtrList & args)1412 EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
1413   auto res_abstract = EvalUndeterminedArgs(args);
1414   if (res_abstract != nullptr) {
1415     MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined";
1416     return res_abstract;
1417   }
1418   // if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type.
1419   if (nargs_ != args.size()) {
1420     MS_LOG(INTERNAL_EXCEPTION) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size()
1421                                << " inputs";
1422   }
1423   TypePtr res_value_type = return_value_type_;
1424   ValuePtrList value_list;
1425   for (const auto &arg : args) {
1426     // Check if all arguments are scalar type.
1427     MS_EXCEPTION_IF_NULL(arg);
1428     if (arg->isa<AbstractScalar>()) {
1429       auto arg_scalar = dyn_cast_ptr<AbstractScalar>(arg);
1430       const auto &arg_value = arg_scalar->GetValueTrack();
1431       value_list.push_back(arg_value);
1432     } else {
1433       // Raise TypeError Expected Scalar.
1434       MS_LOG(INTERNAL_EXCEPTION) << "Expect scalar arguments for uniform primitives.";
1435     }
1436   }
1437   for (const auto &item : type_map_) {
1438     TypePtrList selections;
1439     (void)std::transform(item.second.begin(), item.second.end(), std::back_inserter(selections),
1440                          [&args](size_t arg_idx) -> TypePtr {
1441                            if (arg_idx >= args.size()) {
1442                              MS_LOG(EXCEPTION) << "Index: " << arg_idx << " out of range: " << args.size();
1443                            }
1444                            MS_EXCEPTION_IF_NULL(args[arg_idx]);
1445                            return args[arg_idx]->GetTypeTrack();
1446                          });
1447     TypePtr res = CheckTypeList(item.first, selections);
1448     MS_EXCEPTION_IF_NULL(return_value_type_);
1449     MS_EXCEPTION_IF_NULL(item.first);
1450     if (*return_value_type_ == *(item.first)) {
1451       res_value_type = res;
1452     }
1453   }
1454 
1455   ValuePtr evaluated_value = RunImpl(value_list);
1456   MS_EXCEPTION_IF_NULL(evaluated_value);
1457   if (!(*evaluated_value == *kValueAny)) {
1458     res_value_type = evaluated_value->type();
1459   }
1460   // for comparison primitives , return type shall have be specified to be bool.
1461   if (specify_out_type_ != nullptr) {
1462     res_value_type = specify_out_type_;
1463   }
1464 
1465   AbstractScalarPtr abs_base = std::make_shared<AbstractScalar>(evaluated_value, res_value_type);
1466   return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>());
1467 }
1468 
RunImpl(const ValuePtrList & args) const1469 ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const {
1470   if (!eval_value_) {
1471     return kValueAny;
1472   } else {
1473     if (std::any_of(args.begin(), args.end(), [](const ValuePtr &arg) {
1474           MS_EXCEPTION_IF_NULL(arg);
1475           return arg->ContainsValueAny();
1476         })) {
1477       return kValueAny;
1478     }
1479     return impl_(args);
1480   }
1481 }
1482 
1483 // Primitive implementation
1484 // static function start
1485 namespace {
InitStandardPrimEvaluator(PrimitivePtr primitive,const StandardPrimitiveImplReg eval_impl)1486 EvaluatorPtr InitStandardPrimEvaluator(PrimitivePtr primitive, const StandardPrimitiveImplReg eval_impl) {
1487   EvaluatorPtr prim_evaluator = std::make_shared<StandardPrimEvaluator>(primitive, eval_impl);
1488   return prim_evaluator;
1489 }
1490 
InitUniformPrimEvaluator(const PrimitivePtr & primitive,PrimitiveImpl prim_impl,bool eval_value,const TypePtr & specify_out_type)1491 EvaluatorPtr InitUniformPrimEvaluator(const PrimitivePtr &primitive, PrimitiveImpl prim_impl, bool eval_value,
1492                                       const TypePtr &specify_out_type) {
1493   FunctionPtr func = nullptr;
1494   (void)prim::PrimToFunction::GetInstance().GetFunction(primitive, &func);
1495   MS_EXCEPTION_IF_NULL(func);
1496 
1497   EvaluatorPtr uniform_primitive_evaluator =
1498     std::make_shared<UniformPrimEvaluator>(func, prim_impl, eval_value, specify_out_type);
1499   return uniform_primitive_evaluator;
1500 }
1501 
AddToManager(const AnalysisEnginePtr & engine,const FuncGraphPtr func_graph)1502 inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr func_graph) {
1503   MS_EXCEPTION_IF_NULL(engine);
1504   FuncGraphManagerPtr manager = engine->func_graph_manager();
1505   MS_EXCEPTION_IF_NULL(manager);
1506   manager->AddFuncGraph(func_graph);
1507 }
1508 
1509 enum class REQUIRE_TYPE { ATTR, METHOD };
1510 
IsPyExecuteData(const AbstractBasePtr & data_abstract)1511 bool IsPyExecuteData(const AbstractBasePtr &data_abstract) {
1512   MS_EXCEPTION_IF_NULL(data_abstract);
1513   return data_abstract->isa<abstract::AbstractAny>();
1514 }
1515 
CheckObjAttrValid(const TypePtr & data_type,const std::string & item_name,const AbstractBasePtr & data_args)1516 void CheckObjAttrValid(const TypePtr &data_type, const std::string &item_name, const AbstractBasePtr &data_args) {
1517   MS_EXCEPTION_IF_NULL(data_type);
1518   MS_EXCEPTION_IF_NULL(data_args);
1519   // Check if the obj's attr is invalid or decoratored by @jit_forbidden_register
1520   std::string data_type_str = TypeIdLabel(NormalizeTypeId(data_type->type_id()));
1521   if (data_args->isa<AbstractRefTensor>()) {
1522     data_type_str = "Parameter";
1523   } else if (data_args->isa<AbstractNamedTuple>()) {
1524     data_type_str = "NamedTuple";
1525   }
1526   py::module mod1 = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
1527   py::object obj_define = python_adapter::CallPyModFn(mod1, parse::PYTHON_MOD_GET_OBJ_DEFINED, data_type_str);
1528   if (py::isinstance<py::none>(obj_define)) {
1529     return;
1530   }
1531   py::module mod2 = python_adapter::GetPyModule(parse::PYTHON_MOD_MODULE);
1532   auto is_jit_forbidden_method =
1533     python_adapter::CallPyModFn(mod2, parse::PYTHON_MOD_IS_INVALID_METHOD, obj_define, data_type_str, item_name);
1534   if (py::cast<bool>(is_jit_forbidden_method) || data_args->isa<AbstractRefTensor>()) {
1535     MS_LOG(EXCEPTION) << "Failed to compile in GRAPH_MODE because the '" << data_type_str << "' object's method '"
1536                       << item_name << "' is not supported in 'construct' or function with @jit decorator. "
1537                       << "Try to use the '" << data_type_str << "." << item_name << "' externally "
1538                       << "such as initialized in the method '__init__' before assigning"
1539                       << ".\nFor more details, please refer to "
1540                       << "https://www.mindspore.cn/docs/zh-CN/master/design/dynamic_graph_and_static_graph.html \n";
1541   }
1542 }
1543 
SetTypeForGetAttr(const AnfNodePtr & getattr_node,const AbstractBasePtr & value_abs)1544 AnfNodePtr SetTypeForGetAttr(const AnfNodePtr &getattr_node, const AbstractBasePtr &value_abs) {
1545   // Set setattr's abstract as getattr's abstract.
1546   if (value_abs != nullptr &&
1547       (value_abs->isa<abstract::AbstractTensor>() || value_abs->isa<abstract::AbstractScalar>())) {
1548     auto type = value_abs->BuildType();
1549     auto shape = value_abs->BuildShape();
1550     fallback::SetRealType<AnfNode, Type>(getattr_node, type);
1551     fallback::SetRealShape<AnfNode, abstract::BaseShape>(getattr_node, shape);
1552     auto abs_tensor = value_abs->cast_ptr<abstract::AbstractTensor>();
1553     if (abs_tensor != nullptr) {
1554       if (abs_tensor != nullptr && abs_tensor->is_adapter()) {
1555         getattr_node->set_user_data<bool>(fallback::kIsAdapter, std::make_shared<bool>(true));
1556       }
1557     }
1558   }
1559   return getattr_node;
1560 }
1561 
InterpretGetAttrNode(const AbstractBasePtrList & args_abs_list,const AnfNodeConfigPtr & out_conf)1562 EvalResultPtr InterpretGetAttrNode(const AbstractBasePtrList &args_abs_list, const AnfNodeConfigPtr &out_conf) {
1563   MS_EXCEPTION_IF_NULL(out_conf);
1564   auto out_node = out_conf->node();
1565   MS_EXCEPTION_IF_NULL(out_node);
1566   const auto cnode = dyn_cast<CNode>(out_node);
1567   MS_EXCEPTION_IF_NULL(cnode);
1568   auto fg = cnode->func_graph();
1569 
1570   auto data_args = args_abs_list[0];
1571   MS_EXCEPTION_IF_NULL(data_args);
1572   // Not check if the data is from PyExecute CNode.
1573   // Do not check the validity of the attribute in the variable scenario.
1574   if (!IsPyExecuteData(data_args) && !raiseutils::HasVariableCondition(fg)) {
1575     TypePtr data_type = data_args->BuildType();
1576     MS_EXCEPTION_IF_NULL(data_type);
1577     auto item_args = args_abs_list[1];
1578     MS_EXCEPTION_IF_NULL(item_args);
1579     ValuePtr item_value = item_args->BuildValue();
1580     auto item_str = item_value->cast_ptr<StringImm>();
1581     MS_EXCEPTION_IF_NULL(item_str);
1582     std::string item_name = item_str->value();
1583     CheckObjAttrValid(data_type, item_name, data_args);
1584   }
1585 
1586   constexpr auto debug_recursive_level = 2;
1587   const auto &debug_info = trace::GetSourceCodeDebugInfo(out_node->debug_info());
1588   const auto &location = debug_info->location();
1589   if (location == nullptr) {
1590     MS_LOG(WARNING) << "Location info is null, node: " << out_node->DebugString(debug_recursive_level);
1591     return nullptr;
1592   }
1593   const auto expr = location->expr_src();
1594   if (expr.empty()) {
1595     MS_LOG(WARNING) << "Location's expr is empty, node: " << out_node->DebugString(debug_recursive_level);
1596     return nullptr;
1597   }
1598 
1599   constexpr auto item_index = 1;
1600   auto item_arg = args_abs_list.at(item_index);
1601   MS_EXCEPTION_IF_NULL(item_arg);
1602   auto attr_name = GetValue<string>(item_arg->BuildValue());
1603   AnfNodePtr getattr_node;
1604   auto obj_change = cnode->user_data<bool>(fallback::kObjectAttrChange);
1605   if (obj_change != nullptr && *obj_change) {
1606     // The object is changed by setattr node, directly convert it to PyExecute node.
1607     getattr_node = fallback::ConvertCNodeToPyExecuteForPrim(cnode, "getattr");
1608     constexpr auto args_size = 3;
1609     if (args_abs_list.size() == args_size) {  // Has setattr node as input.
1610       auto getattr_cnode = getattr_node->cast<CNodePtr>();
1611       MS_EXCEPTION_IF_NULL(getattr_cnode);
1612       getattr_cnode->add_input(cnode->input(args_size));
1613       constexpr auto value_index = 2;
1614       getattr_node = SetTypeForGetAttr(getattr_cnode, args_abs_list[value_index]);
1615     }
1616   } else {
1617     getattr_node = fallback::ConvertGetAttrNodeToPyInterpret(fg, cnode, attr_name);
1618   }
1619   MS_LOG(DEBUG) << "Convert: " << cnode->DebugString() << " -> " << getattr_node->DebugString();
1620   auto eng = out_conf->engine();
1621   MS_EXCEPTION_IF_NULL(eng);
1622   auto fn_conf = eng->MakeConfig(getattr_node, out_conf->context(), out_conf->func_graph());
1623   return eng->ForwardConfig(out_conf, fn_conf);
1624 }
1625 
InterpretSetAttrNode(const AbstractBasePtrList & args_abs_list,const AnfNodeConfigPtr & out_conf)1626 EvalResultPtr InterpretSetAttrNode(const AbstractBasePtrList &args_abs_list, const AnfNodeConfigPtr &out_conf) {
1627   MS_EXCEPTION_IF_NULL(out_conf);
1628   auto out_node = out_conf->node();
1629   MS_EXCEPTION_IF_NULL(out_node);
1630   const auto cnode = dyn_cast<CNode>(out_node);
1631   MS_EXCEPTION_IF_NULL(cnode);
1632   auto fg = cnode->func_graph();
1633   MS_EXCEPTION_IF_NULL(fg);
1634   auto owner_abs = args_abs_list[0];
1635   MS_EXCEPTION_IF_NULL(owner_abs);
1636   if (owner_abs->isa<abstract::AbstractRefTensor>()) {
1637     MS_EXCEPTION(ValueError) << "Do not support to set attribute for a parameter.";
1638   }
1639   auto owner_value = owner_abs->BuildValue();
1640   auto owner_node = cnode->input(1);
1641   constexpr auto debug_recursive_level = 2;
1642   MS_EXCEPTION_IF_NULL(owner_value);
1643   MS_LOG(DEBUG) << "node: " << out_conf->node()->DebugString(debug_recursive_level)
1644                 << ", owner_value: " << owner_value->ToString();
1645   if (owner_value->isa<parse::InterpretedObject>()) {
1646     const auto &interpreted_value = dyn_cast<parse::InterpretedObject>(owner_value);
1647     const auto &key = interpreted_value->name();
1648     owner_node = fallback::ConvertPyObjectToPyExecute(fg, key, interpreted_value->obj(), owner_node, true);
1649   }
1650 
1651   ValuePtr attr_str_value = args_abs_list[1]->BuildValue();
1652   MS_EXCEPTION_IF_NULL(attr_str_value);
1653   if (!attr_str_value->isa<StringImm>()) {
1654     MS_LOG(EXCEPTION) << "Expect a string, but got: " << attr_str_value->ToString();
1655   }
1656   auto attr_str = attr_str_value->cast<StringImmPtr>();
1657   MS_EXCEPTION_IF_NULL(attr_str);
1658 
1659   constexpr auto internal_setattr_owner_str = "__internal_setattr_owner__";
1660   constexpr auto internal_setattr_value_str = "__internal_setattr_value__";
1661   std::stringstream script_buffer;
1662   script_buffer << "__import__('mindspore').common._utils._jit_fallback_set_attr(" << internal_setattr_owner_str << ", "
1663                 << attr_str->value() << ", " << internal_setattr_value_str << ")";
1664   MS_LOG(DEBUG) << "script: " << script_buffer.str();
1665   const auto script_setattr_str = std::make_shared<StringImm>(script_buffer.str());
1666 
1667   std::vector<ValuePtr> key_list;
1668   (void)key_list.emplace_back(std::make_shared<StringImm>(internal_setattr_owner_str));
1669   (void)key_list.emplace_back(attr_str);
1670   (void)key_list.emplace_back(std::make_shared<StringImm>(internal_setattr_value_str));
1671   const auto key_tuple = std::make_shared<ValueTuple>(key_list);
1672 
1673   std::vector<AnfNodePtr> value_list{NewValueNode(prim::kPrimMakeTuple)};
1674   (void)value_list.emplace_back(owner_node);
1675   (void)value_list.emplace_back(NewValueNode(attr_str));
1676   constexpr auto value_node_index = 3;
1677   (void)value_list.emplace_back(cnode->input(value_node_index));
1678   const auto value_tuple_node = fg->NewCNode(value_list);
1679 
1680   const auto setattr_node =
1681     fallback::CreatePyExecuteCNode(cnode, NewValueNode(script_setattr_str), NewValueNode(key_tuple), value_tuple_node);
1682   MS_LOG(DEBUG) << "setattr_node: " << setattr_node->DebugString(debug_recursive_level);
1683 
1684   // Save abstract for getattr.
1685   constexpr auto value_abs_index = 2;
1686   auto value_abs = args_abs_list[value_abs_index];
1687   if (value_abs != nullptr &&
1688       (value_abs->isa<abstract::AbstractTensor>() || value_abs->isa<abstract::AbstractScalar>())) {
1689     auto type = value_abs->BuildType();
1690     auto shape = value_abs->BuildShape();
1691     fallback::SetRealType<AnfNode, Type>(setattr_node, type);
1692     fallback::SetRealShape<AnfNode, abstract::BaseShape>(setattr_node, shape);
1693     auto abs_tensor = value_abs->cast_ptr<abstract::AbstractTensor>();
1694     if (abs_tensor != nullptr && abs_tensor->is_adapter()) {
1695       setattr_node->set_user_data<bool>(fallback::kIsAdapter, std::make_shared<bool>(true));
1696     }
1697   }
1698 
1699   auto eng = out_conf->engine();
1700   MS_EXCEPTION_IF_NULL(eng);
1701   auto fn_conf = eng->MakeConfig(setattr_node, out_conf->context(), out_conf->func_graph());
1702   return eng->ForwardConfig(out_conf, fn_conf);
1703 }
1704 
StaticGetterInferred(const ValuePtr & value,const ConfigPtr & data_conf,const AnfNodeConfigPtr & old_conf,REQUIRE_TYPE require_type=REQUIRE_TYPE::METHOD)1705 EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, const AnfNodeConfigPtr &old_conf,
1706                                    REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD) {
1707   MS_EXCEPTION_IF_NULL(old_conf);
1708   AbstractBasePtr abstract = ToAbstract(value, AnalysisContext::DummyContext(), old_conf);
1709   // Create new cnode
1710   std::vector<AnfNodePtr> input = {NewValueNode(prim::kPrimPartial)};
1711   auto func_graph_func = dyn_cast_ptr<abstract::FuncGraphAbstractClosure>(abstract);
1712   if (func_graph_func != nullptr) {
1713     FuncGraphPtr fg = func_graph_func->func_graph();
1714     input.push_back(NewValueNode(fg));
1715   } else {
1716     auto prim_func = dyn_cast_ptr<abstract::PrimitiveAbstractClosure>(abstract);
1717     MS_EXCEPTION_IF_NULL(prim_func);
1718     PrimitivePtr prim = prim_func->prim();
1719     input.push_back(NewValueNode(prim));
1720   }
1721 
1722   auto conf = dyn_cast_ptr<abstract::AnfNodeConfig>(data_conf);
1723   MS_EXCEPTION_IF_NULL(conf);
1724   input.push_back(conf->node());
1725   MS_EXCEPTION_IF_NULL(old_conf);
1726   MS_EXCEPTION_IF_NULL(old_conf->node());
1727   FuncGraphPtr func_graph = old_conf->node()->func_graph();
1728   MS_EXCEPTION_IF_NULL(func_graph);
1729   CNodePtr new_cnode = func_graph->NewCNode(input);
1730   if (require_type == REQUIRE_TYPE::ATTR) {
1731     new_cnode = func_graph->NewCNode({new_cnode});
1732   }
1733   AnalysisEnginePtr eng = old_conf->engine();
1734   AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, old_conf->context(), old_conf->func_graph());
1735   return eng->ForwardConfig(old_conf, fn_conf);
1736 }
1737 
SetSideEffectFlag(const PrimitivePtr & prim,const AnfNodeConfigPtr & out_conf)1738 void SetSideEffectFlag(const PrimitivePtr &prim, const AnfNodeConfigPtr &out_conf) {
1739   if (prim == nullptr) {
1740     return;
1741   }
1742   auto effect_info = GetPrimEffectInfo(prim);
1743   if (effect_info.memory || effect_info.io) {
1744     const auto &cnode = dyn_cast<CNode>(out_conf->node());
1745     MS_EXCEPTION_IF_NULL(cnode);
1746     MS_EXCEPTION_IF_NULL(out_conf->func_graph());
1747     MS_LOG(DEBUG) << "Found side-effect, cnode: " << cnode->DebugString()
1748                   << ", func_graph: " << out_conf->func_graph()->ToString();
1749     cnode->set_has_side_effect_node(true);
1750     out_conf->func_graph()->set_has_side_effect_node(true);
1751   }
1752 }
1753 
SetOriginObject(const AnfNodePtr & node,const AnfNodeConfigPtr & out_conf)1754 void SetOriginObject(const AnfNodePtr &node, const AnfNodeConfigPtr &out_conf) {
1755   if (!node->isa<ValueNode>()) {
1756     return;
1757   }
1758   auto vnode = node->cast<ValueNodePtr>();
1759   if (vnode->value()->has_user_data("origin_object")) {
1760     auto origin_object = vnode->value()->user_data<py::object>("origin_object");
1761     out_conf->node()->set_user_data<py::object>("origin_object", origin_object);
1762   }
1763 }
1764 
SetSparseBpropFlag(const PrimitivePtr & prim,const AnfNodeConfigPtr & out_conf)1765 void SetSparseBpropFlag(const PrimitivePtr &prim, const AnfNodeConfigPtr &out_conf) {
1766   if (GetPrimitiveFlag(prim, GRAPH_FLAG_BPROP_RETURN_SPARSE)) {
1767     out_conf->func_graph()->set_flag(FUNC_GRAPH_FLAG_SPARSE_BPROP, true);
1768     EnvSetSparseResultMgr::GetInstance().Set(true);
1769   }
1770 }
1771 
GetEvaluatedValueForNameSpaceString(const AbstractBasePtrList & args_abs_list,const ValuePtr & data_value,const AnfNodeConfigPtr & out_conf,const std::string & data)1772 EvalResultPtr GetEvaluatedValueForNameSpaceString(const AbstractBasePtrList &args_abs_list, const ValuePtr &data_value,
1773                                                   const AnfNodeConfigPtr &out_conf, const std::string &data) {
1774   constexpr size_t item_index = 1;
1775   auto item_args = args_abs_list[item_index];
1776   MS_EXCEPTION_IF_NULL(item_args);
1777   ValuePtr item_value = item_args->BuildValue();
1778   MS_EXCEPTION_IF_NULL(data_value);
1779   MS_EXCEPTION_IF_NULL(item_value);
1780   if (item_value->isa<StringImm>()) {
1781     auto string_value = item_value->cast_ptr<StringImm>();
1782     MS_EXCEPTION_IF_NULL(string_value);
1783     item_value = std::make_shared<parse::Symbol>(string_value->value());
1784   }
1785   if (!item_value->isa<parse::Symbol>()) {
1786     MS_LOG(INTERNAL_EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
1787   }
1788 
1789   // item_name to func addr from obj_map
1790   auto symbol = item_value->cast<parse::SymbolPtr>();
1791   auto name_space = data_value->cast<parse::NameSpacePtr>();
1792   constexpr auto tensors_queue_attr = "__is_tensors_queue__";
1793   constexpr auto pop_attr = "pop";
1794   if (name_space != nullptr && py::hasattr(name_space->namespace_obj(), tensors_queue_attr) &&
1795       symbol->symbol() == pop_attr) {
1796     constexpr auto graph_pop_attr = "__graph_pop__";
1797     symbol = std::make_shared<parse::Symbol>(graph_pop_attr);
1798   }
1799   MS_EXCEPTION_IF_NULL(out_conf);
1800   auto out_node = out_conf->node();
1801   MS_EXCEPTION_IF_NULL(out_node);
1802   FuncGraphPtr func_graph = out_node->func_graph();
1803   MS_EXCEPTION_IF_NULL(func_graph);
1804   auto new_node = parse::ResolveSymbol(func_graph->manager(), name_space, symbol, out_node);
1805   if (new_node == nullptr) {
1806     MS_LOG(INTERNAL_EXCEPTION) << "Resolve node failed";
1807   }
1808 
1809   auto prim = GetPrimitiveWithoutDoSignature(new_node);
1810   SetSparseBpropFlag(prim, out_conf);
1811   SetSideEffectFlag(prim, out_conf);
1812   SetOriginObject(new_node, out_conf);
1813 
1814   if (IsValueNode<TypeNull>(new_node)) {
1815     // Do not find the attribute.
1816     constexpr auto max_args_len = 3;
1817     bool has_default = (args_abs_list.size() == max_args_len);
1818     if (!has_default) {
1819       MS_EXCEPTION(AttributeError) << data << " object has no attribute " << symbol->symbol();
1820     }
1821     auto out_cnode = out_node->cast_ptr<CNode>();
1822     MS_EXCEPTION_IF_NULL(out_cnode);
1823     constexpr auto default_index = 3;
1824     auto default_node = out_cnode->input(default_index);
1825     auto eng = out_conf->engine();
1826     MS_EXCEPTION_IF_NULL(eng);
1827     auto fn_conf = eng->MakeConfig(default_node, out_conf->context(), out_conf->func_graph());
1828     return eng->ForwardConfig(out_conf, fn_conf);
1829   }
1830 
1831   auto new_node_to_fg = GetValueNode<FuncGraphPtr>(new_node);
1832   if (new_node_to_fg != nullptr) {
1833     bool has_recompute_scope = (out_node->scope() != nullptr &&
1834                                 out_node->scope()->name().compare(0, strlen(kAttrRecompute), kAttrRecompute) == 0);
1835     if (has_recompute_scope) {
1836       parse::UpdateRecomputeScope(new_node_to_fg);
1837     } else if (MsContext::GetInstance()->get_param<int>(MS_CTX_DEBUG_LEVEL) == kLevelDebug) {
1838       UpdateDebugInfo(new_node_to_fg, out_node->scope(), out_node->debug_info());
1839     }
1840   }
1841 
1842   AnalysisEnginePtr eng = out_conf->engine();
1843   MS_EXCEPTION_IF_NULL(eng);
1844   AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
1845   return eng->ForwardConfig(out_conf, fn_conf);
1846 }
1847 
GenerateFuncGraphForOverriddenMethod(AnfNodePtr node,const ValuePtr & item_value,const AnfNodeConfigPtr & out_conf)1848 EvalResultPtr GenerateFuncGraphForOverriddenMethod(AnfNodePtr node, const ValuePtr &item_value,
1849                                                    const AnfNodeConfigPtr &out_conf) {
1850   const auto &item_str = item_value->cast_ptr<StringImm>();
1851   FuncGraphPtr inner_fg = nullptr;
1852   py::object overridden_method = py::none();
1853   py::object value_obj = py::none();
1854   if (item_str != nullptr) {
1855     const std::string &item_name = item_str->value();
1856     if (node->has_user_data(item_name)) {
1857       value_obj = *node->user_data<py::object>(item_name);
1858       overridden_method = value_obj.attr("__class__").attr(item_name.c_str());
1859     }
1860   }
1861   bool is_getattr = node->has_user_data("__getattr__");
1862   if (is_getattr) {
1863     value_obj = *node->user_data<py::object>("__getattr__");
1864     try {
1865       overridden_method = value_obj.attr("__class__").attr("__getattr__");
1866     } catch (const std::exception &e) {
1867       MS_LOG(DEBUG) << value_obj << " has no attribute getattr.";
1868     }
1869   }
1870   if (py::isinstance<py::none>(overridden_method) || py::isinstance<py::none>(value_obj)) {
1871     return nullptr;
1872   }
1873   {
1874     MS_LOG_TRY_CATCH_SCOPE;
1875     inner_fg = parse::ParsePythonCode(overridden_method);
1876   }
1877   MS_EXCEPTION_IF_NULL(out_conf);
1878   auto eng = out_conf->engine();
1879   MS_EXCEPTION_IF_NULL(eng);
1880   auto cnode = node->cast<CNodePtr>();
1881   MS_EXCEPTION_IF_NULL(cnode);
1882   FuncGraphPtr func_graph = node->func_graph();
1883   MS_EXCEPTION_IF_NULL(func_graph);
1884   const auto &interpreted_obj = std::make_shared<parse::InterpretedObject>(value_obj);
1885   const auto &value_node = NewValueNode(interpreted_obj);
1886   if (inner_fg == nullptr) {
1887     std::vector<AnfNodePtr> new_inputs;
1888     for (size_t i = 0; i < cnode->size(); i++) {
1889       if (i == 1) {
1890         new_inputs.push_back(value_node);
1891       } else {
1892         new_inputs.push_back(cnode->input(i));
1893       }
1894     }
1895     CNodePtr new_cnode = func_graph->NewCNode(new_inputs);
1896     auto fn_conf = eng->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
1897     return eng->ForwardConfig(out_conf, fn_conf);
1898   }
1899   AddToManager(eng, inner_fg);
1900   if (is_getattr) {
1901     std::vector<AnfNodePtr> new_inputs = {NewValueNode(inner_fg)};
1902     for (size_t i = 0; i < cnode->size(); i++) {
1903       if (i > 0) {
1904         new_inputs.push_back(cnode->input(i));
1905       }
1906     }
1907     CNodePtr new_cnode = func_graph->NewCNode(new_inputs);
1908     auto fn_conf = eng->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
1909     return eng->ForwardConfig(out_conf, fn_conf);
1910   }
1911   std::vector<AnfNodePtr> input = {NewValueNode(prim::kPrimPartial)};
1912   input.push_back(NewValueNode(inner_fg));
1913   input.push_back(value_node);
1914   CNodePtr new_cnode = func_graph->NewCNode(input);
1915   auto fn_conf = eng->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
1916   return eng->ForwardConfig(out_conf, fn_conf);
1917 }
1918 
GetEvaluatedValueForNameSpace(const AbstractBasePtrList & args_abs_list,const AnfNodeConfigPtr & out_conf,const bool check_override=false)1919 EvalResultPtr GetEvaluatedValueForNameSpace(const AbstractBasePtrList &args_abs_list, const AnfNodeConfigPtr &out_conf,
1920                                             const bool check_override = false) {
1921   // args_abs_list: same as StaticGetter
1922   constexpr size_t args_min_size = 2;
1923   if (args_abs_list.size() < args_min_size) {
1924     MS_LOG(INTERNAL_EXCEPTION) << "Size of args_abs_list is less than 2";
1925   }
1926   MS_EXCEPTION_IF_NULL(out_conf);
1927   // An external type.
1928   constexpr auto data_index = 0;
1929   constexpr auto item_index = 1;
1930   auto data = args_abs_list[data_index];
1931   auto item = args_abs_list[item_index];
1932   MS_EXCEPTION_IF_NULL(data);
1933   MS_EXCEPTION_IF_NULL(item);
1934   MS_EXCEPTION_IF_NULL(out_conf->node());
1935   auto data_value = data->BuildValue();
1936   MS_EXCEPTION_IF_NULL(data_value);
1937   auto data_type = data->BuildType();
1938   MS_EXCEPTION_IF_NULL(data_type);
1939   auto item_value = item->BuildValue();
1940   std::string data_id_str = TypeIdToString(data_type->type_id());
1941   if (check_override) {
1942     auto inner_fg_res = GenerateFuncGraphForOverriddenMethod(out_conf->node(), item_value, out_conf);
1943     if (inner_fg_res != nullptr) return inner_fg_res;
1944   }
1945   py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
1946   if (data_value->isa<parse::ClassType>()) {
1947     auto class_val = dyn_cast_ptr<parse::ClassType>(data_value);
1948     auto class_obj = class_val->obj();
1949     py::object ns_obj = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, class_obj);
1950     data_value = std::make_shared<parse::NameSpace>(parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, ns_obj);
1951     data_id_str = class_val->name();
1952   }
1953   if (data_value->isa<parse::MsClassObject>()) {
1954     auto class_val = dyn_cast_ptr<parse::MsClassObject>(data_value);
1955     auto class_obj = class_val->obj();
1956     py::object ns_obj = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, class_obj);
1957     data_value = std::make_shared<parse::NameSpace>(parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, ns_obj);
1958     data_id_str = class_val->name();
1959   }
1960   if (!data_value->isa<parse::NameSpace>()) {
1961     MS_EXCEPTION_IF_NULL(item_value);
1962     MS_LOG(DEBUG) << "Evaluate " << data_value->ToString() << " attribute: " << item_value->ToString()
1963                   << ".\nnode: " << out_conf->node()->DebugString() << "\n"
1964                   << trace::GetDebugInfoStr(out_conf->node()->debug_info());
1965     auto res = InterpretGetAttrNode(args_abs_list, out_conf);
1966     if (res == nullptr) {
1967       MS_EXCEPTION(AttributeError) << data_value->ToString() << " object has no attribute: " << item_value->ToString();
1968     }
1969     return res;
1970   }
1971   return GetEvaluatedValueForNameSpaceString(args_abs_list, data_value, out_conf, data_id_str);
1972 }
1973 
GetEvaluatedValueForPrimitiveAttr(const AbstractBasePtrList & args_abs_list,const AbstractFunctionPtr & data_args)1974 EvalResultPtr GetEvaluatedValueForPrimitiveAttr(const AbstractBasePtrList &args_abs_list,
1975                                                 const AbstractFunctionPtr &data_args) {
1976   MS_EXCEPTION_IF_NULL(data_args);
1977   if (!data_args->isa<PrimitiveAbstractClosure>()) {
1978     return nullptr;
1979   }
1980   auto prim_abs = dyn_cast_ptr<PrimitiveAbstractClosure>(data_args);
1981   const auto &prim = prim_abs->prim();
1982   MS_EXCEPTION_IF_NULL(prim);
1983   constexpr auto item_index = 1;
1984   auto item_arg = args_abs_list.at(item_index);
1985   MS_EXCEPTION_IF_NULL(item_arg);
1986   auto attr_name = GetValue<string>(item_arg->BuildValue());
1987   auto value = prim->GetAttr(attr_name);
1988   if (value == nullptr) {
1989     MS_LOG(INFO) << "The Primitive: " << prim->ToString() << " has not attr " << attr_name;
1990     MS_LOG(INFO) << "PrimAttr: " << prim->GetAttrsText();
1991     return nullptr;
1992   }
1993   return std::make_shared<EvalResult>(value->ToAbstract(), nullptr);
1994 }
1995 
GetEvaluatedValueForAdapterTensorAttrOrMethod(const AnalysisEnginePtr & engine,const AbstractBasePtr & data_args,const AbstractBasePtr & item_args,const ConfigPtr & data_conf,const AnfNodeConfigPtr & out_conf)1996 EvalResultPtr GetEvaluatedValueForAdapterTensorAttrOrMethod(const AnalysisEnginePtr &engine,
1997                                                             const AbstractBasePtr &data_args,
1998                                                             const AbstractBasePtr &item_args,
1999                                                             const ConfigPtr &data_conf,
2000                                                             const AnfNodeConfigPtr &out_conf) {
2001   MS_EXCEPTION_IF_NULL(data_args);
2002   MS_EXCEPTION_IF_NULL(item_args);
2003   // Check whether it is AdapterTensor or AdapterParameter.
2004   auto abs = data_args->cast_ptr<abstract::AbstractTensor>();
2005   if (abs == nullptr || !abs->is_adapter()) {
2006     return nullptr;
2007   }
2008 
2009   // Get the name of attr/method.
2010   ValuePtr item_value = item_args->BuildValue();
2011   MS_EXCEPTION_IF_NULL(item_value);
2012   if (!item_value->isa<StringImm>()) {
2013     MS_LOG(EXCEPTION) << "Expect a string, but got: " << item_value->ToString();
2014   }
2015   std::string item_name = item_value->cast_ptr<StringImm>()->value();
2016 
2017   constexpr size_t attr_index = 0;
2018   constexpr size_t flag_index = 1;
2019   constexpr size_t info_required_size = 2;
2020   py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
2021   py::tuple attr_info = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_ADAPTER_TENSOR_ATTR, py::str(item_name));
2022   if (attr_info.size() != info_required_size) {
2023     MS_EXCEPTION(NameError) << "attr info size should be 2, but got " << attr_info.size();
2024   }
2025   // If func is none, it means there is no such attr or method.
2026   py::object func = attr_info[attr_index];
2027   if (py::isinstance<py::none>(func)) {
2028     return nullptr;
2029   }
2030   ValuePtr converted_value = nullptr;
2031   bool success = parse::ConvertData(func, &converted_value);
2032   if (!success || converted_value == nullptr || !converted_value->isa<FuncGraph>()) {
2033     return nullptr;
2034   }
2035   AddToManager(engine, converted_value->cast<FuncGraphPtr>());
2036 
2037   // Check whether it is an attribute or a method.
2038   bool is_attr = attr_info[flag_index].cast<bool>();
2039   REQUIRE_TYPE require_type = is_attr ? REQUIRE_TYPE::ATTR : REQUIRE_TYPE::METHOD;
2040   return StaticGetterInferred(converted_value, data_conf, out_conf, require_type);
2041 }
2042 
GetOriginObj(const AnfNodePtr & node)2043 py::object GetOriginObj(const AnfNodePtr &node) {
2044   MS_EXCEPTION_IF_NULL(node);
2045   py::object origin_obj;
2046   if (node->has_user_data("origin_object")) {
2047     return *node->user_data<py::object>("origin_object");
2048   }
2049   if (!node->isa<ValueNode>()) {
2050     return origin_obj;
2051   }
2052   auto vnode = node->cast<ValueNodePtr>();
2053   if (vnode->value()->has_user_data("origin_object")) {
2054     return *vnode->value()->user_data<py::object>("origin_object");
2055   }
2056   return origin_obj;
2057 }
2058 
GetEvaluatedValueForAttrOrMethodNotInMap(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const AnfNodeConfigPtr & out_conf,const std::string & item_name,const TypePtr & data_type)2059 EvalResultPtr GetEvaluatedValueForAttrOrMethodNotInMap(const AnalysisEnginePtr &engine,
2060                                                        const AbstractBasePtrList &args_abs_list,
2061                                                        const AnfNodeConfigPtr &out_conf, const std::string &item_name,
2062                                                        const TypePtr &data_type) {
2063   constexpr auto max_args_len = 3;
2064   bool has_default = (args_abs_list.size() == max_args_len);
2065   auto out_node = out_conf->node();
2066   auto out_cnode = out_node->cast_ptr<CNode>();
2067   MS_EXCEPTION_IF_NULL(out_cnode);
2068   auto eng = out_conf->engine();
2069   MS_EXCEPTION_IF_NULL(eng);
2070   if (has_default) {
2071     constexpr auto default_index = 3;
2072     auto default_node = out_cnode->input(default_index);
2073     auto fn_conf = eng->MakeConfig(default_node, out_conf->context(), out_conf->func_graph());
2074     return eng->ForwardConfig(out_conf, fn_conf);
2075   }
2076 
2077   py::object value_obj = GetOriginObj(out_cnode->input(1));
2078   if (value_obj.ptr() != nullptr) {
2079     std::vector<AnfNodePtr> new_inputs;
2080     std::string data_type_str = TypeIdLabel(NormalizeTypeId(data_type->type_id()));
2081     py::module mod1 = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
2082     py::object obj_define = python_adapter::CallPyModFn(mod1, parse::PYTHON_MOD_GET_OBJ_DEFINED, data_type_str);
2083     py::object check_res =
2084       python_adapter::CallPyModFn(mod1, parse::PYTHON_MOD_CHECK_IS_SUBCLASS, value_obj, obj_define);
2085     if (py::cast<bool>(check_res)) {
2086       for (size_t i = 0; i < out_cnode->size(); i++) {
2087         if (i == 1) {
2088           const auto &interpreted_obj = std::make_shared<parse::InterpretedObject>(value_obj);
2089           const auto &value_node = NewValueNode(interpreted_obj);
2090           new_inputs.push_back(value_node);
2091         } else {
2092           new_inputs.push_back(out_cnode->input(i));
2093         }
2094       }
2095       CNodePtr new_cnode = out_conf->func_graph()->NewCNode(new_inputs);
2096       auto fn_conf = eng->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
2097       return eng->ForwardConfig(out_conf, fn_conf);
2098     }
2099   }
2100   const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() == kLax);
2101   if (!allow_fallback_runtime) {
2102     MS_EXCEPTION(AttributeError) << "In JIT strict mode, cannot get attributes " << item_name << " or the "
2103                                  << data_type->ToString() << " object has no attribute: " << item_name
2104                                  << "'. You can use os.environ['MS_DEV_JIT_SYNTAX_LEVEL'] = '2' "
2105                                  << "to enable the JIT lax mode to support the current syntax.\n\n"
2106                                  << trace::GetDebugInfoStr(out_conf->node()->debug_info());
2107   }
2108 
2109   constexpr auto recursive_level = 3;
2110   MS_LOG(DEBUG) << "Evaluate " << data_type->ToString() << " attribute: " << item_name
2111                 << ".\nnode: " << out_conf->node()->DebugString(recursive_level) << "\n"
2112                 << trace::GetDebugInfoStr(out_conf->node()->debug_info());
2113   auto res = InterpretGetAttrNode(args_abs_list, out_conf);
2114   if (res == nullptr) {
2115     MS_EXCEPTION(AttributeError) << data_type->ToString() << " object has no attribute: " << item_name;
2116   }
2117   return res;
2118 }
2119 
GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const ConfigPtr & data_conf,const AnfNodeConfigPtr & out_conf)2120 EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine,
2121                                                           const AbstractBasePtrList &args_abs_list,
2122                                                           const ConfigPtr &data_conf,
2123                                                           const AnfNodeConfigPtr &out_conf) {
2124   constexpr size_t data_index = 0;
2125   constexpr size_t item_index = 1;
2126   auto data_args = args_abs_list[data_index];
2127   auto item_args = args_abs_list[item_index];
2128   MS_EXCEPTION_IF_NULL(data_args);
2129   MS_EXCEPTION_IF_NULL(item_args);
2130   ValuePtr item_value = item_args->BuildValue();
2131   MS_EXCEPTION_IF_NULL(item_value);
2132   TypePtr data_type = data_args->BuildType();
2133   MS_EXCEPTION_IF_NULL(data_type);
2134   // Handle NameTuple: getattr(XX, item_value) -> ValueNode().
2135   if (data_args->isa<AbstractNamedTuple>()) {
2136     auto named_tuple = data_args->cast<AbstractNamedTuplePtr>();
2137     const auto &keys = named_tuple->key();
2138     for (size_t it = 0; it < keys.size(); ++it) {
2139       auto key_value = keys[it]->BuildValue();
2140       MS_EXCEPTION_IF_NULL(key_value);
2141       if (*item_value == *key_value) {
2142         auto getattr_node = NewValueNode(named_tuple->elements()[it]->BuildValue());
2143         auto eng = out_conf->engine();
2144         MS_EXCEPTION_IF_NULL(eng);
2145         auto fn_conf = eng->MakeConfig(getattr_node, out_conf->context(), out_conf->func_graph());
2146         return eng->ForwardConfig(out_conf, fn_conf);
2147       }
2148     }
2149   }
2150 
2151   // The method maybe a Primitive or Composite
2152   if (!item_value->isa<StringImm>()) {
2153     MS_LOG(EXCEPTION) << "Expect a string, but got: " << item_value->ToString();
2154   }
2155   auto item_str = item_value->cast_ptr<StringImm>();
2156   MS_EXCEPTION_IF_NULL(item_str);
2157   std::string item_name = item_str->value();
2158   REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD;
2159   Any require = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name);
2160   MS_EXCEPTION_IF_NULL(out_conf->node());
2161   if (require.empty()) {
2162     require = pipeline::Resource::GetAttrPtr(data_type->type_id(), item_name);
2163     if (require.empty()) {
2164       return GetEvaluatedValueForAttrOrMethodNotInMap(engine, args_abs_list, out_conf, item_name, data_type);
2165     }
2166     require_type = REQUIRE_TYPE::ATTR;
2167   }
2168 
2169   ValuePtr converted_value = nullptr;
2170   if (require.is<std::string>()) {
2171     // composite registered in standard_method_map go to this branch
2172     converted_value = prim::GetPythonOps(require.cast<std::string>());
2173     MS_EXCEPTION_IF_NULL(converted_value);
2174 
2175     auto converted_fg = converted_value->cast<FuncGraphPtr>();
2176     if (converted_fg != nullptr) {
2177       bool has_recompute_scope =
2178         (out_conf->node()->scope() != nullptr &&
2179          out_conf->node()->scope()->name().compare(0, strlen(kAttrRecompute), kAttrRecompute) == 0);
2180       if (has_recompute_scope) {
2181         parse::UpdateRecomputeScope(converted_fg);
2182       } else if (MsContext::GetInstance()->get_param<int>(MS_CTX_DEBUG_LEVEL) == kLevelDebug) {
2183         UpdateDebugInfo(converted_fg, out_conf->node()->scope(), out_conf->node()->debug_info());
2184       }
2185     }
2186 
2187     if (!converted_value->isa<Primitive>()) {
2188       AddToManager(engine, converted_value->cast<FuncGraphPtr>());
2189     }
2190   } else if (require.is<PrimitivePtr>()) {
2191     converted_value = require.cast<PrimitivePtr>();
2192   } else {
2193     MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from attr or method map, but got " << require.ToString();
2194   }
2195   return StaticGetterInferred(converted_value, data_conf, out_conf, require_type);
2196 }
2197 
TransPropertyToFunc(const AnfNodeConfigPtr & out_conf,py::object property_net_obj,std::string item_name)2198 EvalResultPtr TransPropertyToFunc(const AnfNodeConfigPtr &out_conf, py::object property_net_obj,
2199                                   std::string item_name) {
2200   py::object property_func = py::none();
2201   try {
2202     property_func = property_net_obj.attr("__class__").attr(py::str(item_name));
2203   } catch (const std::exception &e) {
2204     MS_LOG(ERROR) << property_net_obj << " has no attribute " << item_name;
2205   }
2206   py::object property_func_fget = property_func.attr(py::str("fget"));
2207   auto inner_fg = parse::ParsePythonCode(property_func_fget);
2208   auto eng = out_conf->engine();
2209   MS_EXCEPTION_IF_NULL(eng);
2210   AddToManager(eng, inner_fg);
2211   auto node = out_conf->node();
2212   auto cnode = node->cast<CNodePtr>();
2213   MS_EXCEPTION_IF_NULL(cnode);
2214   FuncGraphPtr func_graph = node->func_graph();
2215   MS_EXCEPTION_IF_NULL(func_graph);
2216   std::vector<AnfNodePtr> new_inputs = {NewValueNode(inner_fg)};
2217   new_inputs.push_back(cnode->input(1));
2218   CNodePtr new_cnode = func_graph->NewCNode(new_inputs);
2219   MS_LOG(DEBUG) << "new_cnode:" << new_cnode->DebugString();
2220   auto fn_conf = eng->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
2221   return eng->ForwardConfig(out_conf, fn_conf);
2222 }
2223 
GetClassAttrFromPyObject(const py::object & cls_obj,const std::string & cls_name,const AbstractBasePtrList & args_abs_list,const AnfNodeConfigPtr & out_conf)2224 EvalResultPtr GetClassAttrFromPyObject(const py::object &cls_obj, const std::string &cls_name,
2225                                        const AbstractBasePtrList &args_abs_list, const AnfNodeConfigPtr &out_conf) {
2226   py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
2227   constexpr auto item_index = 1;
2228   auto item_arg = args_abs_list.at(item_index);
2229   MS_EXCEPTION_IF_NULL(item_arg);
2230   auto attr_name = GetValue<string>(item_arg->BuildValue());
2231   bool is_property =
2232     (python_adapter::CallPyModFn(mod, parse::PYTHON_PARSE_CHECK_ATTR_IS_PROPERTY, cls_obj, attr_name)).cast<bool>();
2233   if (is_property) {
2234     ValuePtr item_value = item_arg->BuildValue();
2235     MS_EXCEPTION_IF_NULL(item_value);
2236     const auto &item_str = item_value->cast_ptr<StringImm>();
2237     const std::string &item_name = item_str->value();
2238     return TransPropertyToFunc(out_conf, cls_obj, item_name);
2239   }
2240   py::object ns_obj = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, cls_obj);
2241   auto ns = std::make_shared<parse::NameSpace>(parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, ns_obj);
2242   return GetEvaluatedValueForNameSpaceString(args_abs_list, ns, out_conf, cls_name);
2243 }
2244 
GetFuncAbstractAttr(const AbstractFunctionPtr & data_args,const AbstractBasePtrList & args_abs_list,const AnfNodeConfigPtr & out_conf)2245 EvalResultPtr GetFuncAbstractAttr(const AbstractFunctionPtr &data_args, const AbstractBasePtrList &args_abs_list,
2246                                   const AnfNodeConfigPtr &out_conf) {
2247   if (data_args == nullptr) {
2248     return nullptr;
2249   }
2250   // Get attribute or method of PartialAbstractClosure, the object could be nn.Cell/ms_class object.
2251   auto data_partial = dyn_cast_ptr<PartialAbstractClosure>(data_args);
2252   if (data_partial != nullptr) {
2253     const auto &partial_args = data_partial->args();
2254     auto prim_abs = dyn_cast_ptr<PrimitiveAbstractClosure>(data_partial->fn());
2255     if (prim_abs != nullptr && !partial_args.empty()) {
2256       MS_EXCEPTION_IF_NULL(prim_abs->prim());
2257       const auto &prim_name = prim_abs->prim()->name();
2258       if (prim_name == prim::kPrimCreateInstance->name()) {
2259         constexpr size_t class_index = 0;
2260         MS_EXCEPTION_IF_NULL(partial_args[class_index]);
2261         auto class_val = partial_args[class_index]->BuildValue();
2262         MS_EXCEPTION_IF_NULL(class_val);
2263         auto wrapper = dyn_cast_ptr<parse::PyObjectWrapper>(class_val);
2264         MS_EXCEPTION_IF_NULL(wrapper);
2265         return GetClassAttrFromPyObject(wrapper->obj(), wrapper->name(), args_abs_list, out_conf);
2266       }
2267     }
2268     return nullptr;
2269   }
2270   // Get attribute or method of FuncGraphAbstractClosure, the object could be nn.Cell/ms_class object.
2271   const auto &cls_obj = fallback::GetPyObjForFuncGraphAbstractClosure(data_args);
2272   if (py::isinstance<Cell>(cls_obj) || py::hasattr(cls_obj, PYTHON_MS_CLASS)) {
2273     return GetClassAttrFromPyObject(cls_obj, py::str(cls_obj), args_abs_list, out_conf);
2274   }
2275   return GetEvaluatedValueForPrimitiveAttr(args_abs_list, data_args);
2276 }
2277 
CheckHasOverriddenMethod(AnfNodePtr node,ValuePtr item_value)2278 bool CheckHasOverriddenMethod(AnfNodePtr node, ValuePtr item_value) {
2279   const auto &item_str = item_value->cast_ptr<StringImm>();
2280   py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
2281   if (item_str != nullptr) {
2282     const std::string &item_name = item_str->value();
2283     if (node->has_user_data(item_name)) {
2284       auto value_obj = *node->user_data<py::object>(item_name);
2285       py::bool_ check = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_CHECK_ATTRS, value_obj, item_name);
2286       return py::cast<bool>(check);
2287     }
2288   }
2289   if (node->has_user_data("__getattr__")) {
2290     auto value_obj = *node->user_data<py::object>("__getattr__");
2291     py::bool_ check = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_CHECK_ATTRS, value_obj, "__getattr__");
2292     return py::cast<bool>(check);
2293   }
2294   return false;
2295 }
2296 
StaticGetter(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const ConfigPtr & data_conf,const AnfNodeConfigPtr & out_conf)2297 EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list,
2298                            const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
2299   // Inputs: namespace and its static function; or class and its member function
2300   constexpr size_t data_index = 0;
2301   constexpr size_t item_index = 1;
2302   auto data_args = args_abs_list[data_index];
2303   auto item_args = args_abs_list[item_index];
2304   MS_EXCEPTION_IF_NULL(data_args);
2305   MS_EXCEPTION_IF_NULL(item_args);
2306   MS_EXCEPTION_IF_NULL(out_conf);
2307   MS_EXCEPTION_IF_NULL(out_conf->node());
2308   constexpr auto recursive_level = 2;
2309   MS_LOG(DEBUG) << "StaticGetter, data: " << data_args->ToString() << ", item: " << item_args->ToString()
2310                 << ", node: " << out_conf->node()->DebugString(recursive_level);
2311   ScopePtr scope = out_conf->node()->scope();
2312   ScopeGuard scope_guard(scope);
2313   ValuePtr item_value = item_args->BuildValue();
2314   MS_EXCEPTION_IF_NULL(item_value);
2315   if (item_value->ContainsValueAny()) {
2316     MS_LOG(INTERNAL_EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
2317   }
2318 
2319   const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() == kLax);
2320   constexpr auto max_args_size = 3;
2321   if (!allow_fallback_runtime && args_abs_list.size() == max_args_size) {
2322     constexpr size_t default_index = 2;
2323     auto default_args = args_abs_list[default_index];
2324     MS_EXCEPTION_IF_NULL(default_args);
2325     if (default_args->isa<abstract::AbstractScalar>()) {
2326       ValuePtr default_value = default_args->BuildValue();
2327       MS_EXCEPTION_IF_NULL(default_value);
2328       if (default_value->isa<parse::InterpretedObject>()) {
2329         auto obj = ValueToPyData(default_value);
2330         auto type_str = python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_GET_TYPE, obj);
2331         MS_EXCEPTION(TypeError) << "For 'getattr', the third input 'default' can not be " << py::str(type_str)
2332                                 << " object " << py::str(obj);
2333       }
2334     }
2335   }
2336 
2337   auto res = GetFuncAbstractAttr(data_args->cast<AbstractFunctionPtr>(), args_abs_list, out_conf);
2338   if (res != nullptr) {
2339     return res;
2340   }
2341 
2342   // Get attribute or method of AdapterTensor object.
2343   res = GetEvaluatedValueForAdapterTensorAttrOrMethod(engine, data_args, item_args, data_conf, out_conf);
2344   if (res != nullptr) {
2345     return res;
2346   }
2347   // Try to search method map, if not found, the data_type should be External type.
2348   TypePtr data_type = data_args->BuildType();
2349   MS_EXCEPTION_IF_NULL(data_type);
2350   // Check if attr is a overridden method.
2351   bool check_override = CheckHasOverriddenMethod(out_conf->node(), item_value);
2352   // Not check if the data is from PyExecute CNode, since its Tensor output is pseud.
2353   if (!IsPyExecuteData(data_args) && pipeline::Resource::IsTypeInBuiltInMap(data_type->type_id()) && !check_override) {
2354     return GetEvaluatedValueForBuiltinTypeAttrOrMethod(engine, args_abs_list, data_conf, out_conf);
2355   }
2356   return GetEvaluatedValueForNameSpace(args_abs_list, out_conf, check_override);
2357 }
2358 
GetAnnotationType(const AnfNodePtr & node,const AbstractBasePtrList & args_abs_list)2359 TypePtr GetAnnotationType(const AnfNodePtr &node, const AbstractBasePtrList &args_abs_list) {
2360   MS_EXCEPTION_IF_NULL(node);
2361   fallback::FormatedVariableTypeFunc func = [&node, &args_abs_list](const std::string &type_var_str) -> TypePtr {
2362     // For PyInterpret, the args[1] is global dict, and the args[2] is local dict.
2363     // For PyExecute, the args[1] is local dict keys, and the args[2] is local dict values.
2364     ValuePtr type_value = nullptr;
2365     const auto &keys_tuple_abs = args_abs_list[1];
2366     MS_EXCEPTION_IF_NULL(keys_tuple_abs);
2367     const auto &keys_tuple = keys_tuple_abs->BuildValue();
2368     const auto &keys = dyn_cast<ValueSequence>(keys_tuple);
2369     bool is_py_execute = (keys != nullptr);
2370     if (is_py_execute) {  // PyExecute.
2371       bool found = false;
2372       size_t i = 0;
2373       for (; i < keys->value().size(); ++i) {
2374         const auto &key = dyn_cast<StringImm>(keys->value()[i]);
2375         MS_EXCEPTION_IF_NULL(key);
2376         if (key->value() == type_var_str) {
2377           found = true;
2378           break;
2379         }
2380       }
2381 
2382       if (!found) {
2383         MS_LOG(INFO) << "Not valid PyExecute CNode. node: " << node->DebugString() << ", keys: " << keys->ToString()
2384                      << ", not found " << type_var_str;
2385         return nullptr;
2386       }
2387       constexpr auto values_index = 2;
2388       const auto &values_tuple_abs = dyn_cast<AbstractSequence>(args_abs_list[values_index]);
2389       MS_EXCEPTION_IF_NULL(values_tuple_abs);
2390       const auto &type_value_abs = values_tuple_abs->elements()[i];
2391       if (type_value_abs == nullptr) {
2392         MS_LOG(INFO) << "Not valid PyExecute CNode. node: " << node->DebugString() << ", key: " << type_var_str
2393                      << ", values_tuple_abs: " << values_tuple_abs->ToString();
2394         return nullptr;
2395       }
2396       bool only_has_real_type = !fallback::HasRealShape(type_value_abs) && fallback::HasRealType(type_value_abs);
2397       type_value =
2398         only_has_real_type ? fallback::GetRealType<AbstractBase, Type>(type_value_abs) : type_value_abs->BuildValue();
2399     } else {  // PyInterpret
2400       constexpr auto local_dict_index = 2;
2401       const auto &local_dict_abs = args_abs_list[local_dict_index];
2402       const auto &dict = dyn_cast<AbstractDictionary>(local_dict_abs);
2403       if (dict == nullptr || dict->elements().empty()) {
2404         MS_EXCEPTION_IF_NULL(local_dict_abs);
2405         MS_LOG(INFO) << "Not valid PyInterpret CNode. node: " << node->DebugString() << ", key: " << type_var_str
2406                      << ", local_dict_abs: " << local_dict_abs->ToString();
2407         return nullptr;
2408       }
2409       for (const auto &element : dict->elements()) {
2410         MS_EXCEPTION_IF_NULL(element.first);
2411         const auto &key = element.first->BuildValue();
2412         if (key == nullptr || !key->isa<StringImm>()) {
2413           continue;
2414         }
2415         if (key->cast<StringImmPtr>()->value() == type_var_str) {
2416           MS_EXCEPTION_IF_NULL(element.second);
2417           type_value = element.second->BuildValue();
2418           break;
2419         }
2420       }
2421     }
2422 
2423     if (type_value == nullptr) {
2424       MS_LOG(INFO) << "Not valid " << (is_py_execute ? "PyExecute" : "PyInterpret")
2425                    << " CNode. node: " << node->DebugString() << ", key: " << type_var_str << ", type value is null.";
2426       return nullptr;
2427     }
2428     const auto &py_type = BuildPyObject(type_value);
2429     MS_LOG(DEBUG) << "type_value: " << type_value->ToString() << ", py_type: " << py_type;
2430     if (!py::isinstance<py::none>(py_type)) {
2431       return py::cast<TypePtr>(py_type);
2432     }
2433     MS_LOG(INFO) << "Not valid " << (is_py_execute ? "PyExecute" : "PyInterpret")
2434                  << " CNode. node: " << node->DebugString() << ", key: " << type_var_str << ", type value is None.";
2435     return nullptr;
2436   };
2437   const auto &type = fallback::GetJitAnnotationTypeFromComment(node, func);
2438   return type;
2439 }
2440 
GetLocalArgsUniqueDtype(const AnfNodePtr & node,const AbstractBasePtrList & args_abs_list)2441 TypePtr GetLocalArgsUniqueDtype(const AnfNodePtr &node, const AbstractBasePtrList &args_abs_list) {
2442   // If force to use ANY.
2443   static const auto force_any = (common::GetCompileConfig("FALLBACK_FORCE_ANY") == "1");
2444   if (force_any) {
2445     return nullptr;
2446   }
2447 
2448   TypePtr res = nullptr;
2449   // Check the abstract, return true if continue, otherwise return false.
2450   auto unique_dtype_check = [&node, &res](const AbstractBasePtr &element_value_abs) -> bool {
2451     MS_EXCEPTION_IF_NULL(element_value_abs);
2452     if (!element_value_abs->isa<abstract::AbstractTensor>()) {
2453       return true;
2454     }
2455     // Fetch the dtype from element_value_abs of tensor.
2456     auto element_abs_tensor = element_value_abs->cast_ptr<abstract::AbstractTensor>();
2457     MS_EXCEPTION_IF_NULL(element_abs_tensor);
2458     MS_EXCEPTION_IF_NULL(element_abs_tensor->element());
2459     const auto dtype = element_abs_tensor->element()->BuildType();
2460     MS_EXCEPTION_IF_NULL(dtype);
2461     // Check default dtype if it's AbstractAny(AbstractTensor)
2462     if (element_value_abs->isa<abstract::AbstractAny>() &&
2463         !element_value_abs->cast_ptr<abstract::AbstractAny>()->supposed_tensor_dtype()) {
2464       return true;
2465     }
2466     if (res == nullptr) {
2467       MS_EXCEPTION_IF_NULL(node);
2468       MS_LOG(INFO) << "Tensor dtype found, set as unique dtype: " << dtype->ToString()
2469                    << ", node: " << node->DebugString() << "\n\n"
2470                    << trace::GetDebugInfoStr(node->debug_info());
2471       res = dtype;
2472       return true;
2473     }
2474     if (res != dtype) {
2475       MS_EXCEPTION_IF_NULL(node);
2476       MS_LOG(INFO) << "More than one tensor dtype found, not set unique dtype. node: " << node->DebugString() << "\n\n"
2477                    << trace::GetDebugInfoStr(node->debug_info());
2478       return false;
2479     }
2480     return true;
2481   };
2482   constexpr auto values_index = 2;
2483   if (args_abs_list.size() <= values_index) {
2484     return nullptr;
2485   }
2486   const auto &values_tuple_abs = dyn_cast<AbstractSequence>(args_abs_list[values_index]);
2487   bool is_py_execute = (values_tuple_abs != nullptr);
2488   if (is_py_execute) {  // PyExecute CNode.
2489     const auto &elements_abs = values_tuple_abs->elements();
2490     for (const auto &element_abs : elements_abs) {
2491       if (!unique_dtype_check(element_abs)) {
2492         return nullptr;
2493       }
2494     }
2495   } else {  // PyInterpret CNode.
2496     const auto &local_dict_abs = dyn_cast<AbstractDictionary>(args_abs_list[values_index]);
2497     MS_EXCEPTION_IF_NULL(local_dict_abs);
2498     const auto &elements_abs = local_dict_abs->elements();
2499     for (const auto &element_abs_pair : elements_abs) {
2500       const auto &element_value_abs = element_abs_pair.second;
2501       if (!unique_dtype_check(element_value_abs)) {
2502         return nullptr;
2503       }
2504     }
2505   }
2506 
2507   if (res != nullptr) {
2508     MS_LOG(INFO) << "Apply unique dtype: " << res->ToString() << " to node: " << node->DebugString() << "\n\n"
2509                  << trace::GetDebugInfoStr(node->debug_info());
2510   }
2511   return res;
2512 }
2513 
AddLabelsToPrimitiveFunction(const PrimitivePtr & prim_func)2514 void AddLabelsToPrimitiveFunction(const PrimitivePtr &prim_func) {
2515   auto prim_name = prim_func->name();
2516   py::module mod = py::module::import(parse::PYTHON_MOD_PRIMITIVE_OP_CREATE_INSTANCE_HELPER_MODULE);
2517   if (!py::hasattr(mod, parse::PYTHON_MOD_PRIMITIVE_OP_LABELS_DICT)) {
2518     MS_LOG(INTERNAL_EXCEPTION) << "Can not found " << parse::PYTHON_MOD_PRIMITIVE_OP_LABELS_DICT << " in "
2519                                << parse::PYTHON_MOD_PRIMITIVE_OP_CREATE_INSTANCE_HELPER_MODULE << ".";
2520   }
2521   py::dict op_labels = mod.attr(parse::PYTHON_MOD_PRIMITIVE_OP_LABELS_DICT);
2522   if (!op_labels.contains(py::str(prim_name))) {
2523     return;
2524   }
2525   py::dict labels = op_labels[py::str(prim_name)];
2526   for (const auto &p : labels) {
2527     auto attr_name = py::cast<std::string>(p.first);
2528     auto attr_obj = py::reinterpret_borrow<py::object>(p.second);
2529     ValuePtr converted_ret = nullptr;
2530     bool converted = parse::ConvertData(attr_obj, &converted_ret);
2531     if (!converted) {
2532       MS_LOG(INTERNAL_EXCEPTION) << "Call 'add_attr' to add attribute to primitive failed,"
2533                                  << " convert python obj to MindSpore obj failed; primitive name: " << prim_name
2534                                  << ", attribute name:" << attr_name << ", attribute value:" << py::str(attr_obj)
2535                                  << ", attribute type:"
2536                                  << py::cast<std::string>(attr_obj.attr("__class__").attr("__name__"));
2537     }
2538     MS_LOG(DEBUG) << "Add attr {" << attr_name << ": " << converted_ret->ToString() << "} to " << prim_name;
2539     (void)prim_func->AddAttr(attr_name, converted_ret);
2540   }
2541 }
2542 
GeneratePrimitiveDefaultArgs(const std::string & op_name,const std::vector<AnfNodePtr> & args_list,const std::vector<ops::OpInputArg> & op_args,bool check_init)2543 std::vector<AnfNodePtr> GeneratePrimitiveDefaultArgs(const std::string &op_name,
2544                                                      const std::vector<AnfNodePtr> &args_list,
2545                                                      const std::vector<ops::OpInputArg> &op_args, bool check_init) {
2546   size_t args_size = args_list.size();
2547   std::vector<AnfNodePtr> nodes;
2548   for (const auto &input : args_list) {
2549     if (HasAbstractMonad(input) || (IsPrimitiveCNode(input, prim::kPrimUpdateState) || IsValueNode<UMonad>(input) ||
2550                                     IsValueNode<IOMonad>(input))) {
2551       continue;
2552     }
2553     (void)nodes.emplace_back(input);
2554   }
2555   if (args_size < op_args.size()) {
2556     for (size_t i = args_size; i < op_args.size(); i++) {
2557       auto default_arg = parse::GetArgDefaultValue(op_name, op_args[i].arg_name_);
2558       if (default_arg == nullptr) {
2559         break;
2560       }
2561       MS_LOG(DEBUG) << "Get the default value of '" << op_args[i].arg_name_ << "' attribute of Primitive[" << op_name
2562                     << "], which is " << default_arg->ToString() << ".";
2563       (void)nodes.emplace_back(NewValueNode(default_arg));
2564     }
2565   }
2566   if (nodes.size() != op_args.size()) {
2567     std::string args_type_str = check_init ? "init arguments" : "inputs";
2568     MS_EXCEPTION(TypeError) << "For Operator[" << op_name << "], the number of " << args_type_str
2569                             << " (including default arguments) should be " << op_args.size()
2570                             << ", but the actual number of inputs is not satisfied, which is " << args_size << ".";
2571   }
2572   return nodes;
2573 }
2574 
ValidateAndConvertArgsType(const std::string & op_name,const std::vector<ops::OpInputArg> & op_args,const AbstractBasePtrList & abs_list,const FuncGraphPtr & fg,std::vector<AnfNodePtr> * nodes)2575 bool ValidateAndConvertArgsType(const std::string &op_name, const std::vector<ops::OpInputArg> &op_args,
2576                                 const AbstractBasePtrList &abs_list, const FuncGraphPtr &fg,
2577                                 std::vector<AnfNodePtr> *nodes) {
2578   bool exist_undetermined_arg = false;
2579   for (size_t i = 0; i < op_args.size(); i++) {
2580     auto op_arg = op_args[i];
2581     auto abs_arg = abs_list[i];
2582     if (abs_arg->isa<abstract::AbstractKeywordArg>()) {
2583       MS_EXCEPTION(TypeError) << "For Primitive[" << op_name
2584                               << "], only positional arguments as inputs are supported, but got "
2585                               << abs_arg->ToString();
2586     }
2587     if (HasAbstractUndetermined(abs_arg)) {
2588       exist_undetermined_arg = true;
2589     }
2590     if (ValidateArgOptional(abs_arg, op_arg) || ops::ValidateArgsType(abs_arg, op_arg.arg_dtype_)) {
2591       continue;
2592     }
2593     if (fallback::ContainsSequenceAnyType(abs_arg)) {
2594       continue;
2595     }
2596     bool match = false;
2597     auto cast_dtypes = op_arg.cast_dtype_;
2598     for (size_t j = 0; j < cast_dtypes.size(); j++) {
2599       if (ops::ValidateArgsType(abs_arg, cast_dtypes[j])) {
2600         (*nodes)[i] = GetNodeAfterTypeConversion((*nodes)[i], op_arg, fg);
2601         match = true;
2602         break;
2603       }
2604     }
2605     if (!match && !exist_undetermined_arg) {
2606       return false;
2607     }
2608   }
2609   return true;
2610 }
2611 
BuilidArgsTypeString(const AbstractBasePtr & arg_abs)2612 std::string BuilidArgsTypeString(const AbstractBasePtr &arg_abs) {
2613   auto arg_type = arg_abs->BuildType();
2614   MS_EXCEPTION_IF_NULL(arg_type);
2615   if (arg_type->isa<Bool>()) {
2616     return "bool";
2617   }
2618   if (arg_type->isa<Int>() || arg_type->isa<UInt>()) {
2619     return "int";
2620   }
2621   if (arg_type->isa<Float>() || arg_type->isa<BFloat>()) {
2622     return "float";
2623   }
2624   if (arg_type->isa<String>()) {
2625     return "string";
2626   }
2627   if (arg_type->isa<TypeNone>()) {
2628     return "None";
2629   }
2630   if (arg_type->isa<TensorType>()) {
2631     return "Tensor";
2632   }
2633   if (arg_type->isa<Tuple>() || arg_type->isa<List>()) {
2634     auto seq_abs = arg_abs->cast_ptr<abstract::AbstractSequence>();
2635     MS_EXCEPTION_IF_NULL(seq_abs);
2636     std::string seq_type = arg_type->isa<Tuple>() ? "tuple" : "list";
2637     if (seq_abs->dynamic_len()) {
2638       return seq_type;
2639     }
2640     std::stringstream ss;
2641     ss << seq_type << "<";
2642     for (size_t i = 0; i < seq_abs->size(); i++) {
2643       if (i == 0) {
2644         ss << BuilidArgsTypeString(seq_abs->elements()[i]);
2645       } else {
2646         ss << ", " << BuilidArgsTypeString(seq_abs->elements()[i]);
2647       }
2648     }
2649     ss << ">";
2650     return ss.str();
2651   }
2652   return arg_type->ToString();
2653 }
2654 
CheckAndConvertPrimitiveArgs(const PrimitivePtr & prim,const FuncGraphPtr & graph,const std::pair<std::vector<AnfNodePtr>,std::vector<AnfNodePtr>> & args_pair,const std::function<AbstractBasePtr (const AnfNodePtr &)> & eval_func,bool is_preprocessed)2655 CNodePtr CheckAndConvertPrimitiveArgs(const PrimitivePtr &prim, const FuncGraphPtr &graph,
2656                                       const std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> &args_pair,
2657                                       const std::function<AbstractBasePtr(const AnfNodePtr &)> &eval_func,
2658                                       bool is_preprocessed) {
2659   auto init_args_list = args_pair.first;
2660   auto call_args_list = args_pair.second;
2661   auto prim_name = prim->name();
2662   auto op_def = mindspore::ops::GetOpDef(prim_name);
2663   MS_EXCEPTION_IF_NULL(op_def);
2664   MS_EXCEPTION_IF_NULL(graph);
2665   // Check args size.
2666   std::vector<ops::OpInputArg> op_call_args;
2667   std::vector<ops::OpInputArg> op_init_args;
2668   auto op_args = op_def->args_;
2669   for (const auto &op_arg : op_args) {
2670     if (op_arg.as_init_arg_) {
2671       (void)op_init_args.emplace_back(op_arg);
2672     } else {
2673       (void)op_call_args.emplace_back(op_arg);
2674     }
2675   }
2676 
2677   MS_LOG(DEBUG) << "For Primitive[" << prim_name << "], the number of init args is expected to be "
2678                 << op_init_args.size() << ", and the number of call args is expected to be " << op_call_args.size();
2679   // Generate primitive default args.
2680   MS_LOG(DEBUG) << "For Primitive[ " << prim_name << "], before processing default args, the number of init args is "
2681                 << init_args_list.size() << " and the number of call args is " << call_args_list.size();
2682   auto call_nodes = GeneratePrimitiveDefaultArgs(prim_name, call_args_list, op_call_args, false);
2683   auto init_nodes = GeneratePrimitiveDefaultArgs(prim_name, init_args_list, op_init_args, true);
2684   MS_LOG(DEBUG) << "For Primitive[ " << prim_name << "], after processing default args, the number of init args is "
2685                 << init_args_list.size() << " and the number of call args is " << call_args_list.size();
2686   // If it is not preprocessed, signatures and need to be processed.
2687   if (!is_preprocessed) {
2688     // Process signatures.
2689     MS_LOG(DEBUG) << "Process signatures for Primitive[" << prim_name << "].";
2690     AbstractBasePtrList call_abs_list;
2691     (void)std::transform(call_nodes.cbegin(), call_nodes.cend(), std::back_inserter(call_abs_list), eval_func);
2692     call_nodes = prim::GetNewInputsBySignatures(graph, prim_name, prim, call_abs_list, call_nodes);
2693     // Process arg_handler.
2694     for (size_t i = 0; i < op_init_args.size(); i++) {
2695       auto abs_node = eval_func(init_nodes[i]);
2696       init_nodes[i] = GetNodeAfterArgHandler(init_nodes[i], prim_name, op_init_args[i], abs_node, graph);
2697     }
2698   }
2699   for (size_t i = 0; i < op_call_args.size(); i++) {
2700     auto abs_node = eval_func(call_nodes[i]);
2701     call_nodes[i] = GetNodeAfterArgHandler(call_nodes[i], prim_name, op_call_args[i], abs_node, graph);
2702   }
2703 
2704   // Check args type and do type conversion.
2705   AbstractBasePtrList call_abs_list;
2706   AbstractBasePtrList init_abs_list;
2707   (void)std::transform(call_nodes.cbegin(), call_nodes.cend(), std::back_inserter(call_abs_list), eval_func);
2708   (void)std::transform(init_nodes.cbegin(), init_nodes.cend(), std::back_inserter(init_abs_list), eval_func);
2709   MS_LOG(DEBUG) << "For Primitive[" << prim_name << "], the number of init args is " << init_nodes.size()
2710                 << " and the number of call args is " << call_nodes.size();
2711   if (!ValidateAndConvertArgsType(prim_name, op_call_args, call_abs_list, graph, &call_nodes) ||
2712       !ValidateAndConvertArgsType(prim_name, op_init_args, init_abs_list, graph, &init_nodes)) {
2713     std::vector<std::string> op_type_list;
2714     (void)std::transform(call_abs_list.cbegin(), call_abs_list.cend(), std::back_inserter(op_type_list),
2715                          [](const AbstractBasePtr &op_abs) { return BuilidArgsTypeString(op_abs); });
2716     (void)std::transform(init_abs_list.cbegin(), init_abs_list.cend(), std::back_inserter(op_type_list),
2717                          [](const AbstractBasePtr &op_abs) { return BuilidArgsTypeString(op_abs); });
2718     MS_EXCEPTION(TypeError) << ops::BuildOpErrorMsg(op_def, op_type_list);
2719   }
2720 
2721   // Create New node.
2722   AnfNodePtrList input_nodes{NewValueNode(prim)};
2723   (void)std::copy(call_nodes.cbegin(), call_nodes.cend(), std::back_inserter(input_nodes));
2724   (void)std::copy(init_nodes.cbegin(), init_nodes.cend(), std::back_inserter(input_nodes));
2725   auto new_cnode = graph->NewCNodeInOrder(input_nodes);
2726   return new_cnode;
2727 }
2728 
CheckAndConvertPrimitiveArgs(const PrimitivePtr & prim,const std::pair<std::vector<AnfNodePtr>,std::vector<AnfNodePtr>> & args_pair,const AnalysisEnginePtr & engine,const AnfNodeConfigPtr & out_conf,bool is_preprocessed)2729 AnfNodePtr CheckAndConvertPrimitiveArgs(const PrimitivePtr &prim,
2730                                         const std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> &args_pair,
2731                                         const AnalysisEnginePtr &engine, const AnfNodeConfigPtr &out_conf,
2732                                         bool is_preprocessed) {
2733   auto graph = out_conf->node()->func_graph();
2734   MS_EXCEPTION_IF_NULL(graph);
2735 
2736   auto eval_func = [&engine, &out_conf](const AnfNodePtr &node) {
2737     AnfNodeConfigPtr config = engine->MakeConfig(node, out_conf->context(), out_conf->func_graph());
2738     MS_EXCEPTION_IF_NULL(config);
2739     const auto &eval_result = config->ObtainEvalResult();
2740     MS_EXCEPTION_IF_NULL(eval_result);
2741     return eval_result->abstract();
2742   };
2743 
2744   auto new_cnode = CheckAndConvertPrimitiveArgs(prim, graph, args_pair, eval_func, is_preprocessed);
2745   MS_LOG(INFO) << "Convert primitive args: " << prim->name() << ". node: " << out_conf->node()->DebugString()
2746                << ", new_node: " << new_cnode->DebugString();
2747   return new_cnode;
2748 }
2749 
ConvertArgsToInputs(const PrimitivePtr & prim,const AnfNodeWeakPtrList & inputs,const FuncGraphPtr & fg,const AnalysisEnginePtr & engine,const AnfNodeConfigPtr & out_conf)2750 AnfNodePtr ConvertArgsToInputs(const PrimitivePtr &prim, const AnfNodeWeakPtrList &inputs, const FuncGraphPtr &fg,
2751                                const AnalysisEnginePtr &engine, const AnfNodeConfigPtr &out_conf) {
2752   // Append Primitive arguments to the inputs.
2753   auto prim_py = prim->cast<PrimitivePyPtr>();
2754   MS_EXCEPTION_IF_NULL(prim_py);
2755   auto op_def = mindspore::ops::GetOpDef(prim->name());
2756   MS_EXCEPTION_IF_NULL(op_def);
2757   // Get init args.
2758   const AnfNodePtrList &prim_init_arg_nodes = GetPrimitiveInitArgs(prim_py, op_def);
2759 
2760   // Get call args.
2761   AnfNodePtrList prim_call_arg_nodes;
2762   (void)std::transform(inputs.cbegin() + 1, inputs.cend(), std::back_inserter(prim_call_arg_nodes),
2763                        [](const AnfNodeWeakPtr &weak_node) {
2764                          const auto &node = weak_node.lock();
2765                          MS_EXCEPTION_IF_NULL(node);
2766                          return node;
2767                        });
2768   // Create new node.
2769   auto new_prim = std::make_shared<Primitive>(*prim);
2770   auto args_pair = std::make_pair(prim_init_arg_nodes, prim_call_arg_nodes);
2771   return CheckAndConvertPrimitiveArgs(new_prim, args_pair, engine, out_conf, true);
2772 }
2773 }  // namespace
2774 
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)2775 EvalResultPtr PrimitiveArgsToInputsEvaluator::EvalPrim(const AnalysisEnginePtr &engine,
2776                                                        const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
2777                                                        const AnfNodeConfigPtr &out_conf) {
2778   // Convert primitive args to inputs.
2779   MS_EXCEPTION_IF_NULL(out_conf);
2780   auto cnode = out_conf->node()->cast<CNodePtr>();
2781   MS_EXCEPTION_IF_NULL(cnode);
2782   auto fg = cnode->func_graph();
2783   MS_EXCEPTION_IF_NULL(fg);
2784 
2785   constexpr size_t index_op = 0;
2786   constexpr size_t index_data = 1;
2787   auto op_node = cnode->input(index_op);
2788   AnfNodePtr new_node = nullptr;
2789   parse::SymbolPtr symbol_node = nullptr;
2790   if (op_node->isa<CNode>()) {
2791     auto inner_op_node = op_node->cast<CNodePtr>()->input(index_op);
2792     if (IsPrimitiveCNode(inner_op_node, prim::kPrimResolve)) {
2793       auto resolve_node = inner_op_node->cast<CNodePtr>();
2794       constexpr size_t index_symbol = 2;
2795       symbol_node = GetValueNode<parse::SymbolPtr>(resolve_node->input(index_symbol));
2796     }
2797   }
2798   if (IsPrimitiveCNode(op_node, prim::kPrimPartial)) {
2799     // The input may be a Partial node, such as {{prim::kPrimPartial, prim::kPrimRank, x}} -> {prim::kPrimRank, x}.
2800     AnfNodeWeakPtrList partial_inputs;
2801     auto op_cnode = op_node->cast<CNodePtr>();
2802     (void)std::copy(op_cnode->weak_inputs().begin() + index_data, op_cnode->weak_inputs().end(),
2803                     std::back_inserter(partial_inputs));
2804     (void)std::copy(cnode->weak_inputs().begin() + index_data, cnode->weak_inputs().end(),
2805                     std::back_inserter(partial_inputs));
2806     new_node = ConvertArgsToInputs(prim_, partial_inputs, fg, engine, out_conf);
2807   } else if (IsPrimitiveCNode(op_node, prim::kPrimGetAttr) ||
2808              IsPrimitiveCNodeWithoutDoSignature(op_node, prim::kPrimGetAttr) ||
2809              (symbol_node != nullptr && symbol_node->symbol() == "getattr")) {
2810     // The input may be a GetAttr node, such as x.abs(): {{prim::kPrimGetAttr, x, abs}} -> {prim::kPrimAbs, x}
2811     auto op_cnode = op_node->cast<CNodePtr>();
2812     AnfNodeWeakPtrList getattr_inputs;
2813     auto new_prim = std::make_shared<Primitive>(prim_->name());
2814     auto new_prim_node = NewValueNode(new_prim);
2815     (void)getattr_inputs.emplace_back(new_prim_node);
2816     (void)getattr_inputs.emplace_back(op_cnode->input(index_data));
2817     (void)std::copy(cnode->weak_inputs().begin() + index_data, cnode->weak_inputs().end(),
2818                     std::back_inserter(getattr_inputs));
2819     new_node = ConvertArgsToInputs(prim_, getattr_inputs, fg, engine, out_conf);
2820   } else {
2821     constexpr int recursive_level = 2;
2822     new_node = ConvertArgsToInputs(prim_, cnode->weak_inputs(), fg, engine, out_conf);
2823     MS_LOG(DEBUG) << "Convert args to inputs for Operator[" << prim_->name()
2824                   << "], node: " << cnode->DebugString(recursive_level);
2825   }
2826 
2827   new_node->set_debug_info(cnode->debug_info());
2828   auto new_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
2829   MS_LOG(INFO) << "Convert primitive args to inputs: " << prim_->ToString() << ". node: " << cnode->DebugString()
2830                << ", new node: " << new_node->DebugString();
2831   return engine->ForwardConfig(out_conf, new_conf);
2832 }
2833 
2834 namespace {
ConvertWeakNode(const AnfNodeWeakPtr & weak_node)2835 AnfNodePtr ConvertWeakNode(const AnfNodeWeakPtr &weak_node) {
2836   const auto &node = weak_node.lock();
2837   MS_EXCEPTION_IF_NULL(node);
2838   return node;
2839 }
2840 }  // namespace
2841 
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)2842 EvalResultPtr DoTransPrimitiveFunctionEvaluator::EvalPrim(const AnalysisEnginePtr &engine,
2843                                                           const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
2844                                                           const AnfNodeConfigPtr &out_conf) {
2845   // For PrimitiveFunction generated by CreateInstance, its args, labels, signatures and
2846   // implicit conversion need to be processed.
2847   auto do_trans_prim_func = prim_->cast<prim::DoTransPrimitiveFunctionPtr>();
2848   MS_EXCEPTION_IF_NULL(do_trans_prim_func);
2849   auto prim_func = do_trans_prim_func->function();
2850   MS_EXCEPTION_IF_NULL(prim_func);
2851   auto cnode = out_conf->node()->cast<CNodePtr>();
2852   MS_EXCEPTION_IF_NULL(cnode);
2853   auto fg = cnode->func_graph();
2854   MS_EXCEPTION_IF_NULL(fg);
2855 
2856   auto prim_name = prim_func->name();
2857   auto op_def = mindspore::ops::GetOpDef(prim_name);
2858   if (op_def == nullptr) {
2859     MS_LOG(INTERNAL_EXCEPTION) << "DoTransPrimitiveFunction only supports Primitive with OpDef, but got " << prim_name
2860                                << ".";
2861   }
2862   if (cnode->size() != args_abs_list.size() + 1) {
2863     MS_LOG(INTERNAL_EXCEPTION) << "For Operator[" << prim_name << "], the number of cnode inputs should be "
2864                                << args_abs_list.size() + 1 << ", but got " << cnode->size()
2865                                << ".\nnode: " << cnode->DebugString();
2866   }
2867   // Handle primitive labels.
2868   AddLabelsToPrimitiveFunction(prim_func);
2869   // Handle primitive signatures.
2870   auto arg_signatures = op_def->signatures_;
2871   prim_func->set_signatures(arg_signatures);
2872   prim_func->set_has_signature(!arg_signatures.empty());
2873   // Get init args size.
2874   size_t init_args_size = 0;
2875   if (do_trans_prim_func->has_given_init_size()) {
2876     // Might need to handle default arguments.
2877     init_args_size = do_trans_prim_func->given_init_size();
2878   } else {
2879     // All call args and init args should have been provided.
2880     size_t op_args_size = op_def->args_.size();
2881     if (op_args_size != args_abs_list.size()) {
2882       MS_EXCEPTION(TypeError) << "For Operator['" << prim_name
2883                               << "]', the number of inputs and init args (including default arguments) should be "
2884                               << op_args_size << ", but got " << args_abs_list.size() << ". ";
2885     }
2886     for (size_t i = 0; i < op_args_size; i++) {
2887       if (op_def->args_[i].as_init_arg_) {
2888         ++init_args_size;
2889       }
2890     }
2891   }
2892 
2893   // Get init args and call args.
2894   AnfNodePtrList prim_init_arg_nodes;
2895   (void)std::transform(cnode->weak_inputs().cbegin() + cnode->size() - init_args_size, cnode->weak_inputs().cend(),
2896                        std::back_inserter(prim_init_arg_nodes), ConvertWeakNode);
2897   AnfNodePtrList prim_call_arg_nodes;
2898   (void)std::transform(cnode->weak_inputs().cbegin() + 1, cnode->weak_inputs().cend() - init_args_size,
2899                        std::back_inserter(prim_call_arg_nodes), ConvertWeakNode);
2900 
2901   auto args_pair = std::make_pair(prim_init_arg_nodes, prim_call_arg_nodes);
2902   auto new_cnode = CheckAndConvertPrimitiveArgs(prim_func, args_pair, engine, out_conf, false);
2903   auto new_conf = engine->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
2904   MS_LOG(INFO) << "Prim: " << prim_func->name() << ", " << cnode->DebugString() << ", " << new_cnode->DebugString();
2905   return engine->ForwardConfig(out_conf, new_conf);
2906 }
2907 
GetInitArgsFromUnpackCall(const prim::DoTransPrimitiveFunctionPtr & do_trans_prim,const CNodePtr & unpack_call_cnode,const AnalysisEnginePtr & engine,const AnfNodeConfigPtr & out_conf)2908 AnfNodePtrList GetInitArgsFromUnpackCall(const prim::DoTransPrimitiveFunctionPtr &do_trans_prim,
2909                                          const CNodePtr &unpack_call_cnode, const AnalysisEnginePtr &engine,
2910                                          const AnfNodeConfigPtr &out_conf) {
2911   auto prim = do_trans_prim->function();
2912   auto op_def = mindspore::ops::GetOpDef(prim->name());
2913   MS_EXCEPTION_IF_NULL(op_def);
2914 
2915   AnfNodePtrList new_inputs;
2916   std::map<std::string, AnfNodePtr> key_map;
2917   auto fg = out_conf->node()->func_graph();
2918   constexpr size_t inputs_start_index = 2;
2919   for (size_t index = inputs_start_index; index < unpack_call_cnode->size(); index++) {
2920     auto input = unpack_call_cnode->input(index);
2921     AnfNodeConfigPtr config = engine->MakeConfig(input, out_conf->context(), out_conf->func_graph());
2922     MS_EXCEPTION_IF_NULL(config);
2923     const auto &eval_result = config->ObtainEvalResult();
2924     MS_EXCEPTION_IF_NULL(eval_result);
2925     auto input_abs = eval_result->abstract();
2926     if (input_abs->isa<AbstractDictionary>()) {
2927       auto dict_elems = input_abs->cast<AbstractDictionaryPtr>()->elements();
2928       for (const auto &elem : dict_elems) {
2929         auto key = GetValue<std::string>(elem.first->BuildValue());
2930         auto elem_value = fg->NewCNode({NewValueNode(prim::kPrimDictGetItem), input, NewValueNode(key)});
2931         key_map[key] = elem_value;
2932       }
2933     } else if (input_abs->isa<AbstractTuple>()) {
2934       auto arg_tuple = input_abs->cast<AbstractTuplePtr>();
2935       for (size_t i = 0; i < arg_tuple->size(); ++i) {
2936         MS_LOG(DEBUG) << "Get args for Primitive[" << prim->name() << "]: " << input->DebugString() << ", i: " << i;
2937         (void)new_inputs.emplace_back(
2938           fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(SizeToLong(i))}));
2939       }
2940     } else if (input_abs->isa<AbstractList>()) {
2941       auto arg_list = input_abs->cast<AbstractListPtr>();
2942       for (size_t i = 0; i < arg_list->size(); ++i) {
2943         MS_LOG(DEBUG) << "Get args for Primitive[" << prim->name() << "]: " << input->DebugString() << ", i: " << i;
2944         (void)new_inputs.emplace_back(
2945           fg->NewCNode({NewValueNode(prim::kPrimListGetItem), input, NewValueNode(SizeToLong(i))}));
2946       }
2947     } else {
2948       MS_LOG(INTERNAL_EXCEPTION) << "The arguments of UnpackCall operator should be tuple, list or dict, but got "
2949                                  << input_abs->ToString();
2950     }
2951   }
2952 
2953   // Handle variable arguments.
2954   auto op_args = op_def->args_;
2955   auto inputs_size = new_inputs.size();
2956   size_t index = 0;
2957   size_t init_args_num = 0;
2958   for (const auto &op_arg : op_args) {
2959     if (!(op_arg.as_init_arg_)) {
2960       continue;
2961     }
2962     init_args_num++;
2963     if (index < inputs_size) {
2964       index++;
2965       continue;
2966     }
2967     auto arg_name = op_arg.arg_name_;
2968     auto iter = key_map.find(arg_name);
2969     if (iter != key_map.end()) {
2970       MS_LOG(DEBUG) << "Get args for Primitive[" << prim->name() << "]: " << iter->second->DebugString();
2971       (void)new_inputs.emplace_back(iter->second);
2972       (void)key_map.erase(arg_name);
2973     } else {
2974       auto default_value = parse::GetArgDefaultValue(prim->name(), arg_name);
2975       if (default_value == nullptr) {
2976         MS_EXCEPTION(TypeError) << "For Operator[" << prim->name() << "], there is no matching input for argument '"
2977                                 << arg_name << "'.";
2978       }
2979       MS_LOG(DEBUG) << "Get args for Primitive[" << prim->name() << "]: " << default_value->ToString();
2980       (void)new_inputs.emplace_back(NewValueNode(default_value));
2981     }
2982   }
2983   if (init_args_num < new_inputs.size()) {
2984     MS_EXCEPTION(TypeError) << "For Operator[" << prim->name() << "], the number of init arguments should be "
2985                             << init_args_num << ", but got " << new_inputs.size() << ".";
2986   }
2987   if (!key_map.empty()) {
2988     std::stringstream ss;
2989     ss << "For Operator[" << prim->name() << "], there are unmatched arguments: ";
2990     for (const auto &elem : key_map) {
2991       ss << elem.first << " ";
2992     }
2993     ss << ".";
2994     MS_EXCEPTION(TypeError) << ss.str();
2995   }
2996   do_trans_prim->set_given_init_size(new_inputs.size());
2997   return new_inputs;
2998 }
2999 
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)3000 EvalResultPtr PartialToEndEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list,
3001                                               const ConfigPtr &, const AnfNodeConfigPtr &out_conf) {
3002   // Convert Partial{Prim, a, b}(x, y) to {Prim, x, y, a, b}.
3003   auto prim = primal_func_->BuildValue();
3004   MS_EXCEPTION_IF_NULL(prim);
3005   AnfNodePtrList new_inputs{NewValueNode(prim)};
3006   auto do_trans_prim = prim->cast<prim::DoTransPrimitiveFunctionPtr>();
3007   MS_EXCEPTION_IF_NULL(do_trans_prim);
3008   // Add inputs: x, y.
3009   MS_EXCEPTION_IF_NULL(out_conf);
3010   auto cnode = out_conf->node()->cast<CNodePtr>();
3011   MS_EXCEPTION_IF_NULL(cnode);
3012   for (size_t i = 1; i < cnode->size(); i++) {
3013     (void)new_inputs.emplace_back(cnode->input(i));
3014   }
3015   // Add args: a, b.
3016   constexpr size_t op_index = 0;
3017   auto partial_node = cnode->input(op_index);
3018   MS_EXCEPTION_IF_NULL(partial_node);
3019   auto partial_cnode = partial_node->cast<CNodePtr>();
3020   if (partial_cnode == nullptr) {
3021     MS_EXCEPTION(TypeError) << "For Primitive[" << prim->ToString()
3022                             << "], only positional arguments as inputs are supported, but got "
3023                             << partial_node->DebugString() << ".";
3024   }
3025   if (IsValueNode<prim::UnpackCall>(partial_cnode->input(op_index))) {
3026     auto unpack_call_args = GetInitArgsFromUnpackCall(do_trans_prim, partial_cnode, engine, out_conf);
3027     (void)std::copy(unpack_call_args.cbegin(), unpack_call_args.cend(), std::back_inserter(new_inputs));
3028   } else {
3029     (void)std::transform(partial_cnode->weak_inputs().cbegin() + 1, partial_cnode->weak_inputs().cend(),
3030                          std::back_inserter(new_inputs), [](const auto &weak_node) {
3031                            const auto &node = weak_node.lock();
3032                            MS_EXCEPTION_IF_NULL(node);
3033                            return node;
3034                          });
3035   }
3036 
3037   auto fg = cnode->func_graph();
3038   MS_EXCEPTION_IF_NULL(fg);
3039   auto new_cnode = fg->NewCNodeInOrder(new_inputs);
3040   auto new_conf = engine->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
3041   constexpr auto recursive_level = 2;
3042   MS_LOG(INFO) << "For Primitive[" << prim->ToString() << "], convert partial node "
3043                << cnode->DebugString(recursive_level) << " to new cnode " << new_cnode->DebugString(recursive_level);
3044   return engine->ForwardConfig(out_conf, new_conf);
3045 }
3046 
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)3047 EvalResultPtr ConstexprEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list,
3048                                            const ConfigPtr &, const AnfNodeConfigPtr &out_conf) {
3049   // Consider all primitive implemented python infer() real use the tuple/list arguments.
3050   CheckSequenceArgumentForPythonPrimitive(prim_py_, args_abs_list);
3051   MS_EXCEPTION_IF_NULL(prim_py_);
3052   auto py_args = PreparePyInputs(args_abs_list);
3053   prim_py_->BeginRecordAddAttr();
3054   py::dict output = prim_py_->RunInfer(py_args);
3055   prim_py_->EndRecordAddAttr();
3056   if (output.contains("fn")) {
3057     // The inputs contain variable, the constexpr will run as graph.
3058     py::tuple values = output["fn"];
3059     if (values.empty()) {
3060       MS_LOG(EXCEPTION) << "Can not get origin function from constexpr.";
3061     }
3062     auto inner_val = parse::ParsePythonCode(values[0]);
3063     MS_EXCEPTION_IF_NULL(inner_val);
3064     auto inner_fg = dyn_cast<FuncGraph>(inner_val);
3065     MS_EXCEPTION_IF_NULL(inner_fg);
3066     MS_EXCEPTION_IF_NULL(out_conf);
3067     auto cur_graph = out_conf->func_graph();
3068     MS_EXCEPTION_IF_NULL(cur_graph);
3069     auto mng = cur_graph->manager();
3070     MS_EXCEPTION_IF_NULL(mng);
3071     inner_fg->set_manager(mng);
3072     auto out_node = out_conf->node();
3073     MS_EXCEPTION_IF_NULL(out_node);
3074     auto out_cnode = dyn_cast<CNode>(out_node);
3075     MS_EXCEPTION_IF_NULL(out_cnode);
3076     FuncGraphPtr func_graph = out_node->func_graph();
3077     MS_EXCEPTION_IF_NULL(func_graph);
3078     std::vector<AnfNodePtr> new_cnode_inputs = {NewValueNode(inner_fg)};
3079     const auto &out_cnode_inputs = out_cnode->weak_inputs();
3080     (void)std::transform(out_cnode_inputs.cbegin() + 1, out_cnode_inputs.cend(), std::back_inserter(new_cnode_inputs),
3081                          [](const auto &weak_node) {
3082                            const auto &node = weak_node.lock();
3083                            MS_EXCEPTION_IF_NULL(node);
3084                            return node;
3085                          });
3086     auto new_node = func_graph->NewCNodeInOrder(new_cnode_inputs);
3087     AnalysisEnginePtr eng = out_conf->engine();
3088     MS_EXCEPTION_IF_NULL(eng);
3089     AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
3090     return eng->ForwardConfig(out_conf, fn_conf);
3091   }
3092   // If all inputs are constant value, use python prim evaluator.
3093   // Ensure input arguments are evaluated.
3094   auto res_abstract = EvalUndeterminedArgs(args_abs_list);
3095   if (res_abstract != nullptr) {
3096     MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined";
3097     return res_abstract;
3098   }
3099   auto forbid_reuse = prim_py_->HasAttr(GRAPH_FLAG_FORBID_REUSE_RESULT);
3100   if (!forbid_reuse) {
3101     // Try to get infer result from evaluator cache.
3102     EvalResultPtr eval_result = evaluator_cache_mgr_->GetValue(args_abs_list);
3103     if (eval_result != nullptr) {
3104       MS_EXCEPTION_IF_NULL(eval_result->abstract());
3105       return std::make_shared<EvalResult>(eval_result->abstract()->Clone(), eval_result->attribute());
3106     }
3107   }
3108   const auto &added_attrs = prim_py_->evaluate_added_attrs();
3109   MS_LOG(DEBUG) << "Output type is " << py::str(output);
3110   auto res_abs = PyInferRes2Abstract(prim_py_, output);
3111   MS_EXCEPTION_IF_NULL(res_abs);
3112   MS_LOG(DEBUG) << "Python InferTensor result abstract: " << res_abs->ToString();
3113   EvalResultPtr eval_result = std::make_shared<EvalResult>(res_abs, std::make_shared<AttrValueMap>(added_attrs));
3114   evaluator_cache_mgr_->SetValue(args_abs_list, eval_result);
3115   return eval_result;
3116 }
3117 
EvalPrim(const AnalysisEnginePtr &,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)3118 EvalResultPtr MakeTupleEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list,
3119                                            const ConfigPtr &, const AnfNodeConfigPtr &out_conf) {
3120   std::shared_ptr<AnfNodeWeakPtrList> sequence_nodes = std::make_shared<AnfNodeWeakPtrList>();
3121   auto abs = std::make_shared<AbstractTuple>(args_abs_list, sequence_nodes);
3122   if (out_conf != nullptr) {  // 'out_conf' maybe nullptr in PyNative mode.
3123     if (args_abs_list.empty()) {
3124       MS_EXCEPTION_IF_NULL(out_conf->node());
3125       MS_LOG(INFO) << "For MakeTuple, the inputs should not be empty. node: " << out_conf->node()->DebugString();
3126     }
3127     static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
3128     if (enable_eliminate_unused_element) {
3129       auto flags = GetSequenceNodeElementsUseFlags(out_conf->node());
3130       if (flags == nullptr) {
3131         SetSequenceNodeElementsUseFlags(out_conf->node(), std::make_shared<std::vector<bool>>(args_abs_list.size()));
3132       }
3133       bool has_any = fallback::ContainsSequenceAnyType(abs);
3134       if (has_any) {
3135         SetSequenceElementsUseFlagsRecursively(abs, true);
3136       }
3137       (void)sequence_nodes->emplace_back(AnfNodeWeakPtr(out_conf->node()));
3138     }
3139   }
3140   auto res = std::make_shared<EvalResult>(abs, std::make_shared<AttrValueMap>());
3141   evaluator_cache_mgr_->SetValue(args_abs_list, res);
3142   // pass the need_unpack tag from the AnfNode to the abstract
3143   if (out_conf != nullptr) {
3144     auto node = out_conf->node();
3145     constexpr auto need_unpack_str = "need_unpack";
3146     auto need_unpack = node->user_data<bool>(need_unpack_str);
3147     if (need_unpack != nullptr && *need_unpack) {
3148       abs->SetData<bool>(need_unpack_str, std::make_shared<bool>(true));
3149     }
3150   }
3151   return res;
3152 }
3153 
EvalPrim(const AnalysisEnginePtr &,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)3154 EvalResultPtr MakeListEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list,
3155                                           const ConfigPtr &, const AnfNodeConfigPtr &out_conf) {
3156   std::shared_ptr<AnfNodeWeakPtrList> sequence_nodes = std::make_shared<AnfNodeWeakPtrList>();
3157   auto abs = std::make_shared<AbstractList>(args_abs_list, sequence_nodes);
3158   if (out_conf != nullptr) {  // 'out_conf' maybe nullptr in PyNative mode.
3159     if (args_abs_list.empty()) {
3160       MS_LOG(INFO) << "For MakeList, the inputs should not be empty. node: " << out_conf->node()->DebugString();
3161     }
3162     static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
3163     if (enable_eliminate_unused_element) {
3164       auto flags = GetSequenceNodeElementsUseFlags(out_conf->node());
3165       if (flags == nullptr) {
3166         SetSequenceNodeElementsUseFlags(out_conf->node(), std::make_shared<std::vector<bool>>(args_abs_list.size()));
3167       }
3168 
3169       (void)sequence_nodes->emplace_back(AnfNodeWeakPtr(out_conf->node()));
3170       bool has_any = fallback::ContainsSequenceAnyType(abs);
3171       if (has_any) {
3172         SetSequenceElementsUseFlagsRecursively(abs, true);
3173       }
3174     }
3175   }
3176   MS_LOG(DEBUG) << "Generate python object for new value node.";
3177   if (fallback::EnableFallbackListDictInplace()) {
3178     py::object py_list_obj = fallback::GeneratePyObj(abs);
3179     fallback::AttachPyObjToAbs(abs, py_list_obj, true);
3180   }
3181   auto res = std::make_shared<EvalResult>(abs, std::make_shared<AttrValueMap>());
3182   evaluator_cache_mgr_->SetValue(args_abs_list, res);
3183   return res;
3184 }
3185 
CreateRealAbstract(const TypePtr & preset_type,const BaseShapePtr & shape,const AnfNodePtr & node,const AbstractBasePtrList & args_abs_list)3186 AbstractBasePtr CreateRealAbstract(const TypePtr &preset_type, const BaseShapePtr &shape, const AnfNodePtr &node,
3187                                    const AbstractBasePtrList &args_abs_list) {
3188   AbstractBasePtr res = nullptr;
3189   if (preset_type->isa<Scalar>()) {
3190     res = std::make_shared<AbstractScalar>(preset_type);
3191   } else if (preset_type->isa<List>() || preset_type->isa<Tuple>()) {
3192     res = fallback::GenerateAbstractSequence(shape, preset_type, true);
3193   } else if (preset_type->isa<TensorType>() && !preset_type->isa<AnyType>()) {
3194     auto tensor_type = preset_type->cast_ptr<TensorType>();
3195     MS_EXCEPTION_IF_NULL(tensor_type);
3196     auto element = std::make_shared<abstract::AbstractScalar>(kValueAny, tensor_type->element());
3197     res = std::make_shared<abstract::AbstractTensor>(element, shape);
3198     auto abs_tensor = res->cast_ptr<abstract::AbstractTensor>();
3199     if (node->has_user_data(fallback::kIsAdapter)) {
3200       abs_tensor->set_is_adapter(true);
3201     }
3202   } else {
3203     const auto any_abstract = std::make_shared<AbstractAny>();
3204     // If no annotation dtype, try to use unique tensor dtype.
3205     auto dtype = GetLocalArgsUniqueDtype(node, args_abs_list);
3206     if (dtype != nullptr) {
3207       MS_EXCEPTION_IF_NULL(any_abstract->element());
3208       any_abstract->element()->set_type(dtype);
3209       any_abstract->set_supposed_tensor_dtype(true);
3210     }
3211     res = any_abstract;
3212   }
3213   return res;
3214 }
3215 
EvalPrim(const AnalysisEnginePtr &,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)3216 EvalResultPtr PyExecuteEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list,
3217                                            const ConfigPtr &, const AnfNodeConfigPtr &out_conf) {
3218   MS_EXCEPTION_IF_NULL(out_conf);
3219   if (args_abs_list.empty()) {
3220     MS_LOG(INTERNAL_EXCEPTION) << "'args_abs_list' should not be empty";
3221   }
3222 
3223   // Handle for DDE.
3224   for (size_t i = 0; i < args_abs_list.size(); ++i) {
3225     MS_EXCEPTION_IF_NULL(args_abs_list[i]);
3226     if (args_abs_list[i]->isa<abstract::AbstractSequence>()) {
3227       MS_LOG(DEBUG) << "Primitive \'PyExecute\' is consuming tuple/list arguments[" << i
3228                     << "]: " << args_abs_list[i]->ToString();
3229       SetSequenceElementsUseFlagsRecursively(args_abs_list[i], true);
3230     }
3231   }
3232 
3233   auto node = out_conf->node();
3234   MS_EXCEPTION_IF_NULL(node);
3235   MS_LOG(DEBUG) << "The current pyexecute node: " << node->DebugString();
3236   // Get the type parameter.
3237   MS_EXCEPTION_IF_NULL(args_abs_list[0]);
3238   ValuePtr script_value_track = args_abs_list[0]->GetValueTrack();
3239   MS_EXCEPTION_IF_NULL(script_value_track);
3240   auto script_obj = dyn_cast_ptr<StringImm>(script_value_track);
3241   if (script_obj == nullptr) {
3242     MS_LOG(INTERNAL_EXCEPTION) << "Cast value failed, not PyObjectWrapper: " << script_value_track->ToString() << ".";
3243   }
3244 
3245   // Make global and local parameters.
3246   const std::string &script = script_obj->value();
3247   // Call python script string.
3248   MS_LOG(DEBUG) << "Call script: " << script << ", args: " << args_abs_list;
3249   // Make abstract by type and shape.
3250   AbstractBasePtr res = nullptr;
3251   // Support Tensor annotation type. Add list and tuple here later.
3252   TypePtr dtype = nullptr;
3253   TypePtr type = GetAnnotationType(node, args_abs_list);
3254   if (type != nullptr && type->isa<TensorType>()) {
3255     dtype = type->cast<TensorTypePtr>()->element();
3256   }
3257   // Create output abstract.
3258   if (dtype != nullptr) {
3259     res = std::make_shared<AbstractTensor>(dtype, std::make_shared<Shape>(ShapeVector({Shape::kShapeRankAny})));
3260   } else if (fallback::HasRealType(node) && fallback::HasRealShape(node)) {
3261     const auto &preset_type = fallback::GetRealType<AnfNode, Type>(node);
3262     MS_LOG(DEBUG) << "preset_type: " << preset_type->ToString();
3263     const auto &shape = fallback::GetRealShape<AnfNode, BaseShape>(node);
3264     MS_LOG(DEBUG) << "shape: " << shape->ToString();
3265     res = CreateRealAbstract(preset_type, shape, node, args_abs_list);
3266   } else if (fallback::HasRealType(node) && fallback::GetRealType<AnfNode, Type>(node)->isa<NegligibleType>()) {
3267     res = std::make_shared<AbstractNegligible>();
3268   } else {
3269     const auto any_abstract = std::make_shared<AbstractAny>();
3270     // If no annotation dtype, try to use unique tensor dtype.
3271     dtype = GetLocalArgsUniqueDtype(node, args_abs_list);
3272     if (dtype != nullptr) {
3273       MS_EXCEPTION_IF_NULL(any_abstract->element());
3274       any_abstract->element()->set_type(dtype);
3275       any_abstract->set_supposed_tensor_dtype(true);
3276     }
3277     res = any_abstract;
3278   }
3279 
3280   // Set input real type and shape for caller.
3281   if (fallback::HasRealType(node)) {
3282     const auto &real_type = fallback::GetRealType<AnfNode, Type>(node);
3283     fallback::SetRealType<AbstractBase, Type>(res, real_type);
3284   }
3285   if (fallback::HasRealShape(node)) {
3286     const auto &real_shape = fallback::GetRealShape<AnfNode, BaseShape>(node);
3287     fallback::SetRealShape<AbstractBase, BaseShape>(res, real_shape);
3288   }
3289   if (res->isa<AbstractTensor>() && node->has_user_data(fallback::kAdapterTensor) &&
3290       *node->user_data<bool>(fallback::kAdapterTensor)) {
3291     auto res_tensor = res->cast<AbstractTensorPtr>();
3292     res_tensor->set_is_adapter(true);
3293   }
3294   auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
3295   evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
3296   return infer_result;
3297 }
3298 
3299 namespace {
3300 class PyInterpretEvaluator : public TransitionPrimEvaluator {
3301  public:
PyInterpretEvaluator()3302   PyInterpretEvaluator() : TransitionPrimEvaluator("PyInterpretEvaluator") {}
3303   ~PyInterpretEvaluator() override = default;
3304   MS_DECLARE_PARENT(PyInterpretEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr &,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)3305   EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
3306                          const AnfNodeConfigPtr &out_conf) override {
3307     if (args_abs_list.empty()) {
3308       MS_LOG(INTERNAL_EXCEPTION) << "'args_abs_list' should not be empty";
3309     }
3310     auto node = out_conf->node();
3311     MS_EXCEPTION_IF_NULL(node);
3312     MS_LOG(DEBUG) << "The current interpret node: " << node->DebugString();
3313 
3314     // If the interpret node contains FuncGraph node input, need to convert the Graph node to Interpreted object.
3315     AnfNodePtr converted_interpret_node = ConvertPyInterpretNode(node, args_abs_list);
3316     if (converted_interpret_node != nullptr) {
3317       AnalysisEnginePtr eng = out_conf->engine();
3318       MS_EXCEPTION_IF_NULL(eng);
3319       AnfNodeConfigPtr fn_conf = eng->MakeConfig(converted_interpret_node, out_conf->context(), out_conf->func_graph());
3320       return eng->ForwardConfig(out_conf, fn_conf);
3321     }
3322 
3323     non_const_err_ = false;
3324     check_list_dict_inplace_ =
3325       node->has_user_data(fallback::kCheckListDictInplace) && *node->user_data<bool>(fallback::kCheckListDictInplace);
3326 
3327     constexpr size_t script_index = 0;
3328     const std::string &script = GetScriptStr(args_abs_list[script_index]);
3329     // Make global and local parameters.
3330     py::tuple params = MakeParameters(args_abs_list, script);
3331     // Would convert PyInterpret to PyExecute then.
3332     if (non_const_err_ || fallback::GetJitAnnotationSideEffectFromComment(node)) {
3333       // Make abstract by type and shape.
3334       AbstractBasePtr res = nullptr;
3335       // Support Tensor annotation type. Add list and tuple here later.
3336       TypePtr dtype = nullptr;
3337       TypePtr type = GetAnnotationType(node, args_abs_list);
3338       if (type != nullptr && type->isa<TensorType>()) {
3339         dtype = type->cast<TensorTypePtr>()->element();
3340       }
3341       // Create output abstract.
3342       if (dtype != nullptr) {
3343         res = std::make_shared<AbstractTensor>(dtype, std::make_shared<Shape>(ShapeVector({Shape::kShapeRankAny})));
3344       } else {
3345         const auto any_abstract = std::make_shared<AbstractAny>();
3346         // If no annotation dtype, try to use unique tensor dtype.
3347         dtype = GetLocalArgsUniqueDtype(node, args_abs_list);
3348         if (dtype != nullptr) {
3349           MS_EXCEPTION_IF_NULL(any_abstract->element());
3350           any_abstract->element()->set_type(dtype);
3351           any_abstract->set_supposed_tensor_dtype(true);
3352         }
3353         res = any_abstract;
3354       }
3355       auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
3356       evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
3357       return infer_result;
3358     }
3359 
3360     // Call python script string.
3361     MS_LOG(DEBUG) << "Call script: " << script << ", params: " << py::str(params);
3362     auto obj = parse::data_converter::CallPythonScript(py::str(script), params);
3363     if (py::isinstance<py::none>(obj)) {
3364       AbstractBasePtr res = std::make_shared<abstract::AbstractNone>();
3365       auto infer_result = std::make_shared<EvalResult>(res, nullptr);
3366       evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
3367       return infer_result;
3368     }
3369 
3370     ValuePtr converted_val = nullptr;
3371     bool converted = false;
3372     // converted_val could be a InterpretedObject.
3373     if (node->has_user_data("__keep_metafg_obj_flag__")) {
3374       converted_val = std::make_shared<parse::InterpretedObject>(obj);
3375       converted = true;
3376     } else {
3377       converted = parse::ConvertData(obj, &converted_val, true);
3378     }
3379     if (!converted) {
3380       MS_LOG(INTERNAL_EXCEPTION) << "Convert the python object failed";
3381     }
3382     MS_EXCEPTION_IF_NULL(converted_val);
3383     auto fg = node->func_graph();
3384     MS_EXCEPTION_IF_NULL(fg);
3385     auto mng = fg->manager();
3386     MS_EXCEPTION_IF_NULL(mng);
3387     AddManagerForFuncGraphValue(converted_val, mng);
3388     if (converted_val->isa<tensor::Tensor>() && HasConstArgAttr(obj)) {
3389       MS_LOG(WARNING) << "The tensor " << converted_val->ToString()
3390                       << " which is not used for network input argument should not be set const.";
3391     }
3392     if (converted_val->isa<parse::InterpretedObject>()) {
3393       const auto interpreted_value = dyn_cast<parse::InterpretedObject>(converted_val);
3394       MS_LOG(DEBUG) << "The InterpretedObject(" << converted_val->ToString() << ") is converted by PyInterpret"
3395                     << " node: " << node->DebugString();
3396       interpreted_value->set_has_converted(true);
3397     }
3398 
3399     AbstractBasePtr res = ToAbstract(converted_val, AnalysisContext::DummyContext(), out_conf);
3400     auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
3401     evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
3402     return infer_result;
3403   }
3404 
AddManagerForFuncGraphValue(const ValuePtr & val,const FuncGraphManagerPtr & mng) const3405   void AddManagerForFuncGraphValue(const ValuePtr &val, const FuncGraphManagerPtr &mng) const {
3406     // mng has been checked before using.
3407     MS_EXCEPTION_IF_NULL(val);
3408     if (val->isa<ValueSequence>()) {
3409       auto val_seq = val->cast<ValueSequencePtr>();
3410       const auto &values = val_seq->value();
3411       std::for_each(values.begin(), values.end(),
3412                     [this, &mng](const ValuePtr &e) { AddManagerForFuncGraphValue(e, mng); });
3413       return;
3414     }
3415     if (val->isa<ValueDictionary>()) {
3416       auto val_dict = val->cast<ValueDictionaryPtr>();
3417       const auto &values = val_dict->value();
3418       std::for_each(values.begin(), values.end(), [this, &mng](const std::pair<ValuePtr, ValuePtr> &pair) {
3419         // Key for value dictionary can not have function graph.
3420         AddManagerForFuncGraphValue(pair.second, mng);
3421       });
3422       return;
3423     }
3424     if (val->isa<FuncGraph>()) {
3425       auto val_fg = val->cast<FuncGraphPtr>();
3426       if (val_fg->manager() == nullptr) {
3427         mng->AddFuncGraph(val_fg);
3428         val_fg->set_manager(mng);
3429       }
3430     }
3431     return;
3432   }
3433 
CheckInterpretInput(const AbstractDictionaryPtr & abstract_dict,const std::string & script) const3434   void CheckInterpretInput(const AbstractDictionaryPtr &abstract_dict, const std::string &script) const {
3435     // Check whether this node should be interpretive executed.
3436     MS_EXCEPTION_IF_NULL(abstract_dict);
3437     const auto &elements = abstract_dict->elements();
3438     if (elements.empty()) {
3439       return;
3440     }
3441     for (const auto &element : elements) {
3442       const auto &name = element.first;
3443       const auto &local_abs = element.second;
3444       MS_EXCEPTION_IF_NULL(local_abs);
3445       const auto &local_abs_val = local_abs->BuildValue();
3446       MS_EXCEPTION_IF_NULL(local_abs_val);
3447       MS_EXCEPTION_IF_NULL(name);
3448       auto py_data_name = py::str(ValueToPyData(name->BuildValue()));
3449       bool has_python_obj = check_list_dict_inplace_ && fallback::HasObjInExtraInfoHolder(local_abs);
3450       if (local_abs_val->ContainsValueAny() || has_python_obj) {
3451         const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() == kLax);
3452         if (allow_fallback_runtime) {
3453           MS_LOG(INFO) << "When using JIT Fallback to handle script '" << script
3454                        << "', the inputs should be constant, but found variable '" << py_data_name
3455                        << "' to be nonconstant. To convert to PyExecute() afterwards";
3456           non_const_err_ = true;
3457         } else {
3458           MS_EXCEPTION(ValueError) << "When handling script '" << script << " in graph mode"
3459                                    << "', the inputs should be constant, but found variable '" << py_data_name
3460                                    << "' to be nonconstant. Try to set jit_syntax_level to LAX.";
3461         }
3462       }
3463     }
3464   }
3465 
AddGlobalPythonFunction(const AbstractDictionaryPtr & global_dict,py::object * global_params_dict) const3466   void AddGlobalPythonFunction(const AbstractDictionaryPtr &global_dict, py::object *global_params_dict) const {
3467     MS_EXCEPTION_IF_NULL(global_dict);
3468     MS_EXCEPTION_IF_NULL(global_params_dict);
3469     const auto &global_dict_elements = global_dict->elements();
3470     for (const auto &element : global_dict_elements) {
3471       const auto &element_name = element.first;
3472       const auto &element_abs = element.second;
3473       MS_EXCEPTION_IF_NULL(element_name);
3474       MS_EXCEPTION_IF_NULL(element_abs);
3475       const auto &fn_py_obj = fallback::GetPyObjForFuncGraphAbstractClosure(element_abs);
3476       if (!py::isinstance<py::none>(fn_py_obj)) {
3477         (*global_params_dict)[ValueToPyData(element_name->BuildValue())] = fn_py_obj;
3478         MS_LOG(DEBUG) << "Found global python function object for " << element_name << ", add it to global dict.";
3479       }
3480     }
3481     return;
3482   }
3483 
MakeParameters(const AbstractBasePtrList & args_abs_list,const std::string & script) const3484   py::tuple MakeParameters(const AbstractBasePtrList &args_abs_list, const std::string &script) const {
3485     constexpr int params_size = 3;
3486     auto args_size = std::count_if(args_abs_list.begin(), args_abs_list.end(),
3487                                    [](const AbstractBasePtr &arg) -> bool { return !arg->isa<AbstractMonad>(); });
3488     if (params_size != args_size) {
3489       MS_LOG(INTERNAL_EXCEPTION) << "Unexpected params_size: " << params_size
3490                                  << ", not equal to arguments.size: " << args_abs_list.size();
3491     }
3492     // The first argument is script string, ignore it.
3493     auto params = py::tuple(params_size - 1);
3494 
3495     // Make the global parameters.
3496     constexpr size_t global_index = 1;
3497     auto global_abs = args_abs_list[global_index];
3498     const py::object &global_params_dict = GetGlobalObject(global_abs);
3499     params[0] = global_params_dict;
3500 
3501     // Make the local parameters.
3502     constexpr size_t local_index = 2;
3503     auto local_dict = dyn_cast<AbstractDictionary>(args_abs_list[local_index]);  // Local parameters dict.
3504     if (local_dict == nullptr) {
3505       MS_EXCEPTION_IF_NULL(args_abs_list[local_index]);
3506       MS_LOG(INTERNAL_EXCEPTION) << "The third argument should be a dictionary, but got "
3507                                  << args_abs_list[local_index]->ToString();
3508     }
3509     auto filtered_local_dict = FilterParameters(local_dict);
3510     MS_LOG(DEBUG) << "arg_2, local_dict: " << local_dict->ToString()
3511                   << ", filtered_local_dict: " << filtered_local_dict->ToString();
3512     ValuePtr local_dict_value = filtered_local_dict->BuildValue();
3513     MS_EXCEPTION_IF_NULL(local_dict_value);
3514     py::dict local_params_dict = ReCheckLocalDict(filtered_local_dict);
3515     MS_LOG(DEBUG) << "arg_2, python local_params_dict: " << local_dict_value->ToString() << " -> "
3516                   << py::str(local_params_dict);
3517     params[1] = local_params_dict;
3518     CheckInterpretInput(filtered_local_dict, script);
3519 
3520     return params;
3521   }
3522 
ReCheckLocalDict(const AbstractDictionaryPtr & filtered_local_dict) const3523   py::dict ReCheckLocalDict(const AbstractDictionaryPtr &filtered_local_dict) const {
3524     const auto &keys_values = filtered_local_dict->elements();
3525     py::dict local_params_dict;
3526     for (auto &key_value : keys_values) {
3527       MS_EXCEPTION_IF_NULL(key_value.second);
3528       ValuePtr element_value = key_value.second->BuildValue();
3529       MS_EXCEPTION_IF_NULL(element_value);
3530       auto py_data = ValueToPyData(element_value);
3531       MS_EXCEPTION_IF_NULL(key_value.first);
3532       local_params_dict[ValueToPyData(key_value.first->BuildValue())] = py_data;
3533     }
3534     return local_params_dict;
3535   }
3536 
FilterParameters(const AbstractDictionaryPtr & abstract_dict) const3537   AbstractDictionaryPtr FilterParameters(const AbstractDictionaryPtr &abstract_dict) const {
3538     MS_EXCEPTION_IF_NULL(abstract_dict);
3539     std::vector<AbstractElementPair> kv;
3540     const auto &keys_values = abstract_dict->elements();
3541     // Filter out the element of Function type.
3542     (void)std::copy_if(keys_values.cbegin(), keys_values.cend(), std::back_inserter(kv),
3543                        [](const AbstractElementPair &item) {
3544                          MS_EXCEPTION_IF_NULL(item.second);
3545                          return (!item.second->isa<abstract::AbstractFunction>());
3546                        });
3547     return std::make_shared<AbstractDictionary>(kv);
3548   }
3549 
HasConstArgAttr(const py::object & obj) const3550   bool HasConstArgAttr(const py::object &obj) const {
3551     constexpr char const_arg_attr[] = "const_arg";
3552     return py::hasattr(obj, const_arg_attr) && py::cast<bool>(py::getattr(obj, const_arg_attr));
3553   }
3554 
GetScriptStr(const AbstractBasePtr & abs) const3555   std::string GetScriptStr(const AbstractBasePtr &abs) const {
3556     // When PyInterpret node is built in python, the value of script abstract should be StringImm.
3557     // Otherwise, the value of script should be Script type.
3558     MS_EXCEPTION_IF_NULL(abs);
3559     ValuePtr value_track = abs->GetValueTrack();
3560     MS_EXCEPTION_IF_NULL(value_track);
3561     if (value_track->isa<parse::Script>()) {
3562       auto script_value_track = dyn_cast_ptr<parse::Script>(value_track);
3563       return script_value_track->script();
3564     }
3565     if (!value_track->isa<StringImm>()) {
3566       MS_INTERNAL_EXCEPTION(TypeError) << "Wrong script type for PyInterpret node, script abs: " << abs->ToString();
3567     }
3568     return value_track->ToString();
3569   }
3570 
GetGlobalObject(const AbstractBasePtr & abs) const3571   py::object GetGlobalObject(const AbstractBasePtr &abs) const {
3572     MS_EXCEPTION_IF_NULL(abs);
3573     if (!abs->isa<abstract::AbstractScalar>() && !abs->isa<abstract::AbstractDictionary>()) {
3574       MS_INTERNAL_EXCEPTION(TypeError) << "The second argument should be a scalar(InterpretedObject) or dictionary, "
3575                                        << "but got " << abs->ToString();
3576     }
3577     auto val = abs->BuildValue();
3578     MS_EXCEPTION_IF_NULL(val);
3579     AbstractDictionaryPtr global_dict = nullptr;
3580     // Some functions in global_dict are not used and will be released early,
3581     // resulting in the func_graph pointer in AbstractClosure being released.
3582     ValuePtr globals_converted_value = nullptr;
3583     py::object global_params_dict;
3584     if (abs->isa<abstract::AbstractDictionary>()) {
3585       global_dict = abs->cast<abstract::AbstractDictionaryPtr>();
3586       auto filtered_global_dict = FilterParameters(global_dict);
3587       global_params_dict = ValueToPyData(filtered_global_dict->BuildValue());
3588     } else {
3589       auto global_dict_interpreted = dyn_cast<parse::InterpretedObject>(val);
3590       MS_EXCEPTION_IF_NULL(global_dict_interpreted);
3591       const py::object &global_params_dict_obj = global_dict_interpreted->obj();
3592       if (!parse::ConvertData(global_params_dict_obj, &globals_converted_value)) {
3593         MS_LOG(INTERNAL_EXCEPTION) << "Convert data failed";
3594       }
3595       MS_EXCEPTION_IF_NULL(globals_converted_value);
3596       // Filter global parameters dict.
3597       global_dict = dyn_cast<AbstractDictionary>(globals_converted_value->ToAbstract());
3598       if (global_dict == nullptr) {
3599         MS_LOG(INTERNAL_EXCEPTION) << "The second argument should be a dictionary, but got "
3600                                    << globals_converted_value->ToAbstract()->ToString();
3601       }
3602       auto filtered_global_dict = FilterParameters(global_dict);
3603       MS_LOG(DEBUG) << "arg_1, global_dict: " << global_dict->ToString()
3604                     << ", filtered_global_dict: " << filtered_global_dict->ToString();
3605       ValuePtr global_dict_value = filtered_global_dict->BuildValue();
3606       global_params_dict = ValueToPyData(global_dict_value);
3607     }
3608     // Add filtered global python function to global_params_dict.
3609     AddGlobalPythonFunction(global_dict, &global_params_dict);
3610     return global_params_dict;
3611   }
3612 
ConvertLocalValueInputNode(const AnfNodePtr & local_node,const AbstractBasePtr & local_abs) const3613   AnfNodePtr ConvertLocalValueInputNode(const AnfNodePtr &local_node, const AbstractBasePtr &local_abs) const {
3614     MS_EXCEPTION_IF_NULL(local_node);
3615     MS_EXCEPTION_IF_NULL(local_abs);
3616     AnfNodePtr ret_node = nullptr;
3617     // Not consider AbstractDictionary scene yet.
3618     if (local_abs->isa<abstract::AbstractSequence>() &&
3619         IsOneOfPrimitiveCNode(local_node, {prim::kPrimMakeTuple, prim::kPrimMakeList})) {
3620       auto local_cnode = local_node->cast<CNodePtr>();
3621       auto local_abs_seq = local_abs->cast<abstract::AbstractSequencePtr>();
3622       if (local_cnode->size() - 1 != local_abs_seq->size()) {
3623         MS_LOG(INTERNAL_EXCEPTION) << "For node: " << local_node->DebugString() << ", input size is "
3624                                    << local_cnode->size() << " and abstract size is " << local_abs_seq->size()
3625                                    << ". Size not matched.";
3626       }
3627       const auto &local_elements_abs = local_abs_seq->elements();
3628       AnfNodePtrList new_inputs;
3629       (void)new_inputs.emplace_back(local_cnode->input(0));
3630       for (size_t i = 1; i < local_cnode->size(); ++i) {
3631         (void)new_inputs.emplace_back(ConvertLocalValueInputNode(local_cnode->input(i), local_elements_abs[i - 1]));
3632       }
3633       auto fg = local_cnode->func_graph();
3634       MS_EXCEPTION_IF_NULL(fg);
3635       ret_node = fg->NewCNode(new_inputs);
3636     } else {
3637       auto py_obj = fallback::GetPyObjForFuncGraphAbstractClosure(local_abs);
3638       if (py::isinstance<py::none>(py_obj)) {
3639         return local_node;
3640       }
3641       ret_node = NewValueNode(std::make_shared<parse::InterpretedObject>(py_obj));
3642     }
3643     MS_EXCEPTION_IF_NULL(ret_node);
3644     ret_node->set_debug_info(local_node->debug_info());
3645     return ret_node;
3646   }
3647 
ConvertPyInterpretNode(const AnfNodePtr & node,const AbstractBasePtrList & args_abs_list) const3648   AnfNodePtr ConvertPyInterpretNode(const AnfNodePtr &node, const AbstractBasePtrList &args_abs_list) const {
3649     MS_EXCEPTION_IF_NULL(node);
3650     // Ensure the same node only check local dict once.
3651     if (node->has_user_data(fallback::kLocalDictCheck) && *node->user_data<bool>(fallback::kLocalDictCheck)) {
3652       return nullptr;
3653     }
3654     node->set_user_data(fallback::kLocalDictCheck, std::make_shared<bool>(true));
3655     auto cnode = node->cast<CNodePtr>();
3656     MS_EXCEPTION_IF_NULL(cnode);
3657     constexpr size_t interpret_min_len = 4;
3658     if (cnode->size() < interpret_min_len) {
3659       MS_LOG(INTERNAL_EXCEPTION) << "The minimum input number for PyInterpret node should be " << interpret_min_len
3660                                  << " but got " << cnode->size();
3661     }
3662     if (args_abs_list.size() < interpret_min_len - 1) {
3663       MS_LOG(INTERNAL_EXCEPTION) << "The minimum number for PyInterpret input abstract should be "
3664                                  << interpret_min_len - 1 << " but got " << args_abs_list.size();
3665     }
3666     constexpr size_t local_index = 3;
3667     auto local_node = cnode->input(local_index);
3668     auto local_node_abs = args_abs_list[local_index - 1];
3669     MS_EXCEPTION_IF_NULL(local_node);
3670     MS_EXCEPTION_IF_NULL(local_node_abs);
3671     if (!IsPrimitiveCNode(local_node, prim::kPrimMakeDict)) {
3672       return nullptr;
3673     }
3674     auto local_cnode = local_node->cast<CNodePtr>();
3675     constexpr size_t make_dict_len = 3;
3676     if (local_cnode->size() != make_dict_len) {
3677       MS_LOG(INTERNAL_EXCEPTION) << "Make dict mode input size should be " << make_dict_len << " but got "
3678                                  << local_cnode->size();
3679     }
3680 
3681     const auto &check_abs_function = [](const AbstractBasePtr &input) {
3682       std::function<bool(const AbstractBasePtr &)> check_abs_function_inner;
3683       check_abs_function_inner = [&](const AbstractBasePtr &abs) {
3684         MS_EXCEPTION_IF_NULL(abs);
3685         if (abs->isa<abstract::AbstractSequence>()) {
3686           auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
3687           const auto &elements = abs_seq->elements();
3688           return std::any_of(elements.begin(), elements.end(),
3689                              [check_abs_function_inner](const AbstractBasePtr &inner_abs) {
3690                                return check_abs_function_inner(inner_abs);
3691                              });
3692         }
3693         if (abs->isa<abstract::AbstractDictionary>()) {
3694           auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
3695           const auto elements = abs_dict->elements();
3696           return std::any_of(elements.begin(), elements.end(),
3697                              [check_abs_function_inner](const abstract::AbstractElementPair &inner_abs) {
3698                                // Dictionary key can not be abstract function, no need to check.
3699                                return check_abs_function_inner(inner_abs.second);
3700                              });
3701         }
3702         return abs->isa<abstract::FuncGraphAbstractClosure>();
3703       };
3704       return check_abs_function_inner(input);
3705     };
3706 
3707     if (!check_abs_function(local_node_abs)) {
3708       return nullptr;
3709     }
3710     auto local_node_abs_dict = local_node_abs->cast<abstract::AbstractDictionaryPtr>();
3711     MS_EXCEPTION_IF_NULL(local_node_abs_dict);
3712     const auto &elements_pair = local_node_abs_dict->elements();
3713     std::vector<abstract::AbstractBasePtr> element_abs{};
3714     (void)std::transform(elements_pair.begin(), elements_pair.end(), std::back_inserter(element_abs),
3715                          [](const AbstractElementPair &pairs) { return pairs.second; });
3716     auto local_value_abs = std::make_shared<abstract::AbstractTuple>(element_abs);
3717     constexpr size_t value_index = 2;
3718     auto local_value_node = local_cnode->input(value_index);
3719     auto new_local_value_node = ConvertLocalValueInputNode(local_value_node, local_value_abs);
3720     std::vector<AnfNodePtr> new_local_node_inputs;
3721     for (size_t i = 0; i < value_index; ++i) {
3722       new_local_node_inputs.push_back(local_cnode->input(i));
3723     }
3724     new_local_node_inputs.push_back(new_local_value_node);
3725     auto fg = node->func_graph();
3726     MS_EXCEPTION_IF_NULL(fg);
3727     auto new_local_cnode = fg->NewCNode(new_local_node_inputs);
3728     new_local_cnode->set_debug_info(local_cnode->debug_info());
3729     std::vector<AnfNodePtr> new_cnode_inputs;
3730     for (size_t i = 0; i < local_index; ++i) {
3731       new_cnode_inputs.push_back(cnode->input(i));
3732     }
3733     new_cnode_inputs.push_back(new_local_cnode);
3734     for (size_t i = local_index + 1; i < cnode->size(); ++i) {
3735       new_cnode_inputs.push_back(cnode->input(i));
3736     }
3737     auto new_cnode = fg->NewCNode(new_cnode_inputs);
3738     new_cnode->set_debug_info(cnode->debug_info());
3739     new_cnode->set_user_data(fallback::kLocalDictCheck, std::make_shared<bool>(true));
3740     return new_cnode;
3741   }
3742 
3743  private:
3744   mutable bool non_const_err_{false};
3745   mutable bool check_list_dict_inplace_{false};
3746 };
3747 
3748 class EmbedEvaluator : public SymbolicPrimEvaluator {
3749  public:
EmbedEvaluator()3750   EmbedEvaluator() : SymbolicPrimEvaluator("EmbedEvaluator") {}
3751   ~EmbedEvaluator() override = default;
3752   MS_DECLARE_PARENT(EmbedEvaluator, SymbolicPrimEvaluator);
EvalPrim(const ConfigPtrList & args_conf_list)3753   EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override {
3754     // arg: free variable to be embedded
3755     if (args_conf_list.size() != 1) {
3756       MS_LOG(INTERNAL_EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size();
3757     }
3758     auto node_conf = dyn_cast_ptr<AnfNodeConfig>(args_conf_list[0]);
3759     MS_EXCEPTION_IF_NULL(node_conf);
3760     const auto &eval_result = node_conf->ObtainEvalResult();
3761     MS_EXCEPTION_IF_NULL(eval_result);
3762     AbstractBasePtr x = eval_result->abstract();
3763     x = SensitivityTransform(x);
3764     SymbolicKeyInstancePtr key = std::make_shared<SymbolicKeyInstance>(node_conf->node(), x);
3765     AbstractScalarPtr abs_scalar = std::make_shared<AbstractScalar>(key, std::make_shared<SymbolicKeyType>());
3766     return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
3767   }
3768 };
3769 
FindParameterNodeByString(const FuncGraphManagerPtr & manager,const std::string & name)3770 static AnfNodePtr FindParameterNodeByString(const FuncGraphManagerPtr &manager, const std::string &name) {
3771   MS_EXCEPTION_IF_NULL(manager);
3772   auto root_g_set = manager->roots();
3773   if (root_g_set.size() != 1) {
3774     return nullptr;
3775   }
3776   const FuncGraphPtr &root_g = root_g_set.back();
3777   MS_EXCEPTION_IF_NULL(root_g);
3778   for (auto &param_node : root_g->parameters()) {
3779     auto param = param_node->cast<ParameterPtr>();
3780     if (param != nullptr && param->name() == name) {
3781       return param;
3782     }
3783   }
3784   return nullptr;
3785 }
3786 
3787 class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
3788  public:
RefToEmbedEvaluator()3789   RefToEmbedEvaluator() : SymbolicPrimEvaluator("RefToEmbedEvaluator") {}
3790   ~RefToEmbedEvaluator() override = default;
3791   MS_DECLARE_PARENT(RefToEmbedEvaluator, SymbolicPrimEvaluator);
EvalPrim(const ConfigPtrList & args_conf_list)3792   EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override {
3793     if (args_conf_list.size() != 1) {
3794       MS_LOG(ERROR) << "Requires 1 parameter, but has: " << args_conf_list.size();
3795       return nullptr;
3796     }
3797     static TypePtr type = std::make_shared<SymbolicKeyType>();
3798     auto node_conf = dyn_cast_ptr<AnfNodeConfig>(args_conf_list[0]);
3799     if (node_conf == nullptr) {
3800       MS_LOG(ERROR) << "Conf should be AnfNodeConfig";
3801       return nullptr;
3802     }
3803     const auto &eval_result = node_conf->ObtainEvalResult();
3804     MS_EXCEPTION_IF_NULL(eval_result);
3805     AbstractBasePtr abs = eval_result->abstract();
3806     MS_EXCEPTION_IF_NULL(abs);
3807     auto ref_key_value = abstract::GetRefKeyValue(abs);
3808     if (ref_key_value == nullptr) {
3809       MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString();
3810       return nullptr;
3811     }
3812     // Check if the input of RefEmbed is a weight parameter, if not, don't create the
3813     // specific SymbolicKey.
3814     // Notes: when different weight parameter have same type and shape passed as parameter to same funcgraph
3815     // which has RefToEmbed CNode, that funcgraph will not be specialized to different funcgraph, so the
3816     // RefToEmbed CNode in that funcgraph also should not be evaluated to specific SymbolicKey.
3817     // Only after that funcgrpah is inlined, the RefToEmbed CNode should be evaluated to specific SymbolicKey.
3818     bool embed_is_weight = false;
3819     if (node_conf->node() != nullptr && node_conf->node()->isa<Parameter>()) {
3820       auto param = node_conf->node()->cast_ptr<Parameter>();
3821       MS_EXCEPTION_IF_NULL(param);
3822       embed_is_weight = param->has_default();
3823     }
3824     auto refkey = ref_key_value->cast_ptr<StringImm>();
3825     if (refkey == nullptr || !embed_is_weight) {
3826       auto res = std::make_shared<AbstractScalar>(type);
3827       return std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
3828     }
3829 
3830     std::string name = refkey->value();
3831     MS_EXCEPTION_IF_NULL(node_conf->node());
3832     if (node_conf->node()->func_graph() == nullptr) {
3833       MS_LOG(INTERNAL_EXCEPTION) << "Should not evaluate a ValueNode, node: " << node_conf->node()->DebugString();
3834     }
3835     const auto &manager = node_conf->node()->func_graph()->manager();
3836     auto node = FindParameterNodeByString(manager, name);
3837     if (node == nullptr) {
3838       MS_LOG(ERROR) << "RefToEmbed input can't find parameter \"" << name << "\" in graph.";
3839       return nullptr;
3840     }
3841     AbstractBasePtr x = SensitivityTransform(abs);
3842     std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
3843     std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
3844     return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
3845   }
3846 };
3847 
3848 class GetAttrEvaluator : public TransitionPrimEvaluator {
3849  public:
GetAttrEvaluator()3850   GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {}
3851   ~GetAttrEvaluator() override = default;
3852   MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const ConfigPtr & in_conf0,const AnfNodeConfigPtr & out_conf)3853   EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list,
3854                          const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
3855     constexpr auto args_min_size = 2;
3856     constexpr auto args_max_size = 3;
3857     const auto args_size = args_abs_list.size();
3858     if (args_size != args_min_size && args_size != args_max_size) {
3859       MS_LOG(EXCEPTION) << "For Primitive GetAttr, the input size should be " << args_min_size << " or "
3860                         << args_max_size << ", but got size: " << args_size;
3861     }
3862     auto res_abstract = EvalUndeterminedArgs(args_abs_list);
3863     if (res_abstract != nullptr) {
3864       return res_abstract;
3865     }
3866 
3867     constexpr auto attr_index = 1;
3868     auto attr_abs = args_abs_list[attr_index];
3869     MS_EXCEPTION_IF_NULL(attr_abs);
3870     auto attr_abs_type = attr_abs->BuildType();
3871     MS_EXCEPTION_IF_NULL(attr_abs_type);
3872     auto type_id = attr_abs_type->type_id();
3873     if (type_id != TypeId::kObjectTypeString) {
3874       MS_EXCEPTION(TypeError) << "getattr(): attribute name must be string but got: " << TypeIdToString(type_id);
3875     }
3876     EvalResultPtr res = nullptr;
3877     if (bound_node() != nullptr) {
3878       TraceGuard trace_guard(std::make_shared<TraceResolve>(bound_node()->debug_info()));
3879       res = StaticGetter(engine, args_abs_list, in_conf0, out_conf);
3880     } else {
3881       res = StaticGetter(engine, args_abs_list, in_conf0, out_conf);
3882     }
3883     // Don't lookup from cache, as different out_conf with same node but different context
3884     // may add different entry to anfnode_config_map, like getattr primitive.
3885     evaluator_cache_mgr_->SetValue(args_abs_list, res);
3886     return res;
3887   }
3888 };
3889 
3890 class SetAttrEvaluator : public TransitionPrimEvaluator {
3891  public:
SetAttrEvaluator()3892   SetAttrEvaluator() : TransitionPrimEvaluator("SetAttrEvaluator") {}
3893   ~SetAttrEvaluator() override = default;
3894   MS_DECLARE_PARENT(SetAttrEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr &,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)3895   EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
3896                          const AnfNodeConfigPtr &out_conf) override {
3897     constexpr size_t min_args_size = 3;
3898     constexpr size_t max_args_size = 4;
3899     size_t args_size = args_abs_list.size();
3900     if (args_size != min_args_size && args_size != max_args_size) {
3901       MS_LOG(EXCEPTION) << "For Primitive SetAttr, the input size should be " << min_args_size << " or "
3902                         << max_args_size << ", but got size: " << args_size;
3903     }
3904     auto res_abstract = EvalUndeterminedArgs(args_abs_list);
3905     if (res_abstract != nullptr) {
3906       return res_abstract;
3907     }
3908 
3909     return InterpretSetAttrNode(args_abs_list, out_conf);
3910   }
3911 };
3912 
3913 class ResolveEvaluator : public TransitionPrimEvaluator {
3914  public:
ResolveEvaluator()3915   ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {}
3916   ~ResolveEvaluator() override = default;
3917   MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const ConfigPtr & in_conf0,const AnfNodeConfigPtr & out_conf)3918   EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list,
3919                          const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
3920     constexpr auto resolve_args_size = 2;       // (namespace, symbol)
3921     constexpr auto resolve_with_args_size = 3;  // (namespace, symbol, arguments)
3922     // Inputs: namespace, symbol
3923     if (args_abs_list.size() != resolve_args_size && args_abs_list.size() != resolve_with_args_size) {
3924       MS_LOG(EXCEPTION) << "Expected args_abs_list size is 2 or 3, but has size: " << args_abs_list.size();
3925     }
3926     EvalResultPtr res = nullptr;
3927     if (bound_node() != nullptr) {
3928       TraceGuard trace_guard(std::make_shared<TraceResolve>(bound_node()->debug_info()));
3929       res = StaticGetter(engine, args_abs_list, in_conf0, out_conf);
3930     } else {
3931       res = StaticGetter(engine, args_abs_list, in_conf0, out_conf);
3932     }
3933     return res;
3934   }
3935 };
3936 
3937 class CreateInstanceEvaluator : public TransitionPrimEvaluator {
3938  public:
CreateInstanceEvaluator()3939   CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {}
3940   ~CreateInstanceEvaluator() override = default;
3941   MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)3942   EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
3943                          const AnfNodeConfigPtr &out_conf) override {
3944     // Check the type parameter.
3945     if (args_abs_list.empty()) {
3946       MS_LOG(INTERNAL_EXCEPTION) << "'args_abs_list' should not be empty";
3947     }
3948     constexpr size_t class_index = 0;
3949     auto class_obj = GetPythonObject(args_abs_list[class_index]);
3950     py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
3951     std::string class_name =
3952       python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_MS_CLASS_NAME, class_obj).cast<std::string>();
3953     // Get the create instance obj's parameters, `params` may contain tuple(args, kwargs).
3954     auto params = py::tuple(args_abs_list.size() - 1);
3955     bool is_prim_variable = GetParameters(args_abs_list, class_obj, class_name, &params);
3956     if (is_prim_variable) {
3957       return CreatePrimitiveInstanceWithVariableArgs(args_abs_list, class_name, class_obj, engine, out_conf);
3958     }
3959     // Create class instance.
3960     auto obj = parse::data_converter::CreatePythonObject(class_obj, params);
3961     if (py::isinstance<py::none>(obj)) {
3962       MS_LOG(EXCEPTION) << "Create python object `" << py::str(class_obj)
3963                         << "` failed, only support to create 'Cell', 'Primitive' or "
3964                         << "user-defined Class decorated with 'jit_class'.";
3965     }
3966 
3967     // Process the object.
3968     MS_EXCEPTION_IF_NULL(out_conf->node());
3969     TraceGuard guard(std::make_shared<TraceResolve>(out_conf->node()->debug_info()));
3970     ValuePtr converted_res = nullptr;
3971     bool converted = parse::ConvertData(obj, &converted_res, true);
3972     if (!converted) {
3973       MS_LOG(INTERNAL_EXCEPTION) << "Convert the python object failed";
3974     }
3975     MS_EXCEPTION_IF_NULL(converted_res);
3976     // To check isolated side effect for the func graph who returns constant.
3977     HandleSideEffect(obj, converted_res, engine, out_conf);
3978 
3979     if (converted_res->isa<FuncGraph>()) {
3980       AddToManager(engine, converted_res->cast<FuncGraphPtr>());
3981     }
3982     AbstractBasePtr res = ToAbstract(converted_res, AnalysisContext::DummyContext(), out_conf);
3983     auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
3984     evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
3985     return infer_result;
3986   }
3987 
GetPythonObject(const AbstractBasePtr & arg_class_type) const3988   py::object GetPythonObject(const AbstractBasePtr &arg_class_type) const {
3989     MS_EXCEPTION_IF_NULL(arg_class_type);
3990     TypePtr type = arg_class_type->GetTypeTrack();
3991     MS_EXCEPTION_IF_NULL(type);
3992     if (type->type_id() != kMetaTypeTypeType && type->type_id() != kObjectTypeClass) {
3993       MS_LOG(EXCEPTION)
3994         << "CreateInstanceEvaluator require first parameter should be an object of TypeType or TypeClass, but got "
3995         << type->ToString();
3996     }
3997 
3998     ValuePtr value_track = arg_class_type->GetValueTrack();
3999     MS_EXCEPTION_IF_NULL(value_track);
4000     auto type_obj = dyn_cast_ptr<parse::PyObjectWrapper>(value_track);
4001     if (type_obj == nullptr) {
4002       MS_LOG(INTERNAL_EXCEPTION) << "Cast value failed, not PyObjectWrapper: " << value_track->ToString() << ".";
4003     }
4004     if (!type_obj->isa<parse::ClassType>() && !type_obj->isa<parse::MsClassObject>()) {
4005       MS_LOG(EXCEPTION)
4006         << "CreateInstanceEvaluator the type_obj should be an object of ClassType or MsClassObject, but got "
4007         << type_obj->ToString() << ".";
4008     }
4009     MS_LOG(DEBUG) << "Get class type: " << type_obj->ToString() << ".";
4010     return type_obj->obj();
4011   }
4012 
HandleSideEffect(const py::object & obj,const ValuePtr & converted_res,const AnalysisEnginePtr & engine,const AnfNodeConfigPtr & out_conf) const4013   void HandleSideEffect(const py::object &obj, const ValuePtr &converted_res, const AnalysisEnginePtr &engine,
4014                         const AnfNodeConfigPtr &out_conf) const {
4015     if (engine->check_side_effect()) {
4016       MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", converted_res: " << converted_res->ToString();
4017       auto prim = GetValueWithoutDoSignature(converted_res)->cast<PrimitivePtr>();
4018       if (prim != nullptr) {
4019         auto effect_info = GetPrimEffectInfo(prim);
4020         if (effect_info.memory || effect_info.io) {
4021           const auto &cnode = dyn_cast<CNode>(out_conf->node());
4022           MS_EXCEPTION_IF_NULL(cnode);
4023           MS_EXCEPTION_IF_NULL(out_conf->func_graph());
4024           MS_LOG(DEBUG) << "Found side-effect, cnode: " << cnode->DebugString()
4025                         << ", func_graph: " << out_conf->func_graph()->ToString();
4026           cnode->set_has_side_effect_node(true);
4027           out_conf->func_graph()->set_has_side_effect_node(true);
4028         }
4029       }
4030     }
4031   }
4032 
GetParameters(const AbstractBasePtrList & args_abs_list,const py::object & obj,const std::string & cls_name,py::tuple * params)4033   bool GetParameters(const AbstractBasePtrList &args_abs_list, const py::object &obj, const std::string &cls_name,
4034                      py::tuple *params) {
4035     auto params_size = (*params).size();
4036     for (size_t i = 0; i < params_size; i++) {
4037       // Only support the Scalar parameters type. Bypass class type by offset with 1.
4038       auto arg = args_abs_list[i + 1];
4039       MS_EXCEPTION_IF_NULL(arg);
4040       auto param_value = arg->BuildValue();
4041       MS_EXCEPTION_IF_NULL(param_value);
4042       if (param_value->ContainsValueAny() && !arg->isa<AbstractFunction>()) {
4043         // If obj is a Primitive class and has variable arguments, just return and go through another process.
4044         if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG) && mindspore::ops::GetOpDef(cls_name) != nullptr) {
4045           return true;
4046         }
4047         MS_EXCEPTION(TypeError) << "When creating an instance of '" << cls_name
4048                                 << "', all arguments are required to be constants, but input " << i
4049                                 << " is a variable, which is " << arg->ToString() << ".";
4050       }
4051       py::object param = ValueToPyData(param_value);
4052       (*params)[i] = param;
4053     }
4054     return false;
4055   }
4056 
CreatePrimitiveInstanceWithVariableArgs(const AbstractBasePtrList & args_abs_list,const std::string & cls_name,const py::object & cls_obj,const AnalysisEnginePtr & engine,const AnfNodeConfigPtr & out_conf) const4057   EvalResultPtr CreatePrimitiveInstanceWithVariableArgs(const AbstractBasePtrList &args_abs_list,
4058                                                         const std::string &cls_name, const py::object &cls_obj,
4059                                                         const AnalysisEnginePtr &engine,
4060                                                         const AnfNodeConfigPtr &out_conf) const {
4061     // Create Primitive instance with variable arguments.
4062     auto prim_func = std::make_shared<Primitive>(cls_name);
4063     auto do_trans_prim_func = std::make_shared<prim::DoTransPrimitiveFunction>(prim_func);
4064     // Ignore the first input which is ClassType.
4065     AbstractBasePtrList partial_args_abs_list(args_abs_list.begin() + 1, args_abs_list.end());
4066     do_trans_prim_func->set_given_init_size(partial_args_abs_list.size());
4067     auto func_ptr = std::make_shared<abstract::PrimitiveAbstractClosure>(do_trans_prim_func);
4068     auto ret_val =
4069       std::make_shared<abstract::PartialAbstractClosure>(func_ptr, partial_args_abs_list, out_conf->node());
4070     ret_val->set_need_append_to_end(true);
4071     return std::make_shared<EvalResult>(ret_val, std::make_shared<AttrValueMap>());
4072   }
4073 };
4074 
4075 class PartialEvaluator : public Evaluator {
4076  public:
PartialEvaluator()4077   PartialEvaluator() : Evaluator("PartialEvaluator") {}
4078   ~PartialEvaluator() override = default;
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)4079   EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
4080                     const AnfNodeConfigPtr &out_conf) override {
4081     if (args_conf_list.size() == 0) {
4082       MS_LOG(INTERNAL_EXCEPTION) << "Args size should be greater than 0";
4083     }
4084     MS_EXCEPTION_IF_NULL(out_conf);
4085     MS_EXCEPTION_IF_NULL(out_conf->node());
4086     MS_EXCEPTION_IF_NULL(args_conf_list[0]);
4087     const auto &arg0_eval_result = args_conf_list[0]->ObtainEvalResult();
4088     MS_EXCEPTION_IF_NULL(arg0_eval_result);
4089     auto arg0_value = arg0_eval_result->abstract();
4090     MS_EXCEPTION_IF_NULL(arg0_value);
4091     AbstractBasePtrList args_abs_list{arg0_value};
4092     auto cnode = out_conf->node()->cast<CNodePtr>();
4093     MS_EXCEPTION_IF_NULL(cnode);
4094 
4095     // Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node.
4096     if (arg0_value->isa<AbstractProblem>()) {
4097       MS_EXCEPTION_IF_NULL(arg0_value->GetValueTrack());
4098       const auto &value_problem = arg0_value->GetValueTrack()->cast<ValueProblemPtr>();
4099       auto res = std::make_shared<AbstractProblem>(value_problem, out_conf->node());
4100       MS_LOG(DEBUG) << "AbstractProblem for node: " << out_conf->node()->DebugString()
4101                     << " as func is: " << arg0_value->ToString();
4102       auto eval_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
4103       evaluator_cache_mgr_->SetValue(args_abs_list, eval_result);
4104       return eval_result;
4105     }
4106     auto func = CheckArg<AbstractFunction>("partial", args_abs_list, 0);
4107     // Sometimes, node[0] in out_conf becomes phi0;
4108     if (func->isa<PrimitiveAbstractClosure>()) {
4109       auto prim_func = dyn_cast_ptr<PrimitiveAbstractClosure>(func);
4110       MS_EXCEPTION_IF_NULL(prim_func);
4111       MS_EXCEPTION_IF_NULL(prim_func->prim());
4112       if (prim_func->prim()->isa<prim::DoSignaturePrimitive>()) {
4113         auto do_signature_prim = dyn_cast_ptr<prim::DoSignaturePrimitive>(prim_func->prim());
4114         return HandleDoSignature(engine, do_signature_prim->function(), out_conf);
4115       }
4116     }
4117 
4118     (void)std::transform(args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_abs_list),
4119                          [](const ConfigPtr &config) -> AbstractBasePtr {
4120                            MS_EXCEPTION_IF_NULL(config);
4121                            const auto &eval_result = config->ObtainEvalResult();
4122                            MS_EXCEPTION_IF_NULL(eval_result);
4123                            return eval_result->abstract();
4124                          });
4125     AbstractBasePtrList args(args_abs_list.begin() + 1, args_abs_list.end());
4126 
4127     if (cnode->size() != (args_conf_list.size() + 1)) {
4128       MS_LOG(INTERNAL_EXCEPTION) << "Out_conf node: " << cnode->DebugString()
4129                                  << ", args_conf_list: " << mindspore::ToString(args_conf_list);
4130     }
4131     AbstractFuncAtomPtrList partial_funcs_list;
4132     auto build_partial = [args, cnode, &partial_funcs_list](const AbstractFuncAtomPtr &atom_func) {
4133       auto new_func = std::make_shared<PartialAbstractClosure>(atom_func, args, cnode);
4134       partial_funcs_list.push_back(new_func);
4135     };
4136     func->Visit(build_partial);
4137 
4138     auto res = AbstractFunction::MakeAbstractFunction(partial_funcs_list);
4139     auto eval_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
4140     MS_LOG(DEBUG) << "args_abs_list: " << args_abs_list << ", eval_result: " << eval_result->abstract()->ToString();
4141     evaluator_cache_mgr_->SetValue(args_abs_list, eval_result);
4142     return eval_result;
4143   }
4144 
Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)4145   EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
4146     MS_LOG(INTERNAL_EXCEPTION) << "Eval() should not be called, Run() method should be called";
4147   }
4148 
HandleDoSignature(const AnalysisEnginePtr & engine,const ValuePtr & signature_value,const AnfNodeConfigPtr & out_conf) const4149   EvalResultPtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value,
4150                                   const AnfNodeConfigPtr &out_conf) const {
4151     MS_EXCEPTION_IF_NULL(engine);
4152     MS_EXCEPTION_IF_NULL(out_conf);
4153     MS_EXCEPTION_IF_NULL(out_conf->node());
4154     auto cnode = out_conf->node()->cast_ptr<CNode>();
4155     MS_EXCEPTION_IF_NULL(cnode);
4156 
4157     ScopeGuard scope_guard(out_conf->node()->scope());
4158     TraceGuard trace_guard(std::make_shared<TraceDoSignature>(out_conf->node()->debug_info()));
4159     auto new_nodes_inputs = cnode->weak_inputs();
4160     auto new_signature_value = std::make_shared<prim::DoSignatureMetaFuncGraph>("signature", signature_value);
4161     auto new_sig_node = NewValueNode(new_signature_value);
4162     new_nodes_inputs[1] = AnfNodeWeakPtr(new_sig_node);
4163     FuncGraphPtr func_graph = cnode->func_graph();
4164     MS_EXCEPTION_IF_NULL(func_graph);
4165     CNodePtr new_cnode = func_graph->NewCNodeWeak(std::move(new_nodes_inputs));
4166     AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
4167     return engine->ForwardConfig(out_conf, fn_conf);
4168   }
4169 };
4170 
4171 class RaiseEvaluator : public TransitionPrimEvaluator {
4172  public:
RaiseEvaluator()4173   RaiseEvaluator() : TransitionPrimEvaluator("RaiseEvaluator") {}
4174   ~RaiseEvaluator() override = default;
4175   MS_DECLARE_PARENT(RaiseEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr &,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)4176   EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
4177                          const AnfNodeConfigPtr &out_conf) override {
4178     MS_EXCEPTION_IF_NULL(out_conf);
4179     // Handle for DDE.
4180     for (size_t i = 0; i < args_abs_list.size(); ++i) {
4181       MS_EXCEPTION_IF_NULL(args_abs_list[i]);
4182       if (args_abs_list[i]->isa<abstract::AbstractSequence>()) {
4183         MS_LOG(DEBUG) << "Primitive \'Raise\' is consuming tuple/list arguments[" << i
4184                       << "]: " << args_abs_list[i]->ToString();
4185         SetSequenceElementsUseFlagsRecursively(args_abs_list[i], true);
4186       }
4187     }
4188     auto node = out_conf->node();
4189     MS_EXCEPTION_IF_NULL(node);
4190     auto cur_graph = node->func_graph();
4191     MS_EXCEPTION_IF_NULL(cur_graph);
4192     if (args_abs_list.empty()) {
4193       // Process raise.
4194       MS_LOG(INTERNAL_EXCEPTION) << "No active exception to reraise.";
4195     }
4196     const auto &cnode = node->cast<CNodePtr>();
4197     MS_EXCEPTION_IF_NULL(cnode);
4198 
4199     // Return Any directly if meet variable condition or content.
4200     bool is_variable_condition = raiseutils::HasVariableCondition(cur_graph);
4201     bool has_variable = false;
4202     size_t index_begin = 2;
4203     size_t index_end = cnode->size() - 1;
4204     for (size_t index = index_begin; index < cnode->size(); ++index) {
4205       if (raiseutils::CheckHasVariable(args_abs_list[index - 1])) {
4206         has_variable = true;
4207         break;
4208       }
4209     }
4210     if (is_variable_condition || has_variable) {
4211       AbstractBasePtr res = std::make_shared<AbstractNegligible>();
4212       cnode->set_has_side_effect_node(true);
4213       cur_graph->set_has_side_effect_node(true);
4214       auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
4215       evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
4216       return infer_result;
4217     }
4218 
4219     // Continue to handle raise in compile time.
4220     std::shared_ptr<raiseutils::KeyValueInfo> key_value = std::make_shared<raiseutils::KeyValueInfo>();
4221     std::string exception_type =
4222       raiseutils::GetExceptionType(args_abs_list[0], cnode->input(index_end), key_value, false);
4223     std::string exception_string;
4224     // Process raise ValueError()
4225     if (args_abs_list.size() == 1) {
4226       RaiseConstant(exception_type);
4227     }
4228     // Processed in units of nodes. Raise ValueError(xxxx)
4229     for (size_t index = index_begin; index < cnode->size() - 1; ++index) {
4230       const auto input = cnode->input(index);
4231       auto input_abs = args_abs_list[index - 1];
4232       MS_EXCEPTION_IF_NULL(input_abs);
4233       const bool need_symbol = raiseutils::CheckNeedSymbol(input_abs);
4234       if (need_symbol) {
4235         exception_string += "'";
4236       }
4237       bool need_comma = !IsPrimitiveCNode(input, prim::kPrimMakeTuple);
4238       exception_string += raiseutils::GetExceptionString(input_abs, input, key_value, need_symbol, need_comma);
4239       if (need_symbol) {
4240         exception_string += "'";
4241       }
4242       constexpr auto end_index = 2;
4243       if (index < cnode->size() - end_index) {
4244         exception_string += ", ";
4245       }
4246     }
4247     bool need_out_symbol = cnode->size() > 4;
4248     if (need_out_symbol) {
4249       exception_string = "(" + exception_string + ")";
4250     }
4251     RaiseConstant(exception_type, exception_string);
4252     MS_LOG(EXCEPTION) << "Constant raise is not raising exception correctly";
4253   }
4254 
4255  private:
RaiseConstant(const std::string & type,const std::string & exception_string="")4256   void RaiseConstant(const std::string &type, const std::string &exception_string = "") {
4257     auto iter = exception_types_map.find(type);
4258     if (iter == exception_types_map.end()) {
4259       MS_LOG(EXCEPTION) << "Unsupported exception type: " << type
4260                         << ". Raise only support some Python standard exception types: "
4261                         << SupportedExceptionsToString();
4262     }
4263     ExceptionType error_type = iter->second;
4264     if (exception_string.empty()) {
4265       MS_EXCEPTION(error_type);
4266     } else {
4267       MS_EXCEPTION(error_type) << exception_string;
4268     }
4269   }
4270 };
4271 
4272 class WithEnterEvaluator : public TransitionPrimEvaluator {
4273  public:
WithEnterEvaluator()4274   WithEnterEvaluator() : TransitionPrimEvaluator("WithEnterEvaluator") {}
4275   ~WithEnterEvaluator() override = default;
4276   MS_DECLARE_PARENT(WithEnterEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)4277   EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
4278                          const AnfNodeConfigPtr &out_conf) override {
4279     MS_EXCEPTION_IF_NULL(out_conf);
4280     MS_EXCEPTION_IF_NULL(out_conf->node());
4281     auto node = out_conf->node()->cast<CNodePtr>();
4282     MS_EXCEPTION_IF_NULL(node);
4283     auto cur_graph = node->func_graph();
4284     MS_EXCEPTION_IF_NULL(cur_graph);
4285 
4286     if (args_abs_list.size() != 1) {
4287       MS_LOG(INTERNAL_EXCEPTION) << "The enter node has wrong input." << node->debug_info();
4288     }
4289 
4290     // Check class object
4291     constexpr size_t cls_index = 0;
4292     MS_EXCEPTION_IF_NULL(args_abs_list[cls_index]);
4293     auto cls_val = args_abs_list[cls_index]->BuildValue();
4294     MS_EXCEPTION_IF_NULL(cls_val);
4295     auto value_obj = cls_val->cast<parse::MsClassObjectPtr>();
4296     if (value_obj == nullptr) {
4297       MS_EXCEPTION(TypeError) << "Only support jit_class instance, but got " << cls_val->ToString();
4298     }
4299     auto cls_obj = value_obj->obj();
4300 
4301     const std::string call_func = "__enter__";
4302     if (!py::hasattr(cls_obj, common::SafeCStr(call_func))) {
4303       MS_LOG(EXCEPTION) << value_obj->name() << " has no " << call_func << " function, please check the code.";
4304     }
4305     py::object call_obj = py::getattr(cls_obj, common::SafeCStr(call_func));
4306     FuncGraphPtr call_func_graph = parse::ConvertToFuncGraph(call_obj);
4307     if (call_func_graph == nullptr) {
4308       MS_LOG(INTERNAL_EXCEPTION) << "Parse python object " << call_func << " failed.";
4309     }
4310     FuncGraphManagerPtr manager = engine->func_graph_manager();
4311     MS_EXCEPTION_IF_NULL(manager);
4312     manager->AddFuncGraph(call_func_graph);
4313 
4314     std::vector<AnfNodePtr> enter_inputs{NewValueNode(call_func_graph)};
4315     //  __enter__(self)
4316     auto call_enter_node = cur_graph->NewCNodeInOrder(enter_inputs);
4317     // Continue to eval call_enter_node.
4318     AnfNodeConfigPtr fn_conf = engine->MakeConfig(call_enter_node, out_conf->context(), out_conf->func_graph());
4319     return engine->ForwardConfig(out_conf, fn_conf);
4320   }
4321 };
4322 
4323 class WithExitEvaluator : public TransitionPrimEvaluator {
4324  public:
WithExitEvaluator()4325   WithExitEvaluator() : TransitionPrimEvaluator("WithExitEvaluator") {}
4326   ~WithExitEvaluator() override = default;
4327   MS_DECLARE_PARENT(WithExitEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)4328   EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
4329                          const AnfNodeConfigPtr &out_conf) override {
4330     MS_EXCEPTION_IF_NULL(out_conf);
4331     MS_EXCEPTION_IF_NULL(out_conf->node());
4332     auto node = out_conf->node()->cast<CNodePtr>();
4333     MS_EXCEPTION_IF_NULL(node);
4334     auto cur_graph = node->func_graph();
4335     MS_EXCEPTION_IF_NULL(cur_graph);
4336 
4337     if (args_abs_list.size() != 1) {
4338       MS_LOG(INTERNAL_EXCEPTION) << "The exit node has wrong input." << node->debug_info();
4339     }
4340 
4341     // Check class object
4342     constexpr size_t cls_index = 0;
4343     MS_EXCEPTION_IF_NULL(args_abs_list[cls_index]);
4344     auto cls_val = args_abs_list[cls_index]->BuildValue();
4345     MS_EXCEPTION_IF_NULL(cls_val);
4346     auto value_obj = cls_val->cast<parse::MsClassObjectPtr>();
4347     if (value_obj == nullptr) {
4348       MS_EXCEPTION(TypeError) << "Only support jit_class instance, but got " << cls_val->ToString();
4349     }
4350     auto cls_obj = value_obj->obj();
4351 
4352     const std::string call_func = "__exit__";
4353     if (!py::hasattr(cls_obj, common::SafeCStr(call_func))) {
4354       MS_LOG(EXCEPTION) << value_obj->name() << " has no " << call_func << " function, please check the code.";
4355     }
4356     py::object call_obj = py::getattr(cls_obj, common::SafeCStr(call_func));
4357     FuncGraphPtr call_func_graph = parse::ConvertToFuncGraph(call_obj);
4358     if (call_func_graph == nullptr) {
4359       MS_LOG(INTERNAL_EXCEPTION) << "Parse python object " << call_func << " failed.";
4360     }
4361     FuncGraphManagerPtr manager = engine->func_graph_manager();
4362     MS_EXCEPTION_IF_NULL(manager);
4363     manager->AddFuncGraph(call_func_graph);
4364 
4365     std::vector<AnfNodePtr> exit_inputs{NewValueNode(call_func_graph)};
4366     constexpr size_t arg_size = 3;
4367     //  __exit__(self, type, value, trace)
4368     for (size_t i = 0; i < arg_size; ++i) {
4369       (void)exit_inputs.emplace_back(NewValueNode(kNone));
4370     }
4371     auto call_exit_node = cur_graph->NewCNodeInOrder(exit_inputs);
4372     // Continue to eval call_exit_node.
4373     AnfNodeConfigPtr fn_conf = engine->MakeConfig(call_exit_node, out_conf->context(), out_conf->func_graph());
4374     return engine->ForwardConfig(out_conf, fn_conf);
4375   }
4376 };
4377 
4378 class CondEvaluator : public TransitionPrimEvaluator {
4379  public:
CondEvaluator()4380   CondEvaluator() : TransitionPrimEvaluator("CondEvaluator") {}
4381   ~CondEvaluator() override = default;
4382   MS_DECLARE_PARENT(CondEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)4383   EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
4384                          const AnfNodeConfigPtr &out_conf) override {
4385     auto res_abstract = EvalUndeterminedArgs(args_abs_list);
4386     if (res_abstract != nullptr) {
4387       return res_abstract;
4388     }
4389     MS_EXCEPTION_IF_NULL(out_conf);
4390     MS_EXCEPTION_IF_NULL(out_conf->node());
4391     auto cnode = out_conf->node()->cast<CNodePtr>();
4392     MS_EXCEPTION_IF_NULL(cnode);
4393     auto cur_graph = cnode->func_graph();
4394     MS_EXCEPTION_IF_NULL(cur_graph);
4395     constexpr size_t input_size = 2;
4396     if (args_abs_list.size() != input_size) {
4397       MS_LOG(INTERNAL_EXCEPTION) << "The input size to cond node should be " << std::to_string(input_size)
4398                                  << ", but got " << std::to_string(args_abs_list.size());
4399     }
4400 
4401     AnfNodePtr new_node = nullptr;
4402     constexpr size_t cond_abs_index = 0;
4403     constexpr size_t cond_input_index = 1;
4404     constexpr size_t flag_input_index = 2;
4405     auto cond_abs = args_abs_list[cond_abs_index];
4406     auto cond_node = cnode->input(cond_input_index);
4407     auto flag_node = cnode->input(flag_input_index);
4408     MS_EXCEPTION_IF_NULL(cond_abs);
4409     if (cond_abs->isa<AbstractAny>()) {
4410       // If the input to cond node is AbstractAny, genenrate pyexecute node 'bool(input)';
4411       const auto script_str = std::make_shared<StringImm>("bool(__input__)");
4412 
4413       const auto input_str = std::make_shared<StringImm>("__input__");
4414       std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
4415       (void)key_value_names_list.emplace_back(NewValueNode(input_str));
4416       const auto key_value_name_tuple = cur_graph->NewCNode(key_value_names_list);
4417 
4418       std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple), cond_node};
4419       const auto key_value_tuple = cur_graph->NewCNode(key_value_list);
4420       new_node =
4421         fallback::CreatePyExecuteCNodeInOrder(cnode, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
4422       fallback::SetRealType<AnfNode, Type>(new_node, std::make_shared<TensorType>(kBool));
4423       fallback::SetRealShape(new_node, std::make_shared<abstract::Shape>(std::vector<int64_t>{Shape::kShapeDimAny}));
4424     } else if (cond_abs->isa<AbstractTensor>() && is_while_condition(flag_node)) {
4425       // When the condition of while is a tensor, do not use standard_method.tensor_bool
4426       // to avoid turning the tensor into scalar to cause a loop.
4427       constexpr auto operations_module = "mindspore.ops.operations";
4428       auto cast_op = python_adapter::GetPyFn(operations_module, kCastOpName)();
4429       auto cast_node = NewValueNode(parse::data_converter::PyDataToValue(cast_op));
4430       auto type_node = NewValueNode(TypeIdToType(kNumberTypeBool));
4431       new_node = cur_graph->NewCNodeInOrder({cast_node, cond_node, type_node});
4432       new_node->set_debug_info(cnode->debug_info());
4433     } else if (cond_abs->isa<AbstractFunction>()) {
4434       auto abs = std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(true), kBool);
4435       return std::make_shared<EvalResult>(abs, std::make_shared<AttrValueMap>());
4436     } else {
4437       // The logic of truth value testing:
4438       //   1. If the object has __bool__ attribute, call __bool__()
4439       //   2. Else if the object has __len__ attribute, call __len__()
4440       //   3. Else return true.
4441       auto cond_type = cond_abs->BuildType();
4442       MS_EXCEPTION_IF_NULL(cond_type);
4443       auto cond_type_id = cond_type->type_id();
4444       constexpr auto bool_attr_str = "__bool__";
4445       constexpr auto len_attr_str = "__len__";
4446       ValuePtr prim_func;
4447       if (!pipeline::Resource::GetMethodPtr(cond_type_id, bool_attr_str).empty()) {
4448         prim_func = prim::GetPythonOps(parse::NAMED_PRIMITIVE_BOOL);
4449       } else if (!pipeline::Resource::GetMethodPtr(cond_type_id, len_attr_str).empty()) {
4450         prim_func = prim::GetPythonOps(parse::NAMED_PRIMITIVE_CHECK_LEN);
4451       } else {
4452         prim_func = prim::GetPythonOps(parse::NAMED_PRIMITIVE_REAL_BOOL);
4453       }
4454       auto prim_fg = dyn_cast<FuncGraph>(prim_func);
4455       MS_EXCEPTION_IF_NULL(prim_fg);
4456       auto mng = cur_graph->manager();
4457       MS_EXCEPTION_IF_NULL(mng);
4458       prim_fg->set_manager(mng);
4459       new_node = cur_graph->NewCNodeInOrder({NewValueNode(prim_fg), cond_node});
4460     }
4461     AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
4462     return engine->ForwardConfig(out_conf, fn_conf);
4463   }
4464 
is_while_condition(const AnfNodePtr & flag_node) const4465   bool is_while_condition(const AnfNodePtr &flag_node) const {
4466     MS_EXCEPTION_IF_NULL(flag_node);
4467     auto vnode = GetValueNode(flag_node);
4468     MS_EXCEPTION_IF_NULL(vnode);
4469     return GetValue<bool>(vnode);
4470   }
4471 };
4472 
4473 struct PrimitiveImplInferValue {
4474   PrimitiveImpl impl_;        // implement function of primitive
4475   bool eval_value_;           // whether evaluate value
4476   TypePtr specify_out_type_;  // whether specify return type
4477   bool in_white_list_;        // true if this Primitive in white list, else false.
4478 };
4479 
4480 using PrimitiveToImplMap = mindspore::HashMap<PrimitivePtr, PrimitiveImplInferValue, PrimitiveHasher, PrimitiveEqual>;
GetUniformPrimitiveToImplMap()4481 PrimitiveToImplMap &GetUniformPrimitiveToImplMap() {
4482   using R = PrimitiveToImplMap::mapped_type;
4483   static PrimitiveToImplMap uniform_prim_implement_map{
4484     {prim::kPrimScalarPow, R{prim::ScalarPow, true, nullptr, true}},
4485     {prim::kPrimScalarUadd, R{prim::ScalarUAdd, true, nullptr, true}},
4486     {prim::kPrimScalarUsub, R{prim::ScalarUSub, true, nullptr, true}},
4487     {prim::kPrimScalarLog, R{prim::ScalarLog, true, nullptr, true}},
4488     {prim::kPrimBitXor, R{prim::BitXor, true, nullptr, true}},
4489     {prim::kPrimBitLeftShift, R{prim::BitLeftShift, true, nullptr, true}},
4490     {prim::kPrimBitRightShift, R{prim::BitRightShift, true, nullptr, true}},
4491     {prim::kPrimScalarNe, R{prim::ScalarNe, true, std::make_shared<Bool>(), true}},
4492     {prim::kPrimBoolAnd, R{prim::BoolAnd, true, std::make_shared<Bool>(), true}},
4493     {prim::kPrimBoolEq, R{prim::BoolEq, true, std::make_shared<Bool>(), true}},
4494     {prim::kPrimBoolOr, R{prim::BoolOr, true, std::make_shared<Bool>(), true}},
4495     {prim::kPrimStringConcat, R{prim::StringConcat, true, nullptr, true}},
4496     {prim::kPrimStringEq, R{prim::StringEq, true, std::make_shared<Bool>(), true}},
4497     {prim::kPrimStringLt, R{prim::StringLt, true, std::make_shared<Bool>(), true}},
4498     {prim::kPrimStringGt, R{prim::StringGt, true, std::make_shared<Bool>(), true}},
4499     {prim::kPrimStringLe, R{prim::StringLe, true, std::make_shared<Bool>(), true}},
4500     {prim::kPrimStringGe, R{prim::StringGe, true, std::make_shared<Bool>(), true}},
4501     {prim::kPrimStringNot, R{prim::StringNot, true, std::make_shared<Bool>(), true}},
4502     {prim::kPrimStringIn, R{prim::StringIn, true, std::make_shared<Bool>(), true}},
4503   };
4504   return uniform_prim_implement_map;
4505 }
4506 
4507 PrimEvaluatorMap prim_evaluator_constructors = PrimEvaluatorMap();
4508 std::mutex PrimEvaluatorConstructorMutex;
4509 
InitPrimEvaluatorConstructors()4510 void InitPrimEvaluatorConstructors() {
4511   PrimEvaluatorMap &constructor = prim_evaluator_constructors;
4512 
4513   for (const auto &iter : GetPrimitiveInferMap()) {
4514     constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second);
4515   }
4516 
4517   for (const auto &iter : GetUniformPrimitiveToImplMap()) {
4518     constructor[iter.first] =
4519       InitUniformPrimEvaluator(iter.first, iter.second.impl_, iter.second.eval_value_, iter.second.specify_out_type_);
4520   }
4521   constructor[prim::kPrimEmbed] = std::make_shared<EmbedEvaluator>();
4522   constructor[prim::kPrimRefToEmbed] = std::make_shared<RefToEmbedEvaluator>();
4523   constructor[prim::kPrimGetAttr] = std::make_shared<GetAttrEvaluator>();
4524   constructor[prim::kPrimSetAttr] = std::make_shared<SetAttrEvaluator>();
4525   constructor[prim::kPrimResolve] = std::make_shared<ResolveEvaluator>();
4526   constructor[prim::kPrimCreateInstance] = std::make_shared<CreateInstanceEvaluator>();
4527   constructor[prim::kPrimPartial] = std::make_shared<PartialEvaluator>();
4528   constructor[prim::kPrimPyInterpret] = std::make_shared<PyInterpretEvaluator>();
4529   constructor[prim::kPrimMakeTuple] = std::make_shared<MakeTupleEvaluator>();
4530   constructor[prim::kPrimMakeList] = std::make_shared<MakeListEvaluator>();
4531   constructor[prim::kPrimRaise] = std::make_shared<RaiseEvaluator>();
4532   constructor[prim::kPrimWithEnter] = std::make_shared<WithEnterEvaluator>();
4533   constructor[prim::kPrimWithExit] = std::make_shared<WithExitEvaluator>();
4534   constructor[prim::kPrimCond] = std::make_shared<CondEvaluator>();
4535 }
4536 
InitBuiltinPrimEvaluatorConstructors()4537 void InitBuiltinPrimEvaluatorConstructors() {
4538   PrimEvaluatorMap &constructor = prim_evaluator_constructors;
4539   constructor[prim::kPrimInnerAbs] = std::make_shared<InnerAbsEvaluator>();
4540   constructor[prim::kPrimInnerRound] = std::make_shared<InnerRoundEvaluator>();
4541 }
4542 }  // namespace
4543 
ClearPrimEvaluatorMap()4544 void ClearPrimEvaluatorMap() {
4545   prim_evaluator_constructors.clear();
4546   GetFrontendPrimitiveInferMapPtr()->clear();
4547   GetUniformPrimitiveToImplMap().clear();
4548 }
4549 
IsInWhiteList(const PrimitivePtr & primitive)4550 bool IsInWhiteList(const PrimitivePtr &primitive) {
4551   MS_EXCEPTION_IF_NULL(primitive);
4552 
4553   using WhiteList = mindspore::HashMap<PrimitivePtr, bool, PrimitiveHasher, PrimitiveEqual>;
4554 
4555   static WhiteList whitelist = {{prim::kPrimPartial, true}};
4556   auto iter = whitelist.find(primitive);
4557   if (iter != whitelist.end()) {
4558     return iter->second;
4559   }
4560 
4561   auto found = abstract::GetFrontendPrimitiveInferImpl(primitive);
4562   if (found.has_value()) {
4563     auto infer = found.value();
4564     return infer.IsInWhiteList();
4565   }
4566 
4567   auto uni_iter = GetUniformPrimitiveToImplMap().find(primitive);
4568   if (uni_iter != GetUniformPrimitiveToImplMap().end()) {
4569     return uni_iter->second.in_white_list_;
4570   }
4571 
4572   return true;
4573 }
4574 
GetPrimEvaluatorConstructors()4575 PrimEvaluatorMap &GetPrimEvaluatorConstructors() {
4576   PrimEvaluatorMap &constructor = prim_evaluator_constructors;
4577   if (!constructor.empty()) {
4578     return constructor;
4579   }
4580   std::lock_guard<std::mutex> initLock(PrimEvaluatorConstructorMutex);
4581   if (constructor.empty()) {
4582     InitPrimEvaluatorConstructors();
4583     InitBuiltinPrimEvaluatorConstructors();
4584   }
4585 
4586   return constructor;
4587 }
4588 
4589 namespace {
IsSubtypeTuple(const AbstractBasePtr x,const TypePtr model)4590 bool IsSubtypeTuple(const AbstractBasePtr x, const TypePtr model) {
4591   MS_EXCEPTION_IF_NULL(x);
4592   MS_EXCEPTION_IF_NULL(model);
4593   auto x_tuple = dyn_cast_ptr<AbstractTuple>(x);
4594   auto model_tuple = dyn_cast_ptr<Tuple>(model);
4595 
4596   if (x_tuple == nullptr || model_tuple == nullptr) {
4597     return false;
4598   }
4599 
4600   if (model->IsGeneric()) {
4601     return true;
4602   }
4603 
4604   if (x_tuple->size() != model_tuple->size()) {
4605     return false;
4606   }
4607 
4608   for (size_t i = 0; i < x_tuple->size(); i++) {
4609     bool is_subtype = IsSubtype((*x_tuple)[i], (*model_tuple)[i]);
4610     if (!is_subtype) {
4611       return false;
4612     }
4613   }
4614   return true;
4615 }
4616 
IsSubtypeArray(const AbstractBasePtr x,const TypePtr model)4617 bool IsSubtypeArray(const AbstractBasePtr x, const TypePtr model) {
4618   MS_EXCEPTION_IF_NULL(x);
4619   MS_EXCEPTION_IF_NULL(model);
4620   auto x_tensor = dyn_cast_ptr<AbstractTensor>(x);
4621   auto model_tensor = dyn_cast_ptr<TensorType>(model);
4622 
4623   if (x_tensor == nullptr || model_tensor == nullptr) {
4624     return false;
4625   }
4626 
4627   if (model->IsGeneric()) {
4628     return true;
4629   }
4630 
4631   return IsSubtype(x_tensor->element(), model_tensor->element());
4632 }
4633 
IsSubtypeList(const AbstractBasePtr x,const TypePtr model)4634 bool IsSubtypeList(const AbstractBasePtr x, const TypePtr model) {
4635   MS_EXCEPTION_IF_NULL(x);
4636   MS_EXCEPTION_IF_NULL(model);
4637   auto x_list = dyn_cast_ptr<AbstractList>(x);
4638   auto model_list = dyn_cast_ptr<List>(model);
4639 
4640   if (x_list == nullptr || model_list == nullptr) {
4641     return false;
4642   }
4643 
4644   if (model->IsGeneric()) {
4645     return true;
4646   }
4647 
4648   if (x_list->size() != model_list->size()) {
4649     return false;
4650   }
4651 
4652   bool is_subtype = true;
4653   for (size_t i = 0; i < x_list->size(); i++) {
4654     is_subtype = IsSubtype((*x_list)[i], (*model_list)[i]);
4655     if (!is_subtype) {
4656       return false;
4657     }
4658   }
4659   return is_subtype;
4660 }
4661 
IsSubtypeScalar(const AbstractBasePtr x,const TypePtr model)4662 inline bool IsSubtypeScalar(const AbstractBasePtr x, const TypePtr model) {
4663   MS_EXCEPTION_IF_NULL(x);
4664   MS_EXCEPTION_IF_NULL(model);
4665   if (dyn_cast_ptr<AbstractScalar>(x) == nullptr) {
4666     return false;
4667   }
4668   auto &x_type = x->GetTypeTrack();
4669   return IsSubType(x_type, model);
4670 }
4671 }  // namespace
4672 
IsSubtype(const AbstractBasePtr x,const TypePtr model)4673 bool IsSubtype(const AbstractBasePtr x, const TypePtr model) {
4674   MS_EXCEPTION_IF_NULL(x);
4675   MS_EXCEPTION_IF_NULL(model);
4676   TypeId model_typeid = model->type_id();
4677   switch (model_typeid) {
4678     case kMetaTypeObject:
4679       return true;
4680     case kObjectTypeTuple:
4681       return IsSubtypeTuple(x, model);
4682     case kObjectTypeTensorType:
4683       return IsSubtypeArray(x, model);
4684     case kObjectTypeList:
4685       return IsSubtypeList(x, model);
4686     default:
4687       if (IsSubType(model, std::make_shared<Number>())) {
4688         return IsSubtypeScalar(x, model);
4689       }
4690       MS_LOG(EXCEPTION) << "Invalid model type: " << model->ToString() << ".";
4691   }
4692 }
4693 
GetPrimitiveInitArgs(const PrimitivePyPtr & prim_py,const ops::OpDef * op_def)4694 AnfNodePtrList GetPrimitiveInitArgs(const PrimitivePyPtr &prim_py, const ops::OpDef *op_def) {
4695   MS_EXCEPTION_IF_NULL(prim_py);
4696   MS_EXCEPTION_IF_NULL(op_def);
4697 
4698   std::vector<AnfNodePtr> prim_init_arg_nodes;
4699   auto obj = prim_py->GetPyObj();
4700 
4701   for (const auto &op_arg : op_def->args_) {
4702     if (op_arg.as_init_arg_) {
4703       auto arg_name = op_arg.arg_name_;
4704       py::object arg_value = py::getattr(obj, common::SafeCStr(arg_name));
4705       ValuePtr converted_ret = nullptr;
4706       bool converted = parse::ConvertData(arg_value, &converted_ret);
4707       if (!converted) {
4708         MS_LOG(INTERNAL_EXCEPTION) << "Cannot convert initialization arg: (" << arg_name << ": " << py::str(arg_value)
4709                                    << ") in Primitive '" << prim_py->name() << "'.";
4710       }
4711       (void)prim_init_arg_nodes.emplace_back(NewValueNode(converted_ret));
4712     }
4713   }
4714   MS_LOG(DEBUG) << "PrimitivePy " << prim_py->name() << " has " << prim_init_arg_nodes.size() << " __init__() args";
4715   return prim_init_arg_nodes;
4716 }
4717 
GeneratePrimitiveCNode(const PrimitivePtr & primitive,const ops::OpDef * op_def,const FuncGraphPtr & graph,const AnfNodePtrList & init_args_nodes,const AnfNodePtrList & call_args_nodes,const std::function<AbstractBasePtr (const AnfNodePtr &)> & eval_func)4718 CNodePtr GeneratePrimitiveCNode(const PrimitivePtr &primitive, const ops::OpDef *op_def, const FuncGraphPtr &graph,
4719                                 const AnfNodePtrList &init_args_nodes, const AnfNodePtrList &call_args_nodes,
4720                                 const std::function<AbstractBasePtr(const AnfNodePtr &)> &eval_func) {
4721   MS_EXCEPTION_IF_NULL(primitive);
4722   MS_EXCEPTION_IF_NULL(op_def);
4723 
4724   auto args_pair = std::make_pair(init_args_nodes, call_args_nodes);
4725 
4726   // Follow the implementations in PrimitiveArgsToInputsEvaluator, convert to base Primitive, and is_preprocessed=true
4727   auto new_prim = std::make_shared<Primitive>(*primitive);
4728   auto new_cnode = CheckAndConvertPrimitiveArgs(new_prim, graph, args_pair, eval_func, true);
4729 
4730   MS_LOG(INFO) << "Convert primitive args: " << primitive->name() << ", new node: " << new_cnode->DebugString();
4731   return new_cnode;
4732 }
4733 }  // namespace abstract
4734 }  // namespace mindspore
4735