• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BRANCH_CULLING_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BRANCH_CULLING_H_
19 
20 #include <vector>
21 #include <algorithm>
22 
23 #include "ir/func_graph.h"
24 #include "ir/func_graph_cloner.h"
25 #include "frontend/optimizer/optimizer_caller.h"
26 #include "ir/pattern_matcher.h"
27 #include "frontend/operator/ops.h"
28 #include "frontend/optimizer/irpass.h"
29 
30 namespace mindspore {
31 namespace opt {
32 namespace irpass {
33 // {prim::kPrimSwitch, true, X, Y}
34 // {prim::kPrimSwitch, false, X, Y}
35 class SwitchSimplify : public OptimizerCaller {
36  public:
operator()37   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
38     PatternNode<AnfNodePtr> cond, true_br, false_br;
39     auto SwitchSimplLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr {
40       auto cond_value_ = GetValue<bool>(GetValueNode(cond.GetNode(node)));
41       if (cond_value_) {
42         return true_br.GetNode(node);
43       }
44       return false_br.GetNode(node);
45     };
46 
47     MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), SwitchSimplLambda,
48                             cond.CheckFunc(IsValueNode<BoolImm>, node));
49 
50     return nullptr;
51   }
52 };
53 
54 // {prim::kPrimTupleGetItem, {prim::kPrimSwitch, X0, X1, X2}, C} =>
55 // {prim::kPrimSwitch, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}}
56 class FloatTupleGetItemSwitch : public OptimizerCaller {
57  public:
operator()58   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
59     PatternNode<AnfNodePtr> cond, true_br, false_br, x;
60     MATCH_REPLACE_IF(node,
61                      PPrimitive(prim::kPrimTupleGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x),
62                      PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimTupleGetItem, true_br, x),
63                                 PPrimitive(prim::kPrimTupleGetItem, false_br, x)),
64                      x.CheckFunc(IsVNode, node));
65     return nullptr;
66   }
67 };
68 
69 // {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} =>
70 // {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}}
71 class FloatEnvGetItemSwitch : public OptimizerCaller {
72  public:
operator()73   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
74     PatternNode<AnfNodePtr> cond, true_br, false_br, x, x2;
75     MATCH_REPLACE(node,
76                   PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2),
77                   PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvGetItem, true_br, x, x2),
78                              PPrimitive(prim::kPrimEnvGetItem, false_br, x, x2)));
79 
80     return nullptr;
81   }
82 };
83 
84 namespace internal {
85 FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond);
86 FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond);
87 // block_nodes[0]: condition node
88 // block_nodes[1]: true branch node
89 // block_nodes[2]: false branch node
90 // branch_output_abs[0]: true branch abstract
91 // branch_output_abs[1]: false branch abstract
92 AnfNodePtr TransformMergeBranches(const std::vector<AnfNodePtr> &block_nodes,
93                                   const std::vector<AbstractBasePtr> &branch_output_abs,
94                                   const FuncGraphPtr &func_graph);
95 }  // namespace internal
96 
97 // {{prim::kPrimSwitch, X, G1, G2}, Xs}
98 class ConvertSwitchReplacement {
99  public:
100   ConvertSwitchReplacement() = default;
101   virtual ~ConvertSwitchReplacement() = default;
102 
operator()103   bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) {
104     AnfNodePtr ret = root->get_return();
105     MS_EXCEPTION_IF_NULL(ret);
106     std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
107 
108     bool change = false;
109     for (auto &node : all_nodes) {
110       if (CheckSwitchWrapNode(node)) {
111         TransformSwitchBranchReplace(node);
112         change = true;
113       }
114     }
115     return change;
116   }
117 
118  private:
119   // Determine whether there are graphs inside the branch graph.
120   bool CheckSwitchBranch(const AnfNodePtr &node);
121   // Determine whether node matches {{prim::kPrimSwitch, X, G1, G2}, Xs}.
122   bool CheckSwitchWrapNode(const AnfNodePtr &node);
123   // Replace switch branch.
124   void TransformSwitchBranchReplace(const AnfNodePtr &node);
125 };
126 
127 // {prim::kPrimSwitch, {prim::kPrimDepend, ValueNode, X}, G1, G2} ->
128 // {prim::kPrimDepend, {prim::kPrimSwitch, ValueNode, G1, G2}, X}
129 class ExchangeSwitchDependValue : public OptimizerCaller {
130  public:
operator()131   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
132     if (!node->isa<CNode>() || node->func_graph() == nullptr) {
133       return nullptr;
134     }
135     ScopePtr scope = node->cast<CNodePtr>()->scope();
136     ScopeGuard scope_guard(scope);
137 
138     PatternNode<AnfNodePtr> cond, true_br, false_br, v, x;
139     MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSwitch, PPrimitive(prim::kPrimDepend, v, x), true_br, false_br),
140                      PPrimitive(prim::kPrimDepend, PPrimitive(prim::kPrimSwitch, v, true_br, false_br), x),
141                      IsVNode(v.GetNode(node)));
142     return nullptr;
143   }
144 };
145 }  // namespace irpass
146 }  // namespace opt
147 }  // namespace mindspore
148 #endif  // #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BRANCH_CULLING_H_
149