1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pipeline/jit/pi/graph_capture/abstract_object.h"
17 #include <algorithm>
18 #include <string>
19 #include <utility>
20 #include <vector>
21 #include <memory>
22 #include "utils/log_adapter.h"
23 #include "pipeline/jit/pi/utils/utils.h"
24 #include "pipeline/jit/pi/pydef.h"
25 #include "pipeline/jit/pi/graph_guard/infer.h"
26 #include "pipeline/jit/pi/graph_compiler/utils.h"
27 #include "pipeline/jit/pi/graph_compiler/pi_ir/ctrl_flow.h"
28 #include "pipeline/jit/pi/graph_compiler/pi_ir/custom_nodes.h"
29 #include "pipeline/jit/pi/graph_compiler/pi_ir/operation.h"
30 #include "pipeline/jit/pi/graph_compiler/pi_ir/value.h"
31 #include "pipeline/jit/ps/action.h"
32 #include "pipeline/jit/ps/parse/data_converter.h"
33 #include "mindspore/core/ops/math_ops.h"
34 #include "include/common/utils/convert_utils_py.h"
35
36 namespace mindspore {
37 namespace pijit {
38 static const size_t DictStep = 2;
39 #define FIND_MAP_CACHE(map, target) \
40 do { \
41 auto iter = (map).find(target); \
42 if (iter != (map).end()) { \
43 return iter->second; \
44 } \
45 } while (0)
46
47 #ifdef DEBUG
48 #define CHECK_PYTHON_EXCEPTION(check_res) \
49 if (PyErr_Occurred()) { \
50 MS_LOG(DEBUG) << "has an python exception"; \
51 MS_ASSERT((check_res) == nullptr); \
52 PyErr_Print(); \
53 PyErr_Clear(); \
54 }
55 #else
56 #define CHECK_PYTHON_EXCEPTION(check_res) PyErr_Clear()
57 #endif
58
59 // mindspore graph can accept these value
60 static const std::set<AObject::Type> kMsSupportedType = {
61 AObject::kTypeInt, AObject::kTypeBool, AObject::kTypeFloat,
62 AObject::kTypeNone, AObject::kTypeString, AObject::kTypeTensor,
63 };
64
65 std::vector<AbstractObjectBase::Resource *> AbstractObjectBase::Resource::weak_this_;
66 bool AbstractObjectBase::trace_flag_ = false;
67
Resource()68 AbstractObjectBase::Resource::Resource() : pool_(__FILE__, __LINE__, "AObject") {
69 MS_EXCEPTION_IF_CHECK_FAIL(weak_this_.empty(), "can't reentrant");
70 weak_this_.push_back(this);
71 }
~Resource()72 AbstractObjectBase::Resource::~Resource() {
73 MS_EXCEPTION_IF_CHECK_FAIL(weak_this_.size() == 1, "can't reentrant");
74 Release();
75 weak_this_.pop_back();
76 }
77
78 // exact equal check
79 static const std::unordered_map<PyTypeObject *, AObject::Type> exact_type_map = {
80 {&PyFunction_Type, AObject::kTypeFunction},
81 {&PyMethod_Type, AObject::kTypeBoundMethod},
82 {&PyCode_Type, AObject::kTypeCodeObject},
83 {&PySlice_Type, AObject::kTypeSlice},
84 {&PySet_Type, AObject::kTypeSet},
85 {&PyFrozenSet_Type, AObject::kTypeSet},
86 {&PyBool_Type, AObject::kTypeBool},
87 {&PyFloat_Type, AObject::kTypeFloat},
88 {&PyLong_Type, AObject::kTypeInt},
89 {&PyList_Type, AObject::kTypeList},
90 {&PyTuple_Type, AObject::kTypeTuple},
91 {&PyDict_Type, AObject::kTypeDict},
92 {&PyDictValues_Type, AObject::kTypeDictValues},
93 {&PyDictKeys_Type, AObject::kTypeDictKeys},
94 {&PyDictItems_Type, AObject::kTypeDictItems},
95 {&PyType_Type, AObject::kTypeType},
96 {&PyUnicode_Type, AObject::kTypeString},
97 {&PyModule_Type, AObject::kTypeModule},
98 {&PyCFunction_Type, AObject::kTypeCFunction},
99 {nullptr, AObject::kTypeAnyValue},
100 };
101
102 // shouldn't add nullptr to this map
103 static const std::unordered_map<PyObject *, AObject::Type> const_object_type_map = {
104 {Py_Ellipsis, AObject::kTypeEllipsis},
105 {Py_None, AObject::kTypeNone},
106 {Py_True, AObject::kTypeBool},
107 {Py_False, AObject::kTypeBool},
108 };
109
110 static const std::vector<std::pair<PyTypeObject *, AObject::Type>> sub_type_map = {
111 {&PyModule_Type, AObject::kTypeModule}, {&PyCFunction_Type, AObject::kTypeCFunction}};
112
113 constexpr size_t fast_type_mask = Py_TPFLAGS_LONG_SUBCLASS | Py_TPFLAGS_LIST_SUBCLASS | Py_TPFLAGS_TUPLE_SUBCLASS |
114 Py_TPFLAGS_UNICODE_SUBCLASS | Py_TPFLAGS_DICT_SUBCLASS | Py_TPFLAGS_TYPE_SUBCLASS;
115
GetTypeDesc(AObject::Type type)116 const char *AbstractObjectBase::GetTypeDesc(AObject::Type type) {
117 #define ABSTRACT_TYPE_DEF(unit) \
118 if (type == AObject::kType##unit) { \
119 return "kType" #unit; \
120 }
121 #include "abstract_type_kind.def"
122 #undef ABSTRACT_TYPE_DEF
123 return "unknown type";
124 }
125
IsMindSporeSupportedType()126 bool AbstractObjectBase::IsMindSporeSupportedType() {
127 return kMsSupportedType.find(GetType()) != kMsSupportedType.end();
128 }
129
ToString(PyObject * op)130 std::string AbstractObjectBase::ToString(PyObject *op) {
131 if (op == nullptr) {
132 return "<NULL>";
133 }
134 ReprRecursionScope scope(op);
135 if (scope.ReEnter()) {
136 return "...";
137 }
138
139 py::object obj = py::cast<py::object>(op);
140 AObject::Type t = AObject::GetPyType(op);
141 std::stringstream s;
142 s << std::string(py::str(reinterpret_cast<PyObject *>(Py_TYPE(op)))) << "{ ";
143 switch (t) {
144 case AObject::kTypeTensor:
145 case AObject::kTypeStubTensor: {
146 s << std::string(py::str(obj.attr("shape"))) << ", " << std::string(py::str(obj.attr("dtype")));
147 break;
148 }
149 case AObject::kTypeBoundMethod: {
150 s << std::string(py::str(PyMethod_GET_FUNCTION(op))) << " at " << ToString(PyMethod_GET_SELF(op));
151 break;
152 }
153 case AObject::kTypeNNCellList:
154 case AObject::kTypeList:
155 case AObject::kTypeTuple: {
156 s << (t == AObject::kTypeTuple ? "( " : "[ ");
157 for (auto i : py::iter(obj)) {
158 s << ToString(i.ptr()) << ", ";
159 }
160 s.seekp(-2, s.cur);
161 s << (t == AObject::kTypeTuple ? " )" : " ]");
162 break;
163 }
164 case AObject::kTypeDict: {
165 PyObject *key;
166 PyObject *val;
167 Py_ssize_t pos = 0;
168 s << "{ ";
169 while (PyDict_Next(op, &pos, &key, &val)) {
170 s << ToString(key) << ":" << ToString(val) << ", ";
171 }
172 s.seekp(-2, s.cur);
173 s << " }";
174 break;
175 }
176 case AObject::kTypeAnyValue:
177 case AObject::kTypeCell: {
178 s << " at " << op;
179 break;
180 }
181 default:
182 s << std::string(py::str(obj));
183 break;
184 }
185 s << " }";
186 return s.str();
187 }
188
ToString() const189 std::string AbstractObjectBase::ToString() const {
190 std::string s = " ";
191 #define ABSTRACT_MS_FLAG_DEF(unit, bit) s += ((ms_flag_ & kMsFlag##unit) ? #unit "|" : "");
192 #include "abstract_ms_flag.def"
193 #undef ABSTRACT_MS_FLAG_DEF
194 if (s.back() == '|') {
195 s.pop_back();
196 }
197 if (type_object_ != nullptr) {
198 s += std::string(py::str(reinterpret_cast<PyObject *>(type_object_)));
199 }
200 return GetTypeDesc(GetType()) + s;
201 }
202
GetPyType(PyTypeObject * tp)203 AbstractObjectBase::Type AbstractObjectBase::GetPyType(PyTypeObject *tp) {
204 if (tp == nullptr) {
205 return kTypeAnyValue;
206 }
207 FIND_MAP_CACHE(exact_type_map, tp);
208 // fast sub type check
209 // __builtin_clz(tp->tp_flags & fast_type_mask), or std::countl_zero
210 /**
211 * sub-class int, float, list, tuple, str, is mindspore unsupported
212 */
213 switch (tp->tp_flags & fast_type_mask) {
214 case Py_TPFLAGS_LONG_SUBCLASS:
215 case Py_TPFLAGS_LIST_SUBCLASS:
216 case Py_TPFLAGS_TUPLE_SUBCLASS:
217 case Py_TPFLAGS_UNICODE_SUBCLASS:
218 case Py_TPFLAGS_DICT_SUBCLASS:
219 return kTypeAnyValue;
220 case Py_TPFLAGS_TYPE_SUBCLASS:
221 return kTypeType;
222 default:
223 break;
224 }
225 // sub type check
226 for (auto &i : sub_type_map) {
227 if (PyType_IsSubtype(tp, i.first)) {
228 return i.second;
229 }
230 }
231 return GetMsType(tp);
232 }
233
GetPyType(PyObject * o)234 AbstractObjectBase::Type AbstractObjectBase::GetPyType(PyObject *o) {
235 if (o == nullptr) {
236 return kTypeAnyValue;
237 }
238 FIND_MAP_CACHE(const_object_type_map, o);
239 if (PyLong_Check(o)) {
240 return (Py_ABS(Py_SIZE(o)) > 2) ? kTypeAnyValue : kTypeInt;
241 }
242 return GetPyType(Py_TYPE(o));
243 }
244
GetMsType(PyTypeObject * tp)245 AbstractObjectBase::Type AbstractObjectBase::GetMsType(PyTypeObject *tp) {
246 static const std::vector<std::pair<bool (*)(PyTypeObject *), AObject::Type>> match_func = {
247 {IsStubTensorType<true>, kTypeStubTensor}, {IsTensorType<true>, kTypeTensor},
248 {IsCellListType<false>, kTypeNNCellList}, {IsCellType<true>, kTypeCell},
249 {IsPrimitiveType<true>, kTypePrimitive}, {IsMetaFuncGraphType<true>, kTypeMetaFuncGraph},
250 {IsMSDTypeType<true>, kTypeMSDType}, {IsPrimitiveFunctionType<true>, kTypePrimitiveFunction},
251 };
252 if (tp == nullptr) {
253 return kTypeAnyValue;
254 }
255 for (auto i : match_func) {
256 if (i.first(tp)) {
257 return i.second;
258 }
259 }
260 return kTypeAnyValue;
261 }
262
MakeAObject(AObject::Type type,PyTypeObject * tp,PyObject * o,RecMap * m)263 AObject *AbstractObjectBase::MakeAObject(AObject::Type type, PyTypeObject *tp, PyObject *o, RecMap *m) {
264 MS_EXCEPTION_IF_CHECK_FAIL(Resource::Current() != nullptr, "can't take resource");
265 MS_EXCEPTION_IF_CHECK_FAIL(tp == nullptr || o == nullptr || Py_TYPE(o) == tp, "check type match value");
266 py::object h = py::cast<py::object>(o);
267 AObject *res;
268 switch (type) {
269 case kTypeStubTensor:
270 case kTypeTensor:
271 res = Resource::Current()->pool()->New<AbstractTensor>(h, type == kTypeStubTensor);
272 break;
273 case kTypeType:
274 res = Resource::Current()->pool()->New<AbstractType>(h);
275 break;
276 case kTypeString:
277 res = Resource::Current()->pool()->New<AbstractSequence>(kTypeString, h);
278 break;
279 case kTypeNNCellList:
280 res = Resource::Current()->pool()->New<AbstractSequence>(kTypeNNCellList, h);
281 break;
282 case kTypeList:
283 res = Resource::Current()->pool()->New<AbstractList>(h, m);
284 break;
285 case kTypeTuple:
286 res = Resource::Current()->pool()->New<AbstractTuple>(h, m);
287 break;
288 case kTypeDict:
289 res = Resource::Current()->pool()->New<AbstractDict>(h, m);
290 break;
291 case kTypeAnyValue:
292 if (tp == nullptr) {
293 res = Resource::Current()->pool()->New<AbstractObjectBase>(kTypeAnyValue);
294 break;
295 }
296 /* fall-through */
297 default:
298 // known type
299 res = Resource::Current()->pool()->New<AbstractObject>(type, h);
300 break;
301 }
302 res->SetTypeObject(o == nullptr ? tp : Py_TYPE(o));
303 return res;
304 }
305
MakeFunction(const std::vector<AObject * > & args,const py::object & globals,int oparg)306 AObject *AbstractObjectBase::MakeFunction(const std::vector<AObject *> &args, const py::object &globals, int oparg) {
307 std::vector<py::object> pyarg;
308 std::transform(args.begin(), args.end(), std::back_inserter(pyarg), [](AObject *i) { return i->GetPyObject(); });
309 auto iter = pyarg.end() - 1;
310 PyObject *qualname = (*iter--).ptr();
311 PyObject *code = (*iter--).ptr();
312 py::object f_handle = py::reinterpret_steal<py::object>(PyFunction_NewWithQualName(code, globals.ptr(), qualname));
313 PyFunctionObject *func = reinterpret_cast<PyFunctionObject *>(f_handle.ptr());
314 MS_EXCEPTION_IF_CHECK_FAIL(func, "MAKE_FUNCTION failed");
315 if (IntToSize(oparg) & 0x08) {
316 func->func_closure = (*iter--).inc_ref().ptr();
317 Py_ssize_t nfrees = PyTuple_GET_SIZE(reinterpret_cast<PyCodeObject *>(code)->co_freevars);
318 bool is_valid = func->func_closure && nfrees == PyTuple_GET_SIZE(func->func_closure);
319 MS_EXCEPTION_IF_CHECK_FAIL(is_valid, "must be has python objects, and it is tuple of cell objects");
320 }
321 if (IntToSize(oparg) & 0x04) {
322 func->func_annotations = (*iter--).inc_ref().ptr();
323 MS_EXCEPTION_IF_CHECK_FAIL(func->func_annotations, "must be has python objects, and it is const key map");
324 }
325 if (IntToSize(oparg) & 0x02) {
326 func->func_kwdefaults = (*iter--).inc_ref().ptr();
327 MS_EXCEPTION_IF_CHECK_FAIL(func->func_kwdefaults, "must be has python objects, and it is const key map");
328 }
329 if (IntToSize(oparg) & 0x01) {
330 func->func_defaults = (*iter--).inc_ref().ptr();
331 MS_EXCEPTION_IF_CHECK_FAIL(func->func_defaults, "must be has python objects, and it is const tuple");
332 }
333 AObject *res = AObject::Convert(f_handle);
334 return res;
335 }
336
BuildOperations(const std::vector<py::object> & args,int opcode)337 py::object AbstractObjectBase::BuildOperations(const std::vector<py::object> &args, int opcode) {
338 PyObject *res = nullptr;
339 PyObject **tmp;
340 std::vector<PyObject *> arr;
341 if (opcode == BUILD_SLICE) {
342 res = PySlice_New(args[0].ptr(), args[1].ptr(), args.size() > 2 ? args[2].ptr() : nullptr);
343 } else if (opcode == BUILD_STRING) {
344 std::transform(args.begin(), args.end(), std::back_inserter(arr), [](const py::object &o) { return o.ptr(); });
345 res = _PyUnicode_JoinArray(py::str().ptr(), arr.data(), arr.size());
346 } else if (opcode == BUILD_SET) {
347 res = PySet_New(nullptr);
348 (void)std::find_if(args.begin(), args.end(), [&res](const py::object &i) { return PySet_Add(res, i.ptr()); });
349 } else if (opcode == BUILD_LIST) {
350 res = PyList_New(args.size());
351 tmp = &PyList_GET_ITEM(res, 0);
352 std::for_each(args.begin(), args.end(), [&tmp](const py::object &i) { return *(tmp++) = i.inc_ref().ptr(); });
353 } else if (opcode == BUILD_TUPLE) {
354 res = PyTuple_New(args.size());
355 tmp = &PyTuple_GET_ITEM(res, 0);
356 std::for_each(args.begin(), args.end(), [&tmp](const py::object &i) { return *(tmp++) = i.inc_ref().ptr(); });
357 } else if (opcode == BUILD_CONST_KEY_MAP) {
358 res = PyDict_New();
359 // must be tuple, here has a cast check
360 tmp = &PyTuple_GET_ITEM(args.back().ptr(), 0);
361 (void)std::find_if(args.begin(), args.end() - 1, [&res, &tmp](const py::object &i) {
362 return PyDict_SetItem(res, *(tmp++), i.ptr()); // break if err_ocurred
363 });
364 } else if (opcode == BUILD_MAP) {
365 res = PyDict_New();
366 for (size_t i = 0; !PyErr_Occurred() && i < args.size(); i += 2) {
367 PyDict_SetItem(res, args[i].ptr(), args[i + 1].ptr());
368 }
369 }
370 if (PyErr_Occurred()) {
371 Py_XDECREF(res);
372 MS_LOG(DEBUG) << "build operation failed: " << Opcode(opcode).name();
373 PyErr_Clear();
374 res = nullptr;
375 }
376 return py::reinterpret_steal<py::object>(res);
377 }
378
BuildOperations(const std::vector<AObject * > & inputs,int opcode)379 AObject *AbstractObjectBase::BuildOperations(const std::vector<AObject *> &inputs, int opcode) {
380 bool build_pyobject = true;
381 std::vector<py::object> args;
382 for (auto i = inputs.begin(); i != inputs.end() && build_pyobject; ++i) {
383 args.push_back(((*i) != nullptr) ? (*i)->GetPyObject() : py::object());
384 build_pyobject &= args.back().ptr() != nullptr;
385 }
386 if (build_pyobject) {
387 return Convert(BuildOperations(args, opcode));
388 }
389
390 AObject *res = nullptr;
391 PyObject *keys;
392 bool err = false;
393 if (opcode == BUILD_LIST || opcode == BUILD_TUPLE) {
394 res = MakeAObject(opcode == BUILD_LIST ? kTypeList : kTypeTuple);
395 static_cast<AbstractTuple *>(res)->Update(inputs);
396 } else if (opcode == BUILD_CONST_KEY_MAP) {
397 res = MakeAObject(kTypeDict);
398 keys = inputs.back()->GetPyObject().ptr();
399 err = static_cast<Py_ssize_t>(inputs.size() - 1) != PyTuple_GET_SIZE(keys);
400 for (Py_ssize_t i = IntToSize(inputs.size() - DictStep); !err && i >= 0; --i) {
401 err = !static_cast<AbstractDict *>(res)->MapAdd(Convert(PyTuple_GET_ITEM(keys, i)), inputs[i]);
402 }
403 } else if (opcode == BUILD_MAP) {
404 res = MakeAObject(kTypeDict);
405 for (size_t i = 0; !err && i < inputs.size(); i += 2) {
406 err = !static_cast<AbstractDict *>(res)->MapAdd(inputs[i], inputs[i + 1]);
407 }
408 } else if (opcode == BUILD_STRING) {
409 res = MakeAObject(kTypeString);
410 } else if (opcode == BUILD_SLICE) {
411 res = MakeAObject(kTypeSlice);
412 } else if (opcode == BUILD_SET) {
413 res = MakeAObject(kTypeSet);
414 } else {
415 err = true;
416 }
417 return err ? MakeAObject(kTypeAnyValue) : res;
418 }
419
MergeOperations(AObject * container,std::vector<AObject * > args,int opcode)420 AObject *AbstractObjectBase::MergeOperations(AObject *container, std::vector<AObject *> args, int opcode) {
421 Type type = container ? container->GetType() : kTypeAnyValue;
422 bool success = false;
423 if (opcode == LIST_EXTEND) {
424 success = type == kTypeList && (static_cast<AbstractList *>(container))->ListExtend(args[0]);
425 } else if (opcode == LIST_APPEND) {
426 success = type == kTypeList && (static_cast<AbstractList *>(container))->ListAppend(args[0]);
427 } else if (opcode == DICT_MERGE) {
428 success = type == kTypeDict && (static_cast<AbstractDict *>(container))->DictMerge(args[0]);
429 } else if (opcode == DICT_UPDATE) {
430 success = type == kTypeDict && (static_cast<AbstractDict *>(container))->DictUpdate(args[0]);
431 } else if (opcode == MAP_ADD) {
432 success = type == kTypeDict && (static_cast<AbstractDict *>(container))->MapAdd(args[0], args[1]);
433 } else if (opcode == SET_UPDATE || opcode == SET_ADD) {
434 success = true;
435 container = MakeAObject(kTypeSet);
436 }
437 if (!success) {
438 return MakeAObject(kTypeAnyValue);
439 }
440 return container;
441 }
442
AbstractObject(Type type,const py::object & o)443 AbstractObject::AbstractObject(Type type, const py::object &o) : AbstractObjectBase(type), value_(o) {
444 // cache attr
445 (void)GetAttr("__ms_mutable__");
446 }
447
GetIter() const448 AObject *AbstractObject::GetIter() const {
449 if (this->GetType() == kTypeAnyValue || value_.ptr() == nullptr) {
450 return MakeAObject(kTypeAnyValue);
451 }
452 PyObject *iter = PyObject_GetIter(value_.ptr());
453 CHECK_PYTHON_EXCEPTION(iter);
454 AObject *res = Convert(iter);
455 Py_XDECREF(iter);
456 return res;
457 }
458
GetAttr(const std::string & name)459 AObject *AbstractObjectBase::GetAttr(const std::string &name) {
460 PyTypeObject *tp = type_object_;
461 if (tp == nullptr) {
462 return MakeAObject(kTypeAnyValue);
463 }
464 py::str name_obj(name);
465 PyObject *attr_obj = PyObject_GetAttr(reinterpret_cast<PyObject *>(tp), name_obj.ptr());
466 if (attr_obj == nullptr) {
467 PyErr_Clear();
468 return MakeAObject(kTypeAnyValue);
469 }
470 AObject *attr = AObject::Convert(attr_obj);
471 Py_DECREF(attr_obj);
472
473 // look up mro, borrowed
474 PyObject *descr = _PyType_Lookup(tp, name_obj.ptr());
475 if (descr) {
476 // check @staticmethod and @classmethod
477 if (Py_IS_TYPE(descr, &PyStaticMethod_Type) || Py_IS_TYPE(descr, &PyClassMethod_Type)) {
478 // attr not modify
479 } else if (PyFunction_Check(descr)) {
480 MS_EXCEPTION_IF_CHECK_FAIL(attr_obj == descr, "unknown user defined descriptor");
481 PyObject *meth = PyMethod_New(descr, Py_None);
482 AObject *m = AObject::Convert(meth);
483 Py_DECREF(meth);
484 m->SetAttr("__self__", this);
485 m->SetAttr("__func__", attr);
486 attr = m;
487 } else {
488 // other type
489 attr = MakeAObject(kTypeAnyValue);
490 }
491 }
492 return attr;
493 }
494
GetAttr(const std::string & name)495 AObject *AbstractObject::GetAttr(const std::string &name) {
496 FIND_MAP_CACHE(attrs_, name);
497 AObject *res = nullptr;
498 if (value_.ptr() != nullptr) {
499 PRINT_IF_HAS_USER_DEFINED_HOOK(value_.ptr(), __getattr__);
500 PRINT_IF_HAS_USER_DEFINED_HOOK(value_.ptr(), __getattribute__);
501 #ifdef DEBUG
502 PyObject *tmp = PyObject_GetAttrString(reinterpret_cast<PyObject *>(Py_TYPE(value_.ptr())), name.c_str());
503 if (tmp) { // is user defined descriptor ?
504 PRINT_IF_HAS_USER_DEFINED_HOOK(tmp, __get__);
505 } else {
506 PyErr_Clear();
507 }
508 Py_XDECREF(tmp);
509 #endif
510 PyObject *attr = PyObject_GetAttrString(value_.ptr(), name.c_str());
511 CHECK_PYTHON_EXCEPTION(attr);
512 res = Convert(attr);
513 Py_XDECREF(attr);
514 } else {
515 res = this->AbstractObjectBase::GetAttr(name);
516 }
517 attrs_[name] = res;
518 return res;
519 }
520
SetAttr(const std::string & n,AObject * v)521 bool AbstractObject::SetAttr(const std::string &n, AObject *v) {
522 attrs_[n] = v ? v : MakeAObject(kTypeAnyValue);
523 return true;
524 }
525
GetItem(AObject * k)526 AObject *AbstractSequence::GetItem(AObject *k) {
527 auto iter = write_cache_.find(k);
528 if (iter != write_cache_.end()) {
529 return iter->second == nullptr ? MakeAObject(kTypeAnyValue) : iter->second;
530 }
531 return this->AbstractObject::GetItem(k);
532 }
533
GetItem(AObject * k)534 AObject *AbstractObject::GetItem(AObject *k) {
535 PyObject *s = this->GetPyObject().ptr();
536 PyObject *i = k ? k->GetPyObject().ptr() : nullptr;
537 PyObject *t = nullptr;
538 if (s != nullptr && i != nullptr && k->GetType() != kTypeAnyValue) {
539 t = PyObject_GetItem(s, i);
540 CHECK_PYTHON_EXCEPTION(t);
541 }
542 AObject *res = Convert(t);
543 Py_XDECREF(t);
544 return res;
545 }
546
SetItem(AObject * k,AObject * v)547 bool AbstractSequence::SetItem(AObject *k, AObject *v) {
548 if (this->type_ == kTypeString || this->type_ == kTypeTuple) {
549 return false;
550 }
551 write_cache_[k] = v ? v : MakeAObject(kTypeAnyValue);
552 return true;
553 }
554
UnaryValue(int op) const555 AObject *AbstractObject::UnaryValue(int op) const {
556 PyObject *res = nullptr;
557 if (op == UNARY_POSITIVE) {
558 res = PyNumber_Positive(value_.ptr());
559 } else if (op == UNARY_NEGATIVE) {
560 res = PyNumber_Negative(value_.ptr());
561 } else if (op == UNARY_INVERT) {
562 res = PyNumber_Invert(value_.ptr());
563 } else if (op == UNARY_NOT) {
564 int err = PyObject_IsTrue(value_.ptr());
565 res = err > 0 ? Py_False : (err == 0 ? Py_True : nullptr);
566 }
567 CHECK_PYTHON_EXCEPTION(res);
568 AObject *ret = Convert(res);
569 Py_XDECREF(res);
570 return ret;
571 }
572
Unary(int op) const573 AObject *AbstractObject::Unary(int op) const {
574 if (this->GetType() == kTypeAnyValue) {
575 return MakeAObject(kTypeAnyValue);
576 }
577 if (value_.ptr() != nullptr) {
578 return UnaryValue(op);
579 }
580 Type res_type = kTypeAnyValue;
581 Type type = this->GetType();
582 if (op == UNARY_POSITIVE || op == UNARY_NEGATIVE || op == UNARY_INVERT) {
583 if (type == kTypeBool || type == kTypeInt) {
584 res_type = kTypeInt;
585 } else if (type == kTypeFloat) {
586 res_type = kTypeFloat;
587 }
588 } else if (op == UNARY_NOT) {
589 bool is_num = type == kTypeBool || type == kTypeInt || type == kTypeFloat;
590 if (is_num || type == kTypeList || type == kTypeTuple || type == kTypeDict) {
591 res_type = kTypeBool;
592 }
593 }
594 return MakeAObject(res_type);
595 }
596
BinaryPow(PyObject * base,PyObject * exp)597 static PyObject *BinaryPow(PyObject *base, PyObject *exp) { return PyNumber_Power(base, exp, Py_None); }
InplacePow(PyObject * base,PyObject * exp)598 static PyObject *InplacePow(PyObject *base, PyObject *exp) { return PyNumber_InPlacePower(base, exp, Py_None); }
599
BinaryIntOp(AObject::Type l,AObject::Type r)600 static AObject::Type BinaryIntOp(AObject::Type l, AObject::Type r) {
601 AObject::Type type = AObject::kTypeAnyValue;
602 switch (l) {
603 case AObject::kTypeInt:
604 case AObject::kTypeBool:
605 if (r == AObject::kTypeInt || r == AObject::kTypeBool) {
606 type = AObject::kTypeInt;
607 }
608 break;
609 default:
610 break;
611 }
612 return type;
613 }
614
615 // operator '&', '^', '|'
NumberLogic(AObject::Type l,AObject::Type r)616 static AObject::Type NumberLogic(AObject::Type l, AObject::Type r) {
617 AObject::Type type = AObject::kTypeAnyValue;
618 if (l == AObject::kTypeBool) {
619 if (r == AObject::kTypeInt || r == AObject::kTypeBool) {
620 type = r;
621 }
622 } else {
623 type = BinaryIntOp(l, r);
624 }
625 return type;
626 }
627
628 // operator '+', '-', '*', '/', '%', '**', '//'
NumberArithmetic(AObject::Type l,AObject::Type r)629 static AObject::Type NumberArithmetic(AObject::Type l, AObject::Type r) {
630 AObject::Type type = AObject::kTypeAnyValue;
631 if (l == AObject::kTypeFloat || r == AObject::kTypeFloat) {
632 if (l == AObject::kTypeInt || l == AObject::kTypeBool || r == AObject::kTypeInt || r == AObject::kTypeBool) {
633 type = AObject::kTypeFloat;
634 }
635 } else {
636 type = BinaryIntOp(l, r);
637 }
638 return type;
639 }
640
BinaryAdd(AObject::Type l,AObject::Type r)641 static AObject::Type BinaryAdd(AObject::Type l, AObject::Type r) {
642 AObject::Type type = AObject::kTypeAnyValue;
643 switch (l) {
644 case AObject::kTypeTuple:
645 case AObject::kTypeList:
646 case AObject::kTypeString:
647 if (r == l) {
648 type = l;
649 }
650 break;
651 default:
652 type = NumberArithmetic(l, r);
653 break;
654 }
655 return type;
656 }
657
BinaryInferDefault(AObject::Type,AObject::Type)658 static AObject::Type BinaryInferDefault(AObject::Type, AObject::Type) { return AObject::kTypeAnyValue; }
659
IsSameType(PyObject * a,PyObject * b)660 static bool IsSameType(PyObject *a, PyObject *b) {
661 if (a != nullptr && b != nullptr && PyType_Check(a) && PyType_Check(b)) {
662 return a == b;
663 }
664 return false;
665 }
666
BinaryIs(AObject * l,AObject * r)667 int AObject::BinaryIs(AObject *l, AObject *r) {
668 PyObject *a = l ? l->GetPyObject().ptr() : nullptr;
669 PyObject *b = r ? r->GetPyObject().ptr() : nullptr;
670 const auto &map = const_object_type_map;
671 bool const_a = map.find(a) != map.end();
672 bool const_b = map.find(b) != map.end();
673 // all is const object
674 if (const_a && const_b) {
675 return a == b;
676 }
677 if (IsSameType(a, b)) {
678 return true;
679 }
680 // a const object and a known object
681 if ((const_a && b) || (const_b && a)) {
682 return false;
683 }
684 // a const object and a unknown object, but known it's type
685 if (const_a && r != nullptr && r->GetType() != AObject::kTypeAnyValue && r->GetType() != AObject::kTypeBool) {
686 MS_EXCEPTION_IF_CHECK_FAIL(!const_b, "shouldn't reach here");
687 return false;
688 }
689 if (const_b && l != nullptr && l->GetType() != AObject::kTypeAnyValue && l->GetType() != AObject::kTypeBool) {
690 MS_EXCEPTION_IF_CHECK_FAIL(!const_a, "shouldn't reach here");
691 return false;
692 }
693 return -1;
694 }
695
BinaryContains(AObject * l,AObject * r)696 int AObject::BinaryContains(AObject *l, AObject *r) {
697 PyObject *o = l->GetPyObject().ptr();
698 PyObject *c = r->GetPyObject().ptr();
699 if (c == nullptr || o == nullptr || r->GetType() == AObject::kTypeAnyValue) {
700 return -1;
701 }
702 int res = PySequence_Contains(c, o);
703 CHECK_PYTHON_EXCEPTION(res < 0 ? nullptr : Py_True);
704 return res;
705 }
706
BinaryIs(AObject * l,AObject * r)707 AObject *BinaryIs(AObject *l, AObject *r) {
708 int res = AObject::BinaryIs(l, r);
709 return res == -1 ? AObject::MakeAObject(AObject::kTypeBool) : AObject::Convert(res ? Py_True : Py_False);
710 }
711
BinaryContains(AObject * l,AObject * r)712 AObject *BinaryContains(AObject *l, AObject *r) {
713 int res = AObject::BinaryContains(l, r);
714 return res == -1 ? AObject::MakeAObject(AObject::kTypeBool) : AObject::Convert(res ? Py_True : Py_False);
715 }
716
717 using InferBinaryFunc = AObject *(*)(AObject *, AObject *);
718 using InferBinaryTypeFunc = AObject::Type (*)(AObject::Type, AObject::Type);
719
720 template <binaryfunc pyfunc, InferBinaryTypeFunc type_infer>
InferBinary(AObject * a,AObject * b)721 AObject *InferBinary(AObject *a, AObject *b) {
722 PyObject *l = a->GetPyObject().ptr();
723 PyObject *r = b->GetPyObject().ptr();
724 if (l == nullptr || r == nullptr) {
725 return AObject::MakeAObject(type_infer(a->GetType(), b->GetType()));
726 }
727 if (a->GetType() == AObject::kTypeAnyValue || b->GetType() == AObject::kTypeAnyValue) {
728 return AObject::MakeAObject(AObject::kTypeAnyValue);
729 }
730 PyObject *o = pyfunc(l, r);
731 CHECK_PYTHON_EXCEPTION(o);
732 AObject *res = AObject::Convert(o);
733 Py_XDECREF(o);
734 return res;
735 }
736
737 // the inplace binary operations of known type don't modify original python object
738 // list, tuple, dict already override binary
739 static std::unordered_map<int, InferBinaryFunc> infer_binary_func = {
740 {BINARY_MATRIX_MULTIPLY, InferBinary<PyNumber_MatrixMultiply, BinaryInferDefault>}, // '@'
741 {INPLACE_MATRIX_MULTIPLY, InferBinary<PyNumber_InPlaceMatrixMultiply, BinaryInferDefault>}, // '@='
742 {BINARY_POWER, InferBinary<BinaryPow, NumberArithmetic>}, // '**'
743 {INPLACE_POWER, InferBinary<InplacePow, NumberArithmetic>}, // '**='
744 {BINARY_MULTIPLY, InferBinary<PyNumber_Multiply, NumberArithmetic>}, // '*'
745 {INPLACE_MULTIPLY, InferBinary<PyNumber_InPlaceMultiply, NumberArithmetic>}, // '*='
746 {BINARY_MODULO, InferBinary<PyNumber_Remainder, NumberArithmetic>}, // '%'
747 {INPLACE_MODULO, InferBinary<PyNumber_InPlaceRemainder, NumberArithmetic>}, // '%='
748 {BINARY_ADD, InferBinary<PyNumber_Add, BinaryAdd>},
749 {INPLACE_ADD, InferBinary<PyNumber_InPlaceAdd, BinaryAdd>},
750 {BINARY_SUBTRACT, InferBinary<PyNumber_Subtract, NumberArithmetic>},
751 {INPLACE_SUBTRACT, InferBinary<PyNumber_InPlaceSubtract, NumberArithmetic>},
752 {BINARY_FLOOR_DIVIDE, InferBinary<PyNumber_FloorDivide, NumberArithmetic>}, // '//'
753 {INPLACE_FLOOR_DIVIDE, InferBinary<PyNumber_InPlaceFloorDivide, NumberArithmetic>}, // '//='
754 {BINARY_TRUE_DIVIDE, InferBinary<PyNumber_TrueDivide, NumberArithmetic>},
755 {INPLACE_TRUE_DIVIDE, InferBinary<PyNumber_InPlaceTrueDivide, NumberArithmetic>},
756 {BINARY_LSHIFT, InferBinary<PyNumber_Lshift, BinaryIntOp>},
757 {INPLACE_LSHIFT, InferBinary<PyNumber_InPlaceLshift, BinaryIntOp>},
758 {BINARY_RSHIFT, InferBinary<PyNumber_Rshift, BinaryIntOp>},
759 {INPLACE_RSHIFT, InferBinary<PyNumber_InPlaceRshift, BinaryIntOp>},
760 {BINARY_AND, InferBinary<PyNumber_And, NumberLogic>},
761 {INPLACE_AND, InferBinary<PyNumber_InPlaceAnd, NumberLogic>},
762 {BINARY_XOR, InferBinary<PyNumber_Xor, NumberLogic>},
763 {INPLACE_XOR, InferBinary<PyNumber_InPlaceXor, NumberLogic>},
764 {BINARY_OR, InferBinary<PyNumber_Or, NumberLogic>},
765 {INPLACE_OR, InferBinary<PyNumber_InPlaceOr, NumberLogic>},
766 {CONTAINS_OP, BinaryContains},
767 {IS_OP, BinaryIs}};
768
Binary(AObject * other,int op)769 AObject *AbstractObject::Binary(AObject *other, int op) {
770 if (other == nullptr) {
771 return MakeAObject(kTypeAnyValue);
772 }
773 auto iter = infer_binary_func.find(op);
774 return iter == infer_binary_func.end() ? MakeAObject(kTypeAnyValue) : iter->second(this, other);
775 }
776
BuildAbstractInstance(const std::vector<AObject * > & args,int opcode)777 AObject *AbstractType::BuildAbstractInstance(const std::vector<AObject *> &args, int opcode) {
778 Type type = kTypeAnyValue;
779 PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(value_.ptr());
780 AbstractTuple *res;
781 switch (type_type_) {
782 case kTypeList:
783 case kTypeTuple:
784 res = static_cast<AbstractTuple *>(MakeAObject(type_type_));
785 if (tp != &PyTuple_Type) {
786 return res;
787 }
788 if (args.size() == 0) {
789 res->Update(args);
790 return res;
791 }
792 if (args[0] && (args[0]->GetType() == kTypeTuple || args[0]->GetType() == kTypeList)) {
793 res->Update(static_cast<AbstractTuple *>(args[0])->items());
794 return res;
795 }
796 return res;
797 case kTypeBool:
798 if (args.size() == 0) {
799 return Convert(Py_False);
800 }
801 type = args[0] ? args[0]->GetType() : kTypeAnyValue;
802 if (type == kTypeList || type == kTypeTuple) {
803 AbstractTuple *tmp = static_cast<AbstractTuple *>(args[0]);
804 return tmp->IsElementValid() ? Convert(tmp->size() ? Py_True : Py_False) : MakeAObject(kTypeBool);
805 }
806 if (type == kTypeDict) {
807 AbstractDict *tmp = static_cast<AbstractDict *>(args[0]);
808 return tmp->IsElementValid() ? Convert(tmp->size() ? Py_True : Py_False) : MakeAObject(kTypeBool);
809 }
810 break;
811 default:
812 break;
813 }
814 return MakeAObject(type_type_, tp, nullptr);
815 }
816
817 // this function call object without error
BuildInstance(const std::vector<py::object> & args,int opcode)818 py::object AbstractType::BuildInstance(const std::vector<py::object> &args, int opcode) {
819 if (value_.ptr() == nullptr) {
820 MS_LOG(DEBUG) << "create instance failed, unknown class";
821 return py::object();
822 }
823 auto pair = Utils::PackCallStackArgs(args, opcode, true);
824 if (pair.first.ptr() == nullptr) {
825 MS_LOG(DEBUG) << "create instance failed, unknown opcode or arguments";
826 return py::object();
827 }
828 PyObject *const *vector_args = &PyTuple_GET_ITEM(pair.first.ptr(), 0);
829 Py_ssize_t kw_cnt = pair.second.ptr() == nullptr ? 0 : PyTuple_GET_SIZE(pair.second.ptr());
830 Py_ssize_t nargs = PyTuple_GET_SIZE(pair.first.ptr());
831 PyObject *inst = PyObject_Vectorcall(value_.ptr(), vector_args, nargs - kw_cnt, pair.second.ptr());
832 CHECK_PYTHON_EXCEPTION(inst);
833 return py::reinterpret_steal<py::object>(inst);
834 }
835
Binary(AObject * o,int op)836 AObject *AbstractTuple::Binary(AObject *o, int op) {
837 // generic binary
838 PyObject *r_obj = o ? o->GetPyObject().ptr() : nullptr;
839 if (op == IS_OP) {
840 bool cnst = const_object_type_map.find(r_obj) != const_object_type_map.end();
841 return cnst ? Convert(Py_False) : MakeAObject(kTypeBool);
842 }
843 if (op == CONTAINS_OP) {
844 return infer_binary_func[CONTAINS_OP](this, o);
845 }
846 // tuple binary
847 if (o == nullptr || this->GetType() != o->GetType()) {
848 if (this->GetType() == kTypeList && op == BINARY_MULTIPLY && (o != nullptr) && o->GetType() == kTypeInt) {
849 AbstractTuple *ret = static_cast<AbstractTuple *>(MakeAObject(this->GetType()));
850 std::vector<AObject *> temp;
851 int res = PyLong_AsLong(o->GetPyObject().ptr());
852 for (int i = 0; i < res; i++) {
853 std::copy(items_.begin(), items_.end(), std::back_inserter(temp));
854 }
855 ret->Update(std::move(temp));
856 return ret;
857 }
858 return MakeAObject(kTypeAnyValue);
859 }
860 AbstractTuple *r_list = static_cast<AbstractTuple *>(o);
861 if (!this->IsElementValid() || !r_list->IsElementValid()) {
862 return MakeAObject(kTypeAnyValue);
863 }
864 if (op == BINARY_ADD || (this->GetType() == kTypeTuple && op == INPLACE_ADD)) {
865 AbstractTuple *ret = static_cast<AbstractTuple *>(MakeAObject(this->GetType()));
866 std::vector<AObject *> temp;
867 std::copy(items_.begin(), items_.end(), std::back_inserter(temp));
868 std::copy(r_list->items_.begin(), r_list->items_.end(), std::back_inserter(temp));
869 ret->Update(std::move(temp));
870 return ret;
871 }
872 if (op == INPLACE_ADD) {
873 std::copy(r_list->items_.begin(), r_list->items_.end(), std::back_inserter(this->items_));
874 MarkModify();
875 return this;
876 }
877 // binary mul, inplace mul
878 return MakeAObject(kTypeAnyValue);
879 }
880
Unary(int op) const881 AObject *AbstractTuple::Unary(int op) const {
882 if (op != UNARY_NOT || !this->IsElementValid()) {
883 return MakeAObject(kTypeAnyValue);
884 }
885 return Convert(this->size() > 0 ? Py_True : Py_False);
886 }
887
888 #define RECURSION_CONVERT(iter_expr, get_expr, set_expr, item) \
889 RecMap holder; \
890 if (rec == nullptr) { \
891 rec = &holder; \
892 } \
893 (*rec)[seq.ptr()] = this; \
894 AObject *aobject = nullptr; \
895 PyObject *item = nullptr; \
896 iter_expr { \
897 get_expr; \
898 auto iter = rec->find(item); \
899 if (iter != rec->end()) { \
900 aobject = iter->second; \
901 } else { \
902 Type t = GetPyType(item); \
903 if (t == kTypeList || t == kTypeTuple || t == kTypeDict) { \
904 PyTypeObject *tp = Py_TYPE(item); \
905 aobject = MakeAObject(t, tp, item, rec); \
906 } else { \
907 aobject = Convert(item); \
908 } \
909 } \
910 set_expr; \
911 }
912
AbstractTuple(Type type,py::object seq,RecMap * rec)913 AbstractTuple::AbstractTuple(Type type, py::object seq, RecMap *rec)
914 : AbstractSequence(type, seq),
915 items_(),
916 ms_support_(kBoolUnknown),
917 element_type_(kTypeAnyValue),
918 element_valid_(false),
919 modify_(false) {
920 type_object_ = (type == kTypeList) ? &PyList_Type : &PyTuple_Type;
921 if (!seq.ptr()) {
922 return;
923 }
924 element_valid_ = true;
925 MS_EXCEPTION_IF_CHECK_FAIL(GetPyType(seq.ptr()) == type, std::string("convert ") + GetTypeDesc(type) + " but got " +
926 GetTypeDesc(GetPyType(seq.ptr())));
927 PyObject *o = seq.ptr();
928 Py_ssize_t siz = Py_SIZE(seq.ptr());
929 items_.resize(siz);
930
931 #define ITER_EXPR for (int i = 0; i < siz; ++i)
932 #define GET_EXPR (item = (type == kTypeList) ? PyList_GET_ITEM(o, i) : PyTuple_GET_ITEM(o, i))
933 #define SET_EXPR items_[i] = aobject
934 RECURSION_CONVERT(ITER_EXPR, GET_EXPR, SET_EXPR, item);
935 #undef ITER_EXPR
936 #undef GET_EXPR
937 #undef SET_EXPR
938
939 // copy it
940 Update();
941 }
942
AbstractDict(Type type,py::object seq,RecMap * rec)943 AbstractDict::AbstractDict(Type type, py::object seq, RecMap *rec)
944 : AbstractSequence(type, seq),
945 dict_(),
946 k_type_(kTypeAnyValue),
947 v_type_(kTypeAnyValue),
948 element_valid_(false),
949 modify_(false) {
950 type_object_ = &PyDict_Type;
951 if (!seq.ptr()) {
952 return;
953 }
954 element_valid_ = true;
955 MS_EXCEPTION_IF_CHECK_FAIL(GetPyType(seq.ptr()) == type, std::string("convert ") + GetTypeDesc(type) + ", but got " +
956 GetTypeDesc(GetPyType(seq.ptr())));
957 PyObject *m = dict_.ptr();
958 PyObject *k;
959 Py_ssize_t p = 0;
960
961 #define ITER_EXPR while (PyDict_Next(seq.ptr(), &p, &k, &item))
962 #define GET_EXPR PRINT_IF_HAS_USER_DEFINED_HOOK(k, __hash__)
963 #define SET_EXPR PyDict_SetItem(m, k, ConvertValue(aobject).ptr())
964 RECURSION_CONVERT(ITER_EXPR, GET_EXPR, SET_EXPR, item);
965 #undef ITER_EXPR
966 #undef SET_EXPR
967 #undef GET_EXPR
968
969 // copy it
970 Update();
971 }
972
973 #undef RECURSION_CONVERT
974
IsMindSporeSupportedType()975 bool AbstractTuple::IsMindSporeSupportedType() {
976 if (ms_support_ != kBoolUnknown) {
977 return ms_support_ == kBoolTrue;
978 }
979 ms_support_ = kBoolFalse;
980 if (kMsSupportedType.find(element_type_) != kMsSupportedType.end()) {
981 ms_support_ = kBoolTrue;
982 return true;
983 }
984 if (!this->IsElementValid()) {
985 return false;
986 }
987 for (auto i : *this) {
988 if (!i) {
989 return false;
990 }
991 if (!i->IsMindSporeSupportedType()) {
992 return false;
993 }
994 }
995 ms_support_ = kBoolTrue;
996 return true;
997 }
998
IsMindSporeSupportedType()999 bool AbstractDict::IsMindSporeSupportedType() {
1000 if (kMsSupportedType.find(k_type_) != kMsSupportedType.end() &&
1001 kMsSupportedType.find(v_type_) != kMsSupportedType.end()) {
1002 return true;
1003 }
1004 if (this->IsElementValid()) {
1005 for (auto i : *this) {
1006 if (!i) {
1007 return false;
1008 }
1009 Type t = i->GetType();
1010 if (t == kTypeList || t == kTypeTuple || t == kTypeDict) {
1011 // check self reference object
1012 return false;
1013 }
1014 if (!i->IsMindSporeSupportedType()) {
1015 return false;
1016 }
1017 }
1018 return true;
1019 }
1020 return false;
1021 }
1022
ToString() const1023 std::string AbstractTuple::ToString() const {
1024 std::stringstream s;
1025 s << this->AObject::ToString() << "<" << GetTypeDesc(element_type_) << ">";
1026 if (this->IsElementValid()) {
1027 s << " size:" << this->size();
1028 } else {
1029 s << "<NoSize>";
1030 }
1031 return s.str();
1032 }
1033
ToString() const1034 std::string AbstractDict::ToString() const {
1035 std::stringstream s;
1036 s << this->AObject::ToString() << '<' << GetTypeDesc(k_type_) << ',' << GetTypeDesc(v_type_) << '>';
1037 if (this->IsElementValid()) {
1038 s << " size:" << size();
1039 } else {
1040 s << "<NoSize>";
1041 }
1042 return s.str();
1043 }
1044
1045 /**
1046 * cast to Py_ssize_t, call hook __index__ by PyNumber_AsSsize_t
1047 * \return -1 if key error, out of bound, overflow to cast Py_ssize_t
1048 */
GetTupleIndex(AObject * k,Py_ssize_t size)1049 static Py_ssize_t GetTupleIndex(AObject *k, Py_ssize_t size) {
1050 Py_ssize_t index = PyLong_AsSsize_t(k->GetPyObject().ptr());
1051 if (PyErr_Occurred()) {
1052 PyErr_Clear();
1053 return -1;
1054 }
1055 if (index < -size || index >= size) {
1056 return -1;
1057 }
1058 index = index < 0 ? (size + index) : index;
1059 return index;
1060 }
1061
SetItem(AObject * k,AObject * v)1062 bool AbstractList::SetItem(AObject *k, AObject *v) {
1063 MarkModify();
1064 if (k == nullptr || k->GetType() == AObject::kTypeAnyValue || k->GetPyObject().ptr() == nullptr) {
1065 // user defined index or unknown key
1066 this->AbstractSequence::SetItem(k, v);
1067 return true;
1068 }
1069 if (!IsElementValid()) {
1070 return true;
1071 }
1072 Py_ssize_t index = GetTupleIndex(k, this->size());
1073 if (index == -1) {
1074 MarkElementInValid();
1075 return false;
1076 }
1077 items_[index] = v;
1078 element_type_ = v->GetType() == element_type_ ? element_type_ : kTypeAnyValue;
1079 return true;
1080 }
1081
GetItem(AObject * k)1082 AObject *AbstractTuple::GetItem(AObject *k) {
1083 if (k == nullptr || k->GetType() == AObject::kTypeAnyValue || k->GetPyObject().ptr() == nullptr) {
1084 // user defined index or unknown key
1085 return this->AbstractSequence::GetItem(k);
1086 }
1087 if (!IsElementValid()) {
1088 return AObject::MakeAObject(element_type_);
1089 }
1090 if (k->GetType() == AObject::kTypeSlice) {
1091 if (this->GetPyObject().ptr() != nullptr) {
1092 return this->AbstractSequence::GetItem(k);
1093 }
1094 AObject *resultTuple = AObject::MakeAObject(this->type_);
1095 PyObject *slicePyObject = k->GetPyObject().ptr();
1096 Py_ssize_t start;
1097 Py_ssize_t stop;
1098 Py_ssize_t step;
1099 if (PySlice_Unpack(slicePyObject, &start, &stop, &step) < 0) {
1100 return AObject::MakeAObject(kTypeAnyValue);
1101 }
1102 if (start >= stop) {
1103 return resultTuple;
1104 }
1105 Py_ssize_t sliceLength = PySlice_AdjustIndices(this->items().size(), &start, &stop, step);
1106 AbstractTuple *resultTuplePtr = static_cast<AbstractTuple *>(resultTuple);
1107 if (start == 0 && step == 1 && sliceLength == this->size()) {
1108 return this;
1109 }
1110 if (step > 1) {
1111 int cursor = 0;
1112 std::vector<AObject *> itemsVector;
1113 for (cursor = 0; cursor < stop; cursor += step) {
1114 itemsVector.push_back(this->items()[cursor]);
1115 }
1116 resultTuplePtr->Update(itemsVector);
1117 return resultTuplePtr;
1118 }
1119 return AObject::MakeAObject(kTypeAnyValue);
1120 }
1121 Py_ssize_t index = GetTupleIndex(k, this->size());
1122 if (index == -1) {
1123 return AObject::MakeAObject(kTypeAnyValue);
1124 }
1125 return items_[index];
1126 }
1127
1128 #undef GET_INDEX
1129
GetAttr(const std::string & name)1130 AObject *AbstractTuple::GetAttr(const std::string &name) {
1131 py::object list = (type_ == kTypeList) ? (py::object)py::list() : py::tuple();
1132 PyObject *attr = PyObject_GetAttrString(list.ptr(), name.c_str());
1133 CHECK_PYTHON_EXCEPTION(attr);
1134 if (attr == nullptr) {
1135 FIND_MAP_CACHE(attrs_, name);
1136 }
1137 AObject *res = Convert(attr);
1138 Py_XDECREF(attr);
1139 return res;
1140 }
1141
Unary(int op) const1142 AObject *AbstractDict::Unary(int op) const {
1143 if (op != UNARY_NOT || !this->IsElementValid()) {
1144 return MakeAObject(kTypeAnyValue);
1145 }
1146 return Convert(this->size() ? Py_True : Py_False);
1147 }
1148
Binary(AObject * other,int op)1149 AObject *AbstractDict::Binary(AObject *other, int op) {
1150 if (op == IS_OP) {
1151 PyObject *b = other ? other->GetPyObject().ptr() : nullptr;
1152 bool cnst = const_object_type_map.find(b) != const_object_type_map.end();
1153 return cnst ? Convert(Py_False) : MakeAObject(kTypeBool);
1154 }
1155 if (op == CONTAINS_OP && other != nullptr) {
1156 return infer_binary_func[CONTAINS_OP](this, other);
1157 }
1158 return MakeAObject(kTypeAnyValue);
1159 }
1160
GetAttr(const std::string & name)1161 AObject *AbstractDict::GetAttr(const std::string &name) {
1162 if (value_.ptr() == nullptr) {
1163 return AObject::MakeAObject(kTypeAnyValue);
1164 }
1165 PyObject *attr = PyObject_GetAttrString(value_.ptr(), name.c_str());
1166 CHECK_PYTHON_EXCEPTION(attr);
1167 AObject *res = Convert(attr);
1168 Py_XDECREF(attr);
1169 return res;
1170 }
1171
GetItem(AObject * k)1172 AObject *AbstractDict::GetItem(AObject *k) {
1173 auto iter = this->write_cache_.find(k);
1174 if (iter != this->write_cache_.end()) {
1175 return iter->second == nullptr ? MakeAObject(kTypeAnyValue) : iter->second;
1176 }
1177 if (!IsElementValid()) {
1178 return MakeAObject(v_type_);
1179 }
1180 PyObject *key = k ? k->GetPyObject().ptr() : nullptr;
1181 if (key == nullptr) {
1182 return MakeAObject(kTypeAnyValue);
1183 }
1184 PRINT_IF_HAS_USER_DEFINED_HOOK(key, __hash__);
1185 PyObject *item = PyDict_GetItem(dict_.ptr(), key);
1186 return item == nullptr ? MakeAObject(kTypeAnyValue) : ConvertValue(item);
1187 }
1188
DictMerge(AObject * o,int update)1189 bool AbstractDict::DictMerge(AObject *o, int update) {
1190 MarkModify();
1191 if (!IsElementValid()) {
1192 return true;
1193 }
1194 if (o == nullptr || o->GetType() != kTypeDict) {
1195 MarkElementInValid();
1196 return true;
1197 }
1198 AbstractDict *d = static_cast<AbstractDict *>(o);
1199 if (!d->IsElementValid() || PyDict_Merge(dict_.ptr(), d->dict_.ptr(), update)) {
1200 MarkElementInValid();
1201 CHECK_PYTHON_EXCEPTION(nullptr);
1202 // unknown user defined dict merge, assume it success
1203 }
1204 if (size() == 0) {
1205 this->k_type_ = d->k_type_;
1206 this->v_type_ = d->v_type_;
1207 } else {
1208 this->k_type_ = this->k_type_ == d->k_type_ ? this->k_type_ : kTypeAnyValue;
1209 this->v_type_ = this->v_type_ == d->v_type_ ? this->v_type_ : kTypeAnyValue;
1210 }
1211 return true;
1212 }
1213
DictUpdate(AObject * o)1214 bool AbstractDict::DictUpdate(AObject *o) { return DictMerge(o, 1); }
1215
MapAdd(AObject * k,AObject * v)1216 bool AbstractDict::MapAdd(AObject *k, AObject *v) {
1217 if (v == nullptr) {
1218 MarkElementInValid();
1219 return true; // assume it success
1220 }
1221 if (size() == 0) {
1222 this->k_type_ = k->GetType();
1223 this->v_type_ = v->GetType();
1224 } else {
1225 this->k_type_ = this->k_type_ == k->GetType() ? this->k_type_ : kTypeAnyValue;
1226 this->v_type_ = this->v_type_ == v->GetType() ? this->v_type_ : kTypeAnyValue;
1227 }
1228 return SetItem(k, v);
1229 }
1230
ListAppend(AObject * item)1231 bool AbstractList::ListAppend(AObject *item) {
1232 MarkModify();
1233 if (!IsElementValid()) {
1234 return true;
1235 }
1236 if (size() == 0) {
1237 this->element_type_ = item->GetType();
1238 } else if (this->element_type_ != item->GetType()) {
1239 this->element_type_ = kTypeAnyValue;
1240 }
1241 items_.push_back(item);
1242 return true;
1243 }
1244
ListExtend(AObject * l)1245 bool AbstractList::ListExtend(AObject *l) {
1246 MarkModify();
1247 if (!IsElementValid()) {
1248 return true;
1249 }
1250 if (l == nullptr || (l->GetType() != kTypeTuple && l->GetType() != kTypeList)) {
1251 MarkElementInValid();
1252 return true;
1253 }
1254 AbstractTuple *i = static_cast<AbstractTuple *>(l);
1255 if (!i->IsElementValid()) {
1256 MarkElementInValid();
1257 return true;
1258 }
1259 if (size() == 0) {
1260 this->element_type_ = i->GetElementType();
1261 } else {
1262 this->element_type_ = this->GetElementType() == i->GetElementType() ? this->GetElementType() : kTypeAnyValue;
1263 }
1264 std::copy(i->items().begin(), i->items().end(), std::back_inserter(items_));
1265 return true;
1266 }
1267
ListToTuple()1268 AbstractTuple *AbstractList::ListToTuple() {
1269 AbstractTuple *res = static_cast<AbstractTuple *>(MakeAObject(kTypeTuple));
1270 if (!IsElementValid()) {
1271 return res;
1272 }
1273 res->SetElementType(this->element_type_);
1274 res->Update(this->items_);
1275 return res;
1276 }
1277
Update(const std::vector<AObject * > & item)1278 bool AbstractTuple::Update(const std::vector<AObject *> &item) {
1279 this->element_valid_ = true;
1280 this->items_ = item;
1281 if (this->items_.size() != 0 && items_[0] != nullptr) {
1282 this->element_type_ = items_[0]->GetType();
1283 bool any = item.end() != std::find_if(item.begin(), item.end(), [this](AObject *i) {
1284 return i ? i->GetType() != this->element_type_ : true;
1285 });
1286 this->element_type_ = any ? kTypeAnyValue : this->element_type_;
1287 }
1288 return Update();
1289 }
1290
Update()1291 bool AbstractTuple::Update() {
1292 if (!this->IsElementValid()) {
1293 return false;
1294 }
1295 if (trace_flag_) {
1296 return true;
1297 }
1298 this->element_type_ = kTypeAnyValue;
1299 // copy it
1300 PyObject *c = (this->type_ == kTypeTuple) ? PyTuple_New(items_.size()) : PyList_New(items_.size());
1301 modify_ = false;
1302 value_ = py::reinterpret_steal<py::object>(c);
1303 for (size_t i = 0; i < items_.size(); i++) {
1304 py::object item = (items_[i] != nullptr) ? items_[i]->GetPyObject() : py::object();
1305 if (item.ptr() == nullptr) {
1306 value_ = py::object();
1307 return false;
1308 }
1309 if (this->type_ == kTypeTuple) {
1310 PyTuple_SET_ITEM(c, i, item.inc_ref().ptr());
1311 } else {
1312 PyList_SET_ITEM(c, i, item.inc_ref().ptr());
1313 }
1314 if (i == 0) {
1315 this->element_type_ = items_[i]->GetType();
1316 } else {
1317 this->element_type_ = this->element_type_ == items_[i]->GetType() ? this->element_type_ : kTypeAnyValue;
1318 }
1319 }
1320 return true;
1321 }
1322
GetPyObject()1323 py::object AbstractList::GetPyObject() {
1324 if (this->write_cache_.size()) {
1325 // see SetItem, can't update unknown value to list
1326 return py::object();
1327 }
1328 if (modify_ && !Update()) {
1329 return py::object();
1330 }
1331 return value_;
1332 }
1333
Update()1334 bool AbstractDict::Update() {
1335 if (trace_flag_) {
1336 return true;
1337 }
1338 value_ = py::object();
1339 for (auto i : this->write_cache_) {
1340 PyObject *key = i.first == nullptr ? nullptr : i.first->GetPyObject().ptr();
1341 if (key == nullptr || -1 == PyDict_SetItem(dict_.ptr(), key, ConvertValue(i.second).ptr())) {
1342 MarkElementInValid();
1343 PyErr_Clear();
1344 return false;
1345 }
1346 }
1347 this->write_cache_.clear();
1348 // copy it
1349 value_ = py::dict();
1350 PyObject *k;
1351 PyObject *v;
1352 Py_ssize_t p = 0;
1353 bool init_element_type = false;
1354 while (PyDict_Next(dict_.ptr(), &p, &k, &v)) {
1355 AObject *i = ConvertValue(v);
1356 py::object item = i != nullptr ? i->GetPyObject() : py::object();
1357 if (item.ptr() == nullptr) {
1358 value_ = py::object();
1359 break;
1360 }
1361 PyDict_SetItem(value_.ptr(), k, item.ptr());
1362 if (init_element_type) {
1363 k_type_ = k_type_ == GetPyType(k) ? k_type_ : kTypeAnyValue;
1364 v_type_ = v_type_ == i->GetType() ? v_type_ : kTypeAnyValue;
1365 } else {
1366 k_type_ = GetPyType(k);
1367 v_type_ = i->GetType();
1368 init_element_type = true;
1369 }
1370 }
1371 return true;
1372 }
1373
GetPyObject()1374 py::object AbstractDict::GetPyObject() {
1375 if (!IsElementValid()) {
1376 return py::object();
1377 }
1378 if (!IsModify()) {
1379 return value_;
1380 }
1381 Update();
1382 return value_;
1383 }
1384
GetTensor(bool sync)1385 py::object AbstractTensor::GetTensor(bool sync) {
1386 if (!is_stub_ || !sync) {
1387 return value_;
1388 }
1389 std::string attr_key = "tensor";
1390 auto iter = attrs_.find(attr_key);
1391 if (iter != attrs_.end()) {
1392 return iter->second->GetPyObject();
1393 }
1394 PyObject *res = PyObject_GetAttrString(value_.ptr(), attr_key.c_str());
1395 if (res != nullptr && res != Py_None) {
1396 attrs_[attr_key] = AObject::Convert(res);
1397 return py::reinterpret_steal<py::object>(res);
1398 }
1399 if (res == nullptr) {
1400 PyErr_Clear();
1401 } else {
1402 Py_DECREF(res);
1403 }
1404 PyObject *meth = PyObject_GetAttrString(value_.ptr(), "stub_sync");
1405 MS_EXCEPTION_IF_CHECK_FAIL(meth && PyMethod_Check(meth), "check value");
1406 res = PyObject_Call(meth, py::tuple().ptr(), nullptr);
1407 Py_DECREF(meth);
1408 CHECK_PYTHON_EXCEPTION(res);
1409 attrs_[attr_key] = AObject::Convert(res);
1410 return py::reinterpret_steal<py::object>(res);
1411 }
1412
PyObjectToAbstract(const py::object & arg)1413 AbstractBasePtr PyObjectToAbstract(const py::object &arg) {
1414 ValuePtr converted = nullptr;
1415 bool success;
1416 if (IsStubTensor(arg)) {
1417 success = mindspore::parse::ConvertStubData(arg, &converted);
1418 } else {
1419 success = mindspore::parse::ConvertData(arg, &converted);
1420 }
1421 if (!success) {
1422 MS_LOG(EXCEPTION) << "Fail to convert the object: " << py::str(arg);
1423 }
1424 auto res = GraphUtils::ArgsToAbstract(arg, converted, false);
1425 if (res->isa<mindspore::abstract::AbstractTensor>()) {
1426 bool check = CheckAdapterTensor(arg);
1427 dyn_cast_ptr<mindspore::abstract::AbstractTensor>(res)->set_is_adapter(check);
1428 }
1429 return res;
1430 }
1431
TensorInferBinarySupport(int opcode)1432 bool TensorInferBinarySupport(int opcode) {
1433 static const std::set<int> support_op = {
1434 BINARY_POWER, BINARY_MULTIPLY, BINARY_MODULO, BINARY_ADD,
1435 BINARY_SUBTRACT, BINARY_SUBSCR, BINARY_FLOOR_DIVIDE, BINARY_TRUE_DIVIDE,
1436 INPLACE_FLOOR_DIVIDE, INPLACE_TRUE_DIVIDE, INPLACE_ADD, INPLACE_SUBTRACT,
1437 INPLACE_MULTIPLY, INPLACE_MODULO, BINARY_LSHIFT, BINARY_RSHIFT,
1438 BINARY_AND, BINARY_XOR, BINARY_OR, INPLACE_POWER,
1439 INPLACE_LSHIFT, INPLACE_RSHIFT, INPLACE_AND, INPLACE_XOR,
1440 INPLACE_OR,
1441 };
1442
1443 return support_op.find(opcode) != support_op.end();
1444 }
1445
InferWithMetaFunc(const AbstractBasePtr & left,const AbstractBasePtr & right,int opcode)1446 mindspore::abstract::AbstractTensorPtr InferWithMetaFunc(const AbstractBasePtr &left, const AbstractBasePtr &right,
1447 int opcode) {
1448 auto func = GraphUtils::GetPrimOrMetaFuncGraph(opcode);
1449 auto res = mindspore::pipeline::AbstractAnalyze(GetValueNode(func), {left, right});
1450 return dyn_cast<mindspore::abstract::AbstractTensor>(res.eval_result->abstract());
1451 }
1452
InferWithPrim(const AbstractBasePtr & left,const AbstractBasePtr & right,int opcode)1453 mindspore::abstract::AbstractTensorPtr InferWithPrim(const AbstractBasePtr &left, const AbstractBasePtr &right,
1454 int opcode) {
1455 static std::unordered_map<int, PrimitivePtr> prim_func = {{BINARY_ADD, prim::kPrimAdd},
1456 {BINARY_SUBTRACT, prim::kPrimSub},
1457 {BINARY_MULTIPLY, prim::kPrimMul},
1458 {BINARY_TRUE_DIVIDE, prim::kPrimDiv},
1459 {BINARY_FLOOR_DIVIDE, prim::kPrimFloorDiv}};
1460
1461 auto left_dtype_ptr = dyn_cast_ptr<mindspore::abstract::AbstractTensor>(left)->element()->BuildType();
1462 MS_EXCEPTION_IF_NULL(left_dtype_ptr);
1463 auto right_dtype_ptr = dyn_cast_ptr<mindspore::abstract::AbstractTensor>(right)->element()->BuildType();
1464 MS_EXCEPTION_IF_NULL(right_dtype_ptr);
1465 if (left_dtype_ptr->type_id() != right_dtype_ptr->type_id() || prim_func.find(opcode) == prim_func.end()) {
1466 return InferWithMetaFunc(left, right, opcode);
1467 }
1468
1469 auto func = prim_func.find(opcode)->second;
1470 auto infer_res = mindspore::abstract::TryInferAbstract(func, {left, right});
1471 if (infer_res.has_value()) {
1472 MS_EXCEPTION_IF_NULL(infer_res.value());
1473 return dyn_cast<mindspore::abstract::AbstractTensor>(infer_res.value());
1474 } else {
1475 return nullptr;
1476 }
1477 }
1478
TensorInferBinary(const AbstractBasePtr & left,const AbstractBasePtr & right,int opcode)1479 py::object TensorInferBinary(const AbstractBasePtr &left, const AbstractBasePtr &right, int opcode) {
1480 mindspore::abstract::AbstractTensorPtr abs;
1481 if (right->isa<mindspore::abstract::AbstractTensor>()) {
1482 abs = InferWithPrim(left, right, opcode);
1483 } else if (right->isa<mindspore::abstract::AbstractScalar>()) {
1484 auto new_right = std::make_shared<mindspore::abstract::AbstractTensor>(right);
1485 abs = InferWithPrim(left, new_right, opcode);
1486 } else {
1487 abs = InferWithMetaFunc(left, right, opcode);
1488 }
1489 MS_EXCEPTION_IF_NULL(abs);
1490 auto dtype_ptr = abs->element()->BuildType();
1491 MS_EXCEPTION_IF_NULL(dtype_ptr);
1492 auto shape_ptr = abs->BuildShape();
1493 MS_EXCEPTION_IF_NULL(shape_ptr);
1494 auto shape = shape_ptr->cast<mindspore::abstract::ShapePtr>()->shape();
1495 auto dtype = dtype_ptr->type_id();
1496 auto tensor = std::make_shared<mindspore::tensor::Tensor>(dtype, shape);
1497 return py::cast(tensor);
1498 }
1499
Binary(int op,const py::object & l_tensor,const py::object & r_tensor)1500 py::object AbstractTensor::Binary(int op, const py::object &l_tensor, const py::object &r_tensor) {
1501 auto left = PyObjectToAbstract(l_tensor);
1502 auto right = PyObjectToAbstract(r_tensor);
1503 auto res = TensorInferBinary(left, right, op);
1504 if (CheckAdapterTensor(l_tensor)) {
1505 res = ConvertToAdapterTensor(res);
1506 } else {
1507 res = ConvertToMsTensor(res);
1508 }
1509 return res;
1510 }
1511
Binary(AObject * other,int op)1512 AObject *AbstractTensor::Binary(AObject *other, int op) {
1513 if (op == IS_OP) {
1514 PyTypeObject *b = other ? other->GetTypeObject() : nullptr;
1515 PyTypeObject *a = GetTypeObject();
1516 return a != b && b != nullptr ? Convert(Py_False) : MakeAObject(kTypeBool);
1517 }
1518
1519 if (other == nullptr || GetPyObject().ptr() == nullptr || !TensorInferBinarySupport(op)) {
1520 return MakeAObject(kTypeTensor);
1521 }
1522
1523 AbstractBasePtr left = PyObjectToAbstract(this->GetPyObject());
1524 AbstractBasePtr right;
1525 if (other->GetPyObject().ptr() == nullptr) {
1526 // if other is scalar with empty value, then transfer to AbstractScalar
1527 // else return any value
1528 switch (other->GetType()) {
1529 case kTypeBool:
1530 right = std::make_shared<mindspore::abstract::AbstractScalar>(kValueAny, kBool);
1531 break;
1532 case kTypeInt:
1533 right = std::make_shared<mindspore::abstract::AbstractScalar>(kValueAny, kInt32);
1534 break;
1535 case kTypeFloat:
1536 right = std::make_shared<mindspore::abstract::AbstractScalar>(kValueAny, kFloat32);
1537 break;
1538 default:
1539 return MakeAObject(kTypeAnyValue);
1540 }
1541 } else {
1542 right = PyObjectToAbstract(other->GetPyObject());
1543 }
1544 auto res = TensorInferBinary(left, right, op);
1545 if (CheckAdapterTensor(value_)) {
1546 res = ConvertToAdapterTensor(res);
1547 } else {
1548 res = ConvertToMsTensor(res);
1549 }
1550 return Convert(res);
1551 }
1552
GetItem(AObject * key)1553 AObject *AbstractTensor::GetItem(AObject *key) {
1554 PyObject *s = value_.ptr();
1555 PyObject *i = key ? key->GetPyObject().ptr() : nullptr;
1556 PyObject *t = nullptr;
1557 if (s != nullptr && i != nullptr) {
1558 // avoid Tensor as index and Tensor data sync
1559 t = PyObject_GetItem(s, i);
1560 CHECK_PYTHON_EXCEPTION(t);
1561 } else {
1562 return MakeAObject(kTypeAnyValue);
1563 }
1564 py::object res = py::reinterpret_steal<py::object>(t);
1565 if (CheckAdapterTensor(value_)) {
1566 res = ConvertToAdapterTensor(res);
1567 } else {
1568 res = ConvertToMsTensor(res);
1569 }
1570 return Convert(res);
1571 }
1572
Unary(int op) const1573 AObject *AbstractTensor::Unary(int op) const {
1574 if (this->value_.ptr() != nullptr) {
1575 return this->AbstractObject::UnaryValue(op);
1576 }
1577 if (op == UNARY_POSITIVE) {
1578 return const_cast<AbstractTensor *>(this);
1579 } else if (op == UNARY_NEGATIVE || op == UNARY_INVERT) {
1580 AbstractTensor *res = static_cast<AbstractTensor *>(MakeAObject(kTypeTensor));
1581 auto it = attrs_.find("shape");
1582 if (it != attrs_.end()) {
1583 res->attrs_["shape"] = it->second;
1584 }
1585 it = attrs_.find("dtype");
1586 if (it != attrs_.end()) {
1587 res->attrs_["dtype"] = it->second;
1588 }
1589 return res;
1590 } else if (op == UNARY_NOT) {
1591 auto it = attrs_.find("shape");
1592 if (it == attrs_.end() || it->second == nullptr) {
1593 return MakeAObject(kTypeTensor);
1594 }
1595 AObject *shape_info = it->second;
1596 PyObject *shape = shape_info->GetPyObject().ptr();
1597 Py_ssize_t ndim = PyTuple_GET_SIZE(shape);
1598 if (ndim == 0 || (ndim == 1 && PyLong_AS_LONG(PyTuple_GET_ITEM(shape, 0))) == 1) {
1599 return MakeAObject(kTypeBool);
1600 }
1601 return MakeAObject(kTypeAnyValue);
1602 }
1603 return MakeAObject(kTypeAnyValue);
1604 }
1605
1606 static const std::unordered_map<std::string, AObject::Type> tensor_attr_type = {
1607 // py Tensor property
1608 {"shape", AObject::kTypeTuple},
1609 {"dtype", AObject::kTypeMSDType},
1610 {"size", AObject::kTypeInt},
1611 {"itemsize", AObject::kTypeInt},
1612 {"nbytes", AObject::kTypeInt},
1613 {"strides", AObject::kTypeTuple},
1614 {"ndim", AObject::kTypeInt},
1615 {"has_init", AObject::kTypeBool},
1616 {"H", AObject::kTypeTensor},
1617 {"mH", AObject::kTypeTensor},
1618 {"T", AObject::kTypeTensor},
1619 {"mT", AObject::kTypeTensor},
1620 // cpp Tensor property
1621 {"_shape", AObject::kTypeTuple},
1622 {"_dtype", AObject::kTypeMSDType},
1623 {"_size", AObject::kTypeInt},
1624 {"_itemsize", AObject::kTypeInt},
1625 {"_nbytes", AObject::kTypeInt},
1626 {"_strides", AObject::kTypeTuple},
1627 {"init_flag", AObject::kTypeBool},
1628 {"adapter_flag", AObject::kTypeBool},
1629 {"param_info", AObject::kTypeAnyValue},
1630 };
1631
1632 // return an uninitialized python tensor
GetUninitializedTensor()1633 static PyObject *GetUninitializedTensor() {
1634 static PyObject *tensor = nullptr;
1635 if (tensor != nullptr) {
1636 return tensor;
1637 }
1638 py::object py_cls = Utils::GetModuleAttr("mindspore", "Tensor", false, true);
1639 py::object cpp_cls = Utils::GetModuleAttr("mindspore._c_expression", "Tensor", false, true);
1640 py::object dtype = Utils::GetModuleAttr("mindspore", "int32", false, true);
1641 py::tuple shape;
1642 tensor = py_cls(cpp_cls(dtype, shape)).inc_ref().ptr();
1643 return tensor;
1644 }
1645
AbstractTensor(const py::object & o,bool is_stub)1646 AbstractTensor::AbstractTensor(const py::object &o, bool is_stub) : AbstractObject(kTypeTensor, o), is_stub_(is_stub) {}
1647
GetAttr(const std::string & name)1648 AObject *AbstractTensor::GetAttr(const std::string &name) {
1649 if (value_.ptr()) {
1650 return this->AbstractObject::GetAttr(name);
1651 }
1652
1653 PyObject *tmp = GetUninitializedTensor();
1654 if (type_object_ != Py_TYPE(tmp)) {
1655 // tensor subclass or StubTensor and it's subclass
1656 // generic attribute
1657 AObject *attr = this->AbstractObjectBase::GetAttr(name);
1658 attrs_[name] = attr;
1659 return attr;
1660 }
1661 // get attribute for exact mindspore.Tensor,
1662 // not MetaTensor, not mindspore._c_expression.Tensor, not StubTensor
1663
1664 // known @property attribute
1665 auto iter = tensor_attr_type.find(name);
1666 if (iter != tensor_attr_type.end()) {
1667 AObject *attr = MakeAObject(iter->second);
1668 if (iter->second == kTypeTuple) {
1669 static_cast<AbstractTuple *>(attr)->SetElementType(kTypeInt);
1670 }
1671 attrs_[name] = attr;
1672 return attr;
1673 }
1674
1675 // know function attribute
1676 PyObject *op = PyObject_GetAttrString(tmp, name.c_str());
1677 AObject *attr = Convert(op);
1678 if (op == nullptr) {
1679 PyErr_Clear();
1680 } else {
1681 Py_DECREF(op);
1682 }
1683 if (attr->GetType() == kTypeBoundMethod) {
1684 attr->SetAttr("__self__", this);
1685 Py_INCREF(Py_None);
1686 Py_SETREF(PyMethod_GET_SELF(op), Py_None);
1687 } else {
1688 // not initialized attribute is not accept
1689 attr = MakeAObject(kTypeAnyValue);
1690 }
1691 attrs_[name] = attr;
1692 return attr;
1693 }
1694
ToString() const1695 std::string AbstractTensor::ToString() const {
1696 std::stringstream s;
1697 py::object dtype;
1698 py::object shape;
1699 std::stringstream extra_info;
1700 if (value_.ptr()) {
1701 dtype = value_.attr("dtype");
1702 shape = value_.attr("shape");
1703 if (is_stub_) {
1704 extra_info << "stub_tensor ";
1705 }
1706 extra_info << "init=" << (CheckTensorDataInitialized(value_) ? "True" : "False");
1707 }
1708 s << this->AbstractObjectBase::ToString() << '\'' << std::string(py::str(dtype.ptr())) << ','
1709 << std::string(py::str(shape.ptr())) << "' " << extra_info.str() << ' ';
1710 return s.str();
1711 }
1712
1713 } // namespace pijit
1714 } // namespace mindspore
1715