1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pipeline/jit/pi/graph_capture/bytecode_inliner.h"
17 #include <set>
18 #include <utility>
19 #include <algorithm>
20 #include <string>
21 #include "pipeline/jit/pi/graph_capture/graph.h"
22 #include "pipeline/jit/pi/graph_capture/side_effect.h"
23 #include "pipeline/jit/pi/graph_guard/cache.h"
24 #include "pipeline/jit/pi/pi_jit_config.h"
25
26 namespace mindspore {
27 namespace pijit {
28
29 extern std::string PrintInstr(const std::vector<std::unique_ptr<Instr>> &list);
30 extern bool CheckMSConstexpr(const py::object &func);
31 extern bool CheckJitConstexpr(const py::object &func);
32
Run()33 void BytecodeInliner::Run() {
34 if (graph_->IsBreakAtLoop() && !graph_->RestoreLoopStatus()) {
35 return;
36 }
37
38 inline_partial_ = graph_->Config().GetBoolConfig(GraphJitConfig::kFeatureBreakAtInlinedFunction);
39 cfg_ = std::make_unique<CFG>(nullptr);
40
41 // collect traced nodes, inline second half bytecode
42 if (graph_->GetStopTraceBci() != -1) {
43 last_frame_ = std::make_unique<FrameStates>();
44 ProcessGraph(graph_, 0);
45 } else {
46 CollectTracedNodes(graph_);
47 }
48
49 if (traced_nodes_.empty()) {
50 return;
51 }
52
53 if (inline_partial_ && graph_->GetStopTraceBci() != -1) {
54 InitCFG();
55 }
56
57 Rebuild();
58
59 if (graph_->Config().GetBoolConfig(GraphJitConfig::kPrintBB)) {
60 GRAPH_JIT_LOG_F("%s\n\n", cfg_->DumpBBs().c_str());
61 }
62
63 ResetGraphStat();
64
65 if (graph_->Config().GetBoolConfig(GraphJitConfig::kLogGraphBreak)) {
66 std::stringstream s;
67 s << "graph new break bci is " << new_break_bci_ << " after inline";
68 if (reconstructed_value_ != nullptr) {
69 const auto &instr = graph_->GetCFG()->instr_pool()[new_break_bci_];
70 s << ", node is reconstructed by inliner [" << reconstructed_value_->ToString() << "] -> [" << instr->ToString()
71 << "]";
72 }
73 GRAPH_JIT_LOG_F("%s", s.str().c_str());
74 }
75 }
76
ResetGraphStat()77 void BytecodeInliner::ResetGraphStat() {
78 graph_->GetFrames().swap(new_frames_);
79 graph_->GetCFG().swap(cfg_);
80 graph_->GetTracedNodes().swap(traced_nodes_);
81 if (graph_->GetStopTraceBci() != -1) {
82 graph_->StopTraceAt(new_break_bci_, graph_->GetStopTraceReason());
83 }
84 }
85
ResetCFG(CodeGenerator * cg)86 void BytecodeInliner::ResetCFG(CodeGenerator *cg) {
87 std::vector<std::unique_ptr<Instr>> list = cg->MoveCode();
88 std::move(cfg_->instr_pool().begin(), cfg_->instr_pool().end(), std::back_inserter(list));
89 cfg_->instr_pool().swap(list);
90 cfg_->bb_pool().clear();
91 cfg_->liveness().reset();
92 cfg_->SetLocalCount(cg->GetCode().co_nlocals);
93 InitCFG();
94 }
95
Rebuild(CodeGenerator * cg)96 void BytecodeInliner::Rebuild(CodeGenerator *cg) {
97 FrameStates new_f = graph_->GetFrame(0);
98 int new_bci = 0;
99 cg->SetGlobals(extra_globals_);
100 cg->Init();
101 cg->MarkAlive();
102 new_frames_[0] = std::make_unique<FrameStates>(new_f);
103 for (size_t index = 0; index < traced_nodes_.size(); ++index) {
104 ValueNode *node = traced_nodes_[index];
105 if (IsNonLocalValue(node)) {
106 node->set_bci(new_bci);
107 continue;
108 }
109 cg->BuildOper(node, index);
110
111 // reset bci
112 int last_op = cg->GetCode().co_code.back()->op();
113 new_bci = SizeToInt(cg->GetCode().co_code.size()) - 1 - (last_op == POP_TOP || last_op == STORE_FAST);
114 node->set_bci(new_bci);
115
116 // reset frame status
117 new_f.GetStacks() = node->getInputs();
118 new_frames_[new_bci] = std::make_unique<FrameStates>(new_f);
119 if (last_op == STORE_FAST) {
120 int arg = cg->GetCode().co_code.back()->arg();
121 new_f.ResizeLocal(std::max(new_f.GetLocals().size(), static_cast<size_t>(arg + 1)));
122 new_f.SetLocal(arg, node);
123 }
124 }
125 int nlocals = std::max(cg->GetLocalsMap().size(), new_f.GetLocals().size());
126 nlocals = std::max(nlocals, cg->GetCode().co_nlocals);
127 cg->SetLocalsCount(nlocals);
128 }
129
Rebuild()130 void BytecodeInliner::Rebuild() {
131 NodeSet ns = {
132 graph_->GetFrame(0).GetLocals(),
133 std::vector<ValueNode *>(),
134 traced_nodes_,
135 };
136 CodeGenerator cg(&ns);
137
138 std::vector<int> alive_locals;
139 if (last_frame_ != nullptr) {
140 BitMap alive = inline_partial_ ? cfg_->GetLiveness()->CollectAlive(0)
141 : graph_->GetCFG()->GetLiveness()->CollectAlive(graph_->GetStopTraceBci());
142 ns.outputs = Graph::CollectAliveNode(*last_frame_, &alive, &alive_locals);
143 } else {
144 ns.outputs.push_back(graph_->GetRetVal());
145 }
146
147 if (graph_->Config().GetBoolConfig(GraphJitConfig::kEnableEliminateUnusedOperation)) {
148 // erase dead local between inline and code rebuild
149 EraseDeadLocal(ns.outputs);
150 EliminateClosureSideEffect();
151 }
152 Rebuild(&cg);
153
154 if (last_frame_ != nullptr) {
155 std::for_each(ns.outputs.begin(), ns.outputs.end(), [&cg](ValueNode *i) { cg.LoadValue(i); });
156 std::for_each(alive_locals.rbegin(), alive_locals.rend(), [&cg](int i) { cg.NewInstr(STORE_FAST, i); });
157 cg.SetLocalsCount(last_frame_->GetLocals().size());
158 } else {
159 cg.GenReturn();
160 }
161
162 if (last_frame_ != nullptr) {
163 MS_EXCEPTION_IF_CHECK_FAIL(new_frames_.find(cg.GetCode().co_code.size()) == new_frames_.end(),
164 "duplicate frame status");
165 new_break_bci_ = SizeToInt(cg.GetCode().co_code.size());
166 new_frames_[new_break_bci_] = std::move(last_frame_);
167 }
168
169 ResetCFG(&cg);
170 }
171
CollectTracedNodes(Graph * graph)172 void BytecodeInliner::CollectTracedNodes(Graph *graph) {
173 for (ValueNode *n : graph->GetTracedNodes()) {
174 if (n->GetType() != AbstractNode::Call) {
175 traced_nodes_.push_back(n);
176 continue;
177 }
178 CallNode *call_node = static_cast<CallNode *>(n);
179 if (call_node->GetSubGraph() == nullptr || call_node->GetInlineReason() != InlineReason::kInline) {
180 traced_nodes_.push_back(n);
181 continue;
182 }
183 std::copy(call_node->GetParams().begin(), call_node->GetParams().end(), std::back_inserter(traced_nodes_));
184 CollectTracedNodes(call_node->GetSubGraph());
185 }
186 }
187
ProcessGraph(Graph * graph,int local_off)188 void BytecodeInliner::ProcessGraph(Graph *graph, int local_off) {
189 int break_bci = graph->GetStopTraceBci();
190 if (break_bci == -1) {
191 return;
192 }
193
194 // build last frame
195 const FrameStates &f = graph->GetFrame(break_bci);
196 last_frame_->GetLocals().insert(last_frame_->GetLocals().end(), f.GetLocals().begin(), f.GetLocals().end());
197 last_frame_->GetStacks().insert(last_frame_->GetStacks().end(), f.GetStacks().begin(), f.GetStacks().end());
198
199 CollectTracedNodes(graph);
200
201 const auto &nodes = graph->GetTracedNodes();
202 if (nodes.size() > 0 && nodes.back()->bci() == break_bci) {
203 // break at traced value
204 Reconstruct(nodes.back(), local_off + f.GetLocals().size());
205 break_bci++;
206 } else {
207 // break at unsupported bytecode
208 MS_EXCEPTION_IF_CHECK_FAIL(nodes.empty() || break_bci > nodes.back()->bci(), "check break bci");
209 new_break_bci_ = 0;
210 }
211
212 std::vector<std::unique_ptr<Instr>> list = CodeGenerator::CopyInstr(graph->GetCFG()->instr_pool(), break_bci);
213 if (inline_partial_) {
214 FixInstr(graph, local_off, &list);
215 }
216 std::move(list.begin(), list.end(), std::back_inserter(cfg_->instr_pool()));
217 cfg_->SetLocalCount(std::max(static_cast<size_t>(cfg_->GetLocalCount()), local_off + f.GetLocals().size()));
218 }
219
EliminateSideEffect(Graph * top_graph,Graph * sub_graph)220 static bool EliminateSideEffect(Graph *top_graph, Graph *sub_graph) {
221 /**
222 * kFeatureBreakAtInlinedFunc
223 * eliminate untracked bytecode side effect after graph break
224 * 1. eliminate MAKE_FUNCTION if it has global access and globals is not same as top func
225 * 2. eliminate STORE_GLOBAL and DELETE_GLOBAL if globals is not same as top func
226 * 3. eliminate closure access operations if function has cell or free variable
227 */
228 if (top_graph != sub_graph && sub_graph->GetFrame(0).GetClosures().size() != 0) {
229 PyObject *frees = sub_graph->GetCodeObj()->co_freevars;
230 if (PyTuple_GET_SIZE(frees) == 1 && std::string("__class__") == PyUnicode_AsUTF8(PyTuple_GET_ITEM(frees, 0))) {
231 /**
232 * BUGS: not check super call or free variable access after the break point
233 **/
234 return true;
235 }
236 return false;
237 }
238 /**
239 * BUGS: not check MAKE_FUNCTION which has global access after the break point
240 **/
241 return true;
242 }
243
CanIninePartial(Graph * top_graph,Graph * sub_graph)244 static bool CanIninePartial(Graph *top_graph, Graph *sub_graph) {
245 if (sub_graph == nullptr) {
246 return false;
247 }
248 if (sub_graph->IsBreakAtLoop()) {
249 return false;
250 }
251 if (top_graph->GetGlobals() == sub_graph->GetGlobals()) {
252 return true;
253 }
254 if (!EliminateSideEffect(top_graph, sub_graph)) {
255 return false;
256 }
257 for (auto i : sub_graph->GetTracedNodes()) {
258 int op = i->GetOpcode();
259 if (op == MAKE_FUNCTION) {
260 return false;
261 }
262 }
263 return true;
264 }
265
Reconstruct(ValueNode * node,int local_off)266 void BytecodeInliner::Reconstruct(ValueNode *node, int local_off) {
267 static const std::set<int> not_value_oper = {
268 STORE_DEREF, DELETE_DEREF, STORE_GLOBAL, DELETE_GLOBAL, STORE_ATTR, DELETE_ATTR,
269 STORE_SUBSCR, DELETE_SUBSCR, IMPORT_STAR, RAISE_VARARGS, RERAISE,
270 };
271 traced_nodes_.pop_back();
272
273 Graph *graph = node->GetGraph();
274 const auto &instr = graph->GetCFG()->instr_pool()[node->bci()];
275
276 int stack_effect = PyCompile_OpcodeStackEffect(instr->op(), instr->arg());
277 bool is_value = not_value_oper.find(instr->op()) == not_value_oper.end();
278 MS_EXCEPTION_IF_CHECK_FAIL(stack_effect <= 0 && stack_effect != PY_INVALID_STACK_EFFECT,
279 "check break bci, too many value produced");
280 last_frame_->Popn(-stack_effect + is_value);
281
282 if (inline_partial_ && node->GetType() == AbstractNode::Call) {
283 CallNode *call_node = static_cast<CallNode *>(node);
284 if (CanIninePartial(this->graph_, call_node->GetSubGraph())) {
285 std::copy(call_node->GetParams().begin(), call_node->GetParams().end(), std::back_inserter(traced_nodes_));
286 ProcessGraph(call_node->GetSubGraph(), local_off);
287 return;
288 }
289 }
290 for (auto i : node->getInputs()) {
291 last_frame_->Push(i);
292 }
293 MS_EXCEPTION_IF_CHECK_FAIL(cfg_->instr_pool().empty(), "just call once if graph break at traced value");
294 cfg_->NewInstrNode(*instr);
295 reconstructed_value_ = node;
296
297 /**
298 * if the node not match the instruction opcode, check it's sideeffect
299 */
300 }
301
FixInstr(Graph * graph,int local_off,std::vector<std::unique_ptr<Instr>> * list)302 void BytecodeInliner::FixInstr(Graph *graph, int local_off, std::vector<std::unique_ptr<Instr>> *list) {
303 if (list->empty()) {
304 return;
305 }
306 for (const auto &i : *list) {
307 if (Opcode(i->op()).IsLocalAccess()) {
308 i->set_arg(i->arg() + local_off);
309 continue;
310 }
311 if (this->graph_ != graph && i->op() == RETURN_VALUE) {
312 i->set_op(JUMP_FORWARD);
313 i->set_extra_jump(list->back().get());
314 continue;
315 }
316 if (graph->GetGlobals().ptr() == this->graph_->GetGlobals().ptr()) {
317 continue;
318 }
319 if (i->op() == LOAD_GLOBAL) {
320 PyObject *value = PyObject_GetItem(graph->GetGlobals().ptr(), py::str(i->name()).ptr());
321 py::object _value_handle = py::reinterpret_steal<py::object>(value);
322 if (value == nullptr) {
323 PyErr_Clear();
324 continue;
325 }
326 auto tr = std::make_shared<RootTrace>(value, TraceType::Global, -1, i->name(), graph->GetModuleName());
327 graph->GetGuard()->GetGuard()->GuardOn(tr, GuardLevel::GId);
328 std::string key = i->name();
329 MapAdd(extra_globals_, key, _value_handle, &key);
330 i->set_name(key);
331 continue;
332 }
333 }
334
335 if (list->back()->op() != JUMP_FORWARD || list->back()->extra_jump() != list->back().get()) {
336 return;
337 }
338 list->back()->set_extra_jump(nullptr);
339 if (graph != this->graph_) {
340 list->back()->set_op(NOP);
341 } else {
342 list->back()->set_op(RETURN_VALUE);
343 }
344 }
345
346 /**
347 * unify the implementations of cfg initialization
348 */
InitCFG()349 void BytecodeInliner::InitCFG() {
350 const auto &list = cfg_->instr_pool();
351
352 // reset bci, erase unused jump
353 CodeGenerator::EraseUnusedInstr(&cfg_->instr_pool());
354
355 // mark labels, ordered map
356 std::map<int, Block *> blocks;
357 blocks.insert({0, cfg_->NewBBAppend()});
358 for (const auto &i : list) {
359 size_t bci = list.size();
360 if (Opcode(i->op()).IsNotFall()) {
361 bci = (size_t)i->bci() + 1;
362 }
363 if (i->extra_jump() != nullptr) {
364 bci = (size_t)i->bci() + 1;
365 if (blocks.find(i->extra_jump()->bci()) == blocks.end()) {
366 blocks.insert({i->extra_jump()->bci(), cfg_->NewBBAppend()});
367 }
368 }
369 if (bci != list.size() && blocks.find(bci) == blocks.end()) {
370 blocks.insert({bci, cfg_->NewBBAppend()});
371 }
372 }
373
374 // link blocks, set range
375 for (auto iter = blocks.begin(); iter != blocks.end();) {
376 Block *cur = iter->second;
377 int head = iter->first;
378 int back;
379 iter++;
380 if (iter != blocks.end()) {
381 back = iter->first;
382 } else {
383 back = SizeToInt(list.size());
384 }
385 cur->set_begin_ci(head);
386 cur->set_end_ci(back);
387 const auto &instr = list[back - 1];
388 if (instr->extra_jump()) {
389 cur->SetJumpBB(blocks[instr->extra_jump()->bci()]);
390 }
391 if (!Opcode(instr->op()).IsNotFall()) {
392 cur->SetFallBB(iter->second);
393 }
394 }
395 cfg_->MarkDeadBB();
396 cfg_->GetLiveness();
397 }
398
IsEliminate(ValueNode * v)399 static bool IsEliminate(ValueNode *v) {
400 auto op = Opcode(v->GetOpcode());
401 if (op.MayDelete()) {
402 return true;
403 }
404 if (op.IsBinaryMath()) {
405 // inplace binary
406 AObject::Type t = v->input(0)->GetVobj()->GetType();
407 return t != AObject::kTypeAnyValue && t != AObject::kTypeList && t != AObject::kTypeCell &&
408 t != AObject::kTypeNNCellList;
409 }
410 if (op.IsCall()) {
411 py::object callable = v->input(0)->GetVobj()->GetPyObject();
412 if (callable.ptr() == nullptr) {
413 return false;
414 }
415 return CheckJitConstexpr(callable) || CheckMSConstexpr(callable);
416 }
417 if (op == GET_ITER) {
418 return v->input(0)->GetVobj()->GetType() != AObject::kTypeAnyValue;
419 }
420 return false;
421 }
422
EraseDeadLocal(const std::vector<ValueNode * > & alive_nodes)423 void BytecodeInliner::EraseDeadLocal(const std::vector<ValueNode *> &alive_nodes) {
424 std::set<ValueNode *> alive;
425 for (auto i : alive_nodes) {
426 alive.insert(i);
427 }
428
429 // erase dead locals
430 std::set<ValueNode *> used;
431 do {
432 used = alive;
433 for (auto i : traced_nodes_) {
434 for (auto j : i->getInputs()) {
435 used.insert(j);
436 }
437 }
438 auto iter = std::remove_if(traced_nodes_.begin(), traced_nodes_.end(), [&used](ValueNode *i) {
439 // check it
440 return used.find(i) == used.end() && IsEliminate(i);
441 });
442 if (iter == traced_nodes_.end()) {
443 break;
444 }
445 traced_nodes_.erase(iter, traced_nodes_.end());
446 } while (true);
447 }
448
EliminateClosureSideEffect()449 void BytecodeInliner::EliminateClosureSideEffect() {
450 PyCodeObject *co = graph_->GetCodeObj();
451 int ncells = PyTuple_GET_SIZE(co->co_cellvars);
452 int nfrees = PyTuple_GET_SIZE(co->co_freevars);
453 if (ncells + nfrees == 0) {
454 return;
455 }
456 std::set<InstrNode *> alive_closure_access;
457
458 if (last_frame_ != nullptr) {
459 auto iter = std::find_if(cfg_->instr_pool().begin(), cfg_->instr_pool().end(), [](const std::unique_ptr<Instr> &i) {
460 return i->op() == LOAD_DEREF || (i->op() == MAKE_FUNCTION && ((signed)i->arg() & 0x08));
461 });
462 if (iter != cfg_->instr_pool().end()) {
463 return;
464 }
465 }
466
467 for (auto i : traced_nodes_) {
468 if (i->GetOpcode() == MAKE_FUNCTION && (i->GetOparg() & 0x08)) {
469 ValueNode *tuple = *(i->getInputs().end() - 3);
470 for (auto c : tuple->getInputs()) {
471 const auto &nodes = static_cast<CellVarNode *>(c)->GetCellOper();
472 alive_closure_access.insert(nodes.begin(), nodes.end());
473 }
474 }
475 }
476
477 auto iter = std::remove_if(traced_nodes_.begin(), traced_nodes_.end(), [&alive_closure_access](ValueNode *i) {
478 int op = i->GetOpcode();
479 return (op == STORE_DEREF || op == DELETE_DEREF) && alive_closure_access.find(i) == alive_closure_access.end();
480 });
481 traced_nodes_.erase(iter, traced_nodes_.end());
482
483 for (auto item = traced_nodes_.begin(); item != traced_nodes_.end();) {
484 if ((*item)->GetOpcode() == STORE_DEREF) {
485 if ((*item)->getInputs()[0]->GetOpcode() == LOAD_DEREF &&
486 (*item)->getInputs()[0]->GetOparg() == (*item)->GetOparg() &&
487 (*item)->getInputs()[0]->GetGraph() == (*item)->GetGraph()) {
488 item = traced_nodes_.erase(item);
489 } else {
490 ++item;
491 }
492 } else {
493 ++item;
494 }
495 }
496 }
497
498 } // namespace pijit
499 } // namespace mindspore
500