1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2023 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "pipeline/jit/ps/parse/data_converter.h"
20 #include <utility>
21 #include <unordered_map>
22 #include <algorithm>
23 #include "mindspore/core/ops/structure_ops.h"
24 #include "pipeline/jit/ps/parse/resolve.h"
25 #include "pipeline/jit/ps/pipeline.h"
26 #include "frontend/operator/ops.h"
27 #include "frontend/operator/composite/composite.h"
28 #include "ir/func_graph_cloner.h"
29 #include "ir/cell.h"
30 #include "ir/dtype.h"
31 #include "utils/symbolic.h"
32 #include "utils/ms_context.h"
33 #include "include/common/fallback.h"
34 #include "include/common/utils/utils.h"
35 #include "include/common/utils/convert_utils_py.h"
36 #include "include/common/utils/primfunc_utils.h"
37 #include "frontend/operator/composite/multitype_funcgraph.h"
38
39 namespace mindspore {
40 namespace parse {
41 namespace {
42 struct PyDataToValueRegister {
PyDataToValueRegistermindspore::parse::__anon6d4124920111::PyDataToValueRegister43 PyDataToValueRegister() noexcept {
44 python_adapter::PyAdapterCallback::SetPyDataToValueHandler(data_converter::PyDataToValue);
45 }
46 } callback_register;
47 } // namespace
48 using Tensor = mindspore::tensor::Tensor;
49 using TensorPtr = mindspore::tensor::TensorPtr;
50 using BaseTensor = mindspore::tensor::BaseTensor;
51 using BaseTensorPtr = mindspore::tensor::BaseTensorPtr;
52 using MetaTensor = mindspore::tensor::MetaTensor;
53 using MetaTensorPtr = mindspore::tensor::MetaTensorPtr;
54 using CSRTensor = mindspore::tensor::CSRTensor;
55 using CSRTensorPtr = mindspore::tensor::CSRTensorPtr;
56 using COOTensor = mindspore::tensor::COOTensor;
57 using COOTensorPtr = mindspore::tensor::COOTensorPtr;
58 using MapTensor = mindspore::tensor::MapTensor;
59 using MapTensorPtr = mindspore::tensor::MapTensorPtr;
60
61 using InstanceCheckFunc = std::function<bool(const py::object &)>;
62 using InstanceConvertFunc = std::function<ValuePtr(const py::object &, bool, const TypePtr &, const ValuePtrList &)>;
63 static constexpr int kBit8 = 8;
64 static constexpr int kBit16 = 16;
65 static constexpr int kBit32 = 32;
66 static constexpr int kBit64 = 64;
67
68 class DataConvertFunc {
69 public:
DataConvertFunc(InstanceConvertFunc convert_func)70 explicit DataConvertFunc(InstanceConvertFunc convert_func) : convert_func_(std::move(convert_func)) {}
71
72 virtual ~DataConvertFunc() = default;
73
74 virtual bool Matched(const py::object &obj) = 0;
75
ConvertPyObject(const py::object & obj,bool use_sig,const TypePtr & dtype,const ValuePtrList & args_value_list={})76 ValuePtr ConvertPyObject(const py::object &obj, bool use_sig, const TypePtr &dtype,
77 const ValuePtrList &args_value_list = {}) {
78 if (convert_func_ == nullptr) {
79 MS_LOG(INTERNAL_EXCEPTION) << "convert func is null";
80 }
81 return convert_func_(obj, use_sig, dtype, args_value_list);
82 }
83
84 private:
85 InstanceConvertFunc convert_func_ = nullptr;
86 };
87
88 using DataConvertFuncPtr = std::shared_ptr<DataConvertFunc>;
89
90 using ArgsObjConvertFunc = std::function<ValuePtr(const py::object &)>;
91 using ArgsObjSigConvertFunc = std::function<ValuePtr(const py::object &, bool)>;
92 using ArgsObjTypeConvertFunc = std::function<ValuePtr(const py::object &, const TypePtr &)>;
93 using ArgsObjArgsValueConvertFunc = std::function<ValuePtr(const py::object &, const ValuePtrList &)>;
94
95 // Convert the data according to instance type
96 template <typename T>
97 class ByTypeDataConvertFunc : public DataConvertFunc {
98 public:
ByTypeDataConvertFunc(const InstanceConvertFunc & convert_func)99 explicit ByTypeDataConvertFunc(const InstanceConvertFunc &convert_func)
100 : DataConvertFunc(convert_func), check_func_(py::isinstance<T>) {}
101
ByTypeDataConvertFunc(const ValuePtr & converted_type)102 explicit ByTypeDataConvertFunc(const ValuePtr &converted_type)
103 : DataConvertFunc([converted_type](const py::object &, bool, const TypePtr &, const ValuePtrList &) -> ValuePtr {
104 return converted_type;
105 }),
106 check_func_(py::isinstance<T>) {}
107
ByTypeDataConvertFunc(const ArgsObjConvertFunc & convert_func)108 explicit ByTypeDataConvertFunc(const ArgsObjConvertFunc &convert_func)
109 : DataConvertFunc([convert_func](const py::object &obj, bool, const TypePtr &, const ValuePtrList &) -> ValuePtr {
110 return convert_func(obj);
111 }),
112 check_func_(py::isinstance<T>) {}
113
ByTypeDataConvertFunc(const ArgsObjSigConvertFunc & convert_func)114 explicit ByTypeDataConvertFunc(const ArgsObjSigConvertFunc &convert_func)
115 : DataConvertFunc([convert_func](const py::object &obj, bool use_sig, const TypePtr &,
116 const ValuePtrList &) -> ValuePtr { return convert_func(obj, use_sig); }),
117 check_func_(py::isinstance<T>) {}
118
ByTypeDataConvertFunc(const ArgsObjTypeConvertFunc & convert_func)119 explicit ByTypeDataConvertFunc(const ArgsObjTypeConvertFunc &convert_func)
120 : DataConvertFunc([convert_func](const py::object &obj, bool, const TypePtr &dtype,
121 const ValuePtrList &) -> ValuePtr { return convert_func(obj, dtype); }),
122 check_func_(py::isinstance<T>) {}
123
ByTypeDataConvertFunc(const ArgsObjArgsValueConvertFunc & convert_func)124 explicit ByTypeDataConvertFunc(const ArgsObjArgsValueConvertFunc &convert_func)
125 : DataConvertFunc([convert_func](const py::object &obj, bool, const TypePtr &,
126 const ValuePtrList &args_value_list) -> ValuePtr {
127 return convert_func(obj, args_value_list);
128 }),
129 check_func_(py::isinstance<T>) {}
130
131 ~ByTypeDataConvertFunc() override = default;
132
Matched(const py::object & obj)133 bool Matched(const py::object &obj) override { return check_func_ != nullptr ? check_func_(obj) : false; }
134
135 private:
136 InstanceCheckFunc check_func_ = nullptr;
137 };
138
139 // Convert the data according to object attribute.
140 class ByAttrDataConvertFunc : public DataConvertFunc {
141 public:
ByAttrDataConvertFunc(const ArgsObjConvertFunc & convert_func,const std::string & attr_name,const std::string & cell_list_from_top="")142 ByAttrDataConvertFunc(const ArgsObjConvertFunc &convert_func, const std::string &attr_name,
143 const std::string &cell_list_from_top = "")
144 : DataConvertFunc([convert_func](const py::object &obj, bool, const TypePtr &, const ValuePtrList &) -> ValuePtr {
145 return convert_func(obj);
146 }),
147 attr_name_(attr_name),
148 cell_list_from_top_(cell_list_from_top) {}
149
ByAttrDataConvertFunc(const ArgsObjSigConvertFunc & convert_func,const std::string & attr_name,const std::string & cell_list_from_top="")150 ByAttrDataConvertFunc(const ArgsObjSigConvertFunc &convert_func, const std::string &attr_name,
151 const std::string &cell_list_from_top = "")
152 : DataConvertFunc([convert_func](const py::object &obj, bool use_sig, const TypePtr &,
153 const ValuePtrList &) -> ValuePtr { return convert_func(obj, use_sig); }),
154 attr_name_(attr_name),
155 cell_list_from_top_(cell_list_from_top) {}
156
157 ~ByAttrDataConvertFunc() override = default;
158
Matched(const py::object & obj)159 bool Matched(const py::object &obj) override {
160 return py::hasattr(obj, attr_name_.c_str()) && !py::hasattr(obj, cell_list_from_top_.c_str());
161 }
162
163 private:
164 std::string attr_name_;
165 std::string cell_list_from_top_;
166 };
167
168 // Convert the data according to match function.
169 class ByFuncDataConvertFunc : public DataConvertFunc {
170 public:
ByFuncDataConvertFunc(const InstanceCheckFunc & match_func,const ArgsObjConvertFunc & convert_func)171 ByFuncDataConvertFunc(const InstanceCheckFunc &match_func, const ArgsObjConvertFunc &convert_func)
172 : DataConvertFunc([convert_func](const py::object &obj, bool, const TypePtr &, const ValuePtrList &) -> ValuePtr {
173 return convert_func(obj);
174 }),
175 match_func_(match_func) {}
176
ByFuncDataConvertFunc(const InstanceCheckFunc & match_func,const ArgsObjSigConvertFunc & convert_func)177 ByFuncDataConvertFunc(const InstanceCheckFunc &match_func, const ArgsObjSigConvertFunc &convert_func)
178 : DataConvertFunc([convert_func](const py::object &obj, bool use_sig, const TypePtr &,
179 const ValuePtrList &) -> ValuePtr { return convert_func(obj, use_sig); }),
180 match_func_(match_func) {}
181
182 ~ByFuncDataConvertFunc() override = default;
183
Matched(const py::object & obj)184 bool Matched(const py::object &obj) override { return match_func_ != nullptr ? match_func_(obj) : false; }
185
186 private:
187 InstanceCheckFunc match_func_ = nullptr;
188 };
189
ConvertToBpropCut(const py::object & obj)190 FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
191 std::vector<std::string> results = data_converter::GetObjKey(obj);
192 std::string obj_key = results[0];
193 py::function bprop_func = py::getattr(obj, CUSTOM_BPROP_NAME);
194
195 auto bprop_graph = std::make_shared<FuncGraph>();
196 std::vector<AnfNodePtr> outputs;
197
198 auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut");
199 fake_bprop->AddBackwardHookFn(0, bprop_func);
200 (void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true));
201 outputs.push_back(NewValueNode(fake_bprop));
202
203 py::object code_obj = py::getattr(bprop_func, "__code__");
204 // Three parameters self, out and dout need to be excluded
205 constexpr auto kBpropExcludeParamNum = 3;
206 size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - kBpropExcludeParamNum;
207 for (size_t i = 0; i < inputs_num; ++i) {
208 auto param = bprop_graph->add_parameter();
209 outputs.push_back(param);
210 }
211 auto p1 = bprop_graph->add_parameter();
212 auto p2 = bprop_graph->add_parameter();
213 outputs.push_back(p1);
214 outputs.push_back(p2);
215
216 bprop_graph->set_output(bprop_graph->NewCNode(std::move(outputs)));
217 data_converter::SetObjGraphValue(obj_key, bprop_graph);
218 return bprop_graph;
219 }
220
221 namespace {
ConvertTuple(const py::object & obj,bool use_signature)222 ValuePtr ConvertTuple(const py::object &obj, bool use_signature) {
223 MS_LOG(DEBUG) << "Converting python tuple";
224 auto tuple = obj.cast<py::tuple>();
225 std::vector<ValuePtr> value_list;
226 for (size_t it = 0; it < tuple.size(); ++it) {
227 ValuePtr out = nullptr;
228 bool success = ConvertData(tuple[it], &out, use_signature);
229 if (!success) {
230 return nullptr;
231 }
232 value_list.push_back(out);
233 }
234 auto res = std::make_shared<ValueTuple>(value_list);
235 return res;
236 }
237
IsNamedTuple(const py::object & obj)238 bool IsNamedTuple(const py::object &obj) { return py::hasattr(obj, "_fields") && py::isinstance<py::tuple>(obj); }
239
ConvertNamedTuple(const py::object & obj,bool use_signature)240 ValuePtr ConvertNamedTuple(const py::object &obj, bool use_signature) {
241 MS_LOG(DEBUG) << "Converting python NamedTuple";
242 if (!py::hasattr(obj, "_asdict")) {
243 return nullptr;
244 }
245 auto asdict_fn = obj.attr("_asdict");
246 auto asdict_obj = asdict_fn();
247 auto dict_values = asdict_obj.cast<py::dict>();
248 std::vector<ValuePtr> keys;
249 std::vector<ValuePtr> values;
250 for (auto item : dict_values) {
251 ValuePtr key = nullptr;
252 ValuePtr value = nullptr;
253 bool success = ConvertData(py::cast<py::object>(item.first), &key, use_signature) &&
254 ConvertData(py::cast<py::object>(item.second), &value, use_signature);
255 if (!success) {
256 return nullptr;
257 }
258 MS_LOG(DEBUG) << key->ToString() << ", " << value->ToString();
259 keys.push_back(key);
260 values.push_back(value);
261 }
262 auto obj_name = obj.attr("__class__").attr("__name__");
263 std::string sub_class_name = py::str(obj_name).cast<std::string>();
264 return std::make_shared<ValueNamedTuple>(sub_class_name, keys, values);
265 }
266
ConvertStubTuple(const py::object & obj,bool use_signature)267 ValuePtr ConvertStubTuple(const py::object &obj, bool use_signature) {
268 MS_LOG(DEBUG) << "Converting python tuple";
269 auto tuple = obj.cast<py::tuple>();
270 std::vector<ValuePtr> value_list;
271 for (size_t it = 0; it < tuple.size(); ++it) {
272 ValuePtr out = nullptr;
273 bool success = ConvertStubData(tuple[it], &out, use_signature);
274 if (!success) {
275 return nullptr;
276 }
277 value_list.push_back(out);
278 }
279 return std::make_shared<ValueTuple>(value_list);
280 }
281
ConvertList(const py::object & obj,bool use_signature)282 ValuePtr ConvertList(const py::object &obj, bool use_signature) {
283 MS_LOG(DEBUG) << "Converting python list";
284 PyRecursionScope scope(obj);
285
286 auto list = obj.cast<py::list>();
287 std::vector<ValuePtr> value_list;
288 for (size_t it = 0; it < list.size(); ++it) {
289 ValuePtr out = nullptr;
290 bool success = ConvertData(list[it], &out, use_signature);
291 if (!success) {
292 return nullptr;
293 }
294 value_list.push_back(out);
295 }
296 auto res = std::make_shared<ValueList>(value_list);
297 return res;
298 }
299
ConvertStubList(const py::object & obj,bool use_signature)300 ValuePtr ConvertStubList(const py::object &obj, bool use_signature) {
301 MS_LOG(DEBUG) << "Converting python list";
302 PyRecursionScope scope(obj);
303
304 auto list = obj.cast<py::list>();
305 std::vector<ValuePtr> value_list;
306 for (size_t it = 0; it < list.size(); ++it) {
307 ValuePtr out = nullptr;
308 bool success = ConvertStubData(list[it], &out, use_signature);
309 if (!success) {
310 return nullptr;
311 }
312 value_list.push_back(out);
313 }
314 return std::make_shared<ValueList>(value_list);
315 }
316
ConvertCellList(const py::object & obj,bool use_signature)317 ValuePtr ConvertCellList(const py::object &obj, bool use_signature) {
318 MS_LOG(DEBUG) << "Converting cell list";
319 PyRecursionScope scope(obj);
320
321 py::sequence list = obj;
322 std::vector<ValuePtr> value_list;
323
324 py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
325 bool is_celllist = py::cast<bool>(python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_CELL_LIST, obj));
326 for (const auto &element : list) {
327 // An element will directly convert to InterpretedObject if:
328 // 1. The container is not a cell list object.
329 // 2. The element should be single cell (cell with no __cell_as_list__ attr).
330 bool to_interpret = !is_celllist && py::isinstance<Cell>(element);
331 if (to_interpret) {
332 value_list.push_back(std::make_shared<parse::InterpretedObject>(element));
333 continue;
334 }
335 ValuePtr out = nullptr;
336 bool success = ConvertData(element, &out, use_signature);
337 if (!success) {
338 return nullptr;
339 }
340 value_list.push_back(out);
341 }
342 return std::make_shared<ValueTuple>(value_list);
343 }
344
ConvertDict(const py::object & obj,bool use_signature)345 ValuePtr ConvertDict(const py::object &obj, bool use_signature) {
346 MS_LOG(DEBUG) << "Converting python dict";
347 PyRecursionScope scope(obj);
348
349 auto dict_values = obj.cast<py::dict>();
350 std::vector<std::pair<ValuePtr, ValuePtr>> key_values;
351 for (auto item : dict_values) {
352 ValuePtr key = nullptr;
353 ValuePtr value = nullptr;
354 bool success = ConvertData(py::cast<py::object>(item.first), &key, use_signature) &&
355 ConvertData(py::cast<py::object>(item.second), &value, use_signature);
356 if (!success) {
357 return nullptr;
358 }
359 (void)key_values.emplace_back(key, value);
360 }
361 auto res = std::make_shared<ValueDictionary>(key_values);
362 res->set_user_data<py::object>("origin_object", std::make_shared<py::object>(obj));
363 return res;
364 }
365
ConvertModuleNameSpace(const py::object & obj)366 ValuePtr ConvertModuleNameSpace(const py::object &obj) {
367 MS_LOG(DEBUG) << "Converting python module";
368 py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
369 py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj);
370 auto converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_MODULE, module_namespace, obj);
371 MS_LOG(DEBUG) << "name_space: " << converted->ToString();
372 return converted;
373 }
374
ConvertMsClass(const py::object & obj)375 ValuePtr ConvertMsClass(const py::object &obj) {
376 MS_LOG(DEBUG) << "Converting ms class";
377 // Convert class instance decorated with jit_class.
378 if (py::hasattr(obj, PYTHON_PARSE_METHOD)) {
379 MS_LOG(DEBUG) << "Convert obj to func graph.";
380 FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
381 if (func_graph == nullptr) {
382 MS_LOG(ERROR) << "Parse resolve function error.";
383 return nullptr;
384 }
385 PyObjectWrapperPtr python_obj = std::make_shared<PyObjectWrapper>(obj, "graph python obj");
386 func_graph->set_python_obj(python_obj);
387 return func_graph;
388 }
389 py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
390 py::object name = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MS_CLASS_NAME, obj);
391 auto cls_name = py::cast<std::string>(name);
392 return std::make_shared<MsClassObject>(obj, cls_name);
393 }
394
ConvertPrimitiveClassType(const py::object & obj)395 ValuePtr ConvertPrimitiveClassType(const py::object &obj) {
396 // need check the primitive is class type or instance
397 auto obj_type = data_converter::GetObjType(obj);
398 if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
399 auto desc = py::cast<std::string>(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj));
400 // desc has format "<class xxxx>", strip the '<' and '>' by offset 1.
401 return std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
402 }
403 return nullptr;
404 }
405
ConvertPrimitive(const py::object & obj,bool use_signature=false)406 ValuePtr ConvertPrimitive(const py::object &obj, bool use_signature = false) {
407 MS_LOG(DEBUG) << "Converting primitive object " << use_signature;
408
409 auto class_type = ConvertPrimitiveClassType(obj);
410 if (class_type != nullptr) {
411 return class_type;
412 }
413 py::object adapter_obj = obj;
414 if (py::hasattr(obj, "__setattr_flag__")) {
415 if (py::hasattr(obj, "_clone")) {
416 auto clone_fn = obj.attr("_clone");
417 adapter_obj = clone_fn();
418 }
419 }
420 auto prim_adapter = adapter_obj.cast<PrimitivePyAdapterPtr>();
421 MS_EXCEPTION_IF_NULL(prim_adapter);
422 auto primitive = prim_adapter->attached_primitive();
423 if (primitive == nullptr) {
424 primitive = std::make_shared<PrimitivePy>(adapter_obj);
425 prim_adapter->set_attached_primitive(primitive);
426 }
427
428 if (use_signature) {
429 return std::make_shared<prim::DoSignaturePrimitive>(primitive->name(), primitive);
430 }
431 return primitive;
432 }
433
ConvertPrimitiveFunction(const py::object & obj)434 ValuePtr ConvertPrimitiveFunction(const py::object &obj) {
435 MS_LOG(DEBUG) << "Converting primitive function";
436 auto class_type = ConvertPrimitiveClassType(obj);
437 if (class_type != nullptr) {
438 return class_type;
439 }
440 auto prim_func_adapter = obj.cast<PrimitiveFunctionAdapterPtr>();
441 MS_EXCEPTION_IF_NULL(prim_func_adapter);
442 auto cpp_primitive_func = prim_func_adapter->attached_primitive_function();
443 if (cpp_primitive_func == nullptr) {
444 auto prim_name = py::getattr(obj, "name").cast<std::string>();
445 return std::make_shared<prim::DoTransPrimitiveFunction>(std::make_shared<Primitive>(prim_name));
446 }
447 return cpp_primitive_func;
448 }
449
ConvertMetaFuncGraph(const py::object & obj,bool use_signature=false)450 ValuePtr ConvertMetaFuncGraph(const py::object &obj, bool use_signature = false) {
451 MS_LOG(DEBUG) << "Converting MetaFuncGraph object";
452 auto meta = obj.cast<MetaFuncGraphPtr>();
453 if (meta == nullptr) {
454 MS_LOG(ERROR) << "Resolve MetaFuncGraph error, get ptr is null";
455 return nullptr;
456 }
457 auto multi = meta->cast<prim::MultitypeFuncGraphPtr>();
458 if (multi != nullptr) {
459 multi->set_meta_obj(obj);
460 }
461 if (use_signature) {
462 return std::make_shared<prim::DoSignaturePrimitive>(meta->name(), meta);
463 }
464 return meta;
465 }
466
ConvertFuncGraph(const py::object & obj)467 ValuePtr ConvertFuncGraph(const py::object &obj) {
468 MS_LOG(DEBUG) << "Converting FuncGraph object";
469 auto func_graph = obj.cast<FuncGraphPtr>();
470 if (func_graph == nullptr) {
471 MS_LOG(ERROR) << "Resolve FuncGraph error, get ptr is null";
472 return nullptr;
473 }
474 func_graph->set_attr("is_load", MakeValue(true));
475 return func_graph;
476 }
477
ConvertSlice(const py::object & obj)478 ValuePtr ConvertSlice(const py::object &obj) {
479 MS_LOG(DEBUG) << "Converting slice object";
480
481 auto convert_func = [obj](const std::string &attr) -> ValuePtr {
482 auto py_attr = py::getattr(obj, attr.c_str());
483 if (py::isinstance<py::none>(py_attr)) {
484 return kNone;
485 }
486 if (py::isinstance<py::int_>(py_attr)) {
487 auto value = py::cast<int64_t>(py_attr);
488 return MakeValue(value);
489 }
490 if (py::isinstance<Tensor>(py_attr)) {
491 return py::cast<TensorPtr>(py_attr);
492 }
493 if (IsStubTensor(py_attr)) {
494 return ConvertStubTensor(py_attr);
495 }
496 MS_LOG(EXCEPTION) << "Attribute '" << attr << "' of " << py::str(obj)
497 << " should be int or Tensor with Int type but got " << py::str(py_attr);
498 };
499 ValuePtr start = convert_func(kSliceStart);
500 ValuePtr stop = convert_func(kSliceStop);
501 ValuePtr step = convert_func(kSliceStep);
502 return std::make_shared<ValueSlice>(start, stop, step);
503 }
504
ConvertCellObjToFuncGraph(const py::object & obj,const ValuePtrList & args_value_list)505 ValuePtr ConvertCellObjToFuncGraph(const py::object &obj, const ValuePtrList &args_value_list) {
506 FuncGraphPtr func_graph = ConvertToFuncGraph(obj, args_value_list);
507 if (func_graph == nullptr) {
508 MS_LOG(ERROR) << "Parse resolve function error.";
509 return nullptr;
510 }
511 // if the cell object has specified bprop, it has user-defined bprop function parse and record it
512 if (py::hasattr(obj, CUSTOM_BPROP_NAME)) {
513 bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug"));
514 FuncGraphPtr bprop_graph =
515 enable_bprop_debug ? ConvertToBpropCut(obj) : ConvertToFuncGraph(obj, {}, PYTHON_MOD_GET_BPROP_METHOD);
516 if (bprop_graph != nullptr) {
517 (void)func_graph->transforms().emplace(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph));
518 (void)bprop_graph->transforms().emplace("primal", FuncGraphTransform(func_graph));
519 func_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
520 func_graph->set_flag(FUNC_GRAPH_FLAG_PRIMAL_OF_BPROP, true);
521 }
522 }
523 if (py::hasattr(obj, STAGE_NAME)) {
524 auto stage = py::cast<int>(py::getattr(obj, STAGE_NAME));
525 func_graph->set_stage(stage);
526 }
527 if (py::hasattr(obj, SEGMENT_NAME)) {
528 auto segment = py::cast<int>(py::getattr(obj, SEGMENT_NAME));
529 func_graph->set_segment(segment);
530 }
531 auto cell = py::cast<CellPtr>(obj);
532 if (cell != nullptr && cell->HasAttr(kAttrRandomOpSnapShot)) {
533 auto value = cell->GetAttr(kAttrRandomOpSnapShot);
534 MS_EXCEPTION_IF_NULL(value);
535 func_graph->set_attr(kAttrRandomOpSnapShot, value);
536 }
537 return func_graph;
538 }
539
ConvertConstantNumpyNumber(const py::object & obj,ResolveType obj_type)540 ValuePtr ConvertConstantNumpyNumber(const py::object &obj, ResolveType obj_type) {
541 if (obj_type == RESOLVE_TYPE_NUMPY_INT_NUMBER) {
542 MS_LOG(INFO) << "Convert constant numpy int64_t number:" << (std::string)py::str(obj);
543 return MakeValue(py::cast<int64_t>(obj));
544 }
545 if (obj_type == RESOLVE_TYPE_NUMPY_FLOAT_NUMBER) {
546 MS_LOG(INFO) << "Convert constant numpy float number::" << (std::string)py::str(obj);
547 return MakeValue(py::cast<float>(obj));
548 }
549 if (obj_type == RESOLVE_TYPE_NUMPY_BOOL_NUMBER) {
550 MS_LOG(INFO) << "Convert constant numpy bool_ number::" << (std::string)py::str(obj);
551 return MakeValue(py::cast<bool>(obj));
552 }
553
554 MS_LOG(ERROR) << "Convert numpy number type is invalid, obj: " << py::str(obj);
555 return nullptr;
556 }
557
CheckJITForbiddenAPI(const py::object & obj)558 void CheckJITForbiddenAPI(const py::object &obj) {
559 auto module = python_adapter::GetPyModule(PYTHON_MOD_MODULE);
560 py::object res = python_adapter::CallPyModFn(module, PYTHON_MOD_GET_MODULE_AND_NAME_INFO, obj);
561 if (!py::isinstance<py::none>(res)) {
562 auto obj_info = py::cast<py::list>(res);
563 auto obj_module = py::cast<std::string>(obj_info[0]);
564 auto obj_name = py::cast<std::string>(obj_info[1]);
565 auto obj_type = py::cast<std::string>(obj_info[2]);
566 std::ostringstream oss;
567 oss << "Failed to compile in GRAPH_MODE because the " << obj_type << " '" << obj_module << "." << obj_name
568 << "' is not supported in 'construct' or function with @jit decorator. "
569 << "Try to use the " << obj_type << " '" << obj_module << "." << obj_name << "' externally "
570 << "such as initialized in the method '__init__' before assigning"
571 << ".\nFor more details, please refer to "
572 << "https://www.mindspore.cn/docs/zh-CN/master/design/dynamic_graph_and_static_graph.html \n";
573 // Check if the API is decoratored by @jit_forbidden_register.
574 bool is_jit_forbidden_register = data_converter::IsJITForbiddenAPI(obj);
575 if (is_jit_forbidden_register) {
576 MS_LOG(EXCEPTION) << oss.str();
577 }
578 // Check if the API's module is in the JIT forbidden module set.
579 bool is_jit_forbidden_module =
580 py::cast<bool>(python_adapter::CallPyModFn(module, PYTHON_MOD_IS_JIT_FORBIDDEN_MODULE, obj_info[0]));
581 if (is_jit_forbidden_module) {
582 MS_LOG(EXCEPTION) << oss.str();
583 }
584 }
585 }
586
ConvertOtherObj(const py::object & obj,bool forbid_reuse=false)587 ValuePtr ConvertOtherObj(const py::object &obj, bool forbid_reuse = false) {
588 auto obj_type = data_converter::GetObjType(obj);
589 MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " ";
590 if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
591 // Check JIT forbidden API
592 CheckJITForbiddenAPI(obj);
593 MS_LOG(DEBUG) << "Resolve the class type, need create class instance.";
594 std::string desc = py::str(obj);
595 // desc has format "<class xxxx>", strip the '<' and '>' by offset 1.
596 return std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
597 }
598 if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD ||
599 (obj_type == RESOLVE_TYPE_CLASS_INSTANCE && py::hasattr(obj, PYTHON_PARSE_METHOD))) {
600 if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD) {
601 // Check JIT forbidden API
602 CheckJITForbiddenAPI(obj);
603 // Check if the function is from a third-party library.
604 py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
605 bool is_third_party_function =
606 python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_FROM_THIRD_PARTY_LIBRARY, obj).cast<bool>();
607 if (is_third_party_function) {
608 MS_LOG(DEBUG) << "Converting the function from third-party library: " << py::str(obj);
609 return std::make_shared<InterpretedObject>(obj);
610 }
611 }
612 MS_LOG(DEBUG) << "Convert the obj to func graph, type is " << obj_type;
613 FuncGraphPtr func_graph = ConvertToFuncGraph(obj, {}, PYTHON_MOD_GET_PARSE_METHOD, forbid_reuse);
614 if (func_graph == nullptr) {
615 MS_LOG(ERROR) << "Parse resolve function error.";
616 return nullptr;
617 }
618 return func_graph;
619 }
620 if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) {
621 MS_LOG(INTERNAL_EXCEPTION) << "Fail to convert class instance: " << py::str(obj);
622 }
623 // Start RESOLVE_TYPE_INVALID.
624 if (obj_type == RESOLVE_TYPE_NUMPY_INT_NUMBER || obj_type == RESOLVE_TYPE_NUMPY_FLOAT_NUMBER ||
625 obj_type == RESOLVE_TYPE_NUMPY_BOOL_NUMBER) {
626 return ConvertConstantNumpyNumber(obj, obj_type);
627 }
628 auto res = std::make_shared<InterpretedObject>(obj);
629 MS_EXCEPTION_IF_NULL(res);
630 MS_LOG(DEBUG) << "Get interpreted object: " << res->ToString();
631 return res;
632 }
633
634 template <typename T>
ConvertNumberWithType(const T & obj,const TypePtr & dtype)635 ValuePtr ConvertNumberWithType(const T &obj, const TypePtr &dtype) {
636 ValuePtr data = nullptr;
637 auto int_dypte = dyn_cast<Int>(dtype);
638 if (int_dypte != nullptr) {
639 switch (int_dypte->nbits()) {
640 case kBit8:
641 data = std::make_shared<Int8Imm>(obj);
642 break;
643 case kBit16:
644 data = std::make_shared<Int16Imm>(obj);
645 break;
646 case kBit32:
647 data = std::make_shared<Int32Imm>(obj);
648 break;
649 case kBit64:
650 data = std::make_shared<Int64Imm>(obj);
651 break;
652 default:
653 data = std::make_shared<Int64Imm>(obj);
654 }
655 return data;
656 }
657
658 auto uint_dypte = dyn_cast<UInt>(dtype);
659 if (uint_dypte != nullptr) {
660 switch (uint_dypte->nbits()) {
661 case kBit8:
662 data = std::make_shared<UInt8Imm>(obj);
663 break;
664 case kBit16:
665 data = std::make_shared<UInt16Imm>(obj);
666 break;
667 case kBit32:
668 data = std::make_shared<UInt32Imm>(obj);
669 break;
670 case kBit64:
671 data = std::make_shared<UInt64Imm>(obj);
672 break;
673 default:
674 data = std::make_shared<UInt32Imm>(obj);
675 }
676 return data;
677 }
678
679 auto float_dypte = dyn_cast<Float>(dtype);
680 if (float_dypte != nullptr) {
681 switch (float_dypte->nbits()) {
682 case kBit32:
683 data = std::make_shared<FP32Imm>(obj);
684 break;
685 case kBit64:
686 data = std::make_shared<FP64Imm>(obj);
687 break;
688 default:
689 data = std::make_shared<FP32Imm>(obj);
690 }
691 return data;
692 }
693 return nullptr;
694 }
695
ConvertIntegerWithType(const py::object & obj,const TypePtr & dtype=nullptr)696 ValuePtr ConvertIntegerWithType(const py::object &obj, const TypePtr &dtype = nullptr) {
697 auto obj_int64 = py::cast<int64_t>(obj);
698 // The mutable _Bool class inherits from int, because base class 'bool' is a marked final.
699 if (py::hasattr(obj, "__ms_mutable_bool__")) {
700 bool obj_bool = obj_int64 != 0;
701 return std::make_shared<BoolImm>(obj_bool);
702 }
703 if (dtype == nullptr) {
704 return std::make_shared<Int64Imm>(obj_int64);
705 }
706 return ConvertNumberWithType<int64_t>(obj_int64, dtype);
707 }
708
ConvertFloatWithType(const py::object & obj,const TypePtr & dtype=nullptr)709 ValuePtr ConvertFloatWithType(const py::object &obj, const TypePtr &dtype = nullptr) {
710 auto obj_float32 = py::cast<pyfloat>(obj);
711 if (dtype == nullptr) {
712 auto obj_double = py::cast<double>(obj);
713 auto ret = std::make_shared<FP32Imm>(obj_float32);
714 ret->set_prim_value(obj_double);
715 return ret;
716 }
717 return ConvertNumberWithType<pyfloat>(obj_float32, dtype);
718 }
719
ConvertNameSpace(const py::object & obj,bool use_signature)720 ValuePtr ConvertNameSpace(const py::object &obj, bool use_signature) {
721 MS_LOG(DEBUG) << "Converting python NameSpace";
722 auto res = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj);
723 MS_LOG(DEBUG) << "name_space: " << res->ToString();
724 return res;
725 }
726
727 template <typename T, typename U>
PyCast(const py::object & obj)728 ValuePtr PyCast(const py::object &obj) {
729 return std::make_shared<T>(py::cast<U>(obj));
730 }
731
732 template <typename T>
ObjCast(const py::object & obj)733 ValuePtr ObjCast(const py::object &obj) {
734 return obj.cast<T>();
735 }
736
GetDataConvertFuncs()737 static const std::vector<DataConvertFuncPtr> &GetDataConvertFuncs() {
738 // Convert data by python object type.
739 static const std::vector<DataConvertFuncPtr> data_convert_funcs{
740 // AdapterTensor needs to be processed before Tensor because it inherits from Tensor.
741 std::make_shared<ByFuncDataConvertFunc>(IsStubTensor, ConvertStubTensor),
742 std::make_shared<ByFuncDataConvertFunc>(IsNamedTuple, ConvertNamedTuple),
743 std::make_shared<ByTypeDataConvertFunc<Tensor>>(ObjCast<TensorPtr>),
744 std::make_shared<ByAttrDataConvertFunc>(ConvertMsClass, PYTHON_MS_CLASS),
745 std::make_shared<ByTypeDataConvertFunc<BaseTensor>>(ObjCast<BaseTensorPtr>),
746 std::make_shared<ByTypeDataConvertFunc<py::tuple>>(ConvertTuple),
747 std::make_shared<ByTypeDataConvertFunc<py::list>>(ConvertList),
748 std::make_shared<ByTypeDataConvertFunc<py::bool_>>(PyCast<BoolImm, bool>),
749 std::make_shared<ByTypeDataConvertFunc<py::int_>>(ConvertIntegerWithType),
750 std::make_shared<ByTypeDataConvertFunc<py::float_>>(ConvertFloatWithType),
751 std::make_shared<ByTypeDataConvertFunc<py::str>>(PyCast<StringImm, string>),
752 std::make_shared<ByTypeDataConvertFunc<py::none>>(kNone),
753 std::make_shared<ByTypeDataConvertFunc<MetaTensor>>(ObjCast<MetaTensorPtr>),
754 std::make_shared<ByTypeDataConvertFunc<CSRTensor>>(ObjCast<CSRTensorPtr>),
755 std::make_shared<ByTypeDataConvertFunc<COOTensor>>(ObjCast<COOTensorPtr>),
756 std::make_shared<ByTypeDataConvertFunc<MapTensor>>(ObjCast<MapTensorPtr>),
757 std::make_shared<ByTypeDataConvertFunc<py::ellipsis>>(kEllipsis),
758 std::make_shared<ByTypeDataConvertFunc<py::module>>(ConvertModuleNameSpace),
759 std::make_shared<ByTypeDataConvertFunc<Type>>(ObjCast<TypePtr>),
760 std::make_shared<ByTypeDataConvertFunc<UMonad>>(ObjCast<UMonadPtr>),
761 std::make_shared<ByTypeDataConvertFunc<IOMonad>>(ObjCast<IOMonadPtr>),
762 std::make_shared<ByAttrDataConvertFunc>(ConvertNameSpace, PYTHON_CLASS_MEMBER_NAMESPACE),
763 std::make_shared<ByTypeDataConvertFunc<py::dict>>(ConvertDict),
764 std::make_shared<ByAttrDataConvertFunc>(ConvertDict, PYTHON_CELL_AS_DICT),
765 std::make_shared<ByTypeDataConvertFunc<py::slice>>(ConvertSlice),
766 std::make_shared<ByAttrDataConvertFunc>(ConvertCellList, PYTHON_CELL_AS_LIST, PYTHON_CELL_LIST_FROM_TOP),
767 std::make_shared<ByTypeDataConvertFunc<Cell>>(ConvertCellObjToFuncGraph),
768 std::make_shared<ByAttrDataConvertFunc>(ConvertPrimitive, PYTHON_PRIMITIVE_FLAG),
769 std::make_shared<ByAttrDataConvertFunc>(ConvertPrimitiveFunction, PYTHON_PRIMITIVE_FUNCTION_FLAG),
770 std::make_shared<ByTypeDataConvertFunc<MetaFuncGraph>>(ConvertMetaFuncGraph),
771 std::make_shared<ByTypeDataConvertFunc<FuncGraph>>(ConvertFuncGraph),
772 };
773 return data_convert_funcs;
774 }
775
GetStubDataConvertFuncs()776 static const std::vector<DataConvertFuncPtr> &GetStubDataConvertFuncs() {
777 // Convert data by python object type.
778 static const std::vector<DataConvertFuncPtr> data_convert_funcs{
779 std::make_shared<ByFuncDataConvertFunc>([](const py::object &obj) -> bool { return IsStubTensor(obj); },
780 PyStubNodeCast),
781 std::make_shared<ByTypeDataConvertFunc<py::tuple>>(ConvertStubTuple),
782 std::make_shared<ByTypeDataConvertFunc<py::list>>(ConvertStubList),
783 };
784 return data_convert_funcs;
785 }
786
RemoveRecomputeScope(const FuncGraphPtr & func_graph)787 void RemoveRecomputeScope(const FuncGraphPtr &func_graph) {
788 MS_EXCEPTION_IF_NULL(func_graph);
789 auto nodes = TopoSort(func_graph->get_return(), SuccDeeperSimple);
790
791 for (const auto &node : nodes) {
792 MS_EXCEPTION_IF_NULL(node);
793 if (!node->isa<CNode>()) {
794 continue;
795 }
796 const auto &origin_scope_name = node->scope()->name();
797 if (origin_scope_name.compare(0, strlen(kAttrRecompute), kAttrRecompute) == 0) {
798 auto remove_recompute_scope = origin_scope_name.substr(strlen(kAttrRecompute) + 1);
799 node->set_scope(std::make_shared<Scope>(remove_recompute_scope));
800 }
801 }
802 }
803 } // namespace
804
ConvertData(const py::object & obj,ValuePtr * data,bool use_signature,const TypePtr & dtype,bool forbid_reuse)805 bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature, const TypePtr &dtype, bool forbid_reuse) {
806 // Check parameter valid
807 if (data == nullptr) {
808 MS_LOG(ERROR) << "The value pointer should not be null.";
809 return false;
810 }
811 ValuePtr converted = nullptr;
812 bool matched = false;
813 const auto &converters = GetDataConvertFuncs();
814 for (auto &converter : converters) {
815 if (converter->Matched(obj)) {
816 converted = converter->ConvertPyObject(obj, use_signature, dtype);
817 matched = true;
818 break;
819 }
820 }
821 if (!matched) {
822 converted = ConvertOtherObj(obj, forbid_reuse);
823 }
824 *data = converted;
825 return converted != nullptr;
826 }
827
ConvertStubData(const py::object & obj,ValuePtr * data,bool use_signature,const TypePtr & dtype,bool forbid_reuse)828 bool ConvertStubData(const py::object &obj, ValuePtr *data, bool use_signature, const TypePtr &dtype,
829 bool forbid_reuse) {
830 if (data == nullptr) {
831 MS_LOG(ERROR) << "The value pointer should not be null.";
832 return false;
833 }
834 ValuePtr converted = nullptr;
835 const auto &convert_funcs = GetStubDataConvertFuncs();
836 for (auto &convert_func : convert_funcs) {
837 if (convert_func->Matched(obj)) {
838 converted = convert_func->ConvertPyObject(obj, use_signature, dtype);
839 *data = converted;
840 return converted != nullptr;
841 }
842 }
843 return ConvertData(obj, data, use_signature, dtype, forbid_reuse);
844 }
845
MakeReusingGraph(const FuncGraphPtr & base_graph)846 FuncGraphPtr MakeReusingGraph(const FuncGraphPtr &base_graph) {
847 static int order = 0;
848 base_graph->set_attr(FUNC_GRAPH_FLAG_CELL_LAZY_INLINE_ORDER, MakeValue(++order));
849 base_graph->debug_info()->set_name("CR_" + base_graph->debug_info()->name());
850 MS_LOG(INFO) << "Lazy inline reusing graph: " << base_graph->ToString()
851 << ", args: " << base_graph->parameters().size() << ", parse order: " << order;
852 return base_graph;
853 }
854
MakeCellFuncGraph(const py::object & obj,const std::string & obj_id,const FuncGraphPtr & reusing_graph)855 FuncGraphPtr MakeCellFuncGraph(const py::object &obj, const std::string &obj_id, const FuncGraphPtr &reusing_graph) {
856 FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
857 // Normalize the name.
858 auto function_name = obj_id;
859 std::replace(function_name.begin(), function_name.end(), '.', '_');
860 std::replace(function_name.begin(), function_name.end(), '<', '_');
861 std::replace(function_name.begin(), function_name.end(), '>', '_');
862 func_graph->debug_info()->set_name(function_name);
863 PyObjectWrapperPtr python_obj = std::make_shared<PyObjectWrapper>(obj, "graph python obj");
864 func_graph->set_python_obj(python_obj);
865 func_graph->set_flag(FUNC_GRAPH_FLAG_PROXY_GRAPH, true);
866 std::vector<AnfNodePtr> new_node_inputs;
867 new_node_inputs.push_back(NewValueNode(reusing_graph));
868 for (const auto &origin_param : reusing_graph->parameters()) {
869 auto param = func_graph->add_parameter();
870 param->set_debug_info(origin_param->debug_info());
871 new_node_inputs.push_back(param);
872 }
873 AnfNodePtr out = func_graph->NewCNodeInOrder(new_node_inputs);
874 func_graph->set_output(out);
875 MS_LOG(INFO) << "Lazy inline cell: " << func_graph->ToString() << ", args: " << func_graph->parameters().size();
876 return func_graph;
877 }
878
ProcessLazyInline(const py::object & obj,const ValuePtrList & args_value_list,const std::string & python_mod_get_parse_method,const std::string & obj_id,const std::string & obj_key)879 FuncGraphPtr ProcessLazyInline(const py::object &obj, const ValuePtrList &args_value_list,
880 const std::string &python_mod_get_parse_method, const std::string &obj_id,
881 const std::string &obj_key) {
882 ValuePtr key_value = nullptr;
883 FuncGraphPtr reusing_graph = nullptr;
884 bool is_key_cache = data_converter::GetObjectValue(obj_key, &key_value);
885 if (is_key_cache && key_value != nullptr && key_value->isa<FuncGraph>()) {
886 MS_LOG(DEBUG) << "Get the cache data, obj: " << obj_key;
887 reusing_graph = key_value->cast<FuncGraphPtr>();
888 } else {
889 auto base_graph = ParsePythonCode(obj, python_mod_get_parse_method, args_value_list);
890 if (base_graph == nullptr) {
891 MS_LOG(ERROR) << "Parse resolve function error.";
892 return nullptr;
893 }
894 if (Parser::GetTopFuncGraph() == base_graph) {
895 return base_graph;
896 }
897 PyObjectWrapperPtr python_obj = std::make_shared<PyObjectWrapper>(obj, "graph python obj");
898 base_graph->set_python_obj(python_obj);
899 MS_LOG(DEBUG) << "Parse reusing function: " << reusing_graph->ToString();
900 reusing_graph = MakeReusingGraph(base_graph);
901 data_converter::CacheObjectValue(obj_key, reusing_graph);
902 }
903 // Let the original cell graph call the reusable graph.
904 auto func_graph = MakeCellFuncGraph(obj, obj_id, reusing_graph);
905 MS_LOG(DEBUG) << func_graph->ToString() << " calls " << reusing_graph->ToString();
906 return func_graph;
907 }
908
909 // Convert data to graph
ConvertToFuncGraph(const py::object & obj,const ValuePtrList & args_value_list,const std::string & python_mod_get_parse_method,bool forbid_reuse)910 FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const ValuePtrList &args_value_list,
911 const std::string &python_mod_get_parse_method, bool forbid_reuse) {
912 std::vector<std::string> results = data_converter::GetObjKey(obj);
913 std::string obj_id = results[0] + python_mod_get_parse_method;
914 std::string obj_key = results[1];
915 FuncGraphPtr func_graph = nullptr;
916 ValuePtr value = nullptr;
917 bool is_debug = MsContext::GetInstance()->get_param<int>(MS_CTX_DEBUG_LEVEL) == kLevelDebug;
918 bool is_cache = data_converter::GetObjectValue(obj_id, &value);
919 if (!is_debug && is_cache && value != nullptr && value->isa<FuncGraph>()) {
920 func_graph = value->cast<FuncGraphPtr>();
921 if (!func_graph->dropped()) {
922 bool has_forbid_reuse_attr = py::hasattr(obj, PYTHON_FUNCTION_FORBID_REUSE);
923 if (forbid_reuse || has_forbid_reuse_attr) {
924 return BasicClone(func_graph);
925 }
926 return func_graph;
927 }
928 }
929 if (obj_key.find("lazy_inline") != obj_key.npos) {
930 func_graph = ProcessLazyInline(obj, args_value_list, python_mod_get_parse_method, results[0], obj_key);
931 if (func_graph == nullptr) {
932 return nullptr;
933 }
934 } else {
935 func_graph = ParsePythonCode(obj, python_mod_get_parse_method, args_value_list);
936 if (func_graph == nullptr) {
937 MS_LOG(ERROR) << "Parse resolve function error.";
938 return nullptr;
939 }
940 }
941
942 data_converter::CacheObjectValue(obj_id, func_graph);
943 if (!obj_key.empty() && python_mod_get_parse_method == PYTHON_MOD_GET_PARSE_METHOD) {
944 data_converter::SetObjGraphValue(obj_key, func_graph);
945 }
946
947 PyObjectWrapperPtr python_obj = std::make_shared<PyObjectWrapper>(obj, "graph python obj");
948 func_graph->set_python_obj(python_obj);
949
950 if (forbid_reuse) {
951 // The function may be set recomputed in parse.
952 if (!data_converter::IsCellInstance(obj)) {
953 RemoveRecomputeScope(func_graph);
954 }
955 // Return the clone graph because the graph may be set recomputed later.
956 return BasicClone(func_graph);
957 }
958
959 return func_graph;
960 }
961
GetArgDefaultValue(const std::string & prim_name,const std::string & arg_name)962 ValuePtr GetArgDefaultValue(const std::string &prim_name, const std::string &arg_name) {
963 py::module mod = py::module::import(PYTHON_MOD_PRIMITIVE_OP_CREATE_INSTANCE_HELPER_MODULE);
964 if (!py::hasattr(mod, PYTHON_MOD_PRIMITIVE_OP_DEFAULT_VALUE_DICT)) {
965 MS_LOG(INTERNAL_EXCEPTION) << "Can not found " << PYTHON_MOD_PRIMITIVE_OP_DEFAULT_VALUE_DICT << "in "
966 << PYTHON_MOD_PRIMITIVE_OP_CREATE_INSTANCE_HELPER_MODULE << ".";
967 }
968 py::dict op_default_dict = mod.attr(PYTHON_MOD_PRIMITIVE_OP_DEFAULT_VALUE_DICT);
969 if (!op_default_dict.contains(py::str(prim_name))) {
970 return nullptr;
971 }
972 py::dict prim_default_dict = op_default_dict[py::str(prim_name)];
973 if (!prim_default_dict.contains(py::str(arg_name))) {
974 return nullptr;
975 }
976 auto default_value = prim_default_dict[py::str(arg_name)];
977 ValuePtr converted_ret = nullptr;
978 bool converted = ConvertData(default_value, &converted_ret);
979 if (!converted) {
980 const std::string &default_name = py::str(default_value);
981 MS_EXCEPTION(ValueError) << "For Operator[" << prim_name << "], '" << default_name
982 << "' is not supported as the default value for '" << arg_name << "'.";
983 }
984 return converted_ret;
985 }
986
987 namespace data_converter {
988 static mindspore::HashMap<std::string, ValuePtr> object_map_;
989
990 static mindspore::OrderedMap<std::string, std::vector<FuncGraphPtr>> object_graphs_map_;
991
SetObjGraphValue(const std::string & obj_key,const FuncGraphPtr & data)992 void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) {
993 object_graphs_map_[obj_key].push_back(data);
994 MS_LOG(DEBUG) << "Set func graph size: " << object_graphs_map_.size();
995 }
996
GetObjGraphs()997 const mindspore::OrderedMap<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs() {
998 MS_LOG(DEBUG) << "Obj graphs size: " << object_graphs_map_.size();
999 return object_graphs_map_;
1000 }
1001
CacheObjectValue(const std::string & obj_key,const ValuePtr & data)1002 void CacheObjectValue(const std::string &obj_key, const ValuePtr &data) { object_map_[obj_key] = data; }
1003
GetObjectValue(const std::string & obj_key,ValuePtr * const data)1004 bool GetObjectValue(const std::string &obj_key, ValuePtr *const data) {
1005 if (object_map_.count(obj_key) != 0) {
1006 *data = object_map_[obj_key];
1007 return true;
1008 }
1009 return false;
1010 }
1011
GetObjKey(const py::object & obj)1012 std::vector<std::string> GetObjKey(const py::object &obj) {
1013 py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
1014 py::tuple obj_tuple = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_KEY, obj);
1015 if (obj_tuple.size() != 2) {
1016 MS_LOG(INTERNAL_EXCEPTION) << "The function of \'get_obj_key()\' must return 2 elements";
1017 }
1018 return {py::cast<std::string>(obj_tuple[0]), py::cast<std::string>(obj_tuple[1])};
1019 }
1020
1021 // Get obj detail type
GetObjType(const py::object & obj)1022 ResolveType GetObjType(const py::object &obj) {
1023 try {
1024 py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
1025 auto obj_type = ResolveType(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast<int32_t>());
1026 return obj_type;
1027 } catch (const py::error_already_set &ex) {
1028 MS_LOG(ERROR) << "Meet a exception from Python when get the type of \'" << py::str(obj) << "\'.\n" << ex.what();
1029 std::rethrow_exception(std::current_exception());
1030 } catch (const py::type_error &ex) {
1031 MS_LOG(ERROR) << "Meet a exception when get the type of \'" << py::str(obj) << "\'.\n" << ex.what();
1032 std::rethrow_exception(std::current_exception());
1033 }
1034 }
1035
1036 // Get class instance detail type.
GetClassInstanceType(const py::object & obj)1037 ClassInstanceType GetClassInstanceType(const py::object &obj) {
1038 py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
1039 auto class_type =
1040 ClassInstanceType(python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_CLASS_INSTANCE_TYPE, obj).cast<int32_t>());
1041 return class_type;
1042 }
1043
1044 // Check if the object is Cell instance.
IsCellInstance(const py::object & obj)1045 bool IsCellInstance(const py::object &obj) {
1046 auto class_type = GetClassInstanceType(obj);
1047 return class_type == CLASS_INSTANCE_TYPE_CELL;
1048 }
1049
1050 // Check if the object is Numpy Array instance.
IsNumpyArrayInstance(const py::object & obj)1051 bool IsNumpyArrayInstance(const py::object &obj) {
1052 auto class_type = GetClassInstanceType(obj);
1053 return class_type == CLASS_INSTANCE_TYPE_NUMPY_ARRAY;
1054 }
1055
1056 // Check if the object is MsClass instance.
IsMsClassInstance(const py::object & obj)1057 bool IsMsClassInstance(const py::object &obj) { return py::hasattr(obj, PYTHON_MS_CLASS); }
1058
1059 // Check if the object is jit forbidden api.
IsJITForbiddenAPI(const py::object & obj)1060 bool IsJITForbiddenAPI(const py::object &obj) { return py::hasattr(obj, PYTHON_JIT_FORBIDDEN); }
1061
1062 // Check if the object is class type.
IsClassType(const py::object & obj)1063 bool IsClassType(const py::object &obj) {
1064 py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
1065 return python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_CLASS_TYPE, obj).cast<bool>();
1066 }
1067
1068 // Create the python class instance.
CreatePythonObject(const py::object & type,const py::tuple & args_kwargs)1069 py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs) {
1070 py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
1071 // `args_kwargs` maybe a tuple(*args), tuple(**kwargs), or tuple(*args, **kwargs).
1072 return args_kwargs.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_INSTANCE, type)
1073 : python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_INSTANCE, type, args_kwargs);
1074 }
1075
1076 // Call the python script string.
CallPythonScript(const py::object & script,const py::tuple & args_kwargs)1077 py::object CallPythonScript(const py::object &script, const py::tuple &args_kwargs) {
1078 py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
1079 // `args_kwargs` is a tuple(dict(global), dict(local)).
1080 return python_adapter::CallPyModFn(mod, PYTHON_MOD_EVAL_PY_SCRIPT, script, args_kwargs);
1081 }
1082
1083 // Get the ids of python script string.
GetPythonScriptIdAttrs(const py::object & script)1084 py::set GetPythonScriptIdAttrs(const py::object &script) {
1085 py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
1086 return python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_SCRIPT_ID_ATTRS, script);
1087 }
1088
PyDataToValue(const py::object & obj)1089 ValuePtr PyDataToValue(const py::object &obj) {
1090 py::object to_convert = obj;
1091 ValuePtr value = nullptr;
1092 (void)ConvertData(to_convert, &value);
1093 return value;
1094 }
1095
PyDataToStubNode(const py::object & obj)1096 ValuePtr PyDataToStubNode(const py::object &obj) {
1097 py::object to_convert = obj;
1098 ValuePtr value = nullptr;
1099 (void)ConvertStubData(to_convert, &value);
1100 return value;
1101 }
1102
ClearObjectCache()1103 void ClearObjectCache() {
1104 object_map_.clear();
1105 object_graphs_map_.clear();
1106 }
1107 } // namespace data_converter
1108
ConvertData(const py::object & obj)1109 ValuePtr DataConverter::ConvertData(const py::object &obj) {
1110 const auto &convert_funcs = GetDataConvertFuncs();
1111 for (auto &convert_func : convert_funcs) {
1112 if (convert_func->Matched(obj)) {
1113 return convert_func->ConvertPyObject(obj, use_signature_, dtype_, args_value_list_);
1114 }
1115 }
1116 return ConvertOtherObj(obj, forbid_reuse_);
1117 }
1118
ConvertPythonFloatToScalarValue(double value)1119 inline ValuePtr ConvertPythonFloatToScalarValue(double value) {
1120 auto ret = std::make_shared<FP32Imm>(static_cast<float>(value));
1121 ret->set_prim_value(value);
1122 return ret;
1123 }
1124
ConvertBool(const py::object & obj)1125 ValuePtr ConvertBool(const py::object &obj) {
1126 if (!py::isinstance<py::bool_>(obj)) {
1127 return nullptr;
1128 }
1129 return PyCast<BoolImm, bool>(obj);
1130 }
1131
ConvertInt(const py::object & obj)1132 ValuePtr ConvertInt(const py::object &obj) {
1133 // bool is also an instance of py::int_
1134 if (py::isinstance<py::bool_>(obj) || !py::isinstance<py::int_>(obj)) {
1135 return nullptr;
1136 }
1137 return ConvertIntegerWithType(obj);
1138 }
1139
ConvertFloat(const py::object & obj)1140 ValuePtr ConvertFloat(const py::object &obj) {
1141 if (!py::isinstance<py::float_>(obj)) {
1142 return nullptr;
1143 }
1144 return ConvertFloatWithType(obj);
1145 }
1146
ConvertNumber(const py::object & obj)1147 ValuePtr ConvertNumber(const py::object &obj) {
1148 if (py::isinstance<py::bool_>(obj)) {
1149 return PyCast<BoolImm, bool>(obj);
1150 }
1151
1152 if (py::isinstance<py::int_>(obj)) {
1153 return ConvertIntegerWithType(obj);
1154 }
1155
1156 if (py::isinstance<py::float_>(obj)) {
1157 return ConvertFloatWithType(obj);
1158 }
1159
1160 return nullptr;
1161 }
1162
ConvertTensor(const py::object & obj)1163 ValuePtr ConvertTensor(const py::object &obj) {
1164 if (IsStubTensor(obj)) {
1165 return PyStubNodeCast(obj);
1166 }
1167
1168 if (!py::isinstance<mindspore::tensor::Tensor>(obj)) {
1169 return nullptr;
1170 }
1171
1172 return ObjCast<TensorPtr>(obj);
1173 }
1174
ConvertTensorValue(const py::object & obj)1175 TensorPtr ConvertTensorValue(const py::object &obj) {
1176 // The difference between the new ConvertTensorValue function and the existing ConvertTensor is:
1177 // If the obj a StubNode, it must be called the WaitValue to convert to a Tensor.
1178 if (IsStubTensor(obj)) {
1179 auto py_stub = py::getattr(obj, stub::PY_ATTR_STUB);
1180 auto stub = py_stub.cast<stub::StubNodePtr>();
1181 if (stub == nullptr) {
1182 return py::getattr(obj, stub::PY_ATTR_TENSOR).cast<tensor::TensorPtr>();
1183 }
1184 auto value = stub->WaitValue();
1185 auto tensor = value->cast<TensorPtr>();
1186 if (tensor == nullptr) {
1187 // BaseTensor should convert to Tensor for Graph mode
1188 auto base_tensor = value->cast<BaseTensorPtr>();
1189 auto real_tensor = std::make_shared<Tensor>(*base_tensor);
1190 stub->SetValue(real_tensor);
1191 return real_tensor;
1192 }
1193 return tensor;
1194 }
1195 if (!py::isinstance<mindspore::tensor::Tensor>(obj)) {
1196 return nullptr;
1197 }
1198 return obj.cast<TensorPtr>();
1199 }
1200
GetTensorDataPtr(const tensor::TensorPtr & tensor)1201 static inline void *GetTensorDataPtr(const tensor::TensorPtr &tensor) {
1202 MS_EXCEPTION_IF_NULL(tensor);
1203 const auto &device_address = tensor->device_address();
1204 if (device_address != nullptr) {
1205 // Before get data, sync form device address should be performed first
1206 tensor->data_sync();
1207 }
1208 return tensor->data_c();
1209 }
1210
ConvertStr(const py::object & obj)1211 ValuePtr ConvertStr(const py::object &obj) {
1212 if (!py::isinstance<py::str>(obj)) {
1213 return nullptr;
1214 }
1215 return PyCast<StringImm, string>(obj);
1216 }
1217
ConvertAny(const py::object & obj)1218 ValuePtr ConvertAny(const py::object &obj) { return parse::data_converter::PyDataToStubNode(obj); }
1219
ConvertDtype(const py::object & obj)1220 ValuePtr ConvertDtype(const py::object &obj) {
1221 if (!py::isinstance<mindspore::Type>(obj)) {
1222 MS_LOG(EXCEPTION) << "Get arg is not mindspore type " << py::str(obj);
1223 }
1224 return obj.cast<TypePtr>();
1225 }
1226
1227 template <typename TS, typename TD, OpDefConvertFunc func>
ConvertSequence(const py::object & obj)1228 ValuePtr ConvertSequence(const py::object &obj) {
1229 if (!py::isinstance<TS>(obj)) {
1230 return nullptr;
1231 }
1232 auto seq = obj.cast<TS>();
1233 std::vector<ValuePtr> value_list;
1234 for (size_t it = 0; it < seq.size(); ++it) {
1235 auto out = func(seq[it]);
1236 if (out == nullptr) {
1237 return nullptr;
1238 }
1239 value_list.emplace_back(out);
1240 }
1241 return std::make_shared<TD>(value_list);
1242 }
1243
1244 template <typename T, OpDefConvertFunc func>
ConvertSingleElementToSequence(const py::object & obj)1245 ValuePtr ConvertSingleElementToSequence(const py::object &obj) {
1246 auto value = func(obj);
1247 if (value == nullptr) {
1248 return nullptr;
1249 }
1250 std::vector<ValuePtr> value_list{value};
1251 return std::make_shared<T>(std::move(value_list));
1252 }
1253
1254 template <typename T1, typename T2>
ConvertSingleElementToTensor(const py::object & obj)1255 ValuePtr ConvertSingleElementToTensor(const py::object &obj) {
1256 if (!py::isinstance<T1>(obj)) {
1257 return nullptr;
1258 }
1259
1260 auto v = py::cast<T2>(obj);
1261 return std::make_shared<tensor::Tensor>(v);
1262 }
1263
ConvertNumberToTensor(const py::object & obj)1264 ValuePtr ConvertNumberToTensor(const py::object &obj) {
1265 if (py::isinstance<py::bool_>(obj)) {
1266 auto v = py::cast<bool>(obj);
1267 return std::make_shared<tensor::Tensor>(v);
1268 }
1269
1270 if (py::isinstance<py::int_>(obj)) {
1271 auto v = py::cast<int64_t>(obj);
1272 return std::make_shared<tensor::Tensor>(v);
1273 }
1274
1275 if (py::isinstance<py::float_>(obj)) {
1276 auto v = py::cast<pyfloat>(obj);
1277 return std::make_shared<tensor::Tensor>(v);
1278 }
1279
1280 return nullptr;
1281 }
1282
1283 template <typename TS, typename TSE, typename TDE>
ConvertSequenceToTensor(const py::object & obj)1284 ValuePtr ConvertSequenceToTensor(const py::object &obj) {
1285 if (!py::isinstance<TS>(obj)) {
1286 return nullptr;
1287 }
1288
1289 auto seq = obj.cast<TS>();
1290 if (seq.size() == 0) {
1291 return nullptr;
1292 }
1293
1294 std::vector<TDE> value_list;
1295 for (size_t it = 0; it < seq.size(); ++it) {
1296 if (!py::isinstance<TSE>(seq[it])) {
1297 return nullptr;
1298 }
1299
1300 auto value = py::cast<TDE>(seq[it]);
1301 value_list.emplace_back(value);
1302 }
1303
1304 return std::make_shared<tensor::Tensor>(value_list);
1305 }
1306
1307 template <typename TS>
ConvertSequenceBoolToTensor(const py::object & obj)1308 ValuePtr ConvertSequenceBoolToTensor(const py::object &obj) {
1309 if (!py::isinstance<TS>(obj)) {
1310 return nullptr;
1311 }
1312
1313 auto seq = obj.cast<TS>();
1314 if (seq.size() == 0) {
1315 return nullptr;
1316 }
1317
1318 auto tensor = std::make_shared<tensor::Tensor>(kNumberTypeBool, ShapeVector({static_cast<int64_t>(seq.size())}));
1319 auto data = static_cast<bool *>(tensor->data_c());
1320 for (size_t it = 0; it < seq.size(); ++it) {
1321 if (!py::isinstance<py::bool_>(seq[it])) {
1322 return nullptr;
1323 }
1324
1325 auto value = py::cast<bool>(seq[it]);
1326 data[it] = value;
1327 }
1328
1329 return tensor;
1330 }
1331
1332 template <typename TD, typename TDE, typename IMMTYPE, TypeId tid>
ConvertTensorToSequence(const py::object & obj)1333 ValuePtr ConvertTensorToSequence(const py::object &obj) {
1334 auto tensor = ConvertTensorValue(obj);
1335 if (tensor == nullptr) {
1336 MS_LOG(INFO) << "Can not convert python object with type [" << obj.get_type() << "] to Tensor.";
1337 return nullptr;
1338 }
1339
1340 auto data_type = tensor->data_type();
1341 // Since the dst object type is only, once the src object is validated as Tensor, the other converting errors should
1342 // be thrown. There is no other paths for this case to run successfully.
1343 if (data_type != tid) {
1344 MS_LOG(ERROR) << "Can not convert Tensor with type " << TypeIdToString(data_type) << "to Sequence with type "
1345 << TypeIdToString(tid) << ".";
1346 return nullptr;
1347 }
1348
1349 auto shape = tensor->shape();
1350 if (shape.size() > 1) {
1351 MS_LOG(ERROR) << "Only support converting 1-D Tensor or scalar Tensor to sequence. But got the shape of Tensor: "
1352 << shape;
1353 return nullptr;
1354 }
1355
1356 auto data = static_cast<TDE *>(GetTensorDataPtr(tensor));
1357 auto size = tensor->DataSize();
1358 std::vector<ValuePtr> value_list;
1359 for (size_t i = 0; i < size; i++) {
1360 value_list.emplace_back(std::make_shared<IMMTYPE>(data[i]));
1361 }
1362 return std::make_shared<TD>(value_list);
1363 }
1364
1365 template <typename TD>
ConvertTensorToSequenceInt(const py::object & obj)1366 ValuePtr ConvertTensorToSequenceInt(const py::object &obj) {
1367 auto tensor = ConvertTensorValue(obj);
1368 if (tensor == nullptr) {
1369 MS_LOG(INFO) << "Can not convert python object with type [" << obj.get_type() << "] to Tensor.";
1370 return nullptr;
1371 }
1372
1373 auto shape = tensor->shape();
1374 if (shape.size() > 1) {
1375 MS_LOG(ERROR) << "Only support converting 1-D Tensor or scalar Tensor to sequence. But got the shape of Tensor: "
1376 << shape;
1377 return nullptr;
1378 }
1379
1380 auto data_type = tensor->data_type();
1381 if (data_type != kNumberTypeInt64 && data_type != kNumberTypeInt32) {
1382 MS_LOG(ERROR) << "Can not convert Tensor with type " << TypeIdToString(data_type) << "to Int Sequence.";
1383 return nullptr;
1384 }
1385 auto size = tensor->DataSize();
1386 std::vector<ValuePtr> value_list;
1387 if (data_type == kNumberTypeInt64) {
1388 auto data = static_cast<int64_t *>(GetTensorDataPtr(tensor));
1389 std::transform(data, data + size, std::back_inserter(value_list),
1390 [](int64_t num) { return std::make_shared<Int64Imm>(num); });
1391 } else {
1392 auto data = static_cast<int32_t *>(GetTensorDataPtr(tensor));
1393 std::transform(data, data + size, std::back_inserter(value_list),
1394 [](int32_t num) { return std::make_shared<Int64Imm>(num); });
1395 }
1396 return std::make_shared<TD>(value_list);
1397 }
1398
1399 template <typename TD>
ConvertTensorToSequenceFloat(const py::object & obj)1400 ValuePtr ConvertTensorToSequenceFloat(const py::object &obj) {
1401 auto float_tensor = ConvertTensorValue(obj);
1402 if (float_tensor == nullptr) {
1403 MS_LOG(INFO) << "Can not convert python object with type [" << obj.get_type() << "] to Tensor.";
1404 return nullptr;
1405 }
1406
1407 auto data_type = float_tensor->data_type();
1408 if (data_type != kNumberTypeFloat64) {
1409 MS_LOG(ERROR) << "Can not convert Tensor with type " << TypeIdToString(data_type) << "to Float64 Sequence.";
1410 return nullptr;
1411 }
1412
1413 auto shape = float_tensor->shape();
1414 if (shape.size() > 1) {
1415 MS_LOG(ERROR) << "Only support converting 1-D Tensor or scalar Tensor to sequence. But got the shape of Tensor: "
1416 << shape;
1417 return nullptr;
1418 }
1419
1420 auto data = static_cast<double *>(GetTensorDataPtr(float_tensor));
1421 auto size = float_tensor->DataSize();
1422 std::vector<ValuePtr> value_list(size);
1423 for (size_t i = 0; i < size; i++) {
1424 value_list.emplace_back(ConvertPythonFloatToScalarValue(data[i]));
1425 }
1426
1427 return std::make_shared<TD>(value_list);
1428 }
1429
1430 template <typename TD>
ConvertTensorToSequenceAny(const py::object & obj)1431 ValuePtr ConvertTensorToSequenceAny(const py::object &obj) {
1432 auto tensor = ConvertTensorValue(obj);
1433 if (tensor == nullptr) {
1434 MS_LOG(INFO) << "Can not convert python object with type [" << obj.get_type() << "] to Tensor.";
1435 return nullptr;
1436 }
1437
1438 auto shape = tensor->shape();
1439 if (shape.size() > 1) {
1440 MS_LOG(ERROR) << "Only support converting 1-D Tensor or scalar Tensor to sequence. But got the shape of Tensor: "
1441 << shape;
1442 return nullptr;
1443 }
1444
1445 auto data_type = tensor->data_type();
1446 auto size = tensor->DataSize();
1447 std::vector<ValuePtr> value_list(size);
1448 if (data_type == kNumberTypeInt64) {
1449 auto data = static_cast<int64_t *>(GetTensorDataPtr(tensor));
1450 for (size_t i = 0; i < size; i++) {
1451 value_list.emplace_back(std::make_shared<Int64Imm>(data[i]));
1452 }
1453 } else if (data_type == kNumberTypeFloat64) {
1454 auto data = static_cast<double *>(GetTensorDataPtr(tensor));
1455 for (size_t i = 0; i < size; i++) {
1456 value_list.emplace_back(ConvertPythonFloatToScalarValue(data[i]));
1457 }
1458 } else if (data_type == kNumberTypeBool) {
1459 auto data = static_cast<bool *>(GetTensorDataPtr(tensor));
1460 for (size_t i = 0; i < size; i++) {
1461 value_list.emplace_back(std::make_shared<BoolImm>(data[i]));
1462 }
1463 } else {
1464 MS_LOG(ERROR) << "Can not convert Tensor with type " << TypeIdToString(data_type) << " to sequence.";
1465 return nullptr;
1466 }
1467
1468 return std::make_shared<TD>(value_list);
1469 }
1470
ConvertTensorToInt(const py::object & obj)1471 ValuePtr ConvertTensorToInt(const py::object &obj) {
1472 auto tensor = ConvertTensorValue(obj);
1473 if (tensor == nullptr) {
1474 return nullptr;
1475 }
1476 if (tensor->DataSize() != 1) {
1477 MS_LOG(ERROR) << "Can only convert tensor with one element to int, but got " << tensor->ToString();
1478 return nullptr;
1479 }
1480 if (tensor->data_type() == kNumberTypeInt64) {
1481 return std::make_shared<Int64Imm>(static_cast<int64_t *>(GetTensorDataPtr(tensor))[0]);
1482 } else if (tensor->data_type() == kNumberTypeInt32) {
1483 return std::make_shared<Int64Imm>(static_cast<int32_t *>(GetTensorDataPtr(tensor))[0]);
1484 } else {
1485 MS_LOG(ERROR) << "Can not convert " << tensor->ToString() << " to int";
1486 return nullptr;
1487 }
1488 }
1489
ConvertTensorToFloat(const py::object & obj)1490 ValuePtr ConvertTensorToFloat(const py::object &obj) {
1491 auto tensor = ConvertTensorValue(obj);
1492 if (tensor == nullptr) {
1493 return nullptr;
1494 }
1495 if (tensor->DataSize() != 1) {
1496 MS_LOG(ERROR) << "Can only convert tensor with one element to float, but got " << tensor->ToString();
1497 return nullptr;
1498 }
1499 if (tensor->data_type() != kNumberTypeFloat64) {
1500 MS_LOG(ERROR) << "Can not convert " << tensor->ToString() << " to float";
1501 return nullptr;
1502 }
1503 return ConvertPythonFloatToScalarValue(static_cast<double *>(GetTensorDataPtr(tensor))[0]);
1504 }
1505
ConvertTensorToBool(const py::object & obj)1506 ValuePtr ConvertTensorToBool(const py::object &obj) {
1507 auto tensor = ConvertTensorValue(obj);
1508 if (tensor == nullptr) {
1509 return nullptr;
1510 }
1511 if (tensor->data_type() != kNumberTypeBool) {
1512 MS_LOG(ERROR) << "Can not convert " << tensor->ToString() << " to bool";
1513 return nullptr;
1514 }
1515 return std::make_shared<BoolImm>(static_cast<bool *>(GetTensorDataPtr(tensor))[0]);
1516 }
1517
ConvertTensorToNumber(const py::object & obj)1518 ValuePtr ConvertTensorToNumber(const py::object &obj) {
1519 auto tensor = ConvertTensorValue(obj);
1520 if (tensor == nullptr) {
1521 return nullptr;
1522 }
1523 if (tensor->DataSize() != 1) {
1524 MS_EXCEPTION(ValueError) << "Can only convert tensor with one element to number, but got " << tensor->ToString();
1525 }
1526
1527 switch (tensor->data_type()) {
1528 case kNumberTypeBool:
1529 return std::make_shared<BoolImm>(static_cast<bool *>(GetTensorDataPtr(tensor))[0]);
1530 case kNumberTypeInt64:
1531 return std::make_shared<Int64Imm>(static_cast<int64_t *>(GetTensorDataPtr(tensor))[0]);
1532 case kNumberTypeInt32:
1533 return std::make_shared<Int32Imm>(static_cast<int32_t *>(GetTensorDataPtr(tensor))[0]);
1534 case kNumberTypeFloat64:
1535 return ConvertPythonFloatToScalarValue(static_cast<double *>(GetTensorDataPtr(tensor))[0]);
1536 case kNumberTypeFloat32:
1537 return ConvertPythonFloatToScalarValue(static_cast<float *>(GetTensorDataPtr(tensor))[0]);
1538 default:
1539 MS_EXCEPTION(TypeError) << "Can not convert " << tensor->ToString() << " to number";
1540 }
1541 }
1542
1543 static const std::unordered_map<int32_t, OpDefConvertFunc> kConverters = {
1544 // convert functions without type_cast
1545 {(int32_t)mindspore::ops::DT_BOOL, ConvertBool},
1546 {(int32_t)mindspore::ops::DT_INT, ConvertInt},
1547 {(int32_t)mindspore::ops::DT_FLOAT, ConvertFloat},
1548 {(int32_t)mindspore::ops::DT_NUMBER, ConvertNumber},
1549 {(int32_t)mindspore::ops::DT_TENSOR, ConvertTensor},
1550 {(int32_t)mindspore::ops::DT_STR, ConvertStr},
1551 {(int32_t)mindspore::ops::DT_ANY, ConvertAny},
1552 {(int32_t)mindspore::ops::DT_TYPE, ConvertDtype},
1553 {(int32_t)mindspore::ops::DT_TUPLE_BOOL, ConvertSequence<py::tuple, ValueTuple, ConvertBool>},
1554 {(int32_t)mindspore::ops::DT_TUPLE_INT, ConvertSequence<py::tuple, ValueTuple, ConvertInt>},
1555 {(int32_t)mindspore::ops::DT_TUPLE_FLOAT, ConvertSequence<py::tuple, ValueTuple, ConvertFloat>},
1556 {(int32_t)mindspore::ops::DT_TUPLE_NUMBER, ConvertSequence<py::tuple, ValueTuple, ConvertNumber>},
1557 {(int32_t)mindspore::ops::DT_TUPLE_TENSOR, ConvertSequence<py::tuple, ValueTuple, ConvertTensor>},
1558 {(int32_t)mindspore::ops::DT_TUPLE_STR, ConvertSequence<py::tuple, ValueTuple, ConvertStr>},
1559 {(int32_t)mindspore::ops::DT_TUPLE_ANY, ConvertSequence<py::tuple, ValueTuple, ConvertAny>},
1560 {(int32_t)mindspore::ops::DT_LIST_BOOL, ConvertSequence<py::list, ValueList, ConvertBool>},
1561 {(int32_t)mindspore::ops::DT_LIST_INT, ConvertSequence<py::list, ValueList, ConvertInt>},
1562 {(int32_t)mindspore::ops::DT_LIST_FLOAT, ConvertSequence<py::list, ValueList, ConvertFloat>},
1563 {(int32_t)mindspore::ops::DT_LIST_NUMBER, ConvertSequence<py::list, ValueList, ConvertNumber>},
1564 {(int32_t)mindspore::ops::DT_LIST_TENSOR, ConvertSequence<py::list, ValueList, ConvertTensor>},
1565 {(int32_t)mindspore::ops::DT_LIST_STR, ConvertSequence<py::list, ValueList, ConvertStr>},
1566 {(int32_t)mindspore::ops::DT_LIST_ANY, ConvertSequence<py::list, ValueList, ConvertAny>},
1567
1568 // TypeCast1: convert single element to sequence
1569 {CombineTypesForTypeCast(mindspore::ops::DT_NUMBER, mindspore::ops::DT_TUPLE_INT),
1570 ConvertSingleElementToSequence<ValueTuple, ConvertNumber>},
1571 {CombineTypesForTypeCast(mindspore::ops::DT_NUMBER, mindspore::ops::DT_LIST_INT),
1572 ConvertSingleElementToSequence<ValueList, ConvertNumber>},
1573 {CombineTypesForTypeCast(mindspore::ops::DT_INT, mindspore::ops::DT_TUPLE_INT),
1574 ConvertSingleElementToSequence<ValueTuple, ConvertInt>},
1575 {CombineTypesForTypeCast(mindspore::ops::DT_INT, mindspore::ops::DT_LIST_INT),
1576 ConvertSingleElementToSequence<ValueList, ConvertInt>},
1577 {CombineTypesForTypeCast(mindspore::ops::DT_FLOAT, mindspore::ops::DT_TUPLE_INT),
1578 ConvertSingleElementToSequence<ValueTuple, ConvertFloat>},
1579 {CombineTypesForTypeCast(mindspore::ops::DT_FLOAT, mindspore::ops::DT_LIST_INT),
1580 ConvertSingleElementToSequence<ValueList, ConvertFloat>},
1581 {CombineTypesForTypeCast(mindspore::ops::DT_BOOL, mindspore::ops::DT_TUPLE_INT),
1582 ConvertSingleElementToSequence<ValueTuple, ConvertBool>},
1583 {CombineTypesForTypeCast(mindspore::ops::DT_BOOL, mindspore::ops::DT_LIST_INT),
1584 ConvertSingleElementToSequence<ValueList, ConvertBool>},
1585 {CombineTypesForTypeCast(mindspore::ops::DT_ANY, mindspore::ops::DT_TUPLE_ANY),
1586 ConvertSingleElementToSequence<ValueTuple, ConvertAny>},
1587 {CombineTypesForTypeCast(mindspore::ops::DT_ANY, mindspore::ops::DT_LIST_ANY),
1588 ConvertSingleElementToSequence<ValueList, ConvertAny>},
1589
1590 // TypeCast2: convert sequence to sequence, such as py::tuple to ValueList
1591 {CombineTypesForTypeCast(mindspore::ops::DT_TUPLE_INT, mindspore::ops::DT_LIST_INT),
1592 ConvertSequence<py::tuple, ValueList, ConvertInt>},
1593 {CombineTypesForTypeCast(mindspore::ops::DT_TUPLE_FLOAT, mindspore::ops::DT_LIST_FLOAT),
1594 ConvertSequence<py::tuple, ValueList, ConvertFloat>},
1595 {CombineTypesForTypeCast(mindspore::ops::DT_TUPLE_BOOL, mindspore::ops::DT_LIST_BOOL),
1596 ConvertSequence<py::tuple, ValueList, ConvertBool>},
1597 {CombineTypesForTypeCast(mindspore::ops::DT_TUPLE_ANY, mindspore::ops::DT_LIST_ANY),
1598 ConvertSequence<py::tuple, ValueList, ConvertAny>},
1599 {CombineTypesForTypeCast(mindspore::ops::DT_TUPLE_TENSOR, mindspore::ops::DT_LIST_TENSOR),
1600 ConvertSequence<py::tuple, ValueList, ConvertTensor>},
1601
1602 {CombineTypesForTypeCast(mindspore::ops::DT_LIST_INT, mindspore::ops::DT_TUPLE_INT),
1603 ConvertSequence<py::list, ValueTuple, ConvertInt>},
1604 {CombineTypesForTypeCast(mindspore::ops::DT_LIST_FLOAT, mindspore::ops::DT_TUPLE_FLOAT),
1605 ConvertSequence<py::list, ValueTuple, ConvertFloat>},
1606 {CombineTypesForTypeCast(mindspore::ops::DT_LIST_BOOL, mindspore::ops::DT_TUPLE_BOOL),
1607 ConvertSequence<py::list, ValueTuple, ConvertBool>},
1608 {CombineTypesForTypeCast(mindspore::ops::DT_LIST_ANY, mindspore::ops::DT_TUPLE_ANY),
1609 ConvertSequence<py::list, ValueTuple, ConvertAny>},
1610 {CombineTypesForTypeCast(mindspore::ops::DT_LIST_TENSOR, mindspore::ops::DT_TUPLE_TENSOR),
1611 ConvertSequence<py::list, ValueTuple, ConvertAny>},
1612
1613 // TypeCast3: convert single element to Tensor
1614 {CombineTypesForTypeCast(mindspore::ops::DT_INT, mindspore::ops::DT_TENSOR),
1615 ConvertSingleElementToTensor<py::int_, pyint>},
1616 {CombineTypesForTypeCast(mindspore::ops::DT_FLOAT, mindspore::ops::DT_TENSOR),
1617 ConvertSingleElementToTensor<py::float_, pyfloat>},
1618 {CombineTypesForTypeCast(mindspore::ops::DT_BOOL, mindspore::ops::DT_TENSOR),
1619 ConvertSingleElementToTensor<py::bool_, bool>},
1620 {CombineTypesForTypeCast(mindspore::ops::DT_NUMBER, mindspore::ops::DT_TENSOR), ConvertNumberToTensor},
1621
1622 // TypeCast4: convert between sequence and tensor
1623 {CombineTypesForTypeCast(mindspore::ops::DT_TUPLE_INT, mindspore::ops::DT_TENSOR),
1624 ConvertSequenceToTensor<py::tuple, py::int_, pyint>},
1625 {CombineTypesForTypeCast(mindspore::ops::DT_TUPLE_FLOAT, mindspore::ops::DT_TENSOR),
1626 ConvertSequenceToTensor<py::tuple, py::float_, pyfloat>},
1627 {CombineTypesForTypeCast(mindspore::ops::DT_TUPLE_BOOL, mindspore::ops::DT_TENSOR),
1628 ConvertSequenceBoolToTensor<py::tuple>},
1629 {CombineTypesForTypeCast(mindspore::ops::DT_LIST_INT, mindspore::ops::DT_TENSOR),
1630 ConvertSequenceToTensor<py::list, py::int_, pyint>},
1631 {CombineTypesForTypeCast(mindspore::ops::DT_LIST_FLOAT, mindspore::ops::DT_TENSOR),
1632 ConvertSequenceToTensor<py::list, py::float_, pyfloat>},
1633 {CombineTypesForTypeCast(mindspore::ops::DT_LIST_BOOL, mindspore::ops::DT_TENSOR),
1634 ConvertSequenceBoolToTensor<py::list>},
1635
1636 {CombineTypesForTypeCast(mindspore::ops::DT_TENSOR, mindspore::ops::DT_TUPLE_INT),
1637 ConvertTensorToSequenceInt<ValueTuple>},
1638 {CombineTypesForTypeCast(mindspore::ops::DT_TENSOR, mindspore::ops::DT_TUPLE_FLOAT),
1639 ConvertTensorToSequenceFloat<ValueTuple>},
1640 {CombineTypesForTypeCast(mindspore::ops::DT_TENSOR, mindspore::ops::DT_TUPLE_BOOL),
1641 ConvertTensorToSequence<ValueTuple, bool, BoolImm, kNumberTypeBool>},
1642 {CombineTypesForTypeCast(mindspore::ops::DT_TENSOR, mindspore::ops::DT_TUPLE_BOOL),
1643 ConvertTensorToSequenceAny<ValueTuple>},
1644
1645 {CombineTypesForTypeCast(mindspore::ops::DT_TENSOR, mindspore::ops::DT_LIST_INT),
1646 ConvertTensorToSequenceInt<ValueList>},
1647 {CombineTypesForTypeCast(mindspore::ops::DT_TENSOR, mindspore::ops::DT_LIST_FLOAT),
1648 ConvertTensorToSequenceFloat<ValueList>},
1649 {CombineTypesForTypeCast(mindspore::ops::DT_TENSOR, mindspore::ops::DT_LIST_BOOL),
1650 ConvertTensorToSequence<ValueList, bool, BoolImm, kNumberTypeBool>},
1651 {CombineTypesForTypeCast(mindspore::ops::DT_TENSOR, mindspore::ops::DT_LIST_BOOL),
1652 ConvertTensorToSequenceAny<ValueList>},
1653
1654 // TypeCast5: convert tensor to single element
1655 {CombineTypesForTypeCast(mindspore::ops::DT_TENSOR, mindspore::ops::DT_INT), ConvertTensorToInt},
1656 {CombineTypesForTypeCast(mindspore::ops::DT_TENSOR, mindspore::ops::DT_FLOAT), ConvertTensorToFloat},
1657 {CombineTypesForTypeCast(mindspore::ops::DT_TENSOR, mindspore::ops::DT_BOOL), ConvertTensorToBool},
1658 {CombineTypesForTypeCast(mindspore::ops::DT_TENSOR, mindspore::ops::DT_NUMBER), ConvertTensorToNumber},
1659 };
1660
GetConverterByType(int32_t dtype)1661 OpDefConvertFunc GetConverterByType(int32_t dtype) {
1662 auto it = kConverters.find(dtype);
1663 if (it == kConverters.end()) {
1664 if ((dtype >> kTypeShiftBits) == 0) {
1665 MS_LOG(EXCEPTION) << "Can not find converter for dtype[" << ops::EnumToString(static_cast<ops::OP_DTYPE>(dtype))
1666 << "].";
1667 } else {
1668 MS_LOG(EXCEPTION) << "Can not find converter for src_type["
1669 << ops::EnumToString(static_cast<ops::OP_DTYPE>(dtype >> kTypeShiftBits)) << "] and dst_type["
1670 << ops::EnumToString(static_cast<ops::OP_DTYPE>(dtype & kDstMask)) << "].";
1671 }
1672 }
1673
1674 return it->second;
1675 }
1676 } // namespace parse
1677 } // namespace mindspore
1678