1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2021 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "pipeline/jit/static_analysis/static_analysis.h"
20 #include <algorithm>
21 #include <set>
22 #include "abstract/abstract_value.h"
23 #include "pipeline/jit/static_analysis/prim.h"
24 #include "frontend/operator/ops.h"
25 #include "utils/symbolic.h"
26 #include "utils/ms_exception.h"
27 #include "ir/tensor.h"
28 #include "ir/func_graph_cloner.h"
29 #include "pipeline/jit/parse/data_converter.h"
30 #include "pipeline/jit/static_analysis/evaluator.h"
31 #include "debug/trace.h"
32 #include "debug/anf_ir_dump.h"
33 #include "pipeline/jit/static_analysis/async_eval_result.h"
34
35 namespace mindspore {
36 namespace abstract {
37 // Record current depth of function call stack, including `stack_frame_depth`.
38 thread_local size_t function_call_depth;
39 thread_local size_t function_call_max_depth;
40 // Record current depth of stack frames call.
41 thread_local size_t stack_frame_depth;
42 thread_local size_t stack_frame_max_depth;
43
ResetFunctionCallDepth()44 void ResetFunctionCallDepth() {
45 function_call_depth = 0;
46 function_call_max_depth = 0;
47 }
IncreaseFunctionCallDepth()48 void IncreaseFunctionCallDepth() {
49 function_call_depth++;
50 if (function_call_max_depth < function_call_depth) {
51 function_call_max_depth = function_call_depth;
52 }
53 }
DecreaseFunctionCallDepth()54 void DecreaseFunctionCallDepth() {
55 if (function_call_depth == 0) {
56 MS_LOG(EXCEPTION) << "Current function call depth is already 0, can not decrease it.";
57 }
58 function_call_depth--;
59 }
FunctionCallDepth()60 size_t FunctionCallDepth() { return function_call_depth; }
FunctionCallMaxDepth()61 size_t FunctionCallMaxDepth() { return function_call_max_depth; }
62
ResetStackFrameDepth()63 void ResetStackFrameDepth() {
64 stack_frame_depth = 0;
65 stack_frame_max_depth = 0;
66 }
IncreaseStackFrameDepth()67 void IncreaseStackFrameDepth() {
68 stack_frame_depth++;
69 if (stack_frame_max_depth < stack_frame_depth) {
70 stack_frame_max_depth = stack_frame_depth;
71 }
72 }
DecreaseStackFrameDepth()73 void DecreaseStackFrameDepth() {
74 if (stack_frame_depth == 0) {
75 MS_LOG(EXCEPTION) << "Current stack frame depth is already 0, can not decrease it.";
76 }
77 stack_frame_depth--;
78 }
StackFrameDepth()79 size_t StackFrameDepth() { return stack_frame_depth; }
StackFrameMaxDepth()80 size_t StackFrameMaxDepth() { return stack_frame_max_depth; }
81
IsIntermediateAbstract(const AbstractBasePtr & arg_spec)82 bool IsIntermediateAbstract(const AbstractBasePtr &arg_spec) {
83 MS_EXCEPTION_IF_NULL(arg_spec);
84 if (dyn_cast<AbstractScalar>(arg_spec)) {
85 auto v = arg_spec->GetValueTrack();
86 if (v->isa<SymbolicKeyInstance>()) {
87 return true;
88 }
89 }
90 return false;
91 }
92
IntermediateJoin(const AbstractBasePtr & arg1,const AbstractBasePtr & arg2)93 AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBasePtr &arg2) {
94 if (dyn_cast<AbstractScalar>(arg1) && dyn_cast<AbstractScalar>(arg2)) {
95 MS_EXCEPTION_IF_NULL(arg1);
96 return arg1->Join(arg2);
97 }
98 return nullptr;
99 }
100
operator ()(const AnfNodeConfigPtr conf) const101 std::size_t AnfNodeConfigHasher::operator()(const AnfNodeConfigPtr conf) const {
102 MS_EXCEPTION_IF_NULL(conf);
103 MS_EXCEPTION_IF_NULL(conf->node());
104 std::size_t hash_value = conf->node()->hash();
105 if (!conf->context()->IsDummyContext()) {
106 hash_value = hash_combine(hash_value, std::hash<AnalysisContext *>{}(conf->context().get()));
107 }
108 return hash_value;
109 }
110
operator ()(const AnfNodeConfigPtr lhs,const AnfNodeConfigPtr rhs) const111 bool AnfNodeConfigEqual::operator()(const AnfNodeConfigPtr lhs, const AnfNodeConfigPtr rhs) const {
112 if (lhs == nullptr || rhs == nullptr) {
113 return false;
114 }
115 if (lhs == rhs) {
116 return true;
117 }
118 return (*lhs == *rhs);
119 }
120
Run(const FuncGraphPtr & func_graph,const AbstractBasePtrList & args_spec_list)121 AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list) {
122 StaticAnalysisException::Instance().ClearException();
123 AnalysisResult result;
124 try {
125 MS_EXCEPTION_IF_NULL(func_graph);
126 ConfigPtrList args_conf_list;
127 (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list),
128 [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
129 MS_EXCEPTION_IF_NULL(func_graph_manager_);
130 func_graph_manager_->AddFuncGraph(func_graph);
131 root_func_graph_ = func_graph;
132
133 // Running the analyzer.
134 ResetFunctionCallDepth();
135 ResetStackFrameDepth();
136 AnalysisContextPtr dummy_context = AnalysisContext::DummyContext();
137 AnalysisContextPtr root_context = Run(func_graph, dummy_context, args_conf_list);
138 MS_EXCEPTION_IF_NULL(root_context);
139 auto root_context_fg = root_context->func_graph();
140 MS_EXCEPTION_IF_NULL(root_context_fg);
141 AnfNodeConfigPtr output_conf = MakeConfig(root_context_fg->get_return(), root_context, root_context_fg);
142 MS_EXCEPTION_IF_NULL(func_graph);
143 MS_LOG(INFO) << func_graph->ToString() << ": Run finished.";
144
145 MS_EXCEPTION_IF_NULL(output_conf);
146 result.inferred = output_conf->ObtainEvalResult();
147 result.context = root_context;
148 } catch (const std::exception &ex) {
149 MS_LOG(INFO) << "Eval " << func_graph->ToString() << " threw exception.";
150 AnalysisSchedule::GetInstance().HandleException(ex);
151 }
152 AnalysisSchedule::GetInstance().Wait();
153 return result;
154 }
155
Run(const FuncGraphPtr & func_graph,const AnalysisContextPtr & context,const ConfigPtrList & args_conf_list)156 AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
157 const ConfigPtrList &args_conf_list) {
158 std::shared_ptr<FuncGraphEvaluator> eval = std::make_shared<FuncGraphEvaluator>(func_graph, context);
159 (void)eval->Run(shared_from_this(), args_conf_list, nullptr);
160 return root_context_;
161 }
162
SaveEvalResultInCache(const AnfNodeConfigPtr & conf,const EvalResultPtr & result)163 void AnalysisEngine::SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) {
164 MS_EXCEPTION_IF_NULL(conf);
165 MS_EXCEPTION_IF_NULL(result);
166 static AnalysisResultCacheMgr &cache_mgr = AnalysisResultCacheMgr::GetInstance();
167 cache_mgr.SetValue(conf, result);
168
169 // Set intermediate abstract value.
170 if (IsIntermediateAbstract(result->abstract())) {
171 if (conf->node()->intermediate_abstract() == nullptr) {
172 conf->node()->set_intermediate_abstract(result->abstract());
173 MS_LOG(DEBUG) << "Set intermediate abstract: " << result->abstract()->ToString();
174 } else {
175 auto old_spec = conf->node()->intermediate_abstract();
176 auto joined_spec = IntermediateJoin(result->abstract(), old_spec);
177 conf->node()->set_intermediate_abstract(joined_spec);
178 MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t"
179 << result->abstract()->ToString() << "\njoined_spec:\t"
180 << (joined_spec != nullptr ? joined_spec->ToString() : "nullptr");
181 }
182 }
183 }
184
ObtainEvalResultWithCache(const AnfNodeConfigPtr & conf)185 EvalResultPtr AnalysisEngine::ObtainEvalResultWithCache(const AnfNodeConfigPtr &conf) {
186 MS_EXCEPTION_IF_NULL(conf);
187 static AnalysisResultCacheMgr &cache_mgr = AnalysisResultCacheMgr::GetInstance();
188 auto result = cache_mgr.GetValue(conf);
189 if (result != nullptr) {
190 return result;
191 }
192 MS_LOG(DEBUG) << "Evaluate cache miss for NodeConfig: " << conf->ToString();
193 result = Eval(conf);
194 if (result == nullptr) {
195 MS_LOG(EXCEPTION) << "Evaluate for NodeConfig " << conf->ToString() << " get nullptr";
196 }
197 MS_LOG(DEBUG) << "Evaluate node on demond for NodeConfig: " << conf->ToString()
198 << ", result: " << result->abstract().get() << ", " << result->abstract()->ToString();
199 SaveEvalResultInCache(conf, result);
200 return result;
201 }
202
Eval(const AnfNodeConfigPtr & conf)203 EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
204 MS_EXCEPTION_IF_NULL(conf);
205 AnfNodePtr node = conf->node();
206 EvalResultPtr eval_result = nullptr;
207 #ifdef DEBUG
208 compute_conf_stack_.push_back(node);
209 std::ostringstream buffer;
210 buffer << "Compute Config Begin:";
211 for (auto iter : compute_conf_stack_) {
212 buffer << " -> " << iter->DebugString();
213 }
214 MS_LOG(DEBUG) << buffer.str();
215 #endif
216 MS_LOG(DEBUG) << "Begin Eval NodeConfig " << conf->ToString();
217 MS_EXCEPTION_IF_NULL(node);
218 if (node->abstract() != nullptr) {
219 MS_LOG(DEBUG) << "Return old abstract: " << node->DebugString();
220 eval_result = std::make_shared<EvalResult>(node->abstract(), std::make_shared<AttrValueMap>());
221 } else if (node->isa<ValueNode>()) {
222 auto value_node = node->cast<ValueNodePtr>();
223 auto abstract = EvalValueNode(value_node, conf);
224 eval_result = std::make_shared<EvalResult>(abstract, std::make_shared<AttrValueMap>());
225 } else if (node->isa<CNode>()) {
226 auto cnode = node->cast<CNodePtr>();
227 trace::TraceEvalCNodeEnter(conf);
228 eval_result = EvalCNode(cnode, conf);
229 trace::TraceEvalCNodeLeave();
230 } else {
231 MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, node: " << node->DebugString() << "(" << node->type_name()
232 << "), fg: " << (node->func_graph() != nullptr ? node->func_graph()->ToString() : "nullgraph");
233 }
234
235 #ifdef DEBUG
236 compute_conf_stack_.pop_back();
237 if (eval_result == nullptr) {
238 MS_LOG(EXCEPTION) << "Compute Config failed, node: " << node->DebugString()
239 << " NodeInfo: " << trace::GetDebugInfo(node->debug_info());
240 }
241 #endif
242 MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << eval_result->abstract()->ToString();
243 return eval_result;
244 }
245
EvalValueNode(const ValueNodePtr & value_node,const AnfNodeConfigPtr & conf)246 AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) {
247 MS_EXCEPTION_IF_NULL(conf);
248 MS_EXCEPTION_IF_NULL(value_node);
249 auto out = ToAbstract(value_node->value(), conf->context(), conf);
250 if (value_node->has_new_value() && out->isa<AbstractTensor>()) {
251 out = out->Broaden();
252 }
253 return out;
254 }
255
GetCNodeOperatorAbstract(const CNodePtr & cnode,const AnalysisContextPtr & context,const FuncGraphPtr & func_graph)256 AbstractBasePtr AnalysisEngine::GetCNodeOperatorAbstract(const CNodePtr &cnode, const AnalysisContextPtr &context,
257 const FuncGraphPtr &func_graph) {
258 MS_EXCEPTION_IF_NULL(cnode);
259 auto &inputs = cnode->inputs();
260 if (inputs.empty()) {
261 MS_LOG(EXCEPTION) << "CNode->inputs() is empty, CNode: " << cnode->DebugString();
262 }
263 AnfNodePtr func_node = inputs[0];
264 MS_EXCEPTION_IF_NULL(func_node);
265 MS_LOG(DEBUG) << "Current CNode function: " << func_node->DebugString();
266 AnfNodeConfigPtr func_conf = MakeConfig(func_node, context, func_graph);
267 MS_EXCEPTION_IF_NULL(func_conf);
268 // Keep it in a local variable, otherwise smart pointer will free it.
269 auto possible_func_eval_result = func_conf->ObtainEvalResult();
270 AbstractBasePtr possible_func = possible_func_eval_result->abstract();
271 if (possible_func == nullptr) {
272 MS_LOG(EXCEPTION) << "No abstract, func_conf: " << func_conf->ToString();
273 }
274 return possible_func;
275 }
276
EvalCNode(const CNodePtr & cnode,const AnfNodeConfigPtr & conf)277 EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
278 MS_EXCEPTION_IF_NULL(conf);
279 MS_EXCEPTION_IF_NULL(cnode);
280 AbstractBasePtr possible_func = GetCNodeOperatorAbstract(cnode, conf->context(), conf->func_graph());
281 if (possible_func->BuildType()->type_id() == kObjectTypeUndeterminedType) {
282 MS_LOG(DEBUG) << "EvalCNode eval Undetermined";
283 return std::make_shared<EvalResult>(possible_func->Clone(), std::make_shared<AttrValueMap>());
284 }
285
286 AbstractFunctionPtr func = dyn_cast<AbstractFunction>(possible_func);
287 if (func == nullptr) {
288 MS_LOG(ERROR) << "Can not cast to a AbstractFunction from " << possible_func->ToString() << ".";
289 MS_LOG(ERROR) << "It's called at: " << cnode->DebugString();
290 MS_EXCEPTION(ValueError) << "This may be not defined, or it can't be a operator. Please check code.";
291 }
292
293 ConfigPtrList args_conf_list;
294 // Ignore the first node which is function name
295 auto &inputs = cnode->inputs();
296 for (std::size_t i = 1; i < inputs.size(); i++) {
297 const AnfNodePtr &node = inputs[i];
298 args_conf_list.push_back(MakeConfig(node, conf->context(), conf->func_graph()));
299 }
300
301 std::vector<EvaluatorPtr> evaluators;
302 auto build_evaluator = [this, &evaluators, &cnode](const AbstractFuncAtomPtr &poss) {
303 auto evaluator = this->GetEvaluatorFor(poss);
304 evaluator->set_bound_node(cnode);
305 evaluators.push_back(evaluator);
306 };
307 func->Visit(build_evaluator);
308
309 auto eval_result = ExecuteEvaluators(evaluators, conf, args_conf_list);
310 return eval_result;
311 }
312
Execute(const AbstractFunctionPtr & func,const AbstractBasePtrList & args_spec_list)313 EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) {
314 MS_EXCEPTION_IF_NULL(func);
315 ConfigPtrList args_conf_list;
316 (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list),
317 [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
318 std::vector<EvaluatorPtr> infs;
319 MS_EXCEPTION_IF_NULL(func);
320 auto build_evaluator = [this, &infs](const AbstractFuncAtomPtr &poss) {
321 auto evaluator = this->GetEvaluatorFor(poss);
322 infs.push_back(evaluator);
323 };
324 func->Visit(build_evaluator);
325 return ExecuteEvaluators(infs, nullptr, args_conf_list);
326 }
327
ClearEvaluatorCache()328 void AnalysisEngine::ClearEvaluatorCache() {
329 for (auto &element : evaluators_) {
330 EvaluatorPtr evaluator = element.second;
331 MS_EXCEPTION_IF_NULL(evaluator);
332 MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
333 evaluator->evaluator_cache_mgr()->Clear();
334 }
335 for (auto &element : prim_constructors_) {
336 EvaluatorPtr evaluator = element.second;
337 MS_EXCEPTION_IF_NULL(evaluator);
338 MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
339 evaluator->evaluator_cache_mgr()->Clear();
340 }
341 for (auto &element : prim_py_evaluators_) {
342 EvaluatorPtr evaluator = element.second;
343 MS_EXCEPTION_IF_NULL(evaluator);
344 MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
345 evaluator->evaluator_cache_mgr()->Clear();
346 }
347 // Release Exception to avoid hup at exit.
348 StaticAnalysisException::Instance().ClearException();
349 }
350
Clear()351 void AnalysisEngine::Clear() {
352 AnalysisResultCacheMgr::GetInstance().Clear();
353 anfnode_config_map_.clear();
354 eval_trace_.clear();
355 evaluators_.clear();
356 constructors_app_.clear();
357 continued_evals_.clear();
358 root_func_graph_ = nullptr;
359 root_context_ = nullptr;
360 }
361
GetPrimEvaluator(const PrimitivePtr & prim,const AnalysisEnginePtr & engine)362 EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) {
363 // Custom Primitive with python infer_shape, infer_type
364 MS_EXCEPTION_IF_NULL(prim);
365 if (prim->isa<prim::DoSignaturePrimitive>()) {
366 return std::make_shared<DoSignatureEvaluator>(prim);
367 }
368 if (prim->isa<prim::UnpackGraphPrimitive>()) {
369 return std::make_shared<UnpackGraphEvaluator>(prim);
370 }
371 if (prim->Hash() == prim::kPrimMixedPrecisionCast->Hash() && prim->name() == prim::kPrimMixedPrecisionCast->name()) {
372 return std::make_shared<MixedPrecisionCastEvaluator>(prim);
373 }
374
375 // find prim infer function in the prim function map return a standard evaluator
376 auto eval_impl = GetPrimitiveInferImpl(prim);
377 if (eval_impl.infer_shape_impl_ != nullptr) {
378 return std::make_shared<StandardPrimEvaluator>(prim, eval_impl);
379 }
380
381 // use python infer function if the infer function not founded in the map return a python evaluator
382 EvaluatorPtr evaluator = nullptr;
383 if (prim->HasPyEvaluator()) {
384 auto prim_py = dyn_cast<PrimitivePy>(prim);
385 if (prim_py != nullptr) {
386 if (engine == nullptr) {
387 return std::make_shared<PythonPrimEvaluator>(prim_py);
388 }
389
390 const auto &iter = engine->prim_py_evaluators_.find(prim_py);
391 if (iter != engine->prim_py_evaluators_.end()) {
392 return iter->second;
393 }
394 evaluator = std::make_shared<PythonPrimEvaluator>(prim_py);
395 engine->prim_py_evaluators_[prim_py] = evaluator;
396 return evaluator;
397 }
398 MS_LOG(ERROR) << "The primitive with python evaluator should be a python primitive.";
399 return nullptr;
400 }
401
402 // return a default evaluator
403 if (engine == nullptr) {
404 // If engine is nullptr, get constructor from default.
405 const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors();
406 auto iter = prim_evaluator_map.find(prim);
407 if (iter != prim_evaluator_map.end()) {
408 evaluator = iter->second;
409 }
410 } else {
411 // If engine is given, get constructor from engine resource.
412 const PrimEvaluatorMap &prim_evaluator_map = engine->PrimConstructors();
413 auto iter = prim_evaluator_map.find(prim);
414 if (iter != prim_evaluator_map.end()) {
415 evaluator = iter->second;
416 }
417 }
418 if (evaluator == nullptr) {
419 MS_LOG(DEBUG) << "The evaluator of the primitive is not defined (" << prim->name() << ").";
420 }
421 return evaluator;
422 }
423
_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> & func)424 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &func) {
425 MS_EXCEPTION_IF_NULL(func);
426 auto inf_pair = evaluators_.find(func);
427 if (inf_pair != evaluators_.end()) {
428 return inf_pair->second;
429 }
430 auto primitive = func->prim();
431 auto evaluator = GetPrimEvaluator(primitive, shared_from_this());
432 if (evaluator == nullptr) {
433 MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << primitive->name() << ").";
434 }
435 evaluators_[func] = evaluator;
436 return evaluator;
437 }
438
_GetEvaluatorFor(const std::shared_ptr<FuncGraphAbstractClosure> & func)439 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<FuncGraphAbstractClosure> &func) {
440 MS_EXCEPTION_IF_NULL(func);
441 auto inf_pair = evaluators_.find(func);
442 if (inf_pair != evaluators_.end()) {
443 return inf_pair->second;
444 }
445 std::shared_ptr<FuncGraphEvaluator> func_graph_evaluator =
446 std::make_shared<FuncGraphEvaluator>(func->func_graph(), func->context());
447 evaluators_[func] = func_graph_evaluator;
448 return func_graph_evaluator;
449 }
450
_GetEvaluatorFor(const std::shared_ptr<MetaFuncGraphAbstractClosure> & func)451 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<MetaFuncGraphAbstractClosure> &func) {
452 MS_EXCEPTION_IF_NULL(func);
453 auto inf_pair = evaluators_.find(func);
454 if (inf_pair != evaluators_.end()) {
455 return inf_pair->second;
456 }
457
458 std::shared_ptr<MetaFuncGraphEvaluator> evaluator =
459 std::make_shared<MetaFuncGraphEvaluator>(func->meta_func_graph(), func->GetScope());
460 evaluators_[func] = evaluator;
461 return evaluator;
462 }
463
_GetEvaluatorFor(const std::shared_ptr<JTransformedAbstractClosure> & func)464 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<JTransformedAbstractClosure> &func) {
465 MS_EXCEPTION_IF_NULL(func);
466 AbstractFunctionPtr func_orig = func->fn();
467 EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig);
468 auto jevaluator = std::make_shared<JEvaluator>(evaluator_orig, func_orig);
469 return jevaluator;
470 }
471
_GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> & func)472 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> &func) {
473 MS_EXCEPTION_IF_NULL(func);
474 std::shared_ptr<VirtualEvaluator> virtual_evaluator =
475 std::make_shared<VirtualEvaluator>(func->args_spec_list(), func->output());
476 return virtual_evaluator;
477 }
478
_GetEvaluatorFor(const std::shared_ptr<PartialAbstractClosure> & func)479 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PartialAbstractClosure> &func) {
480 MS_EXCEPTION_IF_NULL(func);
481 AbstractFunctionPtr func_orig = func->fn();
482 EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig);
483 auto part_pair = std::make_pair(func_orig, func->args());
484 auto itr = constructors_app_.find(part_pair);
485 if (itr != constructors_app_.end()) {
486 return itr->second;
487 }
488 std::shared_ptr<PartialAppEvaluator> partial_evaluator =
489 std::make_shared<PartialAppEvaluator>(evaluator_orig, func->args());
490 constructors_app_[part_pair] = partial_evaluator;
491 return partial_evaluator;
492 }
493
_GetEvaluatorFor(const std::shared_ptr<TypedPrimitiveAbstractClosure> &)494 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<TypedPrimitiveAbstractClosure> &) {
495 MS_LOG(EXCEPTION) << "Should not be called ";
496 }
497
498 // Forward to specific subclass of FunctionWrapper.
_GetEvaluatorFor(const AbstractFunctionPtr & func)499 EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) {
500 MS_EXCEPTION_IF_NULL(func);
501 if (func->isa<PrimitiveAbstractClosure>()) {
502 return _GetEvaluatorFor(func->cast<std::shared_ptr<PrimitiveAbstractClosure>>());
503 } else if (func->isa<FuncGraphAbstractClosure>()) {
504 return _GetEvaluatorFor(func->cast<std::shared_ptr<FuncGraphAbstractClosure>>());
505 } else if (func->isa<MetaFuncGraphAbstractClosure>()) {
506 return _GetEvaluatorFor(func->cast<std::shared_ptr<MetaFuncGraphAbstractClosure>>());
507 } else if (func->isa<JTransformedAbstractClosure>()) {
508 return _GetEvaluatorFor(func->cast<std::shared_ptr<JTransformedAbstractClosure>>());
509 } else if (func->isa<VirtualAbstractClosure>()) {
510 return _GetEvaluatorFor(func->cast<std::shared_ptr<VirtualAbstractClosure>>());
511 } else if (func->isa<PartialAbstractClosure>()) {
512 return _GetEvaluatorFor(func->cast<std::shared_ptr<PartialAbstractClosure>>());
513 } else if (func->isa<TypedPrimitiveAbstractClosure>()) {
514 return _GetEvaluatorFor(func->cast<std::shared_ptr<TypedPrimitiveAbstractClosure>>());
515 } else if (func->isa<AbstractFuncAtom>()) {
516 MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncAtom";
517 } else if (func->isa<AbstractFuncUnion>()) {
518 MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncUnion";
519 } else if (func->isa<DummyAbstractClosure>()) {
520 MS_LOG(EXCEPTION) << "A dummy function cannot eval";
521 } else {
522 MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFunction";
523 }
524 }
525
GetEvaluatorFor(const AbstractFunctionPtr & func)526 EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
527 MS_EXCEPTION_IF_NULL(func);
528 MS_LOG(DEBUG) << "The func value: " << func->ToString();
529 if (func->tracking_id() != nullptr) {
530 MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString();
531 }
532
533 if (func->tracking_id() == nullptr || func->isa<abstract::MetaFuncGraphAbstractClosure>() ||
534 func->isa<abstract::FuncGraphAbstractClosure>()) {
535 EvaluatorPtr evaluator = _GetEvaluatorFor(func);
536 return evaluator;
537 }
538 auto inf_pair = evaluators_.find(func);
539 if (inf_pair != evaluators_.end()) {
540 return inf_pair->second;
541 }
542
543 AbstractFunctionPtr func_generic = func->Copy();
544 func_generic->set_tracking_id(nullptr);
545 EvaluatorPtr eval = _GetEvaluatorFor(func_generic);
546 auto tracked_eval = std::make_shared<TrackedEvaluator>(eval);
547 evaluators_[func] = tracked_eval;
548
549 return tracked_eval;
550 }
551
ForwardConfig(const AnfNodeConfigPtr & orig_conf,const AnfNodeConfigPtr new_conf)552 EvalResultPtr AnalysisEngine::ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) {
553 MS_EXCEPTION_IF_NULL(orig_conf);
554 MS_EXCEPTION_IF_NULL(new_conf);
555 // Use anfnode_config_map_[orig_conf] = new_conf will require AnfNodeConfig provide copy constructor.
556 (void)anfnode_config_map_.emplace(orig_conf, new_conf);
557 MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->node()->DebugString()
558 << ", to new_conf: " << new_conf->node()->DebugString();
559 if (orig_conf->node()->isa<CNode>()) {
560 auto old_cnode = orig_conf->node()->cast<CNodePtr>();
561 MS_EXCEPTION_IF_NULL(old_cnode);
562 if (new_conf->node()->isa<CNode>()) {
563 auto new_cnode = new_conf->node()->cast<CNodePtr>();
564 MS_EXCEPTION_IF_NULL(new_cnode);
565 MS_EXCEPTION_IF_NULL(old_cnode->func_graph());
566 if (old_cnode->func_graph() == new_cnode->func_graph()) {
567 MS_LOG(DEBUG) << "Try to remove forward node from order list, forward node: " << new_cnode->ToString()
568 << ", as origin node should be in order list, origin_node: " << old_cnode->ToString();
569 old_cnode->func_graph()->EraseUnusedNodeInOrder(new_cnode);
570 } else {
571 MS_LOG(EXCEPTION) << "Forward orig_node to different func_graph, old_node: " << old_cnode->DebugString()
572 << ", new_node: " << new_cnode->DebugString();
573 }
574 }
575 }
576 (void)forward_count_++;
577 auto res = ObtainEvalResultWithCache(new_conf);
578 (void)forward_count_--;
579 return res;
580 }
581
ExecuteEvaluators(const std::vector<EvaluatorPtr> & evaluators,const AnfNodeConfigPtr & out_conf,const ConfigPtrList & args_conf_list)582 EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators,
583 const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list) {
584 if (evaluators.size() == 1) {
585 EvaluatorPtr eval = evaluators[0];
586 MS_EXCEPTION_IF_NULL(eval);
587 return eval->Run(shared_from_this(), args_conf_list, out_conf);
588 }
589 static bool enable_singleThread = (common::GetEnv("ENV_SINGLE_EVAL") == "1");
590 if (enable_singleThread) {
591 return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list);
592 } else {
593 return ExecuteMultipleEvaluatorsMultiThread(evaluators, out_conf, args_conf_list);
594 }
595 }
596
SetUndeterminedFlag(const EvaluatorPtr & evaluator,const FuncGraphPtr & possible_parent_fg)597 void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator, const FuncGraphPtr &possible_parent_fg) {
598 MS_EXCEPTION_IF_NULL(evaluator);
599 static std::mutex fg_lock;
600 std::lock_guard<std::mutex> infer_lock(fg_lock);
601 auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>();
602 if (fg_eval == nullptr) {
603 return;
604 }
605
606 auto fg = fg_eval->func_graph();
607 MS_EXCEPTION_IF_NULL(fg);
608 auto undetermined_fgs = fg->recursive();
609 if (undetermined_fgs) {
610 auto fg_parent = fg->parent();
611 if (fg_parent != nullptr) {
612 fg_parent->set_flag(kFuncGraphFlagUndetermined, true);
613 MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString() << " for fg: " << fg->ToString();
614 return;
615 } else if (possible_parent_fg != nullptr) {
616 possible_parent_fg->set_flag(kFuncGraphFlagUndetermined, true);
617 MS_LOG(DEBUG) << "Set graph undetermined: " << possible_parent_fg->ToString() << " for fg: " << fg->ToString();
618 } else {
619 MS_LOG(EXCEPTION) << "cannot find parent for fg: " << fg->ToString();
620 }
621 }
622 }
623
HandleNestedRecursion(const std::vector<EvaluatorPtr> & evaluators,const EvaluatorPtr & eval,const AbstractBasePtrList & args_spec_list,const EvalTraceRevIter & it,bool * continue_flag)624 EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators,
625 const EvaluatorPtr &eval, const AbstractBasePtrList &args_spec_list,
626 const EvalTraceRevIter &it, bool *continue_flag) {
627 MS_EXCEPTION_IF_NULL(continue_flag);
628 MS_EXCEPTION_IF_NULL(eval);
629 *continue_flag = false;
630 // Find latest entry function to handle nested recursion.
631 EvaluatorPtr latest_entry = eval;
632 auto latest_entry_iter = eval_trace_.rbegin();
633 for (auto r_it = eval_trace_.rbegin(); *r_it != *it;) {
634 auto it_temp = std::find(evaluators.begin(), evaluators.end(), r_it->evaluator_);
635 if (it_temp != evaluators.end()) {
636 latest_entry = *it_temp;
637 latest_entry_iter = r_it;
638 break;
639 }
640 latest_entry_iter = ++r_it;
641 }
642 if (latest_entry != eval) {
643 MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString();
644 *continue_flag = true;
645 return latest_entry;
646 }
647
648 bool has_undetermined = false;
649 // Check whether sub loop has untraced undetermined evaluator.
650 std::unordered_set<EvaluatorArgs, EvaluatorArgsHasher, EvaluatorArgsEqual> undetermined_evals;
651 for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) {
652 undetermined_evals.insert(*r_it);
653 }
654 MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size();
655
656 for (auto u_eval : undetermined_evals) {
657 MS_LOG(DEBUG) << u_eval.evaluator_->ToString() << "check undetermined.";
658 auto &alternate_evaluator = multi_poss_[u_eval.evaluator_];
659 auto eval_cache = alternate_evaluator->evaluator_cache_mgr();
660 const auto &alt_eval_args = EvaluatorArgs(alternate_evaluator, args_spec_list);
661 if ((!undetermined_evals.count(alt_eval_args)) &&
662 (((!continued_evals_.count(u_eval)) && (eval_cache->GetValue(args_spec_list) != nullptr)) ||
663 (eval_cache->GetValue(args_spec_list) == nullptr))) {
664 MS_LOG(DEBUG) << u_eval.evaluator_->ToString() << "has undetermined.";
665 has_undetermined = true;
666 break;
667 }
668 }
669 if (!has_undetermined) {
670 MS_LOG(DEBUG) << eval->ToString() << "has no undetermined.";
671 *continue_flag = true;
672 return latest_entry;
673 }
674
675 return latest_entry;
676 }
677
JoinBranchesFailedInfo(const AbstractBasePtr & spec,const AbstractBasePtr & last_spec,const AnfNodePtr & node,const std::string & error_info)678 std::string JoinBranchesFailedInfo(const AbstractBasePtr &spec, const AbstractBasePtr &last_spec,
679 const AnfNodePtr &node, const std::string &error_info) {
680 std::ostringstream buffer;
681 buffer << "The return values of different branches do not join. \n"
682 << error_info << "\nFor more details, please refer to the FAQ at https://www.mindspore.cn.\n"
683 << "The abstract type of the return value of the current branch is " << spec->ToString()
684 << ", and that of the previous branch is " << last_spec->ToString() << ".\n"
685 << "The node " << node->DebugString();
686 if (node->isa<CNode>()) {
687 auto cnode = node->cast<CNodePtr>()->input(0);
688 if (IsPrimitiveCNode(cnode, prim::kPrimSwitch)) {
689 // {prim::kPrimSwitch, cond, true_branch, false_branch}
690 constexpr int true_index = 2;
691 constexpr int false_index = 3;
692 auto inputs = cnode->cast<CNodePtr>()->inputs();
693 buffer << ", true branch: " << inputs.at(true_index)->ToString()
694 << ", false branch: " << inputs.at(false_index)->ToString();
695 } else if (IsPrimitiveCNode(cnode, prim::kPrimSwitchLayer)) {
696 // {prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, branch1, branch2, ...}}
697 constexpr int branch_index = 2;
698 auto tuple_node = cnode->cast<CNodePtr>()->input(branch_index);
699 if (IsPrimitiveCNode(tuple_node, prim::kPrimMakeTuple)) {
700 auto tuple_inputs = tuple_node->cast<CNodePtr>()->inputs();
701 for (size_t i = 1; i < tuple_inputs.size(); i++) {
702 buffer << ", branch" << i << ": " << tuple_inputs.at(i);
703 }
704 }
705 }
706 }
707 buffer << ". trace: " << trace::DumpSourceLines(node);
708 return buffer.str();
709 }
710
ProcessEvalResults(const AbstractBasePtrList & out_specs,const AnfNodePtr & node)711 EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_specs, const AnfNodePtr &node) {
712 if (out_specs.empty()) {
713 MS_LOG(EXCEPTION) << "There is an endless loop for evaluator.";
714 }
715
716 if (out_specs.size() == 1) {
717 MS_EXCEPTION_IF_NULL(out_specs[0]);
718 // If only one result derived, then broaden it to avoid wrong constant propagation.
719 return std::make_shared<EvalResult>(out_specs[0]->Broaden(), std::make_shared<AttrValueMap>());
720 }
721 MS_EXCEPTION_IF_NULL(node);
722
723 AbstractBasePtr last_spec = out_specs[0];
724 AbstractBasePtr joined_spec = out_specs[0];
725 for (const auto &spec : out_specs) {
726 MS_EXCEPTION_IF_NULL(spec);
727 try {
728 joined_spec = joined_spec->Join(spec);
729 } catch (const py::type_error &ex) {
730 auto error_info = ExtractLoggingInfo(ex.what());
731 MS_EXCEPTION(TypeError) << JoinBranchesFailedInfo(spec, last_spec, node, error_info);
732 } catch (const py::value_error &ex) {
733 auto error_info = ExtractLoggingInfo(ex.what());
734 MS_EXCEPTION(ValueError) << JoinBranchesFailedInfo(spec, last_spec, node, error_info);
735 } catch (const std::exception &ex) {
736 auto error_info = ExtractLoggingInfo(ex.what());
737 MS_LOG(EXCEPTION) << JoinBranchesFailedInfo(spec, last_spec, node, error_info);
738 }
739 MS_EXCEPTION_IF_NULL(joined_spec);
740 last_spec = spec;
741 }
742
743 MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString();
744 return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
745 }
746
NeedWaitForBranches(const AbstractBasePtr & abstract)747 bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
748 MS_EXCEPTION_IF_NULL(abstract);
749 if (abstract->isa<AbstractFunction>()) {
750 return true;
751 }
752 if (abstract->isa<AbstractSequeue>()) {
753 auto elements = abstract->cast<AbstractSequeuePtr>()->elements();
754 if (std::any_of(elements.begin(), elements.end(),
755 [](const AbstractBasePtr &item) { return item->isa<AbstractFunction>(); })) {
756 return true;
757 }
758 }
759 return false;
760 }
761
ExecEvaluator(EvaluatorPtr eval,AnalysisEnginePtr engine,ConfigPtrList args_conf_list,AnfNodeConfigPtr out_conf,const std::string & threadID,AsyncAbstractPtr async_result_branch,AsyncAbstractPtr async_result_main,AsyncInferTaskPtr async_run_flag,const trace::TraceGraphEvalStack & graph_evals,const trace::TraceCNodeEvalStack & trace_c_node_evals)762 void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList args_conf_list, AnfNodeConfigPtr out_conf,
763 const std::string &threadID, AsyncAbstractPtr async_result_branch,
764 AsyncAbstractPtr async_result_main, AsyncInferTaskPtr async_run_flag,
765 const trace::TraceGraphEvalStack &graph_evals,
766 const trace::TraceCNodeEvalStack &trace_c_node_evals) {
767 AnalysisSchedule::SetThreadID(threadID);
768 // Restore trace stack for dump stack when there is exception.
769 trace::TraceEvalCNodeStackPrepare(trace_c_node_evals);
770 trace::TraceGraphEvalStackPrepare(graph_evals);
771
772 try {
773 // Wait for Signal to run
774 MS_LOG(DEBUG) << async_run_flag.get() << " " << eval->ToString() << " waiting.";
775 (void)async_run_flag->GetResult();
776 MS_LOG(DEBUG) << async_run_flag.get() << " " << eval->ToString() << " running.";
777
778 // Acquire GIL for eval to callback python.
779 EvalResultPtr result;
780 {
781 py::gil_scoped_acquire pyGuard;
782 result = eval->Run(engine, args_conf_list, out_conf);
783 }
784 MS_EXCEPTION_IF_NULL(result);
785 MS_EXCEPTION_IF_NULL(result->abstract());
786
787 // Broaden the result of switch(c,t,f)()
788 auto broadAbstract = result->abstract()->Broaden();
789 // Notify the thread of waiting for switch node and the main thread to continue.
790 AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, broadAbstract);
791 async_result_branch->SetResult(broadAbstract);
792 async_result_main->SetResult(broadAbstract);
793 // Thread number will be drop when thread exits.
794 AnalysisSchedule::GetInstance().DecreaseThreadCount();
795 MS_LOG(DEBUG) << GetInferThread() << "async :" << eval->ToString()
796 << " asyncResult address = " << async_result_branch.get()
797 << " value = " << async_result_branch->TryGetResult()->ToString();
798 } catch (const std::exception &e1) {
799 auto abstractErrPtr = std::make_shared<AbstractError>(std::make_shared<StringImm>("Exception"), out_conf->node());
800 AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, abstractErrPtr);
801 async_result_main->SetResult(abstractErrPtr);
802 MS_LOG(INFO) << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString() << " threw exception.";
803 AnalysisSchedule::GetInstance().HandleException(e1);
804 try {
805 // Thread number will be drop when thread exits.
806 AnalysisSchedule::GetInstance().DecreaseThreadCount();
807 } catch (const std::exception &e2) {
808 MS_LOG(DEBUG) << "AnalysisSchedule::GetInstance().DecreaseThreadCount() threw exception.";
809 }
810 }
811 }
812
ExecuteMultipleEvaluatorsMultiThread(const std::vector<EvaluatorPtr> & evaluators,const AnfNodeConfigPtr & out_conf,const ConfigPtrList & args_conf_list)813 EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::vector<EvaluatorPtr> &evaluators,
814 const AnfNodeConfigPtr &out_conf,
815 const ConfigPtrList &args_conf_list) {
816 MS_EXCEPTION_IF_NULL(out_conf);
817 MS_EXCEPTION_IF_NULL(out_conf->node());
818 // Release GIL for C++
819 py::gil_scoped_release infer_gil_release;
820 // Wait for the last switch node to finish.
821 MS_LOG(DEBUG) << GetInferThread() << "async : entry switch " << out_conf->ToString();
822 auto eval_result = AnalysisResultCacheMgr::GetInstance().GetSwitchValue(out_conf);
823 if (eval_result == nullptr) {
824 MS_LOG(DEBUG) << GetInferThread() << "async : Init switch " << out_conf->node()->ToString();
825 AnalysisResultCacheMgr::GetInstance().InitSwitchValue(out_conf);
826 } else {
827 return std::make_shared<EvalResult>(eval_result, nullptr);
828 }
829 auto possible_parent_fg = out_conf->node()->func_graph();
830
831 // Eval result of the main.
832 AsyncAbstractPtr asyncResult_main = std::make_shared<AsyncAbstract>();
833 // Eval result of the branches
834 std::vector<AsyncAbstractPtr> branchAsyncResults;
835
836 for (auto &evaluator : evaluators) {
837 static std::atomic<int> idCount{0};
838 std::string threadId = AnalysisSchedule::GetThreadID() + "." + std::to_string(idCount.fetch_add(1));
839 MS_EXCEPTION_IF_NULL(evaluator);
840 SetUndeterminedFlag(evaluator, possible_parent_fg);
841 AsyncAbstractPtr branchAsyncResult = std::make_shared<AsyncAbstract>();
842 // Control the order to run.
843 AsyncAbstractPtr asyncRunOrder = std::make_shared<AsyncAbstract>();
844 AsyncInferTaskPtr asyncTask = AsyncInferTask::MakeShared(asyncRunOrder, threadId);
845 // Add point to the async thread.
846 AnalysisSchedule::GetInstance().IncreaseThreadCount();
847 MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluator->ToString();
848 auto thread =
849 std::thread(ExecEvaluator, evaluator, shared_from_this(), args_conf_list, out_conf, threadId, branchAsyncResult,
850 asyncResult_main, asyncTask, trace::GetCurrentGraphEvalStack(), trace::GetCNodeDebugStack());
851 thread.detach();
852 // Push to list of running loop
853 asyncRunOrder->SetResult(std::make_shared<AbstractScalar>(1));
854 MS_LOG(DEBUG) << " add to schedule: " << asyncTask.get();
855 AnalysisSchedule::GetInstance().Add2Schedule(asyncTask); // Activate order witch child thread.
856 (void)branchAsyncResults.emplace_back(std::move(branchAsyncResult));
857 }
858
859 MS_LOG(DEBUG) << GetInferThread() << "async : wait for one of async to finish. " << evaluators[0]->ToString()
860 << " or " << evaluators[1]->ToString() << "...";
861 auto async_main = AsyncInferTask::MakeShared(asyncResult_main);
862 MS_LOG(DEBUG) << " add to schedule: " << async_main.get();
863 AnalysisSchedule::GetInstance().Add2Schedule(async_main); // Third order
864 auto firstResult = async_main->GetResult();
865 MS_EXCEPTION_IF_NULL(firstResult);
866 MS_LOG(DEBUG) << GetInferThread() << "async main thread result of " << out_conf->node()->ToString() << " = "
867 << firstResult->ToString();
868
869 AbstractBasePtrList out_specs;
870 size_t len = evaluators.size();
871 if (NeedWaitForBranches(firstResult)) {
872 for (size_t i = 0; i < len; ++i) {
873 MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[i]->ToString();
874 auto async_branch = AsyncInferTask::MakeShared(branchAsyncResults[i]);
875 MS_LOG(DEBUG) << " add to schedule: " << async_branch.get();
876 AnalysisSchedule::GetInstance().Add2Schedule(async_branch);
877 auto result = async_branch->GetResult();
878 MS_EXCEPTION_IF_NULL(result);
879 out_specs.push_back(result);
880 }
881 } else {
882 // Give one more chance to wait for the result of the branches.
883 auto async_tmp = AsyncInferTask::MakeShared(asyncResult_main);
884 MS_LOG(DEBUG) << " add to schedule: " << async_tmp.get();
885 AnalysisSchedule::GetInstance().Add2Schedule(async_tmp);
886 (void)async_tmp->GetResult();
887 for (size_t i = 0; i < len; ++i) {
888 // Not wait to get the result of branch.
889 auto result = branchAsyncResults[i]->TryGetResult();
890 if (result) {
891 MS_LOG(DEBUG) << GetInferThread() << "async get " << evaluators[i]->ToString()
892 << " result: " << result->ToString();
893 out_specs.push_back(result);
894 }
895 }
896 }
897
898 return ProcessEvalResults(out_specs, out_conf->node());
899 }
900
ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> & evaluators,const AnfNodeConfigPtr & out_conf,const ConfigPtrList & args_conf_list)901 EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
902 const AnfNodeConfigPtr &out_conf,
903 const ConfigPtrList &args_conf_list) {
904 AbstractBasePtrList out_specs;
905 const size_t evaluators_size = 2;
906 if (evaluators.size() < evaluators_size) {
907 MS_LOG(ERROR) << "evaluators size is less than 2";
908 }
909 multi_poss_[evaluators[0]] = evaluators[1];
910 multi_poss_[evaluators[1]] = evaluators[0];
911 AbstractBasePtrList args_spec_list;
912 (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
913 [](const ConfigPtr &conf) -> AbstractBasePtr {
914 MS_EXCEPTION_IF_NULL(conf);
915 return conf->ObtainEvalResult()->abstract();
916 });
917 MS_EXCEPTION_IF_NULL(out_conf);
918 MS_EXCEPTION_IF_NULL(out_conf->node());
919 auto possible_parent_fg = out_conf->node()->func_graph();
920 for (auto eval : evaluators) {
921 MS_EXCEPTION_IF_NULL(eval);
922 (void)SetUndeterminedFlag(eval, possible_parent_fg);
923 const auto current_inf = EvaluatorArgs(eval, args_spec_list);
924 MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
925 // If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating.
926 auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf);
927 if (it == eval_trace_.rend()) {
928 eval_trace_.push_back(current_inf);
929 auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf);
930 auto eval_abstract = eval_result->abstract();
931 MS_EXCEPTION_IF_NULL(eval_abstract);
932
933 out_specs.push_back(eval_abstract);
934 eval_trace_.pop_back();
935 if (eval_trace_.empty()) {
936 multi_poss_.clear();
937 }
938 } else {
939 bool continue_flag = false;
940 auto latest_entry = HandleNestedRecursion(evaluators, eval, args_spec_list, it, &continue_flag);
941 if (continue_flag) {
942 MS_LOG(DEBUG) << "continued_evals_ add " << current_inf.evaluator_.get() << current_inf.evaluator_->ToString();
943 continued_evals_.insert(current_inf);
944 continue;
945 }
946
947 // Try to travel the latest undetermined.
948 if (latest_entry != eval_trace_.rbegin()->evaluator_) {
949 MS_LOG(DEBUG) << "Direct Run Evaluator " << eval.get() << "----" << eval->ToString();
950 auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf);
951 MS_EXCEPTION_IF_NULL(eval_result->abstract());
952 MS_LOG(DEBUG) << "end Direct Evaluator " << latest_entry->ToString()
953 << " return out_spec: " << eval_result->abstract()->ToString();
954 return eval_result;
955 }
956 }
957 }
958
959 return ProcessEvalResults(out_specs, out_conf->node());
960 }
961
ObtainEvalResult()962 EvalResultPtr AnfNodeConfig::ObtainEvalResult() {
963 AnfNodeConfigPtr self = shared_from_base<AnfNodeConfig>();
964 return engine_.lock()->ObtainEvalResultWithCache(self);
965 }
966
MakeAbstractClosure(const FuncGraphPtr & func_graph,const abstract::AnalysisContextPtr & context,const AnfNodePtr & anf_node)967 abstract::AbstractBasePtr MakeAbstractClosure(const FuncGraphPtr &func_graph,
968 const abstract::AnalysisContextPtr &context, const AnfNodePtr &anf_node) {
969 AnalysisContextPtr temp_context = context;
970 if (temp_context == nullptr) {
971 temp_context = abstract::AnalysisContext::DummyContext();
972 }
973 return std::make_shared<abstract::FuncGraphAbstractClosure>(func_graph, temp_context, anf_node);
974 }
975
MakeAbstractClosure(const MetaFuncGraphPtr & meta_func_graph,const AnfNodePtr & anf_node)976 abstract::AbstractBasePtr MakeAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const AnfNodePtr &anf_node) {
977 abstract::MetaFuncGraphAbstractClosurePtr meta_func_graph_fn;
978 if (anf_node == nullptr) {
979 meta_func_graph_fn = std::make_shared<abstract::MetaFuncGraphAbstractClosure>(meta_func_graph);
980 } else {
981 meta_func_graph_fn =
982 std::make_shared<abstract::MetaFuncGraphAbstractClosure>(meta_func_graph, anf_node, anf_node->scope());
983 }
984 return meta_func_graph_fn;
985 }
986
MakeAbstractClosure(const PrimitivePtr & primitive,const AnfNodePtr & anf_node)987 abstract::AbstractBasePtr MakeAbstractClosure(const PrimitivePtr &primitive, const AnfNodePtr &anf_node) {
988 auto prim_func = std::make_shared<abstract::PrimitiveAbstractClosure>(primitive, anf_node);
989 return prim_func;
990 }
991
ToAbstract(const ValuePtr & value,const AnalysisContextPtr & context,const AnfNodeConfigPtr & conf)992 AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context, const AnfNodeConfigPtr &conf) {
993 MS_EXCEPTION_IF_NULL(value);
994 AnfNodePtr anf_node = nullptr;
995 if (conf != nullptr) {
996 anf_node = conf->node();
997 }
998 if (value->isa<FuncGraph>()) {
999 auto func_graph = value->cast<FuncGraphPtr>();
1000 return MakeAbstractClosure(func_graph, context, anf_node);
1001 }
1002 if (value->isa<MetaFuncGraph>()) {
1003 auto meta_func_graph = value->cast<MetaFuncGraphPtr>();
1004 return MakeAbstractClosure(meta_func_graph, anf_node);
1005 }
1006 if (value->isa<Primitive>()) {
1007 auto prim = value->cast<PrimitivePtr>();
1008 return MakeAbstractClosure(prim, anf_node);
1009 } else {
1010 return value->ToAbstract();
1011 }
1012 }
1013
FromValueInside(const ValuePtr & value,bool broaden)1014 AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) {
1015 AbstractBasePtr a = ToAbstract(value, nullptr, nullptr);
1016 if (broaden) {
1017 a = a->Broaden();
1018 }
1019 return a;
1020 }
1021
EvalOnePrim(const PrimitivePtr & primitive,const AbstractBasePtrList & arg_specs)1022 EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) {
1023 auto evaluator = GetPrimEvaluator(primitive, nullptr);
1024 if (evaluator == nullptr) {
1025 MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << primitive->name() << ").";
1026 }
1027 if (!evaluator->isa<TrivialPrimEvaluator>()) {
1028 MS_LOG(EXCEPTION) << "Prim " << primitive->ToString() << " should build a TrivialPrimEvaluator, but "
1029 << evaluator->ToString();
1030 }
1031 auto trivial_evaluator = dyn_cast<TrivialPrimEvaluator>(evaluator);
1032 auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs);
1033 return eval_result;
1034 }
1035 } // namespace abstract
1036 } // namespace mindspore
1037