• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "pipeline/jit/parse/data_converter.h"
20 #include <unordered_map>
21 #include <utility>
22 #include <string>
23 #include <memory>
24 #include <vector>
25 #include "pipeline/jit/parse/resolve.h"
26 #include "pipeline/jit/parse/python_adapter.h"
27 #include "frontend/operator/ops.h"
28 #include "frontend/operator/composite/composite.h"
29 #include "ir/func_graph_cloner.h"
30 #include "ir/cell.h"
31 #include "utils/symbolic.h"
32 #include "utils/ms_context.h"
33 
34 namespace mindspore {
35 namespace parse {
36 using Tensor = mindspore::tensor::Tensor;
37 using TensorPtr = mindspore::tensor::TensorPtr;
38 using MetaTensor = mindspore::tensor::MetaTensor;
39 using MetaTensorPtr = mindspore::tensor::MetaTensorPtr;
40 
41 using InstanceCheckFunc = std::function<bool(const py::object &)>;
42 using InstanceConvertFunc = std::function<ValuePtr(const py::object &, bool, const TypePtr &)>;
43 static constexpr int kBit8 = 8;
44 static constexpr int kBit16 = 16;
45 static constexpr int kBit32 = 32;
46 static constexpr int kBit64 = 64;
47 class DataConverter {
48  public:
DataConverter(InstanceConvertFunc convert_func)49   explicit DataConverter(InstanceConvertFunc convert_func) : convert_func_(std::move(convert_func)) {}
50   virtual ~DataConverter() = default;
51   virtual bool Matched(const py::object &obj) = 0;
ConvertPyObject(const py::object & obj,bool use_sig,const TypePtr & dtype)52   virtual ValuePtr ConvertPyObject(const py::object &obj, bool use_sig, const TypePtr &dtype) {
53     if (convert_func_ == nullptr) {
54       MS_LOG(EXCEPTION) << "convert func is null";
55     }
56     return convert_func_(obj, use_sig, dtype);
57   }
58 
59  private:
60   InstanceConvertFunc convert_func_ = nullptr;
61 };
62 using DataConverterPtr = std::shared_ptr<DataConverter>;
63 
64 using ArgsObjConvertFunc = std::function<ValuePtr(const py::object &)>;
65 using ArgsObjSigConvertFunc = std::function<ValuePtr(const py::object &, bool)>;
66 using ArgsOjbTypeConvertFunc = std::function<ValuePtr(const py::object &, const TypePtr &)>;
67 
68 // Convert the data according instance type
69 template <typename T>
70 class ByTypeDataConverter : public DataConverter {
71  public:
ByTypeDataConverter(const InstanceConvertFunc & convert_func)72   explicit ByTypeDataConverter(const InstanceConvertFunc &convert_func)
73       : DataConverter(convert_func), check_func_(py::isinstance<T>) {}
ByTypeDataConverter(const ValuePtr & converted_type)74   explicit ByTypeDataConverter(const ValuePtr &converted_type)
75       : DataConverter(
76           [converted_type](const py::object &, bool, const TypePtr &) -> ValuePtr { return converted_type; }),
77         check_func_(py::isinstance<T>) {}
ByTypeDataConverter(const ArgsObjConvertFunc & convert_func)78   explicit ByTypeDataConverter(const ArgsObjConvertFunc &convert_func)
79       : DataConverter(
80           [convert_func](const py::object &obj, bool, const TypePtr &) -> ValuePtr { return convert_func(obj); }),
81         check_func_(py::isinstance<T>) {}
ByTypeDataConverter(const ArgsObjSigConvertFunc & convert_func)82   explicit ByTypeDataConverter(const ArgsObjSigConvertFunc &convert_func)
83       : DataConverter([convert_func](const py::object &obj, bool use_sig, const TypePtr &) -> ValuePtr {
84           return convert_func(obj, use_sig);
85         }),
86         check_func_(py::isinstance<T>) {}
ByTypeDataConverter(const ArgsOjbTypeConvertFunc & convert_func)87   explicit ByTypeDataConverter(const ArgsOjbTypeConvertFunc &convert_func)
88       : DataConverter([convert_func](const py::object &obj, bool, const TypePtr &dtype) -> ValuePtr {
89           return convert_func(obj, dtype);
90         }),
91         check_func_(py::isinstance<T>) {}
92   ~ByTypeDataConverter() override = default;
93 
Matched(const py::object & obj)94   bool Matched(const py::object &obj) override { return check_func_ != nullptr ? check_func_(obj) : false; }
95 
96  private:
97   InstanceCheckFunc check_func_ = nullptr;
98 };
99 
100 // Convert the data according object attribute.
101 class ByAttrDataConverter : public DataConverter {
102  public:
ByAttrDataConverter(const char * attr_name,const ArgsObjConvertFunc & convert_func)103   ByAttrDataConverter(const char *attr_name, const ArgsObjConvertFunc &convert_func)
104       : DataConverter(
105           [convert_func](const py::object &obj, bool, const TypePtr &) -> ValuePtr { return convert_func(obj); }),
106         attr_name_(attr_name) {}
ByAttrDataConverter(const char * attr_name,const ArgsObjSigConvertFunc & convert_func)107   ByAttrDataConverter(const char *attr_name, const ArgsObjSigConvertFunc &convert_func)
108       : DataConverter([convert_func](const py::object &obj, bool use_sig, const TypePtr &) -> ValuePtr {
109           return convert_func(obj, use_sig);
110         }),
111         attr_name_(attr_name) {}
112   ~ByAttrDataConverter() override = default;
Matched(const py::object & obj)113   bool Matched(const py::object &obj) override { return py::hasattr(obj, attr_name_); }
114 
115  private:
116   const char *attr_name_ = nullptr;
117 };
118 
ConvertToBpropCut(const py::object & obj)119 FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
120   std::vector<std::string> results = data_converter::GetObjKey(obj);
121   std::string obj_key = results[0];
122   py::function bprop_func = py::getattr(obj, CUSTOM_BPROP_NAME);
123 
124   auto bprop_graph = std::make_shared<FuncGraph>();
125   std::vector<AnfNodePtr> outputs;
126 
127   auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut");
128   fake_bprop->set_hook(bprop_func);
129   (void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true));
130   outputs.push_back(NewValueNode(fake_bprop));
131 
132   py::object code_obj = py::getattr(bprop_func, "__code__");
133   // Three parameters self, out and dout need to be excluded
134   constexpr auto kBpropExcludeParamNum = 3;
135   size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - kBpropExcludeParamNum;
136   for (size_t i = 0; i < inputs_num; ++i) {
137     auto param = bprop_graph->add_parameter();
138     outputs.push_back(param);
139   }
140   auto p1 = bprop_graph->add_parameter();
141   auto p2 = bprop_graph->add_parameter();
142   outputs.push_back(p1);
143   outputs.push_back(p2);
144 
145   bprop_graph->set_output(bprop_graph->NewCNode(outputs));
146   data_converter::SetObjGraphValue(obj_key, bprop_graph);
147   return bprop_graph;
148 }
149 
150 namespace {
ConvertTuple(const py::object & obj,bool use_signature)151 ValuePtr ConvertTuple(const py::object &obj, bool use_signature) {
152   MS_LOG(DEBUG) << "Converting python tuple";
153   auto tuple = obj.cast<py::tuple>();
154   std::vector<ValuePtr> value_list;
155   for (size_t it = 0; it < tuple.size(); ++it) {
156     ValuePtr out = nullptr;
157     bool success = ConvertData(tuple[it], &out, use_signature);
158     if (!success) {
159       return nullptr;
160     }
161     value_list.push_back(out);
162   }
163   return std::make_shared<ValueTuple>(value_list);
164 }
165 
ConvertList(const py::object & obj,bool use_signature)166 ValuePtr ConvertList(const py::object &obj, bool use_signature) {
167   MS_LOG(DEBUG) << "Converting python list";
168 
169   auto list = obj.cast<py::list>();
170   std::vector<ValuePtr> value_list;
171   for (size_t it = 0; it < list.size(); ++it) {
172     ValuePtr out = nullptr;
173     bool success = ConvertData(list[it], &out, use_signature);
174     if (!success) {
175       return nullptr;
176     }
177     value_list.push_back(out);
178   }
179   return std::make_shared<ValueList>(value_list);
180 }
181 
ConvertCellList(const py::object & obj,bool use_signature)182 ValuePtr ConvertCellList(const py::object &obj, bool use_signature) {
183   MS_LOG(DEBUG) << "Converting cell list";
184   py::sequence list = obj;
185   std::vector<ValuePtr> value_list;
186   for (size_t it = 0; it < list.size(); ++it) {
187     ValuePtr out = nullptr;
188     bool success = ConvertData(list[it], &out, use_signature);
189     if (!success) {
190       return nullptr;
191     }
192     value_list.push_back(out);
193   }
194   return std::make_shared<ValueTuple>(value_list);
195 }
196 
ConvertDict(const py::object & obj,bool use_signature)197 ValuePtr ConvertDict(const py::object &obj, bool use_signature) {
198   MS_LOG(DEBUG) << "Converting python dict";
199 
200   auto dict_values = obj.cast<py::dict>();
201   std::vector<std::pair<std::string, ValuePtr>> key_values;
202   for (auto item : dict_values) {
203     if (!py::isinstance<py::str>(item.first)) {
204       MS_LOG(ERROR) << "The key of dict is only support str.";
205       return nullptr;
206     }
207     std::string key = py::str(item.first);
208     ValuePtr out = nullptr;
209     bool success = ConvertData(dict_values[item.first], &out, use_signature);
210     if (!success) {
211       return nullptr;
212     }
213     key_values.emplace_back(key, out);
214   }
215   return std::make_shared<ValueDictionary>(key_values);
216 }
217 
ConvertModuleNameSpace(const py::object & obj)218 ValuePtr ConvertModuleNameSpace(const py::object &obj) {
219   MS_LOG(DEBUG) << "Converting python module";
220   py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
221   py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj);
222   auto converted =
223     std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_MODULE, py::cast<py::module>(module_namespace), obj);
224   MS_LOG(DEBUG) << "name_space: " << converted->ToString();
225   return converted;
226 }
227 
ConvertDataClass(const py::object & obj)228 ValuePtr ConvertDataClass(const py::object &obj) {
229   MS_LOG(DEBUG) << "Converting dataclass";
230   // Maybe the obj is dataclass define
231   auto desc = py::cast<std::string>(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj));
232   // desc has format "<class xxxx>", strip the '<' and '>' by offset 1
233   auto converted = std::make_shared<ClassObject>(obj, std::string(desc.begin() + 1, desc.end() - 1));
234   return converted;
235 }
236 
ConvertPrimitive(const py::object & obj,bool use_signature=false)237 ValuePtr ConvertPrimitive(const py::object &obj, bool use_signature = false) {
238   MS_LOG(DEBUG) << "Converting primitive object" << use_signature;
239 
240   // need check the primitive is class type or instance
241   auto obj_type = data_converter::GetObjType(obj);
242   if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
243     auto desc = py::cast<std::string>(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj));
244     // desc has format "<class xxxx>", strip the '<' and '>' by offset 1.
245     return std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
246   }
247   py::object adapter_obj = obj;
248   if (py::hasattr(obj, "__setattr_flag__")) {
249     if (py::hasattr(obj, "_clone")) {
250       auto clone_fn = obj.attr("_clone");
251       adapter_obj = clone_fn();
252     }
253   }
254   auto prim_adapter = adapter_obj.cast<PrimitivePyAdapterPtr>();
255   MS_EXCEPTION_IF_NULL(prim_adapter);
256   auto primitive = prim_adapter->attached_primitive();
257   if (primitive == nullptr) {
258     primitive = std::make_shared<PrimitivePy>(adapter_obj, prim_adapter);
259     prim_adapter->set_attached_primitive(primitive);
260   }
261 
262   if (use_signature) {
263     return std::make_shared<prim::DoSignaturePrimitive>(primitive->name(), primitive);
264   }
265   return primitive;
266 }
267 
ConvertMetaFuncGraph(const py::object & obj,bool use_signature=false)268 ValuePtr ConvertMetaFuncGraph(const py::object &obj, bool use_signature = false) {
269   MS_LOG(DEBUG) << "Converting MetaFuncGraph object";
270   auto meta = obj.cast<MetaFuncGraphPtr>();
271   if (meta == nullptr) {
272     MS_LOG(ERROR) << "Resolve MetaFuncGraph error, get ptr is null";
273     return nullptr;
274   }
275   if (use_signature) {
276     return std::make_shared<prim::DoSignaturePrimitive>(meta->name(), meta);
277   }
278   return meta;
279 }
280 
ConvertFuncGraph(const py::object & obj)281 ValuePtr ConvertFuncGraph(const py::object &obj) {
282   MS_LOG(DEBUG) << "Converting FuncGraph object";
283   auto func_graph = obj.cast<FuncGraphPtr>();
284   if (func_graph == nullptr) {
285     MS_LOG(ERROR) << "Resolve FuncGraph error, get ptr is null";
286     return nullptr;
287   }
288   auto new_fg = BasicClone(func_graph);
289   new_fg->set_attr("is_load", MakeValue(true));
290   return new_fg;
291 }
292 
ConvertSlice(const py::object & obj)293 ValuePtr ConvertSlice(const py::object &obj) {
294   MS_LOG(DEBUG) << "Converting slice object";
295 
296   auto convert_func = [obj](const std::string &attr) -> ValuePtr {
297     auto py_attr = py::getattr(obj, attr.c_str());
298     if (py::isinstance<py::none>(py_attr)) {
299       return kNone;
300     }
301     if (py::isinstance<py::int_>(py_attr)) {
302       auto value = py::cast<int64_t>(py_attr);
303       return MakeValue(value);
304     }
305     MS_LOG(EXCEPTION) << "Attribute '" << attr << "' of " << py::str(obj) << " should be int but got "
306                       << py::str(py_attr);
307   };
308   ValuePtr start = convert_func("start");
309   ValuePtr stop = convert_func("stop");
310   ValuePtr step = convert_func("step");
311   return std::make_shared<ValueSlice>(start, stop, step);
312 }
313 
ConvertCellObjToFuncGraph(const py::object & obj)314 ValuePtr ConvertCellObjToFuncGraph(const py::object &obj) {
315   FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
316   if (func_graph == nullptr) {
317     MS_LOG(ERROR) << "Parse resolve function error.";
318     return nullptr;
319   }
320   // if the cell object has specified bprop, it has user-defined bprop function parse and record it
321   if (py::hasattr(obj, CUSTOM_BPROP_NAME)) {
322     bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug"));
323     FuncGraphPtr bprop_graph =
324       enable_bprop_debug ? ConvertToBpropCut(obj) : ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD);
325     if (bprop_graph != nullptr) {
326       (void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
327       (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph)));
328       func_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
329     }
330   }
331   if (py::hasattr(obj, STAGE_NAME)) {
332     auto stage = py::cast<int>(py::getattr(obj, STAGE_NAME));
333     func_graph->set_stage(stage);
334   }
335   return func_graph;
336 }
337 
ConvertOtherObj(const py::object & obj)338 ValuePtr ConvertOtherObj(const py::object &obj) {
339   auto obj_type = data_converter::GetObjType(obj);
340   MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " ";
341   if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
342     MS_LOG(DEBUG) << "Resolve the class type, need create class instance.";
343     std::string desc = py::str(obj);
344     // desc has format "<class xxxx>", strip the '<' and '>' by offset 1.
345     return std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
346   }
347   if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD) {
348     MS_LOG(DEBUG) << "Convert the obj to func graph, type is " << obj_type;
349     FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
350     if (func_graph == nullptr) {
351       MS_LOG(ERROR) << "Parse resolve function error.";
352       return nullptr;
353     }
354     return func_graph;
355   }
356   if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) {
357     // Create the namespace for common class instance
358     // When the obj is Cell, default parse the 'construct'
359     py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
360     py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
361     auto res = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
362     MS_LOG(DEBUG) << "name_space: " << res->ToString();
363     return res;
364   }
365   MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj));
366   return nullptr;
367 }
368 
369 template <typename T>
ConvertNumberWithType(const T & obj,const TypePtr & dtype)370 ValuePtr ConvertNumberWithType(const T &obj, const TypePtr &dtype) {
371   ValuePtr data = nullptr;
372   auto int_dypte = dyn_cast<Int>(dtype);
373   if (int_dypte != nullptr) {
374     switch (int_dypte->nbits()) {
375       case kBit8:
376         data = std::make_shared<Int8Imm>(obj);
377         break;
378       case kBit16:
379         data = std::make_shared<Int16Imm>(obj);
380         break;
381       case kBit32:
382         data = std::make_shared<Int32Imm>(obj);
383         break;
384       case kBit64:
385         data = std::make_shared<Int64Imm>(obj);
386         break;
387       default:
388         data = std::make_shared<Int64Imm>(obj);
389     }
390     return data;
391   }
392 
393   auto uint_dypte = dyn_cast<UInt>(dtype);
394   if (uint_dypte != nullptr) {
395     switch (uint_dypte->nbits()) {
396       case kBit8:
397         data = std::make_shared<UInt8Imm>(obj);
398         break;
399       case kBit16:
400         data = std::make_shared<UInt16Imm>(obj);
401         break;
402       case kBit32:
403         data = std::make_shared<UInt32Imm>(obj);
404         break;
405       case kBit64:
406         data = std::make_shared<UInt64Imm>(obj);
407         break;
408       default:
409         data = std::make_shared<UInt32Imm>(obj);
410     }
411     return data;
412   }
413 
414   auto float_dypte = dyn_cast<Float>(dtype);
415   if (float_dypte != nullptr) {
416     switch (float_dypte->nbits()) {
417       case kBit32:
418         data = std::make_shared<FP32Imm>(obj);
419         break;
420       case kBit64:
421         data = std::make_shared<FP64Imm>(obj);
422         break;
423       default:
424         data = std::make_shared<FP32Imm>(obj);
425     }
426     return data;
427   }
428   return nullptr;
429 }
430 
ConvertIntegerWithType(const py::object & obj,const TypePtr & dtype=nullptr)431 ValuePtr ConvertIntegerWithType(const py::object &obj, const TypePtr &dtype = nullptr) {
432   auto obj_int64 = py::cast<int64_t>(obj);
433   if (dtype == nullptr) {
434     return std::make_shared<Int64Imm>(obj_int64);
435   }
436   return ConvertNumberWithType<int64_t>(obj_int64, dtype);
437 }
438 
ConvertFloatWithType(const py::object & obj,const TypePtr & dtype=nullptr)439 ValuePtr ConvertFloatWithType(const py::object &obj, const TypePtr &dtype = nullptr) {
440   auto obj_float64 = py::cast<float>(obj);
441   if (dtype == nullptr) {
442     return std::make_shared<FP32Imm>(obj_float64);
443   }
444   return ConvertNumberWithType<float>(obj_float64, dtype);
445 }
446 
447 template <typename T, typename U>
PyCast(const py::object & obj)448 ValuePtr PyCast(const py::object &obj) {
449   return std::make_shared<T>(py::cast<U>(obj));
450 }
451 
452 template <typename T>
ObjCast(const py::object & obj)453 ValuePtr ObjCast(const py::object &obj) {
454   return obj.cast<T>();
455 }
456 
GetDataConverters()457 std::vector<DataConverterPtr> GetDataConverters() {
458   static std::vector<DataConverterPtr> data_converters = {
459     // Convert data by python object type.
460     std::make_shared<ByTypeDataConverter<py::none>>(kNone),
461     std::make_shared<ByTypeDataConverter<py::bool_>>(PyCast<BoolImm, bool>),
462     std::make_shared<ByTypeDataConverter<py::str>>(PyCast<StringImm, string>),
463     std::make_shared<ByTypeDataConverter<py::ellipsis>>(kEllipsis),
464     std::make_shared<ByTypeDataConverter<py::module>>(ConvertModuleNameSpace),
465     std::make_shared<ByAttrDataConverter>(PYTHON_DATACLASS_FIELDS, ConvertDataClass),
466     std::make_shared<ByTypeDataConverter<Type>>(ObjCast<TypePtr>),
467     std::make_shared<ByTypeDataConverter<Tensor>>(ObjCast<TensorPtr>),
468     std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>),
469     std::make_shared<ByTypeDataConverter<UMonad>>(ObjCast<UMonadPtr>),
470     std::make_shared<ByTypeDataConverter<IOMonad>>(ObjCast<IOMonadPtr>),
471     std::make_shared<ByTypeDataConverter<EnvInstance>>(ObjCast<std::shared_ptr<EnvInstance>>),
472     std::make_shared<ByAttrDataConverter>(PYTHON_CLASS_MEMBER_NAMESPACE,
473                                           [](const py::object &obj) -> ValuePtr {
474                                             auto res =
475                                               std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj);
476                                             MS_LOG(DEBUG) << "name_space: " << res->ToString();
477                                             return res;
478                                           }),
479     std::make_shared<ByTypeDataConverter<py::int_>>(ConvertIntegerWithType),
480     std::make_shared<ByTypeDataConverter<py::float_>>(ConvertFloatWithType),
481     std::make_shared<ByTypeDataConverter<py::dict>>(ConvertDict),
482     std::make_shared<ByTypeDataConverter<py::slice>>(ConvertSlice),
483     std::make_shared<ByTypeDataConverter<py::tuple>>(ConvertTuple),
484     std::make_shared<ByAttrDataConverter>(PYTHON_CELL_AS_LIST, ConvertCellList),
485     std::make_shared<ByTypeDataConverter<Cell>>(ConvertCellObjToFuncGraph),
486     std::make_shared<ByTypeDataConverter<py::list>>(ConvertList),
487     std::make_shared<ByAttrDataConverter>(PYTHON_PRIMITIVE_FLAG, ConvertPrimitive),
488     std::make_shared<ByTypeDataConverter<MetaFuncGraph>>(ConvertMetaFuncGraph),
489     std::make_shared<ByTypeDataConverter<FuncGraph>>(ConvertFuncGraph),
490   };
491   return data_converters;
492 }
493 }  // namespace
494 
ConvertData(const py::object & obj,ValuePtr * const data,bool use_signature,const TypePtr & dtype)495 bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, const TypePtr &dtype) {
496   // Check parameter valid
497   if (data == nullptr) {
498     MS_LOG(ERROR) << "Data is null pointer";
499     return false;
500   }
501   ValuePtr converted = nullptr;
502   bool matched = false;
503   auto converters = GetDataConverters();
504   for (auto &converter : converters) {
505     if (converter->Matched(obj)) {
506       converted = converter->ConvertPyObject(obj, use_signature, dtype);
507       matched = true;
508       break;
509     }
510   }
511   if (!matched) {
512     converted = ConvertOtherObj(obj);
513   }
514   *data = converted;
515   return converted != nullptr;
516 }
517 
518 // Convert data to graph
ConvertToFuncGraph(const py::object & obj,const std::string & python_mod_get_parse_method)519 FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python_mod_get_parse_method) {
520   std::vector<std::string> results = data_converter::GetObjKey(obj);
521   std::string obj_id = results[0] + python_mod_get_parse_method;
522   std::string obj_key = results[1];
523   FuncGraphPtr func_graph = nullptr;
524   ValuePtr value = nullptr;
525   bool is_cache = data_converter::GetObjectValue(obj_id, &value);
526   if (is_cache && value != nullptr && value->isa<FuncGraph>()) {
527     MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id;
528     func_graph = value->cast<FuncGraphPtr>();
529     if (!func_graph->dropped()) {
530       return func_graph;
531     }
532   }
533 
534   func_graph = ParsePythonCode(obj, python_mod_get_parse_method);
535   if (func_graph == nullptr) {
536     MS_LOG(ERROR) << "Parse resolve function error.";
537     return nullptr;
538   }
539 
540   data_converter::MakeProperNameToFuncGraph(func_graph, obj_id);
541   data_converter::CacheObjectValue(obj_id, func_graph);
542   if (!obj_key.empty()) {
543     MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString();
544     data_converter::SetObjGraphValue(obj_key, func_graph);
545   }
546 
547   return func_graph;
548 }
549 namespace data_converter {
550 static std::unordered_map<std::string, ValuePtr> object_map_;
551 
552 static std::unordered_map<std::string, std::vector<FuncGraphPtr>> object_graphs_map_;
553 
SetObjGraphValue(const std::string & obj_key,const FuncGraphPtr & data)554 void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) {
555   object_graphs_map_[obj_key].push_back(data);
556   MS_LOG(DEBUG) << "Set func graph size:" << object_graphs_map_.size();
557 }
558 
GetObjGraphs()559 const std::unordered_map<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs() {
560   MS_LOG(DEBUG) << "Obj size:" << object_graphs_map_.size();
561   return object_graphs_map_;
562 }
563 
CacheObjectValue(const std::string & obj_key,const ValuePtr & data)564 void CacheObjectValue(const std::string &obj_key, const ValuePtr &data) { object_map_[obj_key] = data; }
GetObjectValue(const std::string & obj_key,ValuePtr * const data)565 bool GetObjectValue(const std::string &obj_key, ValuePtr *const data) {
566   if (object_map_.count(obj_key)) {
567     *data = object_map_[obj_key];
568     return true;
569   }
570   return false;
571 }
GetObjKey(const py::object & obj)572 std::vector<std::string> GetObjKey(const py::object &obj) {
573   py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
574   py::tuple obj_tuple = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_KEY, obj);
575   if (obj_tuple.size() != 2) {
576     MS_LOG(EXCEPTION) << "Get_obj_key must return 2 elements";
577   }
578   return {py::cast<std::string>(obj_tuple[0]), py::cast<std::string>(obj_tuple[1])};
579 }
580 
581 // Get obj detail type
GetObjType(const py::object & obj)582 ResolveTypeDef GetObjType(const py::object &obj) {
583   try {
584     py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
585     auto obj_type =
586       ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast<int32_t>());
587     return obj_type;
588   } catch (const py::error_already_set &ex) {
589     MS_LOG(ERROR) << "Meet a exception from Python when get the type of `" << py::str(obj) << "`.\n" << ex.what();
590     std::rethrow_exception(std::current_exception());
591   } catch (const py::type_error &ex) {
592     MS_LOG(ERROR) << "Meet a exception when get the type of `" << py::str(obj) << "`.\n" << ex.what();
593     std::rethrow_exception(std::current_exception());
594   }
595 }
596 
597 // Get class instance detail type.
GetClassInstanceType(const py::object & obj)598 ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) {
599   py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
600   auto class_type =
601     ClassInstanceTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_CLASS_INSTANCE_TYPE, obj).cast<int32_t>());
602   return class_type;
603 }
604 
605 // Check the object is Cell Instance.
IsCellInstance(const py::object & obj)606 bool IsCellInstance(const py::object &obj) {
607   auto class_type = GetClassInstanceType(obj);
608   bool isCell = (class_type == CLASS_INSTANCE_TYPE_CELL);
609   return isCell;
610 }
611 
612 // Create the python class instance.
CreatePythonObject(const py::object & type,const py::tuple & args_kwargs)613 py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs) {
614   py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
615   // `args_kwargs` maybe a tuple(*args), tuple(**kwargs), or tuple(*args, **kwargs).
616   return args_kwargs.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_INSTANCE, type)
617                              : python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_INSTANCE, type, args_kwargs);
618 }
619 
620 // Call the python script string.
CallPythonScript(const py::object & script,const py::tuple & args_kwargs)621 py::object CallPythonScript(const py::object &script, const py::tuple &args_kwargs) {
622   py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
623   // `args_kwargs` is a tuple(dict(global), dict(local)).
624   return args_kwargs.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_EVAL_PY_SCRIPT, script)
625                              : python_adapter::CallPyModFn(mod, PYTHON_MOD_EVAL_PY_SCRIPT, script, args_kwargs);
626 }
627 
628 // Generate an appropriate name and set to graph debuginfo,
629 // character <> can not used in the dot file, so change to another symbol.
MakeProperNameToFuncGraph(const FuncGraphPtr & func_graph,std::string name)630 void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name) {
631   MS_EXCEPTION_IF_NULL(func_graph);
632   MS_EXCEPTION_IF_NULL(func_graph->debug_info());
633   // Set detail name info of function
634   std::ostringstream oss;
635   for (size_t i = 0; i < name.size(); i++) {
636     if (name[i] == '<') {
637       oss << "「";
638     } else if (name[i] == '>') {
639       oss << "」";
640     } else {
641       oss << name[i];
642     }
643   }
644   func_graph->debug_info()->set_full_name(oss.str());
645 }
646 
PyDataToValue(const py::object & obj)647 ValuePtr PyDataToValue(const py::object &obj) {
648   py::object to_convert = obj;
649   ValuePtr value = nullptr;
650   (void)ConvertData(to_convert, &value);
651   return value;
652 }
653 
ClearObjectCache()654 void ClearObjectCache() {
655   object_map_.clear();
656   object_graphs_map_.clear();
657 }
658 }  // namespace data_converter
659 
660 static std::unordered_map<std::string, ClassPtr> g_dataClassToClass = {};
661 
662 // Parse dataclass to mindspore Class type
ParseDataClass(const py::object & cls_obj)663 ClassPtr ParseDataClass(const py::object &cls_obj) {
664   std::string cls_name = py::cast<std::string>(python_adapter::GetPyObjAttr(cls_obj, "__name__"));
665   std::string cls_module = py::cast<std::string>(python_adapter::GetPyObjAttr(cls_obj, "__module__"));
666   std::string cls = cls_module + "." + cls_name;
667   auto iterator = g_dataClassToClass.find(cls);
668   if (iterator != g_dataClassToClass.end()) {
669     return iterator->second;
670   }
671 
672   py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
673   ClassAttrVector attributes;
674   py::dict names = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_ATTRS, cls_obj);
675   for (auto &item : names) {
676     auto type_value = item.second.cast<TypePtr>();
677     MS_EXCEPTION_IF_NULL(type_value);
678     MS_LOG(DEBUG) << "(Name: " << py::cast<std::string>(item.first) << ", type: " << type_value->ToString() << ")";
679     attributes.push_back(std::make_pair(py::cast<std::string>(item.first), type_value));
680   }
681 
682   std::unordered_map<std::string, ValuePtr> methods_map;
683   py::dict methods = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_METHODS, cls_obj);
684   for (auto &item : methods) {
685     auto fun_name = item.first.cast<std::string>();
686     auto obj = py::cast<py::object>(item.second);
687     std::shared_ptr<PyObjectWrapper> method_obj = std::make_shared<PyObjectWrapper>(obj, fun_name);
688     methods_map[fun_name] = method_obj;
689   }
690 
691   std::shared_ptr<Class> me_class = std::make_shared<Class>(Named(cls_name), attributes, methods_map);
692   // static Variable for cache
693   // cppcheck-suppress unreadVariable
694   g_dataClassToClass[cls] = me_class;
695 
696   return me_class;
697 }
698 
CleanDataClassToClassMap()699 void CleanDataClassToClassMap() { g_dataClassToClass.clear(); }
700 }  // namespace parse
701 }  // namespace mindspore
702