1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_PI_JIT_GRAPH_CAPTURE_CFG_H 17 #define MINDSPORE_PI_JIT_GRAPH_CAPTURE_CFG_H 18 19 #include <memory> 20 #include <set> 21 #include <string> 22 #include <queue> 23 #include <vector> 24 #include "pipeline/jit/pi/pydef.h" 25 #include "pipeline/jit/pi/utils/ptr_list_ref.h" 26 #include "pybind11/pybind11.h" 27 #include "pipeline/jit/pi/graph_capture/local_liveness.h" 28 29 namespace mindspore { 30 namespace pijit { 31 32 namespace py = pybind11; 33 34 class AbstractNode; 35 class Graph; 36 37 class Instr : public PtrListNodeBase<Instr> { 38 public: 39 Instr(const Instr &) = delete; 40 Instr &operator=(const Instr &) = delete; bci_(bci)41 Instr(int op, int arg, int bci = -1, int line = -1) : bci_(bci), op_(op), arg_(arg), line_(line) { 42 MS_EXCEPTION_IF_CHECK_FAIL(op != Opcode::k_ILLEGAL_OPCODE, "ILLEGAL OPCODE !!!"); 43 } Instr(int op,int arg,const std::string & name)44 Instr(int op, int arg, const std::string &name) : Instr(op, arg) { name_ = name; } Instr(int op,int arg,const py::object & cnst)45 Instr(int op, int arg, const py::object &cnst) : Instr(op, arg) { cnst_ = cnst; } Instr(int op)46 explicit Instr(int op) : Instr(op, 0) {} 47 virtual ~Instr() = default; 48 bci()49 int bci() const { return bci_; } set_bci(int i)50 void set_bci(int i) { bci_ = i; } op()51 int op() const { return op_; } set_op(int op)52 void set_op(int op) { op_ = op; } arg()53 int arg() const { return arg_; } set_arg(int arg)54 void set_arg(int arg) { arg_ = arg; } line()55 int line() const { return line_; } set_line(int l)56 void set_line(int l) { line_ = l; } is_fall()57 bool is_fall() const { return is_fall_; } set_is_fall(int is_fall)58 void set_is_fall(int is_fall) { is_fall_ = is_fall; } extra_preds()59 const std::vector<Instr *> &extra_preds() const { return extra_preds_; } extra_preds()60 std::vector<Instr *> &extra_preds() { return extra_preds_; } extra_jump()61 Instr *extra_jump() const { return extra_jump_; } set_extra_jump(Instr * j)62 void set_extra_jump(Instr *j) { extra_jump_ = j; } 63 name()64 const std::string &name() const { return name_; } set_name(const std::string & n)65 void set_name(const std::string &n) { name_ = n; } cnst()66 const py::object &cnst() const { return cnst_; } set_cnst(PyObject * cnst)67 void set_cnst(PyObject *cnst) { cnst_ = py::reinterpret_borrow<py::object>(cnst); } set_cnst(const py::object & cnst)68 void set_cnst(const py::object &cnst) { cnst_ = cnst; } 69 AddExtraPred(Instr * instr)70 void AddExtraPred(Instr *instr) { extra_preds_.push_back(instr); } 71 std::string Dump(const std::string &prefix = "") const; 72 std::string ToString() const; 73 74 private: 75 int bci_; 76 int op_; 77 int arg_; 78 int line_; 79 std::string name_; 80 py::object cnst_; 81 82 bool is_fall_ = true; 83 std::vector<Instr *> extra_preds_; 84 Instr *extra_jump_ = nullptr; 85 }; 86 87 class Block; 88 struct BBIdCmp { 89 bool operator()(const Block *lhs, const Block *rhs) const; 90 }; 91 92 struct BBIdGreaterCmp { 93 bool operator()(const Block *lhs, const Block *rhs) const; 94 }; 95 96 using UniqueInstr = std::unique_ptr<Instr>; 97 using Instrs = PtrListRef<Instr>; 98 class CFG; 99 class Block { 100 public: 101 enum TrackResult { 102 kNotTrack, 103 kTrackHasTensor, 104 kTrackHasOpsPrimitive, 105 kTrackBreak, 106 kHasGlobalSideEffect, 107 kHasAttrSideEffect, 108 kHasClosureSideEffect, 109 }; 110 Block() = default; 111 ~Block() = default; id()112 uint32_t id() const { return id_; } set_id(uint32_t arg)113 void set_id(uint32_t arg) { id_ = arg; } instrs()114 Instrs &instrs() { return instrs_; } instrs()115 const Instrs &instrs() const { return instrs_; } AddInstr(Instr * i)116 void AddInstr(Instr *i) { instrs_.push_back(i); } pred_bbs()117 const std::set<Block *, BBIdCmp> &pred_bbs() const { return pred_bbs_; } pred_bbs()118 std::set<Block *, BBIdCmp> &pred_bbs() { return pred_bbs_; } succ_bbs()119 const std::set<Block *, BBIdCmp> &succ_bbs() const { return succ_bbs_; } succ_bbs()120 std::set<Block *, BBIdCmp> &succ_bbs() { return succ_bbs_; } set_is_loop_head(bool flag)121 void set_is_loop_head(bool flag) { is_loop_head_ = flag; } is_loop_head()122 bool is_loop_head() const { return is_loop_head_; } set_is_loop_body(bool flag)123 void set_is_loop_body(bool flag) { is_loop_body_ = flag; } is_loop_body()124 bool is_loop_body() const { return is_loop_body_; } is_dead()125 bool is_dead() const { return is_dead_; } set_is_dead(bool flag)126 void set_is_dead(bool flag) { is_dead_ = flag; } 127 128 std::string Dump(bool dump_instr = true) const; 129 begin_ci()130 int begin_ci() const { return begin_; } end_ci()131 int end_ci() const { return end_; } set_begin_ci(int i)132 void set_begin_ci(int i) { begin_ = i; } set_end_ci(int i)133 void set_end_ci(int i) { end_ = i; } GetFallBB()134 Block *GetFallBB() const { return fall_bb_; } GetJumpBB()135 Block *GetJumpBB() const { return jump_bb_; } 136 void SetFallBB(Block *arg); 137 void SetJumpBB(Block *arg); 138 void RemoveInstr(Instr *instr); 139 void RemoveInstrs(); 140 IsTrackBreak()141 bool IsTrackBreak() const { return track_result_ & (1 << IntToSize(kTrackBreak)); } HasPrimitive()142 bool HasPrimitive() const { return track_result_ & (1 << IntToSize(kTrackHasOpsPrimitive)); } HasTensor()143 bool HasTensor() const { return track_result_ & (1 << IntToSize(kTrackHasTensor)); } HasUnresolvedSideEffect()144 bool HasUnresolvedSideEffect() const { return track_result_ & (1 << IntToSize(kHasGlobalSideEffect)); } HasAttrSideEffect()145 bool HasAttrSideEffect() const { return track_result_ & (1 << IntToSize(kHasAttrSideEffect)); } HasClosureSideEffect()146 bool HasClosureSideEffect() const { return track_result_ & (1 << IntToSize(kHasClosureSideEffect)); } SetTrackResult(TrackResult r)147 void SetTrackResult(TrackResult r) { track_result_ = (track_result_ & ~(1 << IntToSize(kNotTrack))) | (1 << r); } 148 149 void AddSuccBB(Block *bb); 150 bool RemoveEdge(Block *bb); 151 void ClearOutEdges(); 152 153 Block *Clone(CFG *cfg); 154 155 private: 156 uint32_t id_; // start from 0 157 int begin_; 158 int end_; 159 std::set<Block *, BBIdCmp> pred_bbs_; 160 std::set<Block *, BBIdCmp> succ_bbs_; // include fall_bb_ and jump_bb_ 161 Block *fall_bb_ = nullptr; 162 Block *jump_bb_ = nullptr; 163 164 bool is_loop_body_ = false; 165 bool is_loop_head_ = false; 166 bool is_dead_ = true; 167 168 // (chaiyouheng): remove 169 Instrs instrs_; 170 int track_result_ = (1 << kNotTrack); 171 }; 172 173 class CFG { 174 public: CFG(PyCodeObject * co)175 explicit CFG(PyCodeObject *co) : pycode_(co), nlocals_(0) {} 176 177 // BFS Iterator 178 class BBIterator { 179 public: 180 BBIterator() = default; BBIterator(const CFG * c)181 explicit BBIterator(const CFG *c) : visit_(c->bb_pool().size(), false) { 182 q_.push(c->GetFirstBB()); 183 visit_[c->GetFirstBB()->id()] = true; 184 } 185 BBIterator(const CFG * c,Block * bb)186 BBIterator(const CFG *c, Block *bb) : visit_(c->bb_pool().size(), false) { 187 q_.push(bb); 188 visit_[bb->id()] = true; 189 } 190 GetVisitMap()191 const auto &GetVisitMap() const { return visit_; } 192 Block *operator*() const { return q_.front(); } 193 bool operator!=(const BBIterator &end) const { return !q_.empty(); } 194 BBIterator &operator++(); 195 196 std::queue<Block *> q_; 197 std::vector<bool> visit_; 198 }; 199 begin()200 BBIterator begin() const { return BBIterator(this); } begin(Block * start)201 BBIterator begin(Block *start) const { return BBIterator(this, start); } end()202 BBIterator end() const { return BBIterator(); } 203 bb_pool()204 const std::vector<std::unique_ptr<Block>> &bb_pool() const { return bb_pool_; } instr_pool()205 const std::vector<std::unique_ptr<Instr>> &instr_pool() const { return instrs_; } liveness()206 const std::unique_ptr<Liveness> &liveness() const { return liveness_; } instr_pool()207 std::vector<std::unique_ptr<Instr>> &instr_pool() { return instrs_; } bb_pool()208 std::vector<std::unique_ptr<Block>> &bb_pool() { return bb_pool_; } liveness()209 std::unique_ptr<Liveness> &liveness() { return liveness_; } GetCodeObject()210 PyCodeObject *GetCodeObject() const { return pycode_; } GetLocalCount()211 int GetLocalCount() const { return nlocals_; } SetLocalCount(int n)212 void SetLocalCount(int n) { nlocals_ = n; } ToString()213 std::string ToString() const { return DumpBBs(); } 214 215 const Liveness *GetLiveness(); 216 217 void GenerateCFG(); 218 void MarkDeadBB(); 219 220 // clear dead bb's edges 221 void ClearDeadBBEdges(); 222 GetFirstBB()223 Block *GetFirstBB() const { return bb_pool_.size() ? bb_pool_[0].get() : nullptr; } 224 Block *GetBlockByBci(int) const; 225 226 std::string DumpBBs(std::string phase = "") const; 227 void DumpCFGGraph(); 228 void DumpCFGGraph(std::ofstream &file); 229 void DumpCFGGraphForBB(std::ofstream &file, const Block &bb) const; 230 void DumpCFGGraphForEdge(std::ofstream &file); 231 232 Block *NewBBAppend(); 233 Instr *NewInstrNode(int bci, int op, int arg, int line); 234 Instr *NewInstrNode(const Instr &instr); 235 Instr *NewLoadInstrNode(int bci, int arg, int line, PyObject *cnst); 236 std::unique_ptr<CFG> Clone(); 237 238 private: 239 void BuildInst(); 240 void BuildBB(); 241 bool BuildCFG(); 242 243 PyCodeObject *const pycode_; 244 std::vector<std::unique_ptr<Instr>> instrs_; 245 std::vector<std::unique_ptr<Block>> bb_pool_; 246 std::unique_ptr<Liveness> liveness_; 247 int nlocals_; 248 bool is_generated_ = false; 249 }; 250 } // namespace pijit 251 } // namespace mindspore 252 253 #endif // MINDSPORE_PI_JIT_GRAPH_CAPTURE_CFG_H 254