• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STACK_FRAME_H_
18 #define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STACK_FRAME_H_
19 
20 #include <utility>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 
25 #include "pipeline/jit/ps/static_analysis/evaluator.h"
26 #include "mindspore/core/ops/framework_ops.h"
27 #include "utils/compile_config.h"
28 
29 namespace mindspore {
30 namespace abstract {
31 class StackFrame;
32 using EvaluatorWeakPtr = std::weak_ptr<Evaluator>;
33 using BaseFuncGraphEvaluatorPtr = std::shared_ptr<BaseFuncGraphEvaluator>;
34 
35 class StackFrame final : public Base {
36  public:
StackFrame(const EvaluatorPtr & evaluator,const FuncGraphPtr & func_graph,const AnalysisContextPtr & current_context,const AnalysisContextPtr & parent_context)37   StackFrame(const EvaluatorPtr &evaluator, const FuncGraphPtr &func_graph, const AnalysisContextPtr &current_context,
38              const AnalysisContextPtr &parent_context)
39       : evaluator_(EvaluatorWeakPtr(evaluator)),
40         func_graph_(func_graph),
41         current_context_(current_context),
42         parent_context_(parent_context),
43         slot_index_(0),
44         done_(false) {
45     Load();
46   }
47   virtual ~StackFrame() = default;
48 
49   MS_DECLARE_PARENT(StackFrame, Base);
50 
Load()51   void Load() {
52     MS_EXCEPTION_IF_NULL(func_graph_);
53     node_slots_ = TopoSort(func_graph_->get_return(), SuccIncoming, [](const AnfNodePtr &node) -> IncludeType {
54       static const bool enable_pre_lift = (common::GetCompileConfig("PRE_LIFT") == "1");
55       if (node->isa<ValueNode>() || node->isa<Parameter>() ||
56           (enable_pre_lift && IsPrimitiveCNode(node, prim::kPrimPartial))) {
57         return EXCLUDE;
58       }
59       return FOLLOW;
60     });
61     if (node_slots_.empty()) {
62       MS_LOG(INTERNAL_EXCEPTION) << "The func graph is empty, func graph: " << func_graph_ << "/"
63                                  << func_graph_->ToString()
64                                  << ", has return: " << (func_graph_->get_return() != nullptr);
65     }
66     slot_index_ = 0;
67     args_abs_list_.clear();
68   }
69 
70   // Check if we need branch to another func graph.
71   StackFramePtr Jump(const AnalysisEnginePtr &engine);
72   // Run one step in current func graph.
73   EvalResultPtr Step(const AnalysisEnginePtr &engine);
74   // Return back from branch func graph.
75   void Back(const AnalysisEnginePtr &engine, const StackFramePtr &last_stack_frame, const EvalResultPtr &eval_result);
76 
Done()77   bool Done() const { return done_; }
78 
CurrentNode()79   AnfNodePtr &CurrentNode() {
80     if (slot_index_ >= node_slots_.size()) {
81       MS_LOG(INTERNAL_EXCEPTION) << "The stack frame of " << func_graph_->ToString()
82                                  << " is invalid. Try to access frame sequence by index " << slot_index_
83                                  << ", while the size is " << node_slots_.size() << ".";
84     }
85     return node_slots_[slot_index_];
86   }
87 
NextNode()88   AnfNodePtr &NextNode() {
89     auto &current_node = CurrentNode();
90     MS_EXCEPTION_IF_NULL(func_graph_);
91     // Set `done_` true, if the stack frames is being exhausted.
92     if (current_node == func_graph_->get_return()) {
93       done_ = true;
94     }
95     // Move cursor to next node.
96     slot_index_++;
97     return current_node;
98   }
99 
evaluator()100   EvaluatorPtr evaluator() const { return evaluator_.lock(); }
func_graph()101   FuncGraphPtr func_graph() const { return func_graph_; }
current_context()102   AnalysisContextPtr current_context() const { return current_context_; }
parent_context()103   AnalysisContextPtr parent_context() const { return parent_context_; }
104 
args_abs_list()105   const AbstractBasePtrList &args_abs_list() const { return args_abs_list_; }
set_args_abs_list(const AbstractBasePtrList && args_abs_list)106   void set_args_abs_list(const AbstractBasePtrList &&args_abs_list) { args_abs_list_ = args_abs_list; }
107 
ToString()108   std::string ToString() const override {
109     MS_EXCEPTION_IF_NULL(func_graph_);
110     std::ostringstream buffer;
111     buffer << "StackFrame: " << this << ", " << func_graph_->ToString();
112     if (slot_index_ < node_slots_.size()) {
113       auto current_node = node_slots_[slot_index_];
114       buffer << "(#" << slot_index_ << " / Running " << current_node->DebugString() << ")";
115     } else {
116       buffer << "(Exhausted..)";
117     }
118     buffer << ", parent: ";
119     MS_EXCEPTION_IF_NULL(parent_context_);
120     auto parent_graph = parent_context_->func_graph();
121     if (parent_graph != nullptr) {
122       buffer << parent_graph << "/" << parent_graph->ToString();
123     } else {
124       buffer << "NULL";
125     }
126     return buffer.str();
127   }
128 
129   friend std::ostream &operator<<(std::ostream &os, const StackFramePtr &frame) {
130     MS_EXCEPTION_IF_NULL(frame);
131     os << frame->ToString();
132     return os;
133   }
134 
135  private:
136   AbstractBasePtrList GenerateArgsAbsList(const AnalysisEnginePtr &engine, const EvaluatorPtr &evaluator,
137                                           const CNodePtr &current_cnode);
138   AnalysisContextPtr GetParentContext(const BaseFuncGraphEvaluatorPtr &fg_evaluator,
139                                       const AbstractFunctionPtr &graph_func) const;
140   StackFramePtr DoJump(const AnalysisEnginePtr &engine, const CNodePtr &current_cnode,
141                        const AbstractFunctionPtr &graph_func);
142 
143   EvaluatorWeakPtr evaluator_;
144   FuncGraphPtr func_graph_;
145   AnalysisContextPtr current_context_;
146   AnalysisContextPtr parent_context_;
147   AbstractBasePtrList args_abs_list_;
148   std::vector<AnfNodePtr> node_slots_;
149   size_t slot_index_;
150   bool done_;
151 };
152 }  // namespace abstract
153 }  // namespace mindspore
154 #endif  // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STACK_FRAME_H_
155