1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pipeline/jit/pi/common.h"
17 #include <algorithm>
18 #include <iomanip>
19 #include <iterator>
20 #include <regex>
21 #include <list>
22 #include <map>
23 #include <memory>
24 #include <string>
25 #include <vector>
26 #include "pybind11/pybind11.h"
27 #include "include/common/utils/convert_utils_py.h"
28 #include "pipeline/jit/pi/auto_grad/function_node.h"
29 #include "pipeline/jit/pi/external.h"
30 #include "pipeline/jit/pi/graph_capture/graph_build.h"
31 #include "pipeline/jit/pi/graph_capture/graph_analyzer.h"
32 #include "pipeline/jit/pi/graph_compiler/abstract_type_deducer.h"
33 #include "pipeline/jit/pi/graph_compiler/compiler.h"
34 #include "pipeline/jit/pi/graph_compiler/cg/byte_code_generator.h"
35 #include "pipeline/jit/pi/graph_compiler/inliner/func_inliner.h"
36 #include "pipeline/jit/pi/graph_compiler/parser/byte_code_parser.h"
37 #include "pipeline/jit/pi/graph_compiler/utils.h"
38 #include "pipeline/jit/pi/utils/utils.h"
39 #include "pipeline/jit/pi/graph_guard/guard.h"
40 #include "pipeline/jit/pi/graph_guard/strategy.h"
41 #include "pipeline/jit/pi/graph_guard/shape_ctx.h"
42 #include "pipeline/jit/ps/pipeline.h"
43 #include "pipeline/pynative/pynative_utils.h"
44 #include "runtime/pynative/op_executor.h"
45 #include "include/common/debug/anf_ir_dump.h"
46 #include "pipeline/jit/pi/graph_capture/code_generator.h"
47 #include "pipeline/jit/pi/graph_capture/bytecode_inliner.h"
48 #include "utils/convert_utils_base.h"
49
50 namespace mindspore {
51 namespace pijit {
52 static Py_tss_t *tss = NULL;
53
54 void AddConfigToGuard(const GraphJitConfig &c, OptGuardPtr guard);
55 void AddGuardForParam(const PyFrameObject *f, OptGuardPtr guard, bool detach);
56 void AddGuardForGlobals(const PyFrameObject *f, OptGuardPtr guard, bool detach);
57 static void AddGradFlagForParam(bool grad_flag, OptGuardPtr guard, bool detach);
58 static void CollectTraceBack(JitCompileResults *c, PyCodeObject *code, bool is_graph_mode);
59
60 class ByteCodeRunStatistic {
61 public:
~ByteCodeRunStatistic()62 ~ByteCodeRunStatistic() {
63 if (py_.empty() && graph_.empty()) {
64 return;
65 }
66 std::cout << ToString() << std::endl;
67 }
68
Count(PyObject * code,bool graph_preferred)69 void Count(PyObject *code, bool graph_preferred) {
70 if (graph_preferred) {
71 graph_[PyBytes_GET_SIZE(code)]++;
72 } else {
73 py_[PyBytes_GET_SIZE(code)]++;
74 }
75 }
76
ToString()77 std::string ToString() {
78 const auto SumFunc = [](size_t sum, const std::pair<uint64_t, size_t> &i) { return sum + (i.first * i.second); };
79 size_t sum_py = std::accumulate(py_.begin(), py_.end(), 0, SumFunc);
80 size_t sum_graph = std::accumulate(graph_.begin(), graph_.end(), 0, SumFunc);
81 double ratio = static_cast<double>(sum_graph) / (sum_graph + sum_py);
82 return "execute code ratio (graph / (graph + python)): " + std::to_string(ratio);
83 }
84
GetInstance()85 static ByteCodeRunStatistic *GetInstance() {
86 static ByteCodeRunStatistic instance;
87 return &instance;
88 }
89
90 private:
91 ByteCodeRunStatistic() = default;
92 std::map<uint64_t, size_t> py_;
93 std::map<uint64_t, size_t> graph_;
94 };
95
96 class StaticAnalysisExceptionCleaner {
97 public:
98 StaticAnalysisExceptionCleaner() = default;
~StaticAnalysisExceptionCleaner()99 ~StaticAnalysisExceptionCleaner() { StaticAnalysisException::Instance().ClearException(); }
100 };
101
102 class RunEnvironment {
103 public:
104 RunEnvironment() = default;
105
fetchAndSetRunEnv(const JitCompileResults * jcr)106 void fetchAndSetRunEnv(const JitCompileResults *jcr) {
107 auto ms_context = MsContext::GetInstance();
108 run_mode_ = ms_context->get_param<int>(MS_CTX_EXECUTION_MODE);
109 jit_level_ = ms_context->GetJitLevel();
110 task_sink_ = ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
111
112 auto jit_level = jcr->conf->getJitLevel();
113 auto grad_flag = pynative::PyNativeExecutor::GetInstance()->grad_flag();
114 auto run_mode = jit_level == "O2" && !grad_flag ? kGraphMode : kPynativeMode;
115 auto task_sink = jit_level == "O2" && !grad_flag;
116 ms_context->set_param(MS_CTX_EXECUTION_MODE, run_mode);
117 ms_context->set_param(MS_CTX_JIT_LEVEL, jit_level);
118 ms_context->SetJitLevel(jit_level);
119 ms_context->set_param<bool>(MS_CTX_ENABLE_TASK_SINK, task_sink);
120 ms_context->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, task_sink);
121 }
122
resumePreviousRunEnv()123 void resumePreviousRunEnv() {
124 auto ms_context = MsContext::GetInstance();
125 ms_context->set_param(MS_CTX_EXECUTION_MODE, run_mode_);
126 ms_context->set_param(MS_CTX_JIT_LEVEL, jit_level_);
127 ms_context->SetJitLevel(jit_level_);
128 ms_context->set_param<bool>(MS_CTX_ENABLE_TASK_SINK, task_sink_);
129 ms_context->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, task_sink_);
130 }
131
132 private:
133 int run_mode_ = kPynativeMode;
134 std::string jit_level_;
135 bool task_sink_ = false;
136 };
137
PrintGuardPerf()138 static void PrintGuardPerf() {
139 std::map<std::string, std::pair<size_t, size_t>> guard_info;
140 std::map<std::string, std::pair<size_t, size_t>> guard_freq_info;
141 std::map<std::string, std::pair<size_t, size_t>> trace_info;
142 std::map<std::string, std::pair<size_t, std::vector<size_t>>> item_info;
143 OptGuardPerf::GetGuardPerf()->GetGuardPerfInfo(&guard_info, &item_info, &trace_info, &guard_freq_info);
144 std::cout << "Guard performance info:" << std::endl;
145 std::cout << "guard, count, total time, success, fail" << std::endl;
146 for (const auto &item : guard_info) {
147 auto iter = guard_freq_info.find(item.first);
148 if (iter != guard_freq_info.end()) {
149 std::cout << "guard:" << item.first << ", " << item.second.first << ", " << item.second.second << ","
150 << iter->second.first << "," << iter->second.second << std::endl;
151 } else {
152 std::cout << "guard:" << item.first << ", " << item.second.first << ", " << item.second.second << std::endl;
153 }
154 }
155 std::cout << "trace, count, total time" << std::endl;
156 for (const auto &item : trace_info) {
157 std::cout << "trace:" << item.first << ", " << item.second.first << ", " << item.second.second << std::endl;
158 }
159 std::cout << "item, count, [stage time]" << std::endl;
160 for (const auto &item : item_info) {
161 std::cout << "item:" << item.first << "," << item.second.first << ", [";
162 for (auto stage : item.second.second) {
163 std::cout << stage << ",";
164 }
165 std::cout << "]" << std::endl;
166 }
167 }
168
169 // jit compiler initialize
ensureInitialize()170 static void ensureInitialize() {
171 static bool init = false;
172 if (init) {
173 return;
174 }
175 init = true;
176 if (tss == NULL) {
177 tss = PyThread_tss_alloc();
178 PyThread_tss_create(tss);
179 }
180 std::atexit([]() {
181 if (kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kLogGuardPerf)) {
182 PrintGuardPerf();
183 }
184 });
185 }
186
PushInlineInfo(InlineInfo info)187 void Tracebackes::PushInlineInfo(InlineInfo info) {
188 const auto &it = inline_infos_.find(info.root_name_);
189 if (it != inline_infos_.cend()) {
190 it->second.push_back(info);
191 } else {
192 std::list<InlineInfo> inlines;
193 inlines.push_back(info);
194 inline_infos_.emplace(info.root_name_, inlines);
195 }
196 }
197
PrintLabel(std::stringstream & os,const std::string & str,int distance=30)198 static void PrintLabel(std::stringstream &os, const std::string &str, int distance = 30) {
199 os << std::left << std::setw(distance) << str << ": ";
200 }
201
Dump(bool is_all) const202 std::string Tracebackes::Dump(bool is_all) const {
203 constexpr auto width = 10;
204
205 std::stringstream os;
206 std::string cur_name = tbs_.empty() ? "" : tbs_.back().func_name_;
207 if (is_all) {
208 os << "*** Dump Traceback on [" << raw_func_info_name_ << "] ***\n";
209 } else {
210 os << "*** Dump ByteCode After Traceback on [" << cur_name << "] ***\n";
211 }
212 if (tbs_.empty()) {
213 return os.str();
214 }
215 std::list<Tracebacke> candidates;
216 if (is_all) {
217 candidates = tbs_;
218 } else {
219 // last one traceback
220 candidates.emplace_back(tbs_.back());
221 }
222 // dump traceback list head
223 int name_length = FindMaxNameLength(candidates);
224 os << std::left << std::setw(name_length) << "func_name: --> " << std::left << std::setw(name_length)
225 << "changed_func:" << std::left << std::setw(width) << "run_mode:" << std::left << std::setw(kThree * width)
226 << "stop_trace:" << std::left << std::setw(width) << "code_size:" << std::endl;
227 os << "--------------------------------------------------------------------------------------\n";
228 // dump traceback list content
229 for (const auto &tb : candidates) {
230 os << std::left << std::setw(name_length) << tb.func_name_ << " --> ";
231 os << std::left << std::setw(name_length) << tb.changed_func_;
232 if (tb.is_graph_mode_) {
233 os << std::left << std::setw(width) << "[GRAPH]";
234 } else {
235 os << std::left << std::setw(width) << "PYNATIVE";
236 }
237 // dump stop trace reason
238 auto it_trace = stop_trace_res_.find(tb.func_name_);
239 if (it_trace != stop_trace_res_.cend()) {
240 os << std::left << std::setw(kThree * width) << GetStopTraceReasonDesc(it_trace->second);
241 } else {
242 os << std::left << std::setw(kThree * width) << "unknown";
243 }
244 os << std::left << std::setw(width) << tb.code_size_ << " =====>\n";
245 // dump inline info
246 DumpInlineInfo(os, tb.func_name_);
247 }
248 os << "\n\n";
249 if (is_all) {
250 os << DumpSummary();
251 }
252 return os.str();
253 }
254
DumpInlineInfo(std::stringstream & os,const std::string & func_name) const255 void Tracebackes::DumpInlineInfo(std::stringstream &os, const std::string &func_name) const {
256 const auto &it = inline_infos_.find(func_name);
257 if (it == inline_infos_.cend()) {
258 return;
259 }
260 for (const auto &info : it->second) {
261 std::string space((info.depth + 1) * kTwo, ' ');
262 os << space << "| inline_info:" << GetInlineReasonDesc(info.res) << " line:" << info.line;
263 if (!info.inline_name_.empty()) {
264 os << " func_name:" << info.inline_name_;
265 }
266 if (info.res == InlineReason::kInline || info.res == InlineReason::kInlinePartial) {
267 os << " code_size:" << info.code_size_;
268 }
269 os << "\n";
270 }
271 }
272
DumpSummary() const273 std::string Tracebackes::DumpSummary() const {
274 std::stringstream os;
275 if (tbs_.empty()) {
276 return os.str();
277 }
278 os << "*** Dump Summary on [" << raw_func_info_name_ << "] ***\n";
279 PrintLabel(os, "traceback_num");
280 os << tbs_.size() << "\n";
281
282 std::array<int, kStopTrace_Reason_Count> stop_trace_reason_array{0};
283 std::array<int, kInline_Reason_Count> inline_reason_array{0};
284 int graph_mode_num = 0;
285 int raw_code_size = raw_code_size_;
286 int pynative_code_size = 0;
287 int graph_mode_code_size = 0;
288 for (const auto &tb : tbs_) {
289 if (tb.is_graph_mode_) {
290 graph_mode_num++;
291 graph_mode_code_size += tb.code_size_;
292 } else {
293 pynative_code_size += tb.code_size_;
294 }
295 auto it_trace = stop_trace_res_.find(tb.func_name_);
296 if (it_trace != stop_trace_res_.cend()) {
297 // count stop trace reason
298 stop_trace_reason_array[it_trace->second]++;
299 }
300 const auto &it_inline = inline_infos_.find(tb.func_name_);
301 if (it_inline == inline_infos_.cend()) {
302 continue;
303 }
304 for (const auto &info : it_inline->second) {
305 // count inline reason
306 inline_reason_array[info.res]++;
307 if (info.res == InlineReason::kInline || info.res == InlineReason::kInlinePartial) {
308 raw_code_size += info.code_size_;
309 }
310 }
311 }
312 PrintLabel(os, "graph_mode_num");
313 os << graph_mode_num << "\n";
314 PrintLabel(os, "raw_code_size(+ inline)");
315 os << raw_code_size << "\n";
316 PrintLabel(os, "pynative_code_size");
317 os << pynative_code_size << "\n";
318 PrintLabel(os, "graph_mode_code_size");
319 os << graph_mode_code_size << "\n";
320 os << "----------stop_trace_reason----------\n";
321 for (size_t i = 0; i < stop_trace_reason_array.size(); ++i) {
322 PrintLabel(os, GetStopTraceReasonDesc(static_cast<StopTraceReason>(i)));
323 os << stop_trace_reason_array[i] << "\n";
324 }
325 os << "----------inline_reason----------\n";
326 for (size_t i = 0; i < inline_reason_array.size(); ++i) {
327 PrintLabel(os, GetInlineReasonDesc(static_cast<InlineReason>(i)));
328 os << inline_reason_array[i] << "\n";
329 }
330 os << "\n\n";
331 return os.str();
332 }
333
FindMaxNameLength(const std::list<Tracebacke> & tbs) const334 int Tracebackes::FindMaxNameLength(const std::list<Tracebacke> &tbs) const {
335 constexpr auto name_length = kFive * (kTwo + kFive);
336 int max_length = 15;
337 for (const auto &tb : tbs) {
338 int len1 = SizeToInt(tb.func_name_.length());
339 int len2 = SizeToInt(tb.changed_func_.length());
340 max_length = std::max(max_length, std::max(len1, len2)) + kTwo;
341 }
342 max_length = std::min(max_length, name_length);
343 return max_length;
344 }
345
freeJitCompileResults(void * jitCompileResults)346 static void freeJitCompileResults(void *jitCompileResults) {
347 // maybe nullptr if other module use _PyEval_RequestCodeExtraIndex
348 if (jitCompileResults == nullptr) {
349 return;
350 }
351 // called after code object freed
352 JitCompileResults *c = reinterpret_cast<JitCompileResults *>(jitCompileResults);
353
354 for (auto &oc : c->codehub->GetOptTarget(OptOption::CreateOptionByPoint(c))) {
355 PyCodeObject *co = oc->GetPythonCode();
356 MS_EXCEPTION_IF_CHECK_FAIL(co == nullptr || Py_REFCNT(co) == 1, "code handler must be only one");
357 }
358 c->code = nullptr;
359 c->codehub.reset();
360
361 std::for_each(c->children_.begin(), c->children_.end(), [](CodeExtra *i) { i->parent_ = nullptr; });
362 if (c->parent_ != nullptr) {
363 auto &leaf = c->parent_->children_;
364 leaf.erase(std::remove_if(leaf.begin(), leaf.end(), [c](CodeExtra *i) { return i == c; }), leaf.end());
365 }
366 MS_LOG(DEBUG) << __FUNCTION__ << " " << c;
367 delete c;
368 }
369
allocJitCompileResults()370 static JitCompileResults *allocJitCompileResults() {
371 JitCompileResults *c = new JitCompileResults();
372 c->parent_ = nullptr;
373 c->stat = JitCompileResults::NEVER_COMPILE;
374 c->tbs = std::make_shared<Tracebackes>();
375 c->codehub = std::make_shared<OptCodeHub>();
376 c->conf = std::make_shared<GraphJitConfig>();
377 c->break_count_ = 0;
378 c->signature_ = nullptr;
379 return c;
380 }
381
getJitCompileResults(PyObject * code,bool alloc)382 JitCompileResults *getJitCompileResults(PyObject *code, bool alloc) {
383 if (PyMethod_Check(code)) {
384 code = PyMethod_GET_FUNCTION(code);
385 }
386 if (PyFunction_Check(code)) {
387 code = PyFunction_GET_CODE(code);
388 }
389 if (!PyCode_Check(code)) {
390 return nullptr;
391 }
392 ensureInitialize();
393 Py_ssize_t index = (Py_ssize_t)PyThread_tss_get(tss);
394 if (index == 0) {
395 index = _PyEval_RequestCodeExtraIndex(freeJitCompileResults);
396 if (index == -1) {
397 return nullptr;
398 }
399 // ensure index is not 0
400 PyThread_tss_set(tss, reinterpret_cast<void *>(index + 1));
401 } else {
402 index = index - 1;
403 }
404
405 JitCompileResults *c = nullptr;
406 if (!_PyCode_GetExtra(code, index, reinterpret_cast<void **>(&c))) {
407 if (c != nullptr) {
408 return c;
409 }
410 if (!alloc) {
411 return nullptr;
412 }
413 c = allocJitCompileResults();
414 if (c == nullptr) {
415 return nullptr;
416 }
417 if (!_PyCode_SetExtra(code, index, c)) {
418 MS_LOG(DEBUG) << "allocJitCompileResults " << c << " for " << std::string(py::str(code));
419 return c;
420 }
421 freeJitCompileResults(c);
422 }
423 PyErr_Clear();
424 return nullptr;
425 }
426
RebuildFrame(PyThreadState * tstate,PyCodeObject * co,const PyFrameObject * f)427 static PyFrameObject *RebuildFrame(PyThreadState *tstate, PyCodeObject *co, const PyFrameObject *f) {
428 int argc = f->f_code->co_argcount + f->f_code->co_kwonlyargcount;
429 MS_ASSERT(co != nullptr && argc == co->co_argcount + co->co_kwonlyargcount);
430 MS_ASSERT((f->f_code->co_flags & CO_VARARGS) == (co->co_flags & CO_VARARGS));
431 MS_ASSERT((f->f_code->co_flags & CO_VARKEYWORDS) == (co->co_flags & CO_VARKEYWORDS));
432 argc += (static_cast<unsigned int>(f->f_code->co_flags) & CO_VARARGS) ? 1 : 0;
433 argc += (static_cast<unsigned int>(f->f_code->co_flags) & CO_VARKEYWORDS) ? 1 : 0;
434
435 PyFrameObject *frame = PyFrame_New(tstate, co, f->f_globals, NULL);
436 // copy arguments
437 for (int i = 0; i < argc; i++) {
438 Py_XINCREF(f->f_localsplus[i]);
439 frame->f_localsplus[i] = f->f_localsplus[i];
440 }
441 // restore arguments from cell
442 std::vector<PyObject *> cells_content(f->f_code->co_nlocals, nullptr);
443 for (int i = 0; f->f_code->co_cell2arg != NULL && i < PyTuple_GET_SIZE(f->f_code->co_cellvars); ++i) {
444 Py_ssize_t argi = f->f_code->co_cell2arg[i];
445 if (argi != CO_CELL_NOT_AN_ARG) {
446 PyObject *cell = f->f_localsplus[f->f_code->co_nlocals + i];
447 cells_content[argi] = PyCell_GET(cell);
448 }
449 }
450 // new cell
451 for (int i = 0; i < PyTuple_GET_SIZE(co->co_cellvars); ++i) {
452 PyObject *cell;
453 if (co->co_cell2arg != NULL && co->co_cell2arg[i] != CO_CELL_NOT_AN_ARG) {
454 Py_ssize_t argi = co->co_cell2arg[i];
455 MS_EXCEPTION_IF_CHECK_FAIL(cells_content[argi], "Unbound local exception");
456 cell = PyCell_New(cells_content[argi]);
457 } else {
458 cell = PyCell_New(NULL);
459 }
460 frame->f_localsplus[co->co_nlocals + i] = cell;
461 }
462
463 // copy closure
464 for (int i = 0; i < PyTuple_GET_SIZE(co->co_freevars); ++i) {
465 int a = f->f_code->co_nlocals + PyTuple_GET_SIZE(f->f_code->co_cellvars) + i;
466 int b = co->co_nlocals + PyTuple_GET_SIZE(co->co_cellvars) + i;
467 auto o = f->f_localsplus[a];
468 Py_XINCREF(o);
469 frame->f_localsplus[b] = o;
470 }
471 return frame;
472 }
473
GetClosure(const PyFrameObject * f)474 static PyObject *GetClosure(const PyFrameObject *f) {
475 int nfrees = PyTuple_GET_SIZE(f->f_code->co_freevars);
476 if (nfrees == 0) {
477 return nullptr;
478 }
479 PyObject *closure = PyTuple_New(nfrees);
480 int idx = f->f_code->co_nlocals + PyTuple_GET_SIZE(f->f_code->co_cellvars);
481 for (int i = 0; i < nfrees; ++i) {
482 PyObject *o = f->f_localsplus[idx + i];
483 Py_INCREF(o);
484 PyTuple_SET_ITEM(closure, i, o);
485 }
486 return closure;
487 }
488
PrepareCallCompiledCallable(PyThreadState * tstate,const PyFrameObject * f,const JitCompileResults * c)489 static PyFrameObject *PrepareCallCompiledCallable(PyThreadState *tstate, const PyFrameObject *f,
490 const JitCompileResults *c) {
491 return RebuildFrame(tstate, c->code->GetPythonCode(), f);
492 }
493
GuardForFrame(const PyFrameObject * frame,const OptCodePtr & oc,const GraphJitConfig & conf)494 static void GuardForFrame(const PyFrameObject *frame, const OptCodePtr &oc, const GraphJitConfig &conf) {
495 const char *code_name = PyUnicode_AsUTF8(frame->f_code->co_name);
496 AddConfigToGuard(conf, oc->GetGuard());
497 AddGuardForParam(frame, oc->GetGuard(), conf.GetBoolConfig(GraphJitConfig::kGuardDetachObject));
498 AddGradFlagForParam(pynative::PyNativeExecutor::GetInstance()->grad_flag(), oc->GetGuard(),
499 conf.GetBoolConfig(GraphJitConfig::kGuardDetachObject));
500 if (conf.GetBoolConfig(GraphJitConfig::kPrintGuard)) {
501 GRAPH_JIT_LOG_F("Guard on %s by %s!\n", code_name, oc->GetGuard()->GetDescript().c_str());
502 return;
503 }
504 if (IS_OUTPUT_ON(mindspore::kDebug)) {
505 // It tooks too much time in Guard's GetDescript function when trace depth is too large.
506 MS_LOG(DEBUG) << "Guard on " << code_name << " by " << oc->GetGuard()->GetDescript() << "!" << std::endl;
507 }
508 }
509
ValidateCompiledResults(const JitCompileResults * c)510 static void ValidateCompiledResults(const JitCompileResults *c) {
511 if (c->stat != JitCompileResults::GRAPH_CALLABLE) {
512 return;
513 }
514 bool valid_res;
515 if (c->code->GetNativeFunc()) {
516 valid_res = true;
517 } else {
518 valid_res = c->code->GetPythonCode() != nullptr;
519 }
520 MS_EXCEPTION_IF_CHECK_FAIL(valid_res, "check compiled result");
521 }
522
MarkBreak(Graph * g)523 static void MarkBreak(Graph *g) {
524 TimeRecorder recorder(__FUNCTION__, kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kLogPerf));
525 int break_bci = g->GetStopTraceBci();
526 if (break_bci == -1) {
527 return;
528 }
529 PyCodeObject *code;
530 const auto &nodes = g->GetTracedNodes();
531 if (nodes.empty()) {
532 code = g->GetCodeObj();
533 } else {
534 auto iter = std::find_if(nodes.begin(), nodes.end(), [&break_bci](ValueNode *i) { return i->bci() >= break_bci; });
535 iter -= iter == nodes.end();
536 for (code = (*iter)->GetGraph()->GetCodeObj(); code == nullptr && iter != nodes.begin(); --iter) {
537 code = (*iter)->GetGraph()->GetCodeObj();
538 }
539 }
540 MS_EXCEPTION_IF_NULL(code);
541 auto jcr = getJitCompileResults(reinterpret_cast<PyObject *>(code), false);
542 if (jcr != nullptr) {
543 jcr->break_count_++;
544 }
545 }
546
GetAllArgs(JitCompileResults * jcr)547 std::vector<py::object> GetAllArgs(JitCompileResults *jcr) {
548 auto all_args = PackArgs(jcr->origin_frame_);
549 constexpr size_t arg_index = 0;
550 constexpr size_t vargs_index = 1;
551 constexpr size_t kwargs_index = 2;
552 auto args = py::cast<py::list>(all_args[arg_index]);
553 if (all_args[vargs_index].ptr() != nullptr) {
554 PyList_Append(args.ptr(), all_args[vargs_index].ptr()); // args + vargs
555 }
556 if (all_args[kwargs_index].ptr() != nullptr) {
557 PyList_Append(args.ptr(), all_args[kwargs_index].ptr()); // args + kwargs
558 }
559 return args.cast<std::vector<py::object>>();
560 }
561
562 static void GraphCapture(JitCompileResults *jcr);
HandleBreakAtLoop(JitCompileResults * jcr,const GraphBuilderPtr & g)563 static auto HandleBreakAtLoop(JitCompileResults *jcr, const GraphBuilderPtr &g) {
564 // one stage need adapter
565 if (g->GetGraph()->IsBreakAtLoopAfterUnrolling()) {
566 if (jcr->conf->GetBoolConfig(GraphJitConfig::kLogGraphBreak)) {
567 GRAPH_JIT_LOG_F("===> graph break after loop unrolling\n%s\n", g->GetGraph()->ToString(1).c_str());
568 }
569 // reset guard
570 jcr->code->SetGuard(std::make_shared<OptGuard>());
571 AddConfigToGuard(*jcr->conf, jcr->code->GetGuard());
572 // disable loop unroll
573 jcr->conf->SetBool<GraphJitConfig::kLoopUnrolling>(Py_False);
574 // restart captured
575 GraphCapture(jcr);
576 // reset config
577 jcr->conf->SetBool<GraphJitConfig::kLoopUnrolling>(Py_True);
578 return true;
579 }
580 return false;
581 }
582
HandleUnsupportedSyntax(JitCompileResults * jcr,const GraphBuilderPtr & g)583 static auto HandleUnsupportedSyntax(JitCompileResults *jcr, const GraphBuilderPtr &g) {
584 int break_bci = g->GetGraph()->GetStopTraceBci();
585 if (break_bci == -1) {
586 return false;
587 }
588 int break_op = g->GetGraph()->GetCFG()->instr_pool()[break_bci]->op();
589 bool unsupported = break_op == WITH_CLEANUP_START || break_op == WITH_CLEANUP_FINISH || break_op == END_FINALLY;
590 if (g->StackSize() > 0 || unsupported) {
591 // something happened in with syntax
592 jcr->code->SetGuard(std::make_shared<OptGuard>());
593 AddConfigToGuard(*jcr->conf, jcr->code->GetGuard());
594 jcr->conf->SetBool<GraphJitConfig::kSkipException>(Py_True);
595 GraphCapture(jcr);
596 g->GetTryBlockStacks().clear();
597 jcr->conf->SetBool<GraphJitConfig::kSkipException>(Py_False);
598 return true;
599 }
600 return false;
601 }
602
TraceRun(JitCompileResults * jcr)603 static auto TraceRun(JitCompileResults *jcr) {
604 TimeRecorder recorder(__FUNCTION__, kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kLogPerf));
605
606 GraphJitConfig &conf = *jcr->conf;
607 GraphBuilderPtr g = GraphBuilder::Creator(jcr->origin_frame_, conf.GetBoolConfig(GraphJitConfig::kTraceFlag));
608
609 if (conf.GetBoolConfig(GraphJitConfig::kTraceFlag)) {
610 auto mg = std::dynamic_pointer_cast<MindGraphBuilder>(g);
611 (void)mg->FGBuilder()->AddTopGraphInputs(PackArgs(jcr->origin_frame_));
612 }
613 (void)g->TraceRun();
614 return g;
615 }
616
Inline(JitCompileResults * jcr,const GraphBuilderPtr & g)617 static void Inline(JitCompileResults *jcr, const GraphBuilderPtr &g) {
618 TimeRecorder recorder(__FUNCTION__, kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kLogPerf));
619 GraphJitConfig &conf = *jcr->conf;
620 // One stage should skip inline process.
621 if (!conf.GetBoolConfig(GraphJitConfig::kTraceFlag)) {
622 BytecodeInliner inliner(g->GetGraph(), py::cast<py::dict>(jcr->origin_frame_->f_globals));
623 inliner.Run();
624 }
625 }
626
Analyze(const GraphBuilderPtr & g)627 static auto Analyze(const GraphBuilderPtr &g) {
628 TimeRecorder recorder(__FUNCTION__, kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kLogPerf));
629
630 auto analyzer = GraphAnalyzer::Creator(g);
631 analyzer->Analyze();
632 return analyzer;
633 }
634
635 // preprocess before compile, split bytecode to sub-function
636 // return whether the code should be modified
GraphCapture(JitCompileResults * jcr)637 static void GraphCapture(JitCompileResults *jcr) {
638 TimeRecorder recorder(__FUNCTION__, kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kLogPerf));
639 MS_EXCEPTION_IF_NULL(jcr->code);
640
641 GraphJitConfig &conf = *jcr->conf;
642 AObject::SetTraceFlag(conf.GetBoolConfig(GraphJitConfig::kTraceFlag));
643 GraphBuilderPtr g = TraceRun(jcr);
644 if (HandleUnsupportedSyntax(jcr, g)) {
645 return;
646 }
647 if (g->GetGraph()->IsBreakAtLoop() && !g->GetGraph()->RestoreLoopStatus()) {
648 jcr->stat = JitCompileResults::NEVER_COMPILE;
649 return;
650 }
651 Inline(jcr, g);
652 GraphAnalyzerPtr analyzer = Analyze(g);
653 if (HandleBreakAtLoop(jcr, g)) {
654 return;
655 }
656 MarkBreak(g->GetGraph());
657
658 // dump DFG
659 if (conf.GetBoolConfig(GraphJitConfig::kPrintAfterAll)) {
660 g->DumpDFG();
661 }
662
663 py::object new_code = MakeCodeFromCodeGen(g, analyzer, jcr->origin_frame_->f_globals);
664 if (new_code.ptr() != nullptr) {
665 jcr->code->SetPythonCode(new_code);
666 jcr->stat = JitCompileResults::GRAPH_CALLABLE;
667 }
668
669 if (conf.GetBoolConfig(GraphJitConfig::kPrintAfterAll)) {
670 if (conf.GetBoolConfig(GraphJitConfig::kTraceFlag)) {
671 const auto &debug_str = analyzer->GetCaptureInfo().ToString();
672 PY_PRINT_F("*** Dump One Stage ByteCode Collection After CodeGen *** \n%s", debug_str.c_str());
673 }
674 Utils::DisFuncObject(new_code.ptr());
675 GRAPH_JIT_LOG_F("\n\n");
676 }
677
678 // collect stop trace reason to traceback
679 jcr->tbs->PushStopTraceRes(g->GetGraph()->GetCodeName(), g->GetGraph()->GetStopTraceReason());
680
681 bool captured = !analyzer->NeedInterpret() && !conf.GetBoolConfig(GraphJitConfig::kInterpretCapturedCode);
682 if (captured && !jcr->conf->GetBoolConfig(GraphJitConfig::kTraceFlag)) {
683 jcr->stat = JitCompileResults::GRAPH_CAPTURED;
684 }
685 }
686
CollectTraceBack(JitCompileResults * c,PyCodeObject * code,bool is_graph_mode)687 static void CollectTraceBack(JitCompileResults *c, PyCodeObject *code, bool is_graph_mode) {
688 if (code == nullptr) {
689 code = c->origin_frame_->f_code;
690 }
691 std::string name = Utils::GetPyName(c->origin_frame_->f_code->co_name);
692 std::string changed_name = Utils::GetPyName(code->co_name);
693 int code_size = SizeToInt((PyBytes_GET_SIZE(code->co_code)) / sizeof(_Py_CODEUNIT));
694 c->tbs->PushTbs({name, changed_name, code_size, is_graph_mode});
695 }
696
GetFuncGraphPhase(const PyFrameObject & frame,const OptCodePtr & oc)697 std::string GetFuncGraphPhase(const PyFrameObject &frame, const OptCodePtr &oc) {
698 std::string phase = py::cast<std::string>(frame.f_code->co_filename) + "_" +
699 std::to_string(frame.f_code->co_firstlineno) + "_" + py::cast<std::string>(frame.f_code->co_name);
700 if (oc != nullptr) {
701 phase += std::to_string(oc->GetGuard()->Info().Id());
702 } else {
703 for (int i = 0; i < frame.f_code->co_argcount; i++) {
704 PyObject *obj = PyTuple_GET_ITEM(frame.f_code->co_varnames, i);
705 py::object para = py::cast<py::object>(PyDict_GetItem(frame.f_locals, obj));
706 auto node = GraphUtils::ConvertPythonObjectToAnfNode(para);
707 phase += "_" + node->abstract()->ToString();
708 }
709 }
710 phase += ".pi_jit";
711 return phase;
712 }
713
AddConfigToGuard(const GraphJitConfig & c,OptGuardPtr guard)714 void AddConfigToGuard(const GraphJitConfig &c, OptGuardPtr guard) {
715 std::map<std::string, bool> bool_cfg;
716 std::map<std::string, int> int_cfg;
717 bool_cfg[kSpecializeScalar] = c.GetBoolConfig(GraphJitConfig::kGuardSpecializeScalar);
718 bool_cfg[kSpecializeContainer] = c.GetBoolConfig(GraphJitConfig::kGuardSpecializeContainer);
719 bool_cfg[kSpecializeTensor] = c.GetBoolConfig(GraphJitConfig::kGuardSpecializeTensor);
720 int_cfg[kGuardRelaxCnt] = c.getIntConfig(GraphJitConfig::kGuardRelaxCount);
721 guard->UpdateConfig(bool_cfg, int_cfg);
722 }
723
AddGuardForParam(const PyFrameObject * f,OptGuardPtr guard,bool detach)724 void AddGuardForParam(const PyFrameObject *f, OptGuardPtr guard, bool detach) {
725 int argc = f->f_code->co_argcount + f->f_code->co_kwonlyargcount;
726 PyObject *vargs = NULL;
727 PyObject *kwargs = NULL;
728 if (static_cast<unsigned int>(f->f_code->co_flags) & CO_VARARGS) {
729 vargs = f->f_localsplus[argc];
730 }
731 if (static_cast<unsigned int>(f->f_code->co_flags) & CO_VARKEYWORDS) {
732 kwargs = f->f_localsplus[argc + (vargs ? 1 : 0)];
733 }
734 for (int i = 0; i < argc; ++i) {
735 if (f->f_localsplus[i] == nullptr) {
736 continue;
737 }
738 RootTracePtr ptr = std::make_shared<RootTrace>(f->f_localsplus[i], mindspore::pijit::TraceType::Param, i);
739 guard->GuardOn(ptr, mindspore::pijit::GuardLevel::GDeduce, false);
740 if (detach) {
741 ptr->Detach();
742 }
743 }
744 if (vargs != NULL) {
745 RootTracePtr ptr = std::make_shared<RootTrace>(f->f_localsplus[argc], mindspore::pijit::TraceType::Param, argc);
746 guard->GuardOn(ptr, mindspore::pijit::GuardLevel::GDeduce, false);
747 if (detach) {
748 ptr->Detach();
749 }
750 }
751 if (kwargs != NULL) {
752 RootTracePtr ptr = std::make_shared<RootTrace>(f->f_localsplus[argc + (vargs ? 1 : 0)],
753 mindspore::pijit::TraceType::Param, argc + (vargs ? 1 : 0));
754 guard->GuardOn(ptr, mindspore::pijit::GuardLevel::GDeduce, false);
755 if (detach) {
756 ptr->Detach();
757 }
758 }
759 for (int i = 0; f->f_code->co_cell2arg && i < PyTuple_GET_SIZE(f->f_code->co_cellvars); ++i) {
760 Py_ssize_t arg = f->f_code->co_cell2arg[i];
761 if (arg != CO_CELL_NOT_AN_ARG) {
762 auto cell = f->f_localsplus[f->f_code->co_nlocals + i];
763 RootTracePtr ptr = std::make_shared<RootTrace>(PyCell_GET(cell), mindspore::pijit::TraceType::Deref, i);
764 guard->GuardOn(ptr, mindspore::pijit::GuardLevel::GDeduce, false);
765 if (detach) {
766 ptr->Detach();
767 }
768 }
769 }
770 for (int i = 0; i < PyTuple_GET_SIZE(f->f_code->co_freevars); ++i) {
771 Py_ssize_t arg = PyTuple_GET_SIZE(f->f_code->co_cellvars) + i;
772 auto cell = f->f_localsplus[f->f_code->co_nlocals + arg];
773 RootTracePtr ptr = std::make_shared<RootTrace>(PyCell_GET(cell), mindspore::pijit::TraceType::Deref, arg);
774 guard->GuardOn(ptr, mindspore::pijit::GuardLevel::GDeduce, false);
775 if (detach) {
776 ptr->Detach();
777 }
778 }
779 }
780
AddGuardForGlobals(const PyFrameObject * f,OptGuardPtr guard,bool detach)781 void AddGuardForGlobals(const PyFrameObject *f, OptGuardPtr guard, bool detach) {
782 PyCodeObject *co = f->f_code;
783 const _Py_CODEUNIT *bytecodes = reinterpret_cast<_Py_CODEUNIT *>(PyBytes_AsString(co->co_code));
784 int size = (PyBytes_GET_SIZE(co->co_code)) / SizeToInt(sizeof(_Py_CODEUNIT));
785 unsigned int exarg = 0;
786 for (int bci = 0; bci < size; ++bci) {
787 int opcode = _Py_OPCODE(bytecodes[bci]);
788 int oparg = (exarg << 8) | _Py_OPARG(bytecodes[bci]);
789 exarg = static_cast<unsigned>((opcode == EXTENDED_ARG) ? oparg : 0);
790 if (opcode != LOAD_GLOBAL) {
791 continue;
792 }
793 PyObject *k = PyTuple_GET_ITEM(co->co_names, oparg);
794 PyObject *v = PyDict_GetItem(f->f_globals, k);
795 std::string key = PyUnicode_AsUTF8(k);
796 if (v == nullptr) {
797 PyErr_Clear();
798 continue;
799 }
800
801 TracePtr ptr = std::make_shared<RootTrace>(v, TraceType::Global, -1, key);
802
803 AObject::Type t = AObject::GetPyType(v);
804 GuardLevel level = GuardLevel::GType;
805 if (t == AObject::kTypeCell || t == AObject::kTypePrimitive || t == AObject::kTypeMSDType) {
806 level = GuardLevel::GDeduce;
807 } else if (t == AObject::kTypeFunction) {
808 ptr = std::make_shared<OpTrace>(PyFunction_GET_CODE(v), LOAD_ATTR, -1, std::vector<TracePtr>({ptr}), "__code__");
809 level = GuardLevel::GId;
810 } else if (t == AObject::kTypeTuple || t == AObject::kTypeList || t == AObject::kTypeDict) {
811 /**
812 * graph treat tuple, list, dict as constant variable.
813 * add container guard and check it, check contains Tensor
814 */
815 continue;
816 }
817
818 guard->GuardOn(ptr, level, false);
819 if (detach) {
820 ptr->Detach();
821 }
822 }
823 }
824
AddGradFlagForParam(bool grad_flag,OptGuardPtr guard,bool detach)825 static void AddGradFlagForParam(bool grad_flag, OptGuardPtr guard, bool detach) {
826 CustomizedTracePtr ptr = std::make_shared<CustomizedTrace>(
827 grad_flag ? Py_True : Py_False,
828 [](PTraceContext context) -> PyObject * {
829 static pynative::PyNativeExecutor *pynative_exec = nullptr;
830 if (pynative_exec == nullptr) {
831 pynative_exec = pynative::PyNativeExecutor::GetInstance().get();
832 }
833 PyObject *ret = pynative_exec->grad_flag() ? Py_True : Py_False;
834 Py_INCREF(ret);
835 return ret;
836 },
837 [grad_flag](bool simple) -> std::string {
838 if (simple) {
839 return std::string("g\\") + std::to_string(grad_flag ? 1 : 0);
840 }
841 return std::string("{PyNativeExecutor::GetInstance()->grad_flag == ") + std::to_string(grad_flag) +
842 std::string("}(type:") + std::to_string(TraceType::Customized) + std::string(")");
843 });
844 guard->GuardOn(ptr, mindspore::pijit::GuardLevel::GEqual, true);
845 if (detach) {
846 ptr->Detach();
847 }
848 }
849
CallGraphCompiler(JitCompileResults * jcr,PyFunctionObject * func,const PyFrameObject * frame)850 static std::string CallGraphCompiler(JitCompileResults *jcr, PyFunctionObject *func, const PyFrameObject *frame) {
851 std::string phase = GetFuncGraphPhase(*frame, jcr->code);
852 MS_LOG(DEBUG) << "Phase is " << phase << "!";
853 CallableGraph callable = mindspore::pijit::Compiler::Compile(*func, *frame, phase);
854 if (callable == nullptr) {
855 jcr->stat = JitCompileResults::NEVER_COMPILE;
856 return std::string();
857 }
858
859 ReleaseFunc rFunc = nullptr;
860 if (jcr->conf->GetBoolConfig(GraphJitConfig::kAutoCleanCache)) {
861 rFunc = [phase]() {
862 auto graph_executor = mindspore::pipeline::GraphExecutorPy::GetInstance();
863 if (graph_executor->HasCompiled(phase)) {
864 py::str p(phase);
865 py::set s;
866 s.add(phase);
867 py::object o = py::none();
868 graph_executor->DelNetRes(o, s);
869 MS_LOG(DEBUG) << "To release " << phase;
870 }
871 };
872 }
873 jcr->code->SetNativeFunc(phase, callable, rFunc);
874 jcr->stat = JitCompileResults::GRAPH_CALLABLE;
875 return phase;
876 }
877
GraphToString(FuncGraphPtr graph)878 std::string GraphToString(FuncGraphPtr graph) {
879 std::ostringstream graph_buffer;
880 DumpIR(graph_buffer, graph);
881 auto ret = graph_buffer.str();
882 std::regex regAddress("(0x)([0-9a-f]+)");
883 ret = std::regex_replace(ret, regAddress, "");
884 std::regex regFunc(std::string("(") + graph->ToString() + std::string(")"));
885 ret = std::regex_replace(ret, regFunc, "");
886 std::regex regVar("(\\%[0-9]+\\()([A-Za-z0-9_]+)(\\))");
887 ret = std::regex_replace(ret, regVar, "$1$3");
888 std::regex regNode("CNode_([0-9]+)");
889 ret = std::regex_replace(ret, regNode, "");
890 return ret;
891 }
892
GraphCompile(JitCompileResults * jcr,const PyFrameObject * frame)893 static void GraphCompile(JitCompileResults *jcr, const PyFrameObject *frame) {
894 TimeRecorder recorder(__FUNCTION__, kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kLogPerf));
895 GuardForFrame(frame, jcr->code, *jcr->conf);
896 AddGuardForGlobals(frame, jcr->code->GetGuard(), jcr->conf->GetBoolConfig(GraphJitConfig::kGuardDetachObject));
897
898 bool enable_dynamicshape = jcr->conf->GetBoolConfig(GraphJitConfig::kEnableDynamicShape);
899 OptStrategy::MakeGCStrategy(jcr->codehub, jcr->conf->getIntConfig(GraphJitConfig::kLimitGraphSize),
900 jcr->conf->getIntConfig(GraphJitConfig::kLimitGraphCount), enable_dynamicshape,
901 jcr->code);
902 // restore function object from frame
903 PyObject *new_func = PyFunction_New(reinterpret_cast<PyObject *>(frame->f_code), frame->f_globals);
904 Py_XSETREF(PyFunction_GET_CLOSURE(new_func), GetClosure(frame));
905 PyFunctionObject *func = reinterpret_cast<PyFunctionObject *>(new_func);
906 PyFrameObject *f = const_cast<PyFrameObject *>(frame);
907 std::vector<PyObject *> backup;
908 if (enable_dynamicshape) {
909 backup = jcr->code->GetGuard()->ApplyDynamicShape(f);
910 PyFrame_FastToLocals(f);
911 }
912 RunEnvironment runEnvironment;
913 runEnvironment.fetchAndSetRunEnv(jcr);
914 std::string phase = CallGraphCompiler(jcr, func, frame);
915 runEnvironment.resumePreviousRunEnv();
916 if (enable_dynamicshape) {
917 jcr->code->GetGuard()->RevertDynamicShape(f, backup);
918 PyFrame_FastToLocals(f);
919 }
920
921 Py_DECREF(new_func);
922
923 if (jcr->conf->GetBoolConfig(GraphJitConfig::kReuseGraph)) {
924 auto graph_executor = mindspore::pipeline::GraphExecutorPy::GetInstance();
925 FuncGraphPtr ms_func_graph = graph_executor->GetFuncGraph(phase);
926 std::string key = GraphToString(ms_func_graph);
927 auto pcode = OptCodeHub::Filter(key, [jcr, graph_executor, ms_func_graph](OptCodePtr code) {
928 FuncGraphPtr func_graph = graph_executor->GetFuncGraph(code->GetPhase());
929 FuncGraphPairMapEquiv equiv_graph;
930 NodeMapEquiv equiv_node;
931 if (func_graph != nullptr && Isomorphic(ms_func_graph, func_graph, &equiv_graph, &equiv_node)) {
932 return true;
933 } else {
934 return false;
935 }
936 });
937 if (pcode != nullptr) {
938 if (jcr->conf->GetBoolConfig(GraphJitConfig::kPrintReuseGraph)) {
939 std::ostringstream graph_buffer;
940 DumpIR(graph_buffer, ms_func_graph);
941 std::cout << "Graph Duplicated:" << std::endl;
942 std::cout << " Graph:" << graph_buffer.str() << std::endl;
943 std::cout << " Bytecode:" << std::endl;
944 Utils::DisFuncObject(reinterpret_cast<PyObject *>(frame->f_code));
945 }
946 // find duplicate graph and reuse it
947 pcode->Copy(jcr->code);
948 } else {
949 // current graph is a new one and register it
950 OptCodeHub::Register(key, jcr->code);
951 }
952 }
953 }
954
955 extern bool UnsupportedCodeTypeCheck(PyCodeObject *co);
JitCompile(PyThreadState * tstate,JitCompileResults * c)956 static bool JitCompile(PyThreadState *tstate, JitCompileResults *c) {
957 if (UnsupportedCodeTypeCheck(c->origin_frame_->f_code)) {
958 return false;
959 }
960
961 ShapeContext sc(c->origin_frame_, c->signature_);
962 std::string code_str = py::str(reinterpret_cast<PyObject *>(c->origin_frame_->f_code));
963 MS_LOG(DEBUG) << "---start compile " << code_str << "---";
964
965 // new guard code
966 c->code = c->codehub->AddOptTarget(OptOption::CreateOptionByPoint(c));
967 AddConfigToGuard(*c->conf, c->code->GetGuard());
968
969 py::object frame = py::reinterpret_borrow<py::object>(reinterpret_cast<PyObject *>(c->origin_frame_));
970 if (c->stat == JitCompileResults::GRAPH_CANDIDATE) {
971 TimeRecorder time_recorder("kTimeCompileCapture", kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kLogPerf));
972 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kCapture, runtime::ProfilerEvent::kCaptureProcess,
973 "PIJitCapture");
974 c->stat = JitCompileResults::GRAPH_BUILDING;
975 auto aobject_resource = AObject::MakeResource();
976 GraphCapture(c);
977 sc.ApplySignature();
978 if (c->stat == JitCompileResults::GRAPH_CAPTURED) {
979 PyFrameObject *f = PrepareCallCompiledCallable(tstate, c->origin_frame_, c);
980 frame = py::reinterpret_steal<py::object>(reinterpret_cast<PyObject *>(f));
981 }
982 if (c->conf->GetBoolConfig(GraphJitConfig::kTraceFlag)) {
983 PyFrameObject *f = reinterpret_cast<PyFrameObject *>(frame.ptr());
984 GuardForFrame(f, c->code, *c->conf);
985 AddGuardForGlobals(f, c->code->GetGuard(), c->conf->GetBoolConfig(GraphJitConfig::kGuardDetachObject));
986 }
987 aobject_resource.Release();
988 }
989 sc.ApplySignature();
990
991 if (c->stat == JitCompileResults::GRAPH_CAPTURED) {
992 TimeRecorder time_recorder("kTimeCompileGraph", kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kLogPerf));
993 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kCapture, runtime::ProfilerEvent::kCaptureCompile,
994 "PIJitCompile");
995 c->stat = JitCompileResults::GRAPH_BUILDING;
996 PyFrameObject *f = reinterpret_cast<PyFrameObject *>(frame.ptr());
997 PyFrame_FastToLocals(f);
998 GraphCompile(c, f);
999 }
1000
1001 auto guard = c->code->GetGuard()->Optimize();
1002 if (guard != nullptr) {
1003 c->code->SetGuard(guard);
1004 }
1005
1006 CollectTraceBack(c, c->code->GetPythonCode(), c->code->GetNativeFunc() != nullptr);
1007
1008 if (c->conf->GetBoolConfig(GraphJitConfig::kPrintAfterAll)) {
1009 GRAPH_JIT_LOG_F("%s\n", c->tbs->Dump().c_str());
1010
1011 GRAPH_JIT_LOG_F("generated guard at %s\n", code_str.c_str());
1012 GRAPH_JIT_LOG_F("%s\n", c->code->GetGuard()->ToString().c_str());
1013 }
1014 if (c->stat != JitCompileResults::GRAPH_CALLABLE) {
1015 c->stat = JitCompileResults::NEVER_COMPILE;
1016 return false;
1017 }
1018 return true;
1019 }
1020
PackArgs(const PyFrameObject * frame)1021 std::vector<py::object> PackArgs(const PyFrameObject *frame) {
1022 const Py_ssize_t argc = frame->f_code->co_argcount + frame->f_code->co_kwonlyargcount;
1023 bool has_varg = static_cast<unsigned int>(frame->f_code->co_flags) & CO_VARARGS;
1024 py::list args(argc);
1025 py::object vargs;
1026 py::object kwvargs;
1027 for (Py_ssize_t i = 0; i < argc; ++i) {
1028 args[i] = py::reinterpret_borrow<py::object>(frame->f_localsplus[i]);
1029 }
1030 if (has_varg) {
1031 vargs = py::reinterpret_borrow<py::object>(frame->f_localsplus[argc]);
1032 }
1033 if (static_cast<unsigned int>(frame->f_code->co_flags) & CO_VARKEYWORDS) {
1034 kwvargs = py::reinterpret_borrow<py::object>(frame->f_localsplus[argc + has_varg]);
1035 }
1036
1037 const Py_ssize_t ncells = PyTuple_GET_SIZE(frame->f_code->co_cellvars);
1038 for (Py_ssize_t i = 0; frame->f_code->co_cell2arg && i < ncells; ++i) {
1039 Py_ssize_t argi = frame->f_code->co_cell2arg[i];
1040 if (argi != CO_CELL_NOT_AN_ARG) {
1041 PyObject *cell = frame->f_localsplus[frame->f_code->co_nlocals + i];
1042 args[argi] = py::reinterpret_borrow<py::object>(PyCell_GET(cell));
1043 }
1044 }
1045 return {args, vargs, kwvargs};
1046 }
1047
ResultMutable(py::object obj)1048 static py::object ResultMutable(py::object obj) {
1049 py::object mutable_func = Utils::GetModuleAttr("mindspore.common", "mutable", false, true);
1050 if (py::isinstance<py::tuple>(obj)) {
1051 auto tuple_obj = obj.cast<py::tuple>();
1052 py::list mutable_list(tuple_obj);
1053 for (size_t i = 0; i < tuple_obj.size(); i++) {
1054 try {
1055 auto mutable_element = mutable_func(tuple_obj[i]);
1056 mutable_list[i] = mutable_element;
1057 } catch (py::error_already_set &e) {
1058 if (PyErr_Occurred()) {
1059 PyErr_Clear();
1060 }
1061 continue;
1062 }
1063 }
1064 auto mutable_tuple = py::tuple(mutable_list);
1065 return mutable_tuple;
1066 } else {
1067 try {
1068 auto mutable_obj = mutable_func(obj);
1069 return mutable_obj;
1070 } catch (py::error_already_set &e) {
1071 if (PyErr_Occurred()) {
1072 PyErr_Clear();
1073 }
1074 }
1075 }
1076 return obj;
1077 }
1078
CallGraph(const JitCompileResults * c,const py::object & args,const py::object & kwvargs)1079 static py::object CallGraph(const JitCompileResults *c, const py::object &args, const py::object &kwvargs) {
1080 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kCapture, runtime::ProfilerEvent::kCaptureRunGraph,
1081 "PIJitRunGraph");
1082
1083 StaticAnalysisExceptionCleaner exception_cleaner;
1084 RunEnvironment runEnvironment;
1085 runEnvironment.fetchAndSetRunEnv(c);
1086 PyObject *py_args = args.ptr();
1087 PyObject *py_kwvargs = kwvargs.ptr();
1088 PyObject *res;
1089 if (c->conf->GetBoolConfig(GraphJitConfig::kPerfStatistics) &&
1090 c->code->GetPerf(OptPerf::PerfKind::kPerfGraph)->GetStatistics()->GetTotalCount() <
1091 c->conf->getIntConfig(GraphJitConfig::kPerfStatisticsCount)) {
1092 std::function<PyObject *(PyObject * py_args, PyObject * py_kwvargs)> func = [c](PyObject *py_args,
1093 PyObject *py_kwvargs) {
1094 auto ret = c->code->GetNativeFunc()(py_args, py_kwvargs);
1095 runtime::OpExecutor::GetInstance().WaitAll();
1096 return ret;
1097 };
1098 runtime::OpExecutor::GetInstance().WaitAll();
1099 res = CallFunction(c->code->GetPerf(OptPerf::PerfKind::kPerfGraph), func, py_args, py_kwvargs);
1100 } else {
1101 res = c->code->GetNativeFunc()(py_args, py_kwvargs);
1102 }
1103 runEnvironment.resumePreviousRunEnv();
1104 if (res == NULL && !PyErr_Occurred()) {
1105 PyErr_SetString(PyExc_RuntimeError, "compiled graph execute failed");
1106 }
1107 auto res_obj = py::reinterpret_steal<py::object>(res);
1108 return ResultMutable(res_obj);
1109 }
1110
CallCompiledCallable(PyThreadState * tstate,PyFrameObject * f,const JitCompileResults * c)1111 static py::object CallCompiledCallable(PyThreadState *tstate, PyFrameObject *f, const JitCompileResults *c) {
1112 PyFrameObject *new_f;
1113 PyObject *res;
1114 int bci;
1115
1116 if (c->code->GetPythonCode() != nullptr) {
1117 new_f = PrepareCallCompiledCallable(tstate, f, c);
1118 } else {
1119 Py_INCREF(f);
1120 new_f = f;
1121 }
1122
1123 if (c->conf->GetBoolConfig(GraphJitConfig::kPerfStatistics) &&
1124 c->code->GetPerf(OptPerf::PerfKind::kPerfPyNative)->GetStatistics()->GetTotalCount() <
1125 c->conf->getIntConfig(GraphJitConfig::kPerfStatisticsCount)) {
1126 std::function<PyObject *(PyThreadState * tstate, PyFrameObject * f, int exc)> func = [](PyThreadState *tstate,
1127 PyFrameObject *f, int exc) {
1128 auto ret = _PyEval_EvalFrameDefault(tstate, f, exc);
1129 runtime::OpExecutor::GetInstance().WaitAll();
1130 return ret;
1131 };
1132 runtime::OpExecutor::GetInstance().WaitAll();
1133 // use function pointer not std::function
1134 res = CallFunction(c->code->GetPerf(OptPerf::PerfKind::kPerfPyNative), func, tstate, new_f, 0);
1135 } else {
1136 res = _PyEval_EvalFrameDefault(tstate, new_f, 0);
1137 }
1138
1139 bci = new_f->f_lasti;
1140 Py_DECREF(new_f);
1141
1142 if (res == NULL && !PyErr_Occurred()) {
1143 PyErr_Format(PyExc_RuntimeError, "compiled function failed with unknown error, error bci %d", bci);
1144 }
1145 return py::reinterpret_steal<py::object>(res);
1146 }
1147
CheckTensorInContainer(py::object args)1148 static bool CheckTensorInContainer(py::object args) {
1149 if (py::isinstance<py::tuple>(args)) {
1150 py::tuple t = py::cast<py::tuple>(args);
1151 for (size_t i = 0; i < t.size(); ++i) {
1152 if (CheckTensorInContainer(t[i])) {
1153 return true;
1154 }
1155 }
1156 } else if (py::isinstance<py::list>(args)) {
1157 py::list l = py::cast<py::list>(args);
1158 for (size_t i = 0; i < l.size(); ++i) {
1159 if (CheckTensorInContainer(l[i])) {
1160 return true;
1161 }
1162 }
1163 }
1164 if (IsStubTensor(args) || py::isinstance<mindspore::tensor::Tensor>(args.ptr())) {
1165 return true;
1166 } else {
1167 return false;
1168 }
1169 }
1170
1171 static bool CheckAbstract(abstract::AbstractBasePtr abs, bool incontainer);
1172
CheckContainer(abstract::AbstractBasePtr abs)1173 static bool CheckContainer(abstract::AbstractBasePtr abs) {
1174 if (abs->isa<abstract::AbstractTuple>()) {
1175 auto elems = abs->cast<abstract::AbstractTuplePtr>()->elements();
1176 for (size_t idx = 0; idx < elems.size(); ++idx) {
1177 if (!CheckAbstract(elems[idx], true)) {
1178 return false;
1179 }
1180 }
1181 }
1182 if (abs->isa<abstract::AbstractList>()) {
1183 auto elems = abs->cast<abstract::AbstractListPtr>()->elements();
1184 for (size_t idx = 0; idx < elems.size(); ++idx) {
1185 if (!CheckAbstract(elems[idx], true)) {
1186 return false;
1187 }
1188 }
1189 }
1190 if (abs->isa<abstract::AbstractSequence>()) {
1191 auto elems = abs->cast<abstract::AbstractSequencePtr>()->elements();
1192 for (size_t idx = 0; idx < elems.size(); ++idx) {
1193 if (!CheckAbstract(elems[idx], true)) {
1194 return false;
1195 }
1196 }
1197 }
1198 if (abs->isa<abstract::AbstractDictionary>()) {
1199 auto elems = abs->cast<abstract::AbstractDictionaryPtr>()->elements();
1200 for (size_t idx = 0; idx < elems.size(); ++idx) {
1201 if (!CheckAbstract(elems[idx].first, true) || !CheckAbstract(elems[idx].first, true)) {
1202 return false;
1203 }
1204 }
1205 }
1206 if (abs->isa<abstract::AbstractSlice>()) {
1207 auto slice = abs->cast<abstract::AbstractSlicePtr>();
1208 return !CheckAbstract(slice->start(), true) || !CheckAbstract(slice->stop(), true) ||
1209 !CheckAbstract(slice->step(), true);
1210 }
1211 return true;
1212 }
1213
CheckAbstract(abstract::AbstractBasePtr abs,bool incontainer)1214 static bool CheckAbstract(abstract::AbstractBasePtr abs, bool incontainer) {
1215 if (incontainer && abs->isa<abstract::AbstractAny>()) {
1216 return false;
1217 }
1218 if (abs->isa<abstract::AbstractTuple>() || abs->isa<abstract::AbstractList>() ||
1219 abs->isa<abstract::AbstractSequence>() || abs->isa<abstract::AbstractDictionary>() ||
1220 abs->isa<abstract::AbstractSlice>()) {
1221 return CheckContainer(abs);
1222 }
1223 if (abs->isa<abstract::AbstractNone>() || abs->isa<abstract::AbstractNull>() || abs->isa<abstract::AbstractType>() ||
1224 abs->isa<abstract::AbstractFunction>() || abs->isa<abstract::AbstractAny>()) {
1225 return false;
1226 }
1227 if (abs->isa<abstract::AbstractScalar>()) {
1228 auto tp = abs->GetTypeTrack()->type_id();
1229 return tp != kMetaTypeNone && tp != kMetaTypeNull && tp != kNumberTypeBool;
1230 }
1231 return true;
1232 }
1233
CheckValidReturn(const JitCompileResults * c)1234 static bool CheckValidReturn(const JitCompileResults *c) {
1235 auto graph_executor = mindspore::pipeline::GraphExecutorPy::GetInstance();
1236 FuncGraphPtr ms_func_graph = graph_executor->GetFuncGraph(c->code->GetPhase());
1237 auto abs = ms_func_graph->output()->abstract();
1238 return CheckAbstract(abs, false);
1239 }
1240
PreferCallGraph(const JitCompileResults * c,py::object args)1241 static bool PreferCallGraph(const JitCompileResults *c, py::object args) {
1242 if (c->code->GetNativeFunc() == nullptr) {
1243 return false;
1244 }
1245 if (c->conf->GetBoolConfig(GraphJitConfig::kTraceFlag)) {
1246 return true;
1247 }
1248 if (!CheckValidReturn(c)) {
1249 return false;
1250 }
1251 py::tuple t = py::cast<py::tuple>(args);
1252 for (size_t i = 0; i < t.size(); ++i) {
1253 py::object obj = t[i];
1254 if (IsMutableObj(obj)) {
1255 continue;
1256 }
1257 if ((py::isinstance<py::list>(t[i]) || py::isinstance<py::tuple>(t[i])) && CheckTensorInContainer(t[i])) {
1258 return false;
1259 }
1260 }
1261 OptStrategy::ExecKind stat = OptStrategy::ExecKind::kExecGraph;
1262 if (c->conf->GetBoolConfig(GraphJitConfig::kPerfStatistics)) {
1263 constexpr auto kStatisticsScale = 10000.0;
1264 int scale_statistics = c->conf->getIntConfig(GraphJitConfig::kPerfStatisticsScale10000x);
1265 stat = OptStrategy::MakeExecStrategyByPerf(
1266 c->code->GetPerf(OptPerf::PerfKind::kPerfGraph), c->code->GetPerf(OptPerf::PerfKind::kPerfPyNative),
1267 c->conf->getIntConfig(GraphJitConfig::kPerfStatisticsCount), scale_statistics / kStatisticsScale);
1268 }
1269 int graph_bytecode_min = c->conf->getIntConfig(GraphJitConfig::kStaticGraphBytecodeMin);
1270 if (graph_bytecode_min > 0 && stat == OptStrategy::ExecKind::kExecGraph) {
1271 stat = OptStrategy::MakeExecStrategyByComplex(c->code->GetPythonCode(), graph_bytecode_min);
1272 }
1273 return stat == OptStrategy::ExecKind::kExecGraph;
1274 }
1275
SetExecStatus(const JitCompileResults * c,const PyFrameObject * f,bool graph_preferred)1276 static void SetExecStatus(const JitCompileResults *c, const PyFrameObject *f, bool graph_preferred) {
1277 bool enable_statistics = c->conf->GetBoolConfig(GraphJitConfig::kPerfStatistics);
1278 int graph_bytecode_min = c->conf->getIntConfig(GraphJitConfig::kStaticGraphBytecodeMin);
1279 if (enable_statistics || (graph_bytecode_min > 0)) {
1280 PyObject_SetItem(f->f_globals, reinterpret_cast<PyObject *>(f->f_code), (graph_preferred ? Py_True : Py_False));
1281 }
1282 }
1283
CallCompiledResults(PyThreadState * tstate,PyFrameObject * f,JitCompileResults * c)1284 static py::object CallCompiledResults(PyThreadState *tstate, PyFrameObject *f, JitCompileResults *c) {
1285 if (MsContext::GetInstance()->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
1286 return py::none();
1287 }
1288
1289 ValidateCompiledResults(c);
1290
1291 std::vector<py::object> packed_args = PackArgs(f);
1292 if (packed_args[1].ptr() != nullptr) {
1293 PyList_Append(packed_args[0].ptr(), packed_args[1].ptr());
1294 }
1295
1296 py::object args = py::reinterpret_steal<py::object>(PyList_AsTuple(packed_args[0].ptr()));
1297 py::object kwvargs = packed_args[2];
1298 bool graph_preferred = PreferCallGraph(c, args);
1299 SetExecStatus(c, f, graph_preferred);
1300 py::object res;
1301 if (!graph_preferred) {
1302 res = CallCompiledCallable(tstate, f, c);
1303 } else if (!c->conf->GetBoolConfig(GraphJitConfig::kCompileWithTry)) {
1304 res = CallGraph(c, args, kwvargs);
1305 } else {
1306 try {
1307 res = CallGraph(c, args, kwvargs);
1308 } catch (std::exception &e) {
1309 MS_LOG(WARNING) << "compile result has an error, de-optimization\n" << e.what();
1310 res = CallCompiledCallable(tstate, f, c);
1311 c->stat = JitCompileResults::NEVER_COMPILE;
1312 }
1313 }
1314 c->code->Inc();
1315
1316 if (kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kLogPerf)) {
1317 PyObject *new_code = c->code->GetPythonCode() ? c->code->GetPythonCode()->co_code : f->f_code->co_code;
1318 ByteCodeRunStatistic::GetInstance()->Count(graph_preferred ? f->f_code->co_code : new_code, graph_preferred);
1319 }
1320
1321 // dump traceback
1322 if (c->conf->GetBoolConfig(GraphJitConfig::kPrintTraceback)) {
1323 // dump all traceback for the root function
1324 GRAPH_JIT_LOG_F("%s\n", c->tbs->Dump(true).c_str());
1325 }
1326 if (!PyErr_Occurred()) {
1327 c->tbs->Clear();
1328 }
1329 return res;
1330 }
1331
CheckGuard(JitCompileResults * c,const PyFrameObject * f)1332 static bool CheckGuard(JitCompileResults *c, const PyFrameObject *f) {
1333 TimeRecorder time_recorder(__FUNCTION__, kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kLogPerf));
1334
1335 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kCapture, runtime::ProfilerEvent::kCaptureGuard,
1336 "PIJitGuard");
1337
1338 StaticAnalysisExceptionCleaner exception_cleaner;
1339
1340 c->code = nullptr;
1341 std::map<size_t, PyObject *> cache;
1342 std::map<size_t, bool> success;
1343 std::map<size_t, bool> fail;
1344 OptOptionPtr opt = OptOption::CreateOptionByPoint(c);
1345 auto set = c->codehub->GetOptTarget(opt);
1346 set = OptStrategy::MakeGuardListStrategyByFrame(f, set);
1347 for (size_t i = set.size(); i != 0; i--) {
1348 auto oc = set[i - 1];
1349 OptGuardPtr guard = oc->GetGuard();
1350 bool print_guard = c->conf->GetBoolConfig(GraphJitConfig::kPrintGuard);
1351 if (guard != nullptr &&
1352 guard->Check(f, print_guard, &cache, &success, &fail, c->conf->GetBoolConfig(GraphJitConfig::kLogGuardPerf))) {
1353 c->code = oc;
1354 c->codehub->UpdateOptTarget(opt, oc);
1355 break;
1356 }
1357 }
1358 for (auto item : cache) {
1359 Py_XDECREF(item.second);
1360 }
1361 MS_LOG(DEBUG) << __FUNCTION__ << (c->code != nullptr ? " success !" : " failed !");
1362 return c->code != nullptr;
1363 }
1364
1365 class JitSyntaxLevelScope {
1366 public:
JitSyntaxLevelScope(bool enable)1367 explicit JitSyntaxLevelScope(bool enable) : enable_(enable) {
1368 if (enable_) {
1369 MS_LOG(INFO) << "Start run PIJit with one stage mode";
1370 origin_jit_syntax_level_ = common::GetEnv("MS_DEV_JIT_SYNTAX_LEVEL");
1371 common::SetEnv("MS_DEV_JIT_SYNTAX_LEVEL", "0");
1372 }
1373 }
~JitSyntaxLevelScope()1374 ~JitSyntaxLevelScope() {
1375 if (enable_) {
1376 common::SetEnv("MS_DEV_JIT_SYNTAX_LEVEL", origin_jit_syntax_level_.c_str());
1377 }
1378 }
1379
1380 private:
1381 std::string origin_jit_syntax_level_;
1382 bool enable_;
1383 };
1384
JitCompileWithTry(PyThreadState * tstate,JitCompileResults * c)1385 static bool JitCompileWithTry(PyThreadState *tstate, JitCompileResults *c) {
1386 TimeRecorder time_recorder(__FUNCTION__, kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kLogPerf));
1387
1388 JitSyntaxLevelScope jit_syntax_level_scope(c->conf->GetBoolConfig(GraphJitConfig::kTraceFlag));
1389 StaticAnalysisExceptionCleaner exception_cleaner;
1390
1391 if (!c->conf->GetBoolConfig(GraphJitConfig::kCompileWithTry)) {
1392 return JitCompile(tstate, c);
1393 }
1394
1395 bool compiled = false;
1396 try {
1397 compiled = JitCompile(tstate, c);
1398 } catch (std::exception &e) {
1399 MS_LOG(ERROR) << "got an unexpected c++ error [" << e.what() << "]";
1400 }
1401 if (PyErr_Occurred()) {
1402 MS_LOG(ERROR) << "got an unexpected python error [" << py::error_already_set().what() << "]";
1403 PyErr_Clear();
1404 compiled = false;
1405 }
1406 return compiled;
1407 }
1408
EliminateStubTensor(const py::tuple & args)1409 py::tuple EliminateStubTensor(const py::tuple &args) {
1410 py::tuple new_args = py::reinterpret_steal<py::tuple>(PyTuple_New(args.size()));
1411 for (size_t idx = 0; idx < args.size(); idx++) {
1412 new_args[idx] = IsStubTensor(args[idx]) ? python_adapter::CallPyObjMethod(args[idx], "stub_sync") : args[idx];
1413 }
1414 return new_args;
1415 }
1416
1417 // bellowing code is used for debugging code generate, and will be remove soon
test_graph_ir_code_gen(PyFrameObject * frame)1418 py::object test_graph_ir_code_gen(PyFrameObject *frame) {
1419 PyFrame_FastToLocals(frame);
1420 auto func =
1421 py::reinterpret_steal<py::object>(PyFunction_New(reinterpret_cast<PyObject *>(frame->f_code), frame->f_globals));
1422 mindspore::pijit::Utils::DisFuncObject(func.ptr());
1423 auto byteCodeParser = std::make_shared<mindspore::pijit::ByteCodeParser>(func);
1424 mindspore::pijit::ir::FunctionNodePtr func_node = byteCodeParser->Parse();
1425 auto inliner = std::make_shared<mindspore::pijit::FuncInliner>(func_node);
1426 inliner->Run();
1427 int arg_cnt = frame->f_code->co_argcount + frame->f_code->co_kwonlyargcount;
1428 if (static_cast<unsigned int>(frame->f_code->co_flags) & CO_VARARGS) {
1429 arg_cnt++;
1430 }
1431 py::list locals = py::reinterpret_steal<py::list>(PyDict_Values(frame->f_locals));
1432 py::tuple args = py::reinterpret_steal<py::tuple>(PyList_AsTuple(PyList_GetSlice(locals.ptr(), 0, arg_cnt)));
1433 py::dict kwargs = (static_cast<unsigned int>(frame->f_code->co_flags) & CO_VARKEYWORDS) == 0x0
1434 ? py::dict()
1435 : py::cast<py::dict>(locals[arg_cnt]);
1436 args = EliminateStubTensor(args);
1437 mindspore::pijit::AbstractTypeDeducer::Deduce(func_node, args, kwargs);
1438 func_node->Sort();
1439 std::cout << func_node->ToString() << std::endl;
1440 auto func_obj = mindspore::pijit::ByteCodeGenerator::GenFunction(func_node);
1441 mindspore::pijit::Utils::DisFuncObject(func_obj.ptr());
1442 if ((static_cast<unsigned int>(func_node->GetFlags()) & CO_VARARGS) != 0) {
1443 auto pos_cnt = args.size() - 1;
1444 auto var_vargs = py::cast<py::tuple>(args[pos_cnt]);
1445 auto new_args = py::reinterpret_steal<py::tuple>(PyTuple_New(pos_cnt + var_vargs.size()));
1446 size_t index = 0;
1447 std::for_each(args.begin(), args.end() - 1, [&index, &new_args](const py::handle &arg) {
1448 new_args[index] = arg;
1449 index++;
1450 });
1451 std::for_each(var_vargs.begin(), var_vargs.end(), [&index, &new_args](const py::handle &arg) {
1452 new_args[index] = arg;
1453 index++;
1454 });
1455 args = new_args;
1456 }
1457 auto res = py::reinterpret_steal<py::object>(PyObject_Call(func_obj.ptr(), args.ptr(), kwargs.ptr()));
1458 res.inc_ref();
1459 return res;
1460 }
1461
CodeHook(PyThreadState * tstate,JitCompileResults * c,PyFrameObject * frame)1462 static py::object CodeHook(PyThreadState *tstate, JitCompileResults *c, PyFrameObject *frame) {
1463 if (c->conf->GetBoolConfig(GraphJitConfig::kTestGraphIR)) {
1464 return test_graph_ir_code_gen(frame);
1465 }
1466 bool just_compiled = false;
1467 switch (c->stat) {
1468 case JitCompileResults::NEVER_COMPILE:
1469 break;
1470 case JitCompileResults::GRAPH_CAPTURED:
1471 if (c->conf->GetBoolConfig(GraphJitConfig::kInterpretCapturedCode)) {
1472 break;
1473 }
1474 /* fallthrough */
1475 case JitCompileResults::GRAPH_CANDIDATE:
1476 MS_EXCEPTION_IF_CHECK_FAIL(c->origin_frame_ == nullptr || c->origin_frame_ == frame,
1477 "check recursive call compiling function");
1478 c->origin_frame_ = frame;
1479 if (c->conf->GetBoolConfig(GraphJitConfig::kCompileWithoutCapture)) {
1480 c->stat = JitCompileResults::GRAPH_CAPTURED;
1481 }
1482 if (!JitCompileWithTry(tstate, c)) {
1483 c->stat = JitCompileResults::NEVER_COMPILE;
1484 break;
1485 }
1486 just_compiled = true;
1487 /* fallthrough */
1488 case JitCompileResults::GRAPH_CALLABLE: {
1489 if (CheckGuard(c, frame)) {
1490 c->origin_frame_ = nullptr;
1491 return CallCompiledResults(tstate, frame, c);
1492 }
1493 if (c->stat == JitCompileResults::NEVER_COMPILE) {
1494 break;
1495 }
1496 if (!just_compiled) {
1497 c->stat = JitCompileResults::GRAPH_CANDIDATE;
1498 return CodeHook(tstate, c, frame);
1499 }
1500 MS_LOG(EXCEPTION) << "shouldn't reach here";
1501 }
1502 case JitCompileResults::GRAPH_BUILDING:
1503 MS_LOG(ERROR) << "recursive call, compiler call the code "
1504 << std::string(py::str(reinterpret_cast<PyObject *>(frame->f_code))) << " which is compiling";
1505 break;
1506 default:
1507 MS_LOG(EXCEPTION) << "shouldn't reach here";
1508 break;
1509 }
1510 PyObject *res = _PyEval_EvalFrameDefault(tstate, frame, 0);
1511 return py::reinterpret_steal<py::object>(res);
1512 }
1513
ApplyAutoJit(PyFrameObject * f)1514 static void ApplyAutoJit(PyFrameObject *f) {
1515 if (!kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kAutoJit)) {
1516 return;
1517 }
1518
1519 PyObject *code = reinterpret_cast<PyObject *>(f->f_code);
1520 if (getJitCompileResults(code, false) != nullptr) {
1521 return;
1522 }
1523
1524 // first reached this code
1525 // allocate for all code while auto jit
1526 (void)getJitCompileResults(code, true);
1527 if (!kPIJitConfigDefault.ShouldAutoJit(f)) {
1528 return;
1529 }
1530 (void)pi_jit_should_compile(py::cast<py::object>(code), py::dict(), py::none());
1531 }
1532
CollectGradientArguments(const PyFrameObject & frame)1533 py::list CollectGradientArguments(const PyFrameObject &frame) {
1534 py::list arguments;
1535
1536 // Collect Positional Arguments
1537 for (int index = 1; index < frame.f_code->co_argcount; index++) {
1538 arguments.append(py::cast<py::object>(frame.f_localsplus[index]));
1539 }
1540
1541 // Collect Variable Arguments
1542 if ((static_cast<unsigned int>(frame.f_code->co_flags) & CO_VARARGS) != 0x0) {
1543 auto var_args = py::cast<py::tuple>(frame.f_localsplus[frame.f_code->co_argcount]);
1544 std::for_each(var_args.begin(), var_args.end(), [&arguments](const auto &arg) { arguments.append(arg); });
1545 }
1546
1547 // Collect Variable Arguments
1548 if ((static_cast<unsigned int>(frame.f_code->co_flags) & CO_VARKEYWORDS) != 0x0) {
1549 auto kw_args = py::cast<py::dict>(frame.f_localsplus[frame.f_code->co_argcount + 1]);
1550 std::for_each(kw_args.begin(), kw_args.end(), [&arguments](const auto &item) { arguments.append(item.second); });
1551 }
1552
1553 return arguments;
1554 }
1555
AutoGrad(PyFrameObject * f,PyObject * ret)1556 void AutoGrad(PyFrameObject *f, PyObject *ret) {
1557 // improve performance for infer
1558 if (kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kInferOnly)) {
1559 return;
1560 }
1561 // must have a return value and prim must have argument
1562 if (ret == nullptr || f->f_code->co_argcount <= 0) {
1563 return;
1564 }
1565 // the call function of primitive
1566 if (py::cast<py::object>(f->f_code->co_name).cast<std::string>() != "__call__") {
1567 return;
1568 }
1569 // only record primitvie now
1570 if (f->f_localsplus[0] == nullptr) {
1571 return;
1572 }
1573 if (!py::isinstance<Primitive>(f->f_localsplus[0]) && !py::isinstance<PrimitivePy>(f->f_localsplus[0]) &&
1574 !py::isinstance<PrimitivePyAdapter>(f->f_localsplus[0])) {
1575 return;
1576 }
1577 // gradient info check
1578 if (!grad::FunctionNode::HasAttrReqGrad(ret) && !py::isinstance<py::tuple>(ret)) {
1579 return;
1580 }
1581 MS_EXCEPTION_IF_CHECK_FAIL(f->f_code->co_kwonlyargcount == 0, "Must not have kw only args.");
1582 auto inputs = CollectGradientArguments(*f);
1583 if (!std::any_of(inputs.begin(), inputs.end(),
1584 [](const auto &input) { return grad::FunctionNode::IsRequiresGradient(input); })) {
1585 return;
1586 }
1587 grad::FunctionNode::RecordPrimitive(py::cast<py::object>(f->f_localsplus[0]), py::cast<py::object>(ret), inputs);
1588 }
1589
1590 #if (PY_MAJOR_VERSION == 3) && (PY_MINOR_VERSION < 9)
EvalFrame(PyFrameObject * f,int exc)1591 PyObject *EvalFrame(PyFrameObject *f, int exc) {
1592 PyThreadState *tstate = PyThreadState_Get();
1593
1594 #else
1595 PyObject *EvalFrame(PyThreadState *tstate, PyFrameObject *f, int exc) {
1596 #endif
1597
1598 // exception handler
1599 if (exc != 0) {
1600 return _PyEval_EvalFrameDefault(tstate, f, exc);
1601 }
1602
1603 ApplyAutoJit(f);
1604
1605 PyObject *code = reinterpret_cast<PyObject *>(f->f_code);
1606 JitCompileResults *c = getJitCompileResults(code, false);
1607 if (c == nullptr) {
1608 auto ret = _PyEval_EvalFrameDefault(tstate, f, exc);
1609 AutoGrad(f, ret);
1610 return ret;
1611 }
1612 py::object res;
1613 try {
1614 res = CodeHook(tstate, c, f);
1615 } catch (py::error_already_set &e) {
1616 e.restore();
1617 } catch (py::builtin_exception &e) {
1618 e.set_error();
1619 }
1620 return res.inc_ref().ptr();
1621 }
1622 } // namespace pijit
1623 } // namespace mindspore
1624
1625 namespace mindspore {
1626
1627 #if (PY_MAJOR_VERSION == 3) && (PY_MINOR_VERSION >= 7) && (PY_MINOR_VERSION <= 10)
1628
pi_jit_enable()1629 py::bool_ pi_jit_enable() {
1630 PyInterpreterState *inter = PyInterpreterState_Main();
1631 _PyFrameEvalFunction prev = _PyInterpreterState_GetEvalFrameFunc(inter);
1632 _PyFrameEvalFunction def = _PyEval_EvalFrameDefault;
1633 if (prev != def) {
1634 return false;
1635 }
1636 mindspore::pijit::ensureInitialize();
1637 _PyInterpreterState_SetEvalFrameFunc(inter, mindspore::pijit::EvalFrame);
1638 return true;
1639 }
1640
pi_jit_disable()1641 py::bool_ pi_jit_disable() {
1642 PyInterpreterState *inter = PyInterpreterState_Main();
1643 _PyFrameEvalFunction prev = _PyInterpreterState_GetEvalFrameFunc(inter);
1644 _PyFrameEvalFunction def = _PyEval_EvalFrameDefault;
1645 if (prev != mindspore::pijit::EvalFrame) {
1646 return false;
1647 }
1648 _PyInterpreterState_SetEvalFrameFunc(inter, def);
1649 return true;
1650 }
1651
pi_jit_should_compile(const py::object & funcHandle,const py::object & tag,const py::object & signature)1652 py::bool_ pi_jit_should_compile(const py::object &funcHandle, const py::object &tag, const py::object &signature) {
1653 PyObject *func = funcHandle.ptr();
1654 PyObject *code = NULL;
1655 if (PyFunction_Check(func)) {
1656 code = PyFunction_GET_CODE(func);
1657 } else if (PyMethod_Check(func)) {
1658 func = PyMethod_GET_FUNCTION(func);
1659 code = PyFunction_GET_CODE(func);
1660 } else if (PyCode_Check(func)) {
1661 code = func;
1662 } else {
1663 return false;
1664 }
1665 mindspore::pijit::JitCompileResults *c = mindspore::pijit::getJitCompileResults(code);
1666 if (c == nullptr) {
1667 return false;
1668 }
1669 PyObject *sig = signature.ptr();
1670 if (sig != nullptr && sig != Py_None) {
1671 c->signature_ = sig;
1672 Py_INCREF(sig);
1673 }
1674 auto new_config = mindspore::pijit::GraphJitConfig(tag);
1675 // When switching between one-stage and two-stage, reset the config.
1676 if (c->conf->GetBoolConfig(pijit::GraphJitConfig::kTraceFlag) !=
1677 new_config.GetBoolConfig(pijit::GraphJitConfig::kTraceFlag)) {
1678 c->code = nullptr;
1679 c->codehub = std::make_shared<pijit::OptCodeHub>();
1680 }
1681 if (c->stat != mindspore::pijit::JitCompileResults::NEVER_COMPILE) {
1682 *c->conf = new_config;
1683 return true;
1684 }
1685
1686 auto raw_code_size = (PyBytes_GET_SIZE(reinterpret_cast<PyCodeObject *>(code)->co_code)) / sizeof(_Py_CODEUNIT);
1687 std::string raw_func_info_name = py::str(code).cast<std::string>();
1688 std::string raw_func_name = "";
1689 if (PyFunction_Check(func)) {
1690 const char *module_name = PyUnicode_AsUTF8(PyFunction_GET_MODULE(func));
1691 const char *s = strchr(module_name, '.');
1692 std::string top_module = s ? std::string(module_name, s - module_name) : module_name;
1693 mindspore::pijit::kPIJitConfigDefault.AddAllowedInlineModules(top_module);
1694
1695 raw_func_name = mindspore::pijit::Utils::GetPyName(reinterpret_cast<PyFunctionObject *>(func)->func_qualname);
1696 }
1697
1698 c->stat = mindspore::pijit::JitCompileResults::GRAPH_CANDIDATE;
1699 *c->conf = new_config;
1700 *c->tbs = mindspore::pijit::Tracebackes(raw_func_name, raw_func_info_name, raw_code_size);
1701 return true;
1702 }
1703 #else
1704
1705 py::bool_ pi_jit_enable() {
1706 MS_LOG(ERROR) << "PiJit not support this python version " << PY_MAJOR_VERSION << '.' << PY_MINOR_VERSION
1707 << " only support on python3.7, python3.8, python3.9, python3.10";
1708 return py::bool_(false);
1709 }
1710 py::bool_ pi_jit_disable() { return py::bool_(false); }
1711 py::bool_ pi_jit_should_compile(const py::object &func, const py::object &tag, const py::object &signature) {
1712 return py::bool_(false);
1713 }
1714
1715 #endif
1716
ConvertCodeExtra(mindspore::pijit::CodeExtra * c)1717 static py::object ConvertCodeExtra(mindspore::pijit::CodeExtra *c) {
1718 if (c->code == nullptr) {
1719 return py::object();
1720 }
1721 PyCodeObject *compiled_code = c->code->GetPythonCode();
1722 auto compiled_func = c->code->GetNativeFunc();
1723 auto guard = c->code->GetGuard();
1724 if (compiled_func == nullptr && compiled_code == nullptr) {
1725 return py::object();
1726 }
1727 py::dict code;
1728 if (compiled_code != nullptr) {
1729 PyDict_SetItemString(code.ptr(), "compiled_code_", reinterpret_cast<PyObject *>(compiled_code));
1730 }
1731 if (compiled_func != nullptr) {
1732 PyDict_SetItemString(code.ptr(), "phase_", py::str(c->code->GetPhase()).ptr());
1733 }
1734 if (guard != nullptr && !guard->IsEmpty()) {
1735 PyDict_SetItemString(code.ptr(), "guard_", py::str(guard->ToString()).ptr());
1736 }
1737 PyDict_SetItemString(code.ptr(), "call_count_", py::int_(c->code->Count()).ptr());
1738 return code;
1739 }
1740
get_code_extra(const py::object & func)1741 py::object get_code_extra(const py::object &func) {
1742 py::object code = mindspore::pijit::GetPyCodeObject(func);
1743 if (code.ptr() == nullptr) {
1744 return py::none();
1745 }
1746 auto c = mindspore::pijit::getJitCompileResults(code.ptr(), false);
1747 if (c == nullptr) {
1748 return py::none();
1749 }
1750
1751 constexpr const char *stat_str[] = {
1752 "NEVER_COMPILE", "GRAPH_CANDIDATE", "GRAPH_CAPTURED", "GRAPH_BUILDING", "GRAPH_CALLABLE",
1753 };
1754
1755 py::dict result;
1756 py::object compiled_code = ConvertCodeExtra(c);
1757 if (compiled_code.ptr() != nullptr) {
1758 PyDict_SetItemString(result.ptr(), "code", compiled_code.ptr());
1759 }
1760 PyDict_SetItemString(result.ptr(), "stat", py::str(stat_str[c->stat]).ptr());
1761 PyDict_SetItemString(result.ptr(), "compile_count_", py::int_(c->compile_count_).ptr());
1762 PyDict_SetItemString(result.ptr(), "break_count_", py::int_(c->break_count_).ptr());
1763 return result;
1764 }
1765
FunctionId(const py::object & callable)1766 size_t FunctionId(const py::object &callable) {
1767 // filter special cpp function
1768 auto py_cfunction_filter = [](PyObject *op) -> void * {
1769 // pybind11::cpp_function::dispatcher;
1770 static PyCFunction pybind_dispatcher = PyCFunction_GET_FUNCTION(py::cpp_function([]() {}).ptr());
1771 PyCFunction result = PyCFunction_GET_FUNCTION(op);
1772 return result == pybind_dispatcher ? op : reinterpret_cast<void *>(result);
1773 };
1774 PyObject *op = callable.ptr();
1775 if (PyMethod_Check(op)) {
1776 op = PyMethod_GET_FUNCTION(op);
1777 }
1778 if (PyInstanceMethod_Check(op)) {
1779 op = PyInstanceMethod_GET_FUNCTION(op);
1780 }
1781 void *result = op;
1782 if (PyCFunction_Check(op)) {
1783 // types.BuiltinFunctionType = type(len) same as types.BuiltinMethodType = type(list().append)
1784 result = py_cfunction_filter(op);
1785 } else if (Py_IS_TYPE(op, &PyMethodDescr_Type)) {
1786 // types.MethodDescriptorType = type(list.append)
1787 PyCFunction func = reinterpret_cast<PyMethodDescrObject *>(op)->d_method->ml_meth;
1788 result = reinterpret_cast<void *>(func);
1789 } else if (Py_IS_TYPE(op, &PyWrapperDescr_Type)) {
1790 // types.WrapperDescriptorType = type(object.__init__)
1791 result = reinterpret_cast<PyWrapperDescrObject *>(op)->d_wrapped;
1792 } else if (Py_IS_TYPE(op, &_PyMethodWrapper_Type)) {
1793 // types.WrapperDescriptorType = type(object().__str__)
1794 PyObject *self = PyObject_GetAttrString(op, "__self__");
1795 PyObject *attr = PyObject_GetAttrString(op, "__name__");
1796 PyObject *descr = PyObject_GetAttr(reinterpret_cast<PyObject *>(Py_TYPE(self)), attr);
1797 result = reinterpret_cast<PyWrapperDescrObject *>(descr)->d_wrapped;
1798 Py_DECREF(self);
1799 Py_DECREF(attr);
1800 Py_DECREF(descr);
1801 }
1802 return reinterpret_cast<size_t>(result);
1803 }
1804
1805 } // namespace mindspore
1806