• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "pipeline/jit/static_analysis/static_analysis.h"
20 #include <algorithm>
21 #include <set>
22 #include "abstract/abstract_value.h"
23 #include "pipeline/jit/static_analysis/prim.h"
24 #include "frontend/operator/ops.h"
25 #include "utils/symbolic.h"
26 #include "utils/ms_exception.h"
27 #include "ir/tensor.h"
28 #include "ir/func_graph_cloner.h"
29 #include "pipeline/jit/parse/data_converter.h"
30 #include "pipeline/jit/static_analysis/evaluator.h"
31 #include "debug/trace.h"
32 #include "debug/anf_ir_dump.h"
33 #include "pipeline/jit/static_analysis/async_eval_result.h"
34 
35 namespace mindspore {
36 namespace abstract {
37 // Record current depth of function call stack, including `stack_frame_depth`.
38 thread_local size_t function_call_depth;
39 thread_local size_t function_call_max_depth;
40 // Record current depth of stack frames call.
41 thread_local size_t stack_frame_depth;
42 thread_local size_t stack_frame_max_depth;
43 
ResetFunctionCallDepth()44 void ResetFunctionCallDepth() {
45   function_call_depth = 0;
46   function_call_max_depth = 0;
47 }
IncreaseFunctionCallDepth()48 void IncreaseFunctionCallDepth() {
49   function_call_depth++;
50   if (function_call_max_depth < function_call_depth) {
51     function_call_max_depth = function_call_depth;
52   }
53 }
DecreaseFunctionCallDepth()54 void DecreaseFunctionCallDepth() {
55   if (function_call_depth == 0) {
56     MS_LOG(EXCEPTION) << "Current function call depth is already 0, can not decrease it.";
57   }
58   function_call_depth--;
59 }
FunctionCallDepth()60 size_t FunctionCallDepth() { return function_call_depth; }
FunctionCallMaxDepth()61 size_t FunctionCallMaxDepth() { return function_call_max_depth; }
62 
ResetStackFrameDepth()63 void ResetStackFrameDepth() {
64   stack_frame_depth = 0;
65   stack_frame_max_depth = 0;
66 }
IncreaseStackFrameDepth()67 void IncreaseStackFrameDepth() {
68   stack_frame_depth++;
69   if (stack_frame_max_depth < stack_frame_depth) {
70     stack_frame_max_depth = stack_frame_depth;
71   }
72 }
DecreaseStackFrameDepth()73 void DecreaseStackFrameDepth() {
74   if (stack_frame_depth == 0) {
75     MS_LOG(EXCEPTION) << "Current stack frame depth is already 0, can not decrease it.";
76   }
77   stack_frame_depth--;
78 }
StackFrameDepth()79 size_t StackFrameDepth() { return stack_frame_depth; }
StackFrameMaxDepth()80 size_t StackFrameMaxDepth() { return stack_frame_max_depth; }
81 
IsIntermediateAbstract(const AbstractBasePtr & arg_spec)82 bool IsIntermediateAbstract(const AbstractBasePtr &arg_spec) {
83   MS_EXCEPTION_IF_NULL(arg_spec);
84   if (dyn_cast<AbstractScalar>(arg_spec)) {
85     auto v = arg_spec->GetValueTrack();
86     if (v->isa<SymbolicKeyInstance>()) {
87       return true;
88     }
89   }
90   return false;
91 }
92 
IntermediateJoin(const AbstractBasePtr & arg1,const AbstractBasePtr & arg2)93 AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBasePtr &arg2) {
94   if (dyn_cast<AbstractScalar>(arg1) && dyn_cast<AbstractScalar>(arg2)) {
95     MS_EXCEPTION_IF_NULL(arg1);
96     return arg1->Join(arg2);
97   }
98   return nullptr;
99 }
100 
operator ()(const AnfNodeConfigPtr conf) const101 std::size_t AnfNodeConfigHasher::operator()(const AnfNodeConfigPtr conf) const {
102   MS_EXCEPTION_IF_NULL(conf);
103   MS_EXCEPTION_IF_NULL(conf->node());
104   std::size_t hash_value = conf->node()->hash();
105   if (!conf->context()->IsDummyContext()) {
106     hash_value = hash_combine(hash_value, std::hash<AnalysisContext *>{}(conf->context().get()));
107   }
108   return hash_value;
109 }
110 
operator ()(const AnfNodeConfigPtr lhs,const AnfNodeConfigPtr rhs) const111 bool AnfNodeConfigEqual::operator()(const AnfNodeConfigPtr lhs, const AnfNodeConfigPtr rhs) const {
112   if (lhs == nullptr || rhs == nullptr) {
113     return false;
114   }
115   if (lhs == rhs) {
116     return true;
117   }
118   return (*lhs == *rhs);
119 }
120 
Run(const FuncGraphPtr & func_graph,const AbstractBasePtrList & args_spec_list)121 AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list) {
122   StaticAnalysisException::Instance().ClearException();
123   AnalysisResult result;
124   try {
125     MS_EXCEPTION_IF_NULL(func_graph);
126     ConfigPtrList args_conf_list;
127     (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list),
128                          [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
129     MS_EXCEPTION_IF_NULL(func_graph_manager_);
130     func_graph_manager_->AddFuncGraph(func_graph);
131     root_func_graph_ = func_graph;
132 
133     // Running the analyzer.
134     ResetFunctionCallDepth();
135     ResetStackFrameDepth();
136     AnalysisContextPtr dummy_context = AnalysisContext::DummyContext();
137     AnalysisContextPtr root_context = Run(func_graph, dummy_context, args_conf_list);
138     MS_EXCEPTION_IF_NULL(root_context);
139     auto root_context_fg = root_context->func_graph();
140     MS_EXCEPTION_IF_NULL(root_context_fg);
141     AnfNodeConfigPtr output_conf = MakeConfig(root_context_fg->get_return(), root_context, root_context_fg);
142     MS_EXCEPTION_IF_NULL(func_graph);
143     MS_LOG(INFO) << func_graph->ToString() << ": Run finished.";
144 
145     MS_EXCEPTION_IF_NULL(output_conf);
146     result.inferred = output_conf->ObtainEvalResult();
147     result.context = root_context;
148   } catch (const std::exception &ex) {
149     MS_LOG(INFO) << "Eval " << func_graph->ToString() << " threw exception.";
150     AnalysisSchedule::GetInstance().HandleException(ex);
151   }
152   AnalysisSchedule::GetInstance().Wait();
153   return result;
154 }
155 
Run(const FuncGraphPtr & func_graph,const AnalysisContextPtr & context,const ConfigPtrList & args_conf_list)156 AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
157                                        const ConfigPtrList &args_conf_list) {
158   std::shared_ptr<FuncGraphEvaluator> eval = std::make_shared<FuncGraphEvaluator>(func_graph, context);
159   (void)eval->Run(shared_from_this(), args_conf_list, nullptr);
160   return root_context_;
161 }
162 
SaveEvalResultInCache(const AnfNodeConfigPtr & conf,const EvalResultPtr & result)163 void AnalysisEngine::SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) {
164   MS_EXCEPTION_IF_NULL(conf);
165   MS_EXCEPTION_IF_NULL(result);
166   static AnalysisResultCacheMgr &cache_mgr = AnalysisResultCacheMgr::GetInstance();
167   cache_mgr.SetValue(conf, result);
168 
169   // Set intermediate abstract value.
170   if (IsIntermediateAbstract(result->abstract())) {
171     if (conf->node()->intermediate_abstract() == nullptr) {
172       conf->node()->set_intermediate_abstract(result->abstract());
173       MS_LOG(DEBUG) << "Set intermediate abstract: " << result->abstract()->ToString();
174     } else {
175       auto old_spec = conf->node()->intermediate_abstract();
176       auto joined_spec = IntermediateJoin(result->abstract(), old_spec);
177       conf->node()->set_intermediate_abstract(joined_spec);
178       MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t"
179                     << result->abstract()->ToString() << "\njoined_spec:\t"
180                     << (joined_spec != nullptr ? joined_spec->ToString() : "nullptr");
181     }
182   }
183 }
184 
ObtainEvalResultWithCache(const AnfNodeConfigPtr & conf)185 EvalResultPtr AnalysisEngine::ObtainEvalResultWithCache(const AnfNodeConfigPtr &conf) {
186   MS_EXCEPTION_IF_NULL(conf);
187   static AnalysisResultCacheMgr &cache_mgr = AnalysisResultCacheMgr::GetInstance();
188   auto result = cache_mgr.GetValue(conf);
189   if (result != nullptr) {
190     return result;
191   }
192   MS_LOG(DEBUG) << "Evaluate cache miss for NodeConfig: " << conf->ToString();
193   result = Eval(conf);
194   if (result == nullptr) {
195     MS_LOG(EXCEPTION) << "Evaluate for NodeConfig " << conf->ToString() << " get nullptr";
196   }
197   MS_LOG(DEBUG) << "Evaluate node on demond for NodeConfig: " << conf->ToString()
198                 << ", result: " << result->abstract().get() << ", " << result->abstract()->ToString();
199   SaveEvalResultInCache(conf, result);
200   return result;
201 }
202 
Eval(const AnfNodeConfigPtr & conf)203 EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
204   MS_EXCEPTION_IF_NULL(conf);
205   AnfNodePtr node = conf->node();
206   EvalResultPtr eval_result = nullptr;
207 #ifdef DEBUG
208   compute_conf_stack_.push_back(node);
209   std::ostringstream buffer;
210   buffer << "Compute Config Begin:";
211   for (auto iter : compute_conf_stack_) {
212     buffer << " -> " << iter->DebugString();
213   }
214   MS_LOG(DEBUG) << buffer.str();
215 #endif
216   MS_LOG(DEBUG) << "Begin Eval NodeConfig " << conf->ToString();
217   MS_EXCEPTION_IF_NULL(node);
218   if (node->abstract() != nullptr) {
219     MS_LOG(DEBUG) << "Return old abstract: " << node->DebugString();
220     eval_result = std::make_shared<EvalResult>(node->abstract(), std::make_shared<AttrValueMap>());
221   } else if (node->isa<ValueNode>()) {
222     auto value_node = node->cast<ValueNodePtr>();
223     auto abstract = EvalValueNode(value_node, conf);
224     eval_result = std::make_shared<EvalResult>(abstract, std::make_shared<AttrValueMap>());
225   } else if (node->isa<CNode>()) {
226     auto cnode = node->cast<CNodePtr>();
227     trace::TraceEvalCNodeEnter(conf);
228     eval_result = EvalCNode(cnode, conf);
229     trace::TraceEvalCNodeLeave();
230   } else {
231     MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, node: " << node->DebugString() << "(" << node->type_name()
232                       << "), fg: " << (node->func_graph() != nullptr ? node->func_graph()->ToString() : "nullgraph");
233   }
234 
235 #ifdef DEBUG
236   compute_conf_stack_.pop_back();
237   if (eval_result == nullptr) {
238     MS_LOG(EXCEPTION) << "Compute Config failed, node: " << node->DebugString()
239                       << " NodeInfo: " << trace::GetDebugInfo(node->debug_info());
240   }
241 #endif
242   MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << eval_result->abstract()->ToString();
243   return eval_result;
244 }
245 
EvalValueNode(const ValueNodePtr & value_node,const AnfNodeConfigPtr & conf)246 AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) {
247   MS_EXCEPTION_IF_NULL(conf);
248   MS_EXCEPTION_IF_NULL(value_node);
249   auto out = ToAbstract(value_node->value(), conf->context(), conf);
250   if (value_node->has_new_value() && out->isa<AbstractTensor>()) {
251     out = out->Broaden();
252   }
253   return out;
254 }
255 
GetCNodeOperatorAbstract(const CNodePtr & cnode,const AnalysisContextPtr & context,const FuncGraphPtr & func_graph)256 AbstractBasePtr AnalysisEngine::GetCNodeOperatorAbstract(const CNodePtr &cnode, const AnalysisContextPtr &context,
257                                                          const FuncGraphPtr &func_graph) {
258   MS_EXCEPTION_IF_NULL(cnode);
259   auto &inputs = cnode->inputs();
260   if (inputs.empty()) {
261     MS_LOG(EXCEPTION) << "CNode->inputs() is empty, CNode: " << cnode->DebugString();
262   }
263   AnfNodePtr func_node = inputs[0];
264   MS_EXCEPTION_IF_NULL(func_node);
265   MS_LOG(DEBUG) << "Current CNode function: " << func_node->DebugString();
266   AnfNodeConfigPtr func_conf = MakeConfig(func_node, context, func_graph);
267   MS_EXCEPTION_IF_NULL(func_conf);
268   // Keep it in a local variable, otherwise smart pointer will free it.
269   auto possible_func_eval_result = func_conf->ObtainEvalResult();
270   AbstractBasePtr possible_func = possible_func_eval_result->abstract();
271   if (possible_func == nullptr) {
272     MS_LOG(EXCEPTION) << "No abstract, func_conf: " << func_conf->ToString();
273   }
274   return possible_func;
275 }
276 
EvalCNode(const CNodePtr & cnode,const AnfNodeConfigPtr & conf)277 EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
278   MS_EXCEPTION_IF_NULL(conf);
279   MS_EXCEPTION_IF_NULL(cnode);
280   AbstractBasePtr possible_func = GetCNodeOperatorAbstract(cnode, conf->context(), conf->func_graph());
281   if (possible_func->BuildType()->type_id() == kObjectTypeUndeterminedType) {
282     MS_LOG(DEBUG) << "EvalCNode eval Undetermined";
283     return std::make_shared<EvalResult>(possible_func->Clone(), std::make_shared<AttrValueMap>());
284   }
285 
286   AbstractFunctionPtr func = dyn_cast<AbstractFunction>(possible_func);
287   if (func == nullptr) {
288     MS_LOG(ERROR) << "Can not cast to a AbstractFunction from " << possible_func->ToString() << ".";
289     MS_LOG(ERROR) << "It's called at: " << cnode->DebugString();
290     MS_EXCEPTION(ValueError) << "This may be not defined, or it can't be a operator. Please check code.";
291   }
292 
293   ConfigPtrList args_conf_list;
294   // Ignore the first node which is function name
295   auto &inputs = cnode->inputs();
296   for (std::size_t i = 1; i < inputs.size(); i++) {
297     const AnfNodePtr &node = inputs[i];
298     args_conf_list.push_back(MakeConfig(node, conf->context(), conf->func_graph()));
299   }
300 
301   std::vector<EvaluatorPtr> evaluators;
302   auto build_evaluator = [this, &evaluators, &cnode](const AbstractFuncAtomPtr &poss) {
303     auto evaluator = this->GetEvaluatorFor(poss);
304     evaluator->set_bound_node(cnode);
305     evaluators.push_back(evaluator);
306   };
307   func->Visit(build_evaluator);
308 
309   auto eval_result = ExecuteEvaluators(evaluators, conf, args_conf_list);
310   return eval_result;
311 }
312 
Execute(const AbstractFunctionPtr & func,const AbstractBasePtrList & args_spec_list)313 EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) {
314   MS_EXCEPTION_IF_NULL(func);
315   ConfigPtrList args_conf_list;
316   (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list),
317                        [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
318   std::vector<EvaluatorPtr> infs;
319   MS_EXCEPTION_IF_NULL(func);
320   auto build_evaluator = [this, &infs](const AbstractFuncAtomPtr &poss) {
321     auto evaluator = this->GetEvaluatorFor(poss);
322     infs.push_back(evaluator);
323   };
324   func->Visit(build_evaluator);
325   return ExecuteEvaluators(infs, nullptr, args_conf_list);
326 }
327 
ClearEvaluatorCache()328 void AnalysisEngine::ClearEvaluatorCache() {
329   for (auto &element : evaluators_) {
330     EvaluatorPtr evaluator = element.second;
331     MS_EXCEPTION_IF_NULL(evaluator);
332     MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
333     evaluator->evaluator_cache_mgr()->Clear();
334   }
335   for (auto &element : prim_constructors_) {
336     EvaluatorPtr evaluator = element.second;
337     MS_EXCEPTION_IF_NULL(evaluator);
338     MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
339     evaluator->evaluator_cache_mgr()->Clear();
340   }
341   for (auto &element : prim_py_evaluators_) {
342     EvaluatorPtr evaluator = element.second;
343     MS_EXCEPTION_IF_NULL(evaluator);
344     MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
345     evaluator->evaluator_cache_mgr()->Clear();
346   }
347   // Release Exception to avoid hup at exit.
348   StaticAnalysisException::Instance().ClearException();
349 }
350 
Clear()351 void AnalysisEngine::Clear() {
352   AnalysisResultCacheMgr::GetInstance().Clear();
353   anfnode_config_map_.clear();
354   eval_trace_.clear();
355   evaluators_.clear();
356   constructors_app_.clear();
357   continued_evals_.clear();
358   root_func_graph_ = nullptr;
359   root_context_ = nullptr;
360 }
361 
GetPrimEvaluator(const PrimitivePtr & prim,const AnalysisEnginePtr & engine)362 EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) {
363   // Custom Primitive with python infer_shape, infer_type
364   MS_EXCEPTION_IF_NULL(prim);
365   if (prim->isa<prim::DoSignaturePrimitive>()) {
366     return std::make_shared<DoSignatureEvaluator>(prim);
367   }
368   if (prim->isa<prim::UnpackGraphPrimitive>()) {
369     return std::make_shared<UnpackGraphEvaluator>(prim);
370   }
371   if (prim->Hash() == prim::kPrimMixedPrecisionCast->Hash() && prim->name() == prim::kPrimMixedPrecisionCast->name()) {
372     return std::make_shared<MixedPrecisionCastEvaluator>(prim);
373   }
374 
375   // find prim infer function in the prim function map return a standard evaluator
376   auto eval_impl = GetPrimitiveInferImpl(prim);
377   if (eval_impl.infer_shape_impl_ != nullptr) {
378     return std::make_shared<StandardPrimEvaluator>(prim, eval_impl);
379   }
380 
381   // use python infer function if the infer function not founded in the map return a python evaluator
382   EvaluatorPtr evaluator = nullptr;
383   if (prim->HasPyEvaluator()) {
384     auto prim_py = dyn_cast<PrimitivePy>(prim);
385     if (prim_py != nullptr) {
386       if (engine == nullptr) {
387         return std::make_shared<PythonPrimEvaluator>(prim_py);
388       }
389 
390       const auto &iter = engine->prim_py_evaluators_.find(prim_py);
391       if (iter != engine->prim_py_evaluators_.end()) {
392         return iter->second;
393       }
394       evaluator = std::make_shared<PythonPrimEvaluator>(prim_py);
395       engine->prim_py_evaluators_[prim_py] = evaluator;
396       return evaluator;
397     }
398     MS_LOG(ERROR) << "The primitive with python evaluator should be a python primitive.";
399     return nullptr;
400   }
401 
402   // return a default evaluator
403   if (engine == nullptr) {
404     // If engine is nullptr, get constructor from default.
405     const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors();
406     auto iter = prim_evaluator_map.find(prim);
407     if (iter != prim_evaluator_map.end()) {
408       evaluator = iter->second;
409     }
410   } else {
411     // If engine is given, get constructor from engine resource.
412     const PrimEvaluatorMap &prim_evaluator_map = engine->PrimConstructors();
413     auto iter = prim_evaluator_map.find(prim);
414     if (iter != prim_evaluator_map.end()) {
415       evaluator = iter->second;
416     }
417   }
418   if (evaluator == nullptr) {
419     MS_LOG(DEBUG) << "The evaluator of the primitive is not defined (" << prim->name() << ").";
420   }
421   return evaluator;
422 }
423 
_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> & func)424 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &func) {
425   MS_EXCEPTION_IF_NULL(func);
426   auto inf_pair = evaluators_.find(func);
427   if (inf_pair != evaluators_.end()) {
428     return inf_pair->second;
429   }
430   auto primitive = func->prim();
431   auto evaluator = GetPrimEvaluator(primitive, shared_from_this());
432   if (evaluator == nullptr) {
433     MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << primitive->name() << ").";
434   }
435   evaluators_[func] = evaluator;
436   return evaluator;
437 }
438 
_GetEvaluatorFor(const std::shared_ptr<FuncGraphAbstractClosure> & func)439 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<FuncGraphAbstractClosure> &func) {
440   MS_EXCEPTION_IF_NULL(func);
441   auto inf_pair = evaluators_.find(func);
442   if (inf_pair != evaluators_.end()) {
443     return inf_pair->second;
444   }
445   std::shared_ptr<FuncGraphEvaluator> func_graph_evaluator =
446     std::make_shared<FuncGraphEvaluator>(func->func_graph(), func->context());
447   evaluators_[func] = func_graph_evaluator;
448   return func_graph_evaluator;
449 }
450 
_GetEvaluatorFor(const std::shared_ptr<MetaFuncGraphAbstractClosure> & func)451 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<MetaFuncGraphAbstractClosure> &func) {
452   MS_EXCEPTION_IF_NULL(func);
453   auto inf_pair = evaluators_.find(func);
454   if (inf_pair != evaluators_.end()) {
455     return inf_pair->second;
456   }
457 
458   std::shared_ptr<MetaFuncGraphEvaluator> evaluator =
459     std::make_shared<MetaFuncGraphEvaluator>(func->meta_func_graph(), func->GetScope());
460   evaluators_[func] = evaluator;
461   return evaluator;
462 }
463 
_GetEvaluatorFor(const std::shared_ptr<JTransformedAbstractClosure> & func)464 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<JTransformedAbstractClosure> &func) {
465   MS_EXCEPTION_IF_NULL(func);
466   AbstractFunctionPtr func_orig = func->fn();
467   EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig);
468   auto jevaluator = std::make_shared<JEvaluator>(evaluator_orig, func_orig);
469   return jevaluator;
470 }
471 
_GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> & func)472 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> &func) {
473   MS_EXCEPTION_IF_NULL(func);
474   std::shared_ptr<VirtualEvaluator> virtual_evaluator =
475     std::make_shared<VirtualEvaluator>(func->args_spec_list(), func->output());
476   return virtual_evaluator;
477 }
478 
_GetEvaluatorFor(const std::shared_ptr<PartialAbstractClosure> & func)479 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PartialAbstractClosure> &func) {
480   MS_EXCEPTION_IF_NULL(func);
481   AbstractFunctionPtr func_orig = func->fn();
482   EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig);
483   auto part_pair = std::make_pair(func_orig, func->args());
484   auto itr = constructors_app_.find(part_pair);
485   if (itr != constructors_app_.end()) {
486     return itr->second;
487   }
488   std::shared_ptr<PartialAppEvaluator> partial_evaluator =
489     std::make_shared<PartialAppEvaluator>(evaluator_orig, func->args());
490   constructors_app_[part_pair] = partial_evaluator;
491   return partial_evaluator;
492 }
493 
_GetEvaluatorFor(const std::shared_ptr<TypedPrimitiveAbstractClosure> &)494 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<TypedPrimitiveAbstractClosure> &) {
495   MS_LOG(EXCEPTION) << "Should not be called ";
496 }
497 
498 // Forward to specific subclass of FunctionWrapper.
_GetEvaluatorFor(const AbstractFunctionPtr & func)499 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) {
500   MS_EXCEPTION_IF_NULL(func);
501   if (func->isa<PrimitiveAbstractClosure>()) {
502     return _GetEvaluatorFor(func->cast<std::shared_ptr<PrimitiveAbstractClosure>>());
503   } else if (func->isa<FuncGraphAbstractClosure>()) {
504     return _GetEvaluatorFor(func->cast<std::shared_ptr<FuncGraphAbstractClosure>>());
505   } else if (func->isa<MetaFuncGraphAbstractClosure>()) {
506     return _GetEvaluatorFor(func->cast<std::shared_ptr<MetaFuncGraphAbstractClosure>>());
507   } else if (func->isa<JTransformedAbstractClosure>()) {
508     return _GetEvaluatorFor(func->cast<std::shared_ptr<JTransformedAbstractClosure>>());
509   } else if (func->isa<VirtualAbstractClosure>()) {
510     return _GetEvaluatorFor(func->cast<std::shared_ptr<VirtualAbstractClosure>>());
511   } else if (func->isa<PartialAbstractClosure>()) {
512     return _GetEvaluatorFor(func->cast<std::shared_ptr<PartialAbstractClosure>>());
513   } else if (func->isa<TypedPrimitiveAbstractClosure>()) {
514     return _GetEvaluatorFor(func->cast<std::shared_ptr<TypedPrimitiveAbstractClosure>>());
515   } else if (func->isa<AbstractFuncAtom>()) {
516     MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncAtom";
517   } else if (func->isa<AbstractFuncUnion>()) {
518     MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncUnion";
519   } else if (func->isa<DummyAbstractClosure>()) {
520     MS_LOG(EXCEPTION) << "A dummy function cannot eval";
521   } else {
522     MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFunction";
523   }
524 }
525 
GetEvaluatorFor(const AbstractFunctionPtr & func)526 EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
527   MS_EXCEPTION_IF_NULL(func);
528   MS_LOG(DEBUG) << "The func value: " << func->ToString();
529   if (func->tracking_id() != nullptr) {
530     MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString();
531   }
532 
533   if (func->tracking_id() == nullptr || func->isa<abstract::MetaFuncGraphAbstractClosure>() ||
534       func->isa<abstract::FuncGraphAbstractClosure>()) {
535     EvaluatorPtr evaluator = _GetEvaluatorFor(func);
536     return evaluator;
537   }
538   auto inf_pair = evaluators_.find(func);
539   if (inf_pair != evaluators_.end()) {
540     return inf_pair->second;
541   }
542 
543   AbstractFunctionPtr func_generic = func->Copy();
544   func_generic->set_tracking_id(nullptr);
545   EvaluatorPtr eval = _GetEvaluatorFor(func_generic);
546   auto tracked_eval = std::make_shared<TrackedEvaluator>(eval);
547   evaluators_[func] = tracked_eval;
548 
549   return tracked_eval;
550 }
551 
ForwardConfig(const AnfNodeConfigPtr & orig_conf,const AnfNodeConfigPtr new_conf)552 EvalResultPtr AnalysisEngine::ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) {
553   MS_EXCEPTION_IF_NULL(orig_conf);
554   MS_EXCEPTION_IF_NULL(new_conf);
555   // Use anfnode_config_map_[orig_conf] = new_conf will require AnfNodeConfig provide copy constructor.
556   (void)anfnode_config_map_.emplace(orig_conf, new_conf);
557   MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->node()->DebugString()
558                 << ", to new_conf: " << new_conf->node()->DebugString();
559   if (orig_conf->node()->isa<CNode>()) {
560     auto old_cnode = orig_conf->node()->cast<CNodePtr>();
561     MS_EXCEPTION_IF_NULL(old_cnode);
562     if (new_conf->node()->isa<CNode>()) {
563       auto new_cnode = new_conf->node()->cast<CNodePtr>();
564       MS_EXCEPTION_IF_NULL(new_cnode);
565       MS_EXCEPTION_IF_NULL(old_cnode->func_graph());
566       if (old_cnode->func_graph() == new_cnode->func_graph()) {
567         MS_LOG(DEBUG) << "Try to remove forward node from order list, forward node: " << new_cnode->ToString()
568                       << ", as origin node should be in order list, origin_node: " << old_cnode->ToString();
569         old_cnode->func_graph()->EraseUnusedNodeInOrder(new_cnode);
570       } else {
571         MS_LOG(EXCEPTION) << "Forward orig_node to different func_graph, old_node: " << old_cnode->DebugString()
572                           << ", new_node: " << new_cnode->DebugString();
573       }
574     }
575   }
576   (void)forward_count_++;
577   auto res = ObtainEvalResultWithCache(new_conf);
578   (void)forward_count_--;
579   return res;
580 }
581 
ExecuteEvaluators(const std::vector<EvaluatorPtr> & evaluators,const AnfNodeConfigPtr & out_conf,const ConfigPtrList & args_conf_list)582 EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators,
583                                                 const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list) {
584   if (evaluators.size() == 1) {
585     EvaluatorPtr eval = evaluators[0];
586     MS_EXCEPTION_IF_NULL(eval);
587     return eval->Run(shared_from_this(), args_conf_list, out_conf);
588   }
589   static bool enable_singleThread = (common::GetEnv("ENV_SINGLE_EVAL") == "1");
590   if (enable_singleThread) {
591     return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list);
592   } else {
593     return ExecuteMultipleEvaluatorsMultiThread(evaluators, out_conf, args_conf_list);
594   }
595 }
596 
SetUndeterminedFlag(const EvaluatorPtr & evaluator,const FuncGraphPtr & possible_parent_fg)597 void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator, const FuncGraphPtr &possible_parent_fg) {
598   MS_EXCEPTION_IF_NULL(evaluator);
599   static std::mutex fg_lock;
600   std::lock_guard<std::mutex> infer_lock(fg_lock);
601   auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>();
602   if (fg_eval == nullptr) {
603     return;
604   }
605 
606   auto fg = fg_eval->func_graph();
607   MS_EXCEPTION_IF_NULL(fg);
608   auto undetermined_fgs = fg->recursive();
609   if (undetermined_fgs) {
610     auto fg_parent = fg->parent();
611     if (fg_parent != nullptr) {
612       fg_parent->set_flag(kFuncGraphFlagUndetermined, true);
613       MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString() << " for fg: " << fg->ToString();
614       return;
615     } else if (possible_parent_fg != nullptr) {
616       possible_parent_fg->set_flag(kFuncGraphFlagUndetermined, true);
617       MS_LOG(DEBUG) << "Set graph undetermined: " << possible_parent_fg->ToString() << " for fg: " << fg->ToString();
618     } else {
619       MS_LOG(EXCEPTION) << "cannot find parent for fg: " << fg->ToString();
620     }
621   }
622 }
623 
HandleNestedRecursion(const std::vector<EvaluatorPtr> & evaluators,const EvaluatorPtr & eval,const AbstractBasePtrList & args_spec_list,const EvalTraceRevIter & it,bool * continue_flag)624 EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators,
625                                                    const EvaluatorPtr &eval, const AbstractBasePtrList &args_spec_list,
626                                                    const EvalTraceRevIter &it, bool *continue_flag) {
627   MS_EXCEPTION_IF_NULL(continue_flag);
628   MS_EXCEPTION_IF_NULL(eval);
629   *continue_flag = false;
630   // Find latest entry function to handle nested recursion.
631   EvaluatorPtr latest_entry = eval;
632   auto latest_entry_iter = eval_trace_.rbegin();
633   for (auto r_it = eval_trace_.rbegin(); *r_it != *it;) {
634     auto it_temp = std::find(evaluators.begin(), evaluators.end(), r_it->evaluator_);
635     if (it_temp != evaluators.end()) {
636       latest_entry = *it_temp;
637       latest_entry_iter = r_it;
638       break;
639     }
640     latest_entry_iter = ++r_it;
641   }
642   if (latest_entry != eval) {
643     MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString();
644     *continue_flag = true;
645     return latest_entry;
646   }
647 
648   bool has_undetermined = false;
649   // Check whether sub loop has untraced undetermined evaluator.
650   std::unordered_set<EvaluatorArgs, EvaluatorArgsHasher, EvaluatorArgsEqual> undetermined_evals;
651   for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) {
652     undetermined_evals.insert(*r_it);
653   }
654   MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size();
655 
656   for (auto u_eval : undetermined_evals) {
657     MS_LOG(DEBUG) << u_eval.evaluator_->ToString() << "check undetermined.";
658     auto &alternate_evaluator = multi_poss_[u_eval.evaluator_];
659     auto eval_cache = alternate_evaluator->evaluator_cache_mgr();
660     const auto &alt_eval_args = EvaluatorArgs(alternate_evaluator, args_spec_list);
661     if ((!undetermined_evals.count(alt_eval_args)) &&
662         (((!continued_evals_.count(u_eval)) && (eval_cache->GetValue(args_spec_list) != nullptr)) ||
663          (eval_cache->GetValue(args_spec_list) == nullptr))) {
664       MS_LOG(DEBUG) << u_eval.evaluator_->ToString() << "has undetermined.";
665       has_undetermined = true;
666       break;
667     }
668   }
669   if (!has_undetermined) {
670     MS_LOG(DEBUG) << eval->ToString() << "has no undetermined.";
671     *continue_flag = true;
672     return latest_entry;
673   }
674 
675   return latest_entry;
676 }
677 
JoinBranchesFailedInfo(const AbstractBasePtr & spec,const AbstractBasePtr & last_spec,const AnfNodePtr & node,const std::string & error_info)678 std::string JoinBranchesFailedInfo(const AbstractBasePtr &spec, const AbstractBasePtr &last_spec,
679                                    const AnfNodePtr &node, const std::string &error_info) {
680   std::ostringstream buffer;
681   buffer << "The return values of different branches do not join. \n"
682          << error_info << "\nFor more details, please refer to the FAQ at https://www.mindspore.cn.\n"
683          << "The abstract type of the return value of the current branch is " << spec->ToString()
684          << ", and that of the previous branch is " << last_spec->ToString() << ".\n"
685          << "The node " << node->DebugString();
686   if (node->isa<CNode>()) {
687     auto cnode = node->cast<CNodePtr>()->input(0);
688     if (IsPrimitiveCNode(cnode, prim::kPrimSwitch)) {
689       // {prim::kPrimSwitch, cond, true_branch, false_branch}
690       constexpr int true_index = 2;
691       constexpr int false_index = 3;
692       auto inputs = cnode->cast<CNodePtr>()->inputs();
693       buffer << ", true branch: " << inputs.at(true_index)->ToString()
694              << ", false branch: " << inputs.at(false_index)->ToString();
695     } else if (IsPrimitiveCNode(cnode, prim::kPrimSwitchLayer)) {
696       // {prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, branch1, branch2, ...}}
697       constexpr int branch_index = 2;
698       auto tuple_node = cnode->cast<CNodePtr>()->input(branch_index);
699       if (IsPrimitiveCNode(tuple_node, prim::kPrimMakeTuple)) {
700         auto tuple_inputs = tuple_node->cast<CNodePtr>()->inputs();
701         for (size_t i = 1; i < tuple_inputs.size(); i++) {
702           buffer << ", branch" << i << ": " << tuple_inputs.at(i);
703         }
704       }
705     }
706   }
707   buffer << ". trace: " << trace::DumpSourceLines(node);
708   return buffer.str();
709 }
710 
ProcessEvalResults(const AbstractBasePtrList & out_specs,const AnfNodePtr & node)711 EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_specs, const AnfNodePtr &node) {
712   if (out_specs.empty()) {
713     MS_LOG(EXCEPTION) << "There is an endless loop for evaluator.";
714   }
715 
716   if (out_specs.size() == 1) {
717     MS_EXCEPTION_IF_NULL(out_specs[0]);
718     // If only one result derived, then broaden it to avoid wrong constant propagation.
719     return std::make_shared<EvalResult>(out_specs[0]->Broaden(), std::make_shared<AttrValueMap>());
720   }
721   MS_EXCEPTION_IF_NULL(node);
722 
723   AbstractBasePtr last_spec = out_specs[0];
724   AbstractBasePtr joined_spec = out_specs[0];
725   for (const auto &spec : out_specs) {
726     MS_EXCEPTION_IF_NULL(spec);
727     try {
728       joined_spec = joined_spec->Join(spec);
729     } catch (const py::type_error &ex) {
730       auto error_info = ExtractLoggingInfo(ex.what());
731       MS_EXCEPTION(TypeError) << JoinBranchesFailedInfo(spec, last_spec, node, error_info);
732     } catch (const py::value_error &ex) {
733       auto error_info = ExtractLoggingInfo(ex.what());
734       MS_EXCEPTION(ValueError) << JoinBranchesFailedInfo(spec, last_spec, node, error_info);
735     } catch (const std::exception &ex) {
736       auto error_info = ExtractLoggingInfo(ex.what());
737       MS_LOG(EXCEPTION) << JoinBranchesFailedInfo(spec, last_spec, node, error_info);
738     }
739     MS_EXCEPTION_IF_NULL(joined_spec);
740     last_spec = spec;
741   }
742 
743   MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString();
744   return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
745 }
746 
NeedWaitForBranches(const AbstractBasePtr & abstract)747 bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
748   MS_EXCEPTION_IF_NULL(abstract);
749   if (abstract->isa<AbstractFunction>()) {
750     return true;
751   }
752   if (abstract->isa<AbstractSequeue>()) {
753     auto elements = abstract->cast<AbstractSequeuePtr>()->elements();
754     if (std::any_of(elements.begin(), elements.end(),
755                     [](const AbstractBasePtr &item) { return item->isa<AbstractFunction>(); })) {
756       return true;
757     }
758   }
759   return false;
760 }
761 
ExecEvaluator(EvaluatorPtr eval,AnalysisEnginePtr engine,ConfigPtrList args_conf_list,AnfNodeConfigPtr out_conf,const std::string & threadID,AsyncAbstractPtr async_result_branch,AsyncAbstractPtr async_result_main,AsyncInferTaskPtr async_run_flag,const trace::TraceGraphEvalStack & graph_evals,const trace::TraceCNodeEvalStack & trace_c_node_evals)762 void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList args_conf_list, AnfNodeConfigPtr out_conf,
763                    const std::string &threadID, AsyncAbstractPtr async_result_branch,
764                    AsyncAbstractPtr async_result_main, AsyncInferTaskPtr async_run_flag,
765                    const trace::TraceGraphEvalStack &graph_evals,
766                    const trace::TraceCNodeEvalStack &trace_c_node_evals) {
767   AnalysisSchedule::SetThreadID(threadID);
768   // Restore trace stack for dump stack when there is exception.
769   trace::TraceEvalCNodeStackPrepare(trace_c_node_evals);
770   trace::TraceGraphEvalStackPrepare(graph_evals);
771 
772   try {
773     // Wait for Signal to run
774     MS_LOG(DEBUG) << async_run_flag.get() << "  " << eval->ToString() << " waiting.";
775     (void)async_run_flag->GetResult();
776     MS_LOG(DEBUG) << async_run_flag.get() << "  " << eval->ToString() << " running.";
777 
778     // Acquire GIL for eval to callback python.
779     EvalResultPtr result;
780     {
781       py::gil_scoped_acquire pyGuard;
782       result = eval->Run(engine, args_conf_list, out_conf);
783     }
784     MS_EXCEPTION_IF_NULL(result);
785     MS_EXCEPTION_IF_NULL(result->abstract());
786 
787     // Broaden the result of switch(c,t,f)()
788     auto broadAbstract = result->abstract()->Broaden();
789     // Notify the thread of waiting for switch node and the main thread to continue.
790     AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, broadAbstract);
791     async_result_branch->SetResult(broadAbstract);
792     async_result_main->SetResult(broadAbstract);
793     // Thread number will be drop when thread exits.
794     AnalysisSchedule::GetInstance().DecreaseThreadCount();
795     MS_LOG(DEBUG) << GetInferThread() << "async :" << eval->ToString()
796                   << " asyncResult address = " << async_result_branch.get()
797                   << " value = " << async_result_branch->TryGetResult()->ToString();
798   } catch (const std::exception &e1) {
799     auto abstractErrPtr = std::make_shared<AbstractError>(std::make_shared<StringImm>("Exception"), out_conf->node());
800     AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, abstractErrPtr);
801     async_result_main->SetResult(abstractErrPtr);
802     MS_LOG(INFO) << "Eval node: " << out_conf->node()->ToString() << "  " << eval->ToString() << " threw exception.";
803     AnalysisSchedule::GetInstance().HandleException(e1);
804     try {
805       // Thread number will be drop when thread exits.
806       AnalysisSchedule::GetInstance().DecreaseThreadCount();
807     } catch (const std::exception &e2) {
808       MS_LOG(DEBUG) << "AnalysisSchedule::GetInstance().DecreaseThreadCount() threw exception.";
809     }
810   }
811 }
812 
ExecuteMultipleEvaluatorsMultiThread(const std::vector<EvaluatorPtr> & evaluators,const AnfNodeConfigPtr & out_conf,const ConfigPtrList & args_conf_list)813 EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::vector<EvaluatorPtr> &evaluators,
814                                                                    const AnfNodeConfigPtr &out_conf,
815                                                                    const ConfigPtrList &args_conf_list) {
816   MS_EXCEPTION_IF_NULL(out_conf);
817   MS_EXCEPTION_IF_NULL(out_conf->node());
818   // Release GIL for C++
819   py::gil_scoped_release infer_gil_release;
820   // Wait for the last switch node to finish.
821   MS_LOG(DEBUG) << GetInferThread() << "async : entry switch  " << out_conf->ToString();
822   auto eval_result = AnalysisResultCacheMgr::GetInstance().GetSwitchValue(out_conf);
823   if (eval_result == nullptr) {
824     MS_LOG(DEBUG) << GetInferThread() << "async : Init switch  " << out_conf->node()->ToString();
825     AnalysisResultCacheMgr::GetInstance().InitSwitchValue(out_conf);
826   } else {
827     return std::make_shared<EvalResult>(eval_result, nullptr);
828   }
829   auto possible_parent_fg = out_conf->node()->func_graph();
830 
831   // Eval result of the main.
832   AsyncAbstractPtr asyncResult_main = std::make_shared<AsyncAbstract>();
833   // Eval result of the branches
834   std::vector<AsyncAbstractPtr> branchAsyncResults;
835 
836   for (auto &evaluator : evaluators) {
837     static std::atomic<int> idCount{0};
838     std::string threadId = AnalysisSchedule::GetThreadID() + "." + std::to_string(idCount.fetch_add(1));
839     MS_EXCEPTION_IF_NULL(evaluator);
840     SetUndeterminedFlag(evaluator, possible_parent_fg);
841     AsyncAbstractPtr branchAsyncResult = std::make_shared<AsyncAbstract>();
842     // Control the order to run.
843     AsyncAbstractPtr asyncRunOrder = std::make_shared<AsyncAbstract>();
844     AsyncInferTaskPtr asyncTask = AsyncInferTask::MakeShared(asyncRunOrder, threadId);
845     // Add point to the async thread.
846     AnalysisSchedule::GetInstance().IncreaseThreadCount();
847     MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluator->ToString();
848     auto thread =
849       std::thread(ExecEvaluator, evaluator, shared_from_this(), args_conf_list, out_conf, threadId, branchAsyncResult,
850                   asyncResult_main, asyncTask, trace::GetCurrentGraphEvalStack(), trace::GetCNodeDebugStack());
851     thread.detach();
852     // Push to list of running loop
853     asyncRunOrder->SetResult(std::make_shared<AbstractScalar>(1));
854     MS_LOG(DEBUG) << " add to schedule: " << asyncTask.get();
855     AnalysisSchedule::GetInstance().Add2Schedule(asyncTask);  // Activate order witch child thread.
856     (void)branchAsyncResults.emplace_back(std::move(branchAsyncResult));
857   }
858 
859   MS_LOG(DEBUG) << GetInferThread() << "async : wait for one of async to finish.  " << evaluators[0]->ToString()
860                 << " or  " << evaluators[1]->ToString() << "...";
861   auto async_main = AsyncInferTask::MakeShared(asyncResult_main);
862   MS_LOG(DEBUG) << " add to schedule: " << async_main.get();
863   AnalysisSchedule::GetInstance().Add2Schedule(async_main);  // Third order
864   auto firstResult = async_main->GetResult();
865   MS_EXCEPTION_IF_NULL(firstResult);
866   MS_LOG(DEBUG) << GetInferThread() << "async main thread result of " << out_conf->node()->ToString() << " = "
867                 << firstResult->ToString();
868 
869   AbstractBasePtrList out_specs;
870   size_t len = evaluators.size();
871   if (NeedWaitForBranches(firstResult)) {
872     for (size_t i = 0; i < len; ++i) {
873       MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[i]->ToString();
874       auto async_branch = AsyncInferTask::MakeShared(branchAsyncResults[i]);
875       MS_LOG(DEBUG) << " add to schedule: " << async_branch.get();
876       AnalysisSchedule::GetInstance().Add2Schedule(async_branch);
877       auto result = async_branch->GetResult();
878       MS_EXCEPTION_IF_NULL(result);
879       out_specs.push_back(result);
880     }
881   } else {
882     // Give one more chance to wait for the result of the branches.
883     auto async_tmp = AsyncInferTask::MakeShared(asyncResult_main);
884     MS_LOG(DEBUG) << " add to schedule: " << async_tmp.get();
885     AnalysisSchedule::GetInstance().Add2Schedule(async_tmp);
886     (void)async_tmp->GetResult();
887     for (size_t i = 0; i < len; ++i) {
888       // Not wait to get the result of branch.
889       auto result = branchAsyncResults[i]->TryGetResult();
890       if (result) {
891         MS_LOG(DEBUG) << GetInferThread() << "async get " << evaluators[i]->ToString()
892                       << " result: " << result->ToString();
893         out_specs.push_back(result);
894       }
895     }
896   }
897 
898   return ProcessEvalResults(out_specs, out_conf->node());
899 }
900 
ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> & evaluators,const AnfNodeConfigPtr & out_conf,const ConfigPtrList & args_conf_list)901 EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
902                                                         const AnfNodeConfigPtr &out_conf,
903                                                         const ConfigPtrList &args_conf_list) {
904   AbstractBasePtrList out_specs;
905   const size_t evaluators_size = 2;
906   if (evaluators.size() < evaluators_size) {
907     MS_LOG(ERROR) << "evaluators size is less than 2";
908   }
909   multi_poss_[evaluators[0]] = evaluators[1];
910   multi_poss_[evaluators[1]] = evaluators[0];
911   AbstractBasePtrList args_spec_list;
912   (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
913                        [](const ConfigPtr &conf) -> AbstractBasePtr {
914                          MS_EXCEPTION_IF_NULL(conf);
915                          return conf->ObtainEvalResult()->abstract();
916                        });
917   MS_EXCEPTION_IF_NULL(out_conf);
918   MS_EXCEPTION_IF_NULL(out_conf->node());
919   auto possible_parent_fg = out_conf->node()->func_graph();
920   for (auto eval : evaluators) {
921     MS_EXCEPTION_IF_NULL(eval);
922     (void)SetUndeterminedFlag(eval, possible_parent_fg);
923     const auto current_inf = EvaluatorArgs(eval, args_spec_list);
924     MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
925     // If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating.
926     auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf);
927     if (it == eval_trace_.rend()) {
928       eval_trace_.push_back(current_inf);
929       auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf);
930       auto eval_abstract = eval_result->abstract();
931       MS_EXCEPTION_IF_NULL(eval_abstract);
932 
933       out_specs.push_back(eval_abstract);
934       eval_trace_.pop_back();
935       if (eval_trace_.empty()) {
936         multi_poss_.clear();
937       }
938     } else {
939       bool continue_flag = false;
940       auto latest_entry = HandleNestedRecursion(evaluators, eval, args_spec_list, it, &continue_flag);
941       if (continue_flag) {
942         MS_LOG(DEBUG) << "continued_evals_ add " << current_inf.evaluator_.get() << current_inf.evaluator_->ToString();
943         continued_evals_.insert(current_inf);
944         continue;
945       }
946 
947       // Try to travel the latest undetermined.
948       if (latest_entry != eval_trace_.rbegin()->evaluator_) {
949         MS_LOG(DEBUG) << "Direct Run Evaluator " << eval.get() << "----" << eval->ToString();
950         auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf);
951         MS_EXCEPTION_IF_NULL(eval_result->abstract());
952         MS_LOG(DEBUG) << "end Direct Evaluator " << latest_entry->ToString()
953                       << " return out_spec: " << eval_result->abstract()->ToString();
954         return eval_result;
955       }
956     }
957   }
958 
959   return ProcessEvalResults(out_specs, out_conf->node());
960 }
961 
ObtainEvalResult()962 EvalResultPtr AnfNodeConfig::ObtainEvalResult() {
963   AnfNodeConfigPtr self = shared_from_base<AnfNodeConfig>();
964   return engine_.lock()->ObtainEvalResultWithCache(self);
965 }
966 
MakeAbstractClosure(const FuncGraphPtr & func_graph,const abstract::AnalysisContextPtr & context,const AnfNodePtr & anf_node)967 abstract::AbstractBasePtr MakeAbstractClosure(const FuncGraphPtr &func_graph,
968                                               const abstract::AnalysisContextPtr &context, const AnfNodePtr &anf_node) {
969   AnalysisContextPtr temp_context = context;
970   if (temp_context == nullptr) {
971     temp_context = abstract::AnalysisContext::DummyContext();
972   }
973   return std::make_shared<abstract::FuncGraphAbstractClosure>(func_graph, temp_context, anf_node);
974 }
975 
MakeAbstractClosure(const MetaFuncGraphPtr & meta_func_graph,const AnfNodePtr & anf_node)976 abstract::AbstractBasePtr MakeAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const AnfNodePtr &anf_node) {
977   abstract::MetaFuncGraphAbstractClosurePtr meta_func_graph_fn;
978   if (anf_node == nullptr) {
979     meta_func_graph_fn = std::make_shared<abstract::MetaFuncGraphAbstractClosure>(meta_func_graph);
980   } else {
981     meta_func_graph_fn =
982       std::make_shared<abstract::MetaFuncGraphAbstractClosure>(meta_func_graph, anf_node, anf_node->scope());
983   }
984   return meta_func_graph_fn;
985 }
986 
MakeAbstractClosure(const PrimitivePtr & primitive,const AnfNodePtr & anf_node)987 abstract::AbstractBasePtr MakeAbstractClosure(const PrimitivePtr &primitive, const AnfNodePtr &anf_node) {
988   auto prim_func = std::make_shared<abstract::PrimitiveAbstractClosure>(primitive, anf_node);
989   return prim_func;
990 }
991 
ToAbstract(const ValuePtr & value,const AnalysisContextPtr & context,const AnfNodeConfigPtr & conf)992 AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context, const AnfNodeConfigPtr &conf) {
993   MS_EXCEPTION_IF_NULL(value);
994   AnfNodePtr anf_node = nullptr;
995   if (conf != nullptr) {
996     anf_node = conf->node();
997   }
998   if (value->isa<FuncGraph>()) {
999     auto func_graph = value->cast<FuncGraphPtr>();
1000     return MakeAbstractClosure(func_graph, context, anf_node);
1001   }
1002   if (value->isa<MetaFuncGraph>()) {
1003     auto meta_func_graph = value->cast<MetaFuncGraphPtr>();
1004     return MakeAbstractClosure(meta_func_graph, anf_node);
1005   }
1006   if (value->isa<Primitive>()) {
1007     auto prim = value->cast<PrimitivePtr>();
1008     return MakeAbstractClosure(prim, anf_node);
1009   } else {
1010     return value->ToAbstract();
1011   }
1012 }
1013 
FromValueInside(const ValuePtr & value,bool broaden)1014 AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) {
1015   AbstractBasePtr a = ToAbstract(value, nullptr, nullptr);
1016   if (broaden) {
1017     a = a->Broaden();
1018   }
1019   return a;
1020 }
1021 
EvalOnePrim(const PrimitivePtr & primitive,const AbstractBasePtrList & arg_specs)1022 EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) {
1023   auto evaluator = GetPrimEvaluator(primitive, nullptr);
1024   if (evaluator == nullptr) {
1025     MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << primitive->name() << ").";
1026   }
1027   if (!evaluator->isa<TrivialPrimEvaluator>()) {
1028     MS_LOG(EXCEPTION) << "Prim " << primitive->ToString() << " should build a TrivialPrimEvaluator, but "
1029                       << evaluator->ToString();
1030   }
1031   auto trivial_evaluator = dyn_cast<TrivialPrimEvaluator>(evaluator);
1032   auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs);
1033   return eval_result;
1034 }
1035 }  // namespace abstract
1036 }  // namespace mindspore
1037