1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/parser/tf/functionalize_cond.h"
18 #include <algorithm>
19 #include <memory>
20 #include <deque>
21 #include <unordered_set>
22 #include "mindspore/core/ops/sequence_ops.h"
23 #include "mindspore/core/ops/framework_ops.h"
24 #include "include/errorcode.h"
25 #include "tools/converter/ops/ops_def.h"
26 #include "nnacl/op_base.h"
27 #include "src/common/log_util.h"
28 #include "ops/return.h"
29 #include "tools/lite_exporter/fetch_content.h"
30
31 namespace mindspore::opt {
GetSwitchBranchType(const CNodePtr & switch_cnode,BranchType * branch_type)32 STATUS FunctionalizeCond::GetSwitchBranchType(const CNodePtr &switch_cnode, BranchType *branch_type) {
33 MS_ASSERT(switch_cnode != nullptr);
34 MS_ASSERT(branch_type != nullptr);
35 auto manager = fg_->manager();
36 if (manager == nullptr) {
37 MS_LOG(ERROR) << "manager is nullptr";
38 return RET_ERROR;
39 }
40 auto node_users = manager->node_users()[switch_cnode];
41 if (node_users.size() != 1) { // only one output of switch is referenced in cond
42 MS_LOG(ERROR) << "switch's node users is not correct";
43 return RET_ERROR;
44 }
45 auto node_user = node_users.front();
46 auto tuple_get_item = node_user.first;
47 if (!utils::isa<CNodePtr>(tuple_get_item) || !CheckPrimitiveType(tuple_get_item, prim::kPrimTupleGetItem)) {
48 MS_LOG(ERROR) << "switch's node user is not TupleGetItem";
49 return RET_ERROR;
50 }
51 auto tuple_get_item_cnode = utils::cast<CNodePtr>(tuple_get_item);
52 auto idx = GetTupleGetItemOutIndex(tuple_get_item_cnode);
53 if (idx == 0) {
54 *branch_type = kElseBranch;
55 } else if (idx == 1) {
56 *branch_type = kThenBranch;
57 } else {
58 MS_LOG(ERROR) << "wrong tuple_get_item index";
59 return RET_ERROR;
60 }
61 return RET_OK;
62 }
63
CheckBranchIsEffective(const CNodePtr & switch_cnode,BranchType branch_type)64 void FunctionalizeCond::CheckBranchIsEffective(const CNodePtr &switch_cnode, BranchType branch_type) {
65 MS_ASSERT(switch_cnode != nullptr);
66 MS_ASSERT(is_effective != nullptr);
67 if (switch_cnode->size() < C3NUM) {
68 return;
69 }
70 auto cond_node = switch_cnode->input(kInputIndexTwo);
71 if (!utils::isa<Parameter>(cond_node)) {
72 return;
73 }
74 auto cond_pnode = cond_node->cast<ParameterPtr>();
75 lite::DataInfo data_info;
76 if (FetchFromDefaultParam(cond_pnode, converter::FmkType::kFmkTypeTf, &data_info, false) != lite::RET_OK) {
77 return;
78 }
79 if (data_info.data_ptr_ == nullptr || data_info.data_type_ != kNumberTypeBool) {
80 return;
81 }
82 bool cond = *(static_cast<bool *>(data_info.data_ptr_));
83 if (branch_type == kThenBranch) {
84 then_switch_ = switch_cnode;
85 then_is_effective_ = cond;
86 }
87 if (branch_type == kElseBranch) {
88 else_switch_ = switch_cnode;
89 else_is_effective_ = !cond;
90 }
91 }
92
BranchSubGraphAddNodes(const FuncGraphPtr & graph,const AnfNodePtr & root_node,BranchType branch_type)93 STATUS FunctionalizeCond::BranchSubGraphAddNodes(const FuncGraphPtr &graph, const AnfNodePtr &root_node,
94 BranchType branch_type) {
95 CHECK_NULL_RETURN(graph);
96 CHECK_NULL_RETURN(root_node);
97 std::deque<AnfNodePtr> q;
98 std::unordered_set<AnfNodePtr> vis;
99 q.push_back(root_node);
100 while (!q.empty()) {
101 auto node = q.front();
102 CHECK_NULL_RETURN(node);
103 q.pop_front();
104 vis.insert(node);
105 if (FunctionalizeControlOpPass::IsSwitch(node)) {
106 auto cnode = utils::cast<CNodePtr>(node);
107 BranchType this_type;
108 if (GetSwitchBranchType(cnode, &this_type) != RET_OK || this_type != branch_type) {
109 MS_LOG(ERROR) << "switch node in branch " << branch_type << " is not correct";
110 return RET_ERROR;
111 }
112 CheckBranchIsEffective(cnode, branch_type);
113 continue;
114 }
115 if (utils::isa<ParameterPtr>(node)) {
116 graph->add_parameter(node->cast<ParameterPtr>());
117 }
118 graph->AddNode(node);
119 if (!utils::isa<ValueNodePtr>(node)) {
120 node->set_func_graph(graph);
121 }
122 if (utils::isa<CNodePtr>(node)) {
123 auto cnode = utils::cast<CNodePtr>(node);
124 for (size_t i = 1; i < cnode->size(); i++) {
125 auto inputi = cnode->input(i);
126 if (vis.find(inputi) == vis.end()) {
127 q.push_back(cnode->input(i));
128 }
129 }
130 }
131 }
132 return RET_OK;
133 }
134
PosInInputNodes(const CNodePtr & node)135 int FunctionalizeCond::PosInInputNodes(const CNodePtr &node) {
136 auto index = std::find(input_nodes_.begin(), input_nodes_.end(), node);
137 if (index == input_nodes_.end()) {
138 input_nodes_.push_back(node);
139 return input_nodes_.size() - 1;
140 }
141 return index - input_nodes_.begin();
142 }
143
IdentifySubgraphInput(const FuncGraphPtr & graph,std::string graph_name)144 STATUS FunctionalizeCond::IdentifySubgraphInput(const FuncGraphPtr &graph, std::string graph_name) {
145 CHECK_NULL_RETURN(graph);
146 std::vector<AnfNodePtr> nodes_need_drop{};
147 for (auto &cnode : graph->GetOrderedCnodes()) {
148 for (auto &input_node : cnode->inputs()) {
149 if (FunctionalizeControlOpPass::IsSwitch(input_node)) {
150 CHECK_NULL_RETURN(input_node);
151 auto switch_node = input_node->cast<CNodePtr>();
152 CHECK_NULL_RETURN(switch_node);
153 auto switch_input = utils::cast<CNodePtr>(switch_node->input(1));
154 auto pos = PosInInputNodes(switch_input);
155 nodes_need_drop.push_back(cnode);
156 pred_nodes_.push_back(switch_node->input(kInputIndexTwo));
157 // set parameter
158 auto parameter = graph->add_parameter();
159 CHECK_NULL_RETURN(parameter);
160 parameter->set_abstract(cnode->abstract());
161 // hardcode for subgraph input name
162 parameter->set_name(graph_name + "_input_" + std::to_string(pos) + "_parameter");
163
164 // replace switch
165 auto manager = fg_->manager();
166 CHECK_NULL_RETURN(manager);
167 auto node_users = manager->node_users()[cnode];
168 for (auto &node_user : node_users) {
169 if (graph->nodes().contains(node_user.first)) {
170 manager->SetEdge(node_user.first, node_user.second, parameter);
171 }
172 }
173 }
174 }
175 }
176 return RET_OK;
177 }
178
CreateBranchGraph(const AnfNodePtr & node,std::string name,BranchType branch_type)179 FuncGraphPtr FunctionalizeCond::CreateBranchGraph(const AnfNodePtr &node, std::string name, BranchType branch_type) {
180 MS_CHECK_TRUE_RET(node != nullptr, nullptr);
181 auto graph = FunctionalizeControlOpPass::NewFuncGraph(name, converter::kFmkTypeTf);
182 if (graph == nullptr) {
183 MS_LOG(ERROR) << "new graph Partial Node return nullptr";
184 return nullptr;
185 }
186 graph->set_manager(fg_->manager());
187 auto status = BranchSubGraphAddNodes(graph, node, branch_type);
188 if (status != RET_OK) {
189 return nullptr;
190 }
191
192 if (!CheckPrimitiveType(node, prim::kPrimSwitch)) { // graph is not empty
193 auto return_prim_ptr = std::make_shared<ops::Return>();
194 if (return_prim_ptr == nullptr) {
195 MS_LOG(ERROR) << "GetReturnPrim return nullptr";
196 return nullptr;
197 }
198 auto return_prim_c = return_prim_ptr->GetPrim();
199 MS_CHECK_TRUE_RET(return_prim_c != nullptr, nullptr);
200 auto value_node = NewValueNode(return_prim_c);
201 MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
202 std::vector<AnfNodePtr> op_inputs{value_node, node}; // If subgraph only has one output tensor
203 auto return_cnode = graph->NewCNode(op_inputs);
204 MS_CHECK_TRUE_RET(return_cnode != nullptr, nullptr);
205 return_cnode->set_fullname_with_scope(name + "-return");
206 return_cnode->set_func_graph(graph);
207 graph->set_return(return_cnode);
208 auto graph_output = graph->output();
209 MS_CHECK_TRUE_RET(graph_output != nullptr, nullptr);
210 auto graph_output_cnode = graph_output->cast<CNodePtr>();
211 MS_CHECK_TRUE_RET(graph_output_cnode != nullptr, nullptr);
212 graph_output_cnode->set_fullname_with_scope(name + "_output_0_cnode");
213 }
214 return graph;
215 }
216
CreateNewIf(const FuncGraphPtr & else_branch,const FuncGraphPtr & then_branch)217 CNodePtr FunctionalizeCond::CreateNewIf(const FuncGraphPtr &else_branch, const FuncGraphPtr &then_branch) {
218 MS_CHECK_TRUE_RET(else_branch != nullptr, nullptr);
219 MS_CHECK_TRUE_RET(then_branch != nullptr, nullptr);
220
221 auto if_primc = std::make_shared<mindspore::lite::If>();
222 if (if_primc == nullptr) {
223 MS_LOG(ERROR) << "new if_primitive failed";
224 return nullptr;
225 }
226 auto if_value_node = NewValueNode(if_primc);
227 if (if_value_node == nullptr) {
228 MS_LOG(ERROR) << "new if_value_node failed";
229 return nullptr;
230 }
231 auto then_value_node = NewValueNode(then_branch);
232 MS_CHECK_TRUE_RET(then_value_node != nullptr, nullptr);
233 auto else_value_node = NewValueNode(else_branch);
234 MS_CHECK_TRUE_RET(else_value_node != nullptr, nullptr);
235 std::vector<AnfNodePtr> if_op_inputs = {if_value_node, then_value_node, else_value_node, pred_node_};
236 std::copy(input_nodes_.begin(), input_nodes_.end(), std::back_inserter(if_op_inputs));
237 return fg_->NewCNode(if_op_inputs);
238 }
239
VerifyPredictNode()240 STATUS FunctionalizeCond::VerifyPredictNode() {
241 if (pred_nodes_.empty()) {
242 return RET_ERROR;
243 }
244 for (size_t i = 1; i < pred_nodes_.size(); ++i) {
245 if (pred_nodes_[i] != pred_nodes_[0]) {
246 return RET_ERROR;
247 }
248 }
249 pred_node_ = pred_nodes_[0];
250 return RET_OK;
251 }
252
DegenerateNonControlFlow(const FuncGraphPtr & else_graph,const FuncGraphPtr & then_graph)253 STATUS FunctionalizeCond::DegenerateNonControlFlow(const FuncGraphPtr &else_graph, const FuncGraphPtr &then_graph) {
254 MS_ASSERT(else_graph != nullptr && then_graph != nullptr);
255 std::vector<AnfNodePtr> nodes;
256 auto else_nodes = else_graph->nodes();
257 nodes.insert(nodes.end(), else_nodes.begin(), else_nodes.end());
258 auto then_nodes = then_graph->nodes();
259 nodes.insert(nodes.end(), then_nodes.begin(), then_nodes.end());
260 for (auto &node : nodes) {
261 MS_CHECK_TRUE_MSG(node != nullptr, lite::RET_ERROR, "find a node is a nullptr.");
262 if (!utils::isa<ValueNode>(node)) {
263 node->set_func_graph(fg_);
264 }
265 }
266 auto manager = fg_->manager();
267 MS_CHECK_TRUE_MSG(manager != nullptr, lite::RET_ERROR, "manager must be not a nullptr.");
268 CNodePtr switch_op{nullptr};
269 int merge_input_index = 1;
270 if (then_is_effective_ && !else_is_effective_) {
271 switch_op = then_switch_;
272 merge_input_index = kInputIndexTwo;
273 } else if (else_is_effective_ && !then_is_effective_) {
274 switch_op = else_switch_;
275 } else {
276 return lite::RET_ERROR;
277 }
278 MS_CHECK_TRUE_MSG(switch_op != nullptr, lite::RET_NULL_PTR, "switch node is a nullptr.");
279 MS_CHECK_TRUE_MSG(switch_op->size() >= kInputSizeThree, lite::RET_ERROR, "switch's inputs-size is invalid.");
280 auto node_users = manager->node_users()[switch_op];
281 for (auto &node_user : node_users) {
282 auto post_node = node_user.first;
283 if (!CheckPrimitiveType(post_node, prim::kPrimTupleGetItem)) {
284 MS_LOG(ERROR) << "switch's post-node must be TupleGetItem.";
285 return lite::RET_ERROR;
286 }
287 if (!manager->Replace(post_node, switch_op->input(1))) {
288 MS_LOG(ERROR) << "Manager: Replace unused switch-node failed.";
289 return lite::RET_ERROR;
290 }
291 }
292 if (!manager->Replace(merge_node_, merge_node_->input(merge_input_index))) {
293 MS_LOG(ERROR) << "Manager: Replace unused merge-node failed.";
294 return lite::RET_ERROR;
295 }
296 return lite::RET_OK;
297 }
298
Process()299 STATUS FunctionalizeCond::Process() {
300 if (fg_ == nullptr || merge_node_ == nullptr || merge_node_->size() != kInputSizeThree) {
301 MS_LOG(ERROR) << "fg or merge is not correct";
302 return RET_ERROR;
303 }
304
305 then_is_effective_ = true;
306 then_switch_ = nullptr;
307 else_is_effective_ = true;
308 else_switch_ = nullptr;
309 auto else_branch_name = merge_node_->fullname_with_scope() + "-partial-if-else";
310 auto then_branch_name = merge_node_->fullname_with_scope() + "-partial-then-else";
311
312 auto else_branch = CreateBranchGraph(merge_node_->input(1), else_branch_name, kElseBranch);
313 if (else_branch == nullptr) {
314 MS_LOG(ERROR) << "create else branch failed";
315 return RET_ERROR;
316 }
317 auto then_branch = CreateBranchGraph(merge_node_->input(kInputIndexTwo), then_branch_name, kThenBranch);
318 if (then_branch == nullptr) {
319 MS_LOG(ERROR) << "create then branch failed";
320 return RET_ERROR;
321 }
322 if (else_is_effective_ ^ then_is_effective_) {
323 auto status = DegenerateNonControlFlow(else_branch, then_branch);
324 if (status != lite::RET_OK) {
325 MS_LOG(ERROR) << "Degenerate to non-control-flow failed.";
326 }
327 return status;
328 }
329
330 auto status = IdentifySubgraphInput(else_branch, else_branch_name);
331 if (status != RET_OK) {
332 return status;
333 }
334 status = IdentifySubgraphInput(then_branch, then_branch_name);
335 if (status != RET_OK) {
336 return status;
337 }
338
339 status = VerifyPredictNode();
340 if (status != RET_OK) {
341 return status;
342 }
343
344 auto if_node = CreateNewIf(else_branch, then_branch);
345 if (if_node == nullptr) {
346 MS_LOG(ERROR) << "create if node error";
347 return RET_ERROR;
348 }
349 if_node->set_abstract(merge_node_->abstract()->Clone());
350 auto manager = fg_->manager();
351 CHECK_NULL_RETURN(manager);
352 auto node_users = manager->node_users()[merge_node_];
353 for (auto &node_user : node_users) {
354 manager->SetEdge(node_user.first, node_user.second, if_node);
355 }
356 return RET_OK;
357 }
358 } // namespace mindspore::opt
359