1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ 19 20 #include <vector> 21 #include <algorithm> 22 23 #include "ir/func_graph.h" 24 #include "ir/func_graph_cloner.h" 25 #include "frontend/optimizer/optimizer_caller.h" 26 #include "ir/pattern_matcher.h" 27 #include "frontend/operator/ops.h" 28 #include "frontend/optimizer/irpass.h" 29 30 namespace mindspore { 31 namespace opt { 32 namespace irpass { 33 // {prim::kPrimSwitch, true, X, Y} 34 // {prim::kPrimSwitch, false, X, Y} 35 class SwitchSimplify : public OptimizerCaller { 36 public: operator()37 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 38 PatternNode<AnfNodePtr> cond, true_br, false_br; 39 auto SwitchSimplLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr { 40 auto cond_value_ = GetValue<bool>(GetValueNode(cond.GetNode(node))); 41 if (cond_value_) { 42 return true_br.GetNode(node); 43 } 44 return false_br.GetNode(node); 45 }; 46 47 MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), SwitchSimplLambda, 48 cond.CheckFunc(IsValueNode<BoolImm>, node)); 49 50 return nullptr; 51 } 52 }; 53 54 // {prim::kPrimTupleGetItem, {prim::kPrimSwitch, X0, X1, X2}, C} => 55 // {prim::kPrimSwitch, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}} 56 class FloatTupleGetItemSwitch : public OptimizerCaller { 57 public: operator()58 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 59 PatternNode<AnfNodePtr> cond, true_br, false_br, x; 60 MATCH_REPLACE_IF(node, 61 PPrimitive(prim::kPrimTupleGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x), 62 PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimTupleGetItem, true_br, x), 63 PPrimitive(prim::kPrimTupleGetItem, false_br, x)), 64 x.CheckFunc(IsVNode, node)); 65 return nullptr; 66 } 67 }; 68 69 // {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} => 70 // {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}} 71 class FloatEnvGetItemSwitch : public OptimizerCaller { 72 public: operator()73 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 74 PatternNode<AnfNodePtr> cond, true_br, false_br, x, x2; 75 MATCH_REPLACE(node, 76 PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2), 77 PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvGetItem, true_br, x, x2), 78 PPrimitive(prim::kPrimEnvGetItem, false_br, x, x2))); 79 80 return nullptr; 81 } 82 }; 83 84 namespace internal { 85 FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond); 86 FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond); 87 // block_nodes[0]: condition node 88 // block_nodes[1]: true branch node 89 // block_nodes[2]: false branch node 90 // branch_output_abs[0]: true branch abstract 91 // branch_output_abs[1]: false branch abstract 92 AnfNodePtr TransformMergeBranches(const std::vector<AnfNodePtr> &block_nodes, 93 const std::vector<AbstractBasePtr> &branch_output_abs, 94 const FuncGraphPtr &func_graph); 95 } // namespace internal 96 97 // {{prim::kPrimSwitch, X, G1, G2}, Xs} 98 class ConvertSwitchReplacement { 99 public: 100 ConvertSwitchReplacement() = default; 101 virtual ~ConvertSwitchReplacement() = default; 102 operator()103 bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) { 104 AnfNodePtr ret = root->get_return(); 105 MS_EXCEPTION_IF_NULL(ret); 106 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret); 107 108 bool change = false; 109 for (auto &node : all_nodes) { 110 if (CheckSwitchWrapNode(node)) { 111 TransformSwitchBranchReplace(node); 112 change = true; 113 } 114 } 115 return change; 116 } 117 118 private: 119 // Determine whether there are graphs inside the branch graph. 120 bool CheckSwitchBranch(const AnfNodePtr &node); 121 // Determine whether node matches {{prim::kPrimSwitch, X, G1, G2}, Xs}. 122 bool CheckSwitchWrapNode(const AnfNodePtr &node); 123 // Replace switch branch. 124 void TransformSwitchBranchReplace(const AnfNodePtr &node); 125 }; 126 127 // {prim::kPrimSwitch, {prim::kPrimDepend, ValueNode, X}, G1, G2} -> 128 // {prim::kPrimDepend, {prim::kPrimSwitch, ValueNode, G1, G2}, X} 129 class ExchangeSwitchDependValue : public OptimizerCaller { 130 public: operator()131 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 132 if (!node->isa<CNode>() || node->func_graph() == nullptr) { 133 return nullptr; 134 } 135 ScopePtr scope = node->cast<CNodePtr>()->scope(); 136 ScopeGuard scope_guard(scope); 137 138 PatternNode<AnfNodePtr> cond, true_br, false_br, v, x; 139 MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSwitch, PPrimitive(prim::kPrimDepend, v, x), true_br, false_br), 140 PPrimitive(prim::kPrimDepend, PPrimitive(prim::kPrimSwitch, v, true_br, false_br), x), 141 IsVNode(v.GetNode(node))); 142 return nullptr; 143 } 144 }; 145 } // namespace irpass 146 } // namespace opt 147 } // namespace mindspore 148 #endif // #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ 149