• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pipeline/jit/pi/graph_capture/bytecode_inliner.h"
17 #include <set>
18 #include <utility>
19 #include <algorithm>
20 #include <string>
21 #include "pipeline/jit/pi/graph_capture/graph.h"
22 #include "pipeline/jit/pi/graph_capture/side_effect.h"
23 #include "pipeline/jit/pi/graph_guard/cache.h"
24 #include "pipeline/jit/pi/pi_jit_config.h"
25 
26 namespace mindspore {
27 namespace pijit {
28 
29 extern std::string PrintInstr(const std::vector<std::unique_ptr<Instr>> &list);
30 extern bool CheckMSConstexpr(const py::object &func);
31 extern bool CheckJitConstexpr(const py::object &func);
32 
Run()33 void BytecodeInliner::Run() {
34   if (graph_->IsBreakAtLoop() && !graph_->RestoreLoopStatus()) {
35     return;
36   }
37 
38   inline_partial_ = graph_->Config().GetBoolConfig(GraphJitConfig::kFeatureBreakAtInlinedFunction);
39   cfg_ = std::make_unique<CFG>(nullptr);
40 
41   // collect traced nodes, inline second half bytecode
42   if (graph_->GetStopTraceBci() != -1) {
43     last_frame_ = std::make_unique<FrameStates>();
44     ProcessGraph(graph_, 0);
45   } else {
46     CollectTracedNodes(graph_);
47   }
48 
49   if (traced_nodes_.empty()) {
50     return;
51   }
52 
53   if (inline_partial_ && graph_->GetStopTraceBci() != -1) {
54     InitCFG();
55   }
56 
57   Rebuild();
58 
59   if (graph_->Config().GetBoolConfig(GraphJitConfig::kPrintBB)) {
60     GRAPH_JIT_LOG_F("%s\n\n", cfg_->DumpBBs().c_str());
61   }
62 
63   ResetGraphStat();
64 
65   if (graph_->Config().GetBoolConfig(GraphJitConfig::kLogGraphBreak)) {
66     std::stringstream s;
67     s << "graph new break bci is " << new_break_bci_ << " after inline";
68     if (reconstructed_value_ != nullptr) {
69       const auto &instr = graph_->GetCFG()->instr_pool()[new_break_bci_];
70       s << ", node is reconstructed by inliner [" << reconstructed_value_->ToString() << "] -> [" << instr->ToString()
71         << "]";
72     }
73     GRAPH_JIT_LOG_F("%s", s.str().c_str());
74   }
75 }
76 
ResetGraphStat()77 void BytecodeInliner::ResetGraphStat() {
78   graph_->GetFrames().swap(new_frames_);
79   graph_->GetCFG().swap(cfg_);
80   graph_->GetTracedNodes().swap(traced_nodes_);
81   if (graph_->GetStopTraceBci() != -1) {
82     graph_->StopTraceAt(new_break_bci_, graph_->GetStopTraceReason());
83   }
84 }
85 
ResetCFG(CodeGenerator * cg)86 void BytecodeInliner::ResetCFG(CodeGenerator *cg) {
87   std::vector<std::unique_ptr<Instr>> list = cg->MoveCode();
88   std::move(cfg_->instr_pool().begin(), cfg_->instr_pool().end(), std::back_inserter(list));
89   cfg_->instr_pool().swap(list);
90   cfg_->bb_pool().clear();
91   cfg_->liveness().reset();
92   cfg_->SetLocalCount(cg->GetCode().co_nlocals);
93   InitCFG();
94 }
95 
Rebuild(CodeGenerator * cg)96 void BytecodeInliner::Rebuild(CodeGenerator *cg) {
97   FrameStates new_f = graph_->GetFrame(0);
98   int new_bci = 0;
99   cg->SetGlobals(extra_globals_);
100   cg->Init();
101   cg->MarkAlive();
102   new_frames_[0] = std::make_unique<FrameStates>(new_f);
103   for (size_t index = 0; index < traced_nodes_.size(); ++index) {
104     ValueNode *node = traced_nodes_[index];
105     if (IsNonLocalValue(node)) {
106       node->set_bci(new_bci);
107       continue;
108     }
109     cg->BuildOper(node, index);
110 
111     // reset bci
112     int last_op = cg->GetCode().co_code.back()->op();
113     new_bci = SizeToInt(cg->GetCode().co_code.size()) - 1 - (last_op == POP_TOP || last_op == STORE_FAST);
114     node->set_bci(new_bci);
115 
116     // reset frame status
117     new_f.GetStacks() = node->getInputs();
118     new_frames_[new_bci] = std::make_unique<FrameStates>(new_f);
119     if (last_op == STORE_FAST) {
120       int arg = cg->GetCode().co_code.back()->arg();
121       new_f.ResizeLocal(std::max(new_f.GetLocals().size(), static_cast<size_t>(arg + 1)));
122       new_f.SetLocal(arg, node);
123     }
124   }
125   int nlocals = std::max(cg->GetLocalsMap().size(), new_f.GetLocals().size());
126   nlocals = std::max(nlocals, cg->GetCode().co_nlocals);
127   cg->SetLocalsCount(nlocals);
128 }
129 
Rebuild()130 void BytecodeInliner::Rebuild() {
131   NodeSet ns = {
132     graph_->GetFrame(0).GetLocals(),
133     std::vector<ValueNode *>(),
134     traced_nodes_,
135   };
136   CodeGenerator cg(&ns);
137 
138   std::vector<int> alive_locals;
139   if (last_frame_ != nullptr) {
140     BitMap alive = inline_partial_ ? cfg_->GetLiveness()->CollectAlive(0)
141                                    : graph_->GetCFG()->GetLiveness()->CollectAlive(graph_->GetStopTraceBci());
142     ns.outputs = Graph::CollectAliveNode(*last_frame_, &alive, &alive_locals);
143   } else {
144     ns.outputs.push_back(graph_->GetRetVal());
145   }
146 
147   if (graph_->Config().GetBoolConfig(GraphJitConfig::kEnableEliminateUnusedOperation)) {
148     // erase dead local between inline and code rebuild
149     EraseDeadLocal(ns.outputs);
150     EliminateClosureSideEffect();
151   }
152   Rebuild(&cg);
153 
154   if (last_frame_ != nullptr) {
155     std::for_each(ns.outputs.begin(), ns.outputs.end(), [&cg](ValueNode *i) { cg.LoadValue(i); });
156     std::for_each(alive_locals.rbegin(), alive_locals.rend(), [&cg](int i) { cg.NewInstr(STORE_FAST, i); });
157     cg.SetLocalsCount(last_frame_->GetLocals().size());
158   } else {
159     cg.GenReturn();
160   }
161 
162   if (last_frame_ != nullptr) {
163     MS_EXCEPTION_IF_CHECK_FAIL(new_frames_.find(cg.GetCode().co_code.size()) == new_frames_.end(),
164                                "duplicate frame status");
165     new_break_bci_ = SizeToInt(cg.GetCode().co_code.size());
166     new_frames_[new_break_bci_] = std::move(last_frame_);
167   }
168 
169   ResetCFG(&cg);
170 }
171 
CollectTracedNodes(Graph * graph)172 void BytecodeInliner::CollectTracedNodes(Graph *graph) {
173   for (ValueNode *n : graph->GetTracedNodes()) {
174     if (n->GetType() != AbstractNode::Call) {
175       traced_nodes_.push_back(n);
176       continue;
177     }
178     CallNode *call_node = static_cast<CallNode *>(n);
179     if (call_node->GetSubGraph() == nullptr || call_node->GetInlineReason() != InlineReason::kInline) {
180       traced_nodes_.push_back(n);
181       continue;
182     }
183     std::copy(call_node->GetParams().begin(), call_node->GetParams().end(), std::back_inserter(traced_nodes_));
184     CollectTracedNodes(call_node->GetSubGraph());
185   }
186 }
187 
ProcessGraph(Graph * graph,int local_off)188 void BytecodeInliner::ProcessGraph(Graph *graph, int local_off) {
189   int break_bci = graph->GetStopTraceBci();
190   if (break_bci == -1) {
191     return;
192   }
193 
194   // build last frame
195   const FrameStates &f = graph->GetFrame(break_bci);
196   last_frame_->GetLocals().insert(last_frame_->GetLocals().end(), f.GetLocals().begin(), f.GetLocals().end());
197   last_frame_->GetStacks().insert(last_frame_->GetStacks().end(), f.GetStacks().begin(), f.GetStacks().end());
198 
199   CollectTracedNodes(graph);
200 
201   const auto &nodes = graph->GetTracedNodes();
202   if (nodes.size() > 0 && nodes.back()->bci() == break_bci) {
203     // break at traced value
204     Reconstruct(nodes.back(), local_off + f.GetLocals().size());
205     break_bci++;
206   } else {
207     // break at unsupported bytecode
208     MS_EXCEPTION_IF_CHECK_FAIL(nodes.empty() || break_bci > nodes.back()->bci(), "check break bci");
209     new_break_bci_ = 0;
210   }
211 
212   std::vector<std::unique_ptr<Instr>> list = CodeGenerator::CopyInstr(graph->GetCFG()->instr_pool(), break_bci);
213   if (inline_partial_) {
214     FixInstr(graph, local_off, &list);
215   }
216   std::move(list.begin(), list.end(), std::back_inserter(cfg_->instr_pool()));
217   cfg_->SetLocalCount(std::max(static_cast<size_t>(cfg_->GetLocalCount()), local_off + f.GetLocals().size()));
218 }
219 
EliminateSideEffect(Graph * top_graph,Graph * sub_graph)220 static bool EliminateSideEffect(Graph *top_graph, Graph *sub_graph) {
221   /**
222    * kFeatureBreakAtInlinedFunc
223    * eliminate untracked bytecode side effect after graph break
224    * 1. eliminate MAKE_FUNCTION if it has global access and globals is not same as top func
225    * 2. eliminate STORE_GLOBAL and DELETE_GLOBAL if globals is not same as top func
226    * 3. eliminate closure access operations if function has cell or free variable
227    */
228   if (top_graph != sub_graph && sub_graph->GetFrame(0).GetClosures().size() != 0) {
229     PyObject *frees = sub_graph->GetCodeObj()->co_freevars;
230     if (PyTuple_GET_SIZE(frees) == 1 && std::string("__class__") == PyUnicode_AsUTF8(PyTuple_GET_ITEM(frees, 0))) {
231       /**
232        * BUGS: not check super call or free variable access after the break point
233        **/
234       return true;
235     }
236     return false;
237   }
238   /**
239    * BUGS: not check MAKE_FUNCTION which has global access after the break point
240    **/
241   return true;
242 }
243 
CanIninePartial(Graph * top_graph,Graph * sub_graph)244 static bool CanIninePartial(Graph *top_graph, Graph *sub_graph) {
245   if (sub_graph == nullptr) {
246     return false;
247   }
248   if (sub_graph->IsBreakAtLoop()) {
249     return false;
250   }
251   if (top_graph->GetGlobals() == sub_graph->GetGlobals()) {
252     return true;
253   }
254   if (!EliminateSideEffect(top_graph, sub_graph)) {
255     return false;
256   }
257   for (auto i : sub_graph->GetTracedNodes()) {
258     int op = i->GetOpcode();
259     if (op == MAKE_FUNCTION) {
260       return false;
261     }
262   }
263   return true;
264 }
265 
Reconstruct(ValueNode * node,int local_off)266 void BytecodeInliner::Reconstruct(ValueNode *node, int local_off) {
267   static const std::set<int> not_value_oper = {
268     STORE_DEREF,  DELETE_DEREF,  STORE_GLOBAL, DELETE_GLOBAL, STORE_ATTR, DELETE_ATTR,
269     STORE_SUBSCR, DELETE_SUBSCR, IMPORT_STAR,  RAISE_VARARGS, RERAISE,
270   };
271   traced_nodes_.pop_back();
272 
273   Graph *graph = node->GetGraph();
274   const auto &instr = graph->GetCFG()->instr_pool()[node->bci()];
275 
276   int stack_effect = PyCompile_OpcodeStackEffect(instr->op(), instr->arg());
277   bool is_value = not_value_oper.find(instr->op()) == not_value_oper.end();
278   MS_EXCEPTION_IF_CHECK_FAIL(stack_effect <= 0 && stack_effect != PY_INVALID_STACK_EFFECT,
279                              "check break bci, too many value produced");
280   last_frame_->Popn(-stack_effect + is_value);
281 
282   if (inline_partial_ && node->GetType() == AbstractNode::Call) {
283     CallNode *call_node = static_cast<CallNode *>(node);
284     if (CanIninePartial(this->graph_, call_node->GetSubGraph())) {
285       std::copy(call_node->GetParams().begin(), call_node->GetParams().end(), std::back_inserter(traced_nodes_));
286       ProcessGraph(call_node->GetSubGraph(), local_off);
287       return;
288     }
289   }
290   for (auto i : node->getInputs()) {
291     last_frame_->Push(i);
292   }
293   MS_EXCEPTION_IF_CHECK_FAIL(cfg_->instr_pool().empty(), "just call once if graph break at traced value");
294   cfg_->NewInstrNode(*instr);
295   reconstructed_value_ = node;
296 
297   /**
298    * if the node not match the instruction opcode, check it's sideeffect
299    */
300 }
301 
FixInstr(Graph * graph,int local_off,std::vector<std::unique_ptr<Instr>> * list)302 void BytecodeInliner::FixInstr(Graph *graph, int local_off, std::vector<std::unique_ptr<Instr>> *list) {
303   if (list->empty()) {
304     return;
305   }
306   for (const auto &i : *list) {
307     if (Opcode(i->op()).IsLocalAccess()) {
308       i->set_arg(i->arg() + local_off);
309       continue;
310     }
311     if (this->graph_ != graph && i->op() == RETURN_VALUE) {
312       i->set_op(JUMP_FORWARD);
313       i->set_extra_jump(list->back().get());
314       continue;
315     }
316     if (graph->GetGlobals().ptr() == this->graph_->GetGlobals().ptr()) {
317       continue;
318     }
319     if (i->op() == LOAD_GLOBAL) {
320       PyObject *value = PyObject_GetItem(graph->GetGlobals().ptr(), py::str(i->name()).ptr());
321       py::object _value_handle = py::reinterpret_steal<py::object>(value);
322       if (value == nullptr) {
323         PyErr_Clear();
324         continue;
325       }
326       auto tr = std::make_shared<RootTrace>(value, TraceType::Global, -1, i->name(), graph->GetModuleName());
327       graph->GetGuard()->GetGuard()->GuardOn(tr, GuardLevel::GId);
328       std::string key = i->name();
329       MapAdd(extra_globals_, key, _value_handle, &key);
330       i->set_name(key);
331       continue;
332     }
333   }
334 
335   if (list->back()->op() != JUMP_FORWARD || list->back()->extra_jump() != list->back().get()) {
336     return;
337   }
338   list->back()->set_extra_jump(nullptr);
339   if (graph != this->graph_) {
340     list->back()->set_op(NOP);
341   } else {
342     list->back()->set_op(RETURN_VALUE);
343   }
344 }
345 
346 /**
347  * unify the implementations of cfg initialization
348  */
InitCFG()349 void BytecodeInliner::InitCFG() {
350   const auto &list = cfg_->instr_pool();
351 
352   // reset bci, erase unused jump
353   CodeGenerator::EraseUnusedInstr(&cfg_->instr_pool());
354 
355   // mark labels, ordered map
356   std::map<int, Block *> blocks;
357   blocks.insert({0, cfg_->NewBBAppend()});
358   for (const auto &i : list) {
359     size_t bci = list.size();
360     if (Opcode(i->op()).IsNotFall()) {
361       bci = (size_t)i->bci() + 1;
362     }
363     if (i->extra_jump() != nullptr) {
364       bci = (size_t)i->bci() + 1;
365       if (blocks.find(i->extra_jump()->bci()) == blocks.end()) {
366         blocks.insert({i->extra_jump()->bci(), cfg_->NewBBAppend()});
367       }
368     }
369     if (bci != list.size() && blocks.find(bci) == blocks.end()) {
370       blocks.insert({bci, cfg_->NewBBAppend()});
371     }
372   }
373 
374   // link blocks, set range
375   for (auto iter = blocks.begin(); iter != blocks.end();) {
376     Block *cur = iter->second;
377     int head = iter->first;
378     int back;
379     iter++;
380     if (iter != blocks.end()) {
381       back = iter->first;
382     } else {
383       back = SizeToInt(list.size());
384     }
385     cur->set_begin_ci(head);
386     cur->set_end_ci(back);
387     const auto &instr = list[back - 1];
388     if (instr->extra_jump()) {
389       cur->SetJumpBB(blocks[instr->extra_jump()->bci()]);
390     }
391     if (!Opcode(instr->op()).IsNotFall()) {
392       cur->SetFallBB(iter->second);
393     }
394   }
395   cfg_->MarkDeadBB();
396   cfg_->GetLiveness();
397 }
398 
IsEliminate(ValueNode * v)399 static bool IsEliminate(ValueNode *v) {
400   auto op = Opcode(v->GetOpcode());
401   if (op.MayDelete()) {
402     return true;
403   }
404   if (op.IsBinaryMath()) {
405     // inplace binary
406     AObject::Type t = v->input(0)->GetVobj()->GetType();
407     return t != AObject::kTypeAnyValue && t != AObject::kTypeList && t != AObject::kTypeCell &&
408            t != AObject::kTypeNNCellList;
409   }
410   if (op.IsCall()) {
411     py::object callable = v->input(0)->GetVobj()->GetPyObject();
412     if (callable.ptr() == nullptr) {
413       return false;
414     }
415     return CheckJitConstexpr(callable) || CheckMSConstexpr(callable);
416   }
417   if (op == GET_ITER) {
418     return v->input(0)->GetVobj()->GetType() != AObject::kTypeAnyValue;
419   }
420   return false;
421 }
422 
EraseDeadLocal(const std::vector<ValueNode * > & alive_nodes)423 void BytecodeInliner::EraseDeadLocal(const std::vector<ValueNode *> &alive_nodes) {
424   std::set<ValueNode *> alive;
425   for (auto i : alive_nodes) {
426     alive.insert(i);
427   }
428 
429   // erase dead locals
430   std::set<ValueNode *> used;
431   do {
432     used = alive;
433     for (auto i : traced_nodes_) {
434       for (auto j : i->getInputs()) {
435         used.insert(j);
436       }
437     }
438     auto iter = std::remove_if(traced_nodes_.begin(), traced_nodes_.end(), [&used](ValueNode *i) {
439       // check it
440       return used.find(i) == used.end() && IsEliminate(i);
441     });
442     if (iter == traced_nodes_.end()) {
443       break;
444     }
445     traced_nodes_.erase(iter, traced_nodes_.end());
446   } while (true);
447 }
448 
EliminateClosureSideEffect()449 void BytecodeInliner::EliminateClosureSideEffect() {
450   PyCodeObject *co = graph_->GetCodeObj();
451   int ncells = PyTuple_GET_SIZE(co->co_cellvars);
452   int nfrees = PyTuple_GET_SIZE(co->co_freevars);
453   if (ncells + nfrees == 0) {
454     return;
455   }
456   std::set<InstrNode *> alive_closure_access;
457 
458   if (last_frame_ != nullptr) {
459     auto iter = std::find_if(cfg_->instr_pool().begin(), cfg_->instr_pool().end(), [](const std::unique_ptr<Instr> &i) {
460       return i->op() == LOAD_DEREF || (i->op() == MAKE_FUNCTION && ((signed)i->arg() & 0x08));
461     });
462     if (iter != cfg_->instr_pool().end()) {
463       return;
464     }
465   }
466 
467   for (auto i : traced_nodes_) {
468     if (i->GetOpcode() == MAKE_FUNCTION && (i->GetOparg() & 0x08)) {
469       ValueNode *tuple = *(i->getInputs().end() - 3);
470       for (auto c : tuple->getInputs()) {
471         const auto &nodes = static_cast<CellVarNode *>(c)->GetCellOper();
472         alive_closure_access.insert(nodes.begin(), nodes.end());
473       }
474     }
475   }
476 
477   auto iter = std::remove_if(traced_nodes_.begin(), traced_nodes_.end(), [&alive_closure_access](ValueNode *i) {
478     int op = i->GetOpcode();
479     return (op == STORE_DEREF || op == DELETE_DEREF) && alive_closure_access.find(i) == alive_closure_access.end();
480   });
481   traced_nodes_.erase(iter, traced_nodes_.end());
482 
483   for (auto item = traced_nodes_.begin(); item != traced_nodes_.end();) {
484     if ((*item)->GetOpcode() == STORE_DEREF) {
485       if ((*item)->getInputs()[0]->GetOpcode() == LOAD_DEREF &&
486           (*item)->getInputs()[0]->GetOparg() == (*item)->GetOparg() &&
487           (*item)->getInputs()[0]->GetGraph() == (*item)->GetGraph()) {
488         item = traced_nodes_.erase(item);
489       } else {
490         ++item;
491       }
492     } else {
493       ++item;
494     }
495   }
496 }
497 
498 }  // namespace pijit
499 }  // namespace mindspore
500