• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_CONTROL_FLOW_PASS_H_
18 #define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_CONTROL_FLOW_PASS_H_
19 #include <string>
20 #include <vector>
21 #include <unordered_map>
22 #include <deque>
23 #include <set>
24 #include "schema/inner/model_generated.h"
25 #include "tools/converter/converter_flags.h"
26 #include "backend/optimizer/common/pass.h"
27 
28 namespace mindspore::opt {
29 class ControlFlowPass : public Pass {
30  public:
ControlFlowPass()31   ControlFlowPass() : Pass("control_flow_pass") {}
32   ~ControlFlowPass() override = default;
33   bool Run(const FuncGraphPtr &fg) override;
34 
35  private:
36   void ReplaceNode(const FuncGraphPtr &fg, const std::unordered_map<AnfNodePtr, AnfNodePtr> &replace_pairs);
37   void VisitedNodesUsedByAfterParts(const std::set<AnfNodePtr> &visited_nodes,
38                                     const std::vector<AnfNodePtr> &remain_nodes,
39                                     std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg);
40   int SplitGraph(const FuncGraphPtr &fg, AnfNodePtr *control_flow_node, std::set<AnfNodePtr> *visited_nodes,
41                  std::vector<AnfNodePtr> *remain_nodes);
42   size_t GetItemVisitedNums(const std::set<AnfNodePtr> &visited_nodes, const AnfNodePtr &tuple_node);
43   void MoveGetItemToVisited(const size_t &need_size, const AnfNodePtr &tuple_node, std::set<AnfNodePtr> *visited_nodes,
44                             std::vector<AnfNodePtr> *remain_nodes);
45   void BindGetItemNodes(std::set<AnfNodePtr> *visited_nodes, std::vector<AnfNodePtr> *remain_nodes);
46   int CreateAfterGraph(const FuncGraphPtr &main_fg, const std::vector<AnfNodePtr> &remain_nodes,
47                        const CNodePtr &aim_cnode, FuncGraphPtr *after_fg);
48 
49   // process while
50   int CreateWhileCondCallNode(
51     const FuncGraphPtr &fg, const CNodePtr &while_cnode, const std::vector<AnfNodePtr> &visited_nodes_used_by_after_fg,
52     CNodePtr *cond_partial_cnode, std::vector<AnfNodePtr> *cond_nodes_used_by_after_partial,
53     std::unordered_map<AnfNodePtr, AnfNodePtr> *visited_nodes_and_cond_fg_inputs_replace_pairs);
54   int CreateWhileBodyPartialNode(const FuncGraphPtr &cond_fg, const CNodePtr &while_cnode, CNodePtr *body_partial_node);
55   int CreateWhileAfterPartialNode(
56     const FuncGraphPtr &main_fg, const FuncGraphPtr &cond_fg, const std::vector<AnfNodePtr> &remain_nodes,
57     const std::vector<AnfNodePtr> &cond_nodes_used_by_after_partial,
58     const std::unordered_map<AnfNodePtr, AnfNodePtr> &visited_nodes_and_cond_fg_inputs_replace_pairs,
59     CNodePtr *while_cnode, CNodePtr *after_partial_cnode);
60   int ProcessWhileOp(const FuncGraphPtr &fg, const std::set<AnfNodePtr> &visited_nodes,
61                      const std::vector<AnfNodePtr> &remain_nodes, const AnfNodePtr &while_node);
62 
63   // process if
64   int CreateIfPartialNodeExternalInputs(const CNodePtr &if_cnode, const FuncGraphPtr &partial_fg,
65                                         std::vector<AnfNodePtr> *then_partial_cnode_inputs);
66   int CreateIfPartialNode(const FuncGraphPtr &fg, const size_t &index,
67                           std::vector<AnfNodePtr> *fg_inputs_only_used_by_after_partial, CNodePtr *if_cnode,
68                           FuncGraphPtr *after_fg, CNodePtr *then_partial_cnode);
69   int CreateIfThenPartialNode(const FuncGraphPtr &main_fg,
70                               std::vector<AnfNodePtr> *fg_inputs_only_used_by_after_partial, CNodePtr *if_cnode,
71                               FuncGraphPtr *after_fg, CNodePtr *then_partial_cnode);
72   int CreateIfElsePartialNode(const FuncGraphPtr &main_fg,
73                               std::vector<AnfNodePtr> *fg_inputs_only_used_by_after_partial, CNodePtr *if_cnode,
74                               FuncGraphPtr *after_fg, CNodePtr *else_partial_cnode);
75   int ProcessIfOp(const FuncGraphPtr &fg, const std::set<AnfNodePtr> &visited_nodes,
76                   const std::vector<AnfNodePtr> &remain_nodes, const AnfNodePtr &if_node);
77 
78   int ProcessControlOp(const FuncGraphPtr &fg);
79 
80   const size_t kCNodePrimIndex = 0;
81   const size_t kCNodeFirstInputIndex = 1;
82   const size_t kCNodeSecondInputIndex = 2;
83 
84   const size_t kGetItemInputSize = 3;
85   const size_t kPartialFirstInputSize = 2;
86 
87   const size_t kWhileMinInputSize = 3;
88   const size_t kWhileCondIndex = 1;
89   const size_t kWhileBodyIndex = 2;
90 
91   const size_t kIfMinInputSize = 4;
92   const size_t kIfThenIndex = 1;
93   const size_t kIfElseIndex = 2;
94   const size_t kIfCondIndex = 3;
95 
96   std::deque<FuncGraphPtr> to_process_q{};
97 };
98 }  // namespace mindspore::opt
99 #endif
100