• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_PI_JIT_GRAPH_CAPTURE_GRAPH_H
17 #define MINDSPORE_PI_JIT_GRAPH_CAPTURE_GRAPH_H
18 
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 #include "pipeline/jit/pi/graph_capture/cfg.h"
25 #include "pipeline/jit/pi/graph_capture/loop.h"
26 #include "pipeline/jit/pi/graph_capture/node.h"
27 #include "pipeline/jit/pi/utils/allocator.h"
28 #include "pipeline/jit/pi/graph_guard/trace.h"
29 #include "pipeline/jit/pi/graph_guard/guard.h"
30 #include "pipeline/jit/pi/graph_capture/side_effect.h"
31 #include "utils/convert_utils_base.h"
32 
33 namespace mindspore {
34 namespace pijit {
35 
36 class OptCode;
37 class GraphJitConfig;
38 
39 class FrameStates {
40  public:
Local(int i)41   ValueNode *Local(int i) const {
42     MS_ASSERT((int)locals.size() > i);
43     return locals[i];
44   }
SetLocal(int i,ValueNode * v)45   void SetLocal(int i, ValueNode *v) {
46     MS_ASSERT((int)locals.size() > i);
47     locals[i] = v;
48   }
49 
Closure(int i)50   CellVarNode *Closure(int i) const {
51     MS_ASSERT((int)cell_free.size() > i);
52     return cell_free[i];
53   }
SetClosure(int i,CellVarNode * v)54   void SetClosure(int i, CellVarNode *v) {
55     MS_ASSERT((int)cell_free.size() > i);
56     cell_free[i] = v;
57   }
58 
Peek(int p)59   ValueNode *&Peek(int p) {
60     MS_ASSERT((int)stack.size() > p);
61     return stack[stack.size() - p - 1];
62   }
63 
Peek(int p)64   ValueNode *Peek(int p) const {
65     MS_ASSERT((int)stack.size() > p);
66     return stack[stack.size() - p - 1];
67   }
68 
Pop()69   ValueNode *Pop() {
70     MS_ASSERT(stack.size() > 0);
71     auto r = stack[stack.size() - 1];
72     stack.pop_back();
73     return r;
74   }
Popn(int n)75   void Popn(int n) {
76     for (int i = 0; i < n; i++) {
77       Pop();
78     }
79   }
Push(ValueNode * i)80   void Push(ValueNode *i) { stack.push_back(i); }
81 
Rot(int i)82   void Rot(int i) {
83     MS_ASSERT((int)stack.size() - i >= 0);
84     ValueNode *v = Pop();
85     stack.insert(stack.end() - i, v);
86   }
87 
ResizeLocal(int i)88   void ResizeLocal(int i) {
89     MS_ASSERT((int)locals.size() <= i);
90     locals.resize(i, &ValueNode::kUnboundLocal);
91   }
ResizeClosure(int i)92   void ResizeClosure(int i) {
93     MS_ASSERT((int)cell_free.size() <= i);
94     cell_free.resize(i);
95   }
96 
GetLocals()97   const auto &GetLocals() const { return locals; }
GetStacks()98   const auto &GetStacks() const { return stack; }
GetClosures()99   const auto &GetClosures() const { return cell_free; }
100 
GetLocals()101   auto &GetLocals() { return locals; }
GetStacks()102   auto &GetStacks() { return stack; }
GetClosures()103   auto &GetClosures() { return cell_free; }
104 
105   std::string ToString() const;
106 
107  private:
108   std::vector<ValueNode *> stack;
109   std::vector<ValueNode *> locals;
110   std::vector<CellVarNode *> cell_free;
111 };
112 
113 class Graph {
114  public:
115   Graph(PyCodeObject *co, PyObject *globals, const GraphJitConfig &conf);
~Graph()116   virtual ~Graph() {}
117 
GetGeneratorResult()118   ValueNode *GetGeneratorResult() const { return generator_result_; }
SetGeneratorResult(ValueNode * generator_result)119   void SetGeneratorResult(ValueNode *generator_result) { generator_result_ = generator_result; }
120 
SetRetVal(ValueNode * v)121   void SetRetVal(ValueNode *v) { ret_val_ = v; }
GetRetVal()122   ValueNode *GetRetVal() const { return ret_val_; }
GetCodeObj()123   PyCodeObject *GetCodeObj() const { return reinterpret_cast<PyCodeObject *>(co_.ptr()); }
GetGlobals()124   const py::object &GetGlobals() const { return f_globals_; }
125 
StopTraceAt(int bci,StopTraceReason reason)126   void StopTraceAt(int bci, StopTraceReason reason) { stop_trace_info_ = {bci, reason}; }
GetStopTraceBci()127   int GetStopTraceBci() const { return stop_trace_info_.bci; }
GetStopTraceReason()128   StopTraceReason GetStopTraceReason() const { return stop_trace_info_.reason; }
GetModuleName()129   const char *GetModuleName() const { return module_name_; }
130 
GetCFG()131   auto &GetCFG() { return cfg_; }
GetCFG()132   const auto &GetCFG() const { return cfg_; }
Config()133   const GraphJitConfig &Config() const { return conf_; }
134 
135   const FrameStates &GetFrame(int bci) const;
136   void SetFrame(int bci, const FrameStates &f);
GetFrames()137   auto &GetFrames() { return frame_states_; }
GetFrames()138   const auto &GetFrames() const { return frame_states_; }
allocator()139   Allocator &allocator() { return alloc_; }
140   ValueNode *NewValueNode(AObject *, int op, int arg, const std::vector<ValueNode *> & = {},
141                           const std::string &name = "");
142   CallNode *NewCallNode(int op, int arg, const std::vector<ValueNode *> &);
loops()143   const std::vector<LoopInfo *> &loops() const { return loops_; }
AddLoop(LoopInfo * loop)144   void AddLoop(LoopInfo *loop) { loops_.emplace_back(loop); }
145 
146   // only func name
GetCodeName()147   std::string GetCodeName() const {
148     PyCodeObject *c = reinterpret_cast<PyCodeObject *>(co_.ptr());
149     return Utils::GetPyName(c->co_name);
150   }
151 
152   bool GuardValueNode(ValueNode *, GuardLevel level = GuardLevel::GEqual);
153   bool GuardType(ValueNode *);
154   bool GuardSequenceNodeLength(ValueNode *, Py_ssize_t);
155   bool GuardInlinedFunc(CallNode *call_node);
156 
157   TracePtr TraceValueNode(ValueNode *, int max_trace_depth = -1);
GetPruneBranchCount()158   int GetPruneBranchCount() const { return prune_branch_count_; }
SetPruneBranchCount(int count)159   void SetPruneBranchCount(int count) { prune_branch_count_ = count; }
GetGuard()160   const std::shared_ptr<OptCode> &GetGuard() const { return guard_; }
SetGuard(const std::shared_ptr<OptCode> & guard)161   void SetGuard(const std::shared_ptr<OptCode> &guard) { guard_ = guard; }
162 
163   // (chaiyouheng): restore graph status at loop begin, clear trace values and operations and guards
RestoreLoopStatus()164   bool RestoreLoopStatus() { return false; }
165   bool IsBreakAtLoop() const;
166   bool IsBreakAtLoopAfterUnrolling() const;
GetTracedNodes()167   const std::vector<ValueNode *> &GetTracedNodes() const { return traced_nodes_; }
GetTracedNodes()168   std::vector<ValueNode *> &GetTracedNodes() { return traced_nodes_; }
169 
170   std::string ToString(int depth = 0) const;
171 
172   std::string DumpBreakInfo() const;
173 
SetParent(Graph * parent)174   void SetParent(Graph *parent) { parent_ = parent; }
GetParent()175   Graph *GetParent() const { return parent_; }
176 
177   const std::shared_ptr<SideEffect> &GetSideEffect() const;
178   void SetSideEffect(const std::shared_ptr<SideEffect> &handler);
179 
180   // collect alive node, output bitmap
181   std::vector<ValueNode *> CollectAliveNode(int bci, std::vector<int> * = nullptr, BitMap * = nullptr) const;
182 
183   // collect alive node, clear the bit if alive local is unbound
184   static std::vector<ValueNode *> CollectAliveNode(const FrameStates &, BitMap *, std::vector<int> * = nullptr);
185 
186  private:
187   std::unique_ptr<CFG> cfg_;
188   std::vector<LoopInfo *> loops_;
189 
190   // frame status
191   std::map<int, std::unique_ptr<FrameStates>> frame_states_;
192   std::vector<ValueNode *> traced_nodes_;
193 
194   // return value
195   ValueNode *ret_val_;
196 
197   // used to fold generator function call
198   ValueNode *generator_result_;
199 
200   // the traced code object
201   py::object co_;
202 
203   // globals that may be used by frame when the tracer start
204   py::object f_globals_;
205 
206   const char *module_name_;
207 
208   struct StopTraceInfo {
209     int bci;  // trace stopped bci
210     StopTraceReason reason;
211   } stop_trace_info_;
212 
213   Allocator alloc_;
214 
215   const GraphJitConfig &conf_;
216 
217   std::shared_ptr<OptCode> guard_;
218   int prune_branch_count_;
219   Graph *parent_{nullptr};
220   std::shared_ptr<SideEffect> side_effect_;
221 };
222 }  // namespace pijit
223 }  // namespace mindspore
224 
225 #endif  // MINDSPORE_PI_JIT_GRAPH_CAPTURE_GRAPH_H
226