• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_PI_JIT_UTILS_H
17 #define MINDSPORE_PI_JIT_UTILS_H
18 
19 #include <string>
20 #include <vector>
21 #include <utility>
22 #include <map>
23 #include <chrono>
24 #include "pybind11/pybind11.h"
25 #include "pipeline/jit/pi/pydef.h"
26 #include "mindspore/core/ir/cell.h"
27 
28 namespace mindspore {
29 namespace pijit {
30 
31 namespace py = pybind11;
32 
33 constexpr auto kTwo = 2;
34 constexpr auto kThree = 3;
35 constexpr auto kFive = 5;
36 
37 enum StopTraceReason : uint8_t {
38 #define STOP_TRACE_REASON_KIND(kind, description) k##kind,
39 #include "stop_trace_reason.def"
40 #undef STOP_TRACE_REASON_KIND
41 };
42 
43 std::string GetStopTraceReasonDesc(StopTraceReason res);
44 
45 enum InlineReason : uint8_t {
46 #define INLINE_REASON_KIND(kind, description) k##kind,
47 #include "inline_reason.def"
48 #undef INLINE_REASON_KIND
49 };
50 
51 std::string GetInlineReasonDesc(InlineReason res);
52 
53 enum LoopUnrollingReason : uint8_t {
54 #define LOOP_UNROLLING_REASON_KIND(kind, description) k##kind,
55 #include "loop_unrolling_reason.def"
56 #undef LOOP_UNROLLING_REASON_KIND
57 };
58 
59 std::string GetLoopUnrollingReasonDesc(LoopUnrollingReason res);
60 
61 class Utils {
62  public:
63   Utils() = default;
64   ~Utils() = default;
65 
66   static std::string GetPyName(PyObject *obj);
67 
68   static PyFrameObject *PrepareFrame(PyObject *callable, PyObject *args, PyObject *kwargs);
69 
70   // find a object from specified module. default not import, not throw.
71   static py::object GetModuleAttr(const std::string &mod_name, const std::string &attr_name, bool _import = false,
72                                   bool _throw = false);
73 
74   // if has a python exception, log it and return the exception information
75   static std::string ReportPythonException();
76 
77   /**
78    * Pack stack arguments to PyObject by opcode
79    *
80    * \param args stack arguments, the layout match opcode.
81    * \param callop CALL_FUNCTION/CALL_METHOD/CALL_FUNCTION_KW/CALL_FUNCTION_EX.
82    * \param ret_vector_args if true, return a tuple arguments with names tuple.
83    *                        default, return a tuple arguments with a dict arguments.
84    *                        if failed, pair.first is empty.
85    * \return a pair of arguments for object call
86    */
87   static std::pair<py::object, py::object> PackCallStackArgs(const std::vector<py::object> &args, int opcode,
88                                                              bool ret_vector_args = false);
89 
90   // alias python 'print(func); import dis; dis.dis(func)'
91   static void DisFuncObject(PyObject *);
92   // alias python 'print(...)'
93   static void PyBuiltinPrint(PyObject *);
94 
95   static PyObject *MixedPrecisionTypeToDType(MixedPrecisionType mixedPrecisionType);
96 };
97 
98 #define GRAPH_JIT_LOG_F PY_PRINT_F
99 
100 #define PY_PRINT_F(fmt, ...)                                       \
101   do {                                                             \
102     PyObject *_pystr;                                              \
103     if (fmt[strlen(fmt) - 1] == '\n') {                            \
104       std::string _fstr = fmt;                                     \
105       _fstr[_fstr.size() - 1] = ' ';                               \
106       _pystr = PyUnicode_FromFormat(_fstr.c_str(), ##__VA_ARGS__); \
107     } else {                                                       \
108       _pystr = PyUnicode_FromFormat(fmt, ##__VA_ARGS__);           \
109     }                                                              \
110     Utils::PyBuiltinPrint(_pystr);                                 \
111     Py_DECREF(_pystr);                                             \
112   } while (0)
113 
114 #define REPLACE_PY_MEMBER(member, o)     \
115   do {                                   \
116     PyObject *py_replace_tmp = (member); \
117     Py_XINCREF(o);                       \
118     (member) = (o);                      \
119     Py_XDECREF(py_replace_tmp);          \
120   } while (0)
121 
122 #ifdef DEBUG
123 #define PRINT_IF_HAS_USER_DEFINED_HOOK(op, hook)                                         \
124   do {                                                                                   \
125     static const char *slot_key_##hook = #hook;                                          \
126     PyObject *attr_##hook = PyObject_GetAttrString(op, slot_key_##hook);                 \
127     if (attr_##hook && (PyMethod_Check(attr_##hook) || PyFunction_Check(attr_##hook))) { \
128       PY_PRINT_F("%A has hook " #hook, PyType_Check(op) ? op : (PyObject *)Py_TYPE(op)); \
129     } else {                                                                             \
130       PyErr_Clear();                                                                     \
131     }                                                                                    \
132     Py_XDECREF(attr_##hook);                                                             \
133   } while (0)
134 #else
135 #define PRINT_IF_HAS_USER_DEFINED_HOOK(op, hook)
136 #endif
137 class ReprRecursionScope {
138  public:
ReprRecursionScope(PyObject * v)139   explicit ReprRecursionScope(PyObject *v) : v_(v), stat_(v == nullptr ? -1 : Py_ReprEnter(v)) {}
~ReprRecursionScope()140   ~ReprRecursionScope() {
141     if (stat_ == 0) {
142       Py_ReprLeave(v_);
143     }
144   }
ErrExist()145   bool ErrExist() { return stat_ < 0; }
ReEnter()146   bool ReEnter() { return stat_ > 0; }
ReEnterOrError()147   bool ReEnterOrError() { return ReEnter() || ErrExist(); }
148 
149  private:
150   PyObject *v_;
151   int stat_;
152 };
153 
154 bool HasMutableOrConstAttr(PyObject *obj);
155 bool IsMutableObj(const py::object &obj);
156 bool CheckMutableOrNonConstAttr(PyObject *obj);
157 bool HasDynamicLength(PyObject *obj);
158 bool CheckDynamicLength(PyObject *obj);
159 bool CheckScalar(PyObject *obj);
160 bool CheckContainer(PyObject *obj);
161 bool IsTensorPyObject(PyObject *obj);
162 bool IsMsClass(PyObject *obj);
163 bool IsNumpyObject(PyObject *obj);
164 const char *GetFuncName(const py::object &handle);
165 
166 bool CheckAdapterTensor(const py::object &tensor);
167 py::object ConvertToMsTensor(const py::object &tensor);
168 py::object ConvertToAdapterTensor(const py::object &tensor);
169 
170 std::string GetTopModule(const py::object &o);
171 py::object GetPyCodeObject(const py::object &any, bool exact_func = false);
172 size_t DeviceAvailableMemSize();
173 bool CheckConstPyObject(PyObject *cnst);
174 
175 class TimeRecorder {
176  public:
177   using RecorderType = const char *;
178   static constexpr double scale = std::nano::den;
179 
180   class TimeData {
181    public:
182     struct Data {
183       uint64_t count;
184       uint64_t nano;
185     };
186     TimeData() = default;
187     ~TimeData();
188     std::string ToString();
189 
190     std::map<RecorderType, Data> data_;
191   };
192 
193   explicit TimeRecorder(const RecorderType &descr, bool record = true);
194   ~TimeRecorder();
195 
196  private:
197   static TimeData *Data();
198 
199   RecorderType descr_;
200   std::chrono::steady_clock::time_point start_;
201   bool record_;
202 };
203 
204 class RefTracker {
205  public:
206   ~RefTracker();
207   bool Track(PyObject *obj, const std::string &descr);
208   static RefTracker *GetInstance();
209 
210  private:
211   RefTracker();
212   static PyObject *UnTrack(PyObject *ref, PyObject *);
213   std::map<void *, std::pair<PyObject *, PyObject *>> tracked_;
214   PyMethodDef mdef_;
215 };
216 
217 }  // namespace pijit
218 }  // namespace mindspore
219 
220 #endif  // MINDSPORE_PI_JIT_UTILS_H
221