• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_PI_JIT_GRAPH_CAPTURE_CFG_H
17 #define MINDSPORE_PI_JIT_GRAPH_CAPTURE_CFG_H
18 
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <queue>
23 #include <vector>
24 #include "pipeline/jit/pi/pydef.h"
25 #include "pipeline/jit/pi/utils/ptr_list_ref.h"
26 #include "pybind11/pybind11.h"
27 #include "pipeline/jit/pi/graph_capture/local_liveness.h"
28 
29 namespace mindspore {
30 namespace pijit {
31 
32 namespace py = pybind11;
33 
34 class AbstractNode;
35 class Graph;
36 
37 class Instr : public PtrListNodeBase<Instr> {
38  public:
39   Instr(const Instr &) = delete;
40   Instr &operator=(const Instr &) = delete;
bci_(bci)41   Instr(int op, int arg, int bci = -1, int line = -1) : bci_(bci), op_(op), arg_(arg), line_(line) {
42     MS_EXCEPTION_IF_CHECK_FAIL(op != Opcode::k_ILLEGAL_OPCODE, "ILLEGAL OPCODE !!!");
43   }
Instr(int op,int arg,const std::string & name)44   Instr(int op, int arg, const std::string &name) : Instr(op, arg) { name_ = name; }
Instr(int op,int arg,const py::object & cnst)45   Instr(int op, int arg, const py::object &cnst) : Instr(op, arg) { cnst_ = cnst; }
Instr(int op)46   explicit Instr(int op) : Instr(op, 0) {}
47   virtual ~Instr() = default;
48 
bci()49   int bci() const { return bci_; }
set_bci(int i)50   void set_bci(int i) { bci_ = i; }
op()51   int op() const { return op_; }
set_op(int op)52   void set_op(int op) { op_ = op; }
arg()53   int arg() const { return arg_; }
set_arg(int arg)54   void set_arg(int arg) { arg_ = arg; }
line()55   int line() const { return line_; }
set_line(int l)56   void set_line(int l) { line_ = l; }
is_fall()57   bool is_fall() const { return is_fall_; }
set_is_fall(int is_fall)58   void set_is_fall(int is_fall) { is_fall_ = is_fall; }
extra_preds()59   const std::vector<Instr *> &extra_preds() const { return extra_preds_; }
extra_preds()60   std::vector<Instr *> &extra_preds() { return extra_preds_; }
extra_jump()61   Instr *extra_jump() const { return extra_jump_; }
set_extra_jump(Instr * j)62   void set_extra_jump(Instr *j) { extra_jump_ = j; }
63 
name()64   const std::string &name() const { return name_; }
set_name(const std::string & n)65   void set_name(const std::string &n) { name_ = n; }
cnst()66   const py::object &cnst() const { return cnst_; }
set_cnst(PyObject * cnst)67   void set_cnst(PyObject *cnst) { cnst_ = py::reinterpret_borrow<py::object>(cnst); }
set_cnst(const py::object & cnst)68   void set_cnst(const py::object &cnst) { cnst_ = cnst; }
69 
AddExtraPred(Instr * instr)70   void AddExtraPred(Instr *instr) { extra_preds_.push_back(instr); }
71   std::string Dump(const std::string &prefix = "") const;
72   std::string ToString() const;
73 
74  private:
75   int bci_;
76   int op_;
77   int arg_;
78   int line_;
79   std::string name_;
80   py::object cnst_;
81 
82   bool is_fall_ = true;
83   std::vector<Instr *> extra_preds_;
84   Instr *extra_jump_ = nullptr;
85 };
86 
87 class Block;
88 struct BBIdCmp {
89   bool operator()(const Block *lhs, const Block *rhs) const;
90 };
91 
92 struct BBIdGreaterCmp {
93   bool operator()(const Block *lhs, const Block *rhs) const;
94 };
95 
96 using UniqueInstr = std::unique_ptr<Instr>;
97 using Instrs = PtrListRef<Instr>;
98 class CFG;
99 class Block {
100  public:
101   enum TrackResult {
102     kNotTrack,
103     kTrackHasTensor,
104     kTrackHasOpsPrimitive,
105     kTrackBreak,
106     kHasGlobalSideEffect,
107     kHasAttrSideEffect,
108     kHasClosureSideEffect,
109   };
110   Block() = default;
111   ~Block() = default;
id()112   uint32_t id() const { return id_; }
set_id(uint32_t arg)113   void set_id(uint32_t arg) { id_ = arg; }
instrs()114   Instrs &instrs() { return instrs_; }
instrs()115   const Instrs &instrs() const { return instrs_; }
AddInstr(Instr * i)116   void AddInstr(Instr *i) { instrs_.push_back(i); }
pred_bbs()117   const std::set<Block *, BBIdCmp> &pred_bbs() const { return pred_bbs_; }
pred_bbs()118   std::set<Block *, BBIdCmp> &pred_bbs() { return pred_bbs_; }
succ_bbs()119   const std::set<Block *, BBIdCmp> &succ_bbs() const { return succ_bbs_; }
succ_bbs()120   std::set<Block *, BBIdCmp> &succ_bbs() { return succ_bbs_; }
set_is_loop_head(bool flag)121   void set_is_loop_head(bool flag) { is_loop_head_ = flag; }
is_loop_head()122   bool is_loop_head() const { return is_loop_head_; }
set_is_loop_body(bool flag)123   void set_is_loop_body(bool flag) { is_loop_body_ = flag; }
is_loop_body()124   bool is_loop_body() const { return is_loop_body_; }
is_dead()125   bool is_dead() const { return is_dead_; }
set_is_dead(bool flag)126   void set_is_dead(bool flag) { is_dead_ = flag; }
127 
128   std::string Dump(bool dump_instr = true) const;
129 
begin_ci()130   int begin_ci() const { return begin_; }
end_ci()131   int end_ci() const { return end_; }
set_begin_ci(int i)132   void set_begin_ci(int i) { begin_ = i; }
set_end_ci(int i)133   void set_end_ci(int i) { end_ = i; }
GetFallBB()134   Block *GetFallBB() const { return fall_bb_; }
GetJumpBB()135   Block *GetJumpBB() const { return jump_bb_; }
136   void SetFallBB(Block *arg);
137   void SetJumpBB(Block *arg);
138   void RemoveInstr(Instr *instr);
139   void RemoveInstrs();
140 
IsTrackBreak()141   bool IsTrackBreak() const { return track_result_ & (1 << IntToSize(kTrackBreak)); }
HasPrimitive()142   bool HasPrimitive() const { return track_result_ & (1 << IntToSize(kTrackHasOpsPrimitive)); }
HasTensor()143   bool HasTensor() const { return track_result_ & (1 << IntToSize(kTrackHasTensor)); }
HasUnresolvedSideEffect()144   bool HasUnresolvedSideEffect() const { return track_result_ & (1 << IntToSize(kHasGlobalSideEffect)); }
HasAttrSideEffect()145   bool HasAttrSideEffect() const { return track_result_ & (1 << IntToSize(kHasAttrSideEffect)); }
HasClosureSideEffect()146   bool HasClosureSideEffect() const { return track_result_ & (1 << IntToSize(kHasClosureSideEffect)); }
SetTrackResult(TrackResult r)147   void SetTrackResult(TrackResult r) { track_result_ = (track_result_ & ~(1 << IntToSize(kNotTrack))) | (1 << r); }
148 
149   void AddSuccBB(Block *bb);
150   bool RemoveEdge(Block *bb);
151   void ClearOutEdges();
152 
153   Block *Clone(CFG *cfg);
154 
155  private:
156   uint32_t id_;  // start from 0
157   int begin_;
158   int end_;
159   std::set<Block *, BBIdCmp> pred_bbs_;
160   std::set<Block *, BBIdCmp> succ_bbs_;  // include fall_bb_ and jump_bb_
161   Block *fall_bb_ = nullptr;
162   Block *jump_bb_ = nullptr;
163 
164   bool is_loop_body_ = false;
165   bool is_loop_head_ = false;
166   bool is_dead_ = true;
167 
168   // (chaiyouheng): remove
169   Instrs instrs_;
170   int track_result_ = (1 << kNotTrack);
171 };
172 
173 class CFG {
174  public:
CFG(PyCodeObject * co)175   explicit CFG(PyCodeObject *co) : pycode_(co), nlocals_(0) {}
176 
177   // BFS Iterator
178   class BBIterator {
179    public:
180     BBIterator() = default;
BBIterator(const CFG * c)181     explicit BBIterator(const CFG *c) : visit_(c->bb_pool().size(), false) {
182       q_.push(c->GetFirstBB());
183       visit_[c->GetFirstBB()->id()] = true;
184     }
185 
BBIterator(const CFG * c,Block * bb)186     BBIterator(const CFG *c, Block *bb) : visit_(c->bb_pool().size(), false) {
187       q_.push(bb);
188       visit_[bb->id()] = true;
189     }
190 
GetVisitMap()191     const auto &GetVisitMap() const { return visit_; }
192     Block *operator*() const { return q_.front(); }
193     bool operator!=(const BBIterator &end) const { return !q_.empty(); }
194     BBIterator &operator++();
195 
196     std::queue<Block *> q_;
197     std::vector<bool> visit_;
198   };
199 
begin()200   BBIterator begin() const { return BBIterator(this); }
begin(Block * start)201   BBIterator begin(Block *start) const { return BBIterator(this, start); }
end()202   BBIterator end() const { return BBIterator(); }
203 
bb_pool()204   const std::vector<std::unique_ptr<Block>> &bb_pool() const { return bb_pool_; }
instr_pool()205   const std::vector<std::unique_ptr<Instr>> &instr_pool() const { return instrs_; }
liveness()206   const std::unique_ptr<Liveness> &liveness() const { return liveness_; }
instr_pool()207   std::vector<std::unique_ptr<Instr>> &instr_pool() { return instrs_; }
bb_pool()208   std::vector<std::unique_ptr<Block>> &bb_pool() { return bb_pool_; }
liveness()209   std::unique_ptr<Liveness> &liveness() { return liveness_; }
GetCodeObject()210   PyCodeObject *GetCodeObject() const { return pycode_; }
GetLocalCount()211   int GetLocalCount() const { return nlocals_; }
SetLocalCount(int n)212   void SetLocalCount(int n) { nlocals_ = n; }
ToString()213   std::string ToString() const { return DumpBBs(); }
214 
215   const Liveness *GetLiveness();
216 
217   void GenerateCFG();
218   void MarkDeadBB();
219 
220   // clear dead bb's edges
221   void ClearDeadBBEdges();
222 
GetFirstBB()223   Block *GetFirstBB() const { return bb_pool_.size() ? bb_pool_[0].get() : nullptr; }
224   Block *GetBlockByBci(int) const;
225 
226   std::string DumpBBs(std::string phase = "") const;
227   void DumpCFGGraph();
228   void DumpCFGGraph(std::ofstream &file);
229   void DumpCFGGraphForBB(std::ofstream &file, const Block &bb) const;
230   void DumpCFGGraphForEdge(std::ofstream &file);
231 
232   Block *NewBBAppend();
233   Instr *NewInstrNode(int bci, int op, int arg, int line);
234   Instr *NewInstrNode(const Instr &instr);
235   Instr *NewLoadInstrNode(int bci, int arg, int line, PyObject *cnst);
236   std::unique_ptr<CFG> Clone();
237 
238  private:
239   void BuildInst();
240   void BuildBB();
241   bool BuildCFG();
242 
243   PyCodeObject *const pycode_;
244   std::vector<std::unique_ptr<Instr>> instrs_;
245   std::vector<std::unique_ptr<Block>> bb_pool_;
246   std::unique_ptr<Liveness> liveness_;
247   int nlocals_;
248   bool is_generated_ = false;
249 };
250 }  // namespace pijit
251 }  // namespace mindspore
252 
253 #endif  // MINDSPORE_PI_JIT_GRAPH_CAPTURE_CFG_H
254