1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_PI_JIT_UTILS_H 17 #define MINDSPORE_PI_JIT_UTILS_H 18 19 #include <string> 20 #include <vector> 21 #include <utility> 22 #include <map> 23 #include <chrono> 24 #include "pybind11/pybind11.h" 25 #include "pipeline/jit/pi/pydef.h" 26 #include "mindspore/core/ir/cell.h" 27 28 namespace mindspore { 29 namespace pijit { 30 31 namespace py = pybind11; 32 33 constexpr auto kTwo = 2; 34 constexpr auto kThree = 3; 35 constexpr auto kFive = 5; 36 37 enum StopTraceReason : uint8_t { 38 #define STOP_TRACE_REASON_KIND(kind, description) k##kind, 39 #include "stop_trace_reason.def" 40 #undef STOP_TRACE_REASON_KIND 41 }; 42 43 std::string GetStopTraceReasonDesc(StopTraceReason res); 44 45 enum InlineReason : uint8_t { 46 #define INLINE_REASON_KIND(kind, description) k##kind, 47 #include "inline_reason.def" 48 #undef INLINE_REASON_KIND 49 }; 50 51 std::string GetInlineReasonDesc(InlineReason res); 52 53 enum LoopUnrollingReason : uint8_t { 54 #define LOOP_UNROLLING_REASON_KIND(kind, description) k##kind, 55 #include "loop_unrolling_reason.def" 56 #undef LOOP_UNROLLING_REASON_KIND 57 }; 58 59 std::string GetLoopUnrollingReasonDesc(LoopUnrollingReason res); 60 61 class Utils { 62 public: 63 Utils() = default; 64 ~Utils() = default; 65 66 static std::string GetPyName(PyObject *obj); 67 68 static PyFrameObject *PrepareFrame(PyObject *callable, PyObject *args, PyObject *kwargs); 69 70 // find a object from specified module. default not import, not throw. 71 static py::object GetModuleAttr(const std::string &mod_name, const std::string &attr_name, bool _import = false, 72 bool _throw = false); 73 74 // if has a python exception, log it and return the exception information 75 static std::string ReportPythonException(); 76 77 /** 78 * Pack stack arguments to PyObject by opcode 79 * 80 * \param args stack arguments, the layout match opcode. 81 * \param callop CALL_FUNCTION/CALL_METHOD/CALL_FUNCTION_KW/CALL_FUNCTION_EX. 82 * \param ret_vector_args if true, return a tuple arguments with names tuple. 83 * default, return a tuple arguments with a dict arguments. 84 * if failed, pair.first is empty. 85 * \return a pair of arguments for object call 86 */ 87 static std::pair<py::object, py::object> PackCallStackArgs(const std::vector<py::object> &args, int opcode, 88 bool ret_vector_args = false); 89 90 // alias python 'print(func); import dis; dis.dis(func)' 91 static void DisFuncObject(PyObject *); 92 // alias python 'print(...)' 93 static void PyBuiltinPrint(PyObject *); 94 95 static PyObject *MixedPrecisionTypeToDType(MixedPrecisionType mixedPrecisionType); 96 }; 97 98 #define GRAPH_JIT_LOG_F PY_PRINT_F 99 100 #define PY_PRINT_F(fmt, ...) \ 101 do { \ 102 PyObject *_pystr; \ 103 if (fmt[strlen(fmt) - 1] == '\n') { \ 104 std::string _fstr = fmt; \ 105 _fstr[_fstr.size() - 1] = ' '; \ 106 _pystr = PyUnicode_FromFormat(_fstr.c_str(), ##__VA_ARGS__); \ 107 } else { \ 108 _pystr = PyUnicode_FromFormat(fmt, ##__VA_ARGS__); \ 109 } \ 110 Utils::PyBuiltinPrint(_pystr); \ 111 Py_DECREF(_pystr); \ 112 } while (0) 113 114 #define REPLACE_PY_MEMBER(member, o) \ 115 do { \ 116 PyObject *py_replace_tmp = (member); \ 117 Py_XINCREF(o); \ 118 (member) = (o); \ 119 Py_XDECREF(py_replace_tmp); \ 120 } while (0) 121 122 #ifdef DEBUG 123 #define PRINT_IF_HAS_USER_DEFINED_HOOK(op, hook) \ 124 do { \ 125 static const char *slot_key_##hook = #hook; \ 126 PyObject *attr_##hook = PyObject_GetAttrString(op, slot_key_##hook); \ 127 if (attr_##hook && (PyMethod_Check(attr_##hook) || PyFunction_Check(attr_##hook))) { \ 128 PY_PRINT_F("%A has hook " #hook, PyType_Check(op) ? op : (PyObject *)Py_TYPE(op)); \ 129 } else { \ 130 PyErr_Clear(); \ 131 } \ 132 Py_XDECREF(attr_##hook); \ 133 } while (0) 134 #else 135 #define PRINT_IF_HAS_USER_DEFINED_HOOK(op, hook) 136 #endif 137 class ReprRecursionScope { 138 public: ReprRecursionScope(PyObject * v)139 explicit ReprRecursionScope(PyObject *v) : v_(v), stat_(v == nullptr ? -1 : Py_ReprEnter(v)) {} ~ReprRecursionScope()140 ~ReprRecursionScope() { 141 if (stat_ == 0) { 142 Py_ReprLeave(v_); 143 } 144 } ErrExist()145 bool ErrExist() { return stat_ < 0; } ReEnter()146 bool ReEnter() { return stat_ > 0; } ReEnterOrError()147 bool ReEnterOrError() { return ReEnter() || ErrExist(); } 148 149 private: 150 PyObject *v_; 151 int stat_; 152 }; 153 154 bool HasMutableOrConstAttr(PyObject *obj); 155 bool IsMutableObj(const py::object &obj); 156 bool CheckMutableOrNonConstAttr(PyObject *obj); 157 bool HasDynamicLength(PyObject *obj); 158 bool CheckDynamicLength(PyObject *obj); 159 bool CheckScalar(PyObject *obj); 160 bool CheckContainer(PyObject *obj); 161 bool IsTensorPyObject(PyObject *obj); 162 bool IsMsClass(PyObject *obj); 163 bool IsNumpyObject(PyObject *obj); 164 const char *GetFuncName(const py::object &handle); 165 166 bool CheckAdapterTensor(const py::object &tensor); 167 py::object ConvertToMsTensor(const py::object &tensor); 168 py::object ConvertToAdapterTensor(const py::object &tensor); 169 170 std::string GetTopModule(const py::object &o); 171 py::object GetPyCodeObject(const py::object &any, bool exact_func = false); 172 size_t DeviceAvailableMemSize(); 173 bool CheckConstPyObject(PyObject *cnst); 174 175 class TimeRecorder { 176 public: 177 using RecorderType = const char *; 178 static constexpr double scale = std::nano::den; 179 180 class TimeData { 181 public: 182 struct Data { 183 uint64_t count; 184 uint64_t nano; 185 }; 186 TimeData() = default; 187 ~TimeData(); 188 std::string ToString(); 189 190 std::map<RecorderType, Data> data_; 191 }; 192 193 explicit TimeRecorder(const RecorderType &descr, bool record = true); 194 ~TimeRecorder(); 195 196 private: 197 static TimeData *Data(); 198 199 RecorderType descr_; 200 std::chrono::steady_clock::time_point start_; 201 bool record_; 202 }; 203 204 class RefTracker { 205 public: 206 ~RefTracker(); 207 bool Track(PyObject *obj, const std::string &descr); 208 static RefTracker *GetInstance(); 209 210 private: 211 RefTracker(); 212 static PyObject *UnTrack(PyObject *ref, PyObject *); 213 std::map<void *, std::pair<PyObject *, PyObject *>> tracked_; 214 PyMethodDef mdef_; 215 }; 216 217 } // namespace pijit 218 } // namespace mindspore 219 220 #endif // MINDSPORE_PI_JIT_UTILS_H 221