1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h"
17 #include "backend/optimizer/common/helper.h"
18 #include "backend/session/anf_runtime_algorithm.h"
19 #include "utils/trace_base.h"
20 namespace mindspore {
21 namespace opt {
DefinePattern() const22 const BaseRef AdamApplyOneFusion::DefinePattern() const {
23 const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
24 const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
25 VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex2], input_vars_[kIndex1]});
26 VectorRef mul3 =
27 VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex3], VectorRef({prim::kPrimSquare, input_vars_[kIndex0]})});
28 VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})});
29 VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex1], input_vars_[kIndex0]});
30 VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex0], input_vars_[kIndex2]});
31 VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
32 VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
33 return VectorRef(
34 {prim::kPrimSub, input_vars_[kIndex3], VectorRef({prim::kPrimMul, input_vars_[kIndex4], true_div0})});
35 }
36
DefinePattern() const37 const BaseRef AdamApplyOneCond1Fusion::DefinePattern() const {
38 const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
39 const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
40 VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex2], input_vars_[kIndex1]});
41 VectorRef mul3 =
42 VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex3], VectorRef({prim::kPrimSquare, input_vars_[kIndex0]})});
43 VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})});
44 VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex1], input_vars_[kIndex0]});
45 VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex0], input_vars_[kIndex2]});
46 VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
47 VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, add2_y_, sqrt0})});
48 return VectorRef(
49 {prim::kPrimSub, input_vars_[kIndex3], VectorRef({prim::kPrimMul, input_vars_[kIndex4], true_div0})});
50 }
51
DefinePattern() const52 const BaseRef AdamApplyOneCond2Fusion::DefinePattern() const {
53 const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
54 const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
55 VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex2], input_vars_[kIndex1]});
56 VectorRef mul3 =
57 VectorRef({prim::kPrimMul, VectorRef({prim::kPrimSquare, input_vars_[kIndex0]}), mul_x_input_vars_[kIndex3]});
58 VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})});
59 VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex1], input_vars_[kIndex0]});
60 VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex0], input_vars_[kIndex2]});
61 VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
62 VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
63 return VectorRef(
64 {prim::kPrimSub, input_vars_[kIndex3], VectorRef({prim::kPrimMul, true_div0, input_vars_[kIndex4]})});
65 }
66
DefinePattern() const67 const BaseRef AdamApplyOneCond3Fusion::DefinePattern() const {
68 const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
69 const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
70 VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex2], input_vars_[kIndex1]});
71 VectorRef mul3 =
72 VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex3], VectorRef({prim::kPrimSquare, input_vars_[kIndex0]})});
73 VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})});
74 VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex1], input_vars_[kIndex0]});
75 VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex0], input_vars_[kIndex2]});
76 VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
77 VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
78 return VectorRef(
79 {prim::kPrimSub, input_vars_[kIndex3], VectorRef({prim::kPrimMul, true_div0, input_vars_[kIndex4]})});
80 }
81
DefinePattern() const82 const BaseRef AdamApplyOneCond4Fusion::DefinePattern() const {
83 const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
84 const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
85 VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex2], input_vars_[kIndex1]});
86 VectorRef mul3 =
87 VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex3], VectorRef({prim::kPrimSquare, input_vars_[kIndex0]})});
88 VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})});
89 VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex1], input_vars_[kIndex0]});
90 VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex0], input_vars_[kIndex2]});
91 VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
92 VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, add2_y_, sqrt0})});
93 return VectorRef(
94 {prim::kPrimSub, input_vars_[kIndex3], VectorRef({prim::kPrimMul, true_div0, input_vars_[kIndex4]})});
95 }
96
DefinePattern() const97 const BaseRef AdamApplyOneAssignFusion::DefinePattern() const {
98 const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
99 const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
100 VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex2], input_vars_[kIndex1]});
101 VectorRef mul3 =
102 VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex3], VectorRef({prim::kPrimSquare, input_vars_[kIndex0]})});
103 VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
104 VectorRef sqrt0 = VectorRef({prim_sqrt, add1});
105 VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex1], input_vars_[kIndex0]});
106 VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex0], input_vars_[kIndex2]});
107 VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
108 VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
109 VectorRef sub0 =
110 VectorRef({sub0_var_, input_vars_[kIndex3], VectorRef({prim::kPrimMul, input_vars_[kIndex4], true_div0})});
111 VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[kIndex3], sub0});
112 VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
113 VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[kIndex2], add0});
114 VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
115 VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[kIndex1], add1});
116 return VectorRef({prim::kPrimDepend, depend1, assign2});
117 }
118
DefinePattern() const119 const BaseRef AdamApplyOneAssignCond1Fusion::DefinePattern() const {
120 const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
121 const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
122 VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex2], input_vars_[kIndex1]});
123 VectorRef mul3 =
124 VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex3], VectorRef({prim::kPrimSquare, input_vars_[kIndex0]})});
125 VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
126 VectorRef sqrt0 = VectorRef({prim_sqrt, add1});
127 VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex1], input_vars_[kIndex0]});
128 VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex0], input_vars_[kIndex2]});
129 VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
130 VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, add2_y_, sqrt0})});
131 VectorRef sub0 =
132 VectorRef({sub0_var_, input_vars_[kIndex3], VectorRef({prim::kPrimMul, input_vars_[kIndex4], true_div0})});
133 VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[kIndex3], sub0});
134 VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
135 VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[kIndex2], add0});
136 VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
137 VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[kIndex1], add1});
138 return VectorRef({prim::kPrimDepend, depend1, assign2});
139 }
140
DefinePattern() const141 const BaseRef AdamApplyOneAssignCond2Fusion::DefinePattern() const {
142 const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
143 const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
144 VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex2], input_vars_[kIndex1]});
145 VectorRef mul3 =
146 VectorRef({prim::kPrimMul, VectorRef({prim::kPrimSquare, input_vars_[kIndex0]}), mul_x_input_vars_[kIndex3]});
147 VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
148 VectorRef sqrt0 = VectorRef({prim_sqrt, add1});
149 VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex1], input_vars_[kIndex0]});
150 VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex0], input_vars_[kIndex2]});
151 VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
152 VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
153 VectorRef sub0 =
154 VectorRef({sub0_var_, input_vars_[kIndex3], VectorRef({prim::kPrimMul, true_div0, input_vars_[kIndex4]})});
155 VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[kIndex3], sub0});
156 VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
157 VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[kIndex2], add0});
158 VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
159 VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[kIndex1], add1});
160 return VectorRef({prim::kPrimDepend, depend1, assign2});
161 }
162
DefinePattern() const163 const BaseRef AdamApplyOneAssignCond3Fusion::DefinePattern() const {
164 const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
165 const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
166 VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex2], input_vars_[kIndex1]});
167 VectorRef mul3 =
168 VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex3], VectorRef({prim::kPrimSquare, input_vars_[kIndex0]})});
169 VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
170 VectorRef sqrt0 = VectorRef({prim_sqrt, add1});
171 VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex1], input_vars_[kIndex0]});
172 VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex0], input_vars_[kIndex2]});
173 VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
174 VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
175 VectorRef sub0 =
176 VectorRef({sub0_var_, input_vars_[kIndex3], VectorRef({prim::kPrimMul, true_div0, input_vars_[kIndex4]})});
177 VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[kIndex3], sub0});
178 VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
179 VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[kIndex2], add0});
180 VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
181 VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[kIndex1], add1});
182 return VectorRef({prim::kPrimDepend, depend1, assign2});
183 }
184
DefinePattern() const185 const BaseRef AdamApplyOneAssignCond4Fusion::DefinePattern() const {
186 const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
187 const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
188 VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex2], input_vars_[kIndex1]});
189 VectorRef mul3 =
190 VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex3], VectorRef({prim::kPrimSquare, input_vars_[kIndex0]})});
191 VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
192 VectorRef sqrt0 = VectorRef({prim_sqrt, add1});
193 VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex1], input_vars_[kIndex0]});
194 VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex0], input_vars_[kIndex2]});
195 VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
196 VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, add2_y_, sqrt0})});
197 VectorRef sub0 =
198 VectorRef({sub0_var_, input_vars_[kIndex3], VectorRef({prim::kPrimMul, true_div0, input_vars_[kIndex4]})});
199 VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[kIndex3], sub0});
200 VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
201 VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[kIndex2], add0});
202 VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
203 VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[kIndex1], add1});
204 return VectorRef({prim::kPrimDepend, depend1, assign2});
205 }
206
CreateAdamApplyOneNode(const FuncGraphPtr & func_graph,const EquivPtr & equiv,const AnfNodePtr & final_node) const207 AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
208 const AnfNodePtr &final_node) const {
209 MS_EXCEPTION_IF_NULL(func_graph);
210 MS_EXCEPTION_IF_NULL(equiv);
211 PrimitivePtr prim = nullptr;
212 if (AnfAlgo::CheckPrimitiveType(final_node, prim::kPrimDepend)) {
213 prim = std::make_shared<Primitive>(kAdamApplyOneAssignOpName);
214 } else {
215 prim = std::make_shared<Primitive>(kAdamApplyOneOpName);
216 }
217 std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(prim)};
218 for (const auto &input_var : input_vars_) {
219 auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_var]);
220 MS_EXCEPTION_IF_NULL(input_node);
221 new_node_inputs.push_back(input_node);
222 }
223 for (const auto &mul_x_input_var : mul_x_input_vars_) {
224 auto mul_x_input_node = utils::cast<AnfNodePtr>((*equiv)[mul_x_input_var]);
225 MS_EXCEPTION_IF_NULL(mul_x_input_node);
226 new_node_inputs.push_back(mul_x_input_node);
227 }
228 auto add2_y_node = utils::cast<AnfNodePtr>((*equiv)[add2_y_]);
229 MS_EXCEPTION_IF_NULL(add2_y_node);
230 new_node_inputs.push_back(add2_y_node);
231 auto new_node = func_graph->NewCNode(new_node_inputs);
232 return new_node;
233 }
234
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv) const235 const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
236 const EquivPtr &equiv) const {
237 MS_EXCEPTION_IF_NULL(func_graph);
238 auto sub0 = node;
239 if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
240 auto iter_sub0 = (*equiv).find(sub0_var_);
241 if (iter_sub0 == (*equiv).end()) {
242 MS_LOG(EXCEPTION) << "The equiv map is expected to contains the sub0 var after matched."
243 << " trace: " << trace::DumpSourceLines(node);
244 }
245 sub0 = utils::cast<AnfNodePtr>(iter_sub0->second);
246 }
247 MS_EXCEPTION_IF_NULL(sub0);
248 if (!CheckSupportDataType(sub0, kFloatDataTypeSet)) {
249 return nullptr;
250 }
251 auto new_node = CreateAdamApplyOneNode(func_graph, equiv, node);
252 MS_EXCEPTION_IF_NULL(new_node);
253 new_node->set_scope(sub0->scope());
254 // Set abstract of new node
255 AbstractBasePtrList new_node_abstract_list;
256 auto iter_add0 = (*equiv).find(add0_var_);
257 if (iter_add0 == (*equiv).end()) {
258 MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add0 var after matched."
259 << " trace: " << trace::DumpSourceLines(node);
260 }
261 auto iter_add1 = (*equiv).find(add1_var_);
262 if (iter_add1 == (*equiv).end()) {
263 MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add1 var after matched."
264 << " trace: " << trace::DumpSourceLines(node);
265 }
266 auto add0 = utils::cast<AnfNodePtr>(iter_add0->second);
267 MS_EXCEPTION_IF_NULL(add0);
268 auto add1 = utils::cast<AnfNodePtr>(iter_add1->second);
269 MS_EXCEPTION_IF_NULL(add1);
270 new_node_abstract_list.push_back(add1->abstract());
271 new_node_abstract_list.push_back(add0->abstract());
272 new_node_abstract_list.push_back(sub0->abstract());
273 auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(new_node_abstract_list);
274 new_node->set_abstract(abstract_tuple);
275 // Create tuple_getitem node for outputs
276 std::vector<AnfNodePtr> new_node_outputs;
277 CreateMultipleOutputsOfAnfNode(func_graph, new_node, kAdamApplyOneOutputNum, &new_node_outputs);
278 if (new_node_outputs.size() != kAdamApplyOneOutputNum) {
279 MS_LOG(EXCEPTION) << "The output size of node " << new_node->DebugString() << " should be "
280 << kAdamApplyOneOutputNum << " trace: " << trace::DumpSourceLines(node);
281 }
282 auto manager = func_graph->manager();
283 MS_EXCEPTION_IF_NULL(manager);
284 (void)manager->Replace(add1, new_node_outputs[kIndex0]);
285 (void)manager->Replace(add0, new_node_outputs[kIndex1]);
286 return new_node_outputs[kIndex2];
287 }
288 } // namespace opt
289 } // namespace mindspore
290