• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2024 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "pipeline/jit/ps/static_analysis/static_analysis.h"
20 #include <algorithm>
21 #include <memory>
22 #include <mutex>
23 #include <set>
24 #include <unordered_set>
25 #include <utility>
26 #include <atomic>
27 #include "mindspore/core/ops/structure_ops.h"
28 #include "mindspore/core/ops/sequence_ops.h"
29 #include "mindspore/core/ops/framework_ops.h"
30 #include "abstract/abstract_value.h"
31 #include "pipeline/jit/ps/fallback.h"
32 #include "pipeline/jit/ps/action.h"
33 #include "pipeline/jit/ps/parse/resolve.h"
34 #include "pipeline/jit/ps/static_analysis/prim.h"
35 #include "frontend/operator/ops.h"
36 #include "utils/ms_exception.h"
37 #include "utils/compile_config.h"
38 #include "ir/func_graph_cloner.h"
39 #include "pipeline/jit/ps/static_analysis/evaluator.h"
40 #include "pipeline/jit/ps/debug/trace.h"
41 #include "include/common/fallback.h"
42 #include "include/common/debug/anf_ir_dump.h"
43 #include "include/common/utils/convert_utils_py.h"
44 #include "include/common/utils/python_adapter.h"
45 #include "pipeline/jit/ps/static_analysis/async_eval_result.h"
46 #include "frontend/operator/ops_front_infer_function.h"
47 #include "frontend/operator/composite/composite.h"
48 #include "ops/op_def.h"
49 
50 namespace mindspore {
51 namespace abstract {
52 // Record current depth of function call stack, including `stack_frame_depth`.
53 std::atomic<size_t> function_call_depth;
54 // Record current depth of stack frames call.
55 std::atomic<size_t> stack_frame_depth;
56 
ResetFunctionCallDepth()57 void ResetFunctionCallDepth() { function_call_depth = 0; }
58 
IncreaseFunctionCallDepth()59 void IncreaseFunctionCallDepth() { (void)(++function_call_depth); }
60 
DecreaseFunctionCallDepth()61 void DecreaseFunctionCallDepth() {
62   if (function_call_depth == 0) {
63     MS_LOG(INTERNAL_EXCEPTION) << "Current function call depth is already 0, can not decrease it.";
64   }
65   function_call_depth--;
66 }
67 
FunctionCallDepth()68 size_t FunctionCallDepth() { return function_call_depth; }
69 
ResetStackFrameDepth()70 void ResetStackFrameDepth() { stack_frame_depth = 0; }
71 
IncreaseStackFrameDepth()72 void IncreaseStackFrameDepth() { (void)(++stack_frame_depth); }
73 
DecreaseStackFrameDepth()74 void DecreaseStackFrameDepth() {
75   if (stack_frame_depth == 0) {
76     MS_LOG(INTERNAL_EXCEPTION) << "Current stack frame depth is already 0, can not decrease it.";
77   }
78   stack_frame_depth--;
79 }
80 
StackFrameDepth()81 size_t StackFrameDepth() { return stack_frame_depth; }
82 
83 namespace {
ExecEvaluator(EvaluatorPtr eval,AnalysisEnginePtr engine,ConfigPtrList args_conf_list,AnfNodeConfigPtr out_conf,std::string thread_id,AsyncAbstractPtr async_result_branch,AsyncAbstractPtr async_result_main,AsyncInferTaskPtr async_task,trace::TraceGraphEvalStack graph_evals,trace::TraceCNodeEvalStack trace_c_node_evals)84 void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList args_conf_list, AnfNodeConfigPtr out_conf,
85                    std::string thread_id, AsyncAbstractPtr async_result_branch, AsyncAbstractPtr async_result_main,
86                    AsyncInferTaskPtr async_task, trace::TraceGraphEvalStack graph_evals,
87                    trace::TraceCNodeEvalStack trace_c_node_evals) {
88   MS_EXCEPTION_IF_NULL(eval);
89   MS_EXCEPTION_IF_NULL(async_task);
90   AnalysisSchedule::set_thread_id(thread_id);
91   // Restore trace stack for dump stack when there is exception.
92   trace::TraceEvalCNodeStackPrepare(trace_c_node_evals);
93   trace_c_node_evals.clear();
94   trace::TraceGraphEvalStackPrepare(graph_evals);
95   graph_evals.clear();
96 
97   try {
98     // Wait for Signal to run
99     MS_LOG(DEBUG) << async_task.get() << "  " << eval->ToString() << " waiting.";
100     (void)async_task->GetResult();
101     MS_LOG(DEBUG) << async_task.get() << "  " << eval->ToString() << " running.";
102 
103     // Acquire GIL for eval to callback python.
104     EvalResultPtr result;
105     {
106       MS_LOG(DEBUG) << eval->ToString() << "_" << AnalysisSchedule::thread_id() << " begin.";
107       py::gil_scoped_acquire py_guard;
108       result = eval->Run(engine, args_conf_list, out_conf);
109     }
110     MS_LOG(DEBUG) << eval->ToString() << "_" << AnalysisSchedule::thread_id() << " end.";
111     MS_EXCEPTION_IF_NULL(result);
112     MS_EXCEPTION_IF_NULL(result->abstract());
113 
114     // Check the branch value to be compatible with the other branch value.
115     AnalysisResultCacheMgr::GetInstance().CheckSwitchValueJoinable(out_conf, result->abstract());
116     // Broaden the result of switch(c,t,f)()
117     auto broaden_abstract = result->abstract()->Broaden();
118 
119     MS_EXCEPTION_IF_NULL(async_result_branch);
120     MS_EXCEPTION_IF_NULL(async_result_main);
121     // Notify the thread of waiting for branch value and the main thread to continue.
122     async_result_branch->set_result(broaden_abstract);
123     async_result_main->set_result(broaden_abstract);
124     MS_LOG(DEBUG) << GetInferThread() << " async :" << eval->ToString()
125                   << " asyncResult address = " << async_result_branch.get();
126     if (async_result_branch->TryGetResult()) {
127       MS_LOG(DEBUG) << "value = " << (async_result_branch->TryGetResult())->ToString();
128     } else {
129       MS_LOG(DEBUG) << "value = null.";
130     }
131   } catch (const std::exception &ex) {
132     MS_EXCEPTION_IF_NULL(out_conf->node());
133     MS_LOG(INFO) << GetInferThread() << "Eval node: " << out_conf->node()->ToString() << "  " << eval->ToString()
134                  << " threw exception: " << ex.what();
135     AnalysisSchedule::GetInstance().HandleException(ex);
136   }
137   trace::ClearTraceStack();
138   ClearThreadLocal();
139   MS_LOG(DEBUG) << AnalysisSchedule::thread_id() << " exited.";
140   // Thread number will be drop when thread exits.
141   AnalysisSchedule::GetInstance().DecreaseThreadCount();
142 }
143 
BuildAsyncAbstractRecursively(const AbstractBasePtr & orig_abs,const std::vector<AsyncAbstractPtr> & pending_async_abstract_list,const std::vector<std::size_t> & index)144 AbstractBasePtr BuildAsyncAbstractRecursively(const AbstractBasePtr &orig_abs,
145                                               const std::vector<AsyncAbstractPtr> &pending_async_abstract_list,
146                                               const std::vector<std::size_t> &index) {
147   MS_EXCEPTION_IF_NULL(orig_abs);
148   auto sequence_abs = dyn_cast_ptr<AbstractSequence>(orig_abs);
149   if (sequence_abs != nullptr) {
150     const auto &orig_elements = sequence_abs->elements();
151     AbstractBasePtrList new_elements;
152     for (size_t i = 0; i < orig_elements.size(); ++i) {
153       MS_EXCEPTION_IF_NULL(orig_elements[i]);
154       if (orig_elements[i]->isa<AbstractFuncAtom>()) {
155         AbstractFuncAtomPtrList abs_func_list{orig_elements[i]->cast<AbstractFuncAtomPtr>()};
156         for (size_t j = 0; j < pending_async_abstract_list.size(); ++j) {
157           std::vector<std::size_t> new_index(index);
158           new_index.push_back(i);
159           auto async_func = AsyncAbstractFuncAtom::MakeShared(pending_async_abstract_list[j], new_index);
160           abs_func_list.push_back(async_func);
161         }
162         new_elements.push_back(AbstractFunction::MakeAbstractFunction(abs_func_list));
163       } else if (orig_elements[i]->isa<AbstractSequence>()) {
164         std::vector<std::size_t> new_index(index);
165         new_index.push_back(i);
166         new_elements.push_back(BuildAsyncAbstractRecursively(orig_elements[i], pending_async_abstract_list, new_index));
167       } else {
168         new_elements.push_back(orig_elements[i]);
169       }
170     }
171     static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
172     AbstractBasePtr new_abs;
173     if (orig_abs->isa<AbstractTuple>()) {
174       new_abs = std::make_shared<AbstractTuple>(
175         new_elements, (enable_eliminate_unused_element ? sequence_abs->sequence_nodes() : nullptr));
176     } else if (orig_abs->isa<AbstractList>()) {
177       new_abs = std::make_shared<AbstractList>(
178         new_elements, (enable_eliminate_unused_element ? sequence_abs->sequence_nodes() : nullptr));
179     } else {
180       MS_LOG(INTERNAL_EXCEPTION) << "FirstResult is not AbstractTuple or AbstractList, but: " << orig_abs->ToString();
181     }
182     return new_abs;
183   }
184   MS_LOG(INTERNAL_EXCEPTION) << "Orig abstract is not AbstractTuple or AbstractList, but: " << orig_abs->ToString();
185 }
186 
BuildPossibleSpecs(const AbstractBasePtr & first_result,const std::vector<AsyncAbstractPtr> & branch_async_abstract_list,AbstractBasePtrList * out_abs_list)187 void BuildPossibleSpecs(const AbstractBasePtr &first_result,
188                         const std::vector<AsyncAbstractPtr> &branch_async_abstract_list,
189                         AbstractBasePtrList *out_abs_list) {
190   MS_EXCEPTION_IF_NULL(out_abs_list);
191   MS_EXCEPTION_IF_NULL(first_result);
192   std::vector<AsyncAbstractPtr> pending_async_abstract_list;
193   std::size_t len = branch_async_abstract_list.size();
194 
195   for (size_t i = 0; i < len; ++i) {
196     AbstractBasePtr result;
197     MS_EXCEPTION_IF_NULL(branch_async_abstract_list[i]);
198     if (enable_waiting_branch_eval()) {
199       result = branch_async_abstract_list[i]->GetResult();
200     } else {
201       result = branch_async_abstract_list[i]->TryGetResult();
202     }
203 
204     if (result) {
205       if (result->isa<AsyncAbstractFuncAtom>()) {
206         branch_async_abstract_list[i]->ClearPossibleResult();
207         pending_async_abstract_list.push_back(branch_async_abstract_list[i]);
208         MS_LOG(DEBUG) << "Pending add: " << branch_async_abstract_list[i].get() << "_"
209                       << branch_async_abstract_list[i]->ToString();
210       } else {
211         out_abs_list->push_back(result);
212       }
213     } else {
214       pending_async_abstract_list.push_back(branch_async_abstract_list[i]);
215       MS_LOG(DEBUG) << "Pending add: " << branch_async_abstract_list[i].get() << "_"
216                     << branch_async_abstract_list[i]->ToString();
217     }
218   }
219 
220   if (first_result->isa<AbstractFunction>()) {
221     for (std::size_t j = 0; j < pending_async_abstract_list.size(); ++j) {
222       auto async_func = AsyncAbstractFuncAtom::MakeShared(pending_async_abstract_list[j], std::vector<size_t>{0});
223       out_abs_list->push_back(async_func);
224       MS_LOG(DEBUG) << "out_abs_list add: " << async_func.get() << "_" << async_func->ToString();
225     }
226   } else if (first_result->isa<AbstractSequence>()) {
227     const auto &new_first_result =
228       BuildAsyncAbstractRecursively(first_result, pending_async_abstract_list, std::vector<size_t>());
229     MS_LOG(DEBUG) << GetInferThread() << " Try to replace old first with new one, old: " << first_result->ToString()
230                   << ", new: " << new_first_result->ToString();
231     std::replace_if(
232       out_abs_list->begin(), out_abs_list->end(),
233       [first_result](const auto &element) { return element == first_result; }, new_first_result);
234   } else {
235     MS_LOG(DEBUG) << GetInferThread() << " wait for normal async result";
236   }
237 }
238 
ConvertToPyInterpretCall(const CNodePtr & cnode,const AnfNodeConfigPtr & conf,const AnfNodePtr & func_node=nullptr)239 EvalResultPtr ConvertToPyInterpretCall(const CNodePtr &cnode, const AnfNodeConfigPtr &conf,
240                                        const AnfNodePtr &func_node = nullptr) {
241   auto fg = cnode->func_graph();
242   MS_EXCEPTION_IF_NULL(fg);
243   auto out_node = conf->node();
244   MS_EXCEPTION_IF_NULL(out_node);
245   std::stringstream script_buffer;
246   AnfNodePtrList local_key_inputs = {NewValueNode(prim::kPrimMakeTuple)};
247   AnfNodePtrList local_value_inputs = {NewValueNode(prim::kPrimMakeTuple)};
248 
249   // Handle call function
250   const std::string call_func_str = "__call_func_str__";
251   constexpr size_t call_func_index = 0;
252   script_buffer << call_func_str << "(";
253   (void)local_key_inputs.emplace_back(NewValueNode(call_func_str));
254   if (func_node == nullptr) {
255     (void)local_value_inputs.emplace_back(cnode->input(call_func_index));
256   } else {
257     (void)local_value_inputs.emplace_back(func_node);
258   }
259 
260   // Handle inputs.
261   const std::string call_prefix = "__input_";
262   for (size_t i = 1; i < cnode->size(); ++i) {
263     auto cur_node = cnode->input(i);
264     if (IsPrimitiveCNode(cur_node, prim::kPrimMakeKeywordArg)) {
265       const std::string value_cur_str = call_prefix + "_value_" + std::to_string(i - 1) + "__";
266       constexpr size_t key_inputs_index = 1;
267       constexpr size_t value_inputs_index = 2;
268       constexpr size_t expect_inputs_size = 3;
269       if (cur_node->cast<CNodePtr>()->size() != expect_inputs_size) {
270         MS_LOG(INTERNAL_EXCEPTION) << "The make_keyword_arg node should have " << expect_inputs_size
271                                    << " inputs, but got " << cnode->size();
272       }
273       auto key_node = cur_node->cast<CNodePtr>()->input(key_inputs_index);
274       if (!IsValueNode<StringImm>(key_node)) {
275         MS_LOG(INTERNAL_EXCEPTION) << "The key in make_keyword args must be string, but got "
276                                    << key_node->DebugString();
277       }
278       auto key_string = GetValue<std::string>(GetValueNode(key_node));
279       std::string key_value_str = key_string + "=" + value_cur_str;
280       (void)local_key_inputs.emplace_back(NewValueNode(value_cur_str));
281       script_buffer << key_value_str << ",";
282       auto value_node = cur_node->cast<CNodePtr>()->input(value_inputs_index);
283       (void)local_value_inputs.emplace_back(value_node);
284     } else {
285       const std::string cur_str = call_prefix + std::to_string(i - 1) + "__";
286       script_buffer << cur_str << ",";
287       (void)local_key_inputs.emplace_back(NewValueNode(cur_str));
288       (void)local_value_inputs.emplace_back(cur_node);
289     }
290   }
291   script_buffer << ")";
292   const auto &script = script_buffer.str();
293   auto local_key_node = fg->NewCNode(local_key_inputs);
294   auto local_value_node = fg->NewCNode(local_value_inputs);
295   auto local_dict_node = fg->NewCNode({NewValueNode(prim::kPrimMakeDict), local_key_node, local_value_node});
296   auto obj_call_node =
297     fallback::CreatePyInterpretCNode(fg, script, py::dict(), local_dict_node, out_node->debug_info());
298   MS_LOG(DEBUG) << "Created obj_call_node: " << obj_call_node->DebugString();
299   AnalysisEnginePtr eng = conf->engine();
300   MS_EXCEPTION_IF_NULL(eng);
301   AnfNodeConfigPtr fn_conf = eng->MakeConfig(obj_call_node, conf->context(), conf->func_graph());
302   return eng->ForwardConfig(conf, fn_conf);
303 }
304 
ParsePyObjToFunc(const py::object & py_fn,const CNodePtr & cnode,const AnfNodeConfigPtr & conf)305 EvalResultPtr ParsePyObjToFunc(const py::object &py_fn, const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
306   FuncGraphPtr func_fg = nullptr;
307   {
308     MS_LOG_TRY_CATCH_SCOPE;
309     func_fg = parse::ParsePythonCode(py_fn);
310   }
311   if (func_fg != nullptr) {
312     auto fg = cnode->func_graph();
313     MS_EXCEPTION_IF_NULL(fg);
314     func_fg->set_manager(fg->manager());
315 
316     std::vector<AnfNodePtr> new_cnode_inputs;
317     (void)new_cnode_inputs.emplace_back(NewValueNode(func_fg));
318     for (std::size_t i = 1; i < cnode->size(); ++i) {
319       (void)new_cnode_inputs.emplace_back(cnode->input(i));
320     }
321     auto new_cnode = fg->NewCNodeInOrder(new_cnode_inputs);
322     new_cnode->set_debug_info(cnode->debug_info());
323 
324     AnalysisEnginePtr eng = conf->engine();
325     MS_EXCEPTION_IF_NULL(eng);
326     AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, conf->context(), conf->func_graph());
327     return eng->ForwardConfig(conf, fn_conf);
328   } else {
329     return ConvertToPyInterpretCall(cnode, conf);
330   }
331 }
332 
GetClassName(const py::object & cls_obj)333 std::string GetClassName(const py::object &cls_obj) {
334   if (py::hasattr(cls_obj, "__class__")) {
335     return py::getattr(py::getattr(cls_obj, "__class__"), "__name__").cast<py::str>();
336   }
337   return py::getattr(cls_obj, "__name__").cast<py::str>();
338 }
339 
ConvertCallPyObjCallFunc(const CNodePtr & cnode,const AbstractBasePtr & abs,const AnfNodeConfigPtr & conf)340 EvalResultPtr ConvertCallPyObjCallFunc(const CNodePtr &cnode, const AbstractBasePtr &abs,
341                                        const AnfNodeConfigPtr &conf) {
342   MS_EXCEPTION_IF_NULL(cnode);
343   MS_EXCEPTION_IF_NULL(abs);
344   auto val = abs->BuildValue();
345   MS_EXCEPTION_IF_NULL(val);
346   auto warp_obj = dyn_cast_ptr<parse::PyObjectWrapper>(val);
347   MS_EXCEPTION_IF_NULL(warp_obj);
348   py::object cls_obj = warp_obj->obj();
349   auto class_name = GetClassName(cls_obj);
350   py::object call_obj = py::none();
351   const std::string construct_func_name = "construct";
352   if (py::hasattr(cls_obj, common::SafeCStr(construct_func_name)) && py::isinstance<Cell>(cls_obj)) {
353     call_obj = py::getattr(cls_obj, common::SafeCStr(construct_func_name));
354   } else {
355     const std::string call_func_name = "__call__";
356     if (py::hasattr(cls_obj, common::SafeCStr(call_func_name))) {
357       call_obj = py::getattr(cls_obj, common::SafeCStr(call_func_name));
358     }
359   }
360   if (py::isinstance<py::none>(call_obj)) {
361     MS_EXCEPTION(ValueError) << class_name << "is not a callable object";
362   }
363   return ParsePyObjToFunc(call_obj, cnode, conf);
364 }
365 
ConvertMsClassObjToFunc(const CNodePtr & cnode,const AbstractBasePtr & abs,const AnfNodeConfigPtr & conf)366 EvalResultPtr ConvertMsClassObjToFunc(const CNodePtr &cnode, const AbstractBasePtr &abs, const AnfNodeConfigPtr &conf) {
367   MS_EXCEPTION_IF_NULL(cnode);
368   MS_EXCEPTION_IF_NULL(abs);
369   auto val = abs->BuildValue();
370   MS_EXCEPTION_IF_NULL(val);
371   auto class_val = dyn_cast_ptr<parse::MsClassObject>(val);
372   MS_EXCEPTION_IF_NULL(class_val);
373   py::object cls_obj = class_val->obj();
374   const std::string call_func_name = "__call__";
375   if (!py::hasattr(cls_obj, common::SafeCStr(call_func_name))) {
376     MS_EXCEPTION(ValueError) << class_val->name() << " has no " << call_func_name
377                              << " function, please check the code.";
378   }
379   py::object call_obj = py::getattr(cls_obj, common::SafeCStr(call_func_name));
380   return ParsePyObjToFunc(call_obj, cnode, conf);
381 }
382 
CheckFuncSideEffect(const AbstractFunctionPtr & func)383 bool CheckFuncSideEffect(const AbstractFunctionPtr &func) {
384   // Check if func graph contains isolated side-effect, and sync.
385   auto func_graph_abs = dyn_cast_ptr<FuncGraphAbstractClosure>(func);
386   if (func_graph_abs != nullptr) {
387     MS_EXCEPTION_IF_NULL(func_graph_abs->func_graph());
388     return func_graph_abs->func_graph()->has_side_effect_node();
389   } else {
390     auto meta_func_graph_abs = dyn_cast_ptr<MetaFuncGraphAbstractClosure>(func);
391     if (meta_func_graph_abs != nullptr) {
392       MS_EXCEPTION_IF_NULL(meta_func_graph_abs->meta_func_graph());
393       return meta_func_graph_abs->meta_func_graph()->has_side_effect_node();
394     }
395     if (func->isa<abstract::PartialAbstractClosure>()) {
396       const auto &abstract_partial_func = func->cast<abstract::PartialAbstractClosurePtr>();
397       const auto &abstract_fn = abstract_partial_func->fn();
398       MS_EXCEPTION_IF_NULL(abstract_fn);
399       return CheckFuncSideEffect(abstract_fn);
400     }
401   }
402   return false;
403 }
404 
GetRealFuncAtom(const AbstractFuncAtomPtr & possible_func)405 AbstractFuncAtomPtr GetRealFuncAtom(const AbstractFuncAtomPtr &possible_func) {
406   MS_EXCEPTION_IF_NULL(possible_func);
407   auto real_atom = possible_func;
408   const auto &async_abs_func = possible_func->cast_ptr<AsyncAbstractFuncAtom>();
409   if (async_abs_func != nullptr) {
410     auto real_func = async_abs_func->GetUnique();
411     real_atom = dyn_cast<AbstractFuncAtom>(real_func);
412     MS_EXCEPTION_IF_NULL(real_atom);
413     MS_LOG(DEBUG) << "Real AsyncAbstractFuncAtom is: " << real_atom->ToString();
414   }
415   return real_atom;
416 }
417 
418 template <typename T>
Match(const ValuePtr & prim)419 bool Match(const ValuePtr &prim) {
420   return prim->isa<T>();
421 }
422 using MetaFgMatchFunc = std::function<bool(const ValuePtr &)>;
423 
MatchMetaFg(const ValuePtr & prim)424 bool MatchMetaFg(const ValuePtr &prim) {
425   static const std::vector<MetaFgMatchFunc> meta_fg_ops{
426     Match<prim::GradOperation>,
427     Match<prim::VmapOperation>,
428     Match<prim::Shard>,
429   };
430   return std::any_of(meta_fg_ops.cbegin(), meta_fg_ops.cend(),
431                      [&prim](const MetaFgMatchFunc &match_func) { return match_func(prim); });
432 }
433 
RemoveSequenceFromOrderList(const CNodePtr & origin_cnode)434 void RemoveSequenceFromOrderList(const CNodePtr &origin_cnode) {
435   constexpr size_t sequence_input_pos = 2;
436   if (origin_cnode->size() <= sequence_input_pos) {
437     return;
438   }
439   auto seq_node = origin_cnode->input(sequence_input_pos);
440   auto prim = GetCNodePrimitiveWithoutDoSignature(seq_node);
441   if (prim != nullptr &&
442       (IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList))) {
443     auto seq_cnode = dyn_cast<CNode>(seq_node);
444     MS_EXCEPTION_IF_NULL(seq_cnode);
445     seq_cnode->func_graph()->EraseUnusedNodeInOrder(seq_cnode);
446   }
447 }
448 
GetEvalResult(const AnfNodePtr & node,const AnalysisEnginePtr & engine,const AnfNodeConfigPtr & conf)449 AbstractBasePtr GetEvalResult(const AnfNodePtr &node, const AnalysisEnginePtr &engine, const AnfNodeConfigPtr &conf) {
450   AnfNodeConfigPtr func_conf = std::make_shared<AnfNodeConfig>(engine, node, conf->context(), conf->func_graph());
451   auto possible_func_eval_result = func_conf->ObtainEvalResult();
452   MS_EXCEPTION_IF_NULL(possible_func_eval_result);
453   return possible_func_eval_result->abstract();
454 }
455 
IsFuncGraphAbstractInput(const CNodePtr & origin_cnode,const AnalysisEnginePtr & engine,const AnfNodeConfigPtr & conf)456 bool IsFuncGraphAbstractInput(const CNodePtr &origin_cnode, const AnalysisEnginePtr &engine,
457                               const AnfNodeConfigPtr &conf) {
458   auto possible_func = GetEvalResult(origin_cnode->input(1), engine, conf);
459   if (possible_func == nullptr || !possible_func->isa<FuncGraphAbstractClosure>()) {
460     return false;
461   }
462   // Check whether it is a high order scene such as GradOperation(GradOperation(net)), the meta_unpack_prepare doesn't
463   // handle before. To handle this later.
464   if (!origin_cnode->input(1)->isa<CNode>()) {
465     return true;
466   }
467   auto input1_cnode = origin_cnode->input(1)->cast<CNodePtr>();
468   auto possible_prim = GetEvalResult(input1_cnode->input(0), engine, conf);
469   if (possible_prim == nullptr || !possible_prim->isa<PrimitiveAbstractClosure>()) {
470     return true;
471   }
472   auto value = GetValueWithoutDoSignature(possible_prim->cast<PrimitiveAbstractClosurePtr>()->prim());
473   return !MatchMetaFg(value);
474 }
475 
476 // {{meta_fg, g, w}, Ys} => {{meta_fg, {UnpackGraph, g, Ys}, w}, Ys}
477 // {UnpackCall, {meta_fg, g, w}, Ys} => {UnpackCall, {meta_fg, {UnpackGraph, g, Ys}, w}, Ys}
InsertUnpackGraph(const CNodePtr & origin_cnode,const ValuePtr & value,const AnfNodeConfigPtr & conf,const AnalysisEnginePtr & engine)478 AnfNodePtr InsertUnpackGraph(const CNodePtr &origin_cnode, const ValuePtr &value, const AnfNodeConfigPtr &conf,
479                              const AnalysisEnginePtr &engine) {
480   // origin_cnode is {meta_fg, g, ...}
481   const size_t inputs_x_minimum_size = 2;
482   if (origin_cnode->size() < inputs_x_minimum_size) {
483     return nullptr;
484   }
485 
486   if (value == nullptr || !MatchMetaFg(value)) {
487     return nullptr;
488   }
489 
490   if (!IsFuncGraphAbstractInput(origin_cnode, engine, conf)) {
491     return nullptr;
492   }
493 
494   auto manager = conf->engine()->func_graph_manager();
495   MS_EXCEPTION_IF_NULL(manager);
496   auto node_users = manager->node_users()[origin_cnode];
497   if (node_users.empty()) {
498     return nullptr;
499   }
500   auto meta_user = node_users.begin()->first->cast<CNodePtr>();
501   MS_EXCEPTION_IF_NULL(meta_user);
502   int index = node_users.begin()->second;
503   if (index != 0 && index != 1) {
504     return nullptr;
505   }
506 
507   bool need_unpack_args = false;
508   if (index == 1) {
509     // The meta_fg user node should be UnpackCall.
510     auto input0_value = GetValueWithoutDoSignature(meta_user->input(0));
511     if (input0_value == nullptr || !input0_value->isa<prim::UnpackCall>()) {
512       return nullptr;
513     }
514     need_unpack_args = true;
515   }
516   // Create UnpackGraph node.
517   bool sens_param = false;
518   if (value->isa<prim::GradOperation>()) {
519     sens_param = value->cast<prim::GradOperationPtr>()->sens_param();
520     RemoveSequenceFromOrderList(origin_cnode);
521   }
522   auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>(sens_param, need_unpack_args);
523   std::vector<AnfNodePtr> unpack_graph_inputs{NewValueNode(unpack_graph), origin_cnode->input(1)};
524   const auto &meta_user_inputs = meta_user->inputs();
525   constexpr int64_t unpack_inputs_begin_index = 2;
526   int64_t offset = (need_unpack_args ? unpack_inputs_begin_index : 1);
527   (void)std::transform(meta_user_inputs.begin() + offset, meta_user_inputs.end(),
528                        std::back_inserter(unpack_graph_inputs),
529                        [](const AnfNodePtr &node) -> AnfNodePtr { return node; });
530   auto fg = origin_cnode->func_graph();
531   MS_EXCEPTION_IF_NULL(fg);
532   auto unpack_graph_node = fg->NewCNodeBefore(meta_user, unpack_graph_inputs);
533   // Create new call_node.
534   auto new_cnode_inputs = origin_cnode->inputs();
535   new_cnode_inputs[1] = unpack_graph_node;
536   auto new_cnode = fg->NewCNodeBefore(meta_user, new_cnode_inputs);
537   return new_cnode;
538 }
539 }  // namespace
540 
Get(const PrimitivePtr & prim,const AbstractBasePtrList & args) const541 EvalResultPtr PrimitiveEvalCache::Get(const PrimitivePtr &prim, const AbstractBasePtrList &args) const {
542   MS_EXCEPTION_IF_NULL(prim);
543   std::lock_guard<std::mutex> guard(mutex_);
544   auto cache_iter = prim_cache_.find(prim->name());
545   if (cache_iter == prim_cache_.end()) {
546     return nullptr;
547   }
548   auto &cache = cache_iter->second;
549   auto iter = cache.find(PrimitiveEvalCacheKey{prim->attrs(), args});
550   if (iter == cache.end()) {
551     return nullptr;
552   }
553   return iter->second;
554 }
555 
Put(const PrimitivePtr & prim,AttrValueMap && attrs,const AbstractBasePtrList & args,const EvalResultPtr & result)556 void PrimitiveEvalCache::Put(const PrimitivePtr &prim, AttrValueMap &&attrs, const AbstractBasePtrList &args,
557                              const EvalResultPtr &result) {
558   MS_EXCEPTION_IF_NULL(prim);
559   std::lock_guard<std::mutex> guard(mutex_);
560   (void)prim_cache_[prim->name()].emplace(PrimitiveEvalCacheKey{std::move(attrs), args}, result);
561 }
562 
Clear()563 void PrimitiveEvalCache::Clear() {
564   std::lock_guard<std::mutex> guard(mutex_);
565   prim_cache_.clear();
566 }
567 
Run(const FuncGraphPtr & func_graph,const AbstractBasePtrList & args_abs_list)568 AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_abs_list) {
569   StaticAnalysisException::Instance().ClearException();
570   AnalysisResult result;
571   try {
572     MS_EXCEPTION_IF_NULL(func_graph);
573     ConfigPtrList args_conf_list;
574     (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(args_conf_list),
575                          [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
576     MS_EXCEPTION_IF_NULL(func_graph_manager_);
577     func_graph_manager_->AddFuncGraph(func_graph);
578     root_func_graph_ = func_graph;
579 
580     // Running the analyzer.
581     ResetFunctionCallDepth();
582     ResetStackFrameDepth();
583     // Create a new root dummy context for the new analysis session.
584     AnalysisContextPtr dummy_context = AnalysisContext::NewDummyContext();
585     MS_LOG(DEBUG) << func_graph->ToString() << ": Run begin.";
586     AnalysisContextPtr root_context = Run(func_graph, dummy_context, args_conf_list);
587     AnalysisSchedule::GetInstance().Wait();
588     MS_EXCEPTION_IF_NULL(root_context);
589     auto root_context_fg = root_context->func_graph();
590     MS_EXCEPTION_IF_NULL(root_context_fg);
591     AnfNodeConfigPtr output_conf = MakeConfig(root_context_fg->get_return(), root_context, root_context_fg);
592     MS_LOG(DEBUG) << func_graph->ToString() << ": Run finished.";
593 
594     MS_EXCEPTION_IF_NULL(output_conf);
595     auto eval_result = output_conf->ObtainEvalResult();
596     result.eval_result = eval_result;
597     result.context = root_context;
598   } catch (const std::exception &ex) {
599     MS_LOG(INFO) << "Eval " << func_graph->ToString() << " threw exception.";
600     AnalysisSchedule::GetInstance().HandleException(ex);
601   }
602   AnalysisSchedule::GetInstance().Wait();
603   MS_LOG(DEBUG) << func_graph->ToString() << ": Run end.";
604   // Set the sequence nodes' elements use flags all true.
605   SetSequenceElementsUseFlagsRecursively(result.eval_result->abstract(), true);
606   MS_LOG(DEBUG) << func_graph->ToString() << ":SetSequenceElementsUseFlagsRecursively Run end.";
607   return result;
608 }
609 
Run(const FuncGraphPtr & func_graph,const AnalysisContextPtr & context,const ConfigPtrList & args_conf_list)610 AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
611                                        const ConfigPtrList &args_conf_list) {
612   auto evaluator = std::make_shared<FuncGraphEvaluator>(func_graph, context);
613   (void)evaluator->Run(shared_from_this(), args_conf_list, nullptr);
614   return root_context_;
615 }
616 
ObtainEvalResultFromCache(const AnfNodeConfigPtr & conf)617 EvalResultPtr ObtainEvalResultFromCache(const AnfNodeConfigPtr &conf) {
618   MS_EXCEPTION_IF_NULL(conf);
619   static AnalysisResultCacheMgr &cache_mgr = AnalysisResultCacheMgr::GetInstance();
620   auto result = cache_mgr.GetValue(conf);
621   if (result != nullptr) {
622     MS_EXCEPTION_IF_NULL(result->abstract());
623     MS_LOG(DEBUG) << "Evaluate cache found for NodeConfig: " << conf->ToString()
624                   << ", result: " << result->abstract().get() << "/" << result->abstract()->ToString();
625     return result;
626   }
627   return nullptr;
628 }
629 
ObtainEvalResultWithCache(const AnfNodeConfigPtr & conf)630 EvalResultPtr AnalysisEngine::ObtainEvalResultWithCache(const AnfNodeConfigPtr &conf) {
631   MS_EXCEPTION_IF_NULL(conf);
632   auto result = ObtainEvalResultFromCache(conf);
633   if (result != nullptr) {
634     return result;
635   }
636   MS_LOG(DEBUG) << "Evaluate cache miss for NodeConfig: " << conf->ToString();
637   result = ObtainEvalResultWithoutCache(conf);
638   return result;
639 }
640 
ObtainEvalResultWithoutCache(const AnfNodeConfigPtr & conf)641 EvalResultPtr AnalysisEngine::ObtainEvalResultWithoutCache(const AnfNodeConfigPtr &conf) {
642   MS_EXCEPTION_IF_NULL(conf);
643   EvalResultPtr result = nullptr;
644   result = Eval(conf);
645   if (result == nullptr) {
646     MS_LOG(INTERNAL_EXCEPTION) << "Evaluate for NodeConfig " << conf->ToString() << " get nullptr";
647   }
648   MS_EXCEPTION_IF_NULL(result->abstract());
649   MS_LOG(DEBUG) << "Always Evaluate node for NodeConfig: " << conf->ToString()
650                 << ", result: " << result->abstract().get() << "/" << result->abstract()->ToString();
651   SaveEvalResultInCache(conf, result);
652   return result;
653 }
654 
SaveEvalResultInCache(const AnfNodeConfigPtr & conf,const EvalResultPtr & result) const655 void AnalysisEngine::SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) const {
656   MS_EXCEPTION_IF_NULL(conf);
657   MS_EXCEPTION_IF_NULL(result);
658   static AnalysisResultCacheMgr &cache_mgr = AnalysisResultCacheMgr::GetInstance();
659   auto iter = cache_mgr.GetCache().find(conf);
660   if (iter != cache_mgr.GetCache().end()) {
661     MS_EXCEPTION_IF_NULL(iter->second);
662     MS_EXCEPTION_IF_NULL(iter->second->abstract());
663     MS_LOG(DEBUG) << "Found previous result for NodeConfig: " << conf->ToString()
664                   << ", result: " << iter->second->abstract().get() << "/" << iter->second->abstract()->ToString();
665     // Update sequence nodes info, if matched in cache.
666     static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
667     if (enable_eliminate_unused_element) {
668       auto new_sequence = dyn_cast<AbstractSequence>(result->abstract());
669       auto old_sequence = dyn_cast<AbstractSequence>(iter->second->abstract());
670       if (old_sequence != nullptr && new_sequence != nullptr) {
671         MS_LOG(DEBUG) << "Before synchronize sequence nodes use flags for NodeConfig: " << conf->ToString()
672                       << ", old_sequence: " << old_sequence->ToString()
673                       << ", new_sequence: " << new_sequence->ToString();
674         SynchronizeSequenceElementsUseFlagsRecursively(old_sequence, new_sequence);
675         MS_LOG(DEBUG) << "After synchronize sequence nodes use flags for NodeConfig: " << conf->ToString()
676                       << ", old_sequence: " << old_sequence->ToString()
677                       << ", new_sequence: " << new_sequence->ToString();
678       }
679     }
680   }
681   MS_EXCEPTION_IF_NULL(result->abstract());
682   MS_LOG(DEBUG) << "Save result for NodeConfig: " << conf->ToString() << ", result: " << result->abstract().get() << "/"
683                 << result->abstract()->ToString();
684   cache_mgr.SetValue(conf, result);
685 }
686 
SynchronizeSequenceElementsUseFlagsForFuncGraphArgs(const AnalysisEnginePtr & engine,const FuncGraphPtr & fg,const CNodePtr & cnode,const AbstractFunctionPtr & base_func_graph_func,const AnalysisContextPtr & fg_context)687 void SynchronizeSequenceElementsUseFlagsForFuncGraphArgs(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
688                                                          const CNodePtr &cnode,
689                                                          const AbstractFunctionPtr &base_func_graph_func,
690                                                          const AnalysisContextPtr &fg_context) {
691   // Get the evaluator for func graph.
692   auto evaluator = engine->GetEvaluatorFor(base_func_graph_func);
693   MS_EXCEPTION_IF_NULL(evaluator);
694 
695   AbstractBasePtrList args_abs_list;
696   for (std::size_t i = 1; i < cnode->size(); i++) {
697     auto config = engine->MakeConfig(cnode->input(i), fg_context, fg);
698     auto result = config->ObtainEvalResult();
699     MS_EXCEPTION_IF_NULL(result);
700     auto abs = result->abstract();
701     args_abs_list.push_back(abs);
702   }
703 
704   // Check if already evaluated before.
705   MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
706   auto &cache = evaluator->evaluator_cache_mgr()->GetCache();
707   auto iter = cache.find(args_abs_list);
708   if (iter != cache.end()) {
709     MS_EXCEPTION_IF_NULL(fg_context);
710     MS_LOG(DEBUG) << "Eval before, current_node: " << cnode->DebugString() << ", context: " << fg_context->ToString()
711                   << ", args: " << args_abs_list;
712     // Update inputs sequence nodes info, if matched in cache.
713     for (std::size_t i = 0; i < args_abs_list.size(); ++i) {
714       auto new_sequence = dyn_cast<AbstractSequence>(args_abs_list[i]);
715       auto old_sequence = dyn_cast<AbstractSequence>(iter->first[i]);
716       if (old_sequence != nullptr && new_sequence != nullptr) {
717         MS_LOG(DEBUG) << "Before synchronize sequence nodes use flags, old_sequence: " << old_sequence->ToString()
718                       << ", new_sequence: " << new_sequence->ToString();
719         SynchronizeSequenceElementsUseFlagsRecursively(old_sequence, new_sequence);
720         MS_LOG(DEBUG) << "After synchronize sequence nodes use flags, old_sequence: " << old_sequence->ToString()
721                       << ", new_sequence: " << new_sequence->ToString();
722       }
723     }
724   }
725 }
726 
Eval(const AnfNodeConfigPtr & conf)727 EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
728   MS_EXCEPTION_IF_NULL(conf);
729   AnfNodePtr node = conf->node();
730   EvalResultPtr eval_result = nullptr;
731 #ifdef DEBUG
732   compute_conf_stack_.push_back(node);
733   std::ostringstream buffer;
734   buffer << "Compute Config Begin:";
735   for (auto iter : compute_conf_stack_) {
736     buffer << " -> " << iter->DebugString();
737   }
738   MS_LOG(DEBUG) << buffer.str();
739 #endif
740   MS_LOG(DEBUG) << "Begin Eval NodeConfig " << conf->ToString();
741   MS_EXCEPTION_IF_NULL(node);
742   if (node->abstract() != nullptr) {
743     MS_LOG(DEBUG) << "Return old abstract: " << node->DebugString();
744     eval_result = std::make_shared<EvalResult>(node->abstract(), std::make_shared<AttrValueMap>());
745   } else if (node->isa<ValueNode>()) {
746     auto value_node = node->cast<ValueNodePtr>();
747     auto abstract = EvalValueNode(value_node, conf);
748     eval_result = std::make_shared<EvalResult>(abstract, std::make_shared<AttrValueMap>());
749   } else if (node->isa<CNode>()) {
750     auto cnode = node->cast<CNodePtr>();
751     trace::TraceEvalCNodeEnter(conf);
752     MS_LOG(DEBUG) << "Begin Eval CNode: " << cnode->DebugString();
753     eval_result = EvalCNode(cnode, conf);
754     MS_LOG(DEBUG) << "End Eval CNode: " << cnode->DebugString();
755     trace::TraceEvalCNodeLeave();
756   } else {
757     MS_LOG(INTERNAL_EXCEPTION) << "Illegal AnfNode for evaluating, node: " << node->DebugString()
758                                << "(type:" << node->type_name() << "), fg: "
759                                << (node->func_graph() != nullptr ? node->func_graph()->ToString() : "nullgraph")
760                                << " conf: " << conf->ToString();
761   }
762 
763 #ifdef DEBUG
764   compute_conf_stack_.pop_back();
765   if (eval_result == nullptr) {
766     MS_LOG(INTERNAL_EXCEPTION) << "Compute Config failed, node: " << node->DebugString()
767                                << " NodeInfo: " << trace::GetDebugInfo(node->debug_info());
768   }
769 #endif
770   MS_EXCEPTION_IF_NULL(eval_result->abstract());
771   MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << eval_result->abstract()->ToString();
772   return eval_result;
773 }
774 
EvalValueNode(const ValueNodePtr & value_node,const AnfNodeConfigPtr & conf) const775 AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) const {
776   MS_EXCEPTION_IF_NULL(conf);
777   MS_EXCEPTION_IF_NULL(value_node);
778   auto out = ToAbstract(value_node->value(), conf->context(), conf);
779   if (value_node->has_new_value() && out->isa<AbstractTensor>()) {
780     out = out->Broaden();
781   }
782   return out;
783 }
784 
GetForwardConfig(const AnfNodeConfigPtr & conf) const785 AnfNodeConfigPtr AnalysisEngine::GetForwardConfig(const AnfNodeConfigPtr &conf) const {
786   MS_EXCEPTION_IF_NULL(conf);
787   AnfNodeConfigPtr new_conf = conf;
788   auto conf_iter = anfnode_config_map().find(conf);
789   while (conf_iter != anfnode_config_map().end()) {
790     new_conf = conf_iter->second;
791     MS_EXCEPTION_IF_NULL(new_conf);
792     conf_iter = anfnode_config_map().find(new_conf);
793   }
794   return new_conf;
795 }
796 
InterpretedNodeCall(const CNodePtr & cnode,const AnfNodeConfigPtr & conf)797 EvalResultPtr AnalysisEngine::InterpretedNodeCall(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
798   MS_EXCEPTION_IF_NULL(cnode);
799   if (cnode->empty()) {
800     MS_LOG(INTERNAL_EXCEPTION) << "CNode inputs should not be empty, CNode: " << cnode->DebugString();
801   }
802 
803   // Check if the operator input is PyExecute CNode.
804   const auto &func_node = cnode->input(0);
805   MS_EXCEPTION_IF_NULL(func_node);
806   constexpr auto recursive_level = 2;
807   MS_LOG(DEBUG) << "Current CNode: " << cnode->DebugString(recursive_level)
808                 << ", func_node: " << func_node->DebugString(recursive_level);
809   auto prim = GetCNodePrimitiveWithoutDoSignature(func_node);
810   if (!IsPrimitiveEquals(prim, prim::kPrimResolve) && !IsPrimitiveEquals(prim, prim::kPrimGetAttr) &&
811       !IsPrimitiveEquals(prim, prim::kPrimPyExecute) && !IsPrimitiveEquals(prim, prim::kPrimPyInterpret)) {
812     // Optimize the performance.
813     return nullptr;
814   }
815   AnfNodeConfigPtr func_conf = MakeConfig(func_node, conf->context(), conf->func_graph());
816   MS_EXCEPTION_IF_NULL(func_conf);
817   const auto &forwarded_conf = GetForwardConfig(func_conf);
818   if (!IsPrimitiveCNode(forwarded_conf->node(), prim::kPrimPyExecute) &&
819       !IsPrimitiveCNode(forwarded_conf->node(), prim::kPrimPyInterpret)) {
820     return nullptr;
821   }
822 
823   if (IsPrimitiveEquals(prim, prim::kPrimResolve)) {
824     return ConvertToPyInterpretCall(cnode, conf, forwarded_conf->node());
825   }
826   // Forward getattr CNode call to PyInterpreted CNode.
827   return ConvertToPyInterpretCall(cnode, conf);
828 }
829 
GetCNodeOperatorAbstract(const CNodePtr & cnode,const AnalysisContextPtr & context,const FuncGraphPtr & func_graph)830 AbstractBasePtr AnalysisEngine::GetCNodeOperatorAbstract(const CNodePtr &cnode, const AnalysisContextPtr &context,
831                                                          const FuncGraphPtr &func_graph) {
832   MS_EXCEPTION_IF_NULL(cnode);
833   if (cnode->empty()) {
834     MS_LOG(INTERNAL_EXCEPTION) << "CNode inputs should not be empty, CNode: " << cnode->DebugString();
835   }
836   auto &func_node = cnode->input(0);
837   MS_EXCEPTION_IF_NULL(func_node);
838   MS_LOG(DEBUG) << "Current CNode function: " << func_node->DebugString();
839   AnfNodeConfigPtr func_conf = MakeConfig(func_node, context, func_graph);
840   MS_EXCEPTION_IF_NULL(func_conf);
841   // Keep it in a local variable, otherwise smart pointer will free it.
842   auto possible_func_eval_result = func_conf->ObtainEvalResult();
843   MS_EXCEPTION_IF_NULL(possible_func_eval_result);
844   auto &possible_func = possible_func_eval_result->abstract();
845   if (possible_func == nullptr) {
846     MS_LOG(INTERNAL_EXCEPTION) << "No abstract, func_conf: " << func_conf->ToString();
847   }
848   return possible_func;
849 }
850 
ConvertClassTypeToFunc(const CNodePtr & cnode,const AbstractBasePtr & abs,const AnfNodeConfigPtr & conf)851 EvalResultPtr AnalysisEngine::ConvertClassTypeToFunc(const CNodePtr &cnode, const AbstractBasePtr &abs,
852                                                      const AnfNodeConfigPtr &conf) {
853   MS_EXCEPTION_IF_NULL(cnode);
854   const auto inputs_size = cnode->size();
855   AbstractBasePtrList input_abs;
856   input_abs.reserve(inputs_size - 1);
857   for (std::size_t i = 1; i < inputs_size; ++i) {
858     const AnfNodePtr &node = cnode->input(i);
859     auto cur_config = MakeConfig(node, conf->context(), conf->func_graph());
860     const auto &cur_eval_result = cur_config->ObtainEvalResult();
861     MS_EXCEPTION_IF_NULL(cur_eval_result);
862     auto cur_abs = cur_eval_result->abstract();
863     MS_EXCEPTION_IF_NULL(cur_abs);
864     input_abs.push_back(cur_abs);
865   }
866   bool has_non_graph_input = std::any_of(input_abs.begin(), input_abs.end(), [](const AbstractBasePtr &abs) {
867     MS_EXCEPTION_IF_NULL(abs);
868     return abs->isa<abstract::AbstractAny>() || abs->BuildValue()->isa<parse::InterpretedObject>();
869   });
870   if (has_non_graph_input) {
871     return ConvertToPyInterpretCall(cnode, conf);
872   }
873   MS_EXCEPTION_IF_NULL(abs);
874   auto val = abs->BuildValue();
875   MS_EXCEPTION_IF_NULL(val);
876   auto class_val = dyn_cast_ptr<parse::ClassType>(val);
877   MS_EXCEPTION_IF_NULL(class_val);
878   const auto &class_name = class_val->name();
879   std::vector<AnfNodePtr> new_cnode_inputs;
880   auto fg = cnode->func_graph();
881   MS_EXCEPTION_IF_NULL(fg);
882 
883   std::map<std::string, ValueNodePtr> list_or_tuple_func_map = {
884     {"class 'list'", NewValueNode(std::make_shared<prim::ListFunc>("list_func"))},
885     {"class 'tuple'", NewValueNode(std::make_shared<prim::TupleFunc>("tuple_func"))}};
886   auto iter = list_or_tuple_func_map.find(class_name);
887   if (iter != list_or_tuple_func_map.end()) {
888     (void)new_cnode_inputs.emplace_back(iter->second);
889   } else {
890     auto class_obj = class_val->obj();
891     py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
892     auto py_fn =
893       python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_CONVERT_CLASS_TO_FUNCTION, py::str(class_name), class_obj);
894     if (py::isinstance<py::none>(py_fn)) {
895       return ConvertToPyInterpretCall(cnode, conf);
896     }
897     auto func_fg = parse::ParsePythonCode(py_fn);
898     MS_EXCEPTION_IF_NULL(func_fg);
899     func_fg->set_manager(fg->manager());
900     (void)new_cnode_inputs.emplace_back(NewValueNode(func_fg));
901   }
902 
903   for (std::size_t i = 1; i < cnode->size(); ++i) {
904     (void)new_cnode_inputs.emplace_back(cnode->input(i));
905   }
906   auto new_cnode = fg->NewCNodeInOrder(new_cnode_inputs);
907   new_cnode->set_debug_info(cnode->debug_info());
908   AnalysisEnginePtr eng = conf->engine();
909   MS_EXCEPTION_IF_NULL(eng);
910   AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, conf->context(), conf->func_graph());
911   return eng->ForwardConfig(conf, fn_conf);
912 }
913 
EvalCNode(const CNodePtr & cnode,const AnfNodeConfigPtr & conf)914 EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
915   MS_EXCEPTION_IF_NULL(conf);
916   MS_EXCEPTION_IF_NULL(cnode);
917 
918   // Handle the interpreted node call here.
919   const auto &interpreted_eval_result = InterpretedNodeCall(cnode, conf);
920   if (interpreted_eval_result != nullptr) {
921     return interpreted_eval_result;
922   }
923 
924   AbstractBasePtr possible_func = GetCNodeOperatorAbstract(cnode, conf->context(), conf->func_graph());
925   MS_EXCEPTION_IF_NULL(possible_func->BuildType());
926   if (possible_func->IsSameTypeId(AbstractUndetermined::kTypeId)) {
927     MS_LOG(DEBUG) << "EvalCNode eval Undetermined";
928     return std::make_shared<EvalResult>(possible_func->Clone(), std::make_shared<AttrValueMap>());
929   }
930 
931   if (possible_func->isa<AbstractClass>()) {
932     return ConvertMsClassObjToFunc(cnode, possible_func, conf);
933   }
934   if (possible_func->isa<AbstractScalar>()) {
935     // Convert class to function, such as list(xxx).
936     auto val = possible_func->BuildValue();
937     MS_EXCEPTION_IF_NULL(val);
938     if (val->isa<parse::ClassType>()) {
939       return ConvertClassTypeToFunc(cnode, possible_func, conf);
940     }
941     if (val->isa<parse::InterpretedObject>()) {
942       return ConvertCallPyObjCallFunc(cnode, possible_func, conf);
943     }
944   }
945 
946   if (possible_func->isa<AbstractAny>()) {
947     return ConvertToPyInterpretCall(cnode, conf);
948   }
949 
950   if (possible_func->isa<PrimitiveAbstractClosure>()) {
951     auto value = GetValueWithoutDoSignature(possible_func->cast<PrimitiveAbstractClosurePtr>()->prim());
952     auto new_cnode = InsertUnpackGraph(cnode, value, conf, shared_from_this());
953     if (new_cnode != nullptr) {
954       AnalysisEnginePtr eng = conf->engine();
955       MS_EXCEPTION_IF_NULL(eng);
956       AnfNodeConfigPtr new_conf = eng->MakeConfig(new_cnode, conf->context(), conf->func_graph());
957       return eng->ForwardConfig(conf, new_conf);
958     }
959   }
960 
961   auto func = dyn_cast_ptr<AbstractFunction>(possible_func);
962   if (func == nullptr) {
963     MS_LOG(ERROR) << "Can not cast to a AbstractFunction from " << possible_func->ToString() << ".";
964     MS_LOG(ERROR) << "It's called at: " << cnode->DebugString();
965     MS_EXCEPTION(ValueError) << "The object is not callable. Please check code.";
966   }
967 
968   // Make arguments config list.
969   bool contains_side_effect = false;
970   const auto inputs_size = cnode->size();
971   ConfigPtrList args_conf_list;
972   args_conf_list.reserve(inputs_size - 1);
973   // Ignore the first node which is function name.
974   for (std::size_t i = 1; i < inputs_size; ++i) {
975     const AnfNodePtr &node = cnode->input(i);
976     (void)args_conf_list.emplace_back(MakeConfig(node, conf->context(), conf->func_graph()));
977     if (check_side_effect()) {
978       auto input_cnode = dyn_cast_ptr<CNode>(node);
979       if (input_cnode != nullptr) {
980         contains_side_effect = contains_side_effect || input_cnode->has_side_effect_node();
981       }
982     }
983   }
984 
985   // Find evaluators.
986   std::vector<EvaluatorPtr> evaluators;
987   func->Visit([this, &evaluators, &cnode](const AbstractFuncAtomPtr &possible_func) {
988     const auto &real_func_atom = GetRealFuncAtom(possible_func);
989     auto evaluator = this->GetEvaluatorFor(real_func_atom);
990     evaluator->set_bound_node(cnode);
991     (void)evaluators.emplace_back(std::move(evaluator));
992   });
993 
994   // Run evaluators.
995   auto eval_result = ExecuteEvaluators(evaluators, conf, args_conf_list);
996   // Check if func graph contains isolated side-effect, and sync.
997   if (check_side_effect()) {
998     func->Visit([&contains_side_effect](const AbstractFuncAtomPtr &possible_func) {
999       const auto &real_func_atom = GetRealFuncAtom(possible_func);
1000       bool func_has_side_effect = CheckFuncSideEffect(real_func_atom);
1001       if (func_has_side_effect) {
1002         contains_side_effect = true;
1003       }
1004     });
1005     if (contains_side_effect) {
1006       MS_EXCEPTION_IF_NULL(conf->func_graph());
1007       MS_LOG(DEBUG) << "Found side-effect, cnode: " << cnode->DebugString()
1008                     << ", func_graph: " << conf->func_graph()->ToString();
1009       cnode->set_has_side_effect_node(true);
1010       conf->func_graph()->set_has_side_effect_node(true);
1011       eval_result->set_has_side_effect_node(true);
1012     }
1013   }
1014   return eval_result;
1015 }
1016 
Execute(const AbstractFunctionPtr & func,const AbstractBasePtrList & args_abs_list)1017 EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_abs_list) {
1018   MS_EXCEPTION_IF_NULL(func);
1019   ConfigPtrList args_conf_list;
1020   (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(args_conf_list),
1021                        [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
1022   std::vector<EvaluatorPtr> infs;
1023   MS_EXCEPTION_IF_NULL(func);
1024   auto build_evaluator = [this, &infs](const AbstractFuncAtomPtr &poss) {
1025     auto evaluator = this->GetEvaluatorFor(poss);
1026     infs.push_back(evaluator);
1027   };
1028   func->Visit(build_evaluator);
1029   return ExecuteEvaluators(infs, nullptr, args_conf_list);
1030 }
1031 
ClearEvaluatorCache()1032 void AnalysisEngine::ClearEvaluatorCache() {
1033   py::gil_scoped_acquire gil;
1034   for (auto &element : evaluators_) {
1035     EvaluatorPtr evaluator = element.second;
1036     if (evaluator == nullptr || evaluator->evaluator_cache_mgr() == nullptr) {
1037       continue;
1038     }
1039     MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
1040     evaluator->evaluator_cache_mgr()->Clear();
1041   }
1042   for (auto &element : prim_constructors_) {
1043     EvaluatorPtr evaluator = element.second;
1044     if (evaluator == nullptr || evaluator->evaluator_cache_mgr() == nullptr) {
1045       continue;
1046     }
1047     MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
1048     evaluator->evaluator_cache_mgr()->Clear();
1049   }
1050   for (auto &element : prim_py_evaluators_) {
1051     EvaluatorPtr evaluator = element.second;
1052     if (evaluator == nullptr || evaluator->evaluator_cache_mgr() == nullptr) {
1053       continue;
1054     }
1055     MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
1056     evaluator->evaluator_cache_mgr()->Clear();
1057   }
1058   // Release exception to avoid hup at exit.
1059   StaticAnalysisException::Instance().ClearException();
1060   // Reset the EnvironGet sparse option.
1061   EnvSetSparseResultMgr::GetInstance().Set(false);
1062 }
1063 
Clear()1064 void AnalysisEngine::Clear() {
1065   AnalysisResultCacheMgr::GetInstance().Clear();
1066   anfnode_config_map_.clear();
1067   eval_trace_.clear();
1068   evaluators_.clear();
1069   prim_py_evaluators_.clear();
1070   constructors_app_.clear();
1071   continued_evals_.clear();
1072   root_context_ = nullptr;
1073 }
1074 
GetPyEvaluator(const PrimitivePtr & prim,const AnalysisEnginePtr & engine)1075 EvaluatorPtr GetPyEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) {
1076   auto prim_py = dyn_cast<PrimitivePy>(prim);
1077   if (prim_py != nullptr) {
1078     auto is_constexpr = prim_py->HasAttr(GRAPH_FLAG_CONSTEXPR_PRIM);
1079     if (is_constexpr) {
1080       return std::make_shared<ConstexprEvaluator>(prim_py);
1081     }
1082     if (engine == nullptr) {
1083       return std::make_shared<PythonPrimEvaluator>(prim_py);
1084     }
1085 
1086     const auto &iter = engine->prim_py_evaluators_.find(prim_py);
1087     if (iter != engine->prim_py_evaluators_.end()) {
1088       return iter->second;
1089     }
1090     auto evaluator = std::make_shared<PythonPrimEvaluator>(prim_py);
1091     engine->prim_py_evaluators_[prim_py] = evaluator;
1092     return evaluator;
1093   }
1094   MS_LOG(ERROR) << "The primitive with python evaluator should be a python primitive.";
1095   return nullptr;
1096 }
1097 
GetStandardPrimEvaluator(const PrimitivePtr & prim)1098 inline StandardPrimEvaluatorPtr GetStandardPrimEvaluator(const PrimitivePtr &prim) {
1099   auto eval_impl_opt = GetFrontendPrimitiveInferImpl(prim);
1100   if (eval_impl_opt.has_value()) {
1101     // Find prim infer function in the prim function map return a standard evaluator
1102     auto eval_impl = eval_impl_opt.value();
1103     if (eval_impl.IsImplInferShapeAndType() && !IsPrimitiveEquals(prim, prim::kPrimMakeTuple) &&
1104         !IsPrimitiveEquals(prim, prim::kPrimMakeList)) {
1105       return std::make_shared<StandardPrimEvaluator>(prim, eval_impl);
1106     }
1107   }
1108 
1109   return nullptr;
1110 }
1111 
GetPrimEvaluator(const PrimitivePtr & prim,const AnalysisEnginePtr & engine)1112 EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) {
1113   // Custom Primitive with python infer_shape, infer_type
1114   MS_EXCEPTION_IF_NULL(prim);
1115   if (prim->isa<prim::DoSignaturePrimitive>()) {
1116     return std::make_shared<DoSignatureEvaluator>(prim);
1117   }
1118   if (prim->isa<prim::UnpackGraphPrimitive>()) {
1119     return std::make_shared<UnpackGraphEvaluator>(prim);
1120   }
1121   if (IsPrimitiveEquals(prim, prim::kPrimMixedPrecisionCast)) {
1122     return std::make_shared<MixedPrecisionCastEvaluator>(prim);
1123   }
1124   if (IsPrimitiveEquals(prim, prim::kPrimPyExecute)) {
1125     return std::make_shared<PyExecuteEvaluator>();
1126   }
1127   static const bool enable_pre_lift = (common::GetCompileConfig("PRE_LIFT") == "1");
1128   if (enable_pre_lift && IsPrimitiveEquals(prim, prim::kPrimSwitch)) {
1129     return std::make_shared<SwitchEvaluator>();
1130   }
1131 
1132   if (prim->isa<prim::DoTransPrimitiveFunction>()) {
1133     return std::make_shared<DoTransPrimitiveFunctionEvaluator>(prim);
1134   }
1135   // Primitive is defined in OpTable.
1136   if (mindspore::ops::IsPrimitiveFunction(prim->name())) {
1137     if (prim->isa<PrimitivePy>()) {
1138       return std::make_shared<PrimitiveArgsToInputsEvaluator>(prim);
1139     }
1140     return std::make_shared<PrimitiveFunctionEvaluator>(prim);
1141   }
1142 
1143   auto standard_evaluator = GetStandardPrimEvaluator(prim);
1144   if (standard_evaluator != nullptr) {
1145     return standard_evaluator;
1146   }
1147 
1148   // Use python infer function if the infer function not founded in the map return a python evaluator
1149   EvaluatorPtr evaluator = nullptr;
1150   if (prim->HasPyEvaluator()) {
1151     return GetPyEvaluator(prim, engine);
1152   }
1153 
1154   // Delete this when the infer value can be mapped to the CPU backend operator.
1155   if (PrimNeedFrontendInferValue(prim)) {
1156     return nullptr;
1157   }
1158 
1159   // Return a default evaluator
1160   if (engine == nullptr) {
1161     // If engine is nullptr, get constructor from default.
1162     const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors();
1163     auto iter = prim_evaluator_map.find(prim);
1164     if (iter != prim_evaluator_map.end()) {
1165       evaluator = iter->second;
1166     }
1167   } else {
1168     // If engine is given, get constructor from engine resource.
1169     const PrimEvaluatorMap &prim_evaluator_map = engine->PrimConstructors();
1170     auto iter = prim_evaluator_map.find(prim);
1171     if (iter != prim_evaluator_map.end()) {
1172       evaluator = iter->second;
1173     }
1174   }
1175 
1176   if (evaluator == nullptr) {
1177     MS_LOG(DEBUG) << "The evaluator of the primitive is not defined (" << prim->name() << ").";
1178   }
1179   return evaluator;
1180 }
1181 
_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> & func)1182 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &func) {
1183   MS_EXCEPTION_IF_NULL(func);
1184   const auto &primitive = func->prim();
1185   if (func->tracking_id() == 0) {
1186     // Create primitive evaluator if tracking_id == 0.
1187     auto [iter, is_new] = evaluators_.emplace(func, nullptr);
1188 
1189     if (is_new) {
1190       iter->second = GetPrimEvaluator(primitive, shared_from_this());
1191       if (iter->second == nullptr) {
1192         MS_LOG(EXCEPTION) << "Operator '" << primitive->name()
1193                           << "' is invalid, or no matching evaluator could be found.";
1194       }
1195     }
1196     return iter->second;
1197   }
1198   // Use TrackedEvaluator if tracking_id != 0.
1199   auto iter = evaluators_.find(func);
1200   if (iter != evaluators_.end()) {
1201     return iter->second;
1202   }
1203   auto prim_without_tracking_id = std::make_shared<PrimitiveAbstractClosure>(primitive, 0);
1204   EvaluatorPtr prim_evaluator = _GetEvaluatorFor(prim_without_tracking_id);
1205   static const bool enable_pre_lift = (common::GetCompileConfig("PRE_LIFT") == "1");
1206   if (enable_pre_lift && IsPrimitiveEquals(primitive, prim::kPrimSwitch)) {
1207     auto result = evaluators_.emplace(func, prim_evaluator);
1208     return result.first->second;
1209   } else {
1210     auto tracked_evaluator = std::make_shared<TrackedEvaluator>(prim_evaluator);
1211     auto result = evaluators_.emplace(func, std::move(tracked_evaluator));
1212     return result.first->second;
1213   }
1214 }
1215 
_GetEvaluatorFor(const std::shared_ptr<FuncGraphAbstractClosure> & func)1216 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<FuncGraphAbstractClosure> &func) {
1217   MS_EXCEPTION_IF_NULL(func);
1218   auto [iter, is_new] = evaluators_.emplace(func, nullptr);
1219   if (is_new) {
1220     iter->second = std::make_shared<FuncGraphEvaluator>(func->func_graph(), func->context());
1221   }
1222   return iter->second;
1223 }
1224 
_GetEvaluatorFor(const std::shared_ptr<MetaFuncGraphAbstractClosure> & func)1225 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<MetaFuncGraphAbstractClosure> &func) {
1226   MS_EXCEPTION_IF_NULL(func);
1227   auto [iter, is_new] = evaluators_.emplace(func, nullptr);
1228   if (is_new) {
1229     iter->second = std::make_shared<MetaFuncGraphEvaluator>(func->meta_func_graph(), func->GetScope());
1230   }
1231   return iter->second;
1232 }
1233 
_GetEvaluatorFor(const std::shared_ptr<JTransformedAbstractClosure> & func)1234 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<JTransformedAbstractClosure> &func) {
1235   MS_EXCEPTION_IF_NULL(func);
1236   const auto &primal_func = func->fn();
1237   auto primal_evaluator = GetEvaluatorFor(primal_func);
1238   return std::make_shared<JEvaluator>(primal_evaluator, primal_func);
1239 }
1240 
_GetEvaluatorFor(const std::shared_ptr<VmapTransformedAbstractClosure> & func)1241 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<VmapTransformedAbstractClosure> &func) {
1242   MS_EXCEPTION_IF_NULL(func);
1243   const auto &primal_func = func->fn();
1244   const auto &in_axes = func->in_axes();
1245   const auto &out_axes = func->out_axes();
1246   size_t cell_size = func->cell_size();
1247   auto primal_evaluator = GetEvaluatorFor(primal_func);
1248   return std::make_shared<VmapEvaluator>(primal_evaluator, primal_func, in_axes, out_axes, cell_size);
1249 }
1250 
_GetEvaluatorFor(const std::shared_ptr<TaylorTransformedAbstractClosure> & func)1251 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<TaylorTransformedAbstractClosure> &func) {
1252   MS_EXCEPTION_IF_NULL(func);
1253   const auto &primal_func = func->fn();
1254   auto primal_evaluator = GetEvaluatorFor(primal_func);
1255   return std::make_shared<TaylorEvaluator>(primal_evaluator, primal_func);
1256 }
1257 
_GetEvaluatorFor(const std::shared_ptr<ShardTransformedAbstractClosure> & func)1258 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<ShardTransformedAbstractClosure> &func) {
1259   MS_EXCEPTION_IF_NULL(func);
1260   const auto &primal_func = func->fn();
1261   auto primal_evaluator = GetEvaluatorFor(primal_func);
1262   return std::make_shared<ShardEvaluator>(primal_evaluator, primal_func);
1263 }
1264 
_GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> & func)1265 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> &func) {
1266   MS_EXCEPTION_IF_NULL(func);
1267   return std::make_shared<VirtualEvaluator>(func->args_abs_list(), func->output());
1268 }
1269 
_GetEvaluatorFor(const std::shared_ptr<PartialAbstractClosure> & func)1270 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PartialAbstractClosure> &func) {
1271   MS_EXCEPTION_IF_NULL(func);
1272   auto primal_func = func->fn();
1273   auto part_pair = std::make_pair(primal_func, func->args());
1274   auto iter = constructors_app_.find(part_pair);
1275   if (iter != constructors_app_.end()) {
1276     return iter->second;
1277   }
1278   EvaluatorPtr partial_evaluator = nullptr;
1279   if (func->need_append_to_end()) {
1280     partial_evaluator = std::make_shared<PartialToEndEvaluator>(primal_func);
1281   } else {
1282     auto primal_evaluator = GetEvaluatorFor(primal_func);
1283     partial_evaluator = std::make_shared<PartialAppEvaluator>(primal_evaluator, func->args());
1284   }
1285   auto result = constructors_app_.emplace(std::move(part_pair), std::move(partial_evaluator));
1286   return result.first->second;
1287 }
1288 
GetEvaluatorFor(const AbstractFunctionPtr & func)1289 EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
1290   MS_EXCEPTION_IF_NULL(func);
1291   MS_LOG(DEBUG) << "GetEvaluatorFor: " << func->ToString() << " tracking_id: " << func->tracking_id();
1292 
1293   if (func->isa<PrimitiveAbstractClosure>()) {
1294     return _GetEvaluatorFor(std::static_pointer_cast<PrimitiveAbstractClosure>(func));
1295   }
1296   if (func->isa<FuncGraphAbstractClosure>()) {
1297     return _GetEvaluatorFor(std::static_pointer_cast<FuncGraphAbstractClosure>(func));
1298   }
1299   if (func->isa<MetaFuncGraphAbstractClosure>()) {
1300     return _GetEvaluatorFor(std::static_pointer_cast<MetaFuncGraphAbstractClosure>(func));
1301   }
1302   if (func->isa<JTransformedAbstractClosure>()) {
1303     return _GetEvaluatorFor(std::static_pointer_cast<JTransformedAbstractClosure>(func));
1304   }
1305   if (func->isa<VmapTransformedAbstractClosure>()) {
1306     return _GetEvaluatorFor(std::static_pointer_cast<VmapTransformedAbstractClosure>(func));
1307   }
1308   if (func->isa<TaylorTransformedAbstractClosure>()) {
1309     return _GetEvaluatorFor(std::static_pointer_cast<TaylorTransformedAbstractClosure>(func));
1310   }
1311   if (func->isa<ShardTransformedAbstractClosure>()) {
1312     return _GetEvaluatorFor(std::static_pointer_cast<ShardTransformedAbstractClosure>(func));
1313   }
1314   if (func->isa<VirtualAbstractClosure>()) {
1315     return _GetEvaluatorFor(std::static_pointer_cast<VirtualAbstractClosure>(func));
1316   }
1317   if (func->isa<PartialAbstractClosure>()) {
1318     return _GetEvaluatorFor(std::static_pointer_cast<PartialAbstractClosure>(func));
1319   }
1320 
1321   MS_LOG(INTERNAL_EXCEPTION) << "Cannot GetEvaluator from " << func->type_name();
1322 }
1323 
ForwardConfig(const AnfNodeConfigPtr & orig_conf,const AnfNodeConfigPtr new_conf)1324 EvalResultPtr AnalysisEngine::ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) {
1325   MS_EXCEPTION_IF_NULL(orig_conf);
1326   MS_EXCEPTION_IF_NULL(new_conf);
1327   // If always_eval_flag is true in BaseFuncGraphEvaluaotr, then the CNode with same orig_conf may be forwarded
1328   // again, so update the config_map with new_conf;
1329   anfnode_config_map_[orig_conf] = new_conf;
1330   MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->ToString() << ", to new_conf: " << new_conf->ToString();
1331   MS_EXCEPTION_IF_NULL(orig_conf->node());
1332   MS_EXCEPTION_IF_NULL(new_conf->node());
1333   auto old_cnode = orig_conf->node()->cast_ptr<CNode>();
1334   auto new_cnode = new_conf->node()->cast<CNodePtr>();
1335   if (old_cnode != nullptr && new_cnode != nullptr) {
1336     if (old_cnode->func_graph() == new_cnode->func_graph()) {
1337       MS_LOG(DEBUG) << "Try to remove forward node from order list, forward node: " << new_cnode->DebugString()
1338                     << ", as origin node should be in order list, origin_node: " << old_cnode->DebugString();
1339       old_cnode->func_graph()->EraseUnusedNodeInOrder(new_cnode);
1340     } else {
1341       MS_LOG(INTERNAL_EXCEPTION) << "Forward orig_node to different func_graph, old_node: " << old_cnode->DebugString()
1342                                  << ", new_node: " << new_cnode->DebugString();
1343     }
1344   }
1345   (void)forward_count_++;
1346   auto res = ObtainEvalResultWithCache(new_conf);
1347   (void)forward_count_--;
1348   return res;
1349 }
1350 
ExecuteEvaluators(const std::vector<EvaluatorPtr> & evaluators,const AnfNodeConfigPtr & out_conf,const ConfigPtrList & args_conf_list)1351 EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators,
1352                                                 const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list) {
1353   if (evaluators.size() == 1) {
1354     auto &eval = evaluators[0];
1355     MS_EXCEPTION_IF_NULL(eval);
1356     return eval->Run(shared_from_this(), args_conf_list, out_conf);
1357   }
1358   static const bool enable_single_thread = (common::GetCompileConfig("SINGLE_EVAL") == "1");
1359   if (enable_single_thread) {
1360     return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list);
1361   }
1362   return ExecuteMultipleEvaluatorsMultiThread(evaluators, out_conf, args_conf_list);
1363 }
1364 
SetUndeterminedFlag(const std::string & thread_id,const FuncGraph & fg)1365 void AnalysisEngine::SetUndeterminedFlag(const std::string &thread_id, const FuncGraph &fg) {
1366   static std::mutex fg_lock;
1367   std::lock_guard<std::mutex> infer_lock(fg_lock);
1368   MS_LOG(DEBUG) << "Record undetermined flag of fg:" << fg.ToString() << ", thread id:" << thread_id;
1369   func_graph_undetermined_flags_[&fg].push_front(thread_id);
1370 }
1371 
SetIgnoreValueFlag(const std::string & thread_id,FuncGraph * fg)1372 void AnalysisEngine::SetIgnoreValueFlag(const std::string &thread_id, FuncGraph *fg) {
1373   MS_EXCEPTION_IF_NULL(fg);
1374   auto it = func_graph_undetermined_flags_.find(fg);
1375   if (it == func_graph_undetermined_flags_.cend()) {
1376     return;
1377   }
1378   for (const auto &id : it->second) {
1379     if (thread_id.find(id) != std::string::npos && thread_id != id) {
1380       MS_LOG(DEBUG) << "Set ignore value of fg:" << fg->ToString() << ", thread id:" << thread_id;
1381       fg->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUE, true);
1382       return;
1383     }
1384   }
1385 }
1386 
HandleNestedRecursion(const std::vector<EvaluatorPtr> & evaluators,const EvaluatorPtr & eval,const AbstractBasePtrList & args_abs_list,const EvalTraceRevIter & it,bool * continue_flag)1387 EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators,
1388                                                    const EvaluatorPtr &eval, const AbstractBasePtrList &args_abs_list,
1389                                                    const EvalTraceRevIter &it, bool *continue_flag) {
1390   MS_EXCEPTION_IF_NULL(continue_flag);
1391   MS_EXCEPTION_IF_NULL(eval);
1392   *continue_flag = false;
1393   // Find latest entry function to handle nested recursion.
1394   EvaluatorPtr latest_entry = eval;
1395   auto latest_entry_iter = eval_trace_.crbegin();
1396   for (auto r_it = eval_trace_.crbegin(); *r_it != *it;) {
1397     auto it_temp = std::find(evaluators.cbegin(), evaluators.cend(), r_it->evaluator_);
1398     if (it_temp != evaluators.cend()) {
1399       latest_entry = *it_temp;
1400       latest_entry_iter = r_it;
1401       break;
1402     }
1403     latest_entry_iter = ++r_it;
1404   }
1405   if (latest_entry != eval) {
1406     MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString();
1407     *continue_flag = true;
1408     return latest_entry;
1409   }
1410 
1411   bool has_undetermined = false;
1412   // Check whether sub loop has untraced undetermined evaluator.
1413   std::unordered_set<EvaluatorArgs, EvaluatorArgsHasher, EvaluatorArgsEqual> undetermined_evals;
1414   for (auto r_it = eval_trace_.crbegin(); r_it != latest_entry_iter; r_it++) {
1415     (void)undetermined_evals.insert(*r_it);
1416   }
1417   MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size();
1418 
1419   for (const auto &u_eval : undetermined_evals) {
1420     MS_EXCEPTION_IF_NULL(u_eval.evaluator_);
1421     MS_LOG(DEBUG) << u_eval.evaluator_->ToString() << "check undetermined.";
1422     auto &alternate_evaluator = multi_poss_[u_eval.evaluator_];
1423     MS_EXCEPTION_IF_NULL(alternate_evaluator);
1424     auto eval_cache = alternate_evaluator->evaluator_cache_mgr();
1425     MS_EXCEPTION_IF_NULL(eval_cache);
1426     const auto &alt_eval_args = EvaluatorArgs(alternate_evaluator, args_abs_list);
1427     auto is_not_undetermined_eval = (undetermined_evals.find(alt_eval_args) == undetermined_evals.cend());
1428     auto is_not_continued_eval = (continued_evals_.find(u_eval) == continued_evals_.cend());
1429     auto args_not_evaluated = (eval_cache->GetValue(args_abs_list) == nullptr);
1430     if (is_not_undetermined_eval && (args_not_evaluated || is_not_continued_eval)) {
1431       MS_LOG(DEBUG) << u_eval.evaluator_->ToString() << "has undetermined.";
1432       has_undetermined = true;
1433       break;
1434     }
1435   }
1436   if (!has_undetermined) {
1437     MS_LOG(DEBUG) << eval->ToString() << "has no undetermined.";
1438     *continue_flag = true;
1439     return latest_entry;
1440   }
1441 
1442   return latest_entry;
1443 }
1444 
GetFuncGraphFromBranchNode(const AnfNodePtr & branch_node)1445 FuncGraphPtr GetFuncGraphFromBranchNode(const AnfNodePtr &branch_node) {
1446   MS_EXCEPTION_IF_NULL(branch_node);
1447   auto fg = GetValueNode<FuncGraphPtr>(branch_node);
1448   if (fg != nullptr) {
1449     return fg;
1450   }
1451   if (IsPrimitiveCNode(branch_node, prim::kPrimPartial)) {
1452     fg = GetValueNode<FuncGraphPtr>(branch_node->cast<CNodePtr>()->input(kPartialGraphIndex));
1453   }
1454   if (fg != nullptr) {
1455     return fg;
1456   }
1457   MS_LOG(INTERNAL_EXCEPTION) << "Unexpected branch node: " << branch_node->DebugString();
1458 }
1459 
JoinBranchesFailedInfo(const AbstractBasePtr & abs,const AbstractBasePtr & last_out_abs,const AnfNodePtr & node,const std::string & error_info)1460 std::string JoinBranchesFailedInfo(const AbstractBasePtr &abs, const AbstractBasePtr &last_out_abs,
1461                                    const AnfNodePtr &node, const std::string &error_info) {
1462   constexpr int recursive_level = 2;
1463   std::ostringstream buffer;
1464   buffer << "Cannot join the return values of different branches, perhaps you need to make them equal.\n"
1465          << error_info
1466          << "#dmsg#Framework Error Message:#dmsg#The abstract type of the return value of the current branch is:\n"
1467          << abs->ToString() << ",\n and that of the previous branch is:\n"
1468          << last_out_abs->ToString() << ".\n"
1469          << "The node is " << node->DebugString(recursive_level);
1470   if (!node->isa<CNode>()) {
1471     buffer << "\n";
1472     return buffer.str();
1473   }
1474   auto input_node = node->cast_ptr<CNode>()->input(0);
1475   if (IsPrimitiveCNode(input_node, prim::kPrimSwitch)) {
1476     // {prim::kPrimSwitch, cond, true_branch, false_branch}
1477     const auto &cnode = input_node->cast_ptr<CNode>();
1478     auto true_out = GetFuncGraphFromBranchNode(cnode->input(kSwitchTrueBranchIndex))->get_return();
1479     auto false_out = GetFuncGraphFromBranchNode(cnode->input(kSwitchFalseBranchIndex))->get_return();
1480     buffer << ", true branch: " << cnode->input(kSwitchTrueBranchIndex)->ToString() << "\n"
1481            << trace::GetDebugInfoStr(true_out->debug_info())
1482            << "\n, false branch: " << cnode->input(kSwitchFalseBranchIndex)->ToString() << "\n"
1483            << trace::GetDebugInfoStr(false_out->debug_info());
1484   } else if (IsPrimitiveCNode(input_node, prim::kPrimSwitchLayer)) {
1485     // {prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, branch1, branch2, ...}}
1486     constexpr int branch_index = 2;
1487     const auto &tuple_node = input_node->cast_ptr<CNode>()->input(branch_index);
1488     if (IsPrimitiveCNode(tuple_node, prim::kPrimMakeTuple)) {
1489       const auto &cnode = tuple_node->cast_ptr<CNode>();
1490       for (size_t i = 1; i < cnode->size(); i++) {
1491         auto out_node = GetValueNode<FuncGraphPtr>(cnode->input(i))->get_return();
1492         MS_EXCEPTION_IF_NULL(out_node);
1493         buffer << ", branch" << i << ": " << cnode->input(i)->ToString() << "\n"
1494                << trace::GetDebugInfoStr(out_node->debug_info());
1495       }
1496     }
1497   } else {
1498     buffer << trace::GetDebugInfoStr(node->debug_info());
1499   }
1500   buffer << "\n";
1501   return buffer.str();
1502 }
1503 
SetUseFlagsForJoinedAny(const AbstractBasePtrList & out_abs_list)1504 void SetUseFlagsForJoinedAny(const AbstractBasePtrList &out_abs_list) {
1505   for (const auto &abs : out_abs_list) {
1506     SetSequenceElementsUseFlagsRecursively(abs, true);
1507   }
1508 }
1509 
ProcessEvalResults(const AbstractBasePtrList & out_abs_list,const AnfNodePtr & node)1510 EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_abs_list, const AnfNodePtr &node) {
1511   if (out_abs_list.empty()) {
1512     MS_LOG(INTERNAL_EXCEPTION) << "There is an endless loop for evaluator.";
1513   }
1514 
1515   if (out_abs_list.size() == 1) {
1516     MS_EXCEPTION_IF_NULL(out_abs_list[0]);
1517     // If only one result derived, then broaden it to avoid wrong constant propagation.
1518     return std::make_shared<EvalResult>(out_abs_list[0]->Broaden(), std::make_shared<AttrValueMap>());
1519   }
1520   MS_EXCEPTION_IF_NULL(node);
1521 
1522   // Return Any if some branch returns Any.
1523   if (std::any_of(out_abs_list.cbegin(), out_abs_list.cend(), [](const AbstractBasePtr &abs) {
1524         MS_EXCEPTION_IF_NULL(abs);
1525         return abs->isa<AbstractAny>() && !abs->isa<AbstractNegligible>();
1526       })) {
1527     MS_LOG(INFO) << "The branches outputs contain Any output.\nJoin them to Any output.";
1528     return std::make_shared<EvalResult>(std::make_shared<AbstractAny>(), std::make_shared<AttrValueMap>());
1529   }
1530 
1531   AbstractBasePtr last_out_abs = out_abs_list[0];
1532   MS_EXCEPTION_IF_NULL(last_out_abs);
1533   AbstractBasePtr joined_abs = out_abs_list[0];
1534   for (size_t i = 1; i < out_abs_list.size(); ++i) {
1535     const auto &abs = out_abs_list[i];
1536     MS_EXCEPTION_IF_NULL(abs);
1537     try {
1538       MS_LOG(DEBUG) << "Join node: " << node->DebugString() << ", " << joined_abs->ToString() << ", and "
1539                     << abs->ToString();
1540       MS_LOG_TRY_CATCH_SCOPE;
1541       joined_abs = joined_abs->Join(abs);
1542     } catch (const py::type_error &ex) {
1543       auto error_info = ExtractLoggingInfo(ex.what());
1544       const auto info = JoinBranchesFailedInfo(abs, last_out_abs, node, error_info);
1545       MS_LOG(INFO) << info;
1546       auto joined_any = std::make_shared<AbstractJoinedAny>();
1547       joined_any->set_exception(AbstractJoinedAny::ExceptionType::kTypeError);
1548       joined_any->set_message(info);
1549       SetUseFlagsForJoinedAny(out_abs_list);
1550       return std::make_shared<EvalResult>(joined_any, std::make_shared<AttrValueMap>());
1551     } catch (const py::value_error &ex) {
1552       auto error_info = ExtractLoggingInfo(ex.what());
1553       const auto info = JoinBranchesFailedInfo(abs, last_out_abs, node, error_info);
1554       MS_LOG(INFO) << info;
1555       auto joined_any = std::make_shared<AbstractJoinedAny>();
1556       joined_any->set_exception(AbstractJoinedAny::ExceptionType::kValueError);
1557       joined_any->set_message(info);
1558       SetUseFlagsForJoinedAny(out_abs_list);
1559       return std::make_shared<EvalResult>(joined_any, std::make_shared<AttrValueMap>());
1560     } catch (const std::exception &ex) {
1561       auto error_info = ExtractLoggingInfo(ex.what());
1562       const auto info = JoinBranchesFailedInfo(abs, last_out_abs, node, error_info);
1563       MS_LOG(INFO) << info;
1564       auto joined_any = std::make_shared<AbstractJoinedAny>();
1565       joined_any->set_exception(AbstractJoinedAny::ExceptionType::kDefault);
1566       joined_any->set_message(info);
1567       // Remove it when the transform form dict to tuple is disabled in Compatible or Lax mode.
1568       if (joined_abs->isa<AbstractDictionary>()) {
1569         joined_any->set_user_data<bool>("from_dict", std::make_shared<bool>(true));
1570       }
1571       SetUseFlagsForJoinedAny(out_abs_list);
1572       return std::make_shared<EvalResult>(joined_any, std::make_shared<AttrValueMap>());
1573     }
1574     MS_EXCEPTION_IF_NULL(joined_abs);
1575     last_out_abs = abs;
1576   }
1577 
1578   MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_abs->ToString();
1579   return std::make_shared<EvalResult>(joined_abs, std::make_shared<AttrValueMap>());
1580 }
1581 
ExecuteMultipleEvaluatorsMultiThread(const std::vector<EvaluatorPtr> & evaluators,const AnfNodeConfigPtr & out_conf,const ConfigPtrList & args_conf_list)1582 EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::vector<EvaluatorPtr> &evaluators,
1583                                                                    const AnfNodeConfigPtr &out_conf,
1584                                                                    const ConfigPtrList &args_conf_list) {
1585   MS_EXCEPTION_IF_NULL(out_conf);
1586   MS_EXCEPTION_IF_NULL(out_conf->node());
1587   MS_EXCEPTION_IF_NULL(out_conf->func_graph());
1588   // Release GIL for C++
1589   MS_LOG(DEBUG) << out_conf->func_graph()->ToString() << "_" << std::this_thread::get_id() << " begin.";
1590   py::gil_scoped_release infer_gil_release;
1591 
1592   // Only one thread to run
1593   AnalysisSchedule::GetInstance().WaitForRun();
1594 
1595   // Wait for the last switch node to finish.
1596   MS_LOG(DEBUG) << GetInferThread() << "async : entry switch  " << out_conf->ToString();
1597   auto eval_result = AnalysisResultCacheMgr::GetInstance().GetSwitchValue(out_conf);
1598   if (eval_result == nullptr) {
1599     MS_LOG(DEBUG) << GetInferThread() << "async : Init switch  " << out_conf->node()->ToString();
1600     AnalysisResultCacheMgr::GetInstance().InitSwitchValue(out_conf);
1601   } else {
1602     return std::make_shared<EvalResult>(eval_result, nullptr);
1603   }
1604   auto possible_parent_fg = out_conf->node()->func_graph();
1605   MS_EXCEPTION_IF_NULL(possible_parent_fg);
1606   // Eval result of the main.
1607   AsyncAbstractPtr async_result_main = std::make_shared<AsyncAbstract>();
1608   if (possible_parent_fg->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUE)) {
1609     async_result_main->set_ignore_value(true);
1610   }
1611   // Eval result of the branches
1612   std::vector<AsyncAbstractPtr> async_result_branches;
1613   SetUndeterminedFlag(AnalysisSchedule::thread_id(), *possible_parent_fg);
1614   for (auto &evaluator : evaluators) {
1615     static std::atomic<int> id_count{0};
1616     std::string thread_id = AnalysisSchedule::thread_id() + "." + std::to_string(id_count.fetch_add(1));
1617     MS_EXCEPTION_IF_NULL(evaluator);
1618     AsyncAbstractPtr async_result_branch = std::make_shared<AsyncAbstract>(async_result_main);
1619     // Control the order to run.
1620     AsyncAbstractPtr control_run_order = std::make_shared<AsyncAbstract>();
1621     control_run_order->set_result(std::make_shared<AbstractScalar>(1));
1622     AsyncInferTaskPtr async_task = AsyncInferTask::MakeShared(control_run_order, thread_id);
1623     AnalysisSchedule::GetInstance().IncreaseThreadCount();
1624     MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluator->ToString();
1625     auto thread = std::thread(ExecEvaluator, evaluator, shared_from_this(), args_conf_list, out_conf, thread_id,
1626                               async_result_branch, async_result_main, async_task, trace::GetCurrentGraphEvalStack(),
1627                               trace::GetCNodeDebugStack());
1628     thread.detach();
1629 
1630     // Push to list of running loop
1631     MS_LOG(DEBUG) << "Add to schedule: " << async_task.get();
1632     AnalysisSchedule::GetInstance().Add2Schedule(async_task);  // Activate order witch child thread.
1633     (void)async_result_branches.emplace_back(std::move(async_result_branch));
1634   }
1635 
1636   size_t len = evaluators.size();
1637   size_t min_size = 2;
1638   if (len < min_size) {
1639     MS_LOG(EXCEPTION) << "There are at least 2 evaluators in multi thread, but got " << len << " evaluator.";
1640   }
1641 
1642   MS_LOG(DEBUG) << GetInferThread() << "async : wait for one of async to finish.  " << evaluators[0]->ToString()
1643                 << " or  " << evaluators[1]->ToString() << "...";
1644 
1645   auto first_result = async_result_main->GetResult();
1646   MS_EXCEPTION_IF_NULL(first_result);
1647   MS_LOG(DEBUG) << GetInferThread() << "async main thread result of " << out_conf->node()->ToString() << " = "
1648                 << first_result->ToString();
1649 
1650   AbstractBasePtrList out_abs_list;
1651   if (NeedWaitForBranches(first_result)) {
1652     MS_LOG(DEBUG) << GetInferThread() << " BuildPossibleSpecs.";
1653     BuildPossibleSpecs(first_result, async_result_branches, &out_abs_list);
1654   } else {
1655     for (size_t i = 0; i < len; ++i) {
1656       AbstractBasePtr result;
1657       MS_EXCEPTION_IF_NULL(async_result_branches[i]);
1658       if (enable_waiting_branch_eval()) {
1659         // wait to get the result of branch.
1660         result = async_result_branches[i]->GetResult();
1661       } else {
1662         // Not wait to get the result of branch.
1663         result = async_result_branches[i]->TryGetResult();
1664       }
1665 
1666       if (result) {
1667         MS_EXCEPTION_IF_NULL(evaluators[i]);
1668         MS_EXCEPTION_IF_NULL(result);
1669         MS_LOG(DEBUG) << "#" << i << ": " << GetInferThread() << " async get " << evaluators[i]->ToString()
1670                       << ", result: " << result->ToString() << ", args: " << args_conf_list;
1671         out_abs_list.push_back(result);
1672       }
1673     }
1674   }
1675   MS_LOG(DEBUG) << GetInferThread() << " finish.";
1676   const auto &processed_result = ProcessEvalResults(out_abs_list, out_conf->node());
1677   if (processed_result != nullptr) {
1678     // This is the final switch()() value.
1679     AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, processed_result->abstract());
1680   }
1681   MS_LOG(DEBUG) << GetInferThread() << " join finish.";
1682   return processed_result;
1683 }
1684 
ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> & evaluators,const AnfNodeConfigPtr & out_conf,const ConfigPtrList & args_conf_list)1685 EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
1686                                                         const AnfNodeConfigPtr &out_conf,
1687                                                         const ConfigPtrList &args_conf_list) {
1688   AbstractBasePtrList out_abs_list;
1689   const size_t evaluators_size = 2;
1690   if (evaluators.size() < evaluators_size) {
1691     MS_LOG(INTERNAL_EXCEPTION) << "Evaluators size is less than 2.";
1692   }
1693   multi_poss_[evaluators[0]] = evaluators[1];
1694   multi_poss_[evaluators[1]] = evaluators[0];
1695   AbstractBasePtrList args_abs_list = EvaluateArguments(args_conf_list);
1696   MS_EXCEPTION_IF_NULL(out_conf);
1697   MS_EXCEPTION_IF_NULL(out_conf->node());
1698   auto possible_parent_fg = out_conf->node()->func_graph();
1699   MS_EXCEPTION_IF_NULL(possible_parent_fg);
1700   possible_parent_fg->set_flag(kFuncGraphFlagUndetermined, true);
1701   MS_LOG(DEBUG) << "Set graph undetermined flag for " << possible_parent_fg->ToString();
1702   for (const auto &eval : evaluators) {
1703     MS_EXCEPTION_IF_NULL(eval);
1704     const auto current_inf = EvaluatorArgs(eval, args_abs_list);
1705     MS_LOG(DEBUG) << "Check evaluator " << eval->ToString();
1706     // If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating.
1707     auto it = std::find(eval_trace_.crbegin(), eval_trace_.crend(), current_inf);
1708     if (it == eval_trace_.crend()) {
1709       eval_trace_.push_back(current_inf);
1710       auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf);
1711       MS_EXCEPTION_IF_NULL(eval_result);
1712       auto eval_abstract = eval_result->abstract();
1713       MS_EXCEPTION_IF_NULL(eval_abstract);
1714 
1715       out_abs_list.push_back(eval_abstract);
1716       eval_trace_.pop_back();
1717       if (eval_trace_.empty()) {
1718         multi_poss_.clear();
1719       }
1720     } else {
1721       bool continue_flag = false;
1722       auto latest_entry = HandleNestedRecursion(evaluators, eval, args_abs_list, it, &continue_flag);
1723       if (continue_flag) {
1724         MS_EXCEPTION_IF_NULL(current_inf.evaluator_);
1725         MS_LOG(DEBUG) << "The continued_evals_ insert " << current_inf.evaluator_.get() << "/"
1726                       << current_inf.evaluator_->ToString();
1727         continued_evals_.insert(current_inf);
1728         continue;
1729       }
1730 
1731       // Try to travel the latest undetermined.
1732       if (latest_entry != eval_trace_.rbegin()->evaluator_) {
1733         MS_LOG(DEBUG) << "Direct run evaluator " << eval.get() << "/" << eval->ToString();
1734         auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf);
1735         MS_EXCEPTION_IF_NULL(eval_result);
1736         MS_EXCEPTION_IF_NULL(eval_result->abstract());
1737         MS_LOG(DEBUG) << "End direct evaluator " << latest_entry->ToString()
1738                       << ", return out_abs: " << eval_result->abstract()->ToString();
1739         possible_parent_fg->set_flag(kFuncGraphFlagUndetermined, false);
1740         return eval_result;
1741       }
1742     }
1743   }
1744   possible_parent_fg->set_flag(kFuncGraphFlagUndetermined, false);
1745   return ProcessEvalResults(out_abs_list, out_conf->node());
1746 }
1747 
ObtainEvalResult()1748 EvalResultPtr AnfNodeConfig::ObtainEvalResult() {
1749   AnfNodeConfigPtr self = shared_from_base<AnfNodeConfig>();
1750   return engine_.lock()->ObtainEvalResultWithCache(self);
1751 }
1752 
MakeAbstractClosure(const FuncGraphPtr & func_graph,const AnalysisContextPtr & context,const AnfNodePtr & anf_node)1753 AbstractBasePtr MakeAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
1754                                     const AnfNodePtr &anf_node) {
1755   AnalysisContextPtr temp_context = context;
1756   if (temp_context == nullptr) {
1757     temp_context = AnalysisContext::DummyContext();
1758   }
1759   return std::make_shared<FuncGraphAbstractClosure>(func_graph, temp_context, anf_node);
1760 }
1761 
MakeAbstractClosure(const MetaFuncGraphPtr & meta_func_graph,const AnfNodePtr & anf_node)1762 AbstractBasePtr MakeAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const AnfNodePtr &anf_node) {
1763   MetaFuncGraphAbstractClosurePtr meta_func_graph_fn;
1764   if (anf_node == nullptr) {
1765     meta_func_graph_fn = std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph);
1766   } else {
1767     meta_func_graph_fn = std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph, anf_node, anf_node->scope());
1768   }
1769   return meta_func_graph_fn;
1770 }
1771 
MakeAbstractClosure(const PrimitivePtr & primitive,const AnfNodePtr & anf_node)1772 AbstractBasePtr MakeAbstractClosure(const PrimitivePtr &primitive, const AnfNodePtr &anf_node) {
1773   auto prim_func = std::make_shared<PrimitiveAbstractClosure>(primitive, anf_node);
1774   return prim_func;
1775 }
1776 
ToAbstract(const ValuePtr & value,const AnalysisContextPtr & context,const AnfNodeConfigPtr & conf)1777 AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context, const AnfNodeConfigPtr &conf) {
1778   MS_EXCEPTION_IF_NULL(value);
1779   AnfNodePtr anf_node = nullptr;
1780   if (conf != nullptr) {
1781     anf_node = conf->node();
1782   }
1783   if (value->isa<Primitive>()) {
1784     auto prim = value->cast<PrimitivePtr>();
1785     return MakeAbstractClosure(prim, anf_node);
1786   }
1787   if (value->isa<FuncGraph>()) {
1788     auto func_graph = value->cast<FuncGraphPtr>();
1789     return MakeAbstractClosure(func_graph, context, anf_node);
1790   }
1791   if (value->isa<MetaFuncGraph>()) {
1792     auto meta_func_graph = value->cast<MetaFuncGraphPtr>();
1793     return MakeAbstractClosure(meta_func_graph, anf_node);
1794   }
1795   if (value->isa<ValueSequence>() && anf_node != nullptr) {
1796     auto abs = value->ToAbstract();
1797     MS_EXCEPTION_IF_NULL(abs);
1798     // Attach corresponding python sequence object to AbstractSequence.
1799     py::object py_list_obj =
1800       fallback::HasPyObjectInNode(anf_node) ? fallback::GetPyObjectFromNode(anf_node) : ValueToPyData(value);
1801     fallback::AttachPyObjToAbs(abs, py_list_obj, !fallback::HasPyObjectInNode(anf_node));
1802     MS_LOG(DEBUG) << "Attach python list object " << fallback::GetPyObjectPtrStr(py_list_obj)
1803                   << " to new abstract: " << abs->ToString();
1804     // Set sequence node for new AbstractSequence.
1805     static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
1806     if (enable_eliminate_unused_element) {
1807       auto sequence_abs = abs->cast<AbstractSequencePtr>();
1808       MS_EXCEPTION_IF_NULL(sequence_abs);
1809       SetSequenceNodeElementsUseFlags(anf_node, std::make_shared<std::vector<bool>>(sequence_abs->elements().size()));
1810       std::shared_ptr<AnfNodeWeakPtrList> sequence_nodes = std::make_shared<AnfNodeWeakPtrList>();
1811       (void)sequence_nodes->emplace_back(AnfNodeWeakPtr(anf_node));
1812       sequence_abs->set_sequence_nodes(sequence_nodes);
1813     }
1814     return abs;
1815   }
1816   if (value->isa<ValueDictionary>() && anf_node != nullptr) {
1817     auto abs = value->ToAbstract();
1818     MS_EXCEPTION_IF_NULL(abs);
1819     // Attach corresponding python dictionary object to AbstractDictionary.
1820     py::object py_dict_obj =
1821       fallback::HasPyObjectInNode(anf_node) ? fallback::GetPyObjectFromNode(anf_node) : fallback::GeneratePyObj(abs);
1822     fallback::AttachPyObjToAbs(abs, py_dict_obj, !fallback::HasPyObjectInNode(anf_node));
1823     MS_LOG(DEBUG) << "Attach python dict object " << fallback::GetPyObjectPtrStr(py_dict_obj)
1824                   << " to new abstract: " << abs->ToString();
1825     return abs;
1826   }
1827   return value->ToAbstract();
1828 }
1829 
FromValueInside(const ValuePtr & value,bool broaden)1830 AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) {
1831   AbstractBasePtr a = ToAbstract(value, nullptr, nullptr);
1832   if (broaden) {
1833     a = a->Broaden();
1834   }
1835   return a;
1836 }
1837 
EvalOnePrim(const PrimitivePtr & primitive,const AbstractBasePtrList & arg_specs)1838 EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) {
1839   auto evaluator = GetPrimEvaluator(primitive, nullptr);
1840   if (evaluator == nullptr) {
1841     MS_LOG(ERROR) << "The evaluator of the primitive is not defined (" << primitive->name() << ").";
1842     return nullptr;
1843   }
1844   auto trivial_evaluator = dyn_cast_ptr<TrivialPrimEvaluator>(evaluator);
1845   if (trivial_evaluator != nullptr) {
1846     return trivial_evaluator->EvalPrim(nullptr, arg_specs);
1847   }
1848   // Support MakeTuple/MakeList ops in PyNative mode.
1849   auto transition_evaluator = dyn_cast_ptr<TransitionPrimEvaluator>(evaluator);
1850   if (transition_evaluator != nullptr) {
1851     if (transition_evaluator->isa<MakeTupleEvaluator>() || transition_evaluator->isa<MakeListEvaluator>()) {
1852       return transition_evaluator->EvalPrim(nullptr, arg_specs, nullptr, nullptr);
1853     }
1854     return pipeline::AbstractAnalyze(primitive, arg_specs).eval_result;
1855   }
1856   // To add EvalPrim call of TransitionPrimEvaluator such as GetAttr.
1857   MS_LOG(ERROR) << "The primitive '" << primitive->ToString() << "' should be built as a TrivialPrimEvaluator, but "
1858                 << evaluator->ToString();
1859   return nullptr;
1860 }
1861 
EvalFunctionValue(const ValuePtr & func,const AbstractBasePtrList & args_spec)1862 AbstractBasePtr EvalFunctionValue(const ValuePtr &func, const AbstractBasePtrList &args_spec) {
1863   auto func_abs = func->ToAbstract();
1864   if (!func_abs->isa<AbstractFunction>()) {
1865     MS_LOG(EXCEPTION) << "The value : " << func->ToString() << " is not a callable object.";
1866   }
1867   if (func->isa<Primitive>() && !func->isa<prim::DoSignaturePrimitive>()) {
1868     return EvalOnePrim(func->cast<PrimitivePtr>(), args_spec)->abstract();
1869   } else {
1870     auto infer_graph = std::make_shared<FuncGraph>();
1871     std::vector<AnfNodePtr> inputs = {std::make_shared<ValueNode>(func)};
1872     (void)std::transform(args_spec.begin(), args_spec.end(), std::back_inserter(inputs),
1873                          [infer_graph](const AbstractBasePtr &) -> AnfNodePtr { return infer_graph->add_parameter(); });
1874     auto infer_node = infer_graph->NewCNode(inputs);
1875     infer_graph->set_return(infer_node);
1876     auto manager = Manage(infer_graph, true);
1877     auto engine = std::make_shared<abstract::AnalysisEngine>(abstract::GetPrimEvaluatorConstructors(), manager);
1878     auto res = engine->Run(infer_graph, args_spec);
1879     return res.eval_result->abstract();
1880   }
1881 }
1882 
NewContext(const AnalysisContextPtr & current_context,const FuncGraphPtr & fg,const AbstractBasePtrList & args_abs_list)1883 AnalysisContextPtr NewContext(const AnalysisContextPtr &current_context, const FuncGraphPtr &fg,
1884                               const AbstractBasePtrList &args_abs_list) {
1885   MS_EXCEPTION_IF_NULL(fg);
1886   auto new_context = current_context->NewContext(fg, args_abs_list);
1887   if (new_context == nullptr) {  // Not obtain context for fg->parent() during create context.
1888     FuncGraphPtr parent_graph = fg->parent();
1889     const auto no_parent = parent_graph == nullptr;
1890 #ifdef ENABLE_DUMP_IR
1891     DumpIR(std::string("EXCEPTION_NEW_CONTEXT_CURRENT_") + (no_parent ? "0" : "1") + "_" + fg->ToString() + ".ir", fg);
1892     if (!no_parent) {
1893       DumpIR("EXCEPTION_NEW_CONTEXT_PARENT_" + parent_graph->ToString() + ".ir", parent_graph);
1894     }
1895 #endif
1896     // If parent context is not found, we'll raise exception.
1897     MS_LOG(INTERNAL_EXCEPTION) << "BUG: Failed to find parent context in current context: "
1898                                << current_context->ToString() << ", func_graph: " << fg->ToString()
1899                                << ", parent_graph: " << (no_parent ? "null" : parent_graph->ToString()) << ",\n"
1900                                << trace::GetDebugInfoStr(fg->debug_info());
1901   }
1902   return new_context;
1903 }
1904 }  // namespace abstract
1905 }  // namespace mindspore
1906