• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pipeline/jit/pi/graph_capture/graph_build.h"
17 #include <algorithm>
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <vector>
22 #include <utility>
23 #include <unordered_map>
24 #include <map>
25 #include "pipeline/jit/pi/common.h"
26 #include "pipeline/jit/pi/graph_capture/loop_unrolling.h"
27 #include "pipeline/jit/pi/graph_capture/special_func_infer.h"
28 #include "pipeline/jit/pi/graph_guard/infer.h"
29 #include "pipeline/jit/pi/external.h"
30 #include "pipeline/jit/pi/graph_build/func_graph_builder.h"
31 #include "pipeline/jit/pi/graph_capture/abstract_object.h"
32 #include "include/common/debug/anf_ir_dump.h"
33 #include "pipeline/jit/pi/graph_compiler/utils.h"
34 #include "ops/sequence_ops.h"
35 #include "ops/framework_ops.h"
36 #include "ops/structure_ops.h"
37 #include "mindspore/core/ir/cell.h"
38 #include "pybind_api/ir/primitive_py.h"
39 #include "ops/auto_generate/gen_ops_primitive.h"
40 
41 namespace mindspore {
42 namespace pijit {
43 extern TracePtr GetTrace(ValueNode *node, bool strict, bool print, int depth, int max_depth);
44 
45 void LogGuardFailed(ValueNode *node, const GraphJitConfig &conf, const std::string &msg);
46 static bool GuardLoopSequence(Graph *graph, ValueNode *seq_node, Py_ssize_t seq_size = -1);
47 
48 const char *GraphBuilder::ID___self__ = "__self__";
49 const char *GraphBuilder::ID___globals__ = "__globals__";
50 const char *GraphBuilder::ID___call__ = "__call__";
51 const char *GraphBuilder::ID_construct = "construct";
52 
53 static const int infer_primitive_create = 1;
54 static const int infer_primitive_object = 2;
55 static const int infer_primitive_func = 4;
56 static int infer_func_count = 0;
57 static constexpr const char *kPIJitCopyFuncKey = ".<pijit.copy>.";
58 
59 const std::unordered_map<int, bool (GraphBuilder::*)(const Instr &)> GraphBuilder::bytecode_meth_map_ = {
60   {POP_TOP, &GraphBuilder::DoStackOp},
61   {ROT_TWO, &GraphBuilder::DoStackOp},
62   {ROT_THREE, &GraphBuilder::DoStackOp},
63   {ROT_FOUR, &GraphBuilder::DoStackOp},
64   {DUP_TOP, &GraphBuilder::DoStackOp},
65   {DUP_TOP_TWO, &GraphBuilder::DoStackOp},
66   {NOP, &GraphBuilder::DoNop},
67   {EXTENDED_ARG, &GraphBuilder::DoNop},
68   {RETURN_VALUE, &GraphBuilder::DoReturn},
69   {UNARY_POSITIVE, &GraphBuilder::DoUnary},
70   {UNARY_NEGATIVE, &GraphBuilder::DoUnary},
71   {UNARY_NOT, &GraphBuilder::DoUnary},
72   {UNARY_INVERT, &GraphBuilder::DoUnary},
73   {BINARY_MATRIX_MULTIPLY, &GraphBuilder::DoBinary},
74   {BINARY_MULTIPLY, &GraphBuilder::DoBinaryMul},
75   {BINARY_MODULO, &GraphBuilder::DoBinary},
76   {BINARY_POWER, &GraphBuilder::DoBinary},
77   {BINARY_ADD, &GraphBuilder::DoBinaryAdd},
78   {BINARY_SUBTRACT, &GraphBuilder::DoBinary},
79   {BINARY_FLOOR_DIVIDE, &GraphBuilder::DoBinary},
80   {BINARY_TRUE_DIVIDE, &GraphBuilder::DoBinary},
81   {BINARY_LSHIFT, &GraphBuilder::DoBinary},
82   {BINARY_RSHIFT, &GraphBuilder::DoBinary},
83   {BINARY_AND, &GraphBuilder::DoBinary},
84   {BINARY_XOR, &GraphBuilder::DoBinary},
85   {BINARY_OR, &GraphBuilder::DoBinary},
86   {INPLACE_MATRIX_MULTIPLY, &GraphBuilder::DoBinary},
87   {INPLACE_MULTIPLY, &GraphBuilder::DoBinary},
88   {INPLACE_MODULO, &GraphBuilder::DoBinary},
89   {INPLACE_POWER, &GraphBuilder::DoBinary},
90   {INPLACE_ADD, &GraphBuilder::DoInplaceAdd},
91   {INPLACE_SUBTRACT, &GraphBuilder::DoBinary},
92   {INPLACE_FLOOR_DIVIDE, &GraphBuilder::DoBinary},
93   {INPLACE_TRUE_DIVIDE, &GraphBuilder::DoBinary},
94   {INPLACE_LSHIFT, &GraphBuilder::DoBinary},
95   {INPLACE_RSHIFT, &GraphBuilder::DoBinary},
96   {INPLACE_AND, &GraphBuilder::DoBinary},
97   {INPLACE_XOR, &GraphBuilder::DoBinary},
98   {INPLACE_OR, &GraphBuilder::DoBinary},
99   {IS_OP, &GraphBuilder::DoIsOp},
100   {CONTAINS_OP, &GraphBuilder::DoIsOp},
101   {BUILD_TUPLE, &GraphBuilder::DoBuildOp},
102   {BUILD_LIST, &GraphBuilder::DoBuildOp},
103   {BUILD_SET, &GraphBuilder::DoBuildOp},
104   {BUILD_MAP, &GraphBuilder::DoBuildOp},
105   {BUILD_SLICE, &GraphBuilder::DoBuildOp},
106   {BUILD_CONST_KEY_MAP, &GraphBuilder::DoBuildOp},
107   {BUILD_STRING, &GraphBuilder::DoBuildOp},
108   {LIST_APPEND, &GraphBuilder::DoMergeOp},
109   {LIST_EXTEND, &GraphBuilder::DoMergeOp},
110   {DICT_MERGE, &GraphBuilder::DoMergeOp},
111   {DICT_UPDATE, &GraphBuilder::DoMergeOp},
112   {SET_UPDATE, &GraphBuilder::DoMergeOp},
113   {SET_ADD, &GraphBuilder::DoMergeOp},
114   {MAP_ADD, &GraphBuilder::DoMergeOp},
115   {COMPARE_OP, &GraphBuilder::DoCompare},
116   {MAKE_FUNCTION, &GraphBuilder::DoMakeFunction},
117   {FORMAT_VALUE, &GraphBuilder::DoFormatValue},
118   {LIST_TO_TUPLE, &GraphBuilder::DoListToTuple},
119   {LOAD_CONST, &GraphBuilder::DoLoadConst},
120   {IMPORT_STAR, &GraphBuilder::DoImport},
121   {IMPORT_NAME, &GraphBuilder::DoImport},
122   {IMPORT_FROM, &GraphBuilder::DoImport},
123   {CALL_FUNCTION, &GraphBuilder::DoCall},
124   {CALL_FUNCTION_KW, &GraphBuilder::DoCall},
125   {CALL_FUNCTION_EX, &GraphBuilder::DoCall},
126   {CALL_METHOD, &GraphBuilder::DoCall},
127   {UNPACK_SEQUENCE, &GraphBuilder::DoUnpack},
128   {UNPACK_EX, &GraphBuilder::DoUnpack},
129   {BINARY_SUBSCR, &GraphBuilder::DoItemAccess},
130   {STORE_SUBSCR, &GraphBuilder::DoItemAccess},
131   {DELETE_SUBSCR, &GraphBuilder::DoItemAccess},
132   {LOAD_GLOBAL, &GraphBuilder::DoGlobalAccess},
133   {STORE_GLOBAL, &GraphBuilder::DoGlobalAccess},
134   {DELETE_GLOBAL, &GraphBuilder::DoGlobalAccess},
135   {LOAD_METHOD, &GraphBuilder::DoAttrAccess},
136   {LOAD_ATTR, &GraphBuilder::DoAttrAccess},
137   {STORE_ATTR, &GraphBuilder::DoAttrAccess},
138   {DELETE_ATTR, &GraphBuilder::DoAttrAccess},
139   {LOAD_CLOSURE, &GraphBuilder::DoCellAccess},
140   {LOAD_DEREF, &GraphBuilder::DoCellAccess},
141   {STORE_DEREF, &GraphBuilder::DoCellAccess},
142   {DELETE_DEREF, &GraphBuilder::DoCellAccess},
143   {LOAD_FAST, &GraphBuilder::DoLocalAccess},
144   {STORE_FAST, &GraphBuilder::DoLocalAccess},
145   {DELETE_FAST, &GraphBuilder::DoLocalAccess},
146   {GET_ITER, &GraphBuilder::DoGetIter},
147   {FOR_ITER, &GraphBuilder::TraceRunForIter},
148   {POP_JUMP_IF_FALSE, &GraphBuilder::TraceRunControl},
149   {POP_JUMP_IF_TRUE, &GraphBuilder::TraceRunControl},
150   {JUMP_IF_FALSE_OR_POP, &GraphBuilder::TraceRunControl},
151   {JUMP_IF_TRUE_OR_POP, &GraphBuilder::TraceRunControl},
152   {JUMP_FORWARD, &GraphBuilder::TraceRunControl},
153   {JUMP_ABSOLUTE, &GraphBuilder::TraceRunControl},
154   {YIELD_VALUE, &GraphBuilder::DoYieldValue},
155   {POP_BLOCK, &GraphBuilder::DoException},
156   {SETUP_WITH, &GraphBuilder::DoException},
157   {SETUP_FINALLY, &GraphBuilder::DoException},
158   {WITH_CLEANUP_START, &GraphBuilder::DoException},
159   {WITH_CLEANUP_FINISH, &GraphBuilder::DoException},
160   {END_FINALLY, &GraphBuilder::DoException},
161   {SETUP_EXCEPT, &GraphBuilder::DoException},
162 };
163 
DoOtherBytecode(const Instr & instr)164 bool GraphBuilder::DoOtherBytecode(const Instr &instr) {
165   MS_LOG(ERROR) << "TODO: resolve for instruction " << instr.ToString();
166   return false;
167 }
168 
ReplaceAll(ValueNode * old_node,ValueNode * new_node,bool * is_referenced)169 bool GraphBuilder::ReplaceAll(ValueNode *old_node, ValueNode *new_node, bool *is_referenced) {
170   static const std::set<int> ref_op = {
171     BUILD_TUPLE, BUILD_LIST, BUILD_SET, BUILD_MAP, BUILD_CONST_KEY_MAP,
172   };
173 
174   // check reference relationship
175   const auto &nodes = graph_->GetTracedNodes();
176   bool find = std::any_of(nodes.begin(), nodes.end(), [&old_node](ValueNode *node) {
177     if (Opcode(node->GetOpcode()).MayDelete() && ref_op.find(node->GetOpcode()) == ref_op.end()) {
178       return false;
179     }
180     const auto &args = node->getInputs();
181     return std::any_of(args.begin(), args.end(), [&old_node](ValueNode *i) { return i == old_node; });
182   });
183   if (is_referenced != nullptr) {
184     *is_referenced |= find;
185   } else if (find) {
186     return false;
187   }
188 
189   if (parent_ != nullptr && !parent_->ReplaceAll(old_node, new_node, is_referenced)) {
190     return false;
191   }
192   // find id_map, replace all nodes......
193   const auto pred = [&old_node](ValueNode *i) { return i == old_node; };
194   std::replace_if(frame_.GetLocals().begin(), frame_.GetLocals().end(), pred, new_node);
195   std::replace_if(frame_.GetStacks().begin(), frame_.GetStacks().end(), pred, new_node);
196   std::for_each(frame_.GetClosures().begin(), frame_.GetClosures().end(), [&old_node, &new_node](CellVarNode *i) {
197     if (i->GetValue() == old_node) {
198       i->SetValue(new_node);
199     }
200   });
201   return true;
202 }
203 
NewValueNode(AObject * o,int op,int arg,const std::vector<ValueNode * > & p,const std::string & name)204 ValueNode *GraphBuilder::NewValueNode(AObject *o, int op, int arg, const std::vector<ValueNode *> &p,
205                                       const std::string &name) {
206   ValueNode *v;
207   if (Opcode(op).IsCall()) {
208     v = graph_->NewCallNode(op, arg, p);
209     v->SetVobj(o);
210   } else {
211     v = graph_->NewValueNode(o, op, arg, p, name);
212   }
213   v->set_bci(cur_bci_);
214   return v;
215 }
216 
NewValueNode(AObject * o,const Instr & i,const std::vector<ValueNode * > & p)217 ValueNode *GraphBuilder::NewValueNode(AObject *o, const Instr &i, const std::vector<ValueNode *> &p) {
218   ValueNode *v = NewValueNode(o, i.op(), i.arg(), p, i.name());
219   v->SetLineNo(i.line());
220   graph_->GetTracedNodes().push_back(v);
221   return v;
222 }
223 
NewGraph(PyCodeObject * co,PyObject * globals)224 Graph *GraphBuilder::NewGraph(PyCodeObject *co, PyObject *globals) {
225   std::vector<Graph *> &graphs = (root_ != nullptr) ? root_->graph_pool_ : this->graph_pool_;
226   if ((root_ == nullptr || root_ == this) && graph_ == nullptr) {
227     JitCompileResults *jcr = getJitCompileResults(reinterpret_cast<PyObject *>(co), false);
228     MS_EXCEPTION_IF_CHECK_FAIL(jcr && jcr->code != nullptr, "must be create guard code before trace start");
229     graphs.push_back(new Graph(co, globals, *jcr->conf));
230     graphs.back()->SetGuard(jcr->code);
231     // initialize side-effect handler, set unique data
232     graphs.back()->SetSideEffect(std::make_shared<SideEffect>());
233     graphs.back()->GetSideEffect()->set_data(std::make_shared<SideEffectData>());
234   } else {
235     graphs.push_back(new Graph(co, globals, root_->GetGraph()->Config()));
236     graphs.back()->SetGuard(root_->GetGraph()->GetGuard());
237     graphs.back()->SetSideEffect(root_->GetGraph()->GetSideEffect());
238   }
239   return graphs.back();
240 }
241 
CheckValueValid(AObject * obj)242 static bool CheckValueValid(AObject *obj) {
243   if (obj->GetType() == AObject::kTypeTensor) {
244     AbstractTensor *tensor = static_cast<AbstractTensor *>(obj);
245     return tensor->IsStubTensor() || CheckTensorDataInitialized(obj->GetPyObject());
246   } else {
247     return true;
248   }
249 }
250 
CondIsTrue(ValueNode * cond)251 int CondIsTrue(ValueNode *cond) {
252   // if cond is tensor attrs, infer tensor attrs
253   // if tensor is return node of cell, if tensor is return node of primitive
254   // if tensor is result of math operation(+-*/...)
255   AObject *cond_value = cond->GetVobj();
256   int ret = -1;
257   if (cond_value == nullptr || cond_value->GetPyObject().ptr() == nullptr) {
258     return ret;
259   }
260   py::object value = cond_value->GetPyObject();
261   if (CheckValueValid(cond_value)) {
262     ret = PyObject_IsTrue(value.ptr());
263     PyErr_Clear();
264   }
265   return ret;
266 }
267 
CollectObjects(const std::vector<ValueNode * > & nodes)268 static std::vector<AObject *> CollectObjects(const std::vector<ValueNode *> &nodes) {
269   std::vector<AObject *> res;
270   std::transform(nodes.begin(), nodes.end(), std::back_inserter(res),
271                  [](const ValueNode *node) { return node->GetVobj(); });
272   return res;
273 }
274 
UnpackConstObject(const py::object & iterable)275 std::vector<ValueNode *> GraphBuilder::UnpackConstObject(const py::object &iterable) {
276   std::vector<ValueNode *> outputs;
277   std::transform(iterable.begin(), iterable.end(), std::back_inserter(outputs), [this](const auto &item) {
278     return this->NewValueNode(AObject::Convert(item.ptr()), LOAD_CONST, -1, {});
279   });
280   return outputs;
281 }
282 
UnpackSequenceElements(ValueNode * node)283 bool GraphBuilder::UnpackSequenceElements(ValueNode *node) {
284   py::object seq = node->GetVobj()->GetPyObject();
285   if (seq.ptr() == nullptr || !PySequence_Check(seq.ptr()) || !GuardLoopSequence(this->graph_, node)) {
286     return false;
287   }
288 
289   Py_ssize_t size = PySequence_Size(seq.ptr());
290   for (Py_ssize_t index = 0; index < size; ++index) {
291     push(node);
292     DoLoadConst({LOAD_CONST, -1, py::object(py::int_(index))});
293     DoItemAccess({BINARY_SUBSCR, 0});
294   }
295   return true;
296 }
297 
UnpackElements(ValueNode * node)298 bool GraphBuilder::UnpackElements(ValueNode *node) {
299   int opcode = node->GetOpcode();
300   if (opcode == BUILD_LIST || opcode == BUILD_TUPLE) {
301     std::for_each(node->getInputs().begin(), node->getInputs().end(), [this](ValueNode *i) { this->push(i); });
302   } else if (node->IsConstantValue()) {
303     std::vector<ValueNode *> nodes = UnpackConstObject(node->GetVobj()->GetPyObject());
304     std::for_each(nodes.begin(), nodes.end(), [this](ValueNode *i) { this->push(i); });
305   } else {
306     return UnpackSequenceElements(node);
307   }
308   return true;
309 }
310 
GenUnpackValue(const std::function<void (int,int)> & gen_item,int cnt,int cnt_after,Py_ssize_t size)311 static void GenUnpackValue(const std::function<void(int, int)> &gen_item, int cnt, int cnt_after, Py_ssize_t size) {
312   if (cnt_after != -1) {
313     const int end_pos = size - cnt_after;
314     for (int i = size; i > end_pos; --i) {
315       gen_item(i - 1, -1);
316     }
317     gen_item(cnt, end_pos);
318   }
319   for (; cnt > 0; --cnt) {
320     gen_item(cnt - 1, -1);
321   }
322 }
323 
GetUnpackSize(ValueNode * iterable,int cnt,int cnt_after)324 Py_ssize_t GetUnpackSize(ValueNode *iterable, int cnt, int cnt_after) {
325   int op = iterable->GetOpcode();
326   Py_ssize_t total_args = cnt + cnt_after + 1;
327   Py_ssize_t size;
328   if (op == BUILD_LIST || op == BUILD_TUPLE) {
329     size = iterable->getInputs().size();
330   } else {
331     AObject *seq = iterable->GetVobj();
332     PyObject *o = (seq == nullptr) ? nullptr : seq->GetPyObject().ptr();
333     size = (o == nullptr) ? -1 : PyObject_Size(o);
334   }
335   if (size == -1 || (cnt_after == -1 && cnt != size) || total_args > size + 1) {
336     PyErr_Clear();
337     return -1;
338   }
339   return size;
340 }
341 
DoUnpack(const Instr & instr)342 bool GraphBuilder::DoUnpack(const Instr &instr) {
343   int opcode = instr.op();
344   int oparg = instr.arg();
345   int cnt = (opcode == UNPACK_EX) ? (oparg & 0xFF) : oparg;
346   int cnt_after = (opcode == UNPACK_EX) ? (oparg >> 8) : -1;
347   Py_ssize_t size = GetUnpackSize(seek(0), cnt, cnt_after);
348   if (size == -1) {
349     return false;
350   }
351   ValueNode *iterable = pop();
352 
353   size_t elements_size = frame_.GetStacks().size();
354   int iterable_opcode = iterable->GetOpcode();
355   if (iterable_opcode == BUILD_LIST || iterable_opcode == BUILD_TUPLE) {
356     std::for_each(iterable->getInputs().begin(), iterable->getInputs().end(), [this](ValueNode *i) { this->push(i); });
357   } else if (iterable->IsConstantValue()) {
358     std::vector<ValueNode *> nodes = UnpackConstObject(iterable->GetVobj()->GetPyObject());
359     std::for_each(nodes.begin(), nodes.end(), [this](ValueNode *i) { this->push(i); });
360   } else {
361     for (Py_ssize_t index = 0; index < size; ++index) {
362       push(iterable);
363       DoLoadConst({LOAD_CONST, -1, py::object(py::int_(index))});
364       DoItemAccess({BINARY_SUBSCR, 0});
365     }
366   }
367   elements_size = frame_.GetStacks().size() - elements_size;
368   std::vector<ValueNode *> elements(frame_.GetStacks().end() - elements_size, frame_.GetStacks().end());
369   popn(elements_size);
370 
371   auto gen_item = [this, &elements](int i, int j) {
372     if (j == -1) {
373       this->push(elements[i]);
374       return;
375     }
376     MS_EXCEPTION_IF_CHECK_FAIL(j >= i, "check UNPACK_EX oparg");
377     auto in_iter = elements.begin();
378     std::for_each(in_iter + i, in_iter + j, [this](ValueNode *i) { this->push(i); });
379     DoBuildOp({BUILD_LIST, j - i});
380   };
381   GenUnpackValue(gen_item, cnt, cnt_after, size);
382   return true;
383 }
384 
DoCall(const Instr & instr)385 bool GraphBuilder::DoCall(const Instr &instr) {
386   Opcode opcode(instr.op());
387   int oparg = instr.arg();
388   int tmp_arg = oparg;
389   std::vector<ValueNode *> params;
390   if (opcode == CALL_FUNCTION_EX) {
391     tmp_arg = (tmp_arg & 0x01) + 1;
392   } else if (opcode == CALL_FUNCTION_KW) {
393     tmp_arg += 1;
394   }
395   MS_EXCEPTION_IF_CHECK_FAIL(opcode.IsCall(), "must be call");
396   params = {frame_.GetStacks().end() - tmp_arg - 1, frame_.GetStacks().end()};
397   opcode = (opcode == CALL_METHOD) ? CALL_FUNCTION : opcode;
398   popn(tmp_arg + 1);
399   push(NewValueNode(nullptr, opcode, oparg, params));
400 
401   CallNode *call_node = static_cast<CallNode *>(seek(0));
402   call_node->SetVobj(AObject::MakeAObject(AObject::kTypeAnyValue));
403   call_node->SetLineNo(instr.line());
404   call_node->set_bci(instr.bci());
405   this->graph_->GetTracedNodes().push_back(call_node);
406 
407   StopTraceReason r = HandleCall(0);
408   if (r != StopTraceReason::kNonStopTrace) {
409     graph_->StopTraceAt(cur_bci_, r);
410     return false;
411   }
412   return true;
413 }
414 
DoNop(const Instr & instr)415 bool GraphBuilder::DoNop(const Instr &instr) { return true; }
NotImplementBytecode(const Instr & instr)416 bool GraphBuilder::NotImplementBytecode(const Instr &instr) { return false; }
417 
DoYieldValue(const Instr & instr)418 bool GraphBuilder::DoYieldValue(const Instr &instr) {
419   ValueNode *result = graph_->GetGeneratorResult();
420   if (result == nullptr) {
421     result = NewValueNode(nullptr, BUILD_TUPLE, 0);
422     graph_->SetGeneratorResult(result);
423   }
424   ValueNode *value = seek(0);
425   result->AddInput(value);
426   return true;
427 }
428 
DoReturn(const Instr & instr)429 bool GraphBuilder::DoReturn(const Instr &instr) {
430   graph_->SetRetVal(pop());
431   if (graph_->GetGeneratorResult() == nullptr) {
432     return true;
433   }
434   const auto &inputs = graph_->GetGeneratorResult()->getInputs();
435   std::for_each(inputs.begin(), inputs.end(), [this](ValueNode *i) { this->push(i); });
436   DoBuildOp({BUILD_TUPLE, SizeToInt(inputs.size())});
437   ValueNode *new_node = pop();
438   graph_->SetGeneratorResult(new_node);
439   graph_->SetRetVal(new_node);
440   return true;
441 }
442 
GetCallFunctionNode(ValueNode * node,PyObject * dst_dtype)443 ValueNode *GraphBuilder::GetCallFunctionNode(ValueNode *node, PyObject *dst_dtype) {
444   py::object prim_cast = Utils::GetModuleAttr("mindspore.ops.functional", "cast", false, true);
445   ValueNode *prim_node = NewValueNode(AObject::Convert(prim_cast), LOAD_CONST, {});
446   ValueNode *dtype_node = NewValueNode(AObject::Convert(dst_dtype), LOAD_CONST, -1, {});
447   std::vector<ValueNode *> cast_args = {prim_node, node, dtype_node};
448   ValueNode *call_node = NewValueNode(nullptr, CALL_FUNCTION, cast_args.size() - 1, cast_args);
449   return call_node;
450 }
451 
DoMixedPrecisionLocalAccess(const Instr & instr,ValueNode * node)452 bool GraphBuilder::DoMixedPrecisionLocalAccess(const Instr &instr, ValueNode *node) {
453   auto param_node = static_cast<ParamNode *>(node);
454   auto dst_dtype = param_node->GetMixedPrecisionType();
455   ValueNode *call_node = GetCallFunctionNode(node, dst_dtype);
456   push(call_node);
457   auto *call = static_cast<CallNode *>(call_node);
458   call->SetVobj(AObject::MakeAObject(AObject::kTypeAnyValue));
459   call->SetLineNo(instr.line());
460   call->set_bci(instr.bci());
461   StopTraceReason r = HandleCall(0);
462   if (r != StopTraceReason::kNonStopTrace) {
463     graph_->StopTraceAt(cur_bci_, r);
464     return false;
465   }
466   this->graph_->GetTracedNodes().push_back(call_node);
467   return true;
468 }
469 
DoLocalAccess(const Instr & instr)470 bool GraphBuilder::DoLocalAccess(const Instr &instr) {
471   if (instr.op() == LOAD_FAST) {
472     auto local = getLocal(instr.arg());
473     if (local->GetType() == AbstractNode::Param && reinterpret_cast<ParamNode *>(local)->IsMixedPrecisionType()) {
474       // TODO(lvxudong): fix multi cast
475       DoMixedPrecisionLocalAccess(instr, local);
476     } else {
477       push(local);
478     }
479   } else if (instr.op() == STORE_FAST) {
480     setLocal(instr.arg(), pop());
481   } else if (instr.op() == DELETE_FAST) {
482     setLocal(instr.arg(), &ValueNode::kUnboundLocal);
483   } else {
484     MS_LOG(INTERNAL_EXCEPTION) << "parser got an error instruction " << instr.ToString();
485   }
486   return true;
487 }
488 
DoCellAccess(const Instr & instr)489 bool GraphBuilder::DoCellAccess(const Instr &instr) {
490   int opcode = instr.op();
491   int oparg = instr.arg();
492   ValueNode *node;
493   ValueNode *value;
494   PyObject *cell = frame_.Closure(oparg)->GetVobj()->GetPyObject().ptr();
495   MS_EXCEPTION_IF_CHECK_FAIL(cell && PyCell_Check(cell), "must be a cell object");
496   if (opcode == LOAD_CLOSURE) {
497     push(frame_.Closure(oparg));
498   } else if (opcode == LOAD_DEREF) {
499     MS_EXCEPTION_IF_NULL(frame_.Closure(oparg)->GetValue());
500     push(frame_.Closure(oparg)->GetValue());
501   } else if (opcode == STORE_DEREF) {
502     value = pop();
503     bool is_same = value->GetOpcode() == LOAD_DEREF && frame_.Closure(oparg) == frame_.Closure(value->GetOparg());
504     if (!is_same) {
505       node = NewValueNode(nullptr, instr, {value});
506       graph_->GetSideEffect()->Record(node);
507       frame_.Closure(oparg)->SetValue(value);
508       frame_.Closure(oparg)->AddCellOper(node);
509     }
510   } else if (opcode == DELETE_DEREF) {
511     node = NewValueNode(nullptr, instr, {});
512     graph_->GetSideEffect()->Record(node);
513     frame_.Closure(oparg)->SetValue(&ValueNode::kUnboundLocal);
514     frame_.Closure(oparg)->AddCellOper(node);
515   } else {
516     MS_LOG(INTERNAL_EXCEPTION) << "parser got an error instruction " << instr.ToString();
517   }
518   return true;
519 }
520 
521 // Parse byteCode -- SETUP_WITH
DoWith(const Instr & instr)522 bool GraphBuilder::DoWith(const Instr &instr) {
523   if (graph_->Config().GetBoolConfig(GraphJitConfig::kSkipException) || PyErr_Occurred()) {
524     graph_->StopTraceAt(cur_bci_, StopTraceReason::kStopTraceSkip_Exception);
525     return false;
526   }
527   auto node = pop();
528   push(node);
529   DoAttrAccess({LOAD_ATTR, 0, "__exit__"});
530 
531   push(node);
532   DoAttrAccess({LOAD_ATTR, 0, "__enter__"});
533 
534   if (!DoCall({CALL_FUNCTION, 0})) {
535     MS_LOG(ERROR) << "function '__enter__' runs failed here, it should be successful!";
536     return false;
537   }
538   PushStack(TryBlock{SETUP_WITH, instr.extra_jump()->bci(), instr.bci(), false});
539   cur_bci_++;
540   return true;
541 }
542 
DoException(const Instr & instr)543 bool GraphBuilder::DoException(const Instr &instr) {
544 #if (PY_MAJOR_VERSION == 3) && (PY_MINOR_VERSION == 8)
545   return false;
546 #else
547   int opCode = instr.op();
548   if (opCode == SETUP_WITH) {
549     return DoWith(instr);
550   } else if (opCode == POP_BLOCK) {
551     PopStack();
552     return true;
553   } else if (opCode == SETUP_FINALLY) {
554     /*
555       ByteCode like this in python3.9
556       0 SETUP_FINALLY    xxx
557       1 SETUP_FINALLY    xxx
558       the first SETUP_FINALLY points to finally block, the second points to exception block
559     */
560     if (graph_->Config().GetBoolConfig(GraphJitConfig::kSkipException)) {
561       graph_->StopTraceAt(cur_bci_, StopTraceReason::kStopTraceSkip_Exception);
562       return false;
563     }
564     if (StackSize() == 0 || GetTryBlockStacks().back().type != SETUP_FINALLY) {
565       PushStack(TryBlock{SETUP_FINALLY, instr.extra_jump()->bci(), instr.bci(), true});
566     } else {
567       assert(StackSize() > 0 || GetTryBlockStacks().back().type == SETUP_FINALLY);
568       PushStack(TryBlock{SETUP_FINALLY, instr.extra_jump()->bci(), instr.bci(), false});
569     }
570     cur_bci_++;
571     return true;
572   } else if (opCode == WITH_CLEANUP_START) {
573     /* python3.7 only */
574     ValueNode *exc = seek(0);
575     ValueNode *exit_func = seek(1);
576     if (exc->GetVobj()->GetType() != AObject::kTypeNone) {
577       return false;
578     }
579     if (exit_func->GetName() != "__exit__") {
580       MS_LOG(ERROR) << "it should call function '__exit__' here!";
581       return false;
582     }
583     // run exit func
584     push(exc);
585     push(exc);
586     if (!DoCall({CALL_FUNCTION, 3})) {
587       MS_LOG(ERROR) << "function '__exit__' runs failed here, it should be successful!";
588       return false;
589     }
590     push(exc);
591     return true;
592   } else if (opCode == WITH_CLEANUP_FINISH) {
593     auto exc = pop();
594     (void)pop();
595     push(exc);
596     return true;
597   } else if (opCode == END_FINALLY) {
598     (void)pop();
599     return true;
600   } else if (opCode == SETUP_EXCEPT) {
601     if (graph_->Config().GetBoolConfig(GraphJitConfig::kSkipException)) {
602       graph_->StopTraceAt(cur_bci_, StopTraceReason::kStopTraceSkip_Exception);
603       return false;
604     }
605     PushStack(TryBlock{SETUP_EXCEPT, instr.extra_jump()->bci(), instr.bci(), false});
606     cur_bci_++;
607     return true;
608   } else {
609     MS_LOG(INTERNAL_EXCEPTION) << "parser got an error instruction " << instr.ToString();
610   }
611   return false;
612 #endif
613 }
614 
PeekStack(int p)615 TryBlock &GraphBuilder::PeekStack(int p) {
616   MS_ASSERT(tryBlockStacks_.size() > p);
617   return tryBlockStacks_[tryBlockStacks_.size() - p - 1];
618 }
619 
PopStack()620 TryBlock &GraphBuilder::PopStack() {
621   MS_ASSERT(tryBlockStacks_.size() > 0);
622   auto &tb = tryBlockStacks_[tryBlockStacks_.size() - 1];
623   tryBlockStacks_.pop_back();
624   return tb;
625 }
626 
DoGlobalAccess(const Instr & instr)627 bool GraphBuilder::DoGlobalAccess(const Instr &instr) {
628   int opcode = instr.op();
629   int oparg = instr.arg();
630   if (opcode == LOAD_GLOBAL) {
631     auto cache_result = graph_->GetSideEffect()->LoadGlobal(graph_->GetModuleName(), instr.name());
632     if (cache_result.is_deleted_value_) {
633       return false;  // name error
634     } else if (cache_result.cache_value_ != nullptr) {
635       push(cache_result.cache_value_);
636     } else {
637       auto co = graph_->GetCodeObj();
638       PyObject *key = PyTuple_GET_ITEM(co->co_names, oparg);
639       // NOTE: will run __get__, __hash__ function
640       PyObject *obj = PyObject_GetItem(graph_->GetGlobals().ptr(), key);
641       if (obj == nullptr) {
642         PyErr_Clear();
643         obj = PyObject_GetItem(PyEval_GetBuiltins(), key);
644         if (obj == nullptr) {
645           PyErr_Clear();
646         }
647       }
648       py::object pyobj = py::reinterpret_steal<py::object>(obj);
649       auto n = NewValueNode(AObject::Convert(pyobj), instr, {});
650       n->SetName(PyUnicode_AsUTF8(key));
651       push(n);
652     }
653   } else if (opcode == STORE_GLOBAL) {
654     auto global_node = pop();
655     auto node = NewValueNode(nullptr, instr, {global_node});
656     graph_->GetSideEffect()->Record(node);
657   } else if (opcode == DELETE_GLOBAL) {
658     auto node = NewValueNode(nullptr, instr, {});
659     graph_->GetSideEffect()->Record(node);
660   } else {
661     MS_LOG(INTERNAL_EXCEPTION) << "parser got an error instruction " << instr.ToString();
662   }
663   return true;
664 }
665 
HandleSuper(const Instr & instr,AObject * super)666 bool GraphBuilder::HandleSuper(const Instr &instr, AObject *super) {
667   if (super != nullptr && super->GetTypeObject() != &PySuper_Type) {
668     return false;
669   }
670   ValueNode *self_super = SearchSelfPyObject(graph_->GetCodeObj()).second;
671   if (self_super == nullptr) {
672     return false;
673   }
674   py::object method = super->GetPyObject().attr(instr.name().c_str());
675   if (!PyMethod_Check(method.ptr())) {
676     return false;
677   }
678 
679   // method type object
680   auto mtype_obj = reinterpret_cast<PyObject *>(&PyMethod_Type);
681   DoLoadConst({LOAD_CONST, -1, py::cast<py::object>(mtype_obj)});
682 
683   // function object
684   PyObject *m = PyMethod_GET_FUNCTION(method.ptr());
685   DoLoadConst({LOAD_CONST, -1, py::cast<py::object>(m)});
686 
687   push(self_super);
688 
689   // call method type
690   return DoCall({CALL_FUNCTION, 2});
691 }
692 
SetLocalPyObject(ValueNode * node)693 PyObject *SetLocalPyObject(ValueNode *node) {
694   if (node == nullptr || node->GetVobj() == nullptr) {
695     return NULL;
696   } else {
697     return node->GetVobj()->GetPyObject().ptr();
698   }
699 }
700 
SearchSelfPyObject(PyCodeObject * co)701 std::pair<PyObject *, ValueNode *> GraphBuilder::SearchSelfPyObject(PyCodeObject *co) {
702   if (co->co_argcount < 1) {
703     return {nullptr, nullptr};
704   }
705   std::pair<PyObject *, ValueNode *> obj_value;
706   ValueNode *value = frame_.Local(0);
707   // get self or son class, eg.super(Son, self)
708   PyObject *obj = SetLocalPyObject(frame_.Local(0));
709   Py_ssize_t i, n;
710   if (obj == NULL && co->co_cell2arg) {
711     // the first argument might be a cell
712     n = PyTuple_GET_SIZE(co->co_cellvars);
713     for (i = 0; i < n; i++) {
714       if (co->co_cell2arg[i] == 0) {
715         value = frame_.Closure(i)->GetValue();
716         obj = SetLocalPyObject(frame_.Closure(i));
717         break;
718       }
719     }
720   }
721   obj_value = std::make_pair(obj, value);
722   return obj_value;
723 }
724 
HandleGetattr(ValueNode * target_node,const Instr & instr)725 ValueNode *GraphBuilder::HandleGetattr(ValueNode *target_node, const Instr &instr) {
726   return NewValueNode(target_node->get_attr(instr.name()), instr, {target_node});
727 }
728 
DoMixedPrecisionAttrAccess(const Instr & instr,ValueNode * node,ValueNode * attr)729 ValueNode *GraphBuilder::DoMixedPrecisionAttrAccess(const Instr &instr, ValueNode *node, ValueNode *attr) {
730   if (node->GetVobj() == nullptr || node->GetVobj()->GetPyObject().ptr() == nullptr ||
731       node->GetVobj()->GetType() != AbstractObjectBase::kTypeCell) {
732     return nullptr;
733   }
734   auto cell = py::cast<CellPtr>(node->GetVobj()->GetPyObject());
735   auto mixed_type = cell->GetMixedPrecisionType();
736   if (mixed_type == kNotSet) {
737     return nullptr;
738   }
739   if (attr->GetVobj() == nullptr || attr->GetVobj()->GetPyObject().ptr() == nullptr) {
740     return nullptr;
741   }
742   if (attr->GetVobj()->GetType() == AObject::kTypeTensor && !attr->GetVobj()->GetPyObject().attr("dtype").is_none()) {
743     auto src_dtype = attr->GetVobj()->GetPyObject().attr("dtype");
744     bool is_cast = false;
745     if (py::isinstance<Float>(src_dtype)) {
746       auto float_nbits = py::cast<Float>(src_dtype).nbits();
747       if (float_nbits == 64 || (float_nbits == 32 && mixed_type != kFP32) ||
748           (float_nbits == 16 && mixed_type != kFP16)) {
749         is_cast = true;
750       }
751     }
752     if (py::isinstance<BFloat>(src_dtype) && mixed_type != kBF16) {
753       is_cast = true;
754     }
755     if (is_cast) {
756       auto dst_dtype = Utils::MixedPrecisionTypeToDType(mixed_type);
757       ValueNode *call_node = GetCallFunctionNode(attr, dst_dtype);
758       CallNode *call = static_cast<CallNode *>(call_node);
759       call->SetVobj(AObject::MakeAObject(AObject::kTypeAnyValue));
760       call->SetLineNo(instr.line());
761       call->set_bci(instr.bci());
762       push(call_node);
763       StopTraceReason r = HandleCall(0);
764       if (r != StopTraceReason::kNonStopTrace) {
765         graph_->StopTraceAt(cur_bci_, r);
766         return nullptr;
767       }
768       this->graph_->GetTracedNodes().push_back(call_node);
769       return pop();
770     }
771   }
772   return nullptr;
773 }
774 
DoAttrAccess(const Instr & instr)775 bool GraphBuilder::DoAttrAccess(const Instr &instr) {
776   int opcode = instr.op();
777   if (opcode == LOAD_METHOD || opcode == LOAD_ATTR) {
778     auto o = pop();
779     if (HandleSuper(instr, o->GetVobj())) {
780       return true;
781     }
782     auto cache_result = graph_->GetSideEffect()->LoadAttr(o, instr.name());
783     if (cache_result.is_deleted_value_) {  // attribute error
784       return false;
785     } else if (cache_result.cache_value_ != nullptr) {
786       push(cache_result.cache_value_);
787     } else {
788       push(HandleGetattr(o, instr));
789       auto attr = DoMixedPrecisionAttrAccess(instr, o, seek(0));
790       if (attr) {
791         seek(0) = attr;
792       }
793     }
794   } else if (opcode == STORE_ATTR) {
795     if (trace_flag() && parent_ != nullptr) {
796       return false;
797     }
798     auto o = pop();
799     auto v = pop();
800     auto node = NewValueNode(nullptr, instr, {v, o});
801     graph_->GetSideEffect()->Record(node);
802   } else if (opcode == DELETE_ATTR) {
803     auto o = pop();
804     auto node = NewValueNode(nullptr, instr, {o});
805     graph_->GetSideEffect()->Record(node);
806   } else {
807     MS_LOG(INTERNAL_EXCEPTION) << "parser got an error instruction " << instr.ToString();
808   }
809   return true;
810 }
811 
812 // for unpack call optimize
TupleDictGetItem(ValueNode * container,ValueNode * index_node)813 static ValueNode *TupleDictGetItem(ValueNode *container, ValueNode *index_node) {
814   if (!index_node->IsConstantValue()) {
815     return nullptr;
816   }
817   PyObject *index_object = index_node->GetVobj()->GetPyObject().ptr();
818   int opcode = container->GetOpcode();
819   if ((opcode == BUILD_TUPLE || opcode == BUILD_LIST) && PyLong_Check(index_object)) {
820     Py_ssize_t index = PyLong_AsSsize_t(index_object);
821     Py_ssize_t size = container->getInputs().size();
822     if (index < -size || index >= size) {
823       return nullptr;
824     }
825     index = index < 0 ? (size + index) : index;
826     return container->input(index);
827   }
828   if (container->GetOpcode() == BUILD_MAP && PyUnicode_Check(index_object)) {
829     std::string k = PyUnicode_AsUTF8(index_object);
830     size_t element_count = container->GetOparg() << 1;
831     MS_EXCEPTION_IF_CHECK_FAIL(element_count == container->getInputs().size(), "check BUILD_MAP oparg");
832     for (int i = 0; i < container->GetOparg(); ++i) {
833       AObject *tmp = container->input(i * 2)->GetVobj();
834       PyObject *str = tmp ? tmp->GetPyObject().ptr() : nullptr;
835       if (str == nullptr || !PyUnicode_Check(str) || k != PyUnicode_AsUTF8(str)) {
836         continue;
837       }
838       return container->input((i << 1) + 1);
839     }
840   }
841   return nullptr;
842 }
843 
DoGetItem(const Instr & instr)844 bool GraphBuilder::DoGetItem(const Instr &instr) {
845   constexpr const char *kNameGetItem = "__getitem__";
846   auto r = pop();
847   auto l = pop();
848   ValueNode *v = TupleDictGetItem(l, r);
849   if (v != nullptr) {
850     push(v);
851     return true;
852   }
853 
854   AObject *container = l->GetVobj();
855   PyObject *op = container ? container->GetPyObject().ptr() : nullptr;
856   AObject *meth = nullptr;
857 
858   bool call_getitem = op == nullptr || container->GetType() != AObject::kTypeAnyValue;
859   if (!call_getitem) {
860     call_getitem = PyDict_Check(op) || PyTuple_Check(op) || PyList_Check(op);
861   }
862   if (!call_getitem) {
863     meth = container->GetAttr(kNameGetItem);
864     PyObject *m = meth ? meth->GetPyObject().ptr() : nullptr;
865     call_getitem = m == nullptr || !PyMethod_Check(m) || !PyFunction_Check(PyMethod_GET_FUNCTION(m));
866   }
867   if (call_getitem) {
868     /**
869      * check safe callable of __getitem__ if user defined.
870      */
871     AObject *vo = l->binary_subscr(r);
872     v = NewValueNode(vo, instr, {l, r});
873     push(v);
874     return true;
875   }
876 
877   push(l);
878   DoAttrAccess({LOAD_ATTR, 0, kNameGetItem});
879   push(r);
880   return DoCall({CALL_FUNCTION, 1});
881 }
882 
TransformDictSetItem(ValueNode * map,ValueNode * key,ValueNode * value,bool ignore_key_error)883 ValueNode *GraphBuilder::TransformDictSetItem(ValueNode *map, ValueNode *key, ValueNode *value, bool ignore_key_error) {
884   PyObject *index_object = key->GetVobj()->GetPyObject().ptr();
885   if (index_object == nullptr || !key->IsConstantValue()) {
886     return nullptr;  // only supported constant key
887   }
888   constexpr const int kNumberTwo = 2;
889   PyObject *map_object = map->GetVobj()->GetPyObject().ptr();
890   std::vector<ValueNode *> elements;
891   if (map->GetOpcode() == BUILD_MAP) {
892     elements = map->getInputs();
893   } else if (map_object != nullptr) {
894     auto keys = py::reinterpret_steal<py::object>(PyDict_Keys(map_object));
895     // guard dict keys, transform to const key map......
896     Py_ssize_t size = PyList_GET_SIZE(keys.ptr());
897     for (Py_ssize_t i = 0; i < size; ++i) {
898       Instr instr(LOAD_CONST, 0, py::reinterpret_borrow<py::object>(PyList_GET_ITEM(keys.ptr(), i)));
899       this->DoLoadConst(instr);
900       this->push(map);
901       this->DoLoadConst(instr);
902       this->DoGetItem({BINARY_SUBSCR, 0});
903     }
904     elements = {frame_.GetStacks().end() - size * kNumberTwo, frame_.GetStacks().end()};
905     popn(size * kNumberTwo);
906   } else {
907     return nullptr;
908   }
909 
910   // set(delete) element
911   if (value != nullptr) {
912     elements.push_back(key);
913     elements.push_back(value);
914   } else {
915     int index_of_key = -1;
916     for (int i = elements.size() - kNumberTwo; i >= 0 && index_of_key == -1; i -= kNumberTwo) {
917       bool find = elements[i]->GetVobj()->GetPyObject().equal(py::handle(index_object));
918       index_of_key = find ? i : -1;
919     }
920     if (index_of_key != -1) {
921       elements.erase(elements.begin() + index_of_key, elements.begin() + index_of_key + kNumberTwo);
922     } else if (!ignore_key_error) {
923       return nullptr;  // maybe key error
924     }
925   }
926 
927   // rebuild map
928   int size = elements.size() / kNumberTwo;
929   std::for_each(elements.begin(), elements.end(), [this](ValueNode *i) { this->push(i); });
930   DoBuildOp({BUILD_MAP, size});
931   return pop();
932 }
933 
ListIndexCompute(PyObject * index_object,Py_ssize_t size)934 std::vector<Py_ssize_t> ListIndexCompute(PyObject *index_object, Py_ssize_t size) {
935   if (PyIndex_Check(index_object)) {
936     Py_ssize_t index = PyNumber_AsSsize_t(index_object, PyExc_IndexError);
937     if (!PyErr_Occurred() && index > -size && index < size) {
938       index = index < 0 ? (index + size) : index;
939       return {index, index + 1, 1, 1};
940     }
941   } else if (PySlice_Check(index_object)) {
942     Py_ssize_t start;
943     Py_ssize_t stop;
944     Py_ssize_t step;
945     Py_ssize_t slice_length;
946     constexpr Py_ssize_t zero = 0;
947     if (0 == PySlice_GetIndicesEx(index_object, size, &start, &stop, &step, &slice_length)) {
948       slice_length = (start < 0 || stop < 0 || slice_length < 0) ? 0 : slice_length;
949       return {std::max(start, zero), std::max(stop, zero), step, slice_length};
950     }
951   }
952   if (!PyErr_Occurred()) {
953     return {};
954   }
955   throw py::error_already_set();
956 }
957 
958 template <typename T>
SetSlice(std::vector<T> * elements,const std::vector<Py_ssize_t> & computed_slice,std::vector<T> * new_elements=nullptr)959 static bool SetSlice(std::vector<T> *elements, const std::vector<Py_ssize_t> &computed_slice,
960                      std::vector<T> *new_elements = nullptr) {
961   constexpr int start = 0;
962   constexpr int stop = 1;
963   constexpr int step = 2;
964   constexpr int slice_length = 3;
965 
966   const auto &slice = computed_slice;
967   if (slice[step] == 1) {
968     elements->erase(elements->begin() + slice[start], elements->begin() + slice[stop]);
969     if (new_elements != nullptr) {
970       elements->insert(elements->begin() + slice[start], new_elements->begin(), new_elements->end());
971     }
972     return true;
973   }
974   if (new_elements != nullptr && new_elements->size() != static_cast<size_t>(slice[slice_length])) {
975     return false;
976   }
977   for (Py_ssize_t cur = slice[start], i = 0; i < slice[slice_length]; cur += slice[step], ++i) {
978     (*elements)[cur] = new_elements == nullptr ? nullptr : (*new_elements)[i];
979   }
980   if (new_elements == nullptr) {
981     elements->erase(std::remove(elements->begin(), elements->end(), nullptr), elements->end());
982   }
983   return true;
984 }
985 
TransformListSetItem(ValueNode * map,ValueNode * key,ValueNode * value)986 ValueNode *GraphBuilder::TransformListSetItem(ValueNode *map, ValueNode *key, ValueNode *value) {
987   PyObject *index_object = key->GetVobj()->GetPyObject().ptr();
988   if (index_object == nullptr || !key->IsConstantValue()) {
989     return nullptr;  // only supported constant key
990   }
991   PyObject *map_object = map->GetVobj()->GetPyObject().ptr();
992   std::vector<ValueNode *> elements;
993   if (map->GetOpcode() == BUILD_LIST) {
994     elements = map->getInputs();
995   } else if (UnpackElements(map)) {
996     Py_ssize_t size = PyList_GET_SIZE(map_object);
997     elements = {frame().GetStacks().end() - size, frame().GetStacks().end()};
998     popn(size);
999   } else {
1000     return nullptr;
1001   }
1002 
1003   // compute slice
1004   auto slice = ListIndexCompute(index_object, elements.size());
1005   if (slice.empty()) {
1006     return nullptr;
1007   }
1008   // set(delete) elements
1009   size_t stack_size = frame_.GetStacks().size();
1010   if (!PySlice_Check(index_object)) {
1011     auto iter = elements.begin() + slice[0];
1012     (void)(value == nullptr ? elements.erase(iter) : (*iter = value, iter));
1013   } else if (value == nullptr && SetSlice(&elements, slice)) {
1014     // delete success
1015   } else if (value != nullptr && UnpackElements(value)) {
1016     // unpack success
1017     stack_size = frame_.GetStacks().size() - stack_size;
1018     std::vector<ValueNode *> new_elements = {frame_.GetStacks().end() - stack_size, frame_.GetStacks().end()};
1019     popn(stack_size);
1020     if (!SetSlice(&elements, slice, &new_elements)) {
1021       return nullptr;
1022     }
1023     // set succuss
1024   } else {
1025     return nullptr;
1026   }
1027 
1028   std::for_each(elements.begin(), elements.end(), [this](ValueNode *i) { this->push(i); });
1029   DoBuildOp({BUILD_LIST, SizeToInt(elements.size())});
1030   return pop();
1031 }
1032 
DoSetItem(ValueNode * map,ValueNode * key,ValueNode * value)1033 bool GraphBuilder::DoSetItem(ValueNode *map, ValueNode *key, ValueNode *value) {
1034   // only support constant key
1035   if (!this->graph_->GuardValueNode(key)) {
1036     return false;
1037   }
1038   // erase side-effect
1039   ValueNode *side_effect_node = graph_->GetTracedNodes().back();
1040   graph_->GetTracedNodes().pop_back();
1041 
1042   // try to transform
1043   const auto &replace_map = graph_->GetSideEffect()->data()->modified_and_replaced_map();
1044   bool is_new_var = false;
1045   ValueNode *old_node = map;
1046   ValueNode *new_node = nullptr;
1047   AObject::Type type = map->GetVobj()->GetType();
1048   if (type == AObject::kTypeList) {
1049     is_new_var = map->GetOpcode() == BUILD_LIST && replace_map.find(map) == replace_map.end();
1050     new_node = TransformListSetItem(map, key, value);
1051   } else if (type == AObject::kTypeDict) {
1052     is_new_var = map->GetOpcode() == BUILD_MAP && replace_map.find(map) == replace_map.end();
1053     new_node = TransformDictSetItem(map, key, value, false);
1054   }
1055   // failed transform, restore side-effect
1056   if (new_node == nullptr) {
1057     graph_->GetTracedNodes().push_back(side_effect_node);
1058     return false;
1059   }
1060   bool is_referenced = false;
1061   ReplaceAll(old_node, new_node, &is_referenced);
1062   // check it is new variable and not escaped
1063   if (is_new_var && !is_referenced && map != value) {
1064     return true;
1065   }
1066   // restore and record
1067   this->graph_->GetTracedNodes().push_back(side_effect_node);
1068   this->graph_->GetSideEffect()->data()->RecordModifiedAndReplacedNode(old_node, new_node);
1069   this->graph_->GetSideEffect()->Record(side_effect_node);
1070   return true;
1071 }
1072 
DoItemAccess(const Instr & instr)1073 bool GraphBuilder::DoItemAccess(const Instr &instr) {
1074   int opcode = instr.op();
1075   if (opcode == BINARY_SUBSCR) {
1076     DoGetItem(instr);
1077   } else if (opcode == STORE_SUBSCR) {
1078     auto key = pop();
1079     auto map = pop();
1080     auto value = pop();
1081     NewValueNode(nullptr, instr, {value, map, key});
1082     DoSetItem(map, key, value);
1083   } else if (opcode == DELETE_SUBSCR) {
1084     auto key = pop();
1085     auto map = pop();
1086     NewValueNode(nullptr, instr, {map, key});
1087     DoSetItem(map, key, nullptr);
1088   } else {
1089     MS_LOG(INTERNAL_EXCEPTION) << "parser got an error instruction " << instr.ToString();
1090   }
1091   return true;
1092 }
1093 
DoStackOp(const Instr & instr)1094 bool GraphBuilder::DoStackOp(const Instr &instr) {
1095   int opcode = instr.op();
1096   int oparg = instr.arg();
1097   if (opcode == POP_TOP) {
1098     pop();
1099   } else if (opcode == ROT_TWO) {
1100     frame_.Rot(1);
1101   } else if (opcode == ROT_THREE) {
1102     frame_.Rot(2);
1103   } else if (opcode == ROT_FOUR) {
1104     frame_.Rot(3);
1105   } else if (opcode == ROT_N) {
1106     frame_.Rot(oparg - 1);
1107   } else if (opcode == DUP_TOP_TWO) {
1108     push(seek(1));
1109     push(seek(1));
1110   } else if (opcode == DUP_TOP) {
1111     push(seek(0));
1112   } else {
1113     MS_LOG(INTERNAL_EXCEPTION) << "parser got an error instruction " << instr.ToString();
1114   }
1115   return true;
1116 }
1117 
DoLoadConst(const Instr & instr)1118 bool GraphBuilder::DoLoadConst(const Instr &instr) {
1119   auto n = NewValueNode(AObject::Convert(instr.cnst()), instr, {});
1120   push(n);
1121   return true;
1122 }
1123 
DoListToTuple(const Instr & instr)1124 bool GraphBuilder::DoListToTuple(const Instr &instr) {
1125   ValueNode *list = pop();
1126   if (list->GetOpcode() == BUILD_LIST) {
1127     std::for_each(list->getInputs().begin(), list->getInputs().end(), [this](ValueNode *i) { this->push(i); });
1128     return DoBuildOp({BUILD_TUPLE, SizeToInt(list->getInputs().size())});
1129   }
1130   AObject *vo = list->GetVobj();
1131   if (vo && vo->GetType() == AObject::kTypeList) {
1132     vo = static_cast<AbstractList *>(vo)->ListToTuple();
1133   } else {
1134     vo = AObject::MakeAObject(AObject::kTypeAnyValue);
1135   }
1136   ValueNode *tuple = NewValueNode(vo, instr, {list});
1137   push(tuple);
1138   return true;
1139 }
1140 
DoGetIter(const Instr & instr)1141 bool GraphBuilder::DoGetIter(const Instr &instr) {
1142   auto obj = pop();
1143   auto o = obj->GetVobj();
1144   auto iter = NewValueNode(o ? o->GetIter() : AObject::MakeAObject(AObject::kTypeAnyValue), instr, {obj});
1145   push(iter);
1146   iter->marker_ = 0;
1147   return true;
1148 }
1149 
DoMakeFunction(const Instr & instr)1150 bool GraphBuilder::DoMakeFunction(const Instr &instr) {
1151   int oparg = instr.arg();
1152   // int cnt = __builtin_popcount(oparg & 0xf) + 2;
1153   int cnt = !!(oparg & 0x08) + !!(oparg & 0x04) + !!(oparg & 0x02) + !!(oparg & 0x01) + 2;
1154   std::vector<ValueNode *> p(frame_.GetStacks().end() - cnt, frame_.GetStacks().end());
1155   popn(cnt);
1156   AObject *f = AObject::MakeFunction(CollectObjects(p), graph_->GetGlobals(), oparg);
1157   ValueNode *func = NewValueNode(f, instr, p);
1158   push(func);
1159   current_block_->SetTrackResult(Block::kHasGlobalSideEffect);
1160   return true;
1161 }
1162 
InferUnary(ValueNode * node,const Instr & instr)1163 AObject *GraphBuilder::InferUnary(ValueNode *node, const Instr &instr) { return node->GetVobj()->Unary(instr.op()); }
1164 
DoUnary(const Instr & instr)1165 bool GraphBuilder::DoUnary(const Instr &instr) {
1166   ValueNode *node = pop();
1167 
1168   AObject *object_info = InferUnary(node, instr);
1169   ValueNode *new_node = NewValueNode(object_info, instr, {node});
1170   push(new_node);
1171   return true;
1172 }
1173 
DoIsOp(const Instr & instr)1174 bool GraphBuilder::DoIsOp(const Instr &instr) { return DoBinary(instr); }
1175 
InferBinary(ValueNode * left,ValueNode * right,const Instr & instr)1176 AObject *GraphBuilder::InferBinary(ValueNode *left, ValueNode *right, const Instr &instr) {
1177   AObject *object_info;
1178   if (instr.op() == IS_OP || instr.op() == CONTAINS_OP) {
1179     object_info = left->GetVobj()->Binary(right->GetVobj(), instr.op());
1180     PyObject *object = object_info != nullptr ? object_info->GetPyObject().ptr() : nullptr;
1181     if (object != nullptr) {
1182       object_info = AObject::Convert(py::bool_((object == Py_True) ^ instr.arg()));
1183     }
1184   } else if (Opcode(instr.op()).IsBinaryMath()) {
1185     if (left->IsConstantValue() && right->IsConstantValue()) {
1186       // compute real tensor value, not infer fake value
1187       AbstractObject *tensor = static_cast<AbstractObject *>(left->GetVobj());
1188       object_info = tensor->AbstractObject::Binary(right->GetVobj(), instr.op());
1189     } else {
1190       object_info = left->GetVobj()->Binary(right->GetVobj(), instr.op());
1191     }
1192   } else {
1193     return AObject::MakeAObject(AObject::kTypeAnyValue);
1194   }
1195   return object_info;
1196 }
1197 
DoBinary(const Instr & instr)1198 bool GraphBuilder::DoBinary(const Instr &instr) {
1199   ValueNode *right = pop();
1200   ValueNode *left = pop();
1201 
1202   AObject *object_info = InferBinary(left, right, instr);
1203   ValueNode *new_node = NewValueNode(object_info, instr, {left, right});
1204   push(new_node);
1205   return true;
1206 }
1207 
CheckTupleListMul(ValueNode * left,ValueNode * right)1208 static bool CheckTupleListMul(ValueNode *left, ValueNode *right) {
1209   bool special = left->GetOpcode() == BUILD_LIST || left->GetOpcode() == BUILD_TUPLE;
1210   if (!special && left->IsConstantValue()) {
1211     AObject::Type l_type = left->GetVobj()->GetType();
1212     special = l_type == AObject::kTypeTuple || l_type == AObject::kTypeList;
1213   }
1214   if (special && right->IsConstantValue()) {
1215     PyObject *mul = right->GetVobj()->GetPyObject().ptr();
1216     const int max = 2;
1217     return PyLong_Check(mul) && Py_ABS(Py_SIZE(mul)) < max;
1218   }
1219   return false;
1220 }
1221 
DoBinaryMul(const Instr & instr)1222 bool GraphBuilder::DoBinaryMul(const Instr &instr) {
1223   if (!CheckTupleListMul(seek(1), seek(0))) {
1224     return DoBinary(instr);
1225   }
1226 
1227   ValueNode *right = pop();
1228   ValueNode *left = pop();
1229   int l_op = left->GetVobj()->GetType() == AObject::kTypeTuple ? BUILD_TUPLE : BUILD_LIST;
1230 
1231   Py_ssize_t mul = PyLong_AsSsize_t(right->GetVobj()->GetPyObject().ptr());
1232   for (auto i = mul; i > 0; --i) {
1233     UnpackElements(left);
1234   }
1235   int oparg = left->getInputs().size() * (mul < 0 ? 0 : size_t(mul));
1236   DoBuildOp({l_op, oparg});
1237   return true;
1238 }
1239 
CheckTupleListAdd(ValueNode * left,ValueNode * right)1240 static bool CheckTupleListAdd(ValueNode *left, ValueNode *right) {
1241   // type must be same
1242   AObject::Type l_type = left->GetVobj()->GetType();
1243   AObject::Type r_type = right->GetVobj()->GetType();
1244   bool support = l_type == AObject::kTypeTuple || l_type == AObject::kTypeList;
1245   if (!support || l_type != r_type) {
1246     return false;
1247   }
1248   // only handle BUILD_TUPLE and BUILD_LIST
1249   int l_op = left->GetOpcode();
1250   int r_op = right->GetOpcode();
1251   bool special = l_op == BUILD_TUPLE || l_op == BUILD_LIST || l_op == LOAD_CONST;
1252   bool accept = r_op == BUILD_TUPLE || r_op == BUILD_LIST || r_op == LOAD_CONST;
1253   if (!special || !accept) {
1254     return false;
1255   }
1256   return true;
1257 }
1258 
DoInplaceAdd(const Instr & instr)1259 bool GraphBuilder::DoInplaceAdd(const Instr &instr) {
1260   AObject::Type l_type = seek(1)->GetVobj()->GetType();
1261   if (l_type == AObject::kTypeTuple) {
1262     return DoBinaryAdd(instr);
1263   }
1264   if (!CheckTupleListAdd(seek(1), seek(0))) {
1265     return DoBinary(instr);
1266   }
1267 
1268   ValueNode *right = pop();
1269   ValueNode *left = pop();
1270   int l_op = BUILD_LIST;
1271 
1272   int size = this->frame_.GetStacks().size();
1273   UnpackElements(left);
1274   UnpackElements(right);
1275   size = this->frame_.GetStacks().size() - size;
1276   DoBuildOp({l_op, size});
1277 
1278   ValueNode *new_node = pop();
1279   if (ReplaceAll(left, new_node)) {
1280     push(new_node);
1281     return true;
1282   }
1283   graph_->GetTracedNodes().pop_back();
1284   push(left);
1285   push(right);
1286   return DoBinary(instr);
1287 }
1288 
DoBinaryAdd(const Instr & instr)1289 bool GraphBuilder::DoBinaryAdd(const Instr &instr) {
1290   if (!CheckTupleListAdd(seek(1), seek(0))) {
1291     return DoBinary(instr);
1292   }
1293 
1294   ValueNode *right = pop();
1295   ValueNode *left = pop();
1296   int l_op = left->GetVobj()->GetType() == AObject::kTypeTuple ? BUILD_TUPLE : BUILD_LIST;
1297 
1298   int size = this->frame_.GetStacks().size();
1299   UnpackElements(left);
1300   UnpackElements(right);
1301   size = this->frame_.GetStacks().size() - size;
1302   DoBuildOp({l_op, size});
1303   return true;
1304 }
1305 
DoCompare(const Instr & instr)1306 bool GraphBuilder::DoCompare(const Instr &instr) {
1307   Opcode opcode(instr.op());
1308   int oparg = instr.arg();
1309   auto r = pop();
1310   auto l = pop();
1311 
1312   bool invert;
1313   AObject *o;
1314   if (oparg >= Py_LT && oparg <= Py_GE) {
1315     PyObject *left = l->GetVobj() ? l->GetVobj()->GetPyObject().ptr() : nullptr;
1316     PyObject *right = r->GetVobj() ? r->GetVobj()->GetPyObject().ptr() : nullptr;
1317     if (left && right) {
1318       if (CheckValueValid(l->GetVobj()) && CheckValueValid(r->GetVobj())) {
1319         o = AObject::Convert(PyObject_RichCompare(left, right, oparg));
1320         PyErr_Clear();
1321       } else if (l->GetVobj()->GetType() == AObject::kTypeTensor || r->GetVobj()->GetType() == AObject::kTypeTensor) {
1322         o = l->GetVobj()->GetType() == AObject::kTypeTensor ? l->GetVobj() : r->GetVobj();
1323         auto tensor_type = py::reinterpret_borrow<py::object>(GetMsTensorType());
1324         py::object dtype_bool = Utils::GetModuleAttr("mindspore.common.dtype", "bool_");
1325         auto result_tensor = tensor_type(o->GetPyObject(), dtype_bool);
1326         o = AObject::Convert(result_tensor);
1327       } else {
1328         o = AObject::MakeAObject(AObject::kTypeBool);
1329       }
1330     } else {
1331       o = AObject::MakeAObject(AObject::kTypeBool);
1332     }
1333   } else if (opcode.CheckIsOp(oparg, &invert)) {
1334     int res = AObject::BinaryIs(l->GetVobj(), r->GetVobj());
1335     o = res == -1 ? AObject::MakeAObject(AObject::kTypeBool) : AObject::Convert((res ^ invert) ? Py_True : Py_False);
1336   } else if (opcode.CheckContainsOp(oparg, &invert)) {
1337     int res = AObject::BinaryContains(l->GetVobj(), r->GetVobj());
1338     o = res == -1 ? AObject::MakeAObject(AObject::kTypeBool) : AObject::Convert((res ^ invert) ? Py_True : Py_False);
1339   } else {
1340     return false;
1341   }
1342 
1343   auto v = NewValueNode(o, instr, {l, r});
1344   push(v);
1345   return true;
1346 }
1347 
DoBuildOp(const Instr & instr)1348 bool GraphBuilder::DoBuildOp(const Instr &instr) {
1349   int opcode = instr.op();
1350   int oparg = instr.arg();
1351   int tmp_arg = oparg;
1352   tmp_arg += opcode == BUILD_CONST_KEY_MAP;
1353   tmp_arg += opcode == BUILD_MAP ? tmp_arg : 0;
1354   std::vector<ValueNode *> p(frame_.GetStacks().end() - tmp_arg, frame_.GetStacks().end());
1355   popn(tmp_arg);
1356 
1357   ValueNode *v;
1358   if (opcode == BUILD_CONST_KEY_MAP) {
1359     PyObject *keys = p.back()->GetVobj()->GetPyObject().ptr();
1360     MS_EXCEPTION_IF_CHECK_FAIL(keys && PyTuple_CheckExact(keys), "error bytecode BUILD_CONST_KEY_MAP");
1361     Py_ssize_t size = PyTuple_GET_SIZE(keys);
1362     MS_EXCEPTION_IF_CHECK_FAIL(size_t(size) + 1 == p.size(), "error args BUILD_CONST_KEY_MAP");
1363     std::vector<ValueNode *> build_inputs;
1364     for (Py_ssize_t i = 0; i < size; ++i) {
1365       PyObject *item = PyTuple_GET_ITEM(keys, i);
1366       build_inputs.push_back(NewValueNode(AObject::Convert(item), LOAD_CONST, -1));
1367       build_inputs.push_back(p[i]);
1368     }
1369     AObject *vo = AObject::BuildOperations(CollectObjects(build_inputs), BUILD_MAP);
1370     v = NewValueNode(vo, instr, build_inputs);
1371     v->SetOpcode(BUILD_MAP);
1372     v->SetOparg(size);
1373   } else {
1374     AObject *vo = AObject::BuildOperations(CollectObjects(p), opcode);
1375     v = NewValueNode(vo, instr, p);
1376   }
1377   push(v);
1378   return true;
1379 }
1380 
ReplaceMergeOp(int opcode,const std::vector<ValueNode * > & inputs)1381 ValueNode *GraphBuilder::ReplaceMergeOp(int opcode, const std::vector<ValueNode *> &inputs) {
1382   ValueNode *origin = inputs[0];
1383   ValueNode *arg = inputs[1];
1384   ValueNode *arg2 = inputs.size() > 2 ? inputs[2] : nullptr;
1385   if (origin->GetOpcode() != BUILD_LIST && origin->GetOpcode() != BUILD_MAP) {
1386     return nullptr;
1387   }
1388   std::vector<ValueNode *> build_inputs = origin->getInputs();
1389   int div = 2;
1390   if (opcode == LIST_APPEND) {
1391     build_inputs.push_back(arg);
1392     opcode = BUILD_LIST;
1393     div = 1;
1394   } else if (opcode == LIST_EXTEND) {
1395     if (arg->IsConstantValue()) {
1396       build_inputs = UnpackConstObject(arg->GetConstantInfo()->value());
1397     } else if (arg->GetOpcode() == BUILD_LIST || arg->GetOpcode() == BUILD_TUPLE) {
1398       build_inputs.insert(build_inputs.end(), arg->getInputs().begin(), arg->getInputs().end());
1399     } else {
1400       return nullptr;
1401     }
1402     opcode = BUILD_LIST;
1403     div = 1;
1404   } else if (opcode == DICT_MERGE || opcode == DICT_UPDATE) {
1405     if (arg->GetOpcode() != BUILD_MAP) {
1406       return nullptr;
1407     }
1408     build_inputs.insert(build_inputs.end(), arg->getInputs().begin(), arg->getInputs().end());
1409     opcode = BUILD_MAP;
1410   } else if (opcode == MAP_ADD) {
1411     build_inputs.push_back(arg);
1412     build_inputs.push_back(arg2);
1413     opcode = BUILD_MAP;
1414   } else {
1415     return nullptr;
1416   }
1417   std::for_each(build_inputs.begin(), build_inputs.end(), [this](ValueNode *i) { this->push(i); });
1418   int oparg = build_inputs.size() / div;
1419   DoBuildOp({opcode, oparg});
1420   return pop();
1421 }
1422 
DoMergeOp(const Instr & instr)1423 bool GraphBuilder::DoMergeOp(const Instr &instr) {
1424   int opcode = instr.op();
1425   int oparg = instr.arg();
1426   int pos = oparg + (opcode == MAP_ADD);
1427 
1428   int index = this->frame_.GetStacks().size() - 1 - pos;
1429   ValueNode *container = seek(pos);
1430   std::vector<ValueNode *> inputs = {container, pop()};
1431   if (opcode == MAP_ADD) {
1432     inputs.insert(inputs.begin() + 1, pop());
1433   }
1434 
1435   // DICT_MERGE only generated when unpack-call in python3.9, all keys must be string
1436   // NOTE: DICT_MERGE opcode requires that *(stack_pointer - oparg - 2) is a function if has duplicate key
1437   // ...
1438   ValueNode *new_node = ReplaceMergeOp(opcode, inputs);
1439   if (new_node != nullptr) {
1440     this->frame_.GetStacks()[index] = new_node;
1441     return true;
1442   }
1443 
1444   return false;
1445 }
1446 
DoFormatValue(const Instr & instr)1447 bool GraphBuilder::DoFormatValue(const Instr &instr) {
1448   int oparg = instr.arg();
1449   std::vector<ValueNode *> arg;
1450   if ((oparg & FVS_MASK) == FVS_HAVE_SPEC) {
1451     arg.push_back(pop());
1452   }
1453   arg.insert(arg.begin(), pop());
1454   auto vo = AObject::MakeAObject(AObject::kTypeString);
1455   auto v = NewValueNode(vo, instr, arg);
1456   push(v);
1457   return true;
1458 }
1459 
DoImport(const Instr & instr)1460 bool GraphBuilder::DoImport(const Instr &instr) {
1461   int opcode = instr.op();
1462   if (opcode == IMPORT_FROM) {
1463     // any object
1464     push(NewValueNode(AObject::MakeAObject(AObject::kTypeAnyValue), instr, {seek(0)}));
1465   } else if (opcode == IMPORT_STAR) {
1466     auto from = pop();
1467     NewValueNode(AObject::MakeAObject(AObject::kTypeAnyValue), instr, {from});
1468   } else if (opcode == IMPORT_NAME) {
1469     auto from_list = pop();
1470     auto level = pop();
1471     auto vo = AObject::MakeAObject(AObject::kTypeModule);
1472     auto v = NewValueNode(vo, instr, {level, from_list});
1473     push(v);
1474   } else {
1475     return false;
1476   }
1477   return true;
1478 }
1479 
DoByteCode(const Instr & instr)1480 bool GraphBuilder::DoByteCode(const Instr &instr) {
1481   if (current_block_->is_loop_head() && !graph_->Config().GetBoolConfig(GraphJitConfig::kLoopUnrolling)) {
1482     graph_->StopTraceAt(cur_bci_, StopTraceReason::kStopTraceLoop_Unsupported);
1483     return false;
1484   }
1485 
1486   auto func_iter = bytecode_meth_map_.find(instr.op());
1487   bool support = false;
1488   if (func_iter != bytecode_meth_map_.end()) {
1489     const auto func = func_iter->second;
1490     support = (this->*func)(instr);
1491   }
1492 
1493   const auto &nodes = graph_->GetTracedNodes();
1494   for (auto i = nodes.rbegin(); i != nodes.rend() && (*i)->GetBlock() == nullptr; ++i) {
1495     (*i)->SetBlock(current_block_);
1496   }
1497 
1498   if (instr.op() == RETURN_VALUE) {
1499     return false;
1500   }
1501 
1502   if (!support) {
1503     if (graph_->GetStopTraceBci() == -1) {
1504       graph_->StopTraceAt(cur_bci_, StopTraceReason::kStopTraceByteCode_Unsupported);
1505     }
1506     return false;
1507   }
1508 
1509   if (instr.extra_jump() == nullptr) {
1510     ++cur_bci_;
1511   } else {
1512     bool valid = (cur_bci_ == instr.bci() + 1) || cur_bci_ == instr.extra_jump()->bci();
1513     MS_EXCEPTION_IF_CHECK_FAIL(valid, "error jump target");
1514   }
1515   if (cur_bci_ < current_block_->begin_ci() || cur_bci_ >= current_block_->end_ci()) {
1516     current_block_ = graph_->GetCFG()->GetBlockByBci(cur_bci_);
1517   }
1518   return true;
1519 }
1520 
GraphBuilder(const PyFrameObject * f)1521 GraphBuilder::GraphBuilder(const PyFrameObject *f)
1522     : root_(this), parent_(nullptr), graph_(nullptr), current_block_(nullptr) {
1523   PyCodeObject *co = f->f_code;
1524   int argc = co->co_argcount + co->co_kwonlyargcount;
1525   argc += (co->co_flags & CO_VARARGS) ? 1 : 0;
1526   argc += (co->co_flags & CO_VARKEYWORDS) ? 1 : 0;
1527   int ncells = PyTuple_GET_SIZE(co->co_cellvars);
1528   int nfrees = PyTuple_GET_SIZE(co->co_freevars);
1529 
1530   graph_ = NewGraph(co, f->f_globals);
1531 
1532   frame_.ResizeLocal(co->co_nlocals);
1533   frame_.ResizeClosure(ncells + nfrees);
1534   for (int i = 0; i < argc; i++) {
1535     if (f->f_localsplus[i] == nullptr) {
1536       continue;
1537     }
1538     auto vo = AObject::Convert(f->f_localsplus[i]);
1539     ParamNode *n = graph_->allocator().NewNode<ParamNode>(vo, i);
1540     n->SetName(PyUnicode_AsUTF8(PyTuple_GET_ITEM(co->co_varnames, i)));
1541     frame_.SetLocal(i, n);
1542     graph_->GetSideEffect()->data()->Track(f->f_localsplus[i], n);
1543   }
1544   for (int i = 0; i < ncells + nfrees; i++) {
1545     PyObject *cell = f->f_localsplus[co->co_nlocals + i];
1546     PyObject *cell_contents = PyCell_GET(cell);
1547     AbstractNode::Type t = i < ncells ? AbstractNode::CellVar : AbstractNode::FreeVar;
1548     CellVarNode *n = graph_->allocator().NewNode<CellVarNode>(t);
1549     n->SetGraph(graph_);
1550     n->SetVobj(AObject::Convert(cell));
1551     n->SetIndex(i);
1552     frame_.SetClosure(i, n);
1553     if (i < ncells && co->co_cell2arg != nullptr && co->co_cell2arg[i] != CO_CELL_NOT_AN_ARG) {
1554       MS_EXCEPTION_IF_NULL(cell_contents);
1555       n->SetFromParam(co->co_cell2arg[i]);
1556     }
1557     if (cell_contents == nullptr) {
1558       n->SetValue(&ValueNode::kUnboundLocal);
1559     } else {
1560       ValueNode *param = NewValueNode(AObject::Convert(cell_contents), LOAD_DEREF, i);
1561       param->SetGraph(graph_);
1562       n->AddCellOper(param);
1563       n->SetValue(param);
1564     }
1565   }
1566 }
1567 
CollectInlineInfo(CallNode * node,int depth)1568 void GraphBuilder::CollectInlineInfo(CallNode *node, int depth) {
1569   Graph *sub_graph = node->GetSubGraph();
1570   if (!sub_graph) {
1571     return;
1572   }
1573   std::string inline_name = "";
1574   int code_size = 0;
1575   if (sub_graph != nullptr && sub_graph->GetCodeObj() != nullptr) {
1576     inline_name = py::str(reinterpret_cast<PyObject *>(sub_graph->GetCodeObj())).cast<std::string>();
1577     code_size = SizeToInt((PyBytes_GET_SIZE(sub_graph->GetCodeObj()->co_code)) / sizeof(_Py_CODEUNIT));
1578   }
1579   std::string func_name = graph_->GetCodeName();
1580   std::string root_name = root_->GetGraph()->GetCodeName();
1581   JitCompileResults *jcr = getJitCompileResults(reinterpret_cast<PyObject *>(root_->GetGraph()->GetCodeObj()), false);
1582   if (jcr && jcr->tbs && !func_name.empty()) {
1583     jcr->tbs->PushInlineInfo(
1584       {func_name, inline_name, root_name, node->GetInlineReason(), code_size, depth, node->GetLineNo()});
1585   }
1586 }
1587 
HandleLoop()1588 void GraphBuilder::HandleLoop() {
1589   Block *loop_head = graph_->GetCFG()->GetBlockByBci(cur_bci_);
1590   if (!loop_head->is_loop_head()) {
1591     return;
1592   }
1593   /**
1594    * (chaiyouheng): before trace start, unrolling loop. avoid graph status is changed while trace loop
1595    *       just unrolling a small loop that call nn.CellList.
1596    *
1597    * LoopUnrolling loopUnrollingExe = LoopUnrolling(*graph_);
1598    * (void)loopUnrollingExe.ExecuteLoopUnroll(loop_head);
1599    */
1600 }
1601 
FindPyFunc(AObject * vobj)1602 py::object GraphBuilder::FindPyFunc(AObject *vobj) {
1603   if (!vobj) {
1604     return py::cast<py::object>(nullptr);
1605   }
1606 
1607   switch (vobj->GetType()) {
1608     case AObject::kTypeCell:
1609       vobj = vobj->GetAttr(ID_construct);
1610       break;
1611     case AObject::kTypeAnyValue:
1612       vobj = vobj->GetAttr(ID___call__);
1613       break;
1614     case AObject::kTypeType:
1615       vobj = vobj->GetAttr("__init__");
1616       break;
1617     case AObject::kTypeBoundMethod:
1618       vobj = vobj->GetAttr("__func__");
1619     default:
1620       break;
1621   }
1622   py::object func = vobj ? vobj->GetPyObject() : py::object();
1623 
1624   if (func.ptr() == nullptr) {
1625     PyErr_Clear();
1626     return py::cast<py::object>(nullptr);
1627   }
1628 
1629   if (PyMethod_Check(func.ptr())) {
1630     func = py::reinterpret_borrow<py::object>(PyMethod_GET_FUNCTION(func.ptr()));
1631   }
1632 
1633   if (PyFunction_Check(func.ptr())) {
1634     return func;
1635   }
1636   return py::cast<py::object>(nullptr);
1637 }
1638 
GetFuncInfo(ValueNode * func_node)1639 py::object GraphBuilder::GetFuncInfo(ValueNode *func_node) {
1640   AObject *vobj = func_node->GetVobj();
1641   if (vobj->GetType() == AObject::kTypeCFunction) {
1642     return py::object();
1643   }
1644   if (func_node->GetOpcode() == MAKE_FUNCTION) {
1645     return func_node->GetVobj()->GetPyObject();
1646   }
1647   return FindPyFunc(vobj);
1648 }
1649 
WhiteListFuncCheckAndInfer(CallNode * call_node,const py::object & callable)1650 bool GraphBuilder::WhiteListFuncCheckAndInfer(CallNode *call_node, const py::object &callable) {
1651   const auto &conf = call_node->GetGraph()->Config();
1652 
1653   AObject::Type vobj_type = call_node->input(0)->GetVobj()->GetType();
1654   if (vobj_type == AObject::kTypeCell) {
1655     current_block_->SetTrackResult(Block::kTrackHasOpsPrimitive);
1656     std::string module_name = GetTopModule(callable);
1657     if (!module_name.empty()) {
1658       kPIJitConfigDefault.AddAllowedInlineModules(module_name);
1659     }
1660   }
1661 
1662   bool infer_primitive = conf.GetBoolConfig(GraphJitConfig::kInferPrimitive);
1663   int max_infer = conf.getIntConfig(GraphJitConfig::kInferPrimitiveMax);
1664   if (max_infer != 0 && infer_func_count >= max_infer) {
1665     infer_primitive = false;
1666   } else {
1667     infer_func_count++;
1668   }
1669   infer_primitive &= (conf.getIntConfig(GraphJitConfig::kInferPrimitiveMask) & infer_primitive_func) != 0;
1670   if (!infer_primitive && vobj_type == AObject::kTypePrimitive) {
1671     call_node->SetVobj(AObject::MakeAObject(AObject::kTypeTensor));
1672     call_node->SetInlineReason(InlineReason::kInlineGraphSupportedByMS);
1673     current_block_->SetTrackResult(Block::kTrackHasOpsPrimitive);
1674     return true;
1675   }
1676 
1677   InferFunc infer_func = FindInferFunc(callable);
1678   if (infer_func == nullptr) {
1679     return false;
1680   }
1681 
1682   call_node->SetInlineReason(InlineReason::kInlineUnknown);
1683   call_node->SetSubGraph(NewGraph(nullptr, nullptr));
1684   call_node->GetSubGraph()->SetGuard(root_->GetGraph()->GetGuard());
1685   infer_func(call_node, this);
1686 
1687   InlineReason r;
1688   if (call_node->GetSubGraph() == nullptr) {
1689     r = InlineReason::kInlineFuncSpecialize;
1690   } else {
1691     MS_EXCEPTION_IF_NULL(call_node->GetSubGraph()->GetRetVal());
1692     r = InlineReason::kInline;
1693     seek(0) = call_node->GetSubGraph()->GetRetVal();
1694   }
1695   if (call_node->GetInlineReason() == InlineReason::kInlineUnknown) {
1696     call_node->SetInlineReason(r);
1697   }
1698   return true;
1699 }
1700 
UnsupportedCodeTypeCheck(PyCodeObject * co)1701 bool UnsupportedCodeTypeCheck(PyCodeObject *co) {
1702   if (co->co_flags & (CO_GENERATOR | CO_COROUTINE | CO_ASYNC_GENERATOR)) {
1703     MS_LOG(DEBUG) << "generator is unsupported";
1704     return true;
1705   }
1706   /**
1707    * skip super call
1708    * >>>def super_wrapper(self):
1709    * ...    __class__=type(self)
1710    * ...    def super_init(self):
1711    * ...        return super()
1712    * ...    return super_init(self)
1713    * >>>assert super(int, 1).__hash__() == super_wrapper(1).__hash__()
1714    */
1715   return false;
1716 }
1717 
ApplyInlinePolicy(CallNode * call_node)1718 bool ApplyInlinePolicy(CallNode *call_node) {
1719   Graph *g = call_node->GetSubGraph();
1720   if (g == nullptr || g->GetRetVal() == nullptr) {
1721     return false;
1722   }
1723 
1724   PyCodeObject *co = g->GetCodeObj();
1725   int ncells = PyTuple_GET_SIZE(co->co_cellvars);
1726   int nfrees = PyTuple_GET_SIZE(co->co_freevars);
1727 
1728   bool is_make_func = call_node->input(0)->GetOpcode() == MAKE_FUNCTION;
1729   if (is_make_func) {
1730     // inline MAKE_FUNCTION, need eliminate cell and free variable if the function is not dead local.
1731     return ncells == 0;
1732   }
1733 
1734   const auto &closures = g->GetFrame(0).GetClosures();
1735   if (std::any_of(closures.begin(), closures.begin() + ncells, [](auto n) { return !n->GetCellOper().empty(); })) {
1736     return false;
1737   }
1738   if (nfrees > 0) {
1739     // if inline, guard free variable
1740     return nfrees == 1 && std::string("__class__") == PyUnicode_AsUTF8(PyTuple_GET_ITEM(co->co_freevars, 0));
1741   }
1742   if (g->GetRetVal()->GetOpcode() == MAKE_FUNCTION) {
1743     return false;
1744   }
1745   for (auto i : g->GetTracedNodes()) {
1746     // check MAKE_FUNCTION is alive, it is incorrect that inline the function of different module with MAKE_FUNCTION
1747     auto begin = i->getInputs().begin();
1748     if (Opcode(i->GetOpcode()).IsCall() && static_cast<CallNode *>(i)->GetInlineReason() == InlineReason::kInline) {
1749       begin++;
1750     }
1751     if (std::any_of(begin, i->getInputs().end(), [](ValueNode *n) { return n->GetOpcode() == MAKE_FUNCTION; })) {
1752       return false;
1753     }
1754   }
1755   return true;
1756 }
1757 
CheckSupportCreateInstance(CallNode * call_node)1758 bool CheckSupportCreateInstance(CallNode *call_node) {
1759   /**
1760    * only support exactly type, sub-class not create
1761    */
1762   static const std::set<PyTypeObject *> support_create_instance_type = {
1763     &PyComplex_Type, &PyMap_Type,   &PyBaseObject_Type, &PyRange_Type, &PyZip_Type,    &PySlice_Type,
1764     &PyBool_Type,    &PyFloat_Type, &PyLong_Type,       &PyType_Type,  &PyMethod_Type,
1765   };
1766 
1767   AObject *cls_info = call_node->input(0)->GetVobj();
1768   PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(static_cast<AbstractType *>(cls_info)->GetPyObject().ptr());
1769   if (tp == nullptr) {
1770     return false;
1771   }
1772   if (support_create_instance_type.find(tp) != support_create_instance_type.end()) {
1773     return true;
1774   }
1775 
1776   /**
1777    * maybe has sideeffect, limit create
1778    */
1779   static const std::set<PyTypeObject *> limit_create_instance_type = {
1780     &PyList_Type, &PyTuple_Type, &PySet_Type, &PyFrozenSet_Type, &PyDict_Type, &PyUnicode_Type, &PyEnum_Type,
1781   };
1782   if (call_node->getInputs().size() != 2) {
1783     return false;
1784   }
1785   ValueNode *iterable_node = call_node->input(1);
1786   AObject *first_param = iterable_node->GetVobj();
1787   if (first_param == nullptr) {
1788     return false;
1789   }
1790 
1791   if (first_param->GetType() == AObject::kTypeAnyValue) {
1792     if (iterable_node->GetOpcode() != CALL_FUNCTION || call_node->bci() - 1 != iterable_node->bci()) {
1793       return false;
1794     }
1795     /**
1796      * just process this case:
1797      *    z = list(zip(list(x), list(y)))
1798      *    z = list(enumerate(x))
1799      */
1800     // this case, zip object and enumerate object is dead variable
1801   }
1802   return limit_create_instance_type.find(tp) != limit_create_instance_type.end();
1803 }
1804 
BuildSuperObject(PyCodeObject * co)1805 AObject *GraphBuilder::BuildSuperObject(PyCodeObject *co) {
1806   if (co->co_argcount == 0) {
1807     PyErr_SetString(PyExc_RuntimeError, "super(): no arguments");
1808     return nullptr;
1809   }
1810 
1811   Py_ssize_t i, n;
1812   // search self object
1813   PyObject *obj = SearchSelfPyObject(co).first;
1814   if (obj == NULL) {
1815     PyErr_SetString(PyExc_RuntimeError, "super(): arg[0] deleted");
1816     return nullptr;
1817   }
1818 
1819   if (co->co_freevars == NULL) {
1820     n = 0;
1821   } else {
1822     assert(PyTuple_Check(co->co_freevars));
1823     n = PyTuple_GET_SIZE(co->co_freevars);
1824   }
1825 
1826   PyTypeObject *type = NULL;
1827   for (i = 0; i < n; i++) {
1828     PyObject *name = PyTuple_GET_ITEM(co->co_freevars, i);
1829     assert(PyUnicode_Check(name));
1830     // check class id
1831     if (!strcmp("__class__", PyUnicode_AsUTF8(name))) {
1832       Py_ssize_t index = PyTuple_GET_SIZE(co->co_cellvars) + i;
1833       PyObject *cell = SetLocalPyObject(frame_.Closure(index));
1834       if (cell == NULL || !PyCell_Check(cell)) {
1835         PyErr_SetString(PyExc_RuntimeError, "super(): bad __class__ cell");
1836         return nullptr;
1837       }
1838       type = reinterpret_cast<PyTypeObject *>(PyCell_GET(cell));
1839       if (type == NULL) {
1840         PyErr_SetString(PyExc_RuntimeError, "super(): empty __class__ cell");
1841         return nullptr;
1842       }
1843       if (!PyType_Check(type)) {
1844         PyErr_Format(PyExc_RuntimeError, "super(): __class__ is not a tyep (%s)", Py_TYPE(type)->tp_name);
1845         return nullptr;
1846       }
1847       break;
1848     }
1849   }
1850   if (type == NULL) {
1851     PyErr_SetString(PyExc_RuntimeError, "super(): __class__ cell not found");
1852     return nullptr;
1853   }
1854 
1855   py::object py_obj = py::reinterpret_borrow<py::object>(obj);
1856   py::object py_type = py::reinterpret_borrow<py::object>(reinterpret_cast<PyObject *>(type));
1857   py::tuple tuple_obj(2);
1858   tuple_obj[0] = py_type;
1859   tuple_obj[1] = py_obj;
1860   PyObject *ret = PyObject_Call(reinterpret_cast<PyObject *>(&PySuper_Type), tuple_obj.ptr(), nullptr);
1861   AObject *super_obj = AObject::Convert(ret);
1862   Py_DECREF(ret);
1863   return super_obj;
1864 }
1865 
ClassInstantiationFold(CallNode * call_node,AObject::Type type)1866 bool GraphBuilder::ClassInstantiationFold(CallNode *call_node, AObject::Type type) {
1867   const auto &params = call_node->getInputs();
1868   int call_op = call_node->GetOpcode();
1869 
1870   // list, tuple, dict fold
1871   std::vector<ValueNode *> inputs;
1872   int new_op;
1873   int new_arg;
1874   if (type == AObject::kTypeTuple || type == AObject::kTypeList) {
1875     if (params.size() > 1) {
1876       int arg_op = params[1]->GetOpcode();
1877       if (call_op == CALL_FUNCTION && (arg_op == BUILD_TUPLE || arg_op == BUILD_LIST)) {
1878         inputs = params[1]->getInputs();
1879       } else {
1880         return false;
1881       }
1882     }
1883     new_op = type == AObject::kTypeTuple ? BUILD_TUPLE : BUILD_LIST;
1884     new_arg = inputs.size();
1885   } else if (type == AObject::kTypeDict) {
1886     if (params.size() > 1) {
1887       ValueNode *map_node;
1888       if (call_op == CALL_FUNCTION && params[1]->GetOpcode() == BUILD_MAP) {
1889         map_node = params[1];
1890       } else if (call_op == CALL_FUNCTION_EX && params.size() > 2 && params[2]->GetOpcode() == BUILD_MAP) {
1891         map_node = params[2];
1892       } else {
1893         return false;
1894       }
1895       inputs = map_node->getInputs();
1896     }
1897     new_op = BUILD_MAP;
1898     new_arg = inputs.size() / 2;
1899   } else {
1900     return false;
1901   }
1902 
1903   Graph *sub_graph = NewGraph(nullptr, nullptr);
1904   AObject *res = AObject::BuildOperations(CollectObjects(inputs), new_op);
1905   ValueNode *new_node = sub_graph->NewValueNode(res, new_op, new_arg, inputs);
1906   sub_graph->GetTracedNodes().push_back(new_node);
1907   sub_graph->SetRetVal(new_node);
1908 
1909   call_node->SetSubGraph(sub_graph);
1910   call_node->SetInlineReason(InlineReason::kInline);
1911   seek(0) = new_node;
1912   return true;
1913 }
1914 
LogGuardFailed(ValueNode * node,const GraphJitConfig & conf,const std::string & msg)1915 void LogGuardFailed(ValueNode *node, const GraphJitConfig &conf, const std::string &msg) {
1916   if (!conf.GetBoolConfig(GraphJitConfig::kLogGraphBreak)) {
1917     return;
1918   }
1919   auto tr = GetTrace(node, false, true, 0, -1);
1920   std::stringstream s;
1921   if (node->GetVobj() == nullptr || node->GetVobj()->GetPyObject().ptr() == nullptr) {
1922     s << "infer failed\n";
1923   } else {
1924     std::map<Trace *, size_t> cache;
1925     s << "trace:\n" << (tr ? tr->FormatString(&cache).c_str() : "trace failed") << "\n";
1926   }
1927   s << msg << " [" << node->ToString() << "]";
1928   GRAPH_JIT_LOG_F("%s", s.str().c_str());
1929 }
1930 
HandleCallClass(CallNode * call_node)1931 bool GraphBuilder::HandleCallClass(CallNode *call_node) {
1932   AObject *vobj = call_node->input(0)->GetVobj();
1933   if (!vobj || vobj->GetType() != AObject::kTypeType) {
1934     return false;
1935   }
1936   AbstractType *t = static_cast<AbstractType *>(vobj);
1937   AObject::Type type = t->GetTypeType();
1938   if (!trace_flag() && ClassInstantiationFold(call_node, type)) {
1939     return true;
1940   }
1941 
1942   const auto &params = call_node->getInputs();
1943   AObject *instance = nullptr;
1944   bool support_create_instance = CheckSupportCreateInstance(call_node);
1945   bool constant = type == AObject::kTypePrimitive || type == AObject::kTypeTensor || type == AObject::kTypeStubTensor ||
1946                   IsMsClass(t->GetPyObject().ptr());
1947   // create instance
1948   if (support_create_instance || constant) {
1949     constant |= std::none_of(params.begin(), params.end(), [](ValueNode *i) { return !i->IsConstantValue(); });
1950     std::vector<py::object> args;
1951     std::transform(params.begin() + 1, params.end(), std::back_inserter(args), [](ValueNode *n) {
1952       AObject *i = n->GetVobj();
1953       return i ? i->GetPyObject() : py::object();
1954     });
1955     py::object res = t->BuildInstance(args, call_node->GetOpcode());
1956     instance = res.ptr() ? AObject::Convert(res) : nullptr;
1957   } else if (reinterpret_cast<PyTypeObject *>(vobj->GetPyObject().ptr()) == &PySuper_Type) {
1958     // take super ptr and compare with PySuper_Type
1959     instance = BuildSuperObject(graph_->GetCodeObj());
1960     this->graph_->GetTracedNodes().pop_back();
1961     if (PyErr_Occurred()) {
1962       throw py::error_already_set();
1963     }
1964   }
1965 
1966   if (constant && instance != nullptr && GuardConstCallNodeParam(call_node, call_node->GetGraph(), INT_MAX)) {
1967     // make instance is constant
1968     auto iter = this->graph_->GetTracedNodes().end() - 1;
1969     MS_EXCEPTION_IF_CHECK_FAIL(*iter == call_node, "CallNode must be last when build sub graph");
1970     *iter = NewValueNode(instance, LOAD_CONST, -1, {});
1971     seek(0) = *iter;
1972   } else if (!instance) {
1973     // create abstract instance
1974     instance = t->BuildAbstractInstance(CollectObjects({params.begin() + 1, params.end()}), call_node->GetOpcode());
1975   }
1976   call_node->SetVobj(instance);
1977   return instance != nullptr;
1978 }
1979 
1980 // NOTE: must be copy __code__, copy.deepcopy do nothing for code object
CopyPyFunc(const py::object & o)1981 static py::object CopyPyFunc(const py::object &o) {
1982   MS_EXCEPTION_IF_CHECK_FAIL(PyFunction_Check(o.ptr()), "must be function");
1983   PyFunctionObject *func = reinterpret_cast<PyFunctionObject *>(o.ptr());
1984   PyCodeObject *code = reinterpret_cast<PyCodeObject *>(func->func_code);
1985   PyObject *new_name = PyUnicode_FromFormat("%s%U", kPIJitCopyFuncKey, code->co_name);
1986   PyCodeObject *new_code =
1987     PyCode_New(code->co_argcount, code->co_kwonlyargcount, code->co_nlocals, code->co_stacksize, code->co_flags,
1988                code->co_code, code->co_consts, code->co_names, code->co_varnames, code->co_freevars, code->co_cellvars,
1989                code->co_filename, code->co_name, code->co_firstlineno, GetCodeLineTable(code));
1990   if (new_code == nullptr || new_name == nullptr) {
1991     throw py::error_already_set();
1992   }
1993   PyObject *new_func = PyFunction_NewWithQualName(reinterpret_cast<PyObject *>(new_code), func->func_globals, new_name);
1994   PyFunctionObject *new_ff = reinterpret_cast<PyFunctionObject *>(new_func);
1995   REPLACE_PY_MEMBER(new_ff->func_closure, func->func_closure);
1996   REPLACE_PY_MEMBER(new_ff->func_defaults, func->func_defaults);
1997   REPLACE_PY_MEMBER(new_ff->func_kwdefaults, func->func_kwdefaults);
1998   REPLACE_PY_MEMBER(new_ff->func_annotations, func->func_annotations);
1999 
2000   Py_DECREF(new_name);
2001   Py_DECREF(new_code);
2002   return py::reinterpret_steal<py::object>(new_func);
2003 }
2004 
GetPIJitCopiedFunc(const py::object & func)2005 py::object GetPIJitCopiedFunc(const py::object &func) {
2006   PyObject *res = PyObject_GetAttrString(func.ptr(), kPIJitCopyFuncKey);
2007   if (res != nullptr) {
2008     return py::reinterpret_steal<py::object>(res);
2009   }
2010   PyErr_Clear();
2011   py::object copy = CopyPyFunc(func);
2012   PyObject_SetAttrString(func.ptr(), kPIJitCopyFuncKey, copy.ptr());
2013   (void)pi_jit_should_compile(copy, py::dict(), py::none());
2014   return copy;
2015 }
2016 
GetSelfFromMethod(ValueNode * method)2017 ValueNode *GetSelfFromMethod(ValueNode *method) {
2018   if (method->GetOpcode() != LOAD_ATTR) {
2019     return nullptr;
2020   }
2021   ValueNode *self = method->input(0);
2022   /**
2023    * (chaiyouheng):
2024    * Check method is a generic attribute
2025    * descr = _PyType_Lookup(self->GetVobj()->GetTypeObject(), py::str(method->GetName()).ptr());
2026    * Check descr == nullptr || !PyFunction_Check(descr)
2027    */
2028   return self;
2029 }
2030 
ReplaceCall(CallNode * call_node,const py::object & old_func)2031 bool GraphBuilder::ReplaceCall(CallNode *call_node, const py::object &old_func) {
2032   if (call_node->GetOpcode() == CALL_FUNCTION_EX && call_node->input(1)->GetOpcode() != BUILD_TUPLE) {
2033     // dynamic length variable arguments, user-defined unpack sequence
2034     return false;
2035   }
2036   if (!graph_->GuardInlinedFunc(call_node)) {
2037     return false;
2038   }
2039   auto jcr = getJitCompileResults(old_func.ptr(), false);
2040   if (jcr != nullptr && jcr->stat != JitCompileResults::NEVER_COMPILE) {
2041     return true;
2042   }
2043 
2044   py::object new_func = GetPIJitCopiedFunc(old_func);
2045 
2046   auto &nodes = graph_->GetTracedNodes();
2047   MS_EXCEPTION_IF_CHECK_FAIL(nodes.back() == call_node, "CallNode must be last when build sub graph");
2048 
2049   ValueNode *self = nullptr;
2050   AObject::Type func_type = call_node->input(0)->GetVobj()->GetType();
2051   if (func_type == AObject::kTypeBoundMethod) {
2052     ValueNode *func_val = call_node->input(0);
2053     self = GetSelfFromMethod(func_val);
2054     if (self == nullptr) {
2055       ValueNode *node = NewValueNode(func_val->get_attr(GraphBuilder::ID___self__), LOAD_ATTR, -1, {func_val},
2056                                      GraphBuilder::ID___self__);
2057       node->SetGraph(call_node->GetGraph());
2058       nodes.insert(nodes.end() - 1, node);
2059       self = node;
2060     }
2061   } else if (func_type == AObject::kTypeCell || AObject::kTypeAnyValue) {
2062     self = call_node->input(0);
2063   } else if (func_type != AObject::kTypeFunction) {
2064     return false;
2065   }
2066 
2067   std::stringstream key;
2068   PyObject *func_name = reinterpret_cast<PyFunctionObject *>(new_func.ptr())->func_qualname;
2069   key << std::string(py::str(func_name)) << "." << new_func.ptr();
2070 
2071   // new func node
2072   ValueNode *func_node = this->NewValueNode(AObject::Convert(new_func), LOAD_CONST, -1, {});
2073   nodes.insert(nodes.end() - 1, func_node);
2074 
2075   // replace node
2076   call_node->getInputs()[0] = func_node;
2077   if (self == nullptr) {
2078     return true;
2079   }
2080 
2081   // append self to args
2082   if (call_node->GetOpcode() != CALL_FUNCTION_EX) {
2083     call_node->getInputs().insert(call_node->getInputs().begin() + 1, self);
2084     call_node->SetOparg(call_node->GetOparg() + 1);
2085     return true;
2086   }
2087 
2088   // append self to variable arguments
2089   ValueNode *args_node = call_node->input(1);
2090   std::vector<ValueNode *> inputs = args_node->getInputs();
2091   inputs.insert(inputs.begin(), self);
2092   AObject *args_info = AObject::BuildOperations(CollectObjects(inputs), BUILD_TUPLE);
2093 
2094   ValueNode *tuple = this->NewValueNode(args_info, BUILD_TUPLE, inputs.size(), inputs);
2095   tuple->set_bci(call_node->bci());
2096   tuple->SetLineNo(call_node->GetLineNo());
2097   nodes.insert(nodes.end() - 1, tuple);
2098   call_node->getInputs()[1] = tuple;
2099   return true;
2100 }
2101 
MindGraphBuilder(const PyFrameObject * f)2102 MindGraphBuilder::MindGraphBuilder(const PyFrameObject *f) : GraphBuilder(f) {
2103   std::vector<std::string> comments;
2104   auto location = std::make_shared<Location>(py::cast<std::string>(f->f_code->co_filename), f->f_code->co_firstlineno,
2105                                              0, f->f_code->co_firstlineno, 0, "", std::move(comments));
2106   MS_EXCEPTION_IF_NULL(location);
2107   TraceGuard trace_guard(location);
2108   fg_builder_ = std::make_shared<FuncGraphBuilder>(true);
2109   fg_builder_->SetGraphName(py::cast<std::string>(f->f_code->co_name) + "_" +
2110                             std::to_string(f->f_code->co_firstlineno));
2111   co_name_ = py::cast<std::string>(f->f_code->co_name);
2112 }
2113 
2114 namespace {
GetFuncGraphName(const py::object & func,const MindGraphBuilderPtr & subgraph)2115 std::string GetFuncGraphName(const py::object &func, const MindGraphBuilderPtr &subgraph) {
2116   auto func_str = py::cast<std::string>(py::str(func));
2117   std::vector<std::string> vec;
2118   std::istringstream iss(func_str);
2119   std::string str;
2120   while (iss >> str) {
2121     (void)vec.emplace_back(str);
2122   }
2123   if (vec.size() <= 1) {
2124     return "";
2125   }
2126   auto func_name = vec[1];
2127   std::replace(func_name.begin(), func_name.end(), '.', '_');
2128   return func_name + "_" + std::to_string(subgraph->GetGraph()->GetCodeObj()->co_firstlineno);
2129 }
2130 }  // namespace
2131 
BuildSubGraph(CallNode * call_node,int depth,const py::object & func,const GraphBuilderPtr & subgraph)2132 StopTraceReason MindGraphBuilder::BuildSubGraph(CallNode *call_node, int depth, const py::object &func,
2133                                                 const GraphBuilderPtr &subgraph) {
2134   auto sg = std::dynamic_pointer_cast<MindGraphBuilder>(subgraph);
2135   sg->FGBuilder()->AddPrevBuilder(FGBuilder());
2136 
2137   auto code = sg->GetGraph()->GetGuard();
2138   MS_EXCEPTION_IF_NULL(code);
2139   code->GetGuard()->Backup();
2140 
2141   auto args = call_node->GetArgs();
2142   if (PyFunction_Check(func.ptr())) {
2143     args = GetNewArgs(call_node, AObject::Convert(func.ptr()));
2144   }
2145 
2146   MS_LOG(INFO) << "new subgraph->TraceRun: " << py::str(func);
2147   bool succ = sg->FGAddInputs(args);
2148   if (!succ) {
2149     MS_LOG(INFO) << "Add input fail for new subgraph->TraceRun: " << py::str(func);
2150     return StopTraceReason::kStopTraceFunc_ArgHandle_Unsupported;
2151   }
2152   auto reason = sg->TraceRun();
2153   MS_LOG(INFO) << "new subgraph->TraceRun end: " << py::str(func);
2154 
2155   call_node->SetSubGraph(sg->GetGraph());
2156   auto sub_ret = sg->GetGraph()->GetRetVal();
2157   if (sub_ret != nullptr) {
2158     if (sub_ret->GetVobj()->GetPyObject().ptr() == nullptr ||
2159         CheckConstPyObject(sub_ret->GetVobj()->GetPyObject().ptr())) {
2160       call_node->SetVobj(sub_ret->GetVobj());
2161     } else {
2162       sg->FGBuilder()->SetGraphName(GetFuncGraphName(func, sg));
2163       sg->FGAddOutput(false);
2164       if (sg->FGBuilder()->graph() == nullptr) {
2165         MS_LOG(INFO) << "subgraph trace null";
2166         return StopTraceReason::kTrace_Fail;
2167       } else {
2168         TraceGuard trace_guard(GetLocation(call_node));
2169         auto res = FGBuilder()->AddNode(sg->FGBuilder()->graph(), args);
2170         if (res.ptr()) {
2171           MS_LOG(INFO) << "add fg node suc: ";
2172           call_node->SetVobj(AObject::Convert(res));
2173         }
2174       }
2175     }
2176   }
2177   bool is_make_func = call_node->input(0)->GetOpcode() == MAKE_FUNCTION;
2178   if (is_make_func) {
2179     graph_->GuardInlinedFunc(call_node);
2180   }
2181   return reason;
2182 }
2183 
2184 // build sub-graph
BuildSubGraph(CallNode * call_node,int depth,const py::object & func,const GraphBuilderPtr & subgraph)2185 StopTraceReason GraphBuilder::BuildSubGraph(CallNode *call_node, int depth, const py::object &func,
2186                                             const GraphBuilderPtr &subgraph) {
2187   InlineReason stat = InlineReason::kInline;
2188   bool is_make_func = call_node->input(0)->GetOpcode() == MAKE_FUNCTION;
2189 
2190   auto code = subgraph->GetGraph()->GetGuard();
2191   MS_EXCEPTION_IF_NULL(code);
2192   code->GetGuard()->Backup();
2193 
2194   MS_LOG(INFO) << "old subgraph->TraceRun";
2195   subgraph->TraceRun();
2196 
2197   call_node->SetSubGraph(subgraph->GetGraph());
2198   if (subgraph->GetGraph()->GetRetVal() != nullptr) {
2199     call_node->SetVobj(subgraph->GetGraph()->GetRetVal()->GetVobj());
2200   }
2201   bool gen_to_tuple = subgraph->GetGraph()->Config().GetBoolConfig(GraphJitConfig::kEnableGeneratorExpressionToTuple);
2202   if (!gen_to_tuple && subgraph->GetGraph()->GetGeneratorResult() != nullptr) {
2203     subgraph->GetGraph()->SetRetVal(nullptr);
2204   }
2205 
2206   stat = ApplyInlinePolicy(call_node) ? stat : InlineReason::kInlinePolicyDisabled;
2207   if (stat != InlineReason::kInline) {
2208     code->GetGuard()->Rollback();
2209     if (!is_make_func) {
2210       /**
2211        * replace function call, inline or resume capture after break graph
2212        * exclude make function, because of function always a new function but code is constant
2213        **/
2214       stat = ReplaceCall(call_node, func) ? stat : InlineReason::kInlinePolicyDisabled;
2215     }
2216   } else {
2217     if (!is_make_func) {
2218       // exclude make function, because of function always a new function but code is constant
2219       stat = graph_->GuardInlinedFunc(call_node) ? stat : InlineReason::kInlinePolicyDisabled;
2220     }
2221     if (stat != InlineReason::kInline) {
2222       code->GetGuard()->Rollback();
2223     } else {
2224       code->GetGuard()->Pop();
2225     }
2226   }
2227 
2228   // if stat == InlineReason::kInline, guard free variable
2229   call_node->SetInlineReason(stat);
2230   return StopTraceReason::kNonStopTrace;
2231 }
2232 
UnpackDynamicLengthDictByBytecode(std::vector<ValueNode * > * params,CallNode * call_node,ValueNode * dict_node)2233 bool GraphBuilder::UnpackDynamicLengthDictByBytecode(std::vector<ValueNode *> *params, CallNode *call_node,
2234                                                      ValueNode *dict_node) {
2235   // user defined mappings, dynamic length dictionary unpack
2236   if (dict_node->GetVobj()->GetType() != AObject::kTypeDict) {
2237     return false;
2238   }
2239   auto dict = static_cast<AbstractDict *>(dict_node->GetVobj());
2240   if (!dict->IsElementValid()) {
2241     return false;
2242   }
2243   /**
2244    * must be guard this dict length
2245    */
2246   py::dict py_dict = dict->GetPyObject();
2247   py::tuple keys(py_dict.size());
2248   PyObject *key;
2249   PyObject *value;
2250   Py_ssize_t pos = 0;
2251   Py_ssize_t cnt = 0;
2252   while (PyDict_Next(py_dict.ptr(), &pos, &key, &value)) {
2253     PyObject *py_key = key;
2254     MS_EXCEPTION_IF_CHECK_FAIL(PyUnicode_CheckExact(py_key), "key must be string");
2255     PyObject *py_value = value;
2256     ValueNode *index = NewValueNode(AObject::Convert(py_key), LOAD_CONST, -1, {});
2257     ValueNode *val = NewValueNode(AObject::Convert(py_value), BINARY_SUBSCR, 0, {dict_node, index});
2258     keys[cnt++] = py_key;
2259     params->push_back(val);
2260     call_node->AddParam(val);
2261   }
2262   ValueNode *const_keys = NewValueNode(AObject::Convert(keys), LOAD_CONST, -1, {});
2263   params->push_back(const_keys);
2264   return true;
2265 }
2266 
UnpackCallExDict(std::vector<ValueNode * > * params,CallNode * call_node)2267 bool GraphBuilder::UnpackCallExDict(std::vector<ValueNode *> *params, CallNode *call_node) {
2268   ValueNode *dict_node = params->back();
2269   params->clear();
2270   if (dict_node->GetOpcode() != BUILD_MAP) {
2271     return UnpackDynamicLengthDictByBytecode(params, call_node, dict_node);
2272   }
2273   if (dict_node->GetOparg() == 0) {
2274     return true;
2275   }
2276   py::tuple keys(dict_node->GetOparg());
2277   for (int i = 0; i < dict_node->GetOparg(); ++i) {
2278     AObject *k = dict_node->input(i * 2)->GetVobj();
2279     if (k->GetType() != AObject::kTypeString) {
2280       MS_LOG(DEBUG) << "for unpack-call, dict keys must be string";
2281       return false;
2282     }
2283     keys[i] = k->GetPyObject();
2284     params->push_back(dict_node->input((i << 1) + 1));
2285     MS_EXCEPTION_IF_CHECK_FAIL(keys[i].ptr(), "the keys of unpack-call must be a const string");
2286   }
2287   ValueNode *const_keys = this->NewValueNode(AObject::Convert(keys), LOAD_CONST, -1, {});
2288   params->push_back(const_keys);
2289   return true;
2290 }
2291 
UnpackDynamicLengthTupleByBytecode(std::vector<ValueNode * > * params,ValueNode * args_node,CallNode * call_node)2292 bool GraphBuilder::UnpackDynamicLengthTupleByBytecode(std::vector<ValueNode *> *params, ValueNode *args_node,
2293                                                       CallNode *call_node) {
2294   // user-defined sequence, dynamic length tuple unpack
2295   if (args_node->GetVobj() && args_node->GetVobj()->GetType() != AObject::kTypeTuple) {
2296     return false;
2297   }
2298   AbstractTuple *tuple = static_cast<AbstractTuple *>(args_node->GetVobj());
2299   if (!tuple->IsElementValid()) {
2300     return false;
2301   }
2302   /**
2303    * must be guard this tuple length
2304    */
2305   auto items = tuple->items();
2306   std::vector<ValueNode *> args;
2307   for (size_t i = 0; i < items.size(); i++) {
2308     ValueNode *idx_node = this->NewValueNode(AObject::Convert(py::int_(i)), LOAD_CONST, -1, {});
2309     auto value = this->NewValueNode(items[i], BINARY_SUBSCR, 0, {args_node, idx_node});
2310     args.push_back(value);
2311 
2312     call_node->AddParam(value);
2313   }
2314   params->insert(params->begin(), args.begin(), args.end());
2315   return true;
2316 }
2317 
2318 // unpack CALL_FUNCTION_EX parameters
2319 // should do this when bytecode analyze ? replace origin opcode
UnpackCallExParams(std::vector<ValueNode * > * params,int extra_local,bool * has_kw,CallNode * call_node)2320 bool GraphBuilder::UnpackCallExParams(std::vector<ValueNode *> *params, int extra_local, bool *has_kw,
2321                                       CallNode *call_node) {
2322   bool has_dict = params->size() > 1;
2323   ValueNode *args_node = params->operator[](0);
2324   if (!has_dict) {
2325     params->clear();
2326   } else if (!UnpackCallExDict(params, call_node)) {
2327     return false;
2328   }
2329   *has_kw = params->size();
2330   if (args_node->GetOpcode() != BUILD_TUPLE) {
2331     return UnpackDynamicLengthTupleByBytecode(params, args_node, call_node);
2332   }
2333   params->insert(params->begin(), args_node->getInputs().begin(), args_node->getInputs().end());
2334   return true;
2335 }
2336 
PackKwParams(const py::object & func,std::vector<ValueNode * > * params,FrameStates * frame,std::vector<ValueNode * > * kwvargs)2337 bool GraphBuilder::PackKwParams(const py::object &func, std::vector<ValueNode *> *params, FrameStates *frame,
2338                                 std::vector<ValueNode *> *kwvargs) {
2339   PyCodeObject *co = reinterpret_cast<PyCodeObject *>(PyFunction_GET_CODE(func.ptr()));
2340   AObject *keys_info = params->back()->GetVobj();
2341   if (params->back()->GetOpcode() != LOAD_CONST || keys_info->GetType() != AObject::kTypeTuple) {
2342     return false;  // other case
2343   }
2344 
2345   const int posonlyargcount = GetCodePositionOnlyArgCount(co);
2346   PyObject **vars = &PyTuple_GET_ITEM(co->co_varnames, 0);
2347   const int argc = co->co_argcount + co->co_kwonlyargcount;
2348   PyObject **kwnames = &PyTuple_GET_ITEM(keys_info->GetPyObject().ptr(), 0);
2349   const int k_cnt = PyTuple_GET_SIZE(keys_info->GetPyObject().ptr());
2350   // kwnames must be string
2351   MS_ASSERT(static_cast<AbstractTuple *>(keys_info)->GetElementType() == AObject::kTypeString);
2352   MS_EXCEPTION_IF_CHECK_FAIL(SizeToInt(params->size()) > k_cnt, "check param");
2353 
2354   int kw_2_p_cnt = 0;
2355 
2356   // for each kw argument
2357   for (int i = k_cnt - 1; i >= 0; --i) {
2358     PyObject *key = kwnames[i];
2359     // find position and kwonly argument for key
2360     int pos = std::find_if(vars, vars + argc, [&key](PyObject *k) { return !PyUnicode_Compare(key, k); }) - vars;
2361     if (pos < posonlyargcount) {
2362       MS_LOG(DEBUG) << "position only argument specified by key-word";
2363       return false;
2364     }
2365 
2366     ValueNode *v = *(params->end() - 1 - k_cnt + i);
2367     // if key is position arg, store it
2368     if (pos < argc) {
2369       frame->SetLocal(pos, v);
2370       kw_2_p_cnt++;
2371       continue;
2372     }
2373     ValueNode *k = NewValueNode(AObject::Convert(key), LOAD_CONST, -1, {});
2374 
2375     kwvargs->push_back(k);
2376     kwvargs->push_back(v);
2377   }
2378 
2379   params->resize(params->size() - 1 - k_cnt);
2380   if (!(co->co_flags & CO_VARKEYWORDS)) {
2381     return kw_2_p_cnt == k_cnt;  // if not equal, too many key-word arguments
2382   }
2383   return true;
2384 }
2385 
HandleKWParams(const py::object & func,std::vector<ValueNode * > * params,FrameStates * frame)2386 bool GraphBuilder::HandleKWParams(const py::object &func, std::vector<ValueNode *> *params, FrameStates *frame) {
2387   PyCodeObject *co = reinterpret_cast<PyCodeObject *>(PyFunction_GET_CODE(func.ptr()));
2388   std::vector<ValueNode *> kwvargs;
2389   if (!PackKwParams(func, params, frame, &kwvargs)) {
2390     // illegal arguments
2391     return false;
2392   }
2393 
2394   const int argc = co->co_argcount + co->co_kwonlyargcount;
2395   if (!(co->co_flags & CO_VARKEYWORDS)) {
2396     // kw_2_p_cnt == k_cnt, all kw arguments is positions arguments
2397     return true;
2398   }
2399 
2400   int kwvarg_loc = argc + ((co->co_flags & CO_VARARGS) ? 1 : 0);
2401   AObject *dict = AObject::BuildOperations(CollectObjects(kwvargs), BUILD_MAP);
2402   frame->SetLocal(kwvarg_loc, NewValueNode(dict, BUILD_MAP, kwvargs.size() / 2, kwvargs));
2403 
2404   static_cast<CallNode *>(seek(0))->AddParam(frame->Local(kwvarg_loc));
2405   return true;
2406 }
2407 
CheckAndSetDefaultParams(const py::object & func,FrameStates * frame,int position_argc)2408 bool GraphBuilder::CheckAndSetDefaultParams(const py::object &func, FrameStates *frame, int position_argc) {
2409   PyCodeObject *co = reinterpret_cast<PyCodeObject *>(PyFunction_GET_CODE(func.ptr()));
2410   PyObject *defs = PyFunction_GET_DEFAULTS(func.ptr());
2411   PyObject *kwdefs = PyFunction_GET_KW_DEFAULTS(func.ptr());
2412 
2413   const int argc = co->co_argcount + co->co_kwonlyargcount;
2414   PyObject *vars = co->co_varnames;
2415 
2416   int defs_off = defs ? co->co_argcount - PyTuple_GET_SIZE(defs) : INT_MAX;
2417   for (int i = position_argc; i < argc; ++i) {
2418     if (frame->Local(i) != &ValueNode::kUnboundLocal) {
2419       continue;
2420     }
2421     PyObject *val;
2422     if (i < co->co_argcount) {
2423       val = i < defs_off ? nullptr : PyTuple_GET_ITEM(defs, i - defs_off);
2424     } else {
2425       val = kwdefs == nullptr ? nullptr : PyDict_GetItem(kwdefs, PyTuple_GET_ITEM(vars, i));
2426     }
2427     if (val == nullptr) {
2428       MS_LOG(DEBUG) << "no " << (i < defs_off ? "" : "kw-") << "default parameter error";
2429       return false;
2430     }
2431     auto vo = AObject::Convert(val);
2432     ValueNode *c = NewValueNode(vo, LOAD_CONST, -1, {});
2433     frame->SetLocal(i, c);
2434   }
2435   return true;
2436 }
2437 
GetBoundSelf(CallNode * call_node)2438 ValueNode *GetBoundSelf(CallNode *call_node) {
2439   ValueNode *func_val = call_node->input(0);
2440   AObject *vo = func_val->GetVobj();
2441   Graph *graph = call_node->GetGraph();
2442 
2443   ValueNode *self = nullptr;
2444   switch (vo->GetType()) {
2445     case AObject::kTypeBoundMethod: {
2446       self = GetSelfFromMethod(func_val);
2447       if (self == nullptr) {
2448         AObject *tmp = func_val->get_attr(GraphBuilder::ID___self__);
2449         ValueNode *node = graph->NewValueNode(tmp, LOAD_ATTR, -1, {func_val}, GraphBuilder::ID___self__);
2450         node->SetGraph(call_node->GetGraph());
2451         call_node->AddParam(node);
2452         self = node;
2453       }
2454       break;
2455     }
2456     case AObject::kTypeCell: /* fallthrough */
2457     case AObject::kTypeAnyValue:
2458       self = func_val;
2459       break;
2460     case AObject::kTypeCFunction:
2461     case AObject::kTypeFunction:
2462       break;
2463     default:
2464       MS_LOG(INTERNAL_EXCEPTION) << "unimplemented type " << vo->ToString();
2465       break;
2466   }
2467   return self;
2468 }
2469 
HandlePositionParams(const py::object & func,std::vector<ValueNode * > * params,FrameStates * frame)2470 bool GraphBuilder::HandlePositionParams(const py::object &func, std::vector<ValueNode *> *params, FrameStates *frame) {
2471   CallNode *call_node = reinterpret_cast<CallNode *>(seek(0));
2472   PyCodeObject *co = reinterpret_cast<PyCodeObject *>(PyFunction_GET_CODE(func.ptr()));
2473   auto vobj = trace_flag() ? AObject::Convert(func.ptr()) : call_node->input(0)->GetVobj();
2474   AObject::Type callable_type = vobj->GetType();
2475 
2476   ValueNode *self = GetBoundSelf(call_node);
2477   if (self != nullptr) {
2478     params->insert(params->begin(), self);
2479   }
2480 
2481   const int argc = co->co_argcount;
2482   const int has_varg = (co->co_flags & CO_VARARGS) ? 1 : 0;
2483   const int has_kwvarg = (co->co_flags & CO_VARKEYWORDS) ? 1 : 0;
2484   const int varg_loc = argc + co->co_kwonlyargcount;
2485   const int kwvarg_loc = argc + co->co_kwonlyargcount + has_varg;
2486   int pargc = params->size();
2487   if (pargc > argc && !has_varg) {
2488     MS_LOG(DEBUG) << "too many parameters";
2489     return false;
2490   }
2491   bool append_self_to_varg = has_varg && self && callable_type == AObject::kTypeBoundMethod && argc == 0;
2492   if (append_self_to_varg) {  // self is in variable arguments
2493     MS_LOG(INFO) << "not implement append self to variable arguments, inline failed";
2494     return false;
2495   }
2496 
2497   if (has_kwvarg && frame->Local(kwvarg_loc) == &ValueNode::kUnboundLocal) {
2498     auto vo = AObject::Convert(py::dict());
2499     auto m = NewValueNode(vo, BUILD_MAP, 0, {});
2500     call_node->AddParam(m);
2501     frame->SetLocal(kwvarg_loc, m);
2502   }
2503 
2504   if (has_varg) {
2505     int vargc = pargc > argc ? pargc - argc : 0;
2506     std::vector<ValueNode *> vargs(params->end() - vargc, params->end());
2507     params->resize(params->size() - vargc);
2508 
2509     auto vo = AObject::BuildOperations(CollectObjects(vargs), BUILD_TUPLE);
2510     ValueNode *build_tuple = NewValueNode(vo, BUILD_TUPLE, vargc, vargs);
2511     call_node->AddParam(build_tuple);
2512     frame->SetLocal(varg_loc, build_tuple);
2513   }
2514 
2515   pargc = params->size();
2516   for (int i = pargc - 1; i >= 0; --i) {
2517     if (frame->Local(i) != &ValueNode::kUnboundLocal) {
2518       MS_LOG(DEBUG) << "duplicate key-word parameter error";
2519       return false;
2520     }
2521     frame->SetLocal(i, params->back());
2522     params->pop_back();
2523   }
2524   return CheckAndSetDefaultParams(func, frame, pargc);
2525 }
2526 
HandleCallParameters(const py::object & func_info,CallNode * call_node,FrameStates * frame)2527 bool GraphBuilder::HandleCallParameters(const py::object &func_info, CallNode *call_node, FrameStates *frame) {
2528   if (func_info.ptr() == nullptr) {
2529     MS_LOG(EXCEPTION) << "HandleCallParameters with empty func_info input.";
2530   }
2531   PyCodeObject *co = reinterpret_cast<PyCodeObject *>(PyFunction_GET_CODE(func_info.ptr()));
2532   frame->ResizeLocal(co->co_nlocals);
2533 
2534   std::vector<ValueNode *> params(call_node->getInputs().begin() + 1, call_node->getInputs().end());
2535   int op = call_node->GetOpcode();
2536   bool has_kw = (op == CALL_FUNCTION_KW);
2537   if (op == CALL_FUNCTION_EX && !UnpackCallExParams(&params, co->co_nlocals, &has_kw, call_node)) {
2538     return false;  // ex_dict infer failed or user-defined sequence and map arguments
2539   }
2540   if (has_kw && !HandleKWParams(func_info, &params, frame)) {
2541     return false;
2542   }
2543   if (!HandlePositionParams(func_info, &params, frame)) {
2544     return false;
2545   }
2546 
2547   MS_EXCEPTION_IF_CHECK_FAIL(params.size() == 0, "check parameters handle");
2548 
2549   // after store all params
2550   // cell2arg
2551   const Py_ssize_t ncells = PyTuple_GET_SIZE(co->co_cellvars);
2552   const Py_ssize_t *c2a_arr = co->co_cell2arg;
2553   for (int i = 0; c2a_arr != nullptr && i < ncells; ++i) {
2554     if (c2a_arr[i] != CO_CELL_NOT_AN_ARG) {
2555       Py_ssize_t arg_index = c2a_arr[i];
2556       CellVarNode *cell_node = frame->Closure(i);
2557       ValueNode *arg_node = frame->Local(arg_index);
2558       /**
2559        * here not delete the local, continue with local same as closure
2560        * frame->SetLocal(arg_index, &ValueNode::kUnboundLocal);
2561        */
2562 
2563       PyObject *cell = cell_node->GetVobj()->GetPyObject().ptr();
2564       PyObject *cell_contents = arg_node->GetVobj() ? arg_node->GetVobj()->GetPyObject().inc_ref().ptr() : nullptr;
2565       MS_EXCEPTION_IF_CHECK_FAIL(cell && PyCell_Check(cell) && PyCell_GET(cell) == nullptr, "must be a empty closure");
2566 
2567       ValueNode *n = NewValueNode(nullptr, STORE_DEREF, i, {arg_node});
2568 
2569       cell_node->AddCellOper(n);
2570       cell_node->SetValue(arg_node);
2571       Py_XSETREF(PyCell_GET(cell), cell_contents);
2572       // cell variable is eliminate
2573       // call_node->AddParam(n);
2574     }
2575   }
2576   return true;
2577 }
2578 
2579 static void SetGradFuncInfo(mindspore::pijit::CallNode *call_node);
2580 
ResolveCallable(CallNode * call_node,StopTraceReason * stop_reason)2581 py::object GraphBuilder::ResolveCallable(CallNode *call_node, StopTraceReason *stop_reason) {
2582   AObject *callable = call_node->input(0)->GetVobj();
2583   py::object callable_info;
2584   *stop_reason = StopTraceReason::kStopTraceInfer_Fail;
2585   call_node->SetInlineReason(InlineReason::kInlineInfer_Fail);
2586   if (!callable) {
2587     return callable_info;
2588   }
2589   callable_info = callable->GetPyObject();
2590   if (callable_info.ptr() == nullptr) {
2591     callable_info = py::cast<py::object>(reinterpret_cast<PyObject *>(callable->GetTypeObject()));
2592   }
2593 
2594   AObject::Type callable_type = callable->GetType();
2595   if (callable_info.ptr() == nullptr) {
2596     if (callable->TestMsFlag(AObject::kMsFlagGradFunc | AObject::kMsFlagShardFunc | AObject::kMsFlagVmapFunc)) {
2597       SetGradFuncInfo(call_node);
2598       *stop_reason = StopTraceReason::kNonStopTrace;
2599     }
2600     return py::object();
2601   }
2602 
2603   *stop_reason = StopTraceReason::kNonStopTrace;
2604   if (callable_type == AObject::kTypeType) {
2605     call_node->SetInlineReason(InlineReason::kInlineFunc_ArgType_IsClass);
2606     HandleCallClass(call_node);
2607     if (static_cast<AbstractType *>(callable)->GetTypeType() == AObject::kTypeCell) {
2608       *stop_reason = StopTraceReason::kStopTraceInfer_Fail;
2609     }
2610     return py::object();
2611   }
2612 
2613   if (WhiteListFuncCheckAndInfer(call_node, callable_info)) {
2614     if (call_node->GetInlineReason() == InlineReason::kInlineFunc_Type_Unsupported) {
2615       *stop_reason = StopTraceReason::kStopTraceFunc_Type_Unsupported;
2616     }
2617     return py::object();
2618   }
2619 
2620   // find code object
2621   callable_info = GetFuncInfo(call_node->input(0));
2622   if (callable_info.ptr() == nullptr) {
2623     *stop_reason = StopTraceReason::kStopTraceFunc_Type_Unsupported;
2624     call_node->SetInlineReason(InlineReason::kInlineCFunction_Unsupported);
2625   }
2626   return callable_info;
2627 }
2628 
ResolveClosure(const py::object & func_info,ValueNode * callable_node,FrameStates * frame)2629 void GraphBuilder::ResolveClosure(const py::object &func_info, ValueNode *callable_node, FrameStates *frame) {
2630   if (func_info.ptr() == nullptr) {
2631     MS_LOG(INTERNAL_EXCEPTION) << "When resolving closure, get func_info failed.";
2632   }
2633   PyCodeObject *co = reinterpret_cast<PyCodeObject *>(PyFunction_GET_CODE(func_info.ptr()));
2634   PyObject *closure = PyFunction_GET_CLOSURE(func_info.ptr());
2635 
2636   int ncells = PyTuple_GET_SIZE(co->co_cellvars);
2637   int nfrees = PyTuple_GET_SIZE(co->co_freevars);
2638   frame->ResizeClosure(ncells + nfrees);
2639   for (int i = 0; i < ncells; i++) {
2640     CellVarNode *n = graph_->allocator().NewNode<CellVarNode>(CellVarNode::CellVar);
2641     n->SetVobj(AObject::Convert(py::reinterpret_steal<py::object>(PyCell_New(nullptr))));
2642     frame->SetClosure(i, n);
2643   }
2644   // track free variable
2645   bool make_func = callable_node->GetOpcode() == MAKE_FUNCTION;
2646   for (int i = 0; i < nfrees; ++i) {
2647     CellVarNode *freevar = nullptr;
2648     if (make_func) {
2649       ValueNode *tuple = *(callable_node->getInputs().end() - 3);
2650       MS_EXCEPTION_IF_CHECK_FAIL(tuple->GetOpcode() == BUILD_TUPLE, "unknown closure source");
2651       freevar = reinterpret_cast<CellVarNode *>(tuple->input(i));
2652     } else if (closure) {
2653       auto v = PyTuple_GET_ITEM(closure, i);
2654       freevar = graph_->allocator().NewNode<CellVarNode>(CellVarNode::FreeVar);
2655       freevar->SetVobj(AObject::Convert(v));
2656 
2657       // if inline, guard free variable
2658       ValueNode *param = NewValueNode(AObject::Convert(PyCell_GET(v)), LOAD_DEREF, -1);
2659       param->SetGraph(graph_);
2660       freevar->SetValue(param);
2661     } else {
2662       MS_LOG(EXCEPTION) << "error no closure";
2663     }
2664     frame->SetClosure(ncells + i, freevar);
2665   }
2666 }
2667 
SetMixedPrecisionType(CallNode * call_node,FrameStates * frame)2668 void SetMixedPrecisionType(CallNode *call_node, FrameStates *frame) {
2669   auto func_node = call_node->input(0);
2670   if (func_node->GetVobj() && func_node->GetVobj()->GetType() == AbstractObjectBase::kTypeCell) {
2671     auto cell = py::cast<CellPtr>(func_node->GetVobj()->GetPyObject());
2672     auto mixed_type = cell->GetMixedPrecisionType();
2673     if (mixed_type != MixedPrecisionType::kNotSet) {
2674       for (size_t i = 0; i < frame->GetLocals().size(); i++) {
2675         if (frame->Local(i)->GetType() == AbstractNode::Param) {
2676           auto paramNode = reinterpret_cast<ParamNode *>(frame->Local(i));
2677           if (paramNode->GetVobj()->GetType() == AObject::kTypeTensor &&
2678               !paramNode->GetVobj()->GetPyObject().attr("dtype").is_none()) {
2679             auto src_dtype = paramNode->GetVobj()->GetPyObject().attr("dtype");
2680             bool is_cast = false;
2681             if (py::isinstance<Float>(src_dtype)) {
2682               auto float_nbits = py::cast<Float>(src_dtype).nbits();
2683               if (float_nbits == 64 || (float_nbits == 32 && mixed_type != kFP32) ||
2684                   (float_nbits == 16 && mixed_type != kFP16)) {
2685                 is_cast = true;
2686               }
2687             }
2688             if (py::isinstance<BFloat>(src_dtype) && mixed_type != kBF16) {
2689               is_cast = true;
2690             }
2691             if (!is_cast) {
2692               continue;
2693             }
2694             auto dst_dtype = Utils::MixedPrecisionTypeToDType(mixed_type);
2695             paramNode->SetMixedPrecisionType(dst_dtype);
2696           }
2697         }
2698       }
2699     }
2700   }
2701 }
2702 
HandleCall(int depth)2703 StopTraceReason GraphBuilder::HandleCall(int depth) {
2704   MS_EXCEPTION_IF_CHECK_FAIL(seek(0)->GetType() == ValueNode::Call, "must be call node");
2705   CallNode *call_node = reinterpret_cast<CallNode *>(seek(0));
2706   if (depth > root_->graph_->Config().getIntConfig(GraphJitConfig::kMaxInlineDepth)) {
2707     call_node->SetInlineReason(InlineReason::kInlineTooDeep);
2708     return StopTraceReason::kNonStopTrace;
2709   }
2710   StopTraceReason stop_reason = StopTraceReason::kNonStopTrace;
2711 
2712   py::object callable_info = ResolveCallable(call_node, &stop_reason);
2713   if (callable_info.ptr() == nullptr) {
2714     return stop_reason;
2715   }
2716   MS_EXCEPTION_IF_CHECK_FAIL(PyFunction_Check(callable_info.ptr()), "'ResolveCallable' must be return a function");
2717 
2718   // unsupported check
2719   PyCodeObject *co = reinterpret_cast<PyCodeObject *>(PyFunction_GET_CODE(callable_info.ptr()));
2720   PyObject *globals = PyFunction_GET_GLOBALS(callable_info.ptr());
2721   auto subgraph = GraphBuilder::Creator(this->root_ ? this->root_ : this, this, co, globals, trace_flag());
2722 
2723   // frame build
2724   FrameStates *frame = &(subgraph->frame_);
2725   ResolveClosure(callable_info, call_node->input(0), frame);
2726   if (!HandleCallParameters(callable_info, call_node, frame)) {
2727     call_node->SetInlineReason(InlineReason::kInlineFunc_ArgHandle_Unsupported);
2728     return StopTraceReason::kStopTraceFunc_ArgHandle_Unsupported;
2729   }
2730 
2731   SetMixedPrecisionType(call_node, frame);
2732   // build sub-graph
2733   stop_reason = BuildSubGraph(call_node, depth, callable_info, subgraph);
2734   CollectInlineInfo(call_node, depth);
2735 
2736   if (!trace_flag() && call_node->GetSubGraph() && call_node->GetInlineReason() == InlineReason::kInline) {
2737     MS_EXCEPTION_IF_NULL(call_node->GetSubGraph()->GetRetVal());
2738     seek(0) = call_node->GetSubGraph()->GetRetVal();
2739   }
2740   return stop_reason;
2741 }
2742 
GuardLoopSequence(Graph * graph,ValueNode * seq_node,Py_ssize_t seq_size)2743 static bool GuardLoopSequence(Graph *graph, ValueNode *seq_node, Py_ssize_t seq_size) {
2744   // guard length
2745   PyObject *seq = seq_node->GetVobj()->GetPyObject().ptr();
2746   if (seq != nullptr && seq_size == -1) {
2747     seq_size = PySequence_Size(seq);
2748   }
2749   if (seq == nullptr || seq_size == -1) {
2750     PyErr_Clear();
2751     return false;
2752   }
2753   if (!graph->GuardSequenceNodeLength(seq_node, seq_size)) {
2754     return false;
2755   }
2756   if (!graph->GuardType(seq_node)) {
2757     return false;
2758   }
2759   return true;
2760 }
2761 
GuardIterInputs(Graph * graph,ValueNode * seq_node,Py_ssize_t seq_size=-1)2762 bool GuardIterInputs(Graph *graph, ValueNode *seq_node, Py_ssize_t seq_size = -1) {
2763   PyObject *seq = seq_node->GetVobj()->GetPyObject().ptr();
2764   if (seq != nullptr && seq_size == -1) {
2765     seq_size = PySequence_Size(seq);
2766   }
2767   if (seq == nullptr || seq_size == -1) {
2768     PyErr_Clear();
2769     return false;
2770   }
2771   if (!graph->GuardSequenceNodeLength(seq_node, seq_size)) {
2772     return false;
2773   }
2774   auto input_nodes = seq_node->getInputs();
2775   for (size_t i = 1; i < input_nodes.size(); ++i) {
2776     ValueNode *input_node = input_nodes[i];
2777     if (input_node == nullptr) {
2778       return false;
2779     }
2780     TracePtr tr = graph->TraceValueNode(input_node);
2781     if (!(graph->GetGuard()->GetGuard()->GuardOn(tr, GuardLevel::GEqual))) {
2782       MS_LOG(INFO) << "Iterator guard fail: " << seq_node->ToString();
2783       return false;
2784     }
2785   }
2786   MS_LOG(INFO) << "Iterator guard success: " << seq_node->ToString();
2787   return true;
2788 }
2789 
TraceRunForIterSequence(int jump_bci,bool is_range_type)2790 bool GraphBuilder::TraceRunForIterSequence(int jump_bci, bool is_range_type) {
2791   // check for iter
2792   ValueNode *iter_node = seek(0);
2793   ValueNode *seq_node = iter_node->input(0);
2794   PyObject *seq = seq_node->GetVobj()->GetPyObject().ptr();
2795   if (seq == nullptr) {
2796     return false;  // infer failed
2797   }
2798   Py_ssize_t size = PySequence_Size(seq);
2799   if (size == -1) {
2800     PyErr_Clear();
2801     MS_LOG(DEBUG) << "FOR_ITER without __len__";
2802     return false;
2803   }
2804 
2805   int &index = iter_node->marker_;
2806   if (index == 0 && ((is_range_type && !GuardIterInputs(graph_, seq_node)) ||
2807                      (!is_range_type && !GuardLoopSequence(graph_, seq_node)))) {
2808     // loop start.
2809     return false;
2810   }
2811 
2812   if (index >= size) {
2813     pop();
2814     cur_bci_ = jump_bci;
2815     return true;
2816   }
2817 
2818   PyObject *item = PySequence_GetItem(seq, index);
2819   if (item == nullptr) {
2820     MS_LOG(ERROR) << "trace for iter got an error " << py::error_already_set().what();
2821     PyErr_Clear();
2822     return false;
2823   }
2824 
2825   py::object index_object = py::int_(index);
2826   ValueNode *index_node = NewValueNode(AObject::Convert(index_object), LOAD_CONST, -1, {});
2827   push(seq_node);
2828   push(index_node);
2829   DoItemAccess({BINARY_SUBSCR, 0});
2830   ValueNode *item_node = pop();
2831   Py_DECREF(item);
2832 
2833   index++;
2834   push(item_node);
2835   cur_bci_ = cur_bci_ + 1;
2836   return true;
2837 }
2838 
CheckForIterEnumerate(ValueNode * iter_node)2839 static bool CheckForIterEnumerate(ValueNode *iter_node) {
2840   ValueNode *enumerate_node = iter_node->input(0);
2841   if (enumerate_node->GetOpcode() != CALL_FUNCTION || iter_node->bci() - 1 != enumerate_node->bci()) {
2842     // enumerate object maybe alive, shouldn't reduce it
2843     return false;
2844   }
2845   PyObject *enumerate = enumerate_node->GetVobj()->GetPyObject().ptr();
2846   if (enumerate == nullptr) {
2847     return false;
2848   }
2849 
2850   MS_EXCEPTION_IF_NULL(iter_node->GetGraph());
2851 
2852   ValueNode *iterable_node = enumerate_node->input(1);
2853   PyObject *iterable = iterable_node->GetVobj()->GetPyObject().ptr();
2854   if (iterable == nullptr || !PySequence_Check(iterable) || !GuardLoopSequence(iter_node->GetGraph(), iterable_node)) {
2855     // just support sequence iteration
2856     return false;
2857   }
2858   return true;
2859 }
2860 
TraceRunForIterEnumerate(int jump_bci)2861 bool GraphBuilder::TraceRunForIterEnumerate(int jump_bci) {
2862   ValueNode *iter_node = seek(0);
2863   if (iter_node->marker_ == 0) {
2864     if (!CheckForIterEnumerate(iter_node)) {
2865       return false;
2866     }
2867     iter_node->marker_ = 1;
2868   }
2869   ValueNode *enumerate_node = iter_node->input(0);
2870   PyObject *enumerate = enumerate_node->GetVobj()->GetPyObject().ptr();
2871   ValueNode *iterable_node = enumerate_node->input(1);
2872 
2873   // reduce iterable object
2874   ValueNode *seq_node = iterable_node;
2875   PyObject *tuple = PyIter_Next(enumerate);
2876   if (tuple == nullptr) {
2877     if (PyErr_Occurred() && !PyErr_ExceptionMatches(PyExc_StopIteration)) {
2878       MS_LOG(ERROR) << "trace FOR_ITER got an error " << py::error_already_set().what();
2879       PyErr_Clear();
2880       return false;
2881     }
2882     PyErr_Clear();
2883     pop();
2884     cur_bci_ = jump_bci;
2885     return true;
2886   }
2887   PyObject *index = PyTuple_GET_ITEM(tuple, 0);
2888   PyObject *item = PyTuple_GET_ITEM(tuple, 1);
2889   ValueNode *index_node = NewValueNode(AObject::Convert(index), LOAD_CONST, -1, {});
2890   ValueNode *item_node = NewValueNode(AObject::Convert(item), BINARY_SUBSCR, 0, {seq_node, index_node});
2891   ValueNode *value_node = NewValueNode(AObject::Convert(tuple), BUILD_TUPLE, 2, {index_node, item_node});
2892   Py_DECREF(tuple);
2893   graph_->GetTracedNodes().push_back(item_node);
2894   graph_->GetTracedNodes().push_back(value_node);
2895 
2896   push(value_node);
2897   cur_bci_ = cur_bci_ + 1;
2898   return true;
2899 }
2900 
CheckForIterZip(ValueNode * iter_node)2901 static bool CheckForIterZip(ValueNode *iter_node) {
2902   ValueNode *zip_node = iter_node->input(0);
2903   if (zip_node->GetOpcode() != CALL_FUNCTION || iter_node->bci() - 1 != zip_node->bci()) {
2904     return false;
2905   }
2906   PyObject *zip = zip_node->GetVobj()->GetPyObject().ptr();
2907   if (zip == nullptr) {
2908     return false;
2909   }
2910   MS_EXCEPTION_IF_NULL(iter_node->GetGraph());
2911   Graph *graph = iter_node->GetGraph();
2912 
2913   std::vector<ValueNode *> iterable_nodes = {zip_node->getInputs().begin() + 1, zip_node->getInputs().end()};
2914   auto iter = std::find_if(iterable_nodes.begin(), iterable_nodes.end(), [&graph](ValueNode *iterable_node) {
2915     PyObject *iterable = iterable_node->GetVobj()->GetPyObject().ptr();
2916     return iterable == nullptr || !PySequence_Check(iterable) || !GuardLoopSequence(graph, iterable_node);
2917   });
2918   if (iter != iterable_nodes.end()) {
2919     return false;
2920   }
2921   return true;
2922 }
2923 
TraceRunForIterZip(int jump_bci)2924 bool GraphBuilder::TraceRunForIterZip(int jump_bci) {
2925   ValueNode *iter_node = seek(0);
2926   int *index = &iter_node->marker_;
2927   if ((*index) == 0) {
2928     if (!CheckForIterZip(iter_node)) {
2929       return false;
2930     }
2931   }
2932 
2933   ValueNode *zip_node = iter_node->input(0);
2934   PyObject *zip = zip_node->GetVobj()->GetPyObject().ptr();
2935   std::vector<ValueNode *> iterable_nodes = {zip_node->getInputs().begin() + 1, zip_node->getInputs().end()};
2936 
2937   // reduce iterable object
2938   PyObject *tuple = PyIter_Next(zip);
2939   py::object handle = py::reinterpret_steal<py::object>(tuple);
2940   if (handle.ptr() == nullptr) {
2941     if (PyErr_Occurred() && !PyErr_ExceptionMatches(PyExc_StopIteration)) {
2942       MS_LOG(ERROR) << "trace FOR_ITER got an error " << py::error_already_set().what();
2943       PyErr_Clear();
2944       return false;
2945     }
2946     PyErr_Clear();
2947     pop();
2948     cur_bci_ = jump_bci;
2949     return true;
2950   }
2951 
2952   std::vector<ValueNode *> inputs;
2953   for (size_t tuple_index = 0; tuple_index < iterable_nodes.size(); ++tuple_index) {
2954     PyObject *item = PyTuple_GET_ITEM(tuple, tuple_index);
2955     ValueNode *seq_node = iterable_nodes[tuple_index];
2956     ValueNode *index_node = NewValueNode(AObject::Convert(py::int_(*index)), LOAD_CONST, -1, {});
2957     ValueNode *item_node = NewValueNode(AObject::Convert(item), BINARY_SUBSCR, 0, {seq_node, index_node});
2958     inputs.push_back(item_node);
2959     graph_->GetTracedNodes().push_back(item_node);
2960   }
2961   ValueNode *value_node = NewValueNode(AObject::Convert(tuple), BUILD_TUPLE, inputs.size(), inputs);
2962   graph_->GetTracedNodes().push_back(value_node);
2963   push(value_node);
2964 
2965   (*index)++;
2966   cur_bci_ = cur_bci_ + 1;
2967   return true;
2968 }
2969 
IsRangeType(ValueNode * iter_node)2970 bool IsRangeType(ValueNode *iter_node) {
2971   if (iter_node->input(0)->GetOpcode() != CALL_FUNCTION) {
2972     return false;
2973   }
2974   auto vobj = iter_node->input(0)->input(0)->GetVobj();
2975   if (vobj == nullptr) {
2976     return false;
2977   }
2978   PyTypeObject *type = reinterpret_cast<PyTypeObject *>(static_cast<AbstractType *>(vobj)->GetPyObject().ptr());
2979   return type == &PyRange_Type;
2980 }
2981 
TraceRunForIter(const Instr & instr)2982 bool GraphBuilder::TraceRunForIter(const Instr &instr) {
2983   MS_EXCEPTION_IF_NULL(instr.extra_jump());
2984 
2985   // check for iter
2986   ValueNode *iter_node = seek(0);
2987   AObject *iterable = iter_node->getInputs().size() > 0 ? iter_node->input(0)->GetVobj() : nullptr;
2988   bool succ;
2989   if (iter_node->GetOpcode() != GET_ITER) {
2990     MS_LOG(DEBUG) << "FOR_ITER without GET_ITER";
2991     succ = false;
2992   } else if (iterable == nullptr) {
2993     succ = false;
2994   } else if (iterable->GetTypeObject() == &PyEnum_Type) {
2995     succ = TraceRunForIterEnumerate(instr.extra_jump()->bci());
2996   } else if (iterable->GetTypeObject() == &PyZip_Type) {
2997     succ = TraceRunForIterZip(instr.extra_jump()->bci());
2998   } else {
2999     succ = TraceRunForIterSequence(instr.extra_jump()->bci(), IsRangeType(iter_node));
3000   }
3001   if (!succ) {
3002     graph_->StopTraceAt(cur_bci_, StopTraceReason::kStopTraceLoop_Unsupported);
3003   }
3004   return succ;
3005 }
3006 
IsConstantBoolValue(ValueNode * node)3007 static bool IsConstantBoolValue(ValueNode *node) {
3008   const auto &cnst_info = node->GetConstantInfo();
3009   if (cnst_info == nullptr) {
3010     return false;
3011   }
3012   if (cnst_info->value().ptr() != nullptr) {
3013     return true;
3014   }
3015   PyTypeObject *tp = cnst_info->type();
3016   if (tp == nullptr) {
3017     return false;
3018   }
3019   static const std::set<PyTypeObject *> len_to_bool = {&PyTuple_Type, &PyList_Type, &PyDict_Type};
3020   if (len_to_bool.find(tp) != len_to_bool.end() && cnst_info->len() != -1) {
3021     return true;
3022   }
3023   return false;
3024 }
3025 
IsShapeOrDtypeRelatedNode(const ValueNode * node)3026 bool IsShapeOrDtypeRelatedNode(const ValueNode *node) {
3027   if (node->GetOpcode() == CALL_FUNCTION && node->input(0)->GetVobj()->GetType() == AObject ::kTypeCFunction &&
3028       node->input(0)->GetName() == "len") {
3029     node = node->input(1);
3030   }
3031   if (node->GetOpcode() == BINARY_SUBSCR) {
3032     node = node->input(0);
3033   }
3034   if (node->GetOpcode() == CALL_FUNCTION) {
3035     auto func_node = node->input(0);
3036     // prim
3037     if (py::isinstance<mindspore::PrimitivePyAdapter>(func_node->GetVobj()->GetPyObject())) {
3038       auto prime_name = py::cast<mindspore::PrimitivePyAdapterPtr>(func_node->GetVobj()->GetPyObject())->name();
3039       if (prime_name == "Shape" || prime_name == "DType" || prime_name == "Rank") {
3040         return true;
3041       }
3042     }
3043   } else if (node->GetOpcode() == LOAD_ATTR) {
3044     auto attr_name = node->GetName();
3045     if (attr_name == "dtype" || attr_name == "shape" || attr_name == "ndim" || attr_name == "size") {
3046       return true;
3047     }
3048   }
3049   return false;
3050 }
3051 
TryGuardEscape(ValueNode * cond_node)3052 bool TryGuardEscape(ValueNode *cond_node) {
3053   if (cond_node->GetOpcode() == COMPARE_OP &&
3054       std::any_of(cond_node->getInputs().begin(), cond_node->getInputs().end(),
3055                   [](const ValueNode *node) { return IsShapeOrDtypeRelatedNode(node); })) {
3056     return true;
3057   }
3058   if (cond_node->GetOpcode() == COMPARE_OP && cond_node->getInputs().size() == 2 &&
3059       cond_node->input(0)->GetOpcode() == BINARY_SUBSCR && cond_node->input(1)->GetOpcode() == BINARY_SUBSCR) {
3060     return true;
3061   }
3062   if (cond_node->GetOpcode() == CONTAINS_OP &&
3063       (IsShapeOrDtypeRelatedNode(cond_node->input(0)) || IsShapeOrDtypeRelatedNode(cond_node->input(1)))) {
3064     return true;
3065   }
3066   if (cond_node->GetOpcode() == CALL_FUNCTION &&
3067       std::all_of(cond_node->getInputs().begin() + 1, cond_node->getInputs().end(),
3068                   [](const ValueNode *node) { return IsShapeOrDtypeRelatedNode(node); })) {
3069     return true;
3070   }
3071   return false;
3072 }
3073 
IsSatisfyPruneLimit(int cond,Graph * graph_,ValueNode * cond_node)3074 bool IsSatisfyPruneLimit(int cond, Graph *graph_, ValueNode *cond_node) {
3075   if (cond == -1) {
3076     return false;
3077   }
3078   int limit_prune = graph_->Config().getIntConfig(GraphJitConfig::kMaxPruneCase);
3079   if (limit_prune >= 0 && limit_prune < graph_->GetPruneBranchCount()) {
3080     return false;
3081   }
3082   if (IsConstantBoolValue(cond_node)) {
3083     return true;
3084   }
3085   auto tr = graph_->TraceValueNode(cond_node);
3086   if (tr == nullptr) {
3087     if (graph_->Config().getIntConfig(GraphJitConfig::kGuardRelaxCount) > 0) {
3088       PyObject *bool_value = cond_node->GetVobj()->GetPyObject().ptr();
3089       if ((bool_value == Py_True || bool_value == Py_False) && TryGuardEscape(cond_node)) {
3090         return true;
3091       }
3092     }
3093     return false;
3094   }
3095   PyObject *bool_value = cond_node->GetVobj()->GetPyObject().ptr();
3096   if (bool_value != Py_True && bool_value != Py_False) {
3097     bool strict = graph_->Config().GetBoolConfig(GraphJitConfig::kStrictTrace);
3098     auto bool_type = CreateOpTrace(reinterpret_cast<PyObject *>(&PyBool_Type), LOAD_CONST, -1, {}, "", "", strict);
3099     tr = CreateOpTrace(cond ? Py_True : Py_False, CALL_FUNCTION, 1, {bool_type, tr}, "", "", strict);
3100   } else {
3101     cond_node->SetConstantValue(true);
3102   }
3103   graph_->GetGuard()->GetGuard()->GuardOn(tr, GuardLevel::GId);
3104   return true;
3105 }
3106 
LogPrunBranch(ValueNode * cond,const Instr & instr,const GraphJitConfig & conf)3107 static void LogPrunBranch(ValueNode *cond, const Instr &instr, const GraphJitConfig &conf) {
3108   MS_LOG(DEBUG) << "trace run prune branch failed [" << cond->ToString() << "]";
3109   if (conf.GetBoolConfig(GraphJitConfig::kPrintGuard)) {
3110     GRAPH_JIT_LOG_F("Fail to prune bytecode [%s]!\n", instr.ToString().c_str());
3111   } else {
3112     MS_LOG(DEBUG) << "Fail to prune bytecode [" << instr.ToString() << "]!\n";
3113   }
3114 
3115   if (conf.GetBoolConfig(GraphJitConfig::kLogGraphBreak)) {
3116     if (CondIsTrue(cond) == -1) {
3117       GRAPH_JIT_LOG_F("infer failed\n");
3118     } else {
3119       auto tr = GetTrace(cond, false, true, 0, conf.getIntConfig(GraphJitConfig::kMaxTraceDepth));
3120       std::map<Trace *, size_t> cache;
3121       GRAPH_JIT_LOG_F("trace:\n%s\n", tr ? tr->FormatString(&cache).c_str() : "trace failed");
3122     }
3123     if (cond->GetGraph() == nullptr || cond->GetGraph()->GetCodeObj() == nullptr) {
3124       return;
3125     }
3126     GRAPH_JIT_LOG_F("if branch prune failed, condition [%s] at [%U : %d]", cond->ToString().c_str(),
3127                     cond->GetGraph()->GetCodeObj()->co_filename, cond->GetLineNo());
3128   }
3129 }
3130 
TraceRunControl(const Instr & instr)3131 bool GraphBuilder::TraceRunControl(const Instr &instr) {
3132   MS_EXCEPTION_IF_NULL(instr.extra_jump());
3133   Opcode opcode(instr.op());
3134   ValueNode *cond_node = nullptr;
3135   int cond = -1;
3136   int jump_to = -1;
3137   if (opcode == JUMP_FORWARD || opcode == JUMP_ABSOLUTE) {
3138     cur_bci_ = instr.extra_jump()->bci();
3139     return true;
3140   } else if (opcode == FOR_ITER) {
3141     return TraceRunForIter(instr);
3142   } else if (opcode == POP_JUMP_IF_FALSE || opcode == POP_JUMP_IF_TRUE) {
3143     cond_node = pop();
3144     cond = CondIsTrue(cond_node);
3145     jump_to = ((cond == 0) ^ (opcode == POP_JUMP_IF_TRUE)) ? instr.extra_jump()->bci() : cur_bci_ + 1;
3146   } else if (opcode == JUMP_IF_FALSE_OR_POP || opcode == JUMP_IF_TRUE_OR_POP) {
3147     cond_node = seek(0);
3148     cond = CondIsTrue(cond_node);
3149     bool jump = (cond == 0) ^ (opcode == JUMP_IF_TRUE_OR_POP);
3150     cond_node = jump ? seek(0) : pop();
3151     jump_to = jump ? instr.extra_jump()->bci() : cur_bci_ + 1;
3152   } else {
3153     graph_->StopTraceAt(cur_bci_, StopTraceReason::kStopTraceByteCode_Unsupported);
3154     return false;
3155   }
3156 
3157   // if branch
3158   if (!IsSatisfyPruneLimit(cond, graph_, cond_node)) {
3159     LogPrunBranch(cond_node, instr, graph_->Config());
3160     graph_->StopTraceAt(cur_bci_, StopTraceReason::kStopTraceIf_Unsupported);
3161     return false;
3162   }
3163   MS_EXCEPTION_IF_CHECK_FAIL(jump_to != -1, "error jump bci");
3164   cur_bci_ = jump_to;
3165   return true;
3166 }
3167 
EliminateCellAccess(Graph * g)3168 static void EliminateCellAccess(Graph *g) {
3169   PyCodeObject *co = g->GetCodeObj();
3170   int ncells = PyTuple_GET_SIZE(co->co_cellvars);
3171   if (ncells == 0) {
3172     return;
3173   }
3174   ValueNode *ret_node = g->GetRetVal();
3175   if (ret_node == nullptr) {
3176     return;
3177   }
3178   std::set<ValueNode *> escaped;
3179   auto CollectClosure = [&escaped](ValueNode *node) {
3180     if (node->GetOpcode() == MAKE_FUNCTION && (node->GetOparg() & 0x08)) {
3181       const auto &in = (*(node->getInputs().end() - 3))->getInputs();
3182       escaped.insert(in.begin(), in.end());
3183     }
3184   };
3185   for (auto i : g->GetTracedNodes()) {
3186     int op = i->GetOpcode();
3187     if (op == STORE_DEREF && i->GetOparg() < ncells) {
3188       // exclude STORE_DEREF
3189       continue;
3190     }
3191     auto begin = i->getInputs().begin();
3192     if (Opcode(op).IsCall() && static_cast<CallNode *>(i)->GetInlineReason() == InlineReason::kInline) {
3193       begin++;
3194     }
3195     std::for_each(begin, i->getInputs().end(), CollectClosure);
3196   }
3197   CollectClosure(ret_node);
3198   // collect STORE_DEREF with MAKE_FUNCTION ...
3199 
3200   const auto &closures = g->GetFrame(0).GetClosures();
3201   for (int i = 0; i < ncells; ++i) {
3202     if (escaped.find(closures[i]) != escaped.end()) {
3203       continue;
3204     }
3205     for (auto node : closures[i]->GetCellOper()) {
3206       if (node->GetOpcode() != STORE_DEREF) {
3207         // closure access before assign, raise UnboundLocalError
3208         return;
3209       }
3210       node->SetOpcode(LOAD_CONST);
3211       node->SetVobj(AObject::Convert(Py_None));
3212       node->ClearInputs();
3213     }
3214     closures[i]->GetCellOper().clear();
3215   }
3216 }
3217 
TraceRun()3218 StopTraceReason GraphBuilder::TraceRun() {
3219   current_block_ = graph_->GetCFG()->GetFirstBB();
3220   cur_bci_ = 0;
3221   const auto &instrs = graph_->GetCFG()->instr_pool();
3222   while (true) {
3223     this->graph_->SetFrame(cur_bci_, frame_);
3224     MS_EXCEPTION_IF_CHECK_FAIL(static_cast<size_t>(cur_bci_) < instrs.size(), "error control flow");
3225     MS_EXCEPTION_IF_CHECK_FAIL(instrs[cur_bci_]->bci() == cur_bci_, "check instruction bci");
3226     if (!DoByteCode(*instrs[cur_bci_])) {
3227       break;
3228     }
3229   }
3230   if (!trace_flag()) {
3231     EliminateCellAccess(this->graph_);
3232   }
3233   return graph_->GetStopTraceReason();
3234 }
3235 
3236 extern void AddConfigToGuard(const GraphJitConfig &c, OptGuardPtr guard);
3237 extern void AddGuardForParam(const PyFrameObject *f, OptGuardPtr guard, bool detach);
3238 
3239 /**
3240  * Generate a graph from callable, this function will actually create python frame
3241  */
GenerateRootGraph(const py::object & callable,const py::object & args,const py::object & kwargs,const GraphJitConfig & conf)3242 static std::unique_ptr<GraphBuilder> GenerateRootGraph(const py::object &callable, const py::object &args,
3243                                                        const py::object &kwargs, const GraphJitConfig &conf) {
3244   PyFrameObject *frame = Utils::PrepareFrame(callable.ptr(), args.ptr(), kwargs.ptr());
3245   if (frame == nullptr) {
3246     PyErr_Clear();
3247     return nullptr;
3248   }
3249   auto jcr = getJitCompileResults(reinterpret_cast<PyObject *>(frame->f_code));
3250   *jcr->conf = conf;
3251   jcr->code = jcr->codehub->AddOptTarget(OptOption::CreateOptionByPoint(jcr));
3252 
3253   auto res = std::make_unique<GraphBuilder>(frame);
3254 
3255   auto code = res->GetGraph()->GetGuard();
3256   AddConfigToGuard(conf, code->GetGuard());
3257   AddGuardForParam(frame, code->GetGuard(), conf.GetBoolConfig(GraphJitConfig::kGuardDetachObject));
3258 
3259   Py_DECREF(frame);
3260   return res;
3261 }
3262 
3263 /**
3264  * build graph and infer func result
3265  * it used to infer mindspore function, maybe replace with mindspore func_graph to infer.
3266  */
InferFuncResult(const py::object & callable,const py::object & args,const py::object & kwargs,const GraphJitConfig & conf,bool clear_guard)3267 AObject *InferFuncResult(const py::object &callable, const py::object &args, const py::object &kwargs,
3268                          const GraphJitConfig &conf, bool clear_guard) {
3269   auto g = GenerateRootGraph(callable, args, kwargs, conf);
3270   if (g == nullptr) {
3271     return nullptr;
3272   }
3273   g->TraceRun();
3274   if (conf.GetBoolConfig(GraphJitConfig::kPrintAfterAll)) {
3275     g->DumpDFG();
3276   }
3277   if (clear_guard) {
3278     Graph *graph = g->GetGraph();
3279     auto jcr = getJitCompileResults(reinterpret_cast<PyObject *>(graph->GetCodeObj()));
3280     jcr->codehub->DelOptTarget(OptOption::CreateOptionByPoint(jcr), graph->GetGuard());
3281   }
3282 
3283   ValueNode *res = g->GetGraph()->GetRetVal();
3284   if (res == nullptr) {
3285     return nullptr;
3286   }
3287   return res->GetVobj();
3288 }
3289 
InferFuncResult(const py::object & func,const std::vector<AObject * > & stack_args,int opcode,const GraphJitConfig & conf,bool clear_guard)3290 AObject *InferFuncResult(const py::object &func, const std::vector<AObject *> &stack_args, int opcode,
3291                          const GraphJitConfig &conf, bool clear_guard) {
3292   std::vector<py::object> args;
3293   std::transform(stack_args.begin(), stack_args.end(), std::back_inserter(args),
3294                  [](AObject *i) { return i ? i->GetPyObject() : py::object(); });
3295   auto pair = Utils::PackCallStackArgs(args, opcode);
3296   if (pair.first.ptr() == nullptr) {
3297     return nullptr;
3298   }
3299   return InferFuncResult(func, pair.first, pair.second, conf, clear_guard);
3300 }
3301 
InferFuncResult(const py::object & callable,const py::object & args,const py::object & kwargs,const GraphJitConfig & conf)3302 AObject *InferFuncResult(const py::object &callable, const py::object &args, const py::object &kwargs,
3303                          const GraphJitConfig &conf) {
3304   return InferFuncResult(callable, args, kwargs, conf, true);
3305 }
3306 
GetGradSens(ValueNode * grad_node)3307 static bool GetGradSens(ValueNode *grad_node) {
3308   AObject *grad_object = grad_node->GetVobj();
3309   if (grad_object->GetPyObject().ptr() != nullptr) {
3310     return grad_object->GetAttr("sens_param")->GetPyObject().ptr() == Py_True;
3311   }
3312   bool sens_param = false;
3313   AObject *cls = grad_node->getInputs().size() > 0 ? grad_node->input(0)->GetVobj() : nullptr;
3314   if (!(Opcode(grad_node->GetOpcode()).IsCall() && cls != nullptr && cls->GetType() == AObject::kTypeType)) {
3315     return sens_param;
3316   }
3317   if (grad_node->GetOpcode() == CALL_FUNCTION && grad_node->getInputs().size() > 3) {
3318     AObject *tmp = grad_node->input(3)->GetVobj();
3319     sens_param = tmp ? tmp->GetPyObject().ptr() == Py_True : false;
3320   } else if (grad_node->GetOpcode() == CALL_FUNCTION_KW) {
3321     py::object kwnames = grad_node->getInputs().back()->GetVobj()->GetPyObject();
3322     PyObject **arr = &PyTuple_GET_ITEM(kwnames.ptr(), 0);
3323     Py_ssize_t size = PyTuple_GET_SIZE(kwnames.ptr());
3324     PyObject **iter = std::find_if(arr, arr + size, [](PyObject *k) {
3325       // find sens_param key
3326       return !PyUnicode_CompareWithASCIIString(k, "sens_param");
3327     });
3328     AObject *tmp = iter - arr != size ? grad_node->input(iter - arr)->GetVobj() : nullptr;
3329     sens_param = tmp ? tmp->GetPyObject().ptr() == Py_True : false;
3330   }
3331   return sens_param;
3332 }
3333 
SetGradFuncInfo(CallNode * call_node)3334 static void SetGradFuncInfo(CallNode *call_node) {
3335   const int flag = AObject::kMsFlagGradFunc | AObject::kMsFlagShardFunc | AObject::kMsFlagVmapFunc;
3336   ValueNode *grad_func_node = call_node->input(0);
3337   if (grad_func_node->getInputs().size() < 2) {
3338     grad_func_node->GetVobj()->ClearMsFlag(flag);
3339     return;
3340   }
3341   ValueNode *grad_node = grad_func_node->input(0);
3342   ValueNode *deco_func_node = grad_func_node->input(1);
3343   AObject *grad_object = grad_node->GetVobj();
3344   AObject *deco_func = deco_func_node->GetVobj();
3345   bool sens_param = false;
3346   if (grad_func_node->GetVobj()->TestMsFlag(AObject::kMsFlagGradFunc) &&
3347       grad_object->GetType() == AObject::kTypeMetaFuncGraph) {
3348     sens_param = GetGradSens(grad_node);
3349   }
3350 
3351   HandleGradFuncCall(call_node, deco_func, sens_param);
3352 
3353   // guard forward net for grad
3354   if (grad_func_node->GetVobj()->TestMsFlag(flag) && !call_node->GetGraph()->GuardValueNode(deco_func_node)) {
3355     grad_func_node->GetVobj()->ClearMsFlag(flag);
3356   }
3357 }
3358 
DumpDFG()3359 void GraphBuilder::DumpDFG() { GRAPH_JIT_LOG_F("%s", graph_->ToString().c_str()); }
3360 
GetLocation(CallNode * call_node) const3361 LocationPtr MindGraphBuilder::GetLocation(CallNode *call_node) const {
3362   auto file_name = py::cast<std::string>(graph_->GetCodeObj()->co_filename);
3363   auto line_no = call_node->GetLineNo();
3364   std::vector<std::string> comments;
3365   return std::make_shared<Location>(file_name, line_no, 0, line_no, 0, "", std::move(comments));
3366 }
3367 
WhiteListFuncCheckAndInfer(CallNode * call_node,const py::object & callable)3368 bool MindGraphBuilder::WhiteListFuncCheckAndInfer(CallNode *call_node, const py::object &callable) {
3369   InferFunc infer_func = FindInferFunc(callable, trace_flag());
3370   if (infer_func != nullptr) {
3371     call_node->SetSubGraph(NewGraph(nullptr, nullptr));
3372     call_node->GetSubGraph()->SetGuard(root_->GetGraph()->GetGuard());
3373     bool has_sub_graph = infer_func(call_node, this);
3374     if (!has_sub_graph) {
3375       call_node->SetInlineReason(InlineReason::kInlineFuncSpecialize);
3376       MS_ASSERT(!call_node->GetSubGraph());  // check infer function
3377       return true;
3378     }
3379     call_node->SetInlineReason(InlineReason::kInline);
3380     ValueNode *ret_node = call_node->GetSubGraph()->GetRetVal();
3381     MS_EXCEPTION_IF_CHECK_FAIL(ret_node, "infer special function failed");
3382     seek(0) = ret_node;
3383     return true;
3384   }
3385   return false;
3386 }
3387 
FGAddInputs(const std::vector<py::object> & args)3388 bool MindGraphBuilder::FGAddInputs(const std::vector<py::object> &args) {
3389   // Add function graph inputs.
3390   for (size_t i = 0; i < args.size(); ++i) {
3391     auto obj = FGBuilder()->AddSubGraphInput(args[i]);
3392     if (obj.ptr() == nullptr) {
3393       MS_LOG(INFO) << "Add input fail for input: " << std::string(py::str(args[i]));
3394       return false;
3395     }
3396     MS_LOG(INFO) << "Add input success for input: " << std::string(py::str(args[i]));
3397   }
3398   return true;
3399 }
3400 
FGAddOutput(bool is_top_graph)3401 void MindGraphBuilder::FGAddOutput(bool is_top_graph) {
3402   if (auto ret = GetGraph()->GetRetVal()) {
3403     MS_LOG(INFO) << ret->GetVobj()->ToString();
3404     auto out = ret->GetVobj()->GetPyObject();
3405     MS_LOG(INFO) << "try add output: " << py::str(out) << " addr:" << out.ptr();
3406     if (FGBuilder()->AddOutput(out, is_top_graph)) {
3407       MS_LOG(INFO) << "add output succuss";
3408     } else {
3409       MS_LOG(INFO) << "add output fail";
3410     }
3411   }
3412 }
3413 
FGAddNode(CallNode * call_node,const py::object & callable_info,const std::vector<py::object> & args,StopTraceReason * stop_reason)3414 py::object MindGraphBuilder::FGAddNode(CallNode *call_node, const py::object &callable_info,
3415                                        const std::vector<py::object> &args, StopTraceReason *stop_reason) {
3416   MS_LOG(INFO) << "try add node: " << py::str(callable_info);
3417   TraceGuard trace_guard(GetLocation(call_node));
3418   auto res = FGBuilder()->AddNode(callable_info, args);
3419   if (res.ptr() == nullptr) {
3420     MS_LOG(INFO) << "add node fail";
3421     *stop_reason = StopTraceReason::kTrace_Fail;
3422   } else {
3423     MS_LOG(INFO) << "add node suc";
3424     auto node = AObject::Convert(res);
3425     MS_LOG(INFO) << py::str(node->GetPyObject());
3426     MS_LOG(INFO) << node->ToString();
3427     call_node->SetVobj(node);
3428     *stop_reason = StopTraceReason::kNonStopTrace;
3429   }
3430   return py::object();
3431 }
3432 
GetNewArgs(CallNode * call_node,AObject * vobj)3433 std::vector<py::object> MindGraphBuilder::GetNewArgs(CallNode *call_node, AObject *vobj) {
3434   std::vector<py::object> new_args;
3435   vobj = (vobj && vobj->GetType() != AObject::kTypePrimitive) ? vobj : call_node->input(0)->GetVobj();
3436   if (vobj->GetType() == AObject::kTypeCFunction) {
3437     MS_LOG(INFO) << "not support cfunction";
3438   }
3439   auto new_callable_info = FindPyFunc(vobj);
3440   FrameStates f;
3441   ResolveClosure(new_callable_info, call_node->input(0), &f);
3442 
3443   // Need to consider repeat add issue.
3444   if (!HandleCallParameters(new_callable_info, call_node, &f)) {
3445     MS_LOG(INFO) << "HandleCallParameters error" << std::endl;
3446   }
3447   PyCodeObject *co = reinterpret_cast<PyCodeObject *>(PyFunction_GET_CODE(new_callable_info.ptr()));
3448   int argc = co->co_argcount + co->co_kwonlyargcount;
3449   argc += (co->co_flags & CO_VARARGS) ? 1 : 0;
3450   argc += (co->co_flags & CO_VARKEYWORDS) ? 1 : 0;
3451   for (auto it = f.GetLocals().begin(); it != f.GetLocals().begin() + argc; it++) {
3452     std::set<AObject::Type> unsupported_parameter = {
3453       AObject::kTypeAnyValue,  AObject::kTypeFunction,      AObject::kTypeBoundMethod,
3454       AObject::kTypePrimitive, AObject::kTypeMetaFuncGraph, AObject::kTypeCell,
3455     };
3456     auto it_vobj = (*it)->GetVobj();
3457     if (it_vobj != nullptr) {
3458       auto pyobj = it_vobj->GetPyObject();
3459       if (pyobj.ptr() != nullptr) {
3460         if (unsupported_parameter.find(AbstractObjectBase::GetPyType(pyobj.ptr())) == unsupported_parameter.end()) {
3461           new_args.push_back(pyobj);
3462         }
3463       }
3464     }
3465   }
3466   return new_args;
3467 }
3468 
AllConstantArgs(const std::vector<py::object> & args,const py::object & callable_info,CallNode * call_node)3469 bool MindGraphBuilder::AllConstantArgs(const std::vector<py::object> &args, const py::object &callable_info,
3470                                        CallNode *call_node) {
3471   auto new_args = args;
3472   if (PyFunction_Check(callable_info.ptr())) {
3473     new_args = GetNewArgs(call_node);
3474   }
3475 
3476   return std::all_of(new_args.begin(), new_args.end(), [](const auto &arg) { return CheckConstPyObject(arg.ptr()); });
3477 }
3478 
ResolveCallable(CallNode * call_node,StopTraceReason * stop_reason)3479 py::object MindGraphBuilder::ResolveCallable(CallNode *call_node, StopTraceReason *stop_reason) {
3480   AObject *callable = call_node->input(0)->GetVobj();
3481   py::object callable_info;
3482   *stop_reason = StopTraceReason::kStopTraceInfer_Fail;
3483   if (!callable) {
3484     return callable_info;
3485   }
3486   callable_info = callable->GetPyObject();
3487   py::object original_callable = callable_info;
3488   if (callable_info.ptr() == nullptr) {
3489     return py::object();
3490   }
3491   if (!FGBuilder()->ValidateCallableObject(callable_info)) {
3492     return py::object();
3493   }
3494   MS_LOG(INFO) << "trace_flag for: " << py::str(callable_info);
3495   auto args = call_node->GetArgs();
3496   if (FGBuilder()->CanConstantFoldFunc(callable_info) && AllConstantArgs(args, callable_info, call_node)) {
3497     MS_LOG(INFO) << "CanConstantFoldFunc for: " << py::str(callable_info);
3498     JustCallAndSetRes(call_node);
3499     *stop_reason = StopTraceReason::kNonStopTrace;
3500     return py::object();
3501   }
3502   auto method = FGBuilder()->ConvertMethod(callable_info);
3503   if (method.ptr() != nullptr) {
3504     MS_LOG(INFO) << "convert method :" << py::str(callable_info) << " to " << py::str(method);
3505     callable_info = method;
3506     if (!PyFunction_Check(callable_info.ptr())) {  // prim getnewargs here, func getnewargs in subgraph
3507       args = GetNewArgs(call_node, AObject::Convert(callable_info.ptr()));
3508     }
3509   }
3510   auto func = FGBuilder()->ConvertFunction(callable_info);
3511   if (func.ptr() != nullptr) {
3512     MS_LOG(INFO) << "convert function:" << py::str(callable_info) << " to " << py::str(func);
3513     callable_info = func;
3514   }
3515   if (FGBuilder()->CheckCallable(callable_info)) {
3516     if (PyFunction_Check(callable_info.ptr())) {
3517       args = GetNewArgs(call_node);
3518     }
3519     return FGAddNode(call_node, callable_info, args, stop_reason);
3520   }
3521 
3522   py::object result = this->GraphBuilder::ResolveCallable(call_node, stop_reason);
3523   bool pijit_specialized = original_callable == callable_info             // not converted
3524                            || call_node->GetSubGraph() != nullptr         // pijit sub graph
3525                            || callable->GetType() == AObject::kTypeType;  // pijit class instantiation
3526   if (pijit_specialized) {
3527     return result;
3528   }
3529   MS_LOG(DEBUG) << "convert " << std::string(py::str(original_callable)) << " -> "
3530                 << std::string(py::str(callable_info));
3531   return FindPyFunc(AObject::Convert(callable_info));
3532 }
3533 
HandleCallClass(CallNode * call_node)3534 bool MindGraphBuilder::HandleCallClass(CallNode *call_node) {
3535   bool succ = GraphBuilder::HandleCallClass(call_node);
3536   if (!succ) {
3537     MS_LOG(INFO) << "Failed to handle call class";
3538     return false;
3539   } else if (call_node->GetVobj() != nullptr && call_node->GetVobj()->GetPyObject().ptr() != nullptr) {
3540     return FGBuilder()->AddLocalVariable(call_node->GetVobj()->GetPyObject());
3541   }
3542   return false;
3543 }
3544 
3545 // Fix dynamic shape tensor get shape issue.
3546 // Guard and Renormalize strategy should be refactored later.
HandleGetShapeOfDynamicLengthTensor(const py::object & object)3547 py::object MindGraphBuilder::HandleGetShapeOfDynamicLengthTensor(const py::object &object) {
3548   auto anf_node = fg_builder_->ReadLocalVariable(object);
3549   if (anf_node == nullptr || anf_node->abstract() == nullptr) {
3550     return py::object();
3551   }
3552   auto abs = anf_node->abstract();
3553   auto shape = abs->BuildShape();
3554   if (!shape->isa<abstract::TensorShape>()) {
3555     return py::object();
3556   }
3557   const auto &tensor_shape = shape->cast<abstract::TensorShapePtr>()->GetShapeVector();
3558   if (std::all_of(tensor_shape.begin(), tensor_shape.end(), [](auto e) { return e > 0; })) {
3559     return py::object();
3560   }
3561   std::vector<py::object> input_objects = {object};
3562   return fg_builder_->AddNode(prim::kPrimShape, input_objects);
3563 }
3564 
HandleGetattr(ValueNode * target_node,const Instr & instr)3565 ValueNode *MindGraphBuilder::HandleGetattr(ValueNode *target_node, const Instr &instr) {
3566   auto attr_node = NewValueNode(target_node->get_attr(instr.name()), instr, {target_node});
3567   MS_EXCEPTION_IF_NULL(attr_node);
3568   ValueNode *graph_attr_node = nullptr;
3569   auto attr_obj = attr_node->GetVobj()->GetPyObject();
3570   if (instr.name() == "shape") {
3571     auto ret_object = HandleGetShapeOfDynamicLengthTensor(target_node->GetVobj()->GetPyObject());
3572     if (ret_object.ptr() != nullptr) {
3573       return NewValueNode(AObject::Convert(ret_object), instr, {target_node});
3574     }
3575   }
3576   // If the attr_obj can convert to anf node directly, return the origin attr node.
3577   if (fg_builder_->AddAttrPythonObject(attr_obj)) {
3578     graph_attr_node = attr_node;
3579   } else {
3580     std::vector<py::object> input_objects = {target_node->GetVobj()->GetPyObject(), py::str(instr.name())};
3581     auto graph_attr_obj = fg_builder_->AddNode(prim::kPrimGetAttr, input_objects);
3582     if (graph_attr_obj.ptr() == nullptr) {
3583       graph_attr_node = attr_node;
3584     } else {
3585       graph_attr_node = NewValueNode(AObject::Convert(graph_attr_obj), instr, {target_node});
3586     }
3587   }
3588   // Add Id guard for parameter, in case default value for parameter change in execution.
3589   if (attr_obj.ptr() != nullptr && py::hasattr(attr_obj, "__parameter__") &&
3590       py::isinstance<tensor::MetaTensor>(attr_obj)) {
3591     graph_->GuardValueNode(graph_attr_node, GuardLevel::GId);
3592     return graph_attr_node;
3593   }
3594   // Add Guard for getattr node. For scalar/list/tuple/primitive, need to guard value. Otherwise, guard type and shape.
3595   AObject::Type attr_type = graph_attr_node->GetVobj() ? graph_attr_node->GetVobj()->GetType() : AObject::kTypeAnyValue;
3596   static const std::vector<AObject::Type> const_type = {AObject::kTypeInt,      AObject::kTypeFloat, AObject::kTypeBool,
3597                                                         AObject::kTypeTuple,    AObject::kTypeList,  AObject::kTypeDict,
3598                                                         AObject::kTypePrimitive};
3599   // Need to check whether the guard is failed in the future.
3600   if (std::any_of(const_type.begin(), const_type.end(),
3601                   [attr_type](const AObject::Type type) { return attr_type == type; })) {
3602     graph_->GuardValueNode(graph_attr_node, GuardLevel::GEqual);
3603   } else if (attr_type != AObject::kTypeFunction && attr_type != AObject::kTypeBoundMethod &&
3604              attr_type != AObject::kTypeCFunction) {
3605     graph_->GuardValueNode(graph_attr_node, GuardLevel::GDeduce);
3606   }
3607   return graph_attr_node;
3608 }
3609 
HandleMultiOp(const Instr & instr,const std::vector<ValueNode * > & p,bool is_compare)3610 AObject *MindGraphBuilder::HandleMultiOp(const Instr &instr, const std::vector<ValueNode *> &p, bool is_compare) {
3611   int opcode = instr.op();
3612   int oparg = instr.arg();
3613   std::vector<py::object> input_obj;
3614   for (auto input : p) {
3615     if (input->GetVobj() == nullptr) {
3616       return AObject::MakeAObject(AObject::kTypeAnyValue);
3617     }
3618     (void)input_obj.emplace_back(input->GetVobj()->GetPyObject());
3619   }
3620   const auto &op_name =
3621     is_compare ? pijit::GraphUtils::OpCompareArgToGraphName(oparg) : pijit::GraphUtils::OpCodeToGraphName(opcode);
3622   MS_LOG(DEBUG) << "operation name is " << op_name;
3623   if (op_name == "") {
3624     return AObject::MakeAObject(AObject::kTypeAnyValue);
3625   }
3626   auto node = fg_builder_->AddMultiNode(op_name, input_obj);
3627   if (node.ptr() == nullptr) {
3628     return AObject::MakeAObject(AObject::kTypeAnyValue);
3629   }
3630   return AObject::Convert(node);
3631 }
3632 
HandleBuildOp(const Instr & instr,const std::vector<ValueNode * > & p)3633 AObject *MindGraphBuilder::HandleBuildOp(const Instr &instr, const std::vector<ValueNode *> &p) {
3634   auto opcode = instr.op();
3635   std::vector<py::object> input_obj;
3636   for (auto input : p) {
3637     if (input->GetVobj() == nullptr) {
3638       return AObject::MakeAObject(AObject::kTypeAnyValue);
3639     }
3640     (void)input_obj.emplace_back(input->GetVobj()->GetPyObject());
3641   }
3642   auto primitive = pijit::GraphUtils::GetPrimitive(opcode);
3643   if (primitive == nullptr) {
3644     return AObject::MakeAObject(AObject::kTypeAnyValue);
3645   }
3646   if (primitive == prim::kPrimMakeDict) {
3647     if (opcode == BUILD_CONST_KEY_MAP) {
3648       MS_LOG(DEBUG) << "BUILD_CONST_KEY_MAP case, need to pack values.";
3649       std::vector<py::object> value_inputs;
3650       (void)std::transform(input_obj.begin(), input_obj.end() - 1, std::back_inserter(value_inputs),
3651                            [](const py::object &obj) { return obj; });
3652       auto value_node = fg_builder_->AddNode(prim::kPrimMakeTuple, value_inputs);
3653       input_obj = {input_obj.back(), value_node};
3654     } else {
3655       MS_LOG(DEBUG) << "BUILD_KEY_MAP case, need to pack keys and values.";
3656       size_t input_len = input_obj.size();
3657       if (input_len % 2 != 0) {
3658         MS_LOG(INTERNAL_EXCEPTION) << "BUILD_KEY_MAP should have even input, but got: " << input_len;
3659       }
3660       std::vector<py::object> key_obj;
3661       std::vector<py::object> value_obj;
3662       for (size_t i = 0; i < input_len / 2; ++i) {
3663         key_obj.push_back(input_obj[2 * i]);
3664         value_obj.push_back(input_obj[2 * i + 1]);
3665       }
3666       auto key_node = fg_builder_->AddNode(prim::kPrimMakeTuple, key_obj);
3667       auto value_node = fg_builder_->AddNode(prim::kPrimMakeTuple, value_obj);
3668       input_obj = {key_node, value_node};
3669     }
3670   }
3671   if (primitive == prim::kPrimMakeSlice) {
3672     constexpr size_t slice_without_step_len = 2;
3673     if (input_obj.size() == slice_without_step_len) {
3674       // Handle slice without step input scene, such as 0:2. MakeSlice can only handle slice with full inputs.
3675       (void)input_obj.emplace_back(py::int_(1));
3676     }
3677   }
3678   auto node = fg_builder_->AddNode(primitive, input_obj);
3679   return AObject::Convert(node);
3680 }
3681 
DoGetItem(const Instr & instr)3682 bool MindGraphBuilder::DoGetItem(const Instr &instr) {
3683   auto r = pop();
3684   auto l = pop();
3685   auto o = HandleMultiOp(instr, {l, r}, false);
3686   auto v = NewValueNode(o, instr, {l, r});
3687   push(v);
3688   return true;
3689 }
3690 
DoUnary(const Instr & instr)3691 bool MindGraphBuilder::DoUnary(const Instr &instr) {
3692   auto o = pop();
3693   auto r = HandleMultiOp(instr, {o}, false);
3694   auto v = NewValueNode(r, instr, {o});
3695   push(v);
3696   return true;
3697 }
3698 
DoBinary(const Instr & instr)3699 bool MindGraphBuilder::DoBinary(const Instr &instr) {
3700   auto r = pop();
3701   auto l = pop();
3702   auto o = HandleMultiOp(instr, {l, r}, false);
3703   auto v = NewValueNode(o, instr, {l, r});
3704   push(v);
3705   return true;
3706 }
3707 
DoBinaryMul(const Instr & instr)3708 bool MindGraphBuilder::DoBinaryMul(const Instr &instr) {
3709   auto r = pop();
3710   auto l = pop();
3711   auto o = HandleMultiOp(instr, {l, r}, false);
3712   auto v = NewValueNode(o, instr, {l, r});
3713   push(v);
3714   return true;
3715 }
3716 
DoCompare(const Instr & instr)3717 bool MindGraphBuilder::DoCompare(const Instr &instr) {
3718   auto r = pop();
3719   auto l = pop();
3720   auto o = HandleMultiOp(instr, {l, r}, true);
3721   auto v = NewValueNode(o, instr, {l, r});
3722   push(v);
3723   return true;
3724 }
3725 
DoBuildOp(const Instr & instr)3726 bool MindGraphBuilder::DoBuildOp(const Instr &instr) {
3727   int opcode = instr.op();
3728   int oparg = instr.arg();
3729   int tmp_arg = oparg;
3730   tmp_arg += opcode == BUILD_CONST_KEY_MAP;
3731   tmp_arg += opcode == BUILD_MAP ? tmp_arg : 0;
3732   std::vector<ValueNode *> p(frame_.GetStacks().end() - tmp_arg, frame_.GetStacks().end());
3733   auto o = HandleBuildOp(instr, p);
3734   popn(tmp_arg);
3735   auto v = NewValueNode(o, instr, p);
3736   push(v);
3737   return true;
3738 }
3739 
DoIsOp(const Instr & instr)3740 bool MindGraphBuilder::DoIsOp(const Instr &instr) { return GraphBuilder::DoBinary(instr); }
3741 
HandlePositionParams(const py::object & func,std::vector<ValueNode * > * params,FrameStates * frame)3742 bool MindGraphBuilder::HandlePositionParams(const py::object &func, std::vector<ValueNode *> *params,
3743                                             FrameStates *frame) {
3744   CallNode *call_node = reinterpret_cast<CallNode *>(seek(0));
3745   PyCodeObject *co = reinterpret_cast<PyCodeObject *>(PyFunction_GET_CODE(func.ptr()));
3746   auto vobj = trace_flag() ? AObject::Convert(func.ptr()) : call_node->input(0)->GetVobj();
3747   AObject::Type callable_type = vobj->GetType();
3748 
3749   ValueNode *self = GetBoundSelf(call_node);
3750   if (self != nullptr) {
3751     params->insert(params->begin(), self);
3752   }
3753 
3754   const int argc = co->co_argcount;
3755   const int has_varg = (co->co_flags & CO_VARARGS) ? 1 : 0;
3756   const int has_kwvarg = (co->co_flags & CO_VARKEYWORDS) ? 1 : 0;
3757   const int varg_loc = argc + co->co_kwonlyargcount;
3758   const int kwvarg_loc = argc + co->co_kwonlyargcount + has_varg;
3759   int pargc = params->size();
3760   if (pargc > argc && !has_varg) {
3761     MS_LOG(DEBUG) << "too many parameters";
3762     return false;
3763   }
3764   bool append_self_to_varg = has_varg && self && callable_type == AObject::kTypeBoundMethod && argc == 0;
3765   if (append_self_to_varg) {  // self is in variable arguments
3766     MS_LOG(INFO) << "not implement append self to variable arguments, inline failed";
3767     return false;
3768   }
3769 
3770   if (has_kwvarg && frame->Local(kwvarg_loc) == &ValueNode::kUnboundLocal) {
3771     auto vo = AObject::Convert(py::dict());
3772     auto m = NewValueNode(vo, BUILD_MAP, 0, {});
3773     call_node->AddParam(m);
3774     frame->SetLocal(kwvarg_loc, m);
3775   }
3776 
3777   if (has_varg) {
3778     int vargc = pargc > argc ? pargc - argc : 0;
3779     std::vector<ValueNode *> vargs(params->end() - vargc, params->end());
3780     params->resize(params->size() - vargc);
3781     std::for_each(vargs.begin(), vargs.end(), [this](ValueNode *i) { this->push(i); });
3782     DoBuildOp({BUILD_TUPLE, static_cast<int>(vargs.size())});
3783     ValueNode *build_tuple = pop();
3784     call_node->AddParam(build_tuple);
3785     frame->SetLocal(varg_loc, build_tuple);
3786   }
3787 
3788   pargc = params->size();
3789   for (int i = pargc - 1; i >= 0; --i) {
3790     if (frame->Local(i) != &ValueNode::kUnboundLocal) {
3791       MS_LOG(DEBUG) << "duplicate key-word parameter error";
3792       return false;
3793     }
3794     frame->SetLocal(i, params->back());
3795     params->pop_back();
3796   }
3797 
3798   return CheckAndSetDefaultParams(func, frame, pargc);
3799 }
3800 
UnpackCallExParams(std::vector<ValueNode * > * params,int extra_local,bool * has_kw,CallNode * call_node)3801 bool MindGraphBuilder::UnpackCallExParams(std::vector<ValueNode *> *params, int extra_local, bool *has_kw,
3802                                           CallNode *call_node) {
3803   bool has_dict = params->size() > 1;
3804   ValueNode *args_node = params->operator[](0);
3805   if (!has_dict) {
3806     params->clear();
3807   } else if (!UnpackCallExDict(params, call_node)) {
3808     return false;
3809   }
3810   *has_kw = params->size();
3811 
3812   if (args_node->GetVobj() == nullptr) {
3813     return false;
3814   }
3815   py::object object = args_node->GetVobj()->GetPyObject();
3816   if (!py::isinstance<py::tuple>(object)) {
3817     return false;
3818   }
3819   size_t args_len = py::len(py::cast<py::tuple>(object));
3820   if (args_len == 0) {
3821     return true;
3822   }
3823 
3824   std::vector<ValueNode *> new_args_inputs;
3825   for (size_t i = 0; i < args_len; ++i) {
3826     Instr instr(BINARY_SUBSCR, 2);
3827     auto l = args_node;
3828     auto r = NewValueNode(AObject::Convert(py::int_(i)), LOAD_CONST, -1, {});
3829     auto o = HandleMultiOp(instr, {l, r}, false);
3830     new_args_inputs.push_back(NewValueNode(o, instr, {l, r}));
3831   }
3832 
3833   params->insert(params->begin(), new_args_inputs.begin(), new_args_inputs.end());
3834   return true;
3835 }
3836 
HandleKWParams(const py::object & func,std::vector<ValueNode * > * params,FrameStates * frame)3837 bool MindGraphBuilder::HandleKWParams(const py::object &func, std::vector<ValueNode *> *params, FrameStates *frame) {
3838   PyCodeObject *co = reinterpret_cast<PyCodeObject *>(PyFunction_GET_CODE(func.ptr()));
3839   std::vector<ValueNode *> kwvargs;
3840   if (!PackKwParams(func, params, frame, &kwvargs)) {
3841     // illegal arguments
3842     return false;
3843   }
3844 
3845   const int argc = co->co_argcount + co->co_kwonlyargcount;
3846   if (!(co->co_flags & CO_VARKEYWORDS)) {
3847     // kw_2_p_cnt == k_cnt, all kw arguments is positions arguments
3848     return true;
3849   }
3850 
3851   int kwvarg_loc = argc + ((co->co_flags & CO_VARARGS) ? 1 : 0);
3852   std::for_each(kwvargs.begin(), kwvargs.end(), [this](ValueNode *i) { this->push(i); });
3853   DoBuildOp({BUILD_MAP, SizeToInt(kwvargs.size() / 2)});
3854   ValueNode *new_node = pop();
3855   frame->SetLocal(kwvarg_loc, new_node);
3856   graph_->GetTracedNodes().pop_back();
3857 
3858   static_cast<CallNode *>(seek(0))->AddParam(frame->Local(kwvarg_loc));
3859   return true;
3860 }
3861 
UnpackCallExDict(std::vector<ValueNode * > * params,CallNode * call_node)3862 bool MindGraphBuilder::UnpackCallExDict(std::vector<ValueNode *> *params, CallNode *call_node) {
3863   ValueNode *dict_node = params->back();
3864   params->clear();
3865 
3866   if (dict_node->GetVobj() == nullptr) {
3867     return false;
3868   }
3869 
3870   auto object = dict_node->GetVobj()->GetPyObject();
3871   if (!py::isinstance<py::dict>(object)) {
3872     return false;
3873   }
3874   auto dict_object = py::cast<py::dict>(object);
3875   Py_ssize_t dict_len = py::len(dict_object);
3876   if (dict_len == 0) {
3877     return true;
3878   }
3879 
3880   py::tuple keys(dict_len);
3881   size_t i = 0;
3882   for (const auto &pair : dict_object) {
3883     auto cur_key = pair.first;
3884     if (!py::isinstance<py::str>(cur_key)) {
3885       return false;
3886     }
3887     keys[i] = cur_key;
3888     Instr instr(BINARY_SUBSCR, 2);
3889     auto l = dict_node;
3890     auto r = NewValueNode(AObject::Convert(py::cast<py::str>(cur_key)), LOAD_CONST, -1, {});
3891     auto o = HandleMultiOp(instr, {l, r}, false);
3892     params->push_back(NewValueNode(o, instr, {l, r}));
3893     i++;
3894   }
3895 
3896   ValueNode *const_keys = this->NewValueNode(AObject::Convert(keys), LOAD_CONST, -1, {});
3897   params->push_back(const_keys);
3898   return true;
3899 }
3900 
DoItemAccess(const Instr & instr)3901 bool MindGraphBuilder::DoItemAccess(const Instr &instr) {
3902   int opcode = instr.op();
3903   bool res = false;
3904   if (opcode == BINARY_SUBSCR) {
3905     res = DoGetItem(instr);
3906   } else if (opcode == STORE_SUBSCR) {
3907     auto key = pop();
3908     auto map = pop();
3909     auto value = pop();
3910     NewValueNode(nullptr, instr, {value, map, key});
3911     res = DoSetItem(map, key, value);
3912   } else if (opcode == DELETE_SUBSCR) {
3913     auto key = pop();
3914     auto map = pop();
3915     NewValueNode(nullptr, instr, {map, key});
3916     res = DoSetItem(map, key, nullptr);
3917   } else {
3918     MS_LOG(INTERNAL_EXCEPTION) << "parser got an error instruction " << instr.ToString();
3919   }
3920   return res;
3921 }
3922 }  // namespace pijit
3923 }  // namespace mindspore
3924