• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pipeline/jit/pi/graph_capture/loop_unrolling.h"
17 #include <fstream>
18 #include <map>
19 #include <string>
20 #include <unordered_set>
21 #include <utility>
22 #include <vector>
23 #include "pipeline/jit/pi/graph_capture/graph.h"
24 #include "pipeline/jit/pi/graph_capture/loop.h"
25 #include "pipeline/jit/pi/pi_jit_config.h"
26 
27 namespace mindspore {
28 namespace pijit {
29 #define CHECK_PHASE(func, ...)                   \
30   do {                                           \
31     res_ = func(__VA_ARGS__);                    \
32     if (!LoopUnrolling::IsloopUnorlling(res_)) { \
33       return;                                    \
34     }                                            \
35   } while (0)
36 
ExecuteLoopUnroll(Block * header)37 LoopUnrollingReason LoopUnrolling::ExecuteLoopUnroll(Block *header) {
38   if (graph_.loops().empty()) {
39     return kCanNotUnroll;
40   }
41   for (auto *lp : graph_.loops()) {
42     if (lp->header() == header) {
43       loop_ = lp;
44       break;
45     }
46   }
47   if (loop_ == nullptr) {
48     return kCanNotUnroll;
49   }
50   Run();
51   // Dump loop unrolling info
52   if (graph_.Config().GetBoolConfig(GraphJitConfig::kPrintAfterAll)) {
53     GRAPH_JIT_LOG_F("%s\n\n", DumpLoopUnrolling().c_str());
54   }
55   if (IsloopUnorlling(res_) && graph_.Config().GetBoolConfig(GraphJitConfig::kPrintBB)) {
56     GRAPH_JIT_LOG_F("%s\n\n", graph_.GetCFG()->DumpBBs("after loop unrolling ").c_str());
57   }
58   return res_;
59 }
60 
Run()61 void LoopUnrolling::Run() {
62   // check only one exit and one backedges in loop
63   constexpr size_t max_size = 2;
64   if (loop_->exits().empty() || loop_->exits().size() >= max_size || loop_->backedges().size() >= max_size) {
65     res_ = kCanNotSplitGoto;
66     return;
67   }
68   // check pred of exit has only succ block
69   Block *exit = *loop_->exits().begin();
70   if (exit->pred_bbs().size() != 1) {
71     res_ = kCanNotSplitGoto;
72     return;
73   }
74   Block *backedge = *loop_->backedges().begin();
75   if (backedge->GetFallBB() != nullptr) {
76     res_ = kCanNotJumpBackedge;
77     return;
78   }
79   // check foritem
80   if (loop_->header()->instrs().front().op() == FOR_ITER) {
81     CHECK_PHASE(AnalyzeForItem);
82   }
83   CHECK_PHASE(CheckLoopUnrollingSideeffect);
84   is_cfg_changed_ = true;
85   RemoveBackedge();
86   CopyAndInsertBB();
87   FixupInstr();
88 }
89 
AnalyzeForItem()90 LoopUnrollingReason LoopUnrolling::AnalyzeForItem() {
91   // find GET_ITER opcode
92   MS_EXCEPTION_IF_NULL(loop_value_);
93   AObject *loop_vobj = loop_value_->GetVobj();
94   if (!loop_vobj) {
95     return kCanNotUnroll;
96   }
97   PyObject *obj = loop_vobj->GetPyObject().ptr();
98   // check unrolling count
99   if (loop_vobj->GetType() == AObject::kTypeList || loop_vobj->GetType() == AObject::kTypeTuple) {
100     AbstractTuple *list = static_cast<AbstractTuple *>(loop_vobj);
101     if (!list->IsElementValid()) {
102       return kCanNotUnroll;
103     }
104     AddLoopGurad(loop_value_);
105     unrolling_count_ = list->size();
106     loop_op_ = NOP;
107     loop_arg_ = 0;
108   } else if (loop_vobj->GetType() == AObject::kTypeNNCellList && obj != nullptr) {
109     AddLoopGurad(loop_value_);
110     unrolling_count_ = PyObject_Size(obj);
111     loop_op_ = NOP;
112     loop_arg_ = 0;
113   } else {
114     return kCanNotUnroll;
115   }
116   return kCanForItemUnroll;
117 }
118 
AddLoopGurad(ValueNode * value)119 bool LoopUnrolling::AddLoopGurad(ValueNode *value) { return graph_.GuardValueNode(value); }
120 
CheckLoopUnrollingSideeffect()121 LoopUnrollingReason LoopUnrolling::CheckLoopUnrollingSideeffect() {
122   // check length
123   if (unrolling_count_ <= 0 || unrolling_count_ > graph_.Config().getIntConfig(GraphJitConfig::kMaxLoopUnrolling)) {
124     return kCanNotMaxCount;
125   }
126   if (loop_value_ == nullptr && loop_value_->GetVobj()) {
127     return res_;
128   }
129   // check if loop_value is called by CFunction, e.g. list.append()
130   // check side effects
131   return res_;
132 }
133 
AddLoopUnrollingInstr(Block * bb,int count)134 void LoopUnrolling::AddLoopUnrollingInstr(Block *bb, int count) {
135   bb->set_is_loop_head(false);
136   const Instr &first_instr = bb->instrs().front();
137   bb->RemoveInstrs();
138   // remove GET_ITER and adding [count - 1] DUP_TOP
139   if (loop_op_ == NOP && count == 0) {
140     // GET_ITER --> DUP_TOP
141     if (iter_instr_ != nullptr) {
142       iter_instr_->set_op(DUP_TOP);
143       iter_instr_->set_arg(0);
144     }
145     if (unrolling_count_ == 1) {
146       Instr *instr = graph_.GetCFG()->NewInstrNode(first_instr.bci(), POP_TOP, 0, first_instr.line());
147       bb->AddInstr(instr);
148     }
149     for (int i = 0; i < unrolling_count_ - 2; ++i) {
150       Instr *instr = graph_.GetCFG()->NewInstrNode(first_instr.bci(), DUP_TOP, 0, first_instr.line());
151       bb->AddInstr(instr);
152     }
153   }
154   // get list or tuple ref
155   Instr *i = graph_.GetCFG()->NewInstrNode(-1, loop_op_, loop_arg_, first_instr.line());
156   bb->AddInstr(i);
157   if (count == 0) {
158     i->set_bci(first_instr.bci());
159   }
160   py::object value = py::int_(count);
161   // subscript index
162   i = graph_.GetCFG()->NewLoadInstrNode(-1, -1, first_instr.line(), value.ptr());
163   bb->AddInstr(i);
164   i = graph_.GetCFG()->NewInstrNode(-1, BINARY_SUBSCR, 0, first_instr.line());
165   bb->AddInstr(i);
166 }
167 
168 // while-do pattern loop
RemoveBackedge()169 void LoopUnrolling::RemoveBackedge() {
170   loop_->header()->SetJumpBB(nullptr);
171   MS_EXCEPTION_IF_CHECK_FAIL(loop_->backedges().size() == 1, "backedges has only one block");
172   Block *backedge = *loop_->backedges().begin();
173   if (!backedge->instrs().empty() && &backedge->instrs().front() == &backedge->instrs().back()) {
174     backedge->instrs().front().set_op(NOP);  // replace JUMP_ABSOLUTE
175     backedge->instrs().front().set_arg(0);
176   } else {
177     backedge->RemoveInstr(&backedge->instrs().back());  // remove JUMP_ABSOLUTE
178   }
179   backedge->SetJumpBB(nullptr);
180   backedge->SetFallBB(*loop_->exits().begin());
181 }
182 
CopyAndInsertBB()183 void LoopUnrolling::CopyAndInsertBB() {
184   Block *exit = *loop_->exits().begin();
185   MS_EXCEPTION_IF_CHECK_FAIL(exit->pred_bbs().size() == 1, "pred of exit has only succ block");
186   Block *exit_pred = *exit->pred_bbs().begin();
187   Block *header = loop_->header();
188   Block *start = nullptr;
189   Block *tail = nullptr;
190   for (int i = 0; i < unrolling_count_; ++i) {
191     if (i == 0) {
192       AddLoopUnrollingInstr(header, i);
193       continue;
194     }
195     std::map<int, Block *> bb_map = CopyBB();
196     header = bb_map[loop_->header()->id()];
197     AddLoopUnrollingInstr(header, i);
198     if (i == 1) {
199       start = bb_map[loop_->header()->id()];
200       tail = bb_map[exit_pred->id()];
201     }
202     if (i > 1) {
203       tail->SetFallBB(bb_map[loop_->header()->id()]);
204     }
205     tail = bb_map[exit_pred->id()];
206     if (i == unrolling_count_ - 1) {
207       exit_pred->SetFallBB(nullptr);
208       exit_pred->SetFallBB(start);
209       tail->SetFallBB(exit);
210     }
211   }
212 }
213 
CopyBB()214 std::map<int, Block *> LoopUnrolling::CopyBB() {
215   std::map<int, Block *> bb_map;
216   for (Block *memb : loop_->loop_members()) {
217     Block *new_bb = memb->Clone(graph_.GetCFG().get());
218     bb_map.insert(std::make_pair(memb->id(), new_bb));
219   }
220   // link active bb and instr
221   for (auto iter = graph_.GetCFG()->begin(loop_->header()); iter != graph_.GetCFG()->end(); ++iter) {
222     Block *bb = *iter;
223     if (loop_->loop_members().find(bb) == loop_->loop_members().cend()) {
224       break;
225     }
226     Block *dst_bb = bb_map[bb->id()];
227     if (dst_bb == nullptr) {
228       continue;
229     }
230     if (bb->GetFallBB() != nullptr) {
231       Block *dst_fall_bb = bb_map[bb->GetFallBB()->id()];
232       dst_bb->SetFallBB(dst_fall_bb);
233     }
234     if (bb->GetJumpBB() != nullptr) {
235       Block *dst_jump_bb = bb_map[bb->GetJumpBB()->id()];
236       dst_bb->SetJumpBB(dst_jump_bb);
237       // link instr jump
238       dst_bb->instrs().back().set_extra_jump(&dst_jump_bb->instrs().front());
239       dst_jump_bb->instrs().front().AddExtraPred(&dst_bb->instrs().back());
240     }
241   }
242   return bb_map;
243 }
244 
FixupInstr()245 void LoopUnrolling::FixupInstr() {
246   int head_bci = loop_->header()->instrs().front().bci();  // first instruction bci of header is computed
247   int bci = head_bci;
248   // fixup bci
249   std::priority_queue<Block *, std::vector<Block *>, BBIdGreaterCmp> queue;
250   queue.push(loop_->header());
251   std::vector<bool> visited(graph_.GetCFG()->bb_pool().size(), false);
252   // dfs search falled bb
253   while (!queue.empty()) {
254     Block *bb = queue.top();
255     queue.pop();
256     while (bb != nullptr) {
257       if (visited[bb->id()]) {
258         bb = bb->GetFallBB();
259         continue;
260       }
261       visited[bb->id()] = true;
262       for (auto &instr : bb->instrs()) {
263         instr.set_bci(bci++);
264       }
265       if (bb->GetJumpBB() != nullptr && !visited[bb->GetJumpBB()->id()]) {
266         queue.push(bb->GetJumpBB());
267       }
268       bb = bb->GetFallBB();
269     }
270   }
271   // fixup jump arg
272   for (Block *bb : *graph_.GetCFG()) {
273     if (bb->GetJumpBB() == nullptr) {
274       continue;
275     }
276     int jump_bci = bb->GetJumpBB()->instrs().front().bci();
277     Instr &curr_instr = bb->instrs().back();
278     int jump_arg = Opcode(curr_instr.op()).JumpOffset(curr_instr.bci(), jump_bci);
279     curr_instr.set_arg(jump_arg);
280   }
281   // fixup deaded bb bci, because BuildGraph traverse instructions in bci order
282   for (const auto &bb : graph_.GetCFG()->bb_pool()) {
283     if (bb && bb->is_dead() && !bb->instrs().empty() && bb->instrs().front().bci() > head_bci) {
284       for (auto &instr : bb->instrs()) {
285         instr.set_bci(bci++);
286       }
287     }
288   }
289 }
290 
DumpLoopUnrolling()291 std::string LoopUnrolling::DumpLoopUnrolling() {
292   std::ostringstream os;
293   os << "*** Dump info after loop unrolling on ["
294      << py::str(reinterpret_cast<PyObject *>(graph_.GetCodeObj())).cast<std::string>() << "] ***\n";
295   os << "loop unrolling reason: " << GetLoopUnrollingReasonDesc(res_) << '\n';
296   if (loop_ != nullptr) {
297     os << "loop header: " << loop_->header()->Dump(false);
298   }
299   if (unrolling_count_ > 0) {
300     os << "loop count: " << unrolling_count_ << '\n';
301   }
302   os << '\n';
303   return os.str();
304 }
305 }  // namespace pijit
306 }  // namespace mindspore
307