• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pipeline/jit/pi/graph_guard/infer.h"
17 #include <map>
18 #include <string>
19 #include <functional>
20 #include <unordered_set>
21 #include <utility>
22 #include "base/base.h"
23 #include "abstract/ops/primitive_infer_map.h"
24 #include "ops/auto_generate/gen_ops_primitive.h"
25 #include "pybind_api/ir/primitive_py.h"
26 #include "include/common/utils/convert_utils_py.h"
27 #include "include/common/utils/stub_tensor.h"
28 #include "ir/anf.h"
29 #include "utils/flags.h"
30 #include "pipeline/jit/pi/utils/utils.h"
31 #include "pipeline/jit/ps/static_analysis/static_analysis.h"
32 #include "frontend/operator/composite/composite.h"
33 #include "ir/cell.h"
34 #include "pipeline/jit/ps/resource.h"
35 #include "pipeline/jit/pi/pydef.h"
36 #include "pipeline/jit/pi/graph_guard/guard_utils.h"
37 #include "pipeline/jit/ps/parse/data_converter.h"
38 #include "pipeline/jit/ps/action.h"
39 #include "pipeline/jit/pi/graph_build/func_graph_builder.h"
40 
41 namespace mindspore {
42 namespace parse {
43 extern bool ConvertData(const py::object &obj, mindspore::ValuePtr *data, bool use_signature,
44                         const mindspore::TypePtr &dtype, bool forbid_reuse);
45 }
46 
47 namespace abstract {
48 extern mindspore::abstract::AbstractBasePtr ToAbstract(const mindspore::ValuePtr &value,
49                                                        const mindspore::abstract::AnalysisContextPtr &context,
50                                                        const mindspore::abstract::AnfNodeConfigPtr &conf);
51 extern std::optional<StandardPrimitiveImplReg> GetPrimitiveInferImpl(const PrimitivePtr &primitive);
52 }  // namespace abstract
53 
54 namespace pijit {
55 
56 static InferEnginePtr g_pInferEngine = nullptr;
57 constexpr const int ArgsSizeTwo = 2;
58 
59 template <>
IsPrimitiveFunctionType(PyTypeObject * tp)60 bool IsPrimitiveFunctionType<true>(PyTypeObject *tp) {
61   return IsPybindType<mindspore::PrimitiveFunctionAdapter, true>(tp);
62 }
63 
GetInstance()64 InferEnginePtr InferEngine::GetInstance() {
65   if (g_pInferEngine == nullptr) {
66     g_pInferEngine = std::shared_ptr<InferEngine>(new InferEngine());
67   }
68   if (g_pInferEngine->Init()) {
69     return g_pInferEngine;
70   } else {
71     return nullptr;
72   }
73 }
74 
InferEngine()75 InferEngine::InferEngine() {}
76 
Init()77 bool InferEngine::Init() {
78   if (!bInit_) {
79     bInit_ = GetMsTensorType() != nullptr;
80   }
81   return bInit_;
82 }
83 
Deinit()84 bool InferEngine::Deinit() {
85   if (bInit_) {
86     bInit_ = false;
87   }
88   return bInit_;
89 }
90 
91 static std::map<mindspore::TypeId, std::string> g_type2attr = {
92   {mindspore::kNumberTypeBool, "bool_"},          {mindspore::kNumberTypeInt, "int_"},
93   {mindspore::kNumberTypeInt4, "int_"},           {mindspore::kNumberTypeInt8, "int8"},
94   {mindspore::kNumberTypeInt16, "int16"},         {mindspore::kNumberTypeInt32, "int32"},
95   {mindspore::kNumberTypeInt64, "int64"},         {mindspore::kNumberTypeUInt, "uint"},
96   {mindspore::kNumberTypeUInt8, "uint8"},         {mindspore::kNumberTypeUInt16, "uint16"},
97   {mindspore::kNumberTypeUInt32, "uint32"},       {mindspore::kNumberTypeUInt64, "uint64"},
98   {mindspore::kNumberTypeFloat, "float_"},        {mindspore::kNumberTypeFloat16, "float16"},
99   {mindspore::kNumberTypeFloat32, "float32"},     {mindspore::kNumberTypeFloat64, "float64"},
100   {mindspore::kNumberTypeDouble, "float64"},      {mindspore::kNumberTypeComplex, "complex128"},
101   {mindspore::kNumberTypeComplex64, "complex64"}, {mindspore::kNumberTypeComplex128, "complex128"},
102 };
103 
104 static py::object MakeObjectFromAbstract(const mindspore::abstract::BaseShapePtr &base_shape,
105                                          const mindspore::TypePtr &type, bool *is_abstract);
106 
CreateMetaTensor(const ShapeVector & shape,const mindspore::TypePtr & type)107 static py::object CreateMetaTensor(const ShapeVector &shape, const mindspore::TypePtr &type) {
108   mindspore::TypePtr dtype;
109   if (type->isa<mindspore::TensorType>()) {
110     dtype = type->cast<mindspore::TensorTypePtr>()->element();
111   } else {
112     dtype = type;
113   }
114   /**
115    * NOTE: here create a lazy initialized tensor, avoid allocate data
116    */
117   auto tensor = std::make_shared<mindspore::tensor::Tensor>(dtype->type_id(), shape);
118   py::object pytensor = py::reinterpret_borrow<py::object>(GetMsTensorType());
119   return pytensor(py::cast(tensor));
120 }
121 
CreateMetaTensor(const mindspore::abstract::ShapePtr & shape,const mindspore::TypePtr & type)122 static py::object CreateMetaTensor(const mindspore::abstract::ShapePtr &shape, const mindspore::TypePtr &type) {
123   return CreateMetaTensor(shape->shape(), type);
124 }
125 
CreateScalar(const mindspore::TypePtr & type)126 static py::object CreateScalar(const mindspore::TypePtr &type) {
127   static std::map<mindspore::TypeId, py::object> ms_type2py_type_map = {
128     {mindspore::kNumberTypeBool, py::bool_()},
129     {mindspore::kNumberTypeInt, py::int_()},
130     {mindspore::kNumberTypeInt4, py::int_()},
131     {mindspore::kNumberTypeInt8, py::int_()},
132     {mindspore::kNumberTypeInt16, py::int_()},
133     {mindspore::kNumberTypeInt32, py::int_()},
134     {mindspore::kNumberTypeInt64, py::int_()},
135     {mindspore::kNumberTypeUInt, py::int_()},
136     {mindspore::kNumberTypeUInt8, py::int_()},
137     {mindspore::kNumberTypeUInt16, py::int_()},
138     {mindspore::kNumberTypeUInt32, py::int_()},
139     {mindspore::kNumberTypeUInt64, py::int_()},
140     {mindspore::kNumberTypeFloat, py::float_()},
141     {mindspore::kNumberTypeFloat16, py::float_()},
142     {mindspore::kNumberTypeFloat32, py::float_()},
143     {mindspore::kNumberTypeFloat64, py::float_()},
144     {mindspore::kNumberTypeDouble, py::float_()},
145     {mindspore::kNumberTypeComplex, py::reinterpret_steal<py::object>(PyComplex_FromDoubles(0.0, 0.0))},
146     {mindspore::kNumberTypeComplex64, py::reinterpret_steal<py::object>(PyComplex_FromDoubles(0.0, 0.0))},
147     {mindspore::kNumberTypeComplex128, py::reinterpret_steal<py::object>(PyComplex_FromDoubles(0.0, 0.0))},
148   };
149   auto it = ms_type2py_type_map.find(type->type_id());
150   if (it != ms_type2py_type_map.cend()) {
151     return it->second;
152   } else {
153     return py::cast<py::object>(nullptr);
154   }
155 }
156 
CreateTuple(const mindspore::abstract::BaseShapePtr & base_shape,const mindspore::TypePtr & type,bool * is_abstract)157 static py::object CreateTuple(const mindspore::abstract::BaseShapePtr &base_shape, const mindspore::TypePtr &type,
158                               bool *is_abstract) {
159   bool dynamic;
160   mindspore::abstract::SequenceShapePtr shape_tuple;
161   size_t elem_count = 0;
162   auto type_tuple = type->cast_ptr<mindspore::Tuple>();
163   if (base_shape->isa<mindspore::abstract::DynamicSequenceShape>()) {
164     dynamic = true;
165     elem_count = type_tuple->elements().size();
166   } else {
167     dynamic = false;
168     shape_tuple = base_shape->cast<mindspore::abstract::TupleShapePtr>();
169     elem_count = shape_tuple->size();
170   }
171   py::tuple tuple = py::tuple(elem_count);
172   for (size_t it = 0; it < elem_count; ++it) {
173     bool is_abstract_obj = false;
174     auto tensor_it =
175       MakeObjectFromAbstract(dynamic ? base_shape : (*shape_tuple)[it], type_tuple->elements()[it], &is_abstract_obj);
176     Py_INCREF(tensor_it.ptr());
177     PyTuple_SetItem(tuple.ptr(), it, tensor_it.ptr());
178     *is_abstract |= is_abstract_obj;
179   }
180   return tuple;
181 }
182 
CreateList(const mindspore::abstract::BaseShapePtr & base_shape,const mindspore::TypePtr & type,bool * is_abstract)183 static py::object CreateList(const mindspore::abstract::BaseShapePtr &base_shape, const mindspore::TypePtr &type,
184                              bool *is_abstract) {
185   bool dynamic;
186   mindspore::abstract::SequenceShapePtr shape_list;
187   size_t elem_count = 0;
188   auto type_list = type->cast_ptr<mindspore::List>();
189   if (base_shape->isa<mindspore::abstract::DynamicSequenceShape>()) {
190     dynamic = true;
191     elem_count = type_list->elements().size();
192   } else {
193     dynamic = false;
194     shape_list = base_shape->cast<mindspore::abstract::ListShapePtr>();
195     elem_count = shape_list->size();
196   }
197   py::list list = py::list(elem_count);
198   for (size_t it = 0; it < elem_count; ++it) {
199     bool is_abstract_obj = false;
200     auto tensor_it =
201       MakeObjectFromAbstract(dynamic ? base_shape : (*shape_list)[it], type_list->elements()[it], &is_abstract_obj);
202     Py_INCREF(tensor_it.ptr());
203     PyList_SetItem(list.ptr(), it, tensor_it.ptr());
204     *is_abstract |= is_abstract_obj;
205   }
206   return list;
207 }
208 
MakeObjectFromAbstract(const mindspore::abstract::BaseShapePtr & base_shape,const mindspore::TypePtr & type,bool * is_abstract)209 static py::object MakeObjectFromAbstract(const mindspore::abstract::BaseShapePtr &base_shape,
210                                          const mindspore::TypePtr &type, bool *is_abstract) {
211   *is_abstract = false;
212   if (base_shape->isa<mindspore::abstract::Shape>()) {
213     return CreateMetaTensor(base_shape->cast<mindspore::abstract::ShapePtr>(), type);
214   } else if (base_shape->isa<mindspore::abstract::NoShape>() && type->isa<mindspore::Number>()) {
215     *is_abstract = true;
216     return CreateScalar(type);
217   } else if (base_shape->isa<mindspore::abstract::TupleShape>() && type->isa<mindspore::Tuple>()) {
218     return CreateTuple(base_shape, type, is_abstract);
219   } else if (base_shape->isa<mindspore::abstract::ListShape>() && type->isa<mindspore::List>()) {
220     return CreateList(base_shape, type, is_abstract);
221   } else if (base_shape->isa<mindspore::abstract::NoShape>() && type->isa<mindspore::TypeNone>()) {
222     // AbstractNone indicates there is no output for this CNode node.
223     return py::cast<py::object>(Py_None);
224   } else if (type->isa<mindspore::Monad>()) {
225     // Return monad abstract if it is monad type.
226     return py::cast<py::object>(nullptr);
227   } else if (base_shape->isa<mindspore::abstract::DynamicSequenceShape>()) {
228     *is_abstract = true;
229     if (type->isa<mindspore::Tuple>()) {
230       return CreateTuple(base_shape, type, is_abstract);
231     } else if (type->isa<mindspore::List>()) {
232       return CreateList(base_shape, type, is_abstract);
233     } else if (type->isa<mindspore::TensorType>()) {
234       return CreateMetaTensor({-2}, type);
235     } else if (type->isa<mindspore::Number>()) {
236       return CreateScalar(type);
237     } else {
238       MS_LOG(EXCEPTION) << "Evaluator return invalid shape " << base_shape->ToString() << " or type. "
239                         << type->ToString();
240       return py::cast<py::object>(nullptr);
241     }
242   } else {
243     MS_LOG(EXCEPTION) << "Evaluator return invalid shape " << base_shape->ToString() << " or type. "
244                       << type->ToString();
245     return py::cast<py::object>(nullptr);
246   }
247 }
248 
MakeObjectFromPyObject(const py::object & shape_obj,const py::object & type_obj,bool * is_abstract)249 static py::object MakeObjectFromPyObject(const py::object &shape_obj, const py::object &type_obj, bool *is_abstract) {
250   *is_abstract = false;
251   if ((py::isinstance<py::list>(shape_obj) || py::isinstance<py::tuple>(shape_obj)) &&
252       py::isinstance<mindspore::Type>(type_obj)) {
253     auto res_vec = shape_obj.cast<ShapeVector>();
254     auto res_dtype = type_obj.cast<mindspore::TypePtr>();
255     if (res_vec.empty() && (!res_dtype->isa<TensorType>())) {
256       *is_abstract = true;
257       return CreateScalar(res_dtype);
258     }
259     return CreateMetaTensor(res_vec, res_dtype);
260   } else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) {
261     auto typeid_tuple = type_obj.cast<py::tuple>();
262     py::tuple ptr_list(typeid_tuple.size());
263     for (size_t it = 0; !(*is_abstract) && it < typeid_tuple.size(); ++it) {
264       py::object tmp =
265         MakeObjectFromPyObject(shape_obj.cast<py::tuple>()[it], type_obj.cast<py::tuple>()[it], is_abstract);
266       ptr_list[it] = tmp;
267     }
268     return ptr_list;
269   } else if (py::isinstance<py::list>(shape_obj) && py::isinstance<py::list>(type_obj)) {
270     auto typeid_list = type_obj.cast<py::list>();
271     py::list ptr_list;
272     for (size_t it = 0; !(*is_abstract) && it < typeid_list.size(); ++it) {
273       py::object tmp =
274         MakeObjectFromPyObject(shape_obj.cast<py::list>()[it], type_obj.cast<py::list>()[it], is_abstract);
275       ptr_list.append(tmp);
276     }
277     return ptr_list;
278   } else if (shape_obj.is_none() && type_obj.is_none()) {
279     return py::cast<py::object>(Py_None);
280   } else if (py::isinstance<mindspore::Type>(type_obj) &&
281              type_obj.cast<mindspore::Type *>()->isa<mindspore::MonadType>()) {
282     return py::cast<py::object>(nullptr);
283   } else {
284     MS_LOG(EXCEPTION) << "Python evaluator return invalid shape or type. " << py::str(type_obj);
285   }
286 }
287 
HasTensor(py::object obj)288 static bool HasTensor(py::object obj) {
289   if (obj.ptr() == nullptr) {
290     return false;
291   }
292 
293   ReprRecursionScope scope(obj.ptr());
294   if (scope.ReEnterOrError()) {
295     return false;
296   }
297   if (py::isinstance<mindspore::tensor::MetaTensor>(obj)) {
298     return true;
299   } else if (py::isinstance<py::list>(obj)) {
300     auto list_obj = py::cast<py::list>(obj);
301     if (std::any_of(list_obj.begin(), list_obj.end(),
302                     [](const auto &e) { return HasTensor(py::cast<py::object>(e)); })) {
303       return true;
304     }
305   } else if (py::isinstance<py::tuple>(obj)) {
306     auto tuple_obj = py::cast<py::tuple>(obj);
307     if (std::any_of(tuple_obj.begin(), tuple_obj.end(),
308                     [](const auto &e) { return HasTensor(py::cast<py::object>(e)); })) {
309       return true;
310     }
311   } else if (py::isinstance<py::dict>(obj)) {
312     auto dict_obj = py::cast<py::dict>(obj);
313     if (std::any_of(dict_obj.begin(), dict_obj.end(), [](const auto &e) {
314           return HasTensor(py::cast<py::object>(e.first)) || HasTensor(py::cast<py::object>(e.second));
315         })) {
316       return true;
317     }
318   }
319   return false;
320 }
321 
DtypeToEnum(const ValuePtr & value)322 ValuePtr DtypeToEnum(const ValuePtr &value) {
323   if (!value->isa<mindspore::Type>()) {
324     return value;
325   }
326   auto type_id = value->cast<TypePtr>()->type_id();
327   return MakeValue<int64_t>(type_id);
328 }
329 
330 using ArgHandlerFunc = std::function<ValuePtr(const ValuePtr &)>;
331 
GetOppArgHandlerFunc(const std::string & arg_handler)332 ArgHandlerFunc GetOppArgHandlerFunc(const std::string &arg_handler) {
333   static const std::unordered_map<std::string, ArgHandlerFunc> opp_arg_handler_funcs = {
334     {"dtype_to_type_id", DtypeToEnum},
335   };
336   if (opp_arg_handler_funcs.find(arg_handler) != opp_arg_handler_funcs.end()) {
337     return opp_arg_handler_funcs.at(arg_handler);
338   } else {
339     return nullptr;
340   }
341 }
342 
ConvertArgByArgHandler(mindspore::ValuePtr value,ops::OpDef * op_def,size_t i)343 mindspore::ValuePtr ConvertArgByArgHandler(mindspore::ValuePtr value, ops::OpDef *op_def, size_t i) {
344   if (op_def != nullptr && value != nullptr) {
345     auto opp_arg_handler_func = GetOppArgHandlerFunc(op_def->args_[i].arg_handler_);
346     if (opp_arg_handler_func != nullptr) {
347       return opp_arg_handler_func(value);
348     }
349   }
350   return value;
351 }
352 
ConvertArgByCastDtype(py::object arg,ops::OpInputArg op_arg)353 mindspore::ValuePtr ConvertArgByCastDtype(py::object arg, ops::OpInputArg op_arg) {
354   mindspore::ValuePtr value = nullptr;
355   parse::OpDefConvertFunc convert_func = parse::GetConverterByType(static_cast<int32_t>(op_arg.arg_dtype_));
356   MS_EXCEPTION_IF_NULL(convert_func);
357   value = convert_func(arg);
358   if (value != nullptr) {
359     return value;
360   }
361   if (!op_arg.cast_dtype_.empty()) {
362     for (auto cast_dtype : op_arg.cast_dtype_) {
363       convert_func = parse::GetConverterByType(parse::CombineTypesForTypeCast(cast_dtype, op_arg.arg_dtype_));
364       MS_EXCEPTION_IF_NULL(convert_func);
365       auto val = convert_func(arg);
366       if (val != nullptr) {
367         return val;
368       }
369     }
370   }
371   return value;
372 }
373 
convertData(py::object param_obj,bool is_stub,ops::OpDef * op_def,size_t i)374 mindspore::ValuePtr convertData(py::object param_obj, bool is_stub, ops::OpDef *op_def, size_t i) {
375   mindspore::ValuePtr converted = nullptr;
376   if (op_def != nullptr) {
377     if (op_def->args_.size() <= i) {
378       MS_LOG(EXCEPTION) << "Fail to convert the " << i << "th argument by dtype, args[" << i
379                         << "]: " << py::str(param_obj);
380       return nullptr;
381     }
382     converted = ConvertArgByCastDtype(param_obj, op_def->args_[i]);
383   }
384   if (converted) {
385     return converted;
386   }
387   if (is_stub) {
388     if (!mindspore::parse::ConvertStubData(param_obj, &converted, false, nullptr, false)) {
389       MS_LOG(EXCEPTION) << "Fail to convert the " << i << "th argument, args[" << i << "]: " << py::str(param_obj);
390       return nullptr;
391     }
392   } else {
393     if (!mindspore::parse::ConvertData(param_obj, &converted, false, nullptr, false)) {
394       MS_LOG(EXCEPTION) << "Fail to convert the " << i << "th argument, args[" << i << "]: " << py::str(param_obj);
395       return nullptr;
396     }
397   }
398   return converted;
399 }
400 
ChangeAbstractArgList(PrimitivePtr prim,std::vector<PyObject * > args,bool * has_tensor,int * monad_count)401 static AbstractBasePtrList ChangeAbstractArgList(PrimitivePtr prim, std::vector<PyObject *> args, bool *has_tensor,
402                                                  int *monad_count) {
403   std::vector<std::string> prim_cast_ops = {"Div"};
404   py::object handle;
405   if (std::find(prim_cast_ops.begin(), prim_cast_ops.end(), prim->name()) != prim_cast_ops.end() &&
406       args.size() == ArgsSizeTwo) {
407     auto tensor_type = py::reinterpret_borrow<py::object>(GetMsTensorType());
408     if (py::isinstance<mindspore::tensor::Tensor>(args[0]) && CheckScalar(args[1])) {
409       py::object dtype = py::reinterpret_borrow<py::object>(args[0]).attr("dtype");
410       py::object arg1 = py::reinterpret_borrow<py::object>(args[1]);
411       handle = tensor_type(arg1, dtype);
412       args[1] = handle.ptr();
413     } else if (CheckScalar(args[0]) && py::isinstance<mindspore::tensor::Tensor>(args[1])) {
414       py::object dtype = py::reinterpret_borrow<py::object>(args[1]).attr("dtype");
415       py::object arg0 = py::reinterpret_borrow<py::object>(args[0]);
416       handle = tensor_type(arg0, dtype);
417       args[0] = handle.ptr();
418     }
419   }
420   auto op_def = mindspore::ops::GetOpDef(prim->name());
421   AbstractBasePtrList list;
422   for (size_t i = 0; i < args.size(); ++i) {
423     mindspore::ValuePtr converted = nullptr;
424     py::object param_obj = py::reinterpret_borrow<py::object>(args[i]);
425     bool is_stub = false;
426     if (IsStubTensor(param_obj)) {
427       is_stub = true;
428     } else if (py::isinstance<mindspore::Monad>(param_obj)) {
429       *monad_count = *monad_count + 1;
430     }
431     *has_tensor = HasTensor(param_obj);
432     converted = convertData(param_obj, is_stub, op_def, i);
433     converted = ConvertArgByArgHandler(converted, op_def, i);
434     auto arg = mindspore::abstract::ToAbstract(converted, nullptr, nullptr);
435     list.push_back(arg);
436   }
437   return list;
438 }
439 
GeneratePrimitiveArgs(PrimitivePtr prim,std::vector<PyObject * > * list,PyObject * py_primitive)440 void GeneratePrimitiveArgs(PrimitivePtr prim, std::vector<PyObject *> *list, PyObject *py_primitive) {
441   auto op_def = mindspore::ops::GetOpDef(prim->name());
442   if (op_def == nullptr) {
443     return;
444   }
445   std::vector<ops::OpInputArg> op_call_args;
446   std::vector<ops::OpInputArg> op_init_args;
447   auto op_args = op_def->args_;
448   for (const auto &op_arg : op_args) {
449     if (op_arg.as_init_arg_) {
450       op_init_args.emplace_back(op_arg);
451     } else {
452       op_call_args.emplace_back(op_arg);
453     }
454   }
455   size_t args_size = list->size();
456   if (args_size < op_call_args.size()) {
457     for (size_t i = args_size; i < op_call_args.size(); i++) {
458       auto default_value = parse::GetArgDefaultValue(prim->name(), op_call_args[i].arg_name_);
459       if (default_value == nullptr) {
460         continue;
461       }
462       auto arg_value = ValueToPyData(default_value);
463       list->push_back(arg_value.ptr());
464     }
465   }
466   auto obj = py_primitive;
467   for (const auto &op_arg : op_init_args) {
468     auto arg_name = common::SafeCStr(op_arg.arg_name_);
469     if (py::hasattr(obj, arg_name)) {
470       py::object arg_value = py::getattr(obj, arg_name);
471       if (arg_value.ptr() == nullptr) {
472         continue;
473       }
474       list->push_back(arg_value.ptr());
475     }
476   }
477 }
478 
ConvertCppTensor(const py::object & any)479 static py::object ConvertCppTensor(const py::object &any) {
480   PyObject *op = any.ptr();
481   PyTypeObject *cpp_tensor_type = GetPybindType<mindspore::tensor::Tensor>();
482 
483   if (Py_IS_TYPE(op, cpp_tensor_type)) {
484     py::object tp = py::reinterpret_borrow<py::object>(GetMsTensorType());
485     return tp(any);
486   }
487 
488   if (PyTuple_Check(op) || PyList_Check(op)) {
489     for (Py_ssize_t i = 0; i < Py_SIZE(op); ++i) {
490       PyObject **item = PyTuple_Check(op) ? &PyTuple_GET_ITEM(op, i) : &PyList_GET_ITEM(op, i);
491       PyObject *new_item = ConvertCppTensor(py::cast<py::object>(*item)).inc_ref().ptr();
492       Py_SETREF(*item, new_item);
493     }
494     return any;
495   }
496 
497   if (PyDict_Check(op)) {
498     Py_ssize_t pos = 0;
499     PyObject *key;
500     PyObject *value;
501     while (PyDict_Next(op, &pos, &key, &value)) {
502       py::object new_value = ConvertCppTensor(py::cast<py::object>(value));
503       PyDict_SetItem(op, key, new_value.ptr());
504     }
505     return any;
506   }
507   return any;
508 }
509 
510 // return new reference
InferPrimitive(PyObject * primitive,const std::vector<PyObject * > & args,bool * is_abstract)511 PyObject *InferEngine::InferPrimitive(PyObject *primitive, const std::vector<PyObject *> &args, bool *is_abstract) {
512   if (!SupportInfer(primitive)) {
513     return nullptr;
514   }
515   int monad_count = 0;
516   bool has_tensor = false;
517   std::vector<PyObject *> arglist = args;
518   bool isPrimitiveFunction = py::hasattr(primitive, PYTHON_PRIMITIVE_FUNCTION_FLAG);
519   py::object adapter_obj = py::reinterpret_borrow<py::object>(primitive);
520   mindspore::PrimitivePtr prim;
521   if (isPrimitiveFunction) {
522     PrimitiveFunctionAdapterPtr prim_func_adapter = adapter_obj.cast<PrimitiveFunctionAdapterPtr>();
523     MS_EXCEPTION_IF_NULL(prim_func_adapter);
524     PrimitivePtr cpp_primitive_func = prim_func_adapter->attached_primitive_function();
525     if (cpp_primitive_func == nullptr) {
526       std::string prim_name = py::getattr(primitive, "name").cast<std::string>();
527       prim = std::make_shared<Primitive>(prim_name);
528     } else {
529       prim = cpp_primitive_func;
530     }
531   } else {
532     mindspore::PrimitivePyAdapterPtr prim_adapter = adapter_obj.cast<mindspore::PrimitivePyAdapterPtr>();
533     mindspore::PrimitivePyPtr primitive_py = prim_adapter->attached_primitive();
534     if (primitive_py == nullptr) {
535       primitive_py = std::make_shared<mindspore::PrimitivePy>(adapter_obj);
536       prim_adapter->set_attached_primitive(primitive_py);
537     }
538     prim = primitive_py;
539   }
540 
541   PyObject *special_type = InferSpecialPrimitive(primitive, arglist);
542   if (special_type != nullptr) {
543     return special_type;
544   }
545   GeneratePrimitiveArgs(prim, &arglist, primitive);
546   AbstractBasePtrList list = ChangeAbstractArgList(prim, arglist, &has_tensor, &monad_count);
547 
548   *is_abstract = false;
549   std::optional<AbstractBasePtr> opt_res = mindspore::abstract::TryInferAbstract(prim, list);
550   if (opt_res.has_value()) {
551     auto abs = opt_res.value();
552     py::object pyObj;
553     if (abs != nullptr) {
554       pyObj = FuncGraphBuilder::ConvertToPyObj(abs);
555       if (pyObj.ptr() == nullptr) {
556         pyObj = MakeObjectFromAbstract(abs->BuildShape(), abs->BuildType(), is_abstract);
557       }
558       if (pyObj.ptr() != nullptr) {
559         pyObj = ConvertCppTensor(pyObj);
560       }
561     }
562     return pyObj.inc_ref().ptr();
563   } else if (primitive) {
564     if (py::hasattr(primitive, PY_PRIM_METHOD_INFER)) {
565       size_t list_count = arglist.size() - size_t(monad_count);
566       py::tuple py_vals(list_count);
567       for (size_t i = 0; i < list_count; ++i) {
568         py_vals[i] = py::reinterpret_borrow<py::object>(arglist[i]);
569       }
570       auto infer_func = adapter_obj.attr(PY_PRIM_METHOD_INFER);
571       py::dict output = infer_func(*py_vals);
572       if (output[ATTR_VALUE].is_none()) {
573         auto ret = MakeObjectFromPyObject(output[ATTR_SHAPE], output[ATTR_DTYPE], is_abstract);
574         Py_INCREF(ret.ptr());
575         return ret.ptr();
576       } else {
577         Py_INCREF(output[ATTR_VALUE].ptr());
578         return output[ATTR_VALUE].ptr();
579       }
580     } else if (!has_tensor && py::hasattr(primitive, PY_PRIM_METHOD_INFER_VALUE)) {
581       // Tensor maybe uninitialized, avoid infer value and allocate data.
582       // because tensor has no data when doing inference for type, infer_value will crash!
583       py::tuple py_vals(arglist.size());
584       for (size_t i = 0; i < arglist.size(); ++i) {
585         py_vals[i] = py::reinterpret_borrow<py::object>(arglist[i]);
586       }
587       auto infer_value = adapter_obj.attr(PY_PRIM_METHOD_INFER_VALUE);
588       auto output = infer_value(*py_vals);
589       Py_INCREF(output.ptr());
590       return output.ptr();
591     }
592     return nullptr;
593   }
594   return nullptr;
595 }
596 
GetShapeForStubTensor(PyObject * stubtensor)597 static ShapeVector GetShapeForStubTensor(PyObject *stubtensor) {
598   ShapeVector shape;
599   auto stub = PyObject_GetAttrString(stubtensor, "stub");
600   if (stub != nullptr && stub != Py_None) {
601     auto ptr = py::cast<mindspore::stub::StubNodePtr>(stub);
602     auto base = ptr->ToAbstract();
603     auto shape_ptr = base->BuildShape()->cast<abstract::ShapePtr>();
604     if (shape_ptr && !shape_ptr->IsDynamic()) {
605       shape = shape_ptr->shape();
606     }
607     Py_DECREF(stub);
608   } else {
609     auto ptr = PyObject_GetAttrString(stubtensor, "tensor");
610     auto tensor_ptr = py::cast<mindspore::tensor::TensorPtr>(ptr);
611     shape = tensor_ptr->shape();
612     Py_DECREF(ptr);
613   }
614   return shape;
615 }
616 
GetDTypeForStubTensor(PyObject * stubtensor)617 static TypePtr GetDTypeForStubTensor(PyObject *stubtensor) {
618   TypePtr dtype;
619   auto stub = PyObject_GetAttrString(stubtensor, "stub");
620   if (stub != nullptr && stub != Py_None) {
621     auto ptr = py::cast<mindspore::stub::StubNodePtr>(stub);
622     auto base = ptr->ToAbstract();
623     auto dt = base->BuildType();
624     if (dt->isa<mindspore::TensorType>()) {
625       dtype = dt->cast<std::shared_ptr<mindspore::TensorType>>()->element();
626     } else {
627       dtype = dt;
628     }
629     Py_DECREF(stub);
630   } else {
631     auto ptr = PyObject_GetAttrString(stubtensor, "tensor");
632     auto tensor_ptr = py::cast<mindspore::tensor::TensorPtr>(ptr);
633     dtype = tensor_ptr->Dtype();
634     Py_DECREF(ptr);
635   }
636   return dtype;
637 }
638 
InferShape(PyObject *,const std::vector<PyObject * > & args)639 static PyObject *InferShape(PyObject *, const std::vector<PyObject *> &args) {
640   PyObject *arg = args[0];
641   ShapeVector shape;
642   if (IsStubTensor(arg)) {
643     shape = GetShapeForStubTensor(arg);
644   } else {
645     auto pyObj = py::cast<py::object>(arg);
646     auto tensor_ptr = pyObj.cast<mindspore::tensor::MetaTensorPtr>();
647     shape = tensor_ptr->shape();
648   }
649   PyObject *tuple = PyTuple_New(shape.size());
650   for (size_t it = 0; it < shape.size(); ++it) {
651     py::int_ ss(shape[it]);
652     Py_INCREF(ss.ptr());
653     PyTuple_SetItem(tuple, it, ss.ptr());
654   }
655   return tuple;
656 }
657 
InferDType(PyObject *,const std::vector<PyObject * > & args)658 static PyObject *InferDType(PyObject *, const std::vector<PyObject *> &args) {
659   PyObject *arg = args[0];
660   mindspore::TypePtr dtype;
661   if (IsStubTensor(arg)) {
662     dtype = GetDTypeForStubTensor(arg);
663   } else {
664     auto pyObj = py::cast<py::object>(arg);
665     auto tensor_ptr = pyObj.cast<mindspore::tensor::MetaTensorPtr>();
666     dtype = tensor_ptr->Dtype();
667   }
668   PyObject *type = nullptr;
669   if (g_type2attr.find(dtype->type_id()) != g_type2attr.end()) {
670     type = PyObject_GetAttrString(GetMsType(), g_type2attr[dtype->type_id()].c_str());
671   } else {
672     MS_LOG(EXCEPTION) << "Cannot find suitable type for " << dtype->ToString();
673     return nullptr;
674   }
675   return type;
676 }
677 
InferRank(PyObject *,const std::vector<PyObject * > & args)678 static PyObject *InferRank(PyObject *, const std::vector<PyObject *> &args) {
679   PyObject *arg = args[0];
680   ShapeVector shape;
681   if (IsStubTensor(arg)) {
682     shape = GetShapeForStubTensor(arg);
683   } else {
684     auto pyObj = py::cast<py::object>(arg);
685     auto tensor_ptr = pyObj.cast<mindspore::tensor::MetaTensorPtr>();
686     shape = tensor_ptr->shape();
687   }
688   return PyLong_FromSize_t(shape.size());
689 }
690 
InferSize(PyObject *,const std::vector<PyObject * > & args)691 static PyObject *InferSize(PyObject *, const std::vector<PyObject *> &args) {
692   PyObject *arg = args[0];
693   ShapeVector shape;
694   if (IsStubTensor(arg)) {
695     shape = GetShapeForStubTensor(arg);
696   } else {
697     auto pyObj = py::cast<py::object>(arg);
698     auto tensor_ptr = pyObj.cast<mindspore::tensor::MetaTensorPtr>();
699     shape = tensor_ptr->shape();
700   }
701   size_t elements = 1;
702   for (size_t i = 0; i < shape.size(); i++) {
703     elements *= size_t(shape[i]);
704   }
705   return PyLong_FromSize_t(elements);
706 }
707 
GetSpecialPrimitiveInferFunc()708 const SpecialPrimitiveInferFuncMap &GetSpecialPrimitiveInferFunc() {
709   constexpr const auto CallValue = [](PyObject *prim, const std::vector<PyObject *> &args) {
710     PyObject *res = PyObject_Vectorcall(prim, args.data(), args.size(), nullptr);
711     PyErr_Clear();
712     return res;
713   };
714   constexpr const auto CToMSTensor = [](PyObject *prim, const std::vector<PyObject *> &args) {
715     return args.size() == 0 ? nullptr : ConvertToMsTensor(py::cast<py::object>(args[0])).inc_ref().ptr();
716   };
717   constexpr const auto CToAdapterTensor = [](PyObject *prim, const std::vector<PyObject *> &args) {
718     return args.size() == 0 ? nullptr : ConvertToAdapterTensor(py::cast<py::object>(args[0])).inc_ref().ptr();
719   };
720   static const SpecialPrimitiveInferFuncMap specialize = {
721     {"Size", InferSize},
722     {"Rank", InferRank},
723     {"DType", InferDType},
724     {"Shape", InferShape},
725     {"TileSize", CallValue},
726     {"ListToTensor", CallValue},
727     {"TupleToTensor", CallValue},
728     {"ScalarToTensor", CallValue},
729     {"make_range", CallValue},
730     {"ConvertToMsTensor", CToMSTensor},
731     {"ConvertToAdapterTensor", CToAdapterTensor},
732     {"IsShapeUnKnown", [](PyObject *, const std::vector<PyObject *> &) { Py_RETURN_FALSE; }},
733   };
734   return specialize;
735 }
736 
InferSpecialPrimitive(PyObject * primitive,const std::vector<PyObject * > & arglist)737 PyObject *InferEngine::InferSpecialPrimitive(PyObject *primitive, const std::vector<PyObject *> &arglist) {
738   std::string name = py::cast<py::object>(primitive).attr("name").cast<std::string>();
739   auto iter = GetSpecialPrimitiveInferFunc().find(name);
740   if (iter != GetSpecialPrimitiveInferFunc().end()) {
741     return iter->second(primitive, arglist);
742   }
743   return nullptr;
744 }
745 
SupportInfer(PyObject * primitive)746 bool InferEngine::SupportInfer(PyObject *primitive) {
747   if (!Init()) {
748     return false;
749   }
750   bool isPrimitiveFunction = py::hasattr(primitive, PYTHON_PRIMITIVE_FUNCTION_FLAG);
751   py::object adapter_obj = py::reinterpret_borrow<py::object>(primitive);
752 
753   mindspore::PrimitivePtr prim;
754   if (isPrimitiveFunction) {
755     PrimitiveFunctionAdapterPtr prim_func_adapter = adapter_obj.cast<PrimitiveFunctionAdapterPtr>();
756     MS_EXCEPTION_IF_NULL(prim_func_adapter);
757     PrimitivePtr cpp_primitive_func = prim_func_adapter->attached_primitive_function();
758     if (cpp_primitive_func == nullptr) {
759       std::string prim_name = py::getattr(primitive, "name").cast<std::string>();
760       prim = std::make_shared<Primitive>(prim_name);
761     } else {
762       prim = cpp_primitive_func;
763     }
764   } else {
765     mindspore::PrimitivePyAdapterPtr prim_adapter = adapter_obj.cast<mindspore::PrimitivePyAdapterPtr>();
766     mindspore::PrimitivePyPtr primitive_py = prim_adapter->attached_primitive();
767     if (primitive_py == nullptr) {
768       primitive_py = std::make_shared<mindspore::PrimitivePy>(adapter_obj);
769       prim_adapter->set_attached_primitive(primitive_py);
770     }
771     prim = primitive_py;
772   }
773 
774   auto eval_impl = mindspore::abstract::GetPrimitiveInferImpl(prim);
775   auto op_name = prim->name();
776   if (eval_impl != std::nullopt && eval_impl->Get().get() != nullptr) {
777     return true;
778   }
779   auto frontend_func_impl = ops::GetOpFrontendFuncImplPtr(op_name);
780   auto op_def = ops::GetOpDef(op_name);
781   if (frontend_func_impl != nullptr || op_def != nullptr) {
782     return true;
783   }
784   if (GetSpecialPrimitiveInferFunc().find(prim->name()) != GetSpecialPrimitiveInferFunc().end()) {
785     return true;
786   }
787   return false;
788 }
789 
CheckType(const char * mod_name,const char * type_name,bool check_sub_type,PyTypeObject * tp)790 static bool CheckType(const char *mod_name, const char *type_name, bool check_sub_type, PyTypeObject *tp) {
791   py::object cls = Utils::GetModuleAttr(mod_name, type_name);
792   MS_EXCEPTION_IF_CHECK_FAIL(PyType_Check(cls.ptr()), "must be type");
793   bool check_res = reinterpret_cast<PyObject *>(tp) == cls.ptr();
794   if (!check_res && (check_sub_type)) {
795     check_res |= (PyType_IsSubtype(tp, reinterpret_cast<PyTypeObject *>(cls.ptr())) != 0);
796   }
797   return check_res;
798 }
799 
800 // sub-type check
801 template <>
IsGradOperationType(PyTypeObject * tp)802 bool IsGradOperationType<true>(PyTypeObject *tp) {
803   return IsPybindType<mindspore::prim::GradOperation, true>(tp);
804 }
805 template <>
IsVmapOperationType(PyTypeObject * tp)806 bool IsVmapOperationType<true>(PyTypeObject *tp) {
807   return IsPybindType<mindspore::prim::VmapOperation, true>(tp);
808 }
809 template <>
IsShardType(PyTypeObject * tp)810 bool IsShardType<true>(PyTypeObject *tp) {
811   return IsPybindType<mindspore::prim::Shard, true>(tp);
812 }
813 template <>
IsStubTensorType(PyTypeObject * tp)814 bool IsStubTensorType<true>(PyTypeObject *tp) {
815   return CheckType("mindspore.common._stub_tensor", "StubTensor", true, tp);
816 }
817 template <>
IsTensorType(PyTypeObject * tp)818 bool IsTensorType<true>(PyTypeObject *tp) {
819   return IsPybindType<mindspore::tensor::MetaTensor, true>(tp);
820 }
821 template <>
IsCellType(PyTypeObject * tp)822 bool IsCellType<true>(PyTypeObject *tp) {
823   return IsPybindType<mindspore::Cell, true>(tp);
824 }
825 template <>
IsPrimitiveType(PyTypeObject * tp)826 bool IsPrimitiveType<true>(PyTypeObject *tp) {
827   return IsPybindType<mindspore::PrimitivePyAdapter, true>(tp);
828 }
829 template <>
IsMetaFuncGraphType(PyTypeObject * tp)830 bool IsMetaFuncGraphType<true>(PyTypeObject *tp) {
831   return IsPybindType<mindspore::MetaFuncGraph, true>(tp);
832 }
833 template <>
IsMSDTypeType(PyTypeObject * tp)834 bool IsMSDTypeType<true>(PyTypeObject *tp) {
835   return IsPybindType<mindspore::Type, true>(tp);
836 }
837 // exact type check
838 template <>
IsCellListType(PyTypeObject * tp)839 bool IsCellListType<false>(PyTypeObject *tp) {
840   return CheckType("mindspore.nn", "CellList", false, tp);
841 }
842 
CheckTensorDataInitialized(const py::object & py_tensor)843 bool CheckTensorDataInitialized(const py::object &py_tensor) {
844   if (py::isinstance<mindspore::tensor::Tensor>(py_tensor)) {
845     auto tensor = py_tensor.cast<mindspore::tensor::TensorPtr>();
846     return tensor->data().const_data() != nullptr;
847   }
848   return false;
849 }
850 
FindTensorName(const std::string & name)851 bool FindTensorName(const std::string &name) {
852   const auto &meth = pipeline::GetMethodMap().find(kObjectTypeTensorType)->second;
853   if (meth.find(name) != meth.end()) {
854     return true;
855   }
856   const auto &attr = pipeline::GetAttrMap().find(kObjectTypeTensorType)->second;
857   if (attr.find(name) != attr.end()) {
858     return true;
859   }
860   if (name == "device") {
861     return true;
862   }
863   return false;
864 }
865 
PyToAbs(py::handle handle)866 static AbstractBasePtr PyToAbs(py::handle handle) {
867   py::object input = py::cast<py::object>(handle);
868   ValuePtr value_ptr;
869   if (!parse::ConvertStubData(input, &value_ptr) || value_ptr == nullptr) {
870     MS_LOG(ERROR) << "can't convert argument to value ptr [" << std::string(py::str(input)) << "]";
871     return nullptr;
872   }
873   return value_ptr->ToAbstract();
874 }
875 
MakeArgumentsAbstract(py::object callable_object,py::object args,py::object key_words)876 static std::unique_ptr<AbstractBasePtrList> MakeArgumentsAbstract(py::object callable_object, py::object args,
877                                                                   py::object key_words) {
878   // for cell construct
879   auto callable_type = Py_TYPE(callable_object.ptr());
880   if (IsCellType<true>(callable_type)) {
881     callable_object = callable_object.attr("construct");
882   }
883   py::object signature = py::module::import("inspect").attr("signature")(callable_object).attr("bind");
884   py::object bind_args = py::reinterpret_steal<py::object>(PyObject_Call(signature.ptr(), args.ptr(), key_words.ptr()));
885   (void)bind_args.attr("apply_defaults")();
886   args = py::tuple(bind_args.attr("args"));
887   key_words = py::dict(bind_args.attr("kwargs"));
888 
889   AbstractBasePtrList list;
890   for (auto value : args) {
891     auto abs = PyToAbs(value);
892     if (abs == nullptr) {
893       return nullptr;
894     }
895     list.push_back(abs);
896   }
897   if (key_words.ptr() == nullptr) {
898     return std::make_unique<AbstractBasePtrList>(std::move(list));
899   }
900 
901   PyObject *key;
902   PyObject *value;
903   Py_ssize_t pos = 0;
904   while (PyDict_Next(key_words.ptr(), &pos, &key, &value)) {
905     auto abs = PyToAbs(value);
906     if (abs == nullptr) {
907       return nullptr;
908     }
909     list.push_back(std::make_shared<abstract::AbstractKeywordArg>(PyUnicode_AsUTF8(key), abs));
910   }
911   return std::make_unique<AbstractBasePtrList>(std::move(list));
912 }
913 
EvalMSAPIValue(const py::object & ms_api,const py::object & args,const py::object & key_words)914 py::object EvalMSAPIValue(const py::object &ms_api, const py::object &args, const py::object &key_words) {
915   py::object callable_object = ms_api;
916   ValuePtr func_graph;
917   if (!parse::ConvertData(callable_object, &func_graph) || func_graph == nullptr) {
918     MS_LOG(ERROR) << "can't convert callable object to value ptr [" << std::string(py::str(callable_object)) << "]";
919     return py::object();
920   }
921 
922   auto inputs_ptr = MakeArgumentsAbstract(callable_object, args, key_words);
923   if (inputs_ptr == nullptr) {
924     return py::object();
925   }
926 
927   AbstractBasePtrList inputs_abs_list = std::move(*inputs_ptr);
928   AbstractBasePtr eval_result;
929   if (func_graph->isa<Primitive>()) {
930     auto eval_res = abstract::EvalOnePrim(func_graph->cast<PrimitivePtr>(), inputs_abs_list);
931     eval_result = eval_res == nullptr ? nullptr : eval_res->abstract();
932   } else if (func_graph->ToAbstract()->isa<abstract::AbstractFunction>()) {
933     for (size_t i = 0, size = inputs_abs_list.size(); i != size; ++i) {
934       inputs_abs_list[i] = inputs_abs_list[i]->Broaden();
935     }
936     try {
937       auto analyze_res = pipeline::AbstractAnalyzeWithResourceClean(func_graph, inputs_abs_list);
938       eval_result = analyze_res.eval_result == nullptr ? nullptr : analyze_res.eval_result->abstract();
939     } catch (const std::exception &ex) {
940       MS_LOG(ERROR) << "AbstractAnalyze failed for [" << func_graph->ToString() << "], error:" << ex.what();
941     }
942   }
943   if (eval_result == nullptr) {
944     MS_LOG(ERROR) << "eval callable object failed [" << std::string(py::str(callable_object)) << "]";
945     return py::object();
946   }
947   py::object res = FuncGraphBuilder::ConvertToPyObj(eval_result);
948   if (res.ptr() == nullptr) {
949     MS_LOG(ERROR) << "can't convert AbstractBasePtr to PyObject [" << eval_result->ToString() << "]";
950     return py::object();
951   }
952   return ConvertCppTensor(res);
953 }
954 
955 }  // namespace pijit
956 }  // namespace mindspore
957