• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/jit/static_analysis/evaluator.h"
18 
19 #include <algorithm>
20 #include <utility>
21 #include <unordered_set>
22 
23 #include "ir/func_graph_cloner.h"
24 #include "abstract/utils.h"
25 #include "debug/trace.h"
26 #include "utils/ms_context.h"
27 #include "pipeline/jit/static_analysis/stack_frame.h"
28 #include "pipeline/jit/static_analysis/async_eval_result.h"
29 
30 namespace mindspore {
31 namespace abstract {
32 namespace {
EvalEntryLogging(const EvaluatorPtr & evaluator,const AbstractBasePtrList & arg_spec_list,const AnfNodeConfigPtr & out_conf)33 string EvalEntryLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &arg_spec_list,
34                         const AnfNodeConfigPtr &out_conf) {
35   MS_EXCEPTION_IF_NULL(evaluator);
36   std::stringstream ss;
37   if (out_conf != nullptr) {
38     ss << "Evaluator " << evaluator->ToString() << " run for " << out_conf->node()->scope()->name();
39   }
40   for (size_t i = 0; i < arg_spec_list.size(); i++) {
41     ss << evaluator->ToString() << " input[" << i
42        << "] abstract value: " << (arg_spec_list[i] ? arg_spec_list[i]->ToString() : "null abstract.");
43   }
44   return ss.str();
45 }
46 
EvalFailLogging(const EvaluatorPtr & evaluator,const AbstractBasePtrList &,const AnfNodeConfigPtr & out_conf)47 void EvalFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &, const AnfNodeConfigPtr &out_conf) {
48   MS_EXCEPTION_IF_NULL(evaluator);
49   if (out_conf != nullptr) {
50     auto node = out_conf->node();
51     if (IsValueNode<Primitive>(node)) {
52       MS_LOG(ERROR) << "Evaluator " << evaluator->ToString() << " run failed for node " << node->fullname_with_scope()
53                     << ", with debug info: " << trace::GetDebugInfo(node->debug_info());
54     } else {
55       MS_LOG(ERROR) << "Evaluator " << evaluator->ToString() << " run failed for node " << node->DebugString()
56                     << ", with debug info: " << trace::GetDebugInfo(node->debug_info());
57     }
58   }
59 }
60 }  // namespace
61 
EnterStackFrame(const AnalysisEnginePtr & engine,const StackFramePtr & current_stack_frame,const StackFramePtr & new_stack_frame)62 void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr &current_stack_frame,
63                                              const StackFramePtr &new_stack_frame) {
64   MS_EXCEPTION_IF_NULL(current_stack_frame);
65   MS_EXCEPTION_IF_NULL(new_stack_frame);
66   MS_EXCEPTION_IF_NULL(engine);
67   // Enter new func graph.
68   auto &current_node = current_stack_frame->CurrentNode();
69   auto current_context = current_stack_frame->current_context();
70   AnfNodeConfigPtr call_conf = engine->MakeConfig(current_node, current_context, current_context->func_graph());
71   auto evaluator = new_stack_frame->evaluator();
72   MS_EXCEPTION_IF_NULL(evaluator);
73   auto new_context = new_stack_frame->current_context();
74   trace::TraceGraphEvalEnter(new_context, call_conf);
75 
76   // Increase & Check the func graph call depth.
77   IncreaseFunctionCallDepth();
78   IncreaseStackFrameDepth();
79   const uint32_t max_depth = MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH);
80   if (FunctionCallDepth() > max_depth) {
81     MS_LOG(EXCEPTION) << "Exceed function call depth limit " << max_depth
82                       << ", (function call depth: " << FunctionCallDepth()
83                       << ", simulate call depth: " << StackFrameDepth() << ").\n"
84                       << "It's always happened with complex construction of code or infinite recursion or loop.\n"
85                       << "Please check the code if it's has the infinite recursion "
86                       << "or call 'context.set_context(max_call_depth=value)' to adjust this value.\n"
87                       << "If max_call_depth is set larger, the system max stack depth should be set larger too "
88                       << "to avoid stack overflow.\n"
89                       << "For more details, please refer to the FAQ at https://www.mindspore.cn.";
90   }
91   MS_LOG(DEBUG) << evaluator << "(" << evaluator->type_name() << "/" << evaluator->ToString()
92                 << "), enter, function call depth: " << FunctionCallDepth() << " - " << StackFrameDepth();
93 }
94 
LeaveStackFrame(const AnalysisEnginePtr &,const StackFramePtr & current_stack_frame)95 void BaseFuncGraphEvaluator::LeaveStackFrame(const AnalysisEnginePtr &, const StackFramePtr &current_stack_frame) {
96   MS_EXCEPTION_IF_NULL(current_stack_frame);
97   // Leave current func graph.
98   auto current_context = current_stack_frame->current_context();
99   trace::TraceGraphEvalLeave(current_context);
100 
101   // Decrease the func graph call depth.
102   DecreaseFunctionCallDepth();
103   DecreaseStackFrameDepth();
104 
105   auto evaluator = current_stack_frame->evaluator();
106   MS_EXCEPTION_IF_NULL(evaluator);
107   MS_LOG(DEBUG) << evaluator << "(" << evaluator->type_name() << "/" << evaluator->ToString()
108                 << "), leave, function call depth: " << FunctionCallDepth() << " - " << StackFrameDepth();
109 }
110 
111 // Start running stack frames in a Evaluator.
LaunchStackFrame(const AnalysisEnginePtr & engine,const FuncGraphPtr & fg,const AnalysisContextPtr & context)112 AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
113                                                          const AnalysisContextPtr &context) {
114   EvalResultPtr eval_result = nullptr;
115   AbstractBasePtr res_base = nullptr;
116   std::stack<StackFramePtr> stack_frames;
117   auto current_stack_frame = std::make_shared<StackFrame>(shared_from_base<Evaluator>(), fg, context, parent_context_);
118   MS_LOG(DEBUG) << "[" << this << "/StackFrame] Start at func graph, " << current_stack_frame;
119   stack_frames.push(current_stack_frame);
120   while (true) {
121     current_stack_frame = stack_frames.top();
122     if (current_stack_frame->Done()) {
123       MS_EXCEPTION_IF_NULL(res_base);
124       MS_LOG(DEBUG) << "[" << this << "/StackFrame] Leave from func graph, " << current_stack_frame;
125       stack_frames.pop();
126       if (stack_frames.empty()) {
127         MS_LOG(DEBUG) << "[" << this << "/StackFrame] Finish at func graph, " << current_stack_frame
128                       << ", res_base: " << res_base->ToString();
129         break;
130       }
131       // Leave current func graph.
132       LeaveStackFrame(engine, current_stack_frame);
133       // Switch the stack frame.
134       auto last_stack_frame = current_stack_frame;
135       current_stack_frame = stack_frames.top();
136       MS_LOG(DEBUG) << "[" << this << "/StackFrame] Back to func graph, " << current_stack_frame;
137       current_stack_frame->Back(engine, last_stack_frame, eval_result);
138       continue;
139     }
140 
141     auto new_stack_frame = current_stack_frame->Jump(engine);
142     if (new_stack_frame != nullptr) {
143       // Enter new func graph.
144       EnterStackFrame(engine, current_stack_frame, new_stack_frame);
145       // Update current stack frame.
146       stack_frames.push(new_stack_frame);
147       current_stack_frame = new_stack_frame;
148       MS_LOG(DEBUG) << "[" << this << "/StackFrame] Jump to new func graph, " << new_stack_frame;
149       continue;
150     }
151 
152     eval_result = current_stack_frame->Step(engine);
153     MS_EXCEPTION_IF_NULL(eval_result);
154     res_base = eval_result->abstract();
155   }
156   return res_base;
157 }
158 
LaunchRecursiveEval(const AnalysisEnginePtr & engine,const FuncGraphPtr & fg,const AnalysisContextPtr & context)159 AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
160                                                             const AnalysisContextPtr &context) {
161   MS_EXCEPTION_IF_NULL(fg);
162   MS_EXCEPTION_IF_NULL(engine);
163   const AnfNodePtr &func_node = fg->get_return();
164   const auto &all_nodes = TopoSort(func_node, SuccIncoming, [](const AnfNodePtr &node) -> IncludeType {
165     MS_EXCEPTION_IF_NULL(node);
166     if (node->isa<ValueNode>() || node->isa<Parameter>()) {
167       return EXCLUDE;
168     }
169     return FOLLOW;
170   });
171   AbstractBasePtr res_base = nullptr;
172   for (const auto &node : all_nodes) {
173     AnfNodeConfigPtr node_conf = engine->MakeConfig(node, context, fg);
174     MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg << "/" << fg->ToString()
175                   << ", node_conf: " << node_conf->ToString();
176     auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
177     MS_EXCEPTION_IF_NULL(node_eval_result);
178     res_base = node_eval_result->abstract();
179     MS_EXCEPTION_IF_NULL(res_base);
180     MS_LOG(DEBUG) << GetInferThread() << "Eval ( " << node_conf->ToString() << ") = " << res_base->ToString();
181   }
182   MS_EXCEPTION_IF_NULL(res_base);
183   return res_base;
184 }
185 
Eval(AnalysisEnginePtr engine,const AbstractBasePtrList & args_abs_list,const AnfNodeConfigPtr & out_conf)186 EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list,
187                                            const AnfNodeConfigPtr &out_conf) {
188   auto eval_result = evaluator_cache_mgr_->GetValue(args_abs_list);
189   if (eval_result != nullptr) {
190     MS_LOG(ERROR) << ToString() << ArgsToString(args_abs_list) << " entered again. There is something wrong.";
191     return eval_result;
192   } else {
193     MS_LOG(DEBUG) << ToString() << " entered first.";
194   }
195   MS_EXCEPTION_IF_NULL(engine);
196 
197   // Increase & Check the func graph call depth.
198   IncreaseFunctionCallDepth();
199   const uint32_t max_depth = MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH);
200   if (FunctionCallDepth() > max_depth) {
201     MS_LOG(EXCEPTION) << "Exceed function call depth limit " << max_depth
202                       << ", (function call depth: " << FunctionCallDepth()
203                       << ", simulate call depth: " << StackFrameDepth() << ").\n"
204                       << "It's always happened with complex construction of code or infinite recursion or loop.\n"
205                       << "Please check the code if it's has the infinite recursion "
206                       << "or call 'context.set_context(max_call_depth=value)' to adjust this value.\n"
207                       << "If max_call_depth is set larger, the system max stack depth should be set larger too "
208                       << "to avoid stack overflow.\n"
209                       << "For more details, please refer to the FAQ at https://www.mindspore.cn.";
210   }
211   MS_LOG(DEBUG) << this << "(" << type_name() << "/" << ToString()
212                 << "), enter, function call depth: " << FunctionCallDepth() << " - " << StackFrameDepth();
213 
214   FuncGraphPtr fg = GetFuncGraph(engine, args_abs_list);
215   MS_EXCEPTION_IF_NULL(fg);
216   auto context = parent_context_->NewContext(fg, args_abs_list);
217   trace::TraceGraphEvalEnter(context, out_conf);
218 
219   std::size_t nargs = fg->parameters().size();
220   if (args_abs_list.size() != nargs) {
221     MS_EXCEPTION(TypeError) << "The parameters number of the function is " << fg->parameters().size()
222                             << ", but the number of provided arguments is " << args_abs_list.size() << ".\n"
223                             << "FunctionGraph : " << fg->ToString()
224                             << "\nNodeInfo: " << trace::GetDebugInfo(fg->debug_info());
225   }
226   MS_EXCEPTION_IF_NULL(parent_context_);
227   MS_LOG(DEBUG) << GetInferThread() << "@" << fg->ToString() << ArgsToString(args_abs_list) << " { ";
228   if (parent_context_->func_graph() != nullptr) {
229     MS_LOG(DEBUG) << GetInferThread() << "graph_: " << AnalysisSchedule::GetThreadID() << ":"
230                   << parent_context_->func_graph()->ToString() << "()->" << AnalysisSchedule::GetThreadID() << ":"
231                   << fg->ToString() << "();";
232   }
233 
234   auto func_graph_evaluator = dyn_cast<FuncGraphEvaluator>(shared_from_base<BaseFuncGraphEvaluator>());
235   if (func_graph_evaluator != nullptr) {
236     if (engine->root_func_graph() == func_graph_evaluator->func_graph()) {
237       engine->set_root_context(context);
238     }
239   }
240   const auto &parameters = fg->parameters();
241   for (size_t i = 0; i < nargs; i++) {
242     const auto &arg = args_abs_list[i];
243     const auto &node = parameters[i];
244     AnfNodeConfigPtr conf = engine->MakeConfig(node, context, fg);
245     engine->SaveEvalResultInCache(conf, std::make_shared<EvalResult>(arg, nullptr));
246     MS_LOG(DEBUG) << GetInferThread() << "Set Param: " << conf->ToString() << "   =   " << arg->ToString();
247   }
248   MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg << "/" << fg->ToString()
249                 << ", context: " << context->ToString() << ", return node: " << fg->get_return()->DebugString()
250                 << ", parent: " << (parent_context_->func_graph() ? parent_context_->func_graph()->ToString() : "NULL")
251                 << ", current function call depth: " << FunctionCallDepth();
252   AbstractBasePtr res_base = nullptr;
253   if (engine->enable_recursive_eval()) {
254     res_base = LaunchRecursiveEval(engine, fg, context);
255   } else {
256     res_base = LaunchStackFrame(engine, fg, context);
257   }
258 
259   MS_EXCEPTION_IF_NULL(res_base);
260   MS_LOG(DEBUG) << "Analysis FuncGraph end, " << fg << "/" << fg->ToString()
261                 << ", evaluated abstract: " << res_base->ToString() << ", is stub: " << fg->stub();
262   if (fg->stub()) {
263     res_base = std::make_shared<AbstractUndetermined>();
264   }
265   MS_LOG(DEBUG) << GetInferThread() << "} //" << fg->ToString() << " = " << res_base->ToString();
266 
267   trace::TraceGraphEvalLeave(context);
268   // Decrease the func graph call depth.
269   DecreaseFunctionCallDepth();
270   MS_LOG(DEBUG) << this << "(" << type_name() << "/" << ToString()
271                 << "), leave, function call depth: " << FunctionCallDepth() << " - " << StackFrameDepth();
272   auto res = std::make_shared<EvalResult>(res_base, nullptr);
273   return res;
274 }
275 
BroadenArgs(const AbstractBasePtrList & args_spec_list,AbstractBasePtrList * broaded_args)276 void BroadenArgs(const AbstractBasePtrList &args_spec_list, AbstractBasePtrList *broaded_args) {
277   MS_EXCEPTION_IF_NULL(broaded_args);
278   (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(*broaded_args),
279                        [](const AbstractBasePtr &arg) -> AbstractBasePtr {
280                          MS_EXCEPTION_IF_NULL(arg);
281                          if (arg->GetValueTrack() != kAnyValue) {
282                            return arg->Broaden();
283                          }
284                          return arg;
285                        });
286 }
287 
NormalizeArgs(const AbstractBasePtrList & args_spec_list) const288 AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
289   MS_EXCEPTION_IF_NULL(func_graph_);
290   if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
291     AbstractBasePtrList broaded_list;
292     BroadenArgs(args_spec_list, &broaded_list);
293     MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list)
294                   << ", broaded: " << mindspore::ToString(broaded_list);
295     return broaded_list;
296   }
297   return args_spec_list;
298 }
299 
BroadenUndeterminedArgs(const AbstractBasePtrList & args_spec_list)300 AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) {
301   MS_EXCEPTION_IF_NULL(func_graph_);
302   if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
303     return args_spec_list;
304   }
305 
306   if (func_graph_->has_flag(kFuncGraphFlagUndetermined)) {
307     func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
308     auto normalized_args_spec_list = NormalizeArgs(args_spec_list);
309     MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
310     MS_LOG(DEBUG) << "Normalized args " << mindspore::ToString(normalized_args_spec_list);
311     return normalized_args_spec_list;
312   }
313   return args_spec_list;
314 }
315 
GetFuncGraph(AnalysisEnginePtr engine,const AbstractBasePtrList & args_spec_list)316 FuncGraphPtr FuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
317   auto iter = func_graph_cache_.find(args_spec_list);
318   FuncGraphPtr res;
319   if (iter == func_graph_cache_.end()) {
320     auto fg = func_graph();
321     MS_EXCEPTION_IF_NULL(fg);
322     FuncGraphPtr generated_graph = fg->GenerateGraph(args_spec_list);
323     func_graph_cache_[args_spec_list] = generated_graph;
324     MS_EXCEPTION_IF_NULL(engine);
325     engine->func_graph_manager()->AddFuncGraph(generated_graph);
326     res = generated_graph;
327   } else {
328     res = iter->second;
329   }
330 
331   // For the top graph, if it is replaced by generated graph, update the top graph to the new one.
332   if (parse::Parser::GetTopFuncGraph() == func_graph()) {
333     if (res != func_graph()) {
334       parse::Parser::UpdateTopFuncGraph(res);
335     }
336   }
337   return res;
338 }
339 
GetFuncGraph(AnalysisEnginePtr engine,const AbstractBasePtrList & args_spec_list)340 FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
341   auto iter = func_graph_cache_.find(args_spec_list);
342   if (iter != func_graph_cache_.end()) {
343     return iter->second;
344   }
345 
346   MS_EXCEPTION_IF_NULL(meta_func_graph_);
347   FuncGraphPtr generated_func_graph;
348   if (this->bound_node() != nullptr) {
349     TraceGuard trace_guard(std::make_shared<TraceGenMetaFuncGraph>(bound_node()->debug_info()));
350     generated_func_graph = meta_func_graph_->GenerateFuncGraph(args_spec_list);
351   } else {
352     generated_func_graph = meta_func_graph_->GenerateFuncGraph(args_spec_list);
353   }
354 
355   FuncGraphPtr cloned_func_graph = BasicClone(generated_func_graph);
356   func_graph_cache_[args_spec_list] = cloned_func_graph;
357   MS_EXCEPTION_IF_NULL(engine);
358   engine->func_graph_manager()->AddFuncGraph(cloned_func_graph);
359   return cloned_func_graph;
360 }
361 
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)362 EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
363                              const AnfNodeConfigPtr &out_conf) {
364   AbstractBasePtrList args_spec_list;
365   (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
366                        [](const ConfigPtr &conf) -> AbstractBasePtr {
367                          MS_EXCEPTION_IF_NULL(conf);
368                          return conf->ObtainEvalResult()->abstract();
369                        });
370   args_spec_list = NormalizeArgs(args_spec_list);
371   args_spec_list = BroadenUndeterminedArgs(args_spec_list);
372 
373   MS_LOG(DEBUG) << EvalEntryLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
374   const std::string &evaluator_name = ToString();
375   MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
376   auto eval_result = evaluator_cache_mgr_->GetValue(args_spec_list);
377   if (eval_result == nullptr) {
378     MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval().";
379     eval_result = Eval(engine, args_spec_list, out_conf);
380     MS_EXCEPTION_IF_NULL(eval_result);
381     if (eval_result->abstract() == nullptr) {
382       EvalFailLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
383       MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr.";
384     }
385     MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << eval_result->abstract()->ToString() << ".";
386     evaluator_cache_mgr_->SetValue(args_spec_list, eval_result);
387   } else {
388     MS_EXCEPTION_IF_NULL(eval_result);
389     MS_EXCEPTION_IF_NULL(eval_result->abstract());
390     MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << eval_result->abstract()->ToString() << ".";
391   }
392   return eval_result;
393 }
394 
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr &)395 EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
396                                         const AnfNodeConfigPtr &) {
397   AbstractBasePtrList args_spec_list;
398   auto is_py_eval = (identifier_ == "PythonPrimEvaluator");
399   (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
400                        [is_py_eval](const ConfigPtr &conf) -> AbstractBasePtr {
401                          MS_EXCEPTION_IF_NULL(conf);
402                          auto abstract = conf->ObtainEvalResult()->abstract();
403                          MS_EXCEPTION_IF_NULL(abstract);
404                          // broaden the ref_key, while infer python prim for cache
405                          if (is_py_eval && abstract->isa<AbstractRef>()) {
406                            auto abs_ref = abstract->cast<AbstractRefPtr>();
407                            abstract = std::make_shared<AbstractRef>(abs_ref->ref_key()->Broaden(), abs_ref);
408                          }
409                          return abstract;
410                        });
411   return EvalPrim(engine, args_spec_list);
412 }
413 
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)414 EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
415                                            const AnfNodeConfigPtr &out_conf) {
416   if (args_conf_list.empty()) {
417     MS_LOG(EXCEPTION) << "Size should greater than 0";
418   }
419   AbstractBasePtrList args_spec_list;
420   (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
421                        [](const ConfigPtr &conf) -> AbstractBasePtr {
422                          MS_EXCEPTION_IF_NULL(conf);
423                          return conf->ObtainEvalResult()->abstract();
424                        });
425   EvalResultPtr res = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf);
426   // No need to cache.
427   return res;
428 }
429 
Run(AnalysisEnginePtr,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr &)430 EvalResultPtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list,
431                                          const AnfNodeConfigPtr &) {
432   return EvalPrim(args_conf_list);
433 }
434 
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)435 EvalResultPtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
436                                     const AnfNodeConfigPtr &out_conf) {
437   AbstractBasePtrList args_spec_list;
438   (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
439                        [](const ConfigPtr &conf) -> AbstractBasePtr {
440                          MS_EXCEPTION_IF_NULL(conf);
441                          return conf->ObtainEvalResult()->abstract();
442                        });
443   EvalResultPtr res = sub_evaluator_->Run(engine, args_conf_list, out_conf);
444   // Don't lookup from cache, as different out_conf with same node but different context
445   // may add different entry to anfnode_config_map_, like getattr primitive.
446   evaluator_cache_mgr_->SetValue(args_spec_list, res);
447   return res;
448 }
449 
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)450 EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
451                                        const AnfNodeConfigPtr &out_conf) {
452   AbstractBasePtrList args_spec_list;
453   (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
454                        [](const ConfigPtr &conf) -> AbstractBasePtr {
455                          MS_EXCEPTION_IF_NULL(conf);
456                          return conf->ObtainEvalResult()->abstract();
457                        });
458   MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
459   auto eval_result = evaluator_cache_mgr_->GetValue(args_spec_list);
460   if (eval_result != nullptr) {
461     return eval_result;
462   }
463 
464   ConfigPtrList partial_args_conf_list;
465   // Join arguments in partial and the rest arguments from args_conf_list.
466   (void)std::transform(args_spec_list_.begin(), args_spec_list_.end(), std::back_inserter(partial_args_conf_list),
467                        [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
468 
469   (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(partial_args_conf_list),
470                        [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
471   EvalResultPtr res = evaluator_->Run(engine, partial_args_conf_list, out_conf);
472   evaluator_cache_mgr_->SetValue(args_spec_list, res);
473   return res;
474 }
475 
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr &)476 EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) {
477   AbstractBasePtrList args_spec_list;
478   (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
479                        [](const ConfigPtr &conf) -> AbstractBasePtr {
480                          MS_EXCEPTION_IF_NULL(conf);
481                          return conf->ObtainEvalResult()->abstract();
482                        });
483   MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
484   auto eval_result = evaluator_cache_mgr_->GetValue(args_spec_list);
485   if (eval_result != nullptr) {
486     return eval_result;
487   }
488 
489   // Call the original evaluator, get the result: y = f(x)
490   EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr);
491   MS_EXCEPTION_IF_NULL(result);
492   // Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input
493   // parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y)
494   AbstractBasePtrList bparams;
495   bparams.push_back(SensitivityTransform(orig_func_));
496   auto context = MsContext::GetInstance();
497   MS_EXCEPTION_IF_NULL(context);
498   bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
499   (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams),
500                        [&enable_sparse](const AbstractBasePtr &arg_spec) -> AbstractBasePtr {
501                          MS_EXCEPTION_IF_NULL(arg_spec);
502                          if (enable_sparse && arg_spec->isa<AbstractTensor>()) {
503                            return std::make_shared<AbstractUndetermined>();
504                          }
505                          return SensitivityTransform(arg_spec);
506                        });
507   AbstractBasePtr bparams_final = std::make_shared<AbstractTuple>(bparams);
508   AbstractFunctionPtr bprop =
509     std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result->abstract()), bparams_final);
510 
511   // J(f)(J(x)) return a tuple (y, bprop_f)
512   AbstractBasePtrList jargs = {result->abstract(), bprop};
513   AbstractBasePtr jtuple = std::make_shared<AbstractTuple>(jargs);
514   auto res = std::make_shared<EvalResult>(jtuple, std::make_shared<AttrValueMap>());
515   evaluator_cache_mgr_->SetValue(args_spec_list, res);
516   return res;
517 }
518 
Eval(AnalysisEnginePtr,const AbstractBasePtrList & args_spec_list,const AnfNodeConfigPtr & out_conf)519 EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list,
520                                      const AnfNodeConfigPtr &out_conf) {
521   if (args_spec_list.size() != args_spec_list_.size()) {
522     MS_LOG(EXCEPTION) << "Arguments mismatch, parameters no: " << args_spec_list_.size()
523                       << ", arguments no: " << args_spec_list.size();
524   }
525   // Check each parameter and argument match;
526   for (std::size_t i = 0; i < args_spec_list.size(); i++) {
527     MS_EXCEPTION_IF_NULL(args_spec_list[i]);
528     (void)args_spec_list[i]->Join(args_spec_list_[i]);
529   }
530   return std::make_shared<EvalResult>(output_, std::make_shared<AttrValueMap>());
531 }
SingleRun(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)532 EvalResultPtr Evaluator::SingleRun(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
533                                    const AnfNodeConfigPtr &out_conf) {
534   EvalResultPtr result;
535   try {
536     result = this->Run(engine, args_conf_list, out_conf);
537   } catch (const std::exception &ex) {
538     MS_LOG(INFO) << "Eval " << ToString() << " throw exception.";
539     AnalysisSchedule::GetInstance().HandleException(ex);
540   }
541   AnalysisSchedule::GetInstance().Wait();
542   return result;
543 }
544 }  // namespace abstract
545 }  // namespace mindspore
546