1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h"
17 #include "backend/session/anf_runtime_algorithm.h"
18 #include "frontend/optimizer/opt.h"
19 #include "utils/trace_base.h"
20 namespace mindspore {
21 namespace opt {
GetLambNextMVWithDecayOutput(const FuncGraphPtr & func_graph,const AnfNodePtr & new_node,const AnfNodePtr & add3,const AnfNodePtr & add5,const EquivPtr & equiv) const22 AnfNodePtr LambNextMVWithDecayRule::GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph,
23 const AnfNodePtr &new_node, const AnfNodePtr &add3,
24 const AnfNodePtr &add5, const EquivPtr &equiv) const {
25 MS_EXCEPTION_IF_NULL(func_graph);
26 MS_EXCEPTION_IF_NULL(new_node);
27 MS_EXCEPTION_IF_NULL(add3);
28 MS_EXCEPTION_IF_NULL(add5);
29 MS_EXCEPTION_IF_NULL(equiv);
30 auto add0 = GetAnfNodeByVar(equiv, add0_var_);
31 MS_EXCEPTION_IF_NULL(add0);
32 auto add1 = GetAnfNodeByVar(equiv, add1_var_);
33 MS_EXCEPTION_IF_NULL(add1);
34
35 // Set abstract of new node
36 AbstractBasePtrList new_node_list;
37 new_node_list.push_back(add3->abstract());
38 new_node_list.push_back(add0->abstract());
39 new_node_list.push_back(add1->abstract());
40 new_node_list.push_back(add5->abstract());
41 auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(new_node_list);
42 MS_EXCEPTION_IF_NULL(abstract_tuple);
43 new_node->set_abstract(abstract_tuple);
44 // Create tuple_getitem node for outputs
45 std::vector<AnfNodePtr> new_node_outputs;
46 CreateMultipleOutputsOfAnfNode(func_graph, new_node, kLambNextMVWithDecayOutputNum, &new_node_outputs);
47 auto manager = func_graph->manager();
48 MS_EXCEPTION_IF_NULL(manager);
49 (void)manager->Replace(add3, new_node_outputs[kIndex0]);
50 (void)manager->Replace(add0, new_node_outputs[kIndex1]);
51 (void)manager->Replace(add1, new_node_outputs[kIndex2]);
52 return new_node_outputs[kIndex3];
53 }
54
CreateLambNextMVWithDecayNode(const FuncGraphPtr & func_graph,const AnfNodePtr & add3,const AnfNodePtr & add5,const EquivPtr & equiv) const55 AnfNodePtr LambNextMVWithDecayRule::CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph,
56 const AnfNodePtr &add3, const AnfNodePtr &add5,
57 const EquivPtr &equiv) const {
58 MS_EXCEPTION_IF_NULL(func_graph);
59 MS_EXCEPTION_IF_NULL(add3);
60 MS_EXCEPTION_IF_NULL(equiv);
61 // Create new node with all the inputs
62 auto prim = std::make_shared<Primitive>(kLambNextMVWithDecayOpName);
63 std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(prim)};
64 for (size_t i = 0; i < kLambNextMVWithDecayInputNum; ++i) {
65 auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_vars_[i]]);
66 MS_EXCEPTION_IF_NULL(input_node);
67 new_node_inputs.push_back(input_node);
68 }
69 for (size_t i = 0; i < kLambNextMVWithDecayConstantMulInputNum; ++i) {
70 auto constant_mul_input_node = utils::cast<AnfNodePtr>((*equiv)[constant_mul_input_vars_[i]]);
71 MS_EXCEPTION_IF_NULL(constant_mul_input_node);
72 new_node_inputs.push_back(constant_mul_input_node);
73 }
74 auto constant_add2_y_node = utils::cast<AnfNodePtr>((*equiv)[constant_add2_y_]);
75 MS_EXCEPTION_IF_NULL(constant_add2_y_node);
76 new_node_inputs.push_back(constant_add2_y_node);
77 auto new_node = func_graph->NewCNode(new_node_inputs);
78 return GetLambNextMVWithDecayOutput(func_graph, new_node, add3, add5, equiv);
79 }
80
IsShareNodes(const EquivPtr & equiv1,const EquivPtr & equiv2) const81 bool LambNextMVWithDecayRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const {
82 return IsSameNode(equiv1, equiv2, mul4_var_) && IsSameNode(equiv1, equiv2, real_div0_var_) &&
83 IsSameNode(equiv1, equiv2, real_div1_var_) && IsSameNode(equiv1, equiv2, constant_add2_y_);
84 }
85
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv) const86 const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
87 const EquivPtr &equiv) const {
88 MS_EXCEPTION_IF_NULL(func_graph);
89 MS_EXCEPTION_IF_NULL(node);
90 if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
91 return nullptr;
92 }
93 AnfNodePtr mul4 = GetAnfNodeByVar(equiv, mul4_var_);
94 MS_EXCEPTION_IF_NULL(mul4);
95 // Get add3 and match the add3 pattern
96 auto manager = func_graph->manager();
97 MS_EXCEPTION_IF_NULL(manager);
98 if (manager->node_users().find(mul4) == manager->node_users().end()) {
99 MS_LOG(EXCEPTION) << "The Mul4 should be used by at least another node input."
100 << " trace: " << trace::DumpSourceLines(node);
101 }
102 AnfNodeIndexSet mul4_outputs = manager->node_users()[mul4];
103 auto iter = std::find_if(mul4_outputs.begin(), mul4_outputs.end(),
104 [&node, &equiv, this](const std::pair<AnfNodePtr, int> &node_index) {
105 return node_index.first != node && MatchAnotherPattern(node_index.first, equiv);
106 });
107 if (iter != mul4_outputs.end()) {
108 return CreateLambNextMVWithDecayNode(func_graph, iter->first, node, equiv);
109 }
110 return nullptr;
111 }
112
DefineAnotherPattern() const113 BaseRef LambNextMVWithDecayRuleCond1::DefineAnotherPattern() const {
114 const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
115 MS_EXCEPTION_IF_NULL(prim_rsqrt);
116 VarPtr Xs = std::make_shared<SeqVar>();
117 VarPtr Ys = std::make_shared<SeqVar>();
118 VarPtr Zs = std::make_shared<SeqVar>();
119 MS_EXCEPTION_IF_NULL(Xs);
120 MS_EXCEPTION_IF_NULL(Ys);
121 MS_EXCEPTION_IF_NULL(Zs);
122 VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
123 VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
124 VectorRef mul4 = VectorRef({mul4_var_, Zs});
125
126 VectorRef add2 = VectorRef({prim::kPrimAdd, constant_add2_y_, real_div1});
127 VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
128 VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0});
129 VectorRef add3 = VectorRef({prim::kPrimAdd, mul4, real_div2});
130 return add3;
131 }
132
DefinePattern() const133 const BaseRef LambNextMVWithDecayRuleCond1::DefinePattern() const {
134 const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
135 MS_EXCEPTION_IF_NULL(prim_sqrt);
136 const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
137 MS_EXCEPTION_IF_NULL(prim_deal_div);
138 VectorRef mul2 = VectorRef({prim::kPrimMul, input_vars_[kIndex1], constant_mul_input_vars_[kIndex2]});
139 VectorRef mul3 = VectorRef({prim::kPrimMul, input_vars_[kIndex0], constant_mul_input_vars_[kIndex3]});
140 VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
141 VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[kIndex2]});
142 VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
143 VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, constant_add2_y_});
144 VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[kIndex4], constant_mul_input_vars_[kIndex0]});
145 VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[kIndex3], constant_mul_input_vars_[kIndex1]});
146 VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
147 VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[kIndex5]});
148 VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
149 VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[kIndex4], input_vars_[kIndex6]});
150 VectorRef add5 = VectorRef({prim::kPrimAdd, mul4, real_div4});
151 return add5;
152 }
153
DefineAnotherPattern() const154 BaseRef LambNextMVWithDecayRuleCond2::DefineAnotherPattern() const {
155 const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
156 MS_EXCEPTION_IF_NULL(prim_rsqrt);
157 VarPtr Xs = std::make_shared<SeqVar>();
158 VarPtr Ys = std::make_shared<SeqVar>();
159 VarPtr Zs = std::make_shared<SeqVar>();
160 MS_EXCEPTION_IF_NULL(Xs);
161 MS_EXCEPTION_IF_NULL(Ys);
162 MS_EXCEPTION_IF_NULL(Zs);
163 VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
164 VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
165 VectorRef mul4 = VectorRef({mul4_var_, Zs});
166
167 VectorRef add2 = VectorRef({prim::kPrimAdd, constant_add2_y_, real_div1});
168 VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
169 VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0});
170 VectorRef add3 = VectorRef({prim::kPrimAdd, mul4, real_div2});
171 return add3;
172 }
173
DefinePattern() const174 const BaseRef LambNextMVWithDecayRuleCond2::DefinePattern() const {
175 const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
176 MS_EXCEPTION_IF_NULL(prim_sqrt);
177 const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
178 MS_EXCEPTION_IF_NULL(prim_deal_div);
179 VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[kIndex2], input_vars_[kIndex1]});
180 VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[kIndex3], input_vars_[kIndex0]});
181 VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
182 VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[kIndex2]});
183 VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
184 VectorRef add4 = VectorRef({prim::kPrimAdd, constant_add2_y_, sqrt1});
185 VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[kIndex0], input_vars_[kIndex4]});
186 VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[kIndex1], input_vars_[kIndex3]});
187 VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
188 VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[kIndex5]});
189 VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
190 VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[kIndex4], input_vars_[kIndex6]});
191 VectorRef add5 = VectorRef({prim::kPrimAdd, mul4, real_div4});
192 return add5;
193 }
194
DefineAnotherPattern() const195 BaseRef LambNextMVWithDecayRuleCond3::DefineAnotherPattern() const {
196 const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
197 MS_EXCEPTION_IF_NULL(prim_rsqrt);
198 VarPtr Xs = std::make_shared<SeqVar>();
199 VarPtr Ys = std::make_shared<SeqVar>();
200 VarPtr Zs = std::make_shared<SeqVar>();
201 MS_EXCEPTION_IF_NULL(Xs);
202 MS_EXCEPTION_IF_NULL(Ys);
203 MS_EXCEPTION_IF_NULL(Zs);
204 VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
205 VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
206 VectorRef mul4 = VectorRef({mul4_var_, Zs});
207
208 VectorRef add2 = VectorRef({prim::kPrimAdd, real_div1, constant_add2_y_});
209 VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
210 VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0});
211 VectorRef add3 = VectorRef({prim::kPrimAdd, mul4, real_div2});
212 return add3;
213 }
214
DefinePattern() const215 const BaseRef LambNextMVWithDecayRuleCond3::DefinePattern() const {
216 const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
217 MS_EXCEPTION_IF_NULL(prim_sqrt);
218 const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
219 MS_EXCEPTION_IF_NULL(prim_deal_div);
220 VectorRef mul2 = VectorRef({prim::kPrimMul, input_vars_[kIndex1], constant_mul_input_vars_[kIndex2]});
221 VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[kIndex3], input_vars_[kIndex0]});
222 VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
223 VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[kIndex2]});
224 VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
225 VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, constant_add2_y_});
226 VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[kIndex4], constant_mul_input_vars_[kIndex0]});
227 VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[kIndex3], constant_mul_input_vars_[kIndex1]});
228 VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
229 VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[kIndex5]});
230 VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
231 VectorRef mul4 = VectorRef({mul4_var_, input_vars_[kIndex6], constant_mul_input_vars_[kIndex4]});
232 VectorRef add5 = VectorRef({prim::kPrimAdd, mul4, real_div4});
233 return add5;
234 }
235
DefineAnotherPattern() const236 BaseRef LambNextMVWithDecayRuleCond4::DefineAnotherPattern() const {
237 const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
238 MS_EXCEPTION_IF_NULL(prim_rsqrt);
239 VarPtr Xs = std::make_shared<SeqVar>();
240 VarPtr Ys = std::make_shared<SeqVar>();
241 VarPtr Zs = std::make_shared<SeqVar>();
242 MS_EXCEPTION_IF_NULL(Xs);
243 MS_EXCEPTION_IF_NULL(Ys);
244 MS_EXCEPTION_IF_NULL(Zs);
245 // Two patterns share: real_div0, real_div1, mul4, constant_add2_y_
246 VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
247 VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
248 VectorRef mul4 = VectorRef({mul4_var_, Zs});
249
250 VectorRef add2 = VectorRef({prim::kPrimAdd, real_div1, constant_add2_y_});
251 VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
252 VectorRef real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0});
253 VectorRef add3 = VectorRef({prim::kPrimAdd, real_div2, mul4});
254 return add3;
255 }
256
DefinePattern() const257 const BaseRef LambNextMVWithDecayRuleCond4::DefinePattern() const {
258 const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
259 MS_EXCEPTION_IF_NULL(prim_sqrt);
260 const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
261 MS_EXCEPTION_IF_NULL(prim_deal_div);
262 VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[kIndex2], input_vars_[kIndex1]});
263 VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[kIndex3], input_vars_[kIndex0]});
264 VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
265 VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[kIndex2]});
266 VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
267 VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, constant_add2_y_});
268 VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[kIndex0], input_vars_[kIndex4]});
269 VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[kIndex1], input_vars_[kIndex3]});
270 VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
271 VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[kIndex5]});
272 VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
273 VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[kIndex4], input_vars_[kIndex6]});
274 VectorRef add5 = VectorRef({prim::kPrimAdd, real_div4, mul4});
275 return add5;
276 }
277 } // namespace opt
278 } // namespace mindspore
279