1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/ffn_fusion.h"
19 #include <vector>
20 #include <unordered_map>
21 #include "tools/optimizer/common/gllo_utils.h"
22 #include "mindspore/core/ops/lite_ops.h"
23 #include "mindspore/core/ops/custom.h"
24 #include "ops/f_f_n.h"
25 #include "nnacl/op_base.h"
26
27 namespace mindspore {
28 namespace opt {
29 namespace {
30 constexpr int kStructureNum = 2;
31 constexpr int DIV1_Y = 2;
32 constexpr int MUL2_Y = 2;
33 constexpr float DIFF_THRESHOLD = 0.0001;
34 constexpr float DIV2_Y = 1.41421;
35 constexpr float ADD3_Y = 1.0;
36 constexpr float MUL4_y = 0.5;
37 constexpr auto kFFNFusion = "FFN_Fusion";
38 constexpr auto kFFNPatternForConstFolding = "FFNPatternForConstFolding";
39 constexpr auto kFFNPatternForDynamicDims = "FFNPatternForDynamicDims";
40 } // namespace
41
CreateFFNFusionNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,int index) const42 CNodePtr FFNFusion::CreateFFNFusionNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
43 int index) const {
44 auto ffn_fusion_prim = std::make_shared<ops::FFN>();
45 MS_CHECK_TRUE_RET(ffn_fusion_prim != nullptr, nullptr);
46
47 ffn_fusion_prim->AddAttr("activation", api::MakeValue("geglu"));
48 // "inner_precise" must be 1 when "activation" is "geglu"
49 ffn_fusion_prim->AddAttr("inner_precise", api::MakeValue(1));
50
51 auto ffn_fusion_prim_c = ffn_fusion_prim->GetPrim();
52 MS_CHECK_TRUE_RET(ffn_fusion_prim_c != nullptr, nullptr);
53 auto input = (*equiv)[input_[index]];
54 MS_CHECK_TRUE_RET(input != nullptr, nullptr);
55 auto input_node = utils::cast<AnfNodePtr>(input);
56 MS_CHECK_TRUE_RET(input_node != nullptr, nullptr);
57 auto param1 = utils::cast<AnfNodePtr>((*equiv)[matmul1_b_[index]]);
58 MS_CHECK_TRUE_RET(param1 != nullptr, nullptr);
59 auto param2 = utils::cast<AnfNodePtr>((*equiv)[add1_x_[index]]);
60 MS_CHECK_TRUE_RET(param2 != nullptr, nullptr);
61 auto param3 = utils::cast<AnfNodePtr>((*equiv)[matmul2_b_[index]]);
62 MS_CHECK_TRUE_RET(param3 != nullptr, nullptr);
63
64 auto none_value_node = NewValueNode(std::make_shared<None>());
65 none_value_node->set_abstract(std::make_shared<abstract::AbstractNone>());
66
67 auto ffn_fusion_cnode =
68 func_graph->NewCNode(ffn_fusion_prim_c, {input_node, param1, param3, none_value_node, param2});
69 MS_CHECK_TRUE_RET(ffn_fusion_cnode != nullptr, nullptr);
70 ffn_fusion_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_ffn_fusion");
71 if (node->abstract() != nullptr) {
72 ffn_fusion_cnode->set_abstract(node->abstract()->Clone());
73 }
74 return ffn_fusion_cnode;
75 }
76
Init() const77 bool FFNFusion::Init() const {
78 for (int i = 0; i < kMaxPatternNum; i++) {
79 input_[i] = std::make_shared<Var>();
80 MS_CHECK_TRUE_RET(input_[i] != nullptr, false);
81 div2_y_[i] = std::make_shared<Var>();
82 MS_CHECK_TRUE_RET(div2_y_[i] != nullptr, false);
83 add3_y_[i] = std::make_shared<Var>();
84 MS_CHECK_TRUE_RET(add3_y_[i] != nullptr, false);
85 mul4_y_[i] = std::make_shared<Var>();
86 MS_CHECK_TRUE_RET(mul4_y_[i] != nullptr, false);
87 matmul1_b_[i] = std::make_shared<Var>();
88 MS_CHECK_TRUE_RET(matmul1_b_[i] != nullptr, false);
89 add1_x_[i] = std::make_shared<Var>();
90 MS_CHECK_TRUE_RET(add1_x_[i] != nullptr, false);
91 matmul2_b_[i] = std::make_shared<Var>();
92 MS_CHECK_TRUE_RET(matmul2_b_[i] != nullptr, false);
93 }
94 gather_y_ = std::make_shared<Var>();
95 MS_CHECK_TRUE_RET(gather_y_ != nullptr, false);
96 add2_y_ = std::make_shared<Var>();
97 MS_CHECK_TRUE_RET(add2_y_ != nullptr, false);
98 div1_y_ = std::make_shared<Var>();
99 MS_CHECK_TRUE_RET(div1_y_ != nullptr, false);
100 mul1_y_ = std::make_shared<Var>();
101 MS_CHECK_TRUE_RET(mul1_y_ != nullptr, false);
102 mul2_y_ = std::make_shared<Var>();
103 MS_CHECK_TRUE_RET(mul2_y_ != nullptr, false);
104 return true;
105 }
106
DefineFFNPatternForDynamicDims() const107 const VectorRef FFNFusion::DefineFFNPatternForDynamicDims() const {
108 MS_LOG(INFO) << "start define FFN fusion pattern for dynamic dims.";
109 const size_t param_num = 6;
110 std::vector<CondVarPtr> params(param_num);
111 for (size_t i = 0; i < params.size(); ++i) {
112 params[i] = std::make_shared<CondVar>(IsParamNode);
113 MS_CHECK_TRUE_RET(params[i] != nullptr, {});
114 }
115 size_t index = 0;
116 auto is_matmul1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
117 MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
118 VectorRef matmul1_ref({is_matmul1, input_[kDynamicDims], matmul1_b_[kDynamicDims]});
119 auto is_add1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
120 MS_CHECK_TRUE_RET(is_add1 != nullptr, {});
121 VectorRef add1_ref({is_add1, add1_x_[kDynamicDims], matmul1_ref});
122 auto is_shape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimShape>);
123 MS_CHECK_TRUE_RET(is_shape != nullptr, {});
124 VectorRef shape_ref({is_shape, add1_ref});
125 auto is_gather = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimGather>);
126 MS_CHECK_TRUE_RET(is_gather != nullptr, {});
127 VectorRef gather_ref({is_gather, shape_ref, gather_y_, params[index++]});
128 auto is_add2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
129 MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
130 VectorRef add2_ref({is_add2, gather_ref, add2_y_});
131 auto is_div1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimDivFusion>);
132 MS_CHECK_TRUE_RET(is_div1 != nullptr, {});
133 VectorRef div1_ref({is_div1, add2_ref, div1_y_});
134 auto is_mul1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
135 MS_CHECK_TRUE_RET(is_mul1 != nullptr, {});
136 VectorRef mul1_ref({is_mul1, div1_ref, mul1_y_});
137 auto is_stridedslice1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimStridedSlice>);
138 MS_CHECK_TRUE_RET(is_stridedslice1 != nullptr, {});
139 VectorRef stridedslice1_ref(
140 {is_stridedslice1, add1_ref, params[index++], mul1_ref, params[index++], params[index++]});
141 auto is_mul2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
142 MS_CHECK_TRUE_RET(is_mul2 != nullptr, {});
143 VectorRef mul2_ref({is_mul2, div1_ref, mul2_y_});
144 auto is_stridedslice2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimStridedSlice>);
145 MS_CHECK_TRUE_RET(is_stridedslice2 != nullptr, {});
146 VectorRef stridedslice2_ref({is_stridedslice2, add1_ref, mul1_ref, mul2_ref, params[index++], params[index++]});
147 auto is_div2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimDivFusion>);
148 MS_CHECK_TRUE_RET(is_div2 != nullptr, {});
149 VectorRef div2_ref({is_div2, stridedslice2_ref, div2_y_[kDynamicDims]});
150 auto is_erf = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimErf>);
151 MS_CHECK_TRUE_RET(is_erf != nullptr, {});
152 VectorRef erf_ref({is_erf, div2_ref});
153 auto is_add3 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
154 MS_CHECK_TRUE_RET(is_add3 != nullptr, {});
155 VectorRef add3_ref({is_add3, erf_ref, add3_y_[kDynamicDims]});
156 auto is_mul3 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
157 MS_CHECK_TRUE_RET(is_mul3 != nullptr, {});
158 VectorRef mul3_ref({is_mul3, stridedslice2_ref, add3_ref});
159 auto is_mul4 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
160 MS_CHECK_TRUE_RET(is_mul4 != nullptr, {});
161 VectorRef mul4_ref({is_mul4, mul3_ref, mul4_y_[kDynamicDims]});
162 auto is_mul5 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
163 MS_CHECK_TRUE_RET(is_mul5 != nullptr, {});
164 VectorRef mul5_ref({is_mul5, stridedslice1_ref, mul4_ref});
165 auto is_matmul2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
166 MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
167 VectorRef matmul2_ref({is_matmul2, mul5_ref, matmul2_b_[kDynamicDims]});
168 return matmul2_ref;
169 }
170
DefineFFNPatternForConstFolding() const171 const VectorRef FFNFusion::DefineFFNPatternForConstFolding() const {
172 MS_LOG(INFO) << "start define FFN fusion pattern for const folding.";
173 const size_t param_num = 8;
174 std::vector<CondVarPtr> params(param_num);
175 for (size_t i = 0; i < params.size(); ++i) {
176 params[i] = std::make_shared<CondVar>(IsParamNode);
177 MS_CHECK_TRUE_RET(params[i] != nullptr, {});
178 }
179 size_t index = 0;
180 auto is_matmul1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
181 MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
182 VectorRef matmul1_ref({is_matmul1, input_[kConstFold], matmul1_b_[kConstFold]});
183 auto is_add1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
184 MS_CHECK_TRUE_RET(is_add1 != nullptr, {});
185 VectorRef add1_ref({is_add1, add1_x_[kConstFold], matmul1_ref});
186
187 auto is_stridedslice1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimStridedSlice>);
188 MS_CHECK_TRUE_RET(is_stridedslice1 != nullptr, {});
189 VectorRef stridedslice1_ref(
190 {is_stridedslice1, add1_ref, params[index++], params[index++], params[index++], params[index++]});
191
192 auto is_stridedslice2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimStridedSlice>);
193 MS_CHECK_TRUE_RET(is_stridedslice2 != nullptr, {});
194 VectorRef stridedslice2_ref(
195 {is_stridedslice2, add1_ref, params[index++], params[index++], params[index++], params[index++]});
196
197 auto is_div2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimDivFusion>);
198 MS_CHECK_TRUE_RET(is_div2 != nullptr, {});
199 VectorRef div2_ref({is_div2, stridedslice2_ref, div2_y_[kConstFold]});
200 auto is_erf = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimErf>);
201 MS_CHECK_TRUE_RET(is_erf != nullptr, {});
202 VectorRef erf_ref({is_erf, div2_ref});
203 auto is_add3 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
204 MS_CHECK_TRUE_RET(is_add3 != nullptr, {});
205 VectorRef add3_ref({is_add3, erf_ref, add3_y_[kConstFold]});
206 auto is_mul3 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
207 MS_CHECK_TRUE_RET(is_mul3 != nullptr, {});
208 VectorRef mul3_ref({is_mul3, stridedslice2_ref, add3_ref});
209 auto is_mul4 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
210 MS_CHECK_TRUE_RET(is_mul4 != nullptr, {});
211 VectorRef mul4_ref({is_mul4, mul3_ref, mul4_y_[kConstFold]});
212 auto is_mul5 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
213 MS_CHECK_TRUE_RET(is_mul5 != nullptr, {});
214 VectorRef mul5_ref({is_mul5, stridedslice1_ref, mul4_ref});
215 auto is_matmul2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
216 MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
217 VectorRef matmul2_ref({is_matmul2, mul5_ref, matmul2_b_[kConstFold]});
218 return matmul2_ref;
219 }
220
DefinePatterns() const221 std::unordered_map<std::string, VectorRef> FFNFusion::DefinePatterns() const {
222 MS_LOG(INFO) << "start define FFN fusion patterns.";
223 if (!Init()) {
224 MS_LOG(ERROR) << "DefinePatterns Init Failed.";
225 return {};
226 }
227 std::unordered_map<std::string, VectorRef> patterns;
228 patterns[kFFNPatternForConstFolding] = DefineFFNPatternForConstFolding();
229 patterns[kFFNPatternForDynamicDims] = DefineFFNPatternForDynamicDims();
230 return patterns;
231 }
232
CheckPattern(const std::string & pattern_name,const EquivPtr & equiv) const233 bool FFNFusion::CheckPattern(const std::string &pattern_name, const EquivPtr &equiv) const {
234 int index = pattern_name == kFFNPatternForDynamicDims ? kDynamicDims : kConstFold;
235
236 float div2_y = GetFloatParameterValue(equiv, div2_y_[index]);
237 if (div2_y < 0 || fabs(div2_y - DIV2_Y) > DIFF_THRESHOLD) {
238 return false;
239 }
240 float add3_y = GetFloatParameterValue(equiv, add3_y_[index]);
241 if (add3_y < 0 || fabs(add3_y - ADD3_Y) > DIFF_THRESHOLD) {
242 return false;
243 }
244 float mul4_y = GetFloatParameterValue(equiv, mul4_y_[index]);
245 if (mul4_y < 0 || fabs(mul4_y - MUL4_y) > DIFF_THRESHOLD) {
246 return false;
247 }
248 if (pattern_name == kFFNPatternForConstFolding) {
249 return true;
250 }
251 // if pattern is for const folding, there are no nodes below, so no need checking.
252 int gather_index = GetIntParameterValue(equiv, gather_y_);
253 if (gather_index == INT_MIN) {
254 return false;
255 }
256 int add2_y = GetIntParameterValue(equiv, add2_y_);
257 if (add2_y != 1) {
258 return false;
259 }
260 int div1_y = GetIntParameterValue(equiv, div1_y_);
261 if (div1_y != DIV1_Y) {
262 return false;
263 }
264 int mul1_y = GetIntParameterValue(equiv, mul1_y_);
265 if (mul1_y != 1) {
266 return false;
267 }
268 int mul2_y = GetIntParameterValue(equiv, mul2_y_);
269 if (mul2_y != MUL2_Y) {
270 return false;
271 }
272 return true;
273 }
274
Process(const std::string & pattern_name,const mindspore::FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node,const mindspore::EquivPtr & equiv) const275 AnfNodePtr FFNFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
276 const mindspore::AnfNodePtr &node, const mindspore::EquivPtr &equiv) const {
277 MS_LOG(INFO) << "do fusion, pattern name: " << pattern_name;
278 if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
279 MS_LOG(ERROR) << "function graph, node or equiv is nullptr.";
280 return nullptr;
281 }
282 if (!utils::isa<CNodePtr>(node)) {
283 MS_LOG(ERROR) << "this node is not cnode, node name: " << node->fullname_with_scope();
284 return nullptr;
285 }
286 if (IsMarkedTrainOp(utils::cast<CNodePtr>(node))) {
287 MS_LOG(ERROR) << "node is train op, can not fusion.";
288 return nullptr;
289 }
290 if (!CheckPattern(pattern_name, equiv)) {
291 MS_LOG(ERROR) << "CheckPattern failed.";
292 return nullptr;
293 }
294 int index = pattern_name == kFFNPatternForDynamicDims ? kDynamicDims : kConstFold;
295 auto cnode = CreateFFNFusionNode(func_graph, node, equiv, index);
296 if (cnode == nullptr) {
297 MS_LOG(INFO) << "new FFN node failed.";
298 return nullptr;
299 }
300 MS_LOG(INFO) << "FFN fusion success, fusion node name: " << cnode->fullname_with_scope();
301 return cnode;
302 }
303 } // namespace opt
304 } // namespace mindspore
305