1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pipeline/jit/pi/graph_capture/graph_build.h"
17 #include <algorithm>
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <vector>
22 #include <utility>
23 #include <unordered_map>
24 #include <map>
25 #include "pipeline/jit/pi/common.h"
26 #include "pipeline/jit/pi/graph_capture/loop_unrolling.h"
27 #include "pipeline/jit/pi/graph_capture/special_func_infer.h"
28 #include "pipeline/jit/pi/graph_guard/infer.h"
29 #include "pipeline/jit/pi/external.h"
30 #include "pipeline/jit/pi/graph_build/func_graph_builder.h"
31 #include "pipeline/jit/pi/graph_capture/abstract_object.h"
32 #include "include/common/debug/anf_ir_dump.h"
33 #include "pipeline/jit/pi/graph_compiler/utils.h"
34 #include "ops/sequence_ops.h"
35 #include "ops/framework_ops.h"
36 #include "ops/structure_ops.h"
37 #include "mindspore/core/ir/cell.h"
38 #include "pybind_api/ir/primitive_py.h"
39 #include "ops/auto_generate/gen_ops_primitive.h"
40
41 namespace mindspore {
42 namespace pijit {
43 extern TracePtr GetTrace(ValueNode *node, bool strict, bool print, int depth, int max_depth);
44
45 void LogGuardFailed(ValueNode *node, const GraphJitConfig &conf, const std::string &msg);
46 static bool GuardLoopSequence(Graph *graph, ValueNode *seq_node, Py_ssize_t seq_size = -1);
47
48 const char *GraphBuilder::ID___self__ = "__self__";
49 const char *GraphBuilder::ID___globals__ = "__globals__";
50 const char *GraphBuilder::ID___call__ = "__call__";
51 const char *GraphBuilder::ID_construct = "construct";
52
53 static const int infer_primitive_create = 1;
54 static const int infer_primitive_object = 2;
55 static const int infer_primitive_func = 4;
56 static int infer_func_count = 0;
57 static constexpr const char *kPIJitCopyFuncKey = ".<pijit.copy>.";
58
59 const std::unordered_map<int, bool (GraphBuilder::*)(const Instr &)> GraphBuilder::bytecode_meth_map_ = {
60 {POP_TOP, &GraphBuilder::DoStackOp},
61 {ROT_TWO, &GraphBuilder::DoStackOp},
62 {ROT_THREE, &GraphBuilder::DoStackOp},
63 {ROT_FOUR, &GraphBuilder::DoStackOp},
64 {DUP_TOP, &GraphBuilder::DoStackOp},
65 {DUP_TOP_TWO, &GraphBuilder::DoStackOp},
66 {NOP, &GraphBuilder::DoNop},
67 {EXTENDED_ARG, &GraphBuilder::DoNop},
68 {RETURN_VALUE, &GraphBuilder::DoReturn},
69 {UNARY_POSITIVE, &GraphBuilder::DoUnary},
70 {UNARY_NEGATIVE, &GraphBuilder::DoUnary},
71 {UNARY_NOT, &GraphBuilder::DoUnary},
72 {UNARY_INVERT, &GraphBuilder::DoUnary},
73 {BINARY_MATRIX_MULTIPLY, &GraphBuilder::DoBinary},
74 {BINARY_MULTIPLY, &GraphBuilder::DoBinaryMul},
75 {BINARY_MODULO, &GraphBuilder::DoBinary},
76 {BINARY_POWER, &GraphBuilder::DoBinary},
77 {BINARY_ADD, &GraphBuilder::DoBinaryAdd},
78 {BINARY_SUBTRACT, &GraphBuilder::DoBinary},
79 {BINARY_FLOOR_DIVIDE, &GraphBuilder::DoBinary},
80 {BINARY_TRUE_DIVIDE, &GraphBuilder::DoBinary},
81 {BINARY_LSHIFT, &GraphBuilder::DoBinary},
82 {BINARY_RSHIFT, &GraphBuilder::DoBinary},
83 {BINARY_AND, &GraphBuilder::DoBinary},
84 {BINARY_XOR, &GraphBuilder::DoBinary},
85 {BINARY_OR, &GraphBuilder::DoBinary},
86 {INPLACE_MATRIX_MULTIPLY, &GraphBuilder::DoBinary},
87 {INPLACE_MULTIPLY, &GraphBuilder::DoBinary},
88 {INPLACE_MODULO, &GraphBuilder::DoBinary},
89 {INPLACE_POWER, &GraphBuilder::DoBinary},
90 {INPLACE_ADD, &GraphBuilder::DoInplaceAdd},
91 {INPLACE_SUBTRACT, &GraphBuilder::DoBinary},
92 {INPLACE_FLOOR_DIVIDE, &GraphBuilder::DoBinary},
93 {INPLACE_TRUE_DIVIDE, &GraphBuilder::DoBinary},
94 {INPLACE_LSHIFT, &GraphBuilder::DoBinary},
95 {INPLACE_RSHIFT, &GraphBuilder::DoBinary},
96 {INPLACE_AND, &GraphBuilder::DoBinary},
97 {INPLACE_XOR, &GraphBuilder::DoBinary},
98 {INPLACE_OR, &GraphBuilder::DoBinary},
99 {IS_OP, &GraphBuilder::DoIsOp},
100 {CONTAINS_OP, &GraphBuilder::DoIsOp},
101 {BUILD_TUPLE, &GraphBuilder::DoBuildOp},
102 {BUILD_LIST, &GraphBuilder::DoBuildOp},
103 {BUILD_SET, &GraphBuilder::DoBuildOp},
104 {BUILD_MAP, &GraphBuilder::DoBuildOp},
105 {BUILD_SLICE, &GraphBuilder::DoBuildOp},
106 {BUILD_CONST_KEY_MAP, &GraphBuilder::DoBuildOp},
107 {BUILD_STRING, &GraphBuilder::DoBuildOp},
108 {LIST_APPEND, &GraphBuilder::DoMergeOp},
109 {LIST_EXTEND, &GraphBuilder::DoMergeOp},
110 {DICT_MERGE, &GraphBuilder::DoMergeOp},
111 {DICT_UPDATE, &GraphBuilder::DoMergeOp},
112 {SET_UPDATE, &GraphBuilder::DoMergeOp},
113 {SET_ADD, &GraphBuilder::DoMergeOp},
114 {MAP_ADD, &GraphBuilder::DoMergeOp},
115 {COMPARE_OP, &GraphBuilder::DoCompare},
116 {MAKE_FUNCTION, &GraphBuilder::DoMakeFunction},
117 {FORMAT_VALUE, &GraphBuilder::DoFormatValue},
118 {LIST_TO_TUPLE, &GraphBuilder::DoListToTuple},
119 {LOAD_CONST, &GraphBuilder::DoLoadConst},
120 {IMPORT_STAR, &GraphBuilder::DoImport},
121 {IMPORT_NAME, &GraphBuilder::DoImport},
122 {IMPORT_FROM, &GraphBuilder::DoImport},
123 {CALL_FUNCTION, &GraphBuilder::DoCall},
124 {CALL_FUNCTION_KW, &GraphBuilder::DoCall},
125 {CALL_FUNCTION_EX, &GraphBuilder::DoCall},
126 {CALL_METHOD, &GraphBuilder::DoCall},
127 {UNPACK_SEQUENCE, &GraphBuilder::DoUnpack},
128 {UNPACK_EX, &GraphBuilder::DoUnpack},
129 {BINARY_SUBSCR, &GraphBuilder::DoItemAccess},
130 {STORE_SUBSCR, &GraphBuilder::DoItemAccess},
131 {DELETE_SUBSCR, &GraphBuilder::DoItemAccess},
132 {LOAD_GLOBAL, &GraphBuilder::DoGlobalAccess},
133 {STORE_GLOBAL, &GraphBuilder::DoGlobalAccess},
134 {DELETE_GLOBAL, &GraphBuilder::DoGlobalAccess},
135 {LOAD_METHOD, &GraphBuilder::DoAttrAccess},
136 {LOAD_ATTR, &GraphBuilder::DoAttrAccess},
137 {STORE_ATTR, &GraphBuilder::DoAttrAccess},
138 {DELETE_ATTR, &GraphBuilder::DoAttrAccess},
139 {LOAD_CLOSURE, &GraphBuilder::DoCellAccess},
140 {LOAD_DEREF, &GraphBuilder::DoCellAccess},
141 {STORE_DEREF, &GraphBuilder::DoCellAccess},
142 {DELETE_DEREF, &GraphBuilder::DoCellAccess},
143 {LOAD_FAST, &GraphBuilder::DoLocalAccess},
144 {STORE_FAST, &GraphBuilder::DoLocalAccess},
145 {DELETE_FAST, &GraphBuilder::DoLocalAccess},
146 {GET_ITER, &GraphBuilder::DoGetIter},
147 {FOR_ITER, &GraphBuilder::TraceRunForIter},
148 {POP_JUMP_IF_FALSE, &GraphBuilder::TraceRunControl},
149 {POP_JUMP_IF_TRUE, &GraphBuilder::TraceRunControl},
150 {JUMP_IF_FALSE_OR_POP, &GraphBuilder::TraceRunControl},
151 {JUMP_IF_TRUE_OR_POP, &GraphBuilder::TraceRunControl},
152 {JUMP_FORWARD, &GraphBuilder::TraceRunControl},
153 {JUMP_ABSOLUTE, &GraphBuilder::TraceRunControl},
154 {YIELD_VALUE, &GraphBuilder::DoYieldValue},
155 {POP_BLOCK, &GraphBuilder::DoException},
156 {SETUP_WITH, &GraphBuilder::DoException},
157 {SETUP_FINALLY, &GraphBuilder::DoException},
158 {WITH_CLEANUP_START, &GraphBuilder::DoException},
159 {WITH_CLEANUP_FINISH, &GraphBuilder::DoException},
160 {END_FINALLY, &GraphBuilder::DoException},
161 {SETUP_EXCEPT, &GraphBuilder::DoException},
162 };
163
DoOtherBytecode(const Instr & instr)164 bool GraphBuilder::DoOtherBytecode(const Instr &instr) {
165 MS_LOG(ERROR) << "TODO: resolve for instruction " << instr.ToString();
166 return false;
167 }
168
ReplaceAll(ValueNode * old_node,ValueNode * new_node,bool * is_referenced)169 bool GraphBuilder::ReplaceAll(ValueNode *old_node, ValueNode *new_node, bool *is_referenced) {
170 static const std::set<int> ref_op = {
171 BUILD_TUPLE, BUILD_LIST, BUILD_SET, BUILD_MAP, BUILD_CONST_KEY_MAP,
172 };
173
174 // check reference relationship
175 const auto &nodes = graph_->GetTracedNodes();
176 bool find = std::any_of(nodes.begin(), nodes.end(), [&old_node](ValueNode *node) {
177 if (Opcode(node->GetOpcode()).MayDelete() && ref_op.find(node->GetOpcode()) == ref_op.end()) {
178 return false;
179 }
180 const auto &args = node->getInputs();
181 return std::any_of(args.begin(), args.end(), [&old_node](ValueNode *i) { return i == old_node; });
182 });
183 if (is_referenced != nullptr) {
184 *is_referenced |= find;
185 } else if (find) {
186 return false;
187 }
188
189 if (parent_ != nullptr && !parent_->ReplaceAll(old_node, new_node, is_referenced)) {
190 return false;
191 }
192 // find id_map, replace all nodes......
193 const auto pred = [&old_node](ValueNode *i) { return i == old_node; };
194 std::replace_if(frame_.GetLocals().begin(), frame_.GetLocals().end(), pred, new_node);
195 std::replace_if(frame_.GetStacks().begin(), frame_.GetStacks().end(), pred, new_node);
196 std::for_each(frame_.GetClosures().begin(), frame_.GetClosures().end(), [&old_node, &new_node](CellVarNode *i) {
197 if (i->GetValue() == old_node) {
198 i->SetValue(new_node);
199 }
200 });
201 return true;
202 }
203
NewValueNode(AObject * o,int op,int arg,const std::vector<ValueNode * > & p,const std::string & name)204 ValueNode *GraphBuilder::NewValueNode(AObject *o, int op, int arg, const std::vector<ValueNode *> &p,
205 const std::string &name) {
206 ValueNode *v;
207 if (Opcode(op).IsCall()) {
208 v = graph_->NewCallNode(op, arg, p);
209 v->SetVobj(o);
210 } else {
211 v = graph_->NewValueNode(o, op, arg, p, name);
212 }
213 v->set_bci(cur_bci_);
214 return v;
215 }
216
NewValueNode(AObject * o,const Instr & i,const std::vector<ValueNode * > & p)217 ValueNode *GraphBuilder::NewValueNode(AObject *o, const Instr &i, const std::vector<ValueNode *> &p) {
218 ValueNode *v = NewValueNode(o, i.op(), i.arg(), p, i.name());
219 v->SetLineNo(i.line());
220 graph_->GetTracedNodes().push_back(v);
221 return v;
222 }
223
NewGraph(PyCodeObject * co,PyObject * globals)224 Graph *GraphBuilder::NewGraph(PyCodeObject *co, PyObject *globals) {
225 std::vector<Graph *> &graphs = (root_ != nullptr) ? root_->graph_pool_ : this->graph_pool_;
226 if ((root_ == nullptr || root_ == this) && graph_ == nullptr) {
227 JitCompileResults *jcr = getJitCompileResults(reinterpret_cast<PyObject *>(co), false);
228 MS_EXCEPTION_IF_CHECK_FAIL(jcr && jcr->code != nullptr, "must be create guard code before trace start");
229 graphs.push_back(new Graph(co, globals, *jcr->conf));
230 graphs.back()->SetGuard(jcr->code);
231 // initialize side-effect handler, set unique data
232 graphs.back()->SetSideEffect(std::make_shared<SideEffect>());
233 graphs.back()->GetSideEffect()->set_data(std::make_shared<SideEffectData>());
234 } else {
235 graphs.push_back(new Graph(co, globals, root_->GetGraph()->Config()));
236 graphs.back()->SetGuard(root_->GetGraph()->GetGuard());
237 graphs.back()->SetSideEffect(root_->GetGraph()->GetSideEffect());
238 }
239 return graphs.back();
240 }
241
CheckValueValid(AObject * obj)242 static bool CheckValueValid(AObject *obj) {
243 if (obj->GetType() == AObject::kTypeTensor) {
244 AbstractTensor *tensor = static_cast<AbstractTensor *>(obj);
245 return tensor->IsStubTensor() || CheckTensorDataInitialized(obj->GetPyObject());
246 } else {
247 return true;
248 }
249 }
250
CondIsTrue(ValueNode * cond)251 int CondIsTrue(ValueNode *cond) {
252 // if cond is tensor attrs, infer tensor attrs
253 // if tensor is return node of cell, if tensor is return node of primitive
254 // if tensor is result of math operation(+-*/...)
255 AObject *cond_value = cond->GetVobj();
256 int ret = -1;
257 if (cond_value == nullptr || cond_value->GetPyObject().ptr() == nullptr) {
258 return ret;
259 }
260 py::object value = cond_value->GetPyObject();
261 if (CheckValueValid(cond_value)) {
262 ret = PyObject_IsTrue(value.ptr());
263 PyErr_Clear();
264 }
265 return ret;
266 }
267
CollectObjects(const std::vector<ValueNode * > & nodes)268 static std::vector<AObject *> CollectObjects(const std::vector<ValueNode *> &nodes) {
269 std::vector<AObject *> res;
270 std::transform(nodes.begin(), nodes.end(), std::back_inserter(res),
271 [](const ValueNode *node) { return node->GetVobj(); });
272 return res;
273 }
274
UnpackConstObject(const py::object & iterable)275 std::vector<ValueNode *> GraphBuilder::UnpackConstObject(const py::object &iterable) {
276 std::vector<ValueNode *> outputs;
277 std::transform(iterable.begin(), iterable.end(), std::back_inserter(outputs), [this](const auto &item) {
278 return this->NewValueNode(AObject::Convert(item.ptr()), LOAD_CONST, -1, {});
279 });
280 return outputs;
281 }
282
UnpackSequenceElements(ValueNode * node)283 bool GraphBuilder::UnpackSequenceElements(ValueNode *node) {
284 py::object seq = node->GetVobj()->GetPyObject();
285 if (seq.ptr() == nullptr || !PySequence_Check(seq.ptr()) || !GuardLoopSequence(this->graph_, node)) {
286 return false;
287 }
288
289 Py_ssize_t size = PySequence_Size(seq.ptr());
290 for (Py_ssize_t index = 0; index < size; ++index) {
291 push(node);
292 DoLoadConst({LOAD_CONST, -1, py::object(py::int_(index))});
293 DoItemAccess({BINARY_SUBSCR, 0});
294 }
295 return true;
296 }
297
UnpackElements(ValueNode * node)298 bool GraphBuilder::UnpackElements(ValueNode *node) {
299 int opcode = node->GetOpcode();
300 if (opcode == BUILD_LIST || opcode == BUILD_TUPLE) {
301 std::for_each(node->getInputs().begin(), node->getInputs().end(), [this](ValueNode *i) { this->push(i); });
302 } else if (node->IsConstantValue()) {
303 std::vector<ValueNode *> nodes = UnpackConstObject(node->GetVobj()->GetPyObject());
304 std::for_each(nodes.begin(), nodes.end(), [this](ValueNode *i) { this->push(i); });
305 } else {
306 return UnpackSequenceElements(node);
307 }
308 return true;
309 }
310
GenUnpackValue(const std::function<void (int,int)> & gen_item,int cnt,int cnt_after,Py_ssize_t size)311 static void GenUnpackValue(const std::function<void(int, int)> &gen_item, int cnt, int cnt_after, Py_ssize_t size) {
312 if (cnt_after != -1) {
313 const int end_pos = size - cnt_after;
314 for (int i = size; i > end_pos; --i) {
315 gen_item(i - 1, -1);
316 }
317 gen_item(cnt, end_pos);
318 }
319 for (; cnt > 0; --cnt) {
320 gen_item(cnt - 1, -1);
321 }
322 }
323
GetUnpackSize(ValueNode * iterable,int cnt,int cnt_after)324 Py_ssize_t GetUnpackSize(ValueNode *iterable, int cnt, int cnt_after) {
325 int op = iterable->GetOpcode();
326 Py_ssize_t total_args = cnt + cnt_after + 1;
327 Py_ssize_t size;
328 if (op == BUILD_LIST || op == BUILD_TUPLE) {
329 size = iterable->getInputs().size();
330 } else {
331 AObject *seq = iterable->GetVobj();
332 PyObject *o = (seq == nullptr) ? nullptr : seq->GetPyObject().ptr();
333 size = (o == nullptr) ? -1 : PyObject_Size(o);
334 }
335 if (size == -1 || (cnt_after == -1 && cnt != size) || total_args > size + 1) {
336 PyErr_Clear();
337 return -1;
338 }
339 return size;
340 }
341
DoUnpack(const Instr & instr)342 bool GraphBuilder::DoUnpack(const Instr &instr) {
343 int opcode = instr.op();
344 int oparg = instr.arg();
345 int cnt = (opcode == UNPACK_EX) ? (oparg & 0xFF) : oparg;
346 int cnt_after = (opcode == UNPACK_EX) ? (oparg >> 8) : -1;
347 Py_ssize_t size = GetUnpackSize(seek(0), cnt, cnt_after);
348 if (size == -1) {
349 return false;
350 }
351 ValueNode *iterable = pop();
352
353 size_t elements_size = frame_.GetStacks().size();
354 int iterable_opcode = iterable->GetOpcode();
355 if (iterable_opcode == BUILD_LIST || iterable_opcode == BUILD_TUPLE) {
356 std::for_each(iterable->getInputs().begin(), iterable->getInputs().end(), [this](ValueNode *i) { this->push(i); });
357 } else if (iterable->IsConstantValue()) {
358 std::vector<ValueNode *> nodes = UnpackConstObject(iterable->GetVobj()->GetPyObject());
359 std::for_each(nodes.begin(), nodes.end(), [this](ValueNode *i) { this->push(i); });
360 } else {
361 for (Py_ssize_t index = 0; index < size; ++index) {
362 push(iterable);
363 DoLoadConst({LOAD_CONST, -1, py::object(py::int_(index))});
364 DoItemAccess({BINARY_SUBSCR, 0});
365 }
366 }
367 elements_size = frame_.GetStacks().size() - elements_size;
368 std::vector<ValueNode *> elements(frame_.GetStacks().end() - elements_size, frame_.GetStacks().end());
369 popn(elements_size);
370
371 auto gen_item = [this, &elements](int i, int j) {
372 if (j == -1) {
373 this->push(elements[i]);
374 return;
375 }
376 MS_EXCEPTION_IF_CHECK_FAIL(j >= i, "check UNPACK_EX oparg");
377 auto in_iter = elements.begin();
378 std::for_each(in_iter + i, in_iter + j, [this](ValueNode *i) { this->push(i); });
379 DoBuildOp({BUILD_LIST, j - i});
380 };
381 GenUnpackValue(gen_item, cnt, cnt_after, size);
382 return true;
383 }
384
DoCall(const Instr & instr)385 bool GraphBuilder::DoCall(const Instr &instr) {
386 Opcode opcode(instr.op());
387 int oparg = instr.arg();
388 int tmp_arg = oparg;
389 std::vector<ValueNode *> params;
390 if (opcode == CALL_FUNCTION_EX) {
391 tmp_arg = (tmp_arg & 0x01) + 1;
392 } else if (opcode == CALL_FUNCTION_KW) {
393 tmp_arg += 1;
394 }
395 MS_EXCEPTION_IF_CHECK_FAIL(opcode.IsCall(), "must be call");
396 params = {frame_.GetStacks().end() - tmp_arg - 1, frame_.GetStacks().end()};
397 opcode = (opcode == CALL_METHOD) ? CALL_FUNCTION : opcode;
398 popn(tmp_arg + 1);
399 push(NewValueNode(nullptr, opcode, oparg, params));
400
401 CallNode *call_node = static_cast<CallNode *>(seek(0));
402 call_node->SetVobj(AObject::MakeAObject(AObject::kTypeAnyValue));
403 call_node->SetLineNo(instr.line());
404 call_node->set_bci(instr.bci());
405 this->graph_->GetTracedNodes().push_back(call_node);
406
407 StopTraceReason r = HandleCall(0);
408 if (r != StopTraceReason::kNonStopTrace) {
409 graph_->StopTraceAt(cur_bci_, r);
410 return false;
411 }
412 return true;
413 }
414
DoNop(const Instr & instr)415 bool GraphBuilder::DoNop(const Instr &instr) { return true; }
NotImplementBytecode(const Instr & instr)416 bool GraphBuilder::NotImplementBytecode(const Instr &instr) { return false; }
417
DoYieldValue(const Instr & instr)418 bool GraphBuilder::DoYieldValue(const Instr &instr) {
419 ValueNode *result = graph_->GetGeneratorResult();
420 if (result == nullptr) {
421 result = NewValueNode(nullptr, BUILD_TUPLE, 0);
422 graph_->SetGeneratorResult(result);
423 }
424 ValueNode *value = seek(0);
425 result->AddInput(value);
426 return true;
427 }
428
DoReturn(const Instr & instr)429 bool GraphBuilder::DoReturn(const Instr &instr) {
430 graph_->SetRetVal(pop());
431 if (graph_->GetGeneratorResult() == nullptr) {
432 return true;
433 }
434 const auto &inputs = graph_->GetGeneratorResult()->getInputs();
435 std::for_each(inputs.begin(), inputs.end(), [this](ValueNode *i) { this->push(i); });
436 DoBuildOp({BUILD_TUPLE, SizeToInt(inputs.size())});
437 ValueNode *new_node = pop();
438 graph_->SetGeneratorResult(new_node);
439 graph_->SetRetVal(new_node);
440 return true;
441 }
442
GetCallFunctionNode(ValueNode * node,PyObject * dst_dtype)443 ValueNode *GraphBuilder::GetCallFunctionNode(ValueNode *node, PyObject *dst_dtype) {
444 py::object prim_cast = Utils::GetModuleAttr("mindspore.ops.functional", "cast", false, true);
445 ValueNode *prim_node = NewValueNode(AObject::Convert(prim_cast), LOAD_CONST, {});
446 ValueNode *dtype_node = NewValueNode(AObject::Convert(dst_dtype), LOAD_CONST, -1, {});
447 std::vector<ValueNode *> cast_args = {prim_node, node, dtype_node};
448 ValueNode *call_node = NewValueNode(nullptr, CALL_FUNCTION, cast_args.size() - 1, cast_args);
449 return call_node;
450 }
451
DoMixedPrecisionLocalAccess(const Instr & instr,ValueNode * node)452 bool GraphBuilder::DoMixedPrecisionLocalAccess(const Instr &instr, ValueNode *node) {
453 auto param_node = static_cast<ParamNode *>(node);
454 auto dst_dtype = param_node->GetMixedPrecisionType();
455 ValueNode *call_node = GetCallFunctionNode(node, dst_dtype);
456 push(call_node);
457 auto *call = static_cast<CallNode *>(call_node);
458 call->SetVobj(AObject::MakeAObject(AObject::kTypeAnyValue));
459 call->SetLineNo(instr.line());
460 call->set_bci(instr.bci());
461 StopTraceReason r = HandleCall(0);
462 if (r != StopTraceReason::kNonStopTrace) {
463 graph_->StopTraceAt(cur_bci_, r);
464 return false;
465 }
466 this->graph_->GetTracedNodes().push_back(call_node);
467 return true;
468 }
469
DoLocalAccess(const Instr & instr)470 bool GraphBuilder::DoLocalAccess(const Instr &instr) {
471 if (instr.op() == LOAD_FAST) {
472 auto local = getLocal(instr.arg());
473 if (local->GetType() == AbstractNode::Param && reinterpret_cast<ParamNode *>(local)->IsMixedPrecisionType()) {
474 // TODO(lvxudong): fix multi cast
475 DoMixedPrecisionLocalAccess(instr, local);
476 } else {
477 push(local);
478 }
479 } else if (instr.op() == STORE_FAST) {
480 setLocal(instr.arg(), pop());
481 } else if (instr.op() == DELETE_FAST) {
482 setLocal(instr.arg(), &ValueNode::kUnboundLocal);
483 } else {
484 MS_LOG(INTERNAL_EXCEPTION) << "parser got an error instruction " << instr.ToString();
485 }
486 return true;
487 }
488
DoCellAccess(const Instr & instr)489 bool GraphBuilder::DoCellAccess(const Instr &instr) {
490 int opcode = instr.op();
491 int oparg = instr.arg();
492 ValueNode *node;
493 ValueNode *value;
494 PyObject *cell = frame_.Closure(oparg)->GetVobj()->GetPyObject().ptr();
495 MS_EXCEPTION_IF_CHECK_FAIL(cell && PyCell_Check(cell), "must be a cell object");
496 if (opcode == LOAD_CLOSURE) {
497 push(frame_.Closure(oparg));
498 } else if (opcode == LOAD_DEREF) {
499 MS_EXCEPTION_IF_NULL(frame_.Closure(oparg)->GetValue());
500 push(frame_.Closure(oparg)->GetValue());
501 } else if (opcode == STORE_DEREF) {
502 value = pop();
503 bool is_same = value->GetOpcode() == LOAD_DEREF && frame_.Closure(oparg) == frame_.Closure(value->GetOparg());
504 if (!is_same) {
505 node = NewValueNode(nullptr, instr, {value});
506 graph_->GetSideEffect()->Record(node);
507 frame_.Closure(oparg)->SetValue(value);
508 frame_.Closure(oparg)->AddCellOper(node);
509 }
510 } else if (opcode == DELETE_DEREF) {
511 node = NewValueNode(nullptr, instr, {});
512 graph_->GetSideEffect()->Record(node);
513 frame_.Closure(oparg)->SetValue(&ValueNode::kUnboundLocal);
514 frame_.Closure(oparg)->AddCellOper(node);
515 } else {
516 MS_LOG(INTERNAL_EXCEPTION) << "parser got an error instruction " << instr.ToString();
517 }
518 return true;
519 }
520
521 // Parse byteCode -- SETUP_WITH
DoWith(const Instr & instr)522 bool GraphBuilder::DoWith(const Instr &instr) {
523 if (graph_->Config().GetBoolConfig(GraphJitConfig::kSkipException) || PyErr_Occurred()) {
524 graph_->StopTraceAt(cur_bci_, StopTraceReason::kStopTraceSkip_Exception);
525 return false;
526 }
527 auto node = pop();
528 push(node);
529 DoAttrAccess({LOAD_ATTR, 0, "__exit__"});
530
531 push(node);
532 DoAttrAccess({LOAD_ATTR, 0, "__enter__"});
533
534 if (!DoCall({CALL_FUNCTION, 0})) {
535 MS_LOG(ERROR) << "function '__enter__' runs failed here, it should be successful!";
536 return false;
537 }
538 PushStack(TryBlock{SETUP_WITH, instr.extra_jump()->bci(), instr.bci(), false});
539 cur_bci_++;
540 return true;
541 }
542
DoException(const Instr & instr)543 bool GraphBuilder::DoException(const Instr &instr) {
544 #if (PY_MAJOR_VERSION == 3) && (PY_MINOR_VERSION == 8)
545 return false;
546 #else
547 int opCode = instr.op();
548 if (opCode == SETUP_WITH) {
549 return DoWith(instr);
550 } else if (opCode == POP_BLOCK) {
551 PopStack();
552 return true;
553 } else if (opCode == SETUP_FINALLY) {
554 /*
555 ByteCode like this in python3.9
556 0 SETUP_FINALLY xxx
557 1 SETUP_FINALLY xxx
558 the first SETUP_FINALLY points to finally block, the second points to exception block
559 */
560 if (graph_->Config().GetBoolConfig(GraphJitConfig::kSkipException)) {
561 graph_->StopTraceAt(cur_bci_, StopTraceReason::kStopTraceSkip_Exception);
562 return false;
563 }
564 if (StackSize() == 0 || GetTryBlockStacks().back().type != SETUP_FINALLY) {
565 PushStack(TryBlock{SETUP_FINALLY, instr.extra_jump()->bci(), instr.bci(), true});
566 } else {
567 assert(StackSize() > 0 || GetTryBlockStacks().back().type == SETUP_FINALLY);
568 PushStack(TryBlock{SETUP_FINALLY, instr.extra_jump()->bci(), instr.bci(), false});
569 }
570 cur_bci_++;
571 return true;
572 } else if (opCode == WITH_CLEANUP_START) {
573 /* python3.7 only */
574 ValueNode *exc = seek(0);
575 ValueNode *exit_func = seek(1);
576 if (exc->GetVobj()->GetType() != AObject::kTypeNone) {
577 return false;
578 }
579 if (exit_func->GetName() != "__exit__") {
580 MS_LOG(ERROR) << "it should call function '__exit__' here!";
581 return false;
582 }
583 // run exit func
584 push(exc);
585 push(exc);
586 if (!DoCall({CALL_FUNCTION, 3})) {
587 MS_LOG(ERROR) << "function '__exit__' runs failed here, it should be successful!";
588 return false;
589 }
590 push(exc);
591 return true;
592 } else if (opCode == WITH_CLEANUP_FINISH) {
593 auto exc = pop();
594 (void)pop();
595 push(exc);
596 return true;
597 } else if (opCode == END_FINALLY) {
598 (void)pop();
599 return true;
600 } else if (opCode == SETUP_EXCEPT) {
601 if (graph_->Config().GetBoolConfig(GraphJitConfig::kSkipException)) {
602 graph_->StopTraceAt(cur_bci_, StopTraceReason::kStopTraceSkip_Exception);
603 return false;
604 }
605 PushStack(TryBlock{SETUP_EXCEPT, instr.extra_jump()->bci(), instr.bci(), false});
606 cur_bci_++;
607 return true;
608 } else {
609 MS_LOG(INTERNAL_EXCEPTION) << "parser got an error instruction " << instr.ToString();
610 }
611 return false;
612 #endif
613 }
614
PeekStack(int p)615 TryBlock &GraphBuilder::PeekStack(int p) {
616 MS_ASSERT(tryBlockStacks_.size() > p);
617 return tryBlockStacks_[tryBlockStacks_.size() - p - 1];
618 }
619
PopStack()620 TryBlock &GraphBuilder::PopStack() {
621 MS_ASSERT(tryBlockStacks_.size() > 0);
622 auto &tb = tryBlockStacks_[tryBlockStacks_.size() - 1];
623 tryBlockStacks_.pop_back();
624 return tb;
625 }
626
DoGlobalAccess(const Instr & instr)627 bool GraphBuilder::DoGlobalAccess(const Instr &instr) {
628 int opcode = instr.op();
629 int oparg = instr.arg();
630 if (opcode == LOAD_GLOBAL) {
631 auto cache_result = graph_->GetSideEffect()->LoadGlobal(graph_->GetModuleName(), instr.name());
632 if (cache_result.is_deleted_value_) {
633 return false; // name error
634 } else if (cache_result.cache_value_ != nullptr) {
635 push(cache_result.cache_value_);
636 } else {
637 auto co = graph_->GetCodeObj();
638 PyObject *key = PyTuple_GET_ITEM(co->co_names, oparg);
639 // NOTE: will run __get__, __hash__ function
640 PyObject *obj = PyObject_GetItem(graph_->GetGlobals().ptr(), key);
641 if (obj == nullptr) {
642 PyErr_Clear();
643 obj = PyObject_GetItem(PyEval_GetBuiltins(), key);
644 if (obj == nullptr) {
645 PyErr_Clear();
646 }
647 }
648 py::object pyobj = py::reinterpret_steal<py::object>(obj);
649 auto n = NewValueNode(AObject::Convert(pyobj), instr, {});
650 n->SetName(PyUnicode_AsUTF8(key));
651 push(n);
652 }
653 } else if (opcode == STORE_GLOBAL) {
654 auto global_node = pop();
655 auto node = NewValueNode(nullptr, instr, {global_node});
656 graph_->GetSideEffect()->Record(node);
657 } else if (opcode == DELETE_GLOBAL) {
658 auto node = NewValueNode(nullptr, instr, {});
659 graph_->GetSideEffect()->Record(node);
660 } else {
661 MS_LOG(INTERNAL_EXCEPTION) << "parser got an error instruction " << instr.ToString();
662 }
663 return true;
664 }
665
HandleSuper(const Instr & instr,AObject * super)666 bool GraphBuilder::HandleSuper(const Instr &instr, AObject *super) {
667 if (super != nullptr && super->GetTypeObject() != &PySuper_Type) {
668 return false;
669 }
670 ValueNode *self_super = SearchSelfPyObject(graph_->GetCodeObj()).second;
671 if (self_super == nullptr) {
672 return false;
673 }
674 py::object method = super->GetPyObject().attr(instr.name().c_str());
675 if (!PyMethod_Check(method.ptr())) {
676 return false;
677 }
678
679 // method type object
680 auto mtype_obj = reinterpret_cast<PyObject *>(&PyMethod_Type);
681 DoLoadConst({LOAD_CONST, -1, py::cast<py::object>(mtype_obj)});
682
683 // function object
684 PyObject *m = PyMethod_GET_FUNCTION(method.ptr());
685 DoLoadConst({LOAD_CONST, -1, py::cast<py::object>(m)});
686
687 push(self_super);
688
689 // call method type
690 return DoCall({CALL_FUNCTION, 2});
691 }
692
SetLocalPyObject(ValueNode * node)693 PyObject *SetLocalPyObject(ValueNode *node) {
694 if (node == nullptr || node->GetVobj() == nullptr) {
695 return NULL;
696 } else {
697 return node->GetVobj()->GetPyObject().ptr();
698 }
699 }
700
SearchSelfPyObject(PyCodeObject * co)701 std::pair<PyObject *, ValueNode *> GraphBuilder::SearchSelfPyObject(PyCodeObject *co) {
702 if (co->co_argcount < 1) {
703 return {nullptr, nullptr};
704 }
705 std::pair<PyObject *, ValueNode *> obj_value;
706 ValueNode *value = frame_.Local(0);
707 // get self or son class, eg.super(Son, self)
708 PyObject *obj = SetLocalPyObject(frame_.Local(0));
709 Py_ssize_t i, n;
710 if (obj == NULL && co->co_cell2arg) {
711 // the first argument might be a cell
712 n = PyTuple_GET_SIZE(co->co_cellvars);
713 for (i = 0; i < n; i++) {
714 if (co->co_cell2arg[i] == 0) {
715 value = frame_.Closure(i)->GetValue();
716 obj = SetLocalPyObject(frame_.Closure(i));
717 break;
718 }
719 }
720 }
721 obj_value = std::make_pair(obj, value);
722 return obj_value;
723 }
724
HandleGetattr(ValueNode * target_node,const Instr & instr)725 ValueNode *GraphBuilder::HandleGetattr(ValueNode *target_node, const Instr &instr) {
726 return NewValueNode(target_node->get_attr(instr.name()), instr, {target_node});
727 }
728
DoMixedPrecisionAttrAccess(const Instr & instr,ValueNode * node,ValueNode * attr)729 ValueNode *GraphBuilder::DoMixedPrecisionAttrAccess(const Instr &instr, ValueNode *node, ValueNode *attr) {
730 if (node->GetVobj() == nullptr || node->GetVobj()->GetPyObject().ptr() == nullptr ||
731 node->GetVobj()->GetType() != AbstractObjectBase::kTypeCell) {
732 return nullptr;
733 }
734 auto cell = py::cast<CellPtr>(node->GetVobj()->GetPyObject());
735 auto mixed_type = cell->GetMixedPrecisionType();
736 if (mixed_type == kNotSet) {
737 return nullptr;
738 }
739 if (attr->GetVobj() == nullptr || attr->GetVobj()->GetPyObject().ptr() == nullptr) {
740 return nullptr;
741 }
742 if (attr->GetVobj()->GetType() == AObject::kTypeTensor && !attr->GetVobj()->GetPyObject().attr("dtype").is_none()) {
743 auto src_dtype = attr->GetVobj()->GetPyObject().attr("dtype");
744 bool is_cast = false;
745 if (py::isinstance<Float>(src_dtype)) {
746 auto float_nbits = py::cast<Float>(src_dtype).nbits();
747 if (float_nbits == 64 || (float_nbits == 32 && mixed_type != kFP32) ||
748 (float_nbits == 16 && mixed_type != kFP16)) {
749 is_cast = true;
750 }
751 }
752 if (py::isinstance<BFloat>(src_dtype) && mixed_type != kBF16) {
753 is_cast = true;
754 }
755 if (is_cast) {
756 auto dst_dtype = Utils::MixedPrecisionTypeToDType(mixed_type);
757 ValueNode *call_node = GetCallFunctionNode(attr, dst_dtype);
758 CallNode *call = static_cast<CallNode *>(call_node);
759 call->SetVobj(AObject::MakeAObject(AObject::kTypeAnyValue));
760 call->SetLineNo(instr.line());
761 call->set_bci(instr.bci());
762 push(call_node);
763 StopTraceReason r = HandleCall(0);
764 if (r != StopTraceReason::kNonStopTrace) {
765 graph_->StopTraceAt(cur_bci_, r);
766 return nullptr;
767 }
768 this->graph_->GetTracedNodes().push_back(call_node);
769 return pop();
770 }
771 }
772 return nullptr;
773 }
774
DoAttrAccess(const Instr & instr)775 bool GraphBuilder::DoAttrAccess(const Instr &instr) {
776 int opcode = instr.op();
777 if (opcode == LOAD_METHOD || opcode == LOAD_ATTR) {
778 auto o = pop();
779 if (HandleSuper(instr, o->GetVobj())) {
780 return true;
781 }
782 auto cache_result = graph_->GetSideEffect()->LoadAttr(o, instr.name());
783 if (cache_result.is_deleted_value_) { // attribute error
784 return false;
785 } else if (cache_result.cache_value_ != nullptr) {
786 push(cache_result.cache_value_);
787 } else {
788 push(HandleGetattr(o, instr));
789 auto attr = DoMixedPrecisionAttrAccess(instr, o, seek(0));
790 if (attr) {
791 seek(0) = attr;
792 }
793 }
794 } else if (opcode == STORE_ATTR) {
795 if (trace_flag() && parent_ != nullptr) {
796 return false;
797 }
798 auto o = pop();
799 auto v = pop();
800 auto node = NewValueNode(nullptr, instr, {v, o});
801 graph_->GetSideEffect()->Record(node);
802 } else if (opcode == DELETE_ATTR) {
803 auto o = pop();
804 auto node = NewValueNode(nullptr, instr, {o});
805 graph_->GetSideEffect()->Record(node);
806 } else {
807 MS_LOG(INTERNAL_EXCEPTION) << "parser got an error instruction " << instr.ToString();
808 }
809 return true;
810 }
811
812 // for unpack call optimize
TupleDictGetItem(ValueNode * container,ValueNode * index_node)813 static ValueNode *TupleDictGetItem(ValueNode *container, ValueNode *index_node) {
814 if (!index_node->IsConstantValue()) {
815 return nullptr;
816 }
817 PyObject *index_object = index_node->GetVobj()->GetPyObject().ptr();
818 int opcode = container->GetOpcode();
819 if ((opcode == BUILD_TUPLE || opcode == BUILD_LIST) && PyLong_Check(index_object)) {
820 Py_ssize_t index = PyLong_AsSsize_t(index_object);
821 Py_ssize_t size = container->getInputs().size();
822 if (index < -size || index >= size) {
823 return nullptr;
824 }
825 index = index < 0 ? (size + index) : index;
826 return container->input(index);
827 }
828 if (container->GetOpcode() == BUILD_MAP && PyUnicode_Check(index_object)) {
829 std::string k = PyUnicode_AsUTF8(index_object);
830 size_t element_count = container->GetOparg() << 1;
831 MS_EXCEPTION_IF_CHECK_FAIL(element_count == container->getInputs().size(), "check BUILD_MAP oparg");
832 for (int i = 0; i < container->GetOparg(); ++i) {
833 AObject *tmp = container->input(i * 2)->GetVobj();
834 PyObject *str = tmp ? tmp->GetPyObject().ptr() : nullptr;
835 if (str == nullptr || !PyUnicode_Check(str) || k != PyUnicode_AsUTF8(str)) {
836 continue;
837 }
838 return container->input((i << 1) + 1);
839 }
840 }
841 return nullptr;
842 }
843
DoGetItem(const Instr & instr)844 bool GraphBuilder::DoGetItem(const Instr &instr) {
845 constexpr const char *kNameGetItem = "__getitem__";
846 auto r = pop();
847 auto l = pop();
848 ValueNode *v = TupleDictGetItem(l, r);
849 if (v != nullptr) {
850 push(v);
851 return true;
852 }
853
854 AObject *container = l->GetVobj();
855 PyObject *op = container ? container->GetPyObject().ptr() : nullptr;
856 AObject *meth = nullptr;
857
858 bool call_getitem = op == nullptr || container->GetType() != AObject::kTypeAnyValue;
859 if (!call_getitem) {
860 call_getitem = PyDict_Check(op) || PyTuple_Check(op) || PyList_Check(op);
861 }
862 if (!call_getitem) {
863 meth = container->GetAttr(kNameGetItem);
864 PyObject *m = meth ? meth->GetPyObject().ptr() : nullptr;
865 call_getitem = m == nullptr || !PyMethod_Check(m) || !PyFunction_Check(PyMethod_GET_FUNCTION(m));
866 }
867 if (call_getitem) {
868 /**
869 * check safe callable of __getitem__ if user defined.
870 */
871 AObject *vo = l->binary_subscr(r);
872 v = NewValueNode(vo, instr, {l, r});
873 push(v);
874 return true;
875 }
876
877 push(l);
878 DoAttrAccess({LOAD_ATTR, 0, kNameGetItem});
879 push(r);
880 return DoCall({CALL_FUNCTION, 1});
881 }
882
TransformDictSetItem(ValueNode * map,ValueNode * key,ValueNode * value,bool ignore_key_error)883 ValueNode *GraphBuilder::TransformDictSetItem(ValueNode *map, ValueNode *key, ValueNode *value, bool ignore_key_error) {
884 PyObject *index_object = key->GetVobj()->GetPyObject().ptr();
885 if (index_object == nullptr || !key->IsConstantValue()) {
886 return nullptr; // only supported constant key
887 }
888 constexpr const int kNumberTwo = 2;
889 PyObject *map_object = map->GetVobj()->GetPyObject().ptr();
890 std::vector<ValueNode *> elements;
891 if (map->GetOpcode() == BUILD_MAP) {
892 elements = map->getInputs();
893 } else if (map_object != nullptr) {
894 auto keys = py::reinterpret_steal<py::object>(PyDict_Keys(map_object));
895 // guard dict keys, transform to const key map......
896 Py_ssize_t size = PyList_GET_SIZE(keys.ptr());
897 for (Py_ssize_t i = 0; i < size; ++i) {
898 Instr instr(LOAD_CONST, 0, py::reinterpret_borrow<py::object>(PyList_GET_ITEM(keys.ptr(), i)));
899 this->DoLoadConst(instr);
900 this->push(map);
901 this->DoLoadConst(instr);
902 this->DoGetItem({BINARY_SUBSCR, 0});
903 }
904 elements = {frame_.GetStacks().end() - size * kNumberTwo, frame_.GetStacks().end()};
905 popn(size * kNumberTwo);
906 } else {
907 return nullptr;
908 }
909
910 // set(delete) element
911 if (value != nullptr) {
912 elements.push_back(key);
913 elements.push_back(value);
914 } else {
915 int index_of_key = -1;
916 for (int i = elements.size() - kNumberTwo; i >= 0 && index_of_key == -1; i -= kNumberTwo) {
917 bool find = elements[i]->GetVobj()->GetPyObject().equal(py::handle(index_object));
918 index_of_key = find ? i : -1;
919 }
920 if (index_of_key != -1) {
921 elements.erase(elements.begin() + index_of_key, elements.begin() + index_of_key + kNumberTwo);
922 } else if (!ignore_key_error) {
923 return nullptr; // maybe key error
924 }
925 }
926
927 // rebuild map
928 int size = elements.size() / kNumberTwo;
929 std::for_each(elements.begin(), elements.end(), [this](ValueNode *i) { this->push(i); });
930 DoBuildOp({BUILD_MAP, size});
931 return pop();
932 }
933
ListIndexCompute(PyObject * index_object,Py_ssize_t size)934 std::vector<Py_ssize_t> ListIndexCompute(PyObject *index_object, Py_ssize_t size) {
935 if (PyIndex_Check(index_object)) {
936 Py_ssize_t index = PyNumber_AsSsize_t(index_object, PyExc_IndexError);
937 if (!PyErr_Occurred() && index > -size && index < size) {
938 index = index < 0 ? (index + size) : index;
939 return {index, index + 1, 1, 1};
940 }
941 } else if (PySlice_Check(index_object)) {
942 Py_ssize_t start;
943 Py_ssize_t stop;
944 Py_ssize_t step;
945 Py_ssize_t slice_length;
946 constexpr Py_ssize_t zero = 0;
947 if (0 == PySlice_GetIndicesEx(index_object, size, &start, &stop, &step, &slice_length)) {
948 slice_length = (start < 0 || stop < 0 || slice_length < 0) ? 0 : slice_length;
949 return {std::max(start, zero), std::max(stop, zero), step, slice_length};
950 }
951 }
952 if (!PyErr_Occurred()) {
953 return {};
954 }
955 throw py::error_already_set();
956 }
957
958 template <typename T>
SetSlice(std::vector<T> * elements,const std::vector<Py_ssize_t> & computed_slice,std::vector<T> * new_elements=nullptr)959 static bool SetSlice(std::vector<T> *elements, const std::vector<Py_ssize_t> &computed_slice,
960 std::vector<T> *new_elements = nullptr) {
961 constexpr int start = 0;
962 constexpr int stop = 1;
963 constexpr int step = 2;
964 constexpr int slice_length = 3;
965
966 const auto &slice = computed_slice;
967 if (slice[step] == 1) {
968 elements->erase(elements->begin() + slice[start], elements->begin() + slice[stop]);
969 if (new_elements != nullptr) {
970 elements->insert(elements->begin() + slice[start], new_elements->begin(), new_elements->end());
971 }
972 return true;
973 }
974 if (new_elements != nullptr && new_elements->size() != static_cast<size_t>(slice[slice_length])) {
975 return false;
976 }
977 for (Py_ssize_t cur = slice[start], i = 0; i < slice[slice_length]; cur += slice[step], ++i) {
978 (*elements)[cur] = new_elements == nullptr ? nullptr : (*new_elements)[i];
979 }
980 if (new_elements == nullptr) {
981 elements->erase(std::remove(elements->begin(), elements->end(), nullptr), elements->end());
982 }
983 return true;
984 }
985
TransformListSetItem(ValueNode * map,ValueNode * key,ValueNode * value)986 ValueNode *GraphBuilder::TransformListSetItem(ValueNode *map, ValueNode *key, ValueNode *value) {
987 PyObject *index_object = key->GetVobj()->GetPyObject().ptr();
988 if (index_object == nullptr || !key->IsConstantValue()) {
989 return nullptr; // only supported constant key
990 }
991 PyObject *map_object = map->GetVobj()->GetPyObject().ptr();
992 std::vector<ValueNode *> elements;
993 if (map->GetOpcode() == BUILD_LIST) {
994 elements = map->getInputs();
995 } else if (UnpackElements(map)) {
996 Py_ssize_t size = PyList_GET_SIZE(map_object);
997 elements = {frame().GetStacks().end() - size, frame().GetStacks().end()};
998 popn(size);
999 } else {
1000 return nullptr;
1001 }
1002
1003 // compute slice
1004 auto slice = ListIndexCompute(index_object, elements.size());
1005 if (slice.empty()) {
1006 return nullptr;
1007 }
1008 // set(delete) elements
1009 size_t stack_size = frame_.GetStacks().size();
1010 if (!PySlice_Check(index_object)) {
1011 auto iter = elements.begin() + slice[0];
1012 (void)(value == nullptr ? elements.erase(iter) : (*iter = value, iter));
1013 } else if (value == nullptr && SetSlice(&elements, slice)) {
1014 // delete success
1015 } else if (value != nullptr && UnpackElements(value)) {
1016 // unpack success
1017 stack_size = frame_.GetStacks().size() - stack_size;
1018 std::vector<ValueNode *> new_elements = {frame_.GetStacks().end() - stack_size, frame_.GetStacks().end()};
1019 popn(stack_size);
1020 if (!SetSlice(&elements, slice, &new_elements)) {
1021 return nullptr;
1022 }
1023 // set succuss
1024 } else {
1025 return nullptr;
1026 }
1027
1028 std::for_each(elements.begin(), elements.end(), [this](ValueNode *i) { this->push(i); });
1029 DoBuildOp({BUILD_LIST, SizeToInt(elements.size())});
1030 return pop();
1031 }
1032
DoSetItem(ValueNode * map,ValueNode * key,ValueNode * value)1033 bool GraphBuilder::DoSetItem(ValueNode *map, ValueNode *key, ValueNode *value) {
1034 // only support constant key
1035 if (!this->graph_->GuardValueNode(key)) {
1036 return false;
1037 }
1038 // erase side-effect
1039 ValueNode *side_effect_node = graph_->GetTracedNodes().back();
1040 graph_->GetTracedNodes().pop_back();
1041
1042 // try to transform
1043 const auto &replace_map = graph_->GetSideEffect()->data()->modified_and_replaced_map();
1044 bool is_new_var = false;
1045 ValueNode *old_node = map;
1046 ValueNode *new_node = nullptr;
1047 AObject::Type type = map->GetVobj()->GetType();
1048 if (type == AObject::kTypeList) {
1049 is_new_var = map->GetOpcode() == BUILD_LIST && replace_map.find(map) == replace_map.end();
1050 new_node = TransformListSetItem(map, key, value);
1051 } else if (type == AObject::kTypeDict) {
1052 is_new_var = map->GetOpcode() == BUILD_MAP && replace_map.find(map) == replace_map.end();
1053 new_node = TransformDictSetItem(map, key, value, false);
1054 }
1055 // failed transform, restore side-effect
1056 if (new_node == nullptr) {
1057 graph_->GetTracedNodes().push_back(side_effect_node);
1058 return false;
1059 }
1060 bool is_referenced = false;
1061 ReplaceAll(old_node, new_node, &is_referenced);
1062 // check it is new variable and not escaped
1063 if (is_new_var && !is_referenced && map != value) {
1064 return true;
1065 }
1066 // restore and record
1067 this->graph_->GetTracedNodes().push_back(side_effect_node);
1068 this->graph_->GetSideEffect()->data()->RecordModifiedAndReplacedNode(old_node, new_node);
1069 this->graph_->GetSideEffect()->Record(side_effect_node);
1070 return true;
1071 }
1072
DoItemAccess(const Instr & instr)1073 bool GraphBuilder::DoItemAccess(const Instr &instr) {
1074 int opcode = instr.op();
1075 if (opcode == BINARY_SUBSCR) {
1076 DoGetItem(instr);
1077 } else if (opcode == STORE_SUBSCR) {
1078 auto key = pop();
1079 auto map = pop();
1080 auto value = pop();
1081 NewValueNode(nullptr, instr, {value, map, key});
1082 DoSetItem(map, key, value);
1083 } else if (opcode == DELETE_SUBSCR) {
1084 auto key = pop();
1085 auto map = pop();
1086 NewValueNode(nullptr, instr, {map, key});
1087 DoSetItem(map, key, nullptr);
1088 } else {
1089 MS_LOG(INTERNAL_EXCEPTION) << "parser got an error instruction " << instr.ToString();
1090 }
1091 return true;
1092 }
1093
DoStackOp(const Instr & instr)1094 bool GraphBuilder::DoStackOp(const Instr &instr) {
1095 int opcode = instr.op();
1096 int oparg = instr.arg();
1097 if (opcode == POP_TOP) {
1098 pop();
1099 } else if (opcode == ROT_TWO) {
1100 frame_.Rot(1);
1101 } else if (opcode == ROT_THREE) {
1102 frame_.Rot(2);
1103 } else if (opcode == ROT_FOUR) {
1104 frame_.Rot(3);
1105 } else if (opcode == ROT_N) {
1106 frame_.Rot(oparg - 1);
1107 } else if (opcode == DUP_TOP_TWO) {
1108 push(seek(1));
1109 push(seek(1));
1110 } else if (opcode == DUP_TOP) {
1111 push(seek(0));
1112 } else {
1113 MS_LOG(INTERNAL_EXCEPTION) << "parser got an error instruction " << instr.ToString();
1114 }
1115 return true;
1116 }
1117
DoLoadConst(const Instr & instr)1118 bool GraphBuilder::DoLoadConst(const Instr &instr) {
1119 auto n = NewValueNode(AObject::Convert(instr.cnst()), instr, {});
1120 push(n);
1121 return true;
1122 }
1123
DoListToTuple(const Instr & instr)1124 bool GraphBuilder::DoListToTuple(const Instr &instr) {
1125 ValueNode *list = pop();
1126 if (list->GetOpcode() == BUILD_LIST) {
1127 std::for_each(list->getInputs().begin(), list->getInputs().end(), [this](ValueNode *i) { this->push(i); });
1128 return DoBuildOp({BUILD_TUPLE, SizeToInt(list->getInputs().size())});
1129 }
1130 AObject *vo = list->GetVobj();
1131 if (vo && vo->GetType() == AObject::kTypeList) {
1132 vo = static_cast<AbstractList *>(vo)->ListToTuple();
1133 } else {
1134 vo = AObject::MakeAObject(AObject::kTypeAnyValue);
1135 }
1136 ValueNode *tuple = NewValueNode(vo, instr, {list});
1137 push(tuple);
1138 return true;
1139 }
1140
DoGetIter(const Instr & instr)1141 bool GraphBuilder::DoGetIter(const Instr &instr) {
1142 auto obj = pop();
1143 auto o = obj->GetVobj();
1144 auto iter = NewValueNode(o ? o->GetIter() : AObject::MakeAObject(AObject::kTypeAnyValue), instr, {obj});
1145 push(iter);
1146 iter->marker_ = 0;
1147 return true;
1148 }
1149
DoMakeFunction(const Instr & instr)1150 bool GraphBuilder::DoMakeFunction(const Instr &instr) {
1151 int oparg = instr.arg();
1152 // int cnt = __builtin_popcount(oparg & 0xf) + 2;
1153 int cnt = !!(oparg & 0x08) + !!(oparg & 0x04) + !!(oparg & 0x02) + !!(oparg & 0x01) + 2;
1154 std::vector<ValueNode *> p(frame_.GetStacks().end() - cnt, frame_.GetStacks().end());
1155 popn(cnt);
1156 AObject *f = AObject::MakeFunction(CollectObjects(p), graph_->GetGlobals(), oparg);
1157 ValueNode *func = NewValueNode(f, instr, p);
1158 push(func);
1159 current_block_->SetTrackResult(Block::kHasGlobalSideEffect);
1160 return true;
1161 }
1162
InferUnary(ValueNode * node,const Instr & instr)1163 AObject *GraphBuilder::InferUnary(ValueNode *node, const Instr &instr) { return node->GetVobj()->Unary(instr.op()); }
1164
DoUnary(const Instr & instr)1165 bool GraphBuilder::DoUnary(const Instr &instr) {
1166 ValueNode *node = pop();
1167
1168 AObject *object_info = InferUnary(node, instr);
1169 ValueNode *new_node = NewValueNode(object_info, instr, {node});
1170 push(new_node);
1171 return true;
1172 }
1173
DoIsOp(const Instr & instr)1174 bool GraphBuilder::DoIsOp(const Instr &instr) { return DoBinary(instr); }
1175
InferBinary(ValueNode * left,ValueNode * right,const Instr & instr)1176 AObject *GraphBuilder::InferBinary(ValueNode *left, ValueNode *right, const Instr &instr) {
1177 AObject *object_info;
1178 if (instr.op() == IS_OP || instr.op() == CONTAINS_OP) {
1179 object_info = left->GetVobj()->Binary(right->GetVobj(), instr.op());
1180 PyObject *object = object_info != nullptr ? object_info->GetPyObject().ptr() : nullptr;
1181 if (object != nullptr) {
1182 object_info = AObject::Convert(py::bool_((object == Py_True) ^ instr.arg()));
1183 }
1184 } else if (Opcode(instr.op()).IsBinaryMath()) {
1185 if (left->IsConstantValue() && right->IsConstantValue()) {
1186 // compute real tensor value, not infer fake value
1187 AbstractObject *tensor = static_cast<AbstractObject *>(left->GetVobj());
1188 object_info = tensor->AbstractObject::Binary(right->GetVobj(), instr.op());
1189 } else {
1190 object_info = left->GetVobj()->Binary(right->GetVobj(), instr.op());
1191 }
1192 } else {
1193 return AObject::MakeAObject(AObject::kTypeAnyValue);
1194 }
1195 return object_info;
1196 }
1197
DoBinary(const Instr & instr)1198 bool GraphBuilder::DoBinary(const Instr &instr) {
1199 ValueNode *right = pop();
1200 ValueNode *left = pop();
1201
1202 AObject *object_info = InferBinary(left, right, instr);
1203 ValueNode *new_node = NewValueNode(object_info, instr, {left, right});
1204 push(new_node);
1205 return true;
1206 }
1207
CheckTupleListMul(ValueNode * left,ValueNode * right)1208 static bool CheckTupleListMul(ValueNode *left, ValueNode *right) {
1209 bool special = left->GetOpcode() == BUILD_LIST || left->GetOpcode() == BUILD_TUPLE;
1210 if (!special && left->IsConstantValue()) {
1211 AObject::Type l_type = left->GetVobj()->GetType();
1212 special = l_type == AObject::kTypeTuple || l_type == AObject::kTypeList;
1213 }
1214 if (special && right->IsConstantValue()) {
1215 PyObject *mul = right->GetVobj()->GetPyObject().ptr();
1216 const int max = 2;
1217 return PyLong_Check(mul) && Py_ABS(Py_SIZE(mul)) < max;
1218 }
1219 return false;
1220 }
1221
DoBinaryMul(const Instr & instr)1222 bool GraphBuilder::DoBinaryMul(const Instr &instr) {
1223 if (!CheckTupleListMul(seek(1), seek(0))) {
1224 return DoBinary(instr);
1225 }
1226
1227 ValueNode *right = pop();
1228 ValueNode *left = pop();
1229 int l_op = left->GetVobj()->GetType() == AObject::kTypeTuple ? BUILD_TUPLE : BUILD_LIST;
1230
1231 Py_ssize_t mul = PyLong_AsSsize_t(right->GetVobj()->GetPyObject().ptr());
1232 for (auto i = mul; i > 0; --i) {
1233 UnpackElements(left);
1234 }
1235 int oparg = left->getInputs().size() * (mul < 0 ? 0 : size_t(mul));
1236 DoBuildOp({l_op, oparg});
1237 return true;
1238 }
1239
CheckTupleListAdd(ValueNode * left,ValueNode * right)1240 static bool CheckTupleListAdd(ValueNode *left, ValueNode *right) {
1241 // type must be same
1242 AObject::Type l_type = left->GetVobj()->GetType();
1243 AObject::Type r_type = right->GetVobj()->GetType();
1244 bool support = l_type == AObject::kTypeTuple || l_type == AObject::kTypeList;
1245 if (!support || l_type != r_type) {
1246 return false;
1247 }
1248 // only handle BUILD_TUPLE and BUILD_LIST
1249 int l_op = left->GetOpcode();
1250 int r_op = right->GetOpcode();
1251 bool special = l_op == BUILD_TUPLE || l_op == BUILD_LIST || l_op == LOAD_CONST;
1252 bool accept = r_op == BUILD_TUPLE || r_op == BUILD_LIST || r_op == LOAD_CONST;
1253 if (!special || !accept) {
1254 return false;
1255 }
1256 return true;
1257 }
1258
DoInplaceAdd(const Instr & instr)1259 bool GraphBuilder::DoInplaceAdd(const Instr &instr) {
1260 AObject::Type l_type = seek(1)->GetVobj()->GetType();
1261 if (l_type == AObject::kTypeTuple) {
1262 return DoBinaryAdd(instr);
1263 }
1264 if (!CheckTupleListAdd(seek(1), seek(0))) {
1265 return DoBinary(instr);
1266 }
1267
1268 ValueNode *right = pop();
1269 ValueNode *left = pop();
1270 int l_op = BUILD_LIST;
1271
1272 int size = this->frame_.GetStacks().size();
1273 UnpackElements(left);
1274 UnpackElements(right);
1275 size = this->frame_.GetStacks().size() - size;
1276 DoBuildOp({l_op, size});
1277
1278 ValueNode *new_node = pop();
1279 if (ReplaceAll(left, new_node)) {
1280 push(new_node);
1281 return true;
1282 }
1283 graph_->GetTracedNodes().pop_back();
1284 push(left);
1285 push(right);
1286 return DoBinary(instr);
1287 }
1288
DoBinaryAdd(const Instr & instr)1289 bool GraphBuilder::DoBinaryAdd(const Instr &instr) {
1290 if (!CheckTupleListAdd(seek(1), seek(0))) {
1291 return DoBinary(instr);
1292 }
1293
1294 ValueNode *right = pop();
1295 ValueNode *left = pop();
1296 int l_op = left->GetVobj()->GetType() == AObject::kTypeTuple ? BUILD_TUPLE : BUILD_LIST;
1297
1298 int size = this->frame_.GetStacks().size();
1299 UnpackElements(left);
1300 UnpackElements(right);
1301 size = this->frame_.GetStacks().size() - size;
1302 DoBuildOp({l_op, size});
1303 return true;
1304 }
1305
DoCompare(const Instr & instr)1306 bool GraphBuilder::DoCompare(const Instr &instr) {
1307 Opcode opcode(instr.op());
1308 int oparg = instr.arg();
1309 auto r = pop();
1310 auto l = pop();
1311
1312 bool invert;
1313 AObject *o;
1314 if (oparg >= Py_LT && oparg <= Py_GE) {
1315 PyObject *left = l->GetVobj() ? l->GetVobj()->GetPyObject().ptr() : nullptr;
1316 PyObject *right = r->GetVobj() ? r->GetVobj()->GetPyObject().ptr() : nullptr;
1317 if (left && right) {
1318 if (CheckValueValid(l->GetVobj()) && CheckValueValid(r->GetVobj())) {
1319 o = AObject::Convert(PyObject_RichCompare(left, right, oparg));
1320 PyErr_Clear();
1321 } else if (l->GetVobj()->GetType() == AObject::kTypeTensor || r->GetVobj()->GetType() == AObject::kTypeTensor) {
1322 o = l->GetVobj()->GetType() == AObject::kTypeTensor ? l->GetVobj() : r->GetVobj();
1323 auto tensor_type = py::reinterpret_borrow<py::object>(GetMsTensorType());
1324 py::object dtype_bool = Utils::GetModuleAttr("mindspore.common.dtype", "bool_");
1325 auto result_tensor = tensor_type(o->GetPyObject(), dtype_bool);
1326 o = AObject::Convert(result_tensor);
1327 } else {
1328 o = AObject::MakeAObject(AObject::kTypeBool);
1329 }
1330 } else {
1331 o = AObject::MakeAObject(AObject::kTypeBool);
1332 }
1333 } else if (opcode.CheckIsOp(oparg, &invert)) {
1334 int res = AObject::BinaryIs(l->GetVobj(), r->GetVobj());
1335 o = res == -1 ? AObject::MakeAObject(AObject::kTypeBool) : AObject::Convert((res ^ invert) ? Py_True : Py_False);
1336 } else if (opcode.CheckContainsOp(oparg, &invert)) {
1337 int res = AObject::BinaryContains(l->GetVobj(), r->GetVobj());
1338 o = res == -1 ? AObject::MakeAObject(AObject::kTypeBool) : AObject::Convert((res ^ invert) ? Py_True : Py_False);
1339 } else {
1340 return false;
1341 }
1342
1343 auto v = NewValueNode(o, instr, {l, r});
1344 push(v);
1345 return true;
1346 }
1347
DoBuildOp(const Instr & instr)1348 bool GraphBuilder::DoBuildOp(const Instr &instr) {
1349 int opcode = instr.op();
1350 int oparg = instr.arg();
1351 int tmp_arg = oparg;
1352 tmp_arg += opcode == BUILD_CONST_KEY_MAP;
1353 tmp_arg += opcode == BUILD_MAP ? tmp_arg : 0;
1354 std::vector<ValueNode *> p(frame_.GetStacks().end() - tmp_arg, frame_.GetStacks().end());
1355 popn(tmp_arg);
1356
1357 ValueNode *v;
1358 if (opcode == BUILD_CONST_KEY_MAP) {
1359 PyObject *keys = p.back()->GetVobj()->GetPyObject().ptr();
1360 MS_EXCEPTION_IF_CHECK_FAIL(keys && PyTuple_CheckExact(keys), "error bytecode BUILD_CONST_KEY_MAP");
1361 Py_ssize_t size = PyTuple_GET_SIZE(keys);
1362 MS_EXCEPTION_IF_CHECK_FAIL(size_t(size) + 1 == p.size(), "error args BUILD_CONST_KEY_MAP");
1363 std::vector<ValueNode *> build_inputs;
1364 for (Py_ssize_t i = 0; i < size; ++i) {
1365 PyObject *item = PyTuple_GET_ITEM(keys, i);
1366 build_inputs.push_back(NewValueNode(AObject::Convert(item), LOAD_CONST, -1));
1367 build_inputs.push_back(p[i]);
1368 }
1369 AObject *vo = AObject::BuildOperations(CollectObjects(build_inputs), BUILD_MAP);
1370 v = NewValueNode(vo, instr, build_inputs);
1371 v->SetOpcode(BUILD_MAP);
1372 v->SetOparg(size);
1373 } else {
1374 AObject *vo = AObject::BuildOperations(CollectObjects(p), opcode);
1375 v = NewValueNode(vo, instr, p);
1376 }
1377 push(v);
1378 return true;
1379 }
1380
ReplaceMergeOp(int opcode,const std::vector<ValueNode * > & inputs)1381 ValueNode *GraphBuilder::ReplaceMergeOp(int opcode, const std::vector<ValueNode *> &inputs) {
1382 ValueNode *origin = inputs[0];
1383 ValueNode *arg = inputs[1];
1384 ValueNode *arg2 = inputs.size() > 2 ? inputs[2] : nullptr;
1385 if (origin->GetOpcode() != BUILD_LIST && origin->GetOpcode() != BUILD_MAP) {
1386 return nullptr;
1387 }
1388 std::vector<ValueNode *> build_inputs = origin->getInputs();
1389 int div = 2;
1390 if (opcode == LIST_APPEND) {
1391 build_inputs.push_back(arg);
1392 opcode = BUILD_LIST;
1393 div = 1;
1394 } else if (opcode == LIST_EXTEND) {
1395 if (arg->IsConstantValue()) {
1396 build_inputs = UnpackConstObject(arg->GetConstantInfo()->value());
1397 } else if (arg->GetOpcode() == BUILD_LIST || arg->GetOpcode() == BUILD_TUPLE) {
1398 build_inputs.insert(build_inputs.end(), arg->getInputs().begin(), arg->getInputs().end());
1399 } else {
1400 return nullptr;
1401 }
1402 opcode = BUILD_LIST;
1403 div = 1;
1404 } else if (opcode == DICT_MERGE || opcode == DICT_UPDATE) {
1405 if (arg->GetOpcode() != BUILD_MAP) {
1406 return nullptr;
1407 }
1408 build_inputs.insert(build_inputs.end(), arg->getInputs().begin(), arg->getInputs().end());
1409 opcode = BUILD_MAP;
1410 } else if (opcode == MAP_ADD) {
1411 build_inputs.push_back(arg);
1412 build_inputs.push_back(arg2);
1413 opcode = BUILD_MAP;
1414 } else {
1415 return nullptr;
1416 }
1417 std::for_each(build_inputs.begin(), build_inputs.end(), [this](ValueNode *i) { this->push(i); });
1418 int oparg = build_inputs.size() / div;
1419 DoBuildOp({opcode, oparg});
1420 return pop();
1421 }
1422
DoMergeOp(const Instr & instr)1423 bool GraphBuilder::DoMergeOp(const Instr &instr) {
1424 int opcode = instr.op();
1425 int oparg = instr.arg();
1426 int pos = oparg + (opcode == MAP_ADD);
1427
1428 int index = this->frame_.GetStacks().size() - 1 - pos;
1429 ValueNode *container = seek(pos);
1430 std::vector<ValueNode *> inputs = {container, pop()};
1431 if (opcode == MAP_ADD) {
1432 inputs.insert(inputs.begin() + 1, pop());
1433 }
1434
1435 // DICT_MERGE only generated when unpack-call in python3.9, all keys must be string
1436 // NOTE: DICT_MERGE opcode requires that *(stack_pointer - oparg - 2) is a function if has duplicate key
1437 // ...
1438 ValueNode *new_node = ReplaceMergeOp(opcode, inputs);
1439 if (new_node != nullptr) {
1440 this->frame_.GetStacks()[index] = new_node;
1441 return true;
1442 }
1443
1444 return false;
1445 }
1446
DoFormatValue(const Instr & instr)1447 bool GraphBuilder::DoFormatValue(const Instr &instr) {
1448 int oparg = instr.arg();
1449 std::vector<ValueNode *> arg;
1450 if ((oparg & FVS_MASK) == FVS_HAVE_SPEC) {
1451 arg.push_back(pop());
1452 }
1453 arg.insert(arg.begin(), pop());
1454 auto vo = AObject::MakeAObject(AObject::kTypeString);
1455 auto v = NewValueNode(vo, instr, arg);
1456 push(v);
1457 return true;
1458 }
1459
DoImport(const Instr & instr)1460 bool GraphBuilder::DoImport(const Instr &instr) {
1461 int opcode = instr.op();
1462 if (opcode == IMPORT_FROM) {
1463 // any object
1464 push(NewValueNode(AObject::MakeAObject(AObject::kTypeAnyValue), instr, {seek(0)}));
1465 } else if (opcode == IMPORT_STAR) {
1466 auto from = pop();
1467 NewValueNode(AObject::MakeAObject(AObject::kTypeAnyValue), instr, {from});
1468 } else if (opcode == IMPORT_NAME) {
1469 auto from_list = pop();
1470 auto level = pop();
1471 auto vo = AObject::MakeAObject(AObject::kTypeModule);
1472 auto v = NewValueNode(vo, instr, {level, from_list});
1473 push(v);
1474 } else {
1475 return false;
1476 }
1477 return true;
1478 }
1479
DoByteCode(const Instr & instr)1480 bool GraphBuilder::DoByteCode(const Instr &instr) {
1481 if (current_block_->is_loop_head() && !graph_->Config().GetBoolConfig(GraphJitConfig::kLoopUnrolling)) {
1482 graph_->StopTraceAt(cur_bci_, StopTraceReason::kStopTraceLoop_Unsupported);
1483 return false;
1484 }
1485
1486 auto func_iter = bytecode_meth_map_.find(instr.op());
1487 bool support = false;
1488 if (func_iter != bytecode_meth_map_.end()) {
1489 const auto func = func_iter->second;
1490 support = (this->*func)(instr);
1491 }
1492
1493 const auto &nodes = graph_->GetTracedNodes();
1494 for (auto i = nodes.rbegin(); i != nodes.rend() && (*i)->GetBlock() == nullptr; ++i) {
1495 (*i)->SetBlock(current_block_);
1496 }
1497
1498 if (instr.op() == RETURN_VALUE) {
1499 return false;
1500 }
1501
1502 if (!support) {
1503 if (graph_->GetStopTraceBci() == -1) {
1504 graph_->StopTraceAt(cur_bci_, StopTraceReason::kStopTraceByteCode_Unsupported);
1505 }
1506 return false;
1507 }
1508
1509 if (instr.extra_jump() == nullptr) {
1510 ++cur_bci_;
1511 } else {
1512 bool valid = (cur_bci_ == instr.bci() + 1) || cur_bci_ == instr.extra_jump()->bci();
1513 MS_EXCEPTION_IF_CHECK_FAIL(valid, "error jump target");
1514 }
1515 if (cur_bci_ < current_block_->begin_ci() || cur_bci_ >= current_block_->end_ci()) {
1516 current_block_ = graph_->GetCFG()->GetBlockByBci(cur_bci_);
1517 }
1518 return true;
1519 }
1520
GraphBuilder(const PyFrameObject * f)1521 GraphBuilder::GraphBuilder(const PyFrameObject *f)
1522 : root_(this), parent_(nullptr), graph_(nullptr), current_block_(nullptr) {
1523 PyCodeObject *co = f->f_code;
1524 int argc = co->co_argcount + co->co_kwonlyargcount;
1525 argc += (co->co_flags & CO_VARARGS) ? 1 : 0;
1526 argc += (co->co_flags & CO_VARKEYWORDS) ? 1 : 0;
1527 int ncells = PyTuple_GET_SIZE(co->co_cellvars);
1528 int nfrees = PyTuple_GET_SIZE(co->co_freevars);
1529
1530 graph_ = NewGraph(co, f->f_globals);
1531
1532 frame_.ResizeLocal(co->co_nlocals);
1533 frame_.ResizeClosure(ncells + nfrees);
1534 for (int i = 0; i < argc; i++) {
1535 if (f->f_localsplus[i] == nullptr) {
1536 continue;
1537 }
1538 auto vo = AObject::Convert(f->f_localsplus[i]);
1539 ParamNode *n = graph_->allocator().NewNode<ParamNode>(vo, i);
1540 n->SetName(PyUnicode_AsUTF8(PyTuple_GET_ITEM(co->co_varnames, i)));
1541 frame_.SetLocal(i, n);
1542 graph_->GetSideEffect()->data()->Track(f->f_localsplus[i], n);
1543 }
1544 for (int i = 0; i < ncells + nfrees; i++) {
1545 PyObject *cell = f->f_localsplus[co->co_nlocals + i];
1546 PyObject *cell_contents = PyCell_GET(cell);
1547 AbstractNode::Type t = i < ncells ? AbstractNode::CellVar : AbstractNode::FreeVar;
1548 CellVarNode *n = graph_->allocator().NewNode<CellVarNode>(t);
1549 n->SetGraph(graph_);
1550 n->SetVobj(AObject::Convert(cell));
1551 n->SetIndex(i);
1552 frame_.SetClosure(i, n);
1553 if (i < ncells && co->co_cell2arg != nullptr && co->co_cell2arg[i] != CO_CELL_NOT_AN_ARG) {
1554 MS_EXCEPTION_IF_NULL(cell_contents);
1555 n->SetFromParam(co->co_cell2arg[i]);
1556 }
1557 if (cell_contents == nullptr) {
1558 n->SetValue(&ValueNode::kUnboundLocal);
1559 } else {
1560 ValueNode *param = NewValueNode(AObject::Convert(cell_contents), LOAD_DEREF, i);
1561 param->SetGraph(graph_);
1562 n->AddCellOper(param);
1563 n->SetValue(param);
1564 }
1565 }
1566 }
1567
CollectInlineInfo(CallNode * node,int depth)1568 void GraphBuilder::CollectInlineInfo(CallNode *node, int depth) {
1569 Graph *sub_graph = node->GetSubGraph();
1570 if (!sub_graph) {
1571 return;
1572 }
1573 std::string inline_name = "";
1574 int code_size = 0;
1575 if (sub_graph != nullptr && sub_graph->GetCodeObj() != nullptr) {
1576 inline_name = py::str(reinterpret_cast<PyObject *>(sub_graph->GetCodeObj())).cast<std::string>();
1577 code_size = SizeToInt((PyBytes_GET_SIZE(sub_graph->GetCodeObj()->co_code)) / sizeof(_Py_CODEUNIT));
1578 }
1579 std::string func_name = graph_->GetCodeName();
1580 std::string root_name = root_->GetGraph()->GetCodeName();
1581 JitCompileResults *jcr = getJitCompileResults(reinterpret_cast<PyObject *>(root_->GetGraph()->GetCodeObj()), false);
1582 if (jcr && jcr->tbs && !func_name.empty()) {
1583 jcr->tbs->PushInlineInfo(
1584 {func_name, inline_name, root_name, node->GetInlineReason(), code_size, depth, node->GetLineNo()});
1585 }
1586 }
1587
HandleLoop()1588 void GraphBuilder::HandleLoop() {
1589 Block *loop_head = graph_->GetCFG()->GetBlockByBci(cur_bci_);
1590 if (!loop_head->is_loop_head()) {
1591 return;
1592 }
1593 /**
1594 * (chaiyouheng): before trace start, unrolling loop. avoid graph status is changed while trace loop
1595 * just unrolling a small loop that call nn.CellList.
1596 *
1597 * LoopUnrolling loopUnrollingExe = LoopUnrolling(*graph_);
1598 * (void)loopUnrollingExe.ExecuteLoopUnroll(loop_head);
1599 */
1600 }
1601
FindPyFunc(AObject * vobj)1602 py::object GraphBuilder::FindPyFunc(AObject *vobj) {
1603 if (!vobj) {
1604 return py::cast<py::object>(nullptr);
1605 }
1606
1607 switch (vobj->GetType()) {
1608 case AObject::kTypeCell:
1609 vobj = vobj->GetAttr(ID_construct);
1610 break;
1611 case AObject::kTypeAnyValue:
1612 vobj = vobj->GetAttr(ID___call__);
1613 break;
1614 case AObject::kTypeType:
1615 vobj = vobj->GetAttr("__init__");
1616 break;
1617 case AObject::kTypeBoundMethod:
1618 vobj = vobj->GetAttr("__func__");
1619 default:
1620 break;
1621 }
1622 py::object func = vobj ? vobj->GetPyObject() : py::object();
1623
1624 if (func.ptr() == nullptr) {
1625 PyErr_Clear();
1626 return py::cast<py::object>(nullptr);
1627 }
1628
1629 if (PyMethod_Check(func.ptr())) {
1630 func = py::reinterpret_borrow<py::object>(PyMethod_GET_FUNCTION(func.ptr()));
1631 }
1632
1633 if (PyFunction_Check(func.ptr())) {
1634 return func;
1635 }
1636 return py::cast<py::object>(nullptr);
1637 }
1638
GetFuncInfo(ValueNode * func_node)1639 py::object GraphBuilder::GetFuncInfo(ValueNode *func_node) {
1640 AObject *vobj = func_node->GetVobj();
1641 if (vobj->GetType() == AObject::kTypeCFunction) {
1642 return py::object();
1643 }
1644 if (func_node->GetOpcode() == MAKE_FUNCTION) {
1645 return func_node->GetVobj()->GetPyObject();
1646 }
1647 return FindPyFunc(vobj);
1648 }
1649
WhiteListFuncCheckAndInfer(CallNode * call_node,const py::object & callable)1650 bool GraphBuilder::WhiteListFuncCheckAndInfer(CallNode *call_node, const py::object &callable) {
1651 const auto &conf = call_node->GetGraph()->Config();
1652
1653 AObject::Type vobj_type = call_node->input(0)->GetVobj()->GetType();
1654 if (vobj_type == AObject::kTypeCell) {
1655 current_block_->SetTrackResult(Block::kTrackHasOpsPrimitive);
1656 std::string module_name = GetTopModule(callable);
1657 if (!module_name.empty()) {
1658 kPIJitConfigDefault.AddAllowedInlineModules(module_name);
1659 }
1660 }
1661
1662 bool infer_primitive = conf.GetBoolConfig(GraphJitConfig::kInferPrimitive);
1663 int max_infer = conf.getIntConfig(GraphJitConfig::kInferPrimitiveMax);
1664 if (max_infer != 0 && infer_func_count >= max_infer) {
1665 infer_primitive = false;
1666 } else {
1667 infer_func_count++;
1668 }
1669 infer_primitive &= (conf.getIntConfig(GraphJitConfig::kInferPrimitiveMask) & infer_primitive_func) != 0;
1670 if (!infer_primitive && vobj_type == AObject::kTypePrimitive) {
1671 call_node->SetVobj(AObject::MakeAObject(AObject::kTypeTensor));
1672 call_node->SetInlineReason(InlineReason::kInlineGraphSupportedByMS);
1673 current_block_->SetTrackResult(Block::kTrackHasOpsPrimitive);
1674 return true;
1675 }
1676
1677 InferFunc infer_func = FindInferFunc(callable);
1678 if (infer_func == nullptr) {
1679 return false;
1680 }
1681
1682 call_node->SetInlineReason(InlineReason::kInlineUnknown);
1683 call_node->SetSubGraph(NewGraph(nullptr, nullptr));
1684 call_node->GetSubGraph()->SetGuard(root_->GetGraph()->GetGuard());
1685 infer_func(call_node, this);
1686
1687 InlineReason r;
1688 if (call_node->GetSubGraph() == nullptr) {
1689 r = InlineReason::kInlineFuncSpecialize;
1690 } else {
1691 MS_EXCEPTION_IF_NULL(call_node->GetSubGraph()->GetRetVal());
1692 r = InlineReason::kInline;
1693 seek(0) = call_node->GetSubGraph()->GetRetVal();
1694 }
1695 if (call_node->GetInlineReason() == InlineReason::kInlineUnknown) {
1696 call_node->SetInlineReason(r);
1697 }
1698 return true;
1699 }
1700
UnsupportedCodeTypeCheck(PyCodeObject * co)1701 bool UnsupportedCodeTypeCheck(PyCodeObject *co) {
1702 if (co->co_flags & (CO_GENERATOR | CO_COROUTINE | CO_ASYNC_GENERATOR)) {
1703 MS_LOG(DEBUG) << "generator is unsupported";
1704 return true;
1705 }
1706 /**
1707 * skip super call
1708 * >>>def super_wrapper(self):
1709 * ... __class__=type(self)
1710 * ... def super_init(self):
1711 * ... return super()
1712 * ... return super_init(self)
1713 * >>>assert super(int, 1).__hash__() == super_wrapper(1).__hash__()
1714 */
1715 return false;
1716 }
1717
ApplyInlinePolicy(CallNode * call_node)1718 bool ApplyInlinePolicy(CallNode *call_node) {
1719 Graph *g = call_node->GetSubGraph();
1720 if (g == nullptr || g->GetRetVal() == nullptr) {
1721 return false;
1722 }
1723
1724 PyCodeObject *co = g->GetCodeObj();
1725 int ncells = PyTuple_GET_SIZE(co->co_cellvars);
1726 int nfrees = PyTuple_GET_SIZE(co->co_freevars);
1727
1728 bool is_make_func = call_node->input(0)->GetOpcode() == MAKE_FUNCTION;
1729 if (is_make_func) {
1730 // inline MAKE_FUNCTION, need eliminate cell and free variable if the function is not dead local.
1731 return ncells == 0;
1732 }
1733
1734 const auto &closures = g->GetFrame(0).GetClosures();
1735 if (std::any_of(closures.begin(), closures.begin() + ncells, [](auto n) { return !n->GetCellOper().empty(); })) {
1736 return false;
1737 }
1738 if (nfrees > 0) {
1739 // if inline, guard free variable
1740 return nfrees == 1 && std::string("__class__") == PyUnicode_AsUTF8(PyTuple_GET_ITEM(co->co_freevars, 0));
1741 }
1742 if (g->GetRetVal()->GetOpcode() == MAKE_FUNCTION) {
1743 return false;
1744 }
1745 for (auto i : g->GetTracedNodes()) {
1746 // check MAKE_FUNCTION is alive, it is incorrect that inline the function of different module with MAKE_FUNCTION
1747 auto begin = i->getInputs().begin();
1748 if (Opcode(i->GetOpcode()).IsCall() && static_cast<CallNode *>(i)->GetInlineReason() == InlineReason::kInline) {
1749 begin++;
1750 }
1751 if (std::any_of(begin, i->getInputs().end(), [](ValueNode *n) { return n->GetOpcode() == MAKE_FUNCTION; })) {
1752 return false;
1753 }
1754 }
1755 return true;
1756 }
1757
CheckSupportCreateInstance(CallNode * call_node)1758 bool CheckSupportCreateInstance(CallNode *call_node) {
1759 /**
1760 * only support exactly type, sub-class not create
1761 */
1762 static const std::set<PyTypeObject *> support_create_instance_type = {
1763 &PyComplex_Type, &PyMap_Type, &PyBaseObject_Type, &PyRange_Type, &PyZip_Type, &PySlice_Type,
1764 &PyBool_Type, &PyFloat_Type, &PyLong_Type, &PyType_Type, &PyMethod_Type,
1765 };
1766
1767 AObject *cls_info = call_node->input(0)->GetVobj();
1768 PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(static_cast<AbstractType *>(cls_info)->GetPyObject().ptr());
1769 if (tp == nullptr) {
1770 return false;
1771 }
1772 if (support_create_instance_type.find(tp) != support_create_instance_type.end()) {
1773 return true;
1774 }
1775
1776 /**
1777 * maybe has sideeffect, limit create
1778 */
1779 static const std::set<PyTypeObject *> limit_create_instance_type = {
1780 &PyList_Type, &PyTuple_Type, &PySet_Type, &PyFrozenSet_Type, &PyDict_Type, &PyUnicode_Type, &PyEnum_Type,
1781 };
1782 if (call_node->getInputs().size() != 2) {
1783 return false;
1784 }
1785 ValueNode *iterable_node = call_node->input(1);
1786 AObject *first_param = iterable_node->GetVobj();
1787 if (first_param == nullptr) {
1788 return false;
1789 }
1790
1791 if (first_param->GetType() == AObject::kTypeAnyValue) {
1792 if (iterable_node->GetOpcode() != CALL_FUNCTION || call_node->bci() - 1 != iterable_node->bci()) {
1793 return false;
1794 }
1795 /**
1796 * just process this case:
1797 * z = list(zip(list(x), list(y)))
1798 * z = list(enumerate(x))
1799 */
1800 // this case, zip object and enumerate object is dead variable
1801 }
1802 return limit_create_instance_type.find(tp) != limit_create_instance_type.end();
1803 }
1804
BuildSuperObject(PyCodeObject * co)1805 AObject *GraphBuilder::BuildSuperObject(PyCodeObject *co) {
1806 if (co->co_argcount == 0) {
1807 PyErr_SetString(PyExc_RuntimeError, "super(): no arguments");
1808 return nullptr;
1809 }
1810
1811 Py_ssize_t i, n;
1812 // search self object
1813 PyObject *obj = SearchSelfPyObject(co).first;
1814 if (obj == NULL) {
1815 PyErr_SetString(PyExc_RuntimeError, "super(): arg[0] deleted");
1816 return nullptr;
1817 }
1818
1819 if (co->co_freevars == NULL) {
1820 n = 0;
1821 } else {
1822 assert(PyTuple_Check(co->co_freevars));
1823 n = PyTuple_GET_SIZE(co->co_freevars);
1824 }
1825
1826 PyTypeObject *type = NULL;
1827 for (i = 0; i < n; i++) {
1828 PyObject *name = PyTuple_GET_ITEM(co->co_freevars, i);
1829 assert(PyUnicode_Check(name));
1830 // check class id
1831 if (!strcmp("__class__", PyUnicode_AsUTF8(name))) {
1832 Py_ssize_t index = PyTuple_GET_SIZE(co->co_cellvars) + i;
1833 PyObject *cell = SetLocalPyObject(frame_.Closure(index));
1834 if (cell == NULL || !PyCell_Check(cell)) {
1835 PyErr_SetString(PyExc_RuntimeError, "super(): bad __class__ cell");
1836 return nullptr;
1837 }
1838 type = reinterpret_cast<PyTypeObject *>(PyCell_GET(cell));
1839 if (type == NULL) {
1840 PyErr_SetString(PyExc_RuntimeError, "super(): empty __class__ cell");
1841 return nullptr;
1842 }
1843 if (!PyType_Check(type)) {
1844 PyErr_Format(PyExc_RuntimeError, "super(): __class__ is not a tyep (%s)", Py_TYPE(type)->tp_name);
1845 return nullptr;
1846 }
1847 break;
1848 }
1849 }
1850 if (type == NULL) {
1851 PyErr_SetString(PyExc_RuntimeError, "super(): __class__ cell not found");
1852 return nullptr;
1853 }
1854
1855 py::object py_obj = py::reinterpret_borrow<py::object>(obj);
1856 py::object py_type = py::reinterpret_borrow<py::object>(reinterpret_cast<PyObject *>(type));
1857 py::tuple tuple_obj(2);
1858 tuple_obj[0] = py_type;
1859 tuple_obj[1] = py_obj;
1860 PyObject *ret = PyObject_Call(reinterpret_cast<PyObject *>(&PySuper_Type), tuple_obj.ptr(), nullptr);
1861 AObject *super_obj = AObject::Convert(ret);
1862 Py_DECREF(ret);
1863 return super_obj;
1864 }
1865
ClassInstantiationFold(CallNode * call_node,AObject::Type type)1866 bool GraphBuilder::ClassInstantiationFold(CallNode *call_node, AObject::Type type) {
1867 const auto ¶ms = call_node->getInputs();
1868 int call_op = call_node->GetOpcode();
1869
1870 // list, tuple, dict fold
1871 std::vector<ValueNode *> inputs;
1872 int new_op;
1873 int new_arg;
1874 if (type == AObject::kTypeTuple || type == AObject::kTypeList) {
1875 if (params.size() > 1) {
1876 int arg_op = params[1]->GetOpcode();
1877 if (call_op == CALL_FUNCTION && (arg_op == BUILD_TUPLE || arg_op == BUILD_LIST)) {
1878 inputs = params[1]->getInputs();
1879 } else {
1880 return false;
1881 }
1882 }
1883 new_op = type == AObject::kTypeTuple ? BUILD_TUPLE : BUILD_LIST;
1884 new_arg = inputs.size();
1885 } else if (type == AObject::kTypeDict) {
1886 if (params.size() > 1) {
1887 ValueNode *map_node;
1888 if (call_op == CALL_FUNCTION && params[1]->GetOpcode() == BUILD_MAP) {
1889 map_node = params[1];
1890 } else if (call_op == CALL_FUNCTION_EX && params.size() > 2 && params[2]->GetOpcode() == BUILD_MAP) {
1891 map_node = params[2];
1892 } else {
1893 return false;
1894 }
1895 inputs = map_node->getInputs();
1896 }
1897 new_op = BUILD_MAP;
1898 new_arg = inputs.size() / 2;
1899 } else {
1900 return false;
1901 }
1902
1903 Graph *sub_graph = NewGraph(nullptr, nullptr);
1904 AObject *res = AObject::BuildOperations(CollectObjects(inputs), new_op);
1905 ValueNode *new_node = sub_graph->NewValueNode(res, new_op, new_arg, inputs);
1906 sub_graph->GetTracedNodes().push_back(new_node);
1907 sub_graph->SetRetVal(new_node);
1908
1909 call_node->SetSubGraph(sub_graph);
1910 call_node->SetInlineReason(InlineReason::kInline);
1911 seek(0) = new_node;
1912 return true;
1913 }
1914
LogGuardFailed(ValueNode * node,const GraphJitConfig & conf,const std::string & msg)1915 void LogGuardFailed(ValueNode *node, const GraphJitConfig &conf, const std::string &msg) {
1916 if (!conf.GetBoolConfig(GraphJitConfig::kLogGraphBreak)) {
1917 return;
1918 }
1919 auto tr = GetTrace(node, false, true, 0, -1);
1920 std::stringstream s;
1921 if (node->GetVobj() == nullptr || node->GetVobj()->GetPyObject().ptr() == nullptr) {
1922 s << "infer failed\n";
1923 } else {
1924 std::map<Trace *, size_t> cache;
1925 s << "trace:\n" << (tr ? tr->FormatString(&cache).c_str() : "trace failed") << "\n";
1926 }
1927 s << msg << " [" << node->ToString() << "]";
1928 GRAPH_JIT_LOG_F("%s", s.str().c_str());
1929 }
1930
HandleCallClass(CallNode * call_node)1931 bool GraphBuilder::HandleCallClass(CallNode *call_node) {
1932 AObject *vobj = call_node->input(0)->GetVobj();
1933 if (!vobj || vobj->GetType() != AObject::kTypeType) {
1934 return false;
1935 }
1936 AbstractType *t = static_cast<AbstractType *>(vobj);
1937 AObject::Type type = t->GetTypeType();
1938 if (!trace_flag() && ClassInstantiationFold(call_node, type)) {
1939 return true;
1940 }
1941
1942 const auto ¶ms = call_node->getInputs();
1943 AObject *instance = nullptr;
1944 bool support_create_instance = CheckSupportCreateInstance(call_node);
1945 bool constant = type == AObject::kTypePrimitive || type == AObject::kTypeTensor || type == AObject::kTypeStubTensor ||
1946 IsMsClass(t->GetPyObject().ptr());
1947 // create instance
1948 if (support_create_instance || constant) {
1949 constant |= std::none_of(params.begin(), params.end(), [](ValueNode *i) { return !i->IsConstantValue(); });
1950 std::vector<py::object> args;
1951 std::transform(params.begin() + 1, params.end(), std::back_inserter(args), [](ValueNode *n) {
1952 AObject *i = n->GetVobj();
1953 return i ? i->GetPyObject() : py::object();
1954 });
1955 py::object res = t->BuildInstance(args, call_node->GetOpcode());
1956 instance = res.ptr() ? AObject::Convert(res) : nullptr;
1957 } else if (reinterpret_cast<PyTypeObject *>(vobj->GetPyObject().ptr()) == &PySuper_Type) {
1958 // take super ptr and compare with PySuper_Type
1959 instance = BuildSuperObject(graph_->GetCodeObj());
1960 this->graph_->GetTracedNodes().pop_back();
1961 if (PyErr_Occurred()) {
1962 throw py::error_already_set();
1963 }
1964 }
1965
1966 if (constant && instance != nullptr && GuardConstCallNodeParam(call_node, call_node->GetGraph(), INT_MAX)) {
1967 // make instance is constant
1968 auto iter = this->graph_->GetTracedNodes().end() - 1;
1969 MS_EXCEPTION_IF_CHECK_FAIL(*iter == call_node, "CallNode must be last when build sub graph");
1970 *iter = NewValueNode(instance, LOAD_CONST, -1, {});
1971 seek(0) = *iter;
1972 } else if (!instance) {
1973 // create abstract instance
1974 instance = t->BuildAbstractInstance(CollectObjects({params.begin() + 1, params.end()}), call_node->GetOpcode());
1975 }
1976 call_node->SetVobj(instance);
1977 return instance != nullptr;
1978 }
1979
1980 // NOTE: must be copy __code__, copy.deepcopy do nothing for code object
CopyPyFunc(const py::object & o)1981 static py::object CopyPyFunc(const py::object &o) {
1982 MS_EXCEPTION_IF_CHECK_FAIL(PyFunction_Check(o.ptr()), "must be function");
1983 PyFunctionObject *func = reinterpret_cast<PyFunctionObject *>(o.ptr());
1984 PyCodeObject *code = reinterpret_cast<PyCodeObject *>(func->func_code);
1985 PyObject *new_name = PyUnicode_FromFormat("%s%U", kPIJitCopyFuncKey, code->co_name);
1986 PyCodeObject *new_code =
1987 PyCode_New(code->co_argcount, code->co_kwonlyargcount, code->co_nlocals, code->co_stacksize, code->co_flags,
1988 code->co_code, code->co_consts, code->co_names, code->co_varnames, code->co_freevars, code->co_cellvars,
1989 code->co_filename, code->co_name, code->co_firstlineno, GetCodeLineTable(code));
1990 if (new_code == nullptr || new_name == nullptr) {
1991 throw py::error_already_set();
1992 }
1993 PyObject *new_func = PyFunction_NewWithQualName(reinterpret_cast<PyObject *>(new_code), func->func_globals, new_name);
1994 PyFunctionObject *new_ff = reinterpret_cast<PyFunctionObject *>(new_func);
1995 REPLACE_PY_MEMBER(new_ff->func_closure, func->func_closure);
1996 REPLACE_PY_MEMBER(new_ff->func_defaults, func->func_defaults);
1997 REPLACE_PY_MEMBER(new_ff->func_kwdefaults, func->func_kwdefaults);
1998 REPLACE_PY_MEMBER(new_ff->func_annotations, func->func_annotations);
1999
2000 Py_DECREF(new_name);
2001 Py_DECREF(new_code);
2002 return py::reinterpret_steal<py::object>(new_func);
2003 }
2004
GetPIJitCopiedFunc(const py::object & func)2005 py::object GetPIJitCopiedFunc(const py::object &func) {
2006 PyObject *res = PyObject_GetAttrString(func.ptr(), kPIJitCopyFuncKey);
2007 if (res != nullptr) {
2008 return py::reinterpret_steal<py::object>(res);
2009 }
2010 PyErr_Clear();
2011 py::object copy = CopyPyFunc(func);
2012 PyObject_SetAttrString(func.ptr(), kPIJitCopyFuncKey, copy.ptr());
2013 (void)pi_jit_should_compile(copy, py::dict(), py::none());
2014 return copy;
2015 }
2016
GetSelfFromMethod(ValueNode * method)2017 ValueNode *GetSelfFromMethod(ValueNode *method) {
2018 if (method->GetOpcode() != LOAD_ATTR) {
2019 return nullptr;
2020 }
2021 ValueNode *self = method->input(0);
2022 /**
2023 * (chaiyouheng):
2024 * Check method is a generic attribute
2025 * descr = _PyType_Lookup(self->GetVobj()->GetTypeObject(), py::str(method->GetName()).ptr());
2026 * Check descr == nullptr || !PyFunction_Check(descr)
2027 */
2028 return self;
2029 }
2030
ReplaceCall(CallNode * call_node,const py::object & old_func)2031 bool GraphBuilder::ReplaceCall(CallNode *call_node, const py::object &old_func) {
2032 if (call_node->GetOpcode() == CALL_FUNCTION_EX && call_node->input(1)->GetOpcode() != BUILD_TUPLE) {
2033 // dynamic length variable arguments, user-defined unpack sequence
2034 return false;
2035 }
2036 if (!graph_->GuardInlinedFunc(call_node)) {
2037 return false;
2038 }
2039 auto jcr = getJitCompileResults(old_func.ptr(), false);
2040 if (jcr != nullptr && jcr->stat != JitCompileResults::NEVER_COMPILE) {
2041 return true;
2042 }
2043
2044 py::object new_func = GetPIJitCopiedFunc(old_func);
2045
2046 auto &nodes = graph_->GetTracedNodes();
2047 MS_EXCEPTION_IF_CHECK_FAIL(nodes.back() == call_node, "CallNode must be last when build sub graph");
2048
2049 ValueNode *self = nullptr;
2050 AObject::Type func_type = call_node->input(0)->GetVobj()->GetType();
2051 if (func_type == AObject::kTypeBoundMethod) {
2052 ValueNode *func_val = call_node->input(0);
2053 self = GetSelfFromMethod(func_val);
2054 if (self == nullptr) {
2055 ValueNode *node = NewValueNode(func_val->get_attr(GraphBuilder::ID___self__), LOAD_ATTR, -1, {func_val},
2056 GraphBuilder::ID___self__);
2057 node->SetGraph(call_node->GetGraph());
2058 nodes.insert(nodes.end() - 1, node);
2059 self = node;
2060 }
2061 } else if (func_type == AObject::kTypeCell || AObject::kTypeAnyValue) {
2062 self = call_node->input(0);
2063 } else if (func_type != AObject::kTypeFunction) {
2064 return false;
2065 }
2066
2067 std::stringstream key;
2068 PyObject *func_name = reinterpret_cast<PyFunctionObject *>(new_func.ptr())->func_qualname;
2069 key << std::string(py::str(func_name)) << "." << new_func.ptr();
2070
2071 // new func node
2072 ValueNode *func_node = this->NewValueNode(AObject::Convert(new_func), LOAD_CONST, -1, {});
2073 nodes.insert(nodes.end() - 1, func_node);
2074
2075 // replace node
2076 call_node->getInputs()[0] = func_node;
2077 if (self == nullptr) {
2078 return true;
2079 }
2080
2081 // append self to args
2082 if (call_node->GetOpcode() != CALL_FUNCTION_EX) {
2083 call_node->getInputs().insert(call_node->getInputs().begin() + 1, self);
2084 call_node->SetOparg(call_node->GetOparg() + 1);
2085 return true;
2086 }
2087
2088 // append self to variable arguments
2089 ValueNode *args_node = call_node->input(1);
2090 std::vector<ValueNode *> inputs = args_node->getInputs();
2091 inputs.insert(inputs.begin(), self);
2092 AObject *args_info = AObject::BuildOperations(CollectObjects(inputs), BUILD_TUPLE);
2093
2094 ValueNode *tuple = this->NewValueNode(args_info, BUILD_TUPLE, inputs.size(), inputs);
2095 tuple->set_bci(call_node->bci());
2096 tuple->SetLineNo(call_node->GetLineNo());
2097 nodes.insert(nodes.end() - 1, tuple);
2098 call_node->getInputs()[1] = tuple;
2099 return true;
2100 }
2101
MindGraphBuilder(const PyFrameObject * f)2102 MindGraphBuilder::MindGraphBuilder(const PyFrameObject *f) : GraphBuilder(f) {
2103 std::vector<std::string> comments;
2104 auto location = std::make_shared<Location>(py::cast<std::string>(f->f_code->co_filename), f->f_code->co_firstlineno,
2105 0, f->f_code->co_firstlineno, 0, "", std::move(comments));
2106 MS_EXCEPTION_IF_NULL(location);
2107 TraceGuard trace_guard(location);
2108 fg_builder_ = std::make_shared<FuncGraphBuilder>(true);
2109 fg_builder_->SetGraphName(py::cast<std::string>(f->f_code->co_name) + "_" +
2110 std::to_string(f->f_code->co_firstlineno));
2111 co_name_ = py::cast<std::string>(f->f_code->co_name);
2112 }
2113
2114 namespace {
GetFuncGraphName(const py::object & func,const MindGraphBuilderPtr & subgraph)2115 std::string GetFuncGraphName(const py::object &func, const MindGraphBuilderPtr &subgraph) {
2116 auto func_str = py::cast<std::string>(py::str(func));
2117 std::vector<std::string> vec;
2118 std::istringstream iss(func_str);
2119 std::string str;
2120 while (iss >> str) {
2121 (void)vec.emplace_back(str);
2122 }
2123 if (vec.size() <= 1) {
2124 return "";
2125 }
2126 auto func_name = vec[1];
2127 std::replace(func_name.begin(), func_name.end(), '.', '_');
2128 return func_name + "_" + std::to_string(subgraph->GetGraph()->GetCodeObj()->co_firstlineno);
2129 }
2130 } // namespace
2131
BuildSubGraph(CallNode * call_node,int depth,const py::object & func,const GraphBuilderPtr & subgraph)2132 StopTraceReason MindGraphBuilder::BuildSubGraph(CallNode *call_node, int depth, const py::object &func,
2133 const GraphBuilderPtr &subgraph) {
2134 auto sg = std::dynamic_pointer_cast<MindGraphBuilder>(subgraph);
2135 sg->FGBuilder()->AddPrevBuilder(FGBuilder());
2136
2137 auto code = sg->GetGraph()->GetGuard();
2138 MS_EXCEPTION_IF_NULL(code);
2139 code->GetGuard()->Backup();
2140
2141 auto args = call_node->GetArgs();
2142 if (PyFunction_Check(func.ptr())) {
2143 args = GetNewArgs(call_node, AObject::Convert(func.ptr()));
2144 }
2145
2146 MS_LOG(INFO) << "new subgraph->TraceRun: " << py::str(func);
2147 bool succ = sg->FGAddInputs(args);
2148 if (!succ) {
2149 MS_LOG(INFO) << "Add input fail for new subgraph->TraceRun: " << py::str(func);
2150 return StopTraceReason::kStopTraceFunc_ArgHandle_Unsupported;
2151 }
2152 auto reason = sg->TraceRun();
2153 MS_LOG(INFO) << "new subgraph->TraceRun end: " << py::str(func);
2154
2155 call_node->SetSubGraph(sg->GetGraph());
2156 auto sub_ret = sg->GetGraph()->GetRetVal();
2157 if (sub_ret != nullptr) {
2158 if (sub_ret->GetVobj()->GetPyObject().ptr() == nullptr ||
2159 CheckConstPyObject(sub_ret->GetVobj()->GetPyObject().ptr())) {
2160 call_node->SetVobj(sub_ret->GetVobj());
2161 } else {
2162 sg->FGBuilder()->SetGraphName(GetFuncGraphName(func, sg));
2163 sg->FGAddOutput(false);
2164 if (sg->FGBuilder()->graph() == nullptr) {
2165 MS_LOG(INFO) << "subgraph trace null";
2166 return StopTraceReason::kTrace_Fail;
2167 } else {
2168 TraceGuard trace_guard(GetLocation(call_node));
2169 auto res = FGBuilder()->AddNode(sg->FGBuilder()->graph(), args);
2170 if (res.ptr()) {
2171 MS_LOG(INFO) << "add fg node suc: ";
2172 call_node->SetVobj(AObject::Convert(res));
2173 }
2174 }
2175 }
2176 }
2177 bool is_make_func = call_node->input(0)->GetOpcode() == MAKE_FUNCTION;
2178 if (is_make_func) {
2179 graph_->GuardInlinedFunc(call_node);
2180 }
2181 return reason;
2182 }
2183
2184 // build sub-graph
BuildSubGraph(CallNode * call_node,int depth,const py::object & func,const GraphBuilderPtr & subgraph)2185 StopTraceReason GraphBuilder::BuildSubGraph(CallNode *call_node, int depth, const py::object &func,
2186 const GraphBuilderPtr &subgraph) {
2187 InlineReason stat = InlineReason::kInline;
2188 bool is_make_func = call_node->input(0)->GetOpcode() == MAKE_FUNCTION;
2189
2190 auto code = subgraph->GetGraph()->GetGuard();
2191 MS_EXCEPTION_IF_NULL(code);
2192 code->GetGuard()->Backup();
2193
2194 MS_LOG(INFO) << "old subgraph->TraceRun";
2195 subgraph->TraceRun();
2196
2197 call_node->SetSubGraph(subgraph->GetGraph());
2198 if (subgraph->GetGraph()->GetRetVal() != nullptr) {
2199 call_node->SetVobj(subgraph->GetGraph()->GetRetVal()->GetVobj());
2200 }
2201 bool gen_to_tuple = subgraph->GetGraph()->Config().GetBoolConfig(GraphJitConfig::kEnableGeneratorExpressionToTuple);
2202 if (!gen_to_tuple && subgraph->GetGraph()->GetGeneratorResult() != nullptr) {
2203 subgraph->GetGraph()->SetRetVal(nullptr);
2204 }
2205
2206 stat = ApplyInlinePolicy(call_node) ? stat : InlineReason::kInlinePolicyDisabled;
2207 if (stat != InlineReason::kInline) {
2208 code->GetGuard()->Rollback();
2209 if (!is_make_func) {
2210 /**
2211 * replace function call, inline or resume capture after break graph
2212 * exclude make function, because of function always a new function but code is constant
2213 **/
2214 stat = ReplaceCall(call_node, func) ? stat : InlineReason::kInlinePolicyDisabled;
2215 }
2216 } else {
2217 if (!is_make_func) {
2218 // exclude make function, because of function always a new function but code is constant
2219 stat = graph_->GuardInlinedFunc(call_node) ? stat : InlineReason::kInlinePolicyDisabled;
2220 }
2221 if (stat != InlineReason::kInline) {
2222 code->GetGuard()->Rollback();
2223 } else {
2224 code->GetGuard()->Pop();
2225 }
2226 }
2227
2228 // if stat == InlineReason::kInline, guard free variable
2229 call_node->SetInlineReason(stat);
2230 return StopTraceReason::kNonStopTrace;
2231 }
2232
UnpackDynamicLengthDictByBytecode(std::vector<ValueNode * > * params,CallNode * call_node,ValueNode * dict_node)2233 bool GraphBuilder::UnpackDynamicLengthDictByBytecode(std::vector<ValueNode *> *params, CallNode *call_node,
2234 ValueNode *dict_node) {
2235 // user defined mappings, dynamic length dictionary unpack
2236 if (dict_node->GetVobj()->GetType() != AObject::kTypeDict) {
2237 return false;
2238 }
2239 auto dict = static_cast<AbstractDict *>(dict_node->GetVobj());
2240 if (!dict->IsElementValid()) {
2241 return false;
2242 }
2243 /**
2244 * must be guard this dict length
2245 */
2246 py::dict py_dict = dict->GetPyObject();
2247 py::tuple keys(py_dict.size());
2248 PyObject *key;
2249 PyObject *value;
2250 Py_ssize_t pos = 0;
2251 Py_ssize_t cnt = 0;
2252 while (PyDict_Next(py_dict.ptr(), &pos, &key, &value)) {
2253 PyObject *py_key = key;
2254 MS_EXCEPTION_IF_CHECK_FAIL(PyUnicode_CheckExact(py_key), "key must be string");
2255 PyObject *py_value = value;
2256 ValueNode *index = NewValueNode(AObject::Convert(py_key), LOAD_CONST, -1, {});
2257 ValueNode *val = NewValueNode(AObject::Convert(py_value), BINARY_SUBSCR, 0, {dict_node, index});
2258 keys[cnt++] = py_key;
2259 params->push_back(val);
2260 call_node->AddParam(val);
2261 }
2262 ValueNode *const_keys = NewValueNode(AObject::Convert(keys), LOAD_CONST, -1, {});
2263 params->push_back(const_keys);
2264 return true;
2265 }
2266
UnpackCallExDict(std::vector<ValueNode * > * params,CallNode * call_node)2267 bool GraphBuilder::UnpackCallExDict(std::vector<ValueNode *> *params, CallNode *call_node) {
2268 ValueNode *dict_node = params->back();
2269 params->clear();
2270 if (dict_node->GetOpcode() != BUILD_MAP) {
2271 return UnpackDynamicLengthDictByBytecode(params, call_node, dict_node);
2272 }
2273 if (dict_node->GetOparg() == 0) {
2274 return true;
2275 }
2276 py::tuple keys(dict_node->GetOparg());
2277 for (int i = 0; i < dict_node->GetOparg(); ++i) {
2278 AObject *k = dict_node->input(i * 2)->GetVobj();
2279 if (k->GetType() != AObject::kTypeString) {
2280 MS_LOG(DEBUG) << "for unpack-call, dict keys must be string";
2281 return false;
2282 }
2283 keys[i] = k->GetPyObject();
2284 params->push_back(dict_node->input((i << 1) + 1));
2285 MS_EXCEPTION_IF_CHECK_FAIL(keys[i].ptr(), "the keys of unpack-call must be a const string");
2286 }
2287 ValueNode *const_keys = this->NewValueNode(AObject::Convert(keys), LOAD_CONST, -1, {});
2288 params->push_back(const_keys);
2289 return true;
2290 }
2291
UnpackDynamicLengthTupleByBytecode(std::vector<ValueNode * > * params,ValueNode * args_node,CallNode * call_node)2292 bool GraphBuilder::UnpackDynamicLengthTupleByBytecode(std::vector<ValueNode *> *params, ValueNode *args_node,
2293 CallNode *call_node) {
2294 // user-defined sequence, dynamic length tuple unpack
2295 if (args_node->GetVobj() && args_node->GetVobj()->GetType() != AObject::kTypeTuple) {
2296 return false;
2297 }
2298 AbstractTuple *tuple = static_cast<AbstractTuple *>(args_node->GetVobj());
2299 if (!tuple->IsElementValid()) {
2300 return false;
2301 }
2302 /**
2303 * must be guard this tuple length
2304 */
2305 auto items = tuple->items();
2306 std::vector<ValueNode *> args;
2307 for (size_t i = 0; i < items.size(); i++) {
2308 ValueNode *idx_node = this->NewValueNode(AObject::Convert(py::int_(i)), LOAD_CONST, -1, {});
2309 auto value = this->NewValueNode(items[i], BINARY_SUBSCR, 0, {args_node, idx_node});
2310 args.push_back(value);
2311
2312 call_node->AddParam(value);
2313 }
2314 params->insert(params->begin(), args.begin(), args.end());
2315 return true;
2316 }
2317
2318 // unpack CALL_FUNCTION_EX parameters
2319 // should do this when bytecode analyze ? replace origin opcode
UnpackCallExParams(std::vector<ValueNode * > * params,int extra_local,bool * has_kw,CallNode * call_node)2320 bool GraphBuilder::UnpackCallExParams(std::vector<ValueNode *> *params, int extra_local, bool *has_kw,
2321 CallNode *call_node) {
2322 bool has_dict = params->size() > 1;
2323 ValueNode *args_node = params->operator[](0);
2324 if (!has_dict) {
2325 params->clear();
2326 } else if (!UnpackCallExDict(params, call_node)) {
2327 return false;
2328 }
2329 *has_kw = params->size();
2330 if (args_node->GetOpcode() != BUILD_TUPLE) {
2331 return UnpackDynamicLengthTupleByBytecode(params, args_node, call_node);
2332 }
2333 params->insert(params->begin(), args_node->getInputs().begin(), args_node->getInputs().end());
2334 return true;
2335 }
2336
PackKwParams(const py::object & func,std::vector<ValueNode * > * params,FrameStates * frame,std::vector<ValueNode * > * kwvargs)2337 bool GraphBuilder::PackKwParams(const py::object &func, std::vector<ValueNode *> *params, FrameStates *frame,
2338 std::vector<ValueNode *> *kwvargs) {
2339 PyCodeObject *co = reinterpret_cast<PyCodeObject *>(PyFunction_GET_CODE(func.ptr()));
2340 AObject *keys_info = params->back()->GetVobj();
2341 if (params->back()->GetOpcode() != LOAD_CONST || keys_info->GetType() != AObject::kTypeTuple) {
2342 return false; // other case
2343 }
2344
2345 const int posonlyargcount = GetCodePositionOnlyArgCount(co);
2346 PyObject **vars = &PyTuple_GET_ITEM(co->co_varnames, 0);
2347 const int argc = co->co_argcount + co->co_kwonlyargcount;
2348 PyObject **kwnames = &PyTuple_GET_ITEM(keys_info->GetPyObject().ptr(), 0);
2349 const int k_cnt = PyTuple_GET_SIZE(keys_info->GetPyObject().ptr());
2350 // kwnames must be string
2351 MS_ASSERT(static_cast<AbstractTuple *>(keys_info)->GetElementType() == AObject::kTypeString);
2352 MS_EXCEPTION_IF_CHECK_FAIL(SizeToInt(params->size()) > k_cnt, "check param");
2353
2354 int kw_2_p_cnt = 0;
2355
2356 // for each kw argument
2357 for (int i = k_cnt - 1; i >= 0; --i) {
2358 PyObject *key = kwnames[i];
2359 // find position and kwonly argument for key
2360 int pos = std::find_if(vars, vars + argc, [&key](PyObject *k) { return !PyUnicode_Compare(key, k); }) - vars;
2361 if (pos < posonlyargcount) {
2362 MS_LOG(DEBUG) << "position only argument specified by key-word";
2363 return false;
2364 }
2365
2366 ValueNode *v = *(params->end() - 1 - k_cnt + i);
2367 // if key is position arg, store it
2368 if (pos < argc) {
2369 frame->SetLocal(pos, v);
2370 kw_2_p_cnt++;
2371 continue;
2372 }
2373 ValueNode *k = NewValueNode(AObject::Convert(key), LOAD_CONST, -1, {});
2374
2375 kwvargs->push_back(k);
2376 kwvargs->push_back(v);
2377 }
2378
2379 params->resize(params->size() - 1 - k_cnt);
2380 if (!(co->co_flags & CO_VARKEYWORDS)) {
2381 return kw_2_p_cnt == k_cnt; // if not equal, too many key-word arguments
2382 }
2383 return true;
2384 }
2385
HandleKWParams(const py::object & func,std::vector<ValueNode * > * params,FrameStates * frame)2386 bool GraphBuilder::HandleKWParams(const py::object &func, std::vector<ValueNode *> *params, FrameStates *frame) {
2387 PyCodeObject *co = reinterpret_cast<PyCodeObject *>(PyFunction_GET_CODE(func.ptr()));
2388 std::vector<ValueNode *> kwvargs;
2389 if (!PackKwParams(func, params, frame, &kwvargs)) {
2390 // illegal arguments
2391 return false;
2392 }
2393
2394 const int argc = co->co_argcount + co->co_kwonlyargcount;
2395 if (!(co->co_flags & CO_VARKEYWORDS)) {
2396 // kw_2_p_cnt == k_cnt, all kw arguments is positions arguments
2397 return true;
2398 }
2399
2400 int kwvarg_loc = argc + ((co->co_flags & CO_VARARGS) ? 1 : 0);
2401 AObject *dict = AObject::BuildOperations(CollectObjects(kwvargs), BUILD_MAP);
2402 frame->SetLocal(kwvarg_loc, NewValueNode(dict, BUILD_MAP, kwvargs.size() / 2, kwvargs));
2403
2404 static_cast<CallNode *>(seek(0))->AddParam(frame->Local(kwvarg_loc));
2405 return true;
2406 }
2407
CheckAndSetDefaultParams(const py::object & func,FrameStates * frame,int position_argc)2408 bool GraphBuilder::CheckAndSetDefaultParams(const py::object &func, FrameStates *frame, int position_argc) {
2409 PyCodeObject *co = reinterpret_cast<PyCodeObject *>(PyFunction_GET_CODE(func.ptr()));
2410 PyObject *defs = PyFunction_GET_DEFAULTS(func.ptr());
2411 PyObject *kwdefs = PyFunction_GET_KW_DEFAULTS(func.ptr());
2412
2413 const int argc = co->co_argcount + co->co_kwonlyargcount;
2414 PyObject *vars = co->co_varnames;
2415
2416 int defs_off = defs ? co->co_argcount - PyTuple_GET_SIZE(defs) : INT_MAX;
2417 for (int i = position_argc; i < argc; ++i) {
2418 if (frame->Local(i) != &ValueNode::kUnboundLocal) {
2419 continue;
2420 }
2421 PyObject *val;
2422 if (i < co->co_argcount) {
2423 val = i < defs_off ? nullptr : PyTuple_GET_ITEM(defs, i - defs_off);
2424 } else {
2425 val = kwdefs == nullptr ? nullptr : PyDict_GetItem(kwdefs, PyTuple_GET_ITEM(vars, i));
2426 }
2427 if (val == nullptr) {
2428 MS_LOG(DEBUG) << "no " << (i < defs_off ? "" : "kw-") << "default parameter error";
2429 return false;
2430 }
2431 auto vo = AObject::Convert(val);
2432 ValueNode *c = NewValueNode(vo, LOAD_CONST, -1, {});
2433 frame->SetLocal(i, c);
2434 }
2435 return true;
2436 }
2437
GetBoundSelf(CallNode * call_node)2438 ValueNode *GetBoundSelf(CallNode *call_node) {
2439 ValueNode *func_val = call_node->input(0);
2440 AObject *vo = func_val->GetVobj();
2441 Graph *graph = call_node->GetGraph();
2442
2443 ValueNode *self = nullptr;
2444 switch (vo->GetType()) {
2445 case AObject::kTypeBoundMethod: {
2446 self = GetSelfFromMethod(func_val);
2447 if (self == nullptr) {
2448 AObject *tmp = func_val->get_attr(GraphBuilder::ID___self__);
2449 ValueNode *node = graph->NewValueNode(tmp, LOAD_ATTR, -1, {func_val}, GraphBuilder::ID___self__);
2450 node->SetGraph(call_node->GetGraph());
2451 call_node->AddParam(node);
2452 self = node;
2453 }
2454 break;
2455 }
2456 case AObject::kTypeCell: /* fallthrough */
2457 case AObject::kTypeAnyValue:
2458 self = func_val;
2459 break;
2460 case AObject::kTypeCFunction:
2461 case AObject::kTypeFunction:
2462 break;
2463 default:
2464 MS_LOG(INTERNAL_EXCEPTION) << "unimplemented type " << vo->ToString();
2465 break;
2466 }
2467 return self;
2468 }
2469
HandlePositionParams(const py::object & func,std::vector<ValueNode * > * params,FrameStates * frame)2470 bool GraphBuilder::HandlePositionParams(const py::object &func, std::vector<ValueNode *> *params, FrameStates *frame) {
2471 CallNode *call_node = reinterpret_cast<CallNode *>(seek(0));
2472 PyCodeObject *co = reinterpret_cast<PyCodeObject *>(PyFunction_GET_CODE(func.ptr()));
2473 auto vobj = trace_flag() ? AObject::Convert(func.ptr()) : call_node->input(0)->GetVobj();
2474 AObject::Type callable_type = vobj->GetType();
2475
2476 ValueNode *self = GetBoundSelf(call_node);
2477 if (self != nullptr) {
2478 params->insert(params->begin(), self);
2479 }
2480
2481 const int argc = co->co_argcount;
2482 const int has_varg = (co->co_flags & CO_VARARGS) ? 1 : 0;
2483 const int has_kwvarg = (co->co_flags & CO_VARKEYWORDS) ? 1 : 0;
2484 const int varg_loc = argc + co->co_kwonlyargcount;
2485 const int kwvarg_loc = argc + co->co_kwonlyargcount + has_varg;
2486 int pargc = params->size();
2487 if (pargc > argc && !has_varg) {
2488 MS_LOG(DEBUG) << "too many parameters";
2489 return false;
2490 }
2491 bool append_self_to_varg = has_varg && self && callable_type == AObject::kTypeBoundMethod && argc == 0;
2492 if (append_self_to_varg) { // self is in variable arguments
2493 MS_LOG(INFO) << "not implement append self to variable arguments, inline failed";
2494 return false;
2495 }
2496
2497 if (has_kwvarg && frame->Local(kwvarg_loc) == &ValueNode::kUnboundLocal) {
2498 auto vo = AObject::Convert(py::dict());
2499 auto m = NewValueNode(vo, BUILD_MAP, 0, {});
2500 call_node->AddParam(m);
2501 frame->SetLocal(kwvarg_loc, m);
2502 }
2503
2504 if (has_varg) {
2505 int vargc = pargc > argc ? pargc - argc : 0;
2506 std::vector<ValueNode *> vargs(params->end() - vargc, params->end());
2507 params->resize(params->size() - vargc);
2508
2509 auto vo = AObject::BuildOperations(CollectObjects(vargs), BUILD_TUPLE);
2510 ValueNode *build_tuple = NewValueNode(vo, BUILD_TUPLE, vargc, vargs);
2511 call_node->AddParam(build_tuple);
2512 frame->SetLocal(varg_loc, build_tuple);
2513 }
2514
2515 pargc = params->size();
2516 for (int i = pargc - 1; i >= 0; --i) {
2517 if (frame->Local(i) != &ValueNode::kUnboundLocal) {
2518 MS_LOG(DEBUG) << "duplicate key-word parameter error";
2519 return false;
2520 }
2521 frame->SetLocal(i, params->back());
2522 params->pop_back();
2523 }
2524 return CheckAndSetDefaultParams(func, frame, pargc);
2525 }
2526
HandleCallParameters(const py::object & func_info,CallNode * call_node,FrameStates * frame)2527 bool GraphBuilder::HandleCallParameters(const py::object &func_info, CallNode *call_node, FrameStates *frame) {
2528 if (func_info.ptr() == nullptr) {
2529 MS_LOG(EXCEPTION) << "HandleCallParameters with empty func_info input.";
2530 }
2531 PyCodeObject *co = reinterpret_cast<PyCodeObject *>(PyFunction_GET_CODE(func_info.ptr()));
2532 frame->ResizeLocal(co->co_nlocals);
2533
2534 std::vector<ValueNode *> params(call_node->getInputs().begin() + 1, call_node->getInputs().end());
2535 int op = call_node->GetOpcode();
2536 bool has_kw = (op == CALL_FUNCTION_KW);
2537 if (op == CALL_FUNCTION_EX && !UnpackCallExParams(¶ms, co->co_nlocals, &has_kw, call_node)) {
2538 return false; // ex_dict infer failed or user-defined sequence and map arguments
2539 }
2540 if (has_kw && !HandleKWParams(func_info, ¶ms, frame)) {
2541 return false;
2542 }
2543 if (!HandlePositionParams(func_info, ¶ms, frame)) {
2544 return false;
2545 }
2546
2547 MS_EXCEPTION_IF_CHECK_FAIL(params.size() == 0, "check parameters handle");
2548
2549 // after store all params
2550 // cell2arg
2551 const Py_ssize_t ncells = PyTuple_GET_SIZE(co->co_cellvars);
2552 const Py_ssize_t *c2a_arr = co->co_cell2arg;
2553 for (int i = 0; c2a_arr != nullptr && i < ncells; ++i) {
2554 if (c2a_arr[i] != CO_CELL_NOT_AN_ARG) {
2555 Py_ssize_t arg_index = c2a_arr[i];
2556 CellVarNode *cell_node = frame->Closure(i);
2557 ValueNode *arg_node = frame->Local(arg_index);
2558 /**
2559 * here not delete the local, continue with local same as closure
2560 * frame->SetLocal(arg_index, &ValueNode::kUnboundLocal);
2561 */
2562
2563 PyObject *cell = cell_node->GetVobj()->GetPyObject().ptr();
2564 PyObject *cell_contents = arg_node->GetVobj() ? arg_node->GetVobj()->GetPyObject().inc_ref().ptr() : nullptr;
2565 MS_EXCEPTION_IF_CHECK_FAIL(cell && PyCell_Check(cell) && PyCell_GET(cell) == nullptr, "must be a empty closure");
2566
2567 ValueNode *n = NewValueNode(nullptr, STORE_DEREF, i, {arg_node});
2568
2569 cell_node->AddCellOper(n);
2570 cell_node->SetValue(arg_node);
2571 Py_XSETREF(PyCell_GET(cell), cell_contents);
2572 // cell variable is eliminate
2573 // call_node->AddParam(n);
2574 }
2575 }
2576 return true;
2577 }
2578
2579 static void SetGradFuncInfo(mindspore::pijit::CallNode *call_node);
2580
ResolveCallable(CallNode * call_node,StopTraceReason * stop_reason)2581 py::object GraphBuilder::ResolveCallable(CallNode *call_node, StopTraceReason *stop_reason) {
2582 AObject *callable = call_node->input(0)->GetVobj();
2583 py::object callable_info;
2584 *stop_reason = StopTraceReason::kStopTraceInfer_Fail;
2585 call_node->SetInlineReason(InlineReason::kInlineInfer_Fail);
2586 if (!callable) {
2587 return callable_info;
2588 }
2589 callable_info = callable->GetPyObject();
2590 if (callable_info.ptr() == nullptr) {
2591 callable_info = py::cast<py::object>(reinterpret_cast<PyObject *>(callable->GetTypeObject()));
2592 }
2593
2594 AObject::Type callable_type = callable->GetType();
2595 if (callable_info.ptr() == nullptr) {
2596 if (callable->TestMsFlag(AObject::kMsFlagGradFunc | AObject::kMsFlagShardFunc | AObject::kMsFlagVmapFunc)) {
2597 SetGradFuncInfo(call_node);
2598 *stop_reason = StopTraceReason::kNonStopTrace;
2599 }
2600 return py::object();
2601 }
2602
2603 *stop_reason = StopTraceReason::kNonStopTrace;
2604 if (callable_type == AObject::kTypeType) {
2605 call_node->SetInlineReason(InlineReason::kInlineFunc_ArgType_IsClass);
2606 HandleCallClass(call_node);
2607 if (static_cast<AbstractType *>(callable)->GetTypeType() == AObject::kTypeCell) {
2608 *stop_reason = StopTraceReason::kStopTraceInfer_Fail;
2609 }
2610 return py::object();
2611 }
2612
2613 if (WhiteListFuncCheckAndInfer(call_node, callable_info)) {
2614 if (call_node->GetInlineReason() == InlineReason::kInlineFunc_Type_Unsupported) {
2615 *stop_reason = StopTraceReason::kStopTraceFunc_Type_Unsupported;
2616 }
2617 return py::object();
2618 }
2619
2620 // find code object
2621 callable_info = GetFuncInfo(call_node->input(0));
2622 if (callable_info.ptr() == nullptr) {
2623 *stop_reason = StopTraceReason::kStopTraceFunc_Type_Unsupported;
2624 call_node->SetInlineReason(InlineReason::kInlineCFunction_Unsupported);
2625 }
2626 return callable_info;
2627 }
2628
ResolveClosure(const py::object & func_info,ValueNode * callable_node,FrameStates * frame)2629 void GraphBuilder::ResolveClosure(const py::object &func_info, ValueNode *callable_node, FrameStates *frame) {
2630 if (func_info.ptr() == nullptr) {
2631 MS_LOG(INTERNAL_EXCEPTION) << "When resolving closure, get func_info failed.";
2632 }
2633 PyCodeObject *co = reinterpret_cast<PyCodeObject *>(PyFunction_GET_CODE(func_info.ptr()));
2634 PyObject *closure = PyFunction_GET_CLOSURE(func_info.ptr());
2635
2636 int ncells = PyTuple_GET_SIZE(co->co_cellvars);
2637 int nfrees = PyTuple_GET_SIZE(co->co_freevars);
2638 frame->ResizeClosure(ncells + nfrees);
2639 for (int i = 0; i < ncells; i++) {
2640 CellVarNode *n = graph_->allocator().NewNode<CellVarNode>(CellVarNode::CellVar);
2641 n->SetVobj(AObject::Convert(py::reinterpret_steal<py::object>(PyCell_New(nullptr))));
2642 frame->SetClosure(i, n);
2643 }
2644 // track free variable
2645 bool make_func = callable_node->GetOpcode() == MAKE_FUNCTION;
2646 for (int i = 0; i < nfrees; ++i) {
2647 CellVarNode *freevar = nullptr;
2648 if (make_func) {
2649 ValueNode *tuple = *(callable_node->getInputs().end() - 3);
2650 MS_EXCEPTION_IF_CHECK_FAIL(tuple->GetOpcode() == BUILD_TUPLE, "unknown closure source");
2651 freevar = reinterpret_cast<CellVarNode *>(tuple->input(i));
2652 } else if (closure) {
2653 auto v = PyTuple_GET_ITEM(closure, i);
2654 freevar = graph_->allocator().NewNode<CellVarNode>(CellVarNode::FreeVar);
2655 freevar->SetVobj(AObject::Convert(v));
2656
2657 // if inline, guard free variable
2658 ValueNode *param = NewValueNode(AObject::Convert(PyCell_GET(v)), LOAD_DEREF, -1);
2659 param->SetGraph(graph_);
2660 freevar->SetValue(param);
2661 } else {
2662 MS_LOG(EXCEPTION) << "error no closure";
2663 }
2664 frame->SetClosure(ncells + i, freevar);
2665 }
2666 }
2667
SetMixedPrecisionType(CallNode * call_node,FrameStates * frame)2668 void SetMixedPrecisionType(CallNode *call_node, FrameStates *frame) {
2669 auto func_node = call_node->input(0);
2670 if (func_node->GetVobj() && func_node->GetVobj()->GetType() == AbstractObjectBase::kTypeCell) {
2671 auto cell = py::cast<CellPtr>(func_node->GetVobj()->GetPyObject());
2672 auto mixed_type = cell->GetMixedPrecisionType();
2673 if (mixed_type != MixedPrecisionType::kNotSet) {
2674 for (size_t i = 0; i < frame->GetLocals().size(); i++) {
2675 if (frame->Local(i)->GetType() == AbstractNode::Param) {
2676 auto paramNode = reinterpret_cast<ParamNode *>(frame->Local(i));
2677 if (paramNode->GetVobj()->GetType() == AObject::kTypeTensor &&
2678 !paramNode->GetVobj()->GetPyObject().attr("dtype").is_none()) {
2679 auto src_dtype = paramNode->GetVobj()->GetPyObject().attr("dtype");
2680 bool is_cast = false;
2681 if (py::isinstance<Float>(src_dtype)) {
2682 auto float_nbits = py::cast<Float>(src_dtype).nbits();
2683 if (float_nbits == 64 || (float_nbits == 32 && mixed_type != kFP32) ||
2684 (float_nbits == 16 && mixed_type != kFP16)) {
2685 is_cast = true;
2686 }
2687 }
2688 if (py::isinstance<BFloat>(src_dtype) && mixed_type != kBF16) {
2689 is_cast = true;
2690 }
2691 if (!is_cast) {
2692 continue;
2693 }
2694 auto dst_dtype = Utils::MixedPrecisionTypeToDType(mixed_type);
2695 paramNode->SetMixedPrecisionType(dst_dtype);
2696 }
2697 }
2698 }
2699 }
2700 }
2701 }
2702
HandleCall(int depth)2703 StopTraceReason GraphBuilder::HandleCall(int depth) {
2704 MS_EXCEPTION_IF_CHECK_FAIL(seek(0)->GetType() == ValueNode::Call, "must be call node");
2705 CallNode *call_node = reinterpret_cast<CallNode *>(seek(0));
2706 if (depth > root_->graph_->Config().getIntConfig(GraphJitConfig::kMaxInlineDepth)) {
2707 call_node->SetInlineReason(InlineReason::kInlineTooDeep);
2708 return StopTraceReason::kNonStopTrace;
2709 }
2710 StopTraceReason stop_reason = StopTraceReason::kNonStopTrace;
2711
2712 py::object callable_info = ResolveCallable(call_node, &stop_reason);
2713 if (callable_info.ptr() == nullptr) {
2714 return stop_reason;
2715 }
2716 MS_EXCEPTION_IF_CHECK_FAIL(PyFunction_Check(callable_info.ptr()), "'ResolveCallable' must be return a function");
2717
2718 // unsupported check
2719 PyCodeObject *co = reinterpret_cast<PyCodeObject *>(PyFunction_GET_CODE(callable_info.ptr()));
2720 PyObject *globals = PyFunction_GET_GLOBALS(callable_info.ptr());
2721 auto subgraph = GraphBuilder::Creator(this->root_ ? this->root_ : this, this, co, globals, trace_flag());
2722
2723 // frame build
2724 FrameStates *frame = &(subgraph->frame_);
2725 ResolveClosure(callable_info, call_node->input(0), frame);
2726 if (!HandleCallParameters(callable_info, call_node, frame)) {
2727 call_node->SetInlineReason(InlineReason::kInlineFunc_ArgHandle_Unsupported);
2728 return StopTraceReason::kStopTraceFunc_ArgHandle_Unsupported;
2729 }
2730
2731 SetMixedPrecisionType(call_node, frame);
2732 // build sub-graph
2733 stop_reason = BuildSubGraph(call_node, depth, callable_info, subgraph);
2734 CollectInlineInfo(call_node, depth);
2735
2736 if (!trace_flag() && call_node->GetSubGraph() && call_node->GetInlineReason() == InlineReason::kInline) {
2737 MS_EXCEPTION_IF_NULL(call_node->GetSubGraph()->GetRetVal());
2738 seek(0) = call_node->GetSubGraph()->GetRetVal();
2739 }
2740 return stop_reason;
2741 }
2742
GuardLoopSequence(Graph * graph,ValueNode * seq_node,Py_ssize_t seq_size)2743 static bool GuardLoopSequence(Graph *graph, ValueNode *seq_node, Py_ssize_t seq_size) {
2744 // guard length
2745 PyObject *seq = seq_node->GetVobj()->GetPyObject().ptr();
2746 if (seq != nullptr && seq_size == -1) {
2747 seq_size = PySequence_Size(seq);
2748 }
2749 if (seq == nullptr || seq_size == -1) {
2750 PyErr_Clear();
2751 return false;
2752 }
2753 if (!graph->GuardSequenceNodeLength(seq_node, seq_size)) {
2754 return false;
2755 }
2756 if (!graph->GuardType(seq_node)) {
2757 return false;
2758 }
2759 return true;
2760 }
2761
GuardIterInputs(Graph * graph,ValueNode * seq_node,Py_ssize_t seq_size=-1)2762 bool GuardIterInputs(Graph *graph, ValueNode *seq_node, Py_ssize_t seq_size = -1) {
2763 PyObject *seq = seq_node->GetVobj()->GetPyObject().ptr();
2764 if (seq != nullptr && seq_size == -1) {
2765 seq_size = PySequence_Size(seq);
2766 }
2767 if (seq == nullptr || seq_size == -1) {
2768 PyErr_Clear();
2769 return false;
2770 }
2771 if (!graph->GuardSequenceNodeLength(seq_node, seq_size)) {
2772 return false;
2773 }
2774 auto input_nodes = seq_node->getInputs();
2775 for (size_t i = 1; i < input_nodes.size(); ++i) {
2776 ValueNode *input_node = input_nodes[i];
2777 if (input_node == nullptr) {
2778 return false;
2779 }
2780 TracePtr tr = graph->TraceValueNode(input_node);
2781 if (!(graph->GetGuard()->GetGuard()->GuardOn(tr, GuardLevel::GEqual))) {
2782 MS_LOG(INFO) << "Iterator guard fail: " << seq_node->ToString();
2783 return false;
2784 }
2785 }
2786 MS_LOG(INFO) << "Iterator guard success: " << seq_node->ToString();
2787 return true;
2788 }
2789
TraceRunForIterSequence(int jump_bci,bool is_range_type)2790 bool GraphBuilder::TraceRunForIterSequence(int jump_bci, bool is_range_type) {
2791 // check for iter
2792 ValueNode *iter_node = seek(0);
2793 ValueNode *seq_node = iter_node->input(0);
2794 PyObject *seq = seq_node->GetVobj()->GetPyObject().ptr();
2795 if (seq == nullptr) {
2796 return false; // infer failed
2797 }
2798 Py_ssize_t size = PySequence_Size(seq);
2799 if (size == -1) {
2800 PyErr_Clear();
2801 MS_LOG(DEBUG) << "FOR_ITER without __len__";
2802 return false;
2803 }
2804
2805 int &index = iter_node->marker_;
2806 if (index == 0 && ((is_range_type && !GuardIterInputs(graph_, seq_node)) ||
2807 (!is_range_type && !GuardLoopSequence(graph_, seq_node)))) {
2808 // loop start.
2809 return false;
2810 }
2811
2812 if (index >= size) {
2813 pop();
2814 cur_bci_ = jump_bci;
2815 return true;
2816 }
2817
2818 PyObject *item = PySequence_GetItem(seq, index);
2819 if (item == nullptr) {
2820 MS_LOG(ERROR) << "trace for iter got an error " << py::error_already_set().what();
2821 PyErr_Clear();
2822 return false;
2823 }
2824
2825 py::object index_object = py::int_(index);
2826 ValueNode *index_node = NewValueNode(AObject::Convert(index_object), LOAD_CONST, -1, {});
2827 push(seq_node);
2828 push(index_node);
2829 DoItemAccess({BINARY_SUBSCR, 0});
2830 ValueNode *item_node = pop();
2831 Py_DECREF(item);
2832
2833 index++;
2834 push(item_node);
2835 cur_bci_ = cur_bci_ + 1;
2836 return true;
2837 }
2838
CheckForIterEnumerate(ValueNode * iter_node)2839 static bool CheckForIterEnumerate(ValueNode *iter_node) {
2840 ValueNode *enumerate_node = iter_node->input(0);
2841 if (enumerate_node->GetOpcode() != CALL_FUNCTION || iter_node->bci() - 1 != enumerate_node->bci()) {
2842 // enumerate object maybe alive, shouldn't reduce it
2843 return false;
2844 }
2845 PyObject *enumerate = enumerate_node->GetVobj()->GetPyObject().ptr();
2846 if (enumerate == nullptr) {
2847 return false;
2848 }
2849
2850 MS_EXCEPTION_IF_NULL(iter_node->GetGraph());
2851
2852 ValueNode *iterable_node = enumerate_node->input(1);
2853 PyObject *iterable = iterable_node->GetVobj()->GetPyObject().ptr();
2854 if (iterable == nullptr || !PySequence_Check(iterable) || !GuardLoopSequence(iter_node->GetGraph(), iterable_node)) {
2855 // just support sequence iteration
2856 return false;
2857 }
2858 return true;
2859 }
2860
TraceRunForIterEnumerate(int jump_bci)2861 bool GraphBuilder::TraceRunForIterEnumerate(int jump_bci) {
2862 ValueNode *iter_node = seek(0);
2863 if (iter_node->marker_ == 0) {
2864 if (!CheckForIterEnumerate(iter_node)) {
2865 return false;
2866 }
2867 iter_node->marker_ = 1;
2868 }
2869 ValueNode *enumerate_node = iter_node->input(0);
2870 PyObject *enumerate = enumerate_node->GetVobj()->GetPyObject().ptr();
2871 ValueNode *iterable_node = enumerate_node->input(1);
2872
2873 // reduce iterable object
2874 ValueNode *seq_node = iterable_node;
2875 PyObject *tuple = PyIter_Next(enumerate);
2876 if (tuple == nullptr) {
2877 if (PyErr_Occurred() && !PyErr_ExceptionMatches(PyExc_StopIteration)) {
2878 MS_LOG(ERROR) << "trace FOR_ITER got an error " << py::error_already_set().what();
2879 PyErr_Clear();
2880 return false;
2881 }
2882 PyErr_Clear();
2883 pop();
2884 cur_bci_ = jump_bci;
2885 return true;
2886 }
2887 PyObject *index = PyTuple_GET_ITEM(tuple, 0);
2888 PyObject *item = PyTuple_GET_ITEM(tuple, 1);
2889 ValueNode *index_node = NewValueNode(AObject::Convert(index), LOAD_CONST, -1, {});
2890 ValueNode *item_node = NewValueNode(AObject::Convert(item), BINARY_SUBSCR, 0, {seq_node, index_node});
2891 ValueNode *value_node = NewValueNode(AObject::Convert(tuple), BUILD_TUPLE, 2, {index_node, item_node});
2892 Py_DECREF(tuple);
2893 graph_->GetTracedNodes().push_back(item_node);
2894 graph_->GetTracedNodes().push_back(value_node);
2895
2896 push(value_node);
2897 cur_bci_ = cur_bci_ + 1;
2898 return true;
2899 }
2900
CheckForIterZip(ValueNode * iter_node)2901 static bool CheckForIterZip(ValueNode *iter_node) {
2902 ValueNode *zip_node = iter_node->input(0);
2903 if (zip_node->GetOpcode() != CALL_FUNCTION || iter_node->bci() - 1 != zip_node->bci()) {
2904 return false;
2905 }
2906 PyObject *zip = zip_node->GetVobj()->GetPyObject().ptr();
2907 if (zip == nullptr) {
2908 return false;
2909 }
2910 MS_EXCEPTION_IF_NULL(iter_node->GetGraph());
2911 Graph *graph = iter_node->GetGraph();
2912
2913 std::vector<ValueNode *> iterable_nodes = {zip_node->getInputs().begin() + 1, zip_node->getInputs().end()};
2914 auto iter = std::find_if(iterable_nodes.begin(), iterable_nodes.end(), [&graph](ValueNode *iterable_node) {
2915 PyObject *iterable = iterable_node->GetVobj()->GetPyObject().ptr();
2916 return iterable == nullptr || !PySequence_Check(iterable) || !GuardLoopSequence(graph, iterable_node);
2917 });
2918 if (iter != iterable_nodes.end()) {
2919 return false;
2920 }
2921 return true;
2922 }
2923
TraceRunForIterZip(int jump_bci)2924 bool GraphBuilder::TraceRunForIterZip(int jump_bci) {
2925 ValueNode *iter_node = seek(0);
2926 int *index = &iter_node->marker_;
2927 if ((*index) == 0) {
2928 if (!CheckForIterZip(iter_node)) {
2929 return false;
2930 }
2931 }
2932
2933 ValueNode *zip_node = iter_node->input(0);
2934 PyObject *zip = zip_node->GetVobj()->GetPyObject().ptr();
2935 std::vector<ValueNode *> iterable_nodes = {zip_node->getInputs().begin() + 1, zip_node->getInputs().end()};
2936
2937 // reduce iterable object
2938 PyObject *tuple = PyIter_Next(zip);
2939 py::object handle = py::reinterpret_steal<py::object>(tuple);
2940 if (handle.ptr() == nullptr) {
2941 if (PyErr_Occurred() && !PyErr_ExceptionMatches(PyExc_StopIteration)) {
2942 MS_LOG(ERROR) << "trace FOR_ITER got an error " << py::error_already_set().what();
2943 PyErr_Clear();
2944 return false;
2945 }
2946 PyErr_Clear();
2947 pop();
2948 cur_bci_ = jump_bci;
2949 return true;
2950 }
2951
2952 std::vector<ValueNode *> inputs;
2953 for (size_t tuple_index = 0; tuple_index < iterable_nodes.size(); ++tuple_index) {
2954 PyObject *item = PyTuple_GET_ITEM(tuple, tuple_index);
2955 ValueNode *seq_node = iterable_nodes[tuple_index];
2956 ValueNode *index_node = NewValueNode(AObject::Convert(py::int_(*index)), LOAD_CONST, -1, {});
2957 ValueNode *item_node = NewValueNode(AObject::Convert(item), BINARY_SUBSCR, 0, {seq_node, index_node});
2958 inputs.push_back(item_node);
2959 graph_->GetTracedNodes().push_back(item_node);
2960 }
2961 ValueNode *value_node = NewValueNode(AObject::Convert(tuple), BUILD_TUPLE, inputs.size(), inputs);
2962 graph_->GetTracedNodes().push_back(value_node);
2963 push(value_node);
2964
2965 (*index)++;
2966 cur_bci_ = cur_bci_ + 1;
2967 return true;
2968 }
2969
IsRangeType(ValueNode * iter_node)2970 bool IsRangeType(ValueNode *iter_node) {
2971 if (iter_node->input(0)->GetOpcode() != CALL_FUNCTION) {
2972 return false;
2973 }
2974 auto vobj = iter_node->input(0)->input(0)->GetVobj();
2975 if (vobj == nullptr) {
2976 return false;
2977 }
2978 PyTypeObject *type = reinterpret_cast<PyTypeObject *>(static_cast<AbstractType *>(vobj)->GetPyObject().ptr());
2979 return type == &PyRange_Type;
2980 }
2981
TraceRunForIter(const Instr & instr)2982 bool GraphBuilder::TraceRunForIter(const Instr &instr) {
2983 MS_EXCEPTION_IF_NULL(instr.extra_jump());
2984
2985 // check for iter
2986 ValueNode *iter_node = seek(0);
2987 AObject *iterable = iter_node->getInputs().size() > 0 ? iter_node->input(0)->GetVobj() : nullptr;
2988 bool succ;
2989 if (iter_node->GetOpcode() != GET_ITER) {
2990 MS_LOG(DEBUG) << "FOR_ITER without GET_ITER";
2991 succ = false;
2992 } else if (iterable == nullptr) {
2993 succ = false;
2994 } else if (iterable->GetTypeObject() == &PyEnum_Type) {
2995 succ = TraceRunForIterEnumerate(instr.extra_jump()->bci());
2996 } else if (iterable->GetTypeObject() == &PyZip_Type) {
2997 succ = TraceRunForIterZip(instr.extra_jump()->bci());
2998 } else {
2999 succ = TraceRunForIterSequence(instr.extra_jump()->bci(), IsRangeType(iter_node));
3000 }
3001 if (!succ) {
3002 graph_->StopTraceAt(cur_bci_, StopTraceReason::kStopTraceLoop_Unsupported);
3003 }
3004 return succ;
3005 }
3006
IsConstantBoolValue(ValueNode * node)3007 static bool IsConstantBoolValue(ValueNode *node) {
3008 const auto &cnst_info = node->GetConstantInfo();
3009 if (cnst_info == nullptr) {
3010 return false;
3011 }
3012 if (cnst_info->value().ptr() != nullptr) {
3013 return true;
3014 }
3015 PyTypeObject *tp = cnst_info->type();
3016 if (tp == nullptr) {
3017 return false;
3018 }
3019 static const std::set<PyTypeObject *> len_to_bool = {&PyTuple_Type, &PyList_Type, &PyDict_Type};
3020 if (len_to_bool.find(tp) != len_to_bool.end() && cnst_info->len() != -1) {
3021 return true;
3022 }
3023 return false;
3024 }
3025
IsShapeOrDtypeRelatedNode(const ValueNode * node)3026 bool IsShapeOrDtypeRelatedNode(const ValueNode *node) {
3027 if (node->GetOpcode() == CALL_FUNCTION && node->input(0)->GetVobj()->GetType() == AObject ::kTypeCFunction &&
3028 node->input(0)->GetName() == "len") {
3029 node = node->input(1);
3030 }
3031 if (node->GetOpcode() == BINARY_SUBSCR) {
3032 node = node->input(0);
3033 }
3034 if (node->GetOpcode() == CALL_FUNCTION) {
3035 auto func_node = node->input(0);
3036 // prim
3037 if (py::isinstance<mindspore::PrimitivePyAdapter>(func_node->GetVobj()->GetPyObject())) {
3038 auto prime_name = py::cast<mindspore::PrimitivePyAdapterPtr>(func_node->GetVobj()->GetPyObject())->name();
3039 if (prime_name == "Shape" || prime_name == "DType" || prime_name == "Rank") {
3040 return true;
3041 }
3042 }
3043 } else if (node->GetOpcode() == LOAD_ATTR) {
3044 auto attr_name = node->GetName();
3045 if (attr_name == "dtype" || attr_name == "shape" || attr_name == "ndim" || attr_name == "size") {
3046 return true;
3047 }
3048 }
3049 return false;
3050 }
3051
TryGuardEscape(ValueNode * cond_node)3052 bool TryGuardEscape(ValueNode *cond_node) {
3053 if (cond_node->GetOpcode() == COMPARE_OP &&
3054 std::any_of(cond_node->getInputs().begin(), cond_node->getInputs().end(),
3055 [](const ValueNode *node) { return IsShapeOrDtypeRelatedNode(node); })) {
3056 return true;
3057 }
3058 if (cond_node->GetOpcode() == COMPARE_OP && cond_node->getInputs().size() == 2 &&
3059 cond_node->input(0)->GetOpcode() == BINARY_SUBSCR && cond_node->input(1)->GetOpcode() == BINARY_SUBSCR) {
3060 return true;
3061 }
3062 if (cond_node->GetOpcode() == CONTAINS_OP &&
3063 (IsShapeOrDtypeRelatedNode(cond_node->input(0)) || IsShapeOrDtypeRelatedNode(cond_node->input(1)))) {
3064 return true;
3065 }
3066 if (cond_node->GetOpcode() == CALL_FUNCTION &&
3067 std::all_of(cond_node->getInputs().begin() + 1, cond_node->getInputs().end(),
3068 [](const ValueNode *node) { return IsShapeOrDtypeRelatedNode(node); })) {
3069 return true;
3070 }
3071 return false;
3072 }
3073
IsSatisfyPruneLimit(int cond,Graph * graph_,ValueNode * cond_node)3074 bool IsSatisfyPruneLimit(int cond, Graph *graph_, ValueNode *cond_node) {
3075 if (cond == -1) {
3076 return false;
3077 }
3078 int limit_prune = graph_->Config().getIntConfig(GraphJitConfig::kMaxPruneCase);
3079 if (limit_prune >= 0 && limit_prune < graph_->GetPruneBranchCount()) {
3080 return false;
3081 }
3082 if (IsConstantBoolValue(cond_node)) {
3083 return true;
3084 }
3085 auto tr = graph_->TraceValueNode(cond_node);
3086 if (tr == nullptr) {
3087 if (graph_->Config().getIntConfig(GraphJitConfig::kGuardRelaxCount) > 0) {
3088 PyObject *bool_value = cond_node->GetVobj()->GetPyObject().ptr();
3089 if ((bool_value == Py_True || bool_value == Py_False) && TryGuardEscape(cond_node)) {
3090 return true;
3091 }
3092 }
3093 return false;
3094 }
3095 PyObject *bool_value = cond_node->GetVobj()->GetPyObject().ptr();
3096 if (bool_value != Py_True && bool_value != Py_False) {
3097 bool strict = graph_->Config().GetBoolConfig(GraphJitConfig::kStrictTrace);
3098 auto bool_type = CreateOpTrace(reinterpret_cast<PyObject *>(&PyBool_Type), LOAD_CONST, -1, {}, "", "", strict);
3099 tr = CreateOpTrace(cond ? Py_True : Py_False, CALL_FUNCTION, 1, {bool_type, tr}, "", "", strict);
3100 } else {
3101 cond_node->SetConstantValue(true);
3102 }
3103 graph_->GetGuard()->GetGuard()->GuardOn(tr, GuardLevel::GId);
3104 return true;
3105 }
3106
LogPrunBranch(ValueNode * cond,const Instr & instr,const GraphJitConfig & conf)3107 static void LogPrunBranch(ValueNode *cond, const Instr &instr, const GraphJitConfig &conf) {
3108 MS_LOG(DEBUG) << "trace run prune branch failed [" << cond->ToString() << "]";
3109 if (conf.GetBoolConfig(GraphJitConfig::kPrintGuard)) {
3110 GRAPH_JIT_LOG_F("Fail to prune bytecode [%s]!\n", instr.ToString().c_str());
3111 } else {
3112 MS_LOG(DEBUG) << "Fail to prune bytecode [" << instr.ToString() << "]!\n";
3113 }
3114
3115 if (conf.GetBoolConfig(GraphJitConfig::kLogGraphBreak)) {
3116 if (CondIsTrue(cond) == -1) {
3117 GRAPH_JIT_LOG_F("infer failed\n");
3118 } else {
3119 auto tr = GetTrace(cond, false, true, 0, conf.getIntConfig(GraphJitConfig::kMaxTraceDepth));
3120 std::map<Trace *, size_t> cache;
3121 GRAPH_JIT_LOG_F("trace:\n%s\n", tr ? tr->FormatString(&cache).c_str() : "trace failed");
3122 }
3123 if (cond->GetGraph() == nullptr || cond->GetGraph()->GetCodeObj() == nullptr) {
3124 return;
3125 }
3126 GRAPH_JIT_LOG_F("if branch prune failed, condition [%s] at [%U : %d]", cond->ToString().c_str(),
3127 cond->GetGraph()->GetCodeObj()->co_filename, cond->GetLineNo());
3128 }
3129 }
3130
TraceRunControl(const Instr & instr)3131 bool GraphBuilder::TraceRunControl(const Instr &instr) {
3132 MS_EXCEPTION_IF_NULL(instr.extra_jump());
3133 Opcode opcode(instr.op());
3134 ValueNode *cond_node = nullptr;
3135 int cond = -1;
3136 int jump_to = -1;
3137 if (opcode == JUMP_FORWARD || opcode == JUMP_ABSOLUTE) {
3138 cur_bci_ = instr.extra_jump()->bci();
3139 return true;
3140 } else if (opcode == FOR_ITER) {
3141 return TraceRunForIter(instr);
3142 } else if (opcode == POP_JUMP_IF_FALSE || opcode == POP_JUMP_IF_TRUE) {
3143 cond_node = pop();
3144 cond = CondIsTrue(cond_node);
3145 jump_to = ((cond == 0) ^ (opcode == POP_JUMP_IF_TRUE)) ? instr.extra_jump()->bci() : cur_bci_ + 1;
3146 } else if (opcode == JUMP_IF_FALSE_OR_POP || opcode == JUMP_IF_TRUE_OR_POP) {
3147 cond_node = seek(0);
3148 cond = CondIsTrue(cond_node);
3149 bool jump = (cond == 0) ^ (opcode == JUMP_IF_TRUE_OR_POP);
3150 cond_node = jump ? seek(0) : pop();
3151 jump_to = jump ? instr.extra_jump()->bci() : cur_bci_ + 1;
3152 } else {
3153 graph_->StopTraceAt(cur_bci_, StopTraceReason::kStopTraceByteCode_Unsupported);
3154 return false;
3155 }
3156
3157 // if branch
3158 if (!IsSatisfyPruneLimit(cond, graph_, cond_node)) {
3159 LogPrunBranch(cond_node, instr, graph_->Config());
3160 graph_->StopTraceAt(cur_bci_, StopTraceReason::kStopTraceIf_Unsupported);
3161 return false;
3162 }
3163 MS_EXCEPTION_IF_CHECK_FAIL(jump_to != -1, "error jump bci");
3164 cur_bci_ = jump_to;
3165 return true;
3166 }
3167
EliminateCellAccess(Graph * g)3168 static void EliminateCellAccess(Graph *g) {
3169 PyCodeObject *co = g->GetCodeObj();
3170 int ncells = PyTuple_GET_SIZE(co->co_cellvars);
3171 if (ncells == 0) {
3172 return;
3173 }
3174 ValueNode *ret_node = g->GetRetVal();
3175 if (ret_node == nullptr) {
3176 return;
3177 }
3178 std::set<ValueNode *> escaped;
3179 auto CollectClosure = [&escaped](ValueNode *node) {
3180 if (node->GetOpcode() == MAKE_FUNCTION && (node->GetOparg() & 0x08)) {
3181 const auto &in = (*(node->getInputs().end() - 3))->getInputs();
3182 escaped.insert(in.begin(), in.end());
3183 }
3184 };
3185 for (auto i : g->GetTracedNodes()) {
3186 int op = i->GetOpcode();
3187 if (op == STORE_DEREF && i->GetOparg() < ncells) {
3188 // exclude STORE_DEREF
3189 continue;
3190 }
3191 auto begin = i->getInputs().begin();
3192 if (Opcode(op).IsCall() && static_cast<CallNode *>(i)->GetInlineReason() == InlineReason::kInline) {
3193 begin++;
3194 }
3195 std::for_each(begin, i->getInputs().end(), CollectClosure);
3196 }
3197 CollectClosure(ret_node);
3198 // collect STORE_DEREF with MAKE_FUNCTION ...
3199
3200 const auto &closures = g->GetFrame(0).GetClosures();
3201 for (int i = 0; i < ncells; ++i) {
3202 if (escaped.find(closures[i]) != escaped.end()) {
3203 continue;
3204 }
3205 for (auto node : closures[i]->GetCellOper()) {
3206 if (node->GetOpcode() != STORE_DEREF) {
3207 // closure access before assign, raise UnboundLocalError
3208 return;
3209 }
3210 node->SetOpcode(LOAD_CONST);
3211 node->SetVobj(AObject::Convert(Py_None));
3212 node->ClearInputs();
3213 }
3214 closures[i]->GetCellOper().clear();
3215 }
3216 }
3217
TraceRun()3218 StopTraceReason GraphBuilder::TraceRun() {
3219 current_block_ = graph_->GetCFG()->GetFirstBB();
3220 cur_bci_ = 0;
3221 const auto &instrs = graph_->GetCFG()->instr_pool();
3222 while (true) {
3223 this->graph_->SetFrame(cur_bci_, frame_);
3224 MS_EXCEPTION_IF_CHECK_FAIL(static_cast<size_t>(cur_bci_) < instrs.size(), "error control flow");
3225 MS_EXCEPTION_IF_CHECK_FAIL(instrs[cur_bci_]->bci() == cur_bci_, "check instruction bci");
3226 if (!DoByteCode(*instrs[cur_bci_])) {
3227 break;
3228 }
3229 }
3230 if (!trace_flag()) {
3231 EliminateCellAccess(this->graph_);
3232 }
3233 return graph_->GetStopTraceReason();
3234 }
3235
3236 extern void AddConfigToGuard(const GraphJitConfig &c, OptGuardPtr guard);
3237 extern void AddGuardForParam(const PyFrameObject *f, OptGuardPtr guard, bool detach);
3238
3239 /**
3240 * Generate a graph from callable, this function will actually create python frame
3241 */
GenerateRootGraph(const py::object & callable,const py::object & args,const py::object & kwargs,const GraphJitConfig & conf)3242 static std::unique_ptr<GraphBuilder> GenerateRootGraph(const py::object &callable, const py::object &args,
3243 const py::object &kwargs, const GraphJitConfig &conf) {
3244 PyFrameObject *frame = Utils::PrepareFrame(callable.ptr(), args.ptr(), kwargs.ptr());
3245 if (frame == nullptr) {
3246 PyErr_Clear();
3247 return nullptr;
3248 }
3249 auto jcr = getJitCompileResults(reinterpret_cast<PyObject *>(frame->f_code));
3250 *jcr->conf = conf;
3251 jcr->code = jcr->codehub->AddOptTarget(OptOption::CreateOptionByPoint(jcr));
3252
3253 auto res = std::make_unique<GraphBuilder>(frame);
3254
3255 auto code = res->GetGraph()->GetGuard();
3256 AddConfigToGuard(conf, code->GetGuard());
3257 AddGuardForParam(frame, code->GetGuard(), conf.GetBoolConfig(GraphJitConfig::kGuardDetachObject));
3258
3259 Py_DECREF(frame);
3260 return res;
3261 }
3262
3263 /**
3264 * build graph and infer func result
3265 * it used to infer mindspore function, maybe replace with mindspore func_graph to infer.
3266 */
InferFuncResult(const py::object & callable,const py::object & args,const py::object & kwargs,const GraphJitConfig & conf,bool clear_guard)3267 AObject *InferFuncResult(const py::object &callable, const py::object &args, const py::object &kwargs,
3268 const GraphJitConfig &conf, bool clear_guard) {
3269 auto g = GenerateRootGraph(callable, args, kwargs, conf);
3270 if (g == nullptr) {
3271 return nullptr;
3272 }
3273 g->TraceRun();
3274 if (conf.GetBoolConfig(GraphJitConfig::kPrintAfterAll)) {
3275 g->DumpDFG();
3276 }
3277 if (clear_guard) {
3278 Graph *graph = g->GetGraph();
3279 auto jcr = getJitCompileResults(reinterpret_cast<PyObject *>(graph->GetCodeObj()));
3280 jcr->codehub->DelOptTarget(OptOption::CreateOptionByPoint(jcr), graph->GetGuard());
3281 }
3282
3283 ValueNode *res = g->GetGraph()->GetRetVal();
3284 if (res == nullptr) {
3285 return nullptr;
3286 }
3287 return res->GetVobj();
3288 }
3289
InferFuncResult(const py::object & func,const std::vector<AObject * > & stack_args,int opcode,const GraphJitConfig & conf,bool clear_guard)3290 AObject *InferFuncResult(const py::object &func, const std::vector<AObject *> &stack_args, int opcode,
3291 const GraphJitConfig &conf, bool clear_guard) {
3292 std::vector<py::object> args;
3293 std::transform(stack_args.begin(), stack_args.end(), std::back_inserter(args),
3294 [](AObject *i) { return i ? i->GetPyObject() : py::object(); });
3295 auto pair = Utils::PackCallStackArgs(args, opcode);
3296 if (pair.first.ptr() == nullptr) {
3297 return nullptr;
3298 }
3299 return InferFuncResult(func, pair.first, pair.second, conf, clear_guard);
3300 }
3301
InferFuncResult(const py::object & callable,const py::object & args,const py::object & kwargs,const GraphJitConfig & conf)3302 AObject *InferFuncResult(const py::object &callable, const py::object &args, const py::object &kwargs,
3303 const GraphJitConfig &conf) {
3304 return InferFuncResult(callable, args, kwargs, conf, true);
3305 }
3306
GetGradSens(ValueNode * grad_node)3307 static bool GetGradSens(ValueNode *grad_node) {
3308 AObject *grad_object = grad_node->GetVobj();
3309 if (grad_object->GetPyObject().ptr() != nullptr) {
3310 return grad_object->GetAttr("sens_param")->GetPyObject().ptr() == Py_True;
3311 }
3312 bool sens_param = false;
3313 AObject *cls = grad_node->getInputs().size() > 0 ? grad_node->input(0)->GetVobj() : nullptr;
3314 if (!(Opcode(grad_node->GetOpcode()).IsCall() && cls != nullptr && cls->GetType() == AObject::kTypeType)) {
3315 return sens_param;
3316 }
3317 if (grad_node->GetOpcode() == CALL_FUNCTION && grad_node->getInputs().size() > 3) {
3318 AObject *tmp = grad_node->input(3)->GetVobj();
3319 sens_param = tmp ? tmp->GetPyObject().ptr() == Py_True : false;
3320 } else if (grad_node->GetOpcode() == CALL_FUNCTION_KW) {
3321 py::object kwnames = grad_node->getInputs().back()->GetVobj()->GetPyObject();
3322 PyObject **arr = &PyTuple_GET_ITEM(kwnames.ptr(), 0);
3323 Py_ssize_t size = PyTuple_GET_SIZE(kwnames.ptr());
3324 PyObject **iter = std::find_if(arr, arr + size, [](PyObject *k) {
3325 // find sens_param key
3326 return !PyUnicode_CompareWithASCIIString(k, "sens_param");
3327 });
3328 AObject *tmp = iter - arr != size ? grad_node->input(iter - arr)->GetVobj() : nullptr;
3329 sens_param = tmp ? tmp->GetPyObject().ptr() == Py_True : false;
3330 }
3331 return sens_param;
3332 }
3333
SetGradFuncInfo(CallNode * call_node)3334 static void SetGradFuncInfo(CallNode *call_node) {
3335 const int flag = AObject::kMsFlagGradFunc | AObject::kMsFlagShardFunc | AObject::kMsFlagVmapFunc;
3336 ValueNode *grad_func_node = call_node->input(0);
3337 if (grad_func_node->getInputs().size() < 2) {
3338 grad_func_node->GetVobj()->ClearMsFlag(flag);
3339 return;
3340 }
3341 ValueNode *grad_node = grad_func_node->input(0);
3342 ValueNode *deco_func_node = grad_func_node->input(1);
3343 AObject *grad_object = grad_node->GetVobj();
3344 AObject *deco_func = deco_func_node->GetVobj();
3345 bool sens_param = false;
3346 if (grad_func_node->GetVobj()->TestMsFlag(AObject::kMsFlagGradFunc) &&
3347 grad_object->GetType() == AObject::kTypeMetaFuncGraph) {
3348 sens_param = GetGradSens(grad_node);
3349 }
3350
3351 HandleGradFuncCall(call_node, deco_func, sens_param);
3352
3353 // guard forward net for grad
3354 if (grad_func_node->GetVobj()->TestMsFlag(flag) && !call_node->GetGraph()->GuardValueNode(deco_func_node)) {
3355 grad_func_node->GetVobj()->ClearMsFlag(flag);
3356 }
3357 }
3358
DumpDFG()3359 void GraphBuilder::DumpDFG() { GRAPH_JIT_LOG_F("%s", graph_->ToString().c_str()); }
3360
GetLocation(CallNode * call_node) const3361 LocationPtr MindGraphBuilder::GetLocation(CallNode *call_node) const {
3362 auto file_name = py::cast<std::string>(graph_->GetCodeObj()->co_filename);
3363 auto line_no = call_node->GetLineNo();
3364 std::vector<std::string> comments;
3365 return std::make_shared<Location>(file_name, line_no, 0, line_no, 0, "", std::move(comments));
3366 }
3367
WhiteListFuncCheckAndInfer(CallNode * call_node,const py::object & callable)3368 bool MindGraphBuilder::WhiteListFuncCheckAndInfer(CallNode *call_node, const py::object &callable) {
3369 InferFunc infer_func = FindInferFunc(callable, trace_flag());
3370 if (infer_func != nullptr) {
3371 call_node->SetSubGraph(NewGraph(nullptr, nullptr));
3372 call_node->GetSubGraph()->SetGuard(root_->GetGraph()->GetGuard());
3373 bool has_sub_graph = infer_func(call_node, this);
3374 if (!has_sub_graph) {
3375 call_node->SetInlineReason(InlineReason::kInlineFuncSpecialize);
3376 MS_ASSERT(!call_node->GetSubGraph()); // check infer function
3377 return true;
3378 }
3379 call_node->SetInlineReason(InlineReason::kInline);
3380 ValueNode *ret_node = call_node->GetSubGraph()->GetRetVal();
3381 MS_EXCEPTION_IF_CHECK_FAIL(ret_node, "infer special function failed");
3382 seek(0) = ret_node;
3383 return true;
3384 }
3385 return false;
3386 }
3387
FGAddInputs(const std::vector<py::object> & args)3388 bool MindGraphBuilder::FGAddInputs(const std::vector<py::object> &args) {
3389 // Add function graph inputs.
3390 for (size_t i = 0; i < args.size(); ++i) {
3391 auto obj = FGBuilder()->AddSubGraphInput(args[i]);
3392 if (obj.ptr() == nullptr) {
3393 MS_LOG(INFO) << "Add input fail for input: " << std::string(py::str(args[i]));
3394 return false;
3395 }
3396 MS_LOG(INFO) << "Add input success for input: " << std::string(py::str(args[i]));
3397 }
3398 return true;
3399 }
3400
FGAddOutput(bool is_top_graph)3401 void MindGraphBuilder::FGAddOutput(bool is_top_graph) {
3402 if (auto ret = GetGraph()->GetRetVal()) {
3403 MS_LOG(INFO) << ret->GetVobj()->ToString();
3404 auto out = ret->GetVobj()->GetPyObject();
3405 MS_LOG(INFO) << "try add output: " << py::str(out) << " addr:" << out.ptr();
3406 if (FGBuilder()->AddOutput(out, is_top_graph)) {
3407 MS_LOG(INFO) << "add output succuss";
3408 } else {
3409 MS_LOG(INFO) << "add output fail";
3410 }
3411 }
3412 }
3413
FGAddNode(CallNode * call_node,const py::object & callable_info,const std::vector<py::object> & args,StopTraceReason * stop_reason)3414 py::object MindGraphBuilder::FGAddNode(CallNode *call_node, const py::object &callable_info,
3415 const std::vector<py::object> &args, StopTraceReason *stop_reason) {
3416 MS_LOG(INFO) << "try add node: " << py::str(callable_info);
3417 TraceGuard trace_guard(GetLocation(call_node));
3418 auto res = FGBuilder()->AddNode(callable_info, args);
3419 if (res.ptr() == nullptr) {
3420 MS_LOG(INFO) << "add node fail";
3421 *stop_reason = StopTraceReason::kTrace_Fail;
3422 } else {
3423 MS_LOG(INFO) << "add node suc";
3424 auto node = AObject::Convert(res);
3425 MS_LOG(INFO) << py::str(node->GetPyObject());
3426 MS_LOG(INFO) << node->ToString();
3427 call_node->SetVobj(node);
3428 *stop_reason = StopTraceReason::kNonStopTrace;
3429 }
3430 return py::object();
3431 }
3432
GetNewArgs(CallNode * call_node,AObject * vobj)3433 std::vector<py::object> MindGraphBuilder::GetNewArgs(CallNode *call_node, AObject *vobj) {
3434 std::vector<py::object> new_args;
3435 vobj = (vobj && vobj->GetType() != AObject::kTypePrimitive) ? vobj : call_node->input(0)->GetVobj();
3436 if (vobj->GetType() == AObject::kTypeCFunction) {
3437 MS_LOG(INFO) << "not support cfunction";
3438 }
3439 auto new_callable_info = FindPyFunc(vobj);
3440 FrameStates f;
3441 ResolveClosure(new_callable_info, call_node->input(0), &f);
3442
3443 // Need to consider repeat add issue.
3444 if (!HandleCallParameters(new_callable_info, call_node, &f)) {
3445 MS_LOG(INFO) << "HandleCallParameters error" << std::endl;
3446 }
3447 PyCodeObject *co = reinterpret_cast<PyCodeObject *>(PyFunction_GET_CODE(new_callable_info.ptr()));
3448 int argc = co->co_argcount + co->co_kwonlyargcount;
3449 argc += (co->co_flags & CO_VARARGS) ? 1 : 0;
3450 argc += (co->co_flags & CO_VARKEYWORDS) ? 1 : 0;
3451 for (auto it = f.GetLocals().begin(); it != f.GetLocals().begin() + argc; it++) {
3452 std::set<AObject::Type> unsupported_parameter = {
3453 AObject::kTypeAnyValue, AObject::kTypeFunction, AObject::kTypeBoundMethod,
3454 AObject::kTypePrimitive, AObject::kTypeMetaFuncGraph, AObject::kTypeCell,
3455 };
3456 auto it_vobj = (*it)->GetVobj();
3457 if (it_vobj != nullptr) {
3458 auto pyobj = it_vobj->GetPyObject();
3459 if (pyobj.ptr() != nullptr) {
3460 if (unsupported_parameter.find(AbstractObjectBase::GetPyType(pyobj.ptr())) == unsupported_parameter.end()) {
3461 new_args.push_back(pyobj);
3462 }
3463 }
3464 }
3465 }
3466 return new_args;
3467 }
3468
AllConstantArgs(const std::vector<py::object> & args,const py::object & callable_info,CallNode * call_node)3469 bool MindGraphBuilder::AllConstantArgs(const std::vector<py::object> &args, const py::object &callable_info,
3470 CallNode *call_node) {
3471 auto new_args = args;
3472 if (PyFunction_Check(callable_info.ptr())) {
3473 new_args = GetNewArgs(call_node);
3474 }
3475
3476 return std::all_of(new_args.begin(), new_args.end(), [](const auto &arg) { return CheckConstPyObject(arg.ptr()); });
3477 }
3478
ResolveCallable(CallNode * call_node,StopTraceReason * stop_reason)3479 py::object MindGraphBuilder::ResolveCallable(CallNode *call_node, StopTraceReason *stop_reason) {
3480 AObject *callable = call_node->input(0)->GetVobj();
3481 py::object callable_info;
3482 *stop_reason = StopTraceReason::kStopTraceInfer_Fail;
3483 if (!callable) {
3484 return callable_info;
3485 }
3486 callable_info = callable->GetPyObject();
3487 py::object original_callable = callable_info;
3488 if (callable_info.ptr() == nullptr) {
3489 return py::object();
3490 }
3491 if (!FGBuilder()->ValidateCallableObject(callable_info)) {
3492 return py::object();
3493 }
3494 MS_LOG(INFO) << "trace_flag for: " << py::str(callable_info);
3495 auto args = call_node->GetArgs();
3496 if (FGBuilder()->CanConstantFoldFunc(callable_info) && AllConstantArgs(args, callable_info, call_node)) {
3497 MS_LOG(INFO) << "CanConstantFoldFunc for: " << py::str(callable_info);
3498 JustCallAndSetRes(call_node);
3499 *stop_reason = StopTraceReason::kNonStopTrace;
3500 return py::object();
3501 }
3502 auto method = FGBuilder()->ConvertMethod(callable_info);
3503 if (method.ptr() != nullptr) {
3504 MS_LOG(INFO) << "convert method :" << py::str(callable_info) << " to " << py::str(method);
3505 callable_info = method;
3506 if (!PyFunction_Check(callable_info.ptr())) { // prim getnewargs here, func getnewargs in subgraph
3507 args = GetNewArgs(call_node, AObject::Convert(callable_info.ptr()));
3508 }
3509 }
3510 auto func = FGBuilder()->ConvertFunction(callable_info);
3511 if (func.ptr() != nullptr) {
3512 MS_LOG(INFO) << "convert function:" << py::str(callable_info) << " to " << py::str(func);
3513 callable_info = func;
3514 }
3515 if (FGBuilder()->CheckCallable(callable_info)) {
3516 if (PyFunction_Check(callable_info.ptr())) {
3517 args = GetNewArgs(call_node);
3518 }
3519 return FGAddNode(call_node, callable_info, args, stop_reason);
3520 }
3521
3522 py::object result = this->GraphBuilder::ResolveCallable(call_node, stop_reason);
3523 bool pijit_specialized = original_callable == callable_info // not converted
3524 || call_node->GetSubGraph() != nullptr // pijit sub graph
3525 || callable->GetType() == AObject::kTypeType; // pijit class instantiation
3526 if (pijit_specialized) {
3527 return result;
3528 }
3529 MS_LOG(DEBUG) << "convert " << std::string(py::str(original_callable)) << " -> "
3530 << std::string(py::str(callable_info));
3531 return FindPyFunc(AObject::Convert(callable_info));
3532 }
3533
HandleCallClass(CallNode * call_node)3534 bool MindGraphBuilder::HandleCallClass(CallNode *call_node) {
3535 bool succ = GraphBuilder::HandleCallClass(call_node);
3536 if (!succ) {
3537 MS_LOG(INFO) << "Failed to handle call class";
3538 return false;
3539 } else if (call_node->GetVobj() != nullptr && call_node->GetVobj()->GetPyObject().ptr() != nullptr) {
3540 return FGBuilder()->AddLocalVariable(call_node->GetVobj()->GetPyObject());
3541 }
3542 return false;
3543 }
3544
3545 // Fix dynamic shape tensor get shape issue.
3546 // Guard and Renormalize strategy should be refactored later.
HandleGetShapeOfDynamicLengthTensor(const py::object & object)3547 py::object MindGraphBuilder::HandleGetShapeOfDynamicLengthTensor(const py::object &object) {
3548 auto anf_node = fg_builder_->ReadLocalVariable(object);
3549 if (anf_node == nullptr || anf_node->abstract() == nullptr) {
3550 return py::object();
3551 }
3552 auto abs = anf_node->abstract();
3553 auto shape = abs->BuildShape();
3554 if (!shape->isa<abstract::TensorShape>()) {
3555 return py::object();
3556 }
3557 const auto &tensor_shape = shape->cast<abstract::TensorShapePtr>()->GetShapeVector();
3558 if (std::all_of(tensor_shape.begin(), tensor_shape.end(), [](auto e) { return e > 0; })) {
3559 return py::object();
3560 }
3561 std::vector<py::object> input_objects = {object};
3562 return fg_builder_->AddNode(prim::kPrimShape, input_objects);
3563 }
3564
HandleGetattr(ValueNode * target_node,const Instr & instr)3565 ValueNode *MindGraphBuilder::HandleGetattr(ValueNode *target_node, const Instr &instr) {
3566 auto attr_node = NewValueNode(target_node->get_attr(instr.name()), instr, {target_node});
3567 MS_EXCEPTION_IF_NULL(attr_node);
3568 ValueNode *graph_attr_node = nullptr;
3569 auto attr_obj = attr_node->GetVobj()->GetPyObject();
3570 if (instr.name() == "shape") {
3571 auto ret_object = HandleGetShapeOfDynamicLengthTensor(target_node->GetVobj()->GetPyObject());
3572 if (ret_object.ptr() != nullptr) {
3573 return NewValueNode(AObject::Convert(ret_object), instr, {target_node});
3574 }
3575 }
3576 // If the attr_obj can convert to anf node directly, return the origin attr node.
3577 if (fg_builder_->AddAttrPythonObject(attr_obj)) {
3578 graph_attr_node = attr_node;
3579 } else {
3580 std::vector<py::object> input_objects = {target_node->GetVobj()->GetPyObject(), py::str(instr.name())};
3581 auto graph_attr_obj = fg_builder_->AddNode(prim::kPrimGetAttr, input_objects);
3582 if (graph_attr_obj.ptr() == nullptr) {
3583 graph_attr_node = attr_node;
3584 } else {
3585 graph_attr_node = NewValueNode(AObject::Convert(graph_attr_obj), instr, {target_node});
3586 }
3587 }
3588 // Add Id guard for parameter, in case default value for parameter change in execution.
3589 if (attr_obj.ptr() != nullptr && py::hasattr(attr_obj, "__parameter__") &&
3590 py::isinstance<tensor::MetaTensor>(attr_obj)) {
3591 graph_->GuardValueNode(graph_attr_node, GuardLevel::GId);
3592 return graph_attr_node;
3593 }
3594 // Add Guard for getattr node. For scalar/list/tuple/primitive, need to guard value. Otherwise, guard type and shape.
3595 AObject::Type attr_type = graph_attr_node->GetVobj() ? graph_attr_node->GetVobj()->GetType() : AObject::kTypeAnyValue;
3596 static const std::vector<AObject::Type> const_type = {AObject::kTypeInt, AObject::kTypeFloat, AObject::kTypeBool,
3597 AObject::kTypeTuple, AObject::kTypeList, AObject::kTypeDict,
3598 AObject::kTypePrimitive};
3599 // Need to check whether the guard is failed in the future.
3600 if (std::any_of(const_type.begin(), const_type.end(),
3601 [attr_type](const AObject::Type type) { return attr_type == type; })) {
3602 graph_->GuardValueNode(graph_attr_node, GuardLevel::GEqual);
3603 } else if (attr_type != AObject::kTypeFunction && attr_type != AObject::kTypeBoundMethod &&
3604 attr_type != AObject::kTypeCFunction) {
3605 graph_->GuardValueNode(graph_attr_node, GuardLevel::GDeduce);
3606 }
3607 return graph_attr_node;
3608 }
3609
HandleMultiOp(const Instr & instr,const std::vector<ValueNode * > & p,bool is_compare)3610 AObject *MindGraphBuilder::HandleMultiOp(const Instr &instr, const std::vector<ValueNode *> &p, bool is_compare) {
3611 int opcode = instr.op();
3612 int oparg = instr.arg();
3613 std::vector<py::object> input_obj;
3614 for (auto input : p) {
3615 if (input->GetVobj() == nullptr) {
3616 return AObject::MakeAObject(AObject::kTypeAnyValue);
3617 }
3618 (void)input_obj.emplace_back(input->GetVobj()->GetPyObject());
3619 }
3620 const auto &op_name =
3621 is_compare ? pijit::GraphUtils::OpCompareArgToGraphName(oparg) : pijit::GraphUtils::OpCodeToGraphName(opcode);
3622 MS_LOG(DEBUG) << "operation name is " << op_name;
3623 if (op_name == "") {
3624 return AObject::MakeAObject(AObject::kTypeAnyValue);
3625 }
3626 auto node = fg_builder_->AddMultiNode(op_name, input_obj);
3627 if (node.ptr() == nullptr) {
3628 return AObject::MakeAObject(AObject::kTypeAnyValue);
3629 }
3630 return AObject::Convert(node);
3631 }
3632
HandleBuildOp(const Instr & instr,const std::vector<ValueNode * > & p)3633 AObject *MindGraphBuilder::HandleBuildOp(const Instr &instr, const std::vector<ValueNode *> &p) {
3634 auto opcode = instr.op();
3635 std::vector<py::object> input_obj;
3636 for (auto input : p) {
3637 if (input->GetVobj() == nullptr) {
3638 return AObject::MakeAObject(AObject::kTypeAnyValue);
3639 }
3640 (void)input_obj.emplace_back(input->GetVobj()->GetPyObject());
3641 }
3642 auto primitive = pijit::GraphUtils::GetPrimitive(opcode);
3643 if (primitive == nullptr) {
3644 return AObject::MakeAObject(AObject::kTypeAnyValue);
3645 }
3646 if (primitive == prim::kPrimMakeDict) {
3647 if (opcode == BUILD_CONST_KEY_MAP) {
3648 MS_LOG(DEBUG) << "BUILD_CONST_KEY_MAP case, need to pack values.";
3649 std::vector<py::object> value_inputs;
3650 (void)std::transform(input_obj.begin(), input_obj.end() - 1, std::back_inserter(value_inputs),
3651 [](const py::object &obj) { return obj; });
3652 auto value_node = fg_builder_->AddNode(prim::kPrimMakeTuple, value_inputs);
3653 input_obj = {input_obj.back(), value_node};
3654 } else {
3655 MS_LOG(DEBUG) << "BUILD_KEY_MAP case, need to pack keys and values.";
3656 size_t input_len = input_obj.size();
3657 if (input_len % 2 != 0) {
3658 MS_LOG(INTERNAL_EXCEPTION) << "BUILD_KEY_MAP should have even input, but got: " << input_len;
3659 }
3660 std::vector<py::object> key_obj;
3661 std::vector<py::object> value_obj;
3662 for (size_t i = 0; i < input_len / 2; ++i) {
3663 key_obj.push_back(input_obj[2 * i]);
3664 value_obj.push_back(input_obj[2 * i + 1]);
3665 }
3666 auto key_node = fg_builder_->AddNode(prim::kPrimMakeTuple, key_obj);
3667 auto value_node = fg_builder_->AddNode(prim::kPrimMakeTuple, value_obj);
3668 input_obj = {key_node, value_node};
3669 }
3670 }
3671 if (primitive == prim::kPrimMakeSlice) {
3672 constexpr size_t slice_without_step_len = 2;
3673 if (input_obj.size() == slice_without_step_len) {
3674 // Handle slice without step input scene, such as 0:2. MakeSlice can only handle slice with full inputs.
3675 (void)input_obj.emplace_back(py::int_(1));
3676 }
3677 }
3678 auto node = fg_builder_->AddNode(primitive, input_obj);
3679 return AObject::Convert(node);
3680 }
3681
DoGetItem(const Instr & instr)3682 bool MindGraphBuilder::DoGetItem(const Instr &instr) {
3683 auto r = pop();
3684 auto l = pop();
3685 auto o = HandleMultiOp(instr, {l, r}, false);
3686 auto v = NewValueNode(o, instr, {l, r});
3687 push(v);
3688 return true;
3689 }
3690
DoUnary(const Instr & instr)3691 bool MindGraphBuilder::DoUnary(const Instr &instr) {
3692 auto o = pop();
3693 auto r = HandleMultiOp(instr, {o}, false);
3694 auto v = NewValueNode(r, instr, {o});
3695 push(v);
3696 return true;
3697 }
3698
DoBinary(const Instr & instr)3699 bool MindGraphBuilder::DoBinary(const Instr &instr) {
3700 auto r = pop();
3701 auto l = pop();
3702 auto o = HandleMultiOp(instr, {l, r}, false);
3703 auto v = NewValueNode(o, instr, {l, r});
3704 push(v);
3705 return true;
3706 }
3707
DoBinaryMul(const Instr & instr)3708 bool MindGraphBuilder::DoBinaryMul(const Instr &instr) {
3709 auto r = pop();
3710 auto l = pop();
3711 auto o = HandleMultiOp(instr, {l, r}, false);
3712 auto v = NewValueNode(o, instr, {l, r});
3713 push(v);
3714 return true;
3715 }
3716
DoCompare(const Instr & instr)3717 bool MindGraphBuilder::DoCompare(const Instr &instr) {
3718 auto r = pop();
3719 auto l = pop();
3720 auto o = HandleMultiOp(instr, {l, r}, true);
3721 auto v = NewValueNode(o, instr, {l, r});
3722 push(v);
3723 return true;
3724 }
3725
DoBuildOp(const Instr & instr)3726 bool MindGraphBuilder::DoBuildOp(const Instr &instr) {
3727 int opcode = instr.op();
3728 int oparg = instr.arg();
3729 int tmp_arg = oparg;
3730 tmp_arg += opcode == BUILD_CONST_KEY_MAP;
3731 tmp_arg += opcode == BUILD_MAP ? tmp_arg : 0;
3732 std::vector<ValueNode *> p(frame_.GetStacks().end() - tmp_arg, frame_.GetStacks().end());
3733 auto o = HandleBuildOp(instr, p);
3734 popn(tmp_arg);
3735 auto v = NewValueNode(o, instr, p);
3736 push(v);
3737 return true;
3738 }
3739
DoIsOp(const Instr & instr)3740 bool MindGraphBuilder::DoIsOp(const Instr &instr) { return GraphBuilder::DoBinary(instr); }
3741
HandlePositionParams(const py::object & func,std::vector<ValueNode * > * params,FrameStates * frame)3742 bool MindGraphBuilder::HandlePositionParams(const py::object &func, std::vector<ValueNode *> *params,
3743 FrameStates *frame) {
3744 CallNode *call_node = reinterpret_cast<CallNode *>(seek(0));
3745 PyCodeObject *co = reinterpret_cast<PyCodeObject *>(PyFunction_GET_CODE(func.ptr()));
3746 auto vobj = trace_flag() ? AObject::Convert(func.ptr()) : call_node->input(0)->GetVobj();
3747 AObject::Type callable_type = vobj->GetType();
3748
3749 ValueNode *self = GetBoundSelf(call_node);
3750 if (self != nullptr) {
3751 params->insert(params->begin(), self);
3752 }
3753
3754 const int argc = co->co_argcount;
3755 const int has_varg = (co->co_flags & CO_VARARGS) ? 1 : 0;
3756 const int has_kwvarg = (co->co_flags & CO_VARKEYWORDS) ? 1 : 0;
3757 const int varg_loc = argc + co->co_kwonlyargcount;
3758 const int kwvarg_loc = argc + co->co_kwonlyargcount + has_varg;
3759 int pargc = params->size();
3760 if (pargc > argc && !has_varg) {
3761 MS_LOG(DEBUG) << "too many parameters";
3762 return false;
3763 }
3764 bool append_self_to_varg = has_varg && self && callable_type == AObject::kTypeBoundMethod && argc == 0;
3765 if (append_self_to_varg) { // self is in variable arguments
3766 MS_LOG(INFO) << "not implement append self to variable arguments, inline failed";
3767 return false;
3768 }
3769
3770 if (has_kwvarg && frame->Local(kwvarg_loc) == &ValueNode::kUnboundLocal) {
3771 auto vo = AObject::Convert(py::dict());
3772 auto m = NewValueNode(vo, BUILD_MAP, 0, {});
3773 call_node->AddParam(m);
3774 frame->SetLocal(kwvarg_loc, m);
3775 }
3776
3777 if (has_varg) {
3778 int vargc = pargc > argc ? pargc - argc : 0;
3779 std::vector<ValueNode *> vargs(params->end() - vargc, params->end());
3780 params->resize(params->size() - vargc);
3781 std::for_each(vargs.begin(), vargs.end(), [this](ValueNode *i) { this->push(i); });
3782 DoBuildOp({BUILD_TUPLE, static_cast<int>(vargs.size())});
3783 ValueNode *build_tuple = pop();
3784 call_node->AddParam(build_tuple);
3785 frame->SetLocal(varg_loc, build_tuple);
3786 }
3787
3788 pargc = params->size();
3789 for (int i = pargc - 1; i >= 0; --i) {
3790 if (frame->Local(i) != &ValueNode::kUnboundLocal) {
3791 MS_LOG(DEBUG) << "duplicate key-word parameter error";
3792 return false;
3793 }
3794 frame->SetLocal(i, params->back());
3795 params->pop_back();
3796 }
3797
3798 return CheckAndSetDefaultParams(func, frame, pargc);
3799 }
3800
UnpackCallExParams(std::vector<ValueNode * > * params,int extra_local,bool * has_kw,CallNode * call_node)3801 bool MindGraphBuilder::UnpackCallExParams(std::vector<ValueNode *> *params, int extra_local, bool *has_kw,
3802 CallNode *call_node) {
3803 bool has_dict = params->size() > 1;
3804 ValueNode *args_node = params->operator[](0);
3805 if (!has_dict) {
3806 params->clear();
3807 } else if (!UnpackCallExDict(params, call_node)) {
3808 return false;
3809 }
3810 *has_kw = params->size();
3811
3812 if (args_node->GetVobj() == nullptr) {
3813 return false;
3814 }
3815 py::object object = args_node->GetVobj()->GetPyObject();
3816 if (!py::isinstance<py::tuple>(object)) {
3817 return false;
3818 }
3819 size_t args_len = py::len(py::cast<py::tuple>(object));
3820 if (args_len == 0) {
3821 return true;
3822 }
3823
3824 std::vector<ValueNode *> new_args_inputs;
3825 for (size_t i = 0; i < args_len; ++i) {
3826 Instr instr(BINARY_SUBSCR, 2);
3827 auto l = args_node;
3828 auto r = NewValueNode(AObject::Convert(py::int_(i)), LOAD_CONST, -1, {});
3829 auto o = HandleMultiOp(instr, {l, r}, false);
3830 new_args_inputs.push_back(NewValueNode(o, instr, {l, r}));
3831 }
3832
3833 params->insert(params->begin(), new_args_inputs.begin(), new_args_inputs.end());
3834 return true;
3835 }
3836
HandleKWParams(const py::object & func,std::vector<ValueNode * > * params,FrameStates * frame)3837 bool MindGraphBuilder::HandleKWParams(const py::object &func, std::vector<ValueNode *> *params, FrameStates *frame) {
3838 PyCodeObject *co = reinterpret_cast<PyCodeObject *>(PyFunction_GET_CODE(func.ptr()));
3839 std::vector<ValueNode *> kwvargs;
3840 if (!PackKwParams(func, params, frame, &kwvargs)) {
3841 // illegal arguments
3842 return false;
3843 }
3844
3845 const int argc = co->co_argcount + co->co_kwonlyargcount;
3846 if (!(co->co_flags & CO_VARKEYWORDS)) {
3847 // kw_2_p_cnt == k_cnt, all kw arguments is positions arguments
3848 return true;
3849 }
3850
3851 int kwvarg_loc = argc + ((co->co_flags & CO_VARARGS) ? 1 : 0);
3852 std::for_each(kwvargs.begin(), kwvargs.end(), [this](ValueNode *i) { this->push(i); });
3853 DoBuildOp({BUILD_MAP, SizeToInt(kwvargs.size() / 2)});
3854 ValueNode *new_node = pop();
3855 frame->SetLocal(kwvarg_loc, new_node);
3856 graph_->GetTracedNodes().pop_back();
3857
3858 static_cast<CallNode *>(seek(0))->AddParam(frame->Local(kwvarg_loc));
3859 return true;
3860 }
3861
UnpackCallExDict(std::vector<ValueNode * > * params,CallNode * call_node)3862 bool MindGraphBuilder::UnpackCallExDict(std::vector<ValueNode *> *params, CallNode *call_node) {
3863 ValueNode *dict_node = params->back();
3864 params->clear();
3865
3866 if (dict_node->GetVobj() == nullptr) {
3867 return false;
3868 }
3869
3870 auto object = dict_node->GetVobj()->GetPyObject();
3871 if (!py::isinstance<py::dict>(object)) {
3872 return false;
3873 }
3874 auto dict_object = py::cast<py::dict>(object);
3875 Py_ssize_t dict_len = py::len(dict_object);
3876 if (dict_len == 0) {
3877 return true;
3878 }
3879
3880 py::tuple keys(dict_len);
3881 size_t i = 0;
3882 for (const auto &pair : dict_object) {
3883 auto cur_key = pair.first;
3884 if (!py::isinstance<py::str>(cur_key)) {
3885 return false;
3886 }
3887 keys[i] = cur_key;
3888 Instr instr(BINARY_SUBSCR, 2);
3889 auto l = dict_node;
3890 auto r = NewValueNode(AObject::Convert(py::cast<py::str>(cur_key)), LOAD_CONST, -1, {});
3891 auto o = HandleMultiOp(instr, {l, r}, false);
3892 params->push_back(NewValueNode(o, instr, {l, r}));
3893 i++;
3894 }
3895
3896 ValueNode *const_keys = this->NewValueNode(AObject::Convert(keys), LOAD_CONST, -1, {});
3897 params->push_back(const_keys);
3898 return true;
3899 }
3900
DoItemAccess(const Instr & instr)3901 bool MindGraphBuilder::DoItemAccess(const Instr &instr) {
3902 int opcode = instr.op();
3903 bool res = false;
3904 if (opcode == BINARY_SUBSCR) {
3905 res = DoGetItem(instr);
3906 } else if (opcode == STORE_SUBSCR) {
3907 auto key = pop();
3908 auto map = pop();
3909 auto value = pop();
3910 NewValueNode(nullptr, instr, {value, map, key});
3911 res = DoSetItem(map, key, value);
3912 } else if (opcode == DELETE_SUBSCR) {
3913 auto key = pop();
3914 auto map = pop();
3915 NewValueNode(nullptr, instr, {map, key});
3916 res = DoSetItem(map, key, nullptr);
3917 } else {
3918 MS_LOG(INTERNAL_EXCEPTION) << "parser got an error instruction " << instr.ToString();
3919 }
3920 return res;
3921 }
3922 } // namespace pijit
3923 } // namespace mindspore
3924