1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pipeline/jit/pi/graph_guard/guard_utils.h"
17 #include <regex>
18 #include "pybind11/pybind11.h"
19 #include "pybind_api/ir/primitive_py.h"
20 #include "pybind_api/ir/cell_py.h"
21 #include "include/common/utils/convert_utils_py.h"
22 #include "pipeline/jit/pi/utils/utils.h"
23 #include "include/common/utils/stub_tensor.h"
24 #include "pipeline/jit/pi/graph_guard/strategy.h"
25 #include "pipeline/jit/pi/graph_guard/guard.h"
26 #include "pipeline/jit/pi/graph_guard/infer.h"
27
28 namespace mindspore {
29 namespace pijit {
30
31 static PyObject *kPyAttrStub = nullptr;
32 static PyObject *kPyAttrTensor = nullptr;
33 static PyObject *kPyAttrReprCache = nullptr;
34 static const char kPyAttrReprCacheStr[] = "__repr_cache__";
35 static const char kPyMethodStubSync[] = "stub_sync";
GetAttrStubStr()36 static PyObject *GetAttrStubStr() {
37 if (kPyAttrStub == nullptr) {
38 kPyAttrStub = PyUnicode_FromString(stub::PY_ATTR_STUB);
39 }
40 return kPyAttrStub;
41 }
GetAttrTensorStr()42 static PyObject *GetAttrTensorStr() {
43 if (kPyAttrTensor == nullptr) {
44 kPyAttrTensor = PyUnicode_FromString(stub::PY_ATTR_TENSOR);
45 }
46 return kPyAttrTensor;
47 }
GetAttrReprCacheStr()48 static PyObject *GetAttrReprCacheStr() {
49 if (kPyAttrReprCache == nullptr) {
50 kPyAttrReprCache = PyUnicode_FromString(kPyAttrReprCacheStr);
51 }
52 return kPyAttrReprCache;
53 }
54
GetObjectString(PyObject * objName)55 static std::string GetObjectString(PyObject *objName) {
56 std::string ret = "";
57 if (objName == NULL) {
58 return ret;
59 }
60 PyObject *pyName = PyUnicode_AsEncodedString(objName, "utf-8", NULL);
61 char *strName = PyBytes_AsString(pyName);
62 if (strName != nullptr) {
63 ret = strName;
64 }
65 Py_DECREF(pyName);
66 return ret;
67 }
68
69 #define DESC(op) (std::string("{") + std::string(#op) + std::string(":") + (op) + std::string("}"))
70 #define DESC_STRING(op) (std::string("{") + std::string(#op) + std::string(":") + std::to_string(op) + std::string("}"))
71 #define DESC_STRING_L(op, l) \
72 (std::string("{") + std::string(#op) + std::string("[") + std::to_string(l) + std::string("]") + std::string(":") + \
73 std::to_string(op) + std::string("}")) // NOLINT
74 #define DESC_STRING_S(op, l) \
75 (std::string("{") + std::string(#op) + std::string("[") + std::to_string(l) + std::string("]") + std::string(":") + \
76 (op) + std::string("}")) // NOLINT
77 #define DESC_STRING_O(obj, op) \
78 (std::string("{") + std::string(#op) + std::string(":") + std::to_string(obj->op) + std::string("}"))
79 #define DESC_TOSTRING(op) \
80 (std::string("{") + std::string(#op) + std::string(":") + ((op == nullptr) ? std::string("nil") : op->ToString()) + \
81 std::string("}")) // NOLINT
82 #define DESC_ITEM(opK, opV) \
83 (std::string("{") + ((opK == nullptr) ? std::string("nil") : opK->ToString()) + std::string(":") + \
84 ((opV == nullptr) ? std::string("nil") : opV->ToString()) + std::string("}")) // NOLINT
85 #define DESC_ITEM_V(op) (std::string("{") + std::to_string(op) + std::string("}"))
86 #define DESC_ITEM_T(op) (std::string("{") + ((op == nullptr) ? std::string("nil") : op->ToString()) + std::string("}"))
87 #define DESC_INDEX(op, idx) \
88 (std::string("{") + std::string(#op) + std::string("[") + std::to_string(idx) + std::string("]") + \
89 std::string(":") + ((op[idx] == nullptr) ? std::string("nil") : op[idx]->ToString()) + std::string("}")) // NOLINT
90 #define DESC_INDEX_V(op, idx) \
91 (std::string("{") + std::string(#op) + std::string("[") + std::to_string(idx) + std::string("]") + \
92 std::string(":") + std::to_string(op[idx]) + std::string("}")) // NOLINT
93 #define DESC_END ItemData::ToString()
94
95 typedef enum _ItemType {
96 PyNull = 0,
97 PyLong,
98 PyFloat,
99 PyBool,
100 PyBytes,
101 PyStr,
102 PyList,
103 PyTuple,
104 PySet,
105 PyFrozenSet,
106 PyDict,
107 PyComplex,
108 PySlice,
109 PyFunction,
110 PyMethod,
111 PyInstanceMethod,
112 PyType,
113 PyNumpy,
114 PyUnknown,
115 TensorType,
116 ParamInfo,
117 MetaTensor,
118 Tensor,
119 MapTensor,
120 RowTensor,
121 COOTensor,
122 CSRTensor,
123 Tensordata,
124 Primitive,
125 Cell,
126 } ItemType;
127
128 class ItemData {
129 public:
ItemData(ItemType itemType,bool needSpecialize,int recurseDepth)130 ItemData(ItemType itemType, bool needSpecialize, int recurseDepth)
131 : tp_(itemType), specialized_(needSpecialize), recurseDepth_(recurseDepth), info_(nullptr) {}
132
133 virtual ~ItemData() = default;
134
operator ==(const ItemData & obj) const135 virtual bool operator==(const ItemData &obj) const { return obj.tp_ == tp_; }
136
ToString()137 virtual std::string ToString() {
138 if (tp_ == ItemType::PyNull) {
139 return "(null)";
140 } else {
141 return std::string("(type:") + std::to_string(SizeToInt(tp_)) + ",specialize:" + std::to_string(specialized_) +
142 ",recurse:" + std::to_string(recurseDepth_) + ")";
143 }
144 }
145
Info()146 virtual const InfoPack &Info() {
147 if (info_ == nullptr) {
148 InfoPack info;
149 info << uint8_t(tp_);
150 info.Begin();
151 if (tp_ != ItemType::PyNull && tp_ != ItemType::PyUnknown) {
152 info << specialized_ << recurseDepth_;
153 }
154 SubInfo(&info);
155 info.End();
156 info_ = std::make_shared<InfoPack>(info);
157 info_->Update();
158 }
159 return *info_;
160 }
161
GetItemType()162 virtual ItemType GetItemType() { return tp_; }
163
MatchDynamicShape(std::shared_ptr<ItemData> other)164 virtual bool MatchDynamicShape(std::shared_ptr<ItemData> other) { return false; }
165
166 protected:
SubInfo(InfoPack * info)167 virtual void SubInfo(InfoPack *info) {}
168 ItemType tp_;
169 bool specialized_;
170 int recurseDepth_;
171 InfoPackPtr info_;
172 };
173 using ItemDataPtr = std::shared_ptr<ItemData>;
174
175 static ItemDataPtr CreateItem(PyObject *obj, bool needSpecialize = true, int recurseDepth = INT_MAX);
176
177 class IntData : public ItemData {
178 public:
IntData(PyObject * obj,bool needSpecialize,int recurseDepth)179 IntData(PyObject *obj, bool needSpecialize, int recurseDepth)
180 : ItemData(ItemType::PyLong, needSpecialize, recurseDepth) {
181 tp_ = ItemType::PyLong;
182 intVar_ = PyLong_AsLong(obj);
183 }
184
operator ==(const ItemData & obj) const185 bool operator==(const ItemData &obj) const override {
186 return ItemData::operator==(obj) && (!specialized_ || ((static_cast<const IntData &>(obj)).intVar_ == intVar_));
187 }
188
ToString()189 std::string ToString() override { return DESC_STRING(intVar_) + DESC_END; }
190
191 protected:
SubInfo(InfoPack * info)192 void SubInfo(InfoPack *info) override { (*info) << intVar_; }
193 int64_t intVar_;
194 };
195
196 class FloatData : public ItemData {
197 public:
FloatData(PyObject * obj,bool needSpecialize,int recurseDepth)198 FloatData(PyObject *obj, bool needSpecialize, int recurseDepth)
199 : ItemData(ItemType::PyFloat, needSpecialize, recurseDepth) {
200 floatVar_ = PyFloat_AsDouble(obj);
201 }
202
operator ==(const ItemData & obj) const203 bool operator==(const ItemData &obj) const override {
204 return ItemData::operator==(obj) && (!specialized_ || (static_cast<const FloatData &>(obj)).floatVar_ == floatVar_);
205 }
206
ToString()207 std::string ToString() override { return DESC_STRING(floatVar_) + DESC_END; }
208
209 protected:
SubInfo(InfoPack * info)210 void SubInfo(InfoPack *info) override { (*info) << floatVar_; }
211 double floatVar_;
212 };
213
214 class BoolData : public ItemData {
215 public:
BoolData(PyObject * obj,bool needSpecialize,int recurseDepth)216 BoolData(PyObject *obj, bool needSpecialize, int recurseDepth)
217 : ItemData(ItemType::PyBool, needSpecialize, recurseDepth) {
218 boolVar_ = (obj == Py_True);
219 }
220
operator ==(const ItemData & obj) const221 bool operator==(const ItemData &obj) const override {
222 return ItemData::operator==(obj) && (!specialized_ || (static_cast<const BoolData &>(obj)).boolVar_ == boolVar_);
223 }
224
ToString()225 std::string ToString() override { return DESC_STRING(boolVar_) + DESC_END; }
226
227 protected:
SubInfo(InfoPack * info)228 void SubInfo(InfoPack *info) override { (*info) << boolVar_; }
229 bool boolVar_;
230 };
231
232 class BytesData : public ItemData {
233 public:
BytesData(PyObject * obj,bool needSpecialize,int recurseDepth)234 BytesData(PyObject *obj, bool needSpecialize, int recurseDepth)
235 : ItemData(ItemType::PyBytes, needSpecialize, recurseDepth), len_(PyBytes_Size(obj)) {
236 if (needSpecialize) {
237 buf_ = std::make_unique<uint8_t[]>(len_);
238 if (buf_ != nullptr) {
239 char *pBuf = PyBytes_AS_STRING(reinterpret_cast<PyBytesObject *>(obj));
240 if (pBuf != nullptr) {
241 memcpy_s(buf_.get(), len_, reinterpret_cast<uint8_t *>(pBuf), len_);
242 } else {
243 buf_.release();
244 }
245 }
246 } else {
247 buf_.reset(nullptr);
248 }
249 }
250
~BytesData()251 ~BytesData() override { buf_.release(); }
252
operator ==(const ItemData & obj) const253 bool operator==(const ItemData &obj) const override {
254 if (ItemData::operator==(obj)) {
255 const BytesData &other = static_cast<const BytesData &>(obj);
256 return len_ == other.len_ &&
257 ((specialized_ && (len_ == 0 || (buf_ != nullptr && other.buf_ != nullptr &&
258 memcmp(buf_.get(), other.buf_.get(), len_) == 0))) ||
259 (!specialized_));
260 }
261 return false;
262 }
263
ToString()264 std::string ToString() override {
265 size_t bytes = (size_t)(buf_.get());
266 return DESC_STRING_L(bytes, len_) + DESC_END;
267 }
268
269 protected:
SubInfo(InfoPack * info)270 void SubInfo(InfoPack *info) override { (*info) << (uint64_t)len_ << reinterpret_cast<void *>(buf_.get()); }
271 Py_ssize_t len_;
272 std::unique_ptr<uint8_t[]> buf_;
273 };
274
275 class StringData : public ItemData {
276 public:
StringData(PyObject * obj,bool needSpecialize,int recurseDepth)277 StringData(PyObject *obj, bool needSpecialize, int recurseDepth)
278 : ItemData(ItemType::PyStr, needSpecialize, recurseDepth) {
279 if (needSpecialize) {
280 strVal_ = GetObjectString(obj);
281 }
282 }
283
operator ==(const ItemData & obj) const284 bool operator==(const ItemData &obj) const override {
285 return ItemData::operator==(obj) &&
286 ((specialized_ && (static_cast<const StringData &>(obj)).strVal_.compare(strVal_) == 0) || (!specialized_));
287 }
288
ToString()289 std::string ToString() override { return DESC(strVal_) + DESC_END; }
290
291 protected:
SubInfo(InfoPack * info)292 void SubInfo(InfoPack *info) override { (*info) << strVal_; }
293 std::string strVal_;
294 };
295
296 class ListData : public ItemData {
297 public:
298 ListData(PyObject *obj, bool needSpecialize, int recurseDepth);
299
operator ==(const ItemData & obj) const300 bool operator==(const ItemData &obj) const override {
301 if (ItemData::operator==(obj)) {
302 const ListData &list = static_cast<const ListData &>(obj);
303 if (list.listVar_.size() == listVar_.size()) {
304 return CompareList(list);
305 }
306 }
307 return false;
308 }
309
ToString()310 std::string ToString() override {
311 std::string ret;
312 for (auto it : listVar_) {
313 ret += DESC_ITEM_T(it);
314 }
315 switch (tp_) {
316 case ItemType::PyList: {
317 std::string list = ret;
318 ret = DESC_STRING_S(list, listVar_.size());
319 } break;
320 case ItemType::PyTuple: {
321 std::string tuple = ret;
322 ret = DESC_STRING_S(tuple, listVar_.size());
323 } break;
324 case ItemType::PySet: {
325 std::string set = ret;
326 ret = DESC_STRING_S(set, listVar_.size());
327 } break;
328 case ItemType::PyFrozenSet: {
329 std::string fronzen_set = ret;
330 ret = DESC_STRING_S(fronzen_set, listVar_.size());
331 } break;
332 default:
333 ret = "unknown";
334 break;
335 }
336 return ret + DESC_END;
337 }
338
339 protected:
SubInfo(InfoPack * info)340 void SubInfo(InfoPack *info) override {
341 (*info) << uint8_t(tp_);
342 (*info) << uint64_t(listVar_.size());
343 for (auto v : listVar_) {
344 (*info) << v->Info();
345 }
346 }
CompareList(const ListData & list) const347 bool CompareList(const ListData &list) const {
348 if (!inOrder_) {
349 std::vector<ItemDataPtr> listCpy = list.listVar_;
350 for (size_t i = 0, j; i < listVar_.size(); ++i) {
351 size_t lenList = listCpy.size();
352 for (j = 0; j < lenList; ++j) {
353 if (*(listCpy[j]) == *(listVar_[i])) {
354 listCpy.erase(listCpy.begin() + j);
355 break;
356 }
357 }
358 if (j == lenList) {
359 return false;
360 }
361 }
362 } else {
363 for (size_t i = 0; i < listVar_.size(); ++i) {
364 if (*(list.listVar_[i]) == *(listVar_[i])) {
365 continue;
366 } else {
367 return false;
368 }
369 }
370 }
371 return true;
372 }
InitList(PyObject * obj,bool needSpecialize,int recurseDepth)373 void InitList(PyObject *obj, bool needSpecialize, int recurseDepth) {
374 tp_ = ItemType::PyList;
375 for (Py_ssize_t i = 0; i < PyList_Size(obj); ++i) {
376 PyObject *item = PyList_GetItem(obj, i);
377 if (item != NULL) {
378 if (recurseDepth > 0 || needSpecialize) {
379 listVar_.push_back(CreateItem(item, needSpecialize, recurseDepth));
380 } else {
381 listVar_.push_back(CreateItem(reinterpret_cast<PyObject *>(Py_TYPE(item)), false, false));
382 }
383 }
384 }
385 }
InitTuple(PyObject * obj,bool needSpecialize,int recurseDepth)386 void InitTuple(PyObject *obj, bool needSpecialize, int recurseDepth) {
387 tp_ = ItemType::PyTuple;
388 for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(obj); ++i) {
389 PyObject *item = PyTuple_GET_ITEM(obj, i);
390 if (item != NULL) {
391 if (recurseDepth > 0 || needSpecialize) {
392 listVar_.push_back(CreateItem(item, needSpecialize, recurseDepth));
393 } else {
394 listVar_.push_back(CreateItem(reinterpret_cast<PyObject *>(Py_TYPE(item)), false, false));
395 }
396 }
397 }
398 }
InitSet(PyObject * obj,bool needSpecialize,int recurseDepth)399 void InitSet(PyObject *obj, bool needSpecialize, int recurseDepth) {
400 tp_ = ItemType::PySet;
401 Py_ssize_t pos = 0;
402 PyObject *item;
403 Py_hash_t hash;
404 while (_PySet_NextEntry(obj, &pos, &item, &hash)) {
405 if (recurseDepth > 0 || needSpecialize) {
406 listVar_.push_back(CreateItem(item, needSpecialize, recurseDepth));
407 } else {
408 listVar_.push_back(CreateItem(reinterpret_cast<PyObject *>(Py_TYPE(item)), false, false));
409 }
410 }
411 inOrder_ = false;
412 }
InitFrozenSet(PyObject * obj,bool needSpecialize,int recurseDepth)413 void InitFrozenSet(PyObject *obj, bool needSpecialize, int recurseDepth) {
414 tp_ = ItemType::PyFrozenSet;
415 Py_ssize_t pos = 0;
416 PyObject *item;
417 Py_hash_t hash;
418 while (_PySet_NextEntry(obj, &pos, &item, &hash)) {
419 if (recurseDepth > 0 || needSpecialize) {
420 listVar_.push_back(CreateItem(item, needSpecialize, recurseDepth));
421 } else {
422 listVar_.push_back(CreateItem(reinterpret_cast<PyObject *>(Py_TYPE(item)), false, false));
423 }
424 }
425 inOrder_ = false;
426 }
427 std::vector<ItemDataPtr> listVar_;
428 bool inOrder_ = true;
429 };
430
431 class ComplexData : public ItemData {
432 public:
ComplexData(PyObject * obj,bool needSpecialize,int recurseDepth)433 ComplexData(PyObject *obj, bool needSpecialize, int recurseDepth)
434 : ItemData(ItemType::PyComplex, needSpecialize, recurseDepth) {
435 if (needSpecialize) {
436 complexVar_ = std::make_pair(PyComplex_RealAsDouble(obj), PyComplex_ImagAsDouble(obj));
437 }
438 }
439
operator ==(const ItemData & obj) const440 bool operator==(const ItemData &obj) const override {
441 return ItemData::operator==(obj) &&
442 (!specialized_ || (static_cast<const ComplexData &>(obj)).complexVar_ == complexVar_);
443 }
444
ToString()445 std::string ToString() override {
446 return "complex(" + std::to_string(complexVar_.first) + "," + std::to_string(complexVar_.second) + ")" + DESC_END;
447 }
448
449 protected:
SubInfo(InfoPack * info)450 void SubInfo(InfoPack *info) override { (*info) << complexVar_.first << complexVar_.second; }
451 std::pair<double, double> complexVar_;
452 };
453
454 class SliceData : public ItemData {
455 public:
SliceData(PyObject * obj,bool needSpecialize,int recurseDepth)456 SliceData(PyObject *obj, bool needSpecialize, int recurseDepth)
457 : ItemData(ItemType::PySlice, needSpecialize, recurseDepth) {
458 Py_ssize_t start = 0;
459 Py_ssize_t stop = 0;
460 Py_ssize_t step = 0;
461 if (needSpecialize) {
462 PySlice_Unpack(obj, &start, &stop, &step);
463 sliceVar_.push_back((int64_t)start);
464 sliceVar_.push_back((int64_t)stop);
465 sliceVar_.push_back((int64_t)step);
466 }
467 }
468
operator ==(const ItemData & obj) const469 bool operator==(const ItemData &obj) const override {
470 if (ItemData::operator==(obj)) {
471 const SliceData &other = static_cast<const SliceData &>(obj);
472 return (!specialized_ || (other.sliceVar_[0] == sliceVar_[0] && other.sliceVar_[1] == sliceVar_[1] &&
473 other.sliceVar_[2] == sliceVar_[2]));
474 }
475 return false;
476 }
477
ToString()478 std::string ToString() override {
479 std::string slice;
480 for (auto it : sliceVar_) {
481 slice += DESC_ITEM_V(it);
482 }
483 return DESC_STRING_S(slice, sliceVar_.size()) + DESC_END;
484 }
485
486 protected:
SubInfo(InfoPack * info)487 void SubInfo(InfoPack *info) override { (*info) << sliceVar_; }
488 std::vector<int64_t> sliceVar_;
489 };
490
491 typedef enum _DictType {
492 DtDict = 0,
493 DtKeys,
494 DtValues,
495 DtItems,
496 } DictType;
497
498 class DictData : public ItemData {
499 public:
DictData(PyObject * obj,bool needSpecialize,int recurseDepth)500 DictData(PyObject *obj, bool needSpecialize, int recurseDepth)
501 : ItemData(ItemType::PyDict, needSpecialize, recurseDepth) {
502 if (PyDictKeys_Check(obj)) {
503 dt_ = DictType::DtKeys;
504 obj = PyObject_Vectorcall(reinterpret_cast<PyObject *>(&PyList_Type), &obj, 1, nullptr);
505 } else if (PyDictValues_Check(obj)) {
506 dt_ = DictType::DtValues;
507 obj = PyObject_Vectorcall(reinterpret_cast<PyObject *>(&PyList_Type), &obj, 1, nullptr);
508 } else if (PyDictItems_Check(obj)) {
509 dt_ = DictType::DtItems;
510 obj = PyObject_Vectorcall(reinterpret_cast<PyObject *>(&PyDict_Type), &obj, 1, nullptr);
511 } else {
512 dt_ = DictType::DtDict;
513 }
514 Py_ssize_t pos = 0;
515 PyObject *key;
516 PyObject *val;
517 if (dt_ == DictType::DtItems || dt_ == DictType::DtDict) {
518 while (PyDict_Next(obj, &pos, &key, &val)) {
519 ItemDataPtr k;
520 ItemDataPtr v;
521 if (recurseDepth > 0 || needSpecialize) {
522 k = CreateItem(key, needSpecialize, recurseDepth);
523 v = CreateItem(val, needSpecialize, recurseDepth);
524 } else {
525 k = CreateItem(reinterpret_cast<PyObject *>(Py_TYPE(key)), false, false);
526 v = CreateItem(reinterpret_cast<PyObject *>(Py_TYPE(val)), false, false);
527 }
528 listK_.push_back(k);
529 listV_.push_back(v);
530 }
531 } else {
532 std::vector<ItemDataPtr> &list = dt_ == DictType::DtKeys ? listK_ : listV_;
533 for (Py_ssize_t i = 0; i < PyList_Size(obj); ++i) {
534 PyObject *item = PyList_GetItem(obj, i);
535 if (recurseDepth > 0 || needSpecialize) {
536 list.push_back(CreateItem(item, needSpecialize, recurseDepth));
537 } else {
538 list.push_back(CreateItem(reinterpret_cast<PyObject *>(Py_TYPE(item)), false, false));
539 }
540 }
541 }
542 if (dt_ != DictType::DtDict) {
543 Py_DECREF(obj);
544 }
545 }
546
operator ==(const ItemData & obj) const547 bool operator==(const ItemData &obj) const override {
548 if (ItemData::operator==(obj)) {
549 const DictData &other = static_cast<const DictData &>(obj);
550 if (dt_ != other.dt_) {
551 return false;
552 }
553 if ((dt_ == DictType::DtValues || other.listK_.size() == listK_.size()) &&
554 (dt_ == DictType::DtKeys || other.listV_.size() == listV_.size())) {
555 return CompareKV(other);
556 }
557 }
558 return false;
559 }
560
ToString()561 std::string ToString() override {
562 std::string dict = DESC_STRING(dt_);
563 size_t listSize = 0;
564 if (dt_ == DictType::DtItems || dt_ == DictType::DtDict) {
565 listSize = listK_.size();
566 for (size_t i = 0; i < listSize; ++i) {
567 dict += DESC_ITEM(listK_[i], listV_[i]);
568 }
569 } else if (dt_ == DictType::DtKeys) {
570 listSize = listK_.size();
571 for (size_t i = 0; i < listSize; ++i) {
572 dict += DESC_ITEM_T(listK_[i]);
573 }
574 } else if (dt_ == DictType::DtValues) {
575 listSize = listV_.size();
576 for (size_t i = 0; i < listSize; ++i) {
577 dict += DESC_ITEM_T(listV_[i]);
578 }
579 }
580 return DESC_STRING_S(dict, listSize) + DESC_END;
581 }
582
583 protected:
SubInfo(InfoPack * info)584 void SubInfo(InfoPack *info) override {
585 (*info) << dt_;
586 (*info) << uint64_t(listK_.size());
587 for (auto i : listK_) {
588 (*info) << i->Info();
589 }
590 (*info) << uint64_t(listV_.size());
591 for (auto i : listV_) {
592 (*info) << i->Info();
593 }
594 }
CompareKV(const DictData & other) const595 bool CompareKV(const DictData &other) const {
596 std::vector<ItemDataPtr> listCpK = other.listK_;
597 std::vector<ItemDataPtr> listCpV = other.listV_;
598 size_t listSize = listK_.size();
599 if (listSize < listV_.size()) {
600 listSize = listV_.size();
601 }
602 for (size_t i = 0, j = 0; i < listSize; ++i) {
603 size_t cpListSize = dt_ == DictType::DtValues ? listCpV.size() : listCpK.size();
604 for (; j < cpListSize; ++j) {
605 if ((dt_ == DictType::DtValues || *(listK_[i]) == *(listCpK[j])) &&
606 (dt_ == DictType::DtKeys || *(listV_[i]) == *(listCpV[j]))) {
607 if (dt_ != DictType::DtValues) {
608 listCpK.erase(listCpK.begin() + j);
609 }
610 if (dt_ != DictType::DtKeys) {
611 listCpV.erase(listCpV.begin() + j);
612 }
613 break;
614 }
615 }
616 if (j == cpListSize) {
617 return false;
618 }
619 }
620 return true;
621 }
622 DictType dt_;
623 std::vector<ItemDataPtr> listK_;
624 std::vector<ItemDataPtr> listV_;
625 };
626
627 class FunctionData : public ItemData {
628 public:
FunctionData(PyObject * obj,bool needSpecialize,int recurseDepth)629 FunctionData(PyObject *obj, bool needSpecialize, int recurseDepth)
630 : ItemData(ItemType::PyFunction, needSpecialize, recurseDepth) {
631 if (needSpecialize || recurseDepth > 0) {
632 code_ = reinterpret_cast<PyCodeObject *>(PyFunction_GetCode(obj));
633 defaults_ = CreateItem(PyFunction_GetDefaults(obj), needSpecialize, recurseDepth);
634 kwdefaults_ = CreateItem(PyFunction_GetKwDefaults(obj), needSpecialize, recurseDepth);
635 closure_ = CreateItem(PyFunction_GetClosure(obj), needSpecialize, recurseDepth);
636 } else {
637 code_ = reinterpret_cast<PyCodeObject *>(PyFunction_GetCode(obj));
638 PyObject *temp = PyFunction_GetDefaults(obj);
639 defaults_ =
640 CreateItem((temp == NULL || temp == Py_None) ? Py_None : reinterpret_cast<PyObject *>(Py_TYPE(temp)), false, 0);
641 temp = PyFunction_GetKwDefaults(obj);
642 kwdefaults_ =
643 CreateItem((temp == NULL || temp == Py_None) ? Py_None : reinterpret_cast<PyObject *>(Py_TYPE(temp)), false, 0);
644 temp = PyFunction_GetClosure(obj);
645 closure_ =
646 CreateItem((temp == NULL || temp == Py_None) ? Py_None : reinterpret_cast<PyObject *>(Py_TYPE(temp)), false, 0);
647 }
648 }
649
operator ==(const ItemData & obj) const650 bool operator==(const ItemData &obj) const override {
651 if (ItemData::operator==(obj)) {
652 const FunctionData &other = static_cast<const FunctionData &>(obj);
653 return code_ == other.code_ && *defaults_ == *(other.defaults_) && *kwdefaults_ == *(other.kwdefaults_) &&
654 *closure_ == *(other.closure_);
655 }
656 return false;
657 }
658
ToString()659 std::string ToString() override {
660 std::string func = DESC_TOSTRING(defaults_) + DESC_TOSTRING(kwdefaults_) + DESC_TOSTRING(closure_);
661 return DESC(func) + DESC_END;
662 }
663
664 protected:
SubInfo(InfoPack * info)665 void SubInfo(InfoPack *info) override {
666 (*info) << (defaults_ != nullptr);
667 if (defaults_ != nullptr) {
668 (*info) << defaults_->Info();
669 }
670 (*info) << (kwdefaults_ != nullptr);
671 if (kwdefaults_ != nullptr) {
672 (*info) << kwdefaults_->Info();
673 }
674 (*info) << (closure_ != nullptr);
675 if (closure_ != nullptr) {
676 (*info) << closure_->Info();
677 }
678 }
679 PyCodeObject *code_;
680 ItemDataPtr defaults_;
681 ItemDataPtr kwdefaults_;
682 ItemDataPtr closure_;
683 };
684
685 class MethodData : public ItemData {
686 public:
MethodData(PyObject * obj,bool needSpecialize,int recurseDepth)687 MethodData(PyObject *obj, bool needSpecialize, int recurseDepth)
688 : ItemData(ItemType::PyMethod, needSpecialize, recurseDepth),
689 refFunc_(CreateItem(PyMethod_GET_FUNCTION(obj), needSpecialize, recurseDepth)),
690 refSelf_(CreateItem(PyMethod_GET_SELF(obj), needSpecialize, recurseDepth)) {}
691
operator ==(const ItemData & obj) const692 bool operator==(const ItemData &obj) const override {
693 if (ItemData::operator==(obj)) {
694 const MethodData &other = static_cast<const MethodData &>(obj);
695 return *refFunc_ == *(other.refFunc_) && *refSelf_ == *(other.refSelf_);
696 }
697 return false;
698 }
699
ToString()700 std::string ToString() override {
701 std::string method = DESC_TOSTRING(refFunc_) + DESC_TOSTRING(refSelf_);
702 return DESC(method) + DESC_END;
703 }
704
705 protected:
SubInfo(InfoPack * info)706 void SubInfo(InfoPack *info) override {
707 (*info) << (refFunc_ != nullptr);
708 if (refFunc_ != nullptr) {
709 (*info) << refFunc_->Info();
710 }
711 (*info) << (refSelf_ != nullptr);
712 if (refSelf_ != nullptr) {
713 (*info) << refSelf_->Info();
714 }
715 }
716 ItemDataPtr refFunc_;
717 ItemDataPtr refSelf_;
718 };
719
720 class InstanceMethodData : public ItemData {
721 public:
InstanceMethodData(PyObject * obj,bool needSpecialize,int recurseDepth)722 InstanceMethodData(PyObject *obj, bool needSpecialize, int recurseDepth)
723 : ItemData(ItemType::PyInstanceMethod, needSpecialize, recurseDepth),
724 refFunc_(CreateItem(PyInstanceMethod_GET_FUNCTION(obj), needSpecialize, recurseDepth)) {}
725
operator ==(const ItemData & obj) const726 bool operator==(const ItemData &obj) const override {
727 if (ItemData::operator==(obj)) {
728 const InstanceMethodData &other = static_cast<const InstanceMethodData &>(obj);
729 return *refFunc_ == *(other.refFunc_);
730 }
731 return false;
732 }
733
ToString()734 std::string ToString() override {
735 std::string instance_method = DESC_TOSTRING(refFunc_);
736 return DESC(instance_method) + DESC_END;
737 }
738
739 protected:
SubInfo(InfoPack * info)740 void SubInfo(InfoPack *info) override {
741 (*info) << (refFunc_ != nullptr);
742 if (refFunc_ != nullptr) {
743 (*info) << refFunc_->Info();
744 }
745 }
746 ItemDataPtr refFunc_;
747 };
748
749 class TypeData : public ItemData {
750 public:
TypeData(PyObject * obj,bool needSpecialize,int recurseDepth)751 TypeData(PyObject *obj, bool needSpecialize, int recurseDepth)
752 : ItemData(ItemType::PyType, needSpecialize, recurseDepth) {
753 refType_ = reinterpret_cast<PyTypeObject *>(obj);
754 is_adapter_tensor_type_ = false;
755 ambiguous_tensor_type_ = false;
756 }
757
set_ambiguous_tensor_type(bool value)758 void set_ambiguous_tensor_type(bool value) {
759 if (value && (IsTensorType<true>(refType_) || IsStubTensorType<true>(refType_))) {
760 ambiguous_tensor_type_ = true;
761 PyObject *obj = reinterpret_cast<PyObject *>(refType_);
762 py::object registry = Utils::GetModuleAttr("mindspore.common._register_for_adapter", "ms_adapter_registry");
763 is_adapter_tensor_type_ = registry.ptr() != nullptr && obj == py::getattr(registry, "tensor", nullptr).ptr();
764 }
765 }
766
operator ==(const ItemData & obj) const767 bool operator==(const ItemData &obj) const override {
768 if (ItemData::operator==(obj)) {
769 PyTypeObject *otherType = (static_cast<const TypeData &>(obj)).refType_;
770 bool ret = refType_ == otherType;
771 if (!ret) {
772 ret = PyType_IsSubtype(refType_, otherType) || PyType_IsSubtype(otherType, refType_);
773 }
774 // adapter tensor type must be check exactly
775 // if exactly type check failed, check ambiguous tensor type if necessary
776 if (!is_adapter_tensor_type_ && !ret && ambiguous_tensor_type_) {
777 ret = IsTensorType<true>(otherType) || IsStubTensorType<true>(otherType);
778 }
779 return ret;
780 }
781 return false;
782 }
783
ToString()784 std::string ToString() override {
785 std::string type = refType_->tp_name;
786 return DESC(type) + DESC_END;
787 }
788
789 protected:
SubInfo(InfoPack * info)790 void SubInfo(InfoPack *info) override { (*info) << refType_->tp_name; }
791 PyTypeObject *refType_;
792
793 // this flag is checked only if tensor type is ambiguous
794 bool is_adapter_tensor_type_;
795
796 // mix the tensor type.
797 // only set true if _c_expression.Tensor type, common.Tensor type, StubTensor type and all subtype of them
798 bool ambiguous_tensor_type_;
799 };
800
801 class NumpyData : public ItemData {
802 public:
NumpyData(PyObject * obj,bool needSpecialize,int recurseDepth)803 NumpyData(PyObject *obj, bool needSpecialize, int recurseDepth)
804 : ItemData(ItemType::PyNumpy, needSpecialize, recurseDepth) {
805 py::array arr = py::cast<py::array>(obj);
806 dtype_ = arr.dtype();
807 size_ = (uint64_t)arr.size();
808 itemsize_ = (uint64_t)arr.itemsize();
809 ndim_ = (int64_t)arr.ndim();
810 nbytes_ = (uint64_t)arr.nbytes();
811 for (ssize_t i = 0; i < ndim_; ++i) {
812 shape_.push_back((int64_t)arr.shape()[i]);
813 strides_.push_back((int64_t)arr.strides()[i]);
814 }
815 if (arr.data() != nullptr) {
816 if (needSpecialize) {
817 buf_ = std::make_unique<uint8_t[]>(nbytes_);
818 if (buf_ != NULL) {
819 memcpy_s(buf_.get(), nbytes_, reinterpret_cast<uint8_t *>(arr.mutable_data()), nbytes_);
820 }
821 } else {
822 buf_.reset(nullptr);
823 }
824 } else {
825 buf_.reset(nullptr);
826 }
827 }
828
~NumpyData()829 ~NumpyData() override { buf_.release(); }
830
operator ==(const ItemData & obj) const831 bool operator==(const ItemData &obj) const override {
832 if (ItemData::operator==(obj)) {
833 const NumpyData &other = static_cast<const NumpyData &>(obj);
834 return dtype_ == other.dtype_ && size_ == other.size_ && ndim_ == other.ndim_ && nbytes_ == other.nbytes_ &&
835 shape_ == other.shape_ && strides_ == other.strides_ &&
836 (!specialized_ ||
837 (buf_ != NULL && other.buf_ != NULL && memcmp(buf_.get(), other.buf_.get(), nbytes_) == 0));
838 }
839 return false;
840 }
841
ToString()842 std::string ToString() override {
843 std::string numpy;
844 char dtype_kind = dtype_.kind();
845 numpy +=
846 DESC_STRING(dtype_kind) + DESC_STRING(size_) + DESC_STRING(itemsize_) + DESC_STRING(ndim_) + DESC_STRING(nbytes_);
847 for (size_t i = 0; i < shape_.size(); ++i) {
848 numpy += DESC_INDEX_V(shape_, i) + DESC_INDEX_V(strides_, i);
849 }
850 return DESC(numpy) + DESC_END;
851 }
852
853 protected:
SubInfo(InfoPack * info)854 void SubInfo(InfoPack *info) override {
855 (*info) << dtype_.kind() << size_ << itemsize_ << ndim_ << nbytes_ << shape_ << strides_;
856 }
857 py::dtype dtype_;
858 uint64_t size_;
859 uint64_t itemsize_;
860 int64_t ndim_;
861 uint64_t nbytes_;
862 std::vector<int64_t> shape_;
863 std::vector<int64_t> strides_;
864 std::unique_ptr<uint8_t[]> buf_;
865 };
866
867 class TensorTypeData : public ItemData {
868 public:
TensorTypeData(PyObject * obj,bool needSpecialize,int recurseDepth)869 TensorTypeData(PyObject *obj, bool needSpecialize, int recurseDepth)
870 : ItemData(ItemType::TensorType, needSpecialize, recurseDepth) {
871 auto pyObj = py::cast<py::object>(obj);
872 tpp_ = pyObj.cast<mindspore::TypePtr>();
873 }
874
operator ==(const ItemData & obj) const875 bool operator==(const ItemData &obj) const override {
876 return ItemData::operator==(obj) && (!specialized_ || *((static_cast<const TensorTypeData &>(obj)).tpp_) == *tpp_);
877 }
878
ToString()879 std::string ToString() override {
880 std::string tensor_type = tpp_->ToString();
881 return DESC(tensor_type) + DESC_END;
882 }
883
884 protected:
SubInfo(InfoPack * info)885 void SubInfo(InfoPack *info) override { (*info) << tpp_; }
886 mindspore::TypePtr tpp_;
887 };
888
889 class ParamInfoData : public ItemData {
890 public:
ParamInfoData(PyObject * obj,bool needSpecialize,int recurseDepth)891 ParamInfoData(PyObject *obj, bool needSpecialize, int recurseDepth)
892 : ItemData(ItemType::ParamInfo, needSpecialize, recurseDepth) {
893 auto pyObj = py::cast<py::object>(obj);
894 auto ptr = pyObj.cast<mindspore::ParamInfoPtr>();
895 param_ = ptr->Clone();
896 }
897
operator ==(const ItemData & obj) const898 bool operator==(const ItemData &obj) const override {
899 if (ItemData::operator==(obj)) {
900 if (!specialized_) {
901 return true;
902 }
903 const ParamInfoData &other = static_cast<const ParamInfoData &>(obj);
904 return Equal(param_, other.param_);
905 }
906 return false;
907 }
908
Equal(ParamInfoPtr a,ParamInfoPtr b)909 static bool Equal(ParamInfoPtr a, ParamInfoPtr b) {
910 return a->requires_grad() == b->requires_grad() && a->comm_fusion() == b->comm_fusion() &&
911 a->parallel_optimizer() == b->parallel_optimizer() &&
912 a->parallel_optimizer_comm_recompute() == b->parallel_optimizer_comm_recompute() &&
913 a->parameter_shape() == b->parameter_shape() && a->use_persistent_storage() == b->use_persistent_storage() &&
914 a->cache_enable() == b->cache_enable() && a->param_strategy() == b->param_strategy() &&
915 a->cache_shape() == b->cache_shape() && a->requires_aggr() == b->requires_aggr();
916 }
917
ToString()918 std::string ToString() override {
919 std::string param_info = ToStringAttr(param_);
920 return DESC(param_info) + DESC_END;
921 }
922
ToStringAttr(mindspore::ParamInfoPtr p)923 static std::string ToStringAttr(mindspore::ParamInfoPtr p) {
924 if (p == nullptr) {
925 return "nil";
926 }
927 std::string param_name = p->name();
928 std::string ret = DESC(param_name) + DESC_STRING_O(p, requires_grad()) + DESC_STRING_O(p, comm_fusion()) +
929 DESC_STRING_O(p, parallel_optimizer()) + DESC_STRING_O(p, requires_aggr()) +
930 DESC_STRING_O(p, parallel_optimizer_comm_recompute()) +
931 DESC_STRING_O(p, use_persistent_storage()) + DESC_STRING_O(p, cache_enable());
932 auto parameter_shape = p->parameter_shape();
933 for (size_t i = 0; i < parameter_shape.size(); ++i) {
934 ret += DESC_INDEX_V(parameter_shape, i);
935 }
936 auto cache_shape = p->cache_shape();
937 for (size_t i = 0; i < cache_shape.size(); ++i) {
938 ret += DESC_INDEX_V(cache_shape, i);
939 }
940 auto param_strategy = p->param_strategy();
941 for (size_t i = 0; i < param_strategy.size(); ++i) {
942 ret += DESC_INDEX_V(param_strategy, i);
943 }
944 return ret;
945 }
946
SubInfo(InfoPack * info,mindspore::ParamInfoPtr p)947 static void SubInfo(InfoPack *info, mindspore::ParamInfoPtr p) {
948 if (p == nullptr) {
949 return;
950 }
951 (*info) << p->name() << p->requires_grad() << p->comm_fusion() << p->parallel_optimizer() << p->requires_aggr()
952 << p->parallel_optimizer_comm_recompute() << p->use_persistent_storage() << p->cache_enable()
953 << p->parameter_shape() << p->cache_shape() << p->param_strategy();
954 }
955
956 protected:
SubInfo(InfoPack * info)957 void SubInfo(InfoPack *info) override { SubInfo(info, param_); }
958 mindspore::ParamInfoPtr param_;
959 };
960
961 static constexpr int64_t kDynamicDim = -2;
962 static constexpr int64_t kDynamicShape = -1;
963
IsDynamicDim(const ShapeVector & shape)964 static bool IsDynamicDim(const ShapeVector &shape) {
965 return std::any_of(shape.begin(), shape.end(), [](ShapeValueDType dim) { return dim == kDynamicDim; });
966 }
967
CheckShape(const ShapeVector & a,const ShapeVector & b)968 static bool CheckShape(const ShapeVector &a, const ShapeVector &b) {
969 if (IsDynamicDim(a) || IsDynamicDim(b)) {
970 return true;
971 } else if (a.size() == b.size()) {
972 for (size_t idx = 0; idx < a.size(); idx++) {
973 if (a[idx] != kDynamicShape && b[idx] != kDynamicShape && a[idx] != b[idx]) {
974 return false;
975 }
976 }
977 return true;
978 } else {
979 return false;
980 }
981 }
982
983 class MetaTensorData : public ItemData {
984 public:
MetaTensorData(mindspore::tensor::MetaTensorPtr tensor_ptr,bool needSpecialize,int recurseDepth)985 MetaTensorData(mindspore::tensor::MetaTensorPtr tensor_ptr, bool needSpecialize, int recurseDepth)
986 : ItemData(ItemType::MetaTensor, needSpecialize, recurseDepth) {
987 StoreTensor(tensor_ptr);
988 }
989
MetaTensorData(PyObject * obj,bool needSpecialize,int recurseDepth)990 MetaTensorData(PyObject *obj, bool needSpecialize, int recurseDepth)
991 : ItemData(ItemType::MetaTensor, needSpecialize, recurseDepth) {
992 mindspore::tensor::MetaTensorPtr tensor_ptr = nullptr;
993 PyObject *stubattr = GetAttrStubStr();
994 PyObject *stub = PyObject_HasAttr(obj, stubattr) ? PyObject_GetAttr(obj, stubattr) : nullptr;
995 if (stub != nullptr) {
996 if (stub != Py_None) {
997 is_stubtensor_ = true;
998 } else {
999 PyObject *tensorattr = GetAttrTensorStr();
1000 obj = PyObject_GetAttr(obj, tensorattr);
1001 tensor_ptr = py::cast<mindspore::tensor::TensorPtr>(obj);
1002 Py_DECREF(obj);
1003 }
1004 } else if (py::isinstance<mindspore::tensor::Tensor>(obj)) {
1005 tensor_ptr = py::cast<mindspore::tensor::TensorPtr>(obj);
1006 } else if (py::isinstance<mindspore::tensor::MapTensor>(obj)) {
1007 tensor_ptr = py::cast<mindspore::tensor::MapTensorPtr>(obj);
1008 } else {
1009 tensor_ptr = py::cast<mindspore::tensor::MetaTensorPtr>(obj);
1010 }
1011 if (tensor_ptr != nullptr) {
1012 StoreTensor(tensor_ptr);
1013 } else {
1014 auto ptr = py::cast<mindspore::stub::StubNodePtr>(stub);
1015 StoreStubTensor(ptr);
1016 }
1017 Py_XDECREF(stub);
1018 }
1019
operator ==(const ItemData & obj) const1020 bool operator==(const ItemData &obj) const override {
1021 if (ItemData::operator==(obj)) {
1022 const MetaTensorData &other = static_cast<const MetaTensorData &>(obj);
1023 bool ret;
1024 if (is_stubtensor_ || other.is_stubtensor_) {
1025 ret = CheckShape(shape_, other.shape_) && CheckDataType(other);
1026 } else {
1027 ret = tid_ == other.tid_ && CheckShape(shape_, other.shape_) && is_parameter_ == other.is_parameter_ &&
1028 CheckDataType(other);
1029 }
1030 if (ret) {
1031 if (is_parameter_ == true) {
1032 ret = ((param_ == nullptr && other.param_ == nullptr) ||
1033 (param_ != nullptr && other.param_ != nullptr && ParamInfoData::Equal(param_, other.param_)));
1034 }
1035 }
1036 return ret;
1037 }
1038 return false;
1039 }
1040
MakeTensor()1041 mindspore::tensor::TensorPtr MakeTensor() {
1042 return std::make_shared<mindspore::tensor::Tensor>(data_type_->type_id(), shape_);
1043 }
1044
IsDynamicShape() const1045 bool IsDynamicShape() const {
1046 return std::any_of(shape_.begin(), shape_.end(),
1047 [](ShapeValueDType dim) { return dim == kDynamicDim || dim == kDynamicShape; });
1048 }
1049
ToString()1050 std::string ToString() override {
1051 std::string meta_tensor = ToStringIntern();
1052 return DESC(meta_tensor) + DESC_END;
1053 }
1054
MatchDynamicShape(std::shared_ptr<ItemData> other)1055 bool MatchDynamicShape(std::shared_ptr<ItemData> other) override {
1056 auto type = other->GetItemType();
1057 if (type != ItemType::Tensor && type != ItemType::MetaTensor) {
1058 return false;
1059 }
1060 auto o = static_cast<MetaTensorData *>(other.get());
1061 if (!CheckDataType(*o) || specialized_ != false || o->specialized_ != false) {
1062 return false;
1063 }
1064 if (shape_.size() != o->shape_.size()) {
1065 shape_ = {kDynamicDim};
1066 } else {
1067 for (size_t idx = 0; idx < shape_.size(); ++idx) {
1068 if (shape_[idx] != kDynamicShape && shape_[idx] != o->shape_[idx]) {
1069 shape_[idx] = kDynamicShape;
1070 }
1071 }
1072 }
1073 return true;
1074 }
1075
1076 protected:
MetaTensorData(bool needSpecialize,int recurseDepth)1077 MetaTensorData(bool needSpecialize, int recurseDepth)
1078 : ItemData(ItemType::MetaTensor, needSpecialize, recurseDepth) {}
ToStringIntern()1079 virtual std::string ToStringIntern() {
1080 std::string param_desc = ParamInfoData::ToStringAttr(param_);
1081 std::string shape = "";
1082 for (size_t i = 0; i < shape_.size(); ++i) {
1083 shape += DESC_INDEX_V(shape_, i);
1084 }
1085 std::string is_stubtensor = is_stubtensor_ ? "true" : "false";
1086 return DESC_STRING(tid_) + DESC_TOSTRING(data_type_) + DESC_STRING(is_parameter_) + DESC(param_desc) + DESC(shape) +
1087 DESC(is_stubtensor);
1088 }
1089
CheckDataType(const MetaTensorData & other) const1090 bool CheckDataType(const MetaTensorData &other) const {
1091 return (data_type_ == nullptr && other.data_type_ == nullptr) ||
1092 (data_type_ != nullptr && other.data_type_ != nullptr && *data_type_ == *(other.data_type_));
1093 }
1094
StoreTensor(mindspore::tensor::MetaTensorPtr tensor_ptr)1095 void StoreTensor(mindspore::tensor::MetaTensorPtr tensor_ptr) {
1096 tid_ = tensor_ptr->data_type();
1097 shape_ = tensor_ptr->shape();
1098 data_type_ = tensor_ptr->Dtype();
1099 is_parameter_ = tensor_ptr->is_parameter();
1100 param_ = tensor_ptr->param_info();
1101 }
1102
StoreStubTensor(mindspore::stub::StubNodePtr stub_ptr)1103 void StoreStubTensor(mindspore::stub::StubNodePtr stub_ptr) {
1104 auto base = stub_ptr->ToAbstract();
1105 auto shape = base->BuildShape()->cast<abstract::ShapePtr>();
1106 if (shape && !shape->IsDynamic()) {
1107 shape_ = shape->shape();
1108 } else {
1109 shape_ = {};
1110 }
1111 auto dt = base->BuildType();
1112 if (dt->isa<mindspore::TensorType>()) {
1113 data_type_ = dt->cast<std::shared_ptr<mindspore::TensorType>>()->element();
1114 } else {
1115 data_type_ = dt;
1116 }
1117 }
1118
SubInfo(InfoPack * info)1119 void SubInfo(InfoPack *info) override {
1120 (*info) << uint8_t(tid_) << data_type_ << is_parameter_ << shape_ << is_stubtensor_;
1121 ParamInfoData::SubInfo(info, param_);
1122 }
1123
1124 mindspore::TypeId tid_ = TypeId::kTypeUnknown;
1125 ShapeVector shape_;
1126 TypePtr data_type_;
1127 bool is_parameter_ = false;
1128 bool is_stubtensor_ = false;
1129 mindspore::ParamInfoPtr param_;
1130 };
1131
1132 class TensorData : public MetaTensorData {
1133 public:
TensorData(mindspore::tensor::TensorPtr tensor_ptr,bool needSpecialize,int recurseDepth)1134 TensorData(mindspore::tensor::TensorPtr tensor_ptr, bool needSpecialize, int recurseDepth)
1135 : MetaTensorData(needSpecialize, recurseDepth) {
1136 tp_ = ItemType::Tensor;
1137 StoreTensor(tensor_ptr);
1138 }
1139
TensorData(PyObject * obj,bool needSpecialize,int recurseDepth)1140 TensorData(PyObject *obj, bool needSpecialize, int recurseDepth) : MetaTensorData(needSpecialize, recurseDepth) {
1141 is_stubtensor_ = false;
1142 tp_ = ItemType::Tensor;
1143 mindspore::tensor::TensorPtr tensor_ptr = nullptr;
1144 PyObject *stubattr = GetAttrStubStr();
1145 PyObject *stub = PyObject_HasAttr(obj, stubattr) ? PyObject_GetAttr(obj, stubattr) : nullptr;
1146 if (stub != nullptr) {
1147 if (stub != Py_None) {
1148 specialized_ = false;
1149 }
1150 if (specialized_) {
1151 auto pyObj = python_adapter::CallPyObjMethod(py::cast<py::object>(obj), kPyMethodStubSync);
1152 tensor_ptr = py::cast<mindspore::tensor::TensorPtr>(pyObj.ptr());
1153 } else {
1154 if (stub != Py_None) {
1155 is_stubtensor_ = true;
1156 } else {
1157 PyObject *tensorattr = GetAttrTensorStr();
1158 obj = PyObject_GetAttr(obj, tensorattr);
1159 tensor_ptr = py::cast<mindspore::tensor::TensorPtr>(obj);
1160 Py_DECREF(obj);
1161 }
1162 }
1163 } else if (py::isinstance<mindspore::tensor::Tensor>(obj)) {
1164 tensor_ptr = py::cast<mindspore::tensor::TensorPtr>(obj);
1165 } else if (py::isinstance<mindspore::tensor::MapTensor>(obj)) {
1166 tensor_ptr = py::cast<mindspore::tensor::MapTensorPtr>(obj);
1167 } else {
1168 tensor_ptr = py::cast<mindspore::tensor::TensorPtr>(obj);
1169 }
1170 if (tensor_ptr != nullptr) {
1171 if (OptStrategy::MakeCalcStrategyByShape(tensor_ptr->shape()) != OptStrategy::CalcKind::kCalcValue) {
1172 specialized_ = false;
1173 }
1174 StoreTensor(tensor_ptr);
1175 } else {
1176 auto ptr = py::cast<mindspore::stub::StubNodePtr>(stub);
1177 StoreStubTensor(ptr);
1178 }
1179 Py_XDECREF(stub);
1180 }
1181
~TensorData()1182 ~TensorData() override { data_ptr_.release(); }
1183
IsBaseShapePtr(const TensorData & other) const1184 bool IsBaseShapePtr(const TensorData &other) const {
1185 return (other.base_shape_ptr_ == nullptr && base_shape_ptr_ == other.base_shape_ptr_) ||
1186 (base_shape_ptr_ != nullptr && other.base_shape_ptr_ != nullptr &&
1187 *(other.base_shape_ptr_) == *(base_shape_ptr_));
1188 }
1189
IsCastDtype(const TensorData & other) const1190 bool IsCastDtype(const TensorData &other) const {
1191 return (other.cast_dtype_ == nullptr && cast_dtype_ == nullptr) ||
1192 (other.cast_dtype_ != nullptr && cast_dtype_ != nullptr && *cast_dtype_ == *(other.cast_dtype_));
1193 }
1194
operator ==(const ItemData & obj) const1195 bool operator==(const ItemData &obj) const override {
1196 if (!ItemData::operator==(obj)) {
1197 return false;
1198 }
1199 bool ret = MetaTensorData::operator==(obj);
1200 const TensorData &other = static_cast<const TensorData &>(obj);
1201 if (is_stubtensor_ || other.is_stubtensor_) {
1202 return ret;
1203 }
1204 ret = ret && other.init_flag_ == init_flag_ && other.is_forward_output_ == is_forward_output_ &&
1205 other.graph_output_ == graph_output_ && other.specialized_ == specialized_ && IsBaseShapePtr(other) &&
1206 IsCastDtype(other) && other.compression_type_ == compression_type_ &&
1207 other.quant_params_.size() == quant_params_.size() && other.tensor_name_.compare(tensor_name_) == 0;
1208 if (!ret) {
1209 return ret;
1210 }
1211 for (size_t i = 0; i < quant_params_.size(); ++i) {
1212 if (*(quant_params_[i]) == *(other.quant_params_[i])) {
1213 continue;
1214 } else {
1215 return false;
1216 }
1217 }
1218 if (IsDynamicShape() || other.IsDynamicShape()) {
1219 return true;
1220 } else {
1221 return CheckData(other);
1222 }
1223 }
1224
ToString()1225 std::string ToString() override {
1226 std::string tensor = ToStringIntern();
1227 return DESC(tensor) + DESC_END;
1228 }
1229
1230 protected:
ToStringIntern()1231 std::string ToStringIntern() override {
1232 std::string ret = MetaTensorData::ToStringIntern();
1233 ret += DESC_STRING(is_forward_output_) + DESC_STRING(init_flag_) + DESC_STRING(graph_output_);
1234 ret +=
1235 DESC_TOSTRING(cast_dtype_) + DESC_TOSTRING(base_shape_ptr_) + DESC_STRING(compression_type_) + DESC(tensor_name_);
1236 for (size_t i = 0; i < quant_params_.size(); ++i) {
1237 ret += DESC_INDEX(quant_params_, i);
1238 }
1239 return ret;
1240 }
1241
CheckData(const TensorData & other) const1242 bool CheckData(const TensorData &other) const {
1243 bool ret;
1244 if (specialized_) {
1245 if (data_ptr_ == nullptr || other.data_ptr_ == nullptr) {
1246 ret = data_len_ == other.data_len_;
1247 } else if (data_len_ == other.data_len_) {
1248 ret = memcmp(data_ptr_.get(), other.data_ptr_.get(), data_len_) == 0;
1249 } else {
1250 ret = false;
1251 }
1252 } else {
1253 ret = data_len_ == other.data_len_;
1254 }
1255 return ret;
1256 }
1257
StoreTensor(mindspore::tensor::TensorPtr tensor_ptr)1258 void StoreTensor(mindspore::tensor::TensorPtr tensor_ptr) {
1259 MetaTensorData::StoreTensor(tensor_ptr);
1260 init_flag_ = tensor_ptr->is_init();
1261 is_forward_output_ = tensor_ptr->is_forward_output();
1262 id_ = tensor_ptr->id();
1263 graph_output_ = tensor_ptr->IsGraphOutput();
1264 base_shape_ptr_ = tensor_ptr->base_shape_ptr() == nullptr ? nullptr : tensor_ptr->base_shape_ptr()->Clone();
1265 cast_dtype_ = (tensor_ptr->cast_dtype() == nullptr) ? nullptr : tensor_ptr->cast_dtype()->Clone();
1266 compression_type_ = tensor_ptr->compression_type();
1267 const std::vector<std::shared_ptr<mindspore::QuantizationParam>> &qp = tensor_ptr->quant_params();
1268 tensor_name_ = tensor_ptr->name();
1269 for (auto quant : qp) {
1270 QuantizationParamPtr qptr = std::make_shared<mindspore::QuantizationParam>(quant->quant_algo_name());
1271 quant_params_.push_back(qptr);
1272 qptr->set_attrs(quant->attrs());
1273 }
1274 if (specialized_) {
1275 tensor_ptr->data_sync(true);
1276 auto data = tensor_ptr->data_ptr();
1277 data_len_ = size_t(data->nbytes());
1278 data_ptr_ = std::make_unique<uint8_t[]>(data_len_);
1279 if (data_ptr_ != nullptr) {
1280 memcpy_s(data_ptr_.get(), data_len_, reinterpret_cast<uint8_t *>(data->data()), data_len_);
1281 }
1282 } else {
1283 data_ptr_.reset(nullptr);
1284 data_len_ = size_t(tensor_ptr->data_ptr()->nbytes());
1285 }
1286 }
1287
SubInfo(InfoPack * info)1288 void SubInfo(InfoPack *info) override {
1289 MetaTensorData::SubInfo(info);
1290 (*info) << is_forward_output_ << init_flag_ << graph_output_ << cast_dtype_ << base_shape_ptr_
1291 << uint8_t(compression_type_) << tensor_name_;
1292 (*info) << uint64_t(quant_params_.size());
1293 for (auto qp : quant_params_) {
1294 (*info) << qp;
1295 }
1296 }
1297
1298 bool init_flag_;
1299 bool is_forward_output_;
1300 std::unique_ptr<uint8_t[]> data_ptr_;
1301 size_t data_len_;
1302 std::string id_;
1303 bool graph_output_;
1304 mindspore::abstract::BaseShapePtr base_shape_ptr_;
1305 mindspore::TypePtr cast_dtype_;
1306 mindspore::TensorCompressionType compression_type_;
1307 std::vector<QuantizationParamPtr> quant_params_;
1308 std::string tensor_name_;
1309 };
1310 using TensorDataPtr = std::shared_ptr<TensorData>;
1311
Equal(const TensorDataPtr & a,const TensorDataPtr & b,int recurseDepth)1312 static bool Equal(const TensorDataPtr &a, const TensorDataPtr &b, int recurseDepth) {
1313 if (recurseDepth > 0) {
1314 if (a == nullptr && b == nullptr) {
1315 return true;
1316 } else if (a != nullptr && b != nullptr) {
1317 return *a == *b;
1318 } else {
1319 return false;
1320 }
1321 } else {
1322 return (a == nullptr) == (b == nullptr);
1323 }
1324 }
1325
CreateTensorData(mindspore::tensor::TensorPtr tensor,bool needSpecialize,int recurseDepth)1326 static TensorDataPtr CreateTensorData(mindspore::tensor::TensorPtr tensor, bool needSpecialize, int recurseDepth) {
1327 if (recurseDepth > 0) {
1328 return (tensor == nullptr) ? nullptr : std::make_shared<TensorData>(tensor, needSpecialize, recurseDepth);
1329 } else {
1330 return nullptr;
1331 }
1332 }
1333
1334 class MapTensorData : public TensorData {
1335 public:
MapTensorData(PyObject * obj,bool needSpecialize,int recurseDepth)1336 MapTensorData(PyObject *obj, bool needSpecialize, int recurseDepth) : TensorData(obj, needSpecialize, recurseDepth) {
1337 tp_ = ItemType::MapTensor;
1338 needSpecialize = specialized_;
1339 auto pyObj = py::cast<py::object>(obj);
1340 auto tensor_ptr = pyObj.cast<mindspore::tensor::MapTensorPtr>();
1341 key_dtype_ = tensor_ptr->key_dtype();
1342 if (tensor_ptr->key_tensor() != nullptr) {
1343 key_shape_ = tensor_ptr->key_tensor()->shape();
1344 }
1345 default_value_ = tensor_ptr->default_value() == nullptr ? nullptr : tensor_ptr->default_value()->type()->Clone();
1346 permit_filter_value_ =
1347 tensor_ptr->permit_filter_value() == nullptr ? nullptr : tensor_ptr->permit_filter_value()->type()->Clone();
1348 evict_filter_value_ =
1349 tensor_ptr->evict_filter_value() == nullptr ? nullptr : tensor_ptr->evict_filter_value()->type()->Clone();
1350 value_shape_ = tensor_ptr->value_shape();
1351 key_tensor_ = CreateTensorData(tensor_ptr->key_tensor(), needSpecialize, recurseDepth);
1352 value_tensor_ = CreateTensorData(tensor_ptr->value_tensor(), needSpecialize, recurseDepth);
1353 status_tensor_ = CreateTensorData(tensor_ptr->status_tensor(), needSpecialize, recurseDepth);
1354 }
1355
IsPermitFilterValue(const MapTensorData & other) const1356 bool IsPermitFilterValue(const MapTensorData &other) const {
1357 return (other.default_value_ == nullptr && default_value_ == nullptr) ||
1358 (other.default_value_ != nullptr && default_value_ != nullptr && *default_value_ == *(other.default_value_));
1359 }
1360
IsDefaultValue(const MapTensorData & other) const1361 bool IsDefaultValue(const MapTensorData &other) const {
1362 return (other.default_value_ == nullptr && default_value_ == nullptr) ||
1363 (other.default_value_ != nullptr && default_value_ != nullptr && *default_value_ == *(other.default_value_));
1364 }
1365
IsEvictFilterValue(const MapTensorData & other) const1366 bool IsEvictFilterValue(const MapTensorData &other) const {
1367 return (other.evict_filter_value_ == nullptr && evict_filter_value_ == nullptr) ||
1368 (other.evict_filter_value_ != nullptr && evict_filter_value_ != nullptr &&
1369 *evict_filter_value_ == *(other.evict_filter_value_));
1370 }
1371
operator ==(const ItemData & obj) const1372 bool operator==(const ItemData &obj) const override {
1373 if (!ItemData::operator==(obj)) {
1374 return false;
1375 }
1376 const MapTensorData &other = dynamic_cast<const MapTensorData &>(obj);
1377 bool ret = TensorData::operator==(obj);
1378 return ret && other.key_dtype_ == key_dtype_ && other.key_shape_ == key_shape_ && IsDefaultValue(other) &&
1379 IsPermitFilterValue(other) && IsEvictFilterValue(other) && value_shape_ == other.value_shape_ &&
1380 Equal(key_tensor_, other.key_tensor_, recurseDepth_) &&
1381 Equal(value_tensor_, other.value_tensor_, recurseDepth_) &&
1382 Equal(status_tensor_, other.status_tensor_, recurseDepth_);
1383 }
1384
ToString()1385 std::string ToString() override {
1386 std::string map_tensor = ToStringIntern();
1387 return DESC(map_tensor) + DESC_END;
1388 }
1389
1390 protected:
ToStringIntern()1391 std::string ToStringIntern() override {
1392 return TensorData::ToStringIntern() + DESC_STRING(key_dtype_) + DESC_TOSTRING(default_value_) +
1393 DESC_TOSTRING(permit_filter_value_) + DESC_TOSTRING(evict_filter_value_) + DESC_TOSTRING(key_tensor_) +
1394 DESC_TOSTRING(value_tensor_) + DESC_TOSTRING(status_tensor_) + DESC_END;
1395 }
1396
SubInfo(InfoPack * info)1397 void SubInfo(InfoPack *info) override {
1398 TensorData::SubInfo(info);
1399 (*info) << key_dtype_ << default_value_ << permit_filter_value_ << evict_filter_value_ << value_shape_
1400 << key_tensor_->Info() << value_tensor_->Info() << status_tensor_->Info();
1401 }
1402
1403 mindspore::TypeId key_dtype_;
1404 ShapeVector key_shape_;
1405 TypePtr default_value_;
1406 TypePtr permit_filter_value_;
1407 TypePtr evict_filter_value_;
1408 ShapeVector value_shape_;
1409 TensorDataPtr key_tensor_;
1410 TensorDataPtr value_tensor_;
1411 TensorDataPtr status_tensor_;
1412 };
1413
1414 class RowTensorData : public ItemData {
1415 public:
RowTensorData(PyObject * obj,bool needSpecialize,int recurseDepth)1416 RowTensorData(PyObject *obj, bool needSpecialize, int recurseDepth)
1417 : ItemData(ItemType::RowTensor, needSpecialize, recurseDepth) {
1418 auto pyObj = py::cast<py::object>(obj);
1419 auto tensor_ptr = pyObj.cast<mindspore::tensor::RowTensorPtr>();
1420 data_type_ = tensor_ptr->data_type();
1421 shape_ = tensor_ptr->shape();
1422 indices_ = CreateTensorData(tensor_ptr->GetIndices(), needSpecialize, recurseDepth);
1423 values_ = CreateTensorData(tensor_ptr->GetValues(), needSpecialize, recurseDepth);
1424 }
1425
operator ==(const ItemData & obj) const1426 bool operator==(const ItemData &obj) const override {
1427 if (ItemData::operator==(obj)) {
1428 const RowTensorData &other = static_cast<const RowTensorData &>(obj);
1429 return other.data_type_ == data_type_ && other.shape_ == shape_ &&
1430 Equal(indices_, other.indices_, recurseDepth_) && Equal(values_, other.values_, recurseDepth_);
1431 }
1432 return false;
1433 }
1434
ToString()1435 std::string ToString() override {
1436 std::string row_tensor = DESC_TOSTRING(indices_) + DESC_TOSTRING(values_) + DESC_STRING(data_type_);
1437 return DESC(row_tensor) + DESC_END;
1438 }
1439
1440 protected:
SubInfo(InfoPack * info)1441 void SubInfo(InfoPack *info) override { (*info) << indices_->Info() << values_->Info() << data_type_ << shape_; }
1442 TensorDataPtr indices_;
1443 TensorDataPtr values_;
1444 mindspore::TypeId data_type_;
1445 ShapeVector shape_;
1446 };
1447
1448 class COOTensorData : public ItemData {
1449 public:
COOTensorData(PyObject * obj,bool needSpecialize,int recurseDepth)1450 COOTensorData(PyObject *obj, bool needSpecialize, int recurseDepth)
1451 : ItemData(ItemType::COOTensor, needSpecialize, recurseDepth) {
1452 auto pyObj = py::cast<py::object>(obj);
1453 auto tensor_ptr = pyObj.cast<mindspore::tensor::COOTensorPtr>();
1454 data_type_ = tensor_ptr->data_type();
1455 shape_ = tensor_ptr->shape();
1456 indices_ = CreateTensorData(tensor_ptr->GetIndices(), needSpecialize, recurseDepth);
1457 values_ = CreateTensorData(tensor_ptr->GetValues(), needSpecialize, recurseDepth);
1458 }
1459
operator ==(const ItemData & obj) const1460 bool operator==(const ItemData &obj) const override {
1461 if (ItemData::operator==(obj)) {
1462 const COOTensorData &other = static_cast<const COOTensorData &>(obj);
1463 return other.data_type_ == data_type_ && other.shape_ == shape_ &&
1464 Equal(indices_, other.indices_, recurseDepth_) && Equal(values_, other.values_, recurseDepth_);
1465 }
1466 return false;
1467 }
1468
ToString()1469 std::string ToString() override {
1470 std::string coo_tensor = DESC_TOSTRING(indices_) + DESC_TOSTRING(values_) + DESC_STRING(data_type_);
1471 return DESC(coo_tensor) + DESC_END;
1472 }
1473
1474 protected:
SubInfo(InfoPack * info)1475 void SubInfo(InfoPack *info) override { (*info) << indices_->Info() << values_->Info() << data_type_ << shape_; }
1476 TensorDataPtr indices_;
1477 TensorDataPtr values_;
1478 mindspore::TypeId data_type_;
1479 ShapeVector shape_;
1480 };
1481
1482 class CSRTensorData : public ItemData {
1483 public:
CSRTensorData(PyObject * obj,bool needSpecialize,int recurseDepth)1484 CSRTensorData(PyObject *obj, bool needSpecialize, int recurseDepth)
1485 : ItemData(ItemType::CSRTensor, needSpecialize, recurseDepth) {
1486 auto pyObj = py::cast<py::object>(obj);
1487 auto tensor_ptr = pyObj.cast<mindspore::tensor::CSRTensorPtr>();
1488 data_type_ = tensor_ptr->data_type();
1489 shape_ = tensor_ptr->shape();
1490 indices_ = CreateTensorData(tensor_ptr->GetIndices(), needSpecialize, recurseDepth);
1491 values_ = CreateTensorData(tensor_ptr->GetValues(), needSpecialize, recurseDepth);
1492 indptr_ = CreateTensorData(tensor_ptr->GetIndptr(), needSpecialize, recurseDepth);
1493 }
1494
operator ==(const ItemData & obj) const1495 bool operator==(const ItemData &obj) const override {
1496 if (ItemData::operator==(obj)) {
1497 const CSRTensorData &other = static_cast<const CSRTensorData &>(obj);
1498 return other.data_type_ == data_type_ && other.shape_ == shape_ &&
1499 Equal(indices_, other.indices_, recurseDepth_) && Equal(values_, other.values_, recurseDepth_) &&
1500 Equal(indptr_, other.indptr_, recurseDepth_);
1501 }
1502 return false;
1503 }
1504
ToString()1505 std::string ToString() override {
1506 std::string csr_tensor =
1507 DESC_TOSTRING(indices_) + DESC_TOSTRING(values_) + DESC_TOSTRING(indptr_) + DESC_STRING(data_type_);
1508 return DESC(csr_tensor) + DESC_END;
1509 }
1510
1511 protected:
SubInfo(InfoPack * info)1512 void SubInfo(InfoPack *info) override {
1513 (*info) << indices_->Info() << values_->Info() << indptr_->Info() << data_type_ << shape_;
1514 }
1515 TensorDataPtr indices_;
1516 TensorDataPtr values_;
1517 TensorDataPtr indptr_;
1518 mindspore::TypeId data_type_;
1519 ShapeVector shape_;
1520 };
1521
1522 class TensorDataData : public ItemData {
1523 public:
TensorDataData(PyObject * obj,bool needSpecialize,int recurseDepth)1524 TensorDataData(PyObject *obj, bool needSpecialize, int recurseDepth)
1525 : ItemData(ItemType::Tensordata, needSpecialize, recurseDepth) {
1526 auto pyObj = py::cast<py::object>(obj);
1527 auto data = pyObj.cast<mindspore::tensor::TensorDataPtr>();
1528 size_ = (uint64_t)data->size();
1529 itemsize_ = (uint64_t)data->itemsize();
1530 nbytes_ = (uint64_t)data->nbytes();
1531 ndim_ = (int64_t)data->ndim();
1532 if (specialized_) {
1533 data_ptr_ = std::make_unique<uint8_t[]>(nbytes_);
1534 if (data_ptr_ != nullptr) {
1535 memcpy_s(data_ptr_.get(), nbytes_, reinterpret_cast<uint8_t *>(data->data()), nbytes_);
1536 }
1537 } else {
1538 data_ptr_.reset(nullptr);
1539 }
1540 }
1541
~TensorDataData()1542 ~TensorDataData() override { data_ptr_.release(); }
1543
operator ==(const ItemData & obj) const1544 bool operator==(const ItemData &obj) const override {
1545 if (ItemData::operator==(obj)) {
1546 const TensorDataData &other = static_cast<const TensorDataData &>(obj);
1547 if (specialized_) {
1548 return data_ptr_ != nullptr && other.data_ptr_ != nullptr && nbytes_ == other.nbytes_ &&
1549 memcmp(data_ptr_.get(), other.data_ptr_.get(), nbytes_) == 0;
1550 } else {
1551 return size_ == other.size_ && itemsize_ == other.itemsize_ && nbytes_ == other.nbytes_ &&
1552 ndim_ == other.ndim_ && (data_ptr_ == nullptr) == (other.data_ptr_ == nullptr);
1553 }
1554 }
1555 return false;
1556 }
1557
ToString()1558 std::string ToString() override {
1559 std::string tensor_data = DESC_STRING(size_) + DESC_STRING(itemsize_) + DESC_STRING(nbytes_) + DESC_STRING(ndim_);
1560 return DESC(tensor_data) + DESC_END;
1561 }
1562
1563 protected:
SubInfo(InfoPack * info)1564 void SubInfo(InfoPack *info) override { (*info) << size_ << itemsize_ << nbytes_ << ndim_; }
1565 std::unique_ptr<uint8_t[]> data_ptr_;
1566 uint64_t size_;
1567 uint64_t itemsize_;
1568 uint64_t nbytes_;
1569 int64_t ndim_;
1570 };
1571
1572 class PrimitiveData : public ItemData {
1573 public:
PrimitiveData(PyObject * obj,bool needSpecialize,int recurseDepth)1574 PrimitiveData(PyObject *obj, bool needSpecialize, int recurseDepth)
1575 : ItemData(ItemType::Primitive, needSpecialize, recurseDepth) {
1576 auto pyObj = py::cast<py::object>(obj);
1577 auto data = pyObj.cast<PrimitivePyAdapterPtr>();
1578 py::dict pd = data->GetAttrDict();
1579 auto dct = pd.ptr();
1580 Py_ssize_t pos = 0;
1581 PyObject *key;
1582 PyObject *val;
1583 while (PyDict_Next(dct, &pos, &key, &val)) {
1584 ItemDataPtr k;
1585 ItemDataPtr v;
1586 if (recurseDepth > 0 || needSpecialize) {
1587 k = CreateItem(key, needSpecialize, recurseDepth);
1588 v = CreateItem(val, needSpecialize, recurseDepth);
1589 } else {
1590 k =
1591 CreateItem((key == NULL || key == Py_None) ? NULL : reinterpret_cast<PyObject *>(Py_TYPE(key)), false, false);
1592 v =
1593 CreateItem((val == NULL || val == Py_None) ? NULL : reinterpret_cast<PyObject *>(Py_TYPE(val)), false, false);
1594 }
1595 listK_.push_back(k);
1596 listV_.push_back(v);
1597 }
1598 }
1599
operator ==(const ItemData & obj) const1600 bool operator==(const ItemData &obj) const override {
1601 if (ItemData::operator==(obj)) {
1602 const PrimitiveData &other = static_cast<const PrimitiveData &>(obj);
1603 if (other.listK_.size() == listK_.size() && other.listV_.size() == listV_.size()) {
1604 for (size_t i = 0; i < listK_.size(); ++i) {
1605 if (*(listK_[i]) == *(other.listK_[i]) && *(listV_[i]) == *(other.listV_[i])) {
1606 continue;
1607 } else {
1608 return false;
1609 }
1610 }
1611 return true;
1612 }
1613 }
1614 return false;
1615 }
1616
ToString()1617 std::string ToString() override {
1618 std::string primitive;
1619 for (size_t i = 0; i < listK_.size(); ++i) {
1620 primitive += DESC_ITEM(listK_[i], listV_[i]);
1621 }
1622 return DESC(primitive) + DESC_END;
1623 }
1624
1625 protected:
SubInfo(InfoPack * info)1626 void SubInfo(InfoPack *info) override {
1627 (*info) << uint64_t(listK_.size());
1628 for (auto item : listK_) {
1629 (*info) << item->Info();
1630 }
1631 (*info) << uint64_t(listV_.size());
1632 for (auto item : listV_) {
1633 (*info) << item->Info();
1634 }
1635 }
1636 std::vector<ItemDataPtr> listK_;
1637 std::vector<ItemDataPtr> listV_;
1638 };
1639
1640 class CellData : public ItemData {
1641 public:
CellData(PyObject * obj,bool needSpecialize,int recurseDepth)1642 CellData(PyObject *obj, bool needSpecialize, int recurseDepth)
1643 : ItemData(ItemType::Cell, needSpecialize, recurseDepth) {
1644 auto pyObj = py::cast<py::object>(obj);
1645 auto cell = pyObj.cast<mindspore::CellPtr>();
1646 PyObject *ns = PyObject_GetAttrString(obj, "__dict__");
1647 if (!ns) {
1648 return;
1649 }
1650 PyObject *items = PyMapping_Items(ns);
1651 if (!items) {
1652 return;
1653 }
1654 for (Py_ssize_t pos = 0; pos < PyList_GET_SIZE(items); pos++) {
1655 PyObject *it = PySequence_Fast(PyList_GET_ITEM(items, pos), "items() returned non-iterable");
1656 if (!it || PySequence_Fast_GET_SIZE(it) != 2) {
1657 if (it) {
1658 Py_DECREF(it);
1659 }
1660 continue;
1661 }
1662 PyObject *key = PySequence_Fast_GET_ITEM(it, 0);
1663 PyObject *val = PySequence_Fast_GET_ITEM(it, 1);
1664 ItemDataPtr k;
1665 ItemDataPtr v;
1666 if (recurseDepth > 0 || needSpecialize) {
1667 k = CreateItem(key, needSpecialize, recurseDepth);
1668 v = CreateItem(val, needSpecialize, recurseDepth);
1669 } else {
1670 k =
1671 CreateItem((key == NULL || key == Py_None) ? NULL : reinterpret_cast<PyObject *>(Py_TYPE(key)), false, false);
1672 v =
1673 CreateItem((val == NULL || val == Py_None) ? NULL : reinterpret_cast<PyObject *>(Py_TYPE(val)), false, false);
1674 }
1675 listK_.push_back(k);
1676 listV_.push_back(v);
1677 Py_DECREF(it);
1678 }
1679 Py_DECREF(items);
1680 Py_DECREF(ns);
1681 }
1682
operator ==(const ItemData & obj) const1683 bool operator==(const ItemData &obj) const override {
1684 if (ItemData::operator==(obj)) {
1685 const CellData &other = static_cast<const CellData &>(obj);
1686 for (size_t i = 0; i < listK_.size(); ++i) {
1687 if (*(listK_[i]) == *(other.listK_[i]) && *(listV_[i]) == *(other.listV_[i])) {
1688 continue;
1689 } else {
1690 return false;
1691 }
1692 }
1693 return true;
1694 }
1695 return false;
1696 }
1697
ToString()1698 std::string ToString() override {
1699 std::string cell;
1700 for (size_t i = 0; i < listK_.size(); ++i) {
1701 cell += DESC_ITEM(listK_[i], listV_[i]);
1702 }
1703 return DESC(cell) + DESC_END;
1704 }
1705
1706 protected:
SubInfo(InfoPack * info)1707 void SubInfo(InfoPack *info) override {
1708 (*info) << uint64_t(listK_.size());
1709 for (auto item : listK_) {
1710 (*info) << item->Info();
1711 }
1712 (*info) << uint64_t(listV_.size());
1713 for (auto item : listV_) {
1714 (*info) << item->Info();
1715 }
1716 }
1717 std::vector<ItemDataPtr> listK_;
1718 std::vector<ItemDataPtr> listV_;
1719 };
1720
1721 class UnknownData : public ItemData {
1722 public:
UnknownData(PyObject * obj,bool needSpecialize,int recurseDepth)1723 UnknownData(PyObject *obj, bool needSpecialize, int recurseDepth)
1724 : ItemData(ItemType::PyUnknown, needSpecialize, recurseDepth) {
1725 refId_ = obj;
1726 }
1727
operator ==(const ItemData & obj) const1728 bool operator==(const ItemData &obj) const override {
1729 if (ItemData::operator==(obj)) {
1730 return refId_ == (static_cast<const UnknownData &>(obj)).refId_;
1731 }
1732 return false;
1733 }
1734
ToString()1735 std::string ToString() override {
1736 std::string ret = "unknown";
1737 return ret + ItemData::ToString();
1738 }
1739
1740 protected:
1741 PyObject *refId_;
1742 };
1743
ListData(PyObject * obj,bool needSpecialize,int recurseDepth)1744 ListData::ListData(PyObject *obj, bool needSpecialize, int recurseDepth)
1745 : ItemData(ItemType::PyList, needSpecialize, recurseDepth) {
1746 if (PyList_Check(obj)) {
1747 InitList(obj, needSpecialize, recurseDepth);
1748 } else if (PyTuple_Check(obj)) {
1749 InitTuple(obj, needSpecialize, recurseDepth);
1750 } else if (PySet_Check(obj)) {
1751 InitSet(obj, needSpecialize, recurseDepth);
1752 } else if (PyFrozenSet_Check(obj)) {
1753 InitFrozenSet(obj, needSpecialize, recurseDepth);
1754 }
1755 if (needSpecialize) {
1756 return; // value check exactly
1757 }
1758 for (const auto &data : listVar_) {
1759 if (data->GetItemType() == ItemType::PyType) {
1760 static_cast<TypeData *>(data.get())->set_ambiguous_tensor_type(true);
1761 }
1762 }
1763 }
1764
1765 using CheckPyObjectFunc = bool (*)(PyObject *obj);
1766 using CreatePyObjectFunc = ItemDataPtr (*)(PyObject *obj, bool need_specialize, int recurse_depth);
1767 template <typename T>
CreatePyData(PyObject * obj,bool need_specialize,int recurse_depth)1768 ItemDataPtr CreatePyData(PyObject *obj, bool need_specialize, int recurse_depth) {
1769 return std::make_shared<T>(obj, need_specialize, recurse_depth);
1770 }
1771 template <typename T>
CreateMutablePyData(PyObject * obj,bool need_specialize,int recurse_depth)1772 ItemDataPtr CreateMutablePyData(PyObject *obj, bool need_specialize, int recurse_depth) {
1773 return std::make_shared<T>(obj, false, recurse_depth);
1774 }
CheckMetaTensorObject(PyObject * obj)1775 static bool CheckMetaTensorObject(PyObject *obj) {
1776 return py::isinstance<mindspore::tensor::MetaTensor>(obj) || IsStubTensor(py::cast<py::object>(obj));
1777 }
CheckTensorObject(PyObject * obj)1778 static bool CheckTensorObject(PyObject *obj) { return py::isinstance<mindspore::tensor::Tensor>(obj); }
CheckDictKeyValueItemObject(PyObject * obj)1779 static bool CheckDictKeyValueItemObject(PyObject *obj) {
1780 return !!PyDict_Check(obj) || !!PyDictKeys_Check(obj) || !!PyDictValues_Check(obj) || !!PyDictItems_Check(obj);
1781 }
1782 static const std::vector<std::pair<CheckPyObjectFunc, CreatePyObjectFunc>> kFuncPyObjectConverter = {
__anonb719df630302() 1783 {[](PyObject *obj) -> bool { return PyLong_Check(obj) && !PyBool_Check(obj); }, CreatePyData<IntData>},
__anonb719df630402() 1784 {[](PyObject *obj) -> bool { return !!PyFloat_Check(obj); }, CreatePyData<FloatData>},
__anonb719df630502() 1785 {[](PyObject *obj) -> bool { return !!PyBool_Check(obj); }, CreatePyData<BoolData>},
__anonb719df630602() 1786 {[](PyObject *obj) -> bool { return !!PyBytes_Check(obj); }, CreatePyData<BytesData>},
__anonb719df630702() 1787 {[](PyObject *obj) -> bool { return !!PyUnicode_Check(obj); }, CreatePyData<StringData>},
__anonb719df630802() 1788 {[](PyObject *obj) -> bool { return !!PyList_Check(obj); }, CreatePyData<ListData>},
__anonb719df630902() 1789 {[](PyObject *obj) -> bool { return !!PyTuple_Check(obj); }, CreatePyData<ListData>},
__anonb719df630a02() 1790 {[](PyObject *obj) -> bool { return !!PySet_Check(obj); }, CreatePyData<ListData>},
__anonb719df630b02() 1791 {[](PyObject *obj) -> bool { return !!PyFrozenSet_Check(obj); }, CreatePyData<ListData>},
__anonb719df630c02() 1792 {[](PyObject *obj) -> bool { return CheckDictKeyValueItemObject(obj); }, CreatePyData<DictData>},
__anonb719df630d02() 1793 {[](PyObject *obj) -> bool { return !!PyComplex_Check(obj); }, CreatePyData<ComplexData>},
__anonb719df630e02() 1794 {[](PyObject *obj) -> bool { return !!PySlice_Check(obj); }, CreatePyData<SliceData>},
__anonb719df630f02() 1795 {[](PyObject *obj) -> bool { return !!PyFunction_Check(obj); }, CreatePyData<FunctionData>},
__anonb719df631002() 1796 {[](PyObject *obj) -> bool { return !!PyMethod_Check(obj); }, CreatePyData<MethodData>},
__anonb719df631102() 1797 {[](PyObject *obj) -> bool { return !!PyInstanceMethod_Check(obj); }, CreatePyData<InstanceMethodData>},
__anonb719df631202() 1798 {[](PyObject *obj) -> bool { return !!PyType_Check(obj); }, CreatePyData<TypeData>},
__anonb719df631302() 1799 {[](PyObject *obj) -> bool { return py::isinstance<py::array>(obj); }, CreatePyData<NumpyData>},
__anonb719df631402() 1800 {[](PyObject *obj) -> bool { return py::isinstance<mindspore::Type>(obj); }, CreatePyData<TensorTypeData>},
__anonb719df631502() 1801 {[](PyObject *obj) -> bool { return py::isinstance<mindspore::tensor::MapTensor>(obj); },
1802 CreatePyData<MapTensorData>},
__anonb719df631602() 1803 {[](PyObject *obj) -> bool { return py::isinstance<mindspore::ParamInfo>(obj); }, CreatePyData<ParamInfoData>},
__anonb719df631702() 1804 {[](PyObject *obj) -> bool { return CheckMetaTensorObject(obj); }, CreatePyData<MetaTensorData>},
__anonb719df631802() 1805 {[](PyObject *obj) -> bool { return CheckTensorObject(obj); }, CreatePyData<TensorData>},
__anonb719df631902() 1806 {[](PyObject *obj) -> bool { return py::isinstance<mindspore::tensor::TensorData>(obj); },
1807 CreatePyData<TensorDataData>},
__anonb719df631a02() 1808 {[](PyObject *obj) -> bool { return py::isinstance<mindspore::PrimitivePyAdapter>(obj); },
1809 CreatePyData<PrimitiveData>},
__anonb719df631b02() 1810 {[](PyObject *obj) -> bool { return py::isinstance<mindspore::Cell>(obj); }, CreatePyData<CellData>},
__anonb719df631c02() 1811 {[](PyObject *obj) -> bool { return py::isinstance<mindspore::tensor::RowTensor>(obj); },
1812 CreateMutablePyData<RowTensorData>},
__anonb719df631d02() 1813 {[](PyObject *obj) -> bool { return py::isinstance<mindspore::tensor::COOTensor>(obj); },
1814 CreateMutablePyData<COOTensorData>},
__anonb719df631e02() 1815 {[](PyObject *obj) -> bool { return py::isinstance<mindspore::tensor::CSRTensor>(obj); },
1816 CreateMutablePyData<CSRTensorData>},
1817 };
1818
CreateData(PyObject * obj,bool need_specialize,int recurse_depth)1819 static ItemDataPtr CreateData(PyObject *obj, bool need_specialize, int recurse_depth) {
1820 auto tar =
1821 std::find_if(kFuncPyObjectConverter.begin(), kFuncPyObjectConverter.end(),
1822 [obj](const std::pair<CheckPyObjectFunc, CreatePyObjectFunc> &func) { return func.first(obj); });
1823 if (tar != kFuncPyObjectConverter.end()) {
1824 return tar->second(obj, need_specialize, recurse_depth);
1825 } else {
1826 return std::make_shared<UnknownData>(obj, need_specialize, recurse_depth);
1827 }
1828 }
1829
CreateItem(PyObject * obj,bool need_specialize,int recurse_depth)1830 static ItemDataPtr CreateItem(PyObject *obj, bool need_specialize, int recurse_depth) {
1831 ReprRecursionScope scope(obj);
1832 if (scope.ReEnterOrError()) {
1833 return std::make_shared<ItemData>(ItemType::PyNull, need_specialize, recurse_depth);
1834 }
1835 if (recurse_depth < -1) {
1836 if (obj != NULL && obj != Py_None) {
1837 PyObject *py_type;
1838 py::object py_obj = py::reinterpret_borrow<py::object>(obj);
1839 if (IsStubTensor(py_obj)) {
1840 py_type = GetMsTensorType();
1841 } else {
1842 py_type = reinterpret_cast<PyObject *>(Py_TYPE(obj));
1843 }
1844 return std::make_shared<TypeData>(py_type, false, 0);
1845 } else {
1846 return std::make_shared<ItemData>(ItemType::PyNull, false, 0);
1847 }
1848 }
1849 recurse_depth -= 1;
1850 ItemDataPtr dp;
1851 if (obj != NULL && obj != Py_None) {
1852 dp = CreateData(obj, need_specialize, recurse_depth);
1853 } else {
1854 dp = std::make_shared<ItemData>(ItemType::PyNull, need_specialize, recurse_depth);
1855 }
1856 return dp;
1857 }
1858
GuardItem(TracePtr tt)1859 GuardItem::GuardItem(TracePtr tt) : var_(tt), type_(GIType::GTUnknown), info_(nullptr) {}
1860
Replace(TracePtr dst,TracePtr src)1861 void GuardItem::Replace(TracePtr dst, TracePtr src) {
1862 if (!var_) {
1863 return;
1864 }
1865 if (*var_ == *src) {
1866 var_ = dst;
1867 } else {
1868 var_->Replace(dst, src);
1869 }
1870 }
1871
Optimize()1872 GuardItemPtr GuardItem::Optimize() {
1873 auto trace = var_->Optimize();
1874 if (trace != nullptr) {
1875 var_ = trace;
1876 info_ = nullptr;
1877 Info();
1878 return shared_from_this();
1879 } else {
1880 return nullptr;
1881 }
1882 }
1883
GetTrace()1884 TracePtr GuardItem::GetTrace() { return var_; }
1885
operator ==(const GuardItem & obj) const1886 bool GuardItem::operator==(const GuardItem &obj) const { return type_ == obj.type_ && *var_ == *(obj.var_); }
1887
1888 static constexpr int kGuardItemTotalStage = 2;
1889 static constexpr int kGuardItemRetrieveStage = 0;
1890 static constexpr int kGuardItemCompareStage = 1;
1891
GuardItemPerfStart(bool enable,int total)1892 static void GuardItemPerfStart(bool enable, int total) {
1893 if (enable) {
1894 OptGuardPerf::GetGuardPerf()->LogItemPerfStart(total);
1895 }
1896 }
1897
GuardItemPerfStage(bool enable,GuardItem * item,int stage)1898 static void GuardItemPerfStage(bool enable, GuardItem *item, int stage) {
1899 if (enable) {
1900 OptGuardPerf::GetGuardPerf()->LogItemPerfEnd(item, stage);
1901 }
1902 }
1903
1904 class EqGuard : public GuardItem {
1905 public:
EqGuard(TracePtr obj,bool needSpecialize,int recurseDepth)1906 EqGuard(TracePtr obj, bool needSpecialize, int recurseDepth)
1907 : GuardItem(obj),
1908 dp_(CreateItem(obj->GetObject(), needSpecialize, recurseDepth)),
1909 specialized_(needSpecialize),
1910 recurse_(recurseDepth) {
1911 type_ = GIType::GTEqual;
1912 last_ = obj->GetObject();
1913 }
1914
Check(const PyFrameObject * frame,std::map<size_t,PyObject * > * cache,bool perf)1915 virtual bool Check(const PyFrameObject *frame, std::map<size_t, PyObject *> *cache, bool perf) {
1916 if (var_->IsConst()) {
1917 return true;
1918 }
1919 if (!var_->IsSpecialized() && var_->GetRelaxCount() > 0 && !RootTrace::Support(var_->GetTraceType()) &&
1920 !specialized_) {
1921 // it just needs to guard inputs instead of leaf node
1922 return true;
1923 }
1924 GuardItemPerfStart(perf, kGuardItemTotalStage);
1925 PyObject *obj = GetObjectFromTrace(frame, var_, cache, perf);
1926 GuardItemPerfStage(perf, this, kGuardItemRetrieveStage);
1927 bool ret = obj == last_ || Check(obj);
1928 GuardItemPerfStage(perf, this, kGuardItemCompareStage);
1929 if (obj != NULL) {
1930 Py_DECREF(obj);
1931 }
1932 return ret;
1933 }
1934
Check(PyObject * obj)1935 virtual bool Check(PyObject *obj) {
1936 ItemDataPtr other = CreateItem(obj, specialized_, recurse_);
1937 return *dp_ == *other;
1938 }
1939
ToString()1940 virtual std::string ToString() {
1941 if (strGuard_.size() > 0) {
1942 return strGuard_;
1943 }
1944 strGuard_ = var_->ToString() + "==" + dp_->ToString();
1945 strGuard_ = std::regex_replace(strGuard_, std::regex("(\n)"), "");
1946 return strGuard_;
1947 }
1948
Info()1949 virtual const InfoPack &Info() {
1950 if (info_ == nullptr) {
1951 InfoPack info;
1952 info << uint8_t(type_);
1953 info.Begin();
1954 info << var_->Info() << dp_->Info();
1955 info.End();
1956 info_ = std::make_shared<InfoPack>(info);
1957 info_->Update();
1958 }
1959 return *info_;
1960 }
1961
operator ==(const GuardItem & obj) const1962 bool operator==(const GuardItem &obj) const override {
1963 if (GuardItem::operator==(obj)) {
1964 auto other = static_cast<const EqGuard &>(obj);
1965 return specialized_ == other.specialized_ && recurse_ == other.recurse_ && *dp_ == *(other.dp_);
1966 }
1967 return false;
1968 }
1969
MatchDynamicShape(std::shared_ptr<GuardItem> other)1970 bool MatchDynamicShape(std::shared_ptr<GuardItem> other) override {
1971 var_->Detach();
1972 other->GetTrace()->Detach();
1973 if (other->GetType() != GIType::GTEqual || !(*var_ == *(other->GetTrace())) ||
1974 !dp_->MatchDynamicShape((static_cast<EqGuard *>(other.get()))->dp_)) {
1975 return false;
1976 } else {
1977 return true;
1978 }
1979 }
1980
ApplyDynamicShape(PyObject * obj)1981 PyObject *ApplyDynamicShape(PyObject *obj) override {
1982 auto type = dp_->GetItemType();
1983 if (type != ItemType::MetaTensor && type != ItemType::Tensor) {
1984 return nullptr;
1985 }
1986 auto item = (MetaTensorData &)(*dp_);
1987 if (item.IsDynamicShape()) {
1988 return py::cast(item.MakeTensor()).inc_ref().ptr();
1989 } else {
1990 return nullptr;
1991 }
1992 }
1993
1994 protected:
1995 ItemDataPtr dp_;
1996 bool specialized_;
1997 int recurse_;
1998 PyObject *last_;
1999 };
2000
2001 class TypeGuard : public GuardItem {
2002 public:
TypeGuard(TracePtr obj)2003 explicit TypeGuard(TracePtr obj) : GuardItem(obj) {
2004 type_ = GIType::GTType;
2005 is_const_ = false;
2006 if (obj->GetTraceType() == TraceType::Type) {
2007 refType_ = std::dynamic_pointer_cast<TypeTrace>(obj)->GetType();
2008 } else {
2009 refType_ = Py_TYPE(obj->GetObject());
2010 }
2011 if (obj->GetRelaxCount() > 0) {
2012 check_count_ = 0;
2013 } else {
2014 check_count_ = -1;
2015 }
2016 }
2017
Check(const PyFrameObject * frame,std::map<size_t,PyObject * > * cache,bool perf)2018 virtual bool Check(const PyFrameObject *frame, std::map<size_t, PyObject *> *cache, bool perf) {
2019 if (var_->IsConst() || is_const_) {
2020 return true;
2021 }
2022 GuardItemPerfStart(perf, kGuardItemTotalStage);
2023 PyObject *obj = GetObjectFromTrace(frame, var_, cache, perf);
2024 GuardItemPerfStage(perf, this, kGuardItemRetrieveStage);
2025 bool ret = Check(obj);
2026 GuardItemPerfStage(perf, this, kGuardItemCompareStage);
2027 if (var_->GetTraceType() != TraceType::Type && obj != NULL) {
2028 Py_DECREF(obj);
2029 }
2030 if (check_count_ >= 0) {
2031 if (!ret) {
2032 check_count_ = -1;
2033 } else {
2034 check_count_++;
2035 if (check_count_ > var_->GetRelaxCount()) {
2036 is_const_ = true;
2037 }
2038 }
2039 }
2040 return ret;
2041 }
2042
Check(PyObject * obj)2043 virtual bool Check(PyObject *obj) {
2044 if (obj == NULL) {
2045 return false;
2046 }
2047 PyTypeObject *tp;
2048 if (var_->GetTraceType() == TraceType::Type) {
2049 tp = reinterpret_cast<PyTypeObject *>(obj);
2050 } else {
2051 tp = Py_TYPE(obj);
2052 }
2053 if (tp != refType_) {
2054 return false;
2055 } else {
2056 return true;
2057 }
2058 }
2059
ToString()2060 std::string ToString() override {
2061 if (strGuard_.size() > 0) {
2062 return strGuard_;
2063 }
2064 if (var_->GetTraceType() == TraceType::Type) {
2065 strGuard_ = var_->ToString() + std::string("==") + refType_->tp_name;
2066 } else {
2067 strGuard_ = std::string("type(") + var_->ToString() + std::string(")==") + refType_->tp_name;
2068 }
2069 strGuard_ = std::regex_replace(strGuard_, std::regex("(\n)"), "");
2070 return strGuard_;
2071 }
2072
Info()2073 virtual const InfoPack &Info() {
2074 if (info_ == nullptr) {
2075 InfoPack info;
2076 info << uint8_t(type_);
2077 info.Begin();
2078 info << var_->Info() << refType_->tp_name;
2079 info.End();
2080 info_ = std::make_shared<InfoPack>(info);
2081 info_->Update();
2082 }
2083 return *info_;
2084 }
2085
operator ==(const GuardItem & obj) const2086 bool operator==(const GuardItem &obj) const override {
2087 if (GuardItem::operator==(obj)) {
2088 return refType_ == (static_cast<const TypeGuard &>(obj)).refType_;
2089 }
2090 return false;
2091 }
2092
2093 protected:
2094 PyTypeObject *refType_;
2095 int check_count_;
2096 bool is_const_;
2097 };
2098
2099 class IdGuard : public GuardItem {
2100 public:
IdGuard(TracePtr obj)2101 explicit IdGuard(TracePtr obj) : GuardItem(obj) {
2102 type_ = GIType::GTId;
2103 refId_ = obj->GetObject();
2104 }
2105
Check(const PyFrameObject * frame,std::map<size_t,PyObject * > * cache,bool perf)2106 virtual bool Check(const PyFrameObject *frame, std::map<size_t, PyObject *> *cache, bool perf) {
2107 if (var_->IsConst()) {
2108 return true;
2109 }
2110 GuardItemPerfStart(perf, kGuardItemTotalStage);
2111 PyObject *obj = GetObjectFromTrace(frame, var_, cache, perf);
2112 GuardItemPerfStage(perf, this, kGuardItemRetrieveStage);
2113 bool ret = Check(obj);
2114 GuardItemPerfStage(perf, this, kGuardItemCompareStage);
2115 if (obj != NULL) {
2116 Py_DECREF(obj);
2117 }
2118 return ret;
2119 }
2120
Check(PyObject * obj)2121 virtual bool Check(PyObject *obj) {
2122 bool ret = false;
2123 if (obj == NULL) {
2124 return ret;
2125 }
2126 if (obj != refId_) {
2127 ret = false;
2128 } else {
2129 ret = true;
2130 }
2131 return ret;
2132 }
2133
ToString()2134 std::string ToString() override {
2135 if (strGuard_.size() > 0) {
2136 return strGuard_;
2137 }
2138 strGuard_ = std::string("id(") + var_->ToString() + std::string(")==") + std::to_string((size_t)refId_);
2139 strGuard_ = std::regex_replace(strGuard_, std::regex("(\n)"), "");
2140 return strGuard_;
2141 }
2142
Info()2143 virtual const InfoPack &Info() {
2144 if (info_ == nullptr) {
2145 InfoPack info;
2146 info << uint8_t(type_);
2147 info.Begin();
2148 info << var_->Info() << reinterpret_cast<void *>(refId_);
2149 info.End();
2150 info_ = std::make_shared<InfoPack>(info);
2151 info_->Update();
2152 }
2153 return *info_;
2154 }
2155
operator ==(const GuardItem & obj) const2156 bool operator==(const GuardItem &obj) const override {
2157 if (GuardItem::operator==(obj)) {
2158 return refId_ == (static_cast<const IdGuard &>(obj)).refId_;
2159 }
2160 return false;
2161 }
2162
2163 protected:
2164 PyObject *refId_;
2165 };
2166
2167 class ReprGuard : public GuardItem {
2168 public:
ReprGuard(TracePtr obj)2169 explicit ReprGuard(TracePtr obj) : GuardItem(obj) {
2170 type_ = GIType::GTRepr;
2171 refRepr_ = PyObject_Repr(obj->GetObject());
2172 }
2173
~ReprGuard()2174 virtual ~ReprGuard() { Py_XDECREF(refRepr_); }
2175
Check(const PyFrameObject * frame,std::map<size_t,PyObject * > * cache,bool perf)2176 virtual bool Check(const PyFrameObject *frame, std::map<size_t, PyObject *> *cache, bool perf) {
2177 if (var_->IsConst()) {
2178 return true;
2179 }
2180 GuardItemPerfStart(perf, kGuardItemTotalStage);
2181 PyObject *obj = GetObjectFromTrace(frame, var_, cache, perf);
2182 GuardItemPerfStage(perf, this, kGuardItemRetrieveStage);
2183 bool ret = Check(obj);
2184 GuardItemPerfStage(perf, this, kGuardItemCompareStage);
2185 if (obj != nullptr) {
2186 Py_DECREF(obj);
2187 }
2188 return ret;
2189 }
2190
Check(PyObject * obj)2191 virtual bool Check(PyObject *obj) {
2192 bool ret = false;
2193 if (obj == nullptr) {
2194 return ret;
2195 }
2196 PyObject *repr_flag = GetAttrReprCacheStr();
2197 PyObject *repr;
2198 bool from_cache = false;
2199 if (PyObject_HasAttr(obj, repr_flag)) {
2200 from_cache = true;
2201 repr = PyObject_GetAttr(obj, repr_flag);
2202 } else {
2203 repr = PyObject_Repr(obj);
2204 PyObject_SetAttr(obj, repr_flag, repr);
2205 }
2206 if (PyUnicode_Compare(repr, refRepr_)) {
2207 // Inplace operation may change the object without clearing the cache of guard.
2208 if (from_cache && !PyUnicode_Compare(PyObject_Repr(obj), refRepr_)) {
2209 ret = true;
2210 } else {
2211 ret = false;
2212 }
2213 } else {
2214 ret = true;
2215 }
2216 Py_XDECREF(repr);
2217 return ret;
2218 }
2219
ToString()2220 std::string ToString() override {
2221 if (strGuard_.size() > 0) {
2222 return strGuard_;
2223 }
2224 strGuard_ = std::string(PyUnicode_AsUTF8(refRepr_));
2225 strGuard_ = std::regex_replace(strGuard_, std::regex("(\n)"), "");
2226 return strGuard_;
2227 }
2228
operator ==(const GuardItem & obj) const2229 bool operator==(const GuardItem &obj) const override {
2230 if (GuardItem::operator==(obj)) {
2231 return refRepr_ == (static_cast<const ReprGuard &>(obj)).refRepr_;
2232 }
2233 return false;
2234 }
2235
2236 protected:
Info()2237 virtual const InfoPack &Info() {
2238 if (info_ == nullptr) {
2239 InfoPack info;
2240 info << uint8_t(type_);
2241 info.Begin();
2242 info << std::string(PyUnicode_AsUTF8(refRepr_));
2243 info.End();
2244 info_ = std::make_shared<InfoPack>(info);
2245 info_->Update();
2246 }
2247 return *info_;
2248 }
2249 PyObject *refRepr_;
2250 };
2251
2252 class AttrGuard : public GuardItem {
2253 public:
AttrGuard(TracePtr pObj)2254 explicit AttrGuard(TracePtr pObj) : GuardItem(pObj) {
2255 type_ = GIType::GTAttr;
2256 AttrTracePtr t = std::dynamic_pointer_cast<AttrTrace>(pObj);
2257 PyObject *obj = t->GetOrigin()->GetObject();
2258 nameAttr_ = t->GetAttribute();
2259 if (PyObject_HasAttrString(obj, nameAttr_.c_str()) != 0) {
2260 hasAttr_ = true;
2261 } else {
2262 hasAttr_ = false;
2263 bool is_dict = PyDict_CheckExact(obj);
2264 PyObject *itemName = PyUnicode_FromString(nameAttr_.c_str());
2265 PyObject *attr = NULL;
2266 if (is_dict) {
2267 attr = PyDict_GetItem(obj, itemName);
2268 if (attr != NULL) {
2269 Py_INCREF(attr);
2270 }
2271 } else if (PyMapping_Check(obj) || PySequence_Check(obj)) {
2272 attr = PyObject_GetItem(obj, itemName);
2273 }
2274 hasAttr_ = attr != NULL;
2275 Py_DECREF(itemName);
2276 if (attr != NULL) {
2277 Py_DECREF(attr);
2278 }
2279 }
2280 }
2281
2282 ~AttrGuard() = default;
2283
Check(const PyFrameObject * frame,std::map<size_t,PyObject * > * cache,bool perf)2284 virtual bool Check(const PyFrameObject *frame, std::map<size_t, PyObject *> *cache, bool perf) {
2285 if (var_->IsConst()) {
2286 return true;
2287 }
2288 GuardItemPerfStart(perf, kGuardItemTotalStage);
2289 PyObject *obj = GetObjectFromTrace(frame, var_, cache, perf);
2290 GuardItemPerfStage(perf, this, kGuardItemRetrieveStage);
2291 bool ret = CheckIntern(obj);
2292 GuardItemPerfStage(perf, this, kGuardItemCompareStage);
2293 if (obj != NULL) {
2294 Py_DECREF(obj);
2295 }
2296 return ret;
2297 }
2298
Check(PyObject * obj)2299 virtual bool Check(PyObject *obj) {
2300 bool ret;
2301 if (PyObject_HasAttrString(obj, nameAttr_.c_str()) != 0) {
2302 ret = hasAttr_;
2303 } else {
2304 bool is_dict = PyDict_CheckExact(obj);
2305 PyObject *itemName = PyUnicode_FromString(nameAttr_.c_str());
2306 PyObject *attr = NULL;
2307 if (is_dict) {
2308 attr = PyDict_GetItem(obj, itemName);
2309 if (attr != NULL) {
2310 Py_INCREF(attr);
2311 }
2312 } else if (PyMapping_Check(obj) || PySequence_Check(obj)) {
2313 attr = PyObject_GetItem(obj, itemName);
2314 }
2315 ret = CheckIntern(attr);
2316 Py_DECREF(itemName);
2317 if (attr != NULL) {
2318 Py_DECREF(attr);
2319 }
2320 }
2321 return ret;
2322 }
2323
CheckIntern(PyObject * obj)2324 virtual bool CheckIntern(PyObject *obj) {
2325 bool ret;
2326 if ((obj == NULL && !hasAttr_) || (obj != NULL && hasAttr_)) {
2327 ret = true;
2328 } else {
2329 ret = false;
2330 }
2331 return ret;
2332 }
2333
ToString()2334 virtual std::string ToString() {
2335 if (strGuard_.size() > 0) {
2336 return strGuard_;
2337 }
2338 strGuard_ = std::string("exist(") + var_->ToString() + std::string(".") + nameAttr_ +
2339 "==" + std::to_string(hasAttr_) + std::string(")");
2340 strGuard_ = std::regex_replace(strGuard_, std::regex("(\n)"), "");
2341 return strGuard_;
2342 }
2343
operator ==(const GuardItem & obj) const2344 bool operator==(const GuardItem &obj) const override {
2345 if (GuardItem::operator==(obj)) {
2346 return hasAttr_ == (static_cast<const AttrGuard &>(obj)).hasAttr_ &&
2347 nameAttr_ == (static_cast<const AttrGuard &>(obj)).nameAttr_;
2348 }
2349 return false;
2350 }
2351
2352 protected:
Info()2353 virtual const InfoPack &Info() {
2354 if (info_ == nullptr) {
2355 InfoPack info;
2356 info << uint8_t(type_);
2357 info.Begin();
2358 info << var_->Info() << nameAttr_ << hasAttr_;
2359 info.End();
2360 info_ = std::make_shared<InfoPack>(info);
2361 info_->Update();
2362 }
2363 return *info_;
2364 }
2365 bool hasAttr_;
2366 std::string nameAttr_;
2367 };
2368
GuardEqual(TracePtr obj,bool needSpecialize,int recurseDepth)2369 GuardItemPtr GuardEqual(TracePtr obj, bool needSpecialize, int recurseDepth) {
2370 return std::make_shared<EqGuard>(obj, needSpecialize, recurseDepth);
2371 }
2372
GuardType(TracePtr obj)2373 GuardItemPtr GuardType(TracePtr obj) { return std::make_shared<TypeGuard>(obj); }
2374
GuardId(TracePtr obj)2375 GuardItemPtr GuardId(TracePtr obj) {
2376 auto py_obj = obj->GetObject();
2377 auto pyObj = py::cast<py::object>(obj->GetObject());
2378 bool is_param = py::hasattr(pyObj, "__parameter__") && py::isinstance<tensor::MetaTensor>(pyObj);
2379 if (!is_param && (IsStubTensor(pyObj) || py::isinstance<mindspore::tensor::Tensor>(py_obj))) {
2380 return GuardEqual(obj, false, INT_MAX);
2381 } else {
2382 return std::make_shared<IdGuard>(obj);
2383 }
2384 }
2385
GuardRepr(TracePtr obj)2386 GuardItemPtr GuardRepr(TracePtr obj) { return std::make_shared<ReprGuard>(obj); }
2387
GuardAttr(TracePtr obj)2388 GuardItemPtr GuardAttr(TracePtr obj) {
2389 if (obj->GetTraceType() != TraceType::Attr) {
2390 return nullptr;
2391 } else {
2392 return std::make_shared<AttrGuard>(obj);
2393 }
2394 }
2395
IsPyObjectEqual(PyObject * src,PyObject * dst)2396 bool IsPyObjectEqual(PyObject *src, PyObject *dst) {
2397 if (src == dst) {
2398 return true;
2399 }
2400 ItemDataPtr src_item = CreateItem(src, true, INT_MAX);
2401 ItemDataPtr dst_item = CreateItem(dst, true, INT_MAX);
2402 return *src_item == *dst_item;
2403 }
2404
2405 static PyObject *g_ms_module = nullptr;
2406 static PyObject *g_ms_type = nullptr;
2407 static PyObject *g_tensor_type = nullptr;
2408
InitMsModule()2409 static bool InitMsModule() {
2410 if (g_ms_module == nullptr) {
2411 g_ms_module = PyImport_ImportModule("mindspore");
2412 }
2413 return g_ms_module != nullptr && g_ms_module != Py_None;
2414 }
2415
InitMsType()2416 static bool InitMsType() {
2417 if (g_ms_type == NULL) {
2418 g_ms_type = PyImport_ImportModule("mindspore.common.dtype");
2419 }
2420 return g_ms_type != NULL && g_ms_type != Py_None;
2421 }
2422
InitMsTensor()2423 static bool InitMsTensor() {
2424 if (g_tensor_type == nullptr && InitMsModule()) {
2425 g_tensor_type = PyObject_GetAttrString(g_ms_module, "Tensor");
2426 }
2427 return g_tensor_type != nullptr && g_tensor_type != Py_None && PyType_Check(g_tensor_type);
2428 }
2429
GetMsModule()2430 PyObject *GetMsModule() {
2431 if (InitMsModule()) {
2432 return g_ms_module;
2433 } else {
2434 return nullptr;
2435 }
2436 }
2437
GetMsType()2438 PyObject *GetMsType() {
2439 if (InitMsType()) {
2440 return g_ms_type;
2441 } else {
2442 return nullptr;
2443 }
2444 }
2445
GetMsTensorType()2446 PyObject *GetMsTensorType() {
2447 if (InitMsTensor()) {
2448 return g_tensor_type;
2449 } else {
2450 return nullptr;
2451 }
2452 }
2453
2454 } // namespace pijit
2455 } // namespace mindspore
2456