• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/ffn_fusion.h"
19 #include <vector>
20 #include <unordered_map>
21 #include "tools/optimizer/common/gllo_utils.h"
22 #include "mindspore/core/ops/lite_ops.h"
23 #include "mindspore/core/ops/custom.h"
24 #include "ops/f_f_n.h"
25 #include "nnacl/op_base.h"
26 
27 namespace mindspore {
28 namespace opt {
29 namespace {
30 constexpr int kStructureNum = 2;
31 constexpr int DIV1_Y = 2;
32 constexpr int MUL2_Y = 2;
33 constexpr float DIFF_THRESHOLD = 0.0001;
34 constexpr float DIV2_Y = 1.41421;
35 constexpr float ADD3_Y = 1.0;
36 constexpr float MUL4_y = 0.5;
37 constexpr auto kFFNFusion = "FFN_Fusion";
38 constexpr auto kFFNPatternForConstFolding = "FFNPatternForConstFolding";
39 constexpr auto kFFNPatternForDynamicDims = "FFNPatternForDynamicDims";
40 }  // namespace
41 
CreateFFNFusionNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,int index) const42 CNodePtr FFNFusion::CreateFFNFusionNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
43                                         int index) const {
44   auto ffn_fusion_prim = std::make_shared<ops::FFN>();
45   MS_CHECK_TRUE_RET(ffn_fusion_prim != nullptr, nullptr);
46 
47   ffn_fusion_prim->AddAttr("activation", api::MakeValue("geglu"));
48   // "inner_precise" must be 1 when "activation" is "geglu"
49   ffn_fusion_prim->AddAttr("inner_precise", api::MakeValue(1));
50 
51   auto ffn_fusion_prim_c = ffn_fusion_prim->GetPrim();
52   MS_CHECK_TRUE_RET(ffn_fusion_prim_c != nullptr, nullptr);
53   auto input = (*equiv)[input_[index]];
54   MS_CHECK_TRUE_RET(input != nullptr, nullptr);
55   auto input_node = utils::cast<AnfNodePtr>(input);
56   MS_CHECK_TRUE_RET(input_node != nullptr, nullptr);
57   auto param1 = utils::cast<AnfNodePtr>((*equiv)[matmul1_b_[index]]);
58   MS_CHECK_TRUE_RET(param1 != nullptr, nullptr);
59   auto param2 = utils::cast<AnfNodePtr>((*equiv)[add1_x_[index]]);
60   MS_CHECK_TRUE_RET(param2 != nullptr, nullptr);
61   auto param3 = utils::cast<AnfNodePtr>((*equiv)[matmul2_b_[index]]);
62   MS_CHECK_TRUE_RET(param3 != nullptr, nullptr);
63 
64   auto none_value_node = NewValueNode(std::make_shared<None>());
65   none_value_node->set_abstract(std::make_shared<abstract::AbstractNone>());
66 
67   auto ffn_fusion_cnode =
68     func_graph->NewCNode(ffn_fusion_prim_c, {input_node, param1, param3, none_value_node, param2});
69   MS_CHECK_TRUE_RET(ffn_fusion_cnode != nullptr, nullptr);
70   ffn_fusion_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_ffn_fusion");
71   if (node->abstract() != nullptr) {
72     ffn_fusion_cnode->set_abstract(node->abstract()->Clone());
73   }
74   return ffn_fusion_cnode;
75 }
76 
Init() const77 bool FFNFusion::Init() const {
78   for (int i = 0; i < kMaxPatternNum; i++) {
79     input_[i] = std::make_shared<Var>();
80     MS_CHECK_TRUE_RET(input_[i] != nullptr, false);
81     div2_y_[i] = std::make_shared<Var>();
82     MS_CHECK_TRUE_RET(div2_y_[i] != nullptr, false);
83     add3_y_[i] = std::make_shared<Var>();
84     MS_CHECK_TRUE_RET(add3_y_[i] != nullptr, false);
85     mul4_y_[i] = std::make_shared<Var>();
86     MS_CHECK_TRUE_RET(mul4_y_[i] != nullptr, false);
87     matmul1_b_[i] = std::make_shared<Var>();
88     MS_CHECK_TRUE_RET(matmul1_b_[i] != nullptr, false);
89     add1_x_[i] = std::make_shared<Var>();
90     MS_CHECK_TRUE_RET(add1_x_[i] != nullptr, false);
91     matmul2_b_[i] = std::make_shared<Var>();
92     MS_CHECK_TRUE_RET(matmul2_b_[i] != nullptr, false);
93   }
94   gather_y_ = std::make_shared<Var>();
95   MS_CHECK_TRUE_RET(gather_y_ != nullptr, false);
96   add2_y_ = std::make_shared<Var>();
97   MS_CHECK_TRUE_RET(add2_y_ != nullptr, false);
98   div1_y_ = std::make_shared<Var>();
99   MS_CHECK_TRUE_RET(div1_y_ != nullptr, false);
100   mul1_y_ = std::make_shared<Var>();
101   MS_CHECK_TRUE_RET(mul1_y_ != nullptr, false);
102   mul2_y_ = std::make_shared<Var>();
103   MS_CHECK_TRUE_RET(mul2_y_ != nullptr, false);
104   return true;
105 }
106 
DefineFFNPatternForDynamicDims() const107 const VectorRef FFNFusion::DefineFFNPatternForDynamicDims() const {
108   MS_LOG(INFO) << "start define FFN fusion pattern for dynamic dims.";
109   const size_t param_num = 6;
110   std::vector<CondVarPtr> params(param_num);
111   for (size_t i = 0; i < params.size(); ++i) {
112     params[i] = std::make_shared<CondVar>(IsParamNode);
113     MS_CHECK_TRUE_RET(params[i] != nullptr, {});
114   }
115   size_t index = 0;
116   auto is_matmul1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
117   MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
118   VectorRef matmul1_ref({is_matmul1, input_[kDynamicDims], matmul1_b_[kDynamicDims]});
119   auto is_add1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
120   MS_CHECK_TRUE_RET(is_add1 != nullptr, {});
121   VectorRef add1_ref({is_add1, add1_x_[kDynamicDims], matmul1_ref});
122   auto is_shape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimShape>);
123   MS_CHECK_TRUE_RET(is_shape != nullptr, {});
124   VectorRef shape_ref({is_shape, add1_ref});
125   auto is_gather = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimGather>);
126   MS_CHECK_TRUE_RET(is_gather != nullptr, {});
127   VectorRef gather_ref({is_gather, shape_ref, gather_y_, params[index++]});
128   auto is_add2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
129   MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
130   VectorRef add2_ref({is_add2, gather_ref, add2_y_});
131   auto is_div1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimDivFusion>);
132   MS_CHECK_TRUE_RET(is_div1 != nullptr, {});
133   VectorRef div1_ref({is_div1, add2_ref, div1_y_});
134   auto is_mul1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
135   MS_CHECK_TRUE_RET(is_mul1 != nullptr, {});
136   VectorRef mul1_ref({is_mul1, div1_ref, mul1_y_});
137   auto is_stridedslice1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimStridedSlice>);
138   MS_CHECK_TRUE_RET(is_stridedslice1 != nullptr, {});
139   VectorRef stridedslice1_ref(
140     {is_stridedslice1, add1_ref, params[index++], mul1_ref, params[index++], params[index++]});
141   auto is_mul2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
142   MS_CHECK_TRUE_RET(is_mul2 != nullptr, {});
143   VectorRef mul2_ref({is_mul2, div1_ref, mul2_y_});
144   auto is_stridedslice2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimStridedSlice>);
145   MS_CHECK_TRUE_RET(is_stridedslice2 != nullptr, {});
146   VectorRef stridedslice2_ref({is_stridedslice2, add1_ref, mul1_ref, mul2_ref, params[index++], params[index++]});
147   auto is_div2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimDivFusion>);
148   MS_CHECK_TRUE_RET(is_div2 != nullptr, {});
149   VectorRef div2_ref({is_div2, stridedslice2_ref, div2_y_[kDynamicDims]});
150   auto is_erf = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimErf>);
151   MS_CHECK_TRUE_RET(is_erf != nullptr, {});
152   VectorRef erf_ref({is_erf, div2_ref});
153   auto is_add3 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
154   MS_CHECK_TRUE_RET(is_add3 != nullptr, {});
155   VectorRef add3_ref({is_add3, erf_ref, add3_y_[kDynamicDims]});
156   auto is_mul3 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
157   MS_CHECK_TRUE_RET(is_mul3 != nullptr, {});
158   VectorRef mul3_ref({is_mul3, stridedslice2_ref, add3_ref});
159   auto is_mul4 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
160   MS_CHECK_TRUE_RET(is_mul4 != nullptr, {});
161   VectorRef mul4_ref({is_mul4, mul3_ref, mul4_y_[kDynamicDims]});
162   auto is_mul5 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
163   MS_CHECK_TRUE_RET(is_mul5 != nullptr, {});
164   VectorRef mul5_ref({is_mul5, stridedslice1_ref, mul4_ref});
165   auto is_matmul2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
166   MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
167   VectorRef matmul2_ref({is_matmul2, mul5_ref, matmul2_b_[kDynamicDims]});
168   return matmul2_ref;
169 }
170 
DefineFFNPatternForConstFolding() const171 const VectorRef FFNFusion::DefineFFNPatternForConstFolding() const {
172   MS_LOG(INFO) << "start define FFN fusion pattern for const folding.";
173   const size_t param_num = 8;
174   std::vector<CondVarPtr> params(param_num);
175   for (size_t i = 0; i < params.size(); ++i) {
176     params[i] = std::make_shared<CondVar>(IsParamNode);
177     MS_CHECK_TRUE_RET(params[i] != nullptr, {});
178   }
179   size_t index = 0;
180   auto is_matmul1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
181   MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
182   VectorRef matmul1_ref({is_matmul1, input_[kConstFold], matmul1_b_[kConstFold]});
183   auto is_add1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
184   MS_CHECK_TRUE_RET(is_add1 != nullptr, {});
185   VectorRef add1_ref({is_add1, add1_x_[kConstFold], matmul1_ref});
186 
187   auto is_stridedslice1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimStridedSlice>);
188   MS_CHECK_TRUE_RET(is_stridedslice1 != nullptr, {});
189   VectorRef stridedslice1_ref(
190     {is_stridedslice1, add1_ref, params[index++], params[index++], params[index++], params[index++]});
191 
192   auto is_stridedslice2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimStridedSlice>);
193   MS_CHECK_TRUE_RET(is_stridedslice2 != nullptr, {});
194   VectorRef stridedslice2_ref(
195     {is_stridedslice2, add1_ref, params[index++], params[index++], params[index++], params[index++]});
196 
197   auto is_div2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimDivFusion>);
198   MS_CHECK_TRUE_RET(is_div2 != nullptr, {});
199   VectorRef div2_ref({is_div2, stridedslice2_ref, div2_y_[kConstFold]});
200   auto is_erf = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimErf>);
201   MS_CHECK_TRUE_RET(is_erf != nullptr, {});
202   VectorRef erf_ref({is_erf, div2_ref});
203   auto is_add3 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
204   MS_CHECK_TRUE_RET(is_add3 != nullptr, {});
205   VectorRef add3_ref({is_add3, erf_ref, add3_y_[kConstFold]});
206   auto is_mul3 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
207   MS_CHECK_TRUE_RET(is_mul3 != nullptr, {});
208   VectorRef mul3_ref({is_mul3, stridedslice2_ref, add3_ref});
209   auto is_mul4 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
210   MS_CHECK_TRUE_RET(is_mul4 != nullptr, {});
211   VectorRef mul4_ref({is_mul4, mul3_ref, mul4_y_[kConstFold]});
212   auto is_mul5 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
213   MS_CHECK_TRUE_RET(is_mul5 != nullptr, {});
214   VectorRef mul5_ref({is_mul5, stridedslice1_ref, mul4_ref});
215   auto is_matmul2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
216   MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
217   VectorRef matmul2_ref({is_matmul2, mul5_ref, matmul2_b_[kConstFold]});
218   return matmul2_ref;
219 }
220 
DefinePatterns() const221 std::unordered_map<std::string, VectorRef> FFNFusion::DefinePatterns() const {
222   MS_LOG(INFO) << "start define FFN fusion patterns.";
223   if (!Init()) {
224     MS_LOG(ERROR) << "DefinePatterns Init Failed.";
225     return {};
226   }
227   std::unordered_map<std::string, VectorRef> patterns;
228   patterns[kFFNPatternForConstFolding] = DefineFFNPatternForConstFolding();
229   patterns[kFFNPatternForDynamicDims] = DefineFFNPatternForDynamicDims();
230   return patterns;
231 }
232 
CheckPattern(const std::string & pattern_name,const EquivPtr & equiv) const233 bool FFNFusion::CheckPattern(const std::string &pattern_name, const EquivPtr &equiv) const {
234   int index = pattern_name == kFFNPatternForDynamicDims ? kDynamicDims : kConstFold;
235 
236   float div2_y = GetFloatParameterValue(equiv, div2_y_[index]);
237   if (div2_y < 0 || fabs(div2_y - DIV2_Y) > DIFF_THRESHOLD) {
238     return false;
239   }
240   float add3_y = GetFloatParameterValue(equiv, add3_y_[index]);
241   if (add3_y < 0 || fabs(add3_y - ADD3_Y) > DIFF_THRESHOLD) {
242     return false;
243   }
244   float mul4_y = GetFloatParameterValue(equiv, mul4_y_[index]);
245   if (mul4_y < 0 || fabs(mul4_y - MUL4_y) > DIFF_THRESHOLD) {
246     return false;
247   }
248   if (pattern_name == kFFNPatternForConstFolding) {
249     return true;
250   }
251   // if pattern is for const folding, there are no nodes below, so no need checking.
252   int gather_index = GetIntParameterValue(equiv, gather_y_);
253   if (gather_index == INT_MIN) {
254     return false;
255   }
256   int add2_y = GetIntParameterValue(equiv, add2_y_);
257   if (add2_y != 1) {
258     return false;
259   }
260   int div1_y = GetIntParameterValue(equiv, div1_y_);
261   if (div1_y != DIV1_Y) {
262     return false;
263   }
264   int mul1_y = GetIntParameterValue(equiv, mul1_y_);
265   if (mul1_y != 1) {
266     return false;
267   }
268   int mul2_y = GetIntParameterValue(equiv, mul2_y_);
269   if (mul2_y != MUL2_Y) {
270     return false;
271   }
272   return true;
273 }
274 
Process(const std::string & pattern_name,const mindspore::FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node,const mindspore::EquivPtr & equiv) const275 AnfNodePtr FFNFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
276                               const mindspore::AnfNodePtr &node, const mindspore::EquivPtr &equiv) const {
277   MS_LOG(INFO) << "do fusion, pattern name: " << pattern_name;
278   if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
279     MS_LOG(ERROR) << "function graph, node or equiv is nullptr.";
280     return nullptr;
281   }
282   if (!utils::isa<CNodePtr>(node)) {
283     MS_LOG(ERROR) << "this node is not cnode, node name: " << node->fullname_with_scope();
284     return nullptr;
285   }
286   if (IsMarkedTrainOp(utils::cast<CNodePtr>(node))) {
287     MS_LOG(ERROR) << "node is train op, can not fusion.";
288     return nullptr;
289   }
290   if (!CheckPattern(pattern_name, equiv)) {
291     MS_LOG(ERROR) << "CheckPattern failed.";
292     return nullptr;
293   }
294   int index = pattern_name == kFFNPatternForDynamicDims ? kDynamicDims : kConstFold;
295   auto cnode = CreateFFNFusionNode(func_graph, node, equiv, index);
296   if (cnode == nullptr) {
297     MS_LOG(INFO) << "new FFN node failed.";
298     return nullptr;
299   }
300   MS_LOG(INFO) << "FFN fusion success, fusion node name: " << cnode->fullname_with_scope();
301   return cnode;
302 }
303 }  // namespace opt
304 }  // namespace mindspore
305