• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/graph/control_flow_pass.h"
19 #include <vector>
20 #include <memory>
21 #include <algorithm>
22 #include "mindspore/core/ops/sequence_ops.h"
23 #include "mindspore/core/ops/lite_ops.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "ops/switch.h"
26 #include "ops/fusion/partial_fusion.h"
27 #include "include/errorcode.h"
28 #include "tools/optimizer/common/gllo_utils.h"
29 #include "src/common/log_adapter.h"
30 #include "tools/common/node_util.h"
31 #include "nnacl/op_base.h"
32 #include "include/registry/converter_context.h"
33 
34 namespace mindspore::opt {
ReplaceNode(const FuncGraphPtr & fg,const std::unordered_map<AnfNodePtr,AnfNodePtr> & replace_pairs)35 void ControlFlowPass::ReplaceNode(const FuncGraphPtr &fg,
36                                   const std::unordered_map<AnfNodePtr, AnfNodePtr> &replace_pairs) {
37   for (auto &node : fg->nodes()) {
38     if (!utils::isa<CNodePtr>(node)) {
39       continue;
40     }
41     auto cnode = node->cast<CNodePtr>();
42     MS_ASSERT(cnode != nullptr);
43     auto new_inputs = cnode->inputs();
44     for (auto &input : new_inputs) {
45       if (replace_pairs.find(input) == replace_pairs.end()) {
46         continue;
47       }
48       input = replace_pairs.at(input);
49     }
50     cnode->set_inputs(new_inputs);
51   }
52 }
53 
VisitedNodesUsedByAfterParts(const std::set<AnfNodePtr> & visited_nodes,const std::vector<AnfNodePtr> & remain_nodes,std::vector<AnfNodePtr> * visited_nodes_used_by_after_fg)54 void ControlFlowPass::VisitedNodesUsedByAfterParts(const std::set<AnfNodePtr> &visited_nodes,
55                                                    const std::vector<AnfNodePtr> &remain_nodes,
56                                                    std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg) {
57   std::deque<AnfNodePtr> nodes{};
58   std::set<AnfNodePtr> visited_nodes_used_by_after_fg_set{};
59   std::set<AnfNodePtr> remain_nodes_set{};
60   nodes.assign(remain_nodes.begin(), remain_nodes.end());
61   while (!nodes.empty()) {
62     auto node = nodes.front();
63     nodes.pop_front();
64     remain_nodes_set.insert(node);
65     if (!utils::isa<CNodePtr>(node)) {
66       continue;
67     }
68     auto cnode = node->cast<CNodePtr>();
69     MS_ASSERT(cnode != nullptr);
70     for (auto &input : cnode->inputs()) {
71       if (visited_nodes.find(input) != visited_nodes.end() &&
72           visited_nodes_used_by_after_fg_set.find(input) == visited_nodes_used_by_after_fg_set.end()) {
73         visited_nodes_used_by_after_fg->push_back(input);
74         visited_nodes_used_by_after_fg_set.insert(input);
75       }
76     }
77   }
78 }
79 
GetItemVisitedNums(const std::set<AnfNodePtr> & visited_nodes,const AnfNodePtr & tuple_node)80 size_t ControlFlowPass::GetItemVisitedNums(const std::set<AnfNodePtr> &visited_nodes, const AnfNodePtr &tuple_node) {
81   size_t count = 0;
82   for (auto &node : visited_nodes) {
83     if (!utils::isa<CNodePtr>(node)) {
84       continue;
85     }
86     if (!CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
87       continue;
88     }
89     auto get_item_cnode = node->cast<CNodePtr>();
90     MS_ASSERT(get_item_cnode != nullptr);
91     if (get_item_cnode->inputs()[kCNodeFirstInputIndex] == tuple_node) {
92       count++;
93     }
94   }
95   return count;
96 }
97 
MoveGetItemToVisited(const size_t & need_size,const AnfNodePtr & tuple_node,std::set<AnfNodePtr> * visited_nodes,std::vector<AnfNodePtr> * remain_nodes)98 void ControlFlowPass::MoveGetItemToVisited(const size_t &need_size, const AnfNodePtr &tuple_node,
99                                            std::set<AnfNodePtr> *visited_nodes, std::vector<AnfNodePtr> *remain_nodes) {
100   size_t i = 0;
101   for (auto it = remain_nodes->begin(); it != remain_nodes->end();) {
102     if (!utils::isa<CNodePtr>(*it)) {
103       ++it;
104       continue;
105     }
106     if (!CheckPrimitiveType(*it, prim::kPrimTupleGetItem)) {
107       ++it;
108       continue;
109     }
110     auto get_item_cnode = (*it)->cast<CNodePtr>();
111     MS_ASSERT(get_item_cnode != nullptr);
112     if (get_item_cnode->inputs()[kCNodeFirstInputIndex] != tuple_node) {
113       ++it;
114       continue;
115     }
116     i++;
117     visited_nodes->insert(*it);
118     it = remain_nodes->erase(it);
119     if (need_size == i) {
120       return;
121     }
122   }
123   MS_LOG(INFO) << tuple_node->fullname_with_scope() << " not found enough get item, size: " << need_size - i;
124 }
125 
BindGetItemNodes(std::set<AnfNodePtr> * visited_nodes,std::vector<AnfNodePtr> * remain_nodes)126 void ControlFlowPass::BindGetItemNodes(std::set<AnfNodePtr> *visited_nodes, std::vector<AnfNodePtr> *remain_nodes) {
127   std::deque<AnfNodePtr> multi_output_nodes{};
128   for (auto &node : *visited_nodes) {
129     if (!utils::isa<CNodePtr>(node)) {
130       continue;
131     }
132     if (utils::isa<abstract::AbstractTuple>(node->abstract())) {
133       multi_output_nodes.push_back(node);
134     }
135   }
136 
137   while (!multi_output_nodes.empty()) {
138     auto cur_node = multi_output_nodes.front();
139     multi_output_nodes.pop_front();
140     size_t total_getitem_size = cur_node->abstract()->cast<abstract::AbstractTuplePtr>()->size();
141     size_t visited_getitem_size = GetItemVisitedNums(*visited_nodes, cur_node);
142     if (total_getitem_size == visited_getitem_size) {
143       continue;
144     }
145 
146     size_t need_getitem_size = total_getitem_size - visited_getitem_size;
147     MoveGetItemToVisited(need_getitem_size, cur_node, visited_nodes, remain_nodes);
148   }
149 }
150 
SplitGraph(const FuncGraphPtr & fg,AnfNodePtr * control_flow_node,std::set<AnfNodePtr> * visited_nodes,std::vector<AnfNodePtr> * remain_nodes)151 int ControlFlowPass::SplitGraph(const FuncGraphPtr &fg, AnfNodePtr *control_flow_node,
152                                 std::set<AnfNodePtr> *visited_nodes, std::vector<AnfNodePtr> *remain_nodes) {
153   auto inputs = fg->get_inputs();
154 
155   // notice: fg->nodes() is not work in this pass, cause too many useless parameter have been created.
156   auto node_list = TopoSort(fg->get_return());
157   for (auto &node : node_list) {
158     MS_ASSERT(node != nullptr);
159     if (utils::isa<CNodePtr>(node) &&
160         (CheckPrimitiveType(node, prim::kPrimWhile) || CheckPrimitiveType(node, prim::kPrimIf))) {
161       *control_flow_node = node;
162       break;
163     }
164   }
165 
166   std::deque<AnfNodePtr> q;
167   visited_nodes->insert(inputs.begin(), inputs.end());
168   q.push_back(*control_flow_node);
169   while (!q.empty()) {
170     auto node = q.front();
171     q.pop_front();
172     if (!utils::isa<CNodePtr>(node)) {
173       continue;
174     }
175     visited_nodes->insert(node);
176     auto cnode = utils::cast<CNodePtr>(node);
177     MS_CHECK_TRUE_MSG(cnode != nullptr, RET_ERROR, "cast ptr failed");
178     for (size_t i = 0; i < cnode->size(); i++) {
179       auto input = cnode->input(i);
180       if (visited_nodes->find(input) == visited_nodes->end()) {
181         q.push_back(input);
182       }
183     }
184   }
185 
186   for (auto &node : node_list) {
187     if (visited_nodes->find(node) == visited_nodes->end()) {
188       remain_nodes->push_back(node);
189     }
190   }
191   visited_nodes->erase(*control_flow_node);
192 
193   BindGetItemNodes(visited_nodes, remain_nodes);
194 
195   return RET_SUCCESS;
196 }
197 
CreateAfterGraph(const FuncGraphPtr & main_fg,const std::vector<AnfNodePtr> & remain_nodes,const CNodePtr & aim_cnode,FuncGraphPtr * after_fg)198 int ControlFlowPass::CreateAfterGraph(const FuncGraphPtr &main_fg, const std::vector<AnfNodePtr> &remain_nodes,
199                                       const CNodePtr &aim_cnode, FuncGraphPtr *after_fg) {
200   *after_fg = std::make_shared<FuncGraph>();
201   MS_CHECK_TRUE_MSG(*after_fg != nullptr, lite::RET_NULL_PTR, "*after_fg is nullptr");
202   auto manager = main_fg->manager();
203   MS_ASSERT(manager != nullptr);
204   manager->AddFuncGraph(*after_fg);
205   (*after_fg)->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeTf)));
206   (*after_fg)->set_attr("graph_name", MakeValue(aim_cnode->fullname_with_scope() + "_after_fg"));
207   (*after_fg)->set_manager(main_fg->manager());
208 
209   for (auto &cur_node : remain_nodes) {
210     if (cur_node->isa<ValueNode>()) {
211       continue;
212     }
213     if (cur_node == main_fg->get_return()) {
214       continue;
215     }
216     (*after_fg)->AddNode(cur_node);
217     if (!utils::isa<ValueNodePtr>(cur_node)) {
218       cur_node->set_func_graph(*after_fg);
219     }
220     if (cur_node == main_fg->output()) {
221       (*after_fg)->set_output(cur_node, false);
222     }
223     main_fg->DropNode(cur_node);
224   }
225   return RET_SUCCESS;
226 }
227 
CreateWhileCondCallNode(const FuncGraphPtr & fg,const CNodePtr & while_cnode,const std::vector<AnfNodePtr> & visited_nodes_used_by_after_fg,CNodePtr * cond_call_cnode,std::vector<AnfNodePtr> * cond_nodes_used_by_after_partial,std::unordered_map<AnfNodePtr,AnfNodePtr> * visited_nodes_and_cond_fg_inputs_replace_pairs)228 int ControlFlowPass::CreateWhileCondCallNode(
229   const FuncGraphPtr &fg, const CNodePtr &while_cnode, const std::vector<AnfNodePtr> &visited_nodes_used_by_after_fg,
230   CNodePtr *cond_call_cnode, std::vector<AnfNodePtr> *cond_nodes_used_by_after_partial,
231   std::unordered_map<AnfNodePtr, AnfNodePtr> *visited_nodes_and_cond_fg_inputs_replace_pairs) {
232   auto cond_vnode = while_cnode->input(kWhileCondIndex);
233   MS_CHECK_TRUE_MSG(cond_vnode != nullptr, lite::RET_NULL_PTR, "cnode is nullptr");
234   auto cond_fg = GetValueNode<std::shared_ptr<FuncGraph>>(cond_vnode);
235   if (cond_fg == nullptr) {
236     MS_LOG(ERROR) << "Get value as func graph failed.";
237     return RET_FAILED;
238   }
239 
240   // create after partial node
241   ValueNodePtr cond_partial_anf_primitive = lite::GetPartialFusionPrim();
242   if (cond_partial_anf_primitive == nullptr) {
243     MS_LOG(ERROR) << "GetPartialFusionPrim failed.";
244     return RET_FAILED;
245   }
246 
247   std::vector<AnfNodePtr> cond_partial_cnode_inputs{cond_partial_anf_primitive, cond_vnode};
248   cond_partial_cnode_inputs.insert(cond_partial_cnode_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize,
249                                    while_cnode->inputs().end());
250 
251   auto origin_cond_fg_inputs = cond_fg->get_inputs();
252   for (auto &item : visited_nodes_used_by_after_fg) {
253     bool found = false;
254     size_t input_index = 0;
255     for (size_t i = kPartialFirstInputSize; i < cond_partial_cnode_inputs.size(); ++i) {
256       if (cond_partial_cnode_inputs[i] == item) {
257         found = true;
258         input_index = i - kPartialFirstInputSize;
259         break;
260       }
261     }
262 
263     if (found) {
264       (*visited_nodes_and_cond_fg_inputs_replace_pairs)[item] = origin_cond_fg_inputs.at(input_index);
265       cond_nodes_used_by_after_partial->push_back(origin_cond_fg_inputs.at(input_index));
266       continue;
267     }
268 
269     // set after fg inputs to cond_partial_cnode inputs
270     cond_partial_cnode_inputs.push_back(item);
271     auto new_parameter = cond_fg->add_parameter();
272     MS_CHECK_TRUE_MSG(new_parameter != nullptr, lite::RET_NULL_PTR, "new_parameter is nullptr");
273     new_parameter->set_name(item->fullname_with_scope() + "_cond_fg_parameter");
274     new_parameter->set_abstract(item->abstract());
275     (*visited_nodes_and_cond_fg_inputs_replace_pairs)[item] = new_parameter;
276     cond_nodes_used_by_after_partial->push_back(new_parameter);
277   }
278 
279   auto cond_partial_cnode = fg->NewCNode(cond_partial_cnode_inputs);
280   MS_CHECK_TRUE_MSG(cond_partial_cnode != nullptr, lite::RET_NULL_PTR, "cond_partial_cnode is nullptr");
281   cond_partial_cnode->set_fullname_with_scope("partial_" + cond_fg->get_attr("graph_name")->ToString());
282 
283   // insert call node
284   std::vector<AnfNodePtr> call_node_inputs{cond_partial_cnode};
285   *cond_call_cnode = fg->NewCNode(call_node_inputs);
286   MS_CHECK_TRUE_MSG(*cond_call_cnode != nullptr, lite::RET_NULL_PTR, "new cnode is nullptr");
287   (*cond_call_cnode)->set_fullname_with_scope("call_" + cond_partial_cnode->fullname_with_scope());
288 
289   return RET_SUCCESS;
290 }
291 
CreateWhileBodyPartialNode(const FuncGraphPtr & cond_fg,const CNodePtr & while_cnode,CNodePtr * body_partial_node)292 int ControlFlowPass::CreateWhileBodyPartialNode(const FuncGraphPtr &cond_fg, const CNodePtr &while_cnode,
293                                                 CNodePtr *body_partial_node) {
294   auto body_vnode = while_cnode->input(kWhileBodyIndex);
295   MS_CHECK_TRUE_MSG(body_vnode != nullptr, RET_FAILED, "body_vnode is nullptr");
296   auto body_fg = GetValueNode<std::shared_ptr<FuncGraph>>(body_vnode);
297   if (body_fg == nullptr) {
298     MS_LOG(ERROR) << "Get value as func_graph failed.";
299     return RET_FAILED;
300   }
301 
302   ValueNodePtr partial_anf_primitive = lite::GetPartialFusionPrim();
303   if (partial_anf_primitive == nullptr) {
304     MS_LOG(ERROR) << "GetPartialFusionPrim failed.";
305     return RET_FAILED;
306   }
307 
308   std::vector<AnfNodePtr> body_partial_node_inputs{partial_anf_primitive, body_vnode};
309   // set body inputs to body partial inputs
310   auto cond_fg_inputs = cond_fg->get_inputs();
311   body_partial_node_inputs.insert(body_partial_node_inputs.end(), cond_fg_inputs.begin(), cond_fg_inputs.end());
312   *body_partial_node = cond_fg->NewCNode(body_partial_node_inputs);
313   MS_CHECK_TRUE_MSG(*body_partial_node != nullptr, RET_FAILED, "new cnode is nullptr");
314   (*body_partial_node)->set_fullname_with_scope("CNode_" + body_fg->get_attr("graph_name")->ToString());
315 
316   // add after inputs for body fg to call cond fg
317   auto body_fg_inputs = body_fg->get_inputs();
318   auto origin_body_fg_inputs_size = body_fg_inputs.size();
319   for (size_t i = origin_body_fg_inputs_size; i < cond_fg_inputs.size(); ++i) {
320     if (!utils::isa<ParameterPtr>(cond_fg_inputs[i])) {
321       MS_LOG(ERROR) << "fg is not right.";
322       return RET_FAILED;
323     }
324     auto new_parameter = body_fg->add_parameter();
325     MS_CHECK_TRUE_MSG(new_parameter != nullptr, lite::RET_NULL_PTR, "new_parameter is nullptr");
326     new_parameter->set_name(cond_fg_inputs[i]->fullname_with_scope() + "_body_fg_parameter");
327     new_parameter->set_abstract(cond_fg_inputs[i]->abstract());
328   }
329 
330   // call the cond fg
331   ValueNodePtr cond_partial_anf_primitive = lite::GetPartialFusionPrim();
332   if (cond_partial_anf_primitive == nullptr) {
333     MS_LOG(ERROR) << "`new cond_partial_anf_primitive failed.";
334     return RET_FAILED;
335   }
336   auto cond_partial_vnode = NewValueNode(cond_fg);
337   MS_CHECK_TRUE_MSG(cond_partial_vnode != nullptr, lite::RET_NULL_PTR, "cond_partial_vnode is nullptr");
338   std::vector<AnfNodePtr> cond_partial_inputs{cond_partial_anf_primitive, cond_partial_vnode};
339   // set body fg output
340   auto body_output = body_fg->output()->cast<CNodePtr>();
341   MS_ASSERT(body_output != nullptr);
342   if (CheckPrimitiveType(body_output, prim::kPrimMakeTuple)) {
343     for (size_t i = 1; i < body_output->size(); ++i) {
344       cond_partial_inputs.push_back(body_output->input(i));
345     }
346     body_fg->DropNode(body_output);
347   } else {
348     cond_partial_inputs.push_back(body_output);
349   }
350 
351   body_fg_inputs = body_fg->get_inputs();
352   for (size_t i = origin_body_fg_inputs_size; i < body_fg_inputs.size(); ++i) {
353     cond_partial_inputs.push_back(body_fg_inputs[i]);
354   }
355 
356   auto cond_partial_cnode = body_fg->NewCNode(cond_partial_inputs);
357   MS_CHECK_TRUE_MSG(cond_partial_cnode != nullptr, lite::RET_NULL_PTR, "cond_partial_cnode != nullptr");
358   cond_partial_cnode->set_fullname_with_scope(body_fg->get_attr("graph_name")->ToString() + "_call_cond_fg");
359 
360   // insert call node
361   std::vector<AnfNodePtr> call_node_inputs{cond_partial_cnode};
362   auto cond_call_cnode = body_fg->NewCNode(call_node_inputs);
363   MS_CHECK_TRUE_MSG(cond_call_cnode != nullptr, RET_FAILED, "new cnode is nullptr");
364   cond_call_cnode->set_fullname_with_scope("call_" + cond_partial_cnode->fullname_with_scope());
365   body_fg->set_output(cond_call_cnode);
366 
367   to_process_q.push_back(body_fg);
368   return RET_SUCCESS;
369 }
370 
CreateWhileAfterPartialNode(const FuncGraphPtr & main_fg,const FuncGraphPtr & cond_fg,const std::vector<AnfNodePtr> & remain_nodes,const std::vector<AnfNodePtr> & cond_nodes_used_by_after_partial,const std::unordered_map<AnfNodePtr,AnfNodePtr> & visited_nodes_and_cond_fg_inputs_replace_pairs,const CNodePtr * while_cnode,CNodePtr * after_partial_cnode)371 int ControlFlowPass::CreateWhileAfterPartialNode(
372   const FuncGraphPtr &main_fg, const FuncGraphPtr &cond_fg, const std::vector<AnfNodePtr> &remain_nodes,
373   const std::vector<AnfNodePtr> &cond_nodes_used_by_after_partial,
374   const std::unordered_map<AnfNodePtr, AnfNodePtr> &visited_nodes_and_cond_fg_inputs_replace_pairs,
375   const CNodePtr *while_cnode, CNodePtr *after_partial_cnode) {
376   // create after_fg
377   FuncGraphPtr after_fg = nullptr;
378   if (CreateAfterGraph(main_fg, remain_nodes, *while_cnode, &after_fg) != RET_SUCCESS) {
379     MS_LOG(ERROR) << "CreateAfterGraph failed.";
380     return RET_FAILED;
381   }
382 
383   auto after_value_node = NewValueNode(after_fg);
384   MS_CHECK_TRUE_MSG(after_value_node != nullptr, RET_FAILED, "after_value_node is nullptr");
385   ValueNodePtr partial_anf_primitive = lite::GetPartialFusionPrim();
386   if (partial_anf_primitive == nullptr) {
387     MS_LOG(ERROR) << "GetPartialFusionPrim failed.";
388     return RET_FAILED;
389   }
390 
391   std::unordered_map<AnfNodePtr, AnfNodePtr> after_partial_inputs_and_after_fg_inputs_replace_pairs{};
392   std::vector<AnfNodePtr> after_partial_cnode_inputs{partial_anf_primitive, after_value_node};
393   auto cond_fg_inputs = cond_fg->get_inputs();
394   for (const auto &node : after_fg->nodes()) {
395     if (!CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
396       continue;
397     }
398     auto get_tuple_item_cnode = node->cast<CNodePtr>();
399     MS_ASSERT(get_tuple_item_cnode != nullptr);
400     MS_ASSERT(get_tuple_item_cnode->size() == kGetItemInputSize);
401     if (get_tuple_item_cnode->input(kCNodeFirstInputIndex) != *while_cnode) {
402       continue;
403     }
404     auto index_vnode = get_tuple_item_cnode->inputs().at(kCNodeSecondInputIndex);
405     if (!utils::isa<ValueNode>(index_vnode)) {
406       MS_LOG(ERROR) << "TupleGetItem's input 2 is not value node";
407       return RET_FAILED;
408     }
409     auto value_node = utils::cast<ValueNodePtr>(index_vnode);
410     MS_ASSERT(value_node != nullptr);
411 
412     auto input_index = value_node->value()->type()->number_type() == kNumberTypeInt64
413                          ? GetValue<int64_t>(value_node->value())
414                          : GetValue<int>(value_node->value());
415 
416     after_partial_cnode_inputs.push_back(cond_fg_inputs.at(input_index));
417     auto new_parameter = after_fg->add_parameter();
418     MS_CHECK_TRUE_MSG(new_parameter != nullptr, RET_FAILED, "new_parameter != nullptr");
419     new_parameter->set_name(node->fullname_with_scope() + "_after_partial_parameter");
420     new_parameter->set_abstract(node->abstract());
421     after_partial_inputs_and_after_fg_inputs_replace_pairs[node] = new_parameter;
422   }
423 
424   std::unordered_map<AnfNodePtr, AnfNodePtr> visited_nodes_after_fg_replace_pair{};
425   for (auto &input : cond_nodes_used_by_after_partial) {
426     after_partial_cnode_inputs.push_back(visited_nodes_and_cond_fg_inputs_replace_pairs.at(input));
427     auto new_parameter = after_fg->add_parameter();
428     MS_CHECK_TRUE_MSG(new_parameter != nullptr, RET_FAILED, "new_parameter != nullptr");
429     new_parameter->set_name(input->fullname_with_scope() + "_after_fg_parameter");
430     new_parameter->set_abstract(input->abstract());
431     visited_nodes_after_fg_replace_pair[visited_nodes_and_cond_fg_inputs_replace_pairs.at(input)] = new_parameter;
432   }
433 
434   ReplaceNode(after_fg, visited_nodes_and_cond_fg_inputs_replace_pairs);
435   ReplaceNode(after_fg, after_partial_inputs_and_after_fg_inputs_replace_pairs);
436   ReplaceNode(after_fg, visited_nodes_after_fg_replace_pair);
437   *after_partial_cnode = cond_fg->NewCNode(after_partial_cnode_inputs);
438   MS_CHECK_TRUE_MSG(*after_partial_cnode != nullptr, RET_FAILED, "new cnode is nullptr");
439   (*after_partial_cnode)->set_fullname_with_scope("CNode_" + after_fg->get_attr("graph_name")->ToString());
440   return RET_SUCCESS;
441 }
442 
ProcessWhileOp(const FuncGraphPtr & fg,const std::set<AnfNodePtr> & visited_nodes,const std::vector<AnfNodePtr> & remain_nodes,const AnfNodePtr & while_node)443 int ControlFlowPass::ProcessWhileOp(const FuncGraphPtr &fg, const std::set<AnfNodePtr> &visited_nodes,
444                                     const std::vector<AnfNodePtr> &remain_nodes, const AnfNodePtr &while_node) {
445   if (while_node == nullptr) {
446     MS_LOG(INFO) << "not found while, no need to process.";
447     return RET_SUCCESS;
448   }
449 
450   auto while_cnode = while_node->cast<CNodePtr>();
451   MS_ASSERT(while_cnode != nullptr);
452   if (while_cnode->size() < kWhileMinInputSize) {
453     MS_LOG(ERROR) << "while input is not right.";
454     return RET_FAILED;
455   }
456 
457   std::vector<AnfNodePtr> visited_nodes_used_by_after_fg{};
458   VisitedNodesUsedByAfterParts(visited_nodes, remain_nodes, &visited_nodes_used_by_after_fg);
459 
460   CNodePtr cond_call_cnode = nullptr;
461   std::unordered_map<AnfNodePtr, AnfNodePtr> visited_nodes_and_cond_fg_inputs_replace_pairs{};
462   std::vector<AnfNodePtr> cond_nodes_used_by_after_partial{};
463   int ret = CreateWhileCondCallNode(fg, while_cnode, visited_nodes_used_by_after_fg, &cond_call_cnode,
464                                     &cond_nodes_used_by_after_partial, &visited_nodes_and_cond_fg_inputs_replace_pairs);
465   if (ret != RET_SUCCESS) {
466     MS_LOG(ERROR) << "while create cond call cnode failed, ret: " << ret;
467     return ret;
468   }
469 
470   auto cond_fg_cnode = cond_call_cnode->input(kCNodePrimIndex)->cast<CNodePtr>();
471   MS_ASSERT(cond_fg_cnode != nullptr);
472   AnfNodePtr cond_fg_vnode = cond_fg_cnode->input(kCNodeFirstInputIndex);
473   MS_ASSERT(cond_fg_vnode != nullptr);
474   auto cond_fg = GetValueNode<std::shared_ptr<FuncGraph>>(cond_fg_vnode);
475   MS_CHECK_TRUE_MSG(cond_fg != nullptr, RET_FAILED, "Get value as func_graph failed.");
476 
477   CNodePtr body_partial_node = nullptr;
478   ret = CreateWhileBodyPartialNode(cond_fg, while_cnode, &body_partial_node);
479   if (ret != RET_SUCCESS) {
480     MS_LOG(ERROR) << "while create body partial cnode failed, ret: " << ret;
481     return ret;
482   }
483 
484   CNodePtr after_partial_cnode = nullptr;
485   ret = CreateWhileAfterPartialNode(fg, cond_fg, remain_nodes, visited_nodes_used_by_after_fg,
486                                     visited_nodes_and_cond_fg_inputs_replace_pairs, &while_cnode, &after_partial_cnode);
487   if (ret != RET_SUCCESS) {
488     MS_LOG(ERROR) << "while create after partial cnode failed, ret: " << ret;
489     return ret;
490   }
491 
492   // create switch cnode
493   ValueNodePtr switch_anf_primitive = lite::GetSwitchAnfPrim();
494   if (switch_anf_primitive == nullptr) {
495     MS_LOG(ERROR) << "GetSwitchAnfPrim failed.";
496     return lite::RET_ERROR;
497   }
498 
499   // insert switch node
500   std::vector<AnfNodePtr> switch_node_inputs = {switch_anf_primitive, cond_fg->output(), body_partial_node,
501                                                 after_partial_cnode};
502   auto switch_cnode = cond_fg->NewCNode(switch_node_inputs);
503   MS_CHECK_TRUE_MSG(switch_cnode != nullptr, RET_ERROR, "NewCnode failed");
504   switch_cnode->set_fullname_with_scope("while-Switch-" + cond_fg->get_attr("graph_name")->ToString());
505 
506   // insert call node
507   std::vector<AnfNodePtr> call_node_inputs{switch_cnode};
508   auto call_node = cond_fg->NewCNode(call_node_inputs);
509   MS_CHECK_TRUE_MSG(call_node != nullptr, lite::RET_NULL_PTR, "call_node is nullptr");
510   call_node->set_fullname_with_scope("call_" + switch_cnode->fullname_with_scope());
511   cond_fg->set_output(call_node);
512 
513   fg->DropNode(while_cnode);
514   fg->set_output(cond_call_cnode);
515 
516   auto after_cnode = after_partial_cnode->input(kCNodeFirstInputIndex)->cast<ValueNodePtr>();
517   MS_ASSERT(after_cnode != nullptr);
518   auto after_fg = after_cnode->value()->cast<FuncGraphPtr>();
519   if (after_fg == nullptr) {
520     MS_LOG(ERROR) << "after_fg is nullptr.";
521     return RET_FAILED;
522   }
523   to_process_q.push_back(cond_fg);
524   to_process_q.push_back(after_fg);
525   return RET_SUCCESS;
526 }
527 
CreateIfPartialNodeExternalInputs(const CNodePtr & if_cnode,const FuncGraphPtr & partial_fg,std::vector<AnfNodePtr> * then_partial_cnode_inputs)528 int ControlFlowPass::CreateIfPartialNodeExternalInputs(const CNodePtr &if_cnode, const FuncGraphPtr &partial_fg,
529                                                        std::vector<AnfNodePtr> *then_partial_cnode_inputs) {
530   auto if_inputs = if_cnode->inputs();
531   auto fg_name_attr = partial_fg->get_attr("graph_name");
532   MS_CHECK_TRUE_RET(fg_name_attr != nullptr, RET_FAILED);
533   auto partial_fg_name = fg_name_attr->ToString();
534   std::vector<AnfNodePtr> if_external_inputs{};
535   if_external_inputs.assign(if_inputs.begin() + kIfMinInputSize, if_inputs.end());
536   auto origin_then_fg_inputs = partial_fg->get_inputs();
537   if (if_external_inputs.size() < origin_then_fg_inputs.size()) {
538     MS_LOG(ERROR) << "graph is not right.";
539     return RET_FAILED;
540   } else if (if_external_inputs.size() == origin_then_fg_inputs.size()) {
541     then_partial_cnode_inputs->insert(then_partial_cnode_inputs->end(), if_external_inputs.begin(),
542                                       if_external_inputs.end());
543     return RET_SUCCESS;
544   } else {
545     for (auto &fg_input : origin_then_fg_inputs) {
546       auto fg_input_name = fg_input->fullname_with_scope();
547       auto pos = partial_fg_name.size() + sizeof("_input_");
548       auto pos2 = fg_input_name.find('_', pos);
549       auto idx_str = fg_input_name.substr(pos - 1, pos2 - pos + 1);
550       auto partial_idx = 0;
551       try {
552         partial_idx = std::stoi(idx_str);
553       } catch (const std::exception &e) {
554         MS_LOG(ERROR) << "Get index failed: " << e.what();
555         return RET_FAILED;
556       }
557       then_partial_cnode_inputs->push_back(if_external_inputs.at(partial_idx));
558     }
559   }
560   return RET_SUCCESS;
561 }
562 
CreateIfPartialNode(const FuncGraphPtr & fg,const size_t & index,std::vector<AnfNodePtr> * visited_nodes_used_by_after_fg,const CNodePtr & if_cnode,const FuncGraphPtr & after_fg,CNodePtr * then_partial_cnode)563 int ControlFlowPass::CreateIfPartialNode(const FuncGraphPtr &fg, const size_t &index,
564                                          std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg,
565                                          const CNodePtr &if_cnode, const FuncGraphPtr &after_fg,
566                                          CNodePtr *then_partial_cnode) {
567   auto then_vnode = if_cnode->input(index);
568   MS_ASSERT(then_vnode != nullptr);
569   auto then_fg = GetValueNode<std::shared_ptr<FuncGraph>>(then_vnode);
570   MS_CHECK_TRUE_MSG(then_fg != nullptr, RET_FAILED, "Get value as func_graph failed.");
571 
572   // create then partial node
573   ValueNodePtr then_partial_anf_primitive = lite::GetPartialFusionPrim();
574   MS_CHECK_TRUE_MSG(then_partial_anf_primitive != nullptr, RET_FAILED, "GetPartialFusionPrim failed.");
575   std::vector<AnfNodePtr> then_partial_cnode_inputs{then_partial_anf_primitive, then_vnode};
576   if (CreateIfPartialNodeExternalInputs(if_cnode, then_fg, &then_partial_cnode_inputs) != RET_SUCCESS) {
577     MS_LOG(ERROR) << "CreateIfPartialNodeExternalInputs failed.";
578     return RET_FAILED;
579   }
580   std::unordered_map<AnfNodePtr, AnfNodePtr> visited_nodes_and_after_partial_inputs_replace_pairs{};
581   std::vector<AnfNodePtr> then_nodes_used_by_after_partial{};
582   // set fg inputs to then_partial_cnode inputs
583   auto origin_then_fg_inputs = then_fg->get_inputs();
584   for (auto &item : *visited_nodes_used_by_after_fg) {
585     bool found = false;
586     size_t input_index = 0;
587     for (size_t i = kPartialFirstInputSize; i < then_partial_cnode_inputs.size(); ++i) {
588       if (then_partial_cnode_inputs[i] == item) {
589         found = true;
590         input_index = i - kPartialFirstInputSize;
591         break;
592       }
593     }
594     if (found) {
595       visited_nodes_and_after_partial_inputs_replace_pairs[item] = origin_then_fg_inputs.at(input_index);
596       then_nodes_used_by_after_partial.push_back(origin_then_fg_inputs.at(input_index));
597       continue;
598     }
599 
600     // set after fg inputs to cond_partial_cnode inputs
601     then_partial_cnode_inputs.push_back(item);
602     auto new_parameter = then_fg->add_parameter();
603     MS_CHECK_TRUE_MSG(new_parameter != nullptr, RET_FAILED, "new_parameter is nullptr");
604     if (index == kIfThenIndex) {
605       new_parameter->set_name(item->fullname_with_scope() + "_then_fg_parameter");
606     } else {
607       new_parameter->set_name(item->fullname_with_scope() + "_else_fg_parameter");
608     }
609     new_parameter->set_abstract(item->abstract());
610     visited_nodes_and_after_partial_inputs_replace_pairs[item] = new_parameter;
611     then_nodes_used_by_after_partial.push_back(new_parameter);
612   }
613   *then_partial_cnode = fg->NewCNode(then_partial_cnode_inputs);
614   MS_CHECK_TRUE_MSG(*then_partial_cnode != nullptr, RET_FAILED, "new cnode is nullptr");
615   auto fg_name_attr = then_fg->get_attr("graph_name");
616   MS_CHECK_TRUE_RET(fg_name_attr != nullptr, RET_FAILED);
617   auto then_fg_name = fg_name_attr->ToString();
618   (*then_partial_cnode)->set_fullname_with_scope("partial_" + then_fg_name);
619 
620   // create after partial node
621   ValueNodePtr after_partial_anf_primitive = lite::GetPartialFusionPrim();
622   MS_CHECK_TRUE_MSG(after_partial_anf_primitive != nullptr, RET_FAILED, "GetPartialFusionPrim failed.");
623   auto after_value_node = NewValueNode(after_fg);
624   MS_CHECK_TRUE_MSG(after_value_node != nullptr, RET_FAILED, "NewValueNode failed.");
625   // make the right after partial input
626   std::vector<AnfNodePtr> after_partial_cnode_inputs{after_partial_anf_primitive, after_value_node};
627   if (!CheckPrimitiveType(then_fg->output(), prim::kPrimMakeTuple)) {
628     after_partial_cnode_inputs.push_back(then_fg->output());
629   } else {
630     auto then_fg_output = then_fg->output()->cast<CNodePtr>();
631     MS_CHECK_TRUE_MSG(then_fg_output != nullptr, RET_ERROR, "cast ptr failed");
632     for (size_t i = kCNodeFirstInputIndex; i < then_fg_output->size(); ++i) {
633       after_partial_cnode_inputs.push_back(then_fg_output->input(i));
634     }
635     then_fg->DropNode(then_fg_output);
636   }
637   size_t if_output_size = after_partial_cnode_inputs.size() - kCNodeSecondInputIndex;
638 
639   // add after fg inputs to partial node
640   std::copy(then_nodes_used_by_after_partial.begin(), then_nodes_used_by_after_partial.end(),
641             std::back_inserter(after_partial_cnode_inputs));
642   // insert partial node
643   auto after_partial_cnode = then_fg->NewCNode(after_partial_cnode_inputs);
644   MS_CHECK_TRUE_MSG(after_partial_cnode != nullptr, RET_FAILED, "NewCNode failed");
645   auto after_fg_name = after_fg->get_attr("graph_name")->ToString();
646   after_partial_cnode->set_fullname_with_scope("partial_" + after_fg_name);
647 
648   // insert call node
649   std::vector<AnfNodePtr> call_node_inputs{after_partial_cnode};
650   auto call_node = then_fg->NewCNode(call_node_inputs);
651   MS_CHECK_TRUE_MSG(call_node != nullptr, RET_FAILED, "NewCNode failed");
652   call_node->set_fullname_with_scope("call_" + after_partial_cnode->fullname_with_scope());
653   then_fg->set_output(call_node);
654   to_process_q.push_back(then_fg);
655   ReplaceNode(after_fg, visited_nodes_and_after_partial_inputs_replace_pairs);
656 
657   // check the inputs of after fg
658   auto after_fg_inputs_size = after_fg->get_inputs().size();
659   if (after_fg_inputs_size == after_partial_cnode_inputs.size() - kPartialFirstInputSize) {
660     return RET_SUCCESS;
661   }
662 
663   // make the inputs of the after fg
664   std::unordered_map<AnfNodePtr, AnfNodePtr> after_partial_after_fg_replace_pairs{};
665   for (size_t i = kPartialFirstInputSize; i < after_partial_cnode_inputs.size(); ++i) {
666     auto &input = after_partial_cnode_inputs[i];
667     auto new_parameter = after_fg->add_parameter();
668     MS_CHECK_TRUE_MSG(new_parameter != nullptr, RET_FAILED, "add_parameter failed");
669     new_parameter->set_name(std::to_string(i - kPartialFirstInputSize) + "_" + input->fullname_with_scope());
670     new_parameter->set_abstract(input->abstract());
671     if (i < kPartialFirstInputSize + if_output_size) {
672       after_partial_after_fg_replace_pairs[if_cnode] = new_parameter;
673     } else {
674       after_partial_after_fg_replace_pairs[input] = new_parameter;
675     }
676   }
677   ReplaceNode(after_fg, after_partial_after_fg_replace_pairs);
678 
679   return RET_SUCCESS;
680 }
681 
CreateIfElsePartialNode(const FuncGraphPtr & main_fg,std::vector<AnfNodePtr> * visited_nodes_used_by_after_fg,const CNodePtr & if_cnode,const FuncGraphPtr & after_fg,CNodePtr * else_partial_cnode)682 int ControlFlowPass::CreateIfElsePartialNode(const FuncGraphPtr &main_fg,
683                                              std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg,
684                                              const CNodePtr &if_cnode, const FuncGraphPtr &after_fg,
685                                              CNodePtr *else_partial_cnode) {
686   return CreateIfPartialNode(main_fg, kIfElseIndex, visited_nodes_used_by_after_fg, if_cnode, after_fg,
687                              else_partial_cnode);
688 }
689 
CreateIfThenPartialNode(const FuncGraphPtr & main_fg,std::vector<AnfNodePtr> * visited_nodes_used_by_after_fg,const CNodePtr & if_cnode,const FuncGraphPtr & after_fg,CNodePtr * then_partial_cnode)690 int ControlFlowPass::CreateIfThenPartialNode(const FuncGraphPtr &main_fg,
691                                              std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg,
692                                              const CNodePtr &if_cnode, const FuncGraphPtr &after_fg,
693                                              CNodePtr *then_partial_cnode) {
694   return CreateIfPartialNode(main_fg, kIfThenIndex, visited_nodes_used_by_after_fg, if_cnode, after_fg,
695                              then_partial_cnode);
696 }
697 
ProcessIfOp(const FuncGraphPtr & fg,const std::set<AnfNodePtr> & visited_nodes,const std::vector<AnfNodePtr> & remain_nodes,const AnfNodePtr & if_node)698 int ControlFlowPass::ProcessIfOp(const FuncGraphPtr &fg, const std::set<AnfNodePtr> &visited_nodes,
699                                  const std::vector<AnfNodePtr> &remain_nodes, const AnfNodePtr &if_node) {
700   if (if_node == nullptr) {
701     MS_LOG(INFO) << "not found if, no need to process.";
702     return RET_SUCCESS;
703   }
704 
705   auto if_cnode = if_node->cast<CNodePtr>();
706   MS_ASSERT(if_cnode != nullptr);
707   if (if_cnode->size() < kIfMinInputSize) {
708     MS_LOG(ERROR) << "if input is not right.";
709     return RET_FAILED;
710   }
711 
712   // create after_fg
713   FuncGraphPtr after_fg = nullptr;
714   if (CreateAfterGraph(fg, remain_nodes, if_cnode, &after_fg) != RET_SUCCESS) {
715     MS_LOG(ERROR) << "CreateAfterGraph failed.";
716     return RET_FAILED;
717   }
718 
719   // get fg input which is not used by after_parts
720   std::vector<AnfNodePtr> visited_nodes_used_by_after_fg{};
721   VisitedNodesUsedByAfterParts(visited_nodes, remain_nodes, &visited_nodes_used_by_after_fg);
722 
723   CNodePtr then_partial_cnode = nullptr;
724   int ret = CreateIfThenPartialNode(fg, &visited_nodes_used_by_after_fg, if_cnode, after_fg, &then_partial_cnode);
725   if (ret != RET_SUCCESS) {
726     MS_LOG(ERROR) << "if create then partial cnode failed, ret: " << ret;
727     return ret;
728   }
729 
730   CNodePtr else_partial_cnode = nullptr;
731   ret = CreateIfElsePartialNode(fg, &visited_nodes_used_by_after_fg, if_cnode, after_fg, &else_partial_cnode);
732   if (ret != RET_SUCCESS) {
733     MS_LOG(ERROR) << "if create else partial cnode failed, ret: " << ret;
734     return ret;
735   }
736 
737   // create switch cnode
738   ValueNodePtr switch_anf_primitive = lite::GetSwitchAnfPrim();
739   if (switch_anf_primitive == nullptr) {
740     MS_LOG(ERROR) << "GetSwitchAnfPrim failed.";
741     return RET_FAILED;
742   }
743 
744   //  insert switch node
745   std::vector<AnfNodePtr> switch_node_inputs = {switch_anf_primitive, if_cnode->input(kIfCondIndex), then_partial_cnode,
746                                                 else_partial_cnode};
747   auto switch_cnode = fg->NewCNode(switch_node_inputs);
748   MS_CHECK_TRUE_MSG(switch_cnode != nullptr, RET_FAILED, "NewCNode failed");
749   switch_cnode->set_fullname_with_scope("if-Switch-" + fg->get_attr("graph_name")->ToString());
750 
751   // insert call node
752   std::vector<AnfNodePtr> call_node_inputs{switch_cnode};
753   auto call_node = fg->NewCNode(call_node_inputs);
754   MS_CHECK_TRUE_MSG(call_node != nullptr, RET_FAILED, "NewCNode failed");
755   call_node->set_fullname_with_scope("call_" + switch_cnode->fullname_with_scope());
756   fg->DropNode(if_cnode);
757   fg->set_output(call_node, true);
758 
759   to_process_q.push_back(after_fg);
760   return RET_SUCCESS;
761 }
762 
ProcessControlOp(const FuncGraphPtr & fg)763 int ControlFlowPass::ProcessControlOp(const FuncGraphPtr &fg) {
764   if (fg == nullptr) {
765     MS_LOG(ERROR) << "fg is nullptr.";
766     return RET_FAILED;
767   }
768 
769   AnfNodePtr control_flow_node = nullptr;
770   std::vector<AnfNodePtr> remain_nodes{};
771   std::set<AnfNodePtr> visited_nodes{};
772   int ret = SplitGraph(fg, &control_flow_node, &visited_nodes, &remain_nodes);
773   if (ret != RET_SUCCESS) {
774     MS_LOG(ERROR) << "SplitGraph failed, ret: " << ret;
775     return ret;
776   }
777 
778   if (control_flow_node == nullptr) {
779     MS_LOG(INFO) << "not found control flow op, no need to process.";
780     return RET_SUCCESS;
781   }
782 
783   if (CheckPrimitiveType(control_flow_node, prim::kPrimWhile)) {
784     ret = ProcessWhileOp(fg, visited_nodes, remain_nodes, control_flow_node);
785     if (ret != RET_SUCCESS) {
786       MS_LOG(ERROR) << "ProcessWhileOp failed.";
787       return ret;
788     }
789   }
790 
791   if (CheckPrimitiveType(control_flow_node, prim::kPrimIf)) {
792     ret = ProcessIfOp(fg, visited_nodes, remain_nodes, control_flow_node);
793     if (ret != RET_SUCCESS) {
794       MS_LOG(ERROR) << "ProcessIfOp failed.";
795       return ret;
796     }
797   }
798   return RET_SUCCESS;
799 }
800 
Run(const FuncGraphPtr & fg)801 bool ControlFlowPass::Run(const FuncGraphPtr &fg) {
802   MS_ASSERT(fg != nullptr);
803   to_process_q.push_back(fg);
804   while (!to_process_q.empty()) {
805     auto cur_fg = to_process_q.front();
806     auto cur_fg_name = cur_fg->get_attr("graph_name")->ToString();
807     int ret = ProcessControlOp(cur_fg);
808     if (ret != RET_SUCCESS) {
809       MS_LOG(ERROR) << "ProcessControlOp for graph: " << cur_fg_name << " failed.";
810       lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
811       return false;
812     }
813     to_process_q.pop_front();
814   }
815   return true;
816 }
817 }  // namespace mindspore::opt
818