• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h"
17 #include "backend/optimizer/common/helper.h"
18 #include "backend/session/anf_runtime_algorithm.h"
19 #include "utils/trace_base.h"
20 namespace mindspore {
21 namespace opt {
DefinePattern() const22 const BaseRef AdamApplyOneFusion::DefinePattern() const {
23   const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
24   const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
25   VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex2], input_vars_[kIndex1]});
26   VectorRef mul3 =
27     VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex3], VectorRef({prim::kPrimSquare, input_vars_[kIndex0]})});
28   VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})});
29   VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex1], input_vars_[kIndex0]});
30   VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex0], input_vars_[kIndex2]});
31   VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
32   VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
33   return VectorRef(
34     {prim::kPrimSub, input_vars_[kIndex3], VectorRef({prim::kPrimMul, input_vars_[kIndex4], true_div0})});
35 }
36 
DefinePattern() const37 const BaseRef AdamApplyOneCond1Fusion::DefinePattern() const {
38   const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
39   const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
40   VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex2], input_vars_[kIndex1]});
41   VectorRef mul3 =
42     VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex3], VectorRef({prim::kPrimSquare, input_vars_[kIndex0]})});
43   VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})});
44   VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex1], input_vars_[kIndex0]});
45   VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex0], input_vars_[kIndex2]});
46   VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
47   VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, add2_y_, sqrt0})});
48   return VectorRef(
49     {prim::kPrimSub, input_vars_[kIndex3], VectorRef({prim::kPrimMul, input_vars_[kIndex4], true_div0})});
50 }
51 
DefinePattern() const52 const BaseRef AdamApplyOneCond2Fusion::DefinePattern() const {
53   const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
54   const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
55   VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex2], input_vars_[kIndex1]});
56   VectorRef mul3 =
57     VectorRef({prim::kPrimMul, VectorRef({prim::kPrimSquare, input_vars_[kIndex0]}), mul_x_input_vars_[kIndex3]});
58   VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})});
59   VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex1], input_vars_[kIndex0]});
60   VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex0], input_vars_[kIndex2]});
61   VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
62   VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
63   return VectorRef(
64     {prim::kPrimSub, input_vars_[kIndex3], VectorRef({prim::kPrimMul, true_div0, input_vars_[kIndex4]})});
65 }
66 
DefinePattern() const67 const BaseRef AdamApplyOneCond3Fusion::DefinePattern() const {
68   const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
69   const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
70   VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex2], input_vars_[kIndex1]});
71   VectorRef mul3 =
72     VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex3], VectorRef({prim::kPrimSquare, input_vars_[kIndex0]})});
73   VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})});
74   VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex1], input_vars_[kIndex0]});
75   VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex0], input_vars_[kIndex2]});
76   VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
77   VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
78   return VectorRef(
79     {prim::kPrimSub, input_vars_[kIndex3], VectorRef({prim::kPrimMul, true_div0, input_vars_[kIndex4]})});
80 }
81 
DefinePattern() const82 const BaseRef AdamApplyOneCond4Fusion::DefinePattern() const {
83   const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
84   const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
85   VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex2], input_vars_[kIndex1]});
86   VectorRef mul3 =
87     VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex3], VectorRef({prim::kPrimSquare, input_vars_[kIndex0]})});
88   VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})});
89   VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex1], input_vars_[kIndex0]});
90   VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex0], input_vars_[kIndex2]});
91   VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
92   VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, add2_y_, sqrt0})});
93   return VectorRef(
94     {prim::kPrimSub, input_vars_[kIndex3], VectorRef({prim::kPrimMul, true_div0, input_vars_[kIndex4]})});
95 }
96 
DefinePattern() const97 const BaseRef AdamApplyOneAssignFusion::DefinePattern() const {
98   const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
99   const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
100   VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex2], input_vars_[kIndex1]});
101   VectorRef mul3 =
102     VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex3], VectorRef({prim::kPrimSquare, input_vars_[kIndex0]})});
103   VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
104   VectorRef sqrt0 = VectorRef({prim_sqrt, add1});
105   VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex1], input_vars_[kIndex0]});
106   VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex0], input_vars_[kIndex2]});
107   VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
108   VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
109   VectorRef sub0 =
110     VectorRef({sub0_var_, input_vars_[kIndex3], VectorRef({prim::kPrimMul, input_vars_[kIndex4], true_div0})});
111   VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[kIndex3], sub0});
112   VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
113   VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[kIndex2], add0});
114   VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
115   VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[kIndex1], add1});
116   return VectorRef({prim::kPrimDepend, depend1, assign2});
117 }
118 
DefinePattern() const119 const BaseRef AdamApplyOneAssignCond1Fusion::DefinePattern() const {
120   const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
121   const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
122   VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex2], input_vars_[kIndex1]});
123   VectorRef mul3 =
124     VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex3], VectorRef({prim::kPrimSquare, input_vars_[kIndex0]})});
125   VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
126   VectorRef sqrt0 = VectorRef({prim_sqrt, add1});
127   VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex1], input_vars_[kIndex0]});
128   VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex0], input_vars_[kIndex2]});
129   VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
130   VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, add2_y_, sqrt0})});
131   VectorRef sub0 =
132     VectorRef({sub0_var_, input_vars_[kIndex3], VectorRef({prim::kPrimMul, input_vars_[kIndex4], true_div0})});
133   VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[kIndex3], sub0});
134   VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
135   VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[kIndex2], add0});
136   VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
137   VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[kIndex1], add1});
138   return VectorRef({prim::kPrimDepend, depend1, assign2});
139 }
140 
DefinePattern() const141 const BaseRef AdamApplyOneAssignCond2Fusion::DefinePattern() const {
142   const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
143   const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
144   VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex2], input_vars_[kIndex1]});
145   VectorRef mul3 =
146     VectorRef({prim::kPrimMul, VectorRef({prim::kPrimSquare, input_vars_[kIndex0]}), mul_x_input_vars_[kIndex3]});
147   VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
148   VectorRef sqrt0 = VectorRef({prim_sqrt, add1});
149   VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex1], input_vars_[kIndex0]});
150   VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex0], input_vars_[kIndex2]});
151   VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
152   VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
153   VectorRef sub0 =
154     VectorRef({sub0_var_, input_vars_[kIndex3], VectorRef({prim::kPrimMul, true_div0, input_vars_[kIndex4]})});
155   VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[kIndex3], sub0});
156   VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
157   VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[kIndex2], add0});
158   VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
159   VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[kIndex1], add1});
160   return VectorRef({prim::kPrimDepend, depend1, assign2});
161 }
162 
DefinePattern() const163 const BaseRef AdamApplyOneAssignCond3Fusion::DefinePattern() const {
164   const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
165   const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
166   VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex2], input_vars_[kIndex1]});
167   VectorRef mul3 =
168     VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex3], VectorRef({prim::kPrimSquare, input_vars_[kIndex0]})});
169   VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
170   VectorRef sqrt0 = VectorRef({prim_sqrt, add1});
171   VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex1], input_vars_[kIndex0]});
172   VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex0], input_vars_[kIndex2]});
173   VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
174   VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
175   VectorRef sub0 =
176     VectorRef({sub0_var_, input_vars_[kIndex3], VectorRef({prim::kPrimMul, true_div0, input_vars_[kIndex4]})});
177   VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[kIndex3], sub0});
178   VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
179   VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[kIndex2], add0});
180   VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
181   VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[kIndex1], add1});
182   return VectorRef({prim::kPrimDepend, depend1, assign2});
183 }
184 
DefinePattern() const185 const BaseRef AdamApplyOneAssignCond4Fusion::DefinePattern() const {
186   const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
187   const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
188   VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex2], input_vars_[kIndex1]});
189   VectorRef mul3 =
190     VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex3], VectorRef({prim::kPrimSquare, input_vars_[kIndex0]})});
191   VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
192   VectorRef sqrt0 = VectorRef({prim_sqrt, add1});
193   VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex1], input_vars_[kIndex0]});
194   VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[kIndex0], input_vars_[kIndex2]});
195   VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
196   VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, add2_y_, sqrt0})});
197   VectorRef sub0 =
198     VectorRef({sub0_var_, input_vars_[kIndex3], VectorRef({prim::kPrimMul, true_div0, input_vars_[kIndex4]})});
199   VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[kIndex3], sub0});
200   VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
201   VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[kIndex2], add0});
202   VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
203   VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[kIndex1], add1});
204   return VectorRef({prim::kPrimDepend, depend1, assign2});
205 }
206 
CreateAdamApplyOneNode(const FuncGraphPtr & func_graph,const EquivPtr & equiv,const AnfNodePtr & final_node) const207 AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
208                                                       const AnfNodePtr &final_node) const {
209   MS_EXCEPTION_IF_NULL(func_graph);
210   MS_EXCEPTION_IF_NULL(equiv);
211   PrimitivePtr prim = nullptr;
212   if (AnfAlgo::CheckPrimitiveType(final_node, prim::kPrimDepend)) {
213     prim = std::make_shared<Primitive>(kAdamApplyOneAssignOpName);
214   } else {
215     prim = std::make_shared<Primitive>(kAdamApplyOneOpName);
216   }
217   std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(prim)};
218   for (const auto &input_var : input_vars_) {
219     auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_var]);
220     MS_EXCEPTION_IF_NULL(input_node);
221     new_node_inputs.push_back(input_node);
222   }
223   for (const auto &mul_x_input_var : mul_x_input_vars_) {
224     auto mul_x_input_node = utils::cast<AnfNodePtr>((*equiv)[mul_x_input_var]);
225     MS_EXCEPTION_IF_NULL(mul_x_input_node);
226     new_node_inputs.push_back(mul_x_input_node);
227   }
228   auto add2_y_node = utils::cast<AnfNodePtr>((*equiv)[add2_y_]);
229   MS_EXCEPTION_IF_NULL(add2_y_node);
230   new_node_inputs.push_back(add2_y_node);
231   auto new_node = func_graph->NewCNode(new_node_inputs);
232   return new_node;
233 }
234 
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv) const235 const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
236                                              const EquivPtr &equiv) const {
237   MS_EXCEPTION_IF_NULL(func_graph);
238   auto sub0 = node;
239   if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
240     auto iter_sub0 = (*equiv).find(sub0_var_);
241     if (iter_sub0 == (*equiv).end()) {
242       MS_LOG(EXCEPTION) << "The equiv map is expected to contains the sub0 var after matched."
243                         << " trace: " << trace::DumpSourceLines(node);
244     }
245     sub0 = utils::cast<AnfNodePtr>(iter_sub0->second);
246   }
247   MS_EXCEPTION_IF_NULL(sub0);
248   if (!CheckSupportDataType(sub0, kFloatDataTypeSet)) {
249     return nullptr;
250   }
251   auto new_node = CreateAdamApplyOneNode(func_graph, equiv, node);
252   MS_EXCEPTION_IF_NULL(new_node);
253   new_node->set_scope(sub0->scope());
254   // Set abstract of new node
255   AbstractBasePtrList new_node_abstract_list;
256   auto iter_add0 = (*equiv).find(add0_var_);
257   if (iter_add0 == (*equiv).end()) {
258     MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add0 var after matched."
259                       << " trace: " << trace::DumpSourceLines(node);
260   }
261   auto iter_add1 = (*equiv).find(add1_var_);
262   if (iter_add1 == (*equiv).end()) {
263     MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add1 var after matched."
264                       << " trace: " << trace::DumpSourceLines(node);
265   }
266   auto add0 = utils::cast<AnfNodePtr>(iter_add0->second);
267   MS_EXCEPTION_IF_NULL(add0);
268   auto add1 = utils::cast<AnfNodePtr>(iter_add1->second);
269   MS_EXCEPTION_IF_NULL(add1);
270   new_node_abstract_list.push_back(add1->abstract());
271   new_node_abstract_list.push_back(add0->abstract());
272   new_node_abstract_list.push_back(sub0->abstract());
273   auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(new_node_abstract_list);
274   new_node->set_abstract(abstract_tuple);
275   // Create tuple_getitem node for outputs
276   std::vector<AnfNodePtr> new_node_outputs;
277   CreateMultipleOutputsOfAnfNode(func_graph, new_node, kAdamApplyOneOutputNum, &new_node_outputs);
278   if (new_node_outputs.size() != kAdamApplyOneOutputNum) {
279     MS_LOG(EXCEPTION) << "The output size of node " << new_node->DebugString() << " should be "
280                       << kAdamApplyOneOutputNum << " trace: " << trace::DumpSourceLines(node);
281   }
282   auto manager = func_graph->manager();
283   MS_EXCEPTION_IF_NULL(manager);
284   (void)manager->Replace(add1, new_node_outputs[kIndex0]);
285   (void)manager->Replace(add0, new_node_outputs[kIndex1]);
286   return new_node_outputs[kIndex2];
287 }
288 }  // namespace opt
289 }  // namespace mindspore
290