• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_
18 
19 #include <vector>
20 #include <memory>
21 #include <string>
22 #include "backend/optimizer/common/optimizer.h"
23 #include "utils/utils.h"
24 namespace mindspore {
25 namespace opt {
26 class AdamApplyOneWithDecayRule : public PatternProcessPass {
27  public:
28   explicit AdamApplyOneWithDecayRule(const std::string &name = "adam_apply_one_with_decay_rule", bool multigraph = true)
PatternProcessPass(name,multigraph)29       : PatternProcessPass(name, multigraph) {
30     input0_ = std::make_shared<Var>();
31     input1_ = std::make_shared<Var>();
32     input2_ = std::make_shared<Var>();
33     input3_ = std::make_shared<Var>();
34     input4_ = std::make_shared<Var>();
35     mul0_x_ = std::make_shared<Var>();
36     mul1_x_ = std::make_shared<Var>();
37     mul2_x_ = std::make_shared<Var>();
38     mul3_x_ = std::make_shared<Var>();
39     mul4_x_ = std::make_shared<Var>();
40     add2_y_ = std::make_shared<Var>();
41     add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
42     add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
43     sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimSub->name()));
44   }
45   ~AdamApplyOneWithDecayRule() override = default;
46   const BaseRef DefinePattern() const override = 0;
47   const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
48 
49  protected:
50   std::vector<AnfNodePtr> GetFusionNodeInputs(const EquivPtr &equiv, const AnfNodePtr &final_node) const;
51   VarPtr input0_;
52   VarPtr input1_;
53   VarPtr input2_;
54   VarPtr input3_;
55   VarPtr input4_;
56   VarPtr mul0_x_;
57   VarPtr mul1_x_;
58   VarPtr mul2_x_;
59   VarPtr mul3_x_;
60   VarPtr mul4_x_;
61   VarPtr add2_y_;
62   VarPtr add0_var_;
63   VarPtr add1_var_;
64   VarPtr sub0_var_;
65 };
66 
67 class AdamApplyOneWithDecayRuleCond1 : public AdamApplyOneWithDecayRule {
68  public:
69   explicit AdamApplyOneWithDecayRuleCond1(bool multigraph = true)
70       : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond1", multigraph) {}
71 
72   ~AdamApplyOneWithDecayRuleCond1() override = default;
73   const BaseRef DefinePattern() const override;
74 };
75 
76 class AdamApplyOneWithDecayRuleCond2 : public AdamApplyOneWithDecayRule {
77  public:
78   explicit AdamApplyOneWithDecayRuleCond2(bool multigraph = true)
79       : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond2", multigraph) {}
80 
81   ~AdamApplyOneWithDecayRuleCond2() override = default;
82   const BaseRef DefinePattern() const override;
83 };
84 
85 class AdamApplyOneWithDecayRuleCond3 : public AdamApplyOneWithDecayRule {
86  public:
87   explicit AdamApplyOneWithDecayRuleCond3(bool multigraph = true)
88       : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond3", multigraph) {}
89 
90   ~AdamApplyOneWithDecayRuleCond3() override = default;
91   const BaseRef DefinePattern() const override;
92 };
93 
94 class AdamApplyOneWithDecayRuleCond4 : public AdamApplyOneWithDecayRule {
95  public:
96   explicit AdamApplyOneWithDecayRuleCond4(bool multigraph = true)
97       : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond4", multigraph) {}
98 
99   ~AdamApplyOneWithDecayRuleCond4() override = default;
100   const BaseRef DefinePattern() const override;
101 };
102 
103 class AdamApplyOneWithDecayRuleCond5 : public AdamApplyOneWithDecayRule {
104  public:
105   explicit AdamApplyOneWithDecayRuleCond5(bool multigraph = true)
106       : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond5", multigraph) {}
107 
108   ~AdamApplyOneWithDecayRuleCond5() override = default;
109   const BaseRef DefinePattern() const override;
110 };
111 
112 class AdamApplyOneWithDecayAssignRuleCond1 : public AdamApplyOneWithDecayRule {
113  public:
114   explicit AdamApplyOneWithDecayAssignRuleCond1(bool multigraph = true)
115       : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond1", multigraph) {}
116 
117   ~AdamApplyOneWithDecayAssignRuleCond1() override = default;
118   const BaseRef DefinePattern() const override;
119 };
120 
121 class AdamApplyOneWithDecayAssignRuleCond2 : public AdamApplyOneWithDecayRule {
122  public:
123   explicit AdamApplyOneWithDecayAssignRuleCond2(bool multigraph = true)
124       : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond2", multigraph) {}
125 
126   ~AdamApplyOneWithDecayAssignRuleCond2() override = default;
127   const BaseRef DefinePattern() const override;
128 };
129 
130 class AdamApplyOneWithDecayAssignRuleCond3 : public AdamApplyOneWithDecayRule {
131  public:
132   explicit AdamApplyOneWithDecayAssignRuleCond3(bool multigraph = true)
133       : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond3", multigraph) {}
134 
135   ~AdamApplyOneWithDecayAssignRuleCond3() override = default;
136   const BaseRef DefinePattern() const override;
137 };
138 
139 class AdamApplyOneWithDecayAssignRuleCond4 : public AdamApplyOneWithDecayRule {
140  public:
141   explicit AdamApplyOneWithDecayAssignRuleCond4(bool multigraph = true)
142       : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond4", multigraph) {}
143 
144   ~AdamApplyOneWithDecayAssignRuleCond4() override = default;
145   const BaseRef DefinePattern() const override;
146 };
147 
148 class AdamApplyOneWithDecayAssignRuleCond5 : public AdamApplyOneWithDecayRule {
149  public:
150   explicit AdamApplyOneWithDecayAssignRuleCond5(bool multigraph = true)
151       : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond5", multigraph) {}
152 
153   ~AdamApplyOneWithDecayAssignRuleCond5() override = default;
154   const BaseRef DefinePattern() const override;
155 };
156 }  // namespace opt
157 }  // namespace mindspore
158 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_
159