1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "pipeline/jit/parse/data_converter.h"
20 #include <unordered_map>
21 #include <utility>
22 #include <string>
23 #include <memory>
24 #include <vector>
25 #include "pipeline/jit/parse/resolve.h"
26 #include "pipeline/jit/parse/python_adapter.h"
27 #include "frontend/operator/ops.h"
28 #include "frontend/operator/composite/composite.h"
29 #include "ir/func_graph_cloner.h"
30 #include "ir/cell.h"
31 #include "utils/symbolic.h"
32 #include "utils/ms_context.h"
33
34 namespace mindspore {
35 namespace parse {
36 using Tensor = mindspore::tensor::Tensor;
37 using TensorPtr = mindspore::tensor::TensorPtr;
38 using MetaTensor = mindspore::tensor::MetaTensor;
39 using MetaTensorPtr = mindspore::tensor::MetaTensorPtr;
40
41 using InstanceCheckFunc = std::function<bool(const py::object &)>;
42 using InstanceConvertFunc = std::function<ValuePtr(const py::object &, bool, const TypePtr &)>;
43 static constexpr int kBit8 = 8;
44 static constexpr int kBit16 = 16;
45 static constexpr int kBit32 = 32;
46 static constexpr int kBit64 = 64;
47 class DataConverter {
48 public:
DataConverter(InstanceConvertFunc convert_func)49 explicit DataConverter(InstanceConvertFunc convert_func) : convert_func_(std::move(convert_func)) {}
50 virtual ~DataConverter() = default;
51 virtual bool Matched(const py::object &obj) = 0;
ConvertPyObject(const py::object & obj,bool use_sig,const TypePtr & dtype)52 virtual ValuePtr ConvertPyObject(const py::object &obj, bool use_sig, const TypePtr &dtype) {
53 if (convert_func_ == nullptr) {
54 MS_LOG(EXCEPTION) << "convert func is null";
55 }
56 return convert_func_(obj, use_sig, dtype);
57 }
58
59 private:
60 InstanceConvertFunc convert_func_ = nullptr;
61 };
62 using DataConverterPtr = std::shared_ptr<DataConverter>;
63
64 using ArgsObjConvertFunc = std::function<ValuePtr(const py::object &)>;
65 using ArgsObjSigConvertFunc = std::function<ValuePtr(const py::object &, bool)>;
66 using ArgsOjbTypeConvertFunc = std::function<ValuePtr(const py::object &, const TypePtr &)>;
67
68 // Convert the data according instance type
69 template <typename T>
70 class ByTypeDataConverter : public DataConverter {
71 public:
ByTypeDataConverter(const InstanceConvertFunc & convert_func)72 explicit ByTypeDataConverter(const InstanceConvertFunc &convert_func)
73 : DataConverter(convert_func), check_func_(py::isinstance<T>) {}
ByTypeDataConverter(const ValuePtr & converted_type)74 explicit ByTypeDataConverter(const ValuePtr &converted_type)
75 : DataConverter(
76 [converted_type](const py::object &, bool, const TypePtr &) -> ValuePtr { return converted_type; }),
77 check_func_(py::isinstance<T>) {}
ByTypeDataConverter(const ArgsObjConvertFunc & convert_func)78 explicit ByTypeDataConverter(const ArgsObjConvertFunc &convert_func)
79 : DataConverter(
80 [convert_func](const py::object &obj, bool, const TypePtr &) -> ValuePtr { return convert_func(obj); }),
81 check_func_(py::isinstance<T>) {}
ByTypeDataConverter(const ArgsObjSigConvertFunc & convert_func)82 explicit ByTypeDataConverter(const ArgsObjSigConvertFunc &convert_func)
83 : DataConverter([convert_func](const py::object &obj, bool use_sig, const TypePtr &) -> ValuePtr {
84 return convert_func(obj, use_sig);
85 }),
86 check_func_(py::isinstance<T>) {}
ByTypeDataConverter(const ArgsOjbTypeConvertFunc & convert_func)87 explicit ByTypeDataConverter(const ArgsOjbTypeConvertFunc &convert_func)
88 : DataConverter([convert_func](const py::object &obj, bool, const TypePtr &dtype) -> ValuePtr {
89 return convert_func(obj, dtype);
90 }),
91 check_func_(py::isinstance<T>) {}
92 ~ByTypeDataConverter() override = default;
93
Matched(const py::object & obj)94 bool Matched(const py::object &obj) override { return check_func_ != nullptr ? check_func_(obj) : false; }
95
96 private:
97 InstanceCheckFunc check_func_ = nullptr;
98 };
99
100 // Convert the data according object attribute.
101 class ByAttrDataConverter : public DataConverter {
102 public:
ByAttrDataConverter(const char * attr_name,const ArgsObjConvertFunc & convert_func)103 ByAttrDataConverter(const char *attr_name, const ArgsObjConvertFunc &convert_func)
104 : DataConverter(
105 [convert_func](const py::object &obj, bool, const TypePtr &) -> ValuePtr { return convert_func(obj); }),
106 attr_name_(attr_name) {}
ByAttrDataConverter(const char * attr_name,const ArgsObjSigConvertFunc & convert_func)107 ByAttrDataConverter(const char *attr_name, const ArgsObjSigConvertFunc &convert_func)
108 : DataConverter([convert_func](const py::object &obj, bool use_sig, const TypePtr &) -> ValuePtr {
109 return convert_func(obj, use_sig);
110 }),
111 attr_name_(attr_name) {}
112 ~ByAttrDataConverter() override = default;
Matched(const py::object & obj)113 bool Matched(const py::object &obj) override { return py::hasattr(obj, attr_name_); }
114
115 private:
116 const char *attr_name_ = nullptr;
117 };
118
ConvertToBpropCut(const py::object & obj)119 FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
120 std::vector<std::string> results = data_converter::GetObjKey(obj);
121 std::string obj_key = results[0];
122 py::function bprop_func = py::getattr(obj, CUSTOM_BPROP_NAME);
123
124 auto bprop_graph = std::make_shared<FuncGraph>();
125 std::vector<AnfNodePtr> outputs;
126
127 auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut");
128 fake_bprop->set_hook(bprop_func);
129 (void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true));
130 outputs.push_back(NewValueNode(fake_bprop));
131
132 py::object code_obj = py::getattr(bprop_func, "__code__");
133 // Three parameters self, out and dout need to be excluded
134 constexpr auto kBpropExcludeParamNum = 3;
135 size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - kBpropExcludeParamNum;
136 for (size_t i = 0; i < inputs_num; ++i) {
137 auto param = bprop_graph->add_parameter();
138 outputs.push_back(param);
139 }
140 auto p1 = bprop_graph->add_parameter();
141 auto p2 = bprop_graph->add_parameter();
142 outputs.push_back(p1);
143 outputs.push_back(p2);
144
145 bprop_graph->set_output(bprop_graph->NewCNode(outputs));
146 data_converter::SetObjGraphValue(obj_key, bprop_graph);
147 return bprop_graph;
148 }
149
150 namespace {
ConvertTuple(const py::object & obj,bool use_signature)151 ValuePtr ConvertTuple(const py::object &obj, bool use_signature) {
152 MS_LOG(DEBUG) << "Converting python tuple";
153 auto tuple = obj.cast<py::tuple>();
154 std::vector<ValuePtr> value_list;
155 for (size_t it = 0; it < tuple.size(); ++it) {
156 ValuePtr out = nullptr;
157 bool success = ConvertData(tuple[it], &out, use_signature);
158 if (!success) {
159 return nullptr;
160 }
161 value_list.push_back(out);
162 }
163 return std::make_shared<ValueTuple>(value_list);
164 }
165
ConvertList(const py::object & obj,bool use_signature)166 ValuePtr ConvertList(const py::object &obj, bool use_signature) {
167 MS_LOG(DEBUG) << "Converting python list";
168
169 auto list = obj.cast<py::list>();
170 std::vector<ValuePtr> value_list;
171 for (size_t it = 0; it < list.size(); ++it) {
172 ValuePtr out = nullptr;
173 bool success = ConvertData(list[it], &out, use_signature);
174 if (!success) {
175 return nullptr;
176 }
177 value_list.push_back(out);
178 }
179 return std::make_shared<ValueList>(value_list);
180 }
181
ConvertCellList(const py::object & obj,bool use_signature)182 ValuePtr ConvertCellList(const py::object &obj, bool use_signature) {
183 MS_LOG(DEBUG) << "Converting cell list";
184 py::sequence list = obj;
185 std::vector<ValuePtr> value_list;
186 for (size_t it = 0; it < list.size(); ++it) {
187 ValuePtr out = nullptr;
188 bool success = ConvertData(list[it], &out, use_signature);
189 if (!success) {
190 return nullptr;
191 }
192 value_list.push_back(out);
193 }
194 return std::make_shared<ValueTuple>(value_list);
195 }
196
ConvertDict(const py::object & obj,bool use_signature)197 ValuePtr ConvertDict(const py::object &obj, bool use_signature) {
198 MS_LOG(DEBUG) << "Converting python dict";
199
200 auto dict_values = obj.cast<py::dict>();
201 std::vector<std::pair<std::string, ValuePtr>> key_values;
202 for (auto item : dict_values) {
203 if (!py::isinstance<py::str>(item.first)) {
204 MS_LOG(ERROR) << "The key of dict is only support str.";
205 return nullptr;
206 }
207 std::string key = py::str(item.first);
208 ValuePtr out = nullptr;
209 bool success = ConvertData(dict_values[item.first], &out, use_signature);
210 if (!success) {
211 return nullptr;
212 }
213 key_values.emplace_back(key, out);
214 }
215 return std::make_shared<ValueDictionary>(key_values);
216 }
217
ConvertModuleNameSpace(const py::object & obj)218 ValuePtr ConvertModuleNameSpace(const py::object &obj) {
219 MS_LOG(DEBUG) << "Converting python module";
220 py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
221 py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj);
222 auto converted =
223 std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_MODULE, py::cast<py::module>(module_namespace), obj);
224 MS_LOG(DEBUG) << "name_space: " << converted->ToString();
225 return converted;
226 }
227
ConvertDataClass(const py::object & obj)228 ValuePtr ConvertDataClass(const py::object &obj) {
229 MS_LOG(DEBUG) << "Converting dataclass";
230 // Maybe the obj is dataclass define
231 auto desc = py::cast<std::string>(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj));
232 // desc has format "<class xxxx>", strip the '<' and '>' by offset 1
233 auto converted = std::make_shared<ClassObject>(obj, std::string(desc.begin() + 1, desc.end() - 1));
234 return converted;
235 }
236
ConvertPrimitive(const py::object & obj,bool use_signature=false)237 ValuePtr ConvertPrimitive(const py::object &obj, bool use_signature = false) {
238 MS_LOG(DEBUG) << "Converting primitive object" << use_signature;
239
240 // need check the primitive is class type or instance
241 auto obj_type = data_converter::GetObjType(obj);
242 if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
243 auto desc = py::cast<std::string>(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj));
244 // desc has format "<class xxxx>", strip the '<' and '>' by offset 1.
245 return std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
246 }
247 py::object adapter_obj = obj;
248 if (py::hasattr(obj, "__setattr_flag__")) {
249 if (py::hasattr(obj, "_clone")) {
250 auto clone_fn = obj.attr("_clone");
251 adapter_obj = clone_fn();
252 }
253 }
254 auto prim_adapter = adapter_obj.cast<PrimitivePyAdapterPtr>();
255 MS_EXCEPTION_IF_NULL(prim_adapter);
256 auto primitive = prim_adapter->attached_primitive();
257 if (primitive == nullptr) {
258 primitive = std::make_shared<PrimitivePy>(adapter_obj, prim_adapter);
259 prim_adapter->set_attached_primitive(primitive);
260 }
261
262 if (use_signature) {
263 return std::make_shared<prim::DoSignaturePrimitive>(primitive->name(), primitive);
264 }
265 return primitive;
266 }
267
ConvertMetaFuncGraph(const py::object & obj,bool use_signature=false)268 ValuePtr ConvertMetaFuncGraph(const py::object &obj, bool use_signature = false) {
269 MS_LOG(DEBUG) << "Converting MetaFuncGraph object";
270 auto meta = obj.cast<MetaFuncGraphPtr>();
271 if (meta == nullptr) {
272 MS_LOG(ERROR) << "Resolve MetaFuncGraph error, get ptr is null";
273 return nullptr;
274 }
275 if (use_signature) {
276 return std::make_shared<prim::DoSignaturePrimitive>(meta->name(), meta);
277 }
278 return meta;
279 }
280
ConvertFuncGraph(const py::object & obj)281 ValuePtr ConvertFuncGraph(const py::object &obj) {
282 MS_LOG(DEBUG) << "Converting FuncGraph object";
283 auto func_graph = obj.cast<FuncGraphPtr>();
284 if (func_graph == nullptr) {
285 MS_LOG(ERROR) << "Resolve FuncGraph error, get ptr is null";
286 return nullptr;
287 }
288 auto new_fg = BasicClone(func_graph);
289 new_fg->set_attr("is_load", MakeValue(true));
290 return new_fg;
291 }
292
ConvertSlice(const py::object & obj)293 ValuePtr ConvertSlice(const py::object &obj) {
294 MS_LOG(DEBUG) << "Converting slice object";
295
296 auto convert_func = [obj](const std::string &attr) -> ValuePtr {
297 auto py_attr = py::getattr(obj, attr.c_str());
298 if (py::isinstance<py::none>(py_attr)) {
299 return kNone;
300 }
301 if (py::isinstance<py::int_>(py_attr)) {
302 auto value = py::cast<int64_t>(py_attr);
303 return MakeValue(value);
304 }
305 MS_LOG(EXCEPTION) << "Attribute '" << attr << "' of " << py::str(obj) << " should be int but got "
306 << py::str(py_attr);
307 };
308 ValuePtr start = convert_func("start");
309 ValuePtr stop = convert_func("stop");
310 ValuePtr step = convert_func("step");
311 return std::make_shared<ValueSlice>(start, stop, step);
312 }
313
ConvertCellObjToFuncGraph(const py::object & obj)314 ValuePtr ConvertCellObjToFuncGraph(const py::object &obj) {
315 FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
316 if (func_graph == nullptr) {
317 MS_LOG(ERROR) << "Parse resolve function error.";
318 return nullptr;
319 }
320 // if the cell object has specified bprop, it has user-defined bprop function parse and record it
321 if (py::hasattr(obj, CUSTOM_BPROP_NAME)) {
322 bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug"));
323 FuncGraphPtr bprop_graph =
324 enable_bprop_debug ? ConvertToBpropCut(obj) : ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD);
325 if (bprop_graph != nullptr) {
326 (void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
327 (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph)));
328 func_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
329 }
330 }
331 if (py::hasattr(obj, STAGE_NAME)) {
332 auto stage = py::cast<int>(py::getattr(obj, STAGE_NAME));
333 func_graph->set_stage(stage);
334 }
335 return func_graph;
336 }
337
ConvertOtherObj(const py::object & obj)338 ValuePtr ConvertOtherObj(const py::object &obj) {
339 auto obj_type = data_converter::GetObjType(obj);
340 MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " ";
341 if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
342 MS_LOG(DEBUG) << "Resolve the class type, need create class instance.";
343 std::string desc = py::str(obj);
344 // desc has format "<class xxxx>", strip the '<' and '>' by offset 1.
345 return std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
346 }
347 if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD) {
348 MS_LOG(DEBUG) << "Convert the obj to func graph, type is " << obj_type;
349 FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
350 if (func_graph == nullptr) {
351 MS_LOG(ERROR) << "Parse resolve function error.";
352 return nullptr;
353 }
354 return func_graph;
355 }
356 if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) {
357 // Create the namespace for common class instance
358 // When the obj is Cell, default parse the 'construct'
359 py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
360 py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
361 auto res = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
362 MS_LOG(DEBUG) << "name_space: " << res->ToString();
363 return res;
364 }
365 MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj));
366 return nullptr;
367 }
368
369 template <typename T>
ConvertNumberWithType(const T & obj,const TypePtr & dtype)370 ValuePtr ConvertNumberWithType(const T &obj, const TypePtr &dtype) {
371 ValuePtr data = nullptr;
372 auto int_dypte = dyn_cast<Int>(dtype);
373 if (int_dypte != nullptr) {
374 switch (int_dypte->nbits()) {
375 case kBit8:
376 data = std::make_shared<Int8Imm>(obj);
377 break;
378 case kBit16:
379 data = std::make_shared<Int16Imm>(obj);
380 break;
381 case kBit32:
382 data = std::make_shared<Int32Imm>(obj);
383 break;
384 case kBit64:
385 data = std::make_shared<Int64Imm>(obj);
386 break;
387 default:
388 data = std::make_shared<Int64Imm>(obj);
389 }
390 return data;
391 }
392
393 auto uint_dypte = dyn_cast<UInt>(dtype);
394 if (uint_dypte != nullptr) {
395 switch (uint_dypte->nbits()) {
396 case kBit8:
397 data = std::make_shared<UInt8Imm>(obj);
398 break;
399 case kBit16:
400 data = std::make_shared<UInt16Imm>(obj);
401 break;
402 case kBit32:
403 data = std::make_shared<UInt32Imm>(obj);
404 break;
405 case kBit64:
406 data = std::make_shared<UInt64Imm>(obj);
407 break;
408 default:
409 data = std::make_shared<UInt32Imm>(obj);
410 }
411 return data;
412 }
413
414 auto float_dypte = dyn_cast<Float>(dtype);
415 if (float_dypte != nullptr) {
416 switch (float_dypte->nbits()) {
417 case kBit32:
418 data = std::make_shared<FP32Imm>(obj);
419 break;
420 case kBit64:
421 data = std::make_shared<FP64Imm>(obj);
422 break;
423 default:
424 data = std::make_shared<FP32Imm>(obj);
425 }
426 return data;
427 }
428 return nullptr;
429 }
430
ConvertIntegerWithType(const py::object & obj,const TypePtr & dtype=nullptr)431 ValuePtr ConvertIntegerWithType(const py::object &obj, const TypePtr &dtype = nullptr) {
432 auto obj_int64 = py::cast<int64_t>(obj);
433 if (dtype == nullptr) {
434 return std::make_shared<Int64Imm>(obj_int64);
435 }
436 return ConvertNumberWithType<int64_t>(obj_int64, dtype);
437 }
438
ConvertFloatWithType(const py::object & obj,const TypePtr & dtype=nullptr)439 ValuePtr ConvertFloatWithType(const py::object &obj, const TypePtr &dtype = nullptr) {
440 auto obj_float64 = py::cast<float>(obj);
441 if (dtype == nullptr) {
442 return std::make_shared<FP32Imm>(obj_float64);
443 }
444 return ConvertNumberWithType<float>(obj_float64, dtype);
445 }
446
447 template <typename T, typename U>
PyCast(const py::object & obj)448 ValuePtr PyCast(const py::object &obj) {
449 return std::make_shared<T>(py::cast<U>(obj));
450 }
451
452 template <typename T>
ObjCast(const py::object & obj)453 ValuePtr ObjCast(const py::object &obj) {
454 return obj.cast<T>();
455 }
456
GetDataConverters()457 std::vector<DataConverterPtr> GetDataConverters() {
458 static std::vector<DataConverterPtr> data_converters = {
459 // Convert data by python object type.
460 std::make_shared<ByTypeDataConverter<py::none>>(kNone),
461 std::make_shared<ByTypeDataConverter<py::bool_>>(PyCast<BoolImm, bool>),
462 std::make_shared<ByTypeDataConverter<py::str>>(PyCast<StringImm, string>),
463 std::make_shared<ByTypeDataConverter<py::ellipsis>>(kEllipsis),
464 std::make_shared<ByTypeDataConverter<py::module>>(ConvertModuleNameSpace),
465 std::make_shared<ByAttrDataConverter>(PYTHON_DATACLASS_FIELDS, ConvertDataClass),
466 std::make_shared<ByTypeDataConverter<Type>>(ObjCast<TypePtr>),
467 std::make_shared<ByTypeDataConverter<Tensor>>(ObjCast<TensorPtr>),
468 std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>),
469 std::make_shared<ByTypeDataConverter<UMonad>>(ObjCast<UMonadPtr>),
470 std::make_shared<ByTypeDataConverter<IOMonad>>(ObjCast<IOMonadPtr>),
471 std::make_shared<ByTypeDataConverter<EnvInstance>>(ObjCast<std::shared_ptr<EnvInstance>>),
472 std::make_shared<ByAttrDataConverter>(PYTHON_CLASS_MEMBER_NAMESPACE,
473 [](const py::object &obj) -> ValuePtr {
474 auto res =
475 std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj);
476 MS_LOG(DEBUG) << "name_space: " << res->ToString();
477 return res;
478 }),
479 std::make_shared<ByTypeDataConverter<py::int_>>(ConvertIntegerWithType),
480 std::make_shared<ByTypeDataConverter<py::float_>>(ConvertFloatWithType),
481 std::make_shared<ByTypeDataConverter<py::dict>>(ConvertDict),
482 std::make_shared<ByTypeDataConverter<py::slice>>(ConvertSlice),
483 std::make_shared<ByTypeDataConverter<py::tuple>>(ConvertTuple),
484 std::make_shared<ByAttrDataConverter>(PYTHON_CELL_AS_LIST, ConvertCellList),
485 std::make_shared<ByTypeDataConverter<Cell>>(ConvertCellObjToFuncGraph),
486 std::make_shared<ByTypeDataConverter<py::list>>(ConvertList),
487 std::make_shared<ByAttrDataConverter>(PYTHON_PRIMITIVE_FLAG, ConvertPrimitive),
488 std::make_shared<ByTypeDataConverter<MetaFuncGraph>>(ConvertMetaFuncGraph),
489 std::make_shared<ByTypeDataConverter<FuncGraph>>(ConvertFuncGraph),
490 };
491 return data_converters;
492 }
493 } // namespace
494
ConvertData(const py::object & obj,ValuePtr * const data,bool use_signature,const TypePtr & dtype)495 bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, const TypePtr &dtype) {
496 // Check parameter valid
497 if (data == nullptr) {
498 MS_LOG(ERROR) << "Data is null pointer";
499 return false;
500 }
501 ValuePtr converted = nullptr;
502 bool matched = false;
503 auto converters = GetDataConverters();
504 for (auto &converter : converters) {
505 if (converter->Matched(obj)) {
506 converted = converter->ConvertPyObject(obj, use_signature, dtype);
507 matched = true;
508 break;
509 }
510 }
511 if (!matched) {
512 converted = ConvertOtherObj(obj);
513 }
514 *data = converted;
515 return converted != nullptr;
516 }
517
518 // Convert data to graph
ConvertToFuncGraph(const py::object & obj,const std::string & python_mod_get_parse_method)519 FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python_mod_get_parse_method) {
520 std::vector<std::string> results = data_converter::GetObjKey(obj);
521 std::string obj_id = results[0] + python_mod_get_parse_method;
522 std::string obj_key = results[1];
523 FuncGraphPtr func_graph = nullptr;
524 ValuePtr value = nullptr;
525 bool is_cache = data_converter::GetObjectValue(obj_id, &value);
526 if (is_cache && value != nullptr && value->isa<FuncGraph>()) {
527 MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id;
528 func_graph = value->cast<FuncGraphPtr>();
529 if (!func_graph->dropped()) {
530 return func_graph;
531 }
532 }
533
534 func_graph = ParsePythonCode(obj, python_mod_get_parse_method);
535 if (func_graph == nullptr) {
536 MS_LOG(ERROR) << "Parse resolve function error.";
537 return nullptr;
538 }
539
540 data_converter::MakeProperNameToFuncGraph(func_graph, obj_id);
541 data_converter::CacheObjectValue(obj_id, func_graph);
542 if (!obj_key.empty()) {
543 MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString();
544 data_converter::SetObjGraphValue(obj_key, func_graph);
545 }
546
547 return func_graph;
548 }
549 namespace data_converter {
550 static std::unordered_map<std::string, ValuePtr> object_map_;
551
552 static std::unordered_map<std::string, std::vector<FuncGraphPtr>> object_graphs_map_;
553
SetObjGraphValue(const std::string & obj_key,const FuncGraphPtr & data)554 void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) {
555 object_graphs_map_[obj_key].push_back(data);
556 MS_LOG(DEBUG) << "Set func graph size:" << object_graphs_map_.size();
557 }
558
GetObjGraphs()559 const std::unordered_map<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs() {
560 MS_LOG(DEBUG) << "Obj size:" << object_graphs_map_.size();
561 return object_graphs_map_;
562 }
563
CacheObjectValue(const std::string & obj_key,const ValuePtr & data)564 void CacheObjectValue(const std::string &obj_key, const ValuePtr &data) { object_map_[obj_key] = data; }
GetObjectValue(const std::string & obj_key,ValuePtr * const data)565 bool GetObjectValue(const std::string &obj_key, ValuePtr *const data) {
566 if (object_map_.count(obj_key)) {
567 *data = object_map_[obj_key];
568 return true;
569 }
570 return false;
571 }
GetObjKey(const py::object & obj)572 std::vector<std::string> GetObjKey(const py::object &obj) {
573 py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
574 py::tuple obj_tuple = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_KEY, obj);
575 if (obj_tuple.size() != 2) {
576 MS_LOG(EXCEPTION) << "Get_obj_key must return 2 elements";
577 }
578 return {py::cast<std::string>(obj_tuple[0]), py::cast<std::string>(obj_tuple[1])};
579 }
580
581 // Get obj detail type
GetObjType(const py::object & obj)582 ResolveTypeDef GetObjType(const py::object &obj) {
583 try {
584 py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
585 auto obj_type =
586 ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast<int32_t>());
587 return obj_type;
588 } catch (const py::error_already_set &ex) {
589 MS_LOG(ERROR) << "Meet a exception from Python when get the type of `" << py::str(obj) << "`.\n" << ex.what();
590 std::rethrow_exception(std::current_exception());
591 } catch (const py::type_error &ex) {
592 MS_LOG(ERROR) << "Meet a exception when get the type of `" << py::str(obj) << "`.\n" << ex.what();
593 std::rethrow_exception(std::current_exception());
594 }
595 }
596
597 // Get class instance detail type.
GetClassInstanceType(const py::object & obj)598 ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) {
599 py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
600 auto class_type =
601 ClassInstanceTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_CLASS_INSTANCE_TYPE, obj).cast<int32_t>());
602 return class_type;
603 }
604
605 // Check the object is Cell Instance.
IsCellInstance(const py::object & obj)606 bool IsCellInstance(const py::object &obj) {
607 auto class_type = GetClassInstanceType(obj);
608 bool isCell = (class_type == CLASS_INSTANCE_TYPE_CELL);
609 return isCell;
610 }
611
612 // Create the python class instance.
CreatePythonObject(const py::object & type,const py::tuple & args_kwargs)613 py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs) {
614 py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
615 // `args_kwargs` maybe a tuple(*args), tuple(**kwargs), or tuple(*args, **kwargs).
616 return args_kwargs.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_INSTANCE, type)
617 : python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_INSTANCE, type, args_kwargs);
618 }
619
620 // Call the python script string.
CallPythonScript(const py::object & script,const py::tuple & args_kwargs)621 py::object CallPythonScript(const py::object &script, const py::tuple &args_kwargs) {
622 py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
623 // `args_kwargs` is a tuple(dict(global), dict(local)).
624 return args_kwargs.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_EVAL_PY_SCRIPT, script)
625 : python_adapter::CallPyModFn(mod, PYTHON_MOD_EVAL_PY_SCRIPT, script, args_kwargs);
626 }
627
628 // Generate an appropriate name and set to graph debuginfo,
629 // character <> can not used in the dot file, so change to another symbol.
MakeProperNameToFuncGraph(const FuncGraphPtr & func_graph,std::string name)630 void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name) {
631 MS_EXCEPTION_IF_NULL(func_graph);
632 MS_EXCEPTION_IF_NULL(func_graph->debug_info());
633 // Set detail name info of function
634 std::ostringstream oss;
635 for (size_t i = 0; i < name.size(); i++) {
636 if (name[i] == '<') {
637 oss << "「";
638 } else if (name[i] == '>') {
639 oss << "」";
640 } else {
641 oss << name[i];
642 }
643 }
644 func_graph->debug_info()->set_full_name(oss.str());
645 }
646
PyDataToValue(const py::object & obj)647 ValuePtr PyDataToValue(const py::object &obj) {
648 py::object to_convert = obj;
649 ValuePtr value = nullptr;
650 (void)ConvertData(to_convert, &value);
651 return value;
652 }
653
ClearObjectCache()654 void ClearObjectCache() {
655 object_map_.clear();
656 object_graphs_map_.clear();
657 }
658 } // namespace data_converter
659
660 static std::unordered_map<std::string, ClassPtr> g_dataClassToClass = {};
661
662 // Parse dataclass to mindspore Class type
ParseDataClass(const py::object & cls_obj)663 ClassPtr ParseDataClass(const py::object &cls_obj) {
664 std::string cls_name = py::cast<std::string>(python_adapter::GetPyObjAttr(cls_obj, "__name__"));
665 std::string cls_module = py::cast<std::string>(python_adapter::GetPyObjAttr(cls_obj, "__module__"));
666 std::string cls = cls_module + "." + cls_name;
667 auto iterator = g_dataClassToClass.find(cls);
668 if (iterator != g_dataClassToClass.end()) {
669 return iterator->second;
670 }
671
672 py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
673 ClassAttrVector attributes;
674 py::dict names = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_ATTRS, cls_obj);
675 for (auto &item : names) {
676 auto type_value = item.second.cast<TypePtr>();
677 MS_EXCEPTION_IF_NULL(type_value);
678 MS_LOG(DEBUG) << "(Name: " << py::cast<std::string>(item.first) << ", type: " << type_value->ToString() << ")";
679 attributes.push_back(std::make_pair(py::cast<std::string>(item.first), type_value));
680 }
681
682 std::unordered_map<std::string, ValuePtr> methods_map;
683 py::dict methods = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_METHODS, cls_obj);
684 for (auto &item : methods) {
685 auto fun_name = item.first.cast<std::string>();
686 auto obj = py::cast<py::object>(item.second);
687 std::shared_ptr<PyObjectWrapper> method_obj = std::make_shared<PyObjectWrapper>(obj, fun_name);
688 methods_map[fun_name] = method_obj;
689 }
690
691 std::shared_ptr<Class> me_class = std::make_shared<Class>(Named(cls_name), attributes, methods_map);
692 // static Variable for cache
693 // cppcheck-suppress unreadVariable
694 g_dataClassToClass[cls] = me_class;
695
696 return me_class;
697 }
698
CleanDataClassToClassMap()699 void CleanDataClassToClassMap() { g_dataClassToClass.clear(); }
700 } // namespace parse
701 } // namespace mindspore
702