1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_PI_JIT_TRACE_H 17 #define MINDSPORE_PI_JIT_TRACE_H 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 #include <map> 23 #include <functional> 24 #include "pipeline/jit/pi/pydef.h" 25 #include "pybind11/pybind11.h" 26 #include "pipeline/jit/pi/graph_guard/info.h" 27 28 namespace py = pybind11; 29 30 namespace mindspore { 31 namespace pijit { 32 33 typedef enum _TraceType { 34 Unknown = 0, 35 Global, 36 Deref, 37 Closure, 38 BuiltIn, 39 Local, 40 Param, 41 Name, 42 ClassDeref, 43 Const, 44 Item, 45 Attr, 46 Type, 47 Operation, 48 Customized, 49 Unsupported, 50 } TraceType; 51 52 typedef struct _TraceContext { 53 PyObject *f_globals; 54 PyObject *f_builtins; 55 PyObject *f_locals; 56 PyObject *const *f_localsplus; 57 PyCodeObject *f_code; 58 std::map<size_t, PyObject *> *cache; 59 } TraceContext, *PTraceContext; 60 61 class Trace : public std::enable_shared_from_this<Trace> { 62 public: 63 Trace(PyObject *obj, std::shared_ptr<Trace> origin); 64 virtual ~Trace(); 65 virtual std::shared_ptr<Trace> GetOrigin(); 66 /// \brief Get the borrow reference for the object and call Py_INCREF/Py_DECREF by yourself. 67 /// \param[out] borrow reference for PyObject 68 virtual PyObject *GetObject(); 69 virtual TraceType GetTraceType(); 70 virtual TraceType GetOriginType(); 71 virtual void Replace(std::shared_ptr<Trace> dst, std::shared_ptr<Trace> src); 72 virtual bool operator==(const Trace &trace); 73 virtual void Detach(); 74 /// \brief Get the reference for the object by Py_INCREF and call Py_DECREF by yourself. 75 /// \param[in] context for trace 76 /// \param[in] perf for performance of trace 77 /// \param[out] borrow reference for PyObject 78 virtual PyObject *Retrieve(PTraceContext context, bool perf = false); 79 virtual std::string ToString(bool include_param = true) = 0; 80 virtual std::string FormatString(std::map<Trace *, size_t> *cache); 81 virtual const InfoPack &Info() = 0; 82 virtual void Cache(PTraceContext context, PyObject *obj); 83 virtual bool IsConst() const; 84 virtual std::shared_ptr<Trace> Optimize(); 85 virtual std::shared_ptr<Trace> This(); 86 virtual void SetRelaxCount(int cnt); 87 virtual int GetRelaxCount() const; 88 virtual void EnableRelax(); 89 virtual bool RelaxEnabled() const; 90 virtual bool IsSpecialized() const; 91 virtual int GetDepth() const; 92 93 protected: 94 PyObject *obj_; 95 std::shared_ptr<Trace> origin_; 96 TraceType originType_; 97 TraceType curType_; 98 std::string strTrace_; 99 InfoPackPtr info_; 100 bool is_const_; 101 int relax_count_; 102 int relax_limit_; 103 bool is_specialized_; 104 int depth_; 105 }; 106 using TracePtr = std::shared_ptr<Trace>; 107 using TraceVector = std::vector<TracePtr>; 108 109 class RootTrace : public Trace { 110 public: 111 RootTrace(PyObject *obj, TraceType tt, int index = -1, std::string name = "", std::string module_name = ""); 112 virtual ~RootTrace() = default; 113 virtual PyObject *Retrieve(PTraceContext context, bool perf = false); 114 virtual std::string ToString(bool include_param = true); 115 virtual void GetParam(int *index, std::string *name, std::string *module_name); 116 virtual bool operator==(const Trace &trace); 117 virtual const InfoPack &Info(); 118 static bool Support(TraceType tt); 119 120 protected: 121 PyObject *RetrieveGlobal(PTraceContext context); 122 PyObject *RetrieveDeref(PTraceContext context); 123 PyObject *RetrieveClosure(PTraceContext context); 124 PyObject *RetrieveBuiltin(PTraceContext context); 125 PyObject *RetrieveLocal(PTraceContext context); 126 PyObject *RetrieveParam(PTraceContext context); 127 PyObject *RetrieveName(PTraceContext context); 128 PyObject *RetrieveClassDeref(PTraceContext context); 129 130 int idx_; 131 std::string name_; 132 std::string module_name_; 133 }; 134 using RootTracePtr = std::shared_ptr<RootTrace>; 135 136 class ItemTrace : public Trace { 137 public: 138 ItemTrace(PyObject *obj, TracePtr origin, TracePtr item); 139 virtual ~ItemTrace() = default; 140 virtual TracePtr GetItem(); 141 virtual void Replace(std::shared_ptr<Trace> dst, std::shared_ptr<Trace> src); 142 virtual PyObject *Retrieve(PTraceContext context, bool perf = false); 143 virtual std::string ToString(bool include_param = true); 144 virtual bool operator==(const Trace &trace); 145 virtual void Detach(); 146 virtual const InfoPack &Info(); 147 virtual TracePtr Optimize(); 148 virtual void SetRelaxCount(int cnt); 149 static bool Support(TraceType tt); 150 151 protected: 152 TracePtr item_; 153 }; 154 using ItemTracePtr = std::shared_ptr<ItemTrace>; 155 156 class AttrTrace : public Trace { 157 public: 158 AttrTrace(PyObject *obj, TracePtr origin, std::string attr); 159 virtual ~AttrTrace() = default; 160 virtual std::string GetAttribute(); 161 virtual PyObject *Retrieve(PTraceContext context, bool perf = false); 162 virtual std::string ToString(bool include_param = true); 163 virtual bool operator==(const Trace &trace); 164 virtual const InfoPack &Info(); 165 virtual TracePtr Optimize(); 166 virtual void SetRelaxCount(int cnt); 167 static bool Support(TraceType tt); 168 169 protected: 170 std::string attr_; 171 }; 172 using AttrTracePtr = std::shared_ptr<AttrTrace>; 173 174 class ConstTrace : public Trace { 175 public: 176 ConstTrace(PyObject *obj, int index); 177 virtual ~ConstTrace() = default; 178 virtual int GetIndex(); 179 virtual PyObject *Retrieve(PTraceContext context, bool perf = false); 180 virtual std::string ToString(bool include_param = true); 181 virtual bool operator==(const Trace &trace); 182 virtual void Detach(); 183 virtual const InfoPack &Info(); 184 static bool Support(TraceType tt); 185 186 protected: 187 int index_; 188 }; 189 using ConstTracePtr = std::shared_ptr<ConstTrace>; 190 191 class TypeTrace : public Trace { 192 public: 193 TypeTrace(PyObject *obj, TracePtr origin); 194 virtual ~TypeTrace() = default; 195 virtual PyTypeObject *GetType(); 196 virtual PyObject *Retrieve(PTraceContext context, bool perf = false); 197 virtual std::string ToString(bool include_param = true); 198 virtual bool operator==(const Trace &trace); 199 virtual const InfoPack &Info(); 200 virtual void Detach(); 201 virtual TracePtr Optimize(); 202 virtual void SetRelaxCount(int cnt); 203 static bool Support(TraceType tt); 204 205 protected: 206 PyTypeObject *pType_; 207 }; 208 using TypeTracePtr = std::shared_ptr<TypeTrace>; 209 210 class OpTrace : public Trace { 211 public: 212 OpTrace(PyObject *obj, int opcode, int opargs, TraceVector params, std::string name = ""); 213 virtual ~OpTrace() = default; 214 virtual int GetOpCode(); 215 virtual int GetOpArgs(); 216 virtual TracePtr GetParam(size_t idx); 217 virtual size_t GetParamCount(); 218 virtual std::string GetName(); 219 virtual void Replace(std::shared_ptr<Trace> dst, std::shared_ptr<Trace> src); 220 virtual PyObject *Retrieve(PTraceContext context, bool perf = false); 221 virtual std::string ToString(bool include_param = true); 222 virtual bool operator==(const Trace &trace); 223 virtual void Detach(); 224 std::string FormatString(std::map<Trace *, size_t> *cache) override; 225 virtual const InfoPack &Info(); 226 virtual TracePtr Optimize(); 227 virtual void SetRelaxCount(int cnt); 228 static bool Support(TraceType tt); 229 230 protected: 231 virtual void CheckSpecialize(); 232 virtual TracePtr RemoveCastDuplicatePatternPass(); 233 virtual TracePtr RemovePrimOutIsTensorPass(); 234 virtual TracePtr RemoveEmptyTensorPass(); 235 virtual TracePtr RemoveCastPass(); 236 virtual void JudgeDTypeChangePass(); 237 virtual void JudgeDTypeScopePass(); 238 virtual void JudgeCodeChangePass(); 239 virtual void JudgeTrainFlagPass(); 240 virtual void JudgeCompareConstPass(); 241 virtual void JudgeContainsConstPass(); 242 virtual void JudgeInplaceAddConstPass(); 243 virtual void JudgeIsConstPass(); 244 virtual void JudgeBoundMethodPass(); 245 virtual void JudgeSubScrRandPass(); 246 virtual void JudgeDTypeTensorAttrPass(); 247 virtual void JudgeRelaxGuardFuncPass(); 248 249 protected: 250 int opcode_; 251 int opargs_; 252 TraceVector params_; 253 std::string name_; 254 }; 255 using OpTracePtr = std::shared_ptr<OpTrace>; 256 TracePtr CreateOpTrace(PyObject *obj, int opcode, int opargs, TraceVector params, const std::string &module_name = "", 257 const std::string &name = "", bool strict = false, bool print = false); 258 259 /// \brief retrieve the PyObject with ref count plus 1 which will be minus outside 260 typedef std::function<PyObject *(PTraceContext context)> RetrieveFunc; 261 typedef std::function<std::string(bool)> ToStringFunc; 262 class CustomizedTrace : public Trace { 263 public: 264 CustomizedTrace(PyObject *obj, RetrieveFunc rfunc, ToStringFunc sfunc); 265 virtual ~CustomizedTrace() = default; 266 virtual PyObject *Retrieve(PTraceContext context, bool perf = false); 267 virtual std::string ToString(bool include_param = true); 268 virtual const InfoPack &Info(); 269 static bool Support(TraceType tt); 270 271 protected: 272 RetrieveFunc retrieve_; 273 ToStringFunc tostring_; 274 }; 275 using CustomizedTracePtr = std::shared_ptr<CustomizedTrace>; 276 277 class UnsupportedTrace : public Trace { 278 public: 279 UnsupportedTrace(PyObject *obj, TraceVector params, int op, int arg); 280 virtual ~UnsupportedTrace() = default; 281 virtual PyObject *Retrieve(PTraceContext context, bool perf = false); 282 virtual std::string ToString(bool include_param = true); 283 virtual TraceVector GetParams(); 284 virtual void Detach(); 285 std::string FormatString(std::map<Trace *, size_t> *cache) override; 286 virtual const InfoPack &Info(); 287 virtual void SetRelaxCount(int cnt); 288 static bool Support(TraceType tt); 289 290 protected: 291 TraceVector params_; 292 int op_; 293 int arg_; 294 }; 295 using UnsupportedTracePtr = std::shared_ptr<UnsupportedTrace>; 296 297 /// \brief Get the reference for the object by Py_INCREF and call Py_DECREF by yourself. 298 PyObject *GetObjectFromTrace(const PyFrameObject *frame, TracePtr trace, std::map<size_t, PyObject *> *cache = nullptr, 299 bool perf = false); 300 } // namespace pijit 301 } // namespace mindspore 302 303 #endif // MINDSPORE_PI_JIT_TRACE_H 304