1
2 /**
3 * Copyright 2023 Huawei Technologies Co., Ltd
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 #include "pipeline/jit/pi/graph_capture/special_func_infer.h"
18 #include <algorithm>
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <unordered_map>
24 #include <utility>
25 #include <vector>
26 #include "pipeline/jit/pi/common.h"
27 #include "pipeline/jit/pi/external.h"
28 #include "pipeline/jit/pi/graph_capture/graph_build.h"
29 #include "pipeline/jit/pi/graph_guard/infer.h"
30 #include "pipeline/jit/pi/graph_capture/side_effect.h"
31
32 namespace mindspore {
33 namespace pijit {
34 extern ValueNode *GetBoundSelf(CallNode *call_node);
35 extern void LogGuardFailed(ValueNode *node, const GraphJitConfig &conf, const std::string &msg);
36 extern AObject *InferFuncResult(const py::object &func, const std::vector<AObject *> &stack_args, int opcode,
37 const GraphJitConfig &conf, bool clear_guard);
38 extern AObject *InferFuncResult(const py::object &func, const py::object &args, const py::object &kwargs,
39 const GraphJitConfig &conf, bool clear_guard);
40
41 constexpr const char *kModuleName = "mindspore._extends.pijit.pijit_func_white_list";
42 constexpr const char *kFuncMapName = "_func_map";
43 constexpr const char *kSlotCallName = "__call__";
44 constexpr const size_t kDictPopParamsNum = 2;
45 constexpr const size_t BoundMethodInputSize = 2;
46
47 static bool CheckConstexpr(const py::object &func);
48
49 template <AObject::Type type>
SetCallResType(CallNode * call_node,GraphBuilder * unused=nullptr)50 static bool SetCallResType(CallNode *call_node, GraphBuilder *unused = nullptr) {
51 call_node->SetVobj(AObject::MakeAObject(type));
52 call_node->SetSubGraph(nullptr);
53 return false;
54 }
55
JustCallAndSetRes(CallNode * call_node,GraphBuilder * unused)56 bool JustCallAndSetRes(CallNode *call_node, GraphBuilder *unused) {
57 py::object func = call_node->input(0)->GetVobj()->GetPyObject();
58 if (func.ptr() == nullptr) {
59 return SetCallResType<AObject::kTypeAnyValue>(call_node);
60 }
61
62 std::vector<py::object> args;
63 std::transform(call_node->getInputs().begin() + 1, call_node->getInputs().end(), std::back_inserter(args),
64 [](ValueNode *n) { return n->GetVobj() ? n->GetVobj()->GetPyObject() : py::object(); });
65 auto pair = Utils::PackCallStackArgs(args, call_node->GetOpcode());
66 if (pair.first.ptr() == nullptr) {
67 return SetCallResType<AObject::kTypeAnyValue>(call_node);
68 }
69
70 pi_jit_disable();
71 PyObject *value = PyObject_Call(func.ptr(), pair.first.ptr(), pair.second.ptr());
72 if (PyErr_Occurred()) {
73 MS_LOG(INFO) << "got an error " << py::error_already_set().what() << " at call the "
74 << std::string(py::str(func.ptr()));
75 PyErr_Clear();
76 }
77 pi_jit_enable();
78
79 call_node->SetVobj(AObject::Convert(value));
80 call_node->SetSubGraph(nullptr);
81 Py_XDECREF(value);
82 return false;
83 }
84
CallNodeReturnConst(CallNode * call_node,Graph * sub_graph,AObject * value)85 static bool CallNodeReturnConst(CallNode *call_node, Graph *sub_graph, AObject *value) {
86 PyObject *cnst = value->GetPyObject().ptr();
87 MS_EXCEPTION_IF_NULL(cnst);
88
89 ValueNode *ret_node = sub_graph->NewValueNode(value, LOAD_CONST, -1, {});
90 call_node->SetSubGraph(sub_graph);
91 ret_node->SetGraph(call_node->GetGraph());
92
93 sub_graph->SetRetVal(ret_node);
94 call_node->SetInlineReason(InlineReason::kInline);
95 return true;
96 }
97
GuardConstCallNodeParam(CallNode * call_node,Graph * sub_graph,int max_guard_depth)98 bool GuardConstCallNodeParam(CallNode *call_node, Graph *sub_graph, int max_guard_depth) {
99 std::vector<std::pair<TracePtr, GuardLevel>> traces;
100 for (auto i : call_node->getInputs()) {
101 if (i->IsConstantValue()) {
102 continue;
103 }
104 AObject::Type type = i->GetVobj() ? i->GetVobj()->GetType() : AObject::kTypeAnyValue;
105 if (type == AObject::kTypeAnyValue) {
106 return false;
107 }
108 TracePtr tr = sub_graph->TraceValueNode(i, max_guard_depth);
109 if (tr == nullptr) {
110 if (static_cast<size_t>(max_guard_depth) >= INT_MAX) {
111 LogGuardFailed(i, sub_graph->Config(), "GuardConstCannNodeParm failed");
112 }
113 return false;
114 }
115 GuardLevel level = GuardLevel::GEqual;
116 if (type == AObject::kTypeTensor) {
117 if (i->GetOpcode() == LOAD_GLOBAL) {
118 level = GuardLevel::GId; // only guard global tensor
119 } else {
120 level = GuardLevel::GDeduce;
121 }
122 }
123 traces.push_back({tr, level});
124 }
125
126 const auto &guard = sub_graph->GetGuard()->GetGuard();
127 guard->Backup();
128 for (const auto &i : traces) {
129 if (!guard->GuardOn(i.first, i.second)) {
130 guard->Rollback();
131 return false;
132 }
133 }
134 guard->Pop();
135 return true;
136 }
137
InferConvertMap(CallNode * call_node,GraphBuilder * unused=nullptr)138 static bool InferConvertMap(CallNode *call_node, GraphBuilder *unused = nullptr) {
139 AObject *func_info = call_node->input(0)->GetVobj();
140 func_info->SetMsFlag(AObject::kMsFlagStandardFunc);
141 py::object func = func_info->GetPyObject();
142 py::object tmp = Utils::GetModuleAttr("mindspore._extends.parse.resources", "convert_object_map");
143 auto dict_obj = py::cast<py::dict>(tmp);
144 auto infer_obj = dict_obj[func];
145 AObject *res = nullptr;
146 call_node->SetSubGraph(nullptr);
147 SetCallResType<AObject::kTypeTensor>(call_node);
148 if (PyFunction_Check(infer_obj.ptr())) {
149 MS_LOG(DEBUG) << "infer function " << std::string(py::str(PyFunction_GET_CODE(infer_obj.ptr())));
150 int op = call_node->GetOpcode();
151 const auto &conf = call_node->GetGraph()->Config();
152 std::vector<AObject *> args;
153 std::transform(call_node->getInputs().begin() + 1, call_node->getInputs().end(), std::back_inserter(args),
154 [](ValueNode *n) { return n->GetVobj(); });
155 res = InferFuncResult(func, {args.begin() + 1, args.end()}, op, conf, true);
156 } else if (IsPrimitiveType<true>(Py_TYPE(infer_obj.ptr()))) {
157 MS_LOG(DEBUG) << "infer primitive " << std::string(py::str(infer_obj));
158 std::vector<PyObject *> list;
159 bool infer_fail = false;
160 for (size_t i = 1; !infer_fail && i < call_node->getInputs().size(); i++) {
161 AObject *p = call_node->input(i)->GetVobj();
162 PyObject *o = p ? p->GetPyObject().ptr() : nullptr;
163 list.push_back(o);
164 infer_fail = o == nullptr;
165 }
166 if (infer_fail) {
167 return false;
168 }
169 auto inst = mindspore::pijit::InferEngine::GetInstance();
170 bool is_abstract = false;
171 PyObject *ret = inst->InferPrimitive(infer_obj.ptr(), list, &is_abstract);
172 if (ret == nullptr) {
173 return false;
174 }
175 AObject::Type type = AObject::GetPyType(ret);
176 res = is_abstract && type != AObject::kTypeTensor ? AObject::MakeAObject(type) : AObject::Convert(ret);
177 Py_DECREF(ret);
178 } else {
179 return false;
180 }
181 if (res) {
182 call_node->SetVobj(res);
183 }
184 return false;
185 }
186
InferGetCachePrim(CallNode * n,GraphBuilder * unused=nullptr)187 static bool InferGetCachePrim(CallNode *n, GraphBuilder *unused = nullptr) {
188 // just return the first parameter of _get_cache_prim
189 Graph *g = n->GetSubGraph();
190 n->SetVobj(n->input(1)->GetVobj());
191 g->SetRetVal(n->input(1));
192 return true;
193 }
194
InferRegistryGet(CallNode * call_node,GraphBuilder * unused=nullptr)195 static bool InferRegistryGet(CallNode *call_node, GraphBuilder *unused = nullptr) {
196 Graph *g = call_node->GetSubGraph();
197 JustCallAndSetRes(call_node);
198
199 py::object func = call_node->GetVobj()->GetPyObject();
200 if (call_node->getInputs().back()->GetOpcode() == LOAD_CONST && func.ptr() != nullptr) {
201 return CallNodeReturnConst(call_node, g, call_node->GetVobj());
202 }
203 return false;
204 }
205
InferPrimitive(CallNode * call_node,GraphBuilder * unused=nullptr)206 static bool InferPrimitive(CallNode *call_node, GraphBuilder *unused = nullptr) {
207 static const std::unordered_map<std::string, AObject::Type> not_ret_tensor_prim = {
208 {"Prim[_get_grad_op]<constexpr_prim=True>", AObject::kTypeMetaFuncGraph},
209 {"Prim[DType]", AObject::kTypeAnyValue},
210 {"Prim[Partial]<side_effect_propagate=1>", AObject::kTypeAnyValue},
211 };
212 Graph *sub_graph = call_node->GetSubGraph();
213 call_node->SetVobj(AObject::MakeAObject(AObject::kTypeTensor));
214 call_node->SetSubGraph(nullptr);
215 PyObject *prim = call_node->input(0)->GetVobj()->GetPyObject().ptr();
216 std::string prim_key = std::string(py::str(prim));
217 if (prim_key == "Prim[_get_grad_op]<constexpr_prim=True>") {
218 py::object grad_class = Utils::GetModuleAttr("mindspore._c_expression", "GradOperation_");
219 AbstractType *type = static_cast<AbstractType *>(AObject::Convert(grad_class));
220 AObject *res = type != nullptr ? type->BuildAbstractInstance({}, CALL_FUNCTION)
221 : AObject::MakeAObject(AObject::kTypeMetaFuncGraph);
222 call_node->SetVobj(res);
223 return false;
224 }
225
226 auto iter = not_ret_tensor_prim.find(prim_key);
227 if (iter != not_ret_tensor_prim.end()) {
228 call_node->SetVobj(AObject::MakeAObject(iter->second));
229 } else {
230 call_node->SetVobj(AObject::MakeAObject(AObject::kTypeTensor));
231 }
232
233 std::vector<PyObject *> list;
234 bool infer_fail = false;
235 for (size_t i = 1; !infer_fail && i < call_node->getInputs().size(); i++) {
236 AObject *p = call_node->input(i)->GetVobj();
237 if (p == nullptr) {
238 infer_fail = true;
239 break;
240 }
241 PyObject *o;
242 if (p->GetType() == AObject::kTypeTensor) {
243 o = static_cast<AbstractTensor *>(p)->GetTensor(true).ptr();
244 } else {
245 o = p->GetPyObject().ptr();
246 }
247 list.push_back(o);
248 infer_fail = o == nullptr;
249 }
250 if (infer_fail) {
251 return false;
252 }
253
254 auto inst = mindspore::pijit::InferEngine::GetInstance();
255 bool is_abstract = false;
256 PyObject *ret;
257 try {
258 ret = inst->InferPrimitive(prim, list, &is_abstract);
259 } catch (std::exception &e) {
260 MS_LOG(ERROR) << "infer primitive failed. reason:";
261 MS_LOG(ERROR) << e.what();
262 ret = nullptr;
263 }
264 if (ret == nullptr) {
265 return false;
266 }
267
268 AObject::Type type = AObject::GetPyType(ret);
269 AObject *type_info = is_abstract && type != AObject::kTypeTensor ? AObject::MakeAObject(type) : AObject::Convert(ret);
270 call_node->SetVobj(type_info);
271 Py_DECREF(ret);
272
273 ConstantInfo::CollectPrimitiveConstantInfo(call_node);
274 if (call_node->IsConstantValue()) {
275 return CallNodeReturnConst(call_node, sub_graph, call_node->GetVobj());
276 }
277 return false;
278 }
279
InferGradOperation(CallNode * call_node,AObject::MindsporeFlag f)280 static bool InferGradOperation(CallNode *call_node, AObject::MindsporeFlag f) {
281 call_node->SetSubGraph(nullptr);
282 AObject *grad_func = AObject::MakeAObject(AObject::kTypeFunction);
283 grad_func->SetMsFlag(f);
284 call_node->SetVobj(grad_func);
285 py::object func = GraphBuilder::FindPyFunc(call_node->input(1)->GetVobj());
286 if (func.ptr() == nullptr) {
287 return false;
288 }
289 (void)pi_jit_should_compile(func, py::dict(), py::none());
290 auto jcr = getJitCompileResults(PyFunction_GET_CODE(func.ptr()));
291 *jcr->conf = call_node->GetGraph()->Config();
292 return false;
293 }
294
InferMetaFunc(CallNode * call_node,GraphBuilder * unused=nullptr)295 static bool InferMetaFunc(CallNode *call_node, GraphBuilder *unused = nullptr) {
296 call_node->SetSubGraph(nullptr);
297 const auto &vo = call_node->input(0)->GetVobj();
298 MS_EXCEPTION_IF_CHECK_FAIL(vo->GetType() != AObject::kTypeType, "class call is before ");
299 PyTypeObject *tp = vo->GetTypeObject();
300 if (IsGradOperationType<true>(tp)) {
301 // set grad flag
302 return InferGradOperation(call_node, AObject::MindsporeFlag::kMsFlagGradFunc);
303 } else if (IsVmapOperationType<true>(tp)) {
304 // set vmap flag
305 return InferGradOperation(call_node, AObject::MindsporeFlag::kMsFlagVmapFunc);
306 } else if (IsShardType<true>(tp)) {
307 // set shard flag
308 return InferGradOperation(call_node, AObject::MindsporeFlag::kMsFlagShardFunc);
309 }
310 return false;
311 }
312
313 /**
314 * find first free variable in names from function
315 */
FindClosure(const py::object & o,const std::vector<std::string> & names,TracePtr * trace,bool strict,bool print)316 static py::object FindClosure(const py::object &o, const std::vector<std::string> &names, TracePtr *trace, bool strict,
317 bool print) {
318 PyObject *func = o.ptr();
319 if (PyMethod_Check(func)) {
320 func = PyMethod_GET_FUNCTION(func);
321 }
322 if (!PyFunction_Check(func)) {
323 return py::object();
324 }
325 PyCodeObject *co = reinterpret_cast<PyCodeObject *>(PyFunction_GET_CODE(func));
326 PyObject *closure = PyFunction_GET_CLOSURE(func);
327 Py_ssize_t i = PyTuple_GET_SIZE(co->co_freevars) - 1;
328 bool find = false;
329 for (; i >= 0 && !find; --i) {
330 std::string name = PyUnicode_AsUTF8(PyTuple_GET_ITEM(co->co_freevars, i));
331 find = std::find(names.begin(), names.end(), name) != names.end();
332 }
333 if (!find) {
334 return py::object();
335 }
336 Py_ssize_t idx = i + 1;
337 PyObject *cell = PyTuple_GET_ITEM(closure, idx);
338 PyObject *content = PyCell_GET(cell);
339 if (trace) {
340 TracePtr attr = CreateOpTrace(closure, LOAD_ATTR, 0, {*trace}, "", "__closure__", strict, print);
341 TracePtr cc = CreateOpTrace(cell, BINARY_SUBSCR, 0, {attr, std::make_shared<ConstTrace>(py::int_(idx).ptr(), -1)},
342 "", "", strict, print);
343 *trace = CreateOpTrace(content, LOAD_ATTR, 0, {cc}, "", "cell_contents", strict, print);
344 }
345 return py::cast<py::object>(content);
346 }
347
348 /**
349 * get decorated function from 'after_grad'
350 * \param after_grad _Grad.__call__.<locals>.after_grad
351 * \return decorated object
352 */
GetGradDecorated(const py::object & after_grad,TracePtr * trace,bool strict,bool print)353 static py::object GetGradDecorated(const py::object &after_grad, TracePtr *trace, bool strict, bool print) {
354 MS_ASSERT(PyFunction_Check(after_grad.ptr()));
355 py::object decorated = FindClosure(after_grad, {"fn", "fn_"}, trace, strict, print);
356 MS_EXCEPTION_IF_CHECK_FAIL(decorated.ptr() != nullptr, "can't find decorated function 'fn' or 'fn_' from " +
357 std::string(py::str(after_grad.ptr())));
358 if (!PyFunction_Check(decorated.ptr())) {
359 return decorated;
360 }
361 std::string decorated_name = PyUnicode_AsUTF8(reinterpret_cast<PyFunctionObject *>(decorated.ptr())->func_qualname);
362 if (decorated_name == "_Grad.__call__.<locals>.aux_fn") {
363 decorated = FindClosure(decorated, {"fn"}, trace, strict, print);
364 MS_EXCEPTION_IF_CHECK_FAIL(decorated.ptr() != nullptr, "can't find decorated function 'fn' from " + decorated_name);
365 }
366 return decorated;
367 }
368
DeleteGradSensArgs(const py::object & args,const py::object & kwargs)369 static py::object DeleteGradSensArgs(const py::object &args, const py::object &kwargs) {
370 // sens param specified in kwargs
371 if (kwargs.ptr() != nullptr && PyDict_DelItemString(kwargs.ptr(), "sens_param") != -1) {
372 return args;
373 }
374 PyErr_Clear();
375 // sens param is the last position argument
376 PyObject *new_arg = PyTuple_GetSlice(args.ptr(), 0, PyTuple_GET_SIZE(args.ptr()) - 1);
377 return py::reinterpret_steal<py::object>(new_arg);
378 }
379
InferGradFuncResult(const py::object & func,const py::object & args,const py::object & kwargs,const GraphJitConfig & conf)380 static AObject *InferGradFuncResult(const py::object &func, const py::object &args, const py::object &kwargs,
381 const GraphJitConfig &conf) {
382 auto jcr = getJitCompileResults(func.ptr());
383 *jcr->conf = conf;
384 return InferFuncResult(func, args, kwargs, conf, true);
385 }
386
387 /**
388 * Use the function decorated by 'after_grad' and arguments of 'after_grad' when called to infer result.
389 * If the function has no unsupported operation, merge the guard of inferred graph to caller graph.
390 * else clear the mask of mindspore flag, avoid to capture this function call
391 */
HandleGradFuncCall(CallNode * call_node,AObject * decorated,bool sens_param,const py::object & after_grad)392 void HandleGradFuncCall(CallNode *call_node, AObject *decorated, bool sens_param, const py::object &after_grad) {
393 const int except_flag = AObject::kMsFlagGradFunc | AObject::kMsFlagShardFunc | AObject::kMsFlagVmapFunc;
394 ValueNode *grad_func_node = call_node->input(0);
395 std::vector<py::object> stack_args;
396 py::object func;
397 py::object args;
398 py::object kwargs;
399
400 // prepare parameters
401 bool param_ready = decorated->GetPyObject().ptr() != nullptr;
402 for (size_t i = 1; param_ready && i < call_node->getInputs().size(); ++i) {
403 AObject *tmp = call_node->input(i)->GetVobj();
404 stack_args.emplace_back(tmp != nullptr ? tmp->GetPyObject() : py::object());
405 param_ready = stack_args.back().ptr() != nullptr;
406 }
407 if (param_ready) {
408 auto pair = Utils::PackCallStackArgs(stack_args, call_node->GetOpcode());
409 args = pair.first;
410 kwargs = pair.second;
411 param_ready = pair.first.ptr() != nullptr;
412 }
413 if (!param_ready) {
414 call_node->SetInlineReason(InlineReason::kInlineInfer_Fail);
415 grad_func_node->GetVobj()->ClearMsFlag(except_flag);
416 return;
417 }
418 if (sens_param) {
419 args = DeleteGradSensArgs(args, kwargs);
420 }
421
422 // get callable
423 if (decorated->GetType() != AObject::kTypeCell) {
424 MS_EXCEPTION_IF_CHECK_FAIL(decorated->GetType() == AObject::kTypeFunction, "check grad input");
425 func = decorated->GetPyObject();
426 } else {
427 // here get bound method.
428 func = decorated->GetAttr(GraphBuilder::ID_construct)->GetPyObject();
429 }
430
431 AObject *res = InferGradFuncResult(func, args, kwargs, call_node->GetGraph()->Config());
432 if (res == nullptr || !res->IsMindSporeSupportedType()) {
433 call_node->SetInlineReason(InlineReason::kInlineInfer_Fail);
434 grad_func_node->GetVobj()->ClearMsFlag(except_flag);
435 return;
436 }
437 py::object infer_after_grad = Utils::GetModuleAttr(kModuleName, "infer_after_grad", true, true);
438 py::object result;
439 try {
440 result = infer_after_grad(after_grad, args, res->GetPyObject());
441 } catch (std::exception &e) {
442 MS_LOG(WARNING) << "Error while infer_after_grad, error:" << e.what();
443 PyErr_Clear();
444 }
445 if (result.ptr() != nullptr && result.ptr() != Py_None) {
446 call_node->SetVobj(AObject::Convert(result));
447 } else {
448 call_node->SetVobj(res);
449 }
450 call_node->SetInlineReason(InlineReason::kInlineGraphSupportedByMS);
451 }
452
HandleGradFunc(CallNode * call_node,const py::object & after_grad,TracePtr * trace)453 static void HandleGradFunc(CallNode *call_node, const py::object &after_grad, TracePtr *trace) {
454 auto config = call_node->GetGraph()->Config();
455 bool strict = config.GetBoolConfig(GraphJitConfig::kStrictTrace);
456 bool print = config.GetBoolConfig(GraphJitConfig::kPrintGuard);
457 py::object decorated_func = GetGradDecorated(after_grad, trace, strict, print);
458 TracePtr ptr = *trace;
459 py::object grad = FindClosure(after_grad, {"grad_", "self"}, &ptr, strict, print);
460 MS_EXCEPTION_IF_CHECK_FAIL(grad.ptr() != nullptr,
461 "can't find 'grad_' object from " + std::string(py::str(after_grad.ptr())));
462 bool sens_param = grad.attr("sens_param").ptr() == Py_True;
463 MS_LOG(DEBUG) << "infer function 'after_grad', has sens_param " << (sens_param ? "True" : "False");
464
465 auto guard = call_node->GetGraph()->GetGuard()->GetGuard();
466 guard->GuardOn(*trace, mindspore::pijit::GuardLevel::GEqual);
467 if (config.GetBoolConfig(GraphJitConfig::kGuardDetachObject)) {
468 (*trace)->Detach();
469 }
470 call_node->SetSubGraph(nullptr);
471 HandleGradFuncCall(call_node, AObject::Convert(decorated_func), sens_param, after_grad);
472 }
473
InferGradFunc(CallNode * call_node,GraphBuilder * unused=nullptr)474 static bool InferGradFunc(CallNode *call_node, GraphBuilder *unused = nullptr) {
475 AObject *vo = call_node->input(0)->GetVobj();
476 vo->SetMsFlag(AObject::kMsFlagGradFunc);
477 py::object after_grad = vo->GetPyObject();
478 TracePtr trace = call_node->GetGraph()->TraceValueNode(call_node->input(0));
479 if (trace == nullptr) {
480 vo->ClearMsFlag(AObject::kMsFlagGradFunc);
481 call_node->SetSubGraph(nullptr);
482 return false;
483 }
484 HandleGradFunc(call_node, after_grad, &trace);
485 return false;
486 }
487
InferMSConstexpr(CallNode * call_node,GraphBuilder * unused=nullptr)488 static bool InferMSConstexpr(CallNode *call_node, GraphBuilder *unused = nullptr) {
489 Graph *g = call_node->GetSubGraph();
490 JustCallAndSetRes(call_node);
491
492 py::object cnst = call_node->GetVobj()->GetPyObject();
493 if (cnst.ptr() == nullptr) {
494 return false;
495 }
496 bool is_constexpr = CheckConstexpr(call_node->input(0)->GetVobj()->GetPyObject());
497 constexpr int max_guard_depth = 2;
498 if (is_constexpr || GuardConstCallNodeParam(call_node, g, max_guard_depth)) {
499 return CallNodeReturnConst(call_node, g, call_node->GetVobj());
500 }
501 return false;
502 }
503
GuardBuiltinFunc(CallNode * call_node)504 static bool GuardBuiltinFunc(CallNode *call_node) {
505 if (call_node->input(0)->GetVobj() == nullptr) {
506 return false;
507 }
508 PyObject *func = call_node->input(0)->GetVobj()->GetPyObject().ptr();
509 if (PyMethod_Check(func)) {
510 auto self = PyMethod_GET_SELF(func);
511 if (IsTensorType<true>(Py_TYPE(self)) && !CheckTensorDataInitialized(py::cast<py::object>(self))) {
512 // fake value
513 return false;
514 }
515 }
516 Graph *graph = call_node->GetGraph();
517 for (auto i : call_node->getInputs()) {
518 if (i->GetVobj() && i->GetVobj()->GetType() == AObject::kTypeTensor) {
519 AbstractTensor *tensor = static_cast<AbstractTensor *>(i->GetVobj());
520 if (!tensor->IsStubTensor() && !CheckTensorDataInitialized(tensor->GetPyObject())) {
521 // fake value
522 return false;
523 }
524 }
525 }
526 return graph->GuardValueNode(call_node);
527 }
528
GuardIsInstance(CallNode * call_node)529 static bool GuardIsInstance(CallNode *call_node) {
530 Graph *graph = call_node->GetGraph();
531 const auto &cnst = call_node->input(1)->GetConstantInfo();
532 if (cnst != nullptr && cnst->type() != nullptr) {
533 constexpr int second_arg = 2;
534 auto success = graph->GuardValueNode(call_node->input(second_arg));
535 if (!success && (call_node->GetGraph()->Config().getIntConfig(GraphJitConfig::kGuardRelaxCount) > 0)) {
536 TracePtr tr = graph->TraceValueNode(call_node->input(second_arg));
537 if (tr == nullptr) {
538 return true;
539 }
540 }
541 return success;
542 }
543 auto success = graph->GuardValueNode(call_node);
544 if (!success && (call_node->GetGraph()->Config().getIntConfig(GraphJitConfig::kGuardRelaxCount) > 0)) {
545 TracePtr tr = graph->TraceValueNode(call_node);
546 if (tr == nullptr) {
547 return true;
548 }
549 }
550 return success;
551 }
552
InferBuiltinFuncOrMethod(CallNode * call_node,GraphBuilder * unused=nullptr)553 bool InferBuiltinFuncOrMethod(CallNode *call_node, GraphBuilder *unused = nullptr) {
554 Graph *sub_graph = call_node->GetSubGraph();
555 (void)JustCallAndSetRes(call_node);
556 ConstantInfo::CollectBuiltinFuncConstantInfo(call_node);
557 if (call_node->IsConstantValue()) {
558 return CallNodeReturnConst(call_node, sub_graph, call_node->GetVobj());
559 }
560 if (call_node->GetVobj() == nullptr || call_node->GetVobj()->GetPyObject().ptr() == nullptr) {
561 return false;
562 }
563
564 bool guard_success = false;
565 std::string name = GetFuncName(call_node->input(0)->GetVobj()->GetPyObject());
566 if (name == "isinstance") {
567 guard_success = GuardIsInstance(call_node);
568 } else {
569 guard_success = GuardBuiltinFunc(call_node);
570 }
571 if (guard_success) {
572 return CallNodeReturnConst(call_node, sub_graph, call_node->GetVobj());
573 }
574 return false;
575 }
576
InferTensorAsType(CallNode * call_node,GraphBuilder * unused=nullptr)577 static bool InferTensorAsType(CallNode *call_node, GraphBuilder *unused = nullptr) {
578 ValueNode *self_node = GetBoundSelf(call_node);
579 bool is_not_method = call_node->input(0)->GetVobj()->GetType() != AObject::kTypeBoundMethod;
580 ValueNode *dtype_node = call_node->input(1 + is_not_method);
581
582 Graph *sub_graph = call_node->GetSubGraph();
583
584 py::object prim_cast = Utils::GetModuleAttr("mindspore.ops.functional", "cast", false, true);
585
586 PyTypeObject *tp = Py_TYPE(prim_cast.ptr());
587 std::stringstream s;
588 s << (tp->tp_name ? tp->tp_name : "<unnamed>") << "<" << prim_cast.ptr() << ">";
589
590 ValueNode *prim_node = sub_graph->NewValueNode(AObject::Convert(prim_cast), LOAD_CONST, -1, {});
591
592 if (dtype_node->GetVobj()->GetType() == AObject::kTypeString &&
593 dtype_node->GetVobj()->GetPyObject().ptr() != nullptr) {
594 auto dtypeStr = py::cast<std::string>(dtype_node->GetVobj()->GetPyObject());
595 std::vector<std::string> under_line_dtype = {"bool", "int", "float", "list", "tuple"};
596 if (std::find(under_line_dtype.begin(), under_line_dtype.end(), dtypeStr) != under_line_dtype.end()) {
597 dtypeStr = dtypeStr + "_";
598 }
599 auto dtype_obj = Utils::GetModuleAttr("mindspore.common.dtype", dtypeStr, false, true);
600 if (dtype_obj.ptr() != nullptr) {
601 dtype_node = sub_graph->NewValueNode(AObject::Convert(dtype_obj), LOAD_CONST, -1, {});
602 }
603 }
604
605 std::vector<ValueNode *> cast_args = {prim_node, self_node, dtype_node};
606 CallNode *ret_node = sub_graph->NewCallNode(CALL_FUNCTION, cast_args.size() - 1, cast_args);
607 ret_node->SetGraph(sub_graph);
608 (void)InferPrimitive(ret_node);
609
610 sub_graph->GetTracedNodes().push_back(prim_node);
611 sub_graph->GetTracedNodes().push_back(ret_node);
612 sub_graph->SetRetVal(ret_node);
613
614 call_node->SetSubGraph(sub_graph);
615 call_node->SetVobj(ret_node->GetVobj());
616 call_node->SetInlineReason(InlineReason::kInline);
617 return true;
618 }
619
RecordSideEffectCallNode(Graph * graph,CallNode * call_node,SideEffect::Type type,bool trace_flag)620 static void RecordSideEffectCallNode(Graph *graph, CallNode *call_node, SideEffect::Type type, bool trace_flag) {
621 const auto &side_effect = graph->GetSideEffect();
622 ValueNode *side_effect_node;
623 if (trace_flag) {
624 side_effect_node = call_node;
625 } else {
626 side_effect_node = graph->NewCallNode(call_node->GetOpcode(), call_node->GetOparg(), call_node->getInputs());
627 side_effect_node->SetVobj(AObject::MakeAObject(AObject::kTypeAnyValue));
628 graph->GetTracedNodes().push_back(side_effect_node);
629 }
630 side_effect->Record(side_effect_node, type);
631 }
632
InferListAppend(CallNode * call_node,GraphBuilder * parent)633 static bool InferListAppend(CallNode *call_node, GraphBuilder *parent) {
634 call_node->SetSubGraph(nullptr);
635
636 // check is supported type and get arguments
637 bool is_method_descriptor = false;
638 ValueNode *self = GetSelfFromListAppendCall(call_node, &is_method_descriptor);
639 if (self == nullptr) {
640 return false;
641 }
642 ValueNode *new_element = call_node->input(1 + is_method_descriptor);
643
644 // transform to "new_list = [old_list[0], old_list[1]..., new_element]"
645 int size = parent->frame().GetStacks().size();
646 if (!parent->UnpackElements(self)) {
647 return false;
648 }
649 parent->push(new_element);
650 size = parent->frame().GetStacks().size() - size;
651 parent->DoBuildOp({BUILD_LIST, size});
652 auto new_node = parent->pop();
653 auto old_node = self;
654
655 // constant fold and set node info
656 auto builder = GraphBuilder::Creator(parent->root(), parent, nullptr, nullptr, parent->trace_flag());
657 Graph *sub_graph = builder->GetGraph();
658 builder->DoLoadConst({LOAD_CONST, 0, py::object(py::none())});
659 builder->DoReturn({RETURN_VALUE, 0});
660
661 call_node->SetSubGraph(sub_graph);
662 call_node->SetVobj(sub_graph->GetRetVal()->GetVobj());
663 call_node->SetInlineReason(InlineReason::kInline);
664
665 // update frame status and record side-effect
666 bool is_referenced = false;
667 parent->ReplaceAll(old_node, new_node, &is_referenced);
668 const auto &replace_map = parent->GetGraph()->GetSideEffect()->data()->modified_and_replaced_map();
669 bool is_new_var = self->GetOpcode() == BUILD_LIST && replace_map.find(self) == replace_map.end();
670 if (!is_new_var || is_referenced || self == new_element) {
671 parent->GetGraph()->GetSideEffect()->data()->RecordModifiedAndReplacedNode(old_node, new_node);
672 RecordSideEffectCallNode(parent->GetGraph(), call_node, SideEffect::kListAppend, parent->trace_flag());
673 }
674 return true;
675 }
676
InferDictPop(CallNode * call_node,GraphBuilder * parent)677 static bool InferDictPop(CallNode *call_node, GraphBuilder *parent) {
678 call_node->SetSubGraph(nullptr);
679
680 bool is_method_descriptor = false;
681 ValueNode *self = GetSelfFromListAppendCall(call_node, &is_method_descriptor);
682 if (self == nullptr) {
683 return false;
684 }
685 // guard dict key and convert to constant key map
686 if (!parent->GetGraph()->GuardValueNode(self)) {
687 return false;
688 }
689
690 ValueNode *dict_node = self;
691 ValueNode *key_node = call_node->input(1 + is_method_descriptor);
692 ValueNode *default_node = call_node->getInputs().size() > (kDictPopParamsNum + is_method_descriptor)
693 ? call_node->input(kDictPopParamsNum + is_method_descriptor)
694 : nullptr;
695 // get key from dict
696 py::object dict = dict_node->GetVobj()->GetPyObject();
697 py::object key = key_node->GetVobj()->GetPyObject();
698 MS_EXCEPTION_IF_CHECK_FAIL(PyDict_Check(dict.ptr()), "for dict.pop, first parameter must be a dict");
699 py::object value = py::reinterpret_borrow<py::object>(PyDict_GetItem(dict.ptr(), key.ptr()));
700 if (value.ptr() == nullptr) {
701 if (default_node == nullptr) {
702 return false; // key error
703 }
704 value = default_node->GetVobj()->GetPyObject();
705 }
706
707 // transform to "new_map = {key:old_map[key]...}"
708 ValueNode *old_node = dict_node;
709 ValueNode *new_node = parent->TransformDictSetItem(dict_node, key_node, nullptr, default_node != nullptr);
710 if (new_node == nullptr) {
711 return false;
712 }
713
714 // constant fold and set node info
715 auto builder = GraphBuilder::Creator(parent->root(), parent, nullptr, nullptr, parent->trace_flag());
716 Graph *sub_graph = builder->GetGraph();
717 builder->DoLoadConst({LOAD_CONST, 0, value});
718 builder->DoReturn({RETURN_VALUE, 0});
719
720 call_node->SetSubGraph(sub_graph);
721 call_node->SetVobj(sub_graph->GetRetVal()->GetVobj());
722 call_node->SetInlineReason(InlineReason::kInline);
723
724 // update frame status and record side-effect
725 bool is_referenced = false;
726 parent->ReplaceAll(old_node, new_node, &is_referenced);
727 const auto &replace_map = parent->GetGraph()->GetSideEffect()->data()->modified_and_replaced_map();
728 bool is_new_var = self->GetOpcode() == BUILD_MAP && replace_map.find(self) == replace_map.end();
729 if (!is_new_var || is_referenced) {
730 parent->GetGraph()->GetSideEffect()->data()->RecordModifiedAndReplacedNode(old_node, new_node);
731 RecordSideEffectCallNode(parent->GetGraph(), call_node, SideEffect::kDictPop, parent->trace_flag());
732 }
733 return true;
734 }
735
SetForbiddenFuncInfo(CallNode * call_node,GraphBuilder * unused=nullptr)736 static bool SetForbiddenFuncInfo(CallNode *call_node, GraphBuilder *unused = nullptr) {
737 SetCallResType<AObject::kTypeAnyValue>(call_node);
738 call_node->SetInlineReason(InlineReason::kInlineFunc_Type_Unsupported);
739 return false;
740 }
741
742 template <bool force_ms_api>
InferMsApiFunc(CallNode * call_node,GraphBuilder * unused=nullptr)743 bool InferMsApiFunc(CallNode *call_node, GraphBuilder *unused = nullptr) {
744 Graph *sub_graph = call_node->GetSubGraph();
745 SetCallResType<AObject::kTypeAnyValue>(call_node);
746 if (call_node->input(0)->GetVobj() == nullptr || call_node->input(0)->GetVobj()->GetPyObject().ptr() == nullptr) {
747 return false;
748 }
749
750 py::object callable_object = call_node->input(0)->GetVobj()->GetPyObject();
751 std::vector<py::object> args;
752 std::transform(call_node->getInputs().begin() + 1, call_node->getInputs().end(), std::back_inserter(args),
753 [](ValueNode *n) { return n->GetVobj() ? n->GetVobj()->GetPyObject() : py::object(); });
754 auto pair = Utils::PackCallStackArgs(args, call_node->GetOpcode());
755 if (pair.first.ptr() == nullptr) {
756 return false;
757 }
758 PyTypeObject *callable_type = Py_TYPE(callable_object.ptr());
759
760 AObject *info;
761
762 bool enable_func_graph_eval = force_ms_api || kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kEnableMsApiInfer);
763 if (enable_func_graph_eval) {
764 py::object res = EvalMSAPIValue(callable_object, pair.first, pair.second);
765 info = AObject::Convert(res);
766 } else if (IsPrimitiveType<true>(callable_type) || IsPrimitiveFunctionType<true>(callable_type)) {
767 call_node->SetSubGraph(sub_graph);
768 return InferPrimitive(call_node);
769 } else {
770 info = InferFuncResult(callable_object, pair.first, pair.second, call_node->GetGraph()->Config(), true);
771 }
772
773 call_node->SetVobj(info);
774 if (info->GetPyObject().ptr() != nullptr) {
775 call_node->input(0)->GetVobj()->SetMsFlag(AObject::kMsFlagStandardFunc);
776 }
777 if (call_node->IsConstantValue()) {
778 return CallNodeReturnConst(call_node, sub_graph, call_node->GetVobj());
779 }
780 return false;
781 }
782
InferMappingGet(CallNode * call_node,GraphBuilder * unused=nullptr)783 bool InferMappingGet(CallNode *call_node, GraphBuilder *unused = nullptr) {
784 if (call_node->getInputs().size() == BoundMethodInputSize &&
785 call_node->input(0)->GetVobj()->GetType() == AbstractObjectBase::kTypeBoundMethod) {
786 auto func_node = call_node->input(0);
787 auto self = func_node->input(0);
788 auto param_node = call_node->input(1);
789 if (self->IsConstantValue() && param_node->IsConstantValue()) {
790 Graph *g = call_node->GetSubGraph();
791 JustCallAndSetRes(call_node);
792 return CallNodeReturnConst(call_node, g, call_node->GetVobj());
793 }
794 }
795 SetCallResType<AObject::kTypeAnyValue>(call_node);
796 call_node->SetInlineReason(InlineReason::kInlineFunc_Type_Unsupported);
797 return false;
798 }
799
800 enum FuncKey {
801 FUNC_KEY_EMPTY = 0, // ""
802 FUNC_KEY_PIJIT_CONSTEXPR, // "pijit.constexpr"
803 FUNC_KEY_PIJIT_FORBIDDEN, // "pijit.forbidden"
804 FUNC_KEY_BUILTIN_FUNC, // "builtin.func"
805 FUNC_KEY_LIST_APPEND, // "list.append"
806 FUNC_KEY_DICT_POP, // "dict.pop"
807 FUNC_KEY_PRIMITIVE, // "mindspore._c_expression.Primitive_"
808 FUNC_KEY_META_FUNCG_RAPH, // "mindspore._c_expression.MetaFuncGraph_"
809 FUNC_KEY_PSJIT_CODE, // "mindspore.common.api.jit.<locals>.staging_specialize"
810 FUNC_KEY_CONSTEXPR, // "mindspore.ops.primitive.constexpr"
811 FUNC_KEY_PRIMEXPR, // "mindspore.ops.primitive._primexpr"
812 FUNC_KEY_GET_CACHE_PRIM, // "mindspore.ops._primitive_cache._get_cache_prim"
813 FUNC_KEY_REGISTRY_GET, // "mindspore.common._register_for_tensor.Registry.get"
814 FUNC_KEY_TENSOR_ASTYPE, // "mindspore.common.tensor.Tensor.astype"
815 FUNC_KEY_GRAD_OPERATIONS_CODE, // "mindspore.ops.composite.base._Grad.__call__.<locals>.after_grad"
816 FUNC_KEY_PSJIT_CONVERTMAP, // "mindspore._extends.parse.resources.convert_object_map"
817 FUNC_KEY_GRAPH_CELL, // "mindspore.nn.cell.GraphCell"
818 FUNC_KEY_MS_API, // mindspore api
819 FUNC_KEY_MAPPING_GET, // mapping get
820 FUNC_KEY_COUNT,
821 };
822 static FuncKey FindFuncKey(const py::object &callable);
823
824 static const std::unordered_map<FuncKey, InferFunc> infer_func_map = {
825 {FUNC_KEY_PIJIT_CONSTEXPR, JustCallAndSetRes},
826 {FUNC_KEY_PIJIT_FORBIDDEN, SetForbiddenFuncInfo},
827 {FUNC_KEY_BUILTIN_FUNC, InferBuiltinFuncOrMethod},
828 {FUNC_KEY_LIST_APPEND, InferListAppend},
829 {FUNC_KEY_DICT_POP, InferDictPop},
830 {FUNC_KEY_PRIMITIVE, InferPrimitive},
831 {FUNC_KEY_META_FUNCG_RAPH, InferMetaFunc},
832 {FUNC_KEY_PSJIT_CODE, InferMsApiFunc<true>},
833 {FUNC_KEY_CONSTEXPR, InferMSConstexpr},
834 {FUNC_KEY_PRIMEXPR, InferMSConstexpr},
835 {FUNC_KEY_GET_CACHE_PRIM, InferGetCachePrim},
836 {FUNC_KEY_REGISTRY_GET, InferRegistryGet},
837 {FUNC_KEY_TENSOR_ASTYPE, InferTensorAsType},
838 {FUNC_KEY_GRAD_OPERATIONS_CODE, InferGradFunc},
839 {FUNC_KEY_PSJIT_CONVERTMAP, InferConvertMap},
840 {FUNC_KEY_GRAPH_CELL, SetCallResType<AObject::kTypeTensor>},
841 {FUNC_KEY_MS_API, InferMsApiFunc<false>},
842 {FUNC_KEY_MAPPING_GET, InferMappingGet},
843 };
844
845 static const std::unordered_map<FuncKey, InferFunc> mind_infer_func_map = {
846 {FUNC_KEY_PIJIT_CONSTEXPR, JustCallAndSetRes}, {FUNC_KEY_PIJIT_FORBIDDEN, SetForbiddenFuncInfo},
847 {FUNC_KEY_LIST_APPEND, InferListAppend}, {FUNC_KEY_DICT_POP, InferDictPop},
848 {FUNC_KEY_BUILTIN_FUNC, InferBuiltinFuncOrMethod}, {FUNC_KEY_PSJIT_CODE, SetCallResType<AObject::kTypeTensor>},
849 {FUNC_KEY_GET_CACHE_PRIM, InferGetCachePrim}, {FUNC_KEY_REGISTRY_GET, InferRegistryGet},
850 };
851
FindInferFunc(const py::object & callable,bool trace_flag)852 InferFunc FindInferFunc(const py::object &callable, bool trace_flag) {
853 FuncKey k = FindFuncKey(callable);
854 const auto &map = trace_flag ? mind_infer_func_map : infer_func_map;
855 auto iter = map.find(k);
856 if (iter != map.end()) {
857 return iter->second;
858 }
859 return nullptr;
860 }
861
GetFuncKeyMap()862 static const std::unordered_map<size_t, FuncKey> &GetFuncKeyMap() {
863 static std::unordered_map<size_t, FuncKey> map = {};
864 if (!map.empty()) {
865 return map;
866 }
867 py::object func_map = Utils::GetModuleAttr(kModuleName, kFuncMapName, true, true);
868 MS_EXCEPTION_IF_CHECK_FAIL(PyDict_CheckExact(func_map.ptr()), "white list func map must be 'dict[int, str]'");
869 PyObject *key;
870 PyObject *value;
871 Py_ssize_t pos = 0;
872 while (PyDict_Next(func_map.ptr(), &pos, &key, &value)) {
873 MS_EXCEPTION_IF_CHECK_FAIL(PyLong_CheckExact(key), "white list func map key must be 'int'");
874 MS_EXCEPTION_IF_CHECK_FAIL(PyLong_CheckExact(value), "white list func map value must be 'int'");
875 size_t k = (PyLong_AsSize_t(value));
876 MS_EXCEPTION_IF_CHECK_FAIL(k < FUNC_KEY_COUNT, "white list func map got error FuncKey " + std::to_string(k));
877 map[PyLong_AsSize_t(key)] = static_cast<FuncKey>(k);
878 }
879 return map;
880 }
881
KeyFinderFuncId(const py::object & callable)882 static FuncKey KeyFinderFuncId(const py::object &callable) {
883 auto iter = GetFuncKeyMap().find(FunctionId(callable));
884 return iter != GetFuncKeyMap().end() ? iter->second : FUNC_KEY_EMPTY;
885 }
886
KeyFinderFuncCodeId(const py::object & callable)887 static FuncKey KeyFinderFuncCodeId(const py::object &callable) {
888 PyObject *func = callable.ptr();
889 py::object handle;
890 if (IsCellType<true>(Py_TYPE(func))) {
891 handle = callable.attr("construct");
892 func = handle.ptr();
893 }
894 if (PyMethod_Check(func)) {
895 func = PyMethod_GET_FUNCTION(func);
896 }
897 if (PyFunction_Check(func)) {
898 func = PyFunction_GET_CODE(func);
899 }
900 if (!PyCode_Check(func)) {
901 return FUNC_KEY_EMPTY;
902 }
903 auto iter = GetFuncKeyMap().find(reinterpret_cast<size_t>(func));
904 return iter != GetFuncKeyMap().end() ? iter->second : FUNC_KEY_EMPTY;
905 }
906
KeyFinderPrimitive(const py::object & callable)907 static FuncKey KeyFinderPrimitive(const py::object &callable) {
908 PyTypeObject *type_object = Py_TYPE(callable.ptr());
909 bool convert_to_prim = IsPrimitiveType<true>(type_object) || IsPrimitiveFunctionType<true>(type_object);
910 if (!convert_to_prim) {
911 return FUNC_KEY_EMPTY;
912 }
913 py::object func = py::getattr(reinterpret_cast<PyObject *>(type_object), kSlotCallName, nullptr);
914 size_t id;
915 if (func.ptr() == nullptr) {
916 // primitive not defined slot __call__, use it self as id
917 id = reinterpret_cast<size_t>(callable.ptr());
918 } else if (PyFunction_Check(func.ptr())) {
919 // primitive defined python function __call__
920 id = reinterpret_cast<size_t>(PyFunction_GET_CODE(func.ptr()));
921 } else {
922 // primitive defined cpp function __call__
923 id = FunctionId(func);
924 }
925 // first, find map to check special primitive.
926 auto iter = GetFuncKeyMap().find(id);
927 return iter != GetFuncKeyMap().end() ? iter->second : FUNC_KEY_PRIMITIVE;
928 }
929
KeyFinderMetaFunc(const py::object & callable)930 static FuncKey KeyFinderMetaFunc(const py::object &callable) {
931 PyTypeObject *type_object = reinterpret_cast<PyTypeObject *>(callable.ptr());
932 type_object = PyType_CheckExact(type_object) ? type_object : Py_TYPE(type_object);
933 return IsMetaFuncGraphType<true>(type_object) ? FUNC_KEY_META_FUNCG_RAPH : FUNC_KEY_EMPTY;
934 }
935
KeyFinderGraphCell(const py::object & callable)936 static FuncKey KeyFinderGraphCell(const py::object &callable) {
937 static size_t id = 0;
938 if (id == 0) {
939 py::object type = Utils::GetModuleAttr("mindspore.nn.cell", "GraphCell", false, true);
940 id = reinterpret_cast<size_t>(type.ptr());
941 }
942 PyTypeObject *type_object = reinterpret_cast<PyTypeObject *>(callable.ptr());
943 type_object = PyType_CheckExact(type_object) ? type_object : Py_TYPE(type_object);
944 size_t cur_id = reinterpret_cast<size_t>(type_object);
945 return cur_id == id ? FUNC_KEY_GRAPH_CELL : FUNC_KEY_EMPTY;
946 }
947
KeyFinderSkipModule(const py::object & callable)948 static FuncKey KeyFinderSkipModule(const py::object &callable) {
949 const auto &modules = kPIJitConfigDefault.allowed_inline_modules();
950 std::string mod = GetTopModule(callable);
951 if (modules.find(mod) != modules.end()) {
952 return FUNC_KEY_EMPTY;
953 }
954
955 PyObject *func_info = callable.ptr();
956 if (PyMethod_Check(func_info)) {
957 func_info = PyMethod_GET_FUNCTION(func_info);
958 }
959 if (!PyFunction_Check(func_info) && !PyCFunction_Check(func_info) && !PyType_Check(func_info)) {
960 func_info = reinterpret_cast<PyObject *>(Py_TYPE(func_info));
961 }
962 MS_LOG(DEBUG) << "func " << std::string(py::str(func_info)) << " is forbidden to analyze, module is " << mod;
963 return FUNC_KEY_PIJIT_FORBIDDEN;
964 }
965
FindFuncKey(const py::object & callable)966 static FuncKey FindFuncKey(const py::object &callable) {
967 if (callable.ptr() == nullptr || !PyCallable_Check(callable.ptr())) {
968 return FUNC_KEY_EMPTY;
969 }
970 std::vector<FuncKey (*)(const py::object &callable)> finders = {
971 KeyFinderFuncId, KeyFinderFuncCodeId, KeyFinderPrimitive,
972 KeyFinderMetaFunc, KeyFinderGraphCell, KeyFinderSkipModule, // must be last for check modules
973 };
974 FuncKey res = FUNC_KEY_EMPTY;
975 for (auto iter = finders.begin(), end = finders.end(); iter != end && res == FUNC_KEY_EMPTY; ++iter) {
976 res = (*iter)(callable);
977 }
978 return res;
979 }
980
CheckJitConstexpr(const py::object & func)981 bool CheckJitConstexpr(const py::object &func) {
982 if (func.ptr() == nullptr) {
983 return false;
984 }
985 FuncKey k = KeyFinderFuncId(func);
986 return k == FUNC_KEY_PIJIT_CONSTEXPR;
987 }
988
CheckConstexpr(const py::object & func)989 static bool CheckConstexpr(const py::object &func) { return KeyFinderPrimitive(func) == FUNC_KEY_CONSTEXPR; }
990
CheckMSConstexpr(const py::object & func)991 bool CheckMSConstexpr(const py::object &func) {
992 if (func.ptr() == nullptr) {
993 return false;
994 }
995 FuncKey k = KeyFinderPrimitive(func);
996 return k == FUNC_KEY_CONSTEXPR || k == FUNC_KEY_PRIMEXPR;
997 }
998
CheckBuiltinFuncOrMethod(const py::object & func)999 bool CheckBuiltinFuncOrMethod(const py::object &func) {
1000 if (func.ptr() == nullptr) {
1001 return false;
1002 }
1003 FuncKey k = KeyFinderFuncId(func);
1004 return k == FUNC_KEY_BUILTIN_FUNC;
1005 }
1006
1007 } // namespace pijit
1008 } // namespace mindspore
1009