• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_CONTROL_FLOW_PASS_H_
18 #define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_CONTROL_FLOW_PASS_H_
19 #include <string>
20 #include <vector>
21 #include <unordered_map>
22 #include <deque>
23 #include <set>
24 #include "schema/inner/model_generated.h"
25 #include "include/backend/optimizer/pass.h"
26 
27 namespace mindspore::opt {
28 class ControlFlowPass : public Pass {
29  public:
ControlFlowPass()30   ControlFlowPass() : Pass("control_flow_pass") {}
31   ~ControlFlowPass() override = default;
32   bool Run(const FuncGraphPtr &fg) override;
33 
34  private:
35   void ReplaceNode(const FuncGraphPtr &fg, const std::unordered_map<AnfNodePtr, AnfNodePtr> &replace_pairs);
36   void VisitedNodesUsedByAfterParts(const std::set<AnfNodePtr> &visited_nodes,
37                                     const std::vector<AnfNodePtr> &remain_nodes,
38                                     std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg);
39   int SplitGraph(const FuncGraphPtr &fg, AnfNodePtr *control_flow_node, std::set<AnfNodePtr> *visited_nodes,
40                  std::vector<AnfNodePtr> *remain_nodes);
41   size_t GetItemVisitedNums(const std::set<AnfNodePtr> &visited_nodes, const AnfNodePtr &tuple_node);
42   void MoveGetItemToVisited(const size_t &need_size, const AnfNodePtr &tuple_node, std::set<AnfNodePtr> *visited_nodes,
43                             std::vector<AnfNodePtr> *remain_nodes);
44   void BindGetItemNodes(std::set<AnfNodePtr> *visited_nodes, std::vector<AnfNodePtr> *remain_nodes);
45   int CreateAfterGraph(const FuncGraphPtr &main_fg, const std::vector<AnfNodePtr> &remain_nodes,
46                        const CNodePtr &aim_cnode, FuncGraphPtr *after_fg);
47 
48   // process while
49   int CreateWhileCondCallNode(
50     const FuncGraphPtr &fg, const CNodePtr &while_cnode, const std::vector<AnfNodePtr> &visited_nodes_used_by_after_fg,
51     CNodePtr *cond_partial_cnode, std::vector<AnfNodePtr> *cond_nodes_used_by_after_partial,
52     std::unordered_map<AnfNodePtr, AnfNodePtr> *visited_nodes_and_cond_fg_inputs_replace_pairs);
53   int CreateWhileBodyPartialNode(const FuncGraphPtr &cond_fg, const CNodePtr &while_cnode, CNodePtr *body_partial_node);
54   int CreateWhileAfterPartialNode(
55     const FuncGraphPtr &main_fg, const FuncGraphPtr &cond_fg, const std::vector<AnfNodePtr> &remain_nodes,
56     const std::vector<AnfNodePtr> &cond_nodes_used_by_after_partial,
57     const std::unordered_map<AnfNodePtr, AnfNodePtr> &visited_nodes_and_cond_fg_inputs_replace_pairs,
58     const CNodePtr *while_cnode, CNodePtr *after_partial_cnode);
59   int ProcessWhileOp(const FuncGraphPtr &fg, const std::set<AnfNodePtr> &visited_nodes,
60                      const std::vector<AnfNodePtr> &remain_nodes, const AnfNodePtr &while_node);
61 
62   // process if
63   int CreateIfPartialNodeExternalInputs(const CNodePtr &if_cnode, const FuncGraphPtr &partial_fg,
64                                         std::vector<AnfNodePtr> *then_partial_cnode_inputs);
65   int CreateIfPartialNode(const FuncGraphPtr &fg, const size_t &index,
66                           std::vector<AnfNodePtr> *fg_inputs_only_used_by_after_partial, const CNodePtr &if_cnode,
67                           const FuncGraphPtr &after_fg, CNodePtr *then_partial_cnode);
68   int CreateIfThenPartialNode(const FuncGraphPtr &main_fg,
69                               std::vector<AnfNodePtr> *fg_inputs_only_used_by_after_partial, const CNodePtr &if_cnode,
70                               const FuncGraphPtr &after_fg, CNodePtr *then_partial_cnode);
71   int CreateIfElsePartialNode(const FuncGraphPtr &main_fg,
72                               std::vector<AnfNodePtr> *fg_inputs_only_used_by_after_partial, const CNodePtr &if_cnode,
73                               const FuncGraphPtr &after_fg, CNodePtr *else_partial_cnode);
74   int ProcessIfOp(const FuncGraphPtr &fg, const std::set<AnfNodePtr> &visited_nodes,
75                   const std::vector<AnfNodePtr> &remain_nodes, const AnfNodePtr &if_node);
76 
77   int ProcessControlOp(const FuncGraphPtr &fg);
78 
79   const size_t kCNodePrimIndex = 0;
80   const size_t kCNodeFirstInputIndex = 1;
81   const size_t kCNodeSecondInputIndex = 2;
82 
83   const size_t kGetItemInputSize = 3;
84   const size_t kPartialFirstInputSize = 2;
85 
86   const size_t kWhileMinInputSize = 3;
87   const size_t kWhileCondIndex = 1;
88   const size_t kWhileBodyIndex = 2;
89 
90   const size_t kIfMinInputSize = 4;
91   const size_t kIfThenIndex = 1;
92   const size_t kIfElseIndex = 2;
93   const size_t kIfCondIndex = 3;
94 
95   std::deque<FuncGraphPtr> to_process_q{};
96 };
97 }  // namespace mindspore::opt
98 #endif
99