• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h"
17 #include "backend/session/anf_runtime_algorithm.h"
18 #include "frontend/optimizer/opt.h"
19 #include "utils/trace_base.h"
20 namespace mindspore {
21 namespace opt {
GetLambNextMVWithDecayOutput(const FuncGraphPtr & func_graph,const AnfNodePtr & new_node,const AnfNodePtr & add3,const AnfNodePtr & add5,const EquivPtr & equiv) const22 AnfNodePtr LambNextMVWithDecayRule::GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph,
23                                                                  const AnfNodePtr &new_node, const AnfNodePtr &add3,
24                                                                  const AnfNodePtr &add5, const EquivPtr &equiv) const {
25   MS_EXCEPTION_IF_NULL(func_graph);
26   MS_EXCEPTION_IF_NULL(new_node);
27   MS_EXCEPTION_IF_NULL(add3);
28   MS_EXCEPTION_IF_NULL(add5);
29   MS_EXCEPTION_IF_NULL(equiv);
30   auto add0 = GetAnfNodeByVar(equiv, add0_var_);
31   MS_EXCEPTION_IF_NULL(add0);
32   auto add1 = GetAnfNodeByVar(equiv, add1_var_);
33   MS_EXCEPTION_IF_NULL(add1);
34 
35   // Set abstract of new node
36   AbstractBasePtrList new_node_list;
37   new_node_list.push_back(add3->abstract());
38   new_node_list.push_back(add0->abstract());
39   new_node_list.push_back(add1->abstract());
40   new_node_list.push_back(add5->abstract());
41   auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(new_node_list);
42   MS_EXCEPTION_IF_NULL(abstract_tuple);
43   new_node->set_abstract(abstract_tuple);
44   // Create tuple_getitem node for outputs
45   std::vector<AnfNodePtr> new_node_outputs;
46   CreateMultipleOutputsOfAnfNode(func_graph, new_node, kLambNextMVWithDecayOutputNum, &new_node_outputs);
47   auto manager = func_graph->manager();
48   MS_EXCEPTION_IF_NULL(manager);
49   (void)manager->Replace(add3, new_node_outputs[kIndex0]);
50   (void)manager->Replace(add0, new_node_outputs[kIndex1]);
51   (void)manager->Replace(add1, new_node_outputs[kIndex2]);
52   return new_node_outputs[kIndex3];
53 }
54 
CreateLambNextMVWithDecayNode(const FuncGraphPtr & func_graph,const AnfNodePtr & add3,const AnfNodePtr & add5,const EquivPtr & equiv) const55 AnfNodePtr LambNextMVWithDecayRule::CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph,
56                                                                   const AnfNodePtr &add3, const AnfNodePtr &add5,
57                                                                   const EquivPtr &equiv) const {
58   MS_EXCEPTION_IF_NULL(func_graph);
59   MS_EXCEPTION_IF_NULL(add3);
60   MS_EXCEPTION_IF_NULL(equiv);
61   // Create new node with all the inputs
62   auto prim = std::make_shared<Primitive>(kLambNextMVWithDecayOpName);
63   std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(prim)};
64   for (size_t i = 0; i < kLambNextMVWithDecayInputNum; ++i) {
65     auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_vars_[i]]);
66     MS_EXCEPTION_IF_NULL(input_node);
67     new_node_inputs.push_back(input_node);
68   }
69   for (size_t i = 0; i < kLambNextMVWithDecayConstantMulInputNum; ++i) {
70     auto constant_mul_input_node = utils::cast<AnfNodePtr>((*equiv)[constant_mul_input_vars_[i]]);
71     MS_EXCEPTION_IF_NULL(constant_mul_input_node);
72     new_node_inputs.push_back(constant_mul_input_node);
73   }
74   auto constant_add2_y_node = utils::cast<AnfNodePtr>((*equiv)[constant_add2_y_]);
75   MS_EXCEPTION_IF_NULL(constant_add2_y_node);
76   new_node_inputs.push_back(constant_add2_y_node);
77   auto new_node = func_graph->NewCNode(new_node_inputs);
78   return GetLambNextMVWithDecayOutput(func_graph, new_node, add3, add5, equiv);
79 }
80 
IsShareNodes(const EquivPtr & equiv1,const EquivPtr & equiv2) const81 bool LambNextMVWithDecayRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const {
82   return IsSameNode(equiv1, equiv2, mul4_var_) && IsSameNode(equiv1, equiv2, real_div0_var_) &&
83          IsSameNode(equiv1, equiv2, real_div1_var_) && IsSameNode(equiv1, equiv2, constant_add2_y_);
84 }
85 
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv) const86 const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
87                                                   const EquivPtr &equiv) const {
88   MS_EXCEPTION_IF_NULL(func_graph);
89   MS_EXCEPTION_IF_NULL(node);
90   if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
91     return nullptr;
92   }
93   AnfNodePtr mul4 = GetAnfNodeByVar(equiv, mul4_var_);
94   MS_EXCEPTION_IF_NULL(mul4);
95   // Get add3 and match the add3 pattern
96   auto manager = func_graph->manager();
97   MS_EXCEPTION_IF_NULL(manager);
98   if (manager->node_users().find(mul4) == manager->node_users().end()) {
99     MS_LOG(EXCEPTION) << "The Mul4 should be used by at least another node input."
100                       << " trace: " << trace::DumpSourceLines(node);
101   }
102   AnfNodeIndexSet mul4_outputs = manager->node_users()[mul4];
103   auto iter = std::find_if(mul4_outputs.begin(), mul4_outputs.end(),
104                            [&node, &equiv, this](const std::pair<AnfNodePtr, int> &node_index) {
105                              return node_index.first != node && MatchAnotherPattern(node_index.first, equiv);
106                            });
107   if (iter != mul4_outputs.end()) {
108     return CreateLambNextMVWithDecayNode(func_graph, iter->first, node, equiv);
109   }
110   return nullptr;
111 }
112 
DefineAnotherPattern() const113 BaseRef LambNextMVWithDecayRuleCond1::DefineAnotherPattern() const {
114   const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
115   MS_EXCEPTION_IF_NULL(prim_rsqrt);
116   VarPtr Xs = std::make_shared<SeqVar>();
117   VarPtr Ys = std::make_shared<SeqVar>();
118   VarPtr Zs = std::make_shared<SeqVar>();
119   MS_EXCEPTION_IF_NULL(Xs);
120   MS_EXCEPTION_IF_NULL(Ys);
121   MS_EXCEPTION_IF_NULL(Zs);
122   VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
123   VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
124   VectorRef mul4 = VectorRef({mul4_var_, Zs});
125 
126   VectorRef add2 = VectorRef({prim::kPrimAdd, constant_add2_y_, real_div1});
127   VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
128   VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0});
129   VectorRef add3 = VectorRef({prim::kPrimAdd, mul4, real_div2});
130   return add3;
131 }
132 
DefinePattern() const133 const BaseRef LambNextMVWithDecayRuleCond1::DefinePattern() const {
134   const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
135   MS_EXCEPTION_IF_NULL(prim_sqrt);
136   const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
137   MS_EXCEPTION_IF_NULL(prim_deal_div);
138   VectorRef mul2 = VectorRef({prim::kPrimMul, input_vars_[kIndex1], constant_mul_input_vars_[kIndex2]});
139   VectorRef mul3 = VectorRef({prim::kPrimMul, input_vars_[kIndex0], constant_mul_input_vars_[kIndex3]});
140   VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
141   VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[kIndex2]});
142   VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
143   VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, constant_add2_y_});
144   VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[kIndex4], constant_mul_input_vars_[kIndex0]});
145   VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[kIndex3], constant_mul_input_vars_[kIndex1]});
146   VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
147   VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[kIndex5]});
148   VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
149   VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[kIndex4], input_vars_[kIndex6]});
150   VectorRef add5 = VectorRef({prim::kPrimAdd, mul4, real_div4});
151   return add5;
152 }
153 
DefineAnotherPattern() const154 BaseRef LambNextMVWithDecayRuleCond2::DefineAnotherPattern() const {
155   const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
156   MS_EXCEPTION_IF_NULL(prim_rsqrt);
157   VarPtr Xs = std::make_shared<SeqVar>();
158   VarPtr Ys = std::make_shared<SeqVar>();
159   VarPtr Zs = std::make_shared<SeqVar>();
160   MS_EXCEPTION_IF_NULL(Xs);
161   MS_EXCEPTION_IF_NULL(Ys);
162   MS_EXCEPTION_IF_NULL(Zs);
163   VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
164   VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
165   VectorRef mul4 = VectorRef({mul4_var_, Zs});
166 
167   VectorRef add2 = VectorRef({prim::kPrimAdd, constant_add2_y_, real_div1});
168   VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
169   VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0});
170   VectorRef add3 = VectorRef({prim::kPrimAdd, mul4, real_div2});
171   return add3;
172 }
173 
DefinePattern() const174 const BaseRef LambNextMVWithDecayRuleCond2::DefinePattern() const {
175   const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
176   MS_EXCEPTION_IF_NULL(prim_sqrt);
177   const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
178   MS_EXCEPTION_IF_NULL(prim_deal_div);
179   VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[kIndex2], input_vars_[kIndex1]});
180   VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[kIndex3], input_vars_[kIndex0]});
181   VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
182   VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[kIndex2]});
183   VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
184   VectorRef add4 = VectorRef({prim::kPrimAdd, constant_add2_y_, sqrt1});
185   VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[kIndex0], input_vars_[kIndex4]});
186   VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[kIndex1], input_vars_[kIndex3]});
187   VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
188   VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[kIndex5]});
189   VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
190   VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[kIndex4], input_vars_[kIndex6]});
191   VectorRef add5 = VectorRef({prim::kPrimAdd, mul4, real_div4});
192   return add5;
193 }
194 
DefineAnotherPattern() const195 BaseRef LambNextMVWithDecayRuleCond3::DefineAnotherPattern() const {
196   const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
197   MS_EXCEPTION_IF_NULL(prim_rsqrt);
198   VarPtr Xs = std::make_shared<SeqVar>();
199   VarPtr Ys = std::make_shared<SeqVar>();
200   VarPtr Zs = std::make_shared<SeqVar>();
201   MS_EXCEPTION_IF_NULL(Xs);
202   MS_EXCEPTION_IF_NULL(Ys);
203   MS_EXCEPTION_IF_NULL(Zs);
204   VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
205   VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
206   VectorRef mul4 = VectorRef({mul4_var_, Zs});
207 
208   VectorRef add2 = VectorRef({prim::kPrimAdd, real_div1, constant_add2_y_});
209   VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
210   VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0});
211   VectorRef add3 = VectorRef({prim::kPrimAdd, mul4, real_div2});
212   return add3;
213 }
214 
DefinePattern() const215 const BaseRef LambNextMVWithDecayRuleCond3::DefinePattern() const {
216   const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
217   MS_EXCEPTION_IF_NULL(prim_sqrt);
218   const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
219   MS_EXCEPTION_IF_NULL(prim_deal_div);
220   VectorRef mul2 = VectorRef({prim::kPrimMul, input_vars_[kIndex1], constant_mul_input_vars_[kIndex2]});
221   VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[kIndex3], input_vars_[kIndex0]});
222   VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
223   VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[kIndex2]});
224   VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
225   VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, constant_add2_y_});
226   VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[kIndex4], constant_mul_input_vars_[kIndex0]});
227   VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[kIndex3], constant_mul_input_vars_[kIndex1]});
228   VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
229   VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[kIndex5]});
230   VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
231   VectorRef mul4 = VectorRef({mul4_var_, input_vars_[kIndex6], constant_mul_input_vars_[kIndex4]});
232   VectorRef add5 = VectorRef({prim::kPrimAdd, mul4, real_div4});
233   return add5;
234 }
235 
DefineAnotherPattern() const236 BaseRef LambNextMVWithDecayRuleCond4::DefineAnotherPattern() const {
237   const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
238   MS_EXCEPTION_IF_NULL(prim_rsqrt);
239   VarPtr Xs = std::make_shared<SeqVar>();
240   VarPtr Ys = std::make_shared<SeqVar>();
241   VarPtr Zs = std::make_shared<SeqVar>();
242   MS_EXCEPTION_IF_NULL(Xs);
243   MS_EXCEPTION_IF_NULL(Ys);
244   MS_EXCEPTION_IF_NULL(Zs);
245   // Two patterns share: real_div0, real_div1, mul4, constant_add2_y_
246   VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
247   VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
248   VectorRef mul4 = VectorRef({mul4_var_, Zs});
249 
250   VectorRef add2 = VectorRef({prim::kPrimAdd, real_div1, constant_add2_y_});
251   VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
252   VectorRef real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0});
253   VectorRef add3 = VectorRef({prim::kPrimAdd, real_div2, mul4});
254   return add3;
255 }
256 
DefinePattern() const257 const BaseRef LambNextMVWithDecayRuleCond4::DefinePattern() const {
258   const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
259   MS_EXCEPTION_IF_NULL(prim_sqrt);
260   const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
261   MS_EXCEPTION_IF_NULL(prim_deal_div);
262   VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[kIndex2], input_vars_[kIndex1]});
263   VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[kIndex3], input_vars_[kIndex0]});
264   VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
265   VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[kIndex2]});
266   VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
267   VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, constant_add2_y_});
268   VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[kIndex0], input_vars_[kIndex4]});
269   VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[kIndex1], input_vars_[kIndex3]});
270   VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
271   VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[kIndex5]});
272   VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
273   VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[kIndex4], input_vars_[kIndex6]});
274   VectorRef add5 = VectorRef({prim::kPrimAdd, real_div4, mul4});
275   return add5;
276 }
277 }  // namespace opt
278 }  // namespace mindspore
279