1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/jit/static_analysis/stack_frame.h"
18 #include "debug/trace.h"
19 #include "pipeline/jit/static_analysis/async_eval_result.h"
20
21 namespace mindspore {
22 namespace abstract {
GenerateArgsAbsList(const AnalysisEnginePtr & engine,const EvaluatorPtr & evaluator,const CNodePtr current_cnode)23 AbstractBasePtrList StackFrame::GenerateArgsAbsList(const AnalysisEnginePtr &engine, const EvaluatorPtr &evaluator,
24 const CNodePtr current_cnode) {
25 MS_EXCEPTION_IF_NULL(current_cnode);
26 MS_EXCEPTION_IF_NULL(evaluator);
27 AbstractBasePtrList args_abs_list;
28 auto &inputs = current_cnode->inputs();
29 for (std::size_t i = 1; i < inputs.size(); i++) {
30 auto config = engine->MakeConfig(inputs[i], current_context_, current_context_->func_graph());
31 auto result = config->ObtainEvalResult();
32 MS_EXCEPTION_IF_NULL(result);
33 auto abs = result->abstract();
34 args_abs_list.push_back(abs);
35 }
36 args_abs_list = evaluator->NormalizeArgs(args_abs_list);
37 args_abs_list = evaluator->BroadenUndeterminedArgs(args_abs_list);
38 return args_abs_list;
39 }
40
GetParentContext(const BaseFuncGraphEvaluatorPtr & fg_evaluator,const AbstractFunctionPtr & graph_func)41 AnalysisContextPtr StackFrame::GetParentContext(const BaseFuncGraphEvaluatorPtr &fg_evaluator,
42 const AbstractFunctionPtr &graph_func) {
43 MS_EXCEPTION_IF_NULL(graph_func);
44 MS_EXCEPTION_IF_NULL(fg_evaluator);
45 AnalysisContextPtr parent_context = nullptr;
46 auto func_graph_abs = dyn_cast<FuncGraphAbstractClosure>(graph_func);
47 if (func_graph_abs != nullptr) { // Set parent context for FuncGraphAbstractClosure.
48 parent_context = func_graph_abs->context();
49 } else if (graph_func->isa<MetaFuncGraphAbstractClosure>()) { // Or DummyContext for MetaFuncGraphAbstractClosure.
50 parent_context = fg_evaluator->parent_context();
51 if (parent_context == nullptr) {
52 parent_context = AnalysisContext::DummyContext();
53 fg_evaluator->set_parent_context(parent_context);
54 }
55 } else { // Not call FuncGraph or MetaFuncGraph.
56 MS_LOG(EXCEPTION) << "Should be FuncGraphAbstractClosure or MetaFuncGraphAbstractClosure.";
57 }
58 return parent_context;
59 }
60
61 // Inner jump implementation.
DoJump(const AnalysisEnginePtr & engine,const CNodePtr current_cnode,const AbstractFunctionPtr & graph_func)62 StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr current_cnode,
63 const AbstractFunctionPtr &graph_func) {
64 MS_EXCEPTION_IF_NULL(engine);
65 MS_EXCEPTION_IF_NULL(current_cnode);
66 // Get the evaluator for func graph.
67 auto evaluator = engine->GetEvaluatorFor(graph_func);
68 auto fg_evaluator = dyn_cast<BaseFuncGraphEvaluator>(evaluator);
69 if (fg_evaluator == nullptr) {
70 MS_LOG(EXCEPTION) << "Evaluator should be a BaseGraphEvaluator, but got " << evaluator->ToString();
71 }
72
73 // Evaluate the inputs firstly. Build arguments for the func graph.
74 AbstractBasePtrList args_abs_list = GenerateArgsAbsList(engine, evaluator, current_cnode);
75 // Check if already evaluated before.
76 if (evaluator->evaluator_cache_mgr()->GetValue(args_abs_list) != nullptr) {
77 return nullptr;
78 }
79
80 // Generate func graph with arguments.
81 auto fg = fg_evaluator->GetFuncGraph(engine, args_abs_list);
82 MS_EXCEPTION_IF_NULL(fg);
83 std::size_t nargs = fg->parameters().size();
84 if (args_abs_list.size() != nargs) {
85 MS_EXCEPTION(TypeError) << "The parameters number of the function is " << fg->parameters().size()
86 << ", but the number of provided arguments is " << args_abs_list.size() << ".\n"
87 << "FunctionGraph ID : " << fg->ToString()
88 << "\nNodeInfo: " << trace::GetDebugInfo(fg->debug_info());
89 }
90 MS_LOG(DEBUG) << "current_node: " << current_cnode->DebugString() << ", fg: " << fg->ToString()
91 << ", current_context_: " << current_context_->ToString();
92
93 // Find parent context and create new context.
94 AnalysisContextPtr parent_context = GetParentContext(fg_evaluator, graph_func);
95 MS_EXCEPTION_IF_NULL(parent_context);
96 auto new_context = parent_context->NewContext(fg, args_abs_list);
97
98 // Evaluate the parameters with new context.
99 for (size_t i = 0; i < nargs; i++) {
100 const auto &arg_abs = args_abs_list[i];
101 const auto &node = fg->parameters()[i];
102 AnfNodeConfigPtr conf = engine->MakeConfig(node, new_context, new_context->func_graph());
103 engine->SaveEvalResultInCache(conf, std::make_shared<EvalResult>(arg_abs, nullptr));
104 }
105
106 // Create a new stack frame and set arguments for it.
107 auto new_stack_frame = std::make_shared<StackFrame>(fg_evaluator, fg, new_context, parent_context);
108 new_stack_frame->set_args_abs_list(std::move(args_abs_list));
109 return new_stack_frame;
110 }
111
112 // Check if we need branch to another func graph.
Jump(const AnalysisEnginePtr & engine)113 StackFramePtr StackFrame::Jump(const AnalysisEnginePtr &engine) {
114 MS_EXCEPTION_IF_NULL(engine);
115 auto ¤t_node = CurrentNode();
116 if (!current_node->isa<CNode>()) {
117 return nullptr;
118 }
119 auto cnode = current_node->cast<CNodePtr>();
120 auto maybe_func = engine->GetCNodeOperatorAbstract(cnode, current_context_, current_context_->func_graph());
121 if (!maybe_func->isa<abstract::MetaFuncGraphAbstractClosure>() &&
122 !maybe_func->isa<abstract::FuncGraphAbstractClosure>()) {
123 return nullptr; // Not call FuncGraph or MetaFuncGraph.
124 }
125
126 // It's FuncGraph Call or MetaFuncGraph Call. `maybe_func` is definitely a AbstractFunction.
127 AnfNodeConfigPtr call_node_conf = engine->MakeConfig(cnode, current_context_, current_context_->func_graph());
128 // Enter the call CNode.
129 trace::TraceEvalCNodeEnter(call_node_conf);
130 auto res = DoJump(engine, cnode, dyn_cast<AbstractFunction>(maybe_func));
131 if (res == nullptr) {
132 trace::TraceEvalCNodeLeave();
133 }
134 return res;
135 }
136
137 // Run one step in current func graph.
Step(const AnalysisEnginePtr & engine)138 EvalResultPtr StackFrame::Step(const AnalysisEnginePtr &engine) {
139 MS_EXCEPTION_IF_NULL(engine);
140 auto ¤t_node = NextNode();
141 MS_LOG(DEBUG) << "current_node: " << current_node->DebugString()
142 << ", current_context_: " << current_context_->ToString();
143 AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_, current_context_->func_graph());
144 auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
145 MS_LOG(DEBUG) << GetInferThread() << "Eval(" << node_conf->ToString() << ") = "
146 << (node_eval_result->abstract() ? node_eval_result->abstract()->ToString() : "Abstract null");
147 return node_eval_result;
148 }
149
150 // Return back from child func graph.
Back(const AnalysisEnginePtr & engine,const StackFramePtr & last_stack_frame,const EvalResultPtr & eval_result)151 void StackFrame::Back(const AnalysisEnginePtr &engine, const StackFramePtr &last_stack_frame,
152 const EvalResultPtr &eval_result) {
153 MS_EXCEPTION_IF_NULL(engine);
154 MS_EXCEPTION_IF_NULL(last_stack_frame);
155 MS_EXCEPTION_IF_NULL(eval_result);
156 // Overwrite the result if func graph is stub.
157 EvalResultPtr result = eval_result;
158 if (last_stack_frame->func_graph()->stub()) {
159 result = std::make_shared<EvalResult>(std::make_shared<AbstractUndetermined>(), nullptr);
160 }
161 // Save func graph eval result for specialize.
162 auto evaluator = last_stack_frame->evaluator();
163 MS_EXCEPTION_IF_NULL(evaluator);
164 evaluator->evaluator_cache_mgr()->SetValue(last_stack_frame->args_abs_list(), result);
165
166 // Continue saving node's result for parent func graph.
167 auto ¤t_node = NextNode();
168 MS_LOG(DEBUG) << "current_node: " << current_node->DebugString()
169 << ", current_context_: " << current_context_->ToString();
170 AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_, current_context_->func_graph());
171 engine->SaveEvalResultInCache(node_conf, result);
172
173 // Leave the call CNode.
174 trace::TraceEvalCNodeLeave();
175 }
176 } // namespace abstract
177 } // namespace mindspore
178