1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pipeline/jit/pi/pi_jit_config.h"
17 #include <string>
18 #include <unordered_map>
19 #include "utils/log_adapter.h"
20 #include "pipeline/jit/pi/external.h"
21 #include "pipeline/jit/pi/utils/utils.h"
22 #include "pipeline/jit/pi/pydef.h"
23
24 namespace mindspore {
25 namespace pijit {
26
27 GraphJitConfig kPIJitConfigDefault;
28
29 constexpr int kDefaultMaxTraceDepth = 16;
30
31 constexpr const char *kModuleName = "mindspore._extends.pijit.pijit_func_white_list";
32 constexpr const char *kFuncMapName = "_func_map";
33 constexpr const char *kGuardFuncMapName = "guard_func_map";
34
35 static const std::unordered_map<std::string, bool (GraphJitConfig::*)(PyObject *)> key_map = {
36 {"auto_jit_func_filter", &GraphJitConfig::SetAutoJitFilter},
37 {"auto_jit_cell", &GraphJitConfig::SetBool<GraphJitConfig::kAutoJitCell>},
38 {"auto_grad", &GraphJitConfig::SetBool<GraphJitConfig::kAutoGrad>},
39 {"compile_by_trace", &GraphJitConfig::SetBool<GraphJitConfig::kTraceFlag>},
40 {"print_after_all", &GraphJitConfig::SetBool<GraphJitConfig::kPrintAfterAll>},
41 {"print_tb", &GraphJitConfig::SetBool<GraphJitConfig::kPrintTraceback>},
42 {"print_bb", &GraphJitConfig::SetBool<GraphJitConfig::kPrintBB>},
43 {"print_cfg", &GraphJitConfig::SetBool<GraphJitConfig::kPrintCFG>},
44 {"interpret_captured_code", &GraphJitConfig::SetBool<GraphJitConfig::kInterpretCapturedCode>},
45 {"compile_without_capture", &GraphJitConfig::SetBool<GraphJitConfig::kCompileWithoutCapture>},
46 {"compile_with_try", &GraphJitConfig::SetBool<GraphJitConfig::kCompileWithTry>},
47 {"specialize_scalar", &GraphJitConfig::SetBool<GraphJitConfig::kGuardSpecializeScalar>},
48 {"specialize_container", &GraphJitConfig::SetBool<GraphJitConfig::kGuardSpecializeContainer>},
49 {"specialize_tensor", &GraphJitConfig::SetBool<GraphJitConfig::kGuardSpecializeTensor>},
50 {"guard_detach_object", &GraphJitConfig::SetBool<GraphJitConfig::kGuardDetachObject>},
51 {"print_guard", &GraphJitConfig::SetBool<GraphJitConfig::kPrintGuard>},
52 {"reuse_graph", &GraphJitConfig::SetBool<GraphJitConfig::kReuseGraph>},
53 {"print_reuse_graph", &GraphJitConfig::SetBool<GraphJitConfig::kPrintReuseGraph>},
54 {"auto_clean_cache", &GraphJitConfig::SetBool<GraphJitConfig::kAutoCleanCache>},
55 {"prune_case", &GraphJitConfig::SetBool<GraphJitConfig::kPruneCase>},
56 {"loop_unrolling", &GraphJitConfig::SetBool<GraphJitConfig::kLoopUnrolling>},
57 {"infer_only", &GraphJitConfig::SetBool<GraphJitConfig::kInferOnly>},
58 {"infer_primitive", &GraphJitConfig::SetBool<GraphJitConfig::kInferPrimitive>},
59 {"strict_trace", &GraphJitConfig::SetBool<GraphJitConfig::kStrictTrace>},
60 {"perf_statistics", &GraphJitConfig::SetBool<GraphJitConfig::kPerfStatistics>},
61 {"LOG_GRAPH_BREAK", &GraphJitConfig::SetBool<GraphJitConfig::kLogGraphBreak>},
62 {"LOG_PERF", &GraphJitConfig::SetBool<GraphJitConfig::kLogPerf>},
63 {"LOG_GUARD_PERF", &GraphJitConfig::SetBool<GraphJitConfig::kLogGuardPerf>},
64 {"enable_dynamic_shape", &GraphJitConfig::SetBool<GraphJitConfig::kEnableDynamicShape>},
65 {"test_graph_ir", &GraphJitConfig::SetBool<GraphJitConfig::kTestGraphIR>},
66 {"kFeatureBreakAtInlinedFunction", &GraphJitConfig::SetBool<GraphJitConfig::kFeatureBreakAtInlinedFunction>},
67 {"kEnableEliminateUnusedOperation", &GraphJitConfig::SetBool<GraphJitConfig::kEnableEliminateUnusedOperation>},
68 {"kEnableGeneratorExpressionToTuple", &GraphJitConfig::SetBool<GraphJitConfig::kEnableGeneratorExpressionToTuple>},
69 // kEnableOptimizeForAttrItem
70 {"MAX_INLINE_DEPTH", &GraphJitConfig::SetInt<GraphJitConfig::kMaxInlineDepth>},
71 {"MAX_TRACE_DEPTH", &GraphJitConfig::SetInt<GraphJitConfig::kMaxTraceDepth>},
72 {"MAX_PRUNE_CASE", &GraphJitConfig::SetInt<GraphJitConfig::kMaxPruneCase>},
73 {"MAX_LOOP_UNROLLING", &GraphJitConfig::SetInt<GraphJitConfig::kMaxLoopUnrolling>},
74 {"INFER_PRIMITIVE_MASK", &GraphJitConfig::SetInt<GraphJitConfig::kInferPrimitiveMask>},
75 {"INFER_PRIMITIVE_MAX", &GraphJitConfig::SetInt<GraphJitConfig::kInferPrimitiveMax>},
76 {"STATIC_GRAPH_BYTECODE_MIN", &GraphJitConfig::SetInt<GraphJitConfig::kStaticGraphBytecodeMin>},
77 {"PERF_STATISTICS_COUNT", &GraphJitConfig::SetInt<GraphJitConfig::kPerfStatisticsCount>},
78 {"PERF_STATISTICS_SCALE_10000X", &GraphJitConfig::SetInt<GraphJitConfig::kPerfStatisticsScale10000x>},
79 {"limit_graph_size", &GraphJitConfig::SetInt<GraphJitConfig::kLimitGraphSize>},
80 {"limit_graph_count", &GraphJitConfig::SetInt<GraphJitConfig::kLimitGraphCount>},
81 {"relax_guard_count", &GraphJitConfig::SetInt<GraphJitConfig::kGuardRelaxCount>},
82 {"allowed_inline_modules", &GraphJitConfig::AddAllowedInlineModules},
83 {"pijit_forbidden", &GraphJitConfig::AddJitForbidden},
84 {"pijit_constexpr", &GraphJitConfig::AddJitConstexpr},
85 {"relax_guard_func", &GraphJitConfig::AddJitRelaxGuard},
86 {"jit_level", &GraphJitConfig::AddJitLevel},
87 };
88
GraphJitConfig()89 GraphJitConfig::GraphJitConfig() {
90 bool_conf[kAutoJitCell - kBoolConf] = false;
91 bool_conf[kAutoGrad - kBoolConf] = false;
92 bool_conf[kPrintAfterAll - kBoolConf] = false;
93 bool_conf[kTraceFlag - kBoolConf] = true;
94 bool_conf[kPrintTraceback - kBoolConf] = false;
95 bool_conf[kPrintBB - kBoolConf] = false;
96 bool_conf[kPrintCFG - kBoolConf] = false;
97 bool_conf[kInterpretCapturedCode - kBoolConf] = false;
98 bool_conf[kCompileWithoutCapture - kBoolConf] = false;
99 bool_conf[kCompileWithTry - kBoolConf] = true;
100 bool_conf[kGuardSpecializeScalar - kBoolConf] = true;
101 bool_conf[kGuardSpecializeContainer - kBoolConf] = false;
102 bool_conf[kGuardSpecializeTensor - kBoolConf] = false;
103 bool_conf[kGuardDetachObject - kBoolConf] = false;
104 bool_conf[kPrintGuard - kBoolConf] = false;
105 bool_conf[kReuseGraph - kBoolConf] = false;
106 bool_conf[kPrintReuseGraph - kBoolConf] = false;
107 bool_conf[kAutoCleanCache - kBoolConf] = false;
108 bool_conf[kPruneCase - kBoolConf] = true;
109 bool_conf[kLoopUnrolling - kBoolConf] = true;
110 bool_conf[kSkipException - kBoolConf] = false;
111 bool_conf[kInferOnly - kBoolConf] = true;
112 bool_conf[kInferPrimitive - kBoolConf] = true;
113 bool_conf[kStrictTrace - kBoolConf] = true;
114 bool_conf[kPerfStatistics - kBoolConf] = false;
115 bool_conf[kLogGraphBreak - kBoolConf] = false;
116 bool_conf[kLogPerf - kBoolConf] = false;
117 bool_conf[kLogGuardPerf - kBoolConf] = false;
118 bool_conf[kTestGraphIR - kBoolConf] = false;
119 bool_conf[kEnableGeneratorExpressionToTuple - kBoolConf] = true;
120 bool_conf[kEnableDynamicShape - kBoolConf] = false;
121 bool_conf[kEnableMsApiInfer - kBoolConf] = false;
122
123 /*'EnableOptimizeForAttrItem' options must be ensure that multiple calls of the
124 *__getattr__, __getitem__ function of the user-defined object do not affect the correctness.
125 */
126 bool_conf[kEnableOptimizeForAttrItem - kBoolConf] = true;
127 bool_conf[kEnableEliminateUnusedOperation - kBoolConf] = false;
128 bool_conf[kFeatureBreakAtInlinedFunction - kBoolConf] = true;
129
130 int_conf[kMaxInlineDepth - kIntConf] = 8;
131 int_conf[kMaxTraceDepth - kIntConf] = kDefaultMaxTraceDepth;
132 int_conf[kMaxPruneCase - kIntConf] = -1;
133 int_conf[kMaxLoopUnrolling - kIntConf] = 100;
134 int_conf[kInferPrimitiveMask - kIntConf] = 7;
135 int_conf[kInferPrimitiveMax - kIntConf] = 0;
136 int_conf[kStaticGraphBytecodeMin - kIntConf] = 0;
137 int_conf[kPerfStatisticsCount - kIntConf] = 1;
138 int_conf[kPerfStatisticsScale10000x - kIntConf] = 1000;
139 int_conf[kLimitGraphSize - kIntConf] = 0;
140 int_conf[kLimitGraphCount - kIntConf] = 0;
141 int_conf[kGuardRelaxCount - kIntConf] = 0;
142
143 allowed_inline_modules_.insert("mindspore");
144
145 jit_level = "O0";
146 }
147
GetObjectsMap()148 static py::object GetObjectsMap() {
149 py::str mod_name("mindspore");
150 py::str key_name("<pijit.registry>");
151 // can't import module while the module is deallocated
152 py::object ms = py::reinterpret_steal<py::object>(PyImport_GetModule(mod_name.ptr()));
153 if (ms.ptr() == nullptr || !PyModule_Check(ms.ptr())) {
154 return py::object();
155 }
156 PyObject *registry = PyObject_GetAttr(ms.ptr(), key_name.ptr());
157 if (registry != nullptr) {
158 MS_EXCEPTION_IF_CHECK_FAIL(PyDict_CheckExact(registry), "got duplicate attribute for <pijit.registry>");
159 return py::reinterpret_steal<py::object>(registry);
160 }
161 PyErr_Clear();
162
163 // just set once, module reload will not rewrite attribute.
164 static bool init = false;
165 if (init) {
166 return py::object();
167 }
168 init = true;
169 registry = PyDict_New();
170 PyObject_SetAttr(ms.ptr(), key_name.ptr(), registry);
171 return py::reinterpret_steal<py::object>(registry);
172 }
173
AddToFuncMap(PyObject * list,const std::string & map_name,const std::string & key)174 static bool AddToFuncMap(PyObject *list, const std::string &map_name, const std::string &key) {
175 py::object func_map = Utils::GetModuleAttr(kModuleName, map_name, true, true);
176 py::object key_object = Utils::GetModuleAttr(kModuleName, key, true, true);
177 for (const py::handle &i : py::iter(list)) {
178 if (!PyCallable_Check(i.ptr())) {
179 return false;
180 }
181 py::int_ id = FunctionId(py::reinterpret_borrow<py::object>(i));
182 PyDict_SetItem(func_map.ptr(), id.ptr(), key_object.ptr());
183 }
184 return true;
185 }
186
AddJitForbidden(PyObject * list)187 bool GraphJitConfig::AddJitForbidden(PyObject *list) {
188 return AddToFuncMap(list, kFuncMapName, "FUNC_KEY_PIJIT_FORBIDDEN");
189 }
190
AddJitLevel(PyObject * str)191 bool GraphJitConfig::AddJitLevel(PyObject *str) {
192 if (py::isinstance<py::str>(str)) {
193 py::str jit_level_obj = py::cast<py::str>(str);
194 auto jit_level_str = py::cast<std::string>(jit_level_obj);
195 if (jit_level_str != "O0" && jit_level_str != "O1" && jit_level_str != "O2") {
196 return false;
197 }
198 jit_level = jit_level_str;
199 return true;
200 }
201 return false;
202 }
203
getJitLevel() const204 std::string GraphJitConfig::getJitLevel() const { return jit_level; }
205
AddJitConstexpr(PyObject * list)206 bool GraphJitConfig::AddJitConstexpr(PyObject *list) {
207 return AddToFuncMap(list, kFuncMapName, "FUNC_KEY_PIJIT_CONSTEXPR");
208 }
209
AddJitRelaxGuard(PyObject * list)210 bool GraphJitConfig::AddJitRelaxGuard(PyObject *list) {
211 return AddToFuncMap(list, kGuardFuncMapName, "GUARD_KEY_RELAX_FUNC");
212 }
213
AddAllowedInlineModules(PyObject * list)214 bool GraphJitConfig::AddAllowedInlineModules(PyObject *list) {
215 py::object l = py::reinterpret_borrow<py::object>(list);
216 for (const auto &i : py::iter(l)) {
217 const char *name = nullptr;
218 if (PyUnicode_Check(i.ptr())) {
219 name = PyUnicode_AsUTF8(i.ptr());
220 } else if (PyModule_Check(i.ptr())) {
221 name = PyModule_GetName(i.ptr());
222 } else {
223 continue;
224 }
225 if (name == nullptr) {
226 PyErr_Clear();
227 continue;
228 }
229 AddAllowedInlineModules(name);
230 }
231 return true;
232 }
233
AddAllowedInlineModules(const std::string & module_name)234 void GraphJitConfig::AddAllowedInlineModules(const std::string &module_name) {
235 kPIJitConfigDefault.allowed_inline_modules_.insert(module_name);
236 }
237
SetAutoJitFilter(PyObject * callable)238 bool GraphJitConfig::SetAutoJitFilter(PyObject *callable) {
239 if (!PyCallable_Check(callable)) {
240 MS_LOG(WARNING) << "PIJit option 'auto_jit_func_filter' only accept callable, but got "
241 << std::string(py::str(callable));
242 return false;
243 }
244 py::object map = GetObjectsMap();
245 if (map.ptr() == nullptr) {
246 return false;
247 }
248 (void)SetBool<kAutoJit>(Py_True);
249 PyDict_SetItemString(map.ptr(), "<auto jit filter>", callable);
250 return true;
251 }
252
ShouldAutoJit(PyFrameObject * f)253 bool GraphJitConfig::ShouldAutoJit(PyFrameObject *f) {
254 if (!GetBoolConfig(kAutoJit)) {
255 return false;
256 }
257 py::object map = GetObjectsMap();
258 if (map.ptr() == nullptr) {
259 // mindspore module is unload
260 (void)SetBool<kAutoJit>(Py_False);
261 return false;
262 }
263 PyObject *filter = PyDict_GetItemString(map.ptr(), "<auto jit filter>");
264 if (filter == nullptr) {
265 (void)SetBool<kAutoJit>(Py_False);
266 return false;
267 }
268 PyObject *arg = reinterpret_cast<PyObject *>(f);
269 PyObject *res = PyObject_Vectorcall(filter, &arg, 1, nullptr);
270 if (PyErr_Occurred()) {
271 MS_LOG(ERROR) << "***" << py::error_already_set().what() << "*** at " << std::string(py::str(filter)) << " ignored";
272 PyErr_Clear();
273 (void)SetBool<kAutoJit>(Py_False);
274 return false;
275 }
276 Py_DECREF(res);
277 return res == Py_True;
278 }
279
GraphJitConfig(const py::object & c)280 GraphJitConfig::GraphJitConfig(const py::object &c) {
281 *this = kPIJitConfigDefault;
282 (void)c.cast<py::dict>();
283 PyObject *key;
284 PyObject *value;
285 Py_ssize_t pos = 0;
286 while (PyDict_Next(c.ptr(), &pos, &key, &value)) {
287 if (PyUnicode_Check(key)) {
288 const char *k = PyUnicode_AsUTF8(key);
289 auto iter = key_map.find(k);
290 if (iter != key_map.end() && (this->*(iter->second))(value)) {
291 continue;
292 }
293 }
294 MS_LOG(WARNING) << "unknown PIJit options: " << std::string(py::str(key)) << ":" << std::string(py::str(value));
295 }
296 }
297
ReplaceMethod(const py::object & cls,PyMethodDef * mdef,const char * save_name,bool enable)298 static void ReplaceMethod(const py::object &cls, PyMethodDef *mdef, const char *save_name, bool enable) {
299 py::object func = cls.attr(mdef->ml_name);
300 bool is_hook = false;
301 if (Py_IS_TYPE(func.ptr(), &PyMethodDescr_Type)) {
302 is_hook = reinterpret_cast<PyMethodDescrObject *>(func.ptr())->d_method->ml_meth == mdef->ml_meth;
303 }
304 if (enable && !is_hook) {
305 PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(cls.ptr());
306 py::object hook = py::reinterpret_steal<py::object>(PyDescr_NewMethod(tp, mdef));
307 cls.attr(mdef->ml_name) = hook;
308 cls.attr(save_name) = func;
309 }
310 if (!enable && is_hook) {
311 cls.attr(mdef->ml_name) = cls.attr(save_name);
312 py::delattr(cls, save_name);
313 }
314 }
315
ApplyAutoJitCell()316 void GraphJitConfig::ApplyAutoJitCell() {
317 static constexpr const char *name = "__call__";
318 static constexpr const char *save_name = "_old__call__";
319 static const PyCFunctionWithKeywords CellForward = [](PyObject *self, PyObject *vargs, PyObject *kwargs) {
320 PyObject *construct = PyObject_GetAttrString(self, "construct");
321 py::object handle = py::reinterpret_steal<py::object>(construct);
322 if (construct != nullptr) {
323 (void)pi_jit_should_compile(handle, py::dict(), py::none());
324 } else {
325 PyErr_Clear();
326 }
327
328 PyObject *func = PyObject_GetAttrString(self, save_name);
329 PyObject *ret = PyObject_Call(func, vargs, kwargs);
330 Py_DECREF(func);
331 return ret;
332 };
333 static PyMethodDef mdef = {name, reinterpret_cast<PyCFunction>(CellForward), METH_VARARGS | METH_KEYWORDS, "Hook"};
334
335 bool enable = kPIJitConfigDefault.GetBoolConfig(GraphJitConfig::kAutoJitCell);
336 py::object cls = Utils::GetModuleAttr("mindspore.nn", "Cell", false, false);
337 ReplaceMethod(cls, &mdef, save_name, enable);
338 }
339
340 } // namespace pijit
341
update_pijit_default_config(const py::kwargs & conf)342 void update_pijit_default_config(const py::kwargs &conf) {
343 mindspore::pijit::kPIJitConfigDefault = mindspore::pijit::GraphJitConfig(conf);
344 mindspore::pijit::GraphJitConfig::ApplyAutoJitCell();
345 }
346
347 } // namespace mindspore
348