1 /**
2 * Copyright 2019-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pipeline/jit/ps/load_mindir.h"
17
18 #include <string>
19 #include <set>
20 #include <memory>
21 #include <algorithm>
22
23 #include "utils/log_adapter.h"
24 #include "abstract/abstract_value.h"
25 #include "pipeline/jit/ps/parse/parse_base.h"
26 #include "utils/check_convert_utils.h"
27 #include "load_mindir/infer_mindir.h"
28
29 namespace mindspore {
30 namespace pipeline {
InferMindIR(const ResourcePtr & resource)31 bool InferMindIR(const ResourcePtr &resource) {
32 MS_EXCEPTION_IF_NULL(resource);
33 const auto &root = resource->func_graph();
34 InferFuncGraphLoaded(root);
35 return true;
36 }
37
ArgsNeededToConvert(const PrimitivePtr & prim)38 std::vector<AnfNodePtr> ArgsNeededToConvert(const PrimitivePtr &prim) {
39 auto op_def = mindspore::ops::GetOpDef(prim->name());
40 std::vector<AnfNodePtr> prim_init_arg_nodes;
41 MS_EXCEPTION_IF_NULL(op_def);
42 // Get init args.
43 for (const auto &op_arg : op_def->args_) {
44 if (op_arg.as_init_arg_) {
45 auto arg_name = op_arg.arg_name_;
46 ValuePtr attr;
47 // "data_format" is renamed as "format" for some operator.
48 if (CheckAndConvertUtils::CheckPrimAttrConverted(prim->name()) && arg_name == "data_format" &&
49 prim->HasAttr("format")) {
50 attr = prim->GetAttr("format");
51 } else if (!prim->HasAttr(arg_name)) {
52 attr = parse::GetArgDefaultValue(prim->name(), arg_name);
53 if (attr == nullptr) {
54 MS_LOG(EXCEPTION) << "Cannot find attribute: " << arg_name << " from primitive :" << prim->name();
55 }
56 } else {
57 attr = prim->GetAttr(arg_name);
58 }
59 (void)prim_init_arg_nodes.emplace_back(NewValueNode(attr));
60 }
61 }
62 return prim_init_arg_nodes;
63 }
64
ModifyOneCNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode)65 void ModifyOneCNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
66 MS_EXCEPTION_IF_NULL(cnode);
67 auto &inputs = cnode->inputs();
68 if (IsValueNode<Primitive>(inputs[0])) {
69 auto prim = GetValueNode<PrimitivePtr>(inputs[0]);
70 if (mindspore::ops::IsPrimitiveFunction(prim->name())) {
71 // Append Primitive arguments to the inputs.
72 std::vector<AnfNodePtr> prim_init_arg_nodes = ArgsNeededToConvert(prim);
73 // Get call args.
74 AnfNodePtrList prim_call_arg_nodes(inputs.begin() + 1, inputs.end());
75 // Create new node.
76 auto new_prim = std::make_shared<Primitive>(*prim);
77 AnfNodePtrList input_nodes{NewValueNode(new_prim)};
78 (void)std::copy(prim_call_arg_nodes.cbegin(), prim_call_arg_nodes.cend(), std::back_inserter(input_nodes));
79 (void)std::copy(prim_init_arg_nodes.cbegin(), prim_init_arg_nodes.cend(), std::back_inserter(input_nodes));
80 auto new_cnode = func_graph->NewCNodeInOrder(input_nodes);
81 MS_LOG(DEBUG) << "Convert primitive args: " << prim->name() << ". node: " << cnode->DebugString()
82 << ", new_node: " << new_cnode->DebugString();
83 auto manager = func_graph->manager();
84 if (manager == nullptr) {
85 manager = MakeManager();
86 manager->AddFuncGraph(func_graph, true);
87 }
88 (void)manager->Replace(cnode, new_cnode);
89 }
90 }
91 }
92
ModifyOneFuncGraph(const FuncGraphPtr & func_graph,std::set<FuncGraphPtr> * func_graph_set,std::set<FuncGraphPtr> * func_graph_modified)93 void ModifyOneFuncGraph(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *func_graph_set,
94 std::set<FuncGraphPtr> *func_graph_modified) {
95 MS_LOG(DEBUG) << "Start modifying: " << func_graph->ToString();
96 std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
97 for (const AnfNodePtr &node : nodes) {
98 MS_EXCEPTION_IF_NULL(node);
99 if (!node->isa<CNode>()) {
100 continue;
101 }
102 auto cnode = node->cast<CNodePtr>();
103 ModifyOneCNode(func_graph, cnode);
104 auto &inputs = cnode->inputs();
105 for (size_t i = 0; i < inputs.size(); ++i) {
106 if (IsValueNode<FuncGraph>(inputs[i])) {
107 FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(inputs[i]);
108 if ((*func_graph_set).find(fg) == (*func_graph_set).end() &&
109 (*func_graph_modified).find(fg) == (*func_graph_modified).end()) {
110 (void)(*func_graph_set).insert(fg);
111 }
112 }
113 }
114 }
115 }
116
ModifyGraphs(const FuncGraphPtr & func_graph)117 void ModifyGraphs(const FuncGraphPtr &func_graph) {
118 std::set<FuncGraphPtr> func_graph_set{};
119 std::set<FuncGraphPtr> func_graph_modified{};
120 func_graph_set.insert(func_graph);
121 // Check every node in every graph to find nodes needed to convert.
122 while (!func_graph_set.empty()) {
123 FuncGraphPtr fg = *func_graph_set.cbegin();
124 if (!func_graph->has_flag("generated_from_mindir_with_prim_func")) {
125 ModifyOneFuncGraph(fg, &func_graph_set, &func_graph_modified);
126 }
127 (void)func_graph_set.erase(fg);
128 (void)func_graph_modified.insert(fg);
129 }
130 }
131
ModifyGraphGeneratedByMindIR(const ResourcePtr & resource)132 bool ModifyGraphGeneratedByMindIR(const ResourcePtr &resource) {
133 MS_EXCEPTION_IF_NULL(resource);
134 const auto &func_graph = resource->func_graph();
135 ModifyGraphs(func_graph);
136 return true;
137 }
138 } // namespace pipeline
139 } // namespace mindspore
140