1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ 19 20 #include <vector> 21 #include <algorithm> 22 23 #include "ir/func_graph.h" 24 #include "mindspore/core/ops/sequence_ops.h" 25 #include "mindspore/core/ops/comparison_ops.h" 26 #include "mindspore/core/ops/framework_ops.h" 27 #include "ir/func_graph_cloner.h" 28 #include "frontend/optimizer/optimizer_caller.h" 29 #include "ir/pattern_matcher.h" 30 #include "frontend/operator/ops.h" 31 #include "frontend/optimizer/irpass.h" 32 #include "pipeline/jit/ps/parse/resolve.h" 33 34 namespace mindspore { 35 namespace opt { 36 namespace irpass { 37 // {prim::kPrimSwitch, true, X, Y} 38 // {prim::kPrimSwitch, false, X, Y} 39 class SwitchSimplify : public OptimizerCaller { 40 public: operator()41 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 42 PatternNode<AnfNodePtr> cond; 43 PatternNode<AnfNodePtr> true_br; 44 PatternNode<AnfNodePtr> false_br; 45 auto SwitchSimplLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr { 46 auto value_ptr = GetValueNode(cond.GetNode(node)); 47 bool cond_value; 48 if (value_ptr->isa<BoolImm>()) { 49 cond_value = GetValue<bool>(value_ptr); 50 } else { 51 MS_LOG(EXCEPTION) << "The condition of branch must be a bool tensor value or a bool scalar value," 52 << " not support this condition value: " << value_ptr->ToString(); 53 } 54 55 MS_LOG(DEBUG) << "condition value: " << value_ptr->ToString() << ", cond: " << cond_value 56 << ", node: " << node->DebugString(); 57 AnfNodePtr branch_node; 58 if (cond_value) { 59 branch_node = true_br.GetNode(node); 60 } else { 61 branch_node = false_br.GetNode(node); 62 } 63 auto fg = GetValuePtr<FuncGraph>(branch_node); 64 if (fg != nullptr) { 65 MS_LOG(DEBUG) << "No recursive, " << fg->ToString(); 66 fg->set_flag(FUNC_GRAPH_FLAG_NO_RECURSIVE, true); 67 } 68 return branch_node; 69 }; 70 71 auto IsDeterminateCondition = [](const AnfNodePtr &node) -> bool { return IsValueNode<BoolImm>(node); }; 72 MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), SwitchSimplLambda, 73 cond.CheckFunc(IsDeterminateCondition, node)); 74 75 return nullptr; 76 } 77 }; 78 79 // {prim::kPrimLess, Value1, Value2} 80 // {prim::kPrimSwitch, Less, X, Y} 81 // {prim::kPrimGreater, Value1, Value2} 82 // {prim::kPrimSwitch, Greater, X, Y} 83 class CompareSwitchSimplify : public OptimizerCaller { 84 public: operator()85 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 86 PatternNode<AnfNodePtr> cond; 87 PatternNode<AnfNodePtr> true_br; 88 PatternNode<AnfNodePtr> false_br; 89 auto CompareSwitchSimplifyLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr { 90 auto cnode = node->cast<CNodePtr>(); 91 MS_EXCEPTION_IF_NULL(cnode); 92 auto compare_cnode = cnode->input(kIndex1)->cast<CNodePtr>(); 93 MS_EXCEPTION_IF_NULL(compare_cnode); 94 auto cond_tensor1 = GetValue<tensor::TensorPtr>(GetValueNode(compare_cnode->input(kIndex1))); 95 auto cond_tensor2 = GetValue<tensor::TensorPtr>(GetValueNode(compare_cnode->input(kIndex2))); 96 auto cond_value1 = reinterpret_cast<float *>(cond_tensor1->data_c()); 97 auto cond_value2 = reinterpret_cast<float *>(cond_tensor2->data_c()); 98 bool flag = false; 99 if (IsPrimitiveCNode(compare_cnode, prim::kPrimLess) && (*cond_value1 < *cond_value2)) { 100 flag = true; 101 } else if (IsPrimitiveCNode(compare_cnode, prim::kPrimGreater) && (*cond_value1 > *cond_value2)) { 102 flag = true; 103 } 104 if (flag) { 105 return true_br.GetNode(node); 106 } 107 return false_br.GetNode(node); 108 }; 109 110 auto ConstantCompareLambda = [](const AnfNodePtr &node) -> bool { 111 if (!node->isa<CNode>()) { 112 return false; 113 } 114 auto cnode = node->cast<CNodePtr>(); 115 if (!IsPrimitiveCNode(cnode, prim::kPrimLess) && !IsPrimitiveCNode(cnode, prim::kPrimGreater)) { 116 return false; 117 } 118 bool has_no_value = 119 std::any_of(cnode->inputs().begin() + kIndex1, cnode->inputs().end(), [](const AnfNodePtr &node) { 120 if (!IsValueNode<tensor::Tensor>(node)) { 121 return true; 122 } 123 auto value = GetValue<tensor::TensorPtr>(GetValueNode(node)); 124 if (value->device_address() != nullptr) { 125 return true; 126 } 127 if (value->DataSize() > 1) { 128 return true; 129 } 130 auto type_id = value->Dtype()->type_id(); 131 if (type_id != TypeId::kNumberTypeFloat32 && type_id != TypeId::kNumberTypeFloat) { 132 return true; 133 } 134 return false; 135 }); 136 return !has_no_value; 137 }; 138 139 MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), CompareSwitchSimplifyLambda, 140 cond.CheckFunc(ConstantCompareLambda, node)); 141 142 return nullptr; 143 } 144 }; 145 146 // {prim::kPrimTupleGetItem, {prim::kPrimSwitch, X0, X1, X2}, C} => 147 // {prim::kPrimSwitch, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}} 148 class FloatTupleGetItemSwitch : public OptimizerCaller { 149 public: operator()150 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 151 PatternNode<AnfNodePtr> cond; 152 PatternNode<AnfNodePtr> true_br; 153 PatternNode<AnfNodePtr> false_br; 154 PatternNode<AnfNodePtr> x; 155 MATCH_REPLACE_IF(node, 156 PPrimitive(prim::kPrimTupleGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x), 157 PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimTupleGetItem, true_br, x), 158 PPrimitive(prim::kPrimTupleGetItem, false_br, x)), 159 x.CheckFunc(IsVNode, node)); 160 return nullptr; 161 } 162 }; 163 164 // {prim::kPrimEnvironGet, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} => 165 // {prim::kPrimSwitch, X1, {prim::kPrimEnvironGet, X2, X4, X5}, {prim::kPrimEnvironGet, X3, X4, X5}} 166 class FloatEnvironGetSwitch : public OptimizerCaller { 167 public: operator()168 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 169 PatternNode<AnfNodePtr> cond; 170 PatternNode<AnfNodePtr> true_br; 171 PatternNode<AnfNodePtr> false_br; 172 PatternNode<AnfNodePtr> x; 173 PatternNode<AnfNodePtr> x2; 174 MATCH_REPLACE(node, 175 PPrimitive(prim::kPrimEnvironGet, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2), 176 PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvironGet, true_br, x, x2), 177 PPrimitive(prim::kPrimEnvironGet, false_br, x, x2))); 178 179 return nullptr; 180 } 181 }; 182 183 namespace internal { 184 FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond); 185 FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond); 186 // block_nodes[0]: condition node 187 // block_nodes[1]: true branch node 188 // block_nodes[2]: false branch node 189 // branch_output_abs[0]: true branch abstract 190 // branch_output_abs[1]: false branch abstract 191 AnfNodePtr TransformMergeBranches(const std::vector<AnfNodePtr> &block_nodes, 192 const std::vector<AbstractBasePtr> &branch_output_abs, 193 const FuncGraphPtr &func_graph); 194 } // namespace internal 195 196 // {{prim::kPrimSwitch, X, G1, G2}, Xs} 197 class ConvertSwitchReplacement { 198 public: 199 ConvertSwitchReplacement() = default; 200 virtual ~ConvertSwitchReplacement() = default; 201 operator()202 bool operator()(const FuncGraphPtr &root, const OptimizerPtr &) const { 203 auto manager = root->manager(); 204 MS_EXCEPTION_IF_NULL(manager); 205 auto all_nodes = manager->all_nodes(); 206 207 bool change = false; 208 for (auto &node : all_nodes) { 209 if (CheckSwitchWrapNode(node)) { 210 TransformSwitchBranchReplace(node); 211 change = true; 212 } 213 } 214 return change; 215 } 216 217 private: 218 // Determine whether there are graphs inside the branch graph. 219 bool CheckSwitchBranch(const AnfNodePtr &node) const; 220 // Determine whether node matches {{prim::kPrimSwitch, X, G1, G2}, Xs}. 221 bool CheckSwitchWrapNode(const AnfNodePtr &node) const; 222 // Replace switch branch. 223 void TransformSwitchBranchReplace(const AnfNodePtr &node) const; 224 }; 225 226 // {prim::kPrimSwitch, {prim::kPrimDepend, ValueNode, X}, G1, G2} -> 227 // {prim::kPrimDepend, {prim::kPrimSwitch, ValueNode, G1, G2}, X} 228 class ExchangeSwitchDependValue : public OptimizerCaller { 229 public: operator()230 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 231 if (!node->isa<CNode>() || node->func_graph() == nullptr) { 232 return nullptr; 233 } 234 ScopePtr scope = node->cast<CNodePtr>()->scope(); 235 ScopeGuard scope_guard(scope); 236 237 PatternNode<AnfNodePtr> cond; 238 PatternNode<AnfNodePtr> true_br; 239 PatternNode<AnfNodePtr> false_br; 240 PatternNode<AnfNodePtr> v; 241 PatternNode<AnfNodePtr> x; 242 MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSwitch, PPrimitive(prim::kPrimDepend, v, x), true_br, false_br), 243 PPrimitive(prim::kPrimDepend, PPrimitive(prim::kPrimSwitch, v, true_br, false_br), x), 244 IsVNode(v.GetNode(node))); 245 return nullptr; 246 } 247 }; 248 } // namespace irpass 249 } // namespace opt 250 } // namespace mindspore 251 #endif // #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ 252