1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.h"
18 #include <memory>
19 #include <algorithm>
20 #include "utils/utils.h"
21 #include "backend/optimizer/common/helper.h"
22 #include "base/core_ops.h"
23
24 namespace mindspore {
25 namespace opt {
IsRuleMatched(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,std::vector<AnfNodePtr> * old_pattern_outputs) const26 bool LambNextMVRule::IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
27 std::vector<AnfNodePtr> *old_pattern_outputs) const {
28 MS_EXCEPTION_IF_NULL(func_graph);
29 MS_EXCEPTION_IF_NULL(equiv);
30 auto real_div0 = GetAnfNodeByVar(equiv, real_div0_var_);
31 auto real_div2 = GetAnfNodeByVar(equiv, real_div2_var_);
32 constexpr size_t kRealDiv0Size = 2;
33
34 auto manager = func_graph->manager();
35 MS_EXCEPTION_IF_NULL(manager);
36 auto &users = manager->node_users();
37 if (users.find(real_div0) == users.end() || users[real_div0].size() < kRealDiv0Size) {
38 return false;
39 }
40 AnfNodeIndexSet real_div0_outputs = users[real_div0];
41 auto iter = std::find_if(real_div0_outputs.begin(), real_div0_outputs.end(),
42 [&real_div2, &equiv, this](const std::pair<AnfNodePtr, int> &node_index) {
43 return node_index.first != real_div2 && node_index.second == 1 &&
44 MatchAnotherPattern(node_index.first, equiv);
45 });
46 if (iter == real_div0_outputs.end()) {
47 return false;
48 }
49
50 (*old_pattern_outputs).push_back(node);
51 (*old_pattern_outputs).push_back(GetAnfNodeByVar(equiv, add0_var_));
52 (*old_pattern_outputs).push_back(GetAnfNodeByVar(equiv, add1_var_));
53 (*old_pattern_outputs).push_back(iter->first);
54
55 return true;
56 }
57
CreateLambNextMVNode(const FuncGraphPtr & func_graph,const std::vector<AnfNodePtr> & old_pattern_outputs,const EquivPtr & equiv) const58 AnfNodePtr LambNextMVRule::CreateLambNextMVNode(const FuncGraphPtr &func_graph,
59 const std::vector<AnfNodePtr> &old_pattern_outputs,
60 const EquivPtr &equiv) const {
61 MS_EXCEPTION_IF_NULL(func_graph);
62 auto prim = std::make_shared<Primitive>(kLambNextMVOpName);
63 std::vector<AnfNodePtr> lamb_next_mv_rule_inputs = {NewValueNode(prim)};
64 lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input0_]));
65 lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input1_]));
66 lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input2_]));
67 lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input3_]));
68 lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input4_]));
69 lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input5_]));
70 lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input6_]));
71 lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul0_x_]));
72 lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul1_sub_]));
73 lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul2_x_]));
74 lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul3_sub1_]));
75 lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul4_x_]));
76 lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[add2_y_]));
77 auto lamb_next_mv_rule = func_graph->NewCNode(lamb_next_mv_rule_inputs);
78 MS_EXCEPTION_IF_NULL(lamb_next_mv_rule);
79
80 // Set abstract of new node
81 AbstractBasePtrList new_abstracts;
82 (void)std::transform(old_pattern_outputs.begin(), old_pattern_outputs.end(), std::back_inserter(new_abstracts),
83 [](const AnfNodePtr &out) { return out->abstract(); });
84 auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(new_abstracts);
85 MS_EXCEPTION_IF_NULL(abstract_tuple);
86 lamb_next_mv_rule->set_abstract(abstract_tuple);
87
88 // Create tuple_getitem node for outputs
89 std::vector<AnfNodePtr> lamb_next_mv_rule_outputs;
90 CreateMultipleOutputsOfAnfNode(func_graph, lamb_next_mv_rule, kLambNextMVRuleOutputNum, &lamb_next_mv_rule_outputs);
91
92 auto manager = func_graph->manager();
93 MS_EXCEPTION_IF_NULL(manager);
94 (void)manager->Replace(old_pattern_outputs[kIndex1], lamb_next_mv_rule_outputs[kIndex1]);
95 (void)manager->Replace(old_pattern_outputs[kIndex2], lamb_next_mv_rule_outputs[kIndex2]);
96 (void)manager->Replace(old_pattern_outputs[kIndex3], lamb_next_mv_rule_outputs[kIndex3]);
97
98 return lamb_next_mv_rule_outputs[0];
99 }
100
IsShareNodes(const EquivPtr & equiv1,const EquivPtr & equiv2) const101 bool LambNextMVRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const {
102 return IsSameNode(equiv1, equiv2, real_div0_var_) && IsSameNode(equiv1, equiv2, real_div1_var_) &&
103 IsSameNode(equiv1, equiv2, add2_y_);
104 }
105
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv) const106 const AnfNodePtr LambNextMVRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
107 const EquivPtr &equiv) const {
108 if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
109 return nullptr;
110 }
111 std::vector<AnfNodePtr> old_pattern_outputs;
112 if (!IsRuleMatched(func_graph, node, equiv, &old_pattern_outputs)) {
113 return nullptr;
114 }
115
116 return CreateLambNextMVNode(func_graph, old_pattern_outputs, equiv);
117 }
118
DefinePattern() const119 const BaseRef LambNextMVRuleCond1::DefinePattern() const {
120 const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
121
122 auto mul0 = VectorRef({prim::kPrimMul, mul0_x_, input4_});
123 auto mul1 = VectorRef({prim::kPrimMul, mul1_sub_, input3_});
124 auto mul2 = VectorRef({prim::kPrimMul, mul2_x_, input1_});
125 auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_});
126 auto mul4 = VectorRef({prim::kPrimMul, mul4_x_, input6_});
127 auto add0 = VectorRef({add0_var_, mul0, mul1});
128 auto add1 = VectorRef({add1_var_, mul2, mul3});
129
130 auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
131 auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
132
133 auto add2 = VectorRef({prim::kPrimAdd, add2_y_, real_div1});
134 auto sqrt0 = VectorRef({prim_rsqrt, add2});
135 auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0});
136
137 return VectorRef({prim::kPrimAdd, mul4, real_div2});
138 }
139
DefineAnotherPattern() const140 BaseRef LambNextMVRuleCond1::DefineAnotherPattern() const {
141 const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
142 const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
143 VarPtr Xs = std::make_shared<SeqVar>();
144 VarPtr Ys = std::make_shared<SeqVar>();
145 // Two patterns share: real_div0, real_div1, add2_y_
146 VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
147 VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
148
149 VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
150 VectorRef add4 = VectorRef({prim::kPrimAdd, add2_y_, sqrt1});
151 VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
152 return real_div4;
153 }
154
DefinePattern() const155 const BaseRef LambNextMVRuleCond2::DefinePattern() const {
156 const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
157
158 auto mul0 = VectorRef({prim::kPrimMul, input4_, mul0_x_});
159 auto mul1 = VectorRef({prim::kPrimMul, input3_, mul1_sub_});
160 auto mul2 = VectorRef({prim::kPrimMul, input1_, mul2_x_});
161 auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_});
162 auto mul4 = VectorRef({prim::kPrimMul, input6_, mul4_x_});
163 auto add0 = VectorRef({add0_var_, mul0, mul1});
164 auto add1 = VectorRef({add1_var_, mul2, mul3});
165
166 auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
167 auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
168
169 auto add2 = VectorRef({prim::kPrimAdd, add2_y_, real_div1});
170 auto sqrt0 = VectorRef({prim_rsqrt, add2});
171 auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0});
172
173 return VectorRef({prim::kPrimAdd, mul4, real_div2});
174 }
175
DefineAnotherPattern() const176 BaseRef LambNextMVRuleCond2::DefineAnotherPattern() const {
177 const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
178 const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
179 VarPtr Xs = std::make_shared<SeqVar>();
180 VarPtr Ys = std::make_shared<SeqVar>();
181 // Two patterns share: real_div0, real_div1, add2_y_
182 VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
183 VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
184
185 VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
186 VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, add2_y_});
187 VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
188 return real_div4;
189 }
190
DefinePattern() const191 const BaseRef LambNextMVRuleCond3::DefinePattern() const {
192 const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
193
194 auto mul0 = VectorRef({prim::kPrimMul, input4_, mul0_x_});
195 auto mul1 = VectorRef({prim::kPrimMul, input3_, mul1_sub_});
196 auto mul2 = VectorRef({prim::kPrimMul, input1_, mul2_x_});
197 auto mul3 = VectorRef({prim::kPrimMul, input0_, mul3_sub1_});
198 auto mul4 = VectorRef({prim::kPrimMul, input6_, mul4_x_});
199 auto add0 = VectorRef({add0_var_, mul0, mul1});
200 auto add1 = VectorRef({add1_var_, mul2, mul3});
201
202 auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
203 auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
204
205 auto add2 = VectorRef({prim::kPrimAdd, real_div1, add2_y_});
206 auto sqrt0 = VectorRef({prim_rsqrt, add2});
207 auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0});
208
209 return VectorRef({prim::kPrimAdd, mul4, real_div2});
210 }
211
DefineAnotherPattern() const212 BaseRef LambNextMVRuleCond3::DefineAnotherPattern() const {
213 const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
214 const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
215 VarPtr Xs = std::make_shared<SeqVar>();
216 VarPtr Ys = std::make_shared<SeqVar>();
217 // Two patterns share: real_div0, real_div1, add2_y_
218 VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
219 VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
220
221 VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
222 VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, add2_y_});
223 VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
224 return real_div4;
225 }
226
DefinePattern() const227 const BaseRef LambNextMVRuleCond4::DefinePattern() const {
228 const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
229
230 auto mul0 = VectorRef({prim::kPrimMul, mul0_x_, input4_});
231 auto mul1 = VectorRef({prim::kPrimMul, mul1_sub_, input3_});
232 auto mul2 = VectorRef({prim::kPrimMul, mul2_x_, input1_});
233 auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_});
234 auto mul4 = VectorRef({prim::kPrimMul, mul4_x_, input6_});
235 auto add0 = VectorRef({add0_var_, mul0, mul1});
236 auto add1 = VectorRef({add1_var_, mul2, mul3});
237
238 auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
239 auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
240
241 auto add2 = VectorRef({prim::kPrimAdd, real_div1, add2_y_});
242 auto sqrt0 = VectorRef({prim_rsqrt, add2});
243 auto real_div2 = VectorRef({real_div2_var_, real_div0, sqrt0});
244
245 return VectorRef({prim::kPrimAdd, real_div2, mul4});
246 }
247
DefineAnotherPattern() const248 BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const {
249 const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
250 const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
251 VarPtr Xs = std::make_shared<SeqVar>();
252 VarPtr Ys = std::make_shared<SeqVar>();
253 // Two patterns share: real_div0, real_div1, add2_y_
254 VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
255 VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
256
257 VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
258 VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, add2_y_});
259 VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
260 return real_div4;
261 }
262 } // namespace opt
263 } // namespace mindspore
264