• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pipeline/jit/pi/pi_jit_config.h"
17 #include <string>
18 #include <unordered_map>
19 #include "utils/log_adapter.h"
20 #include "pipeline/jit/pi/external.h"
21 #include "pipeline/jit/pi/utils/utils.h"
22 #include "pipeline/jit/pi/pydef.h"
23 
24 namespace mindspore {
25 namespace pijit {
26 
27 GraphJitConfig kPIJitConfigDefault;
28 
29 constexpr int kDefaultMaxTraceDepth = 16;
30 
31 constexpr const char *kModuleName = "mindspore._extends.pijit.pijit_func_white_list";
32 constexpr const char *kFuncMapName = "_func_map";
33 constexpr const char *kGuardFuncMapName = "guard_func_map";
34 
35 static const std::unordered_map<std::string, bool (GraphJitConfig::*)(PyObject *)> key_map = {
36   {"auto_jit_func_filter", &GraphJitConfig::SetAutoJitFilter},
37   {"auto_jit_cell", &GraphJitConfig::SetBool<GraphJitConfig::kAutoJitCell>},
38   {"auto_grad", &GraphJitConfig::SetBool<GraphJitConfig::kAutoGrad>},
39   {"compile_by_trace", &GraphJitConfig::SetBool<GraphJitConfig::kTraceFlag>},
40   {"print_after_all", &GraphJitConfig::SetBool<GraphJitConfig::kPrintAfterAll>},
41   {"print_tb", &GraphJitConfig::SetBool<GraphJitConfig::kPrintTraceback>},
42   {"print_bb", &GraphJitConfig::SetBool<GraphJitConfig::kPrintBB>},
43   {"print_cfg", &GraphJitConfig::SetBool<GraphJitConfig::kPrintCFG>},
44   {"interpret_captured_code", &GraphJitConfig::SetBool<GraphJitConfig::kInterpretCapturedCode>},
45   {"compile_without_capture", &GraphJitConfig::SetBool<GraphJitConfig::kCompileWithoutCapture>},
46   {"compile_with_try", &GraphJitConfig::SetBool<GraphJitConfig::kCompileWithTry>},
47   {"specialize_scalar", &GraphJitConfig::SetBool<GraphJitConfig::kGuardSpecializeScalar>},
48   {"specialize_container", &GraphJitConfig::SetBool<GraphJitConfig::kGuardSpecializeContainer>},
49   {"specialize_tensor", &GraphJitConfig::SetBool<GraphJitConfig::kGuardSpecializeTensor>},
50   {"guard_detach_object", &GraphJitConfig::SetBool<GraphJitConfig::kGuardDetachObject>},
51   {"print_guard", &GraphJitConfig::SetBool<GraphJitConfig::kPrintGuard>},
52   {"reuse_graph", &GraphJitConfig::SetBool<GraphJitConfig::kReuseGraph>},
53   {"print_reuse_graph", &GraphJitConfig::SetBool<GraphJitConfig::kPrintReuseGraph>},
54   {"auto_clean_cache", &GraphJitConfig::SetBool<GraphJitConfig::kAutoCleanCache>},
55   {"prune_case", &GraphJitConfig::SetBool<GraphJitConfig::kPruneCase>},
56   {"loop_unrolling", &GraphJitConfig::SetBool<GraphJitConfig::kLoopUnrolling>},
57   {"infer_only", &GraphJitConfig::SetBool<GraphJitConfig::kInferOnly>},
58   {"infer_primitive", &GraphJitConfig::SetBool<GraphJitConfig::kInferPrimitive>},
59   {"strict_trace", &GraphJitConfig::SetBool<GraphJitConfig::kStrictTrace>},
60   {"perf_statistics", &GraphJitConfig::SetBool<GraphJitConfig::kPerfStatistics>},
61   {"LOG_GRAPH_BREAK", &GraphJitConfig::SetBool<GraphJitConfig::kLogGraphBreak>},
62   {"LOG_PERF", &GraphJitConfig::SetBool<GraphJitConfig::kLogPerf>},
63   {"LOG_GUARD_PERF", &GraphJitConfig::SetBool<GraphJitConfig::kLogGuardPerf>},
64   {"enable_dynamic_shape", &GraphJitConfig::SetBool<GraphJitConfig::kEnableDynamicShape>},
65   {"test_graph_ir", &GraphJitConfig::SetBool<GraphJitConfig::kTestGraphIR>},
66   {"kFeatureBreakAtInlinedFunction", &GraphJitConfig::SetBool<GraphJitConfig::kFeatureBreakAtInlinedFunction>},
67   {"kEnableEliminateUnusedOperation", &GraphJitConfig::SetBool<GraphJitConfig::kEnableEliminateUnusedOperation>},
68   {"kEnableGeneratorExpressionToTuple", &GraphJitConfig::SetBool<GraphJitConfig::kEnableGeneratorExpressionToTuple>},
69   // kEnableOptimizeForAttrItem
70   {"MAX_INLINE_DEPTH", &GraphJitConfig::SetInt<GraphJitConfig::kMaxInlineDepth>},
71   {"MAX_TRACE_DEPTH", &GraphJitConfig::SetInt<GraphJitConfig::kMaxTraceDepth>},
72   {"MAX_PRUNE_CASE", &GraphJitConfig::SetInt<GraphJitConfig::kMaxPruneCase>},
73   {"MAX_LOOP_UNROLLING", &GraphJitConfig::SetInt<GraphJitConfig::kMaxLoopUnrolling>},
74   {"INFER_PRIMITIVE_MASK", &GraphJitConfig::SetInt<GraphJitConfig::kInferPrimitiveMask>},
75   {"INFER_PRIMITIVE_MAX", &GraphJitConfig::SetInt<GraphJitConfig::kInferPrimitiveMax>},
76   {"STATIC_GRAPH_BYTECODE_MIN", &GraphJitConfig::SetInt<GraphJitConfig::kStaticGraphBytecodeMin>},
77   {"PERF_STATISTICS_COUNT", &GraphJitConfig::SetInt<GraphJitConfig::kPerfStatisticsCount>},
78   {"PERF_STATISTICS_SCALE_10000X", &GraphJitConfig::SetInt<GraphJitConfig::kPerfStatisticsScale10000x>},
79   {"limit_graph_size", &GraphJitConfig::SetInt<GraphJitConfig::kLimitGraphSize>},
80   {"limit_graph_count", &GraphJitConfig::SetInt<GraphJitConfig::kLimitGraphCount>},
81   {"relax_guard_count", &GraphJitConfig::SetInt<GraphJitConfig::kGuardRelaxCount>},
82   {"allowed_inline_modules", &GraphJitConfig::AddAllowedInlineModules},
83   {"pijit_forbidden", &GraphJitConfig::AddJitForbidden},
84   {"pijit_constexpr", &GraphJitConfig::AddJitConstexpr},
85   {"relax_guard_func", &GraphJitConfig::AddJitRelaxGuard},
86   {"jit_level", &GraphJitConfig::AddJitLevel},
87 };
88 
GraphJitConfig()89 GraphJitConfig::GraphJitConfig() {
90   bool_conf[kAutoJitCell - kBoolConf] = false;
91   bool_conf[kAutoGrad - kBoolConf] = false;
92   bool_conf[kPrintAfterAll - kBoolConf] = false;
93   bool_conf[kTraceFlag - kBoolConf] = true;
94   bool_conf[kPrintTraceback - kBoolConf] = false;
95   bool_conf[kPrintBB - kBoolConf] = false;
96   bool_conf[kPrintCFG - kBoolConf] = false;
97   bool_conf[kInterpretCapturedCode - kBoolConf] = false;
98   bool_conf[kCompileWithoutCapture - kBoolConf] = false;
99   bool_conf[kCompileWithTry - kBoolConf] = true;
100   bool_conf[kGuardSpecializeScalar - kBoolConf] = true;
101   bool_conf[kGuardSpecializeContainer - kBoolConf] = false;
102   bool_conf[kGuardSpecializeTensor - kBoolConf] = false;
103   bool_conf[kGuardDetachObject - kBoolConf] = false;
104   bool_conf[kPrintGuard - kBoolConf] = false;
105   bool_conf[kReuseGraph - kBoolConf] = false;
106   bool_conf[kPrintReuseGraph - kBoolConf] = false;
107   bool_conf[kAutoCleanCache - kBoolConf] = false;
108   bool_conf[kPruneCase - kBoolConf] = true;
109   bool_conf[kLoopUnrolling - kBoolConf] = true;
110   bool_conf[kSkipException - kBoolConf] = false;
111   bool_conf[kInferOnly - kBoolConf] = true;
112   bool_conf[kInferPrimitive - kBoolConf] = true;
113   bool_conf[kStrictTrace - kBoolConf] = true;
114   bool_conf[kPerfStatistics - kBoolConf] = false;
115   bool_conf[kLogGraphBreak - kBoolConf] = false;
116   bool_conf[kLogPerf - kBoolConf] = false;
117   bool_conf[kLogGuardPerf - kBoolConf] = false;
118   bool_conf[kTestGraphIR - kBoolConf] = false;
119   bool_conf[kEnableGeneratorExpressionToTuple - kBoolConf] = true;
120   bool_conf[kEnableDynamicShape - kBoolConf] = false;
121   bool_conf[kEnableMsApiInfer - kBoolConf] = false;
122 
123   /*'EnableOptimizeForAttrItem' options must be ensure that multiple calls of the
124    *__getattr__, __getitem__ function of the user-defined object do not affect the correctness.
125    */
126   bool_conf[kEnableOptimizeForAttrItem - kBoolConf] = true;
127   bool_conf[kEnableEliminateUnusedOperation - kBoolConf] = false;
128   bool_conf[kFeatureBreakAtInlinedFunction - kBoolConf] = true;
129 
130   int_conf[kMaxInlineDepth - kIntConf] = 8;
131   int_conf[kMaxTraceDepth - kIntConf] = kDefaultMaxTraceDepth;
132   int_conf[kMaxPruneCase - kIntConf] = -1;
133   int_conf[kMaxLoopUnrolling - kIntConf] = 100;
134   int_conf[kInferPrimitiveMask - kIntConf] = 7;
135   int_conf[kInferPrimitiveMax - kIntConf] = 0;
136   int_conf[kStaticGraphBytecodeMin - kIntConf] = 0;
137   int_conf[kPerfStatisticsCount - kIntConf] = 1;
138   int_conf[kPerfStatisticsScale10000x - kIntConf] = 1000;
139   int_conf[kLimitGraphSize - kIntConf] = 0;
140   int_conf[kLimitGraphCount - kIntConf] = 0;
141   int_conf[kGuardRelaxCount - kIntConf] = 0;
142 
143   allowed_inline_modules_.insert("mindspore");
144 
145   jit_level = "O0";
146 }
147 
GetObjectsMap()148 static py::object GetObjectsMap() {
149   py::str mod_name("mindspore");
150   py::str key_name("<pijit.registry>");
151   // can't import module while the module is deallocated
152   py::object ms = py::reinterpret_steal<py::object>(PyImport_GetModule(mod_name.ptr()));
153   if (ms.ptr() == nullptr || !PyModule_Check(ms.ptr())) {
154     return py::object();
155   }
156   PyObject *registry = PyObject_GetAttr(ms.ptr(), key_name.ptr());
157   if (registry != nullptr) {
158     MS_EXCEPTION_IF_CHECK_FAIL(PyDict_CheckExact(registry), "got duplicate attribute for <pijit.registry>");
159     return py::reinterpret_steal<py::object>(registry);
160   }
161   PyErr_Clear();
162 
163   // just set once, module reload will not rewrite attribute.
164   static bool init = false;
165   if (init) {
166     return py::object();
167   }
168   init = true;
169   registry = PyDict_New();
170   PyObject_SetAttr(ms.ptr(), key_name.ptr(), registry);
171   return py::reinterpret_steal<py::object>(registry);
172 }
173 
AddToFuncMap(PyObject * list,const std::string & map_name,const std::string & key)174 static bool AddToFuncMap(PyObject *list, const std::string &map_name, const std::string &key) {
175   py::object func_map = Utils::GetModuleAttr(kModuleName, map_name, true, true);
176   py::object key_object = Utils::GetModuleAttr(kModuleName, key, true, true);
177   for (const py::handle &i : py::iter(list)) {
178     if (!PyCallable_Check(i.ptr())) {
179       return false;
180     }
181     py::int_ id = FunctionId(py::reinterpret_borrow<py::object>(i));
182     PyDict_SetItem(func_map.ptr(), id.ptr(), key_object.ptr());
183   }
184   return true;
185 }
186 
AddJitForbidden(PyObject * list)187 bool GraphJitConfig::AddJitForbidden(PyObject *list) {
188   return AddToFuncMap(list, kFuncMapName, "FUNC_KEY_PIJIT_FORBIDDEN");
189 }
190 
AddJitLevel(PyObject * str)191 bool GraphJitConfig::AddJitLevel(PyObject *str) {
192   if (py::isinstance<py::str>(str)) {
193     py::str jit_level_obj = py::cast<py::str>(str);
194     auto jit_level_str = py::cast<std::string>(jit_level_obj);
195     if (jit_level_str != "O0" && jit_level_str != "O1" && jit_level_str != "O2") {
196       return false;
197     }
198     jit_level = jit_level_str;
199     return true;
200   }
201   return false;
202 }
203 
getJitLevel() const204 std::string GraphJitConfig::getJitLevel() const { return jit_level; }
205 
AddJitConstexpr(PyObject * list)206 bool GraphJitConfig::AddJitConstexpr(PyObject *list) {
207   return AddToFuncMap(list, kFuncMapName, "FUNC_KEY_PIJIT_CONSTEXPR");
208 }
209 
AddJitRelaxGuard(PyObject * list)210 bool GraphJitConfig::AddJitRelaxGuard(PyObject *list) {
211   return AddToFuncMap(list, kGuardFuncMapName, "GUARD_KEY_RELAX_FUNC");
212 }
213 
AddAllowedInlineModules(PyObject * list)214 bool GraphJitConfig::AddAllowedInlineModules(PyObject *list) {
215   py::object l = py::reinterpret_borrow<py::object>(list);
216   for (const auto &i : py::iter(l)) {
217     const char *name = nullptr;
218     if (PyUnicode_Check(i.ptr())) {
219       name = PyUnicode_AsUTF8(i.ptr());
220     } else if (PyModule_Check(i.ptr())) {
221       name = PyModule_GetName(i.ptr());
222     } else {
223       continue;
224     }
225     if (name == nullptr) {
226       PyErr_Clear();
227       continue;
228     }
229     AddAllowedInlineModules(name);
230   }
231   return true;
232 }
233 
AddAllowedInlineModules(const std::string & module_name)234 void GraphJitConfig::AddAllowedInlineModules(const std::string &module_name) {
235   kPIJitConfigDefault.allowed_inline_modules_.insert(module_name);
236 }
237 
SetAutoJitFilter(PyObject * callable)238 bool GraphJitConfig::SetAutoJitFilter(PyObject *callable) {
239   if (!PyCallable_Check(callable)) {
240     MS_LOG(WARNING) << "PIJit option 'auto_jit_func_filter' only accept callable, but got "
241                     << std::string(py::str(callable));
242     return false;
243   }
244   py::object map = GetObjectsMap();
245   if (map.ptr() == nullptr) {
246     return false;
247   }
248   (void)SetBool<kAutoJit>(Py_True);
249   PyDict_SetItemString(map.ptr(), "<auto jit filter>", callable);
250   return true;
251 }
252 
ShouldAutoJit(PyFrameObject * f)253 bool GraphJitConfig::ShouldAutoJit(PyFrameObject *f) {
254   if (!GetBoolConfig(kAutoJit)) {
255     return false;
256   }
257   py::object map = GetObjectsMap();
258   if (map.ptr() == nullptr) {
259     // mindspore module is unload
260     (void)SetBool<kAutoJit>(Py_False);
261     return false;
262   }
263   PyObject *filter = PyDict_GetItemString(map.ptr(), "<auto jit filter>");
264   if (filter == nullptr) {
265     (void)SetBool<kAutoJit>(Py_False);
266     return false;
267   }
268   PyObject *arg = reinterpret_cast<PyObject *>(f);
269   PyObject *res = PyObject_Vectorcall(filter, &arg, 1, nullptr);
270   if (PyErr_Occurred()) {
271     MS_LOG(ERROR) << "***" << py::error_already_set().what() << "*** at " << std::string(py::str(filter)) << " ignored";
272     PyErr_Clear();
273     (void)SetBool<kAutoJit>(Py_False);
274     return false;
275   }
276   Py_DECREF(res);
277   return res == Py_True;
278 }
279 
GraphJitConfig(const py::object & c)280 GraphJitConfig::GraphJitConfig(const py::object &c) {
281   *this = kPIJitConfigDefault;
282   (void)c.cast<py::dict>();
283   PyObject *key;
284   PyObject *value;
285   Py_ssize_t pos = 0;
286   while (PyDict_Next(c.ptr(), &pos, &key, &value)) {
287     if (PyUnicode_Check(key)) {
288       const char *k = PyUnicode_AsUTF8(key);
289       auto iter = key_map.find(k);
290       if (iter != key_map.end() && (this->*(iter->second))(value)) {
291         continue;
292       }
293     }
294     MS_LOG(WARNING) << "unknown PIJit options: " << std::string(py::str(key)) << ":" << std::string(py::str(value));
295   }
296 }
297 
ReplaceMethod(const py::object & cls,PyMethodDef * mdef,const char * save_name,bool enable)298 static void ReplaceMethod(const py::object &cls, PyMethodDef *mdef, const char *save_name, bool enable) {
299   py::object func = cls.attr(mdef->ml_name);
300   bool is_hook = false;
301   if (Py_IS_TYPE(func.ptr(), &PyMethodDescr_Type)) {
302     is_hook = reinterpret_cast<PyMethodDescrObject *>(func.ptr())->d_method->ml_meth == mdef->ml_meth;
303   }
304   if (enable && !is_hook) {
305     PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(cls.ptr());
306     py::object hook = py::reinterpret_steal<py::object>(PyDescr_NewMethod(tp, mdef));
307     cls.attr(mdef->ml_name) = hook;
308     cls.attr(save_name) = func;
309   }
310   if (!enable && is_hook) {
311     cls.attr(mdef->ml_name) = cls.attr(save_name);
312     py::delattr(cls, save_name);
313   }
314 }
315 
ApplyAutoJitCell()316 void GraphJitConfig::ApplyAutoJitCell() {
317   static constexpr const char *name = "__call__";
318   static constexpr const char *save_name = "_old__call__";
319   static const PyCFunctionWithKeywords CellForward = [](PyObject *self, PyObject *vargs, PyObject *kwargs) {
320     PyObject *construct = PyObject_GetAttrString(self, "construct");
321     py::object handle = py::reinterpret_steal<py::object>(construct);
322     if (construct != nullptr) {
323       (void)pi_jit_should_compile(handle, py::dict(), py::none());
324     } else {
325       PyErr_Clear();
326     }
327 
328     PyObject *func = PyObject_GetAttrString(self, save_name);
329     PyObject *ret = PyObject_Call(func, vargs, kwargs);
330     Py_DECREF(func);
331     return ret;
332   };
333   static PyMethodDef mdef = {name, reinterpret_cast<PyCFunction>(CellForward), METH_VARARGS | METH_KEYWORDS, "Hook"};
334 
335   bool enable = kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kAutoJitCell);
336   py::object cls = Utils::GetModuleAttr("mindspore.nn", "Cell", false, false);
337   ReplaceMethod(cls, &mdef, save_name, enable);
338 }
339 
340 }  // namespace pijit
341 
update_pijit_default_config(const py::kwargs & conf)342 void update_pijit_default_config(const py::kwargs &conf) {
343   mindspore::pijit::kPIJitConfigDefault = mindspore::pijit::GraphJitConfig(conf);
344   mindspore::pijit::GraphJitConfig::ApplyAutoJitCell();
345 }
346 
347 }  // namespace mindspore
348