• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pipeline/jit/pi/common.h"
17 #include <algorithm>
18 #include <iomanip>
19 #include <iterator>
20 #include <regex>
21 #include <list>
22 #include <map>
23 #include <memory>
24 #include <string>
25 #include <vector>
26 #include "pybind11/pybind11.h"
27 #include "include/common/utils/convert_utils_py.h"
28 #include "pipeline/jit/pi/auto_grad/function_node.h"
29 #include "pipeline/jit/pi/external.h"
30 #include "pipeline/jit/pi/graph_capture/graph_build.h"
31 #include "pipeline/jit/pi/graph_capture/graph_analyzer.h"
32 #include "pipeline/jit/pi/graph_compiler/abstract_type_deducer.h"
33 #include "pipeline/jit/pi/graph_compiler/compiler.h"
34 #include "pipeline/jit/pi/graph_compiler/cg/byte_code_generator.h"
35 #include "pipeline/jit/pi/graph_compiler/inliner/func_inliner.h"
36 #include "pipeline/jit/pi/graph_compiler/parser/byte_code_parser.h"
37 #include "pipeline/jit/pi/graph_compiler/utils.h"
38 #include "pipeline/jit/pi/utils/utils.h"
39 #include "pipeline/jit/pi/graph_guard/guard.h"
40 #include "pipeline/jit/pi/graph_guard/strategy.h"
41 #include "pipeline/jit/pi/graph_guard/shape_ctx.h"
42 #include "pipeline/jit/ps/pipeline.h"
43 #include "pipeline/pynative/pynative_utils.h"
44 #include "runtime/pynative/op_executor.h"
45 #include "include/common/debug/anf_ir_dump.h"
46 #include "pipeline/jit/pi/graph_capture/code_generator.h"
47 #include "pipeline/jit/pi/graph_capture/bytecode_inliner.h"
48 #include "utils/convert_utils_base.h"
49 
50 namespace mindspore {
51 namespace pijit {
52 static Py_tss_t *tss = NULL;
53 
54 void AddConfigToGuard(const GraphJitConfig &c, OptGuardPtr guard);
55 void AddGuardForParam(const PyFrameObject *f, OptGuardPtr guard, bool detach);
56 void AddGuardForGlobals(const PyFrameObject *f, OptGuardPtr guard, bool detach);
57 static void AddGradFlagForParam(bool grad_flag, OptGuardPtr guard, bool detach);
58 static void CollectTraceBack(JitCompileResults *c, PyCodeObject *code, bool is_graph_mode);
59 
60 class ByteCodeRunStatistic {
61  public:
~ByteCodeRunStatistic()62   ~ByteCodeRunStatistic() {
63     if (py_.empty() && graph_.empty()) {
64       return;
65     }
66     std::cout << ToString() << std::endl;
67   }
68 
Count(PyObject * code,bool graph_preferred)69   void Count(PyObject *code, bool graph_preferred) {
70     if (graph_preferred) {
71       graph_[PyBytes_GET_SIZE(code)]++;
72     } else {
73       py_[PyBytes_GET_SIZE(code)]++;
74     }
75   }
76 
ToString()77   std::string ToString() {
78     const auto SumFunc = [](size_t sum, const std::pair<uint64_t, size_t> &i) { return sum + (i.first * i.second); };
79     size_t sum_py = std::accumulate(py_.begin(), py_.end(), 0, SumFunc);
80     size_t sum_graph = std::accumulate(graph_.begin(), graph_.end(), 0, SumFunc);
81     double ratio = static_cast<double>(sum_graph) / (sum_graph + sum_py);
82     return "execute code ratio (graph / (graph + python)): " + std::to_string(ratio);
83   }
84 
GetInstance()85   static ByteCodeRunStatistic *GetInstance() {
86     static ByteCodeRunStatistic instance;
87     return &instance;
88   }
89 
90  private:
91   ByteCodeRunStatistic() = default;
92   std::map<uint64_t, size_t> py_;
93   std::map<uint64_t, size_t> graph_;
94 };
95 
96 class StaticAnalysisExceptionCleaner {
97  public:
98   StaticAnalysisExceptionCleaner() = default;
~StaticAnalysisExceptionCleaner()99   ~StaticAnalysisExceptionCleaner() { StaticAnalysisException::Instance().ClearException(); }
100 };
101 
102 class RunEnvironment {
103  public:
104   RunEnvironment() = default;
105 
fetchAndSetRunEnv(const JitCompileResults * jcr)106   void fetchAndSetRunEnv(const JitCompileResults *jcr) {
107     auto ms_context = MsContext::GetInstance();
108     run_mode_ = ms_context->get_param<int>(MS_CTX_EXECUTION_MODE);
109     jit_level_ = ms_context->GetJitLevel();
110     task_sink_ = ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
111 
112     auto jit_level = jcr->conf->getJitLevel();
113     auto grad_flag = pynative::PyNativeExecutor::GetInstance()->grad_flag();
114     auto run_mode = jit_level == "O2" && !grad_flag ? kGraphMode : kPynativeMode;
115     auto task_sink = jit_level == "O2" && !grad_flag;
116     ms_context->set_param(MS_CTX_EXECUTION_MODE, run_mode);
117     ms_context->set_param(MS_CTX_JIT_LEVEL, jit_level);
118     ms_context->SetJitLevel(jit_level);
119     ms_context->set_param<bool>(MS_CTX_ENABLE_TASK_SINK, task_sink);
120     ms_context->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, task_sink);
121   }
122 
resumePreviousRunEnv()123   void resumePreviousRunEnv() {
124     auto ms_context = MsContext::GetInstance();
125     ms_context->set_param(MS_CTX_EXECUTION_MODE, run_mode_);
126     ms_context->set_param(MS_CTX_JIT_LEVEL, jit_level_);
127     ms_context->SetJitLevel(jit_level_);
128     ms_context->set_param<bool>(MS_CTX_ENABLE_TASK_SINK, task_sink_);
129     ms_context->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, task_sink_);
130   }
131 
132  private:
133   int run_mode_ = kPynativeMode;
134   std::string jit_level_;
135   bool task_sink_ = false;
136 };
137 
PrintGuardPerf()138 static void PrintGuardPerf() {
139   std::map<std::string, std::pair<size_t, size_t>> guard_info;
140   std::map<std::string, std::pair<size_t, size_t>> guard_freq_info;
141   std::map<std::string, std::pair<size_t, size_t>> trace_info;
142   std::map<std::string, std::pair<size_t, std::vector<size_t>>> item_info;
143   OptGuardPerf::GetGuardPerf()->GetGuardPerfInfo(&guard_info, &item_info, &trace_info, &guard_freq_info);
144   std::cout << "Guard performance info:" << std::endl;
145   std::cout << "guard, count, total time, success, fail" << std::endl;
146   for (const auto &item : guard_info) {
147     auto iter = guard_freq_info.find(item.first);
148     if (iter != guard_freq_info.end()) {
149       std::cout << "guard:" << item.first << ", " << item.second.first << ", " << item.second.second << ","
150                 << iter->second.first << "," << iter->second.second << std::endl;
151     } else {
152       std::cout << "guard:" << item.first << ", " << item.second.first << ", " << item.second.second << std::endl;
153     }
154   }
155   std::cout << "trace, count, total time" << std::endl;
156   for (const auto &item : trace_info) {
157     std::cout << "trace:" << item.first << ", " << item.second.first << ", " << item.second.second << std::endl;
158   }
159   std::cout << "item, count, [stage time]" << std::endl;
160   for (const auto &item : item_info) {
161     std::cout << "item:" << item.first << "," << item.second.first << ", [";
162     for (auto stage : item.second.second) {
163       std::cout << stage << ",";
164     }
165     std::cout << "]" << std::endl;
166   }
167 }
168 
169 // jit compiler initialize
ensureInitialize()170 static void ensureInitialize() {
171   static bool init = false;
172   if (init) {
173     return;
174   }
175   init = true;
176   if (tss == NULL) {
177     tss = PyThread_tss_alloc();
178     PyThread_tss_create(tss);
179   }
180   std::atexit([]() {
181     if (kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kLogGuardPerf)) {
182       PrintGuardPerf();
183     }
184   });
185 }
186 
PushInlineInfo(InlineInfo info)187 void Tracebackes::PushInlineInfo(InlineInfo info) {
188   const auto &it = inline_infos_.find(info.root_name_);
189   if (it != inline_infos_.cend()) {
190     it->second.push_back(info);
191   } else {
192     std::list<InlineInfo> inlines;
193     inlines.push_back(info);
194     inline_infos_.emplace(info.root_name_, inlines);
195   }
196 }
197 
PrintLabel(std::stringstream & os,const std::string & str,int distance=30)198 static void PrintLabel(std::stringstream &os, const std::string &str, int distance = 30) {
199   os << std::left << std::setw(distance) << str << ": ";
200 }
201 
Dump(bool is_all) const202 std::string Tracebackes::Dump(bool is_all) const {
203   constexpr auto width = 10;
204 
205   std::stringstream os;
206   std::string cur_name = tbs_.empty() ? "" : tbs_.back().func_name_;
207   if (is_all) {
208     os << "*** Dump Traceback on [" << raw_func_info_name_ << "] ***\n";
209   } else {
210     os << "*** Dump ByteCode After Traceback on [" << cur_name << "] ***\n";
211   }
212   if (tbs_.empty()) {
213     return os.str();
214   }
215   std::list<Tracebacke> candidates;
216   if (is_all) {
217     candidates = tbs_;
218   } else {
219     // last one traceback
220     candidates.emplace_back(tbs_.back());
221   }
222   // dump traceback list head
223   int name_length = FindMaxNameLength(candidates);
224   os << std::left << std::setw(name_length) << "func_name:  -->  " << std::left << std::setw(name_length)
225      << "changed_func:" << std::left << std::setw(width) << "run_mode:" << std::left << std::setw(kThree * width)
226      << "stop_trace:" << std::left << std::setw(width) << "code_size:" << std::endl;
227   os << "--------------------------------------------------------------------------------------\n";
228   // dump traceback list content
229   for (const auto &tb : candidates) {
230     os << std::left << std::setw(name_length) << tb.func_name_ << "  -->  ";
231     os << std::left << std::setw(name_length) << tb.changed_func_;
232     if (tb.is_graph_mode_) {
233       os << std::left << std::setw(width) << "[GRAPH]";
234     } else {
235       os << std::left << std::setw(width) << "PYNATIVE";
236     }
237     // dump stop trace reason
238     auto it_trace = stop_trace_res_.find(tb.func_name_);
239     if (it_trace != stop_trace_res_.cend()) {
240       os << std::left << std::setw(kThree * width) << GetStopTraceReasonDesc(it_trace->second);
241     } else {
242       os << std::left << std::setw(kThree * width) << "unknown";
243     }
244     os << std::left << std::setw(width) << tb.code_size_ << " =====>\n";
245     // dump inline info
246     DumpInlineInfo(os, tb.func_name_);
247   }
248   os << "\n\n";
249   if (is_all) {
250     os << DumpSummary();
251   }
252   return os.str();
253 }
254 
DumpInlineInfo(std::stringstream & os,const std::string & func_name) const255 void Tracebackes::DumpInlineInfo(std::stringstream &os, const std::string &func_name) const {
256   const auto &it = inline_infos_.find(func_name);
257   if (it == inline_infos_.cend()) {
258     return;
259   }
260   for (const auto &info : it->second) {
261     std::string space((info.depth + 1) * kTwo, ' ');
262     os << space << "| inline_info:" << GetInlineReasonDesc(info.res) << " line:" << info.line;
263     if (!info.inline_name_.empty()) {
264       os << " func_name:" << info.inline_name_;
265     }
266     if (info.res == InlineReason::kInline || info.res == InlineReason::kInlinePartial) {
267       os << " code_size:" << info.code_size_;
268     }
269     os << "\n";
270   }
271 }
272 
DumpSummary() const273 std::string Tracebackes::DumpSummary() const {
274   std::stringstream os;
275   if (tbs_.empty()) {
276     return os.str();
277   }
278   os << "*** Dump Summary on [" << raw_func_info_name_ << "] ***\n";
279   PrintLabel(os, "traceback_num");
280   os << tbs_.size() << "\n";
281 
282   std::array<int, kStopTrace_Reason_Count> stop_trace_reason_array{0};
283   std::array<int, kInline_Reason_Count> inline_reason_array{0};
284   int graph_mode_num = 0;
285   int raw_code_size = raw_code_size_;
286   int pynative_code_size = 0;
287   int graph_mode_code_size = 0;
288   for (const auto &tb : tbs_) {
289     if (tb.is_graph_mode_) {
290       graph_mode_num++;
291       graph_mode_code_size += tb.code_size_;
292     } else {
293       pynative_code_size += tb.code_size_;
294     }
295     auto it_trace = stop_trace_res_.find(tb.func_name_);
296     if (it_trace != stop_trace_res_.cend()) {
297       // count stop trace reason
298       stop_trace_reason_array[it_trace->second]++;
299     }
300     const auto &it_inline = inline_infos_.find(tb.func_name_);
301     if (it_inline == inline_infos_.cend()) {
302       continue;
303     }
304     for (const auto &info : it_inline->second) {
305       // count inline reason
306       inline_reason_array[info.res]++;
307       if (info.res == InlineReason::kInline || info.res == InlineReason::kInlinePartial) {
308         raw_code_size += info.code_size_;
309       }
310     }
311   }
312   PrintLabel(os, "graph_mode_num");
313   os << graph_mode_num << "\n";
314   PrintLabel(os, "raw_code_size(+ inline)");
315   os << raw_code_size << "\n";
316   PrintLabel(os, "pynative_code_size");
317   os << pynative_code_size << "\n";
318   PrintLabel(os, "graph_mode_code_size");
319   os << graph_mode_code_size << "\n";
320   os << "----------stop_trace_reason----------\n";
321   for (size_t i = 0; i < stop_trace_reason_array.size(); ++i) {
322     PrintLabel(os, GetStopTraceReasonDesc(static_cast<StopTraceReason>(i)));
323     os << stop_trace_reason_array[i] << "\n";
324   }
325   os << "----------inline_reason----------\n";
326   for (size_t i = 0; i < inline_reason_array.size(); ++i) {
327     PrintLabel(os, GetInlineReasonDesc(static_cast<InlineReason>(i)));
328     os << inline_reason_array[i] << "\n";
329   }
330   os << "\n\n";
331   return os.str();
332 }
333 
FindMaxNameLength(const std::list<Tracebacke> & tbs) const334 int Tracebackes::FindMaxNameLength(const std::list<Tracebacke> &tbs) const {
335   constexpr auto name_length = kFive * (kTwo + kFive);
336   int max_length = 15;
337   for (const auto &tb : tbs) {
338     int len1 = SizeToInt(tb.func_name_.length());
339     int len2 = SizeToInt(tb.changed_func_.length());
340     max_length = std::max(max_length, std::max(len1, len2)) + kTwo;
341   }
342   max_length = std::min(max_length, name_length);
343   return max_length;
344 }
345 
freeJitCompileResults(void * jitCompileResults)346 static void freeJitCompileResults(void *jitCompileResults) {
347   // maybe nullptr if other module use _PyEval_RequestCodeExtraIndex
348   if (jitCompileResults == nullptr) {
349     return;
350   }
351   // called after code object freed
352   JitCompileResults *c = reinterpret_cast<JitCompileResults *>(jitCompileResults);
353 
354   for (auto &oc : c->codehub->GetOptTarget(OptOption::CreateOptionByPoint(c))) {
355     PyCodeObject *co = oc->GetPythonCode();
356     MS_EXCEPTION_IF_CHECK_FAIL(co == nullptr || Py_REFCNT(co) == 1, "code handler must be only one");
357   }
358   c->code = nullptr;
359   c->codehub.reset();
360 
361   std::for_each(c->children_.begin(), c->children_.end(), [](CodeExtra *i) { i->parent_ = nullptr; });
362   if (c->parent_ != nullptr) {
363     auto &leaf = c->parent_->children_;
364     leaf.erase(std::remove_if(leaf.begin(), leaf.end(), [c](CodeExtra *i) { return i == c; }), leaf.end());
365   }
366   MS_LOG(DEBUG) << __FUNCTION__ << " " << c;
367   delete c;
368 }
369 
allocJitCompileResults()370 static JitCompileResults *allocJitCompileResults() {
371   JitCompileResults *c = new JitCompileResults();
372   c->parent_ = nullptr;
373   c->stat = JitCompileResults::NEVER_COMPILE;
374   c->tbs = std::make_shared<Tracebackes>();
375   c->codehub = std::make_shared<OptCodeHub>();
376   c->conf = std::make_shared<GraphJitConfig>();
377   c->break_count_ = 0;
378   c->signature_ = nullptr;
379   return c;
380 }
381 
getJitCompileResults(PyObject * code,bool alloc)382 JitCompileResults *getJitCompileResults(PyObject *code, bool alloc) {
383   if (PyMethod_Check(code)) {
384     code = PyMethod_GET_FUNCTION(code);
385   }
386   if (PyFunction_Check(code)) {
387     code = PyFunction_GET_CODE(code);
388   }
389   if (!PyCode_Check(code)) {
390     return nullptr;
391   }
392   ensureInitialize();
393   Py_ssize_t index = (Py_ssize_t)PyThread_tss_get(tss);
394   if (index == 0) {
395     index = _PyEval_RequestCodeExtraIndex(freeJitCompileResults);
396     if (index == -1) {
397       return nullptr;
398     }
399     // ensure index is not 0
400     PyThread_tss_set(tss, reinterpret_cast<void *>(index + 1));
401   } else {
402     index = index - 1;
403   }
404 
405   JitCompileResults *c = nullptr;
406   if (!_PyCode_GetExtra(code, index, reinterpret_cast<void **>(&c))) {
407     if (c != nullptr) {
408       return c;
409     }
410     if (!alloc) {
411       return nullptr;
412     }
413     c = allocJitCompileResults();
414     if (c == nullptr) {
415       return nullptr;
416     }
417     if (!_PyCode_SetExtra(code, index, c)) {
418       MS_LOG(DEBUG) << "allocJitCompileResults " << c << " for " << std::string(py::str(code));
419       return c;
420     }
421     freeJitCompileResults(c);
422   }
423   PyErr_Clear();
424   return nullptr;
425 }
426 
RebuildFrame(PyThreadState * tstate,PyCodeObject * co,const PyFrameObject * f)427 static PyFrameObject *RebuildFrame(PyThreadState *tstate, PyCodeObject *co, const PyFrameObject *f) {
428   int argc = f->f_code->co_argcount + f->f_code->co_kwonlyargcount;
429   MS_ASSERT(co != nullptr && argc == co->co_argcount + co->co_kwonlyargcount);
430   MS_ASSERT((f->f_code->co_flags & CO_VARARGS) == (co->co_flags & CO_VARARGS));
431   MS_ASSERT((f->f_code->co_flags & CO_VARKEYWORDS) == (co->co_flags & CO_VARKEYWORDS));
432   argc += (static_cast<unsigned int>(f->f_code->co_flags) & CO_VARARGS) ? 1 : 0;
433   argc += (static_cast<unsigned int>(f->f_code->co_flags) & CO_VARKEYWORDS) ? 1 : 0;
434 
435   PyFrameObject *frame = PyFrame_New(tstate, co, f->f_globals, NULL);
436   // copy arguments
437   for (int i = 0; i < argc; i++) {
438     Py_XINCREF(f->f_localsplus[i]);
439     frame->f_localsplus[i] = f->f_localsplus[i];
440   }
441   // restore arguments from cell
442   std::vector<PyObject *> cells_content(f->f_code->co_nlocals, nullptr);
443   for (int i = 0; f->f_code->co_cell2arg != NULL && i < PyTuple_GET_SIZE(f->f_code->co_cellvars); ++i) {
444     Py_ssize_t argi = f->f_code->co_cell2arg[i];
445     if (argi != CO_CELL_NOT_AN_ARG) {
446       PyObject *cell = f->f_localsplus[f->f_code->co_nlocals + i];
447       cells_content[argi] = PyCell_GET(cell);
448     }
449   }
450   // new cell
451   for (int i = 0; i < PyTuple_GET_SIZE(co->co_cellvars); ++i) {
452     PyObject *cell;
453     if (co->co_cell2arg != NULL && co->co_cell2arg[i] != CO_CELL_NOT_AN_ARG) {
454       Py_ssize_t argi = co->co_cell2arg[i];
455       MS_EXCEPTION_IF_CHECK_FAIL(cells_content[argi], "Unbound local exception");
456       cell = PyCell_New(cells_content[argi]);
457     } else {
458       cell = PyCell_New(NULL);
459     }
460     frame->f_localsplus[co->co_nlocals + i] = cell;
461   }
462 
463   // copy closure
464   for (int i = 0; i < PyTuple_GET_SIZE(co->co_freevars); ++i) {
465     int a = f->f_code->co_nlocals + PyTuple_GET_SIZE(f->f_code->co_cellvars) + i;
466     int b = co->co_nlocals + PyTuple_GET_SIZE(co->co_cellvars) + i;
467     auto o = f->f_localsplus[a];
468     Py_XINCREF(o);
469     frame->f_localsplus[b] = o;
470   }
471   return frame;
472 }
473 
GetClosure(const PyFrameObject * f)474 static PyObject *GetClosure(const PyFrameObject *f) {
475   int nfrees = PyTuple_GET_SIZE(f->f_code->co_freevars);
476   if (nfrees == 0) {
477     return nullptr;
478   }
479   PyObject *closure = PyTuple_New(nfrees);
480   int idx = f->f_code->co_nlocals + PyTuple_GET_SIZE(f->f_code->co_cellvars);
481   for (int i = 0; i < nfrees; ++i) {
482     PyObject *o = f->f_localsplus[idx + i];
483     Py_INCREF(o);
484     PyTuple_SET_ITEM(closure, i, o);
485   }
486   return closure;
487 }
488 
PrepareCallCompiledCallable(PyThreadState * tstate,const PyFrameObject * f,const JitCompileResults * c)489 static PyFrameObject *PrepareCallCompiledCallable(PyThreadState *tstate, const PyFrameObject *f,
490                                                   const JitCompileResults *c) {
491   return RebuildFrame(tstate, c->code->GetPythonCode(), f);
492 }
493 
GuardForFrame(const PyFrameObject * frame,const OptCodePtr & oc,const GraphJitConfig & conf)494 static void GuardForFrame(const PyFrameObject *frame, const OptCodePtr &oc, const GraphJitConfig &conf) {
495   const char *code_name = PyUnicode_AsUTF8(frame->f_code->co_name);
496   AddConfigToGuard(conf, oc->GetGuard());
497   AddGuardForParam(frame, oc->GetGuard(), conf.GetBoolConfig(GraphJitConfig::kGuardDetachObject));
498   AddGradFlagForParam(pynative::PyNativeExecutor::GetInstance()->grad_flag(), oc->GetGuard(),
499                       conf.GetBoolConfig(GraphJitConfig::kGuardDetachObject));
500   if (conf.GetBoolConfig(GraphJitConfig::kPrintGuard)) {
501     GRAPH_JIT_LOG_F("Guard on %s by %s!\n", code_name, oc->GetGuard()->GetDescript().c_str());
502     return;
503   }
504   if (IS_OUTPUT_ON(mindspore::kDebug)) {
505     // It tooks too much time in Guard's GetDescript function when trace depth is too large.
506     MS_LOG(DEBUG) << "Guard on " << code_name << " by " << oc->GetGuard()->GetDescript() << "!" << std::endl;
507   }
508 }
509 
ValidateCompiledResults(const JitCompileResults * c)510 static void ValidateCompiledResults(const JitCompileResults *c) {
511   if (c->stat != JitCompileResults::GRAPH_CALLABLE) {
512     return;
513   }
514   bool valid_res;
515   if (c->code->GetNativeFunc()) {
516     valid_res = true;
517   } else {
518     valid_res = c->code->GetPythonCode() != nullptr;
519   }
520   MS_EXCEPTION_IF_CHECK_FAIL(valid_res, "check compiled result");
521 }
522 
MarkBreak(Graph * g)523 static void MarkBreak(Graph *g) {
524   TimeRecorder recorder(__FUNCTION__, kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kLogPerf));
525   int break_bci = g->GetStopTraceBci();
526   if (break_bci == -1) {
527     return;
528   }
529   PyCodeObject *code;
530   const auto &nodes = g->GetTracedNodes();
531   if (nodes.empty()) {
532     code = g->GetCodeObj();
533   } else {
534     auto iter = std::find_if(nodes.begin(), nodes.end(), [&break_bci](ValueNode *i) { return i->bci() >= break_bci; });
535     iter -= iter == nodes.end();
536     for (code = (*iter)->GetGraph()->GetCodeObj(); code == nullptr && iter != nodes.begin(); --iter) {
537       code = (*iter)->GetGraph()->GetCodeObj();
538     }
539   }
540   MS_EXCEPTION_IF_NULL(code);
541   auto jcr = getJitCompileResults(reinterpret_cast<PyObject *>(code), false);
542   if (jcr != nullptr) {
543     jcr->break_count_++;
544   }
545 }
546 
GetAllArgs(JitCompileResults * jcr)547 std::vector<py::object> GetAllArgs(JitCompileResults *jcr) {
548   auto all_args = PackArgs(jcr->origin_frame_);
549   constexpr size_t arg_index = 0;
550   constexpr size_t vargs_index = 1;
551   constexpr size_t kwargs_index = 2;
552   auto args = py::cast<py::list>(all_args[arg_index]);
553   if (all_args[vargs_index].ptr() != nullptr) {
554     PyList_Append(args.ptr(), all_args[vargs_index].ptr());  // args + vargs
555   }
556   if (all_args[kwargs_index].ptr() != nullptr) {
557     PyList_Append(args.ptr(), all_args[kwargs_index].ptr());  // args + kwargs
558   }
559   return args.cast<std::vector<py::object>>();
560 }
561 
562 static void GraphCapture(JitCompileResults *jcr);
HandleBreakAtLoop(JitCompileResults * jcr,const GraphBuilderPtr & g)563 static auto HandleBreakAtLoop(JitCompileResults *jcr, const GraphBuilderPtr &g) {
564   // one stage need adapter
565   if (g->GetGraph()->IsBreakAtLoopAfterUnrolling()) {
566     if (jcr->conf->GetBoolConfig(GraphJitConfig::kLogGraphBreak)) {
567       GRAPH_JIT_LOG_F("===> graph break after loop unrolling\n%s\n", g->GetGraph()->ToString(1).c_str());
568     }
569     // reset guard
570     jcr->code->SetGuard(std::make_shared<OptGuard>());
571     AddConfigToGuard(*jcr->conf, jcr->code->GetGuard());
572     // disable loop unroll
573     jcr->conf->SetBool<GraphJitConfig::kLoopUnrolling>(Py_False);
574     // restart captured
575     GraphCapture(jcr);
576     // reset config
577     jcr->conf->SetBool<GraphJitConfig::kLoopUnrolling>(Py_True);
578     return true;
579   }
580   return false;
581 }
582 
HandleUnsupportedSyntax(JitCompileResults * jcr,const GraphBuilderPtr & g)583 static auto HandleUnsupportedSyntax(JitCompileResults *jcr, const GraphBuilderPtr &g) {
584   int break_bci = g->GetGraph()->GetStopTraceBci();
585   if (break_bci == -1) {
586     return false;
587   }
588   int break_op = g->GetGraph()->GetCFG()->instr_pool()[break_bci]->op();
589   bool unsupported = break_op == WITH_CLEANUP_START || break_op == WITH_CLEANUP_FINISH || break_op == END_FINALLY;
590   if (g->StackSize() > 0 || unsupported) {
591     // something happened in with syntax
592     jcr->code->SetGuard(std::make_shared<OptGuard>());
593     AddConfigToGuard(*jcr->conf, jcr->code->GetGuard());
594     jcr->conf->SetBool<GraphJitConfig::kSkipException>(Py_True);
595     GraphCapture(jcr);
596     g->GetTryBlockStacks().clear();
597     jcr->conf->SetBool<GraphJitConfig::kSkipException>(Py_False);
598     return true;
599   }
600   return false;
601 }
602 
TraceRun(JitCompileResults * jcr)603 static auto TraceRun(JitCompileResults *jcr) {
604   TimeRecorder recorder(__FUNCTION__, kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kLogPerf));
605 
606   GraphJitConfig &conf = *jcr->conf;
607   GraphBuilderPtr g = GraphBuilder::Creator(jcr->origin_frame_, conf.GetBoolConfig(GraphJitConfig::kTraceFlag));
608 
609   if (conf.GetBoolConfig(GraphJitConfig::kTraceFlag)) {
610     auto mg = std::dynamic_pointer_cast<MindGraphBuilder>(g);
611     (void)mg->FGBuilder()->AddTopGraphInputs(PackArgs(jcr->origin_frame_));
612   }
613   (void)g->TraceRun();
614   return g;
615 }
616 
Inline(JitCompileResults * jcr,const GraphBuilderPtr & g)617 static void Inline(JitCompileResults *jcr, const GraphBuilderPtr &g) {
618   TimeRecorder recorder(__FUNCTION__, kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kLogPerf));
619   GraphJitConfig &conf = *jcr->conf;
620   // One stage should skip inline process.
621   if (!conf.GetBoolConfig(GraphJitConfig::kTraceFlag)) {
622     BytecodeInliner inliner(g->GetGraph(), py::cast<py::dict>(jcr->origin_frame_->f_globals));
623     inliner.Run();
624   }
625 }
626 
Analyze(const GraphBuilderPtr & g)627 static auto Analyze(const GraphBuilderPtr &g) {
628   TimeRecorder recorder(__FUNCTION__, kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kLogPerf));
629 
630   auto analyzer = GraphAnalyzer::Creator(g);
631   analyzer->Analyze();
632   return analyzer;
633 }
634 
635 // preprocess before compile, split bytecode to sub-function
636 // return whether the code should be modified
GraphCapture(JitCompileResults * jcr)637 static void GraphCapture(JitCompileResults *jcr) {
638   TimeRecorder recorder(__FUNCTION__, kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kLogPerf));
639   MS_EXCEPTION_IF_NULL(jcr->code);
640 
641   GraphJitConfig &conf = *jcr->conf;
642   AObject::SetTraceFlag(conf.GetBoolConfig(GraphJitConfig::kTraceFlag));
643   GraphBuilderPtr g = TraceRun(jcr);
644   if (HandleUnsupportedSyntax(jcr, g)) {
645     return;
646   }
647   if (g->GetGraph()->IsBreakAtLoop() && !g->GetGraph()->RestoreLoopStatus()) {
648     jcr->stat = JitCompileResults::NEVER_COMPILE;
649     return;
650   }
651   Inline(jcr, g);
652   GraphAnalyzerPtr analyzer = Analyze(g);
653   if (HandleBreakAtLoop(jcr, g)) {
654     return;
655   }
656   MarkBreak(g->GetGraph());
657 
658   // dump DFG
659   if (conf.GetBoolConfig(GraphJitConfig::kPrintAfterAll)) {
660     g->DumpDFG();
661   }
662 
663   py::object new_code = MakeCodeFromCodeGen(g, analyzer, jcr->origin_frame_->f_globals);
664   if (new_code.ptr() != nullptr) {
665     jcr->code->SetPythonCode(new_code);
666     jcr->stat = JitCompileResults::GRAPH_CALLABLE;
667   }
668 
669   if (conf.GetBoolConfig(GraphJitConfig::kPrintAfterAll)) {
670     if (conf.GetBoolConfig(GraphJitConfig::kTraceFlag)) {
671       const auto &debug_str = analyzer->GetCaptureInfo().ToString();
672       PY_PRINT_F("*** Dump One Stage ByteCode Collection After CodeGen *** \n%s", debug_str.c_str());
673     }
674     Utils::DisFuncObject(new_code.ptr());
675     GRAPH_JIT_LOG_F("\n\n");
676   }
677 
678   // collect stop trace reason to traceback
679   jcr->tbs->PushStopTraceRes(g->GetGraph()->GetCodeName(), g->GetGraph()->GetStopTraceReason());
680 
681   bool captured = !analyzer->NeedInterpret() && !conf.GetBoolConfig(GraphJitConfig::kInterpretCapturedCode);
682   if (captured && !jcr->conf->GetBoolConfig(GraphJitConfig::kTraceFlag)) {
683     jcr->stat = JitCompileResults::GRAPH_CAPTURED;
684   }
685 }
686 
CollectTraceBack(JitCompileResults * c,PyCodeObject * code,bool is_graph_mode)687 static void CollectTraceBack(JitCompileResults *c, PyCodeObject *code, bool is_graph_mode) {
688   if (code == nullptr) {
689     code = c->origin_frame_->f_code;
690   }
691   std::string name = Utils::GetPyName(c->origin_frame_->f_code->co_name);
692   std::string changed_name = Utils::GetPyName(code->co_name);
693   int code_size = SizeToInt((PyBytes_GET_SIZE(code->co_code)) / sizeof(_Py_CODEUNIT));
694   c->tbs->PushTbs({name, changed_name, code_size, is_graph_mode});
695 }
696 
GetFuncGraphPhase(const PyFrameObject & frame,const OptCodePtr & oc)697 std::string GetFuncGraphPhase(const PyFrameObject &frame, const OptCodePtr &oc) {
698   std::string phase = py::cast<std::string>(frame.f_code->co_filename) + "_" +
699                       std::to_string(frame.f_code->co_firstlineno) + "_" + py::cast<std::string>(frame.f_code->co_name);
700   if (oc != nullptr) {
701     phase += std::to_string(oc->GetGuard()->Info().Id());
702   } else {
703     for (int i = 0; i < frame.f_code->co_argcount; i++) {
704       PyObject *obj = PyTuple_GET_ITEM(frame.f_code->co_varnames, i);
705       py::object para = py::cast<py::object>(PyDict_GetItem(frame.f_locals, obj));
706       auto node = GraphUtils::ConvertPythonObjectToAnfNode(para);
707       phase += "_" + node->abstract()->ToString();
708     }
709   }
710   phase += ".pi_jit";
711   return phase;
712 }
713 
AddConfigToGuard(const GraphJitConfig & c,OptGuardPtr guard)714 void AddConfigToGuard(const GraphJitConfig &c, OptGuardPtr guard) {
715   std::map<std::string, bool> bool_cfg;
716   std::map<std::string, int> int_cfg;
717   bool_cfg[kSpecializeScalar] = c.GetBoolConfig(GraphJitConfig::kGuardSpecializeScalar);
718   bool_cfg[kSpecializeContainer] = c.GetBoolConfig(GraphJitConfig::kGuardSpecializeContainer);
719   bool_cfg[kSpecializeTensor] = c.GetBoolConfig(GraphJitConfig::kGuardSpecializeTensor);
720   int_cfg[kGuardRelaxCnt] = c.getIntConfig(GraphJitConfig::kGuardRelaxCount);
721   guard->UpdateConfig(bool_cfg, int_cfg);
722 }
723 
AddGuardForParam(const PyFrameObject * f,OptGuardPtr guard,bool detach)724 void AddGuardForParam(const PyFrameObject *f, OptGuardPtr guard, bool detach) {
725   int argc = f->f_code->co_argcount + f->f_code->co_kwonlyargcount;
726   PyObject *vargs = NULL;
727   PyObject *kwargs = NULL;
728   if (static_cast<unsigned int>(f->f_code->co_flags) & CO_VARARGS) {
729     vargs = f->f_localsplus[argc];
730   }
731   if (static_cast<unsigned int>(f->f_code->co_flags) & CO_VARKEYWORDS) {
732     kwargs = f->f_localsplus[argc + (vargs ? 1 : 0)];
733   }
734   for (int i = 0; i < argc; ++i) {
735     if (f->f_localsplus[i] == nullptr) {
736       continue;
737     }
738     RootTracePtr ptr = std::make_shared<RootTrace>(f->f_localsplus[i], mindspore::pijit::TraceType::Param, i);
739     guard->GuardOn(ptr, mindspore::pijit::GuardLevel::GDeduce, false);
740     if (detach) {
741       ptr->Detach();
742     }
743   }
744   if (vargs != NULL) {
745     RootTracePtr ptr = std::make_shared<RootTrace>(f->f_localsplus[argc], mindspore::pijit::TraceType::Param, argc);
746     guard->GuardOn(ptr, mindspore::pijit::GuardLevel::GDeduce, false);
747     if (detach) {
748       ptr->Detach();
749     }
750   }
751   if (kwargs != NULL) {
752     RootTracePtr ptr = std::make_shared<RootTrace>(f->f_localsplus[argc + (vargs ? 1 : 0)],
753                                                    mindspore::pijit::TraceType::Param, argc + (vargs ? 1 : 0));
754     guard->GuardOn(ptr, mindspore::pijit::GuardLevel::GDeduce, false);
755     if (detach) {
756       ptr->Detach();
757     }
758   }
759   for (int i = 0; f->f_code->co_cell2arg && i < PyTuple_GET_SIZE(f->f_code->co_cellvars); ++i) {
760     Py_ssize_t arg = f->f_code->co_cell2arg[i];
761     if (arg != CO_CELL_NOT_AN_ARG) {
762       auto cell = f->f_localsplus[f->f_code->co_nlocals + i];
763       RootTracePtr ptr = std::make_shared<RootTrace>(PyCell_GET(cell), mindspore::pijit::TraceType::Deref, i);
764       guard->GuardOn(ptr, mindspore::pijit::GuardLevel::GDeduce, false);
765       if (detach) {
766         ptr->Detach();
767       }
768     }
769   }
770   for (int i = 0; i < PyTuple_GET_SIZE(f->f_code->co_freevars); ++i) {
771     Py_ssize_t arg = PyTuple_GET_SIZE(f->f_code->co_cellvars) + i;
772     auto cell = f->f_localsplus[f->f_code->co_nlocals + arg];
773     RootTracePtr ptr = std::make_shared<RootTrace>(PyCell_GET(cell), mindspore::pijit::TraceType::Deref, arg);
774     guard->GuardOn(ptr, mindspore::pijit::GuardLevel::GDeduce, false);
775     if (detach) {
776       ptr->Detach();
777     }
778   }
779 }
780 
AddGuardForGlobals(const PyFrameObject * f,OptGuardPtr guard,bool detach)781 void AddGuardForGlobals(const PyFrameObject *f, OptGuardPtr guard, bool detach) {
782   PyCodeObject *co = f->f_code;
783   const _Py_CODEUNIT *bytecodes = reinterpret_cast<_Py_CODEUNIT *>(PyBytes_AsString(co->co_code));
784   int size = (PyBytes_GET_SIZE(co->co_code)) / SizeToInt(sizeof(_Py_CODEUNIT));
785   unsigned int exarg = 0;
786   for (int bci = 0; bci < size; ++bci) {
787     int opcode = _Py_OPCODE(bytecodes[bci]);
788     int oparg = (exarg << 8) | _Py_OPARG(bytecodes[bci]);
789     exarg = static_cast<unsigned>((opcode == EXTENDED_ARG) ? oparg : 0);
790     if (opcode != LOAD_GLOBAL) {
791       continue;
792     }
793     PyObject *k = PyTuple_GET_ITEM(co->co_names, oparg);
794     PyObject *v = PyDict_GetItem(f->f_globals, k);
795     std::string key = PyUnicode_AsUTF8(k);
796     if (v == nullptr) {
797       PyErr_Clear();
798       continue;
799     }
800 
801     TracePtr ptr = std::make_shared<RootTrace>(v, TraceType::Global, -1, key);
802 
803     AObject::Type t = AObject::GetPyType(v);
804     GuardLevel level = GuardLevel::GType;
805     if (t == AObject::kTypeCell || t == AObject::kTypePrimitive || t == AObject::kTypeMSDType) {
806       level = GuardLevel::GDeduce;
807     } else if (t == AObject::kTypeFunction) {
808       ptr = std::make_shared<OpTrace>(PyFunction_GET_CODE(v), LOAD_ATTR, -1, std::vector<TracePtr>({ptr}), "__code__");
809       level = GuardLevel::GId;
810     } else if (t == AObject::kTypeTuple || t == AObject::kTypeList || t == AObject::kTypeDict) {
811       /**
812        * graph treat tuple, list, dict as constant variable.
813        * add container guard and check it, check contains Tensor
814        */
815       continue;
816     }
817 
818     guard->GuardOn(ptr, level, false);
819     if (detach) {
820       ptr->Detach();
821     }
822   }
823 }
824 
AddGradFlagForParam(bool grad_flag,OptGuardPtr guard,bool detach)825 static void AddGradFlagForParam(bool grad_flag, OptGuardPtr guard, bool detach) {
826   CustomizedTracePtr ptr = std::make_shared<CustomizedTrace>(
827     grad_flag ? Py_True : Py_False,
828     [](PTraceContext context) -> PyObject * {
829       static pynative::PyNativeExecutor *pynative_exec = nullptr;
830       if (pynative_exec == nullptr) {
831         pynative_exec = pynative::PyNativeExecutor::GetInstance().get();
832       }
833       PyObject *ret = pynative_exec->grad_flag() ? Py_True : Py_False;
834       Py_INCREF(ret);
835       return ret;
836     },
837     [grad_flag](bool simple) -> std::string {
838       if (simple) {
839         return std::string("g\\") + std::to_string(grad_flag ? 1 : 0);
840       }
841       return std::string("{PyNativeExecutor::GetInstance()->grad_flag == ") + std::to_string(grad_flag) +
842              std::string("}(type:") + std::to_string(TraceType::Customized) + std::string(")");
843     });
844   guard->GuardOn(ptr, mindspore::pijit::GuardLevel::GEqual, true);
845   if (detach) {
846     ptr->Detach();
847   }
848 }
849 
CallGraphCompiler(JitCompileResults * jcr,PyFunctionObject * func,const PyFrameObject * frame)850 static std::string CallGraphCompiler(JitCompileResults *jcr, PyFunctionObject *func, const PyFrameObject *frame) {
851   std::string phase = GetFuncGraphPhase(*frame, jcr->code);
852   MS_LOG(DEBUG) << "Phase is " << phase << "!";
853   CallableGraph callable = mindspore::pijit::Compiler::Compile(*func, *frame, phase);
854   if (callable == nullptr) {
855     jcr->stat = JitCompileResults::NEVER_COMPILE;
856     return std::string();
857   }
858 
859   ReleaseFunc rFunc = nullptr;
860   if (jcr->conf->GetBoolConfig(GraphJitConfig::kAutoCleanCache)) {
861     rFunc = [phase]() {
862       auto graph_executor = mindspore::pipeline::GraphExecutorPy::GetInstance();
863       if (graph_executor->HasCompiled(phase)) {
864         py::str p(phase);
865         py::set s;
866         s.add(phase);
867         py::object o = py::none();
868         graph_executor->DelNetRes(o, s);
869         MS_LOG(DEBUG) << "To release " << phase;
870       }
871     };
872   }
873   jcr->code->SetNativeFunc(phase, callable, rFunc);
874   jcr->stat = JitCompileResults::GRAPH_CALLABLE;
875   return phase;
876 }
877 
GraphToString(FuncGraphPtr graph)878 std::string GraphToString(FuncGraphPtr graph) {
879   std::ostringstream graph_buffer;
880   DumpIR(graph_buffer, graph);
881   auto ret = graph_buffer.str();
882   std::regex regAddress("(0x)([0-9a-f]+)");
883   ret = std::regex_replace(ret, regAddress, "");
884   std::regex regFunc(std::string("(") + graph->ToString() + std::string(")"));
885   ret = std::regex_replace(ret, regFunc, "");
886   std::regex regVar("(\\%[0-9]+\\()([A-Za-z0-9_]+)(\\))");
887   ret = std::regex_replace(ret, regVar, "$1$3");
888   std::regex regNode("CNode_([0-9]+)");
889   ret = std::regex_replace(ret, regNode, "");
890   return ret;
891 }
892 
GraphCompile(JitCompileResults * jcr,const PyFrameObject * frame)893 static void GraphCompile(JitCompileResults *jcr, const PyFrameObject *frame) {
894   TimeRecorder recorder(__FUNCTION__, kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kLogPerf));
895   GuardForFrame(frame, jcr->code, *jcr->conf);
896   AddGuardForGlobals(frame, jcr->code->GetGuard(), jcr->conf->GetBoolConfig(GraphJitConfig::kGuardDetachObject));
897 
898   bool enable_dynamicshape = jcr->conf->GetBoolConfig(GraphJitConfig::kEnableDynamicShape);
899   OptStrategy::MakeGCStrategy(jcr->codehub, jcr->conf->getIntConfig(GraphJitConfig::kLimitGraphSize),
900                               jcr->conf->getIntConfig(GraphJitConfig::kLimitGraphCount), enable_dynamicshape,
901                               jcr->code);
902   // restore function object from frame
903   PyObject *new_func = PyFunction_New(reinterpret_cast<PyObject *>(frame->f_code), frame->f_globals);
904   Py_XSETREF(PyFunction_GET_CLOSURE(new_func), GetClosure(frame));
905   PyFunctionObject *func = reinterpret_cast<PyFunctionObject *>(new_func);
906   PyFrameObject *f = const_cast<PyFrameObject *>(frame);
907   std::vector<PyObject *> backup;
908   if (enable_dynamicshape) {
909     backup = jcr->code->GetGuard()->ApplyDynamicShape(f);
910     PyFrame_FastToLocals(f);
911   }
912   RunEnvironment runEnvironment;
913   runEnvironment.fetchAndSetRunEnv(jcr);
914   std::string phase = CallGraphCompiler(jcr, func, frame);
915   runEnvironment.resumePreviousRunEnv();
916   if (enable_dynamicshape) {
917     jcr->code->GetGuard()->RevertDynamicShape(f, backup);
918     PyFrame_FastToLocals(f);
919   }
920 
921   Py_DECREF(new_func);
922 
923   if (jcr->conf->GetBoolConfig(GraphJitConfig::kReuseGraph)) {
924     auto graph_executor = mindspore::pipeline::GraphExecutorPy::GetInstance();
925     FuncGraphPtr ms_func_graph = graph_executor->GetFuncGraph(phase);
926     std::string key = GraphToString(ms_func_graph);
927     auto pcode = OptCodeHub::Filter(key, [jcr, graph_executor, ms_func_graph](OptCodePtr code) {
928       FuncGraphPtr func_graph = graph_executor->GetFuncGraph(code->GetPhase());
929       FuncGraphPairMapEquiv equiv_graph;
930       NodeMapEquiv equiv_node;
931       if (func_graph != nullptr && Isomorphic(ms_func_graph, func_graph, &equiv_graph, &equiv_node)) {
932         return true;
933       } else {
934         return false;
935       }
936     });
937     if (pcode != nullptr) {
938       if (jcr->conf->GetBoolConfig(GraphJitConfig::kPrintReuseGraph)) {
939         std::ostringstream graph_buffer;
940         DumpIR(graph_buffer, ms_func_graph);
941         std::cout << "Graph Duplicated:" << std::endl;
942         std::cout << "  Graph:" << graph_buffer.str() << std::endl;
943         std::cout << "  Bytecode:" << std::endl;
944         Utils::DisFuncObject(reinterpret_cast<PyObject *>(frame->f_code));
945       }
946       // find duplicate graph and reuse it
947       pcode->Copy(jcr->code);
948     } else {
949       // current graph is a new one and register it
950       OptCodeHub::Register(key, jcr->code);
951     }
952   }
953 }
954 
955 extern bool UnsupportedCodeTypeCheck(PyCodeObject *co);
JitCompile(PyThreadState * tstate,JitCompileResults * c)956 static bool JitCompile(PyThreadState *tstate, JitCompileResults *c) {
957   if (UnsupportedCodeTypeCheck(c->origin_frame_->f_code)) {
958     return false;
959   }
960 
961   ShapeContext sc(c->origin_frame_, c->signature_);
962   std::string code_str = py::str(reinterpret_cast<PyObject *>(c->origin_frame_->f_code));
963   MS_LOG(DEBUG) << "---start compile " << code_str << "---";
964 
965   // new guard code
966   c->code = c->codehub->AddOptTarget(OptOption::CreateOptionByPoint(c));
967   AddConfigToGuard(*c->conf, c->code->GetGuard());
968 
969   py::object frame = py::reinterpret_borrow<py::object>(reinterpret_cast<PyObject *>(c->origin_frame_));
970   if (c->stat == JitCompileResults::GRAPH_CANDIDATE) {
971     TimeRecorder time_recorder("kTimeCompileCapture", kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kLogPerf));
972     runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kCapture, runtime::ProfilerEvent::kCaptureProcess,
973                                        "PIJitCapture");
974     c->stat = JitCompileResults::GRAPH_BUILDING;
975     auto aobject_resource = AObject::MakeResource();
976     GraphCapture(c);
977     sc.ApplySignature();
978     if (c->stat == JitCompileResults::GRAPH_CAPTURED) {
979       PyFrameObject *f = PrepareCallCompiledCallable(tstate, c->origin_frame_, c);
980       frame = py::reinterpret_steal<py::object>(reinterpret_cast<PyObject *>(f));
981     }
982     if (c->conf->GetBoolConfig(GraphJitConfig::kTraceFlag)) {
983       PyFrameObject *f = reinterpret_cast<PyFrameObject *>(frame.ptr());
984       GuardForFrame(f, c->code, *c->conf);
985       AddGuardForGlobals(f, c->code->GetGuard(), c->conf->GetBoolConfig(GraphJitConfig::kGuardDetachObject));
986     }
987     aobject_resource.Release();
988   }
989   sc.ApplySignature();
990 
991   if (c->stat == JitCompileResults::GRAPH_CAPTURED) {
992     TimeRecorder time_recorder("kTimeCompileGraph", kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kLogPerf));
993     runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kCapture, runtime::ProfilerEvent::kCaptureCompile,
994                                        "PIJitCompile");
995     c->stat = JitCompileResults::GRAPH_BUILDING;
996     PyFrameObject *f = reinterpret_cast<PyFrameObject *>(frame.ptr());
997     PyFrame_FastToLocals(f);
998     GraphCompile(c, f);
999   }
1000 
1001   auto guard = c->code->GetGuard()->Optimize();
1002   if (guard != nullptr) {
1003     c->code->SetGuard(guard);
1004   }
1005 
1006   CollectTraceBack(c, c->code->GetPythonCode(), c->code->GetNativeFunc() != nullptr);
1007 
1008   if (c->conf->GetBoolConfig(GraphJitConfig::kPrintAfterAll)) {
1009     GRAPH_JIT_LOG_F("%s\n", c->tbs->Dump().c_str());
1010 
1011     GRAPH_JIT_LOG_F("generated guard at %s\n", code_str.c_str());
1012     GRAPH_JIT_LOG_F("%s\n", c->code->GetGuard()->ToString().c_str());
1013   }
1014   if (c->stat != JitCompileResults::GRAPH_CALLABLE) {
1015     c->stat = JitCompileResults::NEVER_COMPILE;
1016     return false;
1017   }
1018   return true;
1019 }
1020 
PackArgs(const PyFrameObject * frame)1021 std::vector<py::object> PackArgs(const PyFrameObject *frame) {
1022   const Py_ssize_t argc = frame->f_code->co_argcount + frame->f_code->co_kwonlyargcount;
1023   bool has_varg = static_cast<unsigned int>(frame->f_code->co_flags) & CO_VARARGS;
1024   py::list args(argc);
1025   py::object vargs;
1026   py::object kwvargs;
1027   for (Py_ssize_t i = 0; i < argc; ++i) {
1028     args[i] = py::reinterpret_borrow<py::object>(frame->f_localsplus[i]);
1029   }
1030   if (has_varg) {
1031     vargs = py::reinterpret_borrow<py::object>(frame->f_localsplus[argc]);
1032   }
1033   if (static_cast<unsigned int>(frame->f_code->co_flags) & CO_VARKEYWORDS) {
1034     kwvargs = py::reinterpret_borrow<py::object>(frame->f_localsplus[argc + has_varg]);
1035   }
1036 
1037   const Py_ssize_t ncells = PyTuple_GET_SIZE(frame->f_code->co_cellvars);
1038   for (Py_ssize_t i = 0; frame->f_code->co_cell2arg && i < ncells; ++i) {
1039     Py_ssize_t argi = frame->f_code->co_cell2arg[i];
1040     if (argi != CO_CELL_NOT_AN_ARG) {
1041       PyObject *cell = frame->f_localsplus[frame->f_code->co_nlocals + i];
1042       args[argi] = py::reinterpret_borrow<py::object>(PyCell_GET(cell));
1043     }
1044   }
1045   return {args, vargs, kwvargs};
1046 }
1047 
ResultMutable(py::object obj)1048 static py::object ResultMutable(py::object obj) {
1049   py::object mutable_func = Utils::GetModuleAttr("mindspore.common", "mutable", false, true);
1050   if (py::isinstance<py::tuple>(obj)) {
1051     auto tuple_obj = obj.cast<py::tuple>();
1052     py::list mutable_list(tuple_obj);
1053     for (size_t i = 0; i < tuple_obj.size(); i++) {
1054       try {
1055         auto mutable_element = mutable_func(tuple_obj[i]);
1056         mutable_list[i] = mutable_element;
1057       } catch (py::error_already_set &e) {
1058         if (PyErr_Occurred()) {
1059           PyErr_Clear();
1060         }
1061         continue;
1062       }
1063     }
1064     auto mutable_tuple = py::tuple(mutable_list);
1065     return mutable_tuple;
1066   } else {
1067     try {
1068       auto mutable_obj = mutable_func(obj);
1069       return mutable_obj;
1070     } catch (py::error_already_set &e) {
1071       if (PyErr_Occurred()) {
1072         PyErr_Clear();
1073       }
1074     }
1075   }
1076   return obj;
1077 }
1078 
CallGraph(const JitCompileResults * c,const py::object & args,const py::object & kwvargs)1079 static py::object CallGraph(const JitCompileResults *c, const py::object &args, const py::object &kwvargs) {
1080   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kCapture, runtime::ProfilerEvent::kCaptureRunGraph,
1081                                      "PIJitRunGraph");
1082 
1083   StaticAnalysisExceptionCleaner exception_cleaner;
1084   RunEnvironment runEnvironment;
1085   runEnvironment.fetchAndSetRunEnv(c);
1086   PyObject *py_args = args.ptr();
1087   PyObject *py_kwvargs = kwvargs.ptr();
1088   PyObject *res;
1089   if (c->conf->GetBoolConfig(GraphJitConfig::kPerfStatistics) &&
1090       c->code->GetPerf(OptPerf::PerfKind::kPerfGraph)->GetStatistics()->GetTotalCount() <
1091         c->conf->getIntConfig(GraphJitConfig::kPerfStatisticsCount)) {
1092     std::function<PyObject *(PyObject * py_args, PyObject * py_kwvargs)> func = [c](PyObject *py_args,
1093                                                                                     PyObject *py_kwvargs) {
1094       auto ret = c->code->GetNativeFunc()(py_args, py_kwvargs);
1095       runtime::OpExecutor::GetInstance().WaitAll();
1096       return ret;
1097     };
1098     runtime::OpExecutor::GetInstance().WaitAll();
1099     res = CallFunction(c->code->GetPerf(OptPerf::PerfKind::kPerfGraph), func, py_args, py_kwvargs);
1100   } else {
1101     res = c->code->GetNativeFunc()(py_args, py_kwvargs);
1102   }
1103   runEnvironment.resumePreviousRunEnv();
1104   if (res == NULL && !PyErr_Occurred()) {
1105     PyErr_SetString(PyExc_RuntimeError, "compiled graph execute failed");
1106   }
1107   auto res_obj = py::reinterpret_steal<py::object>(res);
1108   return ResultMutable(res_obj);
1109 }
1110 
CallCompiledCallable(PyThreadState * tstate,PyFrameObject * f,const JitCompileResults * c)1111 static py::object CallCompiledCallable(PyThreadState *tstate, PyFrameObject *f, const JitCompileResults *c) {
1112   PyFrameObject *new_f;
1113   PyObject *res;
1114   int bci;
1115 
1116   if (c->code->GetPythonCode() != nullptr) {
1117     new_f = PrepareCallCompiledCallable(tstate, f, c);
1118   } else {
1119     Py_INCREF(f);
1120     new_f = f;
1121   }
1122 
1123   if (c->conf->GetBoolConfig(GraphJitConfig::kPerfStatistics) &&
1124       c->code->GetPerf(OptPerf::PerfKind::kPerfPyNative)->GetStatistics()->GetTotalCount() <
1125         c->conf->getIntConfig(GraphJitConfig::kPerfStatisticsCount)) {
1126     std::function<PyObject *(PyThreadState * tstate, PyFrameObject * f, int exc)> func = [](PyThreadState *tstate,
1127                                                                                             PyFrameObject *f, int exc) {
1128       auto ret = _PyEval_EvalFrameDefault(tstate, f, exc);
1129       runtime::OpExecutor::GetInstance().WaitAll();
1130       return ret;
1131     };
1132     runtime::OpExecutor::GetInstance().WaitAll();
1133     // use function pointer not std::function
1134     res = CallFunction(c->code->GetPerf(OptPerf::PerfKind::kPerfPyNative), func, tstate, new_f, 0);
1135   } else {
1136     res = _PyEval_EvalFrameDefault(tstate, new_f, 0);
1137   }
1138 
1139   bci = new_f->f_lasti;
1140   Py_DECREF(new_f);
1141 
1142   if (res == NULL && !PyErr_Occurred()) {
1143     PyErr_Format(PyExc_RuntimeError, "compiled function failed with unknown error, error bci %d", bci);
1144   }
1145   return py::reinterpret_steal<py::object>(res);
1146 }
1147 
CheckTensorInContainer(py::object args)1148 static bool CheckTensorInContainer(py::object args) {
1149   if (py::isinstance<py::tuple>(args)) {
1150     py::tuple t = py::cast<py::tuple>(args);
1151     for (size_t i = 0; i < t.size(); ++i) {
1152       if (CheckTensorInContainer(t[i])) {
1153         return true;
1154       }
1155     }
1156   } else if (py::isinstance<py::list>(args)) {
1157     py::list l = py::cast<py::list>(args);
1158     for (size_t i = 0; i < l.size(); ++i) {
1159       if (CheckTensorInContainer(l[i])) {
1160         return true;
1161       }
1162     }
1163   }
1164   if (IsStubTensor(args) || py::isinstance<mindspore::tensor::Tensor>(args.ptr())) {
1165     return true;
1166   } else {
1167     return false;
1168   }
1169 }
1170 
1171 static bool CheckAbstract(abstract::AbstractBasePtr abs, bool incontainer);
1172 
CheckContainer(abstract::AbstractBasePtr abs)1173 static bool CheckContainer(abstract::AbstractBasePtr abs) {
1174   if (abs->isa<abstract::AbstractTuple>()) {
1175     auto elems = abs->cast<abstract::AbstractTuplePtr>()->elements();
1176     for (size_t idx = 0; idx < elems.size(); ++idx) {
1177       if (!CheckAbstract(elems[idx], true)) {
1178         return false;
1179       }
1180     }
1181   }
1182   if (abs->isa<abstract::AbstractList>()) {
1183     auto elems = abs->cast<abstract::AbstractListPtr>()->elements();
1184     for (size_t idx = 0; idx < elems.size(); ++idx) {
1185       if (!CheckAbstract(elems[idx], true)) {
1186         return false;
1187       }
1188     }
1189   }
1190   if (abs->isa<abstract::AbstractSequence>()) {
1191     auto elems = abs->cast<abstract::AbstractSequencePtr>()->elements();
1192     for (size_t idx = 0; idx < elems.size(); ++idx) {
1193       if (!CheckAbstract(elems[idx], true)) {
1194         return false;
1195       }
1196     }
1197   }
1198   if (abs->isa<abstract::AbstractDictionary>()) {
1199     auto elems = abs->cast<abstract::AbstractDictionaryPtr>()->elements();
1200     for (size_t idx = 0; idx < elems.size(); ++idx) {
1201       if (!CheckAbstract(elems[idx].first, true) || !CheckAbstract(elems[idx].first, true)) {
1202         return false;
1203       }
1204     }
1205   }
1206   if (abs->isa<abstract::AbstractSlice>()) {
1207     auto slice = abs->cast<abstract::AbstractSlicePtr>();
1208     return !CheckAbstract(slice->start(), true) || !CheckAbstract(slice->stop(), true) ||
1209            !CheckAbstract(slice->step(), true);
1210   }
1211   return true;
1212 }
1213 
CheckAbstract(abstract::AbstractBasePtr abs,bool incontainer)1214 static bool CheckAbstract(abstract::AbstractBasePtr abs, bool incontainer) {
1215   if (incontainer && abs->isa<abstract::AbstractAny>()) {
1216     return false;
1217   }
1218   if (abs->isa<abstract::AbstractTuple>() || abs->isa<abstract::AbstractList>() ||
1219       abs->isa<abstract::AbstractSequence>() || abs->isa<abstract::AbstractDictionary>() ||
1220       abs->isa<abstract::AbstractSlice>()) {
1221     return CheckContainer(abs);
1222   }
1223   if (abs->isa<abstract::AbstractNone>() || abs->isa<abstract::AbstractNull>() || abs->isa<abstract::AbstractType>() ||
1224       abs->isa<abstract::AbstractFunction>() || abs->isa<abstract::AbstractAny>()) {
1225     return false;
1226   }
1227   if (abs->isa<abstract::AbstractScalar>()) {
1228     auto tp = abs->GetTypeTrack()->type_id();
1229     return tp != kMetaTypeNone && tp != kMetaTypeNull && tp != kNumberTypeBool;
1230   }
1231   return true;
1232 }
1233 
CheckValidReturn(const JitCompileResults * c)1234 static bool CheckValidReturn(const JitCompileResults *c) {
1235   auto graph_executor = mindspore::pipeline::GraphExecutorPy::GetInstance();
1236   FuncGraphPtr ms_func_graph = graph_executor->GetFuncGraph(c->code->GetPhase());
1237   auto abs = ms_func_graph->output()->abstract();
1238   return CheckAbstract(abs, false);
1239 }
1240 
PreferCallGraph(const JitCompileResults * c,py::object args)1241 static bool PreferCallGraph(const JitCompileResults *c, py::object args) {
1242   if (c->code->GetNativeFunc() == nullptr) {
1243     return false;
1244   }
1245   if (c->conf->GetBoolConfig(GraphJitConfig::kTraceFlag)) {
1246     return true;
1247   }
1248   if (!CheckValidReturn(c)) {
1249     return false;
1250   }
1251   py::tuple t = py::cast<py::tuple>(args);
1252   for (size_t i = 0; i < t.size(); ++i) {
1253     py::object obj = t[i];
1254     if (IsMutableObj(obj)) {
1255       continue;
1256     }
1257     if ((py::isinstance<py::list>(t[i]) || py::isinstance<py::tuple>(t[i])) && CheckTensorInContainer(t[i])) {
1258       return false;
1259     }
1260   }
1261   OptStrategy::ExecKind stat = OptStrategy::ExecKind::kExecGraph;
1262   if (c->conf->GetBoolConfig(GraphJitConfig::kPerfStatistics)) {
1263     constexpr auto kStatisticsScale = 10000.0;
1264     int scale_statistics = c->conf->getIntConfig(GraphJitConfig::kPerfStatisticsScale10000x);
1265     stat = OptStrategy::MakeExecStrategyByPerf(
1266       c->code->GetPerf(OptPerf::PerfKind::kPerfGraph), c->code->GetPerf(OptPerf::PerfKind::kPerfPyNative),
1267       c->conf->getIntConfig(GraphJitConfig::kPerfStatisticsCount), scale_statistics / kStatisticsScale);
1268   }
1269   int graph_bytecode_min = c->conf->getIntConfig(GraphJitConfig::kStaticGraphBytecodeMin);
1270   if (graph_bytecode_min > 0 && stat == OptStrategy::ExecKind::kExecGraph) {
1271     stat = OptStrategy::MakeExecStrategyByComplex(c->code->GetPythonCode(), graph_bytecode_min);
1272   }
1273   return stat == OptStrategy::ExecKind::kExecGraph;
1274 }
1275 
SetExecStatus(const JitCompileResults * c,const PyFrameObject * f,bool graph_preferred)1276 static void SetExecStatus(const JitCompileResults *c, const PyFrameObject *f, bool graph_preferred) {
1277   bool enable_statistics = c->conf->GetBoolConfig(GraphJitConfig::kPerfStatistics);
1278   int graph_bytecode_min = c->conf->getIntConfig(GraphJitConfig::kStaticGraphBytecodeMin);
1279   if (enable_statistics || (graph_bytecode_min > 0)) {
1280     PyObject_SetItem(f->f_globals, reinterpret_cast<PyObject *>(f->f_code), (graph_preferred ? Py_True : Py_False));
1281   }
1282 }
1283 
CallCompiledResults(PyThreadState * tstate,PyFrameObject * f,JitCompileResults * c)1284 static py::object CallCompiledResults(PyThreadState *tstate, PyFrameObject *f, JitCompileResults *c) {
1285   if (MsContext::GetInstance()->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
1286     return py::none();
1287   }
1288 
1289   ValidateCompiledResults(c);
1290 
1291   std::vector<py::object> packed_args = PackArgs(f);
1292   if (packed_args[1].ptr() != nullptr) {
1293     PyList_Append(packed_args[0].ptr(), packed_args[1].ptr());
1294   }
1295 
1296   py::object args = py::reinterpret_steal<py::object>(PyList_AsTuple(packed_args[0].ptr()));
1297   py::object kwvargs = packed_args[2];
1298   bool graph_preferred = PreferCallGraph(c, args);
1299   SetExecStatus(c, f, graph_preferred);
1300   py::object res;
1301   if (!graph_preferred) {
1302     res = CallCompiledCallable(tstate, f, c);
1303   } else if (!c->conf->GetBoolConfig(GraphJitConfig::kCompileWithTry)) {
1304     res = CallGraph(c, args, kwvargs);
1305   } else {
1306     try {
1307       res = CallGraph(c, args, kwvargs);
1308     } catch (std::exception &e) {
1309       MS_LOG(WARNING) << "compile result has an error, de-optimization\n" << e.what();
1310       res = CallCompiledCallable(tstate, f, c);
1311       c->stat = JitCompileResults::NEVER_COMPILE;
1312     }
1313   }
1314   c->code->Inc();
1315 
1316   if (kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kLogPerf)) {
1317     PyObject *new_code = c->code->GetPythonCode() ? c->code->GetPythonCode()->co_code : f->f_code->co_code;
1318     ByteCodeRunStatistic::GetInstance()->Count(graph_preferred ? f->f_code->co_code : new_code, graph_preferred);
1319   }
1320 
1321   // dump traceback
1322   if (c->conf->GetBoolConfig(GraphJitConfig::kPrintTraceback)) {
1323     // dump all traceback for the root function
1324     GRAPH_JIT_LOG_F("%s\n", c->tbs->Dump(true).c_str());
1325   }
1326   if (!PyErr_Occurred()) {
1327     c->tbs->Clear();
1328   }
1329   return res;
1330 }
1331 
CheckGuard(JitCompileResults * c,const PyFrameObject * f)1332 static bool CheckGuard(JitCompileResults *c, const PyFrameObject *f) {
1333   TimeRecorder time_recorder(__FUNCTION__, kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kLogPerf));
1334 
1335   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kCapture, runtime::ProfilerEvent::kCaptureGuard,
1336                                      "PIJitGuard");
1337 
1338   StaticAnalysisExceptionCleaner exception_cleaner;
1339 
1340   c->code = nullptr;
1341   std::map<size_t, PyObject *> cache;
1342   std::map<size_t, bool> success;
1343   std::map<size_t, bool> fail;
1344   OptOptionPtr opt = OptOption::CreateOptionByPoint(c);
1345   auto set = c->codehub->GetOptTarget(opt);
1346   set = OptStrategy::MakeGuardListStrategyByFrame(f, set);
1347   for (size_t i = set.size(); i != 0; i--) {
1348     auto oc = set[i - 1];
1349     OptGuardPtr guard = oc->GetGuard();
1350     bool print_guard = c->conf->GetBoolConfig(GraphJitConfig::kPrintGuard);
1351     if (guard != nullptr &&
1352         guard->Check(f, print_guard, &cache, &success, &fail, c->conf->GetBoolConfig(GraphJitConfig::kLogGuardPerf))) {
1353       c->code = oc;
1354       c->codehub->UpdateOptTarget(opt, oc);
1355       break;
1356     }
1357   }
1358   for (auto item : cache) {
1359     Py_XDECREF(item.second);
1360   }
1361   MS_LOG(DEBUG) << __FUNCTION__ << (c->code != nullptr ? " success !" : " failed !");
1362   return c->code != nullptr;
1363 }
1364 
1365 class JitSyntaxLevelScope {
1366  public:
JitSyntaxLevelScope(bool enable)1367   explicit JitSyntaxLevelScope(bool enable) : enable_(enable) {
1368     if (enable_) {
1369       MS_LOG(INFO) << "Start run PIJit with one stage mode";
1370       origin_jit_syntax_level_ = common::GetEnv("MS_DEV_JIT_SYNTAX_LEVEL");
1371       common::SetEnv("MS_DEV_JIT_SYNTAX_LEVEL", "0");
1372     }
1373   }
~JitSyntaxLevelScope()1374   ~JitSyntaxLevelScope() {
1375     if (enable_) {
1376       common::SetEnv("MS_DEV_JIT_SYNTAX_LEVEL", origin_jit_syntax_level_.c_str());
1377     }
1378   }
1379 
1380  private:
1381   std::string origin_jit_syntax_level_;
1382   bool enable_;
1383 };
1384 
JitCompileWithTry(PyThreadState * tstate,JitCompileResults * c)1385 static bool JitCompileWithTry(PyThreadState *tstate, JitCompileResults *c) {
1386   TimeRecorder time_recorder(__FUNCTION__, kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kLogPerf));
1387 
1388   JitSyntaxLevelScope jit_syntax_level_scope(c->conf->GetBoolConfig(GraphJitConfig::kTraceFlag));
1389   StaticAnalysisExceptionCleaner exception_cleaner;
1390 
1391   if (!c->conf->GetBoolConfig(GraphJitConfig::kCompileWithTry)) {
1392     return JitCompile(tstate, c);
1393   }
1394 
1395   bool compiled = false;
1396   try {
1397     compiled = JitCompile(tstate, c);
1398   } catch (std::exception &e) {
1399     MS_LOG(ERROR) << "got an unexpected c++ error [" << e.what() << "]";
1400   }
1401   if (PyErr_Occurred()) {
1402     MS_LOG(ERROR) << "got an unexpected python error [" << py::error_already_set().what() << "]";
1403     PyErr_Clear();
1404     compiled = false;
1405   }
1406   return compiled;
1407 }
1408 
EliminateStubTensor(const py::tuple & args)1409 py::tuple EliminateStubTensor(const py::tuple &args) {
1410   py::tuple new_args = py::reinterpret_steal<py::tuple>(PyTuple_New(args.size()));
1411   for (size_t idx = 0; idx < args.size(); idx++) {
1412     new_args[idx] = IsStubTensor(args[idx]) ? python_adapter::CallPyObjMethod(args[idx], "stub_sync") : args[idx];
1413   }
1414   return new_args;
1415 }
1416 
1417 // bellowing code is used for debugging code generate, and will be remove soon
test_graph_ir_code_gen(PyFrameObject * frame)1418 py::object test_graph_ir_code_gen(PyFrameObject *frame) {
1419   PyFrame_FastToLocals(frame);
1420   auto func =
1421     py::reinterpret_steal<py::object>(PyFunction_New(reinterpret_cast<PyObject *>(frame->f_code), frame->f_globals));
1422   mindspore::pijit::Utils::DisFuncObject(func.ptr());
1423   auto byteCodeParser = std::make_shared<mindspore::pijit::ByteCodeParser>(func);
1424   mindspore::pijit::ir::FunctionNodePtr func_node = byteCodeParser->Parse();
1425   auto inliner = std::make_shared<mindspore::pijit::FuncInliner>(func_node);
1426   inliner->Run();
1427   int arg_cnt = frame->f_code->co_argcount + frame->f_code->co_kwonlyargcount;
1428   if (static_cast<unsigned int>(frame->f_code->co_flags) & CO_VARARGS) {
1429     arg_cnt++;
1430   }
1431   py::list locals = py::reinterpret_steal<py::list>(PyDict_Values(frame->f_locals));
1432   py::tuple args = py::reinterpret_steal<py::tuple>(PyList_AsTuple(PyList_GetSlice(locals.ptr(), 0, arg_cnt)));
1433   py::dict kwargs = (static_cast<unsigned int>(frame->f_code->co_flags) & CO_VARKEYWORDS) == 0x0
1434                       ? py::dict()
1435                       : py::cast<py::dict>(locals[arg_cnt]);
1436   args = EliminateStubTensor(args);
1437   mindspore::pijit::AbstractTypeDeducer::Deduce(func_node, args, kwargs);
1438   func_node->Sort();
1439   std::cout << func_node->ToString() << std::endl;
1440   auto func_obj = mindspore::pijit::ByteCodeGenerator::GenFunction(func_node);
1441   mindspore::pijit::Utils::DisFuncObject(func_obj.ptr());
1442   if ((static_cast<unsigned int>(func_node->GetFlags()) & CO_VARARGS) != 0) {
1443     auto pos_cnt = args.size() - 1;
1444     auto var_vargs = py::cast<py::tuple>(args[pos_cnt]);
1445     auto new_args = py::reinterpret_steal<py::tuple>(PyTuple_New(pos_cnt + var_vargs.size()));
1446     size_t index = 0;
1447     std::for_each(args.begin(), args.end() - 1, [&index, &new_args](const py::handle &arg) {
1448       new_args[index] = arg;
1449       index++;
1450     });
1451     std::for_each(var_vargs.begin(), var_vargs.end(), [&index, &new_args](const py::handle &arg) {
1452       new_args[index] = arg;
1453       index++;
1454     });
1455     args = new_args;
1456   }
1457   auto res = py::reinterpret_steal<py::object>(PyObject_Call(func_obj.ptr(), args.ptr(), kwargs.ptr()));
1458   res.inc_ref();
1459   return res;
1460 }
1461 
CodeHook(PyThreadState * tstate,JitCompileResults * c,PyFrameObject * frame)1462 static py::object CodeHook(PyThreadState *tstate, JitCompileResults *c, PyFrameObject *frame) {
1463   if (c->conf->GetBoolConfig(GraphJitConfig::kTestGraphIR)) {
1464     return test_graph_ir_code_gen(frame);
1465   }
1466   bool just_compiled = false;
1467   switch (c->stat) {
1468     case JitCompileResults::NEVER_COMPILE:
1469       break;
1470     case JitCompileResults::GRAPH_CAPTURED:
1471       if (c->conf->GetBoolConfig(GraphJitConfig::kInterpretCapturedCode)) {
1472         break;
1473       }
1474     /* fallthrough */
1475     case JitCompileResults::GRAPH_CANDIDATE:
1476       MS_EXCEPTION_IF_CHECK_FAIL(c->origin_frame_ == nullptr || c->origin_frame_ == frame,
1477                                  "check recursive call compiling function");
1478       c->origin_frame_ = frame;
1479       if (c->conf->GetBoolConfig(GraphJitConfig::kCompileWithoutCapture)) {
1480         c->stat = JitCompileResults::GRAPH_CAPTURED;
1481       }
1482       if (!JitCompileWithTry(tstate, c)) {
1483         c->stat = JitCompileResults::NEVER_COMPILE;
1484         break;
1485       }
1486       just_compiled = true;
1487     /* fallthrough */
1488     case JitCompileResults::GRAPH_CALLABLE: {
1489       if (CheckGuard(c, frame)) {
1490         c->origin_frame_ = nullptr;
1491         return CallCompiledResults(tstate, frame, c);
1492       }
1493       if (c->stat == JitCompileResults::NEVER_COMPILE) {
1494         break;
1495       }
1496       if (!just_compiled) {
1497         c->stat = JitCompileResults::GRAPH_CANDIDATE;
1498         return CodeHook(tstate, c, frame);
1499       }
1500       MS_LOG(EXCEPTION) << "shouldn't reach here";
1501     }
1502     case JitCompileResults::GRAPH_BUILDING:
1503       MS_LOG(ERROR) << "recursive call, compiler call the code "
1504                     << std::string(py::str(reinterpret_cast<PyObject *>(frame->f_code))) << " which is compiling";
1505       break;
1506     default:
1507       MS_LOG(EXCEPTION) << "shouldn't reach here";
1508       break;
1509   }
1510   PyObject *res = _PyEval_EvalFrameDefault(tstate, frame, 0);
1511   return py::reinterpret_steal<py::object>(res);
1512 }
1513 
ApplyAutoJit(PyFrameObject * f)1514 static void ApplyAutoJit(PyFrameObject *f) {
1515   if (!kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kAutoJit)) {
1516     return;
1517   }
1518 
1519   PyObject *code = reinterpret_cast<PyObject *>(f->f_code);
1520   if (getJitCompileResults(code, false) != nullptr) {
1521     return;
1522   }
1523 
1524   // first reached this code
1525   // allocate for all code while auto jit
1526   (void)getJitCompileResults(code, true);
1527   if (!kPIJitConfigDefault.ShouldAutoJit(f)) {
1528     return;
1529   }
1530   (void)pi_jit_should_compile(py::cast<py::object>(code), py::dict(), py::none());
1531 }
1532 
CollectGradientArguments(const PyFrameObject & frame)1533 py::list CollectGradientArguments(const PyFrameObject &frame) {
1534   py::list arguments;
1535 
1536   // Collect Positional Arguments
1537   for (int index = 1; index < frame.f_code->co_argcount; index++) {
1538     arguments.append(py::cast<py::object>(frame.f_localsplus[index]));
1539   }
1540 
1541   // Collect Variable Arguments
1542   if ((static_cast<unsigned int>(frame.f_code->co_flags) & CO_VARARGS) != 0x0) {
1543     auto var_args = py::cast<py::tuple>(frame.f_localsplus[frame.f_code->co_argcount]);
1544     std::for_each(var_args.begin(), var_args.end(), [&arguments](const auto &arg) { arguments.append(arg); });
1545   }
1546 
1547   // Collect Variable Arguments
1548   if ((static_cast<unsigned int>(frame.f_code->co_flags) & CO_VARKEYWORDS) != 0x0) {
1549     auto kw_args = py::cast<py::dict>(frame.f_localsplus[frame.f_code->co_argcount + 1]);
1550     std::for_each(kw_args.begin(), kw_args.end(), [&arguments](const auto &item) { arguments.append(item.second); });
1551   }
1552 
1553   return arguments;
1554 }
1555 
AutoGrad(PyFrameObject * f,PyObject * ret)1556 void AutoGrad(PyFrameObject *f, PyObject *ret) {
1557   // improve performance for infer
1558   if (kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kInferOnly)) {
1559     return;
1560   }
1561   // must have a return value and prim must have argument
1562   if (ret == nullptr || f->f_code->co_argcount <= 0) {
1563     return;
1564   }
1565   // the call function of primitive
1566   if (py::cast<py::object>(f->f_code->co_name).cast<std::string>() != "__call__") {
1567     return;
1568   }
1569   // only record primitvie now
1570   if (f->f_localsplus[0] == nullptr) {
1571     return;
1572   }
1573   if (!py::isinstance<Primitive>(f->f_localsplus[0]) && !py::isinstance<PrimitivePy>(f->f_localsplus[0]) &&
1574       !py::isinstance<PrimitivePyAdapter>(f->f_localsplus[0])) {
1575     return;
1576   }
1577   // gradient info check
1578   if (!grad::FunctionNode::HasAttrReqGrad(ret) && !py::isinstance<py::tuple>(ret)) {
1579     return;
1580   }
1581   MS_EXCEPTION_IF_CHECK_FAIL(f->f_code->co_kwonlyargcount == 0, "Must not have kw only args.");
1582   auto inputs = CollectGradientArguments(*f);
1583   if (!std::any_of(inputs.begin(), inputs.end(),
1584                    [](const auto &input) { return grad::FunctionNode::IsRequiresGradient(input); })) {
1585     return;
1586   }
1587   grad::FunctionNode::RecordPrimitive(py::cast<py::object>(f->f_localsplus[0]), py::cast<py::object>(ret), inputs);
1588 }
1589 
1590 #if (PY_MAJOR_VERSION == 3) && (PY_MINOR_VERSION < 9)
EvalFrame(PyFrameObject * f,int exc)1591 PyObject *EvalFrame(PyFrameObject *f, int exc) {
1592   PyThreadState *tstate = PyThreadState_Get();
1593 
1594 #else
1595 PyObject *EvalFrame(PyThreadState *tstate, PyFrameObject *f, int exc) {
1596 #endif
1597 
1598   // exception handler
1599   if (exc != 0) {
1600     return _PyEval_EvalFrameDefault(tstate, f, exc);
1601   }
1602 
1603   ApplyAutoJit(f);
1604 
1605   PyObject *code = reinterpret_cast<PyObject *>(f->f_code);
1606   JitCompileResults *c = getJitCompileResults(code, false);
1607   if (c == nullptr) {
1608     auto ret = _PyEval_EvalFrameDefault(tstate, f, exc);
1609     AutoGrad(f, ret);
1610     return ret;
1611   }
1612   py::object res;
1613   try {
1614     res = CodeHook(tstate, c, f);
1615   } catch (py::error_already_set &e) {
1616     e.restore();
1617   } catch (py::builtin_exception &e) {
1618     e.set_error();
1619   }
1620   return res.inc_ref().ptr();
1621 }
1622 }  // namespace pijit
1623 }  // namespace mindspore
1624 
1625 namespace mindspore {
1626 
1627 #if (PY_MAJOR_VERSION == 3) && (PY_MINOR_VERSION >= 7) && (PY_MINOR_VERSION <= 10)
1628 
pi_jit_enable()1629 py::bool_ pi_jit_enable() {
1630   PyInterpreterState *inter = PyInterpreterState_Main();
1631   _PyFrameEvalFunction prev = _PyInterpreterState_GetEvalFrameFunc(inter);
1632   _PyFrameEvalFunction def = _PyEval_EvalFrameDefault;
1633   if (prev != def) {
1634     return false;
1635   }
1636   mindspore::pijit::ensureInitialize();
1637   _PyInterpreterState_SetEvalFrameFunc(inter, mindspore::pijit::EvalFrame);
1638   return true;
1639 }
1640 
pi_jit_disable()1641 py::bool_ pi_jit_disable() {
1642   PyInterpreterState *inter = PyInterpreterState_Main();
1643   _PyFrameEvalFunction prev = _PyInterpreterState_GetEvalFrameFunc(inter);
1644   _PyFrameEvalFunction def = _PyEval_EvalFrameDefault;
1645   if (prev != mindspore::pijit::EvalFrame) {
1646     return false;
1647   }
1648   _PyInterpreterState_SetEvalFrameFunc(inter, def);
1649   return true;
1650 }
1651 
pi_jit_should_compile(const py::object & funcHandle,const py::object & tag,const py::object & signature)1652 py::bool_ pi_jit_should_compile(const py::object &funcHandle, const py::object &tag, const py::object &signature) {
1653   PyObject *func = funcHandle.ptr();
1654   PyObject *code = NULL;
1655   if (PyFunction_Check(func)) {
1656     code = PyFunction_GET_CODE(func);
1657   } else if (PyMethod_Check(func)) {
1658     func = PyMethod_GET_FUNCTION(func);
1659     code = PyFunction_GET_CODE(func);
1660   } else if (PyCode_Check(func)) {
1661     code = func;
1662   } else {
1663     return false;
1664   }
1665   mindspore::pijit::JitCompileResults *c = mindspore::pijit::getJitCompileResults(code);
1666   if (c == nullptr) {
1667     return false;
1668   }
1669   PyObject *sig = signature.ptr();
1670   if (sig != nullptr && sig != Py_None) {
1671     c->signature_ = sig;
1672     Py_INCREF(sig);
1673   }
1674   auto new_config = mindspore::pijit::GraphJitConfig(tag);
1675   // When switching between one-stage and two-stage, reset the config.
1676   if (c->conf->GetBoolConfig(pijit::GraphJitConfig::kTraceFlag) !=
1677       new_config.GetBoolConfig(pijit::GraphJitConfig::kTraceFlag)) {
1678     c->code = nullptr;
1679     c->codehub = std::make_shared<pijit::OptCodeHub>();
1680   }
1681   if (c->stat != mindspore::pijit::JitCompileResults::NEVER_COMPILE) {
1682     *c->conf = new_config;
1683     return true;
1684   }
1685 
1686   auto raw_code_size = (PyBytes_GET_SIZE(reinterpret_cast<PyCodeObject *>(code)->co_code)) / sizeof(_Py_CODEUNIT);
1687   std::string raw_func_info_name = py::str(code).cast<std::string>();
1688   std::string raw_func_name = "";
1689   if (PyFunction_Check(func)) {
1690     const char *module_name = PyUnicode_AsUTF8(PyFunction_GET_MODULE(func));
1691     const char *s = strchr(module_name, '.');
1692     std::string top_module = s ? std::string(module_name, s - module_name) : module_name;
1693     mindspore::pijit::kPIJitConfigDefault.AddAllowedInlineModules(top_module);
1694 
1695     raw_func_name = mindspore::pijit::Utils::GetPyName(reinterpret_cast<PyFunctionObject *>(func)->func_qualname);
1696   }
1697 
1698   c->stat = mindspore::pijit::JitCompileResults::GRAPH_CANDIDATE;
1699   *c->conf = new_config;
1700   *c->tbs = mindspore::pijit::Tracebackes(raw_func_name, raw_func_info_name, raw_code_size);
1701   return true;
1702 }
1703 #else
1704 
1705 py::bool_ pi_jit_enable() {
1706   MS_LOG(ERROR) << "PiJit not support this python version " << PY_MAJOR_VERSION << '.' << PY_MINOR_VERSION
1707                 << " only support on python3.7, python3.8, python3.9, python3.10";
1708   return py::bool_(false);
1709 }
1710 py::bool_ pi_jit_disable() { return py::bool_(false); }
1711 py::bool_ pi_jit_should_compile(const py::object &func, const py::object &tag, const py::object &signature) {
1712   return py::bool_(false);
1713 }
1714 
1715 #endif
1716 
ConvertCodeExtra(mindspore::pijit::CodeExtra * c)1717 static py::object ConvertCodeExtra(mindspore::pijit::CodeExtra *c) {
1718   if (c->code == nullptr) {
1719     return py::object();
1720   }
1721   PyCodeObject *compiled_code = c->code->GetPythonCode();
1722   auto compiled_func = c->code->GetNativeFunc();
1723   auto guard = c->code->GetGuard();
1724   if (compiled_func == nullptr && compiled_code == nullptr) {
1725     return py::object();
1726   }
1727   py::dict code;
1728   if (compiled_code != nullptr) {
1729     PyDict_SetItemString(code.ptr(), "compiled_code_", reinterpret_cast<PyObject *>(compiled_code));
1730   }
1731   if (compiled_func != nullptr) {
1732     PyDict_SetItemString(code.ptr(), "phase_", py::str(c->code->GetPhase()).ptr());
1733   }
1734   if (guard != nullptr && !guard->IsEmpty()) {
1735     PyDict_SetItemString(code.ptr(), "guard_", py::str(guard->ToString()).ptr());
1736   }
1737   PyDict_SetItemString(code.ptr(), "call_count_", py::int_(c->code->Count()).ptr());
1738   return code;
1739 }
1740 
get_code_extra(const py::object & func)1741 py::object get_code_extra(const py::object &func) {
1742   py::object code = mindspore::pijit::GetPyCodeObject(func);
1743   if (code.ptr() == nullptr) {
1744     return py::none();
1745   }
1746   auto c = mindspore::pijit::getJitCompileResults(code.ptr(), false);
1747   if (c == nullptr) {
1748     return py::none();
1749   }
1750 
1751   constexpr const char *stat_str[] = {
1752     "NEVER_COMPILE", "GRAPH_CANDIDATE", "GRAPH_CAPTURED", "GRAPH_BUILDING", "GRAPH_CALLABLE",
1753   };
1754 
1755   py::dict result;
1756   py::object compiled_code = ConvertCodeExtra(c);
1757   if (compiled_code.ptr() != nullptr) {
1758     PyDict_SetItemString(result.ptr(), "code", compiled_code.ptr());
1759   }
1760   PyDict_SetItemString(result.ptr(), "stat", py::str(stat_str[c->stat]).ptr());
1761   PyDict_SetItemString(result.ptr(), "compile_count_", py::int_(c->compile_count_).ptr());
1762   PyDict_SetItemString(result.ptr(), "break_count_", py::int_(c->break_count_).ptr());
1763   return result;
1764 }
1765 
FunctionId(const py::object & callable)1766 size_t FunctionId(const py::object &callable) {
1767   // filter special cpp function
1768   auto py_cfunction_filter = [](PyObject *op) -> void * {
1769     // pybind11::cpp_function::dispatcher;
1770     static PyCFunction pybind_dispatcher = PyCFunction_GET_FUNCTION(py::cpp_function([]() {}).ptr());
1771     PyCFunction result = PyCFunction_GET_FUNCTION(op);
1772     return result == pybind_dispatcher ? op : reinterpret_cast<void *>(result);
1773   };
1774   PyObject *op = callable.ptr();
1775   if (PyMethod_Check(op)) {
1776     op = PyMethod_GET_FUNCTION(op);
1777   }
1778   if (PyInstanceMethod_Check(op)) {
1779     op = PyInstanceMethod_GET_FUNCTION(op);
1780   }
1781   void *result = op;
1782   if (PyCFunction_Check(op)) {
1783     // types.BuiltinFunctionType = type(len) same as types.BuiltinMethodType = type(list().append)
1784     result = py_cfunction_filter(op);
1785   } else if (Py_IS_TYPE(op, &PyMethodDescr_Type)) {
1786     // types.MethodDescriptorType = type(list.append)
1787     PyCFunction func = reinterpret_cast<PyMethodDescrObject *>(op)->d_method->ml_meth;
1788     result = reinterpret_cast<void *>(func);
1789   } else if (Py_IS_TYPE(op, &PyWrapperDescr_Type)) {
1790     // types.WrapperDescriptorType = type(object.__init__)
1791     result = reinterpret_cast<PyWrapperDescrObject *>(op)->d_wrapped;
1792   } else if (Py_IS_TYPE(op, &_PyMethodWrapper_Type)) {
1793     // types.WrapperDescriptorType = type(object().__str__)
1794     PyObject *self = PyObject_GetAttrString(op, "__self__");
1795     PyObject *attr = PyObject_GetAttrString(op, "__name__");
1796     PyObject *descr = PyObject_GetAttr(reinterpret_cast<PyObject *>(Py_TYPE(self)), attr);
1797     result = reinterpret_cast<PyWrapperDescrObject *>(descr)->d_wrapped;
1798     Py_DECREF(self);
1799     Py_DECREF(attr);
1800     Py_DECREF(descr);
1801   }
1802   return reinterpret_cast<size_t>(result);
1803 }
1804 
1805 }  // namespace mindspore
1806