• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/jit/ps/parse/resolve.h"
18 
19 #include <utility>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 #include <algorithm>
24 #include <unordered_map>
25 
26 #include "mindspore/core/ops/structure_ops.h"
27 #include "mindspore/core/ops/sequence_ops.h"
28 #include "mindspore/core/ops/framework_ops.h"
29 #include "ir/param_info.h"
30 #include "ir/value.h"
31 #include "ir/map_tensor.h"
32 #include "pipeline/jit/ps/fallback.h"
33 #include "pipeline/jit/ps/parse/data_converter.h"
34 #include "pipeline/jit/ps/parse/parse.h"
35 #include "include/common/utils/python_adapter.h"
36 #include "include/common/utils/parallel_context.h"
37 #include "utils/any.h"
38 #include "frontend/operator/ops.h"
39 #include "frontend/optimizer/opt.h"
40 #include "frontend/optimizer/irpass.h"
41 #include "frontend/optimizer/irpass/symbol_resolver.h"
42 #include "include/common/fallback.h"
43 #include "include/common/debug/anf_dump_utils.h"
44 #include "utils/log_adapter.h"
45 
46 namespace mindspore {
47 namespace parse {
48 static std::unordered_map<std::string, std::string> param_obj_ids;  // param_name : obj_id
CleanParameterNameCache()49 void CleanParameterNameCache() {
50   MS_LOG(DEBUG) << "Clean parameter name cache.";
51   param_obj_ids.clear();
52 }
53 namespace {
ReplaceSpecialChar(const std::string & str)54 std::string ReplaceSpecialChar(const std::string &str) {
55   std::ostringstream oss;
56   for (size_t i = 0; i < str.size(); i++) {
57     if (str[i] == '<') {
58       // ⎡: \u23A1
59       oss << "\u23A1";
60     } else if (str[i] == '>') {
61       // ⎦: \u23A6
62       oss << "\u23A6";
63     } else {
64       oss << str[i];
65     }
66   }
67   return oss.str();
68 }
69 
70 struct AnfDumpHandlerRegister {
AnfDumpHandlerRegistermindspore::parse::__anoncfc1baa10111::AnfDumpHandlerRegister71   AnfDumpHandlerRegister() {
72     AnfDumpHandler::SetValueNodeStrHandler([](const std::shared_ptr<ValueNode> &node) -> std::string {
73       if (node == nullptr) {
74         return "";
75       }
76       if (IsValueNode<MetaFuncGraph>(node)) {
77         return node->value()->cast<MetaFuncGraphPtr>()->name();
78       } else if (IsValueNode<parse::NameSpace>(node)) {
79         return node->value()->cast<parse::NameSpacePtr>()->name();
80       } else if (IsValueNode<parse::Symbol>(node)) {
81         return ReplaceSpecialChar(node->value()->cast<parse::SymbolPtr>()->name());
82       }
83       return "";
84     });
85   }
86 } callback_register;
87 }  // namespace
88 
InterpretedObject(const py::object & obj)89 InterpretedObject::InterpretedObject(const py::object &obj) : PyObjectWrapper(obj) {
90   std::stringstream buf;
91   auto type_str = python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_GET_TYPE, obj);
92   buf << "PythonObject(type: " << std::string(py::str(type_str)) << ", value: " << std::string(py::str(obj)) << ")";
93   this->set_name(buf.str());
94 }
95 
ToAbstract()96 abstract::AbstractBasePtr MsClassObject::ToAbstract() {
97   py::gil_scoped_acquire acquire;
98   bool is_class_type = parse::data_converter::IsClassType(obj());
99   if (is_class_type) {
100     // Class type as func, such as Net(x, y)
101     auto abs_class = std::make_shared<abstract::AbstractClass>(shared_from_base<MsClassObject>());
102     AbstractBasePtrList args_abs_list = {abs_class};
103     auto func = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimCreateInstance);
104     auto res_val = std::make_shared<abstract::PartialAbstractClosure>(func, args_abs_list);
105     res_val->set_value_desc(ToString());
106     return res_val;
107   } else {
108     // Class instance as func, such as net(x, y)
109     return std::make_shared<abstract::AbstractClass>(shared_from_base<MsClassObject>());
110   }
111 }
112 
IsSupportedCreateInstanceType(const py::object & obj)113 static inline bool IsSupportedCreateInstanceType(const py::object &obj) {
114   py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
115   auto res = python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE, obj);
116   if (!py::isinstance<py::bool_>(res)) {
117     MS_LOG(ERROR) << "Expect a bool type, but got " << py::str(res);
118     return false;
119   }
120   return res.cast<bool>();
121 }
122 
ToAbstract()123 abstract::AbstractBasePtr ClassType::ToAbstract() {
124   py::gil_scoped_acquire acquire;
125   auto abs_scalar =
126     std::make_shared<abstract::AbstractScalar>(shared_from_base<ClassType>(), std::make_shared<TypeType>());
127 
128   if (!IsSupportedCreateInstanceType(obj())) {
129     return abs_scalar;
130   }
131   AbstractBasePtrList args_abs_list = {abs_scalar};
132 
133   auto func = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimCreateInstance);
134   auto res_val = std::make_shared<abstract::PartialAbstractClosure>(func, args_abs_list);
135   res_val->set_value_desc(ToString());
136   return res_val;
137 }
138 
139 using tensor::MapTensorPtr;
140 // Get parameter value from a python parameter object.
141 // If it is a map parameter, return the map tensor value in it,
142 // otherwise, return parameter itself as a meta tensor value.
GetParameterValue(const py::object & param_obj)143 ValuePtr GetParameterValue(const py::object &param_obj) {
144   constexpr char attr_map_tensor[] = "_map_tensor";
145   constexpr char attr_param_info[] = "param_info";
146   if (py::hasattr(param_obj, attr_map_tensor)) {
147     auto map_tensor = py::cast<MapTensorPtr>(python_adapter::GetPyObjAttr(param_obj, attr_map_tensor));
148     MS_EXCEPTION_IF_NULL(map_tensor);
149     auto param_info = py::cast<ParamInfoPtr>(python_adapter::GetPyObjAttr(param_obj, attr_param_info));
150     MS_EXCEPTION_IF_NULL(param_info);
151     map_tensor->set_param_info(param_info);
152     return map_tensor;
153   }
154   return py::cast<tensor::MetaTensorPtr>(param_obj);
155 }
156 
157 namespace {
GetPyObjId(const py::object & obj)158 std::string GetPyObjId(const py::object &obj) {
159   py::object out = python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj);
160   if (py::isinstance<py::none>(out)) {
161     MS_LOG(INTERNAL_EXCEPTION) << "Get pyobj failed";
162   }
163   return out.cast<std::string>();
164 }
165 
ClearCNodeAbstract(const FuncGraphPtr & func_graph)166 void ClearCNodeAbstract(const FuncGraphPtr &func_graph) {
167   std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccDeeperSimple, AlwaysInclude);
168   static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
169   for (const auto &node : nodes) {
170     if (node == nullptr || node->isa<Parameter>()) {
171       continue;
172     }
173     auto primitive = GetCNodePrimitive(node);
174     if (primitive != nullptr) {
175       auto is_load = primitive->GetAttr("is_load");
176       if (abstract::GetPrimEvaluator(primitive, nullptr) == nullptr && is_load != nullptr && GetValue<bool>(is_load)) {
177         MS_LOG(INFO) << "The primitive is not defined in front end. Primitive: " << primitive->ToString();
178         continue;
179       }
180     }
181     auto prev_inferred = node->abstract();
182     // Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction.
183     if (!node->isa<ValueNode>() || (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>())) {
184       // Reset tuple/list abstract use flags.
185       if (enable_eliminate_unused_element && prev_inferred != nullptr &&
186           prev_inferred->isa<abstract::AbstractSequence>()) {
187         SetSequenceNodeElementsUseFlags(node, nullptr);
188       }
189       node->set_abstract(nullptr);
190       MS_LOG(DEBUG) << "Abstract of node " << node->DebugString() << " is set to nullptr";
191     }
192   }
193 }
194 
ConvertLoadedGraph(const FuncGraphPtr & func_graph,const ValuePtr & value)195 void ConvertLoadedGraph(const FuncGraphPtr &func_graph, const ValuePtr &value) {
196   if (!value->isa<FuncGraph>()) {
197     return;
198   }
199   auto resolved_graph = value->cast<FuncGraphPtr>();
200   MS_EXCEPTION_IF_NULL(resolved_graph);
201   if (!resolved_graph->has_attr("is_load")) {
202     return;
203   }
204   auto top_graph = Parser::GetTopFuncGraph();
205   std::vector<AnfNodePtr> input_params;
206   auto resolved_graph_count = resolved_graph->fv_param_count();
207   std::vector<ParameterPtr> drop_node_list;
208   for (auto const &param : resolved_graph->parameters()) {
209     auto param_ptr = dyn_cast<Parameter>(param);
210     MS_EXCEPTION_IF_NULL(param_ptr);
211     if (param_ptr->has_default()) {
212       param_ptr->set_func_graph(top_graph);
213       func_graph->add_parameter_obj_node(param_ptr);
214       // Update top_graph
215       top_graph->add_parameter(param_ptr);
216       size_t fv_param_count = top_graph->fv_param_count();
217       top_graph->set_fv_param_count(++fv_param_count);
218       (void)drop_node_list.emplace_back(param_ptr);
219       resolved_graph->set_fv_param_count(--resolved_graph_count);
220     } else {
221       input_params.push_back(param_ptr);
222     }
223   }
224   for (const auto &param_ptr : drop_node_list) {
225     resolved_graph->DropNode(param_ptr);
226   }
227   resolved_graph->set_parameters(input_params);
228   ClearCNodeAbstract(resolved_graph);
229 }
230 
HasConstArgAttr(const py::object & obj)231 bool HasConstArgAttr(const py::object &obj) {
232   constexpr char const_arg_attr[] = "const_arg";
233   return py::hasattr(obj, const_arg_attr) && py::cast<bool>(py::getattr(obj, const_arg_attr));
234 }
235 
HasMutableAttr(const py::object & obj)236 bool HasMutableAttr(const py::object &obj) {
237   constexpr char mutable_attr[] = "__ms_mutable__";
238   return py::hasattr(obj, mutable_attr) && py::cast<bool>(py::getattr(obj, mutable_attr));
239 }
240 
HasVariableLenAttr(const py::object & obj)241 bool HasVariableLenAttr(const py::object &obj) {
242   constexpr char variable_len_attr[] = "__ms_dynamic_len__";
243   return py::hasattr(obj, variable_len_attr) && py::cast<bool>(py::getattr(obj, variable_len_attr));
244 }
245 
ConvertInterpretedObjForResolve(const AnfNodePtr & origin_node,const ValuePtr & convert_result,const FuncGraphPtr & func_graph)246 AnfNodePtr ConvertInterpretedObjForResolve(const AnfNodePtr &origin_node, const ValuePtr &convert_result,
247                                            const FuncGraphPtr &func_graph) {
248   if (convert_result->isa<InterpretedObject>() && !origin_node->has_user_data("__py_interpret_local_value_flag__")) {
249     constexpr auto recursive_level = 2;
250     MS_LOG(DEBUG) << "Convert InterpretedObj for resolve, node: " << origin_node->DebugString(recursive_level);
251     auto interpreted_value = dyn_cast<InterpretedObject>(convert_result);
252     const auto &key = interpreted_value->name();
253     if (interpreted_value->has_converted()) {
254       return fallback::ConvertPyObjectToPyExecute(func_graph, key, interpreted_value->obj(), origin_node, true);
255     }
256     return fallback::ConvertPyObjectToPyInterpret(func_graph, key, interpreted_value->obj(), origin_node, true);
257   }
258   return nullptr;
259 }
260 
ConvertObjectToNode(const AnfNodePtr & origin_node,const py::object & obj,const FuncGraphPtr & func_graph,bool is_element_obj)261 AnfNodePtr ConvertObjectToNode(const AnfNodePtr &origin_node, const py::object &obj, const FuncGraphPtr &func_graph,
262                                bool is_element_obj) {
263   // When the cell is set recomputed, it should not use old scope from cache.
264   MS_EXCEPTION_IF_NULL(origin_node);
265   auto origin_cnode = dyn_cast<CNode>(origin_node);
266   MS_EXCEPTION_IF_NULL(origin_cnode);
267   bool is_resolve = IsPrimitiveCNode(origin_node, prim::kPrimResolve);
268   auto scope = origin_node->scope();
269   bool has_recompute_scope =
270     (scope != nullptr && scope->name().compare(0, strlen(kAttrRecompute), kAttrRecompute) == 0);
271   ValuePtr convert_result = nullptr;
272   constexpr auto resolve_with_args_inputs_size = 4;
273   MS_LOG(DEBUG) << "origin_cnode: " << origin_cnode->DebugString();
274   if (is_resolve && origin_cnode->size() == resolve_with_args_inputs_size) {  // (resolve, namespace, symbol, arguments)
275     constexpr auto args_input_pos = 3;
276     auto args_node = origin_cnode->input(args_input_pos);
277     auto args_value = GetValueNode<ValueTuplePtr>(args_node);
278     MS_EXCEPTION_IF_NULL(args_value);
279     parse::DataConverter data_converter(args_value->value(), python_adapter::UseSignatureInResolve());
280     convert_result = data_converter.ConvertData(obj);
281     if (convert_result == nullptr) {
282       MS_LOG(INTERNAL_EXCEPTION) << "Convert error with Python object: " << std::string(py::str(obj));
283     }
284   } else {  // (resolve/getattr, namespace, symbol, optional[getattr])
285     bool converted =
286       ConvertData(obj, &convert_result, python_adapter::UseSignatureInResolve(), nullptr, has_recompute_scope);
287     if (!converted) {
288       MS_LOG(ERROR) << "Convert data failed";
289       return nullptr;
290     }
291   }
292 
293   // If obj is an element, do not convert InterpretedObj.
294   if (!is_element_obj) {
295     AnfNodePtr interpreted_output = ConvertInterpretedObjForResolve(origin_node, convert_result, func_graph);
296     if (interpreted_output != nullptr) {
297       return interpreted_output;
298     }
299   }
300 
301   if (convert_result->isa<FuncGraph>() && has_recompute_scope) {
302     UpdateRecomputeScope(convert_result->cast<FuncGraphPtr>());
303   }
304   ConvertLoadedGraph(func_graph, convert_result);
305   AnfNodePtr output = NewValueNode(convert_result);
306   if (convert_result->isa<tensor::Tensor>()) {
307     output = GetMixedPrecisionCastHelp(func_graph, output);
308     if (HasConstArgAttr(obj)) {
309       MS_LOG(WARNING) << "The tensor " << convert_result->ToString()
310                       << " which is not used for network input argument should not be set const.";
311     }
312   }
313   if (HasMutableAttr(obj)) {
314     auto dynamic_len = HasVariableLenAttr(obj);
315     output = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimMutable), output, NewValueNode(dynamic_len)});
316   }
317   return output;
318 }
319 
TransformFuncValueNode(const FuncGraphManagerPtr & manager,const FuncGraphPtr & func_graph,const ValuePtr & value)320 AnfNodePtr TransformFuncValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph,
321                                   const ValuePtr &value) {
322   MS_EXCEPTION_IF_NULL(value);
323   if (value->isa<FuncGraph>()) {
324     auto fg = value->cast<FuncGraphPtr>();
325     manager->AddFuncGraph(fg);
326     return NewValueNode(fg);
327   }
328   if (value->isa<Primitive>()) {
329     return NewValueNode(value);
330   }
331   // (1) The CellList or CellDict will be parsed as value_sequence or value_dict of const graph in it,
332   // So if there is graph in list, try to replace the node with make_tuple or make_dict of graph value node.
333   // We do this because the graph manager won't investigate the graph inside value_sequence or value_dict,
334   // change the vector of graph to be make_tuple or make_dict of graph value node.
335   // (2) the primitive value_tuple or value_sequence or value_dict may encounter to abstract error, make it all
336   // independent nodes.
337   if (value->isa<ValueSequence>()) {
338     std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
339     bool is_all_func = true;
340     auto value_sequence = value->cast<ValueSequencePtr>();
341     if (value_sequence->size() == 0) {
342       return nullptr;
343     }
344     for (auto &elem : value_sequence->value()) {
345       auto node = TransformFuncValueNode(manager, func_graph, elem);
346       if (node == nullptr) {
347         is_all_func = false;
348       }
349       (void)inputs.emplace_back(node);
350     }
351     if (is_all_func) {
352       return func_graph->NewCNode(std::move(inputs));
353     }
354     return nullptr;
355   }
356   if (value->isa<ValueDictionary>()) {
357     std::vector<AnfNodePtr> keys{NewValueNode(prim::kPrimMakeTuple)};
358     std::vector<AnfNodePtr> values{NewValueNode(prim::kPrimMakeTuple)};
359     bool is_all_func = true;
360     for (auto &elem : value->cast<ValueDictionaryPtr>()->value()) {
361       (void)keys.emplace_back(NewValueNode(elem.first));
362       auto node = TransformFuncValueNode(manager, func_graph, elem.second);
363       if (node == nullptr) {
364         is_all_func = false;
365       }
366       (void)values.emplace_back(node);
367     }
368     if (is_all_func) {
369       return func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(std::move(keys)),
370                                    func_graph->NewCNode(std::move(values))});
371     }
372     return nullptr;
373   }
374 
375   return nullptr;
376 }
377 
378 // Resolve the python obj, and if the resovled node is valuenode with graphs, add the graphs to manager.
ResolveObjectAndAddToManager(const FuncGraphManagerPtr & manager,const py::object & obj,const AnfNodePtr & node)379 AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, const py::object &obj,
380                                         const AnfNodePtr &node) {
381   MS_EXCEPTION_IF_NULL(node);
382   ScopeGuard scope_guard(node->scope());
383   AnfNodePtr resolved_node = nullptr;
384   bool success = ResolveObjectToNode(node, obj, &resolved_node);
385   if (!success) {
386     MS_LOG(INTERNAL_EXCEPTION) << "Parse Resolve covert failed.";
387   }
388   if (IsValueNode<FuncGraph>(resolved_node)) {
389     auto new_fg = GetValueNode<FuncGraphPtr>(resolved_node);
390     auto fg = node->func_graph();
391     MS_EXCEPTION_IF_NULL(fg);
392     // If it's the sub func graph resolved in a reserved func graph.
393     if (fg->reserved()) {
394       new_fg->set_reserved(true);
395     }
396     manager->AddFuncGraph(new_fg);
397   }
398 
399   // If the constant node is constant of vector of graph, add graph to manager.
400   if (IsValueNode<ValueSequence>(resolved_node) || IsValueNode<ValueDictionary>(resolved_node)) {
401     auto value = resolved_node->cast<ValueNodePtr>()->value();
402     auto new_node = TransformFuncValueNode(manager, node->func_graph(), value);
403     if (new_node != nullptr) {
404       resolved_node = new_node;
405     }
406   }
407   fallback::SetPyObjectToNode(resolved_node, obj);
408   return resolved_node;
409 }
410 
IsParameterObject(const py::object & obj)411 bool IsParameterObject(const py::object &obj) {
412   return py::hasattr(obj, "__parameter__") && py::isinstance<tensor::MetaTensor>(obj);
413 }
414 
ContainsParameter(const py::object & obj)415 bool ContainsParameter(const py::object &obj) {
416   if (IsParameterObject(obj) || py::hasattr(obj, "__parameter_tuple__")) {
417     return true;
418   }
419   if ((py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) && py::len(obj) != 0) {
420     // NamedTuple
421     if (py::hasattr(obj, "_fields")) {
422       return false;
423     }
424     auto tuple = obj.cast<py::tuple>();
425     for (size_t i = 0; i < tuple.size(); ++i) {
426       if (ContainsParameter(tuple[i])) {
427         return true;
428       }
429     }
430   } else if (py::isinstance<py::dict>(obj)) {
431     auto dict = obj.cast<py::dict>();
432     for (auto item : dict) {
433       auto item_value = py::cast<py::object>(item.second);
434       if (ContainsParameter(item_value)) {
435         return true;
436       }
437     }
438   }
439   return false;
440 }
441 }  // namespace
442 
ResolveObjectToNode(const AnfNodePtr & origin_node,const py::object & obj,AnfNodePtr * const node,bool is_element_obj)443 bool ResolveObjectToNode(const AnfNodePtr &origin_node, const py::object &obj, AnfNodePtr *const node,
444                          bool is_element_obj) {
445   MS_EXCEPTION_IF_NULL(origin_node);
446   auto func_graph = origin_node->func_graph();
447   MS_EXCEPTION_IF_NULL(func_graph);
448   if (!ContainsParameter(obj)) {
449     auto output = ConvertObjectToNode(origin_node, obj, func_graph, is_element_obj);
450     if (output == nullptr) {
451       return false;
452     }
453     *node = output;
454     return true;
455   }
456   if (IsParameterObject(obj)) {
457     auto param = ResolveParameterObj(func_graph, obj);
458     if (param == nullptr) {
459       MS_LOG(ERROR) << "Resolve parameter object failed, got nullptr";
460       return false;
461     }
462     MS_LOG(DEBUG) << "Add param graph:" << func_graph->ToString() << ", " << param->DebugString();
463     *node = param;
464     return true;
465   }
466   if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj) || py::hasattr(obj, "__parameter_tuple__")) {
467     bool all_parameter_sequence = true;
468     std::vector<AnfNodePtr> args;
469     auto tuple = obj.cast<py::tuple>();
470     for (size_t i = 0; i < tuple.size(); ++i) {
471       if (!IsParameterObject(tuple[i])) {
472         all_parameter_sequence = false;
473       }
474       AnfNodePtr out = nullptr;
475       bool success = ResolveObjectToNode(origin_node, tuple[i], &out, true);
476       if (!success) {
477         MS_LOG(ERROR) << "Resolve object to node failed";
478         return false;
479       }
480       args.push_back(out);
481     }
482     // Convert [param1, param2, ..., paramN] to tuple.
483     bool need_convert_to_tuple = !is_element_obj && all_parameter_sequence && py::isinstance<py::list>(obj);
484     if (py::isinstance<py::tuple>(obj) || py::hasattr(obj, "__parameter_tuple__") || need_convert_to_tuple) {
485       (void)args.insert(args.begin(), NewValueNode(prim::kPrimMakeTuple));
486     } else {
487       (void)args.insert(args.begin(), NewValueNode(prim::kPrimMakeList));
488     }
489     // The ParameterTuple/tuple/list will not be added in order list,
490     // since we don't want to deal with its RefTensor elements during auto_monad procedure.
491     *node = NewCNode(std::move(args), func_graph);
492     return true;
493   }
494   if (py::isinstance<py::dict>(obj)) {
495     auto dict = obj.cast<py::dict>();
496     std::vector<AnfNodePtr> keys_tuple{NewValueNode(prim::kPrimMakeTuple)};
497     std::vector<AnfNodePtr> values_tuple{NewValueNode(prim::kPrimMakeTuple)};
498     for (auto item : dict) {
499       AnfNodePtr key = nullptr;
500       AnfNodePtr value = nullptr;
501       bool success = ResolveObjectToNode(origin_node, py::cast<py::object>(item.first), &key, true) &&
502                      ResolveObjectToNode(origin_node, py::cast<py::object>(item.second), &value, true);
503       if (!success) {
504         MS_LOG(ERROR) << "Resolve object to node failed";
505         return false;
506       }
507       (void)keys_tuple.emplace_back(key);
508       (void)values_tuple.emplace_back(value);
509     }
510     *node = func_graph->NewCNode(
511       {NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(keys_tuple), func_graph->NewCNode(values_tuple)});
512     return true;
513   }
514   MS_EXCEPTION(TypeError) << "The Parameter in obj '" << py::str(obj) << "' with nested structure is not supported."
515                           << "\nCurrently only single Parameter, ParameterTuple or Parameters in tuple/list/dict "
516                              "are supported. Or do you want to use Tensor instead?";
517 }
518 
GetNamespaceAndSymbol(const AnfNodePtr & node)519 std::pair<NameSpacePtr, SymbolPtr> GetNamespaceAndSymbol(const AnfNodePtr &node) {
520   MS_EXCEPTION_IF_NULL(node);
521   if (IsPrimitiveCNode(node, prim::kPrimResolve)) {
522     auto resolve_cnode = node->cast<CNodePtr>();
523     constexpr size_t namespace_index = 1;
524     auto namespace_node = resolve_cnode->input(namespace_index);
525     constexpr size_t symbol_index = 2;
526     auto symbol_node = resolve_cnode->input(symbol_index);
527     if (!IsValueNode<NameSpace>(namespace_node) || !IsValueNode<Symbol>(symbol_node)) {
528       MS_LOG(EXCEPTION) << "Unexpected type, namespace: " << namespace_node->ToString()
529                         << ", symbol: " << symbol_node->ToString();
530     }
531     // Deal with the case of GetAttr from a class member,
532     // and avoid the case of GetAttr from self (the result of ParseSuper).
533     auto name_space = GetValueNode<NameSpacePtr>(namespace_node);
534     auto symbol = GetValueNode<SymbolPtr>(symbol_node);
535     return {name_space, symbol};
536   }
537   constexpr auto recursive_level = 2;
538   MS_LOG(INTERNAL_EXCEPTION) << "It's not prim::Resolve CNode, node: " << node->DebugString(recursive_level);
539 }
540 
GetSymbolObject(const NameSpacePtr & name_space,const SymbolPtr & symbol,const AnfNodePtr & node)541 py::object GetSymbolObject(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node) {
542   MS_EXCEPTION_IF_NULL(node);
543   if (node->func_graph() == nullptr) {
544     MS_LOG(INTERNAL_EXCEPTION) << "Node " << node->DebugString() << " graph is nullptr.";
545   }
546   if (name_space->module() == RESOLVE_NAMESPACE_NAME_ENTRY) {
547     return name_space->module_obj();
548   } else if (name_space->module() == RESOLVE_NAMESPACE_NAME_CLASS_OBJECT) {
549     MS_LOG(DEBUG) << "namespace: " << py::str(name_space->namespace_obj()) << ", symbol: " << symbol;
550     return name_space->namespace_obj();
551   }
552   py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
553   auto &obj = name_space->namespace_obj();
554   if (py::isinstance<py::none>(obj)) {
555     MS_EXCEPTION(NameError) << "The name \'" << symbol << "\' is not defined.";
556   }
557   const auto &res =
558     python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_FUNCTION, obj, common::SafeCStr(symbol->symbol()));
559   MS_LOG(DEBUG) << "namespace: " << py::str(obj) << ", symbol: " << symbol << ", result: " << py::str(res);
560   return res;
561 }
562 
ResolveSymbol(const FuncGraphManagerPtr & manager,const NameSpacePtr & name_space,const SymbolPtr & symbol,const AnfNodePtr & node)563 AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol,
564                          const AnfNodePtr &node) {
565   MS_EXCEPTION_IF_NULL(node);
566   if (manager == nullptr) {
567     MS_LOG(INTERNAL_EXCEPTION) << "Manager is nullptr.";
568   }
569   MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString()
570                 << ", loc: " << trace::GetDebugInfoStr(node->debug_info());
571   TraceGuard trace_guard(std::make_shared<TraceResolve>(trace::GetSourceCodeDebugInfo(node->debug_info())));
572   auto obj = GetSymbolObject(name_space, symbol, node);
573   AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node);
574   if (IsValueNode<NameSpace>(resolved_node) && !py::isinstance<py::none>(name_space->module_obj())) {
575     auto name_value = GetValueNode(resolved_node);
576     auto nameptr = name_value->cast<NameSpacePtr>();
577     nameptr->set_module_obj(name_space->module_obj());
578   }
579   fallback::SetPyObjectToNode(resolved_node, obj);
580   // Update top graph debug info with user top graph's
581   if (name_space->module() == RESOLVE_NAMESPACE_NAME_ENTRY && IsValueNode<FuncGraph>(resolved_node)) {
582     auto user_top_fg = GetValueNode<FuncGraphPtr>(resolved_node);
583     MS_EXCEPTION_IF_NULL(user_top_fg);
584     auto top_fg = node->func_graph();
585     MS_EXCEPTION_IF_NULL(top_fg);
586     top_fg->set_debug_info(user_top_fg->debug_info());
587     top_fg->return_node()->set_debug_info(user_top_fg->return_node()->debug_info());
588     MS_LOG(DEBUG) << "Update top graph's and node's debug infos with user top graph's. top_fg: " << top_fg->ToString()
589                   << ", user_top_fg: " << user_top_fg->ToString();
590     top_fg->set_attrs(user_top_fg->attrs());
591     // Update top graph parameters' name
592     auto top_params = top_fg->parameters();
593     auto resolve_params = user_top_fg->parameters();
594     auto top_arg_size = top_fg->GetPositionalArgsCount();
595     auto user_arg_size = user_top_fg->GetPositionalArgsCount();
596     if (top_arg_size > user_arg_size) {
597       MS_LOG(INFO) << "Top graph's parameter size: " << top_arg_size
598                    << " should not be greater than resolved func_graph's parameter size: " << user_arg_size;
599     } else {
600       for (int i = 0; i < top_arg_size; i++) {
601         auto param_ptr = top_params[i]->cast<ParameterPtr>();
602         MS_EXCEPTION_IF_NULL(param_ptr);
603         auto user_param_ptr = resolve_params[i]->cast<ParameterPtr>();
604         MS_EXCEPTION_IF_NULL(user_param_ptr);
605         param_ptr->set_debug_info(user_param_ptr->debug_info());
606         param_ptr->set_name(user_param_ptr->name());
607       }
608       MS_LOG(DEBUG) << "Update top graph's parameters debug info with user top graph's parameters";
609     }
610   }
611   return resolved_node;
612 }
613 
CreateResolveNode(const py::object & obj,const AnfNodePtr & attr,const AnfNodePtr & get_attr_node)614 AnfNodePtr CreateResolveNode(const py::object &obj, const AnfNodePtr &attr, const AnfNodePtr &get_attr_node) {
615   py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
616   py::object namespace_obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
617   auto new_namespace = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_obj, obj);
618   auto attr_string = GetValuePtr<StringImm>(attr);
619   MS_EXCEPTION_IF_NULL(attr_string);
620   const std::string &attr_as_string = attr_string->value();
621   auto new_symbol = std::make_shared<Symbol>(attr_as_string);
622   MS_LOG(DEBUG) << "name_space: " << new_namespace->ToString() << ", symbol: " << new_symbol->ToString();
623 
624   auto fg = get_attr_node->func_graph();
625   MS_EXCEPTION_IF_NULL(fg);
626   AnfNodePtr resolved_node =
627     fg->NewCNode({NewValueNode(prim::kPrimResolve), NewValueNode(new_namespace), NewValueNode(new_symbol)});
628   resolved_node->set_debug_info(get_attr_node->debug_info());
629   fg->ReplaceInOrder(get_attr_node, resolved_node);
630   return resolved_node;
631 }
632 
633 // Resolve Cell GetAttr operation.
ResolveCellWithAttr(const FuncGraphManagerPtr & manager,const py::object & obj,const AnfNodePtr & resolve_node,const AnfNodePtr & attr,const AnfNodePtr & get_attr_node)634 AnfNodePtr ResolveCellWithAttr(const FuncGraphManagerPtr &manager, const py::object &obj,
635                                const AnfNodePtr &resolve_node, const AnfNodePtr &attr,
636                                const AnfNodePtr &get_attr_node) {
637   MS_EXCEPTION_IF_NULL(resolve_node);
638   MS_EXCEPTION_IF_NULL(attr);
639   MS_EXCEPTION_IF_NULL(manager);
640   MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", attr: " << attr->ToString();
641   if (IsValueNode<StringImm>(attr)) {
642     const auto &attr_name = GetValue<std::string>(GetValueNode(attr));
643     py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
644     bool is_property =
645       (python_adapter::CallPyModFn(mod, parse::PYTHON_PARSE_CHECK_ATTR_IS_PROPERTY, obj, attr_name)).cast<bool>();
646     if (is_property) {
647       auto get_attr_cnode = get_attr_node->cast<CNodePtr>();
648       AnfNodePtr node = get_attr_cnode->input(1);
649       auto cur_func = get_attr_node->func_graph();
650       auto call_func_node = parse::TransPropertyToFunc(cur_func, node, obj, attr_name);
651       MS_LOG(DEBUG) << "call_func_node:" << call_func_node->DebugString();
652       return call_func_node;
653     }
654   }
655   TraceGuard trace_guard(std::make_shared<TraceResolve>(get_attr_node->debug_info()));
656   if (!data_converter::IsCellInstance(obj)) {
657     AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, resolve_node);
658     AnfNodePtrList inputs = {NewValueNode(prim::kPrimGetAttr), resolved_node, attr};
659     auto cur_func = get_attr_node->func_graph();
660     MS_EXCEPTION_IF_NULL(cur_func);
661     AnfNodePtr res_node = cur_func->NewCNode(std::move(inputs));
662     res_node->set_debug_info(get_attr_node->debug_info());
663     cur_func->ReplaceInOrder(get_attr_node, res_node);
664     return res_node;
665   }
666 
667   constexpr auto tensors_queue_attr = "__is_tensors_queue__";
668   if (py::hasattr(obj, tensors_queue_attr) && IsValueNode<StringImm>(attr)) {
669     const auto &attr_name = GetValue<std::string>(GetValueNode(attr));
670     constexpr auto pop_attr = "pop";
671     if (attr_name == pop_attr) {
672       constexpr auto graph_pop_attr = "__graph_pop__";
673       MS_LOG(DEBUG) << "Replace " << pop_attr << " to " << graph_pop_attr << " for " << py::str(obj);
674       return CreateResolveNode(obj, NewValueNode(graph_pop_attr), get_attr_node);
675     }
676   }
677   return CreateResolveNode(obj, attr, get_attr_node);
678 }
679 
680 // Get attribute or method from ms_class obj or cell obj.
ResolveClassObjectWithAttr(const py::object & cls_obj,const AnfNodePtr & attr,const AnfNodePtr & get_attr_node)681 AnfNodePtr ResolveClassObjectWithAttr(const py::object &cls_obj, const AnfNodePtr &attr,
682                                       const AnfNodePtr &get_attr_node) {
683   MS_EXCEPTION_IF_NULL(get_attr_node);
684   MS_LOG(DEBUG) << "Resolve ms_class obj (" << py::str(cls_obj) << ") with attr " << attr->ToString() << ".";
685   TraceGuard trace_guard(std::make_shared<TraceResolve>(get_attr_node->debug_info()));
686   return CreateResolveNode(cls_obj, attr, get_attr_node);
687 }
688 
ResolveSequenceWithAttr(const FuncGraphManagerPtr & manager,const py::object & obj,const AnfNodePtr & resolve_node,const AnfNodePtr & attr,const CNodePtr & get_attr_node)689 AnfNodePtr ResolveSequenceWithAttr(const FuncGraphManagerPtr &manager, const py::object &obj,
690                                    const AnfNodePtr &resolve_node, const AnfNodePtr &attr,
691                                    const CNodePtr &get_attr_node) {
692   MS_EXCEPTION_IF_NULL(get_attr_node);
693   std::vector<AnfNodePtr> inputs;
694   inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
695   auto sequence = obj.cast<py::sequence>();
696   // Incorporate if all elements of the sequence are Cell instances or MsClass instances.
697   size_t count_cell = 0;
698   size_t count_msclass = 0;
699   size_t sequence_size = sequence.size();
700   for (size_t i = 0; i < sequence_size; ++i) {
701     if (data_converter::IsCellInstance(sequence[i])) {
702       ++count_cell;
703     } else if (data_converter::IsMsClassInstance(sequence[i])) {
704       ++count_msclass;
705     }
706   }
707   if (count_cell == sequence_size) {
708     // Resolve Cell instances.
709     for (size_t i = 0; i < sequence_size; ++i) {
710       auto res = ResolveCellWithAttr(manager, sequence[i], resolve_node, attr, get_attr_node);
711       (void)inputs.emplace_back(res);
712     }
713   } else if (count_msclass == sequence_size) {
714     // Resolve MsClass instances.
715     for (size_t i = 0; i < sequence_size; ++i) {
716       auto res = ResolveClassObjectWithAttr(sequence[i], attr, get_attr_node);
717       (void)inputs.emplace_back(res);
718     }
719   } else {
720     return nullptr;
721   }
722 
723   constexpr auto prim_index = 0;
724   constexpr auto index_index = 2;
725   auto fg = get_attr_node->func_graph();
726   MS_EXCEPTION_IF_NULL(fg);
727   auto make_tuple_node = fg->NewCNodeInOrder(inputs);
728   return fg->NewCNodeInOrder({get_attr_node->input(prim_index), make_tuple_node, get_attr_node->input(index_index)});
729 }
730 
ResolveSymbolWithAttr(const FuncGraphManagerPtr & manager,const AnfNodePtr & object_node,const AnfNodePtr & attr_node,const AnfNodePtr & get_attr_node)731 AnfNodePtr ResolveSymbolWithAttr(const FuncGraphManagerPtr &manager, const AnfNodePtr &object_node,
732                                  const AnfNodePtr &attr_node, const AnfNodePtr &get_attr_node) {
733   // {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
734   auto [name_space, symbol] = GetNamespaceAndSymbol(object_node);
735   MS_EXCEPTION_IF_NULL(name_space);
736   MS_EXCEPTION_IF_NULL(symbol);
737   constexpr std::string_view parse_super_name = "namespace";
738   if (symbol->symbol() == parse_super_name) {
739     return nullptr;
740   }
741   const auto &module_name = name_space->module();
742   auto symbol_obj = GetSymbolObject(name_space, symbol, get_attr_node);
743   if (module_name == RESOLVE_NAMESPACE_NAME_CLASS_MEMBER || data_converter::IsCellInstance(symbol_obj)) {
744     auto res = ResolveCellWithAttr(manager, symbol_obj, object_node, attr_node, get_attr_node);
745     res->set_user_data<py::object>("__getattr__", std::make_shared<py::object>(symbol_obj));
746     return res;
747   }
748   return nullptr;
749 }
750 
751 // Get python object with index from a list or the whole list if the index is not fixed.
GetObjectFromSequence(const NameSpacePtr & name_space,const SymbolPtr & symbol,const AnfNodePtr & node,const AnfNodePtr & index_node)752 py::object GetObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
753                                  const AnfNodePtr &index_node) {
754   MS_EXCEPTION_IF_NULL(node);
755   TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
756   py::object obj = GetSymbolObject(name_space, symbol, node);
757   // If obj is nn.CellList, convert it to sequence.
758   py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
759   bool is_cell_list = py::hasattr(obj, PYTHON_CELL_AS_LIST);
760   if (is_cell_list) {
761     obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CONVERT_CELL_LIST_TO_SEQUENCE, obj);
762   }
763   if (!py::isinstance<py::list>(obj) && !py::isinstance<py::tuple>(obj)) {
764     return py::none();
765   }
766 
767   MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", index_node: " << index_node->ToString();
768   auto imm_value = GetValueNode<Int64ImmPtr>(index_node);
769   if (imm_value == nullptr) {
770     MS_LOG(DEBUG) << "The index is not a value node, so we return the whole list, node: " << node->DebugString()
771                   << ", index_node: " << index_node->DebugString();
772     // Index is not fixed, return the whole list.
773     return obj;
774   }
775   // It index is a value node, get the item of index directly.
776   py::object item_obj =
777     python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_ITEM_FROM_SEQUENCE, obj, py::int_(imm_value->value()));
778   return item_obj;
779 }
780 
IsResolveNodeWithGetItem(const AnfNodePtr & node)781 bool IsResolveNodeWithGetItem(const AnfNodePtr &node) {
782   // Check if the node matches: {prim::kPrim::Resolve, ..., 'getitem'}.
783   if (IsPrimitiveCNode(node, prim::kPrimResolve)) {
784     constexpr size_t symbol_index = 2;
785     constexpr auto getitem_symbol = "getitem";
786     auto cnode = node->cast<CNodePtr>();
787     auto symbol = GetValueNode<SymbolPtr>(cnode->input(symbol_index));
788     return symbol->symbol() == getitem_symbol;
789   }
790   return false;
791 }
792 
IsGetItemCNode(const AnfNodePtr & node)793 bool IsGetItemCNode(const AnfNodePtr &node) {
794   if (!node->isa<CNode>()) {
795     return false;
796   }
797   auto cnode = node->cast<CNodePtr>();
798   constexpr size_t getitem_inputs_size = 3;
799   if (cnode->size() != getitem_inputs_size) {
800     return false;
801   }
802   constexpr auto prim_index = 0;
803   return IsResolveNodeWithGetItem(cnode->input(prim_index));
804 }
805 
ResolveGetItemWithAttr(const FuncGraphManagerPtr & manager,const AnfNodePtr & getitem_node,const AnfNodePtr & attr_node,const AnfNodePtr & node)806 AnfNodePtr ResolveGetItemWithAttr(const FuncGraphManagerPtr &manager, const AnfNodePtr &getitem_node,
807                                   const AnfNodePtr &attr_node, const AnfNodePtr &node) {
808   // {prim::kPrimGetAttr, {getitem, {prim::kPrimResolve, namespace, symbol}, index}, attr}
809   // {prim::kPrimGetAttr, {getitem, {prim::kPrimGetAttr, ResolveNode, member}, index}, attr}
810   constexpr auto data_index = 1;
811   constexpr auto index_index = 2;
812   auto getitem_cnode = getitem_node->cast<CNodePtr>();
813   auto data_node = getitem_cnode->input(data_index);
814   auto index_node = getitem_cnode->input(index_index);
815   if (IsPrimitiveCNode(data_node, prim::kPrimResolve)) {
816     auto [name_space, symbol] = GetNamespaceAndSymbol(data_node);
817     auto obj = GetObjectFromSequence(name_space, symbol, data_node, index_node);
818     if (py::isinstance<py::none>(obj)) {
819       return nullptr;
820     }
821     if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
822       return ResolveSequenceWithAttr(manager, obj, data_node, attr_node, getitem_cnode);
823     }
824     return ResolveCellWithAttr(manager, obj, data_node, attr_node, node);
825   }
826   if (IsPrimitiveCNode(data_node, prim::kPrimGetAttr)) {
827     auto getattr_cnode = data_node->cast<CNodePtr>();
828     auto resolve_node = getattr_cnode->input(data_index);
829     auto member_node = getattr_cnode->input(index_index);
830     if (IsPrimitiveCNode(resolve_node, prim::kPrimResolve)) {
831       // Check if the result is a new resolve node.
832       auto item_node = ResolveSymbolWithAttr(manager, resolve_node, member_node, node);
833       if (IsPrimitiveCNode(item_node, prim::kPrimResolve)) {
834         auto [name_space, symbol] = GetNamespaceAndSymbol(item_node);
835         auto obj = GetObjectFromSequence(name_space, symbol, item_node, index_node);
836         if (py::isinstance<py::none>(obj)) {
837           return nullptr;
838         }
839         if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
840           return ResolveSequenceWithAttr(manager, obj, item_node, attr_node, getitem_cnode);
841         }
842         return ResolveCellWithAttr(manager, obj, item_node, attr_node, node);
843       }
844     }
845   }
846   return nullptr;
847 }
848 
ResolveInterpretedObjectOfSetAttr(const AnfNodePtr & target_node,const AnfNodePtr & attr_node,const AnfNodePtr & value_node)849 AnfNodePtr ResolveInterpretedObjectOfSetAttr(const AnfNodePtr &target_node, const AnfNodePtr &attr_node,
850                                              const AnfNodePtr &value_node) {
851   auto [name_space, symbol] = GetNamespaceAndSymbol(target_node);
852   MS_EXCEPTION_IF_NULL(name_space);
853   MS_EXCEPTION_IF_NULL(symbol);
854   auto symbol_obj = GetSymbolObject(name_space, symbol, target_node);
855   auto interpreted_obj = std::make_shared<InterpretedObject>(symbol_obj);
856   MS_EXCEPTION_IF_NULL(interpreted_obj);
857   MS_LOG(DEBUG) << "Created a interpreted object: " << interpreted_obj->ToString();
858   const auto &resolve_node = ConvertInterpretedObjForResolve(target_node, interpreted_obj, target_node->func_graph());
859 
860   AnfNodePtrList inputs = {NewValueNode(prim::kPrimSetAttr), resolve_node, attr_node, value_node};
861   return target_node->func_graph()->NewCNodeInOrder(std::move(inputs));
862 }
863 
864 namespace {
GetOptResolvePasses(const opt::irpass::ResolveIRPassLib & irpass)865 opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) {
866   // For resolve and getattr primitive.
867   opt::OptPassGroupMap map({
868     {"resolve",
869      {
870        irpass.resolver_,
871      }},
872   });
873   return map;
874 }
875 }  // namespace
876 
ResolveFuncGraph(const FuncGraphPtr & func_graph,const pipeline::ResourceBasePtr & res,bool use_profile)877 bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile) {
878   if (func_graph == nullptr || res == nullptr) {
879     MS_LOG(ERROR) << "func_graph or resource is null";
880     return false;
881   }
882   opt::irpass::ResolveIRPassLib irpass;
883   opt::OptimizerPtr opt_resolve =
884     opt::Optimizer::MakeOptimizer("opt_resolve", res, GetOptResolvePasses(irpass), false, false, false);
885 
886   (void)python_adapter::set_python_scoped();
887 
888   MS_EXCEPTION_IF_NULL(opt_resolve);
889   (void)opt_resolve->step(func_graph, use_profile);
890   return true;
891 }
892 
ResolveAll(const FuncGraphManagerPtr & manager)893 bool ResolveAll(const FuncGraphManagerPtr &manager) {
894   if (manager == nullptr) {
895     MS_LOG(ERROR) << "func graph manager is null";
896     return false;
897   }
898 
899   if (manager->roots().size() > 1) {
900     MS_LOG(WARNING)
901       << "After call ResolveAll, only one graph will be kept in GraphManager. ResolveAll can resolve graphs"
902          "called from root graph, so it's not necessary to pass all graphs as roots. "
903          "Please ensure your usage.";
904   }
905   // Should not use pipeline::Resource as Resource::Clean will clean some
906   // global variable such as ScopeManager, it will cause JExpandedGraphs::GetBprop
907   // fail as valid scope has been cleaned.
908   auto res = std::make_shared<pipeline::ResourceBase>();
909   res->set_manager(manager);
910 
911   auto roots = manager->roots();
912   for (const auto &fg : roots) {
913     bool ret = ResolveFuncGraph(fg, res, false);
914     if (!ret) {
915       MS_EXCEPTION_IF_NULL(fg);
916       MS_LOG(ERROR) << "Resolve fg " << fg->ToString() << " failed";
917     }
918   }
919   return true;
920 }
921 
922 // If any mixed precision flag add a cast node after the parameter node.
923 // argument obj should be python Parameter object
924 // it will be converted to Parameter node here
ResolveParameterObj(const FuncGraphPtr & func_graph,const py::object & obj)925 AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) {
926   MS_EXCEPTION_IF_NULL(func_graph);
927 
928   // Parameter object should not be none
929   if (py::isinstance<py::none>(obj)) {
930     MS_LOG(EXCEPTION) << "Resolve class Parameter error because obj is null.";
931   }
932 
933   if (!py::hasattr(obj, "name")) {
934     MS_LOG(EXCEPTION) << "Resolve class Parameter error: cannot find name attr for obj";
935   }
936 
937   // Get the parameter name from parameter object
938   auto name_attr = python_adapter::GetPyObjAttr(obj, "name");
939   if (py::isinstance<py::none>(name_attr)) {
940     MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
941   }
942   auto obj_id = GetPyObjId(obj);
943   auto param_name = py::cast<std::string>(name_attr);
944   auto top_func_graph = Parser::GetTopFuncGraph();
945   // If the parameter node has been created , return it.
946   ParameterPtr para_node = nullptr;
947   for (auto const &param : top_func_graph->parameters()) {
948     auto param_node = dyn_cast<Parameter>(param);
949     if (param_node != nullptr && param_node->name() == param_name) {
950       if (param_node->is_top_graph_param()) {
951         // If the name of the input of construct is same as the parameters,
952         // add suffix to the name of the input of construct.
953         string suffix_name = param_node->name() + "_$";
954         param_node->set_name(suffix_name);
955         param_node->debug_info()->set_name(suffix_name);
956         MS_LOG(DEBUG) << "Add suffix to the name of the input of construct " << func_graph->ToString()
957                       << ", input: " << param_node->DebugString();
958       } else {
959         // Exist two parameter object which name is the same.
960         auto iter = param_obj_ids.find(param_name);
961         if (iter != param_obj_ids.end() && iter->second != obj_id) {
962           MS_LOG(EXCEPTION)
963             << "The parameter " << param_node->DebugString() << " , its name '" << param_name
964             << "' already exists. Please set a unique name for the parameter."
965             << "\nFor more details with the name of parameter, please refer to "
966             << "https://mindspore.cn/search?inputValue=Please%20set%20a%20unique%20name%20for%20the%20parameter";
967         }
968         para_node = param_node;
969         MS_LOG(DEBUG) << "Found existing parameter for " << func_graph->ToString()
970                       << ", param: " << para_node->DebugString() << ", top_func_graph: " << top_func_graph->ToString();
971         break;
972       }
973     }
974   }
975   if (para_node == nullptr) {
976     auto value = GetParameterValue(obj);
977     para_node = top_func_graph->AddFvParameter(param_name, value);
978     param_obj_ids[param_name] = obj_id;
979     MS_LOG(DEBUG) << "Created a new weight parameter for " << func_graph->ToString()
980                   << ", param: " << para_node->DebugString() << ", top_func_graph: " << top_func_graph->ToString();
981     auto context = parallel::ParallelContext::GetInstance();
982     if (context != nullptr && para_node->has_default()) {
983       auto param_abs = pipeline::GetDefaultValueAbstract(para_node);
984       context->ParallelParameterContextRestoreShape(top_func_graph, para_node, param_abs);
985       para_node->set_abstract(param_abs);
986     }
987   }
988   func_graph->add_parameter_obj_node(para_node);
989   return para_node;
990 }
991 }  // namespace parse
992 }  // namespace mindspore
993