• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.h"
18 #include <memory>
19 #include <algorithm>
20 #include "utils/utils.h"
21 #include "backend/optimizer/common/helper.h"
22 #include "base/core_ops.h"
23 
24 namespace mindspore {
25 namespace opt {
IsRuleMatched(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,std::vector<AnfNodePtr> * old_pattern_outputs) const26 bool LambNextMVRule::IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
27                                    std::vector<AnfNodePtr> *old_pattern_outputs) const {
28   MS_EXCEPTION_IF_NULL(func_graph);
29   MS_EXCEPTION_IF_NULL(equiv);
30   auto real_div0 = GetAnfNodeByVar(equiv, real_div0_var_);
31   auto real_div2 = GetAnfNodeByVar(equiv, real_div2_var_);
32   constexpr size_t kRealDiv0Size = 2;
33 
34   auto manager = func_graph->manager();
35   MS_EXCEPTION_IF_NULL(manager);
36   auto &users = manager->node_users();
37   if (users.find(real_div0) == users.end() || users[real_div0].size() < kRealDiv0Size) {
38     return false;
39   }
40   AnfNodeIndexSet real_div0_outputs = users[real_div0];
41   auto iter = std::find_if(real_div0_outputs.begin(), real_div0_outputs.end(),
42                            [&real_div2, &equiv, this](const std::pair<AnfNodePtr, int> &node_index) {
43                              return node_index.first != real_div2 && node_index.second == 1 &&
44                                     MatchAnotherPattern(node_index.first, equiv);
45                            });
46   if (iter == real_div0_outputs.end()) {
47     return false;
48   }
49 
50   (*old_pattern_outputs).push_back(node);
51   (*old_pattern_outputs).push_back(GetAnfNodeByVar(equiv, add0_var_));
52   (*old_pattern_outputs).push_back(GetAnfNodeByVar(equiv, add1_var_));
53   (*old_pattern_outputs).push_back(iter->first);
54 
55   return true;
56 }
57 
CreateLambNextMVNode(const FuncGraphPtr & func_graph,const std::vector<AnfNodePtr> & old_pattern_outputs,const EquivPtr & equiv) const58 AnfNodePtr LambNextMVRule::CreateLambNextMVNode(const FuncGraphPtr &func_graph,
59                                                 const std::vector<AnfNodePtr> &old_pattern_outputs,
60                                                 const EquivPtr &equiv) const {
61   MS_EXCEPTION_IF_NULL(func_graph);
62   auto prim = std::make_shared<Primitive>(kLambNextMVOpName);
63   std::vector<AnfNodePtr> lamb_next_mv_rule_inputs = {NewValueNode(prim)};
64   lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input0_]));
65   lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input1_]));
66   lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input2_]));
67   lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input3_]));
68   lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input4_]));
69   lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input5_]));
70   lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input6_]));
71   lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul0_x_]));
72   lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul1_sub_]));
73   lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul2_x_]));
74   lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul3_sub1_]));
75   lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul4_x_]));
76   lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[add2_y_]));
77   auto lamb_next_mv_rule = func_graph->NewCNode(lamb_next_mv_rule_inputs);
78   MS_EXCEPTION_IF_NULL(lamb_next_mv_rule);
79 
80   // Set abstract of new node
81   AbstractBasePtrList new_abstracts;
82   (void)std::transform(old_pattern_outputs.begin(), old_pattern_outputs.end(), std::back_inserter(new_abstracts),
83                        [](const AnfNodePtr &out) { return out->abstract(); });
84   auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(new_abstracts);
85   MS_EXCEPTION_IF_NULL(abstract_tuple);
86   lamb_next_mv_rule->set_abstract(abstract_tuple);
87 
88   // Create tuple_getitem node for outputs
89   std::vector<AnfNodePtr> lamb_next_mv_rule_outputs;
90   CreateMultipleOutputsOfAnfNode(func_graph, lamb_next_mv_rule, kLambNextMVRuleOutputNum, &lamb_next_mv_rule_outputs);
91 
92   auto manager = func_graph->manager();
93   MS_EXCEPTION_IF_NULL(manager);
94   (void)manager->Replace(old_pattern_outputs[kIndex1], lamb_next_mv_rule_outputs[kIndex1]);
95   (void)manager->Replace(old_pattern_outputs[kIndex2], lamb_next_mv_rule_outputs[kIndex2]);
96   (void)manager->Replace(old_pattern_outputs[kIndex3], lamb_next_mv_rule_outputs[kIndex3]);
97 
98   return lamb_next_mv_rule_outputs[0];
99 }
100 
IsShareNodes(const EquivPtr & equiv1,const EquivPtr & equiv2) const101 bool LambNextMVRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const {
102   return IsSameNode(equiv1, equiv2, real_div0_var_) && IsSameNode(equiv1, equiv2, real_div1_var_) &&
103          IsSameNode(equiv1, equiv2, add2_y_);
104 }
105 
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv) const106 const AnfNodePtr LambNextMVRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
107                                          const EquivPtr &equiv) const {
108   if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
109     return nullptr;
110   }
111   std::vector<AnfNodePtr> old_pattern_outputs;
112   if (!IsRuleMatched(func_graph, node, equiv, &old_pattern_outputs)) {
113     return nullptr;
114   }
115 
116   return CreateLambNextMVNode(func_graph, old_pattern_outputs, equiv);
117 }
118 
DefinePattern() const119 const BaseRef LambNextMVRuleCond1::DefinePattern() const {
120   const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
121 
122   auto mul0 = VectorRef({prim::kPrimMul, mul0_x_, input4_});
123   auto mul1 = VectorRef({prim::kPrimMul, mul1_sub_, input3_});
124   auto mul2 = VectorRef({prim::kPrimMul, mul2_x_, input1_});
125   auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_});
126   auto mul4 = VectorRef({prim::kPrimMul, mul4_x_, input6_});
127   auto add0 = VectorRef({add0_var_, mul0, mul1});
128   auto add1 = VectorRef({add1_var_, mul2, mul3});
129 
130   auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
131   auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
132 
133   auto add2 = VectorRef({prim::kPrimAdd, add2_y_, real_div1});
134   auto sqrt0 = VectorRef({prim_rsqrt, add2});
135   auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0});
136 
137   return VectorRef({prim::kPrimAdd, mul4, real_div2});
138 }
139 
DefineAnotherPattern() const140 BaseRef LambNextMVRuleCond1::DefineAnotherPattern() const {
141   const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
142   const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
143   VarPtr Xs = std::make_shared<SeqVar>();
144   VarPtr Ys = std::make_shared<SeqVar>();
145   // Two patterns share: real_div0, real_div1, add2_y_
146   VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
147   VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
148 
149   VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
150   VectorRef add4 = VectorRef({prim::kPrimAdd, add2_y_, sqrt1});
151   VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
152   return real_div4;
153 }
154 
DefinePattern() const155 const BaseRef LambNextMVRuleCond2::DefinePattern() const {
156   const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
157 
158   auto mul0 = VectorRef({prim::kPrimMul, input4_, mul0_x_});
159   auto mul1 = VectorRef({prim::kPrimMul, input3_, mul1_sub_});
160   auto mul2 = VectorRef({prim::kPrimMul, input1_, mul2_x_});
161   auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_});
162   auto mul4 = VectorRef({prim::kPrimMul, input6_, mul4_x_});
163   auto add0 = VectorRef({add0_var_, mul0, mul1});
164   auto add1 = VectorRef({add1_var_, mul2, mul3});
165 
166   auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
167   auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
168 
169   auto add2 = VectorRef({prim::kPrimAdd, add2_y_, real_div1});
170   auto sqrt0 = VectorRef({prim_rsqrt, add2});
171   auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0});
172 
173   return VectorRef({prim::kPrimAdd, mul4, real_div2});
174 }
175 
DefineAnotherPattern() const176 BaseRef LambNextMVRuleCond2::DefineAnotherPattern() const {
177   const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
178   const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
179   VarPtr Xs = std::make_shared<SeqVar>();
180   VarPtr Ys = std::make_shared<SeqVar>();
181   // Two patterns share: real_div0, real_div1, add2_y_
182   VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
183   VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
184 
185   VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
186   VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, add2_y_});
187   VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
188   return real_div4;
189 }
190 
DefinePattern() const191 const BaseRef LambNextMVRuleCond3::DefinePattern() const {
192   const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
193 
194   auto mul0 = VectorRef({prim::kPrimMul, input4_, mul0_x_});
195   auto mul1 = VectorRef({prim::kPrimMul, input3_, mul1_sub_});
196   auto mul2 = VectorRef({prim::kPrimMul, input1_, mul2_x_});
197   auto mul3 = VectorRef({prim::kPrimMul, input0_, mul3_sub1_});
198   auto mul4 = VectorRef({prim::kPrimMul, input6_, mul4_x_});
199   auto add0 = VectorRef({add0_var_, mul0, mul1});
200   auto add1 = VectorRef({add1_var_, mul2, mul3});
201 
202   auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
203   auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
204 
205   auto add2 = VectorRef({prim::kPrimAdd, real_div1, add2_y_});
206   auto sqrt0 = VectorRef({prim_rsqrt, add2});
207   auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0});
208 
209   return VectorRef({prim::kPrimAdd, mul4, real_div2});
210 }
211 
DefineAnotherPattern() const212 BaseRef LambNextMVRuleCond3::DefineAnotherPattern() const {
213   const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
214   const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
215   VarPtr Xs = std::make_shared<SeqVar>();
216   VarPtr Ys = std::make_shared<SeqVar>();
217   // Two patterns share: real_div0, real_div1, add2_y_
218   VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
219   VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
220 
221   VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
222   VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, add2_y_});
223   VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
224   return real_div4;
225 }
226 
DefinePattern() const227 const BaseRef LambNextMVRuleCond4::DefinePattern() const {
228   const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
229 
230   auto mul0 = VectorRef({prim::kPrimMul, mul0_x_, input4_});
231   auto mul1 = VectorRef({prim::kPrimMul, mul1_sub_, input3_});
232   auto mul2 = VectorRef({prim::kPrimMul, mul2_x_, input1_});
233   auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_});
234   auto mul4 = VectorRef({prim::kPrimMul, mul4_x_, input6_});
235   auto add0 = VectorRef({add0_var_, mul0, mul1});
236   auto add1 = VectorRef({add1_var_, mul2, mul3});
237 
238   auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
239   auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
240 
241   auto add2 = VectorRef({prim::kPrimAdd, real_div1, add2_y_});
242   auto sqrt0 = VectorRef({prim_rsqrt, add2});
243   auto real_div2 = VectorRef({real_div2_var_, real_div0, sqrt0});
244 
245   return VectorRef({prim::kPrimAdd, real_div2, mul4});
246 }
247 
DefineAnotherPattern() const248 BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const {
249   const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
250   const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
251   VarPtr Xs = std::make_shared<SeqVar>();
252   VarPtr Ys = std::make_shared<SeqVar>();
253   // Two patterns share: real_div0, real_div1, add2_y_
254   VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
255   VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
256 
257   VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
258   VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, add2_y_});
259   VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
260   return real_div4;
261 }
262 }  // namespace opt
263 }  // namespace mindspore
264