1 /**
2 * Copyright 2019-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/jit/ps/parse/resolve.h"
18
19 #include <utility>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 #include <algorithm>
24 #include <unordered_map>
25
26 #include "mindspore/core/ops/structure_ops.h"
27 #include "mindspore/core/ops/sequence_ops.h"
28 #include "mindspore/core/ops/framework_ops.h"
29 #include "ir/param_info.h"
30 #include "ir/value.h"
31 #include "ir/map_tensor.h"
32 #include "pipeline/jit/ps/fallback.h"
33 #include "pipeline/jit/ps/parse/data_converter.h"
34 #include "pipeline/jit/ps/parse/parse.h"
35 #include "include/common/utils/python_adapter.h"
36 #include "include/common/utils/parallel_context.h"
37 #include "utils/any.h"
38 #include "frontend/operator/ops.h"
39 #include "frontend/optimizer/opt.h"
40 #include "frontend/optimizer/irpass.h"
41 #include "frontend/optimizer/irpass/symbol_resolver.h"
42 #include "include/common/fallback.h"
43 #include "include/common/debug/anf_dump_utils.h"
44 #include "utils/log_adapter.h"
45
46 namespace mindspore {
47 namespace parse {
48 static std::unordered_map<std::string, std::string> param_obj_ids; // param_name : obj_id
CleanParameterNameCache()49 void CleanParameterNameCache() {
50 MS_LOG(DEBUG) << "Clean parameter name cache.";
51 param_obj_ids.clear();
52 }
53 namespace {
ReplaceSpecialChar(const std::string & str)54 std::string ReplaceSpecialChar(const std::string &str) {
55 std::ostringstream oss;
56 for (size_t i = 0; i < str.size(); i++) {
57 if (str[i] == '<') {
58 // ⎡: \u23A1
59 oss << "\u23A1";
60 } else if (str[i] == '>') {
61 // ⎦: \u23A6
62 oss << "\u23A6";
63 } else {
64 oss << str[i];
65 }
66 }
67 return oss.str();
68 }
69
70 struct AnfDumpHandlerRegister {
AnfDumpHandlerRegistermindspore::parse::__anon944f56ff0111::AnfDumpHandlerRegister71 AnfDumpHandlerRegister() {
72 AnfDumpHandler::SetValueNodeStrHandler([](const std::shared_ptr<ValueNode> &node) -> std::string {
73 if (node == nullptr) {
74 return "";
75 }
76 if (IsValueNode<MetaFuncGraph>(node)) {
77 return node->value()->cast<MetaFuncGraphPtr>()->name();
78 } else if (IsValueNode<parse::NameSpace>(node)) {
79 return node->value()->cast<parse::NameSpacePtr>()->name();
80 } else if (IsValueNode<parse::Symbol>(node)) {
81 return ReplaceSpecialChar(node->value()->cast<parse::SymbolPtr>()->name());
82 }
83 return "";
84 });
85 }
86 } callback_register;
87 } // namespace
88
InterpretedObject(const py::object & obj)89 InterpretedObject::InterpretedObject(const py::object &obj) : PyObjectWrapper(obj) {
90 std::stringstream buf;
91 auto type_str = python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_GET_TYPE, obj);
92 buf << "PythonObject(type: " << std::string(py::str(type_str)) << ", value: " << std::string(py::str(obj)) << ")";
93 this->set_name(buf.str());
94 }
95
ToAbstract()96 abstract::AbstractBasePtr MsClassObject::ToAbstract() {
97 py::gil_scoped_acquire acquire;
98 bool is_class_type = parse::data_converter::IsClassType(obj());
99 if (is_class_type) {
100 // Class type as func, such as Net(x, y)
101 auto abs_class = std::make_shared<abstract::AbstractClass>(shared_from_base<MsClassObject>());
102 AbstractBasePtrList args_abs_list = {abs_class};
103 auto func = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimCreateInstance);
104 auto res_val = std::make_shared<abstract::PartialAbstractClosure>(func, args_abs_list);
105 res_val->set_value_desc(ToString());
106 return res_val;
107 } else {
108 // Class instance as func, such as net(x, y)
109 return std::make_shared<abstract::AbstractClass>(shared_from_base<MsClassObject>());
110 }
111 }
112
IsSupportedCreateInstanceType(const py::object & obj)113 static inline bool IsSupportedCreateInstanceType(const py::object &obj) {
114 py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
115 auto res = python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE, obj);
116 if (!py::isinstance<py::bool_>(res)) {
117 MS_LOG(ERROR) << "Expect a bool type, but got " << py::str(res);
118 return false;
119 }
120 return res.cast<bool>();
121 }
122
ToAbstract()123 abstract::AbstractBasePtr ClassType::ToAbstract() {
124 py::gil_scoped_acquire acquire;
125 auto abs_scalar =
126 std::make_shared<abstract::AbstractScalar>(shared_from_base<ClassType>(), std::make_shared<TypeType>());
127
128 if (!IsSupportedCreateInstanceType(obj())) {
129 return abs_scalar;
130 }
131 AbstractBasePtrList args_abs_list = {abs_scalar};
132
133 auto func = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimCreateInstance);
134 auto res_val = std::make_shared<abstract::PartialAbstractClosure>(func, args_abs_list);
135 res_val->set_value_desc(ToString());
136 return res_val;
137 }
138
139 using tensor::MapTensorPtr;
140 // Get parameter value from a python parameter object.
141 // If it is a map parameter, return the map tensor value in it,
142 // otherwise, return parameter itself as a meta tensor value.
GetParameterValue(const py::object & param_obj)143 ValuePtr GetParameterValue(const py::object ¶m_obj) {
144 constexpr char attr_map_tensor[] = "_map_tensor";
145 constexpr char attr_param_info[] = "param_info";
146 if (py::hasattr(param_obj, attr_map_tensor)) {
147 auto map_tensor = py::cast<MapTensorPtr>(python_adapter::GetPyObjAttr(param_obj, attr_map_tensor));
148 MS_EXCEPTION_IF_NULL(map_tensor);
149 auto param_info = py::cast<ParamInfoPtr>(python_adapter::GetPyObjAttr(param_obj, attr_param_info));
150 MS_EXCEPTION_IF_NULL(param_info);
151 map_tensor->set_param_info(param_info);
152 return map_tensor;
153 }
154 return py::cast<tensor::MetaTensorPtr>(param_obj);
155 }
156
157 namespace {
GetPyObjId(const py::object & obj)158 std::string GetPyObjId(const py::object &obj) {
159 py::object out = python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj);
160 if (py::isinstance<py::none>(out)) {
161 MS_LOG(INTERNAL_EXCEPTION) << "Get pyobj failed";
162 }
163 return out.cast<std::string>();
164 }
165
ClearCNodeAbstract(const FuncGraphPtr & func_graph)166 void ClearCNodeAbstract(const FuncGraphPtr &func_graph) {
167 std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccDeeperSimple, AlwaysInclude);
168 static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
169 for (const auto &node : nodes) {
170 if (node == nullptr || node->isa<Parameter>()) {
171 continue;
172 }
173 auto primitive = GetCNodePrimitive(node);
174 if (primitive != nullptr) {
175 auto is_load = primitive->GetAttr("is_load");
176 if (abstract::GetPrimEvaluator(primitive, nullptr) == nullptr && is_load != nullptr && GetValue<bool>(is_load)) {
177 MS_LOG(INFO) << "The primitive is not defined in front end. Primitive: " << primitive->ToString();
178 continue;
179 }
180 }
181 auto prev_inferred = node->abstract();
182 // Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction.
183 if (!node->isa<ValueNode>() || (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>())) {
184 // Reset tuple/list abstract use flags.
185 if (enable_eliminate_unused_element && prev_inferred != nullptr &&
186 prev_inferred->isa<abstract::AbstractSequence>()) {
187 SetSequenceNodeElementsUseFlags(node, nullptr);
188 }
189 node->set_abstract(nullptr);
190 MS_LOG(DEBUG) << "Abstract of node " << node->DebugString() << " is set to nullptr";
191 }
192 }
193 }
194
ConvertLoadedGraph(const FuncGraphPtr & func_graph,const ValuePtr & value)195 void ConvertLoadedGraph(const FuncGraphPtr &func_graph, const ValuePtr &value) {
196 if (!value->isa<FuncGraph>()) {
197 return;
198 }
199 auto resolved_graph = value->cast<FuncGraphPtr>();
200 MS_EXCEPTION_IF_NULL(resolved_graph);
201 if (!resolved_graph->has_attr("is_load")) {
202 return;
203 }
204 auto top_graph = Parser::GetTopFuncGraph();
205 std::vector<AnfNodePtr> input_params;
206 auto resolved_graph_count = resolved_graph->fv_param_count();
207 std::vector<ParameterPtr> drop_node_list;
208 for (auto const ¶m : resolved_graph->parameters()) {
209 auto param_ptr = dyn_cast<Parameter>(param);
210 MS_EXCEPTION_IF_NULL(param_ptr);
211 if (param_ptr->has_default()) {
212 param_ptr->set_func_graph(top_graph);
213 func_graph->add_parameter_obj_node(param_ptr);
214 // Update top_graph
215 top_graph->add_parameter(param_ptr);
216 size_t fv_param_count = top_graph->fv_param_count();
217 top_graph->set_fv_param_count(++fv_param_count);
218 (void)drop_node_list.emplace_back(param_ptr);
219 resolved_graph->set_fv_param_count(--resolved_graph_count);
220 } else {
221 input_params.push_back(param_ptr);
222 }
223 }
224 for (const auto ¶m_ptr : drop_node_list) {
225 resolved_graph->DropNode(param_ptr);
226 }
227 resolved_graph->set_parameters(input_params);
228 ClearCNodeAbstract(resolved_graph);
229 }
230
HasConstArgAttr(const py::object & obj)231 bool HasConstArgAttr(const py::object &obj) {
232 constexpr char const_arg_attr[] = "const_arg";
233 return py::hasattr(obj, const_arg_attr) && py::cast<bool>(py::getattr(obj, const_arg_attr));
234 }
235
HasMutableAttr(const py::object & obj)236 bool HasMutableAttr(const py::object &obj) {
237 constexpr char mutable_attr[] = "__ms_mutable__";
238 return py::hasattr(obj, mutable_attr) && py::cast<bool>(py::getattr(obj, mutable_attr));
239 }
240
HasVariableLenAttr(const py::object & obj)241 bool HasVariableLenAttr(const py::object &obj) {
242 constexpr char variable_len_attr[] = "__ms_dynamic_len__";
243 return py::hasattr(obj, variable_len_attr) && py::cast<bool>(py::getattr(obj, variable_len_attr));
244 }
245
ConvertInterpretedObjForResolve(const AnfNodePtr & origin_node,const ValuePtr & convert_result,const FuncGraphPtr & func_graph)246 AnfNodePtr ConvertInterpretedObjForResolve(const AnfNodePtr &origin_node, const ValuePtr &convert_result,
247 const FuncGraphPtr &func_graph) {
248 if (convert_result->isa<InterpretedObject>() && !origin_node->has_user_data("__py_interpret_local_value_flag__")) {
249 constexpr auto recursive_level = 2;
250 MS_LOG(DEBUG) << "Convert InterpretedObj for resolve, node: " << origin_node->DebugString(recursive_level);
251 auto interpreted_value = dyn_cast<InterpretedObject>(convert_result);
252 const auto &key = interpreted_value->name();
253 if (interpreted_value->has_converted()) {
254 return fallback::ConvertPyObjectToPyExecute(func_graph, key, interpreted_value->obj(), origin_node, true);
255 }
256 return fallback::ConvertPyObjectToPyInterpret(func_graph, key, interpreted_value->obj(), origin_node, true);
257 }
258 return nullptr;
259 }
260
ConvertObjectToNode(const AnfNodePtr & origin_node,const py::object & obj,const FuncGraphPtr & func_graph,bool is_element_obj)261 AnfNodePtr ConvertObjectToNode(const AnfNodePtr &origin_node, const py::object &obj, const FuncGraphPtr &func_graph,
262 bool is_element_obj) {
263 // When the cell is set recomputed, it should not use old scope from cache.
264 MS_EXCEPTION_IF_NULL(origin_node);
265 auto origin_cnode = dyn_cast<CNode>(origin_node);
266 MS_EXCEPTION_IF_NULL(origin_cnode);
267 bool is_resolve = IsPrimitiveCNode(origin_node, prim::kPrimResolve);
268 auto scope = origin_node->scope();
269 bool has_recompute_scope =
270 (scope != nullptr && scope->name().compare(0, strlen(kAttrRecompute), kAttrRecompute) == 0);
271 ValuePtr convert_result = nullptr;
272 constexpr auto resolve_with_args_inputs_size = 4;
273 MS_LOG(DEBUG) << "origin_cnode: " << origin_cnode->DebugString();
274 if (is_resolve && origin_cnode->size() == resolve_with_args_inputs_size) { // (resolve, namespace, symbol, arguments)
275 constexpr auto args_input_pos = 3;
276 auto args_node = origin_cnode->input(args_input_pos);
277 auto args_value = GetValueNode<ValueTuplePtr>(args_node);
278 MS_EXCEPTION_IF_NULL(args_value);
279 parse::DataConverter data_converter(args_value->value(), python_adapter::UseSignatureInResolve());
280 convert_result = data_converter.ConvertData(obj);
281 if (convert_result == nullptr) {
282 MS_LOG(INTERNAL_EXCEPTION) << "Convert error with Python object: " << std::string(py::str(obj));
283 }
284 } else { // (resolve/getattr, namespace, symbol, optional[getattr])
285 bool converted =
286 ConvertData(obj, &convert_result, python_adapter::UseSignatureInResolve(), nullptr, has_recompute_scope);
287 if (!converted) {
288 MS_LOG(ERROR) << "Convert data failed";
289 return nullptr;
290 }
291 }
292
293 // If obj is an element, do not convert InterpretedObj.
294 if (!is_element_obj) {
295 AnfNodePtr interpreted_output = ConvertInterpretedObjForResolve(origin_node, convert_result, func_graph);
296 if (interpreted_output != nullptr) {
297 return interpreted_output;
298 }
299 }
300
301 if (convert_result->isa<FuncGraph>() && has_recompute_scope) {
302 UpdateRecomputeScope(convert_result->cast<FuncGraphPtr>());
303 }
304 ConvertLoadedGraph(func_graph, convert_result);
305 AnfNodePtr output = NewValueNode(convert_result);
306 if (convert_result->isa<tensor::Tensor>()) {
307 output = GetMixedPrecisionCastHelp(func_graph, output);
308 if (HasConstArgAttr(obj)) {
309 MS_LOG(WARNING) << "The tensor " << convert_result->ToString()
310 << " which is not used for network input argument should not be set const.";
311 }
312 }
313 if (HasMutableAttr(obj)) {
314 auto dynamic_len = HasVariableLenAttr(obj);
315 output = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimMutable), output, NewValueNode(dynamic_len)});
316 }
317 return output;
318 }
319
TransformFuncValueNode(const FuncGraphManagerPtr & manager,const FuncGraphPtr & func_graph,const ValuePtr & value)320 AnfNodePtr TransformFuncValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph,
321 const ValuePtr &value) {
322 MS_EXCEPTION_IF_NULL(value);
323 if (value->isa<FuncGraph>()) {
324 auto fg = value->cast<FuncGraphPtr>();
325 manager->AddFuncGraph(fg);
326 return NewValueNode(fg);
327 }
328 if (value->isa<Primitive>()) {
329 return NewValueNode(value);
330 }
331 // (1) The CellList or CellDict will be parsed as value_sequence or value_dict of const graph in it,
332 // So if there is graph in list, try to replace the node with make_tuple or make_dict of graph value node.
333 // We do this because the graph manager won't investigate the graph inside value_sequence or value_dict,
334 // change the vector of graph to be make_tuple or make_dict of graph value node.
335 // (2) the primitive value_tuple or value_sequence or value_dict may encounter to abstract error, make it all
336 // independent nodes.
337 if (value->isa<ValueSequence>()) {
338 std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
339 bool is_all_func = true;
340 auto value_sequence = value->cast<ValueSequencePtr>();
341 if (value_sequence->size() == 0) {
342 return nullptr;
343 }
344 for (auto &elem : value_sequence->value()) {
345 auto node = TransformFuncValueNode(manager, func_graph, elem);
346 if (node == nullptr) {
347 is_all_func = false;
348 }
349 (void)inputs.emplace_back(node);
350 }
351 if (is_all_func) {
352 return func_graph->NewCNode(std::move(inputs));
353 }
354 return nullptr;
355 }
356 if (value->isa<ValueDictionary>()) {
357 std::vector<AnfNodePtr> keys{NewValueNode(prim::kPrimMakeTuple)};
358 std::vector<AnfNodePtr> values{NewValueNode(prim::kPrimMakeTuple)};
359 bool is_all_func = true;
360 for (auto &elem : value->cast<ValueDictionaryPtr>()->value()) {
361 (void)keys.emplace_back(NewValueNode(elem.first));
362 auto node = TransformFuncValueNode(manager, func_graph, elem.second);
363 if (node == nullptr) {
364 is_all_func = false;
365 }
366 (void)values.emplace_back(node);
367 }
368 if (is_all_func) {
369 return func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(std::move(keys)),
370 func_graph->NewCNode(std::move(values))});
371 }
372 return nullptr;
373 }
374
375 return nullptr;
376 }
377
378 // Resolve the python obj, and if the resovled node is valuenode with graphs, add the graphs to manager.
ResolveObjectAndAddToManager(const FuncGraphManagerPtr & manager,const py::object & obj,const AnfNodePtr & node)379 AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, const py::object &obj,
380 const AnfNodePtr &node) {
381 MS_EXCEPTION_IF_NULL(node);
382 ScopeGuard scope_guard(node->scope());
383 AnfNodePtr resolved_node = nullptr;
384 bool success = ResolveObjectToNode(node, obj, &resolved_node);
385 if (!success) {
386 MS_LOG(INTERNAL_EXCEPTION) << "Parse Resolve covert failed.";
387 }
388 if (IsValueNode<FuncGraph>(resolved_node)) {
389 auto new_fg = GetValueNode<FuncGraphPtr>(resolved_node);
390 auto fg = node->func_graph();
391 MS_EXCEPTION_IF_NULL(fg);
392 // If it's the sub func graph resolved in a reserved func graph.
393 if (fg->reserved()) {
394 new_fg->set_reserved(true);
395 }
396 manager->AddFuncGraph(new_fg);
397 }
398
399 // If the constant node is constant of vector of graph, add graph to manager.
400 if (IsValueNode<ValueSequence>(resolved_node) || IsValueNode<ValueDictionary>(resolved_node)) {
401 auto value = resolved_node->cast<ValueNodePtr>()->value();
402 auto new_node = TransformFuncValueNode(manager, node->func_graph(), value);
403 if (new_node != nullptr) {
404 resolved_node = new_node;
405 }
406 }
407 fallback::SetPyObjectToNode(resolved_node, obj);
408 return resolved_node;
409 }
410
IsParameterObject(const py::object & obj)411 bool IsParameterObject(const py::object &obj) {
412 return py::hasattr(obj, "__parameter__") && py::isinstance<tensor::MetaTensor>(obj);
413 }
414
ContainsParameter(const py::object & obj)415 bool ContainsParameter(const py::object &obj) {
416 if (IsParameterObject(obj) || py::hasattr(obj, "__parameter_tuple__")) {
417 return true;
418 }
419 if ((py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) && py::len(obj) != 0) {
420 // NamedTuple
421 if (py::hasattr(obj, "_fields")) {
422 return false;
423 }
424 auto tuple = obj.cast<py::tuple>();
425 for (size_t i = 0; i < tuple.size(); ++i) {
426 if (ContainsParameter(tuple[i])) {
427 return true;
428 }
429 }
430 } else if (py::isinstance<py::dict>(obj)) {
431 auto dict = obj.cast<py::dict>();
432 for (auto item : dict) {
433 auto item_value = py::cast<py::object>(item.second);
434 if (ContainsParameter(item_value)) {
435 return true;
436 }
437 }
438 }
439 return false;
440 }
441 } // namespace
442
ResolveObjectToNode(const AnfNodePtr & origin_node,const py::object & obj,AnfNodePtr * const node,bool is_element_obj)443 bool ResolveObjectToNode(const AnfNodePtr &origin_node, const py::object &obj, AnfNodePtr *const node,
444 bool is_element_obj) {
445 MS_EXCEPTION_IF_NULL(origin_node);
446 auto func_graph = origin_node->func_graph();
447 MS_EXCEPTION_IF_NULL(func_graph);
448 if (!ContainsParameter(obj)) {
449 auto output = ConvertObjectToNode(origin_node, obj, func_graph, is_element_obj);
450 if (output == nullptr) {
451 return false;
452 }
453 *node = output;
454 return true;
455 }
456 if (IsParameterObject(obj)) {
457 auto param = ResolveParameterObj(func_graph, obj);
458 if (param == nullptr) {
459 MS_LOG(ERROR) << "Resolve parameter object failed, got nullptr";
460 return false;
461 }
462 MS_LOG(DEBUG) << "Add param graph:" << func_graph->ToString() << ", " << param->DebugString();
463 *node = param;
464 return true;
465 }
466 if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj) || py::hasattr(obj, "__parameter_tuple__")) {
467 bool all_parameter_sequence = true;
468 std::vector<AnfNodePtr> args;
469 auto tuple = obj.cast<py::tuple>();
470 for (size_t i = 0; i < tuple.size(); ++i) {
471 if (!IsParameterObject(tuple[i])) {
472 all_parameter_sequence = false;
473 }
474 AnfNodePtr out = nullptr;
475 bool success = ResolveObjectToNode(origin_node, tuple[i], &out, true);
476 if (!success) {
477 MS_LOG(ERROR) << "Resolve object to node failed";
478 return false;
479 }
480 args.push_back(out);
481 }
482 // Convert [param1, param2, ..., paramN] to tuple.
483 bool need_convert_to_tuple = !is_element_obj && all_parameter_sequence && py::isinstance<py::list>(obj);
484 if (py::isinstance<py::tuple>(obj) || py::hasattr(obj, "__parameter_tuple__") || need_convert_to_tuple) {
485 (void)args.insert(args.begin(), NewValueNode(prim::kPrimMakeTuple));
486 } else {
487 (void)args.insert(args.begin(), NewValueNode(prim::kPrimMakeList));
488 }
489 // The ParameterTuple/tuple/list will not be added in order list,
490 // since we don't want to deal with its RefTensor elements during auto_monad procedure.
491 *node = NewCNode(std::move(args), func_graph);
492 return true;
493 }
494 if (py::isinstance<py::dict>(obj)) {
495 auto dict = obj.cast<py::dict>();
496 std::vector<AnfNodePtr> keys_tuple{NewValueNode(prim::kPrimMakeTuple)};
497 std::vector<AnfNodePtr> values_tuple{NewValueNode(prim::kPrimMakeTuple)};
498 for (auto item : dict) {
499 AnfNodePtr key = nullptr;
500 AnfNodePtr value = nullptr;
501 bool success = ResolveObjectToNode(origin_node, py::cast<py::object>(item.first), &key, true) &&
502 ResolveObjectToNode(origin_node, py::cast<py::object>(item.second), &value, true);
503 if (!success) {
504 MS_LOG(ERROR) << "Resolve object to node failed";
505 return false;
506 }
507 (void)keys_tuple.emplace_back(key);
508 (void)values_tuple.emplace_back(value);
509 }
510 *node = func_graph->NewCNode(
511 {NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(keys_tuple), func_graph->NewCNode(values_tuple)});
512 return true;
513 }
514 MS_EXCEPTION(TypeError) << "The Parameter in obj '" << py::str(obj) << "' with nested structure is not supported."
515 << "\nCurrently only single Parameter, ParameterTuple or Parameters in tuple/list/dict "
516 "are supported. Or do you want to use Tensor instead?";
517 }
518
GetNamespaceAndSymbol(const AnfNodePtr & node)519 std::pair<NameSpacePtr, SymbolPtr> GetNamespaceAndSymbol(const AnfNodePtr &node) {
520 MS_EXCEPTION_IF_NULL(node);
521 if (IsPrimitiveCNode(node, prim::kPrimResolve)) {
522 auto resolve_cnode = node->cast<CNodePtr>();
523 constexpr size_t namespace_index = 1;
524 auto namespace_node = resolve_cnode->input(namespace_index);
525 constexpr size_t symbol_index = 2;
526 auto symbol_node = resolve_cnode->input(symbol_index);
527 if (!IsValueNode<NameSpace>(namespace_node) || !IsValueNode<Symbol>(symbol_node)) {
528 MS_LOG(EXCEPTION) << "Unexpected type, namespace: " << namespace_node->ToString()
529 << ", symbol: " << symbol_node->ToString();
530 }
531 // Deal with the case of GetAttr from a class member,
532 // and avoid the case of GetAttr from self (the result of ParseSuper).
533 auto name_space = GetValueNode<NameSpacePtr>(namespace_node);
534 auto symbol = GetValueNode<SymbolPtr>(symbol_node);
535 return {name_space, symbol};
536 }
537 constexpr auto recursive_level = 2;
538 MS_LOG(INTERNAL_EXCEPTION) << "It's not prim::Resolve CNode, node: " << node->DebugString(recursive_level);
539 }
540
GetSymbolObject(const NameSpacePtr & name_space,const SymbolPtr & symbol,const AnfNodePtr & node)541 py::object GetSymbolObject(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node) {
542 MS_EXCEPTION_IF_NULL(node);
543 if (node->func_graph() == nullptr) {
544 MS_LOG(INTERNAL_EXCEPTION) << "Node " << node->DebugString() << " graph is nullptr.";
545 }
546 if (name_space->module() == RESOLVE_NAMESPACE_NAME_ENTRY) {
547 return name_space->module_obj();
548 } else if (name_space->module() == RESOLVE_NAMESPACE_NAME_CLASS_OBJECT) {
549 MS_LOG(DEBUG) << "namespace: " << py::str(name_space->namespace_obj()) << ", symbol: " << symbol;
550 return name_space->namespace_obj();
551 }
552 py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
553 auto &obj = name_space->namespace_obj();
554 if (py::isinstance<py::none>(obj)) {
555 MS_EXCEPTION(NameError) << "The name \'" << symbol << "\' is not defined.";
556 }
557 const auto &res =
558 python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_FUNCTION, obj, common::SafeCStr(symbol->symbol()));
559 MS_LOG(DEBUG) << "namespace: " << py::str(obj) << ", symbol: " << symbol << ", result: " << py::str(res);
560 return res;
561 }
562
ResolveSymbol(const FuncGraphManagerPtr & manager,const NameSpacePtr & name_space,const SymbolPtr & symbol,const AnfNodePtr & node)563 AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol,
564 const AnfNodePtr &node) {
565 MS_EXCEPTION_IF_NULL(node);
566 if (manager == nullptr) {
567 MS_LOG(INTERNAL_EXCEPTION) << "Manager is nullptr.";
568 }
569 MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString()
570 << ", loc: " << trace::GetDebugInfoStr(node->debug_info());
571 TraceGuard trace_guard(std::make_shared<TraceResolve>(trace::GetSourceCodeDebugInfo(node->debug_info())));
572 auto obj = GetSymbolObject(name_space, symbol, node);
573 AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node);
574 if (IsValueNode<NameSpace>(resolved_node) && !py::isinstance<py::none>(name_space->module_obj())) {
575 auto name_value = GetValueNode(resolved_node);
576 auto nameptr = name_value->cast<NameSpacePtr>();
577 nameptr->set_module_obj(name_space->module_obj());
578 }
579 fallback::SetPyObjectToNode(resolved_node, obj);
580 // Update top graph debug info with user top graph's
581 if (name_space->module() == RESOLVE_NAMESPACE_NAME_ENTRY && IsValueNode<FuncGraph>(resolved_node)) {
582 auto user_top_fg = GetValueNode<FuncGraphPtr>(resolved_node);
583 MS_EXCEPTION_IF_NULL(user_top_fg);
584 auto top_fg = node->func_graph();
585 MS_EXCEPTION_IF_NULL(top_fg);
586 top_fg->set_debug_info(user_top_fg->debug_info());
587 top_fg->return_node()->set_debug_info(user_top_fg->return_node()->debug_info());
588 MS_LOG(DEBUG) << "Update top graph's and node's debug infos with user top graph's. top_fg: " << top_fg->ToString()
589 << ", user_top_fg: " << user_top_fg->ToString();
590 top_fg->set_attrs(user_top_fg->attrs());
591 // Update top graph parameters' name
592 auto top_params = top_fg->parameters();
593 auto resolve_params = user_top_fg->parameters();
594 auto top_arg_size = top_fg->GetPositionalArgsCount();
595 auto user_arg_size = user_top_fg->GetPositionalArgsCount();
596 if (top_arg_size > user_arg_size) {
597 MS_LOG(INFO) << "Top graph's parameter size: " << top_arg_size
598 << " should not be greater than resolved func_graph's parameter size: " << user_arg_size;
599 } else {
600 for (int i = 0; i < top_arg_size; i++) {
601 auto param_ptr = top_params[i]->cast<ParameterPtr>();
602 MS_EXCEPTION_IF_NULL(param_ptr);
603 auto user_param_ptr = resolve_params[i]->cast<ParameterPtr>();
604 MS_EXCEPTION_IF_NULL(user_param_ptr);
605 param_ptr->set_debug_info(user_param_ptr->debug_info());
606 param_ptr->set_name(user_param_ptr->name());
607 }
608 MS_LOG(DEBUG) << "Update top graph's parameters debug info with user top graph's parameters";
609 }
610 }
611 return resolved_node;
612 }
613
CreateResolveNode(const py::object & obj,const AnfNodePtr & attr,const AnfNodePtr & get_attr_node)614 AnfNodePtr CreateResolveNode(const py::object &obj, const AnfNodePtr &attr, const AnfNodePtr &get_attr_node) {
615 py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
616 py::object namespace_obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
617 auto new_namespace = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_obj, obj);
618 auto attr_string = GetValuePtr<StringImm>(attr);
619 MS_EXCEPTION_IF_NULL(attr_string);
620 const std::string &attr_as_string = attr_string->value();
621 auto new_symbol = std::make_shared<Symbol>(attr_as_string);
622 MS_LOG(DEBUG) << "name_space: " << new_namespace->ToString() << ", symbol: " << new_symbol->ToString();
623
624 auto fg = get_attr_node->func_graph();
625 MS_EXCEPTION_IF_NULL(fg);
626 AnfNodePtr resolved_node =
627 fg->NewCNode({NewValueNode(prim::kPrimResolve), NewValueNode(new_namespace), NewValueNode(new_symbol)});
628 resolved_node->set_debug_info(get_attr_node->debug_info());
629 fg->ReplaceInOrder(get_attr_node, resolved_node);
630 return resolved_node;
631 }
632
633 // Resolve Cell GetAttr operation.
ResolveCellWithAttr(const FuncGraphManagerPtr & manager,const py::object & obj,const AnfNodePtr & resolve_node,const AnfNodePtr & attr,const AnfNodePtr & get_attr_node)634 AnfNodePtr ResolveCellWithAttr(const FuncGraphManagerPtr &manager, const py::object &obj,
635 const AnfNodePtr &resolve_node, const AnfNodePtr &attr,
636 const AnfNodePtr &get_attr_node) {
637 MS_EXCEPTION_IF_NULL(resolve_node);
638 MS_EXCEPTION_IF_NULL(attr);
639 MS_EXCEPTION_IF_NULL(manager);
640 MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", attr: " << attr->ToString();
641 if (IsValueNode<StringImm>(attr)) {
642 const auto &attr_name = GetValue<std::string>(GetValueNode(attr));
643 py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
644 bool is_property =
645 (python_adapter::CallPyModFn(mod, parse::PYTHON_PARSE_CHECK_ATTR_IS_PROPERTY, obj, attr_name)).cast<bool>();
646 if (is_property) {
647 auto get_attr_cnode = get_attr_node->cast<CNodePtr>();
648 AnfNodePtr node = get_attr_cnode->input(1);
649 auto cur_func = get_attr_node->func_graph();
650 auto call_func_node = parse::TransPropertyToFunc(cur_func, node, obj, attr_name);
651 MS_LOG(DEBUG) << "call_func_node:" << call_func_node->DebugString();
652 return call_func_node;
653 }
654 }
655 TraceGuard trace_guard(std::make_shared<TraceResolve>(get_attr_node->debug_info()));
656 if (!data_converter::IsCellInstance(obj)) {
657 AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, resolve_node);
658 AnfNodePtrList inputs = {NewValueNode(prim::kPrimGetAttr), resolved_node, attr};
659 auto cur_func = get_attr_node->func_graph();
660 MS_EXCEPTION_IF_NULL(cur_func);
661 AnfNodePtr res_node = cur_func->NewCNode(std::move(inputs));
662 res_node->set_debug_info(get_attr_node->debug_info());
663 cur_func->ReplaceInOrder(get_attr_node, res_node);
664 return res_node;
665 }
666
667 constexpr auto tensors_queue_attr = "__is_tensors_queue__";
668 if (py::hasattr(obj, tensors_queue_attr) && IsValueNode<StringImm>(attr)) {
669 const auto &attr_name = GetValue<std::string>(GetValueNode(attr));
670 constexpr auto pop_attr = "pop";
671 if (attr_name == pop_attr) {
672 constexpr auto graph_pop_attr = "__graph_pop__";
673 MS_LOG(DEBUG) << "Replace " << pop_attr << " to " << graph_pop_attr << " for " << py::str(obj);
674 return CreateResolveNode(obj, NewValueNode(graph_pop_attr), get_attr_node);
675 }
676 }
677 return CreateResolveNode(obj, attr, get_attr_node);
678 }
679
680 // Get attribute or method from ms_class obj or cell obj.
ResolveClassObjectWithAttr(const py::object & cls_obj,const AnfNodePtr & attr,const AnfNodePtr & get_attr_node)681 AnfNodePtr ResolveClassObjectWithAttr(const py::object &cls_obj, const AnfNodePtr &attr,
682 const AnfNodePtr &get_attr_node) {
683 MS_EXCEPTION_IF_NULL(get_attr_node);
684 MS_LOG(DEBUG) << "Resolve ms_class obj (" << py::str(cls_obj) << ") with attr " << attr->ToString() << ".";
685 TraceGuard trace_guard(std::make_shared<TraceResolve>(get_attr_node->debug_info()));
686 return CreateResolveNode(cls_obj, attr, get_attr_node);
687 }
688
ResolveSequenceWithAttr(const FuncGraphManagerPtr & manager,const py::object & obj,const AnfNodePtr & resolve_node,const AnfNodePtr & attr,const CNodePtr & get_attr_node)689 AnfNodePtr ResolveSequenceWithAttr(const FuncGraphManagerPtr &manager, const py::object &obj,
690 const AnfNodePtr &resolve_node, const AnfNodePtr &attr,
691 const CNodePtr &get_attr_node) {
692 MS_EXCEPTION_IF_NULL(get_attr_node);
693 std::vector<AnfNodePtr> inputs;
694 inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
695 auto sequence = obj.cast<py::sequence>();
696 // Incorporate if all elements of the sequence are Cell instances or MsClass instances.
697 size_t count_cell = 0;
698 size_t count_msclass = 0;
699 size_t sequence_size = sequence.size();
700 for (size_t i = 0; i < sequence_size; ++i) {
701 if (data_converter::IsCellInstance(sequence[i])) {
702 ++count_cell;
703 } else if (data_converter::IsMsClassInstance(sequence[i])) {
704 ++count_msclass;
705 }
706 }
707 if (count_cell == sequence_size) {
708 // Resolve Cell instances.
709 for (size_t i = 0; i < sequence_size; ++i) {
710 auto res = ResolveCellWithAttr(manager, sequence[i], resolve_node, attr, get_attr_node);
711 (void)inputs.emplace_back(res);
712 }
713 } else if (count_msclass == sequence_size) {
714 // Resolve MsClass instances.
715 for (size_t i = 0; i < sequence_size; ++i) {
716 auto res = ResolveClassObjectWithAttr(sequence[i], attr, get_attr_node);
717 (void)inputs.emplace_back(res);
718 }
719 } else {
720 return nullptr;
721 }
722
723 constexpr auto prim_index = 0;
724 constexpr auto index_index = 2;
725 auto fg = get_attr_node->func_graph();
726 MS_EXCEPTION_IF_NULL(fg);
727 auto make_tuple_node = fg->NewCNodeInOrder(inputs);
728 return fg->NewCNodeInOrder({get_attr_node->input(prim_index), make_tuple_node, get_attr_node->input(index_index)});
729 }
730
ResolveSymbolWithAttr(const FuncGraphManagerPtr & manager,const AnfNodePtr & object_node,const AnfNodePtr & attr_node,const AnfNodePtr & get_attr_node)731 AnfNodePtr ResolveSymbolWithAttr(const FuncGraphManagerPtr &manager, const AnfNodePtr &object_node,
732 const AnfNodePtr &attr_node, const AnfNodePtr &get_attr_node) {
733 // {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
734 auto [name_space, symbol] = GetNamespaceAndSymbol(object_node);
735 MS_EXCEPTION_IF_NULL(name_space);
736 MS_EXCEPTION_IF_NULL(symbol);
737 constexpr std::string_view parse_super_name = "namespace";
738 if (symbol->symbol() == parse_super_name) {
739 return nullptr;
740 }
741 const auto &module_name = name_space->module();
742 auto symbol_obj = GetSymbolObject(name_space, symbol, get_attr_node);
743 if (module_name == RESOLVE_NAMESPACE_NAME_CLASS_MEMBER || data_converter::IsCellInstance(symbol_obj)) {
744 auto res = ResolveCellWithAttr(manager, symbol_obj, object_node, attr_node, get_attr_node);
745 res->set_user_data<py::object>("__getattr__", std::make_shared<py::object>(symbol_obj));
746 return res;
747 }
748 return nullptr;
749 }
750
751 // Get python object with index from a list or the whole list if the index is not fixed.
GetObjectFromSequence(const NameSpacePtr & name_space,const SymbolPtr & symbol,const AnfNodePtr & node,const AnfNodePtr & index_node)752 py::object GetObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
753 const AnfNodePtr &index_node) {
754 MS_EXCEPTION_IF_NULL(node);
755 TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
756 py::object obj = GetSymbolObject(name_space, symbol, node);
757 // If obj is nn.CellList, convert it to sequence.
758 py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
759 bool is_cell_list = py::hasattr(obj, PYTHON_CELL_AS_LIST);
760 if (is_cell_list) {
761 obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CONVERT_CELL_LIST_TO_SEQUENCE, obj);
762 }
763 if (!py::isinstance<py::list>(obj) && !py::isinstance<py::tuple>(obj)) {
764 return py::none();
765 }
766
767 MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", index_node: " << index_node->ToString();
768 auto imm_value = GetValueNode<Int64ImmPtr>(index_node);
769 if (imm_value == nullptr) {
770 MS_LOG(DEBUG) << "The index is not a value node, so we return the whole list, node: " << node->DebugString()
771 << ", index_node: " << index_node->DebugString();
772 // Index is not fixed, return the whole list.
773 return obj;
774 }
775 // It index is a value node, get the item of index directly.
776 py::object item_obj =
777 python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_ITEM_FROM_SEQUENCE, obj, py::int_(imm_value->value()));
778 return item_obj;
779 }
780
IsResolveNodeWithGetItem(const AnfNodePtr & node)781 bool IsResolveNodeWithGetItem(const AnfNodePtr &node) {
782 // Check if the node matches: {prim::kPrim::Resolve, ..., 'getitem'}.
783 if (IsPrimitiveCNode(node, prim::kPrimResolve)) {
784 constexpr size_t symbol_index = 2;
785 constexpr auto getitem_symbol = "getitem";
786 auto cnode = node->cast<CNodePtr>();
787 auto symbol = GetValueNode<SymbolPtr>(cnode->input(symbol_index));
788 return symbol->symbol() == getitem_symbol;
789 }
790 return false;
791 }
792
IsGetItemCNode(const AnfNodePtr & node)793 bool IsGetItemCNode(const AnfNodePtr &node) {
794 if (!node->isa<CNode>()) {
795 return false;
796 }
797 auto cnode = node->cast<CNodePtr>();
798 constexpr size_t getitem_inputs_size = 3;
799 if (cnode->size() != getitem_inputs_size) {
800 return false;
801 }
802 constexpr auto prim_index = 0;
803 return IsResolveNodeWithGetItem(cnode->input(prim_index));
804 }
805
ResolveGetItemWithAttr(const FuncGraphManagerPtr & manager,const AnfNodePtr & getitem_node,const AnfNodePtr & attr_node,const AnfNodePtr & node)806 AnfNodePtr ResolveGetItemWithAttr(const FuncGraphManagerPtr &manager, const AnfNodePtr &getitem_node,
807 const AnfNodePtr &attr_node, const AnfNodePtr &node) {
808 // {prim::kPrimGetAttr, {getitem, {prim::kPrimResolve, namespace, symbol}, index}, attr}
809 // {prim::kPrimGetAttr, {getitem, {prim::kPrimGetAttr, ResolveNode, member}, index}, attr}
810 constexpr auto data_index = 1;
811 constexpr auto index_index = 2;
812 auto getitem_cnode = getitem_node->cast<CNodePtr>();
813 auto data_node = getitem_cnode->input(data_index);
814 auto index_node = getitem_cnode->input(index_index);
815 if (IsPrimitiveCNode(data_node, prim::kPrimResolve)) {
816 auto [name_space, symbol] = GetNamespaceAndSymbol(data_node);
817 auto obj = GetObjectFromSequence(name_space, symbol, data_node, index_node);
818 if (py::isinstance<py::none>(obj)) {
819 return nullptr;
820 }
821 if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
822 return ResolveSequenceWithAttr(manager, obj, data_node, attr_node, getitem_cnode);
823 }
824 return ResolveCellWithAttr(manager, obj, data_node, attr_node, node);
825 }
826 if (IsPrimitiveCNode(data_node, prim::kPrimGetAttr)) {
827 auto getattr_cnode = data_node->cast<CNodePtr>();
828 auto resolve_node = getattr_cnode->input(data_index);
829 auto member_node = getattr_cnode->input(index_index);
830 if (IsPrimitiveCNode(resolve_node, prim::kPrimResolve)) {
831 // Check if the result is a new resolve node.
832 auto item_node = ResolveSymbolWithAttr(manager, resolve_node, member_node, node);
833 if (IsPrimitiveCNode(item_node, prim::kPrimResolve)) {
834 auto [name_space, symbol] = GetNamespaceAndSymbol(item_node);
835 auto obj = GetObjectFromSequence(name_space, symbol, item_node, index_node);
836 if (py::isinstance<py::none>(obj)) {
837 return nullptr;
838 }
839 if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
840 return ResolveSequenceWithAttr(manager, obj, item_node, attr_node, getitem_cnode);
841 }
842 return ResolveCellWithAttr(manager, obj, item_node, attr_node, node);
843 }
844 }
845 }
846 return nullptr;
847 }
848
ResolveInterpretedObjectOfSetAttr(const AnfNodePtr & target_node,const AnfNodePtr & attr_node,const AnfNodePtr & value_node)849 AnfNodePtr ResolveInterpretedObjectOfSetAttr(const AnfNodePtr &target_node, const AnfNodePtr &attr_node,
850 const AnfNodePtr &value_node) {
851 auto [name_space, symbol] = GetNamespaceAndSymbol(target_node);
852 MS_EXCEPTION_IF_NULL(name_space);
853 MS_EXCEPTION_IF_NULL(symbol);
854 auto symbol_obj = GetSymbolObject(name_space, symbol, target_node);
855 auto interpreted_obj = std::make_shared<InterpretedObject>(symbol_obj);
856 MS_EXCEPTION_IF_NULL(interpreted_obj);
857 MS_LOG(DEBUG) << "Created a interpreted object: " << interpreted_obj->ToString();
858 const auto &resolve_node = ConvertInterpretedObjForResolve(target_node, interpreted_obj, target_node->func_graph());
859
860 AnfNodePtrList inputs = {NewValueNode(prim::kPrimSetAttr), resolve_node, attr_node, value_node};
861 return target_node->func_graph()->NewCNodeInOrder(std::move(inputs));
862 }
863
864 namespace {
GetOptResolvePasses(const opt::irpass::ResolveIRPassLib & irpass)865 opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) {
866 // For resolve and getattr primitive.
867 opt::OptPassGroupMap map({
868 {"resolve",
869 {
870 irpass.resolver_,
871 }},
872 });
873 return map;
874 }
875 } // namespace
876
ResolveFuncGraph(const FuncGraphPtr & func_graph,const pipeline::ResourceBasePtr & res,bool use_profile)877 bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile) {
878 if (func_graph == nullptr || res == nullptr) {
879 MS_LOG(ERROR) << "func_graph or resource is null";
880 return false;
881 }
882 opt::irpass::ResolveIRPassLib irpass;
883 opt::OptimizerPtr opt_resolve =
884 opt::Optimizer::MakeOptimizer("opt_resolve", res, GetOptResolvePasses(irpass), false, false, false);
885
886 (void)python_adapter::set_python_scoped();
887
888 MS_EXCEPTION_IF_NULL(opt_resolve);
889 (void)opt_resolve->step(func_graph, use_profile);
890 return true;
891 }
892
ResolveAll(const FuncGraphManagerPtr & manager)893 bool ResolveAll(const FuncGraphManagerPtr &manager) {
894 if (manager == nullptr) {
895 MS_LOG(ERROR) << "func graph manager is null";
896 return false;
897 }
898
899 if (manager->roots().size() > 1) {
900 MS_LOG(WARNING)
901 << "After call ResolveAll, only one graph will be kept in GraphManager. ResolveAll can resolve graphs"
902 "called from root graph, so it's not necessary to pass all graphs as roots. "
903 "Please ensure your usage.";
904 }
905 // Should not use pipeline::Resource as Resource::Clean will clean some
906 // global variable such as ScopeManager, it will cause JExpandedGraphs::GetBprop
907 // fail as valid scope has been cleaned.
908 auto res = std::make_shared<pipeline::ResourceBase>();
909 res->set_manager(manager);
910
911 auto roots = manager->roots();
912 for (const auto &fg : roots) {
913 bool ret = ResolveFuncGraph(fg, res, false);
914 if (!ret) {
915 MS_EXCEPTION_IF_NULL(fg);
916 MS_LOG(ERROR) << "Resolve fg " << fg->ToString() << " failed";
917 }
918 }
919 return true;
920 }
921
922 // If any mixed precision flag add a cast node after the parameter node.
923 // argument obj should be python Parameter object
924 // it will be converted to Parameter node here
ResolveParameterObj(const FuncGraphPtr & func_graph,const py::object & obj)925 AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) {
926 MS_EXCEPTION_IF_NULL(func_graph);
927
928 // Parameter object should not be none
929 if (py::isinstance<py::none>(obj)) {
930 MS_LOG(EXCEPTION) << "Resolve class Parameter error because obj is null.";
931 }
932
933 if (!py::hasattr(obj, "name")) {
934 MS_LOG(EXCEPTION) << "Resolve class Parameter error: cannot find name attr for obj";
935 }
936
937 // Get the parameter name from parameter object
938 auto name_attr = python_adapter::GetPyObjAttr(obj, "name");
939 if (py::isinstance<py::none>(name_attr)) {
940 MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
941 }
942 auto obj_id = GetPyObjId(obj);
943 auto param_name = py::cast<std::string>(name_attr);
944 auto top_func_graph = Parser::GetTopFuncGraph();
945 // If the parameter node has been created , return it.
946 ParameterPtr para_node = nullptr;
947 for (auto const ¶m : top_func_graph->parameters()) {
948 auto param_node = dyn_cast<Parameter>(param);
949 if (param_node != nullptr && param_node->name() == param_name) {
950 if (param_node->is_top_graph_param()) {
951 // If the name of the input of construct is same as the parameters,
952 // add suffix to the name of the input of construct.
953 string suffix_name = param_node->name() + "_$";
954 param_node->set_name(suffix_name);
955 param_node->debug_info()->set_name(suffix_name);
956 MS_LOG(DEBUG) << "Add suffix to the name of the input of construct " << func_graph->ToString()
957 << ", input: " << param_node->DebugString();
958 } else {
959 // Exist two parameter object which name is the same.
960 auto iter = param_obj_ids.find(param_name);
961 if (iter != param_obj_ids.end() && iter->second != obj_id) {
962 MS_LOG(EXCEPTION)
963 << "The parameter " << param_node->DebugString() << " , its name '" << param_name
964 << "' already exists. Please set a unique name for the parameter."
965 << "\nFor more details with the name of parameter, please refer to "
966 << "https://mindspore.cn/search?inputValue=Please%20set%20a%20unique%20name%20for%20the%20parameter";
967 }
968 para_node = param_node;
969 MS_LOG(DEBUG) << "Found existing parameter for " << func_graph->ToString()
970 << ", param: " << para_node->DebugString() << ", top_func_graph: " << top_func_graph->ToString();
971 break;
972 }
973 }
974 }
975 if (para_node == nullptr) {
976 auto value = GetParameterValue(obj);
977 para_node = top_func_graph->AddFvParameter(param_name, value);
978 param_obj_ids[param_name] = obj_id;
979 MS_LOG(DEBUG) << "Created a new weight parameter for " << func_graph->ToString()
980 << ", param: " << para_node->DebugString() << ", top_func_graph: " << top_func_graph->ToString();
981 auto context = parallel::ParallelContext::GetInstance();
982 if (context != nullptr && para_node->has_default()) {
983 auto param_abs = pipeline::GetDefaultValueAbstract(para_node);
984 context->ParallelParameterContextRestoreShape(top_func_graph, para_node, param_abs);
985 para_node->set_abstract(param_abs);
986 }
987 }
988 func_graph->add_parameter_obj_node(para_node);
989 return para_node;
990 }
991 } // namespace parse
992 } // namespace mindspore
993