1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_PI_JIT_GRAPH_CAPTURE_ABSTRACT_OBJECT_H 17 #define MINDSPORE_PI_JIT_GRAPH_CAPTURE_ABSTRACT_OBJECT_H 18 19 #include <set> 20 #include <string> 21 #include <unordered_map> 22 #include <vector> 23 #include <memory> 24 #include "pybind11/pybind11.h" 25 #include "pipeline/jit/pi/utils/mempool.h" 26 #include "utils/convert_utils_base.h" 27 28 namespace py = pybind11; 29 namespace mindspore { 30 namespace pijit { 31 32 class AbstractObjectBase; 33 using AObject = AbstractObjectBase; 34 35 class AbstractObjectBase { 36 private: 37 class Resource { 38 public: Current()39 static Resource *Current() { return weak_this_.empty() ? nullptr : weak_this_.back(); } 40 41 private: 42 static std::vector<Resource *> weak_this_; 43 44 public: 45 Resource(); 46 ~Resource(); Release()47 void Release() {} pool()48 MemPool<AbstractObjectBase> *pool() { return &pool_; } 49 50 private: 51 MemPool<AbstractObjectBase> pool_; 52 }; 53 54 public: 55 enum Type { 56 #define ABSTRACT_TYPE_DEF(unit) kType##unit, 57 #include "abstract_type_kind.def" 58 #undef ABSTRACT_TYPE_DEF 59 }; 60 enum MindsporeFlag { 61 #define ABSTRACT_MS_FLAG_DEF(unit, bit) kMsFlag##unit = 1 << (bit), 62 #include "abstract_ms_flag.def" 63 #undef ABSTRACT_MS_FLAG_DEF 64 }; 65 static_assert(static_cast<int>(kTypeSlice) + 8 == static_cast<int>(kTypeType)); // builtin type 66 static_assert(static_cast<int>(kTypeAnyValue) == 0); 67 68 enum BoolCache { 69 kBoolFalse = 0, 70 kBoolTrue, 71 kBoolUnknown, 72 }; 73 74 // record PyObject and check self reference for list,tuple,dict 75 using RecMap = std::unordered_map<PyObject *, AObject *>; 76 77 static bool trace_flag_; 78 AbstractObjectBase(Type type)79 explicit AbstractObjectBase(Type type) : type_object_(nullptr), type_(type), ms_flag_(0) {} ~AbstractObjectBase()80 virtual ~AbstractObjectBase() {} 81 SetTypeObject(PyTypeObject * tp)82 void SetTypeObject(PyTypeObject *tp) { type_object_ = tp; } SetTraceFlag(bool trace_flag)83 static void SetTraceFlag(bool trace_flag) { trace_flag_ = trace_flag; } GetTypeObject()84 PyTypeObject *GetTypeObject() const { return type_object_; } GetType()85 Type GetType() const { return type_; } 86 GetPyObject()87 virtual py::object GetPyObject() { return py::object(); } 88 Binary(AObject * other,int op)89 virtual AObject *Binary(AObject *other, int op) { return MakeAObject(kTypeAnyValue); } Unary(int op)90 virtual AObject *Unary(int op) const { return MakeAObject(kTypeAnyValue); } GetIter()91 virtual AObject *GetIter() const { return MakeAObject(kTypeAnyValue); } 92 93 virtual AObject *GetAttr(const std::string &name); GetItem(AObject * key)94 virtual AObject *GetItem(AObject *key) { return MakeAObject(kTypeAnyValue); } 95 96 // return false if has an python exception SetAttr(const std::string & name,AObject * value)97 virtual bool SetAttr(const std::string &name, AObject *value) { return true; } SetItem(AObject * key,AObject * value)98 virtual bool SetItem(AObject *key, AObject *value) { return true; } DelItem(AObject * key)99 virtual bool DelItem(AObject *key) { return SetItem(key, nullptr); } DelAttr(const std::string & name)100 virtual bool DelAttr(const std::string &name) { return SetAttr(name, nullptr); } 101 virtual bool IsMindSporeSupportedType(); 102 virtual std::string ToString() const; 103 SetMsFlag(unsigned flag)104 void SetMsFlag(unsigned flag) { ms_flag_ |= flag; } ClearMsFlag(unsigned flag)105 void ClearMsFlag(unsigned flag) { ms_flag_ &= ~flag; } HasMsFlag(unsigned flag)106 bool HasMsFlag(unsigned flag) { return ms_flag_ & flag; } TestMsFlag(unsigned flag)107 bool TestMsFlag(unsigned flag) { return ms_flag_ & flag; } 108 109 static Type GetPyType(PyObject *op); 110 static Type GetPyType(PyTypeObject *tp); 111 static Type GetMsType(PyTypeObject *tp); Convert(const py::object & o)112 static AObject *Convert(const py::object &o) { return Convert(o.ptr()); } Convert(PyObject * o)113 static AObject *Convert(PyObject *o) { return MakeAObject(GetPyType(o), o ? Py_TYPE(o) : nullptr, o); } MakeAObject(Type real_type)114 static AObject *MakeAObject(Type real_type) { return MakeAObject(real_type, nullptr, nullptr); } MakeResource()115 static auto MakeResource() { return Resource(); } 116 117 static AObject *MakeFunction(const std::vector<AObject *> &args, const py::object &globals, int oparg); 118 119 /** 120 * BUILD_SLICE,BUILD_STRING,BUILD_SET,BUILD_LIST,BUILD_TUPLE,BUILD_CONST_KEY_MAP,BUILD_MAP 121 * \return a new AbstractObject if success, else a empty AbstractObject 122 **/ 123 static AObject *BuildOperations(const std::vector<AObject *> &args, int opcode); 124 static py::object BuildOperations(const std::vector<py::object> &args, int opcode); 125 126 /** 127 * LIST_EXTEND,LIST_APPEND,DICT_MERGE,DICT_UPDATE,SET_UPDATE,SET_ADD,MAP_ADD 128 * \return container if success, else a empty AbstractObject 129 **/ 130 static AObject *MergeOperations(AObject *container, std::vector<AObject *> args, int opcode); 131 132 static int BinaryContains(AObject *l, AObject *r); 133 static int BinaryIs(AObject *l, AObject *r); 134 135 static const char *GetTypeDesc(AObject::Type type); 136 static std::string ToString(PyObject *); 137 138 protected: 139 static AObject *MakeAObject(Type type, PyTypeObject *tp, PyObject *op, RecMap *rec = nullptr); 140 PyTypeObject *type_object_; 141 const Type type_; 142 unsigned ms_flag_; 143 }; 144 145 class AbstractObject : public AbstractObjectBase { 146 public: 147 AbstractObject(Type type, const py::object &o); ~AbstractObject()148 virtual ~AbstractObject() {} 149 GetPyObject()150 py::object GetPyObject() override { return value_; } 151 152 AObject *Binary(AObject *other, int op) override; 153 AObject *Unary(int op) const override; 154 AObject *UnaryValue(int op) const; 155 AObject *GetIter() const override; 156 AObject *GetAttr(const std::string &name) override; 157 AObject *GetItem(AObject *key); 158 bool SetAttr(const std::string &n, AObject *v) override; 159 160 protected: 161 py::object value_; 162 std::unordered_map<std::string, AObject *> attrs_; // cache 163 }; 164 165 class AbstractType : public AbstractObject { 166 public: AbstractType(py::object cls)167 explicit AbstractType(py::object cls) 168 : AbstractObject(kTypeType, cls), type_type_(GetPyType(reinterpret_cast<PyTypeObject *>(cls.ptr()))) { 169 this->SetTypeObject(&PyType_Type); 170 } ~AbstractType()171 virtual ~AbstractType() {} ToString()172 std::string ToString() const override { return std::string(py::str(value_.ptr())); } IsMindSporeSupportedType()173 bool IsMindSporeSupportedType() override { return false; } 174 GetTypeType()175 Type GetTypeType() const { return type_type_; } 176 AObject *BuildAbstractInstance(const std::vector<AObject *> &args, int opcode); 177 py::object BuildInstance(const std::vector<py::object> &args, int opcode); 178 179 private: 180 Type type_type_; 181 }; 182 183 class AbstractSequence : public AbstractObject { 184 public: AbstractSequence(Type type,const py::object & o)185 explicit AbstractSequence(Type type, const py::object &o) : AbstractObject(type, o) {} ~AbstractSequence()186 virtual ~AbstractSequence() {} 187 188 AObject *GetItem(AObject *key) override; 189 bool SetItem(AObject *key, AObject *value) override; 190 GetPyObject()191 py::object GetPyObject() override { return write_cache_.size() ? py::object() : value_; } 192 193 protected: 194 std::unordered_map<AObject *, AObject *> write_cache_; // cache 195 }; 196 197 class AbstractTuple : public AbstractSequence { 198 public: AbstractTuple(kTypeTuple,l,m)199 explicit AbstractTuple(const py::object &l, RecMap *m = nullptr) : AbstractTuple(kTypeTuple, l, m) {} ~AbstractTuple()200 virtual ~AbstractTuple() {} items()201 auto &items() { return items_; } size()202 Py_ssize_t size() const { return IsElementValid() ? items_.size() : -1; } 203 GetPyObject()204 py::object GetPyObject() override { return value_; } 205 AObject *Binary(AObject *other, int op) override; 206 AObject *Unary(int op) const override; 207 AObject *GetAttr(const std::string &name) override; SetAttr(const std::string & name,AObject *)208 bool SetAttr(const std::string &name, AObject *) override { return false; }; 209 AObject *GetItem(AObject *k) override; 210 std::string ToString() const override; 211 bool IsMindSporeSupportedType() override; 212 SetElementType(Type type)213 void SetElementType(Type type) { element_type_ = type; } GetElementType()214 Type GetElementType() const { return element_type_; } IsElementValid()215 bool IsElementValid() const { return element_valid_; } MarkElementInValid()216 void MarkElementInValid() { 217 element_type_ = kTypeAnyValue; 218 element_valid_ = false; 219 modify_ = false; 220 value_ = py::object(); 221 items_.clear(); 222 write_cache_.clear(); 223 } begin()224 auto begin() const { return items_.begin(); } end()225 auto end() const { return items_.end(); } IsModify()226 bool IsModify() const { return modify_ || this->write_cache_.size() > 0; } MarkModify()227 void MarkModify() { modify_ = true; } 228 bool Update(const std::vector<AObject *> &items); 229 bool Update(); 230 231 protected: 232 AbstractTuple(Type type, py::object list, RecMap *m); 233 std::vector<AObject *> items_; 234 BoolCache ms_support_; 235 Type element_type_; 236 bool element_valid_; 237 bool modify_; 238 }; 239 240 class AbstractList : public AbstractTuple { 241 public: AbstractList(const py::object & l,RecMap * m)242 explicit AbstractList(const py::object &l, RecMap *m) : AbstractTuple(kTypeList, l, m) {} ~AbstractList()243 virtual ~AbstractList() {} 244 245 py::object GetPyObject() override; 246 bool SetItem(AObject *k, AObject *v) override; 247 248 bool ListAppend(AObject *item); 249 bool ListExtend(AObject *list); 250 AbstractTuple *ListToTuple(); 251 }; 252 253 class AbstractDict : public AbstractSequence { 254 public: AbstractDict(kTypeDict,dict,m)255 explicit AbstractDict(const py::object &dict, RecMap *m = nullptr) : AbstractDict(kTypeDict, dict, m) {} ~AbstractDict()256 virtual ~AbstractDict() {} size()257 Py_ssize_t size() const { return IsElementValid() ? dict_.size() : -1; } 258 259 std::string ToString() const override; 260 py::object GetPyObject() override; 261 AObject *Unary(int op) const override; 262 AObject *Binary(AObject *, int op) override; 263 AObject *GetAttr(const std::string &name) override; SetAttr(const std::string & name,AObject *)264 bool SetAttr(const std::string &name, AObject *) override { return false; }; 265 AObject *GetItem(AObject *key) override; 266 bool IsMindSporeSupportedType() override; 267 KeyType()268 Type KeyType() const { return k_type_; } ValueType()269 Type ValueType() const { return v_type_; } 270 IsModify()271 bool IsModify() const { return modify_ || this->write_cache_.size() > 0; } MarkModify()272 void MarkModify() { modify_ = true; } 273 bool DictMerge(AObject *o, int update = 0); 274 bool DictUpdate(AObject *o); 275 bool MapAdd(AObject *k, AObject *v); IsElementValid()276 bool IsElementValid() const { return element_valid_; } MarkElementInValid()277 void MarkElementInValid() { 278 k_type_ = kTypeAnyValue; 279 v_type_ = kTypeAnyValue; 280 element_valid_ = false; 281 modify_ = false; 282 value_ = py::object(); 283 dict_.clear(); 284 write_cache_.clear(); 285 } 286 bool Update(); 287 288 class ValueIter { 289 public: ValueIter(const AbstractDict * dict)290 explicit ValueIter(const AbstractDict *dict) : map_(dict->dict_.ptr()), pos_(0) { ++(*this); } ValueIter()291 ValueIter() : map_(nullptr) {} key()292 py::object key() { return py::cast<py::object>(key_); } 293 AObject *operator*() { return AbstractDict::ConvertValue(val_); } 294 bool operator!=(const ValueIter &o) { return map_ != nullptr; } 295 ValueIter &operator++() { 296 map_ = PyDict_Next(map_, &pos_, &key_, &val_) ? map_ : nullptr; 297 return *this; 298 } 299 300 private: 301 PyObject *map_, *key_, *val_; 302 Py_ssize_t pos_; 303 }; begin()304 auto begin() const { return ValueIter(this); } end()305 auto end() const { return ValueIter(); } 306 ConvertValue(PyObject * i)307 static AObject *ConvertValue(PyObject *i) { return reinterpret_cast<AObject *>(PyLong_AsVoidPtr(i)); } ConvertValue(AObject * i)308 static py::object ConvertValue(AObject *i) { return py::reinterpret_steal<py::object>(PyLong_FromVoidPtr(i)); } 309 310 protected: 311 AbstractDict(Type type, py::object o, RecMap *m); 312 py::dict dict_; 313 Type k_type_; 314 Type v_type_; 315 bool element_valid_; 316 bool modify_; 317 }; 318 319 class AbstractTensor : public AbstractObject { 320 public: 321 static py::object Binary(int op, const py::object &, const py::object &); 322 323 public: 324 AbstractTensor(const py::object &o, bool is_stub); ~AbstractTensor()325 virtual ~AbstractTensor() {} 326 AObject *Binary(AObject *, int op) override; 327 AObject *Unary(int op) const override; 328 AObject *GetAttr(const std::string &name) override; 329 std::string ToString() const override; 330 SetItem(AObject * key,AObject * value)331 bool SetItem(AObject *key, AObject *value) override { return true; } 332 AObject *GetItem(AObject *key) override; 333 py::object GetTensor(bool sync); 334 IsMindSporeSupportedType()335 bool IsMindSporeSupportedType() override { return true; } IsStubTensor()336 bool IsStubTensor() const { return is_stub_; } 337 338 private: 339 bool is_stub_; 340 }; 341 } // namespace pijit 342 } // namespace mindspore 343 344 #endif // MINDSPORE_PI_JIT_GRAPH_CAPTURE_ABSTRACT_OBJECT_H 345