1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pipeline/jit/pi/graph_capture/cfg.h"
17 #include <fstream>
18 #include <map>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 #include "pipeline/jit/pi/graph_capture/node.h"
23 #include "pipeline/jit/pi/pi_jit_config.h"
24 #include "pipeline/jit/pi/utils/utils.h"
25
26 namespace mindspore {
27 namespace pijit {
28
29 constexpr const int PY_BCSIZE = sizeof(_Py_CODEUNIT);
30
ToString() const31 std::string Instr::ToString() const {
32 std::stringstream s;
33 s << bci_ << ' ' << Opcode(op_).name() << ' ' << arg_;
34 if (!name().empty()) {
35 s << " " << name();
36 }
37 if (cnst().ptr()) {
38 s << " " << std::string(py::str(cnst().ptr()));
39 }
40 if (extra_jump()) {
41 s << " -> " << extra_jump()->bci();
42 }
43 return s.str();
44 }
45
Dump(const std::string & prefix) const46 std::string Instr::Dump(const std::string &prefix) const {
47 std::stringstream os;
48 os << prefix << " " << bci_ << ' ' << Opcode(op_).name() << ' ' << arg_;
49 return os.str();
50 }
51
AddSuccBB(Block * bb)52 void Block::AddSuccBB(Block *bb) {
53 succ_bbs_.insert(bb);
54 bb->pred_bbs_.insert(this);
55 }
56
SetFallBB(Block * arg)57 void Block::SetFallBB(Block *arg) {
58 if (arg != nullptr) {
59 fall_bb_ = arg;
60 AddSuccBB(arg);
61 } else if (fall_bb_ != nullptr) {
62 // remove fall_bb_
63 succ_bbs_.erase(fall_bb_);
64 fall_bb_->pred_bbs().erase(this);
65 fall_bb_ = nullptr;
66 }
67 }
68
SetJumpBB(Block * arg)69 void Block::SetJumpBB(Block *arg) {
70 if (arg != nullptr) {
71 jump_bb_ = arg;
72 AddSuccBB(arg);
73 } else if (jump_bb_ != nullptr) {
74 // remove jump_bb_
75 succ_bbs_.erase(jump_bb_);
76 jump_bb_->pred_bbs().erase(this);
77 jump_bb_ = nullptr;
78 }
79 }
80
RemoveInstr(Instr * instr)81 void Block::RemoveInstr(Instr *instr) {
82 Instr *jump = instr->extra_jump();
83 if (jump != nullptr) {
84 auto &v = jump->extra_preds();
85 v.erase(std::remove(v.begin(), v.end(), jump), v.end());
86 jump->set_extra_jump(nullptr);
87 }
88 for (Instr *pred : instr->extra_preds()) {
89 pred->set_extra_jump(nullptr);
90 }
91 instr->extra_preds().clear();
92 instrs_.erase(instr);
93 }
94
RemoveInstrs()95 void Block::RemoveInstrs() {
96 if (instrs_.empty()) {
97 return;
98 }
99 RemoveInstr(&instrs_.front());
100 if (instrs_.empty()) {
101 return;
102 }
103 RemoveInstr(&instrs_.back());
104 instrs_.clear();
105 }
106
RemoveEdge(Block * bb)107 bool Block::RemoveEdge(Block *bb) {
108 bb->pred_bbs_.erase(this);
109 jump_bb_ = jump_bb_ == bb ? nullptr : jump_bb_;
110 fall_bb_ = fall_bb_ == bb ? nullptr : fall_bb_;
111 return succ_bbs_.erase(bb);
112 }
113
ClearOutEdges()114 void Block::ClearOutEdges() {
115 while (!succ_bbs_.empty()) {
116 RemoveEdge(*succ_bbs_.begin());
117 }
118 }
119
Dump(bool dump_instr) const120 std::string Block::Dump(bool dump_instr) const {
121 std::stringstream os;
122 os << "Block [" << (begin_ci() * PY_BCSIZE) << ',' << (end_ci() * PY_BCSIZE) << "), (id=" << id_
123 << ", is_dead=" << is_dead_ << ", is_loop_head=" << is_loop_head_ << ", is_loop_body_=" << is_loop_body_
124 << ", preds={";
125 for (Block *bb : pred_bbs_) {
126 os << bb->id() << " ";
127 }
128 os << "}, succs={";
129 for (Block *bb : succ_bbs_) {
130 if (bb == jump_bb_) {
131 os << "jump:";
132 } else {
133 os << "fall:";
134 }
135 os << bb->id() << " ";
136 }
137 os << "}";
138 if (IsTrackBreak()) {
139 os << " Break";
140 }
141 if (HasPrimitive()) {
142 os << " HasPrimitive";
143 }
144 if (HasTensor()) {
145 os << " HasTensor";
146 }
147 if (HasAttrSideEffect()) {
148 os << " HasAttrSideEffect";
149 }
150 os << ")";
151 if (!dump_instr) {
152 return os.str();
153 }
154 os << "\n";
155 for (const auto &instr : instrs_) {
156 os << instr.Dump(" ") << "\n";
157 }
158 return os.str();
159 }
160
Clone(CFG * cfg)161 Block *Block::Clone(CFG *cfg) {
162 Block *new_bb = cfg->NewBBAppend();
163 new_bb->set_is_dead(is_dead_);
164 new_bb->set_is_loop_head(is_loop_head_);
165 new_bb->begin_ = this->begin_;
166 new_bb->end_ = this->end_;
167 // clone instr list
168 for (const auto &instr : instrs_) {
169 Instr *new_instr = cfg->NewInstrNode(instr);
170 new_bb->AddInstr(new_instr);
171 }
172 return new_bb;
173 }
174
operator ()(const Block * lhs,const Block * rhs) const175 bool BBIdCmp::operator()(const Block *lhs, const Block *rhs) const { return (lhs->id() < rhs->id()); }
176
operator ()(const Block * lhs,const Block * rhs) const177 bool BBIdGreaterCmp::operator()(const Block *lhs, const Block *rhs) const { return (lhs->id() > rhs->id()); }
178
NewBBAppend()179 Block *CFG::NewBBAppend() {
180 std::unique_ptr<Block> bb_node = std::make_unique<Block>();
181 bb_node->set_id(bb_pool_.size());
182 bb_pool_.push_back(std::move(bb_node));
183 Block *bb = bb_pool_.back().get();
184 return bb;
185 }
186
NewInstrNode(int bci,int op,int arg,int line)187 Instr *CFG::NewInstrNode(int bci, int op, int arg, int line) {
188 instrs_.emplace_back(std::make_unique<Instr>(op, arg, bci, line));
189 Instr *i = instrs_.back().get();
190 if (op == LOAD_CONST) {
191 i->set_cnst(PyTuple_GET_ITEM(pycode_->co_consts, arg));
192 }
193 if (Opcode(op).HasName()) {
194 i->set_name(PyUnicode_AsUTF8(PyTuple_GET_ITEM(pycode_->co_names, arg)));
195 }
196 return i;
197 }
198
NewLoadInstrNode(int bci,int arg,int line,PyObject * cnst)199 Instr *CFG::NewLoadInstrNode(int bci, int arg, int line, PyObject *cnst) {
200 Instr *i = NewInstrNode(bci, 0, -1, line);
201 i->set_cnst(cnst);
202 i->set_op(LOAD_CONST);
203 return i;
204 }
205
NewInstrNode(const Instr & instr)206 Instr *CFG::NewInstrNode(const Instr &instr) {
207 instrs_.emplace_back(std::make_unique<Instr>(instr.op(), instr.arg(), instr.bci(), instr.line()));
208 Instr *i = instrs_.back().get();
209 i->set_cnst(instr.cnst());
210 i->set_name(instr.name());
211 return i;
212 }
213
GenerateCFG()214 void CFG::GenerateCFG() {
215 MS_EXCEPTION_IF_CHECK_FAIL(pycode_, "shouldn't use this function to generate empty cfg");
216 if (!is_generated_) {
217 nlocals_ = pycode_->co_nlocals;
218 is_generated_ = true;
219 BuildInst();
220 BuildBB();
221 BuildCFG();
222 MarkDeadBB();
223 }
224 }
225
BuildInst()226 void CFG::BuildInst() {
227 PyObject *bytes = pycode_->co_code;
228 const _Py_CODEUNIT *bytecode_ = reinterpret_cast<_Py_CODEUNIT *>(PyBytes_AsString(bytes));
229 int size = (PyBytes_GET_SIZE(bytes)) / PY_BCSIZE;
230 int exarg = 0;
231 std::map<int, std::vector<Instr *>> succ_jump;
232 for (int bci = 0; bci < size; ++bci) {
233 Opcode opcode(_Py_OPCODE(bytecode_[bci]));
234 int oparg = (exarg << 8) | _Py_OPARG(bytecode_[bci]);
235 exarg = (opcode == EXTENDED_ARG) ? oparg : 0;
236 int line = PyCode_Addr2Line(pycode_, PY_BCSIZE * bci);
237 if (opcode == LOAD_METHOD) {
238 opcode = LOAD_ATTR;
239 } else if (opcode == CALL_METHOD) {
240 opcode = CALL_FUNCTION;
241 }
242 Instr *instr = NewInstrNode(bci, opcode, oparg, line);
243 instr->set_is_fall(!opcode.IsNotFall());
244 // link instr jump relation
245 if (opcode.IsJRel() || opcode.IsJAbs()) {
246 int dest = opcode.JumpTarget(bci, oparg);
247 if (dest < bci) {
248 Instr *succ = instr_pool()[dest].get();
249 succ->AddExtraPred(instr);
250 instr->set_extra_jump(succ);
251 } else {
252 // record succ jump
253 succ_jump[dest].push_back(instr);
254 }
255 }
256 auto it = succ_jump.find(bci);
257 if (it != succ_jump.cend()) {
258 for (Instr *pred : it->second) {
259 instr->AddExtraPred(pred);
260 MS_EXCEPTION_IF_CHECK_FAIL(pred->extra_jump() == nullptr, "Python bytecode has at most one jump branch");
261 pred->set_extra_jump(instr);
262 }
263 }
264 }
265 }
266
BuildBB()267 void CFG::BuildBB() {
268 Block *curr_bb = nullptr;
269 for (const auto &instr : instr_pool()) {
270 if (instr == nullptr) {
271 continue;
272 }
273 // check start of BB
274 if (curr_bb == nullptr || !instr->extra_preds().empty()) {
275 curr_bb = NewBBAppend();
276 }
277 curr_bb->AddInstr(instr.get());
278 // check end of BB
279 if (!instr->is_fall() || instr->extra_jump() != nullptr) {
280 curr_bb = nullptr;
281 }
282 }
283 for (const auto &i : bb_pool()) {
284 i->set_begin_ci(i->instrs().front().bci());
285 i->set_end_ci(i->instrs().back().bci() + 1);
286 }
287 }
288
BuildCFG()289 bool CFG::BuildCFG() {
290 // build target map
291 std::map<const Instr *, Block *> target_instr_bb_map;
292 for (const auto &unique_bb : bb_pool_) {
293 Block *bb = unique_bb.get();
294 if (bb->instrs().empty()) {
295 continue;
296 }
297 const Instr *instr_head = &(bb->instrs().front());
298 target_instr_bb_map[instr_head] = bb;
299 }
300 // link
301 for (size_t i = 0; i < bb_pool_.size(); ++i) {
302 Block *bb = bb_pool_[i].get();
303 const Instr *instr_tail = &(bb->instrs().back());
304 if (instr_tail->is_fall()) {
305 if (i + 1 >= bb_pool_.size()) {
306 MS_EXCEPTION_IF_CHECK_FAIL(false, "Method without return");
307 return false;
308 }
309 Block *bb_next = bb_pool_[i + 1].get();
310 bb->SetFallBB(bb_next);
311 }
312 if (instr_tail->extra_jump() != nullptr) {
313 Instr *instr = instr_tail->extra_jump();
314 const auto &it_bb = target_instr_bb_map.find(instr);
315 MS_EXCEPTION_IF_CHECK_FAIL(it_bb != target_instr_bb_map.cend(), "Target BB is not found");
316 Block *bb_next = it_bb->second;
317 bb->SetJumpBB(bb_next);
318 }
319 }
320 return true;
321 }
322
VisitBlock(Block * blk,std::vector<bool> * reach,std::vector<bool> * mark,int * loop_count)323 static bool VisitBlock(Block *blk, std::vector<bool> *reach, std::vector<bool> *mark, int *loop_count) {
324 if (reach->operator[](blk->id())) {
325 if (mark->operator[](blk->id()) && !blk->is_loop_head()) {
326 blk->set_is_loop_head(true);
327 blk->set_is_loop_body(true);
328 (*loop_count)++;
329 }
330 return blk->is_loop_body();
331 }
332 bool loop_body = false;
333
334 blk->set_is_dead(false);
335 reach->operator[](blk->id()) = true;
336 mark->operator[](blk->id()) = true;
337 auto iter = blk->succ_bbs().begin();
338 for (; iter != blk->succ_bbs().end(); ++iter) {
339 loop_body |= VisitBlock(*iter, reach, mark, loop_count);
340 }
341 mark->operator[](blk->id()) = false;
342 if (blk->is_loop_head()) {
343 (*loop_count)--;
344 return (*loop_count) != 0;
345 }
346 blk->set_is_loop_body(loop_body);
347 return loop_body;
348 }
349
MarkDeadBB()350 void CFG::MarkDeadBB() {
351 if (bb_pool_.empty()) {
352 return;
353 }
354 std::vector<bool> reach(bb_pool_.size());
355 std::vector<bool> mark(bb_pool_.size());
356 int loop_count = 0;
357 VisitBlock(bb_pool_[0].get(), &reach, &mark, &loop_count);
358 for (const auto &i : bb_pool_) {
359 if (reach[i->id()]) {
360 continue;
361 }
362 i->set_is_dead(true);
363 }
364 }
365
366 // Simplified cfg
ClearDeadBBEdges()367 void CFG::ClearDeadBBEdges() {
368 MarkDeadBB();
369 for (auto &i : bb_pool_) {
370 if (i->is_dead()) {
371 i->ClearOutEdges();
372 }
373 }
374 }
375
GetBlockByBci(int bci) const376 Block *CFG::GetBlockByBci(int bci) const {
377 auto iter = std::find_if(bb_pool().begin(), bb_pool().end(), [bci](const std::unique_ptr<Block> &i) {
378 return i->begin_ci() <= bci && bci < i->end_ci();
379 });
380 if (iter == bb_pool().end()) {
381 MS_LOG(INTERNAL_EXCEPTION) << "can't find block at " << bci;
382 }
383 return iter->get();
384 }
385
Clone()386 std::unique_ptr<CFG> CFG::Clone() {
387 std::unique_ptr<CFG> new_cfg = std::make_unique<CFG>(pycode_);
388 if (bb_pool_.empty()) {
389 return new_cfg;
390 }
391 for (const auto &bb : bb_pool_) {
392 (void)bb->Clone(new_cfg.get());
393 }
394 // link active bb and instr
395 for (Block *bb : *this) {
396 Block *dst_bb = new_cfg->bb_pool()[bb->id()].get();
397 if (bb->GetFallBB() != nullptr) {
398 Block *dst_fall_bb = new_cfg->bb_pool()[bb->GetFallBB()->id()].get();
399 dst_bb->SetFallBB(dst_fall_bb);
400 }
401 if (bb->GetJumpBB() != nullptr) {
402 Block *dst_jump_bb = new_cfg->bb_pool()[bb->GetJumpBB()->id()].get();
403 dst_bb->SetJumpBB(dst_jump_bb);
404 // link instr jump
405 dst_bb->instrs().back().set_extra_jump(&dst_jump_bb->instrs().front());
406 dst_jump_bb->instrs().front().AddExtraPred(&dst_bb->instrs().back());
407 }
408 }
409 return new_cfg;
410 }
411
DumpBBs(std::string phase) const412 std::string CFG::DumpBBs(std::string phase) const {
413 std::ostringstream os;
414 os << "*** Dump BB " << phase << "on [" << py::str(reinterpret_cast<PyObject *>(pycode_)).cast<std::string>()
415 << "] ***\n";
416 for (const auto &bb : bb_pool_) {
417 os << bb->Dump();
418 }
419 return os.str();
420 }
421
DumpCFGGraph()422 void CFG::DumpCFGGraph() {
423 std::string file_name = Utils::GetPyName(pycode_->co_name);
424 file_name = file_name + ".dot";
425 std::ofstream file(file_name);
426 MS_EXCEPTION_IF_CHECK_FAIL(file.is_open(), "Failed to open General CFG Graph FileName:" + file_name);
427 file << "digraph {\n";
428 file << " label=\"" << file_name << "\"\n";
429 file << " labelloc=t\n";
430 DumpCFGGraph(file);
431 file.close();
432 }
433
DumpCFGGraph(std::ofstream & file)434 void CFG::DumpCFGGraph(std::ofstream &file) {
435 for (const auto &bb : bb_pool_) {
436 DumpCFGGraphForBB(file, *bb);
437 }
438 DumpCFGGraphForEdge(file);
439 file << "}\n";
440 }
441
DumpCFGGraphForBB(std::ofstream & file,const Block & bb) const442 void CFG::DumpCFGGraphForBB(std::ofstream &file, const Block &bb) const {
443 file << " BB" << bb.id() << " [shape=record,label=\"{\n";
444 for (const auto &instr : bb.instrs()) {
445 file << " <instr" << instr.bci() << "> " << instr.Dump();
446 if (&instr == &bb.instrs().back()) {
447 file << "\n";
448 break;
449 } else {
450 file << " |\n";
451 }
452 }
453 file << " }\"];\n";
454 }
455
DumpCFGGraphForEdge(std::ofstream & file)456 void CFG::DumpCFGGraphForEdge(std::ofstream &file) {
457 file << " subgraph cfg_edges {\n";
458 file << " edge [color=\"#000000\",weight=0.3,len=3];\n";
459 for (const auto &bb : bb_pool_) {
460 const Instr &instrS = bb->instrs().back();
461 for (Block *bb_next : bb->succ_bbs()) {
462 const Instr &instrE = bb_next->instrs().front();
463 file << " BB" << bb->id() << ":instr" << instrS.bci() << " -> ";
464 file << "BB" << bb_next->id() << ":instr" << instrE.bci() << "\n";
465 }
466 }
467 file << " }\n";
468 }
469
operator ++()470 CFG::BBIterator &CFG::BBIterator::operator++() {
471 if (q_.empty()) {
472 return *this;
473 }
474 Block *bb = q_.front();
475 q_.pop();
476 for (Block *bb_next : bb->succ_bbs()) {
477 if (visit_[bb_next->id()]) {
478 continue;
479 }
480 q_.push(bb_next);
481 visit_[bb_next->id()] = true;
482 }
483 return *this;
484 }
485
GetLiveness()486 const Liveness *CFG::GetLiveness() {
487 if (liveness_ == nullptr) {
488 liveness_ = std::make_unique<Liveness>(this);
489 liveness_->Init();
490 }
491 return liveness_.get();
492 }
493
494 } // namespace pijit
495 } // namespace mindspore
496