• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BRANCH_CULLING_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BRANCH_CULLING_H_
19 
20 #include <vector>
21 #include <algorithm>
22 
23 #include "ir/func_graph.h"
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/comparison_ops.h"
26 #include "mindspore/core/ops/framework_ops.h"
27 #include "ir/func_graph_cloner.h"
28 #include "frontend/optimizer/optimizer_caller.h"
29 #include "ir/pattern_matcher.h"
30 #include "frontend/operator/ops.h"
31 #include "frontend/optimizer/irpass.h"
32 #include "pipeline/jit/ps/parse/resolve.h"
33 
34 namespace mindspore {
35 namespace opt {
36 namespace irpass {
37 // {prim::kPrimSwitch, true, X, Y}
38 // {prim::kPrimSwitch, false, X, Y}
39 class SwitchSimplify : public OptimizerCaller {
40  public:
operator()41   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
42     PatternNode<AnfNodePtr> cond;
43     PatternNode<AnfNodePtr> true_br;
44     PatternNode<AnfNodePtr> false_br;
45     auto SwitchSimplLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr {
46       auto value_ptr = GetValueNode(cond.GetNode(node));
47       bool cond_value;
48       if (value_ptr->isa<BoolImm>()) {
49         cond_value = GetValue<bool>(value_ptr);
50       } else {
51         MS_LOG(EXCEPTION) << "The condition of branch must be a bool tensor value or a bool scalar value,"
52                           << " not support this condition value: " << value_ptr->ToString();
53       }
54 
55       MS_LOG(DEBUG) << "condition value: " << value_ptr->ToString() << ", cond: " << cond_value
56                     << ", node: " << node->DebugString();
57       AnfNodePtr branch_node;
58       if (cond_value) {
59         branch_node = true_br.GetNode(node);
60       } else {
61         branch_node = false_br.GetNode(node);
62       }
63       auto fg = GetValuePtr<FuncGraph>(branch_node);
64       if (fg != nullptr) {
65         MS_LOG(DEBUG) << "No recursive, " << fg->ToString();
66         fg->set_flag(FUNC_GRAPH_FLAG_NO_RECURSIVE, true);
67       }
68       return branch_node;
69     };
70 
71     auto IsDeterminateCondition = [](const AnfNodePtr &node) -> bool { return IsValueNode<BoolImm>(node); };
72     MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), SwitchSimplLambda,
73                             cond.CheckFunc(IsDeterminateCondition, node));
74 
75     return nullptr;
76   }
77 };
78 
79 // {prim::kPrimLess, Value1, Value2}
80 // {prim::kPrimSwitch, Less, X, Y}
81 // {prim::kPrimGreater, Value1, Value2}
82 // {prim::kPrimSwitch, Greater, X, Y}
83 class CompareSwitchSimplify : public OptimizerCaller {
84  public:
operator()85   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
86     PatternNode<AnfNodePtr> cond;
87     PatternNode<AnfNodePtr> true_br;
88     PatternNode<AnfNodePtr> false_br;
89     auto CompareSwitchSimplifyLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr {
90       auto cnode = node->cast<CNodePtr>();
91       MS_EXCEPTION_IF_NULL(cnode);
92       auto compare_cnode = cnode->input(kIndex1)->cast<CNodePtr>();
93       MS_EXCEPTION_IF_NULL(compare_cnode);
94       auto cond_tensor1 = GetValue<tensor::TensorPtr>(GetValueNode(compare_cnode->input(kIndex1)));
95       auto cond_tensor2 = GetValue<tensor::TensorPtr>(GetValueNode(compare_cnode->input(kIndex2)));
96       auto cond_value1 = reinterpret_cast<float *>(cond_tensor1->data_c());
97       auto cond_value2 = reinterpret_cast<float *>(cond_tensor2->data_c());
98       bool flag = false;
99       if (IsPrimitiveCNode(compare_cnode, prim::kPrimLess) && (*cond_value1 < *cond_value2)) {
100         flag = true;
101       } else if (IsPrimitiveCNode(compare_cnode, prim::kPrimGreater) && (*cond_value1 > *cond_value2)) {
102         flag = true;
103       }
104       if (flag) {
105         return true_br.GetNode(node);
106       }
107       return false_br.GetNode(node);
108     };
109 
110     auto ConstantCompareLambda = [](const AnfNodePtr &node) -> bool {
111       if (!node->isa<CNode>()) {
112         return false;
113       }
114       auto cnode = node->cast<CNodePtr>();
115       if (!IsPrimitiveCNode(cnode, prim::kPrimLess) && !IsPrimitiveCNode(cnode, prim::kPrimGreater)) {
116         return false;
117       }
118       bool has_no_value =
119         std::any_of(cnode->inputs().begin() + kIndex1, cnode->inputs().end(), [](const AnfNodePtr &node) {
120           if (!IsValueNode<tensor::Tensor>(node)) {
121             return true;
122           }
123           auto value = GetValue<tensor::TensorPtr>(GetValueNode(node));
124           if (value->device_address() != nullptr) {
125             return true;
126           }
127           if (value->DataSize() > 1) {
128             return true;
129           }
130           auto type_id = value->Dtype()->type_id();
131           if (type_id != TypeId::kNumberTypeFloat32 && type_id != TypeId::kNumberTypeFloat) {
132             return true;
133           }
134           return false;
135         });
136       return !has_no_value;
137     };
138 
139     MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), CompareSwitchSimplifyLambda,
140                             cond.CheckFunc(ConstantCompareLambda, node));
141 
142     return nullptr;
143   }
144 };
145 
146 // {prim::kPrimTupleGetItem, {prim::kPrimSwitch, X0, X1, X2}, C} =>
147 // {prim::kPrimSwitch, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}}
148 class FloatTupleGetItemSwitch : public OptimizerCaller {
149  public:
operator()150   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
151     PatternNode<AnfNodePtr> cond;
152     PatternNode<AnfNodePtr> true_br;
153     PatternNode<AnfNodePtr> false_br;
154     PatternNode<AnfNodePtr> x;
155     MATCH_REPLACE_IF(node,
156                      PPrimitive(prim::kPrimTupleGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x),
157                      PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimTupleGetItem, true_br, x),
158                                 PPrimitive(prim::kPrimTupleGetItem, false_br, x)),
159                      x.CheckFunc(IsVNode, node));
160     return nullptr;
161   }
162 };
163 
164 // {prim::kPrimEnvironGet, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} =>
165 // {prim::kPrimSwitch, X1, {prim::kPrimEnvironGet, X2, X4, X5}, {prim::kPrimEnvironGet, X3, X4, X5}}
166 class FloatEnvironGetSwitch : public OptimizerCaller {
167  public:
operator()168   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
169     PatternNode<AnfNodePtr> cond;
170     PatternNode<AnfNodePtr> true_br;
171     PatternNode<AnfNodePtr> false_br;
172     PatternNode<AnfNodePtr> x;
173     PatternNode<AnfNodePtr> x2;
174     MATCH_REPLACE(node,
175                   PPrimitive(prim::kPrimEnvironGet, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2),
176                   PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvironGet, true_br, x, x2),
177                              PPrimitive(prim::kPrimEnvironGet, false_br, x, x2)));
178 
179     return nullptr;
180   }
181 };
182 
183 namespace internal {
184 FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond);
185 FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond);
186 // block_nodes[0]: condition node
187 // block_nodes[1]: true branch node
188 // block_nodes[2]: false branch node
189 // branch_output_abs[0]: true branch abstract
190 // branch_output_abs[1]: false branch abstract
191 AnfNodePtr TransformMergeBranches(const std::vector<AnfNodePtr> &block_nodes,
192                                   const std::vector<AbstractBasePtr> &branch_output_abs,
193                                   const FuncGraphPtr &func_graph);
194 }  // namespace internal
195 
196 // {{prim::kPrimSwitch, X, G1, G2}, Xs}
197 class ConvertSwitchReplacement {
198  public:
199   ConvertSwitchReplacement() = default;
200   virtual ~ConvertSwitchReplacement() = default;
201 
operator()202   bool operator()(const FuncGraphPtr &root, const OptimizerPtr &) const {
203     auto manager = root->manager();
204     MS_EXCEPTION_IF_NULL(manager);
205     auto all_nodes = manager->all_nodes();
206 
207     bool change = false;
208     for (auto &node : all_nodes) {
209       if (CheckSwitchWrapNode(node)) {
210         TransformSwitchBranchReplace(node);
211         change = true;
212       }
213     }
214     return change;
215   }
216 
217  private:
218   // Determine whether there are graphs inside the branch graph.
219   bool CheckSwitchBranch(const AnfNodePtr &node) const;
220   // Determine whether node matches {{prim::kPrimSwitch, X, G1, G2}, Xs}.
221   bool CheckSwitchWrapNode(const AnfNodePtr &node) const;
222   // Replace switch branch.
223   void TransformSwitchBranchReplace(const AnfNodePtr &node) const;
224 };
225 
226 // {prim::kPrimSwitch, {prim::kPrimDepend, ValueNode, X}, G1, G2} ->
227 // {prim::kPrimDepend, {prim::kPrimSwitch, ValueNode, G1, G2}, X}
228 class ExchangeSwitchDependValue : public OptimizerCaller {
229  public:
operator()230   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
231     if (!node->isa<CNode>() || node->func_graph() == nullptr) {
232       return nullptr;
233     }
234     ScopePtr scope = node->cast<CNodePtr>()->scope();
235     ScopeGuard scope_guard(scope);
236 
237     PatternNode<AnfNodePtr> cond;
238     PatternNode<AnfNodePtr> true_br;
239     PatternNode<AnfNodePtr> false_br;
240     PatternNode<AnfNodePtr> v;
241     PatternNode<AnfNodePtr> x;
242     MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSwitch, PPrimitive(prim::kPrimDepend, v, x), true_br, false_br),
243                      PPrimitive(prim::kPrimDepend, PPrimitive(prim::kPrimSwitch, v, true_br, false_br), x),
244                      IsVNode(v.GetNode(node)));
245     return nullptr;
246   }
247 };
248 }  // namespace irpass
249 }  // namespace opt
250 }  // namespace mindspore
251 #endif  // #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BRANCH_CULLING_H_
252