1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_PI_JIT_GRAPH_CAPTURE_GRAPH_H 17 #define MINDSPORE_PI_JIT_GRAPH_CAPTURE_GRAPH_H 18 19 #include <map> 20 #include <memory> 21 #include <string> 22 #include <utility> 23 #include <vector> 24 #include "pipeline/jit/pi/graph_capture/cfg.h" 25 #include "pipeline/jit/pi/graph_capture/loop.h" 26 #include "pipeline/jit/pi/graph_capture/node.h" 27 #include "pipeline/jit/pi/utils/allocator.h" 28 #include "pipeline/jit/pi/graph_guard/trace.h" 29 #include "pipeline/jit/pi/graph_guard/guard.h" 30 #include "pipeline/jit/pi/graph_capture/side_effect.h" 31 #include "utils/convert_utils_base.h" 32 33 namespace mindspore { 34 namespace pijit { 35 36 class OptCode; 37 class GraphJitConfig; 38 39 class FrameStates { 40 public: Local(int i)41 ValueNode *Local(int i) const { 42 MS_ASSERT((int)locals.size() > i); 43 return locals[i]; 44 } SetLocal(int i,ValueNode * v)45 void SetLocal(int i, ValueNode *v) { 46 MS_ASSERT((int)locals.size() > i); 47 locals[i] = v; 48 } 49 Closure(int i)50 CellVarNode *Closure(int i) const { 51 MS_ASSERT((int)cell_free.size() > i); 52 return cell_free[i]; 53 } SetClosure(int i,CellVarNode * v)54 void SetClosure(int i, CellVarNode *v) { 55 MS_ASSERT((int)cell_free.size() > i); 56 cell_free[i] = v; 57 } 58 Peek(int p)59 ValueNode *&Peek(int p) { 60 MS_ASSERT((int)stack.size() > p); 61 return stack[stack.size() - p - 1]; 62 } 63 Peek(int p)64 ValueNode *Peek(int p) const { 65 MS_ASSERT((int)stack.size() > p); 66 return stack[stack.size() - p - 1]; 67 } 68 Pop()69 ValueNode *Pop() { 70 MS_ASSERT(stack.size() > 0); 71 auto r = stack[stack.size() - 1]; 72 stack.pop_back(); 73 return r; 74 } Popn(int n)75 void Popn(int n) { 76 for (int i = 0; i < n; i++) { 77 Pop(); 78 } 79 } Push(ValueNode * i)80 void Push(ValueNode *i) { stack.push_back(i); } 81 Rot(int i)82 void Rot(int i) { 83 MS_ASSERT((int)stack.size() - i >= 0); 84 ValueNode *v = Pop(); 85 stack.insert(stack.end() - i, v); 86 } 87 ResizeLocal(int i)88 void ResizeLocal(int i) { 89 MS_ASSERT((int)locals.size() <= i); 90 locals.resize(i, &ValueNode::kUnboundLocal); 91 } ResizeClosure(int i)92 void ResizeClosure(int i) { 93 MS_ASSERT((int)cell_free.size() <= i); 94 cell_free.resize(i); 95 } 96 GetLocals()97 const auto &GetLocals() const { return locals; } GetStacks()98 const auto &GetStacks() const { return stack; } GetClosures()99 const auto &GetClosures() const { return cell_free; } 100 GetLocals()101 auto &GetLocals() { return locals; } GetStacks()102 auto &GetStacks() { return stack; } GetClosures()103 auto &GetClosures() { return cell_free; } 104 105 std::string ToString() const; 106 107 private: 108 std::vector<ValueNode *> stack; 109 std::vector<ValueNode *> locals; 110 std::vector<CellVarNode *> cell_free; 111 }; 112 113 class Graph { 114 public: 115 Graph(PyCodeObject *co, PyObject *globals, const GraphJitConfig &conf); ~Graph()116 virtual ~Graph() {} 117 GetGeneratorResult()118 ValueNode *GetGeneratorResult() const { return generator_result_; } SetGeneratorResult(ValueNode * generator_result)119 void SetGeneratorResult(ValueNode *generator_result) { generator_result_ = generator_result; } 120 SetRetVal(ValueNode * v)121 void SetRetVal(ValueNode *v) { ret_val_ = v; } GetRetVal()122 ValueNode *GetRetVal() const { return ret_val_; } GetCodeObj()123 PyCodeObject *GetCodeObj() const { return reinterpret_cast<PyCodeObject *>(co_.ptr()); } GetGlobals()124 const py::object &GetGlobals() const { return f_globals_; } 125 StopTraceAt(int bci,StopTraceReason reason)126 void StopTraceAt(int bci, StopTraceReason reason) { stop_trace_info_ = {bci, reason}; } GetStopTraceBci()127 int GetStopTraceBci() const { return stop_trace_info_.bci; } GetStopTraceReason()128 StopTraceReason GetStopTraceReason() const { return stop_trace_info_.reason; } GetModuleName()129 const char *GetModuleName() const { return module_name_; } 130 GetCFG()131 auto &GetCFG() { return cfg_; } GetCFG()132 const auto &GetCFG() const { return cfg_; } Config()133 const GraphJitConfig &Config() const { return conf_; } 134 135 const FrameStates &GetFrame(int bci) const; 136 void SetFrame(int bci, const FrameStates &f); GetFrames()137 auto &GetFrames() { return frame_states_; } GetFrames()138 const auto &GetFrames() const { return frame_states_; } allocator()139 Allocator &allocator() { return alloc_; } 140 ValueNode *NewValueNode(AObject *, int op, int arg, const std::vector<ValueNode *> & = {}, 141 const std::string &name = ""); 142 CallNode *NewCallNode(int op, int arg, const std::vector<ValueNode *> &); loops()143 const std::vector<LoopInfo *> &loops() const { return loops_; } AddLoop(LoopInfo * loop)144 void AddLoop(LoopInfo *loop) { loops_.emplace_back(loop); } 145 146 // only func name GetCodeName()147 std::string GetCodeName() const { 148 PyCodeObject *c = reinterpret_cast<PyCodeObject *>(co_.ptr()); 149 return Utils::GetPyName(c->co_name); 150 } 151 152 bool GuardValueNode(ValueNode *, GuardLevel level = GuardLevel::GEqual); 153 bool GuardType(ValueNode *); 154 bool GuardSequenceNodeLength(ValueNode *, Py_ssize_t); 155 bool GuardInlinedFunc(CallNode *call_node); 156 157 TracePtr TraceValueNode(ValueNode *, int max_trace_depth = -1); GetPruneBranchCount()158 int GetPruneBranchCount() const { return prune_branch_count_; } SetPruneBranchCount(int count)159 void SetPruneBranchCount(int count) { prune_branch_count_ = count; } GetGuard()160 const std::shared_ptr<OptCode> &GetGuard() const { return guard_; } SetGuard(const std::shared_ptr<OptCode> & guard)161 void SetGuard(const std::shared_ptr<OptCode> &guard) { guard_ = guard; } 162 163 // (chaiyouheng): restore graph status at loop begin, clear trace values and operations and guards RestoreLoopStatus()164 bool RestoreLoopStatus() { return false; } 165 bool IsBreakAtLoop() const; 166 bool IsBreakAtLoopAfterUnrolling() const; GetTracedNodes()167 const std::vector<ValueNode *> &GetTracedNodes() const { return traced_nodes_; } GetTracedNodes()168 std::vector<ValueNode *> &GetTracedNodes() { return traced_nodes_; } 169 170 std::string ToString(int depth = 0) const; 171 172 std::string DumpBreakInfo() const; 173 SetParent(Graph * parent)174 void SetParent(Graph *parent) { parent_ = parent; } GetParent()175 Graph *GetParent() const { return parent_; } 176 177 const std::shared_ptr<SideEffect> &GetSideEffect() const; 178 void SetSideEffect(const std::shared_ptr<SideEffect> &handler); 179 180 // collect alive node, output bitmap 181 std::vector<ValueNode *> CollectAliveNode(int bci, std::vector<int> * = nullptr, BitMap * = nullptr) const; 182 183 // collect alive node, clear the bit if alive local is unbound 184 static std::vector<ValueNode *> CollectAliveNode(const FrameStates &, BitMap *, std::vector<int> * = nullptr); 185 186 private: 187 std::unique_ptr<CFG> cfg_; 188 std::vector<LoopInfo *> loops_; 189 190 // frame status 191 std::map<int, std::unique_ptr<FrameStates>> frame_states_; 192 std::vector<ValueNode *> traced_nodes_; 193 194 // return value 195 ValueNode *ret_val_; 196 197 // used to fold generator function call 198 ValueNode *generator_result_; 199 200 // the traced code object 201 py::object co_; 202 203 // globals that may be used by frame when the tracer start 204 py::object f_globals_; 205 206 const char *module_name_; 207 208 struct StopTraceInfo { 209 int bci; // trace stopped bci 210 StopTraceReason reason; 211 } stop_trace_info_; 212 213 Allocator alloc_; 214 215 const GraphJitConfig &conf_; 216 217 std::shared_ptr<OptCode> guard_; 218 int prune_branch_count_; 219 Graph *parent_{nullptr}; 220 std::shared_ptr<SideEffect> side_effect_; 221 }; 222 } // namespace pijit 223 } // namespace mindspore 224 225 #endif // MINDSPORE_PI_JIT_GRAPH_CAPTURE_GRAPH_H 226