• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_PI_JIT_GRAPH_CAPTURE_ABSTRACT_OBJECT_H
17 #define MINDSPORE_PI_JIT_GRAPH_CAPTURE_ABSTRACT_OBJECT_H
18 
19 #include <set>
20 #include <string>
21 #include <unordered_map>
22 #include <vector>
23 #include <memory>
24 #include "pybind11/pybind11.h"
25 #include "pipeline/jit/pi/utils/mempool.h"
26 #include "utils/convert_utils_base.h"
27 
28 namespace py = pybind11;
29 namespace mindspore {
30 namespace pijit {
31 
32 class AbstractObjectBase;
33 using AObject = AbstractObjectBase;
34 
35 class AbstractObjectBase {
36  private:
37   class Resource {
38    public:
Current()39     static Resource *Current() { return weak_this_.empty() ? nullptr : weak_this_.back(); }
40 
41    private:
42     static std::vector<Resource *> weak_this_;
43 
44    public:
45     Resource();
46     ~Resource();
Release()47     void Release() {}
pool()48     MemPool<AbstractObjectBase> *pool() { return &pool_; }
49 
50    private:
51     MemPool<AbstractObjectBase> pool_;
52   };
53 
54  public:
55   enum Type {
56 #define ABSTRACT_TYPE_DEF(unit) kType##unit,
57 #include "abstract_type_kind.def"
58 #undef ABSTRACT_TYPE_DEF
59   };
60   enum MindsporeFlag {
61 #define ABSTRACT_MS_FLAG_DEF(unit, bit) kMsFlag##unit = 1 << (bit),
62 #include "abstract_ms_flag.def"
63 #undef ABSTRACT_MS_FLAG_DEF
64   };
65   static_assert(static_cast<int>(kTypeSlice) + 8 == static_cast<int>(kTypeType));  // builtin type
66   static_assert(static_cast<int>(kTypeAnyValue) == 0);
67 
68   enum BoolCache {
69     kBoolFalse = 0,
70     kBoolTrue,
71     kBoolUnknown,
72   };
73 
74   // record PyObject and check self reference for list,tuple,dict
75   using RecMap = std::unordered_map<PyObject *, AObject *>;
76 
77   static bool trace_flag_;
78 
AbstractObjectBase(Type type)79   explicit AbstractObjectBase(Type type) : type_object_(nullptr), type_(type), ms_flag_(0) {}
~AbstractObjectBase()80   virtual ~AbstractObjectBase() {}
81 
SetTypeObject(PyTypeObject * tp)82   void SetTypeObject(PyTypeObject *tp) { type_object_ = tp; }
SetTraceFlag(bool trace_flag)83   static void SetTraceFlag(bool trace_flag) { trace_flag_ = trace_flag; }
GetTypeObject()84   PyTypeObject *GetTypeObject() const { return type_object_; }
GetType()85   Type GetType() const { return type_; }
86 
GetPyObject()87   virtual py::object GetPyObject() { return py::object(); }
88 
Binary(AObject * other,int op)89   virtual AObject *Binary(AObject *other, int op) { return MakeAObject(kTypeAnyValue); }
Unary(int op)90   virtual AObject *Unary(int op) const { return MakeAObject(kTypeAnyValue); }
GetIter()91   virtual AObject *GetIter() const { return MakeAObject(kTypeAnyValue); }
92 
93   virtual AObject *GetAttr(const std::string &name);
GetItem(AObject * key)94   virtual AObject *GetItem(AObject *key) { return MakeAObject(kTypeAnyValue); }
95 
96   // return false if has an python exception
SetAttr(const std::string & name,AObject * value)97   virtual bool SetAttr(const std::string &name, AObject *value) { return true; }
SetItem(AObject * key,AObject * value)98   virtual bool SetItem(AObject *key, AObject *value) { return true; }
DelItem(AObject * key)99   virtual bool DelItem(AObject *key) { return SetItem(key, nullptr); }
DelAttr(const std::string & name)100   virtual bool DelAttr(const std::string &name) { return SetAttr(name, nullptr); }
101   virtual bool IsMindSporeSupportedType();
102   virtual std::string ToString() const;
103 
SetMsFlag(unsigned flag)104   void SetMsFlag(unsigned flag) { ms_flag_ |= flag; }
ClearMsFlag(unsigned flag)105   void ClearMsFlag(unsigned flag) { ms_flag_ &= ~flag; }
HasMsFlag(unsigned flag)106   bool HasMsFlag(unsigned flag) { return ms_flag_ & flag; }
TestMsFlag(unsigned flag)107   bool TestMsFlag(unsigned flag) { return ms_flag_ & flag; }
108 
109   static Type GetPyType(PyObject *op);
110   static Type GetPyType(PyTypeObject *tp);
111   static Type GetMsType(PyTypeObject *tp);
Convert(const py::object & o)112   static AObject *Convert(const py::object &o) { return Convert(o.ptr()); }
Convert(PyObject * o)113   static AObject *Convert(PyObject *o) { return MakeAObject(GetPyType(o), o ? Py_TYPE(o) : nullptr, o); }
MakeAObject(Type real_type)114   static AObject *MakeAObject(Type real_type) { return MakeAObject(real_type, nullptr, nullptr); }
MakeResource()115   static auto MakeResource() { return Resource(); }
116 
117   static AObject *MakeFunction(const std::vector<AObject *> &args, const py::object &globals, int oparg);
118 
119   /**
120    * BUILD_SLICE,BUILD_STRING,BUILD_SET,BUILD_LIST,BUILD_TUPLE,BUILD_CONST_KEY_MAP,BUILD_MAP
121    * \return a new AbstractObject if success, else a empty AbstractObject
122    **/
123   static AObject *BuildOperations(const std::vector<AObject *> &args, int opcode);
124   static py::object BuildOperations(const std::vector<py::object> &args, int opcode);
125 
126   /**
127    * LIST_EXTEND,LIST_APPEND,DICT_MERGE,DICT_UPDATE,SET_UPDATE,SET_ADD,MAP_ADD
128    * \return container if success, else a empty AbstractObject
129    **/
130   static AObject *MergeOperations(AObject *container, std::vector<AObject *> args, int opcode);
131 
132   static int BinaryContains(AObject *l, AObject *r);
133   static int BinaryIs(AObject *l, AObject *r);
134 
135   static const char *GetTypeDesc(AObject::Type type);
136   static std::string ToString(PyObject *);
137 
138  protected:
139   static AObject *MakeAObject(Type type, PyTypeObject *tp, PyObject *op, RecMap *rec = nullptr);
140   PyTypeObject *type_object_;
141   const Type type_;
142   unsigned ms_flag_;
143 };
144 
145 class AbstractObject : public AbstractObjectBase {
146  public:
147   AbstractObject(Type type, const py::object &o);
~AbstractObject()148   virtual ~AbstractObject() {}
149 
GetPyObject()150   py::object GetPyObject() override { return value_; }
151 
152   AObject *Binary(AObject *other, int op) override;
153   AObject *Unary(int op) const override;
154   AObject *UnaryValue(int op) const;
155   AObject *GetIter() const override;
156   AObject *GetAttr(const std::string &name) override;
157   AObject *GetItem(AObject *key);
158   bool SetAttr(const std::string &n, AObject *v) override;
159 
160  protected:
161   py::object value_;
162   std::unordered_map<std::string, AObject *> attrs_;  // cache
163 };
164 
165 class AbstractType : public AbstractObject {
166  public:
AbstractType(py::object cls)167   explicit AbstractType(py::object cls)
168       : AbstractObject(kTypeType, cls), type_type_(GetPyType(reinterpret_cast<PyTypeObject *>(cls.ptr()))) {
169     this->SetTypeObject(&PyType_Type);
170   }
~AbstractType()171   virtual ~AbstractType() {}
ToString()172   std::string ToString() const override { return std::string(py::str(value_.ptr())); }
IsMindSporeSupportedType()173   bool IsMindSporeSupportedType() override { return false; }
174 
GetTypeType()175   Type GetTypeType() const { return type_type_; }
176   AObject *BuildAbstractInstance(const std::vector<AObject *> &args, int opcode);
177   py::object BuildInstance(const std::vector<py::object> &args, int opcode);
178 
179  private:
180   Type type_type_;
181 };
182 
183 class AbstractSequence : public AbstractObject {
184  public:
AbstractSequence(Type type,const py::object & o)185   explicit AbstractSequence(Type type, const py::object &o) : AbstractObject(type, o) {}
~AbstractSequence()186   virtual ~AbstractSequence() {}
187 
188   AObject *GetItem(AObject *key) override;
189   bool SetItem(AObject *key, AObject *value) override;
190 
GetPyObject()191   py::object GetPyObject() override { return write_cache_.size() ? py::object() : value_; }
192 
193  protected:
194   std::unordered_map<AObject *, AObject *> write_cache_;  // cache
195 };
196 
197 class AbstractTuple : public AbstractSequence {
198  public:
AbstractTuple(kTypeTuple,l,m)199   explicit AbstractTuple(const py::object &l, RecMap *m = nullptr) : AbstractTuple(kTypeTuple, l, m) {}
~AbstractTuple()200   virtual ~AbstractTuple() {}
items()201   auto &items() { return items_; }
size()202   Py_ssize_t size() const { return IsElementValid() ? items_.size() : -1; }
203 
GetPyObject()204   py::object GetPyObject() override { return value_; }
205   AObject *Binary(AObject *other, int op) override;
206   AObject *Unary(int op) const override;
207   AObject *GetAttr(const std::string &name) override;
SetAttr(const std::string & name,AObject *)208   bool SetAttr(const std::string &name, AObject *) override { return false; };
209   AObject *GetItem(AObject *k) override;
210   std::string ToString() const override;
211   bool IsMindSporeSupportedType() override;
212 
SetElementType(Type type)213   void SetElementType(Type type) { element_type_ = type; }
GetElementType()214   Type GetElementType() const { return element_type_; }
IsElementValid()215   bool IsElementValid() const { return element_valid_; }
MarkElementInValid()216   void MarkElementInValid() {
217     element_type_ = kTypeAnyValue;
218     element_valid_ = false;
219     modify_ = false;
220     value_ = py::object();
221     items_.clear();
222     write_cache_.clear();
223   }
begin()224   auto begin() const { return items_.begin(); }
end()225   auto end() const { return items_.end(); }
IsModify()226   bool IsModify() const { return modify_ || this->write_cache_.size() > 0; }
MarkModify()227   void MarkModify() { modify_ = true; }
228   bool Update(const std::vector<AObject *> &items);
229   bool Update();
230 
231  protected:
232   AbstractTuple(Type type, py::object list, RecMap *m);
233   std::vector<AObject *> items_;
234   BoolCache ms_support_;
235   Type element_type_;
236   bool element_valid_;
237   bool modify_;
238 };
239 
240 class AbstractList : public AbstractTuple {
241  public:
AbstractList(const py::object & l,RecMap * m)242   explicit AbstractList(const py::object &l, RecMap *m) : AbstractTuple(kTypeList, l, m) {}
~AbstractList()243   virtual ~AbstractList() {}
244 
245   py::object GetPyObject() override;
246   bool SetItem(AObject *k, AObject *v) override;
247 
248   bool ListAppend(AObject *item);
249   bool ListExtend(AObject *list);
250   AbstractTuple *ListToTuple();
251 };
252 
253 class AbstractDict : public AbstractSequence {
254  public:
AbstractDict(kTypeDict,dict,m)255   explicit AbstractDict(const py::object &dict, RecMap *m = nullptr) : AbstractDict(kTypeDict, dict, m) {}
~AbstractDict()256   virtual ~AbstractDict() {}
size()257   Py_ssize_t size() const { return IsElementValid() ? dict_.size() : -1; }
258 
259   std::string ToString() const override;
260   py::object GetPyObject() override;
261   AObject *Unary(int op) const override;
262   AObject *Binary(AObject *, int op) override;
263   AObject *GetAttr(const std::string &name) override;
SetAttr(const std::string & name,AObject *)264   bool SetAttr(const std::string &name, AObject *) override { return false; };
265   AObject *GetItem(AObject *key) override;
266   bool IsMindSporeSupportedType() override;
267 
KeyType()268   Type KeyType() const { return k_type_; }
ValueType()269   Type ValueType() const { return v_type_; }
270 
IsModify()271   bool IsModify() const { return modify_ || this->write_cache_.size() > 0; }
MarkModify()272   void MarkModify() { modify_ = true; }
273   bool DictMerge(AObject *o, int update = 0);
274   bool DictUpdate(AObject *o);
275   bool MapAdd(AObject *k, AObject *v);
IsElementValid()276   bool IsElementValid() const { return element_valid_; }
MarkElementInValid()277   void MarkElementInValid() {
278     k_type_ = kTypeAnyValue;
279     v_type_ = kTypeAnyValue;
280     element_valid_ = false;
281     modify_ = false;
282     value_ = py::object();
283     dict_.clear();
284     write_cache_.clear();
285   }
286   bool Update();
287 
288   class ValueIter {
289    public:
ValueIter(const AbstractDict * dict)290     explicit ValueIter(const AbstractDict *dict) : map_(dict->dict_.ptr()), pos_(0) { ++(*this); }
ValueIter()291     ValueIter() : map_(nullptr) {}
key()292     py::object key() { return py::cast<py::object>(key_); }
293     AObject *operator*() { return AbstractDict::ConvertValue(val_); }
294     bool operator!=(const ValueIter &o) { return map_ != nullptr; }
295     ValueIter &operator++() {
296       map_ = PyDict_Next(map_, &pos_, &key_, &val_) ? map_ : nullptr;
297       return *this;
298     }
299 
300    private:
301     PyObject *map_, *key_, *val_;
302     Py_ssize_t pos_;
303   };
begin()304   auto begin() const { return ValueIter(this); }
end()305   auto end() const { return ValueIter(); }
306 
ConvertValue(PyObject * i)307   static AObject *ConvertValue(PyObject *i) { return reinterpret_cast<AObject *>(PyLong_AsVoidPtr(i)); }
ConvertValue(AObject * i)308   static py::object ConvertValue(AObject *i) { return py::reinterpret_steal<py::object>(PyLong_FromVoidPtr(i)); }
309 
310  protected:
311   AbstractDict(Type type, py::object o, RecMap *m);
312   py::dict dict_;
313   Type k_type_;
314   Type v_type_;
315   bool element_valid_;
316   bool modify_;
317 };
318 
319 class AbstractTensor : public AbstractObject {
320  public:
321   static py::object Binary(int op, const py::object &, const py::object &);
322 
323  public:
324   AbstractTensor(const py::object &o, bool is_stub);
~AbstractTensor()325   virtual ~AbstractTensor() {}
326   AObject *Binary(AObject *, int op) override;
327   AObject *Unary(int op) const override;
328   AObject *GetAttr(const std::string &name) override;
329   std::string ToString() const override;
330 
SetItem(AObject * key,AObject * value)331   bool SetItem(AObject *key, AObject *value) override { return true; }
332   AObject *GetItem(AObject *key) override;
333   py::object GetTensor(bool sync);
334 
IsMindSporeSupportedType()335   bool IsMindSporeSupportedType() override { return true; }
IsStubTensor()336   bool IsStubTensor() const { return is_stub_; }
337 
338  private:
339   bool is_stub_;
340 };
341 }  // namespace pijit
342 }  // namespace mindspore
343 
344 #endif  // MINDSPORE_PI_JIT_GRAPH_CAPTURE_ABSTRACT_OBJECT_H
345