1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/parser/tf/functionalize_control_op_pass.h"
18 #include <algorithm>
19 #include <deque>
20 #include "tools/converter/parser/tf/functionalize_while.h"
21 #include "tools/converter/parser/tf/functionalize_cond.h"
22 #include "include/errorcode.h"
23 #include "nnacl/op_base.h"
24 #include "src/common/log_util.h"
25
26 namespace mindspore::opt {
27
NewFuncGraph(const std::string & subgraph_name,const FmkType & fmk_type)28 FuncGraphPtr FunctionalizeControlOpPass::NewFuncGraph(const std::string &subgraph_name, const FmkType &fmk_type) {
29 auto fg = std::make_shared<FuncGraph>();
30 if (fg == nullptr) {
31 MS_LOG(ERROR) << "new func_graph failed.";
32 return nullptr;
33 }
34 fg->set_attr("graph_name", MakeValue(subgraph_name));
35 fg->set_attr("fmk", MakeValue(static_cast<int>(fmk_type)));
36 return fg;
37 }
38
NodeClusterName(const AnfNodePtr & node)39 std::string FunctionalizeControlOpPass::NodeClusterName(const AnfNodePtr &node) {
40 std::string cluster_name{};
41 if (node == nullptr) {
42 MS_LOG(ERROR) << "node is nullptr.";
43 return cluster_name;
44 }
45 if (!utils::isa<CNodePtr>(node)) {
46 MS_LOG(ERROR) << "node is invalid.";
47 return cluster_name;
48 }
49 // tf node name use '/' split node name
50 auto cnode = utils::cast<CNodePtr>(node);
51 std::string word_in_name = "while/";
52 size_t pos = cnode->fullname_with_scope().rfind(word_in_name);
53 if (pos != std::string::npos) {
54 cluster_name = cnode->fullname_with_scope().substr(0, pos + word_in_name.size());
55 } else {
56 cluster_name = cnode->fullname_with_scope();
57 }
58 return cluster_name;
59 }
60
InitNodeClusters(const FuncGraphPtr & func_graph)61 void FunctionalizeControlOpPass::InitNodeClusters(const FuncGraphPtr &func_graph) {
62 MS_CHECK_TRUE_RET_VOID(func_graph != nullptr);
63 for (auto &node : func_graph->nodes()) {
64 if (!utils::isa<CNodePtr>(node)) {
65 continue;
66 }
67 auto cluster_name = NodeClusterName(node);
68 auto cluster_pos = WhichCluster(cluster_name);
69 if (cluster_pos == node_clusters_.size()) {
70 std::vector<AnfNodePtr> node_list{node};
71 node_clusters_.emplace_back(std::make_pair(cluster_name, node_list));
72 } else {
73 node_clusters_[cluster_pos].second.push_back(node);
74 }
75 }
76 // sort node_clusters_
77 std::sort(node_clusters_.begin(), node_clusters_.end(),
78 [](const std::pair<std::string, std::vector<AnfNodePtr>> &a,
79 const std::pair<std::string, std::vector<AnfNodePtr>> &b) {
80 if (a.first.size() != b.first.size()) {
81 return a.first.size() > b.first.size();
82 } else {
83 return a.first > b.first;
84 }
85 });
86 }
87
WhichCluster(const std::string & cluster_name)88 size_t FunctionalizeControlOpPass::WhichCluster(const std::string &cluster_name) {
89 size_t pos = node_clusters_.size();
90 for (size_t i = 0; i < pos; ++i) {
91 if (node_clusters_[i].first == cluster_name) {
92 return i;
93 }
94 }
95 return pos;
96 }
97
BuildWhileSubgraph(const FuncGraphPtr & func_graph)98 STATUS FunctionalizeControlOpPass::BuildWhileSubgraph(const FuncGraphPtr &func_graph) {
99 CHECK_NULL_RETURN(func_graph);
100 int ret = RET_OK;
101 for (auto &node_cluster : node_clusters_) {
102 for (auto &node : node_cluster.second) {
103 if (IsLoopCond(node)) {
104 loop_cond_nodes_.push_back(node->cast<CNodePtr>());
105 FunctionalizeWhile fw(node_cluster.second, node->cast<CNodePtr>(), func_graph);
106 ret = fw.Process();
107 if (ret != RET_OK) {
108 MS_LOG(ERROR) << "run functionalize while failed, ret: " << ret;
109 return ret;
110 }
111 break;
112 }
113 }
114 }
115 return ret;
116 }
117
BuildIfSubgraph(const FuncGraphPtr & func_graph)118 STATUS FunctionalizeControlOpPass::BuildIfSubgraph(const FuncGraphPtr &func_graph) {
119 CHECK_NULL_RETURN(func_graph);
120 int ret = RET_OK;
121 auto nodes = func_graph->nodes();
122 for (auto &node : nodes) {
123 CHECK_NULL_RETURN(node);
124 if (!IsMerge(node)) {
125 continue;
126 }
127 auto cnode = utils::cast<CNodePtr>(node);
128 FunctionalizeCond fc(func_graph, cnode);
129 ret = fc.Process();
130 if (ret != RET_OK) {
131 MS_LOG(ERROR) << "run functionalize cond failed, ret: " << ret;
132 return ret;
133 }
134 }
135
136 return ret;
137 }
138
Run(const FuncGraphPtr & func_graph)139 bool FunctionalizeControlOpPass::Run(const FuncGraphPtr &func_graph) {
140 if (func_graph == nullptr) {
141 MS_LOG(ERROR) << "func_graph is nullptr, build while subgraph failed.";
142 return false;
143 }
144 // use name to find the frame
145 InitNodeClusters(func_graph);
146 if (BuildWhileSubgraph(func_graph) != RET_OK) {
147 MS_LOG(ERROR) << "build while subgraph failed.";
148 return false;
149 }
150 if (BuildIfSubgraph(func_graph) != RET_OK) {
151 MS_LOG(ERROR) << "build while subgraph failed.";
152 return false;
153 }
154 return true;
155 }
156
BelongToWhichNode(const CNodePtr & node,const AimFunc & aim_func,const FilterFunc & filter_func)157 CNodePtr FunctionalizeControlOpPass::BelongToWhichNode(const CNodePtr &node, const AimFunc &aim_func,
158 const FilterFunc &filter_func) {
159 if (node == nullptr) {
160 MS_LOG(ERROR) << "node is null,search node belong to which node failed.";
161 return nullptr;
162 }
163 if (aim_func(node)) {
164 return node;
165 }
166 CNodePtr aim_node = nullptr;
167 std::deque<AnfNodePtr> todo(256);
168 todo.clear();
169 for (auto &input_node : node->inputs()) {
170 if (aim_func(input_node)) {
171 aim_node = utils::cast<CNodePtr>(input_node);
172 todo.clear();
173 break;
174 }
175 todo.push_back(input_node);
176 }
177
178 while (!todo.empty()) {
179 AnfNodePtr todo_node = todo.front();
180 todo.pop_front();
181 if (aim_func(todo_node)) {
182 if (filter_func == nullptr || filter_func(todo_node)) {
183 aim_node = utils::cast<CNodePtr>(todo_node);
184 todo.clear();
185 break;
186 }
187 }
188 if (utils::isa<CNodePtr>(todo_node)) {
189 auto cnode = utils::cast<CNodePtr>(todo_node);
190 for (size_t i = 0; i < cnode->size(); i++) {
191 todo.push_back(cnode->input(i));
192 }
193 }
194 }
195 if (aim_node == nullptr) {
196 MS_LOG(WARNING) << "not found belonging enter node.";
197 return nullptr;
198 }
199
200 return aim_node;
201 }
202 } // namespace mindspore::opt
203