• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "tools/optimizer/graph/control_flow_pass.h"
17 #include <vector>
18 #include <memory>
19 #include <algorithm>
20 #include "ops/switch.h"
21 #include "ops/fusion/partial_fusion.h"
22 #include "include/errorcode.h"
23 #include "tools/optimizer/common/gllo_utils.h"
24 #include "src/common/log_adapter.h"
25 #include "tools/common/node_util.h"
26 #include "nnacl/op_base.h"
27 
28 namespace mindspore::opt {
ReplaceNode(const FuncGraphPtr & fg,const std::unordered_map<AnfNodePtr,AnfNodePtr> & replace_pairs)29 void ControlFlowPass::ReplaceNode(const FuncGraphPtr &fg,
30                                   const std::unordered_map<AnfNodePtr, AnfNodePtr> &replace_pairs) {
31   for (auto &node : fg->nodes()) {
32     if (!utils::isa<CNodePtr>(node)) {
33       continue;
34     }
35     auto cnode = node->cast<CNodePtr>();
36     MS_ASSERT(cnode != nullptr);
37     auto new_inputs = cnode->inputs();
38     for (auto &input : new_inputs) {
39       if (replace_pairs.find(input) == replace_pairs.end()) {
40         continue;
41       }
42       input = replace_pairs.at(input);
43     }
44     cnode->set_inputs(new_inputs);
45   }
46 }
47 
VisitedNodesUsedByAfterParts(const std::set<AnfNodePtr> & visited_nodes,const std::vector<AnfNodePtr> & remain_nodes,std::vector<AnfNodePtr> * visited_nodes_used_by_after_fg)48 void ControlFlowPass::VisitedNodesUsedByAfterParts(const std::set<AnfNodePtr> &visited_nodes,
49                                                    const std::vector<AnfNodePtr> &remain_nodes,
50                                                    std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg) {
51   std::deque<AnfNodePtr> nodes{};
52   std::set<AnfNodePtr> visited_nodes_used_by_after_fg_set{};
53   std::set<AnfNodePtr> remain_nodes_set{};
54   nodes.assign(remain_nodes.begin(), remain_nodes.end());
55   while (!nodes.empty()) {
56     auto node = nodes.front();
57     nodes.pop_front();
58     remain_nodes_set.insert(node);
59     if (!utils::isa<CNodePtr>(node)) {
60       continue;
61     }
62     auto cnode = node->cast<CNodePtr>();
63     MS_ASSERT(cnode != nullptr);
64     for (auto &input : cnode->inputs()) {
65       if (visited_nodes.find(input) != visited_nodes.end() &&
66           visited_nodes_used_by_after_fg_set.find(input) == visited_nodes_used_by_after_fg_set.end()) {
67         visited_nodes_used_by_after_fg->push_back(input);
68         visited_nodes_used_by_after_fg_set.insert(input);
69       }
70     }
71   }
72 }
73 
GetItemVisitedNums(const std::set<AnfNodePtr> & visited_nodes,const AnfNodePtr & tuple_node)74 size_t ControlFlowPass::GetItemVisitedNums(const std::set<AnfNodePtr> &visited_nodes, const AnfNodePtr &tuple_node) {
75   size_t count = 0;
76   for (auto &node : visited_nodes) {
77     if (!utils::isa<CNodePtr>(node)) {
78       continue;
79     }
80     if (!CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
81       continue;
82     }
83     auto get_item_cnode = node->cast<CNodePtr>();
84     MS_ASSERT(get_item_cnode != nullptr);
85     if (get_item_cnode->inputs()[kCNodeFirstInputIndex] == tuple_node) {
86       count++;
87     }
88   }
89   return count;
90 }
91 
MoveGetItemToVisited(const size_t & need_size,const AnfNodePtr & tuple_node,std::set<AnfNodePtr> * visited_nodes,std::vector<AnfNodePtr> * remain_nodes)92 void ControlFlowPass::MoveGetItemToVisited(const size_t &need_size, const AnfNodePtr &tuple_node,
93                                            std::set<AnfNodePtr> *visited_nodes, std::vector<AnfNodePtr> *remain_nodes) {
94   size_t i = 0;
95   for (auto it = remain_nodes->begin(); it != remain_nodes->end();) {
96     if (!utils::isa<CNodePtr>(*it)) {
97       ++it;
98       continue;
99     }
100     if (!CheckPrimitiveType(*it, prim::kPrimTupleGetItem)) {
101       ++it;
102       continue;
103     }
104     auto get_item_cnode = (*it)->cast<CNodePtr>();
105     MS_ASSERT(get_item_cnode != nullptr);
106     if (get_item_cnode->inputs()[kCNodeFirstInputIndex] != tuple_node) {
107       ++it;
108       continue;
109     }
110     i++;
111     visited_nodes->insert(*it);
112     it = remain_nodes->erase(it);
113     if (need_size == i) {
114       return;
115     }
116   }
117   MS_LOG(INFO) << tuple_node->fullname_with_scope() << " not found enough get item, size: " << need_size - i;
118 }
119 
BindGetItemNodes(std::set<AnfNodePtr> * visited_nodes,std::vector<AnfNodePtr> * remain_nodes)120 void ControlFlowPass::BindGetItemNodes(std::set<AnfNodePtr> *visited_nodes, std::vector<AnfNodePtr> *remain_nodes) {
121   std::deque<AnfNodePtr> multi_output_nodes{};
122   for (auto &node : *visited_nodes) {
123     if (!utils::isa<CNodePtr>(node)) {
124       continue;
125     }
126     if (utils::isa<abstract::AbstractTuple>(node->abstract())) {
127       multi_output_nodes.push_back(node);
128     }
129   }
130 
131   while (!multi_output_nodes.empty()) {
132     auto cur_node = multi_output_nodes.front();
133     multi_output_nodes.pop_front();
134     size_t total_getitem_size = cur_node->abstract()->cast<abstract::AbstractTuplePtr>()->size();
135     size_t visited_getitem_size = GetItemVisitedNums(*visited_nodes, cur_node);
136     if (total_getitem_size == visited_getitem_size) {
137       continue;
138     }
139 
140     size_t need_getitem_size = total_getitem_size - visited_getitem_size;
141     MoveGetItemToVisited(need_getitem_size, cur_node, visited_nodes, remain_nodes);
142   }
143 }
144 
SplitGraph(const FuncGraphPtr & fg,AnfNodePtr * control_flow_node,std::set<AnfNodePtr> * visited_nodes,std::vector<AnfNodePtr> * remain_nodes)145 int ControlFlowPass::SplitGraph(const FuncGraphPtr &fg, AnfNodePtr *control_flow_node,
146                                 std::set<AnfNodePtr> *visited_nodes, std::vector<AnfNodePtr> *remain_nodes) {
147   auto inputs = fg->get_inputs();
148 
149   // notice: fg->nodes() is not work in this pass, cause too many useless parameter have been created.
150   auto node_list = TopoSort(fg->get_return());
151   for (auto &node : node_list) {
152     MS_ASSERT(node != nullptr);
153     if (utils::isa<CNodePtr>(node) &&
154         (CheckPrimitiveType(node, prim::kPrimWhile) || CheckPrimitiveType(node, prim::kPrimIf))) {
155       *control_flow_node = node;
156       break;
157     }
158   }
159 
160   std::deque<AnfNodePtr> q;
161   visited_nodes->insert(inputs.begin(), inputs.end());
162   q.push_back(*control_flow_node);
163   while (!q.empty()) {
164     auto node = q.front();
165     q.pop_front();
166     if (!utils::isa<CNodePtr>(node)) {
167       continue;
168     }
169     visited_nodes->insert(node);
170     auto cnode = utils::cast<CNodePtr>(node);
171     MS_CHECK_TRUE_MSG(cnode != nullptr, RET_ERROR, "cast ptr failed");
172     for (size_t i = 0; i < cnode->inputs().size(); i++) {
173       auto input = cnode->input(i);
174       if (visited_nodes->find(input) == visited_nodes->end()) {
175         q.push_back(input);
176       }
177     }
178   }
179 
180   for (auto &node : node_list) {
181     if (visited_nodes->find(node) == visited_nodes->end()) {
182       remain_nodes->push_back(node);
183     }
184   }
185   visited_nodes->erase(*control_flow_node);
186 
187   BindGetItemNodes(visited_nodes, remain_nodes);
188 
189   return RET_SUCCESS;
190 }
191 
CreateAfterGraph(const FuncGraphPtr & main_fg,const std::vector<AnfNodePtr> & remain_nodes,const CNodePtr & aim_cnode,FuncGraphPtr * after_fg)192 int ControlFlowPass::CreateAfterGraph(const FuncGraphPtr &main_fg, const std::vector<AnfNodePtr> &remain_nodes,
193                                       const CNodePtr &aim_cnode, FuncGraphPtr *after_fg) {
194   *after_fg = std::make_shared<FuncGraph>();
195   MS_CHECK_TRUE_MSG(*after_fg != nullptr, lite::RET_NULL_PTR, "*after_fg is nullptr");
196   auto manager = main_fg->manager();
197   MS_ASSERT(manager != nullptr);
198   manager->AddFuncGraph(*after_fg);
199   (*after_fg)->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeTf)));
200   (*after_fg)->set_attr("graph_name", MakeValue(aim_cnode->fullname_with_scope() + "_after_fg"));
201   (*after_fg)->set_manager(main_fg->manager());
202 
203   for (auto &cur_node : remain_nodes) {
204     if (cur_node->isa<ValueNode>()) {
205       continue;
206     }
207     if (cur_node == main_fg->get_return()) {
208       continue;
209     }
210     (*after_fg)->AddNode(cur_node);
211     if (!utils::isa<ValueNodePtr>(cur_node)) {
212       cur_node->set_func_graph(*after_fg);
213     }
214     if (cur_node == main_fg->output()) {
215       (*after_fg)->set_output(cur_node, false);
216     }
217     main_fg->DropNode(cur_node);
218   }
219   return RET_SUCCESS;
220 }
221 
CreateWhileCondCallNode(const FuncGraphPtr & fg,const CNodePtr & while_cnode,const std::vector<AnfNodePtr> & visited_nodes_used_by_after_fg,CNodePtr * cond_call_cnode,std::vector<AnfNodePtr> * cond_nodes_used_by_after_partial,std::unordered_map<AnfNodePtr,AnfNodePtr> * visited_nodes_and_cond_fg_inputs_replace_pairs)222 int ControlFlowPass::CreateWhileCondCallNode(
223   const FuncGraphPtr &fg, const CNodePtr &while_cnode, const std::vector<AnfNodePtr> &visited_nodes_used_by_after_fg,
224   CNodePtr *cond_call_cnode, std::vector<AnfNodePtr> *cond_nodes_used_by_after_partial,
225   std::unordered_map<AnfNodePtr, AnfNodePtr> *visited_nodes_and_cond_fg_inputs_replace_pairs) {
226   auto cond_vnode = while_cnode->input(kWhileCondIndex);
227   MS_CHECK_TRUE_MSG(cond_vnode != nullptr, lite::RET_NULL_PTR, "cnode is nullptr");
228   auto cond_fg = GetValueNode<std::shared_ptr<FuncGraph>>(cond_vnode);
229   if (cond_fg == nullptr) {
230     MS_LOG(ERROR) << "Get value as func graph failed.";
231     return RET_FAILED;
232   }
233 
234   // create after partial node
235   ValueNodePtr cond_partial_anf_primitive = lite::GetPartialFusionPrim();
236   if (cond_partial_anf_primitive == nullptr) {
237     MS_LOG(ERROR) << "GetPartialFusionPrim failed.";
238     return RET_FAILED;
239   }
240 
241   std::vector<AnfNodePtr> cond_partial_cnode_inputs{cond_partial_anf_primitive, cond_vnode};
242   cond_partial_cnode_inputs.insert(cond_partial_cnode_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize,
243                                    while_cnode->inputs().end());
244 
245   auto origin_cond_fg_inputs = cond_fg->get_inputs();
246   for (auto &item : visited_nodes_used_by_after_fg) {
247     bool found = false;
248     size_t input_index = 0;
249     for (size_t i = kPartialFirstInputSize; i < cond_partial_cnode_inputs.size(); ++i) {
250       if (cond_partial_cnode_inputs[i] == item) {
251         found = true;
252         input_index = i - kPartialFirstInputSize;
253         break;
254       }
255     }
256 
257     if (found) {
258       (*visited_nodes_and_cond_fg_inputs_replace_pairs)[item] = origin_cond_fg_inputs.at(input_index);
259       cond_nodes_used_by_after_partial->push_back(origin_cond_fg_inputs.at(input_index));
260       continue;
261     }
262 
263     // set after fg inputs to cond_partial_cnode inputs
264     cond_partial_cnode_inputs.push_back(item);
265     auto new_parameter = cond_fg->add_parameter();
266     MS_CHECK_TRUE_MSG(new_parameter != nullptr, lite::RET_NULL_PTR, "new_parameter is nullptr");
267     new_parameter->set_name(item->fullname_with_scope() + "_cond_fg_parameter");
268     new_parameter->set_abstract(item->abstract());
269     (*visited_nodes_and_cond_fg_inputs_replace_pairs)[item] = new_parameter;
270     cond_nodes_used_by_after_partial->push_back(new_parameter);
271   }
272 
273   auto cond_partial_cnode = fg->NewCNode(cond_partial_cnode_inputs);
274   MS_CHECK_TRUE_MSG(cond_partial_cnode != nullptr, lite::RET_NULL_PTR, "cond_partial_cnode is nullptr");
275   cond_partial_cnode->set_fullname_with_scope("partial_" + cond_fg->get_attr("graph_name")->ToString());
276 
277   // insert call node
278   std::vector<AnfNodePtr> call_node_inputs{cond_partial_cnode};
279   *cond_call_cnode = fg->NewCNode(call_node_inputs);
280   MS_CHECK_TRUE_MSG(*cond_call_cnode != nullptr, lite::RET_NULL_PTR, "new cnode is nullptr");
281   (*cond_call_cnode)->set_fullname_with_scope("call_" + cond_partial_cnode->fullname_with_scope());
282 
283   return RET_SUCCESS;
284 }
285 
CreateWhileBodyPartialNode(const FuncGraphPtr & cond_fg,const CNodePtr & while_cnode,CNodePtr * body_partial_node)286 int ControlFlowPass::CreateWhileBodyPartialNode(const FuncGraphPtr &cond_fg, const CNodePtr &while_cnode,
287                                                 CNodePtr *body_partial_node) {
288   auto body_vnode = while_cnode->input(kWhileBodyIndex);
289   MS_CHECK_TRUE_MSG(body_vnode != nullptr, RET_FAILED, "body_vnode is nullptr");
290   auto body_fg = GetValueNode<std::shared_ptr<FuncGraph>>(body_vnode);
291   if (body_fg == nullptr) {
292     MS_LOG(ERROR) << "Get value as func_graph failed.";
293     return RET_FAILED;
294   }
295 
296   ValueNodePtr partial_anf_primitive = lite::GetPartialFusionPrim();
297   if (partial_anf_primitive == nullptr) {
298     MS_LOG(ERROR) << "GetPartialFusionPrim failed.";
299     return RET_FAILED;
300   }
301 
302   std::vector<AnfNodePtr> body_partial_node_inputs{partial_anf_primitive, body_vnode};
303   // set body inputs to body partial inputs
304   auto cond_fg_inputs = cond_fg->get_inputs();
305   body_partial_node_inputs.insert(body_partial_node_inputs.end(), cond_fg_inputs.begin(), cond_fg_inputs.end());
306   *body_partial_node = cond_fg->NewCNode(body_partial_node_inputs);
307   MS_CHECK_TRUE_MSG(*body_partial_node != nullptr, RET_FAILED, "new cnode is nullptr");
308   (*body_partial_node)->set_fullname_with_scope("CNode_" + body_fg->get_attr("graph_name")->ToString());
309 
310   // add after inputs for body fg to call cond fg
311   auto body_fg_inputs = body_fg->get_inputs();
312   auto origin_body_fg_inputs_size = body_fg_inputs.size();
313   for (size_t i = origin_body_fg_inputs_size; i < cond_fg_inputs.size(); ++i) {
314     if (!utils::isa<ParameterPtr>(cond_fg_inputs[i])) {
315       MS_LOG(ERROR) << "fg is not right.";
316       return RET_FAILED;
317     }
318     auto new_parameter = body_fg->add_parameter();
319     MS_CHECK_TRUE_MSG(new_parameter != nullptr, lite::RET_NULL_PTR, "new_parameter is nullptr");
320     new_parameter->set_name(cond_fg_inputs[i]->fullname_with_scope() + "_body_fg_parameter");
321     new_parameter->set_abstract(cond_fg_inputs[i]->abstract());
322   }
323 
324   // call the cond fg
325   auto cond_partial_vnode = NewValueNode(cond_fg);
326   MS_CHECK_TRUE_MSG(cond_partial_vnode != nullptr, lite::RET_NULL_PTR, "cond_partial_vnode is nullptr");
327   std::vector<AnfNodePtr> cond_call_cnode_inputs{cond_partial_vnode};
328   // set body fg output
329   auto body_output = body_fg->output()->cast<CNodePtr>();
330   MS_ASSERT(body_output != nullptr);
331   if (CheckPrimitiveType(body_output, prim::kPrimMakeTuple)) {
332     for (size_t i = 1; i < body_output->inputs().size(); ++i) {
333       cond_call_cnode_inputs.push_back(body_output->input(i));
334     }
335     body_fg->DropNode(body_output);
336   } else {
337     cond_call_cnode_inputs.push_back(body_output);
338   }
339 
340   body_fg_inputs = body_fg->get_inputs();
341   for (size_t i = origin_body_fg_inputs_size; i < body_fg_inputs.size(); ++i) {
342     cond_call_cnode_inputs.push_back(body_fg_inputs[i]);
343   }
344 
345   auto cond_call_cnode = body_fg->NewCNode(cond_call_cnode_inputs);
346   MS_CHECK_TRUE_MSG(cond_call_cnode != nullptr, lite::RET_NULL_PTR, "cond_call_cnode != nullptr");
347   cond_call_cnode->set_fullname_with_scope(body_fg->get_attr("graph_name")->ToString() + "_call_cond_fg");
348   body_fg->set_output(cond_call_cnode);
349 
350   to_process_q.push_back(body_fg);
351   return RET_SUCCESS;
352 }
353 
CreateWhileAfterPartialNode(const FuncGraphPtr & main_fg,const FuncGraphPtr & cond_fg,const std::vector<AnfNodePtr> & remain_nodes,const std::vector<AnfNodePtr> & cond_nodes_used_by_after_partial,const std::unordered_map<AnfNodePtr,AnfNodePtr> & visited_nodes_and_cond_fg_inputs_replace_pairs,CNodePtr * while_cnode,CNodePtr * after_partial_cnode)354 int ControlFlowPass::CreateWhileAfterPartialNode(
355   const FuncGraphPtr &main_fg, const FuncGraphPtr &cond_fg, const std::vector<AnfNodePtr> &remain_nodes,
356   const std::vector<AnfNodePtr> &cond_nodes_used_by_after_partial,
357   const std::unordered_map<AnfNodePtr, AnfNodePtr> &visited_nodes_and_cond_fg_inputs_replace_pairs,
358   CNodePtr *while_cnode, CNodePtr *after_partial_cnode) {
359   // create after_fg
360   FuncGraphPtr after_fg = nullptr;
361   if (CreateAfterGraph(main_fg, remain_nodes, *while_cnode, &after_fg) != RET_SUCCESS) {
362     MS_LOG(ERROR) << "CreateAfterGraph failed.";
363     return RET_FAILED;
364   }
365 
366   auto after_value_node = NewValueNode(after_fg);
367   MS_CHECK_TRUE_MSG(after_value_node != nullptr, RET_FAILED, "after_value_node is nullptr");
368   ValueNodePtr partial_anf_primitive = lite::GetPartialFusionPrim();
369   if (partial_anf_primitive == nullptr) {
370     MS_LOG(ERROR) << "GetPartialFusionPrim failed.";
371     return RET_FAILED;
372   }
373 
374   std::unordered_map<AnfNodePtr, AnfNodePtr> after_partial_inputs_and_after_fg_inputs_replace_pairs{};
375   std::vector<AnfNodePtr> after_partial_cnode_inputs{partial_anf_primitive, after_value_node};
376   auto cond_fg_inputs = cond_fg->get_inputs();
377   for (const auto &node : after_fg->nodes()) {
378     if (!CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
379       continue;
380     }
381     auto get_tuple_item_cnode = node->cast<CNodePtr>();
382     MS_ASSERT(get_tuple_item_cnode != nullptr);
383     MS_ASSERT(get_tuple_item_cnode->inputs().size() == kGetItemInputSize);
384     if (get_tuple_item_cnode->input(kCNodeFirstInputIndex) != *while_cnode) {
385       continue;
386     }
387     auto index_vnode = get_tuple_item_cnode->inputs().at(kCNodeSecondInputIndex);
388     if (!utils::isa<ValueNode>(index_vnode)) {
389       MS_LOG(ERROR) << "TupleGetItem's input 2 is not value node";
390       return RET_FAILED;
391     }
392     auto value_node = utils::cast<ValueNodePtr>(index_vnode);
393     MS_ASSERT(value_node != nullptr);
394 
395     auto input_index = value_node->value()->type()->number_type() == kNumberTypeInt64
396                          ? GetValue<int64_t>(value_node->value())
397                          : GetValue<int>(value_node->value());
398 
399     after_partial_cnode_inputs.push_back(cond_fg_inputs.at(input_index));
400     auto new_parameter = after_fg->add_parameter();
401     MS_CHECK_TRUE_MSG(new_parameter != nullptr, RET_FAILED, "new_parameter != nullptr");
402     new_parameter->set_name(node->fullname_with_scope() + "_after_partial_parameter");
403     new_parameter->set_abstract(node->abstract());
404     after_partial_inputs_and_after_fg_inputs_replace_pairs[node] = new_parameter;
405   }
406 
407   std::unordered_map<AnfNodePtr, AnfNodePtr> visited_nodes_after_fg_replace_pair{};
408   for (auto &input : cond_nodes_used_by_after_partial) {
409     after_partial_cnode_inputs.push_back(visited_nodes_and_cond_fg_inputs_replace_pairs.at(input));
410     auto new_parameter = after_fg->add_parameter();
411     MS_CHECK_TRUE_MSG(new_parameter != nullptr, RET_FAILED, "new_parameter != nullptr");
412     new_parameter->set_name(input->fullname_with_scope() + "_after_fg_parameter");
413     new_parameter->set_abstract(input->abstract());
414     visited_nodes_after_fg_replace_pair[visited_nodes_and_cond_fg_inputs_replace_pairs.at(input)] = new_parameter;
415   }
416 
417   ReplaceNode(after_fg, visited_nodes_and_cond_fg_inputs_replace_pairs);
418   ReplaceNode(after_fg, after_partial_inputs_and_after_fg_inputs_replace_pairs);
419   ReplaceNode(after_fg, visited_nodes_after_fg_replace_pair);
420   *after_partial_cnode = cond_fg->NewCNode(after_partial_cnode_inputs);
421   MS_CHECK_TRUE_MSG(*after_partial_cnode != nullptr, RET_FAILED, "new cnode is nullptr");
422   (*after_partial_cnode)->set_fullname_with_scope("CNode_" + after_fg->get_attr("graph_name")->ToString());
423   return RET_SUCCESS;
424 }
425 
ProcessWhileOp(const FuncGraphPtr & fg,const std::set<AnfNodePtr> & visited_nodes,const std::vector<AnfNodePtr> & remain_nodes,const AnfNodePtr & while_node)426 int ControlFlowPass::ProcessWhileOp(const FuncGraphPtr &fg, const std::set<AnfNodePtr> &visited_nodes,
427                                     const std::vector<AnfNodePtr> &remain_nodes, const AnfNodePtr &while_node) {
428   if (while_node == nullptr) {
429     MS_LOG(INFO) << "not found while, no need to process.";
430     return RET_SUCCESS;
431   }
432 
433   auto while_cnode = while_node->cast<CNodePtr>();
434   MS_ASSERT(while_cnode != nullptr);
435   if (while_cnode->inputs().size() < kWhileMinInputSize) {
436     MS_LOG(ERROR) << "while input is not right.";
437     return RET_FAILED;
438   }
439 
440   std::vector<AnfNodePtr> visited_nodes_used_by_after_fg{};
441   VisitedNodesUsedByAfterParts(visited_nodes, remain_nodes, &visited_nodes_used_by_after_fg);
442 
443   CNodePtr cond_call_cnode = nullptr;
444   std::unordered_map<AnfNodePtr, AnfNodePtr> visited_nodes_and_cond_fg_inputs_replace_pairs{};
445   std::vector<AnfNodePtr> cond_nodes_used_by_after_partial{};
446   int ret = CreateWhileCondCallNode(fg, while_cnode, visited_nodes_used_by_after_fg, &cond_call_cnode,
447                                     &cond_nodes_used_by_after_partial, &visited_nodes_and_cond_fg_inputs_replace_pairs);
448   if (ret != RET_SUCCESS) {
449     MS_LOG(ERROR) << "while create cond call cnode failed, ret: " << ret;
450     return ret;
451   }
452 
453   auto cond_fg_cnode = cond_call_cnode->input(kCNodePrimIndex)->cast<CNodePtr>();
454   MS_ASSERT(cond_fg_cnode != nullptr);
455   AnfNodePtr cond_fg_vnode = cond_fg_cnode->input(kCNodeFirstInputIndex);
456   MS_ASSERT(cond_fg_vnode != nullptr);
457   auto cond_fg = GetValueNode<std::shared_ptr<FuncGraph>>(cond_fg_vnode);
458   MS_CHECK_TRUE_MSG(cond_fg != nullptr, RET_FAILED, "Get value as func_graph failed.");
459 
460   CNodePtr body_partial_node = nullptr;
461   ret = CreateWhileBodyPartialNode(cond_fg, while_cnode, &body_partial_node);
462   if (ret != RET_SUCCESS) {
463     MS_LOG(ERROR) << "while create body partial cnode failed, ret: " << ret;
464     return ret;
465   }
466 
467   CNodePtr after_partial_cnode = nullptr;
468   ret = CreateWhileAfterPartialNode(fg, cond_fg, remain_nodes, visited_nodes_used_by_after_fg,
469                                     visited_nodes_and_cond_fg_inputs_replace_pairs, &while_cnode, &after_partial_cnode);
470   if (ret != RET_SUCCESS) {
471     MS_LOG(ERROR) << "while create after partial cnode failed, ret: " << ret;
472     return ret;
473   }
474 
475   // create switch cnode
476   ValueNodePtr switch_anf_primitive = lite::GetSwitchAnfPrim();
477   if (switch_anf_primitive == nullptr) {
478     MS_LOG(ERROR) << "GetSwitchAnfPrim failed.";
479     return lite::RET_ERROR;
480   }
481 
482   // insert switch node
483   std::vector<AnfNodePtr> switch_node_inputs = {switch_anf_primitive, cond_fg->output(), body_partial_node,
484                                                 after_partial_cnode};
485   auto switch_cnode = cond_fg->NewCNode(switch_node_inputs);
486   MS_CHECK_TRUE_MSG(switch_cnode != nullptr, RET_ERROR, "NewCnode failed");
487   switch_cnode->set_fullname_with_scope("while-Switch-" + cond_fg->get_attr("graph_name")->ToString());
488 
489   // insert call node
490   std::vector<AnfNodePtr> call_node_inputs{switch_cnode};
491   auto call_node = cond_fg->NewCNode(call_node_inputs);
492   MS_CHECK_TRUE_MSG(call_node != nullptr, lite::RET_NULL_PTR, "call_node is nullptr");
493   call_node->set_fullname_with_scope("call_" + switch_cnode->fullname_with_scope());
494   cond_fg->set_output(call_node);
495 
496   fg->DropNode(while_cnode);
497   fg->set_output(cond_call_cnode);
498 
499   auto after_cnode = after_partial_cnode->input(kCNodeFirstInputIndex)->cast<ValueNodePtr>();
500   MS_ASSERT(after_cnode != nullptr);
501   auto after_fg = after_cnode->value()->cast<FuncGraphPtr>();
502   if (after_fg == nullptr) {
503     MS_LOG(ERROR) << "after_fg is nullptr.";
504     return RET_FAILED;
505   }
506   to_process_q.push_back(cond_fg);
507   to_process_q.push_back(after_fg);
508   return RET_SUCCESS;
509 }
510 
CreateIfPartialNodeExternalInputs(const CNodePtr & if_cnode,const FuncGraphPtr & partial_fg,std::vector<AnfNodePtr> * then_partial_cnode_inputs)511 int ControlFlowPass::CreateIfPartialNodeExternalInputs(const CNodePtr &if_cnode, const FuncGraphPtr &partial_fg,
512                                                        std::vector<AnfNodePtr> *then_partial_cnode_inputs) {
513   auto if_inputs = if_cnode->inputs();
514   auto fg_name_attr = partial_fg->get_attr("graph_name");
515   MS_CHECK_TRUE_RET(fg_name_attr != nullptr, RET_FAILED);
516   auto partial_fg_name = fg_name_attr->ToString();
517   std::vector<AnfNodePtr> if_external_inputs{};
518   if_external_inputs.assign(if_inputs.begin() + kIfMinInputSize, if_inputs.end());
519   auto origin_then_fg_inputs = partial_fg->get_inputs();
520   if (if_external_inputs.size() < origin_then_fg_inputs.size()) {
521     MS_LOG(ERROR) << "graph is not right.";
522     return RET_FAILED;
523   } else if (if_external_inputs.size() == origin_then_fg_inputs.size()) {
524     then_partial_cnode_inputs->insert(then_partial_cnode_inputs->end(), if_external_inputs.begin(),
525                                       if_external_inputs.end());
526     return RET_SUCCESS;
527   } else {
528     for (auto &fg_input : origin_then_fg_inputs) {
529       auto fg_input_name = fg_input->fullname_with_scope();
530       auto pos = partial_fg_name.size() + sizeof("_input_");
531       auto pos2 = fg_input_name.find('_', pos);
532       auto idx_str = fg_input_name.substr(pos - 1, pos2 - pos + 1);
533       auto partial_idx = 0;
534       try {
535         partial_idx = std::stoi(idx_str);
536       } catch (const std::exception &e) {
537         MS_LOG(ERROR) << "Get index failed: " << e.what();
538         return RET_FAILED;
539       }
540       then_partial_cnode_inputs->push_back(if_external_inputs.at(partial_idx));
541     }
542   }
543   return RET_SUCCESS;
544 }
545 
CreateIfPartialNode(const FuncGraphPtr & fg,const size_t & index,std::vector<AnfNodePtr> * visited_nodes_used_by_after_fg,CNodePtr * if_cnode,FuncGraphPtr * after_fg,CNodePtr * then_partial_cnode)546 int ControlFlowPass::CreateIfPartialNode(const FuncGraphPtr &fg, const size_t &index,
547                                          std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg, CNodePtr *if_cnode,
548                                          FuncGraphPtr *after_fg, CNodePtr *then_partial_cnode) {
549   auto then_vnode = (*if_cnode)->input(index);
550   MS_ASSERT(then_vnode != nullptr);
551   auto then_fg = GetValueNode<std::shared_ptr<FuncGraph>>(then_vnode);
552   MS_CHECK_TRUE_MSG(then_fg != nullptr, RET_FAILED, "Get value as func_graph failed.");
553 
554   // create then partial node
555   ValueNodePtr then_partial_anf_primitive = lite::GetPartialFusionPrim();
556   MS_CHECK_TRUE_MSG(then_partial_anf_primitive != nullptr, RET_FAILED, "GetPartialFusionPrim failed.");
557   std::vector<AnfNodePtr> then_partial_cnode_inputs{then_partial_anf_primitive, then_vnode};
558   if (CreateIfPartialNodeExternalInputs(*if_cnode, then_fg, &then_partial_cnode_inputs) != RET_SUCCESS) {
559     MS_LOG(ERROR) << "CreateIfPartialNodeExternalInputs failed.";
560     return RET_FAILED;
561   }
562   std::unordered_map<AnfNodePtr, AnfNodePtr> visited_nodes_and_after_partial_inputs_replace_pairs{};
563   std::vector<AnfNodePtr> then_nodes_used_by_after_partial{};
564   // set fg inputs to then_partial_cnode inputs
565   auto origin_then_fg_inputs = then_fg->get_inputs();
566   for (auto &item : *visited_nodes_used_by_after_fg) {
567     bool found = false;
568     size_t input_index = 0;
569     for (size_t i = kPartialFirstInputSize; i < then_partial_cnode_inputs.size(); ++i) {
570       if (then_partial_cnode_inputs[i] == item) {
571         found = true;
572         input_index = i - kPartialFirstInputSize;
573         break;
574       }
575     }
576     if (found) {
577       visited_nodes_and_after_partial_inputs_replace_pairs[item] = origin_then_fg_inputs.at(input_index);
578       then_nodes_used_by_after_partial.push_back(origin_then_fg_inputs.at(input_index));
579       continue;
580     }
581 
582     // set after fg inputs to cond_partial_cnode inputs
583     then_partial_cnode_inputs.push_back(item);
584     auto new_parameter = then_fg->add_parameter();
585     MS_CHECK_TRUE_MSG(new_parameter != nullptr, RET_FAILED, "new_parameter is nullptr");
586     if (index == kIfThenIndex) {
587       new_parameter->set_name(item->fullname_with_scope() + "_then_fg_parameter");
588     } else {
589       new_parameter->set_name(item->fullname_with_scope() + "_else_fg_parameter");
590     }
591     new_parameter->set_abstract(item->abstract());
592     visited_nodes_and_after_partial_inputs_replace_pairs[item] = new_parameter;
593     then_nodes_used_by_after_partial.push_back(new_parameter);
594   }
595   *then_partial_cnode = fg->NewCNode(then_partial_cnode_inputs);
596   MS_CHECK_TRUE_MSG(*then_partial_cnode != nullptr, RET_FAILED, "new cnode is nullptr");
597   auto fg_name_attr = then_fg->get_attr("graph_name");
598   MS_CHECK_TRUE_RET(fg_name_attr != nullptr, RET_FAILED);
599   auto then_fg_name = fg_name_attr->ToString();
600   (*then_partial_cnode)->set_fullname_with_scope("partial_" + then_fg_name);
601 
602   // create after partial node
603   ValueNodePtr after_partial_anf_primitive = lite::GetPartialFusionPrim();
604   MS_CHECK_TRUE_MSG(after_partial_anf_primitive != nullptr, RET_FAILED, "GetPartialFusionPrim failed.");
605   auto after_value_node = NewValueNode(*after_fg);
606   MS_CHECK_TRUE_MSG(after_value_node != nullptr, RET_FAILED, "NewValueNode failed.");
607   // make the right after partial input
608   std::vector<AnfNodePtr> after_partial_cnode_inputs{after_partial_anf_primitive, after_value_node};
609   if (!CheckPrimitiveType(then_fg->output(), prim::kPrimMakeTuple)) {
610     after_partial_cnode_inputs.push_back(then_fg->output());
611   } else {
612     auto then_fg_output = then_fg->output()->cast<CNodePtr>();
613     MS_CHECK_TRUE_MSG(then_fg_output != nullptr, RET_ERROR, "cast ptr failed");
614     for (size_t i = kCNodeFirstInputIndex; i < then_fg_output->inputs().size(); ++i) {
615       after_partial_cnode_inputs.push_back(then_fg_output->input(i));
616     }
617     then_fg->DropNode(then_fg_output);
618   }
619   size_t if_output_size = after_partial_cnode_inputs.size() - kCNodeSecondInputIndex;
620 
621   // add after fg inputs to partial node
622   std::copy(then_nodes_used_by_after_partial.begin(), then_nodes_used_by_after_partial.end(),
623             std::back_inserter(after_partial_cnode_inputs));
624   // insert partial node
625   auto after_partial_cnode = then_fg->NewCNode(after_partial_cnode_inputs);
626   MS_CHECK_TRUE_MSG(after_partial_cnode != nullptr, RET_FAILED, "NewCNode failed");
627   auto after_fg_name = (*after_fg)->get_attr("graph_name")->ToString();
628   after_partial_cnode->set_fullname_with_scope("partial_" + after_fg_name);
629 
630   // insert call node
631   std::vector<AnfNodePtr> call_node_inputs{after_partial_cnode};
632   auto call_node = then_fg->NewCNode(call_node_inputs);
633   MS_CHECK_TRUE_MSG(call_node != nullptr, RET_FAILED, "NewCNode failed");
634   call_node->set_fullname_with_scope("call_" + after_partial_cnode->fullname_with_scope());
635   then_fg->set_output(call_node);
636   to_process_q.push_back(then_fg);
637   ReplaceNode(*after_fg, visited_nodes_and_after_partial_inputs_replace_pairs);
638 
639   // check the inputs of after fg
640   auto after_fg_inputs_size = (*after_fg)->get_inputs().size();
641   if (after_fg_inputs_size == after_partial_cnode_inputs.size() - kPartialFirstInputSize) {
642     MS_LOG(INFO) << "not need add after fg input parameters.";
643     return RET_SUCCESS;
644   }
645 
646   // make the inputs of the after fg
647   std::unordered_map<AnfNodePtr, AnfNodePtr> after_partial_after_fg_replace_pairs{};
648   for (size_t i = kPartialFirstInputSize; i < after_partial_cnode_inputs.size(); ++i) {
649     auto &input = after_partial_cnode_inputs[i];
650     auto new_parameter = (*after_fg)->add_parameter();
651     MS_CHECK_TRUE_MSG(new_parameter != nullptr, RET_FAILED, "add_parameter failed");
652     new_parameter->set_name(std::to_string(i - kPartialFirstInputSize) + "_" + input->fullname_with_scope());
653     new_parameter->set_abstract(input->abstract());
654     if (i < kPartialFirstInputSize + if_output_size) {
655       after_partial_after_fg_replace_pairs[*if_cnode] = new_parameter;
656     } else {
657       after_partial_after_fg_replace_pairs[input] = new_parameter;
658     }
659   }
660   ReplaceNode(*after_fg, after_partial_after_fg_replace_pairs);
661 
662   return RET_SUCCESS;
663 }
664 
CreateIfElsePartialNode(const FuncGraphPtr & main_fg,std::vector<AnfNodePtr> * visited_nodes_used_by_after_fg,CNodePtr * if_cnode,FuncGraphPtr * after_fg,CNodePtr * else_partial_cnode)665 int ControlFlowPass::CreateIfElsePartialNode(const FuncGraphPtr &main_fg,
666                                              std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg,
667                                              CNodePtr *if_cnode, FuncGraphPtr *after_fg, CNodePtr *else_partial_cnode) {
668   return CreateIfPartialNode(main_fg, kIfElseIndex, visited_nodes_used_by_after_fg, if_cnode, after_fg,
669                              else_partial_cnode);
670 }
671 
CreateIfThenPartialNode(const FuncGraphPtr & main_fg,std::vector<AnfNodePtr> * visited_nodes_used_by_after_fg,CNodePtr * if_cnode,FuncGraphPtr * after_fg,CNodePtr * then_partial_cnode)672 int ControlFlowPass::CreateIfThenPartialNode(const FuncGraphPtr &main_fg,
673                                              std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg,
674                                              CNodePtr *if_cnode, FuncGraphPtr *after_fg, CNodePtr *then_partial_cnode) {
675   return CreateIfPartialNode(main_fg, kIfThenIndex, visited_nodes_used_by_after_fg, if_cnode, after_fg,
676                              then_partial_cnode);
677 }
678 
ProcessIfOp(const FuncGraphPtr & fg,const std::set<AnfNodePtr> & visited_nodes,const std::vector<AnfNodePtr> & remain_nodes,const AnfNodePtr & if_node)679 int ControlFlowPass::ProcessIfOp(const FuncGraphPtr &fg, const std::set<AnfNodePtr> &visited_nodes,
680                                  const std::vector<AnfNodePtr> &remain_nodes, const AnfNodePtr &if_node) {
681   if (if_node == nullptr) {
682     MS_LOG(INFO) << "not found if, no need to process.";
683     return RET_SUCCESS;
684   }
685 
686   auto if_cnode = if_node->cast<CNodePtr>();
687   MS_ASSERT(if_cnode != nullptr);
688   if (if_cnode->inputs().size() < kIfMinInputSize) {
689     MS_LOG(ERROR) << "if input is not right.";
690     return RET_FAILED;
691   }
692 
693   // create after_fg
694   FuncGraphPtr after_fg = nullptr;
695   if (CreateAfterGraph(fg, remain_nodes, if_cnode, &after_fg) != RET_SUCCESS) {
696     MS_LOG(ERROR) << "CreateAfterGraph failed.";
697     return RET_FAILED;
698   }
699 
700   // get fg input which is not used by after_parts
701   std::vector<AnfNodePtr> visited_nodes_used_by_after_fg{};
702   VisitedNodesUsedByAfterParts(visited_nodes, remain_nodes, &visited_nodes_used_by_after_fg);
703 
704   CNodePtr then_partial_cnode = nullptr;
705   int ret = CreateIfThenPartialNode(fg, &visited_nodes_used_by_after_fg, &if_cnode, &after_fg, &then_partial_cnode);
706   if (ret != RET_SUCCESS) {
707     MS_LOG(ERROR) << "if create then partial cnode failed, ret: " << ret;
708     return ret;
709   }
710 
711   CNodePtr else_partial_cnode = nullptr;
712   ret = CreateIfElsePartialNode(fg, &visited_nodes_used_by_after_fg, &if_cnode, &after_fg, &else_partial_cnode);
713   if (ret != RET_SUCCESS) {
714     MS_LOG(ERROR) << "if create else partial cnode failed, ret: " << ret;
715     return ret;
716   }
717 
718   // create switch cnode
719   ValueNodePtr switch_anf_primitive = lite::GetSwitchAnfPrim();
720   if (switch_anf_primitive == nullptr) {
721     MS_LOG(ERROR) << "GetSwitchAnfPrim failed.";
722     return RET_FAILED;
723   }
724 
725   //  insert switch node
726   std::vector<AnfNodePtr> switch_node_inputs = {switch_anf_primitive, if_cnode->input(kIfCondIndex), then_partial_cnode,
727                                                 else_partial_cnode};
728   auto switch_cnode = fg->NewCNode(switch_node_inputs);
729   MS_CHECK_TRUE_MSG(switch_cnode != nullptr, RET_FAILED, "NewCNode failed");
730   switch_cnode->set_fullname_with_scope("if-Switch-" + fg->get_attr("graph_name")->ToString());
731 
732   // insert call node
733   std::vector<AnfNodePtr> call_node_inputs{switch_cnode};
734   auto call_node = fg->NewCNode(call_node_inputs);
735   MS_CHECK_TRUE_MSG(call_node != nullptr, RET_FAILED, "NewCNode failed");
736   call_node->set_fullname_with_scope("call_" + switch_cnode->fullname_with_scope());
737   fg->DropNode(if_cnode);
738   fg->set_output(call_node, true);
739 
740   to_process_q.push_back(after_fg);
741   return RET_SUCCESS;
742 }
743 
ProcessControlOp(const FuncGraphPtr & fg)744 int ControlFlowPass::ProcessControlOp(const FuncGraphPtr &fg) {
745   if (fg == nullptr) {
746     MS_LOG(ERROR) << "fg is nullptr.";
747     return RET_FAILED;
748   }
749 
750   AnfNodePtr control_flow_node = nullptr;
751   std::vector<AnfNodePtr> remain_nodes{};
752   std::set<AnfNodePtr> visited_nodes{};
753   int ret = SplitGraph(fg, &control_flow_node, &visited_nodes, &remain_nodes);
754   if (ret != RET_SUCCESS) {
755     MS_LOG(ERROR) << "SplitGraph failed, ret: " << ret;
756     return ret;
757   }
758 
759   if (control_flow_node == nullptr) {
760     MS_LOG(INFO) << "not found control flow op, no need to process.";
761     return RET_SUCCESS;
762   }
763 
764   if (CheckPrimitiveType(control_flow_node, prim::kPrimWhile)) {
765     ret = ProcessWhileOp(fg, visited_nodes, remain_nodes, control_flow_node);
766     if (ret != RET_SUCCESS) {
767       MS_LOG(ERROR) << "ProcessWhileOp failed.";
768       return ret;
769     }
770   }
771 
772   if (CheckPrimitiveType(control_flow_node, prim::kPrimIf)) {
773     ret = ProcessIfOp(fg, visited_nodes, remain_nodes, control_flow_node);
774     if (ret != RET_SUCCESS) {
775       MS_LOG(ERROR) << "ProcessIfOp failed.";
776       return ret;
777     }
778   }
779   return RET_SUCCESS;
780 }
781 
Run(const FuncGraphPtr & fg)782 bool ControlFlowPass::Run(const FuncGraphPtr &fg) {
783   MS_ASSERT(fg != nullptr);
784   to_process_q.push_back(fg);
785   while (!to_process_q.empty()) {
786     auto cur_fg = to_process_q.front();
787     auto cur_fg_name = cur_fg->get_attr("graph_name")->ToString();
788     int ret = ProcessControlOp(cur_fg);
789     if (ret != RET_SUCCESS) {
790       MS_LOG(ERROR) << "ProcessControlOp for graph: " << cur_fg_name << " failed.";
791       lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
792       return false;
793     }
794     to_process_q.pop_front();
795   }
796   return true;
797 }
798 }  // namespace mindspore::opt
799