1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ 17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ 18 19 #include <vector> 20 #include <memory> 21 #include <string> 22 #include "backend/optimizer/common/optimizer.h" 23 #include "utils/utils.h" 24 namespace mindspore { 25 namespace opt { 26 class AdamApplyOneWithDecayRule : public PatternProcessPass { 27 public: 28 explicit AdamApplyOneWithDecayRule(const std::string &name = "adam_apply_one_with_decay_rule", bool multigraph = true) PatternProcessPass(name,multigraph)29 : PatternProcessPass(name, multigraph) { 30 input0_ = std::make_shared<Var>(); 31 input1_ = std::make_shared<Var>(); 32 input2_ = std::make_shared<Var>(); 33 input3_ = std::make_shared<Var>(); 34 input4_ = std::make_shared<Var>(); 35 mul0_x_ = std::make_shared<Var>(); 36 mul1_x_ = std::make_shared<Var>(); 37 mul2_x_ = std::make_shared<Var>(); 38 mul3_x_ = std::make_shared<Var>(); 39 mul4_x_ = std::make_shared<Var>(); 40 add2_y_ = std::make_shared<Var>(); 41 add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name())); 42 add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name())); 43 sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimSub->name())); 44 } 45 ~AdamApplyOneWithDecayRule() override = default; 46 const BaseRef DefinePattern() const override = 0; 47 const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; 48 49 protected: 50 std::vector<AnfNodePtr> GetFusionNodeInputs(const EquivPtr &equiv, const AnfNodePtr &final_node) const; 51 VarPtr input0_; 52 VarPtr input1_; 53 VarPtr input2_; 54 VarPtr input3_; 55 VarPtr input4_; 56 VarPtr mul0_x_; 57 VarPtr mul1_x_; 58 VarPtr mul2_x_; 59 VarPtr mul3_x_; 60 VarPtr mul4_x_; 61 VarPtr add2_y_; 62 VarPtr add0_var_; 63 VarPtr add1_var_; 64 VarPtr sub0_var_; 65 }; 66 67 class AdamApplyOneWithDecayRuleCond1 : public AdamApplyOneWithDecayRule { 68 public: 69 explicit AdamApplyOneWithDecayRuleCond1(bool multigraph = true) 70 : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond1", multigraph) {} 71 72 ~AdamApplyOneWithDecayRuleCond1() override = default; 73 const BaseRef DefinePattern() const override; 74 }; 75 76 class AdamApplyOneWithDecayRuleCond2 : public AdamApplyOneWithDecayRule { 77 public: 78 explicit AdamApplyOneWithDecayRuleCond2(bool multigraph = true) 79 : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond2", multigraph) {} 80 81 ~AdamApplyOneWithDecayRuleCond2() override = default; 82 const BaseRef DefinePattern() const override; 83 }; 84 85 class AdamApplyOneWithDecayRuleCond3 : public AdamApplyOneWithDecayRule { 86 public: 87 explicit AdamApplyOneWithDecayRuleCond3(bool multigraph = true) 88 : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond3", multigraph) {} 89 90 ~AdamApplyOneWithDecayRuleCond3() override = default; 91 const BaseRef DefinePattern() const override; 92 }; 93 94 class AdamApplyOneWithDecayRuleCond4 : public AdamApplyOneWithDecayRule { 95 public: 96 explicit AdamApplyOneWithDecayRuleCond4(bool multigraph = true) 97 : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond4", multigraph) {} 98 99 ~AdamApplyOneWithDecayRuleCond4() override = default; 100 const BaseRef DefinePattern() const override; 101 }; 102 103 class AdamApplyOneWithDecayRuleCond5 : public AdamApplyOneWithDecayRule { 104 public: 105 explicit AdamApplyOneWithDecayRuleCond5(bool multigraph = true) 106 : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond5", multigraph) {} 107 108 ~AdamApplyOneWithDecayRuleCond5() override = default; 109 const BaseRef DefinePattern() const override; 110 }; 111 112 class AdamApplyOneWithDecayAssignRuleCond1 : public AdamApplyOneWithDecayRule { 113 public: 114 explicit AdamApplyOneWithDecayAssignRuleCond1(bool multigraph = true) 115 : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond1", multigraph) {} 116 117 ~AdamApplyOneWithDecayAssignRuleCond1() override = default; 118 const BaseRef DefinePattern() const override; 119 }; 120 121 class AdamApplyOneWithDecayAssignRuleCond2 : public AdamApplyOneWithDecayRule { 122 public: 123 explicit AdamApplyOneWithDecayAssignRuleCond2(bool multigraph = true) 124 : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond2", multigraph) {} 125 126 ~AdamApplyOneWithDecayAssignRuleCond2() override = default; 127 const BaseRef DefinePattern() const override; 128 }; 129 130 class AdamApplyOneWithDecayAssignRuleCond3 : public AdamApplyOneWithDecayRule { 131 public: 132 explicit AdamApplyOneWithDecayAssignRuleCond3(bool multigraph = true) 133 : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond3", multigraph) {} 134 135 ~AdamApplyOneWithDecayAssignRuleCond3() override = default; 136 const BaseRef DefinePattern() const override; 137 }; 138 139 class AdamApplyOneWithDecayAssignRuleCond4 : public AdamApplyOneWithDecayRule { 140 public: 141 explicit AdamApplyOneWithDecayAssignRuleCond4(bool multigraph = true) 142 : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond4", multigraph) {} 143 144 ~AdamApplyOneWithDecayAssignRuleCond4() override = default; 145 const BaseRef DefinePattern() const override; 146 }; 147 148 class AdamApplyOneWithDecayAssignRuleCond5 : public AdamApplyOneWithDecayRule { 149 public: 150 explicit AdamApplyOneWithDecayAssignRuleCond5(bool multigraph = true) 151 : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond5", multigraph) {} 152 153 ~AdamApplyOneWithDecayAssignRuleCond5() override = default; 154 const BaseRef DefinePattern() const override; 155 }; 156 } // namespace opt 157 } // namespace mindspore 158 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ 159