1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/jit/static_analysis/evaluator.h"
18
19 #include <algorithm>
20 #include <utility>
21 #include <unordered_set>
22
23 #include "ir/func_graph_cloner.h"
24 #include "abstract/utils.h"
25 #include "debug/trace.h"
26 #include "utils/ms_context.h"
27 #include "pipeline/jit/static_analysis/stack_frame.h"
28 #include "pipeline/jit/static_analysis/async_eval_result.h"
29
30 namespace mindspore {
31 namespace abstract {
32 namespace {
EvalEntryLogging(const EvaluatorPtr & evaluator,const AbstractBasePtrList & arg_spec_list,const AnfNodeConfigPtr & out_conf)33 string EvalEntryLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &arg_spec_list,
34 const AnfNodeConfigPtr &out_conf) {
35 MS_EXCEPTION_IF_NULL(evaluator);
36 std::stringstream ss;
37 if (out_conf != nullptr) {
38 ss << "Evaluator " << evaluator->ToString() << " run for " << out_conf->node()->scope()->name();
39 }
40 for (size_t i = 0; i < arg_spec_list.size(); i++) {
41 ss << evaluator->ToString() << " input[" << i
42 << "] abstract value: " << (arg_spec_list[i] ? arg_spec_list[i]->ToString() : "null abstract.");
43 }
44 return ss.str();
45 }
46
EvalFailLogging(const EvaluatorPtr & evaluator,const AbstractBasePtrList &,const AnfNodeConfigPtr & out_conf)47 void EvalFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &, const AnfNodeConfigPtr &out_conf) {
48 MS_EXCEPTION_IF_NULL(evaluator);
49 if (out_conf != nullptr) {
50 auto node = out_conf->node();
51 if (IsValueNode<Primitive>(node)) {
52 MS_LOG(ERROR) << "Evaluator " << evaluator->ToString() << " run failed for node " << node->fullname_with_scope()
53 << ", with debug info: " << trace::GetDebugInfo(node->debug_info());
54 } else {
55 MS_LOG(ERROR) << "Evaluator " << evaluator->ToString() << " run failed for node " << node->DebugString()
56 << ", with debug info: " << trace::GetDebugInfo(node->debug_info());
57 }
58 }
59 }
60 } // namespace
61
EnterStackFrame(const AnalysisEnginePtr & engine,const StackFramePtr & current_stack_frame,const StackFramePtr & new_stack_frame)62 void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame,
63 const StackFramePtr &new_stack_frame) {
64 MS_EXCEPTION_IF_NULL(current_stack_frame);
65 MS_EXCEPTION_IF_NULL(new_stack_frame);
66 MS_EXCEPTION_IF_NULL(engine);
67 // Enter new func graph.
68 auto ¤t_node = current_stack_frame->CurrentNode();
69 auto current_context = current_stack_frame->current_context();
70 AnfNodeConfigPtr call_conf = engine->MakeConfig(current_node, current_context, current_context->func_graph());
71 auto evaluator = new_stack_frame->evaluator();
72 MS_EXCEPTION_IF_NULL(evaluator);
73 auto new_context = new_stack_frame->current_context();
74 trace::TraceGraphEvalEnter(new_context, call_conf);
75
76 // Increase & Check the func graph call depth.
77 IncreaseFunctionCallDepth();
78 IncreaseStackFrameDepth();
79 const uint32_t max_depth = MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH);
80 if (FunctionCallDepth() > max_depth) {
81 MS_LOG(EXCEPTION) << "Exceed function call depth limit " << max_depth
82 << ", (function call depth: " << FunctionCallDepth()
83 << ", simulate call depth: " << StackFrameDepth() << ").\n"
84 << "It's always happened with complex construction of code or infinite recursion or loop.\n"
85 << "Please check the code if it's has the infinite recursion "
86 << "or call 'context.set_context(max_call_depth=value)' to adjust this value.\n"
87 << "If max_call_depth is set larger, the system max stack depth should be set larger too "
88 << "to avoid stack overflow.\n"
89 << "For more details, please refer to the FAQ at https://www.mindspore.cn.";
90 }
91 MS_LOG(DEBUG) << evaluator << "(" << evaluator->type_name() << "/" << evaluator->ToString()
92 << "), enter, function call depth: " << FunctionCallDepth() << " - " << StackFrameDepth();
93 }
94
LeaveStackFrame(const AnalysisEnginePtr &,const StackFramePtr & current_stack_frame)95 void BaseFuncGraphEvaluator::LeaveStackFrame(const AnalysisEnginePtr &, const StackFramePtr ¤t_stack_frame) {
96 MS_EXCEPTION_IF_NULL(current_stack_frame);
97 // Leave current func graph.
98 auto current_context = current_stack_frame->current_context();
99 trace::TraceGraphEvalLeave(current_context);
100
101 // Decrease the func graph call depth.
102 DecreaseFunctionCallDepth();
103 DecreaseStackFrameDepth();
104
105 auto evaluator = current_stack_frame->evaluator();
106 MS_EXCEPTION_IF_NULL(evaluator);
107 MS_LOG(DEBUG) << evaluator << "(" << evaluator->type_name() << "/" << evaluator->ToString()
108 << "), leave, function call depth: " << FunctionCallDepth() << " - " << StackFrameDepth();
109 }
110
111 // Start running stack frames in a Evaluator.
LaunchStackFrame(const AnalysisEnginePtr & engine,const FuncGraphPtr & fg,const AnalysisContextPtr & context)112 AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
113 const AnalysisContextPtr &context) {
114 EvalResultPtr eval_result = nullptr;
115 AbstractBasePtr res_base = nullptr;
116 std::stack<StackFramePtr> stack_frames;
117 auto current_stack_frame = std::make_shared<StackFrame>(shared_from_base<Evaluator>(), fg, context, parent_context_);
118 MS_LOG(DEBUG) << "[" << this << "/StackFrame] Start at func graph, " << current_stack_frame;
119 stack_frames.push(current_stack_frame);
120 while (true) {
121 current_stack_frame = stack_frames.top();
122 if (current_stack_frame->Done()) {
123 MS_EXCEPTION_IF_NULL(res_base);
124 MS_LOG(DEBUG) << "[" << this << "/StackFrame] Leave from func graph, " << current_stack_frame;
125 stack_frames.pop();
126 if (stack_frames.empty()) {
127 MS_LOG(DEBUG) << "[" << this << "/StackFrame] Finish at func graph, " << current_stack_frame
128 << ", res_base: " << res_base->ToString();
129 break;
130 }
131 // Leave current func graph.
132 LeaveStackFrame(engine, current_stack_frame);
133 // Switch the stack frame.
134 auto last_stack_frame = current_stack_frame;
135 current_stack_frame = stack_frames.top();
136 MS_LOG(DEBUG) << "[" << this << "/StackFrame] Back to func graph, " << current_stack_frame;
137 current_stack_frame->Back(engine, last_stack_frame, eval_result);
138 continue;
139 }
140
141 auto new_stack_frame = current_stack_frame->Jump(engine);
142 if (new_stack_frame != nullptr) {
143 // Enter new func graph.
144 EnterStackFrame(engine, current_stack_frame, new_stack_frame);
145 // Update current stack frame.
146 stack_frames.push(new_stack_frame);
147 current_stack_frame = new_stack_frame;
148 MS_LOG(DEBUG) << "[" << this << "/StackFrame] Jump to new func graph, " << new_stack_frame;
149 continue;
150 }
151
152 eval_result = current_stack_frame->Step(engine);
153 MS_EXCEPTION_IF_NULL(eval_result);
154 res_base = eval_result->abstract();
155 }
156 return res_base;
157 }
158
LaunchRecursiveEval(const AnalysisEnginePtr & engine,const FuncGraphPtr & fg,const AnalysisContextPtr & context)159 AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
160 const AnalysisContextPtr &context) {
161 MS_EXCEPTION_IF_NULL(fg);
162 MS_EXCEPTION_IF_NULL(engine);
163 const AnfNodePtr &func_node = fg->get_return();
164 const auto &all_nodes = TopoSort(func_node, SuccIncoming, [](const AnfNodePtr &node) -> IncludeType {
165 MS_EXCEPTION_IF_NULL(node);
166 if (node->isa<ValueNode>() || node->isa<Parameter>()) {
167 return EXCLUDE;
168 }
169 return FOLLOW;
170 });
171 AbstractBasePtr res_base = nullptr;
172 for (const auto &node : all_nodes) {
173 AnfNodeConfigPtr node_conf = engine->MakeConfig(node, context, fg);
174 MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg << "/" << fg->ToString()
175 << ", node_conf: " << node_conf->ToString();
176 auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
177 MS_EXCEPTION_IF_NULL(node_eval_result);
178 res_base = node_eval_result->abstract();
179 MS_EXCEPTION_IF_NULL(res_base);
180 MS_LOG(DEBUG) << GetInferThread() << "Eval ( " << node_conf->ToString() << ") = " << res_base->ToString();
181 }
182 MS_EXCEPTION_IF_NULL(res_base);
183 return res_base;
184 }
185
Eval(AnalysisEnginePtr engine,const AbstractBasePtrList & args_abs_list,const AnfNodeConfigPtr & out_conf)186 EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list,
187 const AnfNodeConfigPtr &out_conf) {
188 auto eval_result = evaluator_cache_mgr_->GetValue(args_abs_list);
189 if (eval_result != nullptr) {
190 MS_LOG(ERROR) << ToString() << ArgsToString(args_abs_list) << " entered again. There is something wrong.";
191 return eval_result;
192 } else {
193 MS_LOG(DEBUG) << ToString() << " entered first.";
194 }
195 MS_EXCEPTION_IF_NULL(engine);
196
197 // Increase & Check the func graph call depth.
198 IncreaseFunctionCallDepth();
199 const uint32_t max_depth = MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH);
200 if (FunctionCallDepth() > max_depth) {
201 MS_LOG(EXCEPTION) << "Exceed function call depth limit " << max_depth
202 << ", (function call depth: " << FunctionCallDepth()
203 << ", simulate call depth: " << StackFrameDepth() << ").\n"
204 << "It's always happened with complex construction of code or infinite recursion or loop.\n"
205 << "Please check the code if it's has the infinite recursion "
206 << "or call 'context.set_context(max_call_depth=value)' to adjust this value.\n"
207 << "If max_call_depth is set larger, the system max stack depth should be set larger too "
208 << "to avoid stack overflow.\n"
209 << "For more details, please refer to the FAQ at https://www.mindspore.cn.";
210 }
211 MS_LOG(DEBUG) << this << "(" << type_name() << "/" << ToString()
212 << "), enter, function call depth: " << FunctionCallDepth() << " - " << StackFrameDepth();
213
214 FuncGraphPtr fg = GetFuncGraph(engine, args_abs_list);
215 MS_EXCEPTION_IF_NULL(fg);
216 auto context = parent_context_->NewContext(fg, args_abs_list);
217 trace::TraceGraphEvalEnter(context, out_conf);
218
219 std::size_t nargs = fg->parameters().size();
220 if (args_abs_list.size() != nargs) {
221 MS_EXCEPTION(TypeError) << "The parameters number of the function is " << fg->parameters().size()
222 << ", but the number of provided arguments is " << args_abs_list.size() << ".\n"
223 << "FunctionGraph : " << fg->ToString()
224 << "\nNodeInfo: " << trace::GetDebugInfo(fg->debug_info());
225 }
226 MS_EXCEPTION_IF_NULL(parent_context_);
227 MS_LOG(DEBUG) << GetInferThread() << "@" << fg->ToString() << ArgsToString(args_abs_list) << " { ";
228 if (parent_context_->func_graph() != nullptr) {
229 MS_LOG(DEBUG) << GetInferThread() << "graph_: " << AnalysisSchedule::GetThreadID() << ":"
230 << parent_context_->func_graph()->ToString() << "()->" << AnalysisSchedule::GetThreadID() << ":"
231 << fg->ToString() << "();";
232 }
233
234 auto func_graph_evaluator = dyn_cast<FuncGraphEvaluator>(shared_from_base<BaseFuncGraphEvaluator>());
235 if (func_graph_evaluator != nullptr) {
236 if (engine->root_func_graph() == func_graph_evaluator->func_graph()) {
237 engine->set_root_context(context);
238 }
239 }
240 const auto ¶meters = fg->parameters();
241 for (size_t i = 0; i < nargs; i++) {
242 const auto &arg = args_abs_list[i];
243 const auto &node = parameters[i];
244 AnfNodeConfigPtr conf = engine->MakeConfig(node, context, fg);
245 engine->SaveEvalResultInCache(conf, std::make_shared<EvalResult>(arg, nullptr));
246 MS_LOG(DEBUG) << GetInferThread() << "Set Param: " << conf->ToString() << " = " << arg->ToString();
247 }
248 MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg << "/" << fg->ToString()
249 << ", context: " << context->ToString() << ", return node: " << fg->get_return()->DebugString()
250 << ", parent: " << (parent_context_->func_graph() ? parent_context_->func_graph()->ToString() : "NULL")
251 << ", current function call depth: " << FunctionCallDepth();
252 AbstractBasePtr res_base = nullptr;
253 if (engine->enable_recursive_eval()) {
254 res_base = LaunchRecursiveEval(engine, fg, context);
255 } else {
256 res_base = LaunchStackFrame(engine, fg, context);
257 }
258
259 MS_EXCEPTION_IF_NULL(res_base);
260 MS_LOG(DEBUG) << "Analysis FuncGraph end, " << fg << "/" << fg->ToString()
261 << ", evaluated abstract: " << res_base->ToString() << ", is stub: " << fg->stub();
262 if (fg->stub()) {
263 res_base = std::make_shared<AbstractUndetermined>();
264 }
265 MS_LOG(DEBUG) << GetInferThread() << "} //" << fg->ToString() << " = " << res_base->ToString();
266
267 trace::TraceGraphEvalLeave(context);
268 // Decrease the func graph call depth.
269 DecreaseFunctionCallDepth();
270 MS_LOG(DEBUG) << this << "(" << type_name() << "/" << ToString()
271 << "), leave, function call depth: " << FunctionCallDepth() << " - " << StackFrameDepth();
272 auto res = std::make_shared<EvalResult>(res_base, nullptr);
273 return res;
274 }
275
BroadenArgs(const AbstractBasePtrList & args_spec_list,AbstractBasePtrList * broaded_args)276 void BroadenArgs(const AbstractBasePtrList &args_spec_list, AbstractBasePtrList *broaded_args) {
277 MS_EXCEPTION_IF_NULL(broaded_args);
278 (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(*broaded_args),
279 [](const AbstractBasePtr &arg) -> AbstractBasePtr {
280 MS_EXCEPTION_IF_NULL(arg);
281 if (arg->GetValueTrack() != kAnyValue) {
282 return arg->Broaden();
283 }
284 return arg;
285 });
286 }
287
NormalizeArgs(const AbstractBasePtrList & args_spec_list) const288 AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
289 MS_EXCEPTION_IF_NULL(func_graph_);
290 if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
291 AbstractBasePtrList broaded_list;
292 BroadenArgs(args_spec_list, &broaded_list);
293 MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list)
294 << ", broaded: " << mindspore::ToString(broaded_list);
295 return broaded_list;
296 }
297 return args_spec_list;
298 }
299
BroadenUndeterminedArgs(const AbstractBasePtrList & args_spec_list)300 AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) {
301 MS_EXCEPTION_IF_NULL(func_graph_);
302 if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
303 return args_spec_list;
304 }
305
306 if (func_graph_->has_flag(kFuncGraphFlagUndetermined)) {
307 func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
308 auto normalized_args_spec_list = NormalizeArgs(args_spec_list);
309 MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
310 MS_LOG(DEBUG) << "Normalized args " << mindspore::ToString(normalized_args_spec_list);
311 return normalized_args_spec_list;
312 }
313 return args_spec_list;
314 }
315
GetFuncGraph(AnalysisEnginePtr engine,const AbstractBasePtrList & args_spec_list)316 FuncGraphPtr FuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
317 auto iter = func_graph_cache_.find(args_spec_list);
318 FuncGraphPtr res;
319 if (iter == func_graph_cache_.end()) {
320 auto fg = func_graph();
321 MS_EXCEPTION_IF_NULL(fg);
322 FuncGraphPtr generated_graph = fg->GenerateGraph(args_spec_list);
323 func_graph_cache_[args_spec_list] = generated_graph;
324 MS_EXCEPTION_IF_NULL(engine);
325 engine->func_graph_manager()->AddFuncGraph(generated_graph);
326 res = generated_graph;
327 } else {
328 res = iter->second;
329 }
330
331 // For the top graph, if it is replaced by generated graph, update the top graph to the new one.
332 if (parse::Parser::GetTopFuncGraph() == func_graph()) {
333 if (res != func_graph()) {
334 parse::Parser::UpdateTopFuncGraph(res);
335 }
336 }
337 return res;
338 }
339
GetFuncGraph(AnalysisEnginePtr engine,const AbstractBasePtrList & args_spec_list)340 FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
341 auto iter = func_graph_cache_.find(args_spec_list);
342 if (iter != func_graph_cache_.end()) {
343 return iter->second;
344 }
345
346 MS_EXCEPTION_IF_NULL(meta_func_graph_);
347 FuncGraphPtr generated_func_graph;
348 if (this->bound_node() != nullptr) {
349 TraceGuard trace_guard(std::make_shared<TraceGenMetaFuncGraph>(bound_node()->debug_info()));
350 generated_func_graph = meta_func_graph_->GenerateFuncGraph(args_spec_list);
351 } else {
352 generated_func_graph = meta_func_graph_->GenerateFuncGraph(args_spec_list);
353 }
354
355 FuncGraphPtr cloned_func_graph = BasicClone(generated_func_graph);
356 func_graph_cache_[args_spec_list] = cloned_func_graph;
357 MS_EXCEPTION_IF_NULL(engine);
358 engine->func_graph_manager()->AddFuncGraph(cloned_func_graph);
359 return cloned_func_graph;
360 }
361
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)362 EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
363 const AnfNodeConfigPtr &out_conf) {
364 AbstractBasePtrList args_spec_list;
365 (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
366 [](const ConfigPtr &conf) -> AbstractBasePtr {
367 MS_EXCEPTION_IF_NULL(conf);
368 return conf->ObtainEvalResult()->abstract();
369 });
370 args_spec_list = NormalizeArgs(args_spec_list);
371 args_spec_list = BroadenUndeterminedArgs(args_spec_list);
372
373 MS_LOG(DEBUG) << EvalEntryLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
374 const std::string &evaluator_name = ToString();
375 MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
376 auto eval_result = evaluator_cache_mgr_->GetValue(args_spec_list);
377 if (eval_result == nullptr) {
378 MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval().";
379 eval_result = Eval(engine, args_spec_list, out_conf);
380 MS_EXCEPTION_IF_NULL(eval_result);
381 if (eval_result->abstract() == nullptr) {
382 EvalFailLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
383 MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr.";
384 }
385 MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << eval_result->abstract()->ToString() << ".";
386 evaluator_cache_mgr_->SetValue(args_spec_list, eval_result);
387 } else {
388 MS_EXCEPTION_IF_NULL(eval_result);
389 MS_EXCEPTION_IF_NULL(eval_result->abstract());
390 MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << eval_result->abstract()->ToString() << ".";
391 }
392 return eval_result;
393 }
394
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr &)395 EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
396 const AnfNodeConfigPtr &) {
397 AbstractBasePtrList args_spec_list;
398 auto is_py_eval = (identifier_ == "PythonPrimEvaluator");
399 (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
400 [is_py_eval](const ConfigPtr &conf) -> AbstractBasePtr {
401 MS_EXCEPTION_IF_NULL(conf);
402 auto abstract = conf->ObtainEvalResult()->abstract();
403 MS_EXCEPTION_IF_NULL(abstract);
404 // broaden the ref_key, while infer python prim for cache
405 if (is_py_eval && abstract->isa<AbstractRef>()) {
406 auto abs_ref = abstract->cast<AbstractRefPtr>();
407 abstract = std::make_shared<AbstractRef>(abs_ref->ref_key()->Broaden(), abs_ref);
408 }
409 return abstract;
410 });
411 return EvalPrim(engine, args_spec_list);
412 }
413
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)414 EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
415 const AnfNodeConfigPtr &out_conf) {
416 if (args_conf_list.empty()) {
417 MS_LOG(EXCEPTION) << "Size should greater than 0";
418 }
419 AbstractBasePtrList args_spec_list;
420 (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
421 [](const ConfigPtr &conf) -> AbstractBasePtr {
422 MS_EXCEPTION_IF_NULL(conf);
423 return conf->ObtainEvalResult()->abstract();
424 });
425 EvalResultPtr res = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf);
426 // No need to cache.
427 return res;
428 }
429
Run(AnalysisEnginePtr,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr &)430 EvalResultPtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list,
431 const AnfNodeConfigPtr &) {
432 return EvalPrim(args_conf_list);
433 }
434
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)435 EvalResultPtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
436 const AnfNodeConfigPtr &out_conf) {
437 AbstractBasePtrList args_spec_list;
438 (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
439 [](const ConfigPtr &conf) -> AbstractBasePtr {
440 MS_EXCEPTION_IF_NULL(conf);
441 return conf->ObtainEvalResult()->abstract();
442 });
443 EvalResultPtr res = sub_evaluator_->Run(engine, args_conf_list, out_conf);
444 // Don't lookup from cache, as different out_conf with same node but different context
445 // may add different entry to anfnode_config_map_, like getattr primitive.
446 evaluator_cache_mgr_->SetValue(args_spec_list, res);
447 return res;
448 }
449
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)450 EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
451 const AnfNodeConfigPtr &out_conf) {
452 AbstractBasePtrList args_spec_list;
453 (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
454 [](const ConfigPtr &conf) -> AbstractBasePtr {
455 MS_EXCEPTION_IF_NULL(conf);
456 return conf->ObtainEvalResult()->abstract();
457 });
458 MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
459 auto eval_result = evaluator_cache_mgr_->GetValue(args_spec_list);
460 if (eval_result != nullptr) {
461 return eval_result;
462 }
463
464 ConfigPtrList partial_args_conf_list;
465 // Join arguments in partial and the rest arguments from args_conf_list.
466 (void)std::transform(args_spec_list_.begin(), args_spec_list_.end(), std::back_inserter(partial_args_conf_list),
467 [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
468
469 (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(partial_args_conf_list),
470 [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
471 EvalResultPtr res = evaluator_->Run(engine, partial_args_conf_list, out_conf);
472 evaluator_cache_mgr_->SetValue(args_spec_list, res);
473 return res;
474 }
475
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr &)476 EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) {
477 AbstractBasePtrList args_spec_list;
478 (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
479 [](const ConfigPtr &conf) -> AbstractBasePtr {
480 MS_EXCEPTION_IF_NULL(conf);
481 return conf->ObtainEvalResult()->abstract();
482 });
483 MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
484 auto eval_result = evaluator_cache_mgr_->GetValue(args_spec_list);
485 if (eval_result != nullptr) {
486 return eval_result;
487 }
488
489 // Call the original evaluator, get the result: y = f(x)
490 EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr);
491 MS_EXCEPTION_IF_NULL(result);
492 // Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input
493 // parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y)
494 AbstractBasePtrList bparams;
495 bparams.push_back(SensitivityTransform(orig_func_));
496 auto context = MsContext::GetInstance();
497 MS_EXCEPTION_IF_NULL(context);
498 bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
499 (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams),
500 [&enable_sparse](const AbstractBasePtr &arg_spec) -> AbstractBasePtr {
501 MS_EXCEPTION_IF_NULL(arg_spec);
502 if (enable_sparse && arg_spec->isa<AbstractTensor>()) {
503 return std::make_shared<AbstractUndetermined>();
504 }
505 return SensitivityTransform(arg_spec);
506 });
507 AbstractBasePtr bparams_final = std::make_shared<AbstractTuple>(bparams);
508 AbstractFunctionPtr bprop =
509 std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result->abstract()), bparams_final);
510
511 // J(f)(J(x)) return a tuple (y, bprop_f)
512 AbstractBasePtrList jargs = {result->abstract(), bprop};
513 AbstractBasePtr jtuple = std::make_shared<AbstractTuple>(jargs);
514 auto res = std::make_shared<EvalResult>(jtuple, std::make_shared<AttrValueMap>());
515 evaluator_cache_mgr_->SetValue(args_spec_list, res);
516 return res;
517 }
518
Eval(AnalysisEnginePtr,const AbstractBasePtrList & args_spec_list,const AnfNodeConfigPtr & out_conf)519 EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list,
520 const AnfNodeConfigPtr &out_conf) {
521 if (args_spec_list.size() != args_spec_list_.size()) {
522 MS_LOG(EXCEPTION) << "Arguments mismatch, parameters no: " << args_spec_list_.size()
523 << ", arguments no: " << args_spec_list.size();
524 }
525 // Check each parameter and argument match;
526 for (std::size_t i = 0; i < args_spec_list.size(); i++) {
527 MS_EXCEPTION_IF_NULL(args_spec_list[i]);
528 (void)args_spec_list[i]->Join(args_spec_list_[i]);
529 }
530 return std::make_shared<EvalResult>(output_, std::make_shared<AttrValueMap>());
531 }
SingleRun(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)532 EvalResultPtr Evaluator::SingleRun(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
533 const AnfNodeConfigPtr &out_conf) {
534 EvalResultPtr result;
535 try {
536 result = this->Run(engine, args_conf_list, out_conf);
537 } catch (const std::exception &ex) {
538 MS_LOG(INFO) << "Eval " << ToString() << " throw exception.";
539 AnalysisSchedule::GetInstance().HandleException(ex);
540 }
541 AnalysisSchedule::GetInstance().Wait();
542 return result;
543 }
544 } // namespace abstract
545 } // namespace mindspore
546