1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAMB_NEXT_MV_RULE_H_ 18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAMB_NEXT_MV_RULE_H_ 19 20 #include <vector> 21 #include <string> 22 #include <utility> 23 #include <unordered_map> 24 #include <memory> 25 #include "ir/anf.h" 26 #include "backend/optimizer/common/pattern_engine.h" 27 #include "backend/optimizer/common/helper.h" 28 #include "backend/optimizer/common/optimizer.h" 29 30 namespace mindspore { 31 namespace opt { 32 class LambNextMVRule : public MultipleOutputPatternProcessPass { 33 public: 34 explicit LambNextMVRule(const std::string &name = "", bool multigraph = true) MultipleOutputPatternProcessPass(name,multigraph)35 : MultipleOutputPatternProcessPass(name, multigraph) { 36 input0_ = std::make_shared<Var>(); 37 input1_ = std::make_shared<Var>(); 38 input2_ = std::make_shared<Var>(); 39 input3_ = std::make_shared<Var>(); 40 input4_ = std::make_shared<Var>(); 41 input5_ = std::make_shared<Var>(); 42 input6_ = std::make_shared<Var>(); 43 mul0_x_ = std::make_shared<Var>(); 44 mul1_sub_ = std::make_shared<Var>(); 45 mul2_x_ = std::make_shared<Var>(); 46 mul3_sub1_ = std::make_shared<Var>(); 47 mul4_x_ = std::make_shared<Var>(); 48 add2_y_ = std::make_shared<Var>(); 49 real_div0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName)); 50 real_div1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName)); 51 real_div2_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMul->name())); 52 add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name())); 53 add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name())); 54 } 55 ~LambNextMVRule() override = default; 56 const BaseRef DefinePattern() const override = 0; 57 BaseRef DefineAnotherPattern() const override = 0; 58 const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; 59 bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const override; 60 61 protected: 62 bool IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv, 63 std::vector<AnfNodePtr> *old_pattern_outputs) const; 64 AnfNodePtr CreateLambNextMVNode(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &old_pattern_outputs, 65 const EquivPtr &equiv) const; 66 67 VarPtr input0_; 68 VarPtr input1_; 69 VarPtr input2_; 70 VarPtr input3_; 71 VarPtr input4_; 72 VarPtr input5_; 73 VarPtr input6_; 74 VarPtr mul0_x_; 75 VarPtr mul1_sub_; 76 VarPtr mul2_x_; 77 VarPtr mul3_sub1_; 78 VarPtr mul4_x_; 79 VarPtr add2_y_; 80 // nodes which two patterns share, and add2_y_ also. 81 VarPtr real_div0_var_; 82 VarPtr real_div1_var_; 83 // part of output nodes 84 VarPtr add0_var_; 85 VarPtr add1_var_; 86 // other node 87 VarPtr real_div2_var_; 88 }; 89 90 class LambNextMVRuleCond1 : public LambNextMVRule { 91 public: 92 explicit LambNextMVRuleCond1(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond1", multigraph) {} 93 94 ~LambNextMVRuleCond1() override = default; 95 const BaseRef DefinePattern() const override; 96 BaseRef DefineAnotherPattern() const override; 97 }; 98 99 class LambNextMVRuleCond2 : public LambNextMVRule { 100 public: 101 explicit LambNextMVRuleCond2(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond2", multigraph) {} 102 103 ~LambNextMVRuleCond2() override = default; 104 const BaseRef DefinePattern() const override; 105 BaseRef DefineAnotherPattern() const override; 106 }; 107 108 class LambNextMVRuleCond3 : public LambNextMVRule { 109 public: 110 explicit LambNextMVRuleCond3(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond3", multigraph) {} 111 112 ~LambNextMVRuleCond3() override = default; 113 const BaseRef DefinePattern() const override; 114 BaseRef DefineAnotherPattern() const override; 115 }; 116 117 class LambNextMVRuleCond4 : public LambNextMVRule { 118 public: 119 explicit LambNextMVRuleCond4(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond4", multigraph) {} 120 121 ~LambNextMVRuleCond4() override = default; 122 const BaseRef DefinePattern() const override; 123 BaseRef DefineAnotherPattern() const override; 124 }; 125 } // namespace opt 126 } // namespace mindspore 127 128 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAMB_NEXT_MV_RULE_H_ 129