• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 
2 /**
3  * Copyright 2023 Huawei Technologies Co., Ltd
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 #include "pipeline/jit/pi/graph_capture/special_func_infer.h"
18 #include <algorithm>
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <unordered_map>
24 #include <utility>
25 #include <vector>
26 #include "pipeline/jit/pi/common.h"
27 #include "pipeline/jit/pi/external.h"
28 #include "pipeline/jit/pi/graph_capture/graph_build.h"
29 #include "pipeline/jit/pi/graph_guard/infer.h"
30 #include "pipeline/jit/pi/graph_capture/side_effect.h"
31 
32 namespace mindspore {
33 namespace pijit {
34 extern ValueNode *GetBoundSelf(CallNode *call_node);
35 extern void LogGuardFailed(ValueNode *node, const GraphJitConfig &conf, const std::string &msg);
36 extern AObject *InferFuncResult(const py::object &func, const std::vector<AObject *> &stack_args, int opcode,
37                                 const GraphJitConfig &conf, bool clear_guard);
38 extern AObject *InferFuncResult(const py::object &func, const py::object &args, const py::object &kwargs,
39                                 const GraphJitConfig &conf, bool clear_guard);
40 
41 constexpr const char *kModuleName = "mindspore._extends.pijit.pijit_func_white_list";
42 constexpr const char *kFuncMapName = "_func_map";
43 constexpr const char *kSlotCallName = "__call__";
44 constexpr const size_t kDictPopParamsNum = 2;
45 constexpr const size_t BoundMethodInputSize = 2;
46 
47 static bool CheckConstexpr(const py::object &func);
48 
49 template <AObject::Type type>
SetCallResType(CallNode * call_node,GraphBuilder * unused=nullptr)50 static bool SetCallResType(CallNode *call_node, GraphBuilder *unused = nullptr) {
51   call_node->SetVobj(AObject::MakeAObject(type));
52   call_node->SetSubGraph(nullptr);
53   return false;
54 }
55 
JustCallAndSetRes(CallNode * call_node,GraphBuilder * unused)56 bool JustCallAndSetRes(CallNode *call_node, GraphBuilder *unused) {
57   py::object func = call_node->input(0)->GetVobj()->GetPyObject();
58   if (func.ptr() == nullptr) {
59     return SetCallResType<AObject::kTypeAnyValue>(call_node);
60   }
61 
62   std::vector<py::object> args;
63   std::transform(call_node->getInputs().begin() + 1, call_node->getInputs().end(), std::back_inserter(args),
64                  [](ValueNode *n) { return n->GetVobj() ? n->GetVobj()->GetPyObject() : py::object(); });
65   auto pair = Utils::PackCallStackArgs(args, call_node->GetOpcode());
66   if (pair.first.ptr() == nullptr) {
67     return SetCallResType<AObject::kTypeAnyValue>(call_node);
68   }
69 
70   pi_jit_disable();
71   PyObject *value = PyObject_Call(func.ptr(), pair.first.ptr(), pair.second.ptr());
72   if (PyErr_Occurred()) {
73     MS_LOG(INFO) << "got an error " << py::error_already_set().what() << " at call the "
74                  << std::string(py::str(func.ptr()));
75     PyErr_Clear();
76   }
77   pi_jit_enable();
78 
79   call_node->SetVobj(AObject::Convert(value));
80   call_node->SetSubGraph(nullptr);
81   Py_XDECREF(value);
82   return false;
83 }
84 
CallNodeReturnConst(CallNode * call_node,Graph * sub_graph,AObject * value)85 static bool CallNodeReturnConst(CallNode *call_node, Graph *sub_graph, AObject *value) {
86   PyObject *cnst = value->GetPyObject().ptr();
87   MS_EXCEPTION_IF_NULL(cnst);
88 
89   ValueNode *ret_node = sub_graph->NewValueNode(value, LOAD_CONST, -1, {});
90   call_node->SetSubGraph(sub_graph);
91   ret_node->SetGraph(call_node->GetGraph());
92 
93   sub_graph->SetRetVal(ret_node);
94   call_node->SetInlineReason(InlineReason::kInline);
95   return true;
96 }
97 
GuardConstCallNodeParam(CallNode * call_node,Graph * sub_graph,int max_guard_depth)98 bool GuardConstCallNodeParam(CallNode *call_node, Graph *sub_graph, int max_guard_depth) {
99   std::vector<std::pair<TracePtr, GuardLevel>> traces;
100   for (auto i : call_node->getInputs()) {
101     if (i->IsConstantValue()) {
102       continue;
103     }
104     AObject::Type type = i->GetVobj() ? i->GetVobj()->GetType() : AObject::kTypeAnyValue;
105     if (type == AObject::kTypeAnyValue) {
106       return false;
107     }
108     TracePtr tr = sub_graph->TraceValueNode(i, max_guard_depth);
109     if (tr == nullptr) {
110       if (static_cast<size_t>(max_guard_depth) >= INT_MAX) {
111         LogGuardFailed(i, sub_graph->Config(), "GuardConstCannNodeParm failed");
112       }
113       return false;
114     }
115     GuardLevel level = GuardLevel::GEqual;
116     if (type == AObject::kTypeTensor) {
117       if (i->GetOpcode() == LOAD_GLOBAL) {
118         level = GuardLevel::GId;  // only guard global tensor
119       } else {
120         level = GuardLevel::GDeduce;
121       }
122     }
123     traces.push_back({tr, level});
124   }
125 
126   const auto &guard = sub_graph->GetGuard()->GetGuard();
127   guard->Backup();
128   for (const auto &i : traces) {
129     if (!guard->GuardOn(i.first, i.second)) {
130       guard->Rollback();
131       return false;
132     }
133   }
134   guard->Pop();
135   return true;
136 }
137 
InferConvertMap(CallNode * call_node,GraphBuilder * unused=nullptr)138 static bool InferConvertMap(CallNode *call_node, GraphBuilder *unused = nullptr) {
139   AObject *func_info = call_node->input(0)->GetVobj();
140   func_info->SetMsFlag(AObject::kMsFlagStandardFunc);
141   py::object func = func_info->GetPyObject();
142   py::object tmp = Utils::GetModuleAttr("mindspore._extends.parse.resources", "convert_object_map");
143   auto dict_obj = py::cast<py::dict>(tmp);
144   auto infer_obj = dict_obj[func];
145   AObject *res = nullptr;
146   call_node->SetSubGraph(nullptr);
147   SetCallResType<AObject::kTypeTensor>(call_node);
148   if (PyFunction_Check(infer_obj.ptr())) {
149     MS_LOG(DEBUG) << "infer function " << std::string(py::str(PyFunction_GET_CODE(infer_obj.ptr())));
150     int op = call_node->GetOpcode();
151     const auto &conf = call_node->GetGraph()->Config();
152     std::vector<AObject *> args;
153     std::transform(call_node->getInputs().begin() + 1, call_node->getInputs().end(), std::back_inserter(args),
154                    [](ValueNode *n) { return n->GetVobj(); });
155     res = InferFuncResult(func, {args.begin() + 1, args.end()}, op, conf, true);
156   } else if (IsPrimitiveType<true>(Py_TYPE(infer_obj.ptr()))) {
157     MS_LOG(DEBUG) << "infer primitive " << std::string(py::str(infer_obj));
158     std::vector<PyObject *> list;
159     bool infer_fail = false;
160     for (size_t i = 1; !infer_fail && i < call_node->getInputs().size(); i++) {
161       AObject *p = call_node->input(i)->GetVobj();
162       PyObject *o = p ? p->GetPyObject().ptr() : nullptr;
163       list.push_back(o);
164       infer_fail = o == nullptr;
165     }
166     if (infer_fail) {
167       return false;
168     }
169     auto inst = mindspore::pijit::InferEngine::GetInstance();
170     bool is_abstract = false;
171     PyObject *ret = inst->InferPrimitive(infer_obj.ptr(), list, &is_abstract);
172     if (ret == nullptr) {
173       return false;
174     }
175     AObject::Type type = AObject::GetPyType(ret);
176     res = is_abstract && type != AObject::kTypeTensor ? AObject::MakeAObject(type) : AObject::Convert(ret);
177     Py_DECREF(ret);
178   } else {
179     return false;
180   }
181   if (res) {
182     call_node->SetVobj(res);
183   }
184   return false;
185 }
186 
InferGetCachePrim(CallNode * n,GraphBuilder * unused=nullptr)187 static bool InferGetCachePrim(CallNode *n, GraphBuilder *unused = nullptr) {
188   // just return the first parameter of _get_cache_prim
189   Graph *g = n->GetSubGraph();
190   n->SetVobj(n->input(1)->GetVobj());
191   g->SetRetVal(n->input(1));
192   return true;
193 }
194 
InferRegistryGet(CallNode * call_node,GraphBuilder * unused=nullptr)195 static bool InferRegistryGet(CallNode *call_node, GraphBuilder *unused = nullptr) {
196   Graph *g = call_node->GetSubGraph();
197   JustCallAndSetRes(call_node);
198 
199   py::object func = call_node->GetVobj()->GetPyObject();
200   if (call_node->getInputs().back()->GetOpcode() == LOAD_CONST && func.ptr() != nullptr) {
201     return CallNodeReturnConst(call_node, g, call_node->GetVobj());
202   }
203   return false;
204 }
205 
InferPrimitive(CallNode * call_node,GraphBuilder * unused=nullptr)206 static bool InferPrimitive(CallNode *call_node, GraphBuilder *unused = nullptr) {
207   static const std::unordered_map<std::string, AObject::Type> not_ret_tensor_prim = {
208     {"Prim[_get_grad_op]<constexpr_prim=True>", AObject::kTypeMetaFuncGraph},
209     {"Prim[DType]", AObject::kTypeAnyValue},
210     {"Prim[Partial]<side_effect_propagate=1>", AObject::kTypeAnyValue},
211   };
212   Graph *sub_graph = call_node->GetSubGraph();
213   call_node->SetVobj(AObject::MakeAObject(AObject::kTypeTensor));
214   call_node->SetSubGraph(nullptr);
215   PyObject *prim = call_node->input(0)->GetVobj()->GetPyObject().ptr();
216   std::string prim_key = std::string(py::str(prim));
217   if (prim_key == "Prim[_get_grad_op]<constexpr_prim=True>") {
218     py::object grad_class = Utils::GetModuleAttr("mindspore._c_expression", "GradOperation_");
219     AbstractType *type = static_cast<AbstractType *>(AObject::Convert(grad_class));
220     AObject *res = type != nullptr ? type->BuildAbstractInstance({}, CALL_FUNCTION)
221                                    : AObject::MakeAObject(AObject::kTypeMetaFuncGraph);
222     call_node->SetVobj(res);
223     return false;
224   }
225 
226   auto iter = not_ret_tensor_prim.find(prim_key);
227   if (iter != not_ret_tensor_prim.end()) {
228     call_node->SetVobj(AObject::MakeAObject(iter->second));
229   } else {
230     call_node->SetVobj(AObject::MakeAObject(AObject::kTypeTensor));
231   }
232 
233   std::vector<PyObject *> list;
234   bool infer_fail = false;
235   for (size_t i = 1; !infer_fail && i < call_node->getInputs().size(); i++) {
236     AObject *p = call_node->input(i)->GetVobj();
237     if (p == nullptr) {
238       infer_fail = true;
239       break;
240     }
241     PyObject *o;
242     if (p->GetType() == AObject::kTypeTensor) {
243       o = static_cast<AbstractTensor *>(p)->GetTensor(true).ptr();
244     } else {
245       o = p->GetPyObject().ptr();
246     }
247     list.push_back(o);
248     infer_fail = o == nullptr;
249   }
250   if (infer_fail) {
251     return false;
252   }
253 
254   auto inst = mindspore::pijit::InferEngine::GetInstance();
255   bool is_abstract = false;
256   PyObject *ret;
257   try {
258     ret = inst->InferPrimitive(prim, list, &is_abstract);
259   } catch (std::exception &e) {
260     MS_LOG(ERROR) << "infer primitive failed. reason:";
261     MS_LOG(ERROR) << e.what();
262     ret = nullptr;
263   }
264   if (ret == nullptr) {
265     return false;
266   }
267 
268   AObject::Type type = AObject::GetPyType(ret);
269   AObject *type_info = is_abstract && type != AObject::kTypeTensor ? AObject::MakeAObject(type) : AObject::Convert(ret);
270   call_node->SetVobj(type_info);
271   Py_DECREF(ret);
272 
273   ConstantInfo::CollectPrimitiveConstantInfo(call_node);
274   if (call_node->IsConstantValue()) {
275     return CallNodeReturnConst(call_node, sub_graph, call_node->GetVobj());
276   }
277   return false;
278 }
279 
InferGradOperation(CallNode * call_node,AObject::MindsporeFlag f)280 static bool InferGradOperation(CallNode *call_node, AObject::MindsporeFlag f) {
281   call_node->SetSubGraph(nullptr);
282   AObject *grad_func = AObject::MakeAObject(AObject::kTypeFunction);
283   grad_func->SetMsFlag(f);
284   call_node->SetVobj(grad_func);
285   py::object func = GraphBuilder::FindPyFunc(call_node->input(1)->GetVobj());
286   if (func.ptr() == nullptr) {
287     return false;
288   }
289   (void)pi_jit_should_compile(func, py::dict(), py::none());
290   auto jcr = getJitCompileResults(PyFunction_GET_CODE(func.ptr()));
291   *jcr->conf = call_node->GetGraph()->Config();
292   return false;
293 }
294 
InferMetaFunc(CallNode * call_node,GraphBuilder * unused=nullptr)295 static bool InferMetaFunc(CallNode *call_node, GraphBuilder *unused = nullptr) {
296   call_node->SetSubGraph(nullptr);
297   const auto &vo = call_node->input(0)->GetVobj();
298   MS_EXCEPTION_IF_CHECK_FAIL(vo->GetType() != AObject::kTypeType, "class call is before ");
299   PyTypeObject *tp = vo->GetTypeObject();
300   if (IsGradOperationType<true>(tp)) {
301     // set grad flag
302     return InferGradOperation(call_node, AObject::MindsporeFlag::kMsFlagGradFunc);
303   } else if (IsVmapOperationType<true>(tp)) {
304     // set vmap flag
305     return InferGradOperation(call_node, AObject::MindsporeFlag::kMsFlagVmapFunc);
306   } else if (IsShardType<true>(tp)) {
307     // set shard flag
308     return InferGradOperation(call_node, AObject::MindsporeFlag::kMsFlagShardFunc);
309   }
310   return false;
311 }
312 
313 /**
314  * find first free variable in names from function
315  */
FindClosure(const py::object & o,const std::vector<std::string> & names,TracePtr * trace,bool strict,bool print)316 static py::object FindClosure(const py::object &o, const std::vector<std::string> &names, TracePtr *trace, bool strict,
317                               bool print) {
318   PyObject *func = o.ptr();
319   if (PyMethod_Check(func)) {
320     func = PyMethod_GET_FUNCTION(func);
321   }
322   if (!PyFunction_Check(func)) {
323     return py::object();
324   }
325   PyCodeObject *co = reinterpret_cast<PyCodeObject *>(PyFunction_GET_CODE(func));
326   PyObject *closure = PyFunction_GET_CLOSURE(func);
327   Py_ssize_t i = PyTuple_GET_SIZE(co->co_freevars) - 1;
328   bool find = false;
329   for (; i >= 0 && !find; --i) {
330     std::string name = PyUnicode_AsUTF8(PyTuple_GET_ITEM(co->co_freevars, i));
331     find = std::find(names.begin(), names.end(), name) != names.end();
332   }
333   if (!find) {
334     return py::object();
335   }
336   Py_ssize_t idx = i + 1;
337   PyObject *cell = PyTuple_GET_ITEM(closure, idx);
338   PyObject *content = PyCell_GET(cell);
339   if (trace) {
340     TracePtr attr = CreateOpTrace(closure, LOAD_ATTR, 0, {*trace}, "", "__closure__", strict, print);
341     TracePtr cc = CreateOpTrace(cell, BINARY_SUBSCR, 0, {attr, std::make_shared<ConstTrace>(py::int_(idx).ptr(), -1)},
342                                 "", "", strict, print);
343     *trace = CreateOpTrace(content, LOAD_ATTR, 0, {cc}, "", "cell_contents", strict, print);
344   }
345   return py::cast<py::object>(content);
346 }
347 
348 /**
349  * get decorated function from 'after_grad'
350  * \param after_grad _Grad.__call__.<locals>.after_grad
351  * \return decorated object
352  */
GetGradDecorated(const py::object & after_grad,TracePtr * trace,bool strict,bool print)353 static py::object GetGradDecorated(const py::object &after_grad, TracePtr *trace, bool strict, bool print) {
354   MS_ASSERT(PyFunction_Check(after_grad.ptr()));
355   py::object decorated = FindClosure(after_grad, {"fn", "fn_"}, trace, strict, print);
356   MS_EXCEPTION_IF_CHECK_FAIL(decorated.ptr() != nullptr, "can't find decorated function 'fn' or 'fn_' from " +
357                                                            std::string(py::str(after_grad.ptr())));
358   if (!PyFunction_Check(decorated.ptr())) {
359     return decorated;
360   }
361   std::string decorated_name = PyUnicode_AsUTF8(reinterpret_cast<PyFunctionObject *>(decorated.ptr())->func_qualname);
362   if (decorated_name == "_Grad.__call__.<locals>.aux_fn") {
363     decorated = FindClosure(decorated, {"fn"}, trace, strict, print);
364     MS_EXCEPTION_IF_CHECK_FAIL(decorated.ptr() != nullptr, "can't find decorated function 'fn' from " + decorated_name);
365   }
366   return decorated;
367 }
368 
DeleteGradSensArgs(const py::object & args,const py::object & kwargs)369 static py::object DeleteGradSensArgs(const py::object &args, const py::object &kwargs) {
370   // sens param specified in kwargs
371   if (kwargs.ptr() != nullptr && PyDict_DelItemString(kwargs.ptr(), "sens_param") != -1) {
372     return args;
373   }
374   PyErr_Clear();
375   // sens param is the last position argument
376   PyObject *new_arg = PyTuple_GetSlice(args.ptr(), 0, PyTuple_GET_SIZE(args.ptr()) - 1);
377   return py::reinterpret_steal<py::object>(new_arg);
378 }
379 
InferGradFuncResult(const py::object & func,const py::object & args,const py::object & kwargs,const GraphJitConfig & conf)380 static AObject *InferGradFuncResult(const py::object &func, const py::object &args, const py::object &kwargs,
381                                     const GraphJitConfig &conf) {
382   auto jcr = getJitCompileResults(func.ptr());
383   *jcr->conf = conf;
384   return InferFuncResult(func, args, kwargs, conf, true);
385 }
386 
387 /**
388  * Use the function decorated by 'after_grad' and arguments of 'after_grad' when called to infer result.
389  * If the function has no unsupported operation, merge the guard of inferred graph to caller graph.
390  * else clear the mask of mindspore flag, avoid to capture this function call
391  */
HandleGradFuncCall(CallNode * call_node,AObject * decorated,bool sens_param,const py::object & after_grad)392 void HandleGradFuncCall(CallNode *call_node, AObject *decorated, bool sens_param, const py::object &after_grad) {
393   const int except_flag = AObject::kMsFlagGradFunc | AObject::kMsFlagShardFunc | AObject::kMsFlagVmapFunc;
394   ValueNode *grad_func_node = call_node->input(0);
395   std::vector<py::object> stack_args;
396   py::object func;
397   py::object args;
398   py::object kwargs;
399 
400   // prepare parameters
401   bool param_ready = decorated->GetPyObject().ptr() != nullptr;
402   for (size_t i = 1; param_ready && i < call_node->getInputs().size(); ++i) {
403     AObject *tmp = call_node->input(i)->GetVobj();
404     stack_args.emplace_back(tmp != nullptr ? tmp->GetPyObject() : py::object());
405     param_ready = stack_args.back().ptr() != nullptr;
406   }
407   if (param_ready) {
408     auto pair = Utils::PackCallStackArgs(stack_args, call_node->GetOpcode());
409     args = pair.first;
410     kwargs = pair.second;
411     param_ready = pair.first.ptr() != nullptr;
412   }
413   if (!param_ready) {
414     call_node->SetInlineReason(InlineReason::kInlineInfer_Fail);
415     grad_func_node->GetVobj()->ClearMsFlag(except_flag);
416     return;
417   }
418   if (sens_param) {
419     args = DeleteGradSensArgs(args, kwargs);
420   }
421 
422   // get callable
423   if (decorated->GetType() != AObject::kTypeCell) {
424     MS_EXCEPTION_IF_CHECK_FAIL(decorated->GetType() == AObject::kTypeFunction, "check grad input");
425     func = decorated->GetPyObject();
426   } else {
427     // here get bound method.
428     func = decorated->GetAttr(GraphBuilder::ID_construct)->GetPyObject();
429   }
430 
431   AObject *res = InferGradFuncResult(func, args, kwargs, call_node->GetGraph()->Config());
432   if (res == nullptr || !res->IsMindSporeSupportedType()) {
433     call_node->SetInlineReason(InlineReason::kInlineInfer_Fail);
434     grad_func_node->GetVobj()->ClearMsFlag(except_flag);
435     return;
436   }
437   py::object infer_after_grad = Utils::GetModuleAttr(kModuleName, "infer_after_grad", true, true);
438   py::object result;
439   try {
440     result = infer_after_grad(after_grad, args, res->GetPyObject());
441   } catch (std::exception &e) {
442     MS_LOG(WARNING) << "Error while infer_after_grad, error:" << e.what();
443     PyErr_Clear();
444   }
445   if (result.ptr() != nullptr && result.ptr() != Py_None) {
446     call_node->SetVobj(AObject::Convert(result));
447   } else {
448     call_node->SetVobj(res);
449   }
450   call_node->SetInlineReason(InlineReason::kInlineGraphSupportedByMS);
451 }
452 
HandleGradFunc(CallNode * call_node,const py::object & after_grad,TracePtr * trace)453 static void HandleGradFunc(CallNode *call_node, const py::object &after_grad, TracePtr *trace) {
454   auto config = call_node->GetGraph()->Config();
455   bool strict = config.GetBoolConfig(GraphJitConfig::kStrictTrace);
456   bool print = config.GetBoolConfig(GraphJitConfig::kPrintGuard);
457   py::object decorated_func = GetGradDecorated(after_grad, trace, strict, print);
458   TracePtr ptr = *trace;
459   py::object grad = FindClosure(after_grad, {"grad_", "self"}, &ptr, strict, print);
460   MS_EXCEPTION_IF_CHECK_FAIL(grad.ptr() != nullptr,
461                              "can't find 'grad_' object from " + std::string(py::str(after_grad.ptr())));
462   bool sens_param = grad.attr("sens_param").ptr() == Py_True;
463   MS_LOG(DEBUG) << "infer function 'after_grad', has sens_param " << (sens_param ? "True" : "False");
464 
465   auto guard = call_node->GetGraph()->GetGuard()->GetGuard();
466   guard->GuardOn(*trace, mindspore::pijit::GuardLevel::GEqual);
467   if (config.GetBoolConfig(GraphJitConfig::kGuardDetachObject)) {
468     (*trace)->Detach();
469   }
470   call_node->SetSubGraph(nullptr);
471   HandleGradFuncCall(call_node, AObject::Convert(decorated_func), sens_param, after_grad);
472 }
473 
InferGradFunc(CallNode * call_node,GraphBuilder * unused=nullptr)474 static bool InferGradFunc(CallNode *call_node, GraphBuilder *unused = nullptr) {
475   AObject *vo = call_node->input(0)->GetVobj();
476   vo->SetMsFlag(AObject::kMsFlagGradFunc);
477   py::object after_grad = vo->GetPyObject();
478   TracePtr trace = call_node->GetGraph()->TraceValueNode(call_node->input(0));
479   if (trace == nullptr) {
480     vo->ClearMsFlag(AObject::kMsFlagGradFunc);
481     call_node->SetSubGraph(nullptr);
482     return false;
483   }
484   HandleGradFunc(call_node, after_grad, &trace);
485   return false;
486 }
487 
InferMSConstexpr(CallNode * call_node,GraphBuilder * unused=nullptr)488 static bool InferMSConstexpr(CallNode *call_node, GraphBuilder *unused = nullptr) {
489   Graph *g = call_node->GetSubGraph();
490   JustCallAndSetRes(call_node);
491 
492   py::object cnst = call_node->GetVobj()->GetPyObject();
493   if (cnst.ptr() == nullptr) {
494     return false;
495   }
496   bool is_constexpr = CheckConstexpr(call_node->input(0)->GetVobj()->GetPyObject());
497   constexpr int max_guard_depth = 2;
498   if (is_constexpr || GuardConstCallNodeParam(call_node, g, max_guard_depth)) {
499     return CallNodeReturnConst(call_node, g, call_node->GetVobj());
500   }
501   return false;
502 }
503 
GuardBuiltinFunc(CallNode * call_node)504 static bool GuardBuiltinFunc(CallNode *call_node) {
505   if (call_node->input(0)->GetVobj() == nullptr) {
506     return false;
507   }
508   PyObject *func = call_node->input(0)->GetVobj()->GetPyObject().ptr();
509   if (PyMethod_Check(func)) {
510     auto self = PyMethod_GET_SELF(func);
511     if (IsTensorType<true>(Py_TYPE(self)) && !CheckTensorDataInitialized(py::cast<py::object>(self))) {
512       // fake value
513       return false;
514     }
515   }
516   Graph *graph = call_node->GetGraph();
517   for (auto i : call_node->getInputs()) {
518     if (i->GetVobj() && i->GetVobj()->GetType() == AObject::kTypeTensor) {
519       AbstractTensor *tensor = static_cast<AbstractTensor *>(i->GetVobj());
520       if (!tensor->IsStubTensor() && !CheckTensorDataInitialized(tensor->GetPyObject())) {
521         // fake value
522         return false;
523       }
524     }
525   }
526   return graph->GuardValueNode(call_node);
527 }
528 
GuardIsInstance(CallNode * call_node)529 static bool GuardIsInstance(CallNode *call_node) {
530   Graph *graph = call_node->GetGraph();
531   const auto &cnst = call_node->input(1)->GetConstantInfo();
532   if (cnst != nullptr && cnst->type() != nullptr) {
533     constexpr int second_arg = 2;
534     auto success = graph->GuardValueNode(call_node->input(second_arg));
535     if (!success && (call_node->GetGraph()->Config().getIntConfig(GraphJitConfig::kGuardRelaxCount) > 0)) {
536       TracePtr tr = graph->TraceValueNode(call_node->input(second_arg));
537       if (tr == nullptr) {
538         return true;
539       }
540     }
541     return success;
542   }
543   auto success = graph->GuardValueNode(call_node);
544   if (!success && (call_node->GetGraph()->Config().getIntConfig(GraphJitConfig::kGuardRelaxCount) > 0)) {
545     TracePtr tr = graph->TraceValueNode(call_node);
546     if (tr == nullptr) {
547       return true;
548     }
549   }
550   return success;
551 }
552 
InferBuiltinFuncOrMethod(CallNode * call_node,GraphBuilder * unused=nullptr)553 bool InferBuiltinFuncOrMethod(CallNode *call_node, GraphBuilder *unused = nullptr) {
554   Graph *sub_graph = call_node->GetSubGraph();
555   (void)JustCallAndSetRes(call_node);
556   ConstantInfo::CollectBuiltinFuncConstantInfo(call_node);
557   if (call_node->IsConstantValue()) {
558     return CallNodeReturnConst(call_node, sub_graph, call_node->GetVobj());
559   }
560   if (call_node->GetVobj() == nullptr || call_node->GetVobj()->GetPyObject().ptr() == nullptr) {
561     return false;
562   }
563 
564   bool guard_success = false;
565   std::string name = GetFuncName(call_node->input(0)->GetVobj()->GetPyObject());
566   if (name == "isinstance") {
567     guard_success = GuardIsInstance(call_node);
568   } else {
569     guard_success = GuardBuiltinFunc(call_node);
570   }
571   if (guard_success) {
572     return CallNodeReturnConst(call_node, sub_graph, call_node->GetVobj());
573   }
574   return false;
575 }
576 
InferTensorAsType(CallNode * call_node,GraphBuilder * unused=nullptr)577 static bool InferTensorAsType(CallNode *call_node, GraphBuilder *unused = nullptr) {
578   ValueNode *self_node = GetBoundSelf(call_node);
579   bool is_not_method = call_node->input(0)->GetVobj()->GetType() != AObject::kTypeBoundMethod;
580   ValueNode *dtype_node = call_node->input(1 + is_not_method);
581 
582   Graph *sub_graph = call_node->GetSubGraph();
583 
584   py::object prim_cast = Utils::GetModuleAttr("mindspore.ops.functional", "cast", false, true);
585 
586   PyTypeObject *tp = Py_TYPE(prim_cast.ptr());
587   std::stringstream s;
588   s << (tp->tp_name ? tp->tp_name : "<unnamed>") << "<" << prim_cast.ptr() << ">";
589 
590   ValueNode *prim_node = sub_graph->NewValueNode(AObject::Convert(prim_cast), LOAD_CONST, -1, {});
591 
592   if (dtype_node->GetVobj()->GetType() == AObject::kTypeString &&
593       dtype_node->GetVobj()->GetPyObject().ptr() != nullptr) {
594     auto dtypeStr = py::cast<std::string>(dtype_node->GetVobj()->GetPyObject());
595     std::vector<std::string> under_line_dtype = {"bool", "int", "float", "list", "tuple"};
596     if (std::find(under_line_dtype.begin(), under_line_dtype.end(), dtypeStr) != under_line_dtype.end()) {
597       dtypeStr = dtypeStr + "_";
598     }
599     auto dtype_obj = Utils::GetModuleAttr("mindspore.common.dtype", dtypeStr, false, true);
600     if (dtype_obj.ptr() != nullptr) {
601       dtype_node = sub_graph->NewValueNode(AObject::Convert(dtype_obj), LOAD_CONST, -1, {});
602     }
603   }
604 
605   std::vector<ValueNode *> cast_args = {prim_node, self_node, dtype_node};
606   CallNode *ret_node = sub_graph->NewCallNode(CALL_FUNCTION, cast_args.size() - 1, cast_args);
607   ret_node->SetGraph(sub_graph);
608   (void)InferPrimitive(ret_node);
609 
610   sub_graph->GetTracedNodes().push_back(prim_node);
611   sub_graph->GetTracedNodes().push_back(ret_node);
612   sub_graph->SetRetVal(ret_node);
613 
614   call_node->SetSubGraph(sub_graph);
615   call_node->SetVobj(ret_node->GetVobj());
616   call_node->SetInlineReason(InlineReason::kInline);
617   return true;
618 }
619 
RecordSideEffectCallNode(Graph * graph,CallNode * call_node,SideEffect::Type type,bool trace_flag)620 static void RecordSideEffectCallNode(Graph *graph, CallNode *call_node, SideEffect::Type type, bool trace_flag) {
621   const auto &side_effect = graph->GetSideEffect();
622   ValueNode *side_effect_node;
623   if (trace_flag) {
624     side_effect_node = call_node;
625   } else {
626     side_effect_node = graph->NewCallNode(call_node->GetOpcode(), call_node->GetOparg(), call_node->getInputs());
627     side_effect_node->SetVobj(AObject::MakeAObject(AObject::kTypeAnyValue));
628     graph->GetTracedNodes().push_back(side_effect_node);
629   }
630   side_effect->Record(side_effect_node, type);
631 }
632 
InferListAppend(CallNode * call_node,GraphBuilder * parent)633 static bool InferListAppend(CallNode *call_node, GraphBuilder *parent) {
634   call_node->SetSubGraph(nullptr);
635 
636   // check is supported type and get arguments
637   bool is_method_descriptor = false;
638   ValueNode *self = GetSelfFromListAppendCall(call_node, &is_method_descriptor);
639   if (self == nullptr) {
640     return false;
641   }
642   ValueNode *new_element = call_node->input(1 + is_method_descriptor);
643 
644   // transform to "new_list = [old_list[0], old_list[1]..., new_element]"
645   int size = parent->frame().GetStacks().size();
646   if (!parent->UnpackElements(self)) {
647     return false;
648   }
649   parent->push(new_element);
650   size = parent->frame().GetStacks().size() - size;
651   parent->DoBuildOp({BUILD_LIST, size});
652   auto new_node = parent->pop();
653   auto old_node = self;
654 
655   // constant fold and set node info
656   auto builder = GraphBuilder::Creator(parent->root(), parent, nullptr, nullptr, parent->trace_flag());
657   Graph *sub_graph = builder->GetGraph();
658   builder->DoLoadConst({LOAD_CONST, 0, py::object(py::none())});
659   builder->DoReturn({RETURN_VALUE, 0});
660 
661   call_node->SetSubGraph(sub_graph);
662   call_node->SetVobj(sub_graph->GetRetVal()->GetVobj());
663   call_node->SetInlineReason(InlineReason::kInline);
664 
665   // update frame status and record side-effect
666   bool is_referenced = false;
667   parent->ReplaceAll(old_node, new_node, &is_referenced);
668   const auto &replace_map = parent->GetGraph()->GetSideEffect()->data()->modified_and_replaced_map();
669   bool is_new_var = self->GetOpcode() == BUILD_LIST && replace_map.find(self) == replace_map.end();
670   if (!is_new_var || is_referenced || self == new_element) {
671     parent->GetGraph()->GetSideEffect()->data()->RecordModifiedAndReplacedNode(old_node, new_node);
672     RecordSideEffectCallNode(parent->GetGraph(), call_node, SideEffect::kListAppend, parent->trace_flag());
673   }
674   return true;
675 }
676 
InferDictPop(CallNode * call_node,GraphBuilder * parent)677 static bool InferDictPop(CallNode *call_node, GraphBuilder *parent) {
678   call_node->SetSubGraph(nullptr);
679 
680   bool is_method_descriptor = false;
681   ValueNode *self = GetSelfFromListAppendCall(call_node, &is_method_descriptor);
682   if (self == nullptr) {
683     return false;
684   }
685   // guard dict key and convert to constant key map
686   if (!parent->GetGraph()->GuardValueNode(self)) {
687     return false;
688   }
689 
690   ValueNode *dict_node = self;
691   ValueNode *key_node = call_node->input(1 + is_method_descriptor);
692   ValueNode *default_node = call_node->getInputs().size() > (kDictPopParamsNum + is_method_descriptor)
693                               ? call_node->input(kDictPopParamsNum + is_method_descriptor)
694                               : nullptr;
695   // get key from dict
696   py::object dict = dict_node->GetVobj()->GetPyObject();
697   py::object key = key_node->GetVobj()->GetPyObject();
698   MS_EXCEPTION_IF_CHECK_FAIL(PyDict_Check(dict.ptr()), "for dict.pop, first parameter must be a dict");
699   py::object value = py::reinterpret_borrow<py::object>(PyDict_GetItem(dict.ptr(), key.ptr()));
700   if (value.ptr() == nullptr) {
701     if (default_node == nullptr) {
702       return false;  // key error
703     }
704     value = default_node->GetVobj()->GetPyObject();
705   }
706 
707   // transform to "new_map = {key:old_map[key]...}"
708   ValueNode *old_node = dict_node;
709   ValueNode *new_node = parent->TransformDictSetItem(dict_node, key_node, nullptr, default_node != nullptr);
710   if (new_node == nullptr) {
711     return false;
712   }
713 
714   // constant fold and set node info
715   auto builder = GraphBuilder::Creator(parent->root(), parent, nullptr, nullptr, parent->trace_flag());
716   Graph *sub_graph = builder->GetGraph();
717   builder->DoLoadConst({LOAD_CONST, 0, value});
718   builder->DoReturn({RETURN_VALUE, 0});
719 
720   call_node->SetSubGraph(sub_graph);
721   call_node->SetVobj(sub_graph->GetRetVal()->GetVobj());
722   call_node->SetInlineReason(InlineReason::kInline);
723 
724   // update frame status and record side-effect
725   bool is_referenced = false;
726   parent->ReplaceAll(old_node, new_node, &is_referenced);
727   const auto &replace_map = parent->GetGraph()->GetSideEffect()->data()->modified_and_replaced_map();
728   bool is_new_var = self->GetOpcode() == BUILD_MAP && replace_map.find(self) == replace_map.end();
729   if (!is_new_var || is_referenced) {
730     parent->GetGraph()->GetSideEffect()->data()->RecordModifiedAndReplacedNode(old_node, new_node);
731     RecordSideEffectCallNode(parent->GetGraph(), call_node, SideEffect::kDictPop, parent->trace_flag());
732   }
733   return true;
734 }
735 
SetForbiddenFuncInfo(CallNode * call_node,GraphBuilder * unused=nullptr)736 static bool SetForbiddenFuncInfo(CallNode *call_node, GraphBuilder *unused = nullptr) {
737   SetCallResType<AObject::kTypeAnyValue>(call_node);
738   call_node->SetInlineReason(InlineReason::kInlineFunc_Type_Unsupported);
739   return false;
740 }
741 
742 template <bool force_ms_api>
InferMsApiFunc(CallNode * call_node,GraphBuilder * unused=nullptr)743 bool InferMsApiFunc(CallNode *call_node, GraphBuilder *unused = nullptr) {
744   Graph *sub_graph = call_node->GetSubGraph();
745   SetCallResType<AObject::kTypeAnyValue>(call_node);
746   if (call_node->input(0)->GetVobj() == nullptr || call_node->input(0)->GetVobj()->GetPyObject().ptr() == nullptr) {
747     return false;
748   }
749 
750   py::object callable_object = call_node->input(0)->GetVobj()->GetPyObject();
751   std::vector<py::object> args;
752   std::transform(call_node->getInputs().begin() + 1, call_node->getInputs().end(), std::back_inserter(args),
753                  [](ValueNode *n) { return n->GetVobj() ? n->GetVobj()->GetPyObject() : py::object(); });
754   auto pair = Utils::PackCallStackArgs(args, call_node->GetOpcode());
755   if (pair.first.ptr() == nullptr) {
756     return false;
757   }
758   PyTypeObject *callable_type = Py_TYPE(callable_object.ptr());
759 
760   AObject *info;
761 
762   bool enable_func_graph_eval = force_ms_api || kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kEnableMsApiInfer);
763   if (enable_func_graph_eval) {
764     py::object res = EvalMSAPIValue(callable_object, pair.first, pair.second);
765     info = AObject::Convert(res);
766   } else if (IsPrimitiveType<true>(callable_type) || IsPrimitiveFunctionType<true>(callable_type)) {
767     call_node->SetSubGraph(sub_graph);
768     return InferPrimitive(call_node);
769   } else {
770     info = InferFuncResult(callable_object, pair.first, pair.second, call_node->GetGraph()->Config(), true);
771   }
772 
773   call_node->SetVobj(info);
774   if (info->GetPyObject().ptr() != nullptr) {
775     call_node->input(0)->GetVobj()->SetMsFlag(AObject::kMsFlagStandardFunc);
776   }
777   if (call_node->IsConstantValue()) {
778     return CallNodeReturnConst(call_node, sub_graph, call_node->GetVobj());
779   }
780   return false;
781 }
782 
InferMappingGet(CallNode * call_node,GraphBuilder * unused=nullptr)783 bool InferMappingGet(CallNode *call_node, GraphBuilder *unused = nullptr) {
784   if (call_node->getInputs().size() == BoundMethodInputSize &&
785       call_node->input(0)->GetVobj()->GetType() == AbstractObjectBase::kTypeBoundMethod) {
786     auto func_node = call_node->input(0);
787     auto self = func_node->input(0);
788     auto param_node = call_node->input(1);
789     if (self->IsConstantValue() && param_node->IsConstantValue()) {
790       Graph *g = call_node->GetSubGraph();
791       JustCallAndSetRes(call_node);
792       return CallNodeReturnConst(call_node, g, call_node->GetVobj());
793     }
794   }
795   SetCallResType<AObject::kTypeAnyValue>(call_node);
796   call_node->SetInlineReason(InlineReason::kInlineFunc_Type_Unsupported);
797   return false;
798 }
799 
800 enum FuncKey {
801   FUNC_KEY_EMPTY = 0,             // ""
802   FUNC_KEY_PIJIT_CONSTEXPR,       // "pijit.constexpr"
803   FUNC_KEY_PIJIT_FORBIDDEN,       // "pijit.forbidden"
804   FUNC_KEY_BUILTIN_FUNC,          // "builtin.func"
805   FUNC_KEY_LIST_APPEND,           // "list.append"
806   FUNC_KEY_DICT_POP,              // "dict.pop"
807   FUNC_KEY_PRIMITIVE,             // "mindspore._c_expression.Primitive_"
808   FUNC_KEY_META_FUNCG_RAPH,       // "mindspore._c_expression.MetaFuncGraph_"
809   FUNC_KEY_PSJIT_CODE,            // "mindspore.common.api.jit.<locals>.staging_specialize"
810   FUNC_KEY_CONSTEXPR,             // "mindspore.ops.primitive.constexpr"
811   FUNC_KEY_PRIMEXPR,              // "mindspore.ops.primitive._primexpr"
812   FUNC_KEY_GET_CACHE_PRIM,        // "mindspore.ops._primitive_cache._get_cache_prim"
813   FUNC_KEY_REGISTRY_GET,          // "mindspore.common._register_for_tensor.Registry.get"
814   FUNC_KEY_TENSOR_ASTYPE,         // "mindspore.common.tensor.Tensor.astype"
815   FUNC_KEY_GRAD_OPERATIONS_CODE,  // "mindspore.ops.composite.base._Grad.__call__.<locals>.after_grad"
816   FUNC_KEY_PSJIT_CONVERTMAP,      // "mindspore._extends.parse.resources.convert_object_map"
817   FUNC_KEY_GRAPH_CELL,            // "mindspore.nn.cell.GraphCell"
818   FUNC_KEY_MS_API,                // mindspore api
819   FUNC_KEY_MAPPING_GET,           // mapping get
820   FUNC_KEY_COUNT,
821 };
822 static FuncKey FindFuncKey(const py::object &callable);
823 
824 static const std::unordered_map<FuncKey, InferFunc> infer_func_map = {
825   {FUNC_KEY_PIJIT_CONSTEXPR, JustCallAndSetRes},
826   {FUNC_KEY_PIJIT_FORBIDDEN, SetForbiddenFuncInfo},
827   {FUNC_KEY_BUILTIN_FUNC, InferBuiltinFuncOrMethod},
828   {FUNC_KEY_LIST_APPEND, InferListAppend},
829   {FUNC_KEY_DICT_POP, InferDictPop},
830   {FUNC_KEY_PRIMITIVE, InferPrimitive},
831   {FUNC_KEY_META_FUNCG_RAPH, InferMetaFunc},
832   {FUNC_KEY_PSJIT_CODE, InferMsApiFunc<true>},
833   {FUNC_KEY_CONSTEXPR, InferMSConstexpr},
834   {FUNC_KEY_PRIMEXPR, InferMSConstexpr},
835   {FUNC_KEY_GET_CACHE_PRIM, InferGetCachePrim},
836   {FUNC_KEY_REGISTRY_GET, InferRegistryGet},
837   {FUNC_KEY_TENSOR_ASTYPE, InferTensorAsType},
838   {FUNC_KEY_GRAD_OPERATIONS_CODE, InferGradFunc},
839   {FUNC_KEY_PSJIT_CONVERTMAP, InferConvertMap},
840   {FUNC_KEY_GRAPH_CELL, SetCallResType<AObject::kTypeTensor>},
841   {FUNC_KEY_MS_API, InferMsApiFunc<false>},
842   {FUNC_KEY_MAPPING_GET, InferMappingGet},
843 };
844 
845 static const std::unordered_map<FuncKey, InferFunc> mind_infer_func_map = {
846   {FUNC_KEY_PIJIT_CONSTEXPR, JustCallAndSetRes},     {FUNC_KEY_PIJIT_FORBIDDEN, SetForbiddenFuncInfo},
847   {FUNC_KEY_LIST_APPEND, InferListAppend},           {FUNC_KEY_DICT_POP, InferDictPop},
848   {FUNC_KEY_BUILTIN_FUNC, InferBuiltinFuncOrMethod}, {FUNC_KEY_PSJIT_CODE, SetCallResType<AObject::kTypeTensor>},
849   {FUNC_KEY_GET_CACHE_PRIM, InferGetCachePrim},      {FUNC_KEY_REGISTRY_GET, InferRegistryGet},
850 };
851 
FindInferFunc(const py::object & callable,bool trace_flag)852 InferFunc FindInferFunc(const py::object &callable, bool trace_flag) {
853   FuncKey k = FindFuncKey(callable);
854   const auto &map = trace_flag ? mind_infer_func_map : infer_func_map;
855   auto iter = map.find(k);
856   if (iter != map.end()) {
857     return iter->second;
858   }
859   return nullptr;
860 }
861 
GetFuncKeyMap()862 static const std::unordered_map<size_t, FuncKey> &GetFuncKeyMap() {
863   static std::unordered_map<size_t, FuncKey> map = {};
864   if (!map.empty()) {
865     return map;
866   }
867   py::object func_map = Utils::GetModuleAttr(kModuleName, kFuncMapName, true, true);
868   MS_EXCEPTION_IF_CHECK_FAIL(PyDict_CheckExact(func_map.ptr()), "white list func map must be 'dict[int, str]'");
869   PyObject *key;
870   PyObject *value;
871   Py_ssize_t pos = 0;
872   while (PyDict_Next(func_map.ptr(), &pos, &key, &value)) {
873     MS_EXCEPTION_IF_CHECK_FAIL(PyLong_CheckExact(key), "white list func map key must be 'int'");
874     MS_EXCEPTION_IF_CHECK_FAIL(PyLong_CheckExact(value), "white list func map value must be 'int'");
875     size_t k = (PyLong_AsSize_t(value));
876     MS_EXCEPTION_IF_CHECK_FAIL(k < FUNC_KEY_COUNT, "white list func map got error FuncKey " + std::to_string(k));
877     map[PyLong_AsSize_t(key)] = static_cast<FuncKey>(k);
878   }
879   return map;
880 }
881 
KeyFinderFuncId(const py::object & callable)882 static FuncKey KeyFinderFuncId(const py::object &callable) {
883   auto iter = GetFuncKeyMap().find(FunctionId(callable));
884   return iter != GetFuncKeyMap().end() ? iter->second : FUNC_KEY_EMPTY;
885 }
886 
KeyFinderFuncCodeId(const py::object & callable)887 static FuncKey KeyFinderFuncCodeId(const py::object &callable) {
888   PyObject *func = callable.ptr();
889   py::object handle;
890   if (IsCellType<true>(Py_TYPE(func))) {
891     handle = callable.attr("construct");
892     func = handle.ptr();
893   }
894   if (PyMethod_Check(func)) {
895     func = PyMethod_GET_FUNCTION(func);
896   }
897   if (PyFunction_Check(func)) {
898     func = PyFunction_GET_CODE(func);
899   }
900   if (!PyCode_Check(func)) {
901     return FUNC_KEY_EMPTY;
902   }
903   auto iter = GetFuncKeyMap().find(reinterpret_cast<size_t>(func));
904   return iter != GetFuncKeyMap().end() ? iter->second : FUNC_KEY_EMPTY;
905 }
906 
KeyFinderPrimitive(const py::object & callable)907 static FuncKey KeyFinderPrimitive(const py::object &callable) {
908   PyTypeObject *type_object = Py_TYPE(callable.ptr());
909   bool convert_to_prim = IsPrimitiveType<true>(type_object) || IsPrimitiveFunctionType<true>(type_object);
910   if (!convert_to_prim) {
911     return FUNC_KEY_EMPTY;
912   }
913   py::object func = py::getattr(reinterpret_cast<PyObject *>(type_object), kSlotCallName, nullptr);
914   size_t id;
915   if (func.ptr() == nullptr) {
916     // primitive not defined slot __call__, use it self as id
917     id = reinterpret_cast<size_t>(callable.ptr());
918   } else if (PyFunction_Check(func.ptr())) {
919     // primitive defined python function __call__
920     id = reinterpret_cast<size_t>(PyFunction_GET_CODE(func.ptr()));
921   } else {
922     // primitive defined cpp function __call__
923     id = FunctionId(func);
924   }
925   // first, find map to check special primitive.
926   auto iter = GetFuncKeyMap().find(id);
927   return iter != GetFuncKeyMap().end() ? iter->second : FUNC_KEY_PRIMITIVE;
928 }
929 
KeyFinderMetaFunc(const py::object & callable)930 static FuncKey KeyFinderMetaFunc(const py::object &callable) {
931   PyTypeObject *type_object = reinterpret_cast<PyTypeObject *>(callable.ptr());
932   type_object = PyType_CheckExact(type_object) ? type_object : Py_TYPE(type_object);
933   return IsMetaFuncGraphType<true>(type_object) ? FUNC_KEY_META_FUNCG_RAPH : FUNC_KEY_EMPTY;
934 }
935 
KeyFinderGraphCell(const py::object & callable)936 static FuncKey KeyFinderGraphCell(const py::object &callable) {
937   static size_t id = 0;
938   if (id == 0) {
939     py::object type = Utils::GetModuleAttr("mindspore.nn.cell", "GraphCell", false, true);
940     id = reinterpret_cast<size_t>(type.ptr());
941   }
942   PyTypeObject *type_object = reinterpret_cast<PyTypeObject *>(callable.ptr());
943   type_object = PyType_CheckExact(type_object) ? type_object : Py_TYPE(type_object);
944   size_t cur_id = reinterpret_cast<size_t>(type_object);
945   return cur_id == id ? FUNC_KEY_GRAPH_CELL : FUNC_KEY_EMPTY;
946 }
947 
KeyFinderSkipModule(const py::object & callable)948 static FuncKey KeyFinderSkipModule(const py::object &callable) {
949   const auto &modules = kPIJitConfigDefault.allowed_inline_modules();
950   std::string mod = GetTopModule(callable);
951   if (modules.find(mod) != modules.end()) {
952     return FUNC_KEY_EMPTY;
953   }
954 
955   PyObject *func_info = callable.ptr();
956   if (PyMethod_Check(func_info)) {
957     func_info = PyMethod_GET_FUNCTION(func_info);
958   }
959   if (!PyFunction_Check(func_info) && !PyCFunction_Check(func_info) && !PyType_Check(func_info)) {
960     func_info = reinterpret_cast<PyObject *>(Py_TYPE(func_info));
961   }
962   MS_LOG(DEBUG) << "func " << std::string(py::str(func_info)) << " is forbidden to analyze, module is " << mod;
963   return FUNC_KEY_PIJIT_FORBIDDEN;
964 }
965 
FindFuncKey(const py::object & callable)966 static FuncKey FindFuncKey(const py::object &callable) {
967   if (callable.ptr() == nullptr || !PyCallable_Check(callable.ptr())) {
968     return FUNC_KEY_EMPTY;
969   }
970   std::vector<FuncKey (*)(const py::object &callable)> finders = {
971     KeyFinderFuncId,   KeyFinderFuncCodeId, KeyFinderPrimitive,
972     KeyFinderMetaFunc, KeyFinderGraphCell,  KeyFinderSkipModule,  // must be last for check modules
973   };
974   FuncKey res = FUNC_KEY_EMPTY;
975   for (auto iter = finders.begin(), end = finders.end(); iter != end && res == FUNC_KEY_EMPTY; ++iter) {
976     res = (*iter)(callable);
977   }
978   return res;
979 }
980 
CheckJitConstexpr(const py::object & func)981 bool CheckJitConstexpr(const py::object &func) {
982   if (func.ptr() == nullptr) {
983     return false;
984   }
985   FuncKey k = KeyFinderFuncId(func);
986   return k == FUNC_KEY_PIJIT_CONSTEXPR;
987 }
988 
CheckConstexpr(const py::object & func)989 static bool CheckConstexpr(const py::object &func) { return KeyFinderPrimitive(func) == FUNC_KEY_CONSTEXPR; }
990 
CheckMSConstexpr(const py::object & func)991 bool CheckMSConstexpr(const py::object &func) {
992   if (func.ptr() == nullptr) {
993     return false;
994   }
995   FuncKey k = KeyFinderPrimitive(func);
996   return k == FUNC_KEY_CONSTEXPR || k == FUNC_KEY_PRIMEXPR;
997 }
998 
CheckBuiltinFuncOrMethod(const py::object & func)999 bool CheckBuiltinFuncOrMethod(const py::object &func) {
1000   if (func.ptr() == nullptr) {
1001     return false;
1002   }
1003   FuncKey k = KeyFinderFuncId(func);
1004   return k == FUNC_KEY_BUILTIN_FUNC;
1005 }
1006 
1007 }  // namespace pijit
1008 }  // namespace mindspore
1009