1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pipeline/jit/pi/graph_capture/loop_unrolling.h"
17 #include <fstream>
18 #include <map>
19 #include <string>
20 #include <unordered_set>
21 #include <utility>
22 #include <vector>
23 #include "pipeline/jit/pi/graph_capture/graph.h"
24 #include "pipeline/jit/pi/graph_capture/loop.h"
25 #include "pipeline/jit/pi/pi_jit_config.h"
26
27 namespace mindspore {
28 namespace pijit {
29 #define CHECK_PHASE(func, ...) \
30 do { \
31 res_ = func(__VA_ARGS__); \
32 if (!LoopUnrolling::IsloopUnorlling(res_)) { \
33 return; \
34 } \
35 } while (0)
36
ExecuteLoopUnroll(Block * header)37 LoopUnrollingReason LoopUnrolling::ExecuteLoopUnroll(Block *header) {
38 if (graph_.loops().empty()) {
39 return kCanNotUnroll;
40 }
41 for (auto *lp : graph_.loops()) {
42 if (lp->header() == header) {
43 loop_ = lp;
44 break;
45 }
46 }
47 if (loop_ == nullptr) {
48 return kCanNotUnroll;
49 }
50 Run();
51 // Dump loop unrolling info
52 if (graph_.Config().GetBoolConfig(GraphJitConfig::kPrintAfterAll)) {
53 GRAPH_JIT_LOG_F("%s\n\n", DumpLoopUnrolling().c_str());
54 }
55 if (IsloopUnorlling(res_) && graph_.Config().GetBoolConfig(GraphJitConfig::kPrintBB)) {
56 GRAPH_JIT_LOG_F("%s\n\n", graph_.GetCFG()->DumpBBs("after loop unrolling ").c_str());
57 }
58 return res_;
59 }
60
Run()61 void LoopUnrolling::Run() {
62 // check only one exit and one backedges in loop
63 constexpr size_t max_size = 2;
64 if (loop_->exits().empty() || loop_->exits().size() >= max_size || loop_->backedges().size() >= max_size) {
65 res_ = kCanNotSplitGoto;
66 return;
67 }
68 // check pred of exit has only succ block
69 Block *exit = *loop_->exits().begin();
70 if (exit->pred_bbs().size() != 1) {
71 res_ = kCanNotSplitGoto;
72 return;
73 }
74 Block *backedge = *loop_->backedges().begin();
75 if (backedge->GetFallBB() != nullptr) {
76 res_ = kCanNotJumpBackedge;
77 return;
78 }
79 // check foritem
80 if (loop_->header()->instrs().front().op() == FOR_ITER) {
81 CHECK_PHASE(AnalyzeForItem);
82 }
83 CHECK_PHASE(CheckLoopUnrollingSideeffect);
84 is_cfg_changed_ = true;
85 RemoveBackedge();
86 CopyAndInsertBB();
87 FixupInstr();
88 }
89
AnalyzeForItem()90 LoopUnrollingReason LoopUnrolling::AnalyzeForItem() {
91 // find GET_ITER opcode
92 MS_EXCEPTION_IF_NULL(loop_value_);
93 AObject *loop_vobj = loop_value_->GetVobj();
94 if (!loop_vobj) {
95 return kCanNotUnroll;
96 }
97 PyObject *obj = loop_vobj->GetPyObject().ptr();
98 // check unrolling count
99 if (loop_vobj->GetType() == AObject::kTypeList || loop_vobj->GetType() == AObject::kTypeTuple) {
100 AbstractTuple *list = static_cast<AbstractTuple *>(loop_vobj);
101 if (!list->IsElementValid()) {
102 return kCanNotUnroll;
103 }
104 AddLoopGurad(loop_value_);
105 unrolling_count_ = list->size();
106 loop_op_ = NOP;
107 loop_arg_ = 0;
108 } else if (loop_vobj->GetType() == AObject::kTypeNNCellList && obj != nullptr) {
109 AddLoopGurad(loop_value_);
110 unrolling_count_ = PyObject_Size(obj);
111 loop_op_ = NOP;
112 loop_arg_ = 0;
113 } else {
114 return kCanNotUnroll;
115 }
116 return kCanForItemUnroll;
117 }
118
AddLoopGurad(ValueNode * value)119 bool LoopUnrolling::AddLoopGurad(ValueNode *value) { return graph_.GuardValueNode(value); }
120
CheckLoopUnrollingSideeffect()121 LoopUnrollingReason LoopUnrolling::CheckLoopUnrollingSideeffect() {
122 // check length
123 if (unrolling_count_ <= 0 || unrolling_count_ > graph_.Config().getIntConfig(GraphJitConfig::kMaxLoopUnrolling)) {
124 return kCanNotMaxCount;
125 }
126 if (loop_value_ == nullptr && loop_value_->GetVobj()) {
127 return res_;
128 }
129 // check if loop_value is called by CFunction, e.g. list.append()
130 // check side effects
131 return res_;
132 }
133
AddLoopUnrollingInstr(Block * bb,int count)134 void LoopUnrolling::AddLoopUnrollingInstr(Block *bb, int count) {
135 bb->set_is_loop_head(false);
136 const Instr &first_instr = bb->instrs().front();
137 bb->RemoveInstrs();
138 // remove GET_ITER and adding [count - 1] DUP_TOP
139 if (loop_op_ == NOP && count == 0) {
140 // GET_ITER --> DUP_TOP
141 if (iter_instr_ != nullptr) {
142 iter_instr_->set_op(DUP_TOP);
143 iter_instr_->set_arg(0);
144 }
145 if (unrolling_count_ == 1) {
146 Instr *instr = graph_.GetCFG()->NewInstrNode(first_instr.bci(), POP_TOP, 0, first_instr.line());
147 bb->AddInstr(instr);
148 }
149 for (int i = 0; i < unrolling_count_ - 2; ++i) {
150 Instr *instr = graph_.GetCFG()->NewInstrNode(first_instr.bci(), DUP_TOP, 0, first_instr.line());
151 bb->AddInstr(instr);
152 }
153 }
154 // get list or tuple ref
155 Instr *i = graph_.GetCFG()->NewInstrNode(-1, loop_op_, loop_arg_, first_instr.line());
156 bb->AddInstr(i);
157 if (count == 0) {
158 i->set_bci(first_instr.bci());
159 }
160 py::object value = py::int_(count);
161 // subscript index
162 i = graph_.GetCFG()->NewLoadInstrNode(-1, -1, first_instr.line(), value.ptr());
163 bb->AddInstr(i);
164 i = graph_.GetCFG()->NewInstrNode(-1, BINARY_SUBSCR, 0, first_instr.line());
165 bb->AddInstr(i);
166 }
167
168 // while-do pattern loop
RemoveBackedge()169 void LoopUnrolling::RemoveBackedge() {
170 loop_->header()->SetJumpBB(nullptr);
171 MS_EXCEPTION_IF_CHECK_FAIL(loop_->backedges().size() == 1, "backedges has only one block");
172 Block *backedge = *loop_->backedges().begin();
173 if (!backedge->instrs().empty() && &backedge->instrs().front() == &backedge->instrs().back()) {
174 backedge->instrs().front().set_op(NOP); // replace JUMP_ABSOLUTE
175 backedge->instrs().front().set_arg(0);
176 } else {
177 backedge->RemoveInstr(&backedge->instrs().back()); // remove JUMP_ABSOLUTE
178 }
179 backedge->SetJumpBB(nullptr);
180 backedge->SetFallBB(*loop_->exits().begin());
181 }
182
CopyAndInsertBB()183 void LoopUnrolling::CopyAndInsertBB() {
184 Block *exit = *loop_->exits().begin();
185 MS_EXCEPTION_IF_CHECK_FAIL(exit->pred_bbs().size() == 1, "pred of exit has only succ block");
186 Block *exit_pred = *exit->pred_bbs().begin();
187 Block *header = loop_->header();
188 Block *start = nullptr;
189 Block *tail = nullptr;
190 for (int i = 0; i < unrolling_count_; ++i) {
191 if (i == 0) {
192 AddLoopUnrollingInstr(header, i);
193 continue;
194 }
195 std::map<int, Block *> bb_map = CopyBB();
196 header = bb_map[loop_->header()->id()];
197 AddLoopUnrollingInstr(header, i);
198 if (i == 1) {
199 start = bb_map[loop_->header()->id()];
200 tail = bb_map[exit_pred->id()];
201 }
202 if (i > 1) {
203 tail->SetFallBB(bb_map[loop_->header()->id()]);
204 }
205 tail = bb_map[exit_pred->id()];
206 if (i == unrolling_count_ - 1) {
207 exit_pred->SetFallBB(nullptr);
208 exit_pred->SetFallBB(start);
209 tail->SetFallBB(exit);
210 }
211 }
212 }
213
CopyBB()214 std::map<int, Block *> LoopUnrolling::CopyBB() {
215 std::map<int, Block *> bb_map;
216 for (Block *memb : loop_->loop_members()) {
217 Block *new_bb = memb->Clone(graph_.GetCFG().get());
218 bb_map.insert(std::make_pair(memb->id(), new_bb));
219 }
220 // link active bb and instr
221 for (auto iter = graph_.GetCFG()->begin(loop_->header()); iter != graph_.GetCFG()->end(); ++iter) {
222 Block *bb = *iter;
223 if (loop_->loop_members().find(bb) == loop_->loop_members().cend()) {
224 break;
225 }
226 Block *dst_bb = bb_map[bb->id()];
227 if (dst_bb == nullptr) {
228 continue;
229 }
230 if (bb->GetFallBB() != nullptr) {
231 Block *dst_fall_bb = bb_map[bb->GetFallBB()->id()];
232 dst_bb->SetFallBB(dst_fall_bb);
233 }
234 if (bb->GetJumpBB() != nullptr) {
235 Block *dst_jump_bb = bb_map[bb->GetJumpBB()->id()];
236 dst_bb->SetJumpBB(dst_jump_bb);
237 // link instr jump
238 dst_bb->instrs().back().set_extra_jump(&dst_jump_bb->instrs().front());
239 dst_jump_bb->instrs().front().AddExtraPred(&dst_bb->instrs().back());
240 }
241 }
242 return bb_map;
243 }
244
FixupInstr()245 void LoopUnrolling::FixupInstr() {
246 int head_bci = loop_->header()->instrs().front().bci(); // first instruction bci of header is computed
247 int bci = head_bci;
248 // fixup bci
249 std::priority_queue<Block *, std::vector<Block *>, BBIdGreaterCmp> queue;
250 queue.push(loop_->header());
251 std::vector<bool> visited(graph_.GetCFG()->bb_pool().size(), false);
252 // dfs search falled bb
253 while (!queue.empty()) {
254 Block *bb = queue.top();
255 queue.pop();
256 while (bb != nullptr) {
257 if (visited[bb->id()]) {
258 bb = bb->GetFallBB();
259 continue;
260 }
261 visited[bb->id()] = true;
262 for (auto &instr : bb->instrs()) {
263 instr.set_bci(bci++);
264 }
265 if (bb->GetJumpBB() != nullptr && !visited[bb->GetJumpBB()->id()]) {
266 queue.push(bb->GetJumpBB());
267 }
268 bb = bb->GetFallBB();
269 }
270 }
271 // fixup jump arg
272 for (Block *bb : *graph_.GetCFG()) {
273 if (bb->GetJumpBB() == nullptr) {
274 continue;
275 }
276 int jump_bci = bb->GetJumpBB()->instrs().front().bci();
277 Instr &curr_instr = bb->instrs().back();
278 int jump_arg = Opcode(curr_instr.op()).JumpOffset(curr_instr.bci(), jump_bci);
279 curr_instr.set_arg(jump_arg);
280 }
281 // fixup deaded bb bci, because BuildGraph traverse instructions in bci order
282 for (const auto &bb : graph_.GetCFG()->bb_pool()) {
283 if (bb && bb->is_dead() && !bb->instrs().empty() && bb->instrs().front().bci() > head_bci) {
284 for (auto &instr : bb->instrs()) {
285 instr.set_bci(bci++);
286 }
287 }
288 }
289 }
290
DumpLoopUnrolling()291 std::string LoopUnrolling::DumpLoopUnrolling() {
292 std::ostringstream os;
293 os << "*** Dump info after loop unrolling on ["
294 << py::str(reinterpret_cast<PyObject *>(graph_.GetCodeObj())).cast<std::string>() << "] ***\n";
295 os << "loop unrolling reason: " << GetLoopUnrollingReasonDesc(res_) << '\n';
296 if (loop_ != nullptr) {
297 os << "loop header: " << loop_->header()->Dump(false);
298 }
299 if (unrolling_count_ > 0) {
300 os << "loop count: " << unrolling_count_ << '\n';
301 }
302 os << '\n';
303 return os.str();
304 }
305 } // namespace pijit
306 } // namespace mindspore
307