1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2024 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "pipeline/jit/ps/static_analysis/static_analysis.h"
20 #include <algorithm>
21 #include <memory>
22 #include <mutex>
23 #include <set>
24 #include <unordered_set>
25 #include <utility>
26 #include <atomic>
27 #include "mindspore/core/ops/structure_ops.h"
28 #include "mindspore/core/ops/sequence_ops.h"
29 #include "mindspore/core/ops/framework_ops.h"
30 #include "abstract/abstract_value.h"
31 #include "pipeline/jit/ps/fallback.h"
32 #include "pipeline/jit/ps/action.h"
33 #include "pipeline/jit/ps/parse/resolve.h"
34 #include "pipeline/jit/ps/static_analysis/prim.h"
35 #include "frontend/operator/ops.h"
36 #include "utils/ms_exception.h"
37 #include "utils/compile_config.h"
38 #include "ir/func_graph_cloner.h"
39 #include "pipeline/jit/ps/static_analysis/evaluator.h"
40 #include "pipeline/jit/ps/debug/trace.h"
41 #include "include/common/fallback.h"
42 #include "include/common/debug/anf_ir_dump.h"
43 #include "include/common/utils/convert_utils_py.h"
44 #include "include/common/utils/python_adapter.h"
45 #include "pipeline/jit/ps/static_analysis/async_eval_result.h"
46 #include "frontend/operator/ops_front_infer_function.h"
47 #include "frontend/operator/composite/composite.h"
48 #include "ops/op_def.h"
49
50 namespace mindspore {
51 namespace abstract {
52 // Record current depth of function call stack, including `stack_frame_depth`.
53 std::atomic<size_t> function_call_depth;
54 // Record current depth of stack frames call.
55 std::atomic<size_t> stack_frame_depth;
56
ResetFunctionCallDepth()57 void ResetFunctionCallDepth() { function_call_depth = 0; }
58
IncreaseFunctionCallDepth()59 void IncreaseFunctionCallDepth() { (void)(++function_call_depth); }
60
DecreaseFunctionCallDepth()61 void DecreaseFunctionCallDepth() {
62 if (function_call_depth == 0) {
63 MS_LOG(INTERNAL_EXCEPTION) << "Current function call depth is already 0, can not decrease it.";
64 }
65 function_call_depth--;
66 }
67
FunctionCallDepth()68 size_t FunctionCallDepth() { return function_call_depth; }
69
ResetStackFrameDepth()70 void ResetStackFrameDepth() { stack_frame_depth = 0; }
71
IncreaseStackFrameDepth()72 void IncreaseStackFrameDepth() { (void)(++stack_frame_depth); }
73
DecreaseStackFrameDepth()74 void DecreaseStackFrameDepth() {
75 if (stack_frame_depth == 0) {
76 MS_LOG(INTERNAL_EXCEPTION) << "Current stack frame depth is already 0, can not decrease it.";
77 }
78 stack_frame_depth--;
79 }
80
StackFrameDepth()81 size_t StackFrameDepth() { return stack_frame_depth; }
82
83 namespace {
ExecEvaluator(EvaluatorPtr eval,AnalysisEnginePtr engine,ConfigPtrList args_conf_list,AnfNodeConfigPtr out_conf,std::string thread_id,AsyncAbstractPtr async_result_branch,AsyncAbstractPtr async_result_main,AsyncInferTaskPtr async_task,trace::TraceGraphEvalStack graph_evals,trace::TraceCNodeEvalStack trace_c_node_evals)84 void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList args_conf_list, AnfNodeConfigPtr out_conf,
85 std::string thread_id, AsyncAbstractPtr async_result_branch, AsyncAbstractPtr async_result_main,
86 AsyncInferTaskPtr async_task, trace::TraceGraphEvalStack graph_evals,
87 trace::TraceCNodeEvalStack trace_c_node_evals) {
88 MS_EXCEPTION_IF_NULL(eval);
89 MS_EXCEPTION_IF_NULL(async_task);
90 AnalysisSchedule::set_thread_id(thread_id);
91 // Restore trace stack for dump stack when there is exception.
92 trace::TraceEvalCNodeStackPrepare(trace_c_node_evals);
93 trace_c_node_evals.clear();
94 trace::TraceGraphEvalStackPrepare(graph_evals);
95 graph_evals.clear();
96
97 try {
98 // Wait for Signal to run
99 MS_LOG(DEBUG) << async_task.get() << " " << eval->ToString() << " waiting.";
100 (void)async_task->GetResult();
101 MS_LOG(DEBUG) << async_task.get() << " " << eval->ToString() << " running.";
102
103 // Acquire GIL for eval to callback python.
104 EvalResultPtr result;
105 {
106 MS_LOG(DEBUG) << eval->ToString() << "_" << AnalysisSchedule::thread_id() << " begin.";
107 py::gil_scoped_acquire py_guard;
108 result = eval->Run(engine, args_conf_list, out_conf);
109 }
110 MS_LOG(DEBUG) << eval->ToString() << "_" << AnalysisSchedule::thread_id() << " end.";
111 MS_EXCEPTION_IF_NULL(result);
112 MS_EXCEPTION_IF_NULL(result->abstract());
113
114 // Check the branch value to be compatible with the other branch value.
115 AnalysisResultCacheMgr::GetInstance().CheckSwitchValueJoinable(out_conf, result->abstract());
116 // Broaden the result of switch(c,t,f)()
117 auto broaden_abstract = result->abstract()->Broaden();
118
119 MS_EXCEPTION_IF_NULL(async_result_branch);
120 MS_EXCEPTION_IF_NULL(async_result_main);
121 // Notify the thread of waiting for branch value and the main thread to continue.
122 async_result_branch->set_result(broaden_abstract);
123 async_result_main->set_result(broaden_abstract);
124 MS_LOG(DEBUG) << GetInferThread() << " async :" << eval->ToString()
125 << " asyncResult address = " << async_result_branch.get();
126 if (async_result_branch->TryGetResult()) {
127 MS_LOG(DEBUG) << "value = " << (async_result_branch->TryGetResult())->ToString();
128 } else {
129 MS_LOG(DEBUG) << "value = null.";
130 }
131 } catch (const std::exception &ex) {
132 MS_EXCEPTION_IF_NULL(out_conf->node());
133 MS_LOG(INFO) << GetInferThread() << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString()
134 << " threw exception: " << ex.what();
135 AnalysisSchedule::GetInstance().HandleException(ex);
136 }
137 trace::ClearTraceStack();
138 ClearThreadLocal();
139 MS_LOG(DEBUG) << AnalysisSchedule::thread_id() << " exited.";
140 // Thread number will be drop when thread exits.
141 AnalysisSchedule::GetInstance().DecreaseThreadCount();
142 }
143
BuildAsyncAbstractRecursively(const AbstractBasePtr & orig_abs,const std::vector<AsyncAbstractPtr> & pending_async_abstract_list,const std::vector<std::size_t> & index)144 AbstractBasePtr BuildAsyncAbstractRecursively(const AbstractBasePtr &orig_abs,
145 const std::vector<AsyncAbstractPtr> &pending_async_abstract_list,
146 const std::vector<std::size_t> &index) {
147 MS_EXCEPTION_IF_NULL(orig_abs);
148 auto sequence_abs = dyn_cast_ptr<AbstractSequence>(orig_abs);
149 if (sequence_abs != nullptr) {
150 const auto &orig_elements = sequence_abs->elements();
151 AbstractBasePtrList new_elements;
152 for (size_t i = 0; i < orig_elements.size(); ++i) {
153 MS_EXCEPTION_IF_NULL(orig_elements[i]);
154 if (orig_elements[i]->isa<AbstractFuncAtom>()) {
155 AbstractFuncAtomPtrList abs_func_list{orig_elements[i]->cast<AbstractFuncAtomPtr>()};
156 for (size_t j = 0; j < pending_async_abstract_list.size(); ++j) {
157 std::vector<std::size_t> new_index(index);
158 new_index.push_back(i);
159 auto async_func = AsyncAbstractFuncAtom::MakeShared(pending_async_abstract_list[j], new_index);
160 abs_func_list.push_back(async_func);
161 }
162 new_elements.push_back(AbstractFunction::MakeAbstractFunction(abs_func_list));
163 } else if (orig_elements[i]->isa<AbstractSequence>()) {
164 std::vector<std::size_t> new_index(index);
165 new_index.push_back(i);
166 new_elements.push_back(BuildAsyncAbstractRecursively(orig_elements[i], pending_async_abstract_list, new_index));
167 } else {
168 new_elements.push_back(orig_elements[i]);
169 }
170 }
171 static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
172 AbstractBasePtr new_abs;
173 if (orig_abs->isa<AbstractTuple>()) {
174 new_abs = std::make_shared<AbstractTuple>(
175 new_elements, (enable_eliminate_unused_element ? sequence_abs->sequence_nodes() : nullptr));
176 } else if (orig_abs->isa<AbstractList>()) {
177 new_abs = std::make_shared<AbstractList>(
178 new_elements, (enable_eliminate_unused_element ? sequence_abs->sequence_nodes() : nullptr));
179 } else {
180 MS_LOG(INTERNAL_EXCEPTION) << "FirstResult is not AbstractTuple or AbstractList, but: " << orig_abs->ToString();
181 }
182 return new_abs;
183 }
184 MS_LOG(INTERNAL_EXCEPTION) << "Orig abstract is not AbstractTuple or AbstractList, but: " << orig_abs->ToString();
185 }
186
BuildPossibleSpecs(const AbstractBasePtr & first_result,const std::vector<AsyncAbstractPtr> & branch_async_abstract_list,AbstractBasePtrList * out_abs_list)187 void BuildPossibleSpecs(const AbstractBasePtr &first_result,
188 const std::vector<AsyncAbstractPtr> &branch_async_abstract_list,
189 AbstractBasePtrList *out_abs_list) {
190 MS_EXCEPTION_IF_NULL(out_abs_list);
191 MS_EXCEPTION_IF_NULL(first_result);
192 std::vector<AsyncAbstractPtr> pending_async_abstract_list;
193 std::size_t len = branch_async_abstract_list.size();
194
195 for (size_t i = 0; i < len; ++i) {
196 AbstractBasePtr result;
197 MS_EXCEPTION_IF_NULL(branch_async_abstract_list[i]);
198 if (enable_waiting_branch_eval()) {
199 result = branch_async_abstract_list[i]->GetResult();
200 } else {
201 result = branch_async_abstract_list[i]->TryGetResult();
202 }
203
204 if (result) {
205 if (result->isa<AsyncAbstractFuncAtom>()) {
206 branch_async_abstract_list[i]->ClearPossibleResult();
207 pending_async_abstract_list.push_back(branch_async_abstract_list[i]);
208 MS_LOG(DEBUG) << "Pending add: " << branch_async_abstract_list[i].get() << "_"
209 << branch_async_abstract_list[i]->ToString();
210 } else {
211 out_abs_list->push_back(result);
212 }
213 } else {
214 pending_async_abstract_list.push_back(branch_async_abstract_list[i]);
215 MS_LOG(DEBUG) << "Pending add: " << branch_async_abstract_list[i].get() << "_"
216 << branch_async_abstract_list[i]->ToString();
217 }
218 }
219
220 if (first_result->isa<AbstractFunction>()) {
221 for (std::size_t j = 0; j < pending_async_abstract_list.size(); ++j) {
222 auto async_func = AsyncAbstractFuncAtom::MakeShared(pending_async_abstract_list[j], std::vector<size_t>{0});
223 out_abs_list->push_back(async_func);
224 MS_LOG(DEBUG) << "out_abs_list add: " << async_func.get() << "_" << async_func->ToString();
225 }
226 } else if (first_result->isa<AbstractSequence>()) {
227 const auto &new_first_result =
228 BuildAsyncAbstractRecursively(first_result, pending_async_abstract_list, std::vector<size_t>());
229 MS_LOG(DEBUG) << GetInferThread() << " Try to replace old first with new one, old: " << first_result->ToString()
230 << ", new: " << new_first_result->ToString();
231 std::replace_if(
232 out_abs_list->begin(), out_abs_list->end(),
233 [first_result](const auto &element) { return element == first_result; }, new_first_result);
234 } else {
235 MS_LOG(DEBUG) << GetInferThread() << " wait for normal async result";
236 }
237 }
238
ConvertToPyInterpretCall(const CNodePtr & cnode,const AnfNodeConfigPtr & conf,const AnfNodePtr & func_node=nullptr)239 EvalResultPtr ConvertToPyInterpretCall(const CNodePtr &cnode, const AnfNodeConfigPtr &conf,
240 const AnfNodePtr &func_node = nullptr) {
241 auto fg = cnode->func_graph();
242 MS_EXCEPTION_IF_NULL(fg);
243 auto out_node = conf->node();
244 MS_EXCEPTION_IF_NULL(out_node);
245 std::stringstream script_buffer;
246 AnfNodePtrList local_key_inputs = {NewValueNode(prim::kPrimMakeTuple)};
247 AnfNodePtrList local_value_inputs = {NewValueNode(prim::kPrimMakeTuple)};
248
249 // Handle call function
250 const std::string call_func_str = "__call_func_str__";
251 constexpr size_t call_func_index = 0;
252 script_buffer << call_func_str << "(";
253 (void)local_key_inputs.emplace_back(NewValueNode(call_func_str));
254 if (func_node == nullptr) {
255 (void)local_value_inputs.emplace_back(cnode->input(call_func_index));
256 } else {
257 (void)local_value_inputs.emplace_back(func_node);
258 }
259
260 // Handle inputs.
261 const std::string call_prefix = "__input_";
262 for (size_t i = 1; i < cnode->size(); ++i) {
263 auto cur_node = cnode->input(i);
264 if (IsPrimitiveCNode(cur_node, prim::kPrimMakeKeywordArg)) {
265 const std::string value_cur_str = call_prefix + "_value_" + std::to_string(i - 1) + "__";
266 constexpr size_t key_inputs_index = 1;
267 constexpr size_t value_inputs_index = 2;
268 constexpr size_t expect_inputs_size = 3;
269 if (cur_node->cast<CNodePtr>()->size() != expect_inputs_size) {
270 MS_LOG(INTERNAL_EXCEPTION) << "The make_keyword_arg node should have " << expect_inputs_size
271 << " inputs, but got " << cnode->size();
272 }
273 auto key_node = cur_node->cast<CNodePtr>()->input(key_inputs_index);
274 if (!IsValueNode<StringImm>(key_node)) {
275 MS_LOG(INTERNAL_EXCEPTION) << "The key in make_keyword args must be string, but got "
276 << key_node->DebugString();
277 }
278 auto key_string = GetValue<std::string>(GetValueNode(key_node));
279 std::string key_value_str = key_string + "=" + value_cur_str;
280 (void)local_key_inputs.emplace_back(NewValueNode(value_cur_str));
281 script_buffer << key_value_str << ",";
282 auto value_node = cur_node->cast<CNodePtr>()->input(value_inputs_index);
283 (void)local_value_inputs.emplace_back(value_node);
284 } else {
285 const std::string cur_str = call_prefix + std::to_string(i - 1) + "__";
286 script_buffer << cur_str << ",";
287 (void)local_key_inputs.emplace_back(NewValueNode(cur_str));
288 (void)local_value_inputs.emplace_back(cur_node);
289 }
290 }
291 script_buffer << ")";
292 const auto &script = script_buffer.str();
293 auto local_key_node = fg->NewCNode(local_key_inputs);
294 auto local_value_node = fg->NewCNode(local_value_inputs);
295 auto local_dict_node = fg->NewCNode({NewValueNode(prim::kPrimMakeDict), local_key_node, local_value_node});
296 auto obj_call_node =
297 fallback::CreatePyInterpretCNode(fg, script, py::dict(), local_dict_node, out_node->debug_info());
298 MS_LOG(DEBUG) << "Created obj_call_node: " << obj_call_node->DebugString();
299 AnalysisEnginePtr eng = conf->engine();
300 MS_EXCEPTION_IF_NULL(eng);
301 AnfNodeConfigPtr fn_conf = eng->MakeConfig(obj_call_node, conf->context(), conf->func_graph());
302 return eng->ForwardConfig(conf, fn_conf);
303 }
304
ParsePyObjToFunc(const py::object & py_fn,const CNodePtr & cnode,const AnfNodeConfigPtr & conf)305 EvalResultPtr ParsePyObjToFunc(const py::object &py_fn, const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
306 FuncGraphPtr func_fg = nullptr;
307 {
308 MS_LOG_TRY_CATCH_SCOPE;
309 func_fg = parse::ParsePythonCode(py_fn);
310 }
311 if (func_fg != nullptr) {
312 auto fg = cnode->func_graph();
313 MS_EXCEPTION_IF_NULL(fg);
314 func_fg->set_manager(fg->manager());
315
316 std::vector<AnfNodePtr> new_cnode_inputs;
317 (void)new_cnode_inputs.emplace_back(NewValueNode(func_fg));
318 for (std::size_t i = 1; i < cnode->size(); ++i) {
319 (void)new_cnode_inputs.emplace_back(cnode->input(i));
320 }
321 auto new_cnode = fg->NewCNodeInOrder(new_cnode_inputs);
322 new_cnode->set_debug_info(cnode->debug_info());
323
324 AnalysisEnginePtr eng = conf->engine();
325 MS_EXCEPTION_IF_NULL(eng);
326 AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, conf->context(), conf->func_graph());
327 return eng->ForwardConfig(conf, fn_conf);
328 } else {
329 return ConvertToPyInterpretCall(cnode, conf);
330 }
331 }
332
GetClassName(const py::object & cls_obj)333 std::string GetClassName(const py::object &cls_obj) {
334 if (py::hasattr(cls_obj, "__class__")) {
335 return py::getattr(py::getattr(cls_obj, "__class__"), "__name__").cast<py::str>();
336 }
337 return py::getattr(cls_obj, "__name__").cast<py::str>();
338 }
339
ConvertCallPyObjCallFunc(const CNodePtr & cnode,const AbstractBasePtr & abs,const AnfNodeConfigPtr & conf)340 EvalResultPtr ConvertCallPyObjCallFunc(const CNodePtr &cnode, const AbstractBasePtr &abs,
341 const AnfNodeConfigPtr &conf) {
342 MS_EXCEPTION_IF_NULL(cnode);
343 MS_EXCEPTION_IF_NULL(abs);
344 auto val = abs->BuildValue();
345 MS_EXCEPTION_IF_NULL(val);
346 auto warp_obj = dyn_cast_ptr<parse::PyObjectWrapper>(val);
347 MS_EXCEPTION_IF_NULL(warp_obj);
348 py::object cls_obj = warp_obj->obj();
349 auto class_name = GetClassName(cls_obj);
350 py::object call_obj = py::none();
351 const std::string construct_func_name = "construct";
352 if (py::hasattr(cls_obj, common::SafeCStr(construct_func_name)) && py::isinstance<Cell>(cls_obj)) {
353 call_obj = py::getattr(cls_obj, common::SafeCStr(construct_func_name));
354 } else {
355 const std::string call_func_name = "__call__";
356 if (py::hasattr(cls_obj, common::SafeCStr(call_func_name))) {
357 call_obj = py::getattr(cls_obj, common::SafeCStr(call_func_name));
358 }
359 }
360 if (py::isinstance<py::none>(call_obj)) {
361 MS_EXCEPTION(ValueError) << class_name << "is not a callable object";
362 }
363 return ParsePyObjToFunc(call_obj, cnode, conf);
364 }
365
ConvertMsClassObjToFunc(const CNodePtr & cnode,const AbstractBasePtr & abs,const AnfNodeConfigPtr & conf)366 EvalResultPtr ConvertMsClassObjToFunc(const CNodePtr &cnode, const AbstractBasePtr &abs, const AnfNodeConfigPtr &conf) {
367 MS_EXCEPTION_IF_NULL(cnode);
368 MS_EXCEPTION_IF_NULL(abs);
369 auto val = abs->BuildValue();
370 MS_EXCEPTION_IF_NULL(val);
371 auto class_val = dyn_cast_ptr<parse::MsClassObject>(val);
372 MS_EXCEPTION_IF_NULL(class_val);
373 py::object cls_obj = class_val->obj();
374 const std::string call_func_name = "__call__";
375 if (!py::hasattr(cls_obj, common::SafeCStr(call_func_name))) {
376 MS_EXCEPTION(ValueError) << class_val->name() << " has no " << call_func_name
377 << " function, please check the code.";
378 }
379 py::object call_obj = py::getattr(cls_obj, common::SafeCStr(call_func_name));
380 return ParsePyObjToFunc(call_obj, cnode, conf);
381 }
382
CheckFuncSideEffect(const AbstractFunctionPtr & func)383 bool CheckFuncSideEffect(const AbstractFunctionPtr &func) {
384 // Check if func graph contains isolated side-effect, and sync.
385 auto func_graph_abs = dyn_cast_ptr<FuncGraphAbstractClosure>(func);
386 if (func_graph_abs != nullptr) {
387 MS_EXCEPTION_IF_NULL(func_graph_abs->func_graph());
388 return func_graph_abs->func_graph()->has_side_effect_node();
389 } else {
390 auto meta_func_graph_abs = dyn_cast_ptr<MetaFuncGraphAbstractClosure>(func);
391 if (meta_func_graph_abs != nullptr) {
392 MS_EXCEPTION_IF_NULL(meta_func_graph_abs->meta_func_graph());
393 return meta_func_graph_abs->meta_func_graph()->has_side_effect_node();
394 }
395 if (func->isa<abstract::PartialAbstractClosure>()) {
396 const auto &abstract_partial_func = func->cast<abstract::PartialAbstractClosurePtr>();
397 const auto &abstract_fn = abstract_partial_func->fn();
398 MS_EXCEPTION_IF_NULL(abstract_fn);
399 return CheckFuncSideEffect(abstract_fn);
400 }
401 }
402 return false;
403 }
404
GetRealFuncAtom(const AbstractFuncAtomPtr & possible_func)405 AbstractFuncAtomPtr GetRealFuncAtom(const AbstractFuncAtomPtr &possible_func) {
406 MS_EXCEPTION_IF_NULL(possible_func);
407 auto real_atom = possible_func;
408 const auto &async_abs_func = possible_func->cast_ptr<AsyncAbstractFuncAtom>();
409 if (async_abs_func != nullptr) {
410 auto real_func = async_abs_func->GetUnique();
411 real_atom = dyn_cast<AbstractFuncAtom>(real_func);
412 MS_EXCEPTION_IF_NULL(real_atom);
413 MS_LOG(DEBUG) << "Real AsyncAbstractFuncAtom is: " << real_atom->ToString();
414 }
415 return real_atom;
416 }
417
418 template <typename T>
Match(const ValuePtr & prim)419 bool Match(const ValuePtr &prim) {
420 return prim->isa<T>();
421 }
422 using MetaFgMatchFunc = std::function<bool(const ValuePtr &)>;
423
MatchMetaFg(const ValuePtr & prim)424 bool MatchMetaFg(const ValuePtr &prim) {
425 static const std::vector<MetaFgMatchFunc> meta_fg_ops{
426 Match<prim::GradOperation>,
427 Match<prim::VmapOperation>,
428 Match<prim::Shard>,
429 };
430 return std::any_of(meta_fg_ops.cbegin(), meta_fg_ops.cend(),
431 [&prim](const MetaFgMatchFunc &match_func) { return match_func(prim); });
432 }
433
RemoveSequenceFromOrderList(const CNodePtr & origin_cnode)434 void RemoveSequenceFromOrderList(const CNodePtr &origin_cnode) {
435 constexpr size_t sequence_input_pos = 2;
436 if (origin_cnode->size() <= sequence_input_pos) {
437 return;
438 }
439 auto seq_node = origin_cnode->input(sequence_input_pos);
440 auto prim = GetCNodePrimitiveWithoutDoSignature(seq_node);
441 if (prim != nullptr &&
442 (IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList))) {
443 auto seq_cnode = dyn_cast<CNode>(seq_node);
444 MS_EXCEPTION_IF_NULL(seq_cnode);
445 seq_cnode->func_graph()->EraseUnusedNodeInOrder(seq_cnode);
446 }
447 }
448
GetEvalResult(const AnfNodePtr & node,const AnalysisEnginePtr & engine,const AnfNodeConfigPtr & conf)449 AbstractBasePtr GetEvalResult(const AnfNodePtr &node, const AnalysisEnginePtr &engine, const AnfNodeConfigPtr &conf) {
450 AnfNodeConfigPtr func_conf = std::make_shared<AnfNodeConfig>(engine, node, conf->context(), conf->func_graph());
451 auto possible_func_eval_result = func_conf->ObtainEvalResult();
452 MS_EXCEPTION_IF_NULL(possible_func_eval_result);
453 return possible_func_eval_result->abstract();
454 }
455
IsFuncGraphAbstractInput(const CNodePtr & origin_cnode,const AnalysisEnginePtr & engine,const AnfNodeConfigPtr & conf)456 bool IsFuncGraphAbstractInput(const CNodePtr &origin_cnode, const AnalysisEnginePtr &engine,
457 const AnfNodeConfigPtr &conf) {
458 auto possible_func = GetEvalResult(origin_cnode->input(1), engine, conf);
459 if (possible_func == nullptr || !possible_func->isa<FuncGraphAbstractClosure>()) {
460 return false;
461 }
462 // Check whether it is a high order scene such as GradOperation(GradOperation(net)), the meta_unpack_prepare doesn't
463 // handle before. To handle this later.
464 if (!origin_cnode->input(1)->isa<CNode>()) {
465 return true;
466 }
467 auto input1_cnode = origin_cnode->input(1)->cast<CNodePtr>();
468 auto possible_prim = GetEvalResult(input1_cnode->input(0), engine, conf);
469 if (possible_prim == nullptr || !possible_prim->isa<PrimitiveAbstractClosure>()) {
470 return true;
471 }
472 auto value = GetValueWithoutDoSignature(possible_prim->cast<PrimitiveAbstractClosurePtr>()->prim());
473 return !MatchMetaFg(value);
474 }
475
476 // {{meta_fg, g, w}, Ys} => {{meta_fg, {UnpackGraph, g, Ys}, w}, Ys}
477 // {UnpackCall, {meta_fg, g, w}, Ys} => {UnpackCall, {meta_fg, {UnpackGraph, g, Ys}, w}, Ys}
InsertUnpackGraph(const CNodePtr & origin_cnode,const ValuePtr & value,const AnfNodeConfigPtr & conf,const AnalysisEnginePtr & engine)478 AnfNodePtr InsertUnpackGraph(const CNodePtr &origin_cnode, const ValuePtr &value, const AnfNodeConfigPtr &conf,
479 const AnalysisEnginePtr &engine) {
480 // origin_cnode is {meta_fg, g, ...}
481 const size_t inputs_x_minimum_size = 2;
482 if (origin_cnode->size() < inputs_x_minimum_size) {
483 return nullptr;
484 }
485
486 if (value == nullptr || !MatchMetaFg(value)) {
487 return nullptr;
488 }
489
490 if (!IsFuncGraphAbstractInput(origin_cnode, engine, conf)) {
491 return nullptr;
492 }
493
494 auto manager = conf->engine()->func_graph_manager();
495 MS_EXCEPTION_IF_NULL(manager);
496 auto node_users = manager->node_users()[origin_cnode];
497 if (node_users.empty()) {
498 return nullptr;
499 }
500 auto meta_user = node_users.begin()->first->cast<CNodePtr>();
501 MS_EXCEPTION_IF_NULL(meta_user);
502 int index = node_users.begin()->second;
503 if (index != 0 && index != 1) {
504 return nullptr;
505 }
506
507 bool need_unpack_args = false;
508 if (index == 1) {
509 // The meta_fg user node should be UnpackCall.
510 auto input0_value = GetValueWithoutDoSignature(meta_user->input(0));
511 if (input0_value == nullptr || !input0_value->isa<prim::UnpackCall>()) {
512 return nullptr;
513 }
514 need_unpack_args = true;
515 }
516 // Create UnpackGraph node.
517 bool sens_param = false;
518 if (value->isa<prim::GradOperation>()) {
519 sens_param = value->cast<prim::GradOperationPtr>()->sens_param();
520 RemoveSequenceFromOrderList(origin_cnode);
521 }
522 auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>(sens_param, need_unpack_args);
523 std::vector<AnfNodePtr> unpack_graph_inputs{NewValueNode(unpack_graph), origin_cnode->input(1)};
524 const auto &meta_user_inputs = meta_user->inputs();
525 constexpr int64_t unpack_inputs_begin_index = 2;
526 int64_t offset = (need_unpack_args ? unpack_inputs_begin_index : 1);
527 (void)std::transform(meta_user_inputs.begin() + offset, meta_user_inputs.end(),
528 std::back_inserter(unpack_graph_inputs),
529 [](const AnfNodePtr &node) -> AnfNodePtr { return node; });
530 auto fg = origin_cnode->func_graph();
531 MS_EXCEPTION_IF_NULL(fg);
532 auto unpack_graph_node = fg->NewCNodeBefore(meta_user, unpack_graph_inputs);
533 // Create new call_node.
534 auto new_cnode_inputs = origin_cnode->inputs();
535 new_cnode_inputs[1] = unpack_graph_node;
536 auto new_cnode = fg->NewCNodeBefore(meta_user, new_cnode_inputs);
537 return new_cnode;
538 }
539 } // namespace
540
Get(const PrimitivePtr & prim,const AbstractBasePtrList & args) const541 EvalResultPtr PrimitiveEvalCache::Get(const PrimitivePtr &prim, const AbstractBasePtrList &args) const {
542 MS_EXCEPTION_IF_NULL(prim);
543 std::lock_guard<std::mutex> guard(mutex_);
544 auto cache_iter = prim_cache_.find(prim->name());
545 if (cache_iter == prim_cache_.end()) {
546 return nullptr;
547 }
548 auto &cache = cache_iter->second;
549 auto iter = cache.find(PrimitiveEvalCacheKey{prim->attrs(), args});
550 if (iter == cache.end()) {
551 return nullptr;
552 }
553 return iter->second;
554 }
555
Put(const PrimitivePtr & prim,AttrValueMap && attrs,const AbstractBasePtrList & args,const EvalResultPtr & result)556 void PrimitiveEvalCache::Put(const PrimitivePtr &prim, AttrValueMap &&attrs, const AbstractBasePtrList &args,
557 const EvalResultPtr &result) {
558 MS_EXCEPTION_IF_NULL(prim);
559 std::lock_guard<std::mutex> guard(mutex_);
560 (void)prim_cache_[prim->name()].emplace(PrimitiveEvalCacheKey{std::move(attrs), args}, result);
561 }
562
Clear()563 void PrimitiveEvalCache::Clear() {
564 std::lock_guard<std::mutex> guard(mutex_);
565 prim_cache_.clear();
566 }
567
Run(const FuncGraphPtr & func_graph,const AbstractBasePtrList & args_abs_list)568 AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_abs_list) {
569 StaticAnalysisException::Instance().ClearException();
570 AnalysisResult result;
571 try {
572 MS_EXCEPTION_IF_NULL(func_graph);
573 ConfigPtrList args_conf_list;
574 (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(args_conf_list),
575 [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
576 MS_EXCEPTION_IF_NULL(func_graph_manager_);
577 func_graph_manager_->AddFuncGraph(func_graph);
578 root_func_graph_ = func_graph;
579
580 // Running the analyzer.
581 ResetFunctionCallDepth();
582 ResetStackFrameDepth();
583 // Create a new root dummy context for the new analysis session.
584 AnalysisContextPtr dummy_context = AnalysisContext::NewDummyContext();
585 MS_LOG(DEBUG) << func_graph->ToString() << ": Run begin.";
586 AnalysisContextPtr root_context = Run(func_graph, dummy_context, args_conf_list);
587 AnalysisSchedule::GetInstance().Wait();
588 MS_EXCEPTION_IF_NULL(root_context);
589 auto root_context_fg = root_context->func_graph();
590 MS_EXCEPTION_IF_NULL(root_context_fg);
591 AnfNodeConfigPtr output_conf = MakeConfig(root_context_fg->get_return(), root_context, root_context_fg);
592 MS_LOG(DEBUG) << func_graph->ToString() << ": Run finished.";
593
594 MS_EXCEPTION_IF_NULL(output_conf);
595 auto eval_result = output_conf->ObtainEvalResult();
596 result.eval_result = eval_result;
597 result.context = root_context;
598 } catch (const std::exception &ex) {
599 MS_LOG(INFO) << "Eval " << func_graph->ToString() << " threw exception.";
600 AnalysisSchedule::GetInstance().HandleException(ex);
601 }
602 AnalysisSchedule::GetInstance().Wait();
603 MS_LOG(DEBUG) << func_graph->ToString() << ": Run end.";
604 // Set the sequence nodes' elements use flags all true.
605 SetSequenceElementsUseFlagsRecursively(result.eval_result->abstract(), true);
606 MS_LOG(DEBUG) << func_graph->ToString() << ":SetSequenceElementsUseFlagsRecursively Run end.";
607 return result;
608 }
609
Run(const FuncGraphPtr & func_graph,const AnalysisContextPtr & context,const ConfigPtrList & args_conf_list)610 AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
611 const ConfigPtrList &args_conf_list) {
612 auto evaluator = std::make_shared<FuncGraphEvaluator>(func_graph, context);
613 (void)evaluator->Run(shared_from_this(), args_conf_list, nullptr);
614 return root_context_;
615 }
616
ObtainEvalResultFromCache(const AnfNodeConfigPtr & conf)617 EvalResultPtr ObtainEvalResultFromCache(const AnfNodeConfigPtr &conf) {
618 MS_EXCEPTION_IF_NULL(conf);
619 static AnalysisResultCacheMgr &cache_mgr = AnalysisResultCacheMgr::GetInstance();
620 auto result = cache_mgr.GetValue(conf);
621 if (result != nullptr) {
622 MS_EXCEPTION_IF_NULL(result->abstract());
623 MS_LOG(DEBUG) << "Evaluate cache found for NodeConfig: " << conf->ToString()
624 << ", result: " << result->abstract().get() << "/" << result->abstract()->ToString();
625 return result;
626 }
627 return nullptr;
628 }
629
ObtainEvalResultWithCache(const AnfNodeConfigPtr & conf)630 EvalResultPtr AnalysisEngine::ObtainEvalResultWithCache(const AnfNodeConfigPtr &conf) {
631 MS_EXCEPTION_IF_NULL(conf);
632 auto result = ObtainEvalResultFromCache(conf);
633 if (result != nullptr) {
634 return result;
635 }
636 MS_LOG(DEBUG) << "Evaluate cache miss for NodeConfig: " << conf->ToString();
637 result = ObtainEvalResultWithoutCache(conf);
638 return result;
639 }
640
ObtainEvalResultWithoutCache(const AnfNodeConfigPtr & conf)641 EvalResultPtr AnalysisEngine::ObtainEvalResultWithoutCache(const AnfNodeConfigPtr &conf) {
642 MS_EXCEPTION_IF_NULL(conf);
643 EvalResultPtr result = nullptr;
644 result = Eval(conf);
645 if (result == nullptr) {
646 MS_LOG(INTERNAL_EXCEPTION) << "Evaluate for NodeConfig " << conf->ToString() << " get nullptr";
647 }
648 MS_EXCEPTION_IF_NULL(result->abstract());
649 MS_LOG(DEBUG) << "Always Evaluate node for NodeConfig: " << conf->ToString()
650 << ", result: " << result->abstract().get() << "/" << result->abstract()->ToString();
651 SaveEvalResultInCache(conf, result);
652 return result;
653 }
654
SaveEvalResultInCache(const AnfNodeConfigPtr & conf,const EvalResultPtr & result) const655 void AnalysisEngine::SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) const {
656 MS_EXCEPTION_IF_NULL(conf);
657 MS_EXCEPTION_IF_NULL(result);
658 static AnalysisResultCacheMgr &cache_mgr = AnalysisResultCacheMgr::GetInstance();
659 auto iter = cache_mgr.GetCache().find(conf);
660 if (iter != cache_mgr.GetCache().end()) {
661 MS_EXCEPTION_IF_NULL(iter->second);
662 MS_EXCEPTION_IF_NULL(iter->second->abstract());
663 MS_LOG(DEBUG) << "Found previous result for NodeConfig: " << conf->ToString()
664 << ", result: " << iter->second->abstract().get() << "/" << iter->second->abstract()->ToString();
665 // Update sequence nodes info, if matched in cache.
666 static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
667 if (enable_eliminate_unused_element) {
668 auto new_sequence = dyn_cast<AbstractSequence>(result->abstract());
669 auto old_sequence = dyn_cast<AbstractSequence>(iter->second->abstract());
670 if (old_sequence != nullptr && new_sequence != nullptr) {
671 MS_LOG(DEBUG) << "Before synchronize sequence nodes use flags for NodeConfig: " << conf->ToString()
672 << ", old_sequence: " << old_sequence->ToString()
673 << ", new_sequence: " << new_sequence->ToString();
674 SynchronizeSequenceElementsUseFlagsRecursively(old_sequence, new_sequence);
675 MS_LOG(DEBUG) << "After synchronize sequence nodes use flags for NodeConfig: " << conf->ToString()
676 << ", old_sequence: " << old_sequence->ToString()
677 << ", new_sequence: " << new_sequence->ToString();
678 }
679 }
680 }
681 MS_EXCEPTION_IF_NULL(result->abstract());
682 MS_LOG(DEBUG) << "Save result for NodeConfig: " << conf->ToString() << ", result: " << result->abstract().get() << "/"
683 << result->abstract()->ToString();
684 cache_mgr.SetValue(conf, result);
685 }
686
SynchronizeSequenceElementsUseFlagsForFuncGraphArgs(const AnalysisEnginePtr & engine,const FuncGraphPtr & fg,const CNodePtr & cnode,const AbstractFunctionPtr & base_func_graph_func,const AnalysisContextPtr & fg_context)687 void SynchronizeSequenceElementsUseFlagsForFuncGraphArgs(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
688 const CNodePtr &cnode,
689 const AbstractFunctionPtr &base_func_graph_func,
690 const AnalysisContextPtr &fg_context) {
691 // Get the evaluator for func graph.
692 auto evaluator = engine->GetEvaluatorFor(base_func_graph_func);
693 MS_EXCEPTION_IF_NULL(evaluator);
694
695 AbstractBasePtrList args_abs_list;
696 for (std::size_t i = 1; i < cnode->size(); i++) {
697 auto config = engine->MakeConfig(cnode->input(i), fg_context, fg);
698 auto result = config->ObtainEvalResult();
699 MS_EXCEPTION_IF_NULL(result);
700 auto abs = result->abstract();
701 args_abs_list.push_back(abs);
702 }
703
704 // Check if already evaluated before.
705 MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
706 auto &cache = evaluator->evaluator_cache_mgr()->GetCache();
707 auto iter = cache.find(args_abs_list);
708 if (iter != cache.end()) {
709 MS_EXCEPTION_IF_NULL(fg_context);
710 MS_LOG(DEBUG) << "Eval before, current_node: " << cnode->DebugString() << ", context: " << fg_context->ToString()
711 << ", args: " << args_abs_list;
712 // Update inputs sequence nodes info, if matched in cache.
713 for (std::size_t i = 0; i < args_abs_list.size(); ++i) {
714 auto new_sequence = dyn_cast<AbstractSequence>(args_abs_list[i]);
715 auto old_sequence = dyn_cast<AbstractSequence>(iter->first[i]);
716 if (old_sequence != nullptr && new_sequence != nullptr) {
717 MS_LOG(DEBUG) << "Before synchronize sequence nodes use flags, old_sequence: " << old_sequence->ToString()
718 << ", new_sequence: " << new_sequence->ToString();
719 SynchronizeSequenceElementsUseFlagsRecursively(old_sequence, new_sequence);
720 MS_LOG(DEBUG) << "After synchronize sequence nodes use flags, old_sequence: " << old_sequence->ToString()
721 << ", new_sequence: " << new_sequence->ToString();
722 }
723 }
724 }
725 }
726
Eval(const AnfNodeConfigPtr & conf)727 EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
728 MS_EXCEPTION_IF_NULL(conf);
729 AnfNodePtr node = conf->node();
730 EvalResultPtr eval_result = nullptr;
731 #ifdef DEBUG
732 compute_conf_stack_.push_back(node);
733 std::ostringstream buffer;
734 buffer << "Compute Config Begin:";
735 for (auto iter : compute_conf_stack_) {
736 buffer << " -> " << iter->DebugString();
737 }
738 MS_LOG(DEBUG) << buffer.str();
739 #endif
740 MS_LOG(DEBUG) << "Begin Eval NodeConfig " << conf->ToString();
741 MS_EXCEPTION_IF_NULL(node);
742 if (node->abstract() != nullptr) {
743 MS_LOG(DEBUG) << "Return old abstract: " << node->DebugString();
744 eval_result = std::make_shared<EvalResult>(node->abstract(), std::make_shared<AttrValueMap>());
745 } else if (node->isa<ValueNode>()) {
746 auto value_node = node->cast<ValueNodePtr>();
747 auto abstract = EvalValueNode(value_node, conf);
748 eval_result = std::make_shared<EvalResult>(abstract, std::make_shared<AttrValueMap>());
749 } else if (node->isa<CNode>()) {
750 auto cnode = node->cast<CNodePtr>();
751 trace::TraceEvalCNodeEnter(conf);
752 MS_LOG(DEBUG) << "Begin Eval CNode: " << cnode->DebugString();
753 eval_result = EvalCNode(cnode, conf);
754 MS_LOG(DEBUG) << "End Eval CNode: " << cnode->DebugString();
755 trace::TraceEvalCNodeLeave();
756 } else {
757 MS_LOG(INTERNAL_EXCEPTION) << "Illegal AnfNode for evaluating, node: " << node->DebugString()
758 << "(type:" << node->type_name() << "), fg: "
759 << (node->func_graph() != nullptr ? node->func_graph()->ToString() : "nullgraph")
760 << " conf: " << conf->ToString();
761 }
762
763 #ifdef DEBUG
764 compute_conf_stack_.pop_back();
765 if (eval_result == nullptr) {
766 MS_LOG(INTERNAL_EXCEPTION) << "Compute Config failed, node: " << node->DebugString()
767 << " NodeInfo: " << trace::GetDebugInfo(node->debug_info());
768 }
769 #endif
770 MS_EXCEPTION_IF_NULL(eval_result->abstract());
771 MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << eval_result->abstract()->ToString();
772 return eval_result;
773 }
774
EvalValueNode(const ValueNodePtr & value_node,const AnfNodeConfigPtr & conf) const775 AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) const {
776 MS_EXCEPTION_IF_NULL(conf);
777 MS_EXCEPTION_IF_NULL(value_node);
778 auto out = ToAbstract(value_node->value(), conf->context(), conf);
779 if (value_node->has_new_value() && out->isa<AbstractTensor>()) {
780 out = out->Broaden();
781 }
782 return out;
783 }
784
GetForwardConfig(const AnfNodeConfigPtr & conf) const785 AnfNodeConfigPtr AnalysisEngine::GetForwardConfig(const AnfNodeConfigPtr &conf) const {
786 MS_EXCEPTION_IF_NULL(conf);
787 AnfNodeConfigPtr new_conf = conf;
788 auto conf_iter = anfnode_config_map().find(conf);
789 while (conf_iter != anfnode_config_map().end()) {
790 new_conf = conf_iter->second;
791 MS_EXCEPTION_IF_NULL(new_conf);
792 conf_iter = anfnode_config_map().find(new_conf);
793 }
794 return new_conf;
795 }
796
InterpretedNodeCall(const CNodePtr & cnode,const AnfNodeConfigPtr & conf)797 EvalResultPtr AnalysisEngine::InterpretedNodeCall(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
798 MS_EXCEPTION_IF_NULL(cnode);
799 if (cnode->empty()) {
800 MS_LOG(INTERNAL_EXCEPTION) << "CNode inputs should not be empty, CNode: " << cnode->DebugString();
801 }
802
803 // Check if the operator input is PyExecute CNode.
804 const auto &func_node = cnode->input(0);
805 MS_EXCEPTION_IF_NULL(func_node);
806 constexpr auto recursive_level = 2;
807 MS_LOG(DEBUG) << "Current CNode: " << cnode->DebugString(recursive_level)
808 << ", func_node: " << func_node->DebugString(recursive_level);
809 auto prim = GetCNodePrimitiveWithoutDoSignature(func_node);
810 if (!IsPrimitiveEquals(prim, prim::kPrimResolve) && !IsPrimitiveEquals(prim, prim::kPrimGetAttr) &&
811 !IsPrimitiveEquals(prim, prim::kPrimPyExecute) && !IsPrimitiveEquals(prim, prim::kPrimPyInterpret)) {
812 // Optimize the performance.
813 return nullptr;
814 }
815 AnfNodeConfigPtr func_conf = MakeConfig(func_node, conf->context(), conf->func_graph());
816 MS_EXCEPTION_IF_NULL(func_conf);
817 const auto &forwarded_conf = GetForwardConfig(func_conf);
818 if (!IsPrimitiveCNode(forwarded_conf->node(), prim::kPrimPyExecute) &&
819 !IsPrimitiveCNode(forwarded_conf->node(), prim::kPrimPyInterpret)) {
820 return nullptr;
821 }
822
823 if (IsPrimitiveEquals(prim, prim::kPrimResolve)) {
824 return ConvertToPyInterpretCall(cnode, conf, forwarded_conf->node());
825 }
826 // Forward getattr CNode call to PyInterpreted CNode.
827 return ConvertToPyInterpretCall(cnode, conf);
828 }
829
GetCNodeOperatorAbstract(const CNodePtr & cnode,const AnalysisContextPtr & context,const FuncGraphPtr & func_graph)830 AbstractBasePtr AnalysisEngine::GetCNodeOperatorAbstract(const CNodePtr &cnode, const AnalysisContextPtr &context,
831 const FuncGraphPtr &func_graph) {
832 MS_EXCEPTION_IF_NULL(cnode);
833 if (cnode->empty()) {
834 MS_LOG(INTERNAL_EXCEPTION) << "CNode inputs should not be empty, CNode: " << cnode->DebugString();
835 }
836 auto &func_node = cnode->input(0);
837 MS_EXCEPTION_IF_NULL(func_node);
838 MS_LOG(DEBUG) << "Current CNode function: " << func_node->DebugString();
839 AnfNodeConfigPtr func_conf = MakeConfig(func_node, context, func_graph);
840 MS_EXCEPTION_IF_NULL(func_conf);
841 // Keep it in a local variable, otherwise smart pointer will free it.
842 auto possible_func_eval_result = func_conf->ObtainEvalResult();
843 MS_EXCEPTION_IF_NULL(possible_func_eval_result);
844 auto &possible_func = possible_func_eval_result->abstract();
845 if (possible_func == nullptr) {
846 MS_LOG(INTERNAL_EXCEPTION) << "No abstract, func_conf: " << func_conf->ToString();
847 }
848 return possible_func;
849 }
850
ConvertClassTypeToFunc(const CNodePtr & cnode,const AbstractBasePtr & abs,const AnfNodeConfigPtr & conf)851 EvalResultPtr AnalysisEngine::ConvertClassTypeToFunc(const CNodePtr &cnode, const AbstractBasePtr &abs,
852 const AnfNodeConfigPtr &conf) {
853 MS_EXCEPTION_IF_NULL(cnode);
854 const auto inputs_size = cnode->size();
855 AbstractBasePtrList input_abs;
856 input_abs.reserve(inputs_size - 1);
857 for (std::size_t i = 1; i < inputs_size; ++i) {
858 const AnfNodePtr &node = cnode->input(i);
859 auto cur_config = MakeConfig(node, conf->context(), conf->func_graph());
860 const auto &cur_eval_result = cur_config->ObtainEvalResult();
861 MS_EXCEPTION_IF_NULL(cur_eval_result);
862 auto cur_abs = cur_eval_result->abstract();
863 MS_EXCEPTION_IF_NULL(cur_abs);
864 input_abs.push_back(cur_abs);
865 }
866 bool has_non_graph_input = std::any_of(input_abs.begin(), input_abs.end(), [](const AbstractBasePtr &abs) {
867 MS_EXCEPTION_IF_NULL(abs);
868 return abs->isa<abstract::AbstractAny>() || abs->BuildValue()->isa<parse::InterpretedObject>();
869 });
870 if (has_non_graph_input) {
871 return ConvertToPyInterpretCall(cnode, conf);
872 }
873 MS_EXCEPTION_IF_NULL(abs);
874 auto val = abs->BuildValue();
875 MS_EXCEPTION_IF_NULL(val);
876 auto class_val = dyn_cast_ptr<parse::ClassType>(val);
877 MS_EXCEPTION_IF_NULL(class_val);
878 const auto &class_name = class_val->name();
879 std::vector<AnfNodePtr> new_cnode_inputs;
880 auto fg = cnode->func_graph();
881 MS_EXCEPTION_IF_NULL(fg);
882
883 std::map<std::string, ValueNodePtr> list_or_tuple_func_map = {
884 {"class 'list'", NewValueNode(std::make_shared<prim::ListFunc>("list_func"))},
885 {"class 'tuple'", NewValueNode(std::make_shared<prim::TupleFunc>("tuple_func"))}};
886 auto iter = list_or_tuple_func_map.find(class_name);
887 if (iter != list_or_tuple_func_map.end()) {
888 (void)new_cnode_inputs.emplace_back(iter->second);
889 } else {
890 auto class_obj = class_val->obj();
891 py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
892 auto py_fn =
893 python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_CONVERT_CLASS_TO_FUNCTION, py::str(class_name), class_obj);
894 if (py::isinstance<py::none>(py_fn)) {
895 return ConvertToPyInterpretCall(cnode, conf);
896 }
897 auto func_fg = parse::ParsePythonCode(py_fn);
898 MS_EXCEPTION_IF_NULL(func_fg);
899 func_fg->set_manager(fg->manager());
900 (void)new_cnode_inputs.emplace_back(NewValueNode(func_fg));
901 }
902
903 for (std::size_t i = 1; i < cnode->size(); ++i) {
904 (void)new_cnode_inputs.emplace_back(cnode->input(i));
905 }
906 auto new_cnode = fg->NewCNodeInOrder(new_cnode_inputs);
907 new_cnode->set_debug_info(cnode->debug_info());
908 AnalysisEnginePtr eng = conf->engine();
909 MS_EXCEPTION_IF_NULL(eng);
910 AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, conf->context(), conf->func_graph());
911 return eng->ForwardConfig(conf, fn_conf);
912 }
913
EvalCNode(const CNodePtr & cnode,const AnfNodeConfigPtr & conf)914 EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
915 MS_EXCEPTION_IF_NULL(conf);
916 MS_EXCEPTION_IF_NULL(cnode);
917
918 // Handle the interpreted node call here.
919 const auto &interpreted_eval_result = InterpretedNodeCall(cnode, conf);
920 if (interpreted_eval_result != nullptr) {
921 return interpreted_eval_result;
922 }
923
924 AbstractBasePtr possible_func = GetCNodeOperatorAbstract(cnode, conf->context(), conf->func_graph());
925 MS_EXCEPTION_IF_NULL(possible_func->BuildType());
926 if (possible_func->IsSameTypeId(AbstractUndetermined::kTypeId)) {
927 MS_LOG(DEBUG) << "EvalCNode eval Undetermined";
928 return std::make_shared<EvalResult>(possible_func->Clone(), std::make_shared<AttrValueMap>());
929 }
930
931 if (possible_func->isa<AbstractClass>()) {
932 return ConvertMsClassObjToFunc(cnode, possible_func, conf);
933 }
934 if (possible_func->isa<AbstractScalar>()) {
935 // Convert class to function, such as list(xxx).
936 auto val = possible_func->BuildValue();
937 MS_EXCEPTION_IF_NULL(val);
938 if (val->isa<parse::ClassType>()) {
939 return ConvertClassTypeToFunc(cnode, possible_func, conf);
940 }
941 if (val->isa<parse::InterpretedObject>()) {
942 return ConvertCallPyObjCallFunc(cnode, possible_func, conf);
943 }
944 }
945
946 if (possible_func->isa<AbstractAny>()) {
947 return ConvertToPyInterpretCall(cnode, conf);
948 }
949
950 if (possible_func->isa<PrimitiveAbstractClosure>()) {
951 auto value = GetValueWithoutDoSignature(possible_func->cast<PrimitiveAbstractClosurePtr>()->prim());
952 auto new_cnode = InsertUnpackGraph(cnode, value, conf, shared_from_this());
953 if (new_cnode != nullptr) {
954 AnalysisEnginePtr eng = conf->engine();
955 MS_EXCEPTION_IF_NULL(eng);
956 AnfNodeConfigPtr new_conf = eng->MakeConfig(new_cnode, conf->context(), conf->func_graph());
957 return eng->ForwardConfig(conf, new_conf);
958 }
959 }
960
961 auto func = dyn_cast_ptr<AbstractFunction>(possible_func);
962 if (func == nullptr) {
963 MS_LOG(ERROR) << "Can not cast to a AbstractFunction from " << possible_func->ToString() << ".";
964 MS_LOG(ERROR) << "It's called at: " << cnode->DebugString();
965 MS_EXCEPTION(ValueError) << "The object is not callable. Please check code.";
966 }
967
968 // Make arguments config list.
969 bool contains_side_effect = false;
970 const auto inputs_size = cnode->size();
971 ConfigPtrList args_conf_list;
972 args_conf_list.reserve(inputs_size - 1);
973 // Ignore the first node which is function name.
974 for (std::size_t i = 1; i < inputs_size; ++i) {
975 const AnfNodePtr &node = cnode->input(i);
976 (void)args_conf_list.emplace_back(MakeConfig(node, conf->context(), conf->func_graph()));
977 if (check_side_effect()) {
978 auto input_cnode = dyn_cast_ptr<CNode>(node);
979 if (input_cnode != nullptr) {
980 contains_side_effect = contains_side_effect || input_cnode->has_side_effect_node();
981 }
982 }
983 }
984
985 // Find evaluators.
986 std::vector<EvaluatorPtr> evaluators;
987 func->Visit([this, &evaluators, &cnode](const AbstractFuncAtomPtr &possible_func) {
988 const auto &real_func_atom = GetRealFuncAtom(possible_func);
989 auto evaluator = this->GetEvaluatorFor(real_func_atom);
990 evaluator->set_bound_node(cnode);
991 (void)evaluators.emplace_back(std::move(evaluator));
992 });
993
994 // Run evaluators.
995 auto eval_result = ExecuteEvaluators(evaluators, conf, args_conf_list);
996 // Check if func graph contains isolated side-effect, and sync.
997 if (check_side_effect()) {
998 func->Visit([&contains_side_effect](const AbstractFuncAtomPtr &possible_func) {
999 const auto &real_func_atom = GetRealFuncAtom(possible_func);
1000 bool func_has_side_effect = CheckFuncSideEffect(real_func_atom);
1001 if (func_has_side_effect) {
1002 contains_side_effect = true;
1003 }
1004 });
1005 if (contains_side_effect) {
1006 MS_EXCEPTION_IF_NULL(conf->func_graph());
1007 MS_LOG(DEBUG) << "Found side-effect, cnode: " << cnode->DebugString()
1008 << ", func_graph: " << conf->func_graph()->ToString();
1009 cnode->set_has_side_effect_node(true);
1010 conf->func_graph()->set_has_side_effect_node(true);
1011 eval_result->set_has_side_effect_node(true);
1012 }
1013 }
1014 return eval_result;
1015 }
1016
Execute(const AbstractFunctionPtr & func,const AbstractBasePtrList & args_abs_list)1017 EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_abs_list) {
1018 MS_EXCEPTION_IF_NULL(func);
1019 ConfigPtrList args_conf_list;
1020 (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(args_conf_list),
1021 [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
1022 std::vector<EvaluatorPtr> infs;
1023 MS_EXCEPTION_IF_NULL(func);
1024 auto build_evaluator = [this, &infs](const AbstractFuncAtomPtr &poss) {
1025 auto evaluator = this->GetEvaluatorFor(poss);
1026 infs.push_back(evaluator);
1027 };
1028 func->Visit(build_evaluator);
1029 return ExecuteEvaluators(infs, nullptr, args_conf_list);
1030 }
1031
ClearEvaluatorCache()1032 void AnalysisEngine::ClearEvaluatorCache() {
1033 py::gil_scoped_acquire gil;
1034 for (auto &element : evaluators_) {
1035 EvaluatorPtr evaluator = element.second;
1036 if (evaluator == nullptr || evaluator->evaluator_cache_mgr() == nullptr) {
1037 continue;
1038 }
1039 MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
1040 evaluator->evaluator_cache_mgr()->Clear();
1041 }
1042 for (auto &element : prim_constructors_) {
1043 EvaluatorPtr evaluator = element.second;
1044 if (evaluator == nullptr || evaluator->evaluator_cache_mgr() == nullptr) {
1045 continue;
1046 }
1047 MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
1048 evaluator->evaluator_cache_mgr()->Clear();
1049 }
1050 for (auto &element : prim_py_evaluators_) {
1051 EvaluatorPtr evaluator = element.second;
1052 if (evaluator == nullptr || evaluator->evaluator_cache_mgr() == nullptr) {
1053 continue;
1054 }
1055 MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
1056 evaluator->evaluator_cache_mgr()->Clear();
1057 }
1058 // Release exception to avoid hup at exit.
1059 StaticAnalysisException::Instance().ClearException();
1060 // Reset the EnvironGet sparse option.
1061 EnvSetSparseResultMgr::GetInstance().Set(false);
1062 }
1063
Clear()1064 void AnalysisEngine::Clear() {
1065 AnalysisResultCacheMgr::GetInstance().Clear();
1066 anfnode_config_map_.clear();
1067 eval_trace_.clear();
1068 evaluators_.clear();
1069 prim_py_evaluators_.clear();
1070 constructors_app_.clear();
1071 continued_evals_.clear();
1072 root_context_ = nullptr;
1073 }
1074
GetPyEvaluator(const PrimitivePtr & prim,const AnalysisEnginePtr & engine)1075 EvaluatorPtr GetPyEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) {
1076 auto prim_py = dyn_cast<PrimitivePy>(prim);
1077 if (prim_py != nullptr) {
1078 auto is_constexpr = prim_py->HasAttr(GRAPH_FLAG_CONSTEXPR_PRIM);
1079 if (is_constexpr) {
1080 return std::make_shared<ConstexprEvaluator>(prim_py);
1081 }
1082 if (engine == nullptr) {
1083 return std::make_shared<PythonPrimEvaluator>(prim_py);
1084 }
1085
1086 const auto &iter = engine->prim_py_evaluators_.find(prim_py);
1087 if (iter != engine->prim_py_evaluators_.end()) {
1088 return iter->second;
1089 }
1090 auto evaluator = std::make_shared<PythonPrimEvaluator>(prim_py);
1091 engine->prim_py_evaluators_[prim_py] = evaluator;
1092 return evaluator;
1093 }
1094 MS_LOG(ERROR) << "The primitive with python evaluator should be a python primitive.";
1095 return nullptr;
1096 }
1097
GetStandardPrimEvaluator(const PrimitivePtr & prim)1098 inline StandardPrimEvaluatorPtr GetStandardPrimEvaluator(const PrimitivePtr &prim) {
1099 auto eval_impl_opt = GetFrontendPrimitiveInferImpl(prim);
1100 if (eval_impl_opt.has_value()) {
1101 // Find prim infer function in the prim function map return a standard evaluator
1102 auto eval_impl = eval_impl_opt.value();
1103 if (eval_impl.IsImplInferShapeAndType() && !IsPrimitiveEquals(prim, prim::kPrimMakeTuple) &&
1104 !IsPrimitiveEquals(prim, prim::kPrimMakeList)) {
1105 return std::make_shared<StandardPrimEvaluator>(prim, eval_impl);
1106 }
1107 }
1108
1109 return nullptr;
1110 }
1111
GetPrimEvaluator(const PrimitivePtr & prim,const AnalysisEnginePtr & engine)1112 EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) {
1113 // Custom Primitive with python infer_shape, infer_type
1114 MS_EXCEPTION_IF_NULL(prim);
1115 if (prim->isa<prim::DoSignaturePrimitive>()) {
1116 return std::make_shared<DoSignatureEvaluator>(prim);
1117 }
1118 if (prim->isa<prim::UnpackGraphPrimitive>()) {
1119 return std::make_shared<UnpackGraphEvaluator>(prim);
1120 }
1121 if (IsPrimitiveEquals(prim, prim::kPrimMixedPrecisionCast)) {
1122 return std::make_shared<MixedPrecisionCastEvaluator>(prim);
1123 }
1124 if (IsPrimitiveEquals(prim, prim::kPrimPyExecute)) {
1125 return std::make_shared<PyExecuteEvaluator>();
1126 }
1127 static const bool enable_pre_lift = (common::GetCompileConfig("PRE_LIFT") == "1");
1128 if (enable_pre_lift && IsPrimitiveEquals(prim, prim::kPrimSwitch)) {
1129 return std::make_shared<SwitchEvaluator>();
1130 }
1131
1132 if (prim->isa<prim::DoTransPrimitiveFunction>()) {
1133 return std::make_shared<DoTransPrimitiveFunctionEvaluator>(prim);
1134 }
1135 // Primitive is defined in OpTable.
1136 if (mindspore::ops::IsPrimitiveFunction(prim->name())) {
1137 if (prim->isa<PrimitivePy>()) {
1138 return std::make_shared<PrimitiveArgsToInputsEvaluator>(prim);
1139 }
1140 return std::make_shared<PrimitiveFunctionEvaluator>(prim);
1141 }
1142
1143 auto standard_evaluator = GetStandardPrimEvaluator(prim);
1144 if (standard_evaluator != nullptr) {
1145 return standard_evaluator;
1146 }
1147
1148 // Use python infer function if the infer function not founded in the map return a python evaluator
1149 EvaluatorPtr evaluator = nullptr;
1150 if (prim->HasPyEvaluator()) {
1151 return GetPyEvaluator(prim, engine);
1152 }
1153
1154 // Delete this when the infer value can be mapped to the CPU backend operator.
1155 if (PrimNeedFrontendInferValue(prim)) {
1156 return nullptr;
1157 }
1158
1159 // Return a default evaluator
1160 if (engine == nullptr) {
1161 // If engine is nullptr, get constructor from default.
1162 const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors();
1163 auto iter = prim_evaluator_map.find(prim);
1164 if (iter != prim_evaluator_map.end()) {
1165 evaluator = iter->second;
1166 }
1167 } else {
1168 // If engine is given, get constructor from engine resource.
1169 const PrimEvaluatorMap &prim_evaluator_map = engine->PrimConstructors();
1170 auto iter = prim_evaluator_map.find(prim);
1171 if (iter != prim_evaluator_map.end()) {
1172 evaluator = iter->second;
1173 }
1174 }
1175
1176 if (evaluator == nullptr) {
1177 MS_LOG(DEBUG) << "The evaluator of the primitive is not defined (" << prim->name() << ").";
1178 }
1179 return evaluator;
1180 }
1181
_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> & func)1182 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &func) {
1183 MS_EXCEPTION_IF_NULL(func);
1184 const auto &primitive = func->prim();
1185 if (func->tracking_id() == 0) {
1186 // Create primitive evaluator if tracking_id == 0.
1187 auto [iter, is_new] = evaluators_.emplace(func, nullptr);
1188
1189 if (is_new) {
1190 iter->second = GetPrimEvaluator(primitive, shared_from_this());
1191 if (iter->second == nullptr) {
1192 MS_LOG(EXCEPTION) << "Operator '" << primitive->name()
1193 << "' is invalid, or no matching evaluator could be found.";
1194 }
1195 }
1196 return iter->second;
1197 }
1198 // Use TrackedEvaluator if tracking_id != 0.
1199 auto iter = evaluators_.find(func);
1200 if (iter != evaluators_.end()) {
1201 return iter->second;
1202 }
1203 auto prim_without_tracking_id = std::make_shared<PrimitiveAbstractClosure>(primitive, 0);
1204 EvaluatorPtr prim_evaluator = _GetEvaluatorFor(prim_without_tracking_id);
1205 static const bool enable_pre_lift = (common::GetCompileConfig("PRE_LIFT") == "1");
1206 if (enable_pre_lift && IsPrimitiveEquals(primitive, prim::kPrimSwitch)) {
1207 auto result = evaluators_.emplace(func, prim_evaluator);
1208 return result.first->second;
1209 } else {
1210 auto tracked_evaluator = std::make_shared<TrackedEvaluator>(prim_evaluator);
1211 auto result = evaluators_.emplace(func, std::move(tracked_evaluator));
1212 return result.first->second;
1213 }
1214 }
1215
_GetEvaluatorFor(const std::shared_ptr<FuncGraphAbstractClosure> & func)1216 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<FuncGraphAbstractClosure> &func) {
1217 MS_EXCEPTION_IF_NULL(func);
1218 auto [iter, is_new] = evaluators_.emplace(func, nullptr);
1219 if (is_new) {
1220 iter->second = std::make_shared<FuncGraphEvaluator>(func->func_graph(), func->context());
1221 }
1222 return iter->second;
1223 }
1224
_GetEvaluatorFor(const std::shared_ptr<MetaFuncGraphAbstractClosure> & func)1225 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<MetaFuncGraphAbstractClosure> &func) {
1226 MS_EXCEPTION_IF_NULL(func);
1227 auto [iter, is_new] = evaluators_.emplace(func, nullptr);
1228 if (is_new) {
1229 iter->second = std::make_shared<MetaFuncGraphEvaluator>(func->meta_func_graph(), func->GetScope());
1230 }
1231 return iter->second;
1232 }
1233
_GetEvaluatorFor(const std::shared_ptr<JTransformedAbstractClosure> & func)1234 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<JTransformedAbstractClosure> &func) {
1235 MS_EXCEPTION_IF_NULL(func);
1236 const auto &primal_func = func->fn();
1237 auto primal_evaluator = GetEvaluatorFor(primal_func);
1238 return std::make_shared<JEvaluator>(primal_evaluator, primal_func);
1239 }
1240
_GetEvaluatorFor(const std::shared_ptr<VmapTransformedAbstractClosure> & func)1241 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<VmapTransformedAbstractClosure> &func) {
1242 MS_EXCEPTION_IF_NULL(func);
1243 const auto &primal_func = func->fn();
1244 const auto &in_axes = func->in_axes();
1245 const auto &out_axes = func->out_axes();
1246 size_t cell_size = func->cell_size();
1247 auto primal_evaluator = GetEvaluatorFor(primal_func);
1248 return std::make_shared<VmapEvaluator>(primal_evaluator, primal_func, in_axes, out_axes, cell_size);
1249 }
1250
_GetEvaluatorFor(const std::shared_ptr<TaylorTransformedAbstractClosure> & func)1251 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<TaylorTransformedAbstractClosure> &func) {
1252 MS_EXCEPTION_IF_NULL(func);
1253 const auto &primal_func = func->fn();
1254 auto primal_evaluator = GetEvaluatorFor(primal_func);
1255 return std::make_shared<TaylorEvaluator>(primal_evaluator, primal_func);
1256 }
1257
_GetEvaluatorFor(const std::shared_ptr<ShardTransformedAbstractClosure> & func)1258 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<ShardTransformedAbstractClosure> &func) {
1259 MS_EXCEPTION_IF_NULL(func);
1260 const auto &primal_func = func->fn();
1261 auto primal_evaluator = GetEvaluatorFor(primal_func);
1262 return std::make_shared<ShardEvaluator>(primal_evaluator, primal_func);
1263 }
1264
_GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> & func)1265 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> &func) {
1266 MS_EXCEPTION_IF_NULL(func);
1267 return std::make_shared<VirtualEvaluator>(func->args_abs_list(), func->output());
1268 }
1269
_GetEvaluatorFor(const std::shared_ptr<PartialAbstractClosure> & func)1270 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PartialAbstractClosure> &func) {
1271 MS_EXCEPTION_IF_NULL(func);
1272 auto primal_func = func->fn();
1273 auto part_pair = std::make_pair(primal_func, func->args());
1274 auto iter = constructors_app_.find(part_pair);
1275 if (iter != constructors_app_.end()) {
1276 return iter->second;
1277 }
1278 EvaluatorPtr partial_evaluator = nullptr;
1279 if (func->need_append_to_end()) {
1280 partial_evaluator = std::make_shared<PartialToEndEvaluator>(primal_func);
1281 } else {
1282 auto primal_evaluator = GetEvaluatorFor(primal_func);
1283 partial_evaluator = std::make_shared<PartialAppEvaluator>(primal_evaluator, func->args());
1284 }
1285 auto result = constructors_app_.emplace(std::move(part_pair), std::move(partial_evaluator));
1286 return result.first->second;
1287 }
1288
GetEvaluatorFor(const AbstractFunctionPtr & func)1289 EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
1290 MS_EXCEPTION_IF_NULL(func);
1291 MS_LOG(DEBUG) << "GetEvaluatorFor: " << func->ToString() << " tracking_id: " << func->tracking_id();
1292
1293 if (func->isa<PrimitiveAbstractClosure>()) {
1294 return _GetEvaluatorFor(std::static_pointer_cast<PrimitiveAbstractClosure>(func));
1295 }
1296 if (func->isa<FuncGraphAbstractClosure>()) {
1297 return _GetEvaluatorFor(std::static_pointer_cast<FuncGraphAbstractClosure>(func));
1298 }
1299 if (func->isa<MetaFuncGraphAbstractClosure>()) {
1300 return _GetEvaluatorFor(std::static_pointer_cast<MetaFuncGraphAbstractClosure>(func));
1301 }
1302 if (func->isa<JTransformedAbstractClosure>()) {
1303 return _GetEvaluatorFor(std::static_pointer_cast<JTransformedAbstractClosure>(func));
1304 }
1305 if (func->isa<VmapTransformedAbstractClosure>()) {
1306 return _GetEvaluatorFor(std::static_pointer_cast<VmapTransformedAbstractClosure>(func));
1307 }
1308 if (func->isa<TaylorTransformedAbstractClosure>()) {
1309 return _GetEvaluatorFor(std::static_pointer_cast<TaylorTransformedAbstractClosure>(func));
1310 }
1311 if (func->isa<ShardTransformedAbstractClosure>()) {
1312 return _GetEvaluatorFor(std::static_pointer_cast<ShardTransformedAbstractClosure>(func));
1313 }
1314 if (func->isa<VirtualAbstractClosure>()) {
1315 return _GetEvaluatorFor(std::static_pointer_cast<VirtualAbstractClosure>(func));
1316 }
1317 if (func->isa<PartialAbstractClosure>()) {
1318 return _GetEvaluatorFor(std::static_pointer_cast<PartialAbstractClosure>(func));
1319 }
1320
1321 MS_LOG(INTERNAL_EXCEPTION) << "Cannot GetEvaluator from " << func->type_name();
1322 }
1323
ForwardConfig(const AnfNodeConfigPtr & orig_conf,const AnfNodeConfigPtr new_conf)1324 EvalResultPtr AnalysisEngine::ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) {
1325 MS_EXCEPTION_IF_NULL(orig_conf);
1326 MS_EXCEPTION_IF_NULL(new_conf);
1327 // If always_eval_flag is true in BaseFuncGraphEvaluaotr, then the CNode with same orig_conf may be forwarded
1328 // again, so update the config_map with new_conf;
1329 anfnode_config_map_[orig_conf] = new_conf;
1330 MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->ToString() << ", to new_conf: " << new_conf->ToString();
1331 MS_EXCEPTION_IF_NULL(orig_conf->node());
1332 MS_EXCEPTION_IF_NULL(new_conf->node());
1333 auto old_cnode = orig_conf->node()->cast_ptr<CNode>();
1334 auto new_cnode = new_conf->node()->cast<CNodePtr>();
1335 if (old_cnode != nullptr && new_cnode != nullptr) {
1336 if (old_cnode->func_graph() == new_cnode->func_graph()) {
1337 MS_LOG(DEBUG) << "Try to remove forward node from order list, forward node: " << new_cnode->DebugString()
1338 << ", as origin node should be in order list, origin_node: " << old_cnode->DebugString();
1339 old_cnode->func_graph()->EraseUnusedNodeInOrder(new_cnode);
1340 } else {
1341 MS_LOG(INTERNAL_EXCEPTION) << "Forward orig_node to different func_graph, old_node: " << old_cnode->DebugString()
1342 << ", new_node: " << new_cnode->DebugString();
1343 }
1344 }
1345 (void)forward_count_++;
1346 auto res = ObtainEvalResultWithCache(new_conf);
1347 (void)forward_count_--;
1348 return res;
1349 }
1350
ExecuteEvaluators(const std::vector<EvaluatorPtr> & evaluators,const AnfNodeConfigPtr & out_conf,const ConfigPtrList & args_conf_list)1351 EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators,
1352 const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list) {
1353 if (evaluators.size() == 1) {
1354 auto &eval = evaluators[0];
1355 MS_EXCEPTION_IF_NULL(eval);
1356 return eval->Run(shared_from_this(), args_conf_list, out_conf);
1357 }
1358 static const bool enable_single_thread = (common::GetCompileConfig("SINGLE_EVAL") == "1");
1359 if (enable_single_thread) {
1360 return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list);
1361 }
1362 return ExecuteMultipleEvaluatorsMultiThread(evaluators, out_conf, args_conf_list);
1363 }
1364
SetUndeterminedFlag(const std::string & thread_id,const FuncGraph & fg)1365 void AnalysisEngine::SetUndeterminedFlag(const std::string &thread_id, const FuncGraph &fg) {
1366 static std::mutex fg_lock;
1367 std::lock_guard<std::mutex> infer_lock(fg_lock);
1368 MS_LOG(DEBUG) << "Record undetermined flag of fg:" << fg.ToString() << ", thread id:" << thread_id;
1369 func_graph_undetermined_flags_[&fg].push_front(thread_id);
1370 }
1371
SetIgnoreValueFlag(const std::string & thread_id,FuncGraph * fg)1372 void AnalysisEngine::SetIgnoreValueFlag(const std::string &thread_id, FuncGraph *fg) {
1373 MS_EXCEPTION_IF_NULL(fg);
1374 auto it = func_graph_undetermined_flags_.find(fg);
1375 if (it == func_graph_undetermined_flags_.cend()) {
1376 return;
1377 }
1378 for (const auto &id : it->second) {
1379 if (thread_id.find(id) != std::string::npos && thread_id != id) {
1380 MS_LOG(DEBUG) << "Set ignore value of fg:" << fg->ToString() << ", thread id:" << thread_id;
1381 fg->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUE, true);
1382 return;
1383 }
1384 }
1385 }
1386
HandleNestedRecursion(const std::vector<EvaluatorPtr> & evaluators,const EvaluatorPtr & eval,const AbstractBasePtrList & args_abs_list,const EvalTraceRevIter & it,bool * continue_flag)1387 EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators,
1388 const EvaluatorPtr &eval, const AbstractBasePtrList &args_abs_list,
1389 const EvalTraceRevIter &it, bool *continue_flag) {
1390 MS_EXCEPTION_IF_NULL(continue_flag);
1391 MS_EXCEPTION_IF_NULL(eval);
1392 *continue_flag = false;
1393 // Find latest entry function to handle nested recursion.
1394 EvaluatorPtr latest_entry = eval;
1395 auto latest_entry_iter = eval_trace_.crbegin();
1396 for (auto r_it = eval_trace_.crbegin(); *r_it != *it;) {
1397 auto it_temp = std::find(evaluators.cbegin(), evaluators.cend(), r_it->evaluator_);
1398 if (it_temp != evaluators.cend()) {
1399 latest_entry = *it_temp;
1400 latest_entry_iter = r_it;
1401 break;
1402 }
1403 latest_entry_iter = ++r_it;
1404 }
1405 if (latest_entry != eval) {
1406 MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString();
1407 *continue_flag = true;
1408 return latest_entry;
1409 }
1410
1411 bool has_undetermined = false;
1412 // Check whether sub loop has untraced undetermined evaluator.
1413 std::unordered_set<EvaluatorArgs, EvaluatorArgsHasher, EvaluatorArgsEqual> undetermined_evals;
1414 for (auto r_it = eval_trace_.crbegin(); r_it != latest_entry_iter; r_it++) {
1415 (void)undetermined_evals.insert(*r_it);
1416 }
1417 MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size();
1418
1419 for (const auto &u_eval : undetermined_evals) {
1420 MS_EXCEPTION_IF_NULL(u_eval.evaluator_);
1421 MS_LOG(DEBUG) << u_eval.evaluator_->ToString() << "check undetermined.";
1422 auto &alternate_evaluator = multi_poss_[u_eval.evaluator_];
1423 MS_EXCEPTION_IF_NULL(alternate_evaluator);
1424 auto eval_cache = alternate_evaluator->evaluator_cache_mgr();
1425 MS_EXCEPTION_IF_NULL(eval_cache);
1426 const auto &alt_eval_args = EvaluatorArgs(alternate_evaluator, args_abs_list);
1427 auto is_not_undetermined_eval = (undetermined_evals.find(alt_eval_args) == undetermined_evals.cend());
1428 auto is_not_continued_eval = (continued_evals_.find(u_eval) == continued_evals_.cend());
1429 auto args_not_evaluated = (eval_cache->GetValue(args_abs_list) == nullptr);
1430 if (is_not_undetermined_eval && (args_not_evaluated || is_not_continued_eval)) {
1431 MS_LOG(DEBUG) << u_eval.evaluator_->ToString() << "has undetermined.";
1432 has_undetermined = true;
1433 break;
1434 }
1435 }
1436 if (!has_undetermined) {
1437 MS_LOG(DEBUG) << eval->ToString() << "has no undetermined.";
1438 *continue_flag = true;
1439 return latest_entry;
1440 }
1441
1442 return latest_entry;
1443 }
1444
GetFuncGraphFromBranchNode(const AnfNodePtr & branch_node)1445 FuncGraphPtr GetFuncGraphFromBranchNode(const AnfNodePtr &branch_node) {
1446 MS_EXCEPTION_IF_NULL(branch_node);
1447 auto fg = GetValueNode<FuncGraphPtr>(branch_node);
1448 if (fg != nullptr) {
1449 return fg;
1450 }
1451 if (IsPrimitiveCNode(branch_node, prim::kPrimPartial)) {
1452 fg = GetValueNode<FuncGraphPtr>(branch_node->cast<CNodePtr>()->input(kPartialGraphIndex));
1453 }
1454 if (fg != nullptr) {
1455 return fg;
1456 }
1457 MS_LOG(INTERNAL_EXCEPTION) << "Unexpected branch node: " << branch_node->DebugString();
1458 }
1459
JoinBranchesFailedInfo(const AbstractBasePtr & abs,const AbstractBasePtr & last_out_abs,const AnfNodePtr & node,const std::string & error_info)1460 std::string JoinBranchesFailedInfo(const AbstractBasePtr &abs, const AbstractBasePtr &last_out_abs,
1461 const AnfNodePtr &node, const std::string &error_info) {
1462 constexpr int recursive_level = 2;
1463 std::ostringstream buffer;
1464 buffer << "Cannot join the return values of different branches, perhaps you need to make them equal.\n"
1465 << error_info
1466 << "#dmsg#Framework Error Message:#dmsg#The abstract type of the return value of the current branch is:\n"
1467 << abs->ToString() << ",\n and that of the previous branch is:\n"
1468 << last_out_abs->ToString() << ".\n"
1469 << "The node is " << node->DebugString(recursive_level);
1470 if (!node->isa<CNode>()) {
1471 buffer << "\n";
1472 return buffer.str();
1473 }
1474 auto input_node = node->cast_ptr<CNode>()->input(0);
1475 if (IsPrimitiveCNode(input_node, prim::kPrimSwitch)) {
1476 // {prim::kPrimSwitch, cond, true_branch, false_branch}
1477 const auto &cnode = input_node->cast_ptr<CNode>();
1478 auto true_out = GetFuncGraphFromBranchNode(cnode->input(kSwitchTrueBranchIndex))->get_return();
1479 auto false_out = GetFuncGraphFromBranchNode(cnode->input(kSwitchFalseBranchIndex))->get_return();
1480 buffer << ", true branch: " << cnode->input(kSwitchTrueBranchIndex)->ToString() << "\n"
1481 << trace::GetDebugInfoStr(true_out->debug_info())
1482 << "\n, false branch: " << cnode->input(kSwitchFalseBranchIndex)->ToString() << "\n"
1483 << trace::GetDebugInfoStr(false_out->debug_info());
1484 } else if (IsPrimitiveCNode(input_node, prim::kPrimSwitchLayer)) {
1485 // {prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, branch1, branch2, ...}}
1486 constexpr int branch_index = 2;
1487 const auto &tuple_node = input_node->cast_ptr<CNode>()->input(branch_index);
1488 if (IsPrimitiveCNode(tuple_node, prim::kPrimMakeTuple)) {
1489 const auto &cnode = tuple_node->cast_ptr<CNode>();
1490 for (size_t i = 1; i < cnode->size(); i++) {
1491 auto out_node = GetValueNode<FuncGraphPtr>(cnode->input(i))->get_return();
1492 MS_EXCEPTION_IF_NULL(out_node);
1493 buffer << ", branch" << i << ": " << cnode->input(i)->ToString() << "\n"
1494 << trace::GetDebugInfoStr(out_node->debug_info());
1495 }
1496 }
1497 } else {
1498 buffer << trace::GetDebugInfoStr(node->debug_info());
1499 }
1500 buffer << "\n";
1501 return buffer.str();
1502 }
1503
SetUseFlagsForJoinedAny(const AbstractBasePtrList & out_abs_list)1504 void SetUseFlagsForJoinedAny(const AbstractBasePtrList &out_abs_list) {
1505 for (const auto &abs : out_abs_list) {
1506 SetSequenceElementsUseFlagsRecursively(abs, true);
1507 }
1508 }
1509
ProcessEvalResults(const AbstractBasePtrList & out_abs_list,const AnfNodePtr & node)1510 EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_abs_list, const AnfNodePtr &node) {
1511 if (out_abs_list.empty()) {
1512 MS_LOG(INTERNAL_EXCEPTION) << "There is an endless loop for evaluator.";
1513 }
1514
1515 if (out_abs_list.size() == 1) {
1516 MS_EXCEPTION_IF_NULL(out_abs_list[0]);
1517 // If only one result derived, then broaden it to avoid wrong constant propagation.
1518 return std::make_shared<EvalResult>(out_abs_list[0]->Broaden(), std::make_shared<AttrValueMap>());
1519 }
1520 MS_EXCEPTION_IF_NULL(node);
1521
1522 // Return Any if some branch returns Any.
1523 if (std::any_of(out_abs_list.cbegin(), out_abs_list.cend(), [](const AbstractBasePtr &abs) {
1524 MS_EXCEPTION_IF_NULL(abs);
1525 return abs->isa<AbstractAny>() && !abs->isa<AbstractNegligible>();
1526 })) {
1527 MS_LOG(INFO) << "The branches outputs contain Any output.\nJoin them to Any output.";
1528 return std::make_shared<EvalResult>(std::make_shared<AbstractAny>(), std::make_shared<AttrValueMap>());
1529 }
1530
1531 AbstractBasePtr last_out_abs = out_abs_list[0];
1532 MS_EXCEPTION_IF_NULL(last_out_abs);
1533 AbstractBasePtr joined_abs = out_abs_list[0];
1534 for (size_t i = 1; i < out_abs_list.size(); ++i) {
1535 const auto &abs = out_abs_list[i];
1536 MS_EXCEPTION_IF_NULL(abs);
1537 try {
1538 MS_LOG(DEBUG) << "Join node: " << node->DebugString() << ", " << joined_abs->ToString() << ", and "
1539 << abs->ToString();
1540 MS_LOG_TRY_CATCH_SCOPE;
1541 joined_abs = joined_abs->Join(abs);
1542 } catch (const py::type_error &ex) {
1543 auto error_info = ExtractLoggingInfo(ex.what());
1544 const auto info = JoinBranchesFailedInfo(abs, last_out_abs, node, error_info);
1545 MS_LOG(INFO) << info;
1546 auto joined_any = std::make_shared<AbstractJoinedAny>();
1547 joined_any->set_exception(AbstractJoinedAny::ExceptionType::kTypeError);
1548 joined_any->set_message(info);
1549 SetUseFlagsForJoinedAny(out_abs_list);
1550 return std::make_shared<EvalResult>(joined_any, std::make_shared<AttrValueMap>());
1551 } catch (const py::value_error &ex) {
1552 auto error_info = ExtractLoggingInfo(ex.what());
1553 const auto info = JoinBranchesFailedInfo(abs, last_out_abs, node, error_info);
1554 MS_LOG(INFO) << info;
1555 auto joined_any = std::make_shared<AbstractJoinedAny>();
1556 joined_any->set_exception(AbstractJoinedAny::ExceptionType::kValueError);
1557 joined_any->set_message(info);
1558 SetUseFlagsForJoinedAny(out_abs_list);
1559 return std::make_shared<EvalResult>(joined_any, std::make_shared<AttrValueMap>());
1560 } catch (const std::exception &ex) {
1561 auto error_info = ExtractLoggingInfo(ex.what());
1562 const auto info = JoinBranchesFailedInfo(abs, last_out_abs, node, error_info);
1563 MS_LOG(INFO) << info;
1564 auto joined_any = std::make_shared<AbstractJoinedAny>();
1565 joined_any->set_exception(AbstractJoinedAny::ExceptionType::kDefault);
1566 joined_any->set_message(info);
1567 // Remove it when the transform form dict to tuple is disabled in Compatible or Lax mode.
1568 if (joined_abs->isa<AbstractDictionary>()) {
1569 joined_any->set_user_data<bool>("from_dict", std::make_shared<bool>(true));
1570 }
1571 SetUseFlagsForJoinedAny(out_abs_list);
1572 return std::make_shared<EvalResult>(joined_any, std::make_shared<AttrValueMap>());
1573 }
1574 MS_EXCEPTION_IF_NULL(joined_abs);
1575 last_out_abs = abs;
1576 }
1577
1578 MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_abs->ToString();
1579 return std::make_shared<EvalResult>(joined_abs, std::make_shared<AttrValueMap>());
1580 }
1581
ExecuteMultipleEvaluatorsMultiThread(const std::vector<EvaluatorPtr> & evaluators,const AnfNodeConfigPtr & out_conf,const ConfigPtrList & args_conf_list)1582 EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::vector<EvaluatorPtr> &evaluators,
1583 const AnfNodeConfigPtr &out_conf,
1584 const ConfigPtrList &args_conf_list) {
1585 MS_EXCEPTION_IF_NULL(out_conf);
1586 MS_EXCEPTION_IF_NULL(out_conf->node());
1587 MS_EXCEPTION_IF_NULL(out_conf->func_graph());
1588 // Release GIL for C++
1589 MS_LOG(DEBUG) << out_conf->func_graph()->ToString() << "_" << std::this_thread::get_id() << " begin.";
1590 py::gil_scoped_release infer_gil_release;
1591
1592 // Only one thread to run
1593 AnalysisSchedule::GetInstance().WaitForRun();
1594
1595 // Wait for the last switch node to finish.
1596 MS_LOG(DEBUG) << GetInferThread() << "async : entry switch " << out_conf->ToString();
1597 auto eval_result = AnalysisResultCacheMgr::GetInstance().GetSwitchValue(out_conf);
1598 if (eval_result == nullptr) {
1599 MS_LOG(DEBUG) << GetInferThread() << "async : Init switch " << out_conf->node()->ToString();
1600 AnalysisResultCacheMgr::GetInstance().InitSwitchValue(out_conf);
1601 } else {
1602 return std::make_shared<EvalResult>(eval_result, nullptr);
1603 }
1604 auto possible_parent_fg = out_conf->node()->func_graph();
1605 MS_EXCEPTION_IF_NULL(possible_parent_fg);
1606 // Eval result of the main.
1607 AsyncAbstractPtr async_result_main = std::make_shared<AsyncAbstract>();
1608 if (possible_parent_fg->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUE)) {
1609 async_result_main->set_ignore_value(true);
1610 }
1611 // Eval result of the branches
1612 std::vector<AsyncAbstractPtr> async_result_branches;
1613 SetUndeterminedFlag(AnalysisSchedule::thread_id(), *possible_parent_fg);
1614 for (auto &evaluator : evaluators) {
1615 static std::atomic<int> id_count{0};
1616 std::string thread_id = AnalysisSchedule::thread_id() + "." + std::to_string(id_count.fetch_add(1));
1617 MS_EXCEPTION_IF_NULL(evaluator);
1618 AsyncAbstractPtr async_result_branch = std::make_shared<AsyncAbstract>(async_result_main);
1619 // Control the order to run.
1620 AsyncAbstractPtr control_run_order = std::make_shared<AsyncAbstract>();
1621 control_run_order->set_result(std::make_shared<AbstractScalar>(1));
1622 AsyncInferTaskPtr async_task = AsyncInferTask::MakeShared(control_run_order, thread_id);
1623 AnalysisSchedule::GetInstance().IncreaseThreadCount();
1624 MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluator->ToString();
1625 auto thread = std::thread(ExecEvaluator, evaluator, shared_from_this(), args_conf_list, out_conf, thread_id,
1626 async_result_branch, async_result_main, async_task, trace::GetCurrentGraphEvalStack(),
1627 trace::GetCNodeDebugStack());
1628 thread.detach();
1629
1630 // Push to list of running loop
1631 MS_LOG(DEBUG) << "Add to schedule: " << async_task.get();
1632 AnalysisSchedule::GetInstance().Add2Schedule(async_task); // Activate order witch child thread.
1633 (void)async_result_branches.emplace_back(std::move(async_result_branch));
1634 }
1635
1636 size_t len = evaluators.size();
1637 size_t min_size = 2;
1638 if (len < min_size) {
1639 MS_LOG(EXCEPTION) << "There are at least 2 evaluators in multi thread, but got " << len << " evaluator.";
1640 }
1641
1642 MS_LOG(DEBUG) << GetInferThread() << "async : wait for one of async to finish. " << evaluators[0]->ToString()
1643 << " or " << evaluators[1]->ToString() << "...";
1644
1645 auto first_result = async_result_main->GetResult();
1646 MS_EXCEPTION_IF_NULL(first_result);
1647 MS_LOG(DEBUG) << GetInferThread() << "async main thread result of " << out_conf->node()->ToString() << " = "
1648 << first_result->ToString();
1649
1650 AbstractBasePtrList out_abs_list;
1651 if (NeedWaitForBranches(first_result)) {
1652 MS_LOG(DEBUG) << GetInferThread() << " BuildPossibleSpecs.";
1653 BuildPossibleSpecs(first_result, async_result_branches, &out_abs_list);
1654 } else {
1655 for (size_t i = 0; i < len; ++i) {
1656 AbstractBasePtr result;
1657 MS_EXCEPTION_IF_NULL(async_result_branches[i]);
1658 if (enable_waiting_branch_eval()) {
1659 // wait to get the result of branch.
1660 result = async_result_branches[i]->GetResult();
1661 } else {
1662 // Not wait to get the result of branch.
1663 result = async_result_branches[i]->TryGetResult();
1664 }
1665
1666 if (result) {
1667 MS_EXCEPTION_IF_NULL(evaluators[i]);
1668 MS_EXCEPTION_IF_NULL(result);
1669 MS_LOG(DEBUG) << "#" << i << ": " << GetInferThread() << " async get " << evaluators[i]->ToString()
1670 << ", result: " << result->ToString() << ", args: " << args_conf_list;
1671 out_abs_list.push_back(result);
1672 }
1673 }
1674 }
1675 MS_LOG(DEBUG) << GetInferThread() << " finish.";
1676 const auto &processed_result = ProcessEvalResults(out_abs_list, out_conf->node());
1677 if (processed_result != nullptr) {
1678 // This is the final switch()() value.
1679 AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, processed_result->abstract());
1680 }
1681 MS_LOG(DEBUG) << GetInferThread() << " join finish.";
1682 return processed_result;
1683 }
1684
ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> & evaluators,const AnfNodeConfigPtr & out_conf,const ConfigPtrList & args_conf_list)1685 EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
1686 const AnfNodeConfigPtr &out_conf,
1687 const ConfigPtrList &args_conf_list) {
1688 AbstractBasePtrList out_abs_list;
1689 const size_t evaluators_size = 2;
1690 if (evaluators.size() < evaluators_size) {
1691 MS_LOG(INTERNAL_EXCEPTION) << "Evaluators size is less than 2.";
1692 }
1693 multi_poss_[evaluators[0]] = evaluators[1];
1694 multi_poss_[evaluators[1]] = evaluators[0];
1695 AbstractBasePtrList args_abs_list = EvaluateArguments(args_conf_list);
1696 MS_EXCEPTION_IF_NULL(out_conf);
1697 MS_EXCEPTION_IF_NULL(out_conf->node());
1698 auto possible_parent_fg = out_conf->node()->func_graph();
1699 MS_EXCEPTION_IF_NULL(possible_parent_fg);
1700 possible_parent_fg->set_flag(kFuncGraphFlagUndetermined, true);
1701 MS_LOG(DEBUG) << "Set graph undetermined flag for " << possible_parent_fg->ToString();
1702 for (const auto &eval : evaluators) {
1703 MS_EXCEPTION_IF_NULL(eval);
1704 const auto current_inf = EvaluatorArgs(eval, args_abs_list);
1705 MS_LOG(DEBUG) << "Check evaluator " << eval->ToString();
1706 // If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating.
1707 auto it = std::find(eval_trace_.crbegin(), eval_trace_.crend(), current_inf);
1708 if (it == eval_trace_.crend()) {
1709 eval_trace_.push_back(current_inf);
1710 auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf);
1711 MS_EXCEPTION_IF_NULL(eval_result);
1712 auto eval_abstract = eval_result->abstract();
1713 MS_EXCEPTION_IF_NULL(eval_abstract);
1714
1715 out_abs_list.push_back(eval_abstract);
1716 eval_trace_.pop_back();
1717 if (eval_trace_.empty()) {
1718 multi_poss_.clear();
1719 }
1720 } else {
1721 bool continue_flag = false;
1722 auto latest_entry = HandleNestedRecursion(evaluators, eval, args_abs_list, it, &continue_flag);
1723 if (continue_flag) {
1724 MS_EXCEPTION_IF_NULL(current_inf.evaluator_);
1725 MS_LOG(DEBUG) << "The continued_evals_ insert " << current_inf.evaluator_.get() << "/"
1726 << current_inf.evaluator_->ToString();
1727 continued_evals_.insert(current_inf);
1728 continue;
1729 }
1730
1731 // Try to travel the latest undetermined.
1732 if (latest_entry != eval_trace_.rbegin()->evaluator_) {
1733 MS_LOG(DEBUG) << "Direct run evaluator " << eval.get() << "/" << eval->ToString();
1734 auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf);
1735 MS_EXCEPTION_IF_NULL(eval_result);
1736 MS_EXCEPTION_IF_NULL(eval_result->abstract());
1737 MS_LOG(DEBUG) << "End direct evaluator " << latest_entry->ToString()
1738 << ", return out_abs: " << eval_result->abstract()->ToString();
1739 possible_parent_fg->set_flag(kFuncGraphFlagUndetermined, false);
1740 return eval_result;
1741 }
1742 }
1743 }
1744 possible_parent_fg->set_flag(kFuncGraphFlagUndetermined, false);
1745 return ProcessEvalResults(out_abs_list, out_conf->node());
1746 }
1747
ObtainEvalResult()1748 EvalResultPtr AnfNodeConfig::ObtainEvalResult() {
1749 AnfNodeConfigPtr self = shared_from_base<AnfNodeConfig>();
1750 return engine_.lock()->ObtainEvalResultWithCache(self);
1751 }
1752
MakeAbstractClosure(const FuncGraphPtr & func_graph,const AnalysisContextPtr & context,const AnfNodePtr & anf_node)1753 AbstractBasePtr MakeAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
1754 const AnfNodePtr &anf_node) {
1755 AnalysisContextPtr temp_context = context;
1756 if (temp_context == nullptr) {
1757 temp_context = AnalysisContext::DummyContext();
1758 }
1759 return std::make_shared<FuncGraphAbstractClosure>(func_graph, temp_context, anf_node);
1760 }
1761
MakeAbstractClosure(const MetaFuncGraphPtr & meta_func_graph,const AnfNodePtr & anf_node)1762 AbstractBasePtr MakeAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const AnfNodePtr &anf_node) {
1763 MetaFuncGraphAbstractClosurePtr meta_func_graph_fn;
1764 if (anf_node == nullptr) {
1765 meta_func_graph_fn = std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph);
1766 } else {
1767 meta_func_graph_fn = std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph, anf_node, anf_node->scope());
1768 }
1769 return meta_func_graph_fn;
1770 }
1771
MakeAbstractClosure(const PrimitivePtr & primitive,const AnfNodePtr & anf_node)1772 AbstractBasePtr MakeAbstractClosure(const PrimitivePtr &primitive, const AnfNodePtr &anf_node) {
1773 auto prim_func = std::make_shared<PrimitiveAbstractClosure>(primitive, anf_node);
1774 return prim_func;
1775 }
1776
ToAbstract(const ValuePtr & value,const AnalysisContextPtr & context,const AnfNodeConfigPtr & conf)1777 AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context, const AnfNodeConfigPtr &conf) {
1778 MS_EXCEPTION_IF_NULL(value);
1779 AnfNodePtr anf_node = nullptr;
1780 if (conf != nullptr) {
1781 anf_node = conf->node();
1782 }
1783 if (value->isa<Primitive>()) {
1784 auto prim = value->cast<PrimitivePtr>();
1785 return MakeAbstractClosure(prim, anf_node);
1786 }
1787 if (value->isa<FuncGraph>()) {
1788 auto func_graph = value->cast<FuncGraphPtr>();
1789 return MakeAbstractClosure(func_graph, context, anf_node);
1790 }
1791 if (value->isa<MetaFuncGraph>()) {
1792 auto meta_func_graph = value->cast<MetaFuncGraphPtr>();
1793 return MakeAbstractClosure(meta_func_graph, anf_node);
1794 }
1795 if (value->isa<ValueSequence>() && anf_node != nullptr) {
1796 auto abs = value->ToAbstract();
1797 MS_EXCEPTION_IF_NULL(abs);
1798 // Attach corresponding python sequence object to AbstractSequence.
1799 py::object py_list_obj =
1800 fallback::HasPyObjectInNode(anf_node) ? fallback::GetPyObjectFromNode(anf_node) : ValueToPyData(value);
1801 fallback::AttachPyObjToAbs(abs, py_list_obj, !fallback::HasPyObjectInNode(anf_node));
1802 MS_LOG(DEBUG) << "Attach python list object " << fallback::GetPyObjectPtrStr(py_list_obj)
1803 << " to new abstract: " << abs->ToString();
1804 // Set sequence node for new AbstractSequence.
1805 static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
1806 if (enable_eliminate_unused_element) {
1807 auto sequence_abs = abs->cast<AbstractSequencePtr>();
1808 MS_EXCEPTION_IF_NULL(sequence_abs);
1809 SetSequenceNodeElementsUseFlags(anf_node, std::make_shared<std::vector<bool>>(sequence_abs->elements().size()));
1810 std::shared_ptr<AnfNodeWeakPtrList> sequence_nodes = std::make_shared<AnfNodeWeakPtrList>();
1811 (void)sequence_nodes->emplace_back(AnfNodeWeakPtr(anf_node));
1812 sequence_abs->set_sequence_nodes(sequence_nodes);
1813 }
1814 return abs;
1815 }
1816 if (value->isa<ValueDictionary>() && anf_node != nullptr) {
1817 auto abs = value->ToAbstract();
1818 MS_EXCEPTION_IF_NULL(abs);
1819 // Attach corresponding python dictionary object to AbstractDictionary.
1820 py::object py_dict_obj =
1821 fallback::HasPyObjectInNode(anf_node) ? fallback::GetPyObjectFromNode(anf_node) : fallback::GeneratePyObj(abs);
1822 fallback::AttachPyObjToAbs(abs, py_dict_obj, !fallback::HasPyObjectInNode(anf_node));
1823 MS_LOG(DEBUG) << "Attach python dict object " << fallback::GetPyObjectPtrStr(py_dict_obj)
1824 << " to new abstract: " << abs->ToString();
1825 return abs;
1826 }
1827 return value->ToAbstract();
1828 }
1829
FromValueInside(const ValuePtr & value,bool broaden)1830 AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) {
1831 AbstractBasePtr a = ToAbstract(value, nullptr, nullptr);
1832 if (broaden) {
1833 a = a->Broaden();
1834 }
1835 return a;
1836 }
1837
EvalOnePrim(const PrimitivePtr & primitive,const AbstractBasePtrList & arg_specs)1838 EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) {
1839 auto evaluator = GetPrimEvaluator(primitive, nullptr);
1840 if (evaluator == nullptr) {
1841 MS_LOG(ERROR) << "The evaluator of the primitive is not defined (" << primitive->name() << ").";
1842 return nullptr;
1843 }
1844 auto trivial_evaluator = dyn_cast_ptr<TrivialPrimEvaluator>(evaluator);
1845 if (trivial_evaluator != nullptr) {
1846 return trivial_evaluator->EvalPrim(nullptr, arg_specs);
1847 }
1848 // Support MakeTuple/MakeList ops in PyNative mode.
1849 auto transition_evaluator = dyn_cast_ptr<TransitionPrimEvaluator>(evaluator);
1850 if (transition_evaluator != nullptr) {
1851 if (transition_evaluator->isa<MakeTupleEvaluator>() || transition_evaluator->isa<MakeListEvaluator>()) {
1852 return transition_evaluator->EvalPrim(nullptr, arg_specs, nullptr, nullptr);
1853 }
1854 return pipeline::AbstractAnalyze(primitive, arg_specs).eval_result;
1855 }
1856 // To add EvalPrim call of TransitionPrimEvaluator such as GetAttr.
1857 MS_LOG(ERROR) << "The primitive '" << primitive->ToString() << "' should be built as a TrivialPrimEvaluator, but "
1858 << evaluator->ToString();
1859 return nullptr;
1860 }
1861
EvalFunctionValue(const ValuePtr & func,const AbstractBasePtrList & args_spec)1862 AbstractBasePtr EvalFunctionValue(const ValuePtr &func, const AbstractBasePtrList &args_spec) {
1863 auto func_abs = func->ToAbstract();
1864 if (!func_abs->isa<AbstractFunction>()) {
1865 MS_LOG(EXCEPTION) << "The value : " << func->ToString() << " is not a callable object.";
1866 }
1867 if (func->isa<Primitive>() && !func->isa<prim::DoSignaturePrimitive>()) {
1868 return EvalOnePrim(func->cast<PrimitivePtr>(), args_spec)->abstract();
1869 } else {
1870 auto infer_graph = std::make_shared<FuncGraph>();
1871 std::vector<AnfNodePtr> inputs = {std::make_shared<ValueNode>(func)};
1872 (void)std::transform(args_spec.begin(), args_spec.end(), std::back_inserter(inputs),
1873 [infer_graph](const AbstractBasePtr &) -> AnfNodePtr { return infer_graph->add_parameter(); });
1874 auto infer_node = infer_graph->NewCNode(inputs);
1875 infer_graph->set_return(infer_node);
1876 auto manager = Manage(infer_graph, true);
1877 auto engine = std::make_shared<abstract::AnalysisEngine>(abstract::GetPrimEvaluatorConstructors(), manager);
1878 auto res = engine->Run(infer_graph, args_spec);
1879 return res.eval_result->abstract();
1880 }
1881 }
1882
NewContext(const AnalysisContextPtr & current_context,const FuncGraphPtr & fg,const AbstractBasePtrList & args_abs_list)1883 AnalysisContextPtr NewContext(const AnalysisContextPtr ¤t_context, const FuncGraphPtr &fg,
1884 const AbstractBasePtrList &args_abs_list) {
1885 MS_EXCEPTION_IF_NULL(fg);
1886 auto new_context = current_context->NewContext(fg, args_abs_list);
1887 if (new_context == nullptr) { // Not obtain context for fg->parent() during create context.
1888 FuncGraphPtr parent_graph = fg->parent();
1889 const auto no_parent = parent_graph == nullptr;
1890 #ifdef ENABLE_DUMP_IR
1891 DumpIR(std::string("EXCEPTION_NEW_CONTEXT_CURRENT_") + (no_parent ? "0" : "1") + "_" + fg->ToString() + ".ir", fg);
1892 if (!no_parent) {
1893 DumpIR("EXCEPTION_NEW_CONTEXT_PARENT_" + parent_graph->ToString() + ".ir", parent_graph);
1894 }
1895 #endif
1896 // If parent context is not found, we'll raise exception.
1897 MS_LOG(INTERNAL_EXCEPTION) << "BUG: Failed to find parent context in current context: "
1898 << current_context->ToString() << ", func_graph: " << fg->ToString()
1899 << ", parent_graph: " << (no_parent ? "null" : parent_graph->ToString()) << ",\n"
1900 << trace::GetDebugInfoStr(fg->debug_info());
1901 }
1902 return new_context;
1903 }
1904 } // namespace abstract
1905 } // namespace mindspore
1906