• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/parser/tf/functionalize_control_op_pass.h"
18 #include <algorithm>
19 #include <deque>
20 #include "tools/converter/parser/tf/functionalize_while.h"
21 #include "tools/converter/parser/tf/functionalize_cond.h"
22 #include "include/errorcode.h"
23 #include "nnacl/op_base.h"
24 #include "src/common/log_util.h"
25 
26 namespace mindspore::opt {
27 
NewFuncGraph(const std::string & subgraph_name,const FmkType & fmk_type)28 FuncGraphPtr FunctionalizeControlOpPass::NewFuncGraph(const std::string &subgraph_name, const FmkType &fmk_type) {
29   auto fg = std::make_shared<FuncGraph>();
30   if (fg == nullptr) {
31     MS_LOG(ERROR) << "new func_graph failed.";
32     return nullptr;
33   }
34   fg->set_attr("graph_name", MakeValue(subgraph_name));
35   fg->set_attr("fmk", MakeValue(static_cast<int>(fmk_type)));
36   return fg;
37 }
38 
NodeClusterName(const AnfNodePtr & node)39 std::string FunctionalizeControlOpPass::NodeClusterName(const AnfNodePtr &node) {
40   std::string cluster_name{};
41   if (node == nullptr) {
42     MS_LOG(ERROR) << "node is nullptr.";
43     return cluster_name;
44   }
45   if (!utils::isa<CNodePtr>(node)) {
46     MS_LOG(ERROR) << "node is invalid.";
47     return cluster_name;
48   }
49   // tf node name use '/' split node name
50   auto cnode = utils::cast<CNodePtr>(node);
51   std::string word_in_name = "while/";
52   size_t pos = cnode->fullname_with_scope().rfind(word_in_name);
53   if (pos != std::string::npos) {
54     cluster_name = cnode->fullname_with_scope().substr(0, pos + word_in_name.size());
55   } else {
56     cluster_name = cnode->fullname_with_scope();
57   }
58   return cluster_name;
59 }
60 
InitNodeClusters(const FuncGraphPtr & func_graph)61 void FunctionalizeControlOpPass::InitNodeClusters(const FuncGraphPtr &func_graph) {
62   MS_CHECK_TRUE_RET_VOID(func_graph != nullptr);
63   for (auto &node : func_graph->nodes()) {
64     if (!utils::isa<CNodePtr>(node)) {
65       continue;
66     }
67     auto cluster_name = NodeClusterName(node);
68     auto cluster_pos = WhichCluster(cluster_name);
69     if (cluster_pos == node_clusters_.size()) {
70       std::vector<AnfNodePtr> node_list{node};
71       node_clusters_.emplace_back(std::make_pair(cluster_name, node_list));
72     } else {
73       node_clusters_[cluster_pos].second.push_back(node);
74     }
75   }
76   // sort node_clusters_
77   std::sort(node_clusters_.begin(), node_clusters_.end(),
78             [](const std::pair<std::string, std::vector<AnfNodePtr>> &a,
79                const std::pair<std::string, std::vector<AnfNodePtr>> &b) {
80               if (a.first.size() != b.first.size()) {
81                 return a.first.size() > b.first.size();
82               } else {
83                 return a.first > b.first;
84               }
85             });
86 }
87 
WhichCluster(const std::string & cluster_name)88 size_t FunctionalizeControlOpPass::WhichCluster(const std::string &cluster_name) {
89   size_t pos = node_clusters_.size();
90   for (size_t i = 0; i < pos; ++i) {
91     if (node_clusters_[i].first == cluster_name) {
92       return i;
93     }
94   }
95   return pos;
96 }
97 
BuildWhileSubgraph(const FuncGraphPtr & func_graph)98 STATUS FunctionalizeControlOpPass::BuildWhileSubgraph(const FuncGraphPtr &func_graph) {
99   CHECK_NULL_RETURN(func_graph);
100   int ret = RET_OK;
101   for (auto &node_cluster : node_clusters_) {
102     for (auto &node : node_cluster.second) {
103       if (IsLoopCond(node)) {
104         loop_cond_nodes_.push_back(node->cast<CNodePtr>());
105         FunctionalizeWhile fw(node_cluster.second, node->cast<CNodePtr>(), func_graph);
106         ret = fw.Process();
107         if (ret != RET_OK) {
108           MS_LOG(ERROR) << "run functionalize while failed, ret: " << ret;
109           return ret;
110         }
111         break;
112       }
113     }
114   }
115   return ret;
116 }
117 
BuildIfSubgraph(const FuncGraphPtr & func_graph)118 STATUS FunctionalizeControlOpPass::BuildIfSubgraph(const FuncGraphPtr &func_graph) {
119   CHECK_NULL_RETURN(func_graph);
120   int ret = RET_OK;
121   auto nodes = func_graph->nodes();
122   for (auto &node : nodes) {
123     CHECK_NULL_RETURN(node);
124     if (!IsMerge(node)) {
125       continue;
126     }
127     auto cnode = utils::cast<CNodePtr>(node);
128     FunctionalizeCond fc(func_graph, cnode);
129     ret = fc.Process();
130     if (ret != RET_OK) {
131       MS_LOG(ERROR) << "run functionalize cond failed, ret: " << ret;
132       return ret;
133     }
134   }
135 
136   return ret;
137 }
138 
Run(const FuncGraphPtr & func_graph)139 bool FunctionalizeControlOpPass::Run(const FuncGraphPtr &func_graph) {
140   if (func_graph == nullptr) {
141     MS_LOG(ERROR) << "func_graph is nullptr, build while subgraph failed.";
142     return false;
143   }
144   // use name to find the frame
145   InitNodeClusters(func_graph);
146   if (BuildWhileSubgraph(func_graph) != RET_OK) {
147     MS_LOG(ERROR) << "build while subgraph failed.";
148     return false;
149   }
150   if (BuildIfSubgraph(func_graph) != RET_OK) {
151     MS_LOG(ERROR) << "build while subgraph failed.";
152     return false;
153   }
154   return true;
155 }
156 
BelongToWhichNode(const CNodePtr & node,const AimFunc & aim_func,const FilterFunc & filter_func)157 CNodePtr FunctionalizeControlOpPass::BelongToWhichNode(const CNodePtr &node, const AimFunc &aim_func,
158                                                        const FilterFunc &filter_func) {
159   if (node == nullptr) {
160     MS_LOG(ERROR) << "node is null,search node belong to which node failed.";
161     return nullptr;
162   }
163   if (aim_func(node)) {
164     return node;
165   }
166   CNodePtr aim_node = nullptr;
167   std::deque<AnfNodePtr> todo(256);
168   todo.clear();
169   for (auto &input_node : node->inputs()) {
170     if (aim_func(input_node)) {
171       aim_node = utils::cast<CNodePtr>(input_node);
172       todo.clear();
173       break;
174     }
175     todo.push_back(input_node);
176   }
177 
178   while (!todo.empty()) {
179     AnfNodePtr todo_node = todo.front();
180     todo.pop_front();
181     if (aim_func(todo_node)) {
182       if (filter_func == nullptr || filter_func(todo_node)) {
183         aim_node = utils::cast<CNodePtr>(todo_node);
184         todo.clear();
185         break;
186       }
187     }
188     if (utils::isa<CNodePtr>(todo_node)) {
189       auto cnode = utils::cast<CNodePtr>(todo_node);
190       for (size_t i = 0; i < cnode->size(); i++) {
191         todo.push_back(cnode->input(i));
192       }
193     }
194   }
195   if (aim_node == nullptr) {
196     MS_LOG(WARNING) << "not found belonging enter node.";
197     return nullptr;
198   }
199 
200   return aim_node;
201 }
202 }  // namespace mindspore::opt
203