• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2023 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "pipeline/jit/ps/parse/data_converter.h"
20 #include <utility>
21 #include <unordered_map>
22 #include <algorithm>
23 #include "mindspore/core/ops/structure_ops.h"
24 #include "pipeline/jit/ps/parse/resolve.h"
25 #include "pipeline/jit/ps/pipeline.h"
26 #include "frontend/operator/ops.h"
27 #include "frontend/operator/composite/composite.h"
28 #include "ir/func_graph_cloner.h"
29 #include "ir/cell.h"
30 #include "ir/dtype.h"
31 #include "utils/symbolic.h"
32 #include "utils/ms_context.h"
33 #include "include/common/fallback.h"
34 #include "include/common/utils/utils.h"
35 #include "include/common/utils/convert_utils_py.h"
36 #include "include/common/utils/primfunc_utils.h"
37 #include "frontend/operator/composite/multitype_funcgraph.h"
38 
39 namespace mindspore {
40 namespace parse {
41 namespace {
42 struct PyDataToValueRegister {
PyDataToValueRegistermindspore::parse::__anon6d4124920111::PyDataToValueRegister43   PyDataToValueRegister() noexcept {
44     python_adapter::PyAdapterCallback::SetPyDataToValueHandler(data_converter::PyDataToValue);
45   }
46 } callback_register;
47 }  // namespace
48 using Tensor = mindspore::tensor::Tensor;
49 using TensorPtr = mindspore::tensor::TensorPtr;
50 using BaseTensor = mindspore::tensor::BaseTensor;
51 using BaseTensorPtr = mindspore::tensor::BaseTensorPtr;
52 using MetaTensor = mindspore::tensor::MetaTensor;
53 using MetaTensorPtr = mindspore::tensor::MetaTensorPtr;
54 using CSRTensor = mindspore::tensor::CSRTensor;
55 using CSRTensorPtr = mindspore::tensor::CSRTensorPtr;
56 using COOTensor = mindspore::tensor::COOTensor;
57 using COOTensorPtr = mindspore::tensor::COOTensorPtr;
58 using MapTensor = mindspore::tensor::MapTensor;
59 using MapTensorPtr = mindspore::tensor::MapTensorPtr;
60 
61 using InstanceCheckFunc = std::function<bool(const py::object &)>;
62 using InstanceConvertFunc = std::function<ValuePtr(const py::object &, bool, const TypePtr &, const ValuePtrList &)>;
63 static constexpr int kBit8 = 8;
64 static constexpr int kBit16 = 16;
65 static constexpr int kBit32 = 32;
66 static constexpr int kBit64 = 64;
67 
68 class DataConvertFunc {
69  public:
DataConvertFunc(InstanceConvertFunc convert_func)70   explicit DataConvertFunc(InstanceConvertFunc convert_func) : convert_func_(std::move(convert_func)) {}
71 
72   virtual ~DataConvertFunc() = default;
73 
74   virtual bool Matched(const py::object &obj) = 0;
75 
ConvertPyObject(const py::object & obj,bool use_sig,const TypePtr & dtype,const ValuePtrList & args_value_list={})76   ValuePtr ConvertPyObject(const py::object &obj, bool use_sig, const TypePtr &dtype,
77                            const ValuePtrList &args_value_list = {}) {
78     if (convert_func_ == nullptr) {
79       MS_LOG(INTERNAL_EXCEPTION) << "convert func is null";
80     }
81     return convert_func_(obj, use_sig, dtype, args_value_list);
82   }
83 
84  private:
85   InstanceConvertFunc convert_func_ = nullptr;
86 };
87 
88 using DataConvertFuncPtr = std::shared_ptr<DataConvertFunc>;
89 
90 using ArgsObjConvertFunc = std::function<ValuePtr(const py::object &)>;
91 using ArgsObjSigConvertFunc = std::function<ValuePtr(const py::object &, bool)>;
92 using ArgsObjTypeConvertFunc = std::function<ValuePtr(const py::object &, const TypePtr &)>;
93 using ArgsObjArgsValueConvertFunc = std::function<ValuePtr(const py::object &, const ValuePtrList &)>;
94 
95 // Convert the data according to instance type
96 template <typename T>
97 class ByTypeDataConvertFunc : public DataConvertFunc {
98  public:
ByTypeDataConvertFunc(const InstanceConvertFunc & convert_func)99   explicit ByTypeDataConvertFunc(const InstanceConvertFunc &convert_func)
100       : DataConvertFunc(convert_func), check_func_(py::isinstance<T>) {}
101 
ByTypeDataConvertFunc(const ValuePtr & converted_type)102   explicit ByTypeDataConvertFunc(const ValuePtr &converted_type)
103       : DataConvertFunc([converted_type](const py::object &, bool, const TypePtr &, const ValuePtrList &) -> ValuePtr {
104           return converted_type;
105         }),
106         check_func_(py::isinstance<T>) {}
107 
ByTypeDataConvertFunc(const ArgsObjConvertFunc & convert_func)108   explicit ByTypeDataConvertFunc(const ArgsObjConvertFunc &convert_func)
109       : DataConvertFunc([convert_func](const py::object &obj, bool, const TypePtr &, const ValuePtrList &) -> ValuePtr {
110           return convert_func(obj);
111         }),
112         check_func_(py::isinstance<T>) {}
113 
ByTypeDataConvertFunc(const ArgsObjSigConvertFunc & convert_func)114   explicit ByTypeDataConvertFunc(const ArgsObjSigConvertFunc &convert_func)
115       : DataConvertFunc([convert_func](const py::object &obj, bool use_sig, const TypePtr &,
116                                        const ValuePtrList &) -> ValuePtr { return convert_func(obj, use_sig); }),
117         check_func_(py::isinstance<T>) {}
118 
ByTypeDataConvertFunc(const ArgsObjTypeConvertFunc & convert_func)119   explicit ByTypeDataConvertFunc(const ArgsObjTypeConvertFunc &convert_func)
120       : DataConvertFunc([convert_func](const py::object &obj, bool, const TypePtr &dtype,
121                                        const ValuePtrList &) -> ValuePtr { return convert_func(obj, dtype); }),
122         check_func_(py::isinstance<T>) {}
123 
ByTypeDataConvertFunc(const ArgsObjArgsValueConvertFunc & convert_func)124   explicit ByTypeDataConvertFunc(const ArgsObjArgsValueConvertFunc &convert_func)
125       : DataConvertFunc([convert_func](const py::object &obj, bool, const TypePtr &,
126                                        const ValuePtrList &args_value_list) -> ValuePtr {
127           return convert_func(obj, args_value_list);
128         }),
129         check_func_(py::isinstance<T>) {}
130 
131   ~ByTypeDataConvertFunc() override = default;
132 
Matched(const py::object & obj)133   bool Matched(const py::object &obj) override { return check_func_ != nullptr ? check_func_(obj) : false; }
134 
135  private:
136   InstanceCheckFunc check_func_ = nullptr;
137 };
138 
139 // Convert the data according to object attribute.
140 class ByAttrDataConvertFunc : public DataConvertFunc {
141  public:
ByAttrDataConvertFunc(const ArgsObjConvertFunc & convert_func,const std::string & attr_name,const std::string & cell_list_from_top="")142   ByAttrDataConvertFunc(const ArgsObjConvertFunc &convert_func, const std::string &attr_name,
143                         const std::string &cell_list_from_top = "")
144       : DataConvertFunc([convert_func](const py::object &obj, bool, const TypePtr &, const ValuePtrList &) -> ValuePtr {
145           return convert_func(obj);
146         }),
147         attr_name_(attr_name),
148         cell_list_from_top_(cell_list_from_top) {}
149 
ByAttrDataConvertFunc(const ArgsObjSigConvertFunc & convert_func,const std::string & attr_name,const std::string & cell_list_from_top="")150   ByAttrDataConvertFunc(const ArgsObjSigConvertFunc &convert_func, const std::string &attr_name,
151                         const std::string &cell_list_from_top = "")
152       : DataConvertFunc([convert_func](const py::object &obj, bool use_sig, const TypePtr &,
153                                        const ValuePtrList &) -> ValuePtr { return convert_func(obj, use_sig); }),
154         attr_name_(attr_name),
155         cell_list_from_top_(cell_list_from_top) {}
156 
157   ~ByAttrDataConvertFunc() override = default;
158 
Matched(const py::object & obj)159   bool Matched(const py::object &obj) override {
160     return py::hasattr(obj, attr_name_.c_str()) && !py::hasattr(obj, cell_list_from_top_.c_str());
161   }
162 
163  private:
164   std::string attr_name_;
165   std::string cell_list_from_top_;
166 };
167 
168 // Convert the data according to match function.
169 class ByFuncDataConvertFunc : public DataConvertFunc {
170  public:
ByFuncDataConvertFunc(const InstanceCheckFunc & match_func,const ArgsObjConvertFunc & convert_func)171   ByFuncDataConvertFunc(const InstanceCheckFunc &match_func, const ArgsObjConvertFunc &convert_func)
172       : DataConvertFunc([convert_func](const py::object &obj, bool, const TypePtr &, const ValuePtrList &) -> ValuePtr {
173           return convert_func(obj);
174         }),
175         match_func_(match_func) {}
176 
ByFuncDataConvertFunc(const InstanceCheckFunc & match_func,const ArgsObjSigConvertFunc & convert_func)177   ByFuncDataConvertFunc(const InstanceCheckFunc &match_func, const ArgsObjSigConvertFunc &convert_func)
178       : DataConvertFunc([convert_func](const py::object &obj, bool use_sig, const TypePtr &,
179                                        const ValuePtrList &) -> ValuePtr { return convert_func(obj, use_sig); }),
180         match_func_(match_func) {}
181 
182   ~ByFuncDataConvertFunc() override = default;
183 
Matched(const py::object & obj)184   bool Matched(const py::object &obj) override { return match_func_ != nullptr ? match_func_(obj) : false; }
185 
186  private:
187   InstanceCheckFunc match_func_ = nullptr;
188 };
189 
ConvertToBpropCut(const py::object & obj)190 FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
191   std::vector<std::string> results = data_converter::GetObjKey(obj);
192   std::string obj_key = results[0];
193   py::function bprop_func = py::getattr(obj, CUSTOM_BPROP_NAME);
194 
195   auto bprop_graph = std::make_shared<FuncGraph>();
196   std::vector<AnfNodePtr> outputs;
197 
198   auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut");
199   fake_bprop->AddBackwardHookFn(0, bprop_func);
200   (void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true));
201   outputs.push_back(NewValueNode(fake_bprop));
202 
203   py::object code_obj = py::getattr(bprop_func, "__code__");
204   // Three parameters self, out and dout need to be excluded
205   constexpr auto kBpropExcludeParamNum = 3;
206   size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - kBpropExcludeParamNum;
207   for (size_t i = 0; i < inputs_num; ++i) {
208     auto param = bprop_graph->add_parameter();
209     outputs.push_back(param);
210   }
211   auto p1 = bprop_graph->add_parameter();
212   auto p2 = bprop_graph->add_parameter();
213   outputs.push_back(p1);
214   outputs.push_back(p2);
215 
216   bprop_graph->set_output(bprop_graph->NewCNode(std::move(outputs)));
217   data_converter::SetObjGraphValue(obj_key, bprop_graph);
218   return bprop_graph;
219 }
220 
221 namespace {
ConvertTuple(const py::object & obj,bool use_signature)222 ValuePtr ConvertTuple(const py::object &obj, bool use_signature) {
223   MS_LOG(DEBUG) << "Converting python tuple";
224   auto tuple = obj.cast<py::tuple>();
225   std::vector<ValuePtr> value_list;
226   for (size_t it = 0; it < tuple.size(); ++it) {
227     ValuePtr out = nullptr;
228     bool success = ConvertData(tuple[it], &out, use_signature);
229     if (!success) {
230       return nullptr;
231     }
232     value_list.push_back(out);
233   }
234   auto res = std::make_shared<ValueTuple>(value_list);
235   return res;
236 }
237 
IsNamedTuple(const py::object & obj)238 bool IsNamedTuple(const py::object &obj) { return py::hasattr(obj, "_fields") && py::isinstance<py::tuple>(obj); }
239 
ConvertNamedTuple(const py::object & obj,bool use_signature)240 ValuePtr ConvertNamedTuple(const py::object &obj, bool use_signature) {
241   MS_LOG(DEBUG) << "Converting python NamedTuple";
242   if (!py::hasattr(obj, "_asdict")) {
243     return nullptr;
244   }
245   auto asdict_fn = obj.attr("_asdict");
246   auto asdict_obj = asdict_fn();
247   auto dict_values = asdict_obj.cast<py::dict>();
248   std::vector<ValuePtr> keys;
249   std::vector<ValuePtr> values;
250   for (auto item : dict_values) {
251     ValuePtr key = nullptr;
252     ValuePtr value = nullptr;
253     bool success = ConvertData(py::cast<py::object>(item.first), &key, use_signature) &&
254                    ConvertData(py::cast<py::object>(item.second), &value, use_signature);
255     if (!success) {
256       return nullptr;
257     }
258     MS_LOG(DEBUG) << key->ToString() << ", " << value->ToString();
259     keys.push_back(key);
260     values.push_back(value);
261   }
262   auto obj_name = obj.attr("__class__").attr("__name__");
263   std::string sub_class_name = py::str(obj_name).cast<std::string>();
264   return std::make_shared<ValueNamedTuple>(sub_class_name, keys, values);
265 }
266 
ConvertStubTuple(const py::object & obj,bool use_signature)267 ValuePtr ConvertStubTuple(const py::object &obj, bool use_signature) {
268   MS_LOG(DEBUG) << "Converting python tuple";
269   auto tuple = obj.cast<py::tuple>();
270   std::vector<ValuePtr> value_list;
271   for (size_t it = 0; it < tuple.size(); ++it) {
272     ValuePtr out = nullptr;
273     bool success = ConvertStubData(tuple[it], &out, use_signature);
274     if (!success) {
275       return nullptr;
276     }
277     value_list.push_back(out);
278   }
279   return std::make_shared<ValueTuple>(value_list);
280 }
281 
ConvertList(const py::object & obj,bool use_signature)282 ValuePtr ConvertList(const py::object &obj, bool use_signature) {
283   MS_LOG(DEBUG) << "Converting python list";
284   PyRecursionScope scope(obj);
285 
286   auto list = obj.cast<py::list>();
287   std::vector<ValuePtr> value_list;
288   for (size_t it = 0; it < list.size(); ++it) {
289     ValuePtr out = nullptr;
290     bool success = ConvertData(list[it], &out, use_signature);
291     if (!success) {
292       return nullptr;
293     }
294     value_list.push_back(out);
295   }
296   auto res = std::make_shared<ValueList>(value_list);
297   return res;
298 }
299 
ConvertStubList(const py::object & obj,bool use_signature)300 ValuePtr ConvertStubList(const py::object &obj, bool use_signature) {
301   MS_LOG(DEBUG) << "Converting python list";
302   PyRecursionScope scope(obj);
303 
304   auto list = obj.cast<py::list>();
305   std::vector<ValuePtr> value_list;
306   for (size_t it = 0; it < list.size(); ++it) {
307     ValuePtr out = nullptr;
308     bool success = ConvertStubData(list[it], &out, use_signature);
309     if (!success) {
310       return nullptr;
311     }
312     value_list.push_back(out);
313   }
314   return std::make_shared<ValueList>(value_list);
315 }
316 
ConvertCellList(const py::object & obj,bool use_signature)317 ValuePtr ConvertCellList(const py::object &obj, bool use_signature) {
318   MS_LOG(DEBUG) << "Converting cell list";
319   PyRecursionScope scope(obj);
320 
321   py::sequence list = obj;
322   std::vector<ValuePtr> value_list;
323 
324   py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
325   bool is_celllist = py::cast<bool>(python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_CELL_LIST, obj));
326   for (const auto &element : list) {
327     // An element will directly convert to InterpretedObject if:
328     //   1. The container is not a cell list object.
329     //   2. The element should be single cell (cell with no __cell_as_list__ attr).
330     bool to_interpret = !is_celllist && py::isinstance<Cell>(element);
331     if (to_interpret) {
332       value_list.push_back(std::make_shared<parse::InterpretedObject>(element));
333       continue;
334     }
335     ValuePtr out = nullptr;
336     bool success = ConvertData(element, &out, use_signature);
337     if (!success) {
338       return nullptr;
339     }
340     value_list.push_back(out);
341   }
342   return std::make_shared<ValueTuple>(value_list);
343 }
344 
ConvertDict(const py::object & obj,bool use_signature)345 ValuePtr ConvertDict(const py::object &obj, bool use_signature) {
346   MS_LOG(DEBUG) << "Converting python dict";
347   PyRecursionScope scope(obj);
348 
349   auto dict_values = obj.cast<py::dict>();
350   std::vector<std::pair<ValuePtr, ValuePtr>> key_values;
351   for (auto item : dict_values) {
352     ValuePtr key = nullptr;
353     ValuePtr value = nullptr;
354     bool success = ConvertData(py::cast<py::object>(item.first), &key, use_signature) &&
355                    ConvertData(py::cast<py::object>(item.second), &value, use_signature);
356     if (!success) {
357       return nullptr;
358     }
359     (void)key_values.emplace_back(key, value);
360   }
361   auto res = std::make_shared<ValueDictionary>(key_values);
362   res->set_user_data<py::object>("origin_object", std::make_shared<py::object>(obj));
363   return res;
364 }
365 
ConvertModuleNameSpace(const py::object & obj)366 ValuePtr ConvertModuleNameSpace(const py::object &obj) {
367   MS_LOG(DEBUG) << "Converting python module";
368   py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
369   py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj);
370   auto converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_MODULE, module_namespace, obj);
371   MS_LOG(DEBUG) << "name_space: " << converted->ToString();
372   return converted;
373 }
374 
ConvertMsClass(const py::object & obj)375 ValuePtr ConvertMsClass(const py::object &obj) {
376   MS_LOG(DEBUG) << "Converting ms class";
377   // Convert class instance decorated with jit_class.
378   if (py::hasattr(obj, PYTHON_PARSE_METHOD)) {
379     MS_LOG(DEBUG) << "Convert obj to func graph.";
380     FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
381     if (func_graph == nullptr) {
382       MS_LOG(ERROR) << "Parse resolve function error.";
383       return nullptr;
384     }
385     PyObjectWrapperPtr python_obj = std::make_shared<PyObjectWrapper>(obj, "graph python obj");
386     func_graph->set_python_obj(python_obj);
387     return func_graph;
388   }
389   py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
390   py::object name = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MS_CLASS_NAME, obj);
391   auto cls_name = py::cast<std::string>(name);
392   return std::make_shared<MsClassObject>(obj, cls_name);
393 }
394 
ConvertPrimitiveClassType(const py::object & obj)395 ValuePtr ConvertPrimitiveClassType(const py::object &obj) {
396   // need check the primitive is class type or instance
397   auto obj_type = data_converter::GetObjType(obj);
398   if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
399     auto desc = py::cast<std::string>(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj));
400     // desc has format "<class xxxx>", strip the '<' and '>' by offset 1.
401     return std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
402   }
403   return nullptr;
404 }
405 
ConvertPrimitive(const py::object & obj,bool use_signature=false)406 ValuePtr ConvertPrimitive(const py::object &obj, bool use_signature = false) {
407   MS_LOG(DEBUG) << "Converting primitive object " << use_signature;
408 
409   auto class_type = ConvertPrimitiveClassType(obj);
410   if (class_type != nullptr) {
411     return class_type;
412   }
413   py::object adapter_obj = obj;
414   if (py::hasattr(obj, "__setattr_flag__")) {
415     if (py::hasattr(obj, "_clone")) {
416       auto clone_fn = obj.attr("_clone");
417       adapter_obj = clone_fn();
418     }
419   }
420   auto prim_adapter = adapter_obj.cast<PrimitivePyAdapterPtr>();
421   MS_EXCEPTION_IF_NULL(prim_adapter);
422   auto primitive = prim_adapter->attached_primitive();
423   if (primitive == nullptr) {
424     primitive = std::make_shared<PrimitivePy>(adapter_obj);
425     prim_adapter->set_attached_primitive(primitive);
426   }
427 
428   if (use_signature) {
429     return std::make_shared<prim::DoSignaturePrimitive>(primitive->name(), primitive);
430   }
431   return primitive;
432 }
433 
ConvertPrimitiveFunction(const py::object & obj)434 ValuePtr ConvertPrimitiveFunction(const py::object &obj) {
435   MS_LOG(DEBUG) << "Converting primitive function";
436   auto class_type = ConvertPrimitiveClassType(obj);
437   if (class_type != nullptr) {
438     return class_type;
439   }
440   auto prim_func_adapter = obj.cast<PrimitiveFunctionAdapterPtr>();
441   MS_EXCEPTION_IF_NULL(prim_func_adapter);
442   auto cpp_primitive_func = prim_func_adapter->attached_primitive_function();
443   if (cpp_primitive_func == nullptr) {
444     auto prim_name = py::getattr(obj, "name").cast<std::string>();
445     return std::make_shared<prim::DoTransPrimitiveFunction>(std::make_shared<Primitive>(prim_name));
446   }
447   return cpp_primitive_func;
448 }
449 
ConvertMetaFuncGraph(const py::object & obj,bool use_signature=false)450 ValuePtr ConvertMetaFuncGraph(const py::object &obj, bool use_signature = false) {
451   MS_LOG(DEBUG) << "Converting MetaFuncGraph object";
452   auto meta = obj.cast<MetaFuncGraphPtr>();
453   if (meta == nullptr) {
454     MS_LOG(ERROR) << "Resolve MetaFuncGraph error, get ptr is null";
455     return nullptr;
456   }
457   auto multi = meta->cast<prim::MultitypeFuncGraphPtr>();
458   if (multi != nullptr) {
459     multi->set_meta_obj(obj);
460   }
461   if (use_signature) {
462     return std::make_shared<prim::DoSignaturePrimitive>(meta->name(), meta);
463   }
464   return meta;
465 }
466 
ConvertFuncGraph(const py::object & obj)467 ValuePtr ConvertFuncGraph(const py::object &obj) {
468   MS_LOG(DEBUG) << "Converting FuncGraph object";
469   auto func_graph = obj.cast<FuncGraphPtr>();
470   if (func_graph == nullptr) {
471     MS_LOG(ERROR) << "Resolve FuncGraph error, get ptr is null";
472     return nullptr;
473   }
474   func_graph->set_attr("is_load", MakeValue(true));
475   return func_graph;
476 }
477 
ConvertSlice(const py::object & obj)478 ValuePtr ConvertSlice(const py::object &obj) {
479   MS_LOG(DEBUG) << "Converting slice object";
480 
481   auto convert_func = [obj](const std::string &attr) -> ValuePtr {
482     auto py_attr = py::getattr(obj, attr.c_str());
483     if (py::isinstance<py::none>(py_attr)) {
484       return kNone;
485     }
486     if (py::isinstance<py::int_>(py_attr)) {
487       auto value = py::cast<int64_t>(py_attr);
488       return MakeValue(value);
489     }
490     if (py::isinstance<Tensor>(py_attr)) {
491       return py::cast<TensorPtr>(py_attr);
492     }
493     if (IsStubTensor(py_attr)) {
494       return ConvertStubTensor(py_attr);
495     }
496     MS_LOG(EXCEPTION) << "Attribute '" << attr << "' of " << py::str(obj)
497                       << " should be int or Tensor with Int type but got " << py::str(py_attr);
498   };
499   ValuePtr start = convert_func(kSliceStart);
500   ValuePtr stop = convert_func(kSliceStop);
501   ValuePtr step = convert_func(kSliceStep);
502   return std::make_shared<ValueSlice>(start, stop, step);
503 }
504 
ConvertCellObjToFuncGraph(const py::object & obj,const ValuePtrList & args_value_list)505 ValuePtr ConvertCellObjToFuncGraph(const py::object &obj, const ValuePtrList &args_value_list) {
506   FuncGraphPtr func_graph = ConvertToFuncGraph(obj, args_value_list);
507   if (func_graph == nullptr) {
508     MS_LOG(ERROR) << "Parse resolve function error.";
509     return nullptr;
510   }
511   // if the cell object has specified bprop, it has user-defined bprop function parse and record it
512   if (py::hasattr(obj, CUSTOM_BPROP_NAME)) {
513     bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug"));
514     FuncGraphPtr bprop_graph =
515       enable_bprop_debug ? ConvertToBpropCut(obj) : ConvertToFuncGraph(obj, {}, PYTHON_MOD_GET_BPROP_METHOD);
516     if (bprop_graph != nullptr) {
517       (void)func_graph->transforms().emplace(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph));
518       (void)bprop_graph->transforms().emplace("primal", FuncGraphTransform(func_graph));
519       func_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
520       func_graph->set_flag(FUNC_GRAPH_FLAG_PRIMAL_OF_BPROP, true);
521     }
522   }
523   if (py::hasattr(obj, STAGE_NAME)) {
524     auto stage = py::cast<int>(py::getattr(obj, STAGE_NAME));
525     func_graph->set_stage(stage);
526   }
527   if (py::hasattr(obj, SEGMENT_NAME)) {
528     auto segment = py::cast<int>(py::getattr(obj, SEGMENT_NAME));
529     func_graph->set_segment(segment);
530   }
531   auto cell = py::cast<CellPtr>(obj);
532   if (cell != nullptr && cell->HasAttr(kAttrRandomOpSnapShot)) {
533     auto value = cell->GetAttr(kAttrRandomOpSnapShot);
534     MS_EXCEPTION_IF_NULL(value);
535     func_graph->set_attr(kAttrRandomOpSnapShot, value);
536   }
537   return func_graph;
538 }
539 
ConvertConstantNumpyNumber(const py::object & obj,ResolveType obj_type)540 ValuePtr ConvertConstantNumpyNumber(const py::object &obj, ResolveType obj_type) {
541   if (obj_type == RESOLVE_TYPE_NUMPY_INT_NUMBER) {
542     MS_LOG(INFO) << "Convert constant numpy int64_t number:" << (std::string)py::str(obj);
543     return MakeValue(py::cast<int64_t>(obj));
544   }
545   if (obj_type == RESOLVE_TYPE_NUMPY_FLOAT_NUMBER) {
546     MS_LOG(INFO) << "Convert constant numpy float number::" << (std::string)py::str(obj);
547     return MakeValue(py::cast<float>(obj));
548   }
549   if (obj_type == RESOLVE_TYPE_NUMPY_BOOL_NUMBER) {
550     MS_LOG(INFO) << "Convert constant numpy bool_ number::" << (std::string)py::str(obj);
551     return MakeValue(py::cast<bool>(obj));
552   }
553 
554   MS_LOG(ERROR) << "Convert numpy number type is invalid, obj: " << py::str(obj);
555   return nullptr;
556 }
557 
CheckJITForbiddenAPI(const py::object & obj)558 void CheckJITForbiddenAPI(const py::object &obj) {
559   auto module = python_adapter::GetPyModule(PYTHON_MOD_MODULE);
560   py::object res = python_adapter::CallPyModFn(module, PYTHON_MOD_GET_MODULE_AND_NAME_INFO, obj);
561   if (!py::isinstance<py::none>(res)) {
562     auto obj_info = py::cast<py::list>(res);
563     auto obj_module = py::cast<std::string>(obj_info[0]);
564     auto obj_name = py::cast<std::string>(obj_info[1]);
565     auto obj_type = py::cast<std::string>(obj_info[2]);
566     std::ostringstream oss;
567     oss << "Failed to compile in GRAPH_MODE because the " << obj_type << " '" << obj_module << "." << obj_name
568         << "' is not supported in 'construct' or function with @jit decorator. "
569         << "Try to use the " << obj_type << " '" << obj_module << "." << obj_name << "' externally "
570         << "such as initialized in the method '__init__' before assigning"
571         << ".\nFor more details, please refer to "
572         << "https://www.mindspore.cn/docs/zh-CN/master/design/dynamic_graph_and_static_graph.html \n";
573     // Check if the API is decoratored by @jit_forbidden_register.
574     bool is_jit_forbidden_register = data_converter::IsJITForbiddenAPI(obj);
575     if (is_jit_forbidden_register) {
576       MS_LOG(EXCEPTION) << oss.str();
577     }
578     // Check if the API's module is in the JIT forbidden module set.
579     bool is_jit_forbidden_module =
580       py::cast<bool>(python_adapter::CallPyModFn(module, PYTHON_MOD_IS_JIT_FORBIDDEN_MODULE, obj_info[0]));
581     if (is_jit_forbidden_module) {
582       MS_LOG(EXCEPTION) << oss.str();
583     }
584   }
585 }
586 
ConvertOtherObj(const py::object & obj,bool forbid_reuse=false)587 ValuePtr ConvertOtherObj(const py::object &obj, bool forbid_reuse = false) {
588   auto obj_type = data_converter::GetObjType(obj);
589   MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " ";
590   if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
591     // Check JIT forbidden API
592     CheckJITForbiddenAPI(obj);
593     MS_LOG(DEBUG) << "Resolve the class type, need create class instance.";
594     std::string desc = py::str(obj);
595     // desc has format "<class xxxx>", strip the '<' and '>' by offset 1.
596     return std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
597   }
598   if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD ||
599       (obj_type == RESOLVE_TYPE_CLASS_INSTANCE && py::hasattr(obj, PYTHON_PARSE_METHOD))) {
600     if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD) {
601       // Check JIT forbidden API
602       CheckJITForbiddenAPI(obj);
603       // Check if the function is from a third-party library.
604       py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
605       bool is_third_party_function =
606         python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_FROM_THIRD_PARTY_LIBRARY, obj).cast<bool>();
607       if (is_third_party_function) {
608         MS_LOG(DEBUG) << "Converting the function from third-party library: " << py::str(obj);
609         return std::make_shared<InterpretedObject>(obj);
610       }
611     }
612     MS_LOG(DEBUG) << "Convert the obj to func graph, type is " << obj_type;
613     FuncGraphPtr func_graph = ConvertToFuncGraph(obj, {}, PYTHON_MOD_GET_PARSE_METHOD, forbid_reuse);
614     if (func_graph == nullptr) {
615       MS_LOG(ERROR) << "Parse resolve function error.";
616       return nullptr;
617     }
618     return func_graph;
619   }
620   if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) {
621     MS_LOG(INTERNAL_EXCEPTION) << "Fail to convert class instance: " << py::str(obj);
622   }
623   // Start RESOLVE_TYPE_INVALID.
624   if (obj_type == RESOLVE_TYPE_NUMPY_INT_NUMBER || obj_type == RESOLVE_TYPE_NUMPY_FLOAT_NUMBER ||
625       obj_type == RESOLVE_TYPE_NUMPY_BOOL_NUMBER) {
626     return ConvertConstantNumpyNumber(obj, obj_type);
627   }
628   auto res = std::make_shared<InterpretedObject>(obj);
629   MS_EXCEPTION_IF_NULL(res);
630   MS_LOG(DEBUG) << "Get interpreted object: " << res->ToString();
631   return res;
632 }
633 
634 template <typename T>
ConvertNumberWithType(const T & obj,const TypePtr & dtype)635 ValuePtr ConvertNumberWithType(const T &obj, const TypePtr &dtype) {
636   ValuePtr data = nullptr;
637   auto int_dypte = dyn_cast<Int>(dtype);
638   if (int_dypte != nullptr) {
639     switch (int_dypte->nbits()) {
640       case kBit8:
641         data = std::make_shared<Int8Imm>(obj);
642         break;
643       case kBit16:
644         data = std::make_shared<Int16Imm>(obj);
645         break;
646       case kBit32:
647         data = std::make_shared<Int32Imm>(obj);
648         break;
649       case kBit64:
650         data = std::make_shared<Int64Imm>(obj);
651         break;
652       default:
653         data = std::make_shared<Int64Imm>(obj);
654     }
655     return data;
656   }
657 
658   auto uint_dypte = dyn_cast<UInt>(dtype);
659   if (uint_dypte != nullptr) {
660     switch (uint_dypte->nbits()) {
661       case kBit8:
662         data = std::make_shared<UInt8Imm>(obj);
663         break;
664       case kBit16:
665         data = std::make_shared<UInt16Imm>(obj);
666         break;
667       case kBit32:
668         data = std::make_shared<UInt32Imm>(obj);
669         break;
670       case kBit64:
671         data = std::make_shared<UInt64Imm>(obj);
672         break;
673       default:
674         data = std::make_shared<UInt32Imm>(obj);
675     }
676     return data;
677   }
678 
679   auto float_dypte = dyn_cast<Float>(dtype);
680   if (float_dypte != nullptr) {
681     switch (float_dypte->nbits()) {
682       case kBit32:
683         data = std::make_shared<FP32Imm>(obj);
684         break;
685       case kBit64:
686         data = std::make_shared<FP64Imm>(obj);
687         break;
688       default:
689         data = std::make_shared<FP32Imm>(obj);
690     }
691     return data;
692   }
693   return nullptr;
694 }
695 
ConvertIntegerWithType(const py::object & obj,const TypePtr & dtype=nullptr)696 ValuePtr ConvertIntegerWithType(const py::object &obj, const TypePtr &dtype = nullptr) {
697   auto obj_int64 = py::cast<int64_t>(obj);
698   // The mutable _Bool class inherits from int, because base class 'bool' is a marked final.
699   if (py::hasattr(obj, "__ms_mutable_bool__")) {
700     bool obj_bool = obj_int64 != 0;
701     return std::make_shared<BoolImm>(obj_bool);
702   }
703   if (dtype == nullptr) {
704     return std::make_shared<Int64Imm>(obj_int64);
705   }
706   return ConvertNumberWithType<int64_t>(obj_int64, dtype);
707 }
708 
ConvertFloatWithType(const py::object & obj,const TypePtr & dtype=nullptr)709 ValuePtr ConvertFloatWithType(const py::object &obj, const TypePtr &dtype = nullptr) {
710   auto obj_float32 = py::cast<pyfloat>(obj);
711   if (dtype == nullptr) {
712     auto obj_double = py::cast<double>(obj);
713     auto ret = std::make_shared<FP32Imm>(obj_float32);
714     ret->set_prim_value(obj_double);
715     return ret;
716   }
717   return ConvertNumberWithType<pyfloat>(obj_float32, dtype);
718 }
719 
ConvertNameSpace(const py::object & obj,bool use_signature)720 ValuePtr ConvertNameSpace(const py::object &obj, bool use_signature) {
721   MS_LOG(DEBUG) << "Converting python NameSpace";
722   auto res = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj);
723   MS_LOG(DEBUG) << "name_space: " << res->ToString();
724   return res;
725 }
726 
727 template <typename T, typename U>
PyCast(const py::object & obj)728 ValuePtr PyCast(const py::object &obj) {
729   return std::make_shared<T>(py::cast<U>(obj));
730 }
731 
732 template <typename T>
ObjCast(const py::object & obj)733 ValuePtr ObjCast(const py::object &obj) {
734   return obj.cast<T>();
735 }
736 
GetDataConvertFuncs()737 static const std::vector<DataConvertFuncPtr> &GetDataConvertFuncs() {
738   // Convert data by python object type.
739   static const std::vector<DataConvertFuncPtr> data_convert_funcs{
740     // AdapterTensor needs to be processed before Tensor because it inherits from Tensor.
741     std::make_shared<ByFuncDataConvertFunc>(IsStubTensor, ConvertStubTensor),
742     std::make_shared<ByFuncDataConvertFunc>(IsNamedTuple, ConvertNamedTuple),
743     std::make_shared<ByTypeDataConvertFunc<Tensor>>(ObjCast<TensorPtr>),
744     std::make_shared<ByAttrDataConvertFunc>(ConvertMsClass, PYTHON_MS_CLASS),
745     std::make_shared<ByTypeDataConvertFunc<BaseTensor>>(ObjCast<BaseTensorPtr>),
746     std::make_shared<ByTypeDataConvertFunc<py::tuple>>(ConvertTuple),
747     std::make_shared<ByTypeDataConvertFunc<py::list>>(ConvertList),
748     std::make_shared<ByTypeDataConvertFunc<py::bool_>>(PyCast<BoolImm, bool>),
749     std::make_shared<ByTypeDataConvertFunc<py::int_>>(ConvertIntegerWithType),
750     std::make_shared<ByTypeDataConvertFunc<py::float_>>(ConvertFloatWithType),
751     std::make_shared<ByTypeDataConvertFunc<py::str>>(PyCast<StringImm, string>),
752     std::make_shared<ByTypeDataConvertFunc<py::none>>(kNone),
753     std::make_shared<ByTypeDataConvertFunc<MetaTensor>>(ObjCast<MetaTensorPtr>),
754     std::make_shared<ByTypeDataConvertFunc<CSRTensor>>(ObjCast<CSRTensorPtr>),
755     std::make_shared<ByTypeDataConvertFunc<COOTensor>>(ObjCast<COOTensorPtr>),
756     std::make_shared<ByTypeDataConvertFunc<MapTensor>>(ObjCast<MapTensorPtr>),
757     std::make_shared<ByTypeDataConvertFunc<py::ellipsis>>(kEllipsis),
758     std::make_shared<ByTypeDataConvertFunc<py::module>>(ConvertModuleNameSpace),
759     std::make_shared<ByTypeDataConvertFunc<Type>>(ObjCast<TypePtr>),
760     std::make_shared<ByTypeDataConvertFunc<UMonad>>(ObjCast<UMonadPtr>),
761     std::make_shared<ByTypeDataConvertFunc<IOMonad>>(ObjCast<IOMonadPtr>),
762     std::make_shared<ByAttrDataConvertFunc>(ConvertNameSpace, PYTHON_CLASS_MEMBER_NAMESPACE),
763     std::make_shared<ByTypeDataConvertFunc<py::dict>>(ConvertDict),
764     std::make_shared<ByAttrDataConvertFunc>(ConvertDict, PYTHON_CELL_AS_DICT),
765     std::make_shared<ByTypeDataConvertFunc<py::slice>>(ConvertSlice),
766     std::make_shared<ByAttrDataConvertFunc>(ConvertCellList, PYTHON_CELL_AS_LIST, PYTHON_CELL_LIST_FROM_TOP),
767     std::make_shared<ByTypeDataConvertFunc<Cell>>(ConvertCellObjToFuncGraph),
768     std::make_shared<ByAttrDataConvertFunc>(ConvertPrimitive, PYTHON_PRIMITIVE_FLAG),
769     std::make_shared<ByAttrDataConvertFunc>(ConvertPrimitiveFunction, PYTHON_PRIMITIVE_FUNCTION_FLAG),
770     std::make_shared<ByTypeDataConvertFunc<MetaFuncGraph>>(ConvertMetaFuncGraph),
771     std::make_shared<ByTypeDataConvertFunc<FuncGraph>>(ConvertFuncGraph),
772   };
773   return data_convert_funcs;
774 }
775 
GetStubDataConvertFuncs()776 static const std::vector<DataConvertFuncPtr> &GetStubDataConvertFuncs() {
777   // Convert data by python object type.
778   static const std::vector<DataConvertFuncPtr> data_convert_funcs{
779     std::make_shared<ByFuncDataConvertFunc>([](const py::object &obj) -> bool { return IsStubTensor(obj); },
780                                             PyStubNodeCast),
781     std::make_shared<ByTypeDataConvertFunc<py::tuple>>(ConvertStubTuple),
782     std::make_shared<ByTypeDataConvertFunc<py::list>>(ConvertStubList),
783   };
784   return data_convert_funcs;
785 }
786 
RemoveRecomputeScope(const FuncGraphPtr & func_graph)787 void RemoveRecomputeScope(const FuncGraphPtr &func_graph) {
788   MS_EXCEPTION_IF_NULL(func_graph);
789   auto nodes = TopoSort(func_graph->get_return(), SuccDeeperSimple);
790 
791   for (const auto &node : nodes) {
792     MS_EXCEPTION_IF_NULL(node);
793     if (!node->isa<CNode>()) {
794       continue;
795     }
796     const auto &origin_scope_name = node->scope()->name();
797     if (origin_scope_name.compare(0, strlen(kAttrRecompute), kAttrRecompute) == 0) {
798       auto remove_recompute_scope = origin_scope_name.substr(strlen(kAttrRecompute) + 1);
799       node->set_scope(std::make_shared<Scope>(remove_recompute_scope));
800     }
801   }
802 }
803 }  // namespace
804 
ConvertData(const py::object & obj,ValuePtr * data,bool use_signature,const TypePtr & dtype,bool forbid_reuse)805 bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature, const TypePtr &dtype, bool forbid_reuse) {
806   // Check parameter valid
807   if (data == nullptr) {
808     MS_LOG(ERROR) << "The value pointer should not be null.";
809     return false;
810   }
811   ValuePtr converted = nullptr;
812   bool matched = false;
813   const auto &converters = GetDataConvertFuncs();
814   for (auto &converter : converters) {
815     if (converter->Matched(obj)) {
816       converted = converter->ConvertPyObject(obj, use_signature, dtype);
817       matched = true;
818       break;
819     }
820   }
821   if (!matched) {
822     converted = ConvertOtherObj(obj, forbid_reuse);
823   }
824   *data = converted;
825   return converted != nullptr;
826 }
827 
ConvertStubData(const py::object & obj,ValuePtr * data,bool use_signature,const TypePtr & dtype,bool forbid_reuse)828 bool ConvertStubData(const py::object &obj, ValuePtr *data, bool use_signature, const TypePtr &dtype,
829                      bool forbid_reuse) {
830   if (data == nullptr) {
831     MS_LOG(ERROR) << "The value pointer should not be null.";
832     return false;
833   }
834   ValuePtr converted = nullptr;
835   const auto &convert_funcs = GetStubDataConvertFuncs();
836   for (auto &convert_func : convert_funcs) {
837     if (convert_func->Matched(obj)) {
838       converted = convert_func->ConvertPyObject(obj, use_signature, dtype);
839       *data = converted;
840       return converted != nullptr;
841     }
842   }
843   return ConvertData(obj, data, use_signature, dtype, forbid_reuse);
844 }
845 
MakeReusingGraph(const FuncGraphPtr & base_graph)846 FuncGraphPtr MakeReusingGraph(const FuncGraphPtr &base_graph) {
847   static int order = 0;
848   base_graph->set_attr(FUNC_GRAPH_FLAG_CELL_LAZY_INLINE_ORDER, MakeValue(++order));
849   base_graph->debug_info()->set_name("CR_" + base_graph->debug_info()->name());
850   MS_LOG(INFO) << "Lazy inline reusing graph: " << base_graph->ToString()
851                << ", args: " << base_graph->parameters().size() << ", parse order: " << order;
852   return base_graph;
853 }
854 
MakeCellFuncGraph(const py::object & obj,const std::string & obj_id,const FuncGraphPtr & reusing_graph)855 FuncGraphPtr MakeCellFuncGraph(const py::object &obj, const std::string &obj_id, const FuncGraphPtr &reusing_graph) {
856   FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
857   // Normalize the name.
858   auto function_name = obj_id;
859   std::replace(function_name.begin(), function_name.end(), '.', '_');
860   std::replace(function_name.begin(), function_name.end(), '<', '_');
861   std::replace(function_name.begin(), function_name.end(), '>', '_');
862   func_graph->debug_info()->set_name(function_name);
863   PyObjectWrapperPtr python_obj = std::make_shared<PyObjectWrapper>(obj, "graph python obj");
864   func_graph->set_python_obj(python_obj);
865   func_graph->set_flag(FUNC_GRAPH_FLAG_PROXY_GRAPH, true);
866   std::vector<AnfNodePtr> new_node_inputs;
867   new_node_inputs.push_back(NewValueNode(reusing_graph));
868   for (const auto &origin_param : reusing_graph->parameters()) {
869     auto param = func_graph->add_parameter();
870     param->set_debug_info(origin_param->debug_info());
871     new_node_inputs.push_back(param);
872   }
873   AnfNodePtr out = func_graph->NewCNodeInOrder(new_node_inputs);
874   func_graph->set_output(out);
875   MS_LOG(INFO) << "Lazy inline cell: " << func_graph->ToString() << ", args: " << func_graph->parameters().size();
876   return func_graph;
877 }
878 
ProcessLazyInline(const py::object & obj,const ValuePtrList & args_value_list,const std::string & python_mod_get_parse_method,const std::string & obj_id,const std::string & obj_key)879 FuncGraphPtr ProcessLazyInline(const py::object &obj, const ValuePtrList &args_value_list,
880                                const std::string &python_mod_get_parse_method, const std::string &obj_id,
881                                const std::string &obj_key) {
882   ValuePtr key_value = nullptr;
883   FuncGraphPtr reusing_graph = nullptr;
884   bool is_key_cache = data_converter::GetObjectValue(obj_key, &key_value);
885   if (is_key_cache && key_value != nullptr && key_value->isa<FuncGraph>()) {
886     MS_LOG(DEBUG) << "Get the cache data, obj: " << obj_key;
887     reusing_graph = key_value->cast<FuncGraphPtr>();
888   } else {
889     auto base_graph = ParsePythonCode(obj, python_mod_get_parse_method, args_value_list);
890     if (base_graph == nullptr) {
891       MS_LOG(ERROR) << "Parse resolve function error.";
892       return nullptr;
893     }
894     if (Parser::GetTopFuncGraph() == base_graph) {
895       return base_graph;
896     }
897     PyObjectWrapperPtr python_obj = std::make_shared<PyObjectWrapper>(obj, "graph python obj");
898     base_graph->set_python_obj(python_obj);
899     MS_LOG(DEBUG) << "Parse reusing function: " << reusing_graph->ToString();
900     reusing_graph = MakeReusingGraph(base_graph);
901     data_converter::CacheObjectValue(obj_key, reusing_graph);
902   }
903   // Let the original cell graph call the reusable graph.
904   auto func_graph = MakeCellFuncGraph(obj, obj_id, reusing_graph);
905   MS_LOG(DEBUG) << func_graph->ToString() << " calls " << reusing_graph->ToString();
906   return func_graph;
907 }
908 
909 // Convert data to graph
ConvertToFuncGraph(const py::object & obj,const ValuePtrList & args_value_list,const std::string & python_mod_get_parse_method,bool forbid_reuse)910 FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const ValuePtrList &args_value_list,
911                                 const std::string &python_mod_get_parse_method, bool forbid_reuse) {
912   std::vector<std::string> results = data_converter::GetObjKey(obj);
913   std::string obj_id = results[0] + python_mod_get_parse_method;
914   std::string obj_key = results[1];
915   FuncGraphPtr func_graph = nullptr;
916   ValuePtr value = nullptr;
917   bool is_debug = MsContext::GetInstance()->get_param<int>(MS_CTX_DEBUG_LEVEL) == kLevelDebug;
918   bool is_cache = data_converter::GetObjectValue(obj_id, &value);
919   if (!is_debug && is_cache && value != nullptr && value->isa<FuncGraph>()) {
920     func_graph = value->cast<FuncGraphPtr>();
921     if (!func_graph->dropped()) {
922       bool has_forbid_reuse_attr = py::hasattr(obj, PYTHON_FUNCTION_FORBID_REUSE);
923       if (forbid_reuse || has_forbid_reuse_attr) {
924         return BasicClone(func_graph);
925       }
926       return func_graph;
927     }
928   }
929   if (obj_key.find("lazy_inline") != obj_key.npos) {
930     func_graph = ProcessLazyInline(obj, args_value_list, python_mod_get_parse_method, results[0], obj_key);
931     if (func_graph == nullptr) {
932       return nullptr;
933     }
934   } else {
935     func_graph = ParsePythonCode(obj, python_mod_get_parse_method, args_value_list);
936     if (func_graph == nullptr) {
937       MS_LOG(ERROR) << "Parse resolve function error.";
938       return nullptr;
939     }
940   }
941 
942   data_converter::CacheObjectValue(obj_id, func_graph);
943   if (!obj_key.empty() && python_mod_get_parse_method == PYTHON_MOD_GET_PARSE_METHOD) {
944     data_converter::SetObjGraphValue(obj_key, func_graph);
945   }
946 
947   PyObjectWrapperPtr python_obj = std::make_shared<PyObjectWrapper>(obj, "graph python obj");
948   func_graph->set_python_obj(python_obj);
949 
950   if (forbid_reuse) {
951     // The function may be set recomputed in parse.
952     if (!data_converter::IsCellInstance(obj)) {
953       RemoveRecomputeScope(func_graph);
954     }
955     // Return the clone graph because the graph may be set recomputed later.
956     return BasicClone(func_graph);
957   }
958 
959   return func_graph;
960 }
961 
GetArgDefaultValue(const std::string & prim_name,const std::string & arg_name)962 ValuePtr GetArgDefaultValue(const std::string &prim_name, const std::string &arg_name) {
963   py::module mod = py::module::import(PYTHON_MOD_PRIMITIVE_OP_CREATE_INSTANCE_HELPER_MODULE);
964   if (!py::hasattr(mod, PYTHON_MOD_PRIMITIVE_OP_DEFAULT_VALUE_DICT)) {
965     MS_LOG(INTERNAL_EXCEPTION) << "Can not found " << PYTHON_MOD_PRIMITIVE_OP_DEFAULT_VALUE_DICT << "in "
966                                << PYTHON_MOD_PRIMITIVE_OP_CREATE_INSTANCE_HELPER_MODULE << ".";
967   }
968   py::dict op_default_dict = mod.attr(PYTHON_MOD_PRIMITIVE_OP_DEFAULT_VALUE_DICT);
969   if (!op_default_dict.contains(py::str(prim_name))) {
970     return nullptr;
971   }
972   py::dict prim_default_dict = op_default_dict[py::str(prim_name)];
973   if (!prim_default_dict.contains(py::str(arg_name))) {
974     return nullptr;
975   }
976   auto default_value = prim_default_dict[py::str(arg_name)];
977   ValuePtr converted_ret = nullptr;
978   bool converted = ConvertData(default_value, &converted_ret);
979   if (!converted) {
980     const std::string &default_name = py::str(default_value);
981     MS_EXCEPTION(ValueError) << "For Operator[" << prim_name << "], '" << default_name
982                              << "' is not supported as the default value for '" << arg_name << "'.";
983   }
984   return converted_ret;
985 }
986 
987 namespace data_converter {
988 static mindspore::HashMap<std::string, ValuePtr> object_map_;
989 
990 static mindspore::OrderedMap<std::string, std::vector<FuncGraphPtr>> object_graphs_map_;
991 
SetObjGraphValue(const std::string & obj_key,const FuncGraphPtr & data)992 void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) {
993   object_graphs_map_[obj_key].push_back(data);
994   MS_LOG(DEBUG) << "Set func graph size: " << object_graphs_map_.size();
995 }
996 
GetObjGraphs()997 const mindspore::OrderedMap<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs() {
998   MS_LOG(DEBUG) << "Obj graphs size: " << object_graphs_map_.size();
999   return object_graphs_map_;
1000 }
1001 
CacheObjectValue(const std::string & obj_key,const ValuePtr & data)1002 void CacheObjectValue(const std::string &obj_key, const ValuePtr &data) { object_map_[obj_key] = data; }
1003 
GetObjectValue(const std::string & obj_key,ValuePtr * const data)1004 bool GetObjectValue(const std::string &obj_key, ValuePtr *const data) {
1005   if (object_map_.count(obj_key) != 0) {
1006     *data = object_map_[obj_key];
1007     return true;
1008   }
1009   return false;
1010 }
1011 
GetObjKey(const py::object & obj)1012 std::vector<std::string> GetObjKey(const py::object &obj) {
1013   py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
1014   py::tuple obj_tuple = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_KEY, obj);
1015   if (obj_tuple.size() != 2) {
1016     MS_LOG(INTERNAL_EXCEPTION) << "The function of \'get_obj_key()\' must return 2 elements";
1017   }
1018   return {py::cast<std::string>(obj_tuple[0]), py::cast<std::string>(obj_tuple[1])};
1019 }
1020 
1021 // Get obj detail type
GetObjType(const py::object & obj)1022 ResolveType GetObjType(const py::object &obj) {
1023   try {
1024     py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
1025     auto obj_type = ResolveType(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast<int32_t>());
1026     return obj_type;
1027   } catch (const py::error_already_set &ex) {
1028     MS_LOG(ERROR) << "Meet a exception from Python when get the type of \'" << py::str(obj) << "\'.\n" << ex.what();
1029     std::rethrow_exception(std::current_exception());
1030   } catch (const py::type_error &ex) {
1031     MS_LOG(ERROR) << "Meet a exception when get the type of \'" << py::str(obj) << "\'.\n" << ex.what();
1032     std::rethrow_exception(std::current_exception());
1033   }
1034 }
1035 
1036 // Get class instance detail type.
GetClassInstanceType(const py::object & obj)1037 ClassInstanceType GetClassInstanceType(const py::object &obj) {
1038   py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
1039   auto class_type =
1040     ClassInstanceType(python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_CLASS_INSTANCE_TYPE, obj).cast<int32_t>());
1041   return class_type;
1042 }
1043 
1044 // Check if the object is Cell instance.
IsCellInstance(const py::object & obj)1045 bool IsCellInstance(const py::object &obj) {
1046   auto class_type = GetClassInstanceType(obj);
1047   return class_type == CLASS_INSTANCE_TYPE_CELL;
1048 }
1049 
1050 // Check if the object is Numpy Array instance.
IsNumpyArrayInstance(const py::object & obj)1051 bool IsNumpyArrayInstance(const py::object &obj) {
1052   auto class_type = GetClassInstanceType(obj);
1053   return class_type == CLASS_INSTANCE_TYPE_NUMPY_ARRAY;
1054 }
1055 
1056 // Check if the object is MsClass instance.
IsMsClassInstance(const py::object & obj)1057 bool IsMsClassInstance(const py::object &obj) { return py::hasattr(obj, PYTHON_MS_CLASS); }
1058 
1059 // Check if the object is jit forbidden api.
IsJITForbiddenAPI(const py::object & obj)1060 bool IsJITForbiddenAPI(const py::object &obj) { return py::hasattr(obj, PYTHON_JIT_FORBIDDEN); }
1061 
1062 // Check if the object is class type.
IsClassType(const py::object & obj)1063 bool IsClassType(const py::object &obj) {
1064   py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
1065   return python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_CLASS_TYPE, obj).cast<bool>();
1066 }
1067 
1068 // Create the python class instance.
CreatePythonObject(const py::object & type,const py::tuple & args_kwargs)1069 py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs) {
1070   py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
1071   // `args_kwargs` maybe a tuple(*args), tuple(**kwargs), or tuple(*args, **kwargs).
1072   return args_kwargs.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_INSTANCE, type)
1073                              : python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_INSTANCE, type, args_kwargs);
1074 }
1075 
1076 // Call the python script string.
CallPythonScript(const py::object & script,const py::tuple & args_kwargs)1077 py::object CallPythonScript(const py::object &script, const py::tuple &args_kwargs) {
1078   py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
1079   // `args_kwargs` is a tuple(dict(global), dict(local)).
1080   return python_adapter::CallPyModFn(mod, PYTHON_MOD_EVAL_PY_SCRIPT, script, args_kwargs);
1081 }
1082 
1083 // Get the ids of python script string.
GetPythonScriptIdAttrs(const py::object & script)1084 py::set GetPythonScriptIdAttrs(const py::object &script) {
1085   py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
1086   return python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_SCRIPT_ID_ATTRS, script);
1087 }
1088 
PyDataToValue(const py::object & obj)1089 ValuePtr PyDataToValue(const py::object &obj) {
1090   py::object to_convert = obj;
1091   ValuePtr value = nullptr;
1092   (void)ConvertData(to_convert, &value);
1093   return value;
1094 }
1095 
PyDataToStubNode(const py::object & obj)1096 ValuePtr PyDataToStubNode(const py::object &obj) {
1097   py::object to_convert = obj;
1098   ValuePtr value = nullptr;
1099   (void)ConvertStubData(to_convert, &value);
1100   return value;
1101 }
1102 
ClearObjectCache()1103 void ClearObjectCache() {
1104   object_map_.clear();
1105   object_graphs_map_.clear();
1106 }
1107 }  // namespace data_converter
1108 
ConvertData(const py::object & obj)1109 ValuePtr DataConverter::ConvertData(const py::object &obj) {
1110   const auto &convert_funcs = GetDataConvertFuncs();
1111   for (auto &convert_func : convert_funcs) {
1112     if (convert_func->Matched(obj)) {
1113       return convert_func->ConvertPyObject(obj, use_signature_, dtype_, args_value_list_);
1114     }
1115   }
1116   return ConvertOtherObj(obj, forbid_reuse_);
1117 }
1118 
ConvertPythonFloatToScalarValue(double value)1119 inline ValuePtr ConvertPythonFloatToScalarValue(double value) {
1120   auto ret = std::make_shared<FP32Imm>(static_cast<float>(value));
1121   ret->set_prim_value(value);
1122   return ret;
1123 }
1124 
ConvertBool(const py::object & obj)1125 ValuePtr ConvertBool(const py::object &obj) {
1126   if (!py::isinstance<py::bool_>(obj)) {
1127     return nullptr;
1128   }
1129   return PyCast<BoolImm, bool>(obj);
1130 }
1131 
ConvertInt(const py::object & obj)1132 ValuePtr ConvertInt(const py::object &obj) {
1133   // bool is also an instance of py::int_
1134   if (py::isinstance<py::bool_>(obj) || !py::isinstance<py::int_>(obj)) {
1135     return nullptr;
1136   }
1137   return ConvertIntegerWithType(obj);
1138 }
1139 
ConvertFloat(const py::object & obj)1140 ValuePtr ConvertFloat(const py::object &obj) {
1141   if (!py::isinstance<py::float_>(obj)) {
1142     return nullptr;
1143   }
1144   return ConvertFloatWithType(obj);
1145 }
1146 
ConvertNumber(const py::object & obj)1147 ValuePtr ConvertNumber(const py::object &obj) {
1148   if (py::isinstance<py::bool_>(obj)) {
1149     return PyCast<BoolImm, bool>(obj);
1150   }
1151 
1152   if (py::isinstance<py::int_>(obj)) {
1153     return ConvertIntegerWithType(obj);
1154   }
1155 
1156   if (py::isinstance<py::float_>(obj)) {
1157     return ConvertFloatWithType(obj);
1158   }
1159 
1160   return nullptr;
1161 }
1162 
ConvertTensor(const py::object & obj)1163 ValuePtr ConvertTensor(const py::object &obj) {
1164   if (IsStubTensor(obj)) {
1165     return PyStubNodeCast(obj);
1166   }
1167 
1168   if (!py::isinstance<mindspore::tensor::Tensor>(obj)) {
1169     return nullptr;
1170   }
1171 
1172   return ObjCast<TensorPtr>(obj);
1173 }
1174 
ConvertTensorValue(const py::object & obj)1175 TensorPtr ConvertTensorValue(const py::object &obj) {
1176   // The difference between the new ConvertTensorValue function and the existing ConvertTensor is:
1177   // If the obj a StubNode, it must be called the WaitValue to convert to a Tensor.
1178   if (IsStubTensor(obj)) {
1179     auto py_stub = py::getattr(obj, stub::PY_ATTR_STUB);
1180     auto stub = py_stub.cast<stub::StubNodePtr>();
1181     if (stub == nullptr) {
1182       return py::getattr(obj, stub::PY_ATTR_TENSOR).cast<tensor::TensorPtr>();
1183     }
1184     auto value = stub->WaitValue();
1185     auto tensor = value->cast<TensorPtr>();
1186     if (tensor == nullptr) {
1187       // BaseTensor should convert to Tensor for Graph mode
1188       auto base_tensor = value->cast<BaseTensorPtr>();
1189       auto real_tensor = std::make_shared<Tensor>(*base_tensor);
1190       stub->SetValue(real_tensor);
1191       return real_tensor;
1192     }
1193     return tensor;
1194   }
1195   if (!py::isinstance<mindspore::tensor::Tensor>(obj)) {
1196     return nullptr;
1197   }
1198   return obj.cast<TensorPtr>();
1199 }
1200 
GetTensorDataPtr(const tensor::TensorPtr & tensor)1201 static inline void *GetTensorDataPtr(const tensor::TensorPtr &tensor) {
1202   MS_EXCEPTION_IF_NULL(tensor);
1203   const auto &device_address = tensor->device_address();
1204   if (device_address != nullptr) {
1205     // Before get data, sync form device address should be performed first
1206     tensor->data_sync();
1207   }
1208   return tensor->data_c();
1209 }
1210 
ConvertStr(const py::object & obj)1211 ValuePtr ConvertStr(const py::object &obj) {
1212   if (!py::isinstance<py::str>(obj)) {
1213     return nullptr;
1214   }
1215   return PyCast<StringImm, string>(obj);
1216 }
1217 
ConvertAny(const py::object & obj)1218 ValuePtr ConvertAny(const py::object &obj) { return parse::data_converter::PyDataToStubNode(obj); }
1219 
ConvertDtype(const py::object & obj)1220 ValuePtr ConvertDtype(const py::object &obj) {
1221   if (!py::isinstance<mindspore::Type>(obj)) {
1222     MS_LOG(EXCEPTION) << "Get arg is not mindspore type " << py::str(obj);
1223   }
1224   return obj.cast<TypePtr>();
1225 }
1226 
1227 template <typename TS, typename TD, OpDefConvertFunc func>
ConvertSequence(const py::object & obj)1228 ValuePtr ConvertSequence(const py::object &obj) {
1229   if (!py::isinstance<TS>(obj)) {
1230     return nullptr;
1231   }
1232   auto seq = obj.cast<TS>();
1233   std::vector<ValuePtr> value_list;
1234   for (size_t it = 0; it < seq.size(); ++it) {
1235     auto out = func(seq[it]);
1236     if (out == nullptr) {
1237       return nullptr;
1238     }
1239     value_list.emplace_back(out);
1240   }
1241   return std::make_shared<TD>(value_list);
1242 }
1243 
1244 template <typename T, OpDefConvertFunc func>
ConvertSingleElementToSequence(const py::object & obj)1245 ValuePtr ConvertSingleElementToSequence(const py::object &obj) {
1246   auto value = func(obj);
1247   if (value == nullptr) {
1248     return nullptr;
1249   }
1250   std::vector<ValuePtr> value_list{value};
1251   return std::make_shared<T>(std::move(value_list));
1252 }
1253 
1254 template <typename T1, typename T2>
ConvertSingleElementToTensor(const py::object & obj)1255 ValuePtr ConvertSingleElementToTensor(const py::object &obj) {
1256   if (!py::isinstance<T1>(obj)) {
1257     return nullptr;
1258   }
1259 
1260   auto v = py::cast<T2>(obj);
1261   return std::make_shared<tensor::Tensor>(v);
1262 }
1263 
ConvertNumberToTensor(const py::object & obj)1264 ValuePtr ConvertNumberToTensor(const py::object &obj) {
1265   if (py::isinstance<py::bool_>(obj)) {
1266     auto v = py::cast<bool>(obj);
1267     return std::make_shared<tensor::Tensor>(v);
1268   }
1269 
1270   if (py::isinstance<py::int_>(obj)) {
1271     auto v = py::cast<int64_t>(obj);
1272     return std::make_shared<tensor::Tensor>(v);
1273   }
1274 
1275   if (py::isinstance<py::float_>(obj)) {
1276     auto v = py::cast<pyfloat>(obj);
1277     return std::make_shared<tensor::Tensor>(v);
1278   }
1279 
1280   return nullptr;
1281 }
1282 
1283 template <typename TS, typename TSE, typename TDE>
ConvertSequenceToTensor(const py::object & obj)1284 ValuePtr ConvertSequenceToTensor(const py::object &obj) {
1285   if (!py::isinstance<TS>(obj)) {
1286     return nullptr;
1287   }
1288 
1289   auto seq = obj.cast<TS>();
1290   if (seq.size() == 0) {
1291     return nullptr;
1292   }
1293 
1294   std::vector<TDE> value_list;
1295   for (size_t it = 0; it < seq.size(); ++it) {
1296     if (!py::isinstance<TSE>(seq[it])) {
1297       return nullptr;
1298     }
1299 
1300     auto value = py::cast<TDE>(seq[it]);
1301     value_list.emplace_back(value);
1302   }
1303 
1304   return std::make_shared<tensor::Tensor>(value_list);
1305 }
1306 
1307 template <typename TS>
ConvertSequenceBoolToTensor(const py::object & obj)1308 ValuePtr ConvertSequenceBoolToTensor(const py::object &obj) {
1309   if (!py::isinstance<TS>(obj)) {
1310     return nullptr;
1311   }
1312 
1313   auto seq = obj.cast<TS>();
1314   if (seq.size() == 0) {
1315     return nullptr;
1316   }
1317 
1318   auto tensor = std::make_shared<tensor::Tensor>(kNumberTypeBool, ShapeVector({static_cast<int64_t>(seq.size())}));
1319   auto data = static_cast<bool *>(tensor->data_c());
1320   for (size_t it = 0; it < seq.size(); ++it) {
1321     if (!py::isinstance<py::bool_>(seq[it])) {
1322       return nullptr;
1323     }
1324 
1325     auto value = py::cast<bool>(seq[it]);
1326     data[it] = value;
1327   }
1328 
1329   return tensor;
1330 }
1331 
1332 template <typename TD, typename TDE, typename IMMTYPE, TypeId tid>
ConvertTensorToSequence(const py::object & obj)1333 ValuePtr ConvertTensorToSequence(const py::object &obj) {
1334   auto tensor = ConvertTensorValue(obj);
1335   if (tensor == nullptr) {
1336     MS_LOG(INFO) << "Can not convert python object with type [" << obj.get_type() << "] to Tensor.";
1337     return nullptr;
1338   }
1339 
1340   auto data_type = tensor->data_type();
1341   // Since the dst object type is only, once the src object is validated as Tensor, the other converting errors should
1342   // be thrown. There is no other paths for this case to run successfully.
1343   if (data_type != tid) {
1344     MS_LOG(ERROR) << "Can not convert Tensor with type " << TypeIdToString(data_type) << "to Sequence with type "
1345                   << TypeIdToString(tid) << ".";
1346     return nullptr;
1347   }
1348 
1349   auto shape = tensor->shape();
1350   if (shape.size() > 1) {
1351     MS_LOG(ERROR) << "Only support converting 1-D Tensor or scalar Tensor to sequence. But got the shape of Tensor: "
1352                   << shape;
1353     return nullptr;
1354   }
1355 
1356   auto data = static_cast<TDE *>(GetTensorDataPtr(tensor));
1357   auto size = tensor->DataSize();
1358   std::vector<ValuePtr> value_list;
1359   for (size_t i = 0; i < size; i++) {
1360     value_list.emplace_back(std::make_shared<IMMTYPE>(data[i]));
1361   }
1362   return std::make_shared<TD>(value_list);
1363 }
1364 
1365 template <typename TD>
ConvertTensorToSequenceInt(const py::object & obj)1366 ValuePtr ConvertTensorToSequenceInt(const py::object &obj) {
1367   auto tensor = ConvertTensorValue(obj);
1368   if (tensor == nullptr) {
1369     MS_LOG(INFO) << "Can not convert python object with type [" << obj.get_type() << "] to Tensor.";
1370     return nullptr;
1371   }
1372 
1373   auto shape = tensor->shape();
1374   if (shape.size() > 1) {
1375     MS_LOG(ERROR) << "Only support converting 1-D Tensor or scalar Tensor to sequence. But got the shape of Tensor: "
1376                   << shape;
1377     return nullptr;
1378   }
1379 
1380   auto data_type = tensor->data_type();
1381   if (data_type != kNumberTypeInt64 && data_type != kNumberTypeInt32) {
1382     MS_LOG(ERROR) << "Can not convert Tensor with type " << TypeIdToString(data_type) << "to Int Sequence.";
1383     return nullptr;
1384   }
1385   auto size = tensor->DataSize();
1386   std::vector<ValuePtr> value_list;
1387   if (data_type == kNumberTypeInt64) {
1388     auto data = static_cast<int64_t *>(GetTensorDataPtr(tensor));
1389     std::transform(data, data + size, std::back_inserter(value_list),
1390                    [](int64_t num) { return std::make_shared<Int64Imm>(num); });
1391   } else {
1392     auto data = static_cast<int32_t *>(GetTensorDataPtr(tensor));
1393     std::transform(data, data + size, std::back_inserter(value_list),
1394                    [](int32_t num) { return std::make_shared<Int64Imm>(num); });
1395   }
1396   return std::make_shared<TD>(value_list);
1397 }
1398 
1399 template <typename TD>
ConvertTensorToSequenceFloat(const py::object & obj)1400 ValuePtr ConvertTensorToSequenceFloat(const py::object &obj) {
1401   auto float_tensor = ConvertTensorValue(obj);
1402   if (float_tensor == nullptr) {
1403     MS_LOG(INFO) << "Can not convert python object with type [" << obj.get_type() << "] to Tensor.";
1404     return nullptr;
1405   }
1406 
1407   auto data_type = float_tensor->data_type();
1408   if (data_type != kNumberTypeFloat64) {
1409     MS_LOG(ERROR) << "Can not convert Tensor with type " << TypeIdToString(data_type) << "to Float64 Sequence.";
1410     return nullptr;
1411   }
1412 
1413   auto shape = float_tensor->shape();
1414   if (shape.size() > 1) {
1415     MS_LOG(ERROR) << "Only support converting 1-D Tensor or scalar Tensor to sequence. But got the shape of Tensor: "
1416                   << shape;
1417     return nullptr;
1418   }
1419 
1420   auto data = static_cast<double *>(GetTensorDataPtr(float_tensor));
1421   auto size = float_tensor->DataSize();
1422   std::vector<ValuePtr> value_list(size);
1423   for (size_t i = 0; i < size; i++) {
1424     value_list.emplace_back(ConvertPythonFloatToScalarValue(data[i]));
1425   }
1426 
1427   return std::make_shared<TD>(value_list);
1428 }
1429 
1430 template <typename TD>
ConvertTensorToSequenceAny(const py::object & obj)1431 ValuePtr ConvertTensorToSequenceAny(const py::object &obj) {
1432   auto tensor = ConvertTensorValue(obj);
1433   if (tensor == nullptr) {
1434     MS_LOG(INFO) << "Can not convert python object with type [" << obj.get_type() << "] to Tensor.";
1435     return nullptr;
1436   }
1437 
1438   auto shape = tensor->shape();
1439   if (shape.size() > 1) {
1440     MS_LOG(ERROR) << "Only support converting 1-D Tensor or scalar Tensor to sequence. But got the shape of Tensor: "
1441                   << shape;
1442     return nullptr;
1443   }
1444 
1445   auto data_type = tensor->data_type();
1446   auto size = tensor->DataSize();
1447   std::vector<ValuePtr> value_list(size);
1448   if (data_type == kNumberTypeInt64) {
1449     auto data = static_cast<int64_t *>(GetTensorDataPtr(tensor));
1450     for (size_t i = 0; i < size; i++) {
1451       value_list.emplace_back(std::make_shared<Int64Imm>(data[i]));
1452     }
1453   } else if (data_type == kNumberTypeFloat64) {
1454     auto data = static_cast<double *>(GetTensorDataPtr(tensor));
1455     for (size_t i = 0; i < size; i++) {
1456       value_list.emplace_back(ConvertPythonFloatToScalarValue(data[i]));
1457     }
1458   } else if (data_type == kNumberTypeBool) {
1459     auto data = static_cast<bool *>(GetTensorDataPtr(tensor));
1460     for (size_t i = 0; i < size; i++) {
1461       value_list.emplace_back(std::make_shared<BoolImm>(data[i]));
1462     }
1463   } else {
1464     MS_LOG(ERROR) << "Can not convert Tensor with type " << TypeIdToString(data_type) << " to sequence.";
1465     return nullptr;
1466   }
1467 
1468   return std::make_shared<TD>(value_list);
1469 }
1470 
ConvertTensorToInt(const py::object & obj)1471 ValuePtr ConvertTensorToInt(const py::object &obj) {
1472   auto tensor = ConvertTensorValue(obj);
1473   if (tensor == nullptr) {
1474     return nullptr;
1475   }
1476   if (tensor->DataSize() != 1) {
1477     MS_LOG(ERROR) << "Can only convert tensor with one element to int, but got " << tensor->ToString();
1478     return nullptr;
1479   }
1480   if (tensor->data_type() == kNumberTypeInt64) {
1481     return std::make_shared<Int64Imm>(static_cast<int64_t *>(GetTensorDataPtr(tensor))[0]);
1482   } else if (tensor->data_type() == kNumberTypeInt32) {
1483     return std::make_shared<Int64Imm>(static_cast<int32_t *>(GetTensorDataPtr(tensor))[0]);
1484   } else {
1485     MS_LOG(ERROR) << "Can not convert " << tensor->ToString() << " to int";
1486     return nullptr;
1487   }
1488 }
1489 
ConvertTensorToFloat(const py::object & obj)1490 ValuePtr ConvertTensorToFloat(const py::object &obj) {
1491   auto tensor = ConvertTensorValue(obj);
1492   if (tensor == nullptr) {
1493     return nullptr;
1494   }
1495   if (tensor->DataSize() != 1) {
1496     MS_LOG(ERROR) << "Can only convert tensor with one element to float, but got " << tensor->ToString();
1497     return nullptr;
1498   }
1499   if (tensor->data_type() != kNumberTypeFloat64) {
1500     MS_LOG(ERROR) << "Can not convert " << tensor->ToString() << " to float";
1501     return nullptr;
1502   }
1503   return ConvertPythonFloatToScalarValue(static_cast<double *>(GetTensorDataPtr(tensor))[0]);
1504 }
1505 
ConvertTensorToBool(const py::object & obj)1506 ValuePtr ConvertTensorToBool(const py::object &obj) {
1507   auto tensor = ConvertTensorValue(obj);
1508   if (tensor == nullptr) {
1509     return nullptr;
1510   }
1511   if (tensor->data_type() != kNumberTypeBool) {
1512     MS_LOG(ERROR) << "Can not convert " << tensor->ToString() << " to bool";
1513     return nullptr;
1514   }
1515   return std::make_shared<BoolImm>(static_cast<bool *>(GetTensorDataPtr(tensor))[0]);
1516 }
1517 
ConvertTensorToNumber(const py::object & obj)1518 ValuePtr ConvertTensorToNumber(const py::object &obj) {
1519   auto tensor = ConvertTensorValue(obj);
1520   if (tensor == nullptr) {
1521     return nullptr;
1522   }
1523   if (tensor->DataSize() != 1) {
1524     MS_EXCEPTION(ValueError) << "Can only convert tensor with one element to number, but got " << tensor->ToString();
1525   }
1526 
1527   switch (tensor->data_type()) {
1528     case kNumberTypeBool:
1529       return std::make_shared<BoolImm>(static_cast<bool *>(GetTensorDataPtr(tensor))[0]);
1530     case kNumberTypeInt64:
1531       return std::make_shared<Int64Imm>(static_cast<int64_t *>(GetTensorDataPtr(tensor))[0]);
1532     case kNumberTypeInt32:
1533       return std::make_shared<Int32Imm>(static_cast<int32_t *>(GetTensorDataPtr(tensor))[0]);
1534     case kNumberTypeFloat64:
1535       return ConvertPythonFloatToScalarValue(static_cast<double *>(GetTensorDataPtr(tensor))[0]);
1536     case kNumberTypeFloat32:
1537       return ConvertPythonFloatToScalarValue(static_cast<float *>(GetTensorDataPtr(tensor))[0]);
1538     default:
1539       MS_EXCEPTION(TypeError) << "Can not convert " << tensor->ToString() << " to number";
1540   }
1541 }
1542 
1543 static const std::unordered_map<int32_t, OpDefConvertFunc> kConverters = {
1544   // convert functions without type_cast
1545   {(int32_t)mindspore::ops::DT_BOOL, ConvertBool},
1546   {(int32_t)mindspore::ops::DT_INT, ConvertInt},
1547   {(int32_t)mindspore::ops::DT_FLOAT, ConvertFloat},
1548   {(int32_t)mindspore::ops::DT_NUMBER, ConvertNumber},
1549   {(int32_t)mindspore::ops::DT_TENSOR, ConvertTensor},
1550   {(int32_t)mindspore::ops::DT_STR, ConvertStr},
1551   {(int32_t)mindspore::ops::DT_ANY, ConvertAny},
1552   {(int32_t)mindspore::ops::DT_TYPE, ConvertDtype},
1553   {(int32_t)mindspore::ops::DT_TUPLE_BOOL, ConvertSequence<py::tuple, ValueTuple, ConvertBool>},
1554   {(int32_t)mindspore::ops::DT_TUPLE_INT, ConvertSequence<py::tuple, ValueTuple, ConvertInt>},
1555   {(int32_t)mindspore::ops::DT_TUPLE_FLOAT, ConvertSequence<py::tuple, ValueTuple, ConvertFloat>},
1556   {(int32_t)mindspore::ops::DT_TUPLE_NUMBER, ConvertSequence<py::tuple, ValueTuple, ConvertNumber>},
1557   {(int32_t)mindspore::ops::DT_TUPLE_TENSOR, ConvertSequence<py::tuple, ValueTuple, ConvertTensor>},
1558   {(int32_t)mindspore::ops::DT_TUPLE_STR, ConvertSequence<py::tuple, ValueTuple, ConvertStr>},
1559   {(int32_t)mindspore::ops::DT_TUPLE_ANY, ConvertSequence<py::tuple, ValueTuple, ConvertAny>},
1560   {(int32_t)mindspore::ops::DT_LIST_BOOL, ConvertSequence<py::list, ValueList, ConvertBool>},
1561   {(int32_t)mindspore::ops::DT_LIST_INT, ConvertSequence<py::list, ValueList, ConvertInt>},
1562   {(int32_t)mindspore::ops::DT_LIST_FLOAT, ConvertSequence<py::list, ValueList, ConvertFloat>},
1563   {(int32_t)mindspore::ops::DT_LIST_NUMBER, ConvertSequence<py::list, ValueList, ConvertNumber>},
1564   {(int32_t)mindspore::ops::DT_LIST_TENSOR, ConvertSequence<py::list, ValueList, ConvertTensor>},
1565   {(int32_t)mindspore::ops::DT_LIST_STR, ConvertSequence<py::list, ValueList, ConvertStr>},
1566   {(int32_t)mindspore::ops::DT_LIST_ANY, ConvertSequence<py::list, ValueList, ConvertAny>},
1567 
1568   // TypeCast1: convert single element to sequence
1569   {CombineTypesForTypeCast(mindspore::ops::DT_NUMBER, mindspore::ops::DT_TUPLE_INT),
1570    ConvertSingleElementToSequence<ValueTuple, ConvertNumber>},
1571   {CombineTypesForTypeCast(mindspore::ops::DT_NUMBER, mindspore::ops::DT_LIST_INT),
1572    ConvertSingleElementToSequence<ValueList, ConvertNumber>},
1573   {CombineTypesForTypeCast(mindspore::ops::DT_INT, mindspore::ops::DT_TUPLE_INT),
1574    ConvertSingleElementToSequence<ValueTuple, ConvertInt>},
1575   {CombineTypesForTypeCast(mindspore::ops::DT_INT, mindspore::ops::DT_LIST_INT),
1576    ConvertSingleElementToSequence<ValueList, ConvertInt>},
1577   {CombineTypesForTypeCast(mindspore::ops::DT_FLOAT, mindspore::ops::DT_TUPLE_INT),
1578    ConvertSingleElementToSequence<ValueTuple, ConvertFloat>},
1579   {CombineTypesForTypeCast(mindspore::ops::DT_FLOAT, mindspore::ops::DT_LIST_INT),
1580    ConvertSingleElementToSequence<ValueList, ConvertFloat>},
1581   {CombineTypesForTypeCast(mindspore::ops::DT_BOOL, mindspore::ops::DT_TUPLE_INT),
1582    ConvertSingleElementToSequence<ValueTuple, ConvertBool>},
1583   {CombineTypesForTypeCast(mindspore::ops::DT_BOOL, mindspore::ops::DT_LIST_INT),
1584    ConvertSingleElementToSequence<ValueList, ConvertBool>},
1585   {CombineTypesForTypeCast(mindspore::ops::DT_ANY, mindspore::ops::DT_TUPLE_ANY),
1586    ConvertSingleElementToSequence<ValueTuple, ConvertAny>},
1587   {CombineTypesForTypeCast(mindspore::ops::DT_ANY, mindspore::ops::DT_LIST_ANY),
1588    ConvertSingleElementToSequence<ValueList, ConvertAny>},
1589 
1590   // TypeCast2: convert sequence to sequence, such as py::tuple to ValueList
1591   {CombineTypesForTypeCast(mindspore::ops::DT_TUPLE_INT, mindspore::ops::DT_LIST_INT),
1592    ConvertSequence<py::tuple, ValueList, ConvertInt>},
1593   {CombineTypesForTypeCast(mindspore::ops::DT_TUPLE_FLOAT, mindspore::ops::DT_LIST_FLOAT),
1594    ConvertSequence<py::tuple, ValueList, ConvertFloat>},
1595   {CombineTypesForTypeCast(mindspore::ops::DT_TUPLE_BOOL, mindspore::ops::DT_LIST_BOOL),
1596    ConvertSequence<py::tuple, ValueList, ConvertBool>},
1597   {CombineTypesForTypeCast(mindspore::ops::DT_TUPLE_ANY, mindspore::ops::DT_LIST_ANY),
1598    ConvertSequence<py::tuple, ValueList, ConvertAny>},
1599   {CombineTypesForTypeCast(mindspore::ops::DT_TUPLE_TENSOR, mindspore::ops::DT_LIST_TENSOR),
1600    ConvertSequence<py::tuple, ValueList, ConvertTensor>},
1601 
1602   {CombineTypesForTypeCast(mindspore::ops::DT_LIST_INT, mindspore::ops::DT_TUPLE_INT),
1603    ConvertSequence<py::list, ValueTuple, ConvertInt>},
1604   {CombineTypesForTypeCast(mindspore::ops::DT_LIST_FLOAT, mindspore::ops::DT_TUPLE_FLOAT),
1605    ConvertSequence<py::list, ValueTuple, ConvertFloat>},
1606   {CombineTypesForTypeCast(mindspore::ops::DT_LIST_BOOL, mindspore::ops::DT_TUPLE_BOOL),
1607    ConvertSequence<py::list, ValueTuple, ConvertBool>},
1608   {CombineTypesForTypeCast(mindspore::ops::DT_LIST_ANY, mindspore::ops::DT_TUPLE_ANY),
1609    ConvertSequence<py::list, ValueTuple, ConvertAny>},
1610   {CombineTypesForTypeCast(mindspore::ops::DT_LIST_TENSOR, mindspore::ops::DT_TUPLE_TENSOR),
1611    ConvertSequence<py::list, ValueTuple, ConvertAny>},
1612 
1613   // TypeCast3: convert single element to Tensor
1614   {CombineTypesForTypeCast(mindspore::ops::DT_INT, mindspore::ops::DT_TENSOR),
1615    ConvertSingleElementToTensor<py::int_, pyint>},
1616   {CombineTypesForTypeCast(mindspore::ops::DT_FLOAT, mindspore::ops::DT_TENSOR),
1617    ConvertSingleElementToTensor<py::float_, pyfloat>},
1618   {CombineTypesForTypeCast(mindspore::ops::DT_BOOL, mindspore::ops::DT_TENSOR),
1619    ConvertSingleElementToTensor<py::bool_, bool>},
1620   {CombineTypesForTypeCast(mindspore::ops::DT_NUMBER, mindspore::ops::DT_TENSOR), ConvertNumberToTensor},
1621 
1622   // TypeCast4: convert between sequence and tensor
1623   {CombineTypesForTypeCast(mindspore::ops::DT_TUPLE_INT, mindspore::ops::DT_TENSOR),
1624    ConvertSequenceToTensor<py::tuple, py::int_, pyint>},
1625   {CombineTypesForTypeCast(mindspore::ops::DT_TUPLE_FLOAT, mindspore::ops::DT_TENSOR),
1626    ConvertSequenceToTensor<py::tuple, py::float_, pyfloat>},
1627   {CombineTypesForTypeCast(mindspore::ops::DT_TUPLE_BOOL, mindspore::ops::DT_TENSOR),
1628    ConvertSequenceBoolToTensor<py::tuple>},
1629   {CombineTypesForTypeCast(mindspore::ops::DT_LIST_INT, mindspore::ops::DT_TENSOR),
1630    ConvertSequenceToTensor<py::list, py::int_, pyint>},
1631   {CombineTypesForTypeCast(mindspore::ops::DT_LIST_FLOAT, mindspore::ops::DT_TENSOR),
1632    ConvertSequenceToTensor<py::list, py::float_, pyfloat>},
1633   {CombineTypesForTypeCast(mindspore::ops::DT_LIST_BOOL, mindspore::ops::DT_TENSOR),
1634    ConvertSequenceBoolToTensor<py::list>},
1635 
1636   {CombineTypesForTypeCast(mindspore::ops::DT_TENSOR, mindspore::ops::DT_TUPLE_INT),
1637    ConvertTensorToSequenceInt<ValueTuple>},
1638   {CombineTypesForTypeCast(mindspore::ops::DT_TENSOR, mindspore::ops::DT_TUPLE_FLOAT),
1639    ConvertTensorToSequenceFloat<ValueTuple>},
1640   {CombineTypesForTypeCast(mindspore::ops::DT_TENSOR, mindspore::ops::DT_TUPLE_BOOL),
1641    ConvertTensorToSequence<ValueTuple, bool, BoolImm, kNumberTypeBool>},
1642   {CombineTypesForTypeCast(mindspore::ops::DT_TENSOR, mindspore::ops::DT_TUPLE_BOOL),
1643    ConvertTensorToSequenceAny<ValueTuple>},
1644 
1645   {CombineTypesForTypeCast(mindspore::ops::DT_TENSOR, mindspore::ops::DT_LIST_INT),
1646    ConvertTensorToSequenceInt<ValueList>},
1647   {CombineTypesForTypeCast(mindspore::ops::DT_TENSOR, mindspore::ops::DT_LIST_FLOAT),
1648    ConvertTensorToSequenceFloat<ValueList>},
1649   {CombineTypesForTypeCast(mindspore::ops::DT_TENSOR, mindspore::ops::DT_LIST_BOOL),
1650    ConvertTensorToSequence<ValueList, bool, BoolImm, kNumberTypeBool>},
1651   {CombineTypesForTypeCast(mindspore::ops::DT_TENSOR, mindspore::ops::DT_LIST_BOOL),
1652    ConvertTensorToSequenceAny<ValueList>},
1653 
1654   // TypeCast5: convert tensor to single element
1655   {CombineTypesForTypeCast(mindspore::ops::DT_TENSOR, mindspore::ops::DT_INT), ConvertTensorToInt},
1656   {CombineTypesForTypeCast(mindspore::ops::DT_TENSOR, mindspore::ops::DT_FLOAT), ConvertTensorToFloat},
1657   {CombineTypesForTypeCast(mindspore::ops::DT_TENSOR, mindspore::ops::DT_BOOL), ConvertTensorToBool},
1658   {CombineTypesForTypeCast(mindspore::ops::DT_TENSOR, mindspore::ops::DT_NUMBER), ConvertTensorToNumber},
1659 };
1660 
GetConverterByType(int32_t dtype)1661 OpDefConvertFunc GetConverterByType(int32_t dtype) {
1662   auto it = kConverters.find(dtype);
1663   if (it == kConverters.end()) {
1664     if ((dtype >> kTypeShiftBits) == 0) {
1665       MS_LOG(EXCEPTION) << "Can not find converter for dtype[" << ops::EnumToString(static_cast<ops::OP_DTYPE>(dtype))
1666                         << "].";
1667     } else {
1668       MS_LOG(EXCEPTION) << "Can not find converter for src_type["
1669                         << ops::EnumToString(static_cast<ops::OP_DTYPE>(dtype >> kTypeShiftBits)) << "] and dst_type["
1670                         << ops::EnumToString(static_cast<ops::OP_DTYPE>(dtype & kDstMask)) << "].";
1671     }
1672   }
1673 
1674   return it->second;
1675 }
1676 }  // namespace parse
1677 }  // namespace mindspore
1678