• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_PI_JIT_TRACE_H
17 #define MINDSPORE_PI_JIT_TRACE_H
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 #include <map>
23 #include <functional>
24 #include "pipeline/jit/pi/pydef.h"
25 #include "pybind11/pybind11.h"
26 #include "pipeline/jit/pi/graph_guard/info.h"
27 
28 namespace py = pybind11;
29 
30 namespace mindspore {
31 namespace pijit {
32 
33 typedef enum _TraceType {
34   Unknown = 0,
35   Global,
36   Deref,
37   Closure,
38   BuiltIn,
39   Local,
40   Param,
41   Name,
42   ClassDeref,
43   Const,
44   Item,
45   Attr,
46   Type,
47   Operation,
48   Customized,
49   Unsupported,
50 } TraceType;
51 
52 typedef struct _TraceContext {
53   PyObject *f_globals;
54   PyObject *f_builtins;
55   PyObject *f_locals;
56   PyObject *const *f_localsplus;
57   PyCodeObject *f_code;
58   std::map<size_t, PyObject *> *cache;
59 } TraceContext, *PTraceContext;
60 
61 class Trace : public std::enable_shared_from_this<Trace> {
62  public:
63   Trace(PyObject *obj, std::shared_ptr<Trace> origin);
64   virtual ~Trace();
65   virtual std::shared_ptr<Trace> GetOrigin();
66   /// \brief Get the borrow reference for the object and call Py_INCREF/Py_DECREF by yourself.
67   /// \param[out] borrow reference for PyObject
68   virtual PyObject *GetObject();
69   virtual TraceType GetTraceType();
70   virtual TraceType GetOriginType();
71   virtual void Replace(std::shared_ptr<Trace> dst, std::shared_ptr<Trace> src);
72   virtual bool operator==(const Trace &trace);
73   virtual void Detach();
74   /// \brief Get the reference for the object by Py_INCREF and call Py_DECREF by yourself.
75   /// \param[in] context for trace
76   /// \param[in] perf for performance of trace
77   /// \param[out] borrow reference for PyObject
78   virtual PyObject *Retrieve(PTraceContext context, bool perf = false);
79   virtual std::string ToString(bool include_param = true) = 0;
80   virtual std::string FormatString(std::map<Trace *, size_t> *cache);
81   virtual const InfoPack &Info() = 0;
82   virtual void Cache(PTraceContext context, PyObject *obj);
83   virtual bool IsConst() const;
84   virtual std::shared_ptr<Trace> Optimize();
85   virtual std::shared_ptr<Trace> This();
86   virtual void SetRelaxCount(int cnt);
87   virtual int GetRelaxCount() const;
88   virtual void EnableRelax();
89   virtual bool RelaxEnabled() const;
90   virtual bool IsSpecialized() const;
91   virtual int GetDepth() const;
92 
93  protected:
94   PyObject *obj_;
95   std::shared_ptr<Trace> origin_;
96   TraceType originType_;
97   TraceType curType_;
98   std::string strTrace_;
99   InfoPackPtr info_;
100   bool is_const_;
101   int relax_count_;
102   int relax_limit_;
103   bool is_specialized_;
104   int depth_;
105 };
106 using TracePtr = std::shared_ptr<Trace>;
107 using TraceVector = std::vector<TracePtr>;
108 
109 class RootTrace : public Trace {
110  public:
111   RootTrace(PyObject *obj, TraceType tt, int index = -1, std::string name = "", std::string module_name = "");
112   virtual ~RootTrace() = default;
113   virtual PyObject *Retrieve(PTraceContext context, bool perf = false);
114   virtual std::string ToString(bool include_param = true);
115   virtual void GetParam(int *index, std::string *name, std::string *module_name);
116   virtual bool operator==(const Trace &trace);
117   virtual const InfoPack &Info();
118   static bool Support(TraceType tt);
119 
120  protected:
121   PyObject *RetrieveGlobal(PTraceContext context);
122   PyObject *RetrieveDeref(PTraceContext context);
123   PyObject *RetrieveClosure(PTraceContext context);
124   PyObject *RetrieveBuiltin(PTraceContext context);
125   PyObject *RetrieveLocal(PTraceContext context);
126   PyObject *RetrieveParam(PTraceContext context);
127   PyObject *RetrieveName(PTraceContext context);
128   PyObject *RetrieveClassDeref(PTraceContext context);
129 
130   int idx_;
131   std::string name_;
132   std::string module_name_;
133 };
134 using RootTracePtr = std::shared_ptr<RootTrace>;
135 
136 class ItemTrace : public Trace {
137  public:
138   ItemTrace(PyObject *obj, TracePtr origin, TracePtr item);
139   virtual ~ItemTrace() = default;
140   virtual TracePtr GetItem();
141   virtual void Replace(std::shared_ptr<Trace> dst, std::shared_ptr<Trace> src);
142   virtual PyObject *Retrieve(PTraceContext context, bool perf = false);
143   virtual std::string ToString(bool include_param = true);
144   virtual bool operator==(const Trace &trace);
145   virtual void Detach();
146   virtual const InfoPack &Info();
147   virtual TracePtr Optimize();
148   virtual void SetRelaxCount(int cnt);
149   static bool Support(TraceType tt);
150 
151  protected:
152   TracePtr item_;
153 };
154 using ItemTracePtr = std::shared_ptr<ItemTrace>;
155 
156 class AttrTrace : public Trace {
157  public:
158   AttrTrace(PyObject *obj, TracePtr origin, std::string attr);
159   virtual ~AttrTrace() = default;
160   virtual std::string GetAttribute();
161   virtual PyObject *Retrieve(PTraceContext context, bool perf = false);
162   virtual std::string ToString(bool include_param = true);
163   virtual bool operator==(const Trace &trace);
164   virtual const InfoPack &Info();
165   virtual TracePtr Optimize();
166   virtual void SetRelaxCount(int cnt);
167   static bool Support(TraceType tt);
168 
169  protected:
170   std::string attr_;
171 };
172 using AttrTracePtr = std::shared_ptr<AttrTrace>;
173 
174 class ConstTrace : public Trace {
175  public:
176   ConstTrace(PyObject *obj, int index);
177   virtual ~ConstTrace() = default;
178   virtual int GetIndex();
179   virtual PyObject *Retrieve(PTraceContext context, bool perf = false);
180   virtual std::string ToString(bool include_param = true);
181   virtual bool operator==(const Trace &trace);
182   virtual void Detach();
183   virtual const InfoPack &Info();
184   static bool Support(TraceType tt);
185 
186  protected:
187   int index_;
188 };
189 using ConstTracePtr = std::shared_ptr<ConstTrace>;
190 
191 class TypeTrace : public Trace {
192  public:
193   TypeTrace(PyObject *obj, TracePtr origin);
194   virtual ~TypeTrace() = default;
195   virtual PyTypeObject *GetType();
196   virtual PyObject *Retrieve(PTraceContext context, bool perf = false);
197   virtual std::string ToString(bool include_param = true);
198   virtual bool operator==(const Trace &trace);
199   virtual const InfoPack &Info();
200   virtual void Detach();
201   virtual TracePtr Optimize();
202   virtual void SetRelaxCount(int cnt);
203   static bool Support(TraceType tt);
204 
205  protected:
206   PyTypeObject *pType_;
207 };
208 using TypeTracePtr = std::shared_ptr<TypeTrace>;
209 
210 class OpTrace : public Trace {
211  public:
212   OpTrace(PyObject *obj, int opcode, int opargs, TraceVector params, std::string name = "");
213   virtual ~OpTrace() = default;
214   virtual int GetOpCode();
215   virtual int GetOpArgs();
216   virtual TracePtr GetParam(size_t idx);
217   virtual size_t GetParamCount();
218   virtual std::string GetName();
219   virtual void Replace(std::shared_ptr<Trace> dst, std::shared_ptr<Trace> src);
220   virtual PyObject *Retrieve(PTraceContext context, bool perf = false);
221   virtual std::string ToString(bool include_param = true);
222   virtual bool operator==(const Trace &trace);
223   virtual void Detach();
224   std::string FormatString(std::map<Trace *, size_t> *cache) override;
225   virtual const InfoPack &Info();
226   virtual TracePtr Optimize();
227   virtual void SetRelaxCount(int cnt);
228   static bool Support(TraceType tt);
229 
230  protected:
231   virtual void CheckSpecialize();
232   virtual TracePtr RemoveCastDuplicatePatternPass();
233   virtual TracePtr RemovePrimOutIsTensorPass();
234   virtual TracePtr RemoveEmptyTensorPass();
235   virtual TracePtr RemoveCastPass();
236   virtual void JudgeDTypeChangePass();
237   virtual void JudgeDTypeScopePass();
238   virtual void JudgeCodeChangePass();
239   virtual void JudgeTrainFlagPass();
240   virtual void JudgeCompareConstPass();
241   virtual void JudgeContainsConstPass();
242   virtual void JudgeInplaceAddConstPass();
243   virtual void JudgeIsConstPass();
244   virtual void JudgeBoundMethodPass();
245   virtual void JudgeSubScrRandPass();
246   virtual void JudgeDTypeTensorAttrPass();
247   virtual void JudgeRelaxGuardFuncPass();
248 
249  protected:
250   int opcode_;
251   int opargs_;
252   TraceVector params_;
253   std::string name_;
254 };
255 using OpTracePtr = std::shared_ptr<OpTrace>;
256 TracePtr CreateOpTrace(PyObject *obj, int opcode, int opargs, TraceVector params, const std::string &module_name = "",
257                        const std::string &name = "", bool strict = false, bool print = false);
258 
259 /// \brief retrieve the PyObject with ref count plus 1 which will be minus outside
260 typedef std::function<PyObject *(PTraceContext context)> RetrieveFunc;
261 typedef std::function<std::string(bool)> ToStringFunc;
262 class CustomizedTrace : public Trace {
263  public:
264   CustomizedTrace(PyObject *obj, RetrieveFunc rfunc, ToStringFunc sfunc);
265   virtual ~CustomizedTrace() = default;
266   virtual PyObject *Retrieve(PTraceContext context, bool perf = false);
267   virtual std::string ToString(bool include_param = true);
268   virtual const InfoPack &Info();
269   static bool Support(TraceType tt);
270 
271  protected:
272   RetrieveFunc retrieve_;
273   ToStringFunc tostring_;
274 };
275 using CustomizedTracePtr = std::shared_ptr<CustomizedTrace>;
276 
277 class UnsupportedTrace : public Trace {
278  public:
279   UnsupportedTrace(PyObject *obj, TraceVector params, int op, int arg);
280   virtual ~UnsupportedTrace() = default;
281   virtual PyObject *Retrieve(PTraceContext context, bool perf = false);
282   virtual std::string ToString(bool include_param = true);
283   virtual TraceVector GetParams();
284   virtual void Detach();
285   std::string FormatString(std::map<Trace *, size_t> *cache) override;
286   virtual const InfoPack &Info();
287   virtual void SetRelaxCount(int cnt);
288   static bool Support(TraceType tt);
289 
290  protected:
291   TraceVector params_;
292   int op_;
293   int arg_;
294 };
295 using UnsupportedTracePtr = std::shared_ptr<UnsupportedTrace>;
296 
297 /// \brief Get the reference for the object by Py_INCREF and call Py_DECREF by yourself.
298 PyObject *GetObjectFromTrace(const PyFrameObject *frame, TracePtr trace, std::map<size_t, PyObject *> *cache = nullptr,
299                              bool perf = false);
300 }  // namespace pijit
301 }  // namespace mindspore
302 
303 #endif  // MINDSPORE_PI_JIT_TRACE_H
304