• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pipeline/jit/pi/graph_guard/guard_utils.h"
17 #include <regex>
18 #include "pybind11/pybind11.h"
19 #include "pybind_api/ir/primitive_py.h"
20 #include "pybind_api/ir/cell_py.h"
21 #include "include/common/utils/convert_utils_py.h"
22 #include "pipeline/jit/pi/utils/utils.h"
23 #include "include/common/utils/stub_tensor.h"
24 #include "pipeline/jit/pi/graph_guard/strategy.h"
25 #include "pipeline/jit/pi/graph_guard/guard.h"
26 #include "pipeline/jit/pi/graph_guard/infer.h"
27 
28 namespace mindspore {
29 namespace pijit {
30 
31 static PyObject *kPyAttrStub = nullptr;
32 static PyObject *kPyAttrTensor = nullptr;
33 static PyObject *kPyAttrReprCache = nullptr;
34 static const char kPyAttrReprCacheStr[] = "__repr_cache__";
35 static const char kPyMethodStubSync[] = "stub_sync";
GetAttrStubStr()36 static PyObject *GetAttrStubStr() {
37   if (kPyAttrStub == nullptr) {
38     kPyAttrStub = PyUnicode_FromString(stub::PY_ATTR_STUB);
39   }
40   return kPyAttrStub;
41 }
GetAttrTensorStr()42 static PyObject *GetAttrTensorStr() {
43   if (kPyAttrTensor == nullptr) {
44     kPyAttrTensor = PyUnicode_FromString(stub::PY_ATTR_TENSOR);
45   }
46   return kPyAttrTensor;
47 }
GetAttrReprCacheStr()48 static PyObject *GetAttrReprCacheStr() {
49   if (kPyAttrReprCache == nullptr) {
50     kPyAttrReprCache = PyUnicode_FromString(kPyAttrReprCacheStr);
51   }
52   return kPyAttrReprCache;
53 }
54 
GetObjectString(PyObject * objName)55 static std::string GetObjectString(PyObject *objName) {
56   std::string ret = "";
57   if (objName == NULL) {
58     return ret;
59   }
60   PyObject *pyName = PyUnicode_AsEncodedString(objName, "utf-8", NULL);
61   char *strName = PyBytes_AsString(pyName);
62   if (strName != nullptr) {
63     ret = strName;
64   }
65   Py_DECREF(pyName);
66   return ret;
67 }
68 
69 #define DESC(op) (std::string("{") + std::string(#op) + std::string(":") + (op) + std::string("}"))
70 #define DESC_STRING(op) (std::string("{") + std::string(#op) + std::string(":") + std::to_string(op) + std::string("}"))
71 #define DESC_STRING_L(op, l)                                                                                          \
72   (std::string("{") + std::string(#op) + std::string("[") + std::to_string(l) + std::string("]") + std::string(":") + \
73    std::to_string(op) + std::string("}"))  // NOLINT
74 #define DESC_STRING_S(op, l)                                                                                          \
75   (std::string("{") + std::string(#op) + std::string("[") + std::to_string(l) + std::string("]") + std::string(":") + \
76    (op) + std::string("}"))  // NOLINT
77 #define DESC_STRING_O(obj, op) \
78   (std::string("{") + std::string(#op) + std::string(":") + std::to_string(obj->op) + std::string("}"))
79 #define DESC_TOSTRING(op)                                                                                             \
80   (std::string("{") + std::string(#op) + std::string(":") + ((op == nullptr) ? std::string("nil") : op->ToString()) + \
81    std::string("}"))  // NOLINT
82 #define DESC_ITEM(opK, opV)                                                                          \
83   (std::string("{") + ((opK == nullptr) ? std::string("nil") : opK->ToString()) + std::string(":") + \
84    ((opV == nullptr) ? std::string("nil") : opV->ToString()) + std::string("}"))  // NOLINT
85 #define DESC_ITEM_V(op) (std::string("{") + std::to_string(op) + std::string("}"))
86 #define DESC_ITEM_T(op) (std::string("{") + ((op == nullptr) ? std::string("nil") : op->ToString()) + std::string("}"))
87 #define DESC_INDEX(op, idx)                                                                          \
88   (std::string("{") + std::string(#op) + std::string("[") + std::to_string(idx) + std::string("]") + \
89    std::string(":") + ((op[idx] == nullptr) ? std::string("nil") : op[idx]->ToString()) + std::string("}"))  // NOLINT
90 #define DESC_INDEX_V(op, idx)                                                                        \
91   (std::string("{") + std::string(#op) + std::string("[") + std::to_string(idx) + std::string("]") + \
92    std::string(":") + std::to_string(op[idx]) + std::string("}"))  // NOLINT
93 #define DESC_END ItemData::ToString()
94 
95 typedef enum _ItemType {
96   PyNull = 0,
97   PyLong,
98   PyFloat,
99   PyBool,
100   PyBytes,
101   PyStr,
102   PyList,
103   PyTuple,
104   PySet,
105   PyFrozenSet,
106   PyDict,
107   PyComplex,
108   PySlice,
109   PyFunction,
110   PyMethod,
111   PyInstanceMethod,
112   PyType,
113   PyNumpy,
114   PyUnknown,
115   TensorType,
116   ParamInfo,
117   MetaTensor,
118   Tensor,
119   MapTensor,
120   RowTensor,
121   COOTensor,
122   CSRTensor,
123   Tensordata,
124   Primitive,
125   Cell,
126 } ItemType;
127 
128 class ItemData {
129  public:
ItemData(ItemType itemType,bool needSpecialize,int recurseDepth)130   ItemData(ItemType itemType, bool needSpecialize, int recurseDepth)
131       : tp_(itemType), specialized_(needSpecialize), recurseDepth_(recurseDepth), info_(nullptr) {}
132 
133   virtual ~ItemData() = default;
134 
operator ==(const ItemData & obj) const135   virtual bool operator==(const ItemData &obj) const { return obj.tp_ == tp_; }
136 
ToString()137   virtual std::string ToString() {
138     if (tp_ == ItemType::PyNull) {
139       return "(null)";
140     } else {
141       return std::string("(type:") + std::to_string(SizeToInt(tp_)) + ",specialize:" + std::to_string(specialized_) +
142              ",recurse:" + std::to_string(recurseDepth_) + ")";
143     }
144   }
145 
Info()146   virtual const InfoPack &Info() {
147     if (info_ == nullptr) {
148       InfoPack info;
149       info << uint8_t(tp_);
150       info.Begin();
151       if (tp_ != ItemType::PyNull && tp_ != ItemType::PyUnknown) {
152         info << specialized_ << recurseDepth_;
153       }
154       SubInfo(&info);
155       info.End();
156       info_ = std::make_shared<InfoPack>(info);
157       info_->Update();
158     }
159     return *info_;
160   }
161 
GetItemType()162   virtual ItemType GetItemType() { return tp_; }
163 
MatchDynamicShape(std::shared_ptr<ItemData> other)164   virtual bool MatchDynamicShape(std::shared_ptr<ItemData> other) { return false; }
165 
166  protected:
SubInfo(InfoPack * info)167   virtual void SubInfo(InfoPack *info) {}
168   ItemType tp_;
169   bool specialized_;
170   int recurseDepth_;
171   InfoPackPtr info_;
172 };
173 using ItemDataPtr = std::shared_ptr<ItemData>;
174 
175 static ItemDataPtr CreateItem(PyObject *obj, bool needSpecialize = true, int recurseDepth = INT_MAX);
176 
177 class IntData : public ItemData {
178  public:
IntData(PyObject * obj,bool needSpecialize,int recurseDepth)179   IntData(PyObject *obj, bool needSpecialize, int recurseDepth)
180       : ItemData(ItemType::PyLong, needSpecialize, recurseDepth) {
181     tp_ = ItemType::PyLong;
182     intVar_ = PyLong_AsLong(obj);
183   }
184 
operator ==(const ItemData & obj) const185   bool operator==(const ItemData &obj) const override {
186     return ItemData::operator==(obj) && (!specialized_ || ((static_cast<const IntData &>(obj)).intVar_ == intVar_));
187   }
188 
ToString()189   std::string ToString() override { return DESC_STRING(intVar_) + DESC_END; }
190 
191  protected:
SubInfo(InfoPack * info)192   void SubInfo(InfoPack *info) override { (*info) << intVar_; }
193   int64_t intVar_;
194 };
195 
196 class FloatData : public ItemData {
197  public:
FloatData(PyObject * obj,bool needSpecialize,int recurseDepth)198   FloatData(PyObject *obj, bool needSpecialize, int recurseDepth)
199       : ItemData(ItemType::PyFloat, needSpecialize, recurseDepth) {
200     floatVar_ = PyFloat_AsDouble(obj);
201   }
202 
operator ==(const ItemData & obj) const203   bool operator==(const ItemData &obj) const override {
204     return ItemData::operator==(obj) && (!specialized_ || (static_cast<const FloatData &>(obj)).floatVar_ == floatVar_);
205   }
206 
ToString()207   std::string ToString() override { return DESC_STRING(floatVar_) + DESC_END; }
208 
209  protected:
SubInfo(InfoPack * info)210   void SubInfo(InfoPack *info) override { (*info) << floatVar_; }
211   double floatVar_;
212 };
213 
214 class BoolData : public ItemData {
215  public:
BoolData(PyObject * obj,bool needSpecialize,int recurseDepth)216   BoolData(PyObject *obj, bool needSpecialize, int recurseDepth)
217       : ItemData(ItemType::PyBool, needSpecialize, recurseDepth) {
218     boolVar_ = (obj == Py_True);
219   }
220 
operator ==(const ItemData & obj) const221   bool operator==(const ItemData &obj) const override {
222     return ItemData::operator==(obj) && (!specialized_ || (static_cast<const BoolData &>(obj)).boolVar_ == boolVar_);
223   }
224 
ToString()225   std::string ToString() override { return DESC_STRING(boolVar_) + DESC_END; }
226 
227  protected:
SubInfo(InfoPack * info)228   void SubInfo(InfoPack *info) override { (*info) << boolVar_; }
229   bool boolVar_;
230 };
231 
232 class BytesData : public ItemData {
233  public:
BytesData(PyObject * obj,bool needSpecialize,int recurseDepth)234   BytesData(PyObject *obj, bool needSpecialize, int recurseDepth)
235       : ItemData(ItemType::PyBytes, needSpecialize, recurseDepth), len_(PyBytes_Size(obj)) {
236     if (needSpecialize) {
237       buf_ = std::make_unique<uint8_t[]>(len_);
238       if (buf_ != nullptr) {
239         char *pBuf = PyBytes_AS_STRING(reinterpret_cast<PyBytesObject *>(obj));
240         if (pBuf != nullptr) {
241           memcpy_s(buf_.get(), len_, reinterpret_cast<uint8_t *>(pBuf), len_);
242         } else {
243           buf_.release();
244         }
245       }
246     } else {
247       buf_.reset(nullptr);
248     }
249   }
250 
~BytesData()251   ~BytesData() override { buf_.release(); }
252 
operator ==(const ItemData & obj) const253   bool operator==(const ItemData &obj) const override {
254     if (ItemData::operator==(obj)) {
255       const BytesData &other = static_cast<const BytesData &>(obj);
256       return len_ == other.len_ &&
257              ((specialized_ && (len_ == 0 || (buf_ != nullptr && other.buf_ != nullptr &&
258                                               memcmp(buf_.get(), other.buf_.get(), len_) == 0))) ||
259               (!specialized_));
260     }
261     return false;
262   }
263 
ToString()264   std::string ToString() override {
265     size_t bytes = (size_t)(buf_.get());
266     return DESC_STRING_L(bytes, len_) + DESC_END;
267   }
268 
269  protected:
SubInfo(InfoPack * info)270   void SubInfo(InfoPack *info) override { (*info) << (uint64_t)len_ << reinterpret_cast<void *>(buf_.get()); }
271   Py_ssize_t len_;
272   std::unique_ptr<uint8_t[]> buf_;
273 };
274 
275 class StringData : public ItemData {
276  public:
StringData(PyObject * obj,bool needSpecialize,int recurseDepth)277   StringData(PyObject *obj, bool needSpecialize, int recurseDepth)
278       : ItemData(ItemType::PyStr, needSpecialize, recurseDepth) {
279     if (needSpecialize) {
280       strVal_ = GetObjectString(obj);
281     }
282   }
283 
operator ==(const ItemData & obj) const284   bool operator==(const ItemData &obj) const override {
285     return ItemData::operator==(obj) &&
286            ((specialized_ && (static_cast<const StringData &>(obj)).strVal_.compare(strVal_) == 0) || (!specialized_));
287   }
288 
ToString()289   std::string ToString() override { return DESC(strVal_) + DESC_END; }
290 
291  protected:
SubInfo(InfoPack * info)292   void SubInfo(InfoPack *info) override { (*info) << strVal_; }
293   std::string strVal_;
294 };
295 
296 class ListData : public ItemData {
297  public:
298   ListData(PyObject *obj, bool needSpecialize, int recurseDepth);
299 
operator ==(const ItemData & obj) const300   bool operator==(const ItemData &obj) const override {
301     if (ItemData::operator==(obj)) {
302       const ListData &list = static_cast<const ListData &>(obj);
303       if (list.listVar_.size() == listVar_.size()) {
304         return CompareList(list);
305       }
306     }
307     return false;
308   }
309 
ToString()310   std::string ToString() override {
311     std::string ret;
312     for (auto it : listVar_) {
313       ret += DESC_ITEM_T(it);
314     }
315     switch (tp_) {
316       case ItemType::PyList: {
317         std::string list = ret;
318         ret = DESC_STRING_S(list, listVar_.size());
319       } break;
320       case ItemType::PyTuple: {
321         std::string tuple = ret;
322         ret = DESC_STRING_S(tuple, listVar_.size());
323       } break;
324       case ItemType::PySet: {
325         std::string set = ret;
326         ret = DESC_STRING_S(set, listVar_.size());
327       } break;
328       case ItemType::PyFrozenSet: {
329         std::string fronzen_set = ret;
330         ret = DESC_STRING_S(fronzen_set, listVar_.size());
331       } break;
332       default:
333         ret = "unknown";
334         break;
335     }
336     return ret + DESC_END;
337   }
338 
339  protected:
SubInfo(InfoPack * info)340   void SubInfo(InfoPack *info) override {
341     (*info) << uint8_t(tp_);
342     (*info) << uint64_t(listVar_.size());
343     for (auto v : listVar_) {
344       (*info) << v->Info();
345     }
346   }
CompareList(const ListData & list) const347   bool CompareList(const ListData &list) const {
348     if (!inOrder_) {
349       std::vector<ItemDataPtr> listCpy = list.listVar_;
350       for (size_t i = 0, j; i < listVar_.size(); ++i) {
351         size_t lenList = listCpy.size();
352         for (j = 0; j < lenList; ++j) {
353           if (*(listCpy[j]) == *(listVar_[i])) {
354             listCpy.erase(listCpy.begin() + j);
355             break;
356           }
357         }
358         if (j == lenList) {
359           return false;
360         }
361       }
362     } else {
363       for (size_t i = 0; i < listVar_.size(); ++i) {
364         if (*(list.listVar_[i]) == *(listVar_[i])) {
365           continue;
366         } else {
367           return false;
368         }
369       }
370     }
371     return true;
372   }
InitList(PyObject * obj,bool needSpecialize,int recurseDepth)373   void InitList(PyObject *obj, bool needSpecialize, int recurseDepth) {
374     tp_ = ItemType::PyList;
375     for (Py_ssize_t i = 0; i < PyList_Size(obj); ++i) {
376       PyObject *item = PyList_GetItem(obj, i);
377       if (item != NULL) {
378         if (recurseDepth > 0 || needSpecialize) {
379           listVar_.push_back(CreateItem(item, needSpecialize, recurseDepth));
380         } else {
381           listVar_.push_back(CreateItem(reinterpret_cast<PyObject *>(Py_TYPE(item)), false, false));
382         }
383       }
384     }
385   }
InitTuple(PyObject * obj,bool needSpecialize,int recurseDepth)386   void InitTuple(PyObject *obj, bool needSpecialize, int recurseDepth) {
387     tp_ = ItemType::PyTuple;
388     for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(obj); ++i) {
389       PyObject *item = PyTuple_GET_ITEM(obj, i);
390       if (item != NULL) {
391         if (recurseDepth > 0 || needSpecialize) {
392           listVar_.push_back(CreateItem(item, needSpecialize, recurseDepth));
393         } else {
394           listVar_.push_back(CreateItem(reinterpret_cast<PyObject *>(Py_TYPE(item)), false, false));
395         }
396       }
397     }
398   }
InitSet(PyObject * obj,bool needSpecialize,int recurseDepth)399   void InitSet(PyObject *obj, bool needSpecialize, int recurseDepth) {
400     tp_ = ItemType::PySet;
401     Py_ssize_t pos = 0;
402     PyObject *item;
403     Py_hash_t hash;
404     while (_PySet_NextEntry(obj, &pos, &item, &hash)) {
405       if (recurseDepth > 0 || needSpecialize) {
406         listVar_.push_back(CreateItem(item, needSpecialize, recurseDepth));
407       } else {
408         listVar_.push_back(CreateItem(reinterpret_cast<PyObject *>(Py_TYPE(item)), false, false));
409       }
410     }
411     inOrder_ = false;
412   }
InitFrozenSet(PyObject * obj,bool needSpecialize,int recurseDepth)413   void InitFrozenSet(PyObject *obj, bool needSpecialize, int recurseDepth) {
414     tp_ = ItemType::PyFrozenSet;
415     Py_ssize_t pos = 0;
416     PyObject *item;
417     Py_hash_t hash;
418     while (_PySet_NextEntry(obj, &pos, &item, &hash)) {
419       if (recurseDepth > 0 || needSpecialize) {
420         listVar_.push_back(CreateItem(item, needSpecialize, recurseDepth));
421       } else {
422         listVar_.push_back(CreateItem(reinterpret_cast<PyObject *>(Py_TYPE(item)), false, false));
423       }
424     }
425     inOrder_ = false;
426   }
427   std::vector<ItemDataPtr> listVar_;
428   bool inOrder_ = true;
429 };
430 
431 class ComplexData : public ItemData {
432  public:
ComplexData(PyObject * obj,bool needSpecialize,int recurseDepth)433   ComplexData(PyObject *obj, bool needSpecialize, int recurseDepth)
434       : ItemData(ItemType::PyComplex, needSpecialize, recurseDepth) {
435     if (needSpecialize) {
436       complexVar_ = std::make_pair(PyComplex_RealAsDouble(obj), PyComplex_ImagAsDouble(obj));
437     }
438   }
439 
operator ==(const ItemData & obj) const440   bool operator==(const ItemData &obj) const override {
441     return ItemData::operator==(obj) &&
442            (!specialized_ || (static_cast<const ComplexData &>(obj)).complexVar_ == complexVar_);
443   }
444 
ToString()445   std::string ToString() override {
446     return "complex(" + std::to_string(complexVar_.first) + "," + std::to_string(complexVar_.second) + ")" + DESC_END;
447   }
448 
449  protected:
SubInfo(InfoPack * info)450   void SubInfo(InfoPack *info) override { (*info) << complexVar_.first << complexVar_.second; }
451   std::pair<double, double> complexVar_;
452 };
453 
454 class SliceData : public ItemData {
455  public:
SliceData(PyObject * obj,bool needSpecialize,int recurseDepth)456   SliceData(PyObject *obj, bool needSpecialize, int recurseDepth)
457       : ItemData(ItemType::PySlice, needSpecialize, recurseDepth) {
458     Py_ssize_t start = 0;
459     Py_ssize_t stop = 0;
460     Py_ssize_t step = 0;
461     if (needSpecialize) {
462       PySlice_Unpack(obj, &start, &stop, &step);
463       sliceVar_.push_back((int64_t)start);
464       sliceVar_.push_back((int64_t)stop);
465       sliceVar_.push_back((int64_t)step);
466     }
467   }
468 
operator ==(const ItemData & obj) const469   bool operator==(const ItemData &obj) const override {
470     if (ItemData::operator==(obj)) {
471       const SliceData &other = static_cast<const SliceData &>(obj);
472       return (!specialized_ || (other.sliceVar_[0] == sliceVar_[0] && other.sliceVar_[1] == sliceVar_[1] &&
473                                 other.sliceVar_[2] == sliceVar_[2]));
474     }
475     return false;
476   }
477 
ToString()478   std::string ToString() override {
479     std::string slice;
480     for (auto it : sliceVar_) {
481       slice += DESC_ITEM_V(it);
482     }
483     return DESC_STRING_S(slice, sliceVar_.size()) + DESC_END;
484   }
485 
486  protected:
SubInfo(InfoPack * info)487   void SubInfo(InfoPack *info) override { (*info) << sliceVar_; }
488   std::vector<int64_t> sliceVar_;
489 };
490 
491 typedef enum _DictType {
492   DtDict = 0,
493   DtKeys,
494   DtValues,
495   DtItems,
496 } DictType;
497 
498 class DictData : public ItemData {
499  public:
DictData(PyObject * obj,bool needSpecialize,int recurseDepth)500   DictData(PyObject *obj, bool needSpecialize, int recurseDepth)
501       : ItemData(ItemType::PyDict, needSpecialize, recurseDepth) {
502     if (PyDictKeys_Check(obj)) {
503       dt_ = DictType::DtKeys;
504       obj = PyObject_Vectorcall(reinterpret_cast<PyObject *>(&PyList_Type), &obj, 1, nullptr);
505     } else if (PyDictValues_Check(obj)) {
506       dt_ = DictType::DtValues;
507       obj = PyObject_Vectorcall(reinterpret_cast<PyObject *>(&PyList_Type), &obj, 1, nullptr);
508     } else if (PyDictItems_Check(obj)) {
509       dt_ = DictType::DtItems;
510       obj = PyObject_Vectorcall(reinterpret_cast<PyObject *>(&PyDict_Type), &obj, 1, nullptr);
511     } else {
512       dt_ = DictType::DtDict;
513     }
514     Py_ssize_t pos = 0;
515     PyObject *key;
516     PyObject *val;
517     if (dt_ == DictType::DtItems || dt_ == DictType::DtDict) {
518       while (PyDict_Next(obj, &pos, &key, &val)) {
519         ItemDataPtr k;
520         ItemDataPtr v;
521         if (recurseDepth > 0 || needSpecialize) {
522           k = CreateItem(key, needSpecialize, recurseDepth);
523           v = CreateItem(val, needSpecialize, recurseDepth);
524         } else {
525           k = CreateItem(reinterpret_cast<PyObject *>(Py_TYPE(key)), false, false);
526           v = CreateItem(reinterpret_cast<PyObject *>(Py_TYPE(val)), false, false);
527         }
528         listK_.push_back(k);
529         listV_.push_back(v);
530       }
531     } else {
532       std::vector<ItemDataPtr> &list = dt_ == DictType::DtKeys ? listK_ : listV_;
533       for (Py_ssize_t i = 0; i < PyList_Size(obj); ++i) {
534         PyObject *item = PyList_GetItem(obj, i);
535         if (recurseDepth > 0 || needSpecialize) {
536           list.push_back(CreateItem(item, needSpecialize, recurseDepth));
537         } else {
538           list.push_back(CreateItem(reinterpret_cast<PyObject *>(Py_TYPE(item)), false, false));
539         }
540       }
541     }
542     if (dt_ != DictType::DtDict) {
543       Py_DECREF(obj);
544     }
545   }
546 
operator ==(const ItemData & obj) const547   bool operator==(const ItemData &obj) const override {
548     if (ItemData::operator==(obj)) {
549       const DictData &other = static_cast<const DictData &>(obj);
550       if (dt_ != other.dt_) {
551         return false;
552       }
553       if ((dt_ == DictType::DtValues || other.listK_.size() == listK_.size()) &&
554           (dt_ == DictType::DtKeys || other.listV_.size() == listV_.size())) {
555         return CompareKV(other);
556       }
557     }
558     return false;
559   }
560 
ToString()561   std::string ToString() override {
562     std::string dict = DESC_STRING(dt_);
563     size_t listSize = 0;
564     if (dt_ == DictType::DtItems || dt_ == DictType::DtDict) {
565       listSize = listK_.size();
566       for (size_t i = 0; i < listSize; ++i) {
567         dict += DESC_ITEM(listK_[i], listV_[i]);
568       }
569     } else if (dt_ == DictType::DtKeys) {
570       listSize = listK_.size();
571       for (size_t i = 0; i < listSize; ++i) {
572         dict += DESC_ITEM_T(listK_[i]);
573       }
574     } else if (dt_ == DictType::DtValues) {
575       listSize = listV_.size();
576       for (size_t i = 0; i < listSize; ++i) {
577         dict += DESC_ITEM_T(listV_[i]);
578       }
579     }
580     return DESC_STRING_S(dict, listSize) + DESC_END;
581   }
582 
583  protected:
SubInfo(InfoPack * info)584   void SubInfo(InfoPack *info) override {
585     (*info) << dt_;
586     (*info) << uint64_t(listK_.size());
587     for (auto i : listK_) {
588       (*info) << i->Info();
589     }
590     (*info) << uint64_t(listV_.size());
591     for (auto i : listV_) {
592       (*info) << i->Info();
593     }
594   }
CompareKV(const DictData & other) const595   bool CompareKV(const DictData &other) const {
596     std::vector<ItemDataPtr> listCpK = other.listK_;
597     std::vector<ItemDataPtr> listCpV = other.listV_;
598     size_t listSize = listK_.size();
599     if (listSize < listV_.size()) {
600       listSize = listV_.size();
601     }
602     for (size_t i = 0, j = 0; i < listSize; ++i) {
603       size_t cpListSize = dt_ == DictType::DtValues ? listCpV.size() : listCpK.size();
604       for (; j < cpListSize; ++j) {
605         if ((dt_ == DictType::DtValues || *(listK_[i]) == *(listCpK[j])) &&
606             (dt_ == DictType::DtKeys || *(listV_[i]) == *(listCpV[j]))) {
607           if (dt_ != DictType::DtValues) {
608             listCpK.erase(listCpK.begin() + j);
609           }
610           if (dt_ != DictType::DtKeys) {
611             listCpV.erase(listCpV.begin() + j);
612           }
613           break;
614         }
615       }
616       if (j == cpListSize) {
617         return false;
618       }
619     }
620     return true;
621   }
622   DictType dt_;
623   std::vector<ItemDataPtr> listK_;
624   std::vector<ItemDataPtr> listV_;
625 };
626 
627 class FunctionData : public ItemData {
628  public:
FunctionData(PyObject * obj,bool needSpecialize,int recurseDepth)629   FunctionData(PyObject *obj, bool needSpecialize, int recurseDepth)
630       : ItemData(ItemType::PyFunction, needSpecialize, recurseDepth) {
631     if (needSpecialize || recurseDepth > 0) {
632       code_ = reinterpret_cast<PyCodeObject *>(PyFunction_GetCode(obj));
633       defaults_ = CreateItem(PyFunction_GetDefaults(obj), needSpecialize, recurseDepth);
634       kwdefaults_ = CreateItem(PyFunction_GetKwDefaults(obj), needSpecialize, recurseDepth);
635       closure_ = CreateItem(PyFunction_GetClosure(obj), needSpecialize, recurseDepth);
636     } else {
637       code_ = reinterpret_cast<PyCodeObject *>(PyFunction_GetCode(obj));
638       PyObject *temp = PyFunction_GetDefaults(obj);
639       defaults_ =
640         CreateItem((temp == NULL || temp == Py_None) ? Py_None : reinterpret_cast<PyObject *>(Py_TYPE(temp)), false, 0);
641       temp = PyFunction_GetKwDefaults(obj);
642       kwdefaults_ =
643         CreateItem((temp == NULL || temp == Py_None) ? Py_None : reinterpret_cast<PyObject *>(Py_TYPE(temp)), false, 0);
644       temp = PyFunction_GetClosure(obj);
645       closure_ =
646         CreateItem((temp == NULL || temp == Py_None) ? Py_None : reinterpret_cast<PyObject *>(Py_TYPE(temp)), false, 0);
647     }
648   }
649 
operator ==(const ItemData & obj) const650   bool operator==(const ItemData &obj) const override {
651     if (ItemData::operator==(obj)) {
652       const FunctionData &other = static_cast<const FunctionData &>(obj);
653       return code_ == other.code_ && *defaults_ == *(other.defaults_) && *kwdefaults_ == *(other.kwdefaults_) &&
654              *closure_ == *(other.closure_);
655     }
656     return false;
657   }
658 
ToString()659   std::string ToString() override {
660     std::string func = DESC_TOSTRING(defaults_) + DESC_TOSTRING(kwdefaults_) + DESC_TOSTRING(closure_);
661     return DESC(func) + DESC_END;
662   }
663 
664  protected:
SubInfo(InfoPack * info)665   void SubInfo(InfoPack *info) override {
666     (*info) << (defaults_ != nullptr);
667     if (defaults_ != nullptr) {
668       (*info) << defaults_->Info();
669     }
670     (*info) << (kwdefaults_ != nullptr);
671     if (kwdefaults_ != nullptr) {
672       (*info) << kwdefaults_->Info();
673     }
674     (*info) << (closure_ != nullptr);
675     if (closure_ != nullptr) {
676       (*info) << closure_->Info();
677     }
678   }
679   PyCodeObject *code_;
680   ItemDataPtr defaults_;
681   ItemDataPtr kwdefaults_;
682   ItemDataPtr closure_;
683 };
684 
685 class MethodData : public ItemData {
686  public:
MethodData(PyObject * obj,bool needSpecialize,int recurseDepth)687   MethodData(PyObject *obj, bool needSpecialize, int recurseDepth)
688       : ItemData(ItemType::PyMethod, needSpecialize, recurseDepth),
689         refFunc_(CreateItem(PyMethod_GET_FUNCTION(obj), needSpecialize, recurseDepth)),
690         refSelf_(CreateItem(PyMethod_GET_SELF(obj), needSpecialize, recurseDepth)) {}
691 
operator ==(const ItemData & obj) const692   bool operator==(const ItemData &obj) const override {
693     if (ItemData::operator==(obj)) {
694       const MethodData &other = static_cast<const MethodData &>(obj);
695       return *refFunc_ == *(other.refFunc_) && *refSelf_ == *(other.refSelf_);
696     }
697     return false;
698   }
699 
ToString()700   std::string ToString() override {
701     std::string method = DESC_TOSTRING(refFunc_) + DESC_TOSTRING(refSelf_);
702     return DESC(method) + DESC_END;
703   }
704 
705  protected:
SubInfo(InfoPack * info)706   void SubInfo(InfoPack *info) override {
707     (*info) << (refFunc_ != nullptr);
708     if (refFunc_ != nullptr) {
709       (*info) << refFunc_->Info();
710     }
711     (*info) << (refSelf_ != nullptr);
712     if (refSelf_ != nullptr) {
713       (*info) << refSelf_->Info();
714     }
715   }
716   ItemDataPtr refFunc_;
717   ItemDataPtr refSelf_;
718 };
719 
720 class InstanceMethodData : public ItemData {
721  public:
InstanceMethodData(PyObject * obj,bool needSpecialize,int recurseDepth)722   InstanceMethodData(PyObject *obj, bool needSpecialize, int recurseDepth)
723       : ItemData(ItemType::PyInstanceMethod, needSpecialize, recurseDepth),
724         refFunc_(CreateItem(PyInstanceMethod_GET_FUNCTION(obj), needSpecialize, recurseDepth)) {}
725 
operator ==(const ItemData & obj) const726   bool operator==(const ItemData &obj) const override {
727     if (ItemData::operator==(obj)) {
728       const InstanceMethodData &other = static_cast<const InstanceMethodData &>(obj);
729       return *refFunc_ == *(other.refFunc_);
730     }
731     return false;
732   }
733 
ToString()734   std::string ToString() override {
735     std::string instance_method = DESC_TOSTRING(refFunc_);
736     return DESC(instance_method) + DESC_END;
737   }
738 
739  protected:
SubInfo(InfoPack * info)740   void SubInfo(InfoPack *info) override {
741     (*info) << (refFunc_ != nullptr);
742     if (refFunc_ != nullptr) {
743       (*info) << refFunc_->Info();
744     }
745   }
746   ItemDataPtr refFunc_;
747 };
748 
749 class TypeData : public ItemData {
750  public:
TypeData(PyObject * obj,bool needSpecialize,int recurseDepth)751   TypeData(PyObject *obj, bool needSpecialize, int recurseDepth)
752       : ItemData(ItemType::PyType, needSpecialize, recurseDepth) {
753     refType_ = reinterpret_cast<PyTypeObject *>(obj);
754     is_adapter_tensor_type_ = false;
755     ambiguous_tensor_type_ = false;
756   }
757 
set_ambiguous_tensor_type(bool value)758   void set_ambiguous_tensor_type(bool value) {
759     if (value && (IsTensorType<true>(refType_) || IsStubTensorType<true>(refType_))) {
760       ambiguous_tensor_type_ = true;
761       PyObject *obj = reinterpret_cast<PyObject *>(refType_);
762       py::object registry = Utils::GetModuleAttr("mindspore.common._register_for_adapter", "ms_adapter_registry");
763       is_adapter_tensor_type_ = registry.ptr() != nullptr && obj == py::getattr(registry, "tensor", nullptr).ptr();
764     }
765   }
766 
operator ==(const ItemData & obj) const767   bool operator==(const ItemData &obj) const override {
768     if (ItemData::operator==(obj)) {
769       PyTypeObject *otherType = (static_cast<const TypeData &>(obj)).refType_;
770       bool ret = refType_ == otherType;
771       if (!ret) {
772         ret = PyType_IsSubtype(refType_, otherType) || PyType_IsSubtype(otherType, refType_);
773       }
774       // adapter tensor type must be check exactly
775       // if exactly type check failed, check ambiguous tensor type if necessary
776       if (!is_adapter_tensor_type_ && !ret && ambiguous_tensor_type_) {
777         ret = IsTensorType<true>(otherType) || IsStubTensorType<true>(otherType);
778       }
779       return ret;
780     }
781     return false;
782   }
783 
ToString()784   std::string ToString() override {
785     std::string type = refType_->tp_name;
786     return DESC(type) + DESC_END;
787   }
788 
789  protected:
SubInfo(InfoPack * info)790   void SubInfo(InfoPack *info) override { (*info) << refType_->tp_name; }
791   PyTypeObject *refType_;
792 
793   // this flag is checked only if tensor type is ambiguous
794   bool is_adapter_tensor_type_;
795 
796   // mix the tensor type.
797   // only set true if _c_expression.Tensor type, common.Tensor type, StubTensor type and all subtype of them
798   bool ambiguous_tensor_type_;
799 };
800 
801 class NumpyData : public ItemData {
802  public:
NumpyData(PyObject * obj,bool needSpecialize,int recurseDepth)803   NumpyData(PyObject *obj, bool needSpecialize, int recurseDepth)
804       : ItemData(ItemType::PyNumpy, needSpecialize, recurseDepth) {
805     py::array arr = py::cast<py::array>(obj);
806     dtype_ = arr.dtype();
807     size_ = (uint64_t)arr.size();
808     itemsize_ = (uint64_t)arr.itemsize();
809     ndim_ = (int64_t)arr.ndim();
810     nbytes_ = (uint64_t)arr.nbytes();
811     for (ssize_t i = 0; i < ndim_; ++i) {
812       shape_.push_back((int64_t)arr.shape()[i]);
813       strides_.push_back((int64_t)arr.strides()[i]);
814     }
815     if (arr.data() != nullptr) {
816       if (needSpecialize) {
817         buf_ = std::make_unique<uint8_t[]>(nbytes_);
818         if (buf_ != NULL) {
819           memcpy_s(buf_.get(), nbytes_, reinterpret_cast<uint8_t *>(arr.mutable_data()), nbytes_);
820         }
821       } else {
822         buf_.reset(nullptr);
823       }
824     } else {
825       buf_.reset(nullptr);
826     }
827   }
828 
~NumpyData()829   ~NumpyData() override { buf_.release(); }
830 
operator ==(const ItemData & obj) const831   bool operator==(const ItemData &obj) const override {
832     if (ItemData::operator==(obj)) {
833       const NumpyData &other = static_cast<const NumpyData &>(obj);
834       return dtype_ == other.dtype_ && size_ == other.size_ && ndim_ == other.ndim_ && nbytes_ == other.nbytes_ &&
835              shape_ == other.shape_ && strides_ == other.strides_ &&
836              (!specialized_ ||
837               (buf_ != NULL && other.buf_ != NULL && memcmp(buf_.get(), other.buf_.get(), nbytes_) == 0));
838     }
839     return false;
840   }
841 
ToString()842   std::string ToString() override {
843     std::string numpy;
844     char dtype_kind = dtype_.kind();
845     numpy +=
846       DESC_STRING(dtype_kind) + DESC_STRING(size_) + DESC_STRING(itemsize_) + DESC_STRING(ndim_) + DESC_STRING(nbytes_);
847     for (size_t i = 0; i < shape_.size(); ++i) {
848       numpy += DESC_INDEX_V(shape_, i) + DESC_INDEX_V(strides_, i);
849     }
850     return DESC(numpy) + DESC_END;
851   }
852 
853  protected:
SubInfo(InfoPack * info)854   void SubInfo(InfoPack *info) override {
855     (*info) << dtype_.kind() << size_ << itemsize_ << ndim_ << nbytes_ << shape_ << strides_;
856   }
857   py::dtype dtype_;
858   uint64_t size_;
859   uint64_t itemsize_;
860   int64_t ndim_;
861   uint64_t nbytes_;
862   std::vector<int64_t> shape_;
863   std::vector<int64_t> strides_;
864   std::unique_ptr<uint8_t[]> buf_;
865 };
866 
867 class TensorTypeData : public ItemData {
868  public:
TensorTypeData(PyObject * obj,bool needSpecialize,int recurseDepth)869   TensorTypeData(PyObject *obj, bool needSpecialize, int recurseDepth)
870       : ItemData(ItemType::TensorType, needSpecialize, recurseDepth) {
871     auto pyObj = py::cast<py::object>(obj);
872     tpp_ = pyObj.cast<mindspore::TypePtr>();
873   }
874 
operator ==(const ItemData & obj) const875   bool operator==(const ItemData &obj) const override {
876     return ItemData::operator==(obj) && (!specialized_ || *((static_cast<const TensorTypeData &>(obj)).tpp_) == *tpp_);
877   }
878 
ToString()879   std::string ToString() override {
880     std::string tensor_type = tpp_->ToString();
881     return DESC(tensor_type) + DESC_END;
882   }
883 
884  protected:
SubInfo(InfoPack * info)885   void SubInfo(InfoPack *info) override { (*info) << tpp_; }
886   mindspore::TypePtr tpp_;
887 };
888 
889 class ParamInfoData : public ItemData {
890  public:
ParamInfoData(PyObject * obj,bool needSpecialize,int recurseDepth)891   ParamInfoData(PyObject *obj, bool needSpecialize, int recurseDepth)
892       : ItemData(ItemType::ParamInfo, needSpecialize, recurseDepth) {
893     auto pyObj = py::cast<py::object>(obj);
894     auto ptr = pyObj.cast<mindspore::ParamInfoPtr>();
895     param_ = ptr->Clone();
896   }
897 
operator ==(const ItemData & obj) const898   bool operator==(const ItemData &obj) const override {
899     if (ItemData::operator==(obj)) {
900       if (!specialized_) {
901         return true;
902       }
903       const ParamInfoData &other = static_cast<const ParamInfoData &>(obj);
904       return Equal(param_, other.param_);
905     }
906     return false;
907   }
908 
Equal(ParamInfoPtr a,ParamInfoPtr b)909   static bool Equal(ParamInfoPtr a, ParamInfoPtr b) {
910     return a->requires_grad() == b->requires_grad() && a->comm_fusion() == b->comm_fusion() &&
911            a->parallel_optimizer() == b->parallel_optimizer() &&
912            a->parallel_optimizer_comm_recompute() == b->parallel_optimizer_comm_recompute() &&
913            a->parameter_shape() == b->parameter_shape() && a->use_persistent_storage() == b->use_persistent_storage() &&
914            a->cache_enable() == b->cache_enable() && a->param_strategy() == b->param_strategy() &&
915            a->cache_shape() == b->cache_shape() && a->requires_aggr() == b->requires_aggr();
916   }
917 
ToString()918   std::string ToString() override {
919     std::string param_info = ToStringAttr(param_);
920     return DESC(param_info) + DESC_END;
921   }
922 
ToStringAttr(mindspore::ParamInfoPtr p)923   static std::string ToStringAttr(mindspore::ParamInfoPtr p) {
924     if (p == nullptr) {
925       return "nil";
926     }
927     std::string param_name = p->name();
928     std::string ret = DESC(param_name) + DESC_STRING_O(p, requires_grad()) + DESC_STRING_O(p, comm_fusion()) +
929                       DESC_STRING_O(p, parallel_optimizer()) + DESC_STRING_O(p, requires_aggr()) +
930                       DESC_STRING_O(p, parallel_optimizer_comm_recompute()) +
931                       DESC_STRING_O(p, use_persistent_storage()) + DESC_STRING_O(p, cache_enable());
932     auto parameter_shape = p->parameter_shape();
933     for (size_t i = 0; i < parameter_shape.size(); ++i) {
934       ret += DESC_INDEX_V(parameter_shape, i);
935     }
936     auto cache_shape = p->cache_shape();
937     for (size_t i = 0; i < cache_shape.size(); ++i) {
938       ret += DESC_INDEX_V(cache_shape, i);
939     }
940     auto param_strategy = p->param_strategy();
941     for (size_t i = 0; i < param_strategy.size(); ++i) {
942       ret += DESC_INDEX_V(param_strategy, i);
943     }
944     return ret;
945   }
946 
SubInfo(InfoPack * info,mindspore::ParamInfoPtr p)947   static void SubInfo(InfoPack *info, mindspore::ParamInfoPtr p) {
948     if (p == nullptr) {
949       return;
950     }
951     (*info) << p->name() << p->requires_grad() << p->comm_fusion() << p->parallel_optimizer() << p->requires_aggr()
952             << p->parallel_optimizer_comm_recompute() << p->use_persistent_storage() << p->cache_enable()
953             << p->parameter_shape() << p->cache_shape() << p->param_strategy();
954   }
955 
956  protected:
SubInfo(InfoPack * info)957   void SubInfo(InfoPack *info) override { SubInfo(info, param_); }
958   mindspore::ParamInfoPtr param_;
959 };
960 
961 static constexpr int64_t kDynamicDim = -2;
962 static constexpr int64_t kDynamicShape = -1;
963 
IsDynamicDim(const ShapeVector & shape)964 static bool IsDynamicDim(const ShapeVector &shape) {
965   return std::any_of(shape.begin(), shape.end(), [](ShapeValueDType dim) { return dim == kDynamicDim; });
966 }
967 
CheckShape(const ShapeVector & a,const ShapeVector & b)968 static bool CheckShape(const ShapeVector &a, const ShapeVector &b) {
969   if (IsDynamicDim(a) || IsDynamicDim(b)) {
970     return true;
971   } else if (a.size() == b.size()) {
972     for (size_t idx = 0; idx < a.size(); idx++) {
973       if (a[idx] != kDynamicShape && b[idx] != kDynamicShape && a[idx] != b[idx]) {
974         return false;
975       }
976     }
977     return true;
978   } else {
979     return false;
980   }
981 }
982 
983 class MetaTensorData : public ItemData {
984  public:
MetaTensorData(mindspore::tensor::MetaTensorPtr tensor_ptr,bool needSpecialize,int recurseDepth)985   MetaTensorData(mindspore::tensor::MetaTensorPtr tensor_ptr, bool needSpecialize, int recurseDepth)
986       : ItemData(ItemType::MetaTensor, needSpecialize, recurseDepth) {
987     StoreTensor(tensor_ptr);
988   }
989 
MetaTensorData(PyObject * obj,bool needSpecialize,int recurseDepth)990   MetaTensorData(PyObject *obj, bool needSpecialize, int recurseDepth)
991       : ItemData(ItemType::MetaTensor, needSpecialize, recurseDepth) {
992     mindspore::tensor::MetaTensorPtr tensor_ptr = nullptr;
993     PyObject *stubattr = GetAttrStubStr();
994     PyObject *stub = PyObject_HasAttr(obj, stubattr) ? PyObject_GetAttr(obj, stubattr) : nullptr;
995     if (stub != nullptr) {
996       if (stub != Py_None) {
997         is_stubtensor_ = true;
998       } else {
999         PyObject *tensorattr = GetAttrTensorStr();
1000         obj = PyObject_GetAttr(obj, tensorattr);
1001         tensor_ptr = py::cast<mindspore::tensor::TensorPtr>(obj);
1002         Py_DECREF(obj);
1003       }
1004     } else if (py::isinstance<mindspore::tensor::Tensor>(obj)) {
1005       tensor_ptr = py::cast<mindspore::tensor::TensorPtr>(obj);
1006     } else if (py::isinstance<mindspore::tensor::MapTensor>(obj)) {
1007       tensor_ptr = py::cast<mindspore::tensor::MapTensorPtr>(obj);
1008     } else {
1009       tensor_ptr = py::cast<mindspore::tensor::MetaTensorPtr>(obj);
1010     }
1011     if (tensor_ptr != nullptr) {
1012       StoreTensor(tensor_ptr);
1013     } else {
1014       auto ptr = py::cast<mindspore::stub::StubNodePtr>(stub);
1015       StoreStubTensor(ptr);
1016     }
1017     Py_XDECREF(stub);
1018   }
1019 
operator ==(const ItemData & obj) const1020   bool operator==(const ItemData &obj) const override {
1021     if (ItemData::operator==(obj)) {
1022       const MetaTensorData &other = static_cast<const MetaTensorData &>(obj);
1023       bool ret;
1024       if (is_stubtensor_ || other.is_stubtensor_) {
1025         ret = CheckShape(shape_, other.shape_) && CheckDataType(other);
1026       } else {
1027         ret = tid_ == other.tid_ && CheckShape(shape_, other.shape_) && is_parameter_ == other.is_parameter_ &&
1028               CheckDataType(other);
1029       }
1030       if (ret) {
1031         if (is_parameter_ == true) {
1032           ret = ((param_ == nullptr && other.param_ == nullptr) ||
1033                  (param_ != nullptr && other.param_ != nullptr && ParamInfoData::Equal(param_, other.param_)));
1034         }
1035       }
1036       return ret;
1037     }
1038     return false;
1039   }
1040 
MakeTensor()1041   mindspore::tensor::TensorPtr MakeTensor() {
1042     return std::make_shared<mindspore::tensor::Tensor>(data_type_->type_id(), shape_);
1043   }
1044 
IsDynamicShape() const1045   bool IsDynamicShape() const {
1046     return std::any_of(shape_.begin(), shape_.end(),
1047                        [](ShapeValueDType dim) { return dim == kDynamicDim || dim == kDynamicShape; });
1048   }
1049 
ToString()1050   std::string ToString() override {
1051     std::string meta_tensor = ToStringIntern();
1052     return DESC(meta_tensor) + DESC_END;
1053   }
1054 
MatchDynamicShape(std::shared_ptr<ItemData> other)1055   bool MatchDynamicShape(std::shared_ptr<ItemData> other) override {
1056     auto type = other->GetItemType();
1057     if (type != ItemType::Tensor && type != ItemType::MetaTensor) {
1058       return false;
1059     }
1060     auto o = static_cast<MetaTensorData *>(other.get());
1061     if (!CheckDataType(*o) || specialized_ != false || o->specialized_ != false) {
1062       return false;
1063     }
1064     if (shape_.size() != o->shape_.size()) {
1065       shape_ = {kDynamicDim};
1066     } else {
1067       for (size_t idx = 0; idx < shape_.size(); ++idx) {
1068         if (shape_[idx] != kDynamicShape && shape_[idx] != o->shape_[idx]) {
1069           shape_[idx] = kDynamicShape;
1070         }
1071       }
1072     }
1073     return true;
1074   }
1075 
1076  protected:
MetaTensorData(bool needSpecialize,int recurseDepth)1077   MetaTensorData(bool needSpecialize, int recurseDepth)
1078       : ItemData(ItemType::MetaTensor, needSpecialize, recurseDepth) {}
ToStringIntern()1079   virtual std::string ToStringIntern() {
1080     std::string param_desc = ParamInfoData::ToStringAttr(param_);
1081     std::string shape = "";
1082     for (size_t i = 0; i < shape_.size(); ++i) {
1083       shape += DESC_INDEX_V(shape_, i);
1084     }
1085     std::string is_stubtensor = is_stubtensor_ ? "true" : "false";
1086     return DESC_STRING(tid_) + DESC_TOSTRING(data_type_) + DESC_STRING(is_parameter_) + DESC(param_desc) + DESC(shape) +
1087            DESC(is_stubtensor);
1088   }
1089 
CheckDataType(const MetaTensorData & other) const1090   bool CheckDataType(const MetaTensorData &other) const {
1091     return (data_type_ == nullptr && other.data_type_ == nullptr) ||
1092            (data_type_ != nullptr && other.data_type_ != nullptr && *data_type_ == *(other.data_type_));
1093   }
1094 
StoreTensor(mindspore::tensor::MetaTensorPtr tensor_ptr)1095   void StoreTensor(mindspore::tensor::MetaTensorPtr tensor_ptr) {
1096     tid_ = tensor_ptr->data_type();
1097     shape_ = tensor_ptr->shape();
1098     data_type_ = tensor_ptr->Dtype();
1099     is_parameter_ = tensor_ptr->is_parameter();
1100     param_ = tensor_ptr->param_info();
1101   }
1102 
StoreStubTensor(mindspore::stub::StubNodePtr stub_ptr)1103   void StoreStubTensor(mindspore::stub::StubNodePtr stub_ptr) {
1104     auto base = stub_ptr->ToAbstract();
1105     auto shape = base->BuildShape()->cast<abstract::ShapePtr>();
1106     if (shape && !shape->IsDynamic()) {
1107       shape_ = shape->shape();
1108     } else {
1109       shape_ = {};
1110     }
1111     auto dt = base->BuildType();
1112     if (dt->isa<mindspore::TensorType>()) {
1113       data_type_ = dt->cast<std::shared_ptr<mindspore::TensorType>>()->element();
1114     } else {
1115       data_type_ = dt;
1116     }
1117   }
1118 
SubInfo(InfoPack * info)1119   void SubInfo(InfoPack *info) override {
1120     (*info) << uint8_t(tid_) << data_type_ << is_parameter_ << shape_ << is_stubtensor_;
1121     ParamInfoData::SubInfo(info, param_);
1122   }
1123 
1124   mindspore::TypeId tid_ = TypeId::kTypeUnknown;
1125   ShapeVector shape_;
1126   TypePtr data_type_;
1127   bool is_parameter_ = false;
1128   bool is_stubtensor_ = false;
1129   mindspore::ParamInfoPtr param_;
1130 };
1131 
1132 class TensorData : public MetaTensorData {
1133  public:
TensorData(mindspore::tensor::TensorPtr tensor_ptr,bool needSpecialize,int recurseDepth)1134   TensorData(mindspore::tensor::TensorPtr tensor_ptr, bool needSpecialize, int recurseDepth)
1135       : MetaTensorData(needSpecialize, recurseDepth) {
1136     tp_ = ItemType::Tensor;
1137     StoreTensor(tensor_ptr);
1138   }
1139 
TensorData(PyObject * obj,bool needSpecialize,int recurseDepth)1140   TensorData(PyObject *obj, bool needSpecialize, int recurseDepth) : MetaTensorData(needSpecialize, recurseDepth) {
1141     is_stubtensor_ = false;
1142     tp_ = ItemType::Tensor;
1143     mindspore::tensor::TensorPtr tensor_ptr = nullptr;
1144     PyObject *stubattr = GetAttrStubStr();
1145     PyObject *stub = PyObject_HasAttr(obj, stubattr) ? PyObject_GetAttr(obj, stubattr) : nullptr;
1146     if (stub != nullptr) {
1147       if (stub != Py_None) {
1148         specialized_ = false;
1149       }
1150       if (specialized_) {
1151         auto pyObj = python_adapter::CallPyObjMethod(py::cast<py::object>(obj), kPyMethodStubSync);
1152         tensor_ptr = py::cast<mindspore::tensor::TensorPtr>(pyObj.ptr());
1153       } else {
1154         if (stub != Py_None) {
1155           is_stubtensor_ = true;
1156         } else {
1157           PyObject *tensorattr = GetAttrTensorStr();
1158           obj = PyObject_GetAttr(obj, tensorattr);
1159           tensor_ptr = py::cast<mindspore::tensor::TensorPtr>(obj);
1160           Py_DECREF(obj);
1161         }
1162       }
1163     } else if (py::isinstance<mindspore::tensor::Tensor>(obj)) {
1164       tensor_ptr = py::cast<mindspore::tensor::TensorPtr>(obj);
1165     } else if (py::isinstance<mindspore::tensor::MapTensor>(obj)) {
1166       tensor_ptr = py::cast<mindspore::tensor::MapTensorPtr>(obj);
1167     } else {
1168       tensor_ptr = py::cast<mindspore::tensor::TensorPtr>(obj);
1169     }
1170     if (tensor_ptr != nullptr) {
1171       if (OptStrategy::MakeCalcStrategyByShape(tensor_ptr->shape()) != OptStrategy::CalcKind::kCalcValue) {
1172         specialized_ = false;
1173       }
1174       StoreTensor(tensor_ptr);
1175     } else {
1176       auto ptr = py::cast<mindspore::stub::StubNodePtr>(stub);
1177       StoreStubTensor(ptr);
1178     }
1179     Py_XDECREF(stub);
1180   }
1181 
~TensorData()1182   ~TensorData() override { data_ptr_.release(); }
1183 
IsBaseShapePtr(const TensorData & other) const1184   bool IsBaseShapePtr(const TensorData &other) const {
1185     return (other.base_shape_ptr_ == nullptr && base_shape_ptr_ == other.base_shape_ptr_) ||
1186            (base_shape_ptr_ != nullptr && other.base_shape_ptr_ != nullptr &&
1187             *(other.base_shape_ptr_) == *(base_shape_ptr_));
1188   }
1189 
IsCastDtype(const TensorData & other) const1190   bool IsCastDtype(const TensorData &other) const {
1191     return (other.cast_dtype_ == nullptr && cast_dtype_ == nullptr) ||
1192            (other.cast_dtype_ != nullptr && cast_dtype_ != nullptr && *cast_dtype_ == *(other.cast_dtype_));
1193   }
1194 
operator ==(const ItemData & obj) const1195   bool operator==(const ItemData &obj) const override {
1196     if (!ItemData::operator==(obj)) {
1197       return false;
1198     }
1199     bool ret = MetaTensorData::operator==(obj);
1200     const TensorData &other = static_cast<const TensorData &>(obj);
1201     if (is_stubtensor_ || other.is_stubtensor_) {
1202       return ret;
1203     }
1204     ret = ret && other.init_flag_ == init_flag_ && other.is_forward_output_ == is_forward_output_ &&
1205           other.graph_output_ == graph_output_ && other.specialized_ == specialized_ && IsBaseShapePtr(other) &&
1206           IsCastDtype(other) && other.compression_type_ == compression_type_ &&
1207           other.quant_params_.size() == quant_params_.size() && other.tensor_name_.compare(tensor_name_) == 0;
1208     if (!ret) {
1209       return ret;
1210     }
1211     for (size_t i = 0; i < quant_params_.size(); ++i) {
1212       if (*(quant_params_[i]) == *(other.quant_params_[i])) {
1213         continue;
1214       } else {
1215         return false;
1216       }
1217     }
1218     if (IsDynamicShape() || other.IsDynamicShape()) {
1219       return true;
1220     } else {
1221       return CheckData(other);
1222     }
1223   }
1224 
ToString()1225   std::string ToString() override {
1226     std::string tensor = ToStringIntern();
1227     return DESC(tensor) + DESC_END;
1228   }
1229 
1230  protected:
ToStringIntern()1231   std::string ToStringIntern() override {
1232     std::string ret = MetaTensorData::ToStringIntern();
1233     ret += DESC_STRING(is_forward_output_) + DESC_STRING(init_flag_) + DESC_STRING(graph_output_);
1234     ret +=
1235       DESC_TOSTRING(cast_dtype_) + DESC_TOSTRING(base_shape_ptr_) + DESC_STRING(compression_type_) + DESC(tensor_name_);
1236     for (size_t i = 0; i < quant_params_.size(); ++i) {
1237       ret += DESC_INDEX(quant_params_, i);
1238     }
1239     return ret;
1240   }
1241 
CheckData(const TensorData & other) const1242   bool CheckData(const TensorData &other) const {
1243     bool ret;
1244     if (specialized_) {
1245       if (data_ptr_ == nullptr || other.data_ptr_ == nullptr) {
1246         ret = data_len_ == other.data_len_;
1247       } else if (data_len_ == other.data_len_) {
1248         ret = memcmp(data_ptr_.get(), other.data_ptr_.get(), data_len_) == 0;
1249       } else {
1250         ret = false;
1251       }
1252     } else {
1253       ret = data_len_ == other.data_len_;
1254     }
1255     return ret;
1256   }
1257 
StoreTensor(mindspore::tensor::TensorPtr tensor_ptr)1258   void StoreTensor(mindspore::tensor::TensorPtr tensor_ptr) {
1259     MetaTensorData::StoreTensor(tensor_ptr);
1260     init_flag_ = tensor_ptr->is_init();
1261     is_forward_output_ = tensor_ptr->is_forward_output();
1262     id_ = tensor_ptr->id();
1263     graph_output_ = tensor_ptr->IsGraphOutput();
1264     base_shape_ptr_ = tensor_ptr->base_shape_ptr() == nullptr ? nullptr : tensor_ptr->base_shape_ptr()->Clone();
1265     cast_dtype_ = (tensor_ptr->cast_dtype() == nullptr) ? nullptr : tensor_ptr->cast_dtype()->Clone();
1266     compression_type_ = tensor_ptr->compression_type();
1267     const std::vector<std::shared_ptr<mindspore::QuantizationParam>> &qp = tensor_ptr->quant_params();
1268     tensor_name_ = tensor_ptr->name();
1269     for (auto quant : qp) {
1270       QuantizationParamPtr qptr = std::make_shared<mindspore::QuantizationParam>(quant->quant_algo_name());
1271       quant_params_.push_back(qptr);
1272       qptr->set_attrs(quant->attrs());
1273     }
1274     if (specialized_) {
1275       tensor_ptr->data_sync(true);
1276       auto data = tensor_ptr->data_ptr();
1277       data_len_ = size_t(data->nbytes());
1278       data_ptr_ = std::make_unique<uint8_t[]>(data_len_);
1279       if (data_ptr_ != nullptr) {
1280         memcpy_s(data_ptr_.get(), data_len_, reinterpret_cast<uint8_t *>(data->data()), data_len_);
1281       }
1282     } else {
1283       data_ptr_.reset(nullptr);
1284       data_len_ = size_t(tensor_ptr->data_ptr()->nbytes());
1285     }
1286   }
1287 
SubInfo(InfoPack * info)1288   void SubInfo(InfoPack *info) override {
1289     MetaTensorData::SubInfo(info);
1290     (*info) << is_forward_output_ << init_flag_ << graph_output_ << cast_dtype_ << base_shape_ptr_
1291             << uint8_t(compression_type_) << tensor_name_;
1292     (*info) << uint64_t(quant_params_.size());
1293     for (auto qp : quant_params_) {
1294       (*info) << qp;
1295     }
1296   }
1297 
1298   bool init_flag_;
1299   bool is_forward_output_;
1300   std::unique_ptr<uint8_t[]> data_ptr_;
1301   size_t data_len_;
1302   std::string id_;
1303   bool graph_output_;
1304   mindspore::abstract::BaseShapePtr base_shape_ptr_;
1305   mindspore::TypePtr cast_dtype_;
1306   mindspore::TensorCompressionType compression_type_;
1307   std::vector<QuantizationParamPtr> quant_params_;
1308   std::string tensor_name_;
1309 };
1310 using TensorDataPtr = std::shared_ptr<TensorData>;
1311 
Equal(const TensorDataPtr & a,const TensorDataPtr & b,int recurseDepth)1312 static bool Equal(const TensorDataPtr &a, const TensorDataPtr &b, int recurseDepth) {
1313   if (recurseDepth > 0) {
1314     if (a == nullptr && b == nullptr) {
1315       return true;
1316     } else if (a != nullptr && b != nullptr) {
1317       return *a == *b;
1318     } else {
1319       return false;
1320     }
1321   } else {
1322     return (a == nullptr) == (b == nullptr);
1323   }
1324 }
1325 
CreateTensorData(mindspore::tensor::TensorPtr tensor,bool needSpecialize,int recurseDepth)1326 static TensorDataPtr CreateTensorData(mindspore::tensor::TensorPtr tensor, bool needSpecialize, int recurseDepth) {
1327   if (recurseDepth > 0) {
1328     return (tensor == nullptr) ? nullptr : std::make_shared<TensorData>(tensor, needSpecialize, recurseDepth);
1329   } else {
1330     return nullptr;
1331   }
1332 }
1333 
1334 class MapTensorData : public TensorData {
1335  public:
MapTensorData(PyObject * obj,bool needSpecialize,int recurseDepth)1336   MapTensorData(PyObject *obj, bool needSpecialize, int recurseDepth) : TensorData(obj, needSpecialize, recurseDepth) {
1337     tp_ = ItemType::MapTensor;
1338     needSpecialize = specialized_;
1339     auto pyObj = py::cast<py::object>(obj);
1340     auto tensor_ptr = pyObj.cast<mindspore::tensor::MapTensorPtr>();
1341     key_dtype_ = tensor_ptr->key_dtype();
1342     if (tensor_ptr->key_tensor() != nullptr) {
1343       key_shape_ = tensor_ptr->key_tensor()->shape();
1344     }
1345     default_value_ = tensor_ptr->default_value() == nullptr ? nullptr : tensor_ptr->default_value()->type()->Clone();
1346     permit_filter_value_ =
1347       tensor_ptr->permit_filter_value() == nullptr ? nullptr : tensor_ptr->permit_filter_value()->type()->Clone();
1348     evict_filter_value_ =
1349       tensor_ptr->evict_filter_value() == nullptr ? nullptr : tensor_ptr->evict_filter_value()->type()->Clone();
1350     value_shape_ = tensor_ptr->value_shape();
1351     key_tensor_ = CreateTensorData(tensor_ptr->key_tensor(), needSpecialize, recurseDepth);
1352     value_tensor_ = CreateTensorData(tensor_ptr->value_tensor(), needSpecialize, recurseDepth);
1353     status_tensor_ = CreateTensorData(tensor_ptr->status_tensor(), needSpecialize, recurseDepth);
1354   }
1355 
IsPermitFilterValue(const MapTensorData & other) const1356   bool IsPermitFilterValue(const MapTensorData &other) const {
1357     return (other.default_value_ == nullptr && default_value_ == nullptr) ||
1358            (other.default_value_ != nullptr && default_value_ != nullptr && *default_value_ == *(other.default_value_));
1359   }
1360 
IsDefaultValue(const MapTensorData & other) const1361   bool IsDefaultValue(const MapTensorData &other) const {
1362     return (other.default_value_ == nullptr && default_value_ == nullptr) ||
1363            (other.default_value_ != nullptr && default_value_ != nullptr && *default_value_ == *(other.default_value_));
1364   }
1365 
IsEvictFilterValue(const MapTensorData & other) const1366   bool IsEvictFilterValue(const MapTensorData &other) const {
1367     return (other.evict_filter_value_ == nullptr && evict_filter_value_ == nullptr) ||
1368            (other.evict_filter_value_ != nullptr && evict_filter_value_ != nullptr &&
1369             *evict_filter_value_ == *(other.evict_filter_value_));
1370   }
1371 
operator ==(const ItemData & obj) const1372   bool operator==(const ItemData &obj) const override {
1373     if (!ItemData::operator==(obj)) {
1374       return false;
1375     }
1376     const MapTensorData &other = dynamic_cast<const MapTensorData &>(obj);
1377     bool ret = TensorData::operator==(obj);
1378     return ret && other.key_dtype_ == key_dtype_ && other.key_shape_ == key_shape_ && IsDefaultValue(other) &&
1379            IsPermitFilterValue(other) && IsEvictFilterValue(other) && value_shape_ == other.value_shape_ &&
1380            Equal(key_tensor_, other.key_tensor_, recurseDepth_) &&
1381            Equal(value_tensor_, other.value_tensor_, recurseDepth_) &&
1382            Equal(status_tensor_, other.status_tensor_, recurseDepth_);
1383   }
1384 
ToString()1385   std::string ToString() override {
1386     std::string map_tensor = ToStringIntern();
1387     return DESC(map_tensor) + DESC_END;
1388   }
1389 
1390  protected:
ToStringIntern()1391   std::string ToStringIntern() override {
1392     return TensorData::ToStringIntern() + DESC_STRING(key_dtype_) + DESC_TOSTRING(default_value_) +
1393            DESC_TOSTRING(permit_filter_value_) + DESC_TOSTRING(evict_filter_value_) + DESC_TOSTRING(key_tensor_) +
1394            DESC_TOSTRING(value_tensor_) + DESC_TOSTRING(status_tensor_) + DESC_END;
1395   }
1396 
SubInfo(InfoPack * info)1397   void SubInfo(InfoPack *info) override {
1398     TensorData::SubInfo(info);
1399     (*info) << key_dtype_ << default_value_ << permit_filter_value_ << evict_filter_value_ << value_shape_
1400             << key_tensor_->Info() << value_tensor_->Info() << status_tensor_->Info();
1401   }
1402 
1403   mindspore::TypeId key_dtype_;
1404   ShapeVector key_shape_;
1405   TypePtr default_value_;
1406   TypePtr permit_filter_value_;
1407   TypePtr evict_filter_value_;
1408   ShapeVector value_shape_;
1409   TensorDataPtr key_tensor_;
1410   TensorDataPtr value_tensor_;
1411   TensorDataPtr status_tensor_;
1412 };
1413 
1414 class RowTensorData : public ItemData {
1415  public:
RowTensorData(PyObject * obj,bool needSpecialize,int recurseDepth)1416   RowTensorData(PyObject *obj, bool needSpecialize, int recurseDepth)
1417       : ItemData(ItemType::RowTensor, needSpecialize, recurseDepth) {
1418     auto pyObj = py::cast<py::object>(obj);
1419     auto tensor_ptr = pyObj.cast<mindspore::tensor::RowTensorPtr>();
1420     data_type_ = tensor_ptr->data_type();
1421     shape_ = tensor_ptr->shape();
1422     indices_ = CreateTensorData(tensor_ptr->GetIndices(), needSpecialize, recurseDepth);
1423     values_ = CreateTensorData(tensor_ptr->GetValues(), needSpecialize, recurseDepth);
1424   }
1425 
operator ==(const ItemData & obj) const1426   bool operator==(const ItemData &obj) const override {
1427     if (ItemData::operator==(obj)) {
1428       const RowTensorData &other = static_cast<const RowTensorData &>(obj);
1429       return other.data_type_ == data_type_ && other.shape_ == shape_ &&
1430              Equal(indices_, other.indices_, recurseDepth_) && Equal(values_, other.values_, recurseDepth_);
1431     }
1432     return false;
1433   }
1434 
ToString()1435   std::string ToString() override {
1436     std::string row_tensor = DESC_TOSTRING(indices_) + DESC_TOSTRING(values_) + DESC_STRING(data_type_);
1437     return DESC(row_tensor) + DESC_END;
1438   }
1439 
1440  protected:
SubInfo(InfoPack * info)1441   void SubInfo(InfoPack *info) override { (*info) << indices_->Info() << values_->Info() << data_type_ << shape_; }
1442   TensorDataPtr indices_;
1443   TensorDataPtr values_;
1444   mindspore::TypeId data_type_;
1445   ShapeVector shape_;
1446 };
1447 
1448 class COOTensorData : public ItemData {
1449  public:
COOTensorData(PyObject * obj,bool needSpecialize,int recurseDepth)1450   COOTensorData(PyObject *obj, bool needSpecialize, int recurseDepth)
1451       : ItemData(ItemType::COOTensor, needSpecialize, recurseDepth) {
1452     auto pyObj = py::cast<py::object>(obj);
1453     auto tensor_ptr = pyObj.cast<mindspore::tensor::COOTensorPtr>();
1454     data_type_ = tensor_ptr->data_type();
1455     shape_ = tensor_ptr->shape();
1456     indices_ = CreateTensorData(tensor_ptr->GetIndices(), needSpecialize, recurseDepth);
1457     values_ = CreateTensorData(tensor_ptr->GetValues(), needSpecialize, recurseDepth);
1458   }
1459 
operator ==(const ItemData & obj) const1460   bool operator==(const ItemData &obj) const override {
1461     if (ItemData::operator==(obj)) {
1462       const COOTensorData &other = static_cast<const COOTensorData &>(obj);
1463       return other.data_type_ == data_type_ && other.shape_ == shape_ &&
1464              Equal(indices_, other.indices_, recurseDepth_) && Equal(values_, other.values_, recurseDepth_);
1465     }
1466     return false;
1467   }
1468 
ToString()1469   std::string ToString() override {
1470     std::string coo_tensor = DESC_TOSTRING(indices_) + DESC_TOSTRING(values_) + DESC_STRING(data_type_);
1471     return DESC(coo_tensor) + DESC_END;
1472   }
1473 
1474  protected:
SubInfo(InfoPack * info)1475   void SubInfo(InfoPack *info) override { (*info) << indices_->Info() << values_->Info() << data_type_ << shape_; }
1476   TensorDataPtr indices_;
1477   TensorDataPtr values_;
1478   mindspore::TypeId data_type_;
1479   ShapeVector shape_;
1480 };
1481 
1482 class CSRTensorData : public ItemData {
1483  public:
CSRTensorData(PyObject * obj,bool needSpecialize,int recurseDepth)1484   CSRTensorData(PyObject *obj, bool needSpecialize, int recurseDepth)
1485       : ItemData(ItemType::CSRTensor, needSpecialize, recurseDepth) {
1486     auto pyObj = py::cast<py::object>(obj);
1487     auto tensor_ptr = pyObj.cast<mindspore::tensor::CSRTensorPtr>();
1488     data_type_ = tensor_ptr->data_type();
1489     shape_ = tensor_ptr->shape();
1490     indices_ = CreateTensorData(tensor_ptr->GetIndices(), needSpecialize, recurseDepth);
1491     values_ = CreateTensorData(tensor_ptr->GetValues(), needSpecialize, recurseDepth);
1492     indptr_ = CreateTensorData(tensor_ptr->GetIndptr(), needSpecialize, recurseDepth);
1493   }
1494 
operator ==(const ItemData & obj) const1495   bool operator==(const ItemData &obj) const override {
1496     if (ItemData::operator==(obj)) {
1497       const CSRTensorData &other = static_cast<const CSRTensorData &>(obj);
1498       return other.data_type_ == data_type_ && other.shape_ == shape_ &&
1499              Equal(indices_, other.indices_, recurseDepth_) && Equal(values_, other.values_, recurseDepth_) &&
1500              Equal(indptr_, other.indptr_, recurseDepth_);
1501     }
1502     return false;
1503   }
1504 
ToString()1505   std::string ToString() override {
1506     std::string csr_tensor =
1507       DESC_TOSTRING(indices_) + DESC_TOSTRING(values_) + DESC_TOSTRING(indptr_) + DESC_STRING(data_type_);
1508     return DESC(csr_tensor) + DESC_END;
1509   }
1510 
1511  protected:
SubInfo(InfoPack * info)1512   void SubInfo(InfoPack *info) override {
1513     (*info) << indices_->Info() << values_->Info() << indptr_->Info() << data_type_ << shape_;
1514   }
1515   TensorDataPtr indices_;
1516   TensorDataPtr values_;
1517   TensorDataPtr indptr_;
1518   mindspore::TypeId data_type_;
1519   ShapeVector shape_;
1520 };
1521 
1522 class TensorDataData : public ItemData {
1523  public:
TensorDataData(PyObject * obj,bool needSpecialize,int recurseDepth)1524   TensorDataData(PyObject *obj, bool needSpecialize, int recurseDepth)
1525       : ItemData(ItemType::Tensordata, needSpecialize, recurseDepth) {
1526     auto pyObj = py::cast<py::object>(obj);
1527     auto data = pyObj.cast<mindspore::tensor::TensorDataPtr>();
1528     size_ = (uint64_t)data->size();
1529     itemsize_ = (uint64_t)data->itemsize();
1530     nbytes_ = (uint64_t)data->nbytes();
1531     ndim_ = (int64_t)data->ndim();
1532     if (specialized_) {
1533       data_ptr_ = std::make_unique<uint8_t[]>(nbytes_);
1534       if (data_ptr_ != nullptr) {
1535         memcpy_s(data_ptr_.get(), nbytes_, reinterpret_cast<uint8_t *>(data->data()), nbytes_);
1536       }
1537     } else {
1538       data_ptr_.reset(nullptr);
1539     }
1540   }
1541 
~TensorDataData()1542   ~TensorDataData() override { data_ptr_.release(); }
1543 
operator ==(const ItemData & obj) const1544   bool operator==(const ItemData &obj) const override {
1545     if (ItemData::operator==(obj)) {
1546       const TensorDataData &other = static_cast<const TensorDataData &>(obj);
1547       if (specialized_) {
1548         return data_ptr_ != nullptr && other.data_ptr_ != nullptr && nbytes_ == other.nbytes_ &&
1549                memcmp(data_ptr_.get(), other.data_ptr_.get(), nbytes_) == 0;
1550       } else {
1551         return size_ == other.size_ && itemsize_ == other.itemsize_ && nbytes_ == other.nbytes_ &&
1552                ndim_ == other.ndim_ && (data_ptr_ == nullptr) == (other.data_ptr_ == nullptr);
1553       }
1554     }
1555     return false;
1556   }
1557 
ToString()1558   std::string ToString() override {
1559     std::string tensor_data = DESC_STRING(size_) + DESC_STRING(itemsize_) + DESC_STRING(nbytes_) + DESC_STRING(ndim_);
1560     return DESC(tensor_data) + DESC_END;
1561   }
1562 
1563  protected:
SubInfo(InfoPack * info)1564   void SubInfo(InfoPack *info) override { (*info) << size_ << itemsize_ << nbytes_ << ndim_; }
1565   std::unique_ptr<uint8_t[]> data_ptr_;
1566   uint64_t size_;
1567   uint64_t itemsize_;
1568   uint64_t nbytes_;
1569   int64_t ndim_;
1570 };
1571 
1572 class PrimitiveData : public ItemData {
1573  public:
PrimitiveData(PyObject * obj,bool needSpecialize,int recurseDepth)1574   PrimitiveData(PyObject *obj, bool needSpecialize, int recurseDepth)
1575       : ItemData(ItemType::Primitive, needSpecialize, recurseDepth) {
1576     auto pyObj = py::cast<py::object>(obj);
1577     auto data = pyObj.cast<PrimitivePyAdapterPtr>();
1578     py::dict pd = data->GetAttrDict();
1579     auto dct = pd.ptr();
1580     Py_ssize_t pos = 0;
1581     PyObject *key;
1582     PyObject *val;
1583     while (PyDict_Next(dct, &pos, &key, &val)) {
1584       ItemDataPtr k;
1585       ItemDataPtr v;
1586       if (recurseDepth > 0 || needSpecialize) {
1587         k = CreateItem(key, needSpecialize, recurseDepth);
1588         v = CreateItem(val, needSpecialize, recurseDepth);
1589       } else {
1590         k =
1591           CreateItem((key == NULL || key == Py_None) ? NULL : reinterpret_cast<PyObject *>(Py_TYPE(key)), false, false);
1592         v =
1593           CreateItem((val == NULL || val == Py_None) ? NULL : reinterpret_cast<PyObject *>(Py_TYPE(val)), false, false);
1594       }
1595       listK_.push_back(k);
1596       listV_.push_back(v);
1597     }
1598   }
1599 
operator ==(const ItemData & obj) const1600   bool operator==(const ItemData &obj) const override {
1601     if (ItemData::operator==(obj)) {
1602       const PrimitiveData &other = static_cast<const PrimitiveData &>(obj);
1603       if (other.listK_.size() == listK_.size() && other.listV_.size() == listV_.size()) {
1604         for (size_t i = 0; i < listK_.size(); ++i) {
1605           if (*(listK_[i]) == *(other.listK_[i]) && *(listV_[i]) == *(other.listV_[i])) {
1606             continue;
1607           } else {
1608             return false;
1609           }
1610         }
1611         return true;
1612       }
1613     }
1614     return false;
1615   }
1616 
ToString()1617   std::string ToString() override {
1618     std::string primitive;
1619     for (size_t i = 0; i < listK_.size(); ++i) {
1620       primitive += DESC_ITEM(listK_[i], listV_[i]);
1621     }
1622     return DESC(primitive) + DESC_END;
1623   }
1624 
1625  protected:
SubInfo(InfoPack * info)1626   void SubInfo(InfoPack *info) override {
1627     (*info) << uint64_t(listK_.size());
1628     for (auto item : listK_) {
1629       (*info) << item->Info();
1630     }
1631     (*info) << uint64_t(listV_.size());
1632     for (auto item : listV_) {
1633       (*info) << item->Info();
1634     }
1635   }
1636   std::vector<ItemDataPtr> listK_;
1637   std::vector<ItemDataPtr> listV_;
1638 };
1639 
1640 class CellData : public ItemData {
1641  public:
CellData(PyObject * obj,bool needSpecialize,int recurseDepth)1642   CellData(PyObject *obj, bool needSpecialize, int recurseDepth)
1643       : ItemData(ItemType::Cell, needSpecialize, recurseDepth) {
1644     auto pyObj = py::cast<py::object>(obj);
1645     auto cell = pyObj.cast<mindspore::CellPtr>();
1646     PyObject *ns = PyObject_GetAttrString(obj, "__dict__");
1647     if (!ns) {
1648       return;
1649     }
1650     PyObject *items = PyMapping_Items(ns);
1651     if (!items) {
1652       return;
1653     }
1654     for (Py_ssize_t pos = 0; pos < PyList_GET_SIZE(items); pos++) {
1655       PyObject *it = PySequence_Fast(PyList_GET_ITEM(items, pos), "items() returned non-iterable");
1656       if (!it || PySequence_Fast_GET_SIZE(it) != 2) {
1657         if (it) {
1658           Py_DECREF(it);
1659         }
1660         continue;
1661       }
1662       PyObject *key = PySequence_Fast_GET_ITEM(it, 0);
1663       PyObject *val = PySequence_Fast_GET_ITEM(it, 1);
1664       ItemDataPtr k;
1665       ItemDataPtr v;
1666       if (recurseDepth > 0 || needSpecialize) {
1667         k = CreateItem(key, needSpecialize, recurseDepth);
1668         v = CreateItem(val, needSpecialize, recurseDepth);
1669       } else {
1670         k =
1671           CreateItem((key == NULL || key == Py_None) ? NULL : reinterpret_cast<PyObject *>(Py_TYPE(key)), false, false);
1672         v =
1673           CreateItem((val == NULL || val == Py_None) ? NULL : reinterpret_cast<PyObject *>(Py_TYPE(val)), false, false);
1674       }
1675       listK_.push_back(k);
1676       listV_.push_back(v);
1677       Py_DECREF(it);
1678     }
1679     Py_DECREF(items);
1680     Py_DECREF(ns);
1681   }
1682 
operator ==(const ItemData & obj) const1683   bool operator==(const ItemData &obj) const override {
1684     if (ItemData::operator==(obj)) {
1685       const CellData &other = static_cast<const CellData &>(obj);
1686       for (size_t i = 0; i < listK_.size(); ++i) {
1687         if (*(listK_[i]) == *(other.listK_[i]) && *(listV_[i]) == *(other.listV_[i])) {
1688           continue;
1689         } else {
1690           return false;
1691         }
1692       }
1693       return true;
1694     }
1695     return false;
1696   }
1697 
ToString()1698   std::string ToString() override {
1699     std::string cell;
1700     for (size_t i = 0; i < listK_.size(); ++i) {
1701       cell += DESC_ITEM(listK_[i], listV_[i]);
1702     }
1703     return DESC(cell) + DESC_END;
1704   }
1705 
1706  protected:
SubInfo(InfoPack * info)1707   void SubInfo(InfoPack *info) override {
1708     (*info) << uint64_t(listK_.size());
1709     for (auto item : listK_) {
1710       (*info) << item->Info();
1711     }
1712     (*info) << uint64_t(listV_.size());
1713     for (auto item : listV_) {
1714       (*info) << item->Info();
1715     }
1716   }
1717   std::vector<ItemDataPtr> listK_;
1718   std::vector<ItemDataPtr> listV_;
1719 };
1720 
1721 class UnknownData : public ItemData {
1722  public:
UnknownData(PyObject * obj,bool needSpecialize,int recurseDepth)1723   UnknownData(PyObject *obj, bool needSpecialize, int recurseDepth)
1724       : ItemData(ItemType::PyUnknown, needSpecialize, recurseDepth) {
1725     refId_ = obj;
1726   }
1727 
operator ==(const ItemData & obj) const1728   bool operator==(const ItemData &obj) const override {
1729     if (ItemData::operator==(obj)) {
1730       return refId_ == (static_cast<const UnknownData &>(obj)).refId_;
1731     }
1732     return false;
1733   }
1734 
ToString()1735   std::string ToString() override {
1736     std::string ret = "unknown";
1737     return ret + ItemData::ToString();
1738   }
1739 
1740  protected:
1741   PyObject *refId_;
1742 };
1743 
ListData(PyObject * obj,bool needSpecialize,int recurseDepth)1744 ListData::ListData(PyObject *obj, bool needSpecialize, int recurseDepth)
1745     : ItemData(ItemType::PyList, needSpecialize, recurseDepth) {
1746   if (PyList_Check(obj)) {
1747     InitList(obj, needSpecialize, recurseDepth);
1748   } else if (PyTuple_Check(obj)) {
1749     InitTuple(obj, needSpecialize, recurseDepth);
1750   } else if (PySet_Check(obj)) {
1751     InitSet(obj, needSpecialize, recurseDepth);
1752   } else if (PyFrozenSet_Check(obj)) {
1753     InitFrozenSet(obj, needSpecialize, recurseDepth);
1754   }
1755   if (needSpecialize) {
1756     return;  // value check exactly
1757   }
1758   for (const auto &data : listVar_) {
1759     if (data->GetItemType() == ItemType::PyType) {
1760       static_cast<TypeData *>(data.get())->set_ambiguous_tensor_type(true);
1761     }
1762   }
1763 }
1764 
1765 using CheckPyObjectFunc = bool (*)(PyObject *obj);
1766 using CreatePyObjectFunc = ItemDataPtr (*)(PyObject *obj, bool need_specialize, int recurse_depth);
1767 template <typename T>
CreatePyData(PyObject * obj,bool need_specialize,int recurse_depth)1768 ItemDataPtr CreatePyData(PyObject *obj, bool need_specialize, int recurse_depth) {
1769   return std::make_shared<T>(obj, need_specialize, recurse_depth);
1770 }
1771 template <typename T>
CreateMutablePyData(PyObject * obj,bool need_specialize,int recurse_depth)1772 ItemDataPtr CreateMutablePyData(PyObject *obj, bool need_specialize, int recurse_depth) {
1773   return std::make_shared<T>(obj, false, recurse_depth);
1774 }
CheckMetaTensorObject(PyObject * obj)1775 static bool CheckMetaTensorObject(PyObject *obj) {
1776   return py::isinstance<mindspore::tensor::MetaTensor>(obj) || IsStubTensor(py::cast<py::object>(obj));
1777 }
CheckTensorObject(PyObject * obj)1778 static bool CheckTensorObject(PyObject *obj) { return py::isinstance<mindspore::tensor::Tensor>(obj); }
CheckDictKeyValueItemObject(PyObject * obj)1779 static bool CheckDictKeyValueItemObject(PyObject *obj) {
1780   return !!PyDict_Check(obj) || !!PyDictKeys_Check(obj) || !!PyDictValues_Check(obj) || !!PyDictItems_Check(obj);
1781 }
1782 static const std::vector<std::pair<CheckPyObjectFunc, CreatePyObjectFunc>> kFuncPyObjectConverter = {
__anonb719df630302() 1783   {[](PyObject *obj) -> bool { return PyLong_Check(obj) && !PyBool_Check(obj); }, CreatePyData<IntData>},
__anonb719df630402() 1784   {[](PyObject *obj) -> bool { return !!PyFloat_Check(obj); }, CreatePyData<FloatData>},
__anonb719df630502() 1785   {[](PyObject *obj) -> bool { return !!PyBool_Check(obj); }, CreatePyData<BoolData>},
__anonb719df630602() 1786   {[](PyObject *obj) -> bool { return !!PyBytes_Check(obj); }, CreatePyData<BytesData>},
__anonb719df630702() 1787   {[](PyObject *obj) -> bool { return !!PyUnicode_Check(obj); }, CreatePyData<StringData>},
__anonb719df630802() 1788   {[](PyObject *obj) -> bool { return !!PyList_Check(obj); }, CreatePyData<ListData>},
__anonb719df630902() 1789   {[](PyObject *obj) -> bool { return !!PyTuple_Check(obj); }, CreatePyData<ListData>},
__anonb719df630a02() 1790   {[](PyObject *obj) -> bool { return !!PySet_Check(obj); }, CreatePyData<ListData>},
__anonb719df630b02() 1791   {[](PyObject *obj) -> bool { return !!PyFrozenSet_Check(obj); }, CreatePyData<ListData>},
__anonb719df630c02() 1792   {[](PyObject *obj) -> bool { return CheckDictKeyValueItemObject(obj); }, CreatePyData<DictData>},
__anonb719df630d02() 1793   {[](PyObject *obj) -> bool { return !!PyComplex_Check(obj); }, CreatePyData<ComplexData>},
__anonb719df630e02() 1794   {[](PyObject *obj) -> bool { return !!PySlice_Check(obj); }, CreatePyData<SliceData>},
__anonb719df630f02() 1795   {[](PyObject *obj) -> bool { return !!PyFunction_Check(obj); }, CreatePyData<FunctionData>},
__anonb719df631002() 1796   {[](PyObject *obj) -> bool { return !!PyMethod_Check(obj); }, CreatePyData<MethodData>},
__anonb719df631102() 1797   {[](PyObject *obj) -> bool { return !!PyInstanceMethod_Check(obj); }, CreatePyData<InstanceMethodData>},
__anonb719df631202() 1798   {[](PyObject *obj) -> bool { return !!PyType_Check(obj); }, CreatePyData<TypeData>},
__anonb719df631302() 1799   {[](PyObject *obj) -> bool { return py::isinstance<py::array>(obj); }, CreatePyData<NumpyData>},
__anonb719df631402() 1800   {[](PyObject *obj) -> bool { return py::isinstance<mindspore::Type>(obj); }, CreatePyData<TensorTypeData>},
__anonb719df631502() 1801   {[](PyObject *obj) -> bool { return py::isinstance<mindspore::tensor::MapTensor>(obj); },
1802    CreatePyData<MapTensorData>},
__anonb719df631602() 1803   {[](PyObject *obj) -> bool { return py::isinstance<mindspore::ParamInfo>(obj); }, CreatePyData<ParamInfoData>},
__anonb719df631702() 1804   {[](PyObject *obj) -> bool { return CheckMetaTensorObject(obj); }, CreatePyData<MetaTensorData>},
__anonb719df631802() 1805   {[](PyObject *obj) -> bool { return CheckTensorObject(obj); }, CreatePyData<TensorData>},
__anonb719df631902() 1806   {[](PyObject *obj) -> bool { return py::isinstance<mindspore::tensor::TensorData>(obj); },
1807    CreatePyData<TensorDataData>},
__anonb719df631a02() 1808   {[](PyObject *obj) -> bool { return py::isinstance<mindspore::PrimitivePyAdapter>(obj); },
1809    CreatePyData<PrimitiveData>},
__anonb719df631b02() 1810   {[](PyObject *obj) -> bool { return py::isinstance<mindspore::Cell>(obj); }, CreatePyData<CellData>},
__anonb719df631c02() 1811   {[](PyObject *obj) -> bool { return py::isinstance<mindspore::tensor::RowTensor>(obj); },
1812    CreateMutablePyData<RowTensorData>},
__anonb719df631d02() 1813   {[](PyObject *obj) -> bool { return py::isinstance<mindspore::tensor::COOTensor>(obj); },
1814    CreateMutablePyData<COOTensorData>},
__anonb719df631e02() 1815   {[](PyObject *obj) -> bool { return py::isinstance<mindspore::tensor::CSRTensor>(obj); },
1816    CreateMutablePyData<CSRTensorData>},
1817 };
1818 
CreateData(PyObject * obj,bool need_specialize,int recurse_depth)1819 static ItemDataPtr CreateData(PyObject *obj, bool need_specialize, int recurse_depth) {
1820   auto tar =
1821     std::find_if(kFuncPyObjectConverter.begin(), kFuncPyObjectConverter.end(),
1822                  [obj](const std::pair<CheckPyObjectFunc, CreatePyObjectFunc> &func) { return func.first(obj); });
1823   if (tar != kFuncPyObjectConverter.end()) {
1824     return tar->second(obj, need_specialize, recurse_depth);
1825   } else {
1826     return std::make_shared<UnknownData>(obj, need_specialize, recurse_depth);
1827   }
1828 }
1829 
CreateItem(PyObject * obj,bool need_specialize,int recurse_depth)1830 static ItemDataPtr CreateItem(PyObject *obj, bool need_specialize, int recurse_depth) {
1831   ReprRecursionScope scope(obj);
1832   if (scope.ReEnterOrError()) {
1833     return std::make_shared<ItemData>(ItemType::PyNull, need_specialize, recurse_depth);
1834   }
1835   if (recurse_depth < -1) {
1836     if (obj != NULL && obj != Py_None) {
1837       PyObject *py_type;
1838       py::object py_obj = py::reinterpret_borrow<py::object>(obj);
1839       if (IsStubTensor(py_obj)) {
1840         py_type = GetMsTensorType();
1841       } else {
1842         py_type = reinterpret_cast<PyObject *>(Py_TYPE(obj));
1843       }
1844       return std::make_shared<TypeData>(py_type, false, 0);
1845     } else {
1846       return std::make_shared<ItemData>(ItemType::PyNull, false, 0);
1847     }
1848   }
1849   recurse_depth -= 1;
1850   ItemDataPtr dp;
1851   if (obj != NULL && obj != Py_None) {
1852     dp = CreateData(obj, need_specialize, recurse_depth);
1853   } else {
1854     dp = std::make_shared<ItemData>(ItemType::PyNull, need_specialize, recurse_depth);
1855   }
1856   return dp;
1857 }
1858 
GuardItem(TracePtr tt)1859 GuardItem::GuardItem(TracePtr tt) : var_(tt), type_(GIType::GTUnknown), info_(nullptr) {}
1860 
Replace(TracePtr dst,TracePtr src)1861 void GuardItem::Replace(TracePtr dst, TracePtr src) {
1862   if (!var_) {
1863     return;
1864   }
1865   if (*var_ == *src) {
1866     var_ = dst;
1867   } else {
1868     var_->Replace(dst, src);
1869   }
1870 }
1871 
Optimize()1872 GuardItemPtr GuardItem::Optimize() {
1873   auto trace = var_->Optimize();
1874   if (trace != nullptr) {
1875     var_ = trace;
1876     info_ = nullptr;
1877     Info();
1878     return shared_from_this();
1879   } else {
1880     return nullptr;
1881   }
1882 }
1883 
GetTrace()1884 TracePtr GuardItem::GetTrace() { return var_; }
1885 
operator ==(const GuardItem & obj) const1886 bool GuardItem::operator==(const GuardItem &obj) const { return type_ == obj.type_ && *var_ == *(obj.var_); }
1887 
1888 static constexpr int kGuardItemTotalStage = 2;
1889 static constexpr int kGuardItemRetrieveStage = 0;
1890 static constexpr int kGuardItemCompareStage = 1;
1891 
GuardItemPerfStart(bool enable,int total)1892 static void GuardItemPerfStart(bool enable, int total) {
1893   if (enable) {
1894     OptGuardPerf::GetGuardPerf()->LogItemPerfStart(total);
1895   }
1896 }
1897 
GuardItemPerfStage(bool enable,GuardItem * item,int stage)1898 static void GuardItemPerfStage(bool enable, GuardItem *item, int stage) {
1899   if (enable) {
1900     OptGuardPerf::GetGuardPerf()->LogItemPerfEnd(item, stage);
1901   }
1902 }
1903 
1904 class EqGuard : public GuardItem {
1905  public:
EqGuard(TracePtr obj,bool needSpecialize,int recurseDepth)1906   EqGuard(TracePtr obj, bool needSpecialize, int recurseDepth)
1907       : GuardItem(obj),
1908         dp_(CreateItem(obj->GetObject(), needSpecialize, recurseDepth)),
1909         specialized_(needSpecialize),
1910         recurse_(recurseDepth) {
1911     type_ = GIType::GTEqual;
1912     last_ = obj->GetObject();
1913   }
1914 
Check(const PyFrameObject * frame,std::map<size_t,PyObject * > * cache,bool perf)1915   virtual bool Check(const PyFrameObject *frame, std::map<size_t, PyObject *> *cache, bool perf) {
1916     if (var_->IsConst()) {
1917       return true;
1918     }
1919     if (!var_->IsSpecialized() && var_->GetRelaxCount() > 0 && !RootTrace::Support(var_->GetTraceType()) &&
1920         !specialized_) {
1921       // it just needs to guard inputs instead of leaf node
1922       return true;
1923     }
1924     GuardItemPerfStart(perf, kGuardItemTotalStage);
1925     PyObject *obj = GetObjectFromTrace(frame, var_, cache, perf);
1926     GuardItemPerfStage(perf, this, kGuardItemRetrieveStage);
1927     bool ret = obj == last_ || Check(obj);
1928     GuardItemPerfStage(perf, this, kGuardItemCompareStage);
1929     if (obj != NULL) {
1930       Py_DECREF(obj);
1931     }
1932     return ret;
1933   }
1934 
Check(PyObject * obj)1935   virtual bool Check(PyObject *obj) {
1936     ItemDataPtr other = CreateItem(obj, specialized_, recurse_);
1937     return *dp_ == *other;
1938   }
1939 
ToString()1940   virtual std::string ToString() {
1941     if (strGuard_.size() > 0) {
1942       return strGuard_;
1943     }
1944     strGuard_ = var_->ToString() + "==" + dp_->ToString();
1945     strGuard_ = std::regex_replace(strGuard_, std::regex("(\n)"), "");
1946     return strGuard_;
1947   }
1948 
Info()1949   virtual const InfoPack &Info() {
1950     if (info_ == nullptr) {
1951       InfoPack info;
1952       info << uint8_t(type_);
1953       info.Begin();
1954       info << var_->Info() << dp_->Info();
1955       info.End();
1956       info_ = std::make_shared<InfoPack>(info);
1957       info_->Update();
1958     }
1959     return *info_;
1960   }
1961 
operator ==(const GuardItem & obj) const1962   bool operator==(const GuardItem &obj) const override {
1963     if (GuardItem::operator==(obj)) {
1964       auto other = static_cast<const EqGuard &>(obj);
1965       return specialized_ == other.specialized_ && recurse_ == other.recurse_ && *dp_ == *(other.dp_);
1966     }
1967     return false;
1968   }
1969 
MatchDynamicShape(std::shared_ptr<GuardItem> other)1970   bool MatchDynamicShape(std::shared_ptr<GuardItem> other) override {
1971     var_->Detach();
1972     other->GetTrace()->Detach();
1973     if (other->GetType() != GIType::GTEqual || !(*var_ == *(other->GetTrace())) ||
1974         !dp_->MatchDynamicShape((static_cast<EqGuard *>(other.get()))->dp_)) {
1975       return false;
1976     } else {
1977       return true;
1978     }
1979   }
1980 
ApplyDynamicShape(PyObject * obj)1981   PyObject *ApplyDynamicShape(PyObject *obj) override {
1982     auto type = dp_->GetItemType();
1983     if (type != ItemType::MetaTensor && type != ItemType::Tensor) {
1984       return nullptr;
1985     }
1986     auto item = (MetaTensorData &)(*dp_);
1987     if (item.IsDynamicShape()) {
1988       return py::cast(item.MakeTensor()).inc_ref().ptr();
1989     } else {
1990       return nullptr;
1991     }
1992   }
1993 
1994  protected:
1995   ItemDataPtr dp_;
1996   bool specialized_;
1997   int recurse_;
1998   PyObject *last_;
1999 };
2000 
2001 class TypeGuard : public GuardItem {
2002  public:
TypeGuard(TracePtr obj)2003   explicit TypeGuard(TracePtr obj) : GuardItem(obj) {
2004     type_ = GIType::GTType;
2005     is_const_ = false;
2006     if (obj->GetTraceType() == TraceType::Type) {
2007       refType_ = std::dynamic_pointer_cast<TypeTrace>(obj)->GetType();
2008     } else {
2009       refType_ = Py_TYPE(obj->GetObject());
2010     }
2011     if (obj->GetRelaxCount() > 0) {
2012       check_count_ = 0;
2013     } else {
2014       check_count_ = -1;
2015     }
2016   }
2017 
Check(const PyFrameObject * frame,std::map<size_t,PyObject * > * cache,bool perf)2018   virtual bool Check(const PyFrameObject *frame, std::map<size_t, PyObject *> *cache, bool perf) {
2019     if (var_->IsConst() || is_const_) {
2020       return true;
2021     }
2022     GuardItemPerfStart(perf, kGuardItemTotalStage);
2023     PyObject *obj = GetObjectFromTrace(frame, var_, cache, perf);
2024     GuardItemPerfStage(perf, this, kGuardItemRetrieveStage);
2025     bool ret = Check(obj);
2026     GuardItemPerfStage(perf, this, kGuardItemCompareStage);
2027     if (var_->GetTraceType() != TraceType::Type && obj != NULL) {
2028       Py_DECREF(obj);
2029     }
2030     if (check_count_ >= 0) {
2031       if (!ret) {
2032         check_count_ = -1;
2033       } else {
2034         check_count_++;
2035         if (check_count_ > var_->GetRelaxCount()) {
2036           is_const_ = true;
2037         }
2038       }
2039     }
2040     return ret;
2041   }
2042 
Check(PyObject * obj)2043   virtual bool Check(PyObject *obj) {
2044     if (obj == NULL) {
2045       return false;
2046     }
2047     PyTypeObject *tp;
2048     if (var_->GetTraceType() == TraceType::Type) {
2049       tp = reinterpret_cast<PyTypeObject *>(obj);
2050     } else {
2051       tp = Py_TYPE(obj);
2052     }
2053     if (tp != refType_) {
2054       return false;
2055     } else {
2056       return true;
2057     }
2058   }
2059 
ToString()2060   std::string ToString() override {
2061     if (strGuard_.size() > 0) {
2062       return strGuard_;
2063     }
2064     if (var_->GetTraceType() == TraceType::Type) {
2065       strGuard_ = var_->ToString() + std::string("==") + refType_->tp_name;
2066     } else {
2067       strGuard_ = std::string("type(") + var_->ToString() + std::string(")==") + refType_->tp_name;
2068     }
2069     strGuard_ = std::regex_replace(strGuard_, std::regex("(\n)"), "");
2070     return strGuard_;
2071   }
2072 
Info()2073   virtual const InfoPack &Info() {
2074     if (info_ == nullptr) {
2075       InfoPack info;
2076       info << uint8_t(type_);
2077       info.Begin();
2078       info << var_->Info() << refType_->tp_name;
2079       info.End();
2080       info_ = std::make_shared<InfoPack>(info);
2081       info_->Update();
2082     }
2083     return *info_;
2084   }
2085 
operator ==(const GuardItem & obj) const2086   bool operator==(const GuardItem &obj) const override {
2087     if (GuardItem::operator==(obj)) {
2088       return refType_ == (static_cast<const TypeGuard &>(obj)).refType_;
2089     }
2090     return false;
2091   }
2092 
2093  protected:
2094   PyTypeObject *refType_;
2095   int check_count_;
2096   bool is_const_;
2097 };
2098 
2099 class IdGuard : public GuardItem {
2100  public:
IdGuard(TracePtr obj)2101   explicit IdGuard(TracePtr obj) : GuardItem(obj) {
2102     type_ = GIType::GTId;
2103     refId_ = obj->GetObject();
2104   }
2105 
Check(const PyFrameObject * frame,std::map<size_t,PyObject * > * cache,bool perf)2106   virtual bool Check(const PyFrameObject *frame, std::map<size_t, PyObject *> *cache, bool perf) {
2107     if (var_->IsConst()) {
2108       return true;
2109     }
2110     GuardItemPerfStart(perf, kGuardItemTotalStage);
2111     PyObject *obj = GetObjectFromTrace(frame, var_, cache, perf);
2112     GuardItemPerfStage(perf, this, kGuardItemRetrieveStage);
2113     bool ret = Check(obj);
2114     GuardItemPerfStage(perf, this, kGuardItemCompareStage);
2115     if (obj != NULL) {
2116       Py_DECREF(obj);
2117     }
2118     return ret;
2119   }
2120 
Check(PyObject * obj)2121   virtual bool Check(PyObject *obj) {
2122     bool ret = false;
2123     if (obj == NULL) {
2124       return ret;
2125     }
2126     if (obj != refId_) {
2127       ret = false;
2128     } else {
2129       ret = true;
2130     }
2131     return ret;
2132   }
2133 
ToString()2134   std::string ToString() override {
2135     if (strGuard_.size() > 0) {
2136       return strGuard_;
2137     }
2138     strGuard_ = std::string("id(") + var_->ToString() + std::string(")==") + std::to_string((size_t)refId_);
2139     strGuard_ = std::regex_replace(strGuard_, std::regex("(\n)"), "");
2140     return strGuard_;
2141   }
2142 
Info()2143   virtual const InfoPack &Info() {
2144     if (info_ == nullptr) {
2145       InfoPack info;
2146       info << uint8_t(type_);
2147       info.Begin();
2148       info << var_->Info() << reinterpret_cast<void *>(refId_);
2149       info.End();
2150       info_ = std::make_shared<InfoPack>(info);
2151       info_->Update();
2152     }
2153     return *info_;
2154   }
2155 
operator ==(const GuardItem & obj) const2156   bool operator==(const GuardItem &obj) const override {
2157     if (GuardItem::operator==(obj)) {
2158       return refId_ == (static_cast<const IdGuard &>(obj)).refId_;
2159     }
2160     return false;
2161   }
2162 
2163  protected:
2164   PyObject *refId_;
2165 };
2166 
2167 class ReprGuard : public GuardItem {
2168  public:
ReprGuard(TracePtr obj)2169   explicit ReprGuard(TracePtr obj) : GuardItem(obj) {
2170     type_ = GIType::GTRepr;
2171     refRepr_ = PyObject_Repr(obj->GetObject());
2172   }
2173 
~ReprGuard()2174   virtual ~ReprGuard() { Py_XDECREF(refRepr_); }
2175 
Check(const PyFrameObject * frame,std::map<size_t,PyObject * > * cache,bool perf)2176   virtual bool Check(const PyFrameObject *frame, std::map<size_t, PyObject *> *cache, bool perf) {
2177     if (var_->IsConst()) {
2178       return true;
2179     }
2180     GuardItemPerfStart(perf, kGuardItemTotalStage);
2181     PyObject *obj = GetObjectFromTrace(frame, var_, cache, perf);
2182     GuardItemPerfStage(perf, this, kGuardItemRetrieveStage);
2183     bool ret = Check(obj);
2184     GuardItemPerfStage(perf, this, kGuardItemCompareStage);
2185     if (obj != nullptr) {
2186       Py_DECREF(obj);
2187     }
2188     return ret;
2189   }
2190 
Check(PyObject * obj)2191   virtual bool Check(PyObject *obj) {
2192     bool ret = false;
2193     if (obj == nullptr) {
2194       return ret;
2195     }
2196     PyObject *repr_flag = GetAttrReprCacheStr();
2197     PyObject *repr;
2198     bool from_cache = false;
2199     if (PyObject_HasAttr(obj, repr_flag)) {
2200       from_cache = true;
2201       repr = PyObject_GetAttr(obj, repr_flag);
2202     } else {
2203       repr = PyObject_Repr(obj);
2204       PyObject_SetAttr(obj, repr_flag, repr);
2205     }
2206     if (PyUnicode_Compare(repr, refRepr_)) {
2207       // Inplace operation may change the object without clearing the cache of guard.
2208       if (from_cache && !PyUnicode_Compare(PyObject_Repr(obj), refRepr_)) {
2209         ret = true;
2210       } else {
2211         ret = false;
2212       }
2213     } else {
2214       ret = true;
2215     }
2216     Py_XDECREF(repr);
2217     return ret;
2218   }
2219 
ToString()2220   std::string ToString() override {
2221     if (strGuard_.size() > 0) {
2222       return strGuard_;
2223     }
2224     strGuard_ = std::string(PyUnicode_AsUTF8(refRepr_));
2225     strGuard_ = std::regex_replace(strGuard_, std::regex("(\n)"), "");
2226     return strGuard_;
2227   }
2228 
operator ==(const GuardItem & obj) const2229   bool operator==(const GuardItem &obj) const override {
2230     if (GuardItem::operator==(obj)) {
2231       return refRepr_ == (static_cast<const ReprGuard &>(obj)).refRepr_;
2232     }
2233     return false;
2234   }
2235 
2236  protected:
Info()2237   virtual const InfoPack &Info() {
2238     if (info_ == nullptr) {
2239       InfoPack info;
2240       info << uint8_t(type_);
2241       info.Begin();
2242       info << std::string(PyUnicode_AsUTF8(refRepr_));
2243       info.End();
2244       info_ = std::make_shared<InfoPack>(info);
2245       info_->Update();
2246     }
2247     return *info_;
2248   }
2249   PyObject *refRepr_;
2250 };
2251 
2252 class AttrGuard : public GuardItem {
2253  public:
AttrGuard(TracePtr pObj)2254   explicit AttrGuard(TracePtr pObj) : GuardItem(pObj) {
2255     type_ = GIType::GTAttr;
2256     AttrTracePtr t = std::dynamic_pointer_cast<AttrTrace>(pObj);
2257     PyObject *obj = t->GetOrigin()->GetObject();
2258     nameAttr_ = t->GetAttribute();
2259     if (PyObject_HasAttrString(obj, nameAttr_.c_str()) != 0) {
2260       hasAttr_ = true;
2261     } else {
2262       hasAttr_ = false;
2263       bool is_dict = PyDict_CheckExact(obj);
2264       PyObject *itemName = PyUnicode_FromString(nameAttr_.c_str());
2265       PyObject *attr = NULL;
2266       if (is_dict) {
2267         attr = PyDict_GetItem(obj, itemName);
2268         if (attr != NULL) {
2269           Py_INCREF(attr);
2270         }
2271       } else if (PyMapping_Check(obj) || PySequence_Check(obj)) {
2272         attr = PyObject_GetItem(obj, itemName);
2273       }
2274       hasAttr_ = attr != NULL;
2275       Py_DECREF(itemName);
2276       if (attr != NULL) {
2277         Py_DECREF(attr);
2278       }
2279     }
2280   }
2281 
2282   ~AttrGuard() = default;
2283 
Check(const PyFrameObject * frame,std::map<size_t,PyObject * > * cache,bool perf)2284   virtual bool Check(const PyFrameObject *frame, std::map<size_t, PyObject *> *cache, bool perf) {
2285     if (var_->IsConst()) {
2286       return true;
2287     }
2288     GuardItemPerfStart(perf, kGuardItemTotalStage);
2289     PyObject *obj = GetObjectFromTrace(frame, var_, cache, perf);
2290     GuardItemPerfStage(perf, this, kGuardItemRetrieveStage);
2291     bool ret = CheckIntern(obj);
2292     GuardItemPerfStage(perf, this, kGuardItemCompareStage);
2293     if (obj != NULL) {
2294       Py_DECREF(obj);
2295     }
2296     return ret;
2297   }
2298 
Check(PyObject * obj)2299   virtual bool Check(PyObject *obj) {
2300     bool ret;
2301     if (PyObject_HasAttrString(obj, nameAttr_.c_str()) != 0) {
2302       ret = hasAttr_;
2303     } else {
2304       bool is_dict = PyDict_CheckExact(obj);
2305       PyObject *itemName = PyUnicode_FromString(nameAttr_.c_str());
2306       PyObject *attr = NULL;
2307       if (is_dict) {
2308         attr = PyDict_GetItem(obj, itemName);
2309         if (attr != NULL) {
2310           Py_INCREF(attr);
2311         }
2312       } else if (PyMapping_Check(obj) || PySequence_Check(obj)) {
2313         attr = PyObject_GetItem(obj, itemName);
2314       }
2315       ret = CheckIntern(attr);
2316       Py_DECREF(itemName);
2317       if (attr != NULL) {
2318         Py_DECREF(attr);
2319       }
2320     }
2321     return ret;
2322   }
2323 
CheckIntern(PyObject * obj)2324   virtual bool CheckIntern(PyObject *obj) {
2325     bool ret;
2326     if ((obj == NULL && !hasAttr_) || (obj != NULL && hasAttr_)) {
2327       ret = true;
2328     } else {
2329       ret = false;
2330     }
2331     return ret;
2332   }
2333 
ToString()2334   virtual std::string ToString() {
2335     if (strGuard_.size() > 0) {
2336       return strGuard_;
2337     }
2338     strGuard_ = std::string("exist(") + var_->ToString() + std::string(".") + nameAttr_ +
2339                 "==" + std::to_string(hasAttr_) + std::string(")");
2340     strGuard_ = std::regex_replace(strGuard_, std::regex("(\n)"), "");
2341     return strGuard_;
2342   }
2343 
operator ==(const GuardItem & obj) const2344   bool operator==(const GuardItem &obj) const override {
2345     if (GuardItem::operator==(obj)) {
2346       return hasAttr_ == (static_cast<const AttrGuard &>(obj)).hasAttr_ &&
2347              nameAttr_ == (static_cast<const AttrGuard &>(obj)).nameAttr_;
2348     }
2349     return false;
2350   }
2351 
2352  protected:
Info()2353   virtual const InfoPack &Info() {
2354     if (info_ == nullptr) {
2355       InfoPack info;
2356       info << uint8_t(type_);
2357       info.Begin();
2358       info << var_->Info() << nameAttr_ << hasAttr_;
2359       info.End();
2360       info_ = std::make_shared<InfoPack>(info);
2361       info_->Update();
2362     }
2363     return *info_;
2364   }
2365   bool hasAttr_;
2366   std::string nameAttr_;
2367 };
2368 
GuardEqual(TracePtr obj,bool needSpecialize,int recurseDepth)2369 GuardItemPtr GuardEqual(TracePtr obj, bool needSpecialize, int recurseDepth) {
2370   return std::make_shared<EqGuard>(obj, needSpecialize, recurseDepth);
2371 }
2372 
GuardType(TracePtr obj)2373 GuardItemPtr GuardType(TracePtr obj) { return std::make_shared<TypeGuard>(obj); }
2374 
GuardId(TracePtr obj)2375 GuardItemPtr GuardId(TracePtr obj) {
2376   auto py_obj = obj->GetObject();
2377   auto pyObj = py::cast<py::object>(obj->GetObject());
2378   bool is_param = py::hasattr(pyObj, "__parameter__") && py::isinstance<tensor::MetaTensor>(pyObj);
2379   if (!is_param && (IsStubTensor(pyObj) || py::isinstance<mindspore::tensor::Tensor>(py_obj))) {
2380     return GuardEqual(obj, false, INT_MAX);
2381   } else {
2382     return std::make_shared<IdGuard>(obj);
2383   }
2384 }
2385 
GuardRepr(TracePtr obj)2386 GuardItemPtr GuardRepr(TracePtr obj) { return std::make_shared<ReprGuard>(obj); }
2387 
GuardAttr(TracePtr obj)2388 GuardItemPtr GuardAttr(TracePtr obj) {
2389   if (obj->GetTraceType() != TraceType::Attr) {
2390     return nullptr;
2391   } else {
2392     return std::make_shared<AttrGuard>(obj);
2393   }
2394 }
2395 
IsPyObjectEqual(PyObject * src,PyObject * dst)2396 bool IsPyObjectEqual(PyObject *src, PyObject *dst) {
2397   if (src == dst) {
2398     return true;
2399   }
2400   ItemDataPtr src_item = CreateItem(src, true, INT_MAX);
2401   ItemDataPtr dst_item = CreateItem(dst, true, INT_MAX);
2402   return *src_item == *dst_item;
2403 }
2404 
2405 static PyObject *g_ms_module = nullptr;
2406 static PyObject *g_ms_type = nullptr;
2407 static PyObject *g_tensor_type = nullptr;
2408 
InitMsModule()2409 static bool InitMsModule() {
2410   if (g_ms_module == nullptr) {
2411     g_ms_module = PyImport_ImportModule("mindspore");
2412   }
2413   return g_ms_module != nullptr && g_ms_module != Py_None;
2414 }
2415 
InitMsType()2416 static bool InitMsType() {
2417   if (g_ms_type == NULL) {
2418     g_ms_type = PyImport_ImportModule("mindspore.common.dtype");
2419   }
2420   return g_ms_type != NULL && g_ms_type != Py_None;
2421 }
2422 
InitMsTensor()2423 static bool InitMsTensor() {
2424   if (g_tensor_type == nullptr && InitMsModule()) {
2425     g_tensor_type = PyObject_GetAttrString(g_ms_module, "Tensor");
2426   }
2427   return g_tensor_type != nullptr && g_tensor_type != Py_None && PyType_Check(g_tensor_type);
2428 }
2429 
GetMsModule()2430 PyObject *GetMsModule() {
2431   if (InitMsModule()) {
2432     return g_ms_module;
2433   } else {
2434     return nullptr;
2435   }
2436 }
2437 
GetMsType()2438 PyObject *GetMsType() {
2439   if (InitMsType()) {
2440     return g_ms_type;
2441   } else {
2442     return nullptr;
2443   }
2444 }
2445 
GetMsTensorType()2446 PyObject *GetMsTensorType() {
2447   if (InitMsTensor()) {
2448     return g_tensor_type;
2449   } else {
2450     return nullptr;
2451   }
2452 }
2453 
2454 }  // namespace pijit
2455 }  // namespace mindspore
2456