• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/parser/tf/functionalize_cond.h"
18 #include <algorithm>
19 #include <memory>
20 #include <deque>
21 #include <unordered_set>
22 #include "mindspore/core/ops/sequence_ops.h"
23 #include "mindspore/core/ops/framework_ops.h"
24 #include "include/errorcode.h"
25 #include "tools/converter/ops/ops_def.h"
26 #include "nnacl/op_base.h"
27 #include "src/common/log_util.h"
28 #include "ops/return.h"
29 #include "tools/lite_exporter/fetch_content.h"
30 
31 namespace mindspore::opt {
GetSwitchBranchType(const CNodePtr & switch_cnode,BranchType * branch_type)32 STATUS FunctionalizeCond::GetSwitchBranchType(const CNodePtr &switch_cnode, BranchType *branch_type) {
33   MS_ASSERT(switch_cnode != nullptr);
34   MS_ASSERT(branch_type != nullptr);
35   auto manager = fg_->manager();
36   if (manager == nullptr) {
37     MS_LOG(ERROR) << "manager is nullptr";
38     return RET_ERROR;
39   }
40   auto node_users = manager->node_users()[switch_cnode];
41   if (node_users.size() != 1) {  // only one output of switch is referenced in cond
42     MS_LOG(ERROR) << "switch's node users is not correct";
43     return RET_ERROR;
44   }
45   auto node_user = node_users.front();
46   auto tuple_get_item = node_user.first;
47   if (!utils::isa<CNodePtr>(tuple_get_item) || !CheckPrimitiveType(tuple_get_item, prim::kPrimTupleGetItem)) {
48     MS_LOG(ERROR) << "switch's node user is not TupleGetItem";
49     return RET_ERROR;
50   }
51   auto tuple_get_item_cnode = utils::cast<CNodePtr>(tuple_get_item);
52   auto idx = GetTupleGetItemOutIndex(tuple_get_item_cnode);
53   if (idx == 0) {
54     *branch_type = kElseBranch;
55   } else if (idx == 1) {
56     *branch_type = kThenBranch;
57   } else {
58     MS_LOG(ERROR) << "wrong tuple_get_item index";
59     return RET_ERROR;
60   }
61   return RET_OK;
62 }
63 
CheckBranchIsEffective(const CNodePtr & switch_cnode,BranchType branch_type)64 void FunctionalizeCond::CheckBranchIsEffective(const CNodePtr &switch_cnode, BranchType branch_type) {
65   MS_ASSERT(switch_cnode != nullptr);
66   MS_ASSERT(is_effective != nullptr);
67   if (switch_cnode->size() < C3NUM) {
68     return;
69   }
70   auto cond_node = switch_cnode->input(kInputIndexTwo);
71   if (!utils::isa<Parameter>(cond_node)) {
72     return;
73   }
74   auto cond_pnode = cond_node->cast<ParameterPtr>();
75   lite::DataInfo data_info;
76   if (FetchFromDefaultParam(cond_pnode, converter::FmkType::kFmkTypeTf, &data_info, false) != lite::RET_OK) {
77     return;
78   }
79   if (data_info.data_ptr_ == nullptr || data_info.data_type_ != kNumberTypeBool) {
80     return;
81   }
82   bool cond = *(static_cast<bool *>(data_info.data_ptr_));
83   if (branch_type == kThenBranch) {
84     then_switch_ = switch_cnode;
85     then_is_effective_ = cond;
86   }
87   if (branch_type == kElseBranch) {
88     else_switch_ = switch_cnode;
89     else_is_effective_ = !cond;
90   }
91 }
92 
BranchSubGraphAddNodes(const FuncGraphPtr & graph,const AnfNodePtr & root_node,BranchType branch_type)93 STATUS FunctionalizeCond::BranchSubGraphAddNodes(const FuncGraphPtr &graph, const AnfNodePtr &root_node,
94                                                  BranchType branch_type) {
95   CHECK_NULL_RETURN(graph);
96   CHECK_NULL_RETURN(root_node);
97   std::deque<AnfNodePtr> q;
98   std::unordered_set<AnfNodePtr> vis;
99   q.push_back(root_node);
100   while (!q.empty()) {
101     auto node = q.front();
102     CHECK_NULL_RETURN(node);
103     q.pop_front();
104     vis.insert(node);
105     if (FunctionalizeControlOpPass::IsSwitch(node)) {
106       auto cnode = utils::cast<CNodePtr>(node);
107       BranchType this_type;
108       if (GetSwitchBranchType(cnode, &this_type) != RET_OK || this_type != branch_type) {
109         MS_LOG(ERROR) << "switch node in branch " << branch_type << " is not correct";
110         return RET_ERROR;
111       }
112       CheckBranchIsEffective(cnode, branch_type);
113       continue;
114     }
115     if (utils::isa<ParameterPtr>(node)) {
116       graph->add_parameter(node->cast<ParameterPtr>());
117     }
118     graph->AddNode(node);
119     if (!utils::isa<ValueNodePtr>(node)) {
120       node->set_func_graph(graph);
121     }
122     if (utils::isa<CNodePtr>(node)) {
123       auto cnode = utils::cast<CNodePtr>(node);
124       for (size_t i = 1; i < cnode->size(); i++) {
125         auto inputi = cnode->input(i);
126         if (vis.find(inputi) == vis.end()) {
127           q.push_back(cnode->input(i));
128         }
129       }
130     }
131   }
132   return RET_OK;
133 }
134 
PosInInputNodes(const CNodePtr & node)135 int FunctionalizeCond::PosInInputNodes(const CNodePtr &node) {
136   auto index = std::find(input_nodes_.begin(), input_nodes_.end(), node);
137   if (index == input_nodes_.end()) {
138     input_nodes_.push_back(node);
139     return input_nodes_.size() - 1;
140   }
141   return index - input_nodes_.begin();
142 }
143 
IdentifySubgraphInput(const FuncGraphPtr & graph,std::string graph_name)144 STATUS FunctionalizeCond::IdentifySubgraphInput(const FuncGraphPtr &graph, std::string graph_name) {
145   CHECK_NULL_RETURN(graph);
146   std::vector<AnfNodePtr> nodes_need_drop{};
147   for (auto &cnode : graph->GetOrderedCnodes()) {
148     for (auto &input_node : cnode->inputs()) {
149       if (FunctionalizeControlOpPass::IsSwitch(input_node)) {
150         CHECK_NULL_RETURN(input_node);
151         auto switch_node = input_node->cast<CNodePtr>();
152         CHECK_NULL_RETURN(switch_node);
153         auto switch_input = utils::cast<CNodePtr>(switch_node->input(1));
154         auto pos = PosInInputNodes(switch_input);
155         nodes_need_drop.push_back(cnode);
156         pred_nodes_.push_back(switch_node->input(kInputIndexTwo));
157         // set parameter
158         auto parameter = graph->add_parameter();
159         CHECK_NULL_RETURN(parameter);
160         parameter->set_abstract(cnode->abstract());
161         // hardcode for subgraph input name
162         parameter->set_name(graph_name + "_input_" + std::to_string(pos) + "_parameter");
163 
164         // replace switch
165         auto manager = fg_->manager();
166         CHECK_NULL_RETURN(manager);
167         auto node_users = manager->node_users()[cnode];
168         for (auto &node_user : node_users) {
169           if (graph->nodes().contains(node_user.first)) {
170             manager->SetEdge(node_user.first, node_user.second, parameter);
171           }
172         }
173       }
174     }
175   }
176   return RET_OK;
177 }
178 
CreateBranchGraph(const AnfNodePtr & node,std::string name,BranchType branch_type)179 FuncGraphPtr FunctionalizeCond::CreateBranchGraph(const AnfNodePtr &node, std::string name, BranchType branch_type) {
180   MS_CHECK_TRUE_RET(node != nullptr, nullptr);
181   auto graph = FunctionalizeControlOpPass::NewFuncGraph(name, converter::kFmkTypeTf);
182   if (graph == nullptr) {
183     MS_LOG(ERROR) << "new graph Partial Node return nullptr";
184     return nullptr;
185   }
186   graph->set_manager(fg_->manager());
187   auto status = BranchSubGraphAddNodes(graph, node, branch_type);
188   if (status != RET_OK) {
189     return nullptr;
190   }
191 
192   if (!CheckPrimitiveType(node, prim::kPrimSwitch)) {  // graph is not empty
193     auto return_prim_ptr = std::make_shared<ops::Return>();
194     if (return_prim_ptr == nullptr) {
195       MS_LOG(ERROR) << "GetReturnPrim return nullptr";
196       return nullptr;
197     }
198     auto return_prim_c = return_prim_ptr->GetPrim();
199     MS_CHECK_TRUE_RET(return_prim_c != nullptr, nullptr);
200     auto value_node = NewValueNode(return_prim_c);
201     MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
202     std::vector<AnfNodePtr> op_inputs{value_node, node};  // If subgraph only has one output tensor
203     auto return_cnode = graph->NewCNode(op_inputs);
204     MS_CHECK_TRUE_RET(return_cnode != nullptr, nullptr);
205     return_cnode->set_fullname_with_scope(name + "-return");
206     return_cnode->set_func_graph(graph);
207     graph->set_return(return_cnode);
208     auto graph_output = graph->output();
209     MS_CHECK_TRUE_RET(graph_output != nullptr, nullptr);
210     auto graph_output_cnode = graph_output->cast<CNodePtr>();
211     MS_CHECK_TRUE_RET(graph_output_cnode != nullptr, nullptr);
212     graph_output_cnode->set_fullname_with_scope(name + "_output_0_cnode");
213   }
214   return graph;
215 }
216 
CreateNewIf(const FuncGraphPtr & else_branch,const FuncGraphPtr & then_branch)217 CNodePtr FunctionalizeCond::CreateNewIf(const FuncGraphPtr &else_branch, const FuncGraphPtr &then_branch) {
218   MS_CHECK_TRUE_RET(else_branch != nullptr, nullptr);
219   MS_CHECK_TRUE_RET(then_branch != nullptr, nullptr);
220 
221   auto if_primc = std::make_shared<mindspore::lite::If>();
222   if (if_primc == nullptr) {
223     MS_LOG(ERROR) << "new if_primitive failed";
224     return nullptr;
225   }
226   auto if_value_node = NewValueNode(if_primc);
227   if (if_value_node == nullptr) {
228     MS_LOG(ERROR) << "new if_value_node failed";
229     return nullptr;
230   }
231   auto then_value_node = NewValueNode(then_branch);
232   MS_CHECK_TRUE_RET(then_value_node != nullptr, nullptr);
233   auto else_value_node = NewValueNode(else_branch);
234   MS_CHECK_TRUE_RET(else_value_node != nullptr, nullptr);
235   std::vector<AnfNodePtr> if_op_inputs = {if_value_node, then_value_node, else_value_node, pred_node_};
236   std::copy(input_nodes_.begin(), input_nodes_.end(), std::back_inserter(if_op_inputs));
237   return fg_->NewCNode(if_op_inputs);
238 }
239 
VerifyPredictNode()240 STATUS FunctionalizeCond::VerifyPredictNode() {
241   if (pred_nodes_.empty()) {
242     return RET_ERROR;
243   }
244   for (size_t i = 1; i < pred_nodes_.size(); ++i) {
245     if (pred_nodes_[i] != pred_nodes_[0]) {
246       return RET_ERROR;
247     }
248   }
249   pred_node_ = pred_nodes_[0];
250   return RET_OK;
251 }
252 
DegenerateNonControlFlow(const FuncGraphPtr & else_graph,const FuncGraphPtr & then_graph)253 STATUS FunctionalizeCond::DegenerateNonControlFlow(const FuncGraphPtr &else_graph, const FuncGraphPtr &then_graph) {
254   MS_ASSERT(else_graph != nullptr && then_graph != nullptr);
255   std::vector<AnfNodePtr> nodes;
256   auto else_nodes = else_graph->nodes();
257   nodes.insert(nodes.end(), else_nodes.begin(), else_nodes.end());
258   auto then_nodes = then_graph->nodes();
259   nodes.insert(nodes.end(), then_nodes.begin(), then_nodes.end());
260   for (auto &node : nodes) {
261     MS_CHECK_TRUE_MSG(node != nullptr, lite::RET_ERROR, "find a node is a nullptr.");
262     if (!utils::isa<ValueNode>(node)) {
263       node->set_func_graph(fg_);
264     }
265   }
266   auto manager = fg_->manager();
267   MS_CHECK_TRUE_MSG(manager != nullptr, lite::RET_ERROR, "manager must be not a nullptr.");
268   CNodePtr switch_op{nullptr};
269   int merge_input_index = 1;
270   if (then_is_effective_ && !else_is_effective_) {
271     switch_op = then_switch_;
272     merge_input_index = kInputIndexTwo;
273   } else if (else_is_effective_ && !then_is_effective_) {
274     switch_op = else_switch_;
275   } else {
276     return lite::RET_ERROR;
277   }
278   MS_CHECK_TRUE_MSG(switch_op != nullptr, lite::RET_NULL_PTR, "switch node is a nullptr.");
279   MS_CHECK_TRUE_MSG(switch_op->size() >= kInputSizeThree, lite::RET_ERROR, "switch's inputs-size is invalid.");
280   auto node_users = manager->node_users()[switch_op];
281   for (auto &node_user : node_users) {
282     auto post_node = node_user.first;
283     if (!CheckPrimitiveType(post_node, prim::kPrimTupleGetItem)) {
284       MS_LOG(ERROR) << "switch's post-node must be TupleGetItem.";
285       return lite::RET_ERROR;
286     }
287     if (!manager->Replace(post_node, switch_op->input(1))) {
288       MS_LOG(ERROR) << "Manager: Replace unused switch-node failed.";
289       return lite::RET_ERROR;
290     }
291   }
292   if (!manager->Replace(merge_node_, merge_node_->input(merge_input_index))) {
293     MS_LOG(ERROR) << "Manager: Replace unused merge-node failed.";
294     return lite::RET_ERROR;
295   }
296   return lite::RET_OK;
297 }
298 
Process()299 STATUS FunctionalizeCond::Process() {
300   if (fg_ == nullptr || merge_node_ == nullptr || merge_node_->size() != kInputSizeThree) {
301     MS_LOG(ERROR) << "fg or merge is not correct";
302     return RET_ERROR;
303   }
304 
305   then_is_effective_ = true;
306   then_switch_ = nullptr;
307   else_is_effective_ = true;
308   else_switch_ = nullptr;
309   auto else_branch_name = merge_node_->fullname_with_scope() + "-partial-if-else";
310   auto then_branch_name = merge_node_->fullname_with_scope() + "-partial-then-else";
311 
312   auto else_branch = CreateBranchGraph(merge_node_->input(1), else_branch_name, kElseBranch);
313   if (else_branch == nullptr) {
314     MS_LOG(ERROR) << "create else branch failed";
315     return RET_ERROR;
316   }
317   auto then_branch = CreateBranchGraph(merge_node_->input(kInputIndexTwo), then_branch_name, kThenBranch);
318   if (then_branch == nullptr) {
319     MS_LOG(ERROR) << "create then branch failed";
320     return RET_ERROR;
321   }
322   if (else_is_effective_ ^ then_is_effective_) {
323     auto status = DegenerateNonControlFlow(else_branch, then_branch);
324     if (status != lite::RET_OK) {
325       MS_LOG(ERROR) << "Degenerate to non-control-flow failed.";
326     }
327     return status;
328   }
329 
330   auto status = IdentifySubgraphInput(else_branch, else_branch_name);
331   if (status != RET_OK) {
332     return status;
333   }
334   status = IdentifySubgraphInput(then_branch, then_branch_name);
335   if (status != RET_OK) {
336     return status;
337   }
338 
339   status = VerifyPredictNode();
340   if (status != RET_OK) {
341     return status;
342   }
343 
344   auto if_node = CreateNewIf(else_branch, then_branch);
345   if (if_node == nullptr) {
346     MS_LOG(ERROR) << "create if node error";
347     return RET_ERROR;
348   }
349   if_node->set_abstract(merge_node_->abstract()->Clone());
350   auto manager = fg_->manager();
351   CHECK_NULL_RETURN(manager);
352   auto node_users = manager->node_users()[merge_node_];
353   for (auto &node_user : node_users) {
354     manager->SetEdge(node_user.first, node_user.second, if_node);
355   }
356   return RET_OK;
357 }
358 }  // namespace mindspore::opt
359